Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added emotion and asr #11

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 270 additions & 0 deletions benchmarks/DASB/IEMOCAP/ecapa_tdnn/discrete_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
#!/usr/bin/env python3
"""Recipe for training an emotion recognition system from speech data only using IEMOCAP.
The system classifies 4 emotions ( anger, happiness, sadness, neutrality) starting from a SSL encoder.
The probing head is ECAPA-TDNN.

Authors
* Salah Zaiem 2023
* Youcef Kemiche 2023
"""

import os
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
import torchaudio
import torch


class EmoIdBrain(sb.Brain):
def compute_forward(self, batch, stage):
"""Computation pipeline based on a encoder + emotion classifier."""
batch = batch.to(self.device)
wavs, wav_lens = batch.sig
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
tokens, feats = self.modules.discrete_model(wavs, wav_lens)
feats = torch.reshape(feats, (feats.shape[0], feats.shape[1], -1))
embeddings = self.modules.embedding_model(feats, wav_lens)
outputs = self.modules.classifier(embeddings)
outputs = self.hparams.log_softmax(outputs)
return outputs

def compute_objectives(self, predictions, batch, stage):
"""Computes the loss using speaker-id as label."""
emoid, _ = batch.emo_encoded
loss = self.hparams.compute_cost(predictions, emoid)
if stage != sb.Stage.TRAIN:
self.error_metrics.append(batch.id, predictions, emoid)
return loss

def fit_batch(self, batch):
"""Trains the parameters given a single batch in input"""
predictions = self.compute_forward(batch, sb.Stage.TRAIN)
loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
loss.backward()
if self.check_gradients(loss):
self.model_optimizer.step()
self.model_optimizer.zero_grad()
return loss.detach()

def on_stage_start(self, stage, epoch=None):
"""Gets called at the beginning of each epoch.
Arguments
---------
stage : sb.Stage
One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
epoch : int
The currently-starting epoch. This is passed
`None` during the test stage.
"""

# Set up statistics trackers for this stage
self.loss_metric = sb.utils.metric_stats.MetricStats(
metric=sb.nnet.losses.nll_loss
)

# Set up evaluation-only statistics trackers
if stage != sb.Stage.TRAIN:
self.error_metrics = self.hparams.error_stats()

def on_stage_end(self, stage, stage_loss, epoch=None):
"""Gets called at the end of an epoch.
Arguments
---------
stage : sb.Stage
One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
stage_loss : float
The average loss for all of the data processed in this stage.
epoch : int
The currently-starting epoch. This is passed
`None` during the test stage.
"""

# Store the train loss until the validation stage.
if stage == sb.Stage.TRAIN:
self.train_loss = stage_loss

# Summarize the statistics from the stage for record-keeping.
else:
stats = {
"loss": stage_loss,
"error_rate": self.error_metrics.summarize("average"),
}

# At the end of validation...
if stage == sb.Stage.VALID:
old_lr, new_lr = self.hparams.lr_annealing_model(
stats["error_rate"]
)
sb.nnet.schedulers.update_learning_rate(
self.model_optimizer, new_lr
)

# The train_logger writes a summary to stdout and to the logfile.
self.hparams.train_logger.log_stats(
{"Epoch": epoch, "lr": old_lr},
train_stats={"loss": self.train_loss},
valid_stats=stats,
)

# Save the current checkpoint and delete previous checkpoints,
self.checkpointer.save_and_keep_only(
meta=stats, min_keys=["error_rate"]
)

# We also write statistics about test data to stdout and to logfile.
if stage == sb.Stage.TEST:
self.hparams.train_logger.log_stats(
{"Epoch loaded": self.hparams.epoch_counter.current},
test_stats=stats,
)

def init_optimizers(self):
"Initializes the encoder2 optimizer and model optimizer"

self.model_optimizer = self.hparams.model_opt_class(
self.hparams.model.parameters()
)

if self.checkpointer is not None:
self.checkpointer.add_recoverable("modelopt", self.model_optimizer)


