diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 6233257..266e64c 100755 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -76,7 +76,7 @@ def transcribe_timestamped( refine_whisper_precision=0.5, min_word_duration=0.02, # Was 0.04 before 1.11 plot_word_alignment=False, - word_alignement_most_top_layers=None, # Was 6 before 1.9 + word_alignment_most_top_layers=None, # Was 6 before 1.9 remove_empty_words=False, use_backend_timestamps=False, @@ -214,7 +214,7 @@ def transcribe_timestamped( assert refine_whisper_precision >= 0 and refine_whisper_precision / AUDIO_TIME_PER_TOKEN == round(refine_whisper_precision / AUDIO_TIME_PER_TOKEN), f"refine_whisper_precision must be a positive multiple of {AUDIO_TIME_PER_TOKEN}" refine_whisper_precision_nframes = round(refine_whisper_precision / AUDIO_TIME_PER_TOKEN) assert min_word_duration >= 0, f"min_word_duration must be a positive number" - assert word_alignement_most_top_layers is None or word_alignement_most_top_layers > 0, f"word_alignement_most_top_layers must be a strictly positive number" + assert word_alignment_most_top_layers is None or word_alignment_most_top_layers > 0, f"word_alignment_most_top_layers must be a strictly positive number" if isinstance(temperature, (list, tuple)) and len(temperature) == 1: temperature = temperature[0] @@ -242,9 +242,9 @@ def transcribe_timestamped( time_precision = input_stride * HOP_LENGTH / SAMPLE_RATE assert time_precision == AUDIO_TIME_PER_TOKEN - alignment_heads = get_alignment_heads(model) if word_alignement_most_top_layers is None else None - if alignment_heads is None and word_alignement_most_top_layers is None: - word_alignement_most_top_layers = 6 + alignment_heads = get_alignment_heads(model) if word_alignment_most_top_layers is None else None + if alignment_heads is None and word_alignment_most_top_layers is None: + word_alignment_most_top_layers = 6 alignment_options = dict( remove_punctuation_from_words=remove_punctuation_from_words, @@ -253,7 +253,7 @@ def transcribe_timestamped( detect_disfluencies=detect_disfluencies, refine_whisper_precision_nframes=refine_whisper_precision_nframes, plot_word_alignment=plot_word_alignment, - word_alignement_most_top_layers=word_alignement_most_top_layers, + word_alignment_most_top_layers=word_alignment_most_top_layers, alignment_heads=alignment_heads, ) whisper_options = dict( @@ -351,7 +351,7 @@ def _transcribe_timestamped_efficient( refine_whisper_precision_nframes, alignment_heads, plot_word_alignment, - word_alignement_most_top_layers, + word_alignment_most_top_layers, detect_disfluencies, trust_whisper_timestamps, use_timestamps_for_alignment = True, @@ -378,13 +378,13 @@ def _transcribe_timestamped_efficient( debug = logger.getEffectiveLevel() >= logging.DEBUG - word_alignement_most_top_layers = float("inf") if word_alignement_most_top_layers is None else word_alignement_most_top_layers + word_alignment_most_top_layers = float("inf") if word_alignment_most_top_layers is None else word_alignment_most_top_layers # The main outcome timestamped_word_segments = [] # list of timestamped word segments that have been collected so far # Main variables to be accumulated segment_tokens = [[]] # list of lists of token indices that have been collected so far (one list per segment) - segment_attweights = [[] for _ in range(min(word_alignement_most_top_layers, len(model.decoder.blocks)))] + segment_attweights = [[] for _ in range(min(word_alignment_most_top_layers, len(model.decoder.blocks)))] # attention weights on the last segments segment_avglogprobs = [] # average log probability for each segment (actually of the corresponding chunk, as computed by whisper) segment_logprobs = [] # token log probabilities for each segment @@ -875,7 +875,7 @@ def hook_output_logits(layer, ins, outs): nblocks = len(model.decoder.blocks) j = 0 for i, block in enumerate(model.decoder.blocks): - if i < nblocks - word_alignement_most_top_layers: + if i < nblocks - word_alignment_most_top_layers: continue all_hooks.append( block.cross_attn.register_forward_hook( @@ -995,7 +995,7 @@ def _transcribe_timestamped_naive( use_backend_timestamps, alignment_heads, plot_word_alignment, - word_alignement_most_top_layers, + word_alignment_most_top_layers, detect_disfluencies, trust_whisper_timestamps, min_word_duration, @@ -1006,7 +1006,7 @@ def _transcribe_timestamped_naive( language = whisper_options["language"] refine_whisper_precision_sec = refine_whisper_precision_nframes * AUDIO_TIME_PER_TOKEN - word_alignement_most_top_layers = float("inf") if word_alignement_most_top_layers is None else word_alignement_most_top_layers + word_alignment_most_top_layers = float("inf") if word_alignment_most_top_layers is None else word_alignment_most_top_layers audio = get_audio_tensor(audio) audio_duration = audio.shape[-1] / SAMPLE_RATE @@ -1077,7 +1077,7 @@ def hook_output_logits(layer, ins, outs): n_mels = model.dims.n_mels if hasattr(model.dims, "n_mels") else 80 - attention_weights = [[] for _ in range(min(word_alignement_most_top_layers, len(model.decoder.blocks)))] + attention_weights = [[] for _ in range(min(word_alignment_most_top_layers, len(model.decoder.blocks)))] try: @@ -1087,7 +1087,7 @@ def hook_output_logits(layer, ins, outs): nblocks = len(model.decoder.blocks) j = 0 for i, block in enumerate(model.decoder.blocks): - if i < nblocks - word_alignement_most_top_layers: + if i < nblocks - word_alignment_most_top_layers: continue def hook(layer, ins, outs, index=j): if is_transformer_model(model):