libs/core/tests/unit_tests/runnables/test_runnable.py PYTHON 6,005 lines View on github.com → Search inside
File is large — showing lines 1–2,000 of 6,005.
1import asyncio2import re3import sys4import time5import uuid6import warnings7from collections.abc import (8    AsyncIterator,9    Awaitable,10    Callable,11    Iterator,12    Sequence,13)14from functools import partial15from operator import itemgetter16from typing import Any, cast17from uuid import UUID1819import pytest20from freezegun import freeze_time21from packaging import version22from pydantic import BaseModel, Field, ValidationError23from pydantic.v1 import BaseModel as BaseModelV124from pydantic.v1 import Field as FieldV125from pydantic.v1 import ValidationError as ValidationErrorV126from pytest_mock import MockerFixture27from syrupy.assertion import SnapshotAssertion28from typing_extensions import TypedDict, override2930from langchain_core.callbacks import BaseCallbackHandler31from langchain_core.callbacks.manager import (32    AsyncCallbackManagerForRetrieverRun,33    CallbackManagerForRetrieverRun,34    atrace_as_chain_group,35    trace_as_chain_group,36)37from langchain_core.documents import Document38from langchain_core.language_models import (39    FakeListChatModel,40    FakeListLLM,41    FakeStreamingListLLM,42)43from langchain_core.language_models.fake_chat_models import GenericFakeChatModel44from langchain_core.load import dumpd, dumps45from langchain_core.load.load import loads46from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage47from langchain_core.messages.base import BaseMessage48from langchain_core.output_parsers import (49    BaseOutputParser,50    CommaSeparatedListOutputParser,51    StrOutputParser,52)53from langchain_core.outputs.chat_generation import ChatGeneration54from langchain_core.outputs.llm_result import LLMResult55from langchain_core.prompt_values import ChatPromptValue, StringPromptValue56from langchain_core.prompts import (57    ChatPromptTemplate,58    HumanMessagePromptTemplate,59    MessagesPlaceholder,60    PromptTemplate,61    SystemMessagePromptTemplate,62)63from langchain_core.retrievers import BaseRetriever64from langchain_core.runnables import (65    AddableDict,66    ConfigurableField,67    ConfigurableFieldMultiOption,68    ConfigurableFieldSingleOption,69    RouterRunnable,70    Runnable,71    RunnableAssign,72    RunnableBinding,73    RunnableBranch,74    RunnableConfig,75    RunnableGenerator,76    RunnableLambda,77    RunnableParallel,78    RunnablePassthrough,79    RunnablePick,80    RunnableSequence,81    add,82    chain,83)84from langchain_core.runnables.base import RunnableMap, RunnableSerializable85from langchain_core.runnables.utils import Input, Output86from langchain_core.tools import BaseTool, tool87from langchain_core.tracers import (88    BaseTracer,89    ConsoleCallbackHandler,90    Run,91    RunLog,92    RunLogPatch,93)94from langchain_core.tracers._compat import pydantic_copy95from langchain_core.tracers.context import collect_runs96from langchain_core.utils.pydantic import (97    PYDANTIC_VERSION,98    TypeBaseModel,99    model_validate,100)101from langchain_core.version import VERSION102from tests.unit_tests.pydantic_utils import (103    _normalize_schema,104    _schema,105    skip_if_no_pydantic_v1,106)107from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk108109# Several tests assert the legacy `RunLog` / `RunLogPatch` output produced by110# `astream_log`, which cannot be replaced by `astream` without losing coverage.111pytestmark = pytest.mark.filterwarnings(112    "ignore:astream_log is deprecated. Use astream instead.:"113    "langchain_core._api.deprecation.LangChainDeprecationWarning"114)115116PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION117PYDANTIC_VERSION_AT_LEAST_210 = version.parse("2.10") <= PYDANTIC_VERSION118119120class FakeTracer(BaseTracer):121    """Fake tracer that records LangChain execution.122123    It replaces run IDs with deterministic UUIDs for snapshotting.124    """125126    def __init__(self) -> None:127        """Initialize the tracer."""128        super().__init__()129        self.runs: list[Run] = []130        self.uuids_map: dict[UUID, UUID] = {}131        self.uuids_generator = (132            UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)133        )134135    def _replace_uuid(self, uuid: UUID) -> UUID:136        if uuid not in self.uuids_map:137            self.uuids_map[uuid] = next(self.uuids_generator)138        return self.uuids_map[uuid]139140    def _replace_message_id(self, maybe_message: Any) -> Any:141        if isinstance(maybe_message, BaseMessage):142            maybe_message.id = str(next(self.uuids_generator))143        if isinstance(maybe_message, ChatGeneration):144            maybe_message.message.id = str(next(self.uuids_generator))145        if isinstance(maybe_message, LLMResult):146            for i, gen_list in enumerate(maybe_message.generations):147                for j, gen in enumerate(gen_list):148                    maybe_message.generations[i][j] = self._replace_message_id(gen)149        if isinstance(maybe_message, dict):150            for k, v in maybe_message.items():151                maybe_message[k] = self._replace_message_id(v)152        if isinstance(maybe_message, list):153            for i, v in enumerate(maybe_message):154                maybe_message[i] = self._replace_message_id(v)155156        return maybe_message157158    def _copy_run(self, run: Run) -> Run:159        if run.dotted_order:160            levels = run.dotted_order.split(".")161            processed_levels = []162            for level in levels:163                timestamp, run_id = level.split("Z")164                new_run_id = self._replace_uuid(UUID(run_id))165                processed_level = f"{timestamp}Z{new_run_id}"166                processed_levels.append(processed_level)167            new_dotted_order = ".".join(processed_levels)168        else:169            new_dotted_order = None170        update_dict = {171            "id": self._replace_uuid(run.id),172            "parent_run_id": (173                self.uuids_map[run.parent_run_id] if run.parent_run_id else None174            ),175            "child_runs": [self._copy_run(child) for child in run.child_runs],176            "trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,177            "dotted_order": new_dotted_order,178            "inputs": self._replace_message_id(run.inputs),179            "outputs": self._replace_message_id(run.outputs),180        }181        return pydantic_copy(run, update=update_dict)182183    def _persist_run(self, run: Run) -> None:184        """Persist a run."""185        self.runs.append(self._copy_run(run))186187    def flattened_runs(self) -> list[Run]:188        q = [*self.runs]189        result = []190        while q:191            parent = q.pop()192            result.append(parent)193            if parent.child_runs:194                q.extend(parent.child_runs)195        return result196197    @property198    def run_ids(self) -> list[uuid.UUID | None]:199        runs = self.flattened_runs()200        uuids_map = {v: k for k, v in self.uuids_map.items()}201        return [uuids_map.get(r.id) for r in runs]202203204class FakeRunnable(Runnable[str, int]):205    @override206    def invoke(207        self,208        input: str,209        config: RunnableConfig | None = None,210        **kwargs: Any,211    ) -> int:212        return len(input)213214215class FakeRunnableSerializable(RunnableSerializable[str, int]):216    hello: str = ""217218    @override219    def invoke(220        self,221        input: str,222        config: RunnableConfig | None = None,223        **kwargs: Any,224    ) -> int:225        return len(input)226227228class FakeRetriever(BaseRetriever):229    @override230    def _get_relevant_documents(231        self, query: str, *, run_manager: CallbackManagerForRetrieverRun232    ) -> list[Document]:233        return [Document(page_content="foo"), Document(page_content="bar")]234235    @override236    async def _aget_relevant_documents(237        self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun238    ) -> list[Document]:239        return [Document(page_content="foo"), Document(page_content="bar")]240241242@pytest.mark.skipif(243    PYDANTIC_VERSION_AT_LEAST_210,244    reason=(245        "Only test with most recent version of pydantic. "246        "Pydantic introduced small fixes to generated JSONSchema on minor versions."247    ),248)249def test_schemas(snapshot: SnapshotAssertion) -> None:250    fake = FakeRunnable()  # str -> int251252    assert fake.get_input_jsonschema() == {253        "title": "FakeRunnableInput",254        "type": "string",255    }256    assert fake.get_output_jsonschema() == {257        "title": "FakeRunnableOutput",258        "type": "integer",259    }260    assert fake.get_config_jsonschema(include=["tags", "metadata", "run_name"]) == {261        "properties": {262            "metadata": {263                "default": None,264                "title": "Metadata",265                "type": "object",266            },267            "run_name": {"default": None, "title": "Run Name", "type": "string"},268            "tags": {269                "default": None,270                "items": {"type": "string"},271                "title": "Tags",272                "type": "array",273            },274        },275        "title": "FakeRunnableConfig",276        "type": "object",277    }278279    fake_bound = FakeRunnable().bind(a="b")  # str -> int280281    assert fake_bound.get_input_jsonschema() == {282        "title": "FakeRunnableInput",283        "type": "string",284    }285    assert fake_bound.get_output_jsonschema() == {286        "title": "FakeRunnableOutput",287        "type": "integer",288    }289290    fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,))  # str -> int291292    assert fake_w_fallbacks.get_input_jsonschema() == {293        "title": "FakeRunnableInput",294        "type": "string",295    }296    assert fake_w_fallbacks.get_output_jsonschema() == {297        "title": "FakeRunnableOutput",298        "type": "integer",299    }300301    def typed_lambda_impl(x: str) -> int:302        return len(x)303304    typed_lambda = RunnableLambda(typed_lambda_impl)  # str -> int305306    assert typed_lambda.get_input_jsonschema() == {307        "title": "typed_lambda_impl_input",308        "type": "string",309    }310    assert typed_lambda.get_output_jsonschema() == {311        "title": "typed_lambda_impl_output",312        "type": "integer",313    }314315    async def typed_async_lambda_impl(x: str) -> int:316        return len(x)317318    typed_async_lambda = RunnableLambda(typed_async_lambda_impl)  # str -> int319320    assert typed_async_lambda.get_input_jsonschema() == {321        "title": "typed_async_lambda_impl_input",322        "type": "string",323    }324    assert typed_async_lambda.get_output_jsonschema() == {325        "title": "typed_async_lambda_impl_output",326        "type": "integer",327    }328329    fake_ret = FakeRetriever()  # str -> list[Document]330331    assert fake_ret.get_input_jsonschema() == {332        "title": "FakeRetrieverInput",333        "type": "string",334    }335    assert _normalize_schema(fake_ret.get_output_jsonschema()) == {336        "$defs": {337            "Document": {338                "description": "Class for storing a piece of text and "339                "associated metadata.\n"340                "\n"341                "!!! note\n"342                "\n"343                "    `Document` is for **retrieval workflows**, not chat I/O. For "344                "sending text\n"345                "    to an LLM in a conversation, use message types from "346                "`langchain.messages`.\n"347                "\n"348                "Example:\n"349                "    ```python\n"350                "    from langchain_core.documents import Document\n"351                "\n"352                "    document = Document(\n"353                '        page_content="Hello, world!", '354                'metadata={"source": "https://example.com"}\n'355                "    )\n"356                "    ```",357                "properties": {358                    "id": {359                        "anyOf": [{"type": "string"}, {"type": "null"}],360                        "default": None,361                        "title": "Id",362                    },363                    "metadata": {"title": "Metadata", "type": "object"},364                    "page_content": {"title": "Page Content", "type": "string"},365                    "type": {366                        "const": "Document",367                        "default": "Document",368                        "title": "Type",369                    },370                },371                "required": ["page_content"],372                "title": "Document",373                "type": "object",374            }375        },376        "items": {"$ref": "#/$defs/Document"},377        "title": "FakeRetrieverOutput",378        "type": "array",379    }380381    fake_llm = FakeListLLM(responses=["a"])  # str -> list[list[str]]382383    assert _schema(fake_llm.input_schema) == snapshot(name="fake_llm_input_schema")384    assert _schema(fake_llm.output_schema) == {385        "title": "FakeListLLMOutput",386        "type": "string",387    }388389    fake_chat = FakeListChatModel(responses=["a"])  # str -> list[list[str]]390391    assert _schema(fake_chat.input_schema) == snapshot(name="fake_chat_input_schema")392    assert _schema(fake_chat.output_schema) == snapshot(name="fake_chat_output_schema")393394    chat_prompt = ChatPromptTemplate.from_messages(395        [396            MessagesPlaceholder(variable_name="history"),397            ("human", "Hello, how are you?"),398        ]399    )400401    assert _normalize_schema(chat_prompt.get_input_jsonschema()) == snapshot(402        name="chat_prompt_input_schema"403    )404    assert _normalize_schema(chat_prompt.get_output_jsonschema()) == snapshot(405        name="chat_prompt_output_schema"406    )407408    prompt = PromptTemplate.from_template("Hello, {name}!")409410    assert prompt.get_input_jsonschema() == {411        "title": "PromptInput",412        "type": "object",413        "properties": {"name": {"title": "Name", "type": "string"}},414        "required": ["name"],415    }416    assert _schema(prompt.output_schema) == snapshot(name="prompt_output_schema")417418    prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()419420    assert _normalize_schema(prompt_mapper.get_input_jsonschema()) == {421        "$defs": {422            "PromptInput": {423                "properties": {"name": {"title": "Name", "type": "string"}},424                "required": ["name"],425                "title": "PromptInput",426                "type": "object",427            }428        },429        "default": None,430        "items": {"$ref": "#/$defs/PromptInput"},431        "title": "RunnableEach<PromptTemplate>Input",432        "type": "array",433    }434    assert _schema(prompt_mapper.output_schema) == snapshot(435        name="prompt_mapper_output_schema"436    )437438    list_parser = CommaSeparatedListOutputParser()439440    assert _schema(list_parser.input_schema) == snapshot(441        name="list_parser_input_schema"442    )443    assert _schema(list_parser.output_schema) == {444        "title": "CommaSeparatedListOutputParserOutput",445        "type": "array",446        "items": {"type": "string"},447    }448449    seq = prompt | fake_llm | list_parser450451    assert seq.get_input_jsonschema() == {452        "title": "PromptInput",453        "type": "object",454        "properties": {"name": {"title": "Name", "type": "string"}},455        "required": ["name"],456    }457    assert seq.get_output_jsonschema() == {458        "type": "array",459        "items": {"type": "string"},460        "title": "CommaSeparatedListOutputParserOutput",461    }462463    router: Runnable = RouterRunnable({})464465    assert _schema(router.input_schema) == {466        "$ref": "#/definitions/RouterInput",467        "definitions": {468            "RouterInput": {469                "description": "Router input.",470                "properties": {471                    "input": {"title": "Input"},472                    "key": {"title": "Key", "type": "string"},473                },474                "required": ["key", "input"],475                "title": "RouterInput",476                "type": "object",477            }478        },479        "title": "RouterRunnableInput",480    }481    assert router.get_output_jsonschema() == {"title": "RouterRunnableOutput"}482483    seq_w_map = (484        prompt485        | fake_llm486        | {487            "original": RunnablePassthrough(input_type=str),488            "as_list": list_parser,489            "length": typed_lambda_impl,490        }491    )492493    assert seq_w_map.get_input_jsonschema() == {494        "title": "PromptInput",495        "type": "object",496        "properties": {"name": {"title": "Name", "type": "string"}},497        "required": ["name"],498    }499    assert seq_w_map.get_output_jsonschema() == {500        "title": "RunnableParallel<original,as_list,length>Output",501        "type": "object",502        "properties": {503            "original": {"title": "Original", "type": "string"},504            "length": {"title": "Length", "type": "integer"},505            "as_list": {506                "title": "As List",507                "type": "array",508                "items": {"type": "string"},509            },510        },511        "required": ["original", "as_list", "length"],512    }513514    # Add a test for schema of runnable assign515    def foo(x: int) -> int:516        return x517518    foo_ = RunnableLambda(foo)519520    assert foo_.assign(bar=lambda _: "foo").get_output_jsonschema() == {521        "properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},522        "required": ["root", "bar"],523        "title": "RunnableAssignOutput",524        "type": "object",525    }526527528def test_passthrough_assign_schema() -> None:529    retriever = FakeRetriever()  # str -> list[Document]530    prompt = PromptTemplate.from_template("{context} {question}")531    fake_llm = FakeListLLM(responses=["a"])  # str -> list[list[str]]532533    seq_w_assign = (534        RunnablePassthrough.assign(context=itemgetter("question") | retriever)535        | prompt536        | fake_llm537    )538539    assert seq_w_assign.get_input_jsonschema() == {540        "properties": {"question": {"title": "Question", "type": "string"}},541        "title": "RunnableSequenceInput",542        "type": "object",543        "required": ["question"],544    }545    assert seq_w_assign.get_output_jsonschema() == {546        "title": "FakeListLLMOutput",547        "type": "string",548    }549550    invalid_seq_w_assign = (551        RunnablePassthrough.assign(context=itemgetter("question") | retriever)552        | fake_llm  # type: ignore[operator]553    )554555    # fallback to RunnableAssign.input_schema if next runnable doesn't have556    # expected dict input_schema557    assert invalid_seq_w_assign.get_input_jsonschema() == {558        "properties": {"question": {"title": "Question"}},559        "title": "RunnableParallel<context>Input",560        "type": "object",561        "required": ["question"],562    }563564565def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:566    first_lambda = lambda x: x["hello"]  # noqa: E731567    assert RunnableLambda(first_lambda).get_input_jsonschema() == {568        "title": "RunnableLambdaInput",569        "type": "object",570        "properties": {"hello": {"title": "Hello"}},571        "required": ["hello"],572    }573574    second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"])  # noqa: E731575    assert RunnableLambda(second_lambda).get_input_jsonschema() == {576        "title": "RunnableLambdaInput",577        "type": "object",578        "properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},579        "required": ["bye", "hello"],580    }581582    def get_value(value):  # type: ignore[no-untyped-def] # noqa: ANN001,ANN202583        return value["variable_name"]584585    assert RunnableLambda(get_value).get_input_jsonschema() == {586        "title": "get_value_input",587        "type": "object",588        "properties": {"variable_name": {"title": "Variable Name"}},589        "required": ["variable_name"],590    }591592    async def aget_value(value):  # type: ignore[no-untyped-def] # noqa: ANN001,ANN202593        return (value["variable_name"], value.get("another"))594595    assert RunnableLambda(aget_value).get_input_jsonschema() == {596        "title": "aget_value_input",597        "type": "object",598        "properties": {599            "another": {"title": "Another"},600            "variable_name": {"title": "Variable Name"},601        },602        "required": ["another", "variable_name"],603    }604605    async def aget_values(value):  # type: ignore[no-untyped-def] # noqa: ANN001,ANN202606        return {607            "hello": value["variable_name"],608            "bye": value["variable_name"],609            "byebye": value["yo"],610        }611612    assert RunnableLambda(aget_values).get_input_jsonschema() == {613        "title": "aget_values_input",614        "type": "object",615        "properties": {616            "variable_name": {"title": "Variable Name"},617            "yo": {"title": "Yo"},618        },619        "required": ["variable_name", "yo"],620    }621622    class InputType(TypedDict):623        variable_name: str624        yo: int625626    class OutputType(TypedDict):627        hello: str628        bye: str629        byebye: int630631    async def aget_values_typed(value: InputType) -> OutputType:632        return {633            "hello": value["variable_name"],634            "bye": value["variable_name"],635            "byebye": value["yo"],636        }637638    assert _normalize_schema(639        RunnableLambda(aget_values_typed).get_input_jsonschema()640    ) == _normalize_schema(641        {642            "$defs": {643                "InputType": {644                    "properties": {645                        "variable_name": {646                            "title": "Variable Name",647                            "type": "string",648                        },649                        "yo": {"title": "Yo", "type": "integer"},650                    },651                    "required": ["variable_name", "yo"],652                    "title": "InputType",653                    "type": "object",654                }655            },656            "allOf": [{"$ref": "#/$defs/InputType"}],657            "title": "aget_values_typed_input",658        }659    )660661    if PYDANTIC_VERSION_AT_LEAST_29:662        assert _normalize_schema(663            RunnableLambda(aget_values_typed).get_output_jsonschema()664        ) == snapshot(name="schema8")665666667def test_with_types_with_type_generics() -> None:668    """Verify that with_types works if we use things like list[int]."""669670    def foo(x: int) -> None:671        """Add one to the input."""672        raise NotImplementedError673674    # Try specifying some675    RunnableLambda(foo).with_types(676        output_type=list[int],  # type: ignore[arg-type]677        input_type=list[int],  # type: ignore[arg-type]678    )679    RunnableLambda(foo).with_types(680        output_type=Sequence[int],  # type: ignore[arg-type]681        input_type=Sequence[int],  # type: ignore[arg-type]682    )683684685def test_schema_with_itemgetter() -> None:686    """Test runnable with itemgetter."""687    foo = RunnableLambda(itemgetter("hello"))688    assert _schema(foo.input_schema) == {689        "properties": {"hello": {"title": "Hello"}},690        "required": ["hello"],691        "title": "RunnableLambdaInput",692        "type": "object",693    }694    prompt = ChatPromptTemplate.from_template("what is {language}?")695    chain = {"language": itemgetter("language")} | prompt696    assert _schema(chain.input_schema) == {697        "properties": {"language": {"title": "Language"}},698        "required": ["language"],699        "title": "RunnableParallel<language>Input",700        "type": "object",701    }702703704def test_schema_complex_seq() -> None:705    prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")706    prompt2 = ChatPromptTemplate.from_template(707        "what country is the city {city} in? respond in {language}"708    )709710    model = FakeListChatModel(responses=[""])711712    chain1: Runnable = RunnableSequence(713        prompt1, model, StrOutputParser(), name="city_chain"714    )715716    assert chain1.name == "city_chain"717718    chain2 = (719        {"city": chain1, "language": itemgetter("language")}720        | prompt2721        | model722        | StrOutputParser()723    )724725    assert chain2.get_input_jsonschema() == {726        "title": "RunnableParallel<city,language>Input",727        "type": "object",728        "properties": {729            "person": {"title": "Person", "type": "string"},730            "language": {"title": "Language"},731        },732        "required": ["person", "language"],733    }734735    assert chain2.get_output_jsonschema() == {736        "title": "StrOutputParserOutput",737        "type": "string",738    }739740    assert chain2.with_types(input_type=str).get_input_jsonschema() == {741        "title": "RunnableSequenceInput",742        "type": "string",743    }744745    assert chain2.with_types(input_type=int).get_output_jsonschema() == {746        "title": "StrOutputParserOutput",747        "type": "string",748    }749750    class InputType(BaseModel):751        person: str752753    assert chain2.with_types(input_type=InputType).get_input_jsonschema() == {754        "title": "InputType",755        "type": "object",756        "properties": {"person": {"title": "Person", "type": "string"}},757        "required": ["person"],758    }759760761def test_configurable_fields(snapshot: SnapshotAssertion) -> None:762    fake_llm = FakeListLLM(responses=["a"])  # str -> list[list[str]]763764    assert fake_llm.invoke("...") == "a"765766    fake_llm_configurable = fake_llm.configurable_fields(767        responses=ConfigurableField(768            id="llm_responses",769            name="LLM Responses",770            description="A list of fake responses for this LLM",771        )772    )773774    assert fake_llm_configurable.invoke("...") == "a"775776    if PYDANTIC_VERSION_AT_LEAST_29:777        assert _normalize_schema(778            fake_llm_configurable.get_config_jsonschema()779        ) == snapshot(name="schema2")780781    fake_llm_configured = fake_llm_configurable.with_config(782        configurable={"llm_responses": ["b"]}783    )784785    assert fake_llm_configured.invoke("...") == "b"786787    prompt = PromptTemplate.from_template("Hello, {name}!")788789    assert prompt.invoke({"name": "John"}) == StringPromptValue(text="Hello, John!")790791    prompt_configurable = prompt.configurable_fields(792        template=ConfigurableField(793            id="prompt_template",794            name="Prompt Template",795            description="The prompt template for this chain",796        )797    )798799    assert prompt_configurable.invoke({"name": "John"}) == StringPromptValue(800        text="Hello, John!"801    )802803    if PYDANTIC_VERSION_AT_LEAST_29:804        assert _normalize_schema(805            prompt_configurable.get_config_jsonschema()806        ) == snapshot(name="schema3")807808    prompt_configured = prompt_configurable.with_config(809        configurable={"prompt_template": "Hello, {name}! {name}!"}810    )811812    assert prompt_configured.invoke({"name": "John"}) == StringPromptValue(813        text="Hello, John! John!"814    )815816    assert prompt_configurable.with_config(817        configurable={"prompt_template": "Hello {name} in {lang}"}818    ).get_input_jsonschema() == {819        "title": "PromptInput",820        "type": "object",821        "properties": {822            "lang": {"title": "Lang", "type": "string"},823            "name": {"title": "Name", "type": "string"},824        },825        "required": ["lang", "name"],826    }827828    chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()829830    assert chain_configurable.invoke({"name": "John"}) == "a"831832    if PYDANTIC_VERSION_AT_LEAST_29:833        assert _normalize_schema(834            chain_configurable.get_config_jsonschema()835        ) == snapshot(name="schema4")836837    assert (838        chain_configurable.with_config(839            configurable={840                "prompt_template": "A very good morning to you, {name} {lang}!",841                "llm_responses": ["c"],842            }843        ).invoke({"name": "John", "lang": "en"})844        == "c"845    )846847    assert chain_configurable.with_config(848        configurable={849            "prompt_template": "A very good morning to you, {name} {lang}!",850            "llm_responses": ["c"],851        }852    ).get_input_jsonschema() == {853        "title": "PromptInput",854        "type": "object",855        "properties": {856            "lang": {"title": "Lang", "type": "string"},857            "name": {"title": "Name", "type": "string"},858        },859        "required": ["lang", "name"],860    }861862    chain_with_map_configurable = prompt_configurable | {863        "llm1": fake_llm_configurable | StrOutputParser(),864        "llm2": fake_llm_configurable | StrOutputParser(),865        "llm3": fake_llm.configurable_fields(866            responses=ConfigurableField("other_responses")867        )868        | StrOutputParser(),869    }870871    assert chain_with_map_configurable.invoke({"name": "John"}) == {872        "llm1": "a",873        "llm2": "a",874        "llm3": "a",875    }876877    if PYDANTIC_VERSION_AT_LEAST_29:878        assert _normalize_schema(879            chain_with_map_configurable.get_config_jsonschema()880        ) == snapshot(name="schema5")881882    assert chain_with_map_configurable.with_config(883        configurable={884            "prompt_template": "A very good morning to you, {name}!",885            "llm_responses": ["c"],886            "other_responses": ["d"],887        }888    ).invoke({"name": "John"}) == {"llm1": "c", "llm2": "c", "llm3": "d"}889890891def test_configurable_alts_factory() -> None:892    fake_llm = FakeListLLM(responses=["a"]).configurable_alternatives(893        ConfigurableField(id="llm", name="LLM"),894        chat=partial(FakeListLLM, responses=["b"]),895    )896897    assert fake_llm.invoke("...") == "a"898899    assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b"900901902def test_configurable_fields_prefix_keys(snapshot: SnapshotAssertion) -> None:903    fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(904        responses=ConfigurableFieldMultiOption(905            id="responses",906            name="Chat Responses",907            options={908                "hello": "A good morning to you!",909                "bye": "See you later!",910                "helpful": "How can I help you?",911            },912            default=["hello", "bye"],913        ),914        # (sleep is a configurable field in FakeListChatModel)915        sleep=ConfigurableField(916            id="chat_sleep",917            is_shared=True,918        ),919    )920    fake_llm = (921        FakeListLLM(responses=["a"])922        .configurable_fields(923            responses=ConfigurableField(924                id="responses",925                name="LLM Responses",926                description="A list of fake responses for this LLM",927            )928        )929        .configurable_alternatives(930            ConfigurableField(id="llm", name="LLM"),931            chat=fake_chat | StrOutputParser(),932            prefix_keys=True,933        )934    )935    prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(936        template=ConfigurableFieldSingleOption(937            id="prompt_template",938            name="Prompt Template",939            description="The prompt template for this chain",940            options={941                "hello": "Hello, {name}!",942                "good_morning": "A very good morning to you, {name}!",943            },944            default="hello",945        )946    )947948    chain = prompt | fake_llm949950    if PYDANTIC_VERSION_AT_LEAST_29:951        assert _normalize_schema(_schema(chain.config_schema())) == snapshot(952            name="schema6"953        )954955956def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None:957    fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(958        responses=ConfigurableFieldMultiOption(959            id="chat_responses",960            name="Chat Responses",961            options={962                "hello": "A good morning to you!",963                "bye": "See you later!",964                "helpful": "How can I help you?",965            },966            default=["hello", "bye"],967        )968    )969    fake_llm = (970        FakeListLLM(responses=["a"])971        .configurable_fields(972            responses=ConfigurableField(973                id="llm_responses",974                name="LLM Responses",975                description="A list of fake responses for this LLM",976            )977        )978        .configurable_alternatives(979            ConfigurableField(id="llm", name="LLM"),980            chat=fake_chat | StrOutputParser(),981        )982    )983984    prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(985        template=ConfigurableFieldSingleOption(986            id="prompt_template",987            name="Prompt Template",988            description="The prompt template for this chain",989            options={990                "hello": "Hello, {name}!",991                "good_morning": "A very good morning to you, {name}!",992            },993            default="hello",994        )995    )996997    # deduplication of configurable fields998    chain_configurable = prompt | fake_llm | (lambda x: {"name": x}) | prompt | fake_llm9991000    assert chain_configurable.invoke({"name": "John"}) == "a"10011002    if PYDANTIC_VERSION_AT_LEAST_29:1003        assert _normalize_schema(1004            chain_configurable.get_config_jsonschema()1005        ) == snapshot(name="schema7")10061007    assert (1008        chain_configurable.with_config(configurable={"llm": "chat"}).invoke(1009            {"name": "John"}1010        )1011        == "A good morning to you!"1012    )10131014    assert (1015        chain_configurable.with_config(1016            configurable={"llm": "chat", "chat_responses": ["helpful"]}1017        ).invoke({"name": "John"})1018        == "How can I help you?"1019    )102010211022def test_passthrough_tap(mocker: MockerFixture) -> None:1023    fake = FakeRunnable()1024    mock = mocker.Mock()10251026    seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)10271028    assert seq.invoke("hello", my_kwarg="value") == 51029    assert mock.call_args_list == [1030        mocker.call("hello", my_kwarg="value"),1031        mocker.call(5),1032    ]1033    mock.reset_mock()10341035    assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6]1036    assert len(mock.call_args_list) == 41037    for call in [1038        mocker.call("hello", my_kwarg="value"),1039        mocker.call("byebye", my_kwarg="value"),1040        mocker.call(5),1041        mocker.call(6),1042    ]:1043        assert call in mock.call_args_list1044    mock.reset_mock()10451046    assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [1047        5,1048        6,1049    ]1050    assert len(mock.call_args_list) == 41051    for call in [1052        mocker.call("hello", my_kwarg="value"),1053        mocker.call("byebye", my_kwarg="value"),1054        mocker.call(5),1055        mocker.call(6),1056    ]:1057        assert call in mock.call_args_list1058    mock.reset_mock()10591060    assert sorted(1061        a1062        for a in seq.batch_as_completed(1063            ["hello", "byebye"], my_kwarg="value", return_exceptions=True1064        )1065    ) == [1066        (0, 5),1067        (1, 6),1068    ]1069    assert len(mock.call_args_list) == 41070    for call in [1071        mocker.call("hello", my_kwarg="value"),1072        mocker.call("byebye", my_kwarg="value"),1073        mocker.call(5),1074        mocker.call(6),1075    ]:1076        assert call in mock.call_args_list1077    mock.reset_mock()10781079    assert list(1080        seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value")1081    ) == [5]1082    assert mock.call_args_list == [1083        mocker.call("hello", my_kwarg="value"),1084        mocker.call(5),1085    ]1086    mock.reset_mock()108710881089async def test_passthrough_tap_async(mocker: MockerFixture) -> None:1090    fake = FakeRunnable()1091    mock = mocker.Mock()10921093    seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)10941095    assert await seq.ainvoke("hello", my_kwarg="value") == 51096    assert mock.call_args_list == [1097        mocker.call("hello", my_kwarg="value"),1098        mocker.call(5),1099    ]1100    mock.reset_mock()11011102    assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6]1103    assert len(mock.call_args_list) == 41104    for call in [1105        mocker.call("hello", my_kwarg="value"),1106        mocker.call("byebye", my_kwarg="value"),1107        mocker.call(5),1108        mocker.call(6),1109    ]:1110        assert call in mock.call_args_list1111    mock.reset_mock()11121113    assert await seq.abatch(1114        ["hello", "byebye"], my_kwarg="value", return_exceptions=True1115    ) == [1116        5,1117        6,1118    ]1119    assert len(mock.call_args_list) == 41120    for call in [1121        mocker.call("hello", my_kwarg="value"),1122        mocker.call("byebye", my_kwarg="value"),1123        mocker.call(5),1124        mocker.call(6),1125    ]:1126        assert call in mock.call_args_list1127    mock.reset_mock()11281129    assert sorted(1130        [1131            a1132            async for a in seq.abatch_as_completed(1133                ["hello", "byebye"], my_kwarg="value", return_exceptions=True1134            )1135        ]1136    ) == [1137        (0, 5),1138        (1, 6),1139    ]1140    assert len(mock.call_args_list) == 41141    for call in [1142        mocker.call("hello", my_kwarg="value"),1143        mocker.call("byebye", my_kwarg="value"),1144        mocker.call(5),1145        mocker.call(6),1146    ]:1147        assert call in mock.call_args_list1148    mock.reset_mock()11491150    assert [1151        part1152        async for part in seq.astream(1153            "hello", {"metadata": {"key": "value"}}, my_kwarg="value"1154        )1155    ] == [5]1156    assert mock.call_args_list == [1157        mocker.call("hello", my_kwarg="value"),1158        mocker.call(5),1159    ]116011611162async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:1163    fake = FakeRunnableSerializable()1164    spy = mocker.spy(fake.__class__, "invoke")1165    fakew = fake.configurable_fields(hello=ConfigurableField(id="hello", name="Hello"))11661167    assert (1168        fakew.with_config(tags=["a-tag"]).invoke(1169            "hello",1170            {1171                "configurable": {"hello": "there", "__secret_key": "nahnah"},1172                "metadata": {"bye": "now"},1173            },1174        )1175        == 51176    )1177    assert spy.call_args_list[0].args[1:] == (1178        "hello",1179        {1180            "tags": ["a-tag"],1181            "callbacks": None,1182            "recursion_limit": 25,1183            "configurable": {"hello": "there", "__secret_key": "nahnah"},1184            "metadata": {"bye": "now"},1185        },1186    )1187    spy.reset_mock()118811891190def test_with_config(mocker: MockerFixture) -> None:1191    fake = FakeRunnable()1192    spy = mocker.spy(fake, "invoke")11931194    assert fake.with_config(tags=["a-tag"]).invoke("hello") == 51195    assert spy.call_args_list == [1196        mocker.call(1197            "hello",1198            {"tags": ["a-tag"], "metadata": {}, "configurable": {}},1199        ),1200    ]1201    spy.reset_mock()12021203    fake_1 = RunnablePassthrough[Any]()1204    fake_2 = RunnablePassthrough[Any]()1205    spy_seq_step = mocker.spy(fake_1.__class__, "invoke")12061207    sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(1208        tags=["b-tag"], max_concurrency=51209    )1210    assert sequence.invoke("hello") == "hello"1211    assert len(spy_seq_step.call_args_list) == 21212    for i, call in enumerate(spy_seq_step.call_args_list):1213        assert call.args[1] == "hello"1214        if i == 0:1215            assert call.args[2].get("tags") == ["a-tag"]1216            assert call.args[2].get("max_concurrency") is None1217        else:1218            assert call.args[2].get("tags") == ["b-tag"]1219            assert call.args[2].get("max_concurrency") == 51220    mocker.stop(spy_seq_step)12211222    assert [1223        *fake.with_config(tags=["a-tag"]).stream(1224            "hello", {"metadata": {"key": "value"}}1225        )1226    ] == [5]1227    assert spy.call_args_list == [1228        mocker.call(1229            "hello",1230            {"tags": ["a-tag"], "metadata": {"key": "value"}, "configurable": {}},1231        ),1232    ]1233    spy.reset_mock()12341235    assert fake.with_config(recursion_limit=5).batch(1236        ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]1237    ) == [5, 7]12381239    assert len(spy.call_args_list) == 21240    for i, call in enumerate(1241        sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)1242    ):1243        assert call.args[0] == ("hello" if i == 0 else "wooorld")1244        if i == 0:1245            assert call.args[1].get("recursion_limit") == 51246            assert call.args[1].get("tags") == ["a-tag"]1247            assert call.args[1].get("metadata") == {}1248        else:1249            assert call.args[1].get("recursion_limit") == 51250            assert call.args[1].get("tags") == []1251            assert call.args[1].get("metadata") == {"key": "value"}12521253    spy.reset_mock()12541255    assert sorted(1256        c1257        for c in fake.with_config(recursion_limit=5).batch_as_completed(1258            ["hello", "wooorld"],1259            [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}],1260        )1261    ) == [(0, 5), (1, 7)]12621263    assert len(spy.call_args_list) == 21264    for i, call in enumerate(1265        sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)1266    ):1267        assert call.args[0] == ("hello" if i == 0 else "wooorld")1268        if i == 0:1269            assert call.args[1].get("recursion_limit") == 51270            assert call.args[1].get("tags") == ["a-tag"]1271            assert call.args[1].get("metadata") == {}1272        else:1273            assert call.args[1].get("recursion_limit") == 51274            assert call.args[1].get("tags") == []1275            assert call.args[1].get("metadata") == {"key": "value"}12761277    spy.reset_mock()12781279    assert fake.with_config(metadata={"a": "b"}).batch(1280        ["hello", "wooorld"], {"tags": ["a-tag"]}1281    ) == [5, 7]1282    assert len(spy.call_args_list) == 21283    for i, call in enumerate(spy.call_args_list):1284        assert call.args[0] == ("hello" if i == 0 else "wooorld")1285        assert call.args[1].get("tags") == ["a-tag"]1286        assert call.args[1].get("metadata") == {"a": "b"}1287    spy.reset_mock()12881289    assert sorted(1290        c for c in fake.batch_as_completed(["hello", "wooorld"], {"tags": ["a-tag"]})1291    ) == [(0, 5), (1, 7)]1292    assert len(spy.call_args_list) == 21293    for i, call in enumerate(spy.call_args_list):1294        assert call.args[0] == ("hello" if i == 0 else "wooorld")1295        assert call.args[1].get("tags") == ["a-tag"]129612971298async def test_with_config_async(mocker: MockerFixture) -> None:1299    fake = FakeRunnable()1300    spy = mocker.spy(fake, "invoke")13011302    handler = ConsoleCallbackHandler()1303    assert (1304        await fake.with_config(metadata={"a": "b"}).ainvoke(1305            "hello", config={"callbacks": [handler]}1306        )1307        == 51308    )1309    assert spy.call_args_list == [1310        mocker.call(1311            "hello",1312            {1313                "callbacks": [handler],1314                "metadata": {"a": "b"},1315                "configurable": {},1316                "tags": [],1317            },1318        ),1319    ]1320    spy.reset_mock()13211322    assert [1323        part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")1324    ] == [5]1325    assert spy.call_args_list == [1326        mocker.call("hello", {"metadata": {"a": "b"}, "tags": [], "configurable": {}}),1327    ]1328    spy.reset_mock()13291330    assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(1331        ["hello", "wooorld"], {"metadata": {"key": "value"}}1332    ) == [1333        5,1334        7,1335    ]1336    assert sorted(spy.call_args_list) == [1337        mocker.call(1338            "hello",1339            {1340                "metadata": {"key": "value"},1341                "tags": ["c"],1342                "callbacks": None,1343                "recursion_limit": 5,1344                "configurable": {},1345            },1346        ),1347        mocker.call(1348            "wooorld",1349            {1350                "metadata": {"key": "value"},1351                "tags": ["c"],1352                "callbacks": None,1353                "recursion_limit": 5,1354                "configurable": {},1355            },1356        ),1357    ]1358    spy.reset_mock()13591360    assert sorted(1361        [1362            c1363            async for c in fake.with_config(1364                recursion_limit=5, tags=["c"]1365            ).abatch_as_completed(["hello", "wooorld"], {"metadata": {"key": "value"}})1366        ]1367    ) == [1368        (0, 5),1369        (1, 7),1370    ]1371    assert len(spy.call_args_list) == 21372    first_call = next(call for call in spy.call_args_list if call.args[0] == "hello")1373    assert first_call == mocker.call(1374        "hello",1375        {1376            "metadata": {"key": "value"},1377            "tags": ["c"],1378            "callbacks": None,1379            "recursion_limit": 5,1380            "configurable": {},1381        },1382    )1383    second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")1384    assert second_call == mocker.call(1385        "wooorld",1386        {1387            "metadata": {"key": "value"},1388            "tags": ["c"],1389            "callbacks": None,1390            "recursion_limit": 5,1391            "configurable": {},1392        },1393    )139413951396def test_default_method_implementations(mocker: MockerFixture) -> None:1397    fake = FakeRunnable()1398    spy = mocker.spy(fake, "invoke")13991400    assert fake.invoke("hello", {"tags": ["a-tag"]}) == 51401    assert spy.call_args_list == [1402        mocker.call("hello", {"tags": ["a-tag"]}),1403    ]1404    spy.reset_mock()14051406    assert [*fake.stream("hello", {"metadata": {"key": "value"}})] == [5]1407    assert spy.call_args_list == [1408        mocker.call("hello", {"metadata": {"key": "value"}}),1409    ]1410    spy.reset_mock()14111412    assert fake.batch(1413        ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]1414    ) == [5, 7]14151416    assert len(spy.call_args_list) == 21417    for call in spy.call_args_list:1418        call_arg = call.args[0]14191420        if call_arg == "hello":1421            assert call_arg == "hello"1422            assert call.args[1].get("tags") == ["a-tag"]1423            assert call.args[1].get("metadata") == {}1424        else:1425            assert call_arg == "wooorld"1426            assert call.args[1].get("tags") == []1427            assert call.args[1].get("metadata") == {"key": "value"}14281429    spy.reset_mock()14301431    assert fake.batch(["hello", "wooorld"], {"tags": ["a-tag"]}) == [5, 7]1432    assert len(spy.call_args_list) == 21433    assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}1434    for call in spy.call_args_list:1435        assert call.args[1].get("tags") == ["a-tag"]1436        assert call.args[1].get("metadata") == {}143714381439async def test_default_method_implementations_async(mocker: MockerFixture) -> None:1440    fake = FakeRunnable()1441    spy = mocker.spy(fake, "invoke")14421443    assert await fake.ainvoke("hello", config={"callbacks": []}) == 51444    assert spy.call_args_list == [1445        mocker.call("hello", {"callbacks": []}),1446    ]1447    spy.reset_mock()14481449    assert [part async for part in fake.astream("hello")] == [5]1450    assert spy.call_args_list == [1451        mocker.call("hello", None),1452    ]1453    spy.reset_mock()14541455    assert await fake.abatch(["hello", "wooorld"], {"metadata": {"key": "value"}}) == [1456        5,1457        7,1458    ]1459    assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}1460    for call in spy.call_args_list:1461        assert call.args[1] == {1462            "metadata": {"key": "value"},1463            "tags": [],1464            "callbacks": None,1465            "recursion_limit": 25,1466            "configurable": {},1467        }146814691470def test_prompt() -> None:1471    prompt = ChatPromptTemplate.from_messages(1472        messages=[1473            SystemMessage(content="You are a nice assistant."),1474            HumanMessagePromptTemplate.from_template("{question}"),1475        ]1476    )1477    expected = ChatPromptValue(1478        messages=[1479            SystemMessage(content="You are a nice assistant."),1480            HumanMessage(content="What is your name?"),1481        ]1482    )14831484    assert prompt.invoke({"question": "What is your name?"}) == expected14851486    assert prompt.batch(1487        [1488            {"question": "What is your name?"},1489            {"question": "What is your favorite color?"},1490        ]1491    ) == [1492        expected,1493        ChatPromptValue(1494            messages=[1495                SystemMessage(content="You are a nice assistant."),1496                HumanMessage(content="What is your favorite color?"),1497            ]1498        ),1499    ]15001501    assert [*prompt.stream({"question": "What is your name?"})] == [expected]150215031504async def test_prompt_async() -> None:1505    prompt = ChatPromptTemplate.from_messages(1506        messages=[1507            SystemMessage(content="You are a nice assistant."),1508            HumanMessagePromptTemplate.from_template("{question}"),1509        ]1510    )1511    expected = ChatPromptValue(1512        messages=[1513            SystemMessage(content="You are a nice assistant."),1514            HumanMessage(content="What is your name?"),1515        ]1516    )15171518    assert await prompt.ainvoke({"question": "What is your name?"}) == expected15191520    assert await prompt.abatch(1521        [1522            {"question": "What is your name?"},1523            {"question": "What is your favorite color?"},1524        ]1525    ) == [1526        expected,1527        ChatPromptValue(1528            messages=[1529                SystemMessage(content="You are a nice assistant."),1530                HumanMessage(content="What is your favorite color?"),1531            ]1532        ),1533    ]15341535    assert [1536        part async for part in prompt.astream({"question": "What is your name?"})1537    ] == [expected]15381539    stream_log = [1540        part async for part in prompt.astream_log({"question": "What is your name?"})1541    ]15421543    assert len(stream_log[0].ops) == 11544    assert stream_log[0].ops[0]["op"] == "replace"1545    assert stream_log[0].ops[0]["path"] == ""1546    assert stream_log[0].ops[0]["value"]["logs"] == {}1547    assert stream_log[0].ops[0]["value"]["final_output"] is None1548    assert stream_log[0].ops[0]["value"]["streamed_output"] == []1549    assert isinstance(stream_log[0].ops[0]["value"]["id"], str)15501551    assert stream_log[1:] == [1552        RunLogPatch(1553            {"op": "add", "path": "/streamed_output/-", "value": expected},1554            {1555                "op": "replace",1556                "path": "/final_output",1557                "value": ChatPromptValue(1558                    messages=[1559                        SystemMessage(content="You are a nice assistant."),1560                        HumanMessage(content="What is your name?"),1561                    ]1562                ),1563            },1564        ),1565    ]15661567    stream_log_state = [1568        part1569        async for part in prompt.astream_log(1570            {"question": "What is your name?"}, diff=False1571        )1572    ]15731574    # remove random id1575    stream_log[0].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"1576    stream_log_state[-1].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"1577    stream_log_state[-1].state["id"] = "00000000-0000-0000-0000-000000000000"15781579    # assert output with diff=False matches output with diff=True1580    assert stream_log_state[-1].ops == [op for chunk in stream_log for op in chunk.ops]1581    assert stream_log_state[-1] == RunLog(1582        *[op for chunk in stream_log for op in chunk.ops],1583        state={1584            "final_output": ChatPromptValue(1585                messages=[1586                    SystemMessage(content="You are a nice assistant."),1587                    HumanMessage(content="What is your name?"),1588                ]1589            ),1590            "id": "00000000-0000-0000-0000-000000000000",1591            "logs": {},1592            "streamed_output": [1593                ChatPromptValue(1594                    messages=[1595                        SystemMessage(content="You are a nice assistant."),1596                        HumanMessage(content="What is your name?"),1597                    ]1598                )1599            ],1600            "type": "prompt",1601            "name": "ChatPromptTemplate",1602        },1603    )16041605    # nested inside trace_with_chain_group16061607    async with atrace_as_chain_group("a_group") as manager:1608        stream_log_nested = [1609            part1610            async for part in prompt.astream_log(1611                {"question": "What is your name?"}, config={"callbacks": manager}1612            )1613        ]16141615    assert len(stream_log_nested[0].ops) == 11616    assert stream_log_nested[0].ops[0]["op"] == "replace"1617    assert stream_log_nested[0].ops[0]["path"] == ""1618    assert stream_log_nested[0].ops[0]["value"]["logs"] == {}1619    assert stream_log_nested[0].ops[0]["value"]["final_output"] is None1620    assert stream_log_nested[0].ops[0]["value"]["streamed_output"] == []1621    assert isinstance(stream_log_nested[0].ops[0]["value"]["id"], str)16221623    assert stream_log_nested[1:] == [1624        RunLogPatch(1625            {"op": "add", "path": "/streamed_output/-", "value": expected},1626            {1627                "op": "replace",1628                "path": "/final_output",1629                "value": ChatPromptValue(1630                    messages=[1631                        SystemMessage(content="You are a nice assistant."),1632                        HumanMessage(content="What is your name?"),1633                    ]1634                ),1635            },1636        ),1637    ]163816391640def test_prompt_template_params() -> None:1641    prompt = ChatPromptTemplate.from_template(1642        "Respond to the following question: {question}"1643    )1644    result = prompt.invoke(1645        {1646            "question": "test",1647            "topic": "test",1648        }1649    )1650    assert result == ChatPromptValue(1651        messages=[HumanMessage(content="Respond to the following question: test")]1652    )16531654    with pytest.raises(KeyError):1655        prompt.invoke({})165616571658def test_with_listeners(mocker: MockerFixture) -> None:1659    prompt = (1660        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1661        + "{question}"1662    )1663    chat = FakeListChatModel(responses=["foo"])16641665    chain = prompt | chat16661667    mock_start = mocker.Mock()1668    mock_end = mocker.Mock()16691670    chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(1671        {"question": "Who are you?"}1672    )16731674    assert mock_start.call_count == 11675    assert mock_start.call_args[0][0].name == "RunnableSequence"1676    assert mock_end.call_count == 116771678    mock_start.reset_mock()1679    mock_end.reset_mock()16801681    with trace_as_chain_group("hello") as manager:1682        chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(1683            {"question": "Who are you?"}, {"callbacks": manager}1684        )16851686    assert mock_start.call_count == 11687    assert mock_start.call_args[0][0].name == "RunnableSequence"1688    assert mock_end.call_count == 1168916901691async def test_with_listeners_async(mocker: MockerFixture) -> None:1692    prompt = (1693        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1694        + "{question}"1695    )1696    chat = FakeListChatModel(responses=["foo"])16971698    chain = prompt | chat16991700    mock_start = mocker.Mock()1701    mock_end = mocker.Mock()17021703    await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(1704        {"question": "Who are you?"}1705    )17061707    assert mock_start.call_count == 11708    assert mock_start.call_args[0][0].name == "RunnableSequence"1709    assert mock_end.call_count == 117101711    mock_start.reset_mock()1712    mock_end.reset_mock()17131714    async with atrace_as_chain_group("hello") as manager:1715        await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(1716            {"question": "Who are you?"}, {"callbacks": manager}1717        )17181719    assert mock_start.call_count == 11720    assert mock_start.call_args[0][0].name == "RunnableSequence"1721    assert mock_end.call_count == 1172217231724def test_with_listener_propagation(mocker: MockerFixture) -> None:1725    prompt = (1726        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1727        + "{question}"1728    )1729    chat = FakeListChatModel(responses=["foo"])1730    chain = prompt | chat1731    mock_start = mocker.Mock()1732    mock_end = mocker.Mock()1733    chain_with_listeners = chain.with_listeners(on_start=mock_start, on_end=mock_end)17341735    chain_with_listeners.with_retry().invoke({"question": "Who are you?"})17361737    assert mock_start.call_count == 11738    assert mock_start.call_args[0][0].name == "RunnableSequence"1739    assert mock_end.call_count == 117401741    mock_start.reset_mock()1742    mock_end.reset_mock()17431744    chain_with_listeners.invoke({"question": "Who are you?"})17451746    assert mock_start.call_count == 11747    assert mock_start.call_args[0][0].name == "RunnableSequence"1748    assert mock_end.call_count == 117491750    mock_start.reset_mock()1751    mock_end.reset_mock()17521753    chain_with_listeners.with_config({"tags": ["foo"]}).invoke(1754        {"question": "Who are you?"}1755    )17561757    assert mock_start.call_count == 11758    assert mock_start.call_args[0][0].name == "RunnableSequence"1759    assert mock_end.call_count == 117601761    mock_start.reset_mock()1762    mock_end.reset_mock()17631764    chain_with_listeners.bind(stop=["foo"]).invoke({"question": "Who are you?"})17651766    assert mock_start.call_count == 11767    assert mock_start.call_args[0][0].name == "RunnableSequence"1768    assert mock_end.call_count == 117691770    mock_start.reset_mock()1771    mock_end.reset_mock()17721773    mock_start_inner = mocker.Mock()1774    mock_end_inner = mocker.Mock()17751776    chain_with_listeners.with_listeners(1777        on_start=mock_start_inner, on_end=mock_end_inner1778    ).invoke({"question": "Who are you?"})17791780    assert mock_start.call_count == 11781    assert mock_start.call_args[0][0].name == "RunnableSequence"1782    assert mock_end.call_count == 11783    assert mock_start_inner.call_count == 11784    assert mock_start_inner.call_args[0][0].name == "RunnableSequence"1785    assert mock_end_inner.call_count == 1178617871788@freeze_time("2023-01-01")1789@pytest.mark.usefixtures("deterministic_uuids")1790def test_prompt_with_chat_model(1791    mocker: MockerFixture,1792    snapshot: SnapshotAssertion,1793) -> None:1794    prompt = (1795        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1796        + "{question}"1797    )1798    chat = FakeListChatModel(responses=["foo"])17991800    chain = prompt | chat18011802    assert repr(chain) == snapshot1803    assert isinstance(chain, RunnableSequence)1804    assert chain.first == prompt1805    assert chain.middle == []1806    assert chain.last == chat1807    assert dumps(chain, pretty=True) == snapshot18081809    # Test invoke1810    prompt_spy = mocker.spy(prompt.__class__, "invoke")1811    chat_spy = mocker.spy(chat.__class__, "invoke")1812    tracer = FakeTracer()1813    assert chain.invoke(1814        {"question": "What is your name?"}, {"callbacks": [tracer]}1815    ) == _any_id_ai_message(content="foo")1816    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1817    assert chat_spy.call_args.args[1] == ChatPromptValue(1818        messages=[1819            SystemMessage(content="You are a nice assistant."),1820            HumanMessage(content="What is your name?"),1821        ]1822    )18231824    assert tracer.runs == snapshot18251826    mocker.stop(prompt_spy)1827    mocker.stop(chat_spy)18281829    # Test batch1830    prompt_spy = mocker.spy(prompt.__class__, "batch")1831    chat_spy = mocker.spy(chat.__class__, "batch")1832    tracer = FakeTracer()1833    assert chain.batch(1834        [1835            {"question": "What is your name?"},1836            {"question": "What is your favorite color?"},1837        ],1838        {"callbacks": [tracer]},1839    ) == [1840        _any_id_ai_message(content="foo"),1841        _any_id_ai_message(content="foo"),1842    ]1843    assert prompt_spy.call_args.args[1] == [1844        {"question": "What is your name?"},1845        {"question": "What is your favorite color?"},1846    ]1847    assert chat_spy.call_args.args[1] == [1848        ChatPromptValue(1849            messages=[1850                SystemMessage(content="You are a nice assistant."),1851                HumanMessage(content="What is your name?"),1852            ]1853        ),1854        ChatPromptValue(1855            messages=[1856                SystemMessage(content="You are a nice assistant."),1857                HumanMessage(content="What is your favorite color?"),1858            ]1859        ),1860    ]1861    assert (1862        len(1863            [1864                r1865                for r in tracer.runs1866                if r.parent_run_id is None and len(r.child_runs) == 21867            ]1868        )1869        == 21870    ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"1871    mocker.stop(prompt_spy)1872    mocker.stop(chat_spy)18731874    # Test stream1875    prompt_spy = mocker.spy(prompt.__class__, "invoke")1876    chat_spy = mocker.spy(chat.__class__, "stream")1877    tracer = FakeTracer()1878    assert [1879        *chain.stream({"question": "What is your name?"}, {"callbacks": [tracer]})1880    ] == [1881        _any_id_ai_message_chunk(content="f"),1882        _any_id_ai_message_chunk(content="o"),1883        _any_id_ai_message_chunk(content="o", chunk_position="last"),1884    ]1885    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1886    assert chat_spy.call_args.args[1] == ChatPromptValue(1887        messages=[1888            SystemMessage(content="You are a nice assistant."),1889            HumanMessage(content="What is your name?"),1890        ]1891    )189218931894@freeze_time("2023-01-01")1895@pytest.mark.usefixtures("deterministic_uuids")1896async def test_prompt_with_chat_model_async(1897    mocker: MockerFixture,1898    snapshot: SnapshotAssertion,1899) -> None:1900    prompt = (1901        SystemMessagePromptTemplate.from_template("You are a nice assistant.")1902        + "{question}"1903    )1904    chat = FakeListChatModel(responses=["foo"])19051906    chain = prompt | chat19071908    assert repr(chain) == snapshot1909    assert isinstance(chain, RunnableSequence)1910    assert chain.first == prompt1911    assert chain.middle == []1912    assert chain.last == chat1913    assert dumps(chain, pretty=True) == snapshot19141915    # Test invoke1916    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")1917    chat_spy = mocker.spy(chat.__class__, "ainvoke")1918    tracer = FakeTracer()1919    assert await chain.ainvoke(1920        {"question": "What is your name?"}, {"callbacks": [tracer]}1921    ) == _any_id_ai_message(content="foo")1922    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1923    assert chat_spy.call_args.args[1] == ChatPromptValue(1924        messages=[1925            SystemMessage(content="You are a nice assistant."),1926            HumanMessage(content="What is your name?"),1927        ]1928    )19291930    assert tracer.runs == snapshot19311932    mocker.stop(prompt_spy)1933    mocker.stop(chat_spy)19341935    # Test batch1936    prompt_spy = mocker.spy(prompt.__class__, "abatch")1937    chat_spy = mocker.spy(chat.__class__, "abatch")1938    tracer = FakeTracer()1939    assert await chain.abatch(1940        [1941            {"question": "What is your name?"},1942            {"question": "What is your favorite color?"},1943        ],1944        {"callbacks": [tracer]},1945    ) == [1946        _any_id_ai_message(content="foo"),1947        _any_id_ai_message(content="foo"),1948    ]1949    assert prompt_spy.call_args.args[1] == [1950        {"question": "What is your name?"},1951        {"question": "What is your favorite color?"},1952    ]1953    assert chat_spy.call_args.args[1] == [1954        ChatPromptValue(1955            messages=[1956                SystemMessage(content="You are a nice assistant."),1957                HumanMessage(content="What is your name?"),1958            ]1959        ),1960        ChatPromptValue(1961            messages=[1962                SystemMessage(content="You are a nice assistant."),1963                HumanMessage(content="What is your favorite color?"),1964            ]1965        ),1966    ]1967    assert (1968        len(1969            [1970                r1971                for r in tracer.runs1972                if r.parent_run_id is None and len(r.child_runs) == 21973            ]1974        )1975        == 21976    ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"1977    mocker.stop(prompt_spy)1978    mocker.stop(chat_spy)19791980    # Test stream1981    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")1982    chat_spy = mocker.spy(chat.__class__, "astream")1983    tracer = FakeTracer()1984    assert [1985        a1986        async for a in chain.astream(1987            {"question": "What is your name?"}, {"callbacks": [tracer]}1988        )1989    ] == [1990        _any_id_ai_message_chunk(content="f"),1991        _any_id_ai_message_chunk(content="o"),1992        _any_id_ai_message_chunk(content="o", chunk_position="last"),1993    ]1994    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1995    assert chat_spy.call_args.args[1] == ChatPromptValue(1996        messages=[1997            SystemMessage(content="You are a nice assistant."),1998            HumanMessage(content="What is your name?"),1999        ]2000    )

Findings

✓ No findings reported for this file.

Get this view in your editor

Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.