123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- from collections.abc import Iterable
- import six
- import sqlalchemy as sa
- from sqlalchemy.ext.hybrid import hybrid_property
- from sqlalchemy.orm import attributes, class_mapper, ColumnProperty
- from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
- from sqlalchemy.orm.session import _state_session
- from sqlalchemy.util import set_creation_order
- from .exceptions import ImproperlyConfigured
- from .functions import identity
- from .functions.orm import _get_class_registry
- class GenericAttributeImpl(attributes.ScalarAttributeImpl):
- def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
- if self.key in dict_:
- return dict_[self.key]
- # Retrieve the session bound to the state in order to perform
- # a lazy query for the attribute.
- session = _state_session(state)
- if session is None:
- # State is not bound to a session; we cannot proceed.
- return None
- # Find class for discriminator.
- # TODO: Perhaps optimize with some sort of lookup?
- discriminator = self.get_state_discriminator(state)
- target_class = _get_class_registry(state.class_).get(discriminator)
- if target_class is None:
- # Unknown discriminator; return nothing.
- return None
- id = self.get_state_id(state)
- target = session.query(target_class).get(id)
- # Return found (or not found) target.
- return target
- def get_state_discriminator(self, state):
- discriminator = self.parent_token.discriminator
- if isinstance(discriminator, hybrid_property):
- return getattr(state.obj(), discriminator.__name__)
- else:
- return state.attrs[discriminator.key].value
- def get_state_id(self, state):
- # Lookup row with the discriminator and id.
- return tuple(state.attrs[id.key].value for id in self.parent_token.id)
- def set(self, state, dict_, initiator,
- passive=attributes.PASSIVE_OFF,
- check_old=None,
- pop=False):
- # Set us on the state.
- dict_[self.key] = initiator
- if initiator is None:
- # Nullify relationship args
- for id in self.parent_token.id:
- dict_[id.key] = None
- dict_[self.parent_token.discriminator.key] = None
- else:
- # Get the primary key of the initiator and ensure we
- # can support this assignment.
- class_ = type(initiator)
- mapper = class_mapper(class_)
- pk = mapper.identity_key_from_instance(initiator)[1]
- # Set the identifier and the discriminator.
- discriminator = six.text_type(class_.__name__)
- for index, id in enumerate(self.parent_token.id):
- dict_[id.key] = pk[index]
- dict_[self.parent_token.discriminator.key] = discriminator
- class GenericRelationshipProperty(MapperProperty):
- """A generic form of the relationship property.
- Creates a 1 to many relationship between the parent model
- and any other models using a descriminator (the table name).
- :param discriminator
- Field to discriminate which model we are referring to.
- :param id:
- Field to point to the model we are referring to.
- """
- def __init__(self, discriminator, id, doc=None):
- super(GenericRelationshipProperty, self).__init__()
- self._discriminator_col = discriminator
- self._id_cols = id
- self._id = None
- self._discriminator = None
- self.doc = doc
- set_creation_order(self)
- def _column_to_property(self, column):
- if isinstance(column, hybrid_property):
- attr_key = column.__name__
- for key, attr in self.parent.all_orm_descriptors.items():
- if key == attr_key:
- return attr
- else:
- for attr in self.parent.attrs.values():
- if isinstance(attr, ColumnProperty):
- if attr.columns[0].name == column.name:
- return attr
- def init(self):
- def convert_strings(column):
- if isinstance(column, six.string_types):
- return self.parent.columns[column]
- return column
- self._discriminator_col = convert_strings(self._discriminator_col)
- self._id_cols = convert_strings(self._id_cols)
- if isinstance(self._id_cols, Iterable):
- self._id_cols = list(map(convert_strings, self._id_cols))
- else:
- self._id_cols = [self._id_cols]
- self.discriminator = self._column_to_property(self._discriminator_col)
- if self.discriminator is None:
- raise ImproperlyConfigured(
- 'Could not find discriminator descriptor.'
- )
- self.id = list(map(self._column_to_property, self._id_cols))
- class Comparator(PropComparator):
- def __init__(self, prop, parentmapper):
- self.property = prop
- self._parententity = parentmapper
- def __eq__(self, other):
- discriminator = six.text_type(type(other).__name__)
- q = self.property._discriminator_col == discriminator
- other_id = identity(other)
- for index, id in enumerate(self.property._id_cols):
- q &= id == other_id[index]
- return q
- def __ne__(self, other):
- return ~(self == other)
- def is_type(self, other):
- mapper = sa.inspect(other)
- # Iterate through the weak sequence in order to get the actual
- # mappers
- class_names = [six.text_type(other.__name__)]
- class_names.extend([
- six.text_type(submapper.class_.__name__)
- for submapper in mapper._inheriting_mappers
- ])
- return self.property._discriminator_col.in_(class_names)
- def instrument_class(self, mapper):
- attributes.register_attribute(
- mapper.class_,
- self.key,
- comparator=self.Comparator(self, mapper),
- parententity=mapper,
- doc=self.doc,
- impl_class=GenericAttributeImpl,
- parent_token=self
- )
- def generic_relationship(*args, **kwargs):
- return GenericRelationshipProperty(*args, **kwargs)
|