diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 93722a8987..dc8a38ded8 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -237,6 +237,14 @@ def _annotation_includes_function_invocation_context(annotation: Any) -> bool: ClassT = TypeVar("ClassT", bound="SerializationMixin") +def _dump_tool_arguments(model: BaseModel, *, include_none_from: Iterable[str] | None = None) -> dict[str, Any]: + arguments = model.model_dump(exclude_none=True) + for name in include_none_from if include_none_from is not None else model.model_fields_set: + if getattr(model, name, None) is None: + arguments[name] = None + return arguments + + class FunctionTool(SerializationMixin): """A tool that wraps a Python function to make it callable by AI models. @@ -635,8 +643,9 @@ async def invoke( if isinstance(arguments, Mapping): parsed_arguments = dict(arguments) if self.input_model is not None and not self._schema_supplied: - parsed_arguments = self.input_model.model_validate(parsed_arguments).model_dump( - exclude_none=True + parsed_arguments = _dump_tool_arguments( + self.input_model.model_validate(parsed_arguments), + include_none_from=parsed_arguments, ) elif isinstance(arguments, BaseModel): if ( @@ -645,7 +654,7 @@ async def invoke( and not isinstance(arguments, self.input_model) ): raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}") - parsed_arguments = arguments.model_dump(exclude_none=True) + parsed_arguments = _dump_tool_arguments(arguments) else: raise TypeError( f"Expected mapping-like arguments for tool '{self.name}', got {type(arguments).__name__}" @@ -1492,7 +1501,7 @@ async def _auto_invoke_function( runtime_kwargs["session"] = invocation_session try: if not cast(bool, getattr(tool, "_schema_supplied", False)) and tool.input_model is not None: - args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True) + args = _dump_tool_arguments(tool.input_model.model_validate(parsed_args), include_none_from=parsed_args) else: args = dict(parsed_args) args = _validate_arguments_against_schema( diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 3d20a26080..774f41fef9 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -87,6 +87,40 @@ def ai_func(arg1: str) -> str: assert response.messages[2].text == "done" +async def test_base_client_with_function_calling_preserves_null_arguments(chat_client_base: SupportsChatGetResponse): + seen_unit = "not-called" + + @tool(name="get_weather", approval_mode="never_require") + def get_weather(location: str, unit: str | None) -> str: + nonlocal seen_unit + seen_unit = unit + return f"{location}:{unit}" + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="1", + name="get_weather", + arguments='{"location": "Seattle", "unit": null}', + ) + ], + ) + ), + ChatResponse(messages=Message(role="assistant", contents=["done"])), + ] + + response = await chat_client_base.get_response( + [Message(role="user", contents=["weather?"])], options={"tool_choice": "auto", "tools": [get_weather]} + ) + + assert seen_unit is None + assert response.messages[1].contents[0].type == "function_result" + assert response.messages[1].contents[0].result == "Seattle:None" + + async def test_base_client_with_function_calling_string_input(chat_client_base: SupportsChatGetResponse): exec_counter = 0 diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index b3762bf4ef..3d62e76a93 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -183,6 +183,17 @@ def search(query: str, max_results: int = 10) -> str: await search.invoke(arguments={"query": "hello", "max_results": "three"}) +async def test_tool_invoke_preserves_required_nullable_argument() -> None: + @tool + def get_weather(location: str, unit: str | None) -> str: + return f"{location}:{unit}" + + result = await get_weather.invoke(arguments={"location": "Seattle", "unit": None}) + + assert isinstance(result, list) + assert result[0].text == "Seattle:None" + + def test_tool_decorator_with_json_schema_preserves_custom_properties(): """Test schema passthrough keeps custom JSON schema properties."""