Skip to content

Commit

Permalink
Add more tests for event serde
Browse files Browse the repository at this point in the history
This also does some light refactoring.
  • Loading branch information
JordonPhillips committed Sep 11, 2024
1 parent 950d2bf commit ed87484
Show file tree
Hide file tree
Showing 10 changed files with 1,061 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

INITIAL_REQUEST_EVENT_TYPE = "initial-request"
INITIAL_RESPONSE_EVENT_TYPE = "initial-response"
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from smithy_core.utils import expect_type

from ..events import HEADERS_DICT, Event
from ..exceptions import EventError, UnexpectedEventError
from ..exceptions import EventError, UnmodeledEventError
from . import INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE
from .traits import EVENT_HEADER_TRAIT, EVENT_PAYLOAD_TRAIT

INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE)


class EventDeserializer(SpecificShapeDeserializer):
def __init__(
Expand All @@ -29,22 +32,31 @@ def read_struct(
) -> None:
event = Event.decode(self._source)
headers = event.message.headers
message_deserializer = EventMessageDeserializer(
headers, self._payload_codec.create_deserializer(event.message.payload)
)

payload_deserializer = None
if event.message.payload:
payload_deserializer = self._payload_codec.create_deserializer(
event.message.payload
)

message_deserializer = EventMessageDeserializer(headers, payload_deserializer)

match headers.get(":message-type"):
case "event":
member_name = expect_type(str, headers[":event-type"])
consumer(schema.members[member_name], message_deserializer)
if member_name in INITIAL_MESSAGE_TYPES:
# If it's an initial message, skip straight to deserialization.
message_deserializer.read_struct(schema, consumer)
else:
consumer(schema.members[member_name], message_deserializer)
case "exception":
member_name = expect_type(str, headers[":exception-type"])
consumer(schema.members[member_name], message_deserializer)
case "error":
# The `application/vnd.amazon.eventstream` format allows for explicitly
# unmodeled exceptions. These exceptions MUST have the `:error-code`
# and `:error-message` headers set, and they MUST be strings.
raise UnexpectedEventError(
raise UnmodeledEventError(
expect_type(str, headers[":error-code"]),
expect_type(str, headers[":error-message"]),
)
Expand All @@ -54,7 +66,7 @@ def read_struct(

class EventMessageDeserializer(SpecificShapeDeserializer):
def __init__(
self, headers: HEADERS_DICT, payload_deserializer: ShapeDeserializer
self, headers: HEADERS_DICT, payload_deserializer: ShapeDeserializer | None
) -> None:
self._headers = headers
self._payload_deserializer = payload_deserializer
Expand All @@ -70,10 +82,11 @@ def read_struct(
if member_schema is not None and EVENT_HEADER_TRAIT in member_schema.traits:
consumer(member_schema, headers_deserializer)

if (payload_member := self._get_payload_member(schema)) is not None:
consumer(payload_member, self._payload_deserializer)
else:
self._payload_deserializer.read_struct(schema, consumer)
if self._payload_deserializer:
if (payload_member := self._get_payload_member(schema)) is not None:
consumer(payload_member, self._payload_deserializer)
else:
self._payload_deserializer.read_struct(schema, consumer)

def _get_payload_member(self, schema: "Schema") -> "Schema | None":
for member in schema.members.values():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@
ShapeSerializer,
SpecificShapeSerializer,
)
from smithy_core.shapes import ShapeType
from smithy_core.utils import expect_type

from ..events import EventHeaderEncoder, EventMessage
from ..exceptions import InvalidHeaderValue
from .traits import ERROR_TRAIT, EVENT_HEADER_TRAIT, EVENT_PAYLOAD_TRAIT
from . import INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE
from .traits import (
ERROR_TRAIT,
EVENT_HEADER_TRAIT,
EVENT_PAYLOAD_TRAIT,
MEDIA_TYPE_TRAIT,
)

_INITIAL_REQUEST_EVENT_TYPE = "initial-request"
_INITIAL_RESPONSE_EVENT_TYPE = "initial-response"
_DEFAULT_STRING_CONTENT_TYPE = "text/plain"
_DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream"


class EventSerializer(SpecificShapeSerializer):
Expand All @@ -31,15 +39,29 @@ def __init__(
self._payload_codec = payload_codec
self._result: EventMessage | None = None
if is_client_mode:
self._initial_message_event_type = _INITIAL_REQUEST_EVENT_TYPE
self._initial_message_event_type = INITIAL_REQUEST_EVENT_TYPE
else:
self._initial_message_event_type = _INITIAL_RESPONSE_EVENT_TYPE
self._initial_message_event_type = INITIAL_RESPONSE_EVENT_TYPE

def get_result(self) -> EventMessage | None:
return self._result

@contextmanager
def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
# Event stream definitions are unions. Nothing about the union shape actually
# matters for the purposes of event stream serialization though, we just care
# about the specific member we're serializing. So here we yield immediately,
# and the next time this method is called it'll be by the member that we
# actually care about.
#
# Note that if we're serializing an operation input or output, it won't be a
# union at all, so this won't get triggered. Thankfully, that's what we want.
if schema.shape_type is ShapeType.UNION:
try:
yield self
finally:
return

headers_encoder = EventHeaderEncoder()

if ERROR_TRAIT in schema.traits:
Expand All @@ -65,21 +87,40 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
)
header_serializer = EventHeaderSerializer(headers_encoder)

if not self._has_payload_member(schema):
media_type = self._payload_codec.media_type

if (payload_member := self._get_payload_member(schema)) is not None:
media_type = self._get_payload_media_type(payload_member, media_type)
yield EventStreamBindingSerializer(header_serializer, payload_serializer)
else:
with payload_serializer.begin_struct(schema) as body_serializer:
yield EventStreamBindingSerializer(header_serializer, body_serializer)
else:
yield EventStreamBindingSerializer(header_serializer, payload_serializer)

payload_bytes = payload.getvalue()
if payload_bytes:
headers_encoder.encode_string(":content-type", media_type)

self._result = EventMessage(
headers_bytes=headers_encoder.get_result(), payload=payload.getvalue()
headers_bytes=headers_encoder.get_result(), payload=payload_bytes
)

def _has_payload_member(self, schema: "Schema") -> bool:
def _get_payload_member(self, schema: Schema) -> Schema | None:
for member in schema.members.values():
if EVENT_PAYLOAD_TRAIT in member.traits:
return True
return False
return schema
return None

def _get_payload_media_type(self, schema: Schema, default: str) -> str:
if (media_type := schema.traits.get(MEDIA_TYPE_TRAIT)) is not None:
return expect_type(str, media_type.value)

match schema.shape_type:
case ShapeType.STRING:
return _DEFAULT_STRING_CONTENT_TYPE
case ShapeType.BLOB:
return _DEFAULT_BLOB_CONTENT_TYPE
case _:
return default


class EventHeaderSerializer(SpecificShapeSerializer):
Expand Down Expand Up @@ -129,6 +170,7 @@ def __init__(
self._payload_serializer = payload_serializer

def before(self, schema: "Schema") -> ShapeSerializer:
print(f"FOUND TRAITS: {schema.traits}")
if EVENT_HEADER_TRAIT in schema.traits:
return self._header_serializer
return self._payload_serializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
EVENT_HEADER_TRAIT = ShapeID("smithy.api#eventHeader")
EVENT_PAYLOAD_TRAIT = ShapeID("smithy.api#eventPayload")
ERROR_TRAIT = ShapeID("smithy.api#error")
MEDIA_TYPE_TRAIT = ShapeID("smithy.api#mediaType")
Loading

0 comments on commit ed87484

Please sign in to comment.