aggregates.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. """
  2. SQLAlchemy-Utils provides way of automatically calculating aggregate values of
  3. related models and saving them to parent model.
  4. This solution is inspired by RoR counter cache,
  5. `counter_culture`_ and `stackoverflow reply by Michael Bayer`_.
  6. Why?
  7. ----
  8. Many times you may have situations where you need to calculate dynamically some
  9. aggregate value for given model. Some simple examples include:
  10. - Number of products in a catalog
  11. - Average rating for movie
  12. - Latest forum post
  13. - Total price of orders for given customer
  14. Now all these aggregates can be elegantly implemented with SQLAlchemy
  15. column_property_ function. However when your data grows calculating these
  16. values on the fly might start to hurt the performance of your application. The
  17. more aggregates you are using the more performance penalty you get.
  18. This module provides way of calculating these values automatically and
  19. efficiently at the time of modification rather than on the fly.
  20. Features
  21. --------
  22. * Automatically updates aggregate columns when aggregated values change
  23. * Supports aggregate values through arbitrary number levels of relations
  24. * Highly optimized: uses single query per transaction per aggregate column
  25. * Aggregated columns can be of any data type and use any selectable scalar
  26. expression
  27. .. _column_property:
  28. http://docs.sqlalchemy.org/en/latest/orm/mapper_config.html#using-column-property
  29. .. _counter_culture: https://github.com/magnusvk/counter_culture
  30. .. _stackoverflow reply by Michael Bayer:
  31. http://stackoverflow.com/questions/13693872/
  32. Simple aggregates
  33. -----------------
  34. ::
  35. from sqlalchemy_utils import aggregated
  36. class Thread(Base):
  37. __tablename__ = 'thread'
  38. id = sa.Column(sa.Integer, primary_key=True)
  39. name = sa.Column(sa.Unicode(255))
  40. @aggregated('comments', sa.Column(sa.Integer))
  41. def comment_count(self):
  42. return sa.func.count('1')
  43. comments = sa.orm.relationship(
  44. 'Comment',
  45. backref='thread'
  46. )
  47. class Comment(Base):
  48. __tablename__ = 'comment'
  49. id = sa.Column(sa.Integer, primary_key=True)
  50. content = sa.Column(sa.UnicodeText)
  51. thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
  52. thread = Thread(name=u'SQLAlchemy development')
  53. thread.comments.append(Comment(u'Going good!'))
  54. thread.comments.append(Comment(u'Great new features!'))
  55. session.add(thread)
  56. session.commit()
  57. thread.comment_count # 2
  58. Custom aggregate expressions
  59. ----------------------------
  60. Aggregate expression can be virtually any SQL expression not just a simple
  61. function taking one parameter. You can try things such as subqueries and
  62. different kinds of functions.
  63. In the following example we have a Catalog of products where each catalog
  64. knows the net worth of its products.
  65. ::
  66. from sqlalchemy_utils import aggregated
  67. class Catalog(Base):
  68. __tablename__ = 'catalog'
  69. id = sa.Column(sa.Integer, primary_key=True)
  70. name = sa.Column(sa.Unicode(255))
  71. @aggregated('products', sa.Column(sa.Integer))
  72. def net_worth(self):
  73. return sa.func.sum(Product.price)
  74. products = sa.orm.relationship('Product')
  75. class Product(Base):
  76. __tablename__ = 'product'
  77. id = sa.Column(sa.Integer, primary_key=True)
  78. name = sa.Column(sa.Unicode(255))
  79. price = sa.Column(sa.Numeric)
  80. catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
  81. Now the net_worth column of Catalog model will be automatically whenever:
  82. * A new product is added to the catalog
  83. * A product is deleted from the catalog
  84. * The price of catalog product is changed
  85. ::
  86. from decimal import Decimal
  87. product1 = Product(name='Some product', price=Decimal(1000))
  88. product2 = Product(name='Some other product', price=Decimal(500))
  89. catalog = Catalog(
  90. name=u'My first catalog',
  91. products=[
  92. product1,
  93. product2
  94. ]
  95. )
  96. session.add(catalog)
  97. session.commit()
  98. session.refresh(catalog)
  99. catalog.net_worth # 1500
  100. session.delete(product2)
  101. session.commit()
  102. session.refresh(catalog)
  103. catalog.net_worth # 1000
  104. product1.price = 2000
  105. session.commit()
  106. session.refresh(catalog)
  107. catalog.net_worth # 2000
  108. Multiple aggregates per class
  109. -----------------------------
  110. Sometimes you may need to define multiple aggregate values for same class. If
  111. you need to define lots of relationships pointing to same class, remember to
  112. define the relationships as viewonly when possible.
  113. ::
  114. from sqlalchemy_utils import aggregated
  115. class Customer(Base):
  116. __tablename__ = 'customer'
  117. id = sa.Column(sa.Integer, primary_key=True)
  118. name = sa.Column(sa.Unicode(255))
  119. @aggregated('orders', sa.Column(sa.Integer))
  120. def orders_sum(self):
  121. return sa.func.sum(Order.price)
  122. @aggregated('invoiced_orders', sa.Column(sa.Integer))
  123. def invoiced_orders_sum(self):
  124. return sa.func.sum(Order.price)
  125. orders = sa.orm.relationship('Order')
  126. invoiced_orders = sa.orm.relationship(
  127. 'Order',
  128. primaryjoin=
  129. 'sa.and_(Order.customer_id == Customer.id, Order.invoiced)',
  130. viewonly=True
  131. )
  132. class Order(Base):
  133. __tablename__ = 'order'
  134. id = sa.Column(sa.Integer, primary_key=True)
  135. name = sa.Column(sa.Unicode(255))
  136. price = sa.Column(sa.Numeric)
  137. invoiced = sa.Column(sa.Boolean, default=False)
  138. customer_id = sa.Column(sa.Integer, sa.ForeignKey(Customer.id))
  139. Many-to-Many aggregates
  140. -----------------------
  141. Aggregate expressions also support many-to-many relationships. The usual use
  142. scenarios includes things such as:
  143. 1. Friend count of a user
  144. 2. Group count where given user belongs to
  145. ::
  146. user_group = sa.Table('user_group', Base.metadata,
  147. sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
  148. sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
  149. )
  150. class User(Base):
  151. __tablename__ = 'user'
  152. id = sa.Column(sa.Integer, primary_key=True)
  153. name = sa.Column(sa.Unicode(255))
  154. @aggregated('groups', sa.Column(sa.Integer, default=0))
  155. def group_count(self):
  156. return sa.func.count('1')
  157. groups = sa.orm.relationship(
  158. 'Group',
  159. backref='users',
  160. secondary=user_group
  161. )
  162. class Group(Base):
  163. __tablename__ = 'group'
  164. id = sa.Column(sa.Integer, primary_key=True)
  165. name = sa.Column(sa.Unicode(255))
  166. user = User(name=u'John Matrix')
  167. user.groups = [Group(name=u'Group A'), Group(name=u'Group B')]
  168. session.add(user)
  169. session.commit()
  170. session.refresh(user)
  171. user.group_count # 2
  172. Multi-level aggregates
  173. ----------------------
  174. Aggregates can span across multiple relationships. In the following example
  175. each Catalog has a net_worth which is the sum of all products in all
  176. categories.
  177. ::
  178. from sqlalchemy_utils import aggregated
  179. class Catalog(Base):
  180. __tablename__ = 'catalog'
  181. id = sa.Column(sa.Integer, primary_key=True)
  182. name = sa.Column(sa.Unicode(255))
  183. @aggregated('categories.products', sa.Column(sa.Integer))
  184. def net_worth(self):
  185. return sa.func.sum(Product.price)
  186. categories = sa.orm.relationship('Category')
  187. class Category(Base):
  188. __tablename__ = 'category'
  189. id = sa.Column(sa.Integer, primary_key=True)
  190. name = sa.Column(sa.Unicode(255))
  191. catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
  192. products = sa.orm.relationship('Product')
  193. class Product(Base):
  194. __tablename__ = 'product'
  195. id = sa.Column(sa.Integer, primary_key=True)
  196. name = sa.Column(sa.Unicode(255))
  197. price = sa.Column(sa.Numeric)
  198. category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
  199. Examples
  200. --------
  201. Average movie rating
  202. ^^^^^^^^^^^^^^^^^^^^
  203. ::
  204. from sqlalchemy_utils import aggregated
  205. class Movie(Base):
  206. __tablename__ = 'movie'
  207. id = sa.Column(sa.Integer, primary_key=True)
  208. name = sa.Column(sa.Unicode(255))
  209. @aggregated('ratings', sa.Column(sa.Numeric))
  210. def avg_rating(self):
  211. return sa.func.avg(Rating.stars)
  212. ratings = sa.orm.relationship('Rating')
  213. class Rating(Base):
  214. __tablename__ = 'rating'
  215. id = sa.Column(sa.Integer, primary_key=True)
  216. stars = sa.Column(sa.Integer)
  217. movie_id = sa.Column(sa.Integer, sa.ForeignKey(Movie.id))
  218. movie = Movie('Terminator 2')
  219. movie.ratings.append(Rating(stars=5))
  220. movie.ratings.append(Rating(stars=4))
  221. movie.ratings.append(Rating(stars=3))
  222. session.add(movie)
  223. session.commit()
  224. movie.avg_rating # 4
  225. TODO
  226. ----
  227. * Special consideration should be given to `deadlocks`_.
  228. .. _deadlocks:
  229. http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
  230. """
  231. from collections import defaultdict
  232. from weakref import WeakKeyDictionary
  233. import sqlalchemy as sa
  234. from sqlalchemy.ext.declarative import declared_attr
  235. from sqlalchemy.sql.functions import _FunctionGenerator
  236. from .compat import get_scalar_subquery
  237. from .functions.orm import get_column_key
  238. from .relationships import (
  239. chained_join,
  240. path_to_relationships,
  241. select_correlated_expression
  242. )
  243. aggregated_attrs = WeakKeyDictionary()
  244. class AggregatedAttribute(declared_attr):
  245. def __init__(
  246. self,
  247. fget,
  248. relationship,
  249. column,
  250. *args,
  251. **kwargs
  252. ):
  253. super(AggregatedAttribute, self).__init__(fget, *args, **kwargs)
  254. self.__doc__ = fget.__doc__
  255. self.column = column
  256. self.relationship = relationship
  257. def __get__(desc, self, cls):
  258. value = (desc.fget, desc.relationship, desc.column)
  259. if cls not in aggregated_attrs:
  260. aggregated_attrs[cls] = [value]
  261. else:
  262. aggregated_attrs[cls].append(value)
  263. return desc.column
  264. def local_condition(prop, objects):
  265. pairs = prop.local_remote_pairs
  266. if prop.secondary is not None:
  267. parent_column = pairs[1][0]
  268. fetched_column = pairs[1][0]
  269. else:
  270. parent_column = pairs[0][0]
  271. fetched_column = pairs[0][1]
  272. key = get_column_key(prop.mapper, fetched_column)
  273. values = []
  274. for obj in objects:
  275. try:
  276. values.append(getattr(obj, key))
  277. except sa.orm.exc.ObjectDeletedError:
  278. pass
  279. if values:
  280. return parent_column.in_(values)
  281. def aggregate_expression(expr, class_):
  282. if isinstance(expr, sa.sql.visitors.Visitable):
  283. return expr
  284. elif isinstance(expr, _FunctionGenerator):
  285. return expr(sa.sql.text('1'))
  286. else:
  287. return expr(class_)
  288. class AggregatedValue(object):
  289. def __init__(self, class_, attr, path, expr):
  290. self.class_ = class_
  291. self.attr = attr
  292. self.path = path
  293. self.relationships = list(
  294. reversed(path_to_relationships(path, class_))
  295. )
  296. self.expr = aggregate_expression(expr, class_)
  297. @property
  298. def aggregate_query(self):
  299. query = select_correlated_expression(
  300. self.class_,
  301. self.expr,
  302. self.path,
  303. self.relationships[0].mapper.class_
  304. )
  305. return get_scalar_subquery(query)
  306. def update_query(self, objects):
  307. table = self.class_.__table__
  308. query = table.update().values(
  309. {self.attr: self.aggregate_query}
  310. )
  311. if len(self.relationships) == 1:
  312. prop = self.relationships[-1].property
  313. condition = local_condition(prop, objects)
  314. if condition is not None:
  315. return query.where(condition)
  316. else:
  317. # Builds query such as:
  318. #
  319. # UPDATE catalog SET product_count = (aggregate_query)
  320. # WHERE id IN (
  321. # SELECT catalog_id
  322. # FROM category
  323. # INNER JOIN sub_category
  324. # ON category.id = sub_category.category_id
  325. # WHERE sub_category.id IN (product_sub_category_ids)
  326. # )
  327. property_ = self.relationships[-1].property
  328. remote_pairs = property_.local_remote_pairs
  329. local = remote_pairs[0][0]
  330. remote = remote_pairs[0][1]
  331. condition = local_condition(
  332. self.relationships[0].property,
  333. objects
  334. )
  335. if condition is not None:
  336. return query.where(
  337. local.in_(
  338. sa.select(
  339. [remote],
  340. from_obj=[
  341. chained_join(*reversed(self.relationships))
  342. ]
  343. ).where(
  344. condition
  345. )
  346. )
  347. )
  348. class AggregationManager(object):
  349. def __init__(self):
  350. self.reset()
  351. def reset(self):
  352. self.generator_registry = defaultdict(list)
  353. def register_listeners(self):
  354. sa.event.listen(
  355. sa.orm.mapper,
  356. 'after_configured',
  357. self.update_generator_registry
  358. )
  359. sa.event.listen(
  360. sa.orm.session.Session,
  361. 'after_flush',
  362. self.construct_aggregate_queries
  363. )
  364. def update_generator_registry(self):
  365. for class_, attrs in aggregated_attrs.items():
  366. for expr, path, column in attrs:
  367. value = AggregatedValue(
  368. class_=class_,
  369. attr=column,
  370. path=path,
  371. expr=expr(class_)
  372. )
  373. key = value.relationships[0].mapper.class_
  374. self.generator_registry[key].append(
  375. value
  376. )
  377. def construct_aggregate_queries(self, session, ctx):
  378. object_dict = defaultdict(list)
  379. for obj in session:
  380. for class_ in self.generator_registry:
  381. if isinstance(obj, class_):
  382. object_dict[class_].append(obj)
  383. for class_, objects in object_dict.items():
  384. for aggregate_value in self.generator_registry[class_]:
  385. query = aggregate_value.update_query(objects)
  386. if query is not None:
  387. session.execute(query)
  388. manager = AggregationManager()
  389. manager.register_listeners()
  390. def aggregated(
  391. relationship,
  392. column
  393. ):
  394. """
  395. Decorator that generates an aggregated attribute. The decorated function
  396. should return an aggregate select expression.
  397. :param relationship:
  398. Defines the relationship of which the aggregate is calculated from.
  399. The class needs to have given relationship in order to calculate the
  400. aggregate.
  401. :param column:
  402. SQLAlchemy Column object. The column definition of this aggregate
  403. attribute.
  404. """
  405. def wraps(func):
  406. return AggregatedAttribute(
  407. func,
  408. relationship,
  409. column
  410. )
  411. return wraps