libs/partners/mistralai/langchain_mistralai/chat_models.py PYTHON 1,275 lines View on github.com → Search inside
1from __future__ import annotations23import hashlib4import json5import logging6import os7import re8import ssl9import uuid10from collections.abc import Callable, Sequence  # noqa: TC00311from operator import itemgetter12from typing import (13    TYPE_CHECKING,14    Any,15    Literal,16    cast,17)1819import certifi20import httpx21from httpx_sse import EventSource, aconnect_sse, connect_sse22from langchain_core.callbacks import (23    AsyncCallbackManagerForLLMRun,24    CallbackManagerForLLMRun,25)26from langchain_core.language_models import (27    LanguageModelInput,28    ModelProfile,29    ModelProfileRegistry,30)31from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams32from langchain_core.language_models.llms import create_base_retry_decorator33from langchain_core.messages import (34    AIMessage,35    AIMessageChunk,36    BaseMessage,37    BaseMessageChunk,38    ChatMessage,39    ChatMessageChunk,40    HumanMessage,41    HumanMessageChunk,42    InvalidToolCall,43    SystemMessage,44    SystemMessageChunk,45    ToolCall,46    ToolMessage,47    is_data_content_block,48)49from langchain_core.messages.block_translators.openai import (50    convert_to_openai_data_block,51)52from langchain_core.messages.tool import tool_call_chunk53from langchain_core.output_parsers import (54    JsonOutputParser,55    PydanticOutputParser,56)57from langchain_core.output_parsers.base import OutputParserLike58from langchain_core.output_parsers.openai_tools import (59    JsonOutputKeyToolsParser,60    PydanticToolsParser,61    make_invalid_tool_call,62    parse_tool_call,63)64from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult65from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough66from langchain_core.tools import BaseTool67from langchain_core.utils import get_pydantic_field_names, secret_from_env68from langchain_core.utils.function_calling import convert_to_openai_tool69from langchain_core.utils.pydantic import is_basemodel_subclass70from langchain_core.utils.utils import _build_model_kwargs71from pydantic import (72    BaseModel,73    ConfigDict,74    Field,75    SecretStr,76    model_validator,77)78from typing_extensions import Self7980from langchain_mistralai._compat import _convert_from_v1_to_mistral81from langchain_mistralai.data._profiles import _PROFILES8283if TYPE_CHECKING:84    from collections.abc import AsyncIterator, Iterator85    from contextlib import AbstractAsyncContextManager8687logger = logging.getLogger(__name__)8889# Mistral enforces a specific pattern for tool call IDs90TOOL_CALL_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{9}$")919293# This SSL context is equivalent to the default `verify=True`.94# https://www.python-httpx.org/advanced/ssl/#configuring-client-instances95global_ssl_context = ssl.create_default_context(cafile=certifi.where())969798_MODEL_PROFILES = cast("ModelProfileRegistry", _PROFILES)99100101def _get_default_model_profile(model_name: str) -> ModelProfile:102    default = _MODEL_PROFILES.get(model_name) or {}103    return default.copy()104105106def _create_retry_decorator(107    llm: ChatMistralAI,108    run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None,109) -> Callable[[Any], Any]:110    """Return a tenacity retry decorator, preconfigured to handle exceptions."""111    errors = [httpx.RequestError, httpx.StreamError]112    return create_base_retry_decorator(113        error_types=errors, max_retries=llm.max_retries, run_manager=run_manager114    )115116117def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:118    """Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9."""119    return bool(TOOL_CALL_ID_PATTERN.match(tool_call_id))120121122def _base62_encode(num: int) -> str:123    """Encode a number in base62 and ensures result is of a specified length."""124    base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"125    if num == 0:126        return base62[0]127    arr = []128    base = len(base62)129    while num:130        num, rem = divmod(num, base)131        arr.append(base62[rem])132    arr.reverse()133    return "".join(arr)134135136def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:137    """Convert a tool call ID to a Mistral-compatible format."""138    if _is_valid_mistral_tool_call_id(tool_call_id):139        return tool_call_id140    hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()141    hash_int = int.from_bytes(hash_bytes, byteorder="big")142    base62_str = _base62_encode(hash_int)143    if len(base62_str) >= 9:144        return base62_str[:9]145    return base62_str.rjust(9, "0")146147148def _convert_mistral_chat_message_to_message(149    _message: dict,150) -> BaseMessage:151    role = _message["role"]152    if role != "assistant":153        msg = f"Expected role to be 'assistant', got {role}"154        raise ValueError(msg)155    # Mistral returns None for tool invocations156    content = _message.get("content", "") or ""157158    additional_kwargs: dict = {}159    tool_calls = []160    invalid_tool_calls = []161    if raw_tool_calls := _message.get("tool_calls"):162        additional_kwargs["tool_calls"] = raw_tool_calls163        for raw_tool_call in raw_tool_calls:164            try:165                parsed: dict = cast(166                    "dict", parse_tool_call(raw_tool_call, return_id=True)167                )168                if not parsed["id"]:169                    parsed["id"] = uuid.uuid4().hex[:]170                tool_calls.append(parsed)171            except Exception as e:172                invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e)))173    return AIMessage(174        content=content,175        additional_kwargs=additional_kwargs,176        tool_calls=tool_calls,177        invalid_tool_calls=invalid_tool_calls,178        response_metadata={"model_provider": "mistralai"},179    )180181182def _raise_on_error(response: httpx.Response) -> None:183    """Raise an error if the response is an error."""184    if httpx.codes.is_error(response.status_code):185        error_message = response.read().decode("utf-8")186        msg = (187            f"Error response {response.status_code} "188            f"while fetching {response.url}: {error_message}"189        )190        raise httpx.HTTPStatusError(191            msg,192            request=response.request,193            response=response,194        )195196197async def _araise_on_error(response: httpx.Response) -> None:198    """Raise an error if the response is an error."""199    if httpx.codes.is_error(response.status_code):200        error_message = (await response.aread()).decode("utf-8")201        msg = (202            f"Error response {response.status_code} "203            f"while fetching {response.url}: {error_message}"204        )205        raise httpx.HTTPStatusError(206            msg,207            request=response.request,208            response=response,209        )210211212async def _aiter_sse(213    event_source_mgr: AbstractAsyncContextManager[EventSource],214) -> AsyncIterator[dict]:215    """Iterate over the server-sent events."""216    async with event_source_mgr as event_source:217        await _araise_on_error(event_source.response)218        async for event in event_source.aiter_sse():219            if event.data == "[DONE]":220                return221            yield event.json()222223224async def acompletion_with_retry(225    llm: ChatMistralAI,226    run_manager: AsyncCallbackManagerForLLMRun | None = None,227    **kwargs: Any,228) -> Any:229    """Use tenacity to retry the async completion call."""230    retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)231232    @retry_decorator233    async def _completion_with_retry(**kwargs: Any) -> Any:234        if "stream" not in kwargs:235            kwargs["stream"] = False236        stream = kwargs["stream"]237        if stream:238            event_source = aconnect_sse(239                llm.async_client, "POST", "/chat/completions", json=kwargs240            )241            return _aiter_sse(event_source)242        response = await llm.async_client.post(url="/chat/completions", json=kwargs)243        await _araise_on_error(response)244        return response.json()245246    return await _completion_with_retry(**kwargs)247248249def _convert_chunk_to_message_chunk(250    chunk: dict,251    default_class: type[BaseMessageChunk],252    index: int,253    index_type: str,254    output_version: str | None,255) -> tuple[BaseMessageChunk, int, str]:256    _choice = chunk["choices"][0]257    _delta = _choice["delta"]258    role = _delta.get("role")259    content = _delta.get("content") or ""260    if output_version == "v1" and isinstance(content, str):261        content = [{"type": "text", "text": content}]262    if isinstance(content, list):263        for block in content:264            if isinstance(block, dict):265                if "type" in block and block["type"] != index_type:266                    index_type = block["type"]267                    index = index + 1268                if "index" not in block:269                    block["index"] = index270                if block.get("type") == "thinking" and isinstance(271                    block.get("thinking"), list272                ):273                    for sub_block in block["thinking"]:274                        if isinstance(sub_block, dict) and "index" not in sub_block:275                            sub_block["index"] = 0276    if role == "user" or default_class == HumanMessageChunk:277        return HumanMessageChunk(content=content), index, index_type278    if role == "assistant" or default_class == AIMessageChunk:279        additional_kwargs: dict = {}280        response_metadata = {}281        if raw_tool_calls := _delta.get("tool_calls"):282            additional_kwargs["tool_calls"] = raw_tool_calls283            try:284                tool_call_chunks = []285                for raw_tool_call in raw_tool_calls:286                    if not raw_tool_call.get("index") and not raw_tool_call.get("id"):287                        tool_call_id = uuid.uuid4().hex[:]288                    else:289                        tool_call_id = raw_tool_call.get("id")290                    tool_call_chunks.append(291                        tool_call_chunk(292                            name=raw_tool_call["function"].get("name"),293                            args=raw_tool_call["function"].get("arguments"),294                            id=tool_call_id,295                            index=raw_tool_call.get("index"),296                        )297                    )298            except KeyError:299                pass300        else:301            tool_call_chunks = []302        if token_usage := chunk.get("usage"):303            usage_metadata = {304                "input_tokens": token_usage.get("prompt_tokens", 0),305                "output_tokens": token_usage.get("completion_tokens", 0),306                "total_tokens": token_usage.get("total_tokens", 0),307            }308        else:309            usage_metadata = None310        if _choice.get("finish_reason") is not None and isinstance(311            chunk.get("model"), str312        ):313            response_metadata["model_name"] = chunk["model"]314            response_metadata["finish_reason"] = _choice["finish_reason"]315        return (316            AIMessageChunk(317                content=content,318                additional_kwargs=additional_kwargs,319                tool_call_chunks=tool_call_chunks,  # type: ignore[arg-type]320                usage_metadata=usage_metadata,  # type: ignore[arg-type]321                response_metadata={"model_provider": "mistralai", **response_metadata},322            ),323            index,324            index_type,325        )326    if role == "system" or default_class == SystemMessageChunk:327        return SystemMessageChunk(content=content), index, index_type328    if role or default_class == ChatMessageChunk:329        return ChatMessageChunk(content=content, role=role), index, index_type330    return default_class(content=content), index, index_type  # type: ignore[call-arg]331332333def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:334    """Format LangChain ToolCall to dict expected by Mistral."""335    result: dict[str, Any] = {336        "function": {337            "name": tool_call["name"],338            "arguments": json.dumps(tool_call["args"], ensure_ascii=False),339        }340    }341    if _id := tool_call.get("id"):342        result["id"] = _convert_tool_call_id_to_mistral_compatible(_id)343344    return result345346347def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:348    """Format LangChain InvalidToolCall to dict expected by Mistral."""349    result: dict[str, Any] = {350        "function": {351            "name": invalid_tool_call["name"],352            "arguments": invalid_tool_call["args"],353        }354    }355    if _id := invalid_tool_call.get("id"):356        result["id"] = _convert_tool_call_id_to_mistral_compatible(_id)357358    return result359360361def _clean_block(block: dict) -> dict:362    # Remove "index" key added for message aggregation in langchain-core363    new_block = {k: v for k, v in block.items() if k != "index"}364    if block.get("type") == "thinking" and isinstance(block.get("thinking"), list):365        new_block["thinking"] = [366            (367                {k: v for k, v in sb.items() if k != "index"}368                if isinstance(sb, dict) and "index" in sb369                else sb370            )371            for sb in block["thinking"]372        ]373    return new_block374375376def _sanitize_chat_completions_content(content: Any) -> Any:377    """Strip non-wire keys from text content blocks.378379    Mistral's chat completions endpoint rejects unknown fields on tool380    message content blocks (e.g. the `id` that LangChain auto-generates on381    `TextContentBlock`). For list content, keep only `type` and `text` on382    text blocks; pass other blocks and non-list content through unchanged.383    """384    if not isinstance(content, list):385        return content386    sanitized: list[Any] = []387    for block in content:388        if isinstance(block, dict) and block.get("type") == "text" and "text" in block:389            sanitized.append({"type": "text", "text": block["text"]})390        else:391            sanitized.append(block)392    return sanitized393394395def _format_message_content(content: Any) -> Any:396    """Format message content for the Mistral chat completions wire format.397398    Walks list content and translates LangChain canonical v0/v1 multimodal399    data blocks (e.g. `ImageContentBlock` with `url`, `base64`, or400    `file_id`) into the OpenAI-compatible shape that Mistral accepts:401    `{"type": "image_url", "image_url": {"url": "..."}}`. Strings and any402    other dict blocks are returned unchanged so that already-translated wire403    blocks (e.g. `text`, `image_url`) and Mistral-specific blocks404    (`document_url`, `input_audio`) pass through; the API surfaces an error405    for anything it doesn't understand.406407    Args:408        content: The message content. Strings and non-list values pass409            through unchanged; lists are walked block by block.410411    Returns:412        The formatted content. List inputs return a new list with canonical413        data-block translations applied; other inputs are returned as-is.414    """415    if not isinstance(content, list):416        return content417    formatted: list[Any] = []418    for block in content:419        if isinstance(block, dict) and is_data_content_block(block):420            formatted.append(421                convert_to_openai_data_block(block, api="chat/completions")422            )423            continue424        formatted.append(block)425    return formatted426427428def _convert_message_to_mistral_chat_message(429    message: BaseMessage,430) -> dict:431    if isinstance(message, ChatMessage):432        return {"role": message.role, "content": message.content}433    if isinstance(message, HumanMessage):434        return {"role": "user", "content": _format_message_content(message.content)}435    if isinstance(message, AIMessage):436        message_dict: dict[str, Any] = {"role": "assistant"}437        tool_calls: list = []438        if message.tool_calls or message.invalid_tool_calls:439            if message.tool_calls:440                tool_calls.extend(441                    _format_tool_call_for_mistral(tool_call)442                    for tool_call in message.tool_calls443                )444            if message.invalid_tool_calls:445                tool_calls.extend(446                    _format_invalid_tool_call_for_mistral(invalid_tool_call)447                    for invalid_tool_call in message.invalid_tool_calls448                )449        elif "tool_calls" in message.additional_kwargs:450            for tc in message.additional_kwargs["tool_calls"]:451                chunk = {452                    "function": {453                        "name": tc["function"]["name"],454                        "arguments": tc["function"]["arguments"],455                    }456                }457                if _id := tc.get("id"):458                    chunk["id"] = _id459                tool_calls.append(chunk)460        else:461            pass462        if tool_calls:  # do not populate empty list tool_calls463            message_dict["tool_calls"] = tool_calls464465        # Message content466        # Translate v1 content467        if message.response_metadata.get("output_version") == "v1":468            content = _convert_from_v1_to_mistral(469                message.content_blocks, message.response_metadata.get("model_provider")470            )471        else:472            content = message.content473474        if tool_calls and content:475            # Assistant message must have either content or tool_calls, but not both.476            # Some providers may not support tool_calls in the same message as content.477            # This is done to ensure compatibility with messages from other providers.478            content = ""479480        elif isinstance(content, list):481            content = [482                _clean_block(block)483                if isinstance(block, dict) and "index" in block484                else block485                for block in content486            ]487        else:488            content = message.content489490        # if any blocks are dicts, cast strings to text blocks491        if any(isinstance(block, dict) for block in content):492            content = [493                block if isinstance(block, dict) else {"type": "text", "text": block}494                for block in content495            ]496        message_dict["content"] = content497498        if "prefix" in message.additional_kwargs:499            message_dict["prefix"] = message.additional_kwargs["prefix"]500        return message_dict501    if isinstance(message, SystemMessage):502        return {"role": "system", "content": message.content}503    if isinstance(message, ToolMessage):504        return {505            "role": "tool",506            "content": _sanitize_chat_completions_content(message.content),507            "name": message.name,508            "tool_call_id": _convert_tool_call_id_to_mistral_compatible(509                message.tool_call_id510            ),511        }512    msg = f"Got unknown type {message}"513    raise ValueError(msg)514515516class ChatMistralAI(BaseChatModel):517    """A chat model that uses the Mistral AI API."""518519    # The type for client and async_client is ignored because the type is not520    # an Optional after the model is initialized and the model_validator521    # is run.522    client: httpx.Client = Field(  # type: ignore[assignment] # : meta private:523        default=None, exclude=True524    )525526    async_client: httpx.AsyncClient = Field(  # type: ignore[assignment] # : meta private:527        default=None, exclude=True528    )529530    mistral_api_key: SecretStr | None = Field(531        alias="api_key",532        default_factory=secret_from_env("MISTRAL_API_KEY", default=None),533    )534535    endpoint: str | None = Field(default=None, alias="base_url")536537    max_retries: int = 5538539    timeout: int = 120540541    max_concurrent_requests: int = 64542543    model: str = Field(default="mistral-small", alias="model_name")544545    temperature: float = 0.7546547    max_tokens: int | None = None548549    top_p: float = 1550    """Decode using nucleus sampling: consider the smallest set of tokens whose551    probability sum is at least `top_p`. Must be in the closed interval552    `[0.0, 1.0]`."""553554    random_seed: int | None = None555556    safe_mode: bool | None = None557558    streaming: bool = False559560    model_kwargs: dict[str, Any] = Field(default_factory=dict)561    """Holds any invocation parameters not explicitly specified."""562563    model_config = ConfigDict(564        populate_by_name=True,565        arbitrary_types_allowed=True,566    )567568    @model_validator(mode="before")569    @classmethod570    def build_extra(cls, values: dict[str, Any]) -> Any:571        """Build extra kwargs from additional params that were passed in."""572        all_required_field_names = get_pydantic_field_names(cls)573        return _build_model_kwargs(values, all_required_field_names)574575    @property576    def _default_params(self) -> dict[str, Any]:577        """Get the default parameters for calling the API."""578        defaults = {579            "model": self.model,580            "temperature": self.temperature,581            "max_tokens": self.max_tokens,582            "top_p": self.top_p,583            "random_seed": self.random_seed,584            "safe_prompt": self.safe_mode,585            **self.model_kwargs,586        }587        return {k: v for k, v in defaults.items() if v is not None}588589    def _get_ls_params(590        self, stop: list[str] | None = None, **kwargs: Any591    ) -> LangSmithParams:592        """Get standard params for tracing."""593        params = self._get_invocation_params(stop=stop, **kwargs)594        ls_params = LangSmithParams(595            ls_provider="mistral",596            ls_model_name=params.get("model", self.model),597            ls_model_type="chat",598            ls_temperature=params.get("temperature", self.temperature),599        )600        if ls_max_tokens := params.get("max_tokens", self.max_tokens):601            ls_params["ls_max_tokens"] = ls_max_tokens602        if ls_stop := stop or params.get("stop", None):603            ls_params["ls_stop"] = ls_stop604        return ls_params605606    @property607    def _client_params(self) -> dict[str, Any]:608        """Get the parameters used for the client."""609        return self._default_params610611    def completion_with_retry(612        self, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any613    ) -> Any:614        """Use tenacity to retry the completion call."""615        retry_decorator = _create_retry_decorator(self, run_manager=run_manager)616617        @retry_decorator618        def _completion_with_retry(**kwargs: Any) -> Any:619            if "stream" not in kwargs:620                kwargs["stream"] = False621            stream = kwargs["stream"]622            if stream:623624                def iter_sse() -> Iterator[dict]:625                    with connect_sse(626                        self.client, "POST", "/chat/completions", json=kwargs627                    ) as event_source:628                        _raise_on_error(event_source.response)629                        for event in event_source.iter_sse():630                            if event.data == "[DONE]":631                                return632                            yield event.json()633634                return iter_sse()635            response = self.client.post(url="/chat/completions", json=kwargs)636            _raise_on_error(response)637            return response.json()638639        return _completion_with_retry(**kwargs)640641    def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict:642        overall_token_usage: dict = {}643        for output in llm_outputs:644            if output is None:645                # Happens in streaming646                continue647            token_usage = output["token_usage"]648            if token_usage is not None:649                for k, v in token_usage.items():650                    if k in overall_token_usage:651                        overall_token_usage[k] += v652                    else:653                        overall_token_usage[k] = v654        return {"token_usage": overall_token_usage, "model_name": self.model}655656    @model_validator(mode="after")657    def validate_environment(self) -> Self:658        """Validate api key, python package exists, temperature, and top_p."""659        if isinstance(self.mistral_api_key, SecretStr):660            api_key_str: str | None = self.mistral_api_key.get_secret_value()661        else:662            api_key_str = self.mistral_api_key663664        # TODO: handle retries665        base_url_str = (666            self.endpoint667            or os.environ.get("MISTRAL_BASE_URL")668            or "https://api.mistral.ai/v1"669        )670        self.endpoint = base_url_str671        if not self.client:672            self.client = httpx.Client(673                base_url=base_url_str,674                headers={675                    "Content-Type": "application/json",676                    "Accept": "application/json",677                    "Authorization": f"Bearer {api_key_str}",678                },679                timeout=self.timeout,680                verify=global_ssl_context,681            )682        # TODO: handle retries and max_concurrency683        if not self.async_client:684            self.async_client = httpx.AsyncClient(685                base_url=base_url_str,686                headers={687                    "Content-Type": "application/json",688                    "Accept": "application/json",689                    "Authorization": f"Bearer {api_key_str}",690                },691                timeout=self.timeout,692                verify=global_ssl_context,693            )694695        if self.temperature is not None and not 0 <= self.temperature <= 1:696            msg = "temperature must be in the range [0.0, 1.0]"697            raise ValueError(msg)698699        if self.top_p is not None and not 0 <= self.top_p <= 1:700            msg = "top_p must be in the range [0.0, 1.0]"701            raise ValueError(msg)702703        return self704705    def _resolve_model_profile(self) -> ModelProfile | None:706        return _get_default_model_profile(self.model) or None707708    def _generate(709        self,710        messages: list[BaseMessage],711        stop: list[str] | None = None,712        run_manager: CallbackManagerForLLMRun | None = None,713        stream: bool | None = None,  # noqa: FBT001714        **kwargs: Any,715    ) -> ChatResult:716        message_dicts, params = self._create_message_dicts(messages, stop)717        params = {**params, **kwargs}718        response = self.completion_with_retry(719            messages=message_dicts, run_manager=run_manager, **params720        )721        return self._create_chat_result(response)722723    def _create_chat_result(self, response: dict) -> ChatResult:724        generations = []725        token_usage = response.get("usage", {})726        for res in response["choices"]:727            finish_reason = res.get("finish_reason")728            message = _convert_mistral_chat_message_to_message(res["message"])729            if token_usage and isinstance(message, AIMessage):730                message.usage_metadata = {731                    "input_tokens": token_usage.get("prompt_tokens", 0),732                    "output_tokens": token_usage.get("completion_tokens", 0),733                    "total_tokens": token_usage.get("total_tokens", 0),734                }735            gen = ChatGeneration(736                message=message,737                generation_info={"finish_reason": finish_reason},738            )739            generations.append(gen)740741        llm_output = {742            "token_usage": token_usage,743            "model_name": self.model,744            "model": self.model,  # Backwards compatibility745        }746        return ChatResult(generations=generations, llm_output=llm_output)747748    def _create_message_dicts(749        self, messages: list[BaseMessage], stop: list[str] | None750    ) -> tuple[list[dict], dict[str, Any]]:751        params = self._client_params752        if stop is not None or "stop" in params:753            if "stop" in params:754                params.pop("stop")755            logger.warning(756                "Parameter `stop` not yet supported (https://docs.mistral.ai/api)"757            )758        message_dicts = [_convert_message_to_mistral_chat_message(m) for m in messages]759        return message_dicts, params760761    def _stream(762        self,763        messages: list[BaseMessage],764        stop: list[str] | None = None,765        run_manager: CallbackManagerForLLMRun | None = None,766        **kwargs: Any,767    ) -> Iterator[ChatGenerationChunk]:768        message_dicts, params = self._create_message_dicts(messages, stop)769        params = {**params, **kwargs, "stream": True}770771        default_chunk_class: type[BaseMessageChunk] = AIMessageChunk772        index = -1773        index_type = ""774        for chunk in self.completion_with_retry(775            messages=message_dicts, run_manager=run_manager, **params776        ):777            if len(chunk.get("choices", [])) == 0:778                continue779            new_chunk, index, index_type = _convert_chunk_to_message_chunk(780                chunk, default_chunk_class, index, index_type, self.output_version781            )782            # make future chunks same type as first chunk783            default_chunk_class = new_chunk.__class__784            gen_chunk = ChatGenerationChunk(message=new_chunk)785            if run_manager:786                run_manager.on_llm_new_token(787                    token=cast("str", new_chunk.content), chunk=gen_chunk788                )789            yield gen_chunk790791    async def _astream(792        self,793        messages: list[BaseMessage],794        stop: list[str] | None = None,795        run_manager: AsyncCallbackManagerForLLMRun | None = None,796        **kwargs: Any,797    ) -> AsyncIterator[ChatGenerationChunk]:798        message_dicts, params = self._create_message_dicts(messages, stop)799        params = {**params, **kwargs, "stream": True}800801        default_chunk_class: type[BaseMessageChunk] = AIMessageChunk802        index = -1803        index_type = ""804        async for chunk in await acompletion_with_retry(805            self, messages=message_dicts, run_manager=run_manager, **params806        ):807            if len(chunk.get("choices", [])) == 0:808                continue809            new_chunk, index, index_type = _convert_chunk_to_message_chunk(810                chunk, default_chunk_class, index, index_type, self.output_version811            )812            # make future chunks same type as first chunk813            default_chunk_class = new_chunk.__class__814            gen_chunk = ChatGenerationChunk(message=new_chunk)815            if run_manager:816                await run_manager.on_llm_new_token(817                    token=cast("str", new_chunk.content), chunk=gen_chunk818                )819            yield gen_chunk820821    async def _agenerate(822        self,823        messages: list[BaseMessage],824        stop: list[str] | None = None,825        run_manager: AsyncCallbackManagerForLLMRun | None = None,826        stream: bool | None = None,  # noqa: FBT001827        **kwargs: Any,828    ) -> ChatResult:829        message_dicts, params = self._create_message_dicts(messages, stop)830        params = {**params, **kwargs}831        response = await acompletion_with_retry(832            self, messages=message_dicts, run_manager=run_manager, **params833        )834        return self._create_chat_result(response)835836    def bind_tools(837        self,838        tools: Sequence[dict[str, Any] | type | Callable | BaseTool],839        tool_choice: dict | str | Literal["auto", "any"] | None = None,  # noqa: PYI051840        **kwargs: Any,841    ) -> Runnable[LanguageModelInput, AIMessage]:842        """Bind tool-like objects to this chat model.843844        Assumes model is compatible with OpenAI tool-calling API.845846        Args:847            tools: A list of tool definitions to bind to this chat model.848849                Supports any tool definition handled by [`convert_to_openai_tool`][langchain_core.utils.function_calling.convert_to_openai_tool].850            tool_choice: Which tool to require the model to call.851                Must be the name of the single provided function or852                `'auto'` to automatically determine which function to call853                (if any), or a dict of the form:854                {"type": "function", "function": {"name": <<tool_name>>}}.855            kwargs: Any additional parameters are passed directly to856                `self.bind(**kwargs)`.857        """  # noqa: E501858        formatted_tools = [convert_to_openai_tool(tool) for tool in tools]859        if tool_choice:860            tool_names = []861            for tool in formatted_tools:862                if ("function" in tool and (name := tool["function"].get("name"))) or (863                    name := tool.get("name")864                ):865                    tool_names.append(name)866                else:867                    pass868            if tool_choice in tool_names:869                kwargs["tool_choice"] = {870                    "type": "function",871                    "function": {"name": tool_choice},872                }873            else:874                kwargs["tool_choice"] = tool_choice875        return super().bind(tools=formatted_tools, **kwargs)876877    def with_structured_output(878        self,879        schema: dict | type | None = None,880        *,881        method: Literal[882            "function_calling", "json_mode", "json_schema"883        ] = "function_calling",884        include_raw: bool = False,885        **kwargs: Any,886    ) -> Runnable[LanguageModelInput, dict | BaseModel]:887        r"""Model wrapper that returns outputs formatted to match the given schema.888889        Args:890            schema: The output schema. Can be passed in as:891892                - An OpenAI function/tool schema,893                - A JSON Schema,894                - A `TypedDict` class,895                - Or a Pydantic class.896897                If `schema` is a Pydantic class then the model output will be a898                Pydantic instance of that class, and the model-generated fields will be899                validated by the Pydantic class. Otherwise the model output will be a900                dict and will not be validated.901902                See `langchain_core.utils.function_calling.convert_to_openai_tool` for903                more on how to properly specify types and descriptions of schema fields904                when specifying a Pydantic or `TypedDict` class.905906            method: The method for steering model generation, one of:907908                - `'function_calling'`:909                    Uses Mistral's910                    [function-calling feature](https://docs.mistral.ai/capabilities/function_calling/).911                - `'json_schema'`:912                    Uses Mistral's913                    [structured output feature](https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/).914                - `'json_mode'`:915                    Uses Mistral's916                    [JSON mode](https://docs.mistral.ai/capabilities/structured-output/json_mode/).917                    Note that if using JSON mode then you918                    must include instructions for formatting the output into the919                    desired schema into the model call.920921                !!! warning "Behavior changed in `langchain-mistralai` 0.2.5"922923                    Added method="json_schema"924925            include_raw:926                If `False` then only the parsed structured output is returned.927928                If an error occurs during model output parsing it will be raised.929930                If `True` then both the raw model response (a `BaseMessage`) and the931                parsed model response will be returned.932933                If an error occurs during output parsing it will be caught and returned934                as well.935936                The final output is always a `dict` with keys `'raw'`, `'parsed'`, and937                `'parsing_error'`.938939            kwargs: Any additional parameters are passed directly to940                `self.bind(**kwargs)`. This is useful for passing in941                parameters such as `tool_choice` or `tools` to control942                which tool the model should call, or to pass in parameters such as943                `stop` to control when the model should stop generating output.944945        Returns:946            A `Runnable` that takes same inputs as a947                `langchain_core.language_models.chat.BaseChatModel`. If `include_raw` is948                `False` and `schema` is a Pydantic class, `Runnable` outputs an instance949                of `schema` (i.e., a Pydantic object). Otherwise, if `include_raw` is950                `False` then `Runnable` outputs a `dict`.951952                If `include_raw` is `True`, then `Runnable` outputs a `dict` with keys:953954                - `'raw'`: `BaseMessage`955                - `'parsed'`: `None` if there was a parsing error, otherwise the type956                    depends on the `schema` as described above.957                - `'parsing_error'`: `BaseException | None`958959        Example: schema=Pydantic class, method="function_calling", include_raw=False:960961        ```python962        from typing import Optional963964        from langchain_mistralai import ChatMistralAI965        from pydantic import BaseModel, Field966967968        class AnswerWithJustification(BaseModel):969            '''An answer to the user question along with justification for the answer.'''970971            answer: str972            # If we provide default values and/or descriptions for fields, these will be passed973            # to the model. This is an important part of improving a model's ability to974            # correctly return structured outputs.975            justification: str | None = Field(976                default=None, description="A justification for the answer."977            )978979980        model = ChatMistralAI(model="mistral-large-latest", temperature=0)981        structured_model = model.with_structured_output(AnswerWithJustification)982983        structured_model.invoke(984            "What weighs more a pound of bricks or a pound of feathers"985        )986987        # -> AnswerWithJustification(988        #     answer='They weigh the same',989        #     justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'990        # )991        ```992993        Example: schema=Pydantic class, method="function_calling", include_raw=True:994995        ```python996        from langchain_mistralai import ChatMistralAI997        from pydantic import BaseModel9989991000        class AnswerWithJustification(BaseModel):1001            '''An answer to the user question along with justification for the answer.'''10021003            answer: str1004            justification: str100510061007        model = ChatMistralAI(model="mistral-large-latest", temperature=0)1008        structured_model = model.with_structured_output(1009            AnswerWithJustification, include_raw=True1010        )10111012        structured_model.invoke(1013            "What weighs more a pound of bricks or a pound of feathers"1014        )1015        # -> {1016        #     'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),1017        #     'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),1018        #     'parsing_error': None1019        # }1020        ```10211022        Example: schema=TypedDict class, method="function_calling", include_raw=False:10231024        ```python1025        from typing_extensions import Annotated, TypedDict10261027        from langchain_mistralai import ChatMistralAI102810291030        class AnswerWithJustification(TypedDict):1031            '''An answer to the user question along with justification for the answer.'''10321033            answer: str1034            justification: Annotated[1035                str | None, None, "A justification for the answer."1036            ]103710381039        model = ChatMistralAI(model="mistral-large-latest", temperature=0)1040        structured_model = model.with_structured_output(AnswerWithJustification)10411042        structured_model.invoke(1043            "What weighs more a pound of bricks or a pound of feathers"1044        )1045        # -> {1046        #     'answer': 'They weigh the same',1047        #     'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'1048        # }1049        ```10501051        Example: schema=OpenAI function schema, method="function_calling", include_raw=False:10521053        ```python1054        from langchain_mistralai import ChatMistralAI10551056        oai_schema = {1057            'name': 'AnswerWithJustification',1058            'description': 'An answer to the user question along with justification for the answer.',1059            'parameters': {1060                'type': 'object',1061                'properties': {1062                    'answer': {'type': 'string'},1063                    'justification': {'description': 'A justification for the answer.', 'type': 'string'}1064                },1065                'required': ['answer']1066            }10671068            model = ChatMistralAI(model="mistral-large-latest", temperature=0)1069            structured_model = model.with_structured_output(oai_schema)10701071            structured_model.invoke(1072                "What weighs more a pound of bricks or a pound of feathers"1073            )1074            # -> {1075            #     'answer': 'They weigh the same',1076            #     'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'1077            # }1078        ```10791080        Example: schema=Pydantic class, method="json_mode", include_raw=True:10811082        ```python1083        from langchain_mistralai import ChatMistralAI1084        from pydantic import BaseModel108510861087        class AnswerWithJustification(BaseModel):1088            answer: str1089            justification: str109010911092        model = ChatMistralAI(model="mistral-large-latest", temperature=0)1093        structured_model = model.with_structured_output(1094            AnswerWithJustification, method="json_mode", include_raw=True1095        )10961097        structured_model.invoke(1098            "Answer the following question. "1099            "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"1100            "What's heavier a pound of bricks or a pound of feathers?"1101        )1102        # -> {1103        #     'raw': AIMessage(content='{\\n    "answer": "They are both the same weight.",\\n    "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),1104        #     'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),1105        #     'parsing_error': None1106        # }1107        ```11081109        Example: schema=None, method="json_mode", include_raw=True:11101111        ```python1112        structured_model = model.with_structured_output(1113            method="json_mode", include_raw=True1114        )11151116        structured_model.invoke(1117            "Answer the following question. "1118            "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"1119            "What's heavier a pound of bricks or a pound of feathers?"1120        )1121        # -> {1122        #     'raw': AIMessage(content='{\\n    "answer": "They are both the same weight.",\\n    "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),1123        #     'parsed': {1124        #         'answer': 'They are both the same weight.',1125        #         'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'1126        #     },1127        #     'parsing_error': None1128        # }1129        ```1130        """  # noqa: E5011131        _ = kwargs.pop("strict", None)1132        if kwargs:1133            msg = f"Received unsupported arguments {kwargs}"1134            raise ValueError(msg)1135        is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)1136        if method == "function_calling":1137            if schema is None:1138                msg = (1139                    "schema must be specified when method is 'function_calling'. "1140                    "Received None."1141                )1142                raise ValueError(msg)1143            # TODO: Update to pass in tool name as tool_choice if/when Mistral supports1144            # specifying a tool.1145            llm = self.bind_tools(1146                [schema],1147                tool_choice="any",1148                ls_structured_output_format={1149                    "kwargs": {"method": "function_calling"},1150                    "schema": schema,1151                },1152            )1153            if is_pydantic_schema:1154                output_parser: OutputParserLike = PydanticToolsParser(1155                    tools=[schema],  # type: ignore[list-item]1156                    first_tool_only=True,  # type: ignore[list-item]1157                )1158            else:1159                key_name = convert_to_openai_tool(schema)["function"]["name"]1160                output_parser = JsonOutputKeyToolsParser(1161                    key_name=key_name, first_tool_only=True1162                )1163        elif method == "json_mode":1164            llm = self.bind(1165                response_format={"type": "json_object"},1166                ls_structured_output_format={1167                    "kwargs": {1168                        # this is correct - name difference with mistral api1169                        "method": "json_mode"1170                    },1171                    "schema": schema,1172                },1173            )1174            output_parser = (1175                PydanticOutputParser(pydantic_object=schema)  # type: ignore[type-var, arg-type]1176                if is_pydantic_schema1177                else JsonOutputParser()1178            )1179        elif method == "json_schema":1180            if schema is None:1181                msg = (1182                    "schema must be specified when method is 'json_schema'. "1183                    "Received None."1184                )1185                raise ValueError(msg)1186            response_format = _convert_to_openai_response_format(schema, strict=True)1187            llm = self.bind(1188                response_format=response_format,1189                ls_structured_output_format={1190                    "kwargs": {"method": "json_schema"},1191                    "schema": schema,1192                },1193            )11941195            output_parser = (1196                PydanticOutputParser(pydantic_object=schema)  # type: ignore[arg-type]1197                if is_pydantic_schema1198                else JsonOutputParser()1199            )1200        if include_raw:1201            parser_assign = RunnablePassthrough.assign(1202                parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None1203            )1204            parser_none = RunnablePassthrough.assign(parsed=lambda _: None)1205            parser_with_fallback = parser_assign.with_fallbacks(1206                [parser_none], exception_key="parsing_error"1207            )1208            return RunnableMap(raw=llm) | parser_with_fallback1209        return llm | output_parser12101211    @property1212    def _identifying_params(self) -> dict[str, Any]:1213        """Get the identifying parameters."""1214        return self._default_params12151216    @property1217    def _llm_type(self) -> str:1218        """Return type of chat model."""1219        return "mistralai-chat"12201221    @property1222    def lc_secrets(self) -> dict[str, str]:1223        return {"mistral_api_key": "MISTRAL_API_KEY"}12241225    @classmethod1226    def is_lc_serializable(cls) -> bool:1227        """Return whether this model can be serialized by LangChain."""1228        return True12291230    @classmethod1231    def get_lc_namespace(cls) -> list[str]:1232        """Get the namespace of the LangChain object.12331234        Returns:1235            `["langchain", "chat_models", "mistralai"]`1236        """1237        return ["langchain", "chat_models", "mistralai"]123812391240def _convert_to_openai_response_format(1241    schema: dict[str, Any] | type, *, strict: bool | None = None1242) -> dict:1243    """Perform same op as in ChatOpenAI, but do not pass through Pydantic BaseModels."""1244    if (1245        isinstance(schema, dict)1246        and "json_schema" in schema1247        and schema.get("type") == "json_schema"1248    ):1249        response_format = schema1250    elif isinstance(schema, dict) and "name" in schema and "schema" in schema:1251        response_format = {"type": "json_schema", "json_schema": schema}1252    else:1253        if strict is None:1254            if isinstance(schema, dict) and isinstance(schema.get("strict"), bool):1255                strict = schema["strict"]1256            else:1257                strict = False1258        function = convert_to_openai_tool(schema, strict=strict)["function"]1259        function["schema"] = function.pop("parameters")1260        response_format = {"type": "json_schema", "json_schema": function}12611262    if (1263        strict is not None1264        and strict is not response_format["json_schema"].get("strict")1265        and isinstance(schema, dict)1266    ):1267        msg = (1268            f"Output schema already has 'strict' value set to "1269            f"{schema['json_schema']['strict']} but 'strict' also passed in to "1270            f"with_structured_output as {strict}. Please make sure that "1271            f"'strict' is only specified in one place."1272        )1273        raise ValueError(msg)1274    return response_format

Code quality findings 34

Ensure functions have docstrings for documentation
missing-docstring
async def acompletion_with_retry(
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if output_version == "v1" and isinstance(content, str):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(content, list):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(block, dict):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if block.get("type") == "thinking" and isinstance(
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(sub_block, dict) and "index" not in sub_block:
Ensure try blocks have corresponding except or finally blocks
try-without-except
try:
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if _choice.get("finish_reason") is not None and isinstance(
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if block.get("type") == "thinking" and isinstance(block.get("thinking"), list):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(sb, dict) and "index" in sb
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if not isinstance(content, list):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(block, dict) and block.get("type") == "text" and "text" in block:
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if not isinstance(content, list):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(block, dict) and is_data_content_block(block):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(message, ChatMessage):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(message, HumanMessage):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(message, AIMessage):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
elif isinstance(content, list):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(block, dict) and "index" in block
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
block if isinstance(block, dict) else {"type": "text", "text": block}
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(message, SystemMessage):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(message, ToolMessage):
Ensure functions have docstrings for documentation
missing-docstring
def completion_with_retry(
Ensure functions have docstrings for documentation
missing-docstring
def iter_sse() -> Iterator[dict]:
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(self.mistral_api_key, SecretStr):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if token_usage and isinstance(message, AIMessage):
Ensure functions have docstrings for documentation
missing-docstring
def bind_tools(
Ensure functions have docstrings for documentation
missing-docstring
def with_structured_output(
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
Ensure functions have docstrings for documentation
missing-docstring
def lc_secrets(self) -> dict[str, str]:
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
isinstance(schema, dict)
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
elif isinstance(schema, dict) and "name" in schema and "schema" in schema:
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(schema, dict) and isinstance(schema.get("strict"), bool):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
and isinstance(schema, dict)

Get this view in your editor

Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.