Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DTO factory narrowed with a generic alias. #2791

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion litestar/dto/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

from msgspec import UNSET, Struct, UnsetType, convert, defstruct, field
from typing_extensions import get_origin

from litestar.dto._types import (
CollectionType,
Expand Down Expand Up @@ -110,7 +111,7 @@ def __init__(
rename_fields=self.dto_factory.config.rename_fields,
)
self.transfer_model_type = self.create_transfer_model_type(
model_name=model_type.__name__, field_definitions=self.parsed_field_definitions
model_name=(get_origin(model_type) or model_type).__name__, field_definitions=self.parsed_field_definitions
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this feature, anywhere that we could previously assume that the type that narrowed the dto was just a regular class, we now have to account for the fact that it could be an instance of _GenericAlias.

In the my first pass of this PR I've taken the quickest and dirtiest approach to get things passing, but we might need some abstraction over the type that handles the differences.

)
self.dto_data_type: type[DTOData] | None = None

Expand Down
7 changes: 5 additions & 2 deletions litestar/dto/base_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect import getmodule
from typing import TYPE_CHECKING, Collection, Generic, TypeVar

from typing_extensions import NotRequired, TypedDict, get_type_hints
from typing_extensions import NotRequired, TypedDict

from litestar.dto._backend import DTOBackend
from litestar.dto._codegen_backend import DTOCodegenBackend
Expand All @@ -17,6 +17,7 @@
from litestar.types.builtin_types import NoneType
from litestar.types.composite_types import TypeEncodersMap
from litestar.typing import FieldDefinition
from litestar.utils.typing import get_type_hints_with_generics_resolved

if TYPE_CHECKING:
from typing import Any, ClassVar, Generator
Expand Down Expand Up @@ -267,7 +268,9 @@ def get_model_type_hints(

return {
k: FieldDefinition.from_kwarg(annotation=v, name=k)
for k, v in get_type_hints(model_type, localns=namespace, include_extras=True).items()
for k, v in get_type_hints_with_generics_resolved(
model_type, localns=namespace, include_extras=True
).items()
}

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions litestar/dto/dataclass_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import MISSING, fields, replace
from typing import TYPE_CHECKING, Generic, TypeVar

from typing_extensions import get_origin

from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField
Expand All @@ -29,7 +31,8 @@ class DataclassDTO(AbstractDTO[T], Generic[T]):
def generate_field_definitions(
cls, model_type: type[DataclassProtocol]
) -> Generator[DTOFieldDefinition, None, None]:
dc_fields = {f.name: f for f in fields(model_type)}
model_origin = get_origin(model_type) or model_type
dc_fields = {f.name: f for f in fields(model_origin)}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to test this against all of the other dto factory types that support generics too, b/c there is bound to be cases where the _GenericAlias instances break things in there too.

for key, field_definition in cls.get_model_type_hints(model_type).items():
if not (dc_field := dc_fields.get(key)):
continue
Expand All @@ -41,7 +44,7 @@ def generate_field_definitions(
field_definition=field_definition,
default_factory=default_factory,
dto_field=dc_field.metadata.get(DTO_FIELD_META_KEY, DTOField()),
model_name=model_type.__name__,
model_name=model_origin.__name__,
),
name=key,
default=default,
Expand Down
29 changes: 28 additions & 1 deletion litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,21 @@
from copy import deepcopy
from dataclasses import dataclass, is_dataclass, replace
from inspect import Parameter, Signature
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast
from typing import ( # type: ignore[attr-defined]
Any,
AnyStr,
Callable,
ClassVar,
Collection,
ForwardRef,
Literal,
Mapping,
Protocol,
Sequence,
TypeVar,
_GenericAlias, # pyright: ignore
cast,
)

from msgspec import UnsetType
from typing_extensions import NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict
Expand Down Expand Up @@ -442,6 +456,19 @@ def is_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool:
if self.origin in UnionTypes:
return all(t.is_subclass_of(cl) for t in self.inner_types)

if isinstance(self.annotation, _GenericAlias) and self.origin not in (ClassVar, Literal):
cl_args = get_args(cl)
cl_origin = get_origin(cl) or cl
return (
issubclass(self.origin, cl_origin)
and (len(cl_args) == len(self.args) if cl_args else True)
and (
all(t.is_subclass_of(cl_arg) for t, cl_arg in zip(self.inner_types, cl_args))
if cl_args
else True
)
)

Comment on lines +459 to +471
Copy link
Contributor Author

@peterschutt peterschutt Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to determine when a _GenericAlias type should be considered a sub-type of another type.

This says that when A is a _GenericAlias instance, then it is a subclass of B when:

  1. if B is parametrized, then As origin is a subtype of Bs origin, otherwise A's origin is a subtype of B.
  2. if B is parametrized, then A and B must have the same number of type parameters declared
  3. if B is parametrized, then As type params are pairwise subtypes of Bs type params

With 2 and 3, if B is not parametrized, and origin of A is a subtype of B then I think we should treat that as B[Any, ...] so the type parameters only come into it if both types have args declared.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to add a lot of tests for this.

return self.origin not in UnionTypes and is_class_and_subclass(self.origin, cl)

if self.annotation is AnyStr:
Expand Down
4 changes: 3 additions & 1 deletion litestar/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def get_type_hints_with_generics_resolved(
if origin is None:
# Implies the generic types have not been specified in the annotation
type_hints = get_type_hints(annotation, globalns=globalns, localns=localns, include_extras=include_extras)
typevar_map = {p: p for p in annotation.__parameters__}
if not (parameters := getattr(annotation, "__parameters__", None)):
return type_hints
typevar_map = {p: p for p in parameters}
Comment on lines +252 to +254
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to support using this function without knowing up front if the type we are passing in is a generic type or not. If not, it won't have the parameters attribute, and so we just return the type hints without any post processing.

else:
type_hints = get_type_hints(origin, globalns=globalns, localns=localns, include_extras=include_extras)
# the __parameters__ is only available on the origin itself and not the annotation
Expand Down
37 changes: 32 additions & 5 deletions tests/unit/test_dto/test_factory/test_base_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Generic, Tuple, TypeVar, Union

import pytest
from typing_extensions import Annotated

from litestar import Request
from litestar.dto import DataclassDTO, DTOConfig
from litestar.exceptions.dto_exceptions import InvalidAnnotationException
from litestar.types.empty import Empty
from litestar.typing import FieldDefinition

from . import Model
Expand All @@ -19,7 +20,8 @@

from litestar.dto._backend import DTOBackend

T = TypeVar("T", bound=Model)
T = TypeVar("T")
ModelT = TypeVar("ModelT", bound=Model)


def get_backend(dto_type: type[DataclassDTO[Any]]) -> DTOBackend:
Expand Down Expand Up @@ -77,7 +79,7 @@ def test_extra_annotated_metadata_ignored() -> None:

def test_overwrite_config() -> None:
first = DTOConfig(exclude={"a"})
generic_dto = DataclassDTO[Annotated[T, first]] # pyright: ignore
generic_dto = DataclassDTO[Annotated[ModelT, first]] # pyright: ignore
second = DTOConfig(exclude={"b"})
dto = generic_dto[Annotated[Model, second]] # pyright: ignore
assert dto.config is second
Expand All @@ -86,13 +88,13 @@ def test_overwrite_config() -> None:
def test_existing_config_not_overwritten() -> None:
assert getattr(DataclassDTO, "_config", None) is None
first = DTOConfig(exclude={"a"})
generic_dto = DataclassDTO[Annotated[T, first]] # pyright: ignore
generic_dto = DataclassDTO[Annotated[ModelT, first]] # pyright: ignore
dto = generic_dto[Model] # pyright: ignore
assert dto.config is first


def test_config_assigned_via_subclassing() -> None:
class CustomGenericDTO(DataclassDTO[T]):
class CustomGenericDTO(DataclassDTO[ModelT]):
config = DTOConfig(exclude={"a"})

concrete_dto = CustomGenericDTO[Model]
Expand Down Expand Up @@ -161,3 +163,28 @@ class SubType(Model):
assert (
dto_type._dto_backends["handler_id"]["data_backend"].parsed_field_definitions[-1].name == "c" # pyright: ignore
)


def test_type_narrowing_with_generic_type() -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should name it test_get_model_type_hints_with_generic_type()

@dataclass
class Foo(Generic[T]):
foo: T

hints = DataclassDTO.get_model_type_hints(Foo[int])
assert hints == {
"foo": FieldDefinition(
raw=int,
annotation=int,
type_wrappers=(),
origin=None,
args=(),
metadata=(),
instantiable_origin=None,
safe_generic_origin=None,
inner_types=(),
default=Empty,
extra={},
kwarg_definition=None,
name="foo",
)
}
27 changes: 25 additions & 2 deletions tests/unit/test_dto/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from __future__ import annotations

from typing import Dict
from dataclasses import dataclass
from typing import Dict, Generic, TypeVar
from unittest.mock import MagicMock

import pytest

from litestar import Controller, Litestar, Router, post
from litestar.config.app import ExperimentalFeatures
from litestar.dto import AbstractDTO, DTOConfig
from litestar.dto import AbstractDTO, DataclassDTO, DTOConfig, DTOData
from litestar.dto._backend import DTOBackend
from litestar.dto._codegen_backend import DTOCodegenBackend
from litestar.testing import create_test_client

from . import Model

T = TypeVar("T")


@pytest.fixture()
def experimental_features(use_experimental_dto_backend: bool) -> list[ExperimentalFeatures] | None:
Expand Down Expand Up @@ -153,3 +156,23 @@ def handler(data: Model) -> Model:

backend = handler.resolve_data_dto()._dto_backends[handler.handler_id]["data_backend"] # type: ignore[union-attr]
assert isinstance(backend, DTOBackend)


def test_dto_for_generic_model() -> None:
@dataclass
class Foo(Generic[T]):
foo: T

FooDTO = DataclassDTO[Foo[int]]

@post("/foo", dto=FooDTO, signature_types=[Foo])
async def foo_handler(data: DTOData[Foo[int]]) -> Foo[int]:
return data.create_instance()

with create_test_client(route_handlers=foo_handler) as client:
response = client.post("/foo", json={"foo": 1})
assert response.status_code == 201
assert response.json() == {"foo": 1}
response = client.post("/foo", json={"foo": "1"})
assert response.status_code == 400
assert response.json() == {"status_code": 400, "detail": "Expected `int`, got `str` - at `$.foo`"}