Skip to content

Commit

Permalink
More precise type annotations for MultiObjectiveOptimizationConfig (#…
Browse files Browse the repository at this point in the history
…2620)

Summary:
Pull Request resolved: #2620

Context: Type annotations imply that any `Objective` works with a `MultiObjectiveOptimizationConfig`, but the code makes clear that only a `MultiObjective` or `ScalarizedObjective` works, and even has tests for this. I was misled myself into trying to privde an `Objective` to a `MultiObjectiveOptimizationConfig`.

This PR:
* Changes annotations from `Objective` to `Union[MultiObjective, ScalarizedObjective]`
* Adds a pyre-fixme: Inconsistent override. The indirect cause of why this is needed is that `Objective.clone_with_args` returns an `Objective` type even in subclasses unless the method is overriden, rather than a self type.
* Added a couple pyre-fixmes in unit tests that were deliberately testing inappropriate types.

Reviewed By: saitcakmak, mpolson64

Differential Revision: D60476566

fbshipit-source-id: 03afbfdca624026e5ff8d5da3d224f6b1d676032
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 1, 2024
1 parent e9bd020 commit f25b4ce
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 14 deletions.
4 changes: 2 additions & 2 deletions ax/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def objective_weights(self) -> Iterable[Tuple[Objective, float]]:
"""Get the objectives and weights."""
return zip(self.objectives, self.weights)

def clone(self) -> Objective:
def clone(self) -> MultiObjective:
"""Create a copy of the objective."""
return MultiObjective(objectives=[o.clone() for o in self.objectives])

Expand Down Expand Up @@ -235,7 +235,7 @@ def metric_weights(self) -> Iterable[Tuple[Metric, float]]:
"""Get the metrics and weights."""
return zip(self.metrics, self.weights)

def clone(self) -> Objective:
def clone(self) -> ScalarizedObjective:
"""Create a copy of the objective."""
return ScalarizedObjective(
metrics=[m.clone() for m in self.metrics],
Expand Down
13 changes: 7 additions & 6 deletions ax/core/optimization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from itertools import groupby
from logging import Logger
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
Expand Down Expand Up @@ -266,7 +266,7 @@ class MultiObjectiveOptimizationConfig(OptimizationConfig):

def __init__(
self,
objective: Objective,
objective: Union[MultiObjective, ScalarizedObjective],
outcome_constraints: Optional[List[OutcomeConstraint]] = None,
objective_thresholds: Optional[List[ObjectiveThreshold]] = None,
risk_measure: Optional[RiskMeasure] = None,
Expand All @@ -293,14 +293,15 @@ def __init__(
objective_thresholds=objective_thresholds,
risk_measure=risk_measure,
)
self._objective: Objective = objective
self._objective: Union[MultiObjective, ScalarizedObjective] = objective
self._outcome_constraints: List[OutcomeConstraint] = constraints
self._objective_thresholds: List[ObjectiveThreshold] = objective_thresholds
self.risk_measure: Optional[RiskMeasure] = risk_measure

# pyre-fixme[14]: Inconsistent override.
def clone_with_args(
self,
objective: Optional[Objective] = None,
objective: Optional[Union[MultiObjective, ScalarizedObjective]] = None,
outcome_constraints: Optional[
List[OutcomeConstraint]
] = _NO_OUTCOME_CONSTRAINTS,
Expand Down Expand Up @@ -333,12 +334,12 @@ def clone_with_args(
)

@property
def objective(self) -> Objective:
def objective(self) -> Union[MultiObjective, ScalarizedObjective]:
"""Get objective."""
return self._objective

@objective.setter
def objective(self, objective: Objective) -> None:
def objective(self, objective: Union[MultiObjective, ScalarizedObjective]) -> None:
"""Set objective if not present in outcome constraints."""
self._validate_optimization_config(
objective=objective,
Expand Down
14 changes: 12 additions & 2 deletions ax/core/tests/test_optimization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,12 @@ def test_Init(self) -> None:
objective=self.multi_objective, outcome_constraints=self.outcome_constraints
)
self.assertEqual(str(config1), MOOC_STR)
with self.assertRaises(TypeError):
with self.assertRaisesRegex(
TypeError,
"`MultiObjectiveOptimizationConfig` requires an objective of type "
"`MultiObjective` or `ScalarizedObjective`.",
):
# pyre-fixme [8]: Incompatible attribute type
config1.objective = self.objective # Wrong objective type
# updating constraints is fine.
config1.outcome_constraints = [self.outcome_constraint]
Expand Down Expand Up @@ -428,7 +433,12 @@ def test_Eq(self) -> None:

def test_ConstraintValidation(self) -> None:
# Cannot build with non-MultiObjective
with self.assertRaises(TypeError):
with self.assertRaisesRegex(
TypeError,
"`MultiObjectiveOptimizationConfig` requires an objective of type "
"`MultiObjective` or `ScalarizedObjective`.",
):
# pyre-fixme [6]: Incompatible parameter type
MultiObjectiveOptimizationConfig(objective=self.objective)

# Using an outcome constraint for an objective should raise
Expand Down
5 changes: 4 additions & 1 deletion ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from pandas import read_json
from pyre_extensions import assert_is_instance
from sqlalchemy.orm.exc import DetachedInstanceError

logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -585,7 +586,9 @@ def opt_config_and_tracking_metrics_from_sqa(

if objective_thresholds or type(objective) is MultiObjective:
optimization_config = MultiObjectiveOptimizationConfig(
objective=objective,
objective=assert_is_instance(
objective, Union[MultiObjective, ScalarizedObjective]
),
outcome_constraints=outcome_constraints,
objective_thresholds=objective_thresholds,
risk_measure=risk_measure,
Expand Down
6 changes: 3 additions & 3 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,7 +1583,7 @@ def get_map_objective(minimize: bool = False) -> Objective:
return Objective(metric=MapMetric(name="m1"), minimize=minimize)


def get_multi_objective() -> Objective:
def get_multi_objective() -> MultiObjective:
return MultiObjective(
objectives=[
Objective(metric=Metric(name="m1"), minimize=False),
Expand All @@ -1592,7 +1592,7 @@ def get_multi_objective() -> Objective:
)


def get_custom_multi_objective() -> Objective:
def get_custom_multi_objective() -> MultiObjective:
return MultiObjective(
objectives=[
Objective(
Expand Down Expand Up @@ -1633,7 +1633,7 @@ def get_branin_objective(name: str = "branin", minimize: bool = False) -> Object
)


def get_branin_multi_objective(num_objectives: int = 2) -> Objective:
def get_branin_multi_objective(num_objectives: int = 2) -> MultiObjective:
_validate_num_objectives(num_objectives=num_objectives)
objectives = [
Objective(metric=get_branin_metric(name="branin_a"), minimize=True),
Expand Down

0 comments on commit f25b4ce

Please sign in to comment.