Skip to content

Commit

Permalink
Added ResponseFormat to ChatCompletionRequest (#81)
Browse files Browse the repository at this point in the history
* Added ResponseFormat to ChatCompletionRequest
* Added Test for json mode
* Corrected failing integration test
  • Loading branch information
jodendaal committed Nov 10, 2023
1 parent ef69f0b commit 87e4ab3
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 27 deletions.
37 changes: 33 additions & 4 deletions src/OpenAI.Net.Integration.Tests/ChatCompletionService_Create.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using OpenAI.Net.Models;
using Newtonsoft.Json;
using OpenAI.Net.Models;
using OpenAI.Net.Models.Requests;
using System.Net;

using System.Text.Json;

namespace OpenAI.Net.Integration.Tests
{
public class ChatCompletionService_Create : BaseTest
Expand Down Expand Up @@ -29,6 +31,33 @@ public async Task Get(string model, bool isSuccess, HttpStatusCode statusCode)
{
Assert.That(response.Result!.Choices.FirstOrDefault()!.Message.Content.ToLowerInvariant(), Contains.Substring("this is a test"), "Choices are not mapped correctly");
}
}

[TestCase(ModelTypes.Gpt3_5Turbo1106, true, HttpStatusCode.OK, TestName = "Get_When_Success_Json_Object")]
[TestCase("invalid_model", false, HttpStatusCode.NotFound, TestName = "Get_When_Fail_Json_Object")]
public async Task Get_Json_Object(string model, bool isSuccess, HttpStatusCode statusCode)
{
var messages = new List<Message>
{
Message.Create(ChatRoleType.User, "Say this is a test, return response in json")
};

var request = new ChatCompletionRequest(messages)
{
Model = model,
ResponseFormat = ChatResponseFormat.Json
};

var response = await OpenAIService.Chat.Get(request);

Assert.That(response.IsSuccess, Is.EqualTo(isSuccess), "Request failed");
Assert.That(response.StatusCode, Is.EqualTo(statusCode));
Assert.That(response.Result?.Choices?.Count() == 1, Is.EqualTo(isSuccess), "Choices are not mapped correctly");
if (isSuccess)
{
Assert.That(response.Result!.Choices.FirstOrDefault()!.Message.Content.ToLowerInvariant(), Contains.Substring("this is a test"), "Choices are not mapped correctly");
JsonDocument.Parse(response.Result!.Choices.FirstOrDefault()!.Message.Content);
}
}

[TestCase(ModelTypes.GPT35Turbo, true, HttpStatusCode.OK, TestName = "GetWithListExtension_When_Success")]
Expand All @@ -52,8 +81,8 @@ public async Task GetWithListExtension(string model, bool isSuccess, HttpStatusC
Assert.That(response.StatusCode, Is.EqualTo(statusCode));
Assert.That(response.Result?.Choices?.Count() == 1, Is.EqualTo(isSuccess), "Choices are not mapped correctly");
if (isSuccess)
{
Assert.That(response.Result?.Choices?.FirstOrDefault()?.Message.Content.Contains("Globe Life Field",StringComparison.InvariantCultureIgnoreCase), Is.EqualTo(true), $"Incorrect answer {response.Result?.Choices?.FirstOrDefault()?.Message.Content}");
{
Assert.That(response.Result?.Choices?.FirstOrDefault()?.Message.Content.ToLowerInvariant(), expression: Contains.Substring("globe life field").Or.Contains("texas"), $"Incorrect answer {response.Result?.Choices?.FirstOrDefault()?.Message.Content}");
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/OpenAI.Net.Integration.Tests/ImageService_Variation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public async Task Variation_FileInfoFileContent(bool isSuccess, HttpStatusCode s
Assert.That(response.StatusCode, Is.EqualTo(statusCode));
Assert.That(response.Result?.Data?.Count() == 1, Is.EqualTo(isSuccess), "Data is not mapped correctly");
Assert.That(response.Result?.Data?[0].Url?.Contains("https://"), isSuccess ? Is.EqualTo(isSuccess) : Is.EqualTo(null), "Choice text not set");
Assert.That(response.ErrorResponse?.Error?.Message?.Contains("is not one of ['256x256', '512x512', '1024x1024']"), isSuccess ? Is.EqualTo(null) : Is.EqualTo(true), "Error message not returned");
Assert.That(response.ErrorResponse?.Error?.Message, isSuccess ? Is.EqualTo(null) : Contains.Substring("is not one of ['256x256', '512x512', '1024x1024']"), "Error message not returned");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,12 @@ public void CreateMessage(bool isSuccess, string role)
Assert.That($"Role must be one of the following ${string.Join(",", validTypes)} (Parameter 'role')", Is.EqualTo(exception.Message));
}
}

[Test]
public void ChatResponseFormatTests()
{
Assert.That(ChatResponseFormat.Text.Type, Is.EqualTo("text" ));
Assert.That(ChatResponseFormat.Json.Type, Is.EqualTo("json_object"));
}
}
}
8 changes: 8 additions & 0 deletions src/OpenAI.Net/Models/ChatResponseFormat.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace OpenAI.Net
{
public static class ChatResponseFormat
{
public static readonly ChatResponseFormatType Text = new ChatResponseFormatType() { Type = "text" };
public static readonly ChatResponseFormatType Json = new ChatResponseFormatType() { Type = "json_object" };
}
}
40 changes: 23 additions & 17 deletions src/OpenAI.Net/Models/ModelTypes.cs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
namespace OpenAI.Net
{
public static class ModelTypes
public class ModelTypes
{
public const string Ada = "ada";
public const string AdaCodeSearchCode = "ada-code-search-code";
public const string AdaCodeSearchText = "ada-code-search-text";
public const string AdaSearchDocument = "ada-search-document";
public const string AdaSearchQuery = "ada-search-query";
public const string AdaSimilarity = "ada-similarity";
public const string Ada_20200503 = "ada:2020-05-03";
public const string Babbage = "babbage";
public const string Babbage002 = "babbage-002";
public const string BabbageCodeSearchCode = "babbage-code-search-code";
public const string BabbageCodeSearchText = "babbage-code-search-text";
public const string BabbageSearchDocument = "babbage-search-document";
public const string BabbageSearchQuery = "babbage-search-query";
public const string BabbageSimilarity = "babbage-similarity";
public const string Babbage_20200503 = "babbage:2020-05-03";
public const string CanaryTts = "canary-tts";
public const string CodeDavinciEdit001 = "code-davinci-edit-001";
public const string CodeSearchAdaCode001 = "code-search-ada-code-001";
public const string CodeSearchAdaText001 = "code-search-ada-text-001";
Expand All @@ -26,30 +26,34 @@ public static class ModelTypes
public const string CurieSearchDocument = "curie-search-document";
public const string CurieSearchQuery = "curie-search-query";
public const string CurieSimilarity = "curie-similarity";
public const string Curie_20200503 = "curie:2020-05-03";
public const string Cushman_20200503 = "cushman:2020-05-03";
public const string DallE2 = "dall-e-2";
public const string DallE3 = "dall-e-3";
public const string Davinci = "davinci";
public const string DavinciIf_3_0_0 = "davinci-if:3.0.0";
public const string Davinci002 = "davinci-002";
public const string DavinciInstructBeta = "davinci-instruct-beta";
public const string DavinciInstructBeta_2_0_0 = "davinci-instruct-beta:2.0.0";
public const string DavinciSearchDocument = "davinci-search-document";
public const string DavinciSearchQuery = "davinci-search-query";
public const string DavinciSimilarity = "davinci-similarity";
public const string Davinci_20200503 = "davinci:2020-05-03";
public const string IfCurieV2 = "if-curie-v2";
public const string IfDavinciV2 = "if-davinci-v2";
public const string IfDavinci_3_0_0 = "if-davinci:3.0.0";
public const string GPT35Turbo = "gpt-3.5-turbo";
public const string Gpt3_5Turbo0301 = "gpt-3.5-turbo-0301";
public const string Gpt3_5Turbo0613 = "gpt-3.5-turbo-0613";
public const string Gpt3_5Turbo1106 = "gpt-3.5-turbo-1106";
public const string Gpt3_5Turbo16k = "gpt-3.5-turbo-16k";
public const string Gpt3_5Turbo16k0613 = "gpt-3.5-turbo-16k-0613";
public const string Gpt3_5TurboInstruct = "gpt-3.5-turbo-instruct";
public const string Gpt3_5TurboInstruct0914 = "gpt-3.5-turbo-instruct-0914";
public const string Gpt4 = "gpt-4";
public const string Gpt40314 = "gpt-4-0314";
public const string Gpt40613 = "gpt-4-0613";
public const string Gpt41106Preview = "gpt-4-1106-preview";
public const string Gpt4VisionPreview = "gpt-4-vision-preview";
public const string TextAda001 = "text-ada-001";
public const string TextAda_001 = "text-ada:001";
public const string TextBabbage001 = "text-babbage-001";
public const string TextBabbage_001 = "text-babbage:001";
public const string TextCurie001 = "text-curie-001";
public const string TextCurie_001 = "text-curie:001";
public const string TextDavinci001 = "text-davinci-001";
public const string TextDavinci002 = "text-davinci-002";
public const string TextDavinci003 = "text-davinci-003";
public const string TextDavinciEdit001 = "text-davinci-edit-001";
public const string TextDavinci_001 = "text-davinci:001";
public const string TextEmbeddingAda002 = "text-embedding-ada-002";
public const string TextSearchAdaDoc001 = "text-search-ada-doc-001";
public const string TextSearchAdaQuery001 = "text-search-ada-query-001";
Expand All @@ -63,8 +67,10 @@ public static class ModelTypes
public const string TextSimilarityBabbage001 = "text-similarity-babbage-001";
public const string TextSimilarityCurie001 = "text-similarity-curie-001";
public const string TextSimilarityDavinci001 = "text-similarity-davinci-001";
public const string GPT35Turbo = "gpt-3.5-turbo";
public const string Tts1 = "tts-1";
public const string Tts11106 = "tts-1-1106";
public const string Tts1Hd = "tts-1-hd";
public const string Tts1Hd1106 = "tts-1-hd-1106";
public const string Whisper1 = "whisper-1";

}
}
21 changes: 16 additions & 5 deletions src/OpenAI.Net/Models/Requests/ChatCompletionRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,31 +124,37 @@ public ChatCompletionRequest(string model, Message message) : this(model,message
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. <a href="https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids">Learn more</a>.
/// </summary>
public string User { get; set; }
}

/// <summary>
/// An object specifying the format that the model must output. <br/>
/// Setting to { "type": "json_object" }
/// enables JSON mode, which guarantees the message the model generates is valid JSON. <br/>
/// Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message.Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request.Also note that the message content may be partially cut off if finish_reason= "length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length. <br/>
/// <see href="https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias" />
/// </summary>
[JsonPropertyName("response_format")]
public ChatResponseFormatType? ResponseFormat { get; set; }
}
}

namespace OpenAI.Net
{
public class Message
{
private static readonly string[] _validRoles = new string[] { ChatRoleType.User, ChatRoleType.System, ChatRoleType.Assistant };

private Message(string role, string content)
{
Role = role;
Content = content;
}

public string Role { get; init; }
public string Content { get; init; }

public static Message Create(string role, string content)
{
if (!_validRoles.Contains(role))
{
throw new ArgumentException($"Role must be one of the following ${string.Join(",", _validRoles)}", nameof(role));
}

return new Message(role, content);
}
}
Expand All @@ -159,4 +165,9 @@ public class ChatRoleType
public const string System = "system";
public const string Assistant = "assistant";
}

public class ChatResponseFormatType
{
public string Type { get; set; }
}
}

0 comments on commit 87e4ab3

Please sign in to comment.