versioning.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import re
  2. from django.utils.translation import gettext_lazy as _
  3. from rest_framework import exceptions
  4. from rest_framework.compat import unicode_http_header
  5. from rest_framework.reverse import _reverse
  6. from rest_framework.settings import api_settings
  7. from rest_framework.templatetags.rest_framework import replace_query_param
  8. from rest_framework.utils.mediatypes import _MediaType
  9. class BaseVersioning:
  10. default_version = api_settings.DEFAULT_VERSION
  11. allowed_versions = api_settings.ALLOWED_VERSIONS
  12. version_param = api_settings.VERSION_PARAM
  13. def determine_version(self, request, *args, **kwargs):
  14. msg = '{cls}.determine_version() must be implemented.'
  15. raise NotImplementedError(msg.format(
  16. cls=self.__class__.__name__
  17. ))
  18. def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
  19. return _reverse(viewname, args, kwargs, request, format, **extra)
  20. def is_allowed_version(self, version):
  21. if not self.allowed_versions:
  22. return True
  23. return ((version is not None and version == self.default_version) or
  24. (version in self.allowed_versions))
  25. class AcceptHeaderVersioning(BaseVersioning):
  26. """
  27. GET /something/ HTTP/1.1
  28. Host: example.com
  29. Accept: application/json; version=1.0
  30. """
  31. invalid_version_message = _('Invalid version in "Accept" header.')
  32. def determine_version(self, request, *args, **kwargs):
  33. media_type = _MediaType(request.accepted_media_type)
  34. version = media_type.params.get(self.version_param, self.default_version)
  35. version = unicode_http_header(version)
  36. if not self.is_allowed_version(version):
  37. raise exceptions.NotAcceptable(self.invalid_version_message)
  38. return version
  39. # We don't need to implement `reverse`, as the versioning is based
  40. # on the `Accept` header, not on the request URL.
  41. class URLPathVersioning(BaseVersioning):
  42. """
  43. To the client this is the same style as `NamespaceVersioning`.
  44. The difference is in the backend - this implementation uses
  45. Django's URL keyword arguments to determine the version.
  46. An example URL conf for two views that accept two different versions.
  47. urlpatterns = [
  48. re_path(r'^(?P<version>[v1|v2]+)/users/$', users_list, name='users-list'),
  49. re_path(r'^(?P<version>[v1|v2]+)/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
  50. ]
  51. GET /1.0/something/ HTTP/1.1
  52. Host: example.com
  53. Accept: application/json
  54. """
  55. invalid_version_message = _('Invalid version in URL path.')
  56. def determine_version(self, request, *args, **kwargs):
  57. version = kwargs.get(self.version_param, self.default_version)
  58. if version is None:
  59. version = self.default_version
  60. if not self.is_allowed_version(version):
  61. raise exceptions.NotFound(self.invalid_version_message)
  62. return version
  63. def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
  64. if request.version is not None:
  65. kwargs = {} if (kwargs is None) else kwargs
  66. kwargs[self.version_param] = request.version
  67. return super().reverse(
  68. viewname, args, kwargs, request, format, **extra
  69. )
  70. class NamespaceVersioning(BaseVersioning):
  71. """
  72. To the client this is the same style as `URLPathVersioning`.
  73. The difference is in the backend - this implementation uses
  74. Django's URL namespaces to determine the version.
  75. An example URL conf that is namespaced into two separate versions
  76. # users/urls.py
  77. urlpatterns = [
  78. path('/users/', users_list, name='users-list'),
  79. path('/users/<int:pk>/', users_detail, name='users-detail')
  80. ]
  81. # urls.py
  82. urlpatterns = [
  83. path('v1/', include('users.urls', namespace='v1')),
  84. path('v2/', include('users.urls', namespace='v2'))
  85. ]
  86. GET /1.0/something/ HTTP/1.1
  87. Host: example.com
  88. Accept: application/json
  89. """
  90. invalid_version_message = _('Invalid version in URL path. Does not match any version namespace.')
  91. def determine_version(self, request, *args, **kwargs):
  92. resolver_match = getattr(request, 'resolver_match', None)
  93. if resolver_match is None or not resolver_match.namespace:
  94. return self.default_version
  95. # Allow for possibly nested namespaces.
  96. possible_versions = resolver_match.namespace.split(':')
  97. for version in possible_versions:
  98. if self.is_allowed_version(version):
  99. return version
  100. raise exceptions.NotFound(self.invalid_version_message)
  101. def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
  102. if request.version is not None:
  103. viewname = self.get_versioned_viewname(viewname, request)
  104. return super().reverse(
  105. viewname, args, kwargs, request, format, **extra
  106. )
  107. def get_versioned_viewname(self, viewname, request):
  108. return request.version + ':' + viewname
  109. class HostNameVersioning(BaseVersioning):
  110. """
  111. GET /something/ HTTP/1.1
  112. Host: v1.example.com
  113. Accept: application/json
  114. """
  115. hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$')
  116. invalid_version_message = _('Invalid version in hostname.')
  117. def determine_version(self, request, *args, **kwargs):
  118. hostname, separator, port = request.get_host().partition(':')
  119. match = self.hostname_regex.match(hostname)
  120. if not match:
  121. return self.default_version
  122. version = match.group(1)
  123. if not self.is_allowed_version(version):
  124. raise exceptions.NotFound(self.invalid_version_message)
  125. return version
  126. # We don't need to implement `reverse`, as the hostname will already be
  127. # preserved as part of the REST framework `reverse` implementation.
  128. class QueryParameterVersioning(BaseVersioning):
  129. """
  130. GET /something/?version=0.1 HTTP/1.1
  131. Host: example.com
  132. Accept: application/json
  133. """
  134. invalid_version_message = _('Invalid version in query parameter.')
  135. def determine_version(self, request, *args, **kwargs):
  136. version = request.query_params.get(self.version_param, self.default_version)
  137. if not self.is_allowed_version(version):
  138. raise exceptions.NotFound(self.invalid_version_message)
  139. return version
  140. def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
  141. url = super().reverse(
  142. viewname, args, kwargs, request, format, **extra
  143. )
  144. if request.version is not None:
  145. return replace_query_param(url, self.version_param, request.version)
  146. return url