test.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
  2. # to make it harder for the user to import the wrong thing without realizing.
  3. import io
  4. from importlib import import_module
  5. import django
  6. from django.conf import settings
  7. from django.core.exceptions import ImproperlyConfigured
  8. from django.core.handlers.wsgi import WSGIHandler
  9. from django.test import override_settings, testcases
  10. from django.test.client import Client as DjangoClient
  11. from django.test.client import ClientHandler
  12. from django.test.client import RequestFactory as DjangoRequestFactory
  13. from django.utils.encoding import force_bytes
  14. from django.utils.http import urlencode
  15. from rest_framework.compat import coreapi, requests
  16. from rest_framework.settings import api_settings
  17. def force_authenticate(request, user=None, token=None):
  18. request._force_auth_user = user
  19. request._force_auth_token = token
  20. if requests is not None:
  21. class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
  22. def get_all(self, key, default):
  23. return self.getheaders(key)
  24. class MockOriginalResponse:
  25. def __init__(self, headers):
  26. self.msg = HeaderDict(headers)
  27. self.closed = False
  28. def isclosed(self):
  29. return self.closed
  30. def close(self):
  31. self.closed = True
  32. class DjangoTestAdapter(requests.adapters.HTTPAdapter):
  33. """
  34. A transport adapter for `requests`, that makes requests via the
  35. Django WSGI app, rather than making actual HTTP requests over the network.
  36. """
  37. def __init__(self):
  38. self.app = WSGIHandler()
  39. self.factory = DjangoRequestFactory()
  40. def get_environ(self, request):
  41. """
  42. Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
  43. """
  44. method = request.method
  45. url = request.url
  46. kwargs = {}
  47. # Set request content, if any exists.
  48. if request.body is not None:
  49. if hasattr(request.body, 'read'):
  50. kwargs['data'] = request.body.read()
  51. else:
  52. kwargs['data'] = request.body
  53. if 'content-type' in request.headers:
  54. kwargs['content_type'] = request.headers['content-type']
  55. # Set request headers.
  56. for key, value in request.headers.items():
  57. key = key.upper()
  58. if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
  59. continue
  60. kwargs['HTTP_%s' % key.replace('-', '_')] = value
  61. return self.factory.generic(method, url, **kwargs).environ
  62. def send(self, request, *args, **kwargs):
  63. """
  64. Make an outgoing request to the Django WSGI application.
  65. """
  66. raw_kwargs = {}
  67. def start_response(wsgi_status, wsgi_headers, exc_info=None):
  68. status, _, reason = wsgi_status.partition(' ')
  69. raw_kwargs['status'] = int(status)
  70. raw_kwargs['reason'] = reason
  71. raw_kwargs['headers'] = wsgi_headers
  72. raw_kwargs['version'] = 11
  73. raw_kwargs['preload_content'] = False
  74. raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
  75. # Make the outgoing request via WSGI.
  76. environ = self.get_environ(request)
  77. wsgi_response = self.app(environ, start_response)
  78. # Build the underlying urllib3.HTTPResponse
  79. raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
  80. raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
  81. # Build the requests.Response
  82. return self.build_response(request, raw)
  83. def close(self):
  84. pass
  85. class RequestsClient(requests.Session):
  86. def __init__(self, *args, **kwargs):
  87. super().__init__(*args, **kwargs)
  88. adapter = DjangoTestAdapter()
  89. self.mount('http://', adapter)
  90. self.mount('https://', adapter)
  91. def request(self, method, url, *args, **kwargs):
  92. if not url.startswith('http'):
  93. raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
  94. return super().request(method, url, *args, **kwargs)
  95. else:
  96. def RequestsClient(*args, **kwargs):
  97. raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
  98. if coreapi is not None:
  99. class CoreAPIClient(coreapi.Client):
  100. def __init__(self, *args, **kwargs):
  101. self._session = RequestsClient()
  102. kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
  103. super().__init__(*args, **kwargs)
  104. @property
  105. def session(self):
  106. return self._session
  107. else:
  108. def CoreAPIClient(*args, **kwargs):
  109. raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
  110. class APIRequestFactory(DjangoRequestFactory):
  111. renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
  112. default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
  113. def __init__(self, enforce_csrf_checks=False, **defaults):
  114. self.enforce_csrf_checks = enforce_csrf_checks
  115. self.renderer_classes = {}
  116. for cls in self.renderer_classes_list:
  117. self.renderer_classes[cls.format] = cls
  118. super().__init__(**defaults)
  119. def _encode_data(self, data, format=None, content_type=None):
  120. """
  121. Encode the data returning a two tuple of (bytes, content_type)
  122. """
  123. if data is None:
  124. return ('', content_type)
  125. assert format is None or content_type is None, (
  126. 'You may not set both `format` and `content_type`.'
  127. )
  128. if content_type:
  129. # Content type specified explicitly, treat data as a raw bytestring
  130. ret = force_bytes(data, settings.DEFAULT_CHARSET)
  131. else:
  132. format = format or self.default_format
  133. assert format in self.renderer_classes, (
  134. "Invalid format '{}'. Available formats are {}. "
  135. "Set TEST_REQUEST_RENDERER_CLASSES to enable "
  136. "extra request formats.".format(
  137. format,
  138. ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
  139. )
  140. )
  141. # Use format and render the data into a bytestring
  142. renderer = self.renderer_classes[format]()
  143. ret = renderer.render(data)
  144. # Determine the content-type header from the renderer
  145. content_type = renderer.media_type
  146. if renderer.charset:
  147. content_type = "{}; charset={}".format(
  148. content_type, renderer.charset
  149. )
  150. # Coerce text to bytes if required.
  151. if isinstance(ret, str):
  152. ret = ret.encode(renderer.charset)
  153. return ret, content_type
  154. def get(self, path, data=None, **extra):
  155. r = {
  156. 'QUERY_STRING': urlencode(data or {}, doseq=True),
  157. }
  158. if not data and '?' in path:
  159. # Fix to support old behavior where you have the arguments in the
  160. # url. See #1461.
  161. query_string = force_bytes(path.split('?')[1])
  162. query_string = query_string.decode('iso-8859-1')
  163. r['QUERY_STRING'] = query_string
  164. r.update(extra)
  165. return self.generic('GET', path, **r)
  166. def post(self, path, data=None, format=None, content_type=None, **extra):
  167. data, content_type = self._encode_data(data, format, content_type)
  168. return self.generic('POST', path, data, content_type, **extra)
  169. def put(self, path, data=None, format=None, content_type=None, **extra):
  170. data, content_type = self._encode_data(data, format, content_type)
  171. return self.generic('PUT', path, data, content_type, **extra)
  172. def patch(self, path, data=None, format=None, content_type=None, **extra):
  173. data, content_type = self._encode_data(data, format, content_type)
  174. return self.generic('PATCH', path, data, content_type, **extra)
  175. def delete(self, path, data=None, format=None, content_type=None, **extra):
  176. data, content_type = self._encode_data(data, format, content_type)
  177. return self.generic('DELETE', path, data, content_type, **extra)
  178. def options(self, path, data=None, format=None, content_type=None, **extra):
  179. data, content_type = self._encode_data(data, format, content_type)
  180. return self.generic('OPTIONS', path, data, content_type, **extra)
  181. def generic(self, method, path, data='',
  182. content_type='application/octet-stream', secure=False, **extra):
  183. # Include the CONTENT_TYPE, regardless of whether or not data is empty.
  184. if content_type is not None:
  185. extra['CONTENT_TYPE'] = str(content_type)
  186. return super().generic(
  187. method, path, data, content_type, secure, **extra)
  188. def request(self, **kwargs):
  189. request = super().request(**kwargs)
  190. request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
  191. return request
  192. class ForceAuthClientHandler(ClientHandler):
  193. """
  194. A patched version of ClientHandler that can enforce authentication
  195. on the outgoing requests.
  196. """
  197. def __init__(self, *args, **kwargs):
  198. self._force_user = None
  199. self._force_token = None
  200. super().__init__(*args, **kwargs)
  201. def get_response(self, request):
  202. # This is the simplest place we can hook into to patch the
  203. # request object.
  204. force_authenticate(request, self._force_user, self._force_token)
  205. return super().get_response(request)
  206. class APIClient(APIRequestFactory, DjangoClient):
  207. def __init__(self, enforce_csrf_checks=False, **defaults):
  208. super().__init__(**defaults)
  209. self.handler = ForceAuthClientHandler(enforce_csrf_checks)
  210. self._credentials = {}
  211. def credentials(self, **kwargs):
  212. """
  213. Sets headers that will be used on every outgoing request.
  214. """
  215. self._credentials = kwargs
  216. def force_authenticate(self, user=None, token=None):
  217. """
  218. Forcibly authenticates outgoing requests with the given
  219. user and/or token.
  220. """
  221. self.handler._force_user = user
  222. self.handler._force_token = token
  223. if user is None:
  224. self.logout() # Also clear any possible session info if required
  225. def request(self, **kwargs):
  226. # Ensure that any credentials set get added to every request.
  227. kwargs.update(self._credentials)
  228. return super().request(**kwargs)
  229. def get(self, path, data=None, follow=False, **extra):
  230. response = super().get(path, data=data, **extra)
  231. if follow:
  232. response = self._handle_redirects(response, **extra)
  233. return response
  234. def post(self, path, data=None, format=None, content_type=None,
  235. follow=False, **extra):
  236. response = super().post(
  237. path, data=data, format=format, content_type=content_type, **extra)
  238. if follow:
  239. response = self._handle_redirects(response, **extra)
  240. return response
  241. def put(self, path, data=None, format=None, content_type=None,
  242. follow=False, **extra):
  243. response = super().put(
  244. path, data=data, format=format, content_type=content_type, **extra)
  245. if follow:
  246. response = self._handle_redirects(response, **extra)
  247. return response
  248. def patch(self, path, data=None, format=None, content_type=None,
  249. follow=False, **extra):
  250. response = super().patch(
  251. path, data=data, format=format, content_type=content_type, **extra)
  252. if follow:
  253. response = self._handle_redirects(response, **extra)
  254. return response
  255. def delete(self, path, data=None, format=None, content_type=None,
  256. follow=False, **extra):
  257. response = super().delete(
  258. path, data=data, format=format, content_type=content_type, **extra)
  259. if follow:
  260. response = self._handle_redirects(response, **extra)
  261. return response
  262. def options(self, path, data=None, format=None, content_type=None,
  263. follow=False, **extra):
  264. response = super().options(
  265. path, data=data, format=format, content_type=content_type, **extra)
  266. if follow:
  267. response = self._handle_redirects(response, **extra)
  268. return response
  269. def logout(self):
  270. self._credentials = {}
  271. # Also clear any `force_authenticate`
  272. self.handler._force_user = None
  273. self.handler._force_token = None
  274. if self.session:
  275. super().logout()
  276. class APITransactionTestCase(testcases.TransactionTestCase):
  277. client_class = APIClient
  278. class APITestCase(testcases.TestCase):
  279. client_class = APIClient
  280. class APISimpleTestCase(testcases.SimpleTestCase):
  281. client_class = APIClient
  282. class APILiveServerTestCase(testcases.LiveServerTestCase):
  283. client_class = APIClient
  284. def cleanup_url_patterns(cls):
  285. if hasattr(cls, '_module_urlpatterns'):
  286. cls._module.urlpatterns = cls._module_urlpatterns
  287. else:
  288. del cls._module.urlpatterns
  289. class URLPatternsTestCase(testcases.SimpleTestCase):
  290. """
  291. Isolate URL patterns on a per-TestCase basis. For example,
  292. class ATestCase(URLPatternsTestCase):
  293. urlpatterns = [...]
  294. def test_something(self):
  295. ...
  296. class AnotherTestCase(URLPatternsTestCase):
  297. urlpatterns = [...]
  298. def test_something_else(self):
  299. ...
  300. """
  301. @classmethod
  302. def setUpClass(cls):
  303. # Get the module of the TestCase subclass
  304. cls._module = import_module(cls.__module__)
  305. cls._override = override_settings(ROOT_URLCONF=cls.__module__)
  306. if hasattr(cls._module, 'urlpatterns'):
  307. cls._module_urlpatterns = cls._module.urlpatterns
  308. cls._module.urlpatterns = cls.urlpatterns
  309. cls._override.enable()
  310. if django.VERSION > (4, 0):
  311. cls.addClassCleanup(cls._override.disable)
  312. cls.addClassCleanup(cleanup_url_patterns, cls)
  313. super().setUpClass()
  314. if django.VERSION < (4, 0):
  315. @classmethod
  316. def tearDownClass(cls):
  317. super().tearDownClass()
  318. cls._override.disable()
  319. if hasattr(cls, '_module_urlpatterns'):
  320. cls._module.urlpatterns = cls._module_urlpatterns
  321. else:
  322. del cls._module.urlpatterns