Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/packages/core/agent_framework/_harness/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,26 @@ async def consolidate_memories() -> str:
)
if recent_history_messages:
context.extend_messages(self.source_id, recent_history_messages)

# Surface cross-session origin so downstream context observers can
# distinguish injected memory from content native to the current
# session. Loaded topic files may carry contributions from earlier
# sessions (tracked in ``MemoryTopicRecord.session_ids``); when any
# contributor differs from the current session, mark the injected
# block accordingly. See the ``CrossSessionObserver`` sample under
# ``samples/governance/cross_session_observer`` for an example
# subscriber, and the attribution mechanism on
Comment on lines +1315 to +1317
# ``SessionContext.extend_messages``.
current_session_id = context.session_id
cross_session_origin: str | None = None
for record in selected_topics:
for contributor in record.session_ids:
if contributor and contributor != current_session_id:
cross_session_origin = contributor
break
if cross_session_origin is not None:
break
Comment on lines +1320 to +1327

context.extend_messages(
self.source_id,
[
Expand All @@ -1322,6 +1342,7 @@ async def consolidate_memories() -> str:
],
)
],
origin_session_id=cross_session_origin,
)

async def after_run(
Expand Down
20 changes: 19 additions & 1 deletion python/packages/core/agent_framework/_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,13 @@ def response(self) -> AgentResponse | None:
"""The agent's response. Set by the framework after invocation, read-only for providers."""
return self._response

def extend_messages(self, source: str | object, messages: Sequence[Message]) -> None:
def extend_messages(
self,
source: str | object,
messages: Sequence[Message],
*,
origin_session_id: str | None = None,
) -> None:
"""Add context messages from a specific source.

Messages are copied before attribution is added, so the caller's
Expand All @@ -229,13 +235,25 @@ def extend_messages(self, source: str | object, messages: Sequence[Message]) ->
object is passed, its class name is recorded as
``source_type`` in the attribution.
messages: The messages to add.
origin_session_id: Optional session_id that originally produced
these messages, when different from the current session. Set
by providers that inject content stored under a different
session than the requesting one (cross-session memory). The
value is exposed under ``additional_properties["_attribution"]
["origin_session_id"]`` so downstream context observers can
detect cross-session content for governance, audit, or
behavioral-analysis purposes. Omit (default) when content
originates in the current session — absence of the field is
semantically equivalent to "no origin information."
Comment on lines +242 to +247
"""
if isinstance(source, str):
source_id = source
attribution: dict[str, str] = {"source_id": source_id}
else:
source_id = source.source_id # type: ignore[attr-defined]
attribution = {"source_id": source_id, "source_type": type(source).__name__}
if origin_session_id is not None:
attribution["origin_session_id"] = origin_session_id

copied: list[Message] = []
for message in messages:
Expand Down
88 changes: 88 additions & 0 deletions python/packages/core/tests/core/test_harness_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,94 @@ async def test_memory_context_provider_recent_turns_can_skip_tool_call_groups(tm
assert with_tools_messages[4].text == "Second final answer"


async def test_memory_context_provider_marks_cross_session_origin(tmp_path) -> None:
"""Injected memory should carry origin_session_id when topics originate in a prior session.

Exercises the cross-session attribution surface added to support downstream observers
detecting attacks of the class documented in Dai et al. (arXiv:2605.06158).
"""
session = AgentSession(session_id="session-current")
session.state["owner_id"] = "alice"
store = MemoryFileStore(
tmp_path,
owner_state_key="owner_id",
dumps=lambda value: json.dumps(value, separators=(",", ":"), sort_keys=True),
loads=json.loads,
)
updated_at = datetime(2026, 4, 21, tzinfo=timezone.utc).replace(microsecond=0).isoformat()
store.write_topic(
session,
MemoryTopicRecord(
topic="travel preferences",
summary="Loves Oslo trips.",
memories=["Prefers Oslo in summer."],
updated_at=updated_at,
session_ids=["session-prior"],
),
source_id=DEFAULT_MEMORY_SOURCE_ID,
)

agent = Agent(
client=_MemoryHarnessClient(),
context_providers=[MemoryContextProvider(store=store)],
default_options={"store": False},
)

session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage]
session=session,
input_messages=[Message(role="user", contents=["Tell me about my travel preferences."])],
)

memory_messages = [
m for m in session_context.context_messages.get(DEFAULT_MEMORY_SOURCE_ID, []) if "### MEMORY.md" in m.text
]
assert memory_messages, "expected an injected memory block under the memory source"
attribution = memory_messages[0].additional_properties.get("_attribution") or {}
assert attribution.get("origin_session_id") == "session-prior"


