Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get rid of surrogate problem subclasses #2607

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
import numpy as np

from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.benchmark.benchmark_problem import (
BenchmarkProblemProtocol,
BenchmarkProblemWithKnownOptimum,
)
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult
from ax.benchmark.metrics.base import BenchmarkMetricBase, GroundTruthMetricMixin
from ax.core.experiment import Experiment
Expand All @@ -53,16 +50,14 @@
def compute_score_trace(
optimization_trace: np.ndarray,
num_baseline_trials: int,
problem: BenchmarkProblemProtocol,
problem: BenchmarkProblem,
) -> np.ndarray:
"""Computes a score trace from the optimization trace."""

# Use the first GenerationStep's best found point as baseline. Sometimes (ex. in
# a timeout) the first GenerationStep will not have not completed and we will not
# have enough trials; in this case we do not score.
if (len(optimization_trace) <= num_baseline_trials) or not isinstance(
problem, BenchmarkProblemWithKnownOptimum
):
if len(optimization_trace) <= num_baseline_trials:
return np.full(len(optimization_trace), np.nan)
optimum = problem.optimal_value
baseline = optimization_trace[num_baseline_trials - 1]
Expand All @@ -77,7 +72,7 @@ def compute_score_trace(


def _create_benchmark_experiment(
problem: BenchmarkProblemProtocol, method_name: str
problem: BenchmarkProblem, method_name: str
) -> Experiment:
"""Creates an empty experiment for the given problem and method.

Expand Down Expand Up @@ -117,7 +112,7 @@ def _create_benchmark_experiment(


def benchmark_replication(
problem: BenchmarkProblemProtocol,
problem: BenchmarkProblem,
method: BenchmarkMethod,
seed: int,
) -> BenchmarkResult:
Expand Down Expand Up @@ -192,7 +187,7 @@ def benchmark_replication(


def benchmark_one_method_problem(
problem: BenchmarkProblemProtocol,
problem: BenchmarkProblem,
method: BenchmarkMethod,
seeds: Iterable[int],
) -> AggregatedBenchmarkResult:
Expand All @@ -205,7 +200,7 @@ def benchmark_one_method_problem(


def benchmark_multiple_problems_methods(
problems: Iterable[BenchmarkProblemProtocol],
problems: Iterable[BenchmarkProblem],
methods: Iterable[BenchmarkMethod],
seeds: Iterable[int],
) -> List[AggregatedBenchmarkResult]:
Expand All @@ -222,7 +217,7 @@ def benchmark_multiple_problems_methods(


def make_ground_truth_metrics(
problem: BenchmarkProblemProtocol,
problem: BenchmarkProblem,
include_tracking_metrics: bool = True,
) -> Dict[str, Metric]:
"""Makes a ground truth version for each metric defined on the problem.
Expand Down
Loading
Loading