sync.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. import asyncio.coroutines
  2. import functools
  3. import inspect
  4. import os
  5. import sys
  6. import threading
  7. import warnings
  8. import weakref
  9. from concurrent.futures import Future, ThreadPoolExecutor
  10. from typing import Any, Callable, Dict, Optional, overload
  11. from .compatibility import get_running_loop
  12. from .current_thread_executor import CurrentThreadExecutor
  13. from .local import Local
  14. if sys.version_info >= (3, 7):
  15. import contextvars
  16. else:
  17. contextvars = None
  18. def _restore_context(context):
  19. # Check for changes in contextvars, and set them to the current
  20. # context for downstream consumers
  21. for cvar in context:
  22. try:
  23. if cvar.get() != context.get(cvar):
  24. cvar.set(context.get(cvar))
  25. except LookupError:
  26. cvar.set(context.get(cvar))
  27. def _iscoroutinefunction_or_partial(func: Any) -> bool:
  28. # Python < 3.8 does not correctly determine partially wrapped
  29. # coroutine functions are coroutine functions, hence the need for
  30. # this to exist. Code taken from CPython.
  31. if sys.version_info >= (3, 8):
  32. return asyncio.iscoroutinefunction(func)
  33. else:
  34. while inspect.ismethod(func):
  35. func = func.__func__
  36. while isinstance(func, functools.partial):
  37. func = func.func
  38. return asyncio.iscoroutinefunction(func)
  39. class ThreadSensitiveContext:
  40. """Async context manager to manage context for thread sensitive mode
  41. This context manager controls which thread pool executor is used when in
  42. thread sensitive mode. By default, a single thread pool executor is shared
  43. within a process.
  44. In Python 3.7+, the ThreadSensitiveContext() context manager may be used to
  45. specify a thread pool per context.
  46. In Python 3.6, usage of this context manager has no effect.
  47. This context manager is re-entrant, so only the outer-most call to
  48. ThreadSensitiveContext will set the context.
  49. Usage:
  50. >>> import time
  51. >>> async with ThreadSensitiveContext():
  52. ... await sync_to_async(time.sleep, 1)()
  53. """
  54. def __init__(self):
  55. self.token = None
  56. if contextvars:
  57. async def __aenter__(self):
  58. try:
  59. SyncToAsync.thread_sensitive_context.get()
  60. except LookupError:
  61. self.token = SyncToAsync.thread_sensitive_context.set(self)
  62. return self
  63. async def __aexit__(self, exc, value, tb):
  64. if not self.token:
  65. return
  66. executor = SyncToAsync.context_to_thread_executor.pop(self, None)
  67. if executor:
  68. executor.shutdown()
  69. SyncToAsync.thread_sensitive_context.reset(self.token)
  70. else:
  71. async def __aenter__(self):
  72. return self
  73. async def __aexit__(self, exc, value, tb):
  74. pass
  75. class AsyncToSync:
  76. """
  77. Utility class which turns an awaitable that only works on the thread with
  78. the event loop into a synchronous callable that works in a subthread.
  79. If the call stack contains an async loop, the code runs there.
  80. Otherwise, the code runs in a new loop in a new thread.
  81. Either way, this thread then pauses and waits to run any thread_sensitive
  82. code called from further down the call stack using SyncToAsync, before
  83. finally exiting once the async task returns.
  84. """
  85. # Maps launched Tasks to the threads that launched them (for locals impl)
  86. launch_map: "Dict[asyncio.Task[object], threading.Thread]" = {}
  87. # Keeps track of which CurrentThreadExecutor to use. This uses an asgiref
  88. # Local, not a threadlocal, so that tasks can work out what their parent used.
  89. executors = Local()
  90. def __init__(self, awaitable, force_new_loop=False):
  91. if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable):
  92. # Python does not have very reliable detection of async functions
  93. # (lots of false negatives) so this is just a warning.
  94. warnings.warn("async_to_sync was passed a non-async-marked callable")
  95. self.awaitable = awaitable
  96. try:
  97. self.__self__ = self.awaitable.__self__
  98. except AttributeError:
  99. pass
  100. if force_new_loop:
  101. # They have asked that we always run in a new sub-loop.
  102. self.main_event_loop = None
  103. else:
  104. try:
  105. self.main_event_loop = get_running_loop()
  106. except RuntimeError:
  107. # There's no event loop in this thread. Look for the threadlocal if
  108. # we're inside SyncToAsync
  109. main_event_loop_pid = getattr(
  110. SyncToAsync.threadlocal, "main_event_loop_pid", None
  111. )
  112. # We make sure the parent loop is from the same process - if
  113. # they've forked, this is not going to be valid any more (#194)
  114. if main_event_loop_pid and main_event_loop_pid == os.getpid():
  115. self.main_event_loop = getattr(
  116. SyncToAsync.threadlocal, "main_event_loop", None
  117. )
  118. else:
  119. self.main_event_loop = None
  120. def __call__(self, *args, **kwargs):
  121. # You can't call AsyncToSync from a thread with a running event loop
  122. try:
  123. event_loop = get_running_loop()
  124. except RuntimeError:
  125. pass
  126. else:
  127. if event_loop.is_running():
  128. raise RuntimeError(
  129. "You cannot use AsyncToSync in the same thread as an async event loop - "
  130. "just await the async function directly."
  131. )
  132. if contextvars is not None:
  133. # Wrapping context in list so it can be reassigned from within
  134. # `main_wrap`.
  135. context = [contextvars.copy_context()]
  136. else:
  137. context = None
  138. # Make a future for the return information
  139. call_result = Future()
  140. # Get the source thread
  141. source_thread = threading.current_thread()
  142. # Make a CurrentThreadExecutor we'll use to idle in this thread - we
  143. # need one for every sync frame, even if there's one above us in the
  144. # same thread.
  145. if hasattr(self.executors, "current"):
  146. old_current_executor = self.executors.current
  147. else:
  148. old_current_executor = None
  149. current_executor = CurrentThreadExecutor()
  150. self.executors.current = current_executor
  151. # Use call_soon_threadsafe to schedule a synchronous callback on the
  152. # main event loop's thread if it's there, otherwise make a new loop
  153. # in this thread.
  154. try:
  155. awaitable = self.main_wrap(
  156. args, kwargs, call_result, source_thread, sys.exc_info(), context
  157. )
  158. if not (self.main_event_loop and self.main_event_loop.is_running()):
  159. # Make our own event loop - in a new thread - and run inside that.
  160. loop = asyncio.new_event_loop()
  161. loop_executor = ThreadPoolExecutor(max_workers=1)
  162. loop_future = loop_executor.submit(
  163. self._run_event_loop, loop, awaitable
  164. )
  165. if current_executor:
  166. # Run the CurrentThreadExecutor until the future is done
  167. current_executor.run_until_future(loop_future)
  168. # Wait for future and/or allow for exception propagation
  169. loop_future.result()
  170. else:
  171. # Call it inside the existing loop
  172. self.main_event_loop.call_soon_threadsafe(
  173. self.main_event_loop.create_task, awaitable
  174. )
  175. if current_executor:
  176. # Run the CurrentThreadExecutor until the future is done
  177. current_executor.run_until_future(call_result)
  178. finally:
  179. # Clean up any executor we were running
  180. if hasattr(self.executors, "current"):
  181. del self.executors.current
  182. if old_current_executor:
  183. self.executors.current = old_current_executor
  184. if contextvars is not None:
  185. _restore_context(context[0])
  186. # Wait for results from the future.
  187. return call_result.result()
  188. def _run_event_loop(self, loop, coro):
  189. """
  190. Runs the given event loop (designed to be called in a thread).
  191. """
  192. asyncio.set_event_loop(loop)
  193. try:
  194. loop.run_until_complete(coro)
  195. finally:
  196. try:
  197. # mimic asyncio.run() behavior
  198. # cancel unexhausted async generators
  199. if sys.version_info >= (3, 7, 0):
  200. tasks = asyncio.all_tasks(loop)
  201. else:
  202. tasks = asyncio.Task.all_tasks(loop)
  203. for task in tasks:
  204. task.cancel()
  205. async def gather():
  206. await asyncio.gather(*tasks, return_exceptions=True)
  207. loop.run_until_complete(gather())
  208. for task in tasks:
  209. if task.cancelled():
  210. continue
  211. if task.exception() is not None:
  212. loop.call_exception_handler(
  213. {
  214. "message": "unhandled exception during loop shutdown",
  215. "exception": task.exception(),
  216. "task": task,
  217. }
  218. )
  219. if hasattr(loop, "shutdown_asyncgens"):
  220. loop.run_until_complete(loop.shutdown_asyncgens())
  221. finally:
  222. loop.close()
  223. asyncio.set_event_loop(self.main_event_loop)
  224. def __get__(self, parent, objtype):
  225. """
  226. Include self for methods
  227. """
  228. func = functools.partial(self.__call__, parent)
  229. return functools.update_wrapper(func, self.awaitable)
  230. async def main_wrap(
  231. self, args, kwargs, call_result, source_thread, exc_info, context
  232. ):
  233. """
  234. Wraps the awaitable with something that puts the result into the
  235. result/exception future.
  236. """
  237. if context is not None:
  238. _restore_context(context[0])
  239. current_task = SyncToAsync.get_current_task()
  240. self.launch_map[current_task] = source_thread
  241. try:
  242. # If we have an exception, run the function inside the except block
  243. # after raising it so exc_info is correctly populated.
  244. if exc_info[1]:
  245. try:
  246. raise exc_info[1]
  247. except BaseException:
  248. result = await self.awaitable(*args, **kwargs)
  249. else:
  250. result = await self.awaitable(*args, **kwargs)
  251. except BaseException as e:
  252. call_result.set_exception(e)
  253. else:
  254. call_result.set_result(result)
  255. finally:
  256. del self.launch_map[current_task]
  257. if context is not None:
  258. context[0] = contextvars.copy_context()
  259. class SyncToAsync:
  260. """
  261. Utility class which turns a synchronous callable into an awaitable that
  262. runs in a threadpool. It also sets a threadlocal inside the thread so
  263. calls to AsyncToSync can escape it.
  264. If thread_sensitive is passed, the code will run in the same thread as any
  265. outer code. This is needed for underlying Python code that is not
  266. threadsafe (for example, code which handles SQLite database connections).
  267. If the outermost program is async (i.e. SyncToAsync is outermost), then
  268. this will be a dedicated single sub-thread that all sync code runs in,
  269. one after the other. If the outermost program is sync (i.e. AsyncToSync is
  270. outermost), this will just be the main thread. This is achieved by idling
  271. with a CurrentThreadExecutor while AsyncToSync is blocking its sync parent,
  272. rather than just blocking.
  273. If executor is passed in, that will be used instead of the loop's default executor.
  274. In order to pass in an executor, thread_sensitive must be set to False, otherwise
  275. a TypeError will be raised.
  276. """
  277. # If they've set ASGI_THREADS, update the default asyncio executor for now
  278. if "ASGI_THREADS" in os.environ:
  279. loop = get_running_loop()
  280. loop.set_default_executor(
  281. ThreadPoolExecutor(max_workers=int(os.environ["ASGI_THREADS"]))
  282. )
  283. # Maps launched threads to the coroutines that spawned them
  284. launch_map: "Dict[threading.Thread, asyncio.Task[object]]" = {}
  285. # Storage for main event loop references
  286. threadlocal = threading.local()
  287. # Single-thread executor for thread-sensitive code
  288. single_thread_executor = ThreadPoolExecutor(max_workers=1)
  289. # Maintain a contextvar for the current execution context. Optionally used
  290. # for thread sensitive mode.
  291. if sys.version_info >= (3, 7):
  292. thread_sensitive_context: "contextvars.ContextVar[str]" = (
  293. contextvars.ContextVar("thread_sensitive_context")
  294. )
  295. else:
  296. thread_sensitive_context: None = None
  297. # Contextvar that is used to detect if the single thread executor
  298. # would be awaited on while already being used in the same context
  299. if sys.version_info >= (3, 7):
  300. deadlock_context: "contextvars.ContextVar[bool]" = contextvars.ContextVar(
  301. "deadlock_context"
  302. )
  303. else:
  304. deadlock_context: None = None
  305. # Maintaining a weak reference to the context ensures that thread pools are
  306. # erased once the context goes out of scope. This terminates the thread pool.
  307. context_to_thread_executor: "weakref.WeakKeyDictionary[object, ThreadPoolExecutor]" = (
  308. weakref.WeakKeyDictionary()
  309. )
  310. def __init__(
  311. self,
  312. func: Callable[..., Any],
  313. thread_sensitive: bool = True,
  314. executor: Optional["ThreadPoolExecutor"] = None,
  315. ) -> None:
  316. if not callable(func) or _iscoroutinefunction_or_partial(func):
  317. raise TypeError("sync_to_async can only be applied to sync functions.")
  318. self.func = func
  319. functools.update_wrapper(self, func)
  320. self._thread_sensitive = thread_sensitive
  321. self._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore
  322. if thread_sensitive and executor is not None:
  323. raise TypeError("executor must not be set when thread_sensitive is True")
  324. self._executor = executor
  325. try:
  326. self.__self__ = func.__self__ # type: ignore
  327. except AttributeError:
  328. pass
  329. async def __call__(self, *args, **kwargs):
  330. loop = get_running_loop()
  331. # Work out what thread to run the code in
  332. if self._thread_sensitive:
  333. if hasattr(AsyncToSync.executors, "current"):
  334. # If we have a parent sync thread above somewhere, use that
  335. executor = AsyncToSync.executors.current
  336. elif self.thread_sensitive_context and self.thread_sensitive_context.get(
  337. None
  338. ):
  339. # If we have a way of retrieving the current context, attempt
  340. # to use a per-context thread pool executor
  341. thread_sensitive_context = self.thread_sensitive_context.get()
  342. if thread_sensitive_context in self.context_to_thread_executor:
  343. # Re-use thread executor in current context
  344. executor = self.context_to_thread_executor[thread_sensitive_context]
  345. else:
  346. # Create new thread executor in current context
  347. executor = ThreadPoolExecutor(max_workers=1)
  348. self.context_to_thread_executor[thread_sensitive_context] = executor
  349. elif self.deadlock_context and self.deadlock_context.get(False):
  350. raise RuntimeError(
  351. "Single thread executor already being used, would deadlock"
  352. )
  353. else:
  354. # Otherwise, we run it in a fixed single thread
  355. executor = self.single_thread_executor
  356. if self.deadlock_context:
  357. self.deadlock_context.set(True)
  358. else:
  359. # Use the passed in executor, or the loop's default if it is None
  360. executor = self._executor
  361. if contextvars is not None:
  362. context = contextvars.copy_context()
  363. child = functools.partial(self.func, *args, **kwargs)
  364. func = context.run
  365. args = (child,)
  366. kwargs = {}
  367. else:
  368. func = self.func
  369. try:
  370. # Run the code in the right thread
  371. future = loop.run_in_executor(
  372. executor,
  373. functools.partial(
  374. self.thread_handler,
  375. loop,
  376. self.get_current_task(),
  377. sys.exc_info(),
  378. func,
  379. *args,
  380. **kwargs,
  381. ),
  382. )
  383. ret = await asyncio.wait_for(future, timeout=None)
  384. finally:
  385. if contextvars is not None:
  386. _restore_context(context)
  387. if self.deadlock_context:
  388. self.deadlock_context.set(False)
  389. return ret
  390. def __get__(self, parent, objtype):
  391. """
  392. Include self for methods
  393. """
  394. return functools.partial(self.__call__, parent)
  395. def thread_handler(self, loop, source_task, exc_info, func, *args, **kwargs):
  396. """
  397. Wraps the sync application with exception handling.
  398. """
  399. # Set the threadlocal for AsyncToSync
  400. self.threadlocal.main_event_loop = loop
  401. self.threadlocal.main_event_loop_pid = os.getpid()
  402. # Set the task mapping (used for the locals module)
  403. current_thread = threading.current_thread()
  404. if AsyncToSync.launch_map.get(source_task) == current_thread:
  405. # Our parent task was launched from this same thread, so don't make
  406. # a launch map entry - let it shortcut over us! (and stop infinite loops)
  407. parent_set = False
  408. else:
  409. self.launch_map[current_thread] = source_task
  410. parent_set = True
  411. # Run the function
  412. try:
  413. # If we have an exception, run the function inside the except block
  414. # after raising it so exc_info is correctly populated.
  415. if exc_info[1]:
  416. try:
  417. raise exc_info[1]
  418. except BaseException:
  419. return func(*args, **kwargs)
  420. else:
  421. return func(*args, **kwargs)
  422. finally:
  423. # Only delete the launch_map parent if we set it, otherwise it is
  424. # from someone else.
  425. if parent_set:
  426. del self.launch_map[current_thread]
  427. @staticmethod
  428. def get_current_task():
  429. """
  430. Cross-version implementation of asyncio.current_task()
  431. Returns None if there is no task.
  432. """
  433. try:
  434. if hasattr(asyncio, "current_task"):
  435. # Python 3.7 and up
  436. return asyncio.current_task()
  437. else:
  438. # Python 3.6
  439. return asyncio.Task.current_task()
  440. except RuntimeError:
  441. return None
  442. # Lowercase aliases (and decorator friendliness)
  443. async_to_sync = AsyncToSync
  444. @overload
  445. def sync_to_async(
  446. func: None = None,
  447. thread_sensitive: bool = True,
  448. executor: Optional["ThreadPoolExecutor"] = None,
  449. ) -> Callable[[Callable[..., Any]], SyncToAsync]:
  450. ...
  451. @overload
  452. def sync_to_async(
  453. func: Callable[..., Any],
  454. thread_sensitive: bool = True,
  455. executor: Optional["ThreadPoolExecutor"] = None,
  456. ) -> SyncToAsync:
  457. ...
  458. def sync_to_async(
  459. func=None,
  460. thread_sensitive=True,
  461. executor=None,
  462. ):
  463. if func is None:
  464. return lambda f: SyncToAsync(
  465. f,
  466. thread_sensitive=thread_sensitive,
  467. executor=executor,
  468. )
  469. return SyncToAsync(
  470. func,
  471. thread_sensitive=thread_sensitive,
  472. executor=executor,
  473. )