Skip to content

Commit

Permalink
Enable transformers as a backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeronymous committed Jan 26, 2024
1 parent 713626e commit ff8d897
Showing 1 changed file with 247 additions and 9 deletions.
256 changes: 247 additions & 9 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def transcribe_timestamped(
if fp16 is None:
fp16 = model.device != torch.device("cpu")

# TODO: implement efficient approach with transformers
if is_transformer_model(model):
naive_approach = True

# Safety check
input_stride = N_FRAMES // model.dims.n_audio_ctx
time_precision = input_stride * HOP_LENGTH / SAMPLE_RATE
Expand Down Expand Up @@ -290,7 +294,8 @@ def transcribe_timestamped(
transcription, words = remove_last_null_duration_words(transcription, words, recompute_text=True)

# Refine word positions
ensure_increasing_positions(words, min_duration=min_word_duration if trust_whisper_timestamps else 0)
if not is_transformer_model(model):
ensure_increasing_positions(words, min_duration=min_word_duration if trust_whisper_timestamps else 0)

# Combine words and segments
whisper_segments = transcription["segments"]
Expand Down Expand Up @@ -1035,7 +1040,7 @@ def hook_output_logits(layer, ins, outs):

n_mels = model.dims.n_mels if hasattr(model.dims, "n_mels") else 80

attention_weights = [[] for _ in range(min(word_alignement_most_top_layers,len(model.decoder.blocks)))]
attention_weights = [[] for _ in range(min(word_alignement_most_top_layers, len(model.decoder.blocks)))]

try:

Expand All @@ -1047,9 +1052,15 @@ def hook_output_logits(layer, ins, outs):
for i, block in enumerate(model.decoder.blocks):
if i < nblocks - word_alignement_most_top_layers:
continue
def hook(layer, ins, outs, index=j):
if is_transformer_model(model):
attention_weights[index] = outs[1].log()
else:
attention_weights[index] = outs[1]
all_hooks.append(
block.cross_attn.register_forward_hook(
lambda layer, ins, outs, index=j: attention_weights.__setitem__(index, outs[-1])
hook
# lambda layer, ins, outs, index=j: attention_weights.__setitem__(index, outs[1])
)
)
j += 1
Expand Down Expand Up @@ -1159,12 +1170,20 @@ def hook_output_logits(layer, ins, outs):
last_token_check = tokens[-1]
tokens = tokens[:-1]

sot_sequence = tokenizer.sot_sequence
if language:
assert len(sot_sequence) == 3
sot_sequence = (
sot_sequence[0],
tokenizer.to_language_token(language),
sot_sequence[2],
)
tokens = [
*tokenizer.sot_sequence,
*sot_sequence,
tokenizer.timestamp_begin,
] + tokens

i_start = len(tokenizer.sot_sequence)
i_start = len(sot_sequence)

with torch.no_grad():
logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0))
Expand Down Expand Up @@ -1234,8 +1253,10 @@ def hook_output_logits(layer, ins, outs):
segment_tokens_check.append(last_token_check)
if trust_whisper_timestamps:
if segment_tokens_check != segment["tokens"]:
assert len(segment_tokens_check) < len(segment["tokens"]) and segment_tokens_check[:-1] == segment["tokens"][:len(segment_tokens_check)-1], \
f"Got inconsistent tokens: {tokenizer.decode(segment_tokens_check)} != {tokenizer.decode(segment['tokens'])}"
assert len(segment_tokens_check) < len(segment["tokens"]), \
f"First should be longer by one token: '{tokenizer.decode_with_timestamps(segment_tokens_check)}' should include '{tokenizer.decode_with_timestamps(segment['tokens'])}'"
assert segment_tokens_check[:-1] == segment["tokens"][:len(segment_tokens_check)-1], \
f"Got inconsistent tokens: {tokenizer.decode_with_timestamps(segment_tokens_check)} != {tokenizer.decode_with_timestamps(segment['tokens'])}"
segment["tokens"] = segment_tokens_check
segment["text"] = tokenizer.decode(segment["tokens"])
# else: TODO
Expand Down Expand Up @@ -1293,6 +1314,10 @@ def print_timestamped(w):


def get_logit_filters(model, whisper_options, prompt = None):
if is_transformer_model(model):
# import transformers
# transformers.WhisperTimeStampLogitsProcessor
raise NotImplementedError("TODO")
decoding_options = get_decoding_options(whisper_options)
if "initial_prompt" in decoding_options:
prompt0 = decoding_options.pop("initial_prompt")
Expand Down Expand Up @@ -1324,6 +1349,15 @@ def get_decoding_options(whisper_options):
])

