generic.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from collections.abc import Iterable
  2. import six
  3. import sqlalchemy as sa
  4. from sqlalchemy.ext.hybrid import hybrid_property
  5. from sqlalchemy.orm import attributes, class_mapper, ColumnProperty
  6. from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
  7. from sqlalchemy.orm.session import _state_session
  8. from sqlalchemy.util import set_creation_order
  9. from .exceptions import ImproperlyConfigured
  10. from .functions import identity
  11. from .functions.orm import _get_class_registry
  12. class GenericAttributeImpl(attributes.ScalarAttributeImpl):
  13. def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
  14. if self.key in dict_:
  15. return dict_[self.key]
  16. # Retrieve the session bound to the state in order to perform
  17. # a lazy query for the attribute.
  18. session = _state_session(state)
  19. if session is None:
  20. # State is not bound to a session; we cannot proceed.
  21. return None
  22. # Find class for discriminator.
  23. # TODO: Perhaps optimize with some sort of lookup?
  24. discriminator = self.get_state_discriminator(state)
  25. target_class = _get_class_registry(state.class_).get(discriminator)
  26. if target_class is None:
  27. # Unknown discriminator; return nothing.
  28. return None
  29. id = self.get_state_id(state)
  30. target = session.query(target_class).get(id)
  31. # Return found (or not found) target.
  32. return target
  33. def get_state_discriminator(self, state):
  34. discriminator = self.parent_token.discriminator
  35. if isinstance(discriminator, hybrid_property):
  36. return getattr(state.obj(), discriminator.__name__)
  37. else:
  38. return state.attrs[discriminator.key].value
  39. def get_state_id(self, state):
  40. # Lookup row with the discriminator and id.
  41. return tuple(state.attrs[id.key].value for id in self.parent_token.id)
  42. def set(self, state, dict_, initiator,
  43. passive=attributes.PASSIVE_OFF,
  44. check_old=None,
  45. pop=False):
  46. # Set us on the state.
  47. dict_[self.key] = initiator
  48. if initiator is None:
  49. # Nullify relationship args
  50. for id in self.parent_token.id:
  51. dict_[id.key] = None
  52. dict_[self.parent_token.discriminator.key] = None
  53. else:
  54. # Get the primary key of the initiator and ensure we
  55. # can support this assignment.
  56. class_ = type(initiator)
  57. mapper = class_mapper(class_)
  58. pk = mapper.identity_key_from_instance(initiator)[1]
  59. # Set the identifier and the discriminator.
  60. discriminator = six.text_type(class_.__name__)
  61. for index, id in enumerate(self.parent_token.id):
  62. dict_[id.key] = pk[index]
  63. dict_[self.parent_token.discriminator.key] = discriminator
  64. class GenericRelationshipProperty(MapperProperty):
  65. """A generic form of the relationship property.
  66. Creates a 1 to many relationship between the parent model
  67. and any other models using a descriminator (the table name).
  68. :param discriminator
  69. Field to discriminate which model we are referring to.
  70. :param id:
  71. Field to point to the model we are referring to.
  72. """
  73. def __init__(self, discriminator, id, doc=None):
  74. super(GenericRelationshipProperty, self).__init__()
  75. self._discriminator_col = discriminator
  76. self._id_cols = id
  77. self._id = None
  78. self._discriminator = None
  79. self.doc = doc
  80. set_creation_order(self)
  81. def _column_to_property(self, column):
  82. if isinstance(column, hybrid_property):
  83. attr_key = column.__name__
  84. for key, attr in self.parent.all_orm_descriptors.items():
  85. if key == attr_key:
  86. return attr
  87. else:
  88. for attr in self.parent.attrs.values():
  89. if isinstance(attr, ColumnProperty):
  90. if attr.columns[0].name == column.name:
  91. return attr
  92. def init(self):
  93. def convert_strings(column):
  94. if isinstance(column, six.string_types):
  95. return self.parent.columns[column]
  96. return column
  97. self._discriminator_col = convert_strings(self._discriminator_col)
  98. self._id_cols = convert_strings(self._id_cols)
  99. if isinstance(self._id_cols, Iterable):
  100. self._id_cols = list(map(convert_strings, self._id_cols))
  101. else:
  102. self._id_cols = [self._id_cols]
  103. self.discriminator = self._column_to_property(self._discriminator_col)
  104. if self.discriminator is None:
  105. raise ImproperlyConfigured(
  106. 'Could not find discriminator descriptor.'
  107. )
  108. self.id = list(map(self._column_to_property, self._id_cols))
  109. class Comparator(PropComparator):
  110. def __init__(self, prop, parentmapper):
  111. self.property = prop
  112. self._parententity = parentmapper
  113. def __eq__(self, other):
  114. discriminator = six.text_type(type(other).__name__)
  115. q = self.property._discriminator_col == discriminator
  116. other_id = identity(other)
  117. for index, id in enumerate(self.property._id_cols):
  118. q &= id == other_id[index]
  119. return q
  120. def __ne__(self, other):
  121. return ~(self == other)
  122. def is_type(self, other):
  123. mapper = sa.inspect(other)
  124. # Iterate through the weak sequence in order to get the actual
  125. # mappers
  126. class_names = [six.text_type(other.__name__)]
  127. class_names.extend([
  128. six.text_type(submapper.class_.__name__)
  129. for submapper in mapper._inheriting_mappers
  130. ])
  131. return self.property._discriminator_col.in_(class_names)
  132. def instrument_class(self, mapper):
  133. attributes.register_attribute(
  134. mapper.class_,
  135. self.key,
  136. comparator=self.Comparator(self, mapper),
  137. parententity=mapper,
  138. doc=self.doc,
  139. impl_class=GenericAttributeImpl,
  140. parent_token=self
  141. )
  142. def generic_relationship(*args, **kwargs):
  143. return GenericRelationshipProperty(*args, **kwargs)