From 3de05deb2faefd38d62e99d65d69aeebb63a10a3 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Thu, 19 Sep 2024 15:52:36 -0700 Subject: [PATCH] Bug fix: Make `Acquisition.optimize` work with discrete optimizer regardless of whether `raw_samples` is in `optimizer_options` (#2772) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2772 Context: When `raw_samples` is present in `optimizer_options` and the discrete optimizer is used, `raw_samples` is removed from `optimizer_options` because it isn't appropriate for for the discrete optimizer. However, `raw_samples` can't be removed if it isn't present. Not passing it (which is better) causes an error. This PR: Doesn't try to remove `raw_samples` when it isn't present. Reviewed By: saitcakmak Differential Revision: D63035021 fbshipit-source-id: 7ad947ede4e944ff2dfea816afd67c88e3f3f496 --- ax/models/torch/botorch_modular/acquisition.py | 2 +- ax/models/torch/tests/test_acquisition.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index 2ed51810000..c96801528e6 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -452,7 +452,7 @@ def optimize( # TODO[santorella]: Rather than manually removing it, we should # ensure that it is never passed. if optimizer_options is not None: - optimizer_options.pop("raw_samples") + optimizer_options.pop("raw_samples", None) else: optimizer = "optimize_acqf_mixed" diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index 46e97d35801..aa385916b43 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -430,11 +430,23 @@ def test_optimize_discrete(self) -> None: rounding_func=self.rounding_func, ) + acquisition = self.get_acquisition_function() + n = 2 + + # Also check that it runs when optimizer options are provided, whether + # `raw_samples` are present or not. + for optimizer_options in [None, {"raw_samples": 8}, {}]: + with self.subTest(optimizer_options=optimizer_options): + acquisition.optimize( + n=n, + search_space_digest=ssd1, + rounding_func=self.rounding_func, + optimizer_options=optimizer_options, + ) + # check this works without any fixed_feature specified # 2 candidates have acqf value 8, but [1, 3, 4] is pending and thus should # not be selected. [2, 3, 4] is the best point, but has already been picked - acquisition = self.get_acquisition_function() - n = 2 with mock.patch( f"{ACQUISITION_PATH}.optimizer_argparse", wraps=optimizer_argparse ) as mock_optimizer_argparse, mock.patch(