Skip to content

Commit

Permalink
Remove validate_kwarg_typing and typeguard dependence (#2673)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2673

Proximate motivation: test_validate_kwarg_typing is failing on Github Actions but cannot be reproduced on internal setup

Reviewed By: esantorella

Differential Revision: D61489346

fbshipit-source-id: 753ea5e8ff0186c6c36ba053e661d61c8c400ac2
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Aug 19, 2024
1 parent 8f33bcc commit e157b1f
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 176 deletions.
33 changes: 26 additions & 7 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
consolidate_kwargs,
get_function_argument_names,
get_function_default_arguments,
validate_kwarg_typing,
)
from ax.utils.common.logger import get_logger
from ax.utils.common.serialization import callable_from_reference, callable_to_reference
Expand Down Expand Up @@ -289,14 +288,34 @@ def __call__(
model_setup_info = MODEL_KEY_TO_MODEL_SETUP[self.value]
model_class = model_setup_info.model_class
bridge_class = model_setup_info.bridge_class

if not silently_filter_kwargs:
validate_kwarg_typing(
typed_callables=[model_class, bridge_class],
search_space=search_space,
experiment=experiment,
data=data,
# Check correct kwargs are present
callables = (model_class, bridge_class)
kwargs_to_check = {
"search_space": search_space,
"experiment": experiment,
"data": data,
**kwargs,
)
}
checked_kwargs = set()
for fn in callables:
params = signature(fn).parameters
for kw in params.keys():
if kw in kwargs_to_check:
if kw in checked_kwargs:
logger.debug(
f"`{callables}` have duplicate keyword argument: {kw}."
)
else:
checked_kwargs.add(kw)

# Check if kwargs contains keywords not exist in any callables
extra_keywords = [kw for kw in kwargs.keys() if kw not in checked_kwargs]
if len(extra_keywords) != 0:
raise ValueError(
f"Arguments {extra_keywords} are not expected by any of {callables}"
)

# Create model with consolidated arguments: defaults + passed in kwargs.
model_kwargs = consolidate_kwargs(
Expand Down
42 changes: 0 additions & 42 deletions ax/utils/common/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Any, Callable, Optional

from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils_nonnative import version_safe_check_type

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -61,47 +60,6 @@ def get_function_default_arguments(function: Callable) -> dict[str, Any]:
}


# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def validate_kwarg_typing(typed_callables: list[Callable], **kwargs: Any) -> None:
"""Check if keywords in kwargs exist in any of the typed_callables and
if the type of each keyword value matches the type of corresponding arg in one of
the callables
Note: this function expects the typed callables to have unique keywords for
the arguments and will raise an error if repeat keywords are found.
"""
checked_kwargs = set()
for typed_callable in typed_callables:
params = signature(typed_callable).parameters
for kw, param in params.items():
if kw in kwargs:
if kw in checked_kwargs:
logger.debug(
f"`{typed_callables}` have duplicate keyword argument: {kw}."
)
else:
checked_kwargs.add(kw)
kw_val = kwargs.get(kw)
# if the keyword is a callable, we only do shallow checks
if not (callable(kw_val) and callable(param.annotation)):
try:
version_safe_check_type(kw, kw_val, param.annotation)
except TypeError:
message = (
f"`{typed_callable}` expected argument `{kw}` to be of"
f" type {param.annotation}. Got {kw_val}"
f" (type: {type(kw_val)})."
)
logger.warning(message)

# check if kwargs contains keywords not exist in any callables
extra_keywords = [kw for kw in kwargs.keys() if kw not in checked_kwargs]
if len(extra_keywords) != 0:
raise ValueError(
f"Arguments {extra_keywords} are not expected by any of {typed_callables}."
)


# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def warn_on_kwargs(callable_with_kwargs: Callable, **kwargs: Any) -> None:
"""Log a warning when a decoder function receives unexpected kwargs.
Expand Down
104 changes: 1 addition & 103 deletions ax/utils/common/tests/test_kwargutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,117 +8,15 @@


from logging import Logger
from typing import Callable, Optional
from unittest.mock import patch

from ax.utils.common.kwargs import validate_kwarg_typing, warn_on_kwargs
from ax.utils.common.kwargs import warn_on_kwargs
from ax.utils.common.logger import get_logger
from ax.utils.common.testutils import TestCase

logger: Logger = get_logger("ax.utils.common.kwargs")


class TestKwargUtils(TestCase):
def test_validate_kwarg_typing(self) -> None:
def typed_callable(arg1: int, arg2: Optional[str] = None) -> None:
pass

def typed_callable_with_dict(arg3: int, arg4: dict[str, int]) -> None:
pass

def typed_callable_valid(arg3: int, arg4: Optional[str] = None) -> None:
pass

def typed_callable_dup_keyword(arg2: int, arg4: Optional[str] = None) -> None:
pass

def typed_callable_with_callable(
arg1: int, arg2: Callable[[int], dict[str, int]]
) -> None:
pass

def typed_callable_extra_arg(arg1: int, arg2: str, arg3: bool) -> None:
pass

# pass
try:
kwargs = {"arg1": 1, "arg2": "test", "arg3": 2}
validate_kwarg_typing([typed_callable, typed_callable_valid], **kwargs)
except Exception:
self.assertTrue(False, "Exception raised on valid kwargs")

# pass with complex data structure
try:
kwargs = {"arg1": 1, "arg2": "test", "arg3": 2, "arg4": {"k1": 1}}
validate_kwarg_typing([typed_callable, typed_callable_with_dict], **kwargs)
except Exception:
self.assertTrue(False, "Exception raised on valid kwargs")

# callable as arg (same arg count but diff type)
try:
kwargs = {"arg1": 1, "arg2": typed_callable}
validate_kwarg_typing([typed_callable_with_callable], **kwargs)
except Exception:
self.assertTrue(False, "Exception raised on valid kwargs")

# callable as arg (diff arg count)
try:
kwargs = {"arg1": 1, "arg2": typed_callable_extra_arg}
validate_kwarg_typing([typed_callable_with_callable], **kwargs)
except Exception:
self.assertTrue(False, "Exception raised on valid kwargs")

# kwargs contains extra keywords
with self.assertRaises(ValueError):
kwargs = {"arg1": 1, "arg2": "test", "arg3": 3, "arg5": 4}
typed_callables = [typed_callable, typed_callable_valid]
validate_kwarg_typing(typed_callables, **kwargs)

# callables have duplicate keywords
with patch.object(logger, "debug") as mock_debug:
kwargs = {"arg1": 1, "arg2": "test", "arg4": "test_again"}
typed_callables = [typed_callable, typed_callable_dup_keyword]
validate_kwarg_typing(typed_callables, **kwargs)
mock_debug.assert_called_once_with(
f"`{typed_callables}` have duplicate keyword argument: arg2."
)

# mismatch types
with patch.object(logger, "warning") as mock_warning:
kwargs = {"arg1": 1, "arg2": "test", "arg3": "test_again"}
typed_callables = [typed_callable, typed_callable_valid]
validate_kwarg_typing(typed_callables, **kwargs)
expected_message = (
f"`{typed_callable_valid}` expected argument `arg3` to be of type"
f" {type(1)}. Got test_again (type: {type('test_again')})."
)
mock_warning.assert_called_once_with(expected_message)

# mismatch types with Dict
with patch.object(logger, "warning") as mock_warning:
str_dic = {"k1": "test"}
kwargs = {"arg1": 1, "arg2": "test", "arg3": 2, "arg4": str_dic}
typed_callables = [typed_callable, typed_callable_with_dict]
validate_kwarg_typing(typed_callables, **kwargs)
expected_message = (
f"`{typed_callable_with_dict}` expected argument `arg4` to be of type"
f" dict[str, int]. Got {str_dic} (type: {type(str_dic)})."
)
mock_warning.assert_called_once_with(expected_message)

# mismatch types with callable as arg
with patch.object(logger, "warning") as mock_warning:
kwargs = {"arg1": 1, "arg2": "test_again"}
typed_callables = [typed_callable_with_callable]
validate_kwarg_typing(typed_callables, **kwargs)
expected_message = (
f"`{typed_callable_with_callable}` expected argument `arg2` to be of"
f" type typing.Callable[[int], dict[str, int]]. "
f"Got test_again (type: {type('test_again')})."
)
mock_warning.assert_called_once_with(expected_message)


class TestWarnOnKwargs(TestCase):
def test_it_warns_if_kwargs_are_passed(self) -> None:
with patch.object(logger, "warning") as mock_warning:
Expand Down
24 changes: 1 addition & 23 deletions ax/utils/common/typeutils_nonnative.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,9 @@

# pyre-strict

from inspect import signature
from typing import Any, TypeVar
from typing import Any

import numpy as np
from typeguard import check_type

T = TypeVar("T")
V = TypeVar("V")
K = TypeVar("K")
X = TypeVar("X")
Y = TypeVar("Y")


def version_safe_check_type(argname: str, value: T, expected_type: type[T]) -> None:
"""Excecute the check_type function if it has the expected signature, otherwise
warn. This is done to support newer versions of typeguard with minimal loss
of functionality for users that have dependency conflicts"""
# Get the signature of the check_type function
sig = signature(check_type)
# Get the parameters of the check_type function
params = sig.parameters
# Check if the check_type function has the expected signature
params = set(params.keys())
if all(arg in params for arg in ["argname", "value", "expected_type"]):
check_type(argname, value, expected_type)


# pyre-fixme[3]: Return annotation cannot be `Any`.
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"ipywidgets",
# Needed for compatibility with ipywidgets >= 8.0.0
"plotly>=5.12.0",
"typeguard",
"pyre-extensions",
]

Expand Down

0 comments on commit e157b1f

Please sign in to comment.