Skip to content

Commit

Permalink
Add AWS event stream serde support
Browse files Browse the repository at this point in the history
This adds support for serializing and deserializing the
`application/vnd.amazon.eventstream` format for event stream framing.
This adds both high-level support useable in any context and codec
serde support.
  • Loading branch information
JordonPhillips committed Sep 13, 2024
1 parent 8ae4a72 commit 0f38929
Show file tree
Hide file tree
Showing 24 changed files with 1,877 additions and 2 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ lint-py: pants
./pants fix lint python-packages/smithy-aws-core::
./pants fix lint python-packages/smithy-json::
./pants fix lint python-packages/smithy-event-stream::
./pants fix lint python-packages/aws-event-stream::


## Runs checkers for the python packages.
Expand All @@ -64,6 +65,7 @@ check-py: pants
./pants check python-packages/smithy-aws-core::
./pants check python-packages/smithy-json::
./pants check python-packages/smithy-event-stream::
./pants check python-packages/aws-event-stream::


## Runs tests for the python packages.
Expand All @@ -73,6 +75,7 @@ test-py: pants
./pants test python-packages/smithy-aws-core::
./pants test python-packages/smithy-json::
./pants test python-packages/smithy-event-stream::
./pants test python-packages/aws-event-stream::


## Runs formatters/fixers/linters/checkers/tests for the python packages.
Expand Down
28 changes: 28 additions & 0 deletions python-packages/aws-event-stream/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

resource(name="pyproject", source="pyproject.toml")
resource(name="readme", source="README.md")
resource(name="notice", source="NOTICE")

python_distribution(
name="dist",
dependencies=[
":pyproject",
":readme",
":notice",
"python-packages/aws-event-stream/aws_event_stream:source",
],
provides=python_artifact(
name="aws_event_stream",
version="0.0.1",
),
)

# We shouldn't need this, but pants will assume that smithy_core is an external
# dependency since it's in pyproject.toml and there's no way to exclude it, so
# for now we need to duplicate things.
python_requirements(
name="requirements",
source="requirements.txt",
)
1 change: 1 addition & 0 deletions python-packages/aws-event-stream/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include aws_event_stream/py.typed
1 change: 1 addition & 0 deletions python-packages/aws-event-stream/NOTICE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Empty file.
13 changes: 13 additions & 0 deletions python-packages/aws-event-stream/aws_event_stream/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

resource(name="pytyped", source="py.typed")

python_sources(
name="source",
dependencies=[
":pytyped",
"python-packages/aws-event-stream:requirements",
],
sources=["**/*.py"],
)
2 changes: 2 additions & 0 deletions python-packages/aws-event-stream/aws_event_stream/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
from collections.abc import Callable

from smithy_core.codecs import Codec
from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer
from smithy_core.interfaces import BytesReader
from smithy_core.schemas import Schema
from smithy_core.utils import expect_type

from ..events import HEADERS_DICT, Event
from ..exceptions import EventError, UnexpectedEventError
from .traits import EVENT_HEADER_TRAIT, EVENT_PAYLOAD_TRAIT


class EventDeserializer(SpecificShapeDeserializer):
def __init__(
self, source: BytesReader, payload_codec: Codec, is_client_mode: bool = True
) -> None:
self._source = source
self._payload_codec = payload_codec
self._is_client_mode = is_client_mode

def read_struct(
self,
schema: Schema,
consumer: Callable[[Schema, ShapeDeserializer], None],
) -> None:
event = Event.decode(self._source)
headers = event.message.headers
message_deserializer = EventMessageDeserializer(
headers, self._payload_codec.create_deserializer(event.message.payload)
)

match headers.get(":message-type"):
case "event":
member_name = expect_type(str, headers[":event-type"])
consumer(schema.members[member_name], message_deserializer)
case "exception":
member_name = expect_type(str, headers[":exception-type"])
consumer(schema.members[member_name], message_deserializer)
case "error":
# The `application/vnd.amazon.eventstream` format allows for explicitly
# unmodeled exceptions. These exceptions MUST have the `:error-code`
# and `:error-message` headers set, and they MUST be strings.
raise UnexpectedEventError(
expect_type(str, headers[":error-code"]),
expect_type(str, headers[":error-message"]),
)
case _:
raise EventError(f"Unknown event structure: {event}")


class EventMessageDeserializer(SpecificShapeDeserializer):
def __init__(
self, headers: HEADERS_DICT, payload_deserializer: ShapeDeserializer
) -> None:
self._headers = headers
self._payload_deserializer = payload_deserializer

def read_struct(
self,
schema: Schema,
consumer: Callable[[Schema, ShapeDeserializer], None],
) -> None:
headers_deserializer = EventHeaderDeserializer(self._headers)
for key in self._headers.keys():
member_schema = schema.members.get(key)
if member_schema is not None and EVENT_HEADER_TRAIT in member_schema.traits:
consumer(member_schema, headers_deserializer)

if (payload_member := self._get_payload_member(schema)) is not None:
consumer(payload_member, self._payload_deserializer)
else:
self._payload_deserializer.read_struct(schema, consumer)

def _get_payload_member(self, schema: "Schema") -> "Schema | None":
for member in schema.members.values():
if EVENT_PAYLOAD_TRAIT in member.traits:
return member
return None


class EventHeaderDeserializer(SpecificShapeDeserializer):
def __init__(self, headers: HEADERS_DICT) -> None:
self._headers = headers

def read_boolean(self, schema: "Schema") -> bool:
return expect_type(bool, self._headers[schema.expect_member_name()])

