Skip to content

Commit

Permalink
Merge branch 'OpenNMT-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
Raghavendra Sugeeth P S committed Sep 3, 2020
2 parents 43c3869 + d9faa98 commit 176bafa
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 30 deletions.
3 changes: 2 additions & 1 deletion onmt/inputters/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def read(self, images, side, img_dir=None):
img = transforms.ToTensor()(
Image.fromarray(cv2.imread(img_path, 0)))
else:
img = transforms.ToTensor()(Image.open(img_path))
img = Image.open(img_path).convert('RGB')
img = transforms.ToTensor()(img)
if self.truncate and self.truncate != (0, 0):
if not (img.size(1) <= self.truncate[0]
and img.size(2) <= self.truncate[1]):
Expand Down
16 changes: 7 additions & 9 deletions onmt/tests/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,18 @@ def test_advance_with_all_repeats_gets_blocked(self):
# (but it's still the best score, thus we have
# [BLOCKED_SCORE, -inf, -inf, -inf, -inf]
expected_scores = torch.tensor(
[0] + [-float('inf')] * (beam_sz - 1))\
.repeat(batch_sz, 1)
expected_scores[:, 0] = self.BLOCKED_SCORE
[self.BLOCKED_SCORE] + [-float('inf')] * (beam_sz - 1)
).repeat(batch_sz, 1)
self.assertTrue(beam.topk_log_probs.equal(expected_scores))
else:
# repetitions keeps maximizing score
# index 0 has been blocked, so repeating=>+0.0 score
# other indexes are -inf so repeating=>BLOCKED_SCORE
# which is higher
expected_scores = torch.tensor(
[0] + [-float('inf')] * (beam_sz - 1))\
.repeat(batch_sz, 1)
expected_scores[:, :] = self.BLOCKED_SCORE
expected_scores = torch.tensor(
self.BLOCKED_SCORE).repeat(batch_sz, beam_sz)
[self.BLOCKED_SCORE] + [-float('inf')] * (beam_sz - 1)
).repeat(batch_sz, 1)
self.assertTrue(beam.topk_log_probs.equal(expected_scores))

def test_advance_with_some_repeats_gets_blocked(self):
# beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
Expand Down Expand Up @@ -137,7 +134,8 @@ def test_advance_with_some_repeats_gets_blocked(self):

expected = torch.full([batch_sz, beam_sz], float("-inf"))
expected[:, 0] = no_repeat_score
expected[:, 1:] = self.BLOCKED_SCORE
expected[:, 1:3] = self.BLOCKED_SCORE
expected[:, 3:] = float("-inf")
self.assertTrue(
beam.topk_log_probs.equal(expected))

Expand Down
25 changes: 16 additions & 9 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _update_average(self, step):
self.moving_average = copy_params
else:
average_decay = max(self.average_decay,
1 - (step + 1)/(step + 10))
1 - (step + 1) / (step + 10))
for (i, avg), cpt in zip(enumerate(self.moving_average),
self.model.parameters()):
self.moving_average[i] = \
Expand Down Expand Up @@ -306,10 +306,9 @@ def train(self,
break

if (self.model_saver is not None
and (save_checkpoint_steps != 0
and step % save_checkpoint_steps == 0)):
self.model_saver.save(step, is_best,
moving_average=self.moving_average)
and (save_checkpoint_steps != 0
and step % save_checkpoint_steps == 0)):
self.model_saver.save(step, moving_average=self.moving_average)

if train_steps > 0 and step >= train_steps:
break
Expand Down Expand Up @@ -344,7 +343,7 @@ def validate(self, valid_iter, moving_average=None):

for batch in valid_iter:
src, src_lengths = batch.src if isinstance(batch.src, tuple) \
else (batch.src, None)
else (batch.src, None)
tgt = batch.tgt

with torch.cuda.amp.autocast(enabled=self.optim.amp):
Expand Down Expand Up @@ -390,7 +389,7 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
tgt_outer = batch.tgt

bptt = False
for j in range(0, target_size-1, trunc_size):
for j in range(0, target_size - 1, trunc_size):
# 1. Create truncated target.
tgt = tgt_outer[j: j + trunc_size]

