Skip to content
86 changes: 86 additions & 0 deletions src/mcp/server/_typed_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Typed ``send_request`` for server-to-client requests.

`TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over
the host's raw `Outbound.send_raw_request`. Spec server-to-client request types
have their result type inferred via per-type overloads; custom requests pass
``result_type=`` explicitly.

If the spec's request set grows substantially, consider declaring the result
mapping on the request types themselves (a ``__mcp_result__`` ClassVar read via
a structural protocol) so this overload ladder doesn't need maintaining
per-host-class.
"""

from typing import Any, TypeVar, overload

from pydantic import BaseModel

from mcp.shared.dispatcher import CallOptions, Outbound
from mcp.shared.peer import dump_params
from mcp.types import (
CreateMessageRequest,
CreateMessageResult,
ElicitRequest,
ElicitResult,
EmptyResult,
ListRootsRequest,
ListRootsResult,
PingRequest,
Request,
)

__all__ = ["TypedServerRequestMixin"]

ResultT = TypeVar("ResultT", bound=BaseModel)

_RESULT_FOR: dict[type[Request[Any, Any]], type[BaseModel]] = {
CreateMessageRequest: CreateMessageResult,
ElicitRequest: ElicitResult,
ListRootsRequest: ListRootsResult,
PingRequest: EmptyResult,
}


class TypedServerRequestMixin:
"""Typed ``send_request`` for the server-to-client request set.

Mixed into `Connection` and the server `Context`. Each method constrains
``self`` to `Outbound` so any host with ``send_raw_request`` works.
"""

@overload
async def send_request(
self: Outbound, req: CreateMessageRequest, *, opts: CallOptions | None = None
) -> CreateMessageResult: ...
@overload
async def send_request(self: Outbound, req: ElicitRequest, *, opts: CallOptions | None = None) -> ElicitResult: ...
@overload
async def send_request(
self: Outbound, req: ListRootsRequest, *, opts: CallOptions | None = None
) -> ListRootsResult: ...
@overload
async def send_request(self: Outbound, req: PingRequest, *, opts: CallOptions | None = None) -> EmptyResult: ...
@overload
async def send_request(
self: Outbound, req: Request[Any, Any], *, result_type: type[ResultT], opts: CallOptions | None = None
) -> ResultT: ...
async def send_request(
self: Outbound,
req: Request[Any, Any],
*,
result_type: type[BaseModel] | None = None,
opts: CallOptions | None = None,
) -> BaseModel:
"""Send a typed server-to-client request and return its typed result.

For spec request types the result type is inferred. For custom requests
pass ``result_type=`` explicitly.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: No back-channel for server-initiated requests.
KeyError: ``result_type`` omitted for a non-spec request type.
"""
raw = await self.send_raw_request(req.method, dump_params(req.params), opts)
cls = result_type if result_type is not None else _RESULT_FOR[type(req)]
return cls.model_validate(raw)
146 changes: 146 additions & 0 deletions src/mcp/server/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""`Connection` — per-client connection state and the standalone outbound channel.

Always present on `Context` (never ``None``), even in stateless deployments.
Holds peer info populated at ``initialize`` time, the per-connection lifespan
output, and an `Outbound` for the standalone stream (the SSE GET stream in
streamable HTTP, or the single duplex stream in stdio).

`notify` is best-effort: it never raises. If there's no standalone channel
(stateless HTTP) or the stream has been dropped, the notification is
debug-logged and silently discarded — server-initiated notifications are
inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when
there's no channel; `ping` is the only spec-sanctioned standalone request.
"""

import logging
from collections.abc import Mapping
from typing import Any

import anyio

from mcp.server._typed_request import TypedServerRequestMixin
from mcp.shared.dispatcher import CallOptions, Outbound
from mcp.shared.exceptions import NoBackChannelError
from mcp.shared.peer import Meta, dump_params
from mcp.types import ClientCapabilities, Implementation, LoggingLevel

__all__ = ["Connection"]

logger = logging.getLogger(__name__)


def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> dict[str, Any] | None:
if not meta:
return payload
out = dict(payload or {})
out["_meta"] = meta
return out


class Connection(TypedServerRequestMixin):
"""Per-client connection state and standalone-stream `Outbound`.

Constructed by `ServerRunner` once per connection. The peer-info fields are
``None`` until ``initialize`` completes; ``initialized`` is set then.
"""

def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None:
self._outbound = outbound
self.has_standalone_channel = has_standalone_channel

self.client_info: Implementation | None = None
self.client_capabilities: ClientCapabilities | None = None
self.protocol_version: str | None = None
self.initialized: anyio.Event = anyio.Event()
# TODO: make this generic (Connection[StateT]) once connection_lifespan
# wiring lands in ServerRunner.
self.state: Any = None

async def send_raw_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None = None,
) -> dict[str, Any]:
"""Send a raw request on the standalone stream.

Low-level `Outbound` channel. Prefer the typed ``send_request`` (from
`TypedServerRequestMixin`) or the convenience methods below; use this
directly only for off-spec messages.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: ``has_standalone_channel`` is ``False``.
"""
if not self.has_standalone_channel:
raise NoBackChannelError(method)
return await self._outbound.send_raw_request(method, params, opts)

async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
"""Send a best-effort notification on the standalone stream.

Never raises. If there's no standalone channel or the stream is broken,
the notification is dropped and debug-logged.
"""
if not self.has_standalone_channel:
logger.debug("dropped %s: no standalone channel", method)
return
try:
await self._outbound.notify(method, params)
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("dropped %s: standalone stream closed", method)

async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None:
"""Send a ``ping`` request on the standalone stream.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: ``has_standalone_channel`` is ``False``.
"""
await self.send_raw_request("ping", dump_params(None, meta), opts)

async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None:
"""Send a ``notifications/message`` log entry on the standalone stream. Best-effort."""
params: dict[str, Any] = {"level": level, "data": data}
if logger is not None:
params["logger"] = logger
await self.notify("notifications/message", _notification_params(params, meta))

async def send_tool_list_changed(self, *, meta: Meta | None = None) -> None:
await self.notify("notifications/tools/list_changed", _notification_params(None, meta))

async def send_prompt_list_changed(self, *, meta: Meta | None = None) -> None:
await self.notify("notifications/prompts/list_changed", _notification_params(None, meta))

async def send_resource_list_changed(self, *, meta: Meta | None = None) -> None:
await self.notify("notifications/resources/list_changed", _notification_params(None, meta))

async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> None:
await self.notify("notifications/resources/updated", _notification_params({"uri": uri}, meta))

def check_capability(self, capability: ClientCapabilities) -> bool:
"""Return whether the connected client declared the given capability.

Returns ``False`` if ``initialize`` hasn't completed yet.
"""
# TODO: redesign — mirrors v1 ServerSession.check_client_capability
# verbatim for parity.
if self.client_capabilities is None:
return False
have = self.client_capabilities
if capability.roots is not None:
if have.roots is None:
return False
if capability.roots.list_changed and not have.roots.list_changed:
return False
if capability.sampling is not None and have.sampling is None:
return False
if capability.elicitation is not None and have.elicitation is None:
return False
if capability.experimental is not None:
if have.experimental is None:
return False
for k in capability.experimental:
if k not in have.experimental:
return False
return True
60 changes: 60 additions & 0 deletions src/mcp/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@

from typing_extensions import TypeVar

from mcp.server._typed_request import TypedServerRequestMixin
from mcp.server.connection import Connection
from mcp.server.experimental.request_context import Experimental
from mcp.server.session import ServerSession
from mcp.shared._context import RequestContext
from mcp.shared.context import BaseContext
from mcp.shared.dispatcher import DispatchContext
from mcp.shared.message import CloseSSEStreamCallback
from mcp.shared.peer import Meta, PeerMixin
from mcp.shared.transport_context import TransportContext
from mcp.types import LoggingLevel, RequestParamsMeta

LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any])
RequestT = TypeVar("RequestT", default=Any)
Expand All @@ -21,3 +28,56 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex
request: RequestT | None = None
close_sse_stream: CloseSSEStreamCallback | None = None
close_standalone_sse_stream: CloseSSEStreamCallback | None = None


LifespanT = TypeVar("LifespanT", default=Any, covariant=True)
TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True)


class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]):
"""Server-side per-request context.

Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`),
`PeerMixin` (kwarg-style ``sample``/``elicit_*``/``list_roots``/``ping``),
and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds
``lifespan`` and ``connection``.

