Skip to content

Commit

Permalink
fix #2329 (#2544)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Jan 3, 2024
1 parent acb76c3 commit 8bbd18a
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 13 deletions.
2 changes: 2 additions & 0 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ def _init_cache(self, enc_out):
)
if hasattr(layer.self_attn, "rope"):
layer.self_attn.rope = layer.self_attn.rope.to(enc_out.device)
layer.self_attn.cos = layer.self_attn.cos.to(enc_out.device)
layer.self_attn.sin = layer.self_attn.sin.to(enc_out.device)


class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
Expand Down
2 changes: 1 addition & 1 deletion onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
corpora_dict[CorpusName.VALID] = ParallelCorpus(
CorpusName.VALID,
opts.data[CorpusName.VALID]["path_src"],
opts.data[CorpusName.VALID]["path_tgt"],
opts.data[CorpusName.VALID]["path_tgt"] if tgt is None else None,
opts.data[CorpusName.VALID]["path_align"],
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults,
Expand Down
1 change: 1 addition & 0 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def prepare_transforms_vocabs(opt, transforms_cls):
opt.transforms = validset_transforms
if opt.data.get("valid", {}).get("tgt_prefix", None):
opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None)
opt.tgt_file_prefix = True
if opt.data.get("valid", {}).get("src_prefix", None):
opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None)
if opt.data.get("valid", {}).get("tgt_suffix", None):
Expand Down
29 changes: 17 additions & 12 deletions onmt/utils/scoring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from onmt.utils.parse import ArgumentParser
from onmt.translate import GNMTGlobalScorer, Translator
from onmt.opts import translate_opts
from onmt.opts import config_opts, translate_opts
from onmt.constants import CorpusTask
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe
Expand Down Expand Up @@ -49,12 +49,16 @@ def translate(self, model, gpu_rank, step):
# Translator #
# ########## #

# Set translation options
# Set "default" translation options on empty cfgfile
parser = ArgumentParser()
config_opts(parser)
translate_opts(parser)
base_args = ["-model", "dummy"] + ["-src", "dummy"]
opt = parser.parse_args(base_args)
opt.gpu = gpu_rank
if hasattr(self.opt, "tgt_file_prefix"):
opt.tgt_file_prefix = self.opt.tgt_file_prefix
opt.beam_size = 1 # prevent OOM when GPU is almost full at training
ArgumentParser.validate_translate_opts(opt)

# Build translator from options
Expand Down Expand Up @@ -85,25 +89,26 @@ def translate(self, model, gpu_rank, step):
model_opt.num_workers = 0
model_opt.tgt = None

# Retrieve raw references and sources
with codecs.open(
model_opt.data["valid"]["path_tgt"], "r", encoding="utf-8"
) as f:
raw_refs = [line.strip("\n") for line in f if line.strip("\n")]
with codecs.open(
model_opt.data["valid"]["path_src"], "r", encoding="utf-8"
) as f:
raw_srcs = [line.strip("\n") for line in f if line.strip("\n")]

valid_iter = build_dynamic_dataset_iter(
model_opt,
transforms_cls,
translator.vocabs,
task=CorpusTask.VALID,
tgt="", # This force to clear the target side (needed when using tgt_file_prefix)
copy=model_opt.copy_attn,
device_id=opt.gpu,
)

# Retrieve raw references and sources
with codecs.open(
valid_iter.corpora_info["valid"]["path_tgt"], "r", encoding="utf-8"
) as f:
raw_refs = [line.strip("\n") for line in f if line.strip("\n")]
with codecs.open(
valid_iter.corpora_info["valid"]["path_src"], "r", encoding="utf-8"
) as f:
raw_srcs = [line.strip("\n") for line in f if line.strip("\n")]

# ########### #
# Predictions #
# ########### #
Expand Down

0 comments on commit 8bbd18a

Please sign in to comment.