From 0a2afc1ddbf4d690c77c7d78604fafec6467ee3f Mon Sep 17 00:00:00 2001 From: pme0 Date: Tue, 20 Jun 2023 20:10:24 +0100 Subject: [PATCH] add logmel spec --- deeplightning/model/dcgan.py | 20 +++++++++++++++++++- deeplightning/viz/audio/spectrum.py | 23 +++++++++++++---------- deeplightning/viz/audio/wave.py | 9 +++++++-- 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/deeplightning/model/dcgan.py b/deeplightning/model/dcgan.py index 969214a..0a78f48 100644 --- a/deeplightning/model/dcgan.py +++ b/deeplightning/model/dcgan.py @@ -44,6 +44,15 @@ def __init__(self, num_channels, latent_dim): nn.ConvTranspose2d(in_channels=32, out_channels=num_channels, kernel_size=4, stride=2, padding=1, bias=False), nn.Tanh(), ) + """ Parametrize layer creation: + last_channels = 32 + channels = [latent_dim] + [last_channels * (2 ** i) for i in range(num_layers-1)][::-1] + [num_channels] + layers = [] + for i in range(num_layers): + layers.append(nn.ConvTranspose2d(in_channels=channels[i], out_channels=channels[i+1], kernel_size=4, stride=2, padding=0, bias=False)) + layers.append(nn.BatchNorm2d(channels[i+1])) + layers.append(nn.ReLU() if i < num_layers-1 else nn.Tanh()) + """ def forward(self, x): return self.generator(x) @@ -89,7 +98,16 @@ def tensor_size_debugger(self): class DCGAN(nn.Module): def __init__(self, batch_size: int, sample_size: int, image_size: int, latent_dim: int): - """ + """DCGAN: Deep Convolutional Generative Adversarial Network. + + Parameters + ---------- + ??? + + References + ---------- + > A Radford, L Metz, S Chintala (2015) "Unsupervised representation learning with + deep convolutional generative adversarial networks", arXiv:1511.06434 """ super(DCGAN, self).__init__() diff --git a/deeplightning/viz/audio/spectrum.py b/deeplightning/viz/audio/spectrum.py index d7c559a..0668eae 100644 --- a/deeplightning/viz/audio/spectrum.py +++ b/deeplightning/viz/audio/spectrum.py @@ -70,30 +70,26 @@ def spectrogram( scale: str = None, n_fft: int = 2048, hop_length: int = 512, - figsize: tuple = (5,5), + figsize: tuple = (8,3), save_plot: str = None, show_plot: bool = True, + x_axis: str = "time", ): """Display Mel Frequency Cepstral Coefficients. Parameters ---------- path : path to the audio file. - mode : the type of features to be shown in the spectrogram. Can be Short Time Fourier Transform (STFT) amplitude - (`stft_ampl`) or decibels (`stft_db`); or Mel Frequency - Cepstral Coefficients (MFCC) - + (`stft_ampl`) or decibels (`stft_db`); Log Mel spectrogram + (logmel); Mel Frequency Cepstral Coefficients (`mfcc`) scale : the y-axis scale. Can be `linear` or `log`. For `mode == "mfcc"` the scale is chosen automatically - n_fft : - hop_length : - figsize : the figure size - + x_axis : """ assert scale is None or scale in ["linear", "log"] @@ -121,12 +117,19 @@ def spectrogram( frequency_type = 'MFCC' colorbar_label = 'coefficients' scale_type = '' + elif mode == "logmel": + S = librosa.feature.melspectrogram(y=signal, sr=sample_rate, n_fft=n_fft, hop_length=hop_length) + S = librosa.power_to_db(S, ref=np.max) + specshow(data = S, x_axis = x_axis, y_axis = 'mel', sr = sample_rate, hop_length=hop_length) + frequency_type = 'Log Mel' + colorbar_label = 'decibel (dB)' + scale_type = '' else: raise NotImplementedError plt.xlabel("Time") plt.ylabel(f"Frequency (Hz)") - plt.colorbar(label=colorbar_label, orientation="horizontal") + plt.colorbar(label=colorbar_label, orientation="vertical", pad=0.02) plt.title("Spectrogram ({}{})".format( frequency_type if frequency_type is not None else '', scale_type if scale_type is not None else '', diff --git a/deeplightning/viz/audio/wave.py b/deeplightning/viz/audio/wave.py index 9f7d24d..3240cbd 100644 --- a/deeplightning/viz/audio/wave.py +++ b/deeplightning/viz/audio/wave.py @@ -1,3 +1,4 @@ +from typing import Union, Tuple import matplotlib.pyplot as plt import numpy as np import librosa @@ -9,7 +10,9 @@ def waveplot( x_axis: str = "time", save_plot: str = None, show_plot: bool = True, - figsize=(8,4) + figsize: Tuple[int,int] = (8,3), + title: Union[str,None] = None, + ): """Display waveform in the time domain. @@ -33,7 +36,9 @@ def waveplot( # plot plt.figure(figsize=figsize) waveshow(signal, sr=sample_rate, x_axis=x_axis) - plt.title("Audio waveplot") + plt.ylabel("Amplitude") + if title: + plt.title(title) if save_plot is not None: plt.savefig(save_plot, bbox_inches='tight')