Ensure functions have docstrings for documentation
def get_openapi_security_definitions(
1import copy2import http.client3import inspect4import warnings5from collections.abc import Sequence6from typing import Any, Literal, cast78from fastapi import routing9from fastapi._compat import (10 ModelField,11 get_definitions,12 get_flat_models_from_fields,13 get_model_name_map,14 get_schema_from_model_field,15 lenient_issubclass,16)17from fastapi.datastructures import DefaultPlaceholder, _Unset18from fastapi.dependencies.models import Dependant19from fastapi.dependencies.utils import (20 _get_flat_fields_from_params,21 get_flat_dependant,22 get_flat_params,23 get_validation_alias,24)25from fastapi.encoders import jsonable_encoder26from fastapi.exceptions import FastAPIDeprecationWarning27from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX28from fastapi.openapi.models import OpenAPI29from fastapi.params import Body, ParamTypes30from fastapi.responses import Response31from fastapi.sse import _SSE_EVENT_SCHEMA32from fastapi.types import ModelNameMap33from fastapi.utils import (34 deep_dict_update,35 generate_operation_id_for_path,36 is_body_allowed_for_status_code,37)38from pydantic import BaseModel39from starlette.responses import JSONResponse40from starlette.routing import BaseRoute4142validation_error_definition = {43 "title": "ValidationError",44 "type": "object",45 "properties": {46 "loc": {47 "title": "Location",48 "type": "array",49 "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},50 },51 "msg": {"title": "Message", "type": "string"},52 "type": {"title": "Error Type", "type": "string"},53 "input": {"title": "Input"},54 "ctx": {"title": "Context", "type": "object"},55 },56 "required": ["loc", "msg", "type"],57}5859validation_error_response_definition = {60 "title": "HTTPValidationError",61 "type": "object",62 "properties": {63 "detail": {64 "title": "Detail",65 "type": "array",66 "items": {"$ref": REF_PREFIX + "ValidationError"},67 }68 },69}7071status_code_ranges: dict[str, str] = {72 "1XX": "Information",73 "2XX": "Success",74 "3XX": "Redirection",75 "4XX": "Client Error",76 "5XX": "Server Error",77 "DEFAULT": "Default Response",78}798081def get_openapi_security_definitions(82 flat_dependant: Dependant,83) -> tuple[dict[str, Any], list[dict[str, Any]]]:84 security_definitions = {}85 # Use a dict to merge scopes for same security scheme86 operation_security_dict: dict[str, list[str]] = {}87 for security_dependency in flat_dependant._security_dependencies:88 security_definition = jsonable_encoder(89 security_dependency._security_scheme.model,90 by_alias=True,91 exclude_none=True,92 )93 security_name = security_dependency._security_scheme.scheme_name94 security_definitions[security_name] = security_definition95 # Merge scopes for the same security scheme96 if security_name not in operation_security_dict:97 operation_security_dict[security_name] = []98 for scope in security_dependency.oauth_scopes or []:99 if scope not in operation_security_dict[security_name]:100 operation_security_dict[security_name].append(scope)101 operation_security = [102 {name: scopes} for name, scopes in operation_security_dict.items()103 ]104 return security_definitions, operation_security105106107def _get_openapi_operation_parameters(108 *,109 dependant: Dependant,110 model_name_map: ModelNameMap,111 field_mapping: dict[112 tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]113 ],114 separate_input_output_schemas: bool = True,115) -> list[dict[str, Any]]:116 parameters = []117 flat_dependant = get_flat_dependant(dependant, skip_repeats=True)118 path_params = _get_flat_fields_from_params(flat_dependant.path_params)119 query_params = _get_flat_fields_from_params(flat_dependant.query_params)120 header_params = _get_flat_fields_from_params(flat_dependant.header_params)121 cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)122 parameter_groups = [123 (ParamTypes.path, path_params),124 (ParamTypes.query, query_params),125 (ParamTypes.header, header_params),126 (ParamTypes.cookie, cookie_params),127 ]128 default_convert_underscores = True129 if len(flat_dependant.header_params) == 1:130 first_field = flat_dependant.header_params[0]131 if lenient_issubclass(first_field.field_info.annotation, BaseModel):132 default_convert_underscores = getattr(133 first_field.field_info, "convert_underscores", True134 )135 for param_type, param_group in parameter_groups:136 for param in param_group:137 field_info = param.field_info138 # field_info = cast(Param, field_info)139 if not getattr(field_info, "include_in_schema", True):140 continue141 param_schema = get_schema_from_model_field(142 field=param,143 model_name_map=model_name_map,144 field_mapping=field_mapping,145 separate_input_output_schemas=separate_input_output_schemas,146 )147 name = get_validation_alias(param)148 convert_underscores = getattr(149 param.field_info,150 "convert_underscores",151 default_convert_underscores,152 )153 if (154 param_type == ParamTypes.header155 and name == param.name156 and convert_underscores157 ):158 name = param.name.replace("_", "-")159160 parameter = {161 "name": name,162 "in": param_type.value,163 "required": param.field_info.is_required(),164 "schema": param_schema,165 }166 if field_info.description:167 parameter["description"] = field_info.description168 openapi_examples = getattr(field_info, "openapi_examples", None)169 example = getattr(field_info, "example", None)170 if openapi_examples:171 parameter["examples"] = jsonable_encoder(openapi_examples)172 elif example is not _Unset:173 parameter["example"] = jsonable_encoder(example)174 if getattr(field_info, "deprecated", None):175 parameter["deprecated"] = True176 parameters.append(parameter)177 return parameters178179180def get_openapi_operation_request_body(181 *,182 body_field: ModelField | None,183 model_name_map: ModelNameMap,184 field_mapping: dict[185 tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]186 ],187 separate_input_output_schemas: bool = True,188) -> dict[str, Any] | None:189 if not body_field:190 return None191 assert isinstance(body_field, ModelField)192 body_schema = get_schema_from_model_field(193 field=body_field,194 model_name_map=model_name_map,195 field_mapping=field_mapping,196 separate_input_output_schemas=separate_input_output_schemas,197 )198 field_info = cast(Body, body_field.field_info)199 request_media_type = field_info.media_type200 required = body_field.field_info.is_required()201 request_body_oai: dict[str, Any] = {}202 if required:203 request_body_oai["required"] = required204 request_media_content: dict[str, Any] = {"schema": body_schema}205 if field_info.openapi_examples:206 request_media_content["examples"] = jsonable_encoder(207 field_info.openapi_examples208 )209 elif field_info.example is not _Unset:210 request_media_content["example"] = jsonable_encoder(field_info.example)211 request_body_oai["content"] = {request_media_type: request_media_content}212 return request_body_oai213214215def generate_operation_id(216 *, route: routing.APIRoute, method: str217) -> str: # pragma: nocover218 warnings.warn(219 message="fastapi.openapi.utils.generate_operation_id() was deprecated, "220 "it is not used internally, and will be removed soon",221 category=FastAPIDeprecationWarning,222 stacklevel=2,223 )224 if route.operation_id:225 return route.operation_id226 path: str = route.path_format227 return generate_operation_id_for_path(name=route.name, path=path, method=method)228229230def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:231 if route.summary:232 return route.summary233 return route.name.replace("_", " ").title()234235236def get_openapi_operation_metadata(237 *, route: routing.APIRoute, method: str, operation_ids: set[str]238) -> dict[str, Any]:239 operation: dict[str, Any] = {}240 if route.tags:241 operation["tags"] = route.tags242 operation["summary"] = generate_operation_summary(route=route, method=method)243 if route.description:244 operation["description"] = route.description245 operation_id = route.operation_id or route.unique_id246 if operation_id in operation_ids:247 endpoint_name = getattr(route.endpoint, "__name__", "<unnamed_endpoint>")248 message = f"Duplicate Operation ID {operation_id} for function {endpoint_name}"249 file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")250 if file_name:251 message += f" at {file_name}"252 warnings.warn(message, stacklevel=1)253 operation_ids.add(operation_id)254 operation["operationId"] = operation_id255 if route.deprecated:256 operation["deprecated"] = route.deprecated257 return operation258259260def get_openapi_path(261 *,262 route: routing.APIRoute,263 operation_ids: set[str],264 model_name_map: ModelNameMap,265 field_mapping: dict[266 tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]267 ],268 separate_input_output_schemas: bool = True,269) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:270 path = {}271 security_schemes: dict[str, Any] = {}272 definitions: dict[str, Any] = {}273 assert route.methods is not None, "Methods must be a list"274 if isinstance(route.response_class, DefaultPlaceholder):275 current_response_class: type[Response] = route.response_class.value276 else:277 current_response_class = route.response_class278 assert current_response_class, "A response class is needed to generate OpenAPI"279 route_response_media_type: str | None = current_response_class.media_type280 if route.include_in_schema:281 for method in route.methods:282 operation = get_openapi_operation_metadata(283 route=route, method=method, operation_ids=operation_ids284 )285 parameters: list[dict[str, Any]] = []286 flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)287 security_definitions, operation_security = get_openapi_security_definitions(288 flat_dependant=flat_dependant289 )290 if operation_security:291 operation.setdefault("security", []).extend(operation_security)292 if security_definitions:293 security_schemes.update(security_definitions)294 operation_parameters = _get_openapi_operation_parameters(295 dependant=route.dependant,296 model_name_map=model_name_map,297 field_mapping=field_mapping,298 separate_input_output_schemas=separate_input_output_schemas,299 )300 parameters.extend(operation_parameters)301 if parameters:302 all_parameters = {303 (param["in"], param["name"]): param for param in parameters304 }305 required_parameters = {306 (param["in"], param["name"]): param307 for param in parameters308 if param.get("required")309 }310 # Make sure required definitions of the same parameter take precedence311 # over non-required definitions312 all_parameters.update(required_parameters)313 operation["parameters"] = list(all_parameters.values())314 if method in METHODS_WITH_BODY:315 request_body_oai = get_openapi_operation_request_body(316 body_field=route.body_field,317 model_name_map=model_name_map,318 field_mapping=field_mapping,319 separate_input_output_schemas=separate_input_output_schemas,320 )321 if request_body_oai:322 operation["requestBody"] = request_body_oai323 if route.callbacks:324 callbacks = {}325 for callback in route.callbacks:326 if isinstance(callback, routing.APIRoute):327 (328 cb_path,329 cb_security_schemes,330 cb_definitions,331 ) = get_openapi_path(332 route=callback,333 operation_ids=operation_ids,334 model_name_map=model_name_map,335 field_mapping=field_mapping,336 separate_input_output_schemas=separate_input_output_schemas,337 )338 callbacks[callback.name] = {callback.path: cb_path}339 operation["callbacks"] = callbacks340 if route.status_code is not None:341 status_code = str(route.status_code)342 else:343 # It would probably make more sense for all response classes to have an344 # explicit default status_code, and to extract it from them, instead of345 # doing this inspection tricks, that would probably be in the future346 # TODO: probably make status_code a default class attribute for all347 # responses in Starlette348 response_signature = inspect.signature(current_response_class.__init__)349 status_code_param = response_signature.parameters.get("status_code")350 if status_code_param is not None:351 if isinstance(status_code_param.default, int):352 status_code = str(status_code_param.default)353 operation.setdefault("responses", {}).setdefault(status_code, {})[354 "description"355 ] = route.response_description356 if is_body_allowed_for_status_code(route.status_code):357 # Check for JSONL streaming (generator endpoints)358 if route.is_json_stream:359 jsonl_content: dict[str, Any] = {}360 if route.stream_item_field:361 item_schema = get_schema_from_model_field(362 field=route.stream_item_field,363 model_name_map=model_name_map,364 field_mapping=field_mapping,365 separate_input_output_schemas=separate_input_output_schemas,366 )367 jsonl_content["itemSchema"] = item_schema368 else:369 jsonl_content["itemSchema"] = {}370 operation.setdefault("responses", {}).setdefault(371 status_code, {}372 ).setdefault("content", {})["application/jsonl"] = jsonl_content373 elif route.is_sse_stream:374 sse_content: dict[str, Any] = {}375 item_schema = copy.deepcopy(_SSE_EVENT_SCHEMA)376 if route.stream_item_field:377 content_schema = get_schema_from_model_field(378 field=route.stream_item_field,379 model_name_map=model_name_map,380 field_mapping=field_mapping,381 separate_input_output_schemas=separate_input_output_schemas,382 )383 item_schema["required"] = ["data"]384 item_schema["properties"]["data"] = {385 "type": "string",386 "contentMediaType": "application/json",387 "contentSchema": content_schema,388 }389 sse_content["itemSchema"] = item_schema390 operation.setdefault("responses", {}).setdefault(391 status_code, {}392 ).setdefault("content", {})["text/event-stream"] = sse_content393 elif route_response_media_type:394 response_schema = {"type": "string"}395 if lenient_issubclass(current_response_class, JSONResponse):396 if route.response_field:397 response_schema = get_schema_from_model_field(398 field=route.response_field,399 model_name_map=model_name_map,400 field_mapping=field_mapping,401 separate_input_output_schemas=separate_input_output_schemas,402 )403 else:404 response_schema = {}405 operation.setdefault("responses", {}).setdefault(406 status_code, {}407 ).setdefault("content", {}).setdefault(408 route_response_media_type, {}409 )["schema"] = response_schema410 if route.responses:411 operation_responses = operation.setdefault("responses", {})412 for (413 additional_status_code,414 additional_response,415 ) in route.responses.items():416 process_response = copy.deepcopy(additional_response)417 process_response.pop("model", None)418 status_code_key = str(additional_status_code).upper()419 if status_code_key == "DEFAULT":420 status_code_key = "default"421 openapi_response = operation_responses.setdefault(422 status_code_key, {}423 )424 assert isinstance(process_response, dict), (425 "An additional response must be a dict"426 )427 field = route.response_fields.get(additional_status_code)428 additional_field_schema: dict[str, Any] | None = None429 if field:430 additional_field_schema = get_schema_from_model_field(431 field=field,432 model_name_map=model_name_map,433 field_mapping=field_mapping,434 separate_input_output_schemas=separate_input_output_schemas,435 )436 media_type = route_response_media_type or "application/json"437 additional_schema = (438 process_response.setdefault("content", {})439 .setdefault(media_type, {})440 .setdefault("schema", {})441 )442 deep_dict_update(additional_schema, additional_field_schema)443 status_text: str | None = status_code_ranges.get(444 str(additional_status_code).upper()445 ) or http.client.responses.get(int(additional_status_code))446 description = (447 process_response.get("description")448 or openapi_response.get("description")449 or status_text450 or "Additional Response"451 )452 deep_dict_update(openapi_response, process_response)453 openapi_response["description"] = description454 http422 = "422"455 all_route_params = get_flat_params(route.dependant)456 if (all_route_params or route.body_field) and not any(457 status in operation["responses"]458 for status in [http422, "4XX", "default"]459 ):460 operation["responses"][http422] = {461 "description": "Validation Error",462 "content": {463 "application/json": {464 "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}465 }466 },467 }468 if "ValidationError" not in definitions:469 definitions.update(470 {471 "ValidationError": validation_error_definition,472 "HTTPValidationError": validation_error_response_definition,473 }474 )475 if route.openapi_extra:476 deep_dict_update(operation, route.openapi_extra)477 path[method.lower()] = operation478 return path, security_schemes, definitions479480481def get_fields_from_routes(482 routes: Sequence[BaseRoute],483) -> list[ModelField]:484 body_fields_from_routes: list[ModelField] = []485 responses_from_routes: list[ModelField] = []486 request_fields_from_routes: list[ModelField] = []487 callback_flat_models: list[ModelField] = []488 for route in routes:489 if not isinstance(route, routing.APIRoute):490 continue491 if route.include_in_schema:492 if route.body_field:493 assert isinstance(route.body_field, ModelField), (494 "A request body must be a Pydantic Field"495 )496 body_fields_from_routes.append(route.body_field)497 if route.response_field:498 responses_from_routes.append(route.response_field)499 if route.response_fields:500 responses_from_routes.extend(route.response_fields.values())501 if route.stream_item_field:502 responses_from_routes.append(route.stream_item_field)503 if route.callbacks:504 callback_flat_models.extend(get_fields_from_routes(route.callbacks))505 params = get_flat_params(route.dependant)506 request_fields_from_routes.extend(params)507508 flat_models = callback_flat_models + list(509 body_fields_from_routes + responses_from_routes + request_fields_from_routes510 )511 return flat_models512513514def get_openapi(515 *,516 title: str,517 version: str,518 openapi_version: str = "3.1.0",519 summary: str | None = None,520 description: str | None = None,521 routes: Sequence[BaseRoute],522 webhooks: Sequence[BaseRoute] | None = None,523 tags: list[dict[str, Any]] | None = None,524 servers: list[dict[str, str | Any]] | None = None,525 terms_of_service: str | None = None,526 contact: dict[str, str | Any] | None = None,527 license_info: dict[str, str | Any] | None = None,528 separate_input_output_schemas: bool = True,529 external_docs: dict[str, Any] | None = None,530) -> dict[str, Any]:531 info: dict[str, Any] = {"title": title, "version": version}532 if summary:533 info["summary"] = summary534 if description:535 info["description"] = description536 if terms_of_service:537 info["termsOfService"] = terms_of_service538 if contact:539 info["contact"] = contact540 if license_info:541 info["license"] = license_info542 output: dict[str, Any] = {"openapi": openapi_version, "info": info}543 if servers:544 output["servers"] = servers545 components: dict[str, dict[str, Any]] = {}546 paths: dict[str, dict[str, Any]] = {}547 webhook_paths: dict[str, dict[str, Any]] = {}548 operation_ids: set[str] = set()549 all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))550 flat_models = get_flat_models_from_fields(all_fields, known_models=set())551 model_name_map = get_model_name_map(flat_models)552 field_mapping, definitions = get_definitions(553 fields=all_fields,554 model_name_map=model_name_map,555 separate_input_output_schemas=separate_input_output_schemas,556 )557 for route in routes or []:558 if isinstance(route, routing.APIRoute):559 result = get_openapi_path(560 route=route,561 operation_ids=operation_ids,562 model_name_map=model_name_map,563 field_mapping=field_mapping,564 separate_input_output_schemas=separate_input_output_schemas,565 )566 if result:567 path, security_schemes, path_definitions = result568 if path:569 paths.setdefault(route.path_format, {}).update(path)570 if security_schemes:571 components.setdefault("securitySchemes", {}).update(572 security_schemes573 )574 if path_definitions:575 definitions.update(path_definitions)576 for webhook in webhooks or []:577 if isinstance(webhook, routing.APIRoute):578 result = get_openapi_path(579 route=webhook,580 operation_ids=operation_ids,581 model_name_map=model_name_map,582 field_mapping=field_mapping,583 separate_input_output_schemas=separate_input_output_schemas,584 )585 if result:586 path, security_schemes, path_definitions = result587 if path:588 webhook_paths.setdefault(webhook.path_format, {}).update(path)589 if security_schemes:590 components.setdefault("securitySchemes", {}).update(591 security_schemes592 )593 if path_definitions:594 definitions.update(path_definitions)595 if definitions:596 components["schemas"] = {k: definitions[k] for k in sorted(definitions)}597 if components:598 output["components"] = components599 output["paths"] = paths600 if webhook_paths:601 output["webhooks"] = webhook_paths602 if tags:603 output["tags"] = tags604 if external_docs:605 output["externalDocs"] = external_docs606 return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore[no-any-return]
Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.