Ensure functions have docstrings for documentation
async def handle_async_request(
1"""SSRF-safe httpx transport with DNS resolution and IP pinning."""23import asyncio4import socket56import httpx78from langchain_core._security._exceptions import SSRFBlockedError9from langchain_core._security._policy import (10 SSRFPolicy,11 _effective_allowed_hosts,12 validate_resolved_ip,13 validate_url_sync,14)1516# Keys that AsyncHTTPTransport accepts (forwarded from factory kwargs).17_TRANSPORT_KWARGS = frozenset(18 {19 "verify",20 "cert",21 "trust_env",22 "http1",23 "http2",24 "limits",25 "retries",26 }27)282930class SSRFSafeTransport(httpx.AsyncBaseTransport):31 """httpx async transport that validates DNS results against an SSRF policy.3233 For every outgoing request the transport:34 1. Checks the URL scheme against `policy.allowed_schemes`.35 2. Validates the hostname against blocked patterns.36 3. Resolves DNS and validates **all** returned IPs.37 4. Rewrites the request to connect to the first valid IP while38 preserving the original `Host` header and TLS SNI hostname.3940 Redirects are re-validated on each hop because `follow_redirects`41 is set on the *client*, causing `handle_async_request` to be called42 again for each redirect target.43 """4445 def __init__(46 self,47 policy: SSRFPolicy = SSRFPolicy(),48 **transport_kwargs: object,49 ) -> None:50 self._policy = policy51 self._inner = httpx.AsyncHTTPTransport(**transport_kwargs) # type: ignore[arg-type]5253 # ------------------------------------------------------------------ #54 # Core request handler55 # ------------------------------------------------------------------ #5657 async def handle_async_request(58 self,59 request: httpx.Request,60 ) -> httpx.Response:61 hostname = request.url.host or ""62 scheme = request.url.scheme.lower()6364 # 1-3. Scheme, hostname, and pattern checks (reuse sync validator).65 try:66 validate_url_sync(str(request.url), self._policy)67 except SSRFBlockedError:68 raise6970 # Allowed-hosts bypass - skip DNS/IP validation entirely.71 allowed = {h.lower() for h in _effective_allowed_hosts(self._policy)}72 if hostname.lower() in allowed:73 return await self._inner.handle_async_request(request)7475 # 4. DNS resolution76 port = request.url.port or (443 if scheme == "https" else 80)77 try:78 addrinfo = await asyncio.to_thread(79 socket.getaddrinfo,80 hostname,81 port,82 type=socket.SOCK_STREAM,83 )84 except socket.gaierror as exc:85 raise SSRFBlockedError("DNS resolution failed") from exc8687 if not addrinfo:88 raise SSRFBlockedError("DNS resolution returned no results")8990 # 5. Validate ALL resolved IPs - any blocked means reject.91 for _family, _type, _proto, _canonname, sockaddr in addrinfo:92 ip_str: str = sockaddr[0] # type: ignore[assignment]93 validate_resolved_ip(ip_str, self._policy)9495 # 6. Pin to first resolved IP.96 pinned_ip = addrinfo[0][4][0]9798 # 7. Rewrite URL to use pinned IP, preserving Host header and SNI.99 pinned_url = request.url.copy_with(host=pinned_ip)100101 # Build extensions dict, adding sni_hostname for HTTPS so TLS102 # certificate validation uses the original hostname.103 extensions = dict(request.extensions)104 if scheme == "https":105 extensions["sni_hostname"] = hostname.encode("ascii")106107 pinned_request = httpx.Request(108 method=request.method,109 url=pinned_url,110 headers=request.headers, # Host header already set to original111 content=request.content,112 extensions=extensions,113 )114115 return await self._inner.handle_async_request(pinned_request)116117 # ------------------------------------------------------------------ #118 # Lifecycle119 # ------------------------------------------------------------------ #120121 async def aclose(self) -> None:122 await self._inner.aclose()123124125# ---------------------------------------------------------------------- #126# Factory127# ---------------------------------------------------------------------- #128129130class SSRFSafeSyncTransport(httpx.BaseTransport):131 """httpx sync transport that validates DNS results against an SSRF policy.132133 Sync mirror of `SSRFSafeTransport`. See that class for full documentation.134 """135136 def __init__(137 self,138 policy: SSRFPolicy = SSRFPolicy(),139 **transport_kwargs: object,140 ) -> None:141 self._policy = policy142 self._inner = httpx.HTTPTransport(**transport_kwargs) # type: ignore[arg-type]143144 def handle_request(145 self,146 request: httpx.Request,147 ) -> httpx.Response:148 hostname = request.url.host or ""149 scheme = request.url.scheme.lower()150151 validate_url_sync(str(request.url), self._policy)152153 allowed = {h.lower() for h in _effective_allowed_hosts(self._policy)}154 if hostname.lower() in allowed:155 return self._inner.handle_request(request)156157 port = request.url.port or (443 if scheme == "https" else 80)158 try:159 addrinfo = socket.getaddrinfo(160 hostname,161 port,162 type=socket.SOCK_STREAM,163 )164 except socket.gaierror as exc:165 raise SSRFBlockedError("DNS resolution failed") from exc166167 if not addrinfo:168 raise SSRFBlockedError("DNS resolution returned no results")169170 for _family, _type, _proto, _canonname, sockaddr in addrinfo:171 ip_str: str = sockaddr[0] # type: ignore[assignment]172 validate_resolved_ip(ip_str, self._policy)173174 pinned_ip = addrinfo[0][4][0]175 pinned_url = request.url.copy_with(host=pinned_ip)176177 extensions = dict(request.extensions)178 if scheme == "https":179 extensions["sni_hostname"] = hostname.encode("ascii")180181 pinned_request = httpx.Request(182 method=request.method,183 url=pinned_url,184 headers=request.headers,185 content=request.content,186 extensions=extensions,187 )188189 return self._inner.handle_request(pinned_request)190191 def close(self) -> None:192 self._inner.close()193194195# ---------------------------------------------------------------------- #196# Factories197# ---------------------------------------------------------------------- #198199200def ssrf_safe_client(201 policy: SSRFPolicy = SSRFPolicy(),202 **kwargs: object,203) -> httpx.Client:204 """Create an `httpx.Client` with SSRF protection."""205 transport_kwargs: dict[str, object] = {}206 client_kwargs: dict[str, object] = {}207 for key, value in kwargs.items():208 if key in _TRANSPORT_KWARGS:209 transport_kwargs[key] = value210 else:211 client_kwargs[key] = value212213 transport = SSRFSafeSyncTransport(policy=policy, **transport_kwargs)214215 client_kwargs.setdefault("follow_redirects", True)216 client_kwargs.setdefault("max_redirects", 10)217218 return httpx.Client(219 transport=transport,220 **client_kwargs, # type: ignore[arg-type]221 )222223224def ssrf_safe_async_client(225 policy: SSRFPolicy = SSRFPolicy(),226 **kwargs: object,227) -> httpx.AsyncClient:228 """Create an `httpx.AsyncClient` with SSRF protection.229230 Drop-in replacement for `httpx.AsyncClient(...)` - callers just swap231 the constructor call. Transport-specific kwargs (`verify`, `cert`,232 `retries`, etc.) are forwarded to the inner `AsyncHTTPTransport`;233 everything else goes to the `AsyncClient`.234 """235 transport_kwargs: dict[str, object] = {}236 client_kwargs: dict[str, object] = {}237 for key, value in kwargs.items():238 if key in _TRANSPORT_KWARGS:239 transport_kwargs[key] = value240 else:241 client_kwargs[key] = value242243 transport = SSRFSafeTransport(policy=policy, **transport_kwargs)244245 # Apply defaults only if not overridden by caller.246 client_kwargs.setdefault("follow_redirects", True)247 client_kwargs.setdefault("max_redirects", 10)248249 return httpx.AsyncClient(250 transport=transport,251 **client_kwargs, # type: ignore[arg-type]252 )
Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.