test_rowcount.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from sqlalchemy import bindparam
  2. from sqlalchemy import Column
  3. from sqlalchemy import Integer
  4. from sqlalchemy import select
  5. from sqlalchemy import String
  6. from sqlalchemy import Table
  7. from sqlalchemy import testing
  8. from sqlalchemy import text
  9. from sqlalchemy.testing import eq_
  10. from sqlalchemy.testing import fixtures
  11. class RowCountTest(fixtures.TablesTest):
  12. """test rowcount functionality"""
  13. __requires__ = ("sane_rowcount",)
  14. __backend__ = True
  15. @classmethod
  16. def define_tables(cls, metadata):
  17. Table(
  18. "employees",
  19. metadata,
  20. Column(
  21. "employee_id",
  22. Integer,
  23. autoincrement=False,
  24. primary_key=True,
  25. ),
  26. Column("name", String(50)),
  27. Column("department", String(1)),
  28. )
  29. @classmethod
  30. def insert_data(cls, connection):
  31. cls.data = data = [
  32. ("Angela", "A"),
  33. ("Andrew", "A"),
  34. ("Anand", "A"),
  35. ("Bob", "B"),
  36. ("Bobette", "B"),
  37. ("Buffy", "B"),
  38. ("Charlie", "C"),
  39. ("Cynthia", "C"),
  40. ("Chris", "C"),
  41. ]
  42. employees_table = cls.tables.employees
  43. connection.execute(
  44. employees_table.insert(),
  45. [
  46. {"employee_id": i, "name": n, "department": d}
  47. for i, (n, d) in enumerate(data)
  48. ],
  49. )
  50. def test_basic(self, connection):
  51. employees_table = self.tables.employees
  52. s = select(
  53. employees_table.c.name, employees_table.c.department
  54. ).order_by(employees_table.c.employee_id)
  55. rows = connection.execute(s).fetchall()
  56. eq_(rows, self.data)
  57. def test_update_rowcount1(self, connection):
  58. employees_table = self.tables.employees
  59. # WHERE matches 3, 3 rows changed
  60. department = employees_table.c.department
  61. r = connection.execute(
  62. employees_table.update().where(department == "C"),
  63. {"department": "Z"},
  64. )
  65. assert r.rowcount == 3
  66. def test_update_rowcount2(self, connection):
  67. employees_table = self.tables.employees
  68. # WHERE matches 3, 0 rows changed
  69. department = employees_table.c.department
  70. r = connection.execute(
  71. employees_table.update().where(department == "C"),
  72. {"department": "C"},
  73. )
  74. eq_(r.rowcount, 3)
  75. @testing.requires.sane_rowcount_w_returning
  76. def test_update_rowcount_return_defaults(self, connection):
  77. employees_table = self.tables.employees
  78. department = employees_table.c.department
  79. stmt = (
  80. employees_table.update()
  81. .where(department == "C")
  82. .values(name=employees_table.c.department + "Z")
  83. .return_defaults()
  84. )
  85. r = connection.execute(stmt)
  86. eq_(r.rowcount, 3)
  87. def test_raw_sql_rowcount(self, connection):
  88. # test issue #3622, make sure eager rowcount is called for text
  89. result = connection.exec_driver_sql(
  90. "update employees set department='Z' where department='C'"
  91. )
  92. eq_(result.rowcount, 3)
  93. def test_text_rowcount(self, connection):
  94. # test issue #3622, make sure eager rowcount is called for text
  95. result = connection.execute(
  96. text("update employees set department='Z' " "where department='C'")
  97. )
  98. eq_(result.rowcount, 3)
  99. def test_delete_rowcount(self, connection):
  100. employees_table = self.tables.employees
  101. # WHERE matches 3, 3 rows deleted
  102. department = employees_table.c.department
  103. r = connection.execute(
  104. employees_table.delete().where(department == "C")
  105. )
  106. eq_(r.rowcount, 3)
  107. @testing.requires.sane_multi_rowcount
  108. def test_multi_update_rowcount(self, connection):
  109. employees_table = self.tables.employees
  110. stmt = (
  111. employees_table.update()
  112. .where(employees_table.c.name == bindparam("emp_name"))
  113. .values(department="C")
  114. )
  115. r = connection.execute(
  116. stmt,
  117. [
  118. {"emp_name": "Bob"},
  119. {"emp_name": "Cynthia"},
  120. {"emp_name": "nonexistent"},
  121. ],
  122. )
  123. eq_(r.rowcount, 2)
  124. @testing.requires.sane_multi_rowcount
  125. def test_multi_delete_rowcount(self, connection):
  126. employees_table = self.tables.employees
  127. stmt = employees_table.delete().where(
  128. employees_table.c.name == bindparam("emp_name")
  129. )
  130. r = connection.execute(
  131. stmt,
  132. [
  133. {"emp_name": "Bob"},
  134. {"emp_name": "Cynthia"},
  135. {"emp_name": "nonexistent"},
  136. ],
  137. )
  138. eq_(r.rowcount, 2)