Expand Down Expand Up @@ -488,7 +487,12 @@ def _maybe_report_training(self, step, num_steps, learning_rate,
"""
if self.report_manager is not None:
return self.report_manager.report_training(
step, num_steps, learning_rate, report_stats,
step,
num_steps,
learning_rate,
None if self.earlystopper is None
else self.earlystopper.current_tolerance,
report_stats,
multigpu=self.n_gpu > 1)

def _report_step(self, learning_rate, step, train_stats=None,
Expand All @@ -499,7 +503,10 @@ def _report_step(self, learning_rate, step, train_stats=None,
"""
if self.report_manager is not None:
return self.report_manager.report_step(
learning_rate, step, train_stats=train_stats,
learning_rate,
None if self.earlystopper is None
else self.earlystopper.current_tolerance,
step, train_stats=train_stats,
valid_stats=valid_stats)

def maybe_noise_source(self, batch):
Expand Down
3 changes: 2 additions & 1 deletion onmt/translate/decode_strategy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from copy import deepcopy


class DecodeStrategy(object):
Expand Down Expand Up @@ -184,7 +185,7 @@ def maybe_update_forbidden_tokens(self):
# Reordering forbidden_tokens following beam selection
# We rebuild a dict to ensure we get the value and not the pointer
forbidden_tokens.append(
dict(self.forbidden_tokens[path_idx]))
deepcopy(self.forbidden_tokens[path_idx]))

# Grabing the newly selected tokens and associated ngram
current_ngram = tuple(seq[-n:].tolist())
Expand Down
29 changes: 20 additions & 9 deletions onmt/utils/report_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def start(self):
def log(self, *args, **kwargs):
logger.info(*args, **kwargs)

def report_training(self, step, num_steps, learning_rate,
def report_training(self, step, num_steps, learning_rate, patience,
report_stats, multigpu=False):
"""
This is the user-defined batch-level traing progress
Expand All @@ -72,7 +72,7 @@ def report_training(self, step, num_steps, learning_rate,
report_stats = \
onmt.utils.Statistics.all_gather_stats(report_stats)
self._report_training(
step, num_steps, learning_rate, report_stats)
step, num_steps, learning_rate, patience, report_stats)
return onmt.utils.Statistics()
else:
return report_stats
Expand All @@ -81,17 +81,22 @@ def _report_training(self, *args, **kwargs):
""" To be overridden """
raise NotImplementedError()

def report_step(self, lr, step, train_stats=None, valid_stats=None):
def report_step(self, lr, patience, step, train_stats=None,
valid_stats=None):
"""
Report stats of a step
Args:
lr(float): current learning rate
patience(int): current patience
step(int): current step
train_stats(Statistics): training stats
valid_stats(Statistics): validation stats
lr(float): current learning rate
"""
self._report_step(
lr, step, train_stats=train_stats, valid_stats=valid_stats)
lr, patience, step,
train_stats=train_stats,
valid_stats=valid_stats)

def _report_step(self, *args, **kwargs):
raise NotImplementedError()
Expand All @@ -111,12 +116,13 @@ def __init__(self, report_every, start_time=-1., tensorboard_writer=None):
super(ReportMgr, self).__init__(report_every, start_time)
self.tensorboard_writer = tensorboard_writer

def maybe_log_tensorboard(self, stats, prefix, learning_rate, step):
def maybe_log_tensorboard(self, stats, prefix, learning_rate,
patience, step):
if self.tensorboard_writer is not None:
stats.log_tensorboard(
prefix, self.tensorboard_writer, learning_rate, step)
prefix, self.tensorboard_writer, learning_rate, patience, step)

def _report_training(self, step, num_steps, learning_rate,
def _report_training(self, step, num_steps, learning_rate, patience,
report_stats):
"""
See base class method `ReportMgrBase.report_training`.
Expand All @@ -127,12 +133,15 @@ def _report_training(self, step, num_steps, learning_rate,
self.maybe_log_tensorboard(report_stats,
"progress",
learning_rate,
patience,
step)
report_stats = onmt.utils.Statistics()

return report_stats

def _report_step(self, lr, step, train_stats=None, valid_stats=None):
def _report_step(self, lr, patience, step,
train_stats=None,
valid_stats=None):
"""
See base class method `ReportMgrBase.report_step`.
"""
Expand All @@ -143,6 +152,7 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None):
self.maybe_log_tensorboard(train_stats,
"train",
lr,
patience,
step)

if valid_stats is not None:
Expand All @@ -152,4 +162,5 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None):
self.maybe_log_tensorboard(valid_stats,
"valid",
lr,
patience,
step)
4 changes: 3 additions & 1 deletion onmt/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ def output(self, step, num_steps, learning_rate, start):
time.time() - start))
sys.stdout.flush()

def log_tensorboard(self, prefix, writer, learning_rate, step):
def log_tensorboard(self, prefix, writer, learning_rate, patience, step):
""" display statistics to tensorboard """
t = self.elapsed_time()
writer.add_scalar(prefix + "/xent", self.xent(), step)
writer.add_scalar(prefix + "/ppl", self.ppl(), step)
writer.add_scalar(prefix + "/accuracy", self.accuracy(), step)
writer.add_scalar(prefix + "/tgtper", self.n_words / t, step)
writer.add_scalar(prefix + "/lr", learning_rate, step)
if patience is not None:
writer.add_scalar(prefix + "/patience", patience, step)

0 comments on commit 176bafa

Please sign in to comment.