|
- # Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
- # to make it harder for the user to import the wrong thing without realizing.
- import io
- from importlib import import_module
- import django
- from django.conf import settings
- from django.core.exceptions import ImproperlyConfigured
- from django.core.handlers.wsgi import WSGIHandler
- from django.test import override_settings, testcases
- from django.test.client import Client as DjangoClient
- from django.test.client import ClientHandler
- from django.test.client import RequestFactory as DjangoRequestFactory
- from django.utils.encoding import force_bytes
- from django.utils.http import urlencode
- from rest_framework.compat import coreapi, requests
- from rest_framework.settings import api_settings
- def force_authenticate(request, user=None, token=None):
- request._force_auth_user = user
- request._force_auth_token = token
- if requests is not None:
- class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
- def get_all(self, key, default):
- return self.getheaders(key)
- class MockOriginalResponse:
- def __init__(self, headers):
- self.msg = HeaderDict(headers)
- self.closed = False
- def isclosed(self):
- return self.closed
- def close(self):
- self.closed = True
- class DjangoTestAdapter(requests.adapters.HTTPAdapter):
- """
- A transport adapter for `requests`, that makes requests via the
- Django WSGI app, rather than making actual HTTP requests over the network.
- """
- def __init__(self):
- self.app = WSGIHandler()
- self.factory = DjangoRequestFactory()
- def get_environ(self, request):
- """
- Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
- """
- method = request.method
- url = request.url
- kwargs = {}
- # Set request content, if any exists.
- if request.body is not None:
- if hasattr(request.body, 'read'):
- kwargs['data'] = request.body.read()
- else:
- kwargs['data'] = request.body
- if 'content-type' in request.headers:
- kwargs['content_type'] = request.headers['content-type']
- # Set request headers.
- for key, value in request.headers.items():
- key = key.upper()
- if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
- continue
- kwargs['HTTP_%s' % key.replace('-', '_')] = value
- return self.factory.generic(method, url, **kwargs).environ
- def send(self, request, *args, **kwargs):
- """
- Make an outgoing request to the Django WSGI application.
- """
- raw_kwargs = {}
- def start_response(wsgi_status, wsgi_headers, exc_info=None):
- status, _, reason = wsgi_status.partition(' ')
- raw_kwargs['status'] = int(status)
- raw_kwargs['reason'] = reason
- raw_kwargs['headers'] = wsgi_headers
- raw_kwargs['version'] = 11
- raw_kwargs['preload_content'] = False
- raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
- # Make the outgoing request via WSGI.
- environ = self.get_environ(request)
- wsgi_response = self.app(environ, start_response)
- # Build the underlying urllib3.HTTPResponse
- raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
- raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
- # Build the requests.Response
- return self.build_response(request, raw)
- def close(self):
- pass
- class RequestsClient(requests.Session):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- adapter = DjangoTestAdapter()
- self.mount('http://', adapter)
- self.mount('https://', adapter)
- def request(self, method, url, *args, **kwargs):
- if not url.startswith('http'):
- raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
- return super().request(method, url, *args, **kwargs)
- else:
- def RequestsClient(*args, **kwargs):
- raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
- if coreapi is not None:
- class CoreAPIClient(coreapi.Client):
- def __init__(self, *args, **kwargs):
- self._session = RequestsClient()
- kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
- super().__init__(*args, **kwargs)
- @property
- def session(self):
- return self._session
- else:
- def CoreAPIClient(*args, **kwargs):
- raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
- class APIRequestFactory(DjangoRequestFactory):
- renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
- default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
- def __init__(self, enforce_csrf_checks=False, **defaults):
- self.enforce_csrf_checks = enforce_csrf_checks
- self.renderer_classes = {}
- for cls in self.renderer_classes_list:
- self.renderer_classes[cls.format] = cls
- super().__init__(**defaults)
- def _encode_data(self, data, format=None, content_type=None):
- """
- Encode the data returning a two tuple of (bytes, content_type)
- """
- if data is None:
- return ('', content_type)
- assert format is None or content_type is None, (
- 'You may not set both `format` and `content_type`.'
- )
- if content_type:
- # Content type specified explicitly, treat data as a raw bytestring
- ret = force_bytes(data, settings.DEFAULT_CHARSET)
- else:
- format = format or self.default_format
- assert format in self.renderer_classes, (
- "Invalid format '{}'. Available formats are {}. "
- "Set TEST_REQUEST_RENDERER_CLASSES to enable "
- "extra request formats.".format(
- format,
- ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
- )
- )
- # Use format and render the data into a bytestring
- renderer = self.renderer_classes[format]()
- ret = renderer.render(data)
- # Determine the content-type header from the renderer
- content_type = renderer.media_type
- if renderer.charset:
- content_type = "{}; charset={}".format(
- content_type, renderer.charset
- )
- # Coerce text to bytes if required.
- if isinstance(ret, str):
- ret = ret.encode(renderer.charset)
- return ret, content_type
- def get(self, path, data=None, **extra):
- r = {
- 'QUERY_STRING': urlencode(data or {}, doseq=True),
- }
- if not data and '?' in path:
- # Fix to support old behavior where you have the arguments in the
- # url. See #1461.
- query_string = force_bytes(path.split('?')[1])
- query_string = query_string.decode('iso-8859-1')
- r['QUERY_STRING'] = query_string
- r.update(extra)
- return self.generic('GET', path, **r)
- def post(self, path, data=None, format=None, content_type=None, **extra):
- data, content_type = self._encode_data(data, format, content_type)
- return self.generic('POST', path, data, content_type, **extra)
- def put(self, path, data=None, format=None, content_type=None, **extra):
- data, content_type = self._encode_data(data, format, content_type)
- return self.generic('PUT', path, data, content_type, **extra)
- def patch(self, path, data=None, format=None, content_type=None, **extra):
- data, content_type = self._encode_data(data, format, content_type)
- return self.generic('PATCH', path, data, content_type, **extra)
- def delete(self, path, data=None, format=None, content_type=None, **extra):
- data, content_type = self._encode_data(data, format, content_type)
- return self.generic('DELETE', path, data, content_type, **extra)
- def options(self, path, data=None, format=None, content_type=None, **extra):
- data, content_type = self._encode_data(data, format, content_type)
- return self.generic('OPTIONS', path, data, content_type, **extra)
- def generic(self, method, path, data='',
- content_type='application/octet-stream', secure=False, **extra):
- # Include the CONTENT_TYPE, regardless of whether or not data is empty.
- if content_type is not None:
- extra['CONTENT_TYPE'] = str(content_type)
- return super().generic(
- method, path, data, content_type, secure, **extra)
- def request(self, **kwargs):
- request = super().request(**kwargs)
- request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
- return request
- class ForceAuthClientHandler(ClientHandler):
- """
- A patched version of ClientHandler that can enforce authentication
- on the outgoing requests.
- """
- def __init__(self, *args, **kwargs):
- self._force_user = None
- self._force_token = None
- super().__init__(*args, **kwargs)
- def get_response(self, request):
- # This is the simplest place we can hook into to patch the
- # request object.
- force_authenticate(request, self._force_user, self._force_token)
- return super().get_response(request)
- class APIClient(APIRequestFactory, DjangoClient):
- def __init__(self, enforce_csrf_checks=False, **defaults):
- super().__init__(**defaults)
- self.handler = ForceAuthClientHandler(enforce_csrf_checks)
- self._credentials = {}
- def credentials(self, **kwargs):
- """
- Sets headers that will be used on every outgoing request.
- """
- self._credentials = kwargs
- def force_authenticate(self, user=None, token=None):
- """
- Forcibly authenticates outgoing requests with the given
- user and/or token.
- """
- self.handler._force_user = user
- self.handler._force_token = token
- if user is None:
- self.logout() # Also clear any possible session info if required
- def request(self, **kwargs):
- # Ensure that any credentials set get added to every request.
- kwargs.update(self._credentials)
- return super().request(**kwargs)
- def get(self, path, data=None, follow=False, **extra):
- response = super().get(path, data=data, **extra)
- if follow:
- response = self._handle_redirects(response, **extra)
- return response
- def post(self, path, data=None, format=None, content_type=None,
- follow=False, **extra):
- response = super().post(
- path, data=data, format=format, content_type=content_type, **extra)
- if follow:
- response = self._handle_redirects(response, **extra)
- return response
- def put(self, path, data=None, format=None, content_type=None,
- follow=False, **extra):
- response = super().put(
- path, data=data, format=format, content_type=content_type, **extra)
- if follow:
- response = self._handle_redirects(response, **extra)
- return response
- def patch(self, path, data=None, format=None, content_type=None,
- follow=False, **extra):
- response = super().patch(
- path, data=data, format=format, content_type=content_type, **extra)
- if follow:
- response = self._handle_redirects(response, **extra)
- return response
- def delete(self, path, data=None, format=None, content_type=None,
- follow=False, **extra):
- response = super().delete(
- path, data=data, format=format, content_type=content_type, **extra)
- if follow:
- response = self._handle_redirects(response, **extra)
- return response
- def options(self, path, data=None, format=None, content_type=None,
- follow=False, **extra):
- response = super().options(
- path, data=data, format=format, content_type=content_type, **extra)
- if follow:
- response = self._handle_redirects(response, **extra)
- return response
- def logout(self):
- self._credentials = {}
- # Also clear any `force_authenticate`
- self.handler._force_user = None
- self.handler._force_token = None
- if self.session:
- super().logout()
- class APITransactionTestCase(testcases.TransactionTestCase):
- client_class = APIClient
- class APITestCase(testcases.TestCase):
- client_class = APIClient
- class APISimpleTestCase(testcases.SimpleTestCase):
- client_class = APIClient
- class APILiveServerTestCase(testcases.LiveServerTestCase):
- client_class = APIClient
- def cleanup_url_patterns(cls):
- if hasattr(cls, '_module_urlpatterns'):
- cls._module.urlpatterns = cls._module_urlpatterns
- else:
- del cls._module.urlpatterns
- class URLPatternsTestCase(testcases.SimpleTestCase):
- """
- Isolate URL patterns on a per-TestCase basis. For example,
- class ATestCase(URLPatternsTestCase):
- urlpatterns = [...]
- def test_something(self):
- ...
- class AnotherTestCase(URLPatternsTestCase):
- urlpatterns = [...]
- def test_something_else(self):
- ...
- """
- @classmethod
- def setUpClass(cls):
- # Get the module of the TestCase subclass
- cls._module = import_module(cls.__module__)
- cls._override = override_settings(ROOT_URLCONF=cls.__module__)
- if hasattr(cls._module, 'urlpatterns'):
- cls._module_urlpatterns = cls._module.urlpatterns
- cls._module.urlpatterns = cls.urlpatterns
- cls._override.enable()
- if django.VERSION > (4, 0):
- cls.addClassCleanup(cls._override.disable)
- cls.addClassCleanup(cleanup_url_patterns, cls)
- super().setUpClass()
- if django.VERSION < (4, 0):
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- cls._override.disable()
- if hasattr(cls, '_module_urlpatterns'):
- cls._module.urlpatterns = cls._module_urlpatterns
- else:
- del cls._module.urlpatterns
|