-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
253284d
commit d423c2e
Showing
56 changed files
with
9,611 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# ################################ | ||
# ################################ | ||
|
||
# Seed needs to be set at top of yaml, before objects with parameters are made | ||
seed: 1986 | ||
__set_seed: !apply:torch.manual_seed [!ref <seed>] | ||
output_folder: !ref results/<seed> | ||
wer_file: !ref <output_folder>/wer.txt | ||
save_folder: !ref <output_folder>/save | ||
train_log: !ref <output_folder>/train_log.txt | ||
|
||
|
||
# Data files | ||
skip_prep: False | ||
csv_folder: !PLACEHOLDER # Path to the folder containing Buckeye training files | ||
ckpt_interval_minutes: 25 # save checkpoint every N min | ||
train_csv: !ref <csv_folder>/train.csv | ||
valid_csv: !ref <csv_folder>/dev.csv | ||
test_csv: | ||
- !ref <csv_folder>/test.csv | ||
|
||
num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large) | ||
ssl_hub: facebook/wav2vec2-base | ||
encoder_dim: 768 | ||
|
||
# Training parameters | ||
number_of_epochs: 20 | ||
lr: 0.0002 | ||
lr_weights: 0.01 | ||
sorting: ascending | ||
auto_mix_prec: False | ||
sample_rate: 16000 | ||
language_modelling: False | ||
ngram_lm_path: !PLACEHOLDER # Path to the librispeech .arpa mù file | ||
|
||
# With data_parallel batch_size is split into N jobs | ||
# With DDP batch_size is multiplied by N jobs | ||
# Must be 3 per GPU to fit 32GB of VRAM | ||
batch_size: 4 | ||
test_batch_size: 4 | ||
|
||
# Dataloader options | ||
train_dataloader_opts: | ||
batch_size: !ref <batch_size> | ||
|
||
valid_dataloader_opts: | ||
batch_size: !ref <batch_size> | ||
|
||
test_dataloader_opts: | ||
batch_size: !ref <test_batch_size> | ||
|
||
# Model parameters | ||
activation: !name:torch.nn.Sigmoid | ||
dnn_layers: 1 | ||
dnn_neurons: 768 | ||
freeze_encoder: True | ||
|
||
# Outputs | ||
output_neurons: 30 # BPE size, index(blank/eos/bos) = 0 | ||
|
||
|
||
# Functions and classes | ||
# | ||
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter | ||
limit: !ref <number_of_epochs> | ||
|
||
augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment | ||
sample_rate: !ref <sample_rate> | ||
speeds: [95, 100, 105] | ||
|
||
weighted_ssl_model: !new:speechbrain.lobes.models.huggingface_wav2vec.WeightedSSLModel # yamllint disable-line rule:line-length | ||
hub: !ref <ssl_hub> | ||
num_layers: !ref <num_layers_ssl> | ||
|
||
enc: !new:speechbrain.nnet.RNN.LSTM | ||
input_shape: [Null, Null, !ref <encoder_dim>] | ||
num_layers: 2 | ||
bidirectional: True | ||
dropout: 0.2 | ||
hidden_size: 1024 | ||
|
||
ctc_lin: !new:speechbrain.nnet.linear.Linear | ||
input_size: 2048 | ||
n_neurons: !ref <output_neurons> | ||
|
||
log_softmax: !new:speechbrain.nnet.activations.Softmax | ||
apply_log: True | ||
|
||
ctc_cost: !name:speechbrain.nnet.losses.ctc_loss | ||
blank_index: !ref <blank_index> | ||
|
||
modules: | ||
enc: !ref <enc> | ||
ctc_lin: !ref <ctc_lin> | ||
weighted_ssl_model: !ref <weighted_ssl_model> | ||
|
||
model: !new:torch.nn.ModuleList | ||
- [!ref <enc>, !ref <ctc_lin>] | ||
|
||
model_opt_class: !name:torch.optim.Adam | ||
lr: !ref <lr> | ||
|
||
weights_opt_class: !name:torch.optim.Adam | ||
lr: !ref <lr_weights> | ||
|
||
lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler | ||
initial_value: !ref <lr> | ||
improvement_threshold: 0.0025 | ||
annealing_factor: 0.8 | ||
patient: 0 | ||
|
||
lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler | ||
initial_value: !ref <lr_weights> | ||
improvement_threshold: 0.0025 | ||
annealing_factor: 0.9 | ||
patient: 0 | ||
|
||
label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder | ||
|
||
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer | ||
checkpoints_dir: !ref <save_folder> | ||
recoverables: | ||
model: !ref <model> | ||
ssl_model: !ref <weighted_ssl_model> | ||
scheduler_model: !ref <lr_annealing_model> | ||
scheduler_encoder: !ref <lr_annealing_weights> | ||
counter: !ref <epoch_counter> | ||
tokenizer: !ref <label_encoder> | ||
|
||
blank_index: 0 | ||
unk_index: 1 | ||
|
||
|
||
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger | ||
save_file: !ref <train_log> | ||
|
||
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats | ||
|
||
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats | ||
split_tokens: True |
Oops, something went wrong.