brain_namedtuple_enum.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. # Copyright (c) 2012-2015 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
  2. # Copyright (c) 2013-2014 Google, Inc.
  3. # Copyright (c) 2014-2020 Claudiu Popa <pcmanticore@gmail.com>
  4. # Copyright (c) 2014 Eevee (Alex Munroe) <amunroe@yelp.com>
  5. # Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
  6. # Copyright (c) 2015 Dmitry Pribysh <dmand@yandex.ru>
  7. # Copyright (c) 2015 David Shea <dshea@redhat.com>
  8. # Copyright (c) 2015 Philip Lorenz <philip@bithub.de>
  9. # Copyright (c) 2016 Jakub Wilk <jwilk@jwilk.net>
  10. # Copyright (c) 2016 Mateusz Bysiek <mb@mbdev.pl>
  11. # Copyright (c) 2017 Hugo <hugovk@users.noreply.github.com>
  12. # Copyright (c) 2017 Łukasz Rogalski <rogalski.91@gmail.com>
  13. # Copyright (c) 2018 Ville Skyttä <ville.skytta@iki.fi>
  14. # Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
  15. # Copyright (c) 2020 hippo91 <guillaume.peillex@gmail.com>
  16. # Copyright (c) 2020 Ram Rachum <ram@rachum.com>
  17. # Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
  18. # Copyright (c) 2021 Daniël van Noord <13665637+DanielNoord@users.noreply.github.com>
  19. # Copyright (c) 2021 Dimitri Prybysh <dmand@yandex.ru>
  20. # Copyright (c) 2021 David Liu <david@cs.toronto.edu>
  21. # Copyright (c) 2021 pre-commit-ci[bot] <bot@noreply.github.com>
  22. # Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
  23. # Copyright (c) 2021 Andrew Haigh <hello@nelf.in>
  24. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  25. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  26. """Astroid hooks for the Python standard library."""
  27. import functools
  28. import keyword
  29. from textwrap import dedent
  30. import astroid
  31. from astroid import arguments, inference_tip, nodes, util
  32. from astroid.builder import AstroidBuilder, extract_node
  33. from astroid.exceptions import (
  34. AstroidTypeError,
  35. AstroidValueError,
  36. InferenceError,
  37. MroError,
  38. UseInferenceDefault,
  39. )
  40. from astroid.manager import AstroidManager
  41. TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"}
  42. ENUM_BASE_NAMES = {
  43. "Enum",
  44. "IntEnum",
  45. "enum.Enum",
  46. "enum.IntEnum",
  47. "IntFlag",
  48. "enum.IntFlag",
  49. }
  50. def _infer_first(node, context):
  51. if node is util.Uninferable:
  52. raise UseInferenceDefault
  53. try:
  54. value = next(node.infer(context=context))
  55. except StopIteration as exc:
  56. raise InferenceError from exc
  57. if value is util.Uninferable:
  58. raise UseInferenceDefault()
  59. return value
  60. def _find_func_form_arguments(node, context):
  61. def _extract_namedtuple_arg_or_keyword( # pylint: disable=inconsistent-return-statements
  62. position, key_name=None
  63. ):
  64. if len(args) > position:
  65. return _infer_first(args[position], context)
  66. if key_name and key_name in found_keywords:
  67. return _infer_first(found_keywords[key_name], context)
  68. args = node.args
  69. keywords = node.keywords
  70. found_keywords = (
  71. {keyword.arg: keyword.value for keyword in keywords} if keywords else {}
  72. )
  73. name = _extract_namedtuple_arg_or_keyword(position=0, key_name="typename")
  74. names = _extract_namedtuple_arg_or_keyword(position=1, key_name="field_names")
  75. if name and names:
  76. return name.value, names
  77. raise UseInferenceDefault()
  78. def infer_func_form(node, base_type, context=None, enum=False):
  79. """Specific inference function for namedtuple or Python 3 enum."""
  80. # node is a Call node, class name as first argument and generated class
  81. # attributes as second argument
  82. # namedtuple or enums list of attributes can be a list of strings or a
  83. # whitespace-separate string
  84. try:
  85. name, names = _find_func_form_arguments(node, context)
  86. try:
  87. attributes = names.value.replace(",", " ").split()
  88. except AttributeError as exc:
  89. if not enum:
  90. attributes = [
  91. _infer_first(const, context).value for const in names.elts
  92. ]
  93. else:
  94. # Enums supports either iterator of (name, value) pairs
  95. # or mappings.
  96. if hasattr(names, "items") and isinstance(names.items, list):
  97. attributes = [
  98. _infer_first(const[0], context).value
  99. for const in names.items
  100. if isinstance(const[0], nodes.Const)
  101. ]
  102. elif hasattr(names, "elts"):
  103. # Enums can support either ["a", "b", "c"]
  104. # or [("a", 1), ("b", 2), ...], but they can't
  105. # be mixed.
  106. if all(isinstance(const, nodes.Tuple) for const in names.elts):
  107. attributes = [
  108. _infer_first(const.elts[0], context).value
  109. for const in names.elts
  110. if isinstance(const, nodes.Tuple)
  111. ]
  112. else:
  113. attributes = [
  114. _infer_first(const, context).value for const in names.elts
  115. ]
  116. else:
  117. raise AttributeError from exc
  118. if not attributes:
  119. raise AttributeError from exc
  120. except (AttributeError, InferenceError) as exc:
  121. raise UseInferenceDefault from exc
  122. if not enum:
  123. # namedtuple maps sys.intern(str()) over over field_names
  124. attributes = [str(attr) for attr in attributes]
  125. # XXX this should succeed *unless* __str__/__repr__ is incorrect or throws
  126. # in which case we should not have inferred these values and raised earlier
  127. attributes = [attr for attr in attributes if " " not in attr]
  128. # If we can't infer the name of the class, don't crash, up to this point
  129. # we know it is a namedtuple anyway.
  130. name = name or "Uninferable"
  131. # we want to return a Class node instance with proper attributes set
  132. class_node = nodes.ClassDef(name, "docstring")
  133. class_node.parent = node.parent
  134. # set base class=tuple
  135. class_node.bases.append(base_type)
  136. # XXX add __init__(*attributes) method
  137. for attr in attributes:
  138. fake_node = nodes.EmptyNode()
  139. fake_node.parent = class_node
  140. fake_node.attrname = attr
  141. class_node.instance_attrs[attr] = [fake_node]
  142. return class_node, name, attributes
  143. def _has_namedtuple_base(node):
  144. """Predicate for class inference tip
  145. :type node: ClassDef
  146. :rtype: bool
  147. """
  148. return set(node.basenames) & TYPING_NAMEDTUPLE_BASENAMES
  149. def _looks_like(node, name):
  150. func = node.func
  151. if isinstance(func, nodes.Attribute):
  152. return func.attrname == name
  153. if isinstance(func, nodes.Name):
  154. return func.name == name
  155. return False
  156. _looks_like_namedtuple = functools.partial(_looks_like, name="namedtuple")
  157. _looks_like_enum = functools.partial(_looks_like, name="Enum")
  158. _looks_like_typing_namedtuple = functools.partial(_looks_like, name="NamedTuple")
  159. def infer_named_tuple(node, context=None):
  160. """Specific inference function for namedtuple Call node"""
  161. tuple_base_name = nodes.Name(name="tuple", parent=node.root())
  162. class_node, name, attributes = infer_func_form(
  163. node, tuple_base_name, context=context
  164. )
  165. call_site = arguments.CallSite.from_call(node, context=context)
  166. node = extract_node("import collections; collections.namedtuple")
  167. try:
  168. func = next(node.infer())
  169. except StopIteration as e:
  170. raise InferenceError(node=node) from e
  171. try:
  172. rename = next(call_site.infer_argument(func, "rename", context)).bool_value()
  173. except (InferenceError, StopIteration):
  174. rename = False
  175. try:
  176. attributes = _check_namedtuple_attributes(name, attributes, rename)
  177. except AstroidTypeError as exc:
  178. raise UseInferenceDefault("TypeError: " + str(exc)) from exc
  179. except AstroidValueError as exc:
  180. raise UseInferenceDefault("ValueError: " + str(exc)) from exc
  181. replace_args = ", ".join(f"{arg}=None" for arg in attributes)
  182. field_def = (
  183. " {name} = property(lambda self: self[{index:d}], "
  184. "doc='Alias for field number {index:d}')"
  185. )
  186. field_defs = "\n".join(
  187. field_def.format(name=name, index=index)
  188. for index, name in enumerate(attributes)
  189. )
  190. fake = AstroidBuilder(AstroidManager()).string_build(
  191. f"""
  192. class {name}(tuple):
  193. __slots__ = ()
  194. _fields = {attributes!r}
  195. def _asdict(self):
  196. return self.__dict__
  197. @classmethod
  198. def _make(cls, iterable, new=tuple.__new__, len=len):
  199. return new(cls, iterable)
  200. def _replace(self, {replace_args}):
  201. return self
  202. def __getnewargs__(self):
  203. return tuple(self)
  204. {field_defs}
  205. """
  206. )
  207. class_node.locals["_asdict"] = fake.body[0].locals["_asdict"]
  208. class_node.locals["_make"] = fake.body[0].locals["_make"]
  209. class_node.locals["_replace"] = fake.body[0].locals["_replace"]
  210. class_node.locals["_fields"] = fake.body[0].locals["_fields"]
  211. for attr in attributes:
  212. class_node.locals[attr] = fake.body[0].locals[attr]
  213. # we use UseInferenceDefault, we can't be a generator so return an iterator
  214. return iter([class_node])
  215. def _get_renamed_namedtuple_attributes(field_names):
  216. names = list(field_names)
  217. seen = set()
  218. for i, name in enumerate(field_names):
  219. if (
  220. not all(c.isalnum() or c == "_" for c in name)
  221. or keyword.iskeyword(name)
  222. or not name
  223. or name[0].isdigit()
  224. or name.startswith("_")
  225. or name in seen
  226. ):
  227. names[i] = "_%d" % i
  228. seen.add(name)
  229. return tuple(names)
  230. def _check_namedtuple_attributes(typename, attributes, rename=False):
  231. attributes = tuple(attributes)
  232. if rename:
  233. attributes = _get_renamed_namedtuple_attributes(attributes)
  234. # The following snippet is derived from the CPython Lib/collections/__init__.py sources
  235. # <snippet>
  236. for name in (typename,) + attributes:
  237. if not isinstance(name, str):
  238. raise AstroidTypeError("Type names and field names must be strings")
  239. if not name.isidentifier():
  240. raise AstroidValueError(
  241. "Type names and field names must be valid" + f"identifiers: {name!r}"
  242. )
  243. if keyword.iskeyword(name):
  244. raise AstroidValueError(
  245. f"Type names and field names cannot be a keyword: {name!r}"
  246. )
  247. seen = set()
  248. for name in attributes:
  249. if name.startswith("_") and not rename:
  250. raise AstroidValueError(
  251. f"Field names cannot start with an underscore: {name!r}"
  252. )
  253. if name in seen:
  254. raise AstroidValueError(f"Encountered duplicate field name: {name!r}")
  255. seen.add(name)
  256. # </snippet>
  257. return attributes
  258. def infer_enum(node, context=None):
  259. """Specific inference function for enum Call node."""
  260. enum_meta = extract_node(
  261. """
  262. class EnumMeta(object):
  263. 'docstring'
  264. def __call__(self, node):
  265. class EnumAttribute(object):
  266. name = ''
  267. value = 0
  268. return EnumAttribute()
  269. def __iter__(self):
  270. class EnumAttribute(object):
  271. name = ''
  272. value = 0
  273. return [EnumAttribute()]
  274. def __reversed__(self):
  275. class EnumAttribute(object):
  276. name = ''
  277. value = 0
  278. return (EnumAttribute, )
  279. def __next__(self):
  280. return next(iter(self))
  281. def __getitem__(self, attr):
  282. class Value(object):
  283. @property
  284. def name(self):
  285. return ''
  286. @property
  287. def value(self):
  288. return attr
  289. return Value()
  290. __members__ = ['']
  291. """
  292. )
  293. class_node = infer_func_form(node, enum_meta, context=context, enum=True)[0]
  294. return iter([class_node.instantiate_class()])
  295. INT_FLAG_ADDITION_METHODS = """
  296. def __or__(self, other):
  297. return {name}(self.value | other.value)
  298. def __and__(self, other):
  299. return {name}(self.value & other.value)
  300. def __xor__(self, other):
  301. return {name}(self.value ^ other.value)
  302. def __add__(self, other):
  303. return {name}(self.value + other.value)
  304. def __div__(self, other):
  305. return {name}(self.value / other.value)
  306. def __invert__(self):
  307. return {name}(~self.value)
  308. def __mul__(self, other):
  309. return {name}(self.value * other.value)
  310. """
  311. def infer_enum_class(node):
  312. """Specific inference for enums."""
  313. for basename in (b for cls in node.mro() for b in cls.basenames):
  314. if basename not in ENUM_BASE_NAMES:
  315. continue
  316. if node.root().name == "enum":
  317. # Skip if the class is directly from enum module.
  318. break
  319. dunder_members = {}
  320. target_names = set()
  321. for local, values in node.locals.items():
  322. if any(not isinstance(value, nodes.AssignName) for value in values):
  323. continue
  324. stmt = values[0].statement(future=True)
  325. if isinstance(stmt, nodes.Assign):
  326. if isinstance(stmt.targets[0], nodes.Tuple):
  327. targets = stmt.targets[0].itered()
  328. else:
  329. targets = stmt.targets
  330. elif isinstance(stmt, nodes.AnnAssign):
  331. targets = [stmt.target]
  332. else:
  333. continue
  334. inferred_return_value = None
  335. if isinstance(stmt, nodes.Assign):
  336. if isinstance(stmt.value, nodes.Const):
  337. if isinstance(stmt.value.value, str):
  338. inferred_return_value = repr(stmt.value.value)
  339. else:
  340. inferred_return_value = stmt.value.value
  341. else:
  342. inferred_return_value = stmt.value.as_string()
  343. new_targets = []
  344. for target in targets:
  345. if isinstance(target, nodes.Starred):
  346. continue
  347. target_names.add(target.name)
  348. # Replace all the assignments with our mocked class.
  349. classdef = dedent(
  350. """
  351. class {name}({types}):
  352. @property
  353. def value(self):
  354. return {return_value}
  355. @property
  356. def name(self):
  357. return "{name}"
  358. """.format(
  359. name=target.name,
  360. types=", ".join(node.basenames),
  361. return_value=inferred_return_value,
  362. )
  363. )
  364. if "IntFlag" in basename:
  365. # Alright, we need to add some additional methods.
  366. # Unfortunately we still can't infer the resulting objects as
  367. # Enum members, but once we'll be able to do that, the following
  368. # should result in some nice symbolic execution
  369. classdef += INT_FLAG_ADDITION_METHODS.format(name=target.name)
  370. fake = AstroidBuilder(
  371. AstroidManager(), apply_transforms=False
  372. ).string_build(classdef)[target.name]
  373. fake.parent = target.parent
  374. for method in node.mymethods():
  375. fake.locals[method.name] = [method]
  376. new_targets.append(fake.instantiate_class())
  377. dunder_members[local] = fake
  378. node.locals[local] = new_targets
  379. members = nodes.Dict(parent=node)
  380. members.postinit(
  381. [
  382. (nodes.Const(k, parent=members), nodes.Name(v.name, parent=members))
  383. for k, v in dunder_members.items()
  384. ]
  385. )
  386. node.locals["__members__"] = [members]
  387. # The enum.Enum class itself defines two @DynamicClassAttribute data-descriptors
  388. # "name" and "value" (which we override in the mocked class for each enum member
  389. # above). When dealing with inference of an arbitrary instance of the enum
  390. # class, e.g. in a method defined in the class body like:
  391. # class SomeEnum(enum.Enum):
  392. # def method(self):
  393. # self.name # <- here
  394. # In the absence of an enum member called "name" or "value", these attributes
  395. # should resolve to the descriptor on that particular instance, i.e. enum member.
  396. # For "value", we have no idea what that should be, but for "name", we at least
  397. # know that it should be a string, so infer that as a guess.
  398. if "name" not in target_names:
  399. code = dedent(
  400. """
  401. @property
  402. def name(self):
  403. return ''
  404. """
  405. )
  406. name_dynamicclassattr = AstroidBuilder(AstroidManager()).string_build(code)[
  407. "name"
  408. ]
  409. node.locals["name"] = [name_dynamicclassattr]
  410. break
  411. return node
  412. def infer_typing_namedtuple_class(class_node, context=None):
  413. """Infer a subclass of typing.NamedTuple"""
  414. # Check if it has the corresponding bases
  415. annassigns_fields = [
  416. annassign.target.name
  417. for annassign in class_node.body
  418. if isinstance(annassign, nodes.AnnAssign)
  419. ]
  420. code = dedent(
  421. """
  422. from collections import namedtuple
  423. namedtuple({typename!r}, {fields!r})
  424. """
  425. ).format(typename=class_node.name, fields=",".join(annassigns_fields))
  426. node = extract_node(code)
  427. try:
  428. generated_class_node = next(infer_named_tuple(node, context))
  429. except StopIteration as e:
  430. raise InferenceError(node=node, context=context) from e
  431. for method in class_node.mymethods():
  432. generated_class_node.locals[method.name] = [method]
  433. for body_node in class_node.body:
  434. if isinstance(body_node, nodes.Assign):
  435. for target in body_node.targets:
  436. attr = target.name
  437. generated_class_node.locals[attr] = class_node.locals[attr]
  438. elif isinstance(body_node, nodes.ClassDef):
  439. generated_class_node.locals[body_node.name] = [body_node]
  440. return iter((generated_class_node,))
  441. def infer_typing_namedtuple_function(node, context=None):
  442. """
  443. Starting with python3.9, NamedTuple is a function of the typing module.
  444. The class NamedTuple is build dynamically through a call to `type` during
  445. initialization of the `_NamedTuple` variable.
  446. """
  447. klass = extract_node(
  448. """
  449. from typing import _NamedTuple
  450. _NamedTuple
  451. """
  452. )
  453. return klass.infer(context)
  454. def infer_typing_namedtuple(node, context=None):
  455. """Infer a typing.NamedTuple(...) call."""
  456. # This is essentially a namedtuple with different arguments
  457. # so we extract the args and infer a named tuple.
  458. try:
  459. func = next(node.func.infer())
  460. except (InferenceError, StopIteration) as exc:
  461. raise UseInferenceDefault from exc
  462. if func.qname() != "typing.NamedTuple":
  463. raise UseInferenceDefault
  464. if len(node.args) != 2:
  465. raise UseInferenceDefault
  466. if not isinstance(node.args[1], (nodes.List, nodes.Tuple)):
  467. raise UseInferenceDefault
  468. names = []
  469. for elt in node.args[1].elts:
  470. if not isinstance(elt, (nodes.List, nodes.Tuple)):
  471. raise UseInferenceDefault
  472. if len(elt.elts) != 2:
  473. raise UseInferenceDefault
  474. names.append(elt.elts[0].as_string())
  475. typename = node.args[0].as_string()
  476. if names:
  477. field_names = f"({','.join(names)},)"
  478. else:
  479. field_names = "''"
  480. node = extract_node(f"namedtuple({typename}, {field_names})")
  481. return infer_named_tuple(node, context)
  482. def _is_enum_subclass(cls: astroid.ClassDef) -> bool:
  483. """Return whether cls is a subclass of an Enum."""
  484. try:
  485. return any(
  486. klass.name in ENUM_BASE_NAMES
  487. and getattr(klass.root(), "name", None) == "enum"
  488. for klass in cls.mro()
  489. )
  490. except MroError:
  491. return False
  492. AstroidManager().register_transform(
  493. nodes.Call, inference_tip(infer_named_tuple), _looks_like_namedtuple
  494. )
  495. AstroidManager().register_transform(
  496. nodes.Call, inference_tip(infer_enum), _looks_like_enum
  497. )
  498. AstroidManager().register_transform(
  499. nodes.ClassDef, infer_enum_class, predicate=_is_enum_subclass
  500. )
  501. AstroidManager().register_transform(
  502. nodes.ClassDef, inference_tip(infer_typing_namedtuple_class), _has_namedtuple_base
  503. )
  504. AstroidManager().register_transform(
  505. nodes.FunctionDef,
  506. inference_tip(infer_typing_namedtuple_function),
  507. lambda node: node.name == "NamedTuple"
  508. and getattr(node.root(), "name", None) == "typing",
  509. )
  510. AstroidManager().register_transform(
  511. nodes.Call, inference_tip(infer_typing_namedtuple), _looks_like_typing_namedtuple
  512. )