def get_tokenizer(model, task="transcribe", language="en"):
if is_transformer_model(model):
tokenizer = model.tokenizer
tokenizer.sot_sequence = (
tokenizer.sot,
tokenizer.to_language_token(language or "en"),
tokenizer.to_task_token(task),
)
tokenizer.sot_sequence
return model.tokenizer
try:
return whisper.tokenizer.get_tokenizer(
model.is_multilingual,
Expand Down Expand Up @@ -2260,7 +2294,7 @@ def _get_alignment_heads(model_name, num_layers, num_heads):
def _get_number_of_parameters(model):
num_parameters = 0
for name, p in model.named_parameters():
if name in ["decoder.proj_out.weight"]:
if name in ["decoder.proj_out.weight", "model.encoder.embed_positions.weight"]:
continue
num_parameters += p.numel()
return num_parameters
Expand All @@ -2271,7 +2305,26 @@ def load_model(
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
backend: str = "openai-whisper",
):
if backend == "transformers":
try:
import transformers
except ImportError:
raise ImportError(f"If you want to use transformers backend, please install first the transformers library")
if name in whisper.available_models():
name = f"openai/whisper-{name}"
# TODO: use download_root
# TODO: does in_memory makes sense?
generation_config = transformers.GenerationConfig.from_pretrained(name)
processor = transformers.WhisperProcessor.from_pretrained(name)
model = transformers.WhisperForConditionalGeneration.from_pretrained(name)
model = model.to(device)
return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config)

elif backend not in ["openai", "openai-whisper"]:
raise ValueError(f"Got unexpected backend {backend}")

extension = os.path.splitext(name)[-1] if os.path.isfile(name) else None

if name in whisper.available_models() or extension == ".pt":
Expand Down Expand Up @@ -2359,7 +2412,190 @@ def torch_load(model_path):
hf_state_dict = torch.load(model_path, map_location="cpu")
return hf_state_dict

# Some helpers to manage transformers/openai-whisper model

class TransformerWhisperAsOpenAIWhisper:
"""
Wrapper to use a transformers model as a whisper model (at least in whisper-timestamped)
"""

def __init__(self, model, processor, generation_config):

self.model = model # transformers.WhisperForConditionalGeneration
self.processor = processor # transformers.WhisperProcessor
self.generation_config = generation_config # transformers.GenerationConfig
self.device = model.device

# Dimensions
self.dims = whisper.model.ModelDimensions(
n_mels = model.get_encoder().get_input_embeddings().in_channels,
n_audio_ctx = 1500,
n_audio_state = model.get_encoder().get_input_embeddings().out_channels,
n_audio_head = model.get_encoder().layers[0].self_attn.num_heads,
n_audio_layer = len(model.get_encoder().layers),
n_vocab = model.get_decoder().get_input_embeddings().num_embeddings,
n_text_ctx = 448,
n_text_state = model.get_decoder().get_input_embeddings().embedding_dim,
n_text_head = model.get_decoder().layers[0].self_attn.num_heads,
n_text_layer = len(model.get_decoder().layers),
)

# Tokenization
self.tokenizer = processor.tokenizer
(
self.tokenizer.sot,
self.tokenizer.eot,
self.tokenizer.timestamp_begin,
self.tokenizer.no_speech,
) = self.tokenizer.convert_tokens_to_ids([
"<|startoftranscript|>",
"<|endoftext|>",
"<|0.00|>",
"<|nospeech|>",
])
self.tokenizer.all_language_tokens = self.tokenizer.convert_tokens_to_ids([
t for t in self.tokenizer.additional_special_tokens if len(t) in [6,7]
])
self.tokenizer.to_language_token = lambda language: self.generation_config.lang_to_id["<|" + language + "|>"]
self.tokenizer.to_task_token = lambda task: self.generation_config.task_to_id[task]
self.tokenizer.to_timestamp_token = lambda t: self.tokenizer.encode(f"<|{t:0.2f}|>", add_special_tokens=False)[0]
self.tokenizer.decode_with_timestamps = lambda tokens: self.tokenizer.decode(tokens, decode_with_timestamps=True)

# Access to layers (renamed attributes)
self.decoder = self.model.get_decoder()
self.decoder.ln = self.decoder.layer_norm
self.decoder.token_embedding = self.decoder.embed_tokens
self.decoder.blocks = self.decoder.layers
for block in self.decoder.blocks:
block.cross_attn = block.encoder_attn

# From the config
self.is_multilingual = generation_config.is_multilingual # (self.tokenizer.sot != 50257)

# Alignment heads
if hasattr(generation_config, "alignment_heads"):
a = generation_config.alignment_heads
self.alignment_heads = torch.sparse_coo_tensor(np.array(a).transpose(), [True]*len(a)).coalesce().to(self.device)

def named_parameters(self):
return self.model.named_parameters()

def transcribe(self, audio, **kwargs):

features = self.processor(
audio,
return_tensors="pt",
sampling_rate=16_000,
truncation=False,
).input_features.to(self.device)

