From f8740f1aa4b74919e20acdfed9701d9255425742 Mon Sep 17 00:00:00 2001 From: Lina Khodja <57141057+l-k-11235@users.noreply.github.com> Date: Fri, 23 Jun 2023 16:22:54 +0200 Subject: [PATCH] Refactor dynamic scoring (#2416) * The _translate method of the ScoringPreparator class relies on a dynamic dataset iterator. --- .github/workflows/push.yml | 43 +------ onmt/opts.py | 14 --- onmt/tests/pull_request_chk.sh | 56 +++------ onmt/tests/test_events.py | 15 +-- onmt/trainer.py | 92 ++++---------- onmt/translate/translator.py | 4 +- onmt/utils/scoring_utils.py | 211 +++++++++++---------------------- 7 files changed, 117 insertions(+), 318 deletions(-) diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index de5ccb75fe..db6c5e1430 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -101,9 +101,10 @@ jobs: -hidden_size 10 \ -train_steps 10 -valid_steps 5 \ -tensorboard "true" \ - -tensorboard_log_dir /tmp/logs_train_valid \ + -tensorboard_log_dir /tmp/logs_train_and_valid \ -copy_attn - python onmt/tests/test_events.py --logdir /tmp/logs_train_valid -tensorboard_checks train_valid + python onmt/tests/test_events.py --logdir /tmp/logs_train_and_valid -tensorboard_checks train + python onmt/tests/test_events.py --logdir /tmp/logs_train_and_valid -tensorboard_checks valid - name: Test RNN training with coverage run: | python train.py \ @@ -141,35 +142,6 @@ jobs: -attention_dropout 0.2 0.1 0.1 \ -report_every 5 \ -train_steps 10 - - name : Test Transformer training with dynamic scoring - run: | - python3 train.py \ - -config data/data.yaml \ - -src_vocab /tmp/onmt.vocab.src \ - -tgt_vocab /tmp/onmt.vocab.tgt \ - -src_vocab_size 1000 \ - -tgt_vocab_size 1000 \ - -encoder_type transformer \ - -decoder_type transformer \ - -layers 4 \ - -word_vec_size 16 \ - -hidden_size 16 \ - -num_workers 0 -bucket_size 1024 \ - -heads 2 \ - -transformer_ff 64 \ - -num_workers 0 -bucket_size 1024 \ - -accum_count 2 4 8 \ - -accum_steps 0 15000 30000 \ - -save_model /tmp/onmt.model \ - -train_steps 20 \ - -report_every 5 \ - -train_eval_steps 10 \ - -train_metrics "BLEU" "TER" \ - -tensorboard "true" \ - -scoring_debug "true" \ - -tensorboard_log_dir /tmp/logs_train_metrics \ - -dump_preds /tmp/dump_preds - python onmt/tests/test_events.py --logdir /tmp/logs_train_metrics -tensorboard_checks train_metrics - name : Test Transformer training and validation with dynamic scoring and copy run: | python3 train.py \ @@ -192,15 +164,13 @@ jobs: -save_model /tmp/onmt.model \ -train_steps 10 -valid_steps 5 \ -report_every 2 \ - -train_eval_steps 8 \ - -train_metrics "BLEU" "TER" \ -valid_metrics "BLEU" "TER" \ -tensorboard "true" \ -scoring_debug "true" \ - -tensorboard_log_dir /tmp/logs_train_valid_metrics \ + -tensorboard_log_dir /tmp/logs_dynamic-scoring_and_copy \ -dump_preds /tmp/dump_preds \ -copy_attn - python onmt/tests/test_events.py --logdir /tmp/logs_train_valid_metrics -tensorboard_checks train_valid_metrics + python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_copy -tensorboard_checks valid_metrics - name: Test LM training run: | python train.py \ @@ -279,8 +249,7 @@ jobs: -hidden_size 2 -batch_size 10 \ -word_vec_size 5 -hidden_size 10 \ -num_workers 0 -bucket_size 1024 \ - -report_every 5 -train_steps 10 \ - -train_metrics "BLEU" "TER" \ + -report_every 5 -train_steps 10 -valid_steps 5\ -valid_metrics "BLEU" "TER" \ -save_model /tmp/onmt.model \ -save_checkpoint_steps 10 diff --git a/onmt/opts.py b/onmt/opts.py index e1309faf06..6502ca8a71 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -55,20 +55,6 @@ def _add_logging_opts(parser, is_train=True): ) if is_train: - group.add( - "--train_eval_steps", - "-train_eval_steps", - type=int, - default=200, - help="calculate training metrics at this interval", - ) - group.add( - "--train_metrics", - "-train_metrics", - default=[], - nargs="+", - help="List of names of additional training metrics", - ) group.add( "--valid_metrics", "-valid_metrics", diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh index 7f9427e3bd..97beafc7ba 100755 --- a/onmt/tests/pull_request_chk.sh +++ b/onmt/tests/pull_request_chk.sh @@ -64,9 +64,9 @@ ${PYTHON} -m unittest discover >> ${LOG_FILE} 2>&1 echo "Succeeded" | tee -a ${LOG_FILE} -# + # Get Vocabulary test -# + echo -n "[+] Testing vocabulary building..." PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH} ${PYTHON} onmt/bin/build_vocab.py \ -config ${DATA_DIR}/data.yaml \ @@ -138,12 +138,13 @@ ${PYTHON} onmt/bin/train.py \ -word_vec_size 5 -report_every 2 \ -hidden_size 10 -train_steps 10 -valid_steps 5 \ -tensorboard "true" \ - -tensorboard_log_dir $TMP_OUT_DIR/logs_train_valid \ + -tensorboard_log_dir $TMP_OUT_DIR/logs_train_and_valid \ -copy_attn >> ${LOG_FILE} 2>&1 -${PYTHON} onmt/tests/test_events.py --logdir $TMP_OUT_DIR/logs_train_valid -tensorboard_checks train_valid +${PYTHON} onmt/tests/test_events.py --logdir $TMP_OUT_DIR/logs_train_and_valid -tensorboard_checks train +${PYTHON} onmt/tests/test_events.py --logdir $TMP_OUT_DIR/logs_train_and_valid -tensorboard_checks valid [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -rm -r $TMP_OUT_DIR/logs_train_valid +rm -r $TMP_OUT_DIR/logs_train_and_valid echo -n " [+] Testing NMT training w/ align..." ${PYTHON} onmt/bin/train.py \ @@ -175,35 +176,8 @@ ${PYTHON} onmt/bin/train.py \ [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -echo -n " [+] Testing NMT training w/ dynamic scoring..." -${PYTHON} onmt/bin/train.py \ - -config ${DATA_DIR}/data.yaml \ - -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ - -tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \ - -src_vocab_size 1000 \ - -tgt_vocab_size 1000 \ - -encoder_type transformer \ - -decoder_type transformer \ - -layers 4 \ - -word_vec_size 16 \ - -hidden_size 16 \ - -heads 2 \ - -transformer_ff 64 \ - -num_workers 0 -bucket_size 1024 \ - -train_steps 20 \ - -report_every 5 \ - -train_eval_steps 10 \ - -train_metrics "BLEU" "TER" \ - -tensorboard "true" \ - -scoring_debug "true" \ - -tensorboard_log_dir $TMP_OUT_DIR/logs_train_metrics \ - -dump_preds $TMP_OUT_DIR/dump_pred >> ${LOG_FILE} 2>&1 -${PYTHON} onmt/tests/test_events.py --logdir $TMP_OUT_DIR/logs_train_metrics -tensorboard_checks train_metrics -[ "$?" -eq 0 ] || error_exit -echo "Succeeded" | tee -a ${LOG_FILE} -rm -r $TMP_OUT_DIR/logs_train_metrics -echo -n " [+] Testing NMT training w/ dynamic scoring with validation and copy ..." +echo -n " [+] Testing NMT training w/ validation with dynamic scoring and copy ..." ${PYTHON} onmt/bin/train.py \ -config ${DATA_DIR}/data.yaml \ -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ @@ -221,18 +195,18 @@ ${PYTHON} onmt/bin/train.py \ -bucket_size 1024 \ -train_steps 10 \ -report_every 2 \ - -train_eval_steps 8 -valid_steps 5 \ - -train_metrics "BLEU" "TER" \ + -valid_steps 5 \ -valid_metrics "BLEU" "TER" \ -tensorboard "true" \ -scoring_debug "true" \ - -dump_preds $TMP_OUT_DIR/dump_pred \ -copy_attn \ - -tensorboard_log_dir $TMP_OUT_DIR/logs_train_valid_metrics >> ${LOG_FILE} 2>&1 -${PYTHON} onmt/tests/test_events.py --logdir $TMP_OUT_DIR/logs_train_valid_metrics -tensorboard_checks train_valid_metrics + -dump_preds $TMP_OUT_DIR/dump_pred \ + -tensorboard_log_dir $TMP_OUT_DIR/logs_dynamic-scoring_and_copy >> ${LOG_FILE} 2>&1 + +${PYTHON} onmt/tests/test_events.py --logdir $TMP_OUT_DIR/logs_dynamic-scoring_and_copy -tensorboard_checks valid_metrics [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -rm -r $TMP_OUT_DIR/logs_train_valid_metrics +rm -r $TMP_OUT_DIR/logs_dynamic-scoring_and_copy echo -n " [+] Testing LM training..." ${PYTHON} onmt/bin/train.py \ @@ -374,11 +348,11 @@ ${PYTHON} onmt/bin/train.py \ -batch_size 10 \ -word_vec_size 5 -hidden_size 10 \ -num_workers 0 -bucket_size 1024 \ - -report_every 5 -train_steps 10 \ - -train_metrics "BLEU" "TER" \ + -report_every 5 -train_steps 10 -valid_steps 5\ -valid_metrics "BLEU" "TER" \ -save_model $TMP_OUT_DIR/onmt.features.model \ -save_checkpoint_steps 10 >> ${LOG_FILE} 2>&1 + [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} rm -f $TMP_OUT_DIR/onmt.vocab* diff --git a/onmt/tests/test_events.py b/onmt/tests/test_events.py index e6db1c0f0e..219d82a53a 100644 --- a/onmt/tests/test_events.py +++ b/onmt/tests/test_events.py @@ -9,17 +9,10 @@ def __init__(self): metrics = ["BLEU", "TER"] self.scalars = {} self.scalars["train"] = [("progress/" + stat) for stat in stats] - self.scalars["train_valid"] = self.scalars["train"] + [ - ("valid/" + stat) for stat in stats + self.scalars["valid"] = [("valid/" + stat) for stat in stats] + self.scalars["valid_metrics"] = self.scalars["valid"] + [ + ("valid/" + metric) for metric in metrics ] - self.scalars["train_metrics"] = self.scalars["train"] + [ - ("progress/" + metric) for metric in metrics - ] - self.scalars["train_valid_metrics"] = ( - self.scalars["train_metrics"] - + [("valid/" + stat) for stat in stats] - + [("valid/" + metric) for metric in metrics] - ) def reload_events(self, path): ea = event_accumulator.EventAccumulator( @@ -49,7 +42,7 @@ def check_scalars(self, scalars, logdir): "--tensorboard_checks", type=str, required=True, - choices=["train", "train_metrics", "train_valid", "train_valid_metrics"], + choices=["train", "valid", "valid_metrics"], ) args = parser.parse_args() test_event = TestEvents() diff --git a/onmt/trainer.py b/onmt/trainer.py index f5edc3f0ff..a69b908470 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -41,8 +41,6 @@ def build_trainer(opt, device_id, model, vocabs, optim, model_saver=None): validset_transforms = opt.data.get("valid", {}).get("transforms", None) if validset_transforms: scoring_preparator.warm_up(validset_transforms) - scorers_cls = get_scorers_cls(opt.train_metrics) - train_scorers = build_scorers(opt, scorers_cls) scorers_cls = get_scorers_cls(opt.valid_metrics) valid_scorers = build_scorers(opt, scorers_cls) @@ -76,7 +74,6 @@ def build_trainer(opt, device_id, model, vocabs, optim, model_saver=None): train_loss, valid_loss, scoring_preparator, - train_scorers, valid_scorers, optim, trunc_size, @@ -85,7 +82,6 @@ def build_trainer(opt, device_id, model, vocabs, optim, model_saver=None): accum_steps, n_gpu, gpu_rank, - opt.train_eval_steps, report_manager, with_align=True if opt.lambda_align > 0 else False, model_saver=model_saver if gpu_rank <= 0 else None, @@ -111,9 +107,7 @@ class Trainer(object): training loss computation scoring_preparator(:obj:`onmt.translate.utils.ScoringPreparator`): preparator for the calculation of metrics via the - training_eval_handler method - train_scorers (dict): keeps in memory the current values - of the training metrics + _eval_handler method valid_scorers (dict): keeps in memory the current values of the validation metrics optim(:obj:`onmt.utils.optimizers.Optimizer`): @@ -124,7 +118,6 @@ class Trainer(object): accum_steps(list): steps for accum gradients changes. n_gpu (int): number of gpu. gpu_rank (int): ordinal rank of the gpu in the list. - train_eval_steps (int): process a validation every x steps. report_manager(:obj:`onmt.utils.ReportMgrBase`): the object that creates reports, or None with_align (bool): whether to jointly lear alignment @@ -147,7 +140,6 @@ def __init__( train_loss, valid_loss, scoring_preparator, - train_scorers, valid_scorers, optim, trunc_size=0, @@ -156,7 +148,6 @@ def __init__( accum_steps=[0], n_gpu=1, gpu_rank=1, - train_eval_steps=200, report_manager=None, with_align=False, model_saver=None, @@ -175,7 +166,6 @@ def __init__( self.valid_loss = valid_loss self.scoring_preparator = scoring_preparator - self.train_scorers = train_scorers self.valid_scorers = valid_scorers self.optim = optim self.trunc_size = trunc_size @@ -186,7 +176,6 @@ def __init__( self.n_gpu = n_gpu self.gpu_rank = gpu_rank self.report_manager = report_manager - self.train_eval_steps = train_eval_steps self.with_align = with_align self.model_saver = model_saver self.average_decay = average_decay @@ -204,7 +193,7 @@ def __init__( # Set model in training mode. self.model.train() - def _training_eval_handler(self, scorer, preds, texts_ref): + def _eval_handler(self, scorer, preds, texts_ref): """Trigger metrics calculations Args: @@ -386,7 +375,8 @@ def validate(self, valid_iter, moving_average=None): # Set model in validating mode. valid_model.eval() - transformed_batches = [] + # raw_srcs = [] + # raw_refs = [] with torch.no_grad(): stats = onmt.utils.Statistics() start = time.time() @@ -394,11 +384,7 @@ def validate(self, valid_iter, moving_average=None): src = batch["src"] src_len = batch["srclen"] tgt = batch["tgt"] - if self.valid_scorers: - transformed_batch = self.scoring_preparator.ids_to_tokens_batch( - batch - ) - transformed_batches.append(transformed_batch) + with torch.cuda.amp.autocast(enabled=self.optim.amp): # F-prop through the model. model_out, attns = valid_model( @@ -410,7 +396,7 @@ def validate(self, valid_iter, moving_average=None): stats.update(batch_stats) logger.info( - """valid stats calculation and sentences rebuilding + """valid stats calculation took: {} s.""".format( time.time() - start ) @@ -422,32 +408,32 @@ def validate(self, valid_iter, moving_average=None): start = time.time() preds, texts_ref = self.scoring_preparator.translate( model=self.model, - transformed_batches=transformed_batches, gpu_rank=self.gpu_rank, step=self.optim.training_step, - mode="valid", ) logger.info( - """The translation of the valid dataset + """The translation of the valid dataset for dynamic scoring took : {} s.""".format( time.time() - start ) ) - for i, metric in enumerate(self.valid_scorers): - logger.info("UPDATING VALIDATION {}".format(metric)) - self.valid_scorers[metric]["value"] = self._training_eval_handler( - scorer=self.valid_scorers[metric]["scorer"], - preds=preds, - texts_ref=texts_ref, - ) - computed_metrics[metric] = self.valid_scorers[metric]["value"] - logger.info( - "validation {}: {}".format( - metric, self.valid_scorers[metric]["value"] + for i, metric in enumerate(self.valid_scorers): + logger.info("UPDATING VALIDATION {}".format(metric)) + self.valid_scorers[metric]["value"] = self._eval_handler( + scorer=self.valid_scorers[metric]["scorer"], + preds=preds, + texts_ref=texts_ref, + ) + computed_metrics[metric] = self.valid_scorers[metric]["value"] + logger.info( + "validation {}: {}".format( + metric, self.valid_scorers[metric]["value"] + ) + ) + # Compute stats + metric_stats = onmt.utils.Statistics( + 0, 0, 0, 0, 0, computed_metrics ) - ) - # Compute stats - metric_stats = onmt.utils.Statistics(0, 0, 0, 0, 0, computed_metrics) # Update statistics. stats.update(metric_stats) @@ -514,38 +500,6 @@ def _gradient_accumulation( trunc_size=trunc_size, ) - step = self.optim.training_step - if self.train_scorers != {} and step % self.train_eval_steps == 0: - # Compute and save stats - computed_metrics = {} - transformed_batch = self.scoring_preparator.ids_to_tokens_batch( - batch - ) - preds, texts_ref = self.scoring_preparator.translate( - model=self.model, - transformed_batches=[transformed_batch], - gpu_rank=self.gpu_rank, - step=self.optim.training_step, - mode="train", - ) - for i, metric in enumerate(self.train_scorers): - logger.info("UPDATING TRAINING {}".format(metric)) - self.train_scorers[metric][ - "value" - ] = self._training_eval_handler( - scorer=self.train_scorers[metric]["scorer"], - preds=preds, - texts_ref=texts_ref, - ) - logger.info( - "training {}: {}".format( - metric, self.train_scorers[metric]["value"] - ) - ) - computed_metrics[metric] = self.train_scorers[metric][ - "value" - ] - batch_stats.computed_metrics = computed_metrics if loss is not None: loss /= normalization self.optim.backward(loss) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index fa601ee45f..fc891c6c8d 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -612,8 +612,8 @@ def _decode_and_generate( batch_dim=0, batch_offset=batch_offset, ) - scores = scores.view(decoder_in.size(1), -1, scores.size(-1)) - log_probs = scores.squeeze(0).log() + scores = scores.view(-1, decoder_in.size(1), scores.size(-1)) + log_probs = scores.squeeze(1).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [batch_size, tgt_len, vocab ] when full sentence return log_probs, attn diff --git a/onmt/utils/scoring_utils.py b/onmt/utils/scoring_utils.py index 9b924f28b1..01aa59c8c2 100644 --- a/onmt/utils/scoring_utils.py +++ b/onmt/utils/scoring_utils.py @@ -3,11 +3,10 @@ from onmt.utils.parse import ArgumentParser from onmt.translate import GNMTGlobalScorer, Translator from onmt.opts import translate_opts -from onmt.constants import DefaultTokens -from onmt.inputters.text_utils import _addcopykeys, tensorify, text_sort_key +from onmt.constants import CorpusTask +from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter from onmt.inputters.inputter import IterOnDevice from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe -from itertools import repeat class ScoringPreparator: @@ -32,87 +31,7 @@ def warm_up(self, transforms): transforms = make_transforms(self.opt, transforms_cls, self.vocabs) self.transform = TransformPipe.build_from(transforms.values()) - def ids_to_tokens_batch_side(self, batch, side): - """Convert a batch into a list of transformed sentences - Args: - batch: batch yielded from `DynamicDatasetIter` object - side (string): 'src' or 'tgt'. - Returns - transformed_sentences (list): List of lists of tokens. - Each list is a transformed sentence. - """ - vocab = self.vocabs[side] - batch_side = batch[side] - nb_sentences = batch_side.shape[0] - nb_tokens_per_sentence = batch_side.shape[1] - nb_feats = batch_side.shape[2] - 1 - - indices_to_remove = [ - vocab.lookup_token(token) - for token in [DefaultTokens.PAD, DefaultTokens.EOS, DefaultTokens.BOS] - ] - transformed_sentences = [] - for i in range(nb_sentences): - tokens = [ - vocab.lookup_index(batch_side[i, t, 0]) - for t in range(nb_tokens_per_sentence) - if batch_side[i, t, 0] not in indices_to_remove - ] - transformed_sentences.append(tokens) - - if nb_feats > 0: - transformed_feats = [] - for i_feat in range(nb_feats): - fv = self.vocabs["src_feats"][i_feat] - indices_to_remove = [ - fv.lookup_token(token) - for token in [ - DefaultTokens.PAD, - DefaultTokens.EOS, - DefaultTokens.BOS, - ] - ] - transformed_feat = [] - for i in range(nb_sentences): - tokens = [ - fv.lookup_index(batch_side[i, t, i_feat + 1]) - for t in range(nb_tokens_per_sentence) - if batch_side[i, t, i_feat + 1] not in indices_to_remove - ] - transformed_feat.append(tokens) - transformed_feats.append(transformed_feat) - else: - transformed_feats = [repeat(None)] - - return transformed_sentences, transformed_feats - - def ids_to_tokens_batch(self, batch): - """Reconstruct transformed source and reference - sentences from a batch. - Args: - batch: batch yielded from `DynamicDatasetIter` object - Returns: - transformed_batch(list): A list of examples - with the fields "src" and "tgt" - """ - - transformed_srcs, transformed_src_feats = self.ids_to_tokens_batch_side( - batch, "src" - ) - transformed_tgts, _ = self.ids_to_tokens_batch_side(batch, "tgt") - - transformed_batch = [] - for src, tgt, *src_feats in zip( - transformed_srcs, transformed_tgts, *transformed_src_feats - ): - ex = {"src": src, "tgt": tgt} - if src_feats[0] is not None: - ex["src_feats"] = src_feats - transformed_batch.append(ex) - - return transformed_batch - - def translate(self, model, transformed_batches, gpu_rank, step, mode): + def translate(self, model, gpu_rank, step): """Compute and save the sentences predicted by the current model's state related to a batch. @@ -127,17 +46,24 @@ def translate(self, model, transformed_batches, gpu_rank, step, mode): preds (list): Detokenized predictions texts_ref (list): Detokenized target sentences """ - model_opt = self.opt + # ########## # + # Translator # + # ########## # + + # Set translation options parser = ArgumentParser() translate_opts(parser) base_args = ["-model", "dummy"] + ["-src", "dummy"] opt = parser.parse_args(base_args) opt.gpu = gpu_rank ArgumentParser.validate_translate_opts(opt) - ArgumentParser.update_model_opts(model_opt) - ArgumentParser.validate_model_opts(model_opt) + + # Build translator from options scorer = GNMTGlobalScorer.from_opt(opt) out_file = codecs.open(os.devnull, "w", "utf-8") + model_opt = self.opt + ArgumentParser.update_model_opts(model_opt) + ArgumentParser.validate_model_opts(model_opt) translator = Translator.from_opt( model, self.vocabs, @@ -149,64 +75,61 @@ def translate(self, model, transformed_batches, gpu_rank, step, mode): report_score=False, logger=None, ) - # translate - preds = [] - raw_sources = [] - raw_refs = [] - for batch in transformed_batches: - # for validation we build an infer_iter per batch - # in order to avoid oom issues because there is no - # batching strategy in `textbatch_to_tensor` - numeric = [] - for i, ex in enumerate(batch): - if ex is not None: - raw_sources.append(ex["src"]) - raw_refs.append(ex["tgt"]) - if isinstance(ex["src"], bytes): - ex["src"] = ex["src"].decode("utf-8") - idxs = translator.vocabs["src"](ex["src"]) - num_ex = { - "src": {"src": " ".join(ex["src"]), "src_ids": idxs}, - "srclen": len(ex["src"]), - "tgt": None, - "indices": i, - "align": None, - } - if "src_feats" in ex: - fs_idxs = [ - fv(f) - for fv, f in zip( - translator.vocabs["src_feats"], ex["src_feats"] - ) - ] - num_ex["src"]["feats"] = fs_idxs - num_ex = _addcopykeys(translator.vocabs["src"], num_ex) - num_ex["src"]["src"] = ex["src"] - numeric.append(num_ex) - numeric.sort(key=text_sort_key, reverse=True) - infer_iter = [tensorify(self.vocabs, numeric)] - infer_iter = IterOnDevice(infer_iter, opt.gpu) - _, preds_ = translator._translate(infer_iter, transform=self.transform) - preds += preds_ - - # apply_reverse refs - if self.transforms: - texts_ref = self.transform.batch_apply_reverse(raw_refs) - - # flatten preds - preds = [item for preds_ in preds for item in preds_] - else: - texts_ref = [" ".join(raw_ref) for raw_ref in raw_refs] - preds = [" ".join(preds_) for preds_ in preds] - - # save results + + # ################### # + # Validation iterator # + # ################### # + + # Reinstantiate the validation iterator + + transforms_cls = get_transforms_cls(model_opt._all_transform) + model_opt.num_workers = 0 + model_opt.tgt = None + + valid_iter = build_dynamic_dataset_iter( + model_opt, + transforms_cls, + translator.vocabs, + task=CorpusTask.VALID, + copy=model_opt.copy_attn, + ) + + # Retrieve raw references and sources + with codecs.open( + valid_iter.corpora_info["valid"]["path_tgt"], "r", encoding="utf-8" + ) as f: + raw_refs = [line.strip("\n") for line in f if line.strip("\n")] + with codecs.open( + valid_iter.corpora_info["valid"]["path_tgt"], "r", encoding="utf-8" + ) as f: + raw_srcs = [line.strip("\n") for line in f if line.strip("\n")] + + valid_iter = IterOnDevice(valid_iter, opt.gpu) + + # ########### # + # Predictions # + # ########### # + + _, preds = translator._translate( + valid_iter, + transform=valid_iter.transform, + attn_debug=opt.attn_debug, + align_debug=opt.align_debug, + ) + + # ####### # + # Outputs # + # ####### # + + # Flatten predictions + preds = [x.lstrip() for sublist in preds for x in sublist] + + # Save results if len(preds) > 0 and self.opt.scoring_debug: - path = os.path.join( - self.opt.dump_preds, "preds.{}_step_{}.{}".format(mode, step, "txt") - ) + path = os.path.join(self.opt.dump_preds, f"preds.valid_step_{step}.txt") with open(path, "a") as file: for i in range(len(preds)): - file.write("SOURCE: {}\n".format(raw_sources[i])) - file.write("REF: {}\n".format(texts_ref[i])) + file.write("SOURCE: {}\n".format(raw_srcs[i])) + file.write("REF: {}\n".format(raw_refs[i])) file.write("PRED: {}\n\n".format(preds[i])) - return preds, texts_ref + return preds, raw_refs