123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- import sqlalchemy as sa
- from sqlalchemy.orm.attributes import InstrumentedAttribute
- from sqlalchemy.util.langhelpers import symbol
- from .utils import str_coercible
- @str_coercible
- class Path(object):
- def __init__(self, path, separator='.'):
- if isinstance(path, Path):
- self.path = path.path
- else:
- self.path = path
- self.separator = separator
- @property
- def parts(self):
- return self.path.split(self.separator)
- def __iter__(self):
- for part in self.parts:
- yield part
- def __len__(self):
- return len(self.parts)
- def __repr__(self):
- return "%s('%s')" % (self.__class__.__name__, self.path)
- def index(self, element):
- return self.parts.index(element)
- def __getitem__(self, slice):
- result = self.parts[slice]
- if isinstance(result, list):
- return self.__class__(
- self.separator.join(result),
- separator=self.separator
- )
- return result
- def __eq__(self, other):
- return self.path == other.path and self.separator == other.separator
- def __ne__(self, other):
- return not (self == other)
- def __unicode__(self):
- return self.path
- def get_attr(mixed, attr):
- if isinstance(mixed, InstrumentedAttribute):
- return getattr(
- mixed.property.mapper.class_,
- attr
- )
- else:
- return getattr(mixed, attr)
- @str_coercible
- class AttrPath(object):
- def __init__(self, class_, path):
- self.class_ = class_
- self.path = Path(path)
- self.parts = []
- last_attr = class_
- for value in self.path:
- last_attr = get_attr(last_attr, value)
- self.parts.append(last_attr)
- def __iter__(self):
- for part in self.parts:
- yield part
- def __invert__(self):
- def get_backref(part):
- prop = part.property
- backref = prop.backref or prop.back_populates
- if backref is None:
- raise Exception(
- "Invert failed because property '%s' of class "
- "%s has no backref." % (
- prop.key,
- prop.parent.class_.__name__
- )
- )
- if isinstance(backref, tuple):
- return backref[0]
- else:
- return backref
- if isinstance(self.parts[-1].property, sa.orm.ColumnProperty):
- class_ = self.parts[-1].class_
- else:
- class_ = self.parts[-1].mapper.class_
- return self.__class__(
- class_,
- '.'.join(map(get_backref, reversed(self.parts)))
- )
- def index(self, element):
- for index, el in enumerate(self.parts):
- if el is element:
- return index
- @property
- def direction(self):
- symbols = [part.property.direction for part in self.parts]
- if symbol('MANYTOMANY') in symbols:
- return symbol('MANYTOMANY')
- elif symbol('MANYTOONE') in symbols and symbol('ONETOMANY') in symbols:
- return symbol('MANYTOMANY')
- return symbols[0]
- @property
- def uselist(self):
- return any(part.property.uselist for part in self.parts)
- def __getitem__(self, slice):
- result = self.parts[slice]
- if isinstance(result, list) and result:
- if result[0] is self.parts[0]:
- class_ = self.class_
- else:
- class_ = result[0].parent.class_
- return self.__class__(
- class_,
- self.path[slice]
- )
- else:
- return result
- def __len__(self):
- return len(self.path)
- def __repr__(self):
- return "%s(%s, %r)" % (
- self.__class__.__name__,
- self.class_.__name__,
- self.path.path
- )
- def __eq__(self, other):
- return self.path == other.path and self.class_ == other.class_
- def __ne__(self, other):
- return not (self == other)
- def __unicode__(self):
- return str(self.path)
|