Skip to content

Commit

Permalink
Remove gen_unlimited trials (#2347)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2347

This diff removes the gen_unlimited_trials property from GenerationNode based on a recent discussion with the team about not needing this property anymore. It is pretty duplicative of the transition_criteria, and adds extra complexity in the code that isn't outweighing the verboseness of the property.

Reviewed By: lena-kashtelyan

Differential Revision: D55001716

fbshipit-source-id: 41f8f2511923c4f2114fd9630c5a0515963fb1d5
  • Loading branch information
mgarrard authored and facebook-github-bot committed Apr 12, 2024
1 parent 9c0bd3f commit 8401990
Show file tree
Hide file tree
Showing 9 changed files with 6 additions and 92 deletions.
5 changes: 0 additions & 5 deletions ax/modelbridge/external_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
node_name: str,
should_deduplicate: bool = True,
transition_criteria: Optional[Sequence[TransitionCriterion]] = None,
gen_unlimited_trials: bool = True,
) -> None:
"""Initialize an external generation node.
Expand All @@ -73,9 +72,6 @@ def __init__(
transition_criteria: Criteria for determining whether to move to the next
node in the generation strategy. This is an advanced option that is
only relevant if the generation strategy consists of multiple nodes.
gen_unlimited_trials: Whether to generate unlimited trials from this node.
This should only be False if the generation strategy will transition to
another node after generating a limited number of trials from this node.
"""
t_init_start = time.monotonic()
super().__init__(
Expand All @@ -84,7 +80,6 @@ def __init__(
best_model_selector=None,
should_deduplicate=should_deduplicate,
transition_criteria=transition_criteria,
gen_unlimited_trials=gen_unlimited_trials,
)
self.fit_time_since_gen: float = time.monotonic() - t_init_start

Expand Down
29 changes: 3 additions & 26 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ class GenerationNode(SerializationMixin, SortableBase):
condition that must be met before completing a GenerationNode. All `is_met`
must evaluateTrue for the GenerationStrategy to move on to the next
GenerationNode.
gen_unlimited_trials: If True the number of trials that can be generated from
this GenerationNode is unlimited.
Note for developers: by "model" here we really mean an Ax ModelBridge object, which
contains an Ax Model under the hood. We call it "model" here to simplify and focus
Expand All @@ -93,7 +91,6 @@ class GenerationNode(SerializationMixin, SortableBase):
# TODO: Move `should_deduplicate` to `ModelSpec` if possible, and make optional
should_deduplicate: bool
_node_name: str
_gen_unlimited_trials: bool = True

# Optional specifications
_model_spec_to_gen_from: Optional[ModelSpec] = None
Expand All @@ -113,7 +110,6 @@ def __init__(
best_model_selector: Optional[BestModelSelector] = None,
should_deduplicate: bool = False,
transition_criteria: Optional[Sequence[TransitionCriterion]] = None,
gen_unlimited_trials: bool = True,
) -> None:
self._node_name = node_name
# While `GenerationNode` only handles a single `ModelSpec` in the `gen`
Expand All @@ -125,7 +121,6 @@ def __init__(
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate
self._transition_criteria = transition_criteria
self._gen_unlimited_trials = gen_unlimited_trials

@property
def node_name(self) -> str:
Expand Down Expand Up @@ -174,11 +169,6 @@ def experiment(self) -> Experiment:
"""Returns the experiment associated with this GenerationStrategy"""
return self.generation_strategy.experiment

@property
def gen_unlimited_trials(self) -> bool:
"""If True, this GenerationNode can generate unlimited trials."""
return self._gen_unlimited_trials

@property
def is_completed(self) -> bool:
"""Returns True if this GenerationNode is complete and should transition to
Expand Down Expand Up @@ -405,7 +395,8 @@ def should_transition_to_next_node(
Returns:
bool: Whether we should transition to the next node.
"""
if self.gen_unlimited_trials and len(self.transition_criteria) == 0:
# if no transition criteria are defined, this node can generate unlimited trials
if len(self.transition_criteria) == 0:
return False, None

transition_blocking = [
Expand Down Expand Up @@ -501,12 +492,6 @@ def generator_run_limit(self, supress_generation_errors: bool = True) -> int:
trials_from_node=self.trials_from_node,
)
if len(gen_blocking_criterion_delta_from_threshold) == 0:
if not self.gen_unlimited_trials:
logger.warning(
"Even though this node is not flagged for generation of unlimited "
"trials, there are no generation blocking criterion, therefore, "
"unlimited trials will be generated."
)
return -1
return min(gen_blocking_criterion_delta_from_threshold)

Expand All @@ -517,9 +502,7 @@ def __repr__(self) -> str:
model_spec_str = str(self.model_specs).replace("\n", " ").replace("\t", "")
str_rep += model_spec_str

# add node name, gen_unlimited_trials, and transition_criteria
str_rep += f", node_name={self.node_name}"
str_rep += f", gen_unlimited_trials={str(self.gen_unlimited_trials)}"
str_rep += f", transition_criteria={str(self.transition_criteria)}"

return f"{str_rep})"
Expand Down Expand Up @@ -667,7 +650,6 @@ def __post_init__(self) -> None:
# because only then is the order of the generation steps actually known.
transition_criteria = []
if self.num_trials != -1:
gen_unlimited_trials = False
transition_criteria.append(
MaxTrials(
threshold=self.num_trials,
Expand All @@ -676,8 +658,7 @@ def __post_init__(self) -> None:
block_transition_if_unmet=True,
)
)
else:
gen_unlimited_trials = True

if self.min_trials_observed > 0:
transition_criteria.append(
MinTrials(
Expand All @@ -701,16 +682,12 @@ def __post_init__(self) -> None:
)
)

if len(self.completion_criteria) > 0:
gen_unlimited_trials = False

transition_criteria += self.completion_criteria
super().__init__(
node_name=f"GenerationStep_{str(self.index)}",
model_specs=[model_spec],
should_deduplicate=self.should_deduplicate,
transition_criteria=transition_criteria,
gen_unlimited_trials=gen_unlimited_trials,
)

@property
Expand Down
13 changes: 1 addition & 12 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@
extend_pending_observations,
get_pending_observation_features_based_on_trial_status,
)
from ax.exceptions.core import (
AxError,
DataRequiredError,
UnsupportedError,
UserInputError,
)
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
Expand Down Expand Up @@ -750,12 +745,6 @@ def _maybe_move_to_next_step(self, raise_data_required_error: bool = True) -> bo
raise_data_required_error=raise_data_required_error
)
if move_to_next_node:
if self._curr.gen_unlimited_trials:
raise AxError(
"The generation strategy is attempting to transition to next node "
"despite the current node being configured to generate unlimited "
"trials. This should not happen."
)
if self.optimization_complete:
raise GenerationStrategyCompleted(
f"Generation strategy {self} generated all the trials as "
Expand Down
36 changes: 1 addition & 35 deletions ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# pyre-strict

import logging
from logging import Logger
from unittest.mock import patch, PropertyMock

Expand Down Expand Up @@ -143,7 +142,6 @@ def test_properties(self) -> None:
node.model_spec_to_gen_from.diagnostics, node.model_specs[0].diagnostics
)
self.assertEqual(node.node_name, "test")
self.assertEqual(node.gen_unlimited_trials, True)
self.assertEqual(node._unique_id, "test")

