|
- from collections import defaultdict
- from itertools import chain
- import six
- import sqlalchemy as sa
- from sqlalchemy.orm import RelationshipProperty
- from sqlalchemy.orm.attributes import (
- set_committed_value, InstrumentedAttribute
- )
- from sqlalchemy.orm.session import object_session
- from sqlalchemy_utils.generic import GenericRelationshipProperty
- from sqlalchemy_utils.functions.orm import (
- list_local_values,
- list_local_remote_exprs,
- local_values,
- remote_column_names,
- remote_values,
- remote
- )
- class PathException(Exception):
- pass
- class DataException(Exception):
- pass
- class with_backrefs(object):
- """
- Marks given attribute path so that whenever its fetched with batch_fetch
- the backref relations are force set too. Very useful when dealing with
- certain many-to-many relationship scenarios.
- """
- def __init__(self, path):
- self.path = path
- class Path(object):
- """
- A class that represents an attribute path.
- """
- def __init__(self, entities, prop, populate_backrefs=False):
- self.validate_property(prop)
- self.property = prop
- self.entities = entities
- self.populate_backrefs = populate_backrefs
- self.fetcher = self.fetcher_class(self)
- def validate_property(self, prop):
- if (
- not isinstance(prop, RelationshipProperty) and
- not isinstance(prop, GenericRelationshipProperty)
- ):
- raise PathException(
- 'Given attribute is not a relationship property.'
- )
- @property
- def session(self):
- return object_session(self.entities[0])
- @property
- def model(self):
- return self.property.mapper.class_
- @classmethod
- def parse(cls, entities, path, populate_backrefs=False):
- if isinstance(path, six.string_types):
- attrs = path.split('.')
- if len(attrs) > 1:
- related_entities = []
- for entity in entities:
- related_entities.extend(getattr(entity, attrs[0]))
- if not related_entities:
- raise DataException('No related entities.')
- subpath = '.'.join(attrs[1:])
- return Path.parse(related_entities, subpath, populate_backrefs)
- else:
- attr = getattr(
- entities[0].__class__, attrs[0]
- )
- elif isinstance(path, InstrumentedAttribute):
- attr = path
- else:
- raise PathException('Unknown path type.')
- return Path(entities, attr.property, populate_backrefs)
- @property
- def fetcher_class(self):
- if isinstance(self.property, GenericRelationshipProperty):
- return GenericRelationshipFetcher
- else:
- if self.property.secondary is not None:
- return ManyToManyFetcher
- else:
- if self.property.direction.name == 'MANYTOONE':
- return ManyToOneFetcher
- else:
- return OneToManyFetcher
- def batch_fetch(entities, *attr_paths):
- """
- Batch fetch given relationship attribute for collection of entities.
- This function is in many cases a valid alternative for SQLAlchemy's
- subqueryload and performs lot better.
- :param entities: list of entities of the same type
- :param attr_paths:
- List of either InstrumentedAttribute objects or a strings representing
- the name of the instrumented attribute
- Example::
- from sqlalchemy_utils import batch_fetch
- users = session.query(User).limit(20).all()
- batch_fetch(users, User.phonenumbers)
- Function also accepts strings as attribute names: ::
- users = session.query(User).limit(20).all()
- batch_fetch(users, 'phonenumbers')
- Multiple attributes may be provided: ::
- clubs = session.query(Club).limit(20).all()
- batch_fetch(
- clubs,
- 'teams',
- 'teams.players',
- 'teams.players.user_groups'
- )
- You can also force populate backrefs: ::
- from sqlalchemy_utils import with_backrefs
- clubs = session.query(Club).limit(20).all()
- batch_fetch(
- clubs,
- 'teams',
- 'teams.players',
- with_backrefs('teams.players.user_groups')
- )
- """
- if entities:
- for path in attr_paths:
- try:
- fetcher = fetcher_factory(entities, path)
- fetcher.fetch()
- fetcher.populate()
- except DataException:
- pass
- def get_fetcher(entities, path, populate_backrefs):
- return Path.parse(entities, path, populate_backrefs).fetcher
- def fetcher_factory(entities, path):
- populate_backrefs = False
- if isinstance(path, with_backrefs):
- path = path.path
- populate_backrefs = True
- if isinstance(path, tuple):
- return CompositeFetcher(
- *(get_fetcher(entities, p, populate_backrefs) for p in path)
- )
- else:
- return get_fetcher(entities, path, populate_backrefs)
- class CompositeFetcher(object):
- def __init__(self, *fetchers):
- if not all(
- fetchers[0].path.model == fetcher.path.model
- for fetcher in fetchers
- ):
- raise PathException(
- 'Each relationship property must have the same class when '
- 'using CompositeFetcher.'
- )
- self.fetchers = fetchers
- @property
- def session(self):
- return self.fetchers[0].path.session
- @property
- def model(self):
- return self.fetchers[0].path.model
- @property
- def condition(self):
- return sa.or_(
- *(fetcher.condition for fetcher in self.fetchers)
- )
- @property
- def related_entities(self):
- return self.session.query(self.model).filter(self.condition)
- def fetch(self):
- for entity in self.related_entities:
- for fetcher in self.fetchers:
- if any(remote_values(fetcher.prop, entity)):
- fetcher.append_entity(entity)
- def populate(self):
- for fetcher in self.fetchers:
- fetcher.populate()
- class Fetcher(object):
- def __init__(self, path):
- self.path = path
- self.prop = self.path.property
- default = list if self.prop.uselist else lambda: None
- self.parent_dict = defaultdict(default)
- @property
- def relation_query_base(self):
- return self.path.session.query(self.path.model)
- @property
- def related_entities(self):
- return self.relation_query_base.filter(self.condition)
- def populate_backrefs(self, related_entities):
- """
- Populates backrefs for given related entities.
- """
- backref_dict = dict(
- (local_values(self.prop, value[0]), [])
- for value in related_entities
- )
- for value in related_entities:
- backref_dict[local_values(self.prop, value[0])].append(
- self.path.session.query(self.path.entities[0].__class__).get(
- tuple(value[1:])
- )
- )
- for value in related_entities:
- set_committed_value(
- value[0],
- self.prop.back_populates,
- backref_dict[local_values(self.prop, value[0])]
- )
- def populate(self):
- """
- Populate batch fetched entities to parent objects.
- """
- for entity in self.path.entities:
- set_committed_value(
- entity,
- self.prop.key,
- self.parent_dict[local_values(self.prop, entity)]
- )
- if self.path.populate_backrefs:
- self.populate_backrefs(self.related_entities)
- @property
- def condition(self):
- names = list(remote_column_names(self.prop))
- if len(names) == 1:
- attr = getattr(remote(self.prop), names[0])
- return attr.in_(
- v[0] for v in list_local_values(self.prop, self.path.entities)
- )
- elif len(names) > 1:
- return sa.or_(
- *list_local_remote_exprs(self.prop, self.path.entities)
- )
- else:
- raise PathException(
- 'Could not obtain remote column names.'
- )
- def fetch(self):
- for entity in self.related_entities:
- self.append_entity(entity)
- class GenericRelationshipFetcher(object):
- def __init__(self, path):
- self.path = path
- self.prop = self.path.property
- self.parent_dict = defaultdict(lambda: None)
- def fetch(self):
- for entity in self.related_entities:
- self.append_entity(entity)
- def append_entity(self, entity):
- self.parent_dict[remote_values(self.prop, entity)] = entity
- def populate(self):
- """
- Populate batch fetched entities to parent objects.
- """
- for entity in self.path.entities:
- set_committed_value(
- entity,
- self.prop.key,
- self.parent_dict[local_values(self.prop, entity)]
- )
- @property
- def related_entities(self):
- id_dict = defaultdict(list)
- for entity in self.path.entities:
- discriminator = getattr(entity, self.prop._discriminator_col.key)
- for id_col in self.prop._id_cols:
- id_dict[discriminator].append(
- getattr(entity, id_col.key)
- )
- return chain(*self._queries(sa.inspect(entity), id_dict))
- def _queries(self, state, id_dict):
- for discriminator, ids in six.iteritems(id_dict):
- class_ = state.class_._decl_class_registry.get(discriminator)
- yield self.path.session.query(
- class_
- ).filter(
- class_.id.in_(ids)
- )
- class ManyToManyFetcher(Fetcher):
- @property
- def relation_query_base(self):
- return (
- self.path.session
- .query(
- self.path.model,
- *[
- getattr(remote(self.prop), name)
- for name in remote_column_names(self.prop)
- ]
- )
- .join(
- self.prop.secondary, self.prop.secondaryjoin
- )
- )
- def fetch(self):
- for value in self.related_entities:
- self.parent_dict[tuple(value[1:])].append(
- value[0]
- )
- class ManyToOneFetcher(Fetcher):
- def append_entity(self, entity):
- self.parent_dict[remote_values(self.prop, entity)] = entity
- class OneToManyFetcher(Fetcher):
- def append_entity(self, entity):
- self.parent_dict[remote_values(self.prop, entity)].append(
- entity
- )
|