Skip to content

Commit

Permalink
Merge pull request #2 from nrbnlulu/support-deffered-types
Browse files Browse the repository at this point in the history
Support deffered types
  • Loading branch information
ThirVondukr committed Jan 27, 2024
2 parents cae83ae + 430c0f6 commit 2654a4e
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: Test
on: [push]
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@ coverage.xml
# virtualenv
.venv
venv*/


# python cached files
*.py[cod]
24 changes: 14 additions & 10 deletions aioinject/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,27 @@ def __init__(self) -> None:
self._sync_exit_stack = contextlib.ExitStack()

def get(self, provider: Provider[T]) -> T | Literal[NotInCache.sentinel]:
return self._cache.get(provider.type_, NotInCache.sentinel)
return self._cache.get(provider.resolve_type(), NotInCache.sentinel)

def add(self, provider: Provider[T], obj: T) -> None:
if provider.lifetime is not DependencyLifetime.transient:
self._cache[provider.type_] = obj
self._cache[provider.resolve_type()] = obj

def lock(
self,
provider: Provider[Any],
) -> AbstractAsyncContextManager[bool]:
return contextlib.nullcontext(provider.type_ not in self._cache)
return contextlib.nullcontext(
provider.resolve_type() not in self._cache,
)

def sync_lock(
self,
provider: Provider[Any],
) -> AbstractContextManager[bool]:
return contextlib.nullcontext(provider.type_ not in self._cache)
return contextlib.nullcontext(
provider.resolve_type() not in self._cache,
)

