expression_parser.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from collections import OrderedDict
  2. import six
  3. import sqlalchemy as sa
  4. from sqlalchemy.orm.attributes import InstrumentedAttribute
  5. from sqlalchemy.sql.annotation import AnnotatedColumn
  6. from sqlalchemy.sql.expression import (
  7. BooleanClauseList,
  8. BinaryExpression,
  9. UnaryExpression,
  10. BindParameter,
  11. Cast,
  12. )
  13. from sqlalchemy.sql.elements import (
  14. False_,
  15. True_,
  16. Grouping,
  17. ClauseList,
  18. Label,
  19. Case,
  20. Tuple,
  21. Null
  22. )
  23. class ExpressionParser(object):
  24. parsers = OrderedDict((
  25. (BinaryExpression, 'binary_expression'),
  26. (BooleanClauseList, 'boolean_expression'),
  27. (UnaryExpression, 'unary_expression'),
  28. (sa.Column, 'column'),
  29. (AnnotatedColumn, 'column'),
  30. (BindParameter, 'bind_parameter'),
  31. (False_, 'false'),
  32. (True_, 'true'),
  33. (Grouping, 'grouping'),
  34. (ClauseList, 'clause_list'),
  35. (Label, 'label'),
  36. (Cast, 'cast'),
  37. (Case, 'case'),
  38. (Tuple, 'tuple'),
  39. (Null, 'null'),
  40. (InstrumentedAttribute, 'instrumented_attribute')
  41. ))
  42. def expression(self, expr):
  43. if expr is None:
  44. return
  45. for class_, parser in self.parsers.items():
  46. if isinstance(expr, class_):
  47. return getattr(self, parser)(expr)
  48. raise Exception(
  49. 'Unknown expression type %s' % expr.__class__.__name__
  50. )
  51. def instrumented_attribute(self, expr):
  52. return expr
  53. def null(self, expr):
  54. return expr
  55. def tuple(self, expr):
  56. return expr.__class__(
  57. *map(self.expression, expr.clauses),
  58. type_=expr.type
  59. )
  60. def clause_list(self, expr):
  61. return expr.__class__(
  62. *map(self.expression, expr.clauses),
  63. group=expr.group,
  64. group_contents=expr.group_contents,
  65. operator=expr.operator
  66. )
  67. def label(self, expr):
  68. return expr.__class__(
  69. name=expr.name,
  70. element=self.expression(expr._element),
  71. type_=expr.type
  72. )
  73. def cast(self, expr):
  74. return expr.__class__(
  75. expression=self.expression(expr.clause),
  76. type_=expr.type
  77. )
  78. def case(self, expr):
  79. return expr.__class__(
  80. whens=[
  81. tuple(self.expression(x) for x in when) for when in expr.whens
  82. ],
  83. value=self.expression(expr.value),
  84. else_=self.expression(expr.else_)
  85. )
  86. def grouping(self, expr):
  87. return expr.__class__(self.expression(expr.element))
  88. def true(self, expr):
  89. return expr
  90. def false(self, expr):
  91. return expr
  92. def process_table(self, table):
  93. return table
  94. def column(self, column):
  95. table = self.process_table(column.table)
  96. return table.c[column.name]
  97. def unary_expression(self, expr):
  98. return expr.operator(self.expression(expr.element))
  99. def bind_parameter(self, expr):
  100. # somehow bind parameters passed as unicode are converted to
  101. # ascii strings along the way, force convert them back to avoid
  102. # sqlalchemy unicode warnings
  103. if isinstance(expr.type, sa.Unicode):
  104. expr.value = six.text_type(expr.value)
  105. return expr
  106. def binary_expression(self, expr):
  107. return expr.__class__(
  108. left=self.expression(expr.left),
  109. right=self.expression(expr.right),
  110. operator=expr.operator,
  111. type_=expr.type,
  112. negate=expr.negate,
  113. modifiers=expr.modifiers.copy()
  114. )
  115. def boolean_expression(self, expr):
  116. return expr.operator(*[
  117. self.expression(child_expr)
  118. for child_expr in expr.get_children()
  119. ])
  120. def __call__(self, expr):
  121. return self.expression(expr)