local.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import random
  2. import string
  3. import sys
  4. import threading
  5. import weakref
  6. class Local:
  7. """
  8. A drop-in replacement for threading.locals that also works with asyncio
  9. Tasks (via the current_task asyncio method), and passes locals through
  10. sync_to_async and async_to_sync.
  11. Specifically:
  12. - Locals work per-coroutine on any thread not spawned using asgiref
  13. - Locals work per-thread on any thread not spawned using asgiref
  14. - Locals are shared with the parent coroutine when using sync_to_async
  15. - Locals are shared with the parent thread when using async_to_sync
  16. (and if that thread was launched using sync_to_async, with its parent
  17. coroutine as well, with this working for indefinite levels of nesting)
  18. Set thread_critical to True to not allow locals to pass from an async Task
  19. to a thread it spawns. This is needed for code that truly needs
  20. thread-safety, as opposed to things used for helpful context (e.g. sqlite
  21. does not like being called from a different thread to the one it is from).
  22. Thread-critical code will still be differentiated per-Task within a thread
  23. as it is expected it does not like concurrent access.
  24. This doesn't use contextvars as it needs to support 3.6. Once it can support
  25. 3.7 only, we can then reimplement the storage more nicely.
  26. """
  27. CLEANUP_INTERVAL = 60 # seconds
  28. def __init__(self, thread_critical: bool = False) -> None:
  29. self._thread_critical = thread_critical
  30. self._thread_lock = threading.RLock()
  31. self._context_refs: "weakref.WeakSet[object]" = weakref.WeakSet()
  32. # Random suffixes stop accidental reuse between different Locals,
  33. # though we try to force deletion as well.
  34. self._attr_name = "_asgiref_local_impl_{}_{}".format(
  35. id(self),
  36. "".join(random.choice(string.ascii_letters) for i in range(8)),
  37. )
  38. def _get_context_id(self):
  39. """
  40. Get the ID we should use for looking up variables
  41. """
  42. # Prevent a circular reference
  43. from .sync import AsyncToSync, SyncToAsync
  44. # First, pull the current task if we can
  45. context_id = SyncToAsync.get_current_task()
  46. context_is_async = True
  47. # OK, let's try for a thread ID
  48. if context_id is None:
  49. context_id = threading.current_thread()
  50. context_is_async = False
  51. # If we're thread-critical, we stop here, as we can't share contexts.
  52. if self._thread_critical:
  53. return context_id
  54. # Now, take those and see if we can resolve them through the launch maps
  55. for i in range(sys.getrecursionlimit()):
  56. try:
  57. if context_is_async:
  58. # Tasks have a source thread in AsyncToSync
  59. context_id = AsyncToSync.launch_map[context_id]
  60. context_is_async = False
  61. else:
  62. # Threads have a source task in SyncToAsync
  63. context_id = SyncToAsync.launch_map[context_id]
  64. context_is_async = True
  65. except KeyError:
  66. break
  67. else:
  68. # Catch infinite loops (they happen if you are screwing around
  69. # with AsyncToSync implementations)
  70. raise RuntimeError("Infinite launch_map loops")
  71. return context_id
  72. def _get_storage(self):
  73. context_obj = self._get_context_id()
  74. if not hasattr(context_obj, self._attr_name):
  75. setattr(context_obj, self._attr_name, {})
  76. self._context_refs.add(context_obj)
  77. return getattr(context_obj, self._attr_name)
  78. def __del__(self):
  79. try:
  80. for context_obj in self._context_refs:
  81. try:
  82. delattr(context_obj, self._attr_name)
  83. except AttributeError:
  84. pass
  85. except TypeError:
  86. # WeakSet.__iter__ can crash when interpreter is shutting down due
  87. # to _IterationGuard being None.
  88. pass
  89. def __getattr__(self, key):
  90. with self._thread_lock:
  91. storage = self._get_storage()
  92. if key in storage:
  93. return storage[key]
  94. else:
  95. raise AttributeError(f"{self!r} object has no attribute {key!r}")
  96. def __setattr__(self, key, value):
  97. if key in ("_context_refs", "_thread_critical", "_thread_lock", "_attr_name"):
  98. return super().__setattr__(key, value)
  99. with self._thread_lock:
  100. storage = self._get_storage()
  101. storage[key] = value
  102. def __delattr__(self, key):
  103. with self._thread_lock:
  104. storage = self._get_storage()
  105. if key in storage:
  106. del storage[key]
  107. else:
  108. raise AttributeError(f"{self!r} object has no attribute {key!r}")