Skip to content

Commit

Permalink
fix dataset_sample & deploy stop_words (#1385)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Jul 13, 2024
1 parent f8faae8 commit 3605cb8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
17 changes: 12 additions & 5 deletions swift/llm/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,15 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
request_id = request_info['request_id']

kwargs = {'max_new_tokens': request.max_tokens}
for key in ['n', 'stop', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']:
for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']:
kwargs[key] = getattr(request, key)
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
new_value = getattr(request, key)
if new_value is None:
kwargs[key] = getattr(llm_engine.generation_config, key)
else:
kwargs[key] = new_value
kwargs['stop'] = (llm_engine.generation_config.stop or []) + (getattr(request, 'stop') or [])

generation_config = VllmGenerationConfig(**kwargs)
if generation_config.use_beam_search and request.stream:
Expand Down Expand Up @@ -343,7 +344,7 @@ def __repr__(self) -> str:

@torch.inference_mode()
async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionRequest], raw_request: Request):
global model, template
global model, template, _args
result = await _prepare_request(request)
if isinstance(result, JSONResponse):
return result
Expand All @@ -359,8 +360,13 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
new_value = getattr(request, key)
if new_value is None:
kwargs[key] = getattr(model.generation_config, key)
if key == 'temperature':
do_sample = getattr(model.generation_config, 'do_sample')
if not do_sample:
kwargs[key] = 0
else:
kwargs[key] = new_value

if kwargs['temperature'] == 0:
kwargs['do_sample'] = False
kwargs['temperature'] = 1
Expand All @@ -374,7 +380,8 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
set_generation_config(model, generation_config) # inplace
model.generation_config = _old_generation_config
request_info['generation_config'] = generation_config
request_info.update({'seed': request.seed, 'stop': request.stop, 'stream': request.stream})
stop = (_args.stop_words or []) + (getattr(request, 'stop') or [])
request_info.update({'seed': request.seed, 'stop': stop, 'stream': request.stream})
logger.info(request_info)

created_time = int(time.time())
Expand All @@ -397,7 +404,7 @@ async def _generate_full():
model,
template,
**example,
stop_words=request.stop,
stop_words=stop,
generation_config=generation_config,
generation_info=generation_info,
**adapter_kwargs)
Expand Down Expand Up @@ -441,7 +448,7 @@ def _generate_stream():
model,
template,
**example,
stop_words=request.stop,
stop_words=stop,
generation_config=generation_config,
generation_info=generation_info,
**adapter_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ class InferArguments(ArgumentsBase):
top_p: float = 0.7
repetition_penalty: float = 1.
num_beams: int = 1
stop_words: List[str] = None
stop_words: List[str] = field(default_factory=list)

# rope-scaling
rope_scaling: Literal['linear', 'dynamic'] = None
Expand Down
10 changes: 4 additions & 6 deletions swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,10 @@ def sample_dataset(dataset: HfDataset, dataset_sample: int, random_state: Option
return dataset
if random_state is None:
random_state = RandomState()
# Sample the part that exceeds the length of the dataset.
idx = random_state.permutation(len(dataset))[:dataset_sample]
dataset_sample -= len(idx)
if dataset_sample > 0:
idx2 = random_state.choice(len(dataset), dataset_sample)
idx = np.concatenate([idx, idx2], axis=0)

idx_repeat = np.tile(range(len(dataset)), dataset_sample // len(dataset))
idx_random = random_state.permutation(len(dataset))[:dataset_sample % len(dataset)]
idx = np.concatenate([idx_repeat, idx_random])
dataset = dataset.select(idx)
return dataset

Expand Down

0 comments on commit 3605cb8

Please sign in to comment.