proxy_dict.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import sqlalchemy as sa
  2. class ProxyDict(object):
  3. def __init__(self, parent, collection_name, mapping_attr):
  4. self.parent = parent
  5. self.collection_name = collection_name
  6. self.child_class = mapping_attr.class_
  7. self.key_name = mapping_attr.key
  8. self.cache = {}
  9. @property
  10. def collection(self):
  11. return getattr(self.parent, self.collection_name)
  12. def keys(self):
  13. descriptor = getattr(self.child_class, self.key_name)
  14. return [x[0] for x in self.collection.values(descriptor)]
  15. def __contains__(self, key):
  16. if key in self.cache:
  17. return self.cache[key] is not None
  18. return self.fetch(key) is not None
  19. def has_key(self, key):
  20. return self.__contains__(key)
  21. def fetch(self, key):
  22. session = sa.orm.object_session(self.parent)
  23. if session and sa.orm.util.has_identity(self.parent):
  24. obj = self.collection.filter_by(**{self.key_name: key}).first()
  25. self.cache[key] = obj
  26. return obj
  27. def create_new_instance(self, key):
  28. value = self.child_class(**{self.key_name: key})
  29. self.collection.append(value)
  30. self.cache[key] = value
  31. return value
  32. def __getitem__(self, key):
  33. if key in self.cache:
  34. if self.cache[key] is not None:
  35. return self.cache[key]
  36. else:
  37. value = self.fetch(key)
  38. if value:
  39. return value
  40. return self.create_new_instance(key)
  41. def __setitem__(self, key, value):
  42. try:
  43. existing = self[key]
  44. self.collection.remove(existing)
  45. except KeyError:
  46. pass
  47. self.collection.append(value)
  48. self.cache[key] = value
  49. def proxy_dict(parent, collection_name, mapping_attr):
  50. try:
  51. parent._proxy_dicts
  52. except AttributeError:
  53. parent._proxy_dicts = {}
  54. try:
  55. return parent._proxy_dicts[collection_name]
  56. except KeyError:
  57. parent._proxy_dicts[collection_name] = ProxyDict(
  58. parent,
  59. collection_name,
  60. mapping_attr
  61. )
  62. return parent._proxy_dicts[collection_name]
  63. def expire_proxy_dicts(target, context):
  64. if hasattr(target, '_proxy_dicts'):
  65. target._proxy_dicts = {}
  66. sa.event.listen(sa.orm.mapper, 'expire', expire_proxy_dicts)