__init__.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """Support for presenting detailed information in failing assertions."""
  2. import sys
  3. from typing import Any
  4. from typing import Generator
  5. from typing import List
  6. from typing import Optional
  7. from typing import TYPE_CHECKING
  8. from _pytest.assertion import rewrite
  9. from _pytest.assertion import truncate
  10. from _pytest.assertion import util
  11. from _pytest.assertion.rewrite import assertstate_key
  12. from _pytest.config import Config
  13. from _pytest.config import hookimpl
  14. from _pytest.config.argparsing import Parser
  15. from _pytest.nodes import Item
  16. if TYPE_CHECKING:
  17. from _pytest.main import Session
  18. def pytest_addoption(parser: Parser) -> None:
  19. group = parser.getgroup("debugconfig")
  20. group.addoption(
  21. "--assert",
  22. action="store",
  23. dest="assertmode",
  24. choices=("rewrite", "plain"),
  25. default="rewrite",
  26. metavar="MODE",
  27. help=(
  28. "Control assertion debugging tools.\n"
  29. "'plain' performs no assertion debugging.\n"
  30. "'rewrite' (the default) rewrites assert statements in test modules"
  31. " on import to provide assert expression information."
  32. ),
  33. )
  34. parser.addini(
  35. "enable_assertion_pass_hook",
  36. type="bool",
  37. default=False,
  38. help="Enables the pytest_assertion_pass hook."
  39. "Make sure to delete any previously generated pyc cache files.",
  40. )
  41. def register_assert_rewrite(*names: str) -> None:
  42. """Register one or more module names to be rewritten on import.
  43. This function will make sure that this module or all modules inside
  44. the package will get their assert statements rewritten.
  45. Thus you should make sure to call this before the module is
  46. actually imported, usually in your __init__.py if you are a plugin
  47. using a package.
  48. :raises TypeError: If the given module names are not strings.
  49. """
  50. for name in names:
  51. if not isinstance(name, str):
  52. msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
  53. raise TypeError(msg.format(repr(names)))
  54. for hook in sys.meta_path:
  55. if isinstance(hook, rewrite.AssertionRewritingHook):
  56. importhook = hook
  57. break
  58. else:
  59. # TODO(typing): Add a protocol for mark_rewrite() and use it
  60. # for importhook and for PytestPluginManager.rewrite_hook.
  61. importhook = DummyRewriteHook() # type: ignore
  62. importhook.mark_rewrite(*names)
  63. class DummyRewriteHook:
  64. """A no-op import hook for when rewriting is disabled."""
  65. def mark_rewrite(self, *names: str) -> None:
  66. pass
  67. class AssertionState:
  68. """State for the assertion plugin."""
  69. def __init__(self, config: Config, mode) -> None:
  70. self.mode = mode
  71. self.trace = config.trace.root.get("assertion")
  72. self.hook: Optional[rewrite.AssertionRewritingHook] = None
  73. def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
  74. """Try to install the rewrite hook, raise SystemError if it fails."""
  75. config.stash[assertstate_key] = AssertionState(config, "rewrite")
  76. config.stash[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
  77. sys.meta_path.insert(0, hook)
  78. config.stash[assertstate_key].trace("installed rewrite import hook")
  79. def undo() -> None:
  80. hook = config.stash[assertstate_key].hook
  81. if hook is not None and hook in sys.meta_path:
  82. sys.meta_path.remove(hook)
  83. config.add_cleanup(undo)
  84. return hook
  85. def pytest_collection(session: "Session") -> None:
  86. # This hook is only called when test modules are collected
  87. # so for example not in the managing process of pytest-xdist
  88. # (which does not collect test modules).
  89. assertstate = session.config.stash.get(assertstate_key, None)
  90. if assertstate:
  91. if assertstate.hook is not None:
  92. assertstate.hook.set_session(session)
  93. @hookimpl(tryfirst=True, hookwrapper=True)
  94. def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
  95. """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
  96. The rewrite module will use util._reprcompare if it exists to use custom
  97. reporting via the pytest_assertrepr_compare hook. This sets up this custom
  98. comparison for the test.
  99. """
  100. ihook = item.ihook
  101. def callbinrepr(op, left: object, right: object) -> Optional[str]:
  102. """Call the pytest_assertrepr_compare hook and prepare the result.
  103. This uses the first result from the hook and then ensures the
  104. following:
  105. * Overly verbose explanations are truncated unless configured otherwise
  106. (eg. if running in verbose mode).
  107. * Embedded newlines are escaped to help util.format_explanation()
  108. later.
  109. * If the rewrite mode is used embedded %-characters are replaced
  110. to protect later % formatting.
  111. The result can be formatted by util.format_explanation() for
  112. pretty printing.
  113. """
  114. hook_result = ihook.pytest_assertrepr_compare(
  115. config=item.config, op=op, left=left, right=right
  116. )
  117. for new_expl in hook_result:
  118. if new_expl:
  119. new_expl = truncate.truncate_if_required(new_expl, item)
  120. new_expl = [line.replace("\n", "\\n") for line in new_expl]
  121. res = "\n~".join(new_expl)
  122. if item.config.getvalue("assertmode") == "rewrite":
  123. res = res.replace("%", "%%")
  124. return res
  125. return None
  126. saved_assert_hooks = util._reprcompare, util._assertion_pass
  127. util._reprcompare = callbinrepr
  128. util._config = item.config
  129. if ihook.pytest_assertion_pass.get_hookimpls():
  130. def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None:
  131. ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
  132. util._assertion_pass = call_assertion_pass_hook
  133. yield
  134. util._reprcompare, util._assertion_pass = saved_assert_hooks
  135. util._config = None
  136. def pytest_sessionfinish(session: "Session") -> None:
  137. assertstate = session.config.stash.get(assertstate_key, None)
  138. if assertstate:
  139. if assertstate.hook is not None:
  140. assertstate.hook.set_session(None)
  141. def pytest_assertrepr_compare(
  142. config: Config, op: str, left: Any, right: Any
  143. ) -> Optional[List[str]]:
  144. return util.assertrepr_compare(config=config, op=op, left=left, right=right)