diff --git a/ax/modelbridge/generation_node_input_constructors.py b/ax/modelbridge/generation_node_input_constructors.py index d654df3c06e..7fc8938ab91 100644 --- a/ax/modelbridge/generation_node_input_constructors.py +++ b/ax/modelbridge/generation_node_input_constructors.py @@ -57,7 +57,7 @@ def repeat_arm_n( gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s gen call. Returns: - The number of requested arms from the next node. + The number of requested arms from the next node """ if gs_gen_call_kwargs.get("n") is None: raise NotImplementedError( @@ -73,6 +73,34 @@ def repeat_arm_n( return ceil(total_n / 10) +def remaining_n( + previous_node: Optional[GenerationNode], + next_node: GenerationNode, + gs_gen_call_kwargs: Dict[str, Any], +) -> int: + """Generate the remaining number of arms requested for this trial in gs.gen(). + + Args: + previous_node: The previous node in the ``GenerationStrategy``. This is the node + that is being transition away from, and is provided for easy access to + properties of this node. + next_node: The next node in the ``GenerationStrategy``. This is the node that + will leverage the inputs defined by this input constructor. + gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s + gen call. + Returns: + The number of requested arms from the next node + """ + if gs_gen_call_kwargs.get("n") is None: + raise NotImplementedError( + "Currently consume_all_n only supports cases where n is specified" + ) + grs = gs_gen_call_kwargs.get("grs") + total_n = gs_gen_call_kwargs.get("n") + # if all arms have been generated, return 0 + return max(total_n - sum(len(gr.arms) for gr in grs), 0) + + @unique class NodeInputConstructors(Enum): """An enum which maps to a callable method for constructing ``GenerationNode`` @@ -81,6 +109,7 @@ class NodeInputConstructors(Enum): ALL_N = consume_all_n REPEAT_N = repeat_arm_n + REMAINING_N = remaining_n def __call__( self, diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py index 51be970476a..382222c58f6 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -5,11 +5,14 @@ # pyre-strict +from ax.core.arm import Arm +from ax.core.generator_run import GeneratorRun from ax.modelbridge.generation_node import GenerationNode from ax.modelbridge.generation_node_input_constructors import NodeInputConstructors from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.registry import Models from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_experiment class TestGenerationNodeInputConstructors(TestCase): @@ -23,6 +26,7 @@ def setUp(self) -> None: self.sobol_generation_node = GenerationNode( node_name="test", model_specs=[self.sobol_model_spec] ) + self.experiment = get_branin_experiment(with_completed_trial=True) def test_consume_all_n_constructor(self) -> None: """Test that the consume_all_n_constructor returns full n.""" @@ -54,6 +58,38 @@ def test_repeat_arm_n_constructor(self) -> None: self.assertEqual(medium_n, 1) self.assertEqual(large_n, 2) + def test_remaining_n_constructor(self) -> None: + """Test that the remaining_n_constructor returns the remaining n.""" + # construct a list of grs that will mock a list of grs that would exist during + # a gs.gen call. This list has one single arm GR, and one 3-arm GR. + grs = [ + GeneratorRun(arms=[Arm(parameters={"x1": 1, "x2": 5})]), + GeneratorRun(arms=[Arm(parameters={"x1": 1, "x2": y}) for y in range(3)]), + ] + # should return 1 becuase 4 arms already exist and 5 are requested + expect_1 = NodeInputConstructors.REMAINING_N( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={"n": 5, "grs": grs}, + ) + # should return 0 becuase 4 arms already exist and 4 are requested + expect_0 = NodeInputConstructors.REMAINING_N( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={"n": 4, "grs": grs}, + ) + # should return 0 becuase 4 arms already exist and 3 are requested + # this is a bad state that should never be hit, but ensuring proper + # handling here feels like a valid edge case + expect_0b = NodeInputConstructors.REMAINING_N( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={"n": 3, "grs": grs}, + ) + self.assertEqual(expect_1, 1) + self.assertEqual(expect_0, 0) + self.assertEqual(expect_0b, 0) + def test_no_n_provided_error(self) -> None: """Test raise error if n is not specified.""" with self.assertRaisesRegex( @@ -72,3 +108,11 @@ def test_no_n_provided_error(self) -> None: next_node=self.sobol_generation_node, gs_gen_call_kwargs={}, ) + with self.assertRaisesRegex( + NotImplementedError, "only supports cases where n is specified" + ): + _ = NodeInputConstructors.REMAINING_N( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={}, + )