Skip to content

Commit

Permalink
Move multipart parsing to rust
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro committed Sep 22, 2024
1 parent 5028a02 commit 3d64bf2
Show file tree
Hide file tree
Showing 21 changed files with 1,542 additions and 78 deletions.
364 changes: 364 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,20 @@ crate-type = ["cdylib"]
[dependencies]
aes = "=0.8"
anyhow = "=1.0"
buf-read-ext = "=0.4"
cfb8 = "=0.8"
cfb-mode = "=0.8"
ctr = "=0.9"
encoding = "=0.2"
#form_urlencoded = "=1.2"
http = "=1.1"
httparse = "=1.9"
mime = "=0.3"
pyo3 = { version = "=0.22", features = ["anyhow", "extension-module", "generate-import-lib"] }
regex = "=1.10"
ring = "=0.16"
tempfile = "=3.12"
textnonce = "=1.0"

[target.'cfg(any(target_os = "freebsd", target_os = "windows"))'.dependencies]
mimalloc = { version = "0.1.43", default-features = false, features = ["local_dynamic_tls"] }
Expand Down
27 changes: 26 additions & 1 deletion emmett_core/_emmett_core.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Iterator, Optional, Tuple

__version__: str

Expand Down Expand Up @@ -46,3 +46,28 @@ class WSRouter:
def match_route_scheme(self, scheme: str, path: str) -> Tuple[Any, Dict[str, Any]]: ...
def match_route_host(self, host: str, path: str) -> Tuple[Any, Dict[str, Any]]: ...
def match_route_all(self, host: str, scheme: str, path: str) -> Tuple[Any, Dict[str, Any]]: ...

#: http
def get_content_type(header_value: str) -> Optional[str]: ...

#: multipart
class MultiPartReader:
def __init__(self, content_type_header_value: str): ...
def parse(self, data: bytes): ...
def contents(self) -> MultiPartContentsIter: ...

class MultiPartContentsIter:
def __iter__(self) -> Iterator[Tuple[str, bool, Any]]: ...

class FilePartReader:
content_type: Optional[str]
content_length: int
filename: Optional[str]

def read(self, size: Optional[int] = None) -> bytes: ...
def __iter__(self) -> Iterator[bytes]: ...

class MultiPartEncodingError(UnicodeDecodeError): ...
class MultiPartIOError(IOError): ...
class MultiPartParsingError(ValueError): ...
class MultiPartStateError(RuntimeError): ...
62 changes: 45 additions & 17 deletions emmett_core/http/wrappers/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import BinaryIO, Dict, Iterable, Iterator, MutableMapping, Optional, Tuple, Union
from typing import BinaryIO, Dict, Iterator, MutableMapping, Optional, Tuple, Union

from ..._io import loop_copyfileobj

Expand Down Expand Up @@ -50,35 +50,63 @@ def update(self, data: Dict[str, str]): # type: ignore
self._data.update(data)


# class FileStorage:
# __slots__ = ("stream", "filename", "name", "headers", "content_type")

# def __init__(
# self, stream: BinaryIO, filename: str, name: str = None, content_type: str = None, headers: Dict = None
# ):
# self.stream = stream
# self.filename = filename
# self.name = name
# self.headers = headers or {}
# self.content_type = content_type or self.headers.get("content-type")

# @property
# def content_length(self) -> int:
# return int(self.headers.get("content-length", 0))

# async def save(self, destination: Union[BinaryIO, str], buffer_size: int = 16384):
# close_destination = False
# if isinstance(destination, str):
# destination = open(destination, "wb")
# close_destination = True
# try:
# await loop_copyfileobj(self.stream, destination, buffer_size)
# finally:
# if close_destination:
# destination.close()

# def __iter__(self) -> Iterable[bytes]:
# return iter(self.stream)

# def __repr__(self) -> str:
# return f"<{self.__class__.__name__}: " f"{self.filename} ({self.content_type})"


class FileStorage:
__slots__ = ("stream", "filename", "name", "headers", "content_type")
__slots__ = ["inner"]

