foreign_keys.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. from collections import defaultdict
  2. from itertools import groupby
  3. import sqlalchemy as sa
  4. from sqlalchemy.exc import NoInspectionAvailable
  5. from sqlalchemy.orm import object_session
  6. from sqlalchemy.schema import ForeignKeyConstraint, MetaData, Table
  7. from ..query_chain import QueryChain
  8. from .database import has_index
  9. from .orm import _get_class_registry, get_column_key, get_mapper, get_tables
  10. def get_foreign_key_values(fk, obj):
  11. return dict(
  12. (
  13. fk.constraint.columns.values()[index].key,
  14. getattr(obj, element.column.key)
  15. )
  16. for
  17. index, element
  18. in
  19. enumerate(fk.constraint.elements)
  20. )
  21. def group_foreign_keys(foreign_keys):
  22. """
  23. Return a groupby iterator that groups given foreign keys by table.
  24. :param foreign_keys: a sequence of foreign keys
  25. ::
  26. foreign_keys = get_referencing_foreign_keys(User)
  27. for table, fks in group_foreign_keys(foreign_keys):
  28. # do something
  29. pass
  30. .. seealso:: :func:`get_referencing_foreign_keys`
  31. .. versionadded: 0.26.1
  32. """
  33. foreign_keys = sorted(
  34. foreign_keys, key=lambda key: key.constraint.table.name
  35. )
  36. return groupby(foreign_keys, lambda key: key.constraint.table)
  37. def get_referencing_foreign_keys(mixed):
  38. """
  39. Returns referencing foreign keys for given Table object or declarative
  40. class.
  41. :param mixed:
  42. SA Table object or SA declarative class
  43. ::
  44. get_referencing_foreign_keys(User) # set([ForeignKey('user.id')])
  45. get_referencing_foreign_keys(User.__table__)
  46. This function also understands inheritance. This means it returns
  47. all foreign keys that reference any table in the class inheritance tree.
  48. Let's say you have three classes which use joined table inheritance,
  49. namely TextItem, Article and BlogPost with Article and BlogPost inheriting
  50. TextItem.
  51. ::
  52. # This will check all foreign keys that reference either article table
  53. # or textitem table.
  54. get_referencing_foreign_keys(Article)
  55. .. seealso:: :func:`get_tables`
  56. """
  57. if isinstance(mixed, sa.Table):
  58. tables = [mixed]
  59. else:
  60. tables = get_tables(mixed)
  61. referencing_foreign_keys = set()
  62. for table in mixed.metadata.tables.values():
  63. if table not in tables:
  64. for constraint in table.constraints:
  65. if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
  66. for fk in constraint.elements:
  67. if any(fk.references(t) for t in tables):
  68. referencing_foreign_keys.add(fk)
  69. return referencing_foreign_keys
  70. def merge_references(from_, to, foreign_keys=None):
  71. """
  72. Merge the references of an entity into another entity.
  73. Consider the following models::
  74. class User(self.Base):
  75. __tablename__ = 'user'
  76. id = sa.Column(sa.Integer, primary_key=True)
  77. name = sa.Column(sa.String(255))
  78. def __repr__(self):
  79. return 'User(name=%r)' % self.name
  80. class BlogPost(self.Base):
  81. __tablename__ = 'blog_post'
  82. id = sa.Column(sa.Integer, primary_key=True)
  83. title = sa.Column(sa.String(255))
  84. author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
  85. author = sa.orm.relationship(User)
  86. Now lets add some data::
  87. john = self.User(name='John')
  88. jack = self.User(name='Jack')
  89. post = self.BlogPost(title='Some title', author=john)
  90. post2 = self.BlogPost(title='Other title', author=jack)
  91. self.session.add_all([
  92. john,
  93. jack,
  94. post,
  95. post2
  96. ])
  97. self.session.commit()
  98. If we wanted to merge all John's references to Jack it would be as easy as
  99. ::
  100. merge_references(john, jack)
  101. self.session.commit()
  102. post.author # User(name='Jack')
  103. post2.author # User(name='Jack')
  104. :param from_: an entity to merge into another entity
  105. :param to: an entity to merge another entity into
  106. :param foreign_keys: A sequence of foreign keys. By default this is None
  107. indicating all referencing foreign keys should be used.
  108. .. seealso: :func:`dependent_objects`
  109. .. versionadded: 0.26.1
  110. """
  111. if from_.__tablename__ != to.__tablename__:
  112. raise TypeError('The tables of given arguments do not match.')
  113. session = object_session(from_)
  114. foreign_keys = get_referencing_foreign_keys(from_)
  115. for fk in foreign_keys:
  116. old_values = get_foreign_key_values(fk, from_)
  117. new_values = get_foreign_key_values(fk, to)
  118. criteria = (
  119. getattr(fk.constraint.table.c, key) == value
  120. for key, value in old_values.items()
  121. )
  122. try:
  123. mapper = get_mapper(fk.constraint.table)
  124. except ValueError:
  125. query = (
  126. fk.constraint.table
  127. .update()
  128. .where(sa.and_(*criteria))
  129. .values(new_values)
  130. )
  131. session.execute(query)
  132. else:
  133. (
  134. session.query(mapper.class_)
  135. .filter_by(**old_values)
  136. .update(
  137. new_values,
  138. 'evaluate'
  139. )
  140. )
  141. def dependent_objects(obj, foreign_keys=None):
  142. """
  143. Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates
  144. through all dependent objects for given SQLAlchemy object.
  145. Consider a User object is referenced in various articles and also in
  146. various orders. Getting all these dependent objects is as easy as::
  147. from sqlalchemy_utils import dependent_objects
  148. dependent_objects(user)
  149. If you expect an object to have lots of dependent_objects it might be good
  150. to limit the results::
  151. dependent_objects(user).limit(5)
  152. The common use case is checking for all restrict dependent objects before
  153. deleting parent object and inform the user if there are dependent objects
  154. with ondelete='RESTRICT' foreign keys. If this kind of checking is not used
  155. it will lead to nasty IntegrityErrors being raised.
  156. In the following example we delete given user if it doesn't have any
  157. foreign key restricted dependent objects::
  158. from sqlalchemy_utils import get_referencing_foreign_keys
  159. user = session.query(User).get(some_user_id)
  160. deps = list(
  161. dependent_objects(
  162. user,
  163. (
  164. fk for fk in get_referencing_foreign_keys(User)
  165. # On most databases RESTRICT is the default mode hence we
  166. # check for None values also
  167. if fk.ondelete == 'RESTRICT' or fk.ondelete is None
  168. )
  169. ).limit(5)
  170. )
  171. if deps:
  172. # Do something to inform the user
  173. pass
  174. else:
  175. session.delete(user)
  176. :param obj: SQLAlchemy declarative model object
  177. :param foreign_keys:
  178. A sequence of foreign keys to use for searching the dependent_objects
  179. for given object. By default this is None, indicating that all foreign
  180. keys referencing the object will be used.
  181. .. note::
  182. This function does not support exotic mappers that use multiple tables
  183. .. seealso:: :func:`get_referencing_foreign_keys`
  184. .. seealso:: :func:`merge_references`
  185. .. versionadded: 0.26.0
  186. """
  187. if foreign_keys is None:
  188. foreign_keys = get_referencing_foreign_keys(obj)
  189. session = object_session(obj)
  190. chain = QueryChain([])
  191. classes = _get_class_registry(obj.__class__)
  192. for table, keys in group_foreign_keys(foreign_keys):
  193. keys = list(keys)
  194. for class_ in classes.values():
  195. try:
  196. mapper = sa.inspect(class_)
  197. except NoInspectionAvailable:
  198. continue
  199. parent_mapper = mapper.inherits
  200. if (
  201. table in mapper.tables and
  202. not (parent_mapper and table in parent_mapper.tables)
  203. ):
  204. query = session.query(class_).filter(
  205. sa.or_(*_get_criteria(keys, class_, obj))
  206. )
  207. chain.queries.append(query)
  208. return chain
  209. def _get_criteria(keys, class_, obj):
  210. criteria = []
  211. visited_constraints = []
  212. for key in keys:
  213. if key.constraint in visited_constraints:
  214. continue
  215. visited_constraints.append(key.constraint)
  216. subcriteria = []
  217. for index, column in enumerate(key.constraint.columns):
  218. foreign_column = (
  219. key.constraint.elements[index].column
  220. )
  221. subcriteria.append(
  222. getattr(class_, get_column_key(class_, column)) ==
  223. getattr(
  224. obj,
  225. sa.inspect(type(obj))
  226. .get_property_by_column(
  227. foreign_column
  228. ).key
  229. )
  230. )
  231. criteria.append(sa.and_(*subcriteria))
  232. return criteria
  233. def non_indexed_foreign_keys(metadata, engine=None):
  234. """
  235. Finds all non indexed foreign keys from all tables of given MetaData.
  236. Very useful for optimizing postgresql database and finding out which
  237. foreign keys need indexes.
  238. :param metadata: MetaData object to inspect tables from
  239. """
  240. reflected_metadata = MetaData()
  241. if metadata.bind is None and engine is None:
  242. raise Exception(
  243. 'Either pass a metadata object with bind or '
  244. 'pass engine as a second parameter'
  245. )
  246. constraints = defaultdict(list)
  247. for table_name in metadata.tables.keys():
  248. table = Table(
  249. table_name,
  250. reflected_metadata,
  251. autoload=True,
  252. autoload_with=metadata.bind or engine
  253. )
  254. for constraint in table.constraints:
  255. if not isinstance(constraint, ForeignKeyConstraint):
  256. continue
  257. if not has_index(constraint):
  258. constraints[table.name].append(constraint)
  259. return dict(constraints)
  260. def get_fk_constraint_for_columns(table, *columns):
  261. for constraint in table.constraints:
  262. if list(constraint.columns.values()) == list(columns):
  263. return constraint