From 6cab8e3dbb5ff7e33463a374cb11070f700ca194 Mon Sep 17 00:00:00 2001 From: David Eriksson Date: Fri, 5 Apr 2024 09:10:23 -0700 Subject: [PATCH] SimplifyParameterConstraints (#2326) Summary: Remove parameter constraints that can be trivially converted into an updated lower/upper bound Reviewed By: SebastianAment Differential Revision: D55718753 --- .../simplify_parameter_constraints.py | 69 ++++++++++ .../test_simplify_parameter_constraints.py | 122 ++++++++++++++++++ ax/storage/transform_registry.py | 4 + sphinx/source/modelbridge.rst | 8 ++ 4 files changed, 203 insertions(+) create mode 100644 ax/modelbridge/transforms/simplify_parameter_constraints.py create mode 100644 ax/modelbridge/transforms/tests/test_simplify_parameter_constraints.py diff --git a/ax/modelbridge/transforms/simplify_parameter_constraints.py b/ax/modelbridge/transforms/simplify_parameter_constraints.py new file mode 100644 index 00000000000..625f74de238 --- /dev/null +++ b/ax/modelbridge/transforms/simplify_parameter_constraints.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import math +from typing import List, TYPE_CHECKING + +from ax.core.parameter import FixedParameter, ParameterType, RangeParameter +from ax.core.parameter_constraint import ParameterConstraint +from ax.core.search_space import SearchSpace +from ax.modelbridge.transforms.base import Transform +from ax.utils.common.typeutils import checked_cast + +if TYPE_CHECKING: + # import as module to make sphinx-autodoc-typehints happy + from ax import modelbridge as modelbridge_module # noqa F401 + + +class SimplifyParameterConstraints(Transform): + """Convert parameter constraints on one parameter to an updated bound. + + This transform converts parameter constraints on only one parameter into an updated + upper or lower bound. Note that this transform will convert parameters that can only + take on one value into a `FixedParameter`. Make sure this transform is applied + before `RemoveFixed` if you want to remove all fixed parameters. + """ + + def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: + # keeps track of the constraints that cannot be converted to bounds + nontrivial_constraints: List[ParameterConstraint] = [] + for pc in search_space.parameter_constraints: + if len(pc.constraint_dict) == 1: + # This can be turned into an updated bound since only one variable is + # involved in the constraint. + [(p_name, weight)] = pc.constraint_dict.items() + # NOTE: We only allow parameter constraints on range parameters + p = checked_cast(RangeParameter, search_space.parameters[p_name]) + lb, ub = p.lower, p.upper + if weight == 0 and pc.bound < 0: # Cannot be satisfied + raise ValueError( + "Parameter constraint cannot be satisfied since the weight " + "is zero and the bound is negative." + ) + elif weight == 0: # Constraint is always satisfied + continue + elif weight > 0: # New upper bound + ub = float(pc.bound) / float(weight) + if p.parameter_type == ParameterType.INT: + ub = math.floor(ub) # Round down + else: # New lower bound + lb = float(pc.bound) / float(weight) + if p.parameter_type == ParameterType.INT: + lb = math.ceil(lb) # Round up + + if lb == ub: # Need to turn this into a fixed parameter + search_space.parameters[p_name] = FixedParameter( + name=p_name, parameter_type=p.parameter_type, value=lb + ) + elif weight > 0: + p._upper = ub + else: + p._lower = lb + else: + nontrivial_constraints.append(pc) + search_space.set_parameter_constraints(nontrivial_constraints) + return search_space diff --git a/ax/modelbridge/transforms/tests/test_simplify_parameter_constraints.py b/ax/modelbridge/transforms/tests/test_simplify_parameter_constraints.py new file mode 100644 index 00000000000..0f5e5401162 --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_simplify_parameter_constraints.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from copy import deepcopy +from typing import List + +from ax.core.observation import ObservationFeatures +from ax.core.parameter import ( + ChoiceParameter, + FixedParameter, + Parameter, + ParameterType, + RangeParameter, +) +from ax.core.parameter_constraint import ParameterConstraint +from ax.core.search_space import SearchSpace +from ax.modelbridge.transforms.simplify_parameter_constraints import ( + SimplifyParameterConstraints, +) +from ax.utils.common.testutils import TestCase + + +class SimplifyParameterConstraintsTest(TestCase): + def setUp(self) -> None: + self.parameters: List[Parameter] = [ + RangeParameter("x", lower=1, upper=3, parameter_type=ParameterType.FLOAT), + RangeParameter("y", lower=2, upper=5, parameter_type=ParameterType.INT), + ChoiceParameter( + "z", parameter_type=ParameterType.STRING, values=["a", "b", "c"] + ), + ] + self.observation_features = [ + ObservationFeatures(parameters={"x": 2, "y": 2, "z": "b"}) + ] + + def test_transform_no_constraints(self) -> None: + t = SimplifyParameterConstraints() + ss = SearchSpace(parameters=self.parameters) + ss_transformed = t.transform_search_space(search_space=ss) + self.assertEqual(ss, ss_transformed) + self.assertEqual( + self.observation_features, + t.transform_observation_features(self.observation_features), + ) + + def test_transform_weight_zero(self) -> None: + t = SimplifyParameterConstraints() + ss = SearchSpace( + parameters=self.parameters, + parameter_constraints=[ + ParameterConstraint(constraint_dict={"x": 0}, bound=1) + ], + ) + ss_transformed = t.transform_search_space(search_space=deepcopy(ss)) + self.assertEqual(ss_transformed.parameter_constraints, []) + self.assertEqual(ss.parameters, ss_transformed.parameters) + ss_raises = SearchSpace( + parameters=self.parameters, + parameter_constraints=[ + ParameterConstraint(constraint_dict={"x": 0}, bound=-1) + ], + ) + with self.assertRaisesRegex( + ValueError, "Parameter constraint cannot be satisfied since the weight" + ): + ss_transformed = t.transform_search_space(search_space=deepcopy(ss_raises)) + + def test_transform_search_space(self) -> None: + t = SimplifyParameterConstraints() + ss = SearchSpace( + parameters=self.parameters, + parameter_constraints=[ + ParameterConstraint(constraint_dict={"x": 1}, bound=2), # x <= 2 + ParameterConstraint(constraint_dict={"y": -1}, bound=-4), # y => 4 + ], + ) + ss_transformed = t.transform_search_space(search_space=deepcopy(ss)) + self.assertEqual( + { + **ss.parameters, + "x": RangeParameter( + "x", parameter_type=ParameterType.FLOAT, lower=1, upper=2 + ), + "y": RangeParameter( + "y", parameter_type=ParameterType.INT, lower=4, upper=5 + ), + }, + ss_transformed.parameters, + ) + self.assertEqual(ss_transformed.parameter_constraints, []) + self.assertEqual( # No-op + self.observation_features, + t.transform_observation_features(self.observation_features), + ) + + def test_transform_to_fixed(self) -> None: + t = SimplifyParameterConstraints() + ss = SearchSpace( + parameters=self.parameters, + parameter_constraints=[ + ParameterConstraint(constraint_dict={"x": 1}, bound=1), # x == 1 + ParameterConstraint(constraint_dict={"y": -1}, bound=-5), # y == 5 + ], + ) + ss_transformed = t.transform_search_space(search_space=deepcopy(ss)) + self.assertEqual( + { + **ss.parameters, + "x": FixedParameter("x", parameter_type=ParameterType.FLOAT, value=1), + "y": FixedParameter("y", parameter_type=ParameterType.INT, value=5), + }, + ss_transformed.parameters, + ) + self.assertEqual(ss_transformed.parameter_constraints, []) + self.assertEqual( # No-op + self.observation_features, + t.transform_observation_features(self.observation_features), + ) diff --git a/ax/storage/transform_registry.py b/ax/storage/transform_registry.py index 4ffe9bb43c4..77a548819ee 100644 --- a/ax/storage/transform_registry.py +++ b/ax/storage/transform_registry.py @@ -29,6 +29,9 @@ ) from ax.modelbridge.transforms.remove_fixed import RemoveFixed from ax.modelbridge.transforms.search_space_to_choice import SearchSpaceToChoice +from ax.modelbridge.transforms.simplify_parameter_constraints import ( + SimplifyParameterConstraints, +) from ax.modelbridge.transforms.standardize_y import StandardizeY from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY from ax.modelbridge.transforms.task_encode import TaskEncode @@ -79,6 +82,7 @@ LogY: 23, Relativize: 24, RelativizeWithConstantControl: 25, + SimplifyParameterConstraints: 26, } diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index 32eeacf115c..5a18907d4a9 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -338,6 +338,14 @@ Transforms :undoc-members: :show-inheritance: +`ax.modelbridge.transforms.simplify_parameter_constraints` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.modelbridge.transforms.simplify_parameter_constraints + :members: + :undoc-members: + :show-inheritance: + `ax.modelbridge.transforms.standardize\_y` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~