123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- import sqlalchemy as sa
- class ProxyDict(object):
- def __init__(self, parent, collection_name, mapping_attr):
- self.parent = parent
- self.collection_name = collection_name
- self.child_class = mapping_attr.class_
- self.key_name = mapping_attr.key
- self.cache = {}
- @property
- def collection(self):
- return getattr(self.parent, self.collection_name)
- def keys(self):
- descriptor = getattr(self.child_class, self.key_name)
- return [x[0] for x in self.collection.values(descriptor)]
- def __contains__(self, key):
- if key in self.cache:
- return self.cache[key] is not None
- return self.fetch(key) is not None
- def has_key(self, key):
- return self.__contains__(key)
- def fetch(self, key):
- session = sa.orm.object_session(self.parent)
- if session and sa.orm.util.has_identity(self.parent):
- obj = self.collection.filter_by(**{self.key_name: key}).first()
- self.cache[key] = obj
- return obj
- def create_new_instance(self, key):
- value = self.child_class(**{self.key_name: key})
- self.collection.append(value)
- self.cache[key] = value
- return value
- def __getitem__(self, key):
- if key in self.cache:
- if self.cache[key] is not None:
- return self.cache[key]
- else:
- value = self.fetch(key)
- if value:
- return value
- return self.create_new_instance(key)
- def __setitem__(self, key, value):
- try:
- existing = self[key]
- self.collection.remove(existing)
- except KeyError:
- pass
- self.collection.append(value)
- self.cache[key] = value
- def proxy_dict(parent, collection_name, mapping_attr):
- try:
- parent._proxy_dicts
- except AttributeError:
- parent._proxy_dicts = {}
- try:
- return parent._proxy_dicts[collection_name]
- except KeyError:
- parent._proxy_dicts[collection_name] = ProxyDict(
- parent,
- collection_name,
- mapping_attr
- )
- return parent._proxy_dicts[collection_name]
- def expire_proxy_dicts(target, context):
- if hasattr(target, '_proxy_dicts'):
- target._proxy_dicts = {}
- sa.event.listen(sa.orm.mapper, 'expire', expire_proxy_dicts)
|