Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 72 additions & 20 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
import sys
from abc import abstractmethod
from collections.abc import Callable, Collection, Coroutine, Sequence
from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
from datetime import timedelta
from functools import partial
Expand Down Expand Up @@ -142,6 +142,13 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
return meta


def _url_origin(url: Any) -> tuple[str, str, int | None]:
port = url.port
if port is None:
port = 443 if url.scheme == "https" else 80 if url.scheme == "http" else None
return (url.scheme, url.host or "", port)


def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
"""Lazily import the MCP streamable HTTP transport."""
try:
Expand Down Expand Up @@ -255,6 +262,7 @@ def __init__(
self._exit_stack = AsyncExitStack()
self._lifecycle_lock = asyncio.Lock()
self._lifecycle_request_lock = asyncio.Lock()
self._function_load_lock = asyncio.Lock()
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
self._lifecycle_owner_task: asyncio.Task[None] | None = None
self.session = session
Expand Down Expand Up @@ -655,6 +663,11 @@ async def _safe_close_exit_stack(self) -> None:
raise
except asyncio.CancelledError:
logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.")
except Exception as e:
if type(e).__name__ == "ExceptionGroup":
logger.warning("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s", e)
else:
raise

async def _close_and_check_cancelled(self, ex: BaseException) -> bool:
"""Close the exit stack and return True if *ex* is a genuine task cancellation.
Expand Down Expand Up @@ -1018,6 +1031,10 @@ async def load_prompts(self) -> None:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
async with self._function_load_lock:
await self._load_prompts_locked()

async def _load_prompts_locked(self) -> None:
from anyio import ClosedResourceError
from mcp import types

Expand Down Expand Up @@ -1100,6 +1117,10 @@ async def load_tools(self) -> None:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
async with self._function_load_lock:
await self._load_tools_locked()

async def _load_tools_locked(self) -> None:
from anyio import ClosedResourceError
from mcp import types

Expand All @@ -1109,7 +1130,7 @@ async def load_tools(self) -> None:

# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
self._tool_call_meta_by_name.clear()
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}

params: types.PaginatedRequestParams | None = None
while True:
Expand Down Expand Up @@ -1145,7 +1166,7 @@ async def load_tools(self) -> None:

for tool in tool_list.tools:
if tool.meta is not None:
self._tool_call_meta_by_name[tool.name] = dict(tool.meta)
tool_call_meta_by_name[tool.name] = dict(tool.meta)

normalized_name = _normalize_mcp_name(tool.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
Expand Down Expand Up @@ -1194,6 +1215,8 @@ async def _call_tool_with_runtime_kwargs(
break
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)

self._tool_call_meta_by_name = tool_call_meta_by_name

async def _close_on_owner(self) -> None:
# Cancel any pending reload tasks before tearing down the session.
tasks = list(self._pending_reload_tasks)
Expand Down Expand Up @@ -1276,7 +1299,11 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
tool_name: The name of the tool to call.

Keyword Args:
kwargs: Arguments to pass to the tool.
_meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the
``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument.
User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in
non-conflicting keys.
kwargs: Remaining arguments to pass to the tool.

Returns:
A list of Content items representing the tool output. The default
Expand All @@ -1294,6 +1321,19 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
raise ToolExecutionException(
"Tools are not loaded for this server, please set load_tools=True in the constructor."
)

raw_user_meta: object | None = kwargs.get("_meta")
user_meta: dict[str, Any] | None = None
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
if isinstance(raw_user_meta, dict):
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
user_meta = {}
for key, value in raw_user_meta_dict.items():
if not isinstance(key, str):
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
user_meta[key] = value

# Filter out framework kwargs that cannot be serialized by the MCP SDK.
# These are internal objects passed through the function invocation pipeline
# that should not be forwarded to external MCP servers.
Expand All @@ -1313,12 +1353,16 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
"conversation_id",
"options",
"response_format",
"_meta",
}
}

# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
tool_meta = self._tool_call_meta_by_name.get(tool_name)
meta = _inject_otel_into_mcp_meta(dict(tool_meta) if tool_meta is not None else None)
request_meta = dict(tool_meta) if tool_meta is not None else None
if user_meta is not None:
request_meta = {**(request_meta or {}), **user_meta}
meta = _inject_otel_into_mcp_meta(request_meta)

parser = self.parse_tool_results or self._parse_tool_result_from_mcp
# Try the operation, reconnecting once if the connection is closed
Expand All @@ -1336,28 +1380,33 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
return parser(result)
except ToolExecutionException:
raise
except ClosedResourceError as cl_ex:
except (ClosedResourceError, McpError) as call_ex:
is_session_terminated = (
isinstance(call_ex, McpError) and "session terminated" in call_ex.error.message.lower()
)
is_connection_lost = isinstance(call_ex, ClosedResourceError) or is_session_terminated
if not is_connection_lost:
error_message = call_ex.error.message if isinstance(call_ex, McpError) else str(call_ex)
raise ToolExecutionException(error_message, inner_exception=call_ex) from call_ex

if attempt == 0:
# First attempt failed, try reconnecting
logger.info("MCP connection closed unexpectedly. Reconnecting...")
# First attempt failed, try reconnecting.
logger.info("MCP connection closed or terminated unexpectedly. Reconnecting...")
try:
await self.connect(reset=True)
continue # Retry the operation
continue
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
else:
# Second attempt also failed, give up
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
raise ToolExecutionException(
f"Failed to call tool '{tool_name}' - connection lost.",
inner_exception=cl_ex,
) from cl_ex
except McpError as mcp_exc:
error_message = mcp_exc.error.message
raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc

# Second attempt also failed, give up.
logger.error("MCP connection closed unexpectedly after reconnection: %s", call_ex)
raise ToolExecutionException(
f"Failed to call tool '{tool_name}' - connection lost.",
inner_exception=call_ex,
) from call_ex
except Exception as ex:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
Expand Down Expand Up @@ -1718,10 +1767,11 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
Returns:
An async context manager for the streamable HTTP client transport.
"""
from httpx import AsyncClient, Request, Timeout
from httpx import URL, AsyncClient, Request, Timeout

http_client = self._httpx_client
if self._header_provider is not None:
target_origin = _url_origin(URL(self.url))
if http_client is None:
http_client = AsyncClient(
follow_redirects=True,
Expand All @@ -1732,6 +1782,8 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
if not hasattr(self, "_inject_headers_hook"):

async def _inject_headers(request: Request) -> None: # noqa: RUF029
if _url_origin(request.url) != target_origin:
return
headers = _mcp_call_headers.get({})
for key, value in headers.items():
request.headers[key] = value
Expand Down
Loading
Loading