def read_blob(self, schema: "Schema") -> bytes:
return expect_type(bytes, self._headers[schema.expect_member_name()])

def read_byte(self, schema: "Schema") -> int:
return expect_type(int, self._headers[schema.expect_member_name()])

def read_short(self, schema: "Schema") -> int:
return expect_type(int, self._headers[schema.expect_member_name()])

def read_integer(self, schema: "Schema") -> int:
return expect_type(int, self._headers[schema.expect_member_name()])

def read_long(self, schema: "Schema") -> int:
return expect_type(int, self._headers[schema.expect_member_name()])

def read_string(self, schema: "Schema") -> str:
return expect_type(str, self._headers[schema.expect_member_name()])

def read_timestamp(self, schema: "Schema") -> datetime.datetime:
# TODO: do we support timestamp format here? One would assume not since the
# format has a specific timestamp type.
return expect_type(
datetime.datetime, self._headers[schema.expect_member_name()]
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
from collections.abc import Iterator
from contextlib import contextmanager
from io import BytesIO
from typing import Never

from smithy_core.codecs import Codec
from smithy_core.schemas import Schema
from smithy_core.serializers import (
InterceptingSerializer,
ShapeSerializer,
SpecificShapeSerializer,
)

from ..events import EventHeaderEncoder, EventMessage
from ..exceptions import InvalidHeaderValue
from .traits import ERROR_TRAIT, EVENT_HEADER_TRAIT, EVENT_PAYLOAD_TRAIT

_INITIAL_REQUEST_EVENT_TYPE = "initial-request"
_INITIAL_RESPONSE_EVENT_TYPE = "initial-response"


class EventSerializer(SpecificShapeSerializer):
def __init__(
self,
payload_codec: Codec,
is_client_mode: bool = True,
) -> None:
self._payload_codec = payload_codec
self._result: EventMessage | None = None
if is_client_mode:
self._initial_message_event_type = _INITIAL_REQUEST_EVENT_TYPE
else:
self._initial_message_event_type = _INITIAL_RESPONSE_EVENT_TYPE

def get_result(self) -> EventMessage | None:
return self._result

@contextmanager
def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
headers_encoder = EventHeaderEncoder()

if ERROR_TRAIT in schema.traits:
headers_encoder.encode_string(":message-type", "exception")
headers_encoder.encode_string(
":exception-type", schema.expect_member_name()
)
else:
headers_encoder.encode_string(":message-type", "event")
if schema.member_name is None:
# If there's no member name, that must mean that the structure is
# either an input or output structure, and so this represents the
# initial message.
headers_encoder.encode_string(
":event-type", self._initial_message_event_type
)
else:
headers_encoder.encode_string(":event-type", schema.member_name)

payload = BytesIO()
payload_serializer: ShapeSerializer = self._payload_codec.create_serializer(
payload
)
header_serializer = EventHeaderSerializer(headers_encoder)

if not self._has_payload_member(schema):
with payload_serializer.begin_struct(schema) as body_serializer:
yield EventStreamBindingSerializer(header_serializer, body_serializer)
else:
yield EventStreamBindingSerializer(header_serializer, payload_serializer)

self._result = EventMessage(
headers_bytes=headers_encoder.get_result(), payload=payload.getvalue()
)

def _has_payload_member(self, schema: "Schema") -> bool:
for member in schema.members.values():
if EVENT_PAYLOAD_TRAIT in member.traits:
return True
return False


class EventHeaderSerializer(SpecificShapeSerializer):

def __init__(self, encoder: EventHeaderEncoder) -> None:
self._encoder = encoder

def _invalid_state(
self, schema: "Schema | None" = None, message: str | None = None
) -> Never:
if message is None:
message = f"Invalid header value type: {schema}"
raise InvalidHeaderValue(message)

def write_boolean(self, schema: "Schema", value: bool) -> None:
self._encoder.encode_boolean(schema.expect_member_name(), value)

def write_byte(self, schema: "Schema", value: int) -> None:
self._encoder.encode_byte(schema.expect_member_name(), value)

def write_short(self, schema: "Schema", value: int) -> None:
self._encoder.encode_short(schema.expect_member_name(), value)

def write_integer(self, schema: "Schema", value: int) -> None:
self._encoder.encode_integer(schema.expect_member_name(), value)

def write_long(self, schema: "Schema", value: int) -> None:
self._encoder.encode_long(schema.expect_member_name(), value)

def write_string(self, schema: "Schema", value: str) -> None:
self._encoder.encode_string(schema.expect_member_name(), value)

def write_blob(self, schema: "Schema", value: bytes) -> None:
self._encoder.encode_blob(schema.expect_member_name(), value)

def write_timestamp(self, schema: "Schema", value: datetime.datetime) -> None:
self._encoder.encode_timestamp(schema.expect_member_name(), value)


class EventStreamBindingSerializer(InterceptingSerializer):
def __init__(
self,
header_serializer: EventHeaderSerializer,
payload_serializer: ShapeSerializer,
) -> None:
self._header_serializer = header_serializer
self._payload_serializer = payload_serializer

def before(self, schema: "Schema") -> ShapeSerializer:
if EVENT_HEADER_TRAIT in schema.traits:
return self._header_serializer
return self._payload_serializer

def after(self, schema: "Schema") -> None:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from smithy_core.shapes import ShapeID

EVENT_HEADER_TRAIT = ShapeID("smithy.api#eventHeader")
EVENT_PAYLOAD_TRAIT = ShapeID("smithy.api#eventPayload")
ERROR_TRAIT = ShapeID("smithy.api#error")
Loading

0 comments on commit 0f38929

Please sign in to comment.