diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index b2942de2a0..2184d7c6e8 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -1718,7 +1718,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: Returns: An async context manager for the streamable HTTP client transport. """ - from httpx import AsyncClient, Request, Timeout + from httpx import URL, AsyncClient, Request, Timeout http_client = self._httpx_client if self._header_provider is not None: @@ -1730,8 +1730,11 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: self._httpx_client = http_client if not hasattr(self, "_inject_headers_hook"): + target_origin = _url_origin(URL(self.url)) async def _inject_headers(request: Request) -> None: # noqa: RUF029 + if _url_origin(request.url) != target_origin: + return headers = _mcp_call_headers.get({}) for key, value in headers.items(): request.headers[key] = value @@ -1772,6 +1775,13 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]: return await super().call_tool(tool_name, **kwargs) +def _url_origin(url: Any) -> tuple[str, str, int | None]: + port = url.port + if port is None: + port = 443 if url.scheme == "https" else 80 if url.scheme == "http" else None + return (url.scheme, url.host or "", port) + + class MCPWebsocketTool(MCPTool): """MCP tool for connecting to WebSocket-based MCP servers. diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 6273eb76e6..7725f1f8a8 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -4641,6 +4641,42 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook(): await tool._httpx_client.aclose() +async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redirect(): + """The request hook must not re-add caller headers after a cross-origin redirect.""" + import httpx + + from agent_framework._mcp import _mcp_call_headers + + tool = MCPStreamableHTTPTool( + name="test", + url="https://example.com/mcp", + header_provider=lambda kw: {"Authorization": f"Bearer {kw.get('token', '')}"}, + ) + + try: + with patch("agent_framework._mcp.streamable_http_client"): + tool.get_mcp_client() + + assert tool._httpx_client is not None + hooks = tool._httpx_client.event_hooks.get("request", []) + assert len(hooks) == 1 + + token = _mcp_call_headers.set({"Authorization": "Bearer secret"}) + try: + same_origin = httpx.Request("POST", "https://example.com/redirected") + await hooks[0](same_origin) + assert same_origin.headers.get("Authorization") == "Bearer secret" + + cross_origin = httpx.Request("POST", "https://attacker.example/capture") + await hooks[0](cross_origin) + assert "Authorization" not in cross_origin.headers + finally: + _mcp_call_headers.reset(token) + finally: + if getattr(tool, "_httpx_client", None) is not None: + await tool._httpx_client.aclose() + + async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client(): """Test that header_provider works when the user provides their own httpx client.""" import httpx