def test_node_string_representation(self) -> None:
Expand All @@ -156,7 +154,6 @@ def test_node_string_representation(self) -> None:
model_gen_kwargs={},
),
],
gen_unlimited_trials=False,
transition_criteria=[
MaxTrials(threshold=5, only_in_statuses=[TrialStatus.RUNNING])
],
Expand All @@ -167,7 +164,7 @@ def test_node_string_representation(self) -> None:
(
"GenerationNode(model_specs=[ModelSpec(model_enum=GPEI,"
" model_kwargs={}, model_gen_kwargs={}, model_cv_kwargs={},"
" )], node_name=test, gen_unlimited_trials=False, "
" )], node_name=test, "
"transition_criteria=[MaxTrials({'threshold': 5, "
"'only_in_statuses': [<enum 'TrialStatus'>.RUNNING], "
"'not_in_statuses': None, 'transition_to': None, "
Expand Down Expand Up @@ -195,37 +192,6 @@ def test_single_fixed_features(self) -> None:
ObservationFeatures(parameters={"x": 0}),
)

def test_generator_run_limit_unlimited_without_flag(self) -> None:
"""This tests checks that when the `gen_unlimited_trials` flag is false
but there are no generation blocking criteria, then the generator run limit
is set to -1 and a warning is logged.
"""
node = GenerationNode(
node_name="test",
model_specs=[
ModelSpec(
model_enum=Models.GPEI,
model_kwargs={},
model_gen_kwargs={
"n": -1,
"fixed_features": ObservationFeatures(parameters={"x": 0}),
},
),
],
gen_unlimited_trials=False,
)
warning_msg = (
"Even though this node is not flagged for generation of unlimited "
"trials, there are no generation blocking criterion, therefore, "
"unlimited trials will be generated."
)
with self.assertLogs(GenerationNode.__module__, logging.WARNING) as logger:
self.assertEqual(node.generator_run_limit(), -1)
self.assertTrue(
any(warning_msg in output for output in logger.output),
logger.output,
)


