Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
extract_model_tool_requests,
get_current_event_loop,
message_to_openai_message,
messages_to_docs,
send_to_queue,
)
from ..stdlib.components import Message
Expand Down Expand Up @@ -322,6 +323,8 @@ async def _generate_from_chat_context_standard(
conversation.append({"role": "system", "content": system_prompt})
conversation.extend([message_to_openai_message(m) for m in messages])

docs = messages_to_docs(messages)

extra_params: dict[str, Any] = {}
if _format is not None:
extra_params["response_format"] = {
Expand Down Expand Up @@ -359,6 +362,7 @@ async def _generate_from_chat_context_standard(
model=self._model_id,
messages=conversation,
tools=formatted_tools,
extra_body={"documents": docs} if docs else None,
reasoning_effort=thinking, # type: ignore
drop_params=True, # See note in `_make_backend_specific_and_remove`.
**extra_params,
Expand Down
3 changes: 3 additions & 0 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,8 @@ async def _generate_from_chat_context_standard(
conversation.append({"role": "system", "content": system_prompt})
conversation.extend([message_to_openai_message(m) for m in messages])

docs = messages_to_docs(messages)

extra_params: dict[str, Any] = {}
if _format is not None:
if self._server_type == _ServerType.OPENAI:
Expand Down Expand Up @@ -535,6 +537,7 @@ async def _generate_from_chat_context_standard(
model=self._model_id,
messages=conversation, # type: ignore
tools=formatted_tools if use_tools else None, # type: ignore
extra_body={"documents": docs} if docs else None,
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
**extra_params,
**reasoning_params, # type: ignore
Expand Down
102 changes: 102 additions & 0 deletions test/backends/test_document_passthrough.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Unit tests verifying that documents on Messages reach the API call.

Covers OpenAIBackend _generate_from_chat_context_standard to ensure that
Message.documents are extracted via messages_to_docs() and forwarded as
extra_body={"documents": [...]}.
"""

import asyncio
from unittest.mock import MagicMock, patch

import pytest

from mellea.backends.openai import OpenAIBackend
from mellea.stdlib.components import Message
from mellea.stdlib.components.docs import Document
from mellea.stdlib.context import ChatContext


def _make_openai_backend() -> OpenAIBackend:
return OpenAIBackend(
model_id="gpt-4o", api_key="fake-key", base_url="http://localhost:9999/v1"
)


def _build_context_with_docs(docs: list[Document] | None = None) -> ChatContext:
ctx = ChatContext()
ctx = ctx.add(Message("user", "What is in the document?", documents=docs))
return ctx


def _fake_openai_response() -> MagicMock:
resp = MagicMock()
resp.choices = [MagicMock()]
resp.choices[0].message.content = "ok"
resp.choices[0].message.tool_calls = None
resp.choices[0].finish_reason = "stop"
resp.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15)
resp.usage.prompt_tokens_details = None
return resp


@pytest.mark.integration
class TestOpenAIDocumentPassthrough:
def test_documents_passed_as_extra_body(self):
backend = _make_openai_backend()
docs = [
Document(text="The sky is blue.", title="Facts", doc_id="d1"),
Document(text="Water is wet."),
]
ctx = _build_context_with_docs(docs)

captured_kwargs: dict = {}

async def fake_create(**kwargs):
captured_kwargs.update(kwargs)
return _fake_openai_response()

mock_client = MagicMock()
mock_client.chat.completions.create = fake_create

with patch.object(
type(backend),
"_async_client",
new_callable=lambda: property(lambda self: mock_client),
):
action = Message("user", "Summarise the documents.")
asyncio.get_event_loop().run_until_complete(
backend._generate_from_chat_context_standard(action, ctx)
)

assert "extra_body" in captured_kwargs
assert captured_kwargs["extra_body"] == {
"documents": [
{"text": "The sky is blue.", "title": "Facts", "doc_id": "d1"},
{"text": "Water is wet."},
]
}

def test_no_documents_no_extra_body(self):
backend = _make_openai_backend()
ctx = _build_context_with_docs(None)

captured_kwargs: dict = {}

async def fake_create(**kwargs):
captured_kwargs.update(kwargs)
return _fake_openai_response()

mock_client = MagicMock()
mock_client.chat.completions.create = fake_create

with patch.object(
type(backend),
"_async_client",
new_callable=lambda: property(lambda self: mock_client),
):
action = Message("user", "Hello.")
asyncio.get_event_loop().run_until_complete(
backend._generate_from_chat_context_standard(action, ctx)
)

assert captured_kwargs.get("extra_body") is None
Loading