batch.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. from collections import defaultdict
  2. from itertools import chain
  3. import six
  4. import sqlalchemy as sa
  5. from sqlalchemy.orm import RelationshipProperty
  6. from sqlalchemy.orm.attributes import (
  7. set_committed_value, InstrumentedAttribute
  8. )
  9. from sqlalchemy.orm.session import object_session
  10. from sqlalchemy_utils.generic import GenericRelationshipProperty
  11. from sqlalchemy_utils.functions.orm import (
  12. list_local_values,
  13. list_local_remote_exprs,
  14. local_values,
  15. remote_column_names,
  16. remote_values,
  17. remote
  18. )
  19. class PathException(Exception):
  20. pass
  21. class DataException(Exception):
  22. pass
  23. class with_backrefs(object):
  24. """
  25. Marks given attribute path so that whenever its fetched with batch_fetch
  26. the backref relations are force set too. Very useful when dealing with
  27. certain many-to-many relationship scenarios.
  28. """
  29. def __init__(self, path):
  30. self.path = path
  31. class Path(object):
  32. """
  33. A class that represents an attribute path.
  34. """
  35. def __init__(self, entities, prop, populate_backrefs=False):
  36. self.validate_property(prop)
  37. self.property = prop
  38. self.entities = entities
  39. self.populate_backrefs = populate_backrefs
  40. self.fetcher = self.fetcher_class(self)
  41. def validate_property(self, prop):
  42. if (
  43. not isinstance(prop, RelationshipProperty) and
  44. not isinstance(prop, GenericRelationshipProperty)
  45. ):
  46. raise PathException(
  47. 'Given attribute is not a relationship property.'
  48. )
  49. @property
  50. def session(self):
  51. return object_session(self.entities[0])
  52. @property
  53. def model(self):
  54. return self.property.mapper.class_
  55. @classmethod
  56. def parse(cls, entities, path, populate_backrefs=False):
  57. if isinstance(path, six.string_types):
  58. attrs = path.split('.')
  59. if len(attrs) > 1:
  60. related_entities = []
  61. for entity in entities:
  62. related_entities.extend(getattr(entity, attrs[0]))
  63. if not related_entities:
  64. raise DataException('No related entities.')
  65. subpath = '.'.join(attrs[1:])
  66. return Path.parse(related_entities, subpath, populate_backrefs)
  67. else:
  68. attr = getattr(
  69. entities[0].__class__, attrs[0]
  70. )
  71. elif isinstance(path, InstrumentedAttribute):
  72. attr = path
  73. else:
  74. raise PathException('Unknown path type.')
  75. return Path(entities, attr.property, populate_backrefs)
  76. @property
  77. def fetcher_class(self):
  78. if isinstance(self.property, GenericRelationshipProperty):
  79. return GenericRelationshipFetcher
  80. else:
  81. if self.property.secondary is not None:
  82. return ManyToManyFetcher
  83. else:
  84. if self.property.direction.name == 'MANYTOONE':
  85. return ManyToOneFetcher
  86. else:
  87. return OneToManyFetcher
  88. def batch_fetch(entities, *attr_paths):
  89. """
  90. Batch fetch given relationship attribute for collection of entities.
  91. This function is in many cases a valid alternative for SQLAlchemy's
  92. subqueryload and performs lot better.
  93. :param entities: list of entities of the same type
  94. :param attr_paths:
  95. List of either InstrumentedAttribute objects or a strings representing
  96. the name of the instrumented attribute
  97. Example::
  98. from sqlalchemy_utils import batch_fetch
  99. users = session.query(User).limit(20).all()
  100. batch_fetch(users, User.phonenumbers)
  101. Function also accepts strings as attribute names: ::
  102. users = session.query(User).limit(20).all()
  103. batch_fetch(users, 'phonenumbers')
  104. Multiple attributes may be provided: ::
  105. clubs = session.query(Club).limit(20).all()
  106. batch_fetch(
  107. clubs,
  108. 'teams',
  109. 'teams.players',
  110. 'teams.players.user_groups'
  111. )
  112. You can also force populate backrefs: ::
  113. from sqlalchemy_utils import with_backrefs
  114. clubs = session.query(Club).limit(20).all()
  115. batch_fetch(
  116. clubs,
  117. 'teams',
  118. 'teams.players',
  119. with_backrefs('teams.players.user_groups')
  120. )
  121. """
  122. if entities:
  123. for path in attr_paths:
  124. try:
  125. fetcher = fetcher_factory(entities, path)
  126. fetcher.fetch()
  127. fetcher.populate()
  128. except DataException:
  129. pass
  130. def get_fetcher(entities, path, populate_backrefs):
  131. return Path.parse(entities, path, populate_backrefs).fetcher
  132. def fetcher_factory(entities, path):
  133. populate_backrefs = False
  134. if isinstance(path, with_backrefs):
  135. path = path.path
  136. populate_backrefs = True
  137. if isinstance(path, tuple):
  138. return CompositeFetcher(
  139. *(get_fetcher(entities, p, populate_backrefs) for p in path)
  140. )
  141. else:
  142. return get_fetcher(entities, path, populate_backrefs)
  143. class CompositeFetcher(object):
  144. def __init__(self, *fetchers):
  145. if not all(
  146. fetchers[0].path.model == fetcher.path.model
  147. for fetcher in fetchers
  148. ):
  149. raise PathException(
  150. 'Each relationship property must have the same class when '
  151. 'using CompositeFetcher.'
  152. )
  153. self.fetchers = fetchers
  154. @property
  155. def session(self):
  156. return self.fetchers[0].path.session
  157. @property
  158. def model(self):
  159. return self.fetchers[0].path.model
  160. @property
  161. def condition(self):
  162. return sa.or_(
  163. *(fetcher.condition for fetcher in self.fetchers)
  164. )
  165. @property
  166. def related_entities(self):
  167. return self.session.query(self.model).filter(self.condition)
  168. def fetch(self):
  169. for entity in self.related_entities:
  170. for fetcher in self.fetchers:
  171. if any(remote_values(fetcher.prop, entity)):
  172. fetcher.append_entity(entity)
  173. def populate(self):
  174. for fetcher in self.fetchers:
  175. fetcher.populate()
  176. class Fetcher(object):
  177. def __init__(self, path):
  178. self.path = path
  179. self.prop = self.path.property
  180. default = list if self.prop.uselist else lambda: None
  181. self.parent_dict = defaultdict(default)
  182. @property
  183. def relation_query_base(self):
  184. return self.path.session.query(self.path.model)
  185. @property
  186. def related_entities(self):
  187. return self.relation_query_base.filter(self.condition)
  188. def populate_backrefs(self, related_entities):
  189. """
  190. Populates backrefs for given related entities.
  191. """
  192. backref_dict = dict(
  193. (local_values(self.prop, value[0]), [])
  194. for value in related_entities
  195. )
  196. for value in related_entities:
  197. backref_dict[local_values(self.prop, value[0])].append(
  198. self.path.session.query(self.path.entities[0].__class__).get(
  199. tuple(value[1:])
  200. )
  201. )
  202. for value in related_entities:
  203. set_committed_value(
  204. value[0],
  205. self.prop.back_populates,
  206. backref_dict[local_values(self.prop, value[0])]
  207. )
  208. def populate(self):
  209. """
  210. Populate batch fetched entities to parent objects.
  211. """
  212. for entity in self.path.entities:
  213. set_committed_value(
  214. entity,
  215. self.prop.key,
  216. self.parent_dict[local_values(self.prop, entity)]
  217. )
  218. if self.path.populate_backrefs:
  219. self.populate_backrefs(self.related_entities)
  220. @property
  221. def condition(self):
  222. names = list(remote_column_names(self.prop))
  223. if len(names) == 1:
  224. attr = getattr(remote(self.prop), names[0])
  225. return attr.in_(
  226. v[0] for v in list_local_values(self.prop, self.path.entities)
  227. )
  228. elif len(names) > 1:
  229. return sa.or_(
  230. *list_local_remote_exprs(self.prop, self.path.entities)
  231. )
  232. else:
  233. raise PathException(
  234. 'Could not obtain remote column names.'
  235. )
  236. def fetch(self):
  237. for entity in self.related_entities:
  238. self.append_entity(entity)
  239. class GenericRelationshipFetcher(object):
  240. def __init__(self, path):
  241. self.path = path
  242. self.prop = self.path.property
  243. self.parent_dict = defaultdict(lambda: None)
  244. def fetch(self):
  245. for entity in self.related_entities:
  246. self.append_entity(entity)
  247. def append_entity(self, entity):
  248. self.parent_dict[remote_values(self.prop, entity)] = entity
  249. def populate(self):
  250. """
  251. Populate batch fetched entities to parent objects.
  252. """
  253. for entity in self.path.entities:
  254. set_committed_value(
  255. entity,
  256. self.prop.key,
  257. self.parent_dict[local_values(self.prop, entity)]
  258. )
  259. @property
  260. def related_entities(self):
  261. id_dict = defaultdict(list)
  262. for entity in self.path.entities:
  263. discriminator = getattr(entity, self.prop._discriminator_col.key)
  264. for id_col in self.prop._id_cols:
  265. id_dict[discriminator].append(
  266. getattr(entity, id_col.key)
  267. )
  268. return chain(*self._queries(sa.inspect(entity), id_dict))
  269. def _queries(self, state, id_dict):
  270. for discriminator, ids in six.iteritems(id_dict):
  271. class_ = state.class_._decl_class_registry.get(discriminator)
  272. yield self.path.session.query(
  273. class_
  274. ).filter(
  275. class_.id.in_(ids)
  276. )
  277. class ManyToManyFetcher(Fetcher):
  278. @property
  279. def relation_query_base(self):
  280. return (
  281. self.path.session
  282. .query(
  283. self.path.model,
  284. *[
  285. getattr(remote(self.prop), name)
  286. for name in remote_column_names(self.prop)
  287. ]
  288. )
  289. .join(
  290. self.prop.secondary, self.prop.secondaryjoin
  291. )
  292. )
  293. def fetch(self):
  294. for value in self.related_entities:
  295. self.parent_dict[tuple(value[1:])].append(
  296. value[0]
  297. )
  298. class ManyToOneFetcher(Fetcher):
  299. def append_entity(self, entity):
  300. self.parent_dict[remote_values(self.prop, entity)] = entity
  301. class OneToManyFetcher(Fetcher):
  302. def append_entity(self, entity):
  303. self.parent_dict[remote_values(self.prop, entity)].append(
  304. entity
  305. )