transforms.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) 2015-2016, 2018 Claudiu Popa <pcmanticore@gmail.com>
  2. # Copyright (c) 2016 Ceridwen <ceridwenv@gmail.com>
  3. # Copyright (c) 2018 Nick Drozd <nicholasdrozd@gmail.com>
  4. # Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
  5. # Copyright (c) 2021 David Liu <david@cs.toronto.edu>
  6. # Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
  7. # Copyright (c) 2021 Andrew Haigh <hello@nelf.in>
  8. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  9. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  10. import collections
  11. from functools import lru_cache
  12. from astroid.context import _invalidate_cache
  13. class TransformVisitor:
  14. """A visitor for handling transforms.
  15. The standard approach of using it is to call
  16. :meth:`~visit` with an *astroid* module and the class
  17. will take care of the rest, walking the tree and running the
  18. transforms for each encountered node.
  19. """
  20. TRANSFORM_MAX_CACHE_SIZE = 10000
  21. def __init__(self):
  22. self.transforms = collections.defaultdict(list)
  23. @lru_cache(maxsize=TRANSFORM_MAX_CACHE_SIZE)
  24. def _transform(self, node):
  25. """Call matching transforms for the given node if any and return the
  26. transformed node.
  27. """
  28. cls = node.__class__
  29. if cls not in self.transforms:
  30. # no transform registered for this class of node
  31. return node
  32. transforms = self.transforms[cls]
  33. for transform_func, predicate in transforms:
  34. if predicate is None or predicate(node):
  35. ret = transform_func(node)
  36. # if the transformation function returns something, it's
  37. # expected to be a replacement for the node
  38. if ret is not None:
  39. _invalidate_cache()
  40. node = ret
  41. if ret.__class__ != cls:
  42. # Can no longer apply the rest of the transforms.
  43. break
  44. return node
  45. def _visit(self, node):
  46. if hasattr(node, "_astroid_fields"):
  47. for name in node._astroid_fields:
  48. value = getattr(node, name)
  49. visited = self._visit_generic(value)
  50. if visited != value:
  51. setattr(node, name, visited)
  52. return self._transform(node)
  53. def _visit_generic(self, node):
  54. if isinstance(node, list):
  55. return [self._visit_generic(child) for child in node]
  56. if isinstance(node, tuple):
  57. return tuple(self._visit_generic(child) for child in node)
  58. if not node or isinstance(node, str):
  59. return node
  60. return self._visit(node)
  61. def register_transform(self, node_class, transform, predicate=None):
  62. """Register `transform(node)` function to be applied on the given
  63. astroid's `node_class` if `predicate` is None or returns true
  64. when called with the node as argument.
  65. The transform function may return a value which is then used to
  66. substitute the original node in the tree.
  67. """
  68. self.transforms[node_class].append((transform, predicate))
  69. def unregister_transform(self, node_class, transform, predicate=None):
  70. """Unregister the given transform."""
  71. self.transforms[node_class].remove((transform, predicate))
  72. def visit(self, module):
  73. """Walk the given astroid *tree* and transform each encountered node
  74. Only the nodes which have transforms registered will actually
  75. be replaced or changed.
  76. """
  77. return self._visit(module)