diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f2f4407ce..98948ff99 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -669,7 +669,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - if not await self._validate_request_headers(request, send): # pragma: no cover + if not await self._validate_request_headers(request, send): return # Handle resumability: check for Last-Event-ID header diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py index e69de29bb..b5bbb633c 100644 --- a/tests/interaction/transports/__init__.py +++ b/tests/interaction/transports/__init__.py @@ -0,0 +1,9 @@ +"""Transport-specific interaction tests, and the in-process streaming bridge they are built on. + +`StreamingASGITransport` is re-exported here as the sanctioned import point for test code +outside this suite (the bridge module itself is suite-private). +""" + +from tests.interaction.transports._bridge import StreamingASGITransport + +__all__ = ["StreamingASGITransport"] diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index e95dc51b3..e77bd5e2c 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,15 +1,12 @@ """Tests for SSE server request validation.""" import logging -import multiprocessing import re -import socket import anyio import httpx import pytest import sse_starlette.sse -import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -23,12 +20,16 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._stream_protocols import WriteStream from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCRequest, JSONRPCResponse, Tool -from tests.test_helpers import wait_for_server +from mcp.types import JSONRPCRequest, JSONRPCResponse +from tests.interaction.transports import StreamingASGITransport logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +# The in-process app is mounted at this origin purely so URLs are well-formed and the default +# Host header is a localhost form; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + @pytest.fixture(autouse=True) def reset_sse_starlette_exit_event() -> None: @@ -39,275 +40,161 @@ def reset_sse_starlette_exit_event() -> None: app_status.should_exit_event = None -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - async def on_list_tools(self) -> list[Tool]: - return [] - - -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the SSE server with specified security settings.""" - app = SecurityTestServer() +def sse_security_client(security_settings: TransportSecuritySettings | None = None) -> httpx.AsyncClient: + """An httpx client whose requests are served in process by an SSE app with the given settings.""" + server = Server(SERVER_NAME) sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request): + async def handle_sse(request: Request) -> Response: try: - async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: - if streams: - await app.run(streams[0], streams[1], app.create_initialization_options()) + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (read, write): + await server.run(read, write, server.create_initialization_options()) except ValueError as e: - # Validation error was already handled inside connect_sse + # Validation error was already handled inside connect_sse, which sent the rejection + # response itself; its non-empty body checkpoints, so the test reads the rejection + # status before the trailing Response() below sends a second response start. logger.debug(f"SSE connection failed validation: {e}") return Response() - routes = [ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse_transport.handle_post_message), - ] - - starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") - - -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + ) + # The SSE GET runs until it observes a disconnect, so the bridge must let the application + # drain on close rather than cancelling it. + transport = StreamingASGITransport(app, cancel_on_close=False) + return httpx.AsyncClient(transport=transport, base_url=BASE_URL) @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): - """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) +async def test_sse_security_default_settings() -> None: + """With default security settings (protection disabled), any Host and Origin connect.""" + headers = {"Host": "evil.com", "Origin": "http://evil.com"} - try: - headers = {"Host": "evil.com", "Origin": "http://evil.com"} - - async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with sse_security_client() as client: + async with client.stream("GET", "/sse", headers=headers) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): - """Test SSE with invalid Host header.""" - # Enable security by providing settings with an empty allowed_hosts list +async def test_sse_security_invalid_host_header() -> None: + """A Host header outside allowed_hosts is rejected with 421.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - try: - # Test with invalid host header - headers = {"Host": "evil.com"} - - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + async with sse_security_client(security_settings) as client: + response = await client.get("/sse", headers={"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): - """Test SSE with invalid Origin header.""" - # Configure security to allow the host but restrict origins +async def test_sse_security_invalid_origin_header() -> None: + """An Origin header outside allowed_origins is rejected with 403.""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = {"Origin": "http://evil.com"} - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + async with sse_security_client(security_settings) as client: + response = await client.get("/sse", headers={"Origin": "http://evil.com"}) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): - """Test POST endpoint with invalid Content-Type header.""" - # Configure security to allow the host +async def test_sse_security_post_invalid_content_type() -> None: + """A POST whose Content-Type is not application/json (or is missing) is rejected with 400.""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": "text/plain"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" + fake_session_id = "12345678123456781234567812345678" - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" + async with sse_security_client(security_settings) as client: + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": "text/plain"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() + response = await client.post(f"/messages/?session_id={fake_session_id}", content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): - """Test SSE with security disabled.""" +async def test_sse_security_disabled() -> None: + """With protection explicitly disabled, a disallowed Host still connects.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = {"Host": "evil.com"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + async with sse_security_client(settings) as client: + async with client.stream("GET", "/sse", headers={"Host": "evil.com"}) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): - """Test SSE with custom allowed hosts.""" +async def test_sse_security_custom_allowed_hosts() -> None: + """A custom entry in allowed_hosts connects; hosts outside the list are still rejected.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = {"Host": "custom.host"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with custom host - assert response.status_code == 200 - # Test with non-allowed host - headers = {"Host": "evil.com"} + async with sse_security_client(settings) as client: + async with client.stream("GET", "/sse", headers={"Host": "custom.host"}) as response: + assert response.status_code == 200 - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + response = await client.get("/sse", headers={"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): - """Test SSE with wildcard port patterns.""" +async def test_sse_security_wildcard_ports() -> None: + """A `host:*` pattern accepts that host with any port, for Host and Origin alike.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - try: - # Test with various port numbers + async with sse_security_client(settings) as client: for test_port in [8080, 3000, 9999]: - headers = {"Host": f"localhost:{test_port}"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 - - headers = {"Origin": f"http://localhost:{test_port}"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 + async with client.stream("GET", "/sse", headers={"Host": f"localhost:{test_port}"}) as response: + assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with client.stream("GET", "/sse", headers={"Origin": f"http://localhost:{test_port}"}) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): - """Test POST endpoint with valid Content-Type headers.""" - # Configure security to allow the host +async def test_sse_security_post_valid_content_type() -> None: + """Every application/json Content-Type variant passes validation (reaching the session lookup).""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient() as client: - # Test with various valid content types - valid_content_types = [ - "application/json", - "application/json; charset=utf-8", - "application/json;charset=utf-8", - "APPLICATION/JSON", # Case insensitive - ] - - for content_type in valid_content_types: - # Use a valid UUID format (even though session won't exist) - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": content_type}, - json={"test": "data"}, - ) - # Will get 404 because session doesn't exist, but that's OK - # We're testing that it passes the content-type check - assert response.status_code == 404 - assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() + valid_content_types = [ + "application/json", + "application/json; charset=utf-8", + "application/json;charset=utf-8", + "APPLICATION/JSON", # Case insensitive + ] + # A well-formed session ID that no live session owns. + fake_session_id = "12345678123456781234567812345678" + + async with sse_security_client(security_settings) as client: + for content_type in valid_content_types: + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": content_type}, + json={"test": "data"}, + ) + # 404 proves the request passed the content-type check and reached the session lookup. + assert response.status_code == 404 + assert response.text == "Could not find session" def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index ba7554796..f02e520ee 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -340,12 +340,33 @@ async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestP await client.list_tools() +class _IdleTimeoutObserver(logging.Handler): + """Resolves `reaped` when the manager logs that a session's idle timeout fired.""" + + def __init__(self) -> None: + super().__init__() + self.reaped = anyio.Event() + + def emit(self, record: logging.LogRecord) -> None: + if "idle timeout" in record.getMessage(): + self.reaped.set() + + @pytest.mark.anyio -async def test_idle_session_is_reaped(): +async def test_idle_session_is_reaped(caplog: pytest.LogCaptureFixture, request: pytest.FixtureRequest): """After idle timeout fires, the session returns 404.""" app = Server("test-idle-reap") manager = StreamableHTTPSessionManager(app=app, session_idle_timeout=0.05) + # The reap is observed through the manager's own "idle timeout" log record: the manager pops + # the session synchronously after emitting it, before its next await, so a waiter woken by + # the record always finds the session gone. caplog.set_level enables INFO so it is created. + observer = _IdleTimeoutObserver() + manager_logger = logging.getLogger(streamable_http_manager.__name__) + manager_logger.addHandler(observer) + request.addfinalizer(lambda: manager_logger.removeHandler(observer)) + caplog.set_level(logging.INFO, logger=streamable_http_manager.__name__) + async with manager.run(): sent_messages: list[Message] = [] @@ -376,8 +397,10 @@ async def mock_receive(): # pragma: no cover assert session_id is not None, "Session ID not found in response headers" - # Wait for the 50ms idle timeout to fire and cleanup to complete - await anyio.sleep(0.1) + # Wait for the 50ms idle timeout to fire and the session to be unregistered. Re-requesting + # the session to poll for the 404 would push its idle deadline forward and keep it alive. + with anyio.fail_after(5): + await observer.reaped.wait() # Verify via public API: old session ID now returns 404 response_messages: list[Message] = [] diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353..f13bb4a9b 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -1,291 +1,130 @@ """Tests for StreamableHTTP server DNS rebinding protection.""" -import multiprocessing -import socket -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import Mount -from starlette.types import Receive, Scope, Send from mcp.server import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport SERVER_NAME = "test_streamable_http_security_server" +# The in-process app is mounted at this origin purely so URLs are well-formed and the default +# Host header is a localhost form; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +@asynccontextmanager +async def streamable_http_security_client( + security_settings: TransportSecuritySettings | None = None, +) -> AsyncIterator[httpx.AsyncClient]: + """Yield an httpx client served in process by a StreamableHTTP app with the given settings.""" + session_manager = StreamableHTTPSessionManager(app=Server(SERVER_NAME), security_settings=security_settings) + app = Starlette(routes=[Mount("/", app=session_manager.handle_request)]) -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" + async with session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as client: + yield client -class SecurityTestServer(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) +def _base_headers() -> dict[str, str]: + """Headers every well-formed request carries, so each test varies only the header under test.""" + return {"Accept": "application/json, text/event-stream", "Content-Type": "application/json"} - async def on_list_tools(self) -> list[Tool]: - return [] - -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the StreamableHTTP server with specified security settings.""" - app = SecurityTestServer() - - # Create session manager with security settings - session_manager = StreamableHTTPSessionManager( - app=app, - json_response=False, - stateless=False, - security_settings=security_settings, - ) - - # Create the ASGI handler - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - # Create Starlette app with lifespan - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield - - routes = [ - Mount("/", app=handle_streamable_http), - ] - - starlette_app = Starlette(routes=routes, lifespan=lifespan) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") - - -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process +def _initialize_body() -> dict[str, object]: + """A minimal initialize POST body; these tests assert header validation, not the handshake.""" + return {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int): - """Test StreamableHTTP with default security settings (protection enabled).""" - process = start_server_process(server_port) - - try: - # Test with valid localhost headers - async with httpx.AsyncClient(timeout=5.0) as client: - # POST request to initialize session - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - ) - assert response.status_code == 200 - assert "mcp-session-id" in response.headers - - finally: - process.terminate() - process.join() +async def test_streamable_http_security_default_settings() -> None: + """With default security settings, a request with localhost headers is served.""" + async with streamable_http_security_client() as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers()) + assert response.status_code == 200 + assert "mcp-session-id" in response.headers @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int): - """Test StreamableHTTP with invalid Host header.""" +async def test_streamable_http_security_invalid_host_header() -> None: + """A Host header outside allowed_hosts is rejected with 421.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int): - """Test StreamableHTTP with invalid Origin header.""" +async def test_streamable_http_security_invalid_origin_header() -> None: + """An Origin header outside allowed_origins is rejected with 403.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = { - "Origin": "http://evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.post( + "/", json=_initialize_body(), headers=_base_headers() | {"Origin": "http://evil.com"} + ) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int): - """Test StreamableHTTP POST with invalid Content-Type header.""" - process = start_server_process(server_port) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={ - "Content-Type": "text/plain", - "Accept": "application/json, text/event-stream", - }, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={"Accept": "application/json, text/event-stream"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - finally: - process.terminate() - process.join() +async def test_streamable_http_security_invalid_content_type() -> None: + """A POST whose Content-Type is not application/json (or is missing) is rejected with 400.""" + async with streamable_http_security_client() as client: + response = await client.post("/", headers=_base_headers() | {"Content-Type": "text/plain"}, content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + response = await client.post("/", headers={"Accept": "application/json, text/event-stream"}, content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int): - """Test StreamableHTTP with security disabled.""" +async def test_streamable_http_security_disabled() -> None: + """With protection explicitly disabled, a disallowed Host is still served.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int): - """Test StreamableHTTP with custom allowed hosts.""" +async def test_streamable_http_security_custom_allowed_hosts() -> None: + """A custom entry in allowed_hosts is served.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = { - "Host": "custom.host", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully with custom host - assert response.status_code == 200 - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "custom.host"}) + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int): - """Test StreamableHTTP GET request with security.""" +async def test_streamable_http_security_get_request() -> None: + """GET requests pass the same Host validation before any session handling.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) - process = start_server_process(server_port, security_settings) - - try: - # Test GET request with invalid host header - headers = { - "Host": "evil.com", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - # Test GET request with valid host header - headers = { - "Host": "127.0.0.1", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - # GET requests need a session ID in StreamableHTTP - # So it will fail with "Missing session ID" not security error - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - # This should pass security but fail on session validation - assert response.status_code == 400 - body = response.json() - assert "Missing session ID" in body["error"]["message"] - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "127.0.0.1"}) + # An allowed host passes security and fails on session validation instead. + assert response.status_code == 400 + body = response.json() + assert "Missing session ID" in body["error"]["message"]