Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extension protocol #4

Merged
merged 2 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions aioinject/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ class NotInCache(enum.Enum):


class InstanceStore:
def __init__(self) -> None:
def __init__(
self,
exit_stack: contextlib.AsyncExitStack | None = None,
sync_exit_stack: contextlib.ExitStack | None = None,
) -> None:
self._cache: dict[type, Any] = {}
self._exit_stack = contextlib.AsyncExitStack()
self._sync_exit_stack = contextlib.ExitStack()
self._exit_stack = exit_stack or contextlib.AsyncExitStack()
self._sync_exit_stack = sync_exit_stack or contextlib.ExitStack()

def get(self, provider: Provider[T]) -> T | Literal[NotInCache.sentinel]:
return self._cache.get(provider.type_, NotInCache.sentinel)
Expand Down Expand Up @@ -109,8 +113,12 @@ def close(self) -> None:


class SingletonStore(InstanceStore):
def __init__(self) -> None:
super().__init__()
def __init__(
self,
exit_stack: contextlib.AsyncExitStack | None = None,
sync_exit_stack: contextlib.ExitStack | None = None,
) -> None:
super().__init__(exit_stack, sync_exit_stack)
self._locks: dict[type, asyncio.Lock] = collections.defaultdict(
asyncio.Lock,
)
Expand Down
23 changes: 19 additions & 4 deletions aioinject/containers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import contextlib
from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from contextlib import AsyncExitStack
from types import TracebackType
from typing import Any, TypeAlias, TypeVar

from typing_extensions import Self

from aioinject._store import SingletonStore
from aioinject.context import InjectionContext, SyncInjectionContext
from aioinject.extensions import Extension, LifespanExtension, OnInitExtension
from aioinject.providers import Provider


Expand All @@ -15,11 +17,19 @@


class Container:
def __init__(self) -> None:
def __init__(self, extensions: Sequence[Extension] | None = None) -> None:
self._exit_stack = AsyncExitStack()
self._singletons = SingletonStore(exit_stack=self._exit_stack)

self.providers: _Providers[Any] = {}
self._singletons = SingletonStore()
self._unresolved_providers: list[Provider[Any]] = []
self.type_context: dict[str, type[Any]] = {}
self.extensions = extensions or []
self._init_extensions(self.extensions)

def _init_extensions(self, extensions: Sequence[Extension]) -> None:
for extension in extensions:
if isinstance(extension, OnInitExtension):
extension.on_init(self)

