libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py PYTHON 213 lines View on github.com → Search inside
1"""Unit tests for verifying event dispatching.23Much of this code is indirectly tested already through many end-to-end tests4that generate traces based on the callbacks. The traces are all verified5via snapshot testing (e.g., see unit tests for runnables).6"""78import contextvars9from contextlib import asynccontextmanager10from typing import Any11from uuid import UUID1213from typing_extensions import override1415from langchain_core.callbacks import (16    AsyncCallbackHandler,17    AsyncCallbackManager,18    BaseCallbackHandler,19)202122async def test_inline_handlers_share_parent_context() -> None:23    """Verify that handlers that are configured to run_inline can update parent context.2425    This test was created because some of the inline handlers were getting26    their own context as the handling logic was kicked off using asyncio.gather27    which does not automatically propagate the parent context (by design).2829    This issue was affecting only a few specific handlers:3031    * on_llm_start32    * on_chat_model_start3334    which in some cases were triggered with multiple prompts and as a result35    triggering multiple tasks that were launched in parallel.36    """37    some_var: contextvars.ContextVar[str] = contextvars.ContextVar("some_var")3839    class CustomHandler(AsyncCallbackHandler):40        """A handler that sets the context variable.4142        The handler sets the context variable to the name of the callback that was43        called.44        """4546        def __init__(self, *, run_inline: bool) -> None:47            """Initialize the handler."""48            self.run_inline = run_inline4950        @override51        async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:52            """Update the callstack with the name of the callback."""53            some_var.set("on_llm_start")5455    # The manager serves as a callback dispatcher.56    # It's responsible for dispatching callbacks to all registered handlers.57    manager = AsyncCallbackManager(handlers=[CustomHandler(run_inline=True)])5859    # Check on_llm_start60    some_var.set("unset")61    await manager.on_llm_start({}, ["prompt 1"])62    assert some_var.get() == "on_llm_start"6364    # Check what happens when run_inline is False65    # We don't expect the context to be updated66    manager2 = AsyncCallbackManager(67        handlers=[68            CustomHandler(run_inline=False),69        ]70    )7172    some_var.set("unset")73    await manager2.on_llm_start({}, ["prompt 1"])74    # Will not be updated because the handler is not inline75    assert some_var.get() == "unset"767778async def test_inline_handlers_share_parent_context_multiple() -> None:79    """A slightly more complex variation of the test unit test above.8081    This unit test verifies that things work correctly when there are multiple prompts,82    and multiple handlers that are configured to run inline.83    """84    counter_var = contextvars.ContextVar("counter", default=0)8586    shared_stack = []8788    @asynccontextmanager89    async def set_counter_var() -> Any:90        token = counter_var.set(0)91        try:92            yield93        finally:94            counter_var.reset(token)9596    class StatefulAsyncCallbackHandler(AsyncCallbackHandler):97        def __init__(self, name: str, *, run_inline: bool = True):98            self.name = name99            self.run_inline = run_inline100101        async def on_llm_start(102            self,103            serialized: dict[str, Any],104            prompts: list[str],105            *,106            run_id: UUID,107            parent_run_id: UUID | None = None,108            **kwargs: Any,109        ) -> None:110            if self.name == "StateModifier":111                current_counter = counter_var.get()112                counter_var.set(current_counter + 1)113                state = counter_var.get()114            elif self.name == "StateReader":115                state = counter_var.get()116            else:117                state = None118119            shared_stack.append(state)120121            await super().on_llm_start(122                serialized,123                prompts,124                run_id=run_id,125                parent_run_id=parent_run_id,126                **kwargs,127            )128129    handlers: list[BaseCallbackHandler] = [130        StatefulAsyncCallbackHandler("StateModifier", run_inline=True),131        StatefulAsyncCallbackHandler("StateReader", run_inline=True),132        StatefulAsyncCallbackHandler("NonInlineHandler", run_inline=False),133    ]134135    prompts = ["Prompt1", "Prompt2", "Prompt3"]136137    async with set_counter_var():138        shared_stack.clear()139        manager = AsyncCallbackManager(handlers=handlers)140        await manager.on_llm_start({}, prompts)141142        # Assert the order of states143        states = [entry for entry in shared_stack if entry is not None]144        assert states == [145            1,146            1,147            2,148            2,149            3,150            3,151        ]152153154async def test_shielded_callback_context_preservation() -> None:155    """Verify that shielded callbacks preserve context variables.156157    This test specifically addresses the issue where async callbacks decorated158    with @shielded do not properly preserve context variables, breaking159    instrumentation and other context-dependent functionality.160161    The issue manifests in callbacks that use the @shielded decorator:162    * on_llm_end163    * on_llm_error164    * on_chain_end165    * on_chain_error166    * And other shielded callback methods167    """168    context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context")169170    class ContextTestHandler(AsyncCallbackHandler):171        """Handler that reads context variables in shielded callbacks."""172173        def __init__(self) -> None:174            self.run_inline = False175            self.context_values: list[str] = []176177        @override178        async def on_llm_end(self, response: Any, **kwargs: Any) -> None:179            """This method is decorated with @shielded in the run manager."""180            # This should preserve the context variable value181            self.context_values.append(context_var.get("not_found"))182183        @override184        async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None:185            """This method is decorated with @shielded in the run manager."""186            # This should preserve the context variable value187            self.context_values.append(context_var.get("not_found"))188189    # Set up the test context190    context_var.set("test_value")191    handler = ContextTestHandler()192    manager = AsyncCallbackManager(handlers=[handler])193194    # Create run managers that have the shielded methods195    llm_managers = await manager.on_llm_start({}, ["test prompt"])196    llm_run_manager = llm_managers[0]197198    chain_run_manager = await manager.on_chain_start({}, {"test": "input"})199200    # Test LLM end callback (which is shielded)201    await llm_run_manager.on_llm_end({"response": "test"})  # type: ignore[arg-type]202203    # Test Chain end callback (which is shielded)204    await chain_run_manager.on_chain_end({"output": "test"})205206    # The context should be preserved in shielded callbacks207    # This was the main issue - shielded decorators were not preserving context208    assert handler.context_values == ["test_value", "test_value"], (209        f"Expected context values ['test_value', 'test_value'], "210        f"but got {handler.context_values}. "211        f"This indicates the shielded decorator is not preserving context variables."212    )

Code quality findings 2

Ensure functions have docstrings for documentation
missing-docstring
async def set_counter_var() -> Any:
Ensure functions have docstrings for documentation
missing-docstring
async def on_llm_start(

Get this view in your editor

Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.