path.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import sqlalchemy as sa
  2. from sqlalchemy.orm.attributes import InstrumentedAttribute
  3. from sqlalchemy.util.langhelpers import symbol
  4. from .utils import str_coercible
  5. @str_coercible
  6. class Path(object):
  7. def __init__(self, path, separator='.'):
  8. if isinstance(path, Path):
  9. self.path = path.path
  10. else:
  11. self.path = path
  12. self.separator = separator
  13. @property
  14. def parts(self):
  15. return self.path.split(self.separator)
  16. def __iter__(self):
  17. for part in self.parts:
  18. yield part
  19. def __len__(self):
  20. return len(self.parts)
  21. def __repr__(self):
  22. return "%s('%s')" % (self.__class__.__name__, self.path)
  23. def index(self, element):
  24. return self.parts.index(element)
  25. def __getitem__(self, slice):
  26. result = self.parts[slice]
  27. if isinstance(result, list):
  28. return self.__class__(
  29. self.separator.join(result),
  30. separator=self.separator
  31. )
  32. return result
  33. def __eq__(self, other):
  34. return self.path == other.path and self.separator == other.separator
  35. def __ne__(self, other):
  36. return not (self == other)
  37. def __unicode__(self):
  38. return self.path
  39. def get_attr(mixed, attr):
  40. if isinstance(mixed, InstrumentedAttribute):
  41. return getattr(
  42. mixed.property.mapper.class_,
  43. attr
  44. )
  45. else:
  46. return getattr(mixed, attr)
  47. @str_coercible
  48. class AttrPath(object):
  49. def __init__(self, class_, path):
  50. self.class_ = class_
  51. self.path = Path(path)
  52. self.parts = []
  53. last_attr = class_
  54. for value in self.path:
  55. last_attr = get_attr(last_attr, value)
  56. self.parts.append(last_attr)
  57. def __iter__(self):
  58. for part in self.parts:
  59. yield part
  60. def __invert__(self):
  61. def get_backref(part):
  62. prop = part.property
  63. backref = prop.backref or prop.back_populates
  64. if backref is None:
  65. raise Exception(
  66. "Invert failed because property '%s' of class "
  67. "%s has no backref." % (
  68. prop.key,
  69. prop.parent.class_.__name__
  70. )
  71. )
  72. if isinstance(backref, tuple):
  73. return backref[0]
  74. else:
  75. return backref
  76. if isinstance(self.parts[-1].property, sa.orm.ColumnProperty):
  77. class_ = self.parts[-1].class_
  78. else:
  79. class_ = self.parts[-1].mapper.class_
  80. return self.__class__(
  81. class_,
  82. '.'.join(map(get_backref, reversed(self.parts)))
  83. )
  84. def index(self, element):
  85. for index, el in enumerate(self.parts):
  86. if el is element:
  87. return index
  88. @property
  89. def direction(self):
  90. symbols = [part.property.direction for part in self.parts]
  91. if symbol('MANYTOMANY') in symbols:
  92. return symbol('MANYTOMANY')
  93. elif symbol('MANYTOONE') in symbols and symbol('ONETOMANY') in symbols:
  94. return symbol('MANYTOMANY')
  95. return symbols[0]
  96. @property
  97. def uselist(self):
  98. return any(part.property.uselist for part in self.parts)
  99. def __getitem__(self, slice):
  100. result = self.parts[slice]
  101. if isinstance(result, list) and result:
  102. if result[0] is self.parts[0]:
  103. class_ = self.class_
  104. else:
  105. class_ = result[0].parent.class_
  106. return self.__class__(
  107. class_,
  108. self.path[slice]
  109. )
  110. else:
  111. return result
  112. def __len__(self):
  113. return len(self.path)
  114. def __repr__(self):
  115. return "%s(%s, %r)" % (
  116. self.__class__.__name__,
  117. self.class_.__name__,
  118. self.path.path
  119. )
  120. def __eq__(self, other):
  121. return self.path == other.path and self.class_ == other.class_
  122. def __ne__(self, other):
  123. return not (self == other)
  124. def __unicode__(self):
  125. return str(self.path)