Skip to content

Commit

Permalink
Make objective direction checks more strict (#2382)
Browse files Browse the repository at this point in the history
Summary:

With this change, at least one of `minimize` or `lower_is_better` must be specified. If both are specified, they must match.

Added a json storage helper & updated SQA storage helpers for deserializing previously saved objectives in a backwards compatible manner, resolving the conflicts in favor of `minimize`.

Differential Revision: D56315542
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 19, 2024
1 parent 5b38f2a commit d8c55df
Show file tree
Hide file tree
Showing 20 changed files with 822 additions and 846 deletions.
43 changes: 17 additions & 26 deletions ax/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Iterable, List, Optional, Tuple

from ax.core.metric import Metric
from ax.exceptions.core import UserInputError
from ax.utils.common.base import SortableBase
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
Expand All @@ -34,36 +35,27 @@ def __init__(self, metric: Metric, minimize: Optional[bool] = None) -> None:
metric: The metric to be optimized.
minimize: If True, minimize metric. If None, will be set based on the
`lower_is_better` property of the metric (if that is not specified,
will raise a DeprecationWarning).
will raise a `UserInputError`).
"""
lower_is_better = metric.lower_is_better
if minimize is None:
if lower_is_better is None:
warnings.warn(
f"Defaulting to `minimize=False` for metric {metric.name} not "
+ "specifying `lower_is_better` property. This is a wild guess. "
+ "Specify either `lower_is_better` on the metric, or specify "
+ "`minimize` explicitly. This will become an error in the future.",
DeprecationWarning,
raise UserInputError(
f"Metric {metric.name} does not specify `lower_is_better` "
"and `minimize` is not specified. At least one of these "
"must be specified."
)
minimize = False
else:
minimize = lower_is_better
if lower_is_better is not None:
if lower_is_better and not minimize:
warnings.warn(
f"Attempting to maximize metric {metric.name} with property "
"`lower_is_better=True`."
)
elif not lower_is_better and minimize:
warnings.warn(
f"Attempting to minimize metric {metric.name} with property "
"`lower_is_better=False`."
)
self._metric = metric
# pyre-fixme[4]: Attribute must be annotated.
self.minimize = not_none(minimize)
elif lower_is_better is not None and lower_is_better != minimize:
raise UserInputError(
f"Metric {metric.name} specifies {lower_is_better=}, "
"which doesn't match the specified optimization direction "
f"{minimize=}."
)
self._metric: Metric = metric
self.minimize: bool = not_none(minimize)

@property
def metric(self) -> Metric:
Expand Down Expand Up @@ -130,18 +122,17 @@ def __init__(
"as input to `MultiObjective` constructor."
)
metrics = extra_kwargs["metrics"]
minimize = extra_kwargs.get("minimize", False)
minimize = extra_kwargs.get("minimize", None)
warnings.warn(
"Passing `metrics` and `minimize` as input to the `MultiObjective` "
"constructor will soon be deprecated. Instead, pass a list of "
"`objectives`. This will become an error in the future.",
DeprecationWarning,
stacklevel=2,
)
objectives = []
for metric in metrics:
lower_is_better = metric.lower_is_better or False
_minimize = not lower_is_better if minimize else lower_is_better
objectives.append(Objective(metric=metric, minimize=_minimize))
objectives.append(Objective(metric=metric, minimize=minimize))

# pyre-fixme[4]: Attribute must be annotated.
self._objectives = not_none(objectives)
Expand Down
48 changes: 23 additions & 25 deletions ax/core/tests/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

# pyre-strict

import warnings

from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase


Expand All @@ -21,7 +20,7 @@ def setUp(self) -> None:
"m3": Metric(name="m3", lower_is_better=False),
}
self.objectives = {
"o1": Objective(metric=self.metrics["m1"]),
"o1": Objective(metric=self.metrics["m1"], minimize=True),
"o2": Objective(metric=self.metrics["m2"], minimize=True),
"o3": Objective(metric=self.metrics["m3"], minimize=False),
}
Expand All @@ -38,6 +37,12 @@ def setUp(self) -> None:
)

def test_Init(self) -> None:
with self.assertRaisesRegex(UserInputError, "does not specify"):
Objective(metric=self.metrics["m1"]),
with self.assertRaisesRegex(
UserInputError, "doesn't match the specified optimization direction"
):
Objective(metric=self.metrics["m2"], minimize=False)
with self.assertRaises(ValueError):
ScalarizedObjective(
metrics=[self.metrics["m1"], self.metrics["m2"]], weights=[1.0]
Expand All @@ -52,20 +57,6 @@ def test_Init(self) -> None:
metrics=[self.metrics["m1"], self.metrics["m2"]],
minimize=False,
)
warnings.resetwarnings()
warnings.simplefilter("always", append=True)
with warnings.catch_warnings(record=True) as ws:
Objective(metric=self.metrics["m1"])
self.assertTrue(any(issubclass(w.category, DeprecationWarning) for w in ws))
self.assertTrue(
any("Defaulting to `minimize=False`" in str(w.message) for w in ws)
)
with warnings.catch_warnings(record=True) as ws:
Objective(Metric(name="m4", lower_is_better=True), minimize=False)
self.assertTrue(any("Attempting to maximize" in str(w.message) for w in ws))
with warnings.catch_warnings(record=True) as ws:
Objective(Metric(name="m4", lower_is_better=False), minimize=True)
self.assertTrue(any("Attempting to minimize" in str(w.message) for w in ws))
self.assertEqual(
self.objective.get_unconstrainable_metrics(), [self.metrics["m1"]]
)
Expand All @@ -77,15 +68,15 @@ def test_MultiObjective(self) -> None:

self.assertEqual(self.multi_objective.metrics, list(self.metrics.values()))
minimizes = [obj.minimize for obj in self.multi_objective.objectives]
self.assertEqual(minimizes, [False, True, False])
self.assertEqual(minimizes, [True, True, False])
weights = [mw[1] for mw in self.multi_objective.objective_weights]
self.assertEqual(weights, [1.0, 1.0, 1.0])
self.assertEqual(self.multi_objective.clone(), self.multi_objective)
self.assertEqual(
str(self.multi_objective),
(
"MultiObjective(objectives="
'[Objective(metric_name="m1", minimize=False), '
'[Objective(metric_name="m1", minimize=True), '
'Objective(metric_name="m2", minimize=True), '
'Objective(metric_name="m3", minimize=False)])'
),
Expand All @@ -96,19 +87,26 @@ def test_MultiObjective(self) -> None:
)

def test_MultiObjectiveBackwardsCompatibility(self) -> None:
multi_objective = MultiObjective(
metrics=[self.metrics["m1"], self.metrics["m2"], self.metrics["m3"]]
)
metrics = [
Metric(name="m1", lower_is_better=False),
self.metrics["m2"],
self.metrics["m3"],
]
multi_objective = MultiObjective(metrics=metrics)
minimizes = [obj.minimize for obj in multi_objective.objectives]
self.assertEqual(multi_objective.metrics, list(self.metrics.values()))
self.assertEqual(multi_objective.metrics, metrics)
self.assertEqual(minimizes, [False, True, False])

multi_objective_min = MultiObjective(
metrics=[self.metrics["m1"], self.metrics["m2"], self.metrics["m3"]],
metrics=[
Metric(name="m1"),
Metric(name="m2"),
Metric(name="m3", lower_is_better=True),
],
minimize=True,
)
minimizes = [obj.minimize for obj in multi_objective_min.objectives]
self.assertEqual(minimizes, [True, False, True])
self.assertEqual(minimizes, [True, True, True])

def test_ScalarizedObjective(self) -> None:
with self.assertRaises(NotImplementedError):
Expand Down
2 changes: 1 addition & 1 deletion ax/core/tests/test_optimization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def setUp(self) -> None:
"o2": Objective(metric=self.metrics["m2"], minimize=False),
"o3": Objective(metric=self.metrics["m3"], minimize=False),
}
self.objective = Objective(metric=self.metrics["m1"], minimize=False)
self.objective = Objective(metric=self.metrics["m1"], minimize=True)
self.multi_objective = MultiObjective(
objectives=[self.objectives["o1"], self.objectives["o2"]]
)
Expand Down
2 changes: 1 addition & 1 deletion ax/core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def setUp(self) -> None:
self.data = Data(df=self.df)

self.optimization_config = OptimizationConfig(
objective=Objective(metric=Metric(name="a")),
objective=Objective(metric=Metric(name="a"), minimize=False),
outcome_constraints=[
OutcomeConstraint(
metric=Metric(name="b"),
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_ModelBridge(
observation_features=[get_observation1trans().features], weights=[2]
),
)
oc = OptimizationConfig(objective=Objective(metric=Metric(name="test_metric")))
oc = get_optimization_config_no_constraints()
modelbridge._set_kwargs_to_save(
model_key="TestModel", model_kwargs={}, bridge_kwargs={}
)
Expand Down Expand Up @@ -322,7 +322,7 @@ def warn_and_return_mock_obs(
fit_tracking_metrics=False,
)
new_oc = OptimizationConfig(
objective=Objective(metric=Metric(name="test_metric2"))
objective=Objective(metric=Metric(name="test_metric2"), minimize=False),
)
with self.assertRaisesRegex(UnsupportedError, "fit_tracking_metrics"):
modelbridge.gen(n=1, optimization_config=new_oc)
Expand Down
11 changes: 8 additions & 3 deletions ax/modelbridge/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_HasGoodOptConfigModelFit(self) -> None:

# Test single objective
optimization_config = OptimizationConfig(
objective=Objective(metric=Metric("a"))
objective=Objective(metric=Metric("a"), minimize=True)
)
has_good_fit = has_good_opt_config_model_fit(
optimization_config=optimization_config,
Expand All @@ -354,7 +354,12 @@ def test_HasGoodOptConfigModelFit(self) -> None:

# Test multi objective
optimization_config = MultiObjectiveOptimizationConfig(
objective=MultiObjective(metrics=[Metric("a"), Metric("b")])
objective=MultiObjective(
objectives=[
Objective(Metric("a"), minimize=False),
Objective(Metric("b"), minimize=False),
]
)
)
has_good_fit = has_good_opt_config_model_fit(
optimization_config=optimization_config,
Expand All @@ -364,7 +369,7 @@ def test_HasGoodOptConfigModelFit(self) -> None:

# Test constraints
optimization_config = OptimizationConfig(
objective=Objective(metric=Metric("a")),
objective=Objective(metric=Metric("a"), minimize=False),
outcome_constraints=[
OutcomeConstraint(metric=Metric("b"), op=ComparisonOp.GEQ, bound=0.1)
],
Expand Down
13 changes: 4 additions & 9 deletions ax/modelbridge/tests/test_torch_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
get_branin_experiment,
get_branin_search_space,
get_experiment_with_observations,
get_optimization_config_no_constraints,
get_search_space_for_range_value,
)
from ax.utils.testing.mock import fast_botorch_optimize
Expand Down Expand Up @@ -363,9 +364,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None:
observation_features=[
ObservationFeatures(parameters={"x": 1.0, "y": 2.0})
],
optimization_config=OptimizationConfig(
objective=Objective(metric=Metric(name="test_metric"))
),
optimization_config=get_optimization_config_no_constraints(),
)

self.assertEqual(acqf_vals, [5.0])
Expand All @@ -392,9 +391,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None:
ObservationFeatures(parameters={"x": 1.0, "y": 2.0}),
ObservationFeatures(parameters={"x": 1.0, "y": 2.0}),
],
optimization_config=OptimizationConfig(
objective=Objective(metric=Metric(name="test_metric"))
),
optimization_config=get_optimization_config_no_constraints(),
)
t.transform_observation_features.assert_any_call(
[ObservationFeatures(parameters={"x": 1.0, "y": 2.0})],
Expand All @@ -418,9 +415,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None:
ObservationFeatures(parameters={"x": 1.0, "y": 2.0}),
]
],
optimization_config=OptimizationConfig(
objective=Objective(metric=Metric(name="test_metric"))
),
optimization_config=get_optimization_config_no_constraints(),
)
t.transform_observation_features.assert_any_call(
[
Expand Down
6 changes: 4 additions & 2 deletions ax/modelbridge/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def test_extract_outcome_constraints(self) -> None:
def test_extract_objective_thresholds(self) -> None:
outcomes = ["m1", "m2", "m3", "m4"]
objective = MultiObjective(
objectives=[Objective(metric=Metric(name)) for name in outcomes[:3]]
objectives=[
Objective(metric=Metric(name), minimize=False) for name in outcomes[:3]
]
)
objective_thresholds = [
ObjectiveThreshold(
Expand Down Expand Up @@ -159,7 +161,7 @@ def test_extract_objective_thresholds(self) -> None:
self.assertTrue(np.isnan(obj_t[-2:]).all())

# Fails if a threshold does not have a corresponding metric.
objective2 = Objective(Metric("m1"))
objective2 = Objective(Metric("m1"), minimize=False)
with self.assertRaisesRegex(ValueError, "corresponding metrics"):
extract_objective_thresholds(
objective_thresholds=objective_thresholds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_DerelativizeTransform(
)

# Test with no relative constraints
objective = Objective(Metric("c"))
objective = Objective(Metric("c"), minimize=True)
oc = OptimizationConfig(
objective=objective,
outcome_constraints=[
Expand Down Expand Up @@ -300,7 +300,7 @@ def test_Errors(self) -> None:
observations=[],
)
oc = OptimizationConfig(
objective=Objective(Metric("c")),
objective=Objective(Metric("c"), minimize=False),
outcome_constraints=[
OutcomeConstraint(Metric("a"), ComparisonOp.LEQ, bound=2, relative=True)
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def test_relative_constraints(
RangeParameter("y", ParameterType.FLOAT, 0, 20),
]
)
objective = Objective(Metric("c"))
objective = Objective(Metric("c"), minimize=False)

# Test with relative constraint, in-design status quo
oc = OptimizationConfig(
Expand Down
2 changes: 1 addition & 1 deletion ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def setUp(self) -> None:
self.branin_experiment_no_impl_runner_or_metrics = Experiment(
search_space=get_branin_search_space(),
optimization_config=OptimizationConfig(
objective=Objective(metric=Metric(name="branin"))
objective=Objective(metric=Metric(name="branin"), minimize=False)
),
name="branin_experiment_no_impl_runner_or_metrics",
)
Expand Down
14 changes: 7 additions & 7 deletions ax/service/tests/test_report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,11 +560,11 @@ def test_get_metric_name_pairs(self) -> None:
exp._optimization_config = MultiObjectiveOptimizationConfig(
objective=MultiObjective(
objectives=[
Objective(metric=Metric("m0")),
Objective(metric=Metric("m1")),
Objective(metric=Metric("m2")),
Objective(metric=Metric("m3")),
Objective(metric=Metric("m4")),
Objective(metric=Metric("m0"), minimize=False),
Objective(metric=Metric("m1"), minimize=False),
Objective(metric=Metric("m2"), minimize=False),
Objective(metric=Metric("m3"), minimize=False),
Objective(metric=Metric("m4"), minimize=False),
]
)
)
Expand Down Expand Up @@ -1052,9 +1052,9 @@ def test_compare_to_baseline_moo(self) -> None:
optimization_config = MultiObjectiveOptimizationConfig(
objective=MultiObjective(
objectives=[
Objective(metric=Metric("m0")),
Objective(metric=Metric("m0"), minimize=False),
Objective(metric=Metric("m1"), minimize=True),
Objective(metric=Metric("m3")),
Objective(metric=Metric("m3"), minimize=False),
]
)
)
Expand Down
3 changes: 1 addition & 2 deletions ax/service/utils/report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)

import gpytorch

import numpy as np
import pandas as pd
import plotly.graph_objects as go
Expand Down Expand Up @@ -140,7 +139,7 @@ def _get_objective_trace_plot(
plot_objective_value_vs_trial_index(
exp_df=exp_df,
metric_colname=metric_name,
minimize=(
minimize=not_none(
optimization_config.objective.minimize
if optimization_config.objective.metric.name == metric_name
else experiment.metrics[metric_name].lower_is_better
Expand Down
Loading

0 comments on commit d8c55df

Please sign in to comment.