123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241 |
- # orm/evaluator.py
- # Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- import operator
- from .. import inspect
- from .. import util
- from ..sql import and_
- from ..sql import operators
- class UnevaluatableError(Exception):
- pass
- class _NoObject(operators.ColumnOperators):
- def operate(self, *arg, **kw):
- return None
- def reverse_operate(self, *arg, **kw):
- return None
- _NO_OBJECT = _NoObject()
- _straight_ops = set(
- getattr(operators, op)
- for op in (
- "add",
- "mul",
- "sub",
- "div",
- "mod",
- "truediv",
- "lt",
- "le",
- "ne",
- "gt",
- "ge",
- "eq",
- )
- )
- _extended_ops = {
- operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
- operators.not_in_op: (
- lambda a, b: a not in b if a is not _NO_OBJECT else None
- ),
- }
- _notimplemented_ops = set(
- getattr(operators, op)
- for op in (
- "like_op",
- "not_like_op",
- "ilike_op",
- "not_ilike_op",
- "startswith_op",
- "between_op",
- "endswith_op",
- "concat_op",
- )
- )
- class EvaluatorCompiler(object):
- def __init__(self, target_cls=None):
- self.target_cls = target_cls
- def process(self, *clauses):
- if len(clauses) > 1:
- clause = and_(*clauses)
- elif clauses:
- clause = clauses[0]
- meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
- if not meth:
- raise UnevaluatableError(
- "Cannot evaluate %s" % type(clause).__name__
- )
- return meth(clause)
- def visit_grouping(self, clause):
- return self.process(clause.element)
- def visit_null(self, clause):
- return lambda obj: None
- def visit_false(self, clause):
- return lambda obj: False
- def visit_true(self, clause):
- return lambda obj: True
- def visit_column(self, clause):
- if "parentmapper" in clause._annotations:
- parentmapper = clause._annotations["parentmapper"]
- if self.target_cls and not issubclass(
- self.target_cls, parentmapper.class_
- ):
- raise UnevaluatableError(
- "Can't evaluate criteria against alternate class %s"
- % parentmapper.class_
- )
- key = parentmapper._columntoproperty[clause].key
- else:
- key = clause.key
- if (
- self.target_cls
- and key in inspect(self.target_cls).column_attrs
- ):
- util.warn(
- "Evaluating non-mapped column expression '%s' onto "
- "ORM instances; this is a deprecated use case. Please "
- "make use of the actual mapped columns in ORM-evaluated "
- "UPDATE / DELETE expressions." % clause
- )
- else:
- raise UnevaluatableError("Cannot evaluate column: %s" % clause)
- get_corresponding_attr = operator.attrgetter(key)
- return (
- lambda obj: get_corresponding_attr(obj)
- if obj is not None
- else _NO_OBJECT
- )
- def visit_tuple(self, clause):
- return self.visit_clauselist(clause)
- def visit_clauselist(self, clause):
- evaluators = list(map(self.process, clause.clauses))
- if clause.operator is operators.or_:
- def evaluate(obj):
- has_null = False
- for sub_evaluate in evaluators:
- value = sub_evaluate(obj)
- if value:
- return True
- has_null = has_null or value is None
- if has_null:
- return None
- return False
- elif clause.operator is operators.and_:
- def evaluate(obj):
- for sub_evaluate in evaluators:
- value = sub_evaluate(obj)
- if not value:
- if value is None or value is _NO_OBJECT:
- return None
- return False
- return True
- elif clause.operator is operators.comma_op:
- def evaluate(obj):
- values = []
- for sub_evaluate in evaluators:
- value = sub_evaluate(obj)
- if value is None or value is _NO_OBJECT:
- return None
- values.append(value)
- return tuple(values)
- else:
- raise UnevaluatableError(
- "Cannot evaluate clauselist with operator %s" % clause.operator
- )
- return evaluate
- def visit_binary(self, clause):
- eval_left, eval_right = list(
- map(self.process, [clause.left, clause.right])
- )
- operator = clause.operator
- if operator is operators.is_:
- def evaluate(obj):
- return eval_left(obj) == eval_right(obj)
- elif operator is operators.is_not:
- def evaluate(obj):
- return eval_left(obj) != eval_right(obj)
- elif operator in _extended_ops:
- def evaluate(obj):
- left_val = eval_left(obj)
- right_val = eval_right(obj)
- if left_val is None or right_val is None:
- return None
- return _extended_ops[operator](left_val, right_val)
- elif operator in _straight_ops:
- def evaluate(obj):
- left_val = eval_left(obj)
- right_val = eval_right(obj)
- if left_val is None or right_val is None:
- return None
- return operator(eval_left(obj), eval_right(obj))
- else:
- raise UnevaluatableError(
- "Cannot evaluate %s with operator %s"
- % (type(clause).__name__, clause.operator)
- )
- return evaluate
- def visit_unary(self, clause):
- eval_inner = self.process(clause.element)
- if clause.operator is operators.inv:
- def evaluate(obj):
- value = eval_inner(obj)
- if value is None:
- return None
- return not value
- return evaluate
- raise UnevaluatableError(
- "Cannot evaluate %s with operator %s"
- % (type(clause).__name__, clause.operator)
- )
- def visit_bindparam(self, clause):
- if clause.callable:
- val = clause.callable()
- else:
- val = clause.value
- return lambda obj: val
|