libs/core/langchain_core/runnables/utils.py PYTHON 780 lines View on github.com → Search inside
1"""Utility code for `Runnable` objects."""23from __future__ import annotations45import ast6import asyncio7import inspect8import sys9import textwrap1011# Cannot move to TYPE_CHECKING as Mapping and Sequence are needed at runtime by12# RunnableConfigurableFields.13from collections.abc import Mapping, Sequence  # noqa: TC00314from functools import lru_cache15from inspect import signature16from itertools import groupby17from typing import (18    TYPE_CHECKING,19    Any,20    NamedTuple,21    Protocol,22    TypeGuard,23    TypeVar,24)2526from typing_extensions import override2728# Re-export create-model for backwards compatibility29from langchain_core.utils.pydantic import create_model  # noqa: F4013031if TYPE_CHECKING:32    from collections.abc import (33        AsyncIterable,34        AsyncIterator,35        Awaitable,36        Callable,37        Coroutine,38        Iterable,39    )40    from contextvars import Context4142    from langchain_core.runnables.schema import StreamEvent4344Input = TypeVar("Input", contravariant=True)  # noqa: PLC010545# Output type should implement __concat__, as eg str, list, dict do46Output = TypeVar("Output", covariant=True)  # noqa: PLC0105474849async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:50    """Run a coroutine with a semaphore.5152    Args:53        semaphore: The semaphore to use.54        coro: The coroutine to run.5556    Returns:57        The result of the coroutine.58    """59    async with semaphore:60        return await coro616263async def gather_with_concurrency(n: int | None, *coros: Coroutine) -> list:64    """Gather coroutines with a limit on the number of concurrent coroutines.6566    Args:67        n: The number of coroutines to run concurrently.68        *coros: The coroutines to run.6970    Returns:71        The results of the coroutines.72    """73    if n is None:74        return await asyncio.gather(*coros)7576    semaphore = asyncio.Semaphore(n)7778    return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))798081def accepts_run_manager(callable: Callable[..., Any]) -> bool:  # noqa: A00282    """Check if a callable accepts a run_manager argument.8384    Args:85        callable: The callable to check.8687    Returns:88        `True` if the callable accepts a run_manager argument, `False` otherwise.89    """90    try:91        return signature(callable).parameters.get("run_manager") is not None92    except ValueError:93        return False949596def accepts_config(callable: Callable[..., Any]) -> bool:  # noqa: A00297    """Check if a callable accepts a config argument.9899    Args:100        callable: The callable to check.101102    Returns:103        `True` if the callable accepts a config argument, `False` otherwise.104    """105    try:106        return signature(callable).parameters.get("config") is not None107    except ValueError:108        return False109110111def accepts_context(callable: Callable[..., Any]) -> bool:  # noqa: A002112    """Check if a callable accepts a context argument.113114    Args:115        callable: The callable to check.116117    Returns:118        `True` if the callable accepts a context argument, `False` otherwise.119    """120    try:121        return signature(callable).parameters.get("context") is not None122    except ValueError:123        return False124125126def asyncio_accepts_context() -> bool:127    """Check if asyncio.create_task accepts a `context` arg.128129    Returns:130        True if `asyncio.create_task` accepts a context argument, `False` otherwise.131    """132    return sys.version_info >= (3, 11)133134135_T = TypeVar("_T")136137138def coro_with_context(139    coro: Awaitable[_T], context: Context, *, create_task: bool = False140) -> Awaitable[_T]:141    """Await a coroutine with a context.142143    Args:144        coro: The coroutine to await.145        context: The context to use.146        create_task: Whether to create a task.147148    Returns:149        The coroutine with the context.150    """151    if asyncio_accepts_context():152        return asyncio.create_task(coro, context=context)  # type: ignore[arg-type,call-arg,unused-ignore]153    if create_task:154        return asyncio.create_task(coro)  # type: ignore[arg-type]155    return coro156157158class IsLocalDict(ast.NodeVisitor):159    """Check if a name is a local dict."""160161    def __init__(self, name: str, keys: set[str]) -> None:162        """Initialize the visitor.163164        Args:165            name: The name to check.166            keys: The keys to populate.167        """168        self.name = name169        self.keys = keys170171    @override172    def visit_Subscript(self, node: ast.Subscript) -> None:173        """Visit a subscript node.174175        Args:176            node: The node to visit.177        """178        if (179            isinstance(node.ctx, ast.Load)180            and isinstance(node.value, ast.Name)181            and node.value.id == self.name182            and isinstance(node.slice, ast.Constant)183            and isinstance(node.slice.value, str)184        ):185            # we've found a subscript access on the name we're looking for186            self.keys.add(node.slice.value)187188    @override189    def visit_Call(self, node: ast.Call) -> None:190        """Visit a call node.191192        Args:193            node: The node to visit.194        """195        if (196            isinstance(node.func, ast.Attribute)197            and isinstance(node.func.value, ast.Name)198            and node.func.value.id == self.name199            and node.func.attr == "get"200            and len(node.args) in {1, 2}201            and isinstance(node.args[0], ast.Constant)202            and isinstance(node.args[0].value, str)203        ):204            # we've found a .get() call on the name we're looking for205            self.keys.add(node.args[0].value)206207208class IsFunctionArgDict(ast.NodeVisitor):209    """Check if the first argument of a function is a dict."""210211    def __init__(self) -> None:212        """Create a IsFunctionArgDict visitor."""213        self.keys: set[str] = set()214215    @override216    def visit_Lambda(self, node: ast.Lambda) -> None:217        """Visit a lambda function.218219        Args:220            node: The node to visit.221        """222        if not node.args.args:223            return224        input_arg_name = node.args.args[0].arg225        IsLocalDict(input_arg_name, self.keys).visit(node.body)226227    @override228    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:229        """Visit a function definition.230231        Args:232            node: The node to visit.233        """234        if not node.args.args:235            return236        input_arg_name = node.args.args[0].arg237        IsLocalDict(input_arg_name, self.keys).visit(node)238239    @override240    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:241        """Visit an async function definition.242243        Args:244            node: The node to visit.245        """246        if not node.args.args:247            return248        input_arg_name = node.args.args[0].arg249        IsLocalDict(input_arg_name, self.keys).visit(node)250251252class NonLocals(ast.NodeVisitor):253    """Get nonlocal variables accessed."""254255    def __init__(self) -> None:256        """Create a NonLocals visitor."""257        self.loads: set[str] = set()258        self.stores: set[str] = set()259260    @override261    def visit_Name(self, node: ast.Name) -> None:262        """Visit a name node.263264        Args:265            node: The node to visit.266        """267        if isinstance(node.ctx, ast.Load):268            self.loads.add(node.id)269        elif isinstance(node.ctx, ast.Store):270            self.stores.add(node.id)271272    @override273    def visit_Attribute(self, node: ast.Attribute) -> None:274        """Visit an attribute node.275276        Args:277            node: The node to visit.278        """279        if isinstance(node.ctx, ast.Load):280            parent = node.value281            attr_expr = node.attr282            while isinstance(parent, ast.Attribute):283                attr_expr = parent.attr + "." + attr_expr284                parent = parent.value285            if isinstance(parent, ast.Name):286                self.loads.add(parent.id + "." + attr_expr)287                self.loads.discard(parent.id)288            elif isinstance(parent, ast.Call):289                if isinstance(parent.func, ast.Name):290                    self.loads.add(parent.func.id)291                else:292                    parent = parent.func293                    attr_expr = ""294                    while isinstance(parent, ast.Attribute):295                        if attr_expr:296                            attr_expr = parent.attr + "." + attr_expr297                        else:298                            attr_expr = parent.attr299                        parent = parent.value300                    if isinstance(parent, ast.Name):301                        self.loads.add(parent.id + "." + attr_expr)302303304class FunctionNonLocals(ast.NodeVisitor):305    """Get the nonlocal variables accessed of a function."""306307    def __init__(self) -> None:308        """Create a FunctionNonLocals visitor."""309        self.nonlocals: set[str] = set()310311    @override312    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:313        """Visit a function definition.314315        Args:316            node: The node to visit.317        """318        visitor = NonLocals()319        visitor.visit(node)320        self.nonlocals.update(visitor.loads - visitor.stores)321322    @override323    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:324        """Visit an async function definition.325326        Args:327            node: The node to visit.328        """329        visitor = NonLocals()330        visitor.visit(node)331        self.nonlocals.update(visitor.loads - visitor.stores)332333    @override334    def visit_Lambda(self, node: ast.Lambda) -> None:335        """Visit a lambda function.336337        Args:338            node: The node to visit.339        """340        visitor = NonLocals()341        visitor.visit(node)342        self.nonlocals.update(visitor.loads - visitor.stores)343344345class GetLambdaSource(ast.NodeVisitor):346    """Get the source code of a lambda function."""347348    def __init__(self) -> None:349        """Initialize the visitor."""350        self.source: str | None = None351        self.count = 0352353    @override354    def visit_Lambda(self, node: ast.Lambda) -> None:355        """Visit a lambda function.356357        Args:358            node: The node to visit.359        """360        self.count += 1361        if hasattr(ast, "unparse"):362            self.source = ast.unparse(node)363364365def get_function_first_arg_dict_keys(func: Callable) -> list[str] | None:366    """Get the keys of the first argument of a function if it is a dict.367368    Args:369        func: The function to check.370371    Returns:372        The keys of the first argument if it is a dict, None otherwise.373    """374    try:375        code = inspect.getsource(func)376        tree = ast.parse(textwrap.dedent(code))377        visitor = IsFunctionArgDict()378        visitor.visit(tree)379        return sorted(visitor.keys) if visitor.keys else None380    except (SyntaxError, TypeError, OSError, SystemError):381        return None382383384def get_lambda_source(func: Callable) -> str | None:385    """Get the source code of a lambda function.386387    Args:388        func: a Callable that can be a lambda function.389390    Returns:391        the source code of the lambda function.392    """393    try:394        name = func.__name__ if func.__name__ != "<lambda>" else None395    except AttributeError:396        name = None397    try:398        code = inspect.getsource(func)399        tree = ast.parse(textwrap.dedent(code))400        visitor = GetLambdaSource()401        visitor.visit(tree)402    except (SyntaxError, TypeError, OSError, SystemError):403        return name404    return visitor.source if visitor.count == 1 else name405406407@lru_cache(maxsize=256)408def get_function_nonlocals(func: Callable) -> list[Any]:409    """Get the nonlocal variables accessed by a function.410411    Args:412        func: The function to check.413414    Returns:415        The nonlocal variables accessed by the function.416    """417    try:418        code = inspect.getsource(func)419        tree = ast.parse(textwrap.dedent(code))420        visitor = FunctionNonLocals()421        visitor.visit(tree)422        values: list[Any] = []423        closure = (424            inspect.getclosurevars(func.__wrapped__)425            if hasattr(func, "__wrapped__") and callable(func.__wrapped__)426            else inspect.getclosurevars(func)427        )428        candidates = {**closure.globals, **closure.nonlocals}429        for k, v in candidates.items():430            if k in visitor.nonlocals:431                values.append(v)432            for kk in visitor.nonlocals:433                if "." in kk and kk.startswith(k):434                    vv = v435                    for part in kk.split(".")[1:]:436                        if vv is None:437                            break438                        try:439                            vv = getattr(vv, part)440                        except AttributeError:441                            break442                    else:443                        values.append(vv)444    except (SyntaxError, TypeError, OSError, SystemError):445        return []446447    return values448449450def indent_lines_after_first(text: str, prefix: str) -> str:451    """Indent all lines of text after the first line.452453    Args:454        text: The text to indent.455        prefix: Used to determine the number of spaces to indent.456457    Returns:458        The indented text.459    """460    n_spaces = len(prefix)461    spaces = " " * n_spaces462    lines = text.splitlines()463    return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])464465466class AddableDict(dict[str, Any]):467    """Dictionary that can be added to another dictionary."""468469    def __add__(self, other: AddableDict) -> AddableDict:470        """Add a dictionary to this dictionary.471472        Args:473            other: The other dictionary to add.474475        Returns:476            A dictionary that is the result of adding the two dictionaries.477        """478        chunk = AddableDict(self)479        for key in other:480            if key not in chunk or chunk[key] is None:481                chunk[key] = other[key]482            elif other[key] is not None:483                try:484                    added = chunk[key] + other[key]485                except TypeError:486                    added = other[key]487                chunk[key] = added488        return chunk489490    def __radd__(self, other: AddableDict) -> AddableDict:491        """Add this dictionary to another dictionary.492493        Args:494            other: The other dictionary to be added to.495496        Returns:497            A dictionary that is the result of adding the two dictionaries.498        """499        chunk = AddableDict(other)500        for key in self:501            if key not in chunk or chunk[key] is None:502                chunk[key] = self[key]503            elif self[key] is not None:504                try:505                    added = chunk[key] + self[key]506                except TypeError:507                    added = self[key]508                chunk[key] = added509        return chunk510511512_T_co = TypeVar("_T_co", covariant=True)513_T_contra = TypeVar("_T_contra", contravariant=True)514515516class SupportsAdd(Protocol[_T_contra, _T_co]):517    """Protocol for objects that support addition."""518519    def __add__(self, x: _T_contra, /) -> _T_co:520        """Add the object to another object."""521522523Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])524525526def add(addables: Iterable[Addable]) -> Addable | None:527    """Add a sequence of addable objects together.528529    Args:530        addables: The addable objects to add.531532    Returns:533        The result of adding the addable objects.534    """535    final: Addable | None = None536    for chunk in addables:537        final = chunk if final is None else final + chunk538    return final539540541async def aadd(addables: AsyncIterable[Addable]) -> Addable | None:542    """Asynchronously add a sequence of addable objects together.543544    Args:545        addables: The addable objects to add.546547    Returns:548        The result of adding the addable objects.549    """550    final: Addable | None = None551    async for chunk in addables:552        final = chunk if final is None else final + chunk553    return final554555556class ConfigurableField(NamedTuple):557    """Field that can be configured by the user."""558559    id: str560    """The unique identifier of the field."""561562    name: str | None = None563    """The name of the field. """564565    description: str | None = None566    """The description of the field. """567568    annotation: Any | None = None569    """The annotation of the field. """570571    is_shared: bool = False572    """Whether the field is shared."""573574    @override575    def __hash__(self) -> int:576        return hash((self.id, self.annotation))577578579class ConfigurableFieldSingleOption(NamedTuple):580    """Field that can be configured by the user with a default value."""581582    id: str583    """The unique identifier of the field."""584585    options: Mapping[str, Any]586    """The options for the field."""587588    default: str589    """The default value for the field."""590591    name: str | None = None592    """The name of the field. """593594    description: str | None = None595    """The description of the field. """596597    is_shared: bool = False598    """Whether the field is shared."""599600    @override601    def __hash__(self) -> int:602        return hash((self.id, tuple(self.options.keys()), self.default))603604605class ConfigurableFieldMultiOption(NamedTuple):606    """Field that can be configured by the user with multiple default values."""607608    id: str609    """The unique identifier of the field."""610611    options: Mapping[str, Any]612    """The options for the field."""613614    default: Sequence[str]615    """The default values for the field."""616617    name: str | None = None618    """The name of the field. """619620    description: str | None = None621    """The description of the field. """622623    is_shared: bool = False624    """Whether the field is shared."""625626    @override627    def __hash__(self) -> int:628        return hash((self.id, tuple(self.options.keys()), tuple(self.default)))629630631AnyConfigurableField = (632    ConfigurableField | ConfigurableFieldSingleOption | ConfigurableFieldMultiOption633)634635636class ConfigurableFieldSpec(NamedTuple):637    """Field that can be configured by the user. It is a specification of a field."""638639    id: str640    """The unique identifier of the field."""641642    annotation: Any643    """The annotation of the field."""644645    name: str | None = None646    """The name of the field. """647648    description: str | None = None649    """The description of the field. """650651    default: Any = None652    """The default value for the field. """653654    is_shared: bool = False655    """Whether the field is shared."""656657    dependencies: list[str] | None = None658    """The dependencies of the field. """659660661def get_unique_config_specs(662    specs: Iterable[ConfigurableFieldSpec],663) -> list[ConfigurableFieldSpec]:664    """Get the unique config specs from a sequence of config specs.665666    Args:667        specs: The config specs.668669    Returns:670        The unique config specs.671672    Raises:673        ValueError: If the runnable sequence contains conflicting config specs.674    """675    grouped = groupby(676        sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id677    )678    unique: list[ConfigurableFieldSpec] = []679    for spec_id, dupes in grouped:680        first = next(dupes)681        others = list(dupes)682        if len(others) == 0 or all(o == first for o in others):683            unique.append(first)684        else:685            msg = (686                "RunnableSequence contains conflicting config specs"687                f"for {spec_id}: {[first, *others]}"688            )689            raise ValueError(msg)690    return unique691692693class _RootEventFilter:694    def __init__(695        self,696        *,697        include_names: Sequence[str] | None = None,698        include_types: Sequence[str] | None = None,699        include_tags: Sequence[str] | None = None,700        exclude_names: Sequence[str] | None = None,701        exclude_types: Sequence[str] | None = None,702        exclude_tags: Sequence[str] | None = None,703    ) -> None:704        """Utility to filter the root event in the astream_events implementation.705706        This is simply binding the arguments to the namespace to make save on707        a bit of typing in the astream_events implementation.708        """709        self.include_names = include_names710        self.include_types = include_types711        self.include_tags = include_tags712        self.exclude_names = exclude_names713        self.exclude_types = exclude_types714        self.exclude_tags = exclude_tags715716    def include_event(self, event: StreamEvent, root_type: str) -> bool:717        """Determine whether to include an event."""718        if (719            self.include_names is None720            and self.include_types is None721            and self.include_tags is None722        ):723            include = True724        else:725            include = False726727        event_tags = event.get("tags") or []728729        if self.include_names is not None:730            include = include or event["name"] in self.include_names731        if self.include_types is not None:732            include = include or root_type in self.include_types733        if self.include_tags is not None:734            include = include or any(tag in self.include_tags for tag in event_tags)735736        if self.exclude_names is not None:737            include = include and event["name"] not in self.exclude_names738        if self.exclude_types is not None:739            include = include and root_type not in self.exclude_types740        if self.exclude_tags is not None:741            include = include and all(742                tag not in self.exclude_tags for tag in event_tags743            )744745        return include746747748def is_async_generator(749    func: Any,750) -> TypeGuard[Callable[..., AsyncIterator]]:751    """Check if a function is an async generator.752753    Args:754        func: The function to check.755756    Returns:757        `True` if the function is an async generator, `False` otherwise.758    """759    return inspect.isasyncgenfunction(func) or (760        hasattr(func, "__call__")  # noqa: B004761        and inspect.isasyncgenfunction(func.__call__)762    )763764765def is_async_callable(766    func: Any,767) -> TypeGuard[Callable[..., Awaitable]]:768    """Check if a function is async.769770    Args:771        func: The function to check.772773    Returns:774        `True` if the function is async, `False` otherwise.775    """776    return asyncio.iscoroutinefunction(func) or (777        hasattr(func, "__call__")  # noqa: B004778        and asyncio.iscoroutinefunction(func.__call__)779    )

