Skip to content

Commit

Permalink
role with name (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiangning30 committed Jul 30, 2024
1 parent f0ffbf5 commit 05e311d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lagent/llms/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
if merged_prompt.get('fallback_role'):
merged_prompt = self.roles[self.roles[
merged_prompt['fallback_role']]]
res = {}
res = role_prompt.copy()
res['role'] = merged_prompt['api_role']
res['content'] = merged_prompt.get('begin', '')
res['content'] += role_prompt.get('content', '')
Expand Down
5 changes: 3 additions & 2 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def chat(
gen_params = {**self.gen_params, **gen_params}
with ThreadPoolExecutor(max_workers=20) as executor:
tasks = [
executor.submit(self._chat, messages, **gen_params)
executor.submit(self._chat,
self.template_parser._prompt2api(messages),
**gen_params)
for messages in (
[inputs] if isinstance(inputs[0], dict) else inputs)
]
Expand Down Expand Up @@ -296,7 +298,6 @@ def streaming(raw_response):
assert isinstance(messages, list)
gen_params = gen_params.copy()


# Hold out 100 tokens due to potential errors in tiktoken calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
Expand Down

0 comments on commit 05e311d

Please sign in to comment.