util.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import re
  2. from typing import Any
  3. from typing import Iterable
  4. from typing import Iterator
  5. from typing import List
  6. from typing import Optional
  7. from typing import overload
  8. from typing import Tuple
  9. from typing import Type as TypingType
  10. from typing import TypeVar
  11. from typing import Union
  12. from mypy.nodes import ARG_POS
  13. from mypy.nodes import CallExpr
  14. from mypy.nodes import ClassDef
  15. from mypy.nodes import CLASSDEF_NO_INFO
  16. from mypy.nodes import Context
  17. from mypy.nodes import Expression
  18. from mypy.nodes import IfStmt
  19. from mypy.nodes import JsonDict
  20. from mypy.nodes import MemberExpr
  21. from mypy.nodes import NameExpr
  22. from mypy.nodes import Statement
  23. from mypy.nodes import SymbolTableNode
  24. from mypy.nodes import TypeInfo
  25. from mypy.plugin import ClassDefContext
  26. from mypy.plugin import DynamicClassDefContext
  27. from mypy.plugin import SemanticAnalyzerPluginInterface
  28. from mypy.plugins.common import deserialize_and_fixup_type
  29. from mypy.typeops import map_type_from_supertype
  30. from mypy.types import Instance
  31. from mypy.types import NoneType
  32. from mypy.types import Type
  33. from mypy.types import TypeVarType
  34. from mypy.types import UnboundType
  35. from mypy.types import UnionType
  36. _TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
  37. class SQLAlchemyAttribute:
  38. def __init__(
  39. self,
  40. name: str,
  41. line: int,
  42. column: int,
  43. typ: Optional[Type],
  44. info: TypeInfo,
  45. ) -> None:
  46. self.name = name
  47. self.line = line
  48. self.column = column
  49. self.type = typ
  50. self.info = info
  51. def serialize(self) -> JsonDict:
  52. assert self.type
  53. return {
  54. "name": self.name,
  55. "line": self.line,
  56. "column": self.column,
  57. "type": self.type.serialize(),
  58. }
  59. def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
  60. """Expands type vars in the context of a subtype when an attribute is inherited
  61. from a generic super type."""
  62. if not isinstance(self.type, TypeVarType):
  63. return
  64. self.type = map_type_from_supertype(self.type, sub_type, self.info)
  65. @classmethod
  66. def deserialize(
  67. cls,
  68. info: TypeInfo,
  69. data: JsonDict,
  70. api: SemanticAnalyzerPluginInterface,
  71. ) -> "SQLAlchemyAttribute":
  72. data = data.copy()
  73. typ = deserialize_and_fixup_type(data.pop("type"), api)
  74. return cls(typ=typ, info=info, **data)
  75. def name_is_dunder(name):
  76. return bool(re.match(r"^__.+?__$", name))
  77. def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
  78. info.metadata.setdefault("sqlalchemy", {})[key] = data
  79. def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
  80. return info.metadata.get("sqlalchemy", {}).get(key, None)
  81. def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
  82. if info.mro:
  83. for base in info.mro:
  84. metadata = _get_info_metadata(base, key)
  85. if metadata is not None:
  86. return metadata
  87. return None
  88. def establish_as_sqlalchemy(info: TypeInfo) -> None:
  89. info.metadata.setdefault("sqlalchemy", {})
  90. def set_is_base(info: TypeInfo) -> None:
  91. _set_info_metadata(info, "is_base", True)
  92. def get_is_base(info: TypeInfo) -> bool:
  93. is_base = _get_info_metadata(info, "is_base")
  94. return is_base is True
  95. def has_declarative_base(info: TypeInfo) -> bool:
  96. is_base = _get_info_mro_metadata(info, "is_base")
  97. return is_base is True
  98. def set_has_table(info: TypeInfo) -> None:
  99. _set_info_metadata(info, "has_table", True)
  100. def get_has_table(info: TypeInfo) -> bool:
  101. is_base = _get_info_metadata(info, "has_table")
  102. return is_base is True
  103. def get_mapped_attributes(
  104. info: TypeInfo, api: SemanticAnalyzerPluginInterface
  105. ) -> Optional[List[SQLAlchemyAttribute]]:
  106. mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
  107. info, "mapped_attributes"
  108. )
  109. if mapped_attributes is None:
  110. return None
  111. attributes: List[SQLAlchemyAttribute] = []
  112. for data in mapped_attributes:
  113. attr = SQLAlchemyAttribute.deserialize(info, data, api)
  114. attr.expand_typevar_from_subtype(info)
  115. attributes.append(attr)
  116. return attributes
  117. def set_mapped_attributes(
  118. info: TypeInfo, attributes: List[SQLAlchemyAttribute]
  119. ) -> None:
  120. _set_info_metadata(
  121. info,
  122. "mapped_attributes",
  123. [attribute.serialize() for attribute in attributes],
  124. )
  125. def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
  126. msg = "[SQLAlchemy Mypy plugin] %s" % msg
  127. return api.fail(msg, ctx)
  128. def add_global(
  129. ctx: Union[ClassDefContext, DynamicClassDefContext],
  130. module: str,
  131. symbol_name: str,
  132. asname: str,
  133. ) -> None:
  134. module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
  135. if asname not in module_globals:
  136. lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
  137. symbol_name
  138. ]
  139. module_globals[asname] = lookup_sym
  140. @overload
  141. def get_callexpr_kwarg(
  142. callexpr: CallExpr, name: str, *, expr_types: None = ...
  143. ) -> Optional[Union[CallExpr, NameExpr]]:
  144. ...
  145. @overload
  146. def get_callexpr_kwarg(
  147. callexpr: CallExpr,
  148. name: str,
  149. *,
  150. expr_types: Tuple[TypingType[_TArgType], ...]
  151. ) -> Optional[_TArgType]:
  152. ...
  153. def get_callexpr_kwarg(
  154. callexpr: CallExpr,
  155. name: str,
  156. *,
  157. expr_types: Optional[Tuple[TypingType[Any], ...]] = None
  158. ) -> Optional[Any]:
  159. try:
  160. arg_idx = callexpr.arg_names.index(name)
  161. except ValueError:
  162. return None
  163. kwarg = callexpr.args[arg_idx]
  164. if isinstance(
  165. kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
  166. ):
  167. return kwarg
  168. return None
  169. def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
  170. for stmt in stmts:
  171. if (
  172. isinstance(stmt, IfStmt)
  173. and isinstance(stmt.expr[0], NameExpr)
  174. and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
  175. ):
  176. for substmt in stmt.body[0].body:
  177. yield substmt
  178. else:
  179. yield stmt
  180. def unbound_to_instance(
  181. api: SemanticAnalyzerPluginInterface, typ: Type
  182. ) -> Type:
  183. """Take the UnboundType that we seem to get as the ret_type from a FuncDef
  184. and convert it into an Instance/TypeInfo kind of structure that seems
  185. to work as the left-hand type of an AssignmentStatement.
  186. """
  187. if not isinstance(typ, UnboundType):
  188. return typ
  189. # TODO: figure out a more robust way to check this. The node is some
  190. # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
  191. # but I cant figure out how to get them to match up
  192. if typ.name == "Optional":
  193. # convert from "Optional?" to the more familiar
  194. # UnionType[..., NoneType()]
  195. return unbound_to_instance(
  196. api,
  197. UnionType(
  198. [unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
  199. + [NoneType()]
  200. ),
  201. )
  202. node = api.lookup_qualified(typ.name, typ)
  203. if (
  204. node is not None
  205. and isinstance(node, SymbolTableNode)
  206. and isinstance(node.node, TypeInfo)
  207. ):
  208. bound_type = node.node
  209. return Instance(
  210. bound_type,
  211. [
  212. unbound_to_instance(api, arg)
  213. if isinstance(arg, UnboundType)
  214. else arg
  215. for arg in typ.args
  216. ],
  217. )
  218. else:
  219. return typ
  220. def info_for_cls(
  221. cls: ClassDef, api: SemanticAnalyzerPluginInterface
  222. ) -> Optional[TypeInfo]:
  223. if cls.info is CLASSDEF_NO_INFO:
  224. sym = api.lookup_qualified(cls.name, cls)
  225. if sym is None:
  226. return None
  227. assert sym and isinstance(sym.node, TypeInfo)
  228. return sym.node
  229. return cls.info
  230. def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
  231. column_descriptor = NameExpr("__sa_Mapped")
  232. column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
  233. member_expr = MemberExpr(column_descriptor, "_empty_constructor")
  234. return CallExpr(
  235. member_expr,
  236. [expr],
  237. [ARG_POS],
  238. ["arg1"],
  239. )