routers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. """
  2. Routers provide a convenient and consistent way of automatically
  3. determining the URL conf for your API.
  4. They are used by simply instantiating a Router class, and then registering
  5. all the required ViewSets with that router.
  6. For example, you might have a `urls.py` that looks something like this:
  7. router = routers.DefaultRouter()
  8. router.register('users', UserViewSet, 'user')
  9. router.register('accounts', AccountViewSet, 'account')
  10. urlpatterns = router.urls
  11. """
  12. import itertools
  13. from collections import OrderedDict, namedtuple
  14. from django.core.exceptions import ImproperlyConfigured
  15. from django.urls import NoReverseMatch, re_path
  16. from rest_framework import views
  17. from rest_framework.response import Response
  18. from rest_framework.reverse import reverse
  19. from rest_framework.schemas import SchemaGenerator
  20. from rest_framework.schemas.views import SchemaView
  21. from rest_framework.settings import api_settings
  22. from rest_framework.urlpatterns import format_suffix_patterns
  23. Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
  24. DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs'])
  25. def escape_curly_brackets(url_path):
  26. """
  27. Double brackets in regex of url_path for escape string formatting
  28. """
  29. return url_path.replace('{', '{{').replace('}', '}}')
  30. def flatten(list_of_lists):
  31. """
  32. Takes an iterable of iterables, returns a single iterable containing all items
  33. """
  34. return itertools.chain(*list_of_lists)
  35. class BaseRouter:
  36. def __init__(self):
  37. self.registry = []
  38. def register(self, prefix, viewset, basename=None):
  39. if basename is None:
  40. basename = self.get_default_basename(viewset)
  41. self.registry.append((prefix, viewset, basename))
  42. # invalidate the urls cache
  43. if hasattr(self, '_urls'):
  44. del self._urls
  45. def get_default_basename(self, viewset):
  46. """
  47. If `basename` is not specified, attempt to automatically determine
  48. it from the viewset.
  49. """
  50. raise NotImplementedError('get_default_basename must be overridden')
  51. def get_urls(self):
  52. """
  53. Return a list of URL patterns, given the registered viewsets.
  54. """
  55. raise NotImplementedError('get_urls must be overridden')
  56. @property
  57. def urls(self):
  58. if not hasattr(self, '_urls'):
  59. self._urls = self.get_urls()
  60. return self._urls
  61. class SimpleRouter(BaseRouter):
  62. routes = [
  63. # List route.
  64. Route(
  65. url=r'^{prefix}{trailing_slash}$',
  66. mapping={
  67. 'get': 'list',
  68. 'post': 'create'
  69. },
  70. name='{basename}-list',
  71. detail=False,
  72. initkwargs={'suffix': 'List'}
  73. ),
  74. # Dynamically generated list routes. Generated using
  75. # @action(detail=False) decorator on methods of the viewset.
  76. DynamicRoute(
  77. url=r'^{prefix}/{url_path}{trailing_slash}$',
  78. name='{basename}-{url_name}',
  79. detail=False,
  80. initkwargs={}
  81. ),
  82. # Detail route.
  83. Route(
  84. url=r'^{prefix}/{lookup}{trailing_slash}$',
  85. mapping={
  86. 'get': 'retrieve',
  87. 'put': 'update',
  88. 'patch': 'partial_update',
  89. 'delete': 'destroy'
  90. },
  91. name='{basename}-detail',
  92. detail=True,
  93. initkwargs={'suffix': 'Instance'}
  94. ),
  95. # Dynamically generated detail routes. Generated using
  96. # @action(detail=True) decorator on methods of the viewset.
  97. DynamicRoute(
  98. url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$',
  99. name='{basename}-{url_name}',
  100. detail=True,
  101. initkwargs={}
  102. ),
  103. ]
  104. def __init__(self, trailing_slash=True):
  105. self.trailing_slash = '/' if trailing_slash else ''
  106. super().__init__()
  107. def get_default_basename(self, viewset):
  108. """
  109. If `basename` is not specified, attempt to automatically determine
  110. it from the viewset.
  111. """
  112. queryset = getattr(viewset, 'queryset', None)
  113. assert queryset is not None, '`basename` argument not specified, and could ' \
  114. 'not automatically determine the name from the viewset, as ' \
  115. 'it does not have a `.queryset` attribute.'
  116. return queryset.model._meta.object_name.lower()
  117. def get_routes(self, viewset):
  118. """
  119. Augment `self.routes` with any dynamically generated routes.
  120. Returns a list of the Route namedtuple.
  121. """
  122. # converting to list as iterables are good for one pass, known host needs to be checked again and again for
  123. # different functions.
  124. known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]))
  125. extra_actions = viewset.get_extra_actions()
  126. # checking action names against the known actions list
  127. not_allowed = [
  128. action.__name__ for action in extra_actions
  129. if action.__name__ in known_actions
  130. ]
  131. if not_allowed:
  132. msg = ('Cannot use the @action decorator on the following '
  133. 'methods, as they are existing routes: %s')
  134. raise ImproperlyConfigured(msg % ', '.join(not_allowed))
  135. # partition detail and list actions
  136. detail_actions = [action for action in extra_actions if action.detail]
  137. list_actions = [action for action in extra_actions if not action.detail]
  138. routes = []
  139. for route in self.routes:
  140. if isinstance(route, DynamicRoute) and route.detail:
  141. routes += [self._get_dynamic_route(route, action) for action in detail_actions]
  142. elif isinstance(route, DynamicRoute) and not route.detail:
  143. routes += [self._get_dynamic_route(route, action) for action in list_actions]
  144. else:
  145. routes.append(route)
  146. return routes
  147. def _get_dynamic_route(self, route, action):
  148. initkwargs = route.initkwargs.copy()
  149. initkwargs.update(action.kwargs)
  150. url_path = escape_curly_brackets(action.url_path)
  151. return Route(
  152. url=route.url.replace('{url_path}', url_path),
  153. mapping=action.mapping,
  154. name=route.name.replace('{url_name}', action.url_name),
  155. detail=route.detail,
  156. initkwargs=initkwargs,
  157. )
  158. def get_method_map(self, viewset, method_map):
  159. """
  160. Given a viewset, and a mapping of http methods to actions,
  161. return a new mapping which only includes any mappings that
  162. are actually implemented by the viewset.
  163. """
  164. bound_methods = {}
  165. for method, action in method_map.items():
  166. if hasattr(viewset, action):
  167. bound_methods[method] = action
  168. return bound_methods
  169. def get_lookup_regex(self, viewset, lookup_prefix=''):
  170. """
  171. Given a viewset, return the portion of URL regex that is used
  172. to match against a single instance.
  173. Note that lookup_prefix is not used directly inside REST rest_framework
  174. itself, but is required in order to nicely support nested router
  175. implementations, such as drf-nested-routers.
  176. https://github.com/alanjds/drf-nested-routers
  177. """
  178. base_regex = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})'
  179. # Use `pk` as default field, unset set. Default regex should not
  180. # consume `.json` style suffixes and should break at '/' boundaries.
  181. lookup_field = getattr(viewset, 'lookup_field', 'pk')
  182. lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
  183. lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+')
  184. return base_regex.format(
  185. lookup_prefix=lookup_prefix,
  186. lookup_url_kwarg=lookup_url_kwarg,
  187. lookup_value=lookup_value
  188. )
  189. def get_urls(self):
  190. """
  191. Use the registered viewsets to generate a list of URL patterns.
  192. """
  193. ret = []
  194. for prefix, viewset, basename in self.registry:
  195. lookup = self.get_lookup_regex(viewset)
  196. routes = self.get_routes(viewset)
  197. for route in routes:
  198. # Only actions which actually exist on the viewset will be bound
  199. mapping = self.get_method_map(viewset, route.mapping)
  200. if not mapping:
  201. continue
  202. # Build the url pattern
  203. regex = route.url.format(
  204. prefix=prefix,
  205. lookup=lookup,
  206. trailing_slash=self.trailing_slash
  207. )
  208. # If there is no prefix, the first part of the url is probably
  209. # controlled by project's urls.py and the router is in an app,
  210. # so a slash in the beginning will (A) cause Django to give
  211. # warnings and (B) generate URLS that will require using '//'.
  212. if not prefix and regex[:2] == '^/':
  213. regex = '^' + regex[2:]
  214. initkwargs = route.initkwargs.copy()
  215. initkwargs.update({
  216. 'basename': basename,
  217. 'detail': route.detail,
  218. })
  219. view = viewset.as_view(mapping, **initkwargs)
  220. name = route.name.format(basename=basename)
  221. ret.append(re_path(regex, view, name=name))
  222. return ret
  223. class APIRootView(views.APIView):
  224. """
  225. The default basic root view for DefaultRouter
  226. """
  227. _ignore_model_permissions = True
  228. schema = None # exclude from schema
  229. api_root_dict = None
  230. def get(self, request, *args, **kwargs):
  231. # Return a plain {"name": "hyperlink"} response.
  232. ret = OrderedDict()
  233. namespace = request.resolver_match.namespace
  234. for key, url_name in self.api_root_dict.items():
  235. if namespace:
  236. url_name = namespace + ':' + url_name
  237. try:
  238. ret[key] = reverse(
  239. url_name,
  240. args=args,
  241. kwargs=kwargs,
  242. request=request,
  243. format=kwargs.get('format')
  244. )
  245. except NoReverseMatch:
  246. # Don't bail out if eg. no list routes exist, only detail routes.
  247. continue
  248. return Response(ret)
  249. class DefaultRouter(SimpleRouter):
  250. """
  251. The default router extends the SimpleRouter, but also adds in a default
  252. API root view, and adds format suffix patterns to the URLs.
  253. """
  254. include_root_view = True
  255. include_format_suffixes = True
  256. root_view_name = 'api-root'
  257. default_schema_renderers = None
  258. APIRootView = APIRootView
  259. APISchemaView = SchemaView
  260. SchemaGenerator = SchemaGenerator
  261. def __init__(self, *args, **kwargs):
  262. if 'root_renderers' in kwargs:
  263. self.root_renderers = kwargs.pop('root_renderers')
  264. else:
  265. self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
  266. super().__init__(*args, **kwargs)
  267. def get_api_root_view(self, api_urls=None):
  268. """
  269. Return a basic root view.
  270. """
  271. api_root_dict = OrderedDict()
  272. list_name = self.routes[0].name
  273. for prefix, viewset, basename in self.registry:
  274. api_root_dict[prefix] = list_name.format(basename=basename)
  275. return self.APIRootView.as_view(api_root_dict=api_root_dict)
  276. def get_urls(self):
  277. """
  278. Generate the list of URL patterns, including a default root view
  279. for the API, and appending `.json` style format suffixes.
  280. """
  281. urls = super().get_urls()
  282. if self.include_root_view:
  283. view = self.get_api_root_view(api_urls=urls)
  284. root_url = re_path(r'^$', view, name=self.root_view_name)
  285. urls.append(root_url)
  286. if self.include_format_suffixes:
  287. urls = format_suffix_patterns(urls)
  288. return urls