Skip to content

Commit

Permalink
fix(h2_connection_reset): add stream reset when client cancel the req…
Browse files Browse the repository at this point in the history
…uest
  • Loading branch information
MarkLux committed Sep 4, 2024
1 parent 7d87c9d commit 5997901
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
14 changes: 12 additions & 2 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import enum
import logging
import time
Expand Down Expand Up @@ -401,6 +402,9 @@ async def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
await self._max_streams_semaphore.acquire()
self._max_streams -= 1

async def _reset_steam(self, stream_id: int, error_code: int) -> None:
self._h2_state.reset_stream(stream_id=stream_id, error_code=error_code)

async def _response_closed(self, stream_id: int) -> None:
await self._max_streams_semaphore.release()
del self._events[stream_id]
Expand Down Expand Up @@ -578,12 +582,18 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
# we want to close the response (and possibly the connection)
# before raising that exception.
with AsyncShieldCancellation():
await self.aclose()
# close the stream with cancel
await self.aclose(cancel_stream=isinstance(exc, asyncio.exceptions.CancelledError))
raise exc

async def aclose(self) -> None:
async def aclose(self, cancel_stream: bool = False) -> None:
if not self._closed:
self._closed = True
kwargs = {"stream_id": self._stream_id}
async with Trace("response_closed", logger, self._request, kwargs):
if cancel_stream:
await self._connection._reset_steam(
stream_id=self._stream_id,
error_code=h2.settings.ErrorCodes.CANCEL,
)
await self._connection._response_closed(stream_id=self._stream_id)
13 changes: 11 additions & 2 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@ def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
self._max_streams_semaphore.acquire()
self._max_streams -= 1

def _reset_steam(self, stream_id: int, error_code: int) -> None:
self._h2_state.reset_stream(stream_id=stream_id, error_code=error_code)

def _response_closed(self, stream_id: int) -> None:
self._max_streams_semaphore.release()
del self._events[stream_id]
Expand Down Expand Up @@ -578,12 +581,18 @@ def __iter__(self) -> typing.Iterator[bytes]:
# we want to close the response (and possibly the connection)
# before raising that exception.
with ShieldCancellation():
self.close()
# close the stream with cancel
self.close(cancel_stream=True)
raise exc

def close(self) -> None:
def close(self, cancel_stream: bool = False) -> None:
if not self._closed:
self._closed = True
kwargs = {"stream_id": self._stream_id}
with Trace("response_closed", logger, self._request, kwargs):
if cancel_stream:
self._connection._reset_steam(
stream_id=self._stream_id,
error_code=h2.settings.ErrorCodes.CANCEL,
)
self._connection._response_closed(stream_id=self._stream_id)

0 comments on commit 5997901

Please sign in to comment.