Skip to content

Commit

Permalink
Fixes for images in Claude
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaGusev committed Jul 8, 2024
1 parent 7665f7d commit 223123b
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 7 deletions.
5 changes: 3 additions & 2 deletions scripts/anthropic2openai.js
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ async function handleRequest(request) {
for (var part of m.content) {
if (part.type == "image_url") {
const image_url = part.image_url.url;
const base64 = image_url.replace("image/jpeg;base64,", "").replace("data:", "");
const base64 = image_url.split(",")[1].replace("data:", "");
const media_type = image_url.split(",")[0].split(";")[0].replace("data:", "");
fixed_parts.unshift({
type: "image",
source: {
type: "base64",
media_type: "image/jpeg",
media_type: media_type,
data: base64
}
})
Expand Down
29 changes: 25 additions & 4 deletions src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore

from src.provider import LLMProvider
from src.decorators import check_admin
from src.decorators import check_admin, check_creator
from src.localization import Localization
from src.tools import Tool
from src.database import Database
Expand Down Expand Up @@ -161,6 +161,7 @@ def __init__(
("subbuy", self.sub_buy),
("tools", self.toogle_tools),
("history", self.history),
("debug", self.debug),
]
for command, func in commands:
self.dp.message.register(func, Command(command))
Expand Down Expand Up @@ -735,7 +736,7 @@ async def _call_dalle(self, conv_id: str, user_id: int, chat_id: int, placeholde
)
return
image_url: Any = function_response[1]["image_url"]
base64_image = image_url["url"].replace("data:image/jpeg;base64,", "")
base64_image = image_url["url"].split(",")[-1]
image_data = base64.b64decode(base64_image)
input_file = BufferedInputFile(image_data, filename="image.jpeg")
new_message = await self.bot.send_photo(
Expand Down Expand Up @@ -885,6 +886,7 @@ async def generate(self, message: Message) -> None:

history = self._fix_image_roles(history)
history = self._fix_broken_tool_calls(history)
history = self._merge_messages(history)
if tools and "gpt" not in model:
params["tools"] = tools
answer = await self._query_api(provider=provider, messages=history, system_prompt=system_prompt, **params)
Expand Down Expand Up @@ -921,14 +923,33 @@ async def generate(self, message: Message) -> None:
text = self.localization.ERROR.format(admin_username=self.config.admin_user_name, chat_id=chat_id)
await placeholder.edit_text(text)

@check_creator
async def debug(self, message: Message, command: CommandObject) -> None:
assert command.args
conv_id = command.args.strip()
user_id = self.db.get_user_id_by_conv_id(conv_id)
history = self.db.fetch_conversation(conv_id)
model = list({m["model"] for m in history if m["model"]})[0]
provider = self.providers[model]
history = self._prepare_history(history, provider, False)
params = self.db.get_parameters(user_id)
params = provider.params if params is None else params
system_prompt = self.db.get_system_prompt(user_id)
system_prompt = provider.system_prompt if system_prompt is None else system_prompt
placeholder = await message.reply("💬")
history = self._fix_image_roles(history)
history = self._fix_broken_tool_calls(history)
answer = await self._query_api(provider=provider, messages=history, system_prompt=system_prompt, **params)
await placeholder.edit_text(answer[: self.config.output_chunk_size])

@staticmethod
async def _query_api(provider: LLMProvider, messages: ChatMessages, system_prompt: str, **kwargs: Any) -> str:
assert messages
if messages[0]["role"] != "system" and system_prompt.strip():
messages.insert(0, {"role": "system", "content": system_prompt})

print(
provider.model_name,
provider.provider_name,
"####",
len(messages),
"####",
Expand All @@ -943,7 +964,7 @@ async def _query_api(provider: LLMProvider, messages: ChatMessages, system_promp
assert isinstance(chat_completion.choices[0].message.content, str), str(chat_completion)
answer: str = chat_completion.choices[0].message.content
print(
provider.model_name,
provider.provider_name,
"####",
len(messages),
"####",
Expand Down
31 changes: 31 additions & 0 deletions src/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,34 @@ async def wrapped(self: Any, obj: Union[CallbackQuery, Message], *args: Any, **k
return await func(self, obj, *args, **kwargs)

return wrapped


def check_creator(func: Callable[..., Coroutine[Any, Any, Any]]) -> Callable[..., Coroutine[Any, Any, Any]]:
@wraps(func)
async def wrapped(self: Any, obj: Union[CallbackQuery, Message], *args: Any, **kwargs: Any) -> Any:
if isinstance(obj, CallbackQuery):
assert obj.message
assert obj.from_user
chat_id = obj.message.chat.id
user_id = obj.from_user.id
user = obj.from_user
elif isinstance(obj, Message):
assert obj.chat
assert obj.from_user
chat_id = obj.chat.id
user_id = obj.from_user.id
user = obj.from_user
else:
assert False
assert user

if user_id != self.config.admin_user_id:
username = self._get_user_name(user)
await self.bot.send_message(
chat_id=chat_id,
text=self.localization.CREATORS_ONLY.format(username=username),
)
return
return await func(self, obj, *args, **kwargs)

return wrapped
1 change: 1 addition & 0 deletions src/localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Localization:
INCORRECT_EMAIL: str
FILLED_EMAIL: str
ADMINS_ONLY: str
CREATORS_ONLY: str
PAYMENT_URL: str
PAYMENT_CANCEL: str
DALLE_LIMIT: str
Expand Down
1 change: 1 addition & 0 deletions src/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
params: Dict[str, Any] = DEFAULT_PARAMS,
message_count_limit: Dict[str, Any] = DEFAULT_MESSAGE_COUNT_LIMIT,
):
self.provider_name = provider_name
self.model_name = model_name
self.can_handle_images = can_handle_images
self.can_handle_tools = can_handle_tools
Expand Down
2 changes: 1 addition & 1 deletion src/tools/dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def __call__(self, prompt: str, prompt_russian: str) -> Union[str, List[Di
{"type": "text", "text": prompt_russian},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
"image_url": {"url": f"data:image/png;base64,{encoded_image}"},
},
]
return content
Expand Down

0 comments on commit 223123b

Please sign in to comment.