1"""Test the base tool implementation."""23import inspect4import json5import logging6import sys7import textwrap8import threading9from collections.abc import Callable10from dataclasses import dataclass11from datetime import datetime12from enum import Enum13from functools import partial14from typing import (15 Annotated,16 Any,17 Generic,18 Literal,19 TypeVar,20 cast,21 get_type_hints,22)2324import pytest25from pydantic import BaseModel, ConfigDict, Field, ValidationError26from pydantic.v1 import BaseModel as BaseModelV127from pydantic.v1 import ValidationError as ValidationErrorV128from typing_extensions import TypedDict, override2930from langchain_core import tools31from langchain_core.callbacks import (32 AsyncCallbackManagerForToolRun,33 CallbackManagerForToolRun,34)35from langchain_core.callbacks.manager import (36 CallbackManagerForRetrieverRun,37)38from langchain_core.documents import Document39from langchain_core.messages import ToolCall, ToolMessage40from langchain_core.messages.tool import ToolOutputMixin41from langchain_core.retrievers import BaseRetriever42from langchain_core.runnables import (43 RunnableConfig,44 RunnableLambda,45 ensure_config,46)47from langchain_core.tools import (48 BaseTool,49 StructuredTool,50 Tool,51 ToolException,52 tool,53)54from langchain_core.tools.base import (55 TOOL_MESSAGE_BLOCK_TYPES,56 ArgsSchema,57 InjectedToolArg,58 InjectedToolCallId,59 SchemaAnnotationError,60 _DirectlyInjectedToolArg,61 _format_output,62 _is_message_content_block,63 _is_message_content_type,64 get_all_basemodel_annotations,65)66from langchain_core.utils.function_calling import (67 convert_to_openai_function,68 convert_to_openai_tool,69)70from langchain_core.utils.pydantic import (71 _create_subset_model,72 create_model_v2,73)74from tests.unit_tests.fake.callbacks import FakeCallbackHandler75from tests.unit_tests.pydantic_utils import _normalize_schema, _schema7677try:78 from langgraph.prebuilt import ToolRuntime # type: ignore[import-not-found]7980 HAS_LANGGRAPH = True81except ImportError:82 HAS_LANGGRAPH = False838485def _get_tool_call_json_schema(tool: BaseTool) -> dict[str, Any]:86 tool_schema = tool.tool_call_schema87 if isinstance(tool_schema, dict):88 return tool_schema8990 if issubclass(tool_schema, BaseModel):91 return tool_schema.model_json_schema()92 if issubclass(tool_schema, BaseModelV1):93 return tool_schema.schema()94 return {}959697def test_unnamed_decorator() -> None:98 """Test functionality with unnamed decorator."""99100 @tool101 def search_api(query: str) -> str:102 """Search the API for the query."""103 return "API result"104105 assert isinstance(search_api, BaseTool)106 assert search_api.name == "search_api"107 assert not search_api.return_direct108 assert search_api.invoke("test") == "API result"109110111class _MockSchema(BaseModel):112 """Return the arguments directly."""113114 arg1: int115 arg2: bool116 arg3: dict | None = None117118119class _MockStructuredTool(BaseTool):120 name: str = "structured_api"121 args_schema: type[BaseModel] = _MockSchema122 description: str = "A Structured Tool"123124 @override125 def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str:126 return f"{arg1} {arg2} {arg3}"127128 async def _arun(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str:129 raise NotImplementedError130131132class _FakeOutput(ToolOutputMixin):133 """Minimal ToolOutputMixin subclass used only in tests."""134135 def __init__(self, value: int) -> None:136 self.value = value137138 def __eq__(self, other: object) -> bool:139 return isinstance(other, _FakeOutput) and self.value == other.value140141 def __hash__(self) -> int:142 return hash(self.value)143144 def __repr__(self) -> str:145 return f"_FakeOutput({self.value})"146147148def test_structured_args() -> None:149 """Test functionality with structured arguments."""150 structured_api = _MockStructuredTool()151 assert isinstance(structured_api, BaseTool)152 assert structured_api.name == "structured_api"153 expected_result = "1 True {'foo': 'bar'}"154 args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}155 assert structured_api.run(args) == expected_result156157158def test_misannotated_base_tool_raises_error() -> None:159 """Test that a BaseTool with the incorrect typehint raises an exception."""160 with pytest.raises(SchemaAnnotationError):161162 class _MisAnnotatedTool(BaseTool):163 name: str = "structured_api"164 # This would silently be ignored without the custom metaclass165 args_schema: BaseModel = _MockSchema # type: ignore[assignment]166 description: str = "A Structured Tool"167168 @override169 def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str:170 return f"{arg1} {arg2} {arg3}"171172 async def _arun(173 self, *, arg1: int, arg2: bool, arg3: dict | None = None174 ) -> str:175 raise NotImplementedError176177178def test_forward_ref_annotated_base_tool_accepted() -> None:179 """Test that a using forward ref annotation syntax is accepted."""180181 class _ForwardRefAnnotatedTool(BaseTool):182 name: str = "structured_api"183 args_schema: "type[BaseModel]" = _MockSchema184 description: str = "A Structured Tool"185186 @override187 def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str:188 return f"{arg1} {arg2} {arg3}"189190 async def _arun(191 self, *, arg1: int, arg2: bool, arg3: dict | None = None192 ) -> str:193 raise NotImplementedError194195196def test_subclass_annotated_base_tool_accepted() -> None:197 """Test BaseTool child w/ custom schema isn't overwritten."""198199 class _ForwardRefAnnotatedTool(BaseTool):200 name: str = "structured_api"201 args_schema: type[_MockSchema] = _MockSchema202 description: str = "A Structured Tool"203204 @override205 def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str:206 return f"{arg1} {arg2} {arg3}"207208 async def _arun(209 self, *, arg1: int, arg2: bool, arg3: dict | None = None210 ) -> str:211 raise NotImplementedError212213 assert issubclass(_ForwardRefAnnotatedTool, BaseTool)214 tool = _ForwardRefAnnotatedTool()215 assert tool.args_schema == _MockSchema216217218def test_decorator_with_specified_schema() -> None:219 """Test that manually specified schemata are passed through to the tool."""220221 @tool(args_schema=_MockSchema)222 def tool_func(*, arg1: int, arg2: bool, arg3: dict | None = None) -> str:223 return f"{arg1} {arg2} {arg3}"224225 assert isinstance(tool_func, BaseTool)226 assert tool_func.args_schema == _MockSchema227228229@pytest.mark.skipif(230 sys.version_info >= (3, 14),231 reason="pydantic.v1 namespace not supported with Python 3.14+",232)233def test_decorator_with_specified_schema_pydantic_v1() -> None:234 """Test that manually specified schemata are passed through to the tool."""235236 class _MockSchemaV1(BaseModelV1):237 """Return the arguments directly."""238239 arg1: int240 arg2: bool241 arg3: dict | None = None242243 @tool(args_schema=cast("ArgsSchema", _MockSchemaV1))244 def tool_func_v1(*, arg1: int, arg2: bool, arg3: dict | None = None) -> str:245 return f"{arg1} {arg2} {arg3}"246247 assert isinstance(tool_func_v1, BaseTool)248 assert tool_func_v1.args_schema == cast("ArgsSchema", _MockSchemaV1)249250251def test_decorated_function_schema_equivalent() -> None:252 """Test that a BaseTool without a schema meets expectations."""253254 @tool255 def structured_tool_input(256 *, arg1: int, arg2: bool, arg3: dict | None = None257 ) -> str:258 """Return the arguments directly."""259 return f"{arg1} {arg2} {arg3}"260261 assert isinstance(structured_tool_input, BaseTool)262 assert structured_tool_input.args_schema is not None263 assert (264 _schema(structured_tool_input.args_schema)["properties"]265 == _schema(_MockSchema)["properties"]266 == _normalize_schema(structured_tool_input.args)267 )268269270def test_args_kwargs_filtered() -> None:271 class _SingleArgToolWithKwargs(BaseTool):272 name: str = "single_arg_tool"273 description: str = "A single arged tool with kwargs"274275 @override276 def _run(277 self,278 some_arg: str,279 run_manager: CallbackManagerForToolRun | None = None,280 **kwargs: Any,281 ) -> str:282 return "foo"283284 async def _arun(285 self,286 some_arg: str,287 run_manager: AsyncCallbackManagerForToolRun | None = None,288 **kwargs: Any,289 ) -> str:290 raise NotImplementedError291292 tool = _SingleArgToolWithKwargs()293 assert tool.is_single_input294295 class _VarArgToolWithKwargs(BaseTool):296 name: str = "single_arg_tool"297 description: str = "A single arged tool with kwargs"298299 @override300 def _run(301 self,302 *args: Any,303 run_manager: CallbackManagerForToolRun | None = None,304 **kwargs: Any,305 ) -> str:306 return "foo"307308 async def _arun(309 self,310 *args: Any,311 run_manager: AsyncCallbackManagerForToolRun | None = None,312 **kwargs: Any,313 ) -> str:314 raise NotImplementedError315316 tool2 = _VarArgToolWithKwargs()317 assert tool2.is_single_input318319320def test_structured_args_decorator_no_infer_schema() -> None:321 """Test functionality with structured arguments parsed as a decorator."""322323 @tool(infer_schema=False)324 def structured_tool_input(325 arg1: int, arg2: float | datetime, opt_arg: dict | None = None326 ) -> str:327 """Return the arguments directly."""328 return f"{arg1}, {arg2}, {opt_arg}"329330 assert isinstance(structured_tool_input, BaseTool)331 assert structured_tool_input.name == "structured_tool_input"332 args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}333 with pytest.raises(ToolException):334 assert structured_tool_input.run(args)335336337def test_structured_single_str_decorator_no_infer_schema() -> None:338 """Test functionality with structured arguments parsed as a decorator."""339340 @tool(infer_schema=False)341 def unstructured_tool_input(tool_input: str) -> str:342 """Return the arguments directly."""343 assert isinstance(tool_input, str)344 return f"{tool_input}"345346 assert isinstance(unstructured_tool_input, BaseTool)347 assert unstructured_tool_input.args_schema is None348 assert unstructured_tool_input.run("foo") == "foo"349350351def test_structured_tool_types_parsed() -> None:352 """Test the non-primitive types are correctly passed to structured tools."""353354 class SomeEnum(Enum):355 A = "a"356 B = "b"357358 class SomeBaseModel(BaseModel):359 foo: str360361 @tool362 def structured_tool(363 some_enum: SomeEnum,364 some_base_model: SomeBaseModel,365 ) -> dict:366 """Return the arguments directly."""367 return {368 "some_enum": some_enum,369 "some_base_model": some_base_model,370 }371372 assert isinstance(structured_tool, StructuredTool)373 args = {374 "some_enum": SomeEnum.A.value,375 "some_base_model": SomeBaseModel(foo="bar").model_dump(),376 }377 result = structured_tool.run(json.loads(json.dumps(args)))378 expected = {379 "some_enum": SomeEnum.A,380 "some_base_model": SomeBaseModel(foo="bar"),381 }382 assert result == expected383384385@pytest.mark.skipif(386 sys.version_info >= (3, 14),387 reason="pydantic.v1 namespace not supported with Python 3.14+",388)389def test_structured_tool_types_parsed_pydantic_v1() -> None:390 """Test the non-primitive types are correctly passed to structured tools."""391392 class SomeBaseModel(BaseModelV1):393 foo: str394395 class AnotherBaseModel(BaseModelV1):396 bar: str397398 @tool399 def structured_tool(some_base_model: SomeBaseModel) -> AnotherBaseModel:400 """Return the arguments directly."""401 return AnotherBaseModel(bar=some_base_model.foo)402403 assert isinstance(structured_tool, StructuredTool)404405 expected = AnotherBaseModel(bar="baz")406 for arg in [407 SomeBaseModel(foo="baz"),408 SomeBaseModel(foo="baz").dict(),409 ]:410 args = {"some_base_model": arg}411 result = structured_tool.run(args)412 assert result == expected413414415def test_structured_tool_types_parsed_pydantic_mixed() -> None:416 """Test handling of tool with mixed Pydantic version arguments."""417418 class SomeBaseModel(BaseModelV1):419 foo: str420421 class AnotherBaseModel(BaseModel):422 bar: str423424 with pytest.raises(NotImplementedError):425426 @tool427 def structured_tool(428 some_base_model: SomeBaseModel, another_base_model: AnotherBaseModel429 ) -> None:430 """Return the arguments directly."""431432433def test_base_tool_inheritance_base_schema() -> None:434 """Test schema is correctly inferred when inheriting from BaseTool."""435436 class _MockSimpleTool(BaseTool):437 name: str = "simple_tool"438 description: str = "A Simple Tool"439440 @override441 def _run(self, tool_input: str) -> str:442 return f"{tool_input}"443444 @override445 async def _arun(self, tool_input: str) -> str:446 raise NotImplementedError447448 simple_tool = _MockSimpleTool()449 assert simple_tool.args_schema is None450 expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}451 assert simple_tool.args == expected_args452453454def test_tool_lambda_args_schema() -> None:455 """Test args schema inference when the tool argument is a lambda function."""456 tool = Tool(457 name="tool",458 description="A tool",459 func=lambda tool_input: tool_input,460 )461 assert tool.args_schema is None462 expected_args = {"tool_input": {"type": "string"}}463 assert tool.args == expected_args464465466def test_structured_tool_from_function_docstring() -> None:467 """Test that structured tools can be created from functions."""468469 def foo(bar: int, baz: str) -> str:470 """Docstring.471472 Args:473 bar: the bar value474 baz: the baz value475 """476 raise NotImplementedError477478 structured_tool = StructuredTool.from_function(foo)479 assert structured_tool.name == "foo"480 assert structured_tool.args == {481 "bar": {"title": "Bar", "type": "integer"},482 "baz": {"title": "Baz", "type": "string"},483 }484485 assert _schema(structured_tool.args_schema) == {486 "properties": {487 "bar": {"title": "Bar", "type": "integer"},488 "baz": {"title": "Baz", "type": "string"},489 },490 "description": inspect.getdoc(foo),491 "title": "foo",492 "type": "object",493 "required": ["bar", "baz"],494 }495496 assert foo.__doc__ is not None497 assert structured_tool.description == textwrap.dedent(foo.__doc__.strip())498499500def test_structured_tool_from_function_docstring_complex_args() -> None:501 """Test that structured tools can be created from functions."""502503 def foo(bar: int, baz: list[str]) -> str:504 """Docstring.505506 Args:507 bar: int508 baz: list[str]509 """510 raise NotImplementedError511512 structured_tool = StructuredTool.from_function(foo)513 assert structured_tool.name == "foo"514 assert structured_tool.args == {515 "bar": {"title": "Bar", "type": "integer"},516 "baz": {517 "title": "Baz",518 "type": "array",519 "items": {"type": "string"},520 },521 }522523 assert _schema(structured_tool.args_schema) == {524 "properties": {525 "bar": {"title": "Bar", "type": "integer"},526 "baz": {527 "title": "Baz",528 "type": "array",529 "items": {"type": "string"},530 },531 },532 "description": inspect.getdoc(foo),533 "title": "foo",534 "type": "object",535 "required": ["bar", "baz"],536 }537538 assert foo.__doc__ is not None539 assert structured_tool.description == textwrap.dedent(foo.__doc__).strip()540541542def test_structured_tool_lambda_multi_args_schema() -> None:543 """Test args schema inference when the tool argument is a lambda function."""544 tool = StructuredTool.from_function(545 name="tool",546 description="A tool",547 func=lambda tool_input, other_arg: f"{tool_input}{other_arg}",548 )549 assert tool.args_schema is not None550 expected_args = {551 "tool_input": {"title": "Tool Input"},552 "other_arg": {"title": "Other Arg"},553 }554 assert tool.args == expected_args555556557def test_tool_partial_function_args_schema() -> None:558 """Test args schema inference when the tool argument is a partial function."""559560 def func(tool_input: str, other_arg: str) -> str:561 assert isinstance(tool_input, str)562 assert isinstance(other_arg, str)563 return tool_input + other_arg564565 tool = Tool(566 name="tool",567 description="A tool",568 func=partial(func, other_arg="foo"),569 )570 assert tool.run("bar") == "barfoo"571572573def test_empty_args_decorator() -> None:574 """Test inferred schema of decorated fn with no args."""575576 @tool577 def empty_tool_input() -> str:578 """Return a constant."""579 return "the empty result"580581 assert isinstance(empty_tool_input, BaseTool)582 assert empty_tool_input.name == "empty_tool_input"583 assert empty_tool_input.args == {}584 assert empty_tool_input.run({}) == "the empty result"585586587def test_tool_from_function_with_run_manager() -> None:588 """Test run of tool when using run_manager."""589590 def foo(bar: str, callbacks: CallbackManagerForToolRun | None = None) -> str: # noqa: D417591 """Docstring.592593 Args:594 bar: str.595 """596 assert callbacks is not None597 return "foo" + bar598599 handler = FakeCallbackHandler()600 tool = Tool.from_function(foo, name="foo", description="Docstring")601602 assert tool.run(tool_input={"bar": "bar"}, run_manager=[handler]) == "foobar"603 assert tool.run("baz", run_manager=[handler]) == "foobaz"604605606def test_structured_tool_from_function_with_run_manager() -> None:607 """Test args and schema of structured tool when using callbacks."""608609 def foo( # noqa: D417610 bar: int, baz: str, callbacks: CallbackManagerForToolRun | None = None611 ) -> str:612 """Docstring.613614 Args:615 bar: int616 baz: str617 """618 assert callbacks is not None619 return str(bar) + baz620621 handler = FakeCallbackHandler()622 structured_tool = StructuredTool.from_function(foo)623624 assert structured_tool.args == {625 "bar": {"title": "Bar", "type": "integer"},626 "baz": {"title": "Baz", "type": "string"},627 }628629 assert _schema(structured_tool.args_schema) == {630 "properties": {631 "bar": {"title": "Bar", "type": "integer"},632 "baz": {"title": "Baz", "type": "string"},633 },634 "description": inspect.getdoc(foo),635 "title": "foo",636 "type": "object",637 "required": ["bar", "baz"],638 }639640 assert (641 structured_tool.run(642 tool_input={"bar": "10", "baz": "baz"}, run_manger=[handler]643 )644 == "10baz"645 )646647648def test_structured_tool_from_parameterless_function() -> None:649 """Test parameterless function of structured tool."""650651 def foo() -> str:652 """Docstring."""653 return "invoke foo"654655 structured_tool = StructuredTool.from_function(foo)656657 assert structured_tool.run({}) == "invoke foo"658 assert structured_tool.run("") == "invoke foo"659660661def test_named_tool_decorator() -> None:662 """Test functionality when arguments are provided as input to decorator."""663664 @tool("search")665 def search_api(query: str) -> str:666 """Search the API for the query."""667 assert isinstance(query, str)668 return f"API result - {query}"669670 assert isinstance(search_api, BaseTool)671 assert search_api.name == "search"672 assert not search_api.return_direct673 assert search_api.run({"query": "foo"}) == "API result - foo"674675676def test_named_tool_decorator_return_direct() -> None:677 """Test functionality when arguments and return direct are provided as input."""678679 @tool("search", return_direct=True)680 def search_api(query: str, *args: Any) -> str:681 """Search the API for the query."""682 return "API result"683684 assert isinstance(search_api, BaseTool)685 assert search_api.name == "search"686 assert search_api.return_direct687 assert search_api.run({"query": "foo"}) == "API result"688689690def test_unnamed_tool_decorator_return_direct() -> None:691 """Test functionality when only return direct is provided."""692693 @tool(return_direct=True)694 def search_api(query: str) -> str:695 """Search the API for the query."""696 assert isinstance(query, str)697 return "API result"698699 assert isinstance(search_api, BaseTool)700 assert search_api.name == "search_api"701 assert search_api.return_direct702 assert search_api.run({"query": "foo"}) == "API result"703704705def test_tool_with_kwargs() -> None:706 """Test functionality when only return direct is provided."""707708 @tool(return_direct=True)709 def search_api(710 arg_0: str,711 arg_1: float = 4.3,712 ping: str = "hi",713 ) -> str:714 """Search the API for the query."""715 return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"716717 assert isinstance(search_api, BaseTool)718 result = search_api.run(719 tool_input={720 "arg_0": "foo",721 "arg_1": 3.2,722 "ping": "pong",723 }724 )725 assert result == "arg_0=foo, arg_1=3.2, ping=pong"726727 result = search_api.run(728 tool_input={729 "arg_0": "foo",730 }731 )732 assert result == "arg_0=foo, arg_1=4.3, ping=hi"733 # For backwards compatibility, we still accept a single str arg734 result = search_api.run("foobar")735 assert result == "arg_0=foobar, arg_1=4.3, ping=hi"736737738def test_missing_docstring() -> None:739 """Test error is raised when docstring is missing."""740 # expect to throw a value error if there's no docstring741 with pytest.raises(ValueError, match="Function must have a docstring"):742743 @tool744 def search_api(query: str) -> str:745 return "API result"746747 @tool748 class MyTool(BaseModel):749 foo: str750751 assert not MyTool.description # type: ignore[attr-defined]752753754def test_create_tool_positional_args() -> None:755 """Test that positional arguments are allowed."""756 test_tool = Tool("test_name", lambda x: x, "test_description")757 assert test_tool.invoke("foo") == "foo"758 assert test_tool.name == "test_name"759 assert test_tool.description == "test_description"760 assert test_tool.is_single_input761762763def test_create_tool_keyword_args() -> None:764 """Test that keyword arguments are allowed."""765 test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")766 assert test_tool.is_single_input767 assert test_tool.invoke("foo") == "foo"768 assert test_tool.name == "test_name"769 assert test_tool.description == "test_description"770771772async def test_create_async_tool() -> None:773 """Test that async tools are allowed."""774775 async def _test_func(x: str) -> str:776 return x777778 test_tool = Tool(779 name="test_name",780 func=lambda x: x,781 description="test_description",782 coroutine=_test_func,783 )784 assert test_tool.is_single_input785 assert test_tool.invoke("foo") == "foo"786 assert test_tool.name == "test_name"787 assert test_tool.description == "test_description"788 assert test_tool.coroutine is not None789 assert await test_tool.arun("foo") == "foo"790791792class _FakeExceptionTool(BaseTool):793 name: str = "exception"794 description: str = "an exception-throwing tool"795 exception: Exception = ToolException()796797 def _run(self) -> str:798 raise self.exception799800 async def _arun(self) -> str:801 raise self.exception802803804def test_exception_handling_bool() -> None:805 tool_ = _FakeExceptionTool(handle_tool_error=True)806 expected = "Tool execution error"807 actual = tool_.run({})808 assert expected == actual809810811def test_exception_handling_str() -> None:812 expected = "foo bar"813 tool_ = _FakeExceptionTool(handle_tool_error=expected)814 actual = tool_.run({})815 assert expected == actual816817818def test_exception_handling_callable() -> None:819 expected = "foo bar"820821 def handling(e: ToolException) -> str:822 return expected823824 tool_ = _FakeExceptionTool(handle_tool_error=handling)825 actual = tool_.run({})826 assert expected == actual827828829def test_exception_handling_non_tool_exception() -> None:830 tool_ = _FakeExceptionTool(exception=ValueError("some error"))831 with pytest.raises(ValueError, match="some error"):832 tool_.run({})833834835async def test_async_exception_handling_bool() -> None:836 tool_ = _FakeExceptionTool(handle_tool_error=True)837 expected = "Tool execution error"838 actual = await tool_.arun({})839 assert expected == actual840841842async def test_async_exception_handling_str() -> None:843 expected = "foo bar"844 tool_ = _FakeExceptionTool(handle_tool_error=expected)845 actual = await tool_.arun({})846 assert expected == actual847848849async def test_async_exception_handling_callable() -> None:850 expected = "foo bar"851852 def handling(e: ToolException) -> str:853 return expected854855 tool_ = _FakeExceptionTool(handle_tool_error=handling)856 actual = await tool_.arun({})857 assert expected == actual858859860async def test_async_exception_handling_non_tool_exception() -> None:861 tool_ = _FakeExceptionTool(exception=ValueError("some error"))862 with pytest.raises(ValueError, match="some error"):863 await tool_.arun({})864865866def test_structured_tool_from_function() -> None:867 """Test that structured tools can be created from functions."""868869 def foo(bar: int, baz: str) -> str:870 """Docstring thing.871872 Args:873 bar: the bar value874 baz: the baz value875 """876 raise NotImplementedError877878 structured_tool = StructuredTool.from_function(foo)879 assert structured_tool.name == "foo"880 assert structured_tool.args == {881 "bar": {"title": "Bar", "type": "integer"},882 "baz": {"title": "Baz", "type": "string"},883 }884885 assert _schema(structured_tool.args_schema) == {886 "title": "foo",887 "type": "object",888 "description": inspect.getdoc(foo),889 "properties": {890 "bar": {"title": "Bar", "type": "integer"},891 "baz": {"title": "Baz", "type": "string"},892 },893 "required": ["bar", "baz"],894 }895896 assert foo.__doc__ is not None897 assert structured_tool.description == textwrap.dedent(foo.__doc__.strip())898899900def test_validation_error_handling_bool() -> None:901 """Test that validation errors are handled correctly."""902 expected = "Tool input validation error"903 tool_ = _MockStructuredTool(handle_validation_error=True)904 actual = tool_.run({})905 assert expected == actual906907908def test_validation_error_handling_str() -> None:909 """Test that validation errors are handled correctly."""910 expected = "foo bar"911 tool_ = _MockStructuredTool(handle_validation_error=expected)912 actual = tool_.run({})913 assert expected == actual914915916def test_validation_error_handling_callable() -> None:917 """Test that validation errors are handled correctly."""918 expected = "foo bar"919920 def handling(e: ValidationError | ValidationErrorV1) -> str:921 return expected922923 tool_ = _MockStructuredTool(handle_validation_error=handling)924 actual = tool_.run({})925 assert expected == actual926927928@pytest.mark.parametrize(929 "handler",930 [931 True,932 "foo bar",933 lambda _: "foo bar",934 ],935)936def test_validation_error_handling_non_validation_error(937 *,938 handler: bool | str | Callable[[ValidationError | ValidationErrorV1], str],939) -> None:940 """Test that validation errors are handled correctly."""941942 class _RaiseNonValidationErrorTool(BaseTool):943 name: str = "raise_non_validation_error_tool"944 description: str = "A tool that raises a non-validation error"945946 def _parse_input(947 self,948 tool_input: str | dict,949 tool_call_id: str | None,950 ) -> str | dict[str, Any]:951 raise NotImplementedError952953 @override954 def _run(self) -> str:955 return "dummy"956957 @override958 async def _arun(self) -> str:959 return "dummy"960961 tool_ = _RaiseNonValidationErrorTool(handle_validation_error=handler)962 with pytest.raises(NotImplementedError):963 tool_.run({})964965966async def test_async_validation_error_handling_bool() -> None:967 """Test that validation errors are handled correctly."""968 expected = "Tool input validation error"969 tool_ = _MockStructuredTool(handle_validation_error=True)970 actual = await tool_.arun({})971 assert expected == actual972973974async def test_async_validation_error_handling_str() -> None:975 """Test that validation errors are handled correctly."""976 expected = "foo bar"977 tool_ = _MockStructuredTool(handle_validation_error=expected)978 actual = await tool_.arun({})979 assert expected == actual980981982async def test_async_validation_error_handling_callable() -> None:983 """Test that validation errors are handled correctly."""984 expected = "foo bar"985986 def handling(e: ValidationError | ValidationErrorV1) -> str:987 return expected988989 tool_ = _MockStructuredTool(handle_validation_error=handling)990 actual = await tool_.arun({})991 assert expected == actual992993994@pytest.mark.parametrize(995 "handler",996 [997 True,998 "foo bar",999 lambda _: "foo bar",1000 ],1001)1002async def test_async_validation_error_handling_non_validation_error(1003 *,1004 handler: bool | str | Callable[[ValidationError | ValidationErrorV1], str],1005) -> None:1006 """Test that validation errors are handled correctly."""10071008 class _RaiseNonValidationErrorTool(BaseTool):1009 name: str = "raise_non_validation_error_tool"1010 description: str = "A tool that raises a non-validation error"10111012 def _parse_input(1013 self,1014 tool_input: str | dict,1015 tool_call_id: str | None,1016 ) -> str | dict[str, Any]:1017 raise NotImplementedError10181019 @override1020 def _run(self) -> str:1021 return "dummy"10221023 @override1024 async def _arun(self) -> str:1025 return "dummy"10261027 tool_ = _RaiseNonValidationErrorTool(handle_validation_error=handler)1028 with pytest.raises(NotImplementedError):1029 await tool_.arun({})103010311032def test_optional_subset_model_rewrite() -> None:1033 class MyModel(BaseModel):1034 a: str | None = None1035 b: str1036 c: list[str | None] | None = None10371038 model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"])10391040 assert set(_schema(model2)["required"]) == {"b"}104110421043@pytest.mark.parametrize(1044 ("inputs", "expected"),1045 [1046 # Check not required1047 ({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}),1048 # Check overwritten1049 (1050 {"bar": "bar", "baz": 4, "buzz": "not-buzz"},1051 {"bar": "bar", "baz": 4, "buzz": "not-buzz"},1052 ),1053 # Check validation error when missing1054 ({}, None),1055 # Check validation error when wrong type1056 ({"bar": "bar", "baz": "not-an-int"}, None),1057 # Check OK when None explicitly passed1058 ({"bar": "bar", "baz": None}, {"bar": "bar", "baz": None, "buzz": "buzz"}),1059 ],1060)1061def test_tool_invoke_optional_args(inputs: dict, expected: dict | None) -> None:1062 @tool1063 def foo(bar: str, baz: int | None = 3, buzz: str | None = "buzz") -> dict:1064 """The foo."""1065 return {1066 "bar": bar,1067 "baz": baz,1068 "buzz": buzz,1069 }10701071 if expected is not None:1072 assert foo.invoke(inputs) == expected1073 else:1074 with pytest.raises(ValidationError):1075 foo.invoke(inputs)107610771078def test_tool_pass_context() -> None:1079 @tool1080 def foo(bar: str) -> str:1081 """The foo."""1082 config = ensure_config()1083 assert config["configurable"]["foo"] == "not-bar"1084 assert bar == "baz"1085 return bar10861087 assert foo.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"108810891090@pytest.mark.skipif(1091 sys.version_info < (3, 11),1092 reason="requires python3.11 or higher",1093)1094async def test_async_tool_pass_context() -> None:1095 @tool1096 async def foo(bar: str) -> str:1097 """The foo."""1098 config = ensure_config()1099 assert config["configurable"]["foo"] == "not-bar"1100 assert bar == "baz"1101 return bar11021103 assert (1104 await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"1105 )110611071108def assert_bar(bar: Any, bar_config: RunnableConfig) -> Any:1109 assert bar_config["configurable"]["foo"] == "not-bar"1110 assert bar == "baz"1111 return bar111211131114@tool1115def foo(bar: Any, bar_config: RunnableConfig) -> Any:1116 """The foo."""1117 return assert_bar(bar, bar_config)111811191120@tool1121async def afoo(bar: Any, bar_config: RunnableConfig) -> Any:1122 """The foo."""1123 return assert_bar(bar, bar_config)112411251126@tool(infer_schema=False)1127def simple_foo(bar: Any, bar_config: RunnableConfig) -> Any:1128 """The foo."""1129 return assert_bar(bar, bar_config)113011311132@tool(infer_schema=False)1133async def asimple_foo(bar: Any, bar_config: RunnableConfig) -> Any:1134 """The foo."""1135 return assert_bar(bar, bar_config)113611371138class FooBase(BaseTool):1139 name: str = "Foo"1140 description: str = "Foo"11411142 @override1143 def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:1144 return assert_bar(bar, bar_config)114511461147class AFooBase(FooBase):1148 @override1149 async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:1150 return assert_bar(bar, bar_config)115111521153@pytest.mark.parametrize("tool", [foo, simple_foo, FooBase(), AFooBase()])1154def test_tool_pass_config(tool: BaseTool) -> None:1155 assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"11561157 # Test we don't mutate tool calls1158 tool_call = {1159 "name": tool.name,1160 "args": {"bar": "baz"},1161 "id": "abc123",1162 "type": "tool_call",1163 }1164 _ = tool.invoke(tool_call, {"configurable": {"foo": "not-bar"}})1165 assert tool_call["args"] == {"bar": "baz"}116611671168class FooBaseNonPickleable(FooBase):1169 @override1170 def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:1171 return True117211731174def test_tool_pass_config_non_pickleable() -> None:1175 tool = FooBaseNonPickleable()11761177 args = {"bar": threading.Lock()}1178 tool_call = {1179 "name": tool.name,1180 "args": args,1181 "id": "abc123",1182 "type": "tool_call",1183 }1184 _ = tool.invoke(tool_call, {"configurable": {"foo": "not-bar"}})1185 assert tool_call["args"] == args118611871188@pytest.mark.parametrize(1189 "tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]1190)1191async def test_async_tool_pass_config(tool: BaseTool) -> None:1192 assert (1193 await tool.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}})1194 == "baz"1195 )119611971198def test_tool_description() -> None:1199 def foo(bar: str) -> str:1200 """The foo."""1201 return bar12021203 foo1 = tool(foo)1204 assert foo1.description == "The foo."12051206 foo2 = StructuredTool.from_function(foo)1207 assert foo2.description == "The foo."120812091210def test_tool_arg_descriptions() -> None:1211 def foo(bar: str, baz: int) -> str:1212 """The foo.12131214 Args:1215 bar: The bar.1216 baz: The baz.1217 """1218 return bar12191220 foo1 = tool(foo)1221 args_schema = _schema(foo1.args_schema)1222 assert args_schema == {1223 "title": "foo",1224 "type": "object",1225 "description": inspect.getdoc(foo),1226 "properties": {1227 "bar": {"title": "Bar", "type": "string"},1228 "baz": {"title": "Baz", "type": "integer"},1229 },1230 "required": ["bar", "baz"],1231 }12321233 # Test parses docstring1234 foo2 = tool(foo, parse_docstring=True)1235 args_schema = _schema(foo2.args_schema)1236 expected = {1237 "title": "foo",1238 "description": "The foo.",1239 "type": "object",1240 "properties": {1241 "bar": {"title": "Bar", "description": "The bar.", "type": "string"},1242 "baz": {"title": "Baz", "description": "The baz.", "type": "integer"},1243 },1244 "required": ["bar", "baz"],1245 }1246 assert args_schema == expected12471248 # Test parsing with run_manager does not raise error1249 def foo3( # noqa: D4171250 bar: str, baz: int, run_manager: CallbackManagerForToolRun | None = None1251 ) -> str:1252 """The foo.12531254 Args:1255 bar: The bar.1256 baz: The baz.1257 """1258 return bar12591260 as_tool = tool(foo3, parse_docstring=True)1261 args_schema = _schema(as_tool.args_schema)1262 assert args_schema["description"] == expected["description"]1263 assert args_schema["properties"] == expected["properties"]12641265 # Test parsing with runtime does not raise error1266 def foo3_runtime(bar: str, baz: int, runtime: Any) -> str: # noqa: D4171267 """The foo.12681269 Args:1270 bar: The bar.1271 baz: The baz.1272 """1273 return bar12741275 _ = tool(foo3_runtime, parse_docstring=True)12761277 # Test parameterless tool does not raise error for missing Args section1278 # in docstring.1279 def foo4() -> str:1280 """The foo."""1281 return "bar"12821283 as_tool = tool(foo4, parse_docstring=True)1284 args_schema = _schema(as_tool.args_schema)1285 assert args_schema["description"] == expected["description"]12861287 def foo5(run_manager: CallbackManagerForToolRun | None = None) -> str:1288 """The foo."""1289 return "bar"12901291 as_tool = tool(foo5, parse_docstring=True)1292 args_schema = _schema(as_tool.args_schema)1293 assert args_schema["description"] == expected["description"]129412951296def test_docstring_parsing() -> None:1297 expected = {1298 "title": "foo",1299 "description": "The foo.",1300 "type": "object",1301 "properties": {1302 "bar": {"title": "Bar", "description": "The bar.", "type": "string"},1303 "baz": {"title": "Baz", "description": "The baz.", "type": "integer"},1304 },1305 "required": ["bar", "baz"],1306 }13071308 # Simple case1309 def foo(bar: str, baz: int) -> str:1310 """The foo.13111312 Args:1313 bar: The bar.1314 baz: The baz.1315 """1316 return bar13171318 as_tool = tool(foo, parse_docstring=True)1319 args_schema = _schema(as_tool.args_schema)1320 assert args_schema["description"] == "The foo."1321 assert args_schema["properties"] == expected["properties"]13221323 # Multi-line description1324 def foo2(bar: str, baz: int) -> str:1325 """The foo.13261327 Additional description here.13281329 Args:1330 bar: The bar.1331 baz: The baz.1332 """1333 return bar13341335 as_tool = tool(foo2, parse_docstring=True)1336 args_schema2 = _schema(as_tool.args_schema)1337 assert args_schema2["description"] == "The foo. Additional description here."1338 assert args_schema2["properties"] == expected["properties"]13391340 # Multi-line with Returns block1341 def foo3(bar: str, baz: int) -> str:1342 """The foo.13431344 Additional description here.13451346 Args:1347 bar: The bar.1348 baz: The baz.13491350 Returns:1351 description of returned value.1352 """1353 return bar13541355 as_tool = tool(foo3, parse_docstring=True)1356 args_schema3 = _schema(as_tool.args_schema)1357 args_schema3["title"] = "foo2"1358 assert args_schema2 == args_schema313591360 # Single argument1361 def foo4(bar: str) -> str:1362 """The foo.13631364 Args:1365 bar: The bar.1366 """1367 return bar13681369 as_tool = tool(foo4, parse_docstring=True)1370 args_schema4 = _schema(as_tool.args_schema)1371 assert args_schema4["description"] == "The foo."1372 assert args_schema4["properties"] == {1373 "bar": {"description": "The bar.", "title": "Bar", "type": "string"}1374 }137513761377def test_tool_invalid_docstrings() -> None:1378 """Test invalid docstrings."""13791380 def foo3(bar: str, baz: int) -> str:1381 """The foo."""1382 return bar13831384 def foo4(bar: str, baz: int) -> str:1385 """The foo.1386 Args:1387 bar: The bar.1388 baz: The baz.1389 """ # noqa: D205,D411 # We're intentionally testing bad formatting.1390 return bar13911392 for func in {foo3, foo4}:1393 with pytest.raises(ValueError, match="Found invalid Google-Style docstring"):1394 _ = tool(func, parse_docstring=True)13951396 def foo5(bar: str, baz: int) -> str: # noqa: D4171397 """The foo.13981399 Args:1400 banana: The bar.1401 monkey: The baz.1402 """1403 return bar14041405 with pytest.raises(1406 ValueError, match="Arg banana in docstring not found in function signature"1407 ):1408 _ = tool(foo5, parse_docstring=True)140914101411def test_tool_annotated_descriptions() -> None:1412 def foo(1413 bar: Annotated[str, "this is the bar"], baz: Annotated[int, "this is the baz"]1414 ) -> str:1415 """The foo.14161417 Returns:1418 The bar only.1419 """1420 return bar14211422 foo1 = tool(foo)1423 args_schema = _schema(foo1.args_schema)1424 assert args_schema == {1425 "title": "foo",1426 "type": "object",1427 "description": inspect.getdoc(foo),1428 "properties": {1429 "bar": {"title": "Bar", "type": "string", "description": "this is the bar"},1430 "baz": {1431 "title": "Baz",1432 "type": "integer",1433 "description": "this is the baz",1434 },1435 },1436 "required": ["bar", "baz"],1437 }143814391440def test_tool_field_description_preserved() -> None:1441 """Test that `Field(description=...)` is preserved in `@tool` decorator."""14421443 @tool1444 def my_tool(1445 topic: Annotated[str, Field(description="The research topic")],1446 depth: Annotated[int, Field(description="Search depth level")] = 3,1447 ) -> str:1448 """A tool for research."""1449 return f"{topic} at depth {depth}"14501451 args_schema = _schema(my_tool.args_schema)1452 assert args_schema == {1453 "title": "my_tool",1454 "type": "object",1455 "description": "A tool for research.",1456 "properties": {1457 "topic": {1458 "title": "Topic",1459 "type": "string",1460 "description": "The research topic",1461 },1462 "depth": {1463 "title": "Depth",1464 "type": "integer",1465 "description": "Search depth level",1466 "default": 3,1467 },1468 },1469 "required": ["topic"],1470 }147114721473def test_tool_call_input_tool_message_output() -> None:1474 tool_call = {1475 "name": "structured_api",1476 "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},1477 "id": "123",1478 "type": "tool_call",1479 }1480 tool = _MockStructuredTool()1481 expected = ToolMessage(1482 "1 True {'img': 'base64string...'}", tool_call_id="123", name="structured_api"1483 )1484 actual = tool.invoke(tool_call)1485 assert actual == expected14861487 tool_call.pop("type")1488 with pytest.raises(ValidationError):1489 tool.invoke(tool_call)149014911492@pytest.mark.parametrize("block_type", [*TOOL_MESSAGE_BLOCK_TYPES, "bad"])1493def test_tool_content_block_output(block_type: str) -> None:1494 @tool1495 def my_tool(query: str) -> list[dict[str, Any]]:1496 """Test tool."""1497 return [{"type": block_type, "foo": "bar"}]14981499 tool_call = {1500 "type": "tool_call",1501 "name": "my_tool",1502 "args": {"query": "baz"},1503 "id": "call_abc123",1504 }15051506 result = my_tool.invoke(tool_call)1507 assert isinstance(result, ToolMessage)15081509 if block_type in TOOL_MESSAGE_BLOCK_TYPES:1510 assert result.content == [{"type": block_type, "foo": "bar"}]1511 else:1512 assert result.content == '[{"type": "bad", "foo": "bar"}]'151315141515class _MockStructuredToolWithRawOutput(BaseTool):1516 name: str = "structured_api"1517 args_schema: type[BaseModel] = _MockSchema1518 description: str = "A Structured Tool"1519 response_format: Literal["content_and_artifact"] = "content_and_artifact"15201521 @override1522 def _run(1523 self,1524 arg1: int,1525 arg2: bool,1526 arg3: dict[str, Any] | None = None,1527 ) -> tuple[str, dict[str, Any]]:1528 return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}152915301531@tool("structured_api", response_format="content_and_artifact")1532def _mock_structured_tool_with_artifact(1533 *, arg1: int, arg2: bool, arg3: dict[str, str] | None = None1534) -> tuple[str, dict[str, Any]]:1535 """A Structured Tool."""1536 return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}153715381539@pytest.mark.parametrize(1540 "tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_artifact]1541)1542def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None:1543 tool_call: dict[str, Any] = {1544 "name": "structured_api",1545 "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},1546 "id": "123",1547 "type": "tool_call",1548 }1549 expected = ToolMessage(1550 "1 True", artifact=tool_call["args"], tool_call_id="123", name="structured_api"1551 )1552 actual = tool.invoke(tool_call)1553 assert actual == expected15541555 tool_call.pop("type")1556 with pytest.raises(ValidationError):1557 tool.invoke(tool_call)15581559 actual_content = tool.invoke(tool_call["args"])1560 assert actual_content == expected.content156115621563def test_convert_from_runnable_dict() -> None:1564 # Test with typed dict input1565 class Args(TypedDict):1566 a: int1567 b: list[int]15681569 def f(x: Args) -> str:1570 return str(x["a"] * max(x["b"]))15711572 runnable = RunnableLambda(f)1573 as_tool = runnable.as_tool()1574 args_schema = as_tool.args_schema1575 assert args_schema is not None1576 assert _schema(args_schema) == {1577 "title": "f",1578 "type": "object",1579 "properties": {1580 "a": {"title": "A", "type": "integer"},1581 "b": {"title": "B", "type": "array", "items": {"type": "integer"}},1582 },1583 "required": ["a", "b"],1584 }1585 assert as_tool.description1586 result = as_tool.invoke({"a": 3, "b": [1, 2]})1587 assert result == "6"15881589 as_tool = runnable.as_tool(name="my tool", description="test description")1590 assert as_tool.name == "my tool"1591 assert as_tool.description == "test description"15921593 # Dict without typed input-- must supply schema1594 def g(x: dict[str, Any]) -> str:1595 return str(x["a"] * max(x["b"]))15961597 # Specify via args_schema:1598 class GSchema(BaseModel):1599 """Apply a function to an integer and list of integers."""16001601 a: int = Field(..., description="Integer")1602 b: list[int] = Field(..., description="List of ints")16031604 runnable2 = RunnableLambda(g)1605 as_tool2 = runnable2.as_tool(GSchema)1606 as_tool2.invoke({"a": 3, "b": [1, 2]})16071608 # Specify via arg_types:1609 runnable3 = RunnableLambda(g)1610 as_tool3 = runnable3.as_tool(arg_types={"a": int, "b": list[int]})1611 result = as_tool3.invoke({"a": 3, "b": [1, 2]})1612 assert result == "6"16131614 # Test with config1615 def h(x: dict[str, Any]) -> str:1616 config = ensure_config()1617 assert config["configurable"]["foo"] == "not-bar"1618 return str(x["a"] * max(x["b"]))16191620 runnable4 = RunnableLambda(h)1621 as_tool4 = runnable4.as_tool(arg_types={"a": int, "b": list[int]})1622 result = as_tool4.invoke(1623 {"a": 3, "b": [1, 2]}, config={"configurable": {"foo": "not-bar"}}1624 )1625 assert result == "6"162616271628def test_convert_from_runnable_other() -> None:1629 # String input1630 def f(x: str) -> str:1631 return x + "a"16321633 def g(x: str) -> str:1634 return x + "z"16351636 runnable = RunnableLambda(f) | g1637 as_tool = runnable.as_tool()1638 args_schema = as_tool.args_schema1639 assert args_schema is None1640 assert as_tool.description16411642 result = as_tool.invoke("b")1643 assert result == "baz"16441645 # Test with config1646 def h(x: str) -> str:1647 config = ensure_config()1648 assert config["configurable"]["foo"] == "not-bar"1649 return x + "a"16501651 runnable2 = RunnableLambda(h)1652 as_tool2 = runnable2.as_tool()1653 result2 = as_tool2.invoke("b", config={"configurable": {"foo": "not-bar"}})1654 assert result2 == "ba"165516561657@tool("foo", parse_docstring=True)1658def injected_tool(x: int, y: Annotated[str, InjectedToolArg]) -> str:1659 """Foo.16601661 Args:1662 x: abc1663 y: 1231664 """1665 return y166616671668class InjectedTool(BaseTool):1669 name: str = "foo"1670 description: str = "foo."16711672 @override1673 def _run(self, x: int, y: Annotated[str, InjectedToolArg]) -> Any:1674 """Foo.16751676 Args:1677 x: abc1678 y: 1231679 """1680 return y168116821683class fooSchema(BaseModel): # noqa: N8011684 """foo."""16851686 x: int = Field(..., description="abc")1687 y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(1688 ..., description="123"1689 )169016911692class InjectedToolWithSchema(BaseTool):1693 name: str = "foo"1694 description: str = "foo."1695 args_schema: type[BaseModel] = fooSchema16961697 @override1698 def _run(self, x: int, y: str) -> Any:1699 return y170017011702@tool("foo", args_schema=fooSchema)1703def injected_tool_with_schema(x: int, y: str) -> str:1704 return y170517061707@pytest.mark.parametrize("tool_", [InjectedTool()])1708def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:1709 assert _schema(tool_.get_input_schema()) == {1710 "title": "foo",1711 "description": "Foo.\n\nArgs:\n x: abc\n y: 123",1712 "type": "object",1713 "properties": {1714 "x": {"title": "X", "type": "integer"},1715 "y": {"title": "Y", "type": "string"},1716 },1717 "required": ["x", "y"],1718 }1719 assert _schema(tool_.tool_call_schema) == {1720 "title": "foo",1721 "description": "foo.",1722 "type": "object",1723 "properties": {"x": {"title": "X", "type": "integer"}},1724 "required": ["x"],1725 }1726 assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"1727 assert tool_.invoke(1728 {1729 "name": "foo",1730 "args": {"x": 5, "y": "bar"},1731 "id": "123",1732 "type": "tool_call",1733 }1734 ) == ToolMessage("bar", tool_call_id="123", name="foo")1735 expected_error = (1736 ValidationError if not isinstance(tool_, InjectedTool) else TypeError1737 )1738 with pytest.raises(expected_error):1739 tool_.invoke({"x": 5})17401741 assert convert_to_openai_function(tool_) == {1742 "name": "foo",1743 "description": "foo.",1744 "parameters": {1745 "type": "object",1746 "properties": {"x": {"type": "integer"}},1747 "required": ["x"],1748 },1749 }175017511752@pytest.mark.parametrize(1753 "tool_",1754 [injected_tool_with_schema, InjectedToolWithSchema()],1755)1756def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:1757 assert _schema(tool_.get_input_schema()) == {1758 "title": "fooSchema",1759 "description": "foo.",1760 "type": "object",1761 "properties": {1762 "x": {"description": "abc", "title": "X", "type": "integer"},1763 "y": {"description": "123", "title": "Y", "type": "string"},1764 },1765 "required": ["x", "y"],1766 }1767 assert _schema(tool_.tool_call_schema) == {1768 "title": "foo",1769 "description": "foo.",1770 "type": "object",1771 "properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},1772 "required": ["x"],1773 }1774 assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"1775 assert tool_.invoke(1776 {1777 "name": "foo",1778 "args": {"x": 5, "y": "bar"},1779 "id": "123",1780 "type": "tool_call",1781 }1782 ) == ToolMessage("bar", tool_call_id="123", name="foo")1783 expected_error = (1784 ValidationError if not isinstance(tool_, InjectedTool) else TypeError1785 )1786 with pytest.raises(expected_error):1787 tool_.invoke({"x": 5})17881789 assert convert_to_openai_function(tool_) == {1790 "name": "foo",1791 "description": "foo.",1792 "parameters": {1793 "type": "object",1794 "properties": {"x": {"type": "integer", "description": "abc"}},1795 "required": ["x"],1796 },1797 }179817991800def test_tool_injected_arg() -> None:1801 tool_ = injected_tool1802 assert _schema(tool_.get_input_schema()) == {1803 "title": "foo",1804 "description": "Foo.",1805 "type": "object",1806 "properties": {1807 "x": {"description": "abc", "title": "X", "type": "integer"},1808 "y": {"description": "123", "title": "Y", "type": "string"},1809 },1810 "required": ["x", "y"],1811 }1812 assert _schema(tool_.tool_call_schema) == {1813 "title": "foo",1814 "description": "Foo.",1815 "type": "object",1816 "properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},1817 "required": ["x"],1818 }1819 assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"1820 assert tool_.invoke(1821 {1822 "name": "foo",1823 "args": {"x": 5, "y": "bar"},1824 "id": "123",1825 "type": "tool_call",1826 }1827 ) == ToolMessage("bar", tool_call_id="123", name="foo")1828 expected_error = (1829 ValidationError if not isinstance(tool_, InjectedTool) else TypeError1830 )1831 with pytest.raises(expected_error):1832 tool_.invoke({"x": 5})18331834 assert convert_to_openai_function(tool_) == {1835 "name": "foo",1836 "description": "Foo.",1837 "parameters": {1838 "type": "object",1839 "properties": {"x": {"type": "integer", "description": "abc"}},1840 "required": ["x"],1841 },1842 }184318441845def test_tool_inherited_injected_arg() -> None:1846 class BarSchema(BaseModel):1847 """bar."""18481849 y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(1850 ..., description="123"1851 )18521853 class FooSchema(BarSchema):1854 """foo."""18551856 x: int = Field(..., description="abc")18571858 class InheritedInjectedArgTool(BaseTool):1859 name: str = "foo"1860 description: str = "foo."1861 args_schema: type[BaseModel] = FooSchema18621863 @override1864 def _run(self, x: int, y: str) -> Any:1865 return y18661867 tool_ = InheritedInjectedArgTool()1868 assert tool_.get_input_schema().model_json_schema() == {1869 "title": "FooSchema", # Matches the title from the provided schema1870 "description": "foo.",1871 "type": "object",1872 "properties": {1873 "x": {"description": "abc", "title": "X", "type": "integer"},1874 "y": {"description": "123", "title": "Y", "type": "string"},1875 },1876 "required": ["y", "x"],1877 }1878 # Should not include `y` since it's annotated as an injected tool arg1879 assert _get_tool_call_json_schema(tool_) == {1880 "title": "foo",1881 "description": "foo.",1882 "type": "object",1883 "properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},1884 "required": ["x"],1885 }1886 assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"1887 assert tool_.invoke(1888 {1889 "name": "foo",1890 "args": {"x": 5, "y": "bar"},1891 "id": "123",1892 "type": "tool_call",1893 }1894 ) == ToolMessage("bar", tool_call_id="123", name="foo")1895 with pytest.raises(ValidationError):1896 tool_.invoke({"x": 5})18971898 assert convert_to_openai_function(tool_) == {1899 "name": "foo",1900 "description": "foo.",1901 "parameters": {1902 "type": "object",1903 "properties": {"x": {"type": "integer", "description": "abc"}},1904 "required": ["x"],1905 },1906 }190719081909def _get_parametrized_tools() -> list[Callable[..., Any]]:1910 def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:1911 """my_tool."""1912 return "my_tool"19131914 async def my_async_tool(1915 x: int, y: str, *, some_tool: Annotated[Any, InjectedToolArg]1916 ) -> str:1917 """my_tool."""1918 return "my_tool"19191920 return [my_tool, my_async_tool]192119221923@pytest.mark.parametrize("tool_", _get_parametrized_tools())1924def test_fn_injected_arg_with_schema(tool_: Callable[..., Any]) -> None:1925 assert convert_to_openai_function(tool_) == {1926 "name": tool_.__name__,1927 "description": "my_tool.",1928 "parameters": {1929 "type": "object",1930 "properties": {1931 "x": {"type": "integer"},1932 "y": {"type": "string"},1933 },1934 "required": ["x", "y"],1935 },1936 }193719381939def generate_models() -> list[Any]:1940 """Generate a list of base models depending on the pydantic version."""19411942 class FooProper(BaseModel):1943 a: int1944 b: str19451946 return [FooProper]194719481949def generate_backwards_compatible_v1() -> list[Any]:1950 """Generate a model with pydantic 2 from the v1 namespace."""19511952 class FooV1Namespace(BaseModelV1):1953 a: int1954 b: str19551956 return [FooV1Namespace]195719581959# This generates a list of models that can be used for testing that our APIs1960# behave well with either pydantic 1 proper,1961# pydantic v1 from pydantic 2,1962# or pydantic 2 proper.1963TEST_MODELS = generate_models()19641965if sys.version_info < (3, 14):1966 TEST_MODELS += generate_backwards_compatible_v1()196719681969@pytest.mark.parametrize("pydantic_model", TEST_MODELS)1970def test_args_schema_as_pydantic(pydantic_model: Any) -> None:1971 class SomeTool(BaseTool):1972 args_schema: type[pydantic_model] = pydantic_model19731974 @override1975 def _run(self, *args: Any, **kwargs: Any) -> str:1976 return "foo"19771978 tool = SomeTool(1979 name="some_tool", description="some description", args_schema=pydantic_model1980 )19811982 assert tool.args == {1983 "a": {"title": "A", "type": "integer"},1984 "b": {"title": "B", "type": "string"},1985 }19861987 input_schema = tool.get_input_schema()1988 if issubclass(input_schema, BaseModel):1989 input_json_schema = input_schema.model_json_schema()1990 elif issubclass(input_schema, BaseModelV1):1991 input_json_schema = input_schema.schema()1992 else:1993 msg = "Unknown input schema type"1994 raise TypeError(msg)19951996 assert input_json_schema == {1997 "properties": {1998 "a": {"title": "A", "type": "integer"},1999 "b": {"title": "B", "type": "string"},2000 },
Findings
✓ No findings reported for this file.