schema.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # testing/schema.py
  2. # Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. import sys
  8. from . import config
  9. from . import exclusions
  10. from .. import event
  11. from .. import schema
  12. from .. import types as sqltypes
  13. from ..util import OrderedDict
  14. __all__ = ["Table", "Column"]
  15. table_options = {}
  16. def Table(*args, **kw):
  17. """A schema.Table wrapper/hook for dialect-specific tweaks."""
  18. test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
  19. kw.update(table_options)
  20. if exclusions.against(config._current, "mysql"):
  21. if (
  22. "mysql_engine" not in kw
  23. and "mysql_type" not in kw
  24. and "autoload_with" not in kw
  25. ):
  26. if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
  27. kw["mysql_engine"] = "InnoDB"
  28. else:
  29. kw["mysql_engine"] = "MyISAM"
  30. elif exclusions.against(config._current, "mariadb"):
  31. if (
  32. "mariadb_engine" not in kw
  33. and "mariadb_type" not in kw
  34. and "autoload_with" not in kw
  35. ):
  36. if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
  37. kw["mariadb_engine"] = "InnoDB"
  38. else:
  39. kw["mariadb_engine"] = "MyISAM"
  40. # Apply some default cascading rules for self-referential foreign keys.
  41. # MySQL InnoDB has some issues around selecting self-refs too.
  42. if exclusions.against(config._current, "firebird"):
  43. table_name = args[0]
  44. unpack = config.db.dialect.identifier_preparer.unformat_identifiers
  45. # Only going after ForeignKeys in Columns. May need to
  46. # expand to ForeignKeyConstraint too.
  47. fks = [
  48. fk
  49. for col in args
  50. if isinstance(col, schema.Column)
  51. for fk in col.foreign_keys
  52. ]
  53. for fk in fks:
  54. # root around in raw spec
  55. ref = fk._colspec
  56. if isinstance(ref, schema.Column):
  57. name = ref.table.name
  58. else:
  59. # take just the table name: on FB there cannot be
  60. # a schema, so the first element is always the
  61. # table name, possibly followed by the field name
  62. name = unpack(ref)[0]
  63. if name == table_name:
  64. if fk.ondelete is None:
  65. fk.ondelete = "CASCADE"
  66. if fk.onupdate is None:
  67. fk.onupdate = "CASCADE"
  68. return schema.Table(*args, **kw)
  69. def Column(*args, **kw):
  70. """A schema.Column wrapper/hook for dialect-specific tweaks."""
  71. test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
  72. if not config.requirements.foreign_key_ddl.enabled_for_config(config):
  73. args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
  74. col = schema.Column(*args, **kw)
  75. if test_opts.get("test_needs_autoincrement", False) and kw.get(
  76. "primary_key", False
  77. ):
  78. if col.default is None and col.server_default is None:
  79. col.autoincrement = True
  80. # allow any test suite to pick up on this
  81. col.info["test_needs_autoincrement"] = True
  82. # hardcoded rule for firebird, oracle; this should
  83. # be moved out
  84. if exclusions.against(config._current, "firebird", "oracle"):
  85. def add_seq(c, tbl):
  86. c._init_items(
  87. schema.Sequence(
  88. _truncate_name(
  89. config.db.dialect, tbl.name + "_" + c.name + "_seq"
  90. ),
  91. optional=True,
  92. )
  93. )
  94. event.listen(col, "after_parent_attach", add_seq, propagate=True)
  95. return col
  96. class eq_type_affinity(object):
  97. """Helper to compare types inside of datastructures based on affinity.
  98. E.g.::
  99. eq_(
  100. inspect(connection).get_columns("foo"),
  101. [
  102. {
  103. "name": "id",
  104. "type": testing.eq_type_affinity(sqltypes.INTEGER),
  105. "nullable": False,
  106. "default": None,
  107. "autoincrement": False,
  108. },
  109. {
  110. "name": "data",
  111. "type": testing.eq_type_affinity(sqltypes.NullType),
  112. "nullable": True,
  113. "default": None,
  114. "autoincrement": False,
  115. },
  116. ],
  117. )
  118. """
  119. def __init__(self, target):
  120. self.target = sqltypes.to_instance(target)
  121. def __eq__(self, other):
  122. return self.target._type_affinity is other._type_affinity
  123. def __ne__(self, other):
  124. return self.target._type_affinity is not other._type_affinity
  125. class eq_clause_element(object):
  126. """Helper to compare SQL structures based on compare()"""
  127. def __init__(self, target):
  128. self.target = target
  129. def __eq__(self, other):
  130. return self.target.compare(other)
  131. def __ne__(self, other):
  132. return not self.target.compare(other)
  133. def _truncate_name(dialect, name):
  134. if len(name) > dialect.max_identifier_length:
  135. return (
  136. name[0 : max(dialect.max_identifier_length - 6, 0)]
  137. + "_"
  138. + hex(hash(name) % 64)[2:]
  139. )
  140. else:
  141. return name
  142. def pep435_enum(name):
  143. # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
  144. __members__ = OrderedDict()
  145. def __init__(self, name, value, alias=None):
  146. self.name = name
  147. self.value = value
  148. self.__members__[name] = self
  149. value_to_member[value] = self
  150. setattr(self.__class__, name, self)
  151. if alias:
  152. self.__members__[alias] = self
  153. setattr(self.__class__, alias, self)
  154. value_to_member = {}
  155. @classmethod
  156. def get(cls, value):
  157. return value_to_member[value]
  158. someenum = type(
  159. name,
  160. (object,),
  161. {"__members__": __members__, "__init__": __init__, "get": get},
  162. )
  163. # getframe() trick for pickling I don't understand courtesy
  164. # Python namedtuple()
  165. try:
  166. module = sys._getframe(1).f_globals.get("__name__", "__main__")
  167. except (AttributeError, ValueError):
  168. pass
  169. if module is not None:
  170. someenum.__module__ = module
  171. return someenum