def __init__(
self, stream: BinaryIO, filename: str, name: str = None, content_type: str = None, headers: Dict = None
):
self.stream = stream
self.filename = filename
self.name = name
self.headers = headers or {}
self.content_type = content_type or self.headers.get("content-type")
def __init__(self, inner):
self.inner = inner

def __getattr__(self, name):
return getattr(self.inner, name)

@property
def content_length(self) -> int:
return int(self.headers.get("content-length", 0))
def size(self):
return self.inner.content_length

async def save(self, destination: Union[BinaryIO, str], buffer_size: int = 16384):
close_destination = False
if isinstance(destination, str):
destination = open(destination, "wb")
close_destination = True
try:
await loop_copyfileobj(self.stream, destination, buffer_size)
await loop_copyfileobj(self.inner, destination, buffer_size)
finally:
if close_destination:
destination.close()

def __iter__(self) -> Iterable[bytes]:
return iter(self.stream)

def __repr__(self) -> str:
return f"<{self.__class__.__name__}: " f"{self.filename} ({self.content_type})"
return f"<{self.__class__.__name__}: {self.filename} ({self.content_type})>"
74 changes: 37 additions & 37 deletions emmett_core/http/wrappers/request.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import datetime
from abc import abstractmethod
from cgi import FieldStorage, parse_header
from io import BytesIO
from typing import Any
from typing import Any, Optional
from urllib.parse import parse_qs

from ..._emmett_core import (
MultiPartEncodingError,
MultiPartParsingError,
MultiPartReader,
MultiPartStateError,
get_content_type,
)
from ...datastructures import sdict
from ...http.response import HTTPBytesResponse
from ...parsers import Parsers
from ...utils import cachedprop
from . import IngressWrapper
Expand All @@ -27,8 +34,8 @@ def now(self) -> datetime.datetime:
return self._now

@cachedprop
def content_type(self) -> str:
return parse_header(self.headers.get("content-type", ""))[0]
def content_type(self) -> Optional[str]:
return get_content_type(self.headers.get("content-type", "")) or "text/plain"

@cachedprop
def content_length(self) -> int:
Expand All @@ -52,18 +59,19 @@ async def files(self) -> sdict[str, FileStorage]:
_, rv = await self._input_params
return rv

def _load_params_missing(self, data):
async def _load_params_missing(self, body):
return sdict(), sdict()

def _load_params_json(self, data):
async def _load_params_json(self, body):
try:
params = Parsers.get_for("json")(data)
params = Parsers.get_for("json")(await body)
except Exception:
params = {}
return sdict(params), sdict()

def _load_params_form_urlencoded(self, data):
async def _load_params_form_urlencoded(self, body):
rv = sdict()
data = await body
for key, values in parse_qs(data.decode("latin-1"), keep_blank_values=True).items():
if len(values) == 1:
rv[key] = values[0]
Expand All @@ -79,35 +87,27 @@ def _multipart_headers(self):
def _file_param_from_field(field):
return FileStorage(BytesIO(field.file.read()), field.filename, field.name, field.type, field.headers)

def _load_params_form_multipart(self, data):
async def _load_params_form_multipart(self, body):
params, files = sdict(), sdict()
field_storage = FieldStorage(
BytesIO(data),
headers=self._multipart_headers,
environ={"REQUEST_METHOD": self.method},
keep_blank_values=True,
)
for key in field_storage:
field = field_storage[key]
if isinstance(field, list):
if len(field) > 1:
pvalues, fvalues = [], []
for item in field:
if item.filename is not None:
fvalues.append(self._file_param_from_field(item))
else:
pvalues.append(item.value)
if pvalues:
params[key] = pvalues
if fvalues:
files[key] = fvalues
continue
try:
parser = MultiPartReader(self.headers.get("content-type"))
async for chunk in body:
parser.parse(chunk)
for key, is_file, field in parser.contents():
if is_file:
files[key] = data = files[key] or []
data.append(FileStorage(field))
else:
field = field[0]
if field.filename is not None:
files[key] = self._file_param_from_field(field)
else:
params[key] = field.value
params[key] = data = params[key] or []
data.append(field.decode("utf8"))
except MultiPartEncodingError:
raise HTTPBytesResponse(400, "Invalid encoding")
except (MultiPartParsingError, MultiPartStateError):
raise HTTPBytesResponse(400, "Invalid multipart data")
for target in (params, files):
for key, val in target.items():
if len(val) == 1:
target[key] = val[0]
return params, files

