From 3605cb8defd37b4c26ed50287996be1fd3871f86 Mon Sep 17 00:00:00 2001 From: Jintao Date: Sat, 13 Jul 2024 21:38:40 +0800 Subject: [PATCH] fix dataset_sample & deploy stop_words (#1385) --- swift/llm/deploy.py | 17 ++++++++++++----- swift/llm/utils/argument.py | 2 +- swift/llm/utils/dataset.py | 10 ++++------ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/swift/llm/deploy.py b/swift/llm/deploy.py index d3d3ec3eb..5456f403e 100644 --- a/swift/llm/deploy.py +++ b/swift/llm/deploy.py @@ -184,7 +184,7 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR request_id = request_info['request_id'] kwargs = {'max_new_tokens': request.max_tokens} - for key in ['n', 'stop', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']: + for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']: kwargs[key] = getattr(request, key) for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: new_value = getattr(request, key) @@ -192,6 +192,7 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR kwargs[key] = getattr(llm_engine.generation_config, key) else: kwargs[key] = new_value + kwargs['stop'] = (llm_engine.generation_config.stop or []) + (getattr(request, 'stop') or []) generation_config = VllmGenerationConfig(**kwargs) if generation_config.use_beam_search and request.stream: @@ -343,7 +344,7 @@ def __repr__(self) -> str: @torch.inference_mode() async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionRequest], raw_request: Request): - global model, template + global model, template, _args result = await _prepare_request(request) if isinstance(result, JSONResponse): return result @@ -359,8 +360,13 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq new_value = getattr(request, key) if new_value is None: kwargs[key] = getattr(model.generation_config, key) + if key == 'temperature': + do_sample = getattr(model.generation_config, 'do_sample') + if not do_sample: + kwargs[key] = 0 else: kwargs[key] = new_value + if kwargs['temperature'] == 0: kwargs['do_sample'] = False kwargs['temperature'] = 1 @@ -374,7 +380,8 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq set_generation_config(model, generation_config) # inplace model.generation_config = _old_generation_config request_info['generation_config'] = generation_config - request_info.update({'seed': request.seed, 'stop': request.stop, 'stream': request.stream}) + stop = (_args.stop_words or []) + (getattr(request, 'stop') or []) + request_info.update({'seed': request.seed, 'stop': stop, 'stream': request.stream}) logger.info(request_info) created_time = int(time.time()) @@ -397,7 +404,7 @@ async def _generate_full(): model, template, **example, - stop_words=request.stop, + stop_words=stop, generation_config=generation_config, generation_info=generation_info, **adapter_kwargs) @@ -441,7 +448,7 @@ def _generate_stream(): model, template, **example, - stop_words=request.stop, + stop_words=stop, generation_config=generation_config, generation_info=generation_info, **adapter_kwargs) diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 9053a38fc..7acd45559 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -1117,7 +1117,7 @@ class InferArguments(ArgumentsBase): top_p: float = 0.7 repetition_penalty: float = 1. num_beams: int = 1 - stop_words: List[str] = None + stop_words: List[str] = field(default_factory=list) # rope-scaling rope_scaling: Literal['linear', 'dynamic'] = None diff --git a/swift/llm/utils/dataset.py b/swift/llm/utils/dataset.py index 7cc5f21c5..39015a5d4 100644 --- a/swift/llm/utils/dataset.py +++ b/swift/llm/utils/dataset.py @@ -341,12 +341,10 @@ def sample_dataset(dataset: HfDataset, dataset_sample: int, random_state: Option return dataset if random_state is None: random_state = RandomState() - # Sample the part that exceeds the length of the dataset. - idx = random_state.permutation(len(dataset))[:dataset_sample] - dataset_sample -= len(idx) - if dataset_sample > 0: - idx2 = random_state.choice(len(dataset), dataset_sample) - idx = np.concatenate([idx, idx2], axis=0) + + idx_repeat = np.tile(range(len(dataset)), dataset_sample // len(dataset)) + idx_random = random_state.permutation(len(dataset))[:dataset_sample % len(dataset)] + idx = np.concatenate([idx_repeat, idx_random]) dataset = dataset.select(idx) return dataset