libs/core/tests/unit_tests/runnables/test_runnable.py PYTHON 5,770 lines View on github.com → Search inside
File is large — showing lines 1–2,000 of 5,770.
1import asyncio2import re3import sys4import time5import uuid6import warnings7from collections.abc import (8    AsyncIterator,9    Awaitable,10    Callable,11    Iterator,12    Sequence,13)14from functools import partial15from operator import itemgetter16from typing import Any, cast17from uuid import UUID1819import pytest20from freezegun import freeze_time21from packaging import version22from pydantic import BaseModel, Field23from pytest_mock import MockerFixture24from syrupy.assertion import SnapshotAssertion25from typing_extensions import TypedDict, override2627from langchain_core.callbacks import BaseCallbackHandler28from langchain_core.callbacks.manager import (29    AsyncCallbackManagerForRetrieverRun,30    CallbackManagerForRetrieverRun,31    atrace_as_chain_group,32    trace_as_chain_group,33)34from langchain_core.documents import Document35from langchain_core.language_models import (36    FakeListChatModel,37    FakeListLLM,38    FakeStreamingListLLM,39)40from langchain_core.language_models.fake_chat_models import GenericFakeChatModel41from langchain_core.load import dumpd, dumps42from langchain_core.load.load import loads43from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage44from langchain_core.messages.base import BaseMessage45from langchain_core.output_parsers import (46    BaseOutputParser,47    CommaSeparatedListOutputParser,48    StrOutputParser,49)50from langchain_core.outputs.chat_generation import ChatGeneration51from langchain_core.outputs.llm_result import LLMResult52from langchain_core.prompt_values import ChatPromptValue, StringPromptValue53from langchain_core.prompts import (54    ChatPromptTemplate,55    HumanMessagePromptTemplate,56    MessagesPlaceholder,57    PromptTemplate,58    SystemMessagePromptTemplate,59)60from langchain_core.retrievers import BaseRetriever61from langchain_core.runnables import (62    AddableDict,63    ConfigurableField,64    ConfigurableFieldMultiOption,65    ConfigurableFieldSingleOption,66    RouterRunnable,67    Runnable,68    RunnableAssign,69    RunnableBinding,70    RunnableBranch,71    RunnableConfig,72    RunnableGenerator,73    RunnableLambda,74    RunnableParallel,75    RunnablePassthrough,76    RunnablePick,77    RunnableSequence,78    add,79    chain,80)81from langchain_core.runnables.base import RunnableMap, RunnableSerializable82from langchain_core.runnables.utils import Input, Output83from langchain_core.tools import BaseTool, tool84from langchain_core.tracers import (85    BaseTracer,86    ConsoleCallbackHandler,87    Run,88    RunLog,89    RunLogPatch,90)91from langchain_core.tracers._compat import pydantic_copy92from langchain_core.tracers.context import collect_runs93from langchain_core.utils.pydantic import PYDANTIC_VERSION94from tests.unit_tests.pydantic_utils import _normalize_schema, _schema95from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk9697PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION98PYDANTIC_VERSION_AT_LEAST_210 = version.parse("2.10") <= PYDANTIC_VERSION99100101class FakeTracer(BaseTracer):102    """Fake tracer that records LangChain execution.103104    It replaces run IDs with deterministic UUIDs for snapshotting.105    """106107    def __init__(self) -> None:108        """Initialize the tracer."""109        super().__init__()110        self.runs: list[Run] = []111        self.uuids_map: dict[UUID, UUID] = {}112        self.uuids_generator = (113            UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)114        )115116    def _replace_uuid(self, uuid: UUID) -> UUID:117        if uuid not in self.uuids_map:118            self.uuids_map[uuid] = next(self.uuids_generator)119        return self.uuids_map[uuid]120121    def _replace_message_id(self, maybe_message: Any) -> Any:122        if isinstance(maybe_message, BaseMessage):123            maybe_message.id = str(next(self.uuids_generator))124        if isinstance(maybe_message, ChatGeneration):125            maybe_message.message.id = str(next(self.uuids_generator))126        if isinstance(maybe_message, LLMResult):127            for i, gen_list in enumerate(maybe_message.generations):128                for j, gen in enumerate(gen_list):129                    maybe_message.generations[i][j] = self._replace_message_id(gen)130        if isinstance(maybe_message, dict):131            for k, v in maybe_message.items():132                maybe_message[k] = self._replace_message_id(v)133        if isinstance(maybe_message, list):134            for i, v in enumerate(maybe_message):135                maybe_message[i] = self._replace_message_id(v)136137        return maybe_message138139    def _copy_run(self, run: Run) -> Run:140        if run.dotted_order:141            levels = run.dotted_order.split(".")142            processed_levels = []143            for level in levels:144                timestamp, run_id = level.split("Z")145                new_run_id = self._replace_uuid(UUID(run_id))146                processed_level = f"{timestamp}Z{new_run_id}"147                processed_levels.append(processed_level)148            new_dotted_order = ".".join(processed_levels)149        else:150            new_dotted_order = None151        update_dict = {152            "id": self._replace_uuid(run.id),153            "parent_run_id": (154                self.uuids_map[run.parent_run_id] if run.parent_run_id else None155            ),156            "child_runs": [self._copy_run(child) for child in run.child_runs],157            "trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,158            "dotted_order": new_dotted_order,159            "inputs": self._replace_message_id(run.inputs),160            "outputs": self._replace_message_id(run.outputs),161        }162        return pydantic_copy(run, update=update_dict)163164    def _persist_run(self, run: Run) -> None:165        """Persist a run."""166        self.runs.append(self._copy_run(run))167168    def flattened_runs(self) -> list[Run]:169        q = [*self.runs]170        result = []171        while q:172            parent = q.pop()173            result.append(parent)174            if parent.child_runs:175                q.extend(parent.child_runs)176        return result177178    @property179    def run_ids(self) -> list[uuid.UUID | None]:180        runs = self.flattened_runs()181        uuids_map = {v: k for k, v in self.uuids_map.items()}182        return [uuids_map.get(r.id) for r in runs]183184185class FakeRunnable(Runnable[str, int]):186    @override187    def invoke(188        self,189        input: str,190        config: RunnableConfig | None = None,191        **kwargs: Any,192    ) -> int:193        return len(input)194195196class FakeRunnableSerializable(RunnableSerializable[str, int]):197    hello: str = ""198199    @override200    def invoke(201        self,202        input: str,203        config: RunnableConfig | None = None,204        **kwargs: Any,205    ) -> int:206        return len(input)207208209class FakeRetriever(BaseRetriever):210    @override211    def _get_relevant_documents(212        self, query: str, *, run_manager: CallbackManagerForRetrieverRun213    ) -> list[Document]:214        return [Document(page_content="foo"), Document(page_content="bar")]215216    @override217    async def _aget_relevant_documents(218        self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun219    ) -> list[Document]:220        return [Document(page_content="foo"), Document(page_content="bar")]221222223@pytest.mark.skipif(224    PYDANTIC_VERSION_AT_LEAST_210,225    reason=(226        "Only test with most recent version of pydantic. "227        "Pydantic introduced small fixes to generated JSONSchema on minor versions."228    ),229)230def test_schemas(snapshot: SnapshotAssertion) -> None:231    fake = FakeRunnable()  # str -> int232233    assert fake.get_input_jsonschema() == {234        "title": "FakeRunnableInput",235        "type": "string",236    }237    assert fake.get_output_jsonschema() == {238        "title": "FakeRunnableOutput",239        "type": "integer",240    }241    assert fake.get_config_jsonschema(include=["tags", "metadata", "run_name"]) == {242        "properties": {243            "metadata": {244                "default": None,245                "title": "Metadata",246                "type": "object",247            },248            "run_name": {"default": None, "title": "Run Name", "type": "string"},249            "tags": {250                "default": None,251                "items": {"type": "string"},252                "title": "Tags",253                "type": "array",254            },255        },256        "title": "FakeRunnableConfig",257        "type": "object",258    }259260    fake_bound = FakeRunnable().bind(a="b")  # str -> int261262    assert fake_bound.get_input_jsonschema() == {263        "title": "FakeRunnableInput",264        "type": "string",265    }266    assert fake_bound.get_output_jsonschema() == {267        "title": "FakeRunnableOutput",268        "type": "integer",269    }270271    fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,))  # str -> int272273    assert fake_w_fallbacks.get_input_jsonschema() == {274        "title": "FakeRunnableInput",275        "type": "string",276    }277    assert fake_w_fallbacks.get_output_jsonschema() == {278        "title": "FakeRunnableOutput",279        "type": "integer",280    }281282    def typed_lambda_impl(x: str) -> int:283        return len(x)284285    typed_lambda = RunnableLambda(typed_lambda_impl)  # str -> int286287    assert typed_lambda.get_input_jsonschema() == {288        "title": "typed_lambda_impl_input",289        "type": "string",290    }291    assert typed_lambda.get_output_jsonschema() == {292        "title": "typed_lambda_impl_output",293        "type": "integer",294    }295296    async def typed_async_lambda_impl(x: str) -> int:297        return len(x)298299    typed_async_lambda = RunnableLambda(typed_async_lambda_impl)  # str -> int300301    assert typed_async_lambda.get_input_jsonschema() == {302        "title": "typed_async_lambda_impl_input",303        "type": "string",304    }305    assert typed_async_lambda.get_output_jsonschema() == {306        "title": "typed_async_lambda_impl_output",307        "type": "integer",308    }309310    fake_ret = FakeRetriever()  # str -> list[Document]311312    assert fake_ret.get_input_jsonschema() == {313        "title": "FakeRetrieverInput",314        "type": "string",315    }316    assert _normalize_schema(fake_ret.get_output_jsonschema()) == {317        "$defs": {318            "Document": {319                "description": "Class for storing a piece of text and "320                "associated metadata.\n"321                "\n"322                "!!! note\n"323                "\n"324                "    `Document` is for **retrieval workflows**, not chat I/O. For "325                "sending text\n"326                "    to an LLM in a conversation, use message types from "327                "`langchain.messages`.\n"328                "\n"329                "Example:\n"330                "    ```python\n"331                "    from langchain_core.documents import Document\n"332                "\n"333                "    document = Document(\n"334                '        page_content="Hello, world!", '335                'metadata={"source": "https://example.com"}\n'336                "    )\n"337                "    ```",338                "properties": {339                    "id": {340                        "anyOf": [{"type": "string"}, {"type": "null"}],341                        "default": None,342                        "title": "Id",343                    },344                    "metadata": {"title": "Metadata", "type": "object"},345                    "page_content": {"title": "Page Content", "type": "string"},346                    "type": {347                        "const": "Document",348                        "default": "Document",349                        "title": "Type",350                    },351                },352                "required": ["page_content"],353                "title": "Document",354                "type": "object",355            }356        },357        "items": {"$ref": "#/$defs/Document"},358        "title": "FakeRetrieverOutput",359        "type": "array",360    }361362    fake_llm = FakeListLLM(responses=["a"])  # str -> list[list[str]]363364    assert _schema(fake_llm.input_schema) == snapshot(name="fake_llm_input_schema")365    assert _schema(fake_llm.output_schema) == {366        "title": "FakeListLLMOutput",367        "type": "string",368    }369370    fake_chat = FakeListChatModel(responses=["a"])  # str -> list[list[str]]371372    assert _schema(fake_chat.input_schema) == snapshot(name="fake_chat_input_schema")373    assert _schema(fake_chat.output_schema) == snapshot(name="fake_chat_output_schema")374375    chat_prompt = ChatPromptTemplate.from_messages(376        [377            MessagesPlaceholder(variable_name="history"),378            ("human", "Hello, how are you?"),379        ]380    )381382    assert _normalize_schema(chat_prompt.get_input_jsonschema()) == snapshot(383        name="chat_prompt_input_schema"384    )385    assert _normalize_schema(chat_prompt.get_output_jsonschema()) == snapshot(386        name="chat_prompt_output_schema"387    )388389    prompt = PromptTemplate.from_template("Hello, {name}!")390391    assert prompt.get_input_jsonschema() == {392        "title": "PromptInput",393        "type": "object",394        "properties": {"name": {"title": "Name", "type": "string"}},395        "required": ["name"],396    }397    assert _schema(prompt.output_schema) == snapshot(name="prompt_output_schema")398399    prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()400401    assert _normalize_schema(prompt_mapper.get_input_jsonschema()) == {402        "$defs": {403            "PromptInput": {404                "properties": {"name": {"title": "Name", "type": "string"}},405                "required": ["name"],406                "title": "PromptInput",407                "type": "object",408            }409        },410        "default": None,411        "items": {"$ref": "#/$defs/PromptInput"},412        "title": "RunnableEach<PromptTemplate>Input",413        "type": "array",414    }415    assert _schema(prompt_mapper.output_schema) == snapshot(416        name="prompt_mapper_output_schema"417    )418419    list_parser = CommaSeparatedListOutputParser()420421    assert _schema(list_parser.input_schema) == snapshot(422        name="list_parser_input_schema"423    )424    assert _schema(list_parser.output_schema) == {425        "title": "CommaSeparatedListOutputParserOutput",426        "type": "array",427        "items": {"type": "string"},428    }429430    seq = prompt | fake_llm | list_parser431432    assert seq.get_input_jsonschema() == {433        "title": "PromptInput",434        "type": "object",435        "properties": {"name": {"title": "Name", "type": "string"}},436        "required": ["name"],437    }438    assert seq.get_output_jsonschema() == {439        "type": "array",440        "items": {"type": "string"},441        "title": "CommaSeparatedListOutputParserOutput",442    }443444    router: Runnable = RouterRunnable({})445446    assert _schema(router.input_schema) == {447        "$ref": "#/definitions/RouterInput",448        "definitions": {449            "RouterInput": {450                "description": "Router input.",451                "properties": {452                    "input": {"title": "Input"},453                    "key": {"title": "Key", "type": "string"},454                },455                "required": ["key", "input"],456                "title": "RouterInput",457                "type": "object",458            }459        },460        "title": "RouterRunnableInput",461    }462    assert router.get_output_jsonschema() == {"title": "RouterRunnableOutput"}463464    seq_w_map: Runnable = (465        prompt466        | fake_llm467        | {468            "original": RunnablePassthrough(input_type=str),469            "as_list": list_parser,470            "length": typed_lambda_impl,471        }472    )473474    assert seq_w_map.get_input_jsonschema() == {475        "title": "PromptInput",476        "type": "object",477        "properties": {"name": {"title": "Name", "type": "string"}},478        "required": ["name"],479    }480    assert seq_w_map.get_output_jsonschema() == {481        "title": "RunnableParallel<original,as_list,length>Output",482        "type": "object",483        "properties": {484            "original": {"title": "Original", "type": "string"},485            "length": {"title": "Length", "type": "integer"},486            "as_list": {487                "title": "As List",488                "type": "array",489                "items": {"type": "string"},490            },491        },492        "required": ["original", "as_list", "length"],493    }494495    # Add a test for schema of runnable assign496    def foo(x: int) -> int:497        return x498499    foo_ = RunnableLambda(foo)500501    assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == {502        "properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},503        "required": ["root", "bar"],504        "title": "RunnableAssignOutput",505        "type": "object",506    }507508509def test_passthrough_assign_schema() -> None:510    retriever = FakeRetriever()  # str -> list[Document]511    prompt = PromptTemplate.from_template("{context} {question}")512    fake_llm = FakeListLLM(responses=["a"])  # str -> list[list[str]]513514    seq_w_assign = (515        RunnablePassthrough.assign(context=itemgetter("question") | retriever)516        | prompt517        | fake_llm518    )519520    assert seq_w_assign.get_input_jsonschema() == {521        "properties": {"question": {"title": "Question", "type": "string"}},522        "title": "RunnableSequenceInput",523        "type": "object",524        "required": ["question"],525    }526    assert seq_w_assign.get_output_jsonschema() == {527        "title": "FakeListLLMOutput",528        "type": "string",529    }530531    invalid_seq_w_assign = (532        RunnablePassthrough.assign(context=itemgetter("question") | retriever)533        | fake_llm534    )535536    # fallback to RunnableAssign.input_schema if next runnable doesn't have537    # expected dict input_schema538    assert invalid_seq_w_assign.get_input_jsonschema() == {539        "properties": {"question": {"title": "Question"}},540        "title": "RunnableParallel<context>Input",541        "type": "object",542        "required": ["question"],543    }544545546def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:547    first_lambda = lambda x: x["hello"]  # noqa: E731548    assert RunnableLambda(first_lambda).get_input_jsonschema() == {549        "title": "RunnableLambdaInput",550        "type": "object",551        "properties": {"hello": {"title": "Hello"}},552        "required": ["hello"],553    }554555    second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"])  # noqa: E731556    assert RunnableLambda(second_lambda).get_input_jsonschema() == {557        "title": "RunnableLambdaInput",558        "type": "object",559        "properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},560        "required": ["bye", "hello"],561    }562563    def get_value(value):  # type: ignore[no-untyped-def] # noqa: ANN001,ANN202564        return value["variable_name"]565566    assert RunnableLambda(get_value).get_input_jsonschema() == {567        "title": "get_value_input",568        "type": "object",569        "properties": {"variable_name": {"title": "Variable Name"}},570        "required": ["variable_name"],571    }572573    async def aget_value(value):  # type: ignore[no-untyped-def] # noqa: ANN001,ANN202574        return (value["variable_name"], value.get("another"))575576    assert RunnableLambda(aget_value).get_input_jsonschema() == {577        "title": "aget_value_input",578        "type": "object",579        "properties": {580            "another": {"title": "Another"},581            "variable_name": {"title": "Variable Name"},582        },583        "required": ["another", "variable_name"],584    }585586    async def aget_values(value):  # type: ignore[no-untyped-def] # noqa: ANN001,ANN202587        return {588            "hello": value["variable_name"],589            "bye": value["variable_name"],590            "byebye": value["yo"],591        }592593    assert RunnableLambda(aget_values).get_input_jsonschema() == {594        "title": "aget_values_input",595        "type": "object",596        "properties": {597            "variable_name": {"title": "Variable Name"},598            "yo": {"title": "Yo"},599        },600        "required": ["variable_name", "yo"],601    }602603    class InputType(TypedDict):604        variable_name: str605        yo: int606607    class OutputType(TypedDict):608        hello: str609        bye: str610        byebye: int611612    async def aget_values_typed(value: InputType) -> OutputType:613        return {614            "hello": value["variable_name"],615            "bye": value["variable_name"],616            "byebye": value["yo"],617        }618619    assert _normalize_schema(620        RunnableLambda(aget_values_typed).get_input_jsonschema()621    ) == _normalize_schema(622        {623            "$defs": {624                "InputType": {625                    "properties": {626                        "variable_name": {627                            "title": "Variable Name",628                            "type": "string",629                        },630                        "yo": {"title": "Yo", "type": "integer"},631                    },632                    "required": ["variable_name", "yo"],633                    "title": "InputType",634                    "type": "object",635                }636            },637            "allOf": [{"$ref": "#/$defs/InputType"}],638            "title": "aget_values_typed_input",639        }640    )641642    if PYDANTIC_VERSION_AT_LEAST_29:643        assert _normalize_schema(644            RunnableLambda(aget_values_typed).get_output_jsonschema()645        ) == snapshot(name="schema8")646647648def test_with_types_with_type_generics() -> None:649    """Verify that with_types works if we use things like list[int]."""650651    def foo(x: int) -> None:652        """Add one to the input."""653        raise NotImplementedError654655    # Try specifying some656    RunnableLambda(foo).with_types(657        output_type=list[int],  # type: ignore[arg-type]658        input_type=list[int],  # type: ignore[arg-type]659    )660    RunnableLambda(foo).with_types(661        output_type=Sequence[int],  # type: ignore[arg-type]662        input_type=Sequence[int],  # type: ignore[arg-type]663    )664665666def test_schema_with_itemgetter() -> None:667    """Test runnable with itemgetter."""668    foo = RunnableLambda(itemgetter("hello"))669    assert _schema(foo.input_schema) == {670        "properties": {"hello": {"title": "Hello"}},671        "required": ["hello"],672        "title": "RunnableLambdaInput",673        "type": "object",674    }675    prompt = ChatPromptTemplate.from_template("what is {language}?")676    chain: Runnable = {"language": itemgetter("language")} | prompt677    assert _schema(chain.input_schema) == {678        "properties": {"language": {"title": "Language"}},679        "required": ["language"],680        "title": "RunnableParallel<language>Input",681        "type": "object",682    }683684685def test_schema_complex_seq() -> None:686    prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")687    prompt2 = ChatPromptTemplate.from_template(688        "what country is the city {city} in? respond in {language}"689    )690691    model = FakeListChatModel(responses=[""])692693    chain1: Runnable = RunnableSequence(694        prompt1, model, StrOutputParser(), name="city_chain"695    )696697    assert chain1.name == "city_chain"698699    chain2: Runnable = (700        {"city": chain1, "language": itemgetter("language")}701        | prompt2702        | model703        | StrOutputParser()704    )705706    assert chain2.get_input_jsonschema() == {707        "title": "RunnableParallel<city,language>Input",708        "type": "object",709        "properties": {710            "person": {"title": "Person", "type": "string"},711            "language": {"title": "Language"},712        },713        "required": ["person", "language"],714    }715716    assert chain2.get_output_jsonschema() == {717        "title": "StrOutputParserOutput",718        "type": "string",719    }720721    assert chain2.with_types(input_type=str).get_input_jsonschema() == {722        "title": "RunnableSequenceInput",723        "type": "string",724    }725726    assert chain2.with_types(input_type=int).get_output_jsonschema() == {727        "title": "StrOutputParserOutput",728        "type": "string",729    }730731    class InputType(BaseModel):732        person: str733734    assert chain2.with_types(input_type=InputType).get_input_jsonschema() == {735        "title": "InputType",736        "type": "object",737        "properties": {"person": {"title": "Person", "type": "string"}},738        "required": ["person"],739    }740741742def test_configurable_fields(snapshot: SnapshotAssertion) -> None:743    fake_llm = FakeListLLM(responses=["a"])  # str -> list[list[str]]744745    assert fake_llm.invoke("...") == "a"746747    fake_llm_configurable = fake_llm.configurable_fields(748        responses=ConfigurableField(749            id="llm_responses",750            name="LLM Responses",751            description="A list of fake responses for this LLM",752        )753    )754755    assert fake_llm_configurable.invoke("...") == "a"756757    if PYDANTIC_VERSION_AT_LEAST_29:758        assert _normalize_schema(759            fake_llm_configurable.get_config_jsonschema()760        ) == snapshot(name="schema2")761762    fake_llm_configured = fake_llm_configurable.with_config(763        configurable={"llm_responses": ["b"]}764    )765766    assert fake_llm_configured.invoke("...") == "b"767768    prompt = PromptTemplate.from_template("Hello, {name}!")769770    assert prompt.invoke({"name": "John"}) == StringPromptValue(text="Hello, John!")771772    prompt_configurable = prompt.configurable_fields(773        template=ConfigurableField(774            id="prompt_template",775            name="Prompt Template",776            description="The prompt template for this chain",777        )778    )779780    assert prompt_configurable.invoke({"name": "John"}) == StringPromptValue(781        text="Hello, John!"782    )783784    if PYDANTIC_VERSION_AT_LEAST_29:785        assert _normalize_schema(786            prompt_configurable.get_config_jsonschema()787        ) == snapshot(name="schema3")788789    prompt_configured = prompt_configurable.with_config(790        configurable={"prompt_template": "Hello, {name}! {name}!"}791    )792793    assert prompt_configured.invoke({"name": "John"}) == StringPromptValue(794        text="Hello, John! John!"795    )796797    assert prompt_configurable.with_config(798        configurable={"prompt_template": "Hello {name} in {lang}"}799    ).get_input_jsonschema() == {800        "title": "PromptInput",801        "type": "object",802        "properties": {803            "lang": {"title": "Lang", "type": "string"},804            "name": {"title": "Name", "type": "string"},805        },806        "required": ["lang", "name"],807    }808809    chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()810811    assert chain_configurable.invoke({"name": "John"}) == "a"812813    if PYDANTIC_VERSION_AT_LEAST_29:814        assert _normalize_schema(815            chain_configurable.get_config_jsonschema()816        ) == snapshot(name="schema4")817818    assert (819        chain_configurable.with_config(820            configurable={821                "prompt_template": "A very good morning to you, {name} {lang}!",822                "llm_responses": ["c"],823            }824        ).invoke({"name": "John", "lang": "en"})825        == "c"826    )827828    assert chain_configurable.with_config(829        configurable={830            "prompt_template": "A very good morning to you, {name} {lang}!",831            "llm_responses": ["c"],832        }833    ).get_input_jsonschema() == {834        "title": "PromptInput",835        "type": "object",836        "properties": {837            "lang": {"title": "Lang", "type": "string"},838            "name": {"title": "Name", "type": "string"},839        },840        "required": ["lang", "name"],841    }842843    chain_with_map_configurable: Runnable = prompt_configurable | {844        "llm1": fake_llm_configurable | StrOutputParser(),845        "llm2": fake_llm_configurable | StrOutputParser(),846        "llm3": fake_llm.configurable_fields(847            responses=ConfigurableField("other_responses")848        )849        | StrOutputParser(),850    }851852    assert chain_with_map_configurable.invoke({"name": "John"}) == {853        "llm1": "a",854        "llm2": "a",855        "llm3": "a",856    }857858    if PYDANTIC_VERSION_AT_LEAST_29:859        assert _normalize_schema(860            chain_with_map_configurable.get_config_jsonschema()861        ) == snapshot(name="schema5")862863    assert chain_with_map_configurable.with_config(864        configurable={865            "prompt_template": "A very good morning to you, {name}!",866            "llm_responses": ["c"],867            "other_responses": ["d"],868        }869    ).invoke({"name": "John"}) == {"llm1": "c", "llm2": "c", "llm3": "d"}870871872def test_configurable_alts_factory() -> None:873    fake_llm = FakeListLLM(responses=["a"]).configurable_alternatives(874        ConfigurableField(id="llm", name="LLM"),875        chat=partial(FakeListLLM, responses=["b"]),876    )877878    assert fake_llm.invoke("...") == "a"879880    assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b"881882883def test_configurable_fields_prefix_keys(snapshot: SnapshotAssertion) -> None:884    fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(885        responses=ConfigurableFieldMultiOption(886            id="responses",887            name="Chat Responses",888            options={889                "hello": "A good morning to you!",890                "bye": "See you later!",891                "helpful": "How can I help you?",892            },893            default=["hello", "bye"],894        ),895        # (sleep is a configurable field in FakeListChatModel)896        sleep=ConfigurableField(897            id="chat_sleep",898            is_shared=True,899        ),900    )901    fake_llm = (902        FakeListLLM(responses=["a"])903        .configurable_fields(904            responses=ConfigurableField(905                id="responses",906                name="LLM Responses",907                description="A list of fake responses for this LLM",908            )909        )910        .configurable_alternatives(911            ConfigurableField(id="llm", name="LLM"),912            chat=fake_chat | StrOutputParser(),913            prefix_keys=True,914        )915    )916    prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(917        template=ConfigurableFieldSingleOption(918            id="prompt_template",919            name="Prompt Template",920            description="The prompt template for this chain",921            options={922                "hello": "Hello, {name}!",923                "good_morning": "A very good morning to you, {name}!",924            },925            default="hello",926        )927    )928929    chain = prompt | fake_llm930931    if PYDANTIC_VERSION_AT_LEAST_29:932        assert _normalize_schema(_schema(chain.config_schema())) == snapshot(933            name="schema6"934        )935936937def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None:938    fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(939        responses=ConfigurableFieldMultiOption(940            id="chat_responses",941            name="Chat Responses",942            options={943                "hello": "A good morning to you!",944                "bye": "See you later!",945                "helpful": "How can I help you?",946            },947            default=["hello", "bye"],948        )949    )950    fake_llm = (951        FakeListLLM(responses=["a"])952        .configurable_fields(953            responses=ConfigurableField(954                id="llm_responses",955                name="LLM Responses",956                description="A list of fake responses for this LLM",957            )958        )959        .configurable_alternatives(960            ConfigurableField(id="llm", name="LLM"),961            chat=fake_chat | StrOutputParser(),962        )963    )964965    prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(966        template=ConfigurableFieldSingleOption(967            id="prompt_template",968            name="Prompt Template",969            description="The prompt template for this chain",970            options={971                "hello": "Hello, {name}!",972                "good_morning": "A very good morning to you, {name}!",973            },974            default="hello",975        )976    )977978    # deduplication of configurable fields979    chain_configurable = prompt | fake_llm | (lambda x: {"name": x}) | prompt | fake_llm980981    assert chain_configurable.invoke({"name": "John"}) == "a"982983    if PYDANTIC_VERSION_AT_LEAST_29:984        assert _normalize_schema(985            chain_configurable.get_config_jsonschema()986        ) == snapshot(name="schema7")987988    assert (989        chain_configurable.with_config(configurable={"llm": "chat"}).invoke(990            {"name": "John"}991        )992        == "A good morning to you!"993    )994995    assert (996        chain_configurable.with_config(997            configurable={"llm": "chat", "chat_responses": ["helpful"]}998        ).invoke({"name": "John"})999        == "How can I help you?"1000    )100110021003def test_passthrough_tap(mocker: MockerFixture) -> None:1004    fake = FakeRunnable()1005    mock = mocker.Mock()10061007    seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)10081009    assert seq.invoke("hello", my_kwarg="value") == 51010    assert mock.call_args_list == [1011        mocker.call("hello", my_kwarg="value"),1012        mocker.call(5),1013    ]1014    mock.reset_mock()10151016    assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6]1017    assert len(mock.call_args_list) == 41018    for call in [1019        mocker.call("hello", my_kwarg="value"),1020        mocker.call("byebye", my_kwarg="value"),1021        mocker.call(5),1022        mocker.call(6),1023    ]:1024        assert call in mock.call_args_list1025    mock.reset_mock()10261027    assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [1028        5,1029        6,1030    ]1031    assert len(mock.call_args_list) == 41032    for call in [1033        mocker.call("hello", my_kwarg="value"),1034        mocker.call("byebye", my_kwarg="value"),1035        mocker.call(5),1036        mocker.call(6),1037    ]:1038        assert call in mock.call_args_list1039    mock.reset_mock()10401041    assert sorted(1042        a1043        for a in seq.batch_as_completed(1044            ["hello", "byebye"], my_kwarg="value", return_exceptions=True1045        )1046    ) == [1047        (0, 5),1048        (1, 6),1049    ]1050    assert len(mock.call_args_list) == 41051    for call in [1052        mocker.call("hello", my_kwarg="value"),1053        mocker.call("byebye", my_kwarg="value"),1054        mocker.call(5),1055        mocker.call(6),1056    ]:1057        assert call in mock.call_args_list1058    mock.reset_mock()10591060    assert list(1061        seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value")1062    ) == [5]1063    assert mock.call_args_list == [1064        mocker.call("hello", my_kwarg="value"),1065        mocker.call(5),1066    ]1067    mock.reset_mock()106810691070async def test_passthrough_tap_async(mocker: MockerFixture) -> None:1071    fake = FakeRunnable()1072    mock = mocker.Mock()10731074    seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)10751076    assert await seq.ainvoke("hello", my_kwarg="value") == 51077    assert mock.call_args_list == [1078        mocker.call("hello", my_kwarg="value"),1079        mocker.call(5),1080    ]1081    mock.reset_mock()10821083    assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6]1084    assert len(mock.call_args_list) == 41085    for call in [1086        mocker.call("hello", my_kwarg="value"),1087        mocker.call("byebye", my_kwarg="value"),1088        mocker.call(5),1089        mocker.call(6),1090    ]:1091        assert call in mock.call_args_list1092    mock.reset_mock()10931094    assert await seq.abatch(1095        ["hello", "byebye"], my_kwarg="value", return_exceptions=True1096    ) == [1097        5,1098        6,1099    ]1100    assert len(mock.call_args_list) == 41101    for call in [1102        mocker.call("hello", my_kwarg="value"),1103        mocker.call("byebye", my_kwarg="value"),1104        mocker.call(5),1105        mocker.call(6),1106    ]:1107        assert call in mock.call_args_list1108    mock.reset_mock()11091110    assert sorted(1111        [1112            a1113            async for a in seq.abatch_as_completed(1114                ["hello", "byebye"], my_kwarg="value", return_exceptions=True1115            )1116        ]1117    ) == [1118        (0, 5),1119        (1, 6),1120    ]1121    assert len(mock.call_args_list) == 41122    for call in [1123        mocker.call("hello", my_kwarg="value"),1124        mocker.call("byebye", my_kwarg="value"),1125        mocker.call(5),1126        mocker.call(6),1127    ]:1128        assert call in mock.call_args_list1129    mock.reset_mock()11301131    assert [1132        part1133        async for part in seq.astream(1134            "hello", {"metadata": {"key": "value"}}, my_kwarg="value"1135        )1136    ] == [5]1137    assert mock.call_args_list == [1138        mocker.call("hello", my_kwarg="value"),1139        mocker.call(5),1140    ]114111421143async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:1144    fake = FakeRunnableSerializable()1145    spy = mocker.spy(fake.__class__, "invoke")1146    fakew = fake.configurable_fields(hello=ConfigurableField(id="hello", name="Hello"))11471148    assert (1149        fakew.with_config(tags=["a-tag"]).invoke(1150            "hello",1151            {1152                "configurable": {"hello": "there", "__secret_key": "nahnah"},1153                "metadata": {"bye": "now"},1154            },1155        )1156        == 51157    )1158    assert spy.call_args_list[0].args[1:] == (1159        "hello",1160        {1161            "tags": ["a-tag"],1162            "callbacks": None,1163            "recursion_limit": 25,1164            "configurable": {"hello": "there", "__secret_key": "nahnah"},1165            "metadata": {"bye": "now"},1166        },1167    )1168    spy.reset_mock()116911701171def test_with_config(mocker: MockerFixture) -> None:1172    fake = FakeRunnable()1173    spy = mocker.spy(fake, "invoke")11741175    assert fake.with_config(tags=["a-tag"]).invoke("hello") == 51176    assert spy.call_args_list == [1177        mocker.call(1178            "hello",1179            {"tags": ["a-tag"], "metadata": {}, "configurable": {}},1180        ),1181    ]1182    spy.reset_mock()11831184    fake_1 = RunnablePassthrough[Any]()1185    fake_2 = RunnablePassthrough[Any]()1186    spy_seq_step = mocker.spy(fake_1.__class__, "invoke")11871188    sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(1189        tags=["b-tag"], max_concurrency=51190    )1191    assert sequence.invoke("hello") == "hello"1192    assert len(spy_seq_step.call_args_list) == 21193    for i, call in enumerate(spy_seq_step.call_args_list):1194        assert call.args[1] == "hello"1195        if i == 0:1196            assert call.args[2].get("tags") == ["a-tag"]1197            assert call.args[2].get("max_concurrency") is None1198        else:1199            assert call.args[2].get("tags") == ["b-tag"]1200            assert call.args[2].get("max_concurrency") == 51201    mocker.stop(spy_seq_step)12021203    assert [1204        *fake.with_config(tags=["a-tag"]).stream(1205            "hello", {"metadata": {"key": "value"}}1206        )1207    ] == [5]1208    assert spy.call_args_list == [1209        mocker.call(1210            "hello",1211            {"tags": ["a-tag"], "metadata": {"key": "value"}, "configurable": {}},1212        ),1213    ]1214    spy.reset_mock()12151216    assert fake.with_config(recursion_limit=5).batch(1217        ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]1218    ) == [5, 7]12191220    assert len(spy.call_args_list) == 21221    for i, call in enumerate(1222        sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)1223    ):1224        assert call.args[0] == ("hello" if i == 0 else "wooorld")1225        if i == 0:1226            assert call.args[1].get("recursion_limit") == 51227            assert call.args[1].get("tags") == ["a-tag"]1228            assert call.args[1].get("metadata") == {}1229        else:1230            assert call.args[1].get("recursion_limit") == 51231            assert call.args[1].get("tags") == []1232            assert call.args[1].get("metadata") == {"key": "value"}12331234    spy.reset_mock()12351236    assert sorted(1237        c1238        for c in fake.with_config(recursion_limit=5).batch_as_completed(1239            ["hello", "wooorld"],1240            [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}],1241        )1242    ) == [(0, 5), (1, 7)]12431244    assert len(spy.call_args_list) == 21245    for i, call in enumerate(1246        sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)1247    ):1248        assert call.args[0] == ("hello" if i == 0 else "wooorld")1249        if i == 0:1250            assert call.args[1].get("recursion_limit") == 51251            assert call.args[1].get("tags") == ["a-tag"]1252            assert call.args[1].get("metadata") == {}1253        else:1254            assert call.args[1].get("recursion_limit") == 51255            assert call.args[1].get("tags") == []1256            assert call.args[1].get("metadata") == {"key": "value"}12571258    spy.reset_mock()12591260    assert fake.with_config(metadata={"a": "b"}).batch(1261        ["hello", "wooorld"], {"tags": ["a-tag"]}1262    ) == [5, 7]1263    assert len(spy.call_args_list) == 21264    for i, call in enumerate(spy.call_args_list):1265        assert call.args[0] == ("hello" if i == 0 else "wooorld")1266        assert call.args[1].get("tags") == ["a-tag"]1267        assert call.args[1].get("metadata") == {"a": "b"}1268    spy.reset_mock()12691270    assert sorted(1271        c for c in fake.batch_as_completed(["hello", "wooorld"], {"tags": ["a-tag"]})1272    ) == [(0, 5), (1, 7)]1273    assert len(spy.call_args_list) == 21274    for i, call in enumerate(spy.call_args_list):1275        assert call.args[0] == ("hello" if i == 0 else "wooorld")1276        assert call.args[1].get("tags") == ["a-tag"]127712781279async def test_with_config_async(mocker: MockerFixture) -> None:1280    fake = FakeRunnable()1281    spy = mocker.spy(fake, "invoke")12821283    handler = ConsoleCallbackHandler()1284    assert (1285        await fake.with_config(metadata={"a": "b"}).ainvoke(1286            "hello", config={"callbacks": [handler]}1287        )1288        == 51289    )1290    assert spy.call_args_list == [1291        mocker.call(1292            "hello",1293            {1294                "callbacks": [handler],1295                "metadata": {"a": "b"},1296                "configurable": {},1297                "tags": [],1298            },1299        ),1300    ]1301    spy.reset_mock()13021303    assert [1304        part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")1305    ] == [5]1306    assert spy.call_args_list == [1307        mocker.call("hello", {"metadata": {"a": "b"}, "tags": [], "configurable": {}}),1308    ]1309    spy.reset_mock()13101311    assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(1312        ["hello", "wooorld"], {"metadata": {"key": "value"}}1313    ) == [1314        5,1315        7,1316    ]1317    assert sorted(spy.call_args_list) == [1318        mocker.call(1319            "hello",1320            {1321                "metadata": {"key": "value"},1322                "tags": ["c"],1323                "callbacks": None,1324                "recursion_limit": 5,1325                "configurable": {},1326            },1327        ),1328        mocker.call(1329            "wooorld",1330            {1331                "metadata": {"key": "value"},1332                "tags": ["c"],1333                "callbacks": None,1334                "recursion_limit": 5,1335                "configurable": {},1336            },1337        ),1338    ]1339    spy.reset_mock()13401341    assert sorted(1342        [1343            c1344            async for c in fake.with_config(1345                recursion_limit=5, tags=["c"]1346            ).abatch_as_completed(["hello", "wooorld"], {"metadata": {"key": "value"}})1347        ]1348    ) == [1349        (0, 5),1350        (1, 7),1351    ]1352    assert len(spy.call_args_list) == 21353    first_call = next(call for call in spy.call_args_list if call.args[0] == "hello")1354    assert first_call == mocker.call(1355        "hello",1356        {1357            "metadata": {"key": "value"},1358            "tags": ["c"],1359            "callbacks": None,1360            "recursion_limit": 5,1361            "configurable": {},1362        },1363    )1364    second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")1365    assert second_call == mocker.call(1366        "wooorld",1367        {1368            "metadata": {"key": "value"},1369            "tags": ["c"],1370            "callbacks": None,1371            "recursion_limit": 5,1372            "configurable": {},1373        },1374    )137513761377def test_default_method_implementations(mocker: MockerFixture) -> None:1378    fake = FakeRunnable()1379    spy = mocker.spy(fake, "invoke")13801381    assert fake.invoke("hello", {"tags": ["a-tag"]}) == 51382    assert spy.call_args_list == [1383        mocker.call("hello", {"tags": ["a-tag"]}),1384    ]1385    spy.reset_mock()13861387    assert [*fake.stream("hello", {"metadata": {"key": "value"}})] == [5]1388    assert spy.call_args_list == [1389        mocker.call("hello", {"metadata": {"key": "value"}}),1390    ]1391    spy.reset_mock()13921393    assert fake.batch(1394        ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]1395    ) == [5, 7]13961397    assert len(spy.call_args_list) == 21398    for call in spy.call_args_list:1399        call_arg = call.args[0]14001401        if call_arg == "hello":1402            assert call_arg == "hello"1403            assert call.args[1].get("tags") == ["a-tag"]1404            assert call.args[1].get("metadata") == {}1405        else:1406            assert call_arg == "wooorld"1407            assert call.args[1].get("tags") == []1408            assert call.args[1].get("metadata") == {"key": "value"}14091410    spy.reset_mock()14111412    assert fake.batch(["hello", "wooorld"], {"tags": ["a-tag"]}) == [5, 7]1413    assert len(spy.call_args_list) == 21414    assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}1415    for call in spy.call_args_list:1416        assert call.args[1].get("tags") == ["a-tag"]1417        assert call.args[1].get("metadata") == {}141814191420async def test_default_method_implementations_async(mocker: MockerFixture) -> None:1421    fake = FakeRunnable()1422    spy = mocker.spy(fake, "invoke")14231424    assert await fake.ainvoke("hello", config={"callbacks": []}) == 51425    assert spy.call_args_list == [1426        mocker.call("hello", {"callbacks": []}),1427    ]1428    spy.reset_mock()14291430    assert [part async for part in fake.astream("hello")] == [5]1431    assert spy.call_args_list == [1432        mocker.call("hello", None),1433    ]1434    spy.reset_mock()14351436    assert await fake.abatch(["hello", "wooorld"], {"metadata": {"key": "value"}}) == [1437        5,1438        7,1439    ]1440    assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}1441    for call in spy.call_args_list:1442        assert call.args[1] == {1443            "metadata": {"key": "value"},1444            "tags": [],1445            "callbacks": None,1446            "recursion_limit": 25,1447            "configurable": {},1448        }144914501451def test_prompt() -> None:1452    prompt = ChatPromptTemplate.from_messages(1453        messages=[1454            SystemMessage(content="You are a nice assistant."),1455            HumanMessagePromptTemplate.from_template("{question}"),1456        ]1457    )1458    expected = ChatPromptValue(1459        messages=[1460            SystemMessage(content="You are a nice assistant."),1461            HumanMessage(content="What is your name?"),1462        ]1463    )14641465    assert prompt.invoke({"question": "What is your name?"}) == expected14661467    assert prompt.batch(1468        [1469            {"question": "What is your name?"},1470            {"question": "What is your favorite color?"},1471        ]1472    ) == [1473        expected,1474        ChatPromptValue(1475            messages=[1476                SystemMessage(content="You are a nice assistant."),1477                HumanMessage(content="What is your favorite color?"),1478            ]1479        ),1480    ]14811482    assert [*prompt.stream({"question": "What is your name?"})] == [expected]148314841485async def test_prompt_async() -> None:1486    prompt = ChatPromptTemplate.from_messages(1487        messages=[1488            SystemMessage(content="You are a nice assistant."),1489            HumanMessagePromptTemplate.from_template("{question}"),1490        ]1491    )1492    expected = ChatPromptValue(1493        messages=[1494            SystemMessage(content="You are a nice assistant."),1495            HumanMessage(content="What is your name?"),1496        ]1497    )14981499    assert await prompt.ainvoke({"question": "What is your name?"}) == expected15001501    assert await prompt.abatch(1502        [1503            {"question": "What is your name?"},1504            {"question": "What is your favorite color?"},1505        ]1506    ) == [1507        expected,1508        ChatPromptValue(1509            messages=[1510                SystemMessage(content="You are a nice assistant."),1511                HumanMessage(content="What is your favorite color?"),1512            ]1513        ),1514    ]15151516    assert [1517        part async for part in prompt.astream({"question": "What is your name?"})1518    ] == [expected]15191520    stream_log = [1521        part async for part in prompt.astream_log({"question": "What is your name?"})1522    ]15231524    assert len(stream_log[0].ops) == 11525    assert stream_log[0].ops[0]["op"] == "replace"1526    assert stream_log[0].ops[0]["path"] == ""1527    assert stream_log[0].ops[0]["value"]["logs"] == {}1528    assert stream_log[0].ops[0]["value"]["final_output"] is None1529    assert stream_log[0].ops[0]["value"]["streamed_output"] == []1530    assert isinstance(stream_log[0].ops[0]["value"]["id"], str)15311532    assert stream_log[1:] == [1533        RunLogPatch(1534            {"op": "add", "path": "/streamed_output/-", "value": expected},1535            {1536                "op": "replace",1537                "path": "/final_output",1538                "value": ChatPromptValue(1539                    messages=[1540                        SystemMessage(content="You are a nice assistant."),1541                        HumanMessage(content="What is your name?"),1542                    ]1543                ),1544            },1545        ),1546    ]15471548    stream_log_state = [1549        part1550        async for part in prompt.astream_log(1551            {"question": "What is your name?"}, diff=False1552        )1553    ]15541555    # remove random id1556    stream_log[0].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"1557    stream_log_state[-1].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"1558    stream_log_state[-1].state["id"] = "00000000-0000-0000-0000-000000000000"15591560    # assert output with diff=False matches output with diff=True1561    assert stream_log_state[-1].ops == [op for chunk in stream_log for op in chunk.ops]1562    assert stream_log_state[-1] == RunLog(1563        *[op for chunk in stream_log for op in chunk.ops],1564        state={1565            "final_output": ChatPromptValue(1566                messages=[1567                    SystemMessage(content="You are a nice assistant."),1568                    HumanMessage(content="What is your name?"),1569                ]1570            ),1571            "id": "00000000-0000-0000-0000-000000000000",1572            "logs": {},1573            "streamed_output": [1574                ChatPromptValue(1575                    messages=[1576                        SystemMessage(content="You are a nice assistant."),1577                        HumanMessage(content="What is your name?"),1578                    ]1579                )1580            ],1581            "type": "prompt",1582            "name": "ChatPromptTemplate",1583        },1584    )15851586    # nested inside trace_with_chain_group15871588    async with atrace_as_chain_group("a_group") as manager:1589        stream_log_nested = [1590            part1591            async for part in prompt.astream_log(1592                {"question": "What is your name?"}, config={"callbacks": manager}1593            )1594        ]15951596    assert len(stream_log_nested[0].ops) == 11597    assert stream_log_nested[0].ops[0]["op"] == "replace"1598    assert stream_log_nested[0].ops[0]["path"] == ""1599    assert stream_log_nested[0].ops[0]["value"]["logs"] == {}1600    assert stream_log_nested[0].ops[0]["value"]["final_output"] is None1601    assert stream_log_nested[0].ops[0]["value"]["streamed_output"] == []1602    assert isinstance(stream_log_nested[0].ops[0]["value"]["id"], str)16031604    assert stream_log_nested[1:] == [1605        RunLogPatch(1606            {"op": "add", "path": "/streamed_output/-", "value": expected},1607            {1608                "op": "replace",1609                "path": "/final_output",1610                "value": ChatPromptValue(1611                    messages=[1612                        SystemMessage(content="You are a nice assistant."),1613                        HumanMessage(content="What is your name?"),1614                    ]1615                ),1616            },1617        ),1618    ]161916201621def test_prompt_template_params() -> None:1622    prompt = ChatPromptTemplate.from_template(1623        "Respond to the following question: {question}"1624    )1625    result = prompt.invoke(1626        {1627            "question": "test",1628            "topic": "test",1629        }1630    )1631    assert result == ChatPromptValue(1632        messages=[HumanMessage(content="Respond to the following question: test")]1633    )16341635    with pytest.raises(KeyError):1636        prompt.invoke({})163716381639def test_with_listeners(mocker: MockerFixture) -> None:1640    prompt = (1641        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1642        + "{question}"1643    )1644    chat = FakeListChatModel(responses=["foo"])16451646    chain = prompt | chat16471648    mock_start = mocker.Mock()1649    mock_end = mocker.Mock()16501651    chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(1652        {"question": "Who are you?"}1653    )16541655    assert mock_start.call_count == 11656    assert mock_start.call_args[0][0].name == "RunnableSequence"1657    assert mock_end.call_count == 116581659    mock_start.reset_mock()1660    mock_end.reset_mock()16611662    with trace_as_chain_group("hello") as manager:1663        chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(1664            {"question": "Who are you?"}, {"callbacks": manager}1665        )16661667    assert mock_start.call_count == 11668    assert mock_start.call_args[0][0].name == "RunnableSequence"1669    assert mock_end.call_count == 1167016711672async def test_with_listeners_async(mocker: MockerFixture) -> None:1673    prompt = (1674        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1675        + "{question}"1676    )1677    chat = FakeListChatModel(responses=["foo"])16781679    chain = prompt | chat16801681    mock_start = mocker.Mock()1682    mock_end = mocker.Mock()16831684    await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(1685        {"question": "Who are you?"}1686    )16871688    assert mock_start.call_count == 11689    assert mock_start.call_args[0][0].name == "RunnableSequence"1690    assert mock_end.call_count == 116911692    mock_start.reset_mock()1693    mock_end.reset_mock()16941695    async with atrace_as_chain_group("hello") as manager:1696        await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(1697            {"question": "Who are you?"}, {"callbacks": manager}1698        )16991700    assert mock_start.call_count == 11701    assert mock_start.call_args[0][0].name == "RunnableSequence"1702    assert mock_end.call_count == 1170317041705def test_with_listener_propagation(mocker: MockerFixture) -> None:1706    prompt = (1707        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1708        + "{question}"1709    )1710    chat = FakeListChatModel(responses=["foo"])1711    chain: Runnable = prompt | chat1712    mock_start = mocker.Mock()1713    mock_end = mocker.Mock()1714    chain_with_listeners = chain.with_listeners(on_start=mock_start, on_end=mock_end)17151716    chain_with_listeners.with_retry().invoke({"question": "Who are you?"})17171718    assert mock_start.call_count == 11719    assert mock_start.call_args[0][0].name == "RunnableSequence"1720    assert mock_end.call_count == 117211722    mock_start.reset_mock()1723    mock_end.reset_mock()17241725    chain_with_listeners.with_types(output_type=str).invoke(1726        {"question": "Who are you?"}1727    )17281729    assert mock_start.call_count == 11730    assert mock_start.call_args[0][0].name == "RunnableSequence"1731    assert mock_end.call_count == 117321733    mock_start.reset_mock()1734    mock_end.reset_mock()17351736    chain_with_listeners.with_config({"tags": ["foo"]}).invoke(1737        {"question": "Who are you?"}1738    )17391740    assert mock_start.call_count == 11741    assert mock_start.call_args[0][0].name == "RunnableSequence"1742    assert mock_end.call_count == 117431744    mock_start.reset_mock()1745    mock_end.reset_mock()17461747    chain_with_listeners.bind(stop=["foo"]).invoke({"question": "Who are you?"})17481749    assert mock_start.call_count == 11750    assert mock_start.call_args[0][0].name == "RunnableSequence"1751    assert mock_end.call_count == 117521753    mock_start.reset_mock()1754    mock_end.reset_mock()17551756    mock_start_inner = mocker.Mock()1757    mock_end_inner = mocker.Mock()17581759    chain_with_listeners.with_listeners(1760        on_start=mock_start_inner, on_end=mock_end_inner1761    ).invoke({"question": "Who are you?"})17621763    assert mock_start.call_count == 11764    assert mock_start.call_args[0][0].name == "RunnableSequence"1765    assert mock_end.call_count == 11766    assert mock_start_inner.call_count == 11767    assert mock_start_inner.call_args[0][0].name == "RunnableSequence"1768    assert mock_end_inner.call_count == 1176917701771@freeze_time("2023-01-01")1772@pytest.mark.usefixtures("deterministic_uuids")1773def test_prompt_with_chat_model(1774    mocker: MockerFixture,1775    snapshot: SnapshotAssertion,1776) -> None:1777    prompt = (1778        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1779        + "{question}"1780    )1781    chat = FakeListChatModel(responses=["foo"])17821783    chain = prompt | chat17841785    assert repr(chain) == snapshot1786    assert isinstance(chain, RunnableSequence)1787    assert chain.first == prompt1788    assert chain.middle == []1789    assert chain.last == chat1790    assert dumps(chain, pretty=True) == snapshot17911792    # Test invoke1793    prompt_spy = mocker.spy(prompt.__class__, "invoke")1794    chat_spy = mocker.spy(chat.__class__, "invoke")1795    tracer = FakeTracer()1796    assert chain.invoke(1797        {"question": "What is your name?"}, {"callbacks": [tracer]}1798    ) == _any_id_ai_message(content="foo")1799    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1800    assert chat_spy.call_args.args[1] == ChatPromptValue(1801        messages=[1802            SystemMessage(content="You are a nice assistant."),1803            HumanMessage(content="What is your name?"),1804        ]1805    )18061807    assert tracer.runs == snapshot18081809    mocker.stop(prompt_spy)1810    mocker.stop(chat_spy)18111812    # Test batch1813    prompt_spy = mocker.spy(prompt.__class__, "batch")1814    chat_spy = mocker.spy(chat.__class__, "batch")1815    tracer = FakeTracer()1816    assert chain.batch(1817        [1818            {"question": "What is your name?"},1819            {"question": "What is your favorite color?"},1820        ],1821        {"callbacks": [tracer]},1822    ) == [1823        _any_id_ai_message(content="foo"),1824        _any_id_ai_message(content="foo"),1825    ]1826    assert prompt_spy.call_args.args[1] == [1827        {"question": "What is your name?"},1828        {"question": "What is your favorite color?"},1829    ]1830    assert chat_spy.call_args.args[1] == [1831        ChatPromptValue(1832            messages=[1833                SystemMessage(content="You are a nice assistant."),1834                HumanMessage(content="What is your name?"),1835            ]1836        ),1837        ChatPromptValue(1838            messages=[1839                SystemMessage(content="You are a nice assistant."),1840                HumanMessage(content="What is your favorite color?"),1841            ]1842        ),1843    ]1844    assert (1845        len(1846            [1847                r1848                for r in tracer.runs1849                if r.parent_run_id is None and len(r.child_runs) == 21850            ]1851        )1852        == 21853    ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"1854    mocker.stop(prompt_spy)1855    mocker.stop(chat_spy)18561857    # Test stream1858    prompt_spy = mocker.spy(prompt.__class__, "invoke")1859    chat_spy = mocker.spy(chat.__class__, "stream")1860    tracer = FakeTracer()1861    assert [1862        *chain.stream({"question": "What is your name?"}, {"callbacks": [tracer]})1863    ] == [1864        _any_id_ai_message_chunk(content="f"),1865        _any_id_ai_message_chunk(content="o"),1866        _any_id_ai_message_chunk(content="o", chunk_position="last"),1867    ]1868    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1869    assert chat_spy.call_args.args[1] == ChatPromptValue(1870        messages=[1871            SystemMessage(content="You are a nice assistant."),1872            HumanMessage(content="What is your name?"),1873        ]1874    )187518761877@freeze_time("2023-01-01")1878@pytest.mark.usefixtures("deterministic_uuids")1879async def test_prompt_with_chat_model_async(1880    mocker: MockerFixture,1881    snapshot: SnapshotAssertion,1882) -> None:1883    prompt = (1884        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1885        + "{question}"1886    )1887    chat = FakeListChatModel(responses=["foo"])18881889    chain = prompt | chat18901891    assert repr(chain) == snapshot1892    assert isinstance(chain, RunnableSequence)1893    assert chain.first == prompt1894    assert chain.middle == []1895    assert chain.last == chat1896    assert dumps(chain, pretty=True) == snapshot18971898    # Test invoke1899    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")1900    chat_spy = mocker.spy(chat.__class__, "ainvoke")1901    tracer = FakeTracer()1902    assert await chain.ainvoke(1903        {"question": "What is your name?"}, {"callbacks": [tracer]}1904    ) == _any_id_ai_message(content="foo")1905    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1906    assert chat_spy.call_args.args[1] == ChatPromptValue(1907        messages=[1908            SystemMessage(content="You are a nice assistant."),1909            HumanMessage(content="What is your name?"),1910        ]1911    )19121913    assert tracer.runs == snapshot19141915    mocker.stop(prompt_spy)1916    mocker.stop(chat_spy)19171918    # Test batch1919    prompt_spy = mocker.spy(prompt.__class__, "abatch")1920    chat_spy = mocker.spy(chat.__class__, "abatch")1921    tracer = FakeTracer()1922    assert await chain.abatch(1923        [1924            {"question": "What is your name?"},1925            {"question": "What is your favorite color?"},1926        ],1927        {"callbacks": [tracer]},1928    ) == [1929        _any_id_ai_message(content="foo"),1930        _any_id_ai_message(content="foo"),1931    ]1932    assert prompt_spy.call_args.args[1] == [1933        {"question": "What is your name?"},1934        {"question": "What is your favorite color?"},1935    ]1936    assert chat_spy.call_args.args[1] == [1937        ChatPromptValue(1938            messages=[1939                SystemMessage(content="You are a nice assistant."),1940                HumanMessage(content="What is your name?"),1941            ]1942        ),1943        ChatPromptValue(1944            messages=[1945                SystemMessage(content="You are a nice assistant."),1946                HumanMessage(content="What is your favorite color?"),1947            ]1948        ),1949    ]1950    assert (1951        len(1952            [1953                r1954                for r in tracer.runs1955                if r.parent_run_id is None and len(r.child_runs) == 21956            ]1957        )1958        == 21959    ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"1960    mocker.stop(prompt_spy)1961    mocker.stop(chat_spy)19621963    # Test stream1964    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")1965    chat_spy = mocker.spy(chat.__class__, "astream")1966    tracer = FakeTracer()1967    assert [1968        a1969        async for a in chain.astream(1970            {"question": "What is your name?"}, {"callbacks": [tracer]}1971        )1972    ] == [1973        _any_id_ai_message_chunk(content="f"),1974        _any_id_ai_message_chunk(content="o"),1975        _any_id_ai_message_chunk(content="o", chunk_position="last"),1976    ]1977    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1978    assert chat_spy.call_args.args[1] == ChatPromptValue(1979        messages=[1980            SystemMessage(content="You are a nice assistant."),1981            HumanMessage(content="What is your name?"),1982        ]1983    )198419851986@pytest.mark.skipif(1987    condition=sys.version_info[1] == 13,1988    reason=(1989        "temporary, py3.13 exposes some invalid assumptions about order of batch async "1990        "executions."1991    ),1992)1993@freeze_time("2023-01-01")1994async def test_prompt_with_llm(1995    mocker: MockerFixture, snapshot: SnapshotAssertion1996) -> None:1997    prompt = (1998        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1999        + "{question}"2000    )

Findings

✓ No findings reported for this file.

Get this view in your editor

Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.