1import asyncio2import re3import sys4import time5import uuid6import warnings7from collections.abc import (8 AsyncIterator,9 Awaitable,10 Callable,11 Iterator,12 Sequence,13)14from functools import partial15from operator import itemgetter16from typing import Any, cast17from uuid import UUID1819import pytest20from freezegun import freeze_time21from packaging import version22from pydantic import BaseModel, Field23from pytest_mock import MockerFixture24from syrupy.assertion import SnapshotAssertion25from typing_extensions import TypedDict, override2627from langchain_core.callbacks import BaseCallbackHandler28from langchain_core.callbacks.manager import (29 AsyncCallbackManagerForRetrieverRun,30 CallbackManagerForRetrieverRun,31 atrace_as_chain_group,32 trace_as_chain_group,33)34from langchain_core.documents import Document35from langchain_core.language_models import (36 FakeListChatModel,37 FakeListLLM,38 FakeStreamingListLLM,39)40from langchain_core.language_models.fake_chat_models import GenericFakeChatModel41from langchain_core.load import dumpd, dumps42from langchain_core.load.load import loads43from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage44from langchain_core.messages.base import BaseMessage45from langchain_core.output_parsers import (46 BaseOutputParser,47 CommaSeparatedListOutputParser,48 StrOutputParser,49)50from langchain_core.outputs.chat_generation import ChatGeneration51from langchain_core.outputs.llm_result import LLMResult52from langchain_core.prompt_values import ChatPromptValue, StringPromptValue53from langchain_core.prompts import (54 ChatPromptTemplate,55 HumanMessagePromptTemplate,56 MessagesPlaceholder,57 PromptTemplate,58 SystemMessagePromptTemplate,59)60from langchain_core.retrievers import BaseRetriever61from langchain_core.runnables import (62 AddableDict,63 ConfigurableField,64 ConfigurableFieldMultiOption,65 ConfigurableFieldSingleOption,66 RouterRunnable,67 Runnable,68 RunnableAssign,69 RunnableBinding,70 RunnableBranch,71 RunnableConfig,72 RunnableGenerator,73 RunnableLambda,74 RunnableParallel,75 RunnablePassthrough,76 RunnablePick,77 RunnableSequence,78 add,79 chain,80)81from langchain_core.runnables.base import RunnableMap, RunnableSerializable82from langchain_core.runnables.utils import Input, Output83from langchain_core.tools import BaseTool, tool84from langchain_core.tracers import (85 BaseTracer,86 ConsoleCallbackHandler,87 Run,88 RunLog,89 RunLogPatch,90)91from langchain_core.tracers._compat import pydantic_copy92from langchain_core.tracers.context import collect_runs93from langchain_core.utils.pydantic import PYDANTIC_VERSION94from tests.unit_tests.pydantic_utils import _normalize_schema, _schema95from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk9697PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION98PYDANTIC_VERSION_AT_LEAST_210 = version.parse("2.10") <= PYDANTIC_VERSION99100101class FakeTracer(BaseTracer):102 """Fake tracer that records LangChain execution.103104 It replaces run IDs with deterministic UUIDs for snapshotting.105 """106107 def __init__(self) -> None:108 """Initialize the tracer."""109 super().__init__()110 self.runs: list[Run] = []111 self.uuids_map: dict[UUID, UUID] = {}112 self.uuids_generator = (113 UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)114 )115116 def _replace_uuid(self, uuid: UUID) -> UUID:117 if uuid not in self.uuids_map:118 self.uuids_map[uuid] = next(self.uuids_generator)119 return self.uuids_map[uuid]120121 def _replace_message_id(self, maybe_message: Any) -> Any:122 if isinstance(maybe_message, BaseMessage):123 maybe_message.id = str(next(self.uuids_generator))124 if isinstance(maybe_message, ChatGeneration):125 maybe_message.message.id = str(next(self.uuids_generator))126 if isinstance(maybe_message, LLMResult):127 for i, gen_list in enumerate(maybe_message.generations):128 for j, gen in enumerate(gen_list):129 maybe_message.generations[i][j] = self._replace_message_id(gen)130 if isinstance(maybe_message, dict):131 for k, v in maybe_message.items():132 maybe_message[k] = self._replace_message_id(v)133 if isinstance(maybe_message, list):134 for i, v in enumerate(maybe_message):135 maybe_message[i] = self._replace_message_id(v)136137 return maybe_message138139 def _copy_run(self, run: Run) -> Run:140 if run.dotted_order:141 levels = run.dotted_order.split(".")142 processed_levels = []143 for level in levels:144 timestamp, run_id = level.split("Z")145 new_run_id = self._replace_uuid(UUID(run_id))146 processed_level = f"{timestamp}Z{new_run_id}"147 processed_levels.append(processed_level)148 new_dotted_order = ".".join(processed_levels)149 else:150 new_dotted_order = None151 update_dict = {152 "id": self._replace_uuid(run.id),153 "parent_run_id": (154 self.uuids_map[run.parent_run_id] if run.parent_run_id else None155 ),156 "child_runs": [self._copy_run(child) for child in run.child_runs],157 "trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,158 "dotted_order": new_dotted_order,159 "inputs": self._replace_message_id(run.inputs),160 "outputs": self._replace_message_id(run.outputs),161 }162 return pydantic_copy(run, update=update_dict)163164 def _persist_run(self, run: Run) -> None:165 """Persist a run."""166 self.runs.append(self._copy_run(run))167168 def flattened_runs(self) -> list[Run]:169 q = [*self.runs]170 result = []171 while q:172 parent = q.pop()173 result.append(parent)174 if parent.child_runs:175 q.extend(parent.child_runs)176 return result177178 @property179 def run_ids(self) -> list[uuid.UUID | None]:180 runs = self.flattened_runs()181 uuids_map = {v: k for k, v in self.uuids_map.items()}182 return [uuids_map.get(r.id) for r in runs]183184185class FakeRunnable(Runnable[str, int]):186 @override187 def invoke(188 self,189 input: str,190 config: RunnableConfig | None = None,191 **kwargs: Any,192 ) -> int:193 return len(input)194195196class FakeRunnableSerializable(RunnableSerializable[str, int]):197 hello: str = ""198199 @override200 def invoke(201 self,202 input: str,203 config: RunnableConfig | None = None,204 **kwargs: Any,205 ) -> int:206 return len(input)207208209class FakeRetriever(BaseRetriever):210 @override211 def _get_relevant_documents(212 self, query: str, *, run_manager: CallbackManagerForRetrieverRun213 ) -> list[Document]:214 return [Document(page_content="foo"), Document(page_content="bar")]215216 @override217 async def _aget_relevant_documents(218 self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun219 ) -> list[Document]:220 return [Document(page_content="foo"), Document(page_content="bar")]221222223@pytest.mark.skipif(224 PYDANTIC_VERSION_AT_LEAST_210,225 reason=(226 "Only test with most recent version of pydantic. "227 "Pydantic introduced small fixes to generated JSONSchema on minor versions."228 ),229)230def test_schemas(snapshot: SnapshotAssertion) -> None:231 fake = FakeRunnable() # str -> int232233 assert fake.get_input_jsonschema() == {234 "title": "FakeRunnableInput",235 "type": "string",236 }237 assert fake.get_output_jsonschema() == {238 "title": "FakeRunnableOutput",239 "type": "integer",240 }241 assert fake.get_config_jsonschema(include=["tags", "metadata", "run_name"]) == {242 "properties": {243 "metadata": {244 "default": None,245 "title": "Metadata",246 "type": "object",247 },248 "run_name": {"default": None, "title": "Run Name", "type": "string"},249 "tags": {250 "default": None,251 "items": {"type": "string"},252 "title": "Tags",253 "type": "array",254 },255 },256 "title": "FakeRunnableConfig",257 "type": "object",258 }259260 fake_bound = FakeRunnable().bind(a="b") # str -> int261262 assert fake_bound.get_input_jsonschema() == {263 "title": "FakeRunnableInput",264 "type": "string",265 }266 assert fake_bound.get_output_jsonschema() == {267 "title": "FakeRunnableOutput",268 "type": "integer",269 }270271 fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,)) # str -> int272273 assert fake_w_fallbacks.get_input_jsonschema() == {274 "title": "FakeRunnableInput",275 "type": "string",276 }277 assert fake_w_fallbacks.get_output_jsonschema() == {278 "title": "FakeRunnableOutput",279 "type": "integer",280 }281282 def typed_lambda_impl(x: str) -> int:283 return len(x)284285 typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int286287 assert typed_lambda.get_input_jsonschema() == {288 "title": "typed_lambda_impl_input",289 "type": "string",290 }291 assert typed_lambda.get_output_jsonschema() == {292 "title": "typed_lambda_impl_output",293 "type": "integer",294 }295296 async def typed_async_lambda_impl(x: str) -> int:297 return len(x)298299 typed_async_lambda = RunnableLambda(typed_async_lambda_impl) # str -> int300301 assert typed_async_lambda.get_input_jsonschema() == {302 "title": "typed_async_lambda_impl_input",303 "type": "string",304 }305 assert typed_async_lambda.get_output_jsonschema() == {306 "title": "typed_async_lambda_impl_output",307 "type": "integer",308 }309310 fake_ret = FakeRetriever() # str -> list[Document]311312 assert fake_ret.get_input_jsonschema() == {313 "title": "FakeRetrieverInput",314 "type": "string",315 }316 assert _normalize_schema(fake_ret.get_output_jsonschema()) == {317 "$defs": {318 "Document": {319 "description": "Class for storing a piece of text and "320 "associated metadata.\n"321 "\n"322 "!!! note\n"323 "\n"324 " `Document` is for **retrieval workflows**, not chat I/O. For "325 "sending text\n"326 " to an LLM in a conversation, use message types from "327 "`langchain.messages`.\n"328 "\n"329 "Example:\n"330 " ```python\n"331 " from langchain_core.documents import Document\n"332 "\n"333 " document = Document(\n"334 ' page_content="Hello, world!", '335 'metadata={"source": "https://example.com"}\n'336 " )\n"337 " ```",338 "properties": {339 "id": {340 "anyOf": [{"type": "string"}, {"type": "null"}],341 "default": None,342 "title": "Id",343 },344 "metadata": {"title": "Metadata", "type": "object"},345 "page_content": {"title": "Page Content", "type": "string"},346 "type": {347 "const": "Document",348 "default": "Document",349 "title": "Type",350 },351 },352 "required": ["page_content"],353 "title": "Document",354 "type": "object",355 }356 },357 "items": {"$ref": "#/$defs/Document"},358 "title": "FakeRetrieverOutput",359 "type": "array",360 }361362 fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]]363364 assert _schema(fake_llm.input_schema) == snapshot(name="fake_llm_input_schema")365 assert _schema(fake_llm.output_schema) == {366 "title": "FakeListLLMOutput",367 "type": "string",368 }369370 fake_chat = FakeListChatModel(responses=["a"]) # str -> list[list[str]]371372 assert _schema(fake_chat.input_schema) == snapshot(name="fake_chat_input_schema")373 assert _schema(fake_chat.output_schema) == snapshot(name="fake_chat_output_schema")374375 chat_prompt = ChatPromptTemplate.from_messages(376 [377 MessagesPlaceholder(variable_name="history"),378 ("human", "Hello, how are you?"),379 ]380 )381382 assert _normalize_schema(chat_prompt.get_input_jsonschema()) == snapshot(383 name="chat_prompt_input_schema"384 )385 assert _normalize_schema(chat_prompt.get_output_jsonschema()) == snapshot(386 name="chat_prompt_output_schema"387 )388389 prompt = PromptTemplate.from_template("Hello, {name}!")390391 assert prompt.get_input_jsonschema() == {392 "title": "PromptInput",393 "type": "object",394 "properties": {"name": {"title": "Name", "type": "string"}},395 "required": ["name"],396 }397 assert _schema(prompt.output_schema) == snapshot(name="prompt_output_schema")398399 prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()400401 assert _normalize_schema(prompt_mapper.get_input_jsonschema()) == {402 "$defs": {403 "PromptInput": {404 "properties": {"name": {"title": "Name", "type": "string"}},405 "required": ["name"],406 "title": "PromptInput",407 "type": "object",408 }409 },410 "default": None,411 "items": {"$ref": "#/$defs/PromptInput"},412 "title": "RunnableEach<PromptTemplate>Input",413 "type": "array",414 }415 assert _schema(prompt_mapper.output_schema) == snapshot(416 name="prompt_mapper_output_schema"417 )418419 list_parser = CommaSeparatedListOutputParser()420421 assert _schema(list_parser.input_schema) == snapshot(422 name="list_parser_input_schema"423 )424 assert _schema(list_parser.output_schema) == {425 "title": "CommaSeparatedListOutputParserOutput",426 "type": "array",427 "items": {"type": "string"},428 }429430 seq = prompt | fake_llm | list_parser431432 assert seq.get_input_jsonschema() == {433 "title": "PromptInput",434 "type": "object",435 "properties": {"name": {"title": "Name", "type": "string"}},436 "required": ["name"],437 }438 assert seq.get_output_jsonschema() == {439 "type": "array",440 "items": {"type": "string"},441 "title": "CommaSeparatedListOutputParserOutput",442 }443444 router: Runnable = RouterRunnable({})445446 assert _schema(router.input_schema) == {447 "$ref": "#/definitions/RouterInput",448 "definitions": {449 "RouterInput": {450 "description": "Router input.",451 "properties": {452 "input": {"title": "Input"},453 "key": {"title": "Key", "type": "string"},454 },455 "required": ["key", "input"],456 "title": "RouterInput",457 "type": "object",458 }459 },460 "title": "RouterRunnableInput",461 }462 assert router.get_output_jsonschema() == {"title": "RouterRunnableOutput"}463464 seq_w_map: Runnable = (465 prompt466 | fake_llm467 | {468 "original": RunnablePassthrough(input_type=str),469 "as_list": list_parser,470 "length": typed_lambda_impl,471 }472 )473474 assert seq_w_map.get_input_jsonschema() == {475 "title": "PromptInput",476 "type": "object",477 "properties": {"name": {"title": "Name", "type": "string"}},478 "required": ["name"],479 }480 assert seq_w_map.get_output_jsonschema() == {481 "title": "RunnableParallel<original,as_list,length>Output",482 "type": "object",483 "properties": {484 "original": {"title": "Original", "type": "string"},485 "length": {"title": "Length", "type": "integer"},486 "as_list": {487 "title": "As List",488 "type": "array",489 "items": {"type": "string"},490 },491 },492 "required": ["original", "as_list", "length"],493 }494495 # Add a test for schema of runnable assign496 def foo(x: int) -> int:497 return x498499 foo_ = RunnableLambda(foo)500501 assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == {502 "properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},503 "required": ["root", "bar"],504 "title": "RunnableAssignOutput",505 "type": "object",506 }507508509def test_passthrough_assign_schema() -> None:510 retriever = FakeRetriever() # str -> list[Document]511 prompt = PromptTemplate.from_template("{context} {question}")512 fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]]513514 seq_w_assign = (515 RunnablePassthrough.assign(context=itemgetter("question") | retriever)516 | prompt517 | fake_llm518 )519520 assert seq_w_assign.get_input_jsonschema() == {521 "properties": {"question": {"title": "Question", "type": "string"}},522 "title": "RunnableSequenceInput",523 "type": "object",524 "required": ["question"],525 }526 assert seq_w_assign.get_output_jsonschema() == {527 "title": "FakeListLLMOutput",528 "type": "string",529 }530531 invalid_seq_w_assign = (532 RunnablePassthrough.assign(context=itemgetter("question") | retriever)533 | fake_llm534 )535536 # fallback to RunnableAssign.input_schema if next runnable doesn't have537 # expected dict input_schema538 assert invalid_seq_w_assign.get_input_jsonschema() == {539 "properties": {"question": {"title": "Question"}},540 "title": "RunnableParallel<context>Input",541 "type": "object",542 "required": ["question"],543 }544545546def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:547 first_lambda = lambda x: x["hello"] # noqa: E731548 assert RunnableLambda(first_lambda).get_input_jsonschema() == {549 "title": "RunnableLambdaInput",550 "type": "object",551 "properties": {"hello": {"title": "Hello"}},552 "required": ["hello"],553 }554555 second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731556 assert RunnableLambda(second_lambda).get_input_jsonschema() == {557 "title": "RunnableLambdaInput",558 "type": "object",559 "properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},560 "required": ["bye", "hello"],561 }562563 def get_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202564 return value["variable_name"]565566 assert RunnableLambda(get_value).get_input_jsonschema() == {567 "title": "get_value_input",568 "type": "object",569 "properties": {"variable_name": {"title": "Variable Name"}},570 "required": ["variable_name"],571 }572573 async def aget_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202574 return (value["variable_name"], value.get("another"))575576 assert RunnableLambda(aget_value).get_input_jsonschema() == {577 "title": "aget_value_input",578 "type": "object",579 "properties": {580 "another": {"title": "Another"},581 "variable_name": {"title": "Variable Name"},582 },583 "required": ["another", "variable_name"],584 }585586 async def aget_values(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202587 return {588 "hello": value["variable_name"],589 "bye": value["variable_name"],590 "byebye": value["yo"],591 }592593 assert RunnableLambda(aget_values).get_input_jsonschema() == {594 "title": "aget_values_input",595 "type": "object",596 "properties": {597 "variable_name": {"title": "Variable Name"},598 "yo": {"title": "Yo"},599 },600 "required": ["variable_name", "yo"],601 }602603 class InputType(TypedDict):604 variable_name: str605 yo: int606607 class OutputType(TypedDict):608 hello: str609 bye: str610 byebye: int611612 async def aget_values_typed(value: InputType) -> OutputType:613 return {614 "hello": value["variable_name"],615 "bye": value["variable_name"],616 "byebye": value["yo"],617 }618619 assert _normalize_schema(620 RunnableLambda(aget_values_typed).get_input_jsonschema()621 ) == _normalize_schema(622 {623 "$defs": {624 "InputType": {625 "properties": {626 "variable_name": {627 "title": "Variable Name",628 "type": "string",629 },630 "yo": {"title": "Yo", "type": "integer"},631 },632 "required": ["variable_name", "yo"],633 "title": "InputType",634 "type": "object",635 }636 },637 "allOf": [{"$ref": "#/$defs/InputType"}],638 "title": "aget_values_typed_input",639 }640 )641642 if PYDANTIC_VERSION_AT_LEAST_29:643 assert _normalize_schema(644 RunnableLambda(aget_values_typed).get_output_jsonschema()645 ) == snapshot(name="schema8")646647648def test_with_types_with_type_generics() -> None:649 """Verify that with_types works if we use things like list[int]."""650651 def foo(x: int) -> None:652 """Add one to the input."""653 raise NotImplementedError654655 # Try specifying some656 RunnableLambda(foo).with_types(657 output_type=list[int], # type: ignore[arg-type]658 input_type=list[int], # type: ignore[arg-type]659 )660 RunnableLambda(foo).with_types(661 output_type=Sequence[int], # type: ignore[arg-type]662 input_type=Sequence[int], # type: ignore[arg-type]663 )664665666def test_schema_with_itemgetter() -> None:667 """Test runnable with itemgetter."""668 foo = RunnableLambda(itemgetter("hello"))669 assert _schema(foo.input_schema) == {670 "properties": {"hello": {"title": "Hello"}},671 "required": ["hello"],672 "title": "RunnableLambdaInput",673 "type": "object",674 }675 prompt = ChatPromptTemplate.from_template("what is {language}?")676 chain: Runnable = {"language": itemgetter("language")} | prompt677 assert _schema(chain.input_schema) == {678 "properties": {"language": {"title": "Language"}},679 "required": ["language"],680 "title": "RunnableParallel<language>Input",681 "type": "object",682 }683684685def test_schema_complex_seq() -> None:686 prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")687 prompt2 = ChatPromptTemplate.from_template(688 "what country is the city {city} in? respond in {language}"689 )690691 model = FakeListChatModel(responses=[""])692693 chain1: Runnable = RunnableSequence(694 prompt1, model, StrOutputParser(), name="city_chain"695 )696697 assert chain1.name == "city_chain"698699 chain2: Runnable = (700 {"city": chain1, "language": itemgetter("language")}701 | prompt2702 | model703 | StrOutputParser()704 )705706 assert chain2.get_input_jsonschema() == {707 "title": "RunnableParallel<city,language>Input",708 "type": "object",709 "properties": {710 "person": {"title": "Person", "type": "string"},711 "language": {"title": "Language"},712 },713 "required": ["person", "language"],714 }715716 assert chain2.get_output_jsonschema() == {717 "title": "StrOutputParserOutput",718 "type": "string",719 }720721 assert chain2.with_types(input_type=str).get_input_jsonschema() == {722 "title": "RunnableSequenceInput",723 "type": "string",724 }725726 assert chain2.with_types(input_type=int).get_output_jsonschema() == {727 "title": "StrOutputParserOutput",728 "type": "string",729 }730731 class InputType(BaseModel):732 person: str733734 assert chain2.with_types(input_type=InputType).get_input_jsonschema() == {735 "title": "InputType",736 "type": "object",737 "properties": {"person": {"title": "Person", "type": "string"}},738 "required": ["person"],739 }740741742def test_configurable_fields(snapshot: SnapshotAssertion) -> None:743 fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]]744745 assert fake_llm.invoke("...") == "a"746747 fake_llm_configurable = fake_llm.configurable_fields(748 responses=ConfigurableField(749 id="llm_responses",750 name="LLM Responses",751 description="A list of fake responses for this LLM",752 )753 )754755 assert fake_llm_configurable.invoke("...") == "a"756757 if PYDANTIC_VERSION_AT_LEAST_29:758 assert _normalize_schema(759 fake_llm_configurable.get_config_jsonschema()760 ) == snapshot(name="schema2")761762 fake_llm_configured = fake_llm_configurable.with_config(763 configurable={"llm_responses": ["b"]}764 )765766 assert fake_llm_configured.invoke("...") == "b"767768 prompt = PromptTemplate.from_template("Hello, {name}!")769770 assert prompt.invoke({"name": "John"}) == StringPromptValue(text="Hello, John!")771772 prompt_configurable = prompt.configurable_fields(773 template=ConfigurableField(774 id="prompt_template",775 name="Prompt Template",776 description="The prompt template for this chain",777 )778 )779780 assert prompt_configurable.invoke({"name": "John"}) == StringPromptValue(781 text="Hello, John!"782 )783784 if PYDANTIC_VERSION_AT_LEAST_29:785 assert _normalize_schema(786 prompt_configurable.get_config_jsonschema()787 ) == snapshot(name="schema3")788789 prompt_configured = prompt_configurable.with_config(790 configurable={"prompt_template": "Hello, {name}! {name}!"}791 )792793 assert prompt_configured.invoke({"name": "John"}) == StringPromptValue(794 text="Hello, John! John!"795 )796797 assert prompt_configurable.with_config(798 configurable={"prompt_template": "Hello {name} in {lang}"}799 ).get_input_jsonschema() == {800 "title": "PromptInput",801 "type": "object",802 "properties": {803 "lang": {"title": "Lang", "type": "string"},804 "name": {"title": "Name", "type": "string"},805 },806 "required": ["lang", "name"],807 }808809 chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()810811 assert chain_configurable.invoke({"name": "John"}) == "a"812813 if PYDANTIC_VERSION_AT_LEAST_29:814 assert _normalize_schema(815 chain_configurable.get_config_jsonschema()816 ) == snapshot(name="schema4")817818 assert (819 chain_configurable.with_config(820 configurable={821 "prompt_template": "A very good morning to you, {name} {lang}!",822 "llm_responses": ["c"],823 }824 ).invoke({"name": "John", "lang": "en"})825 == "c"826 )827828 assert chain_configurable.with_config(829 configurable={830 "prompt_template": "A very good morning to you, {name} {lang}!",831 "llm_responses": ["c"],832 }833 ).get_input_jsonschema() == {834 "title": "PromptInput",835 "type": "object",836 "properties": {837 "lang": {"title": "Lang", "type": "string"},838 "name": {"title": "Name", "type": "string"},839 },840 "required": ["lang", "name"],841 }842843 chain_with_map_configurable: Runnable = prompt_configurable | {844 "llm1": fake_llm_configurable | StrOutputParser(),845 "llm2": fake_llm_configurable | StrOutputParser(),846 "llm3": fake_llm.configurable_fields(847 responses=ConfigurableField("other_responses")848 )849 | StrOutputParser(),850 }851852 assert chain_with_map_configurable.invoke({"name": "John"}) == {853 "llm1": "a",854 "llm2": "a",855 "llm3": "a",856 }857858 if PYDANTIC_VERSION_AT_LEAST_29:859 assert _normalize_schema(860 chain_with_map_configurable.get_config_jsonschema()861 ) == snapshot(name="schema5")862863 assert chain_with_map_configurable.with_config(864 configurable={865 "prompt_template": "A very good morning to you, {name}!",866 "llm_responses": ["c"],867 "other_responses": ["d"],868 }869 ).invoke({"name": "John"}) == {"llm1": "c", "llm2": "c", "llm3": "d"}870871872def test_configurable_alts_factory() -> None:873 fake_llm = FakeListLLM(responses=["a"]).configurable_alternatives(874 ConfigurableField(id="llm", name="LLM"),875 chat=partial(FakeListLLM, responses=["b"]),876 )877878 assert fake_llm.invoke("...") == "a"879880 assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b"881882883def test_configurable_fields_prefix_keys(snapshot: SnapshotAssertion) -> None:884 fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(885 responses=ConfigurableFieldMultiOption(886 id="responses",887 name="Chat Responses",888 options={889 "hello": "A good morning to you!",890 "bye": "See you later!",891 "helpful": "How can I help you?",892 },893 default=["hello", "bye"],894 ),895 # (sleep is a configurable field in FakeListChatModel)896 sleep=ConfigurableField(897 id="chat_sleep",898 is_shared=True,899 ),900 )901 fake_llm = (902 FakeListLLM(responses=["a"])903 .configurable_fields(904 responses=ConfigurableField(905 id="responses",906 name="LLM Responses",907 description="A list of fake responses for this LLM",908 )909 )910 .configurable_alternatives(911 ConfigurableField(id="llm", name="LLM"),912 chat=fake_chat | StrOutputParser(),913 prefix_keys=True,914 )915 )916 prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(917 template=ConfigurableFieldSingleOption(918 id="prompt_template",919 name="Prompt Template",920 description="The prompt template for this chain",921 options={922 "hello": "Hello, {name}!",923 "good_morning": "A very good morning to you, {name}!",924 },925 default="hello",926 )927 )928929 chain = prompt | fake_llm930931 if PYDANTIC_VERSION_AT_LEAST_29:932 assert _normalize_schema(_schema(chain.config_schema())) == snapshot(933 name="schema6"934 )935936937def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None:938 fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(939 responses=ConfigurableFieldMultiOption(940 id="chat_responses",941 name="Chat Responses",942 options={943 "hello": "A good morning to you!",944 "bye": "See you later!",945 "helpful": "How can I help you?",946 },947 default=["hello", "bye"],948 )949 )950 fake_llm = (951 FakeListLLM(responses=["a"])952 .configurable_fields(953 responses=ConfigurableField(954 id="llm_responses",955 name="LLM Responses",956 description="A list of fake responses for this LLM",957 )958 )959 .configurable_alternatives(960 ConfigurableField(id="llm", name="LLM"),961 chat=fake_chat | StrOutputParser(),962 )963 )964965 prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(966 template=ConfigurableFieldSingleOption(967 id="prompt_template",968 name="Prompt Template",969 description="The prompt template for this chain",970 options={971 "hello": "Hello, {name}!",972 "good_morning": "A very good morning to you, {name}!",973 },974 default="hello",975 )976 )977978 # deduplication of configurable fields979 chain_configurable = prompt | fake_llm | (lambda x: {"name": x}) | prompt | fake_llm980981 assert chain_configurable.invoke({"name": "John"}) == "a"982983 if PYDANTIC_VERSION_AT_LEAST_29:984 assert _normalize_schema(985 chain_configurable.get_config_jsonschema()986 ) == snapshot(name="schema7")987988 assert (989 chain_configurable.with_config(configurable={"llm": "chat"}).invoke(990 {"name": "John"}991 )992 == "A good morning to you!"993 )994995 assert (996 chain_configurable.with_config(997 configurable={"llm": "chat", "chat_responses": ["helpful"]}998 ).invoke({"name": "John"})999 == "How can I help you?"1000 )100110021003def test_passthrough_tap(mocker: MockerFixture) -> None:1004 fake = FakeRunnable()1005 mock = mocker.Mock()10061007 seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)10081009 assert seq.invoke("hello", my_kwarg="value") == 51010 assert mock.call_args_list == [1011 mocker.call("hello", my_kwarg="value"),1012 mocker.call(5),1013 ]1014 mock.reset_mock()10151016 assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6]1017 assert len(mock.call_args_list) == 41018 for call in [1019 mocker.call("hello", my_kwarg="value"),1020 mocker.call("byebye", my_kwarg="value"),1021 mocker.call(5),1022 mocker.call(6),1023 ]:1024 assert call in mock.call_args_list1025 mock.reset_mock()10261027 assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [1028 5,1029 6,1030 ]1031 assert len(mock.call_args_list) == 41032 for call in [1033 mocker.call("hello", my_kwarg="value"),1034 mocker.call("byebye", my_kwarg="value"),1035 mocker.call(5),1036 mocker.call(6),1037 ]:1038 assert call in mock.call_args_list1039 mock.reset_mock()10401041 assert sorted(1042 a1043 for a in seq.batch_as_completed(1044 ["hello", "byebye"], my_kwarg="value", return_exceptions=True1045 )1046 ) == [1047 (0, 5),1048 (1, 6),1049 ]1050 assert len(mock.call_args_list) == 41051 for call in [1052 mocker.call("hello", my_kwarg="value"),1053 mocker.call("byebye", my_kwarg="value"),1054 mocker.call(5),1055 mocker.call(6),1056 ]:1057 assert call in mock.call_args_list1058 mock.reset_mock()10591060 assert list(1061 seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value")1062 ) == [5]1063 assert mock.call_args_list == [1064 mocker.call("hello", my_kwarg="value"),1065 mocker.call(5),1066 ]1067 mock.reset_mock()106810691070async def test_passthrough_tap_async(mocker: MockerFixture) -> None:1071 fake = FakeRunnable()1072 mock = mocker.Mock()10731074 seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)10751076 assert await seq.ainvoke("hello", my_kwarg="value") == 51077 assert mock.call_args_list == [1078 mocker.call("hello", my_kwarg="value"),1079 mocker.call(5),1080 ]1081 mock.reset_mock()10821083 assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6]1084 assert len(mock.call_args_list) == 41085 for call in [1086 mocker.call("hello", my_kwarg="value"),1087 mocker.call("byebye", my_kwarg="value"),1088 mocker.call(5),1089 mocker.call(6),1090 ]:1091 assert call in mock.call_args_list1092 mock.reset_mock()10931094 assert await seq.abatch(1095 ["hello", "byebye"], my_kwarg="value", return_exceptions=True1096 ) == [1097 5,1098 6,1099 ]1100 assert len(mock.call_args_list) == 41101 for call in [1102 mocker.call("hello", my_kwarg="value"),1103 mocker.call("byebye", my_kwarg="value"),1104 mocker.call(5),1105 mocker.call(6),1106 ]:1107 assert call in mock.call_args_list1108 mock.reset_mock()11091110 assert sorted(1111 [1112 a1113 async for a in seq.abatch_as_completed(1114 ["hello", "byebye"], my_kwarg="value", return_exceptions=True1115 )1116 ]1117 ) == [1118 (0, 5),1119 (1, 6),1120 ]1121 assert len(mock.call_args_list) == 41122 for call in [1123 mocker.call("hello", my_kwarg="value"),1124 mocker.call("byebye", my_kwarg="value"),1125 mocker.call(5),1126 mocker.call(6),1127 ]:1128 assert call in mock.call_args_list1129 mock.reset_mock()11301131 assert [1132 part1133 async for part in seq.astream(1134 "hello", {"metadata": {"key": "value"}}, my_kwarg="value"1135 )1136 ] == [5]1137 assert mock.call_args_list == [1138 mocker.call("hello", my_kwarg="value"),1139 mocker.call(5),1140 ]114111421143async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:1144 fake = FakeRunnableSerializable()1145 spy = mocker.spy(fake.__class__, "invoke")1146 fakew = fake.configurable_fields(hello=ConfigurableField(id="hello", name="Hello"))11471148 assert (1149 fakew.with_config(tags=["a-tag"]).invoke(1150 "hello",1151 {1152 "configurable": {"hello": "there", "__secret_key": "nahnah"},1153 "metadata": {"bye": "now"},1154 },1155 )1156 == 51157 )1158 assert spy.call_args_list[0].args[1:] == (1159 "hello",1160 {1161 "tags": ["a-tag"],1162 "callbacks": None,1163 "recursion_limit": 25,1164 "configurable": {"hello": "there", "__secret_key": "nahnah"},1165 "metadata": {"bye": "now"},1166 },1167 )1168 spy.reset_mock()116911701171def test_with_config(mocker: MockerFixture) -> None:1172 fake = FakeRunnable()1173 spy = mocker.spy(fake, "invoke")11741175 assert fake.with_config(tags=["a-tag"]).invoke("hello") == 51176 assert spy.call_args_list == [1177 mocker.call(1178 "hello",1179 {"tags": ["a-tag"], "metadata": {}, "configurable": {}},1180 ),1181 ]1182 spy.reset_mock()11831184 fake_1 = RunnablePassthrough[Any]()1185 fake_2 = RunnablePassthrough[Any]()1186 spy_seq_step = mocker.spy(fake_1.__class__, "invoke")11871188 sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(1189 tags=["b-tag"], max_concurrency=51190 )1191 assert sequence.invoke("hello") == "hello"1192 assert len(spy_seq_step.call_args_list) == 21193 for i, call in enumerate(spy_seq_step.call_args_list):1194 assert call.args[1] == "hello"1195 if i == 0:1196 assert call.args[2].get("tags") == ["a-tag"]1197 assert call.args[2].get("max_concurrency") is None1198 else:1199 assert call.args[2].get("tags") == ["b-tag"]1200 assert call.args[2].get("max_concurrency") == 51201 mocker.stop(spy_seq_step)12021203 assert [1204 *fake.with_config(tags=["a-tag"]).stream(1205 "hello", {"metadata": {"key": "value"}}1206 )1207 ] == [5]1208 assert spy.call_args_list == [1209 mocker.call(1210 "hello",1211 {"tags": ["a-tag"], "metadata": {"key": "value"}, "configurable": {}},1212 ),1213 ]1214 spy.reset_mock()12151216 assert fake.with_config(recursion_limit=5).batch(1217 ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]1218 ) == [5, 7]12191220 assert len(spy.call_args_list) == 21221 for i, call in enumerate(1222 sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)1223 ):1224 assert call.args[0] == ("hello" if i == 0 else "wooorld")1225 if i == 0:1226 assert call.args[1].get("recursion_limit") == 51227 assert call.args[1].get("tags") == ["a-tag"]1228 assert call.args[1].get("metadata") == {}1229 else:1230 assert call.args[1].get("recursion_limit") == 51231 assert call.args[1].get("tags") == []1232 assert call.args[1].get("metadata") == {"key": "value"}12331234 spy.reset_mock()12351236 assert sorted(1237 c1238 for c in fake.with_config(recursion_limit=5).batch_as_completed(1239 ["hello", "wooorld"],1240 [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}],1241 )1242 ) == [(0, 5), (1, 7)]12431244 assert len(spy.call_args_list) == 21245 for i, call in enumerate(1246 sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)1247 ):1248 assert call.args[0] == ("hello" if i == 0 else "wooorld")1249 if i == 0:1250 assert call.args[1].get("recursion_limit") == 51251 assert call.args[1].get("tags") == ["a-tag"]1252 assert call.args[1].get("metadata") == {}1253 else:1254 assert call.args[1].get("recursion_limit") == 51255 assert call.args[1].get("tags") == []1256 assert call.args[1].get("metadata") == {"key": "value"}12571258 spy.reset_mock()12591260 assert fake.with_config(metadata={"a": "b"}).batch(1261 ["hello", "wooorld"], {"tags": ["a-tag"]}1262 ) == [5, 7]1263 assert len(spy.call_args_list) == 21264 for i, call in enumerate(spy.call_args_list):1265 assert call.args[0] == ("hello" if i == 0 else "wooorld")1266 assert call.args[1].get("tags") == ["a-tag"]1267 assert call.args[1].get("metadata") == {"a": "b"}1268 spy.reset_mock()12691270 assert sorted(1271 c for c in fake.batch_as_completed(["hello", "wooorld"], {"tags": ["a-tag"]})1272 ) == [(0, 5), (1, 7)]1273 assert len(spy.call_args_list) == 21274 for i, call in enumerate(spy.call_args_list):1275 assert call.args[0] == ("hello" if i == 0 else "wooorld")1276 assert call.args[1].get("tags") == ["a-tag"]127712781279async def test_with_config_async(mocker: MockerFixture) -> None:1280 fake = FakeRunnable()1281 spy = mocker.spy(fake, "invoke")12821283 handler = ConsoleCallbackHandler()1284 assert (1285 await fake.with_config(metadata={"a": "b"}).ainvoke(1286 "hello", config={"callbacks": [handler]}1287 )1288 == 51289 )1290 assert spy.call_args_list == [1291 mocker.call(1292 "hello",1293 {1294 "callbacks": [handler],1295 "metadata": {"a": "b"},1296 "configurable": {},1297 "tags": [],1298 },1299 ),1300 ]1301 spy.reset_mock()13021303 assert [1304 part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")1305 ] == [5]1306 assert spy.call_args_list == [1307 mocker.call("hello", {"metadata": {"a": "b"}, "tags": [], "configurable": {}}),1308 ]1309 spy.reset_mock()13101311 assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(1312 ["hello", "wooorld"], {"metadata": {"key": "value"}}1313 ) == [1314 5,1315 7,1316 ]1317 assert sorted(spy.call_args_list) == [1318 mocker.call(1319 "hello",1320 {1321 "metadata": {"key": "value"},1322 "tags": ["c"],1323 "callbacks": None,1324 "recursion_limit": 5,1325 "configurable": {},1326 },1327 ),1328 mocker.call(1329 "wooorld",1330 {1331 "metadata": {"key": "value"},1332 "tags": ["c"],1333 "callbacks": None,1334 "recursion_limit": 5,1335 "configurable": {},1336 },1337 ),1338 ]1339 spy.reset_mock()13401341 assert sorted(1342 [1343 c1344 async for c in fake.with_config(1345 recursion_limit=5, tags=["c"]1346 ).abatch_as_completed(["hello", "wooorld"], {"metadata": {"key": "value"}})1347 ]1348 ) == [1349 (0, 5),1350 (1, 7),1351 ]1352 assert len(spy.call_args_list) == 21353 first_call = next(call for call in spy.call_args_list if call.args[0] == "hello")1354 assert first_call == mocker.call(1355 "hello",1356 {1357 "metadata": {"key": "value"},1358 "tags": ["c"],1359 "callbacks": None,1360 "recursion_limit": 5,1361 "configurable": {},1362 },1363 )1364 second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")1365 assert second_call == mocker.call(1366 "wooorld",1367 {1368 "metadata": {"key": "value"},1369 "tags": ["c"],1370 "callbacks": None,1371 "recursion_limit": 5,1372 "configurable": {},1373 },1374 )137513761377def test_default_method_implementations(mocker: MockerFixture) -> None:1378 fake = FakeRunnable()1379 spy = mocker.spy(fake, "invoke")13801381 assert fake.invoke("hello", {"tags": ["a-tag"]}) == 51382 assert spy.call_args_list == [1383 mocker.call("hello", {"tags": ["a-tag"]}),1384 ]1385 spy.reset_mock()13861387 assert [*fake.stream("hello", {"metadata": {"key": "value"}})] == [5]1388 assert spy.call_args_list == [1389 mocker.call("hello", {"metadata": {"key": "value"}}),1390 ]1391 spy.reset_mock()13921393 assert fake.batch(1394 ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]1395 ) == [5, 7]13961397 assert len(spy.call_args_list) == 21398 for call in spy.call_args_list:1399 call_arg = call.args[0]14001401 if call_arg == "hello":1402 assert call_arg == "hello"1403 assert call.args[1].get("tags") == ["a-tag"]1404 assert call.args[1].get("metadata") == {}1405 else:1406 assert call_arg == "wooorld"1407 assert call.args[1].get("tags") == []1408 assert call.args[1].get("metadata") == {"key": "value"}14091410 spy.reset_mock()14111412 assert fake.batch(["hello", "wooorld"], {"tags": ["a-tag"]}) == [5, 7]1413 assert len(spy.call_args_list) == 21414 assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}1415 for call in spy.call_args_list:1416 assert call.args[1].get("tags") == ["a-tag"]1417 assert call.args[1].get("metadata") == {}141814191420async def test_default_method_implementations_async(mocker: MockerFixture) -> None:1421 fake = FakeRunnable()1422 spy = mocker.spy(fake, "invoke")14231424 assert await fake.ainvoke("hello", config={"callbacks": []}) == 51425 assert spy.call_args_list == [1426 mocker.call("hello", {"callbacks": []}),1427 ]1428 spy.reset_mock()14291430 assert [part async for part in fake.astream("hello")] == [5]1431 assert spy.call_args_list == [1432 mocker.call("hello", None),1433 ]1434 spy.reset_mock()14351436 assert await fake.abatch(["hello", "wooorld"], {"metadata": {"key": "value"}}) == [1437 5,1438 7,1439 ]1440 assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}1441 for call in spy.call_args_list:1442 assert call.args[1] == {1443 "metadata": {"key": "value"},1444 "tags": [],1445 "callbacks": None,1446 "recursion_limit": 25,1447 "configurable": {},1448 }144914501451def test_prompt() -> None:1452 prompt = ChatPromptTemplate.from_messages(1453 messages=[1454 SystemMessage(content="You are a nice assistant."),1455 HumanMessagePromptTemplate.from_template("{question}"),1456 ]1457 )1458 expected = ChatPromptValue(1459 messages=[1460 SystemMessage(content="You are a nice assistant."),1461 HumanMessage(content="What is your name?"),1462 ]1463 )14641465 assert prompt.invoke({"question": "What is your name?"}) == expected14661467 assert prompt.batch(1468 [1469 {"question": "What is your name?"},1470 {"question": "What is your favorite color?"},1471 ]1472 ) == [1473 expected,1474 ChatPromptValue(1475 messages=[1476 SystemMessage(content="You are a nice assistant."),1477 HumanMessage(content="What is your favorite color?"),1478 ]1479 ),1480 ]14811482 assert [*prompt.stream({"question": "What is your name?"})] == [expected]148314841485async def test_prompt_async() -> None:1486 prompt = ChatPromptTemplate.from_messages(1487 messages=[1488 SystemMessage(content="You are a nice assistant."),1489 HumanMessagePromptTemplate.from_template("{question}"),1490 ]1491 )1492 expected = ChatPromptValue(1493 messages=[1494 SystemMessage(content="You are a nice assistant."),1495 HumanMessage(content="What is your name?"),1496 ]1497 )14981499 assert await prompt.ainvoke({"question": "What is your name?"}) == expected15001501 assert await prompt.abatch(1502 [1503 {"question": "What is your name?"},1504 {"question": "What is your favorite color?"},1505 ]1506 ) == [1507 expected,1508 ChatPromptValue(1509 messages=[1510 SystemMessage(content="You are a nice assistant."),1511 HumanMessage(content="What is your favorite color?"),1512 ]1513 ),1514 ]15151516 assert [1517 part async for part in prompt.astream({"question": "What is your name?"})1518 ] == [expected]15191520 stream_log = [1521 part async for part in prompt.astream_log({"question": "What is your name?"})1522 ]15231524 assert len(stream_log[0].ops) == 11525 assert stream_log[0].ops[0]["op"] == "replace"1526 assert stream_log[0].ops[0]["path"] == ""1527 assert stream_log[0].ops[0]["value"]["logs"] == {}1528 assert stream_log[0].ops[0]["value"]["final_output"] is None1529 assert stream_log[0].ops[0]["value"]["streamed_output"] == []1530 assert isinstance(stream_log[0].ops[0]["value"]["id"], str)15311532 assert stream_log[1:] == [1533 RunLogPatch(1534 {"op": "add", "path": "/streamed_output/-", "value": expected},1535 {1536 "op": "replace",1537 "path": "/final_output",1538 "value": ChatPromptValue(1539 messages=[1540 SystemMessage(content="You are a nice assistant."),1541 HumanMessage(content="What is your name?"),1542 ]1543 ),1544 },1545 ),1546 ]15471548 stream_log_state = [1549 part1550 async for part in prompt.astream_log(1551 {"question": "What is your name?"}, diff=False1552 )1553 ]15541555 # remove random id1556 stream_log[0].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"1557 stream_log_state[-1].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"1558 stream_log_state[-1].state["id"] = "00000000-0000-0000-0000-000000000000"15591560 # assert output with diff=False matches output with diff=True1561 assert stream_log_state[-1].ops == [op for chunk in stream_log for op in chunk.ops]1562 assert stream_log_state[-1] == RunLog(1563 *[op for chunk in stream_log for op in chunk.ops],1564 state={1565 "final_output": ChatPromptValue(1566 messages=[1567 SystemMessage(content="You are a nice assistant."),1568 HumanMessage(content="What is your name?"),1569 ]1570 ),1571 "id": "00000000-0000-0000-0000-000000000000",1572 "logs": {},1573 "streamed_output": [1574 ChatPromptValue(1575 messages=[1576 SystemMessage(content="You are a nice assistant."),1577 HumanMessage(content="What is your name?"),1578 ]1579 )1580 ],1581 "type": "prompt",1582 "name": "ChatPromptTemplate",1583 },1584 )15851586 # nested inside trace_with_chain_group15871588 async with atrace_as_chain_group("a_group") as manager:1589 stream_log_nested = [1590 part1591 async for part in prompt.astream_log(1592 {"question": "What is your name?"}, config={"callbacks": manager}1593 )1594 ]15951596 assert len(stream_log_nested[0].ops) == 11597 assert stream_log_nested[0].ops[0]["op"] == "replace"1598 assert stream_log_nested[0].ops[0]["path"] == ""1599 assert stream_log_nested[0].ops[0]["value"]["logs"] == {}1600 assert stream_log_nested[0].ops[0]["value"]["final_output"] is None1601 assert stream_log_nested[0].ops[0]["value"]["streamed_output"] == []1602 assert isinstance(stream_log_nested[0].ops[0]["value"]["id"], str)16031604 assert stream_log_nested[1:] == [1605 RunLogPatch(1606 {"op": "add", "path": "/streamed_output/-", "value": expected},1607 {1608 "op": "replace",1609 "path": "/final_output",1610 "value": ChatPromptValue(1611 messages=[1612 SystemMessage(content="You are a nice assistant."),1613 HumanMessage(content="What is your name?"),1614 ]1615 ),1616 },1617 ),1618 ]161916201621def test_prompt_template_params() -> None:1622 prompt = ChatPromptTemplate.from_template(1623 "Respond to the following question: {question}"1624 )1625 result = prompt.invoke(1626 {1627 "question": "test",1628 "topic": "test",1629 }1630 )1631 assert result == ChatPromptValue(1632 messages=[HumanMessage(content="Respond to the following question: test")]1633 )16341635 with pytest.raises(KeyError):1636 prompt.invoke({})163716381639def test_with_listeners(mocker: MockerFixture) -> None:1640 prompt = (1641 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1642 + "{question}"1643 )1644 chat = FakeListChatModel(responses=["foo"])16451646 chain = prompt | chat16471648 mock_start = mocker.Mock()1649 mock_end = mocker.Mock()16501651 chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(1652 {"question": "Who are you?"}1653 )16541655 assert mock_start.call_count == 11656 assert mock_start.call_args[0][0].name == "RunnableSequence"1657 assert mock_end.call_count == 116581659 mock_start.reset_mock()1660 mock_end.reset_mock()16611662 with trace_as_chain_group("hello") as manager:1663 chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(1664 {"question": "Who are you?"}, {"callbacks": manager}1665 )16661667 assert mock_start.call_count == 11668 assert mock_start.call_args[0][0].name == "RunnableSequence"1669 assert mock_end.call_count == 1167016711672async def test_with_listeners_async(mocker: MockerFixture) -> None:1673 prompt = (1674 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1675 + "{question}"1676 )1677 chat = FakeListChatModel(responses=["foo"])16781679 chain = prompt | chat16801681 mock_start = mocker.Mock()1682 mock_end = mocker.Mock()16831684 await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(1685 {"question": "Who are you?"}1686 )16871688 assert mock_start.call_count == 11689 assert mock_start.call_args[0][0].name == "RunnableSequence"1690 assert mock_end.call_count == 116911692 mock_start.reset_mock()1693 mock_end.reset_mock()16941695 async with atrace_as_chain_group("hello") as manager:1696 await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(1697 {"question": "Who are you?"}, {"callbacks": manager}1698 )16991700 assert mock_start.call_count == 11701 assert mock_start.call_args[0][0].name == "RunnableSequence"1702 assert mock_end.call_count == 1170317041705def test_with_listener_propagation(mocker: MockerFixture) -> None:1706 prompt = (1707 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1708 + "{question}"1709 )1710 chat = FakeListChatModel(responses=["foo"])1711 chain: Runnable = prompt | chat1712 mock_start = mocker.Mock()1713 mock_end = mocker.Mock()1714 chain_with_listeners = chain.with_listeners(on_start=mock_start, on_end=mock_end)17151716 chain_with_listeners.with_retry().invoke({"question": "Who are you?"})17171718 assert mock_start.call_count == 11719 assert mock_start.call_args[0][0].name == "RunnableSequence"1720 assert mock_end.call_count == 117211722 mock_start.reset_mock()1723 mock_end.reset_mock()17241725 chain_with_listeners.with_types(output_type=str).invoke(1726 {"question": "Who are you?"}1727 )17281729 assert mock_start.call_count == 11730 assert mock_start.call_args[0][0].name == "RunnableSequence"1731 assert mock_end.call_count == 117321733 mock_start.reset_mock()1734 mock_end.reset_mock()17351736 chain_with_listeners.with_config({"tags": ["foo"]}).invoke(1737 {"question": "Who are you?"}1738 )17391740 assert mock_start.call_count == 11741 assert mock_start.call_args[0][0].name == "RunnableSequence"1742 assert mock_end.call_count == 117431744 mock_start.reset_mock()1745 mock_end.reset_mock()17461747 chain_with_listeners.bind(stop=["foo"]).invoke({"question": "Who are you?"})17481749 assert mock_start.call_count == 11750 assert mock_start.call_args[0][0].name == "RunnableSequence"1751 assert mock_end.call_count == 117521753 mock_start.reset_mock()1754 mock_end.reset_mock()17551756 mock_start_inner = mocker.Mock()1757 mock_end_inner = mocker.Mock()17581759 chain_with_listeners.with_listeners(1760 on_start=mock_start_inner, on_end=mock_end_inner1761 ).invoke({"question": "Who are you?"})17621763 assert mock_start.call_count == 11764 assert mock_start.call_args[0][0].name == "RunnableSequence"1765 assert mock_end.call_count == 11766 assert mock_start_inner.call_count == 11767 assert mock_start_inner.call_args[0][0].name == "RunnableSequence"1768 assert mock_end_inner.call_count == 1176917701771@freeze_time("2023-01-01")1772@pytest.mark.usefixtures("deterministic_uuids")1773def test_prompt_with_chat_model(1774 mocker: MockerFixture,1775 snapshot: SnapshotAssertion,1776) -> None:1777 prompt = (1778 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1779 + "{question}"1780 )1781 chat = FakeListChatModel(responses=["foo"])17821783 chain = prompt | chat17841785 assert repr(chain) == snapshot1786 assert isinstance(chain, RunnableSequence)1787 assert chain.first == prompt1788 assert chain.middle == []1789 assert chain.last == chat1790 assert dumps(chain, pretty=True) == snapshot17911792 # Test invoke1793 prompt_spy = mocker.spy(prompt.__class__, "invoke")1794 chat_spy = mocker.spy(chat.__class__, "invoke")1795 tracer = FakeTracer()1796 assert chain.invoke(1797 {"question": "What is your name?"}, {"callbacks": [tracer]}1798 ) == _any_id_ai_message(content="foo")1799 assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1800 assert chat_spy.call_args.args[1] == ChatPromptValue(1801 messages=[1802 SystemMessage(content="You are a nice assistant."),1803 HumanMessage(content="What is your name?"),1804 ]1805 )18061807 assert tracer.runs == snapshot18081809 mocker.stop(prompt_spy)1810 mocker.stop(chat_spy)18111812 # Test batch1813 prompt_spy = mocker.spy(prompt.__class__, "batch")1814 chat_spy = mocker.spy(chat.__class__, "batch")1815 tracer = FakeTracer()1816 assert chain.batch(1817 [1818 {"question": "What is your name?"},1819 {"question": "What is your favorite color?"},1820 ],1821 {"callbacks": [tracer]},1822 ) == [1823 _any_id_ai_message(content="foo"),1824 _any_id_ai_message(content="foo"),1825 ]1826 assert prompt_spy.call_args.args[1] == [1827 {"question": "What is your name?"},1828 {"question": "What is your favorite color?"},1829 ]1830 assert chat_spy.call_args.args[1] == [1831 ChatPromptValue(1832 messages=[1833 SystemMessage(content="You are a nice assistant."),1834 HumanMessage(content="What is your name?"),1835 ]1836 ),1837 ChatPromptValue(1838 messages=[1839 SystemMessage(content="You are a nice assistant."),1840 HumanMessage(content="What is your favorite color?"),1841 ]1842 ),1843 ]1844 assert (1845 len(1846 [1847 r1848 for r in tracer.runs1849 if r.parent_run_id is None and len(r.child_runs) == 21850 ]1851 )1852 == 21853 ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"1854 mocker.stop(prompt_spy)1855 mocker.stop(chat_spy)18561857 # Test stream1858 prompt_spy = mocker.spy(prompt.__class__, "invoke")1859 chat_spy = mocker.spy(chat.__class__, "stream")1860 tracer = FakeTracer()1861 assert [1862 *chain.stream({"question": "What is your name?"}, {"callbacks": [tracer]})1863 ] == [1864 _any_id_ai_message_chunk(content="f"),1865 _any_id_ai_message_chunk(content="o"),1866 _any_id_ai_message_chunk(content="o", chunk_position="last"),1867 ]1868 assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1869 assert chat_spy.call_args.args[1] == ChatPromptValue(1870 messages=[1871 SystemMessage(content="You are a nice assistant."),1872 HumanMessage(content="What is your name?"),1873 ]1874 )187518761877@freeze_time("2023-01-01")1878@pytest.mark.usefixtures("deterministic_uuids")1879async def test_prompt_with_chat_model_async(1880 mocker: MockerFixture,1881 snapshot: SnapshotAssertion,1882) -> None:1883 prompt = (1884 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1885 + "{question}"1886 )1887 chat = FakeListChatModel(responses=["foo"])18881889 chain = prompt | chat18901891 assert repr(chain) == snapshot1892 assert isinstance(chain, RunnableSequence)1893 assert chain.first == prompt1894 assert chain.middle == []1895 assert chain.last == chat1896 assert dumps(chain, pretty=True) == snapshot18971898 # Test invoke1899 prompt_spy = mocker.spy(prompt.__class__, "ainvoke")1900 chat_spy = mocker.spy(chat.__class__, "ainvoke")1901 tracer = FakeTracer()1902 assert await chain.ainvoke(1903 {"question": "What is your name?"}, {"callbacks": [tracer]}1904 ) == _any_id_ai_message(content="foo")1905 assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1906 assert chat_spy.call_args.args[1] == ChatPromptValue(1907 messages=[1908 SystemMessage(content="You are a nice assistant."),1909 HumanMessage(content="What is your name?"),1910 ]1911 )19121913 assert tracer.runs == snapshot19141915 mocker.stop(prompt_spy)1916 mocker.stop(chat_spy)19171918 # Test batch1919 prompt_spy = mocker.spy(prompt.__class__, "abatch")1920 chat_spy = mocker.spy(chat.__class__, "abatch")1921 tracer = FakeTracer()1922 assert await chain.abatch(1923 [1924 {"question": "What is your name?"},1925 {"question": "What is your favorite color?"},1926 ],1927 {"callbacks": [tracer]},1928 ) == [1929 _any_id_ai_message(content="foo"),1930 _any_id_ai_message(content="foo"),1931 ]1932 assert prompt_spy.call_args.args[1] == [1933 {"question": "What is your name?"},1934 {"question": "What is your favorite color?"},1935 ]1936 assert chat_spy.call_args.args[1] == [1937 ChatPromptValue(1938 messages=[1939 SystemMessage(content="You are a nice assistant."),1940 HumanMessage(content="What is your name?"),1941 ]1942 ),1943 ChatPromptValue(1944 messages=[1945 SystemMessage(content="You are a nice assistant."),1946 HumanMessage(content="What is your favorite color?"),1947 ]1948 ),1949 ]1950 assert (1951 len(1952 [1953 r1954 for r in tracer.runs1955 if r.parent_run_id is None and len(r.child_runs) == 21956 ]1957 )1958 == 21959 ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"1960 mocker.stop(prompt_spy)1961 mocker.stop(chat_spy)19621963 # Test stream1964 prompt_spy = mocker.spy(prompt.__class__, "ainvoke")1965 chat_spy = mocker.spy(chat.__class__, "astream")1966 tracer = FakeTracer()1967 assert [1968 a1969 async for a in chain.astream(1970 {"question": "What is your name?"}, {"callbacks": [tracer]}1971 )1972 ] == [1973 _any_id_ai_message_chunk(content="f"),1974 _any_id_ai_message_chunk(content="o"),1975 _any_id_ai_message_chunk(content="o", chunk_position="last"),1976 ]1977 assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1978 assert chat_spy.call_args.args[1] == ChatPromptValue(1979 messages=[1980 SystemMessage(content="You are a nice assistant."),1981 HumanMessage(content="What is your name?"),1982 ]1983 )198419851986@pytest.mark.skipif(1987 condition=sys.version_info[1] == 13,1988 reason=(1989 "temporary, py3.13 exposes some invalid assumptions about order of batch async "1990 "executions."1991 ),1992)1993@freeze_time("2023-01-01")1994async def test_prompt_with_llm(1995 mocker: MockerFixture, snapshot: SnapshotAssertion1996) -> None:1997 prompt = (1998 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1999 + "{question}"2000 )
Findings
✓ No findings reported for this file.