Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added conservative tokenizer for Claude 3 #146

Open
wants to merge 20 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
476ae14
feat: added conservative tokenizer for Claude 3
adubovik Sep 6, 2024
b5eb9d1
Merge branch 'development' into feat/support-prompt-truncation-for-cl…
adubovik Sep 6, 2024
ee40a84
feat: add tool choice to Claude 3
adubovik Sep 6, 2024
13d3dc9
feat: supported prompt truncation for claude 3
adubovik Sep 9, 2024
5b22aaa
feat: made claude 3 deployment typed literal
adubovik Sep 9, 2024
43b4613
feat: supported tokenize/ and truncate_prompt/ endpoints for Claude 3
adubovik Sep 9, 2024
04234cb
fix: Claude 3: switching to a byte counting tokenization approximation
adubovik Sep 9, 2024
b7dfdb8
chore: updated README
adubovik Sep 10, 2024
f6bce34
feat: added test for Claude 3 truncate prompt
adubovik Sep 10, 2024
3a6514d
feat: added ClaudeRequest data class
adubovik Sep 10, 2024
0523508
feat: added test with a function
adubovik Sep 10, 2024
8f9ea11
feat: added test with an image
adubovik Sep 10, 2024
4cbbacf
chore: fixed README
adubovik Sep 10, 2024
c680102
Merge branch 'development' into feat/support-prompt-truncation-for-cl…
adubovik Sep 10, 2024
ca99288
feat: improved readability of tests with mock_tokenize_text
adubovik Sep 10, 2024
5eb1ff0
fix: simplified to_dial_tool_call method
adubovik Sep 12, 2024
2049084
chore: updated comment
adubovik Sep 12, 2024
d4ca3e3
Merge branch 'development' into feat/support-prompt-truncation-for-cl…
adubovik Sep 20, 2024
da2fcda
fix: renamed truncate_prompt to compute_discarded_messages; introduce…
adubovik Sep 20, 2024
86638a2
chore: renamings
adubovik Sep 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 34 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,40 @@ The project implements [AI DIAL API](https://epam-rail.com/dial_api) for languag

### Chat completion models

The following models support `POST SERVER_URL/openai/deployments/DEPLOYMENT_NAME/chat/completions` endpoint along with optional support of `/tokenize` and `/truncate_prompt` endpoints:

|Vendor|Model|Deployment name|Modality|`/tokenize`|`/truncate_prompt`|tools/functions support|precise tokenization|
|---|---|---|---|---|---|---|---|
|Anthropic|Claude 3.5 Sonnet|anthropic.claude-3-5-sonnet-20240620-v1:0|text-to-text, image-to-text|❌|❌|✅|❌|
|Anthropic|Claude 3 Sonnet|anthropic.claude-3-sonnet-20240229-v1:0|text-to-text, image-to-text|❌|❌|✅|❌|
|Anthropic|Claude 3 Haiku|anthropic.claude-3-haiku-20240307-v1:0|text-to-text, image-to-text|❌|❌|✅|❌|
|Anthropic|Claude 3 Opus|anthropic.claude-3-opus-20240229-v1:0|text-to-text, image-to-text|❌|❌|✅|❌|
|Anthropic|Claude 2.1|anthropic.claude-v2:1|text-to-text|✅|✅|✅|✅|
|Anthropic|Claude 2|anthropic.claude-v2|text-to-text|✅|✅|❌|✅|
|Anthropic|Claude Instant 1.2|anthropic.claude-instant-v1|text-to-text|✅|✅|❌|❌|
|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|✅|✅|❌|❌|
|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|✅|✅|❌|❌|
|Meta|Llama 3.1 8B Instruct|meta.llama3-1-8b-instruct-v1:0|text-to-text|✅|✅|❌|❌|
|Meta|Llama 3 Chat 70B Instruct|meta.llama3-70b-instruct-v1:0|text-to-text|✅|✅|❌|❌|
|Meta|Llama 3 Chat 8B Instruct|meta.llama3-8b-instruct-v1:0|text-to-text|✅|✅|❌|❌|
|Meta|Llama 2 Chat 70B|meta.llama2-70b-chat-v1|text-to-text|✅|✅|❌|❌|
|Meta|Llama 2 Chat 13B|meta.llama2-13b-chat-v1|text-to-text|✅|✅|❌|❌|
|Stability AI|SDXL 1.0|stability.stable-diffusion-xl-v1|text-to-image|❌|✅|❌|❌|
|Amazon|Titan Text G1 - Express|amazon.titan-tg1-large|text-to-text|✅|✅|❌|❌|
|AI21 Labs|Jurassic-2 Ultra|ai21.j2-jumbo-instruct|text-to-text|✅|✅|❌|❌|
|AI21 Labs|Jurassic-2 Mid|ai21.j2-grande-instruct|text-to-text|✅|✅|❌|❌|
|Cohere|Command|cohere.command-text-v14|text-to-text|✅|✅|❌|❌|
|Cohere|Command Light|cohere.command-light-text-v14|text-to-text|✅|✅|❌|❌|

The models that support `/truncate_prompt` do also support `max_prompt_tokens` request parameter.

Certain model do not support precise tokenization, because the tokenization algorithm is not known. Instead an approximate tokenization algorithm is used. It conservatively counts every byte in UTF-8 encoding of a string as a single token.
The following models support `POST SERVER_URL/openai/deployments/DEPLOYMENT_NAME/chat/completions` endpoint along with an optional support of `POST /tokenize` and `POST /truncate_prompt` endpoints:

Note that a model supports `/truncate_prompt` endpoint if and only if it supports `max_prompt_tokens` request parameter.

|Vendor|Model|Deployment name|Modality|`/tokenize`|`/truncate_prompt`, `max_prompt_tokens`|tools/functions|
|---|---|---|---|---|---|---|
|Anthropic|Claude 3.5 Sonnet|anthropic.claude-3-5-sonnet-20240620-v1:0|text-to-text, image-to-text|🟡|🟡|✅|
|Anthropic|Claude 3 Sonnet|anthropic.claude-3-sonnet-20240229-v1:0|text-to-text, image-to-text|🟡|🟡|✅|
|Anthropic|Claude 3 Haiku|anthropic.claude-3-haiku-20240307-v1:0|text-to-text, image-to-text|🟡|🟡|✅|
|Anthropic|Claude 3 Opus|anthropic.claude-3-opus-20240229-v1:0|text-to-text, image-to-text|🟡|🟡|✅|
|Anthropic|Claude 2.1|anthropic.claude-v2:1|text-to-text|✅|✅|✅|
|Anthropic|Claude 2|anthropic.claude-v2|text-to-text|✅|✅|❌|
|Anthropic|Claude Instant 1.2|anthropic.claude-instant-v1|text-to-text|🟡|🟡|❌|
|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3.1 8B Instruct|meta.llama3-1-8b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3 Chat 70B Instruct|meta.llama3-70b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3 Chat 8B Instruct|meta.llama3-8b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 2 Chat 70B|meta.llama2-70b-chat-v1|text-to-text|🟡|🟡|❌|
|Meta|Llama 2 Chat 13B|meta.llama2-13b-chat-v1|text-to-text|🟡|🟡|❌|
|Stability AI|SDXL 1.0|stability.stable-diffusion-xl-v1|text-to-image|❌|🟡|❌|
|Amazon|Titan Text G1 - Express|amazon.titan-tg1-large|text-to-text|🟡|🟡|❌|
|AI21 Labs|Jurassic-2 Ultra|ai21.j2-jumbo-instruct|text-to-text|🟡|🟡|❌|
|AI21 Labs|Jurassic-2 Mid|ai21.j2-grande-instruct|text-to-text|🟡|🟡|❌|
|Cohere|Command|cohere.command-text-v14|text-to-text|🟡|🟡|❌|
|Cohere|Command Light|cohere.command-light-text-v14|text-to-text|🟡|🟡|❌|

✅, 🟡, and ❌ denote degrees of support of the given feature:

||`/tokenize`, `/truncate_prompt`, `max_prompt_token`|tools/functions|
|---|---|---|
|✅|Fully supported via an official tokenization algorithm|Fully supported via native tools API or official prompts to enable tools|
|🟡|Partially supported, because tokenization algorithm wasn't made public by the model vendor.<br>An approximate tokenization algorithm is used instead.<br>It conservatively counts **every byte in UTF-8 encoding of a string as a single token**.|Partially supported, because the model doesn't support tools natively.<br>Prompt engineering is used instead to emulate tools, which may not be very reliable.|
|❌|Not supported|Not supported|

### Embedding models

Expand Down
10 changes: 2 additions & 8 deletions aidial_adapter_bedrock/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,7 @@ async def models():


for deployment in ChatCompletionDeployment:
app.add_chat_completion(
deployment.deployment_id,
BedrockChatCompletion(),
)
app.add_chat_completion(deployment.deployment_id, BedrockChatCompletion())

for deployment in EmbeddingsDeployment:
app.add_embeddings(
deployment.deployment_id,
BedrockEmbeddings(),
)
app.add_embeddings(deployment.deployment_id, BedrockEmbeddings())
12 changes: 8 additions & 4 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from aidial_adapter_bedrock.llm.consumer import ChoiceConsumer
from aidial_adapter_bedrock.llm.errors import UserError, ValidationError
from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter
from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages
from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator
from aidial_adapter_bedrock.utils.log_config import app_logger as log
from aidial_adapter_bedrock.utils.not_implemented import is_implemented
Expand Down Expand Up @@ -62,7 +63,7 @@ async def chat_completion(self, request: Request, response: Response):
model = await self._get_model(request)
params = ModelParameters.create(request)

discarded_messages: Optional[List[int]] = None
discarded_messages: Optional[DiscardedMessages] = None

async def generate_response(usage: TokenUsage) -> None:
nonlocal discarded_messages
Expand Down Expand Up @@ -150,7 +151,7 @@ async def truncate_prompt(
) -> TruncatePromptResponse:
model = await self._get_model(request)

if not is_implemented(model.truncate_prompt):
if not is_implemented(model.compute_discarded_messages):
raise ResourceNotFoundError("The endpoint is not implemented")

outputs: List[TruncatePromptResult] = []
Expand All @@ -167,9 +168,12 @@ async def _truncate_prompt_request(
if params.max_prompt_tokens is None:
raise ValidationError("max_prompt_tokens is required")

discarded_messages = await model.truncate_prompt(
discarded_messages = await model.compute_discarded_messages(
params, request.messages
)
return TruncatePromptSuccess(discarded_messages=discarded_messages)

return TruncatePromptSuccess(
discarded_messages=discarded_messages or []
)
except Exception as e:
return TruncatePromptError(error=str(e))
16 changes: 16 additions & 0 deletions aidial_adapter_bedrock/deployments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from typing import Literal


class ChatCompletionDeployment(str, Enum):
Expand Down Expand Up @@ -65,6 +66,21 @@ def from_deployment_id(
return cls(deployment_id)


Claude3Deployment = Literal[
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_SONNET,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_SONNET_US,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_SONNET_EU,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_US,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_EU,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_HAIKU,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_HAIKU_US,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_HAIKU_EU,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_OPUS,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_OPUS_US,
]


class EmbeddingsDeployment(str, Enum):
AMAZON_TITAN_EMBED_TEXT_V1 = "amazon.titan-embed-text-v1"
AMAZON_TITAN_EMBED_TEXT_V2 = "amazon.titan-embed-text-v2:0"
Expand Down
7 changes: 7 additions & 0 deletions aidial_adapter_bedrock/dial_api/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from aidial_adapter_bedrock.llm.tools.tools_config import (
ToolsConfig,
ToolsMode,
validate_messages,
)

Expand Down Expand Up @@ -44,3 +45,9 @@ def create(cls, request: ChatCompletionRequest) -> "ModelParameters":

def add_stop_sequences(self, stop: List[str]) -> "ModelParameters":
return self.copy(update={"stop": [*self.stop, *stop]})

@property
def tools_mode(self) -> ToolsMode | None:
if self.tool_config is not None:
return self.tool_config.tools_mode
return None
87 changes: 48 additions & 39 deletions aidial_adapter_bedrock/llm/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from typing import AsyncIterator, Callable, List, Optional
from typing import Any, AsyncIterator, Callable, List, Optional

from aidial_sdk.chat_completion import Message, Role
from pydantic import BaseModel
from typing_extensions import override

import aidial_adapter_bedrock.utils.stream as stream_utils
from aidial_adapter_bedrock.dial_api.request import ModelParameters
Expand All @@ -13,10 +14,9 @@
from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig
from aidial_adapter_bedrock.llm.truncate_prompt import (
TruncatePromptError,
DiscardedMessages,
truncate_prompt,
)
from aidial_adapter_bedrock.utils.list import omit_by_indices
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
from aidial_adapter_bedrock.utils.not_implemented import not_implemented

Expand Down Expand Up @@ -51,15 +51,23 @@ async def count_prompt_tokens(
async def count_completion_tokens(self, string: str) -> int: ...

@not_implemented
async def truncate_prompt(
async def compute_discarded_messages(
self, params: ModelParameters, messages: List[Message]
) -> List[int]: ...
) -> DiscardedMessages | None:
"""
The method truncates the list of messages to fit
into the token limit set in `params.max_prompt_tokens`.

If the limit isn't provided, then it returns None.
Otherwise, returns the indices of _discarded_ messages which should be
removed from the list to make the rest fit into the token limit.
"""


class TextCompletionPrompt(BaseModel):
text: str
stop_sequences: List[str]
discarded_messages: Optional[List[int]] = None
discarded_messages: Optional[DiscardedMessages] = None


class TextCompletionAdapter(ChatCompletionAdapter):
Expand All @@ -72,7 +80,7 @@ async def predict(
pass

@abstractmethod
def truncate_and_linearize_messages(
async def truncate_and_linearize_messages(
self, messages: List[BaseMessage], max_prompt_tokens: Optional[int]
) -> TextCompletionPrompt:
pass
Expand All @@ -88,7 +96,7 @@ def preprocess_messages(self, messages: List[Message]) -> List[Message]:

return messages

def get_text_completion_prompt(
async def get_text_completion_prompt(
self, params: ModelParameters, messages: List[Message]
) -> TextCompletionPrompt:

Expand All @@ -97,7 +105,7 @@ def get_text_completion_prompt(
base_messages = tools_emulator.parse_dial_messages(messages)
tool_stop_sequences = tools_emulator.get_stop_sequences()

prompt = self.truncate_and_linearize_messages(
prompt = await self.truncate_and_linearize_messages(
base_messages, params.max_prompt_tokens
)

Expand All @@ -113,37 +121,45 @@ async def chat(
messages: List[Message],
) -> None:

prompt = self.get_text_completion_prompt(params, messages)
prompt = await self.get_text_completion_prompt(params, messages)
params.stop = prompt.stop_sequences

if prompt.discarded_messages is not None:
consumer.set_discarded_messages(prompt.discarded_messages)
consumer.set_discarded_messages(prompt.discarded_messages)

log.debug(f"model parameters: {params.json(exclude_none=True)}")
log.debug(f"prompt: {prompt.text!r}")

await self.predict(consumer, params, prompt.text)

async def truncate_prompt(
async def compute_discarded_messages(
self, params: ModelParameters, messages: List[Message]
) -> List[int]:
prompt = self.get_text_completion_prompt(params, messages)
return prompt.discarded_messages or []
) -> DiscardedMessages | None:
prompt = await self.get_text_completion_prompt(params, messages)
return prompt.discarded_messages


def keep_last(messages: List[Any], idx: int) -> bool:
return idx == len(messages) - 1


def default_keep_message(messages: List[BaseMessage], idx: int) -> bool:
"""Keep system messages and the last message."""
return isinstance(messages[idx], SystemMessage) or idx == len(messages) - 1
def keep_last_and_system_messages(
messages: List[BaseMessage], idx: int
) -> bool:
return isinstance(messages[idx], SystemMessage) or keep_last(messages, idx)


def default_partitioner(messages: List[BaseMessage]) -> List[int]:
def trivial_partitioner(messages: List[Any]) -> List[int]:
return [1] * len(messages)


def turn_based_partitioner(messages: List[Any]) -> List[int]:
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
n = len(messages)
return [2] * (n // 2) + [1] * (n % 2)


class PseudoChatModel(TextCompletionAdapter):
chat_emulator: ChatEmulator
tokenize_string: Callable[[str], int]
chat_emulator: ChatEmulator
partitioner: Callable[[List[BaseMessage]], List[int]]

async def count_prompt_tokens(
Expand All @@ -152,43 +168,36 @@ async def count_prompt_tokens(
messages = self.preprocess_messages(messages)
tools_emulator = self.tools_emulator(params.tool_config)
base_messages = tools_emulator.parse_dial_messages(messages)
return self.tokenize_messages(base_messages)
return await self.tokenize_messages(base_messages)

async def count_completion_tokens(self, string: str) -> int:
return self.tokenize_string(string)

def tokenize_messages(self, messages: List[BaseMessage]) -> int:
async def tokenize_messages(self, messages: List[BaseMessage]) -> int:
return self.tokenize_string(self.chat_emulator.display(messages)[0])

def truncate_and_linearize_messages(
@override
async def truncate_and_linearize_messages(
self, messages: List[BaseMessage], max_prompt_tokens: Optional[int]
) -> TextCompletionPrompt:
truncate_result = truncate_prompt(
discarded_messages, messages = await truncate_prompt(
messages=messages,
tokenize_messages=self.tokenize_messages,
keep_message=default_keep_message,
partition_messages=self.partitioner,
tokenizer=self.tokenize_messages,
keep_message=keep_last_and_system_messages,
partitioner=self.partitioner,
model_limit=None,
user_limit=max_prompt_tokens,
)

if isinstance(truncate_result, TruncatePromptError):
raise truncate_result.to_dial_exception()

discarded_messages: set[int] = truncate_result

messages = omit_by_indices(messages, truncate_result)

text, stop_sequences = self.chat_emulator.display(messages)

discarded_messages_list = (
None if max_prompt_tokens is None else list(discarded_messages)
)
if max_prompt_tokens is None:
discarded_messages = None

return TextCompletionPrompt(
text=text,
stop_sequences=stop_sequences,
discarded_messages=discarded_messages_list,
discarded_messages=discarded_messages,
)

@staticmethod
Expand Down
Loading