_reloader.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. import fnmatch
  2. import os
  3. import subprocess
  4. import sys
  5. import threading
  6. import time
  7. import typing as t
  8. from itertools import chain
  9. from pathlib import PurePath
  10. from ._internal import _log
  11. # The various system prefixes where imports are found. Base values are
  12. # different when running in a virtualenv. The stat reloader won't scan
  13. # these directories, it would be too inefficient.
  14. prefix = {sys.prefix, sys.base_prefix, sys.exec_prefix, sys.base_exec_prefix}
  15. if hasattr(sys, "real_prefix"):
  16. # virtualenv < 20
  17. prefix.add(sys.real_prefix) # type: ignore
  18. _ignore_prefixes = tuple(prefix)
  19. del prefix
  20. def _iter_module_paths() -> t.Iterator[str]:
  21. """Find the filesystem paths associated with imported modules."""
  22. # List is in case the value is modified by the app while updating.
  23. for module in list(sys.modules.values()):
  24. name = getattr(module, "__file__", None)
  25. if name is None:
  26. continue
  27. while not os.path.isfile(name):
  28. # Zip file, find the base file without the module path.
  29. old = name
  30. name = os.path.dirname(name)
  31. if name == old: # skip if it was all directories somehow
  32. break
  33. else:
  34. yield name
  35. def _remove_by_pattern(paths: t.Set[str], exclude_patterns: t.Set[str]) -> None:
  36. for pattern in exclude_patterns:
  37. paths.difference_update(fnmatch.filter(paths, pattern))
  38. def _find_stat_paths(
  39. extra_files: t.Set[str], exclude_patterns: t.Set[str]
  40. ) -> t.Iterable[str]:
  41. """Find paths for the stat reloader to watch. Returns imported
  42. module files, Python files under non-system paths. Extra files and
  43. Python files under extra directories can also be scanned.
  44. System paths have to be excluded for efficiency. Non-system paths,
  45. such as a project root or ``sys.path.insert``, should be the paths
  46. of interest to the user anyway.
  47. """
  48. paths = set()
  49. for path in chain(list(sys.path), extra_files):
  50. path = os.path.abspath(path)
  51. if os.path.isfile(path):
  52. # zip file on sys.path, or extra file
  53. paths.add(path)
  54. for root, dirs, files in os.walk(path):
  55. # Ignore system prefixes for efficience. Don't scan
  56. # __pycache__, it will have a py or pyc module at the import
  57. # path. As an optimization, ignore .git and .hg since
  58. # nothing interesting will be there.
  59. if root.startswith(_ignore_prefixes) or os.path.basename(root) in {
  60. "__pycache__",
  61. ".git",
  62. ".hg",
  63. }:
  64. dirs.clear()
  65. continue
  66. for name in files:
  67. if name.endswith((".py", ".pyc")):
  68. paths.add(os.path.join(root, name))
  69. paths.update(_iter_module_paths())
  70. _remove_by_pattern(paths, exclude_patterns)
  71. return paths
  72. def _find_watchdog_paths(
  73. extra_files: t.Set[str], exclude_patterns: t.Set[str]
  74. ) -> t.Iterable[str]:
  75. """Find paths for the stat reloader to watch. Looks at the same
  76. sources as the stat reloader, but watches everything under
  77. directories instead of individual files.
  78. """
  79. dirs = set()
  80. for name in chain(list(sys.path), extra_files):
  81. name = os.path.abspath(name)
  82. if os.path.isfile(name):
  83. name = os.path.dirname(name)
  84. dirs.add(name)
  85. for name in _iter_module_paths():
  86. dirs.add(os.path.dirname(name))
  87. _remove_by_pattern(dirs, exclude_patterns)
  88. return _find_common_roots(dirs)
  89. def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]:
  90. root: t.Dict[str, dict] = {}
  91. for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True):
  92. node = root
  93. for chunk in chunks:
  94. node = node.setdefault(chunk, {})
  95. node.clear()
  96. rv = set()
  97. def _walk(node: t.Mapping[str, dict], path: t.Tuple[str, ...]) -> None:
  98. for prefix, child in node.items():
  99. _walk(child, path + (prefix,))
  100. if not node:
  101. rv.add(os.path.join(*path))
  102. _walk(root, ())
  103. return rv
  104. def _get_args_for_reloading() -> t.List[str]:
  105. """Determine how the script was executed, and return the args needed
  106. to execute it again in a new process.
  107. """
  108. rv = [sys.executable]
  109. py_script = sys.argv[0]
  110. args = sys.argv[1:]
  111. # Need to look at main module to determine how it was executed.
  112. __main__ = sys.modules["__main__"]
  113. # The value of __package__ indicates how Python was called. It may
  114. # not exist if a setuptools script is installed as an egg. It may be
  115. # set incorrectly for entry points created with pip on Windows.
  116. if getattr(__main__, "__package__", None) is None or (
  117. os.name == "nt"
  118. and __main__.__package__ == ""
  119. and not os.path.exists(py_script)
  120. and os.path.exists(f"{py_script}.exe")
  121. ):
  122. # Executed a file, like "python app.py".
  123. py_script = os.path.abspath(py_script)
  124. if os.name == "nt":
  125. # Windows entry points have ".exe" extension and should be
  126. # called directly.
  127. if not os.path.exists(py_script) and os.path.exists(f"{py_script}.exe"):
  128. py_script += ".exe"
  129. if (
  130. os.path.splitext(sys.executable)[1] == ".exe"
  131. and os.path.splitext(py_script)[1] == ".exe"
  132. ):
  133. rv.pop(0)
  134. rv.append(py_script)
  135. else:
  136. # Executed a module, like "python -m werkzeug.serving".
  137. if sys.argv[0] == "-m":
  138. # Flask works around previous behavior by putting
  139. # "-m flask" in sys.argv.
  140. # TODO remove this once Flask no longer misbehaves
  141. args = sys.argv
  142. else:
  143. if os.path.isfile(py_script):
  144. # Rewritten by Python from "-m script" to "/path/to/script.py".
  145. py_module = t.cast(str, __main__.__package__)
  146. name = os.path.splitext(os.path.basename(py_script))[0]
  147. if name != "__main__":
  148. py_module += f".{name}"
  149. else:
  150. # Incorrectly rewritten by pydevd debugger from "-m script" to "script".
  151. py_module = py_script
  152. rv.extend(("-m", py_module.lstrip(".")))
  153. rv.extend(args)
  154. return rv
  155. class ReloaderLoop:
  156. name = ""
  157. def __init__(
  158. self,
  159. extra_files: t.Optional[t.Iterable[str]] = None,
  160. exclude_patterns: t.Optional[t.Iterable[str]] = None,
  161. interval: t.Union[int, float] = 1,
  162. ) -> None:
  163. self.extra_files: t.Set[str] = {os.path.abspath(x) for x in extra_files or ()}
  164. self.exclude_patterns: t.Set[str] = set(exclude_patterns or ())
  165. self.interval = interval
  166. def __enter__(self) -> "ReloaderLoop":
  167. """Do any setup, then run one step of the watch to populate the
  168. initial filesystem state.
  169. """
  170. self.run_step()
  171. return self
  172. def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
  173. """Clean up any resources associated with the reloader."""
  174. pass
  175. def run(self) -> None:
  176. """Continually run the watch step, sleeping for the configured
  177. interval after each step.
  178. """
  179. while True:
  180. self.run_step()
  181. time.sleep(self.interval)
  182. def run_step(self) -> None:
  183. """Run one step for watching the filesystem. Called once to set
  184. up initial state, then repeatedly to update it.
  185. """
  186. pass
  187. def restart_with_reloader(self) -> int:
  188. """Spawn a new Python interpreter with the same arguments as the
  189. current one, but running the reloader thread.
  190. """
  191. while True:
  192. _log("info", f" * Restarting with {self.name}")
  193. args = _get_args_for_reloading()
  194. new_environ = os.environ.copy()
  195. new_environ["WERKZEUG_RUN_MAIN"] = "true"
  196. exit_code = subprocess.call(args, env=new_environ, close_fds=False)
  197. if exit_code != 3:
  198. return exit_code
  199. def trigger_reload(self, filename: str) -> None:
  200. self.log_reload(filename)
  201. sys.exit(3)
  202. def log_reload(self, filename: str) -> None:
  203. filename = os.path.abspath(filename)
  204. _log("info", f" * Detected change in {filename!r}, reloading")
  205. class StatReloaderLoop(ReloaderLoop):
  206. name = "stat"
  207. def __enter__(self) -> ReloaderLoop:
  208. self.mtimes: t.Dict[str, float] = {}
  209. return super().__enter__()
  210. def run_step(self) -> None:
  211. for name in chain(_find_stat_paths(self.extra_files, self.exclude_patterns)):
  212. try:
  213. mtime = os.stat(name).st_mtime
  214. except OSError:
  215. continue
  216. old_time = self.mtimes.get(name)
  217. if old_time is None:
  218. self.mtimes[name] = mtime
  219. continue
  220. if mtime > old_time:
  221. self.trigger_reload(name)
  222. class WatchdogReloaderLoop(ReloaderLoop):
  223. def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
  224. from watchdog.observers import Observer
  225. from watchdog.events import PatternMatchingEventHandler
  226. super().__init__(*args, **kwargs)
  227. trigger_reload = self.trigger_reload
  228. class EventHandler(PatternMatchingEventHandler): # type: ignore
  229. def on_any_event(self, event): # type: ignore
  230. trigger_reload(event.src_path)
  231. reloader_name = Observer.__name__.lower()
  232. if reloader_name.endswith("observer"):
  233. reloader_name = reloader_name[:-8]
  234. self.name = f"watchdog ({reloader_name})"
  235. self.observer = Observer()
  236. # Extra patterns can be non-Python files, match them in addition
  237. # to all Python files in default and extra directories. Ignore
  238. # __pycache__ since a change there will always have a change to
  239. # the source file (or initial pyc file) as well. Ignore Git and
  240. # Mercurial internal changes.
  241. extra_patterns = [p for p in self.extra_files if not os.path.isdir(p)]
  242. self.event_handler = EventHandler(
  243. patterns=["*.py", "*.pyc", "*.zip", *extra_patterns],
  244. ignore_patterns=[
  245. "*/__pycache__/*",
  246. "*/.git/*",
  247. "*/.hg/*",
  248. *self.exclude_patterns,
  249. ],
  250. )
  251. self.should_reload = False
  252. def trigger_reload(self, filename: str) -> None:
  253. # This is called inside an event handler, which means throwing
  254. # SystemExit has no effect.
  255. # https://github.com/gorakhargosh/watchdog/issues/294
  256. self.should_reload = True
  257. self.log_reload(filename)
  258. def __enter__(self) -> ReloaderLoop:
  259. self.watches: t.Dict[str, t.Any] = {}
  260. self.observer.start()
  261. return super().__enter__()
  262. def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
  263. self.observer.stop()
  264. self.observer.join()
  265. def run(self) -> None:
  266. while not self.should_reload:
  267. self.run_step()
  268. time.sleep(self.interval)
  269. sys.exit(3)
  270. def run_step(self) -> None:
  271. to_delete = set(self.watches)
  272. for path in _find_watchdog_paths(self.extra_files, self.exclude_patterns):
  273. if path not in self.watches:
  274. try:
  275. self.watches[path] = self.observer.schedule(
  276. self.event_handler, path, recursive=True
  277. )
  278. except OSError:
  279. # Clear this path from list of watches We don't want
  280. # the same error message showing again in the next
  281. # iteration.
  282. self.watches[path] = None
  283. to_delete.discard(path)
  284. for path in to_delete:
  285. watch = self.watches.pop(path, None)
  286. if watch is not None:
  287. self.observer.unschedule(watch)
  288. reloader_loops: t.Dict[str, t.Type[ReloaderLoop]] = {
  289. "stat": StatReloaderLoop,
  290. "watchdog": WatchdogReloaderLoop,
  291. }
  292. try:
  293. __import__("watchdog.observers")
  294. except ImportError:
  295. reloader_loops["auto"] = reloader_loops["stat"]
  296. else:
  297. reloader_loops["auto"] = reloader_loops["watchdog"]
  298. def ensure_echo_on() -> None:
  299. """Ensure that echo mode is enabled. Some tools such as PDB disable
  300. it which causes usability issues after a reload."""
  301. # tcgetattr will fail if stdin isn't a tty
  302. if sys.stdin is None or not sys.stdin.isatty():
  303. return
  304. try:
  305. import termios
  306. except ImportError:
  307. return
  308. attributes = termios.tcgetattr(sys.stdin)
  309. if not attributes[3] & termios.ECHO:
  310. attributes[3] |= termios.ECHO
  311. termios.tcsetattr(sys.stdin, termios.TCSANOW, attributes)
  312. def run_with_reloader(
  313. main_func: t.Callable[[], None],
  314. extra_files: t.Optional[t.Iterable[str]] = None,
  315. exclude_patterns: t.Optional[t.Iterable[str]] = None,
  316. interval: t.Union[int, float] = 1,
  317. reloader_type: str = "auto",
  318. ) -> None:
  319. """Run the given function in an independent Python interpreter."""
  320. import signal
  321. signal.signal(signal.SIGTERM, lambda *args: sys.exit(0))
  322. reloader = reloader_loops[reloader_type](
  323. extra_files=extra_files, exclude_patterns=exclude_patterns, interval=interval
  324. )
  325. try:
  326. if os.environ.get("WERKZEUG_RUN_MAIN") == "true":
  327. ensure_echo_on()
  328. t = threading.Thread(target=main_func, args=())
  329. t.daemon = True
  330. # Enter the reloader to set up initial state, then start
  331. # the app thread and reloader update loop.
  332. with reloader:
  333. t.start()
  334. reloader.run()
  335. else:
  336. sys.exit(reloader.restart_with_reloader())
  337. except KeyboardInterrupt:
  338. pass