123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- # ext/mypy/apply.py
- # Copyright (C) 2021 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- from typing import List
- from typing import Optional
- from typing import Union
- from mypy.nodes import ARG_NAMED_OPT
- from mypy.nodes import Argument
- from mypy.nodes import AssignmentStmt
- from mypy.nodes import CallExpr
- from mypy.nodes import ClassDef
- from mypy.nodes import MDEF
- from mypy.nodes import MemberExpr
- from mypy.nodes import NameExpr
- from mypy.nodes import RefExpr
- from mypy.nodes import StrExpr
- from mypy.nodes import SymbolTableNode
- from mypy.nodes import TempNode
- from mypy.nodes import TypeInfo
- from mypy.nodes import Var
- from mypy.plugin import SemanticAnalyzerPluginInterface
- from mypy.plugins.common import add_method_to_class
- from mypy.types import AnyType
- from mypy.types import get_proper_type
- from mypy.types import Instance
- from mypy.types import NoneTyp
- from mypy.types import ProperType
- from mypy.types import TypeOfAny
- from mypy.types import UnboundType
- from mypy.types import UnionType
- from . import infer
- from . import util
- from .names import NAMED_TYPE_SQLA_MAPPED
- def apply_mypy_mapped_attr(
- cls: ClassDef,
- api: SemanticAnalyzerPluginInterface,
- item: Union[NameExpr, StrExpr],
- attributes: List[util.SQLAlchemyAttribute],
- ) -> None:
- if isinstance(item, NameExpr):
- name = item.name
- elif isinstance(item, StrExpr):
- name = item.value
- else:
- return None
- for stmt in cls.defs.body:
- if (
- isinstance(stmt, AssignmentStmt)
- and isinstance(stmt.lvalues[0], NameExpr)
- and stmt.lvalues[0].name == name
- ):
- break
- else:
- util.fail(api, "Can't find mapped attribute {}".format(name), cls)
- return None
- if stmt.type is None:
- util.fail(
- api,
- "Statement linked from _mypy_mapped_attrs has no "
- "typing information",
- stmt,
- )
- return None
- left_hand_explicit_type = get_proper_type(stmt.type)
- assert isinstance(
- left_hand_explicit_type, (Instance, UnionType, UnboundType)
- )
- attributes.append(
- util.SQLAlchemyAttribute(
- name=name,
- line=item.line,
- column=item.column,
- typ=left_hand_explicit_type,
- info=cls.info,
- )
- )
- apply_type_to_mapped_statement(
- api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
- )
- def re_apply_declarative_assignments(
- cls: ClassDef,
- api: SemanticAnalyzerPluginInterface,
- attributes: List[util.SQLAlchemyAttribute],
- ) -> None:
- """For multiple class passes, re-apply our left-hand side types as mypy
- seems to reset them in place.
- """
- mapped_attr_lookup = {attr.name: attr for attr in attributes}
- update_cls_metadata = False
- for stmt in cls.defs.body:
- # for a re-apply, all of our statements are AssignmentStmt;
- # @declared_attr calls will have been converted and this
- # currently seems to be preserved by mypy (but who knows if this
- # will change).
- if (
- isinstance(stmt, AssignmentStmt)
- and isinstance(stmt.lvalues[0], NameExpr)
- and stmt.lvalues[0].name in mapped_attr_lookup
- and isinstance(stmt.lvalues[0].node, Var)
- ):
- left_node = stmt.lvalues[0].node
- python_type_for_type = mapped_attr_lookup[
- stmt.lvalues[0].name
- ].type
- left_node_proper_type = get_proper_type(left_node.type)
- # if we have scanned an UnboundType and now there's a more
- # specific type than UnboundType, call the re-scan so we
- # can get that set up correctly
- if (
- isinstance(python_type_for_type, UnboundType)
- and not isinstance(left_node_proper_type, UnboundType)
- and (
- isinstance(stmt.rvalue, CallExpr)
- and isinstance(stmt.rvalue.callee, MemberExpr)
- and isinstance(stmt.rvalue.callee.expr, NameExpr)
- and stmt.rvalue.callee.expr.node is not None
- and stmt.rvalue.callee.expr.node.fullname
- == NAMED_TYPE_SQLA_MAPPED
- and stmt.rvalue.callee.name == "_empty_constructor"
- and isinstance(stmt.rvalue.args[0], CallExpr)
- and isinstance(stmt.rvalue.args[0].callee, RefExpr)
- )
- ):
- python_type_for_type = (
- infer.infer_type_from_right_hand_nameexpr(
- api,
- stmt,
- left_node,
- left_node_proper_type,
- stmt.rvalue.args[0].callee,
- )
- )
- if python_type_for_type is None or isinstance(
- python_type_for_type, UnboundType
- ):
- continue
- # update the SQLAlchemyAttribute with the better information
- mapped_attr_lookup[
- stmt.lvalues[0].name
- ].type = python_type_for_type
- update_cls_metadata = True
- if python_type_for_type is not None:
- left_node.type = api.named_type(
- NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
- )
- if update_cls_metadata:
- util.set_mapped_attributes(cls.info, attributes)
- def apply_type_to_mapped_statement(
- api: SemanticAnalyzerPluginInterface,
- stmt: AssignmentStmt,
- lvalue: NameExpr,
- left_hand_explicit_type: Optional[ProperType],
- python_type_for_type: Optional[ProperType],
- ) -> None:
- """Apply the Mapped[<type>] annotation and right hand object to a
- declarative assignment statement.
- This converts a Python declarative class statement such as::
- class User(Base):
- # ...
- attrname = Column(Integer)
- To one that describes the final Python behavior to Mypy::
- class User(Base):
- # ...
- attrname : Mapped[Optional[int]] = <meaningless temp node>
- """
- left_node = lvalue.node
- assert isinstance(left_node, Var)
- if left_hand_explicit_type is not None:
- left_node.type = api.named_type(
- NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
- )
- else:
- lvalue.is_inferred_def = False
- left_node.type = api.named_type(
- NAMED_TYPE_SQLA_MAPPED,
- [] if python_type_for_type is None else [python_type_for_type],
- )
- # so to have it skip the right side totally, we can do this:
- # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
- # however, if we instead manufacture a new node that uses the old
- # one, then we can still get type checking for the call itself,
- # e.g. the Column, relationship() call, etc.
- # rewrite the node as:
- # <attr> : Mapped[<typ>] =
- # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
- # the original right-hand side is maintained so it gets type checked
- # internally
- stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
- def add_additional_orm_attributes(
- cls: ClassDef,
- api: SemanticAnalyzerPluginInterface,
- attributes: List[util.SQLAlchemyAttribute],
- ) -> None:
- """Apply __init__, __table__ and other attributes to the mapped class."""
- info = util.info_for_cls(cls, api)
- if info is None:
- return
- is_base = util.get_is_base(info)
- if "__init__" not in info.names and not is_base:
- mapped_attr_names = {attr.name: attr.type for attr in attributes}
- for base in info.mro[1:-1]:
- if "sqlalchemy" not in info.metadata:
- continue
- base_cls_attributes = util.get_mapped_attributes(base, api)
- if base_cls_attributes is None:
- continue
- for attr in base_cls_attributes:
- mapped_attr_names.setdefault(attr.name, attr.type)
- arguments = []
- for name, typ in mapped_attr_names.items():
- if typ is None:
- typ = AnyType(TypeOfAny.special_form)
- arguments.append(
- Argument(
- variable=Var(name, typ),
- type_annotation=typ,
- initializer=TempNode(typ),
- kind=ARG_NAMED_OPT,
- )
- )
- add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
- if "__table__" not in info.names and util.get_has_table(info):
- _apply_placeholder_attr_to_class(
- api, cls, "sqlalchemy.sql.schema.Table", "__table__"
- )
- if not is_base:
- _apply_placeholder_attr_to_class(
- api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
- )
- def _apply_placeholder_attr_to_class(
- api: SemanticAnalyzerPluginInterface,
- cls: ClassDef,
- qualified_name: str,
- attrname: str,
- ) -> None:
- sym = api.lookup_fully_qualified_or_none(qualified_name)
- if sym:
- assert isinstance(sym.node, TypeInfo)
- type_: ProperType = Instance(sym.node, [])
- else:
- type_ = AnyType(TypeOfAny.special_form)
- var = Var(attrname)
- var._fullname = cls.fullname + "." + attrname
- var.info = cls.info
- var.type = type_
- cls.info.names[attrname] = SymbolTableNode(MDEF, var)
|