provision.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. import collections
  2. import logging
  3. from . import config
  4. from . import engines
  5. from . import util
  6. from .. import exc
  7. from .. import inspect
  8. from ..engine import url as sa_url
  9. from ..sql import ddl
  10. from ..sql import schema
  11. from ..util import compat
  12. log = logging.getLogger(__name__)
  13. FOLLOWER_IDENT = None
  14. class register(object):
  15. def __init__(self):
  16. self.fns = {}
  17. @classmethod
  18. def init(cls, fn):
  19. return register().for_db("*")(fn)
  20. def for_db(self, *dbnames):
  21. def decorate(fn):
  22. for dbname in dbnames:
  23. self.fns[dbname] = fn
  24. return self
  25. return decorate
  26. def __call__(self, cfg, *arg):
  27. if isinstance(cfg, compat.string_types):
  28. url = sa_url.make_url(cfg)
  29. elif isinstance(cfg, sa_url.URL):
  30. url = cfg
  31. else:
  32. url = cfg.db.url
  33. backend = url.get_backend_name()
  34. if backend in self.fns:
  35. return self.fns[backend](cfg, *arg)
  36. else:
  37. return self.fns["*"](cfg, *arg)
  38. def create_follower_db(follower_ident):
  39. for cfg in _configs_for_db_operation():
  40. log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
  41. create_db(cfg, cfg.db, follower_ident)
  42. def setup_config(db_url, options, file_config, follower_ident):
  43. # load the dialect, which should also have it set up its provision
  44. # hooks
  45. dialect = sa_url.make_url(db_url).get_dialect()
  46. dialect.load_provisioning()
  47. if follower_ident:
  48. db_url = follower_url_from_main(db_url, follower_ident)
  49. db_opts = {}
  50. update_db_opts(db_url, db_opts)
  51. db_opts["scope"] = "global"
  52. eng = engines.testing_engine(db_url, db_opts)
  53. post_configure_engine(db_url, eng, follower_ident)
  54. eng.connect().close()
  55. cfg = config.Config.register(eng, db_opts, options, file_config)
  56. # a symbolic name that tests can use if they need to disambiguate
  57. # names across databases
  58. if follower_ident:
  59. config.ident = follower_ident
  60. if follower_ident:
  61. configure_follower(cfg, follower_ident)
  62. return cfg
  63. def drop_follower_db(follower_ident):
  64. for cfg in _configs_for_db_operation():
  65. log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
  66. drop_db(cfg, cfg.db, follower_ident)
  67. def generate_db_urls(db_urls, extra_drivers):
  68. """Generate a set of URLs to test given configured URLs plus additional
  69. driver names.
  70. Given::
  71. --dburi postgresql://db1 \
  72. --dburi postgresql://db2 \
  73. --dburi postgresql://db2 \
  74. --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
  75. Noting that the default postgresql driver is psycopg2, the output
  76. would be::
  77. postgresql+psycopg2://db1
  78. postgresql+asyncpg://db1
  79. postgresql+psycopg2://db2
  80. postgresql+psycopg2://db3
  81. That is, for the driver in a --dburi, we want to keep that and use that
  82. driver for each URL it's part of . For a driver that is only
  83. in --dbdrivers, we want to use it just once for one of the URLs.
  84. for a driver that is both coming from --dburi as well as --dbdrivers,
  85. we want to keep it in that dburi.
  86. Driver specific query options can be specified by added them to the
  87. driver name. For example, to enable the async fallback option for
  88. asyncpg::
  89. --dburi postgresql://db1 \
  90. --dbdriver=asyncpg?async_fallback=true
  91. """
  92. urls = set()
  93. backend_to_driver_we_already_have = collections.defaultdict(set)
  94. urls_plus_dialects = [
  95. (url_obj, url_obj.get_dialect())
  96. for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
  97. ]
  98. for url_obj, dialect in urls_plus_dialects:
  99. backend_to_driver_we_already_have[dialect.name].add(dialect.driver)
  100. backend_to_driver_we_need = {}
  101. for url_obj, dialect in urls_plus_dialects:
  102. backend = dialect.name
  103. dialect.load_provisioning()
  104. if backend not in backend_to_driver_we_need:
  105. backend_to_driver_we_need[backend] = extra_per_backend = set(
  106. extra_drivers
  107. ).difference(backend_to_driver_we_already_have[backend])
  108. else:
  109. extra_per_backend = backend_to_driver_we_need[backend]
  110. for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
  111. if driver_url in urls:
  112. continue
  113. urls.add(driver_url)
  114. yield driver_url
  115. def _generate_driver_urls(url, extra_drivers):
  116. main_driver = url.get_driver_name()
  117. extra_drivers.discard(main_driver)
  118. url = generate_driver_url(url, main_driver, "")
  119. yield str(url)
  120. for drv in list(extra_drivers):
  121. if "?" in drv:
  122. driver_only, query_str = drv.split("?", 1)
  123. else:
  124. driver_only = drv
  125. query_str = None
  126. new_url = generate_driver_url(url, driver_only, query_str)
  127. if new_url:
  128. extra_drivers.remove(drv)
  129. yield str(new_url)
  130. @register.init
  131. def generate_driver_url(url, driver, query_str):
  132. backend = url.get_backend_name()
  133. new_url = url.set(
  134. drivername="%s+%s" % (backend, driver),
  135. )
  136. if query_str:
  137. new_url = new_url.update_query_string(query_str)
  138. try:
  139. new_url.get_dialect()
  140. except exc.NoSuchModuleError:
  141. return None
  142. else:
  143. return new_url
  144. def _configs_for_db_operation():
  145. hosts = set()
  146. for cfg in config.Config.all_configs():
  147. cfg.db.dispose()
  148. for cfg in config.Config.all_configs():
  149. url = cfg.db.url
  150. backend = url.get_backend_name()
  151. host_conf = (backend, url.username, url.host, url.database)
  152. if host_conf not in hosts:
  153. yield cfg
  154. hosts.add(host_conf)
  155. for cfg in config.Config.all_configs():
  156. cfg.db.dispose()
  157. @register.init
  158. def drop_all_schema_objects_pre_tables(cfg, eng):
  159. pass
  160. @register.init
  161. def drop_all_schema_objects_post_tables(cfg, eng):
  162. pass
  163. def drop_all_schema_objects(cfg, eng):
  164. drop_all_schema_objects_pre_tables(cfg, eng)
  165. inspector = inspect(eng)
  166. try:
  167. view_names = inspector.get_view_names()
  168. except NotImplementedError:
  169. pass
  170. else:
  171. with eng.begin() as conn:
  172. for vname in view_names:
  173. conn.execute(
  174. ddl._DropView(schema.Table(vname, schema.MetaData()))
  175. )
  176. if config.requirements.schemas.enabled_for_config(cfg):
  177. try:
  178. view_names = inspector.get_view_names(schema="test_schema")
  179. except NotImplementedError:
  180. pass
  181. else:
  182. with eng.begin() as conn:
  183. for vname in view_names:
  184. conn.execute(
  185. ddl._DropView(
  186. schema.Table(
  187. vname,
  188. schema.MetaData(),
  189. schema="test_schema",
  190. )
  191. )
  192. )
  193. util.drop_all_tables(eng, inspector)
  194. if config.requirements.schemas.enabled_for_config(cfg):
  195. util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
  196. util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
  197. drop_all_schema_objects_post_tables(cfg, eng)
  198. if config.requirements.sequences.enabled_for_config(cfg):
  199. with eng.begin() as conn:
  200. for seq in inspector.get_sequence_names():
  201. conn.execute(ddl.DropSequence(schema.Sequence(seq)))
  202. if config.requirements.schemas.enabled_for_config(cfg):
  203. for schema_name in [cfg.test_schema, cfg.test_schema_2]:
  204. for seq in inspector.get_sequence_names(
  205. schema=schema_name
  206. ):
  207. conn.execute(
  208. ddl.DropSequence(
  209. schema.Sequence(seq, schema=schema_name)
  210. )
  211. )
  212. @register.init
  213. def create_db(cfg, eng, ident):
  214. """Dynamically create a database for testing.
  215. Used when a test run will employ multiple processes, e.g., when run
  216. via `tox` or `pytest -n4`.
  217. """
  218. raise NotImplementedError("no DB creation routine for cfg: %s" % eng.url)
  219. @register.init
  220. def drop_db(cfg, eng, ident):
  221. """Drop a database that we dynamically created for testing."""
  222. raise NotImplementedError("no DB drop routine for cfg: %s" % eng.url)
  223. @register.init
  224. def update_db_opts(db_url, db_opts):
  225. """Set database options (db_opts) for a test database that we created."""
  226. pass
  227. @register.init
  228. def post_configure_engine(url, engine, follower_ident):
  229. """Perform extra steps after configuring an engine for testing.
  230. (For the internal dialects, currently only used by sqlite, oracle)
  231. """
  232. pass
  233. @register.init
  234. def follower_url_from_main(url, ident):
  235. """Create a connection URL for a dynamically-created test database.
  236. :param url: the connection URL specified when the test run was invoked
  237. :param ident: the pytest-xdist "worker identifier" to be used as the
  238. database name
  239. """
  240. url = sa_url.make_url(url)
  241. return url.set(database=ident)
  242. @register.init
  243. def configure_follower(cfg, ident):
  244. """Create dialect-specific config settings for a follower database."""
  245. pass
  246. @register.init
  247. def run_reap_dbs(url, ident):
  248. """Remove databases that were created during the test process, after the
  249. process has ended.
  250. This is an optional step that is invoked for certain backends that do not
  251. reliably release locks on the database as long as a process is still in
  252. use. For the internal dialects, this is currently only necessary for
  253. mssql and oracle.
  254. """
  255. pass
  256. def reap_dbs(idents_file):
  257. log.info("Reaping databases...")
  258. urls = collections.defaultdict(set)
  259. idents = collections.defaultdict(set)
  260. dialects = {}
  261. with open(idents_file) as file_:
  262. for line in file_:
  263. line = line.strip()
  264. db_name, db_url = line.split(" ")
  265. url_obj = sa_url.make_url(db_url)
  266. if db_name not in dialects:
  267. dialects[db_name] = url_obj.get_dialect()
  268. dialects[db_name].load_provisioning()
  269. url_key = (url_obj.get_backend_name(), url_obj.host)
  270. urls[url_key].add(db_url)
  271. idents[url_key].add(db_name)
  272. for url_key in urls:
  273. url = list(urls[url_key])[0]
  274. ident = idents[url_key]
  275. run_reap_dbs(url, ident)
  276. @register.init
  277. def temp_table_keyword_args(cfg, eng):
  278. """Specify keyword arguments for creating a temporary Table.
  279. Dialect-specific implementations of this method will return the
  280. kwargs that are passed to the Table method when creating a temporary
  281. table for testing, e.g., in the define_temp_tables method of the
  282. ComponentReflectionTest class in suite/test_reflection.py
  283. """
  284. raise NotImplementedError(
  285. "no temp table keyword args routine for cfg: %s" % eng.url
  286. )
  287. @register.init
  288. def prepare_for_drop_tables(config, connection):
  289. pass
  290. @register.init
  291. def stop_test_class_outside_fixtures(config, db, testcls):
  292. pass
  293. @register.init
  294. def get_temp_table_name(cfg, eng, base_name):
  295. """Specify table name for creating a temporary Table.
  296. Dialect-specific implementations of this method will return the
  297. name to use when creating a temporary table for testing,
  298. e.g., in the define_temp_tables method of the
  299. ComponentReflectionTest class in suite/test_reflection.py
  300. Default to just the base name since that's what most dialects will
  301. use. The mssql dialect's implementation will need a "#" prepended.
  302. """
  303. return base_name
  304. @register.init
  305. def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
  306. raise NotImplementedError(
  307. "backend does not implement a schema name set function: %s"
  308. % (cfg.db.url,)
  309. )