Skip to content

Commit

Permalink
Add support for generics
Browse files Browse the repository at this point in the history
  • Loading branch information
ThirVondukr committed Mar 20, 2024
1 parent 46ab58c commit 713b5ca
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 3 deletions.
30 changes: 27 additions & 3 deletions aioinject/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,33 @@ def collect_dependencies(
)


def _typevar_map(
source: type[Any],
) -> tuple[type, Mapping[object, object]]:
origin = typing.get_origin(source)
if not isclass(source) and not origin:
return source, {}

resolved_source = origin or source
typevar_map: dict[object, object] = {}
for base in (source, *getattr(source, "__orig_bases__", [])):
origin = typing.get_origin(base)
if not origin:
continue

params = origin.__parameters__
args = typing.get_args(base)
typevar_map |= dict(zip(params, args, strict=True))

return resolved_source, typevar_map


def _get_provider_type_hints(
provider: Provider[Any],
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
source = provider.impl
source, typevar_map = _typevar_map(source=provider.impl)

if inspect.isclass(source):
source = source.__init__

Expand All @@ -100,7 +122,7 @@ def _get_provider_type_hints(
if isinstance(arg, Inject):
break
else:
type_hints[key] = Annotated[value, Inject]
type_hints[key] = Annotated[typevar_map.get(value, value), Inject]

return type_hints

Expand All @@ -125,7 +147,9 @@ def _get_provider_type_hints(


def _guess_return_type(factory: _FactoryType[_T]) -> type[_T]:
if isclass(factory):
origin = typing.get_origin(factory)
is_generic = origin and isclass(origin)
if isclass(factory) or is_generic:
return typing.cast(type[_T], factory)

type_hints = _get_type_hints(factory)
Expand Down
Empty file added tests/features/__init__.py
Empty file.
53 changes: 53 additions & 0 deletions tests/features/test_generics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Generic, TypeVar

from aioinject import Container, Object, Scoped
from aioinject.providers import Dependency


T = TypeVar("T")


class GenericService(Generic[T]):
def __init__(self, dependency: str) -> None:
self.dependency = dependency


class WithGenericDependency(Generic[T]):
def __init__(self, dependency: T) -> None:
self.dependency = dependency


class ConstrainedGenericDependency(WithGenericDependency[int]):
pass


async def test_generic_dependency() -> None:
assert Scoped(GenericService[int]).resolve_dependencies() == (
Dependency(
name="dependency",
type_=str,
),
)

assert Scoped(WithGenericDependency[int]).resolve_dependencies() == (
Dependency(
name="dependency",
type_=int,
),
)
assert Scoped(ConstrainedGenericDependency).resolve_dependencies() == (
Dependency(
name="dependency",
type_=int,
),
)


async def test_resolve_generics() -> None:
container = Container()
container.register(Scoped(WithGenericDependency[int]))
container.register(Object(42))

async with container.context() as ctx:
instance = await ctx.resolve(WithGenericDependency[int])
assert isinstance(instance, WithGenericDependency)

0 comments on commit 713b5ca

Please sign in to comment.