Skip to content

Commit

Permalink
Merge commit '3605cb8defd37b4c26ed50287996be1fd3871f86' into release/2.2
Browse files Browse the repository at this point in the history
* commit '3605cb8defd37b4c26ed50287996be1fd3871f86':
  fix dataset_sample & deploy stop_words (#1385)
  update discord and fix documentation link
  update discord and fix documentation link error
  Update discord url which will never expire in README.md
  • Loading branch information
tastelikefeet committed Jul 13, 2024
2 parents 2c8753a + 3605cb8 commit 7163356
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 16 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ SWIFT has rich documentations for users, please feel free to check our documenta
You can contact us and communicate with us by adding our group:


[Discord Group](https://discord.gg/qQXTzNUp) | 微信群
[Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
:-------------------------:|:-------------------------:
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">

Expand Down Expand Up @@ -226,7 +226,7 @@ docker pull registry.us-west-1.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.

## 🚀 Getting Started

This section introduces basic usage, see the [Documentation](#-documentation) section for more ways to use.
This section introduces basic usage, see the [Documentation](https://swift.readthedocs.io/en/latest/) section for more ways to use.

### Web-UI

Expand Down
4 changes: 2 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站:

请扫描下面的二维码来加入我们的交流群:

[Discord Group](https://discord.gg/qQXTzNUp) | 微信群
[Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
:-------------------------:|:-------------------------:
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">

Expand Down Expand Up @@ -228,7 +228,7 @@ docker pull registry.us-west-1.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.

## 🚀 快速开始

本章节介绍基本使用,更丰富的使用方式请查看[文档部分](#-文档)
本章节介绍基本使用,更丰富的使用方式请查看[文档部分](https://swift.readthedocs.io/zh-cn/latest/)

### Web-UI

Expand Down
Binary file modified asset/discord_qr.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 12 additions & 5 deletions swift/llm/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,15 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
request_id = request_info['request_id']

kwargs = {'max_new_tokens': request.max_tokens}
for key in ['n', 'stop', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']:
for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']:
kwargs[key] = getattr(request, key)
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
new_value = getattr(request, key)
if new_value is None:
kwargs[key] = getattr(llm_engine.generation_config, key)
else:
kwargs[key] = new_value
kwargs['stop'] = (llm_engine.generation_config.stop or []) + (getattr(request, 'stop') or [])

generation_config = VllmGenerationConfig(**kwargs)
if generation_config.use_beam_search and request.stream:
Expand Down Expand Up @@ -343,7 +344,7 @@ def __repr__(self) -> str:

@torch.inference_mode()
async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionRequest], raw_request: Request):
global model, template
global model, template, _args
result = await _prepare_request(request)
if isinstance(result, JSONResponse):
return result
Expand All @@ -359,8 +360,13 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
new_value = getattr(request, key)
if new_value is None:
kwargs[key] = getattr(model.generation_config, key)
if key == 'temperature':
do_sample = getattr(model.generation_config, 'do_sample')
if not do_sample:
kwargs[key] = 0
else:
kwargs[key] = new_value

if kwargs['temperature'] == 0:
kwargs['do_sample'] = False
kwargs['temperature'] = 1
Expand All @@ -374,7 +380,8 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
set_generation_config(model, generation_config) # inplace
model.generation_config = _old_generation_config
request_info['generation_config'] = generation_config
request_info.update({'seed': request.seed, 'stop': request.stop, 'stream': request.stream})
stop = (_args.stop_words or []) + (getattr(request, 'stop') or [])
request_info.update({'seed': request.seed, 'stop': stop, 'stream': request.stream})
logger.info(request_info)

created_time = int(time.time())
Expand All @@ -397,7 +404,7 @@ async def _generate_full():
model,
template,
**example,
stop_words=request.stop,
stop_words=stop,
generation_config=generation_config,
generation_info=generation_info,
**adapter_kwargs)
Expand Down Expand Up @@ -441,7 +448,7 @@ def _generate_stream():
model,
template,
**example,
stop_words=request.stop,
stop_words=stop,
generation_config=generation_config,
generation_info=generation_info,
**adapter_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ class InferArguments(ArgumentsBase):
top_p: float = 0.7
repetition_penalty: float = 1.
num_beams: int = 1
stop_words: List[str] = None
stop_words: List[str] = field(default_factory=list)

# rope-scaling
rope_scaling: Literal['linear', 'dynamic'] = None
Expand Down
10 changes: 4 additions & 6 deletions swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,10 @@ def sample_dataset(dataset: HfDataset, dataset_sample: int, random_state: Option
return dataset
if random_state is None:
random_state = RandomState()
# Sample the part that exceeds the length of the dataset.
idx = random_state.permutation(len(dataset))[:dataset_sample]
dataset_sample -= len(idx)
if dataset_sample > 0:
idx2 = random_state.choice(len(dataset), dataset_sample)
idx = np.concatenate([idx, idx2], axis=0)

idx_repeat = np.tile(range(len(dataset)), dataset_sample // len(dataset))
idx_random = random_state.permutation(len(dataset))[:dataset_sample % len(dataset)]
idx = np.concatenate([idx_repeat, idx_random])
dataset = dataset.select(idx)
return dataset

Expand Down

0 comments on commit 7163356

Please sign in to comment.