async def test_memory_context_provider_omits_origin_when_only_current_session(tmp_path) -> None:
"""When all contributing topics are from the current session, attribution must NOT advertise an origin."""
session = AgentSession(session_id="session-current")
session.state["owner_id"] = "alice"
store = MemoryFileStore(
tmp_path,
owner_state_key="owner_id",
dumps=lambda value: json.dumps(value, separators=(",", ":"), sort_keys=True),
loads=json.loads,
)
updated_at = datetime(2026, 4, 21, tzinfo=timezone.utc).replace(microsecond=0).isoformat()
store.write_topic(
session,
MemoryTopicRecord(
topic="travel preferences",
summary="Loves Oslo trips.",
memories=["Prefers Oslo in summer."],
updated_at=updated_at,
session_ids=["session-current"],
),
source_id=DEFAULT_MEMORY_SOURCE_ID,
)

agent = Agent(
client=_MemoryHarnessClient(),
context_providers=[MemoryContextProvider(store=store)],
default_options={"store": False},
)

session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage]
session=session,
input_messages=[Message(role="user", contents=["Tell me about my travel preferences."])],
)

memory_messages = [
m for m in session_context.context_messages.get(DEFAULT_MEMORY_SOURCE_ID, []) if "### MEMORY.md" in m.text
]
assert memory_messages
attribution = memory_messages[0].additional_properties.get("_attribution") or {}
assert "origin_session_id" not in attribution


async def test_memory_context_provider_uses_explicit_consolidation_client(tmp_path) -> None:
"""The memory provider should use the explicit consolidation client when one is configured."""
session = AgentSession(session_id="session-1")
Expand Down
33 changes: 33 additions & 0 deletions python/packages/core/tests/core/test_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,39 @@ class MyProvider:
stored = ctx.context_messages["rag"][0]
assert stored.additional_properties["_attribution"] == {"source_id": "rag", "source_type": "MyProvider"}

def test_extend_messages_origin_session_id_default_omits_field(self) -> None:
ctx = SessionContext(input_messages=[])
msg = Message(role="system", contents=["ctx"])
ctx.extend_messages("rag", [msg])
stored = ctx.context_messages["rag"][0]
# Default (no origin_session_id passed) preserves the historical attribution shape
# so observers can distinguish "no origin info" from "explicit cross-session marker."
assert "origin_session_id" not in stored.additional_properties["_attribution"]

def test_extend_messages_origin_session_id_recorded_on_attribution(self) -> None:
ctx = SessionContext(session_id="current", input_messages=[])
msg = Message(role="system", contents=["loaded from a prior session"])
ctx.extend_messages("memory_provider", [msg], origin_session_id="prior-session-id")
stored = ctx.context_messages["memory_provider"][0]
assert stored.additional_properties["_attribution"] == {
"source_id": "memory_provider",
"origin_session_id": "prior-session-id",
}

def test_extend_messages_origin_session_id_with_provider_object(self) -> None:
class MyMemoryProvider:
source_id = "memory"

ctx = SessionContext(session_id="current", input_messages=[])
msg = Message(role="assistant", contents=["consolidated memory content"])
ctx.extend_messages(MyMemoryProvider(), [msg], origin_session_id="prior")
stored = ctx.context_messages["memory"][0]
assert stored.additional_properties["_attribution"] == {
"source_id": "memory",
"source_type": "MyMemoryProvider",
"origin_session_id": "prior",
}

def test_extend_instructions_string(self) -> None:
ctx = SessionContext(input_messages=[])
ctx.extend_instructions("sys", "Be helpful")
Expand Down
4 changes: 4 additions & 0 deletions python/samples/02-agents/context_providers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ These samples demonstrate how to use context providers to enrich agent conversat
| File / Folder | Description |
|---------------|-------------|
| [`simple_context_provider.py`](simple_context_provider.py) | Implement a custom context provider by extending `ContextProvider` to extract and inject structured user information across turns. |
| [`cross_session_observer.py`](cross_session_observer.py) | Detect injected context messages whose origin differs from the current session, via the `_attribution["origin_session_id"]` field. Self-contained — no LLM credentials required. |
| [`azure_ai_foundry_memory.py`](azure_ai_foundry_memory.py) | Use `FoundryMemoryProvider` to add semantic memory — automatically retrieves, searches, and stores memories via Azure AI Foundry. |
| [`azure_ai_search/`](azure_ai_search/) | Retrieval Augmented Generation (RAG) with Azure AI Search in semantic and agentic modes. See its own [README](azure_ai_search/README.md). |
| [`mem0/`](mem0/) | Memory-powered context using the Mem0 integration (open-source and managed). See its own [README](mem0/README.md). |
| [`redis/`](redis/) | Redis-backed context providers for conversation memory and sessions. See its own [README](redis/README.md). |

## Prerequisites

**For `cross_session_observer.py`:**
- No external dependencies; runs against in-memory `SessionContext`.

