Skip to content

Commit

Permalink
Remove BenchmarkProblemWithKnownOptimum and BenchmarkProtocol typ…
Browse files Browse the repository at this point in the history
…e annotations (#2602)

Summary:
Pull Request resolved: #2602

These are no longer necessary. See previous PR for context. This change reaps them and updates type annotations.

Reviewed By: saitcakmak

Differential Revision: D60146081

fbshipit-source-id: c9710f76fc92a06bcf72d2822fc4f546f75ca363
  • Loading branch information
esantorella authored and facebook-github-bot committed Jul 28, 2024
1 parent 1bcaacd commit ed506ab
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 55 deletions.
21 changes: 8 additions & 13 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
import numpy as np

from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.benchmark.benchmark_problem import (
BenchmarkProblemProtocol,
BenchmarkProblemWithKnownOptimum,
)
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult
from ax.benchmark.metrics.base import BenchmarkMetricBase, GroundTruthMetricMixin
from ax.core.experiment import Experiment
Expand All @@ -53,16 +50,14 @@
def compute_score_trace(
optimization_trace: np.ndarray,
num_baseline_trials: int,
problem: BenchmarkProblemProtocol,
problem: BenchmarkProblem,
) -> np.ndarray:
"""Computes a score trace from the optimization trace."""

# Use the first GenerationStep's best found point as baseline. Sometimes (ex. in
# a timeout) the first GenerationStep will not have not completed and we will not
# have enough trials; in this case we do not score.
if (len(optimization_trace) <= num_baseline_trials) or not isinstance(
problem, BenchmarkProblemWithKnownOptimum
):
if len(optimization_trace) <= num_baseline_trials:
return np.full(len(optimization_trace), np.nan)
optimum = problem.optimal_value
baseline = optimization_trace[num_baseline_trials - 1]
Expand All @@ -77,7 +72,7 @@ def compute_score_trace(


def _create_benchmark_experiment(
problem: BenchmarkProblemProtocol, method_name: str
problem: BenchmarkProblem, method_name: str
) -> Experiment:
"""Creates an empty experiment for the given problem and method.
Expand Down Expand Up @@ -117,7 +112,7 @@ def _create_benchmark_experiment(


def benchmark_replication(
problem: BenchmarkProblemProtocol,
problem: BenchmarkProblem,
method: BenchmarkMethod,
seed: int,
) -> BenchmarkResult:
Expand Down Expand Up @@ -192,7 +187,7 @@ def benchmark_replication(


def benchmark_one_method_problem(
problem: BenchmarkProblemProtocol,
problem: BenchmarkProblem,
method: BenchmarkMethod,
seeds: Iterable[int],
) -> AggregatedBenchmarkResult:
Expand All @@ -205,7 +200,7 @@ def benchmark_one_method_problem(


def benchmark_multiple_problems_methods(
problems: Iterable[BenchmarkProblemProtocol],
problems: Iterable[BenchmarkProblem],
methods: Iterable[BenchmarkMethod],
seeds: Iterable[int],
) -> List[AggregatedBenchmarkResult]:
Expand All @@ -222,7 +217,7 @@ def benchmark_multiple_problems_methods(


def make_ground_truth_metrics(
problem: BenchmarkProblemProtocol,
problem: BenchmarkProblem,
include_tracking_metrics: bool = True,
) -> Dict[str, Metric]:
"""Makes a ground truth version for each metric defined on the problem.
Expand Down
39 changes: 1 addition & 38 deletions ax/benchmark/benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,7 @@
# in the UI.

from dataclasses import dataclass, field
from typing import (
Any,
Dict,
List,
Optional,
Protocol,
runtime_checkable,
Type,
TypeVar,
Union,
)
from typing import Any, Dict, List, Optional, Type, TypeVar, Union

from ax.benchmark.metrics.base import BenchmarkMetricBase

Expand Down Expand Up @@ -72,33 +62,6 @@ def _get_name(
return f"{base_name}{observed_noise}{dim_str}"


@runtime_checkable
class BenchmarkProblemProtocol(Protocol):
"""
Specifies the interface any benchmark problem must adhere to.
Classes implementing this interface include BenchmarkProblem,
SurrogateBenchmarkProblem, and MOOSurrogateBenchmarkProblem.
"""

name: str
search_space: SearchSpace
optimization_config: OptimizationConfig
num_trials: int
tracking_metrics: List[BenchmarkMetricBase]
is_noiseless: bool # If True, evaluations are deterministic
observe_noise_stds: Union[
bool, Dict[str, bool]
] # Whether we observe the observation noise level
has_ground_truth: bool # if True, evals (w/o synthetic noise) are determinstic
runner: Runner


@runtime_checkable
class BenchmarkProblemWithKnownOptimum(Protocol):
optimal_value: float


@dataclass(kw_only=True, repr=True)
class BenchmarkProblem(Base):
"""
Expand Down
8 changes: 4 additions & 4 deletions ax/benchmark/tests/problems/test_surrogate_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np
from ax.benchmark.benchmark import compute_score_trace
from ax.benchmark.benchmark_problem import BenchmarkProblemProtocol
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.utils.common.testutils import TestCase
from ax.utils.testing.benchmark_stubs import get_moo_surrogate, get_soo_surrogate

Expand All @@ -19,12 +19,12 @@ def setUp(self) -> None:
# print max output so errors in 'repr' can be fully shown
self.maxDiff = None

def test_conforms_to_protocol(self) -> None:
def test_conforms_to_api(self) -> None:
sbp = get_soo_surrogate()
self.assertIsInstance(sbp, BenchmarkProblemProtocol)
self.assertIsInstance(sbp, BenchmarkProblem)

mbp = get_moo_surrogate()
self.assertIsInstance(mbp, BenchmarkProblemProtocol)
self.assertIsInstance(mbp, BenchmarkProblem)

def test_repr(self) -> None:

Expand Down

0 comments on commit ed506ab

Please sign in to comment.