diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 345338f..c035e9e 100755 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -229,6 +229,10 @@ def transcribe_timestamped( if fp16 is None: fp16 = model.device != torch.device("cpu") + # TODO: implement efficient approach with transformers + if is_transformer_model(model): + naive_approach = True + # Safety check input_stride = N_FRAMES // model.dims.n_audio_ctx time_precision = input_stride * HOP_LENGTH / SAMPLE_RATE @@ -290,7 +294,8 @@ def transcribe_timestamped( transcription, words = remove_last_null_duration_words(transcription, words, recompute_text=True) # Refine word positions - ensure_increasing_positions(words, min_duration=min_word_duration if trust_whisper_timestamps else 0) + if not is_transformer_model(model): + ensure_increasing_positions(words, min_duration=min_word_duration if trust_whisper_timestamps else 0) # Combine words and segments whisper_segments = transcription["segments"] @@ -1035,7 +1040,7 @@ def hook_output_logits(layer, ins, outs): n_mels = model.dims.n_mels if hasattr(model.dims, "n_mels") else 80 - attention_weights = [[] for _ in range(min(word_alignement_most_top_layers,len(model.decoder.blocks)))] + attention_weights = [[] for _ in range(min(word_alignement_most_top_layers, len(model.decoder.blocks)))] try: @@ -1047,9 +1052,15 @@ def hook_output_logits(layer, ins, outs): for i, block in enumerate(model.decoder.blocks): if i < nblocks - word_alignement_most_top_layers: continue + def hook(layer, ins, outs, index=j): + if is_transformer_model(model): + attention_weights[index] = outs[1].log() + else: + attention_weights[index] = outs[1] all_hooks.append( block.cross_attn.register_forward_hook( - lambda layer, ins, outs, index=j: attention_weights.__setitem__(index, outs[-1]) + hook + # lambda layer, ins, outs, index=j: attention_weights.__setitem__(index, outs[1]) ) ) j += 1 @@ -1159,12 +1170,20 @@ def hook_output_logits(layer, ins, outs): last_token_check = tokens[-1] tokens = tokens[:-1] + sot_sequence = tokenizer.sot_sequence + if language: + assert len(sot_sequence) == 3 + sot_sequence = ( + sot_sequence[0], + tokenizer.to_language_token(language), + sot_sequence[2], + ) tokens = [ - *tokenizer.sot_sequence, + *sot_sequence, tokenizer.timestamp_begin, ] + tokens - i_start = len(tokenizer.sot_sequence) + i_start = len(sot_sequence) with torch.no_grad(): logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0)) @@ -1234,8 +1253,10 @@ def hook_output_logits(layer, ins, outs): segment_tokens_check.append(last_token_check) if trust_whisper_timestamps: if segment_tokens_check != segment["tokens"]: - assert len(segment_tokens_check) < len(segment["tokens"]) and segment_tokens_check[:-1] == segment["tokens"][:len(segment_tokens_check)-1], \ - f"Got inconsistent tokens: {tokenizer.decode(segment_tokens_check)} != {tokenizer.decode(segment['tokens'])}" + assert len(segment_tokens_check) < len(segment["tokens"]), \ + f"First should be longer by one token: '{tokenizer.decode_with_timestamps(segment_tokens_check)}' should include '{tokenizer.decode_with_timestamps(segment['tokens'])}'" + assert segment_tokens_check[:-1] == segment["tokens"][:len(segment_tokens_check)-1], \ + f"Got inconsistent tokens: {tokenizer.decode_with_timestamps(segment_tokens_check)} != {tokenizer.decode_with_timestamps(segment['tokens'])}" segment["tokens"] = segment_tokens_check segment["text"] = tokenizer.decode(segment["tokens"]) # else: TODO @@ -1293,6 +1314,10 @@ def print_timestamped(w): def get_logit_filters(model, whisper_options, prompt = None): + if is_transformer_model(model): + # import transformers + # transformers.WhisperTimeStampLogitsProcessor + raise NotImplementedError("TODO") decoding_options = get_decoding_options(whisper_options) if "initial_prompt" in decoding_options: prompt0 = decoding_options.pop("initial_prompt") @@ -1324,6 +1349,15 @@ def get_decoding_options(whisper_options): ]) def get_tokenizer(model, task="transcribe", language="en"): + if is_transformer_model(model): + tokenizer = model.tokenizer + tokenizer.sot_sequence = ( + tokenizer.sot, + tokenizer.to_language_token(language or "en"), + tokenizer.to_task_token(task), + ) + tokenizer.sot_sequence + return model.tokenizer try: return whisper.tokenizer.get_tokenizer( model.is_multilingual, @@ -2260,7 +2294,7 @@ def _get_alignment_heads(model_name, num_layers, num_heads): def _get_number_of_parameters(model): num_parameters = 0 for name, p in model.named_parameters(): - if name in ["decoder.proj_out.weight"]: + if name in ["decoder.proj_out.weight", "model.encoder.embed_positions.weight"]: continue num_parameters += p.numel() return num_parameters @@ -2271,7 +2305,26 @@ def load_model( device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False, + backend: str = "openai-whisper", ): + if backend == "transformers": + try: + import transformers + except ImportError: + raise ImportError(f"If you want to use transformers backend, please install first the transformers library") + if name in whisper.available_models(): + name = f"openai/whisper-{name}" + # TODO: use download_root + # TODO: does in_memory makes sense? + generation_config = transformers.GenerationConfig.from_pretrained(name) + processor = transformers.WhisperProcessor.from_pretrained(name) + model = transformers.WhisperForConditionalGeneration.from_pretrained(name) + model = model.to(device) + return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config) + + elif backend not in ["openai", "openai-whisper"]: + raise ValueError(f"Got unexpected backend {backend}") + extension = os.path.splitext(name)[-1] if os.path.isfile(name) else None if name in whisper.available_models() or extension == ".pt": @@ -2359,7 +2412,190 @@ def torch_load(model_path): hf_state_dict = torch.load(model_path, map_location="cpu") return hf_state_dict +# Some helpers to manage transformers/openai-whisper model + +class TransformerWhisperAsOpenAIWhisper: + """ + Wrapper to use a transformers model as a whisper model (at least in whisper-timestamped) + """ + + def __init__(self, model, processor, generation_config): + + self.model = model # transformers.WhisperForConditionalGeneration + self.processor = processor # transformers.WhisperProcessor + self.generation_config = generation_config # transformers.GenerationConfig + self.device = model.device + + # Dimensions + self.dims = whisper.model.ModelDimensions( + n_mels = model.get_encoder().get_input_embeddings().in_channels, + n_audio_ctx = 1500, + n_audio_state = model.get_encoder().get_input_embeddings().out_channels, + n_audio_head = model.get_encoder().layers[0].self_attn.num_heads, + n_audio_layer = len(model.get_encoder().layers), + n_vocab = model.get_decoder().get_input_embeddings().num_embeddings, + n_text_ctx = 448, + n_text_state = model.get_decoder().get_input_embeddings().embedding_dim, + n_text_head = model.get_decoder().layers[0].self_attn.num_heads, + n_text_layer = len(model.get_decoder().layers), + ) + + # Tokenization + self.tokenizer = processor.tokenizer + ( + self.tokenizer.sot, + self.tokenizer.eot, + self.tokenizer.timestamp_begin, + self.tokenizer.no_speech, + ) = self.tokenizer.convert_tokens_to_ids([ + "<|startoftranscript|>", + "<|endoftext|>", + "<|0.00|>", + "<|nospeech|>", + ]) + self.tokenizer.all_language_tokens = self.tokenizer.convert_tokens_to_ids([ + t for t in self.tokenizer.additional_special_tokens if len(t) in [6,7] + ]) + self.tokenizer.to_language_token = lambda language: self.generation_config.lang_to_id["<|" + language + "|>"] + self.tokenizer.to_task_token = lambda task: self.generation_config.task_to_id[task] + self.tokenizer.to_timestamp_token = lambda t: self.tokenizer.encode(f"<|{t:0.2f}|>", add_special_tokens=False)[0] + self.tokenizer.decode_with_timestamps = lambda tokens: self.tokenizer.decode(tokens, decode_with_timestamps=True) + + # Access to layers (renamed attributes) + self.decoder = self.model.get_decoder() + self.decoder.ln = self.decoder.layer_norm + self.decoder.token_embedding = self.decoder.embed_tokens + self.decoder.blocks = self.decoder.layers + for block in self.decoder.blocks: + block.cross_attn = block.encoder_attn + + # From the config + self.is_multilingual = generation_config.is_multilingual # (self.tokenizer.sot != 50257) + + # Alignment heads + if hasattr(generation_config, "alignment_heads"): + a = generation_config.alignment_heads + self.alignment_heads = torch.sparse_coo_tensor(np.array(a).transpose(), [True]*len(a)).coalesce().to(self.device) + + def named_parameters(self): + return self.model.named_parameters() + + def transcribe(self, audio, **kwargs): + + features = self.processor( + audio, + return_tensors="pt", + sampling_rate=16_000, + truncation=False, + ).input_features.to(self.device) + + # TODO: double check that this setup is correct + generation_config = self.generation_config + generation_config.num_beams = kwargs.get("beam_size", None) + temperature = kwargs.get("temperature", 0.0) + if isinstance(temperature, (list, tuple)): + # Not supported with transformers + temperature = min(temperature) + if temperature != 0.0: + generation_config.do_sample = True + generation_config.temperature = temperature + generation_config.top_k = kwargs.get("best_of", None) + + output = self.model.generate( + features, + return_dict_in_generate = True, + return_segments = True, + return_timestamps = True, + return_token_timestamps = False, # Note: concurrent token timestamps by transformers + max_length = self.dims.n_text_ctx, + is_multilingual = self.is_multilingual, + task = kwargs.get("task", "transcribe"), + language = kwargs.get("language"), + prompt_ids = kwargs.get("initial_prompt"), + generation_config = generation_config, + ) + + output_dict = {} + + if "segments" not in output: + # WTF : the output format is different when there is only one segment (e.g. audio duration < 30 seconds) + tokens = output.sequences[0] + output = { + "segments": [[{ + "tokens": tokens[1:], + "start": torch.tensor(0.0), + "result": { + "sequences": output.sequences, + "past_key_values": output.past_key_values, + } + }]] + } + + full_text = "" + segments = [] + id = -1 + for sub_segments in output["segments"]: + for segment in sub_segments: + id += 1 + tokens = segment["tokens"] + text = self.tokenizer.decode(tokens, skip_special_tokens=True) + token_timestamps = [(i, t.item()) for i, t in enumerate(tokens) if t >= self.tokenizer.timestamp_begin] + if len(token_timestamps): + assert len(token_timestamps) == 2, f"Got unexpected number of timestamps: {token_timestamps}" + i_start, token_start = token_timestamps[0] + i_end, token_end = token_timestamps[1] + tokens = tokens[i_start+1:i_end] + relative_start = (token_start - self.tokenizer.timestamp_begin) * AUDIO_TIME_PER_TOKEN + relative_end = (token_end - self.tokenizer.timestamp_begin) * AUDIO_TIME_PER_TOKEN + start = max(0, segment["start"].item()) + relative_start = round(relative_start, 2) + start = round(start, 2) + end = round(start + relative_end - relative_start, 2) + offset = start - relative_start + assert offset >= 0, f"Got negative offset ({offset}) with {start=} and {relative_start=}" + else: + raise NotImplementedError("Unexpected situation: no timestamp found") + # start = max(0, segment["start"].item()) + # end = segment["end"].item() + # offset = 0 + # assert end >= start, f"Got end < start ({end} < {start})" + # token_start = self.tokenizer.to_timestamp_token(start) + # token_end = self.tokenizer.to_timestamp_token(end) + + # Accumulate + segments.append({ + "id": id, + "seek": round(offset * SAMPLE_RATE / HOP_LENGTH), + "start": start, + "end": end, + "text": text, + "tokens": [token_start] + tokens.tolist() + [token_end], + "temperature": temperature, + # "avg_logprob": -0.6982866287231445, + # "compression_ratio": 0.5294117647058824, + # "no_speech_prob": 0.019023602828383446 + }) + full_text += text + + output_dict = { + "text": full_text, + "segments": segments, + } + if not kwargs.get("language"): + language_detected = self.tokenizer.decode(output["segments"][0][0]["tokens"][0]) + assert len(language_detected) in [6,7], f"Unexpected language detected: {language_detected}" + language_detected = language_detected[2:-2] + output_dict["language"] = language_detected + + return output_dict + + def __call__(self, mfcc, tokens): + output = self.model(mfcc, decoder_input_ids=tokens, output_attentions=True) + return output.logits + +def is_transformer_model(model): + return isinstance(model, TransformerWhisperAsOpenAIWhisper) # Credit: https://github.com/openai/whisper/discussions/830 @@ -2500,6 +2736,7 @@ def get_do_write(output_format): parser.add_argument('--model', help=f"name of the Whisper model to use. Examples: {', '.join(whisper.available_models())}", default="small") parser.add_argument("--model_dir", default=None, help="the path to save model files; uses ~/.cache/whisper by default", type=str) parser.add_argument("--device", default=get_default_device(), help="device to use for PyTorch inference") + parser.add_argument("--backend", default="openai-whisper", help="Which backend to use", choices=["openai-whisper", "transformers"], type=str) parser.add_argument("--output_dir", "-o", default=None, help="directory to save the outputs", type=str) valid_formats = ["txt", "vtt", "srt", "tsv", "csv", "json"] def str2output_formats(string): @@ -2590,8 +2827,9 @@ def __call__(self, parser, namespace, values, option_string=None): force_cudnn_initialization(device) output_format = args.pop("output_format") + backend = args.pop("backend") - model = load_model(model, device=device, download_root=model_dir) + model = load_model(model, device=device, download_root=model_dir, backend=backend) plot_word_alignment = args.pop("plot")