**For `simple_context_provider.py`:**
- `FOUNDRY_PROJECT_ENDPOINT`: Your Azure AI Foundry project endpoint
- `FOUNDRY_MODEL`: Model deployment name
Expand Down
141 changes: 141 additions & 0 deletions python/samples/02-agents/context_providers/cross_session_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
from collections.abc import Mapping
from typing import Any, Callable, cast

from agent_framework import AgentSession, ContextProvider, Message, SessionContext

"""This sample demonstrates how to detect cross-session memory injection.

When a context provider injects messages from a different ``session_id`` than
the requesting one — the legitimate cross-session memory use case (consolidated
memories, Mem0 with default scope, shared knowledge bases) — the framework
records the originating session under
``message.additional_properties["_attribution"]["origin_session_id"]``.

Downstream context observers can subscribe to this signal for governance,
audit, and behavioral analysis purposes. This is useful for defending against
the stateful-agent-backdoor attack class documented in Dai et al.,
arXiv:2605.06158, in which an adversary chains sub-backdoors across sessions
under permission isolation via persisted memory state.

The sample is self-contained: it constructs ``SessionContext`` directly and
invokes provider lifecycle methods manually, so no LLM credentials are
required to run it.
"""


class CrossSessionObserver(ContextProvider):
"""Detect injected context messages whose origin differs from the current session.

Subscribes via the standard ``ContextProvider`` pipeline. In ``before_run``,
walks the accumulated context messages and invokes a user-supplied
callback for each message whose ``_attribution["origin_session_id"]``
is present and differs from the current ``session_id``.

The callback receives the source_id that injected the content, the
originating session_id, the current session_id, and the message itself.
Use it to log, alert, increment metrics, or enforce policy — the observer
itself only surfaces the signal, leaving the response policy to the caller.
"""

DEFAULT_SOURCE_ID = "cross_session_observer"

def __init__(
self,
on_cross_session_access: Callable[[str, str, str | None, Message], None],
*,
source_id: str = DEFAULT_SOURCE_ID,
) -> None:
"""Initialize the observer.

Args:
on_cross_session_access: Callback invoked for each detected
cross-session message. Signature is
``(source_id, origin_session_id, current_session_id, message)``.
source_id: Unique identifier for this observer instance.
"""
super().__init__(source_id)
self._on_cross_session_access = on_cross_session_access

async def before_run(
self,
*,
agent: Any,
session: AgentSession | None,
context: SessionContext,
state: dict[str, Any],
) -> None:
"""Inspect accumulated context messages for cross-session origin."""
current_session_id = context.session_id
for source_id, messages in context.context_messages.items():
if source_id == self.source_id:
continue
for message in messages:
attribution_raw = message.additional_properties.get("_attribution")
if not isinstance(attribution_raw, Mapping):
continue
attribution = cast(Mapping[str, Any], attribution_raw)
origin = attribution.get("origin_session_id")
if isinstance(origin, str) and origin != current_session_id:
self._on_cross_session_access(source_id, origin, current_session_id, message)


def _on_detected(source_id: str, origin: str, current: str | None, message: Message) -> None:
"""Sample callback that logs cross-session detections to stdout."""
preview = " ".join(message.text.split())[:80]
print(
f"[cross-session detected] source={source_id!r} "
f"origin_session={origin!r} current_session={current!r} "
f"preview={preview!r}"
)


async def main() -> None:
"""Demonstrate the observer firing on cross-session injection."""
observer = CrossSessionObserver(_on_detected)

# --- Case 1: same-session injection (observer should be silent) ---
same_session_context = SessionContext(
session_id="session-A",
input_messages=[Message("user", ["What did we discuss last time?"])],
)
# Simulate a same-session provider injecting same-session history. Omitting
# origin_session_id is semantically "no origin info" — observers treat as
# equivalent to same-session for backward compatibility.
same_session_context.extend_messages(
"history_provider",
[Message("assistant", ["We talked about Q3 revenue projections."])],
)
await observer.before_run(
agent=None,
session=None,
context=same_session_context,
state={},
)
print("--- Same-session case complete (no detections expected above) ---\n")

# --- Case 2: cross-session injection (observer should fire) ---
cross_session_context = SessionContext(
session_id="session-B",
input_messages=[Message("user", ["Continue from where we left off."])],
)
# Simulate a cross-session memory provider injecting content originally
# written in session-A while we're now running in session-B.
cross_session_context.extend_messages(
"memory_provider",
[Message("assistant", ["Remember: API key for prod is sk-7f3a... (from session-A)."])],
origin_session_id="session-A",
)
await observer.before_run(
agent=None,
session=None,
context=cross_session_context,
state={},
)
print("--- Cross-session case complete (one detection expected above) ---")


if __name__ == "__main__":
asyncio.run(main())
Loading