throttling.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """
  2. Provides various throttling policies.
  3. """
  4. import time
  5. from django.core.cache import cache as default_cache
  6. from django.core.exceptions import ImproperlyConfigured
  7. from rest_framework.settings import api_settings
  8. class BaseThrottle:
  9. """
  10. Rate throttling of requests.
  11. """
  12. def allow_request(self, request, view):
  13. """
  14. Return `True` if the request should be allowed, `False` otherwise.
  15. """
  16. raise NotImplementedError('.allow_request() must be overridden')
  17. def get_ident(self, request):
  18. """
  19. Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
  20. if present and number of proxies is > 0. If not use all of
  21. HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
  22. """
  23. xff = request.META.get('HTTP_X_FORWARDED_FOR')
  24. remote_addr = request.META.get('REMOTE_ADDR')
  25. num_proxies = api_settings.NUM_PROXIES
  26. if num_proxies is not None:
  27. if num_proxies == 0 or xff is None:
  28. return remote_addr
  29. addrs = xff.split(',')
  30. client_addr = addrs[-min(num_proxies, len(addrs))]
  31. return client_addr.strip()
  32. return ''.join(xff.split()) if xff else remote_addr
  33. def wait(self):
  34. """
  35. Optionally, return a recommended number of seconds to wait before
  36. the next request.
  37. """
  38. return None
  39. class SimpleRateThrottle(BaseThrottle):
  40. """
  41. A simple cache implementation, that only requires `.get_cache_key()`
  42. to be overridden.
  43. The rate (requests / seconds) is set by a `rate` attribute on the Throttle
  44. class. The attribute is a string of the form 'number_of_requests/period'.
  45. Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
  46. Previous request information used for throttling is stored in the cache.
  47. """
  48. cache = default_cache
  49. timer = time.time
  50. cache_format = 'throttle_%(scope)s_%(ident)s'
  51. scope = None
  52. THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
  53. def __init__(self):
  54. if not getattr(self, 'rate', None):
  55. self.rate = self.get_rate()
  56. self.num_requests, self.duration = self.parse_rate(self.rate)
  57. def get_cache_key(self, request, view):
  58. """
  59. Should return a unique cache-key which can be used for throttling.
  60. Must be overridden.
  61. May return `None` if the request should not be throttled.
  62. """
  63. raise NotImplementedError('.get_cache_key() must be overridden')
  64. def get_rate(self):
  65. """
  66. Determine the string representation of the allowed request rate.
  67. """
  68. if not getattr(self, 'scope', None):
  69. msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
  70. self.__class__.__name__)
  71. raise ImproperlyConfigured(msg)
  72. try:
  73. return self.THROTTLE_RATES[self.scope]
  74. except KeyError:
  75. msg = "No default throttle rate set for '%s' scope" % self.scope
  76. raise ImproperlyConfigured(msg)
  77. def parse_rate(self, rate):
  78. """
  79. Given the request rate string, return a two tuple of:
  80. <allowed number of requests>, <period of time in seconds>
  81. """
  82. if rate is None:
  83. return (None, None)
  84. num, period = rate.split('/')
  85. num_requests = int(num)
  86. duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
  87. return (num_requests, duration)
  88. def allow_request(self, request, view):
  89. """
  90. Implement the check to see if the request should be throttled.
  91. On success calls `throttle_success`.
  92. On failure calls `throttle_failure`.
  93. """
  94. if self.rate is None:
  95. return True
  96. self.key = self.get_cache_key(request, view)
  97. if self.key is None:
  98. return True
  99. self.history = self.cache.get(self.key, [])
  100. self.now = self.timer()
  101. # Drop any requests from the history which have now passed the
  102. # throttle duration
  103. while self.history and self.history[-1] <= self.now - self.duration:
  104. self.history.pop()
  105. if len(self.history) >= self.num_requests:
  106. return self.throttle_failure()
  107. return self.throttle_success()
  108. def throttle_success(self):
  109. """
  110. Inserts the current request's timestamp along with the key
  111. into the cache.
  112. """
  113. self.history.insert(0, self.now)
  114. self.cache.set(self.key, self.history, self.duration)
  115. return True
  116. def throttle_failure(self):
  117. """
  118. Called when a request to the API has failed due to throttling.
  119. """
  120. return False
  121. def wait(self):
  122. """
  123. Returns the recommended next request time in seconds.
  124. """
  125. if self.history:
  126. remaining_duration = self.duration - (self.now - self.history[-1])
  127. else:
  128. remaining_duration = self.duration
  129. available_requests = self.num_requests - len(self.history) + 1
  130. if available_requests <= 0:
  131. return None
  132. return remaining_duration / float(available_requests)
  133. class AnonRateThrottle(SimpleRateThrottle):
  134. """
  135. Limits the rate of API calls that may be made by a anonymous users.
  136. The IP address of the request will be used as the unique cache key.
  137. """
  138. scope = 'anon'
  139. def get_cache_key(self, request, view):
  140. if request.user.is_authenticated:
  141. return None # Only throttle unauthenticated requests.
  142. return self.cache_format % {
  143. 'scope': self.scope,
  144. 'ident': self.get_ident(request)
  145. }
  146. class UserRateThrottle(SimpleRateThrottle):
  147. """
  148. Limits the rate of API calls that may be made by a given user.
  149. The user id will be used as a unique cache key if the user is
  150. authenticated. For anonymous requests, the IP address of the request will
  151. be used.
  152. """
  153. scope = 'user'
  154. def get_cache_key(self, request, view):
  155. if request.user.is_authenticated:
  156. ident = request.user.pk
  157. else:
  158. ident = self.get_ident(request)
  159. return self.cache_format % {
  160. 'scope': self.scope,
  161. 'ident': ident
  162. }
  163. class ScopedRateThrottle(SimpleRateThrottle):
  164. """
  165. Limits the rate of API calls by different amounts for various parts of
  166. the API. Any view that has the `throttle_scope` property set will be
  167. throttled. The unique cache key will be generated by concatenating the
  168. user id of the request, and the scope of the view being accessed.
  169. """
  170. scope_attr = 'throttle_scope'
  171. def __init__(self):
  172. # Override the usual SimpleRateThrottle, because we can't determine
  173. # the rate until called by the view.
  174. pass
  175. def allow_request(self, request, view):
  176. # We can only determine the scope once we're called by the view.
  177. self.scope = getattr(view, self.scope_attr, None)
  178. # If a view does not have a `throttle_scope` always allow the request
  179. if not self.scope:
  180. return True
  181. # Determine the allowed request rate as we normally would during
  182. # the `__init__` call.
  183. self.rate = self.get_rate()
  184. self.num_requests, self.duration = self.parse_rate(self.rate)
  185. # We can now proceed as normal.
  186. return super().allow_request(request, view)
  187. def get_cache_key(self, request, view):
  188. """
  189. If `view.throttle_scope` is not set, don't apply this throttle.
  190. Otherwise generate the unique cache key by concatenating the user id
  191. with the '.throttle_scope` property of the view.
  192. """
  193. if request.user.is_authenticated:
  194. ident = request.user.pk
  195. else:
  196. ident = self.get_ident(request)
  197. return self.cache_format % {
  198. 'scope': self.scope,
  199. 'ident': ident
  200. }