1import asyncio2import re3import sys4import time5import uuid6import warnings7from collections.abc import (8 AsyncIterator,9 Awaitable,10 Callable,11 Iterator,12 Sequence,13)14from functools import partial15from operator import itemgetter16from typing import Any, cast17from uuid import UUID1819import pytest20from freezegun import freeze_time21from packaging import version22from pydantic import BaseModel, Field, ValidationError23from pydantic.v1 import BaseModel as BaseModelV124from pydantic.v1 import Field as FieldV125from pydantic.v1 import ValidationError as ValidationErrorV126from pytest_mock import MockerFixture27from syrupy.assertion import SnapshotAssertion28from typing_extensions import TypedDict, override2930from langchain_core.callbacks import BaseCallbackHandler31from langchain_core.callbacks.manager import (32 AsyncCallbackManagerForRetrieverRun,33 CallbackManagerForRetrieverRun,34 atrace_as_chain_group,35 trace_as_chain_group,36)37from langchain_core.documents import Document38from langchain_core.language_models import (39 FakeListChatModel,40 FakeListLLM,41 FakeStreamingListLLM,42)43from langchain_core.language_models.fake_chat_models import GenericFakeChatModel44from langchain_core.load import dumpd, dumps45from langchain_core.load.load import loads46from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage47from langchain_core.messages.base import BaseMessage48from langchain_core.output_parsers import (49 BaseOutputParser,50 CommaSeparatedListOutputParser,51 StrOutputParser,52)53from langchain_core.outputs.chat_generation import ChatGeneration54from langchain_core.outputs.llm_result import LLMResult55from langchain_core.prompt_values import ChatPromptValue, StringPromptValue56from langchain_core.prompts import (57 ChatPromptTemplate,58 HumanMessagePromptTemplate,59 MessagesPlaceholder,60 PromptTemplate,61 SystemMessagePromptTemplate,62)63from langchain_core.retrievers import BaseRetriever64from langchain_core.runnables import (65 AddableDict,66 ConfigurableField,67 ConfigurableFieldMultiOption,68 ConfigurableFieldSingleOption,69 RouterRunnable,70 Runnable,71 RunnableAssign,72 RunnableBinding,73 RunnableBranch,74 RunnableConfig,75 RunnableGenerator,76 RunnableLambda,77 RunnableParallel,78 RunnablePassthrough,79 RunnablePick,80 RunnableSequence,81 add,82 chain,83)84from langchain_core.runnables.base import RunnableMap, RunnableSerializable85from langchain_core.runnables.utils import Input, Output86from langchain_core.tools import BaseTool, tool87from langchain_core.tracers import (88 BaseTracer,89 ConsoleCallbackHandler,90 Run,91 RunLog,92 RunLogPatch,93)94from langchain_core.tracers._compat import pydantic_copy95from langchain_core.tracers.context import collect_runs96from langchain_core.utils.pydantic import (97 PYDANTIC_VERSION,98 TypeBaseModel,99 model_validate,100)101from langchain_core.version import VERSION102from tests.unit_tests.pydantic_utils import (103 _normalize_schema,104 _schema,105 skip_if_no_pydantic_v1,106)107from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk108109# Several tests assert the legacy `RunLog` / `RunLogPatch` output produced by110# `astream_log`, which cannot be replaced by `astream` without losing coverage.111pytestmark = pytest.mark.filterwarnings(112 "ignore:astream_log is deprecated. Use astream instead.:"113 "langchain_core._api.deprecation.LangChainDeprecationWarning"114)115116PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION117PYDANTIC_VERSION_AT_LEAST_210 = version.parse("2.10") <= PYDANTIC_VERSION118119120class FakeTracer(BaseTracer):121 """Fake tracer that records LangChain execution.122123 It replaces run IDs with deterministic UUIDs for snapshotting.124 """125126 def __init__(self) -> None:127 """Initialize the tracer."""128 super().__init__()129 self.runs: list[Run] = []130 self.uuids_map: dict[UUID, UUID] = {}131 self.uuids_generator = (132 UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)133 )134135 def _replace_uuid(self, uuid: UUID) -> UUID:136 if uuid not in self.uuids_map:137 self.uuids_map[uuid] = next(self.uuids_generator)138 return self.uuids_map[uuid]139140 def _replace_message_id(self, maybe_message: Any) -> Any:141 if isinstance(maybe_message, BaseMessage):142 maybe_message.id = str(next(self.uuids_generator))143 if isinstance(maybe_message, ChatGeneration):144 maybe_message.message.id = str(next(self.uuids_generator))145 if isinstance(maybe_message, LLMResult):146 for i, gen_list in enumerate(maybe_message.generations):147 for j, gen in enumerate(gen_list):148 maybe_message.generations[i][j] = self._replace_message_id(gen)149 if isinstance(maybe_message, dict):150 for k, v in maybe_message.items():151 maybe_message[k] = self._replace_message_id(v)152 if isinstance(maybe_message, list):153 for i, v in enumerate(maybe_message):154 maybe_message[i] = self._replace_message_id(v)155156 return maybe_message157158 def _copy_run(self, run: Run) -> Run:159 if run.dotted_order:160 levels = run.dotted_order.split(".")161 processed_levels = []162 for level in levels:163 timestamp, run_id = level.split("Z")164 new_run_id = self._replace_uuid(UUID(run_id))165 processed_level = f"{timestamp}Z{new_run_id}"166 processed_levels.append(processed_level)167 new_dotted_order = ".".join(processed_levels)168 else:169 new_dotted_order = None170 update_dict = {171 "id": self._replace_uuid(run.id),172 "parent_run_id": (173 self.uuids_map[run.parent_run_id] if run.parent_run_id else None174 ),175 "child_runs": [self._copy_run(child) for child in run.child_runs],176 "trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,177 "dotted_order": new_dotted_order,178 "inputs": self._replace_message_id(run.inputs),179 "outputs": self._replace_message_id(run.outputs),180 }181 return pydantic_copy(run, update=update_dict)182183 def _persist_run(self, run: Run) -> None:184 """Persist a run."""185 self.runs.append(self._copy_run(run))186187 def flattened_runs(self) -> list[Run]:188 q = [*self.runs]189 result = []190 while q:191 parent = q.pop()192 result.append(parent)193 if parent.child_runs:194 q.extend(parent.child_runs)195 return result196197 @property198 def run_ids(self) -> list[uuid.UUID | None]:199 runs = self.flattened_runs()200 uuids_map = {v: k for k, v in self.uuids_map.items()}201 return [uuids_map.get(r.id) for r in runs]202203204class FakeRunnable(Runnable[str, int]):205 @override206 def invoke(207 self,208 input: str,209 config: RunnableConfig | None = None,210 **kwargs: Any,211 ) -> int:212 return len(input)213214215class FakeRunnableSerializable(RunnableSerializable[str, int]):216 hello: str = ""217218 @override219 def invoke(220 self,221 input: str,222 config: RunnableConfig | None = None,223 **kwargs: Any,224 ) -> int:225 return len(input)226227228class FakeRetriever(BaseRetriever):229 @override230 def _get_relevant_documents(231 self, query: str, *, run_manager: CallbackManagerForRetrieverRun232 ) -> list[Document]:233 return [Document(page_content="foo"), Document(page_content="bar")]234235 @override236 async def _aget_relevant_documents(237 self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun238 ) -> list[Document]:239 return [Document(page_content="foo"), Document(page_content="bar")]240241242@pytest.mark.skipif(243 PYDANTIC_VERSION_AT_LEAST_210,244 reason=(245 "Only test with most recent version of pydantic. "246 "Pydantic introduced small fixes to generated JSONSchema on minor versions."247 ),248)249def test_schemas(snapshot: SnapshotAssertion) -> None:250 fake = FakeRunnable() # str -> int251252 assert fake.get_input_jsonschema() == {253 "title": "FakeRunnableInput",254 "type": "string",255 }256 assert fake.get_output_jsonschema() == {257 "title": "FakeRunnableOutput",258 "type": "integer",259 }260 assert fake.get_config_jsonschema(include=["tags", "metadata", "run_name"]) == {261 "properties": {262 "metadata": {263 "default": None,264 "title": "Metadata",265 "type": "object",266 },267 "run_name": {"default": None, "title": "Run Name", "type": "string"},268 "tags": {269 "default": None,270 "items": {"type": "string"},271 "title": "Tags",272 "type": "array",273 },274 },275 "title": "FakeRunnableConfig",276 "type": "object",277 }278279 fake_bound = FakeRunnable().bind(a="b") # str -> int280281 assert fake_bound.get_input_jsonschema() == {282 "title": "FakeRunnableInput",283 "type": "string",284 }285 assert fake_bound.get_output_jsonschema() == {286 "title": "FakeRunnableOutput",287 "type": "integer",288 }289290 fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,)) # str -> int291292 assert fake_w_fallbacks.get_input_jsonschema() == {293 "title": "FakeRunnableInput",294 "type": "string",295 }296 assert fake_w_fallbacks.get_output_jsonschema() == {297 "title": "FakeRunnableOutput",298 "type": "integer",299 }300301 def typed_lambda_impl(x: str) -> int:302 return len(x)303304 typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int305306 assert typed_lambda.get_input_jsonschema() == {307 "title": "typed_lambda_impl_input",308 "type": "string",309 }310 assert typed_lambda.get_output_jsonschema() == {311 "title": "typed_lambda_impl_output",312 "type": "integer",313 }314315 async def typed_async_lambda_impl(x: str) -> int:316 return len(x)317318 typed_async_lambda = RunnableLambda(typed_async_lambda_impl) # str -> int319320 assert typed_async_lambda.get_input_jsonschema() == {321 "title": "typed_async_lambda_impl_input",322 "type": "string",323 }324 assert typed_async_lambda.get_output_jsonschema() == {325 "title": "typed_async_lambda_impl_output",326 "type": "integer",327 }328329 fake_ret = FakeRetriever() # str -> list[Document]330331 assert fake_ret.get_input_jsonschema() == {332 "title": "FakeRetrieverInput",333 "type": "string",334 }335 assert _normalize_schema(fake_ret.get_output_jsonschema()) == {336 "$defs": {337 "Document": {338 "description": "Class for storing a piece of text and "339 "associated metadata.\n"340 "\n"341 "!!! note\n"342 "\n"343 " `Document` is for **retrieval workflows**, not chat I/O. For "344 "sending text\n"345 " to an LLM in a conversation, use message types from "346 "`langchain.messages`.\n"347 "\n"348 "Example:\n"349 " ```python\n"350 " from langchain_core.documents import Document\n"351 "\n"352 " document = Document(\n"353 ' page_content="Hello, world!", '354 'metadata={"source": "https://example.com"}\n'355 " )\n"356 " ```",357 "properties": {358 "id": {359 "anyOf": [{"type": "string"}, {"type": "null"}],360 "default": None,361 "title": "Id",362 },363 "metadata": {"title": "Metadata", "type": "object"},364 "page_content": {"title": "Page Content", "type": "string"},365 "type": {366 "const": "Document",367 "default": "Document",368 "title": "Type",369 },370 },371 "required": ["page_content"],372 "title": "Document",373 "type": "object",374 }375 },376 "items": {"$ref": "#/$defs/Document"},377 "title": "FakeRetrieverOutput",378 "type": "array",379 }380381 fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]]382383 assert _schema(fake_llm.input_schema) == snapshot(name="fake_llm_input_schema")384 assert _schema(fake_llm.output_schema) == {385 "title": "FakeListLLMOutput",386 "type": "string",387 }388389 fake_chat = FakeListChatModel(responses=["a"]) # str -> list[list[str]]390391 assert _schema(fake_chat.input_schema) == snapshot(name="fake_chat_input_schema")392 assert _schema(fake_chat.output_schema) == snapshot(name="fake_chat_output_schema")393394 chat_prompt = ChatPromptTemplate.from_messages(395 [396 MessagesPlaceholder(variable_name="history"),397 ("human", "Hello, how are you?"),398 ]399 )400401 assert _normalize_schema(chat_prompt.get_input_jsonschema()) == snapshot(402 name="chat_prompt_input_schema"403 )404 assert _normalize_schema(chat_prompt.get_output_jsonschema()) == snapshot(405 name="chat_prompt_output_schema"406 )407408 prompt = PromptTemplate.from_template("Hello, {name}!")409410 assert prompt.get_input_jsonschema() == {411 "title": "PromptInput",412 "type": "object",413 "properties": {"name": {"title": "Name", "type": "string"}},414 "required": ["name"],415 }416 assert _schema(prompt.output_schema) == snapshot(name="prompt_output_schema")417418 prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()419420 assert _normalize_schema(prompt_mapper.get_input_jsonschema()) == {421 "$defs": {422 "PromptInput": {423 "properties": {"name": {"title": "Name", "type": "string"}},424 "required": ["name"],425 "title": "PromptInput",426 "type": "object",427 }428 },429 "default": None,430 "items": {"$ref": "#/$defs/PromptInput"},431 "title": "RunnableEach<PromptTemplate>Input",432 "type": "array",433 }434 assert _schema(prompt_mapper.output_schema) == snapshot(435 name="prompt_mapper_output_schema"436 )437438 list_parser = CommaSeparatedListOutputParser()439440 assert _schema(list_parser.input_schema) == snapshot(441 name="list_parser_input_schema"442 )443 assert _schema(list_parser.output_schema) == {444 "title": "CommaSeparatedListOutputParserOutput",445 "type": "array",446 "items": {"type": "string"},447 }448449 seq = prompt | fake_llm | list_parser450451 assert seq.get_input_jsonschema() == {452 "title": "PromptInput",453 "type": "object",454 "properties": {"name": {"title": "Name", "type": "string"}},455 "required": ["name"],456 }457 assert seq.get_output_jsonschema() == {458 "type": "array",459 "items": {"type": "string"},460 "title": "CommaSeparatedListOutputParserOutput",461 }462463 router: Runnable = RouterRunnable({})464465 assert _schema(router.input_schema) == {466 "$ref": "#/definitions/RouterInput",467 "definitions": {468 "RouterInput": {469 "description": "Router input.",470 "properties": {471 "input": {"title": "Input"},472 "key": {"title": "Key", "type": "string"},473 },474 "required": ["key", "input"],475 "title": "RouterInput",476 "type": "object",477 }478 },479 "title": "RouterRunnableInput",480 }481 assert router.get_output_jsonschema() == {"title": "RouterRunnableOutput"}482483 seq_w_map = (484 prompt485 | fake_llm486 | {487 "original": RunnablePassthrough(input_type=str),488 "as_list": list_parser,489 "length": typed_lambda_impl,490 }491 )492493 assert seq_w_map.get_input_jsonschema() == {494 "title": "PromptInput",495 "type": "object",496 "properties": {"name": {"title": "Name", "type": "string"}},497 "required": ["name"],498 }499 assert seq_w_map.get_output_jsonschema() == {500 "title": "RunnableParallel<original,as_list,length>Output",501 "type": "object",502 "properties": {503 "original": {"title": "Original", "type": "string"},504 "length": {"title": "Length", "type": "integer"},505 "as_list": {506 "title": "As List",507 "type": "array",508 "items": {"type": "string"},509 },510 },511 "required": ["original", "as_list", "length"],512 }513514 # Add a test for schema of runnable assign515 def foo(x: int) -> int:516 return x517518 foo_ = RunnableLambda(foo)519520 assert foo_.assign(bar=lambda _: "foo").get_output_jsonschema() == {521 "properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},522 "required": ["root", "bar"],523 "title": "RunnableAssignOutput",524 "type": "object",525 }526527528def test_passthrough_assign_schema() -> None:529 retriever = FakeRetriever() # str -> list[Document]530 prompt = PromptTemplate.from_template("{context} {question}")531 fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]]532533 seq_w_assign = (534 RunnablePassthrough.assign(context=itemgetter("question") | retriever)535 | prompt536 | fake_llm537 )538539 assert seq_w_assign.get_input_jsonschema() == {540 "properties": {"question": {"title": "Question", "type": "string"}},541 "title": "RunnableSequenceInput",542 "type": "object",543 "required": ["question"],544 }545 assert seq_w_assign.get_output_jsonschema() == {546 "title": "FakeListLLMOutput",547 "type": "string",548 }549550 invalid_seq_w_assign = (551 RunnablePassthrough.assign(context=itemgetter("question") | retriever)552 | fake_llm # type: ignore[operator]553 )554555 # fallback to RunnableAssign.input_schema if next runnable doesn't have556 # expected dict input_schema557 assert invalid_seq_w_assign.get_input_jsonschema() == {558 "properties": {"question": {"title": "Question"}},559 "title": "RunnableParallel<context>Input",560 "type": "object",561 "required": ["question"],562 }563564565def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:566 first_lambda = lambda x: x["hello"] # noqa: E731567 assert RunnableLambda(first_lambda).get_input_jsonschema() == {568 "title": "RunnableLambdaInput",569 "type": "object",570 "properties": {"hello": {"title": "Hello"}},571 "required": ["hello"],572 }573574 second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731575 assert RunnableLambda(second_lambda).get_input_jsonschema() == {576 "title": "RunnableLambdaInput",577 "type": "object",578 "properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},579 "required": ["bye", "hello"],580 }581582 def get_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202583 return value["variable_name"]584585 assert RunnableLambda(get_value).get_input_jsonschema() == {586 "title": "get_value_input",587 "type": "object",588 "properties": {"variable_name": {"title": "Variable Name"}},589 "required": ["variable_name"],590 }591592 async def aget_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202593 return (value["variable_name"], value.get("another"))594595 assert RunnableLambda(aget_value).get_input_jsonschema() == {596 "title": "aget_value_input",597 "type": "object",598 "properties": {599 "another": {"title": "Another"},600 "variable_name": {"title": "Variable Name"},601 },602 "required": ["another", "variable_name"],603 }604605 async def aget_values(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202606 return {607 "hello": value["variable_name"],608 "bye": value["variable_name"],609 "byebye": value["yo"],610 }611612 assert RunnableLambda(aget_values).get_input_jsonschema() == {613 "title": "aget_values_input",614 "type": "object",615 "properties": {616 "variable_name": {"title": "Variable Name"},617 "yo": {"title": "Yo"},618 },619 "required": ["variable_name", "yo"],620 }621622 class InputType(TypedDict):623 variable_name: str624 yo: int625626 class OutputType(TypedDict):627 hello: str628 bye: str629 byebye: int630631 async def aget_values_typed(value: InputType) -> OutputType:632 return {633 "hello": value["variable_name"],634 "bye": value["variable_name"],635 "byebye": value["yo"],636 }637638 assert _normalize_schema(639 RunnableLambda(aget_values_typed).get_input_jsonschema()640 ) == _normalize_schema(641 {642 "$defs": {643 "InputType": {644 "properties": {645 "variable_name": {646 "title": "Variable Name",647 "type": "string",648 },649 "yo": {"title": "Yo", "type": "integer"},650 },651 "required": ["variable_name", "yo"],652 "title": "InputType",653 "type": "object",654 }655 },656 "allOf": [{"$ref": "#/$defs/InputType"}],657 "title": "aget_values_typed_input",658 }659 )660661 if PYDANTIC_VERSION_AT_LEAST_29:662 assert _normalize_schema(663 RunnableLambda(aget_values_typed).get_output_jsonschema()664 ) == snapshot(name="schema8")665666667def test_with_types_with_type_generics() -> None:668 """Verify that with_types works if we use things like list[int]."""669670 def foo(x: int) -> None:671 """Add one to the input."""672 raise NotImplementedError673674 # Try specifying some675 RunnableLambda(foo).with_types(676 output_type=list[int], # type: ignore[arg-type]677 input_type=list[int], # type: ignore[arg-type]678 )679 RunnableLambda(foo).with_types(680 output_type=Sequence[int], # type: ignore[arg-type]681 input_type=Sequence[int], # type: ignore[arg-type]682 )683684685def test_schema_with_itemgetter() -> None:686 """Test runnable with itemgetter."""687 foo = RunnableLambda(itemgetter("hello"))688 assert _schema(foo.input_schema) == {689 "properties": {"hello": {"title": "Hello"}},690 "required": ["hello"],691 "title": "RunnableLambdaInput",692 "type": "object",693 }694 prompt = ChatPromptTemplate.from_template("what is {language}?")695 chain = {"language": itemgetter("language")} | prompt696 assert _schema(chain.input_schema) == {697 "properties": {"language": {"title": "Language"}},698 "required": ["language"],699 "title": "RunnableParallel<language>Input",700 "type": "object",701 }702703704def test_schema_complex_seq() -> None:705 prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")706 prompt2 = ChatPromptTemplate.from_template(707 "what country is the city {city} in? respond in {language}"708 )709710 model = FakeListChatModel(responses=[""])711712 chain1: Runnable = RunnableSequence(713 prompt1, model, StrOutputParser(), name="city_chain"714 )715716 assert chain1.name == "city_chain"717718 chain2 = (719 {"city": chain1, "language": itemgetter("language")}720 | prompt2721 | model722 | StrOutputParser()723 )724725 assert chain2.get_input_jsonschema() == {726 "title": "RunnableParallel<city,language>Input",727 "type": "object",728 "properties": {729 "person": {"title": "Person", "type": "string"},730 "language": {"title": "Language"},731 },732 "required": ["person", "language"],733 }734735 assert chain2.get_output_jsonschema() == {736 "title": "StrOutputParserOutput",737 "type": "string",738 }739740 assert chain2.with_types(input_type=str).get_input_jsonschema() == {741 "title": "RunnableSequenceInput",742 "type": "string",743 }744745 assert chain2.with_types(input_type=int).get_output_jsonschema() == {746 "title": "StrOutputParserOutput",747 "type": "string",748 }749750 class InputType(BaseModel):751 person: str752753 assert chain2.with_types(input_type=InputType).get_input_jsonschema() == {754 "title": "InputType",755 "type": "object",756 "properties": {"person": {"title": "Person", "type": "string"}},757 "required": ["person"],758 }759760761def test_configurable_fields(snapshot: SnapshotAssertion) -> None:762 fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]]763764 assert fake_llm.invoke("...") == "a"765766 fake_llm_configurable = fake_llm.configurable_fields(767 responses=ConfigurableField(768 id="llm_responses",769 name="LLM Responses",770 description="A list of fake responses for this LLM",771 )772 )773774 assert fake_llm_configurable.invoke("...") == "a"775776 if PYDANTIC_VERSION_AT_LEAST_29:777 assert _normalize_schema(778 fake_llm_configurable.get_config_jsonschema()779 ) == snapshot(name="schema2")780781 fake_llm_configured = fake_llm_configurable.with_config(782 configurable={"llm_responses": ["b"]}783 )784785 assert fake_llm_configured.invoke("...") == "b"786787 prompt = PromptTemplate.from_template("Hello, {name}!")788789 assert prompt.invoke({"name": "John"}) == StringPromptValue(text="Hello, John!")790791 prompt_configurable = prompt.configurable_fields(792 template=ConfigurableField(793 id="prompt_template",794 name="Prompt Template",795 description="The prompt template for this chain",796 )797 )798799 assert prompt_configurable.invoke({"name": "John"}) == StringPromptValue(800 text="Hello, John!"801 )802803 if PYDANTIC_VERSION_AT_LEAST_29:804 assert _normalize_schema(805 prompt_configurable.get_config_jsonschema()806 ) == snapshot(name="schema3")807808 prompt_configured = prompt_configurable.with_config(809 configurable={"prompt_template": "Hello, {name}! {name}!"}810 )811812 assert prompt_configured.invoke({"name": "John"}) == StringPromptValue(813 text="Hello, John! John!"814 )815816 assert prompt_configurable.with_config(817 configurable={"prompt_template": "Hello {name} in {lang}"}818 ).get_input_jsonschema() == {819 "title": "PromptInput",820 "type": "object",821 "properties": {822 "lang": {"title": "Lang", "type": "string"},823 "name": {"title": "Name", "type": "string"},824 },825 "required": ["lang", "name"],826 }827828 chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()829830 assert chain_configurable.invoke({"name": "John"}) == "a"831832 if PYDANTIC_VERSION_AT_LEAST_29:833 assert _normalize_schema(834 chain_configurable.get_config_jsonschema()835 ) == snapshot(name="schema4")836837 assert (838 chain_configurable.with_config(839 configurable={840 "prompt_template": "A very good morning to you, {name} {lang}!",841 "llm_responses": ["c"],842 }843 ).invoke({"name": "John", "lang": "en"})844 == "c"845 )846847 assert chain_configurable.with_config(848 configurable={849 "prompt_template": "A very good morning to you, {name} {lang}!",850 "llm_responses": ["c"],851 }852 ).get_input_jsonschema() == {853 "title": "PromptInput",854 "type": "object",855 "properties": {856 "lang": {"title": "Lang", "type": "string"},857 "name": {"title": "Name", "type": "string"},858 },859 "required": ["lang", "name"],860 }861862 chain_with_map_configurable = prompt_configurable | {863 "llm1": fake_llm_configurable | StrOutputParser(),864 "llm2": fake_llm_configurable | StrOutputParser(),865 "llm3": fake_llm.configurable_fields(866 responses=ConfigurableField("other_responses")867 )868 | StrOutputParser(),869 }870871 assert chain_with_map_configurable.invoke({"name": "John"}) == {872 "llm1": "a",873 "llm2": "a",874 "llm3": "a",875 }876877 if PYDANTIC_VERSION_AT_LEAST_29:878 assert _normalize_schema(879 chain_with_map_configurable.get_config_jsonschema()880 ) == snapshot(name="schema5")881882 assert chain_with_map_configurable.with_config(883 configurable={884 "prompt_template": "A very good morning to you, {name}!",885 "llm_responses": ["c"],886 "other_responses": ["d"],887 }888 ).invoke({"name": "John"}) == {"llm1": "c", "llm2": "c", "llm3": "d"}889890891def test_configurable_alts_factory() -> None:892 fake_llm = FakeListLLM(responses=["a"]).configurable_alternatives(893 ConfigurableField(id="llm", name="LLM"),894 chat=partial(FakeListLLM, responses=["b"]),895 )896897 assert fake_llm.invoke("...") == "a"898899 assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b"900901902def test_configurable_fields_prefix_keys(snapshot: SnapshotAssertion) -> None:903 fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(904 responses=ConfigurableFieldMultiOption(905 id="responses",906 name="Chat Responses",907 options={908 "hello": "A good morning to you!",909 "bye": "See you later!",910 "helpful": "How can I help you?",911 },912 default=["hello", "bye"],913 ),914 # (sleep is a configurable field in FakeListChatModel)915 sleep=ConfigurableField(916 id="chat_sleep",917 is_shared=True,918 ),919 )920 fake_llm = (921 FakeListLLM(responses=["a"])922 .configurable_fields(923 responses=ConfigurableField(924 id="responses",925 name="LLM Responses",926 description="A list of fake responses for this LLM",927 )928 )929 .configurable_alternatives(930 ConfigurableField(id="llm", name="LLM"),931 chat=fake_chat | StrOutputParser(),932 prefix_keys=True,933 )934 )935 prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(936 template=ConfigurableFieldSingleOption(937 id="prompt_template",938 name="Prompt Template",939 description="The prompt template for this chain",940 options={941 "hello": "Hello, {name}!",942 "good_morning": "A very good morning to you, {name}!",943 },944 default="hello",945 )946 )947948 chain = prompt | fake_llm949950 if PYDANTIC_VERSION_AT_LEAST_29:951 assert _normalize_schema(_schema(chain.config_schema())) == snapshot(952 name="schema6"953 )954955956def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None:957 fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(958 responses=ConfigurableFieldMultiOption(959 id="chat_responses",960 name="Chat Responses",961 options={962 "hello": "A good morning to you!",963 "bye": "See you later!",964 "helpful": "How can I help you?",965 },966 default=["hello", "bye"],967 )968 )969 fake_llm = (970 FakeListLLM(responses=["a"])971 .configurable_fields(972 responses=ConfigurableField(973 id="llm_responses",974 name="LLM Responses",975 description="A list of fake responses for this LLM",976 )977 )978 .configurable_alternatives(979 ConfigurableField(id="llm", name="LLM"),980 chat=fake_chat | StrOutputParser(),981 )982 )983984 prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(985 template=ConfigurableFieldSingleOption(986 id="prompt_template",987 name="Prompt Template",988 description="The prompt template for this chain",989 options={990 "hello": "Hello, {name}!",991 "good_morning": "A very good morning to you, {name}!",992 },993 default="hello",994 )995 )996997 # deduplication of configurable fields998 chain_configurable = prompt | fake_llm | (lambda x: {"name": x}) | prompt | fake_llm9991000 assert chain_configurable.invoke({"name": "John"}) == "a"10011002 if PYDANTIC_VERSION_AT_LEAST_29:1003 assert _normalize_schema(1004 chain_configurable.get_config_jsonschema()1005 ) == snapshot(name="schema7")10061007 assert (1008 chain_configurable.with_config(configurable={"llm": "chat"}).invoke(1009 {"name": "John"}1010 )1011 == "A good morning to you!"1012 )10131014 assert (1015 chain_configurable.with_config(1016 configurable={"llm": "chat", "chat_responses": ["helpful"]}1017 ).invoke({"name": "John"})1018 == "How can I help you?"1019 )102010211022def test_passthrough_tap(mocker: MockerFixture) -> None:1023 fake = FakeRunnable()1024 mock = mocker.Mock()10251026 seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)10271028 assert seq.invoke("hello", my_kwarg="value") == 51029 assert mock.call_args_list == [1030 mocker.call("hello", my_kwarg="value"),1031 mocker.call(5),1032 ]1033 mock.reset_mock()10341035 assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6]1036 assert len(mock.call_args_list) == 41037 for call in [1038 mocker.call("hello", my_kwarg="value"),1039 mocker.call("byebye", my_kwarg="value"),1040 mocker.call(5),1041 mocker.call(6),1042 ]:1043 assert call in mock.call_args_list1044 mock.reset_mock()10451046 assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [1047 5,1048 6,1049 ]1050 assert len(mock.call_args_list) == 41051 for call in [1052 mocker.call("hello", my_kwarg="value"),1053 mocker.call("byebye", my_kwarg="value"),1054 mocker.call(5),1055 mocker.call(6),1056 ]:1057 assert call in mock.call_args_list1058 mock.reset_mock()10591060 assert sorted(1061 a1062 for a in seq.batch_as_completed(1063 ["hello", "byebye"], my_kwarg="value", return_exceptions=True1064 )1065 ) == [1066 (0, 5),1067 (1, 6),1068 ]1069 assert len(mock.call_args_list) == 41070 for call in [1071 mocker.call("hello", my_kwarg="value"),1072 mocker.call("byebye", my_kwarg="value"),1073 mocker.call(5),1074 mocker.call(6),1075 ]:1076 assert call in mock.call_args_list1077 mock.reset_mock()10781079 assert list(1080 seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value")1081 ) == [5]1082 assert mock.call_args_list == [1083 mocker.call("hello", my_kwarg="value"),1084 mocker.call(5),1085 ]1086 mock.reset_mock()108710881089async def test_passthrough_tap_async(mocker: MockerFixture) -> None:1090 fake = FakeRunnable()1091 mock = mocker.Mock()10921093 seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)10941095 assert await seq.ainvoke("hello", my_kwarg="value") == 51096 assert mock.call_args_list == [1097 mocker.call("hello", my_kwarg="value"),1098 mocker.call(5),1099 ]1100 mock.reset_mock()11011102 assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6]1103 assert len(mock.call_args_list) == 41104 for call in [1105 mocker.call("hello", my_kwarg="value"),1106 mocker.call("byebye", my_kwarg="value"),1107 mocker.call(5),1108 mocker.call(6),1109 ]:1110 assert call in mock.call_args_list1111 mock.reset_mock()11121113 assert await seq.abatch(1114 ["hello", "byebye"], my_kwarg="value", return_exceptions=True1115 ) == [1116 5,1117 6,1118 ]1119 assert len(mock.call_args_list) == 41120 for call in [1121 mocker.call("hello", my_kwarg="value"),1122 mocker.call("byebye", my_kwarg="value"),1123 mocker.call(5),1124 mocker.call(6),1125 ]:1126 assert call in mock.call_args_list1127 mock.reset_mock()11281129 assert sorted(1130 [1131 a1132 async for a in seq.abatch_as_completed(1133 ["hello", "byebye"], my_kwarg="value", return_exceptions=True1134 )1135 ]1136 ) == [1137 (0, 5),1138 (1, 6),1139 ]1140 assert len(mock.call_args_list) == 41141 for call in [1142 mocker.call("hello", my_kwarg="value"),1143 mocker.call("byebye", my_kwarg="value"),1144 mocker.call(5),1145 mocker.call(6),1146 ]:1147 assert call in mock.call_args_list1148 mock.reset_mock()11491150 assert [1151 part1152 async for part in seq.astream(1153 "hello", {"metadata": {"key": "value"}}, my_kwarg="value"1154 )1155 ] == [5]1156 assert mock.call_args_list == [1157 mocker.call("hello", my_kwarg="value"),1158 mocker.call(5),1159 ]116011611162async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:1163 fake = FakeRunnableSerializable()1164 spy = mocker.spy(fake.__class__, "invoke")1165 fakew = fake.configurable_fields(hello=ConfigurableField(id="hello", name="Hello"))11661167 assert (1168 fakew.with_config(tags=["a-tag"]).invoke(1169 "hello",1170 {1171 "configurable": {"hello": "there", "__secret_key": "nahnah"},1172 "metadata": {"bye": "now"},1173 },1174 )1175 == 51176 )1177 assert spy.call_args_list[0].args[1:] == (1178 "hello",1179 {1180 "tags": ["a-tag"],1181 "callbacks": None,1182 "recursion_limit": 25,1183 "configurable": {"hello": "there", "__secret_key": "nahnah"},1184 "metadata": {"bye": "now"},1185 },1186 )1187 spy.reset_mock()118811891190def test_with_config(mocker: MockerFixture) -> None:1191 fake = FakeRunnable()1192 spy = mocker.spy(fake, "invoke")11931194 assert fake.with_config(tags=["a-tag"]).invoke("hello") == 51195 assert spy.call_args_list == [1196 mocker.call(1197 "hello",1198 {"tags": ["a-tag"], "metadata": {}, "configurable": {}},1199 ),1200 ]1201 spy.reset_mock()12021203 fake_1 = RunnablePassthrough[Any]()1204 fake_2 = RunnablePassthrough[Any]()1205 spy_seq_step = mocker.spy(fake_1.__class__, "invoke")12061207 sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(1208 tags=["b-tag"], max_concurrency=51209 )1210 assert sequence.invoke("hello") == "hello"1211 assert len(spy_seq_step.call_args_list) == 21212 for i, call in enumerate(spy_seq_step.call_args_list):1213 assert call.args[1] == "hello"1214 if i == 0:1215 assert call.args[2].get("tags") == ["a-tag"]1216 assert call.args[2].get("max_concurrency") is None1217 else:1218 assert call.args[2].get("tags") == ["b-tag"]1219 assert call.args[2].get("max_concurrency") == 51220 mocker.stop(spy_seq_step)12211222 assert [1223 *fake.with_config(tags=["a-tag"]).stream(1224 "hello", {"metadata": {"key": "value"}}1225 )1226 ] == [5]1227 assert spy.call_args_list == [1228 mocker.call(1229 "hello",1230 {"tags": ["a-tag"], "metadata": {"key": "value"}, "configurable": {}},1231 ),1232 ]1233 spy.reset_mock()12341235 assert fake.with_config(recursion_limit=5).batch(1236 ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]1237 ) == [5, 7]12381239 assert len(spy.call_args_list) == 21240 for i, call in enumerate(1241 sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)1242 ):1243 assert call.args[0] == ("hello" if i == 0 else "wooorld")1244 if i == 0:1245 assert call.args[1].get("recursion_limit") == 51246 assert call.args[1].get("tags") == ["a-tag"]1247 assert call.args[1].get("metadata") == {}1248 else:1249 assert call.args[1].get("recursion_limit") == 51250 assert call.args[1].get("tags") == []1251 assert call.args[1].get("metadata") == {"key": "value"}12521253 spy.reset_mock()12541255 assert sorted(1256 c1257 for c in fake.with_config(recursion_limit=5).batch_as_completed(1258 ["hello", "wooorld"],1259 [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}],1260 )1261 ) == [(0, 5), (1, 7)]12621263 assert len(spy.call_args_list) == 21264 for i, call in enumerate(1265 sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)1266 ):1267 assert call.args[0] == ("hello" if i == 0 else "wooorld")1268 if i == 0:1269 assert call.args[1].get("recursion_limit") == 51270 assert call.args[1].get("tags") == ["a-tag"]1271 assert call.args[1].get("metadata") == {}1272 else:1273 assert call.args[1].get("recursion_limit") == 51274 assert call.args[1].get("tags") == []1275 assert call.args[1].get("metadata") == {"key": "value"}12761277 spy.reset_mock()12781279 assert fake.with_config(metadata={"a": "b"}).batch(1280 ["hello", "wooorld"], {"tags": ["a-tag"]}1281 ) == [5, 7]1282 assert len(spy.call_args_list) == 21283 for i, call in enumerate(spy.call_args_list):1284 assert call.args[0] == ("hello" if i == 0 else "wooorld")1285 assert call.args[1].get("tags") == ["a-tag"]1286 assert call.args[1].get("metadata") == {"a": "b"}1287 spy.reset_mock()12881289 assert sorted(1290 c for c in fake.batch_as_completed(["hello", "wooorld"], {"tags": ["a-tag"]})1291 ) == [(0, 5), (1, 7)]1292 assert len(spy.call_args_list) == 21293 for i, call in enumerate(spy.call_args_list):1294 assert call.args[0] == ("hello" if i == 0 else "wooorld")1295 assert call.args[1].get("tags") == ["a-tag"]129612971298async def test_with_config_async(mocker: MockerFixture) -> None:1299 fake = FakeRunnable()1300 spy = mocker.spy(fake, "invoke")13011302 handler = ConsoleCallbackHandler()1303 assert (1304 await fake.with_config(metadata={"a": "b"}).ainvoke(1305 "hello", config={"callbacks": [handler]}1306 )1307 == 51308 )1309 assert spy.call_args_list == [1310 mocker.call(1311 "hello",1312 {1313 "callbacks": [handler],1314 "metadata": {"a": "b"},1315 "configurable": {},1316 "tags": [],1317 },1318 ),1319 ]1320 spy.reset_mock()13211322 assert [1323 part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")1324 ] == [5]1325 assert spy.call_args_list == [1326 mocker.call("hello", {"metadata": {"a": "b"}, "tags": [], "configurable": {}}),1327 ]1328 spy.reset_mock()13291330 assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(1331 ["hello", "wooorld"], {"metadata": {"key": "value"}}1332 ) == [1333 5,1334 7,1335 ]1336 assert sorted(spy.call_args_list) == [1337 mocker.call(1338 "hello",1339 {1340 "metadata": {"key": "value"},1341 "tags": ["c"],1342 "callbacks": None,1343 "recursion_limit": 5,1344 "configurable": {},1345 },1346 ),1347 mocker.call(1348 "wooorld",1349 {1350 "metadata": {"key": "value"},1351 "tags": ["c"],1352 "callbacks": None,1353 "recursion_limit": 5,1354 "configurable": {},1355 },1356 ),1357 ]1358 spy.reset_mock()13591360 assert sorted(1361 [1362 c1363 async for c in fake.with_config(1364 recursion_limit=5, tags=["c"]1365 ).abatch_as_completed(["hello", "wooorld"], {"metadata": {"key": "value"}})1366 ]1367 ) == [1368 (0, 5),1369 (1, 7),1370 ]1371 assert len(spy.call_args_list) == 21372 first_call = next(call for call in spy.call_args_list if call.args[0] == "hello")1373 assert first_call == mocker.call(1374 "hello",1375 {1376 "metadata": {"key": "value"},1377 "tags": ["c"],1378 "callbacks": None,1379 "recursion_limit": 5,1380 "configurable": {},1381 },1382 )1383 second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")1384 assert second_call == mocker.call(1385 "wooorld",1386 {1387 "metadata": {"key": "value"},1388 "tags": ["c"],1389 "callbacks": None,1390 "recursion_limit": 5,1391 "configurable": {},1392 },1393 )139413951396def test_default_method_implementations(mocker: MockerFixture) -> None:1397 fake = FakeRunnable()1398 spy = mocker.spy(fake, "invoke")13991400 assert fake.invoke("hello", {"tags": ["a-tag"]}) == 51401 assert spy.call_args_list == [1402 mocker.call("hello", {"tags": ["a-tag"]}),1403 ]1404 spy.reset_mock()14051406 assert [*fake.stream("hello", {"metadata": {"key": "value"}})] == [5]1407 assert spy.call_args_list == [1408 mocker.call("hello", {"metadata": {"key": "value"}}),1409 ]1410 spy.reset_mock()14111412 assert fake.batch(1413 ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]1414 ) == [5, 7]14151416 assert len(spy.call_args_list) == 21417 for call in spy.call_args_list:1418 call_arg = call.args[0]14191420 if call_arg == "hello":1421 assert call_arg == "hello"1422 assert call.args[1].get("tags") == ["a-tag"]1423 assert call.args[1].get("metadata") == {}1424 else:1425 assert call_arg == "wooorld"1426 assert call.args[1].get("tags") == []1427 assert call.args[1].get("metadata") == {"key": "value"}14281429 spy.reset_mock()14301431 assert fake.batch(["hello", "wooorld"], {"tags": ["a-tag"]}) == [5, 7]1432 assert len(spy.call_args_list) == 21433 assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}1434 for call in spy.call_args_list:1435 assert call.args[1].get("tags") == ["a-tag"]1436 assert call.args[1].get("metadata") == {}143714381439async def test_default_method_implementations_async(mocker: MockerFixture) -> None:1440 fake = FakeRunnable()1441 spy = mocker.spy(fake, "invoke")14421443 assert await fake.ainvoke("hello", config={"callbacks": []}) == 51444 assert spy.call_args_list == [1445 mocker.call("hello", {"callbacks": []}),1446 ]1447 spy.reset_mock()14481449 assert [part async for part in fake.astream("hello")] == [5]1450 assert spy.call_args_list == [1451 mocker.call("hello", None),1452 ]1453 spy.reset_mock()14541455 assert await fake.abatch(["hello", "wooorld"], {"metadata": {"key": "value"}}) == [1456 5,1457 7,1458 ]1459 assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}1460 for call in spy.call_args_list:1461 assert call.args[1] == {1462 "metadata": {"key": "value"},1463 "tags": [],1464 "callbacks": None,1465 "recursion_limit": 25,1466 "configurable": {},1467 }146814691470def test_prompt() -> None:1471 prompt = ChatPromptTemplate.from_messages(1472 messages=[1473 SystemMessage(content="You are a nice assistant."),1474 HumanMessagePromptTemplate.from_template("{question}"),1475 ]1476 )1477 expected = ChatPromptValue(1478 messages=[1479 SystemMessage(content="You are a nice assistant."),1480 HumanMessage(content="What is your name?"),1481 ]1482 )14831484 assert prompt.invoke({"question": "What is your name?"}) == expected14851486 assert prompt.batch(1487 [1488 {"question": "What is your name?"},1489 {"question": "What is your favorite color?"},1490 ]1491 ) == [1492 expected,1493 ChatPromptValue(1494 messages=[1495 SystemMessage(content="You are a nice assistant."),1496 HumanMessage(content="What is your favorite color?"),1497 ]1498 ),1499 ]15001501 assert [*prompt.stream({"question": "What is your name?"})] == [expected]150215031504async def test_prompt_async() -> None:1505 prompt = ChatPromptTemplate.from_messages(1506 messages=[1507 SystemMessage(content="You are a nice assistant."),1508 HumanMessagePromptTemplate.from_template("{question}"),1509 ]1510 )1511 expected = ChatPromptValue(1512 messages=[1513 SystemMessage(content="You are a nice assistant."),1514 HumanMessage(content="What is your name?"),1515 ]1516 )15171518 assert await prompt.ainvoke({"question": "What is your name?"}) == expected15191520 assert await prompt.abatch(1521 [1522 {"question": "What is your name?"},1523 {"question": "What is your favorite color?"},1524 ]1525 ) == [1526 expected,1527 ChatPromptValue(1528 messages=[1529 SystemMessage(content="You are a nice assistant."),1530 HumanMessage(content="What is your favorite color?"),1531 ]1532 ),1533 ]15341535 assert [1536 part async for part in prompt.astream({"question": "What is your name?"})1537 ] == [expected]15381539 stream_log = [1540 part async for part in prompt.astream_log({"question": "What is your name?"})1541 ]15421543 assert len(stream_log[0].ops) == 11544 assert stream_log[0].ops[0]["op"] == "replace"1545 assert stream_log[0].ops[0]["path"] == ""1546 assert stream_log[0].ops[0]["value"]["logs"] == {}1547 assert stream_log[0].ops[0]["value"]["final_output"] is None1548 assert stream_log[0].ops[0]["value"]["streamed_output"] == []1549 assert isinstance(stream_log[0].ops[0]["value"]["id"], str)15501551 assert stream_log[1:] == [1552 RunLogPatch(1553 {"op": "add", "path": "/streamed_output/-", "value": expected},1554 {1555 "op": "replace",1556 "path": "/final_output",1557 "value": ChatPromptValue(1558 messages=[1559 SystemMessage(content="You are a nice assistant."),1560 HumanMessage(content="What is your name?"),1561 ]1562 ),1563 },1564 ),1565 ]15661567 stream_log_state = [1568 part1569 async for part in prompt.astream_log(1570 {"question": "What is your name?"}, diff=False1571 )1572 ]15731574 # remove random id1575 stream_log[0].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"1576 stream_log_state[-1].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"1577 stream_log_state[-1].state["id"] = "00000000-0000-0000-0000-000000000000"15781579 # assert output with diff=False matches output with diff=True1580 assert stream_log_state[-1].ops == [op for chunk in stream_log for op in chunk.ops]1581 assert stream_log_state[-1] == RunLog(1582 *[op for chunk in stream_log for op in chunk.ops],1583 state={1584 "final_output": ChatPromptValue(1585 messages=[1586 SystemMessage(content="You are a nice assistant."),1587 HumanMessage(content="What is your name?"),1588 ]1589 ),1590 "id": "00000000-0000-0000-0000-000000000000",1591 "logs": {},1592 "streamed_output": [1593 ChatPromptValue(1594 messages=[1595 SystemMessage(content="You are a nice assistant."),1596 HumanMessage(content="What is your name?"),1597 ]1598 )1599 ],1600 "type": "prompt",1601 "name": "ChatPromptTemplate",1602 },1603 )16041605 # nested inside trace_with_chain_group16061607 async with atrace_as_chain_group("a_group") as manager:1608 stream_log_nested = [1609 part1610 async for part in prompt.astream_log(1611 {"question": "What is your name?"}, config={"callbacks": manager}1612 )1613 ]16141615 assert len(stream_log_nested[0].ops) == 11616 assert stream_log_nested[0].ops[0]["op"] == "replace"1617 assert stream_log_nested[0].ops[0]["path"] == ""1618 assert stream_log_nested[0].ops[0]["value"]["logs"] == {}1619 assert stream_log_nested[0].ops[0]["value"]["final_output"] is None1620 assert stream_log_nested[0].ops[0]["value"]["streamed_output"] == []1621 assert isinstance(stream_log_nested[0].ops[0]["value"]["id"], str)16221623 assert stream_log_nested[1:] == [1624 RunLogPatch(1625 {"op": "add", "path": "/streamed_output/-", "value": expected},1626 {1627 "op": "replace",1628 "path": "/final_output",1629 "value": ChatPromptValue(1630 messages=[1631 SystemMessage(content="You are a nice assistant."),1632 HumanMessage(content="What is your name?"),1633 ]1634 ),1635 },1636 ),1637 ]163816391640def test_prompt_template_params() -> None:1641 prompt = ChatPromptTemplate.from_template(1642 "Respond to the following question: {question}"1643 )1644 result = prompt.invoke(1645 {1646 "question": "test",1647 "topic": "test",1648 }1649 )1650 assert result == ChatPromptValue(1651 messages=[HumanMessage(content="Respond to the following question: test")]1652 )16531654 with pytest.raises(KeyError):1655 prompt.invoke({})165616571658def test_with_listeners(mocker: MockerFixture) -> None:1659 prompt = (1660 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1661 + "{question}"1662 )1663 chat = FakeListChatModel(responses=["foo"])16641665 chain = prompt | chat16661667 mock_start = mocker.Mock()1668 mock_end = mocker.Mock()16691670 chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(1671 {"question": "Who are you?"}1672 )16731674 assert mock_start.call_count == 11675 assert mock_start.call_args[0][0].name == "RunnableSequence"1676 assert mock_end.call_count == 116771678 mock_start.reset_mock()1679 mock_end.reset_mock()16801681 with trace_as_chain_group("hello") as manager:1682 chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(1683 {"question": "Who are you?"}, {"callbacks": manager}1684 )16851686 assert mock_start.call_count == 11687 assert mock_start.call_args[0][0].name == "RunnableSequence"1688 assert mock_end.call_count == 1168916901691async def test_with_listeners_async(mocker: MockerFixture) -> None:1692 prompt = (1693 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1694 + "{question}"1695 )1696 chat = FakeListChatModel(responses=["foo"])16971698 chain = prompt | chat16991700 mock_start = mocker.Mock()1701 mock_end = mocker.Mock()17021703 await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(1704 {"question": "Who are you?"}1705 )17061707 assert mock_start.call_count == 11708 assert mock_start.call_args[0][0].name == "RunnableSequence"1709 assert mock_end.call_count == 117101711 mock_start.reset_mock()1712 mock_end.reset_mock()17131714 async with atrace_as_chain_group("hello") as manager:1715 await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(1716 {"question": "Who are you?"}, {"callbacks": manager}1717 )17181719 assert mock_start.call_count == 11720 assert mock_start.call_args[0][0].name == "RunnableSequence"1721 assert mock_end.call_count == 1172217231724def test_with_listener_propagation(mocker: MockerFixture) -> None:1725 prompt = (1726 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1727 + "{question}"1728 )1729 chat = FakeListChatModel(responses=["foo"])1730 chain = prompt | chat1731 mock_start = mocker.Mock()1732 mock_end = mocker.Mock()1733 chain_with_listeners = chain.with_listeners(on_start=mock_start, on_end=mock_end)17341735 chain_with_listeners.with_retry().invoke({"question": "Who are you?"})17361737 assert mock_start.call_count == 11738 assert mock_start.call_args[0][0].name == "RunnableSequence"1739 assert mock_end.call_count == 117401741 mock_start.reset_mock()1742 mock_end.reset_mock()17431744 chain_with_listeners.invoke({"question": "Who are you?"})17451746 assert mock_start.call_count == 11747 assert mock_start.call_args[0][0].name == "RunnableSequence"1748 assert mock_end.call_count == 117491750 mock_start.reset_mock()1751 mock_end.reset_mock()17521753 chain_with_listeners.with_config({"tags": ["foo"]}).invoke(1754 {"question": "Who are you?"}1755 )17561757 assert mock_start.call_count == 11758 assert mock_start.call_args[0][0].name == "RunnableSequence"1759 assert mock_end.call_count == 117601761 mock_start.reset_mock()1762 mock_end.reset_mock()17631764 chain_with_listeners.bind(stop=["foo"]).invoke({"question": "Who are you?"})17651766 assert mock_start.call_count == 11767 assert mock_start.call_args[0][0].name == "RunnableSequence"1768 assert mock_end.call_count == 117691770 mock_start.reset_mock()1771 mock_end.reset_mock()17721773 mock_start_inner = mocker.Mock()1774 mock_end_inner = mocker.Mock()17751776 chain_with_listeners.with_listeners(1777 on_start=mock_start_inner, on_end=mock_end_inner1778 ).invoke({"question": "Who are you?"})17791780 assert mock_start.call_count == 11781 assert mock_start.call_args[0][0].name == "RunnableSequence"1782 assert mock_end.call_count == 11783 assert mock_start_inner.call_count == 11784 assert mock_start_inner.call_args[0][0].name == "RunnableSequence"1785 assert mock_end_inner.call_count == 1178617871788@freeze_time("2023-01-01")1789@pytest.mark.usefixtures("deterministic_uuids")1790def test_prompt_with_chat_model(1791 mocker: MockerFixture,1792 snapshot: SnapshotAssertion,1793) -> None:1794 prompt = (1795 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1796 + "{question}"1797 )1798 chat = FakeListChatModel(responses=["foo"])17991800 chain = prompt | chat18011802 assert repr(chain) == snapshot1803 assert isinstance(chain, RunnableSequence)1804 assert chain.first == prompt1805 assert chain.middle == []1806 assert chain.last == chat1807 assert dumps(chain, pretty=True) == snapshot18081809 # Test invoke1810 prompt_spy = mocker.spy(prompt.__class__, "invoke")1811 chat_spy = mocker.spy(chat.__class__, "invoke")1812 tracer = FakeTracer()1813 assert chain.invoke(1814 {"question": "What is your name?"}, {"callbacks": [tracer]}1815 ) == _any_id_ai_message(content="foo")1816 assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1817 assert chat_spy.call_args.args[1] == ChatPromptValue(1818 messages=[1819 SystemMessage(content="You are a nice assistant."),1820 HumanMessage(content="What is your name?"),1821 ]1822 )18231824 assert tracer.runs == snapshot18251826 mocker.stop(prompt_spy)1827 mocker.stop(chat_spy)18281829 # Test batch1830 prompt_spy = mocker.spy(prompt.__class__, "batch")1831 chat_spy = mocker.spy(chat.__class__, "batch")1832 tracer = FakeTracer()1833 assert chain.batch(1834 [1835 {"question": "What is your name?"},1836 {"question": "What is your favorite color?"},1837 ],1838 {"callbacks": [tracer]},1839 ) == [1840 _any_id_ai_message(content="foo"),1841 _any_id_ai_message(content="foo"),1842 ]1843 assert prompt_spy.call_args.args[1] == [1844 {"question": "What is your name?"},1845 {"question": "What is your favorite color?"},1846 ]1847 assert chat_spy.call_args.args[1] == [1848 ChatPromptValue(1849 messages=[1850 SystemMessage(content="You are a nice assistant."),1851 HumanMessage(content="What is your name?"),1852 ]1853 ),1854 ChatPromptValue(1855 messages=[1856 SystemMessage(content="You are a nice assistant."),1857 HumanMessage(content="What is your favorite color?"),1858 ]1859 ),1860 ]1861 assert (1862 len(1863 [1864 r1865 for r in tracer.runs1866 if r.parent_run_id is None and len(r.child_runs) == 21867 ]1868 )1869 == 21870 ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"1871 mocker.stop(prompt_spy)1872 mocker.stop(chat_spy)18731874 # Test stream1875 prompt_spy = mocker.spy(prompt.__class__, "invoke")1876 chat_spy = mocker.spy(chat.__class__, "stream")1877 tracer = FakeTracer()1878 assert [1879 *chain.stream({"question": "What is your name?"}, {"callbacks": [tracer]})1880 ] == [1881 _any_id_ai_message_chunk(content="f"),1882 _any_id_ai_message_chunk(content="o"),1883 _any_id_ai_message_chunk(content="o", chunk_position="last"),1884 ]1885 assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1886 assert chat_spy.call_args.args[1] == ChatPromptValue(1887 messages=[1888 SystemMessage(content="You are a nice assistant."),1889 HumanMessage(content="What is your name?"),1890 ]1891 )189218931894@freeze_time("2023-01-01")1895@pytest.mark.usefixtures("deterministic_uuids")1896async def test_prompt_with_chat_model_async(1897 mocker: MockerFixture,1898 snapshot: SnapshotAssertion,1899) -> None:1900 prompt = (1901 SystemMessagePromptTemplate.from_template("You are a nice assistant.")1902 + "{question}"1903 )1904 chat = FakeListChatModel(responses=["foo"])19051906 chain = prompt | chat19071908 assert repr(chain) == snapshot1909 assert isinstance(chain, RunnableSequence)1910 assert chain.first == prompt1911 assert chain.middle == []1912 assert chain.last == chat1913 assert dumps(chain, pretty=True) == snapshot19141915 # Test invoke1916 prompt_spy = mocker.spy(prompt.__class__, "ainvoke")1917 chat_spy = mocker.spy(chat.__class__, "ainvoke")1918 tracer = FakeTracer()1919 assert await chain.ainvoke(1920 {"question": "What is your name?"}, {"callbacks": [tracer]}1921 ) == _any_id_ai_message(content="foo")1922 assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1923 assert chat_spy.call_args.args[1] == ChatPromptValue(1924 messages=[1925 SystemMessage(content="You are a nice assistant."),1926 HumanMessage(content="What is your name?"),1927 ]1928 )19291930 assert tracer.runs == snapshot19311932 mocker.stop(prompt_spy)1933 mocker.stop(chat_spy)19341935 # Test batch1936 prompt_spy = mocker.spy(prompt.__class__, "abatch")1937 chat_spy = mocker.spy(chat.__class__, "abatch")1938 tracer = FakeTracer()1939 assert await chain.abatch(1940 [1941 {"question": "What is your name?"},1942 {"question": "What is your favorite color?"},1943 ],1944 {"callbacks": [tracer]},1945 ) == [1946 _any_id_ai_message(content="foo"),1947 _any_id_ai_message(content="foo"),1948 ]1949 assert prompt_spy.call_args.args[1] == [1950 {"question": "What is your name?"},1951 {"question": "What is your favorite color?"},1952 ]1953 assert chat_spy.call_args.args[1] == [1954 ChatPromptValue(1955 messages=[1956 SystemMessage(content="You are a nice assistant."),1957 HumanMessage(content="What is your name?"),1958 ]1959 ),1960 ChatPromptValue(1961 messages=[1962 SystemMessage(content="You are a nice assistant."),1963 HumanMessage(content="What is your favorite color?"),1964 ]1965 ),1966 ]1967 assert (1968 len(1969 [1970 r1971 for r in tracer.runs1972 if r.parent_run_id is None and len(r.child_runs) == 21973 ]1974 )1975 == 21976 ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"1977 mocker.stop(prompt_spy)1978 mocker.stop(chat_spy)19791980 # Test stream1981 prompt_spy = mocker.spy(prompt.__class__, "ainvoke")1982 chat_spy = mocker.spy(chat.__class__, "astream")1983 tracer = FakeTracer()1984 assert [1985 a1986 async for a in chain.astream(1987 {"question": "What is your name?"}, {"callbacks": [tracer]}1988 )1989 ] == [1990 _any_id_ai_message_chunk(content="f"),1991 _any_id_ai_message_chunk(content="o"),1992 _any_id_ai_message_chunk(content="o", chunk_position="last"),1993 ]1994 assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}1995 assert chat_spy.call_args.args[1] == ChatPromptValue(1996 messages=[1997 SystemMessage(content="You are a nice assistant."),1998 HumanMessage(content="What is your name?"),1999 ]2000 )
Findings
✓ No findings reported for this file.