brain_nose.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) 2015-2016, 2018, 2020 Claudiu Popa <pcmanticore@gmail.com>
  2. # Copyright (c) 2016 Ceridwen <ceridwenv@gmail.com>
  3. # Copyright (c) 2020-2021 hippo91 <guillaume.peillex@gmail.com>
  4. # Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
  5. # Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
  6. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  7. # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
  8. """Hooks for nose library."""
  9. import re
  10. import textwrap
  11. import astroid.builder
  12. from astroid.brain.helpers import register_module_extender
  13. from astroid.exceptions import InferenceError
  14. from astroid.manager import AstroidManager
  15. _BUILDER = astroid.builder.AstroidBuilder(AstroidManager())
  16. CAPITALS = re.compile("([A-Z])")
  17. def _pep8(name, caps=CAPITALS):
  18. return caps.sub(lambda m: "_" + m.groups()[0].lower(), name)
  19. def _nose_tools_functions():
  20. """Get an iterator of names and bound methods."""
  21. module = _BUILDER.string_build(
  22. textwrap.dedent(
  23. """
  24. import unittest
  25. class Test(unittest.TestCase):
  26. pass
  27. a = Test()
  28. """
  29. )
  30. )
  31. try:
  32. case = next(module["a"].infer())
  33. except (InferenceError, StopIteration):
  34. return
  35. for method in case.methods():
  36. if method.name.startswith("assert") and "_" not in method.name:
  37. pep8_name = _pep8(method.name)
  38. yield pep8_name, astroid.BoundMethod(method, case)
  39. if method.name == "assertEqual":
  40. # nose also exports assert_equals.
  41. yield "assert_equals", astroid.BoundMethod(method, case)
  42. def _nose_tools_transform(node):
  43. for method_name, method in _nose_tools_functions():
  44. node.locals[method_name] = [method]
  45. def _nose_tools_trivial_transform():
  46. """Custom transform for the nose.tools module."""
  47. stub = _BUILDER.string_build("""__all__ = []""")
  48. all_entries = ["ok_", "eq_"]
  49. for pep8_name, method in _nose_tools_functions():
  50. all_entries.append(pep8_name)
  51. stub[pep8_name] = method
  52. # Update the __all__ variable, since nose.tools
  53. # does this manually with .append.
  54. all_assign = stub["__all__"].parent
  55. all_object = astroid.List(all_entries)
  56. all_object.parent = all_assign
  57. all_assign.value = all_object
  58. return stub
  59. register_module_extender(
  60. AstroidManager(), "nose.tools.trivial", _nose_tools_trivial_transform
  61. )
  62. AstroidManager().register_transform(
  63. astroid.Module, _nose_tools_transform, lambda n: n.name == "nose.tools"
  64. )