decorators.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # Copyright (c) 2015-2016, 2018, 2020 Claudiu Popa <pcmanticore@gmail.com>
  2. # Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
  3. # Copyright (c) 2015 Florian Bruhin <me@the-compiler.org>
  4. # Copyright (c) 2016 Derek Gustafson <degustaf@gmail.com>
  5. # Copyright (c) 2018, 2021 Nick Drozd <nicholasdrozd@gmail.com>
  6. # Copyright (c) 2018 Tomas Gavenciak <gavento@ucw.cz>
  7. # Copyright (c) 2018 Ashley Whetter <ashley@awhetter.co.uk>
  8. # Copyright (c) 2018 HoverHell <hoverhell@gmail.com>
  9. # Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
  10. # Copyright (c) 2020-2021 hippo91 <guillaume.peillex@gmail.com>
  11. # Copyright (c) 2020 Ram Rachum <ram@rachum.com>
  12. # Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
  13. # Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
  14. # Copyright (c) 2021 Daniël van Noord <13665637+DanielNoord@users.noreply.github.com>
  15. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  16. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  17. """ A few useful function/method decorators."""
  18. import functools
  19. import inspect
  20. import sys
  21. import warnings
  22. from typing import Callable, TypeVar
  23. import wrapt
  24. from astroid import util
  25. from astroid.context import InferenceContext
  26. from astroid.exceptions import InferenceError
  27. if sys.version_info >= (3, 10):
  28. from typing import ParamSpec
  29. else:
  30. from typing_extensions import ParamSpec
  31. R = TypeVar("R")
  32. P = ParamSpec("P")
  33. @wrapt.decorator
  34. def cached(func, instance, args, kwargs):
  35. """Simple decorator to cache result of method calls without args."""
  36. cache = getattr(instance, "__cache", None)
  37. if cache is None:
  38. instance.__cache = cache = {}
  39. try:
  40. return cache[func]
  41. except KeyError:
  42. cache[func] = result = func(*args, **kwargs)
  43. return result
  44. class cachedproperty:
  45. """Provides a cached property equivalent to the stacking of
  46. @cached and @property, but more efficient.
  47. After first usage, the <property_name> becomes part of the object's
  48. __dict__. Doing:
  49. del obj.<property_name> empties the cache.
  50. Idea taken from the pyramid_ framework and the mercurial_ project.
  51. .. _pyramid: http://pypi.python.org/pypi/pyramid
  52. .. _mercurial: http://pypi.python.org/pypi/Mercurial
  53. """
  54. __slots__ = ("wrapped",)
  55. def __init__(self, wrapped):
  56. try:
  57. wrapped.__name__
  58. except AttributeError as exc:
  59. raise TypeError(f"{wrapped} must have a __name__ attribute") from exc
  60. self.wrapped = wrapped
  61. @property
  62. def __doc__(self):
  63. doc = getattr(self.wrapped, "__doc__", None)
  64. return "<wrapped by the cachedproperty decorator>%s" % (
  65. "\n%s" % doc if doc else ""
  66. )
  67. def __get__(self, inst, objtype=None):
  68. if inst is None:
  69. return self
  70. val = self.wrapped(inst)
  71. setattr(inst, self.wrapped.__name__, val)
  72. return val
  73. def path_wrapper(func):
  74. """return the given infer function wrapped to handle the path
  75. Used to stop inference if the node has already been looked
  76. at for a given `InferenceContext` to prevent infinite recursion
  77. """
  78. @functools.wraps(func)
  79. def wrapped(node, context=None, _func=func, **kwargs):
  80. """wrapper function handling context"""
  81. if context is None:
  82. context = InferenceContext()
  83. if context.push(node):
  84. return
  85. yielded = set()
  86. for res in _func(node, context, **kwargs):
  87. # unproxy only true instance, not const, tuple, dict...
  88. if res.__class__.__name__ == "Instance":
  89. ares = res._proxied
  90. else:
  91. ares = res
  92. if ares not in yielded:
  93. yield res
  94. yielded.add(ares)
  95. return wrapped
  96. @wrapt.decorator
  97. def yes_if_nothing_inferred(func, instance, args, kwargs):
  98. generator = func(*args, **kwargs)
  99. try:
  100. yield next(generator)
  101. except StopIteration:
  102. # generator is empty
  103. yield util.Uninferable
  104. return
  105. yield from generator
  106. @wrapt.decorator
  107. def raise_if_nothing_inferred(func, instance, args, kwargs):
  108. generator = func(*args, **kwargs)
  109. try:
  110. yield next(generator)
  111. except StopIteration as error:
  112. # generator is empty
  113. if error.args:
  114. # pylint: disable=not-a-mapping
  115. raise InferenceError(**error.args[0]) from error
  116. raise InferenceError(
  117. "StopIteration raised without any error information."
  118. ) from error
  119. yield from generator
  120. def deprecate_default_argument_values(
  121. astroid_version: str = "3.0", **arguments: str
  122. ) -> Callable[[Callable[P, R]], Callable[P, R]]:
  123. """Decorator which emitts a DeprecationWarning if any arguments specified
  124. are None or not passed at all.
  125. Arguments should be a key-value mapping, with the key being the argument to check
  126. and the value being a type annotation as string for the value of the argument.
  127. """
  128. # Helpful links
  129. # Decorator for DeprecationWarning: https://stackoverflow.com/a/49802489
  130. # Typing of stacked decorators: https://stackoverflow.com/a/68290080
  131. def deco(func: Callable[P, R]) -> Callable[P, R]:
  132. """Decorator function."""
  133. @functools.wraps(func)
  134. def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
  135. """Emit DeprecationWarnings if conditions are met."""
  136. keys = list(inspect.signature(func).parameters.keys())
  137. for arg, type_annotation in arguments.items():
  138. try:
  139. index = keys.index(arg)
  140. except ValueError:
  141. raise Exception(
  142. f"Can't find argument '{arg}' for '{args[0].__class__.__qualname__}'"
  143. ) from None
  144. if (
  145. # Check kwargs
  146. # - if found, check it's not None
  147. (arg in kwargs and kwargs[arg] is None)
  148. # Check args
  149. # - make sure not in kwargs
  150. # - len(args) needs to be long enough, if too short
  151. # arg can't be in args either
  152. # - args[index] should not be None
  153. or arg not in kwargs
  154. and (
  155. index == -1
  156. or len(args) <= index
  157. or (len(args) > index and args[index] is None)
  158. )
  159. ):
  160. warnings.warn(
  161. f"'{arg}' will be a required argument for "
  162. f"'{args[0].__class__.__qualname__}.{func.__name__}' in astroid {astroid_version} "
  163. f"('{arg}' should be of type: '{type_annotation}')",
  164. DeprecationWarning,
  165. )
  166. return func(*args, **kwargs)
  167. return wrapper
  168. return deco