libs/core/langchain_core/_security/_transport.py PYTHON 253 lines View on github.com → Search inside
1"""SSRF-safe httpx transport with DNS resolution and IP pinning."""23import asyncio4import socket56import httpx78from langchain_core._security._exceptions import SSRFBlockedError9from langchain_core._security._policy import (10    SSRFPolicy,11    _effective_allowed_hosts,12    validate_resolved_ip,13    validate_url_sync,14)1516# Keys that AsyncHTTPTransport accepts (forwarded from factory kwargs).17_TRANSPORT_KWARGS = frozenset(18    {19        "verify",20        "cert",21        "trust_env",22        "http1",23        "http2",24        "limits",25        "retries",26    }27)282930class SSRFSafeTransport(httpx.AsyncBaseTransport):31    """httpx async transport that validates DNS results against an SSRF policy.3233    For every outgoing request the transport:34    1. Checks the URL scheme against `policy.allowed_schemes`.35    2. Validates the hostname against blocked patterns.36    3. Resolves DNS and validates **all** returned IPs.37    4. Rewrites the request to connect to the first valid IP while38       preserving the original `Host` header and TLS SNI hostname.3940    Redirects are re-validated on each hop because `follow_redirects`41    is set on the *client*, causing `handle_async_request` to be called42    again for each redirect target.43    """4445    def __init__(46        self,47        policy: SSRFPolicy = SSRFPolicy(),48        **transport_kwargs: object,49    ) -> None:50        self._policy = policy51        self._inner = httpx.AsyncHTTPTransport(**transport_kwargs)  # type: ignore[arg-type]5253    # ------------------------------------------------------------------ #54    # Core request handler55    # ------------------------------------------------------------------ #5657    async def handle_async_request(58        self,59        request: httpx.Request,60    ) -> httpx.Response:61        hostname = request.url.host or ""62        scheme = request.url.scheme.lower()6364        # 1-3. Scheme, hostname, and pattern checks (reuse sync validator).65        try:66            validate_url_sync(str(request.url), self._policy)67        except SSRFBlockedError:68            raise6970        # Allowed-hosts bypass - skip DNS/IP validation entirely.71        allowed = {h.lower() for h in _effective_allowed_hosts(self._policy)}72        if hostname.lower() in allowed:73            return await self._inner.handle_async_request(request)7475        # 4. DNS resolution76        port = request.url.port or (443 if scheme == "https" else 80)77        try:78            addrinfo = await asyncio.to_thread(79                socket.getaddrinfo,80                hostname,81                port,82                type=socket.SOCK_STREAM,83            )84        except socket.gaierror as exc:85            raise SSRFBlockedError("DNS resolution failed") from exc8687        if not addrinfo:88            raise SSRFBlockedError("DNS resolution returned no results")8990        # 5. Validate ALL resolved IPs - any blocked means reject.91        for _family, _type, _proto, _canonname, sockaddr in addrinfo:92            ip_str: str = sockaddr[0]  # type: ignore[assignment]93            validate_resolved_ip(ip_str, self._policy)9495        # 6. Pin to first resolved IP.96        pinned_ip = addrinfo[0][4][0]9798        # 7. Rewrite URL to use pinned IP, preserving Host header and SNI.99        pinned_url = request.url.copy_with(host=pinned_ip)100101        # Build extensions dict, adding sni_hostname for HTTPS so TLS102        # certificate validation uses the original hostname.103        extensions = dict(request.extensions)104        if scheme == "https":105            extensions["sni_hostname"] = hostname.encode("ascii")106107        pinned_request = httpx.Request(108            method=request.method,109            url=pinned_url,110            headers=request.headers,  # Host header already set to original111            content=request.content,112            extensions=extensions,113        )114115        return await self._inner.handle_async_request(pinned_request)116117    # ------------------------------------------------------------------ #118    # Lifecycle119    # ------------------------------------------------------------------ #120121    async def aclose(self) -> None:122        await self._inner.aclose()123124125# ---------------------------------------------------------------------- #126# Factory127# ---------------------------------------------------------------------- #128129130class SSRFSafeSyncTransport(httpx.BaseTransport):131    """httpx sync transport that validates DNS results against an SSRF policy.132133    Sync mirror of `SSRFSafeTransport`. See that class for full documentation.134    """135136    def __init__(137        self,138        policy: SSRFPolicy = SSRFPolicy(),139        **transport_kwargs: object,140    ) -> None:141        self._policy = policy142        self._inner = httpx.HTTPTransport(**transport_kwargs)  # type: ignore[arg-type]143144    def handle_request(145        self,146        request: httpx.Request,147    ) -> httpx.Response:148        hostname = request.url.host or ""149        scheme = request.url.scheme.lower()150151        validate_url_sync(str(request.url), self._policy)152153        allowed = {h.lower() for h in _effective_allowed_hosts(self._policy)}154        if hostname.lower() in allowed:155            return self._inner.handle_request(request)156157        port = request.url.port or (443 if scheme == "https" else 80)158        try:159            addrinfo = socket.getaddrinfo(160                hostname,161                port,162                type=socket.SOCK_STREAM,163            )164        except socket.gaierror as exc:165            raise SSRFBlockedError("DNS resolution failed") from exc166167        if not addrinfo:168            raise SSRFBlockedError("DNS resolution returned no results")169170        for _family, _type, _proto, _canonname, sockaddr in addrinfo:171            ip_str: str = sockaddr[0]  # type: ignore[assignment]172            validate_resolved_ip(ip_str, self._policy)173174        pinned_ip = addrinfo[0][4][0]175        pinned_url = request.url.copy_with(host=pinned_ip)176177        extensions = dict(request.extensions)178        if scheme == "https":179            extensions["sni_hostname"] = hostname.encode("ascii")180181        pinned_request = httpx.Request(182            method=request.method,183            url=pinned_url,184            headers=request.headers,185            content=request.content,186            extensions=extensions,187        )188189        return self._inner.handle_request(pinned_request)190191    def close(self) -> None:192        self._inner.close()193194195# ---------------------------------------------------------------------- #196# Factories197# ---------------------------------------------------------------------- #198199200def ssrf_safe_client(201    policy: SSRFPolicy = SSRFPolicy(),202    **kwargs: object,203) -> httpx.Client:204    """Create an `httpx.Client` with SSRF protection."""205    transport_kwargs: dict[str, object] = {}206    client_kwargs: dict[str, object] = {}207    for key, value in kwargs.items():208        if key in _TRANSPORT_KWARGS:209            transport_kwargs[key] = value210        else:211            client_kwargs[key] = value212213    transport = SSRFSafeSyncTransport(policy=policy, **transport_kwargs)214215    client_kwargs.setdefault("follow_redirects", True)216    client_kwargs.setdefault("max_redirects", 10)217218    return httpx.Client(219        transport=transport,220        **client_kwargs,  # type: ignore[arg-type]221    )222223224def ssrf_safe_async_client(225    policy: SSRFPolicy = SSRFPolicy(),226    **kwargs: object,227) -> httpx.AsyncClient:228    """Create an `httpx.AsyncClient` with SSRF protection.229230    Drop-in replacement for `httpx.AsyncClient(...)` - callers just swap231    the constructor call.  Transport-specific kwargs (`verify`, `cert`,232    `retries`, etc.) are forwarded to the inner `AsyncHTTPTransport`;233    everything else goes to the `AsyncClient`.234    """235    transport_kwargs: dict[str, object] = {}236    client_kwargs: dict[str, object] = {}237    for key, value in kwargs.items():238        if key in _TRANSPORT_KWARGS:239            transport_kwargs[key] = value240        else:241            client_kwargs[key] = value242243    transport = SSRFSafeTransport(policy=policy, **transport_kwargs)244245    # Apply defaults only if not overridden by caller.246    client_kwargs.setdefault("follow_redirects", True)247    client_kwargs.setdefault("max_redirects", 10)248249    return httpx.AsyncClient(250        transport=transport,251        **client_kwargs,  # type: ignore[arg-type]252    )

Code quality findings 6

Ensure functions have docstrings for documentation
missing-docstring
async def handle_async_request(
Ensure functions have docstrings for documentation
missing-docstring
async def aclose(self) -> None:
Ensure functions have docstrings for documentation
missing-docstring
def handle_request(
Ensure functions have docstrings for documentation
missing-docstring
def close(self) -> None:
Ensure functions have docstrings for documentation
missing-docstring
def ssrf_safe_client(
Ensure functions have docstrings for documentation
missing-docstring
def ssrf_safe_async_client(

Get this view in your editor

Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.