asserts.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """
  2. The functions in this module can be used for testing that the constraints of
  3. your models. Each assert function runs SQL UPDATEs that check for the existence
  4. of given constraint. Consider the following model::
  5. class User(Base):
  6. __tablename__ = 'user'
  7. id = sa.Column(sa.Integer, primary_key=True)
  8. name = sa.Column(sa.String(200), nullable=True)
  9. email = sa.Column(sa.String(255), nullable=False)
  10. user = User(name='John Doe', email='john@example.com')
  11. session.add(user)
  12. session.commit()
  13. We can easily test the constraints by assert_* functions::
  14. from sqlalchemy_utils import (
  15. assert_nullable,
  16. assert_non_nullable,
  17. assert_max_length
  18. )
  19. assert_nullable(user, 'name')
  20. assert_non_nullable(user, 'email')
  21. assert_max_length(user, 'name', 200)
  22. # raises AssertionError because the max length of email is 255
  23. assert_max_length(user, 'email', 300)
  24. """
  25. from decimal import Decimal
  26. import sqlalchemy as sa
  27. from sqlalchemy.dialects.postgresql import ARRAY
  28. from sqlalchemy.exc import DataError, IntegrityError
  29. def _update_field(obj, field, value):
  30. session = sa.orm.object_session(obj)
  31. column = sa.inspect(obj.__class__).columns[field]
  32. query = column.table.update().values(**{column.key: value})
  33. session.execute(query)
  34. session.flush()
  35. def _expect_successful_update(obj, field, value, reraise_exc):
  36. try:
  37. _update_field(obj, field, value)
  38. except (reraise_exc) as e:
  39. session = sa.orm.object_session(obj)
  40. session.rollback()
  41. assert False, str(e)
  42. def _expect_failing_update(obj, field, value, expected_exc):
  43. try:
  44. _update_field(obj, field, value)
  45. except expected_exc:
  46. pass
  47. else:
  48. raise AssertionError('Expected update to raise %s' % expected_exc)
  49. finally:
  50. session = sa.orm.object_session(obj)
  51. session.rollback()
  52. def _repeated_value(type_):
  53. if isinstance(type_, ARRAY):
  54. if isinstance(type_.item_type, sa.Integer):
  55. return [0]
  56. elif isinstance(type_.item_type, sa.String):
  57. return [u'a']
  58. elif isinstance(type_.item_type, sa.Numeric):
  59. return [Decimal('0')]
  60. else:
  61. raise TypeError('Unknown array item type')
  62. else:
  63. return u'a'
  64. def _expected_exception(type_):
  65. if isinstance(type_, ARRAY):
  66. return IntegrityError
  67. else:
  68. return DataError
  69. def assert_nullable(obj, column):
  70. """
  71. Assert that given column is nullable. This is checked by running an SQL
  72. update that assigns given column as None.
  73. :param obj: SQLAlchemy declarative model object
  74. :param column: Name of the column
  75. """
  76. _expect_successful_update(obj, column, None, IntegrityError)
  77. def assert_non_nullable(obj, column):
  78. """
  79. Assert that given column is not nullable. This is checked by running an SQL
  80. update that assigns given column as None.
  81. :param obj: SQLAlchemy declarative model object
  82. :param column: Name of the column
  83. """
  84. _expect_failing_update(obj, column, None, IntegrityError)
  85. def assert_max_length(obj, column, max_length):
  86. """
  87. Assert that the given column is of given max length. This function supports
  88. string typed columns as well as PostgreSQL array typed columns.
  89. In the following example we add a check constraint that user can have a
  90. maximum of 5 favorite colors and then test this.::
  91. class User(Base):
  92. __tablename__ = 'user'
  93. id = sa.Column(sa.Integer, primary_key=True)
  94. favorite_colors = sa.Column(ARRAY(sa.String), nullable=False)
  95. __table_args__ = (
  96. sa.CheckConstraint(
  97. sa.func.array_length(favorite_colors, 1) <= 5
  98. )
  99. )
  100. user = User(name='John Doe', favorite_colors=['red', 'blue'])
  101. session.add(user)
  102. session.commit()
  103. assert_max_length(user, 'favorite_colors', 5)
  104. :param obj: SQLAlchemy declarative model object
  105. :param column: Name of the column
  106. :param max_length: Maximum length of given column
  107. """
  108. type_ = sa.inspect(obj.__class__).columns[column].type
  109. _expect_successful_update(
  110. obj,
  111. column,
  112. _repeated_value(type_) * max_length,
  113. _expected_exception(type_)
  114. )
  115. _expect_failing_update(
  116. obj,
  117. column,
  118. _repeated_value(type_) * (max_length + 1),
  119. _expected_exception(type_)
  120. )
  121. def assert_min_value(obj, column, min_value):
  122. """
  123. Assert that the given column must have a minimum value of `min_value`.
  124. :param obj: SQLAlchemy declarative model object
  125. :param column: Name of the column
  126. :param min_value: The minimum allowed value for given column
  127. """
  128. _expect_successful_update(obj, column, min_value, IntegrityError)
  129. _expect_failing_update(obj, column, min_value - 1, IntegrityError)
  130. def assert_max_value(obj, column, min_value):
  131. """
  132. Assert that the given column must have a minimum value of `max_value`.
  133. :param obj: SQLAlchemy declarative model object
  134. :param column: Name of the column
  135. :param max_value: The maximum allowed value for given column
  136. """
  137. _expect_successful_update(obj, column, min_value, IntegrityError)
  138. _expect_failing_update(obj, column, min_value + 1, IntegrityError)