Ensure functions have docstrings for documentation
async def acompletion_with_retry(
1from __future__ import annotations23import hashlib4import json5import logging6import os7import re8import ssl9import uuid10from collections.abc import Callable, Sequence # noqa: TC00311from operator import itemgetter12from typing import (13 TYPE_CHECKING,14 Any,15 Literal,16 cast,17)1819import certifi20import httpx21from httpx_sse import EventSource, aconnect_sse, connect_sse22from langchain_core.callbacks import (23 AsyncCallbackManagerForLLMRun,24 CallbackManagerForLLMRun,25)26from langchain_core.language_models import (27 LanguageModelInput,28 ModelProfile,29 ModelProfileRegistry,30)31from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams32from langchain_core.language_models.llms import create_base_retry_decorator33from langchain_core.messages import (34 AIMessage,35 AIMessageChunk,36 BaseMessage,37 BaseMessageChunk,38 ChatMessage,39 ChatMessageChunk,40 HumanMessage,41 HumanMessageChunk,42 InvalidToolCall,43 SystemMessage,44 SystemMessageChunk,45 ToolCall,46 ToolMessage,47 is_data_content_block,48)49from langchain_core.messages.block_translators.openai import (50 convert_to_openai_data_block,51)52from langchain_core.messages.tool import tool_call_chunk53from langchain_core.output_parsers import (54 JsonOutputParser,55 PydanticOutputParser,56)57from langchain_core.output_parsers.base import OutputParserLike58from langchain_core.output_parsers.openai_tools import (59 JsonOutputKeyToolsParser,60 PydanticToolsParser,61 make_invalid_tool_call,62 parse_tool_call,63)64from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult65from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough66from langchain_core.tools import BaseTool67from langchain_core.utils import get_pydantic_field_names, secret_from_env68from langchain_core.utils.function_calling import convert_to_openai_tool69from langchain_core.utils.pydantic import is_basemodel_subclass70from langchain_core.utils.utils import _build_model_kwargs71from pydantic import (72 BaseModel,73 ConfigDict,74 Field,75 SecretStr,76 model_validator,77)78from typing_extensions import Self7980from langchain_mistralai._compat import _convert_from_v1_to_mistral81from langchain_mistralai.data._profiles import _PROFILES8283if TYPE_CHECKING:84 from collections.abc import AsyncIterator, Iterator85 from contextlib import AbstractAsyncContextManager8687logger = logging.getLogger(__name__)8889# Mistral enforces a specific pattern for tool call IDs90TOOL_CALL_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{9}$")919293# This SSL context is equivalent to the default `verify=True`.94# https://www.python-httpx.org/advanced/ssl/#configuring-client-instances95global_ssl_context = ssl.create_default_context(cafile=certifi.where())969798_MODEL_PROFILES = cast("ModelProfileRegistry", _PROFILES)99100101def _get_default_model_profile(model_name: str) -> ModelProfile:102 default = _MODEL_PROFILES.get(model_name) or {}103 return default.copy()104105106def _create_retry_decorator(107 llm: ChatMistralAI,108 run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None,109) -> Callable[[Any], Any]:110 """Return a tenacity retry decorator, preconfigured to handle exceptions."""111 errors = [httpx.RequestError, httpx.StreamError]112 return create_base_retry_decorator(113 error_types=errors, max_retries=llm.max_retries, run_manager=run_manager114 )115116117def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:118 """Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9."""119 return bool(TOOL_CALL_ID_PATTERN.match(tool_call_id))120121122def _base62_encode(num: int) -> str:123 """Encode a number in base62 and ensures result is of a specified length."""124 base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"125 if num == 0:126 return base62[0]127 arr = []128 base = len(base62)129 while num:130 num, rem = divmod(num, base)131 arr.append(base62[rem])132 arr.reverse()133 return "".join(arr)134135136def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:137 """Convert a tool call ID to a Mistral-compatible format."""138 if _is_valid_mistral_tool_call_id(tool_call_id):139 return tool_call_id140 hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()141 hash_int = int.from_bytes(hash_bytes, byteorder="big")142 base62_str = _base62_encode(hash_int)143 if len(base62_str) >= 9:144 return base62_str[:9]145 return base62_str.rjust(9, "0")146147148def _convert_mistral_chat_message_to_message(149 _message: dict,150) -> BaseMessage:151 role = _message["role"]152 if role != "assistant":153 msg = f"Expected role to be 'assistant', got {role}"154 raise ValueError(msg)155 # Mistral returns None for tool invocations156 content = _message.get("content", "") or ""157158 additional_kwargs: dict = {}159 tool_calls = []160 invalid_tool_calls = []161 if raw_tool_calls := _message.get("tool_calls"):162 additional_kwargs["tool_calls"] = raw_tool_calls163 for raw_tool_call in raw_tool_calls:164 try:165 parsed: dict = cast(166 "dict", parse_tool_call(raw_tool_call, return_id=True)167 )168 if not parsed["id"]:169 parsed["id"] = uuid.uuid4().hex[:]170 tool_calls.append(parsed)171 except Exception as e:172 invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e)))173 return AIMessage(174 content=content,175 additional_kwargs=additional_kwargs,176 tool_calls=tool_calls,177 invalid_tool_calls=invalid_tool_calls,178 response_metadata={"model_provider": "mistralai"},179 )180181182def _raise_on_error(response: httpx.Response) -> None:183 """Raise an error if the response is an error."""184 if httpx.codes.is_error(response.status_code):185 error_message = response.read().decode("utf-8")186 msg = (187 f"Error response {response.status_code} "188 f"while fetching {response.url}: {error_message}"189 )190 raise httpx.HTTPStatusError(191 msg,192 request=response.request,193 response=response,194 )195196197async def _araise_on_error(response: httpx.Response) -> None:198 """Raise an error if the response is an error."""199 if httpx.codes.is_error(response.status_code):200 error_message = (await response.aread()).decode("utf-8")201 msg = (202 f"Error response {response.status_code} "203 f"while fetching {response.url}: {error_message}"204 )205 raise httpx.HTTPStatusError(206 msg,207 request=response.request,208 response=response,209 )210211212async def _aiter_sse(213 event_source_mgr: AbstractAsyncContextManager[EventSource],214) -> AsyncIterator[dict]:215 """Iterate over the server-sent events."""216 async with event_source_mgr as event_source:217 await _araise_on_error(event_source.response)218 async for event in event_source.aiter_sse():219 if event.data == "[DONE]":220 return221 yield event.json()222223224async def acompletion_with_retry(225 llm: ChatMistralAI,226 run_manager: AsyncCallbackManagerForLLMRun | None = None,227 **kwargs: Any,228) -> Any:229 """Use tenacity to retry the async completion call."""230 retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)231232 @retry_decorator233 async def _completion_with_retry(**kwargs: Any) -> Any:234 if "stream" not in kwargs:235 kwargs["stream"] = False236 stream = kwargs["stream"]237 if stream:238 event_source = aconnect_sse(239 llm.async_client, "POST", "/chat/completions", json=kwargs240 )241 return _aiter_sse(event_source)242 response = await llm.async_client.post(url="/chat/completions", json=kwargs)243 await _araise_on_error(response)244 return response.json()245246 return await _completion_with_retry(**kwargs)247248249def _convert_chunk_to_message_chunk(250 chunk: dict,251 default_class: type[BaseMessageChunk],252 index: int,253 index_type: str,254 output_version: str | None,255) -> tuple[BaseMessageChunk, int, str]:256 _choice = chunk["choices"][0]257 _delta = _choice["delta"]258 role = _delta.get("role")259 content = _delta.get("content") or ""260 if output_version == "v1" and isinstance(content, str):261 content = [{"type": "text", "text": content}]262 if isinstance(content, list):263 for block in content:264 if isinstance(block, dict):265 if "type" in block and block["type"] != index_type:266 index_type = block["type"]267 index = index + 1268 if "index" not in block:269 block["index"] = index270 if block.get("type") == "thinking" and isinstance(271 block.get("thinking"), list272 ):273 for sub_block in block["thinking"]:274 if isinstance(sub_block, dict) and "index" not in sub_block:275 sub_block["index"] = 0276 if role == "user" or default_class == HumanMessageChunk:277 return HumanMessageChunk(content=content), index, index_type278 if role == "assistant" or default_class == AIMessageChunk:279 additional_kwargs: dict = {}280 response_metadata = {}281 if raw_tool_calls := _delta.get("tool_calls"):282 additional_kwargs["tool_calls"] = raw_tool_calls283 try:284 tool_call_chunks = []285 for raw_tool_call in raw_tool_calls:286 if not raw_tool_call.get("index") and not raw_tool_call.get("id"):287 tool_call_id = uuid.uuid4().hex[:]288 else:289 tool_call_id = raw_tool_call.get("id")290 tool_call_chunks.append(291 tool_call_chunk(292 name=raw_tool_call["function"].get("name"),293 args=raw_tool_call["function"].get("arguments"),294 id=tool_call_id,295 index=raw_tool_call.get("index"),296 )297 )298 except KeyError:299 pass300 else:301 tool_call_chunks = []302 if token_usage := chunk.get("usage"):303 usage_metadata = {304 "input_tokens": token_usage.get("prompt_tokens", 0),305 "output_tokens": token_usage.get("completion_tokens", 0),306 "total_tokens": token_usage.get("total_tokens", 0),307 }308 else:309 usage_metadata = None310 if _choice.get("finish_reason") is not None and isinstance(311 chunk.get("model"), str312 ):313 response_metadata["model_name"] = chunk["model"]314 response_metadata["finish_reason"] = _choice["finish_reason"]315 return (316 AIMessageChunk(317 content=content,318 additional_kwargs=additional_kwargs,319 tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]320 usage_metadata=usage_metadata, # type: ignore[arg-type]321 response_metadata={"model_provider": "mistralai", **response_metadata},322 ),323 index,324 index_type,325 )326 if role == "system" or default_class == SystemMessageChunk:327 return SystemMessageChunk(content=content), index, index_type328 if role or default_class == ChatMessageChunk:329 return ChatMessageChunk(content=content, role=role), index, index_type330 return default_class(content=content), index, index_type # type: ignore[call-arg]331332333def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:334 """Format LangChain ToolCall to dict expected by Mistral."""335 result: dict[str, Any] = {336 "function": {337 "name": tool_call["name"],338 "arguments": json.dumps(tool_call["args"], ensure_ascii=False),339 }340 }341 if _id := tool_call.get("id"):342 result["id"] = _convert_tool_call_id_to_mistral_compatible(_id)343344 return result345346347def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:348 """Format LangChain InvalidToolCall to dict expected by Mistral."""349 result: dict[str, Any] = {350 "function": {351 "name": invalid_tool_call["name"],352 "arguments": invalid_tool_call["args"],353 }354 }355 if _id := invalid_tool_call.get("id"):356 result["id"] = _convert_tool_call_id_to_mistral_compatible(_id)357358 return result359360361def _clean_block(block: dict) -> dict:362 # Remove "index" key added for message aggregation in langchain-core363 new_block = {k: v for k, v in block.items() if k != "index"}364 if block.get("type") == "thinking" and isinstance(block.get("thinking"), list):365 new_block["thinking"] = [366 (367 {k: v for k, v in sb.items() if k != "index"}368 if isinstance(sb, dict) and "index" in sb369 else sb370 )371 for sb in block["thinking"]372 ]373 return new_block374375376def _sanitize_chat_completions_content(content: Any) -> Any:377 """Strip non-wire keys from text content blocks.378379 Mistral's chat completions endpoint rejects unknown fields on tool380 message content blocks (e.g. the `id` that LangChain auto-generates on381 `TextContentBlock`). For list content, keep only `type` and `text` on382 text blocks; pass other blocks and non-list content through unchanged.383 """384 if not isinstance(content, list):385 return content386 sanitized: list[Any] = []387 for block in content:388 if isinstance(block, dict) and block.get("type") == "text" and "text" in block:389 sanitized.append({"type": "text", "text": block["text"]})390 else:391 sanitized.append(block)392 return sanitized393394395def _format_message_content(content: Any) -> Any:396 """Format message content for the Mistral chat completions wire format.397398 Walks list content and translates LangChain canonical v0/v1 multimodal399 data blocks (e.g. `ImageContentBlock` with `url`, `base64`, or400 `file_id`) into the OpenAI-compatible shape that Mistral accepts:401 `{"type": "image_url", "image_url": {"url": "..."}}`. Strings and any402 other dict blocks are returned unchanged so that already-translated wire403 blocks (e.g. `text`, `image_url`) and Mistral-specific blocks404 (`document_url`, `input_audio`) pass through; the API surfaces an error405 for anything it doesn't understand.406407 Args:408 content: The message content. Strings and non-list values pass409 through unchanged; lists are walked block by block.410411 Returns:412 The formatted content. List inputs return a new list with canonical413 data-block translations applied; other inputs are returned as-is.414 """415 if not isinstance(content, list):416 return content417 formatted: list[Any] = []418 for block in content:419 if isinstance(block, dict) and is_data_content_block(block):420 formatted.append(421 convert_to_openai_data_block(block, api="chat/completions")422 )423 continue424 formatted.append(block)425 return formatted426427428def _convert_message_to_mistral_chat_message(429 message: BaseMessage,430) -> dict:431 if isinstance(message, ChatMessage):432 return {"role": message.role, "content": message.content}433 if isinstance(message, HumanMessage):434 return {"role": "user", "content": _format_message_content(message.content)}435 if isinstance(message, AIMessage):436 message_dict: dict[str, Any] = {"role": "assistant"}437 tool_calls: list = []438 if message.tool_calls or message.invalid_tool_calls:439 if message.tool_calls:440 tool_calls.extend(441 _format_tool_call_for_mistral(tool_call)442 for tool_call in message.tool_calls443 )444 if message.invalid_tool_calls:445 tool_calls.extend(446 _format_invalid_tool_call_for_mistral(invalid_tool_call)447 for invalid_tool_call in message.invalid_tool_calls448 )449 elif "tool_calls" in message.additional_kwargs:450 for tc in message.additional_kwargs["tool_calls"]:451 chunk = {452 "function": {453 "name": tc["function"]["name"],454 "arguments": tc["function"]["arguments"],455 }456 }457 if _id := tc.get("id"):458 chunk["id"] = _id459 tool_calls.append(chunk)460 else:461 pass462 if tool_calls: # do not populate empty list tool_calls463 message_dict["tool_calls"] = tool_calls464465 # Message content466 # Translate v1 content467 if message.response_metadata.get("output_version") == "v1":468 content = _convert_from_v1_to_mistral(469 message.content_blocks, message.response_metadata.get("model_provider")470 )471 else:472 content = message.content473474 if tool_calls and content:475 # Assistant message must have either content or tool_calls, but not both.476 # Some providers may not support tool_calls in the same message as content.477 # This is done to ensure compatibility with messages from other providers.478 content = ""479480 elif isinstance(content, list):481 content = [482 _clean_block(block)483 if isinstance(block, dict) and "index" in block484 else block485 for block in content486 ]487 else:488 content = message.content489490 # if any blocks are dicts, cast strings to text blocks491 if any(isinstance(block, dict) for block in content):492 content = [493 block if isinstance(block, dict) else {"type": "text", "text": block}494 for block in content495 ]496 message_dict["content"] = content497498 if "prefix" in message.additional_kwargs:499 message_dict["prefix"] = message.additional_kwargs["prefix"]500 return message_dict501 if isinstance(message, SystemMessage):502 return {"role": "system", "content": message.content}503 if isinstance(message, ToolMessage):504 return {505 "role": "tool",506 "content": _sanitize_chat_completions_content(message.content),507 "name": message.name,508 "tool_call_id": _convert_tool_call_id_to_mistral_compatible(509 message.tool_call_id510 ),511 }512 msg = f"Got unknown type {message}"513 raise ValueError(msg)514515516class ChatMistralAI(BaseChatModel):517 """A chat model that uses the Mistral AI API."""518519 # The type for client and async_client is ignored because the type is not520 # an Optional after the model is initialized and the model_validator521 # is run.522 client: httpx.Client = Field( # type: ignore[assignment] # : meta private:523 default=None, exclude=True524 )525526 async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # : meta private:527 default=None, exclude=True528 )529530 mistral_api_key: SecretStr | None = Field(531 alias="api_key",532 default_factory=secret_from_env("MISTRAL_API_KEY", default=None),533 )534535 endpoint: str | None = Field(default=None, alias="base_url")536537 max_retries: int = 5538539 timeout: int = 120540541 max_concurrent_requests: int = 64542543 model: str = Field(default="mistral-small", alias="model_name")544545 temperature: float = 0.7546547 max_tokens: int | None = None548549 top_p: float = 1550 """Decode using nucleus sampling: consider the smallest set of tokens whose551 probability sum is at least `top_p`. Must be in the closed interval552 `[0.0, 1.0]`."""553554 random_seed: int | None = None555556 safe_mode: bool | None = None557558 streaming: bool = False559560 model_kwargs: dict[str, Any] = Field(default_factory=dict)561 """Holds any invocation parameters not explicitly specified."""562563 model_config = ConfigDict(564 populate_by_name=True,565 arbitrary_types_allowed=True,566 )567568 @model_validator(mode="before")569 @classmethod570 def build_extra(cls, values: dict[str, Any]) -> Any:571 """Build extra kwargs from additional params that were passed in."""572 all_required_field_names = get_pydantic_field_names(cls)573 return _build_model_kwargs(values, all_required_field_names)574575 @property576 def _default_params(self) -> dict[str, Any]:577 """Get the default parameters for calling the API."""578 defaults = {579 "model": self.model,580 "temperature": self.temperature,581 "max_tokens": self.max_tokens,582 "top_p": self.top_p,583 "random_seed": self.random_seed,584 "safe_prompt": self.safe_mode,585 **self.model_kwargs,586 }587 return {k: v for k, v in defaults.items() if v is not None}588589 def _get_ls_params(590 self, stop: list[str] | None = None, **kwargs: Any591 ) -> LangSmithParams:592 """Get standard params for tracing."""593 params = self._get_invocation_params(stop=stop, **kwargs)594 ls_params = LangSmithParams(595 ls_provider="mistral",596 ls_model_name=params.get("model", self.model),597 ls_model_type="chat",598 ls_temperature=params.get("temperature", self.temperature),599 )600 if ls_max_tokens := params.get("max_tokens", self.max_tokens):601 ls_params["ls_max_tokens"] = ls_max_tokens602 if ls_stop := stop or params.get("stop", None):603 ls_params["ls_stop"] = ls_stop604 return ls_params605606 @property607 def _client_params(self) -> dict[str, Any]:608 """Get the parameters used for the client."""609 return self._default_params610611 def completion_with_retry(612 self, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any613 ) -> Any:614 """Use tenacity to retry the completion call."""615 retry_decorator = _create_retry_decorator(self, run_manager=run_manager)616617 @retry_decorator618 def _completion_with_retry(**kwargs: Any) -> Any:619 if "stream" not in kwargs:620 kwargs["stream"] = False621 stream = kwargs["stream"]622 if stream:623624 def iter_sse() -> Iterator[dict]:625 with connect_sse(626 self.client, "POST", "/chat/completions", json=kwargs627 ) as event_source:628 _raise_on_error(event_source.response)629 for event in event_source.iter_sse():630 if event.data == "[DONE]":631 return632 yield event.json()633634 return iter_sse()635 response = self.client.post(url="/chat/completions", json=kwargs)636 _raise_on_error(response)637 return response.json()638639 return _completion_with_retry(**kwargs)640641 def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict:642 overall_token_usage: dict = {}643 for output in llm_outputs:644 if output is None:645 # Happens in streaming646 continue647 token_usage = output["token_usage"]648 if token_usage is not None:649 for k, v in token_usage.items():650 if k in overall_token_usage:651 overall_token_usage[k] += v652 else:653 overall_token_usage[k] = v654 return {"token_usage": overall_token_usage, "model_name": self.model}655656 @model_validator(mode="after")657 def validate_environment(self) -> Self:658 """Validate api key, python package exists, temperature, and top_p."""659 if isinstance(self.mistral_api_key, SecretStr):660 api_key_str: str | None = self.mistral_api_key.get_secret_value()661 else:662 api_key_str = self.mistral_api_key663664 # TODO: handle retries665 base_url_str = (666 self.endpoint667 or os.environ.get("MISTRAL_BASE_URL")668 or "https://api.mistral.ai/v1"669 )670 self.endpoint = base_url_str671 if not self.client:672 self.client = httpx.Client(673 base_url=base_url_str,674 headers={675 "Content-Type": "application/json",676 "Accept": "application/json",677 "Authorization": f"Bearer {api_key_str}",678 },679 timeout=self.timeout,680 verify=global_ssl_context,681 )682 # TODO: handle retries and max_concurrency683 if not self.async_client:684 self.async_client = httpx.AsyncClient(685 base_url=base_url_str,686 headers={687 "Content-Type": "application/json",688 "Accept": "application/json",689 "Authorization": f"Bearer {api_key_str}",690 },691 timeout=self.timeout,692 verify=global_ssl_context,693 )694695 if self.temperature is not None and not 0 <= self.temperature <= 1:696 msg = "temperature must be in the range [0.0, 1.0]"697 raise ValueError(msg)698699 if self.top_p is not None and not 0 <= self.top_p <= 1:700 msg = "top_p must be in the range [0.0, 1.0]"701 raise ValueError(msg)702703 return self704705 def _resolve_model_profile(self) -> ModelProfile | None:706 return _get_default_model_profile(self.model) or None707708 def _generate(709 self,710 messages: list[BaseMessage],711 stop: list[str] | None = None,712 run_manager: CallbackManagerForLLMRun | None = None,713 stream: bool | None = None, # noqa: FBT001714 **kwargs: Any,715 ) -> ChatResult:716 message_dicts, params = self._create_message_dicts(messages, stop)717 params = {**params, **kwargs}718 response = self.completion_with_retry(719 messages=message_dicts, run_manager=run_manager, **params720 )721 return self._create_chat_result(response)722723 def _create_chat_result(self, response: dict) -> ChatResult:724 generations = []725 token_usage = response.get("usage", {})726 for res in response["choices"]:727 finish_reason = res.get("finish_reason")728 message = _convert_mistral_chat_message_to_message(res["message"])729 if token_usage and isinstance(message, AIMessage):730 message.usage_metadata = {731 "input_tokens": token_usage.get("prompt_tokens", 0),732 "output_tokens": token_usage.get("completion_tokens", 0),733 "total_tokens": token_usage.get("total_tokens", 0),734 }735 gen = ChatGeneration(736 message=message,737 generation_info={"finish_reason": finish_reason},738 )739 generations.append(gen)740741 llm_output = {742 "token_usage": token_usage,743 "model_name": self.model,744 "model": self.model, # Backwards compatibility745 }746 return ChatResult(generations=generations, llm_output=llm_output)747748 def _create_message_dicts(749 self, messages: list[BaseMessage], stop: list[str] | None750 ) -> tuple[list[dict], dict[str, Any]]:751 params = self._client_params752 if stop is not None or "stop" in params:753 if "stop" in params:754 params.pop("stop")755 logger.warning(756 "Parameter `stop` not yet supported (https://docs.mistral.ai/api)"757 )758 message_dicts = [_convert_message_to_mistral_chat_message(m) for m in messages]759 return message_dicts, params760761 def _stream(762 self,763 messages: list[BaseMessage],764 stop: list[str] | None = None,765 run_manager: CallbackManagerForLLMRun | None = None,766 **kwargs: Any,767 ) -> Iterator[ChatGenerationChunk]:768 message_dicts, params = self._create_message_dicts(messages, stop)769 params = {**params, **kwargs, "stream": True}770771 default_chunk_class: type[BaseMessageChunk] = AIMessageChunk772 index = -1773 index_type = ""774 for chunk in self.completion_with_retry(775 messages=message_dicts, run_manager=run_manager, **params776 ):777 if len(chunk.get("choices", [])) == 0:778 continue779 new_chunk, index, index_type = _convert_chunk_to_message_chunk(780 chunk, default_chunk_class, index, index_type, self.output_version781 )782 # make future chunks same type as first chunk783 default_chunk_class = new_chunk.__class__784 gen_chunk = ChatGenerationChunk(message=new_chunk)785 if run_manager:786 run_manager.on_llm_new_token(787 token=cast("str", new_chunk.content), chunk=gen_chunk788 )789 yield gen_chunk790791 async def _astream(792 self,793 messages: list[BaseMessage],794 stop: list[str] | None = None,795 run_manager: AsyncCallbackManagerForLLMRun | None = None,796 **kwargs: Any,797 ) -> AsyncIterator[ChatGenerationChunk]:798 message_dicts, params = self._create_message_dicts(messages, stop)799 params = {**params, **kwargs, "stream": True}800801 default_chunk_class: type[BaseMessageChunk] = AIMessageChunk802 index = -1803 index_type = ""804 async for chunk in await acompletion_with_retry(805 self, messages=message_dicts, run_manager=run_manager, **params806 ):807 if len(chunk.get("choices", [])) == 0:808 continue809 new_chunk, index, index_type = _convert_chunk_to_message_chunk(810 chunk, default_chunk_class, index, index_type, self.output_version811 )812 # make future chunks same type as first chunk813 default_chunk_class = new_chunk.__class__814 gen_chunk = ChatGenerationChunk(message=new_chunk)815 if run_manager:816 await run_manager.on_llm_new_token(817 token=cast("str", new_chunk.content), chunk=gen_chunk818 )819 yield gen_chunk820821 async def _agenerate(822 self,823 messages: list[BaseMessage],824 stop: list[str] | None = None,825 run_manager: AsyncCallbackManagerForLLMRun | None = None,826 stream: bool | None = None, # noqa: FBT001827 **kwargs: Any,828 ) -> ChatResult:829 message_dicts, params = self._create_message_dicts(messages, stop)830 params = {**params, **kwargs}831 response = await acompletion_with_retry(832 self, messages=message_dicts, run_manager=run_manager, **params833 )834 return self._create_chat_result(response)835836 def bind_tools(837 self,838 tools: Sequence[dict[str, Any] | type | Callable | BaseTool],839 tool_choice: dict | str | Literal["auto", "any"] | None = None, # noqa: PYI051840 **kwargs: Any,841 ) -> Runnable[LanguageModelInput, AIMessage]:842 """Bind tool-like objects to this chat model.843844 Assumes model is compatible with OpenAI tool-calling API.845846 Args:847 tools: A list of tool definitions to bind to this chat model.848849 Supports any tool definition handled by [`convert_to_openai_tool`][langchain_core.utils.function_calling.convert_to_openai_tool].850 tool_choice: Which tool to require the model to call.851 Must be the name of the single provided function or852 `'auto'` to automatically determine which function to call853 (if any), or a dict of the form:854 {"type": "function", "function": {"name": <<tool_name>>}}.855 kwargs: Any additional parameters are passed directly to856 `self.bind(**kwargs)`.857 """ # noqa: E501858 formatted_tools = [convert_to_openai_tool(tool) for tool in tools]859 if tool_choice:860 tool_names = []861 for tool in formatted_tools:862 if ("function" in tool and (name := tool["function"].get("name"))) or (863 name := tool.get("name")864 ):865 tool_names.append(name)866 else:867 pass868 if tool_choice in tool_names:869 kwargs["tool_choice"] = {870 "type": "function",871 "function": {"name": tool_choice},872 }873 else:874 kwargs["tool_choice"] = tool_choice875 return super().bind(tools=formatted_tools, **kwargs)876877 def with_structured_output(878 self,879 schema: dict | type | None = None,880 *,881 method: Literal[882 "function_calling", "json_mode", "json_schema"883 ] = "function_calling",884 include_raw: bool = False,885 **kwargs: Any,886 ) -> Runnable[LanguageModelInput, dict | BaseModel]:887 r"""Model wrapper that returns outputs formatted to match the given schema.888889 Args:890 schema: The output schema. Can be passed in as:891892 - An OpenAI function/tool schema,893 - A JSON Schema,894 - A `TypedDict` class,895 - Or a Pydantic class.896897 If `schema` is a Pydantic class then the model output will be a898 Pydantic instance of that class, and the model-generated fields will be899 validated by the Pydantic class. Otherwise the model output will be a900 dict and will not be validated.901902 See `langchain_core.utils.function_calling.convert_to_openai_tool` for903 more on how to properly specify types and descriptions of schema fields904 when specifying a Pydantic or `TypedDict` class.905906 method: The method for steering model generation, one of:907908 - `'function_calling'`:909 Uses Mistral's910 [function-calling feature](https://docs.mistral.ai/capabilities/function_calling/).911 - `'json_schema'`:912 Uses Mistral's913 [structured output feature](https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/).914 - `'json_mode'`:915 Uses Mistral's916 [JSON mode](https://docs.mistral.ai/capabilities/structured-output/json_mode/).917 Note that if using JSON mode then you918 must include instructions for formatting the output into the919 desired schema into the model call.920921 !!! warning "Behavior changed in `langchain-mistralai` 0.2.5"922923 Added method="json_schema"924925 include_raw:926 If `False` then only the parsed structured output is returned.927928 If an error occurs during model output parsing it will be raised.929930 If `True` then both the raw model response (a `BaseMessage`) and the931 parsed model response will be returned.932933 If an error occurs during output parsing it will be caught and returned934 as well.935936 The final output is always a `dict` with keys `'raw'`, `'parsed'`, and937 `'parsing_error'`.938939 kwargs: Any additional parameters are passed directly to940 `self.bind(**kwargs)`. This is useful for passing in941 parameters such as `tool_choice` or `tools` to control942 which tool the model should call, or to pass in parameters such as943 `stop` to control when the model should stop generating output.944945 Returns:946 A `Runnable` that takes same inputs as a947 `langchain_core.language_models.chat.BaseChatModel`. If `include_raw` is948 `False` and `schema` is a Pydantic class, `Runnable` outputs an instance949 of `schema` (i.e., a Pydantic object). Otherwise, if `include_raw` is950 `False` then `Runnable` outputs a `dict`.951952 If `include_raw` is `True`, then `Runnable` outputs a `dict` with keys:953954 - `'raw'`: `BaseMessage`955 - `'parsed'`: `None` if there was a parsing error, otherwise the type956 depends on the `schema` as described above.957 - `'parsing_error'`: `BaseException | None`958959 Example: schema=Pydantic class, method="function_calling", include_raw=False:960961 ```python962 from typing import Optional963964 from langchain_mistralai import ChatMistralAI965 from pydantic import BaseModel, Field966967968 class AnswerWithJustification(BaseModel):969 '''An answer to the user question along with justification for the answer.'''970971 answer: str972 # If we provide default values and/or descriptions for fields, these will be passed973 # to the model. This is an important part of improving a model's ability to974 # correctly return structured outputs.975 justification: str | None = Field(976 default=None, description="A justification for the answer."977 )978979980 model = ChatMistralAI(model="mistral-large-latest", temperature=0)981 structured_model = model.with_structured_output(AnswerWithJustification)982983 structured_model.invoke(984 "What weighs more a pound of bricks or a pound of feathers"985 )986987 # -> AnswerWithJustification(988 # answer='They weigh the same',989 # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'990 # )991 ```992993 Example: schema=Pydantic class, method="function_calling", include_raw=True:994995 ```python996 from langchain_mistralai import ChatMistralAI997 from pydantic import BaseModel9989991000 class AnswerWithJustification(BaseModel):1001 '''An answer to the user question along with justification for the answer.'''10021003 answer: str1004 justification: str100510061007 model = ChatMistralAI(model="mistral-large-latest", temperature=0)1008 structured_model = model.with_structured_output(1009 AnswerWithJustification, include_raw=True1010 )10111012 structured_model.invoke(1013 "What weighs more a pound of bricks or a pound of feathers"1014 )1015 # -> {1016 # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),1017 # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),1018 # 'parsing_error': None1019 # }1020 ```10211022 Example: schema=TypedDict class, method="function_calling", include_raw=False:10231024 ```python1025 from typing_extensions import Annotated, TypedDict10261027 from langchain_mistralai import ChatMistralAI102810291030 class AnswerWithJustification(TypedDict):1031 '''An answer to the user question along with justification for the answer.'''10321033 answer: str1034 justification: Annotated[1035 str | None, None, "A justification for the answer."1036 ]103710381039 model = ChatMistralAI(model="mistral-large-latest", temperature=0)1040 structured_model = model.with_structured_output(AnswerWithJustification)10411042 structured_model.invoke(1043 "What weighs more a pound of bricks or a pound of feathers"1044 )1045 # -> {1046 # 'answer': 'They weigh the same',1047 # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'1048 # }1049 ```10501051 Example: schema=OpenAI function schema, method="function_calling", include_raw=False:10521053 ```python1054 from langchain_mistralai import ChatMistralAI10551056 oai_schema = {1057 'name': 'AnswerWithJustification',1058 'description': 'An answer to the user question along with justification for the answer.',1059 'parameters': {1060 'type': 'object',1061 'properties': {1062 'answer': {'type': 'string'},1063 'justification': {'description': 'A justification for the answer.', 'type': 'string'}1064 },1065 'required': ['answer']1066 }10671068 model = ChatMistralAI(model="mistral-large-latest", temperature=0)1069 structured_model = model.with_structured_output(oai_schema)10701071 structured_model.invoke(1072 "What weighs more a pound of bricks or a pound of feathers"1073 )1074 # -> {1075 # 'answer': 'They weigh the same',1076 # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'1077 # }1078 ```10791080 Example: schema=Pydantic class, method="json_mode", include_raw=True:10811082 ```python1083 from langchain_mistralai import ChatMistralAI1084 from pydantic import BaseModel108510861087 class AnswerWithJustification(BaseModel):1088 answer: str1089 justification: str109010911092 model = ChatMistralAI(model="mistral-large-latest", temperature=0)1093 structured_model = model.with_structured_output(1094 AnswerWithJustification, method="json_mode", include_raw=True1095 )10961097 structured_model.invoke(1098 "Answer the following question. "1099 "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"1100 "What's heavier a pound of bricks or a pound of feathers?"1101 )1102 # -> {1103 # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),1104 # 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),1105 # 'parsing_error': None1106 # }1107 ```11081109 Example: schema=None, method="json_mode", include_raw=True:11101111 ```python1112 structured_model = model.with_structured_output(1113 method="json_mode", include_raw=True1114 )11151116 structured_model.invoke(1117 "Answer the following question. "1118 "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"1119 "What's heavier a pound of bricks or a pound of feathers?"1120 )1121 # -> {1122 # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),1123 # 'parsed': {1124 # 'answer': 'They are both the same weight.',1125 # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'1126 # },1127 # 'parsing_error': None1128 # }1129 ```1130 """ # noqa: E5011131 _ = kwargs.pop("strict", None)1132 if kwargs:1133 msg = f"Received unsupported arguments {kwargs}"1134 raise ValueError(msg)1135 is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)1136 if method == "function_calling":1137 if schema is None:1138 msg = (1139 "schema must be specified when method is 'function_calling'. "1140 "Received None."1141 )1142 raise ValueError(msg)1143 # TODO: Update to pass in tool name as tool_choice if/when Mistral supports1144 # specifying a tool.1145 llm = self.bind_tools(1146 [schema],1147 tool_choice="any",1148 ls_structured_output_format={1149 "kwargs": {"method": "function_calling"},1150 "schema": schema,1151 },1152 )1153 if is_pydantic_schema:1154 output_parser: OutputParserLike = PydanticToolsParser(1155 tools=[schema], # type: ignore[list-item]1156 first_tool_only=True, # type: ignore[list-item]1157 )1158 else:1159 key_name = convert_to_openai_tool(schema)["function"]["name"]1160 output_parser = JsonOutputKeyToolsParser(1161 key_name=key_name, first_tool_only=True1162 )1163 elif method == "json_mode":1164 llm = self.bind(1165 response_format={"type": "json_object"},1166 ls_structured_output_format={1167 "kwargs": {1168 # this is correct - name difference with mistral api1169 "method": "json_mode"1170 },1171 "schema": schema,1172 },1173 )1174 output_parser = (1175 PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]1176 if is_pydantic_schema1177 else JsonOutputParser()1178 )1179 elif method == "json_schema":1180 if schema is None:1181 msg = (1182 "schema must be specified when method is 'json_schema'. "1183 "Received None."1184 )1185 raise ValueError(msg)1186 response_format = _convert_to_openai_response_format(schema, strict=True)1187 llm = self.bind(1188 response_format=response_format,1189 ls_structured_output_format={1190 "kwargs": {"method": "json_schema"},1191 "schema": schema,1192 },1193 )11941195 output_parser = (1196 PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]1197 if is_pydantic_schema1198 else JsonOutputParser()1199 )1200 if include_raw:1201 parser_assign = RunnablePassthrough.assign(1202 parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None1203 )1204 parser_none = RunnablePassthrough.assign(parsed=lambda _: None)1205 parser_with_fallback = parser_assign.with_fallbacks(1206 [parser_none], exception_key="parsing_error"1207 )1208 return RunnableMap(raw=llm) | parser_with_fallback1209 return llm | output_parser12101211 @property1212 def _identifying_params(self) -> dict[str, Any]:1213 """Get the identifying parameters."""1214 return self._default_params12151216 @property1217 def _llm_type(self) -> str:1218 """Return type of chat model."""1219 return "mistralai-chat"12201221 @property1222 def lc_secrets(self) -> dict[str, str]:1223 return {"mistral_api_key": "MISTRAL_API_KEY"}12241225 @classmethod1226 def is_lc_serializable(cls) -> bool:1227 """Return whether this model can be serialized by LangChain."""1228 return True12291230 @classmethod1231 def get_lc_namespace(cls) -> list[str]:1232 """Get the namespace of the LangChain object.12331234 Returns:1235 `["langchain", "chat_models", "mistralai"]`1236 """1237 return ["langchain", "chat_models", "mistralai"]123812391240def _convert_to_openai_response_format(1241 schema: dict[str, Any] | type, *, strict: bool | None = None1242) -> dict:1243 """Perform same op as in ChatOpenAI, but do not pass through Pydantic BaseModels."""1244 if (1245 isinstance(schema, dict)1246 and "json_schema" in schema1247 and schema.get("type") == "json_schema"1248 ):1249 response_format = schema1250 elif isinstance(schema, dict) and "name" in schema and "schema" in schema:1251 response_format = {"type": "json_schema", "json_schema": schema}1252 else:1253 if strict is None:1254 if isinstance(schema, dict) and isinstance(schema.get("strict"), bool):1255 strict = schema["strict"]1256 else:1257 strict = False1258 function = convert_to_openai_tool(schema, strict=strict)["function"]1259 function["schema"] = function.pop("parameters")1260 response_format = {"type": "json_schema", "json_schema": function}12611262 if (1263 strict is not None1264 and strict is not response_format["json_schema"].get("strict")1265 and isinstance(schema, dict)1266 ):1267 msg = (1268 f"Output schema already has 'strict' value set to "1269 f"{schema['json_schema']['strict']} but 'strict' also passed in to "1270 f"with_structured_output as {strict}. Please make sure that "1271 f"'strict' is only specified in one place."1272 )1273 raise ValueError(msg)1274 return response_format
Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.