Skip to content

Commit

Permalink
fix: renamed truncate_prompt to compute_discarded_messages; introduce…
Browse files Browse the repository at this point in the history
…s DiscardedMessages type alias; got rid of Set[int] for the discarded messages
  • Loading branch information
adubovik committed Sep 20, 2024
1 parent d4ca3e3 commit fea64c5
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 54 deletions.
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def truncate_prompt(
) -> TruncatePromptResponse:
model = await self._get_model(request)

if not is_implemented(model.truncate_prompt):
if not is_implemented(model.compute_discarded_messages):
raise ResourceNotFoundError("The endpoint is not implemented")

outputs: List[TruncatePromptResult] = []
Expand All @@ -167,7 +167,7 @@ async def _truncate_prompt_request(
if params.max_prompt_tokens is None:
raise ValidationError("max_prompt_tokens is required")

discarded_messages = await model.truncate_prompt(
discarded_messages = await model.compute_discarded_messages(
params, request.messages
)

Expand Down
14 changes: 11 additions & 3 deletions aidial_adapter_bedrock/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,17 @@ async def count_prompt_tokens(
async def count_completion_tokens(self, string: str) -> int: ...

@not_implemented
async def truncate_prompt(
async def compute_discarded_messages(
self, params: ModelParameters, messages: List[Message]
) -> List[int] | None: ...
) -> List[int] | None:
"""
The method truncates the list of messages to fit
into the token limit set in `params.max_prompt_tokens`.
If the limit isn't provided, then it returns None.
Otherwise, returns the indices of the messages which should be
removed from the list to make the rest fit into the token limit.
"""


class TextCompletionPrompt(BaseModel):
Expand Down Expand Up @@ -120,7 +128,7 @@ async def chat(

await self.predict(consumer, params, prompt.text)

async def truncate_prompt(
async def compute_discarded_messages(
self, params: ModelParameters, messages: List[Message]
) -> List[int] | None:
prompt = await self.get_text_completion_prompt(params, messages)
Expand Down
2 changes: 1 addition & 1 deletion aidial_adapter_bedrock/llm/model/claude/v3/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ async def count_prompt_tokens(
async def count_completion_tokens(self, string: str) -> int:
return tokenize_text(string)

async def truncate_prompt(
async def compute_discarded_messages(
self, params: DialParameters, messages: List[DialMessage]
) -> List[int] | None:
request = await self._prepare_claude_request(params, messages)
Expand Down
15 changes: 8 additions & 7 deletions aidial_adapter_bedrock/llm/truncate_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _partition_indexer(chunks: List[int]) -> Callable[[int], List[int]]:


_T = TypeVar("_T")
DiscardedMessages = List[int]


async def truncate_prompt(
Expand All @@ -76,12 +77,12 @@ async def truncate_prompt(
partition_messages: Callable[[List[_T]], List[int]],
model_limit: Optional[int],
user_limit: Optional[int],
) -> Tuple[List[int], List[_T]]:
) -> Tuple[DiscardedMessages, List[_T]]:
"""
Returns a list of indices of discarded messages and a list of preserved messages
"""

result = await truncate_prompt_(
result = await compute_discarded_messages(
messages,
tokenize_messages,
keep_message,
Expand All @@ -96,14 +97,14 @@ async def truncate_prompt(
return (list(result), omit_by_indices(messages, result))


async def truncate_prompt_(
async def compute_discarded_messages(
messages: List[_T],
tokenize_messages: Callable[[List[_T]], Awaitable[int]],
keep_message: Callable[[List[_T], int], bool],
partition_messages: Callable[[List[_T]], List[int]],
model_limit: Optional[int],
user_limit: Optional[int],
) -> Set[int] | TruncatePromptError:
) -> DiscardedMessages | TruncatePromptError:
if (
user_limit is not None
and model_limit is not None
Expand All @@ -115,11 +116,11 @@ async def truncate_prompt_(

if user_limit is None:
if model_limit is None:
return set()
return []

token_count = await tokenize_messages(messages)
if token_count <= model_limit:
return set()
return []

return ModelLimitOverflow(
model_limit=model_limit, token_count=token_count
Expand Down Expand Up @@ -162,4 +163,4 @@ async def _tokenize_selected(indices: Set[int]) -> int:
kept_indices.update(chunk_indices)

all_indices = set(range(n))
return all_indices - kept_indices
return sorted(list(all_indices - kept_indices))
33 changes: 18 additions & 15 deletions tests/unit_tests/chat_emulation/test_llama2_chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Set
from typing import List, Optional

import pytest

Expand All @@ -7,8 +7,9 @@
from aidial_adapter_bedrock.llm.message import BaseMessage
from aidial_adapter_bedrock.llm.model.llama.v2 import llama2_config
from aidial_adapter_bedrock.llm.truncate_prompt import (
DiscardedMessages,
TruncatePromptError,
truncate_prompt_,
compute_discarded_messages,
)
from tests.utils.messages import ai, sys, user

Expand All @@ -20,11 +21,11 @@ async def truncate_prompt_by_words(
messages: List[BaseMessage],
user_limit: int,
model_limit: Optional[int] = None,
) -> Set[int] | TruncatePromptError:
) -> DiscardedMessages | TruncatePromptError:
async def _tokenize_by_words(messages: List[BaseMessage]) -> int:
return sum(len(msg.content.split()) for msg in messages)

return await truncate_prompt_(
return await compute_discarded_messages(
messages=messages,
tokenize_messages=_tokenize_by_words,
keep_message=keep_last_and_system_messages,
Expand Down Expand Up @@ -137,20 +138,22 @@ def test_invalid_last_message():
"However, the system messages and the last user message resulted in 2 tokens. "
"Please reduce the length of the messages or increase the maximum prompt tokens.",
),
(turns_sys, 2, {1, 2, 3, 4}),
(turns_sys, 3, {1, 2, 3, 4}),
(turns_sys, 4, {1, 2}),
(turns_sys, 5, {1, 2}),
(turns_sys, 6, set()),
(turns_no_sys, 1, {0, 1, 2, 3}),
(turns_no_sys, 2, {0, 1, 2, 3}),
(turns_no_sys, 3, {0, 1}),
(turns_no_sys, 4, {0, 1}),
(turns_no_sys, 5, set()),
(turns_sys, 2, [1, 2, 3, 4]),
(turns_sys, 3, [1, 2, 3, 4]),
(turns_sys, 4, [1, 2]),
(turns_sys, 5, [1, 2]),
(turns_sys, 6, []),
(turns_no_sys, 1, [0, 1, 2, 3]),
(turns_no_sys, 2, [0, 1, 2, 3]),
(turns_no_sys, 3, [0, 1]),
(turns_no_sys, 4, [0, 1]),
(turns_no_sys, 5, []),
],
)
async def test_multi_turn_dialogue(
messages: List[BaseMessage], user_limit: int, expected: Set[int] | str
messages: List[BaseMessage],
user_limit: int,
expected: DiscardedMessages | str,
):
discarded_messages = await truncate_prompt_by_words(
messages=messages, user_limit=user_limit
Expand Down
39 changes: 21 additions & 18 deletions tests/unit_tests/test_claude3_truncate_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Set
from typing import List
from unittest.mock import patch

import pytest
Expand All @@ -14,6 +14,7 @@
Adapter as Claude_V3,
)
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig
from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages
from tests.utils.messages import ai, sys, to_sdk_messages, user, user_with_image

_DEPLOYMENT = ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_OPUS
Expand All @@ -29,17 +30,17 @@ async def tokenize(
return await _MODEL.count_prompt_tokens(params, messages)


async def truncate_prompt(
async def compute_discarded_messages(
messages: List[Message],
max_prompt_tokens: int | None,
tool_config: ToolsConfig | None = None,
) -> Set[int] | str:
) -> DiscardedMessages | str:
params = ModelParameters(
max_prompt_tokens=max_prompt_tokens, tool_config=tool_config
)

try:
return set(await _MODEL.truncate_prompt(params, messages) or [])
return await _MODEL.compute_discarded_messages(params, messages) or []
except DialException as e:
return e.message

Expand Down Expand Up @@ -89,9 +90,11 @@ async def test_one_turn_no_truncation(mock_tokenize_text):

assert await tokenize(messages) == expected_tokens

discarded_messages = await truncate_prompt(messages, expected_tokens)
discarded_messages = await compute_discarded_messages(
messages, expected_tokens
)

assert discarded_messages == set()
assert discarded_messages == []


@pytest.mark.asyncio
Expand All @@ -107,11 +110,11 @@ async def test_one_turn_with_image(mock_tokenize_text):

assert await tokenize(messages) == expected_tokens

truncation = await truncate_prompt(messages, expected_tokens)
truncation = await compute_discarded_messages(messages, expected_tokens)

assert truncation == set()
assert truncation == []

truncation = await truncate_prompt(messages, expected_tokens - 1)
truncation = await compute_discarded_messages(messages, expected_tokens - 1)

assert (
truncation
Expand All @@ -135,7 +138,7 @@ async def test_one_turn_with_tools(mock_tokenize_text):

assert await tokenize(messages, _TOOL_CONFIG) == expected_tokens

discarded_messages = await truncate_prompt(
discarded_messages = await compute_discarded_messages(
messages, expected_tokens - 1, _TOOL_CONFIG
)

Expand All @@ -159,7 +162,7 @@ async def test_one_turn_overflow(mock_tokenize_text):
11 + (22 + _PER_MESSAGE_TOKENS) + (33 + _PER_MESSAGE_TOKENS)
)

truncation_error = await truncate_prompt(messages, 1)
truncation_error = await compute_discarded_messages(messages, 1)

assert (
truncation_error
Expand All @@ -178,7 +181,7 @@ async def test_multiple_system_messages(mock_tokenize_text):
)

with pytest.raises(ValidationError) as exc_info:
await truncate_prompt(messages, 3)
await compute_discarded_messages(messages, 3)

assert exc_info.value.message == (
"System message is only allowed as the first message"
Expand All @@ -205,11 +208,11 @@ async def test_truncate_first_turn(mock_tokenize_text):

assert await tokenize(messages) == expected_tokens

discarded_messages = await truncate_prompt(
discarded_messages = await compute_discarded_messages(
messages, (_PER_MESSAGE_TOKENS + 33) + (_PER_MESSAGE_TOKENS + 44)
)

assert discarded_messages == {0, 1}
assert discarded_messages == [0, 1]


@pytest.mark.asyncio
Expand All @@ -224,11 +227,11 @@ async def test_truncate_first_turn_with_system(mock_tokenize_text):
]
)

discarded_messages = await truncate_prompt(
discarded_messages = await compute_discarded_messages(
messages, 11 + (_PER_MESSAGE_TOKENS + 44) + (_PER_MESSAGE_TOKENS + 55)
)

assert discarded_messages == {1, 2}
assert discarded_messages == [1, 2]


@pytest.mark.asyncio
Expand All @@ -242,7 +245,7 @@ async def test_zero_turn_overflow(mock_tokenize_text):

expected_tokens = 11 + (22 + _PER_MESSAGE_TOKENS)

truncation_error = await truncate_prompt(messages, 3)
truncation_error = await compute_discarded_messages(messages, 3)

assert (
truncation_error
Expand All @@ -263,7 +266,7 @@ async def test_chat_history_overflow(mock_tokenize_text):

min_possible_tokens = 11 + (44 + _PER_MESSAGE_TOKENS)

truncation_error = await truncate_prompt(messages, 1)
truncation_error = await compute_discarded_messages(messages, 1)

assert (
truncation_error
Expand Down
17 changes: 9 additions & 8 deletions tests/unit_tests/test_truncate_prompt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Set
from typing import List, Optional

import pytest

Expand All @@ -8,8 +8,9 @@
)
from aidial_adapter_bedrock.llm.message import BaseMessage
from aidial_adapter_bedrock.llm.truncate_prompt import (
DiscardedMessages,
TruncatePromptError,
truncate_prompt_,
compute_discarded_messages,
)
from tests.utils.messages import ai, sys, user

Expand All @@ -18,11 +19,11 @@ async def truncate_prompt_by_words(
messages: List[BaseMessage],
user_limit: int,
model_limit: Optional[int] = None,
) -> Set[int] | TruncatePromptError:
) -> DiscardedMessages | TruncatePromptError:
async def _tokenize_by_words(messages: List[BaseMessage]) -> int:
return sum(len(msg.content.split()) for msg in messages)

return await truncate_prompt_(
return await compute_discarded_messages(
messages=messages,
tokenize_messages=_tokenize_by_words,
keep_message=keep_last_and_system_messages,
Expand All @@ -44,7 +45,7 @@ async def test_no_truncation():
messages=messages, user_limit=3
)

assert isinstance(discarded_messages, set) and discarded_messages == set()
assert discarded_messages == []


@pytest.mark.asyncio
Expand All @@ -61,7 +62,7 @@ async def test_truncation():
messages=messages, user_limit=3
)

assert discarded_messages == {1, 3}
assert discarded_messages == [1, 3]


@pytest.mark.asyncio
Expand All @@ -75,7 +76,7 @@ async def test_truncation_with_one_message_left():
messages=messages, user_limit=1
)

assert discarded_messages == {0}
assert discarded_messages == [0]


@pytest.mark.asyncio
Expand All @@ -89,7 +90,7 @@ async def test_truncation_with_one_message_accepted_after_second_check():
messages=messages, user_limit=1
)

assert discarded_messages == {0}
assert discarded_messages == [0]


@pytest.mark.asyncio
Expand Down

0 comments on commit fea64c5

Please sign in to comment.