Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,18 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")

response_complete = False
async for sse in event_source.aiter_sse(): # pragma: no branch
if response_complete:
continue
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
await event_source.response.aclose()
break
response_complete = True

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
Expand Down Expand Up @@ -342,6 +344,7 @@ async def _handle_sse_response(

try:
event_source = EventSource(response)
response_complete = False
async for sse in event_source.aiter_sse(): # pragma: no branch
# Track last event ID for potential reconnection
if sse.id:
Expand All @@ -351,6 +354,9 @@ async def _handle_sse_response(
if sse.retry is not None:
retry_interval_ms = sse.retry

if response_complete:
continue

is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
Expand All @@ -359,10 +365,11 @@ async def _handle_sse_response(
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning response/error
# break the loop
# keep draining the stream so the underlying HTTP connection remains reusable.
if is_complete:
await response.aclose()
return # Normal completion, no reconnect needed
response_complete = True
if response_complete:
return # Normal completion, no reconnect needed
except Exception:
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover

Expand Down Expand Up @@ -404,27 +411,32 @@ async def _handle_reconnection(
# Track for potential further reconnection
reconnect_last_event_id: str = last_event_id
reconnect_retry_ms = retry_interval_ms
response_complete = False

async for sse in event_source.aiter_sse():
if sse.id: # pragma: no branch
reconnect_last_event_id = sse.id
if sse.retry is not None:
reconnect_retry_ms = sse.retry

if response_complete:
continue

is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
await event_source.response.aclose()
return
response_complete = True
if response_complete:
return

# Stream ended again without response - reconnect again (reset attempt counter)
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
except Exception as e: # pragma: no cover
except Exception as e:
logger.debug(f"Reconnection failed: {e}")
# Try to reconnect again if we still have an event ID
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
Expand Down
4 changes: 2 additions & 2 deletions tests/interaction/transports/test_hosting_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ async def collect(params: LoggingMessageNotificationParams) -> None:
http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION
tg.cancel_scope.cancel()

with anyio.fail_after(5): # pragma: no branch
release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event
with anyio.fail_after(5): # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event
release.set()
# init priming + init response + call priming + "first" + "second" + result = 6 stored events.
await store.wait_until_stored(6)
async with ( # pragma: no branch
Expand Down
155 changes: 153 additions & 2 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,16 @@
from starlette.requests import Request
from starlette.routing import Mount

import mcp.client.streamable_http as streamable_http_module
from mcp import MCPError, types
from mcp.client.session import ClientSession
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
from mcp.client.streamable_http import (
RequestContext as StreamableHTTPClientRequestContext,
)
from mcp.client.streamable_http import (
StreamableHTTPTransport,
streamable_http_client,
)
from mcp.server import Server, ServerRequestContext
from mcp.server.streamable_http import (
MCP_PROTOCOL_VERSION_HEADER,
Expand All @@ -45,7 +52,7 @@
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared._context import RequestContext
from mcp.shared._context_streams import create_context_streams
from mcp.shared._context_streams import ContextSendStream, create_context_streams
from mcp.shared._httpx_utils import (
MCP_DEFAULT_SSE_READ_TIMEOUT,
MCP_DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -1803,6 +1810,150 @@ async def test_handle_sse_event_skips_empty_data():
await read_stream.aclose()


class _FakeStreamResponse(httpx.Response):
def __init__(self) -> None:
super().__init__(200, request=httpx.Request("POST", "http://localhost:8000/mcp"))
self.closed_by_transport = False

async def aclose(self) -> None: # pragma: no cover
self.closed_by_transport = True
await super().aclose()


def _response_sse(request_id: int | str) -> ServerSentEvent:
return ServerSentEvent(
event="message",
data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}),
id="response-event",
)


def _make_streamable_http_request_context(
request_id: int | str,
client: httpx.AsyncClient,
write_stream: ContextSendStream[SessionMessage | Exception],
) -> StreamableHTTPClientRequestContext:
return StreamableHTTPClientRequestContext(
client=client,
session_id=None,
session_message=SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=request_id, method="tools/list")),
metadata=None,
read_stream_writer=write_stream,
)


@pytest.mark.anyio
async def test_sse_response_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch):
"""Terminal POST SSE responses are drained instead of force-closed."""
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
response = _FakeStreamResponse()

class FakeEventSource:
def __init__(self, response: _FakeStreamResponse) -> None:
self.response = response

async def aiter_sse(self):
yield _response_sse(1)
yield ServerSentEvent(event="message", data="", id="drained-event")

async def fail_reconnect(*args: Any, **kwargs: Any) -> None: # pragma: no cover
raise AssertionError("terminal responses should not reconnect after draining")

monkeypatch.setattr(streamable_http_module, "EventSource", FakeEventSource)
monkeypatch.setattr(transport, "_handle_reconnection", fail_reconnect)

write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
async with httpx.AsyncClient() as client:
try:
ctx = _make_streamable_http_request_context(1, client, write_stream)
await transport._handle_sse_response(response, ctx)

assert response.closed_by_transport is False
message = await read_stream.receive()
assert isinstance(message, SessionMessage)
assert isinstance(message.message, types.JSONRPCResponse)
assert message.message.id == 1
finally:
await write_stream.aclose()
await read_stream.aclose()


@pytest.mark.anyio
async def test_reconnection_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch):
"""Resumed GET responses use EOF draining instead of response.aclose()."""
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
response = _FakeStreamResponse()

class FakeReconnectionEventSource:
def __init__(self, response: _FakeStreamResponse) -> None:
self.response = response

async def aiter_sse(self):
yield _response_sse("abc")
yield ServerSentEvent(event="message", data="", id="drained-event")

@asynccontextmanager
async def fake_aconnect_sse(*args: Any, **kwargs: Any):
yield FakeReconnectionEventSource(response)

monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse)

write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
async with httpx.AsyncClient() as client:
try:
ctx = _make_streamable_http_request_context("abc", client, write_stream)
await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0)

assert response.closed_by_transport is False
message = await read_stream.receive()
assert isinstance(message, SessionMessage)
assert isinstance(message.message, types.JSONRPCResponse)
assert message.message.id == "abc"
finally:
await write_stream.aclose()
await read_stream.aclose()


@pytest.mark.anyio
async def test_reconnection_retries_after_failed_resume(monkeypatch: pytest.MonkeyPatch):
"""A failed resume attempt falls back to the next reconnection attempt."""
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
response = _FakeStreamResponse()
attempts = 0

class FakeReconnectionEventSource:
def __init__(self, response: _FakeStreamResponse) -> None:
self.response = response

async def aiter_sse(self):
yield _response_sse("abc")

@asynccontextmanager
async def fake_aconnect_sse(*args: Any, **kwargs: Any):
nonlocal attempts
attempts += 1
if attempts == 1:
raise RuntimeError("resume failed")
yield FakeReconnectionEventSource(response)

monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse)

write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
async with httpx.AsyncClient() as client:
try:
ctx = _make_streamable_http_request_context("abc", client, write_stream)
await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0)

assert attempts == 2
message = await read_stream.receive()
assert isinstance(message, SessionMessage)
assert isinstance(message.message, types.JSONRPCResponse)
assert message.message.id == "abc"
finally:
await write_stream.aclose()
await read_stream.aclose()


@pytest.mark.anyio
async def test_priming_event_not_sent_for_old_protocol_version():
"""Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""
Expand Down
Loading