From 3d64bf2c5d6521d2ff9cfb9babad07293f147f07 Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Sun, 22 Sep 2024 21:47:19 +0200 Subject: [PATCH] Move multipart parsing to rust --- Cargo.lock | 364 ++++++++++++++++++ Cargo.toml | 8 + emmett_core/_emmett_core.pyi | 27 +- emmett_core/http/wrappers/helpers.py | 62 ++- emmett_core/http/wrappers/request.py | 74 ++-- emmett_core/protocols/rsgi/helpers.py | 27 ++ .../protocols/rsgi/test_client/client.py | 23 +- .../protocols/rsgi/test_client/helpers.py | 7 +- .../protocols/rsgi/test_client/scope.py | 14 +- emmett_core/protocols/rsgi/wrappers.py | 10 +- src/http/headers.rs | 8 + src/http/mod.rs | 14 + src/lib.rs | 6 +- src/multipart/errors.rs | 48 +++ src/multipart/mod.rs | 27 ++ src/multipart/parse.rs | 329 ++++++++++++++++ src/multipart/parts.rs | 180 +++++++++ src/multipart/utils.rs | 45 +++ tests/multipart/conftest.py | 43 +++ tests/multipart/test_multipart.py | 303 +++++++++++++++ tests/routing/test_router.py | 1 - 21 files changed, 1542 insertions(+), 78 deletions(-) create mode 100644 src/http/headers.rs create mode 100644 src/http/mod.rs create mode 100644 src/multipart/errors.rs create mode 100644 src/multipart/mod.rs create mode 100644 src/multipart/parse.rs create mode 100644 src/multipart/parts.rs create mode 100644 src/multipart/utils.rs create mode 100644 tests/multipart/conftest.py create mode 100644 tests/multipart/test_multipart.py diff --git a/Cargo.lock b/Cargo.lock index 8e2980b..83e957c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -34,12 +34,42 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "base64" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "buf-read-ext" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e2c71c44e5bbc64de4ecfac946e05f9bba5cc296ea7bab4d3eda242a3ffa73c" + [[package]] name = "bumpalo" version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" + [[package]] name = "cc" version = "1.1.18" @@ -117,16 +147,109 @@ version = "0.1.0" dependencies = [ "aes", "anyhow", + "buf-read-ext", "cfb-mode", "cfb8", "ctr", + "encoding", + "http", + "httparse", "mimalloc", + "mime", "pyo3", "regex", "ring", + "tempfile", + "textnonce", "tikv-jemallocator", ] +[[package]] +name = "encoding" +version = "0.2.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b0d943856b990d12d3b55b359144ff341533e516d94098b1d3fc1ac666d36ec" +dependencies = [ + "encoding-index-japanese", + "encoding-index-korean", + "encoding-index-simpchinese", + "encoding-index-singlebyte", + "encoding-index-tradchinese", +] + +[[package]] +name = "encoding-index-japanese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e8b2ff42e9a05335dbf8b5c6f7567e5591d0d916ccef4e0b1710d32a0d0c91" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-korean" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dc33fb8e6bcba213fe2f14275f0963fd16f0a02c878e3095ecfdf5bee529d81" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-simpchinese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d87a7194909b9118fc707194baa434a4e3b0fb6a5a757c73c3adb07aa25031f7" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-singlebyte" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3351d5acffb224af9ca265f435b859c7c01537c0849754d3db3fdf2bfe2ae84a" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-tradchinese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd0e20d5688ce3cab59eb3ef3a2083a5c77bf496cb798dc6fcdb75f323890c18" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding_index_tests" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a246d82be1c9d791c5dfde9a2bd045fc3cbba3fa2b11ad558f27d01712f00569" + +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "fastrand" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "generic-array" version = "0.14.7" @@ -137,12 +260,40 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "httparse" +version = "1.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" + [[package]] name = "indoc" version = "2.0.5" @@ -158,6 +309,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + [[package]] name = "js-sys" version = "0.3.70" @@ -183,6 +340,12 @@ dependencies = [ "libc", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + [[package]] name = "log" version = "0.4.22" @@ -213,6 +376,12 @@ dependencies = [ "libmimalloc-sys", ] +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "once_cell" version = "1.19.0" @@ -225,6 +394,15 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -317,6 +495,47 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +dependencies = [ + "getrandom", + "libc", + "rand_chacha", + "rand_core", + "rand_hc", +] + +[[package]] +name = "rand_chacha" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +dependencies = [ + "rand_core", +] + [[package]] name = "regex" version = "1.10.6" @@ -361,6 +580,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "rustix" +version = "0.38.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + [[package]] name = "shlex" version = "1.3.0" @@ -390,6 +622,29 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "tempfile" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + +[[package]] +name = "textnonce" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7743f8d70cd784ed1dc33106a18998d77758d281dc40dc3e6d050cf0f5286683" +dependencies = [ + "base64", + "rand", +] + [[package]] name = "tikv-jemalloc-sys" version = "0.6.0+5.3.0-1-ge13ca993e8ccb9ba9847cc330696e02839f328f7" @@ -440,6 +695,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + [[package]] name = "wasm-bindgen" version = "0.2.93" @@ -526,3 +787,106 @@ name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 9401edd..d9e341a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,12 +33,20 @@ crate-type = ["cdylib"] [dependencies] aes = "=0.8" anyhow = "=1.0" +buf-read-ext = "=0.4" cfb8 = "=0.8" cfb-mode = "=0.8" ctr = "=0.9" +encoding = "=0.2" +#form_urlencoded = "=1.2" +http = "=1.1" +httparse = "=1.9" +mime = "=0.3" pyo3 = { version = "=0.22", features = ["anyhow", "extension-module", "generate-import-lib"] } regex = "=1.10" ring = "=0.16" +tempfile = "=3.12" +textnonce = "=1.0" [target.'cfg(any(target_os = "freebsd", target_os = "windows"))'.dependencies] mimalloc = { version = "0.1.43", default-features = false, features = ["local_dynamic_tls"] } diff --git a/emmett_core/_emmett_core.pyi b/emmett_core/_emmett_core.pyi index 5720c2c..54d7fac 100644 --- a/emmett_core/_emmett_core.pyi +++ b/emmett_core/_emmett_core.pyi @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple __version__: str @@ -46,3 +46,28 @@ class WSRouter: def match_route_scheme(self, scheme: str, path: str) -> Tuple[Any, Dict[str, Any]]: ... def match_route_host(self, host: str, path: str) -> Tuple[Any, Dict[str, Any]]: ... def match_route_all(self, host: str, scheme: str, path: str) -> Tuple[Any, Dict[str, Any]]: ... + +#: http +def get_content_type(header_value: str) -> Optional[str]: ... + +#: multipart +class MultiPartReader: + def __init__(self, content_type_header_value: str): ... + def parse(self, data: bytes): ... + def contents(self) -> MultiPartContentsIter: ... + +class MultiPartContentsIter: + def __iter__(self) -> Iterator[Tuple[str, bool, Any]]: ... + +class FilePartReader: + content_type: Optional[str] + content_length: int + filename: Optional[str] + + def read(self, size: Optional[int] = None) -> bytes: ... + def __iter__(self) -> Iterator[bytes]: ... + +class MultiPartEncodingError(UnicodeDecodeError): ... +class MultiPartIOError(IOError): ... +class MultiPartParsingError(ValueError): ... +class MultiPartStateError(RuntimeError): ... diff --git a/emmett_core/http/wrappers/helpers.py b/emmett_core/http/wrappers/helpers.py index 2143d8d..fdc64f7 100644 --- a/emmett_core/http/wrappers/helpers.py +++ b/emmett_core/http/wrappers/helpers.py @@ -1,5 +1,5 @@ import re -from typing import BinaryIO, Dict, Iterable, Iterator, MutableMapping, Optional, Tuple, Union +from typing import BinaryIO, Dict, Iterator, MutableMapping, Optional, Tuple, Union from ..._io import loop_copyfileobj @@ -50,21 +50,52 @@ def update(self, data: Dict[str, str]): # type: ignore self._data.update(data) +# class FileStorage: +# __slots__ = ("stream", "filename", "name", "headers", "content_type") + +# def __init__( +# self, stream: BinaryIO, filename: str, name: str = None, content_type: str = None, headers: Dict = None +# ): +# self.stream = stream +# self.filename = filename +# self.name = name +# self.headers = headers or {} +# self.content_type = content_type or self.headers.get("content-type") + +# @property +# def content_length(self) -> int: +# return int(self.headers.get("content-length", 0)) + +# async def save(self, destination: Union[BinaryIO, str], buffer_size: int = 16384): +# close_destination = False +# if isinstance(destination, str): +# destination = open(destination, "wb") +# close_destination = True +# try: +# await loop_copyfileobj(self.stream, destination, buffer_size) +# finally: +# if close_destination: +# destination.close() + +# def __iter__(self) -> Iterable[bytes]: +# return iter(self.stream) + +# def __repr__(self) -> str: +# return f"<{self.__class__.__name__}: " f"{self.filename} ({self.content_type})" + + class FileStorage: - __slots__ = ("stream", "filename", "name", "headers", "content_type") + __slots__ = ["inner"] - def __init__( - self, stream: BinaryIO, filename: str, name: str = None, content_type: str = None, headers: Dict = None - ): - self.stream = stream - self.filename = filename - self.name = name - self.headers = headers or {} - self.content_type = content_type or self.headers.get("content-type") + def __init__(self, inner): + self.inner = inner + + def __getattr__(self, name): + return getattr(self.inner, name) @property - def content_length(self) -> int: - return int(self.headers.get("content-length", 0)) + def size(self): + return self.inner.content_length async def save(self, destination: Union[BinaryIO, str], buffer_size: int = 16384): close_destination = False @@ -72,13 +103,10 @@ async def save(self, destination: Union[BinaryIO, str], buffer_size: int = 16384 destination = open(destination, "wb") close_destination = True try: - await loop_copyfileobj(self.stream, destination, buffer_size) + await loop_copyfileobj(self.inner, destination, buffer_size) finally: if close_destination: destination.close() - def __iter__(self) -> Iterable[bytes]: - return iter(self.stream) - def __repr__(self) -> str: - return f"<{self.__class__.__name__}: " f"{self.filename} ({self.content_type})" + return f"<{self.__class__.__name__}: {self.filename} ({self.content_type})>" diff --git a/emmett_core/http/wrappers/request.py b/emmett_core/http/wrappers/request.py index c2b0b9c..8632c27 100644 --- a/emmett_core/http/wrappers/request.py +++ b/emmett_core/http/wrappers/request.py @@ -1,11 +1,18 @@ import datetime from abc import abstractmethod -from cgi import FieldStorage, parse_header from io import BytesIO -from typing import Any +from typing import Any, Optional from urllib.parse import parse_qs +from ..._emmett_core import ( + MultiPartEncodingError, + MultiPartParsingError, + MultiPartReader, + MultiPartStateError, + get_content_type, +) from ...datastructures import sdict +from ...http.response import HTTPBytesResponse from ...parsers import Parsers from ...utils import cachedprop from . import IngressWrapper @@ -27,8 +34,8 @@ def now(self) -> datetime.datetime: return self._now @cachedprop - def content_type(self) -> str: - return parse_header(self.headers.get("content-type", ""))[0] + def content_type(self) -> Optional[str]: + return get_content_type(self.headers.get("content-type", "")) or "text/plain" @cachedprop def content_length(self) -> int: @@ -52,18 +59,19 @@ async def files(self) -> sdict[str, FileStorage]: _, rv = await self._input_params return rv - def _load_params_missing(self, data): + async def _load_params_missing(self, body): return sdict(), sdict() - def _load_params_json(self, data): + async def _load_params_json(self, body): try: - params = Parsers.get_for("json")(data) + params = Parsers.get_for("json")(await body) except Exception: params = {} return sdict(params), sdict() - def _load_params_form_urlencoded(self, data): + async def _load_params_form_urlencoded(self, body): rv = sdict() + data = await body for key, values in parse_qs(data.decode("latin-1"), keep_blank_values=True).items(): if len(values) == 1: rv[key] = values[0] @@ -79,35 +87,27 @@ def _multipart_headers(self): def _file_param_from_field(field): return FileStorage(BytesIO(field.file.read()), field.filename, field.name, field.type, field.headers) - def _load_params_form_multipart(self, data): + async def _load_params_form_multipart(self, body): params, files = sdict(), sdict() - field_storage = FieldStorage( - BytesIO(data), - headers=self._multipart_headers, - environ={"REQUEST_METHOD": self.method}, - keep_blank_values=True, - ) - for key in field_storage: - field = field_storage[key] - if isinstance(field, list): - if len(field) > 1: - pvalues, fvalues = [], [] - for item in field: - if item.filename is not None: - fvalues.append(self._file_param_from_field(item)) - else: - pvalues.append(item.value) - if pvalues: - params[key] = pvalues - if fvalues: - files[key] = fvalues - continue + try: + parser = MultiPartReader(self.headers.get("content-type")) + async for chunk in body: + parser.parse(chunk) + for key, is_file, field in parser.contents(): + if is_file: + files[key] = data = files[key] or [] + data.append(FileStorage(field)) else: - field = field[0] - if field.filename is not None: - files[key] = self._file_param_from_field(field) - else: - params[key] = field.value + params[key] = data = params[key] or [] + data.append(field.decode("utf8")) + except MultiPartEncodingError: + raise HTTPBytesResponse(400, "Invalid encoding") + except (MultiPartParsingError, MultiPartStateError): + raise HTTPBytesResponse(400, "Invalid multipart data") + for target in (params, files): + for key, val in target.items(): + if len(val) == 1: + target[key] = val[0] return params, files _params_loaders = { @@ -116,9 +116,9 @@ def _load_params_form_multipart(self, data): "multipart/form-data": _load_params_form_multipart, } - async def _load_params(self): + def _load_params(self): loader = self._params_loaders.get(self.content_type, self._load_params_missing) - return loader(self, await self.body) + return loader(self, self.body) @abstractmethod async def push_promise(self, path: str): ... diff --git a/emmett_core/protocols/rsgi/helpers.py b/emmett_core/protocols/rsgi/helpers.py index 6637f6c..6f8f6c3 100644 --- a/emmett_core/protocols/rsgi/helpers.py +++ b/emmett_core/protocols/rsgi/helpers.py @@ -1,4 +1,31 @@ import asyncio +from typing import AsyncGenerator + +from ...http.response import HTTPBytesResponse + + +class BodyWrapper: + __slots__ = ["proto", "timeout"] + + def __init__(self, proto, timeout): + self.proto = proto + self.timeout = timeout + + def __await__(self): + if self.timeout: + return self._await_with_timeout().__await__() + return self.proto().__await__() + + async def _await_with_timeout(self): + try: + rv = await asyncio.wait_for(self.proto(), timeout=self.timeout) + except asyncio.TimeoutError: + raise HTTPBytesResponse(408, b"Request timeout") + return rv + + async def __aiter__(self) -> AsyncGenerator[bytes, None]: + async for chunk in self.proto: + yield chunk class WSTransport: diff --git a/emmett_core/protocols/rsgi/test_client/client.py b/emmett_core/protocols/rsgi/test_client/client.py index 5efe860..de9c120 100644 --- a/emmett_core/protocols/rsgi/test_client/client.py +++ b/emmett_core/protocols/rsgi/test_client/client.py @@ -7,6 +7,7 @@ from ....ctx import Current, RequestContext from ....http.response import HTTPResponse, HTTPStringResponse from ....http.wrappers.response import Response +from ....parsers import Parsers from ....utils import cachedprop from ..handlers import HTTPHandler from ..wrappers import Request @@ -75,7 +76,7 @@ async def dynamic_handler(self, scope, protocol, path): class ClientHTTPHandler(ClientHTTPHandlerMixin, HTTPHandler): ... -class ClientResponse(object): +class ClientResponse: def __init__(self, ctx, raw, status, headers): self.context = ctx self.raw = raw @@ -91,10 +92,15 @@ def __exit__(self, exc_type, exc_value, tb): @cachedprop def data(self): - return self.raw.decode("utf8") + if isinstance(self.raw, bytes): + return self.raw.decode("utf8") + return self.raw + def json(self): + return Parsers.get_for("json")(self.data) -class EmmettTestClient(object): + +class EmmettTestClient: _current: Current _handler_cls = ClientHTTPHandler @@ -245,16 +251,19 @@ def __init__(self, body): self.response_headers = [] self.response_body = b"" self.response_body_stream = None + self.consumed_input = False async def __call__(self): + self.consumed_input = True return self.input def __aiter__(self): - async def inner(): - return self.input + return self - for _ in range(1): - yield inner + async def __anext__(self): + if self.consumed_input: + raise StopAsyncIteration + return await self() def response_empty(self, status, headers): self.response_status = status diff --git a/emmett_core/protocols/rsgi/test_client/helpers.py b/emmett_core/protocols/rsgi/test_client/helpers.py index ec3a20c..5ccf690 100644 --- a/emmett_core/protocols/rsgi/test_client/helpers.py +++ b/emmett_core/protocols/rsgi/test_client/helpers.py @@ -182,7 +182,9 @@ def add_file(self, name, file, filename=None, content_type=None): if filename and content_type is None: content_type = mimetypes.guess_type(filename)[0] or "application/octet-stream" value = _FileHandler(file, filename, name, content_type) - self[name] = value + if name not in self: + self[name] = [] + self[name].append(value) def get_current_url(scope, root_only=False, strip_querystring=False, host_only=False): @@ -291,8 +293,7 @@ def write(string): else: if not isinstance(value, str): value = str(value) - else: - value = to_bytes(value, charset) + value = to_bytes(value, charset) write("\r\n\r\n") write_binary(value) write("\r\n") diff --git a/emmett_core/protocols/rsgi/test_client/scope.py b/emmett_core/protocols/rsgi/test_client/scope.py index 5b3949a..66b825f 100644 --- a/emmett_core/protocols/rsgi/test_client/scope.py +++ b/emmett_core/protocols/rsgi/test_client/scope.py @@ -1,6 +1,6 @@ -import cgi import sys from io import BytesIO +from urllib.parse import parse_qs from ....datastructures import sdict from .helpers import Headers, filesdict, stream_encode_multipart @@ -92,6 +92,7 @@ def __init__( headers=None, data=None, charset="utf-8", + boundary=None, ): if query_string is None and "?" in path: path, query_string = path.split("?", 1) @@ -121,6 +122,7 @@ def __init__( self.errors_stream = errors_stream self.input_stream = input_stream self.content_length = content_length + self.boundary = boundary self.closed = False if data: @@ -146,7 +148,7 @@ def __init__( @staticmethod def _parse_querystring(query_string): - dget = cgi.parse_qs(query_string, keep_blank_values=1) + dget = parse_qs(query_string, keep_blank_values=1) params = sdict(dget) for key, value in params.items(): if isinstance(value, list) and len(value) == 1: @@ -321,8 +323,12 @@ def get_data(self): values = sdict() for d in [self.files, self.form]: for key, val in d.items(): - values[key] = val - input_stream, content_length, boundary = stream_encode_multipart(values, charset=self.charset) + if key not in values: + values[key] = [] + values[key].extend(val) + input_stream, content_length, boundary = stream_encode_multipart( + values, charset=self.charset, boundary=self.boundary + ) content_type += '; boundary="%s"' % boundary elif content_type == "application/x-www-form-urlencoded": values = url_encode(self.form, charset=self.charset) diff --git a/emmett_core/protocols/rsgi/wrappers.py b/emmett_core/protocols/rsgi/wrappers.py index 145cba3..c47b6f5 100644 --- a/emmett_core/protocols/rsgi/wrappers.py +++ b/emmett_core/protocols/rsgi/wrappers.py @@ -1,4 +1,3 @@ -import asyncio from datetime import datetime from typing import Any, Dict, List, Optional, Union from urllib.parse import parse_qs @@ -9,6 +8,7 @@ from ...http.wrappers.request import Request as _Request from ...http.wrappers.websocket import Websocket as _Websocket from ...utils import cachedprop +from .helpers import BodyWrapper class RSGIIngressMixin: @@ -66,14 +66,10 @@ def _multipart_headers(self): return dict(self.headers.items()) @cachedprop - async def body(self) -> bytes: + def body(self) -> BodyWrapper: if self.max_content_length and self.content_length > self.max_content_length: raise HTTPBytesResponse(413, b"Request entity too large") - try: - rv = await asyncio.wait_for(self._proto(), timeout=self.body_timeout) - except asyncio.TimeoutError: - raise HTTPBytesResponse(408, b"Request timeout") - return rv + return BodyWrapper(self._proto, self.body_timeout) async def push_promise(self, path: str): raise NotImplementedError("RSGI protocol doesn't support HTTP2 push.") diff --git a/src/http/headers.rs b/src/http/headers.rs new file mode 100644 index 0000000..3dda1b8 --- /dev/null +++ b/src/http/headers.rs @@ -0,0 +1,8 @@ +use anyhow::Result; +use mime::Mime; + +#[inline] +pub(crate) fn get_content_type(header_value: &str) -> Result { + let mime: Mime = header_value.parse()?; + Ok(mime.essence_str().to_owned()) +} diff --git a/src/http/mod.rs b/src/http/mod.rs new file mode 100644 index 0000000..40ab78c --- /dev/null +++ b/src/http/mod.rs @@ -0,0 +1,14 @@ +use pyo3::prelude::*; + +pub(crate) mod headers; + +#[pyfunction] +fn get_content_type(header_value: &str) -> Option { + headers::get_content_type(header_value).ok() +} + +pub(crate) fn init_pymodule(module: &Bound) -> PyResult<()> { + module.add_function(wrap_pyfunction!(get_content_type, module)?)?; + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index f9b0e21..3c93f3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,8 @@ use pyo3::prelude::*; use std::sync::OnceLock; mod cryptography; +mod http; +mod multipart; mod routing; pub fn get_lib_version() -> &'static str { @@ -22,10 +24,12 @@ pub fn get_lib_version() -> &'static str { } #[pymodule] -fn _emmett_core(_py: Python, module: &Bound) -> PyResult<()> { +fn _emmett_core(py: Python, module: &Bound) -> PyResult<()> { module.add("__version__", get_lib_version())?; cryptography::init_pymodule(module)?; + http::init_pymodule(module)?; + multipart::init_pymodule(py, module)?; routing::init_pymodule(module)?; Ok(()) } diff --git a/src/multipart/errors.rs b/src/multipart/errors.rs new file mode 100644 index 0000000..b57ac1c --- /dev/null +++ b/src/multipart/errors.rs @@ -0,0 +1,48 @@ +use pyo3::{ + create_exception, + exceptions::{PyIOError, PyRuntimeError, PyUnicodeDecodeError, PyValueError}, +}; + +create_exception!( + _emmett_core, + MultiPartEncodingError, + PyUnicodeDecodeError, + "MultiPartEncodingError" +); +create_exception!( + _emmett_core, + MultiPartParsingError, + PyValueError, + "MultiPartParsingError" +); +create_exception!(_emmett_core, MultiPartStateError, PyRuntimeError, "MultiPartStateError"); +create_exception!(_emmett_core, MultiPartIOError, PyIOError, "MultiPartIOError"); + +macro_rules! error_encoding { + () => { + super::errors::MultiPartEncodingError::new_err("multipart encoding error").into() + }; +} + +macro_rules! error_io { + () => { + super::errors::MultiPartIOError::new_err("cannot open fd").into() + }; +} + +macro_rules! error_parsing { + ($msg:tt) => { + super::errors::MultiPartParsingError::new_err($msg).into() + }; +} + +macro_rules! error_state { + () => { + super::errors::MultiPartStateError::new_err("parsing incomplete").into() + }; +} + +pub(crate) use error_encoding; +pub(crate) use error_io; +pub(crate) use error_parsing; +pub(crate) use error_state; diff --git a/src/multipart/mod.rs b/src/multipart/mod.rs new file mode 100644 index 0000000..b8183cf --- /dev/null +++ b/src/multipart/mod.rs @@ -0,0 +1,27 @@ +use pyo3::prelude::*; + +mod errors; +mod parse; +mod parts; +mod utils; + +pub(crate) fn init_pymodule(py: Python, module: &Bound) -> PyResult<()> { + module.add( + "MultiPartEncodingError", + py.get_type_bound::(), + )?; + module.add("MultiPartIOError", py.get_type_bound::())?; + module.add( + "MultiPartParsingError", + py.get_type_bound::(), + )?; + module.add( + "MultiPartStateError", + py.get_type_bound::(), + )?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + + Ok(()) +} diff --git a/src/multipart/parse.rs b/src/multipart/parse.rs new file mode 100644 index 0000000..86fb7bc --- /dev/null +++ b/src/multipart/parse.rs @@ -0,0 +1,329 @@ +use anyhow::Result; +use buf_read_ext::BufReadExt; +use http::{ + header::{self, HeaderMap}, + HeaderName, HeaderValue, +}; +use mime::{self, Mime}; +use pyo3::{exceptions::PyStopIteration, prelude::*, types::PyBytes}; +use std::{ + borrow::Cow, + collections::VecDeque, + io::{BufRead, Cursor}, + mem, +}; + +use super::{ + errors::{error_parsing, error_state}, + parts::{FilePart, FilePartReader, Node, Part}, + utils::charset_decode, +}; + +enum MultiPartParserState { + Clean, + Termination, + Headers, + Value(Part), + File(FilePart), + Nested(Box), +} + +impl Default for MultiPartParserState { + fn default() -> Self { + Self::Clean + } +} + +struct MultiPartParser { + boundaries: (Vec, Vec, Vec), + encoding: String, + state: MultiPartParserState, + buffer: Vec, + pub consumed: bool, + stack: VecDeque, +} + +impl MultiPartParser { + fn new(boundaries: (Vec, Vec, Vec), encoding: String) -> Self { + Self { + boundaries, + encoding, + state: MultiPartParserState::Clean, + buffer: Vec::new(), + consumed: false, + stack: VecDeque::new(), + } + } + + fn take(self) -> VecDeque { + self.stack + } + + fn parse_chunk(&mut self, reader: &mut Cursor) -> Result<()> + where + T: AsRef<[u8]>, + { + let (lt, ltlt, lt_boundary) = &self.boundaries; + + loop { + let peeker = reader.fill_buf()?; + + if let MultiPartParserState::Clean = self.state { + // If the last chunk is empty and we're in clean state there's nothing to do. + if peeker.is_empty() { + return Ok(()); + } + + // If the next two lookahead characters are '--', parsing is finished. + if peeker.len() >= 2 && &peeker[..2] == b"--" { + self.consumed = true; + return Ok(()); + } + + self.state = MultiPartParserState::Termination; + } + + if let MultiPartParserState::Termination = self.state { + // Read the line terminator after the boundary + let (_, found) = reader.stream_until_token(lt, &mut self.buffer)?; + if !found { + return Ok(()); + } + + self.buffer.truncate(0); + self.state = MultiPartParserState::Headers; + } + + if let MultiPartParserState::Headers = self.state { + // Read the headers (which end in 2 line terminators) + let (_, found) = reader.stream_until_token(ltlt, &mut self.buffer)?; + if !found { + return Ok(()); + } + + // Keep the 2 line terminators as httparse will expect it + self.buffer.extend(ltlt.iter().copied()); + + let part_headers = { + let mut header_memory = [httparse::EMPTY_HEADER; 4]; + match httparse::parse_headers(&self.buffer, &mut header_memory) { + Ok(httparse::Status::Complete((_, raw_headers))) => { + let mut headers = HeaderMap::new(); + for header in raw_headers { + let name = HeaderName::try_from(header.name)?; + let value = HeaderValue::from_bytes(header.value)?; + headers.insert(name, value); + } + Ok::(headers) + } + Ok(httparse::Status::Partial) => Err(error_parsing!("incomplete headers")), + Err(_) => Err(error_parsing!("bad headers")), + }? + }; + + // Check for a nested multipart + let mut nested = false; + if let Some(ct) = part_headers.get(header::CONTENT_TYPE) { + let mime: Mime = ct.to_str()?.parse()?; + if mime.type_() == "multipart" { + nested = true; + } + } + if nested { + let inner = MultiPartParser::new(self.boundaries.clone(), self.encoding.clone()); + self.state = MultiPartParserState::Nested(Box::new(inner)); + continue; + } + + let mut is_file = false; + if let Some(cd) = part_headers.get(header::CONTENT_DISPOSITION) { + let cds = charset_decode(&self.encoding, cd.as_bytes())?; + let cd_type = cds.split(';').next().unwrap_or(""); + if cd_type == "attachment" { + is_file = true; + } else { + let cd_params = cds.split_once(';').unwrap_or(("", "")).1; + let mime: Mime = match format!("*/*; {cd_params}").parse() { + Ok(v) => v, + _ => Err::(error_parsing!("foo"))?, + }; + is_file = mime.get_param("filename").is_some(); + } + }; + + match is_file { + true => { + let filepart = FilePart::new(part_headers, &self.encoding)?; + self.state = MultiPartParserState::File(filepart); + } + false => { + let part = Part::new(part_headers, &self.encoding)?; + self.state = MultiPartParserState::Value(part); + } + } + } + + if let MultiPartParserState::Nested(nested) = &mut self.state { + let nres = nested.parse_chunk(reader); + if nres.is_err() || !nested.consumed { + return nres; + } + + let state = mem::take(&mut self.state); + match state { + MultiPartParserState::Nested(nested) => { + let nodes = nested.take(); + self.stack.extend(nodes); + } + _ => unreachable!(), + } + } + + if let MultiPartParserState::Value(part) = &mut self.state { + let (_, found) = reader.stream_until_token(lt_boundary, &mut part.value)?; + if !found { + return Ok(()); + } + + let state = mem::take(&mut self.state); + match state { + MultiPartParserState::Value(part) => { + self.stack.push_back(Node::Part(part)); + } + _ => unreachable!(), + } + } + + if let MultiPartParserState::File(filepart) = &mut self.state { + let (read, found) = reader.stream_until_token( + lt_boundary, + &mut filepart.file.as_mut().expect("uninitialized file part"), + )?; + let size = filepart.size.unwrap_or(0); + filepart.size = Some(size + read); + + if !found { + return Ok(()); + } + + let state = mem::take(&mut self.state); + match state { + MultiPartParserState::File(part) => { + self.stack.push_back(Node::File(part)); + } + _ => unreachable!(), + } + } + } + } +} + +#[pyclass(module = "emmett_core._emmett_core")] +pub(super) struct MultiPartReader { + boundary: Vec, + encoding: String, + inner: Option, +} + +#[pymethods] +impl MultiPartReader { + #[new] + fn new(content_type_header_value: &str) -> Result { + let (boundary, charset) = get_multipart_params(content_type_header_value)?; + Ok(Self { + boundary, + encoding: charset, + inner: None, + }) + } + + fn parse(&mut self, data: Cow<[u8]>) -> Result<()> { + if let Some(inner) = &mut self.inner { + let mut reader = Cursor::new(data); + return inner.parse_chunk(&mut reader); + } + + let mut buf = Vec::new(); + let mut reader = Cursor::new(data); + let (_, found) = reader.stream_until_token(&self.boundary, &mut buf)?; + if !found { + return Err(error_parsing!("EOF before first boundary")); + } + + let read_boundaries = { + let peeker = reader.fill_buf()?; + if peeker.len() > 1 && &peeker[..2] == b"\r\n" { + let mut output = Vec::with_capacity(2 + self.boundary.len()); + output.push(b'\r'); + output.push(b'\n'); + output.extend(self.boundary.clone()); + (vec![b'\r', b'\n'], vec![b'\r', b'\n', b'\r', b'\n'], output) + } else if !peeker.is_empty() && peeker[0] == b'\n' { + let mut output = Vec::with_capacity(1 + self.boundary.len()); + output.push(b'\n'); + output.extend(self.boundary.clone()); + (vec![b'\n'], vec![b'\n', b'\n'], output) + } else { + return Err(error_parsing!("no CrLf after boundary")); + } + }; + self.inner = Some(MultiPartParser::new(read_boundaries, self.encoding.clone())); + self.inner.as_mut().unwrap().parse_chunk(&mut reader) + } + + fn contents(&mut self, py: Python) -> Result> { + if let Some(mut inner) = self.inner.take() { + if !inner.consumed { + return Err(error_state!()); + } + let nodes = mem::take(&mut inner.stack); + return Ok(Py::new(py, MultiPartContentsIter { inner: nodes })?); + } + Err(error_state!()) + } +} + +#[pyclass(module = "emmett_core._emmett_core")] +pub(super) struct MultiPartContentsIter { + inner: VecDeque, +} + +#[pymethods] +impl MultiPartContentsIter { + fn __iter__(pyself: PyRef) -> PyRef { + pyself + } + + fn __next__(&mut self, py: Python) -> PyResult<(String, bool, PyObject)> { + if let Some(item) = self.inner.pop_front() { + return match item { + Node::Part(node) => Ok((node.name, false, PyBytes::new_bound(py, &node.value[..]).into_py(py))), + Node::File(node) => Ok(( + node.name.clone(), + true, + Py::new(py, FilePartReader::new(node)?)?.into_py(py), + )), + }; + } + Err(PyStopIteration::new_err(py.None())) + } +} + +fn get_multipart_params(content_type_header_value: &str) -> Result<(Vec, String)> { + let mime: mime::Mime = content_type_header_value.parse()?; + if mime.type_() != mime::MULTIPART { + return Err(error_parsing!("not multipart")); + } + + if let Some(raw_boundary) = mime.get_param(mime::BOUNDARY) { + let rbs = raw_boundary.as_str(); + let mut boundary = Vec::with_capacity(2 + rbs.len()); + boundary.extend(b"--".iter().copied()); + boundary.extend(rbs.as_bytes()); + + let charset = mime.get_param(mime::CHARSET).map_or("utf-8", |v| v.as_str()); + return Ok((boundary, charset.to_owned())); + } + + Err(error_parsing!("boundary not specified")) +} diff --git a/src/multipart/parts.rs b/src/multipart/parts.rs new file mode 100644 index 0000000..0f2aa79 --- /dev/null +++ b/src/multipart/parts.rs @@ -0,0 +1,180 @@ +use anyhow::Result; +use http::header::{self, HeaderMap}; +use pyo3::{exceptions::PyStopIteration, prelude::*, types::PyBytes}; +use std::{ + fs::File, + io::{BufRead, BufReader, Read}, + os::unix::fs::MetadataExt, + path::PathBuf, +}; +use textnonce::TextNonce; + +use super::{ + errors::{error_io, error_parsing}, + utils::get_mime_param_encoded, +}; +use crate::http::headers::get_content_type; + +pub(super) enum Node { + Part(Part), + File(FilePart), +} + +pub(super) struct Part { + pub name: String, + pub value: Vec, +} +impl Part { + pub fn new(headers: HeaderMap, encoding: &str) -> Result { + let cd = headers + .get(header::CONTENT_DISPOSITION) + .ok_or::(error_parsing!("missing content disposition"))?; + let name = get_mime_param_encoded(cd, "name", encoding)? + .ok_or::(error_parsing!("missing name field"))?; + + Ok(Self { + name, + value: Vec::new(), + }) + } +} + +pub(super) struct FilePart { + pub headers: HeaderMap, + pub name: String, + filename: Option, + path: PathBuf, + pub file: Option, + pub size: Option, + tempdir: Option, +} + +impl FilePart { + pub fn new(headers: HeaderMap, encoding: &str) -> Result { + let cd = headers + .get(header::CONTENT_DISPOSITION) + .ok_or::(error_parsing!("missing content disposition"))?; + let name = get_mime_param_encoded(cd, "name", encoding)? + .ok_or::(error_parsing!("missing name field"))?; + let filename = get_mime_param_encoded(cd, "filename", encoding)?; + let mut path = tempfile::Builder::new().prefix("mime_multipart").tempdir()?.into_path(); + let tempdir = Some(path.clone()); + path.push(TextNonce::sized_urlsafe(32).unwrap().into_string()); + + let file = File::create(path.clone())?; + + Ok(FilePart { + headers, + name, + filename, + path, + file: Some(file), + size: None, + tempdir, + }) + } + + #[inline] + pub fn filename(&self) -> Option { + self.filename.clone() + } + + #[inline] + pub fn content_type(&self) -> Result> { + if let Some(cd) = self.headers.get(header::CONTENT_TYPE) { + return Ok(Some(get_content_type(cd.to_str()?)?)); + } + Ok(None) + } + + #[inline] + pub fn content_length(&self) -> Option { + if let Some(cl) = self.headers.get(header::CONTENT_LENGTH) { + return cl.to_str().unwrap_or("0").parse::().map(Some).unwrap_or(None); + } + None + } +} + +impl Drop for FilePart { + fn drop(&mut self) { + if self.tempdir.is_some() { + let _ = std::fs::remove_file(&self.path); + let _ = std::fs::remove_dir(self.tempdir.as_ref().unwrap()); + } + } +} + +#[pyclass(module = "emmett_core._emmett_core")] +pub(super) struct FilePartReader { + inner: FilePart, + reader: BufReader, + size: u64, +} + +impl FilePartReader { + pub fn new(mut inner: FilePart) -> Result { + drop(inner.file.take().expect("uninitialized file part")); + let file = File::open(inner.path.clone()).map_err::(|_| error_io!())?; + let size = file.metadata().unwrap().len(); + let reader = BufReader::with_capacity(4096, file); + Ok(Self { inner, reader, size }) + } + + #[inline] + fn read_chunk(&mut self, size: usize) -> Result> { + self.reader.fill_buf()?; + let mut buf = Vec::with_capacity(size); + self.reader.read_exact(&mut buf)?; + Ok(buf) + } + + fn read_all(&mut self) -> Result> { + self.reader.fill_buf()?; + let mut buf = Vec::new(); + self.reader.read_to_end(&mut buf)?; + Ok(buf) + } +} + +#[pymethods] +impl FilePartReader { + #[getter(content_type)] + fn get_content_type(&self) -> Option { + self.inner.content_type().unwrap_or(None) + } + + #[getter(content_length)] + fn get_content_length(&self) -> u64 { + if let Some(v) = self.inner.content_length() { + return v; + } + self.size + } + + #[getter(filename)] + fn get_filename(&self) -> Option { + self.inner.filename() + } + + #[pyo3(signature = (size = None))] + fn read<'p>(&mut self, py: Python<'p>, size: Option) -> Result> { + let buf = match size { + Some(size) => self.read_chunk(size), + None => self.read_all(), + }?; + Ok(PyBytes::new_bound(py, &buf[..])) + } + + fn __iter__(pyself: PyRef) -> PyRef { + pyself + } + + fn __next__<'p>(&mut self, py: Python<'p>) -> Result> { + let buf = self.read_chunk(4096)?; + if buf.is_empty() { + return Err(PyStopIteration::new_err(py.None()).into()); + } + Ok(PyBytes::new_bound(py, &buf[..])) + } +} diff --git a/src/multipart/utils.rs b/src/multipart/utils.rs new file mode 100644 index 0000000..4e1d345 --- /dev/null +++ b/src/multipart/utils.rs @@ -0,0 +1,45 @@ +use anyhow::Result; +use encoding::{all as encoders, DecoderTrap, Encoding}; +use http::HeaderValue; +use mime::Mime; + +use super::errors::error_encoding; + +#[allow(dead_code)] +#[inline] +pub(super) fn get_mime_param(hv: &HeaderValue, param: &str) -> Result> { + let hvs = hv.to_str().unwrap_or(""); + let (_ty, params) = hvs.split_once(';').unwrap_or(("", "")); + let mime: Mime = format!("*/*; {params}").parse()?; + Ok(mime.get_param(param).map(|v| v.to_string())) +} + +#[inline] +pub(super) fn get_mime_param_encoded(hv: &HeaderValue, param: &str, encoding: &str) -> Result> { + let hvs = charset_decode(encoding, hv.as_bytes()).unwrap_or_default(); + let (_ty, params) = hvs.split_once(';').unwrap_or(("", "")); + let mime: Mime = format!("*/*; {params}").parse()?; + Ok(mime.get_param(param).map(|v| v.to_string())) +} + +pub(super) fn charset_decode(charset: &str, bytes: &[u8]) -> Result { + match charset { + "us-ascii" => encoders::ASCII.decode(bytes, DecoderTrap::Strict), + "iso-8859-1" => encoders::ISO_8859_1.decode(bytes, DecoderTrap::Strict), + "iso-8859-2" => encoders::ISO_8859_2.decode(bytes, DecoderTrap::Strict), + "iso-8859-3" => encoders::ISO_8859_3.decode(bytes, DecoderTrap::Strict), + "iso-8859-4" => encoders::ISO_8859_4.decode(bytes, DecoderTrap::Strict), + "iso-8859-5" => encoders::ISO_8859_5.decode(bytes, DecoderTrap::Strict), + "iso-8859-6" => encoders::ISO_8859_6.decode(bytes, DecoderTrap::Strict), + "iso-8859-7" => encoders::ISO_8859_7.decode(bytes, DecoderTrap::Strict), + "iso-8859-8" => encoders::ISO_8859_8.decode(bytes, DecoderTrap::Strict), + "iso-8859-10" => encoders::ISO_8859_10.decode(bytes, DecoderTrap::Strict), + "euc-jp" => encoders::EUC_JP.decode(bytes, DecoderTrap::Strict), + "iso-2022-jp" => encoders::ISO_2022_JP.decode(bytes, DecoderTrap::Strict), + "big5" => encoders::BIG5_2003.decode(bytes, DecoderTrap::Strict), + "koi8-r" => encoders::KOI8_R.decode(bytes, DecoderTrap::Strict), + "utf-8" => encoders::UTF_8.decode(bytes, DecoderTrap::Strict), + _ => Err("no encoder".into()), + } + .map_err(|_err| error_encoding!()) +} diff --git a/tests/multipart/conftest.py b/tests/multipart/conftest.py new file mode 100644 index 0000000..067c2b3 --- /dev/null +++ b/tests/multipart/conftest.py @@ -0,0 +1,43 @@ +import pytest + +from emmett_core.app import App as _App +from emmett_core.ctx import Current +from emmett_core.protocols.rsgi.test_client.client import EmmettTestClient +from emmett_core.routing.router import HTTPRouter, WebsocketRouter + + +class App(_App): + def _init_routers(self, url_prefix): + pass + + def _init_handlers(self): + pass + + def _register_with_ctx(self): + pass + + def _init_with_test_env(self, current): + self._router_http = HTTPRouter(self, current) + self._router_ws = WebsocketRouter(self, current) + current.app = self + + +@pytest.fixture(scope="function") +def current(): + return Current() + + +@pytest.fixture(scope="function") +def app(current): + class TestClient(EmmettTestClient): + _current = current + + App.test_client_class = TestClient + rv = App(__name__) + rv._init_with_test_env(current) + return rv + + +@pytest.fixture(scope="function") +def client(app): + return app.test_client() diff --git a/tests/multipart/test_multipart.py b/tests/multipart/test_multipart.py new file mode 100644 index 0000000..3fda85e --- /dev/null +++ b/tests/multipart/test_multipart.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from emmett_core.serializers import Serializers + + +@pytest.fixture(scope="function") +def multipart_client(current, client): + app = client.application + json_dump = Serializers.get_for("json") + + @app.route("/", output="str") + async def multipart(): + rv = {"params": {}, "files": {}} + params = await current.request.body_params + files = await current.request.files + for key, val in params.items(): + accum = [] + if not isinstance(val, list): + val = [val] + for item in val: + accum.append(item) + rv["params"][key] = accum + for key, val in files.items(): + accum = [] + if not isinstance(val, list): + val = [val] + for item in val: + accum.append( + { + "filename": item.filename, + "size": item.size, + "content": item.read().decode("utf8"), + "content_type": item.content_type, + } + ) + rv["files"][key] = accum + return json_dump(rv) + + return client + + +def test_multipart_request_data(multipart_client): + response = multipart_client.post("/", data={"some": "data"}, content_type="multipart/form-data") + assert response.json() == {"params": {"some": ["data"]}, "files": {}} + + +def test_multipart_request_files(tmpdir: Path, multipart_client): + path = tmpdir / "test.txt" + with path.open("wb") as file: + file.write(b"") + + with path.open("rb") as f: + response = multipart_client.post("/", data={"test": f}) + assert response.json() == { + "params": {}, + "files": { + "test": [ + { + "filename": str(path), + "size": 14, + "content": "", + "content_type": "text/plain", + } + ] + }, + } + + +def test_multipart_request_files_with_content_type(tmpdir: Path, multipart_client): + path = tmpdir / "test.txt" + with path.open("wb") as file: + file.write(b"") + + with path.open("rb") as f: + response = multipart_client.post("/", data={"test": (f, "test.txt", "text/plain")}) + assert response.json() == { + "params": {}, + "files": { + "test": [ + { + "filename": "test.txt", + "size": 14, + "content": "", + "content_type": "text/plain", + } + ] + }, + } + + +def test_multipart_request_multiple_files(tmpdir: Path, multipart_client): + path1 = tmpdir / "test1.txt" + with path1.open("wb") as file: + file.write(b"") + + path2 = tmpdir / "test2.txt" + with path2.open("wb") as file: + file.write(b"") + + with path1.open("rb") as f1, path2.open("rb") as f2: + response = multipart_client.post( + "/", data={"test1": (f1, "test1.txt", "text/plain"), "test2": (f2, "test2.txt", "text/plain")} + ) + assert response.json() == { + "params": {}, + "files": { + "test1": [ + { + "filename": "test1.txt", + "size": 15, + "content": "", + "content_type": "text/plain", + } + ], + "test2": [ + { + "filename": "test2.txt", + "size": 15, + "content": "", + "content_type": "text/plain", + } + ], + }, + } + + +def test_multipart_multi_items(tmpdir: Path, multipart_client): + path1 = tmpdir / "test1.txt" + with path1.open("wb") as file: + file.write(b"") + + path2 = tmpdir / "test2.txt" + with path2.open("wb") as file: + file.write(b"") + + with path1.open("rb") as f1, path2.open("rb") as f2: + response = multipart_client.post( + "/", data={"test1": ["abc", (f1, "test1.txt", "text/plain"), (f2, "test2.txt", "text/plain")]} + ) + assert response.json() == { + "params": {"test1": ["abc"]}, + "files": { + "test1": [ + { + "filename": "test1.txt", + "size": 15, + "content": "", + "content_type": "text/plain", + }, + { + "filename": "test2.txt", + "size": 15, + "content": "", + "content_type": "text/plain", + }, + ] + }, + } + + +def test_multipart_request_mixed_files_and_data(multipart_client): + response = multipart_client.post( + "/", + data=( + # data + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" + b'Content-Disposition: form-data; name="field0"\r\n\r\n' + b"value0\r\n" + # file + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" + b'Content-Disposition: form-data; name="file"; filename="file.txt"\r\n' + b"Content-Type: text/plain\r\n\r\n" + b"\r\n" + # data + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" + b'Content-Disposition: form-data; name="field1"\r\n\r\n' + b"value1\r\n" + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" + ), + headers=[("content-type", "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")], + ) + assert response.json() == { + "files": { + "file": [ + { + "filename": "file.txt", + "size": 14, + "content": "", + "content_type": "text/plain", + } + ], + }, + "params": { + "field0": ["value0"], + "field1": ["value1"], + }, + } + + +def test_multipart_request_with_charset_for_filename(multipart_client): + response = multipart_client.post( + "/", + data=( + # file + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" + b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' + b"Content-Type: text/plain\r\n\r\n" + b"\r\n" + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" + ), + headers=[("content-type", "multipart/form-data; charset=utf-8; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")], + ) + assert response.json() == { + "params": {}, + "files": { + "file": [ + { + "filename": "文書.txt", + "size": 14, + "content": "", + "content_type": "text/plain", + } + ] + }, + } + + +def test_multipart_request_without_charset_for_filename(multipart_client): + response = multipart_client.post( + "/", + data=( + # file + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" + b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n' + b"Content-Type: image/jpeg\r\n\r\n" + b"\r\n" + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" + ), + headers=[("content-type", "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")], + ) + assert response.json() == { + "params": {}, + "files": { + "file": [ + { + "filename": "画像.jpg", + "size": 14, + "content": "", + "content_type": "image/jpeg", + } + ], + }, + } + + +def test_multipart_request_with_encoded_value(multipart_client): + response = multipart_client.post( + "/", + data=( + b"--20b303e711c4ab8c443184ac833ab00f\r\n" + b"Content-Disposition: form-data; " + b'name="value"\r\n\r\n' + b"Transf\xc3\xa9rer\r\n" + b"--20b303e711c4ab8c443184ac833ab00f--\r\n" + ), + headers=[("content-type", "multipart/form-data; charset=utf-8; boundary=20b303e711c4ab8c443184ac833ab00f")], + ) + assert response.json() == {"params": {"value": ["Transférer"]}, "files": {}} + + +def test_missing_boundary_parameter(multipart_client): + res = multipart_client.post( + "/", + data=( + # file + b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' + b"Content-Type: text/plain\r\n\r\n" + b"\r\n" + ), + headers=[("content-type", "multipart/form-data; charset=utf-8")], + ) + assert res.status == 400 + assert res.data == "Invalid multipart data" + + +def test_missing_name_parameter_on_content_disposition(multipart_client): + res = multipart_client.post( + "/", + data=( + # data + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; ="field0"\r\n\r\n' b"value0\r\n" + ), + headers=[("content-type", "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")], + ) + assert res.status == 400 + assert res.data == "Invalid multipart data" + + +# TODO: save diff --git a/tests/routing/test_router.py b/tests/routing/test_router.py index 45d45a0..f20be33 100644 --- a/tests/routing/test_router.py +++ b/tests/routing/test_router.py @@ -341,7 +341,6 @@ def test_routing_with_scheme(routing_ctx_scheme): assert route.name == "test_router.test_route" with routing_ctx_scheme.ctx("/test2/1/test") as ctx: - print(ctx.wrapper.scheme) route, _ = routing_ctx_scheme.router.match(ctx.wrapper) assert not route