Skip to content

Commit

Permalink
fix: makings tool call ids unique (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed May 14, 2024
1 parent e7bf33d commit 9ea03d1
Show file tree
Hide file tree
Showing 13 changed files with 343 additions and 244 deletions.
11 changes: 8 additions & 3 deletions aidial_adapter_bedrock/dial_api/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from aidial_sdk.chat_completion.request import ChatCompletionRequest
from pydantic import BaseModel

from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig
from aidial_adapter_bedrock.llm.tools.tools_config import (
ToolsConfig,
validate_messages,
)


class ModelParameters(BaseModel):
Expand All @@ -14,7 +17,7 @@ class ModelParameters(BaseModel):
max_tokens: Optional[int] = None
max_prompt_tokens: Optional[int] = None
stream: bool = False
tool_config: Optional[ToolConfig] = None
tool_config: Optional[ToolsConfig] = None

@classmethod
def create(cls, request: ChatCompletionRequest) -> "ModelParameters":
Expand All @@ -26,6 +29,8 @@ def create(cls, request: ChatCompletionRequest) -> "ModelParameters":
else request.stop
)

validate_messages(request)

return cls(
temperature=request.temperature,
top_p=request.top_p,
Expand All @@ -34,7 +39,7 @@ def create(cls, request: ChatCompletionRequest) -> "ModelParameters":
max_tokens=request.max_tokens,
max_prompt_tokens=request.max_prompt_tokens,
stream=request.stream,
tool_config=ToolConfig.from_request(request),
tool_config=ToolsConfig.from_request(request),
)