@typing.overload
async def enter_context(
Expand Down Expand Up @@ -124,9 +128,9 @@ def __init__(self) -> None:

@contextlib.asynccontextmanager
async def lock(self, provider: Provider[Any]) -> AsyncIterator[bool]:
if provider.type_ not in self._cache:
async with self._locks[provider.type_]:
yield provider.type_ not in self._cache
if provider.resolve_type() not in self._cache:
async with self._locks[provider.resolve_type()]:
yield provider.resolve_type() not in self._cache
return
yield False

Expand All @@ -135,8 +139,8 @@ def sync_lock(
self,
provider: Provider[Any],
) -> Iterator[bool]:
if provider.type_ not in self._cache:
with self._sync_locks[provider.type_]:
yield provider.type_ not in self._cache
if provider.resolve_type() not in self._cache:
with self._sync_locks[provider.resolve_type()]:
yield provider.resolve_type() not in self._cache
return
yield False
35 changes: 26 additions & 9 deletions aioinject/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,33 @@ class Container:
def __init__(self) -> None:
self.providers: _Providers[Any] = {}
self._singletons = SingletonStore()
self._unresolved_providers: list[Provider[Any]] = []
self.type_context: dict[str, type[Any]] = {}

def _resolve_unresolved_provider(self) -> None:
for provider in self._unresolved_providers:
with contextlib.suppress(NameError):
self._register_impl(provider)
self._unresolved_providers.remove(provider)

def _register_impl(self, provider: Provider[Any]) -> None:
provider_type = provider.resolve_type(self.type_context)
if provider_type in self.providers:
msg = f"Provider for type {provider_type} is already registered"
raise ValueError(msg)
self.providers[provider_type] = provider
if klass_name := getattr(provider_type, "__name__", None):
self.type_context[klass_name] = provider_type

def register(
self,
provider: Provider[Any],
) -> None:
if provider.type_ in self.providers:
msg = f"Provider for type {provider.type_} is already registered"
raise ValueError(msg)

self.providers[provider.type_] = provider
try:
self._register_impl(provider)
except NameError:
self._unresolved_providers.append(provider)
self._resolve_unresolved_provider()

def get_provider(self, type_: type[_T]) -> Provider[_T]:
try:
Expand All @@ -53,14 +70,14 @@ def override(
self,
provider: Provider[Any],
) -> Iterator[None]:
previous = self.providers.get(provider.type_)
self.providers[provider.type_] = provider
previous = self.providers.get(provider.resolve_type())
self.providers[provider.resolve_type()] = provider

yield

del self.providers[provider.type_]
del self.providers[provider.resolve_type()]
if previous is not None:
self.providers[provider.type_] = previous
self.providers[provider.resolve_type()] = previous

async def __aenter__(self) -> Self:
return self
Expand Down
6 changes: 4 additions & 2 deletions aioinject/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ async def resolve(
return cached

dependencies = {}
for dependency in provider.dependencies:
for dependency in provider.resolve_dependencies(
self._container.type_context,
):
dependencies[dependency.name] = await self.resolve(
type_=dependency.type_,
)
Expand Down Expand Up @@ -151,7 +153,7 @@ def resolve(
return cached

dependencies = {}
for dependency in provider.dependencies:
for dependency in provider.resolve_dependencies():
dependencies[dependency.name] = self.resolve(
type_=dependency.type_,
)
Expand Down
84 changes: 65 additions & 19 deletions aioinject/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from aioinject.markers import Inject
from aioinject.utils import (
_get_type_hints,
is_context_manager_function,
remove_annotation,
)
Expand Down Expand Up @@ -62,13 +63,13 @@ def _find_inject_marker_in_annotation_args(

def collect_dependencies(
dependant: typing.Callable[..., object] | dict[str, Any],
ctx: dict[str, type[Any]] | None = None,
) -> typing.Iterable[Dependency[object]]:
if not isinstance(dependant, dict):
with remove_annotation(dependant.__annotations__, "return"):
type_hints = typing.get_type_hints(dependant, include_extras=True)
type_hints = _get_type_hints(dependant, context=ctx)
else:
type_hints = dependant

for name, hint in type_hints.items():
dep_type, args = _get_annotation_args(hint)
inject_marker = _find_inject_marker_in_annotation_args(args)
Expand All @@ -81,15 +82,18 @@ def collect_dependencies(
)


def _get_provider_type_hints(provider: Provider[Any]) -> dict[str, Any]:
def _get_provider_type_hints(
provider: Provider[Any],
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
source = provider.impl
if inspect.isclass(source):
source = source.__init__

if isinstance(source, functools.partial):
return {}

type_hints = typing.get_type_hints(source, include_extras=True)
type_hints = _get_type_hints(source, context=context)
for key, value in type_hints.items():
_, args = _get_annotation_args(value)
for arg in args:
Expand Down Expand Up @@ -120,11 +124,14 @@ def _get_provider_type_hints(provider: Provider[Any]) -> dict[str, Any]:
)


def _guess_return_type(factory: _FactoryType[_T]) -> type[_T]:
def _guess_return_type(
factory: _FactoryType[_T],
context: dict[str, type[Any]] | None = None,
) -> type[_T]:
if isclass(factory):
return typing.cast(type[_T], factory)

type_hints = typing.get_type_hints(factory)
type_hints = _get_type_hints(factory, context=context)
try:
return_type = type_hints["return"]
except KeyError as e:
Expand Down Expand Up @@ -159,34 +166,59 @@ class DependencyLifetime(enum.Enum):

@runtime_checkable
class Provider(Protocol[_T]):
type_: type[_T]
impl: Any
lifetime: DependencyLifetime
_cached_dependencies: tuple[Dependency[object], ...]
_cached_type: type[_T] # TODO: I think it is redundant.

async def provide(self, kwargs: Mapping[str, Any]) -> _T:
...

def provide_sync(self, kwargs: Mapping[str, Any]) -> _T:
...

@property
def type_hints(self) -> dict[str, Any]:
def _resolve_type_impl(
self,
context: dict[str, Any] | None = None,
) -> type[_T]:
...

def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]:
try:
return self._cached_type
except AttributeError:
self._cached_type = self._resolve_type_impl(context)
return self._cached_type

def type_hints(self, context: dict[str, Any] | None) -> dict[str, Any]:
...

@property
def is_async(self) -> bool:
...

@functools.cached_property
def dependencies(self) -> tuple[Dependency[object], ...]:
return tuple(collect_dependencies(self.type_hints))
def resolve_dependencies(
self,
context: dict[str, Any] | None = None,
) -> tuple[Dependency[object], ...]:
try:
return self._cached_dependencies
except AttributeError:
self._cached_dependencies = tuple(
collect_dependencies(self.type_hints(context), ctx=context),
)
return self._cached_dependencies

@functools.cached_property
def is_generator(self) -> bool:
return is_context_manager_function(self.impl)

def __repr__(self) -> str:
return f"{self.__class__.__qualname__}(type={self.type_}, implementation={self.impl})"
def __repr__(self) -> str: # pragma: no cover
try:
type_ = repr(self.resolve_type())
except NameError:
type_ = "UNKNOWN"
return f"{self.__class__.__qualname__}(type={type_}, implementation={self.impl})"


class Scoped(Provider[_T]):
Expand All @@ -197,9 +229,15 @@ def __init__(
factory: _FactoryType[_T],
type_: type[_T] | None = None,
) -> None:
self.type_ = type_ or _guess_return_type(factory)
self.type_ = type_
self.impl = factory

def _resolve_type_impl(
self,
context: dict[str, Any] | None = None,
) -> type[_T]:
return self.type_ or _guess_return_type(self.impl, context=context)

def provide_sync(self, kwargs: Mapping[str, Any]) -> _T:
return self.impl(**kwargs) # type: ignore[return-value]

Expand All @@ -209,9 +247,11 @@ async def provide(self, kwargs: Mapping[str, Any]) -> _T:

return self.provide_sync(kwargs)

@functools.cached_property
def type_hints(self) -> dict[str, Any]:
type_hints = _get_provider_type_hints(self)
def type_hints(
self,
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
type_hints = _get_provider_type_hints(self, context=context)
if "return" in type_hints:
del type_hints["return"]
return type_hints
Expand Down Expand Up @@ -244,7 +284,7 @@ class Transient(Scoped[_T]):


class Object(Provider[_T]):
type_hints: ClassVar[dict[str, Any]] = {}
_type_hints: ClassVar[dict[str, Any]] = {}
is_async = False
impl: _T
lifetime = DependencyLifetime.scoped # It's ok to cache it
Expand All @@ -257,8 +297,14 @@ def __init__(
self.type_ = type_ or type(object_)
self.impl = object_

def _resolve_type_impl(self, _: dict[str, Any] | None = None) -> type[_T]:
return self.type_

def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: # noqa: ARG002
return self.impl

async def provide(self, kwargs: Mapping[str, Any]) -> _T: # noqa: ARG002
return self.impl

def type_hints(self, _: dict[str, Any] | None = None) -> dict[str, Any]:
return self._type_hints
9 changes: 9 additions & 0 deletions aioinject/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,12 @@ def remove_annotation(
yield
if annotation is not sentinel:
annotations[name] = annotation


def _get_type_hints(
obj: Any,
context: dict[str, type[Any]] | None = None,
) -> dict[str, Any]:
if not context:
context = {}
return typing.get_type_hints(obj, include_extras=True, localns=context)
18 changes: 12 additions & 6 deletions aioinject/validation/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ def all_dependencies_are_present(
) -> Sequence[ContainerValidationError]:
errors = []
for provider in container.providers.values():
for dependency in provider.dependencies:
if dependency.type_ not in container.providers:
for dependency in provider.resolve_dependencies(
container.type_context,
):
dep_type = dependency.type_
if dep_type not in container.providers:
error = DependencyNotFoundError(
message=f"Provider for type {dependency.type_} not found",
dependency=dependency.type_,
message=f"Provider for type {dep_type} not found",
dependency=dep_type,
)
errors.append(error)

Expand All @@ -44,9 +47,12 @@ def __call__(
if not self.dependant(provider):
continue

for dependency in provider.dependencies:
for dependency in provider.resolve_dependencies(
container.type_context,
):
dep_type = dependency.type_
dependency_provider = container.get_provider(
type_=dependency.type_,
type_=dep_type,
)
if self.dependency(dependency_provider):
msg = f"Provider {provider!r} cannot depend on {dependency_provider!r}"
Expand Down
Loading

0 comments on commit 2654a4e

Please sign in to comment.