horizontal_shard.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # ext/horizontal_shard.py
  2. # Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. """Horizontal sharding support.
  8. Defines a rudimental 'horizontal sharding' system which allows a Session to
  9. distribute queries and persistence operations across multiple databases.
  10. For a usage example, see the :ref:`examples_sharding` example included in
  11. the source distribution.
  12. """
  13. from .. import event
  14. from .. import exc
  15. from .. import inspect
  16. from .. import util
  17. from ..orm.query import Query
  18. from ..orm.session import Session
  19. __all__ = ["ShardedSession", "ShardedQuery"]
  20. class ShardedQuery(Query):
  21. def __init__(self, *args, **kwargs):
  22. super(ShardedQuery, self).__init__(*args, **kwargs)
  23. self.id_chooser = self.session.id_chooser
  24. self.query_chooser = self.session.query_chooser
  25. self.execute_chooser = self.session.execute_chooser
  26. self._shard_id = None
  27. def set_shard(self, shard_id):
  28. """Return a new query, limited to a single shard ID.
  29. All subsequent operations with the returned query will
  30. be against the single shard regardless of other state.
  31. The shard_id can be passed for a 2.0 style execution to the
  32. bind_arguments dictionary of :meth:`.Session.execute`::
  33. results = session.execute(
  34. stmt,
  35. bind_arguments={"shard_id": "my_shard"}
  36. )
  37. """
  38. return self.execution_options(_sa_shard_id=shard_id)
  39. class ShardedSession(Session):
  40. def __init__(
  41. self,
  42. shard_chooser,
  43. id_chooser,
  44. execute_chooser=None,
  45. shards=None,
  46. query_cls=ShardedQuery,
  47. **kwargs
  48. ):
  49. """Construct a ShardedSession.
  50. :param shard_chooser: A callable which, passed a Mapper, a mapped
  51. instance, and possibly a SQL clause, returns a shard ID. This id
  52. may be based off of the attributes present within the object, or on
  53. some round-robin scheme. If the scheme is based on a selection, it
  54. should set whatever state on the instance to mark it in the future as
  55. participating in that shard.
  56. :param id_chooser: A callable, passed a query and a tuple of identity
  57. values, which should return a list of shard ids where the ID might
  58. reside. The databases will be queried in the order of this listing.
  59. :param execute_chooser: For a given :class:`.ORMExecuteState`,
  60. returns the list of shard_ids
  61. where the query should be issued. Results from all shards returned
  62. will be combined together into a single listing.
  63. .. versionchanged:: 1.4 The ``execute_chooser`` parameter
  64. supersedes the ``query_chooser`` parameter.
  65. :param shards: A dictionary of string shard names
  66. to :class:`~sqlalchemy.engine.Engine` objects.
  67. """
  68. query_chooser = kwargs.pop("query_chooser", None)
  69. super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
  70. event.listen(
  71. self, "do_orm_execute", execute_and_instances, retval=True
  72. )
  73. self.shard_chooser = shard_chooser
  74. self.id_chooser = id_chooser
  75. if query_chooser:
  76. util.warn_deprecated(
  77. "The ``query_choser`` parameter is deprecated; "
  78. "please use ``execute_chooser``.",
  79. "1.4",
  80. )
  81. if execute_chooser:
  82. raise exc.ArgumentError(
  83. "Can't pass query_chooser and execute_chooser "
  84. "at the same time."
  85. )
  86. def execute_chooser(orm_context):
  87. return query_chooser(orm_context.statement)
  88. self.execute_chooser = execute_chooser
  89. else:
  90. self.execute_chooser = execute_chooser
  91. self.query_chooser = query_chooser
  92. self.__binds = {}
  93. if shards is not None:
  94. for k in shards:
  95. self.bind_shard(k, shards[k])
  96. def _identity_lookup(
  97. self,
  98. mapper,
  99. primary_key_identity,
  100. identity_token=None,
  101. lazy_loaded_from=None,
  102. **kw
  103. ):
  104. """override the default :meth:`.Session._identity_lookup` method so
  105. that we search for a given non-token primary key identity across all
  106. possible identity tokens (e.g. shard ids).
  107. .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from
  108. the :class:`_query.Query` object to the :class:`.Session`.
  109. """
  110. if identity_token is not None:
  111. return super(ShardedSession, self)._identity_lookup(
  112. mapper,
  113. primary_key_identity,
  114. identity_token=identity_token,
  115. **kw
  116. )
  117. else:
  118. q = self.query(mapper)
  119. if lazy_loaded_from:
  120. q = q._set_lazyload_from(lazy_loaded_from)
  121. for shard_id in self.id_chooser(q, primary_key_identity):
  122. obj = super(ShardedSession, self)._identity_lookup(
  123. mapper,
  124. primary_key_identity,
  125. identity_token=shard_id,
  126. lazy_loaded_from=lazy_loaded_from,
  127. **kw
  128. )
  129. if obj is not None:
  130. return obj
  131. return None
  132. def _choose_shard_and_assign(self, mapper, instance, **kw):
  133. if instance is not None:
  134. state = inspect(instance)
  135. if state.key:
  136. token = state.key[2]
  137. assert token is not None
  138. return token
  139. elif state.identity_token:
  140. return state.identity_token
  141. shard_id = self.shard_chooser(mapper, instance, **kw)
  142. if instance is not None:
  143. state.identity_token = shard_id
  144. return shard_id
  145. def connection_callable(
  146. self, mapper=None, instance=None, shard_id=None, **kwargs
  147. ):
  148. """Provide a :class:`_engine.Connection` to use in the unit of work
  149. flush process.
  150. """
  151. if shard_id is None:
  152. shard_id = self._choose_shard_and_assign(mapper, instance)
  153. if self.in_transaction():
  154. return self.get_transaction().connection(mapper, shard_id=shard_id)
  155. else:
  156. return self.get_bind(
  157. mapper, shard_id=shard_id, instance=instance
  158. ).connect(**kwargs)
  159. def get_bind(
  160. self, mapper=None, shard_id=None, instance=None, clause=None, **kw
  161. ):
  162. if shard_id is None:
  163. shard_id = self._choose_shard_and_assign(
  164. mapper, instance, clause=clause
  165. )
  166. return self.__binds[shard_id]
  167. def bind_shard(self, shard_id, bind):
  168. self.__binds[shard_id] = bind
  169. def execute_and_instances(orm_context):
  170. if orm_context.is_select:
  171. load_options = active_options = orm_context.load_options
  172. update_options = None
  173. elif orm_context.is_update or orm_context.is_delete:
  174. load_options = None
  175. update_options = active_options = orm_context.update_delete_options
  176. else:
  177. load_options = update_options = active_options = None
  178. session = orm_context.session
  179. def iter_for_shard(shard_id, load_options, update_options):
  180. execution_options = dict(orm_context.local_execution_options)
  181. bind_arguments = dict(orm_context.bind_arguments)
  182. bind_arguments["shard_id"] = shard_id
  183. if orm_context.is_select:
  184. load_options += {"_refresh_identity_token": shard_id}
  185. execution_options["_sa_orm_load_options"] = load_options
  186. elif orm_context.is_update or orm_context.is_delete:
  187. update_options += {"_refresh_identity_token": shard_id}
  188. execution_options["_sa_orm_update_options"] = update_options
  189. return orm_context.invoke_statement(
  190. bind_arguments=bind_arguments, execution_options=execution_options
  191. )
  192. if active_options and active_options._refresh_identity_token is not None:
  193. shard_id = active_options._refresh_identity_token
  194. elif "_sa_shard_id" in orm_context.execution_options:
  195. shard_id = orm_context.execution_options["_sa_shard_id"]
  196. elif "shard_id" in orm_context.bind_arguments:
  197. shard_id = orm_context.bind_arguments["shard_id"]
  198. else:
  199. shard_id = None
  200. if shard_id is not None:
  201. return iter_for_shard(shard_id, load_options, update_options)
  202. else:
  203. partial = []
  204. for shard_id in session.execute_chooser(orm_context):
  205. result_ = iter_for_shard(shard_id, load_options, update_options)
  206. partial.append(result_)
  207. return partial[0].merge(*partial[1:])