Skip to content

Commit

Permalink
Correctly infer types on functions with multiple decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
ThirVondukr committed Mar 31, 2024
1 parent 15df7d4 commit b446776
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 9 deletions.
14 changes: 5 additions & 9 deletions aioinject/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,14 @@ def _get_provider_type_hints(


def _guess_return_type(factory: _FactoryType[_T]) -> type[_T]:
unwrapped = inspect.unwrap(factory)

origin = typing.get_origin(factory)
is_generic = origin and isclass(origin)
if isclass(factory) or is_generic:
return typing.cast(type[_T], factory)

type_hints = _get_type_hints(factory)
type_hints = _get_type_hints(unwrapped)
try:
return_type = type_hints["return"]
except KeyError as e:
Expand All @@ -162,18 +164,12 @@ def _guess_return_type(factory: _FactoryType[_T]) -> type[_T]:
if origin := typing.get_origin(return_type):
args = typing.get_args(return_type)

maybe_wrapped = getattr( # @functools.wraps
factory,
"__wrapped__",
factory,
)

is_async_gen = (
origin in _ASYNC_GENERATORS
and inspect.isasyncgenfunction(maybe_wrapped)
and inspect.isasyncgenfunction(unwrapped)
)
is_sync_gen = origin in _GENERATORS and inspect.isgeneratorfunction(
maybe_wrapped,
unwrapped,
)
if is_async_gen or is_sync_gen:
return_type = args[0]
Expand Down
File renamed without changes.
11 changes: 11 additions & 0 deletions tests/utils/test_guess_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from aioinject.providers import _guess_return_type
from tests.utils_ import dummy_decorator


def test_class() -> None:
Expand Down Expand Up @@ -64,3 +65,13 @@ async def iterable() -> return_type:

assert _guess_return_type(iterable) is int
assert _guess_return_type(contextlib.asynccontextmanager(iterable)) is int # type: ignore[comparison-overlap]


def test_decorated_function() -> None:
def func() -> int:
return 42

func = dummy_decorator(func)
func = dummy_decorator(func)

assert _guess_return_type(func) is int
15 changes: 15 additions & 0 deletions tests/utils_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import functools
from collections.abc import Callable
from typing import ParamSpec, TypeVar


T = TypeVar("T")
P = ParamSpec("P")


def dummy_decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
def decorator(*args: P.args, **kwargs: P.kwargs) -> T:
return func(*args, **kwargs)

return decorator

0 comments on commit b446776

Please sign in to comment.