Ensure functions have docstrings for documentation
async def set_counter_var() -> Any:
1"""Unit tests for verifying event dispatching.23Much of this code is indirectly tested already through many end-to-end tests4that generate traces based on the callbacks. The traces are all verified5via snapshot testing (e.g., see unit tests for runnables).6"""78import contextvars9from contextlib import asynccontextmanager10from typing import Any11from uuid import UUID1213from typing_extensions import override1415from langchain_core.callbacks import (16 AsyncCallbackHandler,17 AsyncCallbackManager,18 BaseCallbackHandler,19)202122async def test_inline_handlers_share_parent_context() -> None:23 """Verify that handlers that are configured to run_inline can update parent context.2425 This test was created because some of the inline handlers were getting26 their own context as the handling logic was kicked off using asyncio.gather27 which does not automatically propagate the parent context (by design).2829 This issue was affecting only a few specific handlers:3031 * on_llm_start32 * on_chat_model_start3334 which in some cases were triggered with multiple prompts and as a result35 triggering multiple tasks that were launched in parallel.36 """37 some_var: contextvars.ContextVar[str] = contextvars.ContextVar("some_var")3839 class CustomHandler(AsyncCallbackHandler):40 """A handler that sets the context variable.4142 The handler sets the context variable to the name of the callback that was43 called.44 """4546 def __init__(self, *, run_inline: bool) -> None:47 """Initialize the handler."""48 self.run_inline = run_inline4950 @override51 async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:52 """Update the callstack with the name of the callback."""53 some_var.set("on_llm_start")5455 # The manager serves as a callback dispatcher.56 # It's responsible for dispatching callbacks to all registered handlers.57 manager = AsyncCallbackManager(handlers=[CustomHandler(run_inline=True)])5859 # Check on_llm_start60 some_var.set("unset")61 await manager.on_llm_start({}, ["prompt 1"])62 assert some_var.get() == "on_llm_start"6364 # Check what happens when run_inline is False65 # We don't expect the context to be updated66 manager2 = AsyncCallbackManager(67 handlers=[68 CustomHandler(run_inline=False),69 ]70 )7172 some_var.set("unset")73 await manager2.on_llm_start({}, ["prompt 1"])74 # Will not be updated because the handler is not inline75 assert some_var.get() == "unset"767778async def test_inline_handlers_share_parent_context_multiple() -> None:79 """A slightly more complex variation of the test unit test above.8081 This unit test verifies that things work correctly when there are multiple prompts,82 and multiple handlers that are configured to run inline.83 """84 counter_var = contextvars.ContextVar("counter", default=0)8586 shared_stack = []8788 @asynccontextmanager89 async def set_counter_var() -> Any:90 token = counter_var.set(0)91 try:92 yield93 finally:94 counter_var.reset(token)9596 class StatefulAsyncCallbackHandler(AsyncCallbackHandler):97 def __init__(self, name: str, *, run_inline: bool = True):98 self.name = name99 self.run_inline = run_inline100101 async def on_llm_start(102 self,103 serialized: dict[str, Any],104 prompts: list[str],105 *,106 run_id: UUID,107 parent_run_id: UUID | None = None,108 **kwargs: Any,109 ) -> None:110 if self.name == "StateModifier":111 current_counter = counter_var.get()112 counter_var.set(current_counter + 1)113 state = counter_var.get()114 elif self.name == "StateReader":115 state = counter_var.get()116 else:117 state = None118119 shared_stack.append(state)120121 await super().on_llm_start(122 serialized,123 prompts,124 run_id=run_id,125 parent_run_id=parent_run_id,126 **kwargs,127 )128129 handlers: list[BaseCallbackHandler] = [130 StatefulAsyncCallbackHandler("StateModifier", run_inline=True),131 StatefulAsyncCallbackHandler("StateReader", run_inline=True),132 StatefulAsyncCallbackHandler("NonInlineHandler", run_inline=False),133 ]134135 prompts = ["Prompt1", "Prompt2", "Prompt3"]136137 async with set_counter_var():138 shared_stack.clear()139 manager = AsyncCallbackManager(handlers=handlers)140 await manager.on_llm_start({}, prompts)141142 # Assert the order of states143 states = [entry for entry in shared_stack if entry is not None]144 assert states == [145 1,146 1,147 2,148 2,149 3,150 3,151 ]152153154async def test_shielded_callback_context_preservation() -> None:155 """Verify that shielded callbacks preserve context variables.156157 This test specifically addresses the issue where async callbacks decorated158 with @shielded do not properly preserve context variables, breaking159 instrumentation and other context-dependent functionality.160161 The issue manifests in callbacks that use the @shielded decorator:162 * on_llm_end163 * on_llm_error164 * on_chain_end165 * on_chain_error166 * And other shielded callback methods167 """168 context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context")169170 class ContextTestHandler(AsyncCallbackHandler):171 """Handler that reads context variables in shielded callbacks."""172173 def __init__(self) -> None:174 self.run_inline = False175 self.context_values: list[str] = []176177 @override178 async def on_llm_end(self, response: Any, **kwargs: Any) -> None:179 """This method is decorated with @shielded in the run manager."""180 # This should preserve the context variable value181 self.context_values.append(context_var.get("not_found"))182183 @override184 async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None:185 """This method is decorated with @shielded in the run manager."""186 # This should preserve the context variable value187 self.context_values.append(context_var.get("not_found"))188189 # Set up the test context190 context_var.set("test_value")191 handler = ContextTestHandler()192 manager = AsyncCallbackManager(handlers=[handler])193194 # Create run managers that have the shielded methods195 llm_managers = await manager.on_llm_start({}, ["test prompt"])196 llm_run_manager = llm_managers[0]197198 chain_run_manager = await manager.on_chain_start({}, {"test": "input"})199200 # Test LLM end callback (which is shielded)201 await llm_run_manager.on_llm_end({"response": "test"}) # type: ignore[arg-type]202203 # Test Chain end callback (which is shielded)204 await chain_run_manager.on_chain_end({"output": "test"})205206 # The context should be preserved in shielded callbacks207 # This was the main issue - shielded decorators were not preserving context208 assert handler.context_values == ["test_value", "test_value"], (209 f"Expected context values ['test_value', 'test_value'], "210 f"but got {handler.context_values}. "211 f"This indicates the shielded decorator is not preserving context variables."212 )
Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.