Skip to content

Commit

Permalink
Merge pull request #4 from lucadellalib/whisper-backward-compatibility
Browse files Browse the repository at this point in the history
Fix Whisper backward compatibility
  • Loading branch information
mravanelli committed Jul 12, 2023
2 parents 0496ecd + 2918948 commit 5e57b10
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions benchmarks/CL_MASR/whisper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ def __init__(
timestamp_tokens = [f"<|{ts:.2f}|>" for ts in timestamps]
self.tokenizer.add_tokens(timestamp_tokens)

# The following breaking changes were introduced in transformers>=4.29:
# 1) mel_filters.shape = (..., feature_extractor.feature_size) instead of (feature_extractor.feature_size, ...)
# 2) mel_filters.dtype = float64 instead of float32
# The following code fixes the issue in a backward compatible way
if self._mel_filters.shape[0] == self._n_fft / 2 + 1:
mel_filters = self._mel_filters.T
self.register_buffer(
"_mel_filters",
torch.as_tensor(mel_filters, dtype=torch.float32),
)
#################################################################

# override
def _log_mel_spectrogram(self, audio):
"""Compute the Mel spectrogram of a batch of input waveforms.
Expand Down Expand Up @@ -162,9 +174,6 @@ def _log_mel_spectrogram(self, audio):
magnitudes = stft[..., :-1].abs() ** 2

filters = self._mel_filters
# Fix dependency issues with transformers>=4.29 in a backward compatible way
if filters.shape[-1] != magnitudes.shape[-2]:
filters = filters.T.to(dtype=magnitudes.dtype)
mel_spec = filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
Expand Down

0 comments on commit 5e57b10

Please sign in to comment.