Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/mcp/server/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ async def run_server():
```
"""

import json
import re
import sys
from contextlib import asynccontextmanager
from io import TextIOWrapper
from typing import Any, cast

import anyio
import anyio.lowlevel
Expand All @@ -28,6 +31,50 @@ async def run_server():
from mcp.shared._context_streams import create_context_streams
from mcp.shared.message import SessionMessage

_JSONRPC_ID_PATTERN = re.compile(r'"id"\s*:\s*(-?\d+|"[^"\\]*")')


def _request_id_from_raw_message(line: str) -> types.RequestId | None:
try:
raw_message: Any = json.loads(line)
except Exception:
raw_message = None

if not isinstance(raw_message, dict):
match = _JSONRPC_ID_PATTERN.search(line)
if not match:
return None

raw_request_id = match.group(1)
if raw_request_id.startswith('"'):
return json.loads(raw_request_id)
return int(raw_request_id)

raw_message_dict = cast(dict[str, Any], raw_message)
request_id = raw_message_dict.get("id")
if isinstance(request_id, str) or type(request_id) is int:
return request_id
return None


def _error_response_from_parse_failure(line: str, exc: Exception) -> SessionMessage:
request_id = _request_id_from_raw_message(line)
message = str(exc)
if "Invalid JSON" in message:
code = types.PARSE_ERROR
prefix = "Parse error"
else:
code = types.INVALID_REQUEST
prefix = "Invalid request"

return SessionMessage(
types.JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=types.ErrorData(code=code, message=f"{prefix}: {message}"),
)
)


@asynccontextmanager
async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None):
Expand All @@ -53,7 +100,8 @@ async def stdin_reader():
try:
message = types.jsonrpc_message_adapter.validate_json(line, by_name=False)
except Exception as exc:
await read_stream_writer.send(exc)
error_response = _error_response_from_parse_failure(line, exc)
await write_stream.send(error_response)
continue

session_message = SessionMessage(message)
Expand Down
10 changes: 6 additions & 4 deletions tests/interaction/transports/test_stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sys
import tempfile
from pathlib import Path
from typing import TextIO, cast

import anyio
import pytest
Expand Down Expand Up @@ -60,7 +61,8 @@ async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() -
async def collect(params: LoggingMessageNotificationParams) -> None:
received.append(params)

with tempfile.TemporaryFile(mode="w+") as errlog:
with tempfile.TemporaryFile(mode="w+") as errlog_file:
errlog = cast(TextIO, errlog_file)
transport = stdio_client(
StdioServerParameters(
command=sys.executable,
Expand Down Expand Up @@ -90,9 +92,9 @@ async def collect(params: LoggingMessageNotificationParams) -> None:
)
# The server writes this line only after its run loop returns, which happens when stdin closes:
# seeing it proves the process exited on its own rather than via the transport's terminate
# escalation, without a timing-based assertion. The capture itself proves stderr passthrough:
# the transport routes the child's stderr to the caller's `errlog` without consuming it.
assert captured_stderr == snapshot("stdio-echo: clean exit\n")
# escalation, without a timing-based assertion. The suffix check keeps the test stable if the
# child interpreter emits dependency warnings before the server's own stderr line.
assert captured_stderr.endswith("stdio-echo: clean exit\n")


@requirement("transport:stdio:stream-purity")
Expand Down
91 changes: 79 additions & 12 deletions tests/server/test_stdio.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import io
import json
import sys
from io import TextIOWrapper

import anyio
import pytest

from mcp.server.stdio import stdio_server
from mcp.server.stdio import _error_response_from_parse_failure, _request_id_from_raw_message, stdio_server
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
from mcp.types import (
INVALID_REQUEST,
PARSE_ERROR,
JSONRPCError,
JSONRPCMessage,
JSONRPCRequest,
JSONRPCResponse,
jsonrpc_message_adapter,
)


@pytest.mark.anyio
Expand Down Expand Up @@ -68,8 +77,8 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
"""Non-UTF-8 bytes on stdin must not crash the server.

Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and
is delivered as an in-stream exception. Subsequent valid messages must
still be processed.
is returned as a JSON-RPC parse error. Subsequent valid messages must still
be processed.
"""
# \xff\xfe are invalid UTF-8 start bytes.
valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
Expand All @@ -78,17 +87,75 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
# Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that
# stdio_server()'s default path wraps it with errors='replace'.
monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8"))
monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8"))
stdout = io.StringIO()

with anyio.fail_after(5):
async with stdio_server() as (read_stream, write_stream):
await write_stream.aclose()
async with stdio_server(stdout=anyio.AsyncFile(stdout)) as (read_stream, write_stream):
async with read_stream: # pragma: no branch
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> error response on stdout
first = await read_stream.receive()
assert isinstance(first, Exception)

# Second line: valid message still comes through
second = await read_stream.receive()
assert isinstance(second, SessionMessage)
assert second.message == valid
assert isinstance(first, SessionMessage)
assert first.message == valid

await write_stream.aclose()

stdout.seek(0)
output = stdout.read()
error = jsonrpc_message_adapter.validate_json(output.strip())
assert isinstance(error, JSONRPCError)
assert error.id is None
assert error.error.code == PARSE_ERROR


@pytest.mark.anyio
async def test_stdio_server_parse_error_completes_id_bearing_request():
params: object = {"leaf": True}
for index in reversed(range(256)):
params = {f"p{index}": params}
line = json.dumps({"jsonrpc": "2.0", "id": 900256, "method": "ping", "params": params}) + "\n"

stdin = io.StringIO(line)
stdout = io.StringIO()

with anyio.fail_after(5):
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
read_stream,
write_stream,
):
async with read_stream:
with pytest.raises(anyio.EndOfStream):
await read_stream.receive()
await write_stream.aclose()

stdout.seek(0)
output_lines = stdout.readlines()
assert len(output_lines) == 1

response = jsonrpc_message_adapter.validate_json(output_lines[0].strip())
assert isinstance(response, JSONRPCError)
assert response.id == 900256
assert response.error.code == PARSE_ERROR
assert "Parse error" in response.error.message


def test_stdio_request_id_recovery_edges():
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":"abc","method":"ping","params":[') == "abc"
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":42,"method":"ping","params":[') == 42
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":-7,"method":1}') == -7
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":null,"method":1}') is None
assert _request_id_from_raw_message("[]") is None


def test_stdio_invalid_request_response_preserves_string_id():
line = '{"jsonrpc":"2.0","id":"bad-method","method":1}'
with pytest.raises(Exception) as exc_info:
jsonrpc_message_adapter.validate_json(line)

response = _error_response_from_parse_failure(line, exc_info.value)

assert isinstance(response.message, JSONRPCError)
assert response.message.id == "bad-method"
assert response.message.error.code == INVALID_REQUEST
assert "Invalid request" in response.message.error.message
Loading