class TestGenerationStep(TestCase):
def setUp(self) -> None:
Expand Down
10 changes: 1 addition & 9 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,11 @@ def setUp(self) -> None:
node_name="sobol_node",
transition_criteria=self.sobol_criterion,
model_specs=[self.sobol_model_spec],
gen_unlimited_trials=False,
)
self.gpei_node = GenerationNode(
node_name="GPEI_node",
transition_criteria=self.gpei_criterion,
model_specs=[self.gpei_model_spec],
gen_unlimited_trials=False,
)

self.sobol_GPEI_GS_nodes = GenerationStrategy(
Expand Down Expand Up @@ -304,7 +302,7 @@ def test_string_representation(self) -> None:
"GenerationStrategy(name='Sobol', nodes=[GenerationNode("
"model_specs=[ModelSpec(model_enum=Sobol, "
"model_kwargs={}, model_gen_kwargs={}, model_cv_kwargs={},"
" )], node_name=test, gen_unlimited_trials=True, "
" )], node_name=test, "
"transition_criteria=[])])",
)

Expand Down Expand Up @@ -637,7 +635,6 @@ def test_trials_as_df(self) -> None:
transition_to="sobol_3_trial",
)
],
gen_unlimited_trials=False,
),
GenerationNode(
node_name="sobol_3_trial",
Expand All @@ -651,7 +648,6 @@ def test_trials_as_df(self) -> None:
transition_to=None,
)
],
gen_unlimited_trials=False,
),
]
)
Expand Down Expand Up @@ -1296,9 +1292,6 @@ def test_gs_with_nodes_and_blocking_criteria(self) -> None:
block_transition_if_unmet=True,
),
],
# If we remove this, the test will fail.
# This behavior needs to be improved.
gen_unlimited_trials=False,
)
gpei_model_spec = ModelSpec(
model_enum=Models.GPEI,
Expand All @@ -1308,7 +1301,6 @@ def test_gs_with_nodes_and_blocking_criteria(self) -> None:
gpei_node = GenerationNode(
node_name="GPEI_node",
model_specs=[gpei_model_spec],
gen_unlimited_trials=True,
)
gs = GenerationStrategy(
name="Sobol+GPEI_Nodes",
Expand Down
1 change: 0 additions & 1 deletion ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,6 @@ def generation_node_from_json(
# isn't implemented
best_model_selector=generation_node_json.pop("best_model_selector", None),
should_deduplicate=generation_node_json.pop("should_deduplicate", False),
gen_unlimited_trials=generation_node_json.pop("gen_unlimited_trials", True),
transition_criteria=(
object_from_json(
generation_node_json.pop("transition_criteria"),
Expand Down
1 change: 0 additions & 1 deletion ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,6 @@ def generation_node_to_dict(generation_node: GenerationNode) -> Dict[str, Any]:
"model_specs": generation_node.model_specs,
"should_deduplicate": generation_node.should_deduplicate,
"node_name": generation_node.node_name,
"gen_unlimited_trials": generation_node.gen_unlimited_trials,
"model_spec_to_gen_from": generation_node._model_spec_to_gen_from,
"transition_criteria": generation_node.transition_criteria,
}
Expand Down
2 changes: 0 additions & 2 deletions ax/utils/testing/modeling_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,11 @@ def sobol_gpei_generation_node_gs() -> GenerationStrategy:
node_name="sobol_node",
transition_criteria=sobol_criterion,
model_specs=[sobol_model_spec],
gen_unlimited_trials=False,
)
gpei_node = GenerationNode(
node_name="GPEI_node",
transition_criteria=gpei_criterion,
model_specs=[gpei_model_spec],
gen_unlimited_trials=False,
)

sobol_GPEI_GS_nodes = GenerationStrategy(
Expand Down
1 change: 0 additions & 1 deletion tutorials/external_generation_node.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@
" block_transition_if_unmet=True,\n",
" )\n",
" ],\n",
" gen_unlimited_trials=False,\n",
" ),\n",
" RandomForestGenerationNode(num_samples=128, regressor_options={}),\n",
" ],\n",
Expand Down

0 comments on commit 8401990

Please sign in to comment.