Skip to content

Commit

Permalink
Remove unused "run trials status" (#2668)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2668

Removing unused functionality

Reviewed By: Balandat

Differential Revision: D54885248

fbshipit-source-id: 4b8002e628a271afc541778a8cc12fe3934eea2d
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Sep 5, 2024
1 parent 0fe2fd5 commit eaeefc6
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 119 deletions.
87 changes: 9 additions & 78 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import datetime
from enum import Enum
from logging import LoggerAdapter
from time import sleep
from typing import Any, Callable, cast, NamedTuple, Optional
Expand Down Expand Up @@ -52,7 +51,7 @@
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.scheduler_options import SchedulerOptions, TrialType
from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase
from ax.utils.common.constants import Keys
from ax.utils.common.constants import Keys, TS_FMT
from ax.utils.common.docutils import copy_doc
from ax.utils.common.executils import retry_on_exception
from ax.utils.common.logger import (
Expand Down Expand Up @@ -113,32 +112,6 @@ class FailureRateExceededError(AxError):
)


class ExperimentStatusProperties(str, Enum):
"""Enum for keys in experiment properties that represent status of
optimization run through scheduler."""

# Number of trials run in each call to `Scheduler.run_trials_and_
# yield_results`.
NUM_TRIALS_RUN_PER_CALL = "num_trials_run_per_call"
# Status of each run of `Scheduler.run_trials_and_
# yield_results`. Recorded twice in a successful/aborted run; first
# "started" is recorded, then "success" or "aborted". If no second
# status is recorded, run must have encountered an exception.
RUN_TRIALS_STATUS = "run_trials_success"
# Timestamps of when the experiment was resumed from storage.
RESUMED_FROM_STORAGE_TIMESTAMPS = "resumed_from_storage_timestamps"


class RunTrialsStatus(str, Enum):
"""Possible statuses for each call to ``Scheduler.run_trials_and_
yield_results``, used in recording experiment status.
"""

STARTED = "started"
SUCCESS = "success"
ABORTED = "aborted"


class Scheduler(WithDBSettingsBase, BestPointMixin):
"""Closed-loop manager class for Ax optimization.
Expand Down Expand Up @@ -233,7 +206,6 @@ def __init__(
self._validate_remaining_trials(experiment=experiment)
if self.options.enforce_immutable_search_space_and_opt_config:
self._enforce_immutable_search_space_and_opt_config()
self._initialize_experiment_status_properties()

if self.db_settings_set and not _skip_experiment_save:
self._maybe_save_experiment_and_generation_strategy(
Expand Down Expand Up @@ -331,12 +303,7 @@ def from_stored_experiment(
# provided to this function.
**kwargs,
)
ts = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S.%f")
scheduler._append_to_experiment_properties(
to_append={
ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS: ts,
}
)
scheduler._record_experiment_resumption_from_storage()
return scheduler

@property
Expand Down Expand Up @@ -934,10 +901,6 @@ def run_trials_and_yield_results(
# trials are pre-existing only if they do not still require running
n_existing = len(self.experiment.trials) - n_initial_candidate_trials

self._record_run_trials_status(
num_preexisting_trials=None, status=RunTrialsStatus.STARTED
)

# Until completion criterion is reached or `max_trials` is scheduled,
# schedule new trials and poll existing ones in a loop.
self._num_remaining_requested_trials = max_trials
Expand Down Expand Up @@ -1547,10 +1510,6 @@ def _abort_optimization(self, num_preexisting_trials: int) -> dict[str, Any]:
return results so far via `report_results`.
"""
self._record_optimization_complete_message()
self._record_run_trials_status(
num_preexisting_trials=num_preexisting_trials,
status=RunTrialsStatus.ABORTED,
)
return self.report_results(force_refit=True)

def _complete_optimization(
Expand All @@ -1569,10 +1528,6 @@ def _complete_optimization(
# Raise an error if the failure rate exceeds tolerance at the
# end of the optimization.
self.error_if_failure_rate_exceeded(force_check=True)
self._record_run_trials_status(
num_preexisting_trials=num_preexisting_trials,
status=RunTrialsStatus.SUCCESS,
)
self.warn_if_non_terminal_trials()
return res

Expand Down Expand Up @@ -1979,29 +1934,6 @@ def _enforce_immutable_search_space_and_opt_config(self) -> None:
True
)

def _initialize_experiment_status_properties(self) -> None:
"""Initializes status-tracking properties of the experiment, which will
be appended to in ``run_trials_and_yield_results``."""
for status_prop_enum_member in ExperimentStatusProperties:
if status_prop_enum_member not in self.experiment._properties:
self.experiment._properties[status_prop_enum_member.value] = []

def _record_run_trials_status(
self, num_preexisting_trials: Optional[int], status: RunTrialsStatus
) -> None:
"""Records status of each call to ``Scheduler.run_trials_and_yield_results``
in properties of this experiment for monitoring of experiment success.
"""
to_append: dict[str, Any] = {
ExperimentStatusProperties.RUN_TRIALS_STATUS.value: status.value
}
if num_preexisting_trials is not None:
new_trials = len(self.experiment.trials) - num_preexisting_trials
to_append[ExperimentStatusProperties.NUM_TRIALS_RUN_PER_CALL.value] = (
new_trials
)
self._append_to_experiment_properties(to_append=to_append)

def _record_optimization_complete_message(self) -> None:
"""Adds a simple optimization completion message to this scheduler's markdown
messages.
Expand All @@ -2017,15 +1949,14 @@ def _record_optimization_complete_message(self) -> None:
else:
self.markdown_messages["Optimization complete"] = completion_msg

def _append_to_experiment_properties(self, to_append: dict[str, Any]) -> None:
"""Appends to list fields in experiment properties based on ``to_append``
input dict of form {property_name: value_to_append}.
def _record_experiment_resumption_from_storage(self) -> None:
"""Adds a timestamp for resumption-from-storage, to the experiment properties.
Useful for debugging purposes and for keeping track of resumption events.
"""
for prop, val_to_append in to_append.items():
if prop in self.experiment._properties:
self.experiment._properties[prop].append(val_to_append)
else:
self.experiment._properties[prop] = [val_to_append]
resumption_timestamps = self.experiment._properties.setdefault(
Keys.RESUMED_FROM_STORAGE_TS.value, []
)
resumption_timestamps.append(datetime.strftime(datetime.now(), TS_FMT))
self._update_experiment_properties_in_db(
experiment_with_updated_properties=self.experiment
)
Expand Down
38 changes: 7 additions & 31 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data

from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.metric import Metric
Expand All @@ -43,7 +42,6 @@
from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin
from ax.runners.synthetic import SyntheticRunner
from ax.service.scheduler import (
ExperimentStatusProperties,
FailureRateExceededError,
get_fitted_model_bridge,
OptimizationResult,
Expand All @@ -62,6 +60,7 @@
from ax.storage.sqa_store.save import save_experiment
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.storage.sqa_store.structs import DBSettings

from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
from ax.utils.common.timeutils import current_timestamp_in_millis
Expand Down Expand Up @@ -471,8 +470,6 @@ def test_init_with_branin_experiment(self) -> None:
self.assertEqual(scheduler.options.tolerated_trial_failure_rate, 0.2)
self.assertEqual(scheduler.options.init_seconds_between_polls, 10)
self.assertIsNone(scheduler._latest_optimization_start_timestamp)
for status_prop in ExperimentStatusProperties:
self.assertEqual(scheduler.experiment._properties[status_prop.value], [])
scheduler.run_all_trials() # Runs no trials since total trials is 0.
# `_latest_optimization_start_timestamp` should be set now.
self.assertLessEqual(
Expand Down Expand Up @@ -589,23 +586,9 @@ def test_run_all_trials_using_runner_and_metrics(self) -> None:
# experiment.
dat = scheduler.experiment.fetch_data().df
self.assertEqual(set(dat["trial_index"].values), set(range(8)))
self.assertEqual(
scheduler.experiment._properties[
ExperimentStatusProperties.RUN_TRIALS_STATUS
],
["started", "success"],
)
self.assertEqual(
scheduler.experiment._properties[
ExperimentStatusProperties.NUM_TRIALS_RUN_PER_CALL
],
[8],
)
self.assertEqual(
scheduler.experiment._properties[
ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS
],
[],
self.assertNotIn(
Keys.RESUMED_FROM_STORAGE_TS.value,
scheduler.experiment._properties,
)

def test_run_all_trials_callback(self) -> None:
Expand Down Expand Up @@ -958,7 +941,6 @@ def test_timeout(self) -> None:
)
scheduler.run_all_trials(timeout_hours=0) # Forcing optimization to time out.
self.assertEqual(len(scheduler.experiment.trials), 0)
self.assertIn("aborted", scheduler.experiment._properties["run_trials_success"])

def test_logging(self) -> None:
gs = self._get_generation_strategy_strategy_for_test(
Expand Down Expand Up @@ -1234,22 +1216,16 @@ def test_sqa_storage_with_experiment_name(self) -> None:
)
# Hack "resumed from storage timestamp" into `exp` to make sure all other fields
# are equal, since difference in resumed from storage timestamps is expected.
exp._properties[ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS] = (
new_scheduler.experiment._properties[
ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS
]
exp._properties[Keys.RESUMED_FROM_STORAGE_TS] = (
new_scheduler.experiment._properties[Keys.RESUMED_FROM_STORAGE_TS]
)
self.assertEqual(new_scheduler.experiment, exp)
self.assertLessEqual(
len(gs._generator_runs),
len(new_scheduler.generation_strategy._generator_runs),
)
self.assertEqual(
len(
new_scheduler.experiment._properties[
ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS
]
),
len(new_scheduler.experiment._properties[Keys.RESUMED_FROM_STORAGE_TS]),
1,
)

Expand Down
25 changes: 15 additions & 10 deletions ax/utils/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,19 @@

from enum import Enum, unique

# ------------------------- Miscellaneous -------------------------

# -------------------------- Warnings --------------------------

TS_FMT = "%Y-%m-%d %H:%M:%S.%f"

DEFAULT_WINSORIZATION_LIMITS_MINIMIZATION: tuple[float, float] = (0.0, 0.2)
DEFAULT_WINSORIZATION_LIMITS_MAXIMIZATION: tuple[float, float] = (0.2, 0.0)

TESTENV_ENV_KEY = "TESTENV"
TESTENV_ENV_VAL = "True"


# --------------------------- Warnings ---------------------------


EXPERIMENT_IS_TEST_WARNING = (
Expand All @@ -19,7 +30,7 @@
)


# -------------------------- Error messages --------------------------
# ------------------------ Error messages ------------------------


UNEXPECTED_METRIC_COMBINATION = """\
Expand All @@ -28,7 +39,7 @@
"""


# --------------------------- Reserved keys ---------------------------
# ------------------------- Reserved keys -------------------------


@unique
Expand Down Expand Up @@ -72,6 +83,7 @@ class Keys(str, Enum):
QMC = "qmc"
RAW_INNER_SAMPLES = "raw_inner_samples"
RAW_SAMPLES = "raw_samples"
RESUMED_FROM_STORAGE_TS = "resumed_from_storage_timestamps"
SAMPLER = "sampler"
SEED_INNER = "seed_inner"
SEQUENTIAL = "sequential"
Expand All @@ -82,10 +94,3 @@ class Keys(str, Enum):
WARM_START_REFITTING = "warm_start_refitting"
X_BASELINE = "X_baseline"
PAIRWISE_PREFERENCE_QUERY = "pairwise_pref_query"


DEFAULT_WINSORIZATION_LIMITS_MINIMIZATION: tuple[float, float] = (0.0, 0.2)
DEFAULT_WINSORIZATION_LIMITS_MAXIMIZATION: tuple[float, float] = (0.2, 0.0)

TESTENV_ENV_KEY = "TESTENV"
TESTENV_ENV_VAL = "True"

0 comments on commit eaeefc6

Please sign in to comment.