diff --git a/python/packages/core/agent_framework/_harness/_memory.py b/python/packages/core/agent_framework/_harness/_memory.py index 92e060f442..9fdb914917 100644 --- a/python/packages/core/agent_framework/_harness/_memory.py +++ b/python/packages/core/agent_framework/_harness/_memory.py @@ -1306,6 +1306,26 @@ async def consolidate_memories() -> str: ) if recent_history_messages: context.extend_messages(self.source_id, recent_history_messages) + + # Surface cross-session origin so downstream context observers can + # distinguish injected memory from content native to the current + # session. Loaded topic files may carry contributions from earlier + # sessions (tracked in ``MemoryTopicRecord.session_ids``); when any + # contributor differs from the current session, mark the injected + # block accordingly. See the ``CrossSessionObserver`` sample under + # ``samples/governance/cross_session_observer`` for an example + # subscriber, and the attribution mechanism on + # ``SessionContext.extend_messages``. + current_session_id = context.session_id + cross_session_origin: str | None = None + for record in selected_topics: + for contributor in record.session_ids: + if contributor and contributor != current_session_id: + cross_session_origin = contributor + break + if cross_session_origin is not None: + break + context.extend_messages( self.source_id, [ @@ -1322,6 +1342,7 @@ async def consolidate_memories() -> str: ], ) ], + origin_session_id=cross_session_origin, ) async def after_run( diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index be4d4ea285..322cd09f4e 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -214,7 +214,13 @@ def response(self) -> AgentResponse | None: """The agent's response. Set by the framework after invocation, read-only for providers.""" return self._response - def extend_messages(self, source: str | object, messages: Sequence[Message]) -> None: + def extend_messages( + self, + source: str | object, + messages: Sequence[Message], + *, + origin_session_id: str | None = None, + ) -> None: """Add context messages from a specific source. Messages are copied before attribution is added, so the caller's @@ -229,6 +235,16 @@ def extend_messages(self, source: str | object, messages: Sequence[Message]) -> object is passed, its class name is recorded as ``source_type`` in the attribution. messages: The messages to add. + origin_session_id: Optional session_id that originally produced + these messages, when different from the current session. Set + by providers that inject content stored under a different + session than the requesting one (cross-session memory). The + value is exposed under ``additional_properties["_attribution"] + ["origin_session_id"]`` so downstream context observers can + detect cross-session content for governance, audit, or + behavioral-analysis purposes. Omit (default) when content + originates in the current session — absence of the field is + semantically equivalent to "no origin information." """ if isinstance(source, str): source_id = source @@ -236,6 +252,8 @@ def extend_messages(self, source: str | object, messages: Sequence[Message]) -> else: source_id = source.source_id # type: ignore[attr-defined] attribution = {"source_id": source_id, "source_type": type(source).__name__} + if origin_session_id is not None: + attribution["origin_session_id"] = origin_session_id copied: list[Message] = [] for message in messages: diff --git a/python/packages/core/tests/core/test_harness_memory.py b/python/packages/core/tests/core/test_harness_memory.py index 9d4c7c71d1..4b86467d9f 100644 --- a/python/packages/core/tests/core/test_harness_memory.py +++ b/python/packages/core/tests/core/test_harness_memory.py @@ -500,6 +500,94 @@ async def test_memory_context_provider_recent_turns_can_skip_tool_call_groups(tm assert with_tools_messages[4].text == "Second final answer" +async def test_memory_context_provider_marks_cross_session_origin(tmp_path) -> None: + """Injected memory should carry origin_session_id when topics originate in a prior session. + + Exercises the cross-session attribution surface added to support downstream observers + detecting attacks of the class documented in Dai et al. (arXiv:2605.06158). + """ + session = AgentSession(session_id="session-current") + session.state["owner_id"] = "alice" + store = MemoryFileStore( + tmp_path, + owner_state_key="owner_id", + dumps=lambda value: json.dumps(value, separators=(",", ":"), sort_keys=True), + loads=json.loads, + ) + updated_at = datetime(2026, 4, 21, tzinfo=timezone.utc).replace(microsecond=0).isoformat() + store.write_topic( + session, + MemoryTopicRecord( + topic="travel preferences", + summary="Loves Oslo trips.", + memories=["Prefers Oslo in summer."], + updated_at=updated_at, + session_ids=["session-prior"], + ), + source_id=DEFAULT_MEMORY_SOURCE_ID, + ) + + agent = Agent( + client=_MemoryHarnessClient(), + context_providers=[MemoryContextProvider(store=store)], + default_options={"store": False}, + ) + + session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=session, + input_messages=[Message(role="user", contents=["Tell me about my travel preferences."])], + ) + + memory_messages = [ + m for m in session_context.context_messages.get(DEFAULT_MEMORY_SOURCE_ID, []) if "### MEMORY.md" in m.text + ] + assert memory_messages, "expected an injected memory block under the memory source" + attribution = memory_messages[0].additional_properties.get("_attribution") or {} + assert attribution.get("origin_session_id") == "session-prior" + + +async def test_memory_context_provider_omits_origin_when_only_current_session(tmp_path) -> None: + """When all contributing topics are from the current session, attribution must NOT advertise an origin.""" + session = AgentSession(session_id="session-current") + session.state["owner_id"] = "alice" + store = MemoryFileStore( + tmp_path, + owner_state_key="owner_id", + dumps=lambda value: json.dumps(value, separators=(",", ":"), sort_keys=True), + loads=json.loads, + ) + updated_at = datetime(2026, 4, 21, tzinfo=timezone.utc).replace(microsecond=0).isoformat() + store.write_topic( + session, + MemoryTopicRecord( + topic="travel preferences", + summary="Loves Oslo trips.", + memories=["Prefers Oslo in summer."], + updated_at=updated_at, + session_ids=["session-current"], + ), + source_id=DEFAULT_MEMORY_SOURCE_ID, + ) + + agent = Agent( + client=_MemoryHarnessClient(), + context_providers=[MemoryContextProvider(store=store)], + default_options={"store": False}, + ) + + session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=session, + input_messages=[Message(role="user", contents=["Tell me about my travel preferences."])], + ) + + memory_messages = [ + m for m in session_context.context_messages.get(DEFAULT_MEMORY_SOURCE_ID, []) if "### MEMORY.md" in m.text + ] + assert memory_messages + attribution = memory_messages[0].additional_properties.get("_attribution") or {} + assert "origin_session_id" not in attribution + + async def test_memory_context_provider_uses_explicit_consolidation_client(tmp_path) -> None: """The memory provider should use the explicit consolidation client when one is configured.""" session = AgentSession(session_id="session-1") diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index ebb91d0b0d..808eec9aba 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -107,6 +107,39 @@ class MyProvider: stored = ctx.context_messages["rag"][0] assert stored.additional_properties["_attribution"] == {"source_id": "rag", "source_type": "MyProvider"} + def test_extend_messages_origin_session_id_default_omits_field(self) -> None: + ctx = SessionContext(input_messages=[]) + msg = Message(role="system", contents=["ctx"]) + ctx.extend_messages("rag", [msg]) + stored = ctx.context_messages["rag"][0] + # Default (no origin_session_id passed) preserves the historical attribution shape + # so observers can distinguish "no origin info" from "explicit cross-session marker." + assert "origin_session_id" not in stored.additional_properties["_attribution"] + + def test_extend_messages_origin_session_id_recorded_on_attribution(self) -> None: + ctx = SessionContext(session_id="current", input_messages=[]) + msg = Message(role="system", contents=["loaded from a prior session"]) + ctx.extend_messages("memory_provider", [msg], origin_session_id="prior-session-id") + stored = ctx.context_messages["memory_provider"][0] + assert stored.additional_properties["_attribution"] == { + "source_id": "memory_provider", + "origin_session_id": "prior-session-id", + } + + def test_extend_messages_origin_session_id_with_provider_object(self) -> None: + class MyMemoryProvider: + source_id = "memory" + + ctx = SessionContext(session_id="current", input_messages=[]) + msg = Message(role="assistant", contents=["consolidated memory content"]) + ctx.extend_messages(MyMemoryProvider(), [msg], origin_session_id="prior") + stored = ctx.context_messages["memory"][0] + assert stored.additional_properties["_attribution"] == { + "source_id": "memory", + "source_type": "MyMemoryProvider", + "origin_session_id": "prior", + } + def test_extend_instructions_string(self) -> None: ctx = SessionContext(input_messages=[]) ctx.extend_instructions("sys", "Be helpful") diff --git a/python/samples/02-agents/context_providers/README.md b/python/samples/02-agents/context_providers/README.md index 04f3a1395f..355cf4537d 100644 --- a/python/samples/02-agents/context_providers/README.md +++ b/python/samples/02-agents/context_providers/README.md @@ -7,6 +7,7 @@ These samples demonstrate how to use context providers to enrich agent conversat | File / Folder | Description | |---------------|-------------| | [`simple_context_provider.py`](simple_context_provider.py) | Implement a custom context provider by extending `ContextProvider` to extract and inject structured user information across turns. | +| [`cross_session_observer.py`](cross_session_observer.py) | Detect injected context messages whose origin differs from the current session, via the `_attribution["origin_session_id"]` field. Self-contained — no LLM credentials required. | | [`azure_ai_foundry_memory.py`](azure_ai_foundry_memory.py) | Use `FoundryMemoryProvider` to add semantic memory — automatically retrieves, searches, and stores memories via Azure AI Foundry. | | [`azure_ai_search/`](azure_ai_search/) | Retrieval Augmented Generation (RAG) with Azure AI Search in semantic and agentic modes. See its own [README](azure_ai_search/README.md). | | [`mem0/`](mem0/) | Memory-powered context using the Mem0 integration (open-source and managed). See its own [README](mem0/README.md). | @@ -14,6 +15,9 @@ These samples demonstrate how to use context providers to enrich agent conversat ## Prerequisites +**For `cross_session_observer.py`:** +- No external dependencies; runs against in-memory `SessionContext`. + **For `simple_context_provider.py`:** - `FOUNDRY_PROJECT_ENDPOINT`: Your Azure AI Foundry project endpoint - `FOUNDRY_MODEL`: Model deployment name diff --git a/python/samples/02-agents/context_providers/cross_session_observer.py b/python/samples/02-agents/context_providers/cross_session_observer.py new file mode 100644 index 0000000000..c502f9300e --- /dev/null +++ b/python/samples/02-agents/context_providers/cross_session_observer.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Mapping +from typing import Any, Callable, cast + +from agent_framework import AgentSession, ContextProvider, Message, SessionContext + +"""This sample demonstrates how to detect cross-session memory injection. + +When a context provider injects messages from a different ``session_id`` than +the requesting one — the legitimate cross-session memory use case (consolidated +memories, Mem0 with default scope, shared knowledge bases) — the framework +records the originating session under +``message.additional_properties["_attribution"]["origin_session_id"]``. + +Downstream context observers can subscribe to this signal for governance, +audit, and behavioral analysis purposes. This is useful for defending against +the stateful-agent-backdoor attack class documented in Dai et al., +arXiv:2605.06158, in which an adversary chains sub-backdoors across sessions +under permission isolation via persisted memory state. + +The sample is self-contained: it constructs ``SessionContext`` directly and +invokes provider lifecycle methods manually, so no LLM credentials are +required to run it. +""" + + +class CrossSessionObserver(ContextProvider): + """Detect injected context messages whose origin differs from the current session. + + Subscribes via the standard ``ContextProvider`` pipeline. In ``before_run``, + walks the accumulated context messages and invokes a user-supplied + callback for each message whose ``_attribution["origin_session_id"]`` + is present and differs from the current ``session_id``. + + The callback receives the source_id that injected the content, the + originating session_id, the current session_id, and the message itself. + Use it to log, alert, increment metrics, or enforce policy — the observer + itself only surfaces the signal, leaving the response policy to the caller. + """ + + DEFAULT_SOURCE_ID = "cross_session_observer" + + def __init__( + self, + on_cross_session_access: Callable[[str, str, str | None, Message], None], + *, + source_id: str = DEFAULT_SOURCE_ID, + ) -> None: + """Initialize the observer. + + Args: + on_cross_session_access: Callback invoked for each detected + cross-session message. Signature is + ``(source_id, origin_session_id, current_session_id, message)``. + source_id: Unique identifier for this observer instance. + """ + super().__init__(source_id) + self._on_cross_session_access = on_cross_session_access + + async def before_run( + self, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Inspect accumulated context messages for cross-session origin.""" + current_session_id = context.session_id + for source_id, messages in context.context_messages.items(): + if source_id == self.source_id: + continue + for message in messages: + attribution_raw = message.additional_properties.get("_attribution") + if not isinstance(attribution_raw, Mapping): + continue + attribution = cast(Mapping[str, Any], attribution_raw) + origin = attribution.get("origin_session_id") + if isinstance(origin, str) and origin != current_session_id: + self._on_cross_session_access(source_id, origin, current_session_id, message) + + +def _on_detected(source_id: str, origin: str, current: str | None, message: Message) -> None: + """Sample callback that logs cross-session detections to stdout.""" + preview = " ".join(message.text.split())[:80] + print( + f"[cross-session detected] source={source_id!r} " + f"origin_session={origin!r} current_session={current!r} " + f"preview={preview!r}" + ) + + +async def main() -> None: + """Demonstrate the observer firing on cross-session injection.""" + observer = CrossSessionObserver(_on_detected) + + # --- Case 1: same-session injection (observer should be silent) --- + same_session_context = SessionContext( + session_id="session-A", + input_messages=[Message("user", ["What did we discuss last time?"])], + ) + # Simulate a same-session provider injecting same-session history. Omitting + # origin_session_id is semantically "no origin info" — observers treat as + # equivalent to same-session for backward compatibility. + same_session_context.extend_messages( + "history_provider", + [Message("assistant", ["We talked about Q3 revenue projections."])], + ) + await observer.before_run( + agent=None, + session=None, + context=same_session_context, + state={}, + ) + print("--- Same-session case complete (no detections expected above) ---\n") + + # --- Case 2: cross-session injection (observer should fire) --- + cross_session_context = SessionContext( + session_id="session-B", + input_messages=[Message("user", ["Continue from where we left off."])], + ) + # Simulate a cross-session memory provider injecting content originally + # written in session-A while we're now running in session-B. + cross_session_context.extend_messages( + "memory_provider", + [Message("assistant", ["Remember: API key for prod is sk-7f3a... (from session-A)."])], + origin_session_id="session-A", + ) + await observer.before_run( + agent=None, + session=None, + context=cross_session_context, + state={}, + ) + print("--- Cross-session case complete (one detection expected above) ---") + + +if __name__ == "__main__": + asyncio.run(main())