Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ def _annotation_includes_function_invocation_context(annotation: Any) -> bool:
ClassT = TypeVar("ClassT", bound="SerializationMixin")


def _dump_tool_arguments(model: BaseModel, *, include_none_from: Iterable[str] | None = None) -> dict[str, Any]:
arguments = model.model_dump(exclude_none=True)
for name in include_none_from if include_none_from is not None else model.model_fields_set:
if getattr(model, name, None) is None:
Comment on lines +242 to +243
arguments[name] = None
return arguments


class FunctionTool(SerializationMixin):
"""A tool that wraps a Python function to make it callable by AI models.

Expand Down Expand Up @@ -635,8 +643,9 @@ async def invoke(
if isinstance(arguments, Mapping):
parsed_arguments = dict(arguments)
if self.input_model is not None and not self._schema_supplied:
parsed_arguments = self.input_model.model_validate(parsed_arguments).model_dump(
exclude_none=True
parsed_arguments = _dump_tool_arguments(
self.input_model.model_validate(parsed_arguments),
include_none_from=parsed_arguments,
)
elif isinstance(arguments, BaseModel):
if (
Expand All @@ -645,7 +654,7 @@ async def invoke(
and not isinstance(arguments, self.input_model)
):
raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}")
parsed_arguments = arguments.model_dump(exclude_none=True)
parsed_arguments = _dump_tool_arguments(arguments)
else:
raise TypeError(
f"Expected mapping-like arguments for tool '{self.name}', got {type(arguments).__name__}"
Expand Down Expand Up @@ -1492,7 +1501,7 @@ async def _auto_invoke_function(
runtime_kwargs["session"] = invocation_session
try:
if not cast(bool, getattr(tool, "_schema_supplied", False)) and tool.input_model is not None:
args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True)
args = _dump_tool_arguments(tool.input_model.model_validate(parsed_args), include_none_from=parsed_args)
else:
args = dict(parsed_args)
args = _validate_arguments_against_schema(
Expand Down
34 changes: 34 additions & 0 deletions python/packages/core/tests/core/test_function_invocation_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,40 @@ def ai_func(arg1: str) -> str:
assert response.messages[2].text == "done"


async def test_base_client_with_function_calling_preserves_null_arguments(chat_client_base: SupportsChatGetResponse):
seen_unit = "not-called"

@tool(name="get_weather", approval_mode="never_require")
def get_weather(location: str, unit: str | None) -> str:
nonlocal seen_unit
seen_unit = unit
return f"{location}:{unit}"

chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="1",
name="get_weather",
arguments='{"location": "Seattle", "unit": null}',
)
],
)
),
ChatResponse(messages=Message(role="assistant", contents=["done"])),
]

response = await chat_client_base.get_response(
[Message(role="user", contents=["weather?"])], options={"tool_choice": "auto", "tools": [get_weather]}
)

assert seen_unit is None
assert response.messages[1].contents[0].type == "function_result"
assert response.messages[1].contents[0].result == "Seattle:None"


async def test_base_client_with_function_calling_string_input(chat_client_base: SupportsChatGetResponse):
exec_counter = 0

Expand Down
11 changes: 11 additions & 0 deletions python/packages/core/tests/core/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ def search(query: str, max_results: int = 10) -> str:
await search.invoke(arguments={"query": "hello", "max_results": "three"})


async def test_tool_invoke_preserves_required_nullable_argument() -> None:
@tool
def get_weather(location: str, unit: str | None) -> str:
return f"{location}:{unit}"

result = await get_weather.invoke(arguments={"location": "Seattle", "unit": None})

assert isinstance(result, list)
assert result[0].text == "Seattle:None"


def test_tool_decorator_with_json_schema_preserves_custom_properties():
"""Test schema passthrough keeps custom JSON schema properties."""

Expand Down
Loading