Skip to content

Commit

Permalink
fixed return
Browse files Browse the repository at this point in the history
  • Loading branch information
Löffler, Hannes committed Mar 20, 2024
1 parent 49f4af9 commit 4913d33
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion reinvent/runmodes/RL/reports/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def write_report(reporter, data: TBData) -> None:

sample_size = ROWS * COLUMNS

image_tensor = make_grid_image(data.smilies, results.total_scores, "score", sample_size, ROWS)
image_tensor, _ = make_grid_image(data.smilies, results.total_scores, "score", sample_size, ROWS)

if image_tensor is not None:
reporter.add_image(
Expand Down
4 changes: 2 additions & 2 deletions reinvent/runmodes/samplers/reports/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setup_TBData(sampled: SampleBatch, time: int, **kwargs):
counter = collections.Counter(sampled.smilies)
top_sampled = counter.most_common(ROWS * COLS)
top_sampled_smiles, top_sampled_freq = zip(*top_sampled)
image_tensor = make_grid_image(top_sampled_smiles, top_sampled_freq, "Times sampled", ROWS * COLS, ROWS)
image_tensor, _ = make_grid_image(top_sampled_smiles, top_sampled_freq, "Times sampled", ROWS * COLS, ROWS)

return TBData(fraction_valid_smiles,
fraction_unique_molecules,
Expand Down Expand Up @@ -94,4 +94,4 @@ def _report_scatter(data: TBData, reporter):
if x_key in data.additional_report.keys() and y_key in data.additional_report.keys():
x, y = data.additional_report[x_key], data.additional_report[y_key]
figure = plot_scatter(x, y, xlabel=xlabel, ylabel=ylabel, title=f'{len(x)} Unique')
reporter.add_figure(f'{xlabel}_{ylabel}', figure)
reporter.add_figure(f'{xlabel}_{ylabel}', figure)

0 comments on commit 4913d33

Please sign in to comment.