Skip to content

Commit

Permalink
Feat: GPTAPI supports qwen (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiangning30 committed Aug 2, 2024
1 parent d37ff17 commit 81e7ace
Showing 1 changed file with 120 additions and 58 deletions.
178 changes: 120 additions & 58 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,13 @@ def __init__(self,
openai_api_base: str = OPENAI_API_BASE,
proxies: Optional[Dict] = None,
**gen_params):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.',
DeprecationWarning)
gen_params.pop('top_k')

super().__init__(
model_type=model_type,
meta_template=meta_template,
query_per_second=query_per_second,
retry=retry,
**gen_params)
self.gen_params.pop('top_k')
if not model_type.lower().startswith('internlm'):
self.gen_params.pop('skip_special_tokens')
self.logger = getLogger(__name__)

if isinstance(key, str):
Expand Down Expand Up @@ -146,8 +140,6 @@ def stream_chat(
raise NotImplementedError('unsupported parameter: max_tokens')
gen_params = self.update_gen_params(**gen_params)
gen_params['stream'] = True
if not self.model_type.lower().startswith('internlm'):
gen_params.pop('session_id', None)

resp = ''
finished = False
Expand All @@ -157,7 +149,10 @@ def stream_chat(
# mapping to role that openai supports
messages = self.template_parser._prompt2api(inputs)
for text in self._stream_chat(messages, **gen_params):
resp += text
if self.model_type.lower().startswith('qwen'):
resp = text
else:
resp += text
if not resp:
continue
# remove stop_words
Expand All @@ -182,12 +177,12 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
str: The generated string.
"""
assert isinstance(messages, list)
gen_params = gen_params.copy()

# Hold out 100 tokens due to potential errors in tiktoken calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return ''
header, data = self.generate_request_data(
model_type=self.model_type,
messages=messages,
gen_params=gen_params,
json_mode=self.json_mode)

max_num_retries = 0
while max_num_retries < self.retry:
Expand All @@ -207,11 +202,7 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
break

key = self.keys[self.key_ctr]

header = {
'Authorization': f'Bearer {key}',
'content-type': 'application/json',
}
header['Authorization'] = f'Bearer {key}'

if self.orgs:
with Lock():
Expand All @@ -222,18 +213,6 @@ def _chat(self, messages: List[dict], **gen_params) -> str:

response = dict()
try:
gen_params_new = gen_params.copy()
data = dict(
model=self.model_type,
messages=messages,
max_tokens=max_tokens,
n=1,
stop=gen_params_new.pop('stop_words'),
frequency_penalty=gen_params_new.pop('repetition_penalty'),
**gen_params_new,
)
if self.json_mode:
data['response_format'] = {'type': 'json_object'}
raw_response = requests.post(
self.url,
headers=header,
Expand Down Expand Up @@ -285,25 +264,36 @@ def streaming(raw_response):
decoded = chunk.decode('utf-8')
if decoded == 'data: [DONE]':
return
if decoded[:6] == 'data: ':
decoded = decoded[6:]
if decoded[:5] == 'data:':
decoded = decoded[5:]
if decoded[0] == ' ':
decoded = decoded[1:]
else:
print(decoded)
continue
response = json.loads(decoded)
if 'code' in response and response['code'] == -20003:
# Context exceeds maximum length
yield ''
return
choice = response['choices'][0]
if self.model_type.lower().startswith('qwen'):
choice = response['output']['choices'][0]
else:
choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
yield choice['delta']['content']
if self.model_type.lower().startswith('qwen'):
yield choice['message']['content']
else:
yield choice['delta']['content']

assert isinstance(messages, list)
gen_params = gen_params.copy()

# Hold out 100 tokens due to potential errors in tiktoken calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return ''
header, data = self.generate_request_data(
model_type=self.model_type,
messages=messages,
gen_params=gen_params,
json_mode=self.json_mode)

max_num_retries = 0
while max_num_retries < self.retry:
Expand All @@ -320,11 +310,7 @@ def streaming(raw_response):
break

key = self.keys[self.key_ctr]

header = {
'Authorization': f'Bearer {key}',
'content-type': 'application/json',
}
header['Authorization'] = f'Bearer {key}'

if self.orgs:
self.org_ctr += 1
Expand All @@ -334,18 +320,6 @@ def streaming(raw_response):

response = dict()
try:
gen_params_new = gen_params.copy()
data = dict(
model=self.model_type,
messages=messages,
max_tokens=max_tokens,
n=1,
stop=gen_params_new.pop('stop_words'),
frequency_penalty=gen_params_new.pop('repetition_penalty'),
**gen_params_new,
)
if self.json_mode:
data['response_format'] = {'type': 'json_object'}
raw_response = requests.post(
self.url,
headers=header,
Expand Down Expand Up @@ -378,6 +352,94 @@ def streaming(raw_response):
f'{max_num_retries} times. Check the logs for '
'details.')

def generate_request_data(self,
model_type,
messages,
gen_params,
json_mode=False):
"""
Generates the request data for different model types.
Args:
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen').
messages (list): The list of messages to be sent to the model.
gen_params (dict): The generation parameters.
json_mode (bool): Flag to determine if the response format should be JSON.
Returns:
tuple: A tuple containing the header and the request data.
"""
# Copy generation parameters to avoid modifying the original dictionary
gen_params = gen_params.copy()

# Hold out 100 tokens due to potential errors in token calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return '', ''

# Initialize the header
header = {
'content-type': 'application/json',
}

# Common parameters processing
gen_params['max_tokens'] = max_tokens
if 'stop_words' in gen_params:
gen_params['stop'] = gen_params.pop('stop_words')
if 'repetition_penalty' in gen_params:
gen_params['frequency_penalty'] = gen_params.pop(
'repetition_penalty')

# Model-specific processing
data = {}
if model_type.lower().startswith('gpt'):
if 'top_k' in gen_params:
warnings.warn(
'`top_k` parameter is deprecated in OpenAI APIs.',
DeprecationWarning)
gen_params.pop('top_k')
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
data = {
'model': model_type,
'messages': messages,
'n': 1,
**gen_params
}
if json_mode:
data['response_format'] = {'type': 'json_object'}
elif model_type.lower().startswith('internlm'):
data = {
'model': model_type,
'messages': messages,
'n': 1,
**gen_params
}
if json_mode:
data['response_format'] = {'type': 'json_object'}
elif model_type.lower().startswith('qwen'):
header['X-DashScope-SSE'] = 'enable'
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
if 'frequency_penalty' in gen_params:
gen_params['repetition_penalty'] = gen_params.pop(
'frequency_penalty')
gen_params['result_format'] = 'message'
data = {
'model': model_type,
'input': {
'messages': messages
},
'parameters': {
**gen_params
}
}
else:
raise NotImplementedError(
f'Model type {model_type} is not supported')

return header, data

def tokenize(self, prompt: str) -> list:
"""Tokenize the input prompt.
Expand Down

0 comments on commit 81e7ace

Please sign in to comment.