Ensure functions have docstrings for documentation
def coro_with_context(
1"""Utility code for `Runnable` objects."""23from __future__ import annotations45import ast6import asyncio7import inspect8import sys9import textwrap1011# Cannot move to TYPE_CHECKING as Mapping and Sequence are needed at runtime by12# RunnableConfigurableFields.13from collections.abc import Mapping, Sequence # noqa: TC00314from functools import lru_cache15from inspect import signature16from itertools import groupby17from typing import (18 TYPE_CHECKING,19 Any,20 NamedTuple,21 Protocol,22 TypeGuard,23 TypeVar,24)2526from typing_extensions import override2728# Re-export create-model for backwards compatibility29from langchain_core.utils.pydantic import create_model # noqa: F4013031if TYPE_CHECKING:32 from collections.abc import (33 AsyncIterable,34 AsyncIterator,35 Awaitable,36 Callable,37 Coroutine,38 Iterable,39 )40 from contextvars import Context4142 from langchain_core.runnables.schema import StreamEvent4344Input = TypeVar("Input", contravariant=True) # noqa: PLC010545# Output type should implement __concat__, as eg str, list, dict do46Output = TypeVar("Output", covariant=True) # noqa: PLC0105474849async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:50 """Run a coroutine with a semaphore.5152 Args:53 semaphore: The semaphore to use.54 coro: The coroutine to run.5556 Returns:57 The result of the coroutine.58 """59 async with semaphore:60 return await coro616263async def gather_with_concurrency(n: int | None, *coros: Coroutine) -> list:64 """Gather coroutines with a limit on the number of concurrent coroutines.6566 Args:67 n: The number of coroutines to run concurrently.68 *coros: The coroutines to run.6970 Returns:71 The results of the coroutines.72 """73 if n is None:74 return await asyncio.gather(*coros)7576 semaphore = asyncio.Semaphore(n)7778 return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))798081def accepts_run_manager(callable: Callable[..., Any]) -> bool: # noqa: A00282 """Check if a callable accepts a run_manager argument.8384 Args:85 callable: The callable to check.8687 Returns:88 `True` if the callable accepts a run_manager argument, `False` otherwise.89 """90 try:91 return signature(callable).parameters.get("run_manager") is not None92 except ValueError:93 return False949596def accepts_config(callable: Callable[..., Any]) -> bool: # noqa: A00297 """Check if a callable accepts a config argument.9899 Args:100 callable: The callable to check.101102 Returns:103 `True` if the callable accepts a config argument, `False` otherwise.104 """105 try:106 return signature(callable).parameters.get("config") is not None107 except ValueError:108 return False109110111def accepts_context(callable: Callable[..., Any]) -> bool: # noqa: A002112 """Check if a callable accepts a context argument.113114 Args:115 callable: The callable to check.116117 Returns:118 `True` if the callable accepts a context argument, `False` otherwise.119 """120 try:121 return signature(callable).parameters.get("context") is not None122 except ValueError:123 return False124125126def asyncio_accepts_context() -> bool:127 """Check if asyncio.create_task accepts a `context` arg.128129 Returns:130 True if `asyncio.create_task` accepts a context argument, `False` otherwise.131 """132 return sys.version_info >= (3, 11)133134135_T = TypeVar("_T")136137138def coro_with_context(139 coro: Awaitable[_T], context: Context, *, create_task: bool = False140) -> Awaitable[_T]:141 """Await a coroutine with a context.142143 Args:144 coro: The coroutine to await.145 context: The context to use.146 create_task: Whether to create a task.147148 Returns:149 The coroutine with the context.150 """151 if asyncio_accepts_context():152 return asyncio.create_task(coro, context=context) # type: ignore[arg-type,call-arg,unused-ignore]153 if create_task:154 return asyncio.create_task(coro) # type: ignore[arg-type]155 return coro156157158class IsLocalDict(ast.NodeVisitor):159 """Check if a name is a local dict."""160161 def __init__(self, name: str, keys: set[str]) -> None:162 """Initialize the visitor.163164 Args:165 name: The name to check.166 keys: The keys to populate.167 """168 self.name = name169 self.keys = keys170171 @override172 def visit_Subscript(self, node: ast.Subscript) -> None:173 """Visit a subscript node.174175 Args:176 node: The node to visit.177 """178 if (179 isinstance(node.ctx, ast.Load)180 and isinstance(node.value, ast.Name)181 and node.value.id == self.name182 and isinstance(node.slice, ast.Constant)183 and isinstance(node.slice.value, str)184 ):185 # we've found a subscript access on the name we're looking for186 self.keys.add(node.slice.value)187188 @override189 def visit_Call(self, node: ast.Call) -> None:190 """Visit a call node.191192 Args:193 node: The node to visit.194 """195 if (196 isinstance(node.func, ast.Attribute)197 and isinstance(node.func.value, ast.Name)198 and node.func.value.id == self.name199 and node.func.attr == "get"200 and len(node.args) in {1, 2}201 and isinstance(node.args[0], ast.Constant)202 and isinstance(node.args[0].value, str)203 ):204 # we've found a .get() call on the name we're looking for205 self.keys.add(node.args[0].value)206207208class IsFunctionArgDict(ast.NodeVisitor):209 """Check if the first argument of a function is a dict."""210211 def __init__(self) -> None:212 """Create a IsFunctionArgDict visitor."""213 self.keys: set[str] = set()214215 @override216 def visit_Lambda(self, node: ast.Lambda) -> None:217 """Visit a lambda function.218219 Args:220 node: The node to visit.221 """222 if not node.args.args:223 return224 input_arg_name = node.args.args[0].arg225 IsLocalDict(input_arg_name, self.keys).visit(node.body)226227 @override228 def visit_FunctionDef(self, node: ast.FunctionDef) -> None:229 """Visit a function definition.230231 Args:232 node: The node to visit.233 """234 if not node.args.args:235 return236 input_arg_name = node.args.args[0].arg237 IsLocalDict(input_arg_name, self.keys).visit(node)238239 @override240 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:241 """Visit an async function definition.242243 Args:244 node: The node to visit.245 """246 if not node.args.args:247 return248 input_arg_name = node.args.args[0].arg249 IsLocalDict(input_arg_name, self.keys).visit(node)250251252class NonLocals(ast.NodeVisitor):253 """Get nonlocal variables accessed."""254255 def __init__(self) -> None:256 """Create a NonLocals visitor."""257 self.loads: set[str] = set()258 self.stores: set[str] = set()259260 @override261 def visit_Name(self, node: ast.Name) -> None:262 """Visit a name node.263264 Args:265 node: The node to visit.266 """267 if isinstance(node.ctx, ast.Load):268 self.loads.add(node.id)269 elif isinstance(node.ctx, ast.Store):270 self.stores.add(node.id)271272 @override273 def visit_Attribute(self, node: ast.Attribute) -> None:274 """Visit an attribute node.275276 Args:277 node: The node to visit.278 """279 if isinstance(node.ctx, ast.Load):280 parent = node.value281 attr_expr = node.attr282 while isinstance(parent, ast.Attribute):283 attr_expr = parent.attr + "." + attr_expr284 parent = parent.value285 if isinstance(parent, ast.Name):286 self.loads.add(parent.id + "." + attr_expr)287 self.loads.discard(parent.id)288 elif isinstance(parent, ast.Call):289 if isinstance(parent.func, ast.Name):290 self.loads.add(parent.func.id)291 else:292 parent = parent.func293 attr_expr = ""294 while isinstance(parent, ast.Attribute):295 if attr_expr:296 attr_expr = parent.attr + "." + attr_expr297 else:298 attr_expr = parent.attr299 parent = parent.value300 if isinstance(parent, ast.Name):301 self.loads.add(parent.id + "." + attr_expr)302303304class FunctionNonLocals(ast.NodeVisitor):305 """Get the nonlocal variables accessed of a function."""306307 def __init__(self) -> None:308 """Create a FunctionNonLocals visitor."""309 self.nonlocals: set[str] = set()310311 @override312 def visit_FunctionDef(self, node: ast.FunctionDef) -> None:313 """Visit a function definition.314315 Args:316 node: The node to visit.317 """318 visitor = NonLocals()319 visitor.visit(node)320 self.nonlocals.update(visitor.loads - visitor.stores)321322 @override323 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:324 """Visit an async function definition.325326 Args:327 node: The node to visit.328 """329 visitor = NonLocals()330 visitor.visit(node)331 self.nonlocals.update(visitor.loads - visitor.stores)332333 @override334 def visit_Lambda(self, node: ast.Lambda) -> None:335 """Visit a lambda function.336337 Args:338 node: The node to visit.339 """340 visitor = NonLocals()341 visitor.visit(node)342 self.nonlocals.update(visitor.loads - visitor.stores)343344345class GetLambdaSource(ast.NodeVisitor):346 """Get the source code of a lambda function."""347348 def __init__(self) -> None:349 """Initialize the visitor."""350 self.source: str | None = None351 self.count = 0352353 @override354 def visit_Lambda(self, node: ast.Lambda) -> None:355 """Visit a lambda function.356357 Args:358 node: The node to visit.359 """360 self.count += 1361 if hasattr(ast, "unparse"):362 self.source = ast.unparse(node)363364365def get_function_first_arg_dict_keys(func: Callable) -> list[str] | None:366 """Get the keys of the first argument of a function if it is a dict.367368 Args:369 func: The function to check.370371 Returns:372 The keys of the first argument if it is a dict, None otherwise.373 """374 try:375 code = inspect.getsource(func)376 tree = ast.parse(textwrap.dedent(code))377 visitor = IsFunctionArgDict()378 visitor.visit(tree)379 return sorted(visitor.keys) if visitor.keys else None380 except (SyntaxError, TypeError, OSError, SystemError):381 return None382383384def get_lambda_source(func: Callable) -> str | None:385 """Get the source code of a lambda function.386387 Args:388 func: a Callable that can be a lambda function.389390 Returns:391 the source code of the lambda function.392 """393 try:394 name = func.__name__ if func.__name__ != "<lambda>" else None395 except AttributeError:396 name = None397 try:398 code = inspect.getsource(func)399 tree = ast.parse(textwrap.dedent(code))400 visitor = GetLambdaSource()401 visitor.visit(tree)402 except (SyntaxError, TypeError, OSError, SystemError):403 return name404 return visitor.source if visitor.count == 1 else name405406407@lru_cache(maxsize=256)408def get_function_nonlocals(func: Callable) -> list[Any]:409 """Get the nonlocal variables accessed by a function.410411 Args:412 func: The function to check.413414 Returns:415 The nonlocal variables accessed by the function.416 """417 try:418 code = inspect.getsource(func)419 tree = ast.parse(textwrap.dedent(code))420 visitor = FunctionNonLocals()421 visitor.visit(tree)422 values: list[Any] = []423 closure = (424 inspect.getclosurevars(func.__wrapped__)425 if hasattr(func, "__wrapped__") and callable(func.__wrapped__)426 else inspect.getclosurevars(func)427 )428 candidates = {**closure.globals, **closure.nonlocals}429 for k, v in candidates.items():430 if k in visitor.nonlocals:431 values.append(v)432 for kk in visitor.nonlocals:433 if "." in kk and kk.startswith(k):434 vv = v435 for part in kk.split(".")[1:]:436 if vv is None:437 break438 try:439 vv = getattr(vv, part)440 except AttributeError:441 break442 else:443 values.append(vv)444 except (SyntaxError, TypeError, OSError, SystemError):445 return []446447 return values448449450def indent_lines_after_first(text: str, prefix: str) -> str:451 """Indent all lines of text after the first line.452453 Args:454 text: The text to indent.455 prefix: Used to determine the number of spaces to indent.456457 Returns:458 The indented text.459 """460 n_spaces = len(prefix)461 spaces = " " * n_spaces462 lines = text.splitlines()463 return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])464465466class AddableDict(dict[str, Any]):467 """Dictionary that can be added to another dictionary."""468469 def __add__(self, other: AddableDict) -> AddableDict:470 """Add a dictionary to this dictionary.471472 Args:473 other: The other dictionary to add.474475 Returns:476 A dictionary that is the result of adding the two dictionaries.477 """478 chunk = AddableDict(self)479 for key in other:480 if key not in chunk or chunk[key] is None:481 chunk[key] = other[key]482 elif other[key] is not None:483 try:484 added = chunk[key] + other[key]485 except TypeError:486 added = other[key]487 chunk[key] = added488 return chunk489490 def __radd__(self, other: AddableDict) -> AddableDict:491 """Add this dictionary to another dictionary.492493 Args:494 other: The other dictionary to be added to.495496 Returns:497 A dictionary that is the result of adding the two dictionaries.498 """499 chunk = AddableDict(other)500 for key in self:501 if key not in chunk or chunk[key] is None:502 chunk[key] = self[key]503 elif self[key] is not None:504 try:505 added = chunk[key] + self[key]506 except TypeError:507 added = self[key]508 chunk[key] = added509 return chunk510511512_T_co = TypeVar("_T_co", covariant=True)513_T_contra = TypeVar("_T_contra", contravariant=True)514515516class SupportsAdd(Protocol[_T_contra, _T_co]):517 """Protocol for objects that support addition."""518519 def __add__(self, x: _T_contra, /) -> _T_co:520 """Add the object to another object."""521522523Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])524525526def add(addables: Iterable[Addable]) -> Addable | None:527 """Add a sequence of addable objects together.528529 Args:530 addables: The addable objects to add.531532 Returns:533 The result of adding the addable objects.534 """535 final: Addable | None = None536 for chunk in addables:537 final = chunk if final is None else final + chunk538 return final539540541async def aadd(addables: AsyncIterable[Addable]) -> Addable | None:542 """Asynchronously add a sequence of addable objects together.543544 Args:545 addables: The addable objects to add.546547 Returns:548 The result of adding the addable objects.549 """550 final: Addable | None = None551 async for chunk in addables:552 final = chunk if final is None else final + chunk553 return final554555556class ConfigurableField(NamedTuple):557 """Field that can be configured by the user."""558559 id: str560 """The unique identifier of the field."""561562 name: str | None = None563 """The name of the field. """564565 description: str | None = None566 """The description of the field. """567568 annotation: Any | None = None569 """The annotation of the field. """570571 is_shared: bool = False572 """Whether the field is shared."""573574 @override575 def __hash__(self) -> int:576 return hash((self.id, self.annotation))577578579class ConfigurableFieldSingleOption(NamedTuple):580 """Field that can be configured by the user with a default value."""581582 id: str583 """The unique identifier of the field."""584585 options: Mapping[str, Any]586 """The options for the field."""587588 default: str589 """The default value for the field."""590591 name: str | None = None592 """The name of the field. """593594 description: str | None = None595 """The description of the field. """596597 is_shared: bool = False598 """Whether the field is shared."""599600 @override601 def __hash__(self) -> int:602 return hash((self.id, tuple(self.options.keys()), self.default))603604605class ConfigurableFieldMultiOption(NamedTuple):606 """Field that can be configured by the user with multiple default values."""607608 id: str609 """The unique identifier of the field."""610611 options: Mapping[str, Any]612 """The options for the field."""613614 default: Sequence[str]615 """The default values for the field."""616617 name: str | None = None618 """The name of the field. """619620 description: str | None = None621 """The description of the field. """622623 is_shared: bool = False624 """Whether the field is shared."""625626 @override627 def __hash__(self) -> int:628 return hash((self.id, tuple(self.options.keys()), tuple(self.default)))629630631AnyConfigurableField = (632 ConfigurableField | ConfigurableFieldSingleOption | ConfigurableFieldMultiOption633)634635636class ConfigurableFieldSpec(NamedTuple):637 """Field that can be configured by the user. It is a specification of a field."""638639 id: str640 """The unique identifier of the field."""641642 annotation: Any643 """The annotation of the field."""644645 name: str | None = None646 """The name of the field. """647648 description: str | None = None649 """The description of the field. """650651 default: Any = None652 """The default value for the field. """653654 is_shared: bool = False655 """Whether the field is shared."""656657 dependencies: list[str] | None = None658 """The dependencies of the field. """659660661def get_unique_config_specs(662 specs: Iterable[ConfigurableFieldSpec],663) -> list[ConfigurableFieldSpec]:664 """Get the unique config specs from a sequence of config specs.665666 Args:667 specs: The config specs.668669 Returns:670 The unique config specs.671672 Raises:673 ValueError: If the runnable sequence contains conflicting config specs.674 """675 grouped = groupby(676 sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id677 )678 unique: list[ConfigurableFieldSpec] = []679 for spec_id, dupes in grouped:680 first = next(dupes)681 others = list(dupes)682 if len(others) == 0 or all(o == first for o in others):683 unique.append(first)684 else:685 msg = (686 "RunnableSequence contains conflicting config specs"687 f"for {spec_id}: {[first, *others]}"688 )689 raise ValueError(msg)690 return unique691692693class _RootEventFilter:694 def __init__(695 self,696 *,697 include_names: Sequence[str] | None = None,698 include_types: Sequence[str] | None = None,699 include_tags: Sequence[str] | None = None,700 exclude_names: Sequence[str] | None = None,701 exclude_types: Sequence[str] | None = None,702 exclude_tags: Sequence[str] | None = None,703 ) -> None:704 """Utility to filter the root event in the astream_events implementation.705706 This is simply binding the arguments to the namespace to make save on707 a bit of typing in the astream_events implementation.708 """709 self.include_names = include_names710 self.include_types = include_types711 self.include_tags = include_tags712 self.exclude_names = exclude_names713 self.exclude_types = exclude_types714 self.exclude_tags = exclude_tags715716 def include_event(self, event: StreamEvent, root_type: str) -> bool:717 """Determine whether to include an event."""718 if (719 self.include_names is None720 and self.include_types is None721 and self.include_tags is None722 ):723 include = True724 else:725 include = False726727 event_tags = event.get("tags") or []728729 if self.include_names is not None:730 include = include or event["name"] in self.include_names731 if self.include_types is not None:732 include = include or root_type in self.include_types733 if self.include_tags is not None:734 include = include or any(tag in self.include_tags for tag in event_tags)735736 if self.exclude_names is not None:737 include = include and event["name"] not in self.exclude_names738 if self.exclude_types is not None:739 include = include and root_type not in self.exclude_types740 if self.exclude_tags is not None:741 include = include and all(742 tag not in self.exclude_tags for tag in event_tags743 )744745 return include746747748def is_async_generator(749 func: Any,750) -> TypeGuard[Callable[..., AsyncIterator]]:751 """Check if a function is an async generator.752753 Args:754 func: The function to check.755756 Returns:757 `True` if the function is an async generator, `False` otherwise.758 """759 return inspect.isasyncgenfunction(func) or (760 hasattr(func, "__call__") # noqa: B004761 and inspect.isasyncgenfunction(func.__call__)762 )763764765def is_async_callable(766 func: Any,767) -> TypeGuard[Callable[..., Awaitable]]:768 """Check if a function is async.769770 Args:771 func: The function to check.772773 Returns:774 `True` if the function is async, `False` otherwise.775 """776 return asyncio.iscoroutinefunction(func) or (777 hasattr(func, "__call__") # noqa: B004778 and asyncio.iscoroutinefunction(func.__call__)779 )
Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.