PageRenderTime 27ms CodeModel.GetById 13ms RepoModel.GetById 0ms app.codeStats 0ms

/lib/sqlalchemy/ext/mypy/util.py

https://bitbucket.org/zzzeek/sqlalchemy
Python | 299 lines | 225 code | 62 blank | 12 comment | 32 complexity | ef14f57f4c1e34f4c3b02fac8c7a9bb2 MD5 | raw file
  1. from typing import Any
  2. from typing import Iterable
  3. from typing import Iterator
  4. from typing import List
  5. from typing import Optional
  6. from typing import overload
  7. from typing import Tuple
  8. from typing import Type as TypingType
  9. from typing import TypeVar
  10. from typing import Union
  11. from mypy.nodes import ARG_POS
  12. from mypy.nodes import CallExpr
  13. from mypy.nodes import ClassDef
  14. from mypy.nodes import CLASSDEF_NO_INFO
  15. from mypy.nodes import Context
  16. from mypy.nodes import Expression
  17. from mypy.nodes import IfStmt
  18. from mypy.nodes import JsonDict
  19. from mypy.nodes import MemberExpr
  20. from mypy.nodes import NameExpr
  21. from mypy.nodes import Statement
  22. from mypy.nodes import SymbolTableNode
  23. from mypy.nodes import TypeInfo
  24. from mypy.plugin import ClassDefContext
  25. from mypy.plugin import DynamicClassDefContext
  26. from mypy.plugin import SemanticAnalyzerPluginInterface
  27. from mypy.plugins.common import deserialize_and_fixup_type
  28. from mypy.typeops import map_type_from_supertype
  29. from mypy.types import Instance
  30. from mypy.types import NoneType
  31. from mypy.types import Type
  32. from mypy.types import TypeVarType
  33. from mypy.types import UnboundType
  34. from mypy.types import UnionType
  35. _TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
  36. class SQLAlchemyAttribute:
  37. def __init__(
  38. self,
  39. name: str,
  40. line: int,
  41. column: int,
  42. typ: Optional[Type],
  43. info: TypeInfo,
  44. ) -> None:
  45. self.name = name
  46. self.line = line
  47. self.column = column
  48. self.type = typ
  49. self.info = info
  50. def serialize(self) -> JsonDict:
  51. assert self.type
  52. return {
  53. "name": self.name,
  54. "line": self.line,
  55. "column": self.column,
  56. "type": self.type.serialize(),
  57. }
  58. def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
  59. """Expands type vars in the context of a subtype when an attribute is inherited
  60. from a generic super type."""
  61. if not isinstance(self.type, TypeVarType):
  62. return
  63. self.type = map_type_from_supertype(self.type, sub_type, self.info)
  64. @classmethod
  65. def deserialize(
  66. cls,
  67. info: TypeInfo,
  68. data: JsonDict,
  69. api: SemanticAnalyzerPluginInterface,
  70. ) -> "SQLAlchemyAttribute":
  71. data = data.copy()
  72. typ = deserialize_and_fixup_type(data.pop("type"), api)
  73. return cls(typ=typ, info=info, **data)
  74. def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
  75. info.metadata.setdefault("sqlalchemy", {})[key] = data
  76. def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
  77. return info.metadata.get("sqlalchemy", {}).get(key, None)
  78. def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
  79. if info.mro:
  80. for base in info.mro:
  81. metadata = _get_info_metadata(base, key)
  82. if metadata is not None:
  83. return metadata
  84. return None
  85. def establish_as_sqlalchemy(info: TypeInfo) -> None:
  86. info.metadata.setdefault("sqlalchemy", {})
  87. def set_is_base(info: TypeInfo) -> None:
  88. _set_info_metadata(info, "is_base", True)
  89. def get_is_base(info: TypeInfo) -> bool:
  90. is_base = _get_info_metadata(info, "is_base")
  91. return is_base is True
  92. def has_declarative_base(info: TypeInfo) -> bool:
  93. is_base = _get_info_mro_metadata(info, "is_base")
  94. return is_base is True
  95. def set_has_table(info: TypeInfo) -> None:
  96. _set_info_metadata(info, "has_table", True)
  97. def get_has_table(info: TypeInfo) -> bool:
  98. is_base = _get_info_metadata(info, "has_table")
  99. return is_base is True
  100. def get_mapped_attributes(
  101. info: TypeInfo, api: SemanticAnalyzerPluginInterface
  102. ) -> Optional[List[SQLAlchemyAttribute]]:
  103. mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
  104. info, "mapped_attributes"
  105. )
  106. if mapped_attributes is None:
  107. return None
  108. attributes: List[SQLAlchemyAttribute] = []
  109. for data in mapped_attributes:
  110. attr = SQLAlchemyAttribute.deserialize(info, data, api)
  111. attr.expand_typevar_from_subtype(info)
  112. attributes.append(attr)
  113. return attributes
  114. def set_mapped_attributes(
  115. info: TypeInfo, attributes: List[SQLAlchemyAttribute]
  116. ) -> None:
  117. _set_info_metadata(
  118. info,
  119. "mapped_attributes",
  120. [attribute.serialize() for attribute in attributes],
  121. )
  122. def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
  123. msg = "[SQLAlchemy Mypy plugin] %s" % msg
  124. return api.fail(msg, ctx)
  125. def add_global(
  126. ctx: Union[ClassDefContext, DynamicClassDefContext],
  127. module: str,
  128. symbol_name: str,
  129. asname: str,
  130. ) -> None:
  131. module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
  132. if asname not in module_globals:
  133. lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
  134. symbol_name
  135. ]
  136. module_globals[asname] = lookup_sym
  137. @overload
  138. def get_callexpr_kwarg(
  139. callexpr: CallExpr, name: str, *, expr_types: None = ...
  140. ) -> Optional[Union[CallExpr, NameExpr]]:
  141. ...
  142. @overload
  143. def get_callexpr_kwarg(
  144. callexpr: CallExpr,
  145. name: str,
  146. *,
  147. expr_types: Tuple[TypingType[_TArgType], ...]
  148. ) -> Optional[_TArgType]:
  149. ...
  150. def get_callexpr_kwarg(
  151. callexpr: CallExpr,
  152. name: str,
  153. *,
  154. expr_types: Optional[Tuple[TypingType[Any], ...]] = None
  155. ) -> Optional[Any]:
  156. try:
  157. arg_idx = callexpr.arg_names.index(name)
  158. except ValueError:
  159. return None
  160. kwarg = callexpr.args[arg_idx]
  161. if isinstance(
  162. kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
  163. ):
  164. return kwarg
  165. return None
  166. def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
  167. for stmt in stmts:
  168. if (
  169. isinstance(stmt, IfStmt)
  170. and isinstance(stmt.expr[0], NameExpr)
  171. and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
  172. ):
  173. for substmt in stmt.body[0].body:
  174. yield substmt
  175. else:
  176. yield stmt
  177. def unbound_to_instance(
  178. api: SemanticAnalyzerPluginInterface, typ: Type
  179. ) -> Type:
  180. """Take the UnboundType that we seem to get as the ret_type from a FuncDef
  181. and convert it into an Instance/TypeInfo kind of structure that seems
  182. to work as the left-hand type of an AssignmentStatement.
  183. """
  184. if not isinstance(typ, UnboundType):
  185. return typ
  186. # TODO: figure out a more robust way to check this. The node is some
  187. # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
  188. # but I cant figure out how to get them to match up
  189. if typ.name == "Optional":
  190. # convert from "Optional?" to the more familiar
  191. # UnionType[..., NoneType()]
  192. return unbound_to_instance(
  193. api,
  194. UnionType(
  195. [unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
  196. + [NoneType()]
  197. ),
  198. )
  199. node = api.lookup_qualified(typ.name, typ)
  200. if (
  201. node is not None
  202. and isinstance(node, SymbolTableNode)
  203. and isinstance(node.node, TypeInfo)
  204. ):
  205. bound_type = node.node
  206. return Instance(
  207. bound_type,
  208. [
  209. unbound_to_instance(api, arg)
  210. if isinstance(arg, UnboundType)
  211. else arg
  212. for arg in typ.args
  213. ],
  214. )
  215. else:
  216. return typ
  217. def info_for_cls(
  218. cls: ClassDef, api: SemanticAnalyzerPluginInterface
  219. ) -> Optional[TypeInfo]:
  220. if cls.info is CLASSDEF_NO_INFO:
  221. sym = api.lookup_qualified(cls.name, cls)
  222. if sym is None:
  223. return None
  224. assert sym and isinstance(sym.node, TypeInfo)
  225. return sym.node
  226. return cls.info
  227. def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
  228. column_descriptor = NameExpr("__sa_Mapped")
  229. column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
  230. member_expr = MemberExpr(column_descriptor, "_empty_constructor")
  231. return CallExpr(
  232. member_expr,
  233. [expr],
  234. [ARG_POS],
  235. ["arg1"],
  236. )