brain_random.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  3. import random
  4. from astroid import helpers
  5. from astroid.exceptions import UseInferenceDefault
  6. from astroid.inference_tip import inference_tip
  7. from astroid.manager import AstroidManager
  8. from astroid.nodes.node_classes import (
  9. Attribute,
  10. Call,
  11. Const,
  12. EvaluatedObject,
  13. List,
  14. Name,
  15. Set,
  16. Tuple,
  17. )
  18. ACCEPTED_ITERABLES_FOR_SAMPLE = (List, Set, Tuple)
  19. def _clone_node_with_lineno(node, parent, lineno):
  20. if isinstance(node, EvaluatedObject):
  21. node = node.original
  22. cls = node.__class__
  23. other_fields = node._other_fields
  24. _astroid_fields = node._astroid_fields
  25. init_params = {"lineno": lineno, "col_offset": node.col_offset, "parent": parent}
  26. postinit_params = {param: getattr(node, param) for param in _astroid_fields}
  27. if other_fields:
  28. init_params.update({param: getattr(node, param) for param in other_fields})
  29. new_node = cls(**init_params)
  30. if hasattr(node, "postinit") and _astroid_fields:
  31. new_node.postinit(**postinit_params)
  32. return new_node
  33. def infer_random_sample(node, context=None):
  34. if len(node.args) != 2:
  35. raise UseInferenceDefault
  36. length = node.args[1]
  37. if not isinstance(length, Const):
  38. raise UseInferenceDefault
  39. if not isinstance(length.value, int):
  40. raise UseInferenceDefault
  41. inferred_sequence = helpers.safe_infer(node.args[0], context=context)
  42. if not inferred_sequence:
  43. raise UseInferenceDefault
  44. if not isinstance(inferred_sequence, ACCEPTED_ITERABLES_FOR_SAMPLE):
  45. raise UseInferenceDefault
  46. if length.value > len(inferred_sequence.elts):
  47. # In this case, this will raise a ValueError
  48. raise UseInferenceDefault
  49. try:
  50. elts = random.sample(inferred_sequence.elts, length.value)
  51. except ValueError as exc:
  52. raise UseInferenceDefault from exc
  53. new_node = List(lineno=node.lineno, col_offset=node.col_offset, parent=node.scope())
  54. new_elts = [
  55. _clone_node_with_lineno(elt, parent=new_node, lineno=new_node.lineno)
  56. for elt in elts
  57. ]
  58. new_node.postinit(new_elts)
  59. return iter((new_node,))
  60. def _looks_like_random_sample(node):
  61. func = node.func
  62. if isinstance(func, Attribute):
  63. return func.attrname == "sample"
  64. if isinstance(func, Name):
  65. return func.name == "sample"
  66. return False
  67. AstroidManager().register_transform(
  68. Call, inference_tip(infer_random_sample), _looks_like_random_sample
  69. )