Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use input constructors during GS #2761

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from logging import Logger
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

# Module-level import to avoid circular dependency b/w this file and
# generation_strategy.py
Expand All @@ -29,6 +29,7 @@
from ax.exceptions.generation_strategy import GenerationStrategyRepeatedPoints
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.best_model_selector import BestModelSelector

from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase
from ax.modelbridge.transition_criterion import (
Expand Down Expand Up @@ -84,6 +85,14 @@ class GenerationNode(SerializationMixin, SortableBase):
condition that must be met before completing a GenerationNode. All `is_met`
must evaluateTrue for the GenerationStrategy to move on to the next
GenerationNode.
input_constructors: A dictionary mapping input constructor purpose enum to the
input constructor enum. Each input constructor maps to a method which
encodes the logic for determining dynamic inputs to the ``GenerationNode``
previous_node: The previous ``GenerationNode`` name in the
``GenerationStrategy``, if any. Initialized to None for all nodes, and is
set during transition from one ``GenerationNode`` to the next. Can be
overwritten if multiple transitions occur between nodes, and will always
store the most recent previous ``GenerationNode`` name.

Note for developers: by "model" here we really mean an Ax ModelBridge object, which
contains an Ax Model under the hood. We call it "model" here to simplify and focus
Expand All @@ -100,6 +109,13 @@ class GenerationNode(SerializationMixin, SortableBase):
_model_spec_to_gen_from: Optional[ModelSpec] = None
# TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping?
_transition_criteria: Optional[Sequence[TransitionCriterion]]
_input_constructors: Optional[
Dict[
modelbridge.generation_node_input_constructors.InputConstructorPurpose,
modelbridge.generation_node_input_constructors.NodeInputConstructors,
]
]
_previous_node: Optional[str] = None

# [TODO] Handle experiment passing more eloquently by enforcing experiment
# attribute is set in generation strategies class
Expand All @@ -114,6 +130,13 @@ def __init__(
best_model_selector: Optional[BestModelSelector] = None,
should_deduplicate: bool = False,
transition_criteria: Optional[Sequence[TransitionCriterion]] = None,
input_constructors: Optional[
Dict[
modelbridge.generation_node_input_constructors.InputConstructorPurpose,
modelbridge.generation_node_input_constructors.NodeInputConstructors,
]
] = None,
previous_node: Optional[str] = None,
) -> None:
self._node_name = node_name
# Check that the model specs have unique model keys.
Expand All @@ -128,6 +151,8 @@ def __init__(
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate
self._transition_criteria = transition_criteria
self._input_constructors = input_constructors
self._previous_node = previous_node

@property
def node_name(self) -> str:
Expand Down Expand Up @@ -174,6 +199,18 @@ def transition_criteria(self) -> Sequence[TransitionCriterion]:
"""
return [] if self._transition_criteria is None else self._transition_criteria

@property
def input_constructors(
self,
) -> Dict[
modelbridge.generation_node_input_constructors.InputConstructorPurpose,
modelbridge.generation_node_input_constructors.NodeInputConstructors,
]:
"""Returns the input constructors that will be used to determine any dynamic
inputs to this ``GenerationNode``.
"""
return self._input_constructors if self._input_constructors is not None else {}

@property
def experiment(self) -> Experiment:
"""Returns the experiment associated with this GenerationStrategy"""
Expand Down Expand Up @@ -479,7 +516,7 @@ def transition_edges(self) -> dict[str, list[TransitionCriterion]]:

def should_transition_to_next_node(
self, raise_data_required_error: bool = True
) -> tuple[bool, Optional[str]]:
) -> tuple[bool, str]:
"""Checks whether we should transition to the next node based on this node's
TransitionCriterion.

Expand All @@ -493,12 +530,12 @@ def should_transition_to_next_node(
check how many generator runs (to be made into trials) can be produced,
but not actually producing them yet.
Returns:
Tuple[bool, Optional[str]]: Whether we should transition to the next node
and the name of the next node.
Tuple[bool, str]: Whether we should transition to the next node
and the name of the node to gen from (either the current or next node)
"""
# if no transition criteria are defined, this node can generate unlimited trials
if len(self.transition_criteria) == 0:
return False, None
return False, self.node_name

# for each edge in node DAG, check if the transition criterion are met, if so
# transition to the next node defined by that edge.
Expand Down Expand Up @@ -542,7 +579,7 @@ def should_transition_to_next_node(
if len(transition_blocking) > 0 and transition_blocking_met:
return True, next_node

return False, None
return False, self.node_name

def generator_run_limit(self, raise_generation_errors: bool = False) -> int:
"""How many generator runs can this generation strategy generate right now,
Expand Down
138 changes: 138 additions & 0 deletions ax/modelbridge/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from enum import Enum, unique
from math import ceil
from typing import Any, Dict, Optional

from ax.modelbridge.generation_node import GenerationNode


def consume_all_n(
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Generate total requested number of arms from the next node.

Example: Initial exploration with Sobol will generate all arms from a
single sobol node.

Args:
previous_node: The previous node in the ``GenerationStrategy``. This is the node
that is being transition away from, and is provided for easy access to
properties of this node.
next_node: The next node in the ``GenerationStrategy``. This is the node that
will leverage the inputs defined by this input constructor.
gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call.
Returns:
The total number of requested arms from the next node.
"""
# TODO: @mgarrard handle case where n isn't specified
if gs_gen_call_kwargs.get("n") is None:
raise NotImplementedError(
"Currently consume_all_n only supports cases where n is specified"
)
return gs_gen_call_kwargs.get("n")


def repeat_arm_n(
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Generate a small percentage of arms requested to be used for repeat arms in
the next trial.

Args:
previous_node: The previous node in the ``GenerationStrategy``. This is the node
that is being transition away from, and is provided for easy access to
properties of this node.
next_node: The next node in the ``GenerationStrategy``. This is the node that
will leverage the inputs defined by this input constructor.
gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call.
Returns:
The number of requested arms from the next node
"""
if gs_gen_call_kwargs.get("n") is None:
raise NotImplementedError(
"Currently consume_all_n only supports cases where n is specified"
)
total_n = gs_gen_call_kwargs.get("n")
if total_n < 6:
# if the next trial is small, we don't want to waste allocation on repeat arms
# users can still manually add repeat arms if they want before allocation
return 0
elif total_n <= 10:
return 1
return ceil(total_n / 10)


def remaining_n(
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Generate the remaining number of arms requested for this trial in gs.gen().

Args:
previous_node: The previous node in the ``GenerationStrategy``. This is the node
that is being transition away from, and is provided for easy access to
properties of this node.
next_node: The next node in the ``GenerationStrategy``. This is the node that
will leverage the inputs defined by this input constructor.
gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call.
Returns:
The number of requested arms from the next node
"""
if gs_gen_call_kwargs.get("n") is None:
raise NotImplementedError(
"Currently consume_all_n only supports cases where n is specified"
)
grs = gs_gen_call_kwargs.get("grs")
total_n = gs_gen_call_kwargs.get("n")
# if all arms have been generated, return 0
return max(total_n - sum(len(gr.arms) for gr in grs), 0)


@unique
class NodeInputConstructors(Enum):
"""An enum which maps to a callable method for constructing ``GenerationNode``
inputs.
"""

ALL_N = consume_all_n
REPEAT_N = repeat_arm_n
REMAINING_N = remaining_n

def __call__(
self,
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Defines a callable method for the Enum as all values are methods"""
return self(
previous_node=previous_node,
next_node=next_node,
gs_gen_call_kwargs=gs_gen_call_kwargs,
)


@unique
class InputConstructorPurpose(Enum):
"""A simple enum to indicate the purpose of the input constructor.

Explanation of the different purposes:
N: Defines the logic to determine the number of arms to generate from the
next ``GenerationNode`` given the total number of arms expected in
this trial.
"""

N = "n"
52 changes: 43 additions & 9 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_node import GenerationNode, GenerationStep
from ax.modelbridge.generation_node_input_constructors import InputConstructorPurpose
from ax.modelbridge.model_spec import FactoryFunctionModelSpec
from ax.modelbridge.modelbridge_utils import get_fixed_features_from_experiment
from ax.modelbridge.transition_criterion import TrialBasedCriterion
Expand Down Expand Up @@ -384,6 +385,7 @@ def gen_with_multiple_nodes(
data: Optional[Data] = None,
pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None,
arms_per_node: Optional[dict[str, int]] = None,
n: int = 1,
) -> list[GeneratorRun]:
"""Produces a List of GeneratorRuns for a single trial, either ``Trial`` or
``BatchTrial``, and if producing a ``BatchTrial`` allows for multiple
Expand Down Expand Up @@ -427,25 +429,57 @@ def gen_with_multiple_nodes(
the spelling.
"""
)
if arms_per_node is None:
arms_per_node = {node_name: 1 for node_name in node_names}
grs = []
continue_gen_for_trial = True

# TODO: @mgarrard update this when gen methods are merged
gen_kwargs = {
"experiment": experiment,
"data": data,
"pending_observations": pending_observations,
"grs": grs,
"n": n,
}
while continue_gen_for_trial:
next_node_name = self.current_node_name
should_transition, next_node = self._curr.should_transition_to_next_node(
raise_data_required_error=False
gen_kwargs["grs"] = grs
should_transition, node_to_gen_from_name = (
self._curr.should_transition_to_next_node(
raise_data_required_error=False
)
)
node_to_gen_from = self._nodes[node_names.index(node_to_gen_from_name)]
if should_transition:
assert next_node is not None
next_node_name = next_node
node_to_gen_from._previous_node = node_to_gen_from_name

# calculate the number of arms to generate from the node that will be used
# during generation.
if arms_per_node is not None:
# arms_per_node provides a way to manually override input
# constructors. This should be used with caution, and only
# if you really know what you're doing. :)
arms_from_node = arms_per_node[node_to_gen_from_name]
elif InputConstructorPurpose.N not in node_to_gen_from.input_constructors:
# if the node does not have an input constructor for N, then we
# assume a default of using all n arms
arms_from_node = n
else:
previous_node = (
self._nodes[node_names.index(node_to_gen_from._previous_node)]
if node_to_gen_from._previous_node is not None
else None
)
arms_from_node = node_to_gen_from.input_constructors[
InputConstructorPurpose.N
](
previous_node=previous_node,
next_node=node_to_gen_from,
gs_gen_call_kwargs=gen_kwargs,
)
grs.extend(
self._gen_multiple(
experiment=experiment,
num_generator_runs=1,
data=data,
n=arms_per_node[next_node_name],
n=arms_from_node,
pending_observations=pending_observations,
)
)
Expand Down
23 changes: 23 additions & 0 deletions ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
GenerationStep,
MISSING_MODEL_SELECTOR_MESSAGE,
)
from ax.modelbridge.generation_node_input_constructors import (
InputConstructorPurpose,
NodeInputConstructors,
)
from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import MaxTrials
Expand Down Expand Up @@ -77,6 +81,25 @@ def test_init(self) -> None:
self.assertEqual(node.model_specs, mbm_specs)
self.assertIs(node.best_model_selector, model_selector)

def test_input_constructor_none(self) -> None:
self.assertIsNone(self.sobol_generation_node._input_constructors)
self.assertEqual(self.sobol_generation_node.input_constructors, {})

def test_input_constructor(self) -> None:
node = GenerationNode(
node_name="test",
model_specs=[self.sobol_model_spec],
input_constructors={InputConstructorPurpose.N: NodeInputConstructors.ALL_N},
)
self.assertEqual(
node.input_constructors,
{InputConstructorPurpose.N: NodeInputConstructors.ALL_N},
)
self.assertEqual(
node._input_constructors,
{InputConstructorPurpose.N: NodeInputConstructors.ALL_N},
)

def test_fit(self) -> None:
dat = self.branin_experiment.lookup_data()
with patch.object(
Expand Down
Loading
Loading