Overuse may indicate design issues; consider polymorphism
if isinstance(message, ChatMessage):
1"""Hugging Face Chat Wrapper."""23from __future__ import annotations45import contextlib6import json7from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence8from dataclasses import dataclass9from operator import itemgetter10from typing import TYPE_CHECKING, Any, Literal, cast1112if TYPE_CHECKING:13 from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint14 from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline1516from langchain_core.callbacks.manager import (17 AsyncCallbackManagerForLLMRun,18 CallbackManagerForLLMRun,19)20from langchain_core.language_models import (21 LanguageModelInput,22 ModelProfile,23 ModelProfileRegistry,24)25from langchain_core.language_models.chat_models import (26 BaseChatModel,27 agenerate_from_stream,28 generate_from_stream,29)30from langchain_core.messages import (31 AIMessage,32 AIMessageChunk,33 BaseMessage,34 BaseMessageChunk,35 ChatMessage,36 ChatMessageChunk,37 FunctionMessage,38 FunctionMessageChunk,39 HumanMessage,40 HumanMessageChunk,41 InvalidToolCall,42 SystemMessage,43 SystemMessageChunk,44 ToolCall,45 ToolMessage,46 ToolMessageChunk,47)48from langchain_core.messages.tool import ToolCallChunk49from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk50from langchain_core.output_parsers import JsonOutputParser51from langchain_core.output_parsers.openai_tools import (52 JsonOutputKeyToolsParser,53 make_invalid_tool_call,54 parse_tool_call,55)56from langchain_core.outputs import (57 ChatGeneration,58 ChatGenerationChunk,59 ChatResult,60 LLMResult,61)62from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough63from langchain_core.tools import BaseTool64from langchain_core.utils.function_calling import (65 convert_to_json_schema,66 convert_to_openai_tool,67)68from langchain_core.utils.pydantic import is_basemodel_subclass69from pydantic import BaseModel, Field, model_validator70from typing_extensions import Self7172from langchain_huggingface.data._profiles import _PROFILES73from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint74from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline7576_MODEL_PROFILES = cast("ModelProfileRegistry", _PROFILES)777879def _get_default_model_profile(model_name: str) -> ModelProfile:80 default = _MODEL_PROFILES.get(model_name) or {}81 return default.copy()828384@dataclass85class TGI_RESPONSE:86 """Response from the TextGenInference API."""8788 choices: list[Any]89 usage: dict909192@dataclass93class TGI_MESSAGE:94 """Message to send to the TextGenInference API."""9596 role: str97 content: str98 tool_calls: list[dict]99100101def _lc_tool_call_to_hf_tool_call(tool_call: ToolCall) -> dict:102 return {103 "type": "function",104 "id": tool_call["id"],105 "function": {106 "name": tool_call["name"],107 "arguments": json.dumps(tool_call["args"], ensure_ascii=False),108 },109 }110111112def _lc_invalid_tool_call_to_hf_tool_call(113 invalid_tool_call: InvalidToolCall,114) -> dict:115 return {116 "type": "function",117 "id": invalid_tool_call["id"],118 "function": {119 "name": invalid_tool_call["name"],120 "arguments": invalid_tool_call["args"],121 },122 }123124125def _convert_message_to_dict(message: BaseMessage) -> dict:126 """Convert a LangChain message to a dictionary.127128 Args:129 message: The LangChain message.130131 Returns:132 The dictionary.133134 """135 message_dict: dict[str, Any]136 if isinstance(message, ChatMessage):137 message_dict = {"role": message.role, "content": message.content}138 elif isinstance(message, HumanMessage):139 message_dict = {"role": "user", "content": message.content}140 elif isinstance(message, AIMessage):141 message_dict = {"role": "assistant", "content": message.content}142 if "function_call" in message.additional_kwargs:143 message_dict["function_call"] = message.additional_kwargs["function_call"]144 # If function call only, content is None not empty string145 if message_dict["content"] == "":146 message_dict["content"] = None147 if message.tool_calls or message.invalid_tool_calls:148 message_dict["tool_calls"] = [149 _lc_tool_call_to_hf_tool_call(tc) for tc in message.tool_calls150 ] + [151 _lc_invalid_tool_call_to_hf_tool_call(tc)152 for tc in message.invalid_tool_calls153 ]154 elif "tool_calls" in message.additional_kwargs:155 message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]156 # If tool calls only, content is None not empty string157 if "tool_calls" in message_dict and message_dict["content"] == "":158 message_dict["content"] = None159 else:160 pass161 elif isinstance(message, SystemMessage):162 message_dict = {"role": "system", "content": message.content}163 elif isinstance(message, FunctionMessage):164 message_dict = {165 "role": "function",166 "content": message.content,167 "name": message.name,168 }169 elif isinstance(message, ToolMessage):170 message_dict = {171 "role": "tool",172 "content": message.content,173 "tool_call_id": message.tool_call_id,174 }175 else:176 msg = f"Got unknown type {message}"177 raise TypeError(msg)178 if "name" in message.additional_kwargs:179 message_dict["name"] = message.additional_kwargs["name"]180 return message_dict181182183def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:184 """Convert a dictionary to a LangChain message.185186 Args:187 _dict: The dictionary.188189 Returns:190 The LangChain message.191192 """193 role = _dict.get("role")194 if role == "user":195 return HumanMessage(content=_dict.get("content", ""))196 if role == "assistant":197 content = _dict.get("content", "") or ""198 additional_kwargs: dict = {}199 if function_call := _dict.get("function_call"):200 additional_kwargs["function_call"] = dict(function_call)201 tool_calls = []202 invalid_tool_calls = []203 if raw_tool_calls := _dict.get("tool_calls"):204 additional_kwargs["tool_calls"] = raw_tool_calls205 for raw_tool_call in raw_tool_calls:206 try:207 tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))208 except Exception as e:209 invalid_tool_calls.append(210 dict(make_invalid_tool_call(raw_tool_call, str(e)))211 )212 return AIMessage(213 content=content,214 additional_kwargs=additional_kwargs,215 tool_calls=tool_calls,216 invalid_tool_calls=invalid_tool_calls,217 )218 if role == "system":219 return SystemMessage(content=_dict.get("content", ""))220 if role == "function":221 return FunctionMessage(222 content=_dict.get("content", ""), name=_dict.get("name", "")223 )224 if role == "tool":225 additional_kwargs = {}226 if "name" in _dict:227 additional_kwargs["name"] = _dict["name"]228 return ToolMessage(229 content=_dict.get("content", ""),230 tool_call_id=_dict.get("tool_call_id", ""),231 additional_kwargs=additional_kwargs,232 )233 return ChatMessage(content=_dict.get("content", ""), role=role or "")234235236def _is_huggingface_hub(llm: Any) -> bool:237 try:238 from langchain_community.llms.huggingface_hub import (239 HuggingFaceHub, # type: ignore[import-not-found]240 )241242 return isinstance(llm, HuggingFaceHub)243 except ImportError:244 # if no langchain community, it is not a HuggingFaceHub245 return False246247248def _convert_chunk_to_message_chunk(249 chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]250) -> BaseMessageChunk:251 choice = chunk["choices"][0]252 _dict = choice["delta"]253 role = cast(str, _dict.get("role"))254 content = cast(str, _dict.get("content") or "")255 additional_kwargs: dict = {}256 tool_call_chunks: list[ToolCallChunk] = []257 if _dict.get("function_call"):258 function_call = dict(_dict["function_call"])259 if "name" in function_call and function_call["name"] is None:260 function_call["name"] = ""261 additional_kwargs["function_call"] = function_call262 if raw_tool_calls := _dict.get("tool_calls"):263 additional_kwargs["tool_calls"] = raw_tool_calls264 for rtc in raw_tool_calls:265 with contextlib.suppress(KeyError):266 tool_call_chunks.append(267 create_tool_call_chunk(268 name=rtc["function"].get("name"),269 args=rtc["function"].get("arguments"),270 id=rtc.get("id"),271 index=rtc.get("index"),272 )273 )274 if role == "user" or default_class == HumanMessageChunk:275 return HumanMessageChunk(content=content)276 if role == "assistant" or default_class == AIMessageChunk:277 if usage := chunk.get("usage"):278 input_tokens = usage.get("prompt_tokens", 0)279 output_tokens = usage.get("completion_tokens", 0)280 usage_metadata = {281 "input_tokens": input_tokens,282 "output_tokens": output_tokens,283 "total_tokens": usage.get("total_tokens", input_tokens + output_tokens),284 }285 else:286 usage_metadata = None287 return AIMessageChunk(288 content=content,289 additional_kwargs=additional_kwargs,290 tool_call_chunks=tool_call_chunks,291 usage_metadata=usage_metadata, # type: ignore[arg-type]292 )293 if role == "system" or default_class == SystemMessageChunk:294 return SystemMessageChunk(content=content)295 if role == "function" or default_class == FunctionMessageChunk:296 return FunctionMessageChunk(content=content, name=_dict["name"])297 if role == "tool" or default_class == ToolMessageChunk:298 return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])299 if role or default_class == ChatMessageChunk:300 return ChatMessageChunk(content=content, role=role)301 return default_class(content=content) # type: ignore[call-arg]302303304def _is_huggingface_textgen_inference(llm: Any) -> bool:305 try:306 from langchain_community.llms.huggingface_text_gen_inference import (307 HuggingFaceTextGenInference, # type: ignore[import-not-found]308 )309310 return isinstance(llm, HuggingFaceTextGenInference)311 except ImportError:312 # if no langchain community, it is not a HuggingFaceTextGenInference313 return False314315316def _is_huggingface_endpoint(llm: Any) -> bool:317 return isinstance(llm, HuggingFaceEndpoint)318319320def _is_huggingface_pipeline(llm: Any) -> bool:321 return isinstance(llm, HuggingFacePipeline)322323324class ChatHuggingFace(BaseChatModel):325 r"""Hugging Face LLM's as ChatModels.326327 Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,328 `HuggingFaceHub`, and `HuggingFacePipeline` LLMs.329330 Upon instantiating this class, the model_id is resolved from the url331 provided to the LLM, and the appropriate tokenizer is loaded from332 the HuggingFace Hub.333334 Setup:335 Install `langchain-huggingface` and ensure your Hugging Face token336 is saved.337338 ```bash339 pip install langchain-huggingface340 ```341342 ```python343 from huggingface_hub import login344345 login() # You will be prompted for your HF key, which will then be saved locally346 ```347348 Key init args — completion params:349 llm:350 LLM to be used.351352 Key init args — client params:353 custom_get_token_ids:354 Optional encoder to use for counting tokens.355 metadata:356 Metadata to add to the run trace.357 tags:358 Tags to add to the run trace.359 verbose:360 Whether to print out response text.361362 See full list of supported init args and their descriptions in the params363 section.364365 Instantiate:366 ```python367 from langchain_huggingface import HuggingFaceEndpoint,368 ChatHuggingFace369370 model = HuggingFaceEndpoint(371 repo_id="microsoft/Phi-3-mini-4k-instruct",372 task="text-generation",373 max_new_tokens=512,374 do_sample=False,375 repetition_penalty=1.03,376 )377378 chat = ChatHuggingFace(llm=model, verbose=True)379 ```380381 Invoke:382 ```python383 messages = [384 ("system", "You are a helpful translator. Translate the user385 sentence to French."),386 ("human", "I love programming."),387 ]388389 chat(...).invoke(messages)390 ```391392 ```python393 AIMessage(content='Je ai une passion pour le programme.\n\nIn394 French, we use "ai" for masculine subjects and "a" for feminine395 subjects. Since "programming" is gender-neutral in English, we396 will go with the masculine "programme".\n\nConfirmation: "J\'aime397 le programme." is more commonly used. The sentence above is398 technically accurate, but less commonly used in spoken French as399 "ai" is used less frequently in everyday speech.',400 response_metadata={'token_usage': ChatCompletionOutputUsage401 (completion_tokens=100, prompt_tokens=55, total_tokens=155),402 'model': '', 'finish_reason': 'length'},403 id='run-874c24b7-0272-4c99-b259-5d6d7facbc56-0')404 ```405406 Stream:407 ```python408 for chunk in chat.stream(messages):409 print(chunk)410 ```411412 ```python413 content='Je ai une passion pour le programme.\n\nIn French, we use414 "ai" for masculine subjects and "a" for feminine subjects.415 Since "programming" is gender-neutral in English,416 we will go with the masculine "programme".\n\nConfirmation:417 "J\'aime le programme." is more commonly used. The sentence418 above is technically accurate, but less commonly used in spoken419 French as "ai" is used less frequently in everyday speech.'420 response_metadata={'token_usage': ChatCompletionOutputUsage421 (completion_tokens=100, prompt_tokens=55, total_tokens=155),422 'model': '', 'finish_reason': 'length'}423 id='run-7d7b1967-9612-4f9a-911a-b2b5ca85046a-0'424 ```425426 Async:427 ```python428 await chat.ainvoke(messages)429 ```430431 ```python432 AIMessage(content='Je déaime le programming.\n\nLittérale : Je433 (j\'aime) déaime (le) programming.\n\nNote: "Programming" in434 French is "programmation". But here, I used "programming" instead435 of "programmation" because the user said "I love programming"436 instead of "I love programming (in French)", which would be437 "J\'aime la programmation". By translating the sentence438 literally, I preserved the original meaning of the user\'s439 sentence.', id='run-fd850318-e299-4735-b4c6-3496dc930b1d-0')440 ```441442 Tool calling:443 ```python444 from pydantic import BaseModel, Field445446 class GetWeather(BaseModel):447 '''Get the current weather in a given location'''448449 location: str = Field(..., description="The city and state,450 e.g. San Francisco, CA")451452 class GetPopulation(BaseModel):453 '''Get the current population in a given location'''454455 location: str = Field(..., description="The city and state,456 e.g. San Francisco, CA")457458 chat_with_tools = chat.bind_tools([GetWeather, GetPopulation])459 ai_msg = chat_with_tools.invoke("Which city is hotter today and460 which is bigger: LA or NY?")461 ai_msg.tool_calls462 ```463464 ```python465 [466 {467 "name": "GetPopulation",468 "args": {"location": "Los Angeles, CA"},469 "id": "0",470 }471 ]472 ```473474 Response metadata475 ```python476 ai_msg = chat.invoke(messages)477 ai_msg.response_metadata478 ```479480 ```python481 {482 "token_usage": ChatCompletionOutputUsage(483 completion_tokens=100, prompt_tokens=8, total_tokens=108484 ),485 "model": "",486 "finish_reason": "length",487 }488 ```489 """ # noqa: E501490491 llm: Any492 """LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint,493 HuggingFaceHub, or HuggingFacePipeline."""494 tokenizer: Any = None495 """Tokenizer for the model. Only used for HuggingFacePipeline."""496 model_id: str | None = None497 """Model ID for the model. Only used for HuggingFaceEndpoint."""498 temperature: float | None = None499 """What sampling temperature to use."""500 stop: str | list[str] | None = Field(default=None, alias="stop_sequences")501 """Default stop sequences."""502 presence_penalty: float | None = None503 """Penalizes repeated tokens."""504 frequency_penalty: float | None = None505 """Penalizes repeated tokens according to frequency."""506 seed: int | None = None507 """Seed for generation"""508 logprobs: bool | None = None509 """Whether to return logprobs."""510 top_logprobs: int | None = None511 """Number of most likely tokens to return at each token position, each with512 an associated log probability. `logprobs` must be set to true513 if this parameter is used."""514 logit_bias: dict[int, int] | None = None515 """Modify the likelihood of specified tokens appearing in the completion."""516 streaming: bool = False517 """Whether to stream the results or not."""518 stream_usage: bool | None = None519 """Whether to include usage metadata in streaming output. If True, an additional520 message chunk will be generated during the stream including usage metadata."""521 n: int | None = None522 """Number of chat completions to generate for each prompt."""523 top_p: float | None = None524 """Total probability mass of tokens to consider at each step."""525 max_tokens: int | None = None526 """Maximum number of tokens to generate."""527 model_kwargs: dict[str, Any] = Field(default_factory=dict)528 """Holds any model parameters valid for `create` call not explicitly specified."""529530 def __init__(self, **kwargs: Any):531 super().__init__(**kwargs)532533 # Inherit properties from the LLM if they weren't explicitly set534 self._inherit_llm_properties()535536 self._resolve_model_id()537538 def _inherit_llm_properties(self) -> None:539 """Inherit properties from the wrapped LLM instance if not explicitly set."""540 if not hasattr(self, "llm") or self.llm is None:541 return542543 # Map of ChatHuggingFace properties to LLM properties544 property_mappings = {545 "temperature": "temperature",546 "max_tokens": "max_new_tokens", # Different naming convention547 "top_p": "top_p",548 "seed": "seed",549 "streaming": "streaming",550 "stop": "stop_sequences",551 }552553 # Inherit properties from LLM and not explicitly set here554 for chat_prop, llm_prop in property_mappings.items():555 if hasattr(self.llm, llm_prop):556 llm_value = getattr(self.llm, llm_prop)557 chat_value = getattr(self, chat_prop, None)558 if not chat_value and llm_value:559 setattr(self, chat_prop, llm_value)560561 # Handle special cases for HuggingFaceEndpoint562 if _is_huggingface_endpoint(self.llm):563 # Inherit additional HuggingFaceEndpoint specific properties564 endpoint_mappings = {565 "frequency_penalty": "repetition_penalty",566 }567568 for chat_prop, llm_prop in endpoint_mappings.items():569 if hasattr(self.llm, llm_prop):570 llm_value = getattr(self.llm, llm_prop)571 chat_value = getattr(self, chat_prop, None)572 if chat_value is None and llm_value is not None:573 setattr(self, chat_prop, llm_value)574575 # Inherit model_kwargs if not explicitly set576 if (577 not self.model_kwargs578 and hasattr(self.llm, "model_kwargs")579 and isinstance(self.llm.model_kwargs, dict)580 ):581 self.model_kwargs = self.llm.model_kwargs.copy()582583 @model_validator(mode="after")584 def validate_llm(self) -> Self:585 if (586 not _is_huggingface_hub(self.llm)587 and not _is_huggingface_textgen_inference(self.llm)588 and not _is_huggingface_endpoint(self.llm)589 and not _is_huggingface_pipeline(self.llm)590 ):591 msg = (592 "Expected llm to be one of HuggingFaceTextGenInference, "593 "HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline "594 f"received {type(self.llm)}"595 )596 raise TypeError(msg)597 return self598599 def _resolve_model_profile(self) -> ModelProfile | None:600 if self.model_id:601 return _get_default_model_profile(self.model_id) or None602 return None603604 @classmethod605 def from_model_id(606 cls,607 model_id: str,608 task: str | None = None,609 backend: Literal["pipeline", "endpoint", "text-gen"] = "pipeline",610 **kwargs: Any,611 ) -> ChatHuggingFace:612 """Construct a ChatHuggingFace model from a model_id.613614 Args:615 model_id: The model ID of the Hugging Face model.616 task: The task to perform (e.g., "text-generation").617 backend: The backend to use. One of "pipeline", "endpoint", "text-gen".618 **kwargs: Additional arguments to pass to the backend or ChatHuggingFace.619 """620 llm: (621 Any # HuggingFacePipeline, HuggingFaceEndpoint, HuggingFaceTextGenInference622 )623 if backend == "pipeline":624 from langchain_huggingface.llms.huggingface_pipeline import (625 HuggingFacePipeline,626 )627628 task = task if task is not None else "text-generation"629630 # Separate pipeline-specific kwargs from ChatHuggingFace kwargs631 # Parameters that should go to HuggingFacePipeline.from_model_id632 pipeline_specific_kwargs = {}633634 # Extract pipeline-specific parameters635 pipeline_keys = [636 "backend",637 "device",638 "device_map",639 "model_kwargs",640 "pipeline_kwargs",641 "batch_size",642 ]643 for key in pipeline_keys:644 if key in kwargs:645 pipeline_specific_kwargs[key] = kwargs.pop(key)646647 # Remaining kwargs (temperature, max_tokens, etc.) should go to648 # pipeline_kwargs for generation parameters, which ChatHuggingFace649 # will inherit from the LLM650 if "pipeline_kwargs" not in pipeline_specific_kwargs:651 pipeline_specific_kwargs["pipeline_kwargs"] = {}652653 # Add generation parameters to pipeline_kwargs654 # Map max_tokens to max_new_tokens for HuggingFace pipeline655 generation_params = {}656 for k, v in list(kwargs.items()):657 if k == "max_tokens":658 generation_params["max_new_tokens"] = v659 kwargs.pop(k)660 elif k in (661 "temperature",662 "max_new_tokens",663 "top_p",664 "top_k",665 "repetition_penalty",666 "do_sample",667 ):668 generation_params[k] = v669 kwargs.pop(k)670671 pipeline_specific_kwargs["pipeline_kwargs"].update(generation_params)672673 # Create the HuggingFacePipeline674 llm = HuggingFacePipeline.from_model_id(675 model_id=model_id, task=task, **pipeline_specific_kwargs676 )677 elif backend == "endpoint":678 from langchain_huggingface.llms.huggingface_endpoint import (679 HuggingFaceEndpoint,680 )681682 llm = HuggingFaceEndpoint(repo_id=model_id, task=task, **kwargs)683 elif backend == "text-gen":684 from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found]685 HuggingFaceTextGenInference,686 )687688 llm = HuggingFaceTextGenInference(inference_server_url=model_id, **kwargs)689 else:690 msg = f"Unknown backend: {backend}"691 raise ValueError(msg)692693 return cls(llm=llm, **kwargs)694695 def _create_chat_result(self, response: dict) -> ChatResult:696 generations = []697 token_usage = response.get("usage", {})698 for res in response["choices"]:699 message = _convert_dict_to_message(res["message"])700 if token_usage and isinstance(message, AIMessage):701 message.usage_metadata = {702 "input_tokens": token_usage.get("prompt_tokens", 0),703 "output_tokens": token_usage.get("completion_tokens", 0),704 "total_tokens": token_usage.get("total_tokens", 0),705 }706 generation_info = {"finish_reason": res.get("finish_reason")}707 if "logprobs" in res:708 generation_info["logprobs"] = res["logprobs"]709 gen = ChatGeneration(710 message=message,711 generation_info=generation_info,712 )713 generations.append(gen)714 llm_output = {715 "token_usage": token_usage,716 "model_name": self.model_id,717 "system_fingerprint": response.get("system_fingerprint", ""),718 }719 return ChatResult(generations=generations, llm_output=llm_output)720721 def _generate(722 self,723 messages: list[BaseMessage],724 stop: list[str] | None = None,725 run_manager: CallbackManagerForLLMRun | None = None,726 stream: bool | None = None, # noqa: FBT001727 **kwargs: Any,728 ) -> ChatResult:729 should_stream = stream if stream is not None else self.streaming730731 if _is_huggingface_textgen_inference(self.llm):732 message_dicts, params = self._create_message_dicts(messages, stop)733 answer = self.llm.client.chat(messages=message_dicts, **kwargs)734 return self._create_chat_result(answer)735 if _is_huggingface_endpoint(self.llm):736 if should_stream:737 stream_iter = self._stream(738 messages, stop=stop, run_manager=run_manager, **kwargs739 )740 return generate_from_stream(stream_iter)741 message_dicts, params = self._create_message_dicts(messages, stop)742 params = {743 "stop": stop,744 **params,745 **({"stream": stream} if stream is not None else {}),746 **kwargs,747 }748 answer = self.llm.client.chat_completion(messages=message_dicts, **params)749 return self._create_chat_result(answer)750 llm_input = self._to_chat_prompt(messages)751752 if should_stream:753 stream_iter = self.llm._stream(754 llm_input, stop=stop, run_manager=run_manager, **kwargs755 )756 return generate_from_stream(stream_iter)757 llm_result = self.llm._generate(758 prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs759 )760 return self._to_chat_result(llm_result)761762 async def _agenerate(763 self,764 messages: list[BaseMessage],765 stop: list[str] | None = None,766 run_manager: AsyncCallbackManagerForLLMRun | None = None,767 stream: bool | None = None, # noqa: FBT001768 **kwargs: Any,769 ) -> ChatResult:770 if _is_huggingface_textgen_inference(self.llm):771 message_dicts, params = self._create_message_dicts(messages, stop)772 answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)773 return self._create_chat_result(answer)774 if _is_huggingface_endpoint(self.llm):775 should_stream = stream if stream is not None else self.streaming776 if should_stream:777 stream_iter = self._astream(778 messages, stop=stop, run_manager=run_manager, **kwargs779 )780 return await agenerate_from_stream(stream_iter)781 message_dicts, params = self._create_message_dicts(messages, stop)782 params = {783 **params,784 **({"stream": stream} if stream is not None else {}),785 **kwargs,786 }787788 answer = await self.llm.async_client.chat_completion(789 messages=message_dicts, **params790 )791 return self._create_chat_result(answer)792 if _is_huggingface_pipeline(self.llm):793 msg = "async generation is not supported with HuggingFacePipeline"794 raise NotImplementedError(msg)795 llm_input = self._to_chat_prompt(messages)796 llm_result = await self.llm._agenerate(797 prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs798 )799 return self._to_chat_result(llm_result)800801 def _should_stream_usage(802 self, *, stream_usage: bool | None = None, **kwargs: Any803 ) -> bool | None:804 """Determine whether to include usage metadata in streaming output.805806 For backwards compatibility, we check for `stream_options` passed807 explicitly to kwargs or in the model_kwargs and override self.stream_usage.808 """809 stream_usage_sources = [ # order of precedence810 stream_usage,811 kwargs.get("stream_options", {}).get("include_usage"),812 self.model_kwargs.get("stream_options", {}).get("include_usage"),813 self.stream_usage,814 ]815 for source in stream_usage_sources:816 if isinstance(source, bool):817 return source818 return self.stream_usage819820 def _stream(821 self,822 messages: list[BaseMessage],823 stop: list[str] | None = None,824 run_manager: CallbackManagerForLLMRun | None = None,825 *,826 stream_usage: bool | None = None,827 **kwargs: Any,828 ) -> Iterator[ChatGenerationChunk]:829 if _is_huggingface_endpoint(self.llm):830 stream_usage = self._should_stream_usage(831 stream_usage=stream_usage, **kwargs832 )833 if stream_usage:834 kwargs["stream_options"] = {"include_usage": stream_usage}835 message_dicts, params = self._create_message_dicts(messages, stop)836 params = {**params, **kwargs, "stream": True}837838 default_chunk_class: type[BaseMessageChunk] = AIMessageChunk839 for chunk in self.llm.client.chat_completion(840 messages=message_dicts, **params841 ):842 if len(chunk["choices"]) == 0:843 if usage := chunk.get("usage"):844 usage_msg = AIMessageChunk(845 content="",846 additional_kwargs={},847 response_metadata={},848 usage_metadata={849 "input_tokens": usage.get("prompt_tokens", 0),850 "output_tokens": usage.get("completion_tokens", 0),851 "total_tokens": usage.get("total_tokens", 0),852 },853 )854 yield ChatGenerationChunk(message=usage_msg)855 continue856857 choice = chunk["choices"][0]858 message_chunk = _convert_chunk_to_message_chunk(859 chunk, default_chunk_class860 )861 generation_info = {}862 if finish_reason := choice.get("finish_reason"):863 generation_info["finish_reason"] = finish_reason864 generation_info["model_name"] = self.model_id865 logprobs = choice.get("logprobs")866 if logprobs:867 generation_info["logprobs"] = logprobs868 default_chunk_class = message_chunk.__class__869 generation_chunk = ChatGenerationChunk(870 message=message_chunk, generation_info=generation_info or None871 )872 if run_manager:873 run_manager.on_llm_new_token(874 generation_chunk.text, chunk=generation_chunk, logprobs=logprobs875 )876 yield generation_chunk877 else:878 llm_input = self._to_chat_prompt(messages)879 stream_iter = self.llm._stream(880 llm_input, stop=stop, run_manager=run_manager, **kwargs881 )882 for chunk in stream_iter: # chunk is a GenerationChunk883 chat_chunk = ChatGenerationChunk(884 message=AIMessageChunk(content=chunk.text),885 generation_info=chunk.generation_info,886 )887 yield chat_chunk888889 async def _astream(890 self,891 messages: list[BaseMessage],892 stop: list[str] | None = None,893 run_manager: AsyncCallbackManagerForLLMRun | None = None,894 *,895 stream_usage: bool | None = None,896 **kwargs: Any,897 ) -> AsyncIterator[ChatGenerationChunk]:898 stream_usage = self._should_stream_usage(stream_usage=stream_usage, **kwargs)899 if stream_usage:900 kwargs["stream_options"] = {"include_usage": stream_usage}901 message_dicts, params = self._create_message_dicts(messages, stop)902 params = {**params, **kwargs, "stream": True}903904 default_chunk_class: type[BaseMessageChunk] = AIMessageChunk905906 async for chunk in await self.llm.async_client.chat_completion(907 messages=message_dicts, **params908 ):909 if len(chunk["choices"]) == 0:910 if usage := chunk.get("usage"):911 usage_msg = AIMessageChunk(912 content="",913 additional_kwargs={},914 response_metadata={},915 usage_metadata={916 "input_tokens": usage.get("prompt_tokens", 0),917 "output_tokens": usage.get("completion_tokens", 0),918 "total_tokens": usage.get("total_tokens", 0),919 },920 )921 yield ChatGenerationChunk(message=usage_msg)922 continue923924 choice = chunk["choices"][0]925 message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)926 generation_info = {}927 if finish_reason := choice.get("finish_reason"):928 generation_info["finish_reason"] = finish_reason929 generation_info["model_name"] = self.model_id930 logprobs = choice.get("logprobs")931 if logprobs:932 generation_info["logprobs"] = logprobs933 default_chunk_class = message_chunk.__class__934 generation_chunk = ChatGenerationChunk(935 message=message_chunk, generation_info=generation_info or None936 )937 if run_manager:938 await run_manager.on_llm_new_token(939 token=generation_chunk.text,940 chunk=generation_chunk,941 logprobs=logprobs,942 )943 yield generation_chunk944945 def _to_chat_prompt(946 self,947 messages: list[BaseMessage],948 ) -> str:949 """Convert a list of messages into a prompt format expected by wrapped LLM."""950 if not messages:951 msg = "At least one HumanMessage must be provided!"952 raise ValueError(msg)953954 if not isinstance(messages[-1], HumanMessage):955 msg = "Last message must be a HumanMessage!"956 raise ValueError(msg)957958 messages_dicts = [self._to_chatml_format(m) for m in messages]959960 return self.tokenizer.apply_chat_template(961 messages_dicts, tokenize=False, add_generation_prompt=True962 )963964 def _to_chatml_format(self, message: BaseMessage) -> dict:965 """Convert LangChain message to ChatML format."""966 if isinstance(message, SystemMessage):967 role = "system"968 elif isinstance(message, AIMessage):969 role = "assistant"970 elif isinstance(message, HumanMessage):971 role = "user"972 else:973 msg = f"Unknown message type: {type(message)}"974 raise ValueError(msg)975976 return {"role": role, "content": message.content}977978 @staticmethod979 def _to_chat_result(llm_result: LLMResult) -> ChatResult:980 chat_generations = []981982 for g in llm_result.generations[0]:983 chat_generation = ChatGeneration(984 message=AIMessage(content=g.text), generation_info=g.generation_info985 )986 chat_generations.append(chat_generation)987988 return ChatResult(989 generations=chat_generations, llm_output=llm_result.llm_output990 )991992 def _resolve_model_id(self) -> None:993 """Resolve the model_id from the LLM's inference_server_url."""994 from huggingface_hub import list_inference_endpoints # type: ignore[import]995996 if _is_huggingface_hub(self.llm) or (997 hasattr(self.llm, "repo_id") and self.llm.repo_id998 ):999 self.model_id = self.llm.repo_id1000 return1001 if _is_huggingface_textgen_inference(self.llm):1002 endpoint_url: str | None = self.llm.inference_server_url1003 if _is_huggingface_pipeline(self.llm):1004 from transformers import AutoTokenizer # type: ignore[import]10051006 self.model_id = self.model_id or self.llm.model_id1007 self.tokenizer = (1008 AutoTokenizer.from_pretrained(self.model_id)1009 if self.tokenizer is None1010 else self.tokenizer1011 )1012 return1013 if _is_huggingface_endpoint(self.llm):1014 self.model_id = self.llm.repo_id or self.llm.model1015 return1016 endpoint_url = self.llm.endpoint_url1017 available_endpoints = list_inference_endpoints("*")1018 for endpoint in available_endpoints:1019 if endpoint.url == endpoint_url:1020 self.model_id = endpoint.repository10211022 if not self.model_id:1023 msg = (1024 "Failed to resolve model_id:"1025 f"Could not find model id for inference server: {endpoint_url}"1026 "Make sure that your Hugging Face token has access to the endpoint."1027 )1028 raise ValueError(msg)10291030 def bind_tools(1031 self,1032 tools: Sequence[dict[str, Any] | type | Callable | BaseTool],1033 *,1034 tool_choice: dict | str | bool | None = None,1035 **kwargs: Any,1036 ) -> Runnable[LanguageModelInput, AIMessage]:1037 """Bind tool-like objects to this chat model.10381039 Assumes model is compatible with OpenAI tool-calling API.10401041 Args:1042 tools: A list of tool definitions to bind to this chat model.10431044 Supports any tool definition handled by [`convert_to_openai_tool`][langchain_core.utils.function_calling.convert_to_openai_tool].1045 tool_choice: Which tool to require the model to call.1046 Must be the name of the single provided function or1047 `'auto'` to automatically determine which function to call1048 (if any), or a dict of the form:1049 {"type": "function", "function": {"name": <<tool_name>>}}.1050 **kwargs: Any additional parameters to pass to the1051 `langchain.runnable.Runnable` constructor.1052 """ # noqa: E5011053 formatted_tools = [convert_to_openai_tool(tool) for tool in tools]1054 if tool_choice is not None and tool_choice:1055 if len(formatted_tools) != 1:1056 msg = (1057 "When specifying `tool_choice`, you must provide exactly one "1058 f"tool. Received {len(formatted_tools)} tools."1059 )1060 raise ValueError(msg)1061 if isinstance(tool_choice, str):1062 if tool_choice not in ("auto", "none", "required"):1063 tool_choice = {1064 "type": "function",1065 "function": {"name": tool_choice},1066 }1067 elif isinstance(tool_choice, bool):1068 tool_choice = formatted_tools[0]1069 elif isinstance(tool_choice, dict):1070 if (1071 formatted_tools[0]["function"]["name"]1072 != tool_choice["function"]["name"]1073 ):1074 msg = (1075 f"Tool choice {tool_choice} was specified, but the only "1076 f"provided tool was {formatted_tools[0]['function']['name']}."1077 )1078 raise ValueError(msg)1079 else:1080 msg = (1081 f"Unrecognized tool_choice type. Expected str, bool or dict. "1082 f"Received: {tool_choice}"1083 )1084 raise ValueError(msg)1085 kwargs["tool_choice"] = tool_choice1086 return super().bind(tools=formatted_tools, **kwargs)10871088 def with_structured_output(1089 self,1090 schema: dict | type[BaseModel] | None = None,1091 *,1092 method: Literal[1093 "function_calling", "json_mode", "json_schema"1094 ] = "function_calling",1095 include_raw: bool = False,1096 **kwargs: Any,1097 ) -> Runnable[LanguageModelInput, dict | BaseModel]:1098 """Model wrapper that returns outputs formatted to match the given schema.10991100 Args:1101 schema: The output schema. Can be passed in as:11021103 - An OpenAI function/tool schema,1104 - A JSON Schema,1105 - A `TypedDict` class11061107 Pydantic class is currently supported.11081109 method: The method for steering model generation, one of:11101111 - `'function_calling'`: uses tool-calling features.1112 - `'json_schema'`: uses dedicated structured output features.1113 - `'json_mode'`: uses JSON mode.11141115 include_raw:1116 If `False` then only the parsed structured output is returned.11171118 If an error occurs during model output parsing it will be raised.11191120 If `True` then both the raw model response (a `BaseMessage`) and the1121 parsed model response will be returned.11221123 If an error occurs during output parsing it will be caught and returned1124 as well.11251126 The final output is always a `dict` with keys `'raw'`, `'parsed'`, and1127 `'parsing_error'`.11281129 kwargs:1130 Additional parameters to pass to the underlying LLM's1131 `langchain_core.language_models.chat.BaseChatModel.bind`1132 method, such as `response_format` or `ls_structured_output_format`.11331134 Returns:1135 A `Runnable` that takes same inputs as a1136 `langchain_core.language_models.chat.BaseChatModel`. If `include_raw` is1137 `False` and `schema` is a Pydantic class, `Runnable` outputs an instance1138 of `schema` (i.e., a Pydantic object). Otherwise, if `include_raw` is1139 `False` then `Runnable` outputs a `dict`.11401141 If `include_raw` is `True`, then `Runnable` outputs a `dict` with keys:11421143 - `'raw'`: `BaseMessage`1144 - `'parsed'`: `None` if there was a parsing error, otherwise the type1145 depends on the `schema` as described above.1146 - `'parsing_error'`: `BaseException | None`1147 """1148 _ = kwargs.pop("strict", None)1149 if kwargs:1150 msg = f"Received unsupported arguments {kwargs}"1151 raise ValueError(msg)1152 is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)1153 if method == "function_calling":1154 if schema is None:1155 msg = (1156 "schema must be specified when method is 'function_calling'. "1157 "Received None."1158 )1159 raise ValueError(msg)1160 formatted_tool = convert_to_openai_tool(schema)1161 tool_name = formatted_tool["function"]["name"]1162 llm = self.bind_tools(1163 [schema],1164 tool_choice=tool_name,1165 ls_structured_output_format={1166 "kwargs": {"method": "function_calling"},1167 "schema": formatted_tool,1168 },1169 )1170 if is_pydantic_schema:1171 msg = "Pydantic schema is not supported for function calling"1172 raise NotImplementedError(msg)1173 output_parser: JsonOutputKeyToolsParser | JsonOutputParser = (1174 JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)1175 )1176 elif method == "json_schema":1177 if schema is None:1178 msg = (1179 "schema must be specified when method is 'json_schema'. "1180 "Received None."1181 )1182 raise ValueError(msg)1183 formatted_schema = convert_to_json_schema(schema)1184 llm = self.bind(1185 response_format={"type": "json_object", "schema": formatted_schema},1186 ls_structured_output_format={1187 "kwargs": {"method": "json_schema"},1188 "schema": schema,1189 },1190 )1191 output_parser = JsonOutputParser() # type: ignore[arg-type]1192 elif method == "json_mode":1193 llm = self.bind(1194 response_format={"type": "json_object"},1195 ls_structured_output_format={1196 "kwargs": {"method": "json_mode"},1197 "schema": schema,1198 },1199 )1200 output_parser = JsonOutputParser() # type: ignore[arg-type]1201 else:1202 msg = (1203 f"Unrecognized method argument. Expected one of 'function_calling' or "1204 f"'json_mode'. Received: '{method}'"1205 )1206 raise ValueError(msg)12071208 if include_raw:1209 parser_assign = RunnablePassthrough.assign(1210 parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None1211 )1212 parser_none = RunnablePassthrough.assign(parsed=lambda _: None)1213 parser_with_fallback = parser_assign.with_fallbacks(1214 [parser_none], exception_key="parsing_error"1215 )1216 return RunnableMap(raw=llm) | parser_with_fallback1217 return llm | output_parser12181219 def _create_message_dicts(1220 self, messages: list[BaseMessage], stop: list[str] | None1221 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:1222 params = self._default_params1223 if stop is not None:1224 params["stop"] = stop1225 message_dicts = [_convert_message_to_dict(m) for m in messages]1226 return message_dicts, params12271228 @property1229 def _default_params(self) -> dict[str, Any]:1230 """Get default parameters for calling Hugging Face Inference Providers API."""1231 params = {1232 "model": self.model_id,1233 "stream": self.streaming,1234 "n": self.n,1235 "temperature": self.temperature,1236 "stop": self.stop,1237 **(self.model_kwargs if self.model_kwargs else {}),1238 }1239 if self.max_tokens is not None:1240 params["max_tokens"] = self.max_tokens1241 return params12421243 @property1244 def _llm_type(self) -> str:1245 return "huggingface-chat-wrapper"
Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.