Skip to content

Commit

Permalink
Bug fix: Make Acquisition.optimize work with discrete optimizer reg…
Browse files Browse the repository at this point in the history
…ardless of whether `raw_samples` is in `optimizer_options` (#2772)

Summary:
Pull Request resolved: #2772

Context: When `raw_samples` is present in `optimizer_options` and the discrete optimizer is used, `raw_samples` is removed from `optimizer_options` because it isn't appropriate for for the discrete optimizer. However, `raw_samples` can't be removed if it isn't present. Not passing it (which is better) causes an error.

This PR: Doesn't try to remove `raw_samples` when it isn't present.

Reviewed By: saitcakmak

Differential Revision: D63035021

fbshipit-source-id: 7ad947ede4e944ff2dfea816afd67c88e3f3f496
  • Loading branch information
esantorella authored and facebook-github-bot committed Sep 19, 2024
1 parent b1b7678 commit 3de05de
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def optimize(
# TODO[santorella]: Rather than manually removing it, we should
# ensure that it is never passed.
if optimizer_options is not None:
optimizer_options.pop("raw_samples")
optimizer_options.pop("raw_samples", None)
else:
optimizer = "optimize_acqf_mixed"

Expand Down
16 changes: 14 additions & 2 deletions ax/models/torch/tests/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,23 @@ def test_optimize_discrete(self) -> None:
rounding_func=self.rounding_func,
)

acquisition = self.get_acquisition_function()
n = 2

# Also check that it runs when optimizer options are provided, whether
# `raw_samples` are present or not.
for optimizer_options in [None, {"raw_samples": 8}, {}]:
with self.subTest(optimizer_options=optimizer_options):
acquisition.optimize(
n=n,
search_space_digest=ssd1,
rounding_func=self.rounding_func,
optimizer_options=optimizer_options,
)

# check this works without any fixed_feature specified
# 2 candidates have acqf value 8, but [1, 3, 4] is pending and thus should
# not be selected. [2, 3, 4] is the best point, but has already been picked
acquisition = self.get_acquisition_function()
n = 2
with mock.patch(
f"{ACQUISITION_PATH}.optimizer_argparse", wraps=optimizer_argparse
) as mock_optimizer_argparse, mock.patch(
Expand Down

0 comments on commit 3de05de

Please sign in to comment.