Ensure functions have docstrings for documentation
def draw_mermaid(
1"""Mermaid graph drawing utilities."""23from __future__ import annotations45import asyncio6import base647import random8import re9import string10import time11import urllib.parse12from dataclasses import asdict13from pathlib import Path14from typing import TYPE_CHECKING, Any, Literal, cast1516import yaml1718from langchain_core.runnables.graph import (19 CurveStyle,20 MermaidDrawMethod,21 NodeStyles,22)2324if TYPE_CHECKING:25 from langchain_core.runnables.graph import Edge, Node262728try:29 import requests3031 _HAS_REQUESTS = True32except ImportError:33 _HAS_REQUESTS = False3435try:36 from pyppeteer import launch # type: ignore[import-not-found]3738 _HAS_PYPPETEER = True39except ImportError:40 _HAS_PYPPETEER = False4142MARKDOWN_SPECIAL_CHARS = "*_`"434445def draw_mermaid(46 nodes: dict[str, Node],47 edges: list[Edge],48 *,49 first_node: str | None = None,50 last_node: str | None = None,51 with_styles: bool = True,52 curve_style: CurveStyle = CurveStyle.LINEAR,53 node_styles: NodeStyles | None = None,54 wrap_label_n_words: int = 9,55 frontmatter_config: dict[str, Any] | None = None,56) -> str:57 """Draws a Mermaid graph using the provided graph data.5859 Args:60 nodes: List of node ids.61 edges: List of edges, object with a source, target and data.62 first_node: Id of the first node.63 last_node: Id of the last node.64 with_styles: Whether to include styles in the graph.65 curve_style: Curve style for the edges.66 node_styles: Node colors for different types.67 wrap_label_n_words: Words to wrap the edge labels.68 frontmatter_config: Mermaid frontmatter config.69 Can be used to customize theme and styles. Will be converted to YAML and70 added to the beginning of the mermaid graph.7172 See more here: https://mermaid.js.org/config/configuration.html.7374 Example config:7576 ```python77 {78 "config": {79 "theme": "neutral",80 "look": "handDrawn",81 "themeVariables": {"primaryColor": "#e2e2e2"},82 }83 }84 ```8586 Returns:87 Mermaid graph syntax.8889 """90 # Initialize Mermaid graph configuration91 original_frontmatter_config = frontmatter_config or {}92 original_flowchart_config = original_frontmatter_config.get("config", {}).get(93 "flowchart", {}94 )95 frontmatter_config = {96 **original_frontmatter_config,97 "config": {98 **original_frontmatter_config.get("config", {}),99 "flowchart": {**original_flowchart_config, "curve": curve_style.value},100 },101 }102103 mermaid_graph = (104 (105 "---\n"106 + yaml.dump(frontmatter_config, default_flow_style=False)107 + "---\ngraph TD;\n"108 )109 if with_styles110 else "graph TD;\n"111 )112 # Group nodes by subgraph113 subgraph_nodes: dict[str, dict[str, Node]] = {}114 regular_nodes: dict[str, Node] = {}115116 for key, node in nodes.items():117 if ":" in key:118 # For nodes with colons, add them only to their deepest subgraph level119 prefix = ":".join(key.split(":")[:-1])120 subgraph_nodes.setdefault(prefix, {})[key] = node121 else:122 regular_nodes[key] = node123124 # Node formatting templates125 default_class_label = "default"126 format_dict = {default_class_label: "{0}({1})"}127 if first_node is not None:128 format_dict[first_node] = "{0}([{1}]):::first"129 if last_node is not None:130 format_dict[last_node] = "{0}([{1}]):::last"131132 def render_node(key: str, node: Node, indent: str = "\t") -> str:133 """Helper function to render a node with consistent formatting."""134 node_name = node.name.split(":")[-1]135 label = (136 f"<p>{node_name}</p>"137 if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS))138 and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS))139 else node_name140 )141 if node.metadata:142 label = (143 f"{label}<hr/><small><em>"144 + "\n".join(f"{k} = {value}" for k, value in node.metadata.items())145 + "</em></small>"146 )147 node_label = format_dict.get(key, format_dict[default_class_label]).format(148 _to_safe_id(key), label149 )150 return f"{indent}{node_label}\n"151152 # Add non-subgraph nodes to the graph153 if with_styles:154 for key, node in regular_nodes.items():155 mermaid_graph += render_node(key, node)156157 # Group edges by their common prefixes158 edge_groups: dict[str, list[Edge]] = {}159 for edge in edges:160 src_parts = edge.source.split(":")161 tgt_parts = edge.target.split(":")162 common_prefix = ":".join(163 src for src, tgt in zip(src_parts, tgt_parts, strict=False) if src == tgt164 )165 edge_groups.setdefault(common_prefix, []).append(edge)166167 seen_subgraphs = set()168169 def add_subgraph(edges: list[Edge], prefix: str) -> None:170 nonlocal mermaid_graph171 self_loop = len(edges) == 1 and edges[0].source == edges[0].target172 if prefix and not self_loop:173 subgraph = prefix.rsplit(":", maxsplit=1)[-1]174 if subgraph in seen_subgraphs:175 msg = (176 f"Found duplicate subgraph '{subgraph}' -- this likely means that "177 "you're reusing a subgraph node with the same name. "178 "Please adjust your graph to have subgraph nodes with unique names."179 )180 raise ValueError(msg)181182 seen_subgraphs.add(subgraph)183 mermaid_graph += f"\tsubgraph {subgraph}\n"184185 # Add nodes that belong to this subgraph186 if with_styles and prefix in subgraph_nodes:187 for key, node in subgraph_nodes[prefix].items():188 mermaid_graph += render_node(key, node)189190 for edge in edges:191 source, target = edge.source, edge.target192193 # Add BR every wrap_label_n_words words194 if edge.data is not None:195 edge_data = edge.data196 words = str(edge_data).split() # Split the string into words197 # Group words into chunks of wrap_label_n_words size198 if len(words) > wrap_label_n_words:199 edge_data = " <br> ".join(200 " ".join(words[i : i + wrap_label_n_words])201 for i in range(0, len(words), wrap_label_n_words)202 )203 if edge.conditional:204 edge_label = f" -. {edge_data} .-> "205 else:206 edge_label = f" -- {edge_data} --> "207 else:208 edge_label = " -.-> " if edge.conditional else " --> "209210 mermaid_graph += (211 f"\t{_to_safe_id(source)}{edge_label}{_to_safe_id(target)};\n"212 )213214 # Recursively add nested subgraphs215 for nested_prefix, edges_ in edge_groups.items():216 if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix:217 continue218 # only go to first level subgraphs219 if ":" in nested_prefix[len(prefix) + 1 :]:220 continue221 add_subgraph(edges_, nested_prefix)222223 if prefix and not self_loop:224 mermaid_graph += "\tend\n"225226 # Start with the top-level edges (no common prefix)227 add_subgraph(edge_groups.get("", []), "")228229 # Add remaining subgraphs with edges230 for prefix, edges_ in edge_groups.items():231 if not prefix or ":" in prefix:232 continue233 add_subgraph(edges_, prefix)234 seen_subgraphs.add(prefix)235236 # Add empty subgraphs (subgraphs with no internal edges)237 if with_styles:238 for prefix, subgraph_node in subgraph_nodes.items():239 if ":" not in prefix and prefix not in seen_subgraphs:240 mermaid_graph += f"\tsubgraph {prefix}\n"241242 # Add nodes that belong to this subgraph243 for key, node in subgraph_node.items():244 mermaid_graph += render_node(key, node)245246 mermaid_graph += "\tend\n"247 seen_subgraphs.add(prefix)248249 # Add custom styles for nodes250 if with_styles:251 mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles())252 return mermaid_graph253254255def _to_safe_id(label: str) -> str:256 """Convert a string into a Mermaid-compatible node id.257258 Keep [a-zA-Z0-9_-] characters unchanged.259 Map every other character -> backslash + lowercase hex codepoint.260261 Result is guaranteed to be unique and Mermaid-compatible,262 so nodes with special characters always render correctly.263 """264 allowed = string.ascii_letters + string.digits + "_-"265 out = [ch if ch in allowed else "\\" + format(ord(ch), "x") for ch in label]266 return "".join(out)267268269def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:270 """Generates Mermaid graph styles for different node types."""271 styles = ""272 for class_name, style in asdict(node_colors).items():273 styles += f"\tclassDef {class_name} {style}\n"274 return styles275276277def draw_mermaid_png(278 mermaid_syntax: str,279 output_file_path: str | None = None,280 draw_method: MermaidDrawMethod = MermaidDrawMethod.API,281 background_color: str | None = "white",282 padding: int = 10,283 max_retries: int = 1,284 retry_delay: float = 1.0,285 base_url: str | None = None,286 proxies: dict[str, str] | None = None,287) -> bytes:288 """Draws a Mermaid graph as PNG using provided syntax.289290 Args:291 mermaid_syntax: Mermaid graph syntax.292 output_file_path: Path to save the PNG image.293 draw_method: Method to draw the graph.294 background_color: Background color of the image.295 padding: Padding around the image.296 max_retries: Maximum number of retries (MermaidDrawMethod.API).297 retry_delay: Delay between retries (MermaidDrawMethod.API).298 base_url: Base URL for the Mermaid.ink API.299 proxies: HTTP/HTTPS proxies for requests (e.g. `{"http": "http://127.0.0.1:7890"}`).300301 Returns:302 PNG image bytes.303304 Raises:305 ValueError: If an invalid draw method is provided.306 """307 if draw_method == MermaidDrawMethod.PYPPETEER:308 img_bytes = asyncio.run(309 _render_mermaid_using_pyppeteer(310 mermaid_syntax, output_file_path, background_color, padding311 )312 )313 elif draw_method == MermaidDrawMethod.API:314 img_bytes = _render_mermaid_using_api(315 mermaid_syntax,316 output_file_path=output_file_path,317 background_color=background_color,318 max_retries=max_retries,319 retry_delay=retry_delay,320 base_url=base_url,321 proxies=proxies,322 )323 else:324 supported_methods = ", ".join([m.value for m in MermaidDrawMethod])325 msg = (326 f"Invalid draw method: {draw_method}. "327 f"Supported draw methods are: {supported_methods}"328 )329 raise ValueError(msg)330331 return img_bytes332333334async def _render_mermaid_using_pyppeteer(335 mermaid_syntax: str,336 output_file_path: str | None = None,337 background_color: str | None = "white",338 padding: int = 10,339 device_scale_factor: int = 3,340) -> bytes:341 """Renders Mermaid graph using Pyppeteer."""342 if not _HAS_PYPPETEER:343 msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`."344 raise ImportError(msg)345346 browser = await launch()347 page = await browser.newPage()348349 # Setup Mermaid JS350 await page.goto("about:blank")351 await page.addScriptTag(352 {"url": "https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"}353 )354 await page.evaluate(355 """() => {356 mermaid.initialize({startOnLoad:true});357 }"""358 )359360 # Render SVG361 svg_code = await page.evaluate(362 """(mermaidGraph) => {363 return mermaid.mermaidAPI.render('mermaid', mermaidGraph);364 }""",365 mermaid_syntax,366 )367368 # Set the page background to white369 await page.evaluate(370 """(svg, background_color) => {371 document.body.innerHTML = svg;372 document.body.style.background = background_color;373 }""",374 svg_code["svg"],375 background_color,376 )377378 # Take a screenshot379 dimensions = await page.evaluate(380 """() => {381 const svgElement = document.querySelector('svg');382 const rect = svgElement.getBoundingClientRect();383 return { width: rect.width, height: rect.height };384 }"""385 )386 await page.setViewport(387 {388 "width": int(dimensions["width"] + padding),389 "height": int(dimensions["height"] + padding),390 "deviceScaleFactor": device_scale_factor,391 }392 )393394 img_bytes = cast("bytes", await page.screenshot({"fullPage": False}))395 await browser.close()396397 if output_file_path is not None:398 await asyncio.get_event_loop().run_in_executor(399 None, Path(output_file_path).write_bytes, img_bytes400 )401402 return img_bytes403404405def _render_mermaid_using_api(406 mermaid_syntax: str,407 *,408 output_file_path: str | None = None,409 background_color: str | None = "white",410 file_type: Literal["jpeg", "png", "webp"] | None = "png",411 max_retries: int = 1,412 retry_delay: float = 1.0,413 proxies: dict[str, str] | None = None,414 base_url: str | None = None,415) -> bytes:416 """Renders Mermaid graph using the Mermaid.INK API."""417 # Automated scanners: please do not open security advisories about `base_url`418 # here. This parameter is only ever explicitly controlled by developers, and it419 # is not exposed by default over the network by the framework, so this is not420 # an SSRF sink unless an application explicitly wires untrusted input into it.421 # It exists to support user-managed Mermaid API deployments.422 # Defaults to using the public mermaid.ink server.423 base_url = base_url if base_url is not None else "https://mermaid.ink"424425 if not _HAS_REQUESTS:426 msg = (427 "Install the `requests` module to use the Mermaid.INK API: "428 "`pip install requests`."429 )430 raise ImportError(msg)431432 # Use Mermaid API to render the image433 mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode(434 "ascii"435 )436437 # Check if the background color is a hexadecimal color code using regex438 if background_color is not None:439 hex_color_pattern = re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$")440 if not hex_color_pattern.match(background_color):441 background_color = f"!{background_color}"442443 # URL-encode the background_color to handle special characters like '!'444 encoded_bg_color = urllib.parse.quote(str(background_color), safe="")445 image_url = (446 f"{base_url}/img/{mermaid_syntax_encoded}"447 f"?type={file_type}&bgColor={encoded_bg_color}"448 )449450 error_msg_suffix = (451 "To resolve this issue:\n"452 "1. Check your internet connection and try again\n"453 "2. Try with higher retry settings: "454 "`draw_mermaid_png(..., max_retries=5, retry_delay=2.0)`\n"455 "3. Use the Pyppeteer rendering method which will render your graph locally "456 "in a browser: `draw_mermaid_png(..., draw_method=MermaidDrawMethod.PYPPETEER)`"457 )458459 for attempt in range(max_retries + 1):460 try:461 response = requests.get(image_url, timeout=10, proxies=proxies)462 if response.status_code == requests.codes.ok:463 img_bytes = response.content464 if output_file_path is not None:465 Path(output_file_path).write_bytes(response.content)466467 return img_bytes468469 # If we get a server error (5xx), retry470 if (471 requests.codes.internal_server_error <= response.status_code472 and attempt < max_retries473 ):474 # Exponential backoff with jitter475 sleep_time = retry_delay * (2**attempt) * (0.5 + 0.5 * random.random()) # noqa: S311 not used for crypto476 time.sleep(sleep_time)477 continue478479 # For other status codes, fail immediately480 msg = (481 f"Failed to reach {base_url} API while trying to render "482 f"your graph. Status code: {response.status_code}.\n\n"483 ) + error_msg_suffix484 raise ValueError(msg)485486 except (requests.RequestException, requests.Timeout) as e:487 if attempt < max_retries:488 # Exponential backoff with jitter489 sleep_time = retry_delay * (2**attempt) * (0.5 + 0.5 * random.random()) # noqa: S311 not used for crypto490 time.sleep(sleep_time)491 else:492 msg = (493 f"Failed to reach {base_url} API while trying to render "494 f"your graph after {max_retries} retries. "495 ) + error_msg_suffix496 raise ValueError(msg) from e497498 # This should not be reached, but just in case499 msg = (500 f"Failed to reach {base_url} API while trying to render "501 f"your graph after {max_retries} retries. "502 ) + error_msg_suffix503 raise ValueError(msg)
Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.