diff --git a/ax/modelbridge/generation_node_input_constructors.py b/ax/modelbridge/generation_node_input_constructors.py new file mode 100644 index 00000000000..45fb0be9346 --- /dev/null +++ b/ax/modelbridge/generation_node_input_constructors.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from enum import Enum, unique +from typing import Any, Dict, Optional + +from ax.modelbridge.generation_node import GenerationNode + + +def consume_all_n( + previous_node: Optional[GenerationNode], + next_node: GenerationNode, + gs_gen_call_kwargs: Dict[str, Any], +) -> int: + """Generate total requested number of arms from the next node. + + Example: Initial exploration with Sobol will generate all arms from a + single sobol node. + + Args: + previous_node: The previous node in the ``GenerationStrategy``. This is the node + that is being transition away from, and is provided for easy access to + properties of this node. + next_node: The next node in the ``GenerationStrategy``. This is the node that + will leverage the inputs defined by this input constructor. + gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s + gen call. + """ + # TODO: @mgarrard handle case where n isn't specified + if gs_gen_call_kwargs.get("n") is None: + raise NotImplementedError( + "Currently consume_all_n only supports cases where n is specified" + ) + return gs_gen_call_kwargs.get("n") + + +@unique +class NodeInputConstructors(Enum): + """An enum which maps to a callable method for constructing ``GenerationNode`` + inputs. + """ + + ALL_N = consume_all_n + + def __call__( + self, + previous_node: Optional[GenerationNode], + next_node: GenerationNode, + gs_gen_call_kwargs: Dict[str, Any], + ) -> int: + """Defines a callable method for the Enum as all values are methods""" + return self( + previous_node=previous_node, + next_node=next_node, + gs_gen_call_kwargs=gs_gen_call_kwargs, + ) + + +@unique +class InputConstructorPurpose(Enum): + """A simple enum to indicate the purpose of the input constructor. + + Explanation of the different purposes: + N: Defines the logic to determine the number of arms to generate from the + next ``GenerationNode`` given the total number of arms expected in + this trial. + """ + + N = "n" diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py new file mode 100644 index 00000000000..0d7d7237431 --- /dev/null +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.modelbridge.generation_node import GenerationNode +from ax.modelbridge.generation_node_input_constructors import NodeInputConstructors +from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.registry import Models +from ax.utils.common.testutils import TestCase + + +class TestGenerationNodeInputConstructors(TestCase): + def setUp(self) -> None: + super().setUp() + self.sobol_model_spec = ModelSpec( + model_enum=Models.SOBOL, + model_kwargs={"init_position": 3}, + model_gen_kwargs={"some_gen_kwarg": "some_value"}, + ) + self.sobol_generation_node = GenerationNode( + node_name="test", model_specs=[self.sobol_model_spec] + ) + + def test_consume_all_n_constructor(self) -> None: + """Test that the consume_all_n_constructor returns full n.""" + num_to_gen = NodeInputConstructors.ALL_N( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={"n": 5}, + ) + + self.assertEqual(num_to_gen, 5) + + def test_consume_all_n_constructor_no_n(self) -> None: + """Test raise error if n is not specified.""" + with self.assertRaisesRegex( + NotImplementedError, "only supports cases where n is specified" + ): + _ = NodeInputConstructors.ALL_N( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={}, + ) diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index d825f4f3989..4190a0de0e5 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -38,6 +38,12 @@ Transition Criterion :undoc-members: :show-inheritance: +Generation Node Input Constructors +.. automodule:: ax.modelbridge.generation_node_input_constructors + :members: + :undoc-members: + :show-inheritance: + Registry ~~~~~~~~ .. automodule:: ax.modelbridge.registry