From 0436cdd0915534f84d3e8783f3e4193c64eb44d9 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Tue, 26 Dec 2023 11:08:55 +0100 Subject: [PATCH] use flash_attn_with_kvcache for faster inference (#2539) * use flash_attn_with_kvcache * patch rmsnorm for multiexperts * rope theta as an option --- onmt/decoders/transformer.py | 20 +++-- onmt/encoders/transformer.py | 6 ++ onmt/modules/moe.py | 4 + onmt/modules/multi_headed_attn.py | 136 +++++++++++++++++++++++++----- onmt/modules/position_ffn.py | 6 +- onmt/modules/rmsnorm.py | 31 +++++-- onmt/opts.py | 7 ++ onmt/translate/translator.py | 3 + onmt/utils/distributed.py | 2 +- tools/convert_HF_llamalike.py | 7 ++ 10 files changed, 183 insertions(+), 39 deletions(-) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 7c56a2ce3a..557c566e20 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -11,11 +11,7 @@ from onmt.modules.position_ffn import ActivationFunction from onmt.modules.moe import MoE from onmt.utils.misc import sequence_mask - -try: - from apex.normalization import FusedRMSNorm as RMSNorm -except ImportError: - from onmt.modules.rmsnorm import RMSNorm +from onmt.modules.rmsnorm import RMSNorm class TransformerDecoderLayerBase(nn.Module): @@ -44,6 +40,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + rotary_theta=1e4, num_experts=0, num_experts_per_tok=2, ): @@ -89,6 +86,7 @@ def __init__( sliding_window (int): Width of the band mask and KV cache (cf Mistral Model) rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied + rotary_theta (int): rotary base theta """ super(TransformerDecoderLayerBase, self).__init__() @@ -100,6 +98,7 @@ def __init__( max_relative_positions=max_relative_positions, relative_positions_buckets=relative_positions_buckets, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, attn_type="self", self_attn_type=self_attn_type, add_qkvbias=add_qkvbias, @@ -280,6 +279,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + rotary_theta=1e4, num_experts=0, num_experts_per_tok=2, ): @@ -311,6 +311,7 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, ) @@ -473,6 +474,7 @@ def from_opt(cls, opt, embeddings): else 1, sliding_window=opt.sliding_window, rotary_interleave=opt.rotary_interleave, + rotary_theta=opt.rotary_theta, num_experts=opt.num_experts, num_experts_per_tok=opt.num_experts_per_tok, ) @@ -563,6 +565,7 @@ class TransformerDecoder(TransformerDecoderBase): parallel_gpu (int): Number of gpu for tensor parallelism sliding_window (int): Width of the band mask and KV cache (cf Mistral Model) rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied + rotary_theta (int): rotary base theta """ def __init__( @@ -594,6 +597,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + rotary_theta=1e4, num_experts=0, num_experts_per_tok=2, ): @@ -627,6 +631,7 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, ) @@ -834,6 +839,7 @@ class TransformerLMDecoder(TransformerDecoderBase): parallel_gpu (int): Number of gpu for tensor parallelism sliding_window (int): Width of the band mask and KV cache (cf Mistral Model) rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied + rotary_theta (int): rotary base theta """ def __init__( @@ -865,6 +871,7 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + rotary_theta=1e4, num_experts=0, num_experts_per_tok=2, ): @@ -897,6 +904,7 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, ) @@ -976,3 +984,5 @@ def _init_cache(self, tgt=None): ) if hasattr(layer.self_attn, "rope"): layer.self_attn.rope = layer.self_attn.rope.to(tgt.device) + layer.self_attn.cos = layer.self_attn.cos.to(tgt.device) + layer.self_attn.sin = layer.self_attn.sin.to(tgt.device) diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 12998957dc..35d32ce709 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -40,6 +40,7 @@ class TransformerEncoderLayer(nn.Module): parallel_gpu (int): Number of gpu for tensor parallelism rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied + rotary_theta (int): rotary base theta """ def __init__( @@ -61,6 +62,7 @@ def __init__( use_ckpting=[], parallel_gpu=1, rotary_interleave=True, + rotary_theta=1e4, ): super(TransformerEncoderLayer, self).__init__() @@ -72,6 +74,7 @@ def __init__( max_relative_positions=max_relative_positions, relative_positions_buckets=relative_positions_buckets, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, attn_type="self", add_qkvbias=add_qkvbias, num_kv=num_kv, @@ -177,6 +180,7 @@ def __init__( use_ckpting=[], parallel_gpu=1, rotary_interleave=True, + rotary_theta=1e4, ): super(TransformerEncoder, self).__init__() @@ -201,6 +205,7 @@ def __init__( use_ckpting=use_ckpting, parallel_gpu=parallel_gpu, rotary_interleave=rotary_interleave, + rotary_theta=rotary_theta, ) for i in range(num_layers) ] @@ -239,6 +244,7 @@ def from_opt(cls, opt, embeddings): if opt.parallel_mode == "tensor_parallel" else 1, rotary_interleave=opt.rotary_interleave, + rotary_theta=opt.rotary_theta, ) def forward(self, src, src_len=None): diff --git a/onmt/modules/moe.py b/onmt/modules/moe.py index 2e1c959636..f356130d97 100644 --- a/onmt/modules/moe.py +++ b/onmt/modules/moe.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from onmt.modules.position_ffn import PositionwiseFeedForward +from torch.distributed import all_reduce class MoE(nn.Module): @@ -40,12 +41,15 @@ def __init__( ) self.gate = nn.Linear(d_model, num_experts, bias=False) self.num_experts_per_tok = num_experts_per_tok + self.parallel_gpu = parallel_gpu def forward(self, x): orig_shape = x.shape x = x.view(-1, x.shape[-1]) scores = self.gate(x) + if self.parallel_gpu > 1: + all_reduce(scores) expert_weights, expert_indices = torch.topk( scores, self.num_experts_per_tok, dim=-1 ) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 8306e524ea..d71ffbc460 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -11,7 +11,6 @@ from torch.distributed import all_reduce from importlib import import_module - # Help functions for Rotary Embeddings # https://arxiv.org/pdf/2104.09864.pdf # too convoluted to make maxseqlen a parameter. @@ -258,6 +257,7 @@ def __init__( max_relative_positions: int = 0, relative_positions_buckets: int = 0, rotary_interleave: bool = True, + rotary_theta: int = 1e4, attn_type: str = None, self_attn_type: str = None, add_qkvbias=False, @@ -352,9 +352,19 @@ def __init__( self.relative_attention_bias = None if max_relative_positions == -1: # rotary embeddings - self.rope = rotaryembeddings(self.dim_per_head) + self.rope = rotaryembeddings(self.dim_per_head, base=rotary_theta) + self.cos = ( + self.rope[:, : self.rope.size(1) // 2].real.contiguous().half() + ) + self.sin = ( + self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half() + ) self.rotary_interleave = rotary_interleave - + self.rotary_theta = rotary_theta + else: + self.cos = None + self.sin = None + self.rotary_interleave = None if max_relative_positions == -2: # alibi positional bias self.alibi = AlibiPositionalBias(head_count) @@ -367,6 +377,9 @@ def __init__( and torch.cuda.get_device_capability()[0] >= 8 ): self.flash_attn_func = getattr(flash_pack, "flash_attn_func") + self.flash_attn_with_kvcache = getattr( + flash_pack, "flash_attn_with_kvcache" + ) self.flash2 = True except ImportError: self.flash2 = False @@ -420,27 +433,104 @@ def forward( key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) - if self.max_relative_positions == -1: # Rotary Embeddings - start_pos = step - seqlen = query.size(2) - if seqlen > self.rope.size(0): - self.rope = rotaryembeddings( - self.dim_per_head, maxseqlen=(seqlen + 2048) - ).to(self.rope.device) - rope = self.rope[start_pos : start_pos + seqlen] - query, key = apply_rotary_emb( - query, key, rope, interleave=self.rotary_interleave - ) + start_pos = step + seqlen = query.size(2) + + if ( + step == 0 + or not self.flash2 + or self.max_relative_positions not in [0, -1] + or query.size(0) > 128 + or query.dtype != torch.float16 + ): + if self.max_relative_positions == -1: # Rotary Embeddings + if seqlen > self.rope.size(0): + self.rope = rotaryembeddings( + self.dim_per_head, + maxseqlen=(seqlen + 2048), + base=self.rotary_theta, + ).to(self.rope.device) + rope = self.rope[start_pos : start_pos + seqlen] + query, key = apply_rotary_emb( + query, key, rope, interleave=self.rotary_interleave + ) + + if self.layer_cache[1]["keys"].numel() != 0: + key = torch.cat((self.layer_cache[1]["keys"], key), dim=2) + value = torch.cat((self.layer_cache[1]["values"], value), dim=2) + if sliding_window > 0 and key.size(2) > sliding_window: + key = key[:, :, 1:, :] + value = value[:, :, 1:, :] + + self.layer_cache[1]["keys"] = key + self.layer_cache[1]["values"] = value - if self.layer_cache[1]["keys"].numel() != 0: - key = torch.cat((self.layer_cache[1]["keys"], key), dim=2) - value = torch.cat((self.layer_cache[1]["values"], value), dim=2) + else: + if self.max_relative_positions == -1: # Rotary Embeddings + if seqlen > self.rope.size(0): + self.rope = rotaryembeddings( + self.dim_per_head, + maxseqlen=(seqlen + 2048), + base=self.rotary_theta, + ).to(self.rope.device) + self.cos = ( + self.rope[:, : self.rope.size(1) // 2] + .real.contiguous() + .half() + ) + self.sin = ( + self.rope[:, : self.rope.size(1) // 2] + .imag.contiguous() + .half() + ) + if start_pos >= self.layer_cache[1]["keys"].size(2): + self.layer_cache[1]["keys"] = torch.cat( + [ + self.layer_cache[1]["keys"], + torch.zeros( + self.layer_cache[1]["keys"].shape[:-2] + + (32,) + + self.layer_cache[1]["keys"].shape[-1:], + device=query.device, + ).half(), + ], + dim=-2, + ) + self.layer_cache[1]["values"] = torch.cat( + [ + self.layer_cache[1]["values"], + torch.zeros( + self.layer_cache[1]["values"].shape[:-2] + + (32,) + + self.layer_cache[1]["values"].shape[-1:], + device=query.device, + ).half(), + ], + dim=-2, + ) if sliding_window > 0 and key.size(2) > sliding_window: - key = key[:, :, 1:, :] - value = value[:, :, 1:, :] + self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][ + :, :, 1:, : + ] + self.layer_cache[1]["values"] = self.layer_cache[1]["values"][ + :, :, 1:, : + ] + context = self.flash_attn_with_kvcache( + query.transpose(1, 2), + self.layer_cache[1]["keys"].transpose(1, 2), + self.layer_cache[1]["values"].transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + rotary_cos=self.cos, + rotary_sin=self.sin, + cache_seqlens=step, + rotary_interleaved=self.rotary_interleave, + ).transpose(1, 2) + attn_output = self.final_linear(unshape(context)) + if self.parallel_gpu > 1: + all_reduce(attn_output) + return attn_output, None - self.layer_cache[1]["keys"] = key - self.layer_cache[1]["values"] = value elif self.attn_type == "context": query = self.linear_query(query) query = shape(query, self.dim_per_head) @@ -484,7 +574,9 @@ def forward( seqlen = query.size(2) if seqlen > self.rope.size(0): self.rope = rotaryembeddings( - self.dim_per_head, maxseqlen=(seqlen + 2048) + self.dim_per_head, + maxseqlen=(seqlen + 2048), + base=self.rotary_theta, ).to(self.rope.device) rope = self.rope[start_pos : start_pos + seqlen].to(query.device) query, key = apply_rotary_emb( diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py index 7ef207228d..4fa85a6bb9 100644 --- a/onmt/modules/position_ffn.py +++ b/onmt/modules/position_ffn.py @@ -3,11 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint - -try: - from apex.normalization import FusedRMSNorm as RMSNorm -except ImportError: - from onmt.modules.rmsnorm import RMSNorm +from onmt.modules.rmsnorm import RMSNorm from torch.nn.utils import skip_init from torch.distributed import all_reduce diff --git a/onmt/modules/rmsnorm.py b/onmt/modules/rmsnorm.py index 50d05529a4..3d8515aa5f 100644 --- a/onmt/modules/rmsnorm.py +++ b/onmt/modules/rmsnorm.py @@ -3,11 +3,19 @@ import torch import torch.nn as nn +try: + import awq_inference_engine + + AWQ_INFERENCE_ENGINE = True +except ImportError: + AWQ_INFERENCE_ENGINE = False + class RMSNorm(torch.nn.Module): """RMSNorm: https://arxiv.org/abs/1910.07467 Args: - hidden_size (int): layer hidden_sizeension. + hidden_size (int): layer hidden_size dimension. + eps: variance epsilon. """ def __init__(self, hidden_size: int, eps: float = 1e-6): @@ -16,8 +24,19 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) def forward(self, hidden_states): - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - hidden_states = hidden_states.to(self.weight.dtype) - return hidden_states * self.weight + if AWQ_INFERENCE_ENGINE: + output = torch.empty_like(hidden_states) + if hidden_states.dim() == 2: # patch for multi experts + hidden_states = hidden_states.unsqueeze(0) + awq_inference_engine.layernorm_forward_cuda( + hidden_states, self.weight, output, self.eps + ) + if hidden_states.dim() == 2: # patch for multi experts + output = output.unsqueeze(0) + return output + else: + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + hidden_states = hidden_states.to(self.weight.dtype) + return hidden_states * self.weight diff --git a/onmt/opts.py b/onmt/opts.py index 1fa3305fe0..c0237f377f 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -880,6 +880,13 @@ def model_opts(parser): "True = default Llama from Meta (original)" "False = used by all Hugging face models", ) + group.add( + "--rotary_theta", + "-rotary_theta", + type=int, + default=10000, + help="Rotary theta base length" "1e4 for Llama2.Mistral" "1e6 for Mixtral", + ) group.add( "--heads", "-heads", diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 45cefbd17b..c576d7cda9 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -1102,6 +1102,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): ) # (4) Begin decoding step by step: + beg_time = time() for step in range(decode_strategy.max_length): decoder_input = ( src if step == 0 else decode_strategy.current_predictions.view(-1, 1, 1) @@ -1139,6 +1140,8 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): if parallel_paths > 1 or any_finished: # select indexes in model state/cache self.model.decoder.map_state(lambda state, dim: state[select_indices]) + if step == 0: + print("step0 time: ", time() - beg_time) return self.report_results( gold_score, diff --git a/onmt/utils/distributed.py b/onmt/utils/distributed.py index e6779c397f..cb0f55c4f9 100644 --- a/onmt/utils/distributed.py +++ b/onmt/utils/distributed.py @@ -212,7 +212,7 @@ def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result): device_id=device_id, ) scores, preds = translator._translate( - infer_iter, infer_iter.transform, opt.attn_debug, opt.align_debug + infer_iter, infer_iter.transforms, opt.attn_debug, opt.align_debug ) queue_result.put(scores) queue_result.put(preds) diff --git a/tools/convert_HF_llamalike.py b/tools/convert_HF_llamalike.py index 8fe9d4efc6..32ace36aee 100755 --- a/tools/convert_HF_llamalike.py +++ b/tools/convert_HF_llamalike.py @@ -223,8 +223,14 @@ def __init__(self, model_path: str): norm_eps = config["layer_norm_epsilon"] else: norm_eps = 1e-6 + if "rope_theta" in config.keys(): + rope_theta = config["rope_theta"] + else: + rope_theta = 1e4 if "sliding_window" in config.keys(): sliding_window = config["sliding_window"] + if sliding_window is None: + sliding_window = 4096 else: sliding_window = 0 @@ -633,6 +639,7 @@ def get_weight(checkpoint, tensor_name): self_attn_type="scaled-dot", max_relative_positions=-1, rotary_interleave=False, + rotary_theta=rope_theta, heads=heads, sliding_window=sliding_window, transformer_ff=transformer_ff,