# TODO: double check that this setup is correct
generation_config = self.generation_config
generation_config.num_beams = kwargs.get("beam_size", None)
temperature = kwargs.get("temperature", 0.0)
if isinstance(temperature, (list, tuple)):
# Not supported with transformers
temperature = min(temperature)
if temperature != 0.0:
generation_config.do_sample = True
generation_config.temperature = temperature
generation_config.top_k = kwargs.get("best_of", None)

output = self.model.generate(
features,
return_dict_in_generate = True,
return_segments = True,
return_timestamps = True,
return_token_timestamps = False, # Note: concurrent token timestamps by transformers
max_length = self.dims.n_text_ctx,
is_multilingual = self.is_multilingual,
task = kwargs.get("task", "transcribe"),
language = kwargs.get("language"),
prompt_ids = kwargs.get("initial_prompt"),
generation_config = generation_config,
)

output_dict = {}

if "segments" not in output:
# WTF : the output format is different when there is only one segment (e.g. audio duration < 30 seconds)
tokens = output.sequences[0]
output = {
"segments": [[{
"tokens": tokens[1:],
"start": torch.tensor(0.0),
"result": {
"sequences": output.sequences,
"past_key_values": output.past_key_values,
}
}]]
}

full_text = ""
segments = []
id = -1
for sub_segments in output["segments"]:
for segment in sub_segments:
id += 1
tokens = segment["tokens"]
text = self.tokenizer.decode(tokens, skip_special_tokens=True)
token_timestamps = [(i, t.item()) for i, t in enumerate(tokens) if t >= self.tokenizer.timestamp_begin]
if len(token_timestamps):
assert len(token_timestamps) == 2, f"Got unexpected number of timestamps: {token_timestamps}"
i_start, token_start = token_timestamps[0]
i_end, token_end = token_timestamps[1]
tokens = tokens[i_start+1:i_end]
relative_start = (token_start - self.tokenizer.timestamp_begin) * AUDIO_TIME_PER_TOKEN
relative_end = (token_end - self.tokenizer.timestamp_begin) * AUDIO_TIME_PER_TOKEN
start = max(0, segment["start"].item())
relative_start = round(relative_start, 2)
start = round(start, 2)
end = round(start + relative_end - relative_start, 2)
offset = start - relative_start
assert offset >= 0, f"Got negative offset ({offset}) with {start=} and {relative_start=}"
else:
raise NotImplementedError("Unexpected situation: no timestamp found")
# start = max(0, segment["start"].item())
# end = segment["end"].item()
# offset = 0
# assert end >= start, f"Got end < start ({end} < {start})"
# token_start = self.tokenizer.to_timestamp_token(start)
# token_end = self.tokenizer.to_timestamp_token(end)

# Accumulate
segments.append({
"id": id,
"seek": round(offset * SAMPLE_RATE / HOP_LENGTH),
"start": start,
"end": end,
"text": text,
"tokens": [token_start] + tokens.tolist() + [token_end],
"temperature": temperature,
# "avg_logprob": -0.6982866287231445,
# "compression_ratio": 0.5294117647058824,
# "no_speech_prob": 0.019023602828383446
})
full_text += text

output_dict = {
"text": full_text,
"segments": segments,
}
if not kwargs.get("language"):
language_detected = self.tokenizer.decode(output["segments"][0][0]["tokens"][0])
assert len(language_detected) in [6,7], f"Unexpected language detected: {language_detected}"
language_detected = language_detected[2:-2]
output_dict["language"] = language_detected

return output_dict

def __call__(self, mfcc, tokens):
output = self.model(mfcc, decoder_input_ids=tokens, output_attentions=True)
return output.logits


def is_transformer_model(model):
return isinstance(model, TransformerWhisperAsOpenAIWhisper)


# Credit: https://github.com/openai/whisper/discussions/830
Expand Down Expand Up @@ -2500,6 +2736,7 @@ def get_do_write(output_format):
parser.add_argument('--model', help=f"name of the Whisper model to use. Examples: {', '.join(whisper.available_models())}", default="small")
parser.add_argument("--model_dir", default=None, help="the path to save model files; uses ~/.cache/whisper by default", type=str)
parser.add_argument("--device", default=get_default_device(), help="device to use for PyTorch inference")
parser.add_argument("--backend", default="openai-whisper", help="Which backend to use", choices=["openai-whisper", "transformers"], type=str)
parser.add_argument("--output_dir", "-o", default=None, help="directory to save the outputs", type=str)
valid_formats = ["txt", "vtt", "srt", "tsv", "csv", "json"]
def str2output_formats(string):
Expand Down Expand Up @@ -2590,8 +2827,9 @@ def __call__(self, parser, namespace, values, option_string=None):
force_cudnn_initialization(device)

output_format = args.pop("output_format")
backend = args.pop("backend")

model = load_model(model, device=device, download_root=model_dir)
model = load_model(model, device=device, download_root=model_dir, backend=backend)

plot_word_alignment = args.pop("plot")

Expand Down

0 comments on commit ff8d897

Please sign in to comment.