diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index aa3e50e07e..b7fa507802 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -240,7 +240,10 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") + response_complete = False async for sse in event_source.aiter_sse(): # pragma: no branch + if response_complete: + continue is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, @@ -248,8 +251,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: ctx.metadata.on_resumption_token_update if ctx.metadata else None, ) if is_complete: - await event_source.response.aclose() - break + response_complete = True async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" @@ -342,6 +344,7 @@ async def _handle_sse_response( try: event_source = EventSource(response) + response_complete = False async for sse in event_source.aiter_sse(): # pragma: no branch # Track last event ID for potential reconnection if sse.id: @@ -351,6 +354,9 @@ async def _handle_sse_response( if sse.retry is not None: retry_interval_ms = sse.retry + if response_complete: + continue + is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, @@ -359,10 +365,11 @@ async def _handle_sse_response( is_initialization=is_initialization, ) # If the SSE event indicates completion, like returning response/error - # break the loop + # keep draining the stream so the underlying HTTP connection remains reusable. if is_complete: - await response.aclose() - return # Normal completion, no reconnect needed + response_complete = True + if response_complete: + return # Normal completion, no reconnect needed except Exception: logger.debug("SSE stream ended", exc_info=True) # pragma: no cover @@ -404,6 +411,7 @@ async def _handle_reconnection( # Track for potential further reconnection reconnect_last_event_id: str = last_event_id reconnect_retry_ms = retry_interval_ms + response_complete = False async for sse in event_source.aiter_sse(): if sse.id: # pragma: no branch @@ -411,6 +419,9 @@ async def _handle_reconnection( if sse.retry is not None: reconnect_retry_ms = sse.retry + if response_complete: + continue + is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, @@ -418,13 +429,14 @@ async def _handle_reconnection( ctx.metadata.on_resumption_token_update if ctx.metadata else None, ) if is_complete: - await event_source.response.aclose() - return + response_complete = True + if response_complete: + return # Stream ended again without response - reconnect again (reset attempt counter) logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0) - except Exception as e: # pragma: no cover + except Exception as e: logger.debug(f"Reconnection failed: {e}") # Try to reconnect again if we still have an event ID await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index c7945d56c3..94ed169730 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -357,8 +357,8 @@ async def collect(params: LoggingMessageNotificationParams) -> None: http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION tg.cancel_scope.cancel() - with anyio.fail_after(5): # pragma: no branch - release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event + with anyio.fail_after(5): # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event + release.set() # init priming + init response + call priming + "first" + "second" + result = 6 stored events. await store.wait_until_stored(6) async with ( # pragma: no branch diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..36590ee5ce 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -27,9 +27,16 @@ from starlette.requests import Request from starlette.routing import Mount +import mcp.client.streamable_http as streamable_http_module from mcp import MCPError, types from mcp.client.session import ClientSession -from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client +from mcp.client.streamable_http import ( + RequestContext as StreamableHTTPClientRequestContext, +) +from mcp.client.streamable_http import ( + StreamableHTTPTransport, + streamable_http_client, +) from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, @@ -45,7 +52,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._context import RequestContext -from mcp.shared._context_streams import create_context_streams +from mcp.shared._context_streams import ContextSendStream, create_context_streams from mcp.shared._httpx_utils import ( MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT, @@ -1803,6 +1810,150 @@ async def test_handle_sse_event_skips_empty_data(): await read_stream.aclose() +class _FakeStreamResponse(httpx.Response): + def __init__(self) -> None: + super().__init__(200, request=httpx.Request("POST", "http://localhost:8000/mcp")) + self.closed_by_transport = False + + async def aclose(self) -> None: # pragma: no cover + self.closed_by_transport = True + await super().aclose() + + +def _response_sse(request_id: int | str) -> ServerSentEvent: + return ServerSentEvent( + event="message", + data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}), + id="response-event", + ) + + +def _make_streamable_http_request_context( + request_id: int | str, + client: httpx.AsyncClient, + write_stream: ContextSendStream[SessionMessage | Exception], +) -> StreamableHTTPClientRequestContext: + return StreamableHTTPClientRequestContext( + client=client, + session_id=None, + session_message=SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=request_id, method="tools/list")), + metadata=None, + read_stream_writer=write_stream, + ) + + +@pytest.mark.anyio +async def test_sse_response_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch): + """Terminal POST SSE responses are drained instead of force-closed.""" + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + response = _FakeStreamResponse() + + class FakeEventSource: + def __init__(self, response: _FakeStreamResponse) -> None: + self.response = response + + async def aiter_sse(self): + yield _response_sse(1) + yield ServerSentEvent(event="message", data="", id="drained-event") + + async def fail_reconnect(*args: Any, **kwargs: Any) -> None: # pragma: no cover + raise AssertionError("terminal responses should not reconnect after draining") + + monkeypatch.setattr(streamable_http_module, "EventSource", FakeEventSource) + monkeypatch.setattr(transport, "_handle_reconnection", fail_reconnect) + + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) + async with httpx.AsyncClient() as client: + try: + ctx = _make_streamable_http_request_context(1, client, write_stream) + await transport._handle_sse_response(response, ctx) + + assert response.closed_by_transport is False + message = await read_stream.receive() + assert isinstance(message, SessionMessage) + assert isinstance(message.message, types.JSONRPCResponse) + assert message.message.id == 1 + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_reconnection_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch): + """Resumed GET responses use EOF draining instead of response.aclose().""" + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + response = _FakeStreamResponse() + + class FakeReconnectionEventSource: + def __init__(self, response: _FakeStreamResponse) -> None: + self.response = response + + async def aiter_sse(self): + yield _response_sse("abc") + yield ServerSentEvent(event="message", data="", id="drained-event") + + @asynccontextmanager + async def fake_aconnect_sse(*args: Any, **kwargs: Any): + yield FakeReconnectionEventSource(response) + + monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse) + + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) + async with httpx.AsyncClient() as client: + try: + ctx = _make_streamable_http_request_context("abc", client, write_stream) + await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0) + + assert response.closed_by_transport is False + message = await read_stream.receive() + assert isinstance(message, SessionMessage) + assert isinstance(message.message, types.JSONRPCResponse) + assert message.message.id == "abc" + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_reconnection_retries_after_failed_resume(monkeypatch: pytest.MonkeyPatch): + """A failed resume attempt falls back to the next reconnection attempt.""" + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + response = _FakeStreamResponse() + attempts = 0 + + class FakeReconnectionEventSource: + def __init__(self, response: _FakeStreamResponse) -> None: + self.response = response + + async def aiter_sse(self): + yield _response_sse("abc") + + @asynccontextmanager + async def fake_aconnect_sse(*args: Any, **kwargs: Any): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise RuntimeError("resume failed") + yield FakeReconnectionEventSource(response) + + monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse) + + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) + async with httpx.AsyncClient() as client: + try: + ctx = _make_streamable_http_request_context("abc", client, write_stream) + await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0) + + assert attempts == 2 + message = await read_stream.receive() + assert isinstance(message, SessionMessage) + assert isinstance(message.message, types.JSONRPCResponse) + assert message.message.id == "abc" + finally: + await write_stream.aclose() + await read_stream.aclose() + + @pytest.mark.anyio async def test_priming_event_not_sent_for_old_protocol_version(): """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""