Skip to content

Commit

Permalink
add remaining n input constructor
Browse files Browse the repository at this point in the history
Summary:
This diff adds the input constructor for our default logic for repeat arms. See note in comments about how this new logic differs
Follow up diffs:
- update the transition logic to leverage this
- storage --> let's do this once we all like the 3 input constructors
- update the input constructors to handle the case where n isn't provided as a kwarg
- add tests for real world scenario per Liz suggestion
- add test to enforce signature across input constructors per Liz/Sait/Daniel suggestion

Differential Revision: D62467907
  • Loading branch information
mgarrard authored and facebook-github-bot committed Sep 10, 2024
1 parent 515d305 commit cd76008
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
31 changes: 30 additions & 1 deletion ax/modelbridge/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def repeat_arm_n(
gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call.
Returns:
The number of requested arms from the next node.
The number of requested arms from the next node
"""
if gs_gen_call_kwargs.get("n") is None:
raise NotImplementedError(
Expand All @@ -73,6 +73,34 @@ def repeat_arm_n(
return ceil(total_n / 10)


def remaining_n(
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Generate the remaining number of arms requested for this trial in gs.gen().
Args:
previous_node: The previous node in the ``GenerationStrategy``. This is the node
that is being transition away from, and is provided for easy access to
properties of this node.
next_node: The next node in the ``GenerationStrategy``. This is the node that
will leverage the inputs defined by this input constructor.
gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call.
Returns:
The number of requested arms from the next node
"""
if gs_gen_call_kwargs.get("n") is None:
raise NotImplementedError(
"Currently consume_all_n only supports cases where n is specified"
)
grs = gs_gen_call_kwargs.get("grs")
total_n = gs_gen_call_kwargs.get("n")
# if all arms have been generated, return 0
return max(total_n - sum(len(gr.arms) for gr in grs), 0)


@unique
class NodeInputConstructors(Enum):
"""An enum which maps to a callable method for constructing ``GenerationNode``
Expand All @@ -81,6 +109,7 @@ class NodeInputConstructors(Enum):

ALL_N = consume_all_n
REPEAT_N = repeat_arm_n
REMAINING_N = remaining_n

def __call__(
self,
Expand Down
44 changes: 44 additions & 0 deletions ax/modelbridge/tests/test_generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

# pyre-strict

from ax.core.arm import Arm
from ax.core.generator_run import GeneratorRun
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_node_input_constructors import NodeInputConstructors
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment


class TestGenerationNodeInputConstructors(TestCase):
Expand All @@ -23,6 +26,7 @@ def setUp(self) -> None:
self.sobol_generation_node = GenerationNode(
node_name="test", model_specs=[self.sobol_model_spec]
)
self.experiment = get_branin_experiment(with_completed_trial=True)

def test_consume_all_n_constructor(self) -> None:
"""Test that the consume_all_n_constructor returns full n."""
Expand Down Expand Up @@ -54,6 +58,38 @@ def test_repeat_arm_n_constructor(self) -> None:
self.assertEqual(medium_n, 1)
self.assertEqual(large_n, 2)

def test_remaining_n_constructor(self) -> None:
"""Test that the remaining_n_constructor returns the remaining n."""
# construct a list of grs that will mock a list of grs that would exist during
# a gs.gen call. This list has one single arm GR, and one 3-arm GR.
grs = [
GeneratorRun(arms=[Arm(parameters={"x1": 1, "x2": 5})]),
GeneratorRun(arms=[Arm(parameters={"x1": 1, "x2": y}) for y in range(3)]),
]
# should return 1 becuase 4 arms already exist and 5 are requested
expect_1 = NodeInputConstructors.REMAINING_N(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={"n": 5, "grs": grs},
)
# should return 0 becuase 4 arms already exist and 4 are requested
expect_0 = NodeInputConstructors.REMAINING_N(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={"n": 4, "grs": grs},
)
# should return 0 becuase 4 arms already exist and 3 are requested
# this is a bad state that should never be hit, but ensuring proper
# handling here feels like a valid edge case
expect_0b = NodeInputConstructors.REMAINING_N(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={"n": 3, "grs": grs},
)
self.assertEqual(expect_1, 1)
self.assertEqual(expect_0, 0)
self.assertEqual(expect_0b, 0)

def test_no_n_provided_error(self) -> None:
"""Test raise error if n is not specified."""
with self.assertRaisesRegex(
Expand All @@ -72,3 +108,11 @@ def test_no_n_provided_error(self) -> None:
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={},
)
with self.assertRaisesRegex(
NotImplementedError, "only supports cases where n is specified"
):
_ = NodeInputConstructors.REMAINING_N(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={},
)

0 comments on commit cd76008

Please sign in to comment.