Skip to content
Open
117 changes: 116 additions & 1 deletion python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from opentelemetry import propagate

from ._middleware import FunctionInvocationContext
from ._tools import FunctionTool
from ._types import (
ChatOptions,
Expand All @@ -39,7 +40,6 @@
from mcp.shared.session import RequestResponder

from ._clients import SupportsChatGetResponse
from ._middleware import FunctionInvocationContext


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1292,6 +1292,121 @@ async def get_prompt(self, prompt_name: str, **kwargs: Any) -> str:
raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to get prompt '{prompt_name}' after retries.")

def as_progressive_tools(
self,
list_tool_name: str = "list_mcp_tools",
call_tool_name: str = "call_mcp",
) -> list[FunctionTool]:
"""Expose this MCP server in a progressive discovery mode.

Instead of exposing every remote tool schema upfront, the model receives a small
stable surface:
- A discovery tool to list available tools and their schemas.
- A dispatch tool to call a specific tool by name.

This is useful for large MCP servers where exposing all tool schemas upfront
would add significant token overhead. The SDK still owns connection lifecycle,
allowed_tools filtering, result parsing, exceptions, and OTel propagation.

Args:
list_tool_name: Name for the discovery tool. Defaults to "list_mcp_tools".
call_tool_name: Name for the dispatch tool. Defaults to "call_mcp".

Returns:
A list of exactly two FunctionTools to pass to an Agent.
"""

async def _list_tools(server: str | None = None) -> str:
"""List available tools on this MCP server.

Args:
server: The name of the server to list tools for. Must match this server's name if provided.
"""
if server and server != self.name:
return json.dumps([])

tool_list = []
for func in self.functions:
tool_list.append(
{
"name": func.name,
"description": func.description,
"parameters": func.parameters(),
}
)
return json.dumps(tool_list, separators=(",", ":"))

async def _call_tool(
server: str,
tool: str,
arguments: dict[str, Any] | None = None,
context: FunctionInvocationContext | None = None,
) -> Any:
"""Call a specific tool on this MCP server.

Note:
Any approval_mode or middleware configured on the underlying target tool
are enforced at the call_mcp wrapper tool level, as call_mcp is the
actual FunctionTool that traverses the agent execution pipeline.

Args:
server: The name of the server. Must match this server's name.
tool: The name of the tool to call.
arguments: The arguments to pass to the tool.
context: The framework function invocation context.
"""
if server != self.name:
raise ToolExecutionException(f"Unknown server '{server}'. This dispatcher is for server '{self.name}'.")

target_func: FunctionTool | None = None
for func in self.functions:
props = func.additional_properties or {}
if (
func.name == tool
or props.get(_MCP_NORMALIZED_NAME_KEY) == tool
or props.get(_MCP_REMOTE_NAME_KEY) == tool
):
target_func = func
break

if not target_func:
raise ToolExecutionException(f"Tool '{tool}' not found or not allowed on server '{self.name}'.")

# Create a fresh context for the target tool so that FunctionTool.invoke's
# in-place mutations (context.function, context.arguments, context.kwargs)
# do not corrupt the wrapper call_mcp's context that middleware may still
# be observing after call_next() returns.
target_context = FunctionInvocationContext(
function=target_func,
arguments=arguments or {},
session=context.session if context is not None else None,
kwargs=context.kwargs if context is not None else None,
)
Comment thread
karthik-0306 marked this conversation as resolved.
return await target_func.invoke(arguments=arguments or {}, context=target_context)

list_tool = FunctionTool(
name=list_tool_name,
description=f"List available tools on the {self.name} MCP server.",
func=_list_tools,
approval_mode="never_require",
)

# When approval_mode is a dict (MCPSpecificApproval with per-tool allow/deny
# lists), the framework's _try_execute_function_calls cannot interpret it as
# a wrapper-level policy string and will silently bypass approval. Normalise
# any dict value to "always_require" so the dispatch wrapper is always gated
# conservatively.
wrapper_approval_mode = "always_require" if isinstance(self.approval_mode, dict) else self.approval_mode

call_tool = FunctionTool(
name=call_tool_name,
description=f"Call a specific tool on the {self.name} MCP server.",
func=_call_tool,
Comment thread
karthik-0306 marked this conversation as resolved.
approval_mode=wrapper_approval_mode,
)

return [list_tool, call_tool]

async def __aenter__(self) -> Self:
"""Enter the async context manager.

Expand Down
Loading
Loading