apply.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # ext/mypy/apply.py
  2. # Copyright (C) 2021 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. from typing import List
  8. from typing import Optional
  9. from typing import Union
  10. from mypy.nodes import ARG_NAMED_OPT
  11. from mypy.nodes import Argument
  12. from mypy.nodes import AssignmentStmt
  13. from mypy.nodes import CallExpr
  14. from mypy.nodes import ClassDef
  15. from mypy.nodes import MDEF
  16. from mypy.nodes import MemberExpr
  17. from mypy.nodes import NameExpr
  18. from mypy.nodes import RefExpr
  19. from mypy.nodes import StrExpr
  20. from mypy.nodes import SymbolTableNode
  21. from mypy.nodes import TempNode
  22. from mypy.nodes import TypeInfo
  23. from mypy.nodes import Var
  24. from mypy.plugin import SemanticAnalyzerPluginInterface
  25. from mypy.plugins.common import add_method_to_class
  26. from mypy.types import AnyType
  27. from mypy.types import get_proper_type
  28. from mypy.types import Instance
  29. from mypy.types import NoneTyp
  30. from mypy.types import ProperType
  31. from mypy.types import TypeOfAny
  32. from mypy.types import UnboundType
  33. from mypy.types import UnionType
  34. from . import infer
  35. from . import util
  36. from .names import NAMED_TYPE_SQLA_MAPPED
  37. def apply_mypy_mapped_attr(
  38. cls: ClassDef,
  39. api: SemanticAnalyzerPluginInterface,
  40. item: Union[NameExpr, StrExpr],
  41. attributes: List[util.SQLAlchemyAttribute],
  42. ) -> None:
  43. if isinstance(item, NameExpr):
  44. name = item.name
  45. elif isinstance(item, StrExpr):
  46. name = item.value
  47. else:
  48. return None
  49. for stmt in cls.defs.body:
  50. if (
  51. isinstance(stmt, AssignmentStmt)
  52. and isinstance(stmt.lvalues[0], NameExpr)
  53. and stmt.lvalues[0].name == name
  54. ):
  55. break
  56. else:
  57. util.fail(api, "Can't find mapped attribute {}".format(name), cls)
  58. return None
  59. if stmt.type is None:
  60. util.fail(
  61. api,
  62. "Statement linked from _mypy_mapped_attrs has no "
  63. "typing information",
  64. stmt,
  65. )
  66. return None
  67. left_hand_explicit_type = get_proper_type(stmt.type)
  68. assert isinstance(
  69. left_hand_explicit_type, (Instance, UnionType, UnboundType)
  70. )
  71. attributes.append(
  72. util.SQLAlchemyAttribute(
  73. name=name,
  74. line=item.line,
  75. column=item.column,
  76. typ=left_hand_explicit_type,
  77. info=cls.info,
  78. )
  79. )
  80. apply_type_to_mapped_statement(
  81. api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
  82. )
  83. def re_apply_declarative_assignments(
  84. cls: ClassDef,
  85. api: SemanticAnalyzerPluginInterface,
  86. attributes: List[util.SQLAlchemyAttribute],
  87. ) -> None:
  88. """For multiple class passes, re-apply our left-hand side types as mypy
  89. seems to reset them in place.
  90. """
  91. mapped_attr_lookup = {attr.name: attr for attr in attributes}
  92. update_cls_metadata = False
  93. for stmt in cls.defs.body:
  94. # for a re-apply, all of our statements are AssignmentStmt;
  95. # @declared_attr calls will have been converted and this
  96. # currently seems to be preserved by mypy (but who knows if this
  97. # will change).
  98. if (
  99. isinstance(stmt, AssignmentStmt)
  100. and isinstance(stmt.lvalues[0], NameExpr)
  101. and stmt.lvalues[0].name in mapped_attr_lookup
  102. and isinstance(stmt.lvalues[0].node, Var)
  103. ):
  104. left_node = stmt.lvalues[0].node
  105. python_type_for_type = mapped_attr_lookup[
  106. stmt.lvalues[0].name
  107. ].type
  108. left_node_proper_type = get_proper_type(left_node.type)
  109. # if we have scanned an UnboundType and now there's a more
  110. # specific type than UnboundType, call the re-scan so we
  111. # can get that set up correctly
  112. if (
  113. isinstance(python_type_for_type, UnboundType)
  114. and not isinstance(left_node_proper_type, UnboundType)
  115. and (
  116. isinstance(stmt.rvalue, CallExpr)
  117. and isinstance(stmt.rvalue.callee, MemberExpr)
  118. and isinstance(stmt.rvalue.callee.expr, NameExpr)
  119. and stmt.rvalue.callee.expr.node is not None
  120. and stmt.rvalue.callee.expr.node.fullname
  121. == NAMED_TYPE_SQLA_MAPPED
  122. and stmt.rvalue.callee.name == "_empty_constructor"
  123. and isinstance(stmt.rvalue.args[0], CallExpr)
  124. and isinstance(stmt.rvalue.args[0].callee, RefExpr)
  125. )
  126. ):
  127. python_type_for_type = (
  128. infer.infer_type_from_right_hand_nameexpr(
  129. api,
  130. stmt,
  131. left_node,
  132. left_node_proper_type,
  133. stmt.rvalue.args[0].callee,
  134. )
  135. )
  136. if python_type_for_type is None or isinstance(
  137. python_type_for_type, UnboundType
  138. ):
  139. continue
  140. # update the SQLAlchemyAttribute with the better information
  141. mapped_attr_lookup[
  142. stmt.lvalues[0].name
  143. ].type = python_type_for_type
  144. update_cls_metadata = True
  145. if python_type_for_type is not None:
  146. left_node.type = api.named_type(
  147. NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
  148. )
  149. if update_cls_metadata:
  150. util.set_mapped_attributes(cls.info, attributes)
  151. def apply_type_to_mapped_statement(
  152. api: SemanticAnalyzerPluginInterface,
  153. stmt: AssignmentStmt,
  154. lvalue: NameExpr,
  155. left_hand_explicit_type: Optional[ProperType],
  156. python_type_for_type: Optional[ProperType],
  157. ) -> None:
  158. """Apply the Mapped[<type>] annotation and right hand object to a
  159. declarative assignment statement.
  160. This converts a Python declarative class statement such as::
  161. class User(Base):
  162. # ...
  163. attrname = Column(Integer)
  164. To one that describes the final Python behavior to Mypy::
  165. class User(Base):
  166. # ...
  167. attrname : Mapped[Optional[int]] = <meaningless temp node>
  168. """
  169. left_node = lvalue.node
  170. assert isinstance(left_node, Var)
  171. if left_hand_explicit_type is not None:
  172. left_node.type = api.named_type(
  173. NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
  174. )
  175. else:
  176. lvalue.is_inferred_def = False
  177. left_node.type = api.named_type(
  178. NAMED_TYPE_SQLA_MAPPED,
  179. [] if python_type_for_type is None else [python_type_for_type],
  180. )
  181. # so to have it skip the right side totally, we can do this:
  182. # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
  183. # however, if we instead manufacture a new node that uses the old
  184. # one, then we can still get type checking for the call itself,
  185. # e.g. the Column, relationship() call, etc.
  186. # rewrite the node as:
  187. # <attr> : Mapped[<typ>] =
  188. # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
  189. # the original right-hand side is maintained so it gets type checked
  190. # internally
  191. stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
  192. def add_additional_orm_attributes(
  193. cls: ClassDef,
  194. api: SemanticAnalyzerPluginInterface,
  195. attributes: List[util.SQLAlchemyAttribute],
  196. ) -> None:
  197. """Apply __init__, __table__ and other attributes to the mapped class."""
  198. info = util.info_for_cls(cls, api)
  199. if info is None:
  200. return
  201. is_base = util.get_is_base(info)
  202. if "__init__" not in info.names and not is_base:
  203. mapped_attr_names = {attr.name: attr.type for attr in attributes}
  204. for base in info.mro[1:-1]:
  205. if "sqlalchemy" not in info.metadata:
  206. continue
  207. base_cls_attributes = util.get_mapped_attributes(base, api)
  208. if base_cls_attributes is None:
  209. continue
  210. for attr in base_cls_attributes:
  211. mapped_attr_names.setdefault(attr.name, attr.type)
  212. arguments = []
  213. for name, typ in mapped_attr_names.items():
  214. if typ is None:
  215. typ = AnyType(TypeOfAny.special_form)
  216. arguments.append(
  217. Argument(
  218. variable=Var(name, typ),
  219. type_annotation=typ,
  220. initializer=TempNode(typ),
  221. kind=ARG_NAMED_OPT,
  222. )
  223. )
  224. add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
  225. if "__table__" not in info.names and util.get_has_table(info):
  226. _apply_placeholder_attr_to_class(
  227. api, cls, "sqlalchemy.sql.schema.Table", "__table__"
  228. )
  229. if not is_base:
  230. _apply_placeholder_attr_to_class(
  231. api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
  232. )
  233. def _apply_placeholder_attr_to_class(
  234. api: SemanticAnalyzerPluginInterface,
  235. cls: ClassDef,
  236. qualified_name: str,
  237. attrname: str,
  238. ) -> None:
  239. sym = api.lookup_fully_qualified_or_none(qualified_name)
  240. if sym:
  241. assert isinstance(sym.node, TypeInfo)
  242. type_: ProperType = Instance(sym.node, [])
  243. else:
  244. type_ = AnyType(TypeOfAny.special_form)
  245. var = Var(attrname)
  246. var._fullname = cls.fullname + "." + attrname
  247. var.info = cls.info
  248. var.type = type_
  249. cls.info.names[attrname] = SymbolTableNode(MDEF, var)