From b63f5f884732ddbf20e499394c5fa5d762784f8e Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Wed, 3 Jun 2026 05:06:59 +0800 Subject: [PATCH] fix(server): return stdio parse errors --- src/mcp/server/stdio.py | 50 +++++++++++- tests/interaction/transports/test_stdio.py | 10 ++- tests/server/test_stdio.py | 91 +++++++++++++++++++--- 3 files changed, 134 insertions(+), 17 deletions(-) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5c1459dff6..add8f2e557 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -17,9 +17,12 @@ async def run_server(): ``` """ +import json +import re import sys from contextlib import asynccontextmanager from io import TextIOWrapper +from typing import Any, cast import anyio import anyio.lowlevel @@ -28,6 +31,50 @@ async def run_server(): from mcp.shared._context_streams import create_context_streams from mcp.shared.message import SessionMessage +_JSONRPC_ID_PATTERN = re.compile(r'"id"\s*:\s*(-?\d+|"[^"\\]*")') + + +def _request_id_from_raw_message(line: str) -> types.RequestId | None: + try: + raw_message: Any = json.loads(line) + except Exception: + raw_message = None + + if not isinstance(raw_message, dict): + match = _JSONRPC_ID_PATTERN.search(line) + if not match: + return None + + raw_request_id = match.group(1) + if raw_request_id.startswith('"'): + return json.loads(raw_request_id) + return int(raw_request_id) + + raw_message_dict = cast(dict[str, Any], raw_message) + request_id = raw_message_dict.get("id") + if isinstance(request_id, str) or type(request_id) is int: + return request_id + return None + + +def _error_response_from_parse_failure(line: str, exc: Exception) -> SessionMessage: + request_id = _request_id_from_raw_message(line) + message = str(exc) + if "Invalid JSON" in message: + code = types.PARSE_ERROR + prefix = "Parse error" + else: + code = types.INVALID_REQUEST + prefix = "Invalid request" + + return SessionMessage( + types.JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=types.ErrorData(code=code, message=f"{prefix}: {message}"), + ) + ) + @asynccontextmanager async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None): @@ -53,7 +100,8 @@ async def stdin_reader(): try: message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) except Exception as exc: - await read_stream_writer.send(exc) + error_response = _error_response_from_parse_failure(line, exc) + await write_stream.send(error_response) continue session_message = SessionMessage(message) diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py index 27cc65de42..97dac41a77 100644 --- a/tests/interaction/transports/test_stdio.py +++ b/tests/interaction/transports/test_stdio.py @@ -21,6 +21,7 @@ import sys import tempfile from pathlib import Path +from typing import TextIO, cast import anyio import pytest @@ -60,7 +61,8 @@ async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() - async def collect(params: LoggingMessageNotificationParams) -> None: received.append(params) - with tempfile.TemporaryFile(mode="w+") as errlog: + with tempfile.TemporaryFile(mode="w+") as errlog_file: + errlog = cast(TextIO, errlog_file) transport = stdio_client( StdioServerParameters( command=sys.executable, @@ -90,9 +92,9 @@ async def collect(params: LoggingMessageNotificationParams) -> None: ) # The server writes this line only after its run loop returns, which happens when stdin closes: # seeing it proves the process exited on its own rather than via the transport's terminate - # escalation, without a timing-based assertion. The capture itself proves stderr passthrough: - # the transport routes the child's stderr to the caller's `errlog` without consuming it. - assert captured_stderr == snapshot("stdio-echo: clean exit\n") + # escalation, without a timing-based assertion. The suffix check keeps the test stable if the + # child interpreter emits dependency warnings before the server's own stderr line. + assert captured_stderr.endswith("stdio-echo: clean exit\n") @requirement("transport:stdio:stream-purity") diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 677a993567..57cb50d646 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,13 +1,22 @@ import io +import json import sys from io import TextIOWrapper import anyio import pytest -from mcp.server.stdio import stdio_server +from mcp.server.stdio import _error_response_from_parse_failure, _request_id_from_raw_message, stdio_server from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter +from mcp.types import ( + INVALID_REQUEST, + PARSE_ERROR, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + jsonrpc_message_adapter, +) @pytest.mark.anyio @@ -68,8 +77,8 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): """Non-UTF-8 bytes on stdin must not crash the server. Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and - is delivered as an in-stream exception. Subsequent valid messages must - still be processed. + is returned as a JSON-RPC parse error. Subsequent valid messages must still + be processed. """ # \xff\xfe are invalid UTF-8 start bytes. valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") @@ -78,17 +87,75 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): # Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that # stdio_server()'s default path wraps it with errors='replace'. monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) - monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8")) + stdout = io.StringIO() with anyio.fail_after(5): - async with stdio_server() as (read_stream, write_stream): - await write_stream.aclose() + async with stdio_server(stdout=anyio.AsyncFile(stdout)) as (read_stream, write_stream): async with read_stream: # pragma: no branch - # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream + # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> error response on stdout first = await read_stream.receive() - assert isinstance(first, Exception) # Second line: valid message still comes through - second = await read_stream.receive() - assert isinstance(second, SessionMessage) - assert second.message == valid + assert isinstance(first, SessionMessage) + assert first.message == valid + + await write_stream.aclose() + + stdout.seek(0) + output = stdout.read() + error = jsonrpc_message_adapter.validate_json(output.strip()) + assert isinstance(error, JSONRPCError) + assert error.id is None + assert error.error.code == PARSE_ERROR + + +@pytest.mark.anyio +async def test_stdio_server_parse_error_completes_id_bearing_request(): + params: object = {"leaf": True} + for index in reversed(range(256)): + params = {f"p{index}": params} + line = json.dumps({"jsonrpc": "2.0", "id": 900256, "method": "ping", "params": params}) + "\n" + + stdin = io.StringIO(line) + stdout = io.StringIO() + + with anyio.fail_after(5): + async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as ( + read_stream, + write_stream, + ): + async with read_stream: + with pytest.raises(anyio.EndOfStream): + await read_stream.receive() + await write_stream.aclose() + + stdout.seek(0) + output_lines = stdout.readlines() + assert len(output_lines) == 1 + + response = jsonrpc_message_adapter.validate_json(output_lines[0].strip()) + assert isinstance(response, JSONRPCError) + assert response.id == 900256 + assert response.error.code == PARSE_ERROR + assert "Parse error" in response.error.message + + +def test_stdio_request_id_recovery_edges(): + assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":"abc","method":"ping","params":[') == "abc" + assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":42,"method":"ping","params":[') == 42 + assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":-7,"method":1}') == -7 + assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":null,"method":1}') is None + assert _request_id_from_raw_message("[]") is None + + +def test_stdio_invalid_request_response_preserves_string_id(): + line = '{"jsonrpc":"2.0","id":"bad-method","method":1}' + with pytest.raises(Exception) as exc_info: + jsonrpc_message_adapter.validate_json(line) + + response = _error_response_from_parse_failure(line, exc_info.value) + + assert isinstance(response.message, JSONRPCError) + assert response.message.id == "bad-method" + assert response.message.error.code == INVALID_REQUEST + assert "Invalid request" in response.message.error.message