diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 35ccb1d58a..b72a83fd81 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -642,6 +642,11 @@ async def _safe_close_exit_stack(self) -> None: ) else: raise + except Exception as e: + if type(e).__name__ == "ExceptionGroup": + logger.warning("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s", e) + else: + raise except asyncio.CancelledError: logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.") diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 0fc5867d79..b050b9c51d 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -3891,6 +3891,33 @@ async def test_mcp_tool_safe_close_handles_cancelled_error(): mock_exit_stack.aclose.assert_called_once() +async def test_mcp_tool_safe_close_handles_cleanup_exception_group(): + """Cleanup task groups should not hide the original connect failure.""" + import builtins + from contextlib import AsyncExitStack + + exception_group_type = getattr(builtins, "ExceptionGroup", None) + if exception_group_type is None: + pytest.skip("ExceptionGroup is not available on this Python version") + + tool = MCPStreamableHTTPTool( + name="test", + url="http://example.com/mcp", + load_tools=False, + load_prompts=False, + ) + + mock_exit_stack = AsyncMock(spec=AsyncExitStack) + mock_exit_stack.enter_async_context = AsyncMock(side_effect=ConnectionRefusedError("down")) + mock_exit_stack.aclose = AsyncMock(side_effect=exception_group_type("cleanup failed", [RuntimeError("reader")])) + tool._exit_stack = mock_exit_stack + + with pytest.raises(ToolException, match="Failed to connect to MCP server"): + await tool.__aenter__() + + mock_exit_stack.aclose.assert_called_once() + + async def test_connect_sets_logging_level_when_logger_level_is_set(): """Test that connect() sets the MCP server logging level when the logger level is not NOTSET."""