diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/__init__.py b/python-packages/aws-event-stream/aws_event_stream/_private/__init__.py index e69de29b..75f7768a 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/__init__.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/__init__.py @@ -0,0 +1,5 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +INITIAL_REQUEST_EVENT_TYPE = "initial-request" +INITIAL_RESPONSE_EVENT_TYPE = "initial-response" diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py index 1012e4e4..1f4ca8a0 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py @@ -10,9 +10,12 @@ from smithy_core.utils import expect_type from ..events import HEADERS_DICT, Event -from ..exceptions import EventError, UnexpectedEventError +from ..exceptions import EventError, UnmodeledEventError +from . import INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE from .traits import EVENT_HEADER_TRAIT, EVENT_PAYLOAD_TRAIT +INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE) + class EventDeserializer(SpecificShapeDeserializer): def __init__( @@ -29,14 +32,23 @@ def read_struct( ) -> None: event = Event.decode(self._source) headers = event.message.headers - message_deserializer = EventMessageDeserializer( - headers, self._payload_codec.create_deserializer(event.message.payload) - ) + + payload_deserializer = None + if event.message.payload: + payload_deserializer = self._payload_codec.create_deserializer( + event.message.payload + ) + + message_deserializer = EventMessageDeserializer(headers, payload_deserializer) match headers.get(":message-type"): case "event": member_name = expect_type(str, headers[":event-type"]) - consumer(schema.members[member_name], message_deserializer) + if member_name in INITIAL_MESSAGE_TYPES: + # If it's an initial message, skip straight to deserialization. + message_deserializer.read_struct(schema, consumer) + else: + consumer(schema.members[member_name], message_deserializer) case "exception": member_name = expect_type(str, headers[":exception-type"]) consumer(schema.members[member_name], message_deserializer) @@ -44,7 +56,7 @@ def read_struct( # The `application/vnd.amazon.eventstream` format allows for explicitly # unmodeled exceptions. These exceptions MUST have the `:error-code` # and `:error-message` headers set, and they MUST be strings. - raise UnexpectedEventError( + raise UnmodeledEventError( expect_type(str, headers[":error-code"]), expect_type(str, headers[":error-message"]), ) @@ -54,7 +66,7 @@ def read_struct( class EventMessageDeserializer(SpecificShapeDeserializer): def __init__( - self, headers: HEADERS_DICT, payload_deserializer: ShapeDeserializer + self, headers: HEADERS_DICT, payload_deserializer: ShapeDeserializer | None ) -> None: self._headers = headers self._payload_deserializer = payload_deserializer @@ -70,10 +82,11 @@ def read_struct( if member_schema is not None and EVENT_HEADER_TRAIT in member_schema.traits: consumer(member_schema, headers_deserializer) - if (payload_member := self._get_payload_member(schema)) is not None: - consumer(payload_member, self._payload_deserializer) - else: - self._payload_deserializer.read_struct(schema, consumer) + if self._payload_deserializer: + if (payload_member := self._get_payload_member(schema)) is not None: + consumer(payload_member, self._payload_deserializer) + else: + self._payload_deserializer.read_struct(schema, consumer) def _get_payload_member(self, schema: "Schema") -> "Schema | None": for member in schema.members.values(): diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py index 06612169..12d919e5 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py @@ -13,13 +13,21 @@ ShapeSerializer, SpecificShapeSerializer, ) +from smithy_core.shapes import ShapeType +from smithy_core.utils import expect_type from ..events import EventHeaderEncoder, EventMessage from ..exceptions import InvalidHeaderValue -from .traits import ERROR_TRAIT, EVENT_HEADER_TRAIT, EVENT_PAYLOAD_TRAIT +from . import INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE +from .traits import ( + ERROR_TRAIT, + EVENT_HEADER_TRAIT, + EVENT_PAYLOAD_TRAIT, + MEDIA_TYPE_TRAIT, +) -_INITIAL_REQUEST_EVENT_TYPE = "initial-request" -_INITIAL_RESPONSE_EVENT_TYPE = "initial-response" +_DEFAULT_STRING_CONTENT_TYPE = "text/plain" +_DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream" class EventSerializer(SpecificShapeSerializer): @@ -31,15 +39,29 @@ def __init__( self._payload_codec = payload_codec self._result: EventMessage | None = None if is_client_mode: - self._initial_message_event_type = _INITIAL_REQUEST_EVENT_TYPE + self._initial_message_event_type = INITIAL_REQUEST_EVENT_TYPE else: - self._initial_message_event_type = _INITIAL_RESPONSE_EVENT_TYPE + self._initial_message_event_type = INITIAL_RESPONSE_EVENT_TYPE def get_result(self) -> EventMessage | None: return self._result @contextmanager def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]: + # Event stream definitions are unions. Nothing about the union shape actually + # matters for the purposes of event stream serialization though, we just care + # about the specific member we're serializing. So here we yield immediately, + # and the next time this method is called it'll be by the member that we + # actually care about. + # + # Note that if we're serializing an operation input or output, it won't be a + # union at all, so this won't get triggered. Thankfully, that's what we want. + if schema.shape_type is ShapeType.UNION: + try: + yield self + finally: + return + headers_encoder = EventHeaderEncoder() if ERROR_TRAIT in schema.traits: @@ -65,21 +87,40 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]: ) header_serializer = EventHeaderSerializer(headers_encoder) - if not self._has_payload_member(schema): + media_type = self._payload_codec.media_type + + if (payload_member := self._get_payload_member(schema)) is not None: + media_type = self._get_payload_media_type(payload_member, media_type) + yield EventStreamBindingSerializer(header_serializer, payload_serializer) + else: with payload_serializer.begin_struct(schema) as body_serializer: yield EventStreamBindingSerializer(header_serializer, body_serializer) - else: - yield EventStreamBindingSerializer(header_serializer, payload_serializer) + + payload_bytes = payload.getvalue() + if payload_bytes: + headers_encoder.encode_string(":content-type", media_type) self._result = EventMessage( - headers_bytes=headers_encoder.get_result(), payload=payload.getvalue() + headers_bytes=headers_encoder.get_result(), payload=payload_bytes ) - def _has_payload_member(self, schema: "Schema") -> bool: + def _get_payload_member(self, schema: Schema) -> Schema | None: for member in schema.members.values(): if EVENT_PAYLOAD_TRAIT in member.traits: - return True - return False + return schema + return None + + def _get_payload_media_type(self, schema: Schema, default: str) -> str: + if (media_type := schema.traits.get(MEDIA_TYPE_TRAIT)) is not None: + return expect_type(str, media_type.value) + + match schema.shape_type: + case ShapeType.STRING: + return _DEFAULT_STRING_CONTENT_TYPE + case ShapeType.BLOB: + return _DEFAULT_BLOB_CONTENT_TYPE + case _: + return default class EventHeaderSerializer(SpecificShapeSerializer): @@ -129,6 +170,7 @@ def __init__( self._payload_serializer = payload_serializer def before(self, schema: "Schema") -> ShapeSerializer: + print(f"FOUND TRAITS: {schema.traits}") if EVENT_HEADER_TRAIT in schema.traits: return self._header_serializer return self._payload_serializer diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/traits.py b/python-packages/aws-event-stream/aws_event_stream/_private/traits.py index 0acfd969..73812628 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/traits.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/traits.py @@ -6,3 +6,4 @@ EVENT_HEADER_TRAIT = ShapeID("smithy.api#eventHeader") EVENT_PAYLOAD_TRAIT = ShapeID("smithy.api#eventPayload") ERROR_TRAIT = ShapeID("smithy.api#error") +MEDIA_TYPE_TRAIT = ShapeID("smithy.api#mediaType") diff --git a/python-packages/aws-event-stream/aws_event_stream/events.py b/python-packages/aws-event-stream/aws_event_stream/events.py index 0120fa25..50da440b 100644 --- a/python-packages/aws-event-stream/aws_event_stream/events.py +++ b/python-packages/aws-event-stream/aws_event_stream/events.py @@ -8,6 +8,7 @@ """ import datetime +import struct import uuid from binascii import crc32 from collections.abc import Callable, Iterator, Mapping @@ -25,6 +26,7 @@ InvalidHeadersLength, InvalidHeaderValue, InvalidHeaderValueLength, + InvalidIntegerValue, InvalidPayloadLength, ) @@ -32,6 +34,10 @@ MAX_HEADER_VALUE_BYTE_LENGTH = 32 * 1024 - 1 # 32Kb MAX_PAYLOAD_LENGTH = 16 * 1024**2 # 16 Mb +# In addition to the header length and payload length, the total length of the +# message includes 12 bytes for the prelude and 4 bytes for the trailing crc. +_MESSAGE_METADATA_SIZE = 16 + class Byte(int): """An 8-bit integer. @@ -186,16 +192,16 @@ def __init__( ) -> None: """Initialize an EventMessage. - :param headers: The headers present in the event message. This parameter or - `headers_bytes` MUST be specified. If this parameter is unspecified, the - `headers_bytes` parameter will be decoded. + :param headers: The headers present in the event message. If this parameter is + unspecified, the default value will be the decoded value of the + `headers_bytes` parameter. Sized integer values may be indicated for the purpose of serialization using the `Byte`, `Short`, or `Long` types. int values of unspecified size will be assumed to be 32-bit. :param headers_bytes: The serialized bytes of the headers present in the event - message. This parameter or `headers` MUST be specified. + message. :param payload: The serialized bytes of the message payload. """ @@ -207,7 +213,7 @@ def __init__( if headers_bytes is None: if headers is None: - raise ValueError("One of headers or headers_bytes must be set.") + headers = {} elif headers is None: headers = EventHeaderDecoder(headers_bytes).decode_headers() @@ -239,7 +245,7 @@ def _get_headers_bytes(self) -> bytes: return self._headers_bytes def encode(self) -> bytes: - return EventEncoder().encode_bytes( + return _EventEncoder().encode_bytes( headers=self._get_headers_bytes(), payload=self._payload ) @@ -290,7 +296,7 @@ def decode(cls, source: BytesReader) -> Self: prelude_bytes = source.read(8) prelude_crc_bytes = source.read(4) - prelude_crc: int = DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] total_length, headers_length = unpack("!II", prelude_bytes) _validate_checksum(prelude_bytes, prelude_crc) @@ -298,9 +304,8 @@ def decode(cls, source: BytesReader) -> Self: total_length=total_length, headers_length=headers_length, crc=prelude_crc ) - print(f"TOTAL LENGTH: {total_length}\nMESSAGE LENGTH: {total_length - 16}\nMAX {MAX_PAYLOAD_LENGTH}") - message_bytes = source.read(total_length - 16) - crc: int = DecodeUtils.unpack_uint32(source.read(4))[0] + message_bytes = source.read(total_length - _MESSAGE_METADATA_SIZE) + crc: int = _DecodeUtils.unpack_uint32(source.read(4))[0] _validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc) message = EventMessage( @@ -310,10 +315,10 @@ def decode(cls, source: BytesReader) -> Self: return cls(prelude, message, crc) -class EventEncoder: +class _EventEncoder: """A utility class that encodes message bytes into binary events.""" - def encode_bytes(self, headers: bytes, payload: bytes) -> bytes: + def encode_bytes(self, *, headers: bytes = b"", payload: bytes = b"") -> bytes: """Encode message bytes into a binary event. :param headers: The bytes representing the event headers. @@ -336,10 +341,8 @@ def encode_bytes(self, headers: bytes, payload: bytes) -> bytes: return prelude_bytes + message_bytes + final_crc_bytes def _encode_prelude_bytes(self, headers: bytes, payload: bytes) -> bytes: - # In addition to the header length and payload length, the total length of the - # message includes 12 bytes for the prelude and 4 bytes for the trailing crc. header_length = len(headers) - total_length = header_length + len(payload) + 16 + total_length = header_length + len(payload) + _MESSAGE_METADATA_SIZE return pack("!II", total_length, header_length) def _calculate_checksum(self, data: bytes, crc: int = 0) -> int: @@ -425,7 +428,10 @@ def encode_byte(self, key: str, value: int) -> None: """ self._encode_key(key) self._buffer.write(b"\x02") - self._buffer.write(pack("!b", value)) + try: + self._buffer.write(pack("!b", value)) + except struct.error as e: + raise InvalidIntegerValue("byte", value) from e def encode_short(self, key: str, value: int) -> None: """Encode a 16-bit int header. @@ -435,7 +441,10 @@ def encode_short(self, key: str, value: int) -> None: """ self._encode_key(key) self._buffer.write(b"\x03") - self._buffer.write(pack("!h", value)) + try: + self._buffer.write(pack("!h", value)) + except struct.error as e: + raise InvalidIntegerValue("short", value) from e def encode_integer(self, key: str, value: int) -> None: """Encode a 32-bit int header. @@ -445,7 +454,10 @@ def encode_integer(self, key: str, value: int) -> None: """ self._encode_key(key) self._buffer.write(b"\x04") - self._buffer.write(pack("!i", value)) + try: + self._buffer.write(pack("!i", value)) + except struct.error as e: + raise InvalidIntegerValue("integer", value) from e def encode_long(self, key: str, value: int) -> None: """Encode a 64-bit int header. @@ -455,7 +467,10 @@ def encode_long(self, key: str, value: int) -> None: """ self._encode_key(key) self._buffer.write(b"\x05") - self._buffer.write(pack("!q", value)) + try: + self._buffer.write(pack("!q", value)) + except struct.error as e: + raise InvalidIntegerValue("long", value) from e def encode_blob(self, key: str, value: bytes) -> None: """Encode a binary header. @@ -519,7 +534,7 @@ def encode_uuid(self, key: str, value: uuid.UUID): _ArraySize = Literal[1] | Literal[2] | Literal[4] -class DecodeUtils: +class _DecodeUtils: """Unpacking utility functions used in the decoder. All methods on this class take raw bytes and return a tuple containing the value @@ -548,7 +563,7 @@ def unpack_uint8(data: BytesLike) -> tuple[int, int]: :param data: The bytes to parse from. :returns: A tuple containing the (parsed integer value, bytes consumed) """ - value = unpack(DecodeUtils.UINT8_BYTE_FORMAT, data[:1])[0] + value = unpack(_DecodeUtils.UINT8_BYTE_FORMAT, data[:1])[0] return value, 1 @staticmethod @@ -558,7 +573,7 @@ def unpack_uint32(data: BytesLike) -> tuple[int, int]: :param data: The bytes to parse from. :returns: A tuple containing the (parsed integer value, bytes consumed) """ - value = unpack(DecodeUtils.UINT32_BYTE_FORMAT, data[:4])[0] + value = unpack(_DecodeUtils.UINT32_BYTE_FORMAT, data[:4])[0] return value, 4 @staticmethod @@ -568,7 +583,7 @@ def unpack_int8(data: BytesLike): :param data: The bytes to parse from. :returns: A tuple containing the (parsed integer value, bytes consumed) """ - value = unpack(DecodeUtils.INT8_BYTE_FORMAT, data[:1])[0] + value = unpack(_DecodeUtils.INT8_BYTE_FORMAT, data[:1])[0] return value, 1 @staticmethod @@ -578,7 +593,7 @@ def unpack_int16(data: BytesLike) -> tuple[int, int]: :param data: The bytes to parse from. :returns: A tuple containing the (parsed integer value, bytes consumed) """ - value = unpack(DecodeUtils.INT16_BYTE_FORMAT, data[:2])[0] + value = unpack(_DecodeUtils.INT16_BYTE_FORMAT, data[:2])[0] return value, 2 @staticmethod @@ -588,7 +603,7 @@ def unpack_int32(data: BytesLike) -> tuple[int, int]: :param data: The bytes to parse from. :returns: A tuple containing the (parsed integer value, bytes consumed) """ - value = unpack(DecodeUtils.INT32_BYTE_FORMAT, data[:4])[0] + value = unpack(_DecodeUtils.INT32_BYTE_FORMAT, data[:4])[0] return value, 4 @staticmethod @@ -598,7 +613,7 @@ def unpack_int64(data: BytesLike) -> tuple[int, int]: :param data: The bytes to parse from. :returns: A tuple containing the (parsed integer value, bytes consumed) """ - value = unpack(DecodeUtils.INT64_BYTE_FORMAT, data[:8])[0] + value = unpack(_DecodeUtils.INT64_BYTE_FORMAT, data[:8])[0] return value, 8 @staticmethod @@ -617,9 +632,8 @@ def unpack_byte_array( represents the length of the array. Supported values are 1, 2, and 4. :returns: A tuple containing the (parsed bytes, bytes consumed) """ - uint_byte_format = DecodeUtils.UINT_BYTE_FORMAT[length_byte_size] + uint_byte_format = _DecodeUtils.UINT_BYTE_FORMAT[length_byte_size] length = unpack(uint_byte_format, data[:length_byte_size])[0] - print(f"HEADER LENGTH: {length}") if length > MAX_HEADER_VALUE_BYTE_LENGTH: raise InvalidHeaderValueLength(length) bytes_end = length + length_byte_size @@ -643,7 +657,7 @@ def unpack_utf8_string( represents the length of the array. Supported values are 1, 2, and 4. :returns: A tuple containing the (parsed string, bytes consumed) """ - array_bytes, consumed = DecodeUtils.unpack_byte_array(data, length_byte_size) + array_bytes, consumed = _DecodeUtils.unpack_byte_array(data, length_byte_size) return array_bytes.decode("utf-8"), consumed @staticmethod @@ -656,7 +670,7 @@ def unpack_timestamp(data: BytesLike) -> tuple[datetime.datetime, int]: :param data: The bytes to parse from. :returns: A tuple containing the (datetime.datetime, bytes consumed). """ - int_value, consumed = DecodeUtils.unpack_int64(data) + int_value, consumed = _DecodeUtils.unpack_int64(data) timestamp_value = TimestampFormat.EPOCH_SECONDS.deserialize(int_value / 1000) return timestamp_value, consumed @@ -680,14 +694,14 @@ class EventHeaderDecoder(Iterator[tuple[str, HEADER_VALUE]]): # can just return static values. 0: lambda b: (True, 0), # boolean_true 1: lambda b: (False, 0), # boolean_false - 2: DecodeUtils.unpack_int8, # byte - 3: DecodeUtils.unpack_int16, # short - 4: DecodeUtils.unpack_int32, # integer - 5: DecodeUtils.unpack_int64, # long - 6: DecodeUtils.unpack_byte_array, # byte_array - 7: DecodeUtils.unpack_utf8_string, # string - 8: DecodeUtils.unpack_timestamp, # timestamp - 9: DecodeUtils.unpack_uuid, # uuid + 2: _DecodeUtils.unpack_int8, # byte + 3: _DecodeUtils.unpack_int16, # short + 4: _DecodeUtils.unpack_int32, # integer + 5: _DecodeUtils.unpack_int64, # long + 6: _DecodeUtils.unpack_byte_array, # byte_array + 7: _DecodeUtils.unpack_utf8_string, # string + 8: _DecodeUtils.unpack_timestamp, # timestamp + 9: _DecodeUtils.unpack_uuid, # uuid } def __init__(self, header_bytes: BytesLike) -> None: @@ -719,10 +733,10 @@ def decode_header(self) -> tuple[str, HEADER_VALUE]: :returns: A single key-value pair read from the source. """ - key, consumed = DecodeUtils.unpack_utf8_string(self._data, 1) + key, consumed = _DecodeUtils.unpack_utf8_string(self._data, 1) self._advance_data(consumed) - type, consumed = DecodeUtils.unpack_uint8(self._data) + type, consumed = _DecodeUtils.unpack_uint8(self._data) self._advance_data(consumed) value_unpacker = self._HEADER_TYPE_MAP[type] diff --git a/python-packages/aws-event-stream/aws_event_stream/exceptions.py b/python-packages/aws-event-stream/aws_event_stream/exceptions.py index 969baff1..8c5dc083 100644 --- a/python-packages/aws-event-stream/aws_event_stream/exceptions.py +++ b/python-packages/aws-event-stream/aws_event_stream/exceptions.py @@ -1,6 +1,5 @@ """Binary Event Stream support for the application/vnd.amazon.eventstream format.""" -from dataclasses import dataclass from typing import Any from smithy_core.exceptions import SmithyException @@ -10,13 +9,26 @@ class EventError(SmithyException): - pass + """Base error for all errors thrown during event stream handling.""" -@dataclass -class UnexpectedEventError(EventError): +class UnmodeledEventError(EventError): + """Unmodeled event error was read from the event stream. + + These classes of errors tend to be internal server errors or other unexpected errors + on the service side. + """ + code: str + """A code identifying the class of error.""" + message: str + """The explanation of the error sent over the event stream.""" + + def __init__(self, code: str, message: str) -> None: + self.code = code + self.message = message + super().__init__(f"Received unmodeled event error: {code} - {message}") class DuplicateHeader(EventError): @@ -69,3 +81,13 @@ class ChecksumMismatch(EventError): def __init__(self, expected: int, calculated: int): message = f"Checksum mismatch: expected 0x{expected:08x}, calculated 0x{calculated:08x}" super().__init__(message) + + +class InvalidIntegerValue(EventError): + def __init__(self, size: str, value: int): + message = ( + f"Invalid {size} value: {value}. The Byte, Short, and Long types may be " + f"used to specify the size of the int. Unspecified ints are assumed to " + f"be 32-bit." + ) + super().__init__(message) diff --git a/python-packages/aws-event-stream/tests/unit/_private/__init__.py b/python-packages/aws-event-stream/tests/unit/_private/__init__.py new file mode 100644 index 00000000..2cd05b57 --- /dev/null +++ b/python-packages/aws-event-stream/tests/unit/_private/__init__.py @@ -0,0 +1,571 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +import datetime +from typing import Any, Self, ClassVar, Literal + +from aws_event_stream.events import EventMessage, Short, Byte, Long + +from smithy_core.serializers import ShapeSerializer +from smithy_core.deserializers import ShapeDeserializer +from smithy_core.schemas import Schema +from smithy_core.exceptions import SmithyException +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import Trait +from smithy_core.prelude import ( + BOOLEAN, + BYTE, + SHORT, + INTEGER, + LONG, + BLOB, + STRING, + TIMESTAMP, +) + + +EVENT_HEADER_TRAIT = Trait(id=ShapeID("smithy.api#eventHeader")) +EVENT_PAYLOAD_TRAIT = Trait(id=ShapeID("smithy.api#eventPayload")) +ERROR_TRAIT = Trait(id=ShapeID("smithy.api#error"), value="client") +REQUIRED_TRAIT = Trait(id=ShapeID("smithy.api#required")) +STREAMING_TRAIT = Trait(id=ShapeID("smith.api#streaming")) + + +SCHEMA_MESSAGE_EVENT = Schema.collection( + id=ShapeID("smithy.example#MessageEvent"), + members={ + "boolHeader": {"index": 0, "target": BOOLEAN, "traits": [EVENT_HEADER_TRAIT]}, + "byteHeader": {"index": 1, "target": BYTE, "traits": [EVENT_HEADER_TRAIT]}, + "shortHeader": {"index": 2, "target": SHORT, "traits": [EVENT_HEADER_TRAIT]}, + "intHeader": {"index": 3, "target": INTEGER, "traits": [EVENT_HEADER_TRAIT]}, + "longHeader": {"index": 4, "target": LONG, "traits": [EVENT_HEADER_TRAIT]}, + "blobHeader": {"index": 5, "target": BLOB, "traits": [EVENT_HEADER_TRAIT]}, + "stringHeader": {"index": 6, "target": STRING, "traits": [EVENT_HEADER_TRAIT]}, + "timestampHeader": { + "index": 7, + "target": TIMESTAMP, + "traits": [EVENT_HEADER_TRAIT], + }, + "bodyMember": {"index": 8, "target": STRING}, + }, +) + +SCHEMA_PAYLOAD_EVENT = Schema.collection( + id=ShapeID("smithy.example#PayloadEvent"), + members={ + "header": { + "index": 0, + "target": STRING, + "traits": [EVENT_HEADER_TRAIT, REQUIRED_TRAIT], + }, + "payload": { + "index": 1, + "target": STRING, + "traits": [EVENT_PAYLOAD_TRAIT, REQUIRED_TRAIT], + }, + }, +) + +SCHEMA_ERROR_EVENT = Schema.collection( + id=ShapeID("smithy.example#ErrorEvent"), + members={"message": {"index": 0, "target": STRING, "traits": [REQUIRED_TRAIT]}}, + traits=[ERROR_TRAIT], +) + +SCHEMA_EVENT_STREAM = Schema.collection( + id=ShapeID("smithy.example#EventStream"), + shape_type=ShapeType.UNION, + traits=[STREAMING_TRAIT], + members={ + "message": {"index": 0, "target": SCHEMA_MESSAGE_EVENT}, + "payload": {"index": 1, "target": SCHEMA_PAYLOAD_EVENT}, + "error": {"index": 2, "target": SCHEMA_ERROR_EVENT}, + }, +) + +SCHEMA_INITIAL_MESSAGE = Schema.collection( + id=ShapeID("smithy.example#EventStreamOperationInputOutput"), + members={ + "message": {"index": 0, "target": STRING}, + # Event stream members will not be part of the operation input / output + # shape schemas. + # "stream": { + # "index": 1, + # "target": SCHEMA_EVENT_STREAM + # }, + }, +) + + +@dataclass +class MessageEvent: + bool_header: bool | None = None + byte_header: int | None = None + short_header: int | None = None + int_header: int | None = None + long_header: int | None = None + blob_header: bytes | None = None + string_header: str | None = None + timestamp_header: datetime.datetime | None = None + body_member: str | None = None + + def serialize(self, serializer: ShapeSerializer): + with serializer.begin_struct(SCHEMA_MESSAGE_EVENT) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + if self.bool_header is not None: + serializer.write_boolean( + SCHEMA_MESSAGE_EVENT.members["boolHeader"], self.bool_header + ) + + if self.byte_header is not None: + serializer.write_byte( + SCHEMA_MESSAGE_EVENT.members["byteHeader"], self.byte_header + ) + + if self.short_header is not None: + serializer.write_short( + SCHEMA_MESSAGE_EVENT.members["shortHeader"], self.short_header + ) + + if self.int_header is not None: + serializer.write_integer( + SCHEMA_MESSAGE_EVENT.members["intHeader"], self.int_header + ) + + if self.long_header is not None: + serializer.write_long( + SCHEMA_MESSAGE_EVENT.members["longHeader"], self.long_header + ) + + if self.blob_header is not None: + serializer.write_blob( + SCHEMA_MESSAGE_EVENT.members["blobHeader"], self.blob_header + ) + + if self.string_header is not None: + serializer.write_string( + SCHEMA_MESSAGE_EVENT.members["stringHeader"], self.string_header + ) + + if self.timestamp_header is not None: + serializer.write_timestamp( + SCHEMA_MESSAGE_EVENT.members["timestampHeader"], self.timestamp_header + ) + + if self.body_member is not None: + serializer.write_string( + SCHEMA_MESSAGE_EVENT.members["bodyMember"], self.body_member + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["bool_header"] = de.read_boolean( + SCHEMA_MESSAGE_EVENT.members["boolHeader"] + ) + + case 1: + kwargs["byte_header"] = de.read_byte( + SCHEMA_MESSAGE_EVENT.members["byteHeader"] + ) + + case 2: + kwargs["short_header"] = de.read_short( + SCHEMA_MESSAGE_EVENT.members["shortHeader"] + ) + + case 3: + kwargs["int_header"] = de.read_integer( + SCHEMA_MESSAGE_EVENT.members["intHeader"] + ) + + case 4: + kwargs["long_header"] = de.read_long( + SCHEMA_MESSAGE_EVENT.members["longHeader"] + ) + + case 5: + kwargs["blob_header"] = de.read_blob( + SCHEMA_MESSAGE_EVENT.members["blobHeader"] + ) + + case 6: + kwargs["string_header"] = de.read_string( + SCHEMA_MESSAGE_EVENT.members["stringHeader"] + ) + + case 7: + kwargs["timestamp_header"] = de.read_timestamp( + SCHEMA_MESSAGE_EVENT.members["timestampHeader"] + ) + + case 8: + kwargs["body_member"] = de.read_string( + SCHEMA_MESSAGE_EVENT.members["bodyMember"] + ) + + case _: + raise SmithyException(f"Unexpected member schema: {schema}") + + deserializer.read_struct(schema=SCHEMA_MESSAGE_EVENT, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class EventStreamMessageEvent: + value: MessageEvent + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(SCHEMA_EVENT_STREAM, self) + + def serialize_members(self, serializer: ShapeSerializer): + serializer.write_struct(SCHEMA_EVENT_STREAM.members["message"], self.value) + + +@dataclass +class PayloadEvent: + header: str + payload: str + + def serialize(self, serializer: ShapeSerializer): + with serializer.begin_struct(SCHEMA_PAYLOAD_EVENT) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_string(SCHEMA_PAYLOAD_EVENT.members["header"], self.header) + serializer.write_string(SCHEMA_PAYLOAD_EVENT.members["payload"], self.payload) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["header"] = de.read_string( + SCHEMA_PAYLOAD_EVENT.members["header"] + ) + case 1: + kwargs["payload"] = de.read_string( + SCHEMA_PAYLOAD_EVENT.members["payload"] + ) + case _: + raise SmithyException(f"Unexpected member schema: {schema}") + + deserializer.read_struct(schema=SCHEMA_PAYLOAD_EVENT, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class EventStreamPayloadEvent: + value: PayloadEvent + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(SCHEMA_EVENT_STREAM, self) + + def serialize_members(self, serializer: ShapeSerializer): + serializer.write_struct(SCHEMA_EVENT_STREAM.members["payload"], self.value) + + +@dataclass +class ErrorEvent: + code: ClassVar[str] = "NoSuchResource" + fault: ClassVar[Literal["client", "server"]] = "client" + + message: str + + def serialize(self, serializer: ShapeSerializer): + with serializer.begin_struct(SCHEMA_ERROR_EVENT) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_string(SCHEMA_ERROR_EVENT.members["message"], self.message) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["message"] = de.read_string( + SCHEMA_ERROR_EVENT.members["message"] + ) + case _: + raise SmithyException(f"Unexpected member schema: {schema}") + + deserializer.read_struct(schema=SCHEMA_ERROR_EVENT, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class EventStreamErrorEvent: + value: ErrorEvent + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(SCHEMA_EVENT_STREAM, self) + + def serialize_members(self, serializer: ShapeSerializer): + serializer.write_struct(SCHEMA_EVENT_STREAM.members["error"], self.value) + + +@dataclass +class EventStreamUnknownEvent: + tag: str + + def serialize(self, serializer: ShapeSerializer): + raise SmithyException("Unknown union variants may not be serialized.") + + def serialize_members(self, serializer: ShapeSerializer): + raise SmithyException("Unknown union variants may not be serialized.") + + +type EventStream = EventStreamMessageEvent | EventStreamPayloadEvent | EventStreamErrorEvent | EventStreamUnknownEvent + + +class EventStreamDeserializer: + _result: EventStream | None = None + + def deserialize(self, deserializer: ShapeDeserializer) -> EventStream: + self._result = None + deserializer.read_struct(SCHEMA_EVENT_STREAM, self._consumer) + + if self._result is None: + raise SmithyException("Unions must have exactly one value, but found none.") + + return self._result + + def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + self._set_result(EventStreamMessageEvent(MessageEvent.deserialize(de))) + + case 1: + self._set_result(EventStreamPayloadEvent(PayloadEvent.deserialize(de))) + + case 2: + self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de))) + + case _: + raise SmithyException(f"Unexpected member schema: {schema}") + + def _set_result(self, value: EventStream) -> None: + if self._result is not None: + raise SmithyException( + "Unions must have exactly one value, but found more than one." + ) + self._result = value + + +@dataclass +class EventStreamOperationInputOutput: + message: str + + def serialize(self, serializer: ShapeSerializer): + with serializer.begin_struct(SCHEMA_INITIAL_MESSAGE) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_string(SCHEMA_INITIAL_MESSAGE.members["message"], self.message) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["message"] = de.read_string( + SCHEMA_INITIAL_MESSAGE.members["message"] + ) + case _: + raise SmithyException(f"Unexpected member schema: {schema}") + + deserializer.read_struct(schema=SCHEMA_INITIAL_MESSAGE, consumer=_consumer) + return cls(**kwargs) + + +EVENT_STREAM_SERDE_CASES = [ + ( + EventStreamMessageEvent(MessageEvent(bool_header=True)), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "message", + "boolHeader": True, + ":content-type": "application/json", + }, + payload=b"{}", + ), + ), + ( + EventStreamMessageEvent(MessageEvent(bool_header=False)), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "message", + "boolHeader": False, + ":content-type": "application/json", + }, + payload=b"{}", + ), + ), + ( + EventStreamMessageEvent(MessageEvent(byte_header=1)), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "message", + "byteHeader": Byte(1), + ":content-type": "application/json", + }, + payload=b"{}", + ), + ), + ( + EventStreamMessageEvent(MessageEvent(short_header=1)), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "message", + "shortHeader": Short(1), + ":content-type": "application/json", + }, + payload=b"{}", + ), + ), + ( + EventStreamMessageEvent(MessageEvent(int_header=1)), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "message", + "intHeader": 1, + ":content-type": "application/json", + }, + payload=b"{}", + ), + ), + ( + EventStreamMessageEvent(MessageEvent(long_header=1)), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "message", + "longHeader": Long(1), + ":content-type": "application/json", + }, + payload=b"{}", + ), + ), + ( + EventStreamMessageEvent(MessageEvent(blob_header=b"blob")), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "message", + "blobHeader": b"blob", + ":content-type": "application/json", + }, + payload=b"{}", + ), + ), + ( + EventStreamMessageEvent(MessageEvent(string_header="string")), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "message", + "stringHeader": "string", + ":content-type": "application/json", + }, + payload=b"{}", + ), + ), + ( + EventStreamMessageEvent( + MessageEvent( + timestamp_header=datetime.datetime( + 1970, 1, 1, 0, 0, 0, 8000, tzinfo=datetime.UTC + ) + ) + ), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "message", + "timestampHeader": datetime.datetime( + 1970, 1, 1, 0, 0, 0, 8000, tzinfo=datetime.UTC + ), + ":content-type": "application/json", + }, + payload=b"{}", + ), + ), + ( + EventStreamMessageEvent(MessageEvent(body_member="body")), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "message", + ":content-type": "application/json", + }, + payload=b'{"bodyMember":"body"}', + ), + ), + ( + EventStreamPayloadEvent(PayloadEvent(header="header", payload="payload")), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "payload", + "header": "header", + ":content-type": "application/json", + }, + payload=b'"payload"', + ), + ), + ( + EventStreamErrorEvent(ErrorEvent(message="error message")), + EventMessage( + headers={ + ":message-type": "exception", + ":exception-type": "error", + ":content-type": "application/json", + }, + payload=b'{"message":"error message"}', + ), + ), +] + + +INITIAL_REQUEST_CASE = ( + EventStreamOperationInputOutput(message="The initial request!"), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "initial-request", + ":content-type": "application/json", + }, + payload=b'{"message":"The initial request!"}', + ), +) + + +INITIAL_RESPONSE_CASE = ( + EventStreamOperationInputOutput(message="The initial response!"), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "initial-response", + ":content-type": "application/json", + }, + payload=b'{"message":"The initial response!"}', + ), +) diff --git a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py new file mode 100644 index 00000000..9dfa84df --- /dev/null +++ b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from io import BytesIO + +import pytest +from smithy_core.deserializers import DeserializeableShape +from smithy_json import JSONCodec + +from aws_event_stream._private.deserializers import EventDeserializer +from aws_event_stream.events import EventMessage +from aws_event_stream.exceptions import UnmodeledEventError + +from . import ( + EVENT_STREAM_SERDE_CASES, + INITIAL_REQUEST_CASE, + INITIAL_RESPONSE_CASE, + EventStreamDeserializer, + EventStreamOperationInputOutput, +) + + +@pytest.mark.parametrize("expected,given", EVENT_STREAM_SERDE_CASES) +def test_event_deserializer(expected: DeserializeableShape, given: EventMessage): + source = BytesIO(given.encode()) + deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + result = EventStreamDeserializer().deserialize(deserializer) + assert result == expected + + +def test_deserialize_initial_request(): + expected, given = INITIAL_REQUEST_CASE + source = BytesIO(given.encode()) + deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + result = EventStreamOperationInputOutput.deserialize(deserializer) + assert result == expected + + +def test_deserialize_initial_response(): + expected, given = INITIAL_RESPONSE_CASE + source = BytesIO(given.encode()) + deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + result = EventStreamOperationInputOutput.deserialize(deserializer) + assert result == expected + + +def test_deserialize_unmodeled_error(): + message = EventMessage( + headers={ + ":message-type": "error", + ":error-code": "InternalError", + ":error-message": "An internal server error occurred.", + } + ) + source = BytesIO(message.encode()) + deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + + with pytest.raises(UnmodeledEventError, match="InternalError"): + EventStreamOperationInputOutput.deserialize(deserializer) diff --git a/python-packages/aws-event-stream/tests/unit/_private/test_serializers.py b/python-packages/aws-event-stream/tests/unit/_private/test_serializers.py new file mode 100644 index 00000000..c3f16277 --- /dev/null +++ b/python-packages/aws-event-stream/tests/unit/_private/test_serializers.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest +from smithy_core.serializers import SerializeableShape +from smithy_json import JSONCodec + +from aws_event_stream._private.serializers import EventSerializer +from aws_event_stream.events import EventMessage + +from . import EVENT_STREAM_SERDE_CASES, INITIAL_REQUEST_CASE, INITIAL_RESPONSE_CASE + + +@pytest.mark.parametrize("given,expected", EVENT_STREAM_SERDE_CASES) +def test_event_serializer_client_mode( + given: SerializeableShape, expected: EventMessage +): + serializer = EventSerializer(payload_codec=JSONCodec(), is_client_mode=True) + given.serialize(serializer) + actual = serializer.get_result() + assert actual == expected + + +@pytest.mark.parametrize("given,expected", EVENT_STREAM_SERDE_CASES) +def test_event_serializer_server_mode( + given: SerializeableShape, expected: EventMessage +): + serializer = EventSerializer(payload_codec=JSONCodec(), is_client_mode=False) + given.serialize(serializer) + actual = serializer.get_result() + assert actual == expected + + +def test_serialize_initial_request(): + test_event_serializer_client_mode(*INITIAL_REQUEST_CASE) + + +def test_serialize_initial_response(): + test_event_serializer_server_mode(*INITIAL_RESPONSE_CASE) diff --git a/python-packages/aws-event-stream/tests/unit/test_events.py b/python-packages/aws-event-stream/tests/unit/test_events.py index 962d9122..555e3e3d 100644 --- a/python-packages/aws-event-stream/tests/unit/test_events.py +++ b/python-packages/aws-event-stream/tests/unit/test_events.py @@ -1,5 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +# pyright: reportPrivateUsage=false import datetime import uuid from io import BytesIO @@ -8,19 +9,24 @@ from aws_event_stream.events import ( MAX_HEADER_VALUE_BYTE_LENGTH, + MAX_HEADERS_LENGTH, MAX_PAYLOAD_LENGTH, Byte, Event, + EventHeaderDecoder, + EventHeaderEncoder, EventMessage, EventPrelude, Long, Short, ) +from aws_event_stream.events import _EventEncoder as EventEncoder from aws_event_stream.exceptions import ( ChecksumMismatch, DuplicateHeader, InvalidHeadersLength, InvalidHeaderValueLength, + InvalidIntegerValue, InvalidPayloadLength, ) @@ -393,6 +399,7 @@ ERROR_EVENT_MESSAGE, ] +# TODO: botocore isn't passing this, fix there CORRUPTED_HEADERS_LENGTH = ( ( b"\x00\x00\x00\x3d" # total length @@ -417,6 +424,7 @@ ChecksumMismatch, ) +# TODO: botocore isn't passing this, fix there CORRUPTED_LENGTH = ( ( b"\x01\x00\x00\x1d" # total length @@ -522,3 +530,225 @@ def test_encode(expected: bytes, event: Event): def test_negative_cases(encoded: bytes, expected: type[Exception]): with pytest.raises(expected): Event.decode(BytesIO(encoded)) + + +def test_event_prelude_rejects_long_headers(): + headers_length = MAX_HEADERS_LENGTH + 1 + total_length = headers_length + 16 + with pytest.raises(InvalidHeadersLength): + EventPrelude(total_length=total_length, headers_length=headers_length, crc=1) + + +def test_event_prelude_rejects_long_payload(): + total_length = MAX_PAYLOAD_LENGTH + 17 + with pytest.raises(InvalidPayloadLength): + EventPrelude(total_length=total_length, headers_length=0, crc=1) + + +def test_event_message_rejects_long_payload(): + payload = b"0" * (MAX_PAYLOAD_LENGTH + 1) + with pytest.raises(InvalidPayloadLength): + EventMessage(payload=payload) + + +def test_event_message_rejects_long_header_value(): + headers = {"foo": b"0" * (MAX_HEADER_VALUE_BYTE_LENGTH + 1)} + with pytest.raises(InvalidHeaderValueLength): + EventMessage(headers=headers).encode() + + +def test_event_message_rejects_long_headers(): + # 5 of these is more than enough to overcome the header size limit. + long_value = b"0" * (MAX_HEADER_VALUE_BYTE_LENGTH - 1) + headers = { + "1": long_value, + "2": long_value, + "3": long_value, + "4": long_value, + "5": long_value, + } + with pytest.raises(InvalidHeadersLength): + EventMessage(headers=headers).encode() + + # These are correctly encoded, and individually valid, but collectively too long. + long_headers = b"" + for i in range(5): + long_headers += b"\x01" + str(i).encode("utf-8") + b"\x06\x7f\xfe" + long_value + + with pytest.raises(InvalidHeadersLength): + EventMessage(headers_bytes=long_headers) + + +def test_event_message_decodes_headers(): + message = EventMessage(headers_bytes=b"\x04true\x00") + assert message.headers == {"true": True} + + +def test_event_encoder_rejects_long_headers(): + long_value = b"0" * (MAX_HEADER_VALUE_BYTE_LENGTH - 1) + long_headers = b"" + for i in range(5): + long_headers += b"\x01" + str(i).encode("utf-8") + b"\x06\x7f\xfe" + long_value + + with pytest.raises(InvalidHeadersLength): + EventEncoder().encode_bytes(headers=long_headers) + + +def test_event_encoder_rejects_long_payload(): + payload = b"0" * (MAX_PAYLOAD_LENGTH + 1) + with pytest.raises(InvalidPayloadLength): + EventEncoder().encode_bytes(payload=payload) + + +def test_event_encoder_encodes_bytes(): + expected = ( + b"\x00\x00\x00\x3d" # total length + b"\x00\x00\x00\x20" # headers length + b"\x07\xfd\x83\x96" # prelude crc + b"\x0ccontent-type\x07\x00\x10application/json" # headers + b"{'foo':'bar'}" # payload + b"\x8d\x9c\x08\xb1" # message crc + ) + headers = b"\x0ccontent-type\x07\x00\x10application/json" + payload = b"{'foo':'bar'}" + actual = EventEncoder().encode_bytes(headers=headers, payload=payload) + assert actual == expected + + +def test_encode_boolean_header(): + encoder = EventHeaderEncoder() + encoder.encode_boolean("foo", True) + assert encoder.get_result() == b"\x03foo\x00" + + encoder.clear() + encoder.encode_boolean("foo", False) + assert encoder.get_result() == b"\x03foo\x01" + + +def test_encode_byte_header(): + encoder = EventHeaderEncoder() + encoder.encode_byte("foo", 1) + assert encoder.get_result() == b"\x03foo\x02\x01" + + +def test_encode_too_long_byte_header(): + encoder = EventHeaderEncoder() + with pytest.raises(InvalidIntegerValue): + encoder.encode_byte("foo", 2**7) + + +def test_encode_short_header(): + encoder = EventHeaderEncoder() + encoder.encode_short("foo", 1) + assert encoder.get_result() == b"\x03foo\x03\x00\x01" + + +def test_encode_too_long_short_header(): + encoder = EventHeaderEncoder() + with pytest.raises(InvalidIntegerValue): + encoder.encode_short("foo", 2**15) + + +def test_encode_int_header(): + encoder = EventHeaderEncoder() + encoder.encode_integer("foo", 1) + assert encoder.get_result() == b"\x03foo\x04\x00\x00\x00\x01" + + +def test_encode_too_long_int_header(): + encoder = EventHeaderEncoder() + with pytest.raises(InvalidIntegerValue): + encoder.encode_integer("foo", 2**31) + + +def test_encode_long_header(): + encoder = EventHeaderEncoder() + encoder.encode_long("foo", 1) + assert encoder.get_result() == b"\x03foo\x05\x00\x00\x00\x00\x00\x00\x00\x01" + + +def test_encode_too_long_long_header(): + encoder = EventHeaderEncoder() + with pytest.raises(InvalidIntegerValue): + encoder.encode_long("foo", 2**63) + + +def test_encode_blob_header(): + encoder = EventHeaderEncoder() + encoder.encode_blob("foo", b"bytes") + assert encoder.get_result() == b"\x03foo\x06\x00\x05bytes" + + +def test_encode_string_header(): + encoder = EventHeaderEncoder() + encoder.encode_string("foo", "string") + assert encoder.get_result() == b"\x03foo\x07\x00\x06string" + + +def test_encode_timestamp_header(): + encoder = EventHeaderEncoder() + encoder.encode_timestamp( + "foo", datetime.datetime(1970, 1, 1, 0, 0, 0, 8000, tzinfo=datetime.UTC) + ) + assert encoder.get_result() == b"\x03foo\x08\x00\x00\x00\x00\x00\x00\x00\x08" + + +def test_encode_uuid_header(): + encoder = EventHeaderEncoder() + encoder.encode_uuid("foo", uuid.UUID(bytes=b"0123456789abcdef")) + assert encoder.get_result() == b"\x03foo\x090123456789abcdef" + + +def test_decode_bool_header(): + actual = EventHeaderDecoder(b"\x03foo\x00").decode_header() + assert actual == ("foo", True) + + actual = EventHeaderDecoder(b"\x03foo\x01").decode_header() + assert actual == ("foo", False) + + +def test_decode_byte_header(): + actual = EventHeaderDecoder(b"\x03foo\x02\x01").decode_header() + assert actual == ("foo", 1) + + +def test_decode_short_header(): + actual = EventHeaderDecoder(b"\x03foo\x03\x00\x01").decode_header() + assert actual == ("foo", 1) + + +def test_decode_integer_header(): + actual = EventHeaderDecoder(b"\x03foo\x04\x00\x00\x00\x01").decode_header() + assert actual == ("foo", 1) + + +def test_decode_long_header(): + actual = EventHeaderDecoder( + b"\x03foo\x05\x00\x00\x00\x00\x00\x00\x00\x01" + ).decode_header() + assert actual == ("foo", 1) + + +def test_decode_blob_header(): + actual = EventHeaderDecoder(b"\x03foo\x06\x00\x05bytes").decode_header() + assert actual == ("foo", b"bytes") + + +def test_decode_string_header(): + actual = EventHeaderDecoder(b"\x03foo\x07\x00\x06string").decode_header() + assert actual == ("foo", "string") + + +def test_decode_timestamp_header(): + actual = EventHeaderDecoder( + b"\x03foo\x08\x00\x00\x00\x00\x00\x00\x00\x08" + ).decode_header() + assert actual == ( + "foo", + datetime.datetime(1970, 1, 1, 0, 0, 0, 8000, tzinfo=datetime.UTC), + ) + + +def test_decode_uuid_header(): + actual = EventHeaderDecoder(b"\x03foo\x090123456789abcdef").decode_header() + assert actual == ("foo", uuid.UUID(bytes=b"0123456789abcdef"))