Skip to content

Commit

Permalink
Merge pull request #7 from trollfot/main
Browse files Browse the repository at this point in the history
Add sync on_resolve extension.
  • Loading branch information
ThirVondukr committed Mar 28, 2024
2 parents 1f41018 + 4f7e360 commit 9d1b78e
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 5 deletions.
7 changes: 5 additions & 2 deletions aioinject/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,14 @@ def context(
extensions=extensions,
)

def sync_context(self) -> SyncInjectionContext:
def sync_context(
self,
extensions: Sequence[ContextExtension] = (),
) -> SyncInjectionContext:
return SyncInjectionContext(
container=self,
singletons=self._singletons,
extensions=(),
extensions=extensions,
)

@contextlib.contextmanager
Expand Down
10 changes: 9 additions & 1 deletion aioinject/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from aioinject._store import InstanceStore, NotInCache
from aioinject._types import AnyCtx, T
from aioinject.extensions import ContextExtension, OnResolveExtension
from aioinject.extensions import (
ContextExtension, OnResolveExtension, SyncOnResolveExtension
)
from aioinject.providers import Dependency, DependencyLifetime


Expand Down Expand Up @@ -193,6 +195,7 @@ def _resolve(
if provider.is_generator:
resolved = store.enter_sync_context(resolved)
store.add(provider, resolved)
self._on_resolve(provider=provider, instance=resolved)
return resolved

def execute(
Expand All @@ -209,6 +212,11 @@ def execute(
resolved[dependency.name] = self.resolve(type_=dependency.type_)
return function(*args, **kwargs, **resolved)

def _on_resolve(self, provider: Provider[T], instance: T) -> None:
for extension in self._extensions:
if isinstance(extension, SyncOnResolveExtension):
extension.on_resolve(self, provider, instance)

def __enter__(self) -> Self:
self._token = context_var.set(self)
return self
Expand Down
15 changes: 13 additions & 2 deletions aioinject/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
if TYPE_CHECKING:
from contextlib import AbstractAsyncContextManager

from aioinject import Container, InjectionContext, Provider
from aioinject import Container, Provider
from aioinject import InjectionContext, SyncInjectionContext
from aioinject._types import T


Expand Down Expand Up @@ -36,5 +37,15 @@ async def on_resolve(
) -> None: ...


@runtime_checkable
class SyncOnResolveExtension(Protocol):
def on_resolve(
self,
context: SyncInjectionContext,
provider: Provider[T],
instance: T,
) -> None: ...


Extension = LifespanExtension | OnInitExtension
ContextExtension = OnResolveExtension
ContextExtension = OnResolveExtension | SyncOnResolveExtension
53 changes: 53 additions & 0 deletions tests/extensions/test_on_resolve_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import collections

import pytest

from aioinject import (
Container,
InjectionContext,
Object,
Provider,
Scoped,
Singleton,
Transient,
)
from aioinject._types import T
from aioinject.extensions import SyncOnResolveExtension


class _TestExtension(SyncOnResolveExtension):
def __init__(self) -> None:
self.type_counter: dict[type[object], int] = collections.defaultdict(
int,
)

def on_resolve(
self,
context: InjectionContext, # noqa: ARG002
provider: Provider[T],
instance: T, # noqa: ARG002
) -> None:
self.type_counter[provider.type_] += 1


@pytest.mark.parametrize(
"provider",
[
Object(0),
Scoped(int),
Transient(int),
Singleton(int),
],
)
def test_on_resolve(provider: Provider[int]) -> None:
container = Container()
container.register(provider)

extension = _TestExtension()
with container.sync_context(extensions=(extension,)) as ctx:
for i in range(1, 10 + 1):
number = ctx.resolve(int)
assert number == 0
assert extension.type_counter[int] == (
i if isinstance(provider, Transient) else 1
)

0 comments on commit 9d1b78e

Please sign in to comment.