brain_dataclasses.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  3. """
  4. Astroid hook for the dataclasses library
  5. Support built-in dataclasses, pydantic.dataclasses, and marshmallow_dataclass-annotated
  6. dataclasses. References:
  7. - https://docs.python.org/3/library/dataclasses.html
  8. - https://pydantic-docs.helpmanual.io/usage/dataclasses/
  9. - https://lovasoa.github.io/marshmallow_dataclass/
  10. """
  11. import sys
  12. from typing import FrozenSet, Generator, List, Optional, Tuple, Union
  13. from astroid import context, inference_tip
  14. from astroid.builder import parse
  15. from astroid.const import PY37_PLUS, PY39_PLUS
  16. from astroid.exceptions import (
  17. AstroidSyntaxError,
  18. InferenceError,
  19. MroError,
  20. UseInferenceDefault,
  21. )
  22. from astroid.manager import AstroidManager
  23. from astroid.nodes.node_classes import (
  24. AnnAssign,
  25. Assign,
  26. AssignName,
  27. Attribute,
  28. Call,
  29. Name,
  30. NodeNG,
  31. Subscript,
  32. Unknown,
  33. )
  34. from astroid.nodes.scoped_nodes import ClassDef, FunctionDef
  35. from astroid.util import Uninferable
  36. if sys.version_info >= (3, 8):
  37. from typing import Literal
  38. else:
  39. from typing_extensions import Literal
  40. _FieldDefaultReturn = Union[
  41. None, Tuple[Literal["default"], NodeNG], Tuple[Literal["default_factory"], Call]
  42. ]
  43. DATACLASSES_DECORATORS = frozenset(("dataclass",))
  44. FIELD_NAME = "field"
  45. DATACLASS_MODULES = frozenset(
  46. ("dataclasses", "marshmallow_dataclass", "pydantic.dataclasses")
  47. )
  48. DEFAULT_FACTORY = "_HAS_DEFAULT_FACTORY" # based on typing.py
  49. def is_decorated_with_dataclass(node, decorator_names=DATACLASSES_DECORATORS):
  50. """Return True if a decorated node has a `dataclass` decorator applied."""
  51. if not isinstance(node, ClassDef) or not node.decorators:
  52. return False
  53. return any(
  54. _looks_like_dataclass_decorator(decorator_attribute, decorator_names)
  55. for decorator_attribute in node.decorators.nodes
  56. )
  57. def dataclass_transform(node: ClassDef) -> None:
  58. """Rewrite a dataclass to be easily understood by pylint"""
  59. for assign_node in _get_dataclass_attributes(node):
  60. name = assign_node.target.name
  61. rhs_node = Unknown(
  62. lineno=assign_node.lineno,
  63. col_offset=assign_node.col_offset,
  64. parent=assign_node,
  65. )
  66. rhs_node = AstroidManager().visit_transforms(rhs_node)
  67. node.instance_attrs[name] = [rhs_node]
  68. if not _check_generate_dataclass_init(node):
  69. return
  70. try:
  71. reversed_mro = list(reversed(node.mro()))
  72. except MroError:
  73. reversed_mro = [node]
  74. field_assigns = {}
  75. field_order = []
  76. for klass in (k for k in reversed_mro if is_decorated_with_dataclass(k)):
  77. for assign_node in _get_dataclass_attributes(klass, init=True):
  78. name = assign_node.target.name
  79. if name not in field_assigns:
  80. field_order.append(name)
  81. field_assigns[name] = assign_node
  82. init_str = _generate_dataclass_init([field_assigns[name] for name in field_order])
  83. try:
  84. init_node = parse(init_str)["__init__"]
  85. except AstroidSyntaxError:
  86. pass
  87. else:
  88. init_node.parent = node
  89. init_node.lineno, init_node.col_offset = None, None
  90. node.locals["__init__"] = [init_node]
  91. root = node.root()
  92. if DEFAULT_FACTORY not in root.locals:
  93. new_assign = parse(f"{DEFAULT_FACTORY} = object()").body[0]
  94. new_assign.parent = root
  95. root.locals[DEFAULT_FACTORY] = [new_assign.targets[0]]
  96. def _get_dataclass_attributes(node: ClassDef, init: bool = False) -> Generator:
  97. """Yield the AnnAssign nodes of dataclass attributes for the node.
  98. If init is True, also include InitVars, but exclude attributes from calls to
  99. field where init=False.
  100. """
  101. for assign_node in node.body:
  102. if not isinstance(assign_node, AnnAssign) or not isinstance(
  103. assign_node.target, AssignName
  104. ):
  105. continue
  106. if _is_class_var(assign_node.annotation): # type: ignore[arg-type] # annotation is never None
  107. continue
  108. if init:
  109. value = assign_node.value
  110. if (
  111. isinstance(value, Call)
  112. and _looks_like_dataclass_field_call(value, check_scope=False)
  113. and any(
  114. keyword.arg == "init"
  115. and not keyword.value.bool_value() # type: ignore[union-attr] # value is never None
  116. for keyword in value.keywords
  117. )
  118. ):
  119. continue
  120. elif _is_init_var(assign_node.annotation): # type: ignore[arg-type] # annotation is never None
  121. continue
  122. yield assign_node
  123. def _check_generate_dataclass_init(node: ClassDef) -> bool:
  124. """Return True if we should generate an __init__ method for node.
  125. This is True when:
  126. - node doesn't define its own __init__ method
  127. - the dataclass decorator was called *without* the keyword argument init=False
  128. """
  129. if "__init__" in node.locals:
  130. return False
  131. found = None
  132. for decorator_attribute in node.decorators.nodes:
  133. if not isinstance(decorator_attribute, Call):
  134. continue
  135. if _looks_like_dataclass_decorator(decorator_attribute):
  136. found = decorator_attribute
  137. if found is None:
  138. return True
  139. # Check for keyword arguments of the form init=False
  140. return all(
  141. keyword.arg != "init"
  142. and keyword.value.bool_value() # type: ignore[union-attr] # value is never None
  143. for keyword in found.keywords
  144. )
  145. def _generate_dataclass_init(assigns: List[AnnAssign]) -> str:
  146. """Return an init method for a dataclass given the targets."""
  147. target_names = []
  148. params = []
  149. assignments = []
  150. for assign in assigns:
  151. name, annotation, value = assign.target.name, assign.annotation, assign.value
  152. target_names.append(name)
  153. if _is_init_var(annotation): # type: ignore[arg-type] # annotation is never None
  154. init_var = True
  155. if isinstance(annotation, Subscript):
  156. annotation = annotation.slice
  157. else:
  158. # Cannot determine type annotation for parameter from InitVar
  159. annotation = None
  160. assignment_str = ""
  161. else:
  162. init_var = False
  163. assignment_str = f"self.{name} = {name}"
  164. if annotation:
  165. param_str = f"{name}: {annotation.as_string()}"
  166. else:
  167. param_str = name
  168. if value:
  169. if isinstance(value, Call) and _looks_like_dataclass_field_call(
  170. value, check_scope=False
  171. ):
  172. result = _get_field_default(value)
  173. if result:
  174. default_type, default_node = result
  175. if default_type == "default":
  176. param_str += f" = {default_node.as_string()}"
  177. elif default_type == "default_factory":
  178. param_str += f" = {DEFAULT_FACTORY}"
  179. assignment_str = (
  180. f"self.{name} = {default_node.as_string()} "
  181. f"if {name} is {DEFAULT_FACTORY} else {name}"
  182. )
  183. else:
  184. param_str += f" = {value.as_string()}"
  185. params.append(param_str)
  186. if not init_var:
  187. assignments.append(assignment_str)
  188. params_string = ", ".join(["self"] + params)
  189. assignments_string = "\n ".join(assignments) if assignments else "pass"
  190. return f"def __init__({params_string}) -> None:\n {assignments_string}"
  191. def infer_dataclass_attribute(
  192. node: Unknown, ctx: Optional[context.InferenceContext] = None
  193. ) -> Generator:
  194. """Inference tip for an Unknown node that was dynamically generated to
  195. represent a dataclass attribute.
  196. In the case that a default value is provided, that is inferred first.
  197. Then, an Instance of the annotated class is yielded.
  198. """
  199. assign = node.parent
  200. if not isinstance(assign, AnnAssign):
  201. yield Uninferable
  202. return
  203. annotation, value = assign.annotation, assign.value
  204. if value is not None:
  205. yield from value.infer(context=ctx)
  206. if annotation is not None:
  207. yield from _infer_instance_from_annotation(annotation, ctx=ctx)
  208. else:
  209. yield Uninferable
  210. def infer_dataclass_field_call(
  211. node: Call, ctx: Optional[context.InferenceContext] = None
  212. ) -> Generator:
  213. """Inference tip for dataclass field calls."""
  214. if not isinstance(node.parent, (AnnAssign, Assign)):
  215. raise UseInferenceDefault
  216. result = _get_field_default(node)
  217. if not result:
  218. yield Uninferable
  219. else:
  220. default_type, default = result
  221. if default_type == "default":
  222. yield from default.infer(context=ctx)
  223. else:
  224. new_call = parse(default.as_string()).body[0].value
  225. new_call.parent = node.parent
  226. yield from new_call.infer(context=ctx)
  227. def _looks_like_dataclass_decorator(
  228. node: NodeNG, decorator_names: FrozenSet[str] = DATACLASSES_DECORATORS
  229. ) -> bool:
  230. """Return True if node looks like a dataclass decorator.
  231. Uses inference to lookup the value of the node, and if that fails,
  232. matches against specific names.
  233. """
  234. if isinstance(node, Call): # decorator with arguments
  235. node = node.func
  236. try:
  237. inferred = next(node.infer())
  238. except (InferenceError, StopIteration):
  239. inferred = Uninferable
  240. if inferred is Uninferable:
  241. if isinstance(node, Name):
  242. return node.name in decorator_names
  243. if isinstance(node, Attribute):
  244. return node.attrname in decorator_names
  245. return False
  246. return (
  247. isinstance(inferred, FunctionDef)
  248. and inferred.name in decorator_names
  249. and inferred.root().name in DATACLASS_MODULES
  250. )
  251. def _looks_like_dataclass_attribute(node: Unknown) -> bool:
  252. """Return True if node was dynamically generated as the child of an AnnAssign
  253. statement.
  254. """
  255. parent = node.parent
  256. if not parent:
  257. return False
  258. scope = parent.scope()
  259. return (
  260. isinstance(parent, AnnAssign)
  261. and isinstance(scope, ClassDef)
  262. and is_decorated_with_dataclass(scope)
  263. )
  264. def _looks_like_dataclass_field_call(node: Call, check_scope: bool = True) -> bool:
  265. """Return True if node is calling dataclasses field or Field
  266. from an AnnAssign statement directly in the body of a ClassDef.
  267. If check_scope is False, skips checking the statement and body.
  268. """
  269. if check_scope:
  270. stmt = node.statement(future=True)
  271. scope = stmt.scope()
  272. if not (
  273. isinstance(stmt, AnnAssign)
  274. and stmt.value is not None
  275. and isinstance(scope, ClassDef)
  276. and is_decorated_with_dataclass(scope)
  277. ):
  278. return False
  279. try:
  280. inferred = next(node.func.infer())
  281. except (InferenceError, StopIteration):
  282. return False
  283. if not isinstance(inferred, FunctionDef):
  284. return False
  285. return inferred.name == FIELD_NAME and inferred.root().name in DATACLASS_MODULES
  286. def _get_field_default(field_call: Call) -> _FieldDefaultReturn:
  287. """Return a the default value of a field call, and the corresponding keyword argument name.
  288. field(default=...) results in the ... node
  289. field(default_factory=...) results in a Call node with func ... and no arguments
  290. If neither or both arguments are present, return ("", None) instead,
  291. indicating that there is not a valid default value.
  292. """
  293. default, default_factory = None, None
  294. for keyword in field_call.keywords:
  295. if keyword.arg == "default":
  296. default = keyword.value
  297. elif keyword.arg == "default_factory":
  298. default_factory = keyword.value
  299. if default is not None and default_factory is None:
  300. return "default", default
  301. if default is None and default_factory is not None:
  302. new_call = Call(
  303. lineno=field_call.lineno,
  304. col_offset=field_call.col_offset,
  305. parent=field_call.parent,
  306. )
  307. new_call.postinit(func=default_factory)
  308. return "default_factory", new_call
  309. return None
  310. def _is_class_var(node: NodeNG) -> bool:
  311. """Return True if node is a ClassVar, with or without subscripting."""
  312. if PY39_PLUS:
  313. try:
  314. inferred = next(node.infer())
  315. except (InferenceError, StopIteration):
  316. return False
  317. return getattr(inferred, "name", "") == "ClassVar"
  318. # Before Python 3.9, inference returns typing._SpecialForm instead of ClassVar.
  319. # Our backup is to inspect the node's structure.
  320. return isinstance(node, Subscript) and (
  321. isinstance(node.value, Name)
  322. and node.value.name == "ClassVar"
  323. or isinstance(node.value, Attribute)
  324. and node.value.attrname == "ClassVar"
  325. )
  326. def _is_init_var(node: NodeNG) -> bool:
  327. """Return True if node is an InitVar, with or without subscripting."""
  328. try:
  329. inferred = next(node.infer())
  330. except (InferenceError, StopIteration):
  331. return False
  332. return getattr(inferred, "name", "") == "InitVar"
  333. # Allowed typing classes for which we support inferring instances
  334. _INFERABLE_TYPING_TYPES = frozenset(
  335. (
  336. "Dict",
  337. "FrozenSet",
  338. "List",
  339. "Set",
  340. "Tuple",
  341. )
  342. )
  343. def _infer_instance_from_annotation(
  344. node: NodeNG, ctx: Optional[context.InferenceContext] = None
  345. ) -> Generator:
  346. """Infer an instance corresponding to the type annotation represented by node.
  347. Currently has limited support for the typing module.
  348. """
  349. klass = None
  350. try:
  351. klass = next(node.infer(context=ctx))
  352. except (InferenceError, StopIteration):
  353. yield Uninferable
  354. if not isinstance(klass, ClassDef):
  355. yield Uninferable
  356. elif klass.root().name in {
  357. "typing",
  358. "_collections_abc",
  359. "",
  360. }: # "" because of synthetic nodes in brain_typing.py
  361. if klass.name in _INFERABLE_TYPING_TYPES:
  362. yield klass.instantiate_class()
  363. else:
  364. yield Uninferable
  365. else:
  366. yield klass.instantiate_class()
  367. if PY37_PLUS:
  368. AstroidManager().register_transform(
  369. ClassDef, dataclass_transform, is_decorated_with_dataclass
  370. )
  371. AstroidManager().register_transform(
  372. Call,
  373. inference_tip(infer_dataclass_field_call, raise_on_overwrite=True),
  374. _looks_like_dataclass_field_call,
  375. )
  376. AstroidManager().register_transform(
  377. Unknown,
  378. inference_tip(infer_dataclass_attribute, raise_on_overwrite=True),
  379. _looks_like_dataclass_attribute,
  380. )