Skip to content

Commit

Permalink
Don't test with a test stub when there's a function that does the sam…
Browse files Browse the repository at this point in the history
…e thing (#2735)

Summary:
Pull Request resolved: #2735

Context:
* There is a `get_sobol_benchmark_method` test stub, which is not needed when a Sobol benchmark method is provided, also called `get_sobol_benchmark_method`. It is better to test that function.
* The latter `get_sobol_benchmark_method` requires an argument `distribute_replications`. (Making this mandatory was an intentional choice, because it is easy to forget it.)

This diff:
* Gets rid of the test stub and uses the non-stub version instead
* Adds `distribute_replications` in a bunch of places. I chose `False` arbitrarily since the argument will have no effect here.

Reviewed By: saitcakmak

Differential Revision: D62157106

fbshipit-source-id: 9d6ef4e609502fc94d09be31aa31b1dd7325b111
  • Loading branch information
esantorella authored and facebook-github-bot committed Sep 4, 2024
1 parent 6b1b0a4 commit ac3a7ec
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 31 deletions.
31 changes: 13 additions & 18 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ax.benchmark.benchmark_problem import create_problem_from_botorch
from ax.benchmark.benchmark_result import BenchmarkResult
from ax.benchmark.methods.modular_botorch import get_sobol_botorch_modular_acquisition
from ax.benchmark.methods.sobol import get_sobol_benchmark_method
from ax.benchmark.problems.registry import get_problem
from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
Expand All @@ -35,7 +36,6 @@
get_moo_surrogate,
get_multi_objective_benchmark_problem,
get_single_objective_benchmark_problem,
get_sobol_benchmark_method,
get_soo_surrogate,
TestDataset,
)
Expand Down Expand Up @@ -89,7 +89,9 @@ def test_batch(self) -> None:
def test_storage(self) -> None:
problem = get_single_objective_benchmark_problem()
res = benchmark_replication(
problem=problem, method=get_sobol_benchmark_method(), seed=0
problem=problem,
method=get_sobol_benchmark_method(distribute_replications=False),
seed=0,
)
# Experiment is not in storage yet
self.assertTrue(res.experiment is not None)
Expand Down Expand Up @@ -184,26 +186,20 @@ def test_create_benchmark_experiment(self) -> None:
self.assertEqual(experiment.runner, problem.runner)

def test_replication_sobol_synthetic(self) -> None:
method = get_sobol_benchmark_method()
method = get_sobol_benchmark_method(distribute_replications=False)
problems = [
get_single_objective_benchmark_problem(),
get_problem("jenatton", num_trials=6),
]
for problem in problems:
res = benchmark_replication(problem=problem, method=method, seed=0)

self.assertEqual(
min(
problem.num_trials, not_none(method.scheduler_options.total_trials)
),
len(not_none(res.experiment).trials),
)

self.assertEqual(problem.num_trials, len(not_none(res.experiment).trials))
self.assertTrue(np.isfinite(res.score_trace).all())
self.assertTrue(np.all(res.score_trace <= 100))

def test_replication_sobol_surrogate(self) -> None:
method = get_sobol_benchmark_method()
method = get_sobol_benchmark_method(distribute_replications=False)

# This is kind of a weird setup - these are "surrogates" that use a Branin
# synthetic function. The idea here is to test the machinery around the
Expand All @@ -217,10 +213,7 @@ def test_replication_sobol_surrogate(self) -> None:
res = benchmark_replication(problem=problem, method=method, seed=0)

self.assertEqual(
min(
problem.num_trials,
not_none(method.scheduler_options.total_trials),
),
problem.num_trials,
len(not_none(res.experiment).trials),
)

Expand Down Expand Up @@ -313,7 +306,9 @@ def test_replication_moo_sobol(self) -> None:
problem = get_multi_objective_benchmark_problem()

res = benchmark_replication(
problem=problem, method=get_sobol_benchmark_method(), seed=0
problem=problem,
method=get_sobol_benchmark_method(distribute_replications=False),
seed=0,
)

self.assertEqual(
Expand All @@ -331,7 +326,7 @@ def test_benchmark_one_method_problem(self) -> None:
problem = get_single_objective_benchmark_problem()
agg = benchmark_one_method_problem(
problem=problem,
method=get_sobol_benchmark_method(),
method=get_sobol_benchmark_method(distribute_replications=False),
seeds=(0, 1),
)

Expand All @@ -352,7 +347,7 @@ def test_benchmark_multiple_problems_methods(self) -> None:
aggs = benchmark_multiple_problems_methods(
problems=[get_single_objective_benchmark_problem(num_trials=6)],
methods=[
get_sobol_benchmark_method(),
get_sobol_benchmark_method(distribute_replications=False),
get_sobol_botorch_modular_acquisition(
model_cls=SingleTaskGP,
acquisition_cls=qLogNoisyExpectedImprovement,
Expand Down
13 changes: 0 additions & 13 deletions ax/utils/testing/benchmark_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,6 @@ def get_multi_objective_benchmark_problem(
)


def get_sobol_benchmark_method() -> BenchmarkMethod:
return BenchmarkMethod(
name="SOBOL",
generation_strategy=GenerationStrategy(
steps=[GenerationStep(model=Models.SOBOL, num_trials=-1)],
name="SOBOL",
),
scheduler_options=SchedulerOptions(
total_trials=4, init_seconds_between_polls=0
),
)


def get_soo_surrogate() -> SurrogateBenchmarkProblem:
experiment = get_branin_experiment(with_completed_trial=True)
surrogate = TorchModelBridge(
Expand Down

0 comments on commit ac3a7ec

Please sign in to comment.