Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(separation): add PixIT task, ToTaToNet model and SpeechSeparation pipeline #1676

Merged
merged 95 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
8365802
Merge pull request #5 from pyannote/develop
joonaskalda Mar 18, 2024
ff2c705
add convnet layers to PyanNet
joonaskalda Mar 10, 2023
28c9057
add stft and free encoders/decoder
joonaskalda Mar 15, 2023
eca0a31
multitask learning first attempt
joonaskalda Mar 15, 2023
1d2d1ab
properly logging mixit loss in train/valid
joonaskalda Mar 15, 2023
4e175b6
add a weight to mixit_loss
joonaskalda Mar 16, 2023
7894bf7
reformulate multitask loss
joonaskalda Mar 27, 2023
9eeb29c
add dprnn
joonaskalda Apr 16, 2023
2fb16e9
pair mixtures from same file with no overlapping speakers
joonaskalda May 12, 2023
1a83fec
fix mixit loss for odd batch size in validation
joonaskalda May 12, 2023
3ebc46a
make the MoM part of the original batch
joonaskalda May 12, 2023
055744f
clean up
joonaskalda May 12, 2023
4af6be1
check that BS is divisible by 3
joonaskalda May 14, 2023
4aa83b3
don't use MoMs with more than 3 speakers
joonaskalda May 14, 2023
ef14f8f
include original mixtures in separation branch training
joonaskalda May 22, 2023
b218473
matching the order of dimensions of branch outputs
joonaskalda May 24, 2023
ebe3471
make n_sources an argument for model constructor
joonaskalda May 25, 2023
df6794b
changing LSTM default num_layers to 4
joonaskalda May 25, 2023
1fe6106
create separate tasks and models
joonaskalda Jun 9, 2023
b3a7821
Changing n_sources to 3
joonaskalda Jun 12, 2023
0c790d5
forcing alignment between separation and diarization
joonaskalda Jun 16, 2023
70adb7e
fixing edge case of 4 speakers in a second chunk
joonaskalda Jun 17, 2023
fcde9b8
adding a VAD-like forced alignment loss
joonaskalda Jun 18, 2023
699a2fb
refactor: remove vad_loss and warm_up, assume powerset everywhere
joonaskalda Jun 20, 2023
065cde0
remove double check of num_speakers
joonaskalda Jun 20, 2023
3ff0ba5
refactor: moved mom constrcution
joonaskalda Jun 21, 2023
8bb63d2
remove unused mixit wrapper
joonaskalda Jun 21, 2023
5bd913d
format with black
joonaskalda Jun 21, 2023
40c11db
fix for last batch in validation having size 1
joonaskalda Jun 21, 2023
3ee1484
adding documentation
joonaskalda Jun 21, 2023
71541c2
diarization on sources separately and back to multilabel
joonaskalda Jun 26, 2023
8c66d74
make lstm use optional
joonaskalda Jun 26, 2023
dcab13d
make alignment forcing optional
joonaskalda Jun 26, 2023
954e0f8
bug fix
joonaskalda Jun 27, 2023
7dee18c
rename mixit_loss to separation_loss for clarity
joonaskalda Jul 2, 2023
f792135
add 2 sources for noise and alignement accuracy measure
joonaskalda Jul 25, 2023
9f2cd5b
bug regarding specifications being a tuple
joonaskalda Jul 25, 2023
ded4b5e
clean up
joonaskalda Aug 28, 2023
90c9b3a
add avg pooling to diarization branch for smaller kernel sizes
joonaskalda Sep 7, 2023
f97d440
fix validation loss
joonaskalda Sep 9, 2023
2654528
changing to pit_loss
joonaskalda Sep 14, 2023
1ed87e9
changing validation dataloader
joonaskalda Sep 14, 2023
b97dd16
make the additional 2 noise sources optional
joonaskalda Sep 14, 2023
b2baf1d
make aligned training the only supported behavior
joonaskalda Sep 15, 2023
f84c683
clean up
joonaskalda Sep 20, 2023
23f8b7b
change default model parameters
joonaskalda Sep 20, 2023
7ccbee1
3 source mixit
joonaskalda Sep 21, 2023
6917aa0
first commit
joonaskalda Sep 22, 2023
407e1fb
diar branch from masked tf rep instead
joonaskalda Sep 23, 2023
66146fb
add lstm
joonaskalda Sep 23, 2023
adcc594
add lstm back in
joonaskalda Sep 24, 2023
053324e
fix forward
joonaskalda Sep 25, 2023
c1519f4
adding training on single speaker sources
joonaskalda Sep 25, 2023
680e181
check edge case
joonaskalda Sep 25, 2023
031cb9d
clean up and format
joonaskalda Oct 18, 2023
4d20d4a
first commit
joonaskalda Oct 21, 2023
21d7f40
enable finetuning wavlm with separate lr (pytorch-lightning 2.1)
joonaskalda Nov 22, 2023
8081bef
add gradient clipping
joonaskalda Nov 23, 2023
5104521
clean up
joonaskalda Mar 18, 2024
a8c8caa
rename SepDiarNet and include receptive field
joonaskalda Mar 21, 2024
0a9597b
fix rebase mistake in PyanNet
joonaskalda Mar 21, 2024
ec36794
fix rebase mistake in segmentation mixins
joonaskalda Mar 21, 2024
7cfdc75
fix joint task setup
joonaskalda Mar 21, 2024
4b37532
fix joint task init
joonaskalda Mar 21, 2024
c963934
fix data iteration and add docstrings
joonaskalda Mar 21, 2024
658308a
fix docstrings
joonaskalda Mar 21, 2024
b3ba2d4
remove functionality for additional noise sources
joonaskalda Mar 21, 2024
e451a2b
remove functionality for using original sources for separation (when …
joonaskalda Mar 21, 2024
386d750
fixc rebase mistake in speaker diarization task
joonaskalda Mar 21, 2024
60210d9
rename joint task to PixIT
joonaskalda Mar 21, 2024
047f741
make wavlm finetuning optional
joonaskalda Mar 21, 2024
110355e
clean up ToTaToNet
joonaskalda Mar 21, 2024
e2cdae4
add joint diarization separation pipeline
joonaskalda Mar 25, 2024
4df4950
fix docstrings and imports
joonaskalda Mar 25, 2024
c848143
fix ToTaToNet behavior when WavLM not used
joonaskalda Mar 25, 2024
ffe583a
update requirements.txt
joonaskalda Mar 25, 2024
9d79069
Merge branch 'develop' into pixit
hbredin Apr 5, 2024
75c3c8e
doc: update changelog
hbredin Apr 5, 2024
4112ef8
Merge branch 'develop' into pixit
hbredin Apr 18, 2024
776349b
Merge branch 'develop' into pixit
hbredin May 24, 2024
e133515
chore: reorganize things a bit
hbredin May 24, 2024
6e56060
fix: fix import
hbredin May 24, 2024
6adcd2d
fix: fix docstring
hbredin May 24, 2024
8be568c
fix: fix docstrings and default values
joonaskalda May 27, 2024
e8a4db0
fix: rename speaker separation to speech separation
joonaskalda May 27, 2024
8f97c6a
fix: renaming and reorganizing separation pipeline hyperparameters
joonaskalda May 28, 2024
284fb83
doc: add hyper-parameter documntation
hbredin May 28, 2024
089504f
feat(setup): make separation dependencies optional
hbredin May 28, 2024
00e4894
chore: remove evaluation script
hbredin May 30, 2024
ee06e79
fix: remove the lstm layers
joonaskalda May 30, 2024
dbf8ce5
chore: applying ruff here and there
hbredin May 30, 2024
64425c5
doc: simplify docstring
hbredin May 30, 2024
c64548b
doc: credit where it's due
hbredin May 30, 2024
60fe1ba
doc: update changelog
hbredin May 30, 2024
a57ee9d
Merge branch 'develop' into pixit
hbredin May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,19 @@

## develop

### TL;DR

`pyannote.audio` does [speech separation](https://hf.co/pyannote/speech-separation-ami-1.0): multi-speaker audio in, one audio channel per speaker out!

```bash
pip install pyannote.audio[separation]==3.3.0
```

### New features

- feat(task): add `PixIT` joint speaker diarization and speech separation task (with [@joonaskalda](https://github.com/joonaskalda/))
- feat(model): add `ToTaToNet` joint speaker diarization and speech separation model (with [@joonaskalda](https://github.com/joonaskalda/))
- feat(pipeline): add `SpeechSeparation` pipeline (with [@joonaskalda](https://github.com/joonaskalda/))
- feat(io): add option to select torchaudio `backend`

### Fixes
Expand Down
351 changes: 351 additions & 0 deletions pyannote/audio/models/separation/ToTaToNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
# MIT License
#
# Copyright (c) 2024- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# AUTHOR: Joonas Kalda (github.com/joonaskalda)

from functools import lru_cache
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from asteroid_filterbanks import make_enc_dec
from pyannote.core.utils.generators import pairwise

from pyannote.audio.core.model import Model
from pyannote.audio.core.task import Task
from pyannote.audio.utils.params import merge_dict
from pyannote.audio.utils.receptive_field import (
conv1d_num_frames,
conv1d_receptive_field_center,
conv1d_receptive_field_size,
)

try:
from asteroid.masknn import DPRNN
from asteroid.utils.torch_utils import pad_x_to_y

ASTEROID_IS_AVAILABLE = True
except ImportError:
ASTEROID_IS_AVAILABLE = False


try:
from transformers import AutoModel

TRANSFORMERS_IS_AVAILABLE = True
except ImportError:
TRANSFORMERS_IS_AVAILABLE = False


class ToTaToNet(Model):
"""ToTaToNet joint speaker diarization and speech separation model

/--------------\\
Conv1D Encoder --------+--- DPRNN --X------- Conv1D Decoder
WavLM -- upsampling --/ \\--- Avg pool -- Linear -- Classifier


Parameters
----------
sample_rate : int, optional
Audio sample rate. Defaults to 16kHz (16000).
num_channels : int, optional
Number of channels. Defaults to mono (1).
sincnet : dict, optional
Keyword arugments passed to the SincNet block.
Defaults to {"stride": 1}.
linear : dict, optional
Keyword arugments used to initialize linear layers
See ToTaToNet.LINEAR_DEFAULTS for default values.
diar : dict, optional
Keyword arguments used to initalize the average pooling in the diarization branch.
See ToTaToNet.DIAR_DEFAULTS for default values.
encoder_decoder : dict, optional
Keyword arguments used to initalize the encoder and decoder.
See ToTaToNet.ENCODER_DECODER_DEFAULTS for default values.
dprnn : dict, optional
Keyword arguments used to initalize the DPRNN model.
See ToTaToNet.DPRNN_DEFAULTS for default values.
sample_rate : int, optional
Audio sample rate. Defaults to 16000.
num_channels : int, optional
Number of channels. Defaults to 1.
task : Task, optional
Task to perform. Defaults to None.
n_sources : int, optional
Number of separated sources. Defaults to 3.
use_wavlm : bool, optional
Whether to use the WavLM large model for feature extraction. Defaults to True.
gradient_clip_val : float, optional
Gradient clipping value. Required when fine-tuning the WavLM model and thus using two different optimizers.
Defaults to 5.0.

References
----------
Joonas Kalda, Clément Pagés, Ricard Marxer, Tanel Alumäe, and Hervé Bredin.
"PixIT: Joint Training of Speaker Diarization and Speech Separation
from Real-world Multi-speaker Recordings"
Odyssey 2024. https://arxiv.org/abs/2403.02288
"""

ENCODER_DECODER_DEFAULTS = {
"fb_name": "free",
"kernel_size": 32,
"n_filters": 64,
"stride": 16,
}
LINEAR_DEFAULTS = {"hidden_size": 64, "num_layers": 2}
DPRNN_DEFAULTS = {
"n_repeats": 6,
"bn_chan": 128,
"hid_size": 128,
"chunk_size": 100,
"norm_type": "gLN",
"mask_act": "relu",
"rnn_type": "LSTM",
}
DIAR_DEFAULTS = {"frames_per_second": 125}

def __init__(
self,
encoder_decoder: dict = None,
linear: Optional[dict] = None,
diar: Optional[dict] = None,
dprnn: dict = None,
sample_rate: int = 16000,
num_channels: int = 1,
task: Optional[Task] = None,
n_sources: int = 3,
use_wavlm: bool = True,
gradient_clip_val: float = 5.0,
):
if not ASTEROID_IS_AVAILABLE:
raise ImportError(
"'asteroid' must be installed to use ToTaToNet separation. "
"`pip install pyannote-audio[separation]` should do the trick."
)

if not TRANSFORMERS_IS_AVAILABLE:
raise ImportError(
"'transformers' must be installed to use ToTaToNet separation. "
"`pip install pyannote-audio[separation]` should do the trick."
)

super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task)

linear = merge_dict(self.LINEAR_DEFAULTS, linear)
dprnn = merge_dict(self.DPRNN_DEFAULTS, dprnn)
encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder)
diar = merge_dict(self.DIAR_DEFAULTS, diar)
self.use_wavlm = use_wavlm
self.save_hyperparameters("encoder_decoder", "linear", "dprnn", "diar")
self.n_sources = n_sources

if encoder_decoder["fb_name"] == "free":
n_feats_out = encoder_decoder["n_filters"]
elif encoder_decoder["fb_name"] == "stft":
n_feats_out = int(2 * (encoder_decoder["n_filters"] / 2 + 1))
else:
raise ValueError("Filterbank type not recognized.")
self.encoder, self.decoder = make_enc_dec(
sample_rate=sample_rate, **self.hparams.encoder_decoder
)

if self.use_wavlm:
self.wavlm = AutoModel.from_pretrained("microsoft/wavlm-large")
downsampling_factor = 1
for conv_layer in self.wavlm.feature_extractor.conv_layers:
if isinstance(conv_layer.conv, nn.Conv1d):
downsampling_factor *= conv_layer.conv.stride[0]
self.wavlm_scaling = int(downsampling_factor / encoder_decoder["stride"])

self.masker = DPRNN(
encoder_decoder["n_filters"]
+ self.wavlm.feature_projection.projection.out_features,
out_chan=encoder_decoder["n_filters"],
n_src=n_sources,
**self.hparams.dprnn,
)
else:
self.masker = DPRNN(
encoder_decoder["n_filters"],
out_chan=encoder_decoder["n_filters"],
n_src=n_sources,
**self.hparams.dprnn,
)

# diarization can use a lower resolution than separation
self.diarization_scaling = int(
sample_rate / diar["frames_per_second"] / encoder_decoder["stride"]
)
self.average_pool = nn.AvgPool1d(
self.diarization_scaling, stride=self.diarization_scaling
)
linaer_input_features = n_feats_out
if linear["num_layers"] > 0:
self.linear = nn.ModuleList(
[
nn.Linear(in_features, out_features)
for in_features, out_features in pairwise(
[
linaer_input_features,
]
+ [self.hparams.linear["hidden_size"]]
* self.hparams.linear["num_layers"]
)
]
)
self.gradient_clip_val = gradient_clip_val
self.automatic_optimization = False

@property
def dimension(self) -> int:
"""Dimension of output"""
return 1

def build(self):
if self.hparams.linear["num_layers"] > 0:
self.classifier = nn.Linear(64, self.dimension)
else:
self.classifier = nn.Linear(1, self.dimension)
self.activation = self.default_activation()

@lru_cache
def num_frames(self, num_samples: int) -> int:
"""Compute number of output frames

Parameters
----------
num_samples : int
Number of input samples.

Returns
-------
num_frames : int
Number of output frames.
"""

equivalent_stride = (
self.diarization_scaling * self.hparams.encoder_decoder["stride"]
)
equivalent_kernel_size = (
self.diarization_scaling * self.hparams.encoder_decoder["kernel_size"]
)

return conv1d_num_frames(
num_samples, kernel_size=equivalent_kernel_size, stride=equivalent_stride
)

def receptive_field_size(self, num_frames: int = 1) -> int:
"""Compute size of receptive field

Parameters
----------
num_frames : int, optional
Number of frames in the output signal

Returns
-------
receptive_field_size : int
Receptive field size.
"""

equivalent_stride = (
self.diarization_scaling * self.hparams.encoder_decoder["stride"]
)
equivalent_kernel_size = (
self.diarization_scaling * self.hparams.encoder_decoder["kernel_size"]
)

return conv1d_receptive_field_size(
num_frames, kernel_size=equivalent_kernel_size, stride=equivalent_stride
)

def receptive_field_center(self, frame: int = 0) -> int:
"""Compute center of receptive field

Parameters
----------
frame : int, optional
Frame index

Returns
-------
receptive_field_center : int
Index of receptive field center.
"""

equivalent_stride = (
self.diarization_scaling * self.hparams.encoder_decoder["stride"]
)
equivalent_kernel_size = (
self.diarization_scaling * self.hparams.encoder_decoder["kernel_size"]
)

return conv1d_receptive_field_center(
frame, kernel_size=equivalent_kernel_size, stride=equivalent_stride
)

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Pass forward

Parameters
----------
waveforms : (batch, channel, sample)

Returns
-------
scores : (batch, frame, classes)
sources : (batch, sample, n_sources)
"""
bsz = waveforms.shape[0]
tf_rep = self.encoder(waveforms)
if self.use_wavlm:
wavlm_rep = self.wavlm(waveforms.squeeze(1)).last_hidden_state
wavlm_rep = wavlm_rep.transpose(1, 2)
wavlm_rep = wavlm_rep.repeat_interleave(self.wavlm_scaling, dim=-1)
wavlm_rep = pad_x_to_y(wavlm_rep, tf_rep)
wavlm_rep = torch.cat((tf_rep, wavlm_rep), dim=1)
masks = self.masker(wavlm_rep)
else:
masks = self.masker(tf_rep)
# shape: (batch, nsrc, nfilters, nframes)
masked_tf_rep = masks * tf_rep.unsqueeze(1)
decoded_sources = self.decoder(masked_tf_rep)
decoded_sources = pad_x_to_y(decoded_sources, waveforms)
decoded_sources = decoded_sources.transpose(1, 2)
outputs = torch.flatten(masked_tf_rep, start_dim=0, end_dim=1)
# shape (batch * nsrc, nfilters, nframes)
outputs = self.average_pool(outputs)
outputs = outputs.transpose(1, 2)
# shape (batch, nframes, nfilters)
if self.hparams.linear["num_layers"] > 0:
for linear in self.linear:
outputs = F.leaky_relu(linear(outputs))
if self.hparams.linear["num_layers"] == 0:
outputs = (outputs**2).sum(dim=2).unsqueeze(-1)
outputs = self.classifier(outputs)
outputs = outputs.reshape(bsz, self.n_sources, -1)
outputs = outputs.transpose(1, 2)

return self.activation[0](outputs), decoded_sources
25 changes: 25 additions & 0 deletions pyannote/audio/models/separation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# MIT License
#
# Copyright (c) 2024- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from .ToTaToNet import ToTaToNet

__all__ = ["ToTaToNet"]
Loading
Loading