Skip to content

Commit

Permalink
Fix Pyre errors in sqa_classes (#2659)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2659

Use `annotations` to properly annotate Ax SQA class `Column` attributes

Reviewed By: danielcohenlive

Differential Revision: D60612786

fbshipit-source-id: 51f38bafea1bb9f921fe4c49ea6b9b198f20ad20
  • Loading branch information
Cesar-Cardoso authored and facebook-github-bot committed Aug 19, 2024
1 parent 885288e commit 8f33bcc
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 346 deletions.
38 changes: 23 additions & 15 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter:
parameter = RangeParameter(
name=parameter_sqa.name,
parameter_type=parameter_sqa.parameter_type,
lower=parameter_sqa.lower,
upper=parameter_sqa.upper,
lower=float(not_none(parameter_sqa.lower)),
upper=float(not_none(parameter_sqa.upper)),
log_scale=parameter_sqa.log_scale or False,
digits=parameter_sqa.digits,
is_fidelity=parameter_sqa.is_fidelity or False,
Expand Down Expand Up @@ -427,7 +427,7 @@ def parameter_constraint_from_sqa(
)
a = a_values[0]
is_upper_bound = a == 1
bound = parameter_constraint_sqa.bound * a
bound = float(parameter_constraint_sqa.bound) * a
constraint = SumConstraint(
parameters=constraint_parameters,
is_upper_bound=is_upper_bound,
Expand All @@ -436,7 +436,7 @@ def parameter_constraint_from_sqa(
else:
constraint = ParameterConstraint(
constraint_dict=dict(parameter_constraint_sqa.constraint_dict),
bound=parameter_constraint_sqa.bound,
bound=float(parameter_constraint_sqa.bound),
)

constraint.db_id = parameter_constraint_sqa.id
Expand Down Expand Up @@ -476,8 +476,8 @@ def environmental_variable_from_sqa(self, parameter_sqa: SQAParameter) -> Parame
parameter = RangeParameter(
name=parameter_sqa.name,
parameter_type=parameter_sqa.parameter_type,
lower=parameter_sqa.lower,
upper=parameter_sqa.upper,
lower=float(not_none(parameter_sqa.lower)),
upper=float(not_none(parameter_sqa.upper)),
log_scale=parameter_sqa.log_scale or False,
digits=parameter_sqa.digits,
is_fidelity=parameter_sqa.is_fidelity or False,
Expand Down Expand Up @@ -710,8 +710,16 @@ def generator_run_from_sqa(
weights=weights,
optimization_config=opt_config,
search_space=search_space,
fit_time=generator_run_sqa.fit_time,
gen_time=generator_run_sqa.gen_time,
fit_time=(
None
if generator_run_sqa.fit_time is None
else float(generator_run_sqa.fit_time)
),
gen_time=(
None
if generator_run_sqa.gen_time is None
else float(generator_run_sqa.gen_time)
),
best_arm_predictions=best_arm_predictions, # pyre-ignore[6]
# pyre-fixme[6]: Expected `Optional[Tuple[typing.Dict[str, List[float]],
# typing.Dict[str, typing.Dict[str, List[float]]]]]` for 8th param but got
Expand Down Expand Up @@ -900,7 +908,7 @@ def trial_from_sqa(
reduced_state=reduced_state,
immutable_search_space_and_opt_config=immutable_ss_and_oc,
),
weight=generator_run_sqa.weight or 1.0,
weight=float(generator_run_sqa.weight or 1.0),
)
for generator_run_sqa in trial_sqa.generator_runs
]
Expand Down Expand Up @@ -1169,9 +1177,9 @@ def _outcome_constraint_from_sqa(
)
return OutcomeConstraint(
metric=metric,
bound=metric_sqa.bound,
op=metric_sqa.op,
relative=metric_sqa.relative,
bound=float(not_none(metric_sqa.bound)),
op=not_none(metric_sqa.op),
relative=not_none(metric_sqa.relative),
)

def _scalarized_outcome_constraint_from_sqa(
Expand Down Expand Up @@ -1215,7 +1223,7 @@ def _scalarized_outcome_constraint_from_sqa(
scalarized_outcome_constraint = ScalarizedOutcomeConstraint(
metrics=list(metrics),
weights=list(weights),
bound=not_none(metric_sqa.bound),
bound=float(not_none(metric_sqa.bound)),
op=not_none(metric_sqa.op),
relative=not_none(metric_sqa.relative),
)
Expand All @@ -1232,8 +1240,8 @@ def _objective_threshold_from_sqa(
)
ot = ObjectiveThreshold(
metric=metric,
bound=metric_sqa.bound,
relative=metric_sqa.relative,
bound=float(not_none(metric_sqa.bound)),
relative=not_none(metric_sqa.relative),
op=metric_sqa.op,
)
# ObjectiveThreshold constructor clones the passed-in metric, which means
Expand Down
2 changes: 1 addition & 1 deletion ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _get_trials_sqa(
mini_batch_db_ids = trial_db_ids[batch_size * i : batch_size * (i + 1)]
with session_scope() as session:
query = session.query(trial_sqa_class).filter(
trial_sqa_class.id.in_(mini_batch_db_ids) # pyre-ignore[16]
trial_sqa_class.id.in_(mini_batch_db_ids)
)

if trials_query_options is not None:
Expand Down
14 changes: 5 additions & 9 deletions ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,9 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial:
# saved data if it's no longer on the experiment.
with session_scope() as session:
session.query(data_sqa_class).filter_by(experiment_id=experiment_id).filter(
data_sqa_class.trial_index.isnot(None) # pyre-ignore[16]
).filter(
data_sqa_class.trial_index.in_(trial_idcs) # pyre-ignore[16]
).filter(
data_sqa_class.id.not_in(datas_to_keep) # pyre-ignore[16]
data_sqa_class.trial_index.isnot(None)
).filter(data_sqa_class.trial_index.in_(trial_idcs)).filter(
data_sqa_class.id not in datas_to_keep
).delete()

_bulk_merge_into_session(
Expand Down Expand Up @@ -422,7 +420,7 @@ def update_outcome_constraint_on_experiment(
) -> None:
oc_sqa_class = encoder.config.class_to_sqa_class[Metric]

exp_id = experiment.db_id
exp_id: Optional[int] = experiment.db_id
if exp_id is None:
raise UserInputError("Experiment must be saved before being updated.")
oc_id = outcome_constraint.db_id
Expand All @@ -432,9 +430,7 @@ def update_outcome_constraint_on_experiment(
id=oc_id
).delete()

# pyre-fixme[53]: Captured variable `exp_id` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
def add_experiment_id(sqa: SQAMetric):
def add_experiment_id(sqa: SQAMetric) -> None:
sqa.experiment_id = exp_id

encode_func = (
Expand Down
Loading

0 comments on commit 8f33bcc

Please sign in to comment.