def dataio_prep(hparams):
"""This function prepares the datasets to be used in the brain class.
It also defines the data processing pipeline through user-defined
functions. We expect `prepare_mini_librispeech` to have been called before
this, so that the `train.json`, `valid.json`, and `valid.json` manifest
files are available.
Arguments
---------
hparams : dict
This dictionary is loaded from the `train.yaml` file, and it includes
all the hyperparameters needed for dataset construction and loading.
Returns
-------
datasets : dict
Contains two keys, "train" and "valid" that correspond
to the appropriate DynamicItemDataset object.
"""

# Define audio pipeline
@sb.utils.data_pipeline.takes("wav")
@sb.utils.data_pipeline.provides("sig")
def audio_pipeline(wav):
"""Load the signal, and pass it and its length to the corruption class.
This is done on the CPU in the `collate_fn`."""
sig = sb.dataio.dataio.read_audio(wav)
resampled = torchaudio.transforms.Resample(
16000, hparams["codec_sampling_rate"],
)(sig)
return resampled

# Initialization of the label encoder. The label encoder assignes to each
# of the observed label a unique index (e.g, 'spk01': 0, 'spk02': 1, ..)
label_encoder = sb.dataio.encoder.CategoricalEncoder()

# Define label pipeline:
@sb.utils.data_pipeline.takes("emo")
@sb.utils.data_pipeline.provides("emo", "emo_encoded")
def label_pipeline(emo):
yield emo
emo_encoded = label_encoder.encode_label_torch(emo)
yield emo_encoded

# Define datasets. We also connect the dataset with the data processing
# functions defined above.
datasets = {}
data_info = {
"train": hparams["train_annotation"],
"valid": hparams["valid_annotation"],
"test": hparams["test_annotation"],
}
for dataset in data_info:
datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
json_path=data_info[dataset],
replacements={"data_root": hparams["data_folder"]},
dynamic_items=[audio_pipeline, label_pipeline],
output_keys=["id", "sig", "emo_encoded"],
)
# Load or compute the label encoder (with multi-GPU DDP support)
# Please, take a look into the lab_enc_file to see the label to index
# mappinng.

lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
label_encoder.load_or_create(
path=lab_enc_file,
from_didatasets=[datasets["train"]],
output_key="emo",
)

return datasets


# RECIPE BEGINS!
if __name__ == "__main__":
# Reading command line arguments.
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])

# Initialize ddp (useful only for multi-GPU DDP training).
sb.utils.distributed.ddp_init_group(run_opts)

# Load hyperparameters file with command-line overrides.
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)

# Create experiment directory
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)

from iemocap_prepare import prepare_data # noqa E402

# Data preparation, to be run on only one process.
sb.utils.distributed.run_on_main(
prepare_data,
kwargs={
"data_original": hparams["data_folder"],
"save_json_train": hparams["train_annotation"],
"save_json_valid": hparams["valid_annotation"],
"save_json_test": hparams["test_annotation"],
"split_ratio": [80, 10, 10],
"different_speakers": hparams["different_speakers"],
"test_spk_id": hparams["test_spk_id"],
"seed": hparams["seed"],
},
)

# Data preparation, to be run on only one process.
# Create dataset objects "train", "valid", and "test".
datasets = dataio_prep(hparams)

# Initialize the Brain object to prepare for mask training.
emo_id_brain = EmoIdBrain(
modules=hparams["modules"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)

# The `fit()` method iterates the training loop, calling the methods
# necessary to update the parameters of the model. Since all objects
# with changing state are managed by the Checkpointer, training can be
# stopped at any point, and will be resumed on next call.

emo_id_brain.fit(
epoch_counter=emo_id_brain.hparams.epoch_counter,
train_set=datasets["train"],
valid_set=datasets["valid"],
train_loader_kwargs=hparams["train_dataloader_opts"],
valid_loader_kwargs=hparams["valid_dataloader_opts"],
)

# Load the best checkpoint for evaluation
test_stats = emo_id_brain.evaluate(
test_set=datasets["test"],
min_key="error_rate",
test_loader_kwargs=hparams["test_dataloader_opts"],
)
Loading
Loading