diff --git a/aidial_adapter_bedrock/dial_api/request.py b/aidial_adapter_bedrock/dial_api/request.py index 3e5738c..5b7f1b8 100644 --- a/aidial_adapter_bedrock/dial_api/request.py +++ b/aidial_adapter_bedrock/dial_api/request.py @@ -3,7 +3,10 @@ from aidial_sdk.chat_completion.request import ChatCompletionRequest from pydantic import BaseModel -from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig +from aidial_adapter_bedrock.llm.tools.tools_config import ( + ToolsConfig, + validate_messages, +) class ModelParameters(BaseModel): @@ -14,7 +17,7 @@ class ModelParameters(BaseModel): max_tokens: Optional[int] = None max_prompt_tokens: Optional[int] = None stream: bool = False - tool_config: Optional[ToolConfig] = None + tool_config: Optional[ToolsConfig] = None @classmethod def create(cls, request: ChatCompletionRequest) -> "ModelParameters": @@ -26,6 +29,8 @@ def create(cls, request: ChatCompletionRequest) -> "ModelParameters": else request.stop ) + validate_messages(request) + return cls( temperature=request.temperature, top_p=request.top_p, @@ -34,7 +39,7 @@ def create(cls, request: ChatCompletionRequest) -> "ModelParameters": max_tokens=request.max_tokens, max_prompt_tokens=request.max_prompt_tokens, stream=request.stream, - tool_config=ToolConfig.from_request(request), + tool_config=ToolsConfig.from_request(request), ) def add_stop_sequences(self, stop: List[str]) -> "ModelParameters": diff --git a/aidial_adapter_bedrock/llm/chat_model.py b/aidial_adapter_bedrock/llm/chat_model.py index c6ca819..5bca56c 100644 --- a/aidial_adapter_bedrock/llm/chat_model.py +++ b/aidial_adapter_bedrock/llm/chat_model.py @@ -11,7 +11,7 @@ from aidial_adapter_bedrock.llm.errors import ValidationError from aidial_adapter_bedrock.llm.message import BaseMessage, SystemMessage from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator -from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig +from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig from aidial_adapter_bedrock.llm.truncate_prompt import ( TruncatePromptError, truncate_prompt, @@ -30,7 +30,7 @@ def _is_empty_system_message(msg: Message) -> bool: class ChatCompletionAdapter(ABC, BaseModel): - tools_emulator: Callable[[Optional[ToolConfig]], ToolsEmulator] + tools_emulator: Callable[[Optional[ToolsConfig]], ToolsEmulator] class Config: arbitrary_types_allowed = True diff --git a/aidial_adapter_bedrock/llm/tools/claude_emulator.py b/aidial_adapter_bedrock/llm/tools/claude_emulator.py index 3b4f41d..83abf49 100644 --- a/aidial_adapter_bedrock/llm/tools/claude_emulator.py +++ b/aidial_adapter_bedrock/llm/tools/claude_emulator.py @@ -3,8 +3,9 @@ https://docs.anthropic.com/claude/docs/legacy-tool-use """ -from typing import Dict, List, Optional +from typing import List, Optional +from aidial_adapter_bedrock.llm.errors import ValidationError from aidial_adapter_bedrock.llm.message import ( AIFunctionCallMessage, AIRegularMessage, @@ -28,51 +29,33 @@ print_tool_declarations, ) from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator -from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig, ToolsMode -from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log +from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig def convert_to_base_message( - tool_config: Optional[ToolConfig], - id_to_name: Dict[str, str], - msg: ToolMessage, + tool_config: Optional[ToolsConfig], msg: ToolMessage ) -> BaseMessage: - mode: Optional[ToolsMode] = tool_config and tool_config.mode + match msg: case HumanToolResultMessage(id=id, content=content): - assert ( - mode is None or mode == ToolsMode.TOOLS - ), f"Received tool result in '{mode.value}' mode" - name = id_to_name.get(id) - if name is None: - name = "_unknown_" - log.warning( - f"Unable to find tool name for id '{id}', assuming '{name}' name" + if tool_config is None: + raise ValidationError( + "Tool message is used, but tools are not declared" ) - + name = tool_config.get_tool_name(id) return HumanRegularMessage( content=print_function_call_result(name=name, content=content) ) case HumanFunctionResultMessage(name=name, content=content): - assert ( - mode is None or mode == ToolsMode.FUNCTIONS - ), f"Received function result in '{mode.value}' mode" return HumanRegularMessage( content=print_function_call_result(name=name, content=content) ) case AIToolCallMessage(calls=calls): - assert ( - mode is None or mode == ToolsMode.TOOLS - ), f"Received tool call in '{mode.value}' mode" - for call in calls: - id_to_name[call.id] = call.function.name return AIRegularMessage(content=print_tool_calls(calls)) + case AIFunctionCallMessage(call=call): - assert ( - mode is None or mode == ToolsMode.FUNCTIONS - ), f"Received function call in '{mode.value}' mode" return AIRegularMessage(content=print_function_call(call)) @@ -85,7 +68,7 @@ class Config: @property def _tool_declarations(self) -> Optional[str]: return self.tool_config and print_tool_declarations( - self.tool_config.tools + self.tool_config.functions ) def add_tool_declarations( @@ -109,14 +92,11 @@ def get_stop_sequences(self) -> List[str]: def convert_to_base_messages( self, messages: List[BaseMessage | ToolMessage] ) -> List[BaseMessage]: - id_to_name: Dict[str, str] = {} return [ ( message if isinstance(message, BaseMessage) - else convert_to_base_message( - self.tool_config, id_to_name, message - ) + else convert_to_base_message(self.tool_config, message) ) for message in messages ] @@ -132,7 +112,7 @@ def recognize_call( def legacy_tools_emulator( - tool_config: Optional[ToolConfig], + tool_config: Optional[ToolsConfig], ) -> ToolsEmulator: return Claude2_1_ToolsEmulator( tool_config=tool_config, diff --git a/aidial_adapter_bedrock/llm/tools/claude_protocol.py b/aidial_adapter_bedrock/llm/tools/claude_protocol.py index 291e66d..4c44b39 100644 --- a/aidial_adapter_bedrock/llm/tools/claude_protocol.py +++ b/aidial_adapter_bedrock/llm/tools/claude_protocol.py @@ -1,14 +1,14 @@ import json from typing import Dict, List, Literal, Optional -from aidial_sdk.chat_completion import FunctionCall, Tool, ToolCall +from aidial_sdk.chat_completion import Function, FunctionCall, ToolCall from pydantic import BaseModel from aidial_adapter_bedrock.llm.message import ( AIFunctionCallMessage, AIToolCallMessage, ) -from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig, ToolsMode +from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig from aidial_adapter_bedrock.utils.pydnatic import ExtraForbidModel from aidial_adapter_bedrock.utils.xml import parse_xml, tag, tag_nl @@ -91,21 +91,23 @@ def _print_tool_parameters(parameters: ToolParameters) -> str: ) -def _print_tool_declaration(tool: Tool) -> str: +def _print_tool_declaration(function: Function) -> str: return tag_nl( "tool_description", [ - tag("tool_name", tool.function.name), - tag("description", tool.function.description), + tag("tool_name", function.name), + tag("description", function.description), _print_tool_parameters( - ToolParameters.parse_obj(tool.function.parameters) + ToolParameters.parse_obj(function.parameters) ), ], ) -def print_tool_declarations(tools: List[Tool]) -> str: - return tag_nl("tools", [_print_tool_declaration(tool) for tool in tools]) +def print_tool_declarations(functions: List[Function]) -> str: + return tag_nl( + "tools", [_print_tool_declaration(function) for function in functions] + ) def _print_function_call_parameters(parameters: dict) -> str: @@ -146,13 +148,12 @@ def _parse_function_call(text: str) -> FunctionCall: start_index = text.find(FUNC_START_TAG) if start_index == -1: raise Exception( - f"Unable to parse function call, missing '{FUNC_TAG_NAME}' tag" + f"Unable to parse function call, missing {FUNC_TAG_NAME!r} tag" ) try: dict = parse_xml(text[start_index:]) invocation = dict[FUNC_TAG_NAME]["invoke"] - tool_name = invocation["tool_name"] parameters = invocation["parameters"] except Exception: @@ -162,23 +163,18 @@ def _parse_function_call(text: str) -> FunctionCall: def parse_call( - config: Optional[ToolConfig], text: str + config: Optional[ToolsConfig], text: str ) -> AIToolCallMessage | AIFunctionCallMessage | None: if config is None: return None call = _parse_function_call(text) - match config.mode: - case ToolsMode.TOOLS: - return AIToolCallMessage( - calls=[ - ToolCall( - index=0, id=call.name, type="function", function=call - ) - ] - ) - case ToolsMode.FUNCTIONS: - return AIFunctionCallMessage(call=call) + if config.is_tool: + id = config.create_fresh_tool_call_id(call.name) + tool_call = ToolCall(index=0, id=id, type="function", function=call) + return AIToolCallMessage(calls=[tool_call]) + else: + return AIFunctionCallMessage(call=call) def print_function_call_result(name: str, content: str) -> str: diff --git a/aidial_adapter_bedrock/llm/tools/default_emulator.py b/aidial_adapter_bedrock/llm/tools/default_emulator.py index d11bf66..399947e 100644 --- a/aidial_adapter_bedrock/llm/tools/default_emulator.py +++ b/aidial_adapter_bedrock/llm/tools/default_emulator.py @@ -7,7 +7,7 @@ ToolMessage, ) from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator -from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig +from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log @@ -45,5 +45,5 @@ def recognize_call( return content -def default_tools_emulator(tool_config: Optional[ToolConfig]) -> ToolsEmulator: +def default_tools_emulator(tool_config: Optional[ToolsConfig]) -> ToolsEmulator: return DefaultToolsEmulator(tool_config=tool_config) diff --git a/aidial_adapter_bedrock/llm/tools/emulator.py b/aidial_adapter_bedrock/llm/tools/emulator.py index 8fbdf81..d81fabc 100644 --- a/aidial_adapter_bedrock/llm/tools/emulator.py +++ b/aidial_adapter_bedrock/llm/tools/emulator.py @@ -11,11 +11,11 @@ ToolMessage, parse_dial_message, ) -from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig +from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig class ToolsEmulator(ABC, BaseModel): - tool_config: Optional[ToolConfig] + tool_config: Optional[ToolsConfig] @abstractmethod def add_tool_declarations( diff --git a/aidial_adapter_bedrock/llm/tools/tool_config.py b/aidial_adapter_bedrock/llm/tools/tool_config.py deleted file mode 100644 index 984f86d..0000000 --- a/aidial_adapter_bedrock/llm/tools/tool_config.py +++ /dev/null @@ -1,73 +0,0 @@ -from enum import Enum -from typing import List, Optional - -from aidial_sdk.chat_completion import ( - Function, - FunctionChoice, - Tool, - ToolChoice, -) -from aidial_sdk.chat_completion.request import ChatCompletionRequest -from pydantic import BaseModel - -from aidial_adapter_bedrock.llm.errors import ValidationError - - -class ToolsMode(str, Enum): - TOOLS = "tools" - FUNCTIONS = "functions" - - -def _fun_to_tool(fun: Function) -> Tool: - return Tool(type="function", function=fun) - - -class ToolConfig(BaseModel): - mode: ToolsMode - tools: List[Tool] - - @classmethod - def from_request( - cls, request: ChatCompletionRequest - ) -> Optional["ToolConfig"]: - mode: ToolsMode = ToolsMode.TOOLS - tools: List[Tool] = [] - selected_function: Optional[str] = None - - if request.functions is not None and len(request.functions) > 0: - choice = request.function_call - if choice == "none": - return None - - if isinstance(choice, FunctionChoice): - selected_function = choice.name - - mode = ToolsMode.FUNCTIONS - tools = [_fun_to_tool(fun) for fun in request.functions] - - elif request.tools is not None and len(request.tools) > 0: - choice = request.tool_choice - if choice == "none": - return None - - if isinstance(choice, ToolChoice): - selected_function = choice.function.name - - mode = ToolsMode.TOOLS - tools = request.tools - else: - return None - - if selected_function is not None: - tools = [ - tool - for tool in tools - if tool.function.name == selected_function - ] - - if len(tools) == 0: - raise ValidationError( - f"Unable to find tool with name '{selected_function}'" - ) - - return cls(mode=mode, tools=tools) diff --git a/aidial_adapter_bedrock/llm/tools/tools_config.py b/aidial_adapter_bedrock/llm/tools/tools_config.py new file mode 100644 index 0000000..5daf220 --- /dev/null +++ b/aidial_adapter_bedrock/llm/tools/tools_config.py @@ -0,0 +1,173 @@ +from typing import Dict, List, Literal, Self, Tuple, assert_never + +from aidial_sdk.chat_completion import ( + Function, + FunctionChoice, + Message, + Role, + ToolChoice, +) +from aidial_sdk.chat_completion.request import AzureChatCompletionRequest +from pydantic import BaseModel + +from aidial_adapter_bedrock.llm.errors import ValidationError + + +class ToolsConfig(BaseModel): + functions: List[Function] + """ + List of functions/tools. + """ + + required: bool + """ + True forces the model to call one of the available functions. + False allows the model to pick between generating a message or + calling one or more tools/functions. + """ + + tool_ids: Dict[str, str] | None + """ + Mapping from tool call IDs to corresponding tool names. + None means that functions are used, not tools. + """ + + @property + def is_tool(self) -> bool: + return self.tool_ids is not None + + def not_supported(self) -> None: + if self.functions: + if self.is_tool: + raise ValidationError("The tools aren't supported") + else: + raise ValidationError("The functions aren't supported") + + def create_fresh_tool_call_id(self, tool_name: str) -> str: + if self.tool_ids is None: + raise ValidationError("Function are used, but requested tool id") + + idx = 1 + while True: + id = f"{tool_name}_{idx}" + if id not in self.tool_ids: + self.tool_ids[id] = tool_name + return id + idx += 1 + + def get_tool_name(self, tool_call_id: str) -> str: + if self.tool_ids is None: + raise ValidationError("Function are used, but requested tool name") + + tool_name = self.tool_ids.get(tool_call_id) + if tool_name is None: + raise ValidationError(f"Tool call ID not found: {self.tool_ids}") + return tool_name + + @staticmethod + def filter_functions( + function_call: Literal["auto", "none"] | FunctionChoice, + functions: List[Function], + ) -> Tuple[bool, List[Function]]: + match function_call: + case "none": + return False, [] + case "auto": + return False, functions + case FunctionChoice(name=name): + new_functions = [ + func for func in functions if func.name == name + ] + if not new_functions: + raise ValidationError( + f"Function {name!r} is not on the list of available functions" + ) + return True, new_functions + case _: + assert_never(function_call) + + @staticmethod + def tool_choice_to_function_call( + tool_choice: Literal["auto", "none"] | ToolChoice | None, + ) -> Literal["auto", "none"] | FunctionChoice | None: + match tool_choice: + case ToolChoice(function=FunctionChoice(name=name)): + return FunctionChoice(name=name) + case _: + return tool_choice + + @classmethod + def from_request(cls, request: AzureChatCompletionRequest) -> Self | None: + validate_messages(request) + + if request.functions is not None: + functions = request.functions + function_call = request.function_call + tool_ids = None + + elif request.tools is not None: + functions = [tool.function for tool in request.tools] + function_call = ToolsConfig.tool_choice_to_function_call( + request.tool_choice + ) + tool_ids = collect_tool_ids(request.messages) + + else: + functions = [] + function_call = None + tool_ids = None + + if function_call is None: + function_call = "auto" if functions else "none" + + required, selected = ToolsConfig.filter_functions( + function_call, functions + ) + + if selected == []: + return None + + return cls(functions=selected, required=required, tool_ids=tool_ids) + + +def validate_messages(request: AzureChatCompletionRequest) -> None: + decl_tools = request.tools is not None + decl_functions = request.functions is not None + + if decl_functions and decl_tools: + raise ValidationError("Both functions and tools are not allowed") + + for message in request.messages: + if message.role == Role.ASSISTANT: + use_tools = message.tool_calls is not None + if use_tools and not decl_tools: + raise ValidationError( + "Assistant message uses tools, but tools are not declared" + ) + + use_functions = message.function_call is not None + if use_functions and not decl_functions: + raise ValidationError( + "Assistant message uses functions, but functions are not declared" + ) + if message.role == Role.FUNCTION: + if not decl_functions: + raise ValidationError( + "Function message is used, but functions are not declared" + ) + if message.role == Role.TOOL: + if not decl_tools: + raise ValidationError( + "Tool message is used, but tools are not declared" + ) + + +def collect_tool_ids(messages: List[Message]) -> Dict[str, str]: + ret: Dict[str, str] = {} + + for message in messages: + if message.role == Role.ASSISTANT and message.tool_calls is not None: + for tool_call in message.tool_calls: + ret[tool_call.id] = tool_call.function.name + + return ret diff --git a/pyproject.toml b/pyproject.toml index d1eac9d..534e43f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,10 @@ nox = "^2023.4.22" [tool.pytest.ini_options] addopts = "--doctest-modules" +# muting warnings coming from opentelemetry package +filterwarnings = [ + "ignore::DeprecationWarning:opentelemetry.instrumentation.dependencies" +] [tool.pyright] typeCheckingMode = "basic" diff --git a/tests/integration_tests/test_chat_completion.py b/tests/integration_tests/test_chat_completion.py index 9c3ef5c..6d06fb7 100644 --- a/tests/integration_tests/test_chat_completion.py +++ b/tests/integration_tests/test_chat_completion.py @@ -16,7 +16,10 @@ from aidial_adapter_bedrock.deployments import BedrockDeployment from tests.conftest import TEST_SERVER_URL +from tests.utils.json import match_objects from tests.utils.openai import ( + GET_WEATHER_FUNCTION, + GET_WEATHER_TOOL, ChatCompletionResult, ai, ai_function, @@ -24,7 +27,6 @@ chat_completion, function_request, function_response, - function_to_tool, get_client, sanitize_test_name, sys, @@ -33,51 +35,12 @@ user, ) -get_weather_function: Function = { - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "format": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the users location.", - }, - }, - "required": ["location", "format"], - }, -} - - -def match_objects(expected: Any, actual: Any) -> None: - if isinstance(expected, dict): - assert list(sorted(expected.keys())) == list(sorted(actual.keys())) - for k, v in expected.items(): - match_objects(v, actual[k]) - elif isinstance(expected, tuple): - assert len(expected) == len(actual) - for i in range(len(expected)): - match_objects(expected[i], actual[i]) - elif isinstance(expected, list): - assert len(expected) == len(actual) - for i in range(len(expected)): - match_objects(expected[i], actual[i]) - elif callable(expected): - assert expected(actual) - else: - assert expected == actual - def is_valid_function_call( - call: FunctionCall | None, function: Function, expected_args: Any + call: FunctionCall | None, expected_name: str, expected_args: Any ) -> bool: assert call is not None - assert call.name == function["name"] + assert call.name == expected_name obj = json.loads(call.arguments) match_objects(expected_args, obj) return True @@ -85,7 +48,8 @@ def is_valid_function_call( def is_valid_tool_calls( calls: List[ChatCompletionMessageToolCall] | None, - function: Function, + expected_id: str, + expected_name: str, expected_args: Any, ) -> bool: assert calls is not None @@ -93,8 +57,8 @@ def is_valid_tool_calls( call = calls[0] function_call = call.function - assert call.id == function["name"] - assert function_call.name == function["name"] + assert call.id == expected_id + assert function_call.name == expected_name obj = json.loads(function_call.arguments) match_objects(expected_args, obj) @@ -126,15 +90,18 @@ class TestCase: max_tokens: int | None stop: List[str] | None + n: int | None + functions: List[Function] | None tools: List[ChatCompletionToolParam] | None def get_id(self): max_tokens_str = f"maxt={self.max_tokens}" if self.max_tokens else "" stop_sequence_str = f"stop={self.stop}" if self.stop else "" + n_str = f"n={self.n}" if self.n else "" return sanitize_test_name( f"{self.deployment.value} {self.streaming} {max_tokens_str} " - f"{stop_sequence_str} {self.name}" + f"{stop_sequence_str} {n_str} {self.name}" ) @@ -198,6 +165,7 @@ def test_case( expected: ( Callable[[ChatCompletionResult], bool] | ExpectedException ) = expected_success, + n: int | None = None, max_tokens: int | None = None, stop: List[str] | None = None, functions: List[Function] | None = None, @@ -212,6 +180,7 @@ def test_case( expected, max_tokens, stop, + n, functions, tools, ) @@ -342,15 +311,15 @@ def test_case( function_args = {"location": "Glasgow", "format": "celsius"} - name = get_weather_function["name"] + name = GET_WEATHER_FUNCTION["name"] # Functions test_case( name="weather function", messages=[user(query)], - functions=[get_weather_function], + functions=[GET_WEATHER_FUNCTION], expected=lambda s: is_valid_function_call( - s.function_call, get_weather_function, function_args_checker + s.function_call, name, function_args_checker ), ) @@ -360,27 +329,28 @@ def test_case( test_case( name="weather function followup", messages=[user(query), function_req, function_resp], - functions=[get_weather_function], + functions=[GET_WEATHER_FUNCTION], expected=lambda s: "15" in s.content.lower(), ) # Tools + tool_call_id = f"{name}_1" test_case( name="weather tool", messages=[user(query)], - tools=[function_to_tool(get_weather_function)], + tools=[GET_WEATHER_TOOL], expected=lambda s: is_valid_tool_calls( - s.tool_calls, get_weather_function, function_args_checker + s.tool_calls, tool_call_id, name, function_args_checker ), ) - tool_req = ai_tools([tool_request(name, function_args)]) - tool_resp = tool_response(name, "15 celsius") + tool_req = ai_tools([tool_request(tool_call_id, name, function_args)]) + tool_resp = tool_response(tool_call_id, "15 celsius") test_case( name="weather tool followup", messages=[user(query), tool_req, tool_resp], - tools=[function_to_tool(get_weather_function)], + tools=[GET_WEATHER_TOOL], expected=lambda s: "15" in s.content.lower(), ) @@ -408,6 +378,7 @@ async def run_chat_completion() -> ChatCompletionResult: test.streaming, test.stop, test.max_tokens, + test.n, test.functions, test.tools, ) diff --git a/tests/unit_tests/tools_emulation/test_claude.py b/tests/unit_tests/tools_emulation/test_claude.py index a4bd2af..cda0d3b 100644 --- a/tests/unit_tests/tools_emulation/test_claude.py +++ b/tests/unit_tests/tools_emulation/test_claude.py @@ -1,6 +1,6 @@ import json -from aidial_sdk.chat_completion import Function, FunctionCall, Tool +from aidial_sdk.chat_completion import Function, FunctionCall from aidial_adapter_bedrock.llm.tools.claude_protocol import ( _parse_function_call, @@ -9,62 +9,50 @@ ) from aidial_adapter_bedrock.utils.xml import parse_xml -TOOL_ARITY_2 = Tool( - type="function", - function=Function( - name="func_arity_2", - description="desc", - parameters={ - "type": "object", - "properties": { - "param1": {"type": "type1", "description": "desc1"}, - "param2": {"type": "type2"}, - }, +TOOL_ARITY_2 = Function( + name="func_arity_2", + description="desc", + parameters={ + "type": "object", + "properties": { + "param1": {"type": "type1", "description": "desc1"}, + "param2": {"type": "type2"}, }, - ), + }, ) -TOOL_ARITY_0 = Tool( - type="function", - function=Function( - name="func_arity_0", - description="desc", - parameters={"type": "object", "properties": {}}, - ), +TOOL_ARITY_0 = Function( + name="func_arity_0", + description="desc", + parameters={"type": "object", "properties": {}}, ) -TOOL_ENUM_PARAM = Tool( - type="function", - function=Function( - name="func_enum_param", - description="tool with enum parameter", - parameters={ - "type": "object", - "properties": { - "param1": { - "type": "string", - "enum": ["value1", "value2", "value3"], - } - }, +TOOL_ENUM_PARAM = Function( + name="func_enum_param", + description="tool with enum parameter", + parameters={ + "type": "object", + "properties": { + "param1": { + "type": "string", + "enum": ["value1", "value2", "value3"], + } }, - ), + }, ) -TOOL_ARRAY_PARAM = Tool( - type="function", - function=Function( - name="func_array_param", - description="tool with array parameter", - parameters={ - "type": "object", - "properties": { - "param1": { - "type": "array", - "items": {"type": "string"}, - } - }, +TOOL_ARRAY_PARAM = Function( + name="func_array_param", + description="tool with array parameter", + parameters={ + "type": "object", + "properties": { + "param1": { + "type": "array", + "items": {"type": "string"}, + } }, - ), + }, ) diff --git a/tests/utils/json.py b/tests/utils/json.py new file mode 100644 index 0000000..fa6d630 --- /dev/null +++ b/tests/utils/json.py @@ -0,0 +1,20 @@ +from typing import Any + + +def match_objects(expected: Any, actual: Any) -> None: + if isinstance(expected, dict): + assert list(sorted(expected.keys())) == list(sorted(actual.keys())) + for k, v in expected.items(): + match_objects(v, actual[k]) + elif isinstance(expected, tuple): + assert len(expected) == len(actual) + for i in range(len(expected)): + match_objects(expected[i], actual[i]) + elif isinstance(expected, list): + assert len(expected) == len(actual) + for i in range(len(expected)): + match_objects(expected[i], actual[i]) + elif callable(expected): + assert expected(actual) + else: + assert expected == actual diff --git a/tests/utils/openai.py b/tests/utils/openai.py index dfa05b6..f962fd8 100644 --- a/tests/utils/openai.py +++ b/tests/utils/openai.py @@ -60,9 +60,11 @@ def function_request(name: str, args: Any) -> ToolFunction: return {"name": name, "arguments": json.dumps(args)} -def tool_request(name: str, args: Any) -> ChatCompletionMessageToolCallParam: +def tool_request( + id: str, name: str, args: Any +) -> ChatCompletionMessageToolCallParam: return { - "id": name, + "id": id, "type": "function", "function": function_request(name, args), } @@ -74,8 +76,8 @@ def function_response( return {"role": "function", "name": name, "content": content} -def tool_response(name: str, content: str) -> ChatCompletionToolMessageParam: - return {"role": "tool", "tool_call_id": name, "content": content} +def tool_response(id: str, content: str) -> ChatCompletionToolMessageParam: + return {"role": "tool", "tool_call_id": id, "content": content} def function_to_tool(function: FunctionDefinition) -> ChatCompletionToolParam: @@ -98,6 +100,12 @@ def message(self) -> ChatCompletionMessage: def content(self) -> str: return self.message.content or "" + @property + def contents(self) -> List[str]: + return [ + choice.message.content or "" for choice in self.response.choices + ] + @property def usage(self) -> CompletionUsage | None: return self.response.usage @@ -117,6 +125,7 @@ async def chat_completion( stream: bool, stop: Optional[List[str]], max_tokens: Optional[int], + n: Optional[int], functions: List[Function] | None, tools: List[ChatCompletionToolParam] | None, ) -> ChatCompletionResult: @@ -128,6 +137,7 @@ async def get_response() -> ChatCompletion: stop=stop, max_tokens=max_tokens, temperature=0.0, + n=n, function_call="auto" if functions is not None else NOT_GIVEN, functions=functions or NOT_GIVEN, tool_choice="auto" if tools is not None else NOT_GIVEN, @@ -161,3 +171,28 @@ def get_client(base_url: str, model_id: str) -> AsyncAzureOpenAI: max_retries=0, timeout=30, ) + + +GET_WEATHER_FUNCTION: Function = { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, +} + +GET_WEATHER_TOOL: ChatCompletionToolParam = function_to_tool( + GET_WEATHER_FUNCTION +)