Constructed by `ServerRunner` per inbound request and handed to the user's
handler.
"""

def __init__(
self,
dctx: DispatchContext[TransportT],
*,
lifespan: LifespanT,
connection: Connection,
meta: RequestParamsMeta | None = None,
) -> None:
super().__init__(dctx, meta=meta)
self._lifespan = lifespan
self._connection = connection

@property
def lifespan(self) -> LifespanT:
"""The server-wide lifespan output (what `Server(..., lifespan=...)` yielded)."""
return self._lifespan

@property
def connection(self) -> Connection:
"""The per-client `Connection` for this request's connection."""
return self._connection

async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None:
"""Send a request-scoped ``notifications/message`` log entry.

Uses this request's back-channel (so the entry rides the request's SSE
stream in streamable HTTP), not the standalone stream — use
``ctx.connection.log(...)`` for that.
"""
params: dict[str, Any] = {"level": level, "data": data}
if logger is not None:
params["logger"] = logger
if meta:
params["_meta"] = meta
await self.notify("notifications/message", params)
82 changes: 82 additions & 0 deletions src/mcp/shared/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""`BaseContext` — the user-facing per-request context.

Composition over a `DispatchContext`: forwards the transport metadata, the
back-channel (`send_raw_request`/`notify`), progress reporting, and the cancel
event. Adds `meta` (the inbound request's `_meta` field).

Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context`
mixes that in directly). Shared between client and server: the server's
`Context` extends this with `lifespan`/`connection`; `ClientContext` is just an
alias.
"""

from collections.abc import Mapping
from typing import Any, Generic

import anyio
from typing_extensions import TypeVar

from mcp.shared.dispatcher import CallOptions, DispatchContext
from mcp.shared.transport_context import TransportContext
from mcp.types import RequestParamsMeta

__all__ = ["BaseContext"]

TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True)


class BaseContext(Generic[TransportT]):
"""Per-request context wrapping a `DispatchContext`.

`ServerRunner` constructs one per inbound request and passes it to the
user's handler.
"""

def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None:
self._dctx = dctx
self._meta = meta

@property
def transport(self) -> TransportT:
"""Transport-specific metadata for this inbound request."""
return self._dctx.transport

@property
def cancel_requested(self) -> anyio.Event:
"""Set when the peer sends ``notifications/cancelled`` for this request."""
return self._dctx.cancel_requested

@property
def can_send_request(self) -> bool:
"""Whether the back-channel can deliver server-initiated requests."""
return self._dctx.transport.can_send_request

@property
def meta(self) -> RequestParamsMeta | None:
"""The inbound request's ``_meta`` field, if present."""
return self._meta

async def send_raw_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None = None,
) -> dict[str, Any]:
"""Send a request to the peer on the back-channel.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: ``can_send_request`` is ``False``.
"""
return await self._dctx.send_raw_request(method, params, opts)

async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
"""Send a notification to the peer on the back-channel."""
await self._dctx.notify(method, params)

async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
"""Report progress for this request, if the peer supplied a progress token.

A no-op when no token was supplied.
"""
await self._dctx.progress(progress, total, message)
Loading
Loading