Skip to content

Commit

Permalink
fix bug of TritonClient (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiangning30 committed Feb 1, 2024
1 parent 42c6d26 commit eea6e1c
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions lagent/llms/lmdepoly_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,17 @@ def generate(self,
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_tokens, sequence_start,
sequence_end):
if status.value < 0:
break
if status.value == 0:
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt + self.chatbot._session.response)
# remove stop_words
res = filter_suffix(res, self.gen_params.get('stop_words'))
return res
else:
return ''
status = self.state_map.get(status)
if status < ModelStatusCode.END:
return ''
elif status == ModelStatusCode.END:
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt +
self.chatbot._session.response)
# remove stop_words
res = filter_suffix(res, self.gen_params.get('stop_words'))
return res

def stream_chat(self,
inputs: List[dict],
Expand All @@ -130,7 +130,7 @@ def stream_chat(self,
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
from lmdeploy.serve.turbomind.chatbot import Session, StatusCode, get_logger
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'

Expand Down Expand Up @@ -158,19 +158,19 @@ def stream_chat(self,
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_tokens, sequence_start,
sequence_end):
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
status = self.state_map.get(status)
if status < ModelStatusCode.END:
return status, res, _
elif status == ModelStatusCode.END: # remove stop_words
res = filter_suffix(res, self.gen_params.get('stop_words'))
if status.value < 0:
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt +
self.chatbot._session.response)
yield status, res, _
break
else:
yield self.state_map.get(status), res, _
if status.value == 0:
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt + self.chatbot._session.response)
yield self.state_map.get(status), res, _
else:
return self.state_map.get(status), res, _
yield status, res, _

def _update_gen_params(self, **kwargs):
import mmengine
Expand Down

0 comments on commit eea6e1c

Please sign in to comment.