test_cte.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from .. import fixtures
  2. from ..assertions import eq_
  3. from ..schema import Column
  4. from ..schema import Table
  5. from ... import ForeignKey
  6. from ... import Integer
  7. from ... import select
  8. from ... import String
  9. from ... import testing
  10. class CTETest(fixtures.TablesTest):
  11. __backend__ = True
  12. __requires__ = ("ctes",)
  13. run_inserts = "each"
  14. run_deletes = "each"
  15. @classmethod
  16. def define_tables(cls, metadata):
  17. Table(
  18. "some_table",
  19. metadata,
  20. Column("id", Integer, primary_key=True),
  21. Column("data", String(50)),
  22. Column("parent_id", ForeignKey("some_table.id")),
  23. )
  24. Table(
  25. "some_other_table",
  26. metadata,
  27. Column("id", Integer, primary_key=True),
  28. Column("data", String(50)),
  29. Column("parent_id", Integer),
  30. )
  31. @classmethod
  32. def insert_data(cls, connection):
  33. connection.execute(
  34. cls.tables.some_table.insert(),
  35. [
  36. {"id": 1, "data": "d1", "parent_id": None},
  37. {"id": 2, "data": "d2", "parent_id": 1},
  38. {"id": 3, "data": "d3", "parent_id": 1},
  39. {"id": 4, "data": "d4", "parent_id": 3},
  40. {"id": 5, "data": "d5", "parent_id": 3},
  41. ],
  42. )
  43. def test_select_nonrecursive_round_trip(self, connection):
  44. some_table = self.tables.some_table
  45. cte = (
  46. select(some_table)
  47. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  48. .cte("some_cte")
  49. )
  50. result = connection.execute(
  51. select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
  52. )
  53. eq_(result.fetchall(), [("d4",)])
  54. def test_select_recursive_round_trip(self, connection):
  55. some_table = self.tables.some_table
  56. cte = (
  57. select(some_table)
  58. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  59. .cte("some_cte", recursive=True)
  60. )
  61. cte_alias = cte.alias("c1")
  62. st1 = some_table.alias()
  63. # note that SQL Server requires this to be UNION ALL,
  64. # can't be UNION
  65. cte = cte.union_all(
  66. select(st1).where(st1.c.id == cte_alias.c.parent_id)
  67. )
  68. result = connection.execute(
  69. select(cte.c.data)
  70. .where(cte.c.data != "d2")
  71. .order_by(cte.c.data.desc())
  72. )
  73. eq_(
  74. result.fetchall(),
  75. [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
  76. )
  77. def test_insert_from_select_round_trip(self, connection):
  78. some_table = self.tables.some_table
  79. some_other_table = self.tables.some_other_table
  80. cte = (
  81. select(some_table)
  82. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  83. .cte("some_cte")
  84. )
  85. connection.execute(
  86. some_other_table.insert().from_select(
  87. ["id", "data", "parent_id"], select(cte)
  88. )
  89. )
  90. eq_(
  91. connection.execute(
  92. select(some_other_table).order_by(some_other_table.c.id)
  93. ).fetchall(),
  94. [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
  95. )
  96. @testing.requires.ctes_with_update_delete
  97. @testing.requires.update_from
  98. def test_update_from_round_trip(self, connection):
  99. some_table = self.tables.some_table
  100. some_other_table = self.tables.some_other_table
  101. connection.execute(
  102. some_other_table.insert().from_select(
  103. ["id", "data", "parent_id"], select(some_table)
  104. )
  105. )
  106. cte = (
  107. select(some_table)
  108. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  109. .cte("some_cte")
  110. )
  111. connection.execute(
  112. some_other_table.update()
  113. .values(parent_id=5)
  114. .where(some_other_table.c.data == cte.c.data)
  115. )
  116. eq_(
  117. connection.execute(
  118. select(some_other_table).order_by(some_other_table.c.id)
  119. ).fetchall(),
  120. [
  121. (1, "d1", None),
  122. (2, "d2", 5),
  123. (3, "d3", 5),
  124. (4, "d4", 5),
  125. (5, "d5", 3),
  126. ],
  127. )
  128. @testing.requires.ctes_with_update_delete
  129. @testing.requires.delete_from
  130. def test_delete_from_round_trip(self, connection):
  131. some_table = self.tables.some_table
  132. some_other_table = self.tables.some_other_table
  133. connection.execute(
  134. some_other_table.insert().from_select(
  135. ["id", "data", "parent_id"], select(some_table)
  136. )
  137. )
  138. cte = (
  139. select(some_table)
  140. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  141. .cte("some_cte")
  142. )
  143. connection.execute(
  144. some_other_table.delete().where(
  145. some_other_table.c.data == cte.c.data
  146. )
  147. )
  148. eq_(
  149. connection.execute(
  150. select(some_other_table).order_by(some_other_table.c.id)
  151. ).fetchall(),
  152. [(1, "d1", None), (5, "d5", 3)],
  153. )
  154. @testing.requires.ctes_with_update_delete
  155. def test_delete_scalar_subq_round_trip(self, connection):
  156. some_table = self.tables.some_table
  157. some_other_table = self.tables.some_other_table
  158. connection.execute(
  159. some_other_table.insert().from_select(
  160. ["id", "data", "parent_id"], select(some_table)
  161. )
  162. )
  163. cte = (
  164. select(some_table)
  165. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  166. .cte("some_cte")
  167. )
  168. connection.execute(
  169. some_other_table.delete().where(
  170. some_other_table.c.data
  171. == select(cte.c.data)
  172. .where(cte.c.id == some_other_table.c.id)
  173. .scalar_subquery()
  174. )
  175. )
  176. eq_(
  177. connection.execute(
  178. select(some_other_table).order_by(some_other_table.c.id)
  179. ).fetchall(),
  180. [(1, "d1", None), (5, "d5", 3)],
  181. )