_params_loaders = {
Expand All @@ -116,9 +116,9 @@ def _load_params_form_multipart(self, data):
"multipart/form-data": _load_params_form_multipart,
}

async def _load_params(self):
def _load_params(self):
loader = self._params_loaders.get(self.content_type, self._load_params_missing)
return loader(self, await self.body)
return loader(self, self.body)

@abstractmethod
async def push_promise(self, path: str): ...
27 changes: 27 additions & 0 deletions emmett_core/protocols/rsgi/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
import asyncio
from typing import AsyncGenerator

from ...http.response import HTTPBytesResponse


class BodyWrapper:
__slots__ = ["proto", "timeout"]

def __init__(self, proto, timeout):
self.proto = proto
self.timeout = timeout

def __await__(self):
if self.timeout:
return self._await_with_timeout().__await__()
return self.proto().__await__()

async def _await_with_timeout(self):
try:
rv = await asyncio.wait_for(self.proto(), timeout=self.timeout)
except asyncio.TimeoutError:
raise HTTPBytesResponse(408, b"Request timeout")
return rv

async def __aiter__(self) -> AsyncGenerator[bytes, None]:
async for chunk in self.proto:
yield chunk


class WSTransport:
Expand Down
23 changes: 16 additions & 7 deletions emmett_core/protocols/rsgi/test_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ....ctx import Current, RequestContext
from ....http.response import HTTPResponse, HTTPStringResponse
from ....http.wrappers.response import Response
from ....parsers import Parsers
from ....utils import cachedprop
from ..handlers import HTTPHandler
from ..wrappers import Request
Expand Down Expand Up @@ -75,7 +76,7 @@ async def dynamic_handler(self, scope, protocol, path):
class ClientHTTPHandler(ClientHTTPHandlerMixin, HTTPHandler): ...


class ClientResponse(object):
class ClientResponse:
def __init__(self, ctx, raw, status, headers):
self.context = ctx
self.raw = raw
Expand All @@ -91,10 +92,15 @@ def __exit__(self, exc_type, exc_value, tb):

@cachedprop
def data(self):
return self.raw.decode("utf8")
if isinstance(self.raw, bytes):
return self.raw.decode("utf8")
return self.raw

def json(self):
return Parsers.get_for("json")(self.data)

class EmmettTestClient(object):

class EmmettTestClient:
_current: Current
_handler_cls = ClientHTTPHandler

Expand Down Expand Up @@ -245,16 +251,19 @@ def __init__(self, body):
self.response_headers = []
self.response_body = b""
self.response_body_stream = None
self.consumed_input = False

async def __call__(self):
self.consumed_input = True
return self.input

def __aiter__(self):
async def inner():
return self.input
return self

for _ in range(1):
yield inner
async def __anext__(self):
if self.consumed_input:
raise StopAsyncIteration
return await self()

def response_empty(self, status, headers):
self.response_status = status
Expand Down
7 changes: 4 additions & 3 deletions emmett_core/protocols/rsgi/test_client/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def add_file(self, name, file, filename=None, content_type=None):
if filename and content_type is None:
content_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"
value = _FileHandler(file, filename, name, content_type)
self[name] = value
if name not in self:
self[name] = []
self[name].append(value)


def get_current_url(scope, root_only=False, strip_querystring=False, host_only=False):
Expand Down Expand Up @@ -291,8 +293,7 @@ def write(string):
else:
if not isinstance(value, str):
value = str(value)
else:
value = to_bytes(value, charset)
value = to_bytes(value, charset)
write("\r\n\r\n")
write_binary(value)
write("\r\n")
Expand Down
Loading

0 comments on commit 3d64bf2

Please sign in to comment.