Code quality findings 24

Ensure functions have docstrings for documentation
missing-docstring
def coro_with_context(
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
isinstance(node.ctx, ast.Load)
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
and isinstance(node.value, ast.Name)
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
and isinstance(node.slice, ast.Constant)
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
and isinstance(node.slice.value, str)
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
isinstance(node.func, ast.Attribute)
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
and isinstance(node.func.value, ast.Name)
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
and isinstance(node.args[0], ast.Constant)
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
and isinstance(node.args[0].value, str)
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(node.ctx, ast.Load):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
elif isinstance(node.ctx, ast.Store):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(node.ctx, ast.Load):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
while isinstance(parent, ast.Attribute):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(parent, ast.Name):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
elif isinstance(parent, ast.Call):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(parent.func, ast.Name):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
while isinstance(parent, ast.Attribute):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(parent, ast.Name):
Ensure try blocks have corresponding except or finally blocks
try-without-except
try:
Ensure functions have docstrings for documentation
missing-docstring
def get_unique_config_specs(
Avoid unnecessary list conversions; use generators where possible
unnecessary-list
others = list(dupes)
Ensure functions have docstrings for documentation
missing-docstring
def is_async_generator(
Ensure functions have docstrings for documentation
missing-docstring
def is_async_callable(
Avoid complex 'lambda' functions; prefer named functions for clarity and debugging
info maintainability complex-lambda
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id

Get this view in your editor

Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.