Skip to content

Commit

Permalink
Small fixes for Gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaGusev committed Sep 20, 2024
1 parent 735ca2b commit e7b37dd
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 18 deletions.
53 changes: 39 additions & 14 deletions src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,8 @@ async def yookassa_check_payments(self) -> None:
await self.bot.send_message(chat_id=payment.chat_id, text=self.localization.SUB_SUCCESS)
self.db.set_payment_status(payment.payment_id, status=status.value, internal_status="completed")
elif status == YookassaStatus.CANCELED:
await self.bot.send_message(chat_id=payment.chat_id, text=self.localization.PAYMENT_CANCEL)
self.db.set_payment_status(payment.payment_id, status=status.value, internal_status="completed")
await self.bot.send_message(chat_id=payment.chat_id, text=self.localization.PAYMENT_CANCEL)

async def set_email(self, message: Message) -> None:
assert message.text
Expand Down Expand Up @@ -671,7 +671,7 @@ async def set_temperature_button_handler(self, callback: CallbackQuery) -> None:
provider = self.providers[model]
temperature = float(callback.data.split(":")[1])
params = self.db.get_parameters(chat_id)
params = provider.params if params is None else params
params = copy.deepcopy(provider.params) if params is None else params
params["temperature"] = temperature
self.db.set_parameters(chat_id, **params)
assert isinstance(callback.message, Message)
Expand Down Expand Up @@ -1012,11 +1012,22 @@ async def debug(self, message: Message, command: CommandObject) -> None:
await placeholder.edit_text(answer[: self.config.output_chunk_size])

@staticmethod
async def _query_api(provider: LLMProvider, messages: ChatMessages, system_prompt: str, **kwargs: Any) -> str:
async def _query_api(
provider: LLMProvider,
messages: ChatMessages,
system_prompt: str,
num_retries: int = 2,
**kwargs: Any
) -> str:
assert messages
if messages[0]["role"] != "system" and system_prompt.strip():
messages.insert(0, {"role": "system", "content": system_prompt})

if provider.merge_system and messages[0]["role"] == "system":
system_message = messages[0]["content"]
messages = messages[1:]
messages[0]["content"] = system_message + "\n\n" + messages[0]["content"]

print(
provider.provider_name,
"####",
Expand All @@ -1025,13 +1036,23 @@ async def _query_api(provider: LLMProvider, messages: ChatMessages, system_promp
_crop_content(messages[-1]["content"]),
)
casted_messages = [cast(ChatCompletionMessageParam, message) for message in messages]
chat_completion = await provider.api.chat.completions.create(
model=provider.model_name, messages=casted_messages, **kwargs
)
assert chat_completion.choices, str(chat_completion)
assert chat_completion.choices[0].message.content, str(chat_completion)
assert isinstance(chat_completion.choices[0].message.content, str), str(chat_completion)
answer: str = chat_completion.choices[0].message.content
answer: Optional[str] = None
for _ in range(num_retries):
try:
chat_completion = await provider.api.chat.completions.create(
model=provider.model_name, messages=casted_messages, **kwargs
)
assert chat_completion.choices, str(chat_completion)
assert chat_completion.choices[0].message.content, str(chat_completion)
assert isinstance(chat_completion.choices[0].message.content, str), str(chat_completion)
answer = chat_completion.choices[0].message.content
break
except Exception:
traceback.print_exc()
continue
assert answer
if provider.merge_spaces:
answer = answer.replace(" ", " ")
print(
provider.provider_name,
"####",
Expand Down Expand Up @@ -1137,17 +1158,21 @@ async def save_feedback_handler(self, callback: CallbackQuery) -> None:
)

def _count_tokens(self, messages: ChatMessages, provider: LLMProvider) -> int:
model_name = provider.model_name
tokenizer_name = provider.tokenizer_name
if not tokenizer_name:
tokenizer_name = provider.model_name
url = str(provider.api.base_url)
tokens_count = 0

if "api.openai.com" in url:
encoding = tiktoken.encoding_for_model(model_name)
if "o1" in tokenizer_name:
return 0
encoding = tiktoken.encoding_for_model(tokenizer_name)
for m in messages:
if isinstance(m["content"], str):
tokens_count += len(encoding.encode(m["content"]))
elif self._is_image_content(m["content"]):
tokens_count += 1000
tokens_count += 2000
return tokens_count

if "anthropic" in url:
Expand All @@ -1156,7 +1181,7 @@ def _count_tokens(self, messages: ChatMessages, provider: LLMProvider) -> int:
tokens_count += len(m["content"]) // 2
return tokens_count

tokenizer = Tokenizers.get(model_name)
tokenizer = Tokenizers.get(tokenizer_name)
tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
tokens_count = len(tokens)
return tokens_count
Expand Down
7 changes: 5 additions & 2 deletions src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,12 @@ def subscribe_user(self, user_id: int, duration: int) -> None:
session.add(new_subscription)
session.commit()

def get_all_conv_ids(self) -> List[str]:
def get_all_conv_ids(self, min_timestamp: Optional[int] = None) -> List[str]:
with self.Session() as session:
conversations = session.query(Conversation).all()
if min_timestamp is None:
conversations = session.query(Conversation).all()
else:
conversations = session.query(Conversation).filter(Conversation.timestamp >= min_timestamp).all()
return [conv.conv_id for conv in conversations]

def _serialize_content(self, content: Union[None, str, List[Dict[str, Any]]]) -> str:
Expand Down
11 changes: 9 additions & 2 deletions src/provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Any
import copy
from typing import Dict, Any, Optional

from openai import AsyncOpenAI

Expand Down Expand Up @@ -27,15 +28,21 @@ def __init__(
history_max_tokens: int = DEFAULT_HISTORY_MAX_TOKENS,
params: Dict[str, Any] = DEFAULT_PARAMS,
message_count_limit: Dict[str, Any] = DEFAULT_MESSAGE_COUNT_LIMIT,
tokenizer_name: Optional[str] = None,
merge_system: bool = False,
merge_spaces: bool = False,
):
self.provider_name = provider_name
self.model_name = model_name
self.can_handle_images = can_handle_images
self.can_handle_tools = can_handle_tools
self.system_prompt = system_prompt
self.history_max_tokens = history_max_tokens
self.params = params
self.params = copy.deepcopy(params)
self.limits = message_count_limit
self.tokenizer_name = tokenizer_name
self.merge_system = merge_system
self.merge_spaces = merge_spaces
assert "standard" in self.limits
assert "subscribed" in self.limits
self.api = AsyncOpenAI(base_url=base_url, api_key=api_key)

0 comments on commit e7b37dd

Please sign in to comment.