ltree.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from __future__ import absolute_import
  2. import re
  3. import six
  4. from ..utils import str_coercible
  5. path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$')
  6. @str_coercible
  7. class Ltree(object):
  8. """
  9. Ltree class wraps a valid string label path. It provides various
  10. convenience properties and methods.
  11. ::
  12. from sqlalchemy_utils import Ltree
  13. Ltree('1.2.3').path # '1.2.3'
  14. Ltree always validates the given path.
  15. ::
  16. Ltree(None) # raises TypeError
  17. Ltree('..') # raises ValueError
  18. Validator is also available as class method.
  19. ::
  20. Ltree.validate('1.2.3')
  21. Ltree.validate(None) # raises TypeError
  22. Ltree supports equality operators.
  23. ::
  24. Ltree('Countries.Finland') == Ltree('Countries.Finland')
  25. Ltree('Countries.Germany') != Ltree('Countries.Finland')
  26. Ltree objects are hashable.
  27. ::
  28. assert hash(Ltree('Finland')) == hash('Finland')
  29. Ltree objects have length.
  30. ::
  31. assert len(Ltree('1.2')) == 2
  32. assert len(Ltree('some.one.some.where')) # 4
  33. You can easily find subpath indexes.
  34. ::
  35. assert Ltree('1.2.3').index('2.3') == 1
  36. assert Ltree('1.2.3.4.5').index('3.4') == 2
  37. Ltree objects can be sliced.
  38. ::
  39. assert Ltree('1.2.3')[0:2] == Ltree('1.2')
  40. assert Ltree('1.2.3')[1:] == Ltree('2.3')
  41. Finding longest common ancestor.
  42. ::
  43. assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
  44. assert Ltree('1.2.3.4.5').lca('1.2', '1.2.3') == '1'
  45. Ltree objects can be concatenated.
  46. ::
  47. assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2')
  48. """
  49. def __init__(self, path_or_ltree):
  50. if isinstance(path_or_ltree, Ltree):
  51. self.path = path_or_ltree.path
  52. elif isinstance(path_or_ltree, six.string_types):
  53. self.validate(path_or_ltree)
  54. self.path = path_or_ltree
  55. else:
  56. raise TypeError(
  57. "Ltree() argument must be a string or an Ltree, not '{0}'"
  58. .format(
  59. type(path_or_ltree).__name__
  60. )
  61. )
  62. @classmethod
  63. def validate(cls, path):
  64. if path_matcher.match(path) is None:
  65. raise ValueError(
  66. "'{0}' is not a valid ltree path.".format(path)
  67. )
  68. def __len__(self):
  69. return len(self.path.split('.'))
  70. def index(self, other):
  71. subpath = Ltree(other).path.split('.')
  72. parts = self.path.split('.')
  73. for index, _ in enumerate(parts):
  74. if parts[index:len(subpath) + index] == subpath:
  75. return index
  76. raise ValueError('subpath not found')
  77. def descendant_of(self, other):
  78. """
  79. is left argument a descendant of right (or equal)?
  80. ::
  81. assert Ltree('1.2.3.4.5').descendant_of('1.2.3')
  82. """
  83. subpath = self[:len(Ltree(other))]
  84. return subpath == other
  85. def ancestor_of(self, other):
  86. """
  87. is left argument an ancestor of right (or equal)?
  88. ::
  89. assert Ltree('1.2.3').ancestor_of('1.2.3.4.5')
  90. """
  91. subpath = Ltree(other)[:len(self)]
  92. return subpath == self
  93. def __getitem__(self, key):
  94. if isinstance(key, int):
  95. return Ltree(self.path.split('.')[key])
  96. elif isinstance(key, slice):
  97. return Ltree('.'.join(self.path.split('.')[key]))
  98. raise TypeError(
  99. 'Ltree indices must be integers, not {0}'.format(
  100. key.__class__.__name__
  101. )
  102. )
  103. def lca(self, *others):
  104. """
  105. Lowest common ancestor, i.e., longest common prefix of paths
  106. ::
  107. assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
  108. """
  109. other_parts = [Ltree(other).path.split('.') for other in others]
  110. parts = self.path.split('.')
  111. for index, element in enumerate(parts):
  112. if any((
  113. other[index] != element or
  114. len(other) <= index + 1 or
  115. len(parts) == index + 1
  116. for other in other_parts
  117. )):
  118. if index == 0:
  119. return None
  120. return Ltree('.'.join(parts[0:index]))
  121. def __add__(self, other):
  122. return Ltree(self.path + '.' + Ltree(other).path)
  123. def __radd__(self, other):
  124. return Ltree(other) + self
  125. def __eq__(self, other):
  126. if isinstance(other, Ltree):
  127. return self.path == other.path
  128. elif isinstance(other, six.string_types):
  129. return self.path == other
  130. else:
  131. return NotImplemented
  132. def __hash__(self):
  133. return hash(self.path)
  134. def __ne__(self, other):
  135. return not (self == other)
  136. def __repr__(self):
  137. return '%s(%r)' % (self.__class__.__name__, self.path)
  138. def __unicode__(self):
  139. return self.path
  140. def __contains__(self, label):
  141. return label in self.path.split('.')