persistence.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #
  2. # ---- persistence global module ----
  3. #
  4. # this module has the purpose to manage the database (sqlite) general mechanism (creating, accessing, etc...)
  5. #
  6. import logging
  7. from utility.app_logging import logger_name
  8. from sqlalchemy import create_engine
  9. from sqlalchemy.ext.declarative import declarative_base
  10. from sqlalchemy.orm.session import sessionmaker
  11. from sqlalchemy_utils import database_exists, create_database
  12. logger = logging.getLogger(logger_name + ".PERSISTENCE")
  13. __db_engine = None
  14. __db_session = None
  15. __sqlite_db_path = "OpenISP.db"
  16. __sqlite_db_pattern = 'sqlite:///' + __sqlite_db_path
  17. __initialized = False
  18. __db_Base = declarative_base()
  19. def get_session():
  20. if not __initialized:
  21. logger.error(
  22. "can't get session, persistence engine is not initialized")
  23. raise "can't get session, persistence engine is not initialized"
  24. return __db_session
  25. def get_Session_Instance():
  26. if not __initialized:
  27. logger.error(
  28. "can't get session, persistence engine is not initialized")
  29. raise "can't get session, persistence engine is not initialized"
  30. session_instance = __db_session()
  31. return session_instance
  32. def get_db_base():
  33. if __initialized:
  34. logger.error(
  35. "can't get declarative Base, persistence engine is already initialized")
  36. raise "can't get declarative Base, persistence engine is already initialized"
  37. return __db_Base
  38. def is_init():
  39. return __initialized
  40. def init():
  41. logger.debug("creating db engine....")
  42. global __db_engine
  43. __db_engine = create_engine(__sqlite_db_pattern)
  44. # from https://stackoverflow.com/questions/2614984/sqlite-sqlalchemy-how-to-enforce-foreign-keys
  45. def _fk_pragma_on_connect(dbapi_con, con_record):
  46. dbapi_con.execute('pragma foreign_keys=ON')
  47. from sqlalchemy import event
  48. event.listen(__db_engine, 'connect', _fk_pragma_on_connect)
  49. # enforce sqlite foreign keys constraint option
  50. __db_engine.execute('pragma foreign_keys=on')
  51. logger.info("Database Engine Created.")
  52. logger.debug("building declarative model...")
  53. import Model.model_manager
  54. Model.model_manager.declare_model()
  55. logger.info("declarative model built.")
  56. if not database_exists(__db_engine.url):
  57. logger.info("Database is not created, creating one...")
  58. create_database(__db_engine.url)
  59. logger.info("Done !")
  60. __db_Base.metadata.create_all(__db_engine)
  61. if is_sane_database(__db_Base, __db_engine):
  62. logger.info("database is compliant to the model")
  63. else:
  64. logger.error("database is NOT compliant to the model")
  65. raise "database is not compliant to the model !"
  66. global __db_session
  67. __db_session = sessionmaker(bind=__db_engine)
  68. global __initialized
  69. __initialized = True
  70. logger.info("Persistence Session started.")
  71. def is_sane_database(Base, engine):
  72. from sqlalchemy import inspect
  73. from sqlalchemy.orm.clsregistry import _ModuleMarker
  74. from sqlalchemy.orm import RelationshipProperty
  75. iengine = inspect(engine)
  76. current_db_tables = iengine.get_table_names()
  77. model_tables = Base.metadata.sorted_tables
  78. for model_table in model_tables:
  79. logger.debug("table '" + model_table.name +
  80. "' in model, checking in database...")
  81. table_found = False
  82. for db_table in current_db_tables:
  83. logger.debug("checking db table : " + db_table)
  84. if db_table == model_table.name:
  85. table_found = True
  86. logger.debug("model table found in db : " + model_table.name)
  87. break
  88. if not table_found:
  89. logger.error("model table not found in db : " + model_table.name)
  90. return False
  91. logger.debug(
  92. "now checking columns informations between model and database for table : " + model_table.name)
  93. for model_col in model_table.columns:
  94. logger.debug("checking model column : " + model_col.name)
  95. col_found = False
  96. for db_col in iengine.get_columns(model_table.name):
  97. logger.debug("with db table column : " + db_col["name"])
  98. if model_col.name == db_col["name"]:
  99. # TODO type checking and other attributes
  100. """
  101. #checking type (model type are formatted like "VARCHAR(500)")
  102. # (db type are formatted like "VARCHAR(length=500)")
  103. print(type(model_col.type))
  104. str1 = re.sub(r'[A-Z]+', ' ', str(model_col.type))
  105. str2 = re.sub(r'[A-Z]+', ' ', db_col["type"])
  106. if str1 != str2 :
  107. logger.debug("table column : '" + db_col["name"] + "' does not have the same type.")
  108. """
  109. col_found = True
  110. logger.debug(
  111. "table column : '" + db_col["name"] + "' found and compliant to the model table informations")
  112. break
  113. if not col_found:
  114. logger.error("column '" + model_col.name +
  115. "' in table '" + model_table.name + "'not found in db ")
  116. return False
  117. # checking name and type
  118. return True
  119. def wipeout_database():
  120. import os
  121. os.remove(__sqlite_db_path)
  122. # TODO adapt to program, find what is an inspector.
  123. def drop_all_tables(engine, inspector, schema=None, include_names=None):
  124. from sqlalchemy import Column, Table, Integer, MetaData, \
  125. ForeignKeyConstraint
  126. from sqlalchemy.schema import DropTable, DropConstraint
  127. if include_names is not None:
  128. include_names = set(include_names)
  129. with engine.connect() as conn:
  130. for tname, fkcs in reversed(
  131. inspector.get_sorted_table_and_fkc_names(schema=schema)):
  132. if tname:
  133. if include_names is not None and tname not in include_names:
  134. continue
  135. conn.execute(DropTable(
  136. Table(tname, MetaData(), schema=schema)
  137. ))
  138. elif fkcs:
  139. if not engine.dialect.supports_alter:
  140. continue
  141. for tname, fkc in fkcs:
  142. if include_names is not None and \
  143. tname not in include_names:
  144. continue
  145. tb = Table(
  146. tname, MetaData(),
  147. Column('x', Integer),
  148. Column('y', Integer),
  149. schema=schema
  150. )
  151. conn.execute(DropConstraint(
  152. ForeignKeyConstraint(
  153. [tb.c.x], [tb.c.y], name=fkc)
  154. ))