Skip to content

Commit

Permalink
handle batch infer for chat
Browse files Browse the repository at this point in the history
  • Loading branch information
zehuichen123 committed Feb 19, 2024
1 parent 5087e59 commit 352fbe8
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion lagent/llms/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def __init__(self,
**kwargs):
super().__init__(template_parser=template_parser, **kwargs)

def chat(self, inputs: List[dict], do_sample: bool = True, **kwargs):
def chat(self, inputs: Union[List[dict], List[List[dict]]], do_sample: bool = True, **kwargs):
"""Return the chat completions in stream mode.
Args:
Expand All @@ -313,6 +313,12 @@ def chat(self, inputs: List[dict], do_sample: bool = True, **kwargs):
Returns:
the text/chat completion
"""
# handle batch inference with vanilla for loop
if isinstance(inputs[0], list):
resps = []
for input in inputs:
resps.append(self.chat(input, do_sample, **kwargs))
return resps
prompt = self.template_parser(inputs)
query = prompt[-1]['content']
history = prompt[:-1]
Expand Down

0 comments on commit 352fbe8

Please sign in to comment.