libs/core/tests/unit_tests/test_tools.py PYTHON 3,744 lines View on github.com → Search inside
File is large — showing lines 1–2,000 of 3,744.
1"""Test the base tool implementation."""23import inspect4import json5import logging6import sys7import textwrap8import threading9from collections.abc import Callable10from dataclasses import dataclass11from datetime import datetime12from enum import Enum13from functools import partial14from typing import (15    Annotated,16    Any,17    Generic,18    Literal,19    TypeVar,20    cast,21    get_type_hints,22)2324import pytest25from pydantic import BaseModel, ConfigDict, Field, ValidationError26from pydantic.v1 import BaseModel as BaseModelV127from pydantic.v1 import ValidationError as ValidationErrorV128from typing_extensions import TypedDict, override2930from langchain_core import tools31from langchain_core.callbacks import (32    AsyncCallbackManagerForToolRun,33    CallbackManagerForToolRun,34)35from langchain_core.callbacks.manager import (36    CallbackManagerForRetrieverRun,37)38from langchain_core.documents import Document39from langchain_core.messages import ToolCall, ToolMessage40from langchain_core.messages.tool import ToolOutputMixin41from langchain_core.retrievers import BaseRetriever42from langchain_core.runnables import (43    RunnableConfig,44    RunnableLambda,45    ensure_config,46)47from langchain_core.tools import (48    BaseTool,49    StructuredTool,50    Tool,51    ToolException,52    tool,53)54from langchain_core.tools.base import (55    TOOL_MESSAGE_BLOCK_TYPES,56    ArgsSchema,57    InjectedToolArg,58    InjectedToolCallId,59    SchemaAnnotationError,60    _DirectlyInjectedToolArg,61    _format_output,62    _is_message_content_block,63    _is_message_content_type,64    get_all_basemodel_annotations,65)66from langchain_core.utils.function_calling import (67    convert_to_openai_function,68    convert_to_openai_tool,69)70from langchain_core.utils.pydantic import (71    _create_subset_model,72    create_model_v2,73)74from tests.unit_tests.fake.callbacks import FakeCallbackHandler75from tests.unit_tests.pydantic_utils import _normalize_schema, _schema7677try:78    from langgraph.prebuilt import ToolRuntime  # type: ignore[import-not-found]7980    HAS_LANGGRAPH = True81except ImportError:82    HAS_LANGGRAPH = False838485def _get_tool_call_json_schema(tool: BaseTool) -> dict[str, Any]:86    tool_schema = tool.tool_call_schema87    if isinstance(tool_schema, dict):88        return tool_schema8990    if issubclass(tool_schema, BaseModel):91        return tool_schema.model_json_schema()92    if issubclass(tool_schema, BaseModelV1):93        return tool_schema.schema()94    return {}959697def test_unnamed_decorator() -> None:98    """Test functionality with unnamed decorator."""99100    @tool101    def search_api(query: str) -> str:102        """Search the API for the query."""103        return "API result"104105    assert isinstance(search_api, BaseTool)106    assert search_api.name == "search_api"107    assert not search_api.return_direct108    assert search_api.invoke("test") == "API result"109110111class _MockSchema(BaseModel):112    """Return the arguments directly."""113114    arg1: int115    arg2: bool116    arg3: dict | None = None117118119class _MockStructuredTool(BaseTool):120    name: str = "structured_api"121    args_schema: type[BaseModel] = _MockSchema122    description: str = "A Structured Tool"123124    @override125    def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str:126        return f"{arg1} {arg2} {arg3}"127128    async def _arun(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str:129        raise NotImplementedError130131132class _FakeOutput(ToolOutputMixin):133    """Minimal ToolOutputMixin subclass used only in tests."""134135    def __init__(self, value: int) -> None:136        self.value = value137138    def __eq__(self, other: object) -> bool:139        return isinstance(other, _FakeOutput) and self.value == other.value140141    def __hash__(self) -> int:142        return hash(self.value)143144    def __repr__(self) -> str:145        return f"_FakeOutput({self.value})"146147148def test_structured_args() -> None:149    """Test functionality with structured arguments."""150    structured_api = _MockStructuredTool()151    assert isinstance(structured_api, BaseTool)152    assert structured_api.name == "structured_api"153    expected_result = "1 True {'foo': 'bar'}"154    args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}155    assert structured_api.run(args) == expected_result156157158def test_misannotated_base_tool_raises_error() -> None:159    """Test that a BaseTool with the incorrect typehint raises an exception."""160    with pytest.raises(SchemaAnnotationError):161162        class _MisAnnotatedTool(BaseTool):163            name: str = "structured_api"164            # This would silently be ignored without the custom metaclass165            args_schema: BaseModel = _MockSchema  # type: ignore[assignment]166            description: str = "A Structured Tool"167168            @override169            def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str:170                return f"{arg1} {arg2} {arg3}"171172            async def _arun(173                self, *, arg1: int, arg2: bool, arg3: dict | None = None174            ) -> str:175                raise NotImplementedError176177178def test_forward_ref_annotated_base_tool_accepted() -> None:179    """Test that a using forward ref annotation syntax is accepted."""180181    class _ForwardRefAnnotatedTool(BaseTool):182        name: str = "structured_api"183        args_schema: "type[BaseModel]" = _MockSchema184        description: str = "A Structured Tool"185186        @override187        def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str:188            return f"{arg1} {arg2} {arg3}"189190        async def _arun(191            self, *, arg1: int, arg2: bool, arg3: dict | None = None192        ) -> str:193            raise NotImplementedError194195196def test_subclass_annotated_base_tool_accepted() -> None:197    """Test BaseTool child w/ custom schema isn't overwritten."""198199    class _ForwardRefAnnotatedTool(BaseTool):200        name: str = "structured_api"201        args_schema: type[_MockSchema] = _MockSchema202        description: str = "A Structured Tool"203204        @override205        def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str:206            return f"{arg1} {arg2} {arg3}"207208        async def _arun(209            self, *, arg1: int, arg2: bool, arg3: dict | None = None210        ) -> str:211            raise NotImplementedError212213    assert issubclass(_ForwardRefAnnotatedTool, BaseTool)214    tool = _ForwardRefAnnotatedTool()215    assert tool.args_schema == _MockSchema216217218def test_decorator_with_specified_schema() -> None:219    """Test that manually specified schemata are passed through to the tool."""220221    @tool(args_schema=_MockSchema)222    def tool_func(*, arg1: int, arg2: bool, arg3: dict | None = None) -> str:223        return f"{arg1} {arg2} {arg3}"224225    assert isinstance(tool_func, BaseTool)226    assert tool_func.args_schema == _MockSchema227228229@pytest.mark.skipif(230    sys.version_info >= (3, 14),231    reason="pydantic.v1 namespace not supported with Python 3.14+",232)233def test_decorator_with_specified_schema_pydantic_v1() -> None:234    """Test that manually specified schemata are passed through to the tool."""235236    class _MockSchemaV1(BaseModelV1):237        """Return the arguments directly."""238239        arg1: int240        arg2: bool241        arg3: dict | None = None242243    @tool(args_schema=cast("ArgsSchema", _MockSchemaV1))244    def tool_func_v1(*, arg1: int, arg2: bool, arg3: dict | None = None) -> str:245        return f"{arg1} {arg2} {arg3}"246247    assert isinstance(tool_func_v1, BaseTool)248    assert tool_func_v1.args_schema == cast("ArgsSchema", _MockSchemaV1)249250251def test_decorated_function_schema_equivalent() -> None:252    """Test that a BaseTool without a schema meets expectations."""253254    @tool255    def structured_tool_input(256        *, arg1: int, arg2: bool, arg3: dict | None = None257    ) -> str:258        """Return the arguments directly."""259        return f"{arg1} {arg2} {arg3}"260261    assert isinstance(structured_tool_input, BaseTool)262    assert structured_tool_input.args_schema is not None263    assert (264        _schema(structured_tool_input.args_schema)["properties"]265        == _schema(_MockSchema)["properties"]266        == _normalize_schema(structured_tool_input.args)267    )268269270def test_args_kwargs_filtered() -> None:271    class _SingleArgToolWithKwargs(BaseTool):272        name: str = "single_arg_tool"273        description: str = "A  single arged tool with kwargs"274275        @override276        def _run(277            self,278            some_arg: str,279            run_manager: CallbackManagerForToolRun | None = None,280            **kwargs: Any,281        ) -> str:282            return "foo"283284        async def _arun(285            self,286            some_arg: str,287            run_manager: AsyncCallbackManagerForToolRun | None = None,288            **kwargs: Any,289        ) -> str:290            raise NotImplementedError291292    tool = _SingleArgToolWithKwargs()293    assert tool.is_single_input294295    class _VarArgToolWithKwargs(BaseTool):296        name: str = "single_arg_tool"297        description: str = "A single arged tool with kwargs"298299        @override300        def _run(301            self,302            *args: Any,303            run_manager: CallbackManagerForToolRun | None = None,304            **kwargs: Any,305        ) -> str:306            return "foo"307308        async def _arun(309            self,310            *args: Any,311            run_manager: AsyncCallbackManagerForToolRun | None = None,312            **kwargs: Any,313        ) -> str:314            raise NotImplementedError315316    tool2 = _VarArgToolWithKwargs()317    assert tool2.is_single_input318319320def test_structured_args_decorator_no_infer_schema() -> None:321    """Test functionality with structured arguments parsed as a decorator."""322323    @tool(infer_schema=False)324    def structured_tool_input(325        arg1: int, arg2: float | datetime, opt_arg: dict | None = None326    ) -> str:327        """Return the arguments directly."""328        return f"{arg1}, {arg2}, {opt_arg}"329330    assert isinstance(structured_tool_input, BaseTool)331    assert structured_tool_input.name == "structured_tool_input"332    args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}333    with pytest.raises(ToolException):334        assert structured_tool_input.run(args)335336337def test_structured_single_str_decorator_no_infer_schema() -> None:338    """Test functionality with structured arguments parsed as a decorator."""339340    @tool(infer_schema=False)341    def unstructured_tool_input(tool_input: str) -> str:342        """Return the arguments directly."""343        assert isinstance(tool_input, str)344        return f"{tool_input}"345346    assert isinstance(unstructured_tool_input, BaseTool)347    assert unstructured_tool_input.args_schema is None348    assert unstructured_tool_input.run("foo") == "foo"349350351def test_structured_tool_types_parsed() -> None:352    """Test the non-primitive types are correctly passed to structured tools."""353354    class SomeEnum(Enum):355        A = "a"356        B = "b"357358    class SomeBaseModel(BaseModel):359        foo: str360361    @tool362    def structured_tool(363        some_enum: SomeEnum,364        some_base_model: SomeBaseModel,365    ) -> dict:366        """Return the arguments directly."""367        return {368            "some_enum": some_enum,369            "some_base_model": some_base_model,370        }371372    assert isinstance(structured_tool, StructuredTool)373    args = {374        "some_enum": SomeEnum.A.value,375        "some_base_model": SomeBaseModel(foo="bar").model_dump(),376    }377    result = structured_tool.run(json.loads(json.dumps(args)))378    expected = {379        "some_enum": SomeEnum.A,380        "some_base_model": SomeBaseModel(foo="bar"),381    }382    assert result == expected383384385@pytest.mark.skipif(386    sys.version_info >= (3, 14),387    reason="pydantic.v1 namespace not supported with Python 3.14+",388)389def test_structured_tool_types_parsed_pydantic_v1() -> None:390    """Test the non-primitive types are correctly passed to structured tools."""391392    class SomeBaseModel(BaseModelV1):393        foo: str394395    class AnotherBaseModel(BaseModelV1):396        bar: str397398    @tool399    def structured_tool(some_base_model: SomeBaseModel) -> AnotherBaseModel:400        """Return the arguments directly."""401        return AnotherBaseModel(bar=some_base_model.foo)402403    assert isinstance(structured_tool, StructuredTool)404405    expected = AnotherBaseModel(bar="baz")406    for arg in [407        SomeBaseModel(foo="baz"),408        SomeBaseModel(foo="baz").dict(),409    ]:410        args = {"some_base_model": arg}411        result = structured_tool.run(args)412        assert result == expected413414415def test_structured_tool_types_parsed_pydantic_mixed() -> None:416    """Test handling of tool with mixed Pydantic version arguments."""417418    class SomeBaseModel(BaseModelV1):419        foo: str420421    class AnotherBaseModel(BaseModel):422        bar: str423424    with pytest.raises(NotImplementedError):425426        @tool427        def structured_tool(428            some_base_model: SomeBaseModel, another_base_model: AnotherBaseModel429        ) -> None:430            """Return the arguments directly."""431432433def test_base_tool_inheritance_base_schema() -> None:434    """Test schema is correctly inferred when inheriting from BaseTool."""435436    class _MockSimpleTool(BaseTool):437        name: str = "simple_tool"438        description: str = "A Simple Tool"439440        @override441        def _run(self, tool_input: str) -> str:442            return f"{tool_input}"443444        @override445        async def _arun(self, tool_input: str) -> str:446            raise NotImplementedError447448    simple_tool = _MockSimpleTool()449    assert simple_tool.args_schema is None450    expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}451    assert simple_tool.args == expected_args452453454def test_tool_lambda_args_schema() -> None:455    """Test args schema inference when the tool argument is a lambda function."""456    tool = Tool(457        name="tool",458        description="A tool",459        func=lambda tool_input: tool_input,460    )461    assert tool.args_schema is None462    expected_args = {"tool_input": {"type": "string"}}463    assert tool.args == expected_args464465466def test_structured_tool_from_function_docstring() -> None:467    """Test that structured tools can be created from functions."""468469    def foo(bar: int, baz: str) -> str:470        """Docstring.471472        Args:473            bar: the bar value474            baz: the baz value475        """476        raise NotImplementedError477478    structured_tool = StructuredTool.from_function(foo)479    assert structured_tool.name == "foo"480    assert structured_tool.args == {481        "bar": {"title": "Bar", "type": "integer"},482        "baz": {"title": "Baz", "type": "string"},483    }484485    assert _schema(structured_tool.args_schema) == {486        "properties": {487            "bar": {"title": "Bar", "type": "integer"},488            "baz": {"title": "Baz", "type": "string"},489        },490        "description": inspect.getdoc(foo),491        "title": "foo",492        "type": "object",493        "required": ["bar", "baz"],494    }495496    assert foo.__doc__ is not None497    assert structured_tool.description == textwrap.dedent(foo.__doc__.strip())498499500def test_structured_tool_from_function_docstring_complex_args() -> None:501    """Test that structured tools can be created from functions."""502503    def foo(bar: int, baz: list[str]) -> str:504        """Docstring.505506        Args:507            bar: int508            baz: list[str]509        """510        raise NotImplementedError511512    structured_tool = StructuredTool.from_function(foo)513    assert structured_tool.name == "foo"514    assert structured_tool.args == {515        "bar": {"title": "Bar", "type": "integer"},516        "baz": {517            "title": "Baz",518            "type": "array",519            "items": {"type": "string"},520        },521    }522523    assert _schema(structured_tool.args_schema) == {524        "properties": {525            "bar": {"title": "Bar", "type": "integer"},526            "baz": {527                "title": "Baz",528                "type": "array",529                "items": {"type": "string"},530            },531        },532        "description": inspect.getdoc(foo),533        "title": "foo",534        "type": "object",535        "required": ["bar", "baz"],536    }537538    assert foo.__doc__ is not None539    assert structured_tool.description == textwrap.dedent(foo.__doc__).strip()540541542def test_structured_tool_lambda_multi_args_schema() -> None:543    """Test args schema inference when the tool argument is a lambda function."""544    tool = StructuredTool.from_function(545        name="tool",546        description="A tool",547        func=lambda tool_input, other_arg: f"{tool_input}{other_arg}",548    )549    assert tool.args_schema is not None550    expected_args = {551        "tool_input": {"title": "Tool Input"},552        "other_arg": {"title": "Other Arg"},553    }554    assert tool.args == expected_args555556557def test_tool_partial_function_args_schema() -> None:558    """Test args schema inference when the tool argument is a partial function."""559560    def func(tool_input: str, other_arg: str) -> str:561        assert isinstance(tool_input, str)562        assert isinstance(other_arg, str)563        return tool_input + other_arg564565    tool = Tool(566        name="tool",567        description="A tool",568        func=partial(func, other_arg="foo"),569    )570    assert tool.run("bar") == "barfoo"571572573def test_empty_args_decorator() -> None:574    """Test inferred schema of decorated fn with no args."""575576    @tool577    def empty_tool_input() -> str:578        """Return a constant."""579        return "the empty result"580581    assert isinstance(empty_tool_input, BaseTool)582    assert empty_tool_input.name == "empty_tool_input"583    assert empty_tool_input.args == {}584    assert empty_tool_input.run({}) == "the empty result"585586587def test_tool_from_function_with_run_manager() -> None:588    """Test run of tool when using run_manager."""589590    def foo(bar: str, callbacks: CallbackManagerForToolRun | None = None) -> str:  # noqa: D417591        """Docstring.592593        Args:594            bar: str.595        """596        assert callbacks is not None597        return "foo" + bar598599    handler = FakeCallbackHandler()600    tool = Tool.from_function(foo, name="foo", description="Docstring")601602    assert tool.run(tool_input={"bar": "bar"}, run_manager=[handler]) == "foobar"603    assert tool.run("baz", run_manager=[handler]) == "foobaz"604605606def test_structured_tool_from_function_with_run_manager() -> None:607    """Test args and schema of structured tool when using callbacks."""608609    def foo(  # noqa: D417610        bar: int, baz: str, callbacks: CallbackManagerForToolRun | None = None611    ) -> str:612        """Docstring.613614        Args:615            bar: int616            baz: str617        """618        assert callbacks is not None619        return str(bar) + baz620621    handler = FakeCallbackHandler()622    structured_tool = StructuredTool.from_function(foo)623624    assert structured_tool.args == {625        "bar": {"title": "Bar", "type": "integer"},626        "baz": {"title": "Baz", "type": "string"},627    }628629    assert _schema(structured_tool.args_schema) == {630        "properties": {631            "bar": {"title": "Bar", "type": "integer"},632            "baz": {"title": "Baz", "type": "string"},633        },634        "description": inspect.getdoc(foo),635        "title": "foo",636        "type": "object",637        "required": ["bar", "baz"],638    }639640    assert (641        structured_tool.run(642            tool_input={"bar": "10", "baz": "baz"}, run_manger=[handler]643        )644        == "10baz"645    )646647648def test_structured_tool_from_parameterless_function() -> None:649    """Test parameterless function of structured tool."""650651    def foo() -> str:652        """Docstring."""653        return "invoke foo"654655    structured_tool = StructuredTool.from_function(foo)656657    assert structured_tool.run({}) == "invoke foo"658    assert structured_tool.run("") == "invoke foo"659660661def test_named_tool_decorator() -> None:662    """Test functionality when arguments are provided as input to decorator."""663664    @tool("search")665    def search_api(query: str) -> str:666        """Search the API for the query."""667        assert isinstance(query, str)668        return f"API result - {query}"669670    assert isinstance(search_api, BaseTool)671    assert search_api.name == "search"672    assert not search_api.return_direct673    assert search_api.run({"query": "foo"}) == "API result - foo"674675676def test_named_tool_decorator_return_direct() -> None:677    """Test functionality when arguments and return direct are provided as input."""678679    @tool("search", return_direct=True)680    def search_api(query: str, *args: Any) -> str:681        """Search the API for the query."""682        return "API result"683684    assert isinstance(search_api, BaseTool)685    assert search_api.name == "search"686    assert search_api.return_direct687    assert search_api.run({"query": "foo"}) == "API result"688689690def test_unnamed_tool_decorator_return_direct() -> None:691    """Test functionality when only return direct is provided."""692693    @tool(return_direct=True)694    def search_api(query: str) -> str:695        """Search the API for the query."""696        assert isinstance(query, str)697        return "API result"698699    assert isinstance(search_api, BaseTool)700    assert search_api.name == "search_api"701    assert search_api.return_direct702    assert search_api.run({"query": "foo"}) == "API result"703704705def test_tool_with_kwargs() -> None:706    """Test functionality when only return direct is provided."""707708    @tool(return_direct=True)709    def search_api(710        arg_0: str,711        arg_1: float = 4.3,712        ping: str = "hi",713    ) -> str:714        """Search the API for the query."""715        return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"716717    assert isinstance(search_api, BaseTool)718    result = search_api.run(719        tool_input={720            "arg_0": "foo",721            "arg_1": 3.2,722            "ping": "pong",723        }724    )725    assert result == "arg_0=foo, arg_1=3.2, ping=pong"726727    result = search_api.run(728        tool_input={729            "arg_0": "foo",730        }731    )732    assert result == "arg_0=foo, arg_1=4.3, ping=hi"733    # For backwards compatibility, we still accept a single str arg734    result = search_api.run("foobar")735    assert result == "arg_0=foobar, arg_1=4.3, ping=hi"736737738def test_missing_docstring() -> None:739    """Test error is raised when docstring is missing."""740    # expect to throw a value error if there's no docstring741    with pytest.raises(ValueError, match="Function must have a docstring"):742743        @tool744        def search_api(query: str) -> str:745            return "API result"746747    @tool748    class MyTool(BaseModel):749        foo: str750751    assert not MyTool.description  # type: ignore[attr-defined]752753754def test_create_tool_positional_args() -> None:755    """Test that positional arguments are allowed."""756    test_tool = Tool("test_name", lambda x: x, "test_description")757    assert test_tool.invoke("foo") == "foo"758    assert test_tool.name == "test_name"759    assert test_tool.description == "test_description"760    assert test_tool.is_single_input761762763def test_create_tool_keyword_args() -> None:764    """Test that keyword arguments are allowed."""765    test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")766    assert test_tool.is_single_input767    assert test_tool.invoke("foo") == "foo"768    assert test_tool.name == "test_name"769    assert test_tool.description == "test_description"770771772async def test_create_async_tool() -> None:773    """Test that async tools are allowed."""774775    async def _test_func(x: str) -> str:776        return x777778    test_tool = Tool(779        name="test_name",780        func=lambda x: x,781        description="test_description",782        coroutine=_test_func,783    )784    assert test_tool.is_single_input785    assert test_tool.invoke("foo") == "foo"786    assert test_tool.name == "test_name"787    assert test_tool.description == "test_description"788    assert test_tool.coroutine is not None789    assert await test_tool.arun("foo") == "foo"790791792class _FakeExceptionTool(BaseTool):793    name: str = "exception"794    description: str = "an exception-throwing tool"795    exception: Exception = ToolException()796797    def _run(self) -> str:798        raise self.exception799800    async def _arun(self) -> str:801        raise self.exception802803804def test_exception_handling_bool() -> None:805    tool_ = _FakeExceptionTool(handle_tool_error=True)806    expected = "Tool execution error"807    actual = tool_.run({})808    assert expected == actual809810811def test_exception_handling_str() -> None:812    expected = "foo bar"813    tool_ = _FakeExceptionTool(handle_tool_error=expected)814    actual = tool_.run({})815    assert expected == actual816817818def test_exception_handling_callable() -> None:819    expected = "foo bar"820821    def handling(e: ToolException) -> str:822        return expected823824    tool_ = _FakeExceptionTool(handle_tool_error=handling)825    actual = tool_.run({})826    assert expected == actual827828829def test_exception_handling_non_tool_exception() -> None:830    tool_ = _FakeExceptionTool(exception=ValueError("some error"))831    with pytest.raises(ValueError, match="some error"):832        tool_.run({})833834835async def test_async_exception_handling_bool() -> None:836    tool_ = _FakeExceptionTool(handle_tool_error=True)837    expected = "Tool execution error"838    actual = await tool_.arun({})839    assert expected == actual840841842async def test_async_exception_handling_str() -> None:843    expected = "foo bar"844    tool_ = _FakeExceptionTool(handle_tool_error=expected)845    actual = await tool_.arun({})846    assert expected == actual847848849async def test_async_exception_handling_callable() -> None:850    expected = "foo bar"851852    def handling(e: ToolException) -> str:853        return expected854855    tool_ = _FakeExceptionTool(handle_tool_error=handling)856    actual = await tool_.arun({})857    assert expected == actual858859860async def test_async_exception_handling_non_tool_exception() -> None:861    tool_ = _FakeExceptionTool(exception=ValueError("some error"))862    with pytest.raises(ValueError, match="some error"):863        await tool_.arun({})864865866def test_structured_tool_from_function() -> None:867    """Test that structured tools can be created from functions."""868869    def foo(bar: int, baz: str) -> str:870        """Docstring thing.871872        Args:873            bar: the bar value874            baz: the baz value875        """876        raise NotImplementedError877878    structured_tool = StructuredTool.from_function(foo)879    assert structured_tool.name == "foo"880    assert structured_tool.args == {881        "bar": {"title": "Bar", "type": "integer"},882        "baz": {"title": "Baz", "type": "string"},883    }884885    assert _schema(structured_tool.args_schema) == {886        "title": "foo",887        "type": "object",888        "description": inspect.getdoc(foo),889        "properties": {890            "bar": {"title": "Bar", "type": "integer"},891            "baz": {"title": "Baz", "type": "string"},892        },893        "required": ["bar", "baz"],894    }895896    assert foo.__doc__ is not None897    assert structured_tool.description == textwrap.dedent(foo.__doc__.strip())898899900def test_validation_error_handling_bool() -> None:901    """Test that validation errors are handled correctly."""902    expected = "Tool input validation error"903    tool_ = _MockStructuredTool(handle_validation_error=True)904    actual = tool_.run({})905    assert expected == actual906907908def test_validation_error_handling_str() -> None:909    """Test that validation errors are handled correctly."""910    expected = "foo bar"911    tool_ = _MockStructuredTool(handle_validation_error=expected)912    actual = tool_.run({})913    assert expected == actual914915916def test_validation_error_handling_callable() -> None:917    """Test that validation errors are handled correctly."""918    expected = "foo bar"919920    def handling(e: ValidationError | ValidationErrorV1) -> str:921        return expected922923    tool_ = _MockStructuredTool(handle_validation_error=handling)924    actual = tool_.run({})925    assert expected == actual926927928@pytest.mark.parametrize(929    "handler",930    [931        True,932        "foo bar",933        lambda _: "foo bar",934    ],935)936def test_validation_error_handling_non_validation_error(937    *,938    handler: bool | str | Callable[[ValidationError | ValidationErrorV1], str],939) -> None:940    """Test that validation errors are handled correctly."""941942    class _RaiseNonValidationErrorTool(BaseTool):943        name: str = "raise_non_validation_error_tool"944        description: str = "A tool that raises a non-validation error"945946        def _parse_input(947            self,948            tool_input: str | dict,949            tool_call_id: str | None,950        ) -> str | dict[str, Any]:951            raise NotImplementedError952953        @override954        def _run(self) -> str:955            return "dummy"956957        @override958        async def _arun(self) -> str:959            return "dummy"960961    tool_ = _RaiseNonValidationErrorTool(handle_validation_error=handler)962    with pytest.raises(NotImplementedError):963        tool_.run({})964965966async def test_async_validation_error_handling_bool() -> None:967    """Test that validation errors are handled correctly."""968    expected = "Tool input validation error"969    tool_ = _MockStructuredTool(handle_validation_error=True)970    actual = await tool_.arun({})971    assert expected == actual972973974async def test_async_validation_error_handling_str() -> None:975    """Test that validation errors are handled correctly."""976    expected = "foo bar"977    tool_ = _MockStructuredTool(handle_validation_error=expected)978    actual = await tool_.arun({})979    assert expected == actual980981982async def test_async_validation_error_handling_callable() -> None:983    """Test that validation errors are handled correctly."""984    expected = "foo bar"985986    def handling(e: ValidationError | ValidationErrorV1) -> str:987        return expected988989    tool_ = _MockStructuredTool(handle_validation_error=handling)990    actual = await tool_.arun({})991    assert expected == actual992993994@pytest.mark.parametrize(995    "handler",996    [997        True,998        "foo bar",999        lambda _: "foo bar",1000    ],1001)1002async def test_async_validation_error_handling_non_validation_error(1003    *,1004    handler: bool | str | Callable[[ValidationError | ValidationErrorV1], str],1005) -> None:1006    """Test that validation errors are handled correctly."""10071008    class _RaiseNonValidationErrorTool(BaseTool):1009        name: str = "raise_non_validation_error_tool"1010        description: str = "A tool that raises a non-validation error"10111012        def _parse_input(1013            self,1014            tool_input: str | dict,1015            tool_call_id: str | None,1016        ) -> str | dict[str, Any]:1017            raise NotImplementedError10181019        @override1020        def _run(self) -> str:1021            return "dummy"10221023        @override1024        async def _arun(self) -> str:1025            return "dummy"10261027    tool_ = _RaiseNonValidationErrorTool(handle_validation_error=handler)1028    with pytest.raises(NotImplementedError):1029        await tool_.arun({})103010311032def test_optional_subset_model_rewrite() -> None:1033    class MyModel(BaseModel):1034        a: str | None = None1035        b: str1036        c: list[str | None] | None = None10371038    model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"])10391040    assert set(_schema(model2)["required"]) == {"b"}104110421043@pytest.mark.parametrize(1044    ("inputs", "expected"),1045    [1046        # Check not required1047        ({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}),1048        # Check overwritten1049        (1050            {"bar": "bar", "baz": 4, "buzz": "not-buzz"},1051            {"bar": "bar", "baz": 4, "buzz": "not-buzz"},1052        ),1053        # Check validation error when missing1054        ({}, None),1055        # Check validation error when wrong type1056        ({"bar": "bar", "baz": "not-an-int"}, None),1057        # Check OK when None explicitly passed1058        ({"bar": "bar", "baz": None}, {"bar": "bar", "baz": None, "buzz": "buzz"}),1059    ],1060)1061def test_tool_invoke_optional_args(inputs: dict, expected: dict | None) -> None:1062    @tool1063    def foo(bar: str, baz: int | None = 3, buzz: str | None = "buzz") -> dict:1064        """The foo."""1065        return {1066            "bar": bar,1067            "baz": baz,1068            "buzz": buzz,1069        }10701071    if expected is not None:1072        assert foo.invoke(inputs) == expected1073    else:1074        with pytest.raises(ValidationError):1075            foo.invoke(inputs)107610771078def test_tool_pass_context() -> None:1079    @tool1080    def foo(bar: str) -> str:1081        """The foo."""1082        config = ensure_config()1083        assert config["configurable"]["foo"] == "not-bar"1084        assert bar == "baz"1085        return bar10861087    assert foo.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"108810891090@pytest.mark.skipif(1091    sys.version_info < (3, 11),1092    reason="requires python3.11 or higher",1093)1094async def test_async_tool_pass_context() -> None:1095    @tool1096    async def foo(bar: str) -> str:1097        """The foo."""1098        config = ensure_config()1099        assert config["configurable"]["foo"] == "not-bar"1100        assert bar == "baz"1101        return bar11021103    assert (1104        await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"1105    )110611071108def assert_bar(bar: Any, bar_config: RunnableConfig) -> Any:1109    assert bar_config["configurable"]["foo"] == "not-bar"1110    assert bar == "baz"1111    return bar111211131114@tool1115def foo(bar: Any, bar_config: RunnableConfig) -> Any:1116    """The foo."""1117    return assert_bar(bar, bar_config)111811191120@tool1121async def afoo(bar: Any, bar_config: RunnableConfig) -> Any:1122    """The foo."""1123    return assert_bar(bar, bar_config)112411251126@tool(infer_schema=False)1127def simple_foo(bar: Any, bar_config: RunnableConfig) -> Any:1128    """The foo."""1129    return assert_bar(bar, bar_config)113011311132@tool(infer_schema=False)1133async def asimple_foo(bar: Any, bar_config: RunnableConfig) -> Any:1134    """The foo."""1135    return assert_bar(bar, bar_config)113611371138class FooBase(BaseTool):1139    name: str = "Foo"1140    description: str = "Foo"11411142    @override1143    def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:1144        return assert_bar(bar, bar_config)114511461147class AFooBase(FooBase):1148    @override1149    async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:1150        return assert_bar(bar, bar_config)115111521153@pytest.mark.parametrize("tool", [foo, simple_foo, FooBase(), AFooBase()])1154def test_tool_pass_config(tool: BaseTool) -> None:1155    assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"11561157    # Test we don't mutate tool calls1158    tool_call = {1159        "name": tool.name,1160        "args": {"bar": "baz"},1161        "id": "abc123",1162        "type": "tool_call",1163    }1164    _ = tool.invoke(tool_call, {"configurable": {"foo": "not-bar"}})1165    assert tool_call["args"] == {"bar": "baz"}116611671168class FooBaseNonPickleable(FooBase):1169    @override1170    def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:1171        return True117211731174def test_tool_pass_config_non_pickleable() -> None:1175    tool = FooBaseNonPickleable()11761177    args = {"bar": threading.Lock()}1178    tool_call = {1179        "name": tool.name,1180        "args": args,1181        "id": "abc123",1182        "type": "tool_call",1183    }1184    _ = tool.invoke(tool_call, {"configurable": {"foo": "not-bar"}})1185    assert tool_call["args"] == args118611871188@pytest.mark.parametrize(1189    "tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]1190)1191async def test_async_tool_pass_config(tool: BaseTool) -> None:1192    assert (1193        await tool.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}})1194        == "baz"1195    )119611971198def test_tool_description() -> None:1199    def foo(bar: str) -> str:1200        """The foo."""1201        return bar12021203    foo1 = tool(foo)1204    assert foo1.description == "The foo."12051206    foo2 = StructuredTool.from_function(foo)1207    assert foo2.description == "The foo."120812091210def test_tool_arg_descriptions() -> None:1211    def foo(bar: str, baz: int) -> str:1212        """The foo.12131214        Args:1215            bar: The bar.1216            baz: The baz.1217        """1218        return bar12191220    foo1 = tool(foo)1221    args_schema = _schema(foo1.args_schema)1222    assert args_schema == {1223        "title": "foo",1224        "type": "object",1225        "description": inspect.getdoc(foo),1226        "properties": {1227            "bar": {"title": "Bar", "type": "string"},1228            "baz": {"title": "Baz", "type": "integer"},1229        },1230        "required": ["bar", "baz"],1231    }12321233    # Test parses docstring1234    foo2 = tool(foo, parse_docstring=True)1235    args_schema = _schema(foo2.args_schema)1236    expected = {1237        "title": "foo",1238        "description": "The foo.",1239        "type": "object",1240        "properties": {1241            "bar": {"title": "Bar", "description": "The bar.", "type": "string"},1242            "baz": {"title": "Baz", "description": "The baz.", "type": "integer"},1243        },1244        "required": ["bar", "baz"],1245    }1246    assert args_schema == expected12471248    # Test parsing with run_manager does not raise error1249    def foo3(  # noqa: D4171250        bar: str, baz: int, run_manager: CallbackManagerForToolRun | None = None1251    ) -> str:1252        """The foo.12531254        Args:1255            bar: The bar.1256            baz: The baz.1257        """1258        return bar12591260    as_tool = tool(foo3, parse_docstring=True)1261    args_schema = _schema(as_tool.args_schema)1262    assert args_schema["description"] == expected["description"]1263    assert args_schema["properties"] == expected["properties"]12641265    # Test parsing with runtime does not raise error1266    def foo3_runtime(bar: str, baz: int, runtime: Any) -> str:  # noqa: D4171267        """The foo.12681269        Args:1270            bar: The bar.1271            baz: The baz.1272        """1273        return bar12741275    _ = tool(foo3_runtime, parse_docstring=True)12761277    # Test parameterless tool does not raise error for missing Args section1278    # in docstring.1279    def foo4() -> str:1280        """The foo."""1281        return "bar"12821283    as_tool = tool(foo4, parse_docstring=True)1284    args_schema = _schema(as_tool.args_schema)1285    assert args_schema["description"] == expected["description"]12861287    def foo5(run_manager: CallbackManagerForToolRun | None = None) -> str:1288        """The foo."""1289        return "bar"12901291    as_tool = tool(foo5, parse_docstring=True)1292    args_schema = _schema(as_tool.args_schema)1293    assert args_schema["description"] == expected["description"]129412951296def test_docstring_parsing() -> None:1297    expected = {1298        "title": "foo",1299        "description": "The foo.",1300        "type": "object",1301        "properties": {1302            "bar": {"title": "Bar", "description": "The bar.", "type": "string"},1303            "baz": {"title": "Baz", "description": "The baz.", "type": "integer"},1304        },1305        "required": ["bar", "baz"],1306    }13071308    # Simple case1309    def foo(bar: str, baz: int) -> str:1310        """The foo.13111312        Args:1313            bar: The bar.1314            baz: The baz.1315        """1316        return bar13171318    as_tool = tool(foo, parse_docstring=True)1319    args_schema = _schema(as_tool.args_schema)1320    assert args_schema["description"] == "The foo."1321    assert args_schema["properties"] == expected["properties"]13221323    # Multi-line description1324    def foo2(bar: str, baz: int) -> str:1325        """The foo.13261327        Additional description here.13281329        Args:1330            bar: The bar.1331            baz: The baz.1332        """1333        return bar13341335    as_tool = tool(foo2, parse_docstring=True)1336    args_schema2 = _schema(as_tool.args_schema)1337    assert args_schema2["description"] == "The foo. Additional description here."1338    assert args_schema2["properties"] == expected["properties"]13391340    # Multi-line with Returns block1341    def foo3(bar: str, baz: int) -> str:1342        """The foo.13431344        Additional description here.13451346        Args:1347            bar: The bar.1348            baz: The baz.13491350        Returns:1351            description of returned value.1352        """1353        return bar13541355    as_tool = tool(foo3, parse_docstring=True)1356    args_schema3 = _schema(as_tool.args_schema)1357    args_schema3["title"] = "foo2"1358    assert args_schema2 == args_schema313591360    # Single argument1361    def foo4(bar: str) -> str:1362        """The foo.13631364        Args:1365            bar: The bar.1366        """1367        return bar13681369    as_tool = tool(foo4, parse_docstring=True)1370    args_schema4 = _schema(as_tool.args_schema)1371    assert args_schema4["description"] == "The foo."1372    assert args_schema4["properties"] == {1373        "bar": {"description": "The bar.", "title": "Bar", "type": "string"}1374    }137513761377def test_tool_invalid_docstrings() -> None:1378    """Test invalid docstrings."""13791380    def foo3(bar: str, baz: int) -> str:1381        """The foo."""1382        return bar13831384    def foo4(bar: str, baz: int) -> str:1385        """The foo.1386        Args:1387            bar: The bar.1388            baz: The baz.1389        """  # noqa: D205,D411  # We're intentionally testing bad formatting.1390        return bar13911392    for func in {foo3, foo4}:1393        with pytest.raises(ValueError, match="Found invalid Google-Style docstring"):1394            _ = tool(func, parse_docstring=True)13951396    def foo5(bar: str, baz: int) -> str:  # noqa: D4171397        """The foo.13981399        Args:1400            banana: The bar.1401            monkey: The baz.1402        """1403        return bar14041405    with pytest.raises(1406        ValueError, match="Arg banana in docstring not found in function signature"1407    ):1408        _ = tool(foo5, parse_docstring=True)140914101411def test_tool_annotated_descriptions() -> None:1412    def foo(1413        bar: Annotated[str, "this is the bar"], baz: Annotated[int, "this is the baz"]1414    ) -> str:1415        """The foo.14161417        Returns:1418            The bar only.1419        """1420        return bar14211422    foo1 = tool(foo)1423    args_schema = _schema(foo1.args_schema)1424    assert args_schema == {1425        "title": "foo",1426        "type": "object",1427        "description": inspect.getdoc(foo),1428        "properties": {1429            "bar": {"title": "Bar", "type": "string", "description": "this is the bar"},1430            "baz": {1431                "title": "Baz",1432                "type": "integer",1433                "description": "this is the baz",1434            },1435        },1436        "required": ["bar", "baz"],1437    }143814391440def test_tool_field_description_preserved() -> None:1441    """Test that `Field(description=...)` is preserved in `@tool` decorator."""14421443    @tool1444    def my_tool(1445        topic: Annotated[str, Field(description="The research topic")],1446        depth: Annotated[int, Field(description="Search depth level")] = 3,1447    ) -> str:1448        """A tool for research."""1449        return f"{topic} at depth {depth}"14501451    args_schema = _schema(my_tool.args_schema)1452    assert args_schema == {1453        "title": "my_tool",1454        "type": "object",1455        "description": "A tool for research.",1456        "properties": {1457            "topic": {1458                "title": "Topic",1459                "type": "string",1460                "description": "The research topic",1461            },1462            "depth": {1463                "title": "Depth",1464                "type": "integer",1465                "description": "Search depth level",1466                "default": 3,1467            },1468        },1469        "required": ["topic"],1470    }147114721473def test_tool_call_input_tool_message_output() -> None:1474    tool_call = {1475        "name": "structured_api",1476        "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},1477        "id": "123",1478        "type": "tool_call",1479    }1480    tool = _MockStructuredTool()1481    expected = ToolMessage(1482        "1 True {'img': 'base64string...'}", tool_call_id="123", name="structured_api"1483    )1484    actual = tool.invoke(tool_call)1485    assert actual == expected14861487    tool_call.pop("type")1488    with pytest.raises(ValidationError):1489        tool.invoke(tool_call)149014911492@pytest.mark.parametrize("block_type", [*TOOL_MESSAGE_BLOCK_TYPES, "bad"])1493def test_tool_content_block_output(block_type: str) -> None:1494    @tool1495    def my_tool(query: str) -> list[dict[str, Any]]:1496        """Test tool."""1497        return [{"type": block_type, "foo": "bar"}]14981499    tool_call = {1500        "type": "tool_call",1501        "name": "my_tool",1502        "args": {"query": "baz"},1503        "id": "call_abc123",1504    }15051506    result = my_tool.invoke(tool_call)1507    assert isinstance(result, ToolMessage)15081509    if block_type in TOOL_MESSAGE_BLOCK_TYPES:1510        assert result.content == [{"type": block_type, "foo": "bar"}]1511    else:1512        assert result.content == '[{"type": "bad", "foo": "bar"}]'151315141515class _MockStructuredToolWithRawOutput(BaseTool):1516    name: str = "structured_api"1517    args_schema: type[BaseModel] = _MockSchema1518    description: str = "A Structured Tool"1519    response_format: Literal["content_and_artifact"] = "content_and_artifact"15201521    @override1522    def _run(1523        self,1524        arg1: int,1525        arg2: bool,1526        arg3: dict[str, Any] | None = None,1527    ) -> tuple[str, dict[str, Any]]:1528        return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}152915301531@tool("structured_api", response_format="content_and_artifact")1532def _mock_structured_tool_with_artifact(1533    *, arg1: int, arg2: bool, arg3: dict[str, str] | None = None1534) -> tuple[str, dict[str, Any]]:1535    """A Structured Tool."""1536    return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}153715381539@pytest.mark.parametrize(1540    "tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_artifact]1541)1542def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None:1543    tool_call: dict[str, Any] = {1544        "name": "structured_api",1545        "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},1546        "id": "123",1547        "type": "tool_call",1548    }1549    expected = ToolMessage(1550        "1 True", artifact=tool_call["args"], tool_call_id="123", name="structured_api"1551    )1552    actual = tool.invoke(tool_call)1553    assert actual == expected15541555    tool_call.pop("type")1556    with pytest.raises(ValidationError):1557        tool.invoke(tool_call)15581559    actual_content = tool.invoke(tool_call["args"])1560    assert actual_content == expected.content156115621563def test_convert_from_runnable_dict() -> None:1564    # Test with typed dict input1565    class Args(TypedDict):1566        a: int1567        b: list[int]15681569    def f(x: Args) -> str:1570        return str(x["a"] * max(x["b"]))15711572    runnable = RunnableLambda(f)1573    as_tool = runnable.as_tool()1574    args_schema = as_tool.args_schema1575    assert args_schema is not None1576    assert _schema(args_schema) == {1577        "title": "f",1578        "type": "object",1579        "properties": {1580            "a": {"title": "A", "type": "integer"},1581            "b": {"title": "B", "type": "array", "items": {"type": "integer"}},1582        },1583        "required": ["a", "b"],1584    }1585    assert as_tool.description1586    result = as_tool.invoke({"a": 3, "b": [1, 2]})1587    assert result == "6"15881589    as_tool = runnable.as_tool(name="my tool", description="test description")1590    assert as_tool.name == "my tool"1591    assert as_tool.description == "test description"15921593    # Dict without typed input-- must supply schema1594    def g(x: dict[str, Any]) -> str:1595        return str(x["a"] * max(x["b"]))15961597    # Specify via args_schema:1598    class GSchema(BaseModel):1599        """Apply a function to an integer and list of integers."""16001601        a: int = Field(..., description="Integer")1602        b: list[int] = Field(..., description="List of ints")16031604    runnable2 = RunnableLambda(g)1605    as_tool2 = runnable2.as_tool(GSchema)1606    as_tool2.invoke({"a": 3, "b": [1, 2]})16071608    # Specify via arg_types:1609    runnable3 = RunnableLambda(g)1610    as_tool3 = runnable3.as_tool(arg_types={"a": int, "b": list[int]})1611    result = as_tool3.invoke({"a": 3, "b": [1, 2]})1612    assert result == "6"16131614    # Test with config1615    def h(x: dict[str, Any]) -> str:1616        config = ensure_config()1617        assert config["configurable"]["foo"] == "not-bar"1618        return str(x["a"] * max(x["b"]))16191620    runnable4 = RunnableLambda(h)1621    as_tool4 = runnable4.as_tool(arg_types={"a": int, "b": list[int]})1622    result = as_tool4.invoke(1623        {"a": 3, "b": [1, 2]}, config={"configurable": {"foo": "not-bar"}}1624    )1625    assert result == "6"162616271628def test_convert_from_runnable_other() -> None:1629    # String input1630    def f(x: str) -> str:1631        return x + "a"16321633    def g(x: str) -> str:1634        return x + "z"16351636    runnable = RunnableLambda(f) | g1637    as_tool = runnable.as_tool()1638    args_schema = as_tool.args_schema1639    assert args_schema is None1640    assert as_tool.description16411642    result = as_tool.invoke("b")1643    assert result == "baz"16441645    # Test with config1646    def h(x: str) -> str:1647        config = ensure_config()1648        assert config["configurable"]["foo"] == "not-bar"1649        return x + "a"16501651    runnable2 = RunnableLambda(h)1652    as_tool2 = runnable2.as_tool()1653    result2 = as_tool2.invoke("b", config={"configurable": {"foo": "not-bar"}})1654    assert result2 == "ba"165516561657@tool("foo", parse_docstring=True)1658def injected_tool(x: int, y: Annotated[str, InjectedToolArg]) -> str:1659    """Foo.16601661    Args:1662        x: abc1663        y: 1231664    """1665    return y166616671668class InjectedTool(BaseTool):1669    name: str = "foo"1670    description: str = "foo."16711672    @override1673    def _run(self, x: int, y: Annotated[str, InjectedToolArg]) -> Any:1674        """Foo.16751676        Args:1677            x: abc1678            y: 1231679        """1680        return y168116821683class fooSchema(BaseModel):  # noqa: N8011684    """foo."""16851686    x: int = Field(..., description="abc")1687    y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(1688        ..., description="123"1689    )169016911692class InjectedToolWithSchema(BaseTool):1693    name: str = "foo"1694    description: str = "foo."1695    args_schema: type[BaseModel] = fooSchema16961697    @override1698    def _run(self, x: int, y: str) -> Any:1699        return y170017011702@tool("foo", args_schema=fooSchema)1703def injected_tool_with_schema(x: int, y: str) -> str:1704    return y170517061707@pytest.mark.parametrize("tool_", [InjectedTool()])1708def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:1709    assert _schema(tool_.get_input_schema()) == {1710        "title": "foo",1711        "description": "Foo.\n\nArgs:\n    x: abc\n    y: 123",1712        "type": "object",1713        "properties": {1714            "x": {"title": "X", "type": "integer"},1715            "y": {"title": "Y", "type": "string"},1716        },1717        "required": ["x", "y"],1718    }1719    assert _schema(tool_.tool_call_schema) == {1720        "title": "foo",1721        "description": "foo.",1722        "type": "object",1723        "properties": {"x": {"title": "X", "type": "integer"}},1724        "required": ["x"],1725    }1726    assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"1727    assert tool_.invoke(1728        {1729            "name": "foo",1730            "args": {"x": 5, "y": "bar"},1731            "id": "123",1732            "type": "tool_call",1733        }1734    ) == ToolMessage("bar", tool_call_id="123", name="foo")1735    expected_error = (1736        ValidationError if not isinstance(tool_, InjectedTool) else TypeError1737    )1738    with pytest.raises(expected_error):1739        tool_.invoke({"x": 5})17401741    assert convert_to_openai_function(tool_) == {1742        "name": "foo",1743        "description": "foo.",1744        "parameters": {1745            "type": "object",1746            "properties": {"x": {"type": "integer"}},1747            "required": ["x"],1748        },1749    }175017511752@pytest.mark.parametrize(1753    "tool_",1754    [injected_tool_with_schema, InjectedToolWithSchema()],1755)1756def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:1757    assert _schema(tool_.get_input_schema()) == {1758        "title": "fooSchema",1759        "description": "foo.",1760        "type": "object",1761        "properties": {1762            "x": {"description": "abc", "title": "X", "type": "integer"},1763            "y": {"description": "123", "title": "Y", "type": "string"},1764        },1765        "required": ["x", "y"],1766    }1767    assert _schema(tool_.tool_call_schema) == {1768        "title": "foo",1769        "description": "foo.",1770        "type": "object",1771        "properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},1772        "required": ["x"],1773    }1774    assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"1775    assert tool_.invoke(1776        {1777            "name": "foo",1778            "args": {"x": 5, "y": "bar"},1779            "id": "123",1780            "type": "tool_call",1781        }1782    ) == ToolMessage("bar", tool_call_id="123", name="foo")1783    expected_error = (1784        ValidationError if not isinstance(tool_, InjectedTool) else TypeError1785    )1786    with pytest.raises(expected_error):1787        tool_.invoke({"x": 5})17881789    assert convert_to_openai_function(tool_) == {1790        "name": "foo",1791        "description": "foo.",1792        "parameters": {1793            "type": "object",1794            "properties": {"x": {"type": "integer", "description": "abc"}},1795            "required": ["x"],1796        },1797    }179817991800def test_tool_injected_arg() -> None:1801    tool_ = injected_tool1802    assert _schema(tool_.get_input_schema()) == {1803        "title": "foo",1804        "description": "Foo.",1805        "type": "object",1806        "properties": {1807            "x": {"description": "abc", "title": "X", "type": "integer"},1808            "y": {"description": "123", "title": "Y", "type": "string"},1809        },1810        "required": ["x", "y"],1811    }1812    assert _schema(tool_.tool_call_schema) == {1813        "title": "foo",1814        "description": "Foo.",1815        "type": "object",1816        "properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},1817        "required": ["x"],1818    }1819    assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"1820    assert tool_.invoke(1821        {1822            "name": "foo",1823            "args": {"x": 5, "y": "bar"},1824            "id": "123",1825            "type": "tool_call",1826        }1827    ) == ToolMessage("bar", tool_call_id="123", name="foo")1828    expected_error = (1829        ValidationError if not isinstance(tool_, InjectedTool) else TypeError1830    )1831    with pytest.raises(expected_error):1832        tool_.invoke({"x": 5})18331834    assert convert_to_openai_function(tool_) == {1835        "name": "foo",1836        "description": "Foo.",1837        "parameters": {1838            "type": "object",1839            "properties": {"x": {"type": "integer", "description": "abc"}},1840            "required": ["x"],1841        },1842    }184318441845def test_tool_inherited_injected_arg() -> None:1846    class BarSchema(BaseModel):1847        """bar."""18481849        y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(1850            ..., description="123"1851        )18521853    class FooSchema(BarSchema):1854        """foo."""18551856        x: int = Field(..., description="abc")18571858    class InheritedInjectedArgTool(BaseTool):1859        name: str = "foo"1860        description: str = "foo."1861        args_schema: type[BaseModel] = FooSchema18621863        @override1864        def _run(self, x: int, y: str) -> Any:1865            return y18661867    tool_ = InheritedInjectedArgTool()1868    assert tool_.get_input_schema().model_json_schema() == {1869        "title": "FooSchema",  # Matches the title from the provided schema1870        "description": "foo.",1871        "type": "object",1872        "properties": {1873            "x": {"description": "abc", "title": "X", "type": "integer"},1874            "y": {"description": "123", "title": "Y", "type": "string"},1875        },1876        "required": ["y", "x"],1877    }1878    # Should not include `y` since it's annotated as an injected tool arg1879    assert _get_tool_call_json_schema(tool_) == {1880        "title": "foo",1881        "description": "foo.",1882        "type": "object",1883        "properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},1884        "required": ["x"],1885    }1886    assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"1887    assert tool_.invoke(1888        {1889            "name": "foo",1890            "args": {"x": 5, "y": "bar"},1891            "id": "123",1892            "type": "tool_call",1893        }1894    ) == ToolMessage("bar", tool_call_id="123", name="foo")1895    with pytest.raises(ValidationError):1896        tool_.invoke({"x": 5})18971898    assert convert_to_openai_function(tool_) == {1899        "name": "foo",1900        "description": "foo.",1901        "parameters": {1902            "type": "object",1903            "properties": {"x": {"type": "integer", "description": "abc"}},1904            "required": ["x"],1905        },1906    }190719081909def _get_parametrized_tools() -> list[Callable[..., Any]]:1910    def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:1911        """my_tool."""1912        return "my_tool"19131914    async def my_async_tool(1915        x: int, y: str, *, some_tool: Annotated[Any, InjectedToolArg]1916    ) -> str:1917        """my_tool."""1918        return "my_tool"19191920    return [my_tool, my_async_tool]192119221923@pytest.mark.parametrize("tool_", _get_parametrized_tools())1924def test_fn_injected_arg_with_schema(tool_: Callable[..., Any]) -> None:1925    assert convert_to_openai_function(tool_) == {1926        "name": tool_.__name__,1927        "description": "my_tool.",1928        "parameters": {1929            "type": "object",1930            "properties": {1931                "x": {"type": "integer"},1932                "y": {"type": "string"},1933            },1934            "required": ["x", "y"],1935        },1936    }193719381939def generate_models() -> list[Any]:1940    """Generate a list of base models depending on the pydantic version."""19411942    class FooProper(BaseModel):1943        a: int1944        b: str19451946    return [FooProper]194719481949def generate_backwards_compatible_v1() -> list[Any]:1950    """Generate a model with pydantic 2 from the v1 namespace."""19511952    class FooV1Namespace(BaseModelV1):1953        a: int1954        b: str19551956    return [FooV1Namespace]195719581959# This generates a list of models that can be used for testing that our APIs1960# behave well with either pydantic 1 proper,1961# pydantic v1 from pydantic 2,1962# or pydantic 2 proper.1963TEST_MODELS = generate_models()19641965if sys.version_info < (3, 14):1966    TEST_MODELS += generate_backwards_compatible_v1()196719681969@pytest.mark.parametrize("pydantic_model", TEST_MODELS)1970def test_args_schema_as_pydantic(pydantic_model: Any) -> None:1971    class SomeTool(BaseTool):1972        args_schema: type[pydantic_model] = pydantic_model19731974        @override1975        def _run(self, *args: Any, **kwargs: Any) -> str:1976            return "foo"19771978    tool = SomeTool(1979        name="some_tool", description="some description", args_schema=pydantic_model1980    )19811982    assert tool.args == {1983        "a": {"title": "A", "type": "integer"},1984        "b": {"title": "B", "type": "string"},1985    }19861987    input_schema = tool.get_input_schema()1988    if issubclass(input_schema, BaseModel):1989        input_json_schema = input_schema.model_json_schema()1990    elif issubclass(input_schema, BaseModelV1):1991        input_json_schema = input_schema.schema()1992    else:1993        msg = "Unknown input schema type"1994        raise TypeError(msg)19951996    assert input_json_schema == {1997        "properties": {1998            "a": {"title": "A", "type": "integer"},1999            "b": {"title": "B", "type": "string"},2000        },

Findings

✓ No findings reported for this file.

Get this view in your editor

Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.