Skip to content

Commit

Permalink
Fix: gen_config in lmdeploypipeline updated by input gen_params (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiangning30 committed Feb 5, 2024
1 parent 6a54476 commit 90ef521
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion lagent/llms/lmdepoly_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,19 @@ def generate(self,
Returns:
(a list of/batched) text/chat completion
"""
from lmdeploy.messages import GenerationConfig

batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
max_tokens = gen_params.pop('max_tokens')
gen_config = GenerationConfig(**gen_params)
gen_config.max_new_tokens = max_tokens
response = self.model.batch_infer(
prompt, do_preprocess=do_preprocess, **gen_params)
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
response = [resp.text for resp in response]
# remove stop_words
response = filter_suffix(response, self.gen_params.get('stop_words'))
Expand Down

0 comments on commit 90ef521

Please sign in to comment.