def register(
self,
Expand Down Expand Up @@ -68,6 +78,11 @@ def override(self, *providers: Provider[Any]) -> Iterator[None]:
self.providers[provider.type_] = prev

async def __aenter__(self) -> Self:
for extension in self.extensions:
if isinstance(extension, LifespanExtension):
await self._exit_stack.enter_async_context(
extension.lifespan(self),
)
return self

async def __aexit__(
Expand Down
27 changes: 27 additions & 0 deletions aioinject/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

from contextlib import AbstractAsyncContextManager
from typing import TYPE_CHECKING, Protocol, runtime_checkable


if TYPE_CHECKING:
from aioinject import Container


@runtime_checkable
class LifespanExtension(Protocol):
def lifespan(
self,
container: Container,
) -> AbstractAsyncContextManager[None]: ...


@runtime_checkable
class OnInitExtension(Protocol):
def on_init(
self,
container: Container,
) -> None: ...


Extension = LifespanExtension | OnInitExtension
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import cast

import pytest
from _pytest.fixtures import SubRequest


pytest_plugins = [
"anyio",
]


@pytest.fixture(scope="session", autouse=True, params=["asyncio", "trio"])
def anyio_backend(request: SubRequest) -> str:
return cast(str, request.param)
2 changes: 0 additions & 2 deletions tests/container/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_missing_provider() -> None:
assert str(exc_info.value) == msg


@pytest.mark.anyio
async def test_should_close_singletons() -> None:
shutdown = False

Expand Down Expand Up @@ -112,7 +111,6 @@ def dependency() -> Iterator[int]:
assert shutdown is True


@pytest.mark.anyio
async def test_deffered_dependecies() -> None:
if TYPE_CHECKING:
from decimal import Decimal
Expand Down
3 changes: 0 additions & 3 deletions tests/context/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from aioinject.markers import Inject


pytestmark = [pytest.mark.anyio]


class _TestError(Exception):
pass

Expand Down
3 changes: 0 additions & 3 deletions tests/context/test_context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_should_not_use_resolved_class_as_context_manager() -> None:
resolved.mock.close.assert_not_called()


@pytest.mark.anyio
async def test_async_context_manager() -> None:
mock = MagicMock()

Expand All @@ -90,7 +89,6 @@ async def get_session() -> AsyncIterator[_Session]:
mock.close.assert_called_once()


@pytest.mark.anyio
async def test_async_context_would_use_sync_context_managers() -> None:
mock = MagicMock()

Expand All @@ -110,7 +108,6 @@ def get_session() -> Generator[_Session, None, None]:
mock.close.assert_called_once()


@pytest.mark.anyio
async def test_should_not_use_resolved_class_as_async_context_manager() -> (
None
):
Expand Down
6 changes: 0 additions & 6 deletions tests/context/test_execute.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import functools
from typing import Annotated

import pytest

import aioinject
from aioinject import Container, Inject
from aioinject.providers import collect_dependencies
Expand Down Expand Up @@ -40,7 +38,6 @@ def test_execute_sync_with_kwargs(container: Container) -> None:
assert isinstance(c, _C)


@pytest.mark.anyio
async def test_execute_async(container: Container) -> None:
dependencies = list(collect_dependencies(_dependant))
async with container.context() as ctx:
Expand All @@ -49,7 +46,6 @@ async def test_execute_async(container: Container) -> None:
assert isinstance(c, _C)


@pytest.mark.anyio
async def test_execute_async_with_kwargs(container: Container) -> None:
dependencies = list(collect_dependencies(_dependant))
provided_a = _A()
Expand All @@ -59,7 +55,6 @@ async def test_execute_async_with_kwargs(container: Container) -> None:
assert isinstance(c, _C)


@pytest.mark.anyio
async def test_execute_async_coroutine(container: Container) -> None:
dependencies = list(collect_dependencies(_async_dependant))
async with container.context() as ctx:
Expand All @@ -68,7 +63,6 @@ async def test_execute_async_coroutine(container: Container) -> None:
assert isinstance(c, _C)


@pytest.mark.anyio
async def test_provide_functools_partial() -> None:
container = Container()
container.register(
Expand Down
2 changes: 0 additions & 2 deletions tests/ext/fastapi/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ async def route(request: SubRequest) -> str:
return request.param


@pytest.mark.anyio
async def test_function_route(
http_client: httpx.AsyncClient,
provided_value: int,
Expand All @@ -23,7 +22,6 @@ async def test_function_route(
assert response.json() == {"value": provided_value}


@pytest.mark.anyio
async def test_function_route_override(
http_client: httpx.AsyncClient,
container: aioinject.Container,
Expand Down
2 changes: 0 additions & 2 deletions tests/ext/litestar/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import contextlib
from collections.abc import AsyncIterator

import pytest
from litestar import Litestar

import aioinject
from aioinject import Container
from aioinject.ext.litestar import AioInjectPlugin


@pytest.mark.anyio
async def test_lifespan() -> None:
number = 42

Expand Down
3 changes: 0 additions & 3 deletions tests/ext/litestar/test_litestar.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import uuid

import httpx
import pytest

import aioinject


@pytest.mark.anyio
async def test_function_route(
http_client: httpx.AsyncClient,
provided_value: int,
Expand All @@ -16,7 +14,6 @@ async def test_function_route(
assert response.json() == {"value": provided_value}


@pytest.mark.anyio
async def test_function_route_override(
http_client: httpx.AsyncClient,
container: aioinject.Container,
Expand Down
2 changes: 1 addition & 1 deletion tests/ext/strawberry/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tests.ext.strawberry.app import StrawberryApp, _Query


@pytest.fixture
@pytest.fixture(autouse=True)
def anyio_backend() -> str:
return "asyncio"

Expand Down
3 changes: 0 additions & 3 deletions tests/ext/strawberry/test_strawberry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import pytest


pytestmark = [pytest.mark.anyio]


@pytest.mark.parametrize("resolver_name", ["helloWorld", "helloWorldSync"])
async def test_async_resolver(
http_client: httpx.AsyncClient,
Expand Down
1 change: 0 additions & 1 deletion tests/providers/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from aioinject import Inject, Object


@pytest.mark.anyio
async def test_would_provide_same_object() -> None:
obj = object()
provider = Object(object_=obj)
Expand Down
1 change: 0 additions & 1 deletion tests/providers/test_scoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def test_would_return_factory_result(provider: Provider[_Test]) -> None:
assert provider.provide_sync({}) is instance


@pytest.mark.anyio
async def test_provide_async() -> None:
return_value = 42

Expand Down
2 changes: 0 additions & 2 deletions tests/providers/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def test_identity(container: Container) -> None:
assert instance is ctx.resolve(_Test)


@pytest.mark.anyio
async def test_identity_async(container: Container) -> None:
async with container.context() as ctx:
instance = await ctx.resolve(_Test)
Expand All @@ -31,7 +30,6 @@ async def test_identity_async(container: Container) -> None:
assert instance is await ctx.resolve(_Test)


@pytest.mark.anyio
async def test_should_not_execute_twice() -> None:
count = 0

Expand Down
3 changes: 0 additions & 3 deletions tests/providers/test_transient.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@ def container() -> Container:
return container


@pytest.mark.anyio
def test_identity(container: Container) -> None:
with container.sync_context() as ctx:
assert ctx.resolve(_Test) is not ctx.resolve(_Test)


@pytest.mark.anyio
async def test_identity_async(container: Container) -> None:
async with container.context() as ctx:
assert await ctx.resolve(_Test) is not await ctx.resolve(_Test)


@pytest.mark.anyio
async def test_should_close_transient_dependencies() -> None:
count = 0

Expand Down
2 changes: 0 additions & 2 deletions tests/stores/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from aioinject._store import InstanceStore, SingletonStore


pytestmark = [pytest.mark.anyio]

_NUMBER = 42


Expand Down
43 changes: 43 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import contextlib
from collections.abc import AsyncIterator

from aioinject import Container
from aioinject.extensions import LifespanExtension, OnInitExtension


async def test_lifespan_extension() -> None:
class TestExtension(LifespanExtension):
def __init__(self) -> None:
self.open = False
self.closed = False

@contextlib.asynccontextmanager
async def lifespan(
self,
_: Container,
) -> AsyncIterator[None]:
self.open = True
yield
self.closed = True

extension = TestExtension()
container = Container(extensions=[extension])
assert not extension.closed
async with container:
assert extension.open
assert not extension.closed
assert extension.closed


def test_on_mount_extension() -> None:
class TestExtension(OnInitExtension):
def __init__(self) -> None:
self.mounted = False

def on_init(self, _: Container) -> None:
self.mounted = True

extension = TestExtension()
assert not extension.mounted
Container(extensions=[extension])
assert extension.mounted
Loading