def add_stop_sequences(self, stop: List[str]) -> "ModelParameters":
Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.message import BaseMessage, SystemMessage
from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator
from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig
from aidial_adapter_bedrock.llm.truncate_prompt import (
TruncatePromptError,
truncate_prompt,
Expand All @@ -30,7 +30,7 @@ def _is_empty_system_message(msg: Message) -> bool:


class ChatCompletionAdapter(ABC, BaseModel):
tools_emulator: Callable[[Optional[ToolConfig]], ToolsEmulator]
tools_emulator: Callable[[Optional[ToolsConfig]], ToolsEmulator]

class Config:
arbitrary_types_allowed = True
Expand Down
46 changes: 13 additions & 33 deletions aidial_adapter_bedrock/llm/tools/claude_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
https://docs.anthropic.com/claude/docs/legacy-tool-use
"""

from typing import Dict, List, Optional
from typing import List, Optional

from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.message import (
AIFunctionCallMessage,
AIRegularMessage,
Expand All @@ -28,51 +29,33 @@
print_tool_declarations,
)
from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator
from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig, ToolsMode
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig


def convert_to_base_message(
tool_config: Optional[ToolConfig],
id_to_name: Dict[str, str],
msg: ToolMessage,
tool_config: Optional[ToolsConfig], msg: ToolMessage
) -> BaseMessage:
mode: Optional[ToolsMode] = tool_config and tool_config.mode

match msg:
case HumanToolResultMessage(id=id, content=content):
assert (
mode is None or mode == ToolsMode.TOOLS
), f"Received tool result in '{mode.value}' mode"
name = id_to_name.get(id)
if name is None:
name = "_unknown_"
log.warning(
f"Unable to find tool name for id '{id}', assuming '{name}' name"
if tool_config is None:
raise ValidationError(
"Tool message is used, but tools are not declared"
)

name = tool_config.get_tool_name(id)
return HumanRegularMessage(
content=print_function_call_result(name=name, content=content)
)

case HumanFunctionResultMessage(name=name, content=content):
assert (
mode is None or mode == ToolsMode.FUNCTIONS
), f"Received function result in '{mode.value}' mode"
return HumanRegularMessage(
content=print_function_call_result(name=name, content=content)
)

case AIToolCallMessage(calls=calls):
assert (
mode is None or mode == ToolsMode.TOOLS
), f"Received tool call in '{mode.value}' mode"
for call in calls:
id_to_name[call.id] = call.function.name
return AIRegularMessage(content=print_tool_calls(calls))

case AIFunctionCallMessage(call=call):
assert (
mode is None or mode == ToolsMode.FUNCTIONS
), f"Received function call in '{mode.value}' mode"
return AIRegularMessage(content=print_function_call(call))


Expand All @@ -85,7 +68,7 @@ class Config:
@property
def _tool_declarations(self) -> Optional[str]:
return self.tool_config and print_tool_declarations(
self.tool_config.tools
self.tool_config.functions
)

def add_tool_declarations(
Expand All @@ -109,14 +92,11 @@ def get_stop_sequences(self) -> List[str]:
def convert_to_base_messages(
self, messages: List[BaseMessage | ToolMessage]
) -> List[BaseMessage]:
id_to_name: Dict[str, str] = {}
return [
(
message
if isinstance(message, BaseMessage)
else convert_to_base_message(
self.tool_config, id_to_name, message
)
else convert_to_base_message(self.tool_config, message)
)
for message in messages
]
Expand All @@ -132,7 +112,7 @@ def recognize_call(


def legacy_tools_emulator(
tool_config: Optional[ToolConfig],
tool_config: Optional[ToolsConfig],
) -> ToolsEmulator:
return Claude2_1_ToolsEmulator(
tool_config=tool_config,
Expand Down
40 changes: 18 additions & 22 deletions aidial_adapter_bedrock/llm/tools/claude_protocol.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json
from typing import Dict, List, Literal, Optional

from aidial_sdk.chat_completion import FunctionCall, Tool, ToolCall
from aidial_sdk.chat_completion import Function, FunctionCall, ToolCall
from pydantic import BaseModel

from aidial_adapter_bedrock.llm.message import (
AIFunctionCallMessage,
AIToolCallMessage,
)
from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig, ToolsMode
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig
from aidial_adapter_bedrock.utils.pydnatic import ExtraForbidModel
from aidial_adapter_bedrock.utils.xml import parse_xml, tag, tag_nl

Expand Down Expand Up @@ -91,21 +91,23 @@ def _print_tool_parameters(parameters: ToolParameters) -> str:
)


def _print_tool_declaration(tool: Tool) -> str:
def _print_tool_declaration(function: Function) -> str:
return tag_nl(
"tool_description",
[
tag("tool_name", tool.function.name),
tag("description", tool.function.description),
tag("tool_name", function.name),
tag("description", function.description),
_print_tool_parameters(
ToolParameters.parse_obj(tool.function.parameters)
ToolParameters.parse_obj(function.parameters)
),
],
)


def print_tool_declarations(tools: List[Tool]) -> str:
return tag_nl("tools", [_print_tool_declaration(tool) for tool in tools])
def print_tool_declarations(functions: List[Function]) -> str:
return tag_nl(
"tools", [_print_tool_declaration(function) for function in functions]
)


def _print_function_call_parameters(parameters: dict) -> str:
Expand Down Expand Up @@ -146,13 +148,12 @@ def _parse_function_call(text: str) -> FunctionCall:
start_index = text.find(FUNC_START_TAG)
if start_index == -1:
raise Exception(
f"Unable to parse function call, missing '{FUNC_TAG_NAME}' tag"
f"Unable to parse function call, missing {FUNC_TAG_NAME!r} tag"
)

try:
dict = parse_xml(text[start_index:])
invocation = dict[FUNC_TAG_NAME]["invoke"]

tool_name = invocation["tool_name"]
parameters = invocation["parameters"]
except Exception:
Expand All @@ -162,23 +163,18 @@ def _parse_function_call(text: str) -> FunctionCall:


def parse_call(
config: Optional[ToolConfig], text: str
config: Optional[ToolsConfig], text: str
) -> AIToolCallMessage | AIFunctionCallMessage | None:
if config is None:
return None

call = _parse_function_call(text)
match config.mode:
case ToolsMode.TOOLS:
return AIToolCallMessage(
calls=[
ToolCall(
index=0, id=call.name, type="function", function=call
)
]
)
case ToolsMode.FUNCTIONS:
return AIFunctionCallMessage(call=call)
if config.is_tool:
id = config.create_fresh_tool_call_id(call.name)
tool_call = ToolCall(index=0, id=id, type="function", function=call)
return AIToolCallMessage(calls=[tool_call])
else:
return AIFunctionCallMessage(call=call)


def print_function_call_result(name: str, content: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/llm/tools/default_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ToolMessage,
)
from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator
from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


Expand Down Expand Up @@ -45,5 +45,5 @@ def recognize_call(
return content


def default_tools_emulator(tool_config: Optional[ToolConfig]) -> ToolsEmulator:
def default_tools_emulator(tool_config: Optional[ToolsConfig]) -> ToolsEmulator:
return DefaultToolsEmulator(tool_config=tool_config)
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/llm/tools/emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
ToolMessage,
parse_dial_message,
)
from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig


class ToolsEmulator(ABC, BaseModel):
tool_config: Optional[ToolConfig]
tool_config: Optional[ToolsConfig]

@abstractmethod
def add_tool_declarations(
Expand Down
73 changes: 0 additions & 73 deletions aidial_adapter_bedrock/llm/tools/tool_config.py

This file was deleted.

Loading

0 comments on commit 9ea03d1

Please sign in to comment.