From a0f8bef072b8ced027a28d6e6d925a0f3d9a7257 Mon Sep 17 00:00:00 2001 From: kiwigitops Date: Thu, 28 May 2026 14:41:56 -0400 Subject: [PATCH] Fix MCP header provider across transport tasks Signed-off-by: kiwigitops --- python/packages/core/agent_framework/_mcp.py | 24 ++-- python/packages/core/tests/core/test_mcp.py | 120 +++++++++---------- 2 files changed, 72 insertions(+), 72 deletions(-) diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index d872b2b92de..28cbbddc621 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -4,7 +4,6 @@ import asyncio import base64 -import contextvars import json import logging import re @@ -61,7 +60,6 @@ class MCPSpecificApproval(TypedDict, total=False): _MCP_REMOTE_NAME_KEY = "_mcp_remote_name" _MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name" -_mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("_mcp_call_headers") MCP_DEFAULT_TIMEOUT = 30 MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5 @@ -1760,6 +1758,8 @@ def __init__( self.terminate_on_close = terminate_on_close self._httpx_client: AsyncClient | None = http_client self._header_provider = header_provider + self._active_call_headers: dict[str, str] = {} + self._header_provider_call_lock = asyncio.Lock() def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: """Get an MCP streamable HTTP client. @@ -1784,8 +1784,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: async def _inject_headers(request: Request) -> None: # noqa: RUF029 if _url_origin(request.url) != target_origin: return - headers = _mcp_call_headers.get({}) - for key, value in headers.items(): + for key, value in self._active_call_headers.items(): request.headers[key] = value self._inject_headers_hook = _inject_headers # type: ignore[attr-defined] @@ -1802,8 +1801,8 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]: When a ``header_provider`` was supplied at construction time, the runtime *kwargs* (originating from ``FunctionInvocationContext.kwargs``) are passed - to the provider. The returned headers are attached to every HTTP request - made during this tool call via a ``contextvars.ContextVar``. + to the provider. The returned headers are attached to every HTTP request + made during this tool call. Args: tool_name: The name of the tool to call. @@ -1815,12 +1814,13 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]: A list of Content items representing the tool output. """ if self._header_provider is not None: - headers = self._header_provider(kwargs) - token = _mcp_call_headers.set(headers) - try: - return await super().call_tool(tool_name, **kwargs) - finally: - _mcp_call_headers.reset(token) + async with self._header_provider_call_lock: + previous_headers = self._active_call_headers + self._active_call_headers = dict(self._header_provider(kwargs)) + try: + return await super().call_tool(tool_name, **kwargs) + finally: + self._active_call_headers = previous_headers return await super().call_tool(tool_name, **kwargs) diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 519d8e5db31..8cedb540551 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -4639,19 +4639,13 @@ def provider(kwargs): server.session.call_tool.assert_called_once() -async def test_mcp_streamable_http_tool_header_provider_sets_contextvar(): - """Test that call_tool sets the contextvar with headers from header_provider.""" - from agent_framework._mcp import _mcp_call_headers - +async def test_mcp_streamable_http_tool_header_provider_sets_active_headers(): + """Test that call_tool exposes headers from header_provider during the call.""" observed_headers: list[dict[str, str]] = [] original_call_tool = MCPTool.call_tool async def spy_call_tool(self, tool_name, **kwargs): - # Capture the contextvar value during the super call - try: - observed_headers.append(_mcp_call_headers.get()) - except LookupError: - observed_headers.append({}) + observed_headers.append(dict(self._active_call_headers)) return await original_call_tool(self, tool_name, **kwargs) class _TestServer(MCPStreamableHTTPTool): @@ -4692,10 +4686,8 @@ def get_mcp_client(self): assert observed_headers[0] == {"X-Auth": "bearer-xyz"} -async def test_mcp_streamable_http_tool_header_provider_contextvar_reset_after_call(): - """Test that the contextvar is properly reset after call_tool completes.""" - from agent_framework._mcp import _mcp_call_headers - +async def test_mcp_streamable_http_tool_header_provider_active_headers_reset_after_call(): + """Test that active headers are reset after call_tool completes.""" class _TestServer(MCPStreamableHTTPTool): async def connect(self): self.session = Mock(spec=ClientSession) @@ -4728,9 +4720,7 @@ def get_mcp_client(self): await server.load_tools() await server.call_tool("greet", name="Alice", token="secret") - # After call_tool, the contextvar should be unset (reset to no value) - with pytest.raises(LookupError): - _mcp_call_headers.get() + assert server._active_call_headers == {} async def test_mcp_streamable_http_tool_without_header_provider(): @@ -4773,10 +4763,10 @@ def get_mcp_client(self): async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook(): - """Test that the httpx event hook injects headers from the contextvar.""" + """Test that the httpx event hook injects the tool's active headers.""" import httpx - from agent_framework._mcp import MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT, _mcp_call_headers + from agent_framework._mcp import MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT tool = MCPStreamableHTTPTool( name="test", @@ -4797,14 +4787,10 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook(): hooks = tool._httpx_client.event_hooks.get("request", []) assert len(hooks) == 1, "Expected one request event hook" - # Simulate what happens during a call_tool: contextvar is set - token = _mcp_call_headers.set({"X-Custom": "test-value"}) - try: - request = httpx.Request("POST", "http://example.com/mcp") - await hooks[0](request) - assert request.headers.get("X-Custom") == "test-value" - finally: - _mcp_call_headers.reset(token) + tool._active_call_headers = {"X-Custom": "test-value"} + request = httpx.Request("POST", "http://example.com/mcp") + await hooks[0](request) + assert request.headers.get("X-Custom") == "test-value" finally: # Ensure any created httpx client is properly closed if getattr(tool, "_httpx_client", None) is not None: @@ -4815,8 +4801,6 @@ async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redir """The request hook must not re-add caller headers after a cross-origin redirect.""" import httpx - from agent_framework._mcp import _mcp_call_headers - tool = MCPStreamableHTTPTool( name="test", url="http://example.com/mcp", @@ -4831,17 +4815,45 @@ async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redir hooks = tool._httpx_client.event_hooks.get("request", []) assert len(hooks) == 1 - token = _mcp_call_headers.set({"Authorization": "Bearer secret"}) - try: - same_origin = httpx.Request("POST", "http://example.com/redirected") - await hooks[0](same_origin) - assert same_origin.headers.get("Authorization") == "Bearer secret" - - cross_origin = httpx.Request("POST", "http://attacker.example/capture") - await hooks[0](cross_origin) - assert "Authorization" not in cross_origin.headers - finally: - _mcp_call_headers.reset(token) + tool._active_call_headers = {"Authorization": "Bearer secret"} + + same_origin = httpx.Request("POST", "http://example.com/redirected") + await hooks[0](same_origin) + assert same_origin.headers.get("Authorization") == "Bearer secret" + + cross_origin = httpx.Request("POST", "http://attacker.example/capture") + await hooks[0](cross_origin) + assert "Authorization" not in cross_origin.headers + finally: + if getattr(tool, "_httpx_client", None) is not None: + await tool._httpx_client.aclose() + + +async def test_mcp_streamable_http_tool_header_provider_hook_reads_headers_from_transport_task(): + """Test that request hooks can read updated headers from another task.""" + import httpx + + tool = MCPStreamableHTTPTool( + name="test", + url="http://example.com/mcp", + header_provider=lambda kw: {"X-Custom": kw.get("custom", "")}, + ) + + try: + with patch("agent_framework._mcp.streamable_http_client"): + tool.get_mcp_client() + + assert tool._httpx_client is not None + hooks = tool._httpx_client.event_hooks.get("request", []) + assert len(hooks) == 1 + + async def run_hook_in_transport_task() -> str | None: + request = httpx.Request("POST", "http://example.com/mcp") + await hooks[0](request) + return request.headers.get("X-Custom") + + tool._active_call_headers = {"X-Custom": "test-value"} + assert await asyncio.create_task(run_hook_in_transport_task()) == "test-value" finally: if getattr(tool, "_httpx_client", None) is not None: await tool._httpx_client.aclose() @@ -4851,8 +4863,6 @@ async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client() """Test that header_provider works when the user provides their own httpx client.""" import httpx - from agent_framework._mcp import _mcp_call_headers - user_client = httpx.AsyncClient(headers={"X-Base": "static"}) tool = MCPStreamableHTTPTool( @@ -4870,14 +4880,10 @@ async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client() hooks = user_client.event_hooks.get("request", []) assert len(hooks) == 1 - # Verify the hook injects headers - token = _mcp_call_headers.set({"X-Dynamic": "per-request"}) - try: - request = httpx.Request("POST", "http://example.com/mcp") - await hooks[0](request) - assert request.headers.get("X-Dynamic") == "per-request" - finally: - _mcp_call_headers.reset(token) + tool._active_call_headers = {"X-Dynamic": "per-request"} + request = httpx.Request("POST", "http://example.com/mcp") + await hooks[0](request) + assert request.headers.get("X-Dynamic") == "per-request" await user_client.aclose() @@ -4888,19 +4894,12 @@ async def test_mcp_streamable_http_tool_header_provider_via_invoke_with_context( This exercises the full pipeline: FunctionInvocationContext.kwargs -> FunctionTool.invoke -> MCPStreamableHTTPTool.call_tool -> header_provider. """ - from agent_framework._mcp import _mcp_call_headers - observed_headers: list[dict[str, str]] = [] - original_call_tool = MCPStreamableHTTPTool.call_tool + original_call_tool = MCPTool.call_tool async def spy_call_tool(self, tool_name, **kwargs): - # Capture the contextvar value set by call_tool before delegating - result = await original_call_tool(self, tool_name, **kwargs) - try: - observed_headers.append(_mcp_call_headers.get()) - except LookupError: - observed_headers.append({}) - return result + observed_headers.append(dict(self._active_call_headers)) + return await original_call_tool(self, tool_name, **kwargs) class _TestServer(MCPStreamableHTTPTool): async def connect(self): @@ -4951,7 +4950,7 @@ def provider(kwargs): kwargs={"some_token": "my-secret"}, ) - with patch.object(MCPStreamableHTTPTool, "call_tool", spy_call_tool): + with patch.object(MCPTool, "call_tool", spy_call_tool): result = await func.invoke(arguments={"name": "Alice"}, context=context) # Verify the invoke produced a result @@ -4961,6 +4960,7 @@ def provider(kwargs): # Verify header_provider was called with the runtime kwargs assert len(provider_received) == 1 assert provider_received[0]["some_token"] == "my-secret" + assert observed_headers == [{"X-Some-Token": "my-secret"}] # Verify session.call_tool was called with the tool arguments (not the runtime kwargs) server.session.call_tool.assert_called_once()