inference_tip.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  3. """Transform utilities (filters and decorator)"""
  4. import typing
  5. import wrapt
  6. from astroid.exceptions import InferenceOverwriteError
  7. from astroid.nodes import NodeNG
  8. InferFn = typing.Callable[..., typing.Any]
  9. _cache: typing.Dict[typing.Tuple[InferFn, NodeNG], typing.Any] = {}
  10. def clear_inference_tip_cache():
  11. """Clear the inference tips cache."""
  12. _cache.clear()
  13. @wrapt.decorator
  14. def _inference_tip_cached(func, instance, args, kwargs):
  15. """Cache decorator used for inference tips"""
  16. node = args[0]
  17. try:
  18. result = _cache[func, node]
  19. except KeyError:
  20. result = _cache[func, node] = list(func(*args, **kwargs))
  21. return iter(result)
  22. def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn:
  23. """Given an instance specific inference function, return a function to be
  24. given to AstroidManager().register_transform to set this inference function.
  25. :param bool raise_on_overwrite: Raise an `InferenceOverwriteError`
  26. if the inference tip will overwrite another. Used for debugging
  27. Typical usage
  28. .. sourcecode:: python
  29. AstroidManager().register_transform(Call, inference_tip(infer_named_tuple),
  30. predicate)
  31. .. Note::
  32. Using an inference tip will override
  33. any previously set inference tip for the given
  34. node. Use a predicate in the transform to prevent
  35. excess overwrites.
  36. """
  37. def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
  38. if (
  39. raise_on_overwrite
  40. and node._explicit_inference is not None
  41. and node._explicit_inference is not infer_function
  42. ):
  43. raise InferenceOverwriteError(
  44. "Inference already set to {existing_inference}. "
  45. "Trying to overwrite with {new_inference} for {node}".format(
  46. existing_inference=infer_function,
  47. new_inference=node._explicit_inference,
  48. node=node,
  49. )
  50. )
  51. # pylint: disable=no-value-for-parameter
  52. node._explicit_inference = _inference_tip_cached(infer_function)
  53. return node
  54. return transform