From cfe3ae4921671ae08c9a6d159da1c4d8b1df6b8a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:56:49 +0000 Subject: [PATCH 01/52] feat: add Dispatcher Protocol and DirectDispatcher Introduces the Dispatcher abstraction that decouples MCP request/response handling from JSON-RPC framing. A Dispatcher exposes call/notify for outbound messages and run(on_call, on_notify) for inbound dispatch, with no knowledge of MCP types or wire encoding. - shared/dispatcher.py: Dispatcher, DispatchContext, RequestSender Protocols; CallOptions, OnCall/OnNotify, ProgressFnT, DispatchMiddleware - shared/transport_context.py: TransportContext base dataclass - shared/direct_dispatcher.py: in-memory Dispatcher impl that wires two peers with no transport; serves as a fast test substrate and second-impl proof - shared/exceptions.py: NoBackChannelError(MCPError) for transports without a server-to-client request channel - types: REQUEST_CANCELLED SDK error code The JSON-RPC implementation and ServerRunner that consume this Protocol land in follow-up PRs. --- src/mcp/shared/direct_dispatcher.py | 173 +++++++++++++++++++ src/mcp/shared/dispatcher.py | 167 ++++++++++++++++++ src/mcp/shared/exceptions.py | 21 ++- src/mcp/shared/transport_context.py | 30 ++++ src/mcp/types/__init__.py | 2 + src/mcp/types/jsonrpc.py | 1 + tests/shared/test_dispatcher.py | 253 ++++++++++++++++++++++++++++ 7 files changed, 646 insertions(+), 1 deletion(-) create mode 100644 src/mcp/shared/direct_dispatcher.py create mode 100644 src/mcp/shared/dispatcher.py create mode 100644 src/mcp/shared/transport_context.py create mode 100644 tests/shared/test_dispatcher.py diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py new file mode 100644 index 0000000000..4650619428 --- /dev/null +++ b/src/mcp/shared/direct_dispatcher.py @@ -0,0 +1,173 @@ +"""In-memory `Dispatcher` that wires two peers together with no transport. + +`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a call +on one side directly invokes the other side's `on_call`. There is no +serialization, no JSON-RPC framing, and no streams. It exists to: + +* prove the `Dispatcher` Protocol is implementable without JSON-RPC +* provide a fast substrate for testing the layers above the dispatcher + (`ServerRunner`, `Context`, `Connection`) without wire-level moving parts +* embed a server in-process when the JSON-RPC overhead is unnecessary + +Unlike `JSONRPCDispatcher`, exceptions raised in a handler propagate directly +to the caller — there is no exception-to-`ErrorData` boundary here. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any + +import anyio + +from mcp.shared.dispatcher import CallOptions, OnCall, OnNotify, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT + +__all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"] + +DIRECT_TRANSPORT_KIND = "direct" + + +_Call = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] +_Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]] + + +@dataclass +class _DirectDispatchContext: + """`DispatchContext` for an inbound call on a `DirectDispatcher`. + + The back-channel callables target the *originating* side, so a handler's + `send_request` reaches the peer that made the inbound call. + """ + + transport: TransportContext + _back_call: _Call + _back_notify: _Notify + _on_progress: ProgressFnT | None = None + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._back_notify(method, params) + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.transport.can_send_request: + raise NoBackChannelError(method) + return await self._back_call(method, params, opts) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._on_progress is not None: + await self._on_progress(progress, total, message) + + +class DirectDispatcher: + """A `Dispatcher` that calls a peer's handlers directly, in-process. + + Two instances are wired together with `create_direct_dispatcher_pair`; each + holds a reference to the other. `call` on one awaits the peer's `on_call`. + `run` parks until `close` is called. + """ + + def __init__(self, transport_ctx: TransportContext): + self._transport_ctx = transport_ctx + self._peer: DirectDispatcher | None = None + self._on_call: OnCall | None = None + self._on_notify: OnNotify | None = None + self._ready = anyio.Event() + self._closed = anyio.Event() + + def connect_to(self, peer: DirectDispatcher) -> None: + self._peer = peer + + async def call( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + return await self._peer._dispatch_call(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + await self._peer._dispatch_notify(method, params) + + async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: + self._on_call = on_call + self._on_notify = on_notify + self._ready.set() + await self._closed.wait() + + def close(self) -> None: + self._closed.set() + + def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext: + assert self._peer is not None + peer = self._peer + return _DirectDispatchContext( + transport=self._transport_ctx, + _back_call=lambda m, p, o: peer._dispatch_call(m, p, o), + _back_notify=lambda m, p: peer._dispatch_notify(m, p), + _on_progress=on_progress, + ) + + async def _dispatch_call( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None, + ) -> dict[str, Any]: + await self._ready.wait() + assert self._on_call is not None + opts = opts or {} + dctx = self._make_context(on_progress=opts.get("on_progress")) + try: + with anyio.fail_after(opts.get("timeout")): + try: + return await self._on_call(dctx, method, params) + except MCPError: + raise + except Exception as e: + raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e + except TimeoutError: + raise MCPError( + code=REQUEST_TIMEOUT, + message=f"Timed out after {opts.get('timeout')}s waiting for {method!r}", + ) from None + + async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._ready.wait() + assert self._on_notify is not None + dctx = self._make_context() + await self._on_notify(dctx, method, params) + + +def create_direct_dispatcher_pair( + *, + can_send_request: bool = True, +) -> tuple[DirectDispatcher, DirectDispatcher]: + """Create two `DirectDispatcher` instances wired to each other. + + Args: + can_send_request: Sets `TransportContext.can_send_request` on both + sides. Pass ``False`` to simulate a transport with no back-channel. + + Returns: + A ``(left, right)`` pair. Conventionally ``left`` is the client side + and ``right`` is the server side, but the wiring is symmetric. + """ + ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request) + left = DirectDispatcher(ctx) + right = DirectDispatcher(ctx) + left.connect_to(right) + right.connect_to(left) + return left, right diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py new file mode 100644 index 0000000000..09e5e87bb6 --- /dev/null +++ b/src/mcp/shared/dispatcher.py @@ -0,0 +1,167 @@ +"""Dispatcher Protocol — the call/return boundary between transports and handlers. + +A Dispatcher turns a duplex message channel into two things: + +* an outbound API: ``call(method, params)`` and ``notify(method, params)`` +* an inbound pump: ``run(on_call, on_notify)`` that drives the receive loop and + invokes the supplied handlers for each incoming request/notification + +It is deliberately *not* MCP-aware. Method names are strings, params and +results are ``dict[str, Any]``. The MCP type layer (request/result models, +capability negotiation, ``Context``) sits above this; the wire encoding +(JSON-RPC, gRPC, in-process direct calls) sits below it. + +See ``JSONRPCDispatcher`` for the production implementation and +``DirectDispatcher`` for an in-memory implementation used in tests and for +embedding a server in-process. +""" + +from collections.abc import Awaitable, Callable, Mapping +from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable + +import anyio + +from mcp.shared.transport_context import TransportContext + +__all__ = [ + "CallOptions", + "DispatchContext", + "DispatchMiddleware", + "Dispatcher", + "OnCall", + "OnNotify", + "ProgressFnT", + "RequestSender", +] + +TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) + + +class ProgressFnT(Protocol): + """Callback invoked when a progress notification arrives for a pending call.""" + + async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... + + +class CallOptions(TypedDict, total=False): + """Per-call options for `RequestSender.send_request` / `Dispatcher.call`. + + All keys are optional. Dispatchers ignore keys they do not understand. + """ + + timeout: float + """Seconds to wait for a result before raising and sending ``notifications/cancelled``.""" + + on_progress: ProgressFnT + """Receive ``notifications/progress`` updates for this call.""" + + resumption_token: str + """Opaque token to resume a previously interrupted call (transport-dependent).""" + + on_resumption_token: Callable[[str], Awaitable[None]] + """Receive a resumption token when the transport issues one.""" + + +@runtime_checkable +class RequestSender(Protocol): + """Anything that can send a request and await its result. + + Both `Dispatcher` (for top-level outbound calls) and `DispatchContext` + (for server-to-client calls made *during* an inbound request) satisfy this. + """ + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: ... + + +class DispatchContext(Protocol[TransportT_co]): + """Per-request context handed to ``on_call`` / ``on_notify``. + + Carries the transport metadata for the inbound message and provides the + back-channel for sending requests/notifications to the peer while handling + it. + """ + + @property + def transport(self) -> TransportT_co: + """Transport-specific metadata for this inbound message.""" + ... + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends ``notifications/cancelled`` for this request.""" + ... + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a notification to the peer.""" + ... + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request to the peer on the back-channel and await its result. + + Raises: + NoBackChannelError: if ``transport.can_send_request`` is ``False``. + """ + ... + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the inbound request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + ... + + +OnCall = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] +"""Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response.""" + +OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] +"""Handler for inbound notifications: ``(ctx, method, params)``.""" + +DispatchMiddleware = Callable[[OnCall], OnCall] +"""Wraps an ``OnCall`` to produce another ``OnCall``. Applied outermost-first.""" + + +class Dispatcher(Protocol[TransportT_co]): + """A duplex request/notification channel with call-return semantics. + + Implementations own correlation of outbound calls to inbound results, the + receive loop, per-request concurrency, and cancellation/progress wiring. + """ + + async def call( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request and await its result. + + Raises: + MCPError: If the peer responded with an error, or the handler + raised. Implementations normalize all handler exceptions to + `MCPError` so callers see a single exception type. + """ + ... + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a fire-and-forget notification.""" + ... + + async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: + """Drive the receive loop until the underlying channel closes. + + Each inbound request is dispatched to ``on_call`` in its own task; the + returned dict (or raised ``MCPError``) is sent back as the response. + Inbound notifications go to ``on_notify``. + """ + ... diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319d..e9dd2c843e 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -2,7 +2,7 @@ from typing import Any, cast -from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError +from mcp.types import INVALID_REQUEST, URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError class MCPError(Exception): @@ -41,6 +41,25 @@ def __str__(self) -> str: return self.message +class NoBackChannelError(MCPError): + """Raised when sending a server-initiated request over a transport that cannot deliver it. + + Stateless HTTP and JSON-response-mode HTTP have no channel for the server to + push requests (sampling, elicitation, roots/list) to the client. This is + raised by `DispatchContext.send_request` when `transport.can_send_request` + is ``False``, and serializes to an ``INVALID_REQUEST`` error response. + """ + + def __init__(self, method: str): + super().__init__( + code=INVALID_REQUEST, + message=( + f"Cannot send {method!r}: this transport context has no back-channel for server-initiated requests." + ), + ) + self.method = method + + class StatelessModeNotSupported(RuntimeError): """Raised when attempting to use a method that is not supported in stateless mode. diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py new file mode 100644 index 0000000000..31230fda92 --- /dev/null +++ b/src/mcp/shared/transport_context.py @@ -0,0 +1,30 @@ +"""Transport-specific metadata attached to each inbound message. + +`TransportContext` is the base; each transport defines its own subclass with +whatever fields make sense (HTTP request id, ASGI scope, stdio process handle, +etc.). The dispatcher passes it through opaquely; only the layers above the +dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields. +""" + +from dataclasses import dataclass + +__all__ = ["TransportContext"] + + +@dataclass(kw_only=True, frozen=True) +class TransportContext: + """Base transport metadata for an inbound message. + + Subclass per transport and add fields as needed. Instances are immutable. + """ + + kind: str + """Short identifier for the transport (e.g. ``"stdio"``, ``"streamable-http"``).""" + + can_send_request: bool + """Whether the transport can deliver server-initiated requests to the peer. + + ``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for + stdio, SSE, and stateful streamable HTTP. When ``False``, + `DispatchContext.send_request` raises `NoBackChannelError`. + """ diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index b442303937..ca1c328939 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -192,6 +192,7 @@ INVALID_REQUEST, METHOD_NOT_FOUND, PARSE_ERROR, + REQUEST_CANCELLED, REQUEST_TIMEOUT, URL_ELICITATION_REQUIRED, ErrorData, @@ -401,6 +402,7 @@ "INVALID_REQUEST", "METHOD_NOT_FOUND", "PARSE_ERROR", + "REQUEST_CANCELLED", "REQUEST_TIMEOUT", "URL_ELICITATION_REQUIRED", "ErrorData", diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 84304a37c1..14743c33b0 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -43,6 +43,7 @@ class JSONRPCResponse(BaseModel): # SDK error codes CONNECTION_CLOSED = -32000 REQUEST_TIMEOUT = -32001 +REQUEST_CANCELLED = -32002 # Standard JSON-RPC error codes PARSE_ERROR = -32700 diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py new file mode 100644 index 0000000000..dd8d40721a --- /dev/null +++ b/tests/shared/test_dispatcher.py @@ -0,0 +1,253 @@ +"""Behavioral tests for the Dispatcher Protocol via DirectDispatcher. + +These exercise the `Dispatcher` / `DispatchContext` contract end-to-end using +the in-memory `DirectDispatcher`. JSON-RPC framing is covered separately in +``test_jsonrpc_dispatcher.py``. +""" + +from collections.abc import AsyncIterator, Mapping +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnCall, OnNotify +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT + + +class Recorder: + def __init__(self) -> None: + self.calls: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self.contexts: list[DispatchContext[TransportContext]] = [] + self.notified = anyio.Event() + + +def echo_handlers(recorder: Recorder) -> tuple[OnCall, OnNotify]: + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + recorder.calls.append((method, params)) + recorder.contexts.append(ctx) + return {"echoed": method, "params": dict(params or {})} + + async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: + recorder.notifications.append((method, params)) + recorder.notified.set() + + return on_call, on_notify + + +@asynccontextmanager +async def running_pair( + *, + server_on_call: OnCall | None = None, + server_on_notify: OnNotify | None = None, + client_on_call: OnCall | None = None, + client_on_notify: OnNotify | None = None, + can_send_request: bool = True, +) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: + """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" + client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + client_rec, server_rec = Recorder(), Recorder() + c_call, c_notify = echo_handlers(client_rec) + s_call, s_notify = echo_handlers(server_rec) + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, client_on_call or c_call, client_on_notify or c_notify) + tg.start_soon(server.run, server_on_call or s_call, server_on_notify or s_notify) + try: + yield client, server, client_rec, server_rec + finally: + client.close() + server.close() + + +@pytest.mark.anyio +async def test_call_returns_result_from_peer_on_call(): + async with running_pair() as (client, _server, _crec, srec): + with anyio.fail_after(5): + result = await client.call("tools/list", {"cursor": "abc"}) + assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} + assert srec.calls == [("tools/list", {"cursor": "abc"})] + + +@pytest.mark.anyio +async def test_call_reraises_mcperror_from_handler_unchanged(): + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise MCPError(code=INVALID_PARAMS, message="bad cursor") + + async with running_pair(server_on_call=on_call) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.call("tools/list", {}) + assert exc.value.error.code == INVALID_PARAMS + assert exc.value.error.message == "bad cursor" + + +@pytest.mark.anyio +async def test_call_wraps_non_mcperror_exception_as_internal_error(): + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise ValueError("oops") + + async with running_pair(server_on_call=on_call) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.call("tools/list", {}) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) + + +@pytest.mark.anyio +async def test_call_with_timeout_raises_mcperror_request_timeout(): + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await anyio.sleep_forever() + return {} + + async with running_pair(server_on_call=on_call) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.call("slow", None, {"timeout": 0}) + assert exc.value.error.code == REQUEST_TIMEOUT + + +@pytest.mark.anyio +async def test_notify_invokes_peer_on_notify(): + async with running_pair() as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/initialized", {"v": 1}) + await srec.notified.wait() + assert srec.notifications == [("notifications/initialized", {"v": 1})] + + +@pytest.mark.anyio +async def test_ctx_send_request_round_trips_to_calling_side(): + """A handler's ctx.send_request reaches the side that made the inbound call.""" + + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + sample = await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) + return {"sampled": sample} + + async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + with anyio.fail_after(5): + result = await client.call("tools/call", None) + assert crec.calls == [("sampling/createMessage", {"prompt": "hi"})] + assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} + + +@pytest.mark.anyio +async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallows(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.send_request("sampling/createMessage", None) + return {} + + async with running_pair(server_on_call=server_on_call, can_send_request=False) as (client, *_): + with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: + await client.call("tools/call", None) + assert exc.value.method == "sampling/createMessage" + assert exc.value.error.code == INVALID_REQUEST + + +@pytest.mark.anyio +async def test_ctx_notify_invokes_calling_side_on_notify(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.notify("notifications/message", {"level": "info"}) + return {} + + async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + with anyio.fail_after(5): + await client.call("tools/call", None) + await crec.notified.wait() + assert crec.notifications == [("notifications/message", {"level": "info"})] + + +@pytest.mark.anyio +async def test_ctx_progress_invokes_caller_on_progress_callback(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5, total=1.0, message="halfway") + return {} + + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with running_pair(server_on_call=server_on_call) as (client, *_): + with anyio.fail_after(5): + await client.call("tools/call", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_call_issued_before_peer_run_blocks_until_peer_ready(): + client, server = create_direct_dispatcher_pair() + s_call, s_notify = echo_handlers(Recorder()) + c_call, c_notify = echo_handlers(Recorder()) + + async def late_start(): + await anyio.sleep(0) + await server.run(s_call, s_notify) + + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, c_call, c_notify) + tg.start_soon(late_start) + with anyio.fail_after(5): + result = await client.call("ping", None) + assert result == {"echoed": "ping", "params": {}} + client.close() + server.close() + + +@pytest.mark.anyio +async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + async with running_pair(server_on_call=server_on_call) as (client, *_): + with anyio.fail_after(5): + result = await client.call("tools/call", None) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_call_and_notify_raise_runtimeerror_when_no_peer_connected(): + d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + with pytest.raises(RuntimeError, match="no peer"): + await d.call("ping", None) + with pytest.raises(RuntimeError, match="no peer"): + await d.notify("ping", None) + + +@pytest.mark.anyio +async def test_close_makes_run_return(): + client, server = create_direct_dispatcher_pair() + on_call, on_notify = echo_handlers(Recorder()) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(server.run, on_call, on_notify) + tg.start_soon(client.run, on_call, on_notify) + client.close() + server.close() + + +if TYPE_CHECKING: + _dispatcher_check: Dispatcher[TransportContext] = DirectDispatcher( + TransportContext(kind="direct", can_send_request=True) + ) From f53b056c4c2cff734244873da4b8daf74fd41144 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:03:48 +0000 Subject: [PATCH 02/52] fix: address coverage gaps and stale RequestSender docstring - tests: replace unreachable 'return {}' with 'raise NotImplementedError' (already in coverage exclude_also) and collapse send_request+return into one statement - dispatcher: RequestSender docstring no longer claims Dispatcher satisfies it (Dispatcher exposes call(), not send_request()) --- src/mcp/shared/dispatcher.py | 4 ++-- tests/shared/test_dispatcher.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index 09e5e87bb6..b63c00c0bf 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -66,8 +66,8 @@ class CallOptions(TypedDict, total=False): class RequestSender(Protocol): """Anything that can send a request and await its result. - Both `Dispatcher` (for top-level outbound calls) and `DispatchContext` - (for server-to-client calls made *during* an inbound request) satisfy this. + `DispatchContext` satisfies this; `PeerMixin` (and `Connection`/`Peer`) wrap + a `RequestSender` to provide typed request methods. """ async def send_request( diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index dd8d40721a..ddfe1f798f 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -109,7 +109,7 @@ async def on_call( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await anyio.sleep_forever() - return {} + raise NotImplementedError async with running_pair(server_on_call=on_call) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: @@ -148,8 +148,7 @@ async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallo async def server_on_call( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - await ctx.send_request("sampling/createMessage", None) - return {} + return await ctx.send_request("sampling/createMessage", None) async with running_pair(server_on_call=server_on_call, can_send_request=False) as (client, *_): with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: From 2e2b2d7e2985b2be38f47fd4de74ae4e8df8a595 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:52:58 +0000 Subject: [PATCH 03/52] refactor: rename Dispatcher.call to send_request, replace RequestSender with Outbound MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The design doc's `send_request = call` alias only makes the concrete class satisfy RequestSender, not the abstract Dispatcher Protocol — so any consumer typed against `Dispatcher[TT]` (Connection, ServerRunner) couldn't pass it to something expecting a RequestSender without a cast or hand-written bridge. RequestSender was also half a contract: every implementor (Dispatcher, DispatchContext, Connection, Context) has `notify` too, and PeerMixin needs both for its typed sugar (elicit/sample are requests, log is a notification). Outbound(Protocol) declares both methods; Dispatcher and DispatchContext extend it. PeerMixin will wrap an Outbound. One verb everywhere, no aliases, no extra Protocols. - Dispatcher.call -> send_request - OnCall -> OnRequest, on_call -> on_request - RequestSender -> Outbound (now also declares notify) - Dispatcher(Outbound, Protocol[TT]), DispatchContext(Outbound, Protocol[TT]) --- src/mcp/shared/direct_dispatcher.py | 38 ++++----- src/mcp/shared/dispatcher.py | 100 ++++++++++-------------- tests/shared/test_dispatcher.py | 115 ++++++++++++++-------------- 3 files changed, 115 insertions(+), 138 deletions(-) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 4650619428..79b68d0547 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -1,7 +1,7 @@ """In-memory `Dispatcher` that wires two peers together with no transport. -`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a call -on one side directly invokes the other side's `on_call`. There is no +`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a +request on one side directly invokes the other side's `on_request`. There is no serialization, no JSON-RPC framing, and no streams. It exists to: * prove the `Dispatcher` Protocol is implementable without JSON-RPC @@ -21,7 +21,7 @@ import anyio -from mcp.shared.dispatcher import CallOptions, OnCall, OnNotify, ProgressFnT +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT @@ -31,20 +31,20 @@ DIRECT_TRANSPORT_KIND = "direct" -_Call = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] +_Request = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] _Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]] @dataclass class _DirectDispatchContext: - """`DispatchContext` for an inbound call on a `DirectDispatcher`. + """`DispatchContext` for an inbound request on a `DirectDispatcher`. The back-channel callables target the *originating* side, so a handler's - `send_request` reaches the peer that made the inbound call. + `send_request` reaches the peer that made the inbound request. """ transport: TransportContext - _back_call: _Call + _back_request: _Request _back_notify: _Notify _on_progress: ProgressFnT | None = None cancel_requested: anyio.Event = field(default_factory=anyio.Event) @@ -60,7 +60,7 @@ async def send_request( ) -> dict[str, Any]: if not self.transport.can_send_request: raise NoBackChannelError(method) - return await self._back_call(method, params, opts) + return await self._back_request(method, params, opts) async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: if self._on_progress is not None: @@ -71,14 +71,14 @@ class DirectDispatcher: """A `Dispatcher` that calls a peer's handlers directly, in-process. Two instances are wired together with `create_direct_dispatcher_pair`; each - holds a reference to the other. `call` on one awaits the peer's `on_call`. - `run` parks until `close` is called. + holds a reference to the other. `send_request` on one awaits the peer's + `on_request`. `run` parks until `close` is called. """ def __init__(self, transport_ctx: TransportContext): self._transport_ctx = transport_ctx self._peer: DirectDispatcher | None = None - self._on_call: OnCall | None = None + self._on_request: OnRequest | None = None self._on_notify: OnNotify | None = None self._ready = anyio.Event() self._closed = anyio.Event() @@ -86,7 +86,7 @@ def __init__(self, transport_ctx: TransportContext): def connect_to(self, peer: DirectDispatcher) -> None: self._peer = peer - async def call( + async def send_request( self, method: str, params: Mapping[str, Any] | None, @@ -94,15 +94,15 @@ async def call( ) -> dict[str, Any]: if self._peer is None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") - return await self._peer._dispatch_call(method, params, opts) + return await self._peer._dispatch_request(method, params, opts) async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: if self._peer is None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") await self._peer._dispatch_notify(method, params) - async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: - self._on_call = on_call + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + self._on_request = on_request self._on_notify = on_notify self._ready.set() await self._closed.wait() @@ -115,25 +115,25 @@ def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispat peer = self._peer return _DirectDispatchContext( transport=self._transport_ctx, - _back_call=lambda m, p, o: peer._dispatch_call(m, p, o), + _back_request=lambda m, p, o: peer._dispatch_request(m, p, o), _back_notify=lambda m, p: peer._dispatch_notify(m, p), _on_progress=on_progress, ) - async def _dispatch_call( + async def _dispatch_request( self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None, ) -> dict[str, Any]: await self._ready.wait() - assert self._on_call is not None + assert self._on_request is not None opts = opts or {} dctx = self._make_context(on_progress=opts.get("on_progress")) try: with anyio.fail_after(opts.get("timeout")): try: - return await self._on_call(dctx, method, params) + return await self._on_request(dctx, method, params) except MCPError: raise except Exception as e: diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index b63c00c0bf..872fb01eaa 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -2,9 +2,9 @@ A Dispatcher turns a duplex message channel into two things: -* an outbound API: ``call(method, params)`` and ``notify(method, params)`` -* an inbound pump: ``run(on_call, on_notify)`` that drives the receive loop and - invokes the supplied handlers for each incoming request/notification +* an outbound API: ``send_request(method, params)`` and ``notify(method, params)`` +* an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop + and invokes the supplied handlers for each incoming request/notification It is deliberately *not* MCP-aware. Method names are strings, params and results are ``dict[str, Any]``. The MCP type layer (request/result models, @@ -28,23 +28,23 @@ "DispatchContext", "DispatchMiddleware", "Dispatcher", - "OnCall", "OnNotify", + "OnRequest", + "Outbound", "ProgressFnT", - "RequestSender", ] TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) class ProgressFnT(Protocol): - """Callback invoked when a progress notification arrives for a pending call.""" + """Callback invoked when a progress notification arrives for a pending request.""" async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... class CallOptions(TypedDict, total=False): - """Per-call options for `RequestSender.send_request` / `Dispatcher.call`. + """Per-call options for `Outbound.send_request`. All keys are optional. Dispatchers ignore keys they do not understand. """ @@ -53,21 +53,22 @@ class CallOptions(TypedDict, total=False): """Seconds to wait for a result before raising and sending ``notifications/cancelled``.""" on_progress: ProgressFnT - """Receive ``notifications/progress`` updates for this call.""" + """Receive ``notifications/progress`` updates for this request.""" resumption_token: str - """Opaque token to resume a previously interrupted call (transport-dependent).""" + """Opaque token to resume a previously interrupted request (transport-dependent).""" on_resumption_token: Callable[[str], Awaitable[None]] """Receive a resumption token when the transport issues one.""" @runtime_checkable -class RequestSender(Protocol): - """Anything that can send a request and await its result. +class Outbound(Protocol): + """Anything that can send requests and notifications to the peer. - `DispatchContext` satisfies this; `PeerMixin` (and `Connection`/`Peer`) wrap - a `RequestSender` to provide typed request methods. + Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel + during an inbound request) extend this. `PeerMixin` wraps an `Outbound` to + provide typed MCP request/notification methods. """ async def send_request( @@ -75,15 +76,28 @@ async def send_request( method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None, - ) -> dict[str, Any]: ... + ) -> dict[str, Any]: + """Send a request and await its result. + + Raises: + MCPError: If the peer responded with an error, or the handler + raised. Implementations normalize all handler exceptions to + `MCPError` so callers see a single exception type. + """ + ... + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a fire-and-forget notification.""" + ... -class DispatchContext(Protocol[TransportT_co]): - """Per-request context handed to ``on_call`` / ``on_notify``. + +class DispatchContext(Outbound, Protocol[TransportT_co]): + """Per-request context handed to ``on_request`` / ``on_notify``. Carries the transport metadata for the inbound message and provides the back-channel for sending requests/notifications to the peer while handling - it. + it. `send_request` raises `NoBackChannelError` if + ``transport.can_send_request`` is ``False``. """ @property @@ -96,23 +110,6 @@ def cancel_requested(self) -> anyio.Event: """Set when the peer sends ``notifications/cancelled`` for this request.""" ... - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: - """Send a notification to the peer.""" - ... - - async def send_request( - self, - method: str, - params: Mapping[str, Any] | None, - opts: CallOptions | None = None, - ) -> dict[str, Any]: - """Send a request to the peer on the back-channel and await its result. - - Raises: - NoBackChannelError: if ``transport.can_send_request`` is ``False``. - """ - ... - async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: """Report progress for the inbound request, if the peer supplied a progress token. @@ -121,47 +118,28 @@ async def progress(self, progress: float, total: float | None = None, message: s ... -OnCall = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] +OnRequest = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] """Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response.""" OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] """Handler for inbound notifications: ``(ctx, method, params)``.""" -DispatchMiddleware = Callable[[OnCall], OnCall] -"""Wraps an ``OnCall`` to produce another ``OnCall``. Applied outermost-first.""" +DispatchMiddleware = Callable[[OnRequest], OnRequest] +"""Wraps an ``OnRequest`` to produce another ``OnRequest``. Applied outermost-first.""" -class Dispatcher(Protocol[TransportT_co]): +class Dispatcher(Outbound, Protocol[TransportT_co]): """A duplex request/notification channel with call-return semantics. - Implementations own correlation of outbound calls to inbound results, the + Implementations own correlation of outbound requests to inbound results, the receive loop, per-request concurrency, and cancellation/progress wiring. """ - async def call( - self, - method: str, - params: Mapping[str, Any] | None, - opts: CallOptions | None = None, - ) -> dict[str, Any]: - """Send a request and await its result. - - Raises: - MCPError: If the peer responded with an error, or the handler - raised. Implementations normalize all handler exceptions to - `MCPError` so callers see a single exception type. - """ - ... - - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: - """Send a fire-and-forget notification.""" - ... - - async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: """Drive the receive loop until the underlying channel closes. - Each inbound request is dispatched to ``on_call`` in its own task; the - returned dict (or raised ``MCPError``) is sent back as the response. + Each inbound request is dispatched to ``on_request`` in its own task; + the returned dict (or raised ``MCPError``) is sent back as the response. Inbound notifications go to ``on_notify``. """ ... diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index ddfe1f798f..44ab622ad6 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -13,7 +13,7 @@ import pytest from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair -from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnCall, OnNotify +from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT @@ -21,17 +21,17 @@ class Recorder: def __init__(self) -> None: - self.calls: list[tuple[str, Mapping[str, Any] | None]] = [] + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] self.contexts: list[DispatchContext[TransportContext]] = [] self.notified = anyio.Event() -def echo_handlers(recorder: Recorder) -> tuple[OnCall, OnNotify]: - async def on_call( +def echo_handlers(recorder: Recorder) -> tuple[OnRequest, OnNotify]: + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - recorder.calls.append((method, params)) + recorder.requests.append((method, params)) recorder.contexts.append(ctx) return {"echoed": method, "params": dict(params or {})} @@ -39,26 +39,26 @@ async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: recorder.notifications.append((method, params)) recorder.notified.set() - return on_call, on_notify + return on_request, on_notify @asynccontextmanager async def running_pair( *, - server_on_call: OnCall | None = None, + server_on_request: OnRequest | None = None, server_on_notify: OnNotify | None = None, - client_on_call: OnCall | None = None, + client_on_request: OnRequest | None = None, client_on_notify: OnNotify | None = None, can_send_request: bool = True, ) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) client_rec, server_rec = Recorder(), Recorder() - c_call, c_notify = echo_handlers(client_rec) - s_call, s_notify = echo_handlers(server_rec) + c_req, c_notify = echo_handlers(client_rec) + s_req, s_notify = echo_handlers(server_rec) async with anyio.create_task_group() as tg: - tg.start_soon(client.run, client_on_call or c_call, client_on_notify or c_notify) - tg.start_soon(server.run, server_on_call or s_call, server_on_notify or s_notify) + tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) + tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) try: yield client, server, client_rec, server_rec finally: @@ -67,53 +67,53 @@ async def running_pair( @pytest.mark.anyio -async def test_call_returns_result_from_peer_on_call(): +async def test_send_request_returns_result_from_peer_on_request(): async with running_pair() as (client, _server, _crec, srec): with anyio.fail_after(5): - result = await client.call("tools/list", {"cursor": "abc"}) + result = await client.send_request("tools/list", {"cursor": "abc"}) assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} - assert srec.calls == [("tools/list", {"cursor": "abc"})] + assert srec.requests == [("tools/list", {"cursor": "abc"})] @pytest.mark.anyio -async def test_call_reraises_mcperror_from_handler_unchanged(): - async def on_call( +async def test_send_request_reraises_mcperror_from_handler_unchanged(): + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: raise MCPError(code=INVALID_PARAMS, message="bad cursor") - async with running_pair(server_on_call=on_call) as (client, *_): + async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.call("tools/list", {}) + await client.send_request("tools/list", {}) assert exc.value.error.code == INVALID_PARAMS assert exc.value.error.message == "bad cursor" @pytest.mark.anyio -async def test_call_wraps_non_mcperror_exception_as_internal_error(): - async def on_call( +async def test_send_request_wraps_non_mcperror_exception_as_internal_error(): + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: raise ValueError("oops") - async with running_pair(server_on_call=on_call) as (client, *_): + async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.call("tools/list", {}) + await client.send_request("tools/list", {}) assert exc.value.error.code == INTERNAL_ERROR assert isinstance(exc.value.__cause__, ValueError) @pytest.mark.anyio -async def test_call_with_timeout_raises_mcperror_request_timeout(): - async def on_call( +async def test_send_request_with_timeout_raises_mcperror_request_timeout(): + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await anyio.sleep_forever() raise NotImplementedError - async with running_pair(server_on_call=on_call) as (client, *_): + async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.call("slow", None, {"timeout": 0}) + await client.send_request("slow", None, {"timeout": 0}) assert exc.value.error.code == REQUEST_TIMEOUT @@ -128,53 +128,53 @@ async def test_notify_invokes_peer_on_notify(): @pytest.mark.anyio async def test_ctx_send_request_round_trips_to_calling_side(): - """A handler's ctx.send_request reaches the side that made the inbound call.""" + """A handler's ctx.send_request reaches the side that made the inbound request.""" - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: sample = await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) return {"sampled": sample} - async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - result = await client.call("tools/call", None) - assert crec.calls == [("sampling/createMessage", {"prompt": "hi"})] + result = await client.send_request("tools/call", None) + assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} @pytest.mark.anyio async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallows(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: return await ctx.send_request("sampling/createMessage", None) - async with running_pair(server_on_call=server_on_call, can_send_request=False) as (client, *_): + async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: - await client.call("tools/call", None) + await client.send_request("tools/call", None) assert exc.value.method == "sampling/createMessage" assert exc.value.error.code == INVALID_REQUEST @pytest.mark.anyio async def test_ctx_notify_invokes_calling_side_on_notify(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.notify("notifications/message", {"level": "info"}) return {} - async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - await client.call("tools/call", None) + await client.send_request("tools/call", None) await crec.notified.wait() assert crec.notifications == [("notifications/message", {"level": "info"})] @pytest.mark.anyio async def test_ctx_progress_invokes_caller_on_progress_callback(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.progress(0.5, total=1.0, message="halfway") @@ -185,27 +185,27 @@ async def server_on_call( async def on_progress(progress: float, total: float | None, message: str | None) -> None: received.append((progress, total, message)) - async with running_pair(server_on_call=server_on_call) as (client, *_): + async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.call("tools/call", None, {"on_progress": on_progress}) + await client.send_request("tools/call", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio -async def test_call_issued_before_peer_run_blocks_until_peer_ready(): +async def test_send_request_issued_before_peer_run_blocks_until_peer_ready(): client, server = create_direct_dispatcher_pair() - s_call, s_notify = echo_handlers(Recorder()) - c_call, c_notify = echo_handlers(Recorder()) + s_req, s_notify = echo_handlers(Recorder()) + c_req, c_notify = echo_handlers(Recorder()) async def late_start(): await anyio.sleep(0) - await server.run(s_call, s_notify) + await server.run(s_req, s_notify) async with anyio.create_task_group() as tg: - tg.start_soon(client.run, c_call, c_notify) + tg.start_soon(client.run, c_req, c_notify) tg.start_soon(late_start) with anyio.fail_after(5): - result = await client.call("ping", None) + result = await client.send_request("ping", None) assert result == {"echoed": "ping", "params": {}} client.close() server.close() @@ -213,23 +213,23 @@ async def late_start(): @pytest.mark.anyio async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.progress(0.5) return {"ok": True} - async with running_pair(server_on_call=server_on_call) as (client, *_): + async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - result = await client.call("tools/call", None) + result = await client.send_request("tools/call", None) assert result == {"ok": True} @pytest.mark.anyio -async def test_call_and_notify_raise_runtimeerror_when_no_peer_connected(): +async def test_send_request_and_notify_raise_runtimeerror_when_no_peer_connected(): d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) with pytest.raises(RuntimeError, match="no peer"): - await d.call("ping", None) + await d.send_request("ping", None) with pytest.raises(RuntimeError, match="no peer"): await d.notify("ping", None) @@ -237,16 +237,15 @@ async def test_call_and_notify_raise_runtimeerror_when_no_peer_connected(): @pytest.mark.anyio async def test_close_makes_run_return(): client, server = create_direct_dispatcher_pair() - on_call, on_notify = echo_handlers(Recorder()) + on_request, on_notify = echo_handlers(Recorder()) with anyio.fail_after(5): async with anyio.create_task_group() as tg: - tg.start_soon(server.run, on_call, on_notify) - tg.start_soon(client.run, on_call, on_notify) + tg.start_soon(server.run, on_request, on_notify) + tg.start_soon(client.run, on_request, on_notify) client.close() server.close() if TYPE_CHECKING: - _dispatcher_check: Dispatcher[TransportContext] = DirectDispatcher( - TransportContext(kind="direct", can_send_request=True) - ) + _d: Dispatcher[TransportContext] = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + _o: Outbound = _d From b5cf7560728ad15181ac4cd4ab18c566cbdcf12e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 21:38:32 +0000 Subject: [PATCH 04/52] refactor: rename Outbound.send_request to send_raw_request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The dispatcher-layer raw channel is now `send_raw_request(method, params) -> dict`. This frees the `send_request` name for the typed surface (`send_request(req: Request) -> Result`) that Connection/Context/Client add in later PRs. Mechanical rename across Outbound, Dispatcher, DispatchContext, DirectDispatcher, _DirectDispatchContext, and all tests. `can_send_request` (the transport capability flag) is unchanged — it names the capability, not the method. --- src/mcp/shared/direct_dispatcher.py | 8 +++--- src/mcp/shared/dispatcher.py | 15 +++++----- src/mcp/shared/exceptions.py | 2 +- src/mcp/shared/transport_context.py | 2 +- tests/shared/test_dispatcher.py | 44 ++++++++++++++--------------- 5 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 79b68d0547..bb5639a136 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -40,7 +40,7 @@ class _DirectDispatchContext: """`DispatchContext` for an inbound request on a `DirectDispatcher`. The back-channel callables target the *originating* side, so a handler's - `send_request` reaches the peer that made the inbound request. + `send_raw_request` reaches the peer that made the inbound request. """ transport: TransportContext @@ -52,7 +52,7 @@ class _DirectDispatchContext: async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: await self._back_notify(method, params) - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, @@ -71,7 +71,7 @@ class DirectDispatcher: """A `Dispatcher` that calls a peer's handlers directly, in-process. Two instances are wired together with `create_direct_dispatcher_pair`; each - holds a reference to the other. `send_request` on one awaits the peer's + holds a reference to the other. `send_raw_request` on one awaits the peer's `on_request`. `run` parks until `close` is called. """ @@ -86,7 +86,7 @@ def __init__(self, transport_ctx: TransportContext): def connect_to(self, peer: DirectDispatcher) -> None: self._peer = peer - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index 872fb01eaa..ee02e23896 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -2,7 +2,7 @@ A Dispatcher turns a duplex message channel into two things: -* an outbound API: ``send_request(method, params)`` and ``notify(method, params)`` +* an outbound API: ``send_raw_request(method, params)`` and ``notify(method, params)`` * an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop and invokes the supplied handlers for each incoming request/notification @@ -44,7 +44,7 @@ async def __call__(self, progress: float, total: float | None, message: str | No class CallOptions(TypedDict, total=False): - """Per-call options for `Outbound.send_request`. + """Per-call options for `Outbound.send_raw_request`. All keys are optional. Dispatchers ignore keys they do not understand. """ @@ -67,17 +67,18 @@ class Outbound(Protocol): """Anything that can send requests and notifications to the peer. Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel - during an inbound request) extend this. `PeerMixin` wraps an `Outbound` to - provide typed MCP request/notification methods. + during an inbound request) extend this. The MCP type layer (`PeerMixin`, + `Connection`, `Context`) builds typed ``send_request`` / convenience methods + on top of this raw channel. """ - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None, ) -> dict[str, Any]: - """Send a request and await its result. + """Send a request and await its raw result dict. Raises: MCPError: If the peer responded with an error, or the handler @@ -96,7 +97,7 @@ class DispatchContext(Outbound, Protocol[TransportT_co]): Carries the transport metadata for the inbound message and provides the back-channel for sending requests/notifications to the peer while handling - it. `send_request` raises `NoBackChannelError` if + it. `send_raw_request` raises `NoBackChannelError` if ``transport.can_send_request`` is ``False``. """ diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index e9dd2c843e..b62629b6c8 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -46,7 +46,7 @@ class NoBackChannelError(MCPError): Stateless HTTP and JSON-response-mode HTTP have no channel for the server to push requests (sampling, elicitation, roots/list) to the client. This is - raised by `DispatchContext.send_request` when `transport.can_send_request` + raised by `DispatchContext.send_raw_request` when `transport.can_send_request` is ``False``, and serializes to an ``INVALID_REQUEST`` error response. """ diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py index 31230fda92..832cead515 100644 --- a/src/mcp/shared/transport_context.py +++ b/src/mcp/shared/transport_context.py @@ -26,5 +26,5 @@ class TransportContext: ``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for stdio, SSE, and stateful streamable HTTP. When ``False``, - `DispatchContext.send_request` raises `NoBackChannelError`. + `DispatchContext.send_raw_request` raises `NoBackChannelError`. """ diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 44ab622ad6..784ef6698f 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -67,16 +67,16 @@ async def running_pair( @pytest.mark.anyio -async def test_send_request_returns_result_from_peer_on_request(): +async def test_send_raw_request_returns_result_from_peer_on_request(): async with running_pair() as (client, _server, _crec, srec): with anyio.fail_after(5): - result = await client.send_request("tools/list", {"cursor": "abc"}) + result = await client.send_raw_request("tools/list", {"cursor": "abc"}) assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} assert srec.requests == [("tools/list", {"cursor": "abc"})] @pytest.mark.anyio -async def test_send_request_reraises_mcperror_from_handler_unchanged(): +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -84,13 +84,13 @@ async def on_request( async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("tools/list", {}) + await client.send_raw_request("tools/list", {}) assert exc.value.error.code == INVALID_PARAMS assert exc.value.error.message == "bad cursor" @pytest.mark.anyio -async def test_send_request_wraps_non_mcperror_exception_as_internal_error(): +async def test_send_raw_request_wraps_non_mcperror_exception_as_internal_error(): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -98,13 +98,13 @@ async def on_request( async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("tools/list", {}) + await client.send_raw_request("tools/list", {}) assert exc.value.error.code == INTERNAL_ERROR assert isinstance(exc.value.__cause__, ValueError) @pytest.mark.anyio -async def test_send_request_with_timeout_raises_mcperror_request_timeout(): +async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -113,7 +113,7 @@ async def on_request( async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("slow", None, {"timeout": 0}) + await client.send_raw_request("slow", None, {"timeout": 0}) assert exc.value.error.code == REQUEST_TIMEOUT @@ -127,32 +127,32 @@ async def test_notify_invokes_peer_on_notify(): @pytest.mark.anyio -async def test_ctx_send_request_round_trips_to_calling_side(): - """A handler's ctx.send_request reaches the side that made the inbound request.""" +async def test_ctx_send_raw_request_round_trips_to_calling_side(): + """A handler's ctx.send_raw_request reaches the side that made the inbound request.""" async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - sample = await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) + sample = await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) return {"sampled": sample} async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - result = await client.send_request("tools/call", None) + result = await client.send_raw_request("tools/call", None) assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} @pytest.mark.anyio -async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallows(): +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - return await ctx.send_request("sampling/createMessage", None) + return await ctx.send_raw_request("sampling/createMessage", None) async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: - await client.send_request("tools/call", None) + await client.send_raw_request("tools/call", None) assert exc.value.method == "sampling/createMessage" assert exc.value.error.code == INVALID_REQUEST @@ -167,7 +167,7 @@ async def server_on_request( async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - await client.send_request("tools/call", None) + await client.send_raw_request("tools/call", None) await crec.notified.wait() assert crec.notifications == [("notifications/message", {"level": "info"})] @@ -187,12 +187,12 @@ async def on_progress(progress: float, total: float | None, message: str | None) async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("tools/call", None, {"on_progress": on_progress}) + await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio -async def test_send_request_issued_before_peer_run_blocks_until_peer_ready(): +async def test_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): client, server = create_direct_dispatcher_pair() s_req, s_notify = echo_handlers(Recorder()) c_req, c_notify = echo_handlers(Recorder()) @@ -205,7 +205,7 @@ async def late_start(): tg.start_soon(client.run, c_req, c_notify) tg.start_soon(late_start) with anyio.fail_after(5): - result = await client.send_request("ping", None) + result = await client.send_raw_request("ping", None) assert result == {"echoed": "ping", "params": {}} client.close() server.close() @@ -221,15 +221,15 @@ async def server_on_request( async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - result = await client.send_request("tools/call", None) + result = await client.send_raw_request("tools/call", None) assert result == {"ok": True} @pytest.mark.anyio -async def test_send_request_and_notify_raise_runtimeerror_when_no_peer_connected(): +async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) with pytest.raises(RuntimeError, match="no peer"): - await d.send_request("ping", None) + await d.send_raw_request("ping", None) with pytest.raises(RuntimeError, match="no peer"): await d.notify("ping", None) From 163d38f9d9592000b98125a1d6c5e5e074c4c511 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 13:26:22 +0000 Subject: [PATCH 05/52] feat: JSONRPCDispatcher outbound side + parametrized contract tests Chunk (a) of JSONRPCDispatcher: constructor, _Pending/_InFlight/_JSONRPCDispatchContext, send_request/notify and helpers. run() is stubbed. The Dispatcher contract tests are now parametrized over a pair_factory fixture (direct + jsonrpc). The 9 jsonrpc cases are strict-xfail until run()/ _handle_request land in the next commits; once those pass, strict xfail flips to XPASS and forces removal of the marker. Factories return (client, server, close) so running_pair can shut down any implementation uniformly. --- src/mcp/shared/jsonrpc_dispatcher.py | 283 +++++++++++++++++++++++++++ tests/shared/conftest.py | 67 +++++++ tests/shared/test_dispatcher.py | 135 +++++++------ 3 files changed, 421 insertions(+), 64 deletions(-) create mode 100644 src/mcp/shared/jsonrpc_dispatcher.py create mode 100644 tests/shared/conftest.py diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py new file mode 100644 index 0000000000..2a6e0951b8 --- /dev/null +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -0,0 +1,283 @@ +"""JSON-RPC `Dispatcher` implementation. + +Consumes the existing `SessionMessage`-based stream contract that all current +transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation, +the receive loop, per-request task isolation, cancellation/progress wiring, and +the single exception-to-wire boundary. + +The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and +sees only `(ctx, method, params) -> dict`. Transports sit below and see only +`SessionMessage` reads/writes. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable, Mapping +from dataclasses import dataclass, field +from typing import Any, Generic, Literal, TypeVar, overload + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.message import ( + ClientMessageMetadata, + MessageMetadata, + ServerMessageMetadata, + SessionMessage, +) +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + REQUEST_TIMEOUT, + ErrorData, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + ProgressToken, + RequestId, +) + +__all__ = ["JSONRPCDispatcher"] + +logger = logging.getLogger(__name__) + +TransportT = TypeVar("TransportT", bound=TransportContext) + +PeerCancelMode = Literal["interrupt", "signal"] +"""How inbound ``notifications/cancelled`` is applied to a running handler. + +``"interrupt"`` (default) cancels the handler's scope. ``"signal"`` only sets +``ctx.cancel_requested`` and lets the handler observe it cooperatively. +""" + +TransportBuilder = Callable[[RequestId | None, MessageMetadata], TransportContext] +"""Builds the per-message `TransportContext` from the inbound JSON-RPC id and +the `SessionMessage.metadata` the transport attached. Defaults to a plain +`TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied.""" + + +@dataclass(slots=True) +class _Pending: + """An outbound request awaiting its response.""" + + send: MemoryObjectSendStream[dict[str, Any] | ErrorData] + receive: MemoryObjectReceiveStream[dict[str, Any] | ErrorData] + on_progress: ProgressFnT | None = None + + +@dataclass(slots=True) +class _InFlight(Generic[TransportT]): + """An inbound request currently being handled.""" + + scope: anyio.CancelScope + dctx: _JSONRPCDispatchContext[TransportT] + cancelled_by_peer: bool = False + + +@dataclass +class _JSONRPCDispatchContext(Generic[TransportT]): + """Concrete `DispatchContext` produced for each inbound JSON-RPC message.""" + + transport: TransportT + _dispatcher: JSONRPCDispatcher[TransportT] + _request_id: RequestId | None + _progress_token: ProgressToken | None = None + _closed: bool = False + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + @property + def can_send_request(self) -> bool: + return self.transport.can_send_request and not self._closed + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._dispatcher.notify(method, params, _related_request_id=self._request_id) + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.can_send_request: + raise NoBackChannelError(method) + return await self._dispatcher.send_request(method, params, opts, _related_request_id=self._request_id) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._progress_token is None: + return + params: dict[str, Any] = {"progressToken": self._progress_token, "progress": progress} + if total is not None: + params["total"] = total + if message is not None: + params["message"] = message + await self.notify("notifications/progress", params) + + def close(self) -> None: + self._closed = True + + +def _default_transport_builder(_request_id: RequestId | None, _meta: MessageMetadata) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + +def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: + """Choose the `SessionMessage.metadata` for an outgoing request/notification. + + `ServerMessageMetadata` tags a server-to-client message with the inbound + request it belongs to (so streamable-HTTP can route it onto that request's + SSE stream). `ClientMessageMetadata` carries resumption hints to the + client transport. ``None`` is the common case. + """ + if related_request_id is not None: + return ServerMessageMetadata(related_request_id=related_request_id) + if opts: + token = opts.get("resumption_token") + on_token = opts.get("on_resumption_token") + if token is not None or on_token is not None: + return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token) + return None + + +class JSONRPCDispatcher(Generic[TransportT]): + """`Dispatcher` over the existing `SessionMessage` stream contract.""" + + @overload + def __init__( + self: JSONRPCDispatcher[TransportContext], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + ) -> None: ... + @overload + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT], + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + ) -> None: ... + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT] | None = None, + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + ) -> None: + self._read_stream = read_stream + self._write_stream = write_stream + self._transport_builder = transport_builder or _default_transport_builder + self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode + self._raise_handler_exceptions = raise_handler_exceptions + + self._next_id = 0 + self._pending: dict[RequestId, _Pending] = {} + self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._running = False + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + *, + _related_request_id: RequestId | None = None, + ) -> dict[str, Any]: + """Send a JSON-RPC request and await its response. + + ``_related_request_id`` is set only by `_JSONRPCDispatchContext` when a + handler makes a server-to-client request mid-flight; it routes the + outgoing message onto the correct per-request SSE stream (SHTTP) via + `ServerMessageMetadata`. Top-level callers leave it ``None``. + + Raises: + MCPError: The peer responded with a JSON-RPC error; or + ``REQUEST_TIMEOUT`` if ``opts["timeout"]`` elapsed; or + ``CONNECTION_CLOSED`` if the dispatcher shut down while + awaiting the response. + RuntimeError: Called before ``run()`` has started or after it has + finished. + """ + if not self._running: + raise RuntimeError("JSONRPCDispatcher.send_request called before run() / after close") + opts = opts or {} + request_id = self._allocate_id() + out_params = dict(params) if params is not None else None + on_progress = opts.get("on_progress") + if on_progress is not None: + # The caller wants progress updates. The spec mechanism is: include + # `_meta.progressToken` on the request; the peer echoes that token on + # any `notifications/progress` it sends. We use the request id as the + # token so the receive loop can find this `_Pending.on_progress` by + # `_pending[token]` without a second lookup table. + meta = dict((out_params or {}).get("_meta") or {}) + meta["progressToken"] = request_id + out_params = {**(out_params or {}), "_meta": meta} + + send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + pending = _Pending(send=send, receive=receive, on_progress=on_progress) + self._pending[request_id] = pending + + metadata = _outbound_metadata(_related_request_id, opts) + msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) + try: + await self._write(msg, metadata) + with anyio.fail_after(opts.get("timeout")): + outcome = await receive.receive() + except TimeoutError: + # Spec-recommended courtesy: tell the peer we've given up so it can + # stop work and free resources. v1's BaseSession.send_request does + # NOT do this; it's new behaviour. + await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s") + raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None + except anyio.get_cancelled_exc_class(): + # Our caller's scope was cancelled. We're already inside a cancelled + # scope, so any bare `await` here re-raises immediately — shield to + # let the courtesy cancel notification go out before we propagate. + with anyio.CancelScope(shield=True): + await self._cancel_outbound(request_id, "caller cancelled") + raise + finally: + # Always remove the waiter, even on cancel/timeout, so a late + # response from the peer (race) hits a closed stream and is dropped + # in `_dispatch` rather than leaking. + self._pending.pop(request_id, None) + send.close() + receive.close() + + if isinstance(outcome, ErrorData): + raise MCPError(code=outcome.code, message=outcome.message, data=outcome.data) + return outcome + + async def notify( + self, + method: str, + params: Mapping[str, Any] | None, + *, + _related_request_id: RequestId | None = None, + ) -> None: + msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None) + await self._write(msg, _outbound_metadata(_related_request_id, None)) + + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + raise NotImplementedError # chunk (b) + + def _allocate_id(self) -> int: + self._next_id += 1 + return self._next_id + + async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None: + await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) + + async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: + try: + await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) + except anyio.BrokenResourceError: + pass + except anyio.ClosedResourceError: + pass diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py new file mode 100644 index 0000000000..ffa254804e --- /dev/null +++ b/tests/shared/conftest.py @@ -0,0 +1,67 @@ +"""Shared fixtures for `Dispatcher` contract tests. + +The `pair_factory` fixture parametrizes contract tests over every `Dispatcher` +implementation, so the same behavioral assertions run against `DirectDispatcher` +(in-memory) and `JSONRPCDispatcher` (over crossed anyio memory streams). +""" + +from collections.abc import Callable + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.dispatcher import Dispatcher +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import SessionMessage +from mcp.shared.transport_context import TransportContext + +DispatcherTriple = tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Callable[[], None]] +PairFactory = Callable[..., DispatcherTriple] + + +def direct_pair(*, can_send_request: bool = True) -> DispatcherTriple: + client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + + def close() -> None: + client.close() + server.close() + + return client, server, close + + +def jsonrpc_pair(*, can_send_request: bool = True) -> DispatcherTriple: + """Two `JSONRPCDispatcher`s wired over crossed in-memory streams.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_rid: object, _meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=can_send_request) + + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send, transport_builder=builder) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, transport_builder=builder) + + def close() -> None: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + return client, server, close + + +_JSONRPC_XFAIL = pytest.mark.xfail( + strict=True, + reason="JSONRPCDispatcher.run() not yet implemented (PR2 chunks b/c)", +) + + +@pytest.fixture( + params=[ + pytest.param(direct_pair, id="direct"), + pytest.param(jsonrpc_pair, id="jsonrpc", marks=_JSONRPC_XFAIL), + ] +) +def pair_factory(request: pytest.FixtureRequest) -> PairFactory: + return request.param + + +__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 784ef6698f..31fba3dd5d 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -1,8 +1,9 @@ -"""Behavioral tests for the Dispatcher Protocol via DirectDispatcher. +"""Behavioral tests for the Dispatcher Protocol. -These exercise the `Dispatcher` / `DispatchContext` contract end-to-end using -the in-memory `DirectDispatcher`. JSON-RPC framing is covered separately in -``test_jsonrpc_dispatcher.py``. +The contract tests are parametrized over every `Dispatcher` implementation via +the `pair_factory` fixture (see ``conftest.py``); they must pass for both +`DirectDispatcher` and `JSONRPCDispatcher`. Implementation-specific tests pass +a concrete factory directly. """ from collections.abc import AsyncIterator, Mapping @@ -14,10 +15,12 @@ from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound -from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT +from .conftest import PairFactory, direct_pair + class Recorder: def __init__(self) -> None: @@ -44,31 +47,34 @@ async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: @asynccontextmanager async def running_pair( + factory: PairFactory, *, server_on_request: OnRequest | None = None, server_on_notify: OnNotify | None = None, client_on_request: OnRequest | None = None, client_on_notify: OnNotify | None = None, can_send_request: bool = True, -) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: +) -> AsyncIterator[tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Recorder, Recorder]]: """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" - client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + client, server, close = factory(can_send_request=can_send_request) client_rec, server_rec = Recorder(), Recorder() c_req, c_notify = echo_handlers(client_rec) s_req, s_notify = echo_handlers(server_rec) - async with anyio.create_task_group() as tg: - tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) - tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) - try: - yield client, server, client_rec, server_rec - finally: - client.close() - server.close() + try: + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) + tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) + try: + yield client, server, client_rec, server_rec + finally: + tg.cancel_scope.cancel() + finally: + close() @pytest.mark.anyio -async def test_send_raw_request_returns_result_from_peer_on_request(): - async with running_pair() as (client, _server, _crec, srec): +async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): result = await client.send_raw_request("tools/list", {"cursor": "abc"}) assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} @@ -76,13 +82,13 @@ async def test_send_raw_request_returns_result_from_peer_on_request(): @pytest.mark.anyio -async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(): +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: raise MCPError(code=INVALID_PARAMS, message="bad cursor") - async with running_pair(server_on_request=on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", {}) assert exc.value.error.code == INVALID_PARAMS @@ -90,36 +96,22 @@ async def on_request( @pytest.mark.anyio -async def test_send_raw_request_wraps_non_mcperror_exception_as_internal_error(): - async def on_request( - ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None - ) -> dict[str, Any]: - raise ValueError("oops") - - async with running_pair(server_on_request=on_request) as (client, *_): - with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_raw_request("tools/list", {}) - assert exc.value.error.code == INTERNAL_ERROR - assert isinstance(exc.value.__cause__, ValueError) - - -@pytest.mark.anyio -async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(): +async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(pair_factory: PairFactory): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await anyio.sleep_forever() raise NotImplementedError - async with running_pair(server_on_request=on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("slow", None, {"timeout": 0}) assert exc.value.error.code == REQUEST_TIMEOUT @pytest.mark.anyio -async def test_notify_invokes_peer_on_notify(): - async with running_pair() as (client, _server, _crec, srec): +async def test_notify_invokes_peer_on_notify(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): await client.notify("notifications/initialized", {"v": 1}) await srec.notified.wait() @@ -127,7 +119,7 @@ async def test_notify_invokes_peer_on_notify(): @pytest.mark.anyio -async def test_ctx_send_raw_request_round_trips_to_calling_side(): +async def test_ctx_send_raw_request_round_trips_to_calling_side(pair_factory: PairFactory): """A handler's ctx.send_raw_request reaches the side that made the inbound request.""" async def server_on_request( @@ -136,7 +128,7 @@ async def server_on_request( sample = await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) return {"sampled": sample} - async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): result = await client.send_raw_request("tools/call", None) assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] @@ -144,28 +136,27 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(): +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: return await ctx.send_raw_request("sampling/createMessage", None) - async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): - with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: + async with running_pair(pair_factory, server_on_request=server_on_request, can_send_request=False) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/call", None) - assert exc.value.method == "sampling/createMessage" assert exc.value.error.code == INVALID_REQUEST @pytest.mark.anyio -async def test_ctx_notify_invokes_calling_side_on_notify(): +async def test_ctx_notify_invokes_calling_side_on_notify(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.notify("notifications/message", {"level": "info"}) return {} - async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): await client.send_raw_request("tools/call", None) await crec.notified.wait() @@ -173,7 +164,7 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_progress_invokes_caller_on_progress_callback(): +async def test_ctx_progress_invokes_caller_on_progress_callback(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -185,14 +176,44 @@ async def server_on_request( async def on_progress(progress: float, total: float | None, message: str | None) -> None: received.append((progress, total, message)) - async with running_pair(server_on_request=server_on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio -async def test_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): +async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_direct_send_raw_request_wraps_non_mcperror_exception_as_internal_error_with_cause(): + """DirectDispatcher-specific: the original exception is chained via __cause__.""" + + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise ValueError("oops") + + async with running_pair(direct_pair, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", {}) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) + + +@pytest.mark.anyio +async def test_direct_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): client, server = create_direct_dispatcher_pair() s_req, s_notify = echo_handlers(Recorder()) c_req, c_notify = echo_handlers(Recorder()) @@ -212,21 +233,7 @@ async def late_start(): @pytest.mark.anyio -async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): - async def server_on_request( - ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None - ) -> dict[str, Any]: - await ctx.progress(0.5) - return {"ok": True} - - async with running_pair(server_on_request=server_on_request) as (client, *_): - with anyio.fail_after(5): - result = await client.send_raw_request("tools/call", None) - assert result == {"ok": True} - - -@pytest.mark.anyio -async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): +async def test_direct_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) with pytest.raises(RuntimeError, match="no peer"): await d.send_raw_request("ping", None) @@ -235,7 +242,7 @@ async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_conne @pytest.mark.anyio -async def test_close_makes_run_return(): +async def test_direct_close_makes_run_return(): client, server = create_direct_dispatcher_pair() on_request, on_notify = echo_handlers(Recorder()) with anyio.fail_after(5): From 4739744686b183e880ea680a72f69728b5e79b3d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 14:59:11 +0000 Subject: [PATCH 06/52] feat: JSONRPCDispatcher receive loop and dispatch (chunk b) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit run() drives the receive loop in a per-request task group; task_status.started() fires once send_request is usable. _dispatch routes each inbound message synchronously (no awaits — send_nowait/_spawn only) to avoid head-of-line blocking. _spawn propagates the sender's contextvars via Context.run(tg.start_soon, ...) so auth/OTel set by ASGI middleware survive. _fan_out_closed wakes pending send_request waiters with CONNECTION_CLOSED on shutdown (called both post-EOF and in finally; idempotent). Wire-param extraction (progressToken, cancelled.requestId, progress fields) uses structural match patterns — runtime narrowing, no casts, no mcp.types model coupling; malformed input fails to match and the correlation is skipped. _handle_request is happy-path only here (run on_request, write response); the exception-to-wire boundary lands in the next commit. Dispatcher.run() Protocol gained a task_status kwarg (it's a contract-level guarantee). DirectDispatcher.run() updated to match. running_pair now uses tg.start so the test body runs only once the dispatcher is ready. 20 contract tests pass; the 2 needing the exception boundary are strict-xfail. --- src/mcp/shared/direct_dispatcher.py | 10 +- src/mcp/shared/dispatcher.py | 13 +- src/mcp/shared/jsonrpc_dispatcher.py | 234 ++++++++++++++++++++++++++- tests/shared/conftest.py | 22 ++- tests/shared/test_dispatcher.py | 18 ++- 5 files changed, 274 insertions(+), 23 deletions(-) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index bb5639a136..27443ec874 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -20,6 +20,7 @@ from typing import Any import anyio +import anyio.abc from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError @@ -101,10 +102,17 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") await self._peer._dispatch_notify(method, params) - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: self._on_request = on_request self._on_notify = on_notify self._ready.set() + task_status.started() await self._closed.wait() def close(self) -> None: diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index ee02e23896..20c090323b 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -20,6 +20,7 @@ from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable import anyio +import anyio.abc from mcp.shared.transport_context import TransportContext @@ -136,11 +137,21 @@ class Dispatcher(Outbound, Protocol[TransportT_co]): receive loop, per-request concurrency, and cancellation/progress wiring. """ - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: """Drive the receive loop until the underlying channel closes. Each inbound request is dispatched to ``on_request`` in its own task; the returned dict (or raised ``MCPError``) is sent back as the response. Inbound notifications go to ``on_notify``. + + ``task_status.started()`` is called once the dispatcher is ready to + accept ``send_request``/``notify`` calls, so callers can use + ``await tg.start(dispatcher.run, on_request, on_notify)``. """ ... diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 2a6e0951b8..6bf957c19a 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -8,20 +8,30 @@ The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and sees only `(ctx, method, params) -> dict`. Transports sit below and see only `SessionMessage` reads/writes. + +The dispatcher is *mostly* MCP-agnostic — methods/params are opaque strings and +dicts — but it intercepts ``notifications/cancelled`` and +``notifications/progress`` because request correlation, cancellation and +progress are exactly the wiring this layer exists to provide. Those few wire +shapes are extracted with structural ``match`` patterns (no casts, no +``mcp.types`` model coupling); a malformed payload simply fails to match and +the correlation is skipped. """ from __future__ import annotations +import contextvars import logging -from collections.abc import Callable, Mapping +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeVar, overload +from typing import Any, Generic, Literal, TypeVar, cast, overload import anyio +import anyio.abc from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT +from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import ( ClientMessageMetadata, @@ -31,11 +41,14 @@ ) from mcp.shared.transport_context import TransportContext from mcp.types import ( + CONNECTION_CLOSED, REQUEST_TIMEOUT, ErrorData, + JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, + JSONRPCResponse, ProgressToken, RequestId, ) @@ -141,8 +154,12 @@ def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | return None -class JSONRPCDispatcher(Generic[TransportT]): - """`Dispatcher` over the existing `SessionMessage` stream contract.""" +class JSONRPCDispatcher(Dispatcher[TransportT]): + """`Dispatcher` over the existing `SessionMessage` stream contract. + + Inherits the `Dispatcher` Protocol explicitly so pyright checks + conformance at the class definition rather than at first use. + """ @overload def __init__( @@ -171,13 +188,20 @@ def __init__( ) -> None: self._read_stream = read_stream self._write_stream = write_stream - self._transport_builder = transport_builder or _default_transport_builder + # The overloads guarantee that when `transport_builder` is omitted, + # `TransportT` is `TransportContext`, so the default is type-correct; + # pyright can't see across overloads, hence the cast. + self._transport_builder = cast( + "Callable[[RequestId | None, MessageMetadata], TransportT]", + transport_builder or _default_transport_builder, + ) self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode self._raise_handler_exceptions = raise_handler_exceptions self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._tg: anyio.abc.TaskGroup | None = None self._running = False async def send_request( @@ -219,6 +243,11 @@ async def send_request( meta["progressToken"] = request_id out_params = {**(out_params or {}), "_meta": meta} + # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from + # `_resolve_pending`/`_fan_out_closed` means the waiter already has an + # outcome and dropping the late/redundant signal is correct. buffer=0 + # is unsafe — there's a window between registering `_pending[id]` and + # parking in `receive()` where a close signal would be lost. send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) pending = _Pending(send=send, receive=receive, on_progress=on_progress) self._pending[request_id] = pending @@ -264,8 +293,197 @@ async def notify( msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None) await self._write(msg, _outbound_metadata(_related_request_id, None)) - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: - raise NotImplementedError # chunk (b) + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Drive the receive loop until the read stream closes. + + Each inbound request is handled in its own task in an internal task + group; ``task_status.started()`` fires once that group is open, so + ``await tg.start(dispatcher.run, ...)`` resumes when ``send_request`` + is usable. + """ + try: + async with anyio.create_task_group() as tg: + self._tg = tg + self._running = True + task_status.started() + async with self._read_stream: + async for item in self._read_stream: + # Duck-typed: `_context_streams.ContextReceiveStream` + # exposes `.last_context` (the sender's contextvars + # snapshot per message). Plain memory streams don't. + sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) + self._dispatch(item, on_request, on_notify, sender_ctx) + # Read stream EOF: wake any blocked `send_request` waiters now, + # *before* the task group joins, so handlers parked in + # `dctx.send_request()` can unwind and the join doesn't deadlock. + self._running = False + self._fan_out_closed() + finally: + # Covers the cancel/crash paths where the inline fan-out above is + # never reached. Idempotent. + self._running = False + self._tg = None + self._fan_out_closed() + + def _dispatch( + self, + item: SessionMessage | Exception, + on_request: OnRequest, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + """Route one inbound item. Synchronous: never awaits. + + Everything here is `send_nowait` or `_spawn`. An `await` would let one + slow message head-of-line block the entire read loop. + """ + if isinstance(item, Exception): + logger.debug("transport yielded exception: %r", item) + return + metadata = item.metadata + msg = item.message + match msg: + case JSONRPCRequest(): + self._dispatch_request(msg, metadata, on_request, sender_ctx) + case JSONRPCNotification(): + self._dispatch_notification(msg, metadata, on_notify, sender_ctx) + case JSONRPCResponse(): + self._resolve_pending(msg.id, msg.result) + case JSONRPCError(): + # `id` may be None per JSON-RPC (parse error before id known). + self._resolve_pending(msg.id, msg.error) + + def _dispatch_request( + self, + req: JSONRPCRequest, + metadata: MessageMetadata, + on_request: OnRequest, + sender_ctx: contextvars.Context | None, + ) -> None: + progress_token: ProgressToken | None + match req.params: + case {"_meta": {"progressToken": str() | int() as progress_token}}: + pass + case _: + progress_token = None + transport_ctx = self._transport_builder(req.id, metadata) + dctx = _JSONRPCDispatchContext( + transport=transport_ctx, + _dispatcher=self, + _request_id=req.id, + _progress_token=progress_token, + ) + scope = anyio.CancelScope() + self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx) + self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx) + + def _dispatch_notification( + self, + msg: JSONRPCNotification, + metadata: MessageMetadata, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + if msg.method == "notifications/cancelled": + match msg.params: + case {"requestId": str() | int() as rid} if (in_flight := self._in_flight.get(rid)) is not None: + in_flight.cancelled_by_peer = True + in_flight.dctx.cancel_requested.set() + if self._peer_cancel_mode == "interrupt": + in_flight.scope.cancel() + case _: + pass + return + if msg.method == "notifications/progress": + match msg.params: + case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( + pending := self._pending.get(token) + ) is not None and pending.on_progress is not None: + total = msg.params.get("total") + message = msg.params.get("message") + self._spawn( + pending.on_progress, + float(progress), + float(total) if isinstance(total, int | float) else None, + message if isinstance(message, str) else None, + sender_ctx=sender_ctx, + ) + case _: + pass + # fall through: progress is also teed to on_notify + transport_ctx = self._transport_builder(None, metadata) + dctx = _JSONRPCDispatchContext(transport=transport_ctx, _dispatcher=self, _request_id=None) + self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) + + def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: + pending = self._pending.get(request_id) if request_id is not None else None + if pending is None: + logger.debug("dropping response for unknown/late request id %r", request_id) + return + try: + pending.send.send_nowait(outcome) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("waiter for request id %r already gone", request_id) + + def _spawn( + self, + fn: Callable[..., Awaitable[Any]], + *args: object, + sender_ctx: contextvars.Context | None, + ) -> None: + """Schedule ``fn(*args)`` in the run() task group, propagating the sender's contextvars. + + ASGI middleware (auth, OTel) sets contextvars on the request task that + wrote into the read stream. ``Context.run(tg.start_soon, ...)`` makes + the spawned handler inherit *that* context instead of the receive + loop's, so ``auth_context_var`` and OTel spans survive. + """ + assert self._tg is not None + if sender_ctx is not None: + sender_ctx.run(self._tg.start_soon, fn, *args) + else: + self._tg.start_soon(fn, *args) + + def _fan_out_closed(self) -> None: + """Wake every pending ``send_request`` waiter with ``CONNECTION_CLOSED``. + + Synchronous (uses ``send_nowait``) because it's called from ``finally`` + which may be inside a cancelled scope. Idempotent. + """ + closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed") + for pending in self._pending.values(): + try: + pending.send.send_nowait(closed) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + pass + self._pending.clear() + + async def _handle_request( + self, + req: JSONRPCRequest, + dctx: _JSONRPCDispatchContext[TransportT], + scope: anyio.CancelScope, + on_request: OnRequest, + ) -> None: + """Run ``on_request`` for one inbound request and write its response. + + Chunk (b): happy-path only. The full exception-to-wire boundary + (MCPError, ValidationError, INTERNAL_ERROR scrubbing, peer-cancel + no-response) lands in chunk (c). + """ + try: + with scope: + result = await on_request(dctx, req.method, req.params) + await self._write(JSONRPCResponse(jsonrpc="2.0", id=req.id, result=result)) + finally: + self._in_flight.pop(req.id, None) + dctx.close() def _allocate_id(self) -> int: self._next_id += 1 diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py index ffa254804e..b7049493a2 100644 --- a/tests/shared/conftest.py +++ b/tests/shared/conftest.py @@ -48,20 +48,26 @@ def close() -> None: return client, server, close -_JSONRPC_XFAIL = pytest.mark.xfail( - strict=True, - reason="JSONRPCDispatcher.run() not yet implemented (PR2 chunks b/c)", -) - - @pytest.fixture( params=[ pytest.param(direct_pair, id="direct"), - pytest.param(jsonrpc_pair, id="jsonrpc", marks=_JSONRPC_XFAIL), + pytest.param(jsonrpc_pair, id="jsonrpc"), ] ) def pair_factory(request: pytest.FixtureRequest) -> PairFactory: return request.param -__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"] +def xfail_jsonrpc_chunk_c(request: pytest.FixtureRequest, factory: PairFactory) -> None: + """Apply a strict xfail when running against the JSON-RPC dispatcher. + + Use for contract tests that require `_handle_request`'s exception boundary + (PR2 chunk c). Remove once that lands. + """ + if factory is jsonrpc_pair: + request.applymarker( + pytest.mark.xfail(strict=True, reason="needs JSONRPCDispatcher._handle_request exception boundary") + ) + + +__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair", "xfail_jsonrpc_chunk_c"] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 31fba3dd5d..aef6b60bcb 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -19,7 +19,7 @@ from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT -from .conftest import PairFactory, direct_pair +from .conftest import PairFactory, direct_pair, xfail_jsonrpc_chunk_c class Recorder: @@ -62,8 +62,8 @@ async def running_pair( s_req, s_notify = echo_handlers(server_rec) try: async with anyio.create_task_group() as tg: - tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) - tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) + await tg.start(client.run, client_on_request or c_req, client_on_notify or c_notify) + await tg.start(server.run, server_on_request or s_req, server_on_notify or s_notify) try: yield client, server, client_rec, server_rec finally: @@ -82,7 +82,11 @@ async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory @pytest.mark.anyio -async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged( + pair_factory: PairFactory, request: pytest.FixtureRequest +): + xfail_jsonrpc_chunk_c(request, pair_factory) + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -136,7 +140,11 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(pair_factory: PairFactory): +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows( + pair_factory: PairFactory, request: pytest.FixtureRequest +): + xfail_jsonrpc_chunk_c(request, pair_factory) + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: From 62a555b1680f8935204df9fdede29cff799c8435 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 15:20:43 +0000 Subject: [PATCH 07/52] feat: JSONRPCDispatcher exception boundary (chunk c) _handle_request is now the single exception-to-wire boundary: - MCPError -> JSONRPCError(e.error) - pydantic ValidationError -> INVALID_PARAMS - Exception -> INTERNAL_ERROR(str(e)), logged, optionally re-raised - outer-cancel (run() TG shutdown) -> shielded REQUEST_CANCELLED write, re-raise - peer-cancel (notifications/cancelled) -> scope swallows, no response written dctx.close() runs in an inner finally so the back-channel shuts the moment the handler exits. _write_result/_write_error swallow Broken/ClosedResourceError so a dropped connection during the response write doesn't crash the dispatcher. All 22 contract tests now pass against both DirectDispatcher and JSONRPCDispatcher; chunk-c xfail markers removed. --- src/mcp/shared/jsonrpc_dispatcher.py | 54 ++++++++++++++++++++++++---- tests/shared/conftest.py | 14 +------- tests/shared/test_dispatcher.py | 14 ++------ 3 files changed, 52 insertions(+), 30 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 6bf957c19a..f35b37cf9d 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -29,6 +29,7 @@ import anyio import anyio.abc from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT @@ -42,6 +43,9 @@ from mcp.shared.transport_context import TransportContext from mcp.types import ( CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + REQUEST_CANCELLED, REQUEST_TIMEOUT, ErrorData, JSONRPCError, @@ -473,17 +477,43 @@ async def _handle_request( ) -> None: """Run ``on_request`` for one inbound request and write its response. - Chunk (b): happy-path only. The full exception-to-wire boundary - (MCPError, ValidationError, INTERNAL_ERROR scrubbing, peer-cancel - no-response) lands in chunk (c). + This is the single exception-to-wire boundary: handler exceptions are + caught here and serialized to ``JSONRPCError``. Nothing above this in + the stack constructs wire errors. """ try: with scope: - result = await on_request(dctx, req.method, req.params) - await self._write(JSONRPCResponse(jsonrpc="2.0", id=req.id, result=result)) + try: + result = await on_request(dctx, req.method, req.params) + finally: + # Close the back-channel the moment the handler exits + # (success or raise), before the response write — a handler + # spawning detached work that later calls + # `dctx.send_request()` should see `NoBackChannelError`. + dctx.close() + await self._write_result(req.id, result) + # Peer-cancel: `_dispatch_notification` cancelled this scope. anyio + # swallows a scope's *own* cancel at __exit__, so the result write + # (or the handler) is interrupted and execution lands here without + # reaching the `except cancelled` arm below. Spec SHOULD: send no + # response — fall through to `finally`. + except anyio.get_cancelled_exc_class(): + # Outer-cancel: run()'s task group is shutting down. Any bare + # `await` here re-raises immediately, so shield the courtesy write. + with anyio.CancelScope(shield=True): + await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled")) + raise + except MCPError as e: + await self._write_error(req.id, e.error) + except ValidationError as e: + await self._write_error(req.id, ErrorData(code=INVALID_PARAMS, message=str(e))) + except Exception as e: + logger.exception("handler for %r raised", req.method) + await self._write_error(req.id, ErrorData(code=INTERNAL_ERROR, message=str(e))) + if self._raise_handler_exceptions: + raise finally: self._in_flight.pop(req.id, None) - dctx.close() def _allocate_id(self) -> int: self._next_id += 1 @@ -492,6 +522,18 @@ def _allocate_id(self) -> int: async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None: await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) + async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None: + try: + await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped result for %r: write stream closed", request_id) + + async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: + try: + await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped error for %r: write stream closed", request_id) + async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: try: await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py index b7049493a2..1222c05aba 100644 --- a/tests/shared/conftest.py +++ b/tests/shared/conftest.py @@ -58,16 +58,4 @@ def pair_factory(request: pytest.FixtureRequest) -> PairFactory: return request.param -def xfail_jsonrpc_chunk_c(request: pytest.FixtureRequest, factory: PairFactory) -> None: - """Apply a strict xfail when running against the JSON-RPC dispatcher. - - Use for contract tests that require `_handle_request`'s exception boundary - (PR2 chunk c). Remove once that lands. - """ - if factory is jsonrpc_pair: - request.applymarker( - pytest.mark.xfail(strict=True, reason="needs JSONRPCDispatcher._handle_request exception boundary") - ) - - -__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair", "xfail_jsonrpc_chunk_c"] +__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index aef6b60bcb..fc967c1299 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -19,7 +19,7 @@ from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT -from .conftest import PairFactory, direct_pair, xfail_jsonrpc_chunk_c +from .conftest import PairFactory, direct_pair class Recorder: @@ -82,11 +82,7 @@ async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory @pytest.mark.anyio -async def test_send_raw_request_reraises_mcperror_from_handler_unchanged( - pair_factory: PairFactory, request: pytest.FixtureRequest -): - xfail_jsonrpc_chunk_c(request, pair_factory) - +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -140,11 +136,7 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows( - pair_factory: PairFactory, request: pytest.FixtureRequest -): - xfail_jsonrpc_chunk_c(request, pair_factory) - +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: From 72913ad6bacc0020df331be0ca3671c9b2831722 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 15:46:17 +0000 Subject: [PATCH 08/52] test: JSON-RPC-specific dispatcher tests + coverage to 100% Covers behaviors with no DirectDispatcher analog: out-of-order response correlation, INTERNAL_ERROR over the wire, peer-cancel in interrupt and signal modes, CONNECTION_CLOSED on stream EOF mid-await, late-response drop, raise_handler_exceptions propagation, ServerMessageMetadata tagging on ctx.send_request, null-id JSONRPCError drop, ValidationError->INVALID_PARAMS, contextvar propagation via _spawn, and the defensive Broken/Closed/WouldBlock catches. Two small src tweaks for coverage: - _cancel_outbound: combine the two except arms into one tuple - _dispatch: pragma no-branch on the final case (match is exhaustive over JSONRPCMessage; the no-match arc is unreachable) 43 tests, 100% coverage on all PR2 modules, 0.15s wall-clock. --- src/mcp/shared/jsonrpc_dispatcher.py | 8 +- tests/shared/test_jsonrpc_dispatcher.py | 531 ++++++++++++++++++++++++ 2 files changed, 535 insertions(+), 4 deletions(-) create mode 100644 tests/shared/test_jsonrpc_dispatcher.py diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index f35b37cf9d..bbf5666069 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -359,8 +359,10 @@ def _dispatch( self._dispatch_notification(msg, metadata, on_notify, sender_ctx) case JSONRPCResponse(): self._resolve_pending(msg.id, msg.result) - case JSONRPCError(): + case JSONRPCError(): # pragma: no branch # `id` may be None per JSON-RPC (parse error before id known). + # The match is exhaustive over JSONRPCMessage; the no-match arc + # on this final case is unreachable. self._resolve_pending(msg.id, msg.error) def _dispatch_request( @@ -537,7 +539,5 @@ async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: try: await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) - except anyio.BrokenResourceError: - pass - except anyio.ClosedResourceError: + except (anyio.BrokenResourceError, anyio.ClosedResourceError): pass diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py new file mode 100644 index 0000000000..ff24ef4c6b --- /dev/null +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -0,0 +1,531 @@ +"""JSON-RPC-specific Dispatcher tests. + +Behaviors with no `DirectDispatcher` analog: request-id correlation, the +exception-to-wire boundary, peer-cancel handling, and shutdown fan-out. +The contract tests shared with `DirectDispatcher` live in +``test_dispatcher.py``. +""" + +import contextvars +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] + JSONRPCDispatcher, + _outbound_metadata, + _Pending, +) +from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + ErrorData, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + Tool, +) + +from .conftest import jsonrpc_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_concurrent_send_requests_correlate_by_id_when_responses_arrive_out_of_order(): + release_first = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "first": + await release_first.wait() + return {"m": method} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + results: dict[str, dict[str, Any]] = {} + + async def call(method: str) -> None: + results[method] = await client.send_request(method, None) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(call, "first") + await anyio.sleep(0) + tg.start_soon(call, "second") + await anyio.sleep(0) + # second resolves while first is still parked + assert "first" not in results + release_first.set() + assert results == {"first": {"m": "first"}, "second": {"m": "second"}} + + +@pytest.mark.anyio +async def test_handler_raising_exception_sends_internal_error_with_str_message(): + """Per design: INTERNAL_ERROR carries str(e), not a scrubbed message.""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("kaboom") + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.message == "kaboom" + assert exc.value.__cause__ is None # cause does not survive the wire + + +@pytest.mark.anyio +async def test_peer_cancel_interrupt_mode_sets_cancel_requested_and_sends_no_response(): + handler_started = anyio.Event() + handler_exited = anyio.Event() + seen_ctx: list[DCtx] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen_ctx.append(ctx) + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_exited.set() + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def call_then_record() -> None: + with pytest.raises(MCPError): # we'll cancel via tg below + await client.send_request("slow", None) + + tg.start_soon(call_then_record) + await handler_started.wait() + # cancel just the handler (peer-cancel), not our caller + await client.notify("notifications/cancelled", {"requestId": 1}) + await handler_exited.wait() + # Handler torn down, no response was written; caller is still parked. + # Cancel the caller's task to end the test. + tg.cancel_scope.cancel() + assert seen_ctx[0].cancel_requested.is_set() + + +@pytest.mark.anyio +async def test_peer_cancel_signal_mode_sets_event_but_handler_runs_to_completion(): + handler_started = anyio.Event() + cancel_seen = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await ctx.cancel_requested.wait() + cancel_seen.set() + return {"finished": True} + + def factory(*, can_send_request: bool = True): + client, server, close = jsonrpc_pair(can_send_request=can_send_request) + # Reach in to set signal mode on the server side. + assert isinstance(server, JSONRPCDispatcher) + server._peer_cancel_mode = "signal" # pyright: ignore[reportPrivateUsage] + return client, server, close + + result_box: list[dict[str, Any]] = [] + async with running_pair(factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def call() -> None: + result_box.append(await client.send_request("slow", None)) + + tg.start_soon(call) + await handler_started.wait() + await client.notify("notifications/cancelled", {"requestId": 1}) + await cancel_seen.wait() + assert result_box == [{"finished": True}] + + +@pytest.mark.anyio +async def test_send_request_raises_connection_closed_when_read_stream_eofs_mid_await(): + """A blocked send_request is woken with CONNECTION_CLOSED when run() exits.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + + async def caller() -> None: + with pytest.raises(MCPError) as exc: + await client.send_request("ping", None) + assert exc.value.error.code == CONNECTION_CLOSED + + tg.start_soon(caller) + await anyio.sleep(0) + # No server: simulate the peer dropping by closing the read side. + s2c_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_late_response_after_timeout_is_dropped_without_crashing(): + handler_started = anyio.Event() + proceed = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await proceed.wait() + return {"late": True} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + with pytest.raises(MCPError): # REQUEST_TIMEOUT + await client.send_request("slow", None, {"timeout": 0}) + # The server handler is still running; let it finish and write a + # response for an id the client has already discarded. + await handler_started.wait() + proceed.set() + # One more round-trip proves the dispatcher is still healthy. + assert await client.send_request("ping", None) == {"late": True} + + +@pytest.mark.anyio +async def test_raise_handler_exceptions_true_propagates_out_of_run(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_rid: object, _meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + c2s_recv, s2c_send, transport_builder=builder, raise_handler_exceptions=True + ) + + async def boom(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("propagate me") + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + with pytest.raises(BaseException) as exc: + async with anyio.create_task_group() as tg: + await tg.start(server.run, boom, on_notify) + # Inject a request directly onto the server's read stream. + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None)) + ) + assert exc.group_contains(RuntimeError, match="propagate me") + # The error response was still written before re-raising. + sent = s2c_recv.receive_nowait() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCError) + assert sent.message.error.code == INTERNAL_ERROR + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_send_request_tags_outbound_with_server_message_metadata(): + """Server-to-client requests carry related_request_id for SHTTP routing.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + # Kick the server with an inbound request id=7. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) + with anyio.fail_after(5): + outbound = await s2c_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + assert isinstance(outbound.metadata, ServerMessageMetadata) + assert outbound.metadata.related_request_id == 7 + # Reply so the handler completes cleanly. + await c2s_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=outbound.message.id, result={"ok": True})) + ) + with anyio.fail_after(5): + final = await s2c_recv.receive() + assert isinstance(final, SessionMessage) + assert isinstance(final.message, JSONRPCResponse) + assert final.message.id == 7 + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_progress_with_only_progress_value_omits_total_and_message(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await ctx.progress(0.25) + return {} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_request("t", None, {"on_progress": on_progress}) + assert received == [(0.25, None, None)] + + +@pytest.mark.anyio +async def test_handler_raising_validation_error_sends_invalid_params(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + Tool.model_validate({"name": 123}) # raises ValidationError + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_request("t", None) + assert exc.value.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_send_request_before_run_raises_runtimeerror(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + try: + with pytest.raises(RuntimeError, match="before run"): + await d.send_request("ping", None) + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_transport_exception_in_read_stream_is_logged_and_dropped(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(ValueError("transport hiccup")) + # Dispatcher must remain healthy after the dropped exception. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_notification_for_unknown_token_falls_through_to_on_notify(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/progress", {"progressToken": 999, "progress": 0.5}) + await srec.notified.wait() + assert srec.notifications == [("notifications/progress", {"progressToken": 999, "progress": 0.5})] + + +@pytest.mark.anyio +async def test_cancelled_notification_for_unknown_request_id_is_noop(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/cancelled", {"requestId": 999}) + # No effect; dispatcher remains healthy. + assert await client.send_request("t", None) == {"echoed": "t", "params": {}} + assert srec.notifications == [] # cancelled is fully consumed, never teed + + +_probe: contextvars.ContextVar[str] = contextvars.ContextVar("probe", default="unset") + + +@pytest.mark.anyio +async def test_handler_inherits_sender_contextvars_via_spawn(): + """The handler task sees contextvars set by the task that wrote into the read stream.""" + raw_send, raw_recv = anyio.create_memory_object_stream[tuple[contextvars.Context, SessionMessage | Exception]](4) + read_stream = ContextReceiveStream[SessionMessage | Exception](raw_recv) + write_send = ContextSendStream[SessionMessage | Exception](raw_send) + out_send, out_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read_stream, out_send) + + seen: list[str] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen.append(_probe.get()) + return {} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + + async def sender() -> None: + _probe.set("from-sender") + await write_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)) + ) + + tg.start_soon(sender) + with anyio.fail_after(5): + resp = await out_recv.receive() + assert isinstance(resp, SessionMessage) + tg.cancel_scope.cancel() + finally: + for s in (raw_send, raw_recv, out_send, out_recv): + s.close() + assert seen == ["from-sender"] + + +@pytest.mark.anyio +async def test_response_write_after_peer_drop_is_swallowed(): + """Handler completes after the write stream is closed; the dropped write doesn't crash run().""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + proceed = anyio.Event() + handlers_done = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await proceed.wait() + if method == "raise": + handlers_done.set() + raise MCPError(code=INTERNAL_ERROR, message="x") + return {"ok": True} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="ok", params=None))) + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="raise", params=None)) + ) + await anyio.sleep(0) + # Peer drops: close the receive end so the server's writes hit BrokenResourceError. + s2c_recv.close() + proceed.set() + with anyio.fail_after(5): + await handlers_done.wait() + # run() must still be healthy — close the read side to let it exit cleanly. + c2s_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_cancel_outbound_after_write_stream_closed_is_swallowed(): + """Courtesy-cancel write hits a closed stream; the error is swallowed and cancellation propagates.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + caller_done = anyio.Event() + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + caller_scope = anyio.CancelScope() + + async def caller() -> None: + with caller_scope: + await client.send_request("slow", None) + caller_done.set() + + tg.start_soon(caller) + # Deterministic proof the request write completed: pull it off the wire. + with anyio.fail_after(5): + sent = await c2s_recv.receive() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCRequest) + # Now safe: close the client's write end, then cancel the caller. The + # shielded `_cancel_outbound` write hits ClosedResourceError and is + # swallowed; cancellation propagates cleanly. + c2s_send.close() + caller_scope.cancel() + with anyio.fail_after(5): + await caller_done.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +def test_resolve_pending_drops_outcome_when_waiter_stream_already_closed(): + """White-box: a response for an id still in _pending but whose waiter has gone.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + recv.close() # waiter gone — send_nowait will raise BrokenResourceError + d._resolve_pending(1, {"late": True}) # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send): + s.close() + + +def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): + """White-box: the buffer=1 invariant — WouldBlock means waiter already has an outcome.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + # Register a fake pending and pre-fill its single buffer slot. + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + send.send_nowait({"real": "result"}) + d._fan_out_closed() # pyright: ignore[reportPrivateUsage] + # The real result is still there; the close signal was dropped. + assert recv.receive_nowait() == {"real": "result"} + assert d._pending == {} # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send, recv): + s.close() + + +def test_outbound_metadata_with_resumption_token_returns_client_metadata(): + md = _outbound_metadata(None, {"resumption_token": "abc"}) + assert isinstance(md, ClientMessageMetadata) + assert md.resumption_token == "abc" + assert _outbound_metadata(None, None) is None + assert _outbound_metadata(None, {}) is None + + +@pytest.mark.anyio +async def test_jsonrpc_error_response_with_null_id_is_dropped(): + """Parse-error responses (id=null) have no waiter; they're logged and dropped.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + await s2c_send.send( + SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=-32700, message="x"))) + ) + await anyio.sleep(0) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() From e0ca9bc5b0db5f01a7a3863f1c70a59c6bf95a3e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:18:48 +0000 Subject: [PATCH 09/52] ci: run full matrix on PRs targeting any branch The pull_request branch filter meant the test/lint/coverage matrix only ran on PRs targeting main or v1.x. Stacked PRs (targeting feature branches) only got the conformance checks, which are continue-on-error and don't exercise unit tests. Removing the filter so the full matrix runs on every PR. --- .github/workflows/main.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d34e438fc9..341df0abb8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,7 +5,6 @@ on: branches: ["main", "v1.x"] tags: ["v*.*.*"] pull_request: - branches: ["main", "v1.x"] permissions: contents: read From afc1789004941aaff29fa8bab71c940ff497889c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:29:41 +0000 Subject: [PATCH 10/52] test: address 3.11/3.14 coverage instrumentation quirks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 3.14: nested async-with arc misreporting on three create_task_group lines (the documented AGENTS.md case) — pragma: no branch. 3.11: lines after async-CM exit with pytest.raises mis-traced in one test — moved the asserts inside the context manager. --- tests/shared/test_dispatcher.py | 4 ++-- tests/shared/test_jsonrpc_dispatcher.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index fc967c1299..bdadd4cdae 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -208,8 +208,8 @@ async def on_request( async with running_pair(direct_pair, server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", {}) - assert exc.value.error.code == INTERNAL_ERROR - assert isinstance(exc.value.__cause__, ValueError) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) @pytest.mark.anyio diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index ff24ef4c6b..be6386d090 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -56,7 +56,7 @@ async def call(method: str) -> None: results[method] = await client.send_request(method, None) with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch tg.start_soon(call, "first") await anyio.sleep(0) tg.start_soon(call, "second") @@ -99,7 +99,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch async def call_then_record() -> None: with pytest.raises(MCPError): # we'll cancel via tg below @@ -137,7 +137,7 @@ def factory(*, can_send_request: bool = True): result_box: list[dict[str, Any]] = [] async with running_pair(factory, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch async def call() -> None: result_box.append(await client.send_request("slow", None)) From f8f350e33d169d0e7271f3ba383056006b19815e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 21:49:25 +0000 Subject: [PATCH 11/52] refactor: rename send_request to send_raw_request in JSONRPCDispatcher Follows the Outbound Protocol rename in the previous commit. Mechanical rename across JSONRPCDispatcher, _JSONRPCDispatchContext, and tests. --- src/mcp/shared/jsonrpc_dispatcher.py | 18 ++++++------- tests/shared/test_jsonrpc_dispatcher.py | 36 ++++++++++++------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index bbf5666069..f1e7b3675e 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -112,7 +112,7 @@ def can_send_request(self) -> bool: async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: await self._dispatcher.notify(method, params, _related_request_id=self._request_id) - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, @@ -120,7 +120,7 @@ async def send_request( ) -> dict[str, Any]: if not self.can_send_request: raise NoBackChannelError(method) - return await self._dispatcher.send_request(method, params, opts, _related_request_id=self._request_id) + return await self._dispatcher.send_raw_request(method, params, opts, _related_request_id=self._request_id) async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: if self._progress_token is None: @@ -208,7 +208,7 @@ def __init__( self._tg: anyio.abc.TaskGroup | None = None self._running = False - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, @@ -232,7 +232,7 @@ async def send_request( finished. """ if not self._running: - raise RuntimeError("JSONRPCDispatcher.send_request called before run() / after close") + raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close") opts = opts or {} request_id = self._allocate_id() out_params = dict(params) if params is not None else None @@ -308,7 +308,7 @@ async def run( Each inbound request is handled in its own task in an internal task group; ``task_status.started()`` fires once that group is open, so - ``await tg.start(dispatcher.run, ...)`` resumes when ``send_request`` + ``await tg.start(dispatcher.run, ...)`` resumes when ``send_raw_request`` is usable. """ try: @@ -323,9 +323,9 @@ async def run( # snapshot per message). Plain memory streams don't. sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) self._dispatch(item, on_request, on_notify, sender_ctx) - # Read stream EOF: wake any blocked `send_request` waiters now, + # Read stream EOF: wake any blocked `send_raw_request` waiters now, # *before* the task group joins, so handlers parked in - # `dctx.send_request()` can unwind and the join doesn't deadlock. + # `dctx.send_raw_request()` can unwind and the join doesn't deadlock. self._running = False self._fan_out_closed() finally: @@ -457,7 +457,7 @@ def _spawn( self._tg.start_soon(fn, *args) def _fan_out_closed(self) -> None: - """Wake every pending ``send_request`` waiter with ``CONNECTION_CLOSED``. + """Wake every pending ``send_raw_request`` waiter with ``CONNECTION_CLOSED``. Synchronous (uses ``send_nowait``) because it's called from ``finally`` which may be inside a cancelled scope. Idempotent. @@ -491,7 +491,7 @@ async def _handle_request( # Close the back-channel the moment the handler exits # (success or raise), before the response write — a handler # spawning detached work that later calls - # `dctx.send_request()` should see `NoBackChannelError`. + # `dctx.send_raw_request()` should see `NoBackChannelError`. dctx.close() await self._write_result(req.id, result) # Peer-cancel: `_dispatch_notification` cancelled this scope. anyio diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index be6386d090..7f9f11718b 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -41,7 +41,7 @@ @pytest.mark.anyio -async def test_concurrent_send_requests_correlate_by_id_when_responses_arrive_out_of_order(): +async def test_concurrent_send_raw_requests_correlate_by_id_when_responses_arrive_out_of_order(): release_first = anyio.Event() async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -53,7 +53,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | results: dict[str, dict[str, Any]] = {} async def call(method: str) -> None: - results[method] = await client.send_request(method, None) + results[method] = await client.send_raw_request(method, None) with anyio.fail_after(5): async with anyio.create_task_group() as tg: # pragma: no branch @@ -76,7 +76,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("tools/list", None) + await client.send_raw_request("tools/list", None) assert exc.value.error.code == INTERNAL_ERROR assert exc.value.error.message == "kaboom" assert exc.value.__cause__ is None # cause does not survive the wire @@ -103,7 +103,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async def call_then_record() -> None: with pytest.raises(MCPError): # we'll cancel via tg below - await client.send_request("slow", None) + await client.send_raw_request("slow", None) tg.start_soon(call_then_record) await handler_started.wait() @@ -140,7 +140,7 @@ def factory(*, can_send_request: bool = True): async with anyio.create_task_group() as tg: # pragma: no branch async def call() -> None: - result_box.append(await client.send_request("slow", None)) + result_box.append(await client.send_raw_request("slow", None)) tg.start_soon(call) await handler_started.wait() @@ -150,8 +150,8 @@ async def call() -> None: @pytest.mark.anyio -async def test_send_request_raises_connection_closed_when_read_stream_eofs_mid_await(): - """A blocked send_request is woken with CONNECTION_CLOSED when run() exits.""" +async def test_send_raw_request_raises_connection_closed_when_read_stream_eofs_mid_await(): + """A blocked send_raw_request is woken with CONNECTION_CLOSED when run() exits.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -162,7 +162,7 @@ async def test_send_request_raises_connection_closed_when_read_stream_eofs_mid_a async def caller() -> None: with pytest.raises(MCPError) as exc: - await client.send_request("ping", None) + await client.send_raw_request("ping", None) assert exc.value.error.code == CONNECTION_CLOSED tg.start_soon(caller) @@ -187,13 +187,13 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): with pytest.raises(MCPError): # REQUEST_TIMEOUT - await client.send_request("slow", None, {"timeout": 0}) + await client.send_raw_request("slow", None, {"timeout": 0}) # The server handler is still running; let it finish and write a # response for an id the client has already discarded. await handler_started.wait() proceed.set() # One more round-trip proves the dispatcher is still healthy. - assert await client.send_request("ping", None) == {"late": True} + assert await client.send_raw_request("ping", None) == {"late": True} @pytest.mark.anyio @@ -234,14 +234,14 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio -async def test_ctx_send_request_tags_outbound_with_server_message_metadata(): +async def test_ctx_send_raw_request_tags_outbound_with_server_message_metadata(): """Server-to-client requests carry related_request_id for SHTTP routing.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - return await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) + return await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: raise NotImplementedError @@ -285,7 +285,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("t", None, {"on_progress": on_progress}) + await client.send_raw_request("t", None, {"on_progress": on_progress}) assert received == [(0.25, None, None)] @@ -297,18 +297,18 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("t", None) + await client.send_raw_request("t", None) assert exc.value.error.code == INVALID_PARAMS @pytest.mark.anyio -async def test_send_request_before_run_raises_runtimeerror(): +async def test_send_raw_request_before_run_raises_runtimeerror(): c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) try: with pytest.raises(RuntimeError, match="before run"): - await d.send_request("ping", None) + await d.send_raw_request("ping", None) finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): s.close() @@ -351,7 +351,7 @@ async def test_cancelled_notification_for_unknown_request_id_is_noop(): with anyio.fail_after(5): await client.notify("notifications/cancelled", {"requestId": 999}) # No effect; dispatcher remains healthy. - assert await client.send_request("t", None) == {"echoed": "t", "params": {}} + assert await client.send_raw_request("t", None) == {"echoed": "t", "params": {}} assert srec.notifications == [] # cancelled is fully consumed, never teed @@ -451,7 +451,7 @@ async def test_cancel_outbound_after_write_stream_closed_is_swallowed(): async def caller() -> None: with caller_scope: - await client.send_request("slow", None) + await client.send_raw_request("slow", None) caller_done.set() tg.start_soon(caller) From 689b78486daebd74b82af45b1390b4f8229065a2 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:21:27 +0000 Subject: [PATCH 12/52] feat: PeerMixin and Peer wrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PeerMixin defines the typed server-to-client request methods (sample with overloads, elicit_form, elicit_url, list_roots, ping) once. Each method constrains `self: Outbound` so any class with send_request/notify can mix it in — pyright checks the host structurally at the call site. The mixin does no capability gating; that's the host's send_request's job. Peer is a trivial standalone wrapper for when you have a bare Outbound (e.g. a dispatcher) and want the typed sugar without writing your own host class. 6 tests over DirectDispatcher, 0.03s. --- src/mcp/shared/peer.py | 194 ++++++++++++++++++++++++++++++++++++++ tests/shared/test_peer.py | 128 +++++++++++++++++++++++++ 2 files changed, 322 insertions(+) create mode 100644 src/mcp/shared/peer.py create mode 100644 tests/shared/test_peer.py diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py new file mode 100644 index 0000000000..b5d4b960ed --- /dev/null +++ b/src/mcp/shared/peer.py @@ -0,0 +1,194 @@ +"""Typed MCP request sugar over an `Outbound`. + +`PeerMixin` defines the server-to-client request methods (sampling, elicitation, +roots, ping) once. Any class that satisfies `Outbound` (i.e. has `send_request` +and `notify`) can mix it in and get the typed methods for free — `Context`, +`Connection`, `Client`, or the bare `Peer` wrapper below. + +The mixin does no capability gating: it builds the params, calls +``self.send_request(method, params)``, and parses the result into the typed +model. Gating (and `NoBackChannelError`) is the host's `send_request`'s job. +""" + +from collections.abc import Mapping +from typing import Any, overload + +from pydantic import BaseModel + +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + IncludeContext, + ListRootsResult, + ModelPreferences, + SamplingMessage, + Tool, + ToolChoice, +) + +__all__ = ["Peer", "PeerMixin"] + + +def _dump(model: BaseModel) -> dict[str, Any]: + return model.model_dump(by_alias=True, mode="json", exclude_none=True) + + +class PeerMixin: + """Typed server-to-client request methods. + + Each method constrains ``self`` to `Outbound` so the mixin can be applied + to anything with ``send_request``/``notify`` — pyright checks the host + class structurally at the call site. + """ + + @overload + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: None = None, + tool_choice: ToolChoice | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResult: ... + @overload + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool], + tool_choice: ToolChoice | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResultWithTools: ... + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool] | None = None, + tool_choice: ToolChoice | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResult | CreateMessageResultWithTools: + """Send a ``sampling/createMessage`` request to the peer. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: The host's transport context has no + back-channel for server-initiated requests. + """ + params = CreateMessageRequestParams( + messages=messages, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + max_tokens=max_tokens, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + ) + result = await self.send_request("sampling/createMessage", _dump(params), opts) + if tools is not None: + return CreateMessageResultWithTools.model_validate(result) + return CreateMessageResult.model_validate(result) + + async def elicit_form( + self: Outbound, + message: str, + requested_schema: ElicitRequestedSchema, + opts: CallOptions | None = None, + ) -> ElicitResult: + """Send a form-mode ``elicitation/create`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) + result = await self.send_request("elicitation/create", _dump(params), opts) + return ElicitResult.model_validate(result) + + async def elicit_url( + self: Outbound, + message: str, + url: str, + elicitation_id: str, + opts: CallOptions | None = None, + ) -> ElicitResult: + """Send a URL-mode ``elicitation/create`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) + result = await self.send_request("elicitation/create", _dump(params), opts) + return ElicitResult.model_validate(result) + + async def list_roots(self: Outbound, opts: CallOptions | None = None) -> ListRootsResult: + """Send a ``roots/list`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + result = await self.send_request("roots/list", None, opts) + return ListRootsResult.model_validate(result) + + async def ping(self: Outbound, opts: CallOptions | None = None) -> None: + """Send a ``ping`` request and ignore the result. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + await self.send_request("ping", None, opts) + + +class Peer(PeerMixin): + """Standalone wrapper that gives any `Outbound` the `PeerMixin` sugar. + + `Context` and `Connection` mix `PeerMixin` in directly; use `Peer` when + you have a bare dispatcher (or any `Outbound`) and want the typed methods + without writing your own host class. + """ + + def __init__(self, outbound: Outbound) -> None: + self._outbound = outbound + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + return await self._outbound.send_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._outbound.notify(method, params) diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py new file mode 100644 index 0000000000..43d49252cb --- /dev/null +++ b/tests/shared/test_peer.py @@ -0,0 +1,128 @@ +"""Tests for `PeerMixin` and `Peer`. + +Each PeerMixin method is tested by wrapping a `DirectDispatcher` in `Peer`, +calling the typed method, and asserting (a) the right method+params went out +and (b) the return value is the typed result model. +""" + +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.peer import Peer +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CreateMessageResult, + CreateMessageResultWithTools, + ElicitResult, + ListRootsResult, + SamplingMessage, + TextContent, + Tool, +) + +from .conftest import direct_pair +from .test_dispatcher import running_pair + +DCtx = DispatchContext[TransportContext] + + +class _Recorder: + def __init__(self, result: dict[str, Any]) -> None: + self.result = result + self.seen: list[tuple[str, Mapping[str, Any] | None]] = [] + + async def on_request(self, ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + self.seen.append((method, params)) + return self.result + + +@pytest.mark.anyio +async def test_peer_sample_sends_create_message_and_returns_typed_result(): + rec = _Recorder({"role": "assistant", "content": {"type": "text", "text": "hi"}, "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="hello"))], + max_tokens=10, + ) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["maxTokens"] == 10 + assert isinstance(result, CreateMessageResult) + assert result.model == "m" + + +@pytest.mark.anyio +async def test_peer_sample_with_tools_returns_with_tools_result(): + rec = _Recorder({"role": "assistant", "content": [{"type": "text", "text": "x"}], "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="q"))], + max_tokens=5, + tools=[Tool(name="t", input_schema={"type": "object"})], + ) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["tools"][0]["name"] == "t" + assert isinstance(result, CreateMessageResultWithTools) + + +@pytest.mark.anyio +async def test_peer_elicit_form_sends_elicitation_create_with_form_params(): + rec = _Recorder({"action": "accept", "content": {"name": "Max"}}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.elicit_form("Your name?", requested_schema={"type": "object", "properties": {}}) + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "form" + assert params["message"] == "Your name?" + assert isinstance(result, ElicitResult) + + +@pytest.mark.anyio +async def test_peer_elicit_url_sends_elicitation_create_with_url_params(): + rec = _Recorder({"action": "accept"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.elicit_url("Auth needed", url="https://example.com/auth", elicitation_id="e1") + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "url" + assert params["url"] == "https://example.com/auth" + assert isinstance(result, ElicitResult) + + +@pytest.mark.anyio +async def test_peer_list_roots_sends_roots_list_and_returns_typed_result(): + rec = _Recorder({"roots": [{"uri": "file:///workspace"}]}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.list_roots() + method, _ = rec.seen[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert len(result.roots) == 1 + assert str(result.roots[0].uri) == "file:///workspace" + + +@pytest.mark.anyio +async def test_peer_ping_sends_ping_and_returns_none(): + rec = _Recorder({}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.ping() + method, _ = rec.seen[0] + assert method == "ping" + assert result is None From 1096712bd9bc1dfa6c2c2ead55cfc93e5488f604 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:42:46 +0000 Subject: [PATCH 13/52] feat: BaseContext Composition over a DispatchContext: forwards transport/cancel_requested/ send_request/notify/progress and adds meta. Satisfies Outbound so PeerMixin works on it (proven by Peer(bctx).ping() round-tripping). The server Context (next commit) extends this with lifespan/connection; ClientContext will be an alias once ClientSession is reworked. --- src/mcp/shared/context.py | 82 +++++++++++++++++++++++++ tests/shared/test_context.py | 115 +++++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 src/mcp/shared/context.py create mode 100644 tests/shared/test_context.py diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py new file mode 100644 index 0000000000..f6a33d719a --- /dev/null +++ b/src/mcp/shared/context.py @@ -0,0 +1,82 @@ +"""`BaseContext` — the user-facing per-request context. + +Composition over a `DispatchContext`: forwards the transport metadata, the +back-channel (`send_request`/`notify`), progress reporting, and the cancel +event. Adds `meta` (the inbound request's `_meta` field). + +Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context` +mixes that in directly). Shared between client and server: the server's +`Context` extends this with `lifespan`/`connection`; `ClientContext` is just an +alias. +""" + +from collections.abc import Mapping +from typing import Any, Generic + +import anyio +from typing_extensions import TypeVar + +from mcp.shared.dispatcher import CallOptions, DispatchContext +from mcp.shared.transport_context import TransportContext +from mcp.types import RequestParamsMeta + +__all__ = ["BaseContext"] + +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) + + +class BaseContext(Generic[TransportT]): + """Per-request context wrapping a `DispatchContext`. + + `ServerRunner` (PR4) constructs one per inbound request and passes it to + the user's handler. + """ + + def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None: + self._dctx = dctx + self._meta = meta + + @property + def transport(self) -> TransportT: + """Transport-specific metadata for this inbound request.""" + return self._dctx.transport + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends ``notifications/cancelled`` for this request.""" + return self._dctx.cancel_requested + + @property + def can_send_request(self) -> bool: + """Whether the back-channel can deliver server-initiated requests.""" + return self._dctx.transport.can_send_request + + @property + def meta(self) -> RequestParamsMeta | None: + """The inbound request's ``_meta`` field, if present.""" + return self._meta + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request to the peer on the back-channel. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``can_send_request`` is ``False``. + """ + return await self._dctx.send_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a notification to the peer on the back-channel.""" + await self._dctx.notify(method, params) + + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for this request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + await self._dctx.progress(progress, total, message) diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py new file mode 100644 index 0000000000..5d93768433 --- /dev/null +++ b/tests/shared/test_context.py @@ -0,0 +1,115 @@ +"""Tests for `BaseContext`. + +`BaseContext` is composition over a `DispatchContext` — it forwards +``transport``/``cancel_requested``/``send_request``/``notify``/``progress`` +and adds ``meta``. It must satisfy `Outbound` so `PeerMixin` works on it. +""" + +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared.context import BaseContext +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.peer import Peer +from mcp.shared.transport_context import TransportContext + +from .conftest import direct_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_base_context_forwards_transport_and_cancel_requested(): + captured: list[BaseContext[TransportContext]] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + captured.append(bctx) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_request("t", None) + bctx = captured[0] + assert bctx.transport.kind == "direct" + assert isinstance(bctx.cancel_requested, anyio.Event) + assert bctx.can_send_request is True + assert bctx.meta is None + + +@pytest.mark.anyio +async def test_base_context_send_request_and_notify_forward_to_dispatch_context(): + crec = Recorder() + c_req, c_notify = echo_handlers(crec) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + sample = await bctx.send_request("sampling/createMessage", {"x": 1}) + await bctx.notify("notifications/message", {"level": "info"}) + return {"sample": sample} + + async with running_pair( + direct_pair, + server_on_request=server_on_request, + client_on_request=c_req, + client_on_notify=c_notify, + ) as (client, *_): + with anyio.fail_after(5): + result = await client.send_request("tools/call", None) + await crec.notified.wait() + assert crec.requests == [("sampling/createMessage", {"x": 1})] + assert crec.notifications == [("notifications/message", {"level": "info"})] + assert result["sample"] == {"echoed": "sampling/createMessage", "params": {"x": 1}} + + +@pytest.mark.anyio +async def test_base_context_report_progress_invokes_caller_on_progress(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + await bctx.report_progress(0.5, total=1.0, message="halfway") + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_request("t", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_base_context_satisfies_outbound_so_peer_mixin_works(): + """Wrapping a BaseContext in Peer proves it satisfies Outbound structurally.""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + await Peer(bctx).ping() + return {} + + crec = Recorder() + c_req, c_notify = echo_handlers(crec) + async with running_pair( + direct_pair, server_on_request=server_on_request, client_on_request=c_req, client_on_notify=c_notify + ) as (client, *_): + with anyio.fail_after(5): + await client.send_request("t", None) + assert crec.requests == [("ping", None)] + + +@pytest.mark.anyio +async def test_base_context_meta_holds_supplied_request_params_meta(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx, meta={"progressToken": "abc"}) + assert bctx.meta is not None and bctx.meta.get("progressToken") == "abc" + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_request("t", None) From efd7df7b188f1d817f2224771944cede742bf018 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 22:02:36 +0000 Subject: [PATCH 14/52] refactor: follow Outbound.send_raw_request rename in PeerMixin/BaseContext PeerMixin methods and Peer/BaseContext now call/expose send_raw_request. The typed send_request lands on Connection/Context in the next commit. --- src/mcp/shared/context.py | 6 +++--- src/mcp/shared/peer.py | 22 +++++++++++----------- tests/shared/test_context.py | 16 ++++++++-------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index f6a33d719a..68f439b738 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,7 +1,7 @@ """`BaseContext` — the user-facing per-request context. Composition over a `DispatchContext`: forwards the transport metadata, the -back-channel (`send_request`/`notify`), progress reporting, and the cancel +back-channel (`send_raw_request`/`notify`), progress reporting, and the cancel event. Adds `meta` (the inbound request's `_meta` field). Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context` @@ -56,7 +56,7 @@ def meta(self) -> RequestParamsMeta | None: """The inbound request's ``_meta`` field, if present.""" return self._meta - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, @@ -68,7 +68,7 @@ async def send_request( MCPError: The peer responded with an error. NoBackChannelError: ``can_send_request`` is ``False``. """ - return await self._dctx.send_request(method, params, opts) + return await self._dctx.send_raw_request(method, params, opts) async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: """Send a notification to the peer on the back-channel.""" diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py index b5d4b960ed..9951081104 100644 --- a/src/mcp/shared/peer.py +++ b/src/mcp/shared/peer.py @@ -1,13 +1,13 @@ """Typed MCP request sugar over an `Outbound`. `PeerMixin` defines the server-to-client request methods (sampling, elicitation, -roots, ping) once. Any class that satisfies `Outbound` (i.e. has `send_request` +roots, ping) once. Any class that satisfies `Outbound` (i.e. has `send_raw_request` and `notify`) can mix it in and get the typed methods for free — `Context`, `Connection`, `Client`, or the bare `Peer` wrapper below. The mixin does no capability gating: it builds the params, calls -``self.send_request(method, params)``, and parses the result into the typed -model. Gating (and `NoBackChannelError`) is the host's `send_request`'s job. +``self.send_raw_request(method, params)``, and parses the result into the typed +model. Gating (and `NoBackChannelError`) is the host's `send_raw_request`'s job. """ from collections.abc import Mapping @@ -43,7 +43,7 @@ class PeerMixin: """Typed server-to-client request methods. Each method constrains ``self`` to `Outbound` so the mixin can be applied - to anything with ``send_request``/``notify`` — pyright checks the host + to anything with ``send_raw_request``/``notify`` — pyright checks the host class structurally at the call site. """ @@ -113,7 +113,7 @@ async def sample( tools=tools, tool_choice=tool_choice, ) - result = await self.send_request("sampling/createMessage", _dump(params), opts) + result = await self.send_raw_request("sampling/createMessage", _dump(params), opts) if tools is not None: return CreateMessageResultWithTools.model_validate(result) return CreateMessageResult.model_validate(result) @@ -131,7 +131,7 @@ async def elicit_form( NoBackChannelError: No back-channel for server-initiated requests. """ params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) - result = await self.send_request("elicitation/create", _dump(params), opts) + result = await self.send_raw_request("elicitation/create", _dump(params), opts) return ElicitResult.model_validate(result) async def elicit_url( @@ -148,7 +148,7 @@ async def elicit_url( NoBackChannelError: No back-channel for server-initiated requests. """ params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) - result = await self.send_request("elicitation/create", _dump(params), opts) + result = await self.send_raw_request("elicitation/create", _dump(params), opts) return ElicitResult.model_validate(result) async def list_roots(self: Outbound, opts: CallOptions | None = None) -> ListRootsResult: @@ -158,7 +158,7 @@ async def list_roots(self: Outbound, opts: CallOptions | None = None) -> ListRoo MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. """ - result = await self.send_request("roots/list", None, opts) + result = await self.send_raw_request("roots/list", None, opts) return ListRootsResult.model_validate(result) async def ping(self: Outbound, opts: CallOptions | None = None) -> None: @@ -168,7 +168,7 @@ async def ping(self: Outbound, opts: CallOptions | None = None) -> None: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. """ - await self.send_request("ping", None, opts) + await self.send_raw_request("ping", None, opts) class Peer(PeerMixin): @@ -182,13 +182,13 @@ class Peer(PeerMixin): def __init__(self, outbound: Outbound) -> None: self._outbound = outbound - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None, ) -> dict[str, Any]: - return await self._outbound.send_request(method, params, opts) + return await self._outbound.send_raw_request(method, params, opts) async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: await self._outbound.notify(method, params) diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py index 5d93768433..951690028f 100644 --- a/tests/shared/test_context.py +++ b/tests/shared/test_context.py @@ -1,7 +1,7 @@ """Tests for `BaseContext`. `BaseContext` is composition over a `DispatchContext` — it forwards -``transport``/``cancel_requested``/``send_request``/``notify``/``progress`` +``transport``/``cancel_requested``/``send_raw_request``/``notify``/``progress`` and adds ``meta``. It must satisfy `Outbound` so `PeerMixin` works on it. """ @@ -33,7 +33,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("t", None) + await client.send_raw_request("t", None) bctx = captured[0] assert bctx.transport.kind == "direct" assert isinstance(bctx.cancel_requested, anyio.Event) @@ -42,13 +42,13 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | @pytest.mark.anyio -async def test_base_context_send_request_and_notify_forward_to_dispatch_context(): +async def test_base_context_send_raw_request_and_notify_forward_to_dispatch_context(): crec = Recorder() c_req, c_notify = echo_handlers(crec) async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: bctx = BaseContext(ctx) - sample = await bctx.send_request("sampling/createMessage", {"x": 1}) + sample = await bctx.send_raw_request("sampling/createMessage", {"x": 1}) await bctx.notify("notifications/message", {"level": "info"}) return {"sample": sample} @@ -59,7 +59,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | client_on_notify=c_notify, ) as (client, *_): with anyio.fail_after(5): - result = await client.send_request("tools/call", None) + result = await client.send_raw_request("tools/call", None) await crec.notified.wait() assert crec.requests == [("sampling/createMessage", {"x": 1})] assert crec.notifications == [("notifications/message", {"level": "info"})] @@ -80,7 +80,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("t", None, {"on_progress": on_progress}) + await client.send_raw_request("t", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @@ -99,7 +99,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | direct_pair, server_on_request=server_on_request, client_on_request=c_req, client_on_notify=c_notify ) as (client, *_): with anyio.fail_after(5): - await client.send_request("t", None) + await client.send_raw_request("t", None) assert crec.requests == [("ping", None)] @@ -112,4 +112,4 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("t", None) + await client.send_raw_request("t", None) From 7d18a7daa21e6aa3d1fb5bce342f4a1c94e7b807 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 22:21:02 +0000 Subject: [PATCH 15/52] feat: Connection, server Context, typed send_request, meta kwarg TypedServerRequestMixin (server/_typed_request.py) provides shape-2 typed send_request: per-spec overloads (CreateMessage/Elicit/ListRoots/Ping) infer the result type; custom requests pass result_type explicitly. Mixed into both Connection and the server Context. Connection (server/connection.py) wraps an Outbound for the standalone stream. notify is best-effort (never raises); send_raw_request gated on has_standalone_channel; check_capability mirrors v1 for now (FOLLOWUP). Holds peer info populated at initialize time and the per-connection lifespan state. Context (server/context.py, alongside v1's ServerRequestContext) composes BaseContext + PeerMixin + TypedServerRequestMixin and adds lifespan/connection. Request-scoped log() rides the request's back-channel; ctx.connection.log() uses the standalone stream. dump_params(model, meta) merges user-supplied meta into _meta; threaded through every PeerMixin and Connection convenience method. 31 tests, 0.06s. --- src/mcp/server/_typed_request.py | 85 +++++++++++++ src/mcp/server/connection.py | 146 ++++++++++++++++++++++ src/mcp/server/context.py | 60 +++++++++ src/mcp/shared/peer.py | 48 ++++++-- tests/server/test_connection.py | 184 ++++++++++++++++++++++++++++ tests/server/test_server_context.py | 131 ++++++++++++++++++++ tests/shared/test_peer.py | 21 +++- 7 files changed, 661 insertions(+), 14 deletions(-) create mode 100644 src/mcp/server/_typed_request.py create mode 100644 src/mcp/server/connection.py create mode 100644 tests/server/test_connection.py create mode 100644 tests/server/test_server_context.py diff --git a/src/mcp/server/_typed_request.py b/src/mcp/server/_typed_request.py new file mode 100644 index 0000000000..50cae159d1 --- /dev/null +++ b/src/mcp/server/_typed_request.py @@ -0,0 +1,85 @@ +"""Shape-2 typed ``send_request`` for server-to-client requests. + +`TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over +the host's raw `Outbound.send_raw_request`. Spec server-to-client request types +have their result type inferred via per-type overloads; custom requests pass +``result_type=`` explicitly. + +A `HasResult[R]` protocol (one generic signature, mapping declared on the +request type) is the cleaner long-term shape — see FOLLOWUPS.md. This per-spec +overload set is used for now to avoid touching `mcp.types`. +""" + +from typing import Any, TypeVar, overload + +from pydantic import BaseModel + +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.peer import dump_params +from mcp.types import ( + CreateMessageRequest, + CreateMessageResult, + ElicitRequest, + ElicitResult, + EmptyResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + Request, +) + +__all__ = ["TypedServerRequestMixin"] + +ResultT = TypeVar("ResultT", bound=BaseModel) + +_RESULT_FOR: dict[type[Request[Any, Any]], type[BaseModel]] = { + CreateMessageRequest: CreateMessageResult, + ElicitRequest: ElicitResult, + ListRootsRequest: ListRootsResult, + PingRequest: EmptyResult, +} + + +class TypedServerRequestMixin: + """Typed ``send_request`` for the server-to-client request set. + + Mixed into `Connection` and the server `Context`. Each method constrains + ``self`` to `Outbound` so any host with ``send_raw_request`` works. + """ + + @overload + async def send_request( + self: Outbound, req: CreateMessageRequest, *, opts: CallOptions | None = None + ) -> CreateMessageResult: ... + @overload + async def send_request(self: Outbound, req: ElicitRequest, *, opts: CallOptions | None = None) -> ElicitResult: ... + @overload + async def send_request( + self: Outbound, req: ListRootsRequest, *, opts: CallOptions | None = None + ) -> ListRootsResult: ... + @overload + async def send_request(self: Outbound, req: PingRequest, *, opts: CallOptions | None = None) -> EmptyResult: ... + @overload + async def send_request( + self: Outbound, req: Request[Any, Any], *, result_type: type[ResultT], opts: CallOptions | None = None + ) -> ResultT: ... + async def send_request( + self: Outbound, + req: Request[Any, Any], + *, + result_type: type[BaseModel] | None = None, + opts: CallOptions | None = None, + ) -> BaseModel: + """Send a typed server-to-client request and return its typed result. + + For spec request types the result type is inferred. For custom requests + pass ``result_type=`` explicitly. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + KeyError: ``result_type`` omitted for a non-spec request type. + """ + raw = await self.send_raw_request(req.method, dump_params(req.params), opts) + cls = result_type if result_type is not None else _RESULT_FOR[type(req)] + return cls.model_validate(raw) diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py new file mode 100644 index 0000000000..72c4ed062f --- /dev/null +++ b/src/mcp/server/connection.py @@ -0,0 +1,146 @@ +"""`Connection` — per-client connection state and the standalone outbound channel. + +Always present on `Context` (never ``None``), even in stateless deployments. +Holds peer info populated at ``initialize`` time, the per-connection lifespan +output, and an `Outbound` for the standalone stream (the SSE GET stream in +streamable HTTP, or the single duplex stream in stdio). + +`notify` is best-effort: it never raises. If there's no standalone channel +(stateless HTTP) or the stream has been dropped, the notification is +debug-logged and silently discarded — server-initiated notifications are +inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when +there's no channel; `ping` is the only spec-sanctioned standalone request. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +import anyio + +from mcp.server._typed_request import TypedServerRequestMixin +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.exceptions import NoBackChannelError +from mcp.shared.peer import Meta, dump_params +from mcp.types import ClientCapabilities, Implementation, LoggingLevel + +__all__ = ["Connection"] + +logger = logging.getLogger(__name__) + + +def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> dict[str, Any] | None: + if not meta: + return payload + out = dict(payload or {}) + out["_meta"] = meta + return out + + +class Connection(TypedServerRequestMixin): + """Per-client connection state and standalone-stream `Outbound`. + + Constructed by `ServerRunner` once per connection. The peer-info fields are + ``None`` until ``initialize`` completes; ``initialized`` is set then. + """ + + def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None: + self._outbound = outbound + self.has_standalone_channel = has_standalone_channel + + self.client_info: Implementation | None = None + self.client_capabilities: ClientCapabilities | None = None + self.protocol_version: str | None = None + self.initialized: anyio.Event = anyio.Event() + # TODO: make this generic (Connection[StateT]) once connection_lifespan + # wiring lands in ServerRunner — see FOLLOWUPS.md. + self.state: Any = None + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a raw request on the standalone stream. + + Low-level `Outbound` channel. Prefer the typed ``send_request`` (from + `TypedServerRequestMixin`) or the convenience methods below; use this + directly only for off-spec messages. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``has_standalone_channel`` is ``False``. + """ + if not self.has_standalone_channel: + raise NoBackChannelError(method) + return await self._outbound.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a best-effort notification on the standalone stream. + + Never raises. If there's no standalone channel or the stream is broken, + the notification is dropped and debug-logged. + """ + if not self.has_standalone_channel: + logger.debug("dropped %s: no standalone channel", method) + return + try: + await self._outbound.notify(method, params) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped %s: standalone stream closed", method) + + async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: + """Send a ``ping`` request on the standalone stream. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``has_standalone_channel`` is ``False``. + """ + await self.send_raw_request("ping", dump_params(None, meta), opts) + + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: + """Send a ``notifications/message`` log entry on the standalone stream. Best-effort.""" + params: dict[str, Any] = {"level": level, "data": data} + if logger is not None: + params["logger"] = logger + await self.notify("notifications/message", _notification_params(params, meta)) + + async def send_tool_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/tools/list_changed", _notification_params(None, meta)) + + async def send_prompt_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/prompts/list_changed", _notification_params(None, meta)) + + async def send_resource_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/resources/list_changed", _notification_params(None, meta)) + + async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> None: + await self.notify("notifications/resources/updated", _notification_params({"uri": uri}, meta)) + + def check_capability(self, capability: ClientCapabilities) -> bool: + """Return whether the connected client declared the given capability. + + Returns ``False`` if ``initialize`` hasn't completed yet. + """ + # TODO: redesign — mirrors v1 ServerSession.check_client_capability + # verbatim for parity. See FOLLOWUPS.md. + if self.client_capabilities is None: + return False + have = self.client_capabilities + if capability.roots is not None: + if have.roots is None: + return False + if capability.roots.list_changed and not have.roots.list_changed: + return False + if capability.sampling is not None and have.sampling is None: + return False + if capability.elicitation is not None and have.elicitation is None: + return False + if capability.experimental is not None: + if have.experimental is None: + return False + for k in capability.experimental: + if k not in have.experimental: + return False + return True diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index d8e11d78b2..b7b97acf8b 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -5,10 +5,17 @@ from typing_extensions import TypeVar +from mcp.server._typed_request import TypedServerRequestMixin +from mcp.server.connection import Connection from mcp.server.experimental.request_context import Experimental from mcp.server.session import ServerSession from mcp.shared._context import RequestContext +from mcp.shared.context import BaseContext +from mcp.shared.dispatcher import DispatchContext from mcp.shared.message import CloseSSEStreamCallback +from mcp.shared.peer import Meta, PeerMixin +from mcp.shared.transport_context import TransportContext +from mcp.types import LoggingLevel, RequestParamsMeta LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) @@ -21,3 +28,56 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex request: RequestT | None = None close_sse_stream: CloseSSEStreamCallback | None = None close_standalone_sse_stream: CloseSSEStreamCallback | None = None + + +LifespanT = TypeVar("LifespanT", default=Any) +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) + + +class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]): + """Server-side per-request context. + + Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`), + `PeerMixin` (kwarg-style ``sample``/``elicit_*``/``list_roots``/``ping``), + and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds + ``lifespan`` and ``connection``. + + Constructed by `ServerRunner` (PR4) per inbound request and handed to the + user's handler. + """ + + def __init__( + self, + dctx: DispatchContext[TransportT], + *, + lifespan: LifespanT, + connection: Connection, + meta: RequestParamsMeta | None = None, + ) -> None: + super().__init__(dctx, meta=meta) + self._lifespan = lifespan + self._connection = connection + + @property + def lifespan(self) -> LifespanT: + """The server-wide lifespan output (what `Server(..., lifespan=...)` yielded).""" + return self._lifespan + + @property + def connection(self) -> Connection: + """The per-client `Connection` for this request's connection.""" + return self._connection + + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: + """Send a request-scoped ``notifications/message`` log entry. + + Uses this request's back-channel (so the entry rides the request's SSE + stream in streamable HTTP), not the standalone stream — use + ``ctx.connection.log(...)`` for that. + """ + params: dict[str, Any] = {"level": level, "data": data} + if logger is not None: + params["logger"] = logger + if meta: + params["_meta"] = meta + await self.notify("notifications/message", params) diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py index 9951081104..47b64c7769 100644 --- a/src/mcp/shared/peer.py +++ b/src/mcp/shared/peer.py @@ -1,9 +1,9 @@ """Typed MCP request sugar over an `Outbound`. `PeerMixin` defines the server-to-client request methods (sampling, elicitation, -roots, ping) once. Any class that satisfies `Outbound` (i.e. has `send_raw_request` -and `notify`) can mix it in and get the typed methods for free — `Context`, -`Connection`, `Client`, or the bare `Peer` wrapper below. +roots, ping) once. Any class that satisfies `Outbound` (i.e. has +``send_raw_request`` and ``notify``) can mix it in and get the typed methods for +free — `Context`, `Connection`, `Client`, or the bare `Peer` wrapper below. The mixin does no capability gating: it builds the params, calls ``self.send_raw_request(method, params)``, and parses the result into the typed @@ -32,11 +32,24 @@ ToolChoice, ) -__all__ = ["Peer", "PeerMixin"] +__all__ = ["Meta", "Peer", "PeerMixin", "dump_params"] +Meta = dict[str, Any] +"""Type alias for the ``_meta`` field carried on request/notification params.""" -def _dump(model: BaseModel) -> dict[str, Any]: - return model.model_dump(by_alias=True, mode="json", exclude_none=True) + +def dump_params(model: BaseModel | None, meta: Meta | None = None) -> dict[str, Any] | None: + """Serialize a params model to a wire dict, merging ``meta`` into ``_meta``. + + Shared by `PeerMixin`, `Connection`, and `TypedServerRequestMixin` so every + typed convenience method gets the same `_meta` handling. ``meta`` keys take + precedence over any ``_meta`` already present on the model. + """ + out = model.model_dump(by_alias=True, mode="json", exclude_none=True) if model is not None else None + if meta: + out = dict(out or {}) + out["_meta"] = {**out.get("_meta", {}), **meta} + return out class PeerMixin: @@ -61,6 +74,7 @@ async def sample( model_preferences: ModelPreferences | None = None, tools: None = None, tool_choice: ToolChoice | None = None, + meta: Meta | None = None, opts: CallOptions | None = None, ) -> CreateMessageResult: ... @overload @@ -77,6 +91,7 @@ async def sample( model_preferences: ModelPreferences | None = None, tools: list[Tool], tool_choice: ToolChoice | None = None, + meta: Meta | None = None, opts: CallOptions | None = None, ) -> CreateMessageResultWithTools: ... async def sample( @@ -92,6 +107,7 @@ async def sample( model_preferences: ModelPreferences | None = None, tools: list[Tool] | None = None, tool_choice: ToolChoice | None = None, + meta: Meta | None = None, opts: CallOptions | None = None, ) -> CreateMessageResult | CreateMessageResultWithTools: """Send a ``sampling/createMessage`` request to the peer. @@ -113,7 +129,7 @@ async def sample( tools=tools, tool_choice=tool_choice, ) - result = await self.send_raw_request("sampling/createMessage", _dump(params), opts) + result = await self.send_raw_request("sampling/createMessage", dump_params(params, meta), opts) if tools is not None: return CreateMessageResultWithTools.model_validate(result) return CreateMessageResult.model_validate(result) @@ -122,6 +138,8 @@ async def elicit_form( self: Outbound, message: str, requested_schema: ElicitRequestedSchema, + *, + meta: Meta | None = None, opts: CallOptions | None = None, ) -> ElicitResult: """Send a form-mode ``elicitation/create`` request. @@ -131,7 +149,7 @@ async def elicit_form( NoBackChannelError: No back-channel for server-initiated requests. """ params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) - result = await self.send_raw_request("elicitation/create", _dump(params), opts) + result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) return ElicitResult.model_validate(result) async def elicit_url( @@ -139,6 +157,8 @@ async def elicit_url( message: str, url: str, elicitation_id: str, + *, + meta: Meta | None = None, opts: CallOptions | None = None, ) -> ElicitResult: """Send a URL-mode ``elicitation/create`` request. @@ -148,27 +168,29 @@ async def elicit_url( NoBackChannelError: No back-channel for server-initiated requests. """ params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) - result = await self.send_raw_request("elicitation/create", _dump(params), opts) + result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) return ElicitResult.model_validate(result) - async def list_roots(self: Outbound, opts: CallOptions | None = None) -> ListRootsResult: + async def list_roots( + self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None + ) -> ListRootsResult: """Send a ``roots/list`` request. Raises: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. """ - result = await self.send_raw_request("roots/list", None, opts) + result = await self.send_raw_request("roots/list", dump_params(None, meta), opts) return ListRootsResult.model_validate(result) - async def ping(self: Outbound, opts: CallOptions | None = None) -> None: + async def ping(self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: """Send a ``ping`` request and ignore the result. Raises: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. """ - await self.send_raw_request("ping", None, opts) + await self.send_raw_request("ping", dump_params(None, meta), opts) class Peer(PeerMixin): diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py new file mode 100644 index 0000000000..eb3440085a --- /dev/null +++ b/tests/server/test_connection.py @@ -0,0 +1,184 @@ +"""Tests for `Connection`. + +`Connection` wraps an `Outbound` (the standalone stream). Its `notify` is +best-effort (never raises); `send_raw_request` is gated on +``has_standalone_channel``. Tested with a stub `Outbound` so we can assert wire +shape and inject failures. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.shared.dispatcher import CallOptions +from mcp.shared.exceptions import NoBackChannelError +from mcp.types import ( + ClientCapabilities, + ElicitationCapability, + EmptyResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + RootsCapability, + SamplingCapability, +) + + +class StubOutbound: + def __init__( + self, *, result: dict[str, Any] | None = None, raise_on_send: type[BaseException] | None = None + ) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self._result = result if result is not None else {} + self._raise_on_send = raise_on_send + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.requests.append((method, params)) + return self._result + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._raise_on_send is not None: + raise self._raise_on_send() + self.notifications.append((method, params)) + + +@pytest.mark.anyio +async def test_connection_notify_forwards_to_outbound(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.notify("notifications/message", {"level": "info", "data": "hi"}) + assert out.notifications == [("notifications/message", {"level": "info", "data": "hi"})] + + +@pytest.mark.anyio +async def test_connection_notify_swallows_broken_stream_and_debug_logs(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger="mcp.server.connection") + out = StubOutbound(raise_on_send=anyio.BrokenResourceError) + conn = Connection(out, has_standalone_channel=True) + await conn.notify("notifications/message", {"data": "x"}) # must not raise + assert "stream closed" in caplog.text.lower() + + +@pytest.mark.anyio +async def test_connection_notify_drops_when_no_standalone_channel(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger="mcp.server.connection") + out = StubOutbound() + conn = Connection(out, has_standalone_channel=False) + await conn.notify("notifications/message", {"data": "x"}) # must not raise + assert out.notifications == [] + assert "no standalone channel" in caplog.text.lower() + + +@pytest.mark.anyio +async def test_connection_send_raw_request_raises_nobackchannel_when_no_standalone_channel(): + conn = Connection(StubOutbound(), has_standalone_channel=False) + with pytest.raises(NoBackChannelError): + await conn.send_raw_request("ping", None) + + +@pytest.mark.anyio +async def test_connection_send_raw_request_forwards_when_standalone_channel_present(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_raw_request("ping", None) + assert out.requests == [("ping", None)] + assert result == {} + + +@pytest.mark.anyio +async def test_connection_send_request_with_spec_type_infers_result_type(): + out = StubOutbound(result={"roots": [{"uri": "file:///ws"}]}) + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_request(ListRootsRequest()) + method, _ = out.requests[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert str(result.roots[0].uri) == "file:///ws" + + +@pytest.mark.anyio +async def test_connection_send_request_with_result_type_kwarg_validates_custom_type(): + out = StubOutbound(result={}) + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_request(PingRequest(), result_type=EmptyResult) + assert isinstance(result, EmptyResult) + + +@pytest.mark.anyio +async def test_connection_ping_sends_ping_on_standalone(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.ping() + assert out.requests == [("ping", None)] + + +@pytest.mark.anyio +async def test_connection_log_sends_logging_message_notification(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.log("info", {"k": "v"}, logger="my.logger") + method, params = out.notifications[0] + assert method == "notifications/message" + assert params is not None + assert params["level"] == "info" + assert params["data"] == {"k": "v"} + assert params["logger"] == "my.logger" + + +@pytest.mark.anyio +async def test_connection_log_with_meta_includes_meta_in_params(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.log("info", "x", meta={"traceId": "abc"}) + _, params = out.notifications[0] + assert params is not None + assert params["_meta"] == {"traceId": "abc"} + + +@pytest.mark.anyio +async def test_connection_list_changed_notifications_send_correct_methods(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.send_tool_list_changed() + await conn.send_prompt_list_changed() + await conn.send_resource_list_changed() + await conn.send_resource_updated("file:///workspace/a.txt") + methods = [m for m, _ in out.notifications] + assert methods == [ + "notifications/tools/list_changed", + "notifications/prompts/list_changed", + "notifications/resources/list_changed", + "notifications/resources/updated", + ] + assert out.notifications[-1][1] == {"uri": "file:///workspace/a.txt"} + + +@pytest.mark.anyio +async def test_connection_send_tool_list_changed_with_meta_includes_meta_only_params(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.send_tool_list_changed(meta={"k": 1}) + assert out.notifications == [("notifications/tools/list_changed", {"_meta": {"k": 1}})] + + +def test_connection_check_capability_false_before_initialized(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is False + + +def test_connection_check_capability_true_when_client_declares_it(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + conn.client_capabilities = ClientCapabilities( + sampling=SamplingCapability(), roots=RootsCapability(list_changed=True) + ) + conn.initialized.set() + assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is True + assert conn.check_capability(ClientCapabilities(roots=RootsCapability(list_changed=True))) is True + assert conn.check_capability(ClientCapabilities(elicitation=ElicitationCapability())) is False diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py new file mode 100644 index 0000000000..65db51c4a5 --- /dev/null +++ b/tests/server/test_server_context.py @@ -0,0 +1,131 @@ +"""Tests for the server-side `Context`. + +`Context` composes `BaseContext` (forwarding to a `DispatchContext`) with +`PeerMixin` (typed sample/elicit/roots/ping) plus `lifespan` and `connection`. +End-to-end tested over `DirectDispatcher`. +""" + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.transport_context import TransportContext +from mcp.types import CreateMessageResult, ListRootsRequest, ListRootsResult, SamplingMessage, TextContent + +from ..shared.conftest import direct_pair +from ..shared.test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@dataclass +class _Lifespan: + name: str + + +@pytest.mark.anyio +async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): + captured: list[Context[_Lifespan, TransportContext]] = [] + conn = Connection.__new__(Connection) # placeholder until running_pair gives us the dispatcher + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) + captured.append(ctx) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, server, *_): + # Now we have the server dispatcher; build the real Connection bound to it. + conn.__init__(server, has_standalone_channel=True) + with anyio.fail_after(5): + await client.send_raw_request("t", None) + ctx = captured[0] + assert ctx.lifespan.name == "app" + assert ctx.connection is conn + assert ctx.transport.kind == "direct" + assert ctx.can_send_request is True + + +@pytest.mark.anyio +async def test_context_sample_round_trips_via_peer_mixin_on_base_context_outbound(): + crec = Recorder() + + async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + crec.requests.append((method, params)) + return {"role": "assistant", "content": {"type": "text", "text": "ok"}, "model": "m"} + + results: list[CreateMessageResult] = [] + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + results.append( + await ctx.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="hi"))], + max_tokens=5, + ) + ) + return {} + + async with running_pair( + direct_pair, + server_on_request=server_on_request, + client_on_request=client_on_request, + ) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + assert crec.requests[0][0] == "sampling/createMessage" + assert isinstance(results[0], CreateMessageResult) + + +@pytest.mark.anyio +async def test_context_send_request_with_spec_type_infers_result_via_typed_mixin(): + async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return {"roots": []} + + results: list[ListRootsResult] = [] + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + results.append(await ctx.send_request(ListRootsRequest())) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_request=client_on_request) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + assert isinstance(results[0], ListRootsResult) + + +@pytest.mark.anyio +async def test_context_log_sends_request_scoped_message_notification(): + crec = Recorder() + _, c_notify = echo_handlers(crec) + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + await ctx.log("debug", "hello") + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_notify=c_notify) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + await crec.notified.wait() + method, params = crec.notifications[0] + assert method == "notifications/message" + assert params is not None and params["level"] == "debug" and params["data"] == "hello" diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py index 43d49252cb..0d7d9e9bae 100644 --- a/tests/shared/test_peer.py +++ b/tests/shared/test_peer.py @@ -12,7 +12,7 @@ import pytest from mcp.shared.dispatcher import DispatchContext -from mcp.shared.peer import Peer +from mcp.shared.peer import Peer, dump_params from mcp.shared.transport_context import TransportContext from mcp.types import ( CreateMessageResult, @@ -116,6 +116,25 @@ async def test_peer_list_roots_sends_roots_list_and_returns_typed_result(): assert str(result.roots[0].uri) == "file:///workspace" +@pytest.mark.anyio +async def test_peer_list_roots_with_meta_sends_meta_in_params(): + rec = _Recorder({"roots": []}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + await peer.list_roots(meta={"traceId": "t1"}) + method, params = rec.seen[0] + assert method == "roots/list" + assert params == {"_meta": {"traceId": "t1"}} + + +def test_dump_params_merges_meta_over_model_meta(): + out = dump_params(None, None) + assert out is None + out = dump_params(None, {"k": 1}) + assert out == {"_meta": {"k": 1}} + + @pytest.mark.anyio async def test_peer_ping_sends_ping_and_returns_none(): rec = _Recorder({}) From 551cacb5e210cf9a1b3780bbe0017862f9c28531 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 22:44:41 +0000 Subject: [PATCH 16/52] test: close PR3 coverage gaps to 100% - Connection.check_capability per-field branches (parametrized) - Context.log with logger and meta supplied - Peer.notify forwards to wrapped Outbound --- tests/server/test_connection.py | 21 +++++++++++++++++++++ tests/server/test_server_context.py | 25 +++++++++++++++++++++++++ tests/shared/test_peer.py | 17 +++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index eb3440085a..ded9dfd6ac 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -173,6 +173,27 @@ def test_connection_check_capability_false_before_initialized(): assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is False +@pytest.mark.parametrize( + ("have", "want", "expected"), + [ + (ClientCapabilities(roots=None), ClientCapabilities(roots=RootsCapability()), False), + ( + ClientCapabilities(roots=RootsCapability(list_changed=False)), + ClientCapabilities(roots=RootsCapability(list_changed=True)), + False, + ), + (ClientCapabilities(sampling=None), ClientCapabilities(sampling=SamplingCapability()), False), + (ClientCapabilities(experimental=None), ClientCapabilities(experimental={"a": {}}), False), + (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"b": {}}), False), + (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"a": {}}), True), + ], +) +def test_check_capability_per_field_branches(have: ClientCapabilities, want: ClientCapabilities, expected: bool): + conn = Connection(StubOutbound(), has_standalone_channel=True) + conn.client_capabilities = have + assert conn.check_capability(want) is expected + + def test_connection_check_capability_true_when_client_declares_it(): conn = Connection(StubOutbound(), has_standalone_channel=True) conn.client_capabilities = ClientCapabilities( diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index 65db51c4a5..eb2df9b649 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -129,3 +129,28 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | method, params = crec.notifications[0] assert method == "notifications/message" assert params is not None and params["level"] == "debug" and params["data"] == "hello" + + +@pytest.mark.anyio +async def test_context_log_includes_logger_and_meta_when_supplied(): + crec = Recorder() + _, c_notify = echo_handlers(crec) + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + await ctx.log("info", "x", logger="my.log", meta={"traceId": "t"}) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_notify=c_notify) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + await crec.notified.wait() + _, params = crec.notifications[0] + assert params is not None + assert params["logger"] == "my.log" + assert params["_meta"] == {"traceId": "t"} diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py index 0d7d9e9bae..589994c818 100644 --- a/tests/shared/test_peer.py +++ b/tests/shared/test_peer.py @@ -135,6 +135,23 @@ def test_dump_params_merges_meta_over_model_meta(): assert out == {"_meta": {"k": 1}} +@pytest.mark.anyio +async def test_peer_notify_forwards_to_wrapped_outbound(): + sent: list[tuple[str, Mapping[str, Any] | None]] = [] + + class _Out: + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: Any = None + ) -> dict[str, Any]: + raise NotImplementedError + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + sent.append((method, params)) + + await Peer(_Out()).notify("n", {"x": 1}) + assert sent == [("n", {"x": 1})] + + @pytest.mark.anyio async def test_peer_ping_sends_ping_and_returns_none(): rec = _Recorder({}) From 445f99aeeaa8ecb2c17ffbeafccad2eb1cfecd08 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:58:42 +0000 Subject: [PATCH 17/52] test: move asserts inside async-with for 3.11 coverage instrumentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit coverage.py on Python 3.11 doesn't record statements after an 'async with running_pair(...)' exit when there's a nested 'with anyio.fail_after()' inside. Same workaround as 0a8f0f4 in PR2 — move the asserts inside the async-with block. --- tests/server/test_server_context.py | 30 +++++++-------- tests/shared/test_context.py | 20 +++++----- tests/shared/test_peer.py | 60 ++++++++++++++--------------- 3 files changed, 55 insertions(+), 55 deletions(-) diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index eb2df9b649..e01de34d33 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -44,11 +44,11 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | conn.__init__(server, has_standalone_channel=True) with anyio.fail_after(5): await client.send_raw_request("t", None) - ctx = captured[0] - assert ctx.lifespan.name == "app" - assert ctx.connection is conn - assert ctx.transport.kind == "direct" - assert ctx.can_send_request is True + ctx = captured[0] + assert ctx.lifespan.name == "app" + assert ctx.connection is conn + assert ctx.transport.kind == "direct" + assert ctx.can_send_request is True @pytest.mark.anyio @@ -80,8 +80,8 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | ) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("tools/call", None) - assert crec.requests[0][0] == "sampling/createMessage" - assert isinstance(results[0], CreateMessageResult) + assert crec.requests[0][0] == "sampling/createMessage" + assert isinstance(results[0], CreateMessageResult) @pytest.mark.anyio @@ -104,7 +104,7 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | ): with anyio.fail_after(5): await client.send_raw_request("t", None) - assert isinstance(results[0], ListRootsResult) + assert isinstance(results[0], ListRootsResult) @pytest.mark.anyio @@ -126,9 +126,9 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | with anyio.fail_after(5): await client.send_raw_request("t", None) await crec.notified.wait() - method, params = crec.notifications[0] - assert method == "notifications/message" - assert params is not None and params["level"] == "debug" and params["data"] == "hello" + method, params = crec.notifications[0] + assert method == "notifications/message" + assert params is not None and params["level"] == "debug" and params["data"] == "hello" @pytest.mark.anyio @@ -150,7 +150,7 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | with anyio.fail_after(5): await client.send_raw_request("t", None) await crec.notified.wait() - _, params = crec.notifications[0] - assert params is not None - assert params["logger"] == "my.log" - assert params["_meta"] == {"traceId": "t"} + _, params = crec.notifications[0] + assert params is not None + assert params["logger"] == "my.log" + assert params["_meta"] == {"traceId": "t"} diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py index 951690028f..882f90bfab 100644 --- a/tests/shared/test_context.py +++ b/tests/shared/test_context.py @@ -34,11 +34,11 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("t", None) - bctx = captured[0] - assert bctx.transport.kind == "direct" - assert isinstance(bctx.cancel_requested, anyio.Event) - assert bctx.can_send_request is True - assert bctx.meta is None + bctx = captured[0] + assert bctx.transport.kind == "direct" + assert isinstance(bctx.cancel_requested, anyio.Event) + assert bctx.can_send_request is True + assert bctx.meta is None @pytest.mark.anyio @@ -61,9 +61,9 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | with anyio.fail_after(5): result = await client.send_raw_request("tools/call", None) await crec.notified.wait() - assert crec.requests == [("sampling/createMessage", {"x": 1})] - assert crec.notifications == [("notifications/message", {"level": "info"})] - assert result["sample"] == {"echoed": "sampling/createMessage", "params": {"x": 1}} + assert crec.requests == [("sampling/createMessage", {"x": 1})] + assert crec.notifications == [("notifications/message", {"level": "info"})] + assert result["sample"] == {"echoed": "sampling/createMessage", "params": {"x": 1}} @pytest.mark.anyio @@ -81,7 +81,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("t", None, {"on_progress": on_progress}) - assert received == [(0.5, 1.0, "halfway")] + assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio @@ -100,7 +100,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | ) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("t", None) - assert crec.requests == [("ping", None)] + assert crec.requests == [("ping", None)] @pytest.mark.anyio diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py index 589994c818..0be4225818 100644 --- a/tests/shared/test_peer.py +++ b/tests/shared/test_peer.py @@ -50,11 +50,11 @@ async def test_peer_sample_sends_create_message_and_returns_typed_result(): [SamplingMessage(role="user", content=TextContent(type="text", text="hello"))], max_tokens=10, ) - method, params = rec.seen[0] - assert method == "sampling/createMessage" - assert params is not None and params["maxTokens"] == 10 - assert isinstance(result, CreateMessageResult) - assert result.model == "m" + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["maxTokens"] == 10 + assert isinstance(result, CreateMessageResult) + assert result.model == "m" @pytest.mark.anyio @@ -68,10 +68,10 @@ async def test_peer_sample_with_tools_returns_with_tools_result(): max_tokens=5, tools=[Tool(name="t", input_schema={"type": "object"})], ) - method, params = rec.seen[0] - assert method == "sampling/createMessage" - assert params is not None and params["tools"][0]["name"] == "t" - assert isinstance(result, CreateMessageResultWithTools) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["tools"][0]["name"] == "t" + assert isinstance(result, CreateMessageResultWithTools) @pytest.mark.anyio @@ -81,11 +81,11 @@ async def test_peer_elicit_form_sends_elicitation_create_with_form_params(): peer = Peer(client) with anyio.fail_after(5): result = await peer.elicit_form("Your name?", requested_schema={"type": "object", "properties": {}}) - method, params = rec.seen[0] - assert method == "elicitation/create" - assert params is not None and params["mode"] == "form" - assert params["message"] == "Your name?" - assert isinstance(result, ElicitResult) + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "form" + assert params["message"] == "Your name?" + assert isinstance(result, ElicitResult) @pytest.mark.anyio @@ -95,11 +95,11 @@ async def test_peer_elicit_url_sends_elicitation_create_with_url_params(): peer = Peer(client) with anyio.fail_after(5): result = await peer.elicit_url("Auth needed", url="https://example.com/auth", elicitation_id="e1") - method, params = rec.seen[0] - assert method == "elicitation/create" - assert params is not None and params["mode"] == "url" - assert params["url"] == "https://example.com/auth" - assert isinstance(result, ElicitResult) + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "url" + assert params["url"] == "https://example.com/auth" + assert isinstance(result, ElicitResult) @pytest.mark.anyio @@ -109,11 +109,11 @@ async def test_peer_list_roots_sends_roots_list_and_returns_typed_result(): peer = Peer(client) with anyio.fail_after(5): result = await peer.list_roots() - method, _ = rec.seen[0] - assert method == "roots/list" - assert isinstance(result, ListRootsResult) - assert len(result.roots) == 1 - assert str(result.roots[0].uri) == "file:///workspace" + method, _ = rec.seen[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert len(result.roots) == 1 + assert str(result.roots[0].uri) == "file:///workspace" @pytest.mark.anyio @@ -123,9 +123,9 @@ async def test_peer_list_roots_with_meta_sends_meta_in_params(): peer = Peer(client) with anyio.fail_after(5): await peer.list_roots(meta={"traceId": "t1"}) - method, params = rec.seen[0] - assert method == "roots/list" - assert params == {"_meta": {"traceId": "t1"}} + method, params = rec.seen[0] + assert method == "roots/list" + assert params == {"_meta": {"traceId": "t1"}} def test_dump_params_merges_meta_over_model_meta(): @@ -159,6 +159,6 @@ async def test_peer_ping_sends_ping_and_returns_none(): peer = Peer(client) with anyio.fail_after(5): result = await peer.ping() - method, _ = rec.seen[0] - assert method == "ping" - assert result is None + method, _ = rec.seen[0] + assert method == "ping" + assert result is None From 63be3e474244ceb3493bc2705c4066d95ce18f12 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:04:07 +0000 Subject: [PATCH 18/52] docs: drop development-journal language from docstrings/comments Remove references to PR numbers, internal scratch notes, and design-spike shorthand that won't make sense to a fresh reader of the codebase. --- src/mcp/server/_typed_request.py | 9 +++++---- src/mcp/server/connection.py | 4 ++-- src/mcp/server/context.py | 4 ++-- src/mcp/shared/context.py | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/mcp/server/_typed_request.py b/src/mcp/server/_typed_request.py index 50cae159d1..4334b20a94 100644 --- a/src/mcp/server/_typed_request.py +++ b/src/mcp/server/_typed_request.py @@ -1,13 +1,14 @@ -"""Shape-2 typed ``send_request`` for server-to-client requests. +"""Typed ``send_request`` for server-to-client requests. `TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over the host's raw `Outbound.send_raw_request`. Spec server-to-client request types have their result type inferred via per-type overloads; custom requests pass ``result_type=`` explicitly. -A `HasResult[R]` protocol (one generic signature, mapping declared on the -request type) is the cleaner long-term shape — see FOLLOWUPS.md. This per-spec -overload set is used for now to avoid touching `mcp.types`. +If the spec's request set grows substantially, consider declaring the result +mapping on the request types themselves (a ``__mcp_result__`` ClassVar read via +a structural protocol) so this overload ladder doesn't need maintaining +per-host-class. """ from typing import Any, TypeVar, overload diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index 72c4ed062f..df3652ce0e 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -53,7 +53,7 @@ def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None: self.protocol_version: str | None = None self.initialized: anyio.Event = anyio.Event() # TODO: make this generic (Connection[StateT]) once connection_lifespan - # wiring lands in ServerRunner — see FOLLOWUPS.md. + # wiring lands in ServerRunner. self.state: Any = None async def send_raw_request( @@ -124,7 +124,7 @@ def check_capability(self, capability: ClientCapabilities) -> bool: Returns ``False`` if ``initialize`` hasn't completed yet. """ # TODO: redesign — mirrors v1 ServerSession.check_client_capability - # verbatim for parity. See FOLLOWUPS.md. + # verbatim for parity. if self.client_capabilities is None: return False have = self.client_capabilities diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index b7b97acf8b..4f0cffd9ad 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -42,8 +42,8 @@ class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Gener and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds ``lifespan`` and ``connection``. - Constructed by `ServerRunner` (PR4) per inbound request and handed to the - user's handler. + Constructed by `ServerRunner` per inbound request and handed to the user's + handler. """ def __init__( diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 68f439b738..38ca8bd9b4 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -28,8 +28,8 @@ class BaseContext(Generic[TransportT]): """Per-request context wrapping a `DispatchContext`. - `ServerRunner` (PR4) constructs one per inbound request and passes it to - the user's handler. + `ServerRunner` constructs one per inbound request and passes it to the + user's handler. """ def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None: From 786bc55a83a47f04f79cc8c9562d1aad6580bfbd Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 20 Apr 2026 19:01:09 +0000 Subject: [PATCH 19/52] refactor: make BaseContext/Context covariant in their type params MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LifespanT and TransportT are only exposed via read-only properties (lifespan, transport), so covariance is sound. This lets a Context[AppState, HttpTC] be passed where a Context[object, TransportContext] is expected — needed for ServerRunner's middleware chain to compose without casts, and for reusable middleware to be typed Context[object, TransportContext] instead of relying on Any-slack. --- src/mcp/server/context.py | 4 ++-- src/mcp/shared/context.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 4f0cffd9ad..4d35f8a902 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -30,8 +30,8 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex close_standalone_sse_stream: CloseSSEStreamCallback | None = None -LifespanT = TypeVar("LifespanT", default=Any) -TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) +LifespanT = TypeVar("LifespanT", default=Any, covariant=True) +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]): diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 38ca8bd9b4..ff69c48401 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -22,7 +22,7 @@ __all__ = ["BaseContext"] -TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) class BaseContext(Generic[TransportT]): From 958bdd71a1b31055db42610d8f86c695a717e06e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 20 Apr 2026 18:18:20 +0000 Subject: [PATCH 20/52] =?UTF-8?q?feat:=20ServerRunner=20skeleton=20?= =?UTF-8?q?=E2=80=94=20=5Fon=5Frequest,=20initialize,=20init-gate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ServerRunner is the per-connection orchestrator over a Dispatcher. This commit lands the skeleton: ServerRegistry Protocol, _on_request (lookup → validate → build Context → call handler → dump), _handle_initialize (populates Connection, opens the init-gate), and a basic _on_notify. Additive methods on lowlevel Server (get_request_handler / get_notification_handler / middleware / connection_lifespan) so it satisfies ServerRegistry without touching the existing run() path. _PARAMS_FOR_METHOD is scaffolding (marked TODO) until the registry stores params types directly. 5 tests over DirectDispatcher + a real lowlevel Server. --- src/mcp/server/lowlevel/server.py | 20 +++ src/mcp/server/runner.py | 218 ++++++++++++++++++++++++++++++ tests/server/test_runner.py | 154 +++++++++++++++++++++ 3 files changed, 392 insertions(+) create mode 100644 src/mcp/server/runner.py create mode 100644 tests/server/test_runner.py diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 5e4e2e6f5b..de12832dc5 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -246,6 +246,26 @@ def _has_handler(self, method: str) -> bool: """Check if a handler is registered for the given method.""" return method in self._request_handlers or method in self._notification_handlers + # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ + + def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: + """Return the handler for a request method, or ``None``.""" + return self._request_handlers.get(method) + + def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: + """Return the handler for a notification method, or ``None``.""" + return self._notification_handlers.get(method) + + @property + def middleware(self) -> list[Any]: + """Context-tier middleware. Empty until the registry refactor adds registration.""" + return [] + + @property + def connection_lifespan(self) -> None: + """Per-connection lifespan. ``None`` until the registry refactor adds it.""" + return None + # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py new file mode 100644 index 0000000000..66093b250a --- /dev/null +++ b/src/mcp/server/runner.py @@ -0,0 +1,218 @@ +"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`. + +`ServerRunner` is the bridge between the dispatcher layer (`on_request` / +`on_notify`, untyped dicts) and the user's handler layer (typed `Context`, +typed params). One instance per client connection. It: + +* handles the ``initialize`` handshake and populates `Connection` +* gates requests until initialized (``ping`` exempt) +* looks up the handler in the server's registry, validates params, builds + `Context`, runs the middleware chain, returns the result dict +* drives ``dispatcher.run()`` and the per-connection lifespan + +`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies +it via additive methods so the existing ``Server.run()`` path is unaffected. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any, Generic, Protocol, cast + +from pydantic import BaseModel +from typing_extensions import TypeVar + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.server.lowlevel.server import NotificationOptions +from mcp.shared.dispatcher import DispatchContext, Dispatcher +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + CallToolRequestParams, + CompleteRequestParams, + GetPromptRequestParams, + Implementation, + InitializeRequestParams, + InitializeResult, + NotificationParams, + PaginatedRequestParams, + ProgressNotificationParams, + ReadResourceRequestParams, + RequestParams, + ServerCapabilities, + SetLevelRequestParams, + SubscribeRequestParams, + UnsubscribeRequestParams, +) + +__all__ = ["ServerRegistry", "ServerRunner"] + +logger = logging.getLogger(__name__) + +LifespanT = TypeVar("LifespanT", default=Any) +ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext) + +Handler = Callable[..., Awaitable[Any]] +"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely +so the existing `ServerRequestContext`-based handlers and the new +`Context`-based handlers both fit during the transition. +""" + +_INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) + +# TODO: remove this lookup once `Server` stores (params_type, handler) in its +# registry directly. This is scaffolding so ServerRunner can validate params +# without changing the existing `_request_handlers` dict shape. +_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = { + "ping": RequestParams, + "tools/list": PaginatedRequestParams, + "tools/call": CallToolRequestParams, + "prompts/list": PaginatedRequestParams, + "prompts/get": GetPromptRequestParams, + "resources/list": PaginatedRequestParams, + "resources/templates/list": PaginatedRequestParams, + "resources/read": ReadResourceRequestParams, + "resources/subscribe": SubscribeRequestParams, + "resources/unsubscribe": UnsubscribeRequestParams, + "logging/setLevel": SetLevelRequestParams, + "completion/complete": CompleteRequestParams, +} +"""Spec method → params model. Scaffolding while the lowlevel `Server`'s +`_request_handlers` stores handler-only; the registry refactor should make this +the registry's responsibility (or store params types alongside handlers).""" + +_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = { + "notifications/initialized": NotificationParams, + "notifications/roots/list_changed": NotificationParams, + "notifications/progress": ProgressNotificationParams, +} + + +class ServerRegistry(Protocol): + """The handler registry `ServerRunner` consumes. + + The lowlevel `Server` satisfies this via additive methods. + """ + + @property + def name(self) -> str: ... + @property + def version(self) -> str | None: ... + + def get_request_handler(self, method: str) -> Handler | None: ... + def get_notification_handler(self, method: str) -> Handler | None: ... + def get_capabilities( + self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]] + ) -> ServerCapabilities: ... + + +def _dump_result(result: Any) -> dict[str, Any]: + if result is None: + return {} + if isinstance(result, BaseModel): + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + if isinstance(result, dict): + return cast(dict[str, Any], result) + raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") + + +@dataclass +class ServerRunner(Generic[LifespanT, ServerTransportT]): + """Per-connection orchestrator. One instance per client connection.""" + + server: ServerRegistry + dispatcher: Dispatcher[ServerTransportT] + lifespan_state: LifespanT + has_standalone_channel: bool + stateless: bool = False + + connection: Connection = field(init=False) + _initialized: bool = field(init=False) + + def __post_init__(self) -> None: + self._initialized = self.stateless + self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + + async def _on_request( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> dict[str, Any]: + if method == "initialize": + return self._handle_initialize(params) + if not self._initialized and method not in _INIT_EXEMPT: + raise MCPError( + code=INVALID_REQUEST, + message=f"Received {method!r} before initialization was complete", + ) + handler = self.server.get_request_handler(method) + if handler is None: + raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") + # TODO: scaffolding — params_type comes from a static lookup until the + # registry stores it alongside the handler. + params_type = _PARAMS_FOR_METHOD.get(method, RequestParams) + # ValidationError propagates; the dispatcher's exception boundary maps + # it to INVALID_PARAMS. + typed_params = params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + result = await handler(ctx, typed_params) + return _dump_result(result) + + async def _on_notify( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> None: + if method == "notifications/initialized": + self._initialized = True + self.connection.initialized.set() + return + if not self._initialized: + logger.debug("dropped %s: received before initialization", method) + return + handler = self.server.get_notification_handler(method) + if handler is None: + logger.debug("no handler for notification %s", method) + return + params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams) + typed_params = params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + await handler(ctx, typed_params) + + def _make_context( + self, dctx: DispatchContext[TransportContext], typed_params: BaseModel + ) -> Context[LifespanT, ServerTransportT]: + # `OnRequest` delivers `DispatchContext[TransportContext]`; this + # ServerRunner instance was constructed for a specific + # `ServerTransportT`, so the narrow is safe by construction. + narrowed = cast(DispatchContext[ServerTransportT], dctx) + meta = getattr(typed_params, "meta", None) + return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta) + + def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: + init = InitializeRequestParams.model_validate(params or {}) + self.connection.client_info = init.client_info + self.connection.client_capabilities = init.capabilities + # TODO: real version negotiation. This always responds with LATEST, + # which is wrong — the server should pick the highest version both + # sides support and compute a per-connection feature set from it. + # See FOLLOWUPS: "Consolidate per-connection mode/negotiation". + self.connection.protocol_version = ( + init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION + ) + self._initialized = True + self.connection.initialized.set() + result = InitializeResult( + protocol_version=self.connection.protocol_version, + capabilities=self.server.get_capabilities(NotificationOptions(), {}), + server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), + ) + return _dump_result(result) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py new file mode 100644 index 0000000000..5bff4b2888 --- /dev/null +++ b/tests/server/test_runner.py @@ -0,0 +1,154 @@ +"""Tests for `ServerRunner`. + +End-to-end over `DirectDispatcher` with a real lowlevel `Server` as the +registry. Covers `_on_request` routing, the initialize handshake, the +init-gate, and that handlers receive a fully-built `Context`. +""" + +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.server.lowlevel.server import Server +from mcp.server.runner import ServerRunner +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + ClientCapabilities, + Implementation, + InitializeRequestParams, + Tool, +) + +from ..shared.test_dispatcher import Recorder, echo_handlers + + +def _initialize_params() -> dict[str, Any]: + return InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test-client", version="1.0"), + ).model_dump(by_alias=True, exclude_none=True) + + +_seen_ctx: list[Context[Any, TransportContext]] = [] +SrvT = Server[dict[str, Any]] + + +@pytest.fixture +def server() -> SrvT: + """A lowlevel Server with one tools/list handler registered.""" + _seen_ctx.clear() + + async def list_tools(ctx: Any, params: Any) -> Any: + # ctx is typed `Any` because Server's on_list_tools kwarg expects the + # legacy ServerRequestContext shape; ServerRunner passes the new + # `Context`. The transition is intentional — Handler is loosely typed. + _seen_ctx.append(ctx) + return {"tools": [Tool(name="t", input_schema={"type": "object"}).model_dump(by_alias=True)]} + + return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) + + +@pytest.mark.anyio +async def test_runner_handles_initialize_and_populates_connection(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["serverInfo"]["name"] == "test-server" + assert "tools" in result["capabilities"] + assert runner.connection.client_info is not None + assert runner.connection.client_info.name == "test-client" + assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner._initialized is True + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_gates_requests_before_initialize(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INVALID_REQUEST + # ping is exempt + assert await client.send_raw_request("ping", None) == {} + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_routes_to_handler_after_initialize_and_builds_context(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + ctx = _seen_ctx[0] + assert isinstance(ctx, Context) + assert ctx.lifespan is None + assert isinstance(ctx.connection, Connection) + assert ctx.transport.kind == "direct" + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_unknown_method_raises_method_not_found(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + runner._initialized = True # bypass gate for this test + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_stateless_skips_init_gate(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=False, + stateless=True, + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + tg.cancel_scope.cancel() From fb8105621596ce2712fe4abf1a2ba86517debb57 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 20 Apr 2026 19:02:57 +0000 Subject: [PATCH 21/52] feat: ServerRunner middleware (two-tier) + _on_notify ContextMiddleware is a Protocol[L] (contravariant) so Server[L].middleware: list[ContextMiddleware[L]] is properly typed. App-specific middleware sees ctx.lifespan: L; reusable middleware typed ContextMiddleware[object] registers on any Server via contravariance. Context's covariance (previous PR3 commit) makes Context[L, ST] <: Context[L, TransportContext] so the chain composes without casts. dispatch_middleware (DispatchMiddleware list on ServerRunner) wraps the raw _on_request and sees everything including initialize/METHOD_NOT_FOUND. server.middleware (ContextMiddleware) runs inside _on_request after validation/ctx-build and wraps registered handlers only. _on_notify routes notifications/initialized (sets the flag), drops before-init and unknown methods, otherwise builds Context and calls the registered handler. 11 tests over DirectDispatcher + a real lowlevel Server. --- src/mcp/server/context.py | 36 +++++++- src/mcp/server/lowlevel/server.py | 10 +-- src/mcp/server/runner.py | 30 +++++-- tests/server/test_runner.py | 138 ++++++++++++++++++++++++++++++ 4 files changed, 201 insertions(+), 13 deletions(-) diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 4d35f8a902..1c855ae48a 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,8 +1,10 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, Generic +from typing import Any, Generic, Protocol +from pydantic import BaseModel from typing_extensions import TypeVar from mcp.server._typed_request import TypedServerRequestMixin @@ -81,3 +83,35 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * if meta: params["_meta"] = meta await self.notify("notifications/message", params) + + +HandlerResult = BaseModel | dict[str, Any] | None +"""What a request handler (or middleware) may return. `ServerRunner` serializes +all three to a result dict.""" + +CallNext = Callable[[], Awaitable[HandlerResult]] + +_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True) + + +class ContextMiddleware(Protocol[_MwLifespanT]): + """Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``. + + Runs *inside* `ServerRunner._on_request` after params validation and + `Context` construction. Wraps registered handlers (including ``ping``) but + not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed + outermost-first on `Server.middleware`. + + `Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific + middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific + types) can be typed `ContextMiddleware[object]` — `Context` is covariant in + `LifespanT`, so it registers on any `Server[L]`. + """ + + async def __call__( + self, + ctx: Context[_MwLifespanT, TransportContext], + method: str, + params: BaseModel, + call_next: CallNext, + ) -> HandlerResult: ... diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index de12832dc5..466c158bd4 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -58,7 +58,7 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings -from mcp.server.context import ServerRequestContext +from mcp.server.context import ContextMiddleware, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions @@ -199,6 +199,9 @@ def __init__( ] = {} self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None + # Context-tier middleware consumed by `ServerRunner`. Additive; the + # existing `run()` path ignores it. + self.middleware: list[ContextMiddleware[LifespanResultT]] = [] logger.debug("Initializing server %r", name) # Populate internal handler dicts from on_* kwargs @@ -256,11 +259,6 @@ def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] """Return the handler for a notification method, or ``None``.""" return self._notification_handlers.get(method) - @property - def middleware(self) -> list[Any]: - """Context-tier middleware. Empty until the registry refactor adds registration.""" - return [] - @property def connection_lifespan(self) -> None: """Per-connection lifespan. ``None`` until the registry refactor adds it.""" diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 66093b250a..a7dae289f7 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -17,17 +17,18 @@ from __future__ import annotations import logging -from collections.abc import Awaitable, Callable, Mapping +from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field +from functools import partial, reduce from typing import Any, Generic, Protocol, cast from pydantic import BaseModel from typing_extensions import TypeVar from mcp.server.connection import Connection -from mcp.server.context import Context +from mcp.server.context import CallNext, Context, ContextMiddleware from mcp.server.lowlevel.server import NotificationOptions -from mcp.shared.dispatcher import DispatchContext, Dispatcher +from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext from mcp.types import ( @@ -51,7 +52,7 @@ UnsubscribeRequestParams, ) -__all__ = ["ServerRegistry", "ServerRunner"] +__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner"] logger = logging.getLogger(__name__) @@ -64,6 +65,7 @@ `Context`-based handlers both fit during the transition. """ + _INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) # TODO: remove this lookup once `Server` stores (params_type, handler) in its @@ -105,6 +107,9 @@ def name(self) -> str: ... @property def version(self) -> str | None: ... + @property + def middleware(self) -> Sequence[ContextMiddleware[Any]]: ... + def get_request_handler(self, method: str) -> Handler | None: ... def get_notification_handler(self, method: str) -> Handler | None: ... def get_capabilities( @@ -131,6 +136,7 @@ class ServerRunner(Generic[LifespanT, ServerTransportT]): lifespan_state: LifespanT has_standalone_channel: bool stateless: bool = False + dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) connection: Connection = field(init=False) _initialized: bool = field(init=False) @@ -139,6 +145,16 @@ def __post_init__(self) -> None: self._initialized = self.stateless self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + def _compose_on_request(self) -> OnRequest: + """Wrap `_on_request` in `dispatch_middleware`, outermost-first. + + Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict`` + and wraps everything — initialize, METHOD_NOT_FOUND, validation + failures included. `run()` calls this once and hands the result to + `dispatcher.run()`. + """ + return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request) + async def _on_request( self, dctx: DispatchContext[TransportContext], @@ -162,8 +178,10 @@ async def _on_request( # it to INVALID_PARAMS. typed_params = params_type.model_validate(params or {}) ctx = self._make_context(dctx, typed_params) - result = await handler(ctx, typed_params) - return _dump_result(result) + call: CallNext = partial(handler, ctx, typed_params) + for mw in reversed(self.server.middleware): + call = partial(mw, ctx, method, typed_params, call) + return _dump_result(await call()) async def _on_notify( self, diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 5bff4b2888..eca10497c5 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -8,6 +8,7 @@ from typing import Any import anyio +import anyio.lowlevel import pytest from mcp.server.connection import Connection @@ -134,6 +135,143 @@ async def test_runner_unknown_method_raises_method_not_found(server: SrvT): tg.cancel_scope.cancel() +@pytest.mark.anyio +async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.notify("notifications/initialized", None) + await runner.connection.initialized.wait() + assert runner._initialized is True + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): + seen: list[tuple[Any, Any]] = [] + + async def on_roots_changed(ctx: Any, params: Any) -> None: + seen.append((ctx, params)) + + server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + runner._initialized = True + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.notify("notifications/roots/list_changed", None) + # DirectDispatcher delivers synchronously; one yield is enough. + await anyio.lowlevel.checkpoint() + assert len(seen) == 1 + assert isinstance(seen[0][0], Context) + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.notify("notifications/roots/list_changed", None) # before init: dropped + await client.notify("notifications/initialized", None) + await client.notify("notifications/unknown", None) # no handler: dropped + # No exception raised; both drops are silent. + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_dispatch_middleware_wraps_everything_including_initialize(server: SrvT): + seen_methods: list[str] = [] + + def trace_mw(next_on_request: Any) -> Any: + async def wrapped(dctx: Any, method: str, params: Any) -> Any: + seen_methods.append(method) + return await next_on_request(dctx, method, params) + + return wrapped + + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + dispatch_middleware=[trace_mw], + ) + c_req, c_notify = echo_handlers(Recorder()) + on_req = runner._compose_on_request() + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, on_req, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + await client.send_raw_request("tools/list", None) + assert seen_methods == ["initialize", "tools/list"] + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_server_middleware_wraps_handlers_but_not_initialize(server: SrvT): + seen_methods: list[str] = [] + + async def ctx_mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + seen_methods.append(method) + return await call_next() + + server.middleware.append(ctx_mw) + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + await client.send_raw_request("ping", None) + await client.send_raw_request("tools/list", None) + # initialize NOT wrapped; ping and tools/list ARE wrapped. + assert seen_methods == ["ping", "tools/list"] + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_server_middleware_runs_outermost_first(server: SrvT): + order: list[str] = [] + + def make_mw(tag: str) -> Any: + async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + order.append(f"{tag}-in") + result = await call_next() + order.append(f"{tag}-out") + return result + + return mw + + server.middleware.extend([make_mw("a"), make_mw("b")]) + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + runner._initialized = True + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("tools/list", None) + assert order == ["a-in", "b-in", "b-out", "a-out"] + tg.cancel_scope.cancel() + + @pytest.mark.anyio async def test_runner_stateless_skips_init_gate(server: SrvT): client, server_d = create_direct_dispatcher_pair() From 954874b57f20eb928b4a705ba40b1fd091fa4c61 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 22 Apr 2026 01:28:51 +0000 Subject: [PATCH 22/52] feat: ServerRunner.run() and otel_middleware run() composes dispatch_middleware over _on_request and forwards task_status to dispatcher.run() so callers can 'await tg.start(runner.run)'. otel_middleware is a DispatchMiddleware that wraps each request in a span, mirroring the existing Server._handle_request span shape: name 'MCP handle []', mcp.method.name attribute, W3C trace context extracted from params._meta (SEP-414), and ERROR status if the handler raises. connection_lifespan plumbing (the enter-late dance) is deferred to a separate commit since Server.connection_lifespan is None today. --- src/mcp/server/runner.py | 53 ++++++++++++++++++++++++++- tests/server/test_runner.py | 71 ++++++++++++++++++++++++++++++++++++- 2 files changed, 122 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index a7dae289f7..79dfc23e0e 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -22,12 +22,15 @@ from functools import partial, reduce from typing import Any, Generic, Protocol, cast +import anyio.abc +from opentelemetry.trace import SpanKind, StatusCode from pydantic import BaseModel from typing_extensions import TypeVar from mcp.server.connection import Connection from mcp.server.context import CallNext, Context, ContextMiddleware from mcp.server.lowlevel.server import NotificationOptions +from mcp.shared._otel import extract_trace_context, otel_span from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext @@ -52,7 +55,7 @@ UnsubscribeRequestParams, ) -__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner"] +__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"] logger = logging.getLogger(__name__) @@ -117,6 +120,44 @@ def get_capabilities( ) -> ServerCapabilities: ... +def otel_middleware(next_on_request: OnRequest) -> OnRequest: + """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. + + Mirrors the span shape of the existing `Server._handle_request`: span name + ``"MCP handle []"``, ``mcp.method.name`` attribute, W3C + trace context extracted from ``params._meta`` (SEP-414), and an ERROR + status if the handler raises. + """ + + async def wrapped( + dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + target: str | None + match params: + case {"name": str() as target}: + pass + case _: + target = None + parent: Any | None + match params: + case {"_meta": {**meta}}: + parent = extract_trace_context(meta) + case _: + parent = None + span_name = f"MCP handle {method}{f' {target}' if target else ''}" + with otel_span(span_name, kind=SpanKind.SERVER, attributes={"mcp.method.name": method}, context=parent) as span: + try: + return await next_on_request(dctx, method, params) + except MCPError as e: + span.set_status(StatusCode.ERROR, e.error.message) + raise + except Exception as e: + span.set_status(StatusCode.ERROR, str(e)) + raise + + return wrapped + + def _dump_result(result: Any) -> dict[str, Any]: if result is None: return {} @@ -145,6 +186,16 @@ def __post_init__(self) -> None: self._initialized = self.stateless self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: + """Drive the dispatcher until the underlying channel closes. + + Composes `dispatch_middleware` over `_on_request` and hands the result + to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers + can ``await tg.start(runner.run)`` and resume once the dispatcher is + ready to accept requests. + """ + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + def _compose_on_request(self) -> OnRequest: """Wrap `_on_request` in `dispatch_middleware`, outermost-first. diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index eca10497c5..3d2fd84c0c 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -14,7 +14,7 @@ from mcp.server.connection import Connection from mcp.server.context import Context from mcp.server.lowlevel.server import Server -from mcp.server.runner import ServerRunner +from mcp.server.runner import ServerRunner, otel_middleware from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext @@ -272,6 +272,75 @@ async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: tg.cancel_scope.cancel() +@pytest.mark.anyio +async def test_runner_run_drives_dispatcher_end_to_end(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + with anyio.fail_after(5): + init = await client.send_raw_request("initialize", _initialize_params()) + tools = await client.send_raw_request("tools/list", None) + assert init["serverInfo"]["name"] == "test-server" + assert tools["tools"][0]["name"] == "t" + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_run_applies_dispatch_middleware(server: SrvT): + seen: list[str] = [] + + def trace_mw(next_on_request: Any) -> Any: + async def wrapped(dctx: Any, method: str, params: Any) -> Any: + seen.append(method) + return await next_on_request(dctx, method, params) + + return wrapped + + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + dispatch_middleware=[trace_mw], + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + await client.send_raw_request("ping", None) + assert seen == ["initialize", "ping"] + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_otel_middleware_passes_through_result_and_survives_handler_error(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + dispatch_middleware=[otel_middleware], + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + tools = await client.send_raw_request("tools/list", None) + assert tools["tools"][0]["name"] == "t" + with pytest.raises(MCPError): + await client.send_raw_request("nonexistent/method", None) + tg.cancel_scope.cancel() + + @pytest.mark.anyio async def test_runner_stateless_skips_init_gate(server: SrvT): client, server_d = create_direct_dispatcher_pair() From 87579dac9dc18d2e163f8d8d3643e6aad146ed6e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 25 Apr 2026 21:40:05 +0000 Subject: [PATCH 23/52] =?UTF-8?q?test:=20ServerRunner=20coverage=20to=2010?= =?UTF-8?q?0%=20=E2=80=94=20otel=20span=20assertions=20+=20connected=5Frun?= =?UTF-8?q?ner=20harness?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add opentelemetry-sdk as a dev dep and a tests/server/conftest.py 'spans' fixture (TracerProvider + InMemorySpanExporter) so otel_middleware's span contract is observable. - Replace the otel pass-through test with four span-asserting tests (name + target, _meta traceparent → parent, MCPError → ERROR status without traceback, unexpected exception → ERROR status + exception event). These surfaced that start_as_current_span's default set_status_on_exception / record_exception was overwriting the middleware's explicit set_status and attaching tracebacks to protocol-level MCPErrors — now disabled and handled explicitly. - Add handler-return contract tests (None → {}, unsupported → INTERNAL_ERROR). - Introduce connected_runner async-contextmanager test harness and retrofit all tests through runner.run(); drop two tests made redundant by that. Harness closes dispatchers gracefully and re-raises body exceptions outside the task group so failures aren't ExceptionGroup-wrapped (and to avoid a coverage.py trace-loss false-negative on cancel-during-aexit). - Remove the unused Server.connection_lifespan placeholder; it lands with its consumer. --- pyproject.toml | 1 + src/mcp/server/lowlevel/server.py | 5 - src/mcp/server/runner.py | 10 +- src/mcp/shared/_otel.py | 11 +- tests/server/conftest.py | 34 +++ tests/server/test_runner.py | 401 ++++++++++++++---------------- uv.lock | 2 + 7 files changed, 246 insertions(+), 218 deletions(-) create mode 100644 tests/server/conftest.py diff --git a/pyproject.toml b/pyproject.toml index 6d2319621a..5f51fa9b85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ dev = [ "pillow>=12.0", "strict-no-cover", "logfire>=3.0.0", + "opentelemetry-sdk>=1.39.1", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 466c158bd4..a863246a18 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -259,11 +259,6 @@ def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] """Return the handler for a notification method, or ``None``.""" return self._notification_handlers.get(method) - @property - def connection_lifespan(self) -> None: - """Per-connection lifespan. ``None`` until the registry refactor adds it.""" - return None - # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 79dfc23e0e..bb3af04435 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -145,13 +145,21 @@ async def wrapped( case _: parent = None span_name = f"MCP handle {method}{f' {target}' if target else ''}" - with otel_span(span_name, kind=SpanKind.SERVER, attributes={"mcp.method.name": method}, context=parent) as span: + with otel_span( + span_name, + kind=SpanKind.SERVER, + attributes={"mcp.method.name": method}, + context=parent, + record_exception=False, + set_status_on_exception=False, + ) as span: try: return await next_on_request(dctx, method, params) except MCPError as e: span.set_status(StatusCode.ERROR, e.error.message) raise except Exception as e: + span.record_exception(e) span.set_status(StatusCode.ERROR, str(e)) raise diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py index 170e873a0f..553b8a0bce 100644 --- a/src/mcp/shared/_otel.py +++ b/src/mcp/shared/_otel.py @@ -20,9 +20,18 @@ def otel_span( kind: SpanKind, attributes: dict[str, Any] | None = None, context: Context | None = None, + record_exception: bool = True, + set_status_on_exception: bool = True, ) -> Iterator[Any]: """Create an OTel span.""" - with _tracer.start_as_current_span(name, kind=kind, attributes=attributes, context=context) as span: + with _tracer.start_as_current_span( + name, + kind=kind, + attributes=attributes, + context=context, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + ) as span: yield span diff --git a/tests/server/conftest.py b/tests/server/conftest.py new file mode 100644 index 0000000000..37202f529e --- /dev/null +++ b/tests/server/conftest.py @@ -0,0 +1,34 @@ +"""Shared fixtures for server-side tests.""" + +from collections.abc import Iterator + +import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +_span_exporter = InMemorySpanExporter() + + +@pytest.fixture(scope="session") +def _tracer_provider() -> TracerProvider: + """Install a real OTel SDK tracer provider once per test session. + + The runtime dependency is ``opentelemetry-api`` only, which yields no-op + ``NonRecordingSpan`` objects. Tests that need to assert on emitted spans + request the `spans` fixture, which depends on this one to make the global + tracer record into an in-memory exporter. + """ + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(_span_exporter)) + trace.set_tracer_provider(provider) + return provider + + +@pytest.fixture +def spans(_tracer_provider: TracerProvider) -> Iterator[InMemorySpanExporter]: + """In-memory OTel span exporter, cleared before and after each test.""" + _span_exporter.clear() + yield _span_exporter + _span_exporter.clear() diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 3d2fd84c0c..2006bf6486 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -1,24 +1,31 @@ """Tests for `ServerRunner`. End-to-end over `DirectDispatcher` with a real lowlevel `Server` as the -registry. Covers `_on_request` routing, the initialize handshake, the -init-gate, and that handlers receive a fully-built `Context`. +registry. The `connected_runner` helper starts both sides and (by default) +performs the initialize handshake, so each test exercises only the behaviour +under test. """ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from typing import Any import anyio import anyio.lowlevel import pytest +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import SpanKind, StatusCode from mcp.server.connection import Connection from mcp.server.context import Context from mcp.server.lowlevel.server import Server from mcp.server.runner import ServerRunner, otel_middleware -from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchMiddleware from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext from mcp.types import ( + INTERNAL_ERROR, INVALID_REQUEST, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, @@ -58,96 +65,107 @@ async def list_tools(ctx: Any, params: Any) -> Any: return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) -@pytest.mark.anyio -async def test_runner_handles_initialize_and_populates_connection(server: SrvT): +@asynccontextmanager +async def connected_runner( + server: SrvT, + *, + initialized: bool = True, + stateless: bool = False, + has_standalone_channel: bool = True, + dispatch_middleware: list[DispatchMiddleware] | None = None, +) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[None, TransportContext]]]: + """Yield ``(client, runner)`` running over an in-memory dispatcher pair. + + Starts the client (echo handlers) and `runner.run()` in a task group, wraps + the body in ``anyio.fail_after(5)``, and cancels on exit. When + ``initialized`` is true the helper performs the real ``initialize`` request + before yielding, so tests start past the init-gate via the public path. + """ client, server_d = create_direct_dispatcher_pair() runner = ServerRunner( server=server, dispatcher=server_d, lifespan_state=None, - has_standalone_channel=True, + has_standalone_channel=has_standalone_channel, + stateless=stateless, + dispatch_middleware=dispatch_middleware or [], ) c_req, c_notify = echo_handlers(Recorder()) + body_exc: BaseException | None = None async with anyio.create_task_group() as tg: await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - result = await client.send_raw_request("initialize", _initialize_params()) - assert result["serverInfo"]["name"] == "test-server" - assert "tools" in result["capabilities"] - assert runner.connection.client_info is not None - assert runner.connection.client_info.name == "test-client" - assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION - assert runner._initialized is True - tg.cancel_scope.cancel() + await tg.start(runner.run) + try: + with anyio.fail_after(5): + if initialized: + await client.send_raw_request("initialize", _initialize_params()) + yield client, runner + except BaseException as e: + # Capture and re-raise outside the task group so test failures + # surface as the original exception, not an ExceptionGroup wrapper. + body_exc = e + client.close() + server_d.close() + if body_exc is not None: + raise body_exc + + +@pytest.mark.anyio +async def test_connected_runner_propagates_body_exception_unwrapped(server: SrvT): + """The harness re-raises body exceptions as-is, not as ``ExceptionGroup``.""" + with pytest.raises(RuntimeError, match="boom"): + async with connected_runner(server): + raise RuntimeError("boom") + + +@pytest.mark.anyio +async def test_runner_handles_initialize_and_populates_connection(server: SrvT): + async with connected_runner(server, initialized=False) as (client, runner): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["serverInfo"]["name"] == "test-server" + assert "tools" in result["capabilities"] + assert runner.connection.client_info is not None + assert runner.connection.client_info.name == "test-client" + assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner._initialized is True @pytest.mark.anyio async def test_runner_gates_requests_before_initialize(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - with pytest.raises(MCPError) as exc: - await client.send_raw_request("tools/list", None) - assert exc.value.error.code == INVALID_REQUEST - # ping is exempt - assert await client.send_raw_request("ping", None) == {} - tg.cancel_scope.cancel() + async with connected_runner(server, initialized=False) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INVALID_REQUEST + # ping is exempt from the gate + assert await client.send_raw_request("ping", None) == {} @pytest.mark.anyio -async def test_runner_routes_to_handler_after_initialize_and_builds_context(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.send_raw_request("initialize", _initialize_params()) - result = await client.send_raw_request("tools/list", None) - assert result["tools"][0]["name"] == "t" - ctx = _seen_ctx[0] - assert isinstance(ctx, Context) - assert ctx.lifespan is None - assert isinstance(ctx.connection, Connection) - assert ctx.transport.kind == "direct" - tg.cancel_scope.cancel() +async def test_runner_routes_to_handler_and_builds_context(server: SrvT): + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + ctx = _seen_ctx[0] + assert isinstance(ctx, Context) + assert ctx.lifespan is None + assert isinstance(ctx.connection, Connection) + assert ctx.transport.kind == "direct" @pytest.mark.anyio async def test_runner_unknown_method_raises_method_not_found(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - runner._initialized = True # bypass gate for this test - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - with pytest.raises(MCPError) as exc: - await client.send_raw_request("nonexistent/method", None) - assert exc.value.error.code == METHOD_NOT_FOUND - tg.cancel_scope.cancel() + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND @pytest.mark.anyio async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.notify("notifications/initialized", None) - await runner.connection.initialized.wait() - assert runner._initialized is True - tg.cancel_scope.cancel() + async with connected_runner(server, initialized=False) as (client, runner): + await client.notify("notifications/initialized", None) + await runner.connection.initialized.wait() + assert runner._initialized is True @pytest.mark.anyio @@ -158,36 +176,21 @@ async def on_roots_changed(ctx: Any, params: Any) -> None: seen.append((ctx, params)) server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - runner._initialized = True - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.notify("notifications/roots/list_changed", None) - # DirectDispatcher delivers synchronously; one yield is enough. - await anyio.lowlevel.checkpoint() - assert len(seen) == 1 - assert isinstance(seen[0][0], Context) - tg.cancel_scope.cancel() + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", None) + # DirectDispatcher delivers synchronously; one yield is enough. + await anyio.lowlevel.checkpoint() + assert len(seen) == 1 + assert isinstance(seen[0][0], Context) @pytest.mark.anyio async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.notify("notifications/roots/list_changed", None) # before init: dropped - await client.notify("notifications/initialized", None) - await client.notify("notifications/unknown", None) # no handler: dropped - # No exception raised; both drops are silent. - tg.cancel_scope.cancel() + async with connected_runner(server, initialized=False) as (client, _): + await client.notify("notifications/roots/list_changed", None) # before init: dropped + await client.notify("notifications/initialized", None) + await client.notify("notifications/unknown", None) # no handler: dropped + # No exception raised; both drops are silent. @pytest.mark.anyio @@ -201,24 +204,9 @@ async def wrapped(dctx: Any, method: str, params: Any) -> Any: return wrapped - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner( - server=server, - dispatcher=server_d, - lifespan_state=None, - has_standalone_channel=True, - dispatch_middleware=[trace_mw], - ) - c_req, c_notify = echo_handlers(Recorder()) - on_req = runner._compose_on_request() - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, on_req, runner._on_notify) - with anyio.fail_after(5): - await client.send_raw_request("initialize", _initialize_params()) - await client.send_raw_request("tools/list", None) - assert seen_methods == ["initialize", "tools/list"] - tg.cancel_scope.cancel() + async with connected_runner(server, dispatch_middleware=[trace_mw]) as (client, _): + await client.send_raw_request("tools/list", None) + assert seen_methods == ["initialize", "tools/list"] @pytest.mark.anyio @@ -230,19 +218,11 @@ async def ctx_mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: return await call_next() server.middleware.append(ctx_mw) - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.send_raw_request("initialize", _initialize_params()) - await client.send_raw_request("ping", None) - await client.send_raw_request("tools/list", None) - # initialize NOT wrapped; ping and tools/list ARE wrapped. - assert seen_methods == ["ping", "tools/list"] - tg.cancel_scope.cancel() + async with connected_runner(server) as (client, _): + await client.send_raw_request("ping", None) + await client.send_raw_request("tools/list", None) + # initialize (sent by the helper) NOT wrapped; ping and tools/list ARE. + assert seen_methods == ["ping", "tools/list"] @pytest.mark.anyio @@ -259,103 +239,102 @@ async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: return mw server.middleware.extend([make_mw("a"), make_mw("b")]) - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - runner._initialized = True - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.send_raw_request("tools/list", None) - assert order == ["a-in", "b-in", "b-out", "a-out"] - tg.cancel_scope.cancel() + async with connected_runner(server) as (client, _): + await client.send_raw_request("tools/list", None) + assert order == ["a-in", "b-in", "b-out", "a-out"] @pytest.mark.anyio -async def test_runner_run_drives_dispatcher_end_to_end(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(runner.run) - with anyio.fail_after(5): - init = await client.send_raw_request("initialize", _initialize_params()) - tools = await client.send_raw_request("tools/list", None) - assert init["serverInfo"]["name"] == "test-server" - assert tools["tools"][0]["name"] == "t" - tg.cancel_scope.cancel() +async def test_runner_handler_returning_none_yields_empty_result(server: SrvT): + async def set_level(ctx: Any, params: Any) -> None: + return None + + server._request_handlers["logging/setLevel"] = set_level + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("logging/setLevel", {"level": "info"}) + assert result == {} @pytest.mark.anyio -async def test_runner_run_applies_dispatch_middleware(server: SrvT): - seen: list[str] = [] +async def test_runner_handler_returning_unsupported_type_surfaces_as_internal_error(server: SrvT): + async def bad_return(ctx: Any, params: Any) -> int: + return 42 - def trace_mw(next_on_request: Any) -> Any: - async def wrapped(dctx: Any, method: str, params: Any) -> Any: - seen.append(method) - return await next_on_request(dctx, method, params) + server._request_handlers["tools/list"] = bad_return + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + assert "int" in exc.value.error.message - return wrapped - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner( - server=server, - dispatcher=server_d, - lifespan_state=None, - has_standalone_channel=True, - dispatch_middleware=[trace_mw], - ) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(runner.run) - with anyio.fail_after(5): - await client.send_raw_request("initialize", _initialize_params()) - await client.send_raw_request("ping", None) - assert seen == ["initialize", "ping"] - tg.cancel_scope.cancel() +@pytest.mark.anyio +async def test_runner_stateless_skips_init_gate(server: SrvT): + async with connected_runner(server, initialized=False, stateless=True, has_standalone_channel=False) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" @pytest.mark.anyio -async def test_otel_middleware_passes_through_result_and_survives_handler_error(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner( - server=server, - dispatcher=server_d, - lifespan_state=None, - has_standalone_channel=True, - dispatch_middleware=[otel_middleware], - ) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(runner.run) - with anyio.fail_after(5): - await client.send_raw_request("initialize", _initialize_params()) - tools = await client.send_raw_request("tools/list", None) - assert tools["tools"][0]["name"] == "t" - with pytest.raises(MCPError): - await client.send_raw_request("nonexistent/method", None) - tg.cancel_scope.cancel() +async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: InMemorySpanExporter): + async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: + return {"content": [], "isError": False} + + server._request_handlers["tools/call"] = call_tool + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) + assert result == {"content": [], "isError": False} + [span] = spans.get_finished_spans() + assert span.name == "MCP handle tools/call mytool" + assert span.kind == SpanKind.SERVER + assert span.attributes is not None + assert span.attributes["mcp.method.name"] == "tools/call" + assert span.status.status_code == StatusCode.UNSET @pytest.mark.anyio -async def test_runner_stateless_skips_init_gate(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner( - server=server, - dispatcher=server_d, - lifespan_state=None, - has_standalone_channel=False, - stateless=True, - ) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - result = await client.send_raw_request("tools/list", None) - assert result["tools"][0]["name"] == "t" - tg.cancel_scope.cancel() +async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, spans: InMemorySpanExporter): + parent_span_id = "b7ad6b7169203331" + traceparent = f"00-0af7651916cd43dd8448eb211c80319c-{parent_span_id}-01" + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + await client.send_raw_request("tools/list", {"_meta": {"traceparent": traceparent}}) + [span] = spans.get_finished_spans() + assert span.parent is not None + assert format(span.parent.span_id, "016x") == parent_span_id + assert span.context is not None + assert format(span.context.trace_id, "032x") == "0af7651916cd43dd8448eb211c80319c" + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, spans: InMemorySpanExporter): + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + [span] = spans.get_finished_spans() + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "Method not found: nonexistent/method" + # MCPError is a protocol-level response, not a crash — no traceback event. + assert not [e for e in span.events if e.name == "exception"] + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: InMemorySpanExporter): + async def failing(ctx: Any, params: Any) -> Any: + raise ValueError("handler blew up") + + server._request_handlers["tools/list"] = failing + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + [span] = spans.get_finished_spans() + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "handler blew up" + [event] = [e for e in span.events if e.name == "exception"] + assert event.attributes is not None + assert event.attributes["exception.type"] == "ValueError" diff --git a/uv.lock b/uv.lock index 5b72e97fce..86ea2f5fb7 100644 --- a/uv.lock +++ b/uv.lock @@ -885,6 +885,7 @@ dev = [ { name = "inline-snapshot" }, { name = "logfire" }, { name = "mcp", extra = ["cli", "ws"] }, + { name = "opentelemetry-sdk" }, { name = "pillow" }, { name = "pyright" }, { name = "pytest" }, @@ -937,6 +938,7 @@ dev = [ { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "logfire", specifier = ">=3.0.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, + { name = "opentelemetry-sdk", specifier = ">=1.39.1" }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, { name = "pytest", specifier = ">=8.4.0" }, From a19735b521749d850c382fa348be2cf8c8ffa7a5 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 25 Apr 2026 22:14:12 +0000 Subject: [PATCH 24/52] test: converge span capture on capfire to fix xdist order-dependence MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous tests/server/conftest.py called trace.set_tracer_provider() directly, which is set-once per process and raced against logfire's capfire fixture (tests/shared/test_otel.py) under xdist — whichever ran first in a worker won, the other's tests broke. Converge on capfire as the single span-capture owner since logfire.configure() already handles repeat calls by swapping span processors instead of re-setting the provider: - tests/conftest.py: set LOGFIRE_DISTRIBUTED_TRACING=true so propagation tests don't trip logfire's 'found propagated trace context' RuntimeWarning. - tests/server/conftest.py: SpanCapture adapter over capfire.exporter — filters to the mcp-python-sdk instrumentation scope and excludes logfire's pending_span markers, so tests assert on raw ReadableSpan without importing logfire types. - tests/shared/test_otel.py: drop the now-unneeded filterwarnings decorator. --- tests/conftest.py | 11 ++++++++ tests/server/conftest.py | 55 ++++++++++++++++++++++--------------- tests/server/test_runner.py | 18 ++++++------ tests/shared/test_otel.py | 3 -- 4 files changed, 53 insertions(+), 34 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index af7e479932..b83c472135 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,16 @@ +import os + import pytest +# OpenTelemetry's `set_tracer_provider` is set-once per process, so the suite +# uses a single span-capture mechanism: logfire's `capfire` fixture (its +# `configure()` swaps span processors on repeat calls rather than re-setting +# the provider). Logfire's default `distributed_tracing=None` emits a +# RuntimeWarning + diagnostic span when incoming W3C trace context is +# extracted; several tests exercise that propagation deliberately, so opt in +# suite-wide. Set before logfire is imported anywhere. +os.environ.setdefault("LOGFIRE_DISTRIBUTED_TRACING", "true") + @pytest.fixture def anyio_backend(): diff --git a/tests/server/conftest.py b/tests/server/conftest.py index 37202f529e..290ccc957a 100644 --- a/tests/server/conftest.py +++ b/tests/server/conftest.py @@ -3,32 +3,43 @@ from collections.abc import Iterator import pytest -from opentelemetry import trace -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from logfire.testing import CaptureLogfire, TestExporter +from opentelemetry.sdk.trace import ReadableSpan -_span_exporter = InMemorySpanExporter() +class SpanCapture: + """Thin adapter over logfire's `TestExporter` for asserting on MCP spans. -@pytest.fixture(scope="session") -def _tracer_provider() -> TracerProvider: - """Install a real OTel SDK tracer provider once per test session. - - The runtime dependency is ``opentelemetry-api`` only, which yields no-op - ``NonRecordingSpan`` objects. Tests that need to assert on emitted spans - request the `spans` fixture, which depends on this one to make the global - tracer record into an in-memory exporter. + `finished()` returns the raw `ReadableSpan` objects emitted by the + ``mcp-python-sdk`` instrumentation scope, filtered to exclude logfire's + synthetic ``pending_span`` markers, so tests can assert directly on + `.name`, `.kind`, `.status`, `.attributes`, `.parent`, `.events`. """ - provider = TracerProvider() - provider.add_span_processor(SimpleSpanProcessor(_span_exporter)) - trace.set_tracer_provider(provider) - return provider + + def __init__(self, exporter: TestExporter) -> None: + self._exporter = exporter + + def clear(self) -> None: + self._exporter.clear() + + def finished(self) -> list[ReadableSpan]: + return [ + s + for s in self._exporter.exported_spans + if s.instrumentation_scope is not None + and s.instrumentation_scope.name == "mcp-python-sdk" + and not (s.attributes and s.attributes.get("logfire.span_type") == "pending_span") + ] @pytest.fixture -def spans(_tracer_provider: TracerProvider) -> Iterator[InMemorySpanExporter]: - """In-memory OTel span exporter, cleared before and after each test.""" - _span_exporter.clear() - yield _span_exporter - _span_exporter.clear() +def spans(capfire: CaptureLogfire) -> Iterator[SpanCapture]: + """In-memory MCP span capture, cleared before and after each test. + + Backed by the project-level `capfire` override (see ``tests/conftest.py``) + so there is a single global tracer provider for the suite. + """ + capture = SpanCapture(capfire.exporter) + capture.clear() + yield capture + capture.clear() diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 2006bf6486..843b0ae8b9 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -13,7 +13,6 @@ import anyio import anyio.lowlevel import pytest -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import SpanKind, StatusCode from mcp.server.connection import Connection @@ -36,6 +35,7 @@ ) from ..shared.test_dispatcher import Recorder, echo_handlers +from .conftest import SpanCapture def _initialize_params() -> dict[str, Any]: @@ -276,7 +276,7 @@ async def test_runner_stateless_skips_init_gate(server: SrvT): @pytest.mark.anyio -async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: InMemorySpanExporter): +async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: return {"content": [], "isError": False} @@ -285,7 +285,7 @@ async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: spans.clear() result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) assert result == {"content": [], "isError": False} - [span] = spans.get_finished_spans() + [span] = spans.finished() assert span.name == "MCP handle tools/call mytool" assert span.kind == SpanKind.SERVER assert span.attributes is not None @@ -294,13 +294,13 @@ async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: @pytest.mark.anyio -async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, spans: InMemorySpanExporter): +async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, spans: SpanCapture): parent_span_id = "b7ad6b7169203331" traceparent = f"00-0af7651916cd43dd8448eb211c80319c-{parent_span_id}-01" async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() await client.send_raw_request("tools/list", {"_meta": {"traceparent": traceparent}}) - [span] = spans.get_finished_spans() + [span] = spans.finished() assert span.parent is not None assert format(span.parent.span_id, "016x") == parent_span_id assert span.context is not None @@ -308,13 +308,13 @@ async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, s @pytest.mark.anyio -async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, spans: InMemorySpanExporter): +async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, spans: SpanCapture): async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() with pytest.raises(MCPError) as exc: await client.send_raw_request("nonexistent/method", None) assert exc.value.error.code == METHOD_NOT_FOUND - [span] = spans.get_finished_spans() + [span] = spans.finished() assert span.status.status_code == StatusCode.ERROR assert span.status.description == "Method not found: nonexistent/method" # MCPError is a protocol-level response, not a crash — no traceback event. @@ -322,7 +322,7 @@ async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, s @pytest.mark.anyio -async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: InMemorySpanExporter): +async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: SpanCapture): async def failing(ctx: Any, params: Any) -> Any: raise ValueError("handler blew up") @@ -332,7 +332,7 @@ async def failing(ctx: Any, params: Any) -> Any: with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) assert exc.value.error.code == INTERNAL_ERROR - [span] = spans.get_finished_spans() + [span] = spans.finished() assert span.status.status_code == StatusCode.ERROR assert span.status.description == "handler blew up" [event] = [e for e in span.events if e.name == "exception"] diff --git a/tests/shared/test_otel.py b/tests/shared/test_otel.py index ec7ff78cc1..a7df4c4294 100644 --- a/tests/shared/test_otel.py +++ b/tests/shared/test_otel.py @@ -10,9 +10,6 @@ pytestmark = pytest.mark.anyio -# Logfire warns about propagated trace context by default (distributed_tracing=None). -# This is expected here since we're testing cross-boundary context propagation. -@pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_client_and_server_spans(capfire: CaptureLogfire): """Verify that calling a tool produces client and server spans with correct attributes.""" server = MCPServer("test") From 7123dd98e813e627d75a23071aa45e791ceee12b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 30 Apr 2026 17:29:55 +0000 Subject: [PATCH 25/52] feat: Server registry stores HandlerEntry; ServerRunner consumes Server[L] directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Server is generic in LifespanResultT only — no TransportContextT. Spike (scratch/spike-tt-on-server) found a third generic breaks bare-Server plumbing helpers via invariance and only buys one None-check; it remains additive later via PEP 696 default if demand materialises. TT stays on the transport layer (Dispatcher/DispatchContext/BaseContext in mcp.shared); the server layer (Server/Context/ServerRunner/ServerMiddleware) consumes base TransportContext. - HandlerEntry[L] frozen dataclass (params_type, handler) replaces bare callables in the registry; params type erased to Any in storage, correlated at add_request_handler[P] - Public add_request_handler/add_notification_handler; capabilities() zero-arg (notification_options/experimental_capabilities now ctor kwargs) - ServerRunner drops the ServerRegistry Protocol scaffold and reads Server[L] directly; _make_context no longer narrows dctx - ServerMiddleware[L] (one contravariant param) - Context[L] (BaseContext[TransportContext] fixed) --- src/mcp/server/context.py | 17 ++-- src/mcp/server/lowlevel/server.py | 147 +++++++++++++++++++--------- src/mcp/server/runner.py | 122 +++++------------------ tests/server/test_runner.py | 76 ++++++++++---- tests/server/test_server_context.py | 12 +-- 5 files changed, 197 insertions(+), 177 deletions(-) diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 1c855ae48a..1cf2be1899 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -33,10 +33,9 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex LifespanT = TypeVar("LifespanT", default=Any, covariant=True) -TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) -class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]): +class Context(BaseContext[TransportContext], PeerMixin, TypedServerRequestMixin, Generic[LifespanT]): """Server-side per-request context. Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`), @@ -50,7 +49,7 @@ class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Gener def __init__( self, - dctx: DispatchContext[TransportT], + dctx: DispatchContext[TransportContext], *, lifespan: LifespanT, connection: Connection, @@ -94,7 +93,7 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * _MwLifespanT = TypeVar("_MwLifespanT", contravariant=True) -class ContextMiddleware(Protocol[_MwLifespanT]): +class ServerMiddleware(Protocol[_MwLifespanT]): """Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``. Runs *inside* `ServerRunner._on_request` after params validation and @@ -102,15 +101,15 @@ class ContextMiddleware(Protocol[_MwLifespanT]): not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed outermost-first on `Server.middleware`. - `Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific - middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific - types) can be typed `ContextMiddleware[object]` — `Context` is covariant in - `LifespanT`, so it registers on any `Server[L]`. + `Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific + middleware sees `ctx.lifespan: L`. A reusable middleware can be typed + `ServerMiddleware[object]` — `Context` is covariant in `LifespanT`, so it + registers on any `Server[L]`. """ async def __call__( self, - ctx: Context[_MwLifespanT, TransportContext], + ctx: Context[_MwLifespanT], method: str, params: BaseModel, call_next: CallNext, diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index a863246a18..375ca94c0d 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -41,11 +41,13 @@ async def main(): import warnings from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager +from dataclasses import dataclass from importlib.metadata import version as importlib_version from typing import Any, Generic, cast import anyio from opentelemetry.trace import SpanKind, StatusCode +from pydantic import BaseModel from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -58,7 +60,7 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings -from mcp.server.context import ContextMiddleware, ServerRequestContext +from mcp.server.context import HandlerResult, ServerMiddleware, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions @@ -76,6 +78,30 @@ async def main(): LifespanResultT = TypeVar("LifespanResultT", default=Any) +_ParamsT = TypeVar("_ParamsT", bound=BaseModel, default=BaseModel) + +RequestHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[HandlerResult]] +"""A registered request handler: ``(ctx, params) -> result``.""" + +NotificationHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[None]] +"""A registered notification handler: ``(ctx, params) -> None``.""" + + +@dataclass(frozen=True, slots=True) +class HandlerEntry(Generic[LifespanResultT]): + """A registered handler and the params model to validate incoming params against. + + Stored in `Server._request_handlers` / `_notification_handlers` and consumed + by `ServerRunner` to validate, build `Context`, and invoke. The handler's + second-argument type is erased to ``Any`` in storage (each entry has a + different concrete params type and `Callable` parameters are contravariant); + the precise type is recoverable via `params_type`. The correlation is + enforced at registration time by `Server.add_request_handler`. + """ + + params_type: type[BaseModel] + handler: RequestHandler[LifespanResultT, Any] + class NotificationOptions: def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): @@ -85,7 +111,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. Returns: @@ -109,6 +135,8 @@ def __init__( instructions: str | None = None, website_url: str | None = None, icons: list[types.Icon] | None = None, + notification_options: NotificationOptions | None = None, + experimental_capabilities: dict[str, dict[str, Any]] | None = None, lifespan: Callable[ [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], @@ -193,57 +221,77 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan - self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {} - self._notification_handlers: dict[ - str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] - ] = {} + self._notification_options = notification_options or NotificationOptions() + self._experimental_capabilities = experimental_capabilities or {} + self._request_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} + self._notification_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None # Context-tier middleware consumed by `ServerRunner`. Additive; the # existing `run()` path ignores it. - self.middleware: list[ContextMiddleware[LifespanResultT]] = [] + self.middleware: list[ServerMiddleware[LifespanResultT]] = [] logger.debug("Initializing server %r", name) - # Populate internal handler dicts from on_* kwargs - self._request_handlers.update( - { - method: handler - for method, handler in { - "ping": on_ping, - "prompts/list": on_list_prompts, - "prompts/get": on_get_prompt, - "resources/list": on_list_resources, - "resources/templates/list": on_list_resource_templates, - "resources/read": on_read_resource, - "resources/subscribe": on_subscribe_resource, - "resources/unsubscribe": on_unsubscribe_resource, - "tools/list": on_list_tools, - "tools/call": on_call_tool, - "logging/setLevel": on_set_logging_level, - "completion/complete": on_completion, - }.items() - if handler is not None - } - ) + _spec_requests: list[tuple[str, type[BaseModel], RequestHandler[LifespanResultT, Any] | None]] = [ + ("ping", types.RequestParams, on_ping), + ("prompts/list", types.PaginatedRequestParams, on_list_prompts), + ("prompts/get", types.GetPromptRequestParams, on_get_prompt), + ("resources/list", types.PaginatedRequestParams, on_list_resources), + ("resources/templates/list", types.PaginatedRequestParams, on_list_resource_templates), + ("resources/read", types.ReadResourceRequestParams, on_read_resource), + ("resources/subscribe", types.SubscribeRequestParams, on_subscribe_resource), + ("resources/unsubscribe", types.UnsubscribeRequestParams, on_unsubscribe_resource), + ("tools/list", types.PaginatedRequestParams, on_list_tools), + ("tools/call", types.CallToolRequestParams, on_call_tool), + ("logging/setLevel", types.SetLevelRequestParams, on_set_logging_level), + ("completion/complete", types.CompleteRequestParams, on_completion), + ] + self._request_handlers.update({m: HandlerEntry(pt, h) for m, pt, h in _spec_requests if h is not None}) + _spec_notifications: list[tuple[str, type[BaseModel], NotificationHandler[LifespanResultT, Any] | None]] = [ + ("notifications/roots/list_changed", types.NotificationParams, on_roots_list_changed), + ("notifications/progress", types.ProgressNotificationParams, on_progress), + ] self._notification_handlers.update( - { - method: handler - for method, handler in { - "notifications/roots/list_changed": on_roots_list_changed, - "notifications/progress": on_progress, - }.items() - if handler is not None - } + {m: HandlerEntry(pt, h) for m, pt, h in _spec_notifications if h is not None} ) + def add_request_handler( + self, + method: str, + params_type: type[_ParamsT], + handler: RequestHandler[LifespanResultT, _ParamsT], + ) -> None: + """Register a request handler for ``method``. + + ``params_type`` is the model incoming params are validated against + before the handler is invoked. It should subclass `RequestParams` so + ``_meta`` parses uniformly. Replaces any existing handler for the same + method (no collision guard against spec methods). + """ + self._request_handlers[method] = HandlerEntry(params_type, handler) + + def add_notification_handler( + self, + method: str, + params_type: type[_ParamsT], + handler: NotificationHandler[LifespanResultT, _ParamsT], + ) -> None: + """Register a notification handler for ``method``. + + ``params_type`` should subclass `NotificationParams` so ``_meta`` + parses uniformly. Replaces any existing handler. + """ + self._notification_handlers[method] = HandlerEntry(params_type, handler) + def _add_request_handler( self, method: str, - handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + handler: RequestHandler[LifespanResultT, Any], ) -> None: - """Add a request handler, silently replacing any existing handler for the same method.""" - self._request_handlers[method] = handler + # TODO: remove once experimental tasks plumbing and remaining callers + # migrate to `add_request_handler` with an explicit params_type. + self.add_request_handler(method, types.RequestParams, handler) def _has_handler(self, method: str) -> bool: """Check if a handler is registered for the given method.""" @@ -251,14 +299,18 @@ def _has_handler(self, method: str) -> bool: # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ - def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: - """Return the handler for a request method, or ``None``.""" + def get_request_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: + """Return the registered entry for a request method, or ``None``.""" return self._request_handlers.get(method) - def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: - """Return the handler for a notification method, or ``None``.""" + def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: + """Return the registered entry for a notification method, or ``None``.""" return self._notification_handlers.get(method) + def capabilities(self) -> types.ServerCapabilities: + """Derive `ServerCapabilities` from registered handlers and constructor options.""" + return self.get_capabilities(self._notification_options, self._experimental_capabilities) + # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities @@ -474,7 +526,8 @@ async def _handle_request( attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id}, context=parent_context, ) as span: - if handler := self._request_handlers.get(req.method): + if entry := self._request_handlers.get(req.method): + handler = entry.handler logger.debug("Dispatching request of type %s", type(req).__name__) try: @@ -533,7 +586,8 @@ async def _handle_request( span.set_status(StatusCode.ERROR, response.message) try: - await message.respond(response) + # TODO: cast goes away when `_handle_request` is deleted. + await message.respond(cast(types.ServerResult | types.ErrorData, response)) except (anyio.BrokenResourceError, anyio.ClosedResourceError): # Transport closed between handler unblocking and respond. Happens # when _receive_loop's finally wakes a handler blocked on @@ -552,7 +606,8 @@ async def _handle_notification( session: ServerSession, lifespan_context: LifespanResultT, ) -> None: - if handler := self._notification_handlers.get(notify.method): + if entry := self._notification_handlers.get(notify.method): + handler = entry.handler logger.debug("Dispatching notification of type %s", type(notify).__name__) try: diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index bb3af04435..1ef711d020 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -10,17 +10,16 @@ `Context`, runs the middleware chain, returns the result dict * drives ``dispatcher.run()`` and the per-connection lifespan -`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies -it via additive methods so the existing ``Server.run()`` path is unaffected. +`ServerRunner` holds a `Server` directly — `Server` is the registry. """ from __future__ import annotations import logging -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import Mapping from dataclasses import dataclass, field from functools import partial, reduce -from typing import Any, Generic, Protocol, cast +from typing import Any, Generic, cast import anyio.abc from opentelemetry.trace import SpanKind, StatusCode @@ -28,8 +27,8 @@ from typing_extensions import TypeVar from mcp.server.connection import Connection -from mcp.server.context import CallNext, Context, ContextMiddleware -from mcp.server.lowlevel.server import NotificationOptions +from mcp.server.context import CallNext, Context, ServerMiddleware +from mcp.server.lowlevel.server import Server from mcp.shared._otel import extract_trace_context, otel_span from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError @@ -38,87 +37,20 @@ INVALID_REQUEST, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, - CallToolRequestParams, - CompleteRequestParams, - GetPromptRequestParams, Implementation, InitializeRequestParams, InitializeResult, - NotificationParams, - PaginatedRequestParams, - ProgressNotificationParams, - ReadResourceRequestParams, - RequestParams, - ServerCapabilities, - SetLevelRequestParams, - SubscribeRequestParams, - UnsubscribeRequestParams, ) -__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"] +__all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "otel_middleware"] logger = logging.getLogger(__name__) LifespanT = TypeVar("LifespanT", default=Any) -ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext) - -Handler = Callable[..., Awaitable[Any]] -"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely -so the existing `ServerRequestContext`-based handlers and the new -`Context`-based handlers both fit during the transition. -""" _INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) -# TODO: remove this lookup once `Server` stores (params_type, handler) in its -# registry directly. This is scaffolding so ServerRunner can validate params -# without changing the existing `_request_handlers` dict shape. -_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = { - "ping": RequestParams, - "tools/list": PaginatedRequestParams, - "tools/call": CallToolRequestParams, - "prompts/list": PaginatedRequestParams, - "prompts/get": GetPromptRequestParams, - "resources/list": PaginatedRequestParams, - "resources/templates/list": PaginatedRequestParams, - "resources/read": ReadResourceRequestParams, - "resources/subscribe": SubscribeRequestParams, - "resources/unsubscribe": UnsubscribeRequestParams, - "logging/setLevel": SetLevelRequestParams, - "completion/complete": CompleteRequestParams, -} -"""Spec method → params model. Scaffolding while the lowlevel `Server`'s -`_request_handlers` stores handler-only; the registry refactor should make this -the registry's responsibility (or store params types alongside handlers).""" - -_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = { - "notifications/initialized": NotificationParams, - "notifications/roots/list_changed": NotificationParams, - "notifications/progress": ProgressNotificationParams, -} - - -class ServerRegistry(Protocol): - """The handler registry `ServerRunner` consumes. - - The lowlevel `Server` satisfies this via additive methods. - """ - - @property - def name(self) -> str: ... - @property - def version(self) -> str | None: ... - - @property - def middleware(self) -> Sequence[ContextMiddleware[Any]]: ... - - def get_request_handler(self, method: str) -> Handler | None: ... - def get_notification_handler(self, method: str) -> Handler | None: ... - def get_capabilities( - self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]] - ) -> ServerCapabilities: ... - def otel_middleware(next_on_request: OnRequest) -> OnRequest: """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. @@ -177,11 +109,11 @@ def _dump_result(result: Any) -> dict[str, Any]: @dataclass -class ServerRunner(Generic[LifespanT, ServerTransportT]): +class ServerRunner(Generic[LifespanT]): """Per-connection orchestrator. One instance per client connection.""" - server: ServerRegistry - dispatcher: Dispatcher[ServerTransportT] + server: Server[LifespanT] + dispatcher: Dispatcher[TransportContext] lifespan_state: LifespanT has_standalone_channel: bool stateless: bool = False @@ -227,17 +159,15 @@ async def _on_request( code=INVALID_REQUEST, message=f"Received {method!r} before initialization was complete", ) - handler = self.server.get_request_handler(method) - if handler is None: + entry = self.server.get_request_handler(method) + if entry is None: raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") - # TODO: scaffolding — params_type comes from a static lookup until the - # registry stores it alongside the handler. - params_type = _PARAMS_FOR_METHOD.get(method, RequestParams) # ValidationError propagates; the dispatcher's exception boundary maps # it to INVALID_PARAMS. - typed_params = params_type.model_validate(params or {}) + typed_params = entry.params_type.model_validate(params or {}) ctx = self._make_context(dctx, typed_params) - call: CallNext = partial(handler, ctx, typed_params) + # TODO: cast goes away when `ServerRequestContext = Context` lands. + call: CallNext = partial(cast(Any, entry.handler), ctx, typed_params) for mw in reversed(self.server.middleware): call = partial(mw, ctx, method, typed_params, call) return _dump_result(await call()) @@ -255,24 +185,18 @@ async def _on_notify( if not self._initialized: logger.debug("dropped %s: received before initialization", method) return - handler = self.server.get_notification_handler(method) - if handler is None: + entry = self.server.get_notification_handler(method) + if entry is None: logger.debug("no handler for notification %s", method) return - params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams) - typed_params = params_type.model_validate(params or {}) + typed_params = entry.params_type.model_validate(params or {}) ctx = self._make_context(dctx, typed_params) - await handler(ctx, typed_params) - - def _make_context( - self, dctx: DispatchContext[TransportContext], typed_params: BaseModel - ) -> Context[LifespanT, ServerTransportT]: - # `OnRequest` delivers `DispatchContext[TransportContext]`; this - # ServerRunner instance was constructed for a specific - # `ServerTransportT`, so the narrow is safe by construction. - narrowed = cast(DispatchContext[ServerTransportT], dctx) + # TODO: cast goes away when `ServerRequestContext = Context` lands. + await cast(Any, entry.handler)(ctx, typed_params) + + def _make_context(self, dctx: DispatchContext[TransportContext], typed_params: BaseModel) -> Context[LifespanT]: meta = getattr(typed_params, "meta", None) - return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta) + return Context(dctx, lifespan=self.lifespan_state, connection=self.connection, meta=meta) def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: init = InitializeRequestParams.model_validate(params or {}) @@ -289,7 +213,7 @@ def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any] self.connection.initialized.set() result = InitializeResult( protocol_version=self.connection.protocol_version, - capabilities=self.server.get_capabilities(NotificationOptions(), {}), + capabilities=self.server.capabilities(), server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), ) return _dump_result(result) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 843b0ae8b9..5ece8b9cfb 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -8,7 +8,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any +from typing import Any, cast import anyio import anyio.lowlevel @@ -17,20 +17,25 @@ from mcp.server.connection import Connection from mcp.server.context import Context -from mcp.server.lowlevel.server import Server +from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.runner import ServerRunner, otel_middleware from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchMiddleware from mcp.shared.exceptions import MCPError -from mcp.shared.transport_context import TransportContext from mcp.types import ( INTERNAL_ERROR, INVALID_REQUEST, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, + CallToolRequestParams, ClientCapabilities, Implementation, InitializeRequestParams, + ListToolsResult, + NotificationParams, + PaginatedRequestParams, + RequestParams, + SetLevelRequestParams, Tool, ) @@ -46,7 +51,7 @@ def _initialize_params() -> dict[str, Any]: ).model_dump(by_alias=True, exclude_none=True) -_seen_ctx: list[Context[Any, TransportContext]] = [] +_seen_ctx: list[Context[Any]] = [] SrvT = Server[dict[str, Any]] @@ -55,12 +60,11 @@ def server() -> SrvT: """A lowlevel Server with one tools/list handler registered.""" _seen_ctx.clear() - async def list_tools(ctx: Any, params: Any) -> Any: - # ctx is typed `Any` because Server's on_list_tools kwarg expects the - # legacy ServerRequestContext shape; ServerRunner passes the new - # `Context`. The transition is intentional — Handler is loosely typed. + async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + # ctx is `Any` while `on_*` kwargs are typed against `ServerRequestContext` + # but `ServerRunner` passes the new `Context`; tightens once the alias lands. _seen_ctx.append(ctx) - return {"tools": [Tool(name="t", input_schema={"type": "object"}).model_dump(by_alias=True)]} + return ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"})]) return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) @@ -73,7 +77,7 @@ async def connected_runner( stateless: bool = False, has_standalone_channel: bool = True, dispatch_middleware: list[DispatchMiddleware] | None = None, -) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[None, TransportContext]]]: +) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[dict[str, Any]]]]: """Yield ``(client, runner)`` running over an in-memory dispatcher pair. Starts the client (echo handlers) and `runner.run()` in a task group, wraps @@ -85,7 +89,7 @@ async def connected_runner( runner = ServerRunner( server=server, dispatcher=server_d, - lifespan_state=None, + lifespan_state={}, has_standalone_channel=has_standalone_channel, stateless=stateless, dispatch_middleware=dispatch_middleware or [], @@ -147,7 +151,7 @@ async def test_runner_routes_to_handler_and_builds_context(server: SrvT): assert result["tools"][0]["name"] == "t" ctx = _seen_ctx[0] assert isinstance(ctx, Context) - assert ctx.lifespan is None + assert ctx.lifespan == {} assert isinstance(ctx.connection, Connection) assert ctx.transport.kind == "direct" @@ -175,7 +179,7 @@ async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): async def on_roots_changed(ctx: Any, params: Any) -> None: seen.append((ctx, params)) - server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed + server.add_notification_handler("notifications/roots/list_changed", NotificationParams, on_roots_changed) async with connected_runner(server) as (client, _): await client.notify("notifications/roots/list_changed", None) # DirectDispatcher delivers synchronously; one yield is enough. @@ -249,7 +253,7 @@ async def test_runner_handler_returning_none_yields_empty_result(server: SrvT): async def set_level(ctx: Any, params: Any) -> None: return None - server._request_handlers["logging/setLevel"] = set_level + server.add_request_handler("logging/setLevel", SetLevelRequestParams, set_level) async with connected_runner(server) as (client, _): result = await client.send_raw_request("logging/setLevel", {"level": "info"}) assert result == {} @@ -260,7 +264,9 @@ async def test_runner_handler_returning_unsupported_type_surfaces_as_internal_er async def bad_return(ctx: Any, params: Any) -> int: return 42 - server._request_handlers["tools/list"] = bad_return + # cast: deliberately registering a handler with a bad return type to + # exercise the runtime check; pyright would (correctly) reject it otherwise. + server.add_request_handler("tools/list", PaginatedRequestParams, cast(Any, bad_return)) async with connected_runner(server) as (client, _): with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) @@ -275,12 +281,48 @@ async def test_runner_stateless_skips_init_gate(server: SrvT): assert result["tools"][0]["name"] == "t" +@pytest.mark.anyio +async def test_server_add_request_handler_routes_custom_method_with_validated_params(server: SrvT): + class GreetParams(RequestParams): + name: str + + received: list[GreetParams] = [] + + async def greet(ctx: Any, params: GreetParams) -> dict[str, Any]: + received.append(params) + return {"greeting": f"hello {params.name}"} + + server.add_request_handler("custom/greet", GreetParams, greet) + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("custom/greet", {"name": "world"}) + assert result == {"greeting": "hello world"} + assert isinstance(received[0], GreetParams) + assert received[0].name == "world" + + +@pytest.mark.anyio +async def test_server_capabilities_reflects_ctor_options_in_initialize_result(): + async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + server: SrvT = Server( + name="caps-test", + on_list_tools=list_tools, + notification_options=NotificationOptions(tools_changed=True), + experimental_capabilities={"ext": {"k": "v"}}, + ) + async with connected_runner(server, initialized=False) as (client, _): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["capabilities"]["tools"]["listChanged"] is True + assert result["capabilities"]["experimental"] == {"ext": {"k": "v"}} + + @pytest.mark.anyio async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: return {"content": [], "isError": False} - server._request_handlers["tools/call"] = call_tool + server.add_request_handler("tools/call", CallToolRequestParams, call_tool) async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) @@ -326,7 +368,7 @@ async def test_otel_middleware_records_error_status_on_handler_exception(server: async def failing(ctx: Any, params: Any) -> Any: raise ValueError("handler blew up") - server._request_handlers["tools/list"] = failing + server.add_request_handler("tools/list", PaginatedRequestParams, failing) async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() with pytest.raises(MCPError) as exc: diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index e01de34d33..43c2069a87 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -31,11 +31,11 @@ class _Lifespan: @pytest.mark.anyio async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): - captured: list[Context[_Lifespan, TransportContext]] = [] + captured: list[Context[_Lifespan]] = [] conn = Connection.__new__(Connection) # placeholder until running_pair gives us the dispatcher async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) + ctx: Context[_Lifespan] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) captured.append(ctx) return {} @@ -62,7 +62,7 @@ async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | results: list[CreateMessageResult] = [] async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) results.append( @@ -92,7 +92,7 @@ async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | results: list[ListRootsResult] = [] async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) results.append(await ctx.send_request(ListRootsRequest())) @@ -113,7 +113,7 @@ async def test_context_log_sends_request_scoped_message_notification(): _, c_notify = echo_handlers(crec) async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) await ctx.log("debug", "hello") @@ -137,7 +137,7 @@ async def test_context_log_includes_logger_and_meta_when_supplied(): _, c_notify = echo_handlers(crec) async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) await ctx.log("info", "x", logger="my.log", meta={"traceId": "t"}) From c46b52985fba6bc950d1e32d8d73ca1e72ba780f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 7 May 2026 19:19:51 +0000 Subject: [PATCH 26/52] feat: Connection.state + exit_stack; ctx.session_id/headers; TransportContext.headers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per-connection state without a connection_lifespan CM or a second Server generic. Stateless is the default deployment, where a per-connection lifespan would wrap a single request; the enter-late mechanics it would need (race init vs dispatcher-done, ready-gate) were more machinery than the use case warrants. - Connection.session_id: str | None — set by the mount via ServerRunner(session_id=...); per-connection, not per-message - Connection.state: dict[str, Any] — scratch that persists across requests; handlers/middleware read and write freely - Connection.exit_stack: AsyncExitStack — handlers/middleware push CMs or callbacks for per-connection teardown; ServerRunner.run() unwinds it (shielded) in a finally after dispatcher.run() returns - TransportContext.headers: Mapping[str, str] | None on the base — populated by HTTP transports, None on stdio - Context.session_id / Context.headers convenience properties - create_direct_dispatcher_pair(headers=...) and connected_runner(session_id=..., headers=...) for tests --- src/mcp/server/connection.py | 26 +++++-- src/mcp/server/context.py | 19 ++++- src/mcp/server/runner.py | 15 +++- src/mcp/shared/direct_dispatcher.py | 4 +- src/mcp/shared/transport_context.py | 8 +++ tests/server/test_runner.py | 106 +++++++++++++++++++++++++++- 6 files changed, 163 insertions(+), 15 deletions(-) diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index df3652ce0e..5991715a44 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -1,9 +1,10 @@ """`Connection` — per-client connection state and the standalone outbound channel. Always present on `Context` (never ``None``), even in stateless deployments. -Holds peer info populated at ``initialize`` time, the per-connection lifespan -output, and an `Outbound` for the standalone stream (the SSE GET stream in -streamable HTTP, or the single duplex stream in stdio). +Holds peer info populated at ``initialize`` time, per-connection scratch +``state`` and an ``exit_stack`` for teardown, and an `Outbound` for the +standalone stream (the SSE GET stream in streamable HTTP, or the single duplex +stream in stdio). `notify` is best-effort: it never raises. If there's no standalone channel (stateless HTTP) or the stream has been dropped, the notification is @@ -14,6 +15,7 @@ import logging from collections.abc import Mapping +from contextlib import AsyncExitStack from typing import Any import anyio @@ -44,17 +46,27 @@ class Connection(TypedServerRequestMixin): ``None`` until ``initialize`` completes; ``initialized`` is set then. """ - def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None: + def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_id: str | None = None) -> None: self._outbound = outbound self.has_standalone_channel = has_standalone_channel + self.session_id: str | None = session_id self.client_info: Implementation | None = None self.client_capabilities: ClientCapabilities | None = None self.protocol_version: str | None = None self.initialized: anyio.Event = anyio.Event() - # TODO: make this generic (Connection[StateT]) once connection_lifespan - # wiring lands in ServerRunner. - self.state: Any = None + + self.state: dict[str, Any] = {} + """Per-connection scratch state. Handlers and middleware may read and + write freely; persists across requests on this connection.""" + + self.exit_stack: AsyncExitStack = AsyncExitStack() + """Cleanup stack unwound by `ServerRunner` when the connection closes. + + Push context managers (``await exit_stack.enter_async_context(...)``) + or callbacks (``exit_stack.push_async_callback(...)``) from handlers or + middleware to register per-connection teardown. Unwound LIFO after + `dispatcher.run()` returns, shielded from cancellation.""" async def send_raw_request( self, diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 1cf2be1899..d1514a9add 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass from typing import Any, Generic, Protocol @@ -69,6 +69,23 @@ def connection(self) -> Connection: """The per-client `Connection` for this request's connection.""" return self._connection + @property + def session_id(self) -> str | None: + """The transport's session id for this connection, when one exists. + + Convenience for ``ctx.connection.session_id``. ``None`` on stdio and + stateless HTTP. + """ + return self._connection.session_id + + @property + def headers(self) -> Mapping[str, str] | None: + """Request headers carried by this message, when the transport has them. + + Convenience for ``ctx.transport.headers``. ``None`` on stdio. + """ + return self.transport.headers + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: """Send a request-scoped ``notifications/message`` log entry. diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 1ef711d020..1ba732ec4b 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -116,6 +116,7 @@ class ServerRunner(Generic[LifespanT]): dispatcher: Dispatcher[TransportContext] lifespan_state: LifespanT has_standalone_channel: bool + session_id: str | None = None stateless: bool = False dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) @@ -124,7 +125,9 @@ class ServerRunner(Generic[LifespanT]): def __post_init__(self) -> None: self._initialized = self.stateless - self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + self.connection = Connection( + self.dispatcher, has_standalone_channel=self.has_standalone_channel, session_id=self.session_id + ) async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: """Drive the dispatcher until the underlying channel closes. @@ -132,9 +135,15 @@ async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STAT Composes `dispatch_middleware` over `_on_request` and hands the result to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers can ``await tg.start(runner.run)`` and resume once the dispatcher is - ready to accept requests. + ready to accept requests. Once the dispatcher exits, + `connection.exit_stack` is unwound (shielded) so any per-connection + cleanup registered by handlers or middleware runs to completion. """ - await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + try: + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + finally: + with anyio.CancelScope(shield=True): + await self.connection.exit_stack.aclose() def _compose_on_request(self) -> OnRequest: """Wrap `_on_request` in `dispatch_middleware`, outermost-first. diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 27443ec874..1842cf8abc 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -162,18 +162,20 @@ async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) def create_direct_dispatcher_pair( *, can_send_request: bool = True, + headers: Mapping[str, str] | None = None, ) -> tuple[DirectDispatcher, DirectDispatcher]: """Create two `DirectDispatcher` instances wired to each other. Args: can_send_request: Sets `TransportContext.can_send_request` on both sides. Pass ``False`` to simulate a transport with no back-channel. + headers: Sets `TransportContext.headers` on both sides. Returns: A ``(left, right)`` pair. Conventionally ``left`` is the client side and ``right`` is the server side, but the wiring is symmetric. """ - ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request) + ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request, headers=headers) left = DirectDispatcher(ctx) right = DirectDispatcher(ctx) left.connect_to(right) diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py index 832cead515..9346116707 100644 --- a/src/mcp/shared/transport_context.py +++ b/src/mcp/shared/transport_context.py @@ -6,6 +6,7 @@ dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields. """ +from collections.abc import Mapping from dataclasses import dataclass __all__ = ["TransportContext"] @@ -28,3 +29,10 @@ class TransportContext: stdio, SSE, and stateful streamable HTTP. When ``False``, `DispatchContext.send_raw_request` raises `NoBackChannelError`. """ + + headers: Mapping[str, str] | None = None + """Request headers carried by this message, when the transport has them. + + Populated by HTTP-based transports; ``None`` on stdio. Handlers should + None-check before use. + """ diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 5ece8b9cfb..33df234dbe 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -6,8 +6,8 @@ under test. """ -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Mapping +from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import Any, cast import anyio @@ -76,6 +76,8 @@ async def connected_runner( initialized: bool = True, stateless: bool = False, has_standalone_channel: bool = True, + session_id: str | None = None, + headers: Mapping[str, str] | None = None, dispatch_middleware: list[DispatchMiddleware] | None = None, ) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[dict[str, Any]]]]: """Yield ``(client, runner)`` running over an in-memory dispatcher pair. @@ -85,12 +87,13 @@ async def connected_runner( ``initialized`` is true the helper performs the real ``initialize`` request before yielding, so tests start past the init-gate via the public path. """ - client, server_d = create_direct_dispatcher_pair() + client, server_d = create_direct_dispatcher_pair(headers=headers) runner = ServerRunner( server=server, dispatcher=server_d, lifespan_state={}, has_standalone_channel=has_standalone_channel, + session_id=session_id, stateless=stateless, dispatch_middleware=dispatch_middleware or [], ) @@ -380,3 +383,100 @@ async def failing(ctx: Any, params: Any) -> Any: [event] = [e for e in span.events if e.name == "exception"] assert event.attributes is not None assert event.attributes["exception.type"] == "ValueError" + + +@pytest.mark.anyio +async def test_connection_state_persists_across_requests_on_same_connection(server: SrvT) -> None: + async def count(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + ctx.connection.state["n"] = ctx.connection.state.get("n", 0) + 1 + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, count) + async with connected_runner(server) as (client, runner): + await client.send_raw_request("tools/list", None) + await client.send_raw_request("tools/list", None) + assert runner.connection.state == {"n": 2} + + +@pytest.mark.anyio +async def test_connection_exit_stack_runs_pushed_callback_after_close(server: SrvT) -> None: + cleaned: list[str] = [] + + async def push(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + async def _cleanup() -> None: + cleaned.append("done") + + ctx.connection.exit_stack.push_async_callback(_cleanup) + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, push) + async with connected_runner(server) as (client, _runner): + await client.send_raw_request("tools/list", None) + assert cleaned == [] + assert cleaned == ["done"] + + +@pytest.mark.anyio +async def test_connection_exit_stack_unwinds_entered_context_manager_after_close(server: SrvT) -> None: + events: list[str] = [] + + class _Tracker(AbstractAsyncContextManager[str]): + async def __aenter__(self) -> str: + events.append("enter") + return "resource" + + async def __aexit__(self, *exc: object) -> None: + events.append("exit") + + async def acquire(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + res = await ctx.connection.exit_stack.enter_async_context(_Tracker()) + ctx.connection.state["res"] = res + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, acquire) + async with connected_runner(server) as (client, runner): + await client.send_raw_request("tools/list", None) + assert events == ["enter"] + assert runner.connection.state["res"] == "resource" + assert events == ["enter", "exit"] + + +@pytest.mark.anyio +async def test_connection_exit_stack_runs_callbacks_lifo_after_handler_error(server: SrvT) -> None: + cleaned: list[int] = [] + + async def push_then_fail(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + for i in (1, 2, 3): + ctx.connection.exit_stack.push_async_callback(_append, i) + raise RuntimeError("boom") + + async def _append(i: int) -> None: + cleaned.append(i) + + server.add_request_handler("tools/list", PaginatedRequestParams, push_then_fail) + async with connected_runner(server) as (client, _runner): + with pytest.raises(MCPError) as ei: + await client.send_raw_request("tools/list", None) + assert ei.value.error.code == INTERNAL_ERROR + assert cleaned == [] + assert cleaned == [3, 2, 1] + + +@pytest.mark.anyio +async def test_context_session_id_and_headers_expose_connection_and_transport(server: SrvT) -> None: + async with connected_runner(server, session_id="sess-abc", headers={"authorization": "Bearer t"}) as (client, _r): + await client.send_raw_request("tools/list", None) + [ctx] = _seen_ctx + assert ctx.session_id == "sess-abc" + assert ctx.session_id == ctx.connection.session_id + assert ctx.headers == {"authorization": "Bearer t"} + assert ctx.headers is ctx.transport.headers + + +@pytest.mark.anyio +async def test_context_session_id_and_headers_default_none(server: SrvT) -> None: + async with connected_runner(server) as (client, _r): + await client.send_raw_request("tools/list", None) + [ctx] = _seen_ctx + assert ctx.session_id is None + assert ctx.headers is None From 47989e7e03fc2dc8e000f0bfa6b5f9bdf906b3a8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 8 May 2026 14:23:21 +0000 Subject: [PATCH 27/52] fix: JSONRPCDispatcher coerces string response/progress IDs to int for correlation Matches BaseSession._normalize_request_id and the TypeScript SDK: a peer that echoes the request ID as a JSON string still resolves the waiter. Applied at both lookup sites (_resolve_pending and the progress-token match). Parity prep for the PR6 e2e suite. --- src/mcp/shared/jsonrpc_dispatcher.py | 19 +++++- tests/shared/test_jsonrpc_dispatcher.py | 83 +++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index f1e7b3675e..b450bb66d5 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -76,6 +76,21 @@ `TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied.""" +def _coerce_id(request_id: RequestId) -> RequestId: + """Coerce a string request ID to int when it's a valid int literal. + + `_allocate_id` only ever produces ``int`` keys for ``_pending``, but a peer + may echo the ID back as a JSON string. The TypeScript SDK and `BaseSession` + both perform this coercion at lookup time so the response still correlates. + """ + if isinstance(request_id, str): + try: + return int(request_id) + except ValueError: + pass + return request_id + + @dataclass(slots=True) class _Pending: """An outbound request awaiting its response.""" @@ -409,7 +424,7 @@ def _dispatch_notification( if msg.method == "notifications/progress": match msg.params: case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( - pending := self._pending.get(token) + pending := self._pending.get(_coerce_id(token)) ) is not None and pending.on_progress is not None: total = msg.params.get("total") message = msg.params.get("message") @@ -428,7 +443,7 @@ def _dispatch_notification( self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: - pending = self._pending.get(request_id) if request_id is not None else None + pending = self._pending.get(_coerce_id(request_id)) if request_id is not None else None if pending is None: logger.debug("dropping response for unknown/late request id %r", request_id) return diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 7f9f11718b..5755b55d15 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -18,6 +18,7 @@ from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] JSONRPCDispatcher, + _coerce_id, _outbound_metadata, _Pending, ) @@ -29,6 +30,7 @@ INVALID_PARAMS, ErrorData, JSONRPCError, + JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, Tool, @@ -511,6 +513,87 @@ def test_outbound_metadata_with_resumption_token_returns_client_metadata(): assert _outbound_metadata(None, {}) is None +@pytest.mark.anyio +async def test_response_with_string_id_correlates_to_int_keyed_pending_request(): + """A peer that echoes the request ID as a JSON string still resolves the waiter.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def respond_stringly() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=str(rid), result={"ok": True})) + ) + + tg.start_soon(respond_stringly) + result = await client.send_raw_request("ping", None) + assert result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_with_string_token_reaches_callback_for_int_keyed_request(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + seen: list[float] = [] + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def respond_with_string_token_progress() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params={"progressToken": str(rid), "progress": 0.5}, + ) + ) + ) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=rid, result={"ok": True})) + ) + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + seen.append(progress) + + tg.start_soon(respond_with_string_token_progress) + result = await client.send_raw_request("ping", None, {"on_progress": on_progress}) + assert result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert seen == [0.5] + + +def test_coerce_id_passes_through_non_numeric_string_and_int(): + assert _coerce_id("7") == 7 + assert _coerce_id("not-an-int") == "not-an-int" + assert _coerce_id(42) == 42 + + @pytest.mark.anyio async def test_jsonrpc_error_response_with_null_id_is_dropped(): """Parse-error responses (id=null) have no waiter; they're logged and dropped.""" From 7234f0eeee91e343dc11d24f410065dab63416bc Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 12:07:20 +0000 Subject: [PATCH 28/52] feat: DispatchContext.message_metadata passes SessionMessage.metadata through verbatim MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Scaffolding for the swap: ServerRunner._make_context will read this to populate ServerRequestContext.request / close_sse_stream / etc. the same way the current Server._handle_request does. Marked TODO(maxisbey): remove for Context rework — the redesign replaces this with the per-transport context shape. --- src/mcp/shared/direct_dispatcher.py | 3 + src/mcp/shared/dispatcher.py | 14 +++++ src/mcp/shared/jsonrpc_dispatcher.py | 12 +++- tests/shared/test_dispatcher.py | 10 ++++ tests/shared/test_jsonrpc_dispatcher.py | 73 ++++++++++++++++++++++++- 5 files changed, 110 insertions(+), 2 deletions(-) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 1842cf8abc..aec76b7057 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -24,6 +24,7 @@ from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.message import MessageMetadata from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT @@ -47,6 +48,8 @@ class _DirectDispatchContext: transport: TransportContext _back_request: _Request _back_notify: _Notify + message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework + """Always ``None``: in-memory dispatch attaches no transport metadata.""" _on_progress: ProgressFnT | None = None cancel_requested: anyio.Event = field(default_factory=anyio.Event) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index 20c090323b..994fc076ba 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -22,6 +22,7 @@ import anyio import anyio.abc +from mcp.shared.message import MessageMetadata from mcp.shared.transport_context import TransportContext __all__ = [ @@ -107,6 +108,19 @@ def transport(self) -> TransportT_co: """Transport-specific metadata for this inbound message.""" ... + @property + def message_metadata(self) -> MessageMetadata: + """The metadata the transport attached to this inbound message, if any. + + This is `SessionMessage.metadata` passed through verbatim: HTTP + transports attach `ServerMessageMetadata` (the HTTP request, SSE + stream-close callbacks); stdio and in-memory dispatch attach nothing. + Tied to the `SessionMessage` wire format — goes away when transports + stop delivering messages that way. + """ + # TODO(maxisbey): remove for context rework + ... + @property def cancel_requested(self) -> anyio.Event: """Set when the peer sends ``notifications/cancelled`` for this request.""" diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index b450bb66d5..836beb39cf 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -116,6 +116,13 @@ class _JSONRPCDispatchContext(Generic[TransportT]): transport: TransportT _dispatcher: JSONRPCDispatcher[TransportT] _request_id: RequestId | None + message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework + """The transport-attached `SessionMessage.metadata` for this inbound message. + + Carries `ServerMessageMetadata` (HTTP request, SSE stream-close callbacks) + that the server lifts onto its request context. ``None`` for transports + that attach nothing. + """ _progress_token: ProgressToken | None = None _closed: bool = False cancel_requested: anyio.Event = field(default_factory=anyio.Event) @@ -398,6 +405,7 @@ def _dispatch_request( transport=transport_ctx, _dispatcher=self, _request_id=req.id, + message_metadata=metadata, _progress_token=progress_token, ) scope = anyio.CancelScope() @@ -439,7 +447,9 @@ def _dispatch_notification( pass # fall through: progress is also teed to on_notify transport_ctx = self._transport_builder(None, metadata) - dctx = _JSONRPCDispatchContext(transport=transport_ctx, _dispatcher=self, _request_id=None) + dctx = _JSONRPCDispatchContext( + transport=transport_ctx, _dispatcher=self, _request_id=None, message_metadata=metadata + ) self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index bdadd4cdae..e1bc368df1 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -196,6 +196,16 @@ async def server_on_request( assert result == {"ok": True} +@pytest.mark.anyio +async def test_ctx_message_metadata_is_none_when_transport_attaches_nothing(pair_factory: PairFactory): + """Plain requests carry no transport metadata, so handlers see ``None``.""" + async with running_pair(pair_factory) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + assert len(srec.contexts) == 1 + assert srec.contexts[0].message_metadata is None + + @pytest.mark.anyio async def test_direct_send_raw_request_wraps_non_mcperror_exception_as_internal_error_with_cause(): """DirectDispatcher-specific: the original exception is chained via __cause__.""" diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 5755b55d15..d2d30834ac 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -22,7 +22,7 @@ _outbound_metadata, _Pending, ) -from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.transport_context import TransportContext from mcp.types import ( CONNECTION_CLOSED, @@ -274,6 +274,77 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> s.close() +@pytest.mark.anyio +async def test_ctx_message_metadata_carries_inbound_request_metadata(): + """Transport-attached metadata (HTTP request, SSE close hooks) is readable off the dispatch context.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + metadata = ServerMessageMetadata(request_context="request-scoped-data") + seen: list[MessageMetadata] = [] + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen.append(ctx.message_metadata) + return {} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params=None), + metadata=metadata, + ) + ) + with anyio.fail_after(5): + await s2c_recv.receive() # response sent ⇒ the handler has run + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert len(seen) == 1 + assert seen[0] is metadata # the exact object, passed through verbatim + + +@pytest.mark.anyio +async def test_ctx_message_metadata_carries_inbound_notification_metadata(): + """Notifications get the same metadata pass-through as requests.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + metadata = ServerMessageMetadata(request_context="request-scoped-data") + seen: list[MessageMetadata] = [] + notified = anyio.Event() + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + seen.append(ctx.message_metadata) + notified.set() + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send( + SessionMessage( + message=JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params=None), + metadata=metadata, + ) + ) + with anyio.fail_after(5): + await notified.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert len(seen) == 1 + assert seen[0] is metadata + + @pytest.mark.anyio async def test_ctx_progress_with_only_progress_value_omits_total_and_message(): received: list[tuple[float, float | None, str | None]] = [] From b87060de13b49644e632f6eb0c4703ecd3a37890 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 12:19:14 +0000 Subject: [PATCH 29/52] feat: DispatchContext.request_id; drop RequestId arg from transport_builder request_id is the wire-format correlation id (JSON-RPC message id; None for notifications and for dispatchers without one). Lives on DispatchContext because it's wire-format-shaped (dispatcher domain), not transport-shaped. ServerRunner._make_context will read it to populate ServerRequestContext.request_id. transport_builder no longer takes RequestId: that arg existed so the builder could put the id on a TransportContext subclass, which is now redundant with dctx.request_id. Nothing read it. --- src/mcp/shared/direct_dispatcher.py | 4 +++- src/mcp/shared/dispatcher.py | 11 +++++++++++ src/mcp/shared/jsonrpc_dispatcher.py | 16 ++++++++++------ tests/shared/conftest.py | 2 +- tests/shared/test_dispatcher.py | 11 +++++++++++ tests/shared/test_jsonrpc_dispatcher.py | 2 +- 6 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index aec76b7057..b688d63afa 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -26,7 +26,7 @@ from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import MessageMetadata from mcp.shared.transport_context import TransportContext -from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT +from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT, RequestId __all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"] @@ -48,6 +48,8 @@ class _DirectDispatchContext: transport: TransportContext _back_request: _Request _back_notify: _Notify + request_id: RequestId | None = None + """Always ``None``: direct dispatch has no wire-level request id.""" message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework """Always ``None``: in-memory dispatch attaches no transport metadata.""" _on_progress: ProgressFnT | None = None diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index 994fc076ba..f01eaebe87 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -24,6 +24,7 @@ from mcp.shared.message import MessageMetadata from mcp.shared.transport_context import TransportContext +from mcp.types import RequestId __all__ = [ "CallOptions", @@ -108,6 +109,16 @@ def transport(self) -> TransportT_co: """Transport-specific metadata for this inbound message.""" ... + @property + def request_id(self) -> RequestId | None: + """The id of the inbound request, or ``None`` for a notification. + + For JSON-RPC this is the wire ``id`` field. Handlers thread it through + as ``related_request_id`` on outbound notifications so HTTP transports + can route them onto the originating request's response stream. + """ + ... + @property def message_metadata(self) -> MessageMetadata: """The metadata the transport attached to this inbound message, if any. diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 836beb39cf..5ea5ea88f1 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -127,6 +127,10 @@ class _JSONRPCDispatchContext(Generic[TransportT]): _closed: bool = False cancel_requested: anyio.Event = field(default_factory=anyio.Event) + @property + def request_id(self) -> RequestId | None: + return self._request_id + @property def can_send_request(self) -> bool: return self.transport.can_send_request and not self._closed @@ -158,7 +162,7 @@ def close(self) -> None: self._closed = True -def _default_transport_builder(_request_id: RequestId | None, _meta: MessageMetadata) -> TransportContext: +def _default_transport_builder(_meta: MessageMetadata) -> TransportContext: return TransportContext(kind="jsonrpc", can_send_request=True) @@ -199,7 +203,7 @@ def __init__( read_stream: ReadStream[SessionMessage | Exception], write_stream: WriteStream[SessionMessage], *, - transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT], + transport_builder: Callable[[MessageMetadata], TransportT], peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, ) -> None: ... @@ -208,7 +212,7 @@ def __init__( read_stream: ReadStream[SessionMessage | Exception], write_stream: WriteStream[SessionMessage], *, - transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT] | None = None, + transport_builder: Callable[[MessageMetadata], TransportT] | None = None, peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, ) -> None: @@ -218,7 +222,7 @@ def __init__( # `TransportT` is `TransportContext`, so the default is type-correct; # pyright can't see across overloads, hence the cast. self._transport_builder = cast( - "Callable[[RequestId | None, MessageMetadata], TransportT]", + "Callable[[MessageMetadata], TransportT]", transport_builder or _default_transport_builder, ) self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode @@ -400,7 +404,7 @@ def _dispatch_request( pass case _: progress_token = None - transport_ctx = self._transport_builder(req.id, metadata) + transport_ctx = self._transport_builder(metadata) dctx = _JSONRPCDispatchContext( transport=transport_ctx, _dispatcher=self, @@ -446,7 +450,7 @@ def _dispatch_notification( case _: pass # fall through: progress is also teed to on_notify - transport_ctx = self._transport_builder(None, metadata) + transport_ctx = self._transport_builder(metadata) dctx = _JSONRPCDispatchContext( transport=transport_ctx, _dispatcher=self, _request_id=None, message_metadata=metadata ) diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py index 1222c05aba..7b53b42654 100644 --- a/tests/shared/conftest.py +++ b/tests/shared/conftest.py @@ -35,7 +35,7 @@ def jsonrpc_pair(*, can_send_request: bool = True) -> DispatcherTriple: c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - def builder(_rid: object, _meta: object) -> TransportContext: + def builder(_meta: object) -> TransportContext: return TransportContext(kind="jsonrpc", can_send_request=can_send_request) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send, transport_builder=builder) diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index e1bc368df1..a9f35e5c2e 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -206,6 +206,17 @@ async def test_ctx_message_metadata_is_none_when_transport_attaches_nothing(pair assert srec.contexts[0].message_metadata is None +@pytest.mark.anyio +async def test_ctx_request_id_exposes_inbound_id(pair_factory: PairFactory): + """JSON-RPC carries the wire id through; direct dispatch has none.""" + async with running_pair(pair_factory) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + await client.send_raw_request("tools/call", None) + a, b = (ctx.request_id for ctx in srec.contexts) + assert (a is None and b is None) or (isinstance(a, int) and isinstance(b, int) and a != b) + + @pytest.mark.anyio async def test_direct_send_raw_request_wraps_non_mcperror_exception_as_internal_error_with_cause(): """DirectDispatcher-specific: the original exception is chained via __cause__.""" diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index d2d30834ac..de1b6c4600 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -203,7 +203,7 @@ async def test_raise_handler_exceptions_true_propagates_out_of_run(): c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - def builder(_rid: object, _meta: object) -> TransportContext: + def builder(_meta: object) -> TransportContext: return TransportContext(kind="jsonrpc", can_send_request=True) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( From caa7ca9e1e9c0e0ff32d52fddf4fd9f6283b99cb Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 12:25:37 +0000 Subject: [PATCH 30/52] style: ASCII-only docstrings/comments in v2 dispatcher files Replace double-backticks with single, em-dashes with hyphens, and the one fat-arrow with => across files this branch touches. Also drop the unused TransportBuilder type alias (stale after the request_id arg was removed from the builder signature). --- src/mcp/server/_typed_request.py | 14 +++--- src/mcp/server/connection.py | 30 ++++++------ src/mcp/server/context.py | 24 +++++----- src/mcp/server/lowlevel/server.py | 20 ++++---- src/mcp/server/runner.py | 24 +++++----- src/mcp/shared/context.py | 8 ++-- src/mcp/shared/direct_dispatcher.py | 12 ++--- src/mcp/shared/dispatcher.py | 52 ++++++++++---------- src/mcp/shared/exceptions.py | 2 +- src/mcp/shared/jsonrpc_dispatcher.py | 63 ++++++++++++------------- src/mcp/shared/peer.py | 28 +++++------ src/mcp/shared/transport_context.py | 8 ++-- tests/server/conftest.py | 6 +-- tests/server/test_connection.py | 2 +- tests/server/test_runner.py | 10 ++-- tests/shared/test_context.py | 6 +-- tests/shared/test_dispatcher.py | 6 +-- tests/shared/test_jsonrpc_dispatcher.py | 10 ++-- 18 files changed, 160 insertions(+), 165 deletions(-) diff --git a/src/mcp/server/_typed_request.py b/src/mcp/server/_typed_request.py index 4334b20a94..ab3ac11803 100644 --- a/src/mcp/server/_typed_request.py +++ b/src/mcp/server/_typed_request.py @@ -1,12 +1,12 @@ -"""Typed ``send_request`` for server-to-client requests. +"""Typed `send_request` for server-to-client requests. `TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over the host's raw `Outbound.send_raw_request`. Spec server-to-client request types have their result type inferred via per-type overloads; custom requests pass -``result_type=`` explicitly. +`result_type=` explicitly. If the spec's request set grows substantially, consider declaring the result -mapping on the request types themselves (a ``__mcp_result__`` ClassVar read via +mapping on the request types themselves (a `__mcp_result__` ClassVar read via a structural protocol) so this overload ladder doesn't need maintaining per-host-class. """ @@ -42,10 +42,10 @@ class TypedServerRequestMixin: - """Typed ``send_request`` for the server-to-client request set. + """Typed `send_request` for the server-to-client request set. Mixed into `Connection` and the server `Context`. Each method constrains - ``self`` to `Outbound` so any host with ``send_raw_request`` works. + `self` to `Outbound` so any host with `send_raw_request` works. """ @overload @@ -74,12 +74,12 @@ async def send_request( """Send a typed server-to-client request and return its typed result. For spec request types the result type is inferred. For custom requests - pass ``result_type=`` explicitly. + pass `result_type=` explicitly. Raises: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. - KeyError: ``result_type`` omitted for a non-spec request type. + KeyError: `result_type` omitted for a non-spec request type. """ raw = await self.send_raw_request(req.method, dump_params(req.params), opts) cls = result_type if result_type is not None else _RESULT_FOR[type(req)] diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index 5991715a44..65267e34ca 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -1,14 +1,14 @@ -"""`Connection` — per-client connection state and the standalone outbound channel. +"""`Connection` - per-client connection state and the standalone outbound channel. -Always present on `Context` (never ``None``), even in stateless deployments. -Holds peer info populated at ``initialize`` time, per-connection scratch -``state`` and an ``exit_stack`` for teardown, and an `Outbound` for the +Always present on `Context` (never `None`), even in stateless deployments. +Holds peer info populated at `initialize` time, per-connection scratch +`state` and an `exit_stack` for teardown, and an `Outbound` for the standalone stream (the SSE GET stream in streamable HTTP, or the single duplex stream in stdio). `notify` is best-effort: it never raises. If there's no standalone channel (stateless HTTP) or the stream has been dropped, the notification is -debug-logged and silently discarded — server-initiated notifications are +debug-logged and silently discarded - server-initiated notifications are inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when there's no channel; `ping` is the only spec-sanctioned standalone request. """ @@ -43,7 +43,7 @@ class Connection(TypedServerRequestMixin): """Per-client connection state and standalone-stream `Outbound`. Constructed by `ServerRunner` once per connection. The peer-info fields are - ``None`` until ``initialize`` completes; ``initialized`` is set then. + `None` until `initialize` completes; `initialized` is set then. """ def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_id: str | None = None) -> None: @@ -63,8 +63,8 @@ def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_ self.exit_stack: AsyncExitStack = AsyncExitStack() """Cleanup stack unwound by `ServerRunner` when the connection closes. - Push context managers (``await exit_stack.enter_async_context(...)``) - or callbacks (``exit_stack.push_async_callback(...)``) from handlers or + Push context managers (`await exit_stack.enter_async_context(...)`) + or callbacks (`exit_stack.push_async_callback(...)`) from handlers or middleware to register per-connection teardown. Unwound LIFO after `dispatcher.run()` returns, shielded from cancellation.""" @@ -76,13 +76,13 @@ async def send_raw_request( ) -> dict[str, Any]: """Send a raw request on the standalone stream. - Low-level `Outbound` channel. Prefer the typed ``send_request`` (from + Low-level `Outbound` channel. Prefer the typed `send_request` (from `TypedServerRequestMixin`) or the convenience methods below; use this directly only for off-spec messages. Raises: MCPError: The peer responded with an error. - NoBackChannelError: ``has_standalone_channel`` is ``False``. + NoBackChannelError: `has_standalone_channel` is `False`. """ if not self.has_standalone_channel: raise NoBackChannelError(method) @@ -103,16 +103,16 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: logger.debug("dropped %s: standalone stream closed", method) async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: - """Send a ``ping`` request on the standalone stream. + """Send a `ping` request on the standalone stream. Raises: MCPError: The peer responded with an error. - NoBackChannelError: ``has_standalone_channel`` is ``False``. + NoBackChannelError: `has_standalone_channel` is `False`. """ await self.send_raw_request("ping", dump_params(None, meta), opts) async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: - """Send a ``notifications/message`` log entry on the standalone stream. Best-effort.""" + """Send a `notifications/message` log entry on the standalone stream. Best-effort.""" params: dict[str, Any] = {"level": level, "data": data} if logger is not None: params["logger"] = logger @@ -133,9 +133,9 @@ async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> def check_capability(self, capability: ClientCapabilities) -> bool: """Return whether the connected client declared the given capability. - Returns ``False`` if ``initialize`` hasn't completed yet. + Returns `False` if `initialize` hasn't completed yet. """ - # TODO: redesign — mirrors v1 ServerSession.check_client_capability + # TODO: redesign - mirrors v1 ServerSession.check_client_capability # verbatim for parity. if self.client_capabilities is None: return False diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index d1514a9add..fd5c20f264 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -39,9 +39,9 @@ class Context(BaseContext[TransportContext], PeerMixin, TypedServerRequestMixin, """Server-side per-request context. Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`), - `PeerMixin` (kwarg-style ``sample``/``elicit_*``/``list_roots``/``ping``), - and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds - ``lifespan`` and ``connection``. + `PeerMixin` (kwarg-style `sample`/`elicit_*`/`list_roots`/`ping`), + and `TypedServerRequestMixin` (typed `send_request(req) -> Result`). Adds + `lifespan` and `connection`. Constructed by `ServerRunner` per inbound request and handed to the user's handler. @@ -73,7 +73,7 @@ def connection(self) -> Connection: def session_id(self) -> str | None: """The transport's session id for this connection, when one exists. - Convenience for ``ctx.connection.session_id``. ``None`` on stdio and + Convenience for `ctx.connection.session_id`. `None` on stdio and stateless HTTP. """ return self._connection.session_id @@ -82,16 +82,16 @@ def session_id(self) -> str | None: def headers(self) -> Mapping[str, str] | None: """Request headers carried by this message, when the transport has them. - Convenience for ``ctx.transport.headers``. ``None`` on stdio. + Convenience for `ctx.transport.headers`. `None` on stdio. """ return self.transport.headers async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: - """Send a request-scoped ``notifications/message`` log entry. + """Send a request-scoped `notifications/message` log entry. Uses this request's back-channel (so the entry rides the request's SSE - stream in streamable HTTP), not the standalone stream — use - ``ctx.connection.log(...)`` for that. + stream in streamable HTTP), not the standalone stream - use + `ctx.connection.log(...)` for that. """ params: dict[str, Any] = {"level": level, "data": data} if logger is not None: @@ -111,16 +111,16 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * class ServerMiddleware(Protocol[_MwLifespanT]): - """Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``. + """Context-tier middleware: `(ctx, method, typed_params, call_next) -> result`. Runs *inside* `ServerRunner._on_request` after params validation and - `Context` construction. Wraps registered handlers (including ``ping``) but - not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed + `Context` construction. Wraps registered handlers (including `ping`) but + not `initialize`, `METHOD_NOT_FOUND`, or validation failures. Listed outermost-first on `Server.middleware`. `Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific middleware sees `ctx.lifespan: L`. A reusable middleware can be typed - `ServerMiddleware[object]` — `Context` is covariant in `LifespanT`, so it + `ServerMiddleware[object]` - `Context` is covariant in `LifespanT`, so it registers on any `Server[L]`. """ diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 375ca94c0d..afa2f90109 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -81,10 +81,10 @@ async def main(): _ParamsT = TypeVar("_ParamsT", bound=BaseModel, default=BaseModel) RequestHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[HandlerResult]] -"""A registered request handler: ``(ctx, params) -> result``.""" +"""A registered request handler: `(ctx, params) -> result`.""" NotificationHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[None]] -"""A registered notification handler: ``(ctx, params) -> None``.""" +"""A registered notification handler: `(ctx, params) -> None`.""" @dataclass(frozen=True, slots=True) @@ -93,7 +93,7 @@ class HandlerEntry(Generic[LifespanResultT]): Stored in `Server._request_handlers` / `_notification_handlers` and consumed by `ServerRunner` to validate, build `Context`, and invoke. The handler's - second-argument type is erased to ``Any`` in storage (each entry has a + second-argument type is erased to `Any` in storage (each entry has a different concrete params type and `Callable` parameters are contravariant); the precise type is recoverable via `params_type`. The correlation is enforced at registration time by `Server.add_request_handler`. @@ -262,11 +262,11 @@ def add_request_handler( params_type: type[_ParamsT], handler: RequestHandler[LifespanResultT, _ParamsT], ) -> None: - """Register a request handler for ``method``. + """Register a request handler for `method`. - ``params_type`` is the model incoming params are validated against + `params_type` is the model incoming params are validated against before the handler is invoked. It should subclass `RequestParams` so - ``_meta`` parses uniformly. Replaces any existing handler for the same + `_meta` parses uniformly. Replaces any existing handler for the same method (no collision guard against spec methods). """ self._request_handlers[method] = HandlerEntry(params_type, handler) @@ -277,9 +277,9 @@ def add_notification_handler( params_type: type[_ParamsT], handler: NotificationHandler[LifespanResultT, _ParamsT], ) -> None: - """Register a notification handler for ``method``. + """Register a notification handler for `method`. - ``params_type`` should subclass `NotificationParams` so ``_meta`` + `params_type` should subclass `NotificationParams` so `_meta` parses uniformly. Replaces any existing handler. """ self._notification_handlers[method] = HandlerEntry(params_type, handler) @@ -300,11 +300,11 @@ def _has_handler(self, method: str) -> bool: # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ def get_request_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: - """Return the registered entry for a request method, or ``None``.""" + """Return the registered entry for a request method, or `None`.""" return self._request_handlers.get(method) def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: - """Return the registered entry for a notification method, or ``None``.""" + """Return the registered entry for a notification method, or `None`.""" return self._notification_handlers.get(method) def capabilities(self) -> types.ServerCapabilities: diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 1ba732ec4b..662827e592 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -1,16 +1,16 @@ -"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`. +"""`ServerRunner` - per-connection orchestrator over a `Dispatcher`. `ServerRunner` is the bridge between the dispatcher layer (`on_request` / `on_notify`, untyped dicts) and the user's handler layer (typed `Context`, typed params). One instance per client connection. It: -* handles the ``initialize`` handshake and populates `Connection` -* gates requests until initialized (``ping`` exempt) +* handles the `initialize` handshake and populates `Connection` +* gates requests until initialized (`ping` exempt) * looks up the handler in the server's registry, validates params, builds `Context`, runs the middleware chain, returns the result dict -* drives ``dispatcher.run()`` and the per-connection lifespan +* drives `dispatcher.run()` and the per-connection lifespan -`ServerRunner` holds a `Server` directly — `Server` is the registry. +`ServerRunner` holds a `Server` directly - `Server` is the registry. """ from __future__ import annotations @@ -56,8 +56,8 @@ def otel_middleware(next_on_request: OnRequest) -> OnRequest: """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. Mirrors the span shape of the existing `Server._handle_request`: span name - ``"MCP handle []"``, ``mcp.method.name`` attribute, W3C - trace context extracted from ``params._meta`` (SEP-414), and an ERROR + `"MCP handle []"`, `mcp.method.name` attribute, W3C + trace context extracted from `params._meta` (SEP-414), and an ERROR status if the handler raises. """ @@ -133,8 +133,8 @@ async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STAT """Drive the dispatcher until the underlying channel closes. Composes `dispatch_middleware` over `_on_request` and hands the result - to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers - can ``await tg.start(runner.run)`` and resume once the dispatcher is + to `dispatcher.run()`. `task_status.started()` is forwarded so callers + can `await tg.start(runner.run)` and resume once the dispatcher is ready to accept requests. Once the dispatcher exits, `connection.exit_stack` is unwound (shielded) so any per-connection cleanup registered by handlers or middleware runs to completion. @@ -148,8 +148,8 @@ async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STAT def _compose_on_request(self) -> OnRequest: """Wrap `_on_request` in `dispatch_middleware`, outermost-first. - Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict`` - and wraps everything — initialize, METHOD_NOT_FOUND, validation + Dispatch-tier middleware sees raw `(dctx, method, params) -> dict` + and wraps everything - initialize, METHOD_NOT_FOUND, validation failures included. `run()` calls this once and hands the result to `dispatcher.run()`. """ @@ -212,7 +212,7 @@ def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any] self.connection.client_info = init.client_info self.connection.client_capabilities = init.capabilities # TODO: real version negotiation. This always responds with LATEST, - # which is wrong — the server should pick the highest version both + # which is wrong - the server should pick the highest version both # sides support and compute a per-connection feature set from it. # See FOLLOWUPS: "Consolidate per-connection mode/negotiation". self.connection.protocol_version = ( diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index ff69c48401..437f821a81 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,4 +1,4 @@ -"""`BaseContext` — the user-facing per-request context. +"""`BaseContext` - the user-facing per-request context. Composition over a `DispatchContext`: forwards the transport metadata, the back-channel (`send_raw_request`/`notify`), progress reporting, and the cancel @@ -43,7 +43,7 @@ def transport(self) -> TransportT: @property def cancel_requested(self) -> anyio.Event: - """Set when the peer sends ``notifications/cancelled`` for this request.""" + """Set when the peer sends `notifications/cancelled` for this request.""" return self._dctx.cancel_requested @property @@ -53,7 +53,7 @@ def can_send_request(self) -> bool: @property def meta(self) -> RequestParamsMeta | None: - """The inbound request's ``_meta`` field, if present.""" + """The inbound request's `_meta` field, if present.""" return self._meta async def send_raw_request( @@ -66,7 +66,7 @@ async def send_raw_request( Raises: MCPError: The peer responded with an error. - NoBackChannelError: ``can_send_request`` is ``False``. + NoBackChannelError: `can_send_request` is `False`. """ return await self._dctx.send_raw_request(method, params, opts) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index b688d63afa..51dddf1e79 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -10,7 +10,7 @@ * embed a server in-process when the JSON-RPC overhead is unnecessary Unlike `JSONRPCDispatcher`, exceptions raised in a handler propagate directly -to the caller — there is no exception-to-`ErrorData` boundary here. +to the caller - there is no exception-to-`ErrorData` boundary here. """ from __future__ import annotations @@ -49,9 +49,9 @@ class _DirectDispatchContext: _back_request: _Request _back_notify: _Notify request_id: RequestId | None = None - """Always ``None``: direct dispatch has no wire-level request id.""" + """Always `None`: direct dispatch has no wire-level request id.""" message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework - """Always ``None``: in-memory dispatch attaches no transport metadata.""" + """Always `None`: in-memory dispatch attaches no transport metadata.""" _on_progress: ProgressFnT | None = None cancel_requested: anyio.Event = field(default_factory=anyio.Event) @@ -173,12 +173,12 @@ def create_direct_dispatcher_pair( Args: can_send_request: Sets `TransportContext.can_send_request` on both - sides. Pass ``False`` to simulate a transport with no back-channel. + sides. Pass `False` to simulate a transport with no back-channel. headers: Sets `TransportContext.headers` on both sides. Returns: - A ``(left, right)`` pair. Conventionally ``left`` is the client side - and ``right`` is the server side, but the wiring is symmetric. + A `(left, right)` pair. Conventionally `left` is the client side + and `right` is the server side, but the wiring is symmetric. """ ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request, headers=headers) left = DirectDispatcher(ctx) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index f01eaebe87..aca96231f3 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -1,18 +1,18 @@ -"""Dispatcher Protocol — the call/return boundary between transports and handlers. +"""Dispatcher Protocol - the call/return boundary between transports and handlers. A Dispatcher turns a duplex message channel into two things: -* an outbound API: ``send_raw_request(method, params)`` and ``notify(method, params)`` -* an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop +* an outbound API: `send_raw_request(method, params)` and `notify(method, params)` +* an inbound pump: `run(on_request, on_notify)` that drives the receive loop and invokes the supplied handlers for each incoming request/notification It is deliberately *not* MCP-aware. Method names are strings, params and -results are ``dict[str, Any]``. The MCP type layer (request/result models, -capability negotiation, ``Context``) sits above this; the wire encoding +results are `dict[str, Any]`. The MCP type layer (request/result models, +capability negotiation, `Context`) sits above this; the wire encoding (JSON-RPC, gRPC, in-process direct calls) sits below it. -See ``JSONRPCDispatcher`` for the production implementation and -``DirectDispatcher`` for an in-memory implementation used in tests and for +See `JSONRPCDispatcher` for the production implementation and +`DirectDispatcher` for an in-memory implementation used in tests and for embedding a server in-process. """ @@ -53,10 +53,10 @@ class CallOptions(TypedDict, total=False): """ timeout: float - """Seconds to wait for a result before raising and sending ``notifications/cancelled``.""" + """Seconds to wait for a result before raising and sending `notifications/cancelled`.""" on_progress: ProgressFnT - """Receive ``notifications/progress`` updates for this request.""" + """Receive `notifications/progress` updates for this request.""" resumption_token: str """Opaque token to resume a previously interrupted request (transport-dependent).""" @@ -71,7 +71,7 @@ class Outbound(Protocol): Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel during an inbound request) extend this. The MCP type layer (`PeerMixin`, - `Connection`, `Context`) builds typed ``send_request`` / convenience methods + `Connection`, `Context`) builds typed `send_request` / convenience methods on top of this raw channel. """ @@ -96,12 +96,12 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: class DispatchContext(Outbound, Protocol[TransportT_co]): - """Per-request context handed to ``on_request`` / ``on_notify``. + """Per-request context handed to `on_request` / `on_notify`. Carries the transport metadata for the inbound message and provides the back-channel for sending requests/notifications to the peer while handling it. `send_raw_request` raises `NoBackChannelError` if - ``transport.can_send_request`` is ``False``. + `transport.can_send_request` is `False`. """ @property @@ -111,10 +111,10 @@ def transport(self) -> TransportT_co: @property def request_id(self) -> RequestId | None: - """The id of the inbound request, or ``None`` for a notification. + """The id of the inbound request, or `None` for a notification. - For JSON-RPC this is the wire ``id`` field. Handlers thread it through - as ``related_request_id`` on outbound notifications so HTTP transports + For JSON-RPC this is the wire `id` field. Handlers thread it through + as `related_request_id` on outbound notifications so HTTP transports can route them onto the originating request's response stream. """ ... @@ -126,7 +126,7 @@ def message_metadata(self) -> MessageMetadata: This is `SessionMessage.metadata` passed through verbatim: HTTP transports attach `ServerMessageMetadata` (the HTTP request, SSE stream-close callbacks); stdio and in-memory dispatch attach nothing. - Tied to the `SessionMessage` wire format — goes away when transports + Tied to the `SessionMessage` wire format - goes away when transports stop delivering messages that way. """ # TODO(maxisbey): remove for context rework @@ -134,7 +134,7 @@ def message_metadata(self) -> MessageMetadata: @property def cancel_requested(self) -> anyio.Event: - """Set when the peer sends ``notifications/cancelled`` for this request.""" + """Set when the peer sends `notifications/cancelled` for this request.""" ... async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: @@ -146,13 +146,13 @@ async def progress(self, progress: float, total: float | None = None, message: s OnRequest = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] -"""Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response.""" +"""Handler for inbound requests: `(ctx, method, params) -> result`. Raise `MCPError` to send an error response.""" OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] -"""Handler for inbound notifications: ``(ctx, method, params)``.""" +"""Handler for inbound notifications: `(ctx, method, params)`.""" DispatchMiddleware = Callable[[OnRequest], OnRequest] -"""Wraps an ``OnRequest`` to produce another ``OnRequest``. Applied outermost-first.""" +"""Wraps an `OnRequest` to produce another `OnRequest`. Applied outermost-first.""" class Dispatcher(Outbound, Protocol[TransportT_co]): @@ -171,12 +171,12 @@ async def run( ) -> None: """Drive the receive loop until the underlying channel closes. - Each inbound request is dispatched to ``on_request`` in its own task; - the returned dict (or raised ``MCPError``) is sent back as the response. - Inbound notifications go to ``on_notify``. + Each inbound request is dispatched to `on_request` in its own task; + the returned dict (or raised `MCPError`) is sent back as the response. + Inbound notifications go to `on_notify`. - ``task_status.started()`` is called once the dispatcher is ready to - accept ``send_request``/``notify`` calls, so callers can use - ``await tg.start(dispatcher.run, on_request, on_notify)``. + `task_status.started()` is called once the dispatcher is ready to + accept `send_request`/`notify` calls, so callers can use + `await tg.start(dispatcher.run, on_request, on_notify)`. """ ... diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index b62629b6c8..f81b737cc1 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -47,7 +47,7 @@ class NoBackChannelError(MCPError): Stateless HTTP and JSON-response-mode HTTP have no channel for the server to push requests (sampling, elicitation, roots/list) to the client. This is raised by `DispatchContext.send_raw_request` when `transport.can_send_request` - is ``False``, and serializes to an ``INVALID_REQUEST`` error response. + is `False`, and serializes to an `INVALID_REQUEST` error response. """ def __init__(self, method: str): diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 5ea5ea88f1..684807fc15 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -9,12 +9,12 @@ sees only `(ctx, method, params) -> dict`. Transports sit below and see only `SessionMessage` reads/writes. -The dispatcher is *mostly* MCP-agnostic — methods/params are opaque strings and -dicts — but it intercepts ``notifications/cancelled`` and -``notifications/progress`` because request correlation, cancellation and +The dispatcher is *mostly* MCP-agnostic - methods/params are opaque strings and +dicts - but it intercepts `notifications/cancelled` and +`notifications/progress` because request correlation, cancellation and progress are exactly the wiring this layer exists to provide. Those few wire -shapes are extracted with structural ``match`` patterns (no casts, no -``mcp.types`` model coupling); a malformed payload simply fails to match and +shapes are extracted with structural `match` patterns (no casts, no +`mcp.types` model coupling); a malformed payload simply fails to match and the correlation is skipped. """ @@ -64,22 +64,17 @@ TransportT = TypeVar("TransportT", bound=TransportContext) PeerCancelMode = Literal["interrupt", "signal"] -"""How inbound ``notifications/cancelled`` is applied to a running handler. +"""How inbound `notifications/cancelled` is applied to a running handler. -``"interrupt"`` (default) cancels the handler's scope. ``"signal"`` only sets -``ctx.cancel_requested`` and lets the handler observe it cooperatively. +`"interrupt"` (default) cancels the handler's scope. `"signal"` only sets +`ctx.cancel_requested` and lets the handler observe it cooperatively. """ -TransportBuilder = Callable[[RequestId | None, MessageMetadata], TransportContext] -"""Builds the per-message `TransportContext` from the inbound JSON-RPC id and -the `SessionMessage.metadata` the transport attached. Defaults to a plain -`TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied.""" - def _coerce_id(request_id: RequestId) -> RequestId: """Coerce a string request ID to int when it's a valid int literal. - `_allocate_id` only ever produces ``int`` keys for ``_pending``, but a peer + `_allocate_id` only ever produces `int` keys for `_pending`, but a peer may echo the ID back as a JSON string. The TypeScript SDK and `BaseSession` both perform this coercion at lookup time so the response still correlates. """ @@ -120,7 +115,7 @@ class _JSONRPCDispatchContext(Generic[TransportT]): """The transport-attached `SessionMessage.metadata` for this inbound message. Carries `ServerMessageMetadata` (HTTP request, SSE stream-close callbacks) - that the server lifts onto its request context. ``None`` for transports + that the server lifts onto its request context. `None` for transports that attach nothing. """ _progress_token: ProgressToken | None = None @@ -172,7 +167,7 @@ def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | `ServerMessageMetadata` tags a server-to-client message with the inbound request it belongs to (so streamable-HTTP can route it onto that request's SSE stream). `ClientMessageMetadata` carries resumption hints to the - client transport. ``None`` is the common case. + client transport. `None` is the common case. """ if related_request_id is not None: return ServerMessageMetadata(related_request_id=related_request_id) @@ -244,17 +239,17 @@ async def send_raw_request( ) -> dict[str, Any]: """Send a JSON-RPC request and await its response. - ``_related_request_id`` is set only by `_JSONRPCDispatchContext` when a + `_related_request_id` is set only by `_JSONRPCDispatchContext` when a handler makes a server-to-client request mid-flight; it routes the outgoing message onto the correct per-request SSE stream (SHTTP) via - `ServerMessageMetadata`. Top-level callers leave it ``None``. + `ServerMessageMetadata`. Top-level callers leave it `None`. Raises: MCPError: The peer responded with a JSON-RPC error; or - ``REQUEST_TIMEOUT`` if ``opts["timeout"]`` elapsed; or - ``CONNECTION_CLOSED`` if the dispatcher shut down while + `REQUEST_TIMEOUT` if `opts["timeout"]` elapsed; or + `CONNECTION_CLOSED` if the dispatcher shut down while awaiting the response. - RuntimeError: Called before ``run()`` has started or after it has + RuntimeError: Called before `run()` has started or after it has finished. """ if not self._running: @@ -276,7 +271,7 @@ async def send_raw_request( # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from # `_resolve_pending`/`_fan_out_closed` means the waiter already has an # outcome and dropping the late/redundant signal is correct. buffer=0 - # is unsafe — there's a window between registering `_pending[id]` and + # is unsafe - there's a window between registering `_pending[id]` and # parking in `receive()` where a close signal would be lost. send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) pending = _Pending(send=send, receive=receive, on_progress=on_progress) @@ -296,7 +291,7 @@ async def send_raw_request( raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None except anyio.get_cancelled_exc_class(): # Our caller's scope was cancelled. We're already inside a cancelled - # scope, so any bare `await` here re-raises immediately — shield to + # scope, so any bare `await` here re-raises immediately - shield to # let the courtesy cancel notification go out before we propagate. with anyio.CancelScope(shield=True): await self._cancel_outbound(request_id, "caller cancelled") @@ -333,8 +328,8 @@ async def run( """Drive the receive loop until the read stream closes. Each inbound request is handled in its own task in an internal task - group; ``task_status.started()`` fires once that group is open, so - ``await tg.start(dispatcher.run, ...)`` resumes when ``send_raw_request`` + group; `task_status.started()` fires once that group is open, so + `await tg.start(dispatcher.run, ...)` resumes when `send_raw_request` is usable. """ try: @@ -472,12 +467,12 @@ def _spawn( *args: object, sender_ctx: contextvars.Context | None, ) -> None: - """Schedule ``fn(*args)`` in the run() task group, propagating the sender's contextvars. + """Schedule `fn(*args)` in the run() task group, propagating the sender's contextvars. ASGI middleware (auth, OTel) sets contextvars on the request task that - wrote into the read stream. ``Context.run(tg.start_soon, ...)`` makes + wrote into the read stream. `Context.run(tg.start_soon, ...)` makes the spawned handler inherit *that* context instead of the receive - loop's, so ``auth_context_var`` and OTel spans survive. + loop's, so `auth_context_var` and OTel spans survive. """ assert self._tg is not None if sender_ctx is not None: @@ -486,9 +481,9 @@ def _spawn( self._tg.start_soon(fn, *args) def _fan_out_closed(self) -> None: - """Wake every pending ``send_raw_request`` waiter with ``CONNECTION_CLOSED``. + """Wake every pending `send_raw_request` waiter with `CONNECTION_CLOSED`. - Synchronous (uses ``send_nowait``) because it's called from ``finally`` + Synchronous (uses `send_nowait`) because it's called from `finally` which may be inside a cancelled scope. Idempotent. """ closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed") @@ -506,10 +501,10 @@ async def _handle_request( scope: anyio.CancelScope, on_request: OnRequest, ) -> None: - """Run ``on_request`` for one inbound request and write its response. + """Run `on_request` for one inbound request and write its response. This is the single exception-to-wire boundary: handler exceptions are - caught here and serialized to ``JSONRPCError``. Nothing above this in + caught here and serialized to `JSONRPCError`. Nothing above this in the stack constructs wire errors. """ try: @@ -518,7 +513,7 @@ async def _handle_request( result = await on_request(dctx, req.method, req.params) finally: # Close the back-channel the moment the handler exits - # (success or raise), before the response write — a handler + # (success or raise), before the response write - a handler # spawning detached work that later calls # `dctx.send_raw_request()` should see `NoBackChannelError`. dctx.close() @@ -527,7 +522,7 @@ async def _handle_request( # swallows a scope's *own* cancel at __exit__, so the result write # (or the handler) is interrupted and execution lands here without # reaching the `except cancelled` arm below. Spec SHOULD: send no - # response — fall through to `finally`. + # response - fall through to `finally`. except anyio.get_cancelled_exc_class(): # Outer-cancel: run()'s task group is shutting down. Any bare # `await` here re-raises immediately, so shield the courtesy write. diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py index 47b64c7769..a7347e30cc 100644 --- a/src/mcp/shared/peer.py +++ b/src/mcp/shared/peer.py @@ -2,11 +2,11 @@ `PeerMixin` defines the server-to-client request methods (sampling, elicitation, roots, ping) once. Any class that satisfies `Outbound` (i.e. has -``send_raw_request`` and ``notify``) can mix it in and get the typed methods for -free — `Context`, `Connection`, `Client`, or the bare `Peer` wrapper below. +`send_raw_request` and `notify`) can mix it in and get the typed methods for +free - `Context`, `Connection`, `Client`, or the bare `Peer` wrapper below. The mixin does no capability gating: it builds the params, calls -``self.send_raw_request(method, params)``, and parses the result into the typed +`self.send_raw_request(method, params)`, and parses the result into the typed model. Gating (and `NoBackChannelError`) is the host's `send_raw_request`'s job. """ @@ -35,15 +35,15 @@ __all__ = ["Meta", "Peer", "PeerMixin", "dump_params"] Meta = dict[str, Any] -"""Type alias for the ``_meta`` field carried on request/notification params.""" +"""Type alias for the `_meta` field carried on request/notification params.""" def dump_params(model: BaseModel | None, meta: Meta | None = None) -> dict[str, Any] | None: - """Serialize a params model to a wire dict, merging ``meta`` into ``_meta``. + """Serialize a params model to a wire dict, merging `meta` into `_meta`. Shared by `PeerMixin`, `Connection`, and `TypedServerRequestMixin` so every - typed convenience method gets the same `_meta` handling. ``meta`` keys take - precedence over any ``_meta`` already present on the model. + typed convenience method gets the same `_meta` handling. `meta` keys take + precedence over any `_meta` already present on the model. """ out = model.model_dump(by_alias=True, mode="json", exclude_none=True) if model is not None else None if meta: @@ -55,8 +55,8 @@ def dump_params(model: BaseModel | None, meta: Meta | None = None) -> dict[str, class PeerMixin: """Typed server-to-client request methods. - Each method constrains ``self`` to `Outbound` so the mixin can be applied - to anything with ``send_raw_request``/``notify`` — pyright checks the host + Each method constrains `self` to `Outbound` so the mixin can be applied + to anything with `send_raw_request`/`notify` - pyright checks the host class structurally at the call site. """ @@ -110,7 +110,7 @@ async def sample( meta: Meta | None = None, opts: CallOptions | None = None, ) -> CreateMessageResult | CreateMessageResultWithTools: - """Send a ``sampling/createMessage`` request to the peer. + """Send a `sampling/createMessage` request to the peer. Raises: MCPError: The peer responded with an error. @@ -142,7 +142,7 @@ async def elicit_form( meta: Meta | None = None, opts: CallOptions | None = None, ) -> ElicitResult: - """Send a form-mode ``elicitation/create`` request. + """Send a form-mode `elicitation/create` request. Raises: MCPError: The peer responded with an error. @@ -161,7 +161,7 @@ async def elicit_url( meta: Meta | None = None, opts: CallOptions | None = None, ) -> ElicitResult: - """Send a URL-mode ``elicitation/create`` request. + """Send a URL-mode `elicitation/create` request. Raises: MCPError: The peer responded with an error. @@ -174,7 +174,7 @@ async def elicit_url( async def list_roots( self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None ) -> ListRootsResult: - """Send a ``roots/list`` request. + """Send a `roots/list` request. Raises: MCPError: The peer responded with an error. @@ -184,7 +184,7 @@ async def list_roots( return ListRootsResult.model_validate(result) async def ping(self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: - """Send a ``ping`` request and ignore the result. + """Send a `ping` request and ignore the result. Raises: MCPError: The peer responded with an error. diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py index 9346116707..55e5f6bc5f 100644 --- a/src/mcp/shared/transport_context.py +++ b/src/mcp/shared/transport_context.py @@ -20,19 +20,19 @@ class TransportContext: """ kind: str - """Short identifier for the transport (e.g. ``"stdio"``, ``"streamable-http"``).""" + """Short identifier for the transport (e.g. `"stdio"`, `"streamable-http"`).""" can_send_request: bool """Whether the transport can deliver server-initiated requests to the peer. - ``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for - stdio, SSE, and stateful streamable HTTP. When ``False``, + `False` for stateless HTTP and HTTP with JSON response mode; `True` for + stdio, SSE, and stateful streamable HTTP. When `False`, `DispatchContext.send_raw_request` raises `NoBackChannelError`. """ headers: Mapping[str, str] | None = None """Request headers carried by this message, when the transport has them. - Populated by HTTP-based transports; ``None`` on stdio. Handlers should + Populated by HTTP-based transports; `None` on stdio. Handlers should None-check before use. """ diff --git a/tests/server/conftest.py b/tests/server/conftest.py index 290ccc957a..d70dda6526 100644 --- a/tests/server/conftest.py +++ b/tests/server/conftest.py @@ -11,8 +11,8 @@ class SpanCapture: """Thin adapter over logfire's `TestExporter` for asserting on MCP spans. `finished()` returns the raw `ReadableSpan` objects emitted by the - ``mcp-python-sdk`` instrumentation scope, filtered to exclude logfire's - synthetic ``pending_span`` markers, so tests can assert directly on + `mcp-python-sdk` instrumentation scope, filtered to exclude logfire's + synthetic `pending_span` markers, so tests can assert directly on `.name`, `.kind`, `.status`, `.attributes`, `.parent`, `.events`. """ @@ -36,7 +36,7 @@ def finished(self) -> list[ReadableSpan]: def spans(capfire: CaptureLogfire) -> Iterator[SpanCapture]: """In-memory MCP span capture, cleared before and after each test. - Backed by the project-level `capfire` override (see ``tests/conftest.py``) + Backed by the project-level `capfire` override (see `tests/conftest.py`) so there is a single global tracer provider for the suite. """ capture = SpanCapture(capfire.exporter) diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index ded9dfd6ac..be588f7ff7 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -2,7 +2,7 @@ `Connection` wraps an `Outbound` (the standalone stream). Its `notify` is best-effort (never raises); `send_raw_request` is gated on -``has_standalone_channel``. Tested with a stub `Outbound` so we can assert wire +`has_standalone_channel`. Tested with a stub `Outbound` so we can assert wire shape and inject failures. """ diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 33df234dbe..01afd078a6 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -80,11 +80,11 @@ async def connected_runner( headers: Mapping[str, str] | None = None, dispatch_middleware: list[DispatchMiddleware] | None = None, ) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[dict[str, Any]]]]: - """Yield ``(client, runner)`` running over an in-memory dispatcher pair. + """Yield `(client, runner)` running over an in-memory dispatcher pair. Starts the client (echo handlers) and `runner.run()` in a task group, wraps - the body in ``anyio.fail_after(5)``, and cancels on exit. When - ``initialized`` is true the helper performs the real ``initialize`` request + the body in `anyio.fail_after(5)`, and cancels on exit. When + `initialized` is true the helper performs the real `initialize` request before yielding, so tests start past the init-gate via the public path. """ client, server_d = create_direct_dispatcher_pair(headers=headers) @@ -119,7 +119,7 @@ async def connected_runner( @pytest.mark.anyio async def test_connected_runner_propagates_body_exception_unwrapped(server: SrvT): - """The harness re-raises body exceptions as-is, not as ``ExceptionGroup``.""" + """The harness re-raises body exceptions as-is, not as `ExceptionGroup`.""" with pytest.raises(RuntimeError, match="boom"): async with connected_runner(server): raise RuntimeError("boom") @@ -362,7 +362,7 @@ async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, s [span] = spans.finished() assert span.status.status_code == StatusCode.ERROR assert span.status.description == "Method not found: nonexistent/method" - # MCPError is a protocol-level response, not a crash — no traceback event. + # MCPError is a protocol-level response, not a crash - no traceback event. assert not [e for e in span.events if e.name == "exception"] diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py index 882f90bfab..c260c9be80 100644 --- a/tests/shared/test_context.py +++ b/tests/shared/test_context.py @@ -1,8 +1,8 @@ """Tests for `BaseContext`. -`BaseContext` is composition over a `DispatchContext` — it forwards -``transport``/``cancel_requested``/``send_raw_request``/``notify``/``progress`` -and adds ``meta``. It must satisfy `Outbound` so `PeerMixin` works on it. +`BaseContext` is composition over a `DispatchContext` - it forwards +`transport`/`cancel_requested`/`send_raw_request`/`notify`/`progress` +and adds `meta`. It must satisfy `Outbound` so `PeerMixin` works on it. """ from collections.abc import Mapping diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index a9f35e5c2e..c52701bcd9 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -1,7 +1,7 @@ """Behavioral tests for the Dispatcher Protocol. The contract tests are parametrized over every `Dispatcher` implementation via -the `pair_factory` fixture (see ``conftest.py``); they must pass for both +the `pair_factory` fixture (see `conftest.py`); they must pass for both `DirectDispatcher` and `JSONRPCDispatcher`. Implementation-specific tests pass a concrete factory directly. """ @@ -55,7 +55,7 @@ async def running_pair( client_on_notify: OnNotify | None = None, can_send_request: bool = True, ) -> AsyncIterator[tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Recorder, Recorder]]: - """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" + """Yield `(client, server, client_recorder, server_recorder)` with both `run()` loops live.""" client, server, close = factory(can_send_request=can_send_request) client_rec, server_rec = Recorder(), Recorder() c_req, c_notify = echo_handlers(client_rec) @@ -198,7 +198,7 @@ async def server_on_request( @pytest.mark.anyio async def test_ctx_message_metadata_is_none_when_transport_attaches_nothing(pair_factory: PairFactory): - """Plain requests carry no transport metadata, so handlers see ``None``.""" + """Plain requests carry no transport metadata, so handlers see `None`.""" async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): await client.send_raw_request("tools/call", None) diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index de1b6c4600..103fae61bf 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -3,7 +3,7 @@ Behaviors with no `DirectDispatcher` analog: request-id correlation, the exception-to-wire boundary, peer-cancel handling, and shutdown fan-out. The contract tests shared with `DirectDispatcher` live in -``test_dispatcher.py``. +`test_dispatcher.py`. """ import contextvars @@ -300,7 +300,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ) ) with anyio.fail_after(5): - await s2c_recv.receive() # response sent ⇒ the handler has run + await s2c_recv.receive() # response sent => the handler has run tg.cancel_scope.cancel() finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): @@ -502,7 +502,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> proceed.set() with anyio.fail_after(5): await handlers_done.wait() - # run() must still be healthy — close the read side to let it exit cleanly. + # run() must still be healthy - close the read side to let it exit cleanly. c2s_send.close() finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): @@ -553,14 +553,14 @@ def test_resolve_pending_drops_outcome_when_waiter_stream_already_closed(): d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] - recv.close() # waiter gone — send_nowait will raise BrokenResourceError + recv.close() # waiter gone - send_nowait will raise BrokenResourceError d._resolve_pending(1, {"late": True}) # pyright: ignore[reportPrivateUsage] for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send): s.close() def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): - """White-box: the buffer=1 invariant — WouldBlock means waiter already has an outcome.""" + """White-box: the buffer=1 invariant - WouldBlock means waiter already has an outcome.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) From 013d406c39217c27d1e117318802d074ae37e19b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 12:33:56 +0000 Subject: [PATCH 31/52] fix: JSONRPCDispatcher.run enters both streams; ClosedResourceError is EOF run() now enters write_stream alongside read_stream so the write end is released when the read loop exits (BaseSession did this; without it every [sse] interaction leg leaks a MemoryObjectSendStream and fails under filterwarnings=error). ClosedResourceError from the read iterator is caught and treated as clean EOF. Stateless SHTTP teardown closes the dispatcher's receive end after the request is handled; the next __anext__ call on the now-closed stream raises, which previously surfaced as 'Stateless session crashed'. --- src/mcp/shared/jsonrpc_dispatcher.py | 20 ++++++++----- tests/shared/test_jsonrpc_dispatcher.py | 37 +++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 684807fc15..8776d91d86 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -337,13 +337,19 @@ async def run( self._tg = tg self._running = True task_status.started() - async with self._read_stream: - async for item in self._read_stream: - # Duck-typed: `_context_streams.ContextReceiveStream` - # exposes `.last_context` (the sender's contextvars - # snapshot per message). Plain memory streams don't. - sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) - self._dispatch(item, on_request, on_notify, sender_ctx) + async with self._read_stream, self._write_stream: + try: + async for item in self._read_stream: + # Duck-typed: `_context_streams.ContextReceiveStream` + # exposes `.last_context` (the sender's contextvars + # snapshot per message). Plain memory streams don't. + sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) + self._dispatch(item, on_request, on_notify, sender_ctx) + except anyio.ClosedResourceError: + # The transport closed our receive end and we looped back + # to `__anext__` on the now-closed stream (stateless SHTTP + # teardown). Same as EOF. + logger.debug("read stream closed by transport; treating as EOF") # Read stream EOF: wake any blocked `send_raw_request` waiters now, # *before* the task group joins, so handlers parked in # `dctx.send_raw_request()` can unwind and the join doesn't deadlock. diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 103fae61bf..ab7ced8cf2 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -176,6 +176,43 @@ async def caller() -> None: s.close() +@pytest.mark.anyio +async def test_run_returns_cleanly_when_read_stream_receive_end_is_closed(): + """Iterating a closed receive end raises ClosedResourceError; run() treats it as EOF. + + Stateless SHTTP teardown closes the dispatcher's receive end after the + request is handled; the next loop iteration must not surface as a crash. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + # Close the dispatcher's own receive end (not the send end) before run() + # iterates it: __anext__ on a closed stream raises ClosedResourceError. + c2s_recv.close() + with anyio.fail_after(5): + await server.run(on_request, on_notify) + for s in (c2s_send, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_run_closes_write_stream_on_exit(): + """run() enters both streams; the write end is released on EOF.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + c2s_send.close() # EOF the read side; run() exits + with anyio.fail_after(5): + # Write end was entered and released by run(); peer's receive sees EOF. + with pytest.raises(anyio.EndOfStream): + await s2c_recv.receive() + s2c_recv.close() + + @pytest.mark.anyio async def test_late_response_after_timeout_is_dropped_without_crashing(): handler_started = anyio.Event() From 44574ad7fe4916639883c3686940516c96a3817d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 12:46:40 +0000 Subject: [PATCH 32/52] fix: align JSONRPCDispatcher error shapes with existing server Three pinned wire shapes the interaction suite locks in: - Peer-cancelled requests are answered with ErrorData(code=0, message="Request cancelled"). Spec says SHOULD NOT respond, but the existing server always has. - Unhandled handler exceptions become code=0 (not INTERNAL_ERROR), message=str(e). Matches Server._handle_request. - ValidationError becomes the fixed "Invalid request parameters" / data="" shape rather than leaking the pydantic error text. All three carry TODO(maxisbey) markers; they're compat with current behavior, not the intended end state. --- src/mcp/shared/jsonrpc_dispatcher.py | 30 +++++++++++++++++-------- tests/shared/test_jsonrpc_dispatcher.py | 22 +++++++++--------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 8776d91d86..079637744a 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -43,7 +43,6 @@ from mcp.shared.transport_context import TransportContext from mcp.types import ( CONNECTION_CLOSED, - INTERNAL_ERROR, INVALID_PARAMS, REQUEST_CANCELLED, REQUEST_TIMEOUT, @@ -524,11 +523,15 @@ async def _handle_request( # `dctx.send_raw_request()` should see `NoBackChannelError`. dctx.close() await self._write_result(req.id, result) - # Peer-cancel: `_dispatch_notification` cancelled this scope. anyio - # swallows a scope's *own* cancel at __exit__, so the result write - # (or the handler) is interrupted and execution lands here without - # reaching the `except cancelled` arm below. Spec SHOULD: send no - # response - fall through to `finally`. + if scope.cancel_called: + # Peer-cancel: `_dispatch_notification` cancelled this scope. + # anyio swallows a scope's *own* cancel at __exit__, so the + # result write (or the handler) is interrupted and execution + # lands here rather than the `except cancelled` arm below. + # TODO(maxisbey): spec says SHOULD NOT respond after cancel. + # The existing server always has and the interaction suite pins + # that; revisit once the suite's divergence entry is resolved. + await self._write_error(req.id, ErrorData(code=0, message="Request cancelled")) except anyio.get_cancelled_exc_class(): # Outer-cancel: run()'s task group is shutting down. Any bare # `await` here re-raises immediately, so shield the courtesy write. @@ -537,11 +540,20 @@ async def _handle_request( raise except MCPError as e: await self._write_error(req.id, e.error) - except ValidationError as e: - await self._write_error(req.id, ErrorData(code=INVALID_PARAMS, message=str(e))) + except ValidationError: + # TODO(maxisbey): data="" is pinned compat with the existing + # server (which never leaked pydantic error text onto the wire). + # Consider putting the validation detail in `data` once the + # interaction suite's divergence entry is resolved. + await self._write_error( + req.id, ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + ) except Exception as e: logger.exception("handler for %r raised", req.method) - await self._write_error(req.id, ErrorData(code=INTERNAL_ERROR, message=str(e))) + # TODO(maxisbey): code=0 is pinned compat with the existing + # server's `_handle_request`. JSON-RPC says INTERNAL_ERROR + # (-32603); revisit once the suite's divergence entry is resolved. + await self._write_error(req.id, ErrorData(code=0, message=str(e))) if self._raise_handler_exceptions: raise finally: diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index ab7ced8cf2..73157cd9ce 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -70,8 +70,8 @@ async def call(method: str) -> None: @pytest.mark.anyio -async def test_handler_raising_exception_sends_internal_error_with_str_message(): - """Per design: INTERNAL_ERROR carries str(e), not a scrubbed message.""" +async def test_handler_raising_exception_sends_code_zero_with_str_message(): + """Matches the existing server's `_handle_request`: code=0, message=str(e).""" async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: raise RuntimeError("kaboom") @@ -79,13 +79,14 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) - assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.code == 0 assert exc.value.error.message == "kaboom" assert exc.value.__cause__ is None # cause does not survive the wire @pytest.mark.anyio -async def test_peer_cancel_interrupt_mode_sets_cancel_requested_and_sends_no_response(): +async def test_peer_cancel_interrupt_mode_writes_cancelled_error_response(): + """Matches the existing server: a peer-cancelled request is answered with code=0.""" handler_started = anyio.Event() handler_exited = anyio.Event() seen_ctx: list[DCtx] = [] @@ -99,23 +100,22 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | handler_exited.set() raise NotImplementedError + seen_error: list[ErrorData] = [] async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): async with anyio.create_task_group() as tg: # pragma: no branch async def call_then_record() -> None: - with pytest.raises(MCPError): # we'll cancel via tg below + with pytest.raises(MCPError) as exc: await client.send_raw_request("slow", None) + seen_error.append(exc.value.error) tg.start_soon(call_then_record) await handler_started.wait() - # cancel just the handler (peer-cancel), not our caller await client.notify("notifications/cancelled", {"requestId": 1}) await handler_exited.wait() - # Handler torn down, no response was written; caller is still parked. - # Cancel the caller's task to end the test. - tg.cancel_scope.cancel() assert seen_ctx[0].cancel_requested.is_set() + assert seen_error == [ErrorData(code=0, message="Request cancelled")] @pytest.mark.anyio @@ -266,7 +266,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> sent = s2c_recv.receive_nowait() assert isinstance(sent, SessionMessage) assert isinstance(sent.message, JSONRPCError) - assert sent.message.error.code == INTERNAL_ERROR + assert sent.message.error.code == 0 finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): s.close() @@ -408,7 +408,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("t", None) - assert exc.value.error.code == INVALID_PARAMS + assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") @pytest.mark.anyio From e9ee4b40c7763c009b210e9d181f0c109836369c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:04:56 +0000 Subject: [PATCH 33/52] feat: JSONRPCDispatcher.send_raw_request emits CLIENT span and injects _meta Mirrors BaseSession.send_request: outbound requests are wrapped in an otel CLIENT span and W3C trace context is injected into params._meta (SEP-414). A side effect is that _meta is always present on the wire (empty under a no-op tracer), which the interaction suite's sampling/elicitation snapshots pin. The contract-test echo recorder now strips _meta so JSON-RPC and direct dispatch parametrizations record identically. TODO(maxisbey) marker added: this belongs in an outbound middleware once that seam exists; the dispatcher should not own otel. --- src/mcp/shared/jsonrpc_dispatcher.py | 33 +++++++++++++++++++------ tests/shared/test_dispatcher.py | 7 ++++-- tests/shared/test_jsonrpc_dispatcher.py | 29 +++++++++++++++++++++- 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 079637744a..27498e9dc0 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -29,8 +29,10 @@ import anyio import anyio.abc from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from opentelemetry.trace import SpanKind from pydantic import ValidationError +from mcp.shared._otel import inject_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError @@ -255,7 +257,8 @@ async def send_raw_request( raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close") opts = opts or {} request_id = self._allocate_id() - out_params = dict(params) if params is not None else None + out_params = dict(params) if params is not None else {} + out_meta = dict(out_params.get("_meta") or {}) on_progress = opts.get("on_progress") if on_progress is not None: # The caller wants progress updates. The spec mechanism is: include @@ -263,9 +266,8 @@ async def send_raw_request( # any `notifications/progress` it sends. We use the request id as the # token so the receive loop can find this `_Pending.on_progress` by # `_pending[token]` without a second lookup table. - meta = dict((out_params or {}).get("_meta") or {}) - meta["progressToken"] = request_id - out_params = {**(out_params or {}), "_meta": meta} + out_meta["progressToken"] = request_id + out_params["_meta"] = out_meta # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from # `_resolve_pending`/`_fan_out_closed` means the waiter already has an @@ -277,11 +279,26 @@ async def send_raw_request( self._pending[request_id] = pending metadata = _outbound_metadata(_related_request_id, opts) - msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) + target = out_params.get("name") + span_name = f"MCP send {method}{f' {target}' if isinstance(target, str) else ''}" + # TODO(maxisbey): the otel span + inject below mirror + # BaseSession.send_request for parity. They belong in an outbound + # middleware (symmetric with otel_middleware on the inbound side) once + # that seam exists; the dispatcher should not own otel. try: - await self._write(msg, metadata) - with anyio.fail_after(opts.get("timeout")): - outcome = await receive.receive() + with otel_span( + span_name, + kind=SpanKind.CLIENT, + attributes={"mcp.method.name": method, "jsonrpc.request.id": request_id}, + ): + # Inject W3C trace context into _meta (SEP-414). With a no-op + # tracer this writes nothing, but `_meta` itself is still + # present on the wire (and the interaction suite pins that). + inject_trace_context(out_meta) + msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) + await self._write(msg, metadata) + with anyio.fail_after(opts.get("timeout")): + outcome = await receive.receive() except TimeoutError: # Spec-recommended courtesy: tell the peer we've given up so it can # stop work and free resources. v1's BaseSession.send_request does diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index c52701bcd9..d71b013573 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -34,9 +34,12 @@ def echo_handlers(recorder: Recorder) -> tuple[OnRequest, OnNotify]: async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - recorder.requests.append((method, params)) + # Strip `_meta` so JSON-RPC and direct dispatch record identically: + # the JSON-RPC outbound path always attaches `_meta` (otel injection). + recorded = {k: v for k, v in (params or {}).items() if k != "_meta"} if params is not None else None + recorder.requests.append((method, recorded)) recorder.contexts.append(ctx) - return {"echoed": method, "params": dict(params or {})} + return {"echoed": method, "params": recorded or {}} async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: recorder.notifications.append((method, params)) diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 73157cd9ce..ec5b50d637 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -14,7 +14,7 @@ import pytest from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream -from mcp.shared.dispatcher import DispatchContext +from mcp.shared.dispatcher import CallOptions, DispatchContext from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] JSONRPCDispatcher, @@ -399,6 +399,33 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | assert received == [(0.25, None, None)] +@pytest.mark.anyio +async def test_send_raw_request_always_carries_meta_on_the_wire(): + """Outbound requests always include `params._meta` (otel injection per SEP-414). + + Caller-supplied `_meta` keys are preserved; the progress token is merged in. + """ + seen: list[Mapping[str, Any] | None] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen.append(params) + return {} + + async def noop_progress(progress: float, total: float | None, message: str | None) -> None: + raise NotImplementedError + + opts: CallOptions = {"on_progress": noop_progress} + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("a", None) + await client.send_raw_request("b", {"x": 1, "_meta": {"k": "v"}}, opts) + assert seen[0] == {"_meta": {}} + assert seen[1] is not None + assert seen[1]["x"] == 1 + assert seen[1]["_meta"]["k"] == "v" + assert "progressToken" in seen[1]["_meta"] + + @pytest.mark.anyio async def test_handler_raising_validation_error_sends_invalid_params(): async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: From 130e160ba5828539b40f1cbd410873cd0a9626d8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:29:59 +0000 Subject: [PATCH 34/52] fix: ServerRunner error shapes match existing server (METHOD_NOT_FOUND, pre-init) METHOD_NOT_FOUND message is now bare "Method not found" (no method suffix); the interaction suite pins that. The pre-init gate now returns the generic INVALID_PARAMS / "Invalid request parameters" / data="" shape. The existing server has no dedicated pre-init check; the request dies in ClientRequest validation, so clients see this shape. TODO(maxisbey) marked. Also: loosen the gap-8 _meta wire test to be tracer-agnostic (it was order-dependent on the SpanCapture fixture). --- src/mcp/server/runner.py | 12 ++++++------ tests/server/test_runner.py | 7 ++++--- tests/shared/test_jsonrpc_dispatcher.py | 12 ++++++++---- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 662827e592..65563349af 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -34,7 +34,7 @@ from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext from mcp.types import ( - INVALID_REQUEST, + INVALID_PARAMS, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, Implementation, @@ -164,13 +164,13 @@ async def _on_request( if method == "initialize": return self._handle_initialize(params) if not self._initialized and method not in _INIT_EXEMPT: - raise MCPError( - code=INVALID_REQUEST, - message=f"Received {method!r} before initialization was complete", - ) + # TODO(maxisbey): pinned compat. The existing server has no + # dedicated pre-init check; the request dies in ClientRequest + # validation, so the client sees the generic invalid-params shape. + raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="") entry = self.server.get_request_handler(method) if entry is None: - raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") + raise MCPError(code=METHOD_NOT_FOUND, message="Method not found") # ValidationError propagates; the dispatcher's exception boundary maps # it to INVALID_PARAMS. typed_params = entry.params_type.model_validate(params or {}) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 01afd078a6..990dfee01c 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -24,11 +24,12 @@ from mcp.shared.exceptions import MCPError from mcp.types import ( INTERNAL_ERROR, - INVALID_REQUEST, + INVALID_PARAMS, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, CallToolRequestParams, ClientCapabilities, + ErrorData, Implementation, InitializeRequestParams, ListToolsResult, @@ -142,7 +143,7 @@ async def test_runner_gates_requests_before_initialize(server: SrvT): async with connected_runner(server, initialized=False) as (client, _): with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) - assert exc.value.error.code == INVALID_REQUEST + assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") # ping is exempt from the gate assert await client.send_raw_request("ping", None) == {} @@ -361,7 +362,7 @@ async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, s assert exc.value.error.code == METHOD_NOT_FOUND [span] = spans.finished() assert span.status.status_code == StatusCode.ERROR - assert span.status.description == "Method not found: nonexistent/method" + assert span.status.description == "Method not found" # MCPError is a protocol-level response, not a crash - no traceback event. assert not [e for e in span.events if e.name == "exception"] diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index ec5b50d637..cb0b6513db 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -419,11 +419,15 @@ async def noop_progress(progress: float, total: float | None, message: str | Non with anyio.fail_after(5): await client.send_raw_request("a", None) await client.send_raw_request("b", {"x": 1, "_meta": {"k": "v"}}, opts) - assert seen[0] == {"_meta": {}} - assert seen[1] is not None - assert seen[1]["x"] == 1 + # `_meta` is always present. Its contents depend on the active otel + # tracer (traceparent/tracestate may be injected), so assert presence + # and that anything beyond W3C keys is exactly what we expect. + w3c = {"traceparent", "tracestate"} + assert seen[0] is not None and seen[0].keys() == {"_meta"} + assert set(seen[0]["_meta"].keys()) <= w3c + assert seen[1] is not None and seen[1]["x"] == 1 + assert set(seen[1]["_meta"].keys()) - w3c == {"k", "progressToken"} assert seen[1]["_meta"]["k"] == "v" - assert "progressToken" in seen[1]["_meta"] @pytest.mark.anyio From 9bdd153c997ad19b97ebf342720374e9d4d6f9b1 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:42:01 +0000 Subject: [PATCH 35/52] fix: ServerRunner validates spec methods against ClientRequest before lookup Parity with BaseSession._receive_loop: a spec method with malformed params surfaces as INVALID_PARAMS via the dispatcher's ValidationError boundary even when no handler is registered (the existing server validates against the discriminated union before any handler lookup). Gated on the set of spec method names (derived from the ClientRequest union discriminator) so custom methods registered via add_request_handler still route. The existing server rejects those too, but nothing pins that and routing them is strictly better. DirectDispatcher gains the same ValidationError -> INVALID_PARAMS mapping JSONRPCDispatcher has, so runner-over-direct unit tests see the same shape. --- src/mcp/server/runner.py | 25 ++++++++++++++++++++++++- src/mcp/shared/direct_dispatcher.py | 7 ++++++- tests/server/test_runner.py | 26 ++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 65563349af..10b0f60e31 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -19,7 +19,7 @@ from collections.abc import Mapping from dataclasses import dataclass, field from functools import partial, reduce -from typing import Any, Generic, cast +from typing import Any, Generic, cast, get_args import anyio.abc from opentelemetry.trace import SpanKind, StatusCode @@ -37,9 +37,11 @@ INVALID_PARAMS, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, + ClientRequest, Implementation, InitializeRequestParams, InitializeResult, + client_request_adapter, ) __all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "otel_middleware"] @@ -51,6 +53,13 @@ _INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) +_SPEC_CLIENT_METHODS: frozenset[str] = frozenset( + cast(type[BaseModel], arm).model_fields["method"].default for arm in get_args(ClientRequest) +) +"""Method names in the spec `ClientRequest` union, derived from the +discriminator literal on each arm. Used to gate upfront validation so custom +methods registered via `add_request_handler` are not rejected.""" + def otel_middleware(next_on_request: OnRequest) -> OnRequest: """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. @@ -161,6 +170,20 @@ async def _on_request( method: str, params: Mapping[str, Any] | None, ) -> dict[str, Any]: + # TODO(maxisbey): pinned compat. `BaseSession._receive_loop` validates + # every inbound request against the spec `ClientRequest` discriminated + # union *before* handler lookup, so a spec method with malformed params + # surfaces as INVALID_PARAMS via the dispatcher's ValidationError + # boundary even when no handler is registered. v2 wanted to decouple + # the runner from the spec union; revisit once the suite's divergence + # entry is resolved. Gated on spec methods so custom methods registered + # via `add_request_handler` still route (the existing server rejects + # those too, but nothing pins that and routing them is strictly better). + if method in _SPEC_CLIENT_METHODS: + payload: dict[str, Any] = {"method": method} + if params is not None: + payload["params"] = dict(params) + client_request_adapter.validate_python(payload) if method == "initialize": return self._handle_initialize(params) if not self._initialized and method not in _INIT_EXEMPT: diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 51dddf1e79..f252dfad30 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -21,12 +21,13 @@ import anyio import anyio.abc +from pydantic import ValidationError from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import MessageMetadata from mcp.shared.transport_context import TransportContext -from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT, RequestId +from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, REQUEST_TIMEOUT, RequestId __all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"] @@ -149,6 +150,10 @@ async def _dispatch_request( return await self._on_request(dctx, method, params) except MCPError: raise + except ValidationError as e: + # Same shape JSONRPCDispatcher writes, so runner-over-direct + # tests see what runner-over-JSONRPC would. + raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="") from e except Exception as e: raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e except TimeoutError: diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 990dfee01c..609c488ffe 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -161,13 +161,32 @@ async def test_runner_routes_to_handler_and_builds_context(server: SrvT): @pytest.mark.anyio -async def test_runner_unknown_method_raises_method_not_found(server: SrvT): +async def test_runner_spec_method_with_no_handler_raises_method_not_found(server: SrvT): + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("resources/list", None) + assert exc.value.error.code == METHOD_NOT_FOUND + + +@pytest.mark.anyio +async def test_runner_non_spec_method_with_no_handler_raises_method_not_found(server: SrvT): + """Upfront validation is gated to spec methods, so a non-spec method + skips it and reaches handler lookup.""" async with connected_runner(server) as (client, _): with pytest.raises(MCPError) as exc: await client.send_raw_request("nonexistent/method", None) assert exc.value.error.code == METHOD_NOT_FOUND +@pytest.mark.anyio +async def test_runner_malformed_params_for_unregistered_spec_method_raises_invalid_params(server: SrvT): + """A spec method with malformed params is INVALID_PARAMS even with no handler.""" + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/call", {"name": 123}) + assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + + @pytest.mark.anyio async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): async with connected_runner(server, initialized=False) as (client, runner): @@ -287,6 +306,9 @@ async def test_runner_stateless_skips_init_gate(server: SrvT): @pytest.mark.anyio async def test_server_add_request_handler_routes_custom_method_with_validated_params(server: SrvT): + """Custom methods outside the spec `ClientRequest` union skip upfront + validation and route to the registered handler.""" + class GreetParams(RequestParams): name: str @@ -358,7 +380,7 @@ async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, s async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() with pytest.raises(MCPError) as exc: - await client.send_raw_request("nonexistent/method", None) + await client.send_raw_request("resources/list", None) assert exc.value.error.code == METHOD_NOT_FOUND [span] = spans.finished() assert span.status.status_code == StatusCode.ERROR From 87e0dbc7e6714a7a5ea63addb92793012216232c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:47:30 +0000 Subject: [PATCH 36/52] feat: ServerRunner builds InitializeResult from InitializationOptions ServerRunner gains init_options (defaults to server.create_initialization_options()). _handle_initialize builds the full InitializeResult from it (name/title/description/version/website_url/icons/ instructions) and negotiates requested-if-in-SUPPORTED_PROTOCOL_VERSIONS- else-LATEST, matching ServerSession._received_request. --- src/mcp/server/runner.py | 32 ++++++++++++++++++++++---------- tests/server/test_runner.py | 31 +++++++++++++++++++++++-------- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 10b0f60e31..5653b03653 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -29,10 +29,12 @@ from mcp.server.connection import Connection from mcp.server.context import CallNext, Context, ServerMiddleware from mcp.server.lowlevel.server import Server +from mcp.server.models import InitializationOptions from mcp.shared._otel import extract_trace_context, otel_span from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( INVALID_PARAMS, LATEST_PROTOCOL_VERSION, @@ -125,6 +127,8 @@ class ServerRunner(Generic[LifespanT]): dispatcher: Dispatcher[TransportContext] lifespan_state: LifespanT has_standalone_channel: bool + init_options: InitializationOptions | None = None + """`InitializeResult` payload. Defaults to `server.create_initialization_options()`.""" session_id: str | None = None stateless: bool = False dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) @@ -134,6 +138,8 @@ class ServerRunner(Generic[LifespanT]): def __post_init__(self) -> None: self._initialized = self.stateless + if self.init_options is None: + self.init_options = self.server.create_initialization_options() self.connection = Connection( self.dispatcher, has_standalone_channel=self.has_standalone_channel, session_id=self.session_id ) @@ -234,18 +240,24 @@ def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any] init = InitializeRequestParams.model_validate(params or {}) self.connection.client_info = init.client_info self.connection.client_capabilities = init.capabilities - # TODO: real version negotiation. This always responds with LATEST, - # which is wrong - the server should pick the highest version both - # sides support and compute a per-connection feature set from it. - # See FOLLOWUPS: "Consolidate per-connection mode/negotiation". - self.connection.protocol_version = ( - init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION - ) + requested = init.protocol_version + negotiated = requested if requested in SUPPORTED_PROTOCOL_VERSIONS else LATEST_PROTOCOL_VERSION + self.connection.protocol_version = negotiated self._initialized = True self.connection.initialized.set() + assert self.init_options is not None + opts = self.init_options result = InitializeResult( - protocol_version=self.connection.protocol_version, - capabilities=self.server.capabilities(), - server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), + protocol_version=negotiated, + capabilities=opts.capabilities, + server_info=Implementation( + name=opts.server_name, + title=opts.title, + description=opts.description, + version=opts.server_version, + website_url=opts.website_url, + icons=opts.icons, + ), + instructions=opts.instructions, ) return _dump_result(result) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 609c488ffe..522a80bb56 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -18,10 +18,12 @@ from mcp.server.connection import Connection from mcp.server.context import Context from mcp.server.lowlevel.server import NotificationOptions, Server +from mcp.server.models import InitializationOptions from mcp.server.runner import ServerRunner, otel_middleware from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchMiddleware from mcp.shared.exceptions import MCPError +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( INTERNAL_ERROR, INVALID_PARAMS, @@ -77,6 +79,7 @@ async def connected_runner( initialized: bool = True, stateless: bool = False, has_standalone_channel: bool = True, + init_options: InitializationOptions | None = None, session_id: str | None = None, headers: Mapping[str, str] | None = None, dispatch_middleware: list[DispatchMiddleware] | None = None, @@ -94,6 +97,7 @@ async def connected_runner( dispatcher=server_d, lifespan_state={}, has_standalone_channel=has_standalone_channel, + init_options=init_options, session_id=session_id, stateless=stateless, dispatch_middleware=dispatch_middleware or [], @@ -327,20 +331,31 @@ async def greet(ctx: Any, params: GreetParams) -> dict[str, Any]: @pytest.mark.anyio -async def test_server_capabilities_reflects_ctor_options_in_initialize_result(): +async def test_runner_initialize_result_reflects_init_options(): async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: raise NotImplementedError - server: SrvT = Server( - name="caps-test", - on_list_tools=list_tools, - notification_options=NotificationOptions(tools_changed=True), - experimental_capabilities={"ext": {"k": "v"}}, - ) - async with connected_runner(server, initialized=False) as (client, _): + server: SrvT = Server(name="caps-test", on_list_tools=list_tools, instructions="be nice") + init_options = server.create_initialization_options(NotificationOptions(tools_changed=True), {"ext": {"k": "v"}}) + async with connected_runner(server, initialized=False, init_options=init_options) as (client, _): result = await client.send_raw_request("initialize", _initialize_params()) assert result["capabilities"]["tools"]["listChanged"] is True assert result["capabilities"]["experimental"] == {"ext": {"k": "v"}} + assert result["serverInfo"]["name"] == "caps-test" + assert result["instructions"] == "be nice" + + +@pytest.mark.anyio +async def test_runner_initialize_echoes_supported_version_and_falls_back_to_latest(server: SrvT): + oldest = SUPPORTED_PROTOCOL_VERSIONS[0] + async with connected_runner(server, initialized=False) as (client, _): + params = {**_initialize_params(), "protocolVersion": oldest} + result = await client.send_raw_request("initialize", params) + assert result["protocolVersion"] == oldest + async with connected_runner(server, initialized=False) as (client, _): + params = {**_initialize_params(), "protocolVersion": "1999-01-01"} + result = await client.send_raw_request("initialize", params) + assert result["protocolVersion"] == LATEST_PROTOCOL_VERSION @pytest.mark.anyio From 9d121cfeeadbcc78841240bf8e82630ba3eca520 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:49:54 +0000 Subject: [PATCH 37/52] fix: ServerRunner passes None to notification handlers when params absent Matches Server._handle_notification: when the wire omits params, the handler receives None, not an empty model. _make_context now accepts typed_params=None. --- src/mcp/server/runner.py | 10 +++++++--- tests/server/test_runner.py | 6 +++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 5653b03653..b8a9353cdd 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -227,13 +227,17 @@ async def _on_notify( if entry is None: logger.debug("no handler for notification %s", method) return - typed_params = entry.params_type.model_validate(params or {}) + # Absent wire params reach the handler as `None`, not an empty model + # (matches the existing `Server._handle_notification`). + typed_params = entry.params_type.model_validate(params) if params is not None else None ctx = self._make_context(dctx, typed_params) # TODO: cast goes away when `ServerRequestContext = Context` lands. await cast(Any, entry.handler)(ctx, typed_params) - def _make_context(self, dctx: DispatchContext[TransportContext], typed_params: BaseModel) -> Context[LifespanT]: - meta = getattr(typed_params, "meta", None) + def _make_context( + self, dctx: DispatchContext[TransportContext], typed_params: BaseModel | None + ) -> Context[LifespanT]: + meta = getattr(typed_params, "meta", None) if typed_params is not None else None return Context(dctx, lifespan=self.lifespan_state, connection=self.connection, meta=meta) def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 522a80bb56..8b1f8c343e 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -209,10 +209,14 @@ async def on_roots_changed(ctx: Any, params: Any) -> None: server.add_notification_handler("notifications/roots/list_changed", NotificationParams, on_roots_changed) async with connected_runner(server) as (client, _): await client.notify("notifications/roots/list_changed", None) + await client.notify("notifications/roots/list_changed", {}) # DirectDispatcher delivers synchronously; one yield is enough. await anyio.lowlevel.checkpoint() - assert len(seen) == 1 + assert len(seen) == 2 assert isinstance(seen[0][0], Context) + # Absent wire params reach the handler as None; present-but-empty validates. + assert seen[0][1] is None + assert isinstance(seen[1][1], NotificationParams) @pytest.mark.anyio From ca0c67b359f4d75ebc81eb1f3e8bf1c4671a2163 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 14:14:59 +0000 Subject: [PATCH 38/52] fix: JSONRPCDispatcher no-builder __init__ overload accepts all kwargs The first overload (no transport_builder) was missing peer_cancel_mode and raise_handler_exceptions, so callers couldn't pass them without also supplying a builder. Pure typing fix; the impl already handled them. --- src/mcp/shared/jsonrpc_dispatcher.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 27498e9dc0..c2f2c218b7 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -192,6 +192,9 @@ def __init__( self: JSONRPCDispatcher[TransportContext], read_stream: ReadStream[SessionMessage | Exception], write_stream: WriteStream[SessionMessage], + *, + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, ) -> None: ... @overload def __init__( From 16fbef6b9c17817433238f8c35332a16b775fb9b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 14:20:00 +0000 Subject: [PATCH 39/52] fix: JSONRPCDispatcher.run cancels in-flight handlers on read-stream EOF Restores the Server.run() behaviour the dispatcher rework dropped: at read-stream EOF the task group cancels in-flight handler tasks instead of joining on them. Without this, a handler that outlives its caller (its request timed out client-side, or the client disconnected mid-call) keeps run() from returning forever, leaking the handler task and over SSE the GET request that hosts the session. Regression test parks a handler in sleep_forever(), EOFs the read stream, asserts run() returns within fail_after(5). Confirmed to hang on the unpatched code. --- src/mcp/shared/jsonrpc_dispatcher.py | 46 +++++++++++++++---------- tests/shared/test_jsonrpc_dispatcher.py | 42 ++++++++++++++++++++++ 2 files changed, 70 insertions(+), 18 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index c2f2c218b7..557b598d25 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -356,24 +356,34 @@ async def run( self._tg = tg self._running = True task_status.started() - async with self._read_stream, self._write_stream: - try: - async for item in self._read_stream: - # Duck-typed: `_context_streams.ContextReceiveStream` - # exposes `.last_context` (the sender's contextvars - # snapshot per message). Plain memory streams don't. - sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) - self._dispatch(item, on_request, on_notify, sender_ctx) - except anyio.ClosedResourceError: - # The transport closed our receive end and we looped back - # to `__anext__` on the now-closed stream (stateless SHTTP - # teardown). Same as EOF. - logger.debug("read stream closed by transport; treating as EOF") - # Read stream EOF: wake any blocked `send_raw_request` waiters now, - # *before* the task group joins, so handlers parked in - # `dctx.send_raw_request()` can unwind and the join doesn't deadlock. - self._running = False - self._fan_out_closed() + try: + async with self._read_stream, self._write_stream: + try: + async for item in self._read_stream: + # Duck-typed: `_context_streams.ContextReceiveStream` + # exposes `.last_context` (the sender's contextvars + # snapshot per message). Plain memory streams don't. + sender_ctx: contextvars.Context | None = getattr( + self._read_stream, "last_context", None + ) + self._dispatch(item, on_request, on_notify, sender_ctx) + except anyio.ClosedResourceError: + # The transport closed our receive end and we looped + # back to `__anext__` on the now-closed stream + # (stateless SHTTP teardown). Same as EOF. + logger.debug("read stream closed by transport; treating as EOF") + # Read stream EOF: wake any blocked `send_raw_request` waiters + # (callers outside this task group) with CONNECTION_CLOSED. + self._running = False + self._fan_out_closed() + finally: + # Transport closed: cancel in-flight handlers. Without this + # the task-group join waits for them, and a handler that + # outlives its caller (its request timed out client-side, or + # the client disconnected mid-call) would keep `run()` from + # returning forever. Same behaviour as `Server.run()` before + # the dispatcher rework. + tg.cancel_scope.cancel() finally: # Covers the cancel/crash paths where the inline fan-out above is # never reached. Idempotent. diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index cb0b6513db..035f943de1 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -196,6 +196,48 @@ async def test_run_returns_cleanly_when_read_stream_receive_end_is_closed(): s.close() +@pytest.mark.anyio +async def test_run_cancels_in_flight_handlers_when_read_stream_eofs(): + """A handler that outlives its caller must not keep run() from returning. + + Without the cancel-at-EOF, the task-group join would wait on this handler + forever (over SSE that leaks the handler task and the GET request hosting + the session). + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + handler_started = anyio.Event() + handler_cancelled = anyio.Event() + + async def park(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_cancelled.set() + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + run_returned = anyio.Event() + + async def drive() -> None: + await server.run(park, on_notify) + run_returned.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(drive) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None))) + with anyio.fail_after(5): + await handler_started.wait() + c2s_send.close() # EOF the read side; run() must cancel the parked handler + await run_returned.wait() + assert handler_cancelled.is_set() + s2c_recv.close() + + @pytest.mark.anyio async def test_run_closes_write_stream_on_exit(): """run() enters both streams; the write end is released on EOF.""" From 94b3ce9accfaa14204f6b7a25aef34ef535e2553 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 14:24:25 +0000 Subject: [PATCH 40/52] chore: drop unused Server.capabilities() helper Added earlier in this branch for ServerRunner._handle_initialize, which now reads from InitializationOptions instead. No callers remain. --- src/mcp/server/lowlevel/server.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index afa2f90109..0292bf8c8d 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -307,10 +307,6 @@ def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] """Return the registered entry for a notification method, or `None`.""" return self._notification_handlers.get(method) - def capabilities(self) -> types.ServerCapabilities: - """Derive `ServerCapabilities` from registered handlers and constructor options.""" - return self.get_capabilities(self._notification_options, self._experimental_capabilities) - # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities From ac5ab572dbddb8c63f8a9f9b62bed93b8f5d765b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:13:23 +0000 Subject: [PATCH 41/52] feat: Server.run drives JSONRPCDispatcher + ServerRunner; ServerSession is a dispatcher proxy The swap. Production server traffic now flows through the dispatcher/runner path; BaseSession is no longer reached on the server side. ServerRequestContext: standalone dataclass (drops RequestContext base, inlines session/request_id/meta). ServerMiddleware retyped to take it; _MwLifespanT no longer contravariant while the ctx is the invariant mutable dataclass (TODO marked). Connection: client_params holds the full InitializeRequestParams; client_info / client_capabilities are read-through properties. ServerSession: rewritten as a connection-scoped proxy over JSONRPCDispatcher + Connection. send_request / send_notification model-dump and forward to dispatcher.send_raw_request / notify, threading related_request_id so SHTTP routing is unchanged. The typed helpers (create_message, elicit_*, send_log_message, send_*_list_changed, list_roots, send_ping, send_progress_notification, send_elicit_complete, check_client_capability) are kept verbatim. Deleted: the BaseSession-derived receive loop, _received_*, incoming_messages, InitializationState, ServerRequestResponder, send_message, the tasks-only _build_* helpers. ServerRunner: dispatcher is JSONRPCDispatcher concretely (the ServerSession shim needs its _related_request_id kwarg). __post_init__ builds the connection-scoped session. _make_context builds ServerRequestContext from dctx.request_id and dctx.message_metadata (the same isinstance(ServerMessageMetadata) narrow the previous Server._handle_request did). _handle_initialize sets connection.client_params. Both cast(Any, entry.handler) and the getattr(typed_params, 'meta', ...) are gone (meta read via isinstance(typed_params, RequestParams)). Server import is under TYPE_CHECKING to break the cycle with lowlevel/server. Server.run(): builds JSONRPCDispatcher(read, write, raise_handler_exceptions=...) and ServerRunner(..., dispatch_middleware=[otel_middleware]) inside the lifespan, then awaits runner.run(). _handle_message / _handle_request / _handle_notification deleted. --- src/mcp/server/connection.py | 16 +- src/mcp/server/context.py | 36 ++- src/mcp/server/lowlevel/server.py | 223 ++-------------- src/mcp/server/runner.py | 52 +++- src/mcp/server/session.py | 415 ++++++------------------------ 5 files changed, 173 insertions(+), 569 deletions(-) diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index 65267e34ca..1c7ee67412 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -24,7 +24,7 @@ from mcp.shared.dispatcher import CallOptions, Outbound from mcp.shared.exceptions import NoBackChannelError from mcp.shared.peer import Meta, dump_params -from mcp.types import ClientCapabilities, Implementation, LoggingLevel +from mcp.types import ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel __all__ = ["Connection"] @@ -51,8 +51,8 @@ def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_ self.has_standalone_channel = has_standalone_channel self.session_id: str | None = session_id - self.client_info: Implementation | None = None - self.client_capabilities: ClientCapabilities | None = None + self.client_params: InitializeRequestParams | None = None + """The full `initialize` request params; `None` before initialization.""" self.protocol_version: str | None = None self.initialized: anyio.Event = anyio.Event() @@ -68,6 +68,16 @@ def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_ middleware to register per-connection teardown. Unwound LIFO after `dispatcher.run()` returns, shielded from cancellation.""" + @property + def client_info(self) -> Implementation | None: + """The client's `Implementation` from `initialize`; `None` before initialization.""" + return self.client_params.client_info if self.client_params is not None else None + + @property + def client_capabilities(self) -> ClientCapabilities | None: + """The client's `ClientCapabilities` from `initialize`; `None` before initialization.""" + return self.client_params.capabilities if self.client_params is not None else None + async def send_raw_request( self, method: str, diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index fd5c20f264..d6a4aef663 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -9,24 +9,32 @@ from mcp.server._typed_request import TypedServerRequestMixin from mcp.server.connection import Connection -from mcp.server.experimental.request_context import Experimental from mcp.server.session import ServerSession -from mcp.shared._context import RequestContext from mcp.shared.context import BaseContext from mcp.shared.dispatcher import DispatchContext from mcp.shared.message import CloseSSEStreamCallback from mcp.shared.peer import Meta, PeerMixin from mcp.shared.transport_context import TransportContext -from mcp.types import LoggingLevel, RequestParamsMeta +from mcp.types import LoggingLevel, RequestId, RequestParamsMeta LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) @dataclass(kw_only=True) -class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContextT, RequestT]): +class ServerRequestContext(Generic[LifespanContextT, RequestT]): + """Per-request context handed to lowlevel request and notification handlers. + + Built by `ServerRunner._make_context` for each inbound message. Carries the + connection-scoped `ServerSession` (server-to-client requests and + notifications), per-request metadata, and any per-message data the + transport attached (the HTTP request, SSE stream-close callbacks). + """ + + session: ServerSession lifespan_context: LifespanContextT - experimental: Experimental + request_id: RequestId | None = None + meta: RequestParamsMeta | None = None request: RequestT | None = None close_sse_stream: CloseSSEStreamCallback | None = None close_standalone_sse_stream: CloseSSEStreamCallback | None = None @@ -107,26 +115,32 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * CallNext = Callable[[], Awaitable[HandlerResult]] -_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True) +_MwLifespanT = TypeVar("_MwLifespanT") class ServerMiddleware(Protocol[_MwLifespanT]): """Context-tier middleware: `(ctx, method, typed_params, call_next) -> result`. Runs *inside* `ServerRunner._on_request` after params validation and - `Context` construction. Wraps registered handlers (including `ping`) but + context construction. Wraps registered handlers (including `ping`) but not `initialize`, `METHOD_NOT_FOUND`, or validation failures. Listed outermost-first on `Server.middleware`. `Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific - middleware sees `ctx.lifespan: L`. A reusable middleware can be typed - `ServerMiddleware[object]` - `Context` is covariant in `LifespanT`, so it - registers on any `Server[L]`. + middleware sees `ctx.lifespan_context: L`. While the context is the + mutable `ServerRequestContext` dataclass it is invariant in `L`, so a + reusable middleware should be typed `ServerMiddleware[Any]` to register on + any `Server[L]`. """ + # TODO(maxisbey): once `_make_context` returns the (covariant) `Context[L]` + # again, restore `_MwLifespanT` to `contravariant=True` and retype `ctx` + # below to `Context[_MwLifespanT]` so reusable middleware can be + # `ServerMiddleware[object]` instead of `ServerMiddleware[Any]`. + async def __call__( self, - ctx: Context[_MwLifespanT], + ctx: ServerRequestContext[_MwLifespanT, Any], method: str, params: BaseModel, call_next: CallNext, diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 0292bf8c8d..186915412d 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -36,17 +36,13 @@ async def main(): from __future__ import annotations -import contextvars import logging -import warnings from collections.abc import AsyncIterator, Awaitable, Callable -from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass from importlib.metadata import version as importlib_version -from typing import Any, Generic, cast +from typing import Any, Generic -import anyio -from opentelemetry.trace import SpanKind, StatusCode from pydantic import BaseModel from starlette.applications import Starlette from starlette.middleware import Middleware @@ -61,18 +57,16 @@ async def main(): from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings from mcp.server.context import HandlerResult, ServerMiddleware, ServerRequestContext -from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp.server.runner import ServerRunner, otel_middleware from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared._otel import extract_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.exceptions import MCPError -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import SessionMessage +from mcp.shared.transport_context import TransportContext logger = logging.getLogger(__name__) @@ -432,196 +426,23 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, - ): - async with AsyncExitStack() as stack: - lifespan_context = await stack.enter_async_context(self.lifespan(self)) - session = await stack.enter_async_context( - ServerSession( - read_stream, - write_stream, - initialization_options, - stateless=stateless, - ) - ) - - # Configure task support for this session if enabled - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - if task_support is not None: - task_support.configure_session(session) - await stack.enter_async_context(task_support.run()) - - async with anyio.create_task_group() as tg: - try: - async for message in session.incoming_messages: - logger.debug("Received message: %s", message) - - if isinstance(message, RequestResponder) and message.context is not None: - context = message.context - else: - context = contextvars.copy_context() - - context.run( - tg.start_soon, - self._handle_message, - message, - session, - lifespan_context, - raise_exceptions, - ) - finally: - # Transport closed: cancel in-flight handlers. Without this the - # TG join waits for them, and when they eventually try to - # respond they hit a closed write stream (the session's - # _receive_loop closed it when the read stream ended). - tg.cancel_scope.cancel() - - async def _handle_message( - self, - message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, - session: ServerSession, - lifespan_context: LifespanResultT, - raise_exceptions: bool = False, - ): - with warnings.catch_warnings(record=True) as w: - match message: - case RequestResponder() as responder: - with responder: - await self._handle_request( - message, responder.request, session, lifespan_context, raise_exceptions - ) - case Exception(): - logger.error(f"Received exception from stream: {message}") - if raise_exceptions: - raise message - case _: - await self._handle_notification(message, session, lifespan_context) - - for warning in w: # pragma: lax no cover - logger.info("Warning: %s: %s", warning.category.__name__, warning.message) - - async def _handle_request( - self, - message: RequestResponder[types.ClientRequest, types.ServerResult], - req: types.ClientRequest, - session: ServerSession, - lifespan_context: LifespanResultT, - raise_exceptions: bool, - ): - logger.info("Processing request of type %s", type(req).__name__) - - target = getattr(req.params, "name", None) if req.params else None - span_name = f"MCP handle {req.method} {target}" if target else f"MCP handle {req.method}" - - # Extract W3C trace context from _meta (SEP-414). - meta = cast(dict[str, Any] | None, getattr(req.params, "meta", None)) if req.params else None - parent_context = extract_trace_context(meta) if meta is not None else None - - with otel_span( - span_name, - kind=SpanKind.SERVER, - attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id}, - context=parent_context, - ) as span: - if entry := self._request_handlers.get(req.method): - handler = entry.handler - logger.debug("Dispatching request of type %s", type(req).__name__) - - try: - # Extract request context and close_sse_stream from message metadata - request_data = None - close_sse_stream_cb = None - close_standalone_sse_stream_cb = None - if message.message_metadata is not None and isinstance( - message.message_metadata, ServerMessageMetadata - ): - request_data = message.message_metadata.request_context - close_sse_stream_cb = message.message_metadata.close_sse_stream - close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - # Get task metadata from request params if present - task_metadata = None - if hasattr(req, "params") and req.params is not None: # pragma: no branch - task_metadata = getattr(req.params, "task", None) - ctx = ServerRequestContext( - request_id=message.request_id, - meta=message.request_meta, - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - request=request_data, - close_sse_stream=close_sse_stream_cb, - close_standalone_sse_stream=close_standalone_sse_stream_cb, - ) - response = await handler(ctx, req.params) - except MCPError as err: - response = err.error - except anyio.get_cancelled_exc_class(): - if message.cancelled: - # Client sent CancelledNotification; responder.cancel() already - # sent an error response, so skip the duplicate. - logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) - return - # Transport-close cancellation from the TG in run(); re-raise so the - # TG swallows its own cancellation. - raise - except Exception as err: - if raise_exceptions: # pragma: no cover - raise err - response = types.ErrorData(code=0, message=str(err)) - else: - response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") - - if isinstance(response, types.ErrorData) and span is not None: - span.set_status(StatusCode.ERROR, response.message) - - try: - # TODO: cast goes away when `_handle_request` is deleted. - await message.respond(cast(types.ServerResult | types.ErrorData, response)) - except (anyio.BrokenResourceError, anyio.ClosedResourceError): - # Transport closed between handler unblocking and respond. Happens - # when _receive_loop's finally wakes a handler blocked on - # send_request: the handler runs to respond() before run()'s TG - # cancel fires, but after the write stream closed. Closed if our - # end closed (_receive_loop's async-with exit); Broken if the peer - # end closed first (streamable_http terminate()). - logger.debug("Response for %s dropped - transport closed", message.request_id) - return - - logger.debug("Response sent") - - async def _handle_notification( - self, - notify: types.ClientNotification, - session: ServerSession, - lifespan_context: LifespanResultT, ) -> None: - if entry := self._notification_handlers.get(notify.method): - handler = entry.handler - logger.debug("Dispatching notification of type %s", type(notify).__name__) - - try: - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - ctx = ServerRequestContext( - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=None, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - ) - await handler(ctx, notify.params) - except Exception: # pragma: no cover - logger.exception("Uncaught exception in notification handler") + async with self.lifespan(self) as lifespan_context: + dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + read_stream, + write_stream, + raise_handler_exceptions=raise_exceptions, + ) + runner = ServerRunner( + server=self, + dispatcher=dispatcher, + lifespan_state=lifespan_context, + init_options=initialization_options, + has_standalone_channel=True, + stateless=stateless, + dispatch_middleware=[otel_middleware], + ) + await runner.run() def streamable_http_app( self, diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index b8a9353cdd..2ed0381535 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -19,7 +19,7 @@ from collections.abc import Mapping from dataclasses import dataclass, field from functools import partial, reduce -from typing import Any, Generic, cast, get_args +from typing import TYPE_CHECKING, Any, Generic, cast, get_args import anyio.abc from opentelemetry.trace import SpanKind, StatusCode @@ -27,12 +27,14 @@ from typing_extensions import TypeVar from mcp.server.connection import Connection -from mcp.server.context import CallNext, Context, ServerMiddleware -from mcp.server.lowlevel.server import Server +from mcp.server.context import CallNext, ServerMiddleware, ServerRequestContext from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession from mcp.shared._otel import extract_trace_context, otel_span -from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest +from mcp.shared.dispatcher import DispatchContext, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import ServerMessageMetadata from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( @@ -43,9 +45,13 @@ Implementation, InitializeRequestParams, InitializeResult, + RequestParams, client_request_adapter, ) +if TYPE_CHECKING: + from mcp.server.lowlevel.server import Server + __all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "otel_middleware"] logger = logging.getLogger(__name__) @@ -124,7 +130,7 @@ class ServerRunner(Generic[LifespanT]): """Per-connection orchestrator. One instance per client connection.""" server: Server[LifespanT] - dispatcher: Dispatcher[TransportContext] + dispatcher: JSONRPCDispatcher[Any] lifespan_state: LifespanT has_standalone_channel: bool init_options: InitializationOptions | None = None @@ -134,6 +140,8 @@ class ServerRunner(Generic[LifespanT]): dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) connection: Connection = field(init=False) + session: ServerSession = field(init=False) + """Connection-scoped: the same instance reaches every request as `ctx.session`.""" _initialized: bool = field(init=False) def __post_init__(self) -> None: @@ -143,6 +151,7 @@ def __post_init__(self) -> None: self.connection = Connection( self.dispatcher, has_standalone_channel=self.has_standalone_channel, session_id=self.session_id ) + self.session = ServerSession(self.dispatcher, self.connection, stateless=self.stateless) async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: """Drive the dispatcher until the underlying channel closes. @@ -204,8 +213,7 @@ async def _on_request( # it to INVALID_PARAMS. typed_params = entry.params_type.model_validate(params or {}) ctx = self._make_context(dctx, typed_params) - # TODO: cast goes away when `ServerRequestContext = Context` lands. - call: CallNext = partial(cast(Any, entry.handler), ctx, typed_params) + call: CallNext = partial(entry.handler, ctx, typed_params) for mw in reversed(self.server.middleware): call = partial(mw, ctx, method, typed_params, call) return _dump_result(await call()) @@ -231,19 +239,35 @@ async def _on_notify( # (matches the existing `Server._handle_notification`). typed_params = entry.params_type.model_validate(params) if params is not None else None ctx = self._make_context(dctx, typed_params) - # TODO: cast goes away when `ServerRequestContext = Context` lands. - await cast(Any, entry.handler)(ctx, typed_params) + await entry.handler(ctx, typed_params) def _make_context( self, dctx: DispatchContext[TransportContext], typed_params: BaseModel | None - ) -> Context[LifespanT]: - meta = getattr(typed_params, "meta", None) if typed_params is not None else None - return Context(dctx, lifespan=self.lifespan_state, connection=self.connection, meta=meta) + ) -> ServerRequestContext[LifespanT, Any]: + meta = typed_params.meta if isinstance(typed_params, RequestParams) else None + # TODO(maxisbey): remove for Context rework. Reads the SHTTP per-request + # data off the raw `dctx.message_metadata` carrier; replace with the + # per-transport context once that lands. + md = dctx.message_metadata + if isinstance(md, ServerMessageMetadata): + request = md.request_context + close_sse_stream = md.close_sse_stream + close_standalone_sse_stream = md.close_standalone_sse_stream + else: + request = close_sse_stream = close_standalone_sse_stream = None + return ServerRequestContext( + session=self.session, + lifespan_context=self.lifespan_state, + request_id=dctx.request_id, + meta=meta, + request=request, + close_sse_stream=close_sse_stream, + close_standalone_sse_stream=close_standalone_sse_stream, + ) def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: init = InitializeRequestParams.model_validate(params or {}) - self.connection.client_info = init.client_info - self.connection.client_capabilities = init.capabilities + self.connection.client_params = init requested = init.protocol_version negotiated = requested if requested in SUPPORTED_PROTOCOL_VERSIONS else LATEST_PROTOCOL_VERSION self.connection.protocol_version = negotiated diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index fc2f97a9cb..64e98ec46f 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -1,130 +1,100 @@ -"""ServerSession Module - -This module provides the ServerSession class, which manages communication between the -server and client in the MCP (Model Context Protocol) framework. It is most commonly -used in MCP servers to interact with the client. - -Common usage pattern: -``` - async def handle_call_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: - # Check client capabilities before proceeding - if ctx.session.check_client_capability( - types.ClientCapabilities(experimental={"advanced_tools": dict()}) - ): - result = await perform_advanced_tool_operation(params.arguments) - else: - result = await perform_basic_tool_operation(params.arguments) - return result - - async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: - if ctx.session.client_params: - return ListPromptsResult(prompts=generate_custom_prompts(ctx.session.client_params)) - return ListPromptsResult(prompts=default_prompts) - - server = Server(name, on_call_tool=handle_call_tool, on_list_prompts=handle_list_prompts) -``` - -The ServerSession class is typically used internally by the Server class and should not -be instantiated directly by users of the MCP framework. +"""`ServerSession`: server-to-client requests and notifications. + +A thin proxy over `JSONRPCDispatcher` and `Connection`. One instance per +client connection (built by `ServerRunner`). Handlers reach it as +`ctx.session` and use the typed helpers (`create_message`, `elicit_form`, +`send_log_message`, ...) to call back to the client. + +The receive-loop, initialize handling, and per-request task isolation that +used to live here are now owned by `JSONRPCDispatcher` and `ServerRunner`. """ -from enum import Enum -from typing import Any, TypeVar, overload +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar, overload -import anyio -import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream -from pydantic import AnyUrl, TypeAdapter +from pydantic import AnyUrl, BaseModel from mcp import types -from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures -from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.dispatcher import CallOptions, ProgressFnT from mcp.shared.exceptions import StatelessModeNotSupported -from mcp.shared.experimental.tasks.capabilities import check_tasks_capability -from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.session import ( - BaseSession, - RequestResponder, -) -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS - - -class InitializationState(Enum): - NotInitialized = 1 - Initializing = 2 - Initialized = 3 - - -ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") - -ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception -) - - -class ServerSession( - BaseSession[ - types.ServerRequest, - types.ServerNotification, - types.ServerResult, - types.ClientRequest, - types.ClientNotification, - ] -): - _initialized: InitializationState = InitializationState.NotInitialized - _client_params: types.InitializeRequestParams | None = None - _experimental_features: ExperimentalServerSessionFeatures | None = None +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import MessageMetadata, ServerMessageMetadata + +if TYPE_CHECKING: + from mcp.server.connection import Connection + +__all__ = ["ServerSession"] + +ResultT = TypeVar("ResultT", bound=BaseModel) + + +class ServerSession: + """Connection-scoped proxy for server-to-client requests and notifications. + + `send_request` / `send_notification` model-dump their argument and forward + to the dispatcher; the typed helpers below are unchanged from the previous + implementation and only call those two methods. + """ def __init__( self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], - init_options: InitializationOptions, + dispatcher: JSONRPCDispatcher[Any], + connection: Connection, + *, stateless: bool = False, ) -> None: - super().__init__(read_stream, write_stream) + self._dispatcher = dispatcher + self._connection = connection self._stateless = stateless - self._initialization_state = ( - InitializationState.Initialized if stateless else InitializationState.NotInitialized - ) - - self._init_options = init_options - self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ - ServerRequestResponder - ](0) - self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) - - @property - def _receive_request_adapter(self) -> TypeAdapter[types.ClientRequest]: - return types.client_request_adapter - - @property - def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification]: - return types.client_notification_adapter @property def client_params(self) -> types.InitializeRequestParams | None: - return self._client_params - - @property - def experimental(self) -> ExperimentalServerSessionFeatures: - """Experimental APIs for server→client task operations. + """The client's `initialize` request params; `None` before initialization.""" + return self._connection.client_params - WARNING: These APIs are experimental and may change without notice. + async def send_request( + self, + request: types.ServerRequest, + result_type: type[ResultT], + request_read_timeout_seconds: float | None = None, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + ) -> ResultT: + """Send a typed server-to-client request and validate the result. + + `metadata.related_request_id` (when supplied) routes the outgoing + message onto the originating request's response stream over + streamable HTTP. """ - if self._experimental_features is None: - self._experimental_features = ExperimentalServerSessionFeatures(self) - return self._experimental_features + data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + opts: CallOptions = {} + if request_read_timeout_seconds is not None: + opts["timeout"] = request_read_timeout_seconds + if progress_callback is not None: + opts["on_progress"] = progress_callback + related = metadata.related_request_id if isinstance(metadata, ServerMessageMetadata) else None + result = await self._dispatcher.send_raw_request( + data["method"], data.get("params"), opts or None, _related_request_id=related + ) + return result_type.model_validate(result) + + async def send_notification( + self, + notification: types.ServerNotification, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send a typed server-to-client notification.""" + data = notification.model_dump(by_alias=True, mode="json", exclude_none=True) + await self._dispatcher.notify(data["method"], data.get("params"), _related_request_id=related_request_id) def check_client_capability(self, capability: types.ClientCapabilities) -> bool: """Check if the client supports a specific capability.""" - if self._client_params is None: # pragma: lax no cover + if self.client_params is None: # pragma: lax no cover return False - client_caps = self._client_params.capabilities + client_caps = self.client_params.capabilities if capability.roots is not None: # pragma: lax no cover if client_caps.roots is None: @@ -150,60 +120,8 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: return False - if capability.tasks is not None: # pragma: lax no cover - if client_caps.tasks is None: - return False - if not check_tasks_capability(capability.tasks, client_caps.tasks): - return False - return True - async def _receive_loop(self) -> None: - async with self._incoming_message_stream_writer: - await super()._receive_loop() - - async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): - match responder.request: - case types.InitializeRequest(params=params): - requested_version = params.protocol_version - self._initialization_state = InitializationState.Initializing - self._client_params = params - with responder: - await responder.respond( - types.InitializeResult( - protocol_version=requested_version - if requested_version in SUPPORTED_PROTOCOL_VERSIONS - else types.LATEST_PROTOCOL_VERSION, - capabilities=self._init_options.capabilities, - server_info=types.Implementation( - name=self._init_options.server_name, - title=self._init_options.title, - description=self._init_options.description, - version=self._init_options.server_version, - website_url=self._init_options.website_url, - icons=self._init_options.icons, - ), - instructions=self._init_options.instructions, - ) - ) - self._initialization_state = InitializationState.Initialized - case types.PingRequest(): - # Ping requests are allowed at any time - pass - case _: - if self._initialization_state != InitializationState.Initialized: - raise RuntimeError("Received request before initialization was complete") - - async def _received_notification(self, notification: types.ClientNotification) -> None: - # Need this to avoid ASYNC910 - await anyio.lowlevel.checkpoint() - match notification: - case types.InitializedNotification(): - self._initialization_state = InitializationState.Initialized - case _: - if self._initialization_state != InitializationState.Initialized: # pragma: no cover - raise RuntimeError("Received notification before initialization was complete") - async def send_log_message( self, level: types.LoggingLevel, @@ -313,7 +231,7 @@ async def create_message( """ if self._stateless: raise StatelessModeNotSupported(method="sampling") - client_caps = self._client_params.capabilities if self._client_params else None + client_caps = self.client_params.capabilities if self.client_params else None validate_sampling_tools(client_caps, tools, tool_choice) validate_tool_use_result_messages(messages) @@ -333,7 +251,6 @@ async def create_message( ) metadata_obj = ServerMessageMetadata(related_request_id=related_request_id) - # Use different result types based on whether tools are provided if tools is not None: return await self.send_request( request=request, @@ -508,185 +425,3 @@ async def send_elicit_complete( ), related_request_id, ) - - def _build_elicit_form_request( - self, - message: str, - requested_schema: types.ElicitRequestedSchema, - related_task_id: str | None = None, - task: types.TaskMetadata | None = None, - ) -> types.JSONRPCRequest: - """Build a form mode elicitation request without sending it. - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - task: If provided, makes this a task-augmented request - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - task=task, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="elicitation/create", - params=params_data, - ) - - def _build_elicit_url_request( - self, - message: str, - url: str, - elicitation_id: str, - related_task_id: str | None = None, - ) -> types.JSONRPCRequest: - """Build a URL mode elicitation request without sending it. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.ElicitRequestURLParams( - message=message, - url=url, - elicitation_id=elicitation_id, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="elicitation/create", - params=params_data, - ) - - def _build_create_message_request( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - tools: list[types.Tool] | None = None, - tool_choice: types.ToolChoice | None = None, - related_task_id: str | None = None, - task: types.TaskMetadata | None = None, - ) -> types.JSONRPCRequest: - """Build a sampling/createMessage request without sending it. - - Args: - messages: The conversation messages to send - max_tokens: Maximum number of tokens to generate - system_prompt: Optional system prompt - include_context: Optional context inclusion setting - temperature: Optional sampling temperature - stop_sequences: Optional stop sequences - metadata: Optional metadata to pass through to the LLM provider - model_preferences: Optional model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - task: If provided, makes this a task-augmented request - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.CreateMessageRequestParams( - messages=messages, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - max_tokens=max_tokens, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - task=task, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="sampling/createMessage", - params=params_data, - ) - - async def send_message(self, message: SessionMessage) -> None: - """Send a raw session message. - - This is primarily used by TaskResultHandler to deliver queued messages - (elicitation/sampling requests) to the client during task execution. - - WARNING: This is a low-level experimental method that may change without - notice. Prefer using higher-level methods like send_notification() or - send_request() for normal operations. - - Args: - message: The session message to send - """ - await self._write_stream.send(message) - - async def _handle_incoming(self, req: ServerRequestResponder) -> None: - await self._incoming_message_stream_writer.send(req) - - @property - def incoming_messages(self) -> MemoryObjectReceiveStream[ServerRequestResponder]: - return self._incoming_message_stream_reader From c1c851f910532ec491a05b1d3b63bef142e745b0 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 18:12:26 +0000 Subject: [PATCH 42/52] test: rebaseline legacy unit tests for the dispatcher swap; rewrite __main__ server/__main__.py: use Server.run() instead of manual ServerSession + incoming_messages. 49 -> 24 lines. tests/server/test_runner.py: connected_runner now drives a JSONRPCDispatcher pair (was DirectDispatcher); ctx is ServerRequestContext; dropped 6 tests of dormant Context features that return in the Context rework. tests/server/test_connection.py: set client_params instead of the now read-only client_info/client_capabilities properties. tests/server/test_stateless_mode.py: ServerSession(dispatcher, connection) fixtures. tests/conftest.py: capfire override resets _otel._tracer to NoOpTracer on teardown so tests after a span-capture test don't see traceparent injected into _meta (pre-existing order-dep, surfaced once both code paths inject). Deleted (covered by tests/interaction/ or test_runner.py): test_session.py, test_session_race_condition.py, test_progress_notifications.py, test_malformed_input.py. test_lowlevel_exception_handling.py: dropped 3 tests of the deleted _handle_message; kept the real-stream Server.run() regression test. --- src/mcp/server/__main__.py | 33 +- tests/conftest.py | 27 + tests/issues/test_malformed_input.py | 151 ----- tests/server/conftest.py | 7 +- tests/server/test_connection.py | 27 +- .../test_lowlevel_exception_handling.py | 82 +-- tests/server/test_runner.py | 199 ++----- tests/server/test_session.py | 535 ------------------ tests/server/test_session_race_condition.py | 132 ----- tests/server/test_stateless_mode.py | 67 +-- tests/shared/test_progress_notifications.py | 264 --------- 11 files changed, 147 insertions(+), 1377 deletions(-) delete mode 100644 tests/issues/test_malformed_input.py delete mode 100644 tests/server/test_session.py delete mode 100644 tests/server/test_session_race_condition.py delete mode 100644 tests/shared/test_progress_notifications.py diff --git a/src/mcp/server/__main__.py b/src/mcp/server/__main__.py index dbc50b8a79..4305b87e22 100644 --- a/src/mcp/server/__main__.py +++ b/src/mcp/server/__main__.py @@ -1,14 +1,11 @@ -import importlib.metadata import logging import sys import warnings import anyio -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp.server.lowlevel.server import Server from mcp.server.stdio import stdio_server -from mcp.types import ServerCapabilities if not sys.warnoptions: warnings.simplefilter("ignore") @@ -17,32 +14,10 @@ logger = logging.getLogger("server") -async def receive_loop(session: ServerSession): - logger.info("Starting receive loop") - async for message in session.incoming_messages: - if isinstance(message, Exception): - logger.error("Error: %s", message) - continue - - logger.info("Received message from client: %s", message) - - -async def main(): - version = importlib.metadata.version("mcp") +async def main() -> None: + server: Server[dict[str, object]] = Server("mcp") async with stdio_server() as (read_stream, write_stream): - async with ( - ServerSession( - read_stream, - write_stream, - InitializationOptions( - server_name="mcp", - server_version=version, - capabilities=ServerCapabilities(), - ), - ) as session, - write_stream, - ): - await receive_loop(session) + await server.run(read_stream, write_stream, server.create_initialization_options()) if __name__ == "__main__": diff --git a/tests/conftest.py b/tests/conftest.py index b83c472135..2278c9939e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +from collections.abc import Iterator import pytest @@ -11,7 +12,33 @@ # suite-wide. Set before logfire is imported anywhere. os.environ.setdefault("LOGFIRE_DISTRIBUTED_TRACING", "true") +import opentelemetry.trace # noqa: E402 (env var must be set before logfire import below) +from logfire.testing import CaptureLogfire # noqa: E402 + +import mcp.shared._otel # noqa: E402 + @pytest.fixture def anyio_backend(): return "asyncio" + + +@pytest.fixture(name="capfire") +def _capfire_isolated(capfire: CaptureLogfire) -> Iterator[CaptureLogfire]: + """Override of logfire's `capfire` that scopes the MCP tracer to the test. + + `capfire` installs a real tracer provider, and logfire's proxy machinery + mutates the cached `mcp.shared._otel._tracer` to delegate to it for the + rest of the process. Without isolation, every subsequent test in the same + worker would emit real spans, and `send_raw_request` would inject a real + `traceparent` into outbound `_meta`, breaking the interaction-suite + snapshots that pin `_meta={}` under a no-op tracer. + + Setup points `_tracer` at the now-live provider so MCP spans record; + teardown replaces it with a `NoOpTracer`. + """ + mcp.shared._otel._tracer = opentelemetry.trace.get_tracer_provider().get_tracer("mcp-python-sdk") + try: + yield capfire + finally: + mcp.shared._otel._tracer = opentelemetry.trace.NoOpTracer() diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py deleted file mode 100644 index da586f3098..0000000000 --- a/tests/issues/test_malformed_input.py +++ /dev/null @@ -1,151 +0,0 @@ -# Claude Debug -"""Test for HackerOne vulnerability report #3156202 - malformed input DOS.""" - -import anyio -import pytest - -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.types import INVALID_PARAMS, JSONRPCError, JSONRPCMessage, JSONRPCRequest, ServerCapabilities - - -@pytest.mark.anyio -async def test_malformed_initialize_request_does_not_crash_server(): - """Test that malformed initialize requests return proper error responses - instead of crashing the server (HackerOne #3156202). - """ - # Create in-memory streams for testing - read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) - write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10) - - try: - # Create a malformed initialize request (missing required params field) - malformed_request = JSONRPCRequest( - jsonrpc="2.0", - id="f20fe86132ed4cd197f89a7134de5685", - method="initialize", - # params=None # Missing required params field - ) - - # Wrap in session message - request_message = SessionMessage(message=malformed_request) - - # Start a server session - async with ServerSession( - read_stream=read_receive_stream, - write_stream=write_send_stream, - init_options=InitializationOptions( - server_name="test_server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ): - # Send the malformed request - await read_send_stream.send(request_message) - - # Give the session time to process the request - await anyio.sleep(0.1) - - # Check that we received an error response instead of a crash - try: - response_message = write_receive_stream.receive_nowait() - response = response_message.message - - # Verify it's a proper JSON-RPC error response - assert isinstance(response, JSONRPCError) - assert response.jsonrpc == "2.0" - assert response.id == "f20fe86132ed4cd197f89a7134de5685" - assert response.error.code == INVALID_PARAMS - assert "Invalid request parameters" in response.error.message - - # Verify the session is still alive and can handle more requests - # Send another malformed request to confirm server stability - another_malformed_request = JSONRPCRequest( - jsonrpc="2.0", - id="test_id_2", - method="tools/call", - # params=None # Missing required params - ) - another_request_message = SessionMessage(message=another_malformed_request) - - await read_send_stream.send(another_request_message) - await anyio.sleep(0.1) - - # Should get another error response, not a crash - second_response_message = write_receive_stream.receive_nowait() - second_response = second_response_message.message - - assert isinstance(second_response, JSONRPCError) - assert second_response.id == "test_id_2" - assert second_response.error.code == INVALID_PARAMS - - except anyio.WouldBlock: # pragma: no cover - pytest.fail("No response received - server likely crashed") - finally: # pragma: lax no cover - # Close all streams to ensure proper cleanup - await read_send_stream.aclose() - await write_send_stream.aclose() - await read_receive_stream.aclose() - await write_receive_stream.aclose() - - -@pytest.mark.anyio -async def test_multiple_concurrent_malformed_requests(): - """Test that multiple concurrent malformed requests don't crash the server.""" - # Create in-memory streams for testing - read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](100) - write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](100) - - try: - # Start a server session - async with ServerSession( - read_stream=read_receive_stream, - write_stream=write_send_stream, - init_options=InitializationOptions( - server_name="test_server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ): - # Send multiple malformed requests concurrently - malformed_requests: list[SessionMessage] = [] - for i in range(10): - malformed_request = JSONRPCRequest( - jsonrpc="2.0", - id=f"malformed_{i}", - method="initialize", - # params=None # Missing required params - ) - request_message = SessionMessage(message=malformed_request) - malformed_requests.append(request_message) - - # Send all requests - for request in malformed_requests: - await read_send_stream.send(request) - - # Give time to process - await anyio.sleep(0.2) - - # Verify we get error responses for all requests - error_responses: list[JSONRPCMessage] = [] - try: - while True: - response_message = write_receive_stream.receive_nowait() - error_responses.append(response_message.message) - except anyio.WouldBlock: - pass # No more messages - - # Should have received 10 error responses - assert len(error_responses) == 10 - - for i, response in enumerate(error_responses): - assert isinstance(response, JSONRPCError) - assert response.id == f"malformed_{i}" - assert response.error.code == INVALID_PARAMS - finally: # pragma: lax no cover - # Close all streams to ensure proper cleanup - await read_send_stream.aclose() - await write_send_stream.aclose() - await read_receive_stream.aclose() - await write_receive_stream.aclose() diff --git a/tests/server/conftest.py b/tests/server/conftest.py index d70dda6526..9114f0348d 100644 --- a/tests/server/conftest.py +++ b/tests/server/conftest.py @@ -36,10 +36,13 @@ def finished(self) -> list[ReadableSpan]: def spans(capfire: CaptureLogfire) -> Iterator[SpanCapture]: """In-memory MCP span capture, cleared before and after each test. - Backed by the project-level `capfire` override (see `tests/conftest.py`) - so there is a single global tracer provider for the suite. + Backed by the project-level `capfire` override (see `tests/conftest.py`), + which scopes `mcp.shared._otel._tracer` to the test so the real tracer + doesn't leak into later tests in the same worker. """ capture = SpanCapture(capfire.exporter) capture.clear() yield capture capture.clear() + + diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index be588f7ff7..7d1c5dc04a 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -17,9 +17,12 @@ from mcp.shared.dispatcher import CallOptions from mcp.shared.exceptions import NoBackChannelError from mcp.types import ( + LATEST_PROTOCOL_VERSION, ClientCapabilities, ElicitationCapability, EmptyResult, + Implementation, + InitializeRequestParams, ListRootsRequest, ListRootsResult, PingRequest, @@ -28,6 +31,14 @@ ) +def _client_params(capabilities: ClientCapabilities) -> InitializeRequestParams: + return InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=capabilities, + client_info=Implementation(name="t", version="0"), + ) + + class StubOutbound: def __init__( self, *, result: dict[str, Any] | None = None, raise_on_send: type[BaseException] | None = None @@ -190,14 +201,24 @@ def test_connection_check_capability_false_before_initialized(): ) def test_check_capability_per_field_branches(have: ClientCapabilities, want: ClientCapabilities, expected: bool): conn = Connection(StubOutbound(), has_standalone_channel=True) - conn.client_capabilities = have + conn.client_params = _client_params(have) assert conn.check_capability(want) is expected +def test_connection_client_info_and_capabilities_derive_from_client_params(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + assert conn.client_info is None + assert conn.client_capabilities is None + caps = ClientCapabilities(sampling=SamplingCapability()) + conn.client_params = _client_params(caps) + assert conn.client_info is not None and conn.client_info.name == "t" + assert conn.client_capabilities == caps + + def test_connection_check_capability_true_when_client_declares_it(): conn = Connection(StubOutbound(), has_standalone_channel=True) - conn.client_capabilities = ClientCapabilities( - sampling=SamplingCapability(), roots=RootsCapability(list_changed=True) + conn.client_params = _client_params( + ClientCapabilities(sampling=SamplingCapability(), roots=RootsCapability(list_changed=True)) ) conn.initialized.set() assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is True diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 46925916d9..015a5cbafa 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -1,64 +1,8 @@ -from unittest.mock import AsyncMock, Mock - import anyio import pytest -from mcp import types from mcp.server.lowlevel.server import Server -from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder - - -@pytest.mark.anyio -async def test_exception_handling_with_raise_exceptions_true(): - """Transport exceptions are re-raised when raise_exceptions=True.""" - server = Server("test-server") - session = Mock(spec=ServerSession) - - test_exception = RuntimeError("Test error") - - with pytest.raises(RuntimeError, match="Test error"): - await server._handle_message(test_exception, session, {}, raise_exceptions=True) - - -@pytest.mark.anyio -async def test_exception_handling_with_raise_exceptions_false(): - """Transport exceptions are logged locally but not sent to the client. - - The transport that reported the error is likely broken; writing back - through it races with stream closure (#1967, #2064). The TypeScript, - Go, and C# SDKs all log locally only. - """ - server = Server("test-server") - session = Mock(spec=ServerSession) - session.send_log_message = AsyncMock() - - await server._handle_message(RuntimeError("Test error"), session, {}, raise_exceptions=False) - - session.send_log_message.assert_not_called() - - -@pytest.mark.anyio -async def test_normal_message_handling_not_affected(): - """Test that normal messages still work correctly""" - server = Server("test-server") - session = Mock(spec=ServerSession) - - # Create a mock RequestResponder - responder = Mock(spec=RequestResponder) - responder.request = types.PingRequest(method="ping") - responder.__enter__ = Mock(return_value=responder) - responder.__exit__ = Mock(return_value=None) - - # Mock the _handle_request method to avoid complex setup - server._handle_request = AsyncMock() - - # Should handle normally without any exception handling - await server._handle_message(responder, session, {}, raise_exceptions=False) - - # Verify _handle_request was called - server._handle_request.assert_called_once() @pytest.mark.anyio @@ -71,23 +15,21 @@ async def test_server_run_exits_cleanly_when_transport_yields_exception_then_clo 1. Transport yields an Exception into the read stream (streamable_http.py does this in its broad POST-handler except). 2. Transport closes the read stream (terminate() in stateless mode). - 3. _receive_loop exits its `async with read_stream, write_stream:` block, - closing the write stream. - 4. Meanwhile _handle_message(exc) was spawned via tg.start_soon and runs - after the write stream is closed. + 3. The read loop exits and closes the write stream. - Before the fix, _handle_message tried to send_log_message through the - closed write stream, raising ClosedResourceError inside the TaskGroup - and crashing server.run(). After the fix, it only logs locally. + Before the fix, the message handler tried to send_log_message through the + closed write stream, raising ClosedResourceError and crashing server.run(). + After the fix (and now in the dispatcher), the exception is only logged + locally. """ server = Server("test-server") read_send, read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) # Zero-buffer on the write stream forces send() to block until received. - # With no receiver, a send() sits blocked until _receive_loop exits its - # `async with self._read_stream, self._write_stream:` block and closes the - # stream, at which point the blocked send raises ClosedResourceError. - # This deterministically reproduces the race without sleeps. + # With no receiver, a send() sits blocked until the read loop exits its + # `async with read_stream, write_stream:` block and closes the stream, at + # which point the blocked send raises ClosedResourceError. This + # deterministically reproduces the race without sleeps. write_send, write_recv = anyio.create_memory_object_stream[SessionMessage](0) # What the streamable HTTP transport does: push the exception, then close. @@ -96,11 +38,11 @@ async def test_server_run_exits_cleanly_when_transport_yields_exception_then_clo with anyio.fail_after(5): # stateless=True so server.run doesn't wait for initialize handshake. - # Before this fix, this raised ExceptionGroup(ClosedResourceError). + # Before the fix, this raised ExceptionGroup(ClosedResourceError). await server.run(read_recv, write_send, server.create_initialization_options(), stateless=True) - # write_send was closed inside _receive_loop's `async with`; receive_nowait - # raises EndOfStream iff the buffer is empty (i.e., server wrote nothing). + # write_send was closed inside run's `async with`; receive_nowait raises + # EndOfStream iff the buffer is empty (i.e., server wrote nothing). with pytest.raises(anyio.EndOfStream): write_recv.receive_nowait() write_recv.close() diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 8b1f8c343e..2199673673 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -1,31 +1,30 @@ """Tests for `ServerRunner`. -End-to-end over `DirectDispatcher` with a real lowlevel `Server` as the +End-to-end over `JSONRPCDispatcher` with a real lowlevel `Server` as the registry. The `connected_runner` helper starts both sides and (by default) performs the initialize handshake, so each test exercises only the behaviour under test. """ -from collections.abc import AsyncIterator, Mapping -from contextlib import AbstractAsyncContextManager, asynccontextmanager +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from typing import Any, cast import anyio -import anyio.lowlevel import pytest from opentelemetry.trace import SpanKind, StatusCode -from mcp.server.connection import Connection -from mcp.server.context import Context +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.models import InitializationOptions from mcp.server.runner import ServerRunner, otel_middleware -from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.server.session import ServerSession from mcp.shared.dispatcher import DispatchMiddleware from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( - INTERNAL_ERROR, INVALID_PARAMS, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, @@ -42,9 +41,12 @@ Tool, ) +from ..shared.conftest import jsonrpc_pair from ..shared.test_dispatcher import Recorder, echo_handlers from .conftest import SpanCapture +Ctx = ServerRequestContext[dict[str, Any], Any] + def _initialize_params() -> dict[str, Any]: return InitializeRequestParams( @@ -54,7 +56,7 @@ def _initialize_params() -> dict[str, Any]: ).model_dump(by_alias=True, exclude_none=True) -_seen_ctx: list[Context[Any]] = [] +_seen_ctx: list[Ctx] = [] SrvT = Server[dict[str, Any]] @@ -63,9 +65,7 @@ def server() -> SrvT: """A lowlevel Server with one tools/list handler registered.""" _seen_ctx.clear() - async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: - # ctx is `Any` while `on_*` kwargs are typed against `ServerRequestContext` - # but `ServerRunner` passes the new `Context`; tightens once the alias lands. + async def list_tools(ctx: Ctx, params: PaginatedRequestParams | None) -> ListToolsResult: _seen_ctx.append(ctx) return ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"})]) @@ -81,17 +81,17 @@ async def connected_runner( has_standalone_channel: bool = True, init_options: InitializationOptions | None = None, session_id: str | None = None, - headers: Mapping[str, str] | None = None, dispatch_middleware: list[DispatchMiddleware] | None = None, -) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[dict[str, Any]]]]: - """Yield `(client, runner)` running over an in-memory dispatcher pair. +) -> AsyncIterator[tuple[JSONRPCDispatcher[TransportContext], ServerRunner[dict[str, Any]]]]: + """Yield `(client, runner)` running over an in-memory JSON-RPC dispatcher pair. Starts the client (echo handlers) and `runner.run()` in a task group, wraps the body in `anyio.fail_after(5)`, and cancels on exit. When `initialized` is true the helper performs the real `initialize` request before yielding, so tests start past the init-gate via the public path. """ - client, server_d = create_direct_dispatcher_pair(headers=headers) + client, server_d, close = jsonrpc_pair() + assert isinstance(client, JSONRPCDispatcher) and isinstance(server_d, JSONRPCDispatcher) runner = ServerRunner( server=server, dispatcher=server_d, @@ -116,8 +116,7 @@ async def connected_runner( # Capture and re-raise outside the task group so test failures # surface as the original exception, not an ExceptionGroup wrapper. body_exc = e - client.close() - server_d.close() + close() if body_exc is not None: raise body_exc @@ -154,14 +153,15 @@ async def test_runner_gates_requests_before_initialize(server: SrvT): @pytest.mark.anyio async def test_runner_routes_to_handler_and_builds_context(server: SrvT): - async with connected_runner(server) as (client, _): + async with connected_runner(server) as (client, runner): result = await client.send_raw_request("tools/list", None) assert result["tools"][0]["name"] == "t" ctx = _seen_ctx[0] - assert isinstance(ctx, Context) - assert ctx.lifespan == {} - assert isinstance(ctx.connection, Connection) - assert ctx.transport.kind == "direct" + assert isinstance(ctx, ServerRequestContext) + assert ctx.lifespan_context == {} + assert isinstance(ctx.session, ServerSession) + assert ctx.session is runner.session + assert ctx.request_id is not None @pytest.mark.anyio @@ -202,18 +202,19 @@ async def test_runner_on_notify_initialized_sets_flag_and_connection_event(serve @pytest.mark.anyio async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): seen: list[tuple[Any, Any]] = [] + delivered = anyio.Event() - async def on_roots_changed(ctx: Any, params: Any) -> None: + async def on_roots_changed(ctx: Ctx, params: NotificationParams | None) -> None: seen.append((ctx, params)) + if len(seen) == 2: + delivered.set() server.add_notification_handler("notifications/roots/list_changed", NotificationParams, on_roots_changed) async with connected_runner(server) as (client, _): await client.notify("notifications/roots/list_changed", None) await client.notify("notifications/roots/list_changed", {}) - # DirectDispatcher delivers synchronously; one yield is enough. - await anyio.lowlevel.checkpoint() - assert len(seen) == 2 - assert isinstance(seen[0][0], Context) + await delivered.wait() + assert isinstance(seen[0][0], ServerRequestContext) # Absent wire params reach the handler as None; present-but-empty validates. assert seen[0][1] is None assert isinstance(seen[1][1], NotificationParams) @@ -248,7 +249,7 @@ async def wrapped(dctx: Any, method: str, params: Any) -> Any: async def test_runner_server_middleware_wraps_handlers_but_not_initialize(server: SrvT): seen_methods: list[str] = [] - async def ctx_mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + async def ctx_mw(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: seen_methods.append(method) return await call_next() @@ -265,7 +266,7 @@ async def test_runner_server_middleware_runs_outermost_first(server: SrvT): order: list[str] = [] def make_mw(tag: str) -> Any: - async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + async def mw(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: order.append(f"{tag}-in") result = await call_next() order.append(f"{tag}-out") @@ -281,7 +282,7 @@ async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: @pytest.mark.anyio async def test_runner_handler_returning_none_yields_empty_result(server: SrvT): - async def set_level(ctx: Any, params: Any) -> None: + async def set_level(ctx: Ctx, params: SetLevelRequestParams) -> None: return None server.add_request_handler("logging/setLevel", SetLevelRequestParams, set_level) @@ -291,8 +292,8 @@ async def set_level(ctx: Any, params: Any) -> None: @pytest.mark.anyio -async def test_runner_handler_returning_unsupported_type_surfaces_as_internal_error(server: SrvT): - async def bad_return(ctx: Any, params: Any) -> int: +async def test_runner_handler_returning_unsupported_type_surfaces_as_error(server: SrvT): + async def bad_return(ctx: Ctx, params: PaginatedRequestParams | None) -> int: return 42 # cast: deliberately registering a handler with a bad return type to @@ -301,7 +302,7 @@ async def bad_return(ctx: Any, params: Any) -> int: async with connected_runner(server) as (client, _): with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) - assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.code == 0 assert "int" in exc.value.error.message @@ -322,7 +323,7 @@ class GreetParams(RequestParams): received: list[GreetParams] = [] - async def greet(ctx: Any, params: GreetParams) -> dict[str, Any]: + async def greet(ctx: Ctx, params: GreetParams) -> dict[str, Any]: received.append(params) return {"greeting": f"hello {params.name}"} @@ -336,7 +337,7 @@ async def greet(ctx: Any, params: GreetParams) -> dict[str, Any]: @pytest.mark.anyio async def test_runner_initialize_result_reflects_init_options(): - async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + async def list_tools(ctx: Ctx, params: PaginatedRequestParams | None) -> ListToolsResult: raise NotImplementedError server: SrvT = Server(name="caps-test", on_list_tools=list_tools, instructions="be nice") @@ -364,7 +365,7 @@ async def test_runner_initialize_echoes_supported_version_and_falls_back_to_late @pytest.mark.anyio async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): - async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: + async def call_tool(ctx: Ctx, params: CallToolRequestParams) -> dict[str, Any]: return {"content": [], "isError": False} server.add_request_handler("tools/call", CallToolRequestParams, call_tool) @@ -372,26 +373,27 @@ async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: spans.clear() result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) assert result == {"content": [], "isError": False} - [span] = spans.finished() + finished = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + [span] = finished assert span.name == "MCP handle tools/call mytool" - assert span.kind == SpanKind.SERVER assert span.attributes is not None assert span.attributes["mcp.method.name"] == "tools/call" assert span.status.status_code == StatusCode.UNSET @pytest.mark.anyio -async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, spans: SpanCapture): - parent_span_id = "b7ad6b7169203331" - traceparent = f"00-0af7651916cd43dd8448eb211c80319c-{parent_span_id}-01" +async def test_otel_trace_context_propagates_client_to_server(server: SrvT, spans: SpanCapture): + """The client dispatcher injects traceparent into `_meta`; the server's + `otel_middleware` extracts it, so client and server spans share a trace.""" async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() - await client.send_raw_request("tools/list", {"_meta": {"traceparent": traceparent}}) - [span] = spans.finished() - assert span.parent is not None - assert format(span.parent.span_id, "016x") == parent_span_id - assert span.context is not None - assert format(span.context.trace_id, "032x") == "0af7651916cd43dd8448eb211c80319c" + await client.send_raw_request("tools/list", None) + [client_span] = [s for s in spans.finished() if s.kind == SpanKind.CLIENT] + [server_span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert server_span.parent is not None + assert client_span.context is not None and server_span.context is not None + assert server_span.parent.span_id == client_span.context.span_id + assert server_span.context.trace_id == client_span.context.trace_id @pytest.mark.anyio @@ -401,7 +403,7 @@ async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, s with pytest.raises(MCPError) as exc: await client.send_raw_request("resources/list", None) assert exc.value.error.code == METHOD_NOT_FOUND - [span] = spans.finished() + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] assert span.status.status_code == StatusCode.ERROR assert span.status.description == "Method not found" # MCPError is a protocol-level response, not a crash - no traceback event. @@ -410,7 +412,7 @@ async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, s @pytest.mark.anyio async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: SpanCapture): - async def failing(ctx: Any, params: Any) -> Any: + async def failing(ctx: Ctx, params: PaginatedRequestParams | None) -> Any: raise ValueError("handler blew up") server.add_request_handler("tools/list", PaginatedRequestParams, failing) @@ -418,8 +420,8 @@ async def failing(ctx: Any, params: Any) -> Any: spans.clear() with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) - assert exc.value.error.code == INTERNAL_ERROR - [span] = spans.finished() + assert exc.value.error.code == 0 + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] assert span.status.status_code == StatusCode.ERROR assert span.status.description == "handler blew up" [event] = [e for e in span.events if e.name == "exception"] @@ -428,97 +430,16 @@ async def failing(ctx: Any, params: Any) -> Any: @pytest.mark.anyio -async def test_connection_state_persists_across_requests_on_same_connection(server: SrvT) -> None: - async def count(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: - ctx.connection.state["n"] = ctx.connection.state.get("n", 0) + 1 - return ListToolsResult(tools=[]) - - server.add_request_handler("tools/list", PaginatedRequestParams, count) - async with connected_runner(server) as (client, runner): - await client.send_raw_request("tools/list", None) - await client.send_raw_request("tools/list", None) - assert runner.connection.state == {"n": 2} - - -@pytest.mark.anyio -async def test_connection_exit_stack_runs_pushed_callback_after_close(server: SrvT) -> None: - cleaned: list[str] = [] - - async def push(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: - async def _cleanup() -> None: - cleaned.append("done") - - ctx.connection.exit_stack.push_async_callback(_cleanup) - return ListToolsResult(tools=[]) - - server.add_request_handler("tools/list", PaginatedRequestParams, push) - async with connected_runner(server) as (client, _runner): - await client.send_raw_request("tools/list", None) - assert cleaned == [] - assert cleaned == ["done"] - - -@pytest.mark.anyio -async def test_connection_exit_stack_unwinds_entered_context_manager_after_close(server: SrvT) -> None: - events: list[str] = [] - - class _Tracker(AbstractAsyncContextManager[str]): - async def __aenter__(self) -> str: - events.append("enter") - return "resource" - - async def __aexit__(self, *exc: object) -> None: - events.append("exit") - - async def acquire(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: - res = await ctx.connection.exit_stack.enter_async_context(_Tracker()) - ctx.connection.state["res"] = res - return ListToolsResult(tools=[]) - - server.add_request_handler("tools/list", PaginatedRequestParams, acquire) - async with connected_runner(server) as (client, runner): - await client.send_raw_request("tools/list", None) - assert events == ["enter"] - assert runner.connection.state["res"] == "resource" - assert events == ["enter", "exit"] - - -@pytest.mark.anyio -async def test_connection_exit_stack_runs_callbacks_lifo_after_handler_error(server: SrvT) -> None: +async def test_runner_connection_exit_stack_unwinds_after_run_returns(server: SrvT) -> None: + """`runner.connection.exit_stack` is closed when the dispatcher loop ends.""" cleaned: list[int] = [] - async def push_then_fail(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: - for i in (1, 2, 3): - ctx.connection.exit_stack.push_async_callback(_append, i) - raise RuntimeError("boom") - async def _append(i: int) -> None: cleaned.append(i) - server.add_request_handler("tools/list", PaginatedRequestParams, push_then_fail) - async with connected_runner(server) as (client, _runner): - with pytest.raises(MCPError) as ei: - await client.send_raw_request("tools/list", None) - assert ei.value.error.code == INTERNAL_ERROR + async with connected_runner(server) as (client, runner): + for i in (1, 2, 3): + runner.connection.exit_stack.push_async_callback(_append, i) + await client.send_raw_request("tools/list", None) assert cleaned == [] assert cleaned == [3, 2, 1] - - -@pytest.mark.anyio -async def test_context_session_id_and_headers_expose_connection_and_transport(server: SrvT) -> None: - async with connected_runner(server, session_id="sess-abc", headers={"authorization": "Bearer t"}) as (client, _r): - await client.send_raw_request("tools/list", None) - [ctx] = _seen_ctx - assert ctx.session_id == "sess-abc" - assert ctx.session_id == ctx.connection.session_id - assert ctx.headers == {"authorization": "Bearer t"} - assert ctx.headers is ctx.transport.headers - - -@pytest.mark.anyio -async def test_context_session_id_and_headers_default_none(server: SrvT) -> None: - async with connected_runner(server) as (client, _r): - await client.send_raw_request("tools/list", None) - [ctx] = _seen_ctx - assert ctx.session_id is None - assert ctx.headers is None diff --git a/tests/server/test_session.py b/tests/server/test_session.py deleted file mode 100644 index 6116a7c7f5..0000000000 --- a/tests/server/test_session.py +++ /dev/null @@ -1,535 +0,0 @@ -from typing import Any - -import anyio -import pytest - -from mcp import types -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ( - ClientNotification, - CompletionsCapability, - InitializedNotification, - PromptsCapability, - ResourcesCapability, - ServerCapabilities, -) - - -@pytest.mark.anyio -async def test_server_session_initialize(): - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - - # Create a message handler to catch exceptions - async def message_handler( # pragma: no cover - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): - raise message - - received_initialized = False - - async def run_server(): - nonlocal received_initialized - - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="mcp", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - async for message in server_session.incoming_messages: # pragma: no branch - if isinstance(message, Exception): # pragma: no cover - raise message - - if isinstance(message, ClientNotification) and isinstance( - message, InitializedNotification - ): # pragma: no branch - received_initialized = True - return - - try: - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - - await client_session.initialize() - except anyio.ClosedResourceError: # pragma: no cover - pass - - assert received_initialized - - -@pytest.mark.anyio -async def test_check_client_capability(): - """check_client_capability reflects the capabilities sent by the client at initialize.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - - initialized = anyio.Event() - - async def list_roots_callback(context: Any) -> types.ListRootsResult: # pragma: no cover - return types.ListRootsResult(roots=[]) - - async def run_server(server_session: ServerSession): - async for message in server_session.incoming_messages: # pragma: no branch - if isinstance(message, ClientNotification) and isinstance( - message, InitializedNotification - ): # pragma: no branch - initialized.set() - return - - async with ( - ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions(server_name="mcp", server_version="0.1.0", capabilities=ServerCapabilities()), - ) as server_session, - ClientSession( - server_to_client_receive, - client_to_server_send, - list_roots_callback=list_roots_callback, - ) as client_session, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server, server_session) - await client_session.initialize() - with anyio.fail_after(5): - await initialized.wait() - - # ClientSession advertises roots when a list_roots_callback is provided. - assert server_session.check_client_capability(types.ClientCapabilities(roots=types.RootsCapability())) - # ClientSession does not advertise sampling without a sampling_callback. - assert not server_session.check_client_capability(types.ClientCapabilities(sampling=types.SamplingCapability())) - - -@pytest.mark.anyio -async def test_server_capabilities(): - notification_options = NotificationOptions() - experimental_capabilities: dict[str, Any] = {} - - async def noop_list_prompts( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListPromptsResult: - raise NotImplementedError - - async def noop_list_resources( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListResourcesResult: - raise NotImplementedError - - async def noop_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: - raise NotImplementedError - - # No capabilities - server = Server("test") - caps = server.get_capabilities(notification_options, experimental_capabilities) - assert caps.prompts is None - assert caps.resources is None - assert caps.completions is None - - # With prompts handler - server = Server("test", on_list_prompts=noop_list_prompts) - caps = server.get_capabilities(notification_options, experimental_capabilities) - assert caps.prompts == PromptsCapability(list_changed=False) - assert caps.resources is None - assert caps.completions is None - - # With prompts + resources handlers - server = Server("test", on_list_prompts=noop_list_prompts, on_list_resources=noop_list_resources) - caps = server.get_capabilities(notification_options, experimental_capabilities) - assert caps.prompts == PromptsCapability(list_changed=False) - assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) - assert caps.completions is None - - # With prompts + resources + completion handlers - server = Server( - "test", - on_list_prompts=noop_list_prompts, - on_list_resources=noop_list_resources, - on_completion=noop_completion, - ) - caps = server.get_capabilities(notification_options, experimental_capabilities) - assert caps.prompts == PromptsCapability(list_changed=False) - assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) - assert caps.completions == CompletionsCapability() - - -@pytest.mark.anyio -async def test_server_session_initialize_with_older_protocol_version(): - """Test that server accepts and responds with older protocol (2024-11-05).""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - received_initialized = False - received_protocol_version = None - - async def run_server(): - nonlocal received_initialized - - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="mcp", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - async for message in server_session.incoming_messages: # pragma: no branch - if isinstance(message, Exception): # pragma: no cover - raise message - - if isinstance(message, types.ClientNotification) and isinstance( - message, InitializedNotification - ): # pragma: no branch - received_initialized = True - return - - async def mock_client(): - nonlocal received_protocol_version - - # Send initialization request with older protocol version (2024-11-05) - await client_to_server_send.send( - SessionMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=types.InitializeRequestParams( - protocol_version="2024-11-05", - capabilities=types.ClientCapabilities(), - client_info=types.Implementation(name="test-client", version="1.0.0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - - # Wait for the initialize response - init_response_message = await server_to_client_receive.receive() - assert isinstance(init_response_message.message, types.JSONRPCResponse) - result_data = init_response_message.message.result - init_result = types.InitializeResult.model_validate(result_data) - - # Check that the server responded with the requested protocol version - received_protocol_version = init_result.protocol_version - assert received_protocol_version == "2024-11-05" - - # Send initialized notification - await client_to_server_send.send( - SessionMessage(types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) - ) - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - tg.start_soon(mock_client) - - assert received_initialized - assert received_protocol_version == "2024-11-05" - - -@pytest.mark.anyio -async def test_ping_request_before_initialization(): - """Test that ping requests are allowed before initialization is complete.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - ping_response_received = False - ping_response_id = None - - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="mcp", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - async for message in server_session.incoming_messages: # pragma: no branch - if isinstance(message, Exception): # pragma: no cover - raise message - - # We should receive a ping request before initialization - if isinstance(message, RequestResponder) and isinstance( - message.request, types.PingRequest - ): # pragma: no branch - # Respond to the ping - with message: - await message.respond(types.EmptyResult()) - return - - async def mock_client(): - nonlocal ping_response_received, ping_response_id - - # Send ping request before any initialization - await client_to_server_send.send(SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=42, method="ping"))) - - # Wait for the ping response - ping_response_message = await server_to_client_receive.receive() - assert isinstance(ping_response_message.message, types.JSONRPCResponse) - - ping_response_received = True - ping_response_id = ping_response_message.message.id - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - tg.start_soon(mock_client) - - assert ping_response_received - assert ping_response_id == 42 - - -@pytest.mark.anyio -async def test_create_message_tool_result_validation(): - """Test tool_use/tool_result validation in create_message.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ) as session: - # Set up client params with sampling.tools capability for the test - session._client_params = types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities( - sampling=types.SamplingCapability(tools=types.SamplingToolsCapability()) - ), - client_info=types.Implementation(name="test", version="1.0"), - ) - - tool = types.Tool(name="test_tool", input_schema={"type": "object"}) - text = types.TextContent(type="text", text="hello") - tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={}) - tool_result = types.ToolResultContent(type="tool_result", tool_use_id="call_1", content=[]) - - # Case 1: tool_result mixed with other content - with pytest.raises(ValueError, match="only tool_result content"): - await session.create_message( - messages=[ - types.SamplingMessage(role="user", content=text), - types.SamplingMessage(role="assistant", content=tool_use), - types.SamplingMessage(role="user", content=[tool_result, text]), # mixed! - ], - max_tokens=100, - tools=[tool], - ) - - # Case 2: tool_result without previous message - with pytest.raises(ValueError, match="requires a previous message"): - await session.create_message( - messages=[types.SamplingMessage(role="user", content=tool_result)], - max_tokens=100, - tools=[tool], - ) - - # Case 3: tool_result without previous tool_use - with pytest.raises(ValueError, match="do not match any tool_use"): - await session.create_message( - messages=[ - types.SamplingMessage(role="user", content=text), - types.SamplingMessage(role="user", content=tool_result), - ], - max_tokens=100, - tools=[tool], - ) - - # Case 4: mismatched tool IDs - with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"): - await session.create_message( - messages=[ - types.SamplingMessage(role="user", content=text), - types.SamplingMessage(role="assistant", content=tool_use), - types.SamplingMessage( - role="user", - content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]), - ), - ], - max_tokens=100, - tools=[tool], - ) - - # Case 5: text-only message with tools (no tool_results) - passes validation - # Covers has_tool_results=False branch. - # We use move_on_after because validation happens synchronously before - # send_request, which would block indefinitely waiting for a response. - # The timeout lets validation pass, then cancels the blocked send. - with anyio.move_on_after(0.01): - await session.create_message( - messages=[types.SamplingMessage(role="user", content=text)], - max_tokens=100, - tools=[tool], - ) - - # Case 6: valid matching tool_result/tool_use IDs - passes validation - # Covers tool_use_ids == tool_result_ids branch. - # (see Case 5 comment for move_on_after explanation) - with anyio.move_on_after(0.01): - await session.create_message( - messages=[ - types.SamplingMessage(role="user", content=text), - types.SamplingMessage(role="assistant", content=tool_use), - types.SamplingMessage(role="user", content=tool_result), - ], - max_tokens=100, - tools=[tool], - ) - - # Case 7: validation runs even without `tools` parameter - # (tool loop continuation may omit tools while containing tool_result) - with pytest.raises(ValueError, match="do not match any tool_use"): - await session.create_message( - messages=[ - types.SamplingMessage(role="user", content=text), - types.SamplingMessage(role="user", content=tool_result), - ], - max_tokens=100, - # Note: no tools parameter - ) - - # Case 8: empty messages list - skips validation entirely - # Covers the `if messages:` branch (line 280->302) - with anyio.move_on_after(0.01): # pragma: no branch - await session.create_message(messages=[], max_tokens=100) - - -@pytest.mark.anyio -async def test_create_message_without_tools_capability(): - """Test that create_message raises MCPError when tools are provided without capability.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ) as session: - # Set up client params WITHOUT sampling.tools capability - session._client_params = types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities(sampling=types.SamplingCapability()), - client_info=types.Implementation(name="test", version="1.0"), - ) - - tool = types.Tool(name="test_tool", input_schema={"type": "object"}) - text = types.TextContent(type="text", text="hello") - - # Should raise MCPError when tools are provided but client lacks capability - with pytest.raises(MCPError) as exc_info: - await session.create_message( - messages=[types.SamplingMessage(role="user", content=text)], - max_tokens=100, - tools=[tool], - ) - assert "does not support sampling tools capability" in exc_info.value.error.message - - # Should also raise MCPError when tool_choice is provided - with pytest.raises(MCPError) as exc_info: - await session.create_message( - messages=[types.SamplingMessage(role="user", content=text)], - max_tokens=100, - tool_choice=types.ToolChoice(mode="auto"), - ) - assert "does not support sampling tools capability" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_other_requests_blocked_before_initialization(): - """Test that non-ping requests are still blocked before initialization.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - error_response_received = False - error_code = None - - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="mcp", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ): - # Server should handle the request and send an error response - # No need to process incoming_messages since the error is handled automatically - await anyio.sleep(0.1) # Give time for the request to be processed - - async def mock_client(): - nonlocal error_response_received, error_code - - # Try to send a non-ping request before initialization - await client_to_server_send.send( - SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=1, method="prompts/list")) - ) - - # Wait for the error response - error_message = await server_to_client_receive.receive() - if isinstance(error_message.message, types.JSONRPCError): # pragma: no branch - error_response_received = True - error_code = error_message.message.error.code - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - tg.start_soon(mock_client) - - assert error_response_received - assert error_code == types.INVALID_PARAMS diff --git a/tests/server/test_session_race_condition.py b/tests/server/test_session_race_condition.py deleted file mode 100644 index 81041152bc..0000000000 --- a/tests/server/test_session_race_condition.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Test for race condition fix in initialization flow. - -This test verifies that requests can be processed immediately after -responding to InitializeRequest, without waiting for InitializedNotification. - -This is critical for HTTP transport where requests can arrive in any order. -""" - -import anyio -import pytest - -from mcp import types -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ServerCapabilities, Tool - - -@pytest.mark.anyio -async def test_request_immediately_after_initialize_response(): - """Test that requests are accepted immediately after initialize response. - - This reproduces the race condition in stateful HTTP mode where: - 1. Client sends InitializeRequest - 2. Server responds with InitializeResult - 3. Client immediately sends tools/list (before server receives InitializedNotification) - 4. Without fix: Server rejects with "Received request before initialization was complete" - 5. With fix: Server accepts and processes the request - - This test simulates the HTTP transport behavior where InitializedNotification - may arrive in a separate POST request after other requests. - """ - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](10) - - tools_list_success = False - error_received = None - - async def run_server(): - nonlocal tools_list_success - - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities( - tools=types.ToolsCapability(list_changed=False), - ), - ), - ) as server_session: - async for message in server_session.incoming_messages: # pragma: no branch - if isinstance(message, Exception): # pragma: no cover - raise message - - # Handle tools/list request - if isinstance(message, RequestResponder): - if isinstance(message.request, types.ListToolsRequest): # pragma: no branch - tools_list_success = True - # Respond with a tool list - with message: - await message.respond( - types.ListToolsResult( - tools=[ - Tool( - name="example_tool", - description="An example tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - ) - - # Handle InitializedNotification - if isinstance(message, types.ClientNotification): - if isinstance(message, types.InitializedNotification): # pragma: no branch - # Done - exit gracefully - return - - async def mock_client(): - nonlocal error_received - - # Step 1: Send InitializeRequest - await client_to_server_send.send( - SessionMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities(), - client_info=types.Implementation(name="test-client", version="1.0.0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - - # Step 2: Wait for InitializeResult - init_msg = await server_to_client_receive.receive() - assert isinstance(init_msg.message, types.JSONRPCResponse) - - # Step 3: Immediately send tools/list BEFORE InitializedNotification - # This is the race condition scenario - await client_to_server_send.send(SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=2, method="tools/list"))) - - # Step 4: Check the response - tools_msg = await server_to_client_receive.receive() - if isinstance(tools_msg.message, types.JSONRPCError): # pragma: no cover - error_received = tools_msg.message.error.message - - # Step 5: Send InitializedNotification - await client_to_server_send.send( - SessionMessage(types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) - ) - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - tg.start_soon(mock_client) - - # With the PR fix: tools_list_success should be True, error_received should be None - # Without the fix: error_received would contain "Received request before initialization was complete" - assert tools_list_success, f"tools/list should have succeeded. Error received: {error_received}" - assert error_received is None, f"Expected no error, but got: {error_received}" diff --git a/tests/server/test_stateless_mode.py b/tests/server/test_stateless_mode.py index 3bfc6e674c..abe92062c5 100644 --- a/tests/server/test_stateless_mode.py +++ b/tests/server/test_stateless_mode.py @@ -7,45 +7,30 @@ See: https://github.com/modelcontextprotocol/python-sdk/issues/1097 """ -from collections.abc import AsyncGenerator from typing import Any +from unittest.mock import Mock -import anyio import pytest from mcp import types -from mcp.server.models import InitializationOptions +from mcp.server.connection import Connection from mcp.server.session import ServerSession from mcp.shared.exceptions import StatelessModeNotSupported -from mcp.shared.message import SessionMessage -from mcp.types import ServerCapabilities +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher -@pytest.fixture -async def stateless_session() -> AsyncGenerator[ServerSession, None]: - """Create a stateless ServerSession for testing.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - init_options = InitializationOptions( - server_name="test", - server_version="0.1.0", - capabilities=ServerCapabilities(), +def _make_session(*, stateless: bool) -> ServerSession: + """A `ServerSession` with a mock dispatcher; the stateless guard fires before any send.""" + return ServerSession( + Mock(spec=JSONRPCDispatcher), + Connection(Mock(), has_standalone_channel=False), + stateless=stateless, ) - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - init_options, - stateless=True, - ) as session: - yield session + +@pytest.fixture +def stateless_session() -> ServerSession: + return _make_session(stateless=True) @pytest.mark.anyio @@ -126,30 +111,8 @@ async def test_exception_has_method_attribute(stateless_session: ServerSession): @pytest.fixture -async def stateful_session() -> AsyncGenerator[ServerSession, None]: - """Create a stateful ServerSession for testing.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - init_options = InitializationOptions( - server_name="test", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ) - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - init_options, - stateless=False, - ) as session: - yield session +def stateful_session() -> ServerSession: + return _make_session(stateless=False) @pytest.mark.anyio diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py deleted file mode 100644 index aad9e5d439..0000000000 --- a/tests/shared/test_progress_notifications.py +++ /dev/null @@ -1,264 +0,0 @@ -from typing import Any -from unittest.mock import patch - -import anyio -import pytest - -from mcp import Client, types -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder - - -@pytest.mark.anyio -async def test_bidirectional_progress_notifications(): - """Test that both client and server can send progress notifications.""" - # Create memory streams for client/server - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) - - # Run a server session so we can send progress updates in tool - async def run_server(): - # Create a server session - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="ProgressTestServer", - server_version="0.1.0", - capabilities=server.get_capabilities(NotificationOptions(), {}), - ), - ) as server_session: - async for message in server_session.incoming_messages: - try: - await server._handle_message(message, server_session, {}) - except Exception as e: # pragma: no cover - raise e - - # Track progress updates - server_progress_updates: list[dict[str, Any]] = [] - client_progress_updates: list[dict[str, Any]] = [] - - # Progress tokens - server_progress_token = "server_token_123" - client_progress_token = "client_token_456" - - # Register progress handler - async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: - server_progress_updates.append( - { - "token": params.progress_token, - "progress": params.progress, - "total": params.total, - "message": params.message, - } - ) - - # Register list tool handler - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool( - name="test_tool", - description="A tool that sends progress notifications types.CallToolResult: - # Make sure we received a progress token - if params.name == "test_tool": - assert params.meta is not None - progress_token = params.meta.get("progress_token") - assert progress_token is not None - assert progress_token == client_progress_token - - # Send progress notifications using ctx.session - await ctx.session.send_progress_notification( - progress_token=progress_token, - progress=0.25, - total=1.0, - message="Server progress 25%", - ) - - await ctx.session.send_progress_notification( - progress_token=progress_token, - progress=0.5, - total=1.0, - message="Server progress 50%", - ) - - await ctx.session.send_progress_notification( - progress_token=progress_token, - progress=1.0, - total=1.0, - message="Server progress 100%", - ) - - return types.CallToolResult(content=[types.TextContent(type="text", text="Tool executed successfully")]) - - raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover - - # Create a server with progress capability - server = Server( - name="ProgressTestServer", - on_progress=handle_progress, - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - - # Client message handler to store progress notifications - async def handle_client_message( - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message, types.ProgressNotification): # pragma: no branch - params = message.params - client_progress_updates.append( - { - "token": params.progress_token, - "progress": params.progress, - "total": params.total, - "message": params.message, - } - ) - - # Test using client - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=handle_client_message, - ) as client_session, - anyio.create_task_group() as tg, - ): - # Start the server in a background task - tg.start_soon(run_server) - - # Initialize the client connection - await client_session.initialize() - - # Call list_tools with progress token - await client_session.list_tools() - - # Call test_tool with progress token - await client_session.call_tool("test_tool", meta={"progress_token": client_progress_token}) - - # Send progress notifications from client to server - await client_session.send_progress_notification( - progress_token=server_progress_token, - progress=0.33, - total=1.0, - message="Client progress 33%", - ) - - await client_session.send_progress_notification( - progress_token=server_progress_token, - progress=0.66, - total=1.0, - message="Client progress 66%", - ) - - await client_session.send_progress_notification( - progress_token=server_progress_token, - progress=1.0, - total=1.0, - message="Client progress 100%", - ) - - # Wait and exit - await anyio.sleep(0.5) - tg.cancel_scope.cancel() - - # Verify client received progress updates from server - assert len(client_progress_updates) == 3 - assert client_progress_updates[0]["token"] == client_progress_token - assert client_progress_updates[0]["progress"] == 0.25 - assert client_progress_updates[0]["message"] == "Server progress 25%" - assert client_progress_updates[2]["progress"] == 1.0 - - # Verify server received progress updates from client - assert len(server_progress_updates) == 3 - assert server_progress_updates[0]["token"] == server_progress_token - assert server_progress_updates[0]["progress"] == 0.33 - assert server_progress_updates[0]["message"] == "Client progress 33%" - assert server_progress_updates[2]["progress"] == 1.0 - - -@pytest.mark.anyio -async def test_progress_callback_exception_logging(): - """Test that exceptions in progress callbacks are logged and \ - don't crash the session.""" - # Track logged warnings - logged_errors: list[str] = [] - - def mock_log_exception(msg: str, *args: Any, **kwargs: Any) -> None: - logged_errors.append(msg % args if args else msg) - - # Create a progress callback that raises an exception - async def failing_progress_callback(progress: float, total: float | None, message: str | None) -> None: - raise ValueError("Progress callback failed!") - - # Create a server with a tool that sends progress notifications - async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - if params.name == "progress_tool": - assert ctx.request_id is not None - # Send a progress notification - await ctx.session.send_progress_notification( - progress_token=ctx.request_id, - progress=50.0, - total=100.0, - message="Halfway done", - ) - return types.CallToolResult(content=[types.TextContent(type="text", text="progress_result")]) - raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover - - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool( - name="progress_tool", - description="A tool that sends progress notifications", - input_schema={}, - ) - ] - ) - - server = Server( - name="TestProgressServer", - on_call_tool=handle_call_tool, - on_list_tools=handle_list_tools, - ) - - # Test with mocked logging - with patch("mcp.shared.session.logging.exception", side_effect=mock_log_exception): - async with Client(server) as client: - # Call tool with a failing progress callback - result = await client.call_tool( - "progress_tool", - arguments={}, - progress_callback=failing_progress_callback, - ) - - # Verify the request completed successfully despite the callback failure - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, types.TextContent) - assert content.text == "progress_result" - - # Check that a warning was logged for the progress callback exception - assert len(logged_errors) > 0 - assert any("Progress callback raised an exception" in warning for warning in logged_errors) From fa203522898f318285e688cd68d9617cce801cc2 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 18:23:43 +0000 Subject: [PATCH 43/52] fix: notification handlers and progress callbacks cannot crash the connection Both run as bare tasks in the dispatcher's task group; an uncaught exception cancels every sibling (read loop + in-flight requests) and tears down run(). The previous receive-loop swallowed and logged both. ServerRunner._on_notify: ValidationError -> warning + drop; handler Exception -> logger.exception + swallow. JSONRPCDispatcher: user on_progress callbacks are wrapped so a raise is logged and swallowed instead of cascading. Regression tests confirmed to fail before the fix. --- src/mcp/server/runner.py | 17 ++++++++++--- src/mcp/shared/jsonrpc_dispatcher.py | 20 ++++++++++++++- tests/server/conftest.py | 2 -- tests/server/test_runner.py | 33 +++++++++++++++++++++++++ tests/shared/test_jsonrpc_dispatcher.py | 20 +++++++++++++++ 5 files changed, 86 insertions(+), 6 deletions(-) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 2ed0381535..a2f5eb6846 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -23,7 +23,7 @@ import anyio.abc from opentelemetry.trace import SpanKind, StatusCode -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing_extensions import TypeVar from mcp.server.connection import Connection @@ -237,9 +237,20 @@ async def _on_notify( return # Absent wire params reach the handler as `None`, not an empty model # (matches the existing `Server._handle_notification`). - typed_params = entry.params_type.model_validate(params) if params is not None else None + try: + typed_params = entry.params_type.model_validate(params) if params is not None else None + except ValidationError: + logger.warning("dropped %r: malformed params", method) + return ctx = self._make_context(dctx, typed_params) - await entry.handler(ctx, typed_params) + try: + await entry.handler(ctx, typed_params) + except Exception: + # Top-level boundary: a notification handler crashing must not + # tear down the connection (it runs as a bare task in the + # dispatcher's task group; an uncaught exception would cancel + # every sibling, including the read loop and in-flight requests). + logger.exception("notification handler for %r raised", method) def _make_context( self, dctx: DispatchContext[TransportContext], typed_params: BaseModel | None diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 557b598d25..67a6a0bdaa 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -162,6 +162,24 @@ def _default_transport_builder(_meta: MessageMetadata) -> TransportContext: return TransportContext(kind="jsonrpc", can_send_request=True) +def _shielded_progress(fn: ProgressFnT) -> ProgressFnT: + """Wrap a user progress callback so it can't crash the dispatcher. + + The callback runs as a bare task in the dispatcher's task group; an + uncaught exception would cancel every sibling (the read loop and all + in-flight requests). Swallow and log instead, matching the previous + receive-loop's behavior. + """ + + async def _wrapped(progress: float, total: float | None, message: str | None) -> None: + try: + await fn(progress, total, message) + except Exception: + logger.exception("progress callback raised") + + return _wrapped + + def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: """Choose the `SessionMessage.metadata` for an outgoing request/notification. @@ -471,7 +489,7 @@ def _dispatch_notification( total = msg.params.get("total") message = msg.params.get("message") self._spawn( - pending.on_progress, + _shielded_progress(pending.on_progress), float(progress), float(total) if isinstance(total, int | float) else None, message if isinstance(message, str) else None, diff --git a/tests/server/conftest.py b/tests/server/conftest.py index 9114f0348d..e0fa8ee9b0 100644 --- a/tests/server/conftest.py +++ b/tests/server/conftest.py @@ -44,5 +44,3 @@ def spans(capfire: CaptureLogfire) -> Iterator[SpanCapture]: capture.clear() yield capture capture.clear() - - diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 2199673673..069ffed994 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -220,6 +220,39 @@ async def on_roots_changed(ctx: Ctx, params: NotificationParams | None) -> None: assert isinstance(seen[1][1], NotificationParams) +@pytest.mark.anyio +async def test_runner_on_notify_handler_exception_is_swallowed_and_logged( + server: SrvT, caplog: pytest.LogCaptureFixture +): + """A notification handler crashing must not tear down the connection.""" + + async def boom(ctx: Ctx, params: NotificationParams | None) -> None: + raise RuntimeError("notification handler boom") + + server.add_notification_handler("notifications/roots/list_changed", NotificationParams, boom) + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", None) + # Connection still alive: a request after the crashing handler succeeds. + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + assert "notification handler for 'notifications/roots/list_changed' raised" in caplog.text + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_malformed_params(server: SrvT, caplog: pytest.LogCaptureFixture): + """Malformed notification params are logged and dropped, not raised.""" + + async def on_level(ctx: Ctx, params: SetLevelRequestParams) -> None: + raise NotImplementedError + + server.add_notification_handler("notifications/roots/list_changed", SetLevelRequestParams, on_level) + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", {"level": "not-a-level"}) + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + assert "dropped 'notifications/roots/list_changed': malformed params" in caplog.text + + @pytest.mark.anyio async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): async with connected_runner(server, initialized=False) as (client, _): diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 035f943de1..91b7d63dc9 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -441,6 +441,26 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | assert received == [(0.25, None, None)] +@pytest.mark.anyio +async def test_progress_callback_exception_is_swallowed_and_logged(caplog: pytest.LogCaptureFixture): + """A user progress callback raising must not crash the dispatcher.""" + + async def boom(progress: float, total: float | None, message: str | None) -> None: + raise RuntimeError("progress callback boom") + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + opts: CallOptions = {"on_progress": boom} + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("t", None, opts) + # Request still completes; the callback's crash was swallowed. + assert result == {"ok": True} + assert "progress callback raised" in caplog.text + + @pytest.mark.anyio async def test_send_raw_request_always_carries_meta_on_the_wire(): """Outbound requests always include `params._meta` (otel injection per SEP-414). From 07d94ee8f5ca65e631b73a47d8921ace624d33a6 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 18:28:48 +0000 Subject: [PATCH 44/52] fix: late peer-cancel cannot double-respond; mcpserver completion gets typed params JSONRPCDispatcher._handle_request: pop from _in_flight in the inner finally (right after the handler returns, before _write_result). No checkpoint between handler return and pop, so a late notifications/cancelled finds nothing and is a no-op; scope.cancel_called is only true if the cancel landed during the handler. Previously a cancel arriving during _write_result's checkpoint after the result was buffered would send both the result and a code=0 "Request cancelled" for the same id. mcpserver/server.py: register completion/complete via add_request_handler with CompleteRequestParams (was the legacy _add_request_handler which defaulted params_type=RequestParams, so the handler got base RequestParams and params.ref AttributeErrored). Delete _add_request_handler (last caller). --- src/mcp/server/lowlevel/server.py | 9 --------- src/mcp/server/mcpserver/server.py | 6 +----- src/mcp/shared/jsonrpc_dispatcher.py | 23 +++++++++++++---------- 3 files changed, 14 insertions(+), 24 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 507707c254..5ab68349dd 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -272,15 +272,6 @@ def add_notification_handler( """ self._notification_handlers[method] = HandlerEntry(params_type, handler) - def _add_request_handler( - self, - method: str, - handler: RequestHandler[LifespanResultT, Any], - ) -> None: - # TODO: remove once experimental tasks plumbing and remaining callers - # migrate to `add_request_handler` with an explicit params_type. - self.add_request_handler(method, types.RequestParams, handler) - # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ def get_request_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index ec2365810e..fdb69571d8 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -607,11 +607,7 @@ async def handler( completion=result if result is not None else Completion(values=[], total=None, has_more=None), ) - # TODO(maxisbey): remove private access — completion needs post-construction - # handler registration, find a better pattern for this - self._lowlevel_server._add_request_handler( # pyright: ignore[reportPrivateUsage] - "completion/complete", handler - ) + self._lowlevel_server.add_request_handler("completion/complete", CompleteRequestParams, handler) return func return decorator diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 67a6a0bdaa..5f3e801eca 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -565,20 +565,23 @@ async def _handle_request( try: result = await on_request(dctx, req.method, req.params) finally: - # Close the back-channel the moment the handler exits - # (success or raise), before the response write - a handler - # spawning detached work that later calls - # `dctx.send_raw_request()` should see `NoBackChannelError`. + # Handler done: close the back-channel (detached work that + # later calls `dctx.send_raw_request()` should see + # `NoBackChannelError`) and drop from `_in_flight` so a + # late `notifications/cancelled` is a no-op rather than + # racing the result write below. No checkpoint between + # handler return and the pop, so the cancel can't + # interleave there. dctx.close() + self._in_flight.pop(req.id, None) await self._write_result(req.id, result) if scope.cancel_called: - # Peer-cancel: `_dispatch_notification` cancelled this scope. - # anyio swallows a scope's *own* cancel at __exit__, so the - # result write (or the handler) is interrupted and execution - # lands here rather than the `except cancelled` arm below. + # Peer-cancel: `_dispatch_notification` cancelled this scope + # while the handler was running. anyio swallows a scope's *own* + # cancel at __exit__, so execution lands here rather than the + # `except cancelled` arm below. # TODO(maxisbey): spec says SHOULD NOT respond after cancel. - # The existing server always has and the interaction suite pins - # that; revisit once the suite's divergence entry is resolved. + # The existing server always has, so match that for now. await self._write_error(req.id, ErrorData(code=0, message="Request cancelled")) except anyio.get_cancelled_exc_class(): # Outer-cancel: run()'s task group is shutting down. Any bare From 535d62104b2e3f8b5b35dc72bcf756e773532e74 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 18:56:15 +0000 Subject: [PATCH 45/52] feat: JSONRPCDispatcher.inline_methods; otel_middleware sets jsonrpc.request.id inline_methods: request methods in this set are awaited inline in the read loop instead of spawned, so their side effects are visible to the next dequeued message. Server.run() passes {"initialize"} so a client that pipelines initialize + the next request without awaiting InitializeResult (spec says SHOULD NOT, not MUST NOT) sees the initialized state instead of failing the init-gate. Matches the go-sdk's explicit carve-out. _dispatch / _dispatch_request are now async (only await for inline methods). otel_middleware: restore the jsonrpc.request.id span attribute that the previous Server._handle_request set. --- src/mcp/server/lowlevel/server.py | 4 +++ src/mcp/server/runner.py | 5 +++- src/mcp/shared/jsonrpc_dispatcher.py | 31 ++++++++++++++++------ tests/server/test_runner.py | 1 + tests/shared/test_jsonrpc_dispatcher.py | 34 +++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 9 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 5ab68349dd..91d2caab6d 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -396,6 +396,10 @@ async def run( read_stream, write_stream, raise_handler_exceptions=raise_exceptions, + # Handle `initialize` inline so a client that pipelines it with + # the next request (spec says SHOULD NOT, not MUST NOT) sees + # the initialized state instead of failing the init-gate. + inline_methods=frozenset({"initialize"}), ) runner = ServerRunner( server=self, diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index a2f5eb6846..6bb0738c68 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -94,10 +94,13 @@ async def wrapped( case _: parent = None span_name = f"MCP handle {method}{f' {target}' if target else ''}" + attributes: dict[str, str | int] = {"mcp.method.name": method} + if dctx.request_id is not None: + attributes["jsonrpc.request.id"] = dctx.request_id with otel_span( span_name, kind=SpanKind.SERVER, - attributes={"mcp.method.name": method}, + attributes=attributes, context=parent, record_exception=False, set_status_on_exception=False, diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 5f3e801eca..ad405b486c 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -213,6 +213,7 @@ def __init__( *, peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, + inline_methods: frozenset[str] = frozenset(), ) -> None: ... @overload def __init__( @@ -223,6 +224,7 @@ def __init__( transport_builder: Callable[[MessageMetadata], TransportT], peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, + inline_methods: frozenset[str] = frozenset(), ) -> None: ... def __init__( self, @@ -232,6 +234,7 @@ def __init__( transport_builder: Callable[[MessageMetadata], TransportT] | None = None, peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, + inline_methods: frozenset[str] = frozenset(), ) -> None: self._read_stream = read_stream self._write_stream = write_stream @@ -244,6 +247,13 @@ def __init__( ) self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode self._raise_handler_exceptions = raise_handler_exceptions + self._inline_methods = inline_methods + """Request methods handled inline in the read loop (awaited before the + next message is dequeued) instead of spawned concurrently. Use for + methods whose side effects must be observable to the next message, + e.g. `initialize`, so a pipelined follow-up sees the initialized state. + Only suitable for handlers that complete quickly, since inline handling + blocks dequeuing.""" self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} @@ -384,7 +394,7 @@ async def run( sender_ctx: contextvars.Context | None = getattr( self._read_stream, "last_context", None ) - self._dispatch(item, on_request, on_notify, sender_ctx) + await self._dispatch(item, on_request, on_notify, sender_ctx) except anyio.ClosedResourceError: # The transport closed our receive end and we looped # back to `__anext__` on the now-closed stream @@ -409,17 +419,19 @@ async def run( self._tg = None self._fan_out_closed() - def _dispatch( + async def _dispatch( self, item: SessionMessage | Exception, on_request: OnRequest, on_notify: OnNotify, sender_ctx: contextvars.Context | None, ) -> None: - """Route one inbound item. Synchronous: never awaits. + """Route one inbound item. - Everything here is `send_nowait` or `_spawn`. An `await` would let one - slow message head-of-line block the entire read loop. + Everything here is `send_nowait` or `_spawn`; the only `await` is for + `inline_methods` requests, which deliberately block dequeuing until + handled. Any other `await` would let one slow message head-of-line + block the entire read loop. """ if isinstance(item, Exception): logger.debug("transport yielded exception: %r", item) @@ -428,7 +440,7 @@ def _dispatch( msg = item.message match msg: case JSONRPCRequest(): - self._dispatch_request(msg, metadata, on_request, sender_ctx) + await self._dispatch_request(msg, metadata, on_request, sender_ctx) case JSONRPCNotification(): self._dispatch_notification(msg, metadata, on_notify, sender_ctx) case JSONRPCResponse(): @@ -439,7 +451,7 @@ def _dispatch( # on this final case is unreachable. self._resolve_pending(msg.id, msg.error) - def _dispatch_request( + async def _dispatch_request( self, req: JSONRPCRequest, metadata: MessageMetadata, @@ -462,7 +474,10 @@ def _dispatch_request( ) scope = anyio.CancelScope() self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx) - self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx) + if req.method in self._inline_methods: + await self._handle_request(req, dctx, scope, on_request) + else: + self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx) def _dispatch_notification( self, diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 069ffed994..21154fdcbd 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -411,6 +411,7 @@ async def call_tool(ctx: Ctx, params: CallToolRequestParams) -> dict[str, Any]: assert span.name == "MCP handle tools/call mytool" assert span.attributes is not None assert span.attributes["mcp.method.name"] == "tools/call" + assert isinstance(span.attributes["jsonrpc.request.id"], int) assert span.status.status_code == StatusCode.UNSET diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 91b7d63dc9..baea7f4b9a 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -11,6 +11,7 @@ from typing import Any import anyio +import anyio.lowlevel import pytest from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream @@ -461,6 +462,39 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | assert "progress callback raised" in caplog.text +@pytest.mark.anyio +async def test_inline_methods_are_handled_before_next_message_is_dequeued(): + """A method in `inline_methods` runs to completion before subsequent + messages are dispatched, so its side effects are visible to them.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + c2s_recv, s2c_send, inline_methods=frozenset({"first"}) + ) + state = {"initialized": False} + seen: list[bool] = [] + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "first": + await anyio.lowlevel.checkpoint() + state["initialized"] = True + else: + seen.append(state["initialized"]) + return {} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + # Buffer both requests before run() reads anything. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="first", params=None))) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="second", params=None))) + c2s_send.close() + with anyio.fail_after(5): + await server.run(on_request, on_notify) + assert seen == [True] + s2c_recv.close() + + @pytest.mark.anyio async def test_send_raw_request_always_carries_meta_on_the_wire(): """Outbound requests always include `params._meta` (otel injection per SEP-414). From 9f63603584d102659d4b5f4180b1b47d3b2c482a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 3 Jun 2026 10:18:35 +0000 Subject: [PATCH 46/52] fix: InMemoryTransport EOFs the server instead of cancelling; coverage cleanup src/mcp/client/_memory.py: replace tg.cancel_scope.cancel() with stream aclose(). Cancelling here throws CancelledError into the host (test's) task; on CPython 3.11 (gh-106749) coro.throw() drops 'call' trace events for the outer await chain, underflowing coverage's CTracer past the test frame. Post-swap the dispatcher's empty-at-EOF inner task group takes a one-tick fast path, so the join no longer suspends a second time to heal via .send(). EOF teardown is equivalent (the dispatcher's run() cancels its own in-flight handlers on read-stream EOF). Also bundled (coverage cleanup): - shared/session.py: delete server-only BaseSession code now unreachable after the swap (RequestResponder.cancel/_cancel_scope/_on_complete, CancelledNotification handling, _in_flight dict, deferred-respond arm). ClientSession is the only remaining subclass. - tests/client/test_session.py: add tests for the client-reachable BaseSession paths (malformed inbound request -> INVALID_PARAMS, sampling callback raises -> INVALID_PARAMS, progress callback exception swallowed, malformed notification dropped, transport exception forwarded to message_handler). - tests/shared/test_session.py: drop test_in_flight_requests_cleared (asserts the deleted _in_flight dict). - tests/shared/test_dispatcher.py: add contract test for ValidationError -> INVALID_PARAMS (covers DirectDispatcher's arm). - tests/server/test_validation.py: cover the previous-message-has-no- tool-use branch. - examples/everything-server: _add_request_handler (deleted) -> add_request_handler with explicit params type. --- .../mcp_everything_server/server.py | 12 +- src/mcp/client/_memory.py | 10 +- src/mcp/shared/session.py | 100 ++++---------- tests/client/test_session.py | 123 ++++++++++++++++++ tests/server/test_validation.py | 10 ++ tests/shared/test_dispatcher.py | 19 ++- tests/shared/test_session.py | 13 -- 7 files changed, 192 insertions(+), 95 deletions(-) diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index a0620b9c1d..b37ff3e950 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -417,9 +417,15 @@ async def handle_unsubscribe(ctx: ServerRequestContext, params: UnsubscribeReque return EmptyResult() -mcp._lowlevel_server._add_request_handler("logging/setLevel", handle_set_logging_level) # pyright: ignore[reportPrivateUsage] -mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscribe) # pyright: ignore[reportPrivateUsage] -mcp._lowlevel_server._add_request_handler("resources/unsubscribe", handle_unsubscribe) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server.add_request_handler( # pyright: ignore[reportPrivateUsage] + "logging/setLevel", SetLevelRequestParams, handle_set_logging_level +) +mcp._lowlevel_server.add_request_handler( # pyright: ignore[reportPrivateUsage] + "resources/subscribe", SubscribeRequestParams, handle_subscribe +) +mcp._lowlevel_server.add_request_handler( # pyright: ignore[reportPrivateUsage] + "resources/unsubscribe", UnsubscribeRequestParams, handle_unsubscribe +) @mcp.completion() diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index e6e9386731..3813c9e335 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -62,7 +62,15 @@ async def _connect(self) -> AsyncIterator[TransportStreams]: try: yield client_read, client_write finally: - tg.cancel_scope.cancel() + # EOF the server (and our own read side) instead of + # cancelling. The dispatcher's run() cancels its own + # in-flight handlers on read-stream EOF, so cleanup is + # equivalent. Cancelling here would `coro.throw()` into the + # host task, which on CPython 3.11 (gh-106749) drops + # `'call'` trace events for the outer await chain and + # desyncs coverage's CTracer past the test frame. + await client_write.aclose() + await server_write.aclose() async def __aenter__(self) -> TransportStreams: """Connect to the server and return streams for communication.""" diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index ea5d8833bd..376906b49c 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -2,7 +2,6 @@ import contextvars import logging -from collections.abc import Callable from contextlib import AsyncExitStack from types import TracebackType from typing import Any, Generic, Protocol, TypeVar @@ -21,7 +20,6 @@ CONNECTION_CLOSED, INVALID_PARAMS, REQUEST_TIMEOUT, - CancelledNotification, ClientNotification, ClientRequest, ClientResult, @@ -80,7 +78,6 @@ def __init__( request_meta: RequestParamsMeta | None, request: ReceiveRequestT, session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], - on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any], message_metadata: MessageMetadata = None, context: contextvars.Context | None = None, ) -> None: @@ -91,15 +88,10 @@ def __init__( self.context = context self._session = session self._completed = False - self._cancel_scope = anyio.CancelScope() - self._on_complete = on_complete self._entered = False # Track if we're in a context manager def __enter__(self) -> RequestResponder[ReceiveRequestT, SendResultT]: - """Enter the context manager, enabling request cancellation tracking.""" self._entered = True - self._cancel_scope = anyio.CancelScope() - self._cancel_scope.__enter__() return self def __exit__( @@ -108,15 +100,7 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - """Exit the context manager, performing cleanup and notifying completion.""" - try: - if self._completed: - self._on_complete(self) - finally: - self._entered = False - if not self._cancel_scope: # pragma: no cover - raise RuntimeError("No active cancel scope") - self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) + self._entered = False async def respond(self, response: SendResultT | ErrorData) -> None: """Send a response for this request. @@ -130,37 +114,11 @@ async def respond(self, response: SendResultT | ErrorData) -> None: if not self._entered: # pragma: no cover raise RuntimeError("RequestResponder must be used as a context manager") assert not self._completed, "Request already responded to" - - if not self.cancelled: # pragma: no branch - self._completed = True - - await self._session._send_response( # type: ignore[reportPrivateUsage] - request_id=self.request_id, response=response - ) - - async def cancel(self) -> None: - """Cancel this request and mark it as completed.""" - if not self._entered: # pragma: no cover - raise RuntimeError("RequestResponder must be used as a context manager") - if not self._cancel_scope: # pragma: no cover - raise RuntimeError("No active cancel scope") - - self._cancel_scope.cancel() - self._completed = True # Mark as completed so it's removed from in_flight - # Send an error response to indicate cancellation + self._completed = True await self._session._send_response( # type: ignore[reportPrivateUsage] - request_id=self.request_id, - response=ErrorData(code=0, message="Request cancelled"), + request_id=self.request_id, response=response ) - @property - def in_flight(self) -> bool: # pragma: no cover - return not self._completed and not self.cancelled - - @property - def cancelled(self) -> bool: - return self._cancel_scope.cancel_called - class BaseSession( Generic[ @@ -180,7 +138,6 @@ class BaseSession( _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] _request_id: int - _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] def __init__( @@ -195,7 +152,6 @@ def __init__( self._response_streams = {} self._request_id = 0 self._session_read_timeout_seconds = read_timeout_seconds - self._in_flight = {} self._progress_callbacks = {} self._exit_stack = AsyncExitStack() @@ -347,15 +303,10 @@ async def _handle_session_message(message: SessionMessage) -> None: request_meta=validated_request.params.meta if validated_request.params else None, request=validated_request, session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, context=sender_context, ) - self._in_flight[responder.request_id] = responder await self._received_request(responder) - - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) except Exception: # For request validation errors, send a proper JSON-RPC error # response instead of crashing the server @@ -375,33 +326,28 @@ async def _handle_session_message(message: SessionMessage) -> None: message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) - # Handle cancellation notifications - if isinstance(notification, CancelledNotification): - cancelled_id = notification.params.request_id - if cancelled_id in self._in_flight: # pragma: no branch - await self._in_flight[cancelled_id].cancel() - else: - # Handle progress notifications callback - if isinstance(notification, ProgressNotification): - progress_token = notification.params.progress_token - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - try: - await callback( - notification.params.progress, - notification.params.total, - notification.params.message, - ) - except Exception: - logging.exception("Progress callback raised an exception") - await self._received_notification(notification) - await self._handle_incoming(notification) + # Handle progress notifications callback + if isinstance(notification, ProgressNotification): + progress_token = notification.params.progress_token + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + try: + await callback( + notification.params.progress, + notification.params.total, + notification.params.message, + ) + except Exception: + logging.exception("Progress callback raised an exception") + await self._received_notification(notification) + await self._handle_incoming(notification) except Exception: # For other validation errors, log and continue - logging.warning( # pragma: no cover - f"Failed to validate notification:. Message was: {message.message}", + logging.warning( + "Failed to validate notification: %s", + message.message, exc_info=True, ) else: # Response or error diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f25c964f03..281c568b96 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,6 +1,11 @@ from __future__ import annotations +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + import anyio +import anyio.streams.memory import pytest from mcp import types @@ -10,12 +15,14 @@ from mcp.shared.session import RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( + INVALID_PARAMS, LATEST_PROTOCOL_VERSION, CallToolResult, Implementation, InitializedNotification, InitializeRequest, InitializeResult, + JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, @@ -26,6 +33,29 @@ client_request_adapter, ) +_SendToClient = anyio.streams.memory.MemoryObjectSendStream[SessionMessage | Exception] +_RecvFromClient = anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage] + + +@asynccontextmanager +async def raw_client_session( + **kwargs: Any, +) -> AsyncIterator[tuple[ClientSession, _SendToClient, _RecvFromClient]]: + """Yield `(session, send_to_client, recv_from_client)` with the receive loop running. + + `send_to_client` accepts `SessionMessage | Exception` so tests can inject + transport-level exceptions. No initialize handshake is performed. + """ + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage](32) + async with ClientSession(s2c_recv, c2s_send, **kwargs) as session: + try: + with anyio.fail_after(5): + yield session, s2c_send, c2s_recv + finally: + s2c_send.close() + c2s_recv.close() + @pytest.mark.anyio async def test_client_session_initialize(): @@ -705,3 +735,96 @@ async def mock_server(): await session.initialize() await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta) + + +@pytest.mark.anyio +async def test_receive_loop_answers_malformed_inbound_request_with_invalid_params(): + """A request that fails ServerRequest validation gets an INVALID_PARAMS error response.""" + async with raw_client_session() as (_session, to_client, from_client): + await to_client.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=7, method="sampling/createMessage", params={"broken": 1})) + ) + out = await from_client.receive() + assert isinstance(out.message, JSONRPCError) + assert out.message.id == 7 + assert out.message.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_receive_loop_answers_invalid_params_when_sampling_callback_raises(): + """Same boundary catches exceptions from the request handler itself.""" + + async def boom(ctx: object, params: object) -> types.CreateMessageResult: + raise RuntimeError("sampling boom") + + params = types.CreateMessageRequestParams( + messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hi"))], + max_tokens=10, + ).model_dump(by_alias=True, mode="json", exclude_none=True) + async with raw_client_session(sampling_callback=boom) as (_session, to_client, from_client): + await to_client.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=8, method="sampling/createMessage", params=params)) + ) + out = await from_client.receive() + assert isinstance(out.message, JSONRPCError) + assert out.message.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_receive_loop_logs_and_drops_malformed_notification(caplog: pytest.LogCaptureFixture): + """A notification that fails ServerNotification validation is logged and dropped.""" + seen: list[object] = [] + delivered = anyio.Event() + + async def handler(msg: object) -> None: + seen.append(msg) + delivered.set() + + async with raw_client_session(message_handler=handler) as (_session, to_client, _): + await to_client.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="not/a/spec/notification"))) + # Follow with a valid notification so we know the loop is still alive. + await to_client.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/tools/list_changed")) + ) + await delivered.wait() + assert isinstance(seen[0], types.ToolListChangedNotification) + assert "Failed to validate notification" in caplog.text + + +@pytest.mark.anyio +async def test_receive_loop_forwards_transport_exception_to_message_handler(): + seen: list[object] = [] + delivered = anyio.Event() + + async def handler(msg: object) -> None: + seen.append(msg) + delivered.set() + + async with raw_client_session(message_handler=handler) as (_session, to_client, _): + exc = ValueError("bad bytes") + await to_client.send(exc) + await delivered.wait() + assert seen == [exc] + + +@pytest.mark.anyio +async def test_receive_loop_swallows_progress_callback_exception(caplog: pytest.LogCaptureFixture): + delivered = anyio.Event() + + async def boom(progress: float, total: float | None, message: str | None) -> None: + raise RuntimeError("progress boom") + + async def handler(msg: object) -> None: + delivered.set() + + async with raw_client_session(message_handler=handler) as (session, to_client, _): + # Register the callback under a known token without sending a request. + session._progress_callbacks[42] = boom # pyright: ignore[reportPrivateUsage] + params = {"progressToken": 42, "progress": 0.5} + await to_client.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/progress", params=params)) + ) + # The progress notification also reaches the message handler after the + # callback runs, so this fires once the callback's exception is handled. + await delivered.wait() + assert "Progress callback raised an exception" in caplog.text diff --git a/tests/server/test_validation.py b/tests/server/test_validation.py index ad97dd3fd6..19f4eb1088 100644 --- a/tests/server/test_validation.py +++ b/tests/server/test_validation.py @@ -120,6 +120,16 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_without_previ validate_tool_use_result_messages(messages) +def test_validate_tool_use_result_messages_raises_when_previous_message_has_no_tool_use() -> None: + """Raises when tool_result follows a message that has content but no tool_use.""" + messages = [ + SamplingMessage(role="assistant", content=TextContent(type="text", text="just text")), + SamplingMessage(role="user", content=ToolResultContent(type="tool_result", tool_use_id="tool-1")), + ] + with pytest.raises(ValueError, match="do not match any tool_use in the previous message"): + validate_tool_use_result_messages(messages) + + def test_validate_tool_use_result_messages_raises_when_tool_result_ids_dont_match_tool_use() -> None: """Raises when tool_result IDs don't match tool_use IDs.""" messages = [ diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index d71b013573..745f4b3875 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -17,7 +17,7 @@ from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext -from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT +from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT, ErrorData, Tool from .conftest import PairFactory, direct_pair @@ -98,6 +98,23 @@ async def on_request( assert exc.value.error.message == "bad cursor" +@pytest.mark.anyio +async def test_send_raw_request_maps_validation_error_to_invalid_params(pair_factory: PairFactory): + """A pydantic `ValidationError` from the handler surfaces as the + normalized INVALID_PARAMS shape on every dispatcher.""" + + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + Tool.model_validate({"name": 123}) # raises ValidationError + raise NotImplementedError + + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + + @pytest.mark.anyio async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(pair_factory: PairFactory): async def on_request( diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index d7c6cc3b5f..8a53b0819d 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -23,19 +23,6 @@ ) -@pytest.mark.anyio -async def test_in_flight_requests_cleared_after_completion(): - """Verify that _in_flight is empty after all requests complete.""" - server = Server(name="test server") - async with Client(server) as client: - # Send a request and wait for response - response = await client.send_ping() - assert isinstance(response, EmptyResult) - - # Verify _in_flight is empty - assert len(client.session._in_flight) == 0 - - @pytest.mark.anyio async def test_request_cancellation(): """Test that requests can be cancelled while in-flight.""" From 50134cb29a6bae4dc97563e233f634f1e87b3f3d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 3 Jun 2026 12:39:11 +0000 Subject: [PATCH 47/52] test: close remaining server-side coverage gaps; ServerSession.check_client_capability delegates server/session.py: check_client_capability now delegates to Connection.check_capability instead of duplicating it. Connection's version gains the sampling.context/sampling.tools sub-checks and experimental value-equality so the delegation is complete. server/runner.py: otel_middleware sets jsonrpc.request.id unconditionally (DispatchMiddleware wraps on_request only; JSONRPCRequest.id is required, so the None guard was dead). tests/server/test_session.py: re-created for the new ServerSession(dispatcher, connection) shape - covers send_request timeout/progress_callback opts paths and the create_message tools branch. tests/server/test_server_context.py: assert Context.session_id and Context.headers. --- src/mcp/server/connection.py | 13 +++- src/mcp/server/runner.py | 5 +- src/mcp/server/session.py | 31 +------- tests/server/test_connection.py | 18 +++++ tests/server/test_runner.py | 2 +- tests/server/test_server_context.py | 4 +- tests/server/test_session.py | 117 ++++++++++++++++++++++++++++ 7 files changed, 151 insertions(+), 39 deletions(-) create mode 100644 tests/server/test_session.py diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index 1c7ee67412..bca9f2e875 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -155,14 +155,19 @@ def check_capability(self, capability: ClientCapabilities) -> bool: return False if capability.roots.list_changed and not have.roots.list_changed: return False - if capability.sampling is not None and have.sampling is None: - return False + if capability.sampling is not None: + if have.sampling is None: + return False + if capability.sampling.context is not None and have.sampling.context is None: + return False + if capability.sampling.tools is not None and have.sampling.tools is None: + return False if capability.elicitation is not None and have.elicitation is None: return False if capability.experimental is not None: if have.experimental is None: return False - for k in capability.experimental: - if k not in have.experimental: + for k, v in capability.experimental.items(): + if k not in have.experimental or have.experimental[k] != v: return False return True diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 6bb0738c68..539e9078c5 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -94,9 +94,8 @@ async def wrapped( case _: parent = None span_name = f"MCP handle {method}{f' {target}' if target else ''}" - attributes: dict[str, str | int] = {"mcp.method.name": method} - if dctx.request_id is not None: - attributes["jsonrpc.request.id"] = dctx.request_id + # `otel_middleware` wraps `on_request` only, so `request_id` is always set. + attributes = {"mcp.method.name": method, "jsonrpc.request.id": str(dctx.request_id)} with otel_span( span_name, kind=SpanKind.SERVER, diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 64e98ec46f..03ed8ef888 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -91,36 +91,7 @@ async def send_notification( def check_client_capability(self, capability: types.ClientCapabilities) -> bool: """Check if the client supports a specific capability.""" - if self.client_params is None: # pragma: lax no cover - return False - - client_caps = self.client_params.capabilities - - if capability.roots is not None: # pragma: lax no cover - if client_caps.roots is None: - return False - if capability.roots.list_changed and not client_caps.roots.list_changed: - return False - - if capability.sampling is not None: # pragma: lax no cover - if client_caps.sampling is None: - return False - if capability.sampling.context is not None and client_caps.sampling.context is None: - return False - if capability.sampling.tools is not None and client_caps.sampling.tools is None: - return False - - if capability.elicitation is not None and client_caps.elicitation is None: # pragma: lax no cover - return False - - if capability.experimental is not None: # pragma: lax no cover - if client_caps.experimental is None: - return False - for exp_key, exp_value in capability.experimental.items(): - if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: - return False - - return True + return self._connection.check_capability(capability) async def send_log_message( self, diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index 7d1c5dc04a..19f55ce6a0 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -28,6 +28,8 @@ PingRequest, RootsCapability, SamplingCapability, + SamplingContextCapability, + SamplingToolsCapability, ) @@ -194,8 +196,24 @@ def test_connection_check_capability_false_before_initialized(): False, ), (ClientCapabilities(sampling=None), ClientCapabilities(sampling=SamplingCapability()), False), + ( + ClientCapabilities(sampling=SamplingCapability()), + ClientCapabilities(sampling=SamplingCapability(context=SamplingContextCapability())), + False, + ), + ( + ClientCapabilities(sampling=SamplingCapability()), + ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())), + False, + ), + ( + ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())), + ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())), + True, + ), (ClientCapabilities(experimental=None), ClientCapabilities(experimental={"a": {}}), False), (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"b": {}}), False), + (ClientCapabilities(experimental={"a": {"x": 1}}), ClientCapabilities(experimental={"a": {"x": 2}}), False), (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"a": {}}), True), ], ) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 21154fdcbd..4508b55c43 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -411,7 +411,7 @@ async def call_tool(ctx: Ctx, params: CallToolRequestParams) -> dict[str, Any]: assert span.name == "MCP handle tools/call mytool" assert span.attributes is not None assert span.attributes["mcp.method.name"] == "tools/call" - assert isinstance(span.attributes["jsonrpc.request.id"], int) + assert isinstance(span.attributes["jsonrpc.request.id"], str) assert span.status.status_code == StatusCode.UNSET diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index 43c2069a87..8971d3d52f 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -41,7 +41,7 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, server, *_): # Now we have the server dispatcher; build the real Connection bound to it. - conn.__init__(server, has_standalone_channel=True) + conn.__init__(server, has_standalone_channel=True, session_id="sess-1") with anyio.fail_after(5): await client.send_raw_request("t", None) ctx = captured[0] @@ -49,6 +49,8 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | assert ctx.connection is conn assert ctx.transport.kind == "direct" assert ctx.can_send_request is True + assert ctx.session_id == "sess-1" + assert ctx.headers is None @pytest.mark.anyio diff --git a/tests/server/test_session.py b/tests/server/test_session.py new file mode 100644 index 0000000000..f4d91ee254 --- /dev/null +++ b/tests/server/test_session.py @@ -0,0 +1,117 @@ +"""Tests for `ServerSession`. + +`ServerSession` is a thin proxy over a dispatcher and a `Connection`. Tested +with a stub dispatcher so we can assert what reaches the wire (method, params, +`CallOptions`, related-request-id) without standing up a full transport. +""" + +from collections.abc import Mapping +from typing import Any, cast + +import pytest + +from mcp import types +from mcp.server.connection import Connection +from mcp.server.session import ServerSession +from mcp.shared.dispatcher import CallOptions +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import ServerMessageMetadata +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientCapabilities, + Implementation, + InitializeRequestParams, + SamplingCapability, + SamplingToolsCapability, +) + + +class StubDispatcher: + """Records `send_raw_request` / `notify` calls and returns a canned result.""" + + def __init__(self, result: dict[str, Any] | None = None) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None, CallOptions | None, Any]] = [] + self.result = result if result is not None else {} + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + *, + _related_request_id: Any = None, + ) -> dict[str, Any]: + self.requests.append((method, params, opts, _related_request_id)) + return self.result + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + +def _make_session(dispatcher: StubDispatcher, *, capabilities: ClientCapabilities | None = None) -> ServerSession: + conn = Connection(dispatcher, has_standalone_channel=True) + if capabilities is not None: + conn.client_params = InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=capabilities, + client_info=Implementation(name="c", version="0"), + ) + # cast: `ServerSession` is typed to take `JSONRPCDispatcher` but only ever + # calls `send_raw_request` / `notify`, so the stub is structurally sufficient. + return ServerSession(cast("JSONRPCDispatcher[Any]", dispatcher), conn) + + +@pytest.mark.anyio +async def test_send_request_forwards_timeout_and_progress_callback_as_call_options(): + dispatcher = StubDispatcher(result={"roots": []}) + session = _make_session(dispatcher) + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + raise NotImplementedError + + result = await session.send_request( + types.ListRootsRequest(), + types.ListRootsResult, + request_read_timeout_seconds=2.5, + metadata=ServerMessageMetadata(related_request_id=7), + progress_callback=on_progress, + ) + assert isinstance(result, types.ListRootsResult) + method, _params, opts, related = dispatcher.requests[0] + assert method == "roots/list" + assert opts == {"timeout": 2.5, "on_progress": on_progress} + assert related == 7 + + +@pytest.mark.anyio +async def test_send_request_omits_call_options_when_none_given(): + dispatcher = StubDispatcher(result={"roots": []}) + session = _make_session(dispatcher) + await session.send_request(types.ListRootsRequest(), types.ListRootsResult) + _method, _params, opts, related = dispatcher.requests[0] + assert opts is None + assert related is None + + +@pytest.mark.anyio +async def test_create_message_with_tools_returns_with_tools_result(): + dispatcher = StubDispatcher(result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"}) + session = _make_session( + dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) + ) + result = await session.create_message( + messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hi"))], + max_tokens=10, + tools=[types.Tool(name="t", input_schema={"type": "object"})], + ) + assert isinstance(result, types.CreateMessageResultWithTools) + method, params, _opts, _related = dispatcher.requests[0] + assert method == "sampling/createMessage" + assert params is not None and params["tools"][0]["name"] == "t" + + +def test_check_client_capability_delegates_to_connection(): + dispatcher = StubDispatcher() + session = _make_session(dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability())) + assert session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())) is True + assert session.check_client_capability(ClientCapabilities(experimental={"x": {}})) is False From 513ccc35735fa3e4a03e1a41b67aea7b0c7f6748 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 3 Jun 2026 12:41:43 +0000 Subject: [PATCH 48/52] chore: drop write-only _InFlight.cancelled_by_peer; add assertion to notify-drop test cancelled_by_peer was set by _dispatch_notification but never read; the peer-vs-outer-cancel distinction in _handle_request relies on scope.cancel_called alone (and works because nothing else cancels the per-request scope). test_runner_on_notify_drops_before_init_and_unknown_methods now registers a handler and asserts only the post-init notification reaches it (was assertionless before). --- src/mcp/shared/jsonrpc_dispatcher.py | 2 -- tests/server/test_runner.py | 10 +++++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index ad405b486c..b5b713cd1c 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -102,7 +102,6 @@ class _InFlight(Generic[TransportT]): scope: anyio.CancelScope dctx: _JSONRPCDispatchContext[TransportT] - cancelled_by_peer: bool = False @dataclass @@ -489,7 +488,6 @@ def _dispatch_notification( if msg.method == "notifications/cancelled": match msg.params: case {"requestId": str() | int() as rid} if (in_flight := self._in_flight.get(rid)) is not None: - in_flight.cancelled_by_peer = True in_flight.dctx.cancel_requested.set() if self._peer_cancel_mode == "interrupt": in_flight.scope.cancel() diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 4508b55c43..ab65b8ec87 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -255,11 +255,19 @@ async def on_level(ctx: Ctx, params: SetLevelRequestParams) -> None: @pytest.mark.anyio async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): + seen: list[Any] = [] + + async def on_roots(ctx: Ctx, params: NotificationParams | None) -> None: + seen.append(params) + + server.add_notification_handler("notifications/roots/list_changed", NotificationParams, on_roots) async with connected_runner(server, initialized=False) as (client, _): await client.notify("notifications/roots/list_changed", None) # before init: dropped await client.notify("notifications/initialized", None) await client.notify("notifications/unknown", None) # no handler: dropped - # No exception raised; both drops are silent. + await client.notify("notifications/roots/list_changed", None) # post-init: delivered + await anyio.wait_all_tasks_blocked() + assert seen == [None] # only the post-init one reached the handler @pytest.mark.anyio From 5b126bfc5550bb5a85994a3040faa02a93b95c5f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 3 Jun 2026 12:44:14 +0000 Subject: [PATCH 49/52] docs: migration.md entries for the dispatcher swap ServerSession proxy shape, lowlevel _handle_* removal, add_request_handler going public with params_type, raise_exceptions semantics narrowing, BaseSession/RequestResponder server-side cancellation tracking removal. --- docs/migration.md | 92 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 88 insertions(+), 4 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 9850f74cd4..3ba27cf826 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -442,7 +442,7 @@ async def handle_set_logging_level(level: str) -> None: mcp._mcp_server.subscribe_resource()(handle_subscribe) # pyright: ignore[reportPrivateUsage] ``` -In v2, the lowlevel `Server` no longer has decorator methods (handlers are constructor-only), so the equivalent workaround is `_add_request_handler`: +In v2, the lowlevel `Server` no longer has decorator methods (handlers are constructor-only), so the equivalent workaround is `add_request_handler`: **After (v2):** @@ -461,11 +461,11 @@ async def handle_subscribe(ctx: ServerRequestContext, params: SubscribeRequestPa return EmptyResult() -mcp._lowlevel_server._add_request_handler("logging/setLevel", handle_set_logging_level) # pyright: ignore[reportPrivateUsage] -mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscribe) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server.add_request_handler("logging/setLevel", SetLevelRequestParams, handle_set_logging_level) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server.add_request_handler("resources/subscribe", SubscribeRequestParams, handle_subscribe) # pyright: ignore[reportPrivateUsage] ``` -This is a private API and may change. A public way to register these handlers on `MCPServer` is planned; until then, use this workaround or use the lowlevel `Server` directly. +`_lowlevel_server` is private and may change. A public way to register these handlers on `MCPServer` is planned; until then, use this workaround or use the lowlevel `Server` directly. ### `MCPServer`'s `Context` logging: `message` renamed to `data`, `extra` removed @@ -620,6 +620,8 @@ ctx: ClientRequestContext server_ctx: ServerRequestContext[LifespanContextT, RequestT] ``` +`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`), so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`. + The high-level `Context` class (injected into `@mcp.tool()` etc.) similarly dropped its `ServerSessionT` parameter: `Context[ServerSessionT, LifespanContextT, RequestT]` → `Context[LifespanContextT, RequestT]`. Both remaining parameters have defaults, so bare `Context` is usually sufficient: **Before (v1):** @@ -813,6 +815,55 @@ server = Server("my-server", on_list_tools=handle_list_tools) If you need to check whether a handler is registered, track this yourself — there is currently no public introspection API. +### Lowlevel `Server`: `add_request_handler` is now public and takes `params_type` + +The private `_add_request_handler(method, handler)` escape hatch is now the public `add_request_handler(method, params_type, handler)`, alongside a matching `add_notification_handler`. Each takes a `params_type` model that incoming params are validated against before the handler runs. + +```python +# Before (v1 / earlier v2 prereleases) +server._add_request_handler("custom/method", my_handler) + +# After (v2) +server.add_request_handler("custom/method", MyParams, my_handler) +server.add_notification_handler("notifications/custom", MyNotifyParams, my_notify_handler) +``` + +### Lowlevel `Server`: private `_handle_*` dispatch methods removed + +`Server._handle_message`, `_handle_request`, and `_handle_notification` have been removed. The receive loop and per-message dispatch now live in `JSONRPCDispatcher` and `ServerRunner`, which `Server.run()` drives internally. + +These were private, but some users subclassed `Server` and overrode them to intercept requests. Use middleware instead: + +```python +from typing import Any + +from pydantic import BaseModel + +from mcp.server import Server, ServerRequestContext +from mcp.server.context import CallNext, HandlerResult + + +async def logging_middleware( + ctx: ServerRequestContext[Any, Any], method: str, params: BaseModel, call_next: CallNext +) -> HandlerResult: + print(f"handling {method}") + result = await call_next() + print(f"done {method}") + return result + + +server = Server("my-server", on_call_tool=...) +server.middleware.append(logging_middleware) +``` + +For lower-level interception (raw method/params before validation, including unknown methods), use `DispatchMiddleware` from `mcp.shared.dispatcher`. + +### Lowlevel `Server.run(raise_exceptions=True)`: transport errors no longer re-raised + +`raise_exceptions=True` now only governs handler exceptions: an exception raised by an `on_*` handler propagates out of `run()` instead of being converted to a JSON-RPC error response. + +Previously it also re-raised exceptions yielded by the transport onto the read stream (e.g. JSON parse errors). Those are now debug-logged and dropped regardless of `raise_exceptions`. If you relied on `run()` exiting on a transport-level parse error, that no longer happens. + ### Lowlevel `Server`: decorator-based handlers replaced with constructor `on_*` params The lowlevel `Server` class no longer uses decorator methods for handler registration. Instead, handlers are passed as `on_*` keyword arguments to the constructor. @@ -1039,6 +1090,39 @@ from mcp.server import ServerRequestContext # but None in notification handlers ``` +### `ServerSession` is now a thin proxy (no longer a `BaseSession`) + +`ServerSession` no longer subclasses `BaseSession`. It is now a small connection-scoped proxy that exposes `send_request`, `send_notification`, the typed convenience helpers (`create_message`, `elicit_form`, `send_log_message`, `send_tool_list_changed`, ...), `client_params`, and `check_client_capability`. The receive loop, `initialize` handling, and per-request task isolation that previously lived in `ServerSession` have moved to `JSONRPCDispatcher` and `ServerRunner`. + +`ServerSession` is normally constructed for you by `Server.run()` and reached via `ctx.session` in handlers, so most servers are unaffected. If you were constructing or subclassing it directly: + +**Constructor change:** + +```python +# Before (v1) +session = ServerSession(read_stream, write_stream, init_options, stateless=False) + +# After (v2) +session = ServerSession(dispatcher, connection, stateless=False) +# where `dispatcher` is a JSONRPCDispatcher and `connection` is a Connection +``` + +In practice, replace direct `ServerSession` use with `Server.run(read_stream, write_stream, init_options)` and let the framework wire it up. + +**Removed from `mcp.server.session`:** + +- `InitializationState` enum and `ServerSession._initialization_state` — initialization tracking is now on `Connection` (`connection.initialized` is an `anyio.Event`, `connection.client_params` holds the init params). +- `ServerRequestResponder` type alias. +- `ServerSession.incoming_messages` stream — there is no longer a public stream of inbound messages to iterate. Register handlers via the `on_*` constructor params (or `add_request_handler`) and use `Server.middleware` to observe every request. +- `ServerSession.__aenter__` / `__aexit__` — `ServerSession` is no longer an async context manager. +- The private `_receive_loop`, `_received_request`, `_received_notification`, and `_handle_incoming` overrides — there is nothing to override on `ServerSession` anymore. To intercept inbound messages, use `Server.middleware` or `DispatchMiddleware` (see the `_handle_*` removal section above). + +### `BaseSession` / `RequestResponder`: server-side cancellation tracking removed + +`BaseSession._in_flight` and the `RequestResponder` members that supported it (`cancel()`, the `cancelled` and `in_flight` properties, the `on_complete` constructor argument, and the internal `CancelScope`) have been removed. These existed to let `ServerSession` cancel a handler when a `CancelledNotification` arrived; `ServerSession` no longer drives a receive loop, so they were dead code. Inbound-cancellation handling for the server now lives in `JSONRPCDispatcher`. + +`BaseSession` is still used by `ClientSession`, which never relied on these members. `RequestResponder.respond()` is unchanged. + ### Experimental Tasks support removed Tasks (SEP-1686) have been removed from the MCP specification and are no longer part of this SDK. The `mcp.client.experimental`, `mcp.server.experimental`, `mcp.shared.experimental`, and `mcp.server.lowlevel.experimental` modules have been removed, along with all `Task*` types, the `tasks` capability fields, `Tool.execution`, and the `experimental` properties on `ClientSession`, `ServerSession`, `Server`, and `ServerRequestContext`. From 67d61ab6eebdba6e3b9a73adb70a5967e6b0f939 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 3 Jun 2026 19:31:47 +0000 Subject: [PATCH 50/52] fix: address 11 bughunter findings (parity, latent wiring, lifecycle) ServerRunner._dump_result: ErrorData returned by a handler now raises MCPError so it reaches the wire as a JSON-RPC error (was serialized as a success result; the previous Server._handle_request supported this). Validation: by_name=False at every wire-validation boundary (runner, session, peer, _typed_request) so snake_case wire keys are rejected as before. Absent params on requests reach handlers as None (matching the | None annotations) after a required-field check, not as an empty model. Server.run: has_standalone_channel=not stateless (was hardcoded True; made the new NoBackChannelError path inert in stateless SHTTP). ServerRunner: connection.initialized event is set on construction in stateless mode (was never set). exit_stack.aclose() is wrapped so a raising user cleanup callback is logged, not propagated. InMemoryTransport: bounded fallback cancel after EOF aclose(). If user teardown (lifespan __aexit__, exit_stack callbacks) doesn't complete in SERVER_SHUTDOWN_GRACE seconds the task group is cancelled. Healthy path still avoids the gh-106749 throw(). BaseSession._receive_loop: server->client notifications/cancelled is silently consumed again (the deletion let it reach message_handler). JSONRPCDispatcher: courtesy notifications/cancelled on timeout/cancel is tagged with related_request_id so SHTTP routes it onto the per-request stream. The inline_methods branch now spawns via _spawn (so sender contextvars apply) and awaits an Event to preserve ordering. _JSONRPCDispatchContext.notify drops when closed. DispatchContext Protocol: gains can_send_request (predicts whether send_raw_request will raise NoBackChannelError); BaseContext delegates to it. --- src/mcp/client/_memory.py | 43 ++++-- src/mcp/server/_typed_request.py | 2 +- src/mcp/server/connection.py | 8 +- src/mcp/server/context.py | 2 +- src/mcp/server/lowlevel/server.py | 6 +- src/mcp/server/runner.py | 50 +++++-- src/mcp/server/session.py | 2 +- src/mcp/shared/context.py | 8 +- src/mcp/shared/direct_dispatcher.py | 6 +- src/mcp/shared/dispatcher.py | 14 +- src/mcp/shared/exceptions.py | 4 +- src/mcp/shared/jsonrpc_dispatcher.py | 39 +++++- src/mcp/shared/peer.py | 10 +- src/mcp/shared/session.py | 9 ++ tests/client/test_session.py | 33 +++++ tests/client/transports/test_memory.py | 51 +++++++ tests/server/test_connection.py | 12 ++ tests/server/test_runner.py | 170 +++++++++++++++++++++++- tests/server/test_session.py | 11 ++ tests/server/test_stateless_mode.py | 49 ++++++- tests/shared/test_context.py | 20 ++- tests/shared/test_jsonrpc_dispatcher.py | 93 ++++++++++++- tests/shared/test_peer.py | 16 +++ 23 files changed, 608 insertions(+), 50 deletions(-) diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index 3813c9e335..05736acbaf 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -14,6 +14,9 @@ from mcp.server.mcpserver import MCPServer from mcp.shared.memory import create_client_server_memory_streams +SERVER_SHUTDOWN_GRACE = 2.0 +"""Seconds to wait for the in-process server to exit on EOF before cancelling.""" + class InMemoryTransport: """In-memory transport for testing MCP servers without network overhead. @@ -48,29 +51,47 @@ async def _connect(self) -> AsyncIterator[TransportStreams]: client_read, client_write = client_streams server_read, server_write = server_streams - async with anyio.create_task_group() as tg: - # Start server in background - tg.start_soon( - lambda: actual_server.run( + server_done = anyio.Event() + + async def _run_server() -> None: + try: + await actual_server.run( server_read, server_write, actual_server.create_initialization_options(), raise_exceptions=self._raise_exceptions, ) - ) + finally: + server_done.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(_run_server) try: yield client_read, client_write finally: # EOF the server (and our own read side) instead of - # cancelling. The dispatcher's run() cancels its own - # in-flight handlers on read-stream EOF, so cleanup is - # equivalent. Cancelling here would `coro.throw()` into the - # host task, which on CPython 3.11 (gh-106749) drops - # `'call'` trace events for the outer await chain and - # desyncs coverage's CTracer past the test frame. + # cancelling outright. The dispatcher's run() cancels its + # own in-flight handlers on read-stream EOF, so for a + # well-behaved server the task exits naturally and the + # task-group join below is immediate. Cancelling here + # unconditionally would `coro.throw()` into this task, + # which on CPython 3.11 (gh-106749) drops `'call'` trace + # events for the outer await chain and desyncs coverage's + # CTracer past the test frame. await client_write.aclose() await server_write.aclose() + # Backstop: the dispatcher exits on EOF, but the server's + # own teardown (lifespan __aexit__, connection.exit_stack + # callbacks) runs after that and is user code. If it never + # completes the join would hang forever, so bound the wait + # and fall back to cancelling. The healthy path returns + # from wait() without the timeout firing, so the cancel is + # never reached and gh-106749 stays avoided. + with anyio.move_on_after(SERVER_SHUTDOWN_GRACE): + await server_done.wait() + if not server_done.is_set(): + tg.cancel_scope.cancel() async def __aenter__(self) -> TransportStreams: """Connect to the server and return streams for communication.""" diff --git a/src/mcp/server/_typed_request.py b/src/mcp/server/_typed_request.py index ab3ac11803..64b8b8119a 100644 --- a/src/mcp/server/_typed_request.py +++ b/src/mcp/server/_typed_request.py @@ -83,4 +83,4 @@ async def send_request( """ raw = await self.send_raw_request(req.method, dump_params(req.params), opts) cls = result_type if result_type is not None else _RESULT_FOR[type(req)] - return cls.model_validate(raw) + return cls.model_validate(raw, by_name=False) diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index bca9f2e875..849a74b28b 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -43,7 +43,9 @@ class Connection(TypedServerRequestMixin): """Per-client connection state and standalone-stream `Outbound`. Constructed by `ServerRunner` once per connection. The peer-info fields are - `None` until `initialize` completes; `initialized` is set then. + `None` until `initialize` completes; `initialized` is set then. In + stateless deployments the runner sets `initialized` immediately and + peer-info remains `None` (no handshake reaches a stateless connection). """ def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_id: str | None = None) -> None: @@ -66,7 +68,9 @@ def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_ Push context managers (`await exit_stack.enter_async_context(...)`) or callbacks (`exit_stack.push_async_callback(...)`) from handlers or middleware to register per-connection teardown. Unwound LIFO after - `dispatcher.run()` returns, shielded from cancellation.""" + `dispatcher.run()` returns, shielded from cancellation. Exceptions + raised by callbacks are logged and swallowed; they never propagate + out of `ServerRunner.run()`.""" @property def client_info(self) -> Implementation | None: diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index d6a4aef663..7ca4c63c53 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -142,6 +142,6 @@ async def __call__( self, ctx: ServerRequestContext[_MwLifespanT, Any], method: str, - params: BaseModel, + params: BaseModel | None, call_next: CallNext, ) -> HandlerResult: ... diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 91d2caab6d..5650db7199 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -406,7 +406,11 @@ async def run( dispatcher=dispatcher, lifespan_state=lifespan_context, init_options=initialization_options, - has_standalone_channel=True, + # Stateless HTTP has no standalone GET stream, so server-initiated + # requests on `runner.connection` must fail fast with + # `NoBackChannelError` rather than write to a channel that will + # never deliver a response. + has_standalone_channel=not stateless, stateless=stateless, dispatch_middleware=[otel_middleware], ) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 539e9078c5..2ec5173e84 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -42,6 +42,7 @@ LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, ClientRequest, + ErrorData, Implementation, InitializeRequestParams, InitializeResult, @@ -120,6 +121,11 @@ async def wrapped( def _dump_result(result: Any) -> dict[str, Any]: if result is None: return {} + if isinstance(result, ErrorData): + # The existing `BaseSession._send_response` treats a handler-returned + # `ErrorData` as a JSON-RPC error, not a success result. Re-raise as + # `MCPError` so the dispatcher's exception boundary emits `JSONRPCError`. + raise MCPError(code=result.code, message=result.message, data=result.data) if isinstance(result, BaseModel): return result.model_dump(by_alias=True, mode="json", exclude_none=True) if isinstance(result, dict): @@ -153,6 +159,11 @@ def __post_init__(self) -> None: self.connection = Connection( self.dispatcher, has_standalone_channel=self.has_standalone_channel, session_id=self.session_id ) + if self.stateless: + # Keep the public event in lockstep with the gate flag so a handler + # awaiting `connection.initialized` does not hang on a stateless + # connection (where no `initialize` exchange ever arrives). + self.connection.initialized.set() self.session = ServerSession(self.dispatcher, self.connection, stateless=self.stateless) async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: @@ -169,7 +180,15 @@ async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STAT await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) finally: with anyio.CancelScope(shield=True): - await self.connection.exit_stack.aclose() + try: + await self.connection.exit_stack.aclose() + except Exception: + # Top-level boundary: a cleanup callback raising must not + # escape `run()` - it would crash stdio servers on a normal + # disconnect and, via raise-in-finally, mask the original + # exception from `dispatcher.run()` (including the + # CancelledError that SHTTP idle-timeout teardown checks). + logger.exception("connection exit_stack cleanup raised") def _compose_on_request(self) -> OnRequest: """Wrap `_on_request` in `dispatch_middleware`, outermost-first. @@ -200,7 +219,7 @@ async def _on_request( payload: dict[str, Any] = {"method": method} if params is not None: payload["params"] = dict(params) - client_request_adapter.validate_python(payload) + client_request_adapter.validate_python(payload, by_name=False) if method == "initialize": return self._handle_initialize(params) if not self._initialized and method not in _INIT_EXEMPT: @@ -212,8 +231,16 @@ async def _on_request( if entry is None: raise MCPError(code=METHOD_NOT_FOUND, message="Method not found") # ValidationError propagates; the dispatcher's exception boundary maps - # it to INVALID_PARAMS. - typed_params = entry.params_type.model_validate(params or {}) + # it to INVALID_PARAMS. Absent wire params reach the handler as None + # (matches the existing `Server._handle_request`, where `req.params` + # is None for optional-params requests like tools/list); the empty-dict + # validate is a required-field check so a required-params model still + # surfaces as INVALID_PARAMS rather than reaching the handler as None. + if params is None: + entry.params_type.model_validate({}, by_name=False) + typed_params = None + else: + typed_params = entry.params_type.model_validate(params, by_name=False) ctx = self._make_context(dctx, typed_params) call: CallNext = partial(entry.handler, ctx, typed_params) for mw in reversed(self.server.middleware): @@ -237,10 +264,17 @@ async def _on_notify( if entry is None: logger.debug("no handler for notification %s", method) return - # Absent wire params reach the handler as `None`, not an empty model - # (matches the existing `Server._handle_notification`). + # Absent wire params reach the handler as None, not an empty model + # (matches the existing `Server._handle_notification`). The empty-dict + # validate is a required-field check: a required-params model (e.g. + # ProgressNotificationParams) takes the malformed-params drop path + # instead of reaching a non-Optional handler as None. try: - typed_params = entry.params_type.model_validate(params) if params is not None else None + if params is None: + entry.params_type.model_validate({}, by_name=False) + typed_params = None + else: + typed_params = entry.params_type.model_validate(params, by_name=False) except ValidationError: logger.warning("dropped %r: malformed params", method) return @@ -279,7 +313,7 @@ def _make_context( ) def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: - init = InitializeRequestParams.model_validate(params or {}) + init = InitializeRequestParams.model_validate(params or {}, by_name=False) self.connection.client_params = init requested = init.protocol_version negotiated = requested if requested in SUPPORTED_PROTOCOL_VERSIONS else LATEST_PROTOCOL_VERSION diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 03ed8ef888..9016f05a0c 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -78,7 +78,7 @@ async def send_request( result = await self._dispatcher.send_raw_request( data["method"], data.get("params"), opts or None, _related_request_id=related ) - return result_type.model_validate(result) + return result_type.model_validate(result, by_name=False) async def send_notification( self, diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 437f821a81..849054dda0 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -48,8 +48,12 @@ def cancel_requested(self) -> anyio.Event: @property def can_send_request(self) -> bool: - """Whether the back-channel can deliver server-initiated requests.""" - return self._dctx.transport.can_send_request + """Whether the back-channel can currently deliver server-initiated requests. + + `False` when the transport has no back-channel, or when the underlying + dispatch context has been closed because the inbound request finished. + """ + return self._dctx.can_send_request @property def meta(self) -> RequestParamsMeta | None: diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index f252dfad30..1b07b87d77 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -56,6 +56,10 @@ class _DirectDispatchContext: _on_progress: ProgressFnT | None = None cancel_requested: anyio.Event = field(default_factory=anyio.Event) + @property + def can_send_request(self) -> bool: + return self.transport.can_send_request + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: await self._back_notify(method, params) @@ -65,7 +69,7 @@ async def send_raw_request( params: Mapping[str, Any] | None, opts: CallOptions | None = None, ) -> dict[str, Any]: - if not self.transport.can_send_request: + if not self.can_send_request: raise NoBackChannelError(method) return await self._back_request(method, params, opts) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index aca96231f3..9dfb24940a 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -100,8 +100,8 @@ class DispatchContext(Outbound, Protocol[TransportT_co]): Carries the transport metadata for the inbound message and provides the back-channel for sending requests/notifications to the peer while handling - it. `send_raw_request` raises `NoBackChannelError` if - `transport.can_send_request` is `False`. + it. `send_raw_request` raises `NoBackChannelError` if `can_send_request` + is `False`. """ @property @@ -109,6 +109,16 @@ def transport(self) -> TransportT_co: """Transport-specific metadata for this inbound message.""" ... + @property + def can_send_request(self) -> bool: + """Whether the back-channel can currently deliver server-initiated requests. + + `False` when the transport has no back-channel, or when this context has + been closed (the inbound request finished). `send_raw_request` raises + `NoBackChannelError` exactly when this is `False`. + """ + ... + @property def request_id(self) -> RequestId | None: """The id of the inbound request, or `None` for a notification. diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f81b737cc1..bb4cfc0d00 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -46,8 +46,8 @@ class NoBackChannelError(MCPError): Stateless HTTP and JSON-response-mode HTTP have no channel for the server to push requests (sampling, elicitation, roots/list) to the client. This is - raised by `DispatchContext.send_raw_request` when `transport.can_send_request` - is `False`, and serializes to an `INVALID_REQUEST` error response. + raised by `DispatchContext.send_raw_request` when `can_send_request` is + `False`, and serializes to an `INVALID_REQUEST` error response. """ def __init__(self, method: str): diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index b5b713cd1c..55eba7486a 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -131,6 +131,9 @@ def can_send_request(self) -> bool: return self.transport.can_send_request and not self._closed async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._closed: + logger.debug("dropped %s: dispatch context closed", method) + return await self._dispatcher.notify(method, params, _related_request_id=self._request_id) async def send_raw_request( @@ -252,7 +255,9 @@ def __init__( methods whose side effects must be observable to the next message, e.g. `initialize`, so a pipelined follow-up sees the initialized state. Only suitable for handlers that complete quickly, since inline handling - blocks dequeuing.""" + blocks dequeuing; a handler that awaits the peer (`send_raw_request`) + while inline will deadlock because the parked read loop cannot dequeue + the response.""" self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} @@ -333,14 +338,14 @@ async def send_raw_request( # Spec-recommended courtesy: tell the peer we've given up so it can # stop work and free resources. v1's BaseSession.send_request does # NOT do this; it's new behaviour. - await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s") + await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s", _related_request_id) raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None except anyio.get_cancelled_exc_class(): # Our caller's scope was cancelled. We're already inside a cancelled # scope, so any bare `await` here re-raises immediately - shield to # let the courtesy cancel notification go out before we propagate. with anyio.CancelScope(shield=True): - await self._cancel_outbound(request_id, "caller cancelled") + await self._cancel_outbound(request_id, "caller cancelled", _related_request_id) raise finally: # Always remove the waiter, even on cancel/timeout, so a late @@ -474,7 +479,21 @@ async def _dispatch_request( scope = anyio.CancelScope() self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx) if req.method in self._inline_methods: - await self._handle_request(req, dctx, scope, on_request) + # Spawn (so `sender_ctx` applies, matching the concurrent path) but + # park the read loop until the handler returns; that's the inline + # ordering guarantee. Because the read loop is parked, a handler + # that awaits the peer here (e.g. `dctx.send_raw_request`) will + # deadlock: the response can never be dequeued. + done = anyio.Event() + + async def _run_inline() -> None: + try: + await self._handle_request(req, dctx, scope, on_request) + finally: + done.set() + + self._spawn(_run_inline, sender_ctx=sender_ctx) + await done.wait() else: self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx) @@ -642,8 +661,16 @@ async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: except (anyio.BrokenResourceError, anyio.ClosedResourceError): logger.debug("dropped error for %r: write stream closed", request_id) - async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: + async def _cancel_outbound(self, request_id: RequestId, reason: str, related_request_id: RequestId | None) -> None: + # Thread `related_request_id` so streamable-HTTP routes the cancel onto + # the same per-request SSE stream as the request it cancels; without it + # the notification falls through to the standalone GET stream and is + # dropped when no GET stream is open. try: - await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) + await self.notify( + "notifications/cancelled", + {"requestId": request_id, "reason": reason}, + _related_request_id=related_request_id, + ) except (anyio.BrokenResourceError, anyio.ClosedResourceError): pass diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py index a7347e30cc..25ec112b02 100644 --- a/src/mcp/shared/peer.py +++ b/src/mcp/shared/peer.py @@ -131,8 +131,8 @@ async def sample( ) result = await self.send_raw_request("sampling/createMessage", dump_params(params, meta), opts) if tools is not None: - return CreateMessageResultWithTools.model_validate(result) - return CreateMessageResult.model_validate(result) + return CreateMessageResultWithTools.model_validate(result, by_name=False) + return CreateMessageResult.model_validate(result, by_name=False) async def elicit_form( self: Outbound, @@ -150,7 +150,7 @@ async def elicit_form( """ params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) - return ElicitResult.model_validate(result) + return ElicitResult.model_validate(result, by_name=False) async def elicit_url( self: Outbound, @@ -169,7 +169,7 @@ async def elicit_url( """ params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) - return ElicitResult.model_validate(result) + return ElicitResult.model_validate(result, by_name=False) async def list_roots( self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None @@ -181,7 +181,7 @@ async def list_roots( NoBackChannelError: No back-channel for server-initiated requests. """ result = await self.send_raw_request("roots/list", dump_params(None, meta), opts) - return ListRootsResult.model_validate(result) + return ListRootsResult.model_validate(result, by_name=False) async def ping(self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: """Send a `ping` request and ignore the result. diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 376906b49c..63498ca338 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -20,6 +20,7 @@ CONNECTION_CLOSED, INVALID_PARAMS, REQUEST_TIMEOUT, + CancelledNotification, ClientNotification, ClientRequest, ClientResult, @@ -326,6 +327,14 @@ async def _handle_session_message(message: SessionMessage) -> None: message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) + if isinstance(notification, CancelledNotification): + # ClientSession runs server-initiated requests + # inline in this loop, so by the time a peer + # cancellation is read there is nothing left to + # cancel. Consume it here so message_handler + # keeps the contract it had before the + # dispatcher swap removed _in_flight. + return # Handle progress notifications callback if isinstance(notification, ProgressNotification): progress_token = notification.params.progress_token diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 281c568b96..28d212d007 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -807,6 +807,39 @@ async def handler(msg: object) -> None: assert seen == [exc] +@pytest.mark.anyio +async def test_receive_loop_consumes_server_cancelled_without_reaching_message_handler(): + """A server-sent notifications/cancelled is swallowed, matching the pre-swap contract. + + The server dispatcher now emits this on sampling/elicitation timeout, but + ClientSession has no in-flight tracking to act on it, so surfacing it would + only break user handlers that exhaustively match ServerNotification. + """ + seen: list[object] = [] + delivered = anyio.Event() + + async def handler(msg: object) -> None: + seen.append(msg) + delivered.set() + + async with raw_client_session(message_handler=handler) as (_session, to_client, _): + await to_client.send( + SessionMessage( + JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1, "reason": "timed out"} + ) + ) + ) + # Follow with a notification that does reach the handler so we can + # assert ordering deterministically. + await to_client.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/tools/list_changed")) + ) + await delivered.wait() + assert len(seen) == 1 + assert isinstance(seen[0], types.ToolListChangedNotification) + + @pytest.mark.anyio async def test_receive_loop_swallows_progress_callback_exception(caplog: pytest.LogCaptureFixture): delivered = anyio.Event() diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index c8fc41fd5d..8baee128b5 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -1,8 +1,15 @@ """Tests for InMemoryTransport.""" +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import anyio +import anyio.lowlevel import pytest from mcp import Client, types +from mcp.client import _memory from mcp.client._memory import InMemoryTransport from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer @@ -95,3 +102,47 @@ async def test_raise_exceptions(mcpserver_server: MCPServer): transport = InMemoryTransport(mcpserver_server, raise_exceptions=True) async with transport as (read_stream, _write_stream): assert read_stream is not None + + +async def test_aexit_with_well_behaved_lifespan_runs_teardown_without_cancel(): + """A lifespan that finishes promptly on EOF should run to completion. + + The transport closes the streams first and waits for the server to exit + naturally, so teardown observes no cancellation. + """ + teardown_ran = anyio.Event() + + @asynccontextmanager + async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: + yield {} + await anyio.lowlevel.checkpoint() + teardown_ran.set() + + server = Server(name="test_server", lifespan=lifespan) + with anyio.fail_after(5): + async with InMemoryTransport(server): + pass + assert teardown_ran.is_set() + + +async def test_aexit_with_blocking_lifespan_is_bounded(monkeypatch: pytest.MonkeyPatch): + """A lifespan that never returns must not hang `__aexit__` forever. + + After EOFing the server the transport waits `SERVER_SHUTDOWN_GRACE` for a + natural exit, then cancels the server task as a backstop so the + task-group join completes. + """ + monkeypatch.setattr(_memory, "SERVER_SHUTDOWN_GRACE", 0.05) + teardown_started = anyio.Event() + + @asynccontextmanager + async def blocking_lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: + yield {} + teardown_started.set() + await anyio.Event().wait() + + server = Server(name="test_server", lifespan=blocking_lifespan) + with anyio.fail_after(5): + async with InMemoryTransport(server): + pass + assert teardown_started.is_set() diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index 19f55ce6a0..b8378574ef 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -19,6 +19,8 @@ from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, + CreateMessageRequest, + CreateMessageRequestParams, ElicitationCapability, EmptyResult, Implementation, @@ -116,6 +118,16 @@ async def test_connection_send_request_with_spec_type_infers_result_type(): assert str(result.roots[0].uri) == "file:///ws" +@pytest.mark.anyio +async def test_connection_send_request_validates_result_alias_only(): + """Peer results validate alias-only; a snake_case key from the wire is + ignored as extra, not populated by Python field name.""" + snake = {"role": "assistant", "content": {"type": "text", "text": "x"}, "model": "m", "stop_reason": "endTurn"} + conn = Connection(StubOutbound(result=snake), has_standalone_channel=True) + result = await conn.send_request(CreateMessageRequest(params=CreateMessageRequestParams(messages=[], max_tokens=1))) + assert result.stop_reason is None + + @pytest.mark.anyio async def test_connection_send_request_with_result_type_kwarg_validates_custom_type(): out = StubOutbound(result={}) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index ab65b8ec87..403cf3d15a 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -19,7 +19,7 @@ from mcp.server.models import InitializationOptions from mcp.server.runner import ServerRunner, otel_middleware from mcp.server.session import ServerSession -from mcp.shared.dispatcher import DispatchMiddleware +from mcp.shared.dispatcher import DispatchContext, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.transport_context import TransportContext @@ -36,6 +36,7 @@ ListToolsResult, NotificationParams, PaginatedRequestParams, + ProgressNotificationParams, RequestParams, SetLevelRequestParams, Tool, @@ -191,6 +192,52 @@ async def test_runner_malformed_params_for_unregistered_spec_method_raises_inval assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") +@pytest.mark.anyio +async def test_runner_rejects_snake_case_initialize_params(server: SrvT): + """Inbound wire payloads validate alias-only; Python field names are not + accepted (`protocol_version` must arrive as `protocolVersion`).""" + snake = { + "protocol_version": LATEST_PROTOCOL_VERSION, + "capabilities": {}, + "client_info": {"name": "c", "version": "0"}, + } + async with connected_runner(server, initialized=False) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("initialize", snake) + assert exc.value.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_runner_rejects_snake_case_params_for_custom_handler(server: SrvT): + """Custom-method handlers (which skip the spec-method gate) still validate + alias-only at the per-handler boundary.""" + + async def handler(ctx: Ctx, params: ProgressNotificationParams) -> dict[str, Any]: + return {"ok": True} + + server.add_request_handler("custom/progress", ProgressNotificationParams, handler) + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("custom/progress", {"progress_token": 1, "progress": 0.5}) + assert exc.value.error.code == INVALID_PARAMS + result = await client.send_raw_request("custom/progress", {"progressToken": 1, "progress": 0.5}) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_snake_case_params(server: SrvT, caplog: pytest.LogCaptureFixture): + """Notification params validate alias-only; snake_case is dropped as malformed.""" + + async def handler(ctx: Ctx, params: ProgressNotificationParams) -> None: + raise NotImplementedError + + server.add_notification_handler("notifications/roots/list_changed", ProgressNotificationParams, handler) + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", {"progress_token": 1, "progress": 0.5}) + await client.send_raw_request("tools/list", None) + assert "dropped 'notifications/roots/list_changed': malformed params" in caplog.text + + @pytest.mark.anyio async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): async with connected_runner(server, initialized=False) as (client, runner): @@ -253,6 +300,80 @@ async def on_level(ctx: Ctx, params: SetLevelRequestParams) -> None: assert "dropped 'notifications/roots/list_changed': malformed params" in caplog.text +@pytest.mark.anyio +async def test_runner_on_notify_drops_absent_params_when_model_requires_them( + server: SrvT, caplog: pytest.LogCaptureFixture +): + """A params-less progress notification is dropped, not delivered as None. + + `on_progress` is typed to receive a non-Optional `ProgressNotificationParams`; + the previous server validated the full notification union and dropped this + as malformed before dispatch. + """ + + async def on_progress(ctx: Ctx, params: ProgressNotificationParams) -> None: + raise NotImplementedError + + server.add_notification_handler("notifications/progress", ProgressNotificationParams, on_progress) + async with connected_runner(server) as (client, _): + await client.notify("notifications/progress", None) + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + assert "dropped 'notifications/progress': malformed params" in caplog.text + assert "notification handler for" not in caplog.text + + +@pytest.mark.anyio +async def test_runner_absent_wire_params_reaches_request_handler_as_none(): + """A request with no `params` member on the wire reaches the handler as + `None`, matching the previous server and the `| None` handler annotation. + + The in-SDK client always attaches `_meta`, so a dispatch middleware + forwards `params=None` to model what an external client sends. + """ + seen: list[PaginatedRequestParams | None] = [] + + async def list_tools(ctx: Ctx, params: PaginatedRequestParams | None) -> ListToolsResult: + seen.append(params) + return ListToolsResult(tools=[]) + + def drop_params(next_on_request: OnRequest) -> OnRequest: + async def wrapped(dctx: DispatchContext[Any], method: str, params: Any) -> dict[str, Any]: + return await next_on_request(dctx, method, None if method == "tools/list" else params) + + return wrapped + + server: SrvT = Server(name="s", on_list_tools=list_tools) + async with connected_runner(server, dispatch_middleware=[drop_params]) as (client, _): + await client.send_raw_request("tools/list", None) + assert seen == [None] + + +@pytest.mark.anyio +async def test_runner_absent_wire_params_for_required_params_custom_method_is_invalid_params(): + """A custom method whose `params_type` has required fields rejects absent + wire params as INVALID_PARAMS rather than invoking the handler with None.""" + + class GreetParams(RequestParams): + name: str + + async def greet(ctx: Ctx, params: GreetParams) -> dict[str, Any]: + raise NotImplementedError + + def drop_params(next_on_request: OnRequest) -> OnRequest: + async def wrapped(dctx: DispatchContext[Any], method: str, params: Any) -> dict[str, Any]: + return await next_on_request(dctx, method, None if method == "custom/greet" else params) + + return wrapped + + server: SrvT = Server(name="s") + server.add_request_handler("custom/greet", GreetParams, greet) + async with connected_runner(server, dispatch_middleware=[drop_params]) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("custom/greet", {"name": "x"}) + assert exc.value.error.code == INVALID_PARAMS + + @pytest.mark.anyio async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): seen: list[Any] = [] @@ -332,6 +453,21 @@ async def set_level(ctx: Ctx, params: SetLevelRequestParams) -> None: assert result == {} +@pytest.mark.anyio +async def test_runner_handler_returning_error_data_produces_jsonrpc_error(server: SrvT): + """A handler returning `ErrorData` reaches the client as a JSON-RPC error, + not a success result, matching `BaseSession._send_response`.""" + + async def set_level(ctx: Ctx, params: SetLevelRequestParams) -> ErrorData: + return ErrorData(code=INVALID_PARAMS, message="bad level", data={"got": params.level}) + + server.add_request_handler("logging/setLevel", SetLevelRequestParams, set_level) + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("logging/setLevel", {"level": "info"}) + assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="bad level", data={"got": "info"}) + + @pytest.mark.anyio async def test_runner_handler_returning_unsupported_type_surfaces_as_error(server: SrvT): async def bad_return(ctx: Ctx, params: PaginatedRequestParams | None) -> int: @@ -354,6 +490,17 @@ async def test_runner_stateless_skips_init_gate(server: SrvT): assert result["tools"][0]["name"] == "t" +@pytest.mark.anyio +async def test_runner_stateless_connection_initialized_event_set_on_construction(server: SrvT): + """`connection.initialized` mirrors the gate flag in stateless mode so + `await connection.initialized.wait()` does not hang when no handshake + arrives.""" + async with connected_runner(server, initialized=False, stateless=True, has_standalone_channel=False) as (_, runner): + assert runner._initialized is True + assert runner.connection.initialized.is_set() + await runner.connection.initialized.wait() + + @pytest.mark.anyio async def test_server_add_request_handler_routes_custom_method_with_validated_params(server: SrvT): """Custom methods outside the spec `ClientRequest` union skip upfront @@ -485,3 +632,24 @@ async def _append(i: int) -> None: await client.send_raw_request("tools/list", None) assert cleaned == [] assert cleaned == [3, 2, 1] + + +@pytest.mark.anyio +async def test_runner_exit_stack_cleanup_exception_is_logged_not_propagated( + server: SrvT, caplog: pytest.LogCaptureFixture +) -> None: + """A raising cleanup callback is caught and logged; `run()` exits cleanly.""" + cleaned: list[str] = [] + + async def _ok() -> None: + cleaned.append("ok") + + async def _boom() -> None: + raise RuntimeError("cleanup failed") + + async with connected_runner(server) as (client, runner): + runner.connection.exit_stack.push_async_callback(_ok) + runner.connection.exit_stack.push_async_callback(_boom) + await client.send_raw_request("tools/list", None) + assert cleaned == ["ok"] + assert "connection exit_stack cleanup raised" in caplog.text diff --git a/tests/server/test_session.py b/tests/server/test_session.py index f4d91ee254..c77ac8a42c 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -93,6 +93,17 @@ async def test_send_request_omits_call_options_when_none_given(): assert related is None +@pytest.mark.anyio +async def test_send_request_validates_result_alias_only(): + """Peer results validate alias-only; a snake_case key from the wire is + ignored as extra, not populated by Python field name.""" + snake = {"role": "assistant", "content": {"type": "text", "text": "x"}, "model": "m", "stop_reason": "endTurn"} + session = _make_session(StubDispatcher(result=snake)) + request = types.CreateMessageRequest(params=types.CreateMessageRequestParams(messages=[], max_tokens=1)) + result = await session.send_request(request, types.CreateMessageResult) + assert result.stop_reason is None + + @pytest.mark.anyio async def test_create_message_with_tools_returns_with_tools_result(): dispatcher = StubDispatcher(result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"}) diff --git a/tests/server/test_stateless_mode.py b/tests/server/test_stateless_mode.py index abe92062c5..1b628e2388 100644 --- a/tests/server/test_stateless_mode.py +++ b/tests/server/test_stateless_mode.py @@ -10,13 +10,18 @@ from typing import Any from unittest.mock import Mock +import anyio import pytest from mcp import types from mcp.server.connection import Connection +from mcp.server.context import ServerRequestContext +from mcp.server.lowlevel.server import Server from mcp.server.session import ServerSession -from mcp.shared.exceptions import StatelessModeNotSupported +from mcp.shared.exceptions import NoBackChannelError, StatelessModeNotSupported from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCRequest, JSONRPCResponse, ListToolsResult, PaginatedRequestParams def _make_session(*, stateless: bool) -> ServerSession: @@ -138,3 +143,45 @@ async def mock_send_request(*_: Any, **__: Any) -> types.ListRootsResult: assert send_request_called assert isinstance(result, types.ListRootsResult) + + +@pytest.mark.anyio +async def test_server_run_stateless_wires_no_standalone_channel(): + """`Server.run(stateless=True)` must wire `Connection.has_standalone_channel=False`. + + Stateless HTTP has no standalone GET stream, so server-initiated requests on + the connection must fail fast with `NoBackChannelError` rather than write to + a channel that will never deliver a response. The `ServerSession` typed + helpers carry their own stateless guard (tested above); this pins the + `Connection` wiring that `Server.run` produces. + """ + captured: list[Connection] = [] + + async def list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + # `ServerRequestContext` doesn't expose `connection` directly yet (it + # will after the Context rework); reach it via the session for now. + captured.append(ctx.session._connection) # pyright: ignore[reportPrivateUsage] + return ListToolsResult(tools=[]) + + server: Server[Any] = Server("test", on_list_tools=list_tools) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run(server_read, server_write, server.create_initialization_options(), stateless=True) + + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + # stateless=True skips the init gate, so tools/list routes immediately. + await to_server.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/list"))) + with anyio.fail_after(5): + response = (await from_server.receive()).message + assert isinstance(response, JSONRPCResponse) + tg.cancel_scope.cancel() + + assert len(captured) == 1 + conn = captured[0] + assert conn.has_standalone_channel is False + with pytest.raises(NoBackChannelError): + await conn.ping() diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py index c260c9be80..68057a9e10 100644 --- a/tests/shared/test_context.py +++ b/tests/shared/test_context.py @@ -16,7 +16,7 @@ from mcp.shared.peer import Peer from mcp.shared.transport_context import TransportContext -from .conftest import direct_pair +from .conftest import direct_pair, jsonrpc_pair from .test_dispatcher import Recorder, echo_handlers, running_pair DCtx = DispatchContext[TransportContext] @@ -41,6 +41,24 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | assert bctx.meta is None +@pytest.mark.anyio +async def test_base_context_can_send_request_reflects_dispatch_context_closed_state(): + """`can_send_request` must track the dctx, not the static transport flag, + so it agrees with whether `send_raw_request` would raise.""" + captured: list[BaseContext[TransportContext]] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + captured.append(BaseContext(ctx)) + return {} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + bctx = captured[0] + assert bctx.transport.can_send_request is True + assert bctx.can_send_request is False + + @pytest.mark.anyio async def test_base_context_send_raw_request_and_notify_forward_to_dispatch_context(): crec = Recorder() diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index baea7f4b9a..68491c2620 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -16,7 +16,7 @@ from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream from mcp.shared.dispatcher import CallOptions, DispatchContext -from mcp.shared.exceptions import MCPError +from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] JSONRPCDispatcher, _coerce_id, @@ -354,6 +354,56 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> s.close() +@pytest.mark.anyio +async def test_courtesy_cancel_on_timeout_tags_outbound_with_server_message_metadata(): + """The timeout-path `notifications/cancelled` carries the originating request id. + + Streamable-HTTP's `message_router` keys on `ServerMessageMetadata.related_request_id`; + a cancel without it would fall through to the standalone GET stream and be dropped + when no GET stream is open, so the client never learns to stop work. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + with pytest.raises(MCPError): # REQUEST_TIMEOUT + await ctx.send_raw_request("sampling/createMessage", None, {"timeout": 0}) + return {"gave_up": True} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) + with anyio.fail_after(5): + outbound = await s2c_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + assert outbound.message.method == "sampling/createMessage" + sampling_id = outbound.message.id + # Don't respond; let the timeout fire. Next on the wire is the courtesy cancel. + with anyio.fail_after(5): + cancel = await s2c_recv.receive() + assert isinstance(cancel, SessionMessage) + assert isinstance(cancel.message, JSONRPCNotification) + assert cancel.message.method == "notifications/cancelled" + assert cancel.message.params == {"requestId": sampling_id, "reason": "timed out after 0s"} + assert isinstance(cancel.metadata, ServerMessageMetadata) + assert cancel.metadata.related_request_id == 7 + with anyio.fail_after(5): + final = await s2c_recv.receive() + assert isinstance(final, SessionMessage) + assert isinstance(final.message, JSONRPCResponse) + assert final.message.result == {"gave_up": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + @pytest.mark.anyio async def test_ctx_message_metadata_carries_inbound_request_metadata(): """Transport-attached metadata (HTTP request, SSE close hooks) is readable off the dispatch context.""" @@ -442,6 +492,39 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | assert received == [(0.25, None, None)] +@pytest.mark.anyio +async def test_ctx_after_handler_return_reports_closed_and_drops_backchannel_traffic(): + """Once `_handle_request` closes the dctx, the back-channel guard and ops agree. + + Detached work that outlives the handler must see `can_send_request == False`, + get `NoBackChannelError` from `send_raw_request`, and have `notify`/`progress` + silently dropped rather than emitted with a stale `related_request_id`. + """ + captured: list[DCtx] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + captured.append(ctx) + assert ctx.can_send_request is True + return {} + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, _server, crec, _srec): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) + dctx = captured[0] + assert dctx.can_send_request is False + with pytest.raises(NoBackChannelError): + await dctx.send_raw_request("sampling/createMessage", None) + await dctx.notify("notifications/message", {"level": "info"}) + await dctx.progress(0.9) + # A second round-trip flushes any notification the server might have + # written, so an empty client recorder afterwards proves the drop. + await client.send_raw_request("ping", None) + assert crec.notifications == [] + + @pytest.mark.anyio async def test_progress_callback_exception_is_swallowed_and_logged(caplog: pytest.LogCaptureFixture): """A user progress callback raising must not crash the dispatcher.""" @@ -596,13 +679,15 @@ async def test_cancelled_notification_for_unknown_request_id_is_noop(): @pytest.mark.anyio -async def test_handler_inherits_sender_contextvars_via_spawn(): - """The handler task sees contextvars set by the task that wrote into the read stream.""" +@pytest.mark.parametrize("inline", [frozenset[str](), frozenset({"t"})], ids=["spawned", "inline"]) +async def test_handler_inherits_sender_contextvars(inline: frozenset[str]): + """The handler task sees contextvars set by the task that wrote into the + read stream, on both the spawned and the inline-method dispatch paths.""" raw_send, raw_recv = anyio.create_memory_object_stream[tuple[contextvars.Context, SessionMessage | Exception]](4) read_stream = ContextReceiveStream[SessionMessage | Exception](raw_recv) write_send = ContextSendStream[SessionMessage | Exception](raw_send) out_send, out_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read_stream, out_send) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read_stream, out_send, inline_methods=inline) seen: list[str] = [] diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py index 0be4225818..47277ec88e 100644 --- a/tests/shared/test_peer.py +++ b/tests/shared/test_peer.py @@ -57,6 +57,22 @@ async def test_peer_sample_sends_create_message_and_returns_typed_result(): assert result.model == "m" +@pytest.mark.anyio +async def test_peer_sample_validates_result_alias_only(): + """Peer results validate alias-only; a snake_case key from the wire is + ignored as extra, not populated by Python field name.""" + snake = {"role": "assistant", "content": {"type": "text", "text": "x"}, "model": "m", "stop_reason": "endTurn"} + rec = _Recorder(snake) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="q"))], max_tokens=1 + ) + assert isinstance(result, CreateMessageResult) + assert result.stop_reason is None + + @pytest.mark.anyio async def test_peer_sample_with_tools_returns_with_tools_result(): rec = _Recorder({"role": "assistant", "content": [{"type": "text", "text": "x"}], "model": "m"}) From 6c518929c894bd0e173b0d4b9c119866dc68e17c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 3 Jun 2026 21:05:29 +0000 Subject: [PATCH 51/52] fix: BaseSession.__aexit__ healing checkpoint for gh-106749 (CI 3.11/3.14) The second gh-106749 site (the first being _memory.py): BaseSession's task-group cancel is also delivered via coro.throw() into the host task, desyncing CTracer past the caller's frame on 3.11. A cancel_shielded_checkpoint() after the join resumes via .send() and re-stamps the missing 'call' events. Shielded so a pending outer cancel isn't re-delivered here. The added tick shifts whether tests reach streamable_http.py's outer POST-error handler (it was timing-dependent before too); marked lax no cover. Also: collapse a nested with in test_jsonrpc_dispatcher.py and pragma: no branch it (3.14 misreports the 254->255 arc on the nested-with bytecode shape). --- src/mcp/server/streamable_http.py | 7 +++++-- src/mcp/shared/session.py | 11 ++++++++++- tests/shared/test_jsonrpc_dispatcher.py | 5 ++--- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f2f4407cea..217444793a 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -635,7 +635,10 @@ async def sse_writer(): finally: await sse_stream_reader.aclose() - except Exception as err: + except Exception as err: # pragma: lax no cover + # Reached only when something raises during POST handling outside + # the per-SSE-stream guard above; whether tests reach this depends + # on client teardown timing. logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -643,7 +646,7 @@ async def sse_writer(): INTERNAL_ERROR, ) await response(scope, receive, send) - if writer: # pragma: no cover + if writer: await writer.send(Exception(err)) return # pragma: no cover diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 63498ca338..afed6d54f1 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -7,6 +7,7 @@ from typing import Any, Generic, Protocol, TypeVar import anyio +import anyio.lowlevel from anyio.streams.memory import MemoryObjectSendStream from opentelemetry.trace import SpanKind from pydantic import BaseModel, TypeAdapter @@ -173,7 +174,15 @@ async def __aexit__( # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + # The cancel above is delivered via `coro.throw()` into this task; on + # CPython 3.11 (gh-106749) that drops `'call'` trace events for the + # outer await chain and desyncs coverage's CTracer past the caller's + # frame. Yielding once here resumes via `.send()`, which re-stamps the + # missing `'call'` events and resyncs the tracer. Shielded so a pending + # outer cancel is not re-delivered at this point. + await anyio.lowlevel.cancel_shielded_checkpoint() + return result async def send_request( self, diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 68491c2620..da8f8272f8 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -249,10 +249,9 @@ async def test_run_closes_write_stream_on_exit(): async with anyio.create_task_group() as tg: await tg.start(server.run, on_request, on_notify) c2s_send.close() # EOF the read side; run() exits - with anyio.fail_after(5): + with anyio.fail_after(5), pytest.raises(anyio.EndOfStream): # pragma: no branch # Write end was entered and released by run(); peer's receive sees EOF. - with pytest.raises(anyio.EndOfStream): - await s2c_recv.receive() + await s2c_recv.receive() s2c_recv.close() From 48e4c236a557de812cec984b74d170e7197e7f8c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 3 Jun 2026 22:54:35 +0000 Subject: [PATCH 52/52] fix: gh-106749 healing checkpoints at remaining throw sites; anyio>=4.10 on 3.14 The BaseSession.__aexit__ checkpoint only heals throws at or before session exit. Under xdist per-worker test ordering, the unmasked victims sit after later cancel-scope sites: client/streamable_http.py, client/sse.py, client/websocket.py, server/streamable_http_manager.py (finally-cancel after task-group join), and shared/memory.py:create_client_server_memory_streams (heals caller-driven cancels). Same shielded-checkpoint pattern at each. Also updated the _memory.py comment to reference the new memory.py heal. 3.14 lowest-direct: anyio 4.9.0 from_thread.py has return-in-finally which Python 3.14 (PEP 765) warns about at compile time; the warning lands in the stdio test child stderr. Fixed in anyio 4.10 (agronholm/anyio#816); marker-split the floor (locked already has 4.10). --- pyproject.toml | 6 +++++- src/mcp/client/_memory.py | 4 +++- src/mcp/client/sse.py | 8 ++++++++ src/mcp/client/streamable_http.py | 8 ++++++++ src/mcp/client/websocket.py | 8 ++++++++ src/mcp/server/streamable_http_manager.py | 8 ++++++++ src/mcp/shared/memory.py | 10 ++++++++++ uv.lock | 3 ++- 8 files changed, 52 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5f51fa9b85..c3b2bd92b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,11 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ - "anyio>=4.9", + # anyio < 4.10 triggers a compile-time SyntaxWarning on Python 3.14 (PEP 765, + # "'return' in a 'finally' block"); for stdio servers it lands on the child's + # stderr (agronholm/anyio#816, fixed in 4.10). + "anyio>=4.10; python_version >= '3.14'", + "anyio>=4.9; python_version < '3.14'", "httpx>=0.27.1,<1.0.0", "httpx-sse>=0.4", "pydantic>=2.12.0", diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index 05736acbaf..187131e380 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -87,7 +87,9 @@ async def _run_server() -> None: # completes the join would hang forever, so bound the wait # and fall back to cancelling. The healthy path returns # from wait() without the timeout firing, so the cancel is - # never reached and gh-106749 stays avoided. + # never reached and gh-106749 stays avoided. If the cancel + # does fire, the checkpoint at the end of + # `create_client_server_memory_streams` resyncs the tracer. with anyio.move_on_after(SERVER_SHUTDOWN_GRACE): await server_done.wait() if not server_done.is_set(): diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 74e5ba8062..d217ae42ff 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -5,6 +5,7 @@ from urllib.parse import parse_qs, urljoin, urlparse import anyio +import anyio.lowlevel import httpx from anyio.abc import TaskStatus from httpx_sse import SSEError, aconnect_sse @@ -157,3 +158,10 @@ async def _send_message(session_message: SessionMessage) -> None: yield read_stream, write_stream tg.cancel_scope.cancel() + # The cancel above is delivered via `coro.throw()` into this task at + # the task-group join; on CPython 3.11 (gh-106749) that drops `'call'` + # trace events for the outer await chain and desyncs coverage's CTracer + # past the caller's frame. Yielding once here resumes via `.send()`, + # which re-stamps the missing `'call'` events and resyncs the tracer. + # Shielded so a pending outer cancel is not re-delivered at this point. + await anyio.lowlevel.cancel_shielded_checkpoint() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9cdf717c73..78130f2f8e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -9,6 +9,7 @@ from dataclasses import dataclass import anyio +import anyio.lowlevel import httpx from anyio.abc import TaskGroup from httpx_sse import EventSource, ServerSentEvent, aconnect_sse @@ -586,3 +587,10 @@ def start_get_stream() -> None: if transport.session_id and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() + # The cancel above is delivered via `coro.throw()` into this task at + # the task-group join; on CPython 3.11 (gh-106749) that drops `'call'` + # trace events for the outer await chain and desyncs coverage's CTracer + # past the caller's frame. Yielding once here resumes via `.send()`, + # which re-stamps the missing `'call'` events and resyncs the tracer. + # Shielded so a pending outer cancel is not re-delivered at this point. + await anyio.lowlevel.cancel_shielded_checkpoint() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index de473f36d3..c3423c3c98 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -3,6 +3,7 @@ from contextlib import asynccontextmanager import anyio +import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from websockets.asyncio.client import connect as ws_connect @@ -83,3 +84,10 @@ async def ws_writer(): # Once the caller's 'async with' block exits, we shut down tg.cancel_scope.cancel() + # The cancel above is delivered via `coro.throw()` into this task at + # the task-group join; on CPython 3.11 (gh-106749) that drops `'call'` + # trace events for the outer await chain and desyncs coverage's CTracer + # past the caller's frame. Yielding once here resumes via `.send()`, + # which re-stamps the missing `'call'` events and resyncs the tracer. + # Shielded so a pending outer cancel is not re-delivered at this point. + await anyio.lowlevel.cancel_shielded_checkpoint() diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 81350a8f24..9bcf3cb883 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -9,6 +9,7 @@ from uuid import uuid4 import anyio +import anyio.lowlevel from anyio.abc import TaskStatus from starlette.requests import Request from starlette.responses import Response @@ -139,6 +140,13 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: # Clear any remaining server instances self._server_instances.clear() self._session_owners.clear() + # The cancel above is delivered via `coro.throw()` into this task at + # the task-group join; on CPython 3.11 (gh-106749) that drops `'call'` + # trace events for the outer await chain and desyncs coverage's CTracer + # past the caller's frame. Yielding once here resumes via `.send()`, + # which re-stamps the missing `'call'` events and resyncs the tracer. + # Shielded so a pending outer cancel is not re-delivered at this point. + await anyio.lowlevel.cancel_shielded_checkpoint() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 468590d095..b20bfa793e 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -5,6 +5,8 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +import anyio.lowlevel + from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared.message import SessionMessage @@ -28,3 +30,11 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS async with server_to_client_receive, client_to_server_send, client_to_server_receive, server_to_client_send: yield client_streams, server_streams + # Callers routinely cancel a task group wrapped around these streams just + # before this context exits; that cancel is delivered via `coro.throw()`, + # which on CPython 3.11 (gh-106749) drops `'call'` trace events for the + # outer await chain and desyncs coverage's CTracer past the caller's frame. + # Closing memory streams never suspends, so this is the last chance to + # resync: yielding once resumes via `.send()`, which re-stamps the missing + # `'call'` events. Shielded so a pending outer cancel is not re-delivered. + await anyio.lowlevel.cancel_shielded_checkpoint() diff --git a/uv.lock b/uv.lock index e40de57792..1a0ea56b45 100644 --- a/uv.lock +++ b/uv.lock @@ -904,7 +904,8 @@ docs = [ [package.metadata] requires-dist = [ - { name = "anyio", specifier = ">=4.9" }, + { name = "anyio", marker = "python_full_version < '3.14'", specifier = ">=4.9" }, + { name = "anyio", marker = "python_full_version >= '3.14'", specifier = ">=4.10" }, { name = "httpx", specifier = ">=0.27.1,<1.0.0" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" },