diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 98948ff99..2cb4c0748 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -207,7 +207,7 @@ def close_sse_stream(self, request_id: RequestId) -> None: send_stream.close() receive_stream.close() - def close_standalone_sse_stream(self) -> None: # pragma: no cover + def close_standalone_sse_stream(self) -> None: """Close the standalone GET SSE stream, triggering client reconnection. This method closes the HTTP connection for the standalone GET stream used @@ -221,8 +221,6 @@ def close_standalone_sse_stream(self) -> None: # pragma: no cover This is a no-op if there is no active standalone SSE stream. Requires event_store to be configured for events to be stored during the disconnect. - Currently, client reconnection for standalone GET streams is NOT - implemented - this is a known gap (see test_standalone_get_stream_reconnection). """ self.close_sse_stream(GET_STREAM_KEY) @@ -245,7 +243,7 @@ def _create_session_message( async def close_stream_callback() -> None: self.close_sse_stream(request_id) - async def close_standalone_stream_callback() -> None: # pragma: no cover + async def close_standalone_stream_callback() -> None: self.close_standalone_sse_stream() metadata = ServerMessageMetadata( @@ -421,7 +419,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se has_json, has_sse = self._check_accept_headers(request) if self.is_json_response_enabled: # For JSON-only responses, only require application/json - if not has_json: # pragma: no cover + if not has_json: response = self._create_error_response( "Not Acceptable: Client must accept application/json", HTTPStatus.NOT_ACCEPTABLE, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb6..b43a3361c 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1,16 +1,14 @@ """Tests for the StreamableHTTP server and client transport. -Contains tests for both server and client sides of the StreamableHTTP transport. +Contains tests for both server and client sides of the StreamableHTTP transport, driven +entirely in process. """ from __future__ import annotations as _annotations import json -import multiprocessing -import socket import time -import traceback -from collections.abc import AsyncIterator, Generator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import Any @@ -20,8 +18,6 @@ import anyio import httpx import pytest -import requests -import uvicorn from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request @@ -46,11 +42,6 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._context import RequestContext from mcp.shared._context_streams import create_context_streams -from mcp.shared._httpx_utils import ( - MCP_DEFAULT_SSE_READ_TIMEOUT, - MCP_DEFAULT_TIMEOUT, - create_mcp_http_client, -) from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( @@ -66,11 +57,10 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport # Test constants SERVER_NAME = "test_streamable_http_server" -TEST_SESSION_ID = "test-session-id-12345" INIT_REQUEST = { "jsonrpc": "2.0", "method": "initialize", @@ -82,9 +72,12 @@ "id": "init-1", } +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + # Helper functions -def extract_protocol_version_from_sse(response: requests.Response) -> str: +def extract_protocol_version_from_sse(response: httpx.Response) -> str: """Extract the negotiated protocol version from an SSE initialization response.""" assert response.headers.get("Content-Type") == "text/event-stream" for line in response.text.splitlines(): @@ -109,32 +102,23 @@ async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( # pragma: no cover + async def replay_events_after( self, last_event_id: EventId, send_callback: EventCallback, ) -> StreamId | None: """Replay events after the specified ID.""" - # Find the stream ID of the last event - target_stream_id = None - for stream_id, event_id, _ in self._events: - if event_id == last_event_id: - target_stream_id = stream_id - break - - if target_stream_id is None: - # If event ID not found, return None - return None + # Find the stream ID of the last event; clients always resume from a stored event. + target_stream_id = next(stream_id for stream_id, event_id, _ in self._events if event_id == last_event_id) # Convert last_event_id to int for comparison last_event_id_int = int(last_event_id) - # Replay only events from the same stream with ID > last_event_id + # Replay only events from the same stream with ID > last_event_id, skipping priming + # events (None message). for stream_id, event_id, message in self._events: - if stream_id == target_stream_id and int(event_id) > last_event_id_int: - # Skip priming events (None message) - if message is not None: - await send_callback(EventMessage(message, event_id)) + if stream_id == target_stream_id and message is not None and int(event_id) > last_event_id_int: + await send_callback(EventMessage(message, event_id)) return target_stream_id @@ -145,26 +129,23 @@ class ServerState: @asynccontextmanager -async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: yield ServerState() -async def _handle_read_resource( # pragma: no cover +async def _handle_read_resource( ctx: ServerRequestContext[ServerState], params: ReadResourceRequestParams ) -> ReadResourceResult: uri = str(params.uri) parsed = urlparse(uri) if parsed.scheme == "foobar": - text = f"Read {parsed.netloc}" - elif parsed.scheme == "slow": - await anyio.sleep(2.0) - text = f"Slow response from {parsed.netloc}" - else: - raise ValueError(f"Unknown resource: {uri}") - return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) + return ReadResourceResult( + contents=[TextResourceContents(uri=uri, text=f"Read {parsed.netloc}", mime_type="text/plain")] + ) + raise ValueError(f"Unknown resource: {uri}") -async def _handle_list_tools( # pragma: no cover +async def _handle_list_tools( ctx: ServerRequestContext[ServerState], params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -179,11 +160,6 @@ async def _handle_list_tools( # pragma: no cover description="A test tool that sends a notification", input_schema={"type": "object", "properties": {}}, ), - Tool( - name="long_running_with_checkpoints", - description="A long-running tool that sends periodic notifications", - input_schema={"type": "object", "properties": {}}, - ), Tool( name="test_sampling_tool", description="A tool that triggers server-side sampling", @@ -209,17 +185,6 @@ async def _handle_list_tools( # pragma: no cover description="Tool that sends notification1, closes stream, sends notification2, notification3", input_schema={"type": "object", "properties": {}}, ), - Tool( - name="tool_with_multiple_stream_closes", - description="Tool that closes SSE stream multiple times during execution", - input_schema={ - "type": "object", - "properties": { - "checkpoints": {"type": "integer", "default": 3}, - "sleep_time": {"type": "number", "default": 0.2}, - }, - }, - ), Tool( name="tool_with_standalone_stream_close", description="Tool that closes standalone GET stream mid-operation", @@ -229,36 +194,14 @@ async def _handle_list_tools( # pragma: no cover ) -async def _handle_call_tool( # pragma: no cover - ctx: ServerRequestContext[ServerState], params: CallToolRequestParams -) -> CallToolResult: +async def _handle_call_tool(ctx: ServerRequestContext[ServerState], params: CallToolRequestParams) -> CallToolResult: name = params.name - args = params.arguments or {} # When the tool is called, send a notification to test GET stream if name == "test_tool_with_standalone_notification": await ctx.session.send_resource_updated(uri="http://test_resource") return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - elif name == "long_running_with_checkpoints": - await ctx.session.send_log_message( - level="info", - data="Tool started", - logger="tool", - related_request_id=ctx.request_id, - ) - - await anyio.sleep(0.1) - - await ctx.session.send_log_message( - level="info", - data="Tool is almost done", - logger="tool", - related_request_id=ctx.request_id, - ) - - return CallToolResult(content=[TextContent(type="text", text="Completed!")]) - elif name == "test_sampling_tool": sampling_result = await ctx.session.create_message( messages=[ @@ -271,15 +214,12 @@ async def _handle_call_tool( # pragma: no cover related_request_id=ctx.request_id, ) - if sampling_result.content.type == "text": - response = sampling_result.content.text - else: - response = str(sampling_result.content) + assert sampling_result.content.type == "text" return CallToolResult( content=[ TextContent( type="text", - text=f"Response from sampling: {response}", + text=f"Response from sampling: {sampling_result.content.text}", ) ] ) @@ -349,31 +289,12 @@ async def _handle_call_tool( # pragma: no cover ) return CallToolResult(content=[TextContent(type="text", text="All notifications sent")]) - elif name == "tool_with_multiple_stream_closes": - num_checkpoints = args.get("checkpoints", 3) - sleep_time = args.get("sleep_time", 0.2) - - for i in range(num_checkpoints): - await ctx.session.send_log_message( - level="info", - data=f"checkpoint_{i}", - logger="multi_close_tool", - related_request_id=ctx.request_id, - ) - - if ctx.close_sse_stream: - await ctx.close_sse_stream() - - await anyio.sleep(sleep_time) - - return CallToolResult(content=[TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")]) - elif name == "tool_with_standalone_stream_close": await ctx.session.send_resource_updated(uri="http://notification_1") await anyio.sleep(0.1) - if ctx.close_standalone_sse_stream: - await ctx.close_standalone_sse_stream() + assert ctx.close_standalone_sse_stream is not None + await ctx.close_standalone_sse_stream() await anyio.sleep(1.5) await ctx.session.send_resource_updated(uri="http://notification_2") @@ -383,7 +304,7 @@ async def _handle_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) -def _create_server() -> Server[ServerState]: # pragma: no cover +def _create_server() -> Server[ServerState]: return Server( SERVER_NAME, lifespan=_server_lifespan, @@ -393,113 +314,60 @@ def _create_server() -> Server[ServerState]: # pragma: no cover ) -def create_app( +@asynccontextmanager +async def running_app( is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, -) -> Starlette: # pragma: no cover - """Create a Starlette application for testing using the session manager. + server: Server[Any] | None = None, +) -> AsyncIterator[Starlette]: + """Serve the test server's streamable HTTP app in process for the duration. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. retry_interval: Retry interval in milliseconds for SSE polling. + server: Server to mount; defaults to the file's shared test server. """ - # Create server instance - server = _create_server() - - # Create the session manager - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) + # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot + # exist for an in-process app; the protection itself is pinned by + # tests/server/test_streamable_http_security.py. session_manager = StreamableHTTPSessionManager( - app=server, + app=server if server is not None else _create_server(), event_store=event_store, json_response=is_json_response_enabled, - security_settings=security_settings, + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), retry_interval=retry_interval, ) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + async with session_manager.run(): + yield app - # Create an ASGI application that uses the session manager - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lambda app: session_manager.run(), - ) - - return app +def make_client(app: Starlette, headers: dict[str, str] | None = None) -> httpx.AsyncClient: + """An httpx client served in process by `app`, with create_mcp_http_client's redirect default. -def run_server( - port: int, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, -) -> None: # pragma: no cover - """Run the test server. - - Args: - port: Port to listen on. - is_json_response_enabled: If True, use JSON responses instead of SSE streams. - event_store: Optional event store for testing resumability. - retry_interval: Retry interval in milliseconds for SSE polling. + (Starlette's Mount 307-redirects the bare /mcp path to /mcp/, which the SDK's own client + factory follows.) """ - - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="info", - limit_concurrency=10, - timeout_keep_alive=5, - access_log=False, + return httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, headers=headers, follow_redirects=True ) - # Start the server - server = uvicorn.Server(config=config) - - # This is important to catch exceptions and prevent test hangs - try: - server.run() - except Exception: - traceback.print_exc() - -# Test fixtures - using same approach as SSE tests +# Test fixtures @pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +async def basic_app() -> AsyncIterator[Starlette]: + """The test server's app with SSE response mode.""" + async with running_app() as app: + yield app @pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +async def json_app() -> AsyncIterator[Starlette]: + """The test server's app with JSON response mode.""" + async with running_app(is_json_response_enabled=True) as app: + yield app @pytest.fixture @@ -509,82 +377,29 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(event_server_port) - - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: - """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(json_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def basic_server_url(basic_server_port: int) -> str: - """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" - - -@pytest.fixture -def json_server_url(json_server_port: int) -> str: - """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" +async def event_app(event_store: SimpleEventStore) -> AsyncIterator[tuple[SimpleEventStore, Starlette]]: + """The test server's app with an event store and retry_interval enabled.""" + async with running_app(event_store=event_store, retry_interval=500) as app: + yield event_store, app # Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str): - """Test that Accept header is properly validated.""" - # Test without Accept header (suppress requests library default Accept: */*) - session = requests.Session() - session.headers.pop("Accept") - response = session.post( - f"{basic_server_url}/mcp", - headers={"Content-Type": "application/json"}, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_accept_header_validation(basic_app: Starlette) -> None: + """A POST without an Accept header is rejected with 406.""" + async with make_client(basic_app) as client: + # Suppress the httpx client default Accept: */* header + del client.headers["accept"] + response = await client.post( + "/mcp", + headers={"Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text +@pytest.mark.anyio @pytest.mark.parametrize( "accept_header", [ @@ -596,19 +411,21 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): "application/*;q=0.9, text/*;q=0.8", ], ) -def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): - """Test that wildcard Accept headers are accepted per RFC 7231.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": accept_header, - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 +async def test_accept_header_wildcard(basic_app: Starlette, accept_header: str) -> None: + """Wildcard Accept headers are accepted per RFC 7231.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 +@pytest.mark.anyio @pytest.mark.parametrize( "accept_header", [ @@ -617,100 +434,104 @@ def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accep "text/*", ], ) -def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): - """Test that incompatible Accept headers are rejected for SSE mode.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": accept_header, - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - -def test_content_type_validation(basic_server: None, basic_server_url: str): - """Test that Content-Type header is properly validated.""" - # Test with incorrect Content-Type - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "text/plain", - }, - data="This is not JSON", - ) +async def test_accept_header_incompatible(basic_app: Starlette, accept_header: str) -> None: + """Accept headers that cannot cover both response representations are rejected for SSE mode.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text - assert response.status_code == 400 - assert "Invalid Content-Type" in response.text +@pytest.mark.anyio +async def test_content_type_validation(basic_app: Starlette) -> None: + """A POST whose Content-Type is not application/json is rejected with 400.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", + }, + content="This is not JSON", + ) + + assert response.status_code == 400 + assert "Invalid Content-Type" in response.text + + +@pytest.mark.anyio +async def test_json_validation(basic_app: Starlette) -> None: + """A POST body that is not valid JSON is rejected with a parse error.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + content="this is not valid json", + ) + assert response.status_code == 400 + assert "Parse error" in response.text -def test_json_validation(basic_server: None, basic_server_url: str): - """Test that JSON content is properly validated.""" - # Test with invalid JSON - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - data="this is not valid json", - ) - assert response.status_code == 400 - assert "Parse error" in response.text - - -def test_json_parsing(basic_server: None, basic_server_url: str): - """Test that JSON content is properly parse.""" - # Test with valid JSON but invalid JSON-RPC - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"foo": "bar"}, - ) - assert response.status_code == 400 - assert "Validation error" in response.text - - -def test_method_not_allowed(basic_server: None, basic_server_url: str): - """Test that unsupported HTTP methods are rejected.""" - # Test with unsupported method (PUT) - response = requests.put( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 405 - assert "Method Not Allowed" in response.text - - -def test_session_validation(basic_server: None, basic_server_url: str): - """Test session ID validation.""" - # session_id not used directly in this test - - # Test without session ID - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, - ) - assert response.status_code == 400 - assert "Missing session ID" in response.text + +@pytest.mark.anyio +async def test_json_parsing(basic_app: Starlette) -> None: + """Valid JSON that is not a JSON-RPC message is rejected with a validation error.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"foo": "bar"}, + ) + assert response.status_code == 400 + assert "Validation error" in response.text -def test_session_id_pattern(): - """Test that SESSION_ID_PATTERN correctly validates session IDs.""" +@pytest.mark.anyio +async def test_method_not_allowed(basic_app: Starlette) -> None: + """Unsupported HTTP methods are rejected with 405.""" + async with make_client(basic_app) as client: + response = await client.put( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 405 + assert "Method Not Allowed" in response.text + + +@pytest.mark.anyio +async def test_session_validation(basic_app: Starlette) -> None: + """A non-initialize request without a session ID is rejected with 400.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 400 + assert "Missing session ID" in response.text + + +def test_session_id_pattern() -> None: + """SESSION_ID_PATTERN accepts visible ASCII (0x21-0x7E) and rejects everything else.""" # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) valid_session_ids = [ "test-session-id", @@ -744,8 +565,8 @@ def test_session_id_pattern(): assert SESSION_ID_PATTERN.fullmatch(session_id) is None -def test_streamable_http_transport_init_validation(): - """Test that StreamableHTTPServerTransport validates session ID on init.""" +def test_streamable_http_transport_init_validation() -> None: + """StreamableHTTPServerTransport accepts valid or absent session IDs and rejects invalid ones.""" # Valid session ID should initialize without errors valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") assert valid_transport.mcp_session_id == "valid-id" @@ -767,144 +588,153 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str): - """Test session termination via DELETE and subsequent request handling.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 +@pytest.mark.anyio +async def test_session_termination(basic_app: Starlette) -> None: + """DELETE terminates the session, after which requests for it return 404.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 - # Extract negotiated protocol version from SSE response - negotiated_version = extract_protocol_version_from_sse(response) + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) - # Now terminate the session - session_id = response.headers.get(MCP_SESSION_ID_HEADER) - response = requests.delete( - f"{basic_server_url}/mcp", - headers={ - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - ) - assert response.status_code == 200 - - # Try to use the terminated session - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - }, - json={"jsonrpc": "2.0", "method": "ping", "id": 2}, - ) - assert response.status_code == 404 - assert "Session has been terminated" in response.text - - -def test_response(basic_server: None, basic_server_url: str): - """Test response handling for a valid request.""" - mcp_url = f"{basic_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + response = await client.delete( + "/mcp", + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 200 + + # Try to use the terminated session + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "ping", "id": 2}, + ) + assert response.status_code == 404 + assert "Session has been terminated" in response.text - # Extract negotiated protocol version from SSE response - negotiated_version = extract_protocol_version_from_sse(response) - # Now get the session ID - session_id = response.headers.get(MCP_SESSION_ID_HEADER) +@pytest.mark.anyio +async def test_response(basic_app: Starlette) -> None: + """A request on an initialized session is answered on a text/event-stream response.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 - # Try to use the session with proper headers - tools_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, - stream=True, - ) - assert tools_response.status_code == 200 - assert tools_response.headers.get("Content-Type") == "text/event-stream" - - -def test_json_response(json_response_server: None, json_server_url: str): - """Test response handling when is_json_response_enabled is True.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): - """Test that json_response servers only require application/json in Accept header.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests without Accept header.""" - mcp_url = f"{json_server_url}/mcp" - # Suppress requests library default Accept: */* header - session = requests.Session() - session.headers.pop("Accept") - response = session.post( - mcp_url, - headers={ - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + # Now get the session ID + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Try to use the session with proper headers + async with client.stream( + "POST", + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, + ) as tools_response: + assert tools_response.status_code == 200 + assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests with incorrect Accept header.""" - mcp_url = f"{json_server_url}/mcp" - # Test with only text/event-stream (wrong for JSON server) - response = requests.post( - mcp_url, - headers={ - "Accept": "text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_json_response(json_app: Starlette) -> None: + """With JSON response mode enabled, requests are answered with application/json bodies.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + +@pytest.mark.anyio +async def test_json_response_accept_json_only(json_app: Starlette) -> None: + """JSON response mode only requires application/json in the Accept header.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + +@pytest.mark.anyio +async def test_json_response_missing_accept_header(json_app: Starlette) -> None: + """JSON response mode still rejects requests without an Accept header.""" + async with make_client(json_app) as client: + # Suppress the httpx client default Accept: */* header + del client.headers["accept"] + response = await client.post( + "/mcp", + headers={ + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + +@pytest.mark.anyio +async def test_json_response_incorrect_accept_header(json_app: Starlette) -> None: + """JSON response mode rejects an Accept header that does not cover application/json.""" + async with make_client(json_app) as client: + # Test with only text/event-stream (wrong for JSON server) + response = await client.post( + "/mcp", + headers={ + "Accept": "text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + +@pytest.mark.anyio @pytest.mark.parametrize( "accept_header", [ @@ -913,167 +743,134 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ "application/*;q=0.9", ], ) -def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): - """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": accept_header, - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_get_sse_stream(basic_server: None, basic_server_url: str): - """Test establishing an SSE stream via GET request.""" - # First, we need to initialize a session - mcp_url = f"{basic_server_url}/mcp" - init_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 +async def test_json_response_wildcard_accept_header(json_app: Starlette, accept_header: str) -> None: + """JSON response mode accepts wildcard Accept headers per RFC 7231.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" - # Get the session ID - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - assert session_id is not None - # Extract negotiated protocol version from SSE response - init_data = None - assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): - init_data = json.loads(line[6:]) - break - assert init_data is not None - negotiated_version = init_data["result"]["protocolVersion"] - - # Now attempt to establish an SSE stream via GET - get_response = requests.get( - mcp_url, - headers={ - "Accept": "text/event-stream", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) +@pytest.mark.anyio +async def test_get_sse_stream(basic_app: Starlette) -> None: + """GET establishes the standalone SSE stream, and a second GET is rejected with 409.""" + async with make_client(basic_app) as client: + # First, we need to initialize a session + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 - # Verify we got a successful response with the right content type - assert get_response.status_code == 200 - assert get_response.headers.get("Content-Type") == "text/event-stream" + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + negotiated_version = extract_protocol_version_from_sse(init_response) - # Test that a second GET request gets rejected (only one stream allowed) - second_get = requests.get( - mcp_url, - headers={ + # Now attempt to establish an SSE stream via GET + get_headers = { "Accept": "text/event-stream", MCP_SESSION_ID_HEADER: session_id, MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) + } + # The streams enter in order, so the second GET arrives while the first is held open. + async with ( + client.stream("GET", "/mcp", headers=get_headers) as get_response, + client.stream("GET", "/mcp", headers=get_headers) as second_get, + ): + # Verify we got a successful response with the right content type + assert get_response.status_code == 200 + assert get_response.headers.get("Content-Type") == "text/event-stream" - # Should get CONFLICT (409) since there's already a stream - # Note: This might fail if the first stream fully closed before this runs, - # but generally it should work in the test environment where it runs quickly - assert second_get.status_code == 409 - - -def test_get_validation(basic_server: None, basic_server_url: str): - """Test validation for GET requests.""" - # First, we need to initialize a session - mcp_url = f"{basic_server_url}/mcp" - init_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 + # The second GET gets CONFLICT (409): only one standalone stream is allowed per session. + assert second_get.status_code == 409 - # Get the session ID - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - assert session_id is not None - # Extract negotiated protocol version from SSE response - init_data = None - assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): - init_data = json.loads(line[6:]) - break - assert init_data is not None - negotiated_version = init_data["result"]["protocolVersion"] - - # Test without Accept header (suppress requests library default Accept: */*) - session = requests.Session() - session.headers.pop("Accept") - response = session.get( - mcp_url, - headers={ - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - # Test with wrong Accept header - response = requests.get( - mcp_url, - headers={ - "Accept": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_get_validation(basic_app: Starlette) -> None: + """A GET without an Accept header covering text/event-stream is rejected with 406.""" + async with make_client(basic_app) as client: + # First, we need to initialize a session + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + negotiated_version = extract_protocol_version_from_sse(init_response) + + # Test without Accept header (suppress the httpx client default Accept: */*) + del client.headers["accept"] + response = await client.get( + "/mcp", + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with wrong Accept header + response = await client.get( + "/mcp", + headers={ + "Accept": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover - """Create test client matching the SSE test pattern.""" - async with httpx.AsyncClient(base_url=basic_server_url) as client: - yield client - - -@pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str): +async def initialized_client_session(basic_app: Starlette) -> AsyncIterator[ClientSession]: """Create initialized StreamableHTTP client session.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - yield session + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): - """Test basic client connection with initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME +async def test_streamable_http_client_basic_connection(basic_app: Starlette) -> None: + """A client initializes against a server over the StreamableHTTP transport.""" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.server_info.name == SERVER_NAME @pytest.mark.anyio -async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession): - """Test client resource read functionality.""" +async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession) -> None: + """A resource read round-trips its arguments and the handler's content.""" response = await initialized_client_session.read_resource(uri="foobar://test-resource") assert len(response.contents) == 1 assert response.contents[0].uri == "foobar://test-resource" @@ -1082,11 +879,11 @@ async def test_streamable_http_client_resource_read(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession): - """Test client tool invocation.""" +async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession) -> None: + """A tool call reaches the handler and returns its content.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 8 assert tools.tools[0].name == "test_tool" # Call the tool @@ -1097,8 +894,8 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session @pytest.mark.anyio -async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession): - """Test error handling in client.""" +async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession) -> None: + """A server-side error reaches the client as an MCPError with the handler's message.""" with pytest.raises(MCPError) as exc_info: await initialized_client_session.read_resource(uri="unknown://test-error") assert exc_info.value.error.code == 0 @@ -1106,50 +903,56 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): - """Test that session ID persists across requests.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Make multiple requests to verify session persistence - tools = await session.list_tools() - assert len(tools.tools) == 10 - - # Read a resource - resource = await session.read_resource(uri="foobar://test-persist") - assert isinstance(resource.contents[0], TextResourceContents) is True - content = resource.contents[0] - assert isinstance(content, TextResourceContents) - assert content.text == "Read test-persist" +async def test_streamable_http_client_session_persistence(basic_app: Starlette) -> None: + """The session persists across multiple requests on one connection.""" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make multiple requests to verify session persistence + tools = await session.list_tools() + assert len(tools.tools) == 8 + + # Read a resource + resource = await session.read_resource(uri="foobar://test-persist") + assert isinstance(resource.contents[0], TextResourceContents) is True + content = resource.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Read test-persist" @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): - """Test client with JSON response mode.""" - async with streamable_http_client(f"{json_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME - - # Check tool listing - tools = await session.list_tools() - assert len(tools.tools) == 10 - - # Call a tool and verify JSON response handling - result = await session.call_tool("test_tool", {}) - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Called test_tool" +async def test_streamable_http_client_json_response(json_app: Starlette) -> None: + """The client works identically against a server in JSON response mode.""" + async with ( + make_client(json_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.server_info.name == SERVER_NAME + + # Check tool listing + tools = await session.list_tools() + assert len(tools.tools) == 8 + + # Call a tool and verify JSON response handling + result = await session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): - """Test GET stream functionality for server-initiated messages.""" +async def test_streamable_http_client_get_stream(basic_app: Starlette) -> None: + """A server-initiated notification reaches the client on the standalone GET stream.""" notifications_received: list[types.ServerNotification] = [] # Define message handler to capture notifications @@ -1159,30 +962,33 @@ async def message_handler( # pragma: no branch if isinstance(message, types.ServerNotification): # pragma: no branch notifications_received.append(message) - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Initialize the session - this triggers the GET stream setup - result = await session.initialize() - assert isinstance(result, InitializeResult) + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + # Initialize the session - this triggers the GET stream setup + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Call the special tool that sends a notification - await session.call_tool("test_tool_with_standalone_notification", {}) + # Call the special tool that sends a notification + await session.call_tool("test_tool_with_standalone_notification", {}) - # Verify we received the notification - assert len(notifications_received) > 0 + # Verify we received the notification + assert len(notifications_received) > 0 - # Verify the notification is a ResourceUpdatedNotification - resource_update_found = False - for notif in notifications_received: - if isinstance(notif, types.ResourceUpdatedNotification): # pragma: no branch - assert str(notif.params.uri) == "http://test_resource" - resource_update_found = True + # Verify the notification is a ResourceUpdatedNotification + resource_update_found = False + for notif in notifications_received: + if isinstance(notif, types.ResourceUpdatedNotification): # pragma: no branch + assert str(notif.params.uri) == "http://test_resource" + resource_update_found = True - assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" + assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" -def create_session_id_capturing_client() -> tuple[httpx.AsyncClient, list[str]]: - """Create an httpx client that captures the session ID from responses.""" +def create_session_id_capturing_client(app: Starlette) -> tuple[httpx.AsyncClient, list[str]]: + """Create an in-process httpx client that captures the session ID from responses.""" captured_ids: list[str] = [] async def capture_session_id(response: httpx.Response) -> None: @@ -1191,21 +997,22 @@ async def capture_session_id(response: httpx.Response) -> None: captured_ids.append(session_id) client = httpx.AsyncClient( + transport=StreamingASGITransport(app), + base_url=BASE_URL, follow_redirects=True, - timeout=httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT), event_hooks={"response": [capture_session_id]}, ) return client, captured_ids @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): - """Test client session termination functionality.""" +async def test_streamable_http_client_session_termination(basic_app: Starlette) -> None: + """After the client terminates its session on close, a new connection with that session ID fails.""" # Use httpx client with event hooks to capture session ID - httpx_client, captured_ids = create_session_id_capturing_client() + httpx_client, captured_ids = create_session_id_capturing_client(basic_app) async with httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1220,10 +1027,10 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 8 - async with create_mcp_http_client(headers=headers) as httpx_client2: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( + async with make_client(basic_app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, ): @@ -1235,9 +1042,9 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba @pytest.mark.anyio async def test_streamable_http_client_session_termination_204( - basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch -): - """Test client session termination functionality with a 204 response. + basic_app: Starlette, monkeypatch: pytest.MonkeyPatch +) -> None: + """Session termination also succeeds when the server answers the DELETE with 204. This test patches the httpx client to return a 204 response for DELETEs. """ @@ -1263,10 +1070,10 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt monkeypatch.setattr(httpx.AsyncClient, "delete", mock_delete) # Use httpx client with event hooks to capture session ID - httpx_client, captured_ids = create_session_id_capturing_client() + httpx_client, captured_ids = create_session_id_capturing_client(basic_app) async with httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1281,10 +1088,10 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 8 - async with create_mcp_http_client(headers=headers) as httpx_client2: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( + async with make_client(basic_app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, ): @@ -1295,14 +1102,15 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt @pytest.mark.anyio -async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]): - """Test client session resumption using sync primitives for reliable coordination.""" - _, server_url = event_server +async def test_streamable_http_client_resumption(event_app: tuple[SimpleEventStore, Starlette]) -> None: + """A second client resumes an interrupted request with a resumption token and receives the rest.""" + _, app = event_app # Variables to track the state captured_resumption_token: str | None = None captured_notifications: list[types.ServerNotification] = [] - first_notification_received = False + first_notification_received = anyio.Event() + resumption_token_received = anyio.Event() async def message_handler( # pragma: no branch message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -1312,19 +1120,19 @@ async def message_handler( # pragma: no branch # Look for our first notification if isinstance(message, types.LoggingMessageNotification): # pragma: no branch if message.params.data == "First notification before lock": - nonlocal first_notification_received - first_notification_received = True + first_notification_received.set() async def on_resumption_token_update(token: str) -> None: nonlocal captured_resumption_token captured_resumption_token = token + resumption_token_received.set() # Use httpx client with event hooks to capture session ID - httpx_client, captured_ids = create_session_id_capturing_client() + httpx_client, captured_ids = create_session_id_capturing_client(app) # First, start the client session and begin the tool that waits on lock async with httpx_client: - async with streamable_http_client(f"{server_url}/mcp", terminate_on_close=False, http_client=httpx_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", terminate_on_close=False, http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1363,13 +1171,13 @@ async def run_tool(): tg.start_soon(run_tool) # Wait for the first notification and resumption token - while not first_notification_received or not captured_resumption_token: - await anyio.sleep(0.1) + with anyio.fail_after(5): + await first_notification_received.wait() + await resumption_token_received.wait() - # The while loop only exits after first_notification_received=True, - # which is set by message_handler immediately after appending to - # captured_notifications. The server tool is blocked on its lock, - # so nothing else can arrive before we cancel. + # first_notification_received is set by message_handler immediately + # after appending to captured_notifications. The server tool is + # blocked on its lock, so nothing else can arrive before we cancel. assert len(captured_notifications) == 1 assert isinstance(captured_notifications[0], types.LoggingMessageNotification) assert captured_notifications[0].params.data == "First notification before lock" @@ -1379,8 +1187,8 @@ async def run_tool(): # Kill the client session while tool is waiting on lock tg.cancel_scope.cancel() - async with create_mcp_http_client(headers=headers) as httpx_client2: - async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client2) as ( + async with make_client(app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, ): @@ -1413,8 +1221,8 @@ async def run_tool(): @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): - """Test server-initiated sampling request through streamable HTTP transport.""" +async def test_streamablehttp_server_sampling(basic_app: Starlette) -> None: + """A server-initiated sampling request reaches the client callback and its result the tool.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False captured_message_params = None @@ -1441,29 +1249,32 @@ async def sampling_callback( ) # Create client with sampling callback - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Call the tool that triggers server-side sampling - tool_result = await session.call_tool("test_sampling_tool", {}) - - # Verify the tool result contains the expected content - assert len(tool_result.content) == 1 - assert tool_result.content[0].type == "text" - assert "Response from sampling: Received message from server" in tool_result.content[0].text - - # Verify sampling callback was invoked - assert sampling_callback_invoked - assert captured_message_params is not None - assert len(captured_message_params.messages) == 1 - assert captured_message_params.messages[0].content.text == "Server needs client sampling" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session, + ): + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the tool that triggers server-side sampling + tool_result = await session.call_tool("test_sampling_tool", {}) + + # Verify the tool result contains the expected content + assert len(tool_result.content) == 1 + assert tool_result.content[0].type == "text" + assert "Response from sampling: Received message from server" in tool_result.content[0].text + + # Verify sampling callback was invoked + assert sampling_callback_invoked + assert captured_message_params is not None + assert len(captured_message_params.messages) == 1 + assert captured_message_params.messages[0].content.text == "Server needs client sampling" # Context-aware server implementation for testing request context propagation -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -1488,97 +1299,51 @@ async def _handle_context_list_tools( # pragma: no cover ) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - name = params.name - args = params.arguments or {} - - if name == "echo_headers": - headers_info: dict[str, Any] = {} - if ctx.request and isinstance(ctx.request, Request): - headers_info = dict(ctx.request.headers) - return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) - - elif name == "echo_context": - context_data: dict[str, Any] = { - "request_id": args.get("request_id"), - "headers": {}, - "method": None, - "path": None, - } - if ctx.request and isinstance(ctx.request, Request): - request = ctx.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name in ("echo_headers", "echo_context") + assert isinstance(ctx.request, Request) + + if params.name == "echo_headers": + return CallToolResult(content=[TextContent(type="text", text=json.dumps(dict(ctx.request.headers)))]) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) + assert params.arguments is not None + context_data: dict[str, Any] = { + "request_id": params.arguments.get("request_id"), + "headers": dict(ctx.request.headers), + "method": ctx.request.method, + "path": ctx.request.url.path, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) -# Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" +@pytest.fixture +async def context_app() -> AsyncIterator[Starlette]: + """An app whose server echoes request context, served in process.""" server = Server( "ContextAwareServer", on_list_tools=_handle_context_list_tools, on_call_tool=_handle_context_call_tool, ) - session_manager = StreamableHTTPSessionManager( app=server, - event_store=None, - json_response=False, + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), ) - - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lambda app: session_manager.run(), - ) - - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - ) - server_instance.run() - - -@pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + async with session_manager.run(): + yield app @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: - """Test that request context is properly propagated through StreamableHTTP.""" +async def test_streamablehttp_request_context_propagation(context_app: Starlette) -> None: + """Custom HTTP headers on the connection are visible to server handlers via ctx.request.""" custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with make_client(context_app, headers=custom_headers) as httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1602,11 +1367,11 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: - """Test that request contexts are isolated between StreamableHTTP clients.""" +async def test_streamablehttp_request_context_isolation(context_app: Starlette) -> None: + """Each connection's handlers see only that connection's request headers.""" contexts: list[dict[str, Any]] = [] - # Create multiple clients with different headers + # Connect three clients in turn, each with its own headers. for i in range(3): headers = { "X-Request-Id": f"request-{i}", @@ -1614,8 +1379,8 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No "Authorization": f"Bearer token-{i}", } - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with make_client(context_app, headers=headers) as httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1640,145 +1405,160 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): - """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Initialize and get the negotiated version - init_result = await session.initialize() - negotiated_version = init_result.protocol_version - - # Call a tool that echoes headers to verify the header is present - tool_result = await session.call_tool("echo_headers", {}) - - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - headers_data = json.loads(tool_result.content[0].text) - - # Verify protocol version header is present - assert "mcp-protocol-version" in headers_data - assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version - - -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): - """Test that server returns 400 Bad Request version if header unsupported or invalid.""" - # First initialize a session to get a valid session ID - init_response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - - # Test request with invalid protocol version (should fail) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "invalid-version", - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() - - # Test request with unsupported protocol version (should fail) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() +async def test_client_includes_protocol_version_header_after_init(context_app: Starlette) -> None: + """After initialization, every client request carries the negotiated protocol version header.""" + async with ( + make_client(context_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + # Initialize and get the negotiated version + init_result = await session.initialize() + negotiated_version = init_result.protocol_version - # Test request with valid protocol version (should succeed) - negotiated_version = extract_protocol_version_from_sse(init_response) + # Call a tool that echoes headers to verify the header is present + tool_result = await session.call_tool("echo_headers", {}) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, - ) - assert response.status_code == 200 - - -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): - """Test server accepts requests without protocol version header.""" - # First initialize a session to get a valid session ID - init_response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - - # Test request without mcp-protocol-version header (backwards compatibility) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, - stream=True, - ) - assert response.status_code == 200 # Should succeed for backwards compatibility - assert response.headers.get("Content-Type") == "text/event-stream" + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + # Verify protocol version header is present + assert "mcp-protocol-version" in headers_data + assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version + + +@pytest.mark.anyio +async def test_server_validates_protocol_version_header(basic_app: Starlette) -> None: + """An invalid or unsupported protocol version header is rejected with 400; the negotiated one passes.""" + async with make_client(basic_app) as client: + # First initialize a session to get a valid session ID + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test request with invalid protocol version (should fail) + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "invalid-version", + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with unsupported protocol version (should fail) + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with valid protocol version (should succeed) + negotiated_version = extract_protocol_version_from_sse(init_response) + + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, + ) + assert response.status_code == 200 @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str): - """Test that cases where the client crashes are handled gracefully.""" +async def test_server_backwards_compatibility_no_protocol_version(basic_app: Starlette) -> None: + """A request without a protocol version header is accepted for backwards compatibility.""" + async with make_client(basic_app) as client: + # First initialize a session to get a valid session ID + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test request without mcp-protocol-version header (backwards compatibility) + async with client.stream( + "POST", + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, + ) as response: + assert response.status_code == 200 # Should succeed for backwards compatibility + assert response.headers.get("Content-Type") == "text/event-stream" + + +@pytest.mark.anyio +async def test_client_crash_handled(basic_app: Starlette) -> None: + """A client crashing mid-session does not prevent later clients from connecting.""" # Simulate bad client that crashes after init async def bad_client(): """Client that triggers ClosedResourceError""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - raise Exception("client crash") + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + raise Exception("client crash") - # Run bad client a few times to trigger the crash + # Run bad client a few times to trigger the crash. The crash surfaces wrapped in exception + # groups whose exact shape is not the subject here — what matters is that the server survives. for _ in range(3): try: await bad_client() except Exception: pass - await anyio.sleep(0.1) # Try a good client, it should still be able to connect and list tools - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - result = await session.initialize() - assert isinstance(result, InitializeResult) - tools = await session.list_tools() - assert tools.tools + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + result = await session.initialize() + assert isinstance(result, InitializeResult) + tools = await session.list_tools() + assert tools.tools @pytest.mark.anyio -async def test_handle_sse_event_skips_empty_data(): - """Test that _handle_sse_event skips empty SSE data (keep-alive pings).""" +async def test_handle_sse_event_skips_empty_data() -> None: + """_handle_sse_event skips empty SSE data (keep-alive pings) without writing to the stream.""" transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") # Create a mock SSE event with empty data (keep-alive ping) @@ -1804,8 +1584,8 @@ async def test_handle_sse_event_skips_empty_data(): @pytest.mark.anyio -async def test_priming_event_not_sent_for_old_protocol_version(): - """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" +async def test_priming_event_not_sent_for_old_protocol_version() -> None: + """_maybe_send_priming_event skips for old protocol versions (backwards compat).""" # Create a transport with an event store transport = StreamableHTTPServerTransport( "/mcp", @@ -1833,8 +1613,8 @@ async def test_priming_event_not_sent_for_old_protocol_version(): @pytest.mark.anyio -async def test_priming_event_not_sent_without_event_store(): - """Test that _maybe_send_priming_event returns early when no event_store is configured.""" +async def test_priming_event_not_sent_without_event_store() -> None: + """_maybe_send_priming_event returns early when no event_store is configured.""" # Create a transport WITHOUT an event store transport = StreamableHTTPServerTransport("/mcp") @@ -1853,8 +1633,8 @@ async def test_priming_event_not_sent_without_event_store(): @pytest.mark.anyio -async def test_priming_event_includes_retry_interval(): - """Test that _maybe_send_priming_event includes retry field when retry_interval is set.""" +async def test_priming_event_includes_retry_interval() -> None: + """_maybe_send_priming_event includes the retry field when retry_interval is set.""" # Create a transport with an event store AND retry_interval transport = StreamableHTTPServerTransport( "/mcp", @@ -1882,8 +1662,8 @@ async def test_priming_event_includes_retry_interval(): @pytest.mark.anyio -async def test_close_sse_stream_callback_not_provided_for_old_protocol_version(): - """Test that close_sse_stream callbacks are NOT provided for old protocol versions.""" +async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() -> None: + """close_sse_stream callbacks are only provided for protocol versions that support polling.""" # Create a transport with an event store transport = StreamableHTTPServerTransport( "/mcp", @@ -1915,71 +1695,76 @@ async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() @pytest.mark.anyio async def test_streamable_http_client_receives_priming_event( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client should receive priming event (resumption token update) on POST SSE stream.""" - _, server_url = event_server + _, app = event_app captured_resumption_tokens: list[str] = [] async def on_resumption_token_update(token: str) -> None: captured_resumption_tokens.append(token) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() - # Call tool with resumption token callback via send_request - metadata = ClientMessageMetadata( - on_resumption_token_update=on_resumption_token_update, - ) - result = await session.send_request( - types.CallToolRequest(params=types.CallToolRequestParams(name="test_tool", arguments={})), - types.CallToolResult, - metadata=metadata, - ) - assert result is not None - - # Should have received priming event token BEFORE response data - # Priming event = 1 token (empty data, id only) - # Response = 1 token (actual JSON-RPC response) - # Total = 2 tokens minimum - assert len(captured_resumption_tokens) >= 2, ( - f"Server must send priming event before response. " - f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" - ) - assert captured_resumption_tokens[0] is not None + # Call tool with resumption token callback via send_request + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + result = await session.send_request( + types.CallToolRequest(params=types.CallToolRequestParams(name="test_tool", arguments={})), + types.CallToolResult, + metadata=metadata, + ) + assert result is not None + + # Should have received priming event token BEFORE response data + # Priming event = 1 token (empty data, id only) + # Response = 1 token (actual JSON-RPC response) + # Total = 2 tokens minimum + assert len(captured_resumption_tokens) >= 2, ( + f"Server must send priming event before response. " + f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" + ) + assert captured_resumption_tokens[0] is not None @pytest.mark.anyio async def test_server_close_sse_stream_via_context( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Server tool can call ctx.close_sse_stream() to close connection.""" - _, server_url = event_server + _, app = event_app - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() - # Call tool that closes stream mid-operation - # This should NOT raise NotImplementedError when fully implemented - result = await session.call_tool("tool_with_stream_close", {}) + # Call tool that closes stream mid-operation + result = await session.call_tool("tool_with_stream_close", {}) - # Client should still receive complete response (via auto-reconnect) - assert result is not None - assert len(result.content) > 0 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + # Client should still receive complete response (via auto-reconnect) + assert result is not None + assert len(result.content) > 0 + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_client_auto_reconnects( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client should auto-reconnect with Last-Event-ID when server closes after priming event.""" - _, server_url = event_server + _, app = event_app captured_notifications: list[str] = [] async def message_handler( @@ -1991,59 +1776,63 @@ async def message_handler( if isinstance(message, types.LoggingMessageNotification): # pragma: no branch captured_notifications.append(str(message.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() - - # Call tool that: - # 1. Sends notification - # 2. Closes SSE stream - # 3. Sends more notifications (stored in event_store) - # 4. Returns response - result = await session.call_tool("tool_with_stream_close", {}) - - # Client should have auto-reconnected and received ALL notifications - assert len(captured_notifications) >= 2, ( - "Client should auto-reconnect and receive notifications sent both before and after stream close" - ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + + # Call tool that: + # 1. Sends notification + # 2. Closes SSE stream + # 3. Sends more notifications (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_stream_close", {}) + + # Client should have auto-reconnected and received ALL notifications + assert len(captured_notifications) >= 2, ( + "Client should auto-reconnect and receive notifications sent both before and after stream close" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_client_respects_retry_interval( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client MUST respect retry field, waiting specified ms before reconnecting.""" - _, server_url = event_server + _, app = event_app - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() - start_time = time.monotonic() - result = await session.call_tool("tool_with_stream_close", {}) - elapsed = time.monotonic() - start_time + start_time = time.monotonic() + result = await session.call_tool("tool_with_stream_close", {}) + elapsed = time.monotonic() - start_time - # Verify result was received - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + # Verify result was received + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" - # The elapsed time should include at least the retry interval - # if reconnection occurred. This test may be flaky depending on - # implementation details, but demonstrates the expected behavior. - # Note: This assertion may need adjustment based on actual implementation - assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" + # The elapsed time should include at least the retry interval (500ms) before + # the client reconnected; the tool's own work only accounts for ~100ms. + assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" @pytest.mark.anyio async def test_streamable_http_sse_polling_full_cycle( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """End-to-end test: server closes stream, client reconnects, receives all events.""" - _, server_url = event_server + _, app = event_app all_notifications: list[str] = [] async def message_handler( @@ -2055,35 +1844,38 @@ async def message_handler( if isinstance(message, types.LoggingMessageNotification): # pragma: no branch all_notifications.append(str(message.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() - - # Call tool that simulates polling pattern: - # 1. Server sends priming event - # 2. Server sends "Before close" notification - # 3. Server closes stream (calls close_sse_stream) - # 4. (client reconnects automatically) - # 5. Server sends "After close" notification - # 6. Server sends final response - result = await session.call_tool("tool_with_stream_close", {}) - - # Verify all notifications received in order - assert "Before close" in all_notifications, "Should receive notification sent before stream close" - assert "After close" in all_notifications, ( - "Should receive notification sent after stream close (via auto-reconnect)" - ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + + # Call tool that simulates polling pattern: + # 1. Server sends priming event + # 2. Server sends "Before close" notification + # 3. Server closes stream (calls close_sse_stream) + # 4. (client reconnects automatically) + # 5. Server sends "After close" notification + # 6. Server sends final response + result = await session.call_tool("tool_with_stream_close", {}) + + # Verify all notifications received in order + assert "Before close" in all_notifications, "Should receive notification sent before stream close" + assert "After close" in all_notifications, ( + "Should receive notification sent after stream close (via auto-reconnect)" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_events_replayed_after_disconnect( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Events sent while client is disconnected should be replayed on reconnect.""" - _, server_url = event_server + _, app = event_app notification_data: list[str] = [] async def message_handler( @@ -2095,37 +1887,43 @@ async def message_handler( if isinstance(message, types.LoggingMessageNotification): # pragma: no branch notification_data.append(str(message.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() - # Tool sends: notification1, close_stream, notification2, notification3, response - # Client should receive all notifications even though 2&3 were sent during disconnect - result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) + # Tool sends: notification1, close_stream, notification2, notification3, response + # Client should receive all notifications even though 2&3 were sent during disconnect + result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) - assert "notification1" in notification_data, "Should receive notification1 (sent before close)" - assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" - assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" + assert "notification1" in notification_data, "Should receive notification1 (sent before close)" + assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" + assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" - # Verify order: notification1 should come before notification2 and notification3 - idx1 = notification_data.index("notification1") - idx2 = notification_data.index("notification2") - idx3 = notification_data.index("notification3") - assert idx1 < idx2 < idx3, "Notifications should be received in order" + # Verify order: notification1 should come before notification2 and notification3 + idx1 = notification_data.index("notification1") + idx2 = notification_data.index("notification2") + idx3 = notification_data.index("notification3") + assert idx1 < idx2 < idx3, "Notifications should be received in order" - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "All notifications sent" + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "All notifications sent" @pytest.mark.anyio -async def test_streamable_http_multiple_reconnections( - event_server: tuple[SimpleEventStore, str], -): - """Verify multiple close_sse_stream() calls each trigger a client reconnect. +async def test_streamable_http_multiple_reconnections() -> None: + """Every close_sse_stream() severs a live connection and triggers its own client reconnect. - Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure - client has time to reconnect before the next checkpoint. + The tool closes its SSE stream three times; before each next cycle it waits until the + client has observed the previous cycle's two new resumption tokens (the checkpoint and the + new connection's priming event). The priming event is sent only after the server has + re-registered the resumed stream, so once the client holds its token the next close is + guaranteed to sever a live connection rather than silently no-op — making the exact token + count below a consequence of causality, not timing margins. This pins reconnect-per-close + accounting; reconnect *latency* is pinned by test_streamable_http_client_respects_retry_interval. With 3 checkpoints, we expect 8 resumption tokens: - 1 priming (initial POST connection) @@ -2133,45 +1931,72 @@ async def test_streamable_http_multiple_reconnections( - 3 priming (one per reconnect after each close) - 1 response """ - _, server_url = event_server resumption_tokens: list[str] = [] + # milestones[n] fires when the client has observed n tokens. After the initial priming + # (token 1), each completed cycle i contributes exactly two tokens — checkpoint_i and the + # reconnect's priming, in either order — so cycle i is complete at 3 + 2i tokens. + milestones = {3: anyio.Event(), 5: anyio.Event(), 7: anyio.Event()} async def on_resumption_token(token: str) -> None: resumption_tokens.append(token) + milestone = milestones.get(len(resumption_tokens)) + if milestone is not None: + milestone.set() - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Use send_request with metadata to track resumption tokens - metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) - result = await session.send_request( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name="tool_with_multiple_stream_closes", - # retry_interval=500ms, so sleep 600ms to ensure reconnect completes - arguments={"checkpoints": 3, "sleep_time": 0.6}, - ), - ), - types.CallToolResult, - metadata=metadata, + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "multi_close_tool" + for i, milestone in enumerate(milestones.values()): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Client and server share one event loop, so the tool can wait directly on the + # client-side callback observing the reconnect. + with anyio.fail_after(5): + await milestone.wait() + return CallToolResult(content=[TextContent(type="text", text="Completed 3 checkpoints")]) + + server = Server("multi_reconnect_server", on_call_tool=handle_call_tool) + + async with ( + # retry_interval is small to keep the test fast, but nonzero so each dying connection + # finishes unwinding before its replacement registers. + running_app(event_store=SimpleEventStore(), retry_interval=50, server=server) as app, + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + + # Use send_request with metadata to track resumption tokens + metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) + result = await session.send_request( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name="multi_close_tool", arguments={}), + ), + types.CallToolResult, + metadata=metadata, + ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Completed 3 checkpoints" in result.content[0].text + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert "Completed 3 checkpoints" in result.content[0].text - # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are - # captured before send_request returns, so this is safe to check here. - assert len(resumption_tokens) == 8, ( - f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " - f"got {len(resumption_tokens)}: {resumption_tokens}" - ) + # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are + # captured before send_request returns, so this is safe to check here. + assert len(resumption_tokens) == 8, ( + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) @pytest.mark.anyio -async def test_standalone_get_stream_reconnection(event_server: tuple[SimpleEventStore, str]) -> None: +async def test_standalone_get_stream_reconnection(event_app: tuple[SimpleEventStore, Starlette]) -> None: """Test that standalone GET stream automatically reconnects after server closes it. Verifies: @@ -2180,10 +2005,10 @@ async def test_standalone_get_stream_reconnection(event_server: tuple[SimpleEven 3. Client reconnects with Last-Event-ID 4. Client receives notification 2 on new connection - Note: Requires event_server fixture (with event store) because close_standalone_sse_stream + Note: Requires the event store app because close_standalone_sse_stream callback is only provided when event_store is configured and protocol version >= 2025-11-25. """ - _, server_url = event_server + _, app = event_app received_notifications: list[str] = [] async def message_handler( @@ -2195,45 +2020,46 @@ async def message_handler( if isinstance(message, types.ResourceUpdatedNotification): # pragma: no branch received_notifications.append(str(message.params.uri)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() - - # Call tool that: - # 1. Sends notification_1 via GET stream - # 2. Closes standalone GET stream - # 3. Sends notification_2 (stored in event_store) - # 4. Returns response - result = await session.call_tool("tool_with_standalone_stream_close", {}) - - # Verify the tool completed - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Standalone stream close test done" - - # Verify both notifications were received - assert "http://notification_1" in received_notifications, ( - f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" - ) - assert "http://notification_2" in received_notifications, ( - f"Should receive notification 2 after reconnect, got: {received_notifications}" - ) + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + + # Call tool that: + # 1. Sends notification_1 via GET stream + # 2. Closes standalone GET stream + # 3. Sends notification_2 (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_standalone_stream_close", {}) + + # Verify the tool completed + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Standalone stream close test done" + + # Verify both notifications were received + assert "http://notification_1" in received_notifications, ( + f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" + ) + assert "http://notification_2" in received_notifications, ( + f"Should receive notification 2 after reconnect, got: {received_notifications}" + ) @pytest.mark.anyio -async def test_streamable_http_client_does_not_mutate_provided_client( - basic_server: None, basic_server_url: str -) -> None: - """Test that streamable_http_client does not mutate the provided httpx client's headers.""" +async def test_streamable_http_client_does_not_mutate_provided_client(basic_app: Starlette) -> None: + """streamable_http_client does not mutate the provided httpx client's headers.""" # Create a client with custom headers original_headers = { "X-Custom-Header": "custom-value", "Authorization": "Bearer test-token", } - async with httpx.AsyncClient(headers=original_headers, follow_redirects=True) as custom_client: + async with make_client(basic_app, headers=original_headers) as custom_client: # Use the client with streamable_http_client - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=custom_client) as ( read_stream, write_stream, ): @@ -2254,18 +2080,16 @@ async def test_streamable_http_client_does_not_mutate_provided_client( @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str -) -> None: - """Test that MCP protocol headers override httpx.AsyncClient default headers.""" +async def test_streamable_http_client_mcp_headers_override_defaults(context_app: Starlette) -> None: + """MCP protocol headers override the httpx client's default headers in actual requests.""" # httpx.AsyncClient has default "accept: */*" header # We need to verify that our MCP accept header overrides it in actual requests - async with httpx.AsyncClient(follow_redirects=True) as client: + async with make_client(context_app) as client: # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2285,18 +2109,16 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str -) -> None: - """Test that both custom headers and MCP protocol headers are sent in requests.""" +async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_app: Starlette) -> None: + """Custom client headers and MCP protocol headers are both sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", "X-Request-Id": "req-123", "Authorization": "Bearer test-token", } - async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with make_client(context_app, headers=custom_headers) as client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 810c72820..0038b1890 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,7 +2,6 @@ import socket import threading -import time from collections.abc import Generator from contextlib import contextmanager from typing import Any @@ -56,30 +55,3 @@ def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None finally: server.should_exit = True thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) - - -def wait_for_server(port: int, timeout: float = 20.0) -> None: - """Wait for server to be ready to accept connections. - - Polls the server port until it accepts connections or timeout is reached. - This eliminates race conditions without arbitrary sleeps. - - Args: - port: The port number to check - timeout: Maximum time to wait in seconds (default 5.0) - - Raises: - TimeoutError: If server doesn't start within the timeout period - """ - start_time = time.time() - while time.time() - start_time < timeout: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.settimeout(0.1) - s.connect(("127.0.0.1", port)) - # Server is ready - return - except (ConnectionRefusedError, OSError): - # Server not ready yet, retry quickly - time.sleep(0.01) - raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") # pragma: no cover