Skip to content

Commit

Permalink
feat: adapt pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Mar 7, 2024
1 parent c6f4903 commit ecdf2b6
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 56 deletions.
6 changes: 6 additions & 0 deletions jina/_docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@
from docarray import Document, DocumentArray

docarray_v2 = False


import pydantic

is_pydantic_v2 = pydantic.__version__.startswith('2.')

14 changes: 9 additions & 5 deletions jina/serve/runtimes/gateway/graph/topology_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import grpc.aio

from jina._docarray import DocumentArray, docarray_v2
from jina._docarray import DocumentArray, docarray_v2, is_pydantic_v2
from jina.constants import __default_endpoint__
from jina.excepts import InternalNetworkError
from jina.logging.logger import JinaLogger
Expand All @@ -20,7 +20,11 @@
from docarray import DocList
from docarray.documents.legacy import LegacyDocument

from jina.serve.runtimes.helper import _create_pydantic_model_from_schema
if not is_pydantic_v2:
from jina.serve.runtimes.helper import _create_pydantic_model_from_schema as create_base_doc_from_schema

Check warning on line 24 in jina/serve/runtimes/gateway/graph/topology_graph.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/gateway/graph/topology_graph.py#L23-L24

Added lines #L23 - L24 were not covered by tests
else:
from docarray.utils.create_dynamic_doc_class import create_base_doc_from_schema

Check warning on line 26 in jina/serve/runtimes/gateway/graph/topology_graph.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/gateway/graph/topology_graph.py#L26

Added line #L26 was not covered by tests


legacy_doc_schema = LegacyDocument.schema()

Expand Down Expand Up @@ -239,7 +243,7 @@ async def task():
input_model = LegacyDocument
else:
input_model = (
_create_pydantic_model_from_schema(
create_base_doc_from_schema(
input_model_schema,
input_model_name,
models_created_by_name,
Expand Down Expand Up @@ -269,7 +273,7 @@ async def task():
output_model = LegacyDocument
else:
output_model = (
_create_pydantic_model_from_schema(
create_base_doc_from_schema(
output_model_schema,
output_model_name,
models_created_by_name,
Expand Down Expand Up @@ -306,7 +310,7 @@ async def task():
from pydantic import BaseModel

parameters_model = (
_create_pydantic_model_from_schema(
create_base_doc_from_schema(
parameters_model_schema,
parameters_model_name,
models_created_by_name,
Expand Down
11 changes: 7 additions & 4 deletions jina/serve/runtimes/head/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
from jina.serve.runtimes.monitoring import MonitoringRequestMixin
from jina.serve.runtimes.worker.request_handling import WorkerRequestHandler
from jina.types.request.data import DataRequest, Response
from jina._docarray import docarray_v2
from jina._docarray import docarray_v2, is_pydantic_v2

if docarray_v2:
from jina.serve.runtimes.helper import _create_pydantic_model_from_schema
if not is_pydantic_v2:
from jina.serve.runtimes.helper import _create_pydantic_model_from_schema as create_base_doc_from_schema

Check warning on line 23 in jina/serve/runtimes/head/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/head/request_handling.py#L22-L23

Added lines #L22 - L23 were not covered by tests
else:
from docarray.utils.create_dynamic_doc_class import create_base_doc_from_schema

Check warning on line 25 in jina/serve/runtimes/head/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/head/request_handling.py#L25

Added line #L25 was not covered by tests
from docarray import DocList
from docarray.base_doc.any_doc import AnyDoc

Expand Down Expand Up @@ -359,7 +362,7 @@ async def task():
LegacyDocument
)
elif input_model_name not in models_created_by_name:
input_model = _create_pydantic_model_from_schema(
input_model = create_base_doc_from_schema(

Check warning on line 365 in jina/serve/runtimes/head/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/head/request_handling.py#L365

Added line #L365 was not covered by tests
input_model_schema, input_model_name, {}
)
models_created_by_name[input_model_name] = input_model
Expand All @@ -369,7 +372,7 @@ async def task():
LegacyDocument
)
elif output_model_name not in models_created_by_name:
output_model = _create_pydantic_model_from_schema(
output_model = create_base_doc_from_schema(

Check warning on line 375 in jina/serve/runtimes/head/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/head/request_handling.py#L375

Added line #L375 was not covered by tests
output_model_schema, output_model_name, {}
)
models_created_by_name[output_model_name] = output_model
Expand Down
4 changes: 2 additions & 2 deletions jina/serve/runtimes/helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
from typing import Any, Dict, List, Optional, Tuple, Union

from jina._docarray import docarray_v2
from jina._docarray import docarray_v2, is_pydantic_v2

_SPECIFIC_EXECUTOR_SEPARATOR = '__'

Expand Down Expand Up @@ -79,7 +79,7 @@ def _parse_specific_params(parameters: Dict, executor_name: str):
return parsed_params


if docarray_v2:
if docarray_v2 and not is_pydantic_v2:
from docarray import BaseDoc, DocList
from docarray.typing import AnyTensor
from pydantic import create_model
Expand Down
11 changes: 7 additions & 4 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from google.protobuf.struct_pb2 import Struct

from jina._docarray import DocumentArray, docarray_v2
from jina._docarray import DocumentArray, docarray_v2, is_pydantic_v2
from jina.constants import __default_endpoint__
from jina.excepts import BadConfigSource, RuntimeTerminated
from jina.helper import get_full_version
Expand Down Expand Up @@ -1013,21 +1013,24 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto:
if docarray_v2:
from docarray.documents.legacy import LegacyDocument

from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list
if not is_pydantic_v2:
from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list as create_pure_python_type_model

Check warning on line 1017 in jina/serve/runtimes/worker/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/request_handling.py#L1016-L1017

Added lines #L1016 - L1017 were not covered by tests
else:
from docarray.utils.create_dynamic_doc_class import create_pure_python_type_model

Check warning on line 1019 in jina/serve/runtimes/worker/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/request_handling.py#L1019

Added line #L1019 was not covered by tests

legacy_doc_schema = LegacyDocument.schema()
for endpoint_name, inner_dict in schemas.items():
if inner_dict['input']['model'].schema() == legacy_doc_schema:
inner_dict['input']['model'] = legacy_doc_schema
else:
inner_dict['input']['model'] = _create_aux_model_doc_list_to_list(
inner_dict['input']['model'] = create_pure_python_type_model(

Check warning on line 1026 in jina/serve/runtimes/worker/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/request_handling.py#L1026

Added line #L1026 was not covered by tests
inner_dict['input']['model']
).schema()

if inner_dict['output']['model'].schema() == legacy_doc_schema:
inner_dict['output']['model'] = legacy_doc_schema
else:
inner_dict['output']['model'] = _create_aux_model_doc_list_to_list(
inner_dict['output']['model'] = create_pure_python_type_model(

Check warning on line 1033 in jina/serve/runtimes/worker/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/request_handling.py#L1033

Added line #L1033 was not covered by tests
inner_dict['output']['model']
).schema()

Expand Down
35 changes: 19 additions & 16 deletions tests/integration/docarray_v2/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import pytest
from jina._docarray import is_pydantic_v2
from docarray import BaseDoc, DocList
from docarray.documents import ImageDoc, TextDoc
from docarray.documents.legacy import LegacyDocument
Expand Down Expand Up @@ -302,10 +303,11 @@ def bar(self, docs: DocList[Output1], **kwargs) -> DocList[Output2]:
from jina.proto import jina_pb2
from jina.proto.jina_pb2_grpc import JinaDiscoverEndpointsRPCStub
from jina.serve.executors import __dry_run_endpoint__
from jina.serve.runtimes.helper import (
_create_aux_model_doc_list_to_list,
_create_pydantic_model_from_schema,
)
if not is_pydantic_v2:
from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list as create_pure_python_type_model
from jina.serve.runtimes.helper import _create_pydantic_model_from_schema as create_base_doc_from_schema
else:
from docarray.utils.create_dynamic_doc_class import create_pure_python_type_model, create_base_doc_from_schema

channel = grpc.insecure_channel(f'0.0.0.0:{ports[0]}')
stub = JinaDiscoverEndpointsRPCStub(channel)
Expand All @@ -320,16 +322,16 @@ def bar(self, docs: DocList[Output1], **kwargs) -> DocList[Output2]:
v = schema_map['/bar']
assert (
v['input']
== _create_pydantic_model_from_schema(
_create_aux_model_doc_list_to_list(Input1).schema(),
== create_base_doc_from_schema(
create_pure_python_type_model(Input1).schema(),
'Input1',
{},
).schema()
)
assert (
v['output']
== _create_pydantic_model_from_schema(
_create_aux_model_doc_list_to_list(Output2).schema(),
== create_base_doc_from_schema(
create_pure_python_type_model(Output2).schema(),
'Output2',
{},
).schema()
Expand Down Expand Up @@ -390,10 +392,11 @@ def bar(self, docs: DocList[Output1], **kwargs) -> DocList[Output2]:
from jina.proto import jina_pb2
from jina.proto.jina_pb2_grpc import JinaDiscoverEndpointsRPCStub
from jina.serve.executors import __default_endpoint__, __dry_run_endpoint__
from jina.serve.runtimes.helper import (
_create_aux_model_doc_list_to_list,
_create_pydantic_model_from_schema,
)
if not is_pydantic_v2:
from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list as create_pure_python_type_model
from jina.serve.runtimes.helper import _create_pydantic_model_from_schema as create_base_doc_from_schema
else:
from docarray.utils.create_dynamic_doc_class import create_pure_python_type_model, create_base_doc_from_schema

channel = grpc.insecure_channel(f'0.0.0.0:{ports[0]}')
stub = JinaDiscoverEndpointsRPCStub(channel)
Expand All @@ -411,14 +414,14 @@ def bar(self, docs: DocList[Output1], **kwargs) -> DocList[Output2]:
v = schema_map[__default_endpoint__]
assert (
v['input']
== _create_pydantic_model_from_schema(
_create_aux_model_doc_list_to_list(Input1).schema(), 'Input1', {}
== create_base_doc_from_schema(
create_pure_python_type_model(Input1).schema(), 'Input1', {}
).schema()
)
assert (
v['output']
== _create_pydantic_model_from_schema(
_create_aux_model_doc_list_to_list(Output2).schema(), 'Output2', {}
== create_base_doc_from_schema(
create_pure_python_type_model(Output2).schema(), 'Output2', {}
).schema()
)

Expand Down
55 changes: 30 additions & 25 deletions tests/unit/serve/runtimes/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from jina._docarray import docarray_v2
from jina._docarray import docarray_v2, is_pydantic_v2
from jina.serve.helper import get_default_grpc_options
from jina.serve.runtimes.helper import (
_get_name_from_replicas_name,
Expand Down Expand Up @@ -96,10 +96,11 @@ def test_create_pydantic_model_from_schema(transformation):
from docarray.documents import TextDoc
from docarray.typing import AnyTensor, ImageUrl

from jina.serve.runtimes.helper import (
_create_aux_model_doc_list_to_list,
_create_pydantic_model_from_schema,
)
if not is_pydantic_v2:
from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list as create_pure_python_type_model
from jina.serve.runtimes.helper import _create_pydantic_model_from_schema as create_base_doc_from_schema
else:
from docarray.utils.create_dynamic_doc_class import create_pure_python_type_model, create_base_doc_from_schema

class Nested2Doc(BaseDoc):
value: str
Expand All @@ -124,8 +125,8 @@ class CustomDoc(BaseDoc):
nested: Nested1Doc
classvar: ClassVar[str] = 'classvar'

CustomDocCopy = _create_aux_model_doc_list_to_list(CustomDoc)
new_custom_doc_model = _create_pydantic_model_from_schema(
CustomDocCopy = create_pure_python_type_model(CustomDoc)
new_custom_doc_model = create_base_doc_from_schema(
CustomDocCopy.schema(), 'CustomDoc', {}
)

Expand Down Expand Up @@ -199,8 +200,8 @@ class CustomDoc(BaseDoc):
class TextDocWithId(BaseDoc):
ia: str

TextDocWithIdCopy = _create_aux_model_doc_list_to_list(TextDocWithId)
new_textdoc_with_id_model = _create_pydantic_model_from_schema(
TextDocWithIdCopy = create_pure_python_type_model(TextDocWithId)
new_textdoc_with_id_model = create_base_doc_from_schema(
TextDocWithIdCopy.schema(), 'TextDocWithId', {}
)

Expand Down Expand Up @@ -229,8 +230,8 @@ class TextDocWithId(BaseDoc):
class ResultTestDoc(BaseDoc):
matches: DocList[TextDocWithId]

ResultTestDocCopy = _create_aux_model_doc_list_to_list(ResultTestDoc)
new_result_test_doc_with_id_model = _create_pydantic_model_from_schema(
ResultTestDocCopy = create_pure_python_type_model(ResultTestDoc)
new_result_test_doc_with_id_model = create_base_doc_from_schema(
ResultTestDocCopy.schema(), 'ResultTestDoc', {}
)
result_test_docs = DocList[ResultTestDoc](
Expand Down Expand Up @@ -268,10 +269,11 @@ def test_create_empty_doc_list_from_schema(transformation):
from docarray.documents import TextDoc
from docarray.typing import AnyTensor, ImageUrl

from jina.serve.runtimes.helper import (
_create_aux_model_doc_list_to_list,
_create_pydantic_model_from_schema,
)
if not is_pydantic_v2:
from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list as create_pure_python_type_model
from jina.serve.runtimes.helper import _create_pydantic_model_from_schema as create_base_doc_from_schema
else:
from docarray.utils.create_dynamic_doc_class import create_pure_python_type_model, create_base_doc_from_schema

class CustomDoc(BaseDoc):
tensor: Optional[AnyTensor]
Expand All @@ -288,8 +290,8 @@ class CustomDoc(BaseDoc):
tags: Optional[Dict[str, Any]] = None
lf: List[float] = [3.0, 4.1]

CustomDocCopy = _create_aux_model_doc_list_to_list(CustomDoc)
new_custom_doc_model = _create_pydantic_model_from_schema(
CustomDocCopy = create_pure_python_type_model(CustomDoc)
new_custom_doc_model = create_base_doc_from_schema(
CustomDocCopy.schema(), 'CustomDoc', {}
)

Expand All @@ -313,8 +315,8 @@ class CustomDoc(BaseDoc):
class TextDocWithId(BaseDoc):
ia: str

TextDocWithIdCopy = _create_aux_model_doc_list_to_list(TextDocWithId)
new_textdoc_with_id_model = _create_pydantic_model_from_schema(
TextDocWithIdCopy = create_pure_python_type_model(TextDocWithId)
new_textdoc_with_id_model = create_base_doc_from_schema(
TextDocWithIdCopy.schema(), 'TextDocWithId', {}
)

Expand All @@ -336,8 +338,8 @@ class TextDocWithId(BaseDoc):
class ResultTestDoc(BaseDoc):
matches: DocList[TextDocWithId]

ResultTestDocCopy = _create_aux_model_doc_list_to_list(ResultTestDoc)
new_result_test_doc_with_id_model = _create_pydantic_model_from_schema(
ResultTestDocCopy = create_pure_python_type_model(ResultTestDoc)
new_result_test_doc_with_id_model = create_base_doc_from_schema(
ResultTestDocCopy.schema(), 'ResultTestDoc', {}
)
result_test_docs = DocList[ResultTestDoc]()
Expand All @@ -360,8 +362,11 @@ class ResultTestDoc(BaseDoc):
@pytest.mark.skipif(not docarray_v2, reason='Test only working with docarray v2')
def test_dynamic_class_creation_multiple_doclist_nested():
from docarray import BaseDoc, DocList
from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list
from jina.serve.runtimes.helper import _create_pydantic_model_from_schema
if not is_pydantic_v2:
from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list as create_pure_python_type_model
from jina.serve.runtimes.helper import _create_pydantic_model_from_schema as create_base_doc_from_schema
else:
from docarray.utils.create_dynamic_doc_class import create_pure_python_type_model, create_base_doc_from_schema

class MyTextDoc(BaseDoc):
text: str
Expand All @@ -374,8 +379,8 @@ class SearchResult(BaseDoc):

textlist = DocList[MyTextDoc]([MyTextDoc(text='hey')])
models_created_by_name = {}
SearchResult_aux = _create_aux_model_doc_list_to_list(SearchResult)
_ = _create_pydantic_model_from_schema(
SearchResult_aux = create_pure_python_type_model(SearchResult)
_ = create_base_doc_from_schema(
SearchResult_aux.schema(), 'SearchResult', models_created_by_name
)
QuoteFile_reconstructed_in_gateway_from_Search_results = models_created_by_name[
Expand Down

0 comments on commit ecdf2b6

Please sign in to comment.