From de72011630b807d9ed4f9609c3a5aeb309e41e2e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 20:50:08 +0000 Subject: [PATCH 1/2] Run StreamableHTTP transport tests in process instead of over sockets Final installment of the in-process test migration: this was the last file spawning uvicorn subprocesses on bind-then-close ports with readiness polling, which races under pytest-xdist when two workers pick the same ephemeral port. Two tests in this file have flaked exactly that way under parallel load. All four subprocess servers (basic, JSON-response, event-store, context-aware) become in-process apps served through the interaction suite's StreamingASGITransport, held open by the session manager's run() context. Raw `requests` calls become httpx calls against the bridge client; the sync request-validation tests become anyio tests. The second-GET-409 test now holds the first stream open by construction, where the subprocess version noted it "might fail if the first stream fully closed before this runs". Assertions are unchanged, with documented exceptions now that the server handlers run as traced in-process code: - The long_running_with_checkpoints tool and the slow:// resource branch had no callers and are removed, so the expected tools/list count drops from 10 to 9 in five tests. - Dead defensive arms become asserts (sampling non-text fallback, close_sse_stream truthiness checks, the context server's unknown-tool fallthrough and request checks), and the event store's replay-from-unknown-event arm becomes a lookup that requires a stored event, since unreachable branches now fail branch coverage instead of hiding in an untraced subprocess. - test_client_crash_handled no longer sleeps between crashing clients; the bridge drains each client's teardown before the next connects. Three pragmas in src/mcp/server/streamable_http.py covered only by the formerly untraced subprocess (close_standalone_sse_stream, its session message callback, and the JSON-mode Accept rejection) are now executed by traced tests and removed. With the last wait_for_server user migrated, the helper is deleted from tests/test_helpers.py; run_uvicorn_in_thread stays for the websocket smoke test. --- src/mcp/server/streamable_http.py | 8 +- tests/shared/test_streamable_http.py | 2044 ++++++++++++-------------- tests/test_helpers.py | 28 - 3 files changed, 937 insertions(+), 1143 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 98948ff999..2cb4c0748e 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -207,7 +207,7 @@ def close_sse_stream(self, request_id: RequestId) -> None: send_stream.close() receive_stream.close() - def close_standalone_sse_stream(self) -> None: # pragma: no cover + def close_standalone_sse_stream(self) -> None: """Close the standalone GET SSE stream, triggering client reconnection. This method closes the HTTP connection for the standalone GET stream used @@ -221,8 +221,6 @@ def close_standalone_sse_stream(self) -> None: # pragma: no cover This is a no-op if there is no active standalone SSE stream. Requires event_store to be configured for events to be stored during the disconnect. - Currently, client reconnection for standalone GET streams is NOT - implemented - this is a known gap (see test_standalone_get_stream_reconnection). """ self.close_sse_stream(GET_STREAM_KEY) @@ -245,7 +243,7 @@ def _create_session_message( async def close_stream_callback() -> None: self.close_sse_stream(request_id) - async def close_standalone_stream_callback() -> None: # pragma: no cover + async def close_standalone_stream_callback() -> None: self.close_standalone_sse_stream() metadata = ServerMessageMetadata( @@ -421,7 +419,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se has_json, has_sse = self._check_accept_headers(request) if self.is_json_response_enabled: # For JSON-only responses, only require application/json - if not has_json: # pragma: no cover + if not has_json: response = self._create_error_response( "Not Acceptable: Client must accept application/json", HTTPStatus.NOT_ACCEPTABLE, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..9ed209b890 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1,16 +1,14 @@ """Tests for the StreamableHTTP server and client transport. -Contains tests for both server and client sides of the StreamableHTTP transport. +Contains tests for both server and client sides of the StreamableHTTP transport, driven +entirely in process. """ from __future__ import annotations as _annotations import json -import multiprocessing -import socket import time -import traceback -from collections.abc import AsyncIterator, Generator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import Any @@ -20,8 +18,6 @@ import anyio import httpx import pytest -import requests -import uvicorn from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request @@ -46,11 +42,6 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._context import RequestContext from mcp.shared._context_streams import create_context_streams -from mcp.shared._httpx_utils import ( - MCP_DEFAULT_SSE_READ_TIMEOUT, - MCP_DEFAULT_TIMEOUT, - create_mcp_http_client, -) from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( @@ -66,11 +57,10 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport # Test constants SERVER_NAME = "test_streamable_http_server" -TEST_SESSION_ID = "test-session-id-12345" INIT_REQUEST = { "jsonrpc": "2.0", "method": "initialize", @@ -82,9 +72,12 @@ "id": "init-1", } +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + # Helper functions -def extract_protocol_version_from_sse(response: requests.Response) -> str: +def extract_protocol_version_from_sse(response: httpx.Response) -> str: """Extract the negotiated protocol version from an SSE initialization response.""" assert response.headers.get("Content-Type") == "text/event-stream" for line in response.text.splitlines(): @@ -109,32 +102,23 @@ async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( # pragma: no cover + async def replay_events_after( self, last_event_id: EventId, send_callback: EventCallback, ) -> StreamId | None: """Replay events after the specified ID.""" - # Find the stream ID of the last event - target_stream_id = None - for stream_id, event_id, _ in self._events: - if event_id == last_event_id: - target_stream_id = stream_id - break - - if target_stream_id is None: - # If event ID not found, return None - return None + # Find the stream ID of the last event; clients always resume from a stored event. + target_stream_id = next(stream_id for stream_id, event_id, _ in self._events if event_id == last_event_id) # Convert last_event_id to int for comparison last_event_id_int = int(last_event_id) - # Replay only events from the same stream with ID > last_event_id + # Replay only events from the same stream with ID > last_event_id, skipping priming + # events (None message). for stream_id, event_id, message in self._events: - if stream_id == target_stream_id and int(event_id) > last_event_id_int: - # Skip priming events (None message) - if message is not None: - await send_callback(EventMessage(message, event_id)) + if stream_id == target_stream_id and message is not None and int(event_id) > last_event_id_int: + await send_callback(EventMessage(message, event_id)) return target_stream_id @@ -145,26 +129,23 @@ class ServerState: @asynccontextmanager -async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: yield ServerState() -async def _handle_read_resource( # pragma: no cover +async def _handle_read_resource( ctx: ServerRequestContext[ServerState], params: ReadResourceRequestParams ) -> ReadResourceResult: uri = str(params.uri) parsed = urlparse(uri) if parsed.scheme == "foobar": - text = f"Read {parsed.netloc}" - elif parsed.scheme == "slow": - await anyio.sleep(2.0) - text = f"Slow response from {parsed.netloc}" - else: - raise ValueError(f"Unknown resource: {uri}") - return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) + return ReadResourceResult( + contents=[TextResourceContents(uri=uri, text=f"Read {parsed.netloc}", mime_type="text/plain")] + ) + raise ValueError(f"Unknown resource: {uri}") -async def _handle_list_tools( # pragma: no cover +async def _handle_list_tools( ctx: ServerRequestContext[ServerState], params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -179,11 +160,6 @@ async def _handle_list_tools( # pragma: no cover description="A test tool that sends a notification", input_schema={"type": "object", "properties": {}}, ), - Tool( - name="long_running_with_checkpoints", - description="A long-running tool that sends periodic notifications", - input_schema={"type": "object", "properties": {}}, - ), Tool( name="test_sampling_tool", description="A tool that triggers server-side sampling", @@ -229,9 +205,7 @@ async def _handle_list_tools( # pragma: no cover ) -async def _handle_call_tool( # pragma: no cover - ctx: ServerRequestContext[ServerState], params: CallToolRequestParams -) -> CallToolResult: +async def _handle_call_tool(ctx: ServerRequestContext[ServerState], params: CallToolRequestParams) -> CallToolResult: name = params.name args = params.arguments or {} @@ -240,25 +214,6 @@ async def _handle_call_tool( # pragma: no cover await ctx.session.send_resource_updated(uri="http://test_resource") return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - elif name == "long_running_with_checkpoints": - await ctx.session.send_log_message( - level="info", - data="Tool started", - logger="tool", - related_request_id=ctx.request_id, - ) - - await anyio.sleep(0.1) - - await ctx.session.send_log_message( - level="info", - data="Tool is almost done", - logger="tool", - related_request_id=ctx.request_id, - ) - - return CallToolResult(content=[TextContent(type="text", text="Completed!")]) - elif name == "test_sampling_tool": sampling_result = await ctx.session.create_message( messages=[ @@ -271,15 +226,12 @@ async def _handle_call_tool( # pragma: no cover related_request_id=ctx.request_id, ) - if sampling_result.content.type == "text": - response = sampling_result.content.text - else: - response = str(sampling_result.content) + assert sampling_result.content.type == "text" return CallToolResult( content=[ TextContent( type="text", - text=f"Response from sampling: {response}", + text=f"Response from sampling: {sampling_result.content.text}", ) ] ) @@ -361,8 +313,8 @@ async def _handle_call_tool( # pragma: no cover related_request_id=ctx.request_id, ) - if ctx.close_sse_stream: - await ctx.close_sse_stream() + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() await anyio.sleep(sleep_time) @@ -372,8 +324,8 @@ async def _handle_call_tool( # pragma: no cover await ctx.session.send_resource_updated(uri="http://notification_1") await anyio.sleep(0.1) - if ctx.close_standalone_sse_stream: - await ctx.close_standalone_sse_stream() + assert ctx.close_standalone_sse_stream is not None + await ctx.close_standalone_sse_stream() await anyio.sleep(1.5) await ctx.session.send_resource_updated(uri="http://notification_2") @@ -383,7 +335,7 @@ async def _handle_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) -def _create_server() -> Server[ServerState]: # pragma: no cover +def _create_server() -> Server[ServerState]: return Server( SERVER_NAME, lifespan=_server_lifespan, @@ -393,113 +345,58 @@ def _create_server() -> Server[ServerState]: # pragma: no cover ) -def create_app( +@asynccontextmanager +async def running_app( is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, -) -> Starlette: # pragma: no cover - """Create a Starlette application for testing using the session manager. +) -> AsyncIterator[Starlette]: + """Serve the test server's streamable HTTP app in process for the duration. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. retry_interval: Retry interval in milliseconds for SSE polling. """ - # Create server instance - server = _create_server() - - # Create the session manager - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) + # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot + # exist for an in-process app; the protection itself is pinned by + # tests/server/test_streamable_http_security.py. session_manager = StreamableHTTPSessionManager( - app=server, + app=_create_server(), event_store=event_store, json_response=is_json_response_enabled, - security_settings=security_settings, + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), retry_interval=retry_interval, ) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + async with session_manager.run(): + yield app - # Create an ASGI application that uses the session manager - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lambda app: session_manager.run(), - ) - - return app +def make_client(app: Starlette, headers: dict[str, str] | None = None) -> httpx.AsyncClient: + """An httpx client served in process by `app`, with create_mcp_http_client's redirect default. -def run_server( - port: int, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, -) -> None: # pragma: no cover - """Run the test server. - - Args: - port: Port to listen on. - is_json_response_enabled: If True, use JSON responses instead of SSE streams. - event_store: Optional event store for testing resumability. - retry_interval: Retry interval in milliseconds for SSE polling. + (Starlette's Mount 307-redirects the bare /mcp path to /mcp/, which the SDK's own client + factory follows.) """ - - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="info", - limit_concurrency=10, - timeout_keep_alive=5, - access_log=False, + return httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, headers=headers, follow_redirects=True ) - # Start the server - server = uvicorn.Server(config=config) - - # This is important to catch exceptions and prevent test hangs - try: - server.run() - except Exception: - traceback.print_exc() - - -# Test fixtures - using same approach as SSE tests -@pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - +# Test fixtures @pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +async def basic_app() -> AsyncIterator[Starlette]: + """The test server's app with SSE response mode.""" + async with running_app() as app: + yield app @pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +async def json_app() -> AsyncIterator[Starlette]: + """The test server's app with JSON response mode.""" + async with running_app(is_json_response_enabled=True) as app: + yield app @pytest.fixture @@ -509,82 +406,29 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(event_server_port) - - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: - """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(json_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def basic_server_url(basic_server_port: int) -> str: - """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" - - -@pytest.fixture -def json_server_url(json_server_port: int) -> str: - """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" +async def event_app(event_store: SimpleEventStore) -> AsyncIterator[tuple[SimpleEventStore, Starlette]]: + """The test server's app with an event store and retry_interval enabled.""" + async with running_app(event_store=event_store, retry_interval=500) as app: + yield event_store, app # Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str): - """Test that Accept header is properly validated.""" - # Test without Accept header (suppress requests library default Accept: */*) - session = requests.Session() - session.headers.pop("Accept") - response = session.post( - f"{basic_server_url}/mcp", - headers={"Content-Type": "application/json"}, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_accept_header_validation(basic_app: Starlette) -> None: + """A POST without an Accept header is rejected with 406.""" + async with make_client(basic_app) as client: + # Suppress the httpx client default Accept: */* header + del client.headers["accept"] + response = await client.post( + "/mcp", + headers={"Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text +@pytest.mark.anyio @pytest.mark.parametrize( "accept_header", [ @@ -596,19 +440,21 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): "application/*;q=0.9, text/*;q=0.8", ], ) -def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): - """Test that wildcard Accept headers are accepted per RFC 7231.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": accept_header, - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 +async def test_accept_header_wildcard(basic_app: Starlette, accept_header: str) -> None: + """Wildcard Accept headers are accepted per RFC 7231.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 +@pytest.mark.anyio @pytest.mark.parametrize( "accept_header", [ @@ -617,100 +463,104 @@ def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accep "text/*", ], ) -def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): - """Test that incompatible Accept headers are rejected for SSE mode.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": accept_header, - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - -def test_content_type_validation(basic_server: None, basic_server_url: str): - """Test that Content-Type header is properly validated.""" - # Test with incorrect Content-Type - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "text/plain", - }, - data="This is not JSON", - ) +async def test_accept_header_incompatible(basic_app: Starlette, accept_header: str) -> None: + """Accept headers that cannot cover both response representations are rejected for SSE mode.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text - assert response.status_code == 400 - assert "Invalid Content-Type" in response.text +@pytest.mark.anyio +async def test_content_type_validation(basic_app: Starlette) -> None: + """A POST whose Content-Type is not application/json is rejected with 400.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", + }, + content="This is not JSON", + ) + + assert response.status_code == 400 + assert "Invalid Content-Type" in response.text + + +@pytest.mark.anyio +async def test_json_validation(basic_app: Starlette) -> None: + """A POST body that is not valid JSON is rejected with a parse error.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + content="this is not valid json", + ) + assert response.status_code == 400 + assert "Parse error" in response.text + + +@pytest.mark.anyio +async def test_json_parsing(basic_app: Starlette) -> None: + """Valid JSON that is not a JSON-RPC message is rejected with a validation error.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"foo": "bar"}, + ) + assert response.status_code == 400 + assert "Validation error" in response.text + + +@pytest.mark.anyio +async def test_method_not_allowed(basic_app: Starlette) -> None: + """Unsupported HTTP methods are rejected with 405.""" + async with make_client(basic_app) as client: + response = await client.put( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 405 + assert "Method Not Allowed" in response.text -def test_json_validation(basic_server: None, basic_server_url: str): - """Test that JSON content is properly validated.""" - # Test with invalid JSON - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - data="this is not valid json", - ) - assert response.status_code == 400 - assert "Parse error" in response.text - - -def test_json_parsing(basic_server: None, basic_server_url: str): - """Test that JSON content is properly parse.""" - # Test with valid JSON but invalid JSON-RPC - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"foo": "bar"}, - ) - assert response.status_code == 400 - assert "Validation error" in response.text - - -def test_method_not_allowed(basic_server: None, basic_server_url: str): - """Test that unsupported HTTP methods are rejected.""" - # Test with unsupported method (PUT) - response = requests.put( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 405 - assert "Method Not Allowed" in response.text - - -def test_session_validation(basic_server: None, basic_server_url: str): - """Test session ID validation.""" - # session_id not used directly in this test - - # Test without session ID - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, - ) - assert response.status_code == 400 - assert "Missing session ID" in response.text +@pytest.mark.anyio +async def test_session_validation(basic_app: Starlette) -> None: + """A non-initialize request without a session ID is rejected with 400.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 400 + assert "Missing session ID" in response.text -def test_session_id_pattern(): - """Test that SESSION_ID_PATTERN correctly validates session IDs.""" + +def test_session_id_pattern() -> None: + """SESSION_ID_PATTERN accepts visible ASCII (0x21-0x7E) and rejects everything else.""" # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) valid_session_ids = [ "test-session-id", @@ -744,8 +594,8 @@ def test_session_id_pattern(): assert SESSION_ID_PATTERN.fullmatch(session_id) is None -def test_streamable_http_transport_init_validation(): - """Test that StreamableHTTPServerTransport validates session ID on init.""" +def test_streamable_http_transport_init_validation() -> None: + """StreamableHTTPServerTransport accepts valid or absent session IDs and rejects invalid ones.""" # Valid session ID should initialize without errors valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") assert valid_transport.mcp_session_id == "valid-id" @@ -767,144 +617,153 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str): - """Test session termination via DELETE and subsequent request handling.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 +@pytest.mark.anyio +async def test_session_termination(basic_app: Starlette) -> None: + """DELETE terminates the session, after which requests for it return 404.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 - # Extract negotiated protocol version from SSE response - negotiated_version = extract_protocol_version_from_sse(response) + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) - # Now terminate the session - session_id = response.headers.get(MCP_SESSION_ID_HEADER) - response = requests.delete( - f"{basic_server_url}/mcp", - headers={ - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - ) - assert response.status_code == 200 - - # Try to use the terminated session - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - }, - json={"jsonrpc": "2.0", "method": "ping", "id": 2}, - ) - assert response.status_code == 404 - assert "Session has been terminated" in response.text - - -def test_response(basic_server: None, basic_server_url: str): - """Test response handling for a valid request.""" - mcp_url = f"{basic_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + response = await client.delete( + "/mcp", + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 200 + + # Try to use the terminated session + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "ping", "id": 2}, + ) + assert response.status_code == 404 + assert "Session has been terminated" in response.text - # Extract negotiated protocol version from SSE response - negotiated_version = extract_protocol_version_from_sse(response) - # Now get the session ID - session_id = response.headers.get(MCP_SESSION_ID_HEADER) +@pytest.mark.anyio +async def test_response(basic_app: Starlette) -> None: + """A request on an initialized session is answered on a text/event-stream response.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 - # Try to use the session with proper headers - tools_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, - stream=True, - ) - assert tools_response.status_code == 200 - assert tools_response.headers.get("Content-Type") == "text/event-stream" - - -def test_json_response(json_response_server: None, json_server_url: str): - """Test response handling when is_json_response_enabled is True.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): - """Test that json_response servers only require application/json in Accept header.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests without Accept header.""" - mcp_url = f"{json_server_url}/mcp" - # Suppress requests library default Accept: */* header - session = requests.Session() - session.headers.pop("Accept") - response = session.post( - mcp_url, - headers={ - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + # Now get the session ID + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Try to use the session with proper headers + async with client.stream( + "POST", + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, + ) as tools_response: + assert tools_response.status_code == 200 + assert tools_response.headers.get("Content-Type") == "text/event-stream" + + +@pytest.mark.anyio +async def test_json_response(json_app: Starlette) -> None: + """With JSON response mode enabled, requests are answered with application/json bodies.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + +@pytest.mark.anyio +async def test_json_response_accept_json_only(json_app: Starlette) -> None: + """JSON response mode only requires application/json in the Accept header.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests with incorrect Accept header.""" - mcp_url = f"{json_server_url}/mcp" - # Test with only text/event-stream (wrong for JSON server) - response = requests.post( - mcp_url, - headers={ - "Accept": "text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_json_response_missing_accept_header(json_app: Starlette) -> None: + """JSON response mode still rejects requests without an Accept header.""" + async with make_client(json_app) as client: + # Suppress the httpx client default Accept: */* header + del client.headers["accept"] + response = await client.post( + "/mcp", + headers={ + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + +@pytest.mark.anyio +async def test_json_response_incorrect_accept_header(json_app: Starlette) -> None: + """JSON response mode rejects an Accept header that does not cover application/json.""" + async with make_client(json_app) as client: + # Test with only text/event-stream (wrong for JSON server) + response = await client.post( + "/mcp", + headers={ + "Accept": "text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + +@pytest.mark.anyio @pytest.mark.parametrize( "accept_header", [ @@ -913,167 +772,134 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ "application/*;q=0.9", ], ) -def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): - """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": accept_header, - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_get_sse_stream(basic_server: None, basic_server_url: str): - """Test establishing an SSE stream via GET request.""" - # First, we need to initialize a session - mcp_url = f"{basic_server_url}/mcp" - init_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 +async def test_json_response_wildcard_accept_header(json_app: Starlette, accept_header: str) -> None: + """JSON response mode accepts wildcard Accept headers per RFC 7231.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" - # Get the session ID - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - assert session_id is not None - # Extract negotiated protocol version from SSE response - init_data = None - assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): - init_data = json.loads(line[6:]) - break - assert init_data is not None - negotiated_version = init_data["result"]["protocolVersion"] - - # Now attempt to establish an SSE stream via GET - get_response = requests.get( - mcp_url, - headers={ - "Accept": "text/event-stream", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) +@pytest.mark.anyio +async def test_get_sse_stream(basic_app: Starlette) -> None: + """GET establishes the standalone SSE stream, and a second GET is rejected with 409.""" + async with make_client(basic_app) as client: + # First, we need to initialize a session + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 - # Verify we got a successful response with the right content type - assert get_response.status_code == 200 - assert get_response.headers.get("Content-Type") == "text/event-stream" + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + negotiated_version = extract_protocol_version_from_sse(init_response) - # Test that a second GET request gets rejected (only one stream allowed) - second_get = requests.get( - mcp_url, - headers={ + # Now attempt to establish an SSE stream via GET + get_headers = { "Accept": "text/event-stream", MCP_SESSION_ID_HEADER: session_id, MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) + } + # The streams enter in order, so the second GET arrives while the first is held open. + async with ( + client.stream("GET", "/mcp", headers=get_headers) as get_response, + client.stream("GET", "/mcp", headers=get_headers) as second_get, + ): + # Verify we got a successful response with the right content type + assert get_response.status_code == 200 + assert get_response.headers.get("Content-Type") == "text/event-stream" - # Should get CONFLICT (409) since there's already a stream - # Note: This might fail if the first stream fully closed before this runs, - # but generally it should work in the test environment where it runs quickly - assert second_get.status_code == 409 - - -def test_get_validation(basic_server: None, basic_server_url: str): - """Test validation for GET requests.""" - # First, we need to initialize a session - mcp_url = f"{basic_server_url}/mcp" - init_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 + # The second GET gets CONFLICT (409): only one standalone stream is allowed per session. + assert second_get.status_code == 409 - # Get the session ID - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - assert session_id is not None - # Extract negotiated protocol version from SSE response - init_data = None - assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): - init_data = json.loads(line[6:]) - break - assert init_data is not None - negotiated_version = init_data["result"]["protocolVersion"] - - # Test without Accept header (suppress requests library default Accept: */*) - session = requests.Session() - session.headers.pop("Accept") - response = session.get( - mcp_url, - headers={ - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - # Test with wrong Accept header - response = requests.get( - mcp_url, - headers={ - "Accept": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_get_validation(basic_app: Starlette) -> None: + """A GET without an Accept header covering text/event-stream is rejected with 406.""" + async with make_client(basic_app) as client: + # First, we need to initialize a session + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + negotiated_version = extract_protocol_version_from_sse(init_response) + + # Test without Accept header (suppress the httpx client default Accept: */*) + del client.headers["accept"] + response = await client.get( + "/mcp", + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with wrong Accept header + response = await client.get( + "/mcp", + headers={ + "Accept": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover - """Create test client matching the SSE test pattern.""" - async with httpx.AsyncClient(base_url=basic_server_url) as client: - yield client - - -@pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str): +async def initialized_client_session(basic_app: Starlette) -> AsyncIterator[ClientSession]: """Create initialized StreamableHTTP client session.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - yield session + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): - """Test basic client connection with initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME +async def test_streamable_http_client_basic_connection(basic_app: Starlette) -> None: + """A client initializes against a server over the StreamableHTTP transport.""" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.server_info.name == SERVER_NAME @pytest.mark.anyio -async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession): - """Test client resource read functionality.""" +async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession) -> None: + """A resource read round-trips its arguments and the handler's content.""" response = await initialized_client_session.read_resource(uri="foobar://test-resource") assert len(response.contents) == 1 assert response.contents[0].uri == "foobar://test-resource" @@ -1082,11 +908,11 @@ async def test_streamable_http_client_resource_read(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession): - """Test client tool invocation.""" +async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession) -> None: + """A tool call reaches the handler and returns its content.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 9 assert tools.tools[0].name == "test_tool" # Call the tool @@ -1097,8 +923,8 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session @pytest.mark.anyio -async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession): - """Test error handling in client.""" +async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession) -> None: + """A server-side error reaches the client as an MCPError with the handler's message.""" with pytest.raises(MCPError) as exc_info: await initialized_client_session.read_resource(uri="unknown://test-error") assert exc_info.value.error.code == 0 @@ -1106,50 +932,56 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): - """Test that session ID persists across requests.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Make multiple requests to verify session persistence - tools = await session.list_tools() - assert len(tools.tools) == 10 - - # Read a resource - resource = await session.read_resource(uri="foobar://test-persist") - assert isinstance(resource.contents[0], TextResourceContents) is True - content = resource.contents[0] - assert isinstance(content, TextResourceContents) - assert content.text == "Read test-persist" +async def test_streamable_http_client_session_persistence(basic_app: Starlette) -> None: + """The session persists across multiple requests on one connection.""" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make multiple requests to verify session persistence + tools = await session.list_tools() + assert len(tools.tools) == 9 + + # Read a resource + resource = await session.read_resource(uri="foobar://test-persist") + assert isinstance(resource.contents[0], TextResourceContents) is True + content = resource.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Read test-persist" @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): - """Test client with JSON response mode.""" - async with streamable_http_client(f"{json_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME - - # Check tool listing - tools = await session.list_tools() - assert len(tools.tools) == 10 - - # Call a tool and verify JSON response handling - result = await session.call_tool("test_tool", {}) - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Called test_tool" +async def test_streamable_http_client_json_response(json_app: Starlette) -> None: + """The client works identically against a server in JSON response mode.""" + async with ( + make_client(json_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.server_info.name == SERVER_NAME + + # Check tool listing + tools = await session.list_tools() + assert len(tools.tools) == 9 + + # Call a tool and verify JSON response handling + result = await session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): - """Test GET stream functionality for server-initiated messages.""" +async def test_streamable_http_client_get_stream(basic_app: Starlette) -> None: + """A server-initiated notification reaches the client on the standalone GET stream.""" notifications_received: list[types.ServerNotification] = [] # Define message handler to capture notifications @@ -1159,30 +991,33 @@ async def message_handler( # pragma: no branch if isinstance(message, types.ServerNotification): # pragma: no branch notifications_received.append(message) - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Initialize the session - this triggers the GET stream setup - result = await session.initialize() - assert isinstance(result, InitializeResult) + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + # Initialize the session - this triggers the GET stream setup + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Call the special tool that sends a notification - await session.call_tool("test_tool_with_standalone_notification", {}) + # Call the special tool that sends a notification + await session.call_tool("test_tool_with_standalone_notification", {}) - # Verify we received the notification - assert len(notifications_received) > 0 + # Verify we received the notification + assert len(notifications_received) > 0 - # Verify the notification is a ResourceUpdatedNotification - resource_update_found = False - for notif in notifications_received: - if isinstance(notif, types.ResourceUpdatedNotification): # pragma: no branch - assert str(notif.params.uri) == "http://test_resource" - resource_update_found = True + # Verify the notification is a ResourceUpdatedNotification + resource_update_found = False + for notif in notifications_received: + if isinstance(notif, types.ResourceUpdatedNotification): # pragma: no branch + assert str(notif.params.uri) == "http://test_resource" + resource_update_found = True - assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" + assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" -def create_session_id_capturing_client() -> tuple[httpx.AsyncClient, list[str]]: - """Create an httpx client that captures the session ID from responses.""" +def create_session_id_capturing_client(app: Starlette) -> tuple[httpx.AsyncClient, list[str]]: + """Create an in-process httpx client that captures the session ID from responses.""" captured_ids: list[str] = [] async def capture_session_id(response: httpx.Response) -> None: @@ -1191,21 +1026,22 @@ async def capture_session_id(response: httpx.Response) -> None: captured_ids.append(session_id) client = httpx.AsyncClient( + transport=StreamingASGITransport(app), + base_url=BASE_URL, follow_redirects=True, - timeout=httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT), event_hooks={"response": [capture_session_id]}, ) return client, captured_ids @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): - """Test client session termination functionality.""" +async def test_streamable_http_client_session_termination(basic_app: Starlette) -> None: + """After the client terminates its session on close, a new connection with that session ID fails.""" # Use httpx client with event hooks to capture session ID - httpx_client, captured_ids = create_session_id_capturing_client() + httpx_client, captured_ids = create_session_id_capturing_client(basic_app) async with httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1220,10 +1056,10 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 9 - async with create_mcp_http_client(headers=headers) as httpx_client2: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( + async with make_client(basic_app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, ): @@ -1235,9 +1071,9 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba @pytest.mark.anyio async def test_streamable_http_client_session_termination_204( - basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch -): - """Test client session termination functionality with a 204 response. + basic_app: Starlette, monkeypatch: pytest.MonkeyPatch +) -> None: + """Session termination also succeeds when the server answers the DELETE with 204. This test patches the httpx client to return a 204 response for DELETEs. """ @@ -1263,10 +1099,10 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt monkeypatch.setattr(httpx.AsyncClient, "delete", mock_delete) # Use httpx client with event hooks to capture session ID - httpx_client, captured_ids = create_session_id_capturing_client() + httpx_client, captured_ids = create_session_id_capturing_client(basic_app) async with httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1281,10 +1117,10 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 9 - async with create_mcp_http_client(headers=headers) as httpx_client2: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( + async with make_client(basic_app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, ): @@ -1295,14 +1131,15 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt @pytest.mark.anyio -async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]): - """Test client session resumption using sync primitives for reliable coordination.""" - _, server_url = event_server +async def test_streamable_http_client_resumption(event_app: tuple[SimpleEventStore, Starlette]) -> None: + """A second client resumes an interrupted request with a resumption token and receives the rest.""" + _, app = event_app # Variables to track the state captured_resumption_token: str | None = None captured_notifications: list[types.ServerNotification] = [] - first_notification_received = False + first_notification_received = anyio.Event() + resumption_token_received = anyio.Event() async def message_handler( # pragma: no branch message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -1312,19 +1149,19 @@ async def message_handler( # pragma: no branch # Look for our first notification if isinstance(message, types.LoggingMessageNotification): # pragma: no branch if message.params.data == "First notification before lock": - nonlocal first_notification_received - first_notification_received = True + first_notification_received.set() async def on_resumption_token_update(token: str) -> None: nonlocal captured_resumption_token captured_resumption_token = token + resumption_token_received.set() # Use httpx client with event hooks to capture session ID - httpx_client, captured_ids = create_session_id_capturing_client() + httpx_client, captured_ids = create_session_id_capturing_client(app) # First, start the client session and begin the tool that waits on lock async with httpx_client: - async with streamable_http_client(f"{server_url}/mcp", terminate_on_close=False, http_client=httpx_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", terminate_on_close=False, http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1363,13 +1200,13 @@ async def run_tool(): tg.start_soon(run_tool) # Wait for the first notification and resumption token - while not first_notification_received or not captured_resumption_token: - await anyio.sleep(0.1) + with anyio.fail_after(5): + await first_notification_received.wait() + await resumption_token_received.wait() - # The while loop only exits after first_notification_received=True, - # which is set by message_handler immediately after appending to - # captured_notifications. The server tool is blocked on its lock, - # so nothing else can arrive before we cancel. + # first_notification_received is set by message_handler immediately + # after appending to captured_notifications. The server tool is + # blocked on its lock, so nothing else can arrive before we cancel. assert len(captured_notifications) == 1 assert isinstance(captured_notifications[0], types.LoggingMessageNotification) assert captured_notifications[0].params.data == "First notification before lock" @@ -1379,8 +1216,8 @@ async def run_tool(): # Kill the client session while tool is waiting on lock tg.cancel_scope.cancel() - async with create_mcp_http_client(headers=headers) as httpx_client2: - async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client2) as ( + async with make_client(app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, ): @@ -1413,8 +1250,8 @@ async def run_tool(): @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): - """Test server-initiated sampling request through streamable HTTP transport.""" +async def test_streamablehttp_server_sampling(basic_app: Starlette) -> None: + """A server-initiated sampling request reaches the client callback and its result the tool.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False captured_message_params = None @@ -1441,29 +1278,32 @@ async def sampling_callback( ) # Create client with sampling callback - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Call the tool that triggers server-side sampling - tool_result = await session.call_tool("test_sampling_tool", {}) - - # Verify the tool result contains the expected content - assert len(tool_result.content) == 1 - assert tool_result.content[0].type == "text" - assert "Response from sampling: Received message from server" in tool_result.content[0].text - - # Verify sampling callback was invoked - assert sampling_callback_invoked - assert captured_message_params is not None - assert len(captured_message_params.messages) == 1 - assert captured_message_params.messages[0].content.text == "Server needs client sampling" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session, + ): + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the tool that triggers server-side sampling + tool_result = await session.call_tool("test_sampling_tool", {}) + + # Verify the tool result contains the expected content + assert len(tool_result.content) == 1 + assert tool_result.content[0].type == "text" + assert "Response from sampling: Received message from server" in tool_result.content[0].text + + # Verify sampling callback was invoked + assert sampling_callback_invoked + assert captured_message_params is not None + assert len(captured_message_params.messages) == 1 + assert captured_message_params.messages[0].content.text == "Server needs client sampling" # Context-aware server implementation for testing request context propagation -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -1488,97 +1328,51 @@ async def _handle_context_list_tools( # pragma: no cover ) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - name = params.name - args = params.arguments or {} +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name in ("echo_headers", "echo_context") + assert isinstance(ctx.request, Request) - if name == "echo_headers": - headers_info: dict[str, Any] = {} - if ctx.request and isinstance(ctx.request, Request): - headers_info = dict(ctx.request.headers) - return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) - - elif name == "echo_context": - context_data: dict[str, Any] = { - "request_id": args.get("request_id"), - "headers": {}, - "method": None, - "path": None, - } - if ctx.request and isinstance(ctx.request, Request): - request = ctx.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) + if params.name == "echo_headers": + return CallToolResult(content=[TextContent(type="text", text=json.dumps(dict(ctx.request.headers)))]) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) + assert params.arguments is not None + context_data: dict[str, Any] = { + "request_id": params.arguments.get("request_id"), + "headers": dict(ctx.request.headers), + "method": ctx.request.method, + "path": ctx.request.url.path, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) -# Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" +@pytest.fixture +async def context_app() -> AsyncIterator[Starlette]: + """An app whose server echoes request context, served in process.""" server = Server( "ContextAwareServer", on_list_tools=_handle_context_list_tools, on_call_tool=_handle_context_call_tool, ) - session_manager = StreamableHTTPSessionManager( app=server, - event_store=None, - json_response=False, - ) - - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lambda app: session_manager.run(), - ) - - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), ) - server_instance.run() - - -@pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + async with session_manager.run(): + yield app @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: - """Test that request context is properly propagated through StreamableHTTP.""" +async def test_streamablehttp_request_context_propagation(context_app: Starlette) -> None: + """Custom HTTP headers on the connection are visible to server handlers via ctx.request.""" custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with make_client(context_app, headers=custom_headers) as httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1602,11 +1396,11 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: - """Test that request contexts are isolated between StreamableHTTP clients.""" +async def test_streamablehttp_request_context_isolation(context_app: Starlette) -> None: + """Each connection's handlers see only that connection's request headers.""" contexts: list[dict[str, Any]] = [] - # Create multiple clients with different headers + # Connect three clients in turn, each with its own headers. for i in range(3): headers = { "X-Request-Id": f"request-{i}", @@ -1614,8 +1408,8 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No "Authorization": f"Bearer token-{i}", } - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with make_client(context_app, headers=headers) as httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1640,145 +1434,160 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): - """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Initialize and get the negotiated version - init_result = await session.initialize() - negotiated_version = init_result.protocol_version - - # Call a tool that echoes headers to verify the header is present - tool_result = await session.call_tool("echo_headers", {}) - - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - headers_data = json.loads(tool_result.content[0].text) - - # Verify protocol version header is present - assert "mcp-protocol-version" in headers_data - assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version - - -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): - """Test that server returns 400 Bad Request version if header unsupported or invalid.""" - # First initialize a session to get a valid session ID - init_response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - - # Test request with invalid protocol version (should fail) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "invalid-version", - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() - - # Test request with unsupported protocol version (should fail) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() +async def test_client_includes_protocol_version_header_after_init(context_app: Starlette) -> None: + """After initialization, every client request carries the negotiated protocol version header.""" + async with ( + make_client(context_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + # Initialize and get the negotiated version + init_result = await session.initialize() + negotiated_version = init_result.protocol_version - # Test request with valid protocol version (should succeed) - negotiated_version = extract_protocol_version_from_sse(init_response) + # Call a tool that echoes headers to verify the header is present + tool_result = await session.call_tool("echo_headers", {}) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, - ) - assert response.status_code == 200 - - -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): - """Test server accepts requests without protocol version header.""" - # First initialize a session to get a valid session ID - init_response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - - # Test request without mcp-protocol-version header (backwards compatibility) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, - stream=True, - ) - assert response.status_code == 200 # Should succeed for backwards compatibility - assert response.headers.get("Content-Type") == "text/event-stream" + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + # Verify protocol version header is present + assert "mcp-protocol-version" in headers_data + assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version + + +@pytest.mark.anyio +async def test_server_validates_protocol_version_header(basic_app: Starlette) -> None: + """An invalid or unsupported protocol version header is rejected with 400; the negotiated one passes.""" + async with make_client(basic_app) as client: + # First initialize a session to get a valid session ID + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test request with invalid protocol version (should fail) + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "invalid-version", + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with unsupported protocol version (should fail) + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with valid protocol version (should succeed) + negotiated_version = extract_protocol_version_from_sse(init_response) + + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, + ) + assert response.status_code == 200 + + +@pytest.mark.anyio +async def test_server_backwards_compatibility_no_protocol_version(basic_app: Starlette) -> None: + """A request without a protocol version header is accepted for backwards compatibility.""" + async with make_client(basic_app) as client: + # First initialize a session to get a valid session ID + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test request without mcp-protocol-version header (backwards compatibility) + async with client.stream( + "POST", + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, + ) as response: + assert response.status_code == 200 # Should succeed for backwards compatibility + assert response.headers.get("Content-Type") == "text/event-stream" @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str): - """Test that cases where the client crashes are handled gracefully.""" +async def test_client_crash_handled(basic_app: Starlette) -> None: + """A client crashing mid-session does not prevent later clients from connecting.""" # Simulate bad client that crashes after init async def bad_client(): """Client that triggers ClosedResourceError""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - raise Exception("client crash") + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + raise Exception("client crash") - # Run bad client a few times to trigger the crash + # Run bad client a few times to trigger the crash. The crash surfaces wrapped in exception + # groups whose exact shape is not the subject here — what matters is that the server survives. for _ in range(3): try: await bad_client() except Exception: pass - await anyio.sleep(0.1) # Try a good client, it should still be able to connect and list tools - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - result = await session.initialize() - assert isinstance(result, InitializeResult) - tools = await session.list_tools() - assert tools.tools + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + result = await session.initialize() + assert isinstance(result, InitializeResult) + tools = await session.list_tools() + assert tools.tools @pytest.mark.anyio -async def test_handle_sse_event_skips_empty_data(): - """Test that _handle_sse_event skips empty SSE data (keep-alive pings).""" +async def test_handle_sse_event_skips_empty_data() -> None: + """_handle_sse_event skips empty SSE data (keep-alive pings) without writing to the stream.""" transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") # Create a mock SSE event with empty data (keep-alive ping) @@ -1804,8 +1613,8 @@ async def test_handle_sse_event_skips_empty_data(): @pytest.mark.anyio -async def test_priming_event_not_sent_for_old_protocol_version(): - """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" +async def test_priming_event_not_sent_for_old_protocol_version() -> None: + """_maybe_send_priming_event skips for old protocol versions (backwards compat).""" # Create a transport with an event store transport = StreamableHTTPServerTransport( "/mcp", @@ -1833,8 +1642,8 @@ async def test_priming_event_not_sent_for_old_protocol_version(): @pytest.mark.anyio -async def test_priming_event_not_sent_without_event_store(): - """Test that _maybe_send_priming_event returns early when no event_store is configured.""" +async def test_priming_event_not_sent_without_event_store() -> None: + """_maybe_send_priming_event returns early when no event_store is configured.""" # Create a transport WITHOUT an event store transport = StreamableHTTPServerTransport("/mcp") @@ -1853,8 +1662,8 @@ async def test_priming_event_not_sent_without_event_store(): @pytest.mark.anyio -async def test_priming_event_includes_retry_interval(): - """Test that _maybe_send_priming_event includes retry field when retry_interval is set.""" +async def test_priming_event_includes_retry_interval() -> None: + """_maybe_send_priming_event includes the retry field when retry_interval is set.""" # Create a transport with an event store AND retry_interval transport = StreamableHTTPServerTransport( "/mcp", @@ -1882,8 +1691,8 @@ async def test_priming_event_includes_retry_interval(): @pytest.mark.anyio -async def test_close_sse_stream_callback_not_provided_for_old_protocol_version(): - """Test that close_sse_stream callbacks are NOT provided for old protocol versions.""" +async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() -> None: + """close_sse_stream callbacks are only provided for protocol versions that support polling.""" # Create a transport with an event store transport = StreamableHTTPServerTransport( "/mcp", @@ -1915,71 +1724,76 @@ async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() @pytest.mark.anyio async def test_streamable_http_client_receives_priming_event( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client should receive priming event (resumption token update) on POST SSE stream.""" - _, server_url = event_server + _, app = event_app captured_resumption_tokens: list[str] = [] async def on_resumption_token_update(token: str) -> None: captured_resumption_tokens.append(token) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() - # Call tool with resumption token callback via send_request - metadata = ClientMessageMetadata( - on_resumption_token_update=on_resumption_token_update, - ) - result = await session.send_request( - types.CallToolRequest(params=types.CallToolRequestParams(name="test_tool", arguments={})), - types.CallToolResult, - metadata=metadata, - ) - assert result is not None - - # Should have received priming event token BEFORE response data - # Priming event = 1 token (empty data, id only) - # Response = 1 token (actual JSON-RPC response) - # Total = 2 tokens minimum - assert len(captured_resumption_tokens) >= 2, ( - f"Server must send priming event before response. " - f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" - ) - assert captured_resumption_tokens[0] is not None + # Call tool with resumption token callback via send_request + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + result = await session.send_request( + types.CallToolRequest(params=types.CallToolRequestParams(name="test_tool", arguments={})), + types.CallToolResult, + metadata=metadata, + ) + assert result is not None + + # Should have received priming event token BEFORE response data + # Priming event = 1 token (empty data, id only) + # Response = 1 token (actual JSON-RPC response) + # Total = 2 tokens minimum + assert len(captured_resumption_tokens) >= 2, ( + f"Server must send priming event before response. " + f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" + ) + assert captured_resumption_tokens[0] is not None @pytest.mark.anyio async def test_server_close_sse_stream_via_context( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Server tool can call ctx.close_sse_stream() to close connection.""" - _, server_url = event_server + _, app = event_app - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() - # Call tool that closes stream mid-operation - # This should NOT raise NotImplementedError when fully implemented - result = await session.call_tool("tool_with_stream_close", {}) + # Call tool that closes stream mid-operation + result = await session.call_tool("tool_with_stream_close", {}) - # Client should still receive complete response (via auto-reconnect) - assert result is not None - assert len(result.content) > 0 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + # Client should still receive complete response (via auto-reconnect) + assert result is not None + assert len(result.content) > 0 + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_client_auto_reconnects( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client should auto-reconnect with Last-Event-ID when server closes after priming event.""" - _, server_url = event_server + _, app = event_app captured_notifications: list[str] = [] async def message_handler( @@ -1991,59 +1805,63 @@ async def message_handler( if isinstance(message, types.LoggingMessageNotification): # pragma: no branch captured_notifications.append(str(message.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() - - # Call tool that: - # 1. Sends notification - # 2. Closes SSE stream - # 3. Sends more notifications (stored in event_store) - # 4. Returns response - result = await session.call_tool("tool_with_stream_close", {}) - - # Client should have auto-reconnected and received ALL notifications - assert len(captured_notifications) >= 2, ( - "Client should auto-reconnect and receive notifications sent both before and after stream close" - ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + + # Call tool that: + # 1. Sends notification + # 2. Closes SSE stream + # 3. Sends more notifications (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_stream_close", {}) + + # Client should have auto-reconnected and received ALL notifications + assert len(captured_notifications) >= 2, ( + "Client should auto-reconnect and receive notifications sent both before and after stream close" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_client_respects_retry_interval( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client MUST respect retry field, waiting specified ms before reconnecting.""" - _, server_url = event_server + _, app = event_app - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() - start_time = time.monotonic() - result = await session.call_tool("tool_with_stream_close", {}) - elapsed = time.monotonic() - start_time + start_time = time.monotonic() + result = await session.call_tool("tool_with_stream_close", {}) + elapsed = time.monotonic() - start_time - # Verify result was received - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + # Verify result was received + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" - # The elapsed time should include at least the retry interval - # if reconnection occurred. This test may be flaky depending on - # implementation details, but demonstrates the expected behavior. - # Note: This assertion may need adjustment based on actual implementation - assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" + # The elapsed time should include at least the retry interval (500ms) before + # the client reconnected; the tool's own work only accounts for ~100ms. + assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" @pytest.mark.anyio async def test_streamable_http_sse_polling_full_cycle( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """End-to-end test: server closes stream, client reconnects, receives all events.""" - _, server_url = event_server + _, app = event_app all_notifications: list[str] = [] async def message_handler( @@ -2055,35 +1873,38 @@ async def message_handler( if isinstance(message, types.LoggingMessageNotification): # pragma: no branch all_notifications.append(str(message.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() - - # Call tool that simulates polling pattern: - # 1. Server sends priming event - # 2. Server sends "Before close" notification - # 3. Server closes stream (calls close_sse_stream) - # 4. (client reconnects automatically) - # 5. Server sends "After close" notification - # 6. Server sends final response - result = await session.call_tool("tool_with_stream_close", {}) - - # Verify all notifications received in order - assert "Before close" in all_notifications, "Should receive notification sent before stream close" - assert "After close" in all_notifications, ( - "Should receive notification sent after stream close (via auto-reconnect)" - ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + + # Call tool that simulates polling pattern: + # 1. Server sends priming event + # 2. Server sends "Before close" notification + # 3. Server closes stream (calls close_sse_stream) + # 4. (client reconnects automatically) + # 5. Server sends "After close" notification + # 6. Server sends final response + result = await session.call_tool("tool_with_stream_close", {}) + + # Verify all notifications received in order + assert "Before close" in all_notifications, "Should receive notification sent before stream close" + assert "After close" in all_notifications, ( + "Should receive notification sent after stream close (via auto-reconnect)" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_events_replayed_after_disconnect( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Events sent while client is disconnected should be replayed on reconnect.""" - _, server_url = event_server + _, app = event_app notification_data: list[str] = [] async def message_handler( @@ -2095,33 +1916,36 @@ async def message_handler( if isinstance(message, types.LoggingMessageNotification): # pragma: no branch notification_data.append(str(message.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() - # Tool sends: notification1, close_stream, notification2, notification3, response - # Client should receive all notifications even though 2&3 were sent during disconnect - result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) + # Tool sends: notification1, close_stream, notification2, notification3, response + # Client should receive all notifications even though 2&3 were sent during disconnect + result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) - assert "notification1" in notification_data, "Should receive notification1 (sent before close)" - assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" - assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" + assert "notification1" in notification_data, "Should receive notification1 (sent before close)" + assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" + assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" - # Verify order: notification1 should come before notification2 and notification3 - idx1 = notification_data.index("notification1") - idx2 = notification_data.index("notification2") - idx3 = notification_data.index("notification3") - assert idx1 < idx2 < idx3, "Notifications should be received in order" + # Verify order: notification1 should come before notification2 and notification3 + idx1 = notification_data.index("notification1") + idx2 = notification_data.index("notification2") + idx3 = notification_data.index("notification3") + assert idx1 < idx2 < idx3, "Notifications should be received in order" - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "All notifications sent" + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "All notifications sent" @pytest.mark.anyio async def test_streamable_http_multiple_reconnections( - event_server: tuple[SimpleEventStore, str], -): + event_app: tuple[SimpleEventStore, Starlette], +) -> None: """Verify multiple close_sse_stream() calls each trigger a client reconnect. Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure @@ -2133,45 +1957,48 @@ async def test_streamable_http_multiple_reconnections( - 3 priming (one per reconnect after each close) - 1 response """ - _, server_url = event_server + _, app = event_app resumption_tokens: list[str] = [] async def on_resumption_token(token: str) -> None: resumption_tokens.append(token) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Use send_request with metadata to track resumption tokens - metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) - result = await session.send_request( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name="tool_with_multiple_stream_closes", - # retry_interval=500ms, so sleep 600ms to ensure reconnect completes - arguments={"checkpoints": 3, "sleep_time": 0.6}, - ), + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + + # Use send_request with metadata to track resumption tokens + metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) + result = await session.send_request( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="tool_with_multiple_stream_closes", + # retry_interval=500ms, so sleep 600ms to ensure reconnect completes + arguments={"checkpoints": 3, "sleep_time": 0.6}, ), - types.CallToolResult, - metadata=metadata, - ) + ), + types.CallToolResult, + metadata=metadata, + ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Completed 3 checkpoints" in result.content[0].text + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert "Completed 3 checkpoints" in result.content[0].text - # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are - # captured before send_request returns, so this is safe to check here. - assert len(resumption_tokens) == 8, ( - f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " - f"got {len(resumption_tokens)}: {resumption_tokens}" - ) + # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are + # captured before send_request returns, so this is safe to check here. + assert len(resumption_tokens) == 8, ( + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) @pytest.mark.anyio -async def test_standalone_get_stream_reconnection(event_server: tuple[SimpleEventStore, str]) -> None: +async def test_standalone_get_stream_reconnection(event_app: tuple[SimpleEventStore, Starlette]) -> None: """Test that standalone GET stream automatically reconnects after server closes it. Verifies: @@ -2180,10 +2007,10 @@ async def test_standalone_get_stream_reconnection(event_server: tuple[SimpleEven 3. Client reconnects with Last-Event-ID 4. Client receives notification 2 on new connection - Note: Requires event_server fixture (with event store) because close_standalone_sse_stream + Note: Requires the event store app because close_standalone_sse_stream callback is only provided when event_store is configured and protocol version >= 2025-11-25. """ - _, server_url = event_server + _, app = event_app received_notifications: list[str] = [] async def message_handler( @@ -2195,45 +2022,46 @@ async def message_handler( if isinstance(message, types.ResourceUpdatedNotification): # pragma: no branch received_notifications.append(str(message.params.uri)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() - - # Call tool that: - # 1. Sends notification_1 via GET stream - # 2. Closes standalone GET stream - # 3. Sends notification_2 (stored in event_store) - # 4. Returns response - result = await session.call_tool("tool_with_standalone_stream_close", {}) - - # Verify the tool completed - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Standalone stream close test done" - - # Verify both notifications were received - assert "http://notification_1" in received_notifications, ( - f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" - ) - assert "http://notification_2" in received_notifications, ( - f"Should receive notification 2 after reconnect, got: {received_notifications}" - ) + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + + # Call tool that: + # 1. Sends notification_1 via GET stream + # 2. Closes standalone GET stream + # 3. Sends notification_2 (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_standalone_stream_close", {}) + + # Verify the tool completed + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Standalone stream close test done" + + # Verify both notifications were received + assert "http://notification_1" in received_notifications, ( + f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" + ) + assert "http://notification_2" in received_notifications, ( + f"Should receive notification 2 after reconnect, got: {received_notifications}" + ) @pytest.mark.anyio -async def test_streamable_http_client_does_not_mutate_provided_client( - basic_server: None, basic_server_url: str -) -> None: - """Test that streamable_http_client does not mutate the provided httpx client's headers.""" +async def test_streamable_http_client_does_not_mutate_provided_client(basic_app: Starlette) -> None: + """streamable_http_client does not mutate the provided httpx client's headers.""" # Create a client with custom headers original_headers = { "X-Custom-Header": "custom-value", "Authorization": "Bearer test-token", } - async with httpx.AsyncClient(headers=original_headers, follow_redirects=True) as custom_client: + async with make_client(basic_app, headers=original_headers) as custom_client: # Use the client with streamable_http_client - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=custom_client) as ( read_stream, write_stream, ): @@ -2254,18 +2082,16 @@ async def test_streamable_http_client_does_not_mutate_provided_client( @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str -) -> None: - """Test that MCP protocol headers override httpx.AsyncClient default headers.""" +async def test_streamable_http_client_mcp_headers_override_defaults(context_app: Starlette) -> None: + """MCP protocol headers override the httpx client's default headers in actual requests.""" # httpx.AsyncClient has default "accept: */*" header # We need to verify that our MCP accept header overrides it in actual requests - async with httpx.AsyncClient(follow_redirects=True) as client: + async with make_client(context_app) as client: # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2285,18 +2111,16 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str -) -> None: - """Test that both custom headers and MCP protocol headers are sent in requests.""" +async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_app: Starlette) -> None: + """Custom client headers and MCP protocol headers are both sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", "X-Request-Id": "req-123", "Authorization": "Bearer test-token", } - async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with make_client(context_app, headers=custom_headers) as client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 810c72820b..0038b18905 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,7 +2,6 @@ import socket import threading -import time from collections.abc import Generator from contextlib import contextmanager from typing import Any @@ -56,30 +55,3 @@ def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None finally: server.should_exit = True thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) - - -def wait_for_server(port: int, timeout: float = 20.0) -> None: - """Wait for server to be ready to accept connections. - - Polls the server port until it accepts connections or timeout is reached. - This eliminates race conditions without arbitrary sleeps. - - Args: - port: The port number to check - timeout: Maximum time to wait in seconds (default 5.0) - - Raises: - TimeoutError: If server doesn't start within the timeout period - """ - start_time = time.time() - while time.time() - start_time < timeout: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.settimeout(0.1) - s.connect(("127.0.0.1", port)) - # Server is ready - return - except (ConnectionRefusedError, OSError): - # Server not ready yet, retry quickly - time.sleep(0.01) - raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") # pragma: no cover From f18d65bcc78d6bc34ec1e95ae561fd0331e896ee Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 3 Jun 2026 10:17:41 +0000 Subject: [PATCH 2/2] Gate the multiple-reconnections test on observed reconnects, not sleeps The tool slept a fixed duration after each close_sse_stream() so the client's retry-interval reconnect could land before the next close. That is a timing margin, not synchronization: a reconnect delayed past the margin made the close a silent no-op, two cycles merged, and the exact resumption-token count failed. The tool now waits, bounded by fail_after(5), for the client-side resumption-token callback to observe each cycle's two new tokens (the checkpoint and the new connection's priming event). The priming event is sent only after the server has re-registered the resumed stream, so once the client holds its token the next close is guaranteed to sever a live connection. The token count becomes a consequence of causality rather than margins, no sleep remains, the retry interval drops from 500ms to 50ms, and a genuinely failed reconnect now fails loudly at the timeout instead of silently merging cycles. The test now defines its own server with the gated tool inline, since the gate closes over per-test state; the shared server's tool_with_multiple_stream_closes had no other users and is removed, which moves the tools/list count assertions from 9 to 8. --- tests/shared/test_streamable_http.py | 96 ++++++++++++++-------------- 1 file changed, 47 insertions(+), 49 deletions(-) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 9ed209b890..b43a3361c9 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -185,17 +185,6 @@ async def _handle_list_tools( description="Tool that sends notification1, closes stream, sends notification2, notification3", input_schema={"type": "object", "properties": {}}, ), - Tool( - name="tool_with_multiple_stream_closes", - description="Tool that closes SSE stream multiple times during execution", - input_schema={ - "type": "object", - "properties": { - "checkpoints": {"type": "integer", "default": 3}, - "sleep_time": {"type": "number", "default": 0.2}, - }, - }, - ), Tool( name="tool_with_standalone_stream_close", description="Tool that closes standalone GET stream mid-operation", @@ -207,7 +196,6 @@ async def _handle_list_tools( async def _handle_call_tool(ctx: ServerRequestContext[ServerState], params: CallToolRequestParams) -> CallToolResult: name = params.name - args = params.arguments or {} # When the tool is called, send a notification to test GET stream if name == "test_tool_with_standalone_notification": @@ -301,25 +289,6 @@ async def _handle_call_tool(ctx: ServerRequestContext[ServerState], params: Call ) return CallToolResult(content=[TextContent(type="text", text="All notifications sent")]) - elif name == "tool_with_multiple_stream_closes": - num_checkpoints = args.get("checkpoints", 3) - sleep_time = args.get("sleep_time", 0.2) - - for i in range(num_checkpoints): - await ctx.session.send_log_message( - level="info", - data=f"checkpoint_{i}", - logger="multi_close_tool", - related_request_id=ctx.request_id, - ) - - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - - await anyio.sleep(sleep_time) - - return CallToolResult(content=[TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")]) - elif name == "tool_with_standalone_stream_close": await ctx.session.send_resource_updated(uri="http://notification_1") await anyio.sleep(0.1) @@ -350,6 +319,7 @@ async def running_app( is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, + server: Server[Any] | None = None, ) -> AsyncIterator[Starlette]: """Serve the test server's streamable HTTP app in process for the duration. @@ -357,12 +327,13 @@ async def running_app( is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. retry_interval: Retry interval in milliseconds for SSE polling. + server: Server to mount; defaults to the file's shared test server. """ # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot # exist for an in-process app; the protection itself is pinned by # tests/server/test_streamable_http_security.py. session_manager = StreamableHTTPSessionManager( - app=_create_server(), + app=server if server is not None else _create_server(), event_store=event_store, json_response=is_json_response_enabled, security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), @@ -912,7 +883,7 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session """A tool call reaches the handler and returns its content.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 9 + assert len(tools.tools) == 8 assert tools.tools[0].name == "test_tool" # Call the tool @@ -945,7 +916,7 @@ async def test_streamable_http_client_session_persistence(basic_app: Starlette) # Make multiple requests to verify session persistence tools = await session.list_tools() - assert len(tools.tools) == 9 + assert len(tools.tools) == 8 # Read a resource resource = await session.read_resource(uri="foobar://test-persist") @@ -970,7 +941,7 @@ async def test_streamable_http_client_json_response(json_app: Starlette) -> None # Check tool listing tools = await session.list_tools() - assert len(tools.tools) == 9 + assert len(tools.tools) == 8 # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) @@ -1056,7 +1027,7 @@ async def test_streamable_http_client_session_termination(basic_app: Starlette) # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 9 + assert len(tools.tools) == 8 async with make_client(basic_app, headers=headers) as httpx_client2: async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( @@ -1117,7 +1088,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 9 + assert len(tools.tools) == 8 async with make_client(basic_app, headers=headers) as httpx_client2: async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( @@ -1943,13 +1914,16 @@ async def message_handler( @pytest.mark.anyio -async def test_streamable_http_multiple_reconnections( - event_app: tuple[SimpleEventStore, Starlette], -) -> None: - """Verify multiple close_sse_stream() calls each trigger a client reconnect. +async def test_streamable_http_multiple_reconnections() -> None: + """Every close_sse_stream() severs a live connection and triggers its own client reconnect. - Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure - client has time to reconnect before the next checkpoint. + The tool closes its SSE stream three times; before each next cycle it waits until the + client has observed the previous cycle's two new resumption tokens (the checkpoint and the + new connection's priming event). The priming event is sent only after the server has + re-registered the resumed stream, so once the client holds its token the next close is + guaranteed to sever a live connection rather than silently no-op — making the exact token + count below a consequence of causality, not timing margins. This pins reconnect-per-close + accounting; reconnect *latency* is pinned by test_streamable_http_client_respects_retry_interval. With 3 checkpoints, we expect 8 resumption tokens: - 1 priming (initial POST connection) @@ -1957,13 +1931,41 @@ async def test_streamable_http_multiple_reconnections( - 3 priming (one per reconnect after each close) - 1 response """ - _, app = event_app resumption_tokens: list[str] = [] + # milestones[n] fires when the client has observed n tokens. After the initial priming + # (token 1), each completed cycle i contributes exactly two tokens — checkpoint_i and the + # reconnect's priming, in either order — so cycle i is complete at 3 + 2i tokens. + milestones = {3: anyio.Event(), 5: anyio.Event(), 7: anyio.Event()} async def on_resumption_token(token: str) -> None: resumption_tokens.append(token) + milestone = milestones.get(len(resumption_tokens)) + if milestone is not None: + milestone.set() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "multi_close_tool" + for i, milestone in enumerate(milestones.values()): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Client and server share one event loop, so the tool can wait directly on the + # client-side callback observing the reconnect. + with anyio.fail_after(5): + await milestone.wait() + return CallToolResult(content=[TextContent(type="text", text="Completed 3 checkpoints")]) + + server = Server("multi_reconnect_server", on_call_tool=handle_call_tool) async with ( + # retry_interval is small to keep the test fast, but nonzero so each dying connection + # finishes unwinding before its replacement registers. + running_app(event_store=SimpleEventStore(), retry_interval=50, server=server) as app, make_client(app) as http_client, streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), ClientSession(read_stream, write_stream) as session, @@ -1975,11 +1977,7 @@ async def on_resumption_token(token: str) -> None: result = await session.send_request( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams( - name="tool_with_multiple_stream_closes", - # retry_interval=500ms, so sleep 600ms to ensure reconnect completes - arguments={"checkpoints": 3, "sleep_time": 0.6}, - ), + params=types.CallToolRequestParams(name="multi_close_tool", arguments={}), ), types.CallToolResult, metadata=metadata,