diff --git a/examples/pytorch/baichuan/default_offload_opt_param.json b/examples/pytorch/baichuan/default_offload_opt_param.json new file mode 100644 index 000000000..b6f982e5a --- /dev/null +++ b/examples/pytorch/baichuan/default_offload_opt_param.json @@ -0,0 +1,44 @@ +{ + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "total_num_steps": "auto", + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto" + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 10000000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/examples/pytorch/baichuan/finetune_baichuan.py b/examples/pytorch/baichuan/finetune_baichuan.py new file mode 100644 index 000000000..c7a240f98 --- /dev/null +++ b/examples/pytorch/baichuan/finetune_baichuan.py @@ -0,0 +1,235 @@ +import os +import sys +from dataclasses import dataclass, field + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from modelscope import (EpochBasedTrainer, MsDataset, TrainingArgs, + build_dataset_from_file, snapshot_download) +from modelscope.metainfo import Trainers +from modelscope.preprocessors import TextGenerationTransformersPreprocessor +from modelscope.swift import Swift +from modelscope.swift.lora import LoRAConfig +from modelscope.trainers import build_trainer + +DEFAULT_PAD_TOKEN = '[PAD]' +DEFAULT_EOS_TOKEN = '' +DEFAULT_BOS_TOKEN = '' +DEFAULT_UNK_TOKEN = '' + + +@dataclass(init=False) +class TextGenerationArguments(TrainingArgs): + + trainer: str = field( + default=Trainers.default, metadata={ + 'help': 'The trainer used', + }) + + src_txt: str = field( + default=None, + metadata={ + 'help': 'The source text key of preprocessor', + 'cfg_node': 'preprocessor.src_txt' + }) + + tgt_txt: str = field( + default=None, + metadata={ + 'help': 'The target text key of preprocessor', + 'cfg_node': 'preprocessor.tgt_txt' + }) + + preprocessor: str = field( + default=None, + metadata={ + 'help': 'The preprocessor type', + 'cfg_node': 'preprocessor.type' + }) + + lr_scheduler: str = field( + default=None, + metadata={ + 'help': 'The lr scheduler type', + 'cfg_node': 'train.lr_scheduler.type' + }) + + world_size: int = field( + default=None, + metadata={ + 'help': 'The parallel world size', + 'cfg_node': 'megatron.world_size' + }) + + tensor_model_parallel_size: int = field( + default=None, + metadata={ + 'help': 'The tensor model parallel size', + 'cfg_node': 'megatron.tensor_model_parallel_size' + }) + + use_megatron: bool = field( + default=None, metadata={ + 'help': 'Whether to use MegatronHook', + }) + + bf16: bool = field( + default=False, + metadata={ + 'help': 'Whether to use bf16', + 'cfg_node': 'train.bf16' + }) + + deepspeed: str = field( + default=None, + metadata={ + 'help': 'The location of DeepSpeed json config file.', + }) + + T_max: int = field( + default=None, + metadata={ + 'help': 'The T_max for CosineAnnealingLR', + 'cfg_node': 'train.lr_scheduler.T_max' + }) + + use_lora: int = field( + default=0, + metadata={'help': 'Whether to use lora to train the model.'}, + ) + + lora_rank: int = field( + default=32, + metadata={'help': 'The lora rank'}, + ) + + lora_alpha: int = field( + default=32, + metadata={'help': 'The lora alpha'}, + ) + + lora_dropout: float = field( + default=0.05, + metadata={'help': 'The lora dropout'}, + ) + + device_map: str = field( + default=None, + metadata={ + 'help': 'A map that specifies where each submodule should go.' + }) + + +def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, + model): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +config, args = TextGenerationArguments().parse_cli().to_config() +print(config, args) + + +def cfg_modify_fn(cfg): + if args.use_model_config: + cfg.merge_from_dict(config) + else: + cfg = config + if 'hooks' not in cfg.train: + cfg.train['hooks'] = [] + if args.use_megatron: + cfg.train.hooks.append({'type': 'MegatronHook'}) + if args.deepspeed: + cfg.train.hooks.append({ + 'type': 'DeepspeedHook', + 'config': args.deepspeed, + 'save_zero_checkpoint': True, + 'with_mpu': False, + }) + + return cfg + + +if args.dataset_json_file is None: + train_dataset = MsDataset.load( + args.train_dataset_name, + subset_name=args.train_subset_name, + split=args.train_split, + namespace=args.train_dataset_namespace) + validation_dataset = MsDataset.load( + args.val_dataset_name, + subset_name=args.val_subset_name, + split=args.val_split, + namespace=args.val_dataset_namespace) +else: + train_dataset, validation_dataset = build_dataset_from_file( + args.dataset_json_file) + +model_dir = snapshot_download(args.model) +sys.path.append(model_dir) +model = AutoModelForCausalLM.from_pretrained( + model_dir, trust_remote_code=True, device_map=args.device_map) +cfg_file = os.path.join(model_dir, 'configuration.json') +tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + +special_tokens_dict = dict() +if tokenizer.pad_token is None or tokenizer.pad_token == '': + special_tokens_dict['pad_token'] = DEFAULT_PAD_TOKEN +if tokenizer.eos_token is None or tokenizer.eos_token == '': + special_tokens_dict['eos_token'] = DEFAULT_EOS_TOKEN +if tokenizer.bos_token is None or tokenizer.bos_token == '': + special_tokens_dict['bos_token'] = DEFAULT_BOS_TOKEN +if tokenizer.unk_token is None or tokenizer.unk_token == '': + special_tokens_dict['unk_token'] = DEFAULT_UNK_TOKEN + +smart_tokenizer_and_embedding_resize( + special_tokens_dict=special_tokens_dict, + tokenizer=tokenizer, + model=model, +) + +preprocessor = TextGenerationTransformersPreprocessor( + model_dir, + tokenizer=tokenizer, + src_txt=config.preprocessor.src_txt, + tgt_txt=config.preprocessor.tgt_txt) + +if args.use_lora != 0: + lora_config = LoRAConfig( + replace_modules=['pack'], + rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout) + model = model.bfloat16() + Swift.prepare_model(model, lora_config) + +kwargs = dict( + model=model, + cfg_file=cfg_file, + preprocessor=preprocessor, + train_dataset=train_dataset, + eval_dataset=validation_dataset, + seed=args.seed, + cfg_modify_fn=cfg_modify_fn, + # No placement for model, leave the model to `device_map` + device='cpu') + +trainer: EpochBasedTrainer = build_trainer( + name=args.trainer, default_args=kwargs) +trainer.train() diff --git a/examples/pytorch/baichuan/run_train_lora.sh b/examples/pytorch/baichuan/run_train_lora.sh new file mode 100644 index 000000000..2237299df --- /dev/null +++ b/examples/pytorch/baichuan/run_train_lora.sh @@ -0,0 +1,21 @@ +export PYTHONPATH=$PYTHONPATH:./ +torchrun examples/pytorch/baichuan/finetune_baichuan.py \ + --trainer 'text-generation-trainer' \ + --work_dir './tmp' \ + --model 'baichuan-inc/baichuan-7B' \ + --train_dataset_name 'chinese-poetry-collection' \ + --val_dataset_name 'chinese-poetry-collection' \ + --train_split 'train' \ + --val_split 'test' \ + --src_txt 'text1' \ + --tgt_txt 'text2' \ + --max_epochs 1 \ + --per_device_train_batch_size 8 \ + --lr 2e-5 \ + --lr_scheduler 'CosineAnnealingLR' \ + --eval_strategy 'no' \ + --bf16 1 \ + --use_lora 1 \ + --eval_metrics 'text-gen-metric' \ + --T_max 1 \ + --device_map 'auto' \ diff --git a/modelscope/tuners/__init__.py b/examples/pytorch/chatglm6b/__init__.py similarity index 100% rename from modelscope/tuners/__init__.py rename to examples/pytorch/chatglm6b/__init__.py diff --git a/examples/pytorch/chatglm6b/chatglm_trainer.py b/examples/pytorch/chatglm6b/chatglm_trainer.py new file mode 100644 index 000000000..b34563bd0 --- /dev/null +++ b/examples/pytorch/chatglm6b/chatglm_trainer.py @@ -0,0 +1,118 @@ +from typing import Any, Dict, Union + +import numpy as np +import torch +from transformers.deepspeed import is_deepspeed_zero3_enabled + +from modelscope import EpochBasedTrainer, get_logger + +logger = get_logger(__name__) + + +class Seq2SeqTrainer(EpochBasedTrainer): + + def _decode(self, tokens, ignore_pad_token_for_loss=False): + tokens = tokens.cpu().numpy() + if ignore_pad_token_for_loss: + tokens = np.where(tokens != -100, tokens, + self.tokenizer.pad_token_id) + return [ + t for t in self.tokenizer.batch_decode( + tokens, skip_special_tokens=True) if t != '' + ] + + def evaluation_step( + self, + inputs: Dict[str, Union[torch.Tensor, Any]], + ): + has_labels = 'labels' in inputs + # XXX: adapt synced_gpus for fairscale as well + gen_kwargs = self.cfg['gen_kwargs'] + if gen_kwargs.get('max_length') is None and gen_kwargs.get( + 'max_new_tokens') is None: + gen_kwargs['max_length'] = self.model.config.max_length + gen_kwargs['num_beams'] = ( + gen_kwargs['num_beams'] if gen_kwargs.get('num_beams') is not None + else self.model.config.num_beams) + default_synced_gpus = True if is_deepspeed_zero3_enabled() else False + gen_kwargs['synced_gpus'] = ( + gen_kwargs['synced_gpus'] if gen_kwargs.get('synced_gpus') + is not None else default_synced_gpus) + + if 'attention_mask' in inputs: + gen_kwargs['attention_mask'] = inputs.get('attention_mask', None) + if 'position_ids' in inputs: + gen_kwargs['position_ids'] = inputs.get('position_ids', None) + if 'global_attention_mask' in inputs: + gen_kwargs['global_attention_mask'] = inputs.get( + 'global_attention_mask', None) + + # prepare generation inputs + # some encoder-decoder models can have varying encoder's and thus + # varying model input names + if hasattr( + self.model, 'encoder' + ) and self.model.encoder.main_input_name != self.model.main_input_name: + generation_inputs = inputs[self.model.encoder.main_input_name] + else: + generation_inputs = inputs[self.model.main_input_name] + + gen_kwargs['input_ids'] = generation_inputs + gen_kwargs['pad_token_id'] = self.tokenizer.pad_token_id + generated_tokens = self.model.generate(**gen_kwargs) + generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] + + # in case the batch is shorter than max length, the output should be padded + if gen_kwargs.get('max_length') is not None and generated_tokens.shape[ + -1] < gen_kwargs['max_length']: + generated_tokens = self._pad_tensors_to_max_len( + generated_tokens, gen_kwargs['max_length']) + elif gen_kwargs.get('max_new_tokens' + ) is not None and generated_tokens.shape[-1] < ( + gen_kwargs['max_new_tokens'] + 1): + generated_tokens = self._pad_tensors_to_max_len( + generated_tokens, gen_kwargs['max_new_tokens'] + 1) + + if has_labels: + labels = inputs['labels'] + if gen_kwargs.get('max_length') is not None and labels.shape[ + -1] < gen_kwargs['max_length']: + labels = self._pad_tensors_to_max_len(labels, + gen_kwargs['max_length']) + elif gen_kwargs.get( + 'max_new_tokens') is not None and labels.shape[-1] < ( + gen_kwargs['max_new_tokens'] + 1): + labels = self._pad_tensors_to_max_len( + labels, (gen_kwargs['max_new_tokens'] + 1)) + else: + labels = None + + generated_tokens = [ + ''.join(self._decode(seq, False)) for seq in generated_tokens + ] + inputs['tgts'] = [''.join(self._decode(seq, True)) for seq in labels] + return { + 'preds': generated_tokens, + } + + def _pad_tensors_to_max_len(self, tensor, max_length): + if self.tokenizer is not None and hasattr(self.tokenizer, + 'pad_token_id'): + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = ( + self.tokenizer.pad_token_id if self.tokenizer.pad_token_id + is not None else self.tokenizer.eos_token_id) + else: + if self.model.config.pad_token_id is not None: + pad_token_id = self.model.config.pad_token_id + else: + raise ValueError( + 'Pad_token_id must be set in the configuration of the model, in order to pad tensors' + ) + + padded_tensor = pad_token_id * torch.ones( + (tensor.shape[0], max_length), + dtype=tensor.dtype, + device=tensor.device) + padded_tensor[:, :tensor.shape[-1]] = tensor + return padded_tensor diff --git a/examples/pytorch/chatglm6b/finetune.py b/examples/pytorch/chatglm6b/finetune.py new file mode 100644 index 000000000..3fa73ba01 --- /dev/null +++ b/examples/pytorch/chatglm6b/finetune.py @@ -0,0 +1,385 @@ +import os +from dataclasses import dataclass, field + +import numpy as np +import torch +from chatglm_trainer import Seq2SeqTrainer +from text_generation_metric import TextGenerationMetric +from transformers import DataCollatorForSeq2Seq + +from modelscope import snapshot_download +from modelscope.metainfo import Models +from modelscope.models import Model +from modelscope.msdatasets import MsDataset +from modelscope.swift import Swift +from modelscope.swift.lora import LoRAConfig +from modelscope.trainers.training_args import TrainingArgs +from modelscope.utils.config import ConfigDict +from modelscope.utils.hub import read_config + + +@dataclass(init=False) +class Chatglm6bArguments(TrainingArgs): + ptuning_checkpoint: str = field( + default=None, + metadata={ + 'help': 'The p-tuning checkpoint previously trained.', + }) + + pre_seq_len: int = field( + default=None, metadata={ + 'help': 'The p-tuning sequence length', + }) + + prefix_projection: bool = field( + default=False, metadata={ + 'help': '', + }) + + quantization_bit: int = field( + default=None, metadata={ + 'help': 'Quantized bit', + }) + + prompt_column: str = field( + default=None, + metadata={ + 'help': + 'The name of the column in the datasets containing the full texts (for summarization).' + }, + ) + + response_column: str = field( + default=None, + metadata={ + 'help': + 'The name of the column in the datasets containing the summaries (for summarization).' + }, + ) + + history_column: str = field( + default=None, + metadata={ + 'help': + 'The name of the column in the datasets containing the history of chat.' + }, + ) + + source_prefix: str = field( + default='', + metadata={ + 'help': + 'A prefix to add before every source text (useful for T5 models).' + }) + + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + 'help': + 'Whether to ignore the tokens corresponding to padded labels in the loss computation or not.' + }, + ) + + max_source_length: int = field( + default=1024, + metadata={ + 'help': + ('The maximum total input sequence length after tokenization. Sequences longer ' + 'than this will be truncated, sequences shorter will be padded.') + }, + ) + + max_target_length: int = field( + default=128, + metadata={ + 'help': + ('The maximum total sequence length for target text after tokenization. Sequences longer ' + 'than this will be truncated, sequences shorter will be padded.') + }, + ) + + max_train_samples: int = field( + default=None, + metadata={ + 'help': + ('For debugging purposes or quicker training, truncate the number of training examples to this ' + 'value if set.') + }, + ) + + max_eval_samples: int = field( + default=None, + metadata={ + 'help': + ('For debugging purposes or quicker training, truncate the number of evaluation examples to this ' + 'value if set.') + }, + ) + + preprocessing_num_workers: int = field( + default=None, + metadata={ + 'help': 'The number of processes to use for the preprocessing.' + }, + ) + + use_lora: int = field( + default=0, + metadata={'help': 'Whether to use lora to train the model.'}, + ) + + lora_rank: int = field( + default=32, + metadata={'help': 'The lora rank'}, + ) + + lora_alpha: int = field( + default=32, + metadata={'help': 'The lora alpha'}, + ) + + lora_dropout: float = field( + default=0.05, + metadata={'help': 'The lora alpha'}, + ) + + +args = Chatglm6bArguments(eval_metrics='chatglm').parse_cli() +print(args) +config, _ = args.to_config(ignore_default_config=args.use_model_config) +config.dump('./configuration.json') + +if config['model']['type'] == 'chatglm6b': + from modelscope.models.nlp import ChatGLMTokenizer +else: + from modelscope.models.nlp import ChatGLM2Tokenizer as ChatGLMTokenizer + + +def cfg_modify_fn(cfg): + if args.use_model_config: + cfg.merge_from_dict(config) + else: + cfg = config + if cfg.train.lr_scheduler.type == 'LinearLR': + cfg.train.lr_scheduler['total_iters'] = \ + int(len(train_dataset) / cfg.train.dataloader.batch_size_per_gpu) * cfg.train.max_epochs + cfg['gen_kwargs'] = { + 'do_sample': True, + 'top_p': 0.7, + 'max_length': 512, + 'temperature': 0.95 + } + return cfg + + +train_dataset = MsDataset.load( + args.train_dataset_name, + subset_name=args.train_subset_name, + split=args.train_split) +validation_dataset = MsDataset.load( + args.val_dataset_name, + subset_name=args.val_subset_name, + split=args.val_split) + +model_dir = snapshot_download(args.model) +model_config = read_config(model_dir) +model_config['model'] = ConfigDict({ + 'type': config['model']['type'], +}) + +if config['model']['type'] == 'chatglm6b': + model_config['model']['pre_seq_len'] = args.pre_seq_len + model_config['model']['prefix_projection'] = args.prefix_projection + +tokenizer = ChatGLMTokenizer.from_pretrained(model_dir, trust_remote_code=True) +model = Model.from_pretrained( + model_dir, cfg_dict=model_config, device_map='auto') + +if args.ptuning_checkpoint is not None: + # Evaluation + # Loading extra state dict of prefix encoder + + prefix_state_dict = torch.load( + os.path.join(args.ptuning_checkpoint, 'pytorch_model.bin')) + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith('transformer.prefix_encoder.'): + new_prefix_state_dict[k[len('transformer.prefix_encoder.'):]] = v + model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + +if args.quantization_bit is not None: + print(f'Quantized to {args.quantization_bit} bit') + model = model.quantize(args.quantization_bit) +if args.pre_seq_len is not None: + # P-tuning v2 + model = model.half() + model.transformer.prefix_encoder.float() +elif not args.use_lora: + # Finetune + model = model.float() + +if args.use_lora != 0: + lora_config = LoRAConfig( + replace_modules=['attention.query_key_value'], + rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout) + model = model.bfloat16() + Swift.prepare_model(model, lora_config) + +prefix = args.source_prefix if args.source_prefix is not None else '' + +# Get the column names for input/target. +prompt_column = args.prompt_column +response_column = args.response_column +history_column = args.history_column + +# Temporarily set max_target_length for training. +max_target_length = args.max_target_length + +model_parameters = filter(lambda p: p.requires_grad, model.parameters()) +trainable_params = sum([np.prod(p.size()) for p in model_parameters]) + +model_parameters = filter(lambda p: not p.requires_grad, model.parameters()) +non_trainable_params = sum([np.prod(p.size()) for p in model_parameters]) + +print('trainable_params:{} ({:.2f}%), non_trainable_params:{}'.format( + trainable_params, trainable_params / non_trainable_params * 100, + non_trainable_params)) + + +def preprocess_function_eval(examples): + inputs, targets = [], [] + for i in range(len(examples[prompt_column])): + if examples[prompt_column][i] and examples[response_column][i]: + query = examples[prompt_column][i] + if history_column is None or len(examples[history_column][i]) == 0: + prompt = query + else: + prompt = '' + history = examples[history_column][i] + for turn_idx, (old_query, response) in enumerate(history): + prompt += '[Round {}]\n问:{}\n答:{}\n'.format( + turn_idx, old_query, response) + prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query) + inputs.append(prompt) + targets.append(examples[response_column][i]) + + inputs = [prefix + inp for inp in inputs] + model_inputs = tokenizer( + inputs, + max_length=args.max_source_length, + truncation=True, + padding=True) + labels = tokenizer( + text_target=targets, max_length=max_target_length, truncation=True) + + if args.ignore_pad_token_for_loss: + labels['input_ids'] = [[(lb if lb != tokenizer.pad_token_id else -100) + for lb in label] + for label in labels['input_ids']] + model_inputs['labels'] = labels['input_ids'] + + return model_inputs + + +def preprocess_function_train(examples): + max_seq_length = args.max_source_length + args.max_target_length + + model_inputs = { + 'input_ids': [], + 'labels': [], + } + for i in range(len(examples[prompt_column])): + if examples[prompt_column][i] and examples[response_column][i]: + query, answer = examples[prompt_column][i], examples[ + response_column][i] + + if history_column is None: + prompt = query + else: + prompt = '' + history = examples[history_column][i] + for turn_idx, (old_query, response) in enumerate(history): + prompt += '[Round {}]\n问:{}\n答:{}\n'.format( + turn_idx, old_query, response) + prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query) + + prompt = prefix + prompt + a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + b_ids = tokenizer.encode(text=answer, add_special_tokens=False) + + if len(a_ids) > args.max_source_length - 1: + a_ids = a_ids[:args.max_source_length - 1] + + if len(b_ids) > args.max_target_length - 2: + b_ids = b_ids[:args.max_target_length - 2] + + input_ids = tokenizer.build_inputs_with_special_tokens( + a_ids, b_ids) + + if config['model']['type'] == 'chatglm6b': + context_length = input_ids.index(tokenizer.bos_token_id) + else: + context_length = len(a_ids) + 2 + mask_position = context_length - 1 + labels = [-100] * context_length + input_ids[mask_position + 1:] + + pad_len = max_seq_length - len(input_ids) + input_ids = input_ids + [tokenizer.pad_token_id] * pad_len + if config['model']['type'] == 'chatglm6b': + labels = labels + [tokenizer.pad_token_id] * pad_len + if args.ignore_pad_token_for_loss: + labels = [(lb if lb != tokenizer.pad_token_id else -100) + for lb in labels] + else: + labels = labels + [-100] * pad_len + + model_inputs['input_ids'].append(input_ids) + model_inputs['labels'].append(labels) + + return model_inputs + + +train_dataset = train_dataset.to_hf_dataset().map( + preprocess_function_train, + batched=True, + num_proc=args.preprocessing_num_workers, + desc='Running tokenizer on train dataset', +) + +validation_dataset = validation_dataset.to_hf_dataset().map( + preprocess_function_eval, + batched=True, + num_proc=args.preprocessing_num_workers, + desc='Running tokenizer on eval dataset', +) + +# Data collator +label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id +data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=None, + padding=False) + +model.gradient_checkpointing_enable() +if config['model']['type'] == 'chatglm6b': + model.enable_input_require_grads() + +# import torch +# model = torch.nn.DataParallel(model).cuda() +trainer = Seq2SeqTrainer( + model=model, + cfg_file='./configuration.json', + train_dataset=train_dataset, + eval_dataset=validation_dataset, + seed=args.seed, + data_collator=data_collator, + remove_unused_data=True, + # No placement for model, leave the model to `device_map` + device='cpu', + cfg_modify_fn=cfg_modify_fn) +trainer.tokenizer = tokenizer +trainer.train() diff --git a/examples/pytorch/chatglm6b/lora_inference.py b/examples/pytorch/chatglm6b/lora_inference.py new file mode 100644 index 000000000..d31ee50df --- /dev/null +++ b/examples/pytorch/chatglm6b/lora_inference.py @@ -0,0 +1,31 @@ +from modelscope import Model, pipeline, read_config +from modelscope.metainfo import Models +from modelscope.swift import Swift +from modelscope.swift.lora import LoRAConfig +from modelscope.utils.config import ConfigDict + +lora_config = LoRAConfig( + replace_modules=['attention.query_key_value'], + rank=32, + lora_alpha=32, + lora_dropout=0.05, + pretrained_weights='./lora_dureader_target/iter_600.pth') + +model_dir = 'ZhipuAI/ChatGLM-6B' +model_config = read_config(model_dir) +model_config['model'] = ConfigDict({ + 'type': Models.chatglm_6b, +}) + +model = Model.from_pretrained(model_dir, cfg_dict=model_config) +model = model.bfloat16() +Swift.prepare_model(model, lora_config) + +pipe = pipeline('chat', model, pipeline_name='chatglm6b-text-generation') + +print( + pipe({ + 'text': + '纵使进入21世纪后,我国教育水平有了明显进步,高考的难度却依旧不容小觑,高考被中国学生和家长定义为改变命运、改写人生脑重要考试,为了这场考试,学生和家长都付出了很多。', + 'history': [] + })) diff --git a/examples/pytorch/chatglm6b/lora_inference_v2.py b/examples/pytorch/chatglm6b/lora_inference_v2.py new file mode 100644 index 000000000..aa86e8904 --- /dev/null +++ b/examples/pytorch/chatglm6b/lora_inference_v2.py @@ -0,0 +1,31 @@ +from modelscope import Model, pipeline, read_config +from modelscope.metainfo import Models +from modelscope.swift import Swift +from modelscope.swift.lora import LoRAConfig +from modelscope.utils.config import ConfigDict + +lora_config = LoRAConfig( + replace_modules=['attention.query_key_value'], + rank=32, + lora_alpha=32, + lora_dropout=0.05, + pretrained_weights='./lora_dureader_target/iter_600.pth') + +model_dir = 'ZhipuAI/chatglm2-6b' +model_config = read_config(model_dir) +model_config['model'] = ConfigDict({ + 'type': Models.chatglm2_6b, +}) + +model = Model.from_pretrained(model_dir, cfg_dict=model_config) +model = model.bfloat16() +Swift.prepare_model(model, lora_config) + +pipe = pipeline('chat', model, pipeline_name='chatglm2_6b-text-generation') + +print( + pipe({ + 'text': + '纵使进入21世纪后,我国教育水平有了明显进步,高考的难度却依旧不容小觑,高考被中国学生和家长定义为改变命运、改写人生脑重要考试,为了这场考试,学生和家长都付出了很多。', + 'history': [] + })) diff --git a/examples/pytorch/chatglm6b/ptuning_inference.py b/examples/pytorch/chatglm6b/ptuning_inference.py new file mode 100644 index 000000000..ab32bec08 --- /dev/null +++ b/examples/pytorch/chatglm6b/ptuning_inference.py @@ -0,0 +1,34 @@ +import torch + +from modelscope import Model, pipeline, read_config +from modelscope.metainfo import Models +from modelscope.utils.config import ConfigDict + +model_dir = 'ZhipuAI/ChatGLM-6B' +model_config = read_config(model_dir) +model_config['model'] = ConfigDict({ + 'type': Models.chatglm_6b, + 'pre_seq_len': 128, + 'prefix_projection': False, +}) + +model = Model.from_pretrained(model_dir, cfg_dict=model_config) +model = model.half() +model.transformer.prefix_encoder.float() +prefix_state_dict = torch.load('./ptuning_dureader_target/iter_900.pth') +new_prefix_state_dict = {} +for k, v in prefix_state_dict.items(): + if k.startswith('transformer.prefix_encoder.'): + new_prefix_state_dict[k[len('transformer.prefix_encoder.'):]] = v +model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + +pipe = pipeline('chat', model) + +print( + pipe({ + 'text': + '维生素C也叫抗坏血酸,所以它最重要的一个作用是预防坏血病。另外,维生素C在控制感染和愈合伤口方面发挥作用,是一种强大的抗氧化剂,' + '可以中和有害的自由基。维生素C还是合成胶原蛋白的重要营养成分,胶原蛋白是结缔组织中的一种纤维蛋白,它存在于身体的各个系统中:' + '神经系统、免疫系统、骨骼系统、软骨系统、血液系统和其他系统。维生素C有助于产生作用于大脑和神经的多种激素和化学信使。', + 'history': [] + })) diff --git a/examples/pytorch/chatglm6b/run_train_chatglm2_lora_dureader_v2.sh b/examples/pytorch/chatglm6b/run_train_chatglm2_lora_dureader_v2.sh new file mode 100644 index 000000000..d24494cc1 --- /dev/null +++ b/examples/pytorch/chatglm6b/run_train_chatglm2_lora_dureader_v2.sh @@ -0,0 +1,28 @@ +LR=5e-5 + +PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \ + --train_dataset_name modelscope/DuReader_robust-QG \ + --val_dataset_name modelscope/DuReader_robust-QG \ + --train_subset_name default \ + --val_subset_name default \ + --train_split train \ + --val_split validation \ + --prompt_column text1 \ + --response_column text2 \ + --model "ZhipuAI/chatglm2-6b" \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --train.optimizer.options.cumulative_iters 1 \ + --max_epochs 2 \ + --save_strategy 'by_step' \ + --save_interval 300 \ + --lr $LR \ + --eval_strategy "by_step" \ + --eval_interval 300 \ + --lr_strategy 'by_step' \ + --task 'chat' \ + --model.type 'chatglm2-6b' \ + --use_lora 1 \ + --work_dir lora_dureader_target \ diff --git a/examples/pytorch/chatglm6b/run_train_lora_adv.sh b/examples/pytorch/chatglm6b/run_train_lora_adv.sh new file mode 100644 index 000000000..cb6a7856f --- /dev/null +++ b/examples/pytorch/chatglm6b/run_train_lora_adv.sh @@ -0,0 +1,24 @@ +LR=5e-5 + +PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \ + --train_dataset_name AdvertiseGen/train.json \ + --val_dataset_name AdvertiseGen/dev.json \ + --prompt_column content \ + --response_column summary \ + --model "ZhipuAI/ChatGLM-6B" \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --train.optimizer.options.cumulative_iters 1 \ + --max_epochs 1 \ + --save_strategy 'by_step' \ + --save_interval 1000 \ + --lr $LR \ + --eval_strategy "by_step" \ + --eval_interval 1000 \ + --lr_strategy 'by_step' \ + --task 'chat' \ + --model.type 'chatglm6b' \ + --use_lora 1 \ + --work_dir lora_adv_target \ diff --git a/examples/pytorch/chatglm6b/run_train_lora_dureader.sh b/examples/pytorch/chatglm6b/run_train_lora_dureader.sh new file mode 100644 index 000000000..26cbce15b --- /dev/null +++ b/examples/pytorch/chatglm6b/run_train_lora_dureader.sh @@ -0,0 +1,28 @@ +LR=5e-5 + +PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \ + --train_dataset_name modelscope/DuReader_robust-QG \ + --val_dataset_name modelscope/DuReader_robust-QG \ + --train_subset_name default \ + --val_subset_name default \ + --train_split train \ + --val_split validation \ + --prompt_column text1 \ + --response_column text2 \ + --model "ZhipuAI/ChatGLM-6B" \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --train.optimizer.options.cumulative_iters 1 \ + --max_epochs 2 \ + --save_strategy 'by_step' \ + --save_interval 300 \ + --lr $LR \ + --eval_strategy "by_step" \ + --eval_interval 300 \ + --lr_strategy 'by_step' \ + --task 'chat' \ + --model.type 'chatglm6b' \ + --use_lora 1 \ + --work_dir lora_dureader_target \ diff --git a/examples/pytorch/chatglm6b/run_train_ptuning_adv.sh b/examples/pytorch/chatglm6b/run_train_ptuning_adv.sh new file mode 100644 index 000000000..667c0c96d --- /dev/null +++ b/examples/pytorch/chatglm6b/run_train_ptuning_adv.sh @@ -0,0 +1,26 @@ +PRE_SEQ_LEN=128 +LR=2e-2 + +PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \ + --train_dataset_name AdvertiseGen/train.json \ + --val_dataset_name AdvertiseGen/dev.json \ + --prompt_column content \ + --response_column summary \ + --model "ZhipuAI/ChatGLM-6B" \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --train.optimizer.options.cumulative_iters 1 \ + --max_epochs 1 \ + --save_strategy 'by_step' \ + --save_interval 1000 \ + --lr $LR \ + --eval_strategy "by_step" \ + --eval_interval 1000 \ + --lr_strategy 'by_step' \ + --task 'chat' \ + --model.type 'chatglm6b' \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 \ + --work_dir ptuning_adv_target \ diff --git a/examples/pytorch/chatglm6b/run_train_ptuning_dureader.sh b/examples/pytorch/chatglm6b/run_train_ptuning_dureader.sh new file mode 100644 index 000000000..d36ad50ab --- /dev/null +++ b/examples/pytorch/chatglm6b/run_train_ptuning_dureader.sh @@ -0,0 +1,30 @@ +PRE_SEQ_LEN=128 +LR=2e-2 + +PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \ + --train_dataset_name modelscope/DuReader_robust-QG \ + --val_dataset_name modelscope/DuReader_robust-QG \ + --train_subset_name default \ + --val_subset_name default \ + --train_split train \ + --val_split validation \ + --prompt_column text1 \ + --response_column text2 \ + --model "ZhipuAI/ChatGLM-6B" \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --train.optimizer.options.cumulative_iters 1 \ + --max_epochs 3 \ + --save_strategy 'by_step' \ + --save_interval 300 \ + --lr $LR \ + --eval_strategy "by_step" \ + --eval_interval 300 \ + --lr_strategy 'by_step' \ + --task 'chat' \ + --model.type 'chatglm6b' \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 \ + --work_dir ptuning_dureader_target \ diff --git a/examples/pytorch/chatglm6b/text_generation_metric.py b/examples/pytorch/chatglm6b/text_generation_metric.py new file mode 100644 index 000000000..2083453a7 --- /dev/null +++ b/examples/pytorch/chatglm6b/text_generation_metric.py @@ -0,0 +1,85 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict, Iterable, List + +import jieba +import numpy as np +from nltk.translate.bleu_score import (SmoothingFunction, corpus_bleu, + sentence_bleu) +from rouge import Rouge + +from modelscope.metainfo import Metrics +from modelscope.metrics.base import Metric +from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.utils.chinese_utils import rebuild_chinese_str +from modelscope.utils.registry import default_group + + +@METRICS.register_module(group_key=default_group, module_name='chatglm') +class TextGenerationMetric(Metric): + + def __init__(self, target_text='tgts', pred_text='preds'): + self.preds: List[str] = [] + self.tgts: List[str] = [] + self.rouge = Rouge() + self.target_text = target_text + self.pred_text = pred_text + + def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]): + ground_truths = inputs[self.target_text] + eval_results = outputs[self.pred_text] + for truth in ground_truths: + self.tgts.append(truth) + for result in eval_results: + self.preds.append(result) + + def _check(self, pred: str, tgt: str) -> bool: + + def remove_useless(string: str) -> str: + return string.replace(' ', '').replace('.', '') + + return len(remove_useless(pred)) != 0 and len(remove_useless(tgt)) != 0 + + def evaluate(self): + preds, labels = self.preds, self.tgts + if isinstance(preds, tuple): + preds = preds[0] + + score_dict = { + 'rouge-1': [], + 'rouge-2': [], + 'rouge-l': [], + 'bleu-4': [] + } + for pred, label in zip(preds, labels): + hypothesis = list(jieba.cut(pred)) + if len(hypothesis) == 0: + hypothesis = [''] + reference = list(jieba.cut(label)) + rouge = Rouge() + scores = rouge.get_scores(' '.join(hypothesis), + ' '.join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v['f'] * 100, 4)) + bleu_score = sentence_bleu( + [list(label)], + list(pred), + smoothing_function=SmoothingFunction().method3) + score_dict['bleu-4'].append(round(bleu_score * 100, 4)) + + for k, v in score_dict.items(): + score_dict[k] = float(np.mean(v)) + return score_dict + + def merge(self, other: 'TextGenerationMetric'): + self.preds.extend(other.preds) + self.tgts.extend(other.tgts) + + def __getstate__(self): + return self.preds, self.tgts + + def __setstate__(self, state): + self.__init__() + self.preds, self.tgts = state diff --git a/examples/pytorch/llama/finetune_llama.py b/examples/pytorch/llama/finetune_llama.py index 88975e664..f3fda96e8 100644 --- a/examples/pytorch/llama/finetune_llama.py +++ b/examples/pytorch/llama/finetune_llama.py @@ -4,14 +4,10 @@ import copy import logging import os -import shutil -import tempfile -import unittest from dataclasses import dataclass, field import json import torch -import utils from modelscope import TrainingArgs from modelscope.hub.snapshot_download import snapshot_download @@ -19,6 +15,8 @@ from modelscope.models.nlp.llama import LlamaForTextGeneration, LlamaTokenizer from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \ TorchCustomDataset +from modelscope.swift import Swift +from modelscope.swift.lora import LoRAConfig from modelscope.trainers import build_trainer IGNORE_INDEX = -100 @@ -54,11 +52,35 @@ class TextGenerationArguments(TrainingArgs): 'help': 'The location of DeepSpeed json config file.', }) - work_dir: str = field( - default=None, metadata={ - 'help': 'The location of work dir', + use_lora: int = field( + default=0, + metadata={'help': 'Whether to use lora to train the model.'}, + ) + + lora_rank: int = field( + default=32, + metadata={'help': 'The lora rank'}, + ) + + lora_alpha: int = field( + default=32, + metadata={'help': 'The lora alpha'}, + ) + + lora_dropout: float = field( + default=0.05, + metadata={'help': 'The lora dropout'}, + ) + + device_map: str = field( + default=None, + metadata={ + 'help': 'A map that specifies where each submodule should go.' }) + zero_stage: int = field( + default=None, metadata={'help': 'The stage of zero_optimization'}) + def _tokenize_fn(strings, tokenizer): """Tokenize a list of strings.""" @@ -211,12 +233,15 @@ def cfg_modify_fn(cfg): cfg.train.dataloader = {'batch_size_per_gpu': 4, 'workers_per_gpu': 1} if 'hooks' not in cfg.train: cfg.train['hooks'] = [] - cfg.train.hooks.append({ - 'type': 'DeepspeedHook', - 'config': args.deepspeed, - 'save_zero_checkpoint': True, - 'with_mpu': False, - }) + if args.deepspeed is not None: + cfg.train.hooks.append({ + 'type': 'DeepspeedHook', + 'config': args.deepspeed, + 'save_zero_checkpoint': True, + 'with_mpu': False, + }) + if args.zero_stage is not None: + cfg.train.hooks[-1]['zero_stage'] = args.zero_stage cfg.preprocessor.sequence_length = 512 return cfg @@ -225,7 +250,17 @@ def cfg_modify_fn(cfg): args.model) else snapshot_download(args.model) data_path = args.src_txt if args.src_txt else os.path.join( model_path, 'alpaca_data.json') - model = LlamaForTextGeneration.from_pretrained(model_path) + model = LlamaForTextGeneration.from_pretrained( + model_path, device_map=args.device_map) + + if args.use_lora != 0: + lora_config = LoRAConfig( + replace_modules=['q_proj', 'k_proj', 'v_proj'], + rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout) + model = model.bfloat16() + Swift.prepare_model(model, lora_config) tokenizer = LlamaTokenizer.from_pretrained( model_path, @@ -234,10 +269,14 @@ def cfg_modify_fn(cfg): ) special_tokens_dict = dict() - special_tokens_dict['pad_token'] = DEFAULT_PAD_TOKEN - special_tokens_dict['eos_token'] = DEFAULT_EOS_TOKEN - special_tokens_dict['bos_token'] = DEFAULT_BOS_TOKEN - special_tokens_dict['unk_token'] = DEFAULT_UNK_TOKEN + if tokenizer.pad_token is None or tokenizer.pad_token == '': + special_tokens_dict['pad_token'] = DEFAULT_PAD_TOKEN + if tokenizer.eos_token is None or tokenizer.eos_token == '': + special_tokens_dict['eos_token'] = DEFAULT_EOS_TOKEN + if tokenizer.bos_token is None or tokenizer.bos_token == '': + special_tokens_dict['bos_token'] = DEFAULT_BOS_TOKEN + if tokenizer.unk_token is None or tokenizer.unk_token == '': + special_tokens_dict['unk_token'] = DEFAULT_UNK_TOKEN smart_tokenizer_and_embedding_resize( special_tokens_dict=special_tokens_dict, @@ -261,3 +300,19 @@ def cfg_modify_fn(cfg): trainer = build_trainer( name=Trainers.text_generation_trainer, default_args=kwargs) trainer.train() + + # prepare for inference + if args.deepspeed and int(os.environ.get('LOCAL_RANK', 0)) == 0: + tokenizer.save_pretrained(os.path.join(args.work_dir, 'output')) + os.system(f'rm {args.work_dir}/output/pytorch_model*') + os.system( + f'python3 {args.work_dir}/zero_to_fp32.py {args.work_dir} {args.work_dir}/output/pytorch_model.bin' + ) + os.system( + f'cp {model_path}/configuration.json {args.work_dir}/output/configuration.json' + ) + with open(f'{model_path}/config.json', 'r') as f: + config = json.load(f) + config['vocab_size'] = len(tokenizer) + with open(f'{args.work_dir}/output/config.json', 'w') as f: + json.dump(config, f) diff --git a/examples/pytorch/llama/run_train_lora.sh b/examples/pytorch/llama/run_train_lora.sh new file mode 100644 index 000000000..264243251 --- /dev/null +++ b/examples/pytorch/llama/run_train_lora.sh @@ -0,0 +1,11 @@ +DATA_PARALLEL_SIZE=1 + + +export PYTHONPATH=$PYTHONPATH:./ +torchrun --nproc_per_node $DATA_PARALLEL_SIZE examples/pytorch/llama/finetune_llama.py \ + --work_dir './tmp' \ + --model 'skyline2006/llama-7b' \ + --deepspeed 'default_offload_opt_param.json' \ + --eval_interval 100 \ + --use_lora 1 \ + --zero_stage 2 \ diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 2f14976da..a17a05380 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -44,8 +44,8 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DEFAULT_MODEL_REVISION, DEFAULT_REPOSITORY_REVISION, - MASTER_MODEL_BRANCH, DatasetFormations, - DatasetMetaFormats, + MASTER_MODEL_BRANCH, META_FILES_FORMAT, + DatasetFormations, DatasetMetaFormats, DatasetVisibilityMap, DownloadChannel, DownloadMode, ModelFile, VirgoDatasetConfig) @@ -643,42 +643,57 @@ def get_dataset_meta_files_local_paths(self, dataset_name: str, return local_paths, dataset_formation @staticmethod - def fetch_csv_from_url(url, out_path, chunk_size=100000, mode=DownloadMode.REUSE_DATASET_IF_EXISTS): - from io import StringIO + def fetch_meta_files_from_url(url, out_path, chunk_size=1024, mode=DownloadMode.REUSE_DATASET_IF_EXISTS): + """ + Fetch the meta-data files from the url, e.g. csv/jsonl files. + """ import hashlib + import json + from tqdm import tqdm out_path = os.path.join(out_path, hashlib.md5(url.encode(encoding='UTF-8')).hexdigest()) if mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(out_path): os.remove(out_path) if os.path.exists(out_path): - logger.info(f'Reusing cached meta-csv file: {out_path}') + logger.info(f'Reusing cached meta-data file: {out_path}') return out_path cookies = ModelScopeConfig.get_cookies() # Make the request and get the response content as TextIO - logger.info('Loading meta-csv file ...') - - response = requests.get(url, cookies=cookies) - data = StringIO(response.text) - - # Use read_csv with the TextIO object - csv_file_reader = pd.read_csv(data, iterator=True, dtype=str, delimiter=None) + logger.info('Loading meta-data file ...') + response = requests.get(url, cookies=cookies, stream=True) + total_size = int(response.headers.get('content-length', 0)) + progress = tqdm(total=total_size, dynamic_ncols=True) + + def get_chunk(resp): + chunk_data = [] + for data in resp.iter_lines(): + data = data.decode('utf-8') + chunk_data.append(data) + if len(chunk_data) >= chunk_size: + yield chunk_data + chunk_data = [] + yield chunk_data - loop = True iter_num = 0 - while loop: - try: - chunk = csv_file_reader.get_chunk(size=chunk_size) - logger.info(f'Receiving chunk {iter_num}, shape: {chunk.shape}') - if iter_num == 0: - with_header = True + with open(out_path, 'a') as f: + for chunk in get_chunk(response): + progress.update(len(chunk)) + if url.endswith('jsonl'): + chunk = [json.loads(line) for line in chunk if line.strip()] + if len(chunk) == 0: + continue + if iter_num == 0: + with_header = True + else: + with_header = False + chunk_df = pd.DataFrame(chunk) + chunk_df.to_csv(f, index=False, header=with_header) + iter_num += 1 else: - with_header = False - - chunk.to_csv(out_path, mode='a', index=False, header=with_header) - iter_num += 1 - except StopIteration: - loop = False - logger.info('stop chunk iteration') + # csv or others + for line in chunk: + f.write(line + '\n') + progress.close() return out_path @@ -688,7 +703,7 @@ def get_dataset_file_url( dataset_name: str, namespace: str, revision: Optional[str] = DEFAULT_DATASET_REVISION): - if file_name.endswith('.csv'): + if file_name and os.path.splitext(file_name)[-1] in META_FILES_FORMAT: file_name = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ f'Revision={revision}&FilePath={file_name}' return file_name diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index d3365b7c0..2b5bca26e 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -165,6 +165,8 @@ class Models(object): doc2bot = 'doc2bot' peer = 'peer' llama = 'llama' + chatglm_6b = 'chatglm6b' + chatglm2_6b = 'chatglm2-6b' # audio models sambert_hifigan = 'sambert-hifigan' @@ -183,7 +185,9 @@ class Models(object): ecapa_tdnn_sv = 'ecapa-tdnn-sv' campplus_sv = 'cam++-sv' eres2net_sv = 'eres2net-sv' + eres2net_aug_sv = 'eres2net-aug-sv' scl_sd = 'scl-sd' + cluster_backend = 'cluster-backend' rdino_tdnn_sv = 'rdino_ecapa-tdnn-sv' generic_lm = 'generic-lm' @@ -487,6 +491,7 @@ class Pipelines(object): speaker_verification_rdino = 'speaker-verification-rdino' speaker_verification_eres2net = 'speaker-verification-eres2net' speaker_change_locating = 'speaker-change-locating' + segmentation_clustering = 'segmentation-clustering' lm_inference = 'language-score-prediction' speech_timestamp_inference = 'speech-timestamp-inference' diff --git a/modelscope/metrics/text_generation_metric.py b/modelscope/metrics/text_generation_metric.py index 95947d3e5..8ad65eafe 100644 --- a/modelscope/metrics/text_generation_metric.py +++ b/modelscope/metrics/text_generation_metric.py @@ -1,6 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Dict, Iterable, List +import sys +from contextlib import contextmanager +from typing import Dict, Iterable, List, Tuple from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu from rouge import Rouge @@ -55,7 +57,8 @@ def evaluate(self): def mean(iter: Iterable) -> float: return sum(iter) / len(self.preds) - rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) + with extend_recursion_limit(preds, tgts): + rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) @@ -87,3 +90,14 @@ def __getstate__(self): def __setstate__(self, state): self.__init__() self.preds, self.tgts = state + + +@contextmanager +def extend_recursion_limit(preds: Tuple[str], tgts: Tuple[str]): + origin_limit = sys.getrecursionlimit() + new_limit = max(len(pred) + for pred in preds) * max(len(tgt) for tgt in tgts) + if new_limit > origin_limit: + sys.setrecursionlimit(new_limit) + yield + sys.setrecursionlimit(origin_limit) diff --git a/modelscope/models/audio/sv/DTDNN.py b/modelscope/models/audio/sv/DTDNN.py index d86d67994..4fc7fedc8 100644 --- a/modelscope/models/audio/sv/DTDNN.py +++ b/modelscope/models/audio/sv/DTDNN.py @@ -4,6 +4,7 @@ from collections import OrderedDict from typing import Any, Dict, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -17,6 +18,7 @@ TDNNLayer, TransitLayer, get_nonlinear) from modelscope.utils.constant import Tasks +from modelscope.utils.device import create_device class FCM(nn.Module): @@ -162,34 +164,41 @@ def __init__(self, model_dir, model_config: Dict[str, Any], *args, self.feature_dim = self.model_config['fbank_dim'] self.emb_size = self.model_config['emb_size'] + self.device = create_device(self.other_config['device']) self.embedding_model = CAMPPlus(self.feature_dim, self.emb_size) - pretrained_model_name = kwargs['pretrained_model'] self.__load_check_point(pretrained_model_name) + self.embedding_model.to(self.device) self.embedding_model.eval() def forward(self, audio): - assert len(audio.shape) == 2 and audio.shape[ - 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]' - # audio shape: [1, T] + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + if len(audio.shape) == 1: + audio = audio.unsqueeze(0) + assert len( + audio.shape + ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]' + # audio shape: [N, T] feature = self.__extract_feature(audio) - embedding = self.embedding_model(feature) - - return embedding + embedding = self.embedding_model(feature.to(self.device)) + return embedding.detach().cpu() def __extract_feature(self, audio): - feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim) - feature = feature - feature.mean(dim=0, keepdim=True) - feature = feature.unsqueeze(0) - return feature - - def __load_check_point(self, pretrained_model_name, device=None): - if not device: - device = torch.device('cpu') + features = [] + for au in audio: + feature = Kaldi.fbank( + au.unsqueeze(0), num_mel_bins=self.feature_dim) + feature = feature - feature.mean(dim=0, keepdim=True) + features.append(feature.unsqueeze(0)) + features = torch.cat(features) + return features + + def __load_check_point(self, pretrained_model_name): self.embedding_model.load_state_dict( torch.load( os.path.join(self.model_dir, pretrained_model_name), - map_location=device), + map_location=torch.device('cpu')), strict=True) diff --git a/modelscope/models/audio/sv/ERes2Net_aug.py b/modelscope/models/audio/sv/ERes2Net_aug.py new file mode 100644 index 000000000..08f1a8a1c --- /dev/null +++ b/modelscope/models/audio/sv/ERes2Net_aug.py @@ -0,0 +1,340 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker. + ERes2Net_aug incorporates both local and global feature fusion techniques + to improve the performance. The training code is located on the following + GitHub repository: https://github.com/alibaba-damo-academy/3D-Speaker. +""" +import math +import os +from typing import Any, Dict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio.compliance.kaldi as Kaldi + +import modelscope.models.audio.sv.pooling_layers as pooling_layers +from modelscope.metainfo import Models +from modelscope.models import MODELS, TorchModel +from modelscope.models.audio.sv.fusion import AFF +from modelscope.utils.constant import Tasks + + +class ReLU(nn.Hardtanh): + + def __init__(self, inplace=False): + super(ReLU, self).__init__(0, 20, inplace) + + def __repr__(self): + inplace_str = 'inplace' if self.inplace else '' + return self.__class__.__name__ + ' (' \ + + inplace_str + ')' + + +def conv1x1(in_planes, out_planes, stride=1): + '1x1 convolution without padding' + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=1, + stride=stride, + padding=0, + bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + '3x3 convolution with padding' + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlockERes2Net(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): + super(BasicBlockERes2Net, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class BasicBlockERes2Net_diff_AFF(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): + super(BasicBlockERes2Net_diff_AFF, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + + self.nums = scale + + convs = [] + fuse_models = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + # Add different fuse_model parameters + for j in range(self.nums - 1): + fuse_models.append(AFF(channels=width)) + + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.fuse_models = nn.ModuleList(fuse_models) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = self.fuse_models[i - 1](sp, spx[i]) + + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class ERes2Net_aug(nn.Module): + + def __init__(self, + block=BasicBlockERes2Net, + block_fuse=BasicBlockERes2Net_diff_AFF, + num_blocks=[3, 4, 6, 3], + m_channels=64, + feat_dim=80, + embedding_size=192, + pooling_func='TSTP', + two_emb_layer=False): + super(ERes2Net_aug, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + + self.conv1 = nn.Conv2d( + 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + self.layer1 = self._make_layer( + block, m_channels, num_blocks[0], stride=1) + self.layer2 = self._make_layer( + block, m_channels * 2, num_blocks[1], stride=2) + self.layer3 = self._make_layer( + block_fuse, m_channels * 4, num_blocks[2], stride=2) + self.layer4 = self._make_layer( + block_fuse, m_channels * 8, num_blocks[3], stride=2) + + self.layer1_downsample = nn.Conv2d( + m_channels * 4, + m_channels * 8, + kernel_size=3, + padding=1, + stride=2, + bias=False) + self.layer2_downsample = nn.Conv2d( + m_channels * 8, + m_channels * 16, + kernel_size=3, + padding=1, + stride=2, + bias=False) + self.layer3_downsample = nn.Conv2d( + m_channels * 16, + m_channels * 32, + kernel_size=3, + padding=1, + stride=2, + bias=False) + + self.fuse_mode12 = AFF(channels=m_channels * 8) + self.fuse_mode123 = AFF(channels=m_channels * 16) + self.fuse_mode1234 = AFF(channels=m_channels * 32) + + self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2 + self.pool = getattr(pooling_layers, pooling_func)( + in_dim=self.stats_dim * block.expansion) + self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, + embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) + stats = self.pool(fuse_out1234) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a + + +@MODELS.register_module( + Tasks.speaker_verification, module_name=Models.eres2net_aug_sv) +class SpeakerVerificationERes2Net(TorchModel): + r"""Enhanced Res2Net_aug architecture with local and global feature fusion. + ERes2Net_aug is an upgraded version of ERes2Net that uses a larger number of + parameters to achieve better recognition performance. + Args: + model_dir: A model dir. + model_config: The model config. + """ + + def __init__(self, model_dir, model_config: Dict[str, Any], *args, + **kwargs): + super().__init__(model_dir, model_config, *args, **kwargs) + self.model_config = model_config + self.other_config = kwargs + self.feature_dim = 80 + + self.embedding_model = ERes2Net_aug() + + pretrained_model_name = kwargs['pretrained_model'] + self.__load_check_point(pretrained_model_name) + + self.embedding_model.eval() + + def forward(self, audio): + assert len(audio.shape) == 2 and audio.shape[ + 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]' + # audio shape: [1, T] + feature = self.__extract_feature(audio) + embedding = self.embedding_model(feature) + + return embedding + + def __extract_feature(self, audio): + feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim) + feature = feature - feature.mean(dim=0, keepdim=True) + feature = feature.unsqueeze(0) + return feature + + def __load_check_point(self, pretrained_model_name, device=None): + if not device: + device = torch.device('cpu') + self.embedding_model.load_state_dict( + torch.load( + os.path.join(self.model_dir, pretrained_model_name), + map_location=device), + strict=True) diff --git a/modelscope/models/audio/sv/cluster_backend.py b/modelscope/models/audio/sv/cluster_backend.py new file mode 100644 index 000000000..ee8751fc5 --- /dev/null +++ b/modelscope/models/audio/sv/cluster_backend.py @@ -0,0 +1,164 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +import numpy as np +import scipy +import sklearn +from sklearn.cluster._kmeans import k_means + +from modelscope.metainfo import Models +from modelscope.models import MODELS, TorchModel +from modelscope.utils.constant import Tasks + + +class SpectralCluster: + r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix. + This implementation is adapted from https://github.com/speechbrain/speechbrain. + """ + + def __init__(self, min_num_spks=0, max_num_spks=30): + self.min_num_spks = min_num_spks + self.max_num_spks = max_num_spks + + def __call__(self, X, pval, oracle_num=None): + # Similarity matrix computation + sim_mat = self.get_sim_mat(X) + + # Refining similarity matrix with pval + prunned_sim_mat = self.p_pruning(sim_mat, pval) + + # Symmetrization + sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T) + + # Laplacian calculation + laplacian = self.get_laplacian(sym_prund_sim_mat) + + # Get Spectral Embeddings + emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num) + + # Perform clustering + labels = self.cluster_embs(emb, num_of_spk) + + return labels + + def get_sim_mat(self, X): + # Cosine similarities + M = sklearn.metrics.pairwise.cosine_similarity(X, X) + return M + + def p_pruning(self, A, pval): + n_elems = int((1 - pval) * A.shape[0]) + + # For each row in a affinity matrix + for i in range(A.shape[0]): + low_indexes = np.argsort(A[i, :]) + low_indexes = low_indexes[0:n_elems] + + # Replace smaller similarity values by 0s + A[i, low_indexes] = 0 + return A + + def get_laplacian(self, M): + M[np.diag_indices(M.shape[0])] = 0 + D = np.sum(np.abs(M), axis=1) + D = np.diag(D) + L = D - M + return L + + def get_spec_embs(self, L, k_oracle=4): + lambdas, eig_vecs = scipy.linalg.eigh(L) + + if k_oracle is not None: + num_of_spk = k_oracle + else: + lambda_gap_list = self.getEigenGaps( + lambdas[self.min_num_spks - 1:self.max_num_spks - 1]) + num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks + + emb = eig_vecs[:, :num_of_spk] + return emb, num_of_spk + + def cluster_embs(self, emb, k): + _, labels, _ = k_means(emb, k) + return labels + + def getEigenGaps(self, eig_vals): + eig_vals_gap_list = [] + for i in range(len(eig_vals) - 1): + gap = float(eig_vals[i + 1]) - float(eig_vals[i]) + eig_vals_gap_list.append(gap) + return eig_vals_gap_list + + +@MODELS.register_module( + Tasks.speaker_diarization, module_name=Models.cluster_backend) +class ClusterBackend(TorchModel): + r"""Perfom clustering for input embeddings and output the labels. + Args: + model_dir: A model dir. + model_config: The model config. + """ + + def __init__(self, model_dir, model_config: Dict[str, Any], *args, + **kwargs): + super().__init__(model_dir, model_config, *args, **kwargs) + self.model_config = model_config + self.other_config = kwargs + + if self.model_config['cluster_type'] == 'spectral': + self.cluster = SpectralCluster(self.model_config['min_num_spks'], + self.model_config['max_num_spks']) + else: + raise ValueError( + 'modelscope error: Only spectral clustering is currently supported.' + ) + + def forward(self, X, **params): + # clustering and return the labels + k = params['oracle_num'] if 'oracle_num' in params else None + pval = params['pval'] if 'pval' in params else self.model_config['pval'] + assert len( + X.shape + ) == 2, 'modelscope error: the shape of input should be [N, C]' + if self.model_config['cluster_type'] == 'spectral': + if X.shape[0] * pval < 6: + pval = 6. / X.shape[0] + labels = self.cluster(X, pval, k) + else: + raise ValueError( + 'modelscope error: Only spectral clustering is currently supported.' + ) + + if k is None and 'merge_thr' in self.model_config: + labels = self.merge_by_cos(labels, X, + self.model_config['merge_thr']) + + return labels + + def merge_by_cos(self, labels, embs, cos_thr): + # merge the similar speakers by cosine similarity + assert cos_thr > 0 and cos_thr <= 1 + while True: + spk_num = labels.max() + 1 + if spk_num == 1: + break + spk_center = [] + for i in range(spk_num): + spk_emb = embs[labels == i].mean(0) + spk_center.append(spk_emb) + assert len(spk_center) > 0 + spk_center = np.stack(spk_center, axis=0) + norm_spk_center = spk_center / np.linalg.norm( + spk_center, axis=1, keepdims=True) + affinity = np.matmul(norm_spk_center, norm_spk_center.T) + affinity = np.triu(affinity, 1) + spks = np.unravel_index(np.argmax(affinity), affinity.shape) + if affinity[spks] < cos_thr: + break + for i in range(len(labels)): + if labels[i] == spks[1]: + labels[i] = spks[0] + elif labels[i] > merge_spks[1]: + labels[i] -= 1 + return labels diff --git a/modelscope/models/audio/sv/ecapa_tdnn.py b/modelscope/models/audio/sv/ecapa_tdnn.py index 0b6558168..a068efa27 100644 --- a/modelscope/models/audio/sv/ecapa_tdnn.py +++ b/modelscope/models/audio/sv/ecapa_tdnn.py @@ -5,6 +5,7 @@ import os from typing import Any, Dict, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -13,6 +14,7 @@ from modelscope.metainfo import Models from modelscope.models import MODELS, TorchModel from modelscope.utils.constant import Tasks +from modelscope.utils.device import create_device def length_to_mask(length, max_len=None, dtype=None, device=None): @@ -470,35 +472,44 @@ def __init__(self, model_dir, model_config: Dict[str, Any], *args, self.feature_dim = 80 channels_config = [1024, 1024, 1024, 1024, 3072] + self.device = create_device(self.other_config['device']) + print(self.device) self.embedding_model = ECAPA_TDNN( self.feature_dim, channels=channels_config) - pretrained_model_name = kwargs['pretrained_model'] self.__load_check_point(pretrained_model_name) + self.embedding_model.to(self.device) self.embedding_model.eval() def forward(self, audio): - assert len(audio.shape) == 2 and audio.shape[ - 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]' - # audio shape: [1, T] + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + if len(audio.shape) == 1: + audio = audio.unsqueeze(0) + assert len( + audio.shape + ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]' + # audio shape: [N, T] feature = self.__extract_feature(audio) - embedding = self.embedding_model(feature) + embedding = self.embedding_model(feature.to(self.device)) - return embedding + return embedding.detach().cpu() def __extract_feature(self, audio): - feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim) - feature = feature - feature.mean(dim=0, keepdim=True) - feature = feature.unsqueeze(0) - return feature - - def __load_check_point(self, pretrained_model_name, device=None): - if not device: - device = torch.device('cpu') + features = [] + for au in audio: + feature = Kaldi.fbank( + au.unsqueeze(0), num_mel_bins=self.feature_dim) + feature = feature - feature.mean(dim=0, keepdim=True) + features.append(feature.unsqueeze(0)) + features = torch.cat(features) + return features + + def __load_check_point(self, pretrained_model_name): self.embedding_model.load_state_dict( torch.load( os.path.join(self.model_dir, pretrained_model_name), - map_location=device), + map_location=torch.device('cpu')), strict=True) diff --git a/modelscope/models/audio/sv/speaker_change_locator.py b/modelscope/models/audio/sv/speaker_change_locator.py index c22e4c1bf..4926196e6 100644 --- a/modelscope/models/audio/sv/speaker_change_locator.py +++ b/modelscope/models/audio/sv/speaker_change_locator.py @@ -14,6 +14,7 @@ from modelscope.models import MODELS, TorchModel from modelscope.models.audio.sv.DTDNN import CAMPPlus from modelscope.utils.constant import Tasks +from modelscope.utils.device import create_device class MultiHeadSelfAttention(nn.Module): @@ -83,6 +84,7 @@ def forward(self, input_len): for len in input_len ]) + input_pos = input_pos.to(list(self.pos_enc.parameters())[0].device) return self.pos_enc(input_pos) @@ -265,6 +267,7 @@ def __init__(self, model_dir, model_config: Dict[str, Any], *args, self.feature_dim = self.model_config['fbank_dim'] frame_size = self.model_config['frame_size'] anchor_size = self.model_config['anchor_size'] + self.device = create_device(kwargs['device']) self.encoder = CAMPPlus(self.feature_dim, output_level='frame') self.backend = TransformerDetector( @@ -275,10 +278,16 @@ def __init__(self, model_dir, model_config: Dict[str, Any], *args, self.__load_check_point(pretrained_encoder, pretrained_backend) + self.encoder.to(self.device) + self.backend.to(self.device) self.encoder.eval() self.backend.eval() def forward(self, audio, anchors): + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + if isinstance(anchors, np.ndarray): + anchors = torch.from_numpy(anchors) assert len(audio.shape) == 2 and audio.shape[ 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]' assert len( @@ -287,8 +296,8 @@ def forward(self, audio, anchors): 1] == 2, 'modelscope error: the shape of input anchors to model needs to be [1, 2, D]' # audio shape: [1, T] feature = self.__extract_feature(audio) - frame_state = self.encoder(feature) - output = self.backend(frame_state, anchors) + frame_state = self.encoder(feature.to(self.device)) + output = self.backend(frame_state, anchors.to(self.device)) output = output.squeeze(0).detach().cpu().sigmoid() time_scale_factor = int(np.ceil(feature.shape[1] / output.shape[0])) @@ -302,18 +311,17 @@ def __extract_feature(self, audio): feature = feature.unsqueeze(0) return feature - def __load_check_point(self, - pretrained_encoder, - pretrained_backend, - device=None): - if not device: - device = torch.device('cpu') + def __load_check_point( + self, + pretrained_encoder, + pretrained_backend, + ): self.encoder.load_state_dict( torch.load( os.path.join(self.model_dir, pretrained_encoder), - map_location=device)) + map_location=torch.device('cpu'))) self.backend.load_state_dict( torch.load( os.path.join(self.model_dir, pretrained_backend), - map_location=device)) + map_location=torch.device('cpu'))) diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py index 70d2e1da7..7f1a2fe95 100644 --- a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py @@ -54,7 +54,7 @@ def multiclass_nms(multi_bboxes, if score_factors is not None: scores = scores * score_factors[:, None] - labels = torch.arange(num_classes, dtype=torch.long) + labels = torch.arange(num_classes, dtype=torch.long, device=scores.device) labels = labels.view(1, -1).expand_as(scores) bboxes = bboxes.reshape(-1, 4) diff --git a/modelscope/models/cv/vision_efficient_tuning/backbone.py b/modelscope/models/cv/vision_efficient_tuning/backbone.py index 691e44405..e83fb958a 100644 --- a/modelscope/models/cv/vision_efficient_tuning/backbone.py +++ b/modelscope/models/cv/vision_efficient_tuning/backbone.py @@ -191,7 +191,7 @@ def __init__( self.prompt = None def forward(self, x): - if self.prompt is not None: + if self.prompt is not None and self.prompt_length and self.prompt_length > 0: x = self.prompt(x) x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) diff --git a/modelscope/models/multi_modal/diffusion/model.py b/modelscope/models/multi_modal/diffusion/model.py index d979cc7fd..34d8d3422 100644 --- a/modelscope/models/multi_modal/diffusion/model.py +++ b/modelscope/models/multi_modal/diffusion/model.py @@ -114,9 +114,9 @@ def forward(self, noise, timesteps, input_ids, token_type_ids, Tasks.text_to_image_synthesis, module_name=Models.diffusion) class DiffusionForTextToImageSynthesis(Model): - def __init__(self, model_dir, device='gpu'): + def __init__(self, model_dir, device='gpu', **kwargs): device = 'gpu' if torch.cuda.is_available() else 'cpu' - super().__init__(model_dir=model_dir, device=device) + super().__init__(model_dir=model_dir, device=device, **kwargs) diffusion_model = DiffusionModel(model_dir=model_dir) pretrained_params = torch.load( osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu') diff --git a/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py b/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py index 747aecd85..3f616297d 100644 --- a/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py +++ b/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py @@ -19,8 +19,8 @@ from modelscope.models import TorchModel from modelscope.models.builder import MODELS from modelscope.outputs import OutputKeys -from modelscope.tuners.control_sd_lora import ControlLoRATuner -from modelscope.tuners.sd_lora import LoRATuner +from modelscope.swift.control_sd_lora import ControlLoRATuner +from modelscope.swift.sd_lora import LoRATuner from modelscope.utils.checkpoint import save_checkpoint, save_configuration from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks diff --git a/modelscope/models/nlp/T5/text2text_generation.py b/modelscope/models/nlp/T5/text2text_generation.py index 4c6fd295d..aa2826a8b 100644 --- a/modelscope/models/nlp/T5/text2text_generation.py +++ b/modelscope/models/nlp/T5/text2text_generation.py @@ -57,7 +57,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight', ] - def __init__(self, config: T5Config, **kwargs): + def __init__(self, config: T5Config, device_map=None, **kwargs): super().__init__(config) self.model_dim = config.d_model @@ -82,7 +82,8 @@ def __init__(self, config: T5Config, **kwargs): # Model parallel self.model_parallel = False - self.device_map = None + if device_map == 'auto': + self.parallelize() def parallelize(self, device_map=None): self.device_map = ( diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index ebe081a92..c99f04ecb 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -22,6 +22,8 @@ from .csanmt import CsanmtForTranslation from .canmt import CanmtForTranslation from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model + from .chatglm import ChatGLMForConditionalGeneration, ChatGLMTokenizer, ChatGLMConfig + from .chatglm2 import ChatGLM2ForConditionalGeneration, ChatGLM2Tokenizer, ChatGLM2Config from .gpt_neo import GPTNeoModel from .gpt2 import GPT2Model from .gpt3 import GPT3ForTextGeneration, DistributedGPT3 @@ -95,6 +97,14 @@ ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], 'glm_130b': ['GLM130bForTextGeneration'], 'deberta_v2': ['DebertaV2ForMaskedLM', 'DebertaV2Model'], + 'chatglm': [ + 'ChatGLMForConditionalGeneration', 'ChatGLMTokenizer', + 'ChatGLMConfig' + ], + 'chatglm2': [ + 'ChatGLM2ForConditionalGeneration', 'ChatGLM2Tokenizer', + 'ChatGLM2Config' + ], 'heads': ['TextClassificationHead'], 'hf_transformers': ['TransformersModel'], 'gpt2': ['GPT2Model'], diff --git a/modelscope/models/nlp/chatglm/__init__.py b/modelscope/models/nlp/chatglm/__init__.py new file mode 100644 index 000000000..2a0a073fe --- /dev/null +++ b/modelscope/models/nlp/chatglm/__init__.py @@ -0,0 +1,46 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import ChatGLMConfig + from .tokenization import ChatGLMTokenizer + from .text_generation import ChatGLMForConditionalGeneration + from .quantization import ( + quantize, ) + +else: + _import_structure = { + 'configuration': ['ChatGLMConfig'], + 'text_generation': ['ChatGLMForConditionalGeneration'], + 'quantization': ['quantize'], + 'tokenization': [ + 'ChatGLMTokenizer', + ], + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__) diff --git a/modelscope/models/nlp/chatglm/configuration.py b/modelscope/models/nlp/chatglm/configuration.py new file mode 100644 index 000000000..18fdca0f1 --- /dev/null +++ b/modelscope/models/nlp/chatglm/configuration.py @@ -0,0 +1,101 @@ +""" ChatGLM model configuration """ + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class ChatGLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~ChatGLMModel`]. + It is used to instantiate an ChatGLM model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 150528): + Vocabulary size of the ChatGLM-6B model. + Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~ChatGLMModel`] or + [`~TFChatGLMModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + inner_hidden_size (`int`, *optional*, defaults to 16384): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + layernorm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether the model should return the last key/values attentions (not used by all models). + Example: + + ```python + >>> from modelscope.models.nlp.chatglm.configuration import ChatGLMConfig + >>> from modelscope.models.nlp.chatglm.text_generation import ChatGLMModel + + >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration + >>> configuration = ChatGLMConfig() + + >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration + >>> model = ChatGLMModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` +""" + model_type = 'chatglm' + + def __init__(self, + vocab_size=150528, + hidden_size=4096, + num_layers=28, + num_attention_heads=32, + layernorm_epsilon=1e-5, + use_cache=False, + bos_token_id=150004, + eos_token_id=150005, + mask_token_id=150000, + gmask_token_id=150001, + pad_token_id=0, + max_sequence_length=2048, + inner_hidden_size=16384, + position_encoding_2d=True, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): + self.num_layers = num_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.max_sequence_length = max_sequence_length + self.layernorm_epsilon = layernorm_epsilon + self.inner_hidden_size = inner_hidden_size + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.mask_token_id = mask_token_id + self.gmask_token_id = gmask_token_id + self.position_encoding_2d = position_encoding_2d + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) diff --git a/modelscope/models/nlp/chatglm/quantization.py b/modelscope/models/nlp/chatglm/quantization.py new file mode 100644 index 000000000..9994d9c48 --- /dev/null +++ b/modelscope/models/nlp/chatglm/quantization.py @@ -0,0 +1,234 @@ +import base64 +import bz2 +import ctypes +from typing import List + +import torch +from torch.nn import Linear +from torch.nn.parameter import Parameter +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + + quantization_code = '$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ' # noqa + + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + 'int4WeightCompression', + 'int4WeightExtractionFloat', + 'int4WeightExtractionHalf', + 'int8WeightExtractionFloat', + 'int8WeightExtractionHalf', + ], + ) +except Exception as exception: + kernels = None + logger.warning('Failed to load cpm_kernels:' + str(exception)) + + +class W8A16Linear(torch.autograd.Function): + + @staticmethod + def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, + scale_w: torch.Tensor, weight_bit_width): + ctx.inp_shape = inp.size() + ctx.weight_bit_width = weight_bit_width + out_features = quant_w.size(0) + inp = inp.contiguous().view(-1, inp.size(-1)) + weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() + output = inp.mm(weight.t()) + ctx.save_for_backward(inp, quant_w, scale_w) + return output.view(*(ctx.inp_shape[:-1] + (out_features, ))) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + inp, quant_w, scale_w = ctx.saved_tensors + weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) + grad_output = grad_output.contiguous().view(-1, weight.size(0)) + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(inp) + return grad_input.view(ctx.inp_shape), grad_weight.view( + ctx.weight_shape), None, None + + +def compress_int4_weight(weight: torch.Tensor): # (n, m) + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + assert m % 2 == 0 + m = m // 2 + out = torch.empty(n, m, dtype=torch.int8, device='cuda') + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + if kernels is None: + raise RuntimeError( + 'kernels is None, please check whether it is correctly initialized.' + ) + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m) + ], + ) + return out + + +def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, + source_bit_width: int): + if kernels is None: + raise RuntimeError( + 'kernels is None, please check whether it is correctly initialized.' + ) + if source_bit_width == 8: + func = kernels.int8WeightExtractionHalf + elif source_bit_width == 4: + func = kernels.int4WeightExtractionHalf + else: + assert False, 'Unsupported bit-width' + + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + out = torch.empty( + n, m * (8 // source_bit_width), dtype=torch.half, device='cuda') + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + func( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(scale_list.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], + ) + return out + + +class QuantizedLinear(Linear): + + def __init__(self, + weight_bit_width: int, + weight_tensor=None, + bias_tensor=None, + empty_init=False, + *args, + **kwargs): + super(QuantizedLinear, self).__init__(*args, **kwargs) + self.weight_bit_width = weight_bit_width + + shape = self.weight.shape + del self.weight + + if weight_tensor is None or empty_init: + self.weight = torch.empty( + shape[0], + shape[1] * weight_bit_width // 8, + dtype=torch.int8, + device=kwargs['device']) + self.weight_scale = torch.empty( + shape[0], dtype=kwargs['dtype'], device=kwargs['device']) + else: + self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ( + (2**(weight_bit_width - 1)) - 1)).half() # noqa + self.weight = torch.round( + weight_tensor / self.weight_scale[:, None]).to(torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + + self.weight = Parameter( + self.weight.to(kwargs['device']), requires_grad=False) + self.weight_scale = Parameter( + self.weight_scale.to(kwargs['device']), requires_grad=False) + if bias_tensor is not None: + self.bias = Parameter( + bias_tensor.to(kwargs['device']), requires_grad=False) + else: + self.bias = None + + def forward(self, input): + output = W8A16Linear.apply(input, self.weight, self.weight_scale, + self.weight_bit_width) + if self.bias is not None: + output = output + self.bias + return output + + +def quantize(model, weight_bit_width, empty_init=False, **kwargs): + """Replace fp16 linear with quantized linear""" + + for layer in model.layers: + layer.attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.attention.query_key_value.weight.to( + torch.cuda.current_device()), + bias_tensor=layer.attention.query_key_value.bias, + in_features=layer.attention.query_key_value.in_features, + out_features=layer.attention.query_key_value.out_features, + bias=True, + dtype=torch.half, + device=layer.attention.query_key_value.weight.device, + empty_init=empty_init) + layer.attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.attention.dense.weight.to( + torch.cuda.current_device()), + bias_tensor=layer.attention.dense.bias, + in_features=layer.attention.dense.in_features, + out_features=layer.attention.dense.out_features, + bias=True, + dtype=torch.half, + device=layer.attention.dense.weight.device, + empty_init=empty_init) + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.mlp.dense_h_to_4h.weight.to( + torch.cuda.current_device()), + bias_tensor=layer.mlp.dense_h_to_4h.bias, + in_features=layer.mlp.dense_h_to_4h.in_features, + out_features=layer.mlp.dense_h_to_4h.out_features, + bias=True, + dtype=torch.half, + device=layer.mlp.dense_h_to_4h.weight.device, + empty_init=empty_init) + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.mlp.dense_4h_to_h.weight.to( + torch.cuda.current_device()), + bias_tensor=layer.mlp.dense_4h_to_h.bias, + in_features=layer.mlp.dense_4h_to_h.in_features, + out_features=layer.mlp.dense_4h_to_h.out_features, + bias=True, + dtype=torch.half, + device=layer.mlp.dense_4h_to_h.weight.device, + empty_init=empty_init) + return model diff --git a/modelscope/models/nlp/chatglm/text_generation.py b/modelscope/models/nlp/chatglm/text_generation.py new file mode 100644 index 000000000..53f8f1058 --- /dev/null +++ b/modelscope/models/nlp/chatglm/text_generation.py @@ -0,0 +1,1572 @@ +""" PyTorch ChatGLM model. """ + +import copy +import math +import os +import re +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import (GenerationConfig, + LogitsProcessorList, ModelOutput, + StoppingCriteriaList) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, logging) + +from modelscope.metainfo import Models +from modelscope.models import MODELS, Model, TorchModel +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks +from .configuration import ChatGLMConfig +from .tokenization import ChatGLMTokenizer + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = 'THUDM/ChatGLM-6B' +_CONFIG_FOR_DOC = 'ChatGLM6BConfig' + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'THUDM/chatglm-6b', + # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm +] + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + + def __call__(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' + 'https://www.tensorflow.org/install/ for installation instructions.' + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f'Converting TensorFlow checkpoint from {tf_path}') + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f'Loading TF weight {name} with shape {shape}') + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in [ + 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', + 'AdamWeightDecayOptimizer_1', 'global_step' + ] for n in name): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + scope_names = re.split(r'_(\d+)', m_name) + else: + scope_names = [m_name] + if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif scope_names[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'squad': + pointer = getattr(pointer, 'classifier') + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f'Initialize PyTorch weight {name}') + pointer.data = torch.from_numpy(array) + return model + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, + config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, + config.num_layers * config.hidden_size * 2)) + else: + self.embedding = torch.nn.Embedding( + config.pre_seq_len, config.num_layers * config.hidden_size * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * ( + 1.0 + torch.tanh(0.7978845608028654 * x * # noqa + (1.0 + 0.044715 * x * x))) # noqa + + +def gelu(x): + return gelu_impl(x) + + +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base**(torch.arange(0, dim, 2).float() / dim)) + inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or ( + seq_len > self.max_seq_len_cached): # noqa + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange( + seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat( + (-x2, x1), + dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + ( + rotate_half(k) * sin) + return q, k + + +def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, +): + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=0) + value_layer = torch.cat((past_value, value_layer), dim=0) + + # seqlen, batch, num_attention_heads, hidden_size_per_attention_head + seq_len, b, nh, hidden_size = key_layer.shape + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_key_layer_scaling_coeff = float(layer_id + 1) + if scaling_attention_score: + query_layer = query_layer / ( + math.sqrt(hidden_size) * query_key_layer_scaling_coeff) + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], + -1) + + matmul_result = torch.zeros( + 1, + 1, + 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + if self.scale_mask_softmax: + self.scale_mask_softmax.scale = query_key_layer_scaling_coeff + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask.contiguous()) + else: + if not (attention_mask == 0).all(): + # if auto-regressive, skip + attention_scores.masked_fill_(attention_mask, -10000.0) + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + attention_scores = attention_scores * query_key_layer_scaling_coeff + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), + query_layer.size(0), value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + ( + hidden_size_per_partition, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, present, attention_probs) + + return outputs + + +class SelfAttention(torch.nn.Module): + + def __init__(self, + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=None, + bias=True, + params_dtype=torch.float, + position_encoding_2d=True): + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( # noqa + self.hidden_size // # noqa + (self.num_attention_heads * 2) if position_encoding_2d else # noqa + self.hidden_size // self.num_attention_heads, # noqa + base=10000, # noqa + precision=torch.half, # noqa + learnable=False, # noqa + ) # noqa + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + # Strided linear layer. + self.query_key_value = skip_init( + torch.nn.Linear, + hidden_size, + 3 * self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + + self.dense = skip_init( + torch.nn.Linear, + self.inner_hidden_size, + hidden_size, + bias=bias, + dtype=params_dtype, + ) + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, + tensor, + num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [seq_len, batch, 3 * hidden_size] + mixed_raw_layer = self.query_key_value(hidden_states) + + # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, + # 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, + value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ + position_ids[:, 1, :].transpose(0, 1).contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, + block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb( + value_layer, seq_len=position_ids.max() + 1) + # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index( + query_layer, key_layer, cos, sin, position_ids) + + # [seq_len, batch, hidden_size] + context_layer, present, attention_probs = attention_fn( + self=self, + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + hidden_size_per_partition=self.hidden_size_per_partition, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache) + + output = self.dense(context_layer) + + outputs = (output, present) + + if output_attentions: + outputs += (attention_probs, ) + + return outputs # output, present, attention_probs + + +class GEGLU(torch.nn.Module): + + def __init__(self): + super().__init__() + self.activation_fn = F.gelu + + def forward(self, x): + # dim=-1 breaks in jit for pt<1.10 + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) + return x1 * self.activation_fn(x2) + + +class GLU(torch.nn.Module): + + def __init__(self, + hidden_size, + inner_hidden_size=None, + layer_id=None, + bias=True, + activation_func=gelu, + params_dtype=torch.float): + super(GLU, self).__init__() + self.layer_id = layer_id + self.activation_func = activation_func + + # Project to 4h. + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.dense_h_to_4h = skip_init( + torch.nn.Linear, + self.hidden_size, + self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + # Project back to h. + self.dense_4h_to_h = skip_init( + torch.nn.Linear, + self.inner_hidden_size, + self.hidden_size, + bias=bias, + dtype=params_dtype, + ) + + def forward(self, hidden_states): + """ + hidden_states: [seq_len, batch, hidden_size] + """ + + # [seq_len, batch, inner_hidden_size] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + + intermediate_parallel = self.activation_func(intermediate_parallel) + + output = self.dense_4h_to_h(intermediate_parallel) + + return output + + +class GLMBlock(torch.nn.Module): + + def __init__(self, + hidden_size, + num_attention_heads, + layernorm_epsilon, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + layernorm=LayerNorm, + use_bias=True, + params_dtype=torch.float, + num_layers=28, + position_encoding_2d=True): + super(GLMBlock, self).__init__() + # Set output layer initialization if not provided. + + self.layer_id = layer_id + + # Layernorm on the input data. + self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.position_encoding_2d = position_encoding_2d + + # Self attention. + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, + bias=use_bias, + params_dtype=params_dtype, + position_encoding_2d=self.position_encoding_2d) + + # Layernorm on the input data. + self.post_attention_layernorm = layernorm( + hidden_size, eps=layernorm_epsilon) + + self.num_layers = num_layers + + # GLU + self.mlp = GLU( + hidden_size, + inner_hidden_size=inner_hidden_size, + bias=use_bias, + layer_id=layer_id, + params_dtype=params_dtype, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # Layer norm at the begining of the transformer layer. + # [seq_len, batch, hidden_size] + attention_input = self.input_layernorm(hidden_states) + + # Self attention. + attention_outputs = self.attention( + attention_input, + position_ids, + attention_mask=attention_mask, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions) + + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] + + # Residual connection. + alpha = (2 * self.num_layers)**0.5 + hidden_states = attention_input * alpha + attention_output + + mlp_input = self.post_attention_layernorm(hidden_states) + + # MLP. + mlp_output = self.mlp(mlp_input) + + # Second residual connection. + output = mlp_input * alpha + mlp_output + + if use_cache: + outputs = (output, ) + outputs + else: + outputs = (output, ) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class ChatGLMPreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = 'transformer' + _no_split_modules = ['GLMBlock'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [ + seq.tolist().index(self.config.bos_token_id) for seq in input_ids + ] + attention_mask = torch.ones((batch_size, seq_length, seq_length), + device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, input_ids, mask_positions, device, gmask=False): + batch_size, seq_length = input_ids.shape + context_lengths = [ + seq.tolist().index(self.config.bos_token_id) for seq in input_ids + ] + if self.position_encoding_2d: + position_ids = torch.arange( + seq_length, dtype=torch.long, + device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [ + torch.cat(( + torch.zeros( # noqa + context_length, + dtype=torch.long, + device=device), # noqa + torch.arange( # noqa + seq_length - context_length, # noqa + dtype=torch.long, # noqa + device=device) + 1)) # noqa + for context_length in context_lengths + ] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), + dim=1) + else: + position_ids = torch.arange( + seq_length, dtype=torch.long, + device=device).unsqueeze(0).repeat(batch_size, 1) + if not gmask: + for i, context_length in enumerate(context_lengths): + position_ids[context_length:] = mask_positions[i] + + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + kwargs.pop('cfg', None) + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **kwargs) + model.model_dir = model_dir + return model + + +CHATGLM_6B_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CHATGLM_6B_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`ChatGLM6BTokenizer`]. + See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert *input_ids* indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.', + CHATGLM_6B_START_DOCSTRING, +) +class ChatGLMModel(ChatGLMPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + `is_decoder` argument of the configuration set to `True`. + To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` + argument and `add_cross_attention` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__(config) + + # recording parameters + self.max_sequence_length = config.max_sequence_length + self.hidden_size = config.hidden_size + self.params_dtype = torch.half + self.num_attention_heads = config.num_attention_heads + self.vocab_size = config.vocab_size + self.num_layers = config.num_layers + self.layernorm_epsilon = config.layernorm_epsilon + self.inner_hidden_size = config.inner_hidden_size + self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads + self.position_encoding_2d = config.position_encoding_2d + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + + self.word_embeddings = skip_init( + torch.nn.Embedding, + num_embeddings=self.vocab_size, + embedding_dim=self.hidden_size, + dtype=self.params_dtype) + self.gradient_checkpointing = False + + def get_layer(layer_id): + return GLMBlock( + self.hidden_size, + self.num_attention_heads, + self.layernorm_epsilon, + layer_id, + inner_hidden_size=self.inner_hidden_size, + hidden_size_per_attention_head=self. + hidden_size_per_attention_head, + layernorm=LayerNorm, + use_bias=True, + params_dtype=self.params_dtype, + position_encoding_2d=self.position_encoding_2d, + ) + + self.layers = torch.nn.ModuleList( + [get_layer(layer_id) for layer_id in range(self.num_layers)]) + + # Final layer norm before output. + self.final_layernorm = LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon) + + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + # total_params = sum(p.numel() for p in self.parameters()) + # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, + -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, self.pre_seq_len, self.num_layers * 2, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads) + # seq_len, b, nh, hidden_size + print('#########################:', past_key_values.device) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + # past_key_values = [(v[0], v[1]) for v in past_key_values] + return past_key_values + + @add_start_docstrings_to_model_forward( + CHATGLM_6B_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], + ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + # logger.warning_once( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + # ) + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape[:2] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if past_key_values is None: + if self.pre_seq_len is not None: + past_key_values = self.get_prompt( + batch_size=input_ids.shape[0], + device=input_ids.device, + dtype=inputs_embeds.dtype) + else: + past_key_values = tuple([None] * len(self.layers)) + + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, device=input_ids.device) + + if position_ids is None: + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + mask_token = gMASK if gMASK in input_ids else MASK + use_gmask = True if gMASK in input_ids else False + + mask_positions = [ + seq.tolist().index(mask_token) for seq in input_ids + ] + position_ids = self.get_position_ids( + input_ids, + mask_positions=mask_positions, + device=input_ids.device, + gmask=use_gmask) + + if self.pre_seq_len is not None and attention_mask is not None: + prefix_attention_mask = torch.ones( + batch_size, 1, input_ids.size(-1), + self.pre_seq_len).to(attention_mask.device) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), + dim=3) + + # [seq_len, batch, hidden_size] + hidden_states = inputs_embeds.transpose(0, 1) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if attention_mask is None: + attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() + + else: + attention_mask = attention_mask.to(input_ids.device) + + for i, layer in enumerate(self.layers): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + layer_past = past_key_values[i] + + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, hidden_states, position_ids, attention_mask, + torch.tensor(i), layer_past, use_cache, output_attentions) + else: + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions) + + hidden_states = layer_ret[0] + + if use_cache: + presents = presents + (layer_ret[1], ) + + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_ret[2 if use_cache else 1], ) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, presents, all_hidden_states, all_self_attentions + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@MODELS.register_module(Tasks.chat, module_name=Models.chatglm_6b) +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLMConfig): + super().__init__(config) + + # self.hidden_size = config.hidden_size + # self.params_dtype = torch.half + # self.vocab_size = config.vocab_size + self.max_sequence_length = config.max_sequence_length + + self.position_encoding_2d = config.position_encoding_2d + + self.transformer = ChatGLMModel(config) + + self.lm_head = skip_init( + nn.Linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=torch.half) + + self.config = config + + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + # loading tokenizer + self.tokenizer = ChatGLMTokenizer.from_pretrained(config.name_or_path) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs['past_key_values'] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format) + + # update attention mask + if 'attention_mask' in model_kwargs: + attention_mask = model_kwargs['attention_mask'] + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = torch.cat([ + attention_mask, + attention_mask.new_ones((*attention_mask.shape[:3], 1)) + ], + dim=3) # noqa + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = False + model_kwargs['attention_mask'] = torch.cat( + [attention_mask, new_attention_mask], dim=2) + + # update position ids + if 'position_ids' in model_kwargs: + position_ids = model_kwargs['position_ids'] + new_position_id = position_ids[..., -1:].clone() + new_position_id[:, 1, :] += 1 + model_kwargs['position_ids'] = torch.cat( + [position_ids, new_position_id], dim=-1) + + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs) -> dict: + batch_size, seq_length = input_ids.shape + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + mask_token = gMASK if gMASK in input_ids else MASK + use_gmask = True if gMASK in input_ids else False + seqs = input_ids.tolist() + mask_positions = [seq.index(mask_token) for seq in seqs] + + # only last token for input_ids if past is not None + if past is not None or past_key_values is not None: + last_token = input_ids[:, -1].unsqueeze(-1) + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = attention_mask[:, :, -1:] + else: + attention_mask = None + if position_ids is not None: + position_ids = position_ids[..., -1:] + else: + context_lengths = [ + seq.index(self.config.bos_token_id) for seq in seqs + ] + if self.position_encoding_2d: + position_ids = torch.tensor( + [[mask_position, seq_length - context_length] + for mask_position, context_length in zip( + mask_positions, context_lengths)], + dtype=torch.long, + device=input_ids.device).unsqueeze(-1) + else: + position_ids = torch.tensor( + [mask_position for mask_position in mask_positions], + dtype=torch.long, + device=input_ids.device).unsqueeze(-1) + + if past is None: + past = past_key_values + return { + 'input_ids': last_token, + 'past_key_values': past, + 'position_ids': position_ids, + 'attention_mask': attention_mask + } + else: + if attention_mask is not None and attention_mask.dtype != torch.bool: + # logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") + attention_mask = None + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, device=input_ids.device) + if position_ids is None: + position_ids = self.get_position_ids( + input_ids, + device=input_ids.device, + mask_positions=mask_positions, + gmask=use_gmask) + + return { + 'input_ids': input_ids, + 'past_key_values': past, + 'position_ids': position_ids, + 'attention_mask': attention_mask + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits, ) + transformer_outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], + ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple(( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) for layer_past in past) + + def process_response(self, response): + response = response.strip() + response = response.replace('[[训练时间]]', '2023年') + punkts = [ + [',', ','], + ['!', '!'], + [':', ':'], + [';', ';'], + ['\?', '?'], # noqa + ] + for item in punkts: + response = re.sub(r'([\u4e00-\u9fff])%s' % item[0], + r'\1%s' % item[1], response) + response = re.sub(r'%s([\u4e00-\u9fff])' % item[0], + r'%s\1' % item[1], response) + return response + + @torch.no_grad() + def _chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + num_beams=1, + do_sample=True, + top_p=0.7, + temperature=0.95, + logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + 'max_length': max_length, + 'num_beams': num_beams, + 'do_sample': do_sample, + 'top_p': top_p, + 'temperature': temperature, + 'logits_processor': logits_processor, + **kwargs + } + if not history: + prompt = query + else: + prompt = '' + for i, (old_query, response) in enumerate(history): + prompt += '[Round {}]\n问:{}\n答:{}\n'.format( + i, old_query, response) + prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query) + inputs = tokenizer([prompt], return_tensors='pt') + inputs = inputs.to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + do_sample=True, + top_p=0.7, + temperature=0.95, + logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + 'max_length': max_length, + 'do_sample': do_sample, + 'top_p': top_p, + 'temperature': temperature, + 'logits_processor': logits_processor, + **kwargs + } + if not history: + prompt = query + else: + prompt = '' + for i, (old_query, response) in enumerate(history): + prompt += '[Round {}]\n问:{}\n答:{}\n'.format( + i, old_query, response) + prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query) + inputs = tokenizer([prompt], return_tensors='pt') + inputs = inputs.to(self.device) + for outputs in self.stream_generate(**inputs, **gen_kwargs): + outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], + List[int]]] = None, + **kwargs, + ): + _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[ + -1] # noqa + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + _, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get( + 'max_length') is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + 'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we' + ' recommend using `max_new_tokens` to control the maximum length of the generation.', + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=' + f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ' + 'Please refer to the documentation for more information. ' + '(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)', + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = 'decoder_input_ids' if self.config.is_encoder_decoder else 'input_ids' + logger.warning( + f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to' + f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider' + ' increasing `max_new_tokens`.') + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList( + ) + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList( + ) + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial( + probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder) + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria( + input_ids, scores): + break + yield input_ids + + def quantize(self, bits: int, empty_init=False, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info('Already quantized.') + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer = quantize( + self.transformer, bits, empty_init=empty_init, **kwargs) + return self + + def chat(self, input: Dict) -> Dict: + text = input['text'] + history = input['history'] + # args + if 'max_length' in input: + max_length = input['max_length'] + else: + max_length = 2048 + + if 'temperature' in input: + temperature = input['temperature'] + else: + temperature = 0.95 + + if 'num_beams' in input: + num_beams = input['num_beams'] + else: + num_beams = 1 + + if 'do_sample' in input: + do_sample = input['do_sample'] + else: + do_sample = True + + if type(history) == torch.Tensor: + history = history.tolist() + response, history = self._chat( + self.tokenizer, + text, + history, + max_length=max_length, + temperature=temperature, + num_beams=num_beams, + do_sample=do_sample) + logger.info('Generation finished.') + return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history} diff --git a/modelscope/models/nlp/chatglm/tokenization.py b/modelscope/models/nlp/chatglm/tokenization.py new file mode 100644 index 000000000..77bcde55a --- /dev/null +++ b/modelscope/models/nlp/chatglm/tokenization.py @@ -0,0 +1,463 @@ +"""Tokenization classes for ChatGLM.""" +import os +from typing import Dict, List, Optional, Union + +import numpy as np +import sentencepiece as spm +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_base import BatchEncoding, EncodedInput +from transformers.utils import PaddingStrategy, logging + +logger = logging.get_logger(__name__) + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'THUDM/chatglm-6b': 2048, +} + + +class TextTokenizer: + + def __init__(self, model_path): + self.sp = spm.SentencePieceProcessor() + self.sp.Load(model_path) + self.num_tokens = self.sp.vocab_size() + + def encode(self, text): + return self.sp.EncodeAsIds(text) + + def decode(self, ids: List[int]): + return self.sp.DecodeIds(ids) + + def tokenize(self, text): + return self.sp.EncodeAsPieces(text) + + def convert_tokens_to_ids(self, tokens): + return [self.sp.PieceToId(token) for token in tokens] + + def convert_token_to_id(self, token): + return self.sp.PieceToId(token) + + def convert_id_to_token(self, idx): + return self.sp.IdToPiece(idx) + + def __len__(self): + return self.num_tokens + + +class SPTokenizer: + + def __init__( + self, + vocab_file, + num_image_tokens=20000, + max_blank_length=80, + byte_fallback=True, + ): + assert vocab_file is not None + self.vocab_file = vocab_file + self.num_image_tokens = num_image_tokens + self.special_tokens = [ + '[MASK]', '[gMASK]', '[sMASK]', '', '', '', + '', '' + ] + self.max_blank_length = max_blank_length + self.byte_fallback = byte_fallback + self.text_tokenizer = TextTokenizer(vocab_file) + + def _get_text_tokenizer(self): + return self.text_tokenizer + + @staticmethod + def get_blank_token(length: int): + assert length >= 2 + return f'<|blank_{length}|>' + + @staticmethod + def get_tab_token(): + return '<|tab|>' + + @property + def num_text_tokens(self): + return self.text_tokenizer.num_tokens + + @property + def num_tokens(self): + return self.num_image_tokens + self.num_text_tokens + + @staticmethod + def _encode_whitespaces(text: str, max_len: int = 80): + text = text.replace('\t', SPTokenizer.get_tab_token()) + for i in range(max_len, 1, -1): + text = text.replace(' ' * i, SPTokenizer.get_blank_token(i)) + return text + + def _preprocess(self, text: str, linebreak=True, whitespaces=True): + if linebreak: + text = text.replace('\n', '') + if whitespaces: + text = self._encode_whitespaces( + text, max_len=self.max_blank_length) + return text + + def encode(self, + text: str, + linebreak=True, + whitespaces=True, + add_dummy_prefix=True) -> List[int]: + """ + @param text: Text to encode. + @param linebreak: Whether to encode newline (\n) in text. + @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. + @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. + @param add_dummy_prefix: Whether to add dummy blank space in the beginning. + """ + text = self._preprocess(text, linebreak, whitespaces) + if not add_dummy_prefix: + text = '' + text + tmp = self._get_text_tokenizer().encode(text) + tokens = [x + self.num_image_tokens for x in tmp] + return tokens if add_dummy_prefix else tokens[2:] + + def decode(self, text_ids: List[int]) -> str: + ids = [int(_id) - self.num_image_tokens for _id in text_ids] + ids = [_id for _id in ids if _id >= 0] + text = self._get_text_tokenizer().decode(ids) + text = text.replace('', '\n') + text = text.replace(SPTokenizer.get_tab_token(), '\t') + for i in range(2, self.max_blank_length + 1): + text = text.replace(self.get_blank_token(i), ' ' * i) + return text + + def tokenize(self, + text: str, + linebreak=True, + whitespaces=True, + add_dummy_prefix=True) -> List[str]: + """ + @param text: Text to encode. + @param linebreak: Whether to encode newline (\n) in text. + @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. + @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. + @param add_dummy_prefix: Whether to add dummy blank space in the beginning. + """ + text = self._preprocess(text, linebreak, whitespaces) + if not add_dummy_prefix: + text = '' + text + tokens = self._get_text_tokenizer().tokenize(text) + return tokens if add_dummy_prefix else tokens[2:] + + def __getitem__(self, x: Union[int, str]): + if isinstance(x, int): + if x < self.num_image_tokens: + return ''.format(x) + else: + return self.text_tokenizer.convert_id_to_token( + x - self.num_image_tokens) + elif isinstance(x, str): + if x.startswith('') and x[7:-1].isdigit(): + return int(x[7:-1]) + else: + return self.text_tokenizer.convert_token_to_id( + x) + self.num_image_tokens + else: + raise ValueError('The key should be str or int.') + + +class ChatGLMTokenizer(PreTrainedTokenizer): + """ + Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file: Path to the vocabulary file. + do_lower_case: Use lower case letters. + remove_space: Remove spaces. + bos_token: The bos token + eos_token: The Eos Token + end_token: The end token + mask_token: The mask token + gmask_token: The gmask token + padding_side: The padding side + num_image_tokens: The `num_image_tokens` in `SPTokenizer` + """ + + vocab_files_names = {'vocab_file': 'ice_text.model'} + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ['input_ids', 'attention_mask', 'position_ids'] + + def __init__(self, + vocab_file, + do_lower_case=False, + remove_space=False, + bos_token='', + eos_token='', + end_token='', + mask_token='[MASK]', + gmask_token='[gMASK]', + padding_side='left', + num_image_tokens=20000, + **kwargs) -> None: + super().__init__( + do_lower_case=do_lower_case, + remove_space=remove_space, + padding_side=padding_side, + bos_token=bos_token, + eos_token=eos_token, + end_token=end_token, + mask_token=mask_token, + gmask_token=gmask_token, + num_image_tokens=num_image_tokens, + **kwargs) + + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.vocab_file = vocab_file + + self.bos_token = bos_token + self.eos_token = eos_token + self.end_token = end_token + self.mask_token = mask_token + self.gmask_token = gmask_token + + self.sp_tokenizer = SPTokenizer( + vocab_file, num_image_tokens=num_image_tokens) + """ Initialisation """ + + @property + def gmask_token_id(self) -> Optional[int]: + if self.gmask_token is None: + return None + return self.convert_tokens_to_ids(self.gmask_token) + + @property + def end_token_id(self) -> Optional[int]: + """ + `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been + set. + """ + if self.end_token is None: + return None + return self.convert_tokens_to_ids(self.end_token) + + @property + def vocab_size(self): + """ Returns vocab size """ + return self.sp_tokenizer.num_tokens + + def get_vocab(self): + """ Returns vocab as a dict """ + vocab = { + self._convert_id_to_token(i): i + for i in range(self.vocab_size) + } + vocab.update(self.added_tokens_encoder) + return vocab + + def preprocess_text(self, inputs): + if self.remove_space: + outputs = ' '.join(inputs.strip().split()) + else: + outputs = inputs + + if self.do_lower_case: + outputs = outputs.lower() + + return outputs + + def _tokenize(self, text, **kwargs): + """ Returns a tokenized string. """ + text = self.preprocess_text(text) + + seq = self.sp_tokenizer.tokenize(text) + + return seq + + def _decode(self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + **kwargs) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + if len(token_ids) == 0: + return '' + if self.pad_token_id in token_ids: # remove pad + token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) + return self.sp_tokenizer.decode(token_ids) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.sp_tokenizer[token] + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_tokenizer[index] + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join(save_directory, + self.vocab_files_names['vocab_file']) + else: + vocab_file = save_directory + + with open(self.vocab_file, 'rb') as fin: + proto_str = fin.read() + + with open(vocab_file, 'wb') as writer: + writer.write(proto_str) + + return (vocab_file, ) + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + mask_ids = self.sp_tokenizer[self.mask_token] + gmask_ids = self.sp_tokenizer[self.gmask_token] + eos_id = self.sp_tokenizer[self.eos_token] + if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0: + token_ids_0 += [gmask_ids] + + if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids: + token_ids_0 += [self.sp_tokenizer[self.end_token]] + + token_ids_0 += [self.sp_tokenizer[self.bos_token]] + + if token_ids_1 is not None: + if not token_ids_1 or token_ids_1[-1] != eos_id: + token_ids_1 += [eos_id] + token_ids_0 += token_ids_1 + + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + bos_token_id = self.sp_tokenizer[self.bos_token] + mask_token_id = self.sp_tokenizer[self.mask_token] + gmask_token_id = self.sp_tokenizer[self.gmask_token] + assert self.padding_side == 'left' + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and ( + max_length % pad_to_multiple_of != 0): + max_length = ( + (max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len( + required_input) != max_length + + # Initialize attention mask if not present. + if max_length is not None: + if 'attention_mask' not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + attention_mask = np.ones((1, seq_length, seq_length)) + attention_mask = np.tril(attention_mask) + attention_mask[:, :, :context_length] = 1 + attention_mask = np.bool_(attention_mask < 0.5) + encoded_inputs['attention_mask'] = attention_mask + + if 'position_ids' not in encoded_inputs: + position_ids = np.arange(seq_length, dtype=np.int64) + mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id + if mask_token in required_input: + mask_position = required_input.index(mask_token) + position_ids[context_length:] = mask_position + block_position_ids = np.concatenate([ + np.zeros(context_length, dtype=np.int64), + np.arange( + 1, seq_length - context_length + 1, dtype=np.int64) + ]) + encoded_inputs['position_ids'] = np.stack( + [position_ids, block_position_ids], axis=0) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if 'attention_mask' in encoded_inputs: + encoded_inputs['attention_mask'] = np.pad( + encoded_inputs['attention_mask'], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode='constant', + constant_values=True) + if 'token_type_ids' in encoded_inputs: + encoded_inputs['token_type_ids'] = [ + self.pad_token_type_id + ] * difference + encoded_inputs['token_type_ids'] + if 'special_tokens_mask' in encoded_inputs: + encoded_inputs['special_tokens_mask'] = [ + 1 + ] * difference + encoded_inputs['special_tokens_mask'] + if 'position_ids' in encoded_inputs: + encoded_inputs['position_ids'] = np.pad( + encoded_inputs['position_ids'], + pad_width=[(0, 0), (difference, 0)]) + encoded_inputs[self.model_input_names[ + 0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/modelscope/models/nlp/chatglm2/__init__.py b/modelscope/models/nlp/chatglm2/__init__.py new file mode 100644 index 000000000..a2b5bfea2 --- /dev/null +++ b/modelscope/models/nlp/chatglm2/__init__.py @@ -0,0 +1,46 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import ChatGLM2Config + from .tokenization import ChatGLM2Tokenizer + from .text_generation import ChatGLM2ForConditionalGeneration + from .quantization import ( + quantize, ) + +else: + _import_structure = { + 'configuration': ['ChatGLM2Config'], + 'text_generation': ['ChatGLM2ForConditionalGeneration'], + 'quantization': ['quantize'], + 'tokenization': [ + 'ChatGLM2Tokenizer', + ], + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__) diff --git a/modelscope/models/nlp/chatglm2/configuration.py b/modelscope/models/nlp/chatglm2/configuration.py new file mode 100644 index 000000000..b10db8704 --- /dev/null +++ b/modelscope/models/nlp/chatglm2/configuration.py @@ -0,0 +1,58 @@ +""" ChatGLM model configuration """ + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class ChatGLM2Config(PretrainedConfig): + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + **kwargs): + self.num_layers = num_layers + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + super().__init__(**kwargs) diff --git a/modelscope/models/nlp/chatglm2/quantization.py b/modelscope/models/nlp/chatglm2/quantization.py new file mode 100644 index 000000000..612c9e4ba --- /dev/null +++ b/modelscope/models/nlp/chatglm2/quantization.py @@ -0,0 +1,223 @@ +import base64 +import bz2 +import ctypes +from functools import partial +from typing import List + +import torch +from torch.nn import Linear +from torch.nn.parameter import Parameter +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + + quantization_code = '$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ' # noqa + + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + 'int4WeightCompression', + 'int4WeightExtractionFloat', + 'int4WeightExtractionHalf', + 'int8WeightExtractionFloat', + 'int8WeightExtractionHalf', + ], + ) +except Exception as exception: + kernels = None + logger.warning('Failed to load cpm_kernels:' + str(exception)) + + +class W8A16Linear(torch.autograd.Function): + + @staticmethod + def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, + scale_w: torch.Tensor, weight_bit_width): + ctx.inp_shape = inp.size() + ctx.weight_bit_width = weight_bit_width + out_features = quant_w.size(0) + inp = inp.contiguous().view(-1, inp.size(-1)) + weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() + output = inp.mm(weight.t()) + ctx.save_for_backward(inp, quant_w, scale_w) + return output.view(*(ctx.inp_shape[:-1] + (out_features, ))) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + inp, quant_w, scale_w = ctx.saved_tensors + weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) + grad_output = grad_output.contiguous().view(-1, weight.size(0)) + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(inp) + return grad_input.view(ctx.inp_shape), grad_weight.view( + ctx.weight_shape), None, None + + +def compress_int4_weight(weight: torch.Tensor): # (n, m) + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + assert m % 2 == 0 + m = m // 2 + out = torch.empty(n, m, dtype=torch.int8, device='cuda') + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m) + ], + ) + return out + + +def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, + source_bit_width: int): + assert scale_list.dtype in [torch.half, torch.bfloat16] + assert weight.dtype in [torch.int8] + if source_bit_width == 8: + return weight.to(scale_list.dtype) * scale_list[:, None] + elif source_bit_width == 4: + func = ( + kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half + else kernels.int4WeightExtractionBFloat16) + else: + assert False, 'Unsupported bit-width' + + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + out = torch.empty( + n, + m * (8 // source_bit_width), + dtype=scale_list.dtype, + device='cuda') + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + func( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(scale_list.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], + ) + return out + + +class QuantizedLinear(torch.nn.Module): + + def __init__(self, + weight_bit_width: int, + weight, + bias=None, + device='cpu', + dtype=None, + empty_init=False, + *args, + **kwargs): + super().__init__() + self.weight_bit_width = weight_bit_width + + shape = weight.shape + + if weight is None or empty_init: + self.weight = torch.empty( + shape[0], + shape[1] * weight_bit_width // 8, + dtype=torch.int8, + device=device) + self.weight_scale = torch.empty( + shape[0], dtype=dtype, device=device) + else: + self.weight_scale = weight.abs().max(dim=-1).values / ( + (2**(weight_bit_width - 1)) - 1) + self.weight = torch.round(weight / self.weight_scale[:, None]).to( + torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + + self.weight = Parameter(self.weight.to(device), requires_grad=False) + self.weight_scale = Parameter( + self.weight_scale.to(device), requires_grad=False) + self.bias = Parameter( + bias.to(device), requires_grad=False) if bias is not None else None + + def forward(self, input): + output = W8A16Linear.apply(input, self.weight, self.weight_scale, + self.weight_bit_width) + if self.bias is not None: + output = output + self.bias + return output + + +def quantize(model, weight_bit_width, empty_init=False, device=None): + """Replace fp16 linear with quantized linear""" + for layer in model.layers: + layer.self_attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.query_key_value.weight.to( + torch.cuda.current_device()), + bias=layer.self_attention.query_key_value.bias, + dtype=layer.self_attention.query_key_value.weight.dtype, + device=layer.self_attention.query_key_value.weight.device + if device is None else device, + empty_init=empty_init) + layer.self_attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.dense.weight.to( + torch.cuda.current_device()), + bias=layer.self_attention.dense.bias, + dtype=layer.self_attention.dense.weight.dtype, + device=layer.self_attention.dense.weight.device + if device is None else device, + empty_init=empty_init) + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_h_to_4h.weight.to( + torch.cuda.current_device()), + bias=layer.mlp.dense_h_to_4h.bias, + dtype=layer.mlp.dense_h_to_4h.weight.dtype, + device=layer.mlp.dense_h_to_4h.weight.device + if device is None else device, + empty_init=empty_init) + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_4h_to_h.weight.to( + torch.cuda.current_device()), + bias=layer.mlp.dense_4h_to_h.bias, + dtype=layer.mlp.dense_4h_to_h.weight.dtype, + device=layer.mlp.dense_4h_to_h.weight.device + if device is None else device, + empty_init=empty_init) + + return model diff --git a/modelscope/models/nlp/chatglm2/text_generation.py b/modelscope/models/nlp/chatglm2/text_generation.py new file mode 100644 index 000000000..3c510c7a5 --- /dev/null +++ b/modelscope/models/nlp/chatglm2/text_generation.py @@ -0,0 +1,1300 @@ +""" PyTorch ChatGLM model. """ + +import copy +import math +import re +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import (GenerationConfig, + LogitsProcessorList, ModelOutput, + StoppingCriteriaList) +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from modelscope.metainfo import Models +from modelscope.models import MODELS, Model, TorchModel +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks +from .configuration import ChatGLM2Config + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = 'THUDM/ChatGLM2-6B' +_CONFIG_FOR_DOC = 'ChatGLM6BConfig' + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'THUDM/chatglm2-6b', + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + + def __call__(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000**( + torch.arange(0, dim, 2, device=device, dtype=dtype) / dim)) + self.register_buffer('inv_freq', inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl(self, + seq_len: int, + n_elem: int, + dtype: torch.dtype, + device: torch.device, + base: int = 10000): + """Enhanced Transformer with Rotary Position Embedding. + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / ( + base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) + / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack( + [torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16( + ) if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, + self.dim, + dtype=self.inv_freq.dtype, + device=self.inv_freq.device) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, + rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, + normalized_shape, + eps=1e-5, + device=None, + dtype=None, + **kwargs): + super().__init__() + self.weight = torch.nn.Parameter( + torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean( + -1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + + def __init__(self, config: ChatGLM2Config, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [ + k.permute(1, 2, 0, 3) + for k in [query_layer, key_layer, value_layer] + ] + if attention_mask is None and query_layer.shape[ + 2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, ) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[ + 2] == attention_scores.shape[3]: + attention_mask = torch.ones( + output_size[0], + 1, + output_size[2], + output_size[3], + device=attention_scores.device, + dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill( + attention_mask, float('-inf')) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), + query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, + value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, ) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLM2Config, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head + * config.multi_query_group_num) + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + **_config_to_kwargs(config)) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear( + self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config)) + + def _allocate_memory(self, + inference_max_sequence_len, + batch_size, + device=None, + dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward(self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition + * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition + * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition + * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view(query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + key_layer = key_layer.view(key_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head)) + value_layer = value_layer.view(value_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head)) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, # noqa + 3 * self.hidden_size_per_attention_head) # noqa + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, + value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if use_cache: + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition + // self.num_multi_query_groups_per_partition, -1) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition + // self.num_multi_query_groups_per_partition, -1) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + 'dtype': args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLM2Config, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config)) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config)) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLM2Config, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype) + + # Self attention. + self.self_attention = SelfAttention( + config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout( + attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout( + mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLM2Config, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList( + [build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer = self._get_layer(index) + + hidden_states, kv_cache = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache) + if use_cache: + presents = presents + (kv_cache, ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLM2Config + base_model_prefix = 'transformer' + _no_split_modules = ['GLMBlock'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones( + batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones( # noqa + batch_size, + seq_length, + past_length, # noqa + device=input_ids.device), + full_attention_mask), # noqa + dim=-1) # noqa + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze( + 1) # noqa + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange( + seq_length, dtype=torch.long, + device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + kwargs.pop('cfg', None) + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **kwargs) + model.model_dir = model_dir + return model + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLM2Config, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLM2Config, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs['device'] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads + if config.kv_channels is None else config.kv_channels) + + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim // 2, + original_impl=config.original_rope, + device=device, + dtype=config.torch_dtype) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method( + nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=config.torch_dtype, + **init_kwargs) + self.gradient_checkpointing = False + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], + ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if full_attention_mask is None: + if (attention_mask is not None + and not attention_mask.all()) or (past_key_values + and seq_length != 1): + full_attention_mask = self.get_masks( + input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states) + + if not return_dict: + return tuple(v for v in [ + hidden_states, presents, all_hidden_states, all_self_attentions + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + quantize(self.encoder, weight_bit_width) + return self + + +@MODELS.register_module(Tasks.chat, module_name=Models.chatglm2_6b) +class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLM2Config, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel( + config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs['past_key_values'] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format) + + # update attention mask + if 'attention_mask' in model_kwargs: + attention_mask = model_kwargs['attention_mask'] + model_kwargs['attention_mask'] = torch.cat( + [ # noqa + attention_mask, # noqa + attention_mask.new_ones( + (attention_mask.shape[0], 1)) # noqa + ], + dim=-1) # noqa + + # update position ids + if 'position_ids' in model_kwargs: + position_ids = model_kwargs['position_ids'] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs['position_ids'] = torch.cat( + [position_ids, new_position_id], dim=-1) + + model_kwargs['is_first_forward'] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + is_first_forward: bool = True, + **kwargs) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids( + input_ids, device=input_ids.device) + if not is_first_forward: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + 'input_ids': input_ids, + 'past_key_values': past_key_values, + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'return_last_logit': True + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits, ) + transformer_outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], + ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + Output shares the same memory storage as `past`. + """ + return tuple(( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) for layer_past in past) + + def process_response(self, response): + response = response.strip() + response = response.replace('[[训练时间]]', '2023年') + return response + + def build_inputs(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None): + prompt = '' + for i, (old_query, response) in enumerate(history): + prompt += '[Round {}]\n\n问:{}\n\n答:{}\n\n'.format( + i + 1, old_query, response) + prompt += '[Round {}]\n\n问:{}\n\n答:'.format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors='pt') + inputs = inputs.to(self.device) + return inputs + + def build_stream_inputs(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None): + if history: + prompt = '\n\n[Round {}]\n\n问:{}\n\n答:'.format( + len(history) + 1, query) + input_ids = tokenizer.encode(prompt, add_special_tokens=False) + input_ids = input_ids[1:] + inputs = tokenizer.batch_encode_plus([(input_ids, None)], + return_tensors='pt', + add_special_tokens=False) + else: + prompt = '[Round {}]\n\n问:{}\n\n答:'.format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors='pt') + inputs = inputs.to(self.device) + return inputs + + @torch.no_grad() + def chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + num_beams=1, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + 'max_length': max_length, + 'num_beams': num_beams, + 'do_sample': do_sample, + 'top_p': top_p, + 'temperature': temperature, + 'logits_processor': logits_processor, + **kwargs + } + inputs = self.build_inputs(tokenizer, query, history=history) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history} + + @torch.no_grad() + def stream_chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + past_key_values=None, + max_length: int = 2048, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + return_past_key_values=False, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + 'max_length': max_length, + 'do_sample': do_sample, + 'top_p': top_p, + 'temperature': temperature, + 'logits_processor': logits_processor, + **kwargs + } + if past_key_values is None and not return_past_key_values: + inputs = self.build_inputs(tokenizer, query, history=history) + else: + inputs = self.build_stream_inputs( + tokenizer, query, history=history) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat( + (attention_mask.new_ones(1, past_length), attention_mask), + dim=1) + inputs['attention_mask'] = attention_mask + for outputs in self.stream_generate( + **inputs, + past_key_values=past_key_values, + return_past_key_values=return_past_key_values, + **gen_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], + List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + _, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get( + 'max_length') is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + 'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we' + ' recommend using `max_new_tokens` to control the maximum length of the generation.', + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=' + f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ' + 'Please refer to the documentation for more information. ' + '(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)', + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = 'decoder_input_ids' if self.config.is_encoder_decoder else 'input_ids' + logger.warning( + f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to' + f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider' + ' increasing `max_new_tokens`.') + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList( + ) + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList( + ) + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial( + probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder) + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long()) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria( + input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info('Already quantized.') + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize( + self.transformer.encoder, + bits, + empty_init=empty_init, + device=device, + **kwargs) + return self diff --git a/modelscope/models/nlp/chatglm2/tokenization.py b/modelscope/models/nlp/chatglm2/tokenization.py new file mode 100644 index 000000000..5036d881a --- /dev/null +++ b/modelscope/models/nlp/chatglm2/tokenization.py @@ -0,0 +1,251 @@ +"""Tokenization classes for ChatGLM.""" +import os +from typing import Dict, List, Optional, Union + +from sentencepiece import SentencePieceProcessor +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_base import BatchEncoding, EncodedInput +from transformers.utils import PaddingStrategy, logging + +logger = logging.get_logger(__name__) + + +class SPTokenizer: + + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.eos_id() + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + special_tokens = ['[MASK]', '[gMASK]', '[sMASK]', 'sop', 'eop'] + self.special_tokens = {} + self.index_special_tokens = {} + for token in special_tokens: + self.special_tokens[token] = self.n_words + self.index_special_tokens[self.n_words] = token + self.n_words += 1 + + def tokenize(self, s: str): + return self.sp_model.EncodeAsPieces(s) + + def encode(self, + s: str, + bos: bool = False, + eos: bool = False) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + return self.sp_model.decode(t) + + def decode_tokens(self, tokens: List[str]) -> str: + text = self.sp_model.DecodePieces(tokens) + return text + + def convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + if token in self.special_tokens: + return self.special_tokens[token] + return self.sp_model.PieceToId(token) + + def convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.index_special_tokens: + return '' + return self.sp_model.IdToPiece(index) + + +class ChatGLM2Tokenizer(PreTrainedTokenizer): + vocab_files_names = {'vocab_file': 'tokenizer.model'} + + model_input_names = ['input_ids', 'attention_mask', 'position_ids'] + + def __init__(self, vocab_file, padding_side='left', **kwargs): + super().__init__(padding_side=padding_side, **kwargs) + self.name = 'GLMTokenizer' + + self.tokenizer = SPTokenizer(vocab_file) + self.special_tokens = { + '': self.tokenizer.bos_id, + '': self.tokenizer.eos_id, + '': self.tokenizer.pad_id + } + + def get_command(self, token): + if token in self.special_tokens: + return self.special_tokens[token] + assert token in self.tokenizer.special_tokens, f'{token} is not a special token for {self.name}' + return self.tokenizer.special_tokens[token] + + @property + def pad_token(self) -> str: + return '' + + @property + def pad_token_id(self): + return self.get_command('') + + @property + def eos_token_id(self): + return self.get_command('') + + @property + def vocab_size(self): + return self.tokenizer.n_words + + def get_vocab(self): + """ Returns vocab as a dict """ + vocab = { + self._convert_id_to_token(i): i + for i in range(self.vocab_size) + } + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text, **kwargs): + return self.tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.tokenizer.convert_token_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.tokenizer.decode_tokens(tokens) + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join(save_directory, + self.vocab_files_names['vocab_file']) + else: + vocab_file = save_directory + + with open(self.vocab_file, 'rb') as fin: + proto_str = fin.read() + + with open(vocab_file, 'wb') as writer: + writer.write(proto_str) + + return (vocab_file, ) + + def get_prefix_tokens(self): + prefix_tokens = [self.get_command('[gMASK]'), self.get_command('sop')] + return prefix_tokens + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + prefix_tokens = self.get_prefix_tokens() + token_ids_0 = prefix_tokens + token_ids_0 + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + [ + self.get_command('') + ] + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + assert self.padding_side == 'left' + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and ( + max_length % pad_to_multiple_of != 0): + max_length = ( + (max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len( + required_input) != max_length + + # Initialize attention mask if not present. + if 'attention_mask' not in encoded_inputs: + encoded_inputs['attention_mask'] = [1] * seq_length + + if 'position_ids' not in encoded_inputs: + encoded_inputs['position_ids'] = list(range(seq_length)) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if 'attention_mask' in encoded_inputs: + encoded_inputs['attention_mask'] = [ + 0 + ] * difference + encoded_inputs['attention_mask'] + if 'position_ids' in encoded_inputs: + encoded_inputs['position_ids'] = [ + 0 + ] * difference + encoded_inputs['position_ids'] + encoded_inputs[self.model_input_names[ + 0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/modelscope/models/nlp/use/user_satisfaction_estimation.py b/modelscope/models/nlp/use/user_satisfaction_estimation.py index f006f5445..77b124142 100644 --- a/modelscope/models/nlp/use/user_satisfaction_estimation.py +++ b/modelscope/models/nlp/use/user_satisfaction_estimation.py @@ -27,7 +27,8 @@ class UserSatisfactionEstimation(TorchModel): def __init__(self, model_dir: str, bert_name: str = None, - device: str = None): + device: str = None, + **kwargs): """initialize the user satisfaction estimation model from the `model_dir` path. The default preprocessor for this task is DialogueClassificationUsePreprocessor. @@ -36,7 +37,7 @@ def __init__(self, bert_name: The pretrained model, default bert-base-chinese device: The device of running model, default cpu """ - super().__init__(model_dir) + super().__init__(model_dir, **kwargs) self.model_dir = model_dir self.bert_name = bert_name if bert_name is not None else 'bert-base-chinese' self.device = 'cpu' diff --git a/modelscope/msdatasets/context/dataset_context_config.py b/modelscope/msdatasets/context/dataset_context_config.py index 4007d60c4..48124d786 100644 --- a/modelscope/msdatasets/context/dataset_context_config.py +++ b/modelscope/msdatasets/context/dataset_context_config.py @@ -17,7 +17,7 @@ def __init__(self, dataset_name: Union[str, list], namespace: str, data_files: Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]], download_mode: DownloadMode, cache_root_dir: str, - use_streaming: bool, **kwargs): + use_streaming: bool, stream_batch_size: int, **kwargs): self._download_config = None self._data_meta_config = None @@ -42,6 +42,7 @@ def __init__(self, dataset_name: Union[str, list], namespace: str, self.data_files = data_files self.cache_root_dir = cache_root_dir self.use_streaming = use_streaming + self.stream_batch_size = stream_batch_size self.download_virgo_files: bool = False @property diff --git a/modelscope/msdatasets/data_files/data_files_manager.py b/modelscope/msdatasets/data_files/data_files_manager.py index d5f43533b..206a4d74b 100644 --- a/modelscope/msdatasets/data_files/data_files_manager.py +++ b/modelscope/msdatasets/data_files/data_files_manager.py @@ -13,8 +13,8 @@ from modelscope.msdatasets.download.download_config import DataDownloadConfig from modelscope.msdatasets.download.download_manager import ( DataDownloadManager, DataStreamingDownloadManager) -from modelscope.utils.constant import (DatasetPathName, DownloadMode, - MetaDataFields) +from modelscope.utils.constant import (META_FILES_FORMAT, DatasetPathName, + DownloadMode, MetaDataFields) class DataFilesManager(object): @@ -85,7 +85,8 @@ def get_data_files_builder(self) -> Union[DatasetBuilder, None]: builder = TaskSpecificDatasetBuilder( dataset_context_config=self.dataset_context_config) - elif meta_data_file.endswith('.csv'): + elif meta_data_file and os.path.splitext( + meta_data_file)[-1] in META_FILES_FORMAT: builder = CsvDatasetBuilder( dataset_context_config=self.dataset_context_config) else: diff --git a/modelscope/msdatasets/dataset_cls/dataset.py b/modelscope/msdatasets/dataset_cls/dataset.py index cebfcfba2..48a5ab516 100644 --- a/modelscope/msdatasets/dataset_cls/dataset.py +++ b/modelscope/msdatasets/dataset_cls/dataset.py @@ -91,41 +91,37 @@ def __iter__(self): class NativeIterableDataset(IterableDataset): """The modelscope iterable dataset class.""" - def __init__(self, ex_iterable, info, split): + def __init__(self, ex_iterable, info, split, stream_batch_size=1): super().__init__(ex_iterable=ex_iterable, info=info, split=split) + self.stream_batch_size = stream_batch_size def __iter__(self): - for key, entity in tqdm( - self._iter(), + for item in tqdm( + self.iter( + batch_size=self.stream_batch_size, drop_last_batch=False), desc='Overall progress', total=self.n_shards, dynamic_ncols=True): - if isinstance(entity, dict): - ret = {} + ret = {} + if isinstance(item, dict): try: - for k, v in entity.items(): + for k, v in item.items(): ret[k] = v if k.endswith(':FILE'): dl_manager = self._ex_iterable.kwargs.get( 'dl_manager') ex_cache_path = dl_manager.download_and_extract(v) + if isinstance(ex_cache_path, str): + ex_cache_path = [ex_cache_path] ret[k] = ex_cache_path - if k.endswith('Image:FILE'): - from PIL import Image - ret[k - + ':Object'] = Image.open(fp=ex_cache_path) - if k.endswith('Audio:FILE'): - import torchaudio - waveform_and_rate = torchaudio.load( - ex_cache_path) - ret[k + ':Object'] = waveform_and_rate + except Exception as e: logger.error(e) - ret = {} - - entity = ret + ret = item + else: + ret = item - yield entity + yield ret def __len__(self): return self.n_shards @@ -138,9 +134,19 @@ def head(self, n=5): n (int): Number of rows to return. Returns: - Dict[str, list]: e.g. {'col1': [val11, val12, ...], 'col2': [val21, val22, ...]} + list: The list of results, e.g. [{'id': 'abc123', 'text': 'hello world'}, ...] """ - return self._head(n=n) + # return self._head(n=n) + res = [] + if n <= 0: + return res + iter_num = 0 + for item in self.__iter__(): + if iter_num >= n: + break + res.append(item) + iter_num += 1 + return res class VirgoDataset(object): diff --git a/modelscope/msdatasets/download/dataset_builder.py b/modelscope/msdatasets/download/dataset_builder.py index 796e3d83d..140503f05 100644 --- a/modelscope/msdatasets/download/dataset_builder.py +++ b/modelscope/msdatasets/download/dataset_builder.py @@ -206,7 +206,7 @@ def _download_and_prepare(self, dl_manager, download_mode): os.makedirs(target_cache_dir, exist_ok=True) self.local_meta_csv_paths = { - k: HubApi.fetch_csv_from_url(v, target_cache_dir) + k: HubApi.fetch_meta_files_from_url(v, target_cache_dir) for k, v in self.meta_data_files.items() } @@ -301,6 +301,7 @@ def __init__(self, dataset_context_config: DatasetContextConfig): self.meta_data_files = dataset_context_config.data_meta_config.meta_data_files self.zip_data_files = dataset_context_config.data_meta_config.zip_data_files self.input_config_kwargs = dataset_context_config.config_kwargs + self.stream_batch_size = dataset_context_config.stream_batch_size self.cache_build_dir = os.path.join(self.cache_root_dir, self.namespace, self.dataset_name, @@ -433,7 +434,10 @@ def _as_streaming_dataset_single( ex_iterable = self._get_examples_iterable_for_split(splits_generator) return NativeIterableDataset( - ex_iterable, info=self.info, split=splits_generator.name) + ex_iterable, + info=self.info, + split=splits_generator.name, + stream_batch_size=self.stream_batch_size) def _generate_tables(self, **gen_kwargs): @@ -477,8 +481,8 @@ def _generate_tables(self, **gen_kwargs): raise f'Neither column meta nor data file found in {self.dataset_name}.json .' def _get_meta_csv_df(self, meta_file_url: str) -> None: - if not self.meta_csv_df: - meta_csv_file_path = HubApi.fetch_csv_from_url( + if self.meta_csv_df is None or self.meta_csv_df.empty: + meta_csv_file_path = HubApi.fetch_meta_files_from_url( meta_file_url, self.meta_cache_dir) self.meta_csv_df = pd.read_csv( meta_csv_file_path, diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index b9ff9971b..b720ada62 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -82,12 +82,6 @@ def __getitem__(self, key): return self._hf_ds[key] def __len__(self): - if isinstance(self._hf_ds, IterableDataset) or isinstance( - self._hf_ds, NativeIterableDataset): - logger.warning( - f'object of type `{self._hf_ds.__class__.__name__}` has default length 1' - ) - return 1 return len(self._hf_ds) @property @@ -170,6 +164,7 @@ def load( REUSE_DATASET_IF_EXISTS, cache_dir: Optional[str] = MS_DATASETS_CACHE, use_streaming: Optional[bool] = False, + stream_batch_size: Optional[int] = 1, custom_cfg: Optional[Config] = Config(), token: Optional[str] = None, **config_kwargs, @@ -197,6 +192,7 @@ def load( use_streaming (bool, Optional): If set to True, no need to download all data files. Instead, it streams the data progressively, and returns NativeIterableDataset or a dict of NativeIterableDataset. + stream_batch_size (int, Optional): The batch size of the streaming data. custom_cfg (str, Optional): Model configuration, this can be used for custom datasets. see https://modelscope.cn/docs/Configuration%E8%AF%A6%E8%A7%A3 token (str, Optional): SDK token of ModelScope. @@ -251,6 +247,7 @@ def load( download_mode=download_mode, cache_root_dir=cache_dir, use_streaming=use_streaming, + stream_batch_size=stream_batch_size, **config_kwargs) # Load from local disk @@ -273,11 +270,12 @@ def load( dataset_context_config).load_dataset( RemoteDataLoaderType.HF_DATA_LOADER) dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target) - dataset_inst._dataset_context_config = dataset_context_config - if custom_cfg: - dataset_inst.to_custom_dataset( - custom_cfg=custom_cfg, **config_kwargs) - dataset_inst.is_custom = True + if isinstance(dataset_inst, MsDataset): + dataset_inst._dataset_context_config = dataset_context_config + if custom_cfg: + dataset_inst.to_custom_dataset( + custom_cfg=custom_cfg, **config_kwargs) + dataset_inst.is_custom = True return dataset_inst # Load from the modelscope hub elif hub == Hubs.modelscope: diff --git a/modelscope/msdatasets/utils/dataset_utils.py b/modelscope/msdatasets/utils/dataset_utils.py index dc7f9b38d..b40915eb8 100644 --- a/modelscope/msdatasets/utils/dataset_utils.py +++ b/modelscope/msdatasets/utils/dataset_utils.py @@ -207,7 +207,7 @@ def get_dataset_files(subset_split_into: dict, if args_dict and args_dict.get(MetaDataFields.ARGS_BIG_DATA): meta_csv_file_url = meta_map[split] - meta_csv_file_path = HubApi.fetch_csv_from_url( + meta_csv_file_path = HubApi.fetch_meta_files_from_url( meta_csv_file_url, meta_cache_dir) csv_delimiter = context_config.config_kwargs.get('delimiter', ',') diff --git a/modelscope/pipelines/audio/ans_dfsmn_pipeline.py b/modelscope/pipelines/audio/ans_dfsmn_pipeline.py index fad77091a..d2cacc5a0 100644 --- a/modelscope/pipelines/audio/ans_dfsmn_pipeline.py +++ b/modelscope/pipelines/audio/ans_dfsmn_pipeline.py @@ -70,7 +70,8 @@ def stft(x): HOP_LENGTH, STFT_WIN_LEN, center=False, - window=window) + window=window, + return_complex=False) def istft(x, slen): return librosa.istft( diff --git a/modelscope/pipelines/audio/linear_aec_pipeline.py b/modelscope/pipelines/audio/linear_aec_pipeline.py index f07b77a53..aa7a1cd6c 100644 --- a/modelscope/pipelines/audio/linear_aec_pipeline.py +++ b/modelscope/pipelines/audio/linear_aec_pipeline.py @@ -82,18 +82,19 @@ def __init__(self, model, **kwargs): window = torch.hamming_window(winlen, periodic=False) def stft(x): - return torch.stft( - x, - n_fft, - hop_length, - winlen, - center=False, - window=window.to(x.device), - return_complex=False) + return torch.view_as_real( + torch.stft( + x, + n_fft, + hop_length, + winlen, + center=False, + window=window.to(x.device), + return_complex=True)) def istft(x, slen): return torch.istft( - x, + torch.view_as_complex(x), n_fft, hop_length, winlen, diff --git a/modelscope/pipelines/audio/segmentation_clustering_pipeline.py b/modelscope/pipelines/audio/segmentation_clustering_pipeline.py new file mode 100644 index 000000000..020f6268b --- /dev/null +++ b/modelscope/pipelines/audio/segmentation_clustering_pipeline.py @@ -0,0 +1,325 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io +from typing import Any, Dict, List, Union + +import numpy as np +import soundfile as sf +import torch +import torchaudio + +from modelscope.fileio import File +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import InputModel, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['SegmentationClusteringPipeline'] + + +@PIPELINES.register_module( + Tasks.speaker_diarization, module_name=Pipelines.segmentation_clustering) +class SegmentationClusteringPipeline(Pipeline): + """Segmentation and Clustering Pipeline + use `model` to create a Segmentation and Clustering Pipeline. + + Args: + model (SegmentationClusteringPipeline): A model instance, or a model local dir, or a model id in the model hub. + kwargs (dict, `optional`): + Extra kwargs passed into the pipeline's constructor. + Example: + >>> from modelscope.pipelines import pipeline + >>> from modelscope.utils.constant import Tasks + >>> p = pipeline( + >>> task=Tasks.speaker_diarization, model='damo/speech_campplus_speaker-diarization_common') + >>> print(p(audio)) + + """ + + def __init__(self, model: InputModel, **kwargs): + """use `model` to create a speaker diarization pipeline for prediction + Args: + model (str): a valid offical model id + """ + super().__init__(model=model, **kwargs) + self.config = self.model.other_config + config = { + 'seg_dur': 1.5, + 'seg_shift': 0.75, + 'batch_size': 128, + } + self.config.update(config) + self.fs = self.config['sample_rate'] + self.sv_pipeline = pipeline( + task='speaker-verification', model=self.config['speaker_model']) + + def __call__(self, + audio: Union[str, np.ndarray, list], + output_res=False, + **params) -> Dict[str, Any]: + """ extract the speaker embeddings of input audio and do cluster + Args: + audio (str, np.ndarray, list): If it is represented as a str or a np.ndarray, it + should be a complete speech signal and requires VAD preprocessing. If the audio + is represented as a list, it should contain only the effective speech segments + obtained through VAD preprocessing. The list should be formatted as [[0(s),3.2, + np.ndarray], [5.3,9.1, np.ndarray], ...]. Each element is a sublist that contains + the start time, end time, and the numpy array of the speech segment respectively. + """ + self.config.update(params) + # vad + logger.info('Doing VAD...') + vad_segments = self.preprocess(audio) + # check input data + self.check_audio_list(vad_segments) + # segmentation + logger.info('Doing segmentation...') + segments = self.chunk(vad_segments) + # embedding + logger.info('Extracting embeddings...') + embeddings = self.forward(segments) + # clustering + logger.info('Clustering...') + labels = self.clustering(embeddings) + # post processing + logger.info('Post processing...') + output = self.postprocess(segments, vad_segments, labels, embeddings) + return {OutputKeys.TEXT: output} + + def forward(self, input: list) -> np.ndarray: + bs = self.config['batch_size'] + x = [] + embeddings = [] + for i, s in enumerate(input): + x.append(s[2]) + if len(x) >= bs: + x = np.stack(x) + _, embs = self.sv_pipeline(x, output_emb=True) + embeddings.append(embs) + x = [] + if len(x) > 0: + x = np.stack(x) + _, embs = self.sv_pipeline(x, output_emb=True) + embeddings.append(embs) + x = [] + embeddings = np.concatenate(embeddings) + return embeddings + + def clustering(self, embeddings: np.ndarray) -> np.ndarray: + labels = self.model(embeddings, **self.config) + return labels + + def postprocess(self, segments: list, vad_segments: list, + labels: np.ndarray, embeddings: np.ndarray) -> list: + assert len(segments) == len(labels) + labels = self.correct_labels(labels) + distribute_res = [] + for i in range(len(segments)): + distribute_res.append([segments[i][0], segments[i][1], labels[i]]) + # merge the same speakers chronologically + distribute_res = self.merge_seque(distribute_res) + + # accquire speaker center + spk_embs = [] + for i in range(labels.max() + 1): + spk_emb = embeddings[labels == i].mean(0) + spk_embs.append(spk_emb) + spk_embs = np.stack(spk_embs) + + def is_overlapped(t1, t2): + if t1 > t2 + 1e-4: + return True + return False + + # distribute the overlap region + for i in range(1, len(distribute_res)): + if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]): + p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2 + if 'change_locator' in self.config: + if not hasattr(self, 'change_locator_pipeline'): + self.change_locator_pipeline = pipeline( + task=Tasks.speaker_diarization, + model=self.config['change_locator']) + short_utt_st = max(p - 1.5, distribute_res[i - 1][0]) + short_utt_ed = min(p + 1.5, distribute_res[i][1]) + if short_utt_ed - short_utt_st > 1: + audio_data = self.cut_audio(short_utt_st, short_utt_ed, + vad_segments) + spk1 = distribute_res[i - 1][2] + spk2 = distribute_res[i][2] + _, ct = self.change_locator_pipeline( + audio_data, [spk_embs[spk1], spk_embs[spk2]], + output_res=True) + if ct is not None: + p = short_utt_st + ct + distribute_res[i][0] = p + distribute_res[i - 1][1] = p + + # smooth the result + distribute_res = self.smooth(distribute_res) + + return distribute_res + + def preprocess(self, audio: Union[str, np.ndarray, list]) -> list: + if isinstance(audio, list): + audio.sort(key=lambda x: x[0]) + return audio + elif isinstance(audio, str): + file_bytes = File.read(audio) + audio, fs = sf.read(io.BytesIO(file_bytes), dtype='float32') + if len(audio.shape) == 2: + audio = audio[:, 0] + if fs != self.fs: + logger.info( + f'[WARNING]: The sample rate of audio is not {self.fs}, resample it.' + ) + audio, fs = torchaudio.sox_effects.apply_effects_tensor( + torch.from_numpy(audio).unsqueeze(0), + fs, + effects=[['rate', str(self.fs)]]) + audio = audio.squeeze(0).numpy() + assert len(audio.shape) == 1, 'modelscope error: Wrong audio format.' + if audio.dtype in ['int16', 'int32', 'int64']: + audio = (audio / (1 << 15)).astype('float32') + if not hasattr(self, 'vad_pipeline'): + self.vad_pipeline = pipeline( + task=Tasks.voice_activity_detection, + model=self.config['vad_model']) + vad_time = self.vad_pipeline(audio, audio_fs=self.fs) + vad_segments = [] + for t in vad_time['text']: + st = t[0] / 1000 + ed = t[1] / 1000 + vad_segments.append( + [st, ed, audio[int(st * self.fs):int(ed * self.fs)]]) + + return vad_segments + + def check_audio_list(self, audio: list): + audio_dur = 0 + for i in range(len(audio)): + seg = audio[i] + assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.' + assert isinstance(seg[2], + np.ndarray), 'modelscope error: Wrong data type.' + assert int(seg[1] * self.fs) - int( + seg[0] * self.fs + ) == seg[2].shape[ + 0], 'modelscope error: audio data in list is inconsistent with time length.' + if i > 0: + assert seg[0] >= audio[ + i - 1][1], 'modelscope error: Wrong time stamps.' + audio_dur += seg[1] - seg[0] + if audio[i][2].dtype in ['int16', 'int32', 'int64']: + audio[i][2] = (audio[i][2] / (1 << 15)).astype('float32') + assert audio_dur > 10, 'modelscope error: The effective audio duration is too short.' + + def chunk(self, vad_segments: list) -> list: + + def seg_chunk(seg_data): + seg_st = seg_data[0] + data = seg_data[2] + chunk_len = int(self.config['seg_dur'] * self.fs) + chunk_shift = int(self.config['seg_shift'] * self.fs) + last_chunk_ed = 0 + seg_res = [] + for chunk_st in range(0, data.shape[0], chunk_shift): + chunk_ed = min(chunk_st + chunk_len, data.shape[0]) + if chunk_ed <= last_chunk_ed: + break + last_chunk_ed = chunk_ed + chunk_st = max(0, chunk_ed - chunk_len) + chunk_data = data[chunk_st:chunk_ed] + if chunk_data.shape[0] < chunk_len: + chunk_data = np.pad(chunk_data, + (0, chunk_len - chunk_data.shape[0]), + 'constant') + seg_res.append([ + chunk_st / self.fs + seg_st, chunk_ed / self.fs + seg_st, + chunk_data + ]) + return seg_res + + segs = [] + for i, s in enumerate(vad_segments): + segs.extend(seg_chunk(s)) + + return segs + + def cut_audio(self, cut_st: float, cut_ed: float, + audio: Union[np.ndarray, list]) -> np.ndarray: + # collect audio data given the start and end time. + if isinstance(audio, np.ndarray): + return audio[int(cut_st * self.fs):int(cut_ed * self.fs)] + elif isinstance(audio, list): + for i in range(len(audio)): + if i == 0: + if cut_st < audio[i][1]: + st_i = i + else: + if cut_st >= audio[i - 1][1] and cut_st < audio[i][1]: + st_i = i + + if i == len(audio) - 1: + if cut_ed > audio[i][0]: + ed_i = i + else: + if cut_ed > audio[i][0] and cut_ed <= audio[i + 1][0]: + ed_i = i + audio_segs = audio[st_i:ed_i + 1] + cut_data = [] + for i in range(len(audio_segs)): + s_st, s_ed, data = audio_segs[i] + cut_data.append( + data[int((max(cut_st, s_st) - s_st) + * self.fs):int((min(cut_ed, s_ed) - s_st) + * self.fs)]) + cut_data = np.concatenate(cut_data) + return cut_data + else: + raise ValueError('modelscope error: Wrong audio format.') + + def correct_labels(self, labels): + labels_id = 0 + id2id = {} + new_labels = [] + for i in labels: + if i not in id2id: + id2id[i] = labels_id + labels_id += 1 + new_labels.append(id2id[i]) + return np.array(new_labels) + + def merge_seque(self, distribute_res): + res = [distribute_res[0]] + for i in range(1, len(distribute_res)): + if distribute_res[i][2] != res[-1][2] or distribute_res[i][ + 0] > res[-1][1]: + res.append(distribute_res[i]) + else: + res[-1][1] = distribute_res[i][1] + return res + + def smooth(self, res, mindur=1): + # short segments are assigned to nearest speakers. + for i in range(len(res)): + res[i][0] = round(res[i][0], 2) + res[i][1] = round(res[i][1], 2) + if res[i][1] - res[i][0] < mindur: + if i == 0: + res[i][2] = res[i + 1][2] + elif i == len(res) - 1: + res[i][2] = res[i - 1][2] + elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]: + res[i][2] = res[i - 1][2] + else: + res[i][2] = res[i + 1][2] + # merge the speakers + res = self.merge_seque(res) + + return res diff --git a/modelscope/pipelines/audio/speaker_change_locating_pipeline.py b/modelscope/pipelines/audio/speaker_change_locating_pipeline.py index 0bab08ace..a50a8f52f 100644 --- a/modelscope/pipelines/audio/speaker_change_locating_pipeline.py +++ b/modelscope/pipelines/audio/speaker_change_locating_pipeline.py @@ -6,6 +6,7 @@ import numpy as np import soundfile as sf import torch +import torchaudio from modelscope.fileio import File from modelscope.metainfo import Pipelines @@ -46,10 +47,14 @@ def __init__(self, model: InputModel, **kwargs): """ super().__init__(model=model, **kwargs) self.model_config = self.model.model_config - self.config = self.model.model_config - self.anchor_size = self.config['anchor_size'] - - def __call__(self, audio: str, embds: List = None) -> Dict[str, Any]: + self.anchor_size = self.model_config['anchor_size'] + + def __call__( + self, + audio: Union[str, np.ndarray], + embds: Union[list, np.ndarray] = None, + output_res=False, + ): if embds is not None: assert len(embds) == 2 assert isinstance(embds[0], np.ndarray) and isinstance( @@ -65,41 +70,58 @@ def __call__(self, audio: str, embds: List = None) -> Dict[str, Any]: np.stack([embd1, embd2], axis=1).flatten(), np.stack([embd3, embd4], axis=1).flatten(), ] - anchors = torch.from_numpy(np.stack(embds, - axis=0)).float().unsqueeze(0) + if isinstance(embds, list): + anchors = np.stack(embds, axis=0) + anchors = torch.from_numpy(anchors).unsqueeze(0).float() output = self.preprocess(audio) output = self.forward(output, anchors) - output = self.postprocess(output) + output, p = self.postprocess(output) - return output + if output_res: + return output, p + else: + return output def forward(self, input: torch.Tensor, anchors: torch.Tensor): output = self.model(input, anchors) return output - def postprocess(self, input: torch.Tensor) -> Dict[str, Any]: + def postprocess(self, input: torch.Tensor): predict = np.where(np.diff(input.argmax(-1).numpy())) try: predict = predict[0][0] * 0.01 + 0.02 predict = round(predict, 2) - return {OutputKeys.TEXT: f'The change point is at {predict}s.'} + return { + OutputKeys.TEXT: f'The change point is at {predict}s.' + }, predict except Exception: - return {OutputKeys.TEXT: 'No change point is found.'} + return {OutputKeys.TEXT: 'No change point is found.'}, None - def preprocess(self, input: str) -> torch.Tensor: + def preprocess(self, input: Union[str, np.ndarray]) -> torch.Tensor: if isinstance(input, str): file_bytes = File.read(input) data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32') if len(data.shape) == 2: data = data[:, 0] - if fs != self.model_config['sample_rate']: - raise ValueError( - 'modelscope error: Only support %d sample rate files' - % self.model_cfg['sample_rate']) data = torch.from_numpy(data).unsqueeze(0) + if fs != self.model_config['sample_rate']: + logger.warning( + 'The sample rate of audio is not %d, resample it.' + % self.model_config['sample_rate']) + data, fs = torchaudio.sox_effects.apply_effects_tensor( + data, + fs, + effects=[['rate', + str(self.model_config['sample_rate'])]]) + elif isinstance(input, np.ndarray): + if input.dtype in ['int16', 'int32', 'int64']: + input = (input / (1 << 15)).astype('float32') + data = torch.from_numpy(input) + if len(data.shape) == 1: + data = data.unsqueeze(0) else: raise ValueError( - 'modelscope error: The input type is restricted to audio file address' - % i) + 'modelscope error: The input type is restricted to audio file address and numpy array.' + ) return data diff --git a/modelscope/pipelines/audio/speaker_verification_light_pipeline.py b/modelscope/pipelines/audio/speaker_verification_light_pipeline.py index 5cff800ac..8c6212fd1 100644 --- a/modelscope/pipelines/audio/speaker_verification_light_pipeline.py +++ b/modelscope/pipelines/audio/speaker_verification_light_pipeline.py @@ -1,10 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import io +import os from typing import Any, Dict, List, Union +import numpy as np import soundfile as sf import torch +import torchaudio from modelscope.fileio import File from modelscope.metainfo import Pipelines @@ -49,62 +52,115 @@ def __init__(self, model: InputModel, **kwargs): self.thr = self.config['yesOrno_thr'] def __call__(self, - in_audios: List[str], - thr: float = None) -> Dict[str, Any]: + in_audios: Union[np.ndarray, list], + save_dir: str = None, + output_emb: bool = False, + thr: float = None): if thr is not None: self.thr = thr if self.thr < -1 or self.thr > 1: raise ValueError( 'modelscope error: the thr value should be in [-1, 1], but found to be %f.' % self.thr) - outputs = self.preprocess(in_audios) - outputs = self.forward(outputs) - outputs = self.postprocess(outputs) - - return outputs - - def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - emb1 = self.model(inputs['data1']) - emb2 = self.model(inputs['data2']) - - return {'emb1': emb1, 'emb2': emb2} + wavs = self.preprocess(in_audios) + embs = self.forward(wavs) + outputs = self.postprocess(embs, in_audios, save_dir) + if output_emb: + return outputs, embs.numpy() + else: + return outputs + + def forward(self, inputs: Union[torch.Tensor, list]): + if isinstance(inputs, list): + embs = [] + for x in inputs: + embs.append(self.model(x)) + embs = torch.cat(embs) + else: + embs = self.model(inputs) + return embs + + def postprocess(self, + inputs: torch.Tensor, + in_audios: Union[np.ndarray, list], + save_dir=None): + if isinstance(in_audios[0], str): + if save_dir is not None: + # save the embeddings + os.makedirs(save_dir, exist_ok=True) + for i, p in enumerate(in_audios): + save_path = os.path.join( + save_dir, '%s.npy' % + (os.path.basename(p).rsplit('.', 1)[0])) + np.save(save_path, inputs[i].numpy()) + + if len(in_audios) == 2: + # compute the score + score = self.compute_cos_similarity(inputs[0], inputs[1]) + score = round(score, 5) + if score >= self.thr: + ans = 'yes' + else: + ans = 'no' + output = {OutputKeys.SCORE: score, OutputKeys.TEXT: ans} + else: + output = {OutputKeys.TEXT: 'No similarity score output'} - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - score = self.compute_cos_similarity(inputs['emb1'], inputs['emb2']) - score = round(score, 5) - if score >= self.thr: - ans = 'yes' else: - ans = 'no' + output = {OutputKeys.TEXT: 'No similarity score output'} - return {OutputKeys.SCORE: score, OutputKeys.TEXT: ans} + return output - def preprocess(self, inputs: List[str], - **preprocess_params) -> Dict[str, Any]: - if len(inputs) != 2: - raise ValueError( - 'modelscope error: Two input audio files are required.') - output = {} + def preprocess(self, inputs: Union[np.ndarray, list], **preprocess_params): + output = [] for i in range(len(inputs)): if isinstance(inputs[i], str): file_bytes = File.read(inputs[i]) data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32') if len(data.shape) == 2: data = data[:, 0] + data = torch.from_numpy(data).unsqueeze(0) if fs != self.model_config['sample_rate']: - raise ValueError( - 'modelscope error: Only support %d sample rate files' - % self.model_cfg['sample_rate']) - output['data%d' % - (i + 1)] = torch.from_numpy(data).unsqueeze(0) + logger.warning( + 'The sample rate of audio is not %d, resample it.' + % self.model_config['sample_rate']) + data, fs = torchaudio.sox_effects.apply_effects_tensor( + data, + fs, + effects=[[ + 'rate', + str(self.model_config['sample_rate']) + ]]) + data = data.squeeze(0) + elif isinstance(inputs[i], np.ndarray): + assert len( + inputs[i].shape + ) == 1, 'modelscope error: Input array should be [N, T]' + data = inputs[i] + if data.dtype in ['int16', 'int32', 'int64']: + data = (data / (1 << 15)).astype('float32') + data = torch.from_numpy(data) else: raise ValueError( - 'modelscope error: The input type is temporarily restricted to audio file address' + 'modelscope error: The input type is restricted to audio address and nump array.' % i) + output.append(data) + try: + output = torch.stack(output) + except RuntimeError: + pass return output - def compute_cos_similarity(self, emb1: torch.Tensor, - emb2: torch.Tensor) -> float: + def compute_cos_similarity(self, emb1: Union[np.ndarray, torch.Tensor], + emb2: Union[np.ndarray, torch.Tensor]) -> float: + if isinstance(emb1, np.ndarray): + emb1 = torch.from_numpy(emb1) + if isinstance(emb2, np.ndarray): + emb2 = torch.from_numpy(emb2) + if len(emb1.shape): + emb1 = emb1.unsqueeze(0) + if len(emb2.shape): + emb2 = emb2.unsqueeze(0) assert len(emb1.shape) == 2 and len(emb2.shape) == 2 cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) cosine = cos(emb1, emb2) diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index ac244421a..48d328d97 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -54,7 +54,8 @@ def initiate_single_model(self, model): model, device=self.device_name, model_prefetched=True, - invoked_by=Invoke.PIPELINE) if is_model(model) else model + invoked_by=Invoke.PIPELINE, + device_map=self.device_map) if is_model(model) else model else: return model @@ -70,6 +71,7 @@ def __init__(self, preprocessor: Union[Preprocessor, List[Preprocessor]] = None, device: str = 'gpu', auto_collate=True, + device_map=None, **kwargs): """ Base class for pipeline. @@ -87,6 +89,9 @@ def __init__(self, compile_options (dict, optional): The compile options if compile=True, default None to use the default params of 'TorchModel.compile'. """ + if device_map is not None: + assert device == 'gpu', '`device` and `device_map` cannot be input at the same time!' + self.device_map = device_map verify_device(device) self.device_name = device @@ -133,13 +138,14 @@ def prepare_model(self): self._model_prepare_lock.acquire(timeout=600) def _prepare_single(model): - if isinstance(model, torch.nn.Module): + if not isinstance(model, torch.nn.Module) and hasattr( + model, 'model'): + model = model.model + if not isinstance(model, torch.nn.Module): + return + model.eval() + if self.device_map is None: model.to(self.device) - model.eval() - elif hasattr(model, 'model') and isinstance( - model.model, torch.nn.Module): - model.model.to(self.device) - model.model.eval() if not self._model_prepare: # prepare model for pytorch diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index e9d7a7858..97905fa7f 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -106,7 +106,7 @@ from .controllable_image_generation_pipeline import ControllableImageGenerationPipeline from .image_bts_depth_estimation_pipeline import ImageBTSDepthEstimationPipeline from .pedestrian_attribute_recognition_pipeline import PedestrainAttributeRecognitionPipeline - + from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline else: _import_structure = { 'action_recognition_pipeline': ['ActionRecognitionPipeline'], @@ -261,7 +261,10 @@ ], 'pedestrian_attribute_recognition_pipeline': [ 'PedestrainAttributeRecognitionPipeline' - ] + ], + 'image_panoptic_segmentation_pipeline': [ + 'ImagePanopticSegmentationPipeline', + ], } import sys diff --git a/modelscope/pipelines/cv/face_reconstruction_pipeline.py b/modelscope/pipelines/cv/face_reconstruction_pipeline.py index b9a8e3204..28f2227da 100644 --- a/modelscope/pipelines/cv/face_reconstruction_pipeline.py +++ b/modelscope/pipelines/cv/face_reconstruction_pipeline.py @@ -103,7 +103,7 @@ def __init__(self, model: str, device: str): os.path.join(model_root, 'face_alignment', 'depth-6c4283c0e0.zip'), save_ckpt_dir) self.lm_sess = face_alignment.FaceAlignment( - face_alignment.LandmarksType._3D, flip_input=False) + face_alignment.LandmarksType.THREE_D, flip_input=False) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.per_process_gpu_memory_fraction = 0.2 diff --git a/modelscope/pipelines/cv/image_panoptic_segmentation_pipeline.py b/modelscope/pipelines/cv/image_panoptic_segmentation_pipeline.py new file mode 100644 index 000000000..a566fe8cf --- /dev/null +++ b/modelscope/pipelines/cv/image_panoptic_segmentation_pipeline.py @@ -0,0 +1,92 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_segmentation, + module_name=Pipelines.image_panoptic_segmentation) +class ImagePanopticSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image panoptic segmentation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + logger.info('panoptic segmentation model, pipeline init') + + def preprocess(self, input: Input) -> Dict[str, Any]: + from mmdet.datasets.pipelines import Compose + from mmcv.parallel import collate, scatter + from mmdet.datasets import replace_ImageToTensor + cfg = self.model.cfg + # build the data pipeline + if isinstance(input, str): + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + img = np.array(load_image(input)) + img = img[:, :, ::-1] # convert to bgr + elif isinstance(input, PIL.Image.Image): + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + img = np.array(input.convert('RGB')) + elif isinstance(input, np.ndarray): + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + if len(input.shape) == 2: + img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + else: + img = input + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + # collect data + data = dict(img=img) + cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) + test_pipeline = Compose(cfg.data.test.pipeline) + data = test_pipeline(data) + # copy from mmdet_model collect data + data = collate([data], samples_per_gpu=1) + data['img_metas'] = [ + img_metas.data[0] for img_metas in data['img_metas'] + ] + data['img'] = [img.data[0] for img in data['img']] + if next(self.model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [next(self.model.parameters()).device])[0] + return data + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.inference(input) + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + # bz=1, tcguo + pan_results = inputs[0]['pan_results'] + INSTANCE_OFFSET = 1000 + ids = np.unique(pan_results)[::-1] + legal_indices = ids != self.model.num_classes # for VOID label + ids = ids[legal_indices] + labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) + segms = (pan_results[None] == ids[:, None, None]) + masks = [it.astype(np.int) for it in segms] + labels_txt = np.array(self.model.CLASSES)[labels].tolist() + outputs = { + OutputKeys.MASKS: masks, + OutputKeys.LABELS: labels_txt, + OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] + } + return outputs diff --git a/modelscope/pipelines/multi_modal/diffusers_wrapped/stable_diffusion/stable_diffusion_pipeline.py b/modelscope/pipelines/multi_modal/diffusers_wrapped/stable_diffusion/stable_diffusion_pipeline.py index c4f8bda81..f09d459d2 100644 --- a/modelscope/pipelines/multi_modal/diffusers_wrapped/stable_diffusion/stable_diffusion_pipeline.py +++ b/modelscope/pipelines/multi_modal/diffusers_wrapped/stable_diffusion/stable_diffusion_pipeline.py @@ -34,8 +34,9 @@ def __init__(self, model: str, lora_dir: str = None, **kwargs): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # load pipeline + torch_type = torch.float16 if self.device == 'cuda' else torch.float32 self.pipeline = DiffuserStableDiffusionPipeline.from_pretrained( - model, torch_dtype=torch.float16) + model, torch_dtype=torch_type) self.pipeline = self.pipeline.to(self.device) # load lora moudle to unet if lora_dir is not None: diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index cfc3645d7..a0e8a0eeb 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright (c) 2022 Zhipu.AI import os from typing import Any, Dict, Optional, Union @@ -17,7 +18,10 @@ from modelscope.utils.hub import Config, read_config from modelscope.utils.streaming_output import PipelineStreamingOutputMixin -__all__ = ['TextGenerationPipeline', 'TextGenerationT5Pipeline'] +__all__ = [ + 'TextGenerationPipeline', 'TextGenerationT5Pipeline', + 'ChatGLM6bTextGenerationPipeline', 'ChatGLM6bV2TextGenerationPipeline' +] @PIPELINES.register_module( @@ -177,3 +181,71 @@ def forward(self, inputs: Dict[str, Any], with torch.no_grad(): return self.model.generate(**inputs, **forward_params) + + +@PIPELINES.register_module( + group_key=Tasks.chat, module_name='chatglm6b-text-generation') +class ChatGLM6bTextGenerationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + quantization_bit=None, + use_bf16=False, + **kwargs): + from modelscope.models.nlp.chatglm.text_generation import ChatGLMForConditionalGeneration + model = ChatGLMForConditionalGeneration(model) if isinstance( + model, str) else model + if quantization_bit is not None: + model = model.quantize(quantization_bit) + if use_bf16: + model = model.bfloat16() + self.model = model + self.model.eval() + + super().__init__(model=model, **kwargs) + + def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: + return inputs + + # define the forward pass + def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]: + return self.model.chat(inputs) + + # format the outputs from pipeline + def postprocess(self, input, **kwargs) -> Dict[str, Any]: + return input + + +@PIPELINES.register_module( + group_key=Tasks.chat, module_name='chatglm2_6b-text-generation') +class ChatGLM6bV2TextGenerationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + quantization_bit=None, + use_bf16=False, + **kwargs): + from modelscope.models.nlp import ChatGLM2ForConditionalGeneration, ChatGLM2Tokenizer + model = ChatGLM2ForConditionalGeneration(model) if isinstance( + model, str) else model + if quantization_bit is not None: + model = model.quantize(quantization_bit) + if use_bf16: + model = model.bfloat16() + self.model = model + self.model.eval() + self.tokenizer = ChatGLM2Tokenizer.from_pretrained( + self.model.model_dir) + + super().__init__(model=model, **kwargs) + + def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: + return inputs + + # define the forward pass + def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]: + return self.model.chat(self.tokenizer, inputs['text']) + + # format the outputs from pipeline + def postprocess(self, input, **kwargs) -> Dict[str, Any]: + return input diff --git a/modelscope/preprocessors/nlp/text_generation_preprocessor.py b/modelscope/preprocessors/nlp/text_generation_preprocessor.py index 734ddbc29..418f8106b 100644 --- a/modelscope/preprocessors/nlp/text_generation_preprocessor.py +++ b/modelscope/preprocessors/nlp/text_generation_preprocessor.py @@ -139,8 +139,15 @@ def __init__(self, model_type = None if model_dir is not None: model_type = get_model_type(model_dir) - self.nlp_tokenizer = NLPTokenizerForRoberta( - model_dir, model_type, use_fast=use_fast, tokenize_kwargs=kwargs) + if tokenizer is not None: + self.nlp_tokenizer = NLPTokenizer(tokenize_kwargs=kwargs) + self.nlp_tokenizer._tokenizer = tokenizer + else: + self.nlp_tokenizer = NLPTokenizerForRoberta( + model_dir, + model_type, + use_fast=use_fast, + tokenize_kwargs=kwargs) def decode(self, tokens, **kwargs): """Decode the tokens to real text. diff --git a/modelscope/swift/__init__.py b/modelscope/swift/__init__.py new file mode 100644 index 000000000..bd8ea75eb --- /dev/null +++ b/modelscope/swift/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .optimizers.child_tuning_adamw_optimizer import calculate_fisher, ChildTuningAdamW + from .adapter import Adapter, AdapterConfig, AdapterModule + from .lora import LoRA, LoRAConfig, Linear, MergedLinear, Embedding, Conv2d + from .prompt import Prompt, PromptConfig, PromptModule + from .control_sd_lora import ControlLoRACrossAttnProcessor, ControlLoRACrossAttnProcessorV2, ControlLoRATuner + from .base import SwiftConfig, Swift +else: + _import_structure = { + 'optimizers.child_tuning_adamw_optimizer': + ['calculate_fisher', 'ChildTuningAdamW'], + 'adapter': ['Adapter', 'AdapterConfig', 'AdapterModule'], + 'lora': [ + 'LoRA', 'LoRAConfig', 'Linear', 'MergedLinear', 'Embedding', + 'Conv2d' + ], + 'prompt': ['Prompt', 'PromptConfig', 'PromptModule'], + 'control_sd_lora': [ + 'ControlLoRACrossAttnProcessor', 'ControlLoRACrossAttnProcessorV2', + 'ControlLoRATuner' + ], + 'base': ['SwiftConfig', 'Swift'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/swift/adapter.py b/modelscope/swift/adapter.py new file mode 100644 index 000000000..d7366119a --- /dev/null +++ b/modelscope/swift/adapter.py @@ -0,0 +1,195 @@ +import inspect +import os +import re +import types +from dataclasses import dataclass, field +from typing import Union + +import torch +from torch import nn + +from modelscope import snapshot_download +from modelscope.utils.constant import ModelFile +from .base import SwiftConfig + + +@dataclass +class AdapterConfig(SwiftConfig): + """ + The configuration class for the adapter module. + + Adapters project input tokens by an MLP layer. + 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019) + See http://arxiv.org/abs/1902.00751 + + Args: + dim: The dimension of the hidden states + module_name: The feedforward module to be replaced, in regex format + hidden_pos: The position of the hidden state to passed into the adapter, can be int (args) or str (kwargs) + method_name: The method to be replaced, default to replace the forward method + adapter_length: The length of the adapter length (intermediate length) + act_layer: The activation layer of the adapter + only_adapter_trainable: Whether to train only adapters + pretrained_weights: The pretrained adapter weights. + Can be a local dir, local file, or a model id from modelscope + """ + + dim: int = field(metadata={'help': 'The dimension of the hidden states'}) + + module_name: str = field( + metadata={ + 'help': 'The feedforward module to be replaced, in regex format' + }) + + hidden_pos: Union[str, int] = field( + metadata={ + 'help': + 'The position of the hidden state to passed into the adapter, can be int (args) or str (kwargs)' + }) + + method_name: str = field( + default='forward', + metadata={ + 'help': + 'The method to be replaced, default to replace the forward method' + }) + + adapter_length: int = field( + default=128, + metadata={ + 'help': 'The length of the adapter length (intermediate length)' + }) + + act_layer: nn.Module = field( + default=nn.GELU, + metadata={'help': 'The activation layer of the adapter'}) + + only_adapter_trainable: bool = field( + default=True, metadata={'help': 'Whether to train only adapters'}) + + pretrained_weights: str = field( + default=None, + metadata={ + 'help': + 'The pretrained adapter weights. Can be a local dir, local file, or a model id from modelscope' + }) + + +class Adapter: + + @staticmethod + def prepare_model(model: nn.Module, config: AdapterConfig): + module_keys = [key for key, _ in model.named_modules()] + + for module_key in module_keys: + if re.fullmatch(config.module_name, module_key): # noqa + module = model.get_submodule(module_key) + + def _forward(self, *args, **kwargs): + args = self.forward_origin(*args, **kwargs) + if isinstance(args, (tuple, list, dict)): + if isinstance(config.hidden_pos, int): + return args[0:config.hidden_pos] + args[ + config.hidden_pos] + getattr(self, 'adapter')(args[config.hidden_pos]) \ + + args[config.hidden_pos + 1:] # noqa + else: + kwargs[config.hidden_pos] = args[ + config.hidden_pos] + getattr(self, 'adapter')( + args[config.hidden_pos]) + elif isinstance(args, torch.Tensor): + args = getattr(self, 'adapter')(args) + return args + + def _feed_forward_chunk(self, attention_output): + return _forward(self, attention_output) + + module.forward_origin = getattr(module, config.method_name) + num_args_in_forward_chunk_fn = len( + inspect.signature(module.forward_origin).parameters) + if config.method_name == 'feed_forward_chunk' and num_args_in_forward_chunk_fn == 1: + setattr(module, config.method_name, + types.MethodType(_feed_forward_chunk, module)) + else: + setattr(module, config.method_name, + types.MethodType(_forward, module)) + adapter_module = AdapterModule(config.dim, + config.adapter_length, + config.act_layer) + setattr(module, 'adapter', adapter_module) + + if config.only_adapter_trainable: + for n, p in model.named_parameters(): + if 'adapter' not in n: + p.requires_grad = False + + def state_dict_hook(module, destination, prefix, local_metadata): + return { + key: value + for key, value in destination.items() if 'adapter' in key + } + + model.state_dict_hook_handle = model._register_state_dict_hook( + state_dict_hook) + + def load_state_dict(self, state_dict, strict=True): + return self.load_state_dict_origin(state_dict, False) + + model.load_state_dict_origin = model.load_state_dict + model.load_state_dict = types.MethodType(load_state_dict, model) + + if config.pretrained_weights is not None: + if not os.path.exists(config.pretrained_weights): + model_dir = snapshot_download(config.pretrained_weights) + pretrained_weights = os.path.join( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + elif os.path.isfile(config.pretrained_weights): + pretrained_weights = config.pretrained_weights + else: + pretrained_weights = os.path.join( + config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE) + model.load_state_dict(torch.load(pretrained_weights)) + return model + + +class AdapterModule(nn.Module): + """The implementation of adapter tuning method. + + Adapters project input tokens by an MLP layer. + 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019) + See http://arxiv.org/abs/1902.00751 + + Attributes: + dim: An integer indicating the embedding dimension. + adapter_length: An integer indicating the length of adapter tuning. + """ + + def __init__( + self, + dim, + adapter_length=None, + act_layer=nn.GELU, + ): + super(AdapterModule, self).__init__() + self.dim = dim + self.adapter_length = adapter_length + # self.adapter_type = adapter_type + self.ln1 = nn.Linear(dim, adapter_length) + self.activate = act_layer() + self.ln2 = nn.Linear(adapter_length, dim) + self.init_weights() + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + self.apply(_init_weights) + + def forward(self, x, identity=None): + out = self.ln2(self.activate(self.ln1(x))) + if identity is None: + identity = x + out = identity + out + return out diff --git a/modelscope/swift/base.py b/modelscope/swift/base.py new file mode 100644 index 000000000..441521ca2 --- /dev/null +++ b/modelscope/swift/base.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + + +@dataclass +class SwiftConfig: + pass + + +class Swift: + + @staticmethod + def prepare_model(model, config: SwiftConfig): + """Prepare the module and returns the new module. + + Args: + model: The model to tune. + config: The config of the tuner. + + Returns: + The tuned model. + """ + from .lora import LoRA, LoRAConfig + from .adapter import Adapter, AdapterConfig + from .prompt import Prompt, PromptConfig + if isinstance(config, LoRAConfig): + return LoRA.prepare_model(model, config) + elif isinstance(config, AdapterConfig): + return Adapter.prepare_model(model, config) + elif isinstance(config, PromptConfig): + return Prompt.prepare_model(model, config) + return None diff --git a/modelscope/tuners/control_sd_lora.py b/modelscope/swift/control_sd_lora.py similarity index 100% rename from modelscope/tuners/control_sd_lora.py rename to modelscope/swift/control_sd_lora.py diff --git a/modelscope/tuners/lora.py b/modelscope/swift/lora.py similarity index 76% rename from modelscope/tuners/lora.py rename to modelscope/swift/lora.py index ba1e92e12..3c0be6ba9 100644 --- a/modelscope/tuners/lora.py +++ b/modelscope/swift/lora.py @@ -4,93 +4,148 @@ import logging import math import os.path +import re import types +from dataclasses import dataclass, field from typing import Dict, List import torch import torch.nn as nn import torch.nn.functional as F +from modelscope import snapshot_download +from modelscope.utils.constant import ModelFile +from .base import SwiftConfig + logger = logging.getLogger(__name__) -class LoRATuner: +@dataclass +class LoRAConfig(SwiftConfig): + """ + The configuration class for the loRA module. + + Args: + rank: The rank of the LoRA module + replace_modules: The modules to be replaced by LoRA, can be the end of the module name or a regex string + lora_alpha: The factor to add the lora weights + lora_dropout: The dropout rate of the lora module + merge_weights: Whether to merge weights when validating + use_merged_linear: Whether to replace with merged linear layer + enable_lora: The modules need to be turned on when using the merged linear layer + fan_in_fan_out: Set this to True if the layer to replace stores weight like (fan_in, fan_out) + bias: Bias type. Values ca be "none", "all" or "lora_only" + only_lora_trainable: Whether to train only lora + pretrained_weights: The pretrained lora weights. + Can be a local dir, local file, or a model id from modelscope + """ + + rank: int = field( + default=6, metadata={'help': 'The rank of the LoRA module'}) + replace_modules: List = field( + default=None, + metadata={ + 'help': + 'The modules to be replaced by LoRA, can be the end of the module name or a regex string' + }) + lora_alpha: float = field( + default=1., metadata={'help': 'The factor to add the lora weights'}) + lora_dropout: float = field( + default=0., metadata={'help': 'The dropout rate of the lora module'}) + merge_weights: bool = field( + default=True, + metadata={'help': 'Whether to merge weights when validating'}) + use_merged_linear: bool = field( + default=False, + metadata={'help': 'Whether to replace with merged linear layer'}) + enable_lora: List = field( + default=None, + metadata={ + 'help': + 'The modules need to be turned on when using the merged linear layer' + }) + fan_in_fan_out: bool = field( + default=False, + metadata={ + 'help': + 'Set this to True if the layer to replace stores weight like (fan_in, fan_out)' + }) + bias: str = field( + default='none', + metadata={ + 'help': 'Bias type. Values ca be "none", "all" or "lora_only"' + }) + only_lora_trainable: bool = field( + default=True, metadata={'help': 'Whether to train only lora'}) + pretrained_weights: str = field( + default=None, + metadata={ + 'help': + 'The pretrained lora weights. Can be a local dir, local file, or a model id from modelscope' + }) + + +class LoRA: @staticmethod - def tune(model: nn.Module, - rank=6, - replace_modules=None, - lora_alpha=1., - lora_dropout=0., - merge_weights=True, - fan_in_fan_out=False, - bias='none', - pretrained_tuner=None): - """Tune a model with lora. + def prepare_model(model: nn.Module, config: LoRAConfig): + """Tune a model with LoRA. Args: - model: The torch.nn.Module containing the target module to be patched. - rank: The lora rank. - replace_modules: The module names to be replaced, the replacing strategy is `end with`. - lora_alpha: The alpha value for lora module. - lora_dropout: The dropout value for lora module. - merge_weights: If merge_weights set to True, when the module turns to `eval`, the lora weights - will be added into the origin weight to reduce calculation. - fan_in_fan_out: Set this to True if the layer to replace stores weight like (fan_in, fan_out). - bias: The grad strategy for bias, can be `none`, 'all' or 'lora_only'. - pretrained_tuner: The pretrained file of lora. + config: The LoRAConfig instance. Returns: The lora modules """ - modules = LoRATuner._dynamic_patch_lora( + LoRA._dynamic_patch_lora( model, - replace_modules=replace_modules, - r=rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=merge_weights, - fan_in_fan_out=fan_in_fan_out) - - mark_only_lora_as_trainable(model, bias) + replace_modules=config.replace_modules, + r=config.rank, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + merge_weights=config.merge_weights, + use_merged_linear=config.use_merged_linear, + enable_lora=config.enable_lora, + fan_in_fan_out=config.fan_in_fan_out) + + if config.only_lora_trainable: + mark_only_lora_as_trainable(model, config.bias) def state_dict_hook(module, destination, prefix, local_metadata): - return lora_state_dict(destination, bias) + return lora_state_dict(destination, config.bias) model.state_dict_hook_handle = model._register_state_dict_hook( state_dict_hook) - def warning_hook(module, incompatible_keys): - logger.info( - f'The {module.__class__.__name__} module has unmatched keys: {incompatible_keys},' - f'this is converted to a notice with respect to LoRA') - for ik in incompatible_keys: - ik.clear() - - if hasattr(model, 'register_load_state_dict_post_hook'): - model.load_state_dict_hook_handle = model.register_load_state_dict_post_hook( - warning_hook) - else: + def load_state_dict(self, state_dict, strict=True): + return self.load_state_dict_origin(state_dict, False) - def load_state_dict(self, state_dict, strict=True): - return self.load_state_dict_origin(state_dict, False) + model.load_state_dict_origin = model.load_state_dict + model.load_state_dict = types.MethodType(load_state_dict, model) - model.load_state_dict_origin = model.load_state_dict - model.load_state_dict = types.MethodType(load_state_dict, model) + if config.pretrained_weights is not None: + if not os.path.exists(config.pretrained_weights): + model_dir = snapshot_download(config.pretrained_weights) + pretrained_weights = os.path.join( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + elif os.path.isfile(config.pretrained_weights): + pretrained_weights = config.pretrained_weights + else: + pretrained_weights = os.path.join( + config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE) + model.load_state_dict(torch.load(pretrained_weights)) - if pretrained_tuner is not None and os.path.isfile(pretrained_tuner): - logger.info(f'Loading LoRA weights from file: {pretrained_tuner}') - model.load_state_dict(torch.load(pretrained_tuner)) - - return modules + return model @staticmethod - def _dynamic_patch_lora(model, replace_modules, **kwargs): + def _dynamic_patch_lora(model, replace_modules, use_merged_linear, + **kwargs): """Dynamic patch lora to model Args: model: The torch.nn.Module containing the target module to be patched. replace_modules: The module names to be replaced, the replacing strategy is `end with`. + use_merged_linear: Whether to replace with merged linear layer **kwargs: The arguments passed from `tune` which are needed by lora. Returns: @@ -103,8 +158,13 @@ def _dynamic_patch_lora(model, replace_modules, **kwargs): replace_modules = [replace_modules] for module_key in module_keys: - if any([module_key.endswith(name) - for name in replace_modules]): # noqa + if isinstance(replace_modules, str): + target_module_found = re.fullmatch(replace_modules, module_key) + else: + target_module_found = any( + module_key.endswith(target_key) + for target_key in replace_modules) + if target_module_found: # noqa parts = module_key.split('.') module = model.get_submodule('.'.join(parts[:-1])) sub_module = model.get_submodule(module_key) @@ -112,11 +172,19 @@ def _dynamic_patch_lora(model, replace_modules, **kwargs): lora_module = None if isinstance(sub_module, torch.nn.Linear): - lora_module = Linear( - sub_module.in_features, - sub_module.out_features, - bias=sub_module.bias is not None, - **kwargs) + if use_merged_linear: + lora_module = MergedLinear( + sub_module.in_features, + sub_module.out_features, + bias=sub_module.bias is not None, + **kwargs) + else: + kwargs.pop('enable_lora', None) + lora_module = Linear( + sub_module.in_features, + sub_module.out_features, + bias=sub_module.bias is not None, + **kwargs) elif isinstance(sub_module, torch.nn.Conv2d): kwargs.pop('fan_in_fan_out', None) lora_module = Conv2d( @@ -140,9 +208,13 @@ def _dynamic_patch_lora(model, replace_modules, **kwargs): return modules @staticmethod - def unpatch_lora(model, replace_modules): + def unpatch_lora(model, config: LoRAConfig): """Unpatch lora modules and merge the weights to original modules. + LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network. + 'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021) + See https://arxiv.org/abs/2106.09685 + Args: model: The model called with `tune` function. replace_modules: The module names to be replaced, the replacing strategy is `end with`. @@ -152,13 +224,17 @@ def unpatch_lora(model, replace_modules): """ modules = [] module_keys = [key for key, _ in model.named_modules()] - assert isinstance(replace_modules, (str, list)) - if isinstance(replace_modules, str): - replace_modules = [replace_modules] + assert isinstance(config.replace_modules, (str, list)) + replace_modules = config.replace_modules for module_key in module_keys: - if any([module_key.endswith(name) - for name in replace_modules]): # noqa + if isinstance(replace_modules, str): + target_module_found = re.fullmatch(replace_modules, module_key) + else: + target_module_found = any( + module_key.endswith(target_key) + for target_key in replace_modules) + if target_module_found: # noqa parts = module_key.split('.') module = model.get_submodule('.'.join(parts[:-1])) sub_module = model.get_submodule(module_key) diff --git a/modelscope/swift/optimizers/__init__.py b/modelscope/swift/optimizers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py b/modelscope/swift/optimizers/child_tuning_adamw_optimizer.py similarity index 97% rename from modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py rename to modelscope/swift/optimizers/child_tuning_adamw_optimizer.py index 742158016..02b459fa3 100644 --- a/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py +++ b/modelscope/swift/optimizers/child_tuning_adamw_optimizer.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import types from typing import Callable, Iterable, Tuple import numpy as np @@ -22,7 +21,6 @@ from torch.optim import Optimizer from modelscope.utils.logger import get_logger -from .builder import OPTIMIZERS, default_group logger = get_logger() @@ -72,8 +70,6 @@ def calculate_fisher(model: torch.nn.Module, return gradient_mask -@OPTIMIZERS.register_module( - group_key=default_group, module_name='ChildTuningAdamW') class ChildTuningAdamW(Optimizer): def __init__(self, diff --git a/modelscope/swift/prompt.py b/modelscope/swift/prompt.py new file mode 100644 index 000000000..715f69427 --- /dev/null +++ b/modelscope/swift/prompt.py @@ -0,0 +1,214 @@ +import os +import re +import types +from dataclasses import dataclass, field +from typing import Union + +import torch +from torch import nn + +from modelscope import snapshot_download +from modelscope.utils.constant import ModelFile +from .base import SwiftConfig + + +@dataclass +class PromptConfig(SwiftConfig): + """ + The configuration class for the prompt module. + + Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens + and prepend to the original tokens in the first layer or multiple layers. + 'Visual Prompt Tuning' by Jia et al.(2022) + See https://arxiv.org/abs/2203.12119 + + Here we apply the VPT to other fields. + + Args: + dim: The dimension of the hidden states + module_layer_name: The layer module to be replaced, in regex format + embedding_pos: The position of the embedding tensor + attention_mask_pos: The position of the attention mask + attention_mask_value: The value to pad to the attention mask + prompt_length: The length of the prompt tokens + only_prompt_trainable: Whether to train only prompt + attach_front: When set to True, prompt is attached in front of the embedding + pretrained_weights: The pretrained prompt weights. Can be a local dir, local file, + or a model id from modelscope + """ + + dim: int = field(metadata={'help': 'The dimension of the hidden states'}) + + module_layer_name: str = field( + metadata={'help': 'The layer module to be replaced, in regex format'}) + + embedding_pos: Union[str, int] = field( + metadata={'help': 'The position of the embedding tensor'}) + + attention_mask_pos: Union[str, int] = field( + default=None, metadata={'help': 'The position of the attention mask'}) + + attention_mask_value: Union[float, int, bool] = field( + default=0., + metadata={'help': 'The value to pad to the attention mask'}) + + prompt_length: int = field( + default=16, metadata={'help': 'The length of the prompt tokens'}) + + only_prompt_trainable: bool = field( + default=True, metadata={'help': 'Whether to train only prompt'}) + + attach_front: bool = field( + default=True, + metadata={ + 'help': + 'When set to True, prompt is attached in front of the embedding' + }) + + pretrained_weights: str = field( + default=None, + metadata={ + 'help': + 'The pretrained prompt weights. Can be a local dir, local file, or a model id from modelscope' + }) + + +class Prompt: + + @staticmethod + def prepare_model(model: nn.Module, config: PromptConfig): + module_keys = [key for key, _ in model.named_modules()] + for module_key in module_keys: + if re.fullmatch(config.module_layer_name, module_key): # noqa + module = model.get_submodule(module_key) + + def _forward(self, *args, **kwargs): + if isinstance(config.embedding_pos, int): + input_embedding = args[config.embedding_pos] + else: + input_embedding = kwargs[config.embedding_pos] + + input_embedding = getattr( + self, 'prompt').forward(input_embedding) + if isinstance(config.embedding_pos, int): + args = type(args)( + args[0:config.embedding_pos] + (input_embedding, ) + + args[config.embedding_pos + 1:]) + else: + kwargs[config.embedding_pos] = input_embedding + + if config.attention_mask_pos: + attention_mask = None + if isinstance(config.attention_mask_pos, int): + attention_mask = args[config.attention_mask_pos] + elif isinstance(config.attention_mask_pos, str): + attention_mask = kwargs[config.attention_mask_pos] + + if attention_mask is not None: + attention_mask = getattr( + self, + 'prompt').patch_attention_mask(attention_mask) + if isinstance(config.attention_mask_pos, int): + args = type(args)( + args[0:config.attention_mask_pos] + + (attention_mask, ) + + args[config.attention_mask_pos + 1:]) + else: + kwargs[config.attention_mask_pos] = attention_mask + + return self.forward_origin(*args, **kwargs) + + module.forward_origin = module.forward + module.forward = types.MethodType(_forward, module) + prompt_module = PromptModule(config.dim, + int(module_key.rsplit('.')[-1]), + config.prompt_length, + config.attention_mask_value, + config.attach_front) + setattr(module, 'prompt', prompt_module) + + if config.only_prompt_trainable: + for n, p in model.named_parameters(): + if 'prompt' not in n: + p.requires_grad = False + + def state_dict_hook(module, destination, prefix, local_metadata): + return { + key: value + for key, value in destination.items() if 'prompt' in key + } + + model.state_dict_hook_handle = model._register_state_dict_hook( + state_dict_hook) + + def load_state_dict(self, state_dict, strict=True): + return self.load_state_dict_origin(state_dict, False) + + model.load_state_dict_origin = model.load_state_dict + model.load_state_dict = types.MethodType(load_state_dict, model) + + if config.pretrained_weights is not None: + if not os.path.exists(config.pretrained_weights): + model_dir = snapshot_download(config.pretrained_weights) + pretrained_weights = os.path.join( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + elif os.path.isfile(config.pretrained_weights): + pretrained_weights = config.pretrained_weights + else: + pretrained_weights = os.path.join( + config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE) + model.load_state_dict(torch.load(pretrained_weights)) + return model + + +class PromptModule(nn.Module): + """The implementation of vision prompt tuning method. + + Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens + and prepend to the original tokens in the first layer or multiple layers. + 'Visual Prompt Tuning' by Jia et al.(2022) + See https://arxiv.org/abs/2203.12119 + + Attributes: + dim: An integer indicating the embedding dimension. + layer_num: An integer indicating number of layers. + prompt_length: An integer indicating the length of vision prompt tuning. + """ + + def __init__(self, + dim, + layer_num, + prompt_length=None, + mask_values=0., + attach_front=True): + super(PromptModule, self).__init__() + self.dim = dim + self.layer_num = layer_num + self.prompt_length = prompt_length + self.mask_values = mask_values + self.attach_front = attach_front + + self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim)) + nn.init.xavier_uniform_(self.prompt_token) + + def forward(self, x): + prompt_token = self.prompt_token.expand(x.shape[0], -1, -1) + + if self.layer_num == 0: + if self.attach_front: + x = torch.cat((prompt_token, x), dim=1) + else: + x = torch.cat((x, prompt_token), dim=1) + else: + if self.attach_front: + x = torch.cat((prompt_token, x[:, self.prompt_length:, :]), + dim=1) + else: + x = torch.cat((x[:, :-self.prompt_length, :], prompt_token), + dim=1) + return x + + def patch_attention_mask(self, m): + prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length), + self.mask_values).to(m.device) + return torch.cat((prefix_attention_mask, m), dim=-1) diff --git a/modelscope/swift/sd_lora.py b/modelscope/swift/sd_lora.py new file mode 100644 index 000000000..feff05f4b --- /dev/null +++ b/modelscope/swift/sd_lora.py @@ -0,0 +1,218 @@ +# Copyright 2023-2024 The Alibaba Fundamental Vision Team Authors. All rights reserved. +# The implementation is adopted from HighCWu, +# made pubicly available under the Apache License 2.0 License at https://github.com/HighCWu/ControlLoRA +import os +from dataclasses import dataclass +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.cross_attention import CrossAttention, LoRALinearLayer +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.outputs import BaseOutput + + +@dataclass +class TunerOutput(BaseOutput): + lora_states: Tuple[torch.FloatTensor] + + +class LoRACrossAttnProcessor(nn.Module): + """ The implementation of lora attention module. + """ + + def __init__(self, + hidden_size, + cross_attention_dim=None, + rank=4, + post_add=False, + key_states_skipped=False, + value_states_skipped=False, + output_states_skipped=False): + """ Initialize a lora attn instance. + Args: + hidden_size (`int`): The number of channels in embedding. + cross_attention_dim (`int`, *optional*): + The number of channels in the hidden_states. If not given, defaults to `hidden_size`. + rank (`int`, *optional*, defaults to 4): The number of rank of lora. + post_add (`bool`, *optional*, defaults to False): Set to `True`, conduct weighted + adding operation after lora. + key_states_skipped (`bool`, *optional*, defaults to False): + Set to `True` for skip to perform lora on key value. + value_states_skipped (`bool`, *optional*, defaults to False): + Set to `True` for skip to perform lora on value. + output_states_skipped (`bool`, *optional*, defaults to False): + Set to `True` for skip to perform lora on output value. + """ + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.post_add = post_add + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + if not key_states_skipped: + self.to_k_lora = LoRALinearLayer( + hidden_size if post_add else + (cross_attention_dim or hidden_size), hidden_size, rank) + if not value_states_skipped: + self.to_v_lora = LoRALinearLayer( + hidden_size if post_add else + (cross_attention_dim or hidden_size), hidden_size, rank) + if not output_states_skipped: + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + self.key_states_skipped: bool = key_states_skipped + self.value_states_skipped: bool = value_states_skipped + self.output_states_skipped: bool = output_states_skipped + + def skip_key_states(self, is_skipped: bool = True): + if not is_skipped: + assert hasattr(self, 'to_k_lora') + self.key_states_skipped = is_skipped + + def skip_value_states(self, is_skipped: bool = True): + if not is_skipped: + assert hasattr(self, 'to_q_lora') + self.value_states_skipped = is_skipped + + def skip_output_states(self, is_skipped: bool = True): + if not is_skipped: + assert hasattr(self, 'to_out_lora') + self.output_states_skipped = is_skipped + + def __call__(self, + attn: CrossAttention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + scale=1.0): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask( + attention_mask=attention_mask, + target_length=sequence_length, + batch_size=batch_size) + + query = attn.to_q(hidden_states) + query = query + scale * self.to_q_lora( + query if self.post_add else hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + if not self.key_states_skipped: + key = key + scale * self.to_k_lora( + key if self.post_add else encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + if not self.value_states_skipped: + value = value + scale * self.to_v_lora( + value if self.post_add else encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + out = attn.to_out[0](hidden_states) + if not self.output_states_skipped: + out = out + scale * self.to_out_lora( + out if self.post_add else hidden_states) + hidden_states = out + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class LoRATuner(ModelMixin, ConfigMixin): + + @staticmethod + def tune( + model: nn.Module, + tuner_config=None, + pretrained_tuner=None, + ): + tuner = LoRATuner.from_config(tuner_config) + if pretrained_tuner is not None and os.path.exists(pretrained_tuner): + tuner.load_state_dict( + torch.load(pretrained_tuner, map_location='cpu'), strict=True) + tune_layers_list = list( + [list(layer_list) for layer_list in tuner.lora_layers]) + assert hasattr(model, 'unet') + unet = model.unet + tuner.to(unet.device) + tune_attn_procs = tuner.set_tune_layers(unet, tune_layers_list) + unet.set_attn_processor(tune_attn_procs) + return tuner + + def set_tune_layers(self, unet, tune_layers_list): + n_ch = len(unet.config.block_out_channels) + control_ids = [i for i in range(n_ch)] + tune_attn_procs = {} + + for name in unet.attn_processors.keys(): + if name.startswith('mid_block'): + control_id = control_ids[-1] + elif name.startswith('up_blocks'): + block_id = int(name[len('up_blocks.')]) + control_id = list(reversed(control_ids))[block_id] + elif name.startswith('down_blocks'): + block_id = int(name[len('down_blocks.')]) + control_id = control_ids[block_id] + + tune_layers = tune_layers_list[control_id] + if len(tune_layers) != 0: + tune_layer = tune_layers.pop(0) + tune_attn_procs[name] = tune_layer + return tune_attn_procs + + @register_to_config + def __init__( + self, + lora_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + lora_cross_attention_dims: Tuple[List[int]] = ([ + None, 768, None, 768, None, 768, None, 768, None, 768 + ], [None, 768, None, 768, None, 768, None, 768, None, + 768], [None, 768, None, 768, None, 768, None, 768, None, + 768], [None, 768]), + lora_rank: int = 4, + lora_post_add: bool = False, + lora_key_states_skipped: bool = False, + lora_value_states_skipped: bool = False, + lora_output_states_skipped: bool = False, + ): + super().__init__() + + lora_cls = LoRACrossAttnProcessor + + self.lora_layers = nn.ModuleList([]) + + for i, lora_cross_attention_dim in enumerate( + lora_cross_attention_dims): + self.lora_layers.append( + nn.ModuleList([ + lora_cls( + lora_block_out_channels[i], + cross_attention_dim=cross_attention_dim, + rank=lora_rank, + post_add=lora_post_add, + key_states_skipped=lora_key_states_skipped, + value_states_skipped=lora_value_states_skipped, + output_states_skipped=lora_output_states_skipped) + for cross_attention_dim in lora_cross_attention_dim + ])) + + def forward(self) -> Union[TunerOutput, Tuple]: + lora_states_list = [] + tune_layers_list = list( + [list(layer_list) for layer_list in self.lora_layers]) + for tune_list in tune_layers_list: + for tune_layer in tune_list: + lora_states_list.append(tune_layer.to_q_lora.down.weight) + return TunerOutput(lora_states=tuple(lora_states_list)) diff --git a/modelscope/trainers/hooks/checkpoint/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint/checkpoint_hook.py index b8be06823..9bea8aaf5 100644 --- a/modelscope/trainers/hooks/checkpoint/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint/checkpoint_hook.py @@ -187,7 +187,8 @@ def after_run(self, trainer): strategy=self.upload_strategy, done=True) wait_for_done(self.PUSH_TO_HUB_QUEUE_NAME) - self.logger.info('Uploading models done.') + if self.push_to_hub: + self.logger.info('Uploading models done.') def _push_to_hub(self, trainer, prefix, output_dir, delete_dir=False): if self.is_model_id is None: diff --git a/modelscope/trainers/hooks/checkpoint/checkpoint_processor.py b/modelscope/trainers/hooks/checkpoint/checkpoint_processor.py index f28fc3977..4693968ae 100644 --- a/modelscope/trainers/hooks/checkpoint/checkpoint_processor.py +++ b/modelscope/trainers/hooks/checkpoint/checkpoint_processor.py @@ -26,7 +26,6 @@ def prepare_output(self, trainer, output_dir): trainer: The trainer instance. output_dir: The target folder used in inference. """ - model = trainer.unwrap_module(trainer.model) config = trainer.cfg # override pipeline by tasks name after finetune done, @@ -38,8 +37,7 @@ def prepare_output(self, trainer, output_dir): # TODO a temp fix to avoid pipeline_name and task mismatch config['pipeline'] = {'type': config['task']} - self.copy_files_and_dump_config(trainer, output_dir, config, - self._bin_file(model)) + self.copy_files_and_dump_config(trainer, output_dir, config, '*.bin') @staticmethod def copy_files_and_dump_config(trainer, output_dir, config, bin_file): diff --git a/modelscope/trainers/hooks/distributed/deepspeed_hook.py b/modelscope/trainers/hooks/distributed/deepspeed_hook.py new file mode 100644 index 000000000..28d5d79b6 --- /dev/null +++ b/modelscope/trainers/hooks/distributed/deepspeed_hook.py @@ -0,0 +1,403 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2020 The HuggingFace Team. All rights reserved. +import math +import os +import shutil +from functools import partialmethod + +import deepspeed +import torch +from deepspeed import DeepSpeedEngine +from megatron_util import mpu, print_rank_0 +from transformers.deepspeed import HfTrainerDeepSpeedConfig + +from modelscope.metainfo import Hooks +from modelscope.trainers.hooks import LoadCheckpointHook +from modelscope.trainers.hooks.builder import HOOKS +from modelscope.trainers.hooks.checkpoint.checkpoint_hook import ( + BestCkptSaverHook, CheckpointHook) +from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \ + CheckpointProcessor +from modelscope.trainers.hooks.hook import Hook +from modelscope.trainers.hooks.lr_scheduler_hook import (LrSchedulerHook, + LrSchedulerProcessor) +from modelscope.trainers.hooks.optimizer.base import (OptimizerHook, + OptimizerProcessor) +from modelscope.trainers.hooks.priority import Priority +from modelscope.utils.checkpoint import save_checkpoint +from modelscope.utils.constant import DistributedParallelType +from modelscope.utils.device import create_device +from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, + init_dist) + + +class DeepSpeedConfig(HfTrainerDeepSpeedConfig): + """ + The `DeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the + same lifespan as the latter. + """ + + def is_auto(self, ds_key_long): + val = self.get_value(ds_key_long) + if val is None: + return False + else: + return val == 'auto' + + def trainer_config_finalize(self, args, model, num_training_steps): + """ + This stage runs after we have the model and know num_training_steps. + + Now we can complete the configuration process. + """ + # zero + + # deal with config keys that use `auto` value and rely on model's hidden_size + hidden_size_based_keys = [ + 'zero_optimization.reduce_bucket_size', + 'zero_optimization.stage3_prefetch_bucket_size', + 'zero_optimization.stage3_param_persistence_threshold', + ] + hidden_size_auto_keys = [ + x for x in hidden_size_based_keys if self.is_auto(x) + ] + + if len(hidden_size_auto_keys) > 0: + if hasattr(model.config, 'hidden_size'): + hidden_size = model.config.hidden_size + elif hasattr(model.config, 'hidden_sizes'): + # if there are many hidden sizes pick the largest one + hidden_size = max(model.config.hidden_sizes) + else: + raise ValueError( + "The model's config file has neither `hidden_size` nor `hidden_sizes` entry, " + "therefore it's not possible to automatically fill out the following `auto` entries " + f'in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing ' + '`auto` values for these keys with an integer value of your choice.' + ) + + self.fill_only('zero_optimization.reduce_bucket_size', + hidden_size * hidden_size) + if self.is_zero3(): + # automatically assign the optimal config values based on model config + self.fill_only('zero_optimization.stage3_prefetch_bucket_size', + 0.9 * hidden_size * hidden_size) + self.fill_only( + 'zero_optimization.stage3_param_persistence_threshold', + 10 * hidden_size) + + # scheduler + options = args.train.optimizer.get('options', {}) + warmup = options.get('warmup', {}) + warmup_steps = warmup.get('warmup_steps', 0) + warmup_ratio = warmup.get('warmup_ratio', 0.0) + warmup_steps = warmup_steps if warmup_steps > 0 else math.ceil( + num_training_steps * warmup_ratio) + self.fill_match('scheduler.params.total_num_steps', num_training_steps) + self.fill_match('scheduler.params.warmup_num_steps', warmup_steps) + + if len(self.mismatches) > 0: + mismatches = '\n'.join(self.mismatches) + raise ValueError( + 'Please correct the following DeepSpeed config values that mismatch TrainingArguments' + f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'." + ) + + +def deepspeed_optim_sched(trainer, hf_deepspeed_config, num_training_steps): + config = hf_deepspeed_config.config + optimizer = None + if 'optimizer' not in config: + if hf_deepspeed_config.is_offload(): + logger.info( + 'Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the' + ' custom optimizer has both CPU and GPU implementation (except LAMB)' + ) + + # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch. + # But trainer uses AdamW by default. + optimizer = trainer.optimizer + # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer` + config['zero_allow_untested_optimizer'] = True + + lr_scheduler = None + if 'scheduler' not in config: + lr_scheduler = trainer.scheduler + + return optimizer, lr_scheduler + + +class DeepspeedProcessor(CheckpointProcessor, LrSchedulerProcessor, + OptimizerProcessor): + + _BIN_FILE_DIR = 'model' + + def rank_name(self): + # TODO + try: + tp_world_size = mpu.get_tensor_model_parallel_world_size() + if tp_world_size == 1: + return '' + mp_rank = mpu.get_tensor_model_parallel_rank() + return '_mp_rank_{:02d}'.format(mp_rank) + except (ImportError, AssertionError): + return '' + + def get_bin_filename(self, with_mpu=True): + if not with_mpu: + return 'pytorch_model.bin' + else: + mp_rank = mpu.get_tensor_model_parallel_rank() + rank = '{:02d}'.format(mp_rank) + return f'mp_rank_{rank}_model_states.pt' + + def save_checkpoints(self, + trainer, + checkpoint_path_prefix, + output_dir, + meta=None): + model = trainer.unwrap_module(trainer.model) + _train_state_file = checkpoint_path_prefix + self.rank_name( + ) + CheckpointProcessor.TRAINER_STATE_SUFFIX + # Save pth file without model state_dict + save_checkpoint( + model, _train_state_file, None, None, meta=meta, with_model=False) + + save_dir = os.path.dirname(checkpoint_path_prefix) + prefix = os.path.basename(checkpoint_path_prefix) + with_mpu = not mpu.is_unitialized() + bin_file = self.get_bin_filename(with_mpu) + src_file = os.path.join(checkpoint_path_prefix, bin_file) + if self.zero_stage == 3 or with_mpu: + trainer.model.save_checkpoint(save_dir, prefix) + else: + save_checkpoint( + model, src_file, None, None, meta=None, with_meta=False) + + if self.zero_stage == 3: + return + if with_mpu: + dest_file = os.path.join(output_dir, self._BIN_FILE_DIR, bin_file) + else: + dest_file = os.path.join(output_dir, bin_file) + if os.path.isfile(dest_file): + os.unlink(dest_file) + + try: + os.link(src_file, dest_file) + except OSError as e: + get_logger().error( + f'Link {src_file} to {dest_file} error: {e}, ' + 'changing to copy the bin file, this may case more space usage.' + ) + shutil.copyfile(src_file, dest_file) + + def remove_checkpoints(self, trainer, checkpoint_path_prefix): + _train_state_file = checkpoint_path_prefix + self.rank_name( + ) + CheckpointProcessor.TRAINER_STATE_SUFFIX + if os.path.isfile(_train_state_file): + os.remove(_train_state_file) + + shutil.rmtree(checkpoint_path_prefix, ignore_errors=True) + + def load_checkpoints(self, checkpoint_path_prefix, trainer, load_all_state, + strict): + assert os.path.isdir(checkpoint_path_prefix) + path = os.path.dirname(checkpoint_path_prefix) + tag = os.path.basename(checkpoint_path_prefix) + + meta = {} + _train_state_file = checkpoint_path_prefix + self.rank_name( + ) + CheckpointProcessor.TRAINER_STATE_SUFFIX + if os.path.isfile(_train_state_file): + meta = self.load_trainer_state(trainer, _train_state_file, + load_all_state) + + if isinstance(trainer.model, DeepSpeedEngine): + # DeepSpeedEngine is initialized + trainer.model.load_checkpoint( + path, + tag, + load_module_strict=strict, + load_module_only=not load_all_state, + ) + else: + # in eval or prediction + save_dir = checkpoint_path_prefix + bin_file = self.get_bin_filename() + model_file = os.path.join(save_dir, bin_file) + checkpoint = torch.load( + model_file, map_location=lambda storage, loc: storage) + checkpoint = checkpoint['module'] + model_dict = trainer.unwrap_module(trainer.model).state_dict() + for key in checkpoint: + if key not in model_dict.keys(): + print_rank_0('Skip key: ' + key) + else: + print_rank_0('Loading key: ' + key) + trainer.unwrap_module(trainer.model).load_state_dict( + checkpoint, strict=strict) + return meta + + def backward(self, trainer, loss_keys, cumulative_iters, grad_clip): + # assert cumulative_iters == 1, 'DeepSpeed only support cumulative_iters=1' + # The `trainer.model` here is actually a deepspeed engine object. + # backward step + for k in loss_keys: + loss = trainer.train_outputs[k] + trainer.model.backward(loss) + + # update parameters + # Optimizer step for deepspeed must be called on every step regardless of + # the value of gradient accumulation iters + trainer.model.step() + + def initialize_optimizer(self, trainer): + pass + + def step(self, trainer): + pass + + def should_save_on_rank(self, trainer): + return True + + def get_current_lr(self, trainer): + if isinstance(trainer.optimizer, torch.optim.Optimizer) or isinstance( + trainer.optimizer, deepspeed.DeepSpeedOptimizer): + lr = [group['lr'] for group in trainer.optimizer.param_groups] + elif isinstance(trainer.optimizer, dict): + lr = dict() + for name, optim in trainer.optimizer.items(): + lr[name] = [group['lr'] for group in optim.param_groups] + else: + raise RuntimeError( + 'lr is not applicable because optimizer does not exist.') + return lr + + +@HOOKS.register_module(module_name=Hooks.DeepspeedHook) +class DeepspeedHook(Hook): + PRIORITY = Priority.VERY_HIGH + + def __init__(self, + config=None, + deepspeed_activation_checkpointing=True, + save_zero_checkpoint=False, + with_mpu=True, + zero_stage=None): + self.save_zero_checkpoint = save_zero_checkpoint + self.deepspeed_activation_checkpointing = deepspeed_activation_checkpointing + self.with_mpu = with_mpu + self.deepspeed_config = config + if zero_stage is not None: + assert zero_stage in (0, 1, 2, + 3), 'zero_stage must in (0, 1, 2, 3)!' + self.zero_stage = zero_stage + + def register_processor(self, trainer): + processor = DeepspeedProcessor() + optimizer_hook = trainer.get_hook(OptimizerHook) + if len(optimizer_hook) > 0 and not isinstance( + optimizer_hook[0].processor, DeepspeedProcessor): + optimizer_hook[0].set_processor(processor) + ckpt_hook = trainer.get_hook(CheckpointHook) + if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor, + DeepspeedProcessor): + ckpt_hook[0].set_processor(processor) + + best_ckpt_hook = trainer.get_hook(BestCkptSaverHook) + if len(best_ckpt_hook) > 0 and not isinstance( + best_ckpt_hook[0].processor, DeepspeedProcessor): + best_ckpt_hook[0].set_processor(processor) + + load_ckpt_hook = trainer.get_hook(LoadCheckpointHook) + if len(load_ckpt_hook) > 0 and not isinstance( + load_ckpt_hook[0].processor, DeepspeedProcessor): + load_ckpt_hook[0].set_processor(processor) + + lr_scheduler_hook = trainer.get_hook(LrSchedulerHook) + if len(lr_scheduler_hook) > 0 and not isinstance( + lr_scheduler_hook[0].processor, DeepspeedProcessor): + lr_scheduler_hook[0].set_processor(processor) + self.processor = processor + + def prepare_args(self, args): + args.per_device_train_batch_size = args.train.dataloader.get( + 'batch_size_per_gpu', 4) + args.max_grad_norm = args.train.get('clip_grad', 1.0) + args.learning_rate = args.train.optimizer.get('lr', 2e-5) + args.adam_beta1 = args.train.optimizer.get('adam_beta1', 0.9) + args.adam_beta2 = args.train.optimizer.get('adam_beta2', 0.999) + args.adam_epsilon = args.train.optimizer.get('adam_epsilon', 1e-8) + args.weight_decay = args.train.optimizer.get('weight_decay', 0.0) + args.fp16 = args.train.get('use_fp16', False) + args.fp16_full_eval = args.train.get('use_fp16', False) + args.fp16_backend = args.train.get('fp16_backend', 'amp') + args.save_on_each_node = args.train.get('save_on_each_node', False) + args.fp16_opt_level = args.train.get('fp16_opt_level', None) + args.fp16_opt_level = next((item.get('opt_level', args.fp16_opt_level) + for item in args.train.hooks + if item['type'] == 'ApexAMPOptimizerHook'), + args.fp16_opt_level) + if not args.fp16_opt_level: + args.fp16_opt_level = 'O1' + args.bf16 = args.train.get('bf16', False) + + def get_deepspeed_config(self, trainer, args, max_steps): + _, args.world_size = get_dist_info() + self.prepare_args(args) + if os.path.exists(self.deepspeed_config): + deepspeed_config = self.deepspeed_config + else: + deepspeed_config = os.path.join(trainer.model_dir, + self.deepspeed_config) + if not os.path.exists(deepspeed_config): + raise RuntimeError( + f'No such DeepSpeed json config file: {self.deepspeed_config}.' + ) + self.logger.info(f'Loading deepspeed config from {deepspeed_config}') + + ds_config = DeepSpeedConfig(deepspeed_config) + ds_config.trainer_config_process(args) + + ds_config.trainer_config_finalize(args, trainer.model, max_steps) + return ds_config + + def after_init(self, trainer): + init_dist('pytorch') + local_rank = get_local_rank() + trainer.device = create_device(f'cuda:{local_rank}') + trainer.model.to(trainer.device) + trainer.parallel_groups[DistributedParallelType.DP] = None + + def before_val(self, trainer): + pass + + def before_run(self, trainer): + if not hasattr(trainer, 'logger'): + self.logger = get_logger() + else: + self.logger = trainer.logger + + # deepspeed init + args = trainer.cfg + args.gradient_accumulation_steps = args.train.optimizer.get( + 'options', {}).get('cumulative_iters', 1) + num_update_steps_per_epoch = trainer.iters_per_epoch // args.gradient_accumulation_steps + max_steps = math.ceil(trainer._max_epochs * num_update_steps_per_epoch) + + ds_config = self.get_deepspeed_config(trainer, args, max_steps) + + optimizer, lr_scheduler = deepspeed_optim_sched( + trainer, ds_config, max_steps) + config = ds_config.config + if self.zero_stage is not None: + config['zero_optimization']['stage'] = self.zero_stage + self.processor.zero_stage = config['zero_optimization'].get('stage', 0) + + trainer.model, trainer.optimizer, _, trainer.lr_scheduler = deepspeed.initialize( + model=trainer.model, + optimizer=optimizer, + config=config, + lr_scheduler=lr_scheduler) diff --git a/modelscope/trainers/hooks/distributed/megatron_hook.py b/modelscope/trainers/hooks/distributed/megatron_hook.py index c4aeaf19a..66c857df9 100644 --- a/modelscope/trainers/hooks/distributed/megatron_hook.py +++ b/modelscope/trainers/hooks/distributed/megatron_hook.py @@ -35,7 +35,7 @@ def rank_name(self): except (ImportError, AssertionError): return '' - def get_bin_file(self): + def get_bin_filename(self): mp_rank = mpu.get_tensor_model_parallel_rank() rank = '{:02d}'.format(mp_rank) return f'mp_rank_{rank}_model_states.pt' @@ -72,7 +72,7 @@ def save_checkpoints(self, save_dir = os.path.dirname(checkpoint_path_prefix) prefix = os.path.basename(checkpoint_path_prefix) - bin_file = self.get_bin_file() + bin_file = self.get_bin_filename() prefix_bin_file = os.path.join(save_dir, prefix + '_' + bin_file) save_checkpoint(model, prefix_bin_file, with_meta=False) @@ -98,7 +98,7 @@ def remove_checkpoints(self, trainer, checkpoint_path_prefix): save_dir = os.path.dirname(checkpoint_path_prefix) prefix = os.path.basename(checkpoint_path_prefix) - bin_file = self.get_bin_file() + bin_file = self.get_bin_filename() absolute_file = os.path.join(save_dir, prefix + '_' + bin_file) if os.path.isfile(absolute_file): os.remove(absolute_file) @@ -108,7 +108,7 @@ def load_checkpoints(self, checkpoint_path_prefix, trainer, load_all_state, model = trainer.unwrap_module(trainer.model) if os.path.isdir(checkpoint_path_prefix): save_dir = checkpoint_path_prefix - bin_file = self.get_bin_file() + bin_file = self.get_bin_filename() model_file = os.path.join(save_dir, bin_file) load_checkpoint(model_file, model, None, None) else: @@ -119,7 +119,7 @@ def load_checkpoints(self, checkpoint_path_prefix, trainer, load_all_state, save_dir = os.path.dirname(checkpoint_path_prefix) prefix = os.path.basename(checkpoint_path_prefix) - bin_file = self.get_bin_file() + bin_file = self.get_bin_filename() model_file = os.path.join(save_dir, prefix + '_' + bin_file) load_checkpoint(model_file, model, None, None) diff --git a/modelscope/trainers/hooks/lr_scheduler_hook.py b/modelscope/trainers/hooks/lr_scheduler_hook.py index facf51559..acaa7ebf8 100644 --- a/modelscope/trainers/hooks/lr_scheduler_hook.py +++ b/modelscope/trainers/hooks/lr_scheduler_hook.py @@ -112,6 +112,9 @@ def after_train_epoch(self, trainer): self.processor.step(trainer) def _get_log_lr(self, trainer): + # forward compatibility with AddLrLogHook in EasyCV + if not hasattr(self, 'processor'): + self.processor = LrSchedulerProcessor() cur_lr = self.processor.get_current_lr(trainer) # only record lr of the first param group if isinstance(cur_lr, list): diff --git a/modelscope/trainers/optimizer/__init__.py b/modelscope/trainers/optimizer/__init__.py index 9962c2c2a..cd59c072d 100644 --- a/modelscope/trainers/optimizer/__init__.py +++ b/modelscope/trainers/optimizer/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.swift import ChildTuningAdamW from .builder import OPTIMIZERS, build_optimizer -from .child_tuning_adamw_optimizer import ChildTuningAdamW __all__ = ['OPTIMIZERS', 'build_optimizer', 'ChildTuningAdamW'] diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index fd0fafb8b..2e62be892 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -44,6 +44,7 @@ from modelscope.utils.torch_utils import (compile_model, get_dist_info, get_local_rank, init_dist, is_dist, is_master, set_random_seed) +from ..swift import Swift from .base import BaseTrainer from .builder import TRAINERS from .default_config import merge_cfg, merge_hooks, update_cfg @@ -231,7 +232,7 @@ def __init__( # A logic to fit the current code # Put a DDPHook in if launcher is provided. if 'hooks' not in self.cfg.train: - self.cfg.train['hooks'] = ConfigDict([]) + self.cfg.train['hooks'] = [] self.cfg.train['hooks'].append({ 'type': 'DDPHook', 'launcher': self.launcher @@ -264,10 +265,7 @@ def __init__( def tune_module(self, efficient_tuners): if efficient_tuners is not None: for tuner in efficient_tuners: - type = tuner.pop('type') - if type == 'lora': - from modelscope.tuners.lora import LoRATuner - LoRATuner.tune(self.model, **tuner) + self.model = Swift.prepare_model(self.model, tuner) def place_model(self): """Place model to device, or to DDP diff --git a/modelscope/utils/__init__.py b/modelscope/utils/__init__.py index e69de29bb..7486e137b 100644 --- a/modelscope/utils/__init__.py +++ b/modelscope/utils/__init__.py @@ -0,0 +1 @@ +from .hub import create_model_if_not_exist, read_config diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 1f44fc014..32b1e2daa 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -517,6 +517,8 @@ class Devices: 'txt': 'text' } +META_FILES_FORMAT = ('.csv', '.jsonl') + class DatasetPathName: META_NAME = 'meta' diff --git a/requirements/audio/audio_asr.txt b/requirements/audio/audio_asr.txt index d136ce55e..f7b1eaea9 100644 --- a/requirements/audio/audio_asr.txt +++ b/requirements/audio/audio_asr.txt @@ -1 +1 @@ -funasr>=0.6.0 +funasr>=0.6.5 diff --git a/requirements/cv.txt b/requirements/cv.txt index c48d82e2d..14515fdb1 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -6,7 +6,7 @@ chumpy clip>=1.0 control_ldm ddpm_guided_diffusion -diffusers>=0.13.1,<0.15.0 +diffusers>=0.13.1,<=0.15.0 easydict easyrobust edit_distance diff --git a/requirements/framework.txt b/requirements/framework.txt index 3ba659fce..69fa1f409 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -1,6 +1,6 @@ addict attrs -datasets>=2.7.0,<=2.8.0 +datasets einops filelock>=3.3.0 gast>=0.2.2 diff --git a/tests/cli/test_plugins_cmd.py b/tests/cli/test_plugins_cmd.py index b11c67ab6..6962b3f6f 100644 --- a/tests/cli/test_plugins_cmd.py +++ b/tests/cli/test_plugins_cmd.py @@ -4,6 +4,7 @@ from modelscope.utils.plugins import PluginsManager +@unittest.skipUnless(False, reason='For it modify torch version') class PluginsCMDTest(unittest.TestCase): def setUp(self): diff --git a/tests/msdatasets/test_load_meta_jsonl.py b/tests/msdatasets/test_load_meta_jsonl.py new file mode 100644 index 000000000..641985665 --- /dev/null +++ b/tests/msdatasets/test_load_meta_jsonl.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.msdatasets import MsDataset +from modelscope.utils import logger as logging +from modelscope.utils.constant import DownloadMode +from modelscope.utils.test_utils import test_level + +logger = logging.get_logger() + + +class TestLoadMetaJsonl(unittest.TestCase): + + def setUp(self): + self.dataset_id = 'modelscope/ms_ds_meta_jsonlines' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_load_jsonl_in_meta(self): + ds = MsDataset.load( + self.dataset_id, + split='test', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + ds_one = next(iter(ds)) + logger.info(next(iter(ds))) + assert ds_one['text'] + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_panoptic_segmentation.py b/tests/pipelines/test_image_panoptic_segmentation.py index 38c660556..3c635d655 100644 --- a/tests/pipelines/test_image_panoptic_segmentation.py +++ b/tests/pipelines/test_image_panoptic_segmentation.py @@ -18,7 +18,7 @@ def setUp(self) -> None: self.task = Tasks.image_segmentation self.model_id = 'damo/cv_swinL_panoptic-segmentation_cocopan' - @unittest.skip('skip test in current test level: no pipeline implemented') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_image_panoptic_segmentation(self): input_location = 'data/test/images/image_panoptic_segmentation.jpg' pan_segmentor = pipeline(Tasks.image_segmentation, model=self.model_id) @@ -28,7 +28,7 @@ def test_image_panoptic_segmentation(self): cv2.imwrite('result.jpg', draw_img) print('print test_image_panoptic_segmentation return success') - @unittest.skip('skip test in current test level: no pipeline implemented') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_image_panoptic_segmentation_from_PIL(self): input_location = 'data/test/images/image_panoptic_segmentation.jpg' pan_segmentor = pipeline(Tasks.image_segmentation, model=self.model_id) diff --git a/tests/pipelines/test_speaker_verification.py b/tests/pipelines/test_speaker_verification.py index 2b90c66ed..b772ed406 100644 --- a/tests/pipelines/test_speaker_verification.py +++ b/tests/pipelines/test_speaker_verification.py @@ -15,6 +15,7 @@ SPEAKER1_B_EN_16K_WAV = 'data/test/audios/speaker1_b_en_16k.wav' SPEAKER2_A_EN_16K_WAV = 'data/test/audios/speaker2_a_en_16k.wav' SCL_EXAMPLE_WAV = 'data/test/audios/scl_example1.wav' +SD_EXAMPLE_WAV = 'data/test/audios/2speakers_example.wav' class SpeakerVerificationTest(unittest.TestCase): @@ -23,6 +24,8 @@ class SpeakerVerificationTest(unittest.TestCase): rdino_voxceleb_16k_model_id = 'damo/speech_rdino_ecapa_tdnn_sv_en_voxceleb_16k' speaker_change_locating_cn_model_id = 'damo/speech_campplus-transformer_scl_zh-cn_16k-common' eres2net_voxceleb_16k_model_id = 'damo/speech_eres2net_sv_en_voxceleb_16k' + speaker_diarization_model_id = 'damo/speech_campplus_speaker-diarization_common' + eres2net_aug_zh_cn_16k_common_model_id = 'damo/speech_eres2net_sv_zh-cn_16k-common' def setUp(self) -> None: self.task = Tasks.speaker_verification @@ -91,6 +94,27 @@ def test_run_with_speaker_verification_eres2net_voxceleb_16k(self): print(result) self.assertTrue(OutputKeys.SCORE in result) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_speaker_verification_eres2net_aug_zh_cn_common_16k(self): + logger.info('Run speaker verification for eres2net_voxceleb_16k model') + result = self.run_pipeline( + model_id=self.eres2net_aug_zh_cn_16k_common_model_id, + audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV], + model_revision='v1.0.1') + print(result) + self.assertTrue(OutputKeys.SCORE in result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_speaker_diarization_common(self): + logger.info( + 'Run speaker change locating for campplus-transformer model') + result = self.run_pipeline( + model_id=self.speaker_diarization_model_id, + task=Tasks.speaker_diarization, + audios=SD_EXAMPLE_WAV) + print(result) + self.assertTrue(OutputKeys.TEXT in result) + if __name__ == '__main__': unittest.main() diff --git a/tests/trainers/test_dialog_modeling_trainer.py b/tests/trainers/test_dialog_modeling_trainer.py index 900bf904a..3378506c1 100644 --- a/tests/trainers/test_dialog_modeling_trainer.py +++ b/tests/trainers/test_dialog_modeling_trainer.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import shutil import unittest import torch @@ -17,6 +18,11 @@ class TestDialogModelingTrainer(unittest.TestCase): model_id = 'damo/nlp_space_pretrained-dialog-model' output_dir = './dialog_fintune_result' + def tearDown(self): + if os.path.exists(self.output_dir): + shutil.rmtree(self.output_dir) + super().tearDown() + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_trainer_with_model_and_args(self): # download data set diff --git a/tests/trainers/test_document_grounded_dialog_rerank_trainer.py b/tests/trainers/test_document_grounded_dialog_rerank_trainer.py index fad0b55e1..168659dbc 100644 --- a/tests/trainers/test_document_grounded_dialog_rerank_trainer.py +++ b/tests/trainers/test_document_grounded_dialog_rerank_trainer.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import shutil import unittest import json @@ -19,6 +20,10 @@ class TestDialogIntentTrainer(unittest.TestCase): def setUp(self): self.model_id = 'DAMO_ConvAI/nlp_convai_ranking_pretrain' + def tearDown(self): + shutil.rmtree('./model') + super().tearDown() + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_with_model_and_args(self): args = { diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index 797351aa6..ef0ccae7f 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -8,18 +8,17 @@ from modelscope.models import Model from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline +from modelscope.swift.optimizers.child_tuning_adamw_optimizer import \ + calculate_fisher from modelscope.trainers import build_trainer from modelscope.trainers.hooks import Hook from modelscope.trainers.nlp_trainer import (EpochBasedTrainer, NlpEpochBasedTrainer) -from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ - calculate_fisher from modelscope.trainers.training_args import TrainingArgs from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.data_utils import to_device from modelscope.utils.regress_test_utils import (MsRegressTool, compare_arguments_nested) -from modelscope.utils.test_utils import test_level class TestFinetuneSequenceClassification(unittest.TestCase): diff --git a/tests/trainers/test_finetune_vision_efficient_tuning_swift.py b/tests/trainers/test_finetune_vision_efficient_tuning_swift.py new file mode 100644 index 000000000..d8733024a --- /dev/null +++ b/tests/trainers/test_finetune_vision_efficient_tuning_swift.py @@ -0,0 +1,164 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import shutil +import tempfile +import unittest + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.swift import Swift +from modelscope.swift.adapter import AdapterConfig +from modelscope.swift.lora import LoRAConfig +from modelscope.swift.prompt import PromptConfig +from modelscope.trainers import build_trainer +from modelscope.utils.test_utils import test_level + + +class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + self.train_dataset = MsDataset.load( + 'foundation_model_evaluation_benchmark', + namespace='damo', + subset_name='OxfordFlowers', + split='train') + + self.eval_dataset = MsDataset.load( + 'foundation_model_evaluation_benchmark', + namespace='damo', + subset_name='OxfordFlowers', + split='eval') + + self.max_epochs = 1 + self.num_classes = 102 + self.tune_length = 10 + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_swift_lora_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + cfg.model.backbone.lora_length = 0 + return cfg + + lora_config = LoRAConfig( + rank=self.tune_length, + replace_modules=['qkv'], + merge_weights=False, + only_lora_trainable=False, + use_merged_linear=True, + enable_lora=[True]) + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn, + efficient_tuners=[lora_config]) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-lora train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_swift_adapter_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + cfg.model.backbone.adapter_length = 0 + return cfg + + adapter_config = AdapterConfig( + dim=768, + hidden_pos=0, + module_name=r'.*blocks\.\d+\.mlp$', + adapter_length=self.tune_length, + only_adapter_trainable=False) + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn, + efficient_tuners=[adapter_config]) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-adapter train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_swift_prompt_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + cfg.model.backbone.prompt_length = 0 + return cfg + + prompt_config = PromptConfig( + dim=768, + module_layer_name=r'.*blocks\.\d+$', + embedding_pos=0, + prompt_length=self.tune_length, + only_prompt_trainable=False, + attach_front=False) + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn, + efficient_tuners=[prompt_config]) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-prompt train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_portrait_stylization_trainer.py b/tests/trainers/test_image_portrait_stylization_trainer.py index 6a3c41fa8..37b42de68 100644 --- a/tests/trainers/test_image_portrait_stylization_trainer.py +++ b/tests/trainers/test_image_portrait_stylization_trainer.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import shutil import unittest import cv2 @@ -20,6 +21,10 @@ def setUp(self) -> None: self.task = Tasks.image_portrait_stylization self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png' + def tearDown(self): + shutil.rmtree('exp_localtoon') + super().tearDown() + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): model_id = 'damo/cv_unet_person-image-cartoon_compound-models' diff --git a/tests/trainers/test_nerf_recon_acc_trainer.py b/tests/trainers/test_nerf_recon_acc_trainer.py index 4b6c80912..579901afd 100644 --- a/tests/trainers/test_nerf_recon_acc_trainer.py +++ b/tests/trainers/test_nerf_recon_acc_trainer.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import shutil import unittest from modelscope.msdatasets import MsDataset @@ -10,6 +11,10 @@ class TestNeRFReconAccTrainer(unittest.TestCase): + def tearDown(self): + shutil.rmtree('exp_nerf') + super().tearDown() + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): model_id = 'damo/cv_nerf-3d-reconstruction-accelerate_damo' diff --git a/tests/trainers/test_ocr_detection_db_trainer.py b/tests/trainers/test_ocr_detection_db_trainer.py index 10097feaa..136ae4107 100644 --- a/tests/trainers/test_ocr_detection_db_trainer.py +++ b/tests/trainers/test_ocr_detection_db_trainer.py @@ -36,6 +36,10 @@ def setUp(self): self.saved_infer_model = os.path.join(self.saved_dir, 'pytorch_model.pt') + def tearDown(self): + shutil.rmtree(self.saved_dir) + super().tearDown() + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_finetune_singleGPU(self): diff --git a/tests/trainers/test_ofa_mmspeech_trainer.py b/tests/trainers/test_ofa_mmspeech_trainer.py index 2c4f63072..6ab4ac765 100644 --- a/tests/trainers/test_ofa_mmspeech_trainer.py +++ b/tests/trainers/test_ofa_mmspeech_trainer.py @@ -69,11 +69,17 @@ def setUp(self) -> None: 'metrics': [{'type': 'accuracy'}]}, 'preprocessor': []} + self.WORKSPACE = './workspace/ckpts/asr_recognition' + + def tearDown(self) -> None: + if os.path.exists(self.WORKSPACE): + shutil.rmtree(self.WORKSPACE, ignore_errors=True) + super().tearDown() + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_std(self): - WORKSPACE = './workspace/ckpts/asr_recognition' - os.makedirs(WORKSPACE, exist_ok=True) - config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) + os.makedirs(self.WORKSPACE, exist_ok=True) + config_file = os.path.join(self.WORKSPACE, ModelFile.CONFIGURATION) with open(config_file, 'w') as writer: json.dump(self.finetune_cfg, writer) @@ -81,7 +87,7 @@ def test_trainer_std(self): args = dict( model=pretrained_model, - work_dir=WORKSPACE, + work_dir=self.WORKSPACE, train_dataset=MsDataset.load( 'aishell1_subset', subset_name='default', @@ -100,8 +106,8 @@ def test_trainer_std(self): self.assertIn( ModelFile.TORCH_MODEL_BIN_FILE, - os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) - shutil.rmtree(WORKSPACE) + os.listdir( + os.path.join(self.WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) if __name__ == '__main__': diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index f4ca7bcb3..cb4807442 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -69,19 +69,24 @@ def setUp(self) -> None: 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, 'metrics': [{'type': 'accuracy'}]}, 'preprocessor': []} + self.WORKSPACE = './workspace/ckpts/recognition' + + def tearDown(self) -> None: + if os.path.exists(self.WORKSPACE): + shutil.rmtree(self.WORKSPACE, ignore_errors=True) + super().tearDown() @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_std(self): - WORKSPACE = './workspace/ckpts/recognition' - os.makedirs(WORKSPACE, exist_ok=True) - config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) + os.makedirs(self.WORKSPACE, exist_ok=True) + config_file = os.path.join(self.WORKSPACE, ModelFile.CONFIGURATION) with open(config_file, 'w') as writer: json.dump(self.finetune_cfg, writer, indent=4) pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' args = dict( model=pretrained_model, - work_dir=WORKSPACE, + work_dir=self.WORKSPACE, train_dataset=MsDataset.load( 'ocr_fudanvi_zh', subset_name='scene', @@ -100,8 +105,8 @@ def test_trainer_std(self): self.assertIn( ModelFile.TORCH_MODEL_BIN_FILE, - os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) - shutil.rmtree(WORKSPACE) + os.listdir( + os.path.join(self.WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) if __name__ == '__main__': diff --git a/tests/trainers/test_plug_finetune_text_generation.py b/tests/trainers/test_plug_finetune_text_generation.py index 6d9e0740b..4e3523b07 100644 --- a/tests/trainers/test_plug_finetune_text_generation.py +++ b/tests/trainers/test_plug_finetune_text_generation.py @@ -45,6 +45,7 @@ def concat_answer_context(dataset): trainer = build_trainer( name=Trainers.nlp_plug_trainer, default_args=kwargs) trainer.train() + shutil.rmtree(tmp_dir) if __name__ == '__main__': diff --git a/tests/trainers/test_referring_video_object_segmentation_trainer.py b/tests/trainers/test_referring_video_object_segmentation_trainer.py index fb1529545..53ab52365 100644 --- a/tests/trainers/test_referring_video_object_segmentation_trainer.py +++ b/tests/trainers/test_referring_video_object_segmentation_trainer.py @@ -60,6 +60,10 @@ def setUp(self): self.max_epochs = max_epochs + def tearDown(self): + shutil.rmtree('./work_dir') + super().tearDown() + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer(self): kwargs = dict( diff --git a/tests/trainers/test_team_transfer_trainer.py b/tests/trainers/test_team_transfer_trainer.py index 0f6b88bb3..6012da088 100644 --- a/tests/trainers/test_team_transfer_trainer.py +++ b/tests/trainers/test_team_transfer_trainer.py @@ -1,4 +1,5 @@ import os +import shutil import unittest import json @@ -81,6 +82,10 @@ def train_worker(device_id): class TEAMTransferTrainerTest(unittest.TestCase): + def tearDown(self) -> None: + super().tearDown() + shutil.rmtree('./ckpt') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer(self): if torch.cuda.device_count() > 0: diff --git a/tests/trainers/test_tinynas_damoyolo_trainer.py b/tests/trainers/test_tinynas_damoyolo_trainer.py index 5dd9e9288..d2070bec2 100644 --- a/tests/trainers/test_tinynas_damoyolo_trainer.py +++ b/tests/trainers/test_tinynas_damoyolo_trainer.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import shutil import unittest from modelscope.hub.snapshot_download import snapshot_download @@ -21,6 +22,10 @@ def setUp(self): self.model_id = 'damo/cv_tinynas_object-detection_damoyolo' self.cache_path = _setup() + def tearDown(self) -> None: + super().tearDown() + shutil.rmtree('./workdirs') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_from_scratch_singleGPU(self): kwargs = dict( diff --git a/tests/trainers/test_translation_evaluation_trainer.py b/tests/trainers/test_translation_evaluation_trainer.py index 139427da8..a070c8126 100644 --- a/tests/trainers/test_translation_evaluation_trainer.py +++ b/tests/trainers/test_translation_evaluation_trainer.py @@ -1,8 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os.path +import shutil import unittest from modelscope.metainfo import Trainers from modelscope.trainers import build_trainer +from modelscope.utils.hub import read_config from modelscope.utils.test_utils import test_level @@ -13,6 +16,15 @@ def setUp(self) -> None: self.model_id_large = 'damo/nlp_unite_mup_translation_evaluation_multilingual_large' self.model_id_base = 'damo/nlp_unite_mup_translation_evaluation_multilingual_base' + def tearDown(self) -> None: + cfg_base = read_config(self.model_id_base) + if os.path.exists(cfg_base.train.work_dir): + shutil.rmtree(cfg_base.train.work_dir, ignore_errors=True) + cfg_large = read_config(self.model_id_large) + if os.path.exists(cfg_large.train.work_dir): + shutil.rmtree(cfg_large.train.work_dir, ignore_errors=True) + super().tearDown() + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_unite_mup_large(self) -> None: default_args = {'model': self.model_id_large} diff --git a/tests/tuners/test_adapter.py b/tests/tuners/test_adapter.py new file mode 100644 index 000000000..a110591a8 --- /dev/null +++ b/tests/tuners/test_adapter.py @@ -0,0 +1,81 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import numpy as np +import torch + +from modelscope import read_config +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.base import Model +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.swift import Swift +from modelscope.swift.adapter import AdapterConfig +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class TestAdapter(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip in this level') + def test_adapter_smoke_test(self): + dataset = MsDataset.load( + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(2)) + + model_dir = snapshot_download( + 'damo/nlp_structbert_sentence-similarity_chinese-tiny') + model = Model.from_pretrained(model_dir, adv_grad_factor=None) + + cfg_file = os.path.join(model_dir, 'configuration.json') + + model_cfg = os.path.join(model_dir, 'config.json') + model_cfg = read_config(model_cfg) + + adapter_config = AdapterConfig( + dim=model_cfg.hidden_size, + module_name=r'.*layer\.\d+$', + method_name='feed_forward_chunk', + hidden_pos=0) + model = Swift.prepare_model(model, adapter_config) + kwargs = dict( + model=model, + cfg_file=cfg_file, + train_dataset=dataset, + eval_dataset=dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) + + def pipeline_sentence_similarity(model_dir): + model = Model.from_pretrained(model_dir) + adapter_config.pretrained_weights = output_dir + Swift.prepare_model(model, adapter_config) + model.eval() + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, model=model) + return pipeline_ins(input=('test', 'this is a test')) + + output1 = pipeline_sentence_similarity( + 'damo/nlp_structbert_sentence-similarity_chinese-tiny') + print(output1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tuners/test_lora.py b/tests/tuners/test_lora.py index 2f52a4d37..b3238dad7 100644 --- a/tests/tuners/test_lora.py +++ b/tests/tuners/test_lora.py @@ -11,9 +11,10 @@ from modelscope.models.base import Model from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline +from modelscope.swift import Swift +from modelscope.swift.lora import (Linear, LoRA, LoRAConfig, + mark_only_lora_as_trainable) from modelscope.trainers import build_trainer -from modelscope.tuners.lora import (Linear, LoRATuner, - mark_only_lora_as_trainable) from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.test_utils import test_level @@ -66,22 +67,18 @@ def test_lora_smoke_test(self): model_dir = snapshot_download( 'damo/nlp_structbert_sentence-similarity_chinese-tiny') - model = Model.from_pretrained( - 'damo/nlp_structbert_sentence-similarity_chinese-tiny', - adv_grad_factor=None) + model = Model.from_pretrained(model_dir, adv_grad_factor=None) cfg_file = os.path.join(model_dir, 'configuration.json') + lora_config = LoRAConfig(replace_modules=['query', 'key', 'value']) + model = Swift.prepare_model(model, lora_config) kwargs = dict( model=model, cfg_file=cfg_file, train_dataset=dataset, eval_dataset=dataset, - work_dir=self.tmp_dir, - efficient_tuners=[{ - 'type': 'lora', - 'replace_modules': ['query', 'key', 'value'] - }]) + work_dir=self.tmp_dir) trainer = build_trainer(default_args=kwargs) trainer.train() @@ -89,7 +86,8 @@ def test_lora_smoke_test(self): def pipeline_sentence_similarity(model_dir): model = Model.from_pretrained(model_dir) - LoRATuner.tune(model, replace_modules=['query', 'key', 'value']) + lora_config.pretrained_weights = output_dir + Swift.prepare_model(model, lora_config) model.load_state_dict( torch.load(os.path.join(output_dir, 'pytorch_model.bin'))) model.eval() @@ -100,7 +98,7 @@ def pipeline_sentence_similarity(model_dir): output1 = pipeline_sentence_similarity( 'damo/nlp_structbert_sentence-similarity_chinese-tiny') - LoRATuner.unpatch_lora(model, ['query', 'key', 'value']) + LoRA.unpatch_lora(model, lora_config) model.save_pretrained( output_dir, save_checkpoint_names='pytorch_model.bin') diff --git a/tests/tuners/test_prompt.py b/tests/tuners/test_prompt.py new file mode 100644 index 000000000..c338162f6 --- /dev/null +++ b/tests/tuners/test_prompt.py @@ -0,0 +1,83 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import numpy as np +import torch + +from modelscope import read_config +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.base import Model +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.swift import Swift +from modelscope.swift.adapter import AdapterConfig +from modelscope.swift.prompt import PromptConfig +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class TestPrompt(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip in this level') + def test_prompt_smoke_test(self): + dataset = MsDataset.load( + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(2)) + + model_dir = snapshot_download( + 'damo/nlp_structbert_sentence-similarity_chinese-tiny') + model = Model.from_pretrained(model_dir, adv_grad_factor=None) + + cfg_file = os.path.join(model_dir, 'configuration.json') + model_cfg = os.path.join(model_dir, 'config.json') + model_cfg = read_config(model_cfg) + + prompt_config = PromptConfig( + dim=model_cfg.hidden_size, + module_layer_name=r'.*layer\.\d+$', + embedding_pos=0, + attention_mask_pos=1) + + model = Swift.prepare_model(model, prompt_config) + + kwargs = dict( + model=model, + cfg_file=cfg_file, + train_dataset=dataset, + eval_dataset=dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) + + def pipeline_sentence_similarity(model_dir): + model = Model.from_pretrained(model_dir) + prompt_config.pretrained_weights = output_dir + Swift.prepare_model(model, prompt_config) + model.eval() + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, model=model) + return pipeline_ins(input=('test', 'this is a test')) + + output1 = pipeline_sentence_similarity( + 'damo/nlp_structbert_sentence-similarity_chinese-tiny') + print(output1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/case_file_analyzer.py b/tests/utils/case_file_analyzer.py index 82ae55630..f0445954a 100644 --- a/tests/utils/case_file_analyzer.py +++ b/tests/utils/case_file_analyzer.py @@ -395,8 +395,12 @@ def get_pipelines_trainers_test_info(register_modules): if os.path.splitext(f)[1] == '.py' ] for test_file in all_pipeline_cases: - tested_pipelines = analysis_pipeline_test_suite( - test_file, register_modules) + try: + tested_pipelines = analysis_pipeline_test_suite( + test_file, register_modules) + except Exception: + logger.warn('test_suite: %s analysis failed, skipt it' % test_file) + continue if len(tested_pipelines) == 0: logger.warn('test_suite: %s has no pipeline task' % test_file) else: diff --git a/tests/utils/test_ast.py b/tests/utils/test_ast.py index 288c076a2..544e75b69 100644 --- a/tests/utils/test_ast.py +++ b/tests/utils/test_ast.py @@ -35,7 +35,7 @@ def tearDown(self): def test_ast_scaning_class(self): astScaner = AstScanning() pipeline_file = os.path.join(MODELSCOPE_PATH, 'pipelines', 'nlp', - 'text_generation_pipeline.py') + 'fill_mask_pipeline.py') output = astScaner.generate_ast(pipeline_file) self.assertTrue(output['imports'] is not None) self.assertTrue(output['from_imports'] is not None) @@ -45,24 +45,19 @@ def test_ast_scaning_class(self): self.assertIsInstance(imports, dict) self.assertIsInstance(from_imports, dict) self.assertIsInstance(decorators, list) - self.assertListEqual( - list(set(imports.keys()) - set(['torch', 'os'])), []) - self.assertEqual(len(from_imports.keys()), 11) + self.assertListEqual(list(set(imports.keys()) - set(['numpy'])), []) + self.assertEqual(len(from_imports.keys()), 8) self.assertTrue(from_imports['modelscope.metainfo'] is not None) self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) - self.assertEqual( - decorators, - [('PIPELINES', 'text-generation', 'text-generation'), - ('PIPELINES', 'text2text-generation', 'translation_en_to_de'), - ('PIPELINES', 'text2text-generation', 'translation_en_to_ro'), - ('PIPELINES', 'text2text-generation', 'translation_en_to_fr'), - ('PIPELINES', 'text2text-generation', 'text2text-generation')]) + self.assertEqual(decorators, + [('PIPELINES', 'fill-mask', 'fill-mask'), + ('PIPELINES', 'fill-mask', 'fill-mask-ponet')]) def test_files_scaning_method(self): fileScaner = FilesAstScanning() # case of pass in files directly pipeline_file = os.path.join(MODELSCOPE_PATH, 'pipelines', 'nlp', - 'text_generation_pipeline.py') + 'fill_mask_pipeline.py') file_list = [pipeline_file] output = fileScaner.get_files_scan_results(file_list) self.assertTrue(output[INDEX_KEY] is not None) diff --git a/tools/weight_diff.py b/tools/weight_diff.py new file mode 100644 index 000000000..8cfa1d505 --- /dev/null +++ b/tools/weight_diff.py @@ -0,0 +1,170 @@ +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# Copyright (c) Alibaba, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from typing import Dict, Optional + +import torch +import tqdm +import transformers + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.utils.checkpoint import save_pretrained +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def smart_tokenizer_and_embedding_resize(special_tokens_dict: Dict, tokenizer, + model): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def make_same_shape(model_raw: Model, model_convert: Model, tokenizer_raw, + tokenizer_convert): + if model_raw.__class__ != model_convert.__class__: + logger.error( + f'weight diff: These two models should be of the same class. model_raw:' + f'{model_raw.__class__} vs model_convert: {model_convert.__class__}.' + ) + + special_tokens = {} + for k, v in tokenizer_convert.special_tokens_map_extended.items(): + if k not in tokenizer_raw.special_tokens_map_extended: + special_tokens[k] = v + + smart_tokenizer_and_embedding_resize( + special_tokens_dict=special_tokens, + model=model_raw, + tokenizer=tokenizer_raw, + ) + + state_dict_tuned = model_convert.state_dict() + state_dict_raw = model_raw.state_dict() + for key in tqdm.tqdm(state_dict_tuned): + if state_dict_tuned[key].shape != state_dict_raw[key].shape: + logger.error( + f'weight diff: shape mismatch. {key}, model_raw shape: {state_dict_raw[key].shape}' + f' vs model_convert shape: {state_dict_tuned[key].shape}.') + + +def _weight_diff(model_raw, + model_convert, + tokenizer_raw, + tokenizer_convert, + path_to_save=None, + make_diff_or_recover='diff'): + make_same_shape(model_raw, model_convert, tokenizer_raw, tokenizer_convert) + + state_dict_raw = model_raw.state_dict() + state_dict_convert = model_convert.state_dict() + if make_diff_or_recover == 'diff': + for key in tqdm.tqdm(state_dict_convert): + state_dict_convert[key].add_(-state_dict_raw[key]) + elif make_diff_or_recover == 'recover': + for key in tqdm.tqdm(state_dict_convert): + state_dict_convert[key].add_(state_dict_raw[key]) + + if path_to_save: + model_convert.save_pretrained(path_to_save, 'pytorch_model.bin') + tokenizer_convert.save_pretrained(path_to_save) + + return model_convert, tokenizer_convert + + +@torch.inference_mode() +def weight_diff(path_raw: str, + path_convert: str, + path_to_save: str, + make_diff_or_recover, + device='cpu'): + """Make the weight diff. + + This function is given to present full transparency of how the weight diff was created. + """ + if not os.path.exists(path_raw): + logger.info( + f'Path `{path_raw}` not found. Try to load from cache or remote.') + path_raw = snapshot_download(path_raw) + if not os.path.exists(path_convert): + logger.info( + f'Path `{path_convert}` not found. Try to load from cache or remote.' + ) + path_convert = snapshot_download(path_convert) + + model_raw = Model.from_pretrained(path_raw, device=device) + model_convert = Model.from_pretrained(path_convert, device=device) + + tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( + path_raw) + tokenizer_convert: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( + path_convert) + + return _weight_diff( + model_raw, + model_convert, + tokenizer_raw, + tokenizer_convert, + path_to_save=path_to_save, + make_diff_or_recover=make_diff_or_recover) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description= + 'Make the weight diff between the raw model and tuned model, or recover tuned weights from the ' + 'released weight diff.') + + parser.add_argument( + 'make_diff_or_recover', + choices=['diff', 'recover'], + help= + 'model selection, make weight diff or recover weights from the weight diff.' + ) + parser.add_argument( + 'path_raw', type=str, help='path to the raw pretrained model.') + parser.add_argument( + 'path_convert', + type=str, + help= + 'path to the tuned model in mode `diff`, or path to the diff model in mode `recover`.' + ) + parser.add_argument( + 'path_to_save', + type=str, + help='path to save the diff or recover output files.') + args = parser.parse_args() + + weight_diff(args.path_raw, args.path_convert, args.path_to_save, + args.make_diff_or_recover)