Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
zehuichen123 committed Feb 19, 2024
1 parent 4713c99 commit 5087e59
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
1 change: 0 additions & 1 deletion lagent/llms/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]:
return res

def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
# merged_prompt = self.roles[self.roles[role_prompt['role']]]
merged_prompt = self.roles[role_prompt['role']]
if merged_prompt.get('fallback_role'):
merged_prompt = self.roles[self.roles[
Expand Down
8 changes: 2 additions & 6 deletions lagent/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def __init__(self,
top_k: float = None,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
stop_words: Union[List[str], str] = None,
stop_words_id: Union[List[int], int] = None):
stop_words: Union[List[str], str] = None):
self.path = path
self.tokenizer_only = tokenizer_only
# meta template
Expand All @@ -133,16 +132,13 @@ def __init__(self,

if isinstance(stop_words, str):
stop_words = [stop_words]
if isinstance(stop_words_id, int):
stop_words_id = [stop_words_id]
self.gen_params = dict(
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
stop_words=stop_words,
stop_words_id=stop_words_id)
stop_words=stop_words)

def generate(self, inputs: Union[str, List[str]], **gen_params) -> str:
"""Generate results given a str (or list of) inputs.
Expand Down
7 changes: 6 additions & 1 deletion lagent/llms/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ def __init__(self,
tokenizer_only: bool = False,
model_kwargs: dict = dict(device_map='auto'),
meta_template: Optional[Dict] = None,
stop_words_id: Union[List[int], int] = None,
**kwargs):
super().__init__(
path=path,
tokenizer_only=tokenizer_only,
meta_template=meta_template,
**kwargs)
if isinstance(stop_words_id, int):
stop_words_id = [stop_words_id]
self.gen_params.update(stop_words_id=stop_words_id)

self._load_tokenizer(
path=path,
Expand Down Expand Up @@ -316,7 +320,8 @@ def chat(self, inputs: List[dict], do_sample: bool = True, **kwargs):
response, history = self.model.chat(self.tokenizer,
query,
history=history)
except:
except Exception as e:
# handle over-length input error
response = ""
return response

0 comments on commit 5087e59

Please sign in to comment.