libs/partners/mistralai/langchain_mistralai/embeddings.py PYTHON 347 lines View on github.com → Search inside
1import asyncio2import logging3import warnings4from collections.abc import Callable, Iterable56import httpx7from httpx import Response8from langchain_core.embeddings import Embeddings9from langchain_core.utils import (10    secret_from_env,11)12from pydantic import (13    BaseModel,14    ConfigDict,15    Field,16    SecretStr,17    model_validator,18)19from tenacity import retry, retry_if_exception, stop_after_attempt, wait_fixed20from tokenizers import Tokenizer  # type: ignore[import]21from typing_extensions import Self2223logger = logging.getLogger(__name__)2425MAX_TOKENS = 16_00026"""A batching parameter for the Mistral API. This is NOT the maximum number of tokens27accepted by the embedding model for each document/chunk, but rather the maximum number28of tokens that can be sent in a single request to the Mistral API (across multiple29documents/chunks)"""303132def _is_retryable_error(exception: BaseException) -> bool:33    """Determine if an exception should trigger a retry.3435    Only retries on:36    - Timeout exceptions37    - 429 (rate limit) errors38    - 5xx (server) errors3940    Does NOT retry on 400 (bad request) or other 4xx client errors.41    """42    if isinstance(exception, httpx.TimeoutException):43        return True44    if isinstance(exception, httpx.HTTPStatusError):45        status_code = exception.response.status_code46        # Retry on rate limit (429) or server errors (5xx)47        return status_code == 429 or status_code >= 50048    return False495051class DummyTokenizer:52    """Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)."""5354    @staticmethod55    def encode_batch(texts: list[str]) -> list[list[str]]:56        return [list(text) for text in texts]575859class MistralAIEmbeddings(BaseModel, Embeddings):60    """MistralAI embedding model integration.6162    Setup:63        Install `langchain_mistralai` and set environment variable64        `MISTRAL_API_KEY`.6566        ```bash67        pip install -U langchain_mistralai68        export MISTRAL_API_KEY="your-api-key"69        ```7071    Key init args  completion params:72        model:73            Name of `MistralAI` model to use.7475    Key init args  client params:76        api_key:77            The API key for the MistralAI API. If not provided, it will be read from the78            environment variable `MISTRAL_API_KEY`.79        max_concurrent_requests: int80        max_retries:81            The number of times to retry a request if it fails.82        timeout:83            The number of seconds to wait for a response before timing out.84        wait_time:85            The number of seconds to wait before retrying a request in case of 42986            error.87        max_concurrent_requests:88            The maximum number of concurrent requests to make to the Mistral API.8990    See full list of supported init args and their descriptions in the params section.9192    Instantiate:9394        ```python95        from __module_name__ import MistralAIEmbeddings9697        embed = MistralAIEmbeddings(98            model="mistral-embed",99            # api_key="...",100            # other params...101        )102        ```103104    Embed single text:105106        ```python107        input_text = "The meaning of life is 42"108        vector = embed.embed_query(input_text)109        print(vector[:3])110        ```111        ```python112        [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]113        ```114115    Embed multiple text:116117        ```python118        input_texts = ["Document 1...", "Document 2..."]119        vectors = embed.embed_documents(input_texts)120        print(len(vectors))121        # The first 3 coordinates for the first vector122        print(vectors[0][:3])123        ```124        ```python125        2126        [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]127        ```128129    Async:130131        ```python132        vector = await embed.aembed_query(input_text)133        print(vector[:3])134135        # multiple:136        # await embed.aembed_documents(input_texts)137        ```138        ```python139        [-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188]140        ```141    """142143    # The type for client and async_client is ignored because the type is not144    # an Optional after the model is initialized and the model_validator145    # is run.146    client: httpx.Client = Field(default=None)  # type: ignore[assignment]147148    async_client: httpx.AsyncClient = Field(  # type: ignore[assignment]149        default=None150    )151152    mistral_api_key: SecretStr = Field(153        alias="api_key",154        default_factory=secret_from_env("MISTRAL_API_KEY", default=""),155    )156157    endpoint: str = "https://api.mistral.ai/v1/"158159    max_retries: int | None = 5160161    timeout: int = 120162163    wait_time: int | None = 30164165    max_concurrent_requests: int = 64166167    tokenizer: Tokenizer = Field(default=None)168169    model: str = "mistral-embed"170171    model_config = ConfigDict(172        extra="forbid",173        arbitrary_types_allowed=True,174        populate_by_name=True,175    )176177    @model_validator(mode="after")178    def validate_environment(self) -> Self:179        """Validate configuration."""180        api_key_str = self.mistral_api_key.get_secret_value()181        # TODO: handle retries182        if not self.client:183            self.client = httpx.Client(184                base_url=self.endpoint,185                headers={186                    "Content-Type": "application/json",187                    "Accept": "application/json",188                    "Authorization": f"Bearer {api_key_str}",189                },190                timeout=self.timeout,191            )192        # TODO: handle retries and max_concurrency193        if not self.async_client:194            self.async_client = httpx.AsyncClient(195                base_url=self.endpoint,196                headers={197                    "Content-Type": "application/json",198                    "Accept": "application/json",199                    "Authorization": f"Bearer {api_key_str}",200                },201                timeout=self.timeout,202            )203        if self.tokenizer is None:204            try:205                self.tokenizer = Tokenizer.from_pretrained(206                    "mistralai/Mixtral-8x7B-v0.1"207                )208            except OSError:  # huggingface_hub GatedRepoError209                warnings.warn(210                    "Could not download mistral tokenizer from Huggingface for "211                    "calculating batch sizes. Set a Huggingface token via the "212                    "HF_TOKEN environment variable to download the real tokenizer. "213                    "Falling back to a dummy tokenizer that uses `len()`.",214                    stacklevel=2,215                )216                self.tokenizer = DummyTokenizer()217        return self218219    def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:220        """Split list of texts into batches of less than 16k tokens for Mistral API."""221        batch: list[str] = []222        batch_tokens = 0223224        text_token_lengths = [225            len(encoded) for encoded in self.tokenizer.encode_batch(texts)226        ]227228        for text, text_tokens in zip(texts, text_token_lengths, strict=False):229            if batch_tokens + text_tokens > MAX_TOKENS:230                if len(batch) > 0:231                    # edge case where first batch exceeds max tokens232                    # should not yield an empty batch.233                    yield batch234                batch = [text]235                batch_tokens = text_tokens236            else:237                batch.append(text)238                batch_tokens += text_tokens239        if batch:240            yield batch241242    def _retry(self, func: Callable) -> Callable:243        if self.max_retries is None or self.wait_time is None:244            return func245246        return retry(247            retry=retry_if_exception(_is_retryable_error),248            wait=wait_fixed(self.wait_time),249            stop=stop_after_attempt(self.max_retries),250        )(func)251252    def embed_documents(self, texts: list[str]) -> list[list[float]]:253        """Embed a list of document texts.254255        Args:256            texts: The list of texts to embed.257258        Returns:259            List of embeddings, one for each text.260261        """262        try:263            batch_responses = []264265            @self._retry266            def _embed_batch(batch: list[str]) -> Response:267                response = self.client.post(268                    url="/embeddings",269                    json={270                        "model": self.model,271                        "input": batch,272                    },273                )274                response.raise_for_status()275                return response276277            batch_responses = [278                _embed_batch(batch) for batch in self._get_batches(texts)279            ]280            return [281                list(map(float, embedding_obj["embedding"]))282                for response in batch_responses283                for embedding_obj in response.json()["data"]284            ]285        except Exception:286            logger.exception("An error occurred with MistralAI")287            raise288289    async def aembed_documents(self, texts: list[str]) -> list[list[float]]:290        """Embed a list of document texts.291292        Args:293            texts: The list of texts to embed.294295        Returns:296            List of embeddings, one for each text.297        """298        try:299300            @self._retry301            async def _aembed_batch(batch: list[str]) -> Response:302                response = await self.async_client.post(303                    url="/embeddings",304                    json={305                        "model": self.model,306                        "input": batch,307                    },308                )309                response.raise_for_status()310                return response311312            batch_responses = await asyncio.gather(313                *[_aembed_batch(batch) for batch in self._get_batches(texts)]314            )315            return [316                list(map(float, embedding_obj["embedding"]))317                for response in batch_responses318                for embedding_obj in response.json()["data"]319            ]320        except Exception:321            logger.exception("An error occurred with MistralAI")322            raise323324    def embed_query(self, text: str) -> list[float]:325        """Embed a single query text.326327        Args:328            text: The text to embed.329330        Returns:331            Embedding for the text.332333        """334        return self.embed_documents([text])[0]335336    async def aembed_query(self, text: str) -> list[float]:337        """Embed a single query text.338339        Args:340            text: The text to embed.341342        Returns:343            Embedding for the text.344345        """346        return (await self.aembed_documents([text]))[0]

Code quality findings 13

Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(exception, httpx.TimeoutException):
Overuse may indicate design issues; consider polymorphism
isinstance-overuse
if isinstance(exception, httpx.HTTPStatusError):
Ensure functions have docstrings for documentation
missing-docstring
def encode_batch(texts: list[str]) -> list[list[str]]:
Use logging module for better control and configurability
print-statement
print(vector[:3])
Use logging module for better control and configurability
print-statement
print(len(vectors))
Use logging module for better control and configurability
print-statement
print(vectors[0][:3])
Use logging module for better control and configurability
print-statement
print(vector[:3])
Ensure try blocks have corresponding except or finally blocks
try-without-except
try:
Avoid unnecessary list conversions; use generators where possible
unnecessary-list
list(map(float, embedding_obj["embedding"]))
Catch specific exceptions instead of Exception to avoid masking bugs
broad-except
except Exception:
Ensure try blocks have corresponding except or finally blocks
try-without-except
try:
Avoid unnecessary list conversions; use generators where possible
unnecessary-list
list(map(float, embedding_obj["embedding"]))
Catch specific exceptions instead of Exception to avoid masking bugs
broad-except
except Exception:

Get this view in your editor

Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.