diff --git a/examples/tutorials/00_sync/040_pydantic_ai/.dockerignore b/examples/tutorials/00_sync/040_pydantic_ai/.dockerignore new file mode 100644 index 000000000..c49489471 --- /dev/null +++ b/examples/tutorials/00_sync/040_pydantic_ai/.dockerignore @@ -0,0 +1,43 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Environments +.env** +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Git +.git +.gitignore + +# Misc +.DS_Store diff --git a/examples/tutorials/00_sync/040_pydantic_ai/Dockerfile b/examples/tutorials/00_sync/040_pydantic_ai/Dockerfile new file mode 100644 index 000000000..ba2f17d19 --- /dev/null +++ b/examples/tutorials/00_sync/040_pydantic_ai/Dockerfile @@ -0,0 +1,50 @@ +# syntax=docker/dockerfile:1.3 +FROM python:3.12-slim +COPY --from=ghcr.io/astral-sh/uv:0.6.4 /uv /uvx /bin/ + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + htop \ + vim \ + curl \ + tar \ + python3-dev \ + postgresql-client \ + build-essential \ + libpq-dev \ + gcc \ + cmake \ + netcat-openbsd \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN uv pip install --system --upgrade pip setuptools wheel + +ENV UV_HTTP_TIMEOUT=1000 + +# Copy pyproject.toml and README.md to install dependencies +COPY 00_sync/040_pydantic_ai/pyproject.toml /app/040_pydantic_ai/pyproject.toml +COPY 00_sync/040_pydantic_ai/README.md /app/040_pydantic_ai/README.md + +WORKDIR /app/040_pydantic_ai + +# Copy the project code +COPY 00_sync/040_pydantic_ai/project /app/040_pydantic_ai/project + +# Copy the test files +COPY 00_sync/040_pydantic_ai/tests /app/040_pydantic_ai/tests + +# Copy shared test utilities +COPY test_utils /app/test_utils + +# Install the required Python packages with dev dependencies +RUN uv pip install --system .[dev] + +# Set environment variables +ENV PYTHONPATH=/app + +# Set test environment variables +ENV AGENT_NAME=s040-pydantic-ai + +# Run the agent using uvicorn +CMD ["uvicorn", "project.acp:acp", "--host", "0.0.0.0", "--port", "8000"] diff --git a/examples/tutorials/00_sync/040_pydantic_ai/README.md b/examples/tutorials/00_sync/040_pydantic_ai/README.md new file mode 100644 index 000000000..02c3b57c7 --- /dev/null +++ b/examples/tutorials/00_sync/040_pydantic_ai/README.md @@ -0,0 +1,46 @@ +# Tutorial 040: Sync Pydantic AI Agent + +This tutorial demonstrates how to build a **synchronous** Pydantic AI agent on AgentEx with: +- Tool calling (Pydantic AI handles the tool loop internally) +- Streaming token output (including token-by-token tool-call argument streaming) + +## Key Concepts + +### Sync ACP +The sync ACP model uses HTTP request/response for communication. The `@acp.on_message_send` handler receives a message and yields streaming events back to the client. + +### Pydantic AI Integration +- **Agent**: A single `pydantic_ai.Agent` that owns the model and tools. No graph required — Pydantic AI runs its own tool-call loop until the model is done. +- **`@agent.tool_plain`**: Registers a Python function as a tool. Pydantic AI infers the schema from type hints and docstring. +- **`agent.run_stream_events(...)`**: Yields `AgentStreamEvent`s (PartStartEvent / PartDeltaEvent / PartEndEvent / FunctionToolResultEvent) as the model produces them. + +### Streaming +The agent streams tokens and tool-call arguments as they're generated using `convert_pydantic_ai_to_agentex_events()`, which adapts Pydantic AI's stream into AgentEx `TaskMessageUpdate` events. Notably, **tool-call arguments stream as `ToolRequestDelta` tokens** rather than arriving as a single complete payload — a richer experience than what OpenAI Agents SDK currently exposes. + +## Files + +| File | Description | +|------|-------------| +| `project/acp.py` | ACP server and message handler | +| `project/agent.py` | Pydantic AI agent + tool registration | +| `project/tools.py` | Tool definitions (weather example) | +| `tests/test_agent.py` | Integration tests | +| `manifest.yaml` | Agent configuration | + +## Running Locally + +```bash +# From this directory +agentex agents run +``` + +## Running Tests + +```bash +pytest tests/test_agent.py -v +``` + +## Notes + +- Multi-turn conversation memory is not wired in this tutorial. Pydantic AI does not ship a checkpointer like LangGraph; to add memory, load prior messages via `adk.messages.list(task_id=...)` and pass them to `agent.run_stream_events(..., message_history=...)`. +- Reasoning/thinking tokens are not exercised here because `gpt-4o-mini` does not emit `ThinkingPart`s. Swap to a reasoning-capable model (e.g. `openai:o1-mini` via Pydantic AI's appropriate provider) if you want to test that branch end-to-end. diff --git a/examples/tutorials/00_sync/040_pydantic_ai/manifest.yaml b/examples/tutorials/00_sync/040_pydantic_ai/manifest.yaml new file mode 100644 index 000000000..68d3b4a00 --- /dev/null +++ b/examples/tutorials/00_sync/040_pydantic_ai/manifest.yaml @@ -0,0 +1,58 @@ +build: + context: + root: ../../ + include_paths: + - 00_sync/040_pydantic_ai + - test_utils + dockerfile: 00_sync/040_pydantic_ai/Dockerfile + dockerignore: 00_sync/040_pydantic_ai/.dockerignore + +local_development: + agent: + port: 8000 + host_address: host.docker.internal + paths: + acp: project/acp.py + +agent: + acp_type: sync + name: s040-pydantic-ai + description: A sync Pydantic AI agent with tool calling and streaming + + temporal: + enabled: false + + credentials: + - env_var_name: OPENAI_API_KEY + secret_name: openai-api-key + secret_key: api-key + - env_var_name: REDIS_URL + secret_name: redis-url-secret + secret_key: url + - env_var_name: SGP_API_KEY + secret_name: sgp-api-key + secret_key: api-key + - env_var_name: SGP_ACCOUNT_ID + secret_name: sgp-account-id + secret_key: account-id + - env_var_name: SGP_CLIENT_BASE_URL + secret_name: sgp-client-base-url + secret_key: url + +deployment: + image: + repository: "" + tag: "latest" + + global: + agent: + name: "s040-pydantic-ai" + description: "A sync Pydantic AI agent with tool calling and streaming" + replicaCount: 1 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" diff --git a/examples/tutorials/00_sync/040_pydantic_ai/project/__init__.py b/examples/tutorials/00_sync/040_pydantic_ai/project/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/tutorials/00_sync/040_pydantic_ai/project/acp.py b/examples/tutorials/00_sync/040_pydantic_ai/project/acp.py new file mode 100644 index 000000000..0c096893f --- /dev/null +++ b/examples/tutorials/00_sync/040_pydantic_ai/project/acp.py @@ -0,0 +1,78 @@ +"""ACP (Agent Communication Protocol) handler for Agentex. + +This is the API layer — it owns the agent lifecycle and streams tokens +and tool calls from the Pydantic AI agent to the Agentex frontend. +""" + +from __future__ import annotations + +import os +from typing import AsyncGenerator + +from dotenv import load_dotenv + +load_dotenv() + +import agentex.lib.adk as adk +from project.agent import create_agent +from agentex.lib.adk import ( + create_pydantic_ai_tracing_handler, + convert_pydantic_ai_to_agentex_events, +) +from agentex.lib.types.acp import SendMessageParams +from agentex.lib.types.tracing import SGPTracingProcessorConfig +from agentex.lib.utils.logging import make_logger +from agentex.lib.sdk.fastacp.fastacp import FastACP +from agentex.types.task_message_update import TaskMessageUpdate +from agentex.types.task_message_content import TaskMessageContent +from agentex.lib.core.tracing.tracing_processor_manager import add_tracing_processor_config + +logger = make_logger(__name__) + +add_tracing_processor_config( + SGPTracingProcessorConfig( + sgp_api_key=os.environ.get("SGP_API_KEY", ""), + sgp_account_id=os.environ.get("SGP_ACCOUNT_ID", ""), + sgp_base_url=os.environ.get("SGP_CLIENT_BASE_URL", ""), + ) +) + +acp = FastACP.create(acp_type="sync") + +_agent = None + + +def get_agent(): + """Get or create the Pydantic AI agent instance.""" + global _agent + if _agent is None: + _agent = create_agent() + return _agent + + +@acp.on_message_send +async def handle_message_send( + params: SendMessageParams, +) -> TaskMessageContent | list[TaskMessageContent] | AsyncGenerator[TaskMessageUpdate, None]: + """Handle incoming messages from Agentex, streaming tokens and tool calls.""" + agent = get_agent() + task_id = params.task.id + + user_message = params.content.content + logger.info(f"Processing message for task {task_id}") + + async with adk.tracing.span( + trace_id=task_id, + task_id=task_id, + name="message", + input={"message": user_message}, + data={"__span_type__": "AGENT_WORKFLOW"}, + ) as turn_span: + tracing_handler = create_pydantic_ai_tracing_handler( + trace_id=task_id, + parent_span_id=turn_span.id if turn_span else None, + task_id=task_id, + ) + async with agent.run_stream_events(user_message) as stream: + async for event in convert_pydantic_ai_to_agentex_events(stream, tracing_handler=tracing_handler): + yield event diff --git a/examples/tutorials/00_sync/040_pydantic_ai/project/agent.py b/examples/tutorials/00_sync/040_pydantic_ai/project/agent.py new file mode 100644 index 000000000..2c0f6f10c --- /dev/null +++ b/examples/tutorials/00_sync/040_pydantic_ai/project/agent.py @@ -0,0 +1,39 @@ +"""Pydantic AI agent definition. + +The Agent is the boundary between this module and the API layer (acp.py). +Pydantic AI handles its own tool-call loop internally — no graph required. +""" + +from __future__ import annotations + +from datetime import datetime + +from pydantic_ai import Agent + +from project.tools import get_weather + +MODEL_NAME = "openai:gpt-4o-mini" +SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools. + +Current date and time: {timestamp} + +Guidelines: +- Be concise and helpful +- Use tools when they would help answer the user's question +- If you're unsure, ask clarifying questions +- Always provide accurate information +""" + + +def create_agent() -> Agent: + """Build and return the Pydantic AI agent with tools registered.""" + agent = Agent( + MODEL_NAME, + system_prompt=SYSTEM_PROMPT.format( + timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ), + ) + + agent.tool_plain(get_weather) + + return agent diff --git a/examples/tutorials/00_sync/040_pydantic_ai/project/tools.py b/examples/tutorials/00_sync/040_pydantic_ai/project/tools.py new file mode 100644 index 000000000..bab87942a --- /dev/null +++ b/examples/tutorials/00_sync/040_pydantic_ai/project/tools.py @@ -0,0 +1,20 @@ +"""Tool definitions for the Pydantic AI agent. + +Pydantic AI tools are registered directly on the Agent via decorators +(see project.agent). This module hosts the bare functions so they're +easy to unit-test in isolation. +""" + +from __future__ import annotations + + +def get_weather(city: str) -> str: + """Get the current weather for a city. + + Args: + city: The name of the city to get weather for. + + Returns: + A string describing the weather conditions. + """ + return f"The weather in {city} is sunny and 72°F" diff --git a/examples/tutorials/00_sync/040_pydantic_ai/pyproject.toml b/examples/tutorials/00_sync/040_pydantic_ai/pyproject.toml new file mode 100644 index 000000000..f1840931a --- /dev/null +++ b/examples/tutorials/00_sync/040_pydantic_ai/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "s040-pydantic-ai" +version = "0.1.0" +description = "A sync Pydantic AI agent with tool calling and streaming" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "agentex-sdk", + "scale-gp", + "pydantic-ai-slim[openai]>=1.0,<2", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-asyncio", + "httpx", + "black", + "isort", + "flake8", +] + +[tool.uv.sources] +agentex-sdk = { path = "../../../..", editable = true } + +[tool.hatch.build.targets.wheel] +packages = ["project"] + +[tool.black] +line-length = 88 +target-version = ['py312'] + +[tool.isort] +profile = "black" +line_length = 88 diff --git a/examples/tutorials/00_sync/040_pydantic_ai/tests/test_agent.py b/examples/tutorials/00_sync/040_pydantic_ai/tests/test_agent.py new file mode 100644 index 000000000..d3deed1c7 --- /dev/null +++ b/examples/tutorials/00_sync/040_pydantic_ai/tests/test_agent.py @@ -0,0 +1,135 @@ +"""Tests for the sync Pydantic AI agent. + +This test suite validates: +- Non-streaming message sending with tool-calling Pydantic AI agent +- Streaming message sending with token-by-token output + +To run these tests: +1. Make sure the agent is running (via docker-compose or `agentex agents run`) +2. Set the AGENTEX_API_BASE_URL environment variable if not using default +3. Run: pytest test_agent.py -v + +Configuration: +- AGENTEX_API_BASE_URL: Base URL for the AgentEx server (default: http://localhost:5003) +- AGENT_NAME: Name of the agent to test (default: s040-pydantic-ai) +""" + +import os + +import pytest +from test_utils.sync import validate_text_in_string, collect_streaming_response + +from agentex import Agentex +from agentex.types import TextContentParam +from agentex.types.agent_rpc_params import ParamsSendMessageRequest + +AGENTEX_API_BASE_URL = os.environ.get("AGENTEX_API_BASE_URL", "http://localhost:5003") +AGENT_NAME = os.environ.get("AGENT_NAME", "s040-pydantic-ai") + + +@pytest.fixture +def client(): + """Create an AgentEx client instance for testing.""" + return Agentex(base_url=AGENTEX_API_BASE_URL) + + +@pytest.fixture +def agent_name(): + """Return the agent name for testing.""" + return AGENT_NAME + + +@pytest.fixture +def agent_id(client, agent_name): + """Retrieve the agent ID based on the agent name.""" + agents = client.agents.list() + for agent in agents: + if agent.name == agent_name: + return agent.id + raise ValueError(f"Agent with name {agent_name} not found.") + + +class TestNonStreamingMessages: + """Test non-streaming message sending with Pydantic AI agent.""" + + def test_send_simple_message(self, client: Agentex, agent_name: str): + """Test sending a simple message and receiving a response.""" + response = client.agents.send_message( + agent_name=agent_name, + params=ParamsSendMessageRequest( + content=TextContentParam( + author="user", + content="Hello! What can you help me with?", + type="text", + ) + ), + ) + result = response.result + assert result is not None + assert len(result) >= 1 + + def test_tool_calling(self, client: Agentex, agent_name: str): + """Test that the agent can use tools (e.g., weather tool).""" + response = client.agents.send_message( + agent_name=agent_name, + params=ParamsSendMessageRequest( + content=TextContentParam( + author="user", + content="What's the weather in San Francisco?", + type="text", + ) + ), + ) + result = response.result + assert result is not None + assert len(result) >= 1 + + +class TestStreamingMessages: + """Test streaming message sending with Pydantic AI agent.""" + + def test_stream_simple_message(self, client: Agentex, agent_name: str): + """Test streaming a simple message response.""" + stream = client.agents.send_message_stream( + agent_name=agent_name, + params=ParamsSendMessageRequest( + content=TextContentParam( + author="user", + content="Tell me a short joke.", + type="text", + ) + ), + ) + + aggregated_content, chunks = collect_streaming_response(stream) + + assert aggregated_content is not None + assert len(chunks) > 1, "No chunks received in streaming response." + + def test_stream_tool_calling(self, client: Agentex, agent_name: str): + """Test streaming with tool calls. + + This exercises the headline Pydantic AI converter feature: + tool-call argument tokens streaming through as ToolRequestDelta. + """ + stream = client.agents.send_message_stream( + agent_name=agent_name, + params=ParamsSendMessageRequest( + content=TextContentParam( + author="user", + content="What's the weather in New York? Respond with the temperature.", + type="text", + ) + ), + ) + + aggregated_content, chunks = collect_streaming_response(stream) + + assert aggregated_content is not None + assert len(chunks) > 0, "No chunks received in streaming response." + # The weather tool always returns "72°F", so the agent's reply should mention it. + validate_text_in_string("72", aggregated_content) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/tutorials/10_async/00_base/110_pydantic_ai/.dockerignore b/examples/tutorials/10_async/00_base/110_pydantic_ai/.dockerignore new file mode 100644 index 000000000..c49489471 --- /dev/null +++ b/examples/tutorials/10_async/00_base/110_pydantic_ai/.dockerignore @@ -0,0 +1,43 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Environments +.env** +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Git +.git +.gitignore + +# Misc +.DS_Store diff --git a/examples/tutorials/10_async/00_base/110_pydantic_ai/Dockerfile b/examples/tutorials/10_async/00_base/110_pydantic_ai/Dockerfile new file mode 100644 index 000000000..906d62068 --- /dev/null +++ b/examples/tutorials/10_async/00_base/110_pydantic_ai/Dockerfile @@ -0,0 +1,50 @@ +# syntax=docker/dockerfile:1.3 +FROM python:3.12-slim +COPY --from=ghcr.io/astral-sh/uv:0.6.4 /uv /uvx /bin/ + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + htop \ + vim \ + curl \ + tar \ + python3-dev \ + postgresql-client \ + build-essential \ + libpq-dev \ + gcc \ + cmake \ + netcat-openbsd \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN uv pip install --system --upgrade pip setuptools wheel + +ENV UV_HTTP_TIMEOUT=1000 + +# Copy pyproject.toml and README.md to install dependencies +COPY 10_async/00_base/110_pydantic_ai/pyproject.toml /app/110_pydantic_ai/pyproject.toml +COPY 10_async/00_base/110_pydantic_ai/README.md /app/110_pydantic_ai/README.md + +WORKDIR /app/110_pydantic_ai + +# Copy the project code +COPY 10_async/00_base/110_pydantic_ai/project /app/110_pydantic_ai/project + +# Copy the test files +COPY 10_async/00_base/110_pydantic_ai/tests /app/110_pydantic_ai/tests + +# Copy shared test utilities +COPY test_utils /app/test_utils + +# Install the required Python packages with dev dependencies +RUN uv pip install --system .[dev] pytest-asyncio httpx + +# Set environment variables +ENV PYTHONPATH=/app + +# Set test environment variables +ENV AGENT_NAME=ab110-pydantic-ai + +# Run the agent using uvicorn +CMD ["uvicorn", "project.acp:acp", "--host", "0.0.0.0", "--port", "8000"] diff --git a/examples/tutorials/10_async/00_base/110_pydantic_ai/README.md b/examples/tutorials/10_async/00_base/110_pydantic_ai/README.md new file mode 100644 index 000000000..6046b579a --- /dev/null +++ b/examples/tutorials/10_async/00_base/110_pydantic_ai/README.md @@ -0,0 +1,63 @@ +# Tutorial 110 (async/base): Pydantic AI Agent + +This tutorial demonstrates how to build an **async** Pydantic AI agent on AgentEx with: +- Tool calling (Pydantic AI handles the tool loop internally) +- Streaming token output via Redis (text + reasoning tokens stream as deltas) +- Task lifecycle hooks (create / event-send / cancel) + +This is the async counterpart to the sync tutorial at [`00_sync/040_pydantic_ai`](../../../00_sync/040_pydantic_ai/). + +## Key Concepts + +### Async ACP +Unlike sync ACP (HTTP request/response with chunked streaming back), async ACP uses **Redis** for streaming. The HTTP call returns immediately when an event is acknowledged; the agent then pushes updates to Redis on its own schedule. The UI subscribes to Redis to receive deltas. + +### Pydantic AI Integration +- **Agent**: A single `pydantic_ai.Agent` that owns the model and tools. No graph required. +- **`@agent.tool_plain`**: Registers a Python function as a tool. Pydantic AI infers the schema from type hints and docstring. +- **`agent.run_stream_events(...)`**: Yields `AgentStreamEvent`s (`PartStartEvent` / `PartDeltaEvent` / `PartEndEvent` / `FunctionToolResultEvent`) as the model produces them. + +### Streaming +The helper `stream_pydantic_ai_events(stream, task_id)` consumes the Pydantic AI event stream and writes Agentex updates to Redis via `adk.streaming.streaming_task_message_context(...)`: +- **Text and thinking tokens** stream as Redis deltas inside coalesced contexts. +- **Tool requests and tool responses** are emitted as **discrete full messages** (no token-level arg streaming). To stream tool-call argument tokens, use the sync converter — see [`00_sync/040_pydantic_ai`](../../../00_sync/040_pydantic_ai/). + +## Files + +| File | Description | +|------|-------------| +| `project/acp.py` | Async ACP server with task lifecycle handlers | +| `project/agent.py` | Pydantic AI agent + tool registration | +| `project/tools.py` | Tool definitions (weather example) | +| `tests/test_agent.py` | Integration tests | +| `manifest.yaml` | Agent configuration | + +## Running Locally + +```bash +# From this directory +agentex agents run +``` + +## Running Tests + +```bash +pytest tests/test_agent.py -v +``` + +## Sync vs Async — How the Code Differs + +This tutorial uses the same `project/agent.py` and `project/tools.py` as the sync version. The only meaningful differences live in `project/acp.py`: + +| Concern | Sync (`s040-pydantic-ai`) | Async (`ab110-pydantic-ai`) | +|---|---|---| +| ACP type | `FastACP.create(acp_type="sync")` | `FastACP.create(acp_type="async", config=AsyncACPConfig(type="base"))` | +| Handler hook | `@acp.on_message_send` returns/yields events | `@acp.on_task_event_send` returns nothing | +| Stream output | `yield event` (chunked HTTP) | `await context.stream_update(...)` (Redis) | +| Tool calls | Args stream as `ToolRequestDelta` tokens | Args arrive in one full message | +| Lifecycle | Ephemeral (no task hooks) | `on_task_create` + `on_task_cancel` form a durable task contract | + +## Notes + +- Multi-turn conversation memory is not wired here. Pydantic AI does not ship a checkpointer; to add memory, load prior messages via `adk.messages.list(task_id=...)` and pass them to `agent.run_stream_events(..., message_history=...)`. +- Reasoning/thinking tokens are not exercised by `gpt-4o-mini`. Swap to a reasoning-capable model if you want to test that branch end-to-end. diff --git a/examples/tutorials/10_async/00_base/110_pydantic_ai/manifest.yaml b/examples/tutorials/10_async/00_base/110_pydantic_ai/manifest.yaml new file mode 100644 index 000000000..583b07251 --- /dev/null +++ b/examples/tutorials/10_async/00_base/110_pydantic_ai/manifest.yaml @@ -0,0 +1,58 @@ +build: + context: + root: ../../../ + include_paths: + - 10_async/00_base/110_pydantic_ai + - test_utils + dockerfile: 10_async/00_base/110_pydantic_ai/Dockerfile + dockerignore: 10_async/00_base/110_pydantic_ai/.dockerignore + +local_development: + agent: + port: 8000 + host_address: host.docker.internal + paths: + acp: project/acp.py + +agent: + acp_type: async + name: ab110-pydantic-ai + description: An async Pydantic AI agent with tool calling and Redis streaming + + temporal: + enabled: false + + credentials: + - env_var_name: OPENAI_API_KEY + secret_name: openai-api-key + secret_key: api-key + - env_var_name: REDIS_URL + secret_name: redis-url-secret + secret_key: url + - env_var_name: SGP_API_KEY + secret_name: sgp-api-key + secret_key: api-key + - env_var_name: SGP_ACCOUNT_ID + secret_name: sgp-account-id + secret_key: account-id + - env_var_name: SGP_CLIENT_BASE_URL + secret_name: sgp-client-base-url + secret_key: url + +deployment: + image: + repository: "" + tag: "latest" + + global: + agent: + name: "ab110-pydantic-ai" + description: "An async Pydantic AI agent with tool calling and Redis streaming" + replicaCount: 1 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" diff --git a/examples/tutorials/10_async/00_base/110_pydantic_ai/project/__init__.py b/examples/tutorials/10_async/00_base/110_pydantic_ai/project/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/tutorials/10_async/00_base/110_pydantic_ai/project/acp.py b/examples/tutorials/10_async/00_base/110_pydantic_ai/project/acp.py new file mode 100644 index 000000000..0fcd36dc7 --- /dev/null +++ b/examples/tutorials/10_async/00_base/110_pydantic_ai/project/acp.py @@ -0,0 +1,92 @@ +"""ACP handler for async Pydantic AI agent. + +Uses the async ACP model with Redis streaming instead of HTTP yields. +Text and reasoning tokens stream as Redis deltas; tool requests and +responses are persisted as discrete full messages. +""" + +from __future__ import annotations + +import os + +from dotenv import load_dotenv + +load_dotenv() + +import agentex.lib.adk as adk +from project.agent import create_agent +from agentex.lib.adk import ( + stream_pydantic_ai_events, + create_pydantic_ai_tracing_handler, +) +from agentex.lib.types.acp import SendEventParams, CancelTaskParams, CreateTaskParams +from agentex.lib.types.fastacp import AsyncACPConfig +from agentex.lib.types.tracing import SGPTracingProcessorConfig +from agentex.lib.utils.logging import make_logger +from agentex.lib.sdk.fastacp.fastacp import FastACP +from agentex.lib.core.tracing.tracing_processor_manager import add_tracing_processor_config + +logger = make_logger(__name__) + +add_tracing_processor_config( + SGPTracingProcessorConfig( + sgp_api_key=os.environ.get("SGP_API_KEY", ""), + sgp_account_id=os.environ.get("SGP_ACCOUNT_ID", ""), + sgp_base_url=os.environ.get("SGP_CLIENT_BASE_URL", ""), + ) +) + +acp = FastACP.create( + acp_type="async", + config=AsyncACPConfig(type="base"), +) + +_agent = None + + +def get_agent(): + global _agent + if _agent is None: + _agent = create_agent() + return _agent + + +@acp.on_task_event_send +async def handle_task_event_send(params: SendEventParams): + """Handle incoming events, streaming tokens and tool calls via Redis.""" + agent = get_agent() + task_id = params.task.id + user_message = params.event.content.content + + logger.info(f"Processing message for thread {task_id}") + + # Echo the user's message into the task history. + await adk.messages.create(task_id=task_id, content=params.event.content) + + async with adk.tracing.span( + trace_id=task_id, + task_id=task_id, + name="message", + input={"message": user_message}, + data={"__span_type__": "AGENT_WORKFLOW"}, + ) as turn_span: + tracing_handler = create_pydantic_ai_tracing_handler( + trace_id=task_id, + parent_span_id=turn_span.id if turn_span else None, + task_id=task_id, + ) + async with agent.run_stream_events(user_message) as stream: + final_output = await stream_pydantic_ai_events(stream, task_id, tracing_handler=tracing_handler) + + if turn_span: + turn_span.output = {"final_output": final_output} + + +@acp.on_task_create +async def handle_task_create(params: CreateTaskParams): + logger.info(f"Task created: {params.task.id}") + + +@acp.on_task_cancel +async def handle_task_canceled(params: CancelTaskParams): + logger.info(f"Task canceled: {params.task.id}") diff --git a/examples/tutorials/10_async/00_base/110_pydantic_ai/project/agent.py b/examples/tutorials/10_async/00_base/110_pydantic_ai/project/agent.py new file mode 100644 index 000000000..2c0f6f10c --- /dev/null +++ b/examples/tutorials/10_async/00_base/110_pydantic_ai/project/agent.py @@ -0,0 +1,39 @@ +"""Pydantic AI agent definition. + +The Agent is the boundary between this module and the API layer (acp.py). +Pydantic AI handles its own tool-call loop internally — no graph required. +""" + +from __future__ import annotations + +from datetime import datetime + +from pydantic_ai import Agent + +from project.tools import get_weather + +MODEL_NAME = "openai:gpt-4o-mini" +SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools. + +Current date and time: {timestamp} + +Guidelines: +- Be concise and helpful +- Use tools when they would help answer the user's question +- If you're unsure, ask clarifying questions +- Always provide accurate information +""" + + +def create_agent() -> Agent: + """Build and return the Pydantic AI agent with tools registered.""" + agent = Agent( + MODEL_NAME, + system_prompt=SYSTEM_PROMPT.format( + timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ), + ) + + agent.tool_plain(get_weather) + + return agent diff --git a/examples/tutorials/10_async/00_base/110_pydantic_ai/project/tools.py b/examples/tutorials/10_async/00_base/110_pydantic_ai/project/tools.py new file mode 100644 index 000000000..98f65d509 --- /dev/null +++ b/examples/tutorials/10_async/00_base/110_pydantic_ai/project/tools.py @@ -0,0 +1,20 @@ +"""Tool definitions for the async Pydantic AI agent. + +Pydantic AI tools are registered directly on the Agent via decorators +(see project.agent). This module hosts the bare functions so they're +easy to unit-test in isolation. +""" + +from __future__ import annotations + + +def get_weather(city: str) -> str: + """Get the current weather for a city. + + Args: + city: The name of the city to get weather for. + + Returns: + A string describing the weather conditions. + """ + return f"The weather in {city} is sunny and 72°F" diff --git a/examples/tutorials/10_async/00_base/110_pydantic_ai/pyproject.toml b/examples/tutorials/10_async/00_base/110_pydantic_ai/pyproject.toml new file mode 100644 index 000000000..c3fbabae1 --- /dev/null +++ b/examples/tutorials/10_async/00_base/110_pydantic_ai/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "ab110-pydantic-ai" +version = "0.1.0" +description = "An async Pydantic AI agent with tool calling and Redis streaming" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "agentex-sdk", + "scale-gp", + "pydantic-ai-slim[openai]>=1.0,<2", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-asyncio", + "httpx", + "black", + "isort", + "flake8", +] + +[tool.uv.sources] +agentex-sdk = { path = "../../../../..", editable = true } + +[tool.hatch.build.targets.wheel] +packages = ["project"] + +[tool.black] +line-length = 88 +target-version = ['py312'] + +[tool.isort] +profile = "black" +line_length = 88 diff --git a/examples/tutorials/10_async/00_base/110_pydantic_ai/tests/test_agent.py b/examples/tutorials/10_async/00_base/110_pydantic_ai/tests/test_agent.py new file mode 100644 index 000000000..a31322d30 --- /dev/null +++ b/examples/tutorials/10_async/00_base/110_pydantic_ai/tests/test_agent.py @@ -0,0 +1,121 @@ +"""Tests for the async Pydantic AI agent. + +This test suite validates: +- Non-streaming event sending and polling +- Streaming event sending + +To run these tests: +1. Make sure the agent is running (via docker-compose or `agentex agents run`) +2. Set the AGENTEX_API_BASE_URL environment variable if not using default +3. Run: pytest test_agent.py -v + +Configuration: +- AGENTEX_API_BASE_URL: Base URL for the AgentEx server (default: http://localhost:5003) +- AGENT_NAME: Name of the agent to test (default: ab110-pydantic-ai) +""" + +import os + +import pytest +import pytest_asyncio + +from agentex import AsyncAgentex +from agentex.types import TextContentParam +from agentex.types.agent_rpc_params import ParamsCreateTaskRequest +from agentex.lib.sdk.fastacp.base.base_acp_server import uuid + +AGENTEX_API_BASE_URL = os.environ.get("AGENTEX_API_BASE_URL", "http://localhost:5003") +AGENT_NAME = os.environ.get("AGENT_NAME", "ab110-pydantic-ai") + + +@pytest_asyncio.fixture +async def client(): + """Create an AsyncAgentex client instance for testing.""" + client = AsyncAgentex(base_url=AGENTEX_API_BASE_URL) + yield client + await client.close() + + +@pytest.fixture +def agent_name(): + """Return the agent name for testing.""" + return AGENT_NAME + + +@pytest_asyncio.fixture +async def agent_id(client, agent_name): + """Retrieve the agent ID based on the agent name.""" + agents = await client.agents.list() + for agent in agents: + if agent.name == agent_name: + return agent.id + raise ValueError(f"Agent with name {agent_name} not found.") + + +class TestNonStreamingEvents: + """Test non-streaming event sending and polling.""" + + @pytest.mark.asyncio + async def test_send_event(self, client: AsyncAgentex, agent_id: str): + """Test sending an event to the async Pydantic AI agent.""" + task_response = await client.agents.create_task( + agent_id, params=ParamsCreateTaskRequest(name=uuid.uuid1().hex) + ) + task = task_response.result + assert task is not None + + event_content = TextContentParam( + type="text", + author="user", + content="Hello! What can you help me with?", + ) + await client.agents.send_event( + agent_id=agent_id, + params={"task_id": task.id, "content": event_content}, + ) + + @pytest.mark.asyncio + async def test_tool_calling(self, client: AsyncAgentex, agent_id: str): + """Test that the agent can use tools (e.g., weather tool).""" + task_response = await client.agents.create_task( + agent_id, params=ParamsCreateTaskRequest(name=uuid.uuid1().hex) + ) + task = task_response.result + assert task is not None + + event_content = TextContentParam( + type="text", + author="user", + content="What's the weather in San Francisco?", + ) + await client.agents.send_event( + agent_id=agent_id, + params={"task_id": task.id, "content": event_content}, + ) + + +class TestStreamingEvents: + """Test streaming event sending.""" + + @pytest.mark.asyncio + async def test_send_event_and_stream(self, client: AsyncAgentex, agent_id: str): + """Test sending an event and streaming the response.""" + task_response = await client.agents.create_task( + agent_id, params=ParamsCreateTaskRequest(name=uuid.uuid1().hex) + ) + task = task_response.result + assert task is not None + + event_content = TextContentParam( + type="text", + author="user", + content="Tell me a short joke.", + ) + await client.agents.send_event( + agent_id=agent_id, + params={"task_id": task.id, "content": event_content}, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/.dockerignore b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/.dockerignore new file mode 100644 index 000000000..c49489471 --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/.dockerignore @@ -0,0 +1,43 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Environments +.env** +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Git +.git +.gitignore + +# Misc +.DS_Store diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/Dockerfile b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/Dockerfile new file mode 100644 index 000000000..17b0db8a0 --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/Dockerfile @@ -0,0 +1,43 @@ +# syntax=docker/dockerfile:1.3 +FROM python:3.12-slim +COPY --from=ghcr.io/astral-sh/uv:0.6.4 /uv /uvx /bin/ + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + htop \ + vim \ + curl \ + tar \ + python3-dev \ + postgresql-client \ + build-essential \ + libpq-dev \ + gcc \ + cmake \ + netcat-openbsd \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN uv pip install --system --upgrade pip setuptools wheel + +ENV UV_HTTP_TIMEOUT=1000 + +COPY 10_async/10_temporal/110_pydantic_ai/pyproject.toml /app/110_pydantic_ai/pyproject.toml +COPY 10_async/10_temporal/110_pydantic_ai/README.md /app/110_pydantic_ai/README.md + +WORKDIR /app/110_pydantic_ai + +COPY 10_async/10_temporal/110_pydantic_ai/project /app/110_pydantic_ai/project +COPY 10_async/10_temporal/110_pydantic_ai/tests /app/110_pydantic_ai/tests +COPY test_utils /app/test_utils + +RUN uv pip install --system .[dev] + +ENV PYTHONPATH=/app + +ENV AGENT_NAME=at110-pydantic-ai + +CMD ["uvicorn", "project.acp:acp", "--host", "0.0.0.0", "--port", "8000"] + +# When we deploy the worker, we will replace the CMD with the following +# CMD ["python", "-m", "run_worker"] diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/README.md b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/README.md new file mode 100644 index 000000000..b221c1238 --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/README.md @@ -0,0 +1,153 @@ +# Tutorial 110 (temporal): Pydantic AI Agent + +This tutorial demonstrates a **durable** Pydantic AI agent on AgentEx, backed by Temporal: +- Workflow state survives crashes mid-conversation (Temporal replay) +- Every LLM call and every tool call becomes its own Temporal activity (independent retries + observability) +- Streaming via Redis still works — token-by-token deltas appear in the UI in real time + +This is the Temporal counterpart to the async base tutorial at [`10_async/00_base/110_pydantic_ai/`](../../00_base/110_pydantic_ai/). + +## Why Temporal? Why not just async? + +In async base 110, the agent state lives in memory inside the ACP process. If that process dies mid-LLM-call, the in-flight turn is lost. Temporal fixes this by: + +1. Recording every external interaction (LLM call, tool call) to a durable event log. +2. On worker restart, **replaying** the workflow code, using cached activity results to skip work that already finished. +3. Letting workflows live forever — multi-day conversations or human-in-the-loop flows just work. + +## Architecture at a glance + +Two long-running processes plus shared infrastructure: + +``` +┌──────────────────────────┐ ┌──────────────────────────┐ +│ uvicorn project.acp:acp │ │ python -m run_worker │ +│ (HTTP shim, forwards │ │ (executes workflows + │ +│ signals to Temporal) │ │ activities) │ +└──────────────────────────┘ └──────────────────────────┘ + │ │ + └────► Temporal server ◄───────────┘ + (event log + queue) + + Redis ◄─── activities push deltas + │ + └─── Agentex API tails ──► UI client +``` + +The HTTP server is a thin shim that translates `task/event/send` into Temporal signals. The worker is where your agent code actually runs. Temporal sits in between, recording everything. + +## Key code patterns + +### `project/agent.py` — wrap the base agent in `TemporalAgent` + +```python +base_agent = Agent(MODEL_NAME, deps_type=TaskDeps, system_prompt=...) +base_agent.tool_plain(get_weather) + +temporal_agent = TemporalAgent( + base_agent, + name="at110_pydantic_ai_agent", + event_stream_handler=event_handler, # streams to Redis from inside the model activity +) +``` + +`TemporalAgent` (from `pydantic_ai.durable_exec.temporal`) wraps a normal Pydantic AI Agent so that: +- Each LLM call runs in its own activity +- Each tool call runs in its own activity +- The wrapping is invisible to the workflow code that calls `temporal_agent.run(...)` + +### `project/workflow.py` — declare `__pydantic_ai_agents__` + +```python +@workflow.defn(name=environment_variables.WORKFLOW_NAME) +class At110PydanticAiWorkflow(BaseWorkflow): + __pydantic_ai_agents__ = [temporal_agent] # ← discovered by PydanticAIPlugin + + @workflow.signal(name=SignalName.RECEIVE_EVENT) + async def on_task_event_send(self, params): + await adk.messages.create(task_id=params.task.id, content=params.event.content) + result = await temporal_agent.run( + params.event.content.content, + deps=TaskDeps(task_id=params.task.id), + ) +``` + +The `__pydantic_ai_agents__` attribute is how `PydanticAIPlugin` discovers which activities to register on the worker — no manual activity list needed. + +### `project/acp.py` — no handlers, just plugin wiring + +```python +acp = FastACP.create( + acp_type="async", + config=TemporalACPConfig( + type="temporal", + temporal_address=os.getenv("TEMPORAL_ADDRESS", "localhost:7233"), + plugins=[PydanticAIPlugin()], + ), +) +``` + +When `type="temporal"`, FastACP auto-wires HTTP → workflow signals. You don't define `@acp.on_task_event_send` anywhere — Temporal handles it. + +### `project/run_worker.py` — boot the worker with the plugin + +```python +worker = AgentexWorker( + task_queue=task_queue_name, + plugins=[PydanticAIPlugin()], +) +await worker.run( + activities=get_all_activities(), + workflow=At110PydanticAiWorkflow, +) +``` + +`get_all_activities()` returns the built-in Agentex activities (state, messages, streaming, tracing). Pydantic AI's per-agent activities are auto-added by the plugin. + +## Files + +| File | Purpose | +|------|---------| +| `project/acp.py` | Thin HTTP shim — `FastACP.create(type="temporal", ...)` | +| `project/workflow.py` | `@workflow.defn` class with the signal handler | +| `project/agent.py` | Base Pydantic AI Agent wrapped in `TemporalAgent` | +| `project/tools.py` | Tool functions (must be `async` for Temporal compatibility) | +| `project/run_worker.py` | Worker boot script (separate process) | +| `tests/test_agent.py` | End-to-end test verifying tool round-trips | +| `manifest.yaml` | Sets `temporal.enabled: true` and declares workflow + queue name | + +## Running Locally + +You'll need three terminals open (this is the price of Temporal): + +```bash +# Terminal 1 — backend services (separate repo) +cd ~/scale-agentex/agentex +make dev # brings up Temporal, Redis, Postgres, Agentex API + +# Terminal 2 — this tutorial (ACP server + Temporal worker) +cd ~/scale-agentex-python/examples/tutorials/10_async/10_temporal/110_pydantic_ai +agentex agents run # this also launches the worker process + +# Terminal 3 — tests +cd ~/scale-agentex-python/examples/tutorials/10_async/10_temporal/110_pydantic_ai +uv run pytest tests/test_agent.py -v +``` + +Watch the Temporal UI at http://localhost:8233 — you'll see workflow executions, signal events, and one activity per LLM call + one per tool call. + +## Sync vs Async vs Temporal — How the code differs + +| Concern | Sync (040) | Async base (110) | Temporal (this one) | +|---|---|---|---| +| `project/acp.py` | `@acp.on_message_send` yields events | `@acp.on_task_event_send` pushes to Redis | **No handlers** — `FastACP.create(type="temporal", ...)` | +| Where the agent runs | In the ACP HTTP process | In the ACP HTTP process | In a separate worker process | +| Durability | Ephemeral — request-scoped | Ephemeral — process-scoped | **Durable** — survives worker restarts via Temporal replay | +| Per-call retries | None | None | Each model + tool call automatically retried by Temporal | +| Code we add | — | `acp.py` handler | `workflow.py`, `run_worker.py`, wrap agent in `TemporalAgent` | + +## Notes + +- Multi-turn conversation memory is not wired here. Workflow state (`self._turn_number`) is durable, but message history isn't currently threaded into `temporal_agent.run(..., message_history=...)`. To add: load via `adk.messages.list(task_id=...)` inside the signal handler and pass through. +- Reasoning/thinking tokens are not exercised by `gpt-4o-mini`. Swap to a reasoning-capable model to exercise that branch end-to-end. +- Tools must be `async` (Pydantic AI's Temporal integration requires it — sync tools would run in threads, breaking Temporal's determinism guarantees). diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/manifest.yaml b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/manifest.yaml new file mode 100644 index 000000000..15d00076f --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/manifest.yaml @@ -0,0 +1,64 @@ +build: + context: + root: ../../../ + include_paths: + - 10_async/10_temporal/110_pydantic_ai + - test_utils + dockerfile: 10_async/10_temporal/110_pydantic_ai/Dockerfile + dockerignore: 10_async/10_temporal/110_pydantic_ai/.dockerignore + +local_development: + agent: + port: 8000 + host_address: host.docker.internal + paths: + acp: project/acp.py + worker: project/run_worker.py + +agent: + acp_type: async + name: at110-pydantic-ai + description: A Temporal-backed Pydantic AI agent with tool calling and Redis streaming + + temporal: + enabled: true + workflows: + - name: at110-pydantic-ai + queue_name: at110_pydantic_ai_queue + + credentials: + - env_var_name: REDIS_URL + secret_name: redis-url-secret + secret_key: url + - env_var_name: OPENAI_API_KEY + secret_name: openai-api-key + secret_key: api-key + - env_var_name: SGP_API_KEY + secret_name: sgp-api-key + secret_key: api-key + - env_var_name: SGP_ACCOUNT_ID + secret_name: sgp-account-id + secret_key: account-id + - env_var_name: SGP_CLIENT_BASE_URL + secret_name: sgp-client-base-url + secret_key: url + # env: + # OPENAI_BASE_URL: "https://your-litellm-proxy/v1" + +deployment: + image: + repository: "" + tag: "latest" + + global: + agent: + name: "at110-pydantic-ai" + description: "A Temporal-backed Pydantic AI agent" + replicaCount: 1 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/__init__.py b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/acp.py b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/acp.py new file mode 100644 index 000000000..dacb45ad6 --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/acp.py @@ -0,0 +1,35 @@ +"""ACP server for the Temporal Pydantic AI tutorial. + +This file is intentionally thin. When ``acp_type="async"`` is combined +with ``TemporalACPConfig(type="temporal", ...)``, FastACP auto-wires: + + HTTP task/create → @workflow.run on the workflow class + HTTP task/event/send → @workflow.signal(SignalName.RECEIVE_EVENT) + HTTP task/cancel → workflow cancellation via the Temporal client + +so we don't define any handlers here. The actual agent code lives in +``project/workflow.py`` and is executed by the Temporal worker +(``project/run_worker.py``), not by this HTTP process. +""" + +from __future__ import annotations + +import os + +from dotenv import load_dotenv + +load_dotenv() + +from pydantic_ai.durable_exec.temporal import PydanticAIPlugin + +from agentex.lib.types.fastacp import TemporalACPConfig +from agentex.lib.sdk.fastacp.fastacp import FastACP + +acp = FastACP.create( + acp_type="async", + config=TemporalACPConfig( + type="temporal", + temporal_address=os.getenv("TEMPORAL_ADDRESS", "localhost:7233"), + plugins=[PydanticAIPlugin()], + ), +) diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/agent.py b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/agent.py new file mode 100644 index 000000000..a33a317cc --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/agent.py @@ -0,0 +1,108 @@ +"""Pydantic AI agent definition for the Temporal tutorial. + +This module constructs the base ``pydantic_ai.Agent`` once at import time, +registers tools on it, and wraps it in ``TemporalAgent`` from +``pydantic_ai.durable_exec.temporal``. + +The ``TemporalAgent`` wrapper makes every model call and every tool call +run as a Temporal activity automatically. The workflow code stays +deterministic; the non-deterministic work (LLM HTTP calls, tool execution) +moves into recorded activities. + +Streaming back to Agentex happens via ``event_stream_handler``, which +receives Pydantic AI ``AgentStreamEvent``s from inside the model activity +and forwards them to Redis using our existing ``stream_pydantic_ai_events`` +helper. The ``task_id`` is threaded into the handler via ``deps``. +""" + +from __future__ import annotations + +from datetime import datetime +from collections.abc import AsyncIterable + +from pydantic import BaseModel +from pydantic_ai import Agent, RunContext +from pydantic_ai.messages import AgentStreamEvent +from pydantic_ai.durable_exec.temporal import TemporalAgent + +from project.tools import get_weather +from agentex.lib.adk import ( + stream_pydantic_ai_events, + create_pydantic_ai_tracing_handler, +) + +MODEL_NAME = "openai:gpt-4o-mini" +SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools. + +Current date and time: {timestamp} + +Guidelines: +- Be concise and helpful +- Use tools when they would help answer the user's question +- If you're unsure, ask clarifying questions +- Always provide accurate information +""" + + +class TaskDeps(BaseModel): + """Per-run dependencies passed into the agent via ``deps=``. + + Pydantic AI's ``RunContext.deps`` is the canonical place to thread + request-scoped data (like the Agentex task_id) into tools and + event handlers — including code that runs inside Temporal activities. + """ + + task_id: str + # When set, the event handler nests per-tool-call spans under this + # span. Typically the ID of the per-turn span opened by the workflow. + parent_span_id: str | None = None + + +def _build_base_agent() -> Agent[TaskDeps, str]: + """Build the underlying Pydantic AI agent with tools registered. + + Tools must be registered BEFORE the agent is wrapped in TemporalAgent; + changes to tool registration after wrapping are not reflected. + """ + agent: Agent[TaskDeps, str] = Agent( + MODEL_NAME, + deps_type=TaskDeps, + system_prompt=SYSTEM_PROMPT.format(timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + ) + agent.tool_plain(get_weather) + return agent + + +async def event_handler( + run_context: RunContext[TaskDeps], + events: AsyncIterable[AgentStreamEvent], +) -> None: + """Stream Pydantic AI events to Agentex via Redis from inside the model activity. + + Pydantic AI calls this with the live event stream as soon as the model + activity begins emitting parts. Because the handler runs inside the + activity (not the workflow), it can freely make non-deterministic + Redis writes — including the tracing HTTP calls that record per-tool-call + spans under the workflow's per-turn span (when ``parent_span_id`` is set). + """ + tracing_handler = create_pydantic_ai_tracing_handler( + trace_id=run_context.deps.task_id, + parent_span_id=run_context.deps.parent_span_id, + task_id=run_context.deps.task_id, + ) + await stream_pydantic_ai_events( + events, + run_context.deps.task_id, + tracing_handler=tracing_handler, + ) + + +# Construct the durable agent at module load time so that the +# PydanticAIPlugin can auto-discover its activities via the workflow's +# ``__pydantic_ai_agents__`` attribute. +base_agent = _build_base_agent() +temporal_agent: TemporalAgent[TaskDeps, str] = TemporalAgent( + base_agent, + name="at110_pydantic_ai_agent", + event_stream_handler=event_handler, +) diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/run_worker.py b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/run_worker.py new file mode 100644 index 000000000..e54c9d1dc --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/run_worker.py @@ -0,0 +1,48 @@ +"""Temporal worker for the Pydantic AI tutorial. + +Run as a separate long-lived process alongside the ACP HTTP server. The +worker polls Temporal for workflow + activity tasks and executes them. + +The ``PydanticAIPlugin`` reads ``__pydantic_ai_agents__`` off the workflow +class and registers every model/tool activity the TemporalAgent needs — +so we don't have to enumerate activities by hand here. +""" + +import asyncio + +from pydantic_ai.durable_exec.temporal import PydanticAIPlugin + +from project.workflow import At110PydanticAiWorkflow +from agentex.lib.utils.debug import setup_debug_if_enabled +from agentex.lib.utils.logging import make_logger +from agentex.lib.environment_variables import EnvironmentVariables +from agentex.lib.core.temporal.activities import get_all_activities +from agentex.lib.core.temporal.workers.worker import AgentexWorker + +environment_variables = EnvironmentVariables.refresh() +logger = make_logger(__name__) + + +async def main(): + setup_debug_if_enabled() + + task_queue_name = environment_variables.WORKFLOW_TASK_QUEUE + if task_queue_name is None: + raise ValueError("WORKFLOW_TASK_QUEUE is not set") + + # get_all_activities() returns the built-in Agentex activities (state, + # messages, streaming, tracing). Pydantic AI's TemporalAgent activities + # are auto-registered by PydanticAIPlugin via __pydantic_ai_agents__. + worker = AgentexWorker( + task_queue=task_queue_name, + plugins=[PydanticAIPlugin()], + ) + + await worker.run( + activities=get_all_activities(), + workflow=At110PydanticAiWorkflow, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/tools.py b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/tools.py new file mode 100644 index 000000000..75640fcb7 --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/tools.py @@ -0,0 +1,25 @@ +"""Tool definitions for the Temporal Pydantic AI agent. + +These functions are registered on the base Pydantic AI agent. When the agent +is wrapped in ``TemporalAgent``, each tool call becomes its own Temporal +activity automatically — independently retryable and observable in the +Temporal UI. + +Tools must be ``async`` because Pydantic AI's Temporal integration requires +it: non-async tools would run in threads, which is non-deterministic and +unsafe for Temporal replay. +""" + +from __future__ import annotations + + +async def get_weather(city: str) -> str: + """Get the current weather for a city. + + Args: + city: The name of the city to get weather for. + + Returns: + A string describing the weather conditions. + """ + return f"The weather in {city} is sunny and 72°F" diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/workflow.py b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/workflow.py new file mode 100644 index 000000000..aff4cbd99 --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/project/workflow.py @@ -0,0 +1,121 @@ +"""Temporal workflow for the Pydantic AI tutorial. + +The workflow holds task state durably across crashes. Its signal handler +delegates the actual agent run to ``temporal_agent.run(...)`` — which +internally schedules model and tool activities, each independently +durable. The ``event_stream_handler`` registered on ``temporal_agent`` +pushes streaming deltas to Redis while the model activity runs. +""" + +from __future__ import annotations + +import os +import json + +from temporalio import workflow + +from agentex.lib import adk +from project.agent import TaskDeps, temporal_agent +from agentex.lib.types.acp import SendEventParams, CreateTaskParams +from agentex.lib.types.tracing import SGPTracingProcessorConfig +from agentex.lib.utils.logging import make_logger +from agentex.types.text_content import TextContent +from agentex.lib.environment_variables import EnvironmentVariables +from agentex.lib.core.temporal.types.workflow import SignalName +from agentex.lib.core.temporal.workflows.workflow import BaseWorkflow +from agentex.lib.core.tracing.tracing_processor_manager import ( + add_tracing_processor_config, +) + +add_tracing_processor_config( + SGPTracingProcessorConfig( + sgp_api_key=os.environ.get("SGP_API_KEY", ""), + sgp_account_id=os.environ.get("SGP_ACCOUNT_ID", ""), + sgp_base_url=os.environ.get("SGP_CLIENT_BASE_URL", ""), + ) +) + +environment_variables = EnvironmentVariables.refresh() + +if environment_variables.WORKFLOW_NAME is None: + raise ValueError("Environment variable WORKFLOW_NAME is not set") +if environment_variables.AGENT_NAME is None: + raise ValueError("Environment variable AGENT_NAME is not set") + +logger = make_logger(__name__) + + +@workflow.defn(name=environment_variables.WORKFLOW_NAME) +class At110PydanticAiWorkflow(BaseWorkflow): + """Long-running Temporal workflow that delegates each turn to a Pydantic AI TemporalAgent. + + The ``__pydantic_ai_agents__`` attribute is the marker the + ``PydanticAIPlugin`` looks for at worker startup: it pulls + ``temporal_agent.temporal_activities`` off this list and registers them + on the worker automatically — so we don't have to list activities by + hand in ``run_worker.py``. + """ + + __pydantic_ai_agents__ = [temporal_agent] + + def __init__(self): + super().__init__(display_name=environment_variables.AGENT_NAME) + self._complete_task = False + self._turn_number = 0 + + @workflow.signal(name=SignalName.RECEIVE_EVENT) + async def on_task_event_send(self, params: SendEventParams) -> None: + """Handle a new user message: echo it, then run the agent durably.""" + logger.info(f"Received task event: {params.task.id}") + self._turn_number += 1 + + # Echo the user's message so it shows up in the UI as a chat bubble. + await adk.messages.create(task_id=params.task.id, content=params.event.content) + + async with adk.tracing.span( + trace_id=params.task.id, + task_id=params.task.id, + name=f"Turn {self._turn_number}", + input={"message": params.event.content.content}, + ) as span: + # temporal_agent.run() is the magic line. From the outside it + # looks like a regular async call. Internally it schedules: + # 1. A model activity (LLM HTTP call recorded by Temporal) + # 2. For each tool the model invokes, a tool activity + # 3. Each activity is retried, observable, and durable + # While the model activity runs, the event_stream_handler on + # temporal_agent pushes deltas to Redis so the UI sees tokens. + result = await temporal_agent.run( + params.event.content.content, + deps=TaskDeps( + task_id=params.task.id, + parent_span_id=span.id if span else None, + ), + ) + if span: + span.output = {"final_output": result.output} + + @workflow.run + async def on_task_create(self, params: CreateTaskParams) -> str: + """Workflow entry point — keep the conversation alive for incoming signals.""" + logger.info(f"Task created: {params.task.id}") + + await adk.messages.create( + task_id=params.task.id, + content=TextContent( + author="agent", + content=( + f"Task initialized with params:\n{json.dumps(params.params, indent=2)}\n" + f"Send me a message and I'll respond using a Pydantic AI agent backed by Temporal." + ), + ), + ) + + await workflow.wait_condition(lambda: self._complete_task, timeout=None) + return "Task completed" + + @workflow.signal + async def complete_task_signal(self) -> None: + """Graceful workflow shutdown signal.""" + logger.info("Received complete_task signal") + self._complete_task = True diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/pyproject.toml b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/pyproject.toml new file mode 100644 index 000000000..3f18f8a4f --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "at110-pydantic-ai" +version = "0.1.0" +description = "A Temporal-backed Pydantic AI agent with tool calling and Redis streaming" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "agentex-sdk", + "scale-gp", + "temporalio>=1.18.2", + "pydantic-ai-slim[openai]>=1.0,<2", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-asyncio", + "httpx", + "black", + "isort", + "flake8", + "debugpy>=1.8.15", +] + +[tool.uv.sources] +agentex-sdk = { path = "../../../../..", editable = true } + +[tool.hatch.build.targets.wheel] +packages = ["project"] + +[tool.black] +line-length = 88 +target-version = ['py312'] + +[tool.isort] +profile = "black" +line_length = 88 diff --git a/examples/tutorials/10_async/10_temporal/110_pydantic_ai/tests/test_agent.py b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/tests/test_agent.py new file mode 100644 index 000000000..d01276ab8 --- /dev/null +++ b/examples/tutorials/10_async/10_temporal/110_pydantic_ai/tests/test_agent.py @@ -0,0 +1,127 @@ +"""Tests for the Temporal Pydantic AI agent. + +This test suite validates: +- The agent responds to a basic message +- Tool calls are visible in the message history (proving each tool call + ran as its own Temporal activity) + +To run these tests: +1. Make sure the agent is running (worker + ACP server) +2. Set AGENTEX_API_BASE_URL if not using the default +3. Run: pytest tests/test_agent.py -v +""" + +import os +import uuid + +import pytest +import pytest_asyncio +from test_utils.async_utils import ( + poll_messages, + send_event_and_poll_yielding, +) + +from agentex import AsyncAgentex +from agentex.types.task_message import TaskMessage +from agentex.types.agent_rpc_params import ParamsCreateTaskRequest + +AGENTEX_API_BASE_URL = os.environ.get("AGENTEX_API_BASE_URL", "http://localhost:5003") +AGENT_NAME = os.environ.get("AGENT_NAME", "at110-pydantic-ai") + + +@pytest_asyncio.fixture +async def client(): + client = AsyncAgentex(base_url=AGENTEX_API_BASE_URL) + yield client + await client.close() + + +@pytest.fixture +def agent_name(): + return AGENT_NAME + + +@pytest_asyncio.fixture +async def agent_id(client, agent_name): + agents = await client.agents.list() + for agent in agents: + if agent.name == agent_name: + return agent.id + raise ValueError(f"Agent with name {agent_name} not found.") + + +class TestNonStreamingEvents: + """Test that the Temporal-backed Pydantic AI agent responds and uses tools.""" + + @pytest.mark.asyncio + async def test_send_event_and_poll(self, client: AsyncAgentex, agent_id: str): + """Drive a full turn: create task, send a weather question, verify tool round-trip.""" + task_response = await client.agents.create_task( + agent_id, params=ParamsCreateTaskRequest(name=uuid.uuid1().hex) + ) + task = task_response.result + assert task is not None + + # Wait for the welcome message from on_task_create + task_creation_found = False + async for message in poll_messages( + client=client, + task_id=task.id, + timeout=30, + sleep_interval=1.0, + ): + assert isinstance(message, TaskMessage) + if ( + message.content + and message.content.type == "text" + and message.content.author == "agent" + ): + task_creation_found = True + break + assert task_creation_found, "Task creation welcome message not found" + + # Ask about weather — the agent should call get_weather + seen_tool_request = False + seen_tool_response = False + final_message = None + async for message in send_event_and_poll_yielding( + client=client, + agent_id=agent_id, + task_id=task.id, + user_message="What is the weather in San Francisco?", + timeout=60, + sleep_interval=1.0, + ): + assert isinstance(message, TaskMessage) + + if message.content and message.content.type == "tool_request": + seen_tool_request = True + if message.content and message.content.type == "tool_response": + seen_tool_response = True + if final_message and getattr(final_message, "streaming_status", None) == "DONE": + break + + if ( + message.content + and message.content.type == "text" + and message.content.author == "agent" + ): + final_message = message + content_length = len(getattr(message.content, "content", "") or "") + if message.streaming_status == "DONE" and content_length > 0: + if not seen_tool_request or seen_tool_response: + break + + assert seen_tool_request, "Expected a tool_request (agent calling get_weather)" + assert seen_tool_response, "Expected a tool_response (get_weather result)" + assert final_message is not None, "Expected a final agent text message" + final_text = ( + getattr(final_message.content, "content", None) if final_message.content else None + ) + assert isinstance(final_text, str) and len(final_text) > 0 + # The get_weather tool always returns "72°F" — the response should mention it. + assert "72" in final_text, "Expected weather response to mention 72°F" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/pyproject.toml b/pyproject.toml index 547fc9cf9..2627f762d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "mcp[cli]>=1.4.1", "scale-gp>=0.1.0a59", "openai-agents==0.14.1", + "pydantic-ai-slim>=1.0,<2", "tzlocal>=5.3.1", "tzdata>=2025.2", "pytest>=8.4.0", diff --git a/src/agentex/lib/adk/__init__.py b/src/agentex/lib/adk/__init__.py index f177128a3..cbff5a3fe 100644 --- a/src/agentex/lib/adk/__init__.py +++ b/src/agentex/lib/adk/__init__.py @@ -9,6 +9,9 @@ from agentex.lib.adk._modules._langgraph_tracing import create_langgraph_tracing_handler from agentex.lib.adk._modules._langgraph_async import stream_langgraph_events from agentex.lib.adk._modules._langgraph_sync import convert_langgraph_to_agentex_events +from agentex.lib.adk._modules._pydantic_ai_async import stream_pydantic_ai_events +from agentex.lib.adk._modules._pydantic_ai_sync import convert_pydantic_ai_to_agentex_events +from agentex.lib.adk._modules._pydantic_ai_tracing import create_pydantic_ai_tracing_handler from agentex.lib.adk._modules.events import EventsModule from agentex.lib.adk._modules.messages import MessagesModule from agentex.lib.adk._modules.state import StateModule @@ -40,13 +43,15 @@ "tracing", "events", "agent_task_tracker", - # Checkpointing / LangGraph "create_checkpointer", "create_langgraph_tracing_handler", "stream_langgraph_events", "convert_langgraph_to_agentex_events", - + # Pydantic AI + "stream_pydantic_ai_events", + "convert_pydantic_ai_to_agentex_events", + "create_pydantic_ai_tracing_handler", # Providers "providers", # Utils diff --git a/src/agentex/lib/adk/_modules/_pydantic_ai_async.py b/src/agentex/lib/adk/_modules/_pydantic_ai_async.py new file mode 100644 index 000000000..1f7a3cd6c --- /dev/null +++ b/src/agentex/lib/adk/_modules/_pydantic_ai_async.py @@ -0,0 +1,267 @@ +"""Async Pydantic AI streaming helper for Agentex. + +Consumes a Pydantic AI ``agent.run_stream_events(...)`` async iterator and +pushes Agentex streaming updates to Redis via the ``adk.streaming`` +contexts. For use with async ACP agents that stream via Redis rather than +HTTP yields. + +Text and thinking tokens stream as deltas inside coalesced streaming +contexts. Tool requests and tool results are emitted as full +``adk.messages.create(...)`` calls (Option A — matches the async LangGraph +helper's convention). To stream tool-call argument tokens, see the sync +converter at ``agentex.lib.adk._modules._pydantic_ai_sync`` which yields +``ToolRequestDelta`` events. + +Tracing is opt-in via a ``tracing_handler`` parameter — see +``create_pydantic_ai_tracing_handler`` in +``agentex.lib.adk._modules._pydantic_ai_tracing``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from agentex.lib.adk._modules._pydantic_ai_tracing import ( + AgentexPydanticAITracingHandler, + ) + + +async def stream_pydantic_ai_events( + stream, + task_id: str, + tracing_handler: "AgentexPydanticAITracingHandler | None" = None, +) -> str: + """Stream Pydantic AI events to Agentex via Redis. + + Args: + stream: Async iterator yielded by ``agent.run_stream_events(...)``. + task_id: The Agentex task ID to stream messages to. + tracing_handler: Optional handler from + ``create_pydantic_ai_tracing_handler(...)``. When provided, each + tool call in the run is also recorded as an Agentex child span + beneath the handler's configured ``parent_span_id``. Streaming + behavior is unchanged when omitted. + + Returns: + The accumulated text content of the **last** text part in the run. + Multi-step runs (where the model emits text, then a tool call, then + more text) return only the final text segment, matching the + ``stream_langgraph_events`` convention. + """ + # Lazy imports so pydantic-ai isn't required at module load time. + import json + + from pydantic_ai.messages import ( + TextPart, + PartEndEvent, + ThinkingPart, + ToolCallPart, + TextPartDelta, + PartDeltaEvent, + PartStartEvent, + ThinkingPartDelta, + FunctionToolResultEvent, + ) + + from agentex.lib import adk + from agentex.types.text_content import TextContent + from agentex.types.reasoning_content import ReasoningContent + from agentex.types.task_message_delta import TextDelta + from agentex.types.task_message_update import StreamTaskMessageDelta + from agentex.types.tool_request_content import ToolRequestContent + from agentex.types.tool_response_content import ToolResponseContent + from agentex.types.reasoning_content_delta import ReasoningContentDelta + + text_context = None + reasoning_context = None + final_text = "" + + # Per Pydantic-AI part-index bookkeeping. Part indices restart at 0 on + # each new model response, so we overwrite on PartStartEvent. + part_kind: dict[int, str] = {} + tool_call_info: dict[int, tuple[str, str]] = {} + + async def _close_text(): + nonlocal text_context + if text_context: + await text_context.close() + text_context = None + + async def _close_reasoning(): + nonlocal reasoning_context + if reasoning_context: + await reasoning_context.close() + reasoning_context = None + + try: + async for event in stream: + if isinstance(event, PartStartEvent): + if isinstance(event.part, TextPart): + await _close_reasoning() + await _close_text() + + final_text = "" + text_context = await adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=TextContent( + author="agent", + content="", + format="markdown", + ), + ).__aenter__() + part_kind[event.index] = "text" + + # Pydantic AI puts the first streaming chunk in + # PartStartEvent.part.content; surface it as a Delta so it + # actually renders (Start.content is initialization, not body). + if event.part.content: + final_text += event.part.content + await text_context.stream_update( + StreamTaskMessageDelta( + parent_task_message=text_context.task_message, + delta=TextDelta(type="text", text_delta=event.part.content), + type="delta", + ) + ) + + elif isinstance(event.part, ThinkingPart): + await _close_text() + await _close_reasoning() + + reasoning_context = await adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=ReasoningContent( + author="agent", + summary=[], + content=[], + type="reasoning", + style="active", + ), + ).__aenter__() + part_kind[event.index] = "reasoning" + + if event.part.content: + await reasoning_context.stream_update( + StreamTaskMessageDelta( + parent_task_message=reasoning_context.task_message, + delta=ReasoningContentDelta( + type="reasoning_content", + content_index=0, + content_delta=event.part.content, + ), + type="delta", + ) + ) + + elif isinstance(event.part, ToolCallPart): + await _close_text() + await _close_reasoning() + tool_call_info[event.index] = ( + event.part.tool_call_id, + event.part.tool_name, + ) + part_kind[event.index] = "tool_call" + + elif isinstance(event, PartDeltaEvent): + kind = part_kind.get(event.index) + if kind == "text" and isinstance(event.delta, TextPartDelta) and text_context: + final_text += event.delta.content_delta + await text_context.stream_update( + StreamTaskMessageDelta( + parent_task_message=text_context.task_message, + delta=TextDelta(type="text", text_delta=event.delta.content_delta), + type="delta", + ) + ) + elif ( + kind == "reasoning" + and isinstance(event.delta, ThinkingPartDelta) + and reasoning_context + and event.delta.content_delta + ): + await reasoning_context.stream_update( + StreamTaskMessageDelta( + parent_task_message=reasoning_context.task_message, + delta=ReasoningContentDelta( + type="reasoning_content", + content_index=0, + content_delta=event.delta.content_delta, + ), + type="delta", + ) + ) + # Tool-call arg deltas: Pydantic AI accumulates them; we + # surface the final args on PartEndEvent below (Option A). + + elif isinstance(event, PartEndEvent): + kind = part_kind.get(event.index) + if kind == "text": + await _close_text() + elif kind == "reasoning": + await _close_reasoning() + elif kind == "tool_call" and isinstance(event.part, ToolCallPart): + tool_call_id, tool_name = tool_call_info.get(event.index, ("", "")) + args = event.part.args + if isinstance(args, str): + try: + args = json.loads(args) if args else {} + except json.JSONDecodeError: + args = {"_raw": args} + elif args is None: + args = {} + await adk.messages.create( + task_id=task_id, + content=ToolRequestContent( + tool_call_id=tool_call_id, + name=tool_name, + arguments=args, + author="agent", + ), + ) + if tracing_handler is not None and tool_call_id: + await tracing_handler.on_tool_start( + tool_call_id=tool_call_id, + tool_name=tool_name, + arguments=args, + ) + + elif isinstance(event, FunctionToolResultEvent): + await _close_text() + await _close_reasoning() + + result = event.part + tool_call_id = result.tool_call_id + tool_name = getattr(result, "tool_name", "") or "" + content = getattr(result, "content", None) + if content is None: + content_str = str(result) + elif isinstance(content, str): + content_str = content + else: + content_str = str(content) + await adk.messages.create( + task_id=task_id, + content=ToolResponseContent( + tool_call_id=tool_call_id, + name=tool_name, + content=content_str, + author="agent", + ), + ) + if tracing_handler is not None and tool_call_id: + await tracing_handler.on_tool_end( + tool_call_id=tool_call_id, + result=content_str, + ) + + # FunctionToolCallEvent / FinalResultEvent / AgentRunResultEvent + # are intentionally ignored — same as the sync converter. + + finally: + if text_context: + await text_context.close() + if reasoning_context: + await reasoning_context.close() + + return final_text diff --git a/src/agentex/lib/adk/_modules/_pydantic_ai_sync.py b/src/agentex/lib/adk/_modules/_pydantic_ai_sync.py new file mode 100644 index 000000000..b13d9b173 --- /dev/null +++ b/src/agentex/lib/adk/_modules/_pydantic_ai_sync.py @@ -0,0 +1,331 @@ +"""Pydantic AI streaming integration for Agentex. + +Converts a Pydantic AI ``AgentStreamEvent`` stream (as yielded by +``agent.run_stream_events(...)`` or via an ``event_stream_handler``) into the +Agentex ``StreamTaskMessage*`` events that the Agentex server understands. + +Typical sync usage: + + from pydantic_ai import Agent + from agentex.lib.adk import convert_pydantic_ai_to_agentex_events + + agent = Agent("openai:gpt-4o", system_prompt="...") + + @acp.on_message_send + async def handle_message_send(params): + async with agent.run_stream_events(params.content.content) as stream: + async for event in convert_pydantic_ai_to_agentex_events(stream): + yield event +""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, AsyncIterator + +from pydantic_ai.run import AgentRunResultEvent + +if TYPE_CHECKING: + from agentex.lib.adk._modules._pydantic_ai_tracing import ( + AgentexPydanticAITracingHandler, + ) +from pydantic_ai.messages import ( + TextPart, + PartEndEvent, + ThinkingPart, + ToolCallPart, + TextPartDelta, + PartDeltaEvent, + PartStartEvent, + ToolReturnPart, + FinalResultEvent, + ThinkingPartDelta, + ToolCallPartDelta, + FunctionToolCallEvent, + FunctionToolResultEvent, +) + +from agentex.lib.utils.logging import make_logger +from agentex.types.task_message_delta import TextDelta +from agentex.types.tool_request_delta import ToolRequestDelta +from agentex.types.task_message_update import ( + StreamTaskMessageDone, + StreamTaskMessageFull, + StreamTaskMessageDelta, + StreamTaskMessageStart, +) +from agentex.types.task_message_content import TextContent +from agentex.types.tool_request_content import ToolRequestContent +from agentex.types.tool_response_content import ToolResponseContent +from agentex.types.reasoning_content_delta import ReasoningContentDelta + +logger = make_logger(__name__) + + +def _args_delta_to_str(args_delta: str | dict[str, Any] | None) -> str: + """Normalize a Pydantic AI ``ToolCallPartDelta.args_delta`` to a string fragment. + + Pydantic AI emits string fragments for providers that stream JSON tokens + (OpenAI, Anthropic) and dicts for providers that emit one-shot tool calls. + Agentex's ``ToolRequestDelta.arguments_delta`` is concatenated server-side + and parsed as a single JSON object on completion, so we always produce a + string. For dict deltas this is a one-shot dump; subsequent dict deltas + will not compose correctly, but in practice dict deltas arrive as a single + final fragment. + """ + if args_delta is None: + return "" + if isinstance(args_delta, str): + return args_delta + return json.dumps(args_delta) + + +def _tool_return_content(result: ToolReturnPart | Any) -> Any: + """Best-effort extraction of the user-visible content from a tool result. + + ``FunctionToolResultEvent.part`` is ``ToolReturnPart | RetryPromptPart``. + For ``ToolReturnPart`` we surface ``.content`` directly; for ``RetryPromptPart`` + (a retry signal back to the model) we surface a string description so the + UI sees the failure reason. + """ + content = getattr(result, "content", None) + if content is None: + return str(result) + if isinstance(content, (str, int, float, bool, list, dict)): + return content + if hasattr(content, "model_dump"): + try: + return content.model_dump() + except Exception: + return str(content) + return str(content) + + +async def convert_pydantic_ai_to_agentex_events( + stream_response: AsyncIterator[Any], + tracing_handler: "AgentexPydanticAITracingHandler | None" = None, +) -> AsyncIterator[StreamTaskMessageStart | StreamTaskMessageDelta | StreamTaskMessageFull | StreamTaskMessageDone]: + """Convert a Pydantic AI agent event stream into Agentex stream events. + + Mapping: + PartStartEvent(TextPart) -> StreamTaskMessageStart(TextContent) + PartStartEvent(ThinkingPart) -> StreamTaskMessageStart(TextContent) [reasoning channel] + PartStartEvent(ToolCallPart) -> StreamTaskMessageStart(ToolRequestContent) + PartDeltaEvent(TextPartDelta) -> StreamTaskMessageDelta(TextDelta) + PartDeltaEvent(ThinkingPart..) -> StreamTaskMessageDelta(ReasoningContentDelta) + PartDeltaEvent(ToolCallPart..) -> StreamTaskMessageDelta(ToolRequestDelta) + PartEndEvent -> StreamTaskMessageDone + FunctionToolResultEvent -> StreamTaskMessageFull(ToolResponseContent) + FunctionToolCallEvent -> (ignored — already covered by Start/Delta/End) + FinalResultEvent -> (ignored — informational; the run-level + AgentRunResultEvent terminates the stream) + AgentRunResultEvent -> (ignored — Agentex closes the per-message + stream via PartEndEvent already) + + Args: + stream_response: The async iterator yielded by Pydantic AI's + ``agent.run_stream_events(...)`` context manager (or a stream of + ``AgentStreamEvent`` items received in an ``event_stream_handler``). + tracing_handler: Optional handler from + ``create_pydantic_ai_tracing_handler(...)``. When provided, each + tool call in the run is also recorded as an Agentex child span + beneath the handler's configured ``parent_span_id``. Streaming + behavior is unchanged when omitted. + + Yields: + Agentex ``StreamTaskMessage*`` events suitable for forwarding back over + the ACP streaming response. + """ + next_message_index = 0 + # Maps Pydantic AI's per-response part index to our absolute message index. + # Part indices restart at 0 on each new model response in a multi-step run, + # so we always overwrite the entry on PartStartEvent. + part_to_message_index: dict[int, int] = {} + # Tool-call metadata indexed by Pydantic AI part index (so deltas can + # surface the tool_call_id even when ToolCallPartDelta.tool_call_id is None). + tool_call_meta: dict[int, tuple[str, str]] = {} + + async for event in stream_response: + if isinstance(event, PartStartEvent): + message_index = next_message_index + next_message_index += 1 + part_to_message_index[event.index] = message_index + + if isinstance(event.part, TextPart): + yield StreamTaskMessageStart( + type="start", + index=message_index, + content=TextContent( + type="text", + author="agent", + content="", + ), + ) + if event.part.content: + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=TextDelta(type="text", text_delta=event.part.content), + ) + elif isinstance(event.part, ThinkingPart): + yield StreamTaskMessageStart( + type="start", + index=message_index, + content=TextContent( + type="text", + author="agent", + content="", + ), + ) + if event.part.content: + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=ReasoningContentDelta( + type="reasoning_content", + content_index=0, + content_delta=event.part.content, + ), + ) + elif isinstance(event.part, ToolCallPart): + tool_call_meta[event.index] = (event.part.tool_call_id, event.part.tool_name) + # Pydantic AI may already have a fully-formed args dict at start + # when the provider returns the tool call in one shot; surface it + # directly so clients see the complete arguments without waiting + # for deltas. + initial_args: dict[str, Any] = {} + if isinstance(event.part.args, dict): + # dict(...) materializes a fresh dict[str, Any]; pydantic-ai's + # ToolCallPart.args includes TypedDict-style variants that + # pyright doesn't narrow to plain dict[str, Any] via isinstance. + initial_args = dict(event.part.args) + yield StreamTaskMessageStart( + type="start", + index=message_index, + content=ToolRequestContent( + type="tool_request", + author="agent", + tool_call_id=event.part.tool_call_id, + name=event.part.tool_name, + arguments=initial_args, + ), + ) + if isinstance(event.part.args, str) and event.part.args: + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=ToolRequestDelta( + type="tool_request", + tool_call_id=event.part.tool_call_id, + name=event.part.tool_name, + arguments_delta=event.part.args, + ), + ) + else: + logger.debug("Unhandled PartStartEvent part type: %r", type(event.part).__name__) + + elif isinstance(event, PartDeltaEvent): + message_index = part_to_message_index.get(event.index) + if message_index is None: + logger.debug("PartDeltaEvent for unknown part index %s; skipping", event.index) + continue + + if isinstance(event.delta, TextPartDelta): + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=TextDelta(type="text", text_delta=event.delta.content_delta), + ) + elif isinstance(event.delta, ThinkingPartDelta): + if event.delta.content_delta: + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=ReasoningContentDelta( + type="reasoning_content", + content_index=0, + content_delta=event.delta.content_delta, + ), + ) + elif isinstance(event.delta, ToolCallPartDelta): + meta = tool_call_meta.get(event.index) + if meta is None: + # First time we've seen this part; the provider didn't emit + # a PartStartEvent first. Synthesize one from the delta if + # we have enough information. + tool_call_id = event.delta.tool_call_id or "" + tool_name = event.delta.tool_name_delta or "" + tool_call_meta[event.index] = (tool_call_id, tool_name) + else: + tool_call_id, tool_name = meta + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=ToolRequestDelta( + type="tool_request", + tool_call_id=tool_call_id, + name=tool_name, + arguments_delta=_args_delta_to_str(event.delta.args_delta), + ), + ) + else: + logger.debug("Unhandled PartDeltaEvent delta type: %r", type(event.delta).__name__) + + elif isinstance(event, PartEndEvent): + message_index = part_to_message_index.get(event.index) + if message_index is None: + continue + yield StreamTaskMessageDone(type="done", index=message_index) + # Tool-call parts end with the model's full args known. Open a + # tracing child span for the tool execution now; close it when + # FunctionToolResultEvent arrives below. + if tracing_handler is not None and isinstance(event.part, ToolCallPart) and event.part.tool_call_id: + args: dict[str, Any] | str | None + raw_args = event.part.args + if isinstance(raw_args, dict): + args = dict(raw_args) + elif isinstance(raw_args, str): + try: + args = json.loads(raw_args) if raw_args else {} + except json.JSONDecodeError: + args = {"_raw": raw_args} + else: + args = {} + await tracing_handler.on_tool_start( + tool_call_id=event.part.tool_call_id, + tool_name=event.part.tool_name, + arguments=args, + ) + + elif isinstance(event, FunctionToolResultEvent): + result = event.part + tool_call_id = result.tool_call_id + tool_name = getattr(result, "tool_name", "") or "" + message_index = next_message_index + next_message_index += 1 + content_payload = _tool_return_content(result) + yield StreamTaskMessageFull( + type="full", + index=message_index, + content=ToolResponseContent( + type="tool_response", + author="agent", + tool_call_id=tool_call_id, + name=tool_name, + content=content_payload, + ), + ) + if tracing_handler is not None and tool_call_id: + await tracing_handler.on_tool_end( + tool_call_id=tool_call_id, + result=content_payload, + ) + + elif isinstance(event, (FunctionToolCallEvent, FinalResultEvent, AgentRunResultEvent)): + # Already covered by PartStart/PartDelta/PartEnd events above, or + # informational only (FinalResultEvent / AgentRunResultEvent signal + # run-level state, not new content to surface). + continue + + else: + logger.debug("Unhandled Pydantic AI event type: %r", type(event).__name__) diff --git a/src/agentex/lib/adk/_modules/_pydantic_ai_tracing.py b/src/agentex/lib/adk/_modules/_pydantic_ai_tracing.py new file mode 100644 index 000000000..aa9d906eb --- /dev/null +++ b/src/agentex/lib/adk/_modules/_pydantic_ai_tracing.py @@ -0,0 +1,182 @@ +"""Tracing handler that records Agentex spans for tool calls in a pydantic-ai agent run. + +Mirrors the LangGraph tracing handler pattern: the caller creates a handler +bound to a ``trace_id`` and a ``parent_span_id``, then hands it to +``stream_pydantic_ai_events(..., tracing_handler=handler)``. The streamer +calls ``on_tool_start`` / ``on_tool_end`` as it observes the corresponding +events in the agent stream, and the handler records one Agentex child span +per tool call. + +Why a handler-on-the-streamer rather than an OpenTelemetry bridge: +pydantic-ai exposes its stream of ``AgentStreamEvent`` directly, and that +stream already contains every signal we need to record tool spans. Going +through an OTel processor would require setting up an OTel ``TracerProvider`` +plus a bridge processor — that's a much larger investment, and orthogonal +to the streaming path we already own. This handler hooks into the same +event stream the UI-streaming helper consumes, so a single pass over the +events produces both: live deltas on Redis and child spans on the AgentEx +tracing pipeline. + +Why span IDs are derived from ``tool_call_id`` instead of held in a dict: +pydantic-ai's ``TemporalAgent`` splits the agent run across one or more +Temporal activities. The ``event_stream_handler`` is invoked once per +activity, with a fresh handler instance each time. So ``on_tool_start`` +(emitted inside the model activity that issued the tool call) and +``on_tool_end`` (emitted inside the next model activity, after the tool +runs) land in different handler instances — an in-memory dict can't pair +them. Deriving the span ID deterministically from ``(trace_id, +tool_call_id)`` makes the open/close pairing stateless: ``on_tool_end`` +re-derives the same ID and PATCHes the existing span directly. + +Span hierarchy produced:: + + (e.g. "Turn N", created by the caller) + ├── tool: (one child span per tool call) + └── tool: +""" + +from __future__ import annotations + +import uuid +from typing import Any +from datetime import UTC, datetime + +from agentex import AsyncAgentex +from agentex.lib.utils.logging import make_logger +from agentex.lib.adk._modules.tracing import TracingModule +from agentex.lib.adk.utils._modules.client import create_async_agentex_client + +logger = make_logger(__name__) + + +# Stable namespace for deriving tool-call span IDs. The exact UUID value is +# arbitrary; it just needs to be a constant so the same (trace_id, tool_call_id) +# always maps to the same span ID across handler invocations. +_TOOL_SPAN_NAMESPACE = uuid.UUID("8c2f9a2b-3e4d-4b5a-9c1f-0a1b2c3d4e5f") + + +def _tool_span_id(trace_id: str, tool_call_id: str) -> str: + """Deterministic span ID for a given tool call within a trace.""" + return str(uuid.uuid5(_TOOL_SPAN_NAMESPACE, f"{trace_id}:{tool_call_id}")) + + +class AgentexPydanticAITracingHandler: + """Records Agentex tracing spans for tool calls observed in a pydantic-ai event stream. + + Pass an instance to ``stream_pydantic_ai_events(..., tracing_handler=...)`` + or call ``on_tool_start`` / ``on_tool_end`` yourself if you're consuming + the event stream by hand. + """ + + def __init__( + self, + trace_id: str, + parent_span_id: str | None = None, + task_id: str | None = None, + tracing: TracingModule | None = None, + client: AsyncAgentex | None = None, + ) -> None: + self._trace_id = trace_id + self._parent_span_id = parent_span_id + # task_id on the span record (separate from trace_id) is what the + # AgentEx UI's per-task spans dropdown filters by. If you want your + # tool spans visible in that dropdown, set this to the task ID. + self._task_id = task_id + # ``_tracing`` is retained for callers / tests that want to inject a + # mocked TracingModule, even though the on_tool_* methods now go + # direct to the AgentEx client (see module docstring for why). + self._tracing_eager = tracing + self._tracing_lazy: TracingModule | None = None + # Defer client construction until first use so httpx binds to the + # running event loop (matches the TracingModule pattern). + self._client_eager = client + self._client_lazy: AsyncAgentex | None = None + + @property + def _tracing(self) -> TracingModule: + if self._tracing_eager is not None: + return self._tracing_eager + if self._tracing_lazy is None: + self._tracing_lazy = TracingModule() + return self._tracing_lazy + + @property + def _client(self) -> AsyncAgentex: + if self._client_eager is not None: + return self._client_eager + if self._client_lazy is None: + self._client_lazy = create_async_agentex_client() + return self._client_lazy + + async def on_tool_start( + self, + tool_call_id: str, + tool_name: str, + arguments: dict[str, Any] | str | None, + ) -> None: + """Open a child span for a tool call. + + Uses a deterministic span ID derived from ``tool_call_id`` so that + ``on_tool_end`` — which may run inside a different handler instance + when pydantic-ai splits the run across Temporal activities — can + close the same span without needing in-memory state. + """ + span_id = _tool_span_id(self._trace_id, tool_call_id) + await self._client.spans.create( + id=span_id, + trace_id=self._trace_id, + task_id=self._task_id, + parent_id=self._parent_span_id, + name=f"tool:{tool_name}" if tool_name else "tool", + start_time=datetime.now(UTC), + input={"arguments": arguments}, + data={"__span_type__": "CUSTOM"}, + ) + + async def on_tool_end(self, tool_call_id: str, result: Any) -> None: + """Close a child span by PATCHing its end_time and output. + + Re-derives the deterministic span ID from ``tool_call_id`` and updates + the existing span record directly. No in-memory span lookup, so this + works even when ``on_tool_start`` ran inside a different handler + instance (e.g. across pydantic-ai TemporalAgent activity boundaries). + """ + span_id = _tool_span_id(self._trace_id, tool_call_id) + await self._client.spans.update( + span_id, + end_time=datetime.now(UTC), + output={"result": result}, + ) + + async def on_tool_error(self, tool_call_id: str, error: BaseException | str) -> None: + """Close a child span with an error payload as output.""" + span_id = _tool_span_id(self._trace_id, tool_call_id) + await self._client.spans.update( + span_id, + end_time=datetime.now(UTC), + output={"error": str(error)}, + ) + + +def create_pydantic_ai_tracing_handler( + trace_id: str, + parent_span_id: str | None = None, + task_id: str | None = None, +) -> AgentexPydanticAITracingHandler: + """Create a tracing handler that records Agentex spans for pydantic-ai tool calls. + + Args: + trace_id: The trace ID. Typically the Agentex task ID. + parent_span_id: Optional parent span ID to nest tool spans under. If + omitted, the tool spans become trace-root spans. + task_id: Optional task ID stamped onto each span. Required for the + AgentEx UI's per-task spans dropdown to display the spans. + + Returns: + A handler suitable for passing to ``stream_pydantic_ai_events(..., tracing_handler=...)``. + """ + return AgentexPydanticAITracingHandler( + trace_id=trace_id, + parent_span_id=parent_span_id, + task_id=task_id, + ) diff --git a/src/agentex/lib/core/services/adk/streaming.py b/src/agentex/lib/core/services/adk/streaming.py index 7799ea1eb..846305a7d 100644 --- a/src/agentex/lib/core/services/adk/streaming.py +++ b/src/agentex/lib/core/services/adk/streaming.py @@ -2,7 +2,6 @@ import json import asyncio -import contextlib from typing import Literal, Callable, Awaitable from agentex import AsyncAgentex @@ -184,7 +183,7 @@ async def add(self, update: StreamTaskMessageDelta) -> None: async def _run(self) -> None: try: - while not self._closed: + while True: try: await asyncio.wait_for(self._flush_signal.wait(), timeout=self.FLUSH_INTERVAL_S) except asyncio.TimeoutError: @@ -192,29 +191,35 @@ async def _run(self) -> None: async with self._lock: self._flush_signal.clear() drained = self._drain_locked() - for idx, u in enumerate(drained): + for u in drained: try: await self._on_flush(u) - except asyncio.CancelledError: - # Re-enqueue the item being flushed plus any remaining so - # close()'s final drain can recover them. May cause a - # duplicate publish of the in-flight item, which is - # preferable to silent loss for a streaming UX. - async with self._lock: - self._buf = drained[idx:] + self._buf - raise except Exception as e: logger.exception(f"CoalescingBuffer flush failed: {e}") + # Check _closed *after* draining so close() always gets a final + # in-loop flush pass. Exiting here (instead of being cancelled + # mid-flush) guarantees each in-flight item is published exactly + # once — close()'s final drain then only picks up items added + # after the last lock release. + if self._closed: + return except asyncio.CancelledError: pass async def close(self) -> None: + # Signal the ticker to stop and let it exit naturally after its next + # drain. Cancelling mid-flush would risk re-publishing a delta whose + # Redis write already completed but whose await had not yet returned, + # producing the duplicate-tail symptom seen on the UI stream. self._closed = True if self._task is not None: self._flush_signal.set() - self._task.cancel() - with contextlib.suppress(asyncio.CancelledError): + try: await self._task + except asyncio.CancelledError: + # Propagate if our caller is being cancelled; the task itself + # swallows CancelledError so this only fires on outer cancel. + raise self._task = None async with self._lock: drained = self._drain_locked() diff --git a/tests/lib/adk/test_pydantic_ai_async.py b/tests/lib/adk/test_pydantic_ai_async.py new file mode 100644 index 000000000..88210456d --- /dev/null +++ b/tests/lib/adk/test_pydantic_ai_async.py @@ -0,0 +1,834 @@ +"""Tests for the async Pydantic AI -> Agentex streaming helper. + +Unlike the sync converter (which yields ``StreamTaskMessage*`` events for the +caller to forward over HTTP), the async helper publishes deltas to Redis +through ``adk.streaming.streaming_task_message_context`` and full messages +through ``adk.messages.create``. These tests substitute both with in-memory +fakes so we can assert exactly what was published without touching Redis or +the AgentEx server. +""" + +from __future__ import annotations + +from typing import Any, AsyncIterator +from dataclasses import field, dataclass + +import pytest +from pydantic_ai.messages import ( + TextPart, + PartEndEvent, + ThinkingPart, + ToolCallPart, + TextPartDelta, + PartDeltaEvent, + PartStartEvent, + ToolReturnPart, + RetryPromptPart, + ThinkingPartDelta, + FunctionToolResultEvent, +) + +from agentex.types.task_message import TaskMessage +from agentex.types.text_content import TextContent +from agentex.types.reasoning_content import ReasoningContent +from agentex.types.task_message_delta import TextDelta +from agentex.types.task_message_update import StreamTaskMessageDelta +from agentex.types.tool_request_content import ToolRequestContent +from agentex.types.tool_response_content import ToolResponseContent +from agentex.types.reasoning_content_delta import ReasoningContentDelta +from agentex.lib.adk._modules._pydantic_ai_async import stream_pydantic_ai_events + +TASK_ID = "task_test" + + +async def _aiter(events: list[Any]) -> AsyncIterator[Any]: + for e in events: + yield e + + +@dataclass +class FakeContext: + """In-memory stand-in for ``StreamingTaskMessageContext``. + + Records the order of updates and whether ``close()`` was called. The + helper drives this manually via ``__aenter__`` / ``close``, so we don't + use it as an ``async with`` — we just track the calls. + """ + + initial_content: Any + task_message: TaskMessage + closed: bool = False + updates: list[StreamTaskMessageDelta] = field(default_factory=list) + + async def __aenter__(self) -> "FakeContext": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: + await self.close() + return False + + async def stream_update(self, update: StreamTaskMessageDelta) -> None: + if self.closed: + raise AssertionError("stream_update called after close — helper closed the wrong context") + self.updates.append(update) + + async def close(self) -> None: + self.closed = True + + +class FakeStreamingModule: + """Records every streaming context the helper opens, in order.""" + + def __init__(self) -> None: + self.contexts: list[FakeContext] = [] + + def streaming_task_message_context(self, *, task_id: str, initial_content: Any) -> FakeContext: + tm = TaskMessage( + id=f"m{len(self.contexts) + 1}", + task_id=task_id, + content=initial_content, + streaming_status="IN_PROGRESS", + ) + ctx = FakeContext(initial_content=initial_content, task_message=tm) + self.contexts.append(ctx) + return ctx + + +class FakeMessagesModule: + """Records every ``adk.messages.create`` call.""" + + def __init__(self) -> None: + self.created: list[dict[str, Any]] = [] + + async def create(self, *, task_id: str, content: Any) -> TaskMessage: + self.created.append({"task_id": task_id, "content": content}) + return TaskMessage( + id=f"created-{len(self.created)}", + task_id=task_id, + content=content, + streaming_status="DONE", + ) + + +@pytest.fixture +def fake_adk(monkeypatch): + """Patches the lazy ``from agentex.lib import adk`` lookup inside the helper. + + Returns ``(streaming, messages)`` for assertions. + """ + from agentex.lib import adk as adk_module + + streaming = FakeStreamingModule() + messages = FakeMessagesModule() + monkeypatch.setattr(adk_module, "streaming", streaming) + monkeypatch.setattr(adk_module, "messages", messages) + return streaming, messages + + +def _text_deltas(ctx: FakeContext) -> list[str]: + out: list[str] = [] + for u in ctx.updates: + if isinstance(u.delta, TextDelta): + out.append(u.delta.text_delta or "") + return out + + +def _reasoning_deltas(ctx: FakeContext) -> list[str]: + out: list[str] = [] + for u in ctx.updates: + if isinstance(u.delta, ReasoningContentDelta): + out.append(u.delta.content_delta or "") + return out + + +class TestTextStreaming: + async def test_plain_text_opens_context_streams_deltas_and_closes( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + streaming, messages = fake_adk + events = [ + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="Hello")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=", ")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="world!")), + PartEndEvent(index=0, part=TextPart(content="Hello, world!")), + ] + + final = await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert len(streaming.contexts) == 1 + ctx = streaming.contexts[0] + assert isinstance(ctx.initial_content, TextContent) + assert ctx.initial_content.content == "" + assert _text_deltas(ctx) == ["Hello", ", ", "world!"] + assert ctx.closed is True, "PartEndEvent must close the streaming context" + assert messages.created == [], "Plain text must not emit standalone messages" + assert final == "Hello, world!" + + async def test_initial_content_in_part_start_is_streamed_as_delta( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """Pydantic AI sometimes packs the first chunk inside ``PartStartEvent.part.content``. + + Agentex renders only Delta events as the message body, so the helper + must surface that initial chunk as a delta — otherwise the first token + is invisible to the UI. + """ + streaming, _ = fake_adk + events = [ + PartStartEvent(index=0, part=TextPart(content="Already there")), + PartEndEvent(index=0, part=TextPart(content="Already there")), + ] + final = await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + ctx = streaming.contexts[0] + assert _text_deltas(ctx) == ["Already there"] + assert final == "Already there" + + async def test_returns_only_last_text_segment_in_multi_step_run( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """Matches the documented contract / the LangGraph async helper's behavior.""" + streaming, _ = fake_adk + events = [ + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="Looking up...")), + PartEndEvent(index=0, part=TextPart(content="Looking up...")), + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="It's sunny.")), + PartEndEvent(index=0, part=TextPart(content="It's sunny.")), + ] + final = await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert len(streaming.contexts) == 2, "Two text parts → two streaming contexts" + assert all(ctx.closed for ctx in streaming.contexts) + assert _text_deltas(streaming.contexts[0]) == ["Looking up..."] + assert _text_deltas(streaming.contexts[1]) == ["It's sunny."] + assert final == "It's sunny." + + +class TestThinkingStreaming: + async def test_thinking_opens_reasoning_context_with_reasoning_deltas( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + streaming, _ = fake_adk + events = [ + PartStartEvent(index=0, part=ThinkingPart(content="")), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta="step 1...")), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=" step 2.")), + PartEndEvent(index=0, part=ThinkingPart(content="step 1... step 2.")), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + ctx = streaming.contexts[0] + assert isinstance(ctx.initial_content, ReasoningContent) + assert _reasoning_deltas(ctx) == ["step 1...", " step 2."] + assert ctx.closed is True + + async def test_thinking_initial_content_is_streamed_as_delta( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + streaming, _ = fake_adk + events = [ + PartStartEvent(index=0, part=ThinkingPart(content="seed reasoning")), + PartEndEvent(index=0, part=ThinkingPart(content="seed reasoning")), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + ctx = streaming.contexts[0] + assert _reasoning_deltas(ctx) == ["seed reasoning"] + + async def test_empty_thinking_delta_is_skipped( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + streaming, _ = fake_adk + events = [ + PartStartEvent(index=0, part=ThinkingPart(content="")), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=None)), + PartEndEvent(index=0, part=ThinkingPart(content="")), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + ctx = streaming.contexts[0] + assert _reasoning_deltas(ctx) == [], "Empty ThinkingPartDelta must not publish a zero-length reasoning delta" + assert ctx.closed is True + + +class TestToolCallEmission: + async def test_tool_call_emits_full_tool_request_message_on_part_end( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """Async helper uses Option A: tool requests are full messages, not delta streams.""" + streaming, messages = fake_adk + events = [ + PartStartEvent( + index=1, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="c1"), + ), + PartEndEvent( + index=1, + part=ToolCallPart(tool_name="get_weather", args='{"city":"Paris"}', tool_call_id="c1"), + ), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert streaming.contexts == [], "Tool calls do not open a streaming context" + assert len(messages.created) == 1 + msg = messages.created[0] + assert msg["task_id"] == TASK_ID + content = msg["content"] + assert isinstance(content, ToolRequestContent) + assert content.tool_call_id == "c1" + assert content.name == "get_weather" + assert content.arguments == {"city": "Paris"} + assert content.author == "agent" + + async def test_tool_call_with_dict_args_passes_through( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + _, messages = fake_adk + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="search", args={"q": "weather"}, tool_call_id="c"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="search", args={"q": "weather"}, tool_call_id="c"), + ), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert len(messages.created) == 1 + assert messages.created[0]["content"].arguments == {"q": "weather"} + + async def test_tool_call_with_invalid_json_args_surfaces_raw( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """Don't drop the tool call when the model emits malformed JSON args. + + The arguments field is preserved under ``_raw`` so the failure is + visible to the UI rather than silently truncated. + """ + _, messages = fake_adk + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="t", args=None, tool_call_id="c"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="t", args="not-json{", tool_call_id="c"), + ), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert len(messages.created) == 1 + assert messages.created[0]["content"].arguments == {"_raw": "not-json{"} + + async def test_tool_call_with_none_args_defaults_to_empty_dict( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + _, messages = fake_adk + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="t", args=None, tool_call_id="c"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="t", args=None, tool_call_id="c"), + ), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert len(messages.created) == 1 + assert messages.created[0]["content"].arguments == {} + + +class TestToolResult: + async def test_tool_return_emits_full_tool_response_message( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + _, messages = fake_adk + events = [ + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="get_weather", content="Sunny, 72F", tool_call_id="c1"), + ), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert len(messages.created) == 1 + content = messages.created[0]["content"] + assert isinstance(content, ToolResponseContent) + assert content.tool_call_id == "c1" + assert content.name == "get_weather" + assert content.content == "Sunny, 72F" + assert content.author == "agent" + + async def test_tool_return_with_non_string_content_stringifies( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + _, messages = fake_adk + events = [ + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="t", content={"temp": 72, "sky": "clear"}, tool_call_id="c"), + ), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + # The content is stringified; we just check the structured payload is + # still readable from the result. + out = messages.created[0]["content"].content + assert "72" in out and "clear" in out + + async def test_retry_prompt_part_surfaces_as_tool_response( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + _, messages = fake_adk + events = [ + FunctionToolResultEvent( + part=RetryPromptPart( + content="bad arguments", + tool_name="get_weather", + tool_call_id="c1", + ), + ), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert len(messages.created) == 1 + content = messages.created[0]["content"] + assert isinstance(content, ToolResponseContent) + assert content.tool_call_id == "c1" + # RetryPromptPart.content stringifies to the error description + assert "bad arguments" in str(content.content) + + +class TestContextLifecycle: + async def test_text_then_tool_then_text_uses_separate_contexts_in_order( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """End-to-end multi-step shape: text → tool call → tool result → more text. + + Each text/reasoning segment must get its own streaming context that is + closed before the next one opens, and tool messages must interleave + correctly via ``adk.messages.create``. + """ + streaming, messages = fake_adk + events = [ + # First model response: text + tool call. + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="Looking up...")), + PartEndEvent(index=0, part=TextPart(content="Looking up...")), + PartStartEvent( + index=1, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="c1"), + ), + PartEndEvent( + index=1, + part=ToolCallPart(tool_name="get_weather", args="{}", tool_call_id="c1"), + ), + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="get_weather", content="Sunny", tool_call_id="c1"), + ), + # Second model response: more text. + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="It's sunny.")), + PartEndEvent(index=0, part=TextPart(content="It's sunny.")), + ] + final = await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert len(streaming.contexts) == 2, "One context per text part — tool calls don't open streaming contexts" + assert all(ctx.closed for ctx in streaming.contexts) + assert _text_deltas(streaming.contexts[0]) == ["Looking up..."] + assert _text_deltas(streaming.contexts[1]) == ["It's sunny."] + + # Two messages: tool request, then tool response — in that order. + assert [type(m["content"]).__name__ for m in messages.created] == [ + "ToolRequestContent", + "ToolResponseContent", + ] + assert messages.created[0]["content"].tool_call_id == "c1" + assert messages.created[1]["content"].tool_call_id == "c1" + assert final == "It's sunny." + + async def test_new_text_part_after_text_closes_previous( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """Defensive: two text parts in a row (same response) must not bleed deltas across contexts.""" + streaming, _ = fake_adk + events = [ + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="A")), + PartStartEvent(index=1, part=TextPart(content="")), + PartDeltaEvent(index=1, delta=TextPartDelta(content_delta="B")), + PartEndEvent(index=1, part=TextPart(content="B")), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert len(streaming.contexts) == 2 + # First context was closed when the second TextPart started. + assert streaming.contexts[0].closed is True + assert _text_deltas(streaming.contexts[0]) == ["A"] + assert _text_deltas(streaming.contexts[1]) == ["B"] + + async def test_reasoning_then_text_closes_reasoning_context( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """Switching from a thinking part to a text part must close the reasoning context.""" + streaming, _ = fake_adk + events = [ + PartStartEvent(index=0, part=ThinkingPart(content="")), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta="think")), + PartStartEvent(index=1, part=TextPart(content="")), + PartDeltaEvent(index=1, delta=TextPartDelta(content_delta="answer")), + PartEndEvent(index=1, part=TextPart(content="answer")), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert len(streaming.contexts) == 2 + # Reasoning context closed before text opened. + assert streaming.contexts[0].closed is True + assert isinstance(streaming.contexts[0].initial_content, ReasoningContent) + assert _reasoning_deltas(streaming.contexts[0]) == ["think"] + assert isinstance(streaming.contexts[1].initial_content, TextContent) + assert _text_deltas(streaming.contexts[1]) == ["answer"] + + async def test_tool_result_closes_any_open_streaming_context( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """A tool result arriving while a text context is open must close that context first.""" + streaming, messages = fake_adk + events = [ + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="thinking")), + # No PartEndEvent — provider sends the tool result while text is "live". + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="t", content="ok", tool_call_id="c"), + ), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert streaming.contexts[0].closed is True, ( + "Helper must close any open streaming context before emitting a tool result message" + ) + assert len(messages.created) == 1 + + +class TestDeltaForOrphanIndexIgnored: + async def test_part_delta_without_matching_start_is_ignored( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """A delta for an index we never saw a Start for must be a no-op, not a crash.""" + streaming, messages = fake_adk + events = [ + PartDeltaEvent(index=99, delta=TextPartDelta(content_delta="orphan")), + ] + final = await stream_pydantic_ai_events(_aiter(events), TASK_ID) + + assert streaming.contexts == [] + assert messages.created == [] + assert final == "" + + +class TestTracingHandler: + """Tracing handler hooks fire alongside streaming for each tool call.""" + + @dataclass + class _RecordingHandler: + starts: list[dict[str, Any]] = field(default_factory=list) + ends: list[dict[str, Any]] = field(default_factory=list) + + async def on_tool_start(self, tool_call_id: str, tool_name: str, arguments: Any) -> None: + self.starts.append({"tool_call_id": tool_call_id, "tool_name": tool_name, "arguments": arguments}) + + async def on_tool_end(self, tool_call_id: str, result: Any) -> None: + self.ends.append({"tool_call_id": tool_call_id, "result": result}) + + async def test_handler_records_start_and_end_for_each_tool_call( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + _, messages = fake_adk + handler = self._RecordingHandler() + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="c1"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args='{"city":"Paris"}', tool_call_id="c1"), + ), + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="get_weather", content="Sunny", tool_call_id="c1"), + ), + ] + await stream_pydantic_ai_events( + _aiter(events), + TASK_ID, + tracing_handler=handler, # type: ignore[arg-type] + ) + + # Streaming side-effects still happen — tracing is additive. + assert [type(m["content"]).__name__ for m in messages.created] == [ + "ToolRequestContent", + "ToolResponseContent", + ] + # And both lifecycle hooks fired exactly once with the right payload. + assert handler.starts == [ + { + "tool_call_id": "c1", + "tool_name": "get_weather", + "arguments": {"city": "Paris"}, + } + ] + assert handler.ends == [{"tool_call_id": "c1", "result": "Sunny"}] + + async def test_handler_not_called_when_no_tool_calls_in_stream( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + handler = self._RecordingHandler() + events = [ + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="Hello")), + PartEndEvent(index=0, part=TextPart(content="Hello")), + ] + await stream_pydantic_ai_events( + _aiter(events), + TASK_ID, + tracing_handler=handler, # type: ignore[arg-type] + ) + assert handler.starts == [] + assert handler.ends == [] + + async def test_handler_records_each_tool_in_multi_tool_run( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """A turn with two tool calls must produce two start/end pairs in order.""" + handler = self._RecordingHandler() + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="c1"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args="{}", tool_call_id="c1"), + ), + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="get_weather", content="Sunny", tool_call_id="c1"), + ), + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="lookup_city", args=None, tool_call_id="c2"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="lookup_city", args="{}", tool_call_id="c2"), + ), + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="lookup_city", content="Paris, FR", tool_call_id="c2"), + ), + ] + await stream_pydantic_ai_events( + _aiter(events), + TASK_ID, + tracing_handler=handler, # type: ignore[arg-type] + ) + + assert [s["tool_call_id"] for s in handler.starts] == ["c1", "c2"] + assert [e["tool_call_id"] for e in handler.ends] == ["c1", "c2"] + assert handler.starts[0]["tool_name"] == "get_weather" + assert handler.starts[1]["tool_name"] == "lookup_city" + + async def test_omitting_handler_is_a_no_op_for_existing_behavior( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """Regression: passing no tracing handler preserves the pre-tracing behavior.""" + _, messages = fake_adk + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="c1"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args="{}", tool_call_id="c1"), + ), + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="get_weather", content="Sunny", tool_call_id="c1"), + ), + ] + await stream_pydantic_ai_events(_aiter(events), TASK_ID) + # Exact same shape as before tracing existed. + assert [type(m["content"]).__name__ for m in messages.created] == [ + "ToolRequestContent", + "ToolResponseContent", + ] + + +class TestPydanticAITracingHandlerDeterministicIds: + """Regression coverage for ``AgentexPydanticAITracingHandler``. + + pydantic-ai's ``TemporalAgent`` splits a single agent run across several + Temporal activities. The event_stream_handler is invoked once per + activity, with a fresh handler instance each time. So ``on_tool_start`` + (during the model activity that issued the tool call) and ``on_tool_end`` + (during the next model activity, after the tool ran) end up in DIFFERENT + handler instances — an in-memory dict can't pair them. + + The fix is deterministic span IDs derived from ``(trace_id, tool_call_id)``. + These tests lock that in. + """ + + class _RecordingClient: + """Stand-in for ``AsyncAgentex`` capturing spans.create / spans.update calls.""" + + def __init__(self) -> None: + self.creates: list[dict[str, Any]] = [] + self.updates: list[tuple[str, dict[str, Any]]] = [] + self.spans = self # so .spans.create / .spans.update resolve back here + + async def create(self, **kwargs: Any) -> Any: + self.creates.append(kwargs) + return None + + async def update(self, span_id: str, **kwargs: Any) -> Any: + self.updates.append((span_id, kwargs)) + return None + + async def test_same_tool_call_id_yields_same_span_id_across_handler_instances( + self, + ) -> None: + """The whole point of the design: two handler instances with the same + trace_id and tool_call_id resolve to the same span ID — otherwise + ``on_tool_end`` patches a different (non-existent) record and the span + in the DB never gets ``end_time`` / ``output``.""" + from agentex.lib.adk._modules._pydantic_ai_tracing import ( + AgentexPydanticAITracingHandler, + ) + + client_a = self._RecordingClient() + client_b = self._RecordingClient() + + # Two independent handler instances — simulates the cross-activity + # invocation pattern in TemporalAgent. + handler_a = AgentexPydanticAITracingHandler( + trace_id="trace-1", + parent_span_id="parent-1", + task_id="task-1", + client=client_a, # type: ignore[arg-type] + ) + handler_b = AgentexPydanticAITracingHandler( + trace_id="trace-1", + parent_span_id="parent-1", + task_id="task-1", + client=client_b, # type: ignore[arg-type] + ) + + await handler_a.on_tool_start(tool_call_id="call_abc", tool_name="get_weather", arguments={"city": "Paris"}) + await handler_b.on_tool_end(tool_call_id="call_abc", result="Sunny, 72F") + + assert len(client_a.creates) == 1 + assert len(client_b.updates) == 1 + + created_span_id = client_a.creates[0]["id"] + updated_span_id = client_b.updates[0][0] + assert created_span_id == updated_span_id, ( + "on_tool_start and on_tool_end must address the same span across handler " + "instances; mismatch means tool spans will be left open and the AgentEx UI " + "will hide their trace." + ) + + async def test_different_tool_call_ids_yield_different_span_ids(self) -> None: + from agentex.lib.adk._modules._pydantic_ai_tracing import ( + AgentexPydanticAITracingHandler, + ) + + client = self._RecordingClient() + handler = AgentexPydanticAITracingHandler( + trace_id="trace-1", + client=client, # type: ignore[arg-type] + ) + + await handler.on_tool_start("call_a", "get_weather", {"city": "Paris"}) + await handler.on_tool_start("call_b", "get_weather", {"city": "Tokyo"}) + + ids = {c["id"] for c in client.creates} + assert len(ids) == 2, "Distinct tool_call_ids must map to distinct span IDs" + + async def test_same_tool_call_id_in_different_traces_yields_different_span_ids( + self, + ) -> None: + """Span IDs are namespaced by trace_id so two unrelated runs with the + same provider-issued tool_call_id don't collide.""" + from agentex.lib.adk._modules._pydantic_ai_tracing import ( + AgentexPydanticAITracingHandler, + ) + + client = self._RecordingClient() + handler_t1 = AgentexPydanticAITracingHandler(trace_id="trace-1", client=client) # type: ignore[arg-type] + handler_t2 = AgentexPydanticAITracingHandler(trace_id="trace-2", client=client) # type: ignore[arg-type] + + await handler_t1.on_tool_start("call_abc", "t", None) + await handler_t2.on_tool_start("call_abc", "t", None) + + ids = {c["id"] for c in client.creates} + assert len(ids) == 2 + + async def test_on_tool_end_patches_only_end_time_and_output(self) -> None: + """Don't overwrite start_time, name, parent_id, etc. on close — only patch + the fields we have new values for. Sending start_time again could clobber + what was set at create time.""" + from agentex.lib.adk._modules._pydantic_ai_tracing import ( + AgentexPydanticAITracingHandler, + ) + + client = self._RecordingClient() + handler = AgentexPydanticAITracingHandler(trace_id="trace-1", client=client) # type: ignore[arg-type] + + await handler.on_tool_end("call_abc", "Sunny") + + assert len(client.updates) == 1 + _, patch_kwargs = client.updates[0] + assert set(patch_kwargs.keys()) == {"end_time", "output"}, ( + f"Unexpected fields in tool span PATCH: {set(patch_kwargs.keys())}" + ) + assert patch_kwargs["output"] == {"result": "Sunny"} + + async def test_on_tool_error_patches_error_output(self) -> None: + from agentex.lib.adk._modules._pydantic_ai_tracing import ( + AgentexPydanticAITracingHandler, + ) + + client = self._RecordingClient() + handler = AgentexPydanticAITracingHandler(trace_id="trace-1", client=client) # type: ignore[arg-type] + + await handler.on_tool_error("call_abc", RuntimeError("boom")) + + assert len(client.updates) == 1 + _, patch_kwargs = client.updates[0] + assert "error" in patch_kwargs["output"] + assert "boom" in patch_kwargs["output"]["error"] + + +class TestCleanupOnException: + async def test_open_contexts_are_closed_on_iterator_failure( + self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule] + ) -> None: + """If the upstream Pydantic AI stream raises mid-flight, any open + streaming context must still be closed — otherwise the Agentex + ``messages.update(..., streaming_status="DONE")`` call never runs and + the UI shows a perma-streaming message.""" + streaming, _ = fake_adk + + async def boom() -> AsyncIterator[Any]: + yield PartStartEvent(index=0, part=TextPart(content="")) + yield PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="partial")) + raise RuntimeError("upstream provider exploded") + + with pytest.raises(RuntimeError, match="upstream provider exploded"): + await stream_pydantic_ai_events(boom(), TASK_ID) + + assert streaming.contexts[0].closed is True diff --git a/tests/lib/adk/test_pydantic_ai_sync.py b/tests/lib/adk/test_pydantic_ai_sync.py new file mode 100644 index 000000000..13a895f26 --- /dev/null +++ b/tests/lib/adk/test_pydantic_ai_sync.py @@ -0,0 +1,476 @@ +"""Tests for the Pydantic AI -> Agentex stream event converter.""" + +from __future__ import annotations + +import json +from typing import Any, AsyncIterator + +import pytest +from pydantic_ai.messages import ( + TextPart, + PartEndEvent, + ThinkingPart, + ToolCallPart, + TextPartDelta, + PartDeltaEvent, + PartStartEvent, + ToolReturnPart, + RetryPromptPart, + FinalResultEvent, + ThinkingPartDelta, + ToolCallPartDelta, + FunctionToolCallEvent, + FunctionToolResultEvent, +) + +from agentex.types.task_message_delta import TextDelta +from agentex.types.tool_request_delta import ToolRequestDelta +from agentex.types.task_message_update import ( + StreamTaskMessageDone, + StreamTaskMessageFull, + StreamTaskMessageDelta, + StreamTaskMessageStart, +) +from agentex.types.task_message_content import TextContent +from agentex.types.tool_request_content import ToolRequestContent +from agentex.types.tool_response_content import ToolResponseContent +from agentex.types.reasoning_content_delta import ReasoningContentDelta +from agentex.lib.adk._modules._pydantic_ai_sync import ( + _args_delta_to_str, + convert_pydantic_ai_to_agentex_events, +) + + +async def _aiter(events: list[Any]) -> AsyncIterator[Any]: + for e in events: + yield e + + +async def _collect(stream: AsyncIterator[Any]) -> list[Any]: + return [e async for e in stream] + + +class TestArgsDeltaToStr: + def test_none(self): + assert _args_delta_to_str(None) == "" + + def test_string_passthrough(self): + assert _args_delta_to_str('{"k":') == '{"k":' + + def test_dict_dumps_json(self): + assert json.loads(_args_delta_to_str({"city": "Paris"})) == {"city": "Paris"} + + +class TestTextStreaming: + async def test_plain_text_emits_start_deltas_done(self): + events = [ + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="Hello")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=", ")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="world!")), + PartEndEvent(index=0, part=TextPart(content="Hello, world!")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + + assert len(out) == 5 + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[0].content, TextContent) + assert out[0].content.content == "" + assert out[0].index == 0 + + for i, expected in enumerate(["Hello", ", ", "world!"], start=1): + assert isinstance(out[i], StreamTaskMessageDelta) + assert isinstance(out[i].delta, TextDelta) + assert out[i].delta.text_delta == expected + assert out[i].index == 0 + + assert isinstance(out[4], StreamTaskMessageDone) + assert out[4].index == 0 + + async def test_text_with_initial_content_emits_delta(self): + """Pydantic AI puts the first streaming chunk in PartStartEvent.part.content. + + The Agentex protocol only renders Delta events as the message body, so we + must emit the initial content as a Delta — not in the Start — otherwise + the first chunk disappears from the visible message. + """ + events = [ + PartStartEvent(index=0, part=TextPart(content="Already there")), + PartEndEvent(index=0, part=TextPart(content="Already there")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[0].content, TextContent) + assert out[0].content.content == "" + assert isinstance(out[1], StreamTaskMessageDelta) + assert isinstance(out[1].delta, TextDelta) + assert out[1].delta.text_delta == "Already there" + + +class TestThinkingStreaming: + async def test_thinking_emits_reasoning_deltas(self): + events = [ + PartStartEvent(index=0, part=ThinkingPart(content="")), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta="step 1...")), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=" step 2.")), + PartEndEvent(index=0, part=ThinkingPart(content="step 1... step 2.")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[1], StreamTaskMessageDelta) + assert isinstance(out[1].delta, ReasoningContentDelta) + assert out[1].delta.content_delta == "step 1..." + assert out[1].delta.content_index == 0 + assert isinstance(out[2].delta, ReasoningContentDelta) + assert out[2].delta.content_delta == " step 2." + assert isinstance(out[3], StreamTaskMessageDone) + + async def test_thinking_with_initial_content_emits_delta(self): + events = [ + PartStartEvent(index=0, part=ThinkingPart(content="seed reasoning")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[1], StreamTaskMessageDelta) + assert isinstance(out[1].delta, ReasoningContentDelta) + assert out[1].delta.content_delta == "seed reasoning" + + async def test_thinking_delta_skipped_when_empty(self): + events = [ + PartStartEvent(index=0, part=ThinkingPart(content="")), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=None)), + PartEndEvent(index=0, part=ThinkingPart(content="")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert len(out) == 2 # Start + Done; no delta for None content + + +class TestToolCallStreaming: + async def test_tool_call_streamed_token_by_token(self): + """The headline use case: tool-call argument tokens streaming through to the client.""" + events = [ + PartStartEvent( + index=1, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="call_abc"), + ), + PartDeltaEvent( + index=1, + delta=ToolCallPartDelta(args_delta='{"city":', tool_call_id="call_abc"), + ), + PartDeltaEvent(index=1, delta=ToolCallPartDelta(args_delta='"Paris"}')), + PartEndEvent( + index=1, + part=ToolCallPart(tool_name="get_weather", args='{"city":"Paris"}', tool_call_id="call_abc"), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + + assert len(out) == 4 + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[0].content, ToolRequestContent) + assert out[0].content.tool_call_id == "call_abc" + assert out[0].content.name == "get_weather" + assert out[0].content.arguments == {} + + assert isinstance(out[1].delta, ToolRequestDelta) + assert out[1].delta.tool_call_id == "call_abc" + assert out[1].delta.name == "get_weather" + assert out[1].delta.arguments_delta == '{"city":' + + assert isinstance(out[2].delta, ToolRequestDelta) + assert out[2].delta.arguments_delta == '"Paris"}' + # tool_call_id is carried forward from the start even when the delta omits it + assert out[2].delta.tool_call_id == "call_abc" + + assert isinstance(out[3], StreamTaskMessageDone) + + async def test_tool_call_with_full_args_at_start(self): + """Some providers return a tool call in one shot — args dict is set at start.""" + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="search", args={"query": "weather"}, tool_call_id="call_xyz"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="search", args={"query": "weather"}, tool_call_id="call_xyz"), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[0].content, ToolRequestContent) + assert out[0].content.arguments == {"query": "weather"} + # No deltas emitted — args were already complete. + assert len(out) == 2 + assert isinstance(out[1], StreamTaskMessageDone) + + async def test_tool_call_with_full_args_string_at_start(self): + """When args is a complete JSON string at start, surface it as a single delta.""" + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="search", args='{"query":"weather"}', tool_call_id="call_z"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="search", args='{"query":"weather"}', tool_call_id="call_z"), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[0].content, ToolRequestContent) + assert out[0].content.arguments == {} + assert isinstance(out[1], StreamTaskMessageDelta) + assert isinstance(out[1].delta, ToolRequestDelta) + assert out[1].delta.arguments_delta == '{"query":"weather"}' + + async def test_tool_call_dict_args_delta_serialized(self): + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="t", args=None, tool_call_id="cid"), + ), + PartDeltaEvent( + index=0, + delta=ToolCallPartDelta(args_delta={"k": "v"}, tool_call_id="cid"), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert json.loads(out[1].delta.arguments_delta) == {"k": "v"} + + async def test_tool_result_emits_full(self): + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="call_abc"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args="{}", tool_call_id="call_abc"), + ), + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="get_weather", content="Sunny, 72F", tool_call_id="call_abc"), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + + # Last event is the tool result -> Full ToolResponseContent + assert isinstance(out[-1], StreamTaskMessageFull) + assert isinstance(out[-1].content, ToolResponseContent) + assert out[-1].content.tool_call_id == "call_abc" + assert out[-1].content.name == "get_weather" + assert out[-1].content.content == "Sunny, 72F" + + async def test_tool_retry_prompt_surfaces_as_response(self): + events = [ + FunctionToolResultEvent( + part=RetryPromptPart( + content="bad arguments", + tool_name="get_weather", + tool_call_id="call_abc", + ), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert isinstance(out[0], StreamTaskMessageFull) + assert isinstance(out[0].content, ToolResponseContent) + assert out[0].content.tool_call_id == "call_abc" + assert out[0].content.name == "get_weather" + # RetryPromptPart's content is the error message + assert out[0].content.content == "bad arguments" + + +class TestTracingHandlerSync: + """The sync converter has the same opt-in tracing-handler contract as the + async streamer: pass a handler and the converter calls ``on_tool_start`` / + ``on_tool_end`` for each tool call. Streaming yields are unchanged when + omitted.""" + + class _RecordingHandler: + def __init__(self) -> None: + self.starts: list[dict[str, Any]] = [] + self.ends: list[dict[str, Any]] = [] + + async def on_tool_start(self, tool_call_id: str, tool_name: str, arguments: Any) -> None: + self.starts.append({"tool_call_id": tool_call_id, "tool_name": tool_name, "arguments": arguments}) + + async def on_tool_end(self, tool_call_id: str, result: Any) -> None: + self.ends.append({"tool_call_id": tool_call_id, "result": result}) + + async def test_handler_records_start_and_end_for_a_tool_call(self): + handler = self._RecordingHandler() + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="c1"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args='{"city":"Paris"}', tool_call_id="c1"), + ), + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="get_weather", content="Sunny", tool_call_id="c1"), + ), + ] + out = await _collect( + convert_pydantic_ai_to_agentex_events(_aiter(events), tracing_handler=handler) # type: ignore[arg-type] + ) + + # Streaming output is unchanged. + assert any(isinstance(e, StreamTaskMessageStart) for e in out) + assert any(isinstance(e, StreamTaskMessageFull) for e in out) + + assert handler.starts == [ + { + "tool_call_id": "c1", + "tool_name": "get_weather", + "arguments": {"city": "Paris"}, + } + ] + assert handler.ends == [{"tool_call_id": "c1", "result": "Sunny"}] + + async def test_handler_not_called_when_no_tool_calls(self): + handler = self._RecordingHandler() + events = [ + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="hi")), + PartEndEvent(index=0, part=TextPart(content="hi")), + ] + await _collect( + convert_pydantic_ai_to_agentex_events(_aiter(events), tracing_handler=handler) # type: ignore[arg-type] + ) + assert handler.starts == [] + assert handler.ends == [] + + async def test_omitting_handler_preserves_pre_tracing_behavior(self): + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="t", args=None, tool_call_id="c"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="t", args="{}", tool_call_id="c"), + ), + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="t", content="ok", tool_call_id="c"), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + # Same emit shape as before: Start, Done, Full + types = [type(e).__name__ for e in out] + assert "StreamTaskMessageStart" in types + assert "StreamTaskMessageDone" in types + assert "StreamTaskMessageFull" in types + + +class TestMultiStepRun: + async def test_text_then_tool_then_text_assigns_distinct_indices(self): + """A multi-step run: model emits text + tool call → tool runs → model emits more text. + + Pydantic AI restarts part indices at 0 for each new model response, so + the converter must assign fresh Agentex message indices. + """ + events = [ + # First model response: text at index 0, tool call at index 1 + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="Looking up...")), + PartEndEvent(index=0, part=TextPart(content="Looking up...")), + PartStartEvent( + index=1, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="c1"), + ), + PartDeltaEvent(index=1, delta=ToolCallPartDelta(args_delta="{}")), + PartEndEvent(index=1, part=ToolCallPart(tool_name="get_weather", args="{}", tool_call_id="c1")), + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="get_weather", content="Sunny", tool_call_id="c1"), + ), + # Second model response: text restarts at index 0 + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="It's sunny.")), + PartEndEvent(index=0, part=TextPart(content="It's sunny.")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + + # Pull every Start/Full event and check their assigned message indices + anchors = [e for e in out if isinstance(e, (StreamTaskMessageStart, StreamTaskMessageFull))] + indices = [e.index for e in anchors] + assert indices == [0, 1, 2, 3], ( + f"Expected 4 distinct, monotonic message indices for: text1, tool_call, tool_result, text2 — got {indices}" + ) + + # And the second text's deltas should target the second text's message index. + text2_start = anchors[3] + text2_deltas = [ + e + for e in out + if isinstance(e, StreamTaskMessageDelta) and isinstance(e.delta, TextDelta) and e.index == text2_start.index + ] + assert len(text2_deltas) == 1 + text2_delta = text2_deltas[0].delta + assert isinstance(text2_delta, TextDelta) + assert text2_delta.text_delta == "It's sunny." + + +class TestIgnoredEvents: + async def test_function_tool_call_event_is_ignored(self): + """FunctionToolCallEvent is redundant with PartStart+Delta+End and should be skipped.""" + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="t", args=None, tool_call_id="c"), + ), + FunctionToolCallEvent( + part=ToolCallPart(tool_name="t", args="{}", tool_call_id="c"), + ), + PartEndEvent(index=0, part=ToolCallPart(tool_name="t", args="{}", tool_call_id="c")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + # Start + Done only — no event from FunctionToolCallEvent + assert len(out) == 2 + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[1], StreamTaskMessageDone) + + async def test_final_result_event_ignored(self): + events = [ + FinalResultEvent(tool_name=None, tool_call_id=None), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert out == [] + + async def test_unknown_part_index_delta_skipped(self): + events = [ + PartDeltaEvent(index=99, delta=TextPartDelta(content_delta="orphan")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert out == [] + + +class TestStartingTextMatchesAuthor: + """Sanity check that all emitted content is authored by the agent.""" + + @pytest.mark.parametrize( + "events", + [ + [PartStartEvent(index=0, part=TextPart(content=""))], + [PartStartEvent(index=0, part=ThinkingPart(content=""))], + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="t", args=None, tool_call_id="c"), + ) + ], + [ + FunctionToolResultEvent( + part=ToolReturnPart(tool_name="t", content="ok", tool_call_id="c"), + ) + ], + ], + ) + async def test_author_is_agent(self, events: list[Any]): + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + for e in out: + content = getattr(e, "content", None) + if content is not None and hasattr(content, "author"): + assert content.author == "agent" diff --git a/tests/lib/core/services/adk/test_streaming.py b/tests/lib/core/services/adk/test_streaming.py index 8b5fe9a35..a828df224 100644 --- a/tests/lib/core/services/adk/test_streaming.py +++ b/tests/lib/core/services/adk/test_streaming.py @@ -393,15 +393,20 @@ async def on_flush(u: StreamTaskMessageDelta) -> None: assert flushed == [] -class TestCoalescingBufferCancelDuringFlush: +class TestCoalescingBufferCloseDuringFlush: @pytest.mark.asyncio - async def test_cancel_during_flush_recovers_remaining_items( + async def test_close_during_flush_is_exactly_once( self, task_message: TaskMessage ) -> None: - """Regression: when ``close()`` cancels the ticker mid-flush, items in - the local ``drained`` list must be re-enqueued so the final drain in - ``close()`` can recover them. Otherwise the last coalesced batch is - silently dropped — visible to consumers as a truncated stream. + """Regression: ``close()`` while the ticker is mid-flush must publish + each delta exactly once — no loss, no duplicate. + + The earlier implementation cancelled the ticker task during ``close()`` + and re-enqueued the in-flight item to avoid silent loss; that produced + a duplicated tail on the Redis stream when the Redis write had in fact + completed before the cancellation landed. The current implementation + signals the ticker to exit naturally after its next drain pass, which + gives exactly-once delivery without the duplication. """ flushed: list[StreamTaskMessageDelta] = [] first_started = asyncio.Event() @@ -411,8 +416,8 @@ async def slow_flush(u: StreamTaskMessageDelta) -> None: flushed.append(u) if len(flushed) == 1: first_started.set() - # Block the first publish until the test releases it. This - # guarantees the cancellation lands inside the flush loop. + # Block the first publish until the test releases it; this + # parks close() inside the ticker's flush loop. await first_continue.wait() buf = CoalescingBuffer(on_flush=slow_flush) @@ -425,22 +430,22 @@ async def slow_flush(u: StreamTaskMessageDelta) -> None: await asyncio.wait_for(first_started.wait(), timeout=2.0) # Trigger close() while the first flush is blocked, then release it. close_task = asyncio.create_task(buf.close()) + # Give close() a tick to set _closed and start awaiting the ticker. + await asyncio.sleep(0) first_continue.set() await close_task - # All five chunks must appear at least once across all publishes. - # (The first-flushed item may duplicate; that's the documented - # trade-off — duplicate > silent loss.) full = "".join( u.delta.text_delta or "" for u in flushed if isinstance(u.delta, TextDelta) ) - for i in range(5): - assert f"chunk{i}" in full, ( - f"chunk{i} missing — silent data loss across cancel-during-flush boundary. " - f"flushed payloads: {[u.delta.text_delta for u in flushed if isinstance(u.delta, TextDelta)]}" - ) + # Exactly the five chunks, in order, with no duplication of any + # chunk's tail. + assert full == "chunk0chunk1chunk2chunk3chunk4", ( + f"expected exactly-once delivery; got: {full!r} " + f"(payloads: {[u.delta.text_delta for u in flushed if isinstance(u.delta, TextDelta)]})" + ) class TestStreamingTaskMessageContextModes: