diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 25e3aeb9a11..38b61afa7a1 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -571,16 +571,24 @@ def __init__( ) values = list(dict_values) + if is_ordered is False and len(values) == 2: + is_ordered = True + warn( + f"Changing `is_ordered` to `True` for `ChoiceParameter` '{name}' since " + "there are only two possible values.", + AxParameterWarning, + stacklevel=3, + ) self._is_ordered: bool = ( is_ordered if is_ordered is not None - else self._get_default_bool_and_warn(param_string="is_ordered") + else self._get_default_is_ordered_and_warn(num_choices=len(values)) ) # sort_values defaults to True if the parameter is not a string self._sort_values: bool = ( sort_values if sort_values is not None - else self._get_default_bool_and_warn(param_string="sort_values") + else self._get_default_sort_values_and_warn() ) if self.sort_values: values = cast(List[TParamValue], sorted([not_none(v) for v in values])) @@ -597,14 +605,33 @@ def __init__( # that is done in `HierarchicalSearchSpace` constructor. self._dependents = dependents - def _get_default_bool_and_warn(self, param_string: str) -> bool: + def _get_default_is_ordered_and_warn(self, num_choices: int) -> bool: + default_bool = self._parameter_type != ParameterType.STRING or num_choices == 2 + if self._parameter_type == ParameterType.STRING and num_choices > 2: + motivation = " since the parameter is a string with more than 2 choices." + elif num_choices == 2: + motivation = " since there are exactly two choices." + else: + motivation = " since the parameter is not of type string." + warn( + f'`is_ordered` is not specified for `ChoiceParameter` "{self._name}". ' + f"Defaulting to `{default_bool}` {motivation}. To override this behavior " + f"(or avoid this warning), specify `is_ordered` during `ChoiceParameter` " + "construction. Note that choice parameters with exactly 2 choices are " + "always considered ordered and that the user-supplied `is_ordered` has no " + "effect in this particular case.", + AxParameterWarning, + stacklevel=3, + ) + return default_bool + + def _get_default_sort_values_and_warn(self) -> bool: default_bool = self._parameter_type != ParameterType.STRING warn( - f'`{param_string}` is not specified for `ChoiceParameter` "{self._name}". ' + f'`sort_values` is not specified for `ChoiceParameter` "{self._name}". ' f"Defaulting to `{default_bool}` for parameters of `ParameterType` " f"{self.parameter_type.name}. To override this behavior (or avoid this " - f"warning), specify `{param_string}` during `ChoiceParameter` " - "construction.", + f"warning), specify `sort_values` during `ChoiceParameter` construction.", AxParameterWarning, stacklevel=3, ) diff --git a/ax/core/tests/test_parameter.py b/ax/core/tests/test_parameter.py index a5236f296a4..dbc5b5825e3 100644 --- a/ax/core/tests/test_parameter.py +++ b/ax/core/tests/test_parameter.py @@ -16,7 +16,7 @@ ParameterType, RangeParameter, ) -from ax.exceptions.core import AxWarning, UserInputError +from ax.exceptions.core import AxParameterWarning, AxWarning, UserInputError from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import not_none @@ -242,7 +242,7 @@ def setUp(self) -> None: ) self.param3_repr = ( "ChoiceParameter(name='x', parameter_type=STRING, " - "values=['foo', 'bar'], is_fidelity=True, is_ordered=False, " + "values=['foo', 'bar'], is_fidelity=True, is_ordered=True, " "sort_values=False, target_value='bar')" ) self.param4 = ChoiceParameter( @@ -485,7 +485,7 @@ def test_summary_dict(self) -> None: "type": "Choice", "domain": "values=['foo', 'bar']", "parameter_type": "string", - "flags": "fidelity, unordered, unsorted", + "flags": "fidelity, ordered, unsorted", "target_value": "bar", }, ) @@ -509,6 +509,48 @@ def test_duplicate_values(self) -> None: ) self.assertEqual(p.values, ["foo", "bar"]) + def test_two_values_is_ordered(self) -> None: + parameter_types = ( + ParameterType.INT, + ParameterType.FLOAT, + ParameterType.BOOL, + ParameterType.STRING, + ) + parameter_values = ([0, 4], [0, 1.234], [False, True], ["foo", "bar"]) + for parameter_type, values in zip(parameter_types, parameter_values): + p = ChoiceParameter( + name="x", + parameter_type=parameter_type, + values=values, # pyre-ignore + ) + self.assertEqual(p._is_ordered, True) + + # Change `is_ordered` to True and warn + with self.assertWarnsRegex( + AxParameterWarning, + "Changing `is_ordered` to `True` for `ChoiceParameter` 'x' since " + "there are only two possible values", + ): + p = ChoiceParameter( + name="x", + parameter_type=parameter_type, + values=values, # pyre-ignore + is_ordered=False, + ) + self.assertEqual(p._is_ordered, True) + + # Set to True if `is_ordered` is not specified + with self.assertWarnsRegex( + AxParameterWarning, "since there are exactly two choices" + ): + p = ChoiceParameter( + name="x", + parameter_type=parameter_type, + values=values, # pyre-ignore + sort_values=False, + ) + self.assertEqual(p._is_ordered, True) + class FixedParameterTest(TestCase): def setUp(self) -> None: diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 1197058ac56..f7c0ec88ab3 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -850,10 +850,6 @@ def test_hierarchical_search_space(self) -> None: all_parameter_names = checked_cast( HierarchicalSearchSpace, experiment.search_space )._all_parameter_names.copy() - # One of the parameter names is modified by transforms (because it's - # one-hot encoded). - all_parameter_names.remove("model") - all_parameter_names.add("model_OH_PARAM_") for obs in observations: for p_name in all_parameter_names: self.assertIn(p_name, obs.features.parameters) diff --git a/ax/modelbridge/transforms/tests/test_one_hot_transform.py b/ax/modelbridge/transforms/tests/test_one_hot_transform.py index 50baf7a0c7e..18fd799eb08 100644 --- a/ax/modelbridge/transforms/tests/test_one_hot_transform.py +++ b/ax/modelbridge/transforms/tests/test_one_hot_transform.py @@ -36,7 +36,6 @@ def setUp(self) -> None: "c", parameter_type=ParameterType.BOOL, values=[True, False], - is_ordered=False, ), ChoiceParameter( "d", @@ -66,8 +65,7 @@ def setUp(self) -> None: "b" + OH_PARAM_INFIX + "_0": 0, "b" + OH_PARAM_INFIX + "_1": 1, "b" + OH_PARAM_INFIX + "_2": 0, - # Only two choices => one parameter. - "c" + OH_PARAM_INFIX: 0, + "c": False, "d": 10.0, } ) @@ -76,8 +74,8 @@ def setUp(self) -> None: ) def test_Init(self) -> None: - self.assertEqual(list(self.t.encoded_parameters.keys()), ["b", "c"]) - self.assertEqual(list(self.t2.encoded_parameters.keys()), ["b", "c"]) + self.assertEqual(list(self.t.encoded_parameters.keys()), ["b"]) + self.assertEqual(list(self.t2.encoded_parameters.keys()), ["b"]) def test_TransformObservationFeatures(self) -> None: observation_features = [self.observation_features] @@ -126,9 +124,7 @@ def test_TransformSearchSpace(self) -> None: ss2.parameters["b" + OH_PARAM_INFIX + "_1"].parameter_type, ParameterType.FLOAT, ) - self.assertEqual( - ss2.parameters["c" + OH_PARAM_INFIX].parameter_type, ParameterType.FLOAT - ) + self.assertEqual(ss2.parameters["c"].parameter_type, ParameterType.BOOL) self.assertEqual(ss2.parameters["d"].parameter_type, ParameterType.FLOAT) # Parameter range fixed to [0,1]. @@ -136,8 +132,7 @@ def test_TransformSearchSpace(self) -> None: self.assertEqual(ss2.parameters["b" + OH_PARAM_INFIX + "_0"].lower, 0.0) # pyre-fixme[16]: `Parameter` has no attribute `upper`. self.assertEqual(ss2.parameters["b" + OH_PARAM_INFIX + "_1"].upper, 1.0) - self.assertEqual(ss2.parameters["c" + OH_PARAM_INFIX].lower, 0.0) - self.assertEqual(ss2.parameters["c" + OH_PARAM_INFIX].upper, 1.0) + self.assertEqual(ss2.parameters["c"].parameter_type, ParameterType.BOOL) # Ensure we error if we try to transform a fidelity parameter ss3 = SearchSpace( @@ -158,14 +153,11 @@ def test_TransformSearchSpace(self) -> None: def test_w_parameter_distributions(self) -> None: rss = get_robust_search_space() # Transform a non-distributional parameter. - t = OneHot( - search_space=rss, - observations=[], - ) + t = OneHot(search_space=rss, observations=[]) rss_new = t.transform_search_space(rss) # Make sure that the return value is still a RobustSearchSpace. self.assertIsInstance(rss_new, RobustSearchSpace) - self.assertEqual(len(rss_new.parameters.keys()), 4) + self.assertEqual(len(rss_new.parameters.keys()), 6) # pyre-fixme[16]: `SearchSpace` has no attribute `parameter_distributions`. self.assertEqual(rss.parameter_distributions, rss_new.parameter_distributions) self.assertNotIn("c", rss_new.parameters) @@ -183,7 +175,7 @@ def test_w_parameter_distributions(self) -> None: ) rss_new = t.transform_search_space(rss) self.assertIsInstance(rss_new, RobustSearchSpace) - self.assertEqual(len(rss_new.parameters.keys()), 4) + self.assertEqual(len(rss_new.parameters.keys()), 6) self.assertEqual(rss.parameter_distributions, rss_new.parameter_distributions) # pyre-fixme[16]: `SearchSpace` has no attribute `_environmental_variables`. self.assertEqual(rss._environmental_variables, rss_new._environmental_variables) diff --git a/ax/modelbridge/transforms/tests/test_trial_as_task_transform.py b/ax/modelbridge/transforms/tests/test_trial_as_task_transform.py index be83e230319..4fdd3232a5a 100644 --- a/ax/modelbridge/transforms/tests/test_trial_as_task_transform.py +++ b/ax/modelbridge/transforms/tests/test_trial_as_task_transform.py @@ -160,7 +160,7 @@ def test_TransformSearchSpace(self) -> None: self.assertEqual(p.parameter_type, ParameterType.STRING) self.assertEqual(set(p.values), {"u1", "u2"}) self.assertTrue(p.is_task) - self.assertFalse(p.is_ordered) + self.assertTrue(p.is_ordered) # 2 choices so always ordered self.assertEqual(p.target_value, "u1") t = TrialAsTask( search_space=self.search_space, diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index c44bfccfd3d..3c645ff1327 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -1112,7 +1112,7 @@ def get_robust_search_space( RangeParameter("x", ParameterType.FLOAT, lb, ub), RangeParameter("y", ParameterType.FLOAT, lb, ub), RangeParameter("z", ParameterType.INT, lb, ub), - ChoiceParameter("c", ParameterType.STRING, ["red", "panda"]), + ChoiceParameter("c", ParameterType.STRING, ["red", "blue", "green"]), ] if multivariate: if use_discrete: