123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465 |
- # testing/exclusions.py
- # Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- import contextlib
- import operator
- import re
- import sys
- from . import config
- from .. import util
- from ..util import decorator
- from ..util.compat import inspect_getfullargspec
- def skip_if(predicate, reason=None):
- rule = compound()
- pred = _as_predicate(predicate, reason)
- rule.skips.add(pred)
- return rule
- def fails_if(predicate, reason=None):
- rule = compound()
- pred = _as_predicate(predicate, reason)
- rule.fails.add(pred)
- return rule
- class compound(object):
- def __init__(self):
- self.fails = set()
- self.skips = set()
- self.tags = set()
- def __add__(self, other):
- return self.add(other)
- def as_skips(self):
- rule = compound()
- rule.skips.update(self.skips)
- rule.skips.update(self.fails)
- rule.tags.update(self.tags)
- return rule
- def add(self, *others):
- copy = compound()
- copy.fails.update(self.fails)
- copy.skips.update(self.skips)
- copy.tags.update(self.tags)
- for other in others:
- copy.fails.update(other.fails)
- copy.skips.update(other.skips)
- copy.tags.update(other.tags)
- return copy
- def not_(self):
- copy = compound()
- copy.fails.update(NotPredicate(fail) for fail in self.fails)
- copy.skips.update(NotPredicate(skip) for skip in self.skips)
- copy.tags.update(self.tags)
- return copy
- @property
- def enabled(self):
- return self.enabled_for_config(config._current)
- def enabled_for_config(self, config):
- for predicate in self.skips.union(self.fails):
- if predicate(config):
- return False
- else:
- return True
- def matching_config_reasons(self, config):
- return [
- predicate._as_string(config)
- for predicate in self.skips.union(self.fails)
- if predicate(config)
- ]
- def include_test(self, include_tags, exclude_tags):
- return bool(
- not self.tags.intersection(exclude_tags)
- and (not include_tags or self.tags.intersection(include_tags))
- )
- def _extend(self, other):
- self.skips.update(other.skips)
- self.fails.update(other.fails)
- self.tags.update(other.tags)
- def __call__(self, fn):
- if hasattr(fn, "_sa_exclusion_extend"):
- fn._sa_exclusion_extend._extend(self)
- return fn
- @decorator
- def decorate(fn, *args, **kw):
- return self._do(config._current, fn, *args, **kw)
- decorated = decorate(fn)
- decorated._sa_exclusion_extend = self
- return decorated
- @contextlib.contextmanager
- def fail_if(self):
- all_fails = compound()
- all_fails.fails.update(self.skips.union(self.fails))
- try:
- yield
- except Exception as ex:
- all_fails._expect_failure(config._current, ex)
- else:
- all_fails._expect_success(config._current)
- def _do(self, cfg, fn, *args, **kw):
- for skip in self.skips:
- if skip(cfg):
- msg = "'%s' : %s" % (
- config.get_current_test_name(),
- skip._as_string(cfg),
- )
- config.skip_test(msg)
- try:
- return_value = fn(*args, **kw)
- except Exception as ex:
- self._expect_failure(cfg, ex, name=fn.__name__)
- else:
- self._expect_success(cfg, name=fn.__name__)
- return return_value
- def _expect_failure(self, config, ex, name="block"):
- for fail in self.fails:
- if fail(config):
- if util.py2k:
- str_ex = unicode(ex).encode( # noqa: F821
- "utf-8", errors="ignore"
- )
- else:
- str_ex = str(ex)
- print(
- (
- "%s failed as expected (%s): %s "
- % (name, fail._as_string(config), str_ex)
- )
- )
- break
- else:
- util.raise_(ex, with_traceback=sys.exc_info()[2])
- def _expect_success(self, config, name="block"):
- if not self.fails:
- return
- for fail in self.fails:
- if fail(config):
- raise AssertionError(
- "Unexpected success for '%s' (%s)"
- % (
- name,
- " and ".join(
- fail._as_string(config) for fail in self.fails
- ),
- )
- )
- def requires_tag(tagname):
- return tags([tagname])
- def tags(tagnames):
- comp = compound()
- comp.tags.update(tagnames)
- return comp
- def only_if(predicate, reason=None):
- predicate = _as_predicate(predicate)
- return skip_if(NotPredicate(predicate), reason)
- def succeeds_if(predicate, reason=None):
- predicate = _as_predicate(predicate)
- return fails_if(NotPredicate(predicate), reason)
- class Predicate(object):
- @classmethod
- def as_predicate(cls, predicate, description=None):
- if isinstance(predicate, compound):
- return cls.as_predicate(predicate.enabled_for_config, description)
- elif isinstance(predicate, Predicate):
- if description and predicate.description is None:
- predicate.description = description
- return predicate
- elif isinstance(predicate, (list, set)):
- return OrPredicate(
- [cls.as_predicate(pred) for pred in predicate], description
- )
- elif isinstance(predicate, tuple):
- return SpecPredicate(*predicate)
- elif isinstance(predicate, util.string_types):
- tokens = re.match(
- r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
- )
- if not tokens:
- raise ValueError(
- "Couldn't locate DB name in predicate: %r" % predicate
- )
- db = tokens.group(1)
- op = tokens.group(2)
- spec = (
- tuple(int(d) for d in tokens.group(3).split("."))
- if tokens.group(3)
- else None
- )
- return SpecPredicate(db, op, spec, description=description)
- elif callable(predicate):
- return LambdaPredicate(predicate, description)
- else:
- assert False, "unknown predicate type: %s" % predicate
- def _format_description(self, config, negate=False):
- bool_ = self(config)
- if negate:
- bool_ = not negate
- return self.description % {
- "driver": config.db.url.get_driver_name()
- if config
- else "<no driver>",
- "database": config.db.url.get_backend_name()
- if config
- else "<no database>",
- "doesnt_support": "doesn't support" if bool_ else "does support",
- "does_support": "does support" if bool_ else "doesn't support",
- }
- def _as_string(self, config=None, negate=False):
- raise NotImplementedError()
- class BooleanPredicate(Predicate):
- def __init__(self, value, description=None):
- self.value = value
- self.description = description or "boolean %s" % value
- def __call__(self, config):
- return self.value
- def _as_string(self, config, negate=False):
- return self._format_description(config, negate=negate)
- class SpecPredicate(Predicate):
- def __init__(self, db, op=None, spec=None, description=None):
- self.db = db
- self.op = op
- self.spec = spec
- self.description = description
- _ops = {
- "<": operator.lt,
- ">": operator.gt,
- "==": operator.eq,
- "!=": operator.ne,
- "<=": operator.le,
- ">=": operator.ge,
- "in": operator.contains,
- "between": lambda val, pair: val >= pair[0] and val <= pair[1],
- }
- def __call__(self, config):
- if config is None:
- return False
- engine = config.db
- if "+" in self.db:
- dialect, driver = self.db.split("+")
- else:
- dialect, driver = self.db, None
- if dialect and engine.name != dialect:
- return False
- if driver is not None and engine.driver != driver:
- return False
- if self.op is not None:
- assert driver is None, "DBAPI version specs not supported yet"
- version = _server_version(engine)
- oper = (
- hasattr(self.op, "__call__") and self.op or self._ops[self.op]
- )
- return oper(version, self.spec)
- else:
- return True
- def _as_string(self, config, negate=False):
- if self.description is not None:
- return self._format_description(config)
- elif self.op is None:
- if negate:
- return "not %s" % self.db
- else:
- return "%s" % self.db
- else:
- if negate:
- return "not %s %s %s" % (self.db, self.op, self.spec)
- else:
- return "%s %s %s" % (self.db, self.op, self.spec)
- class LambdaPredicate(Predicate):
- def __init__(self, lambda_, description=None, args=None, kw=None):
- spec = inspect_getfullargspec(lambda_)
- if not spec[0]:
- self.lambda_ = lambda db: lambda_()
- else:
- self.lambda_ = lambda_
- self.args = args or ()
- self.kw = kw or {}
- if description:
- self.description = description
- elif lambda_.__doc__:
- self.description = lambda_.__doc__
- else:
- self.description = "custom function"
- def __call__(self, config):
- return self.lambda_(config)
- def _as_string(self, config, negate=False):
- return self._format_description(config)
- class NotPredicate(Predicate):
- def __init__(self, predicate, description=None):
- self.predicate = predicate
- self.description = description
- def __call__(self, config):
- return not self.predicate(config)
- def _as_string(self, config, negate=False):
- if self.description:
- return self._format_description(config, not negate)
- else:
- return self.predicate._as_string(config, not negate)
- class OrPredicate(Predicate):
- def __init__(self, predicates, description=None):
- self.predicates = predicates
- self.description = description
- def __call__(self, config):
- for pred in self.predicates:
- if pred(config):
- return True
- return False
- def _eval_str(self, config, negate=False):
- if negate:
- conjunction = " and "
- else:
- conjunction = " or "
- return conjunction.join(
- p._as_string(config, negate=negate) for p in self.predicates
- )
- def _negation_str(self, config):
- if self.description is not None:
- return "Not " + self._format_description(config)
- else:
- return self._eval_str(config, negate=True)
- def _as_string(self, config, negate=False):
- if negate:
- return self._negation_str(config)
- else:
- if self.description is not None:
- return self._format_description(config)
- else:
- return self._eval_str(config)
- _as_predicate = Predicate.as_predicate
- def _is_excluded(db, op, spec):
- return SpecPredicate(db, op, spec)(config._current)
- def _server_version(engine):
- """Return a server_version_info tuple."""
- # force metadata to be retrieved
- conn = engine.connect()
- version = getattr(engine.dialect, "server_version_info", None)
- if version is None:
- version = ()
- conn.close()
- return version
- def db_spec(*dbs):
- return OrPredicate([Predicate.as_predicate(db) for db in dbs])
- def open(): # noqa
- return skip_if(BooleanPredicate(False, "mark as execute"))
- def closed():
- return skip_if(BooleanPredicate(True, "marked as skip"))
- def fails(reason=None):
- return fails_if(BooleanPredicate(True, reason or "expected to fail"))
- @decorator
- def future(fn, *arg):
- return fails_if(LambdaPredicate(fn), "Future feature")
- def fails_on(db, reason=None):
- return fails_if(db, reason)
- def fails_on_everything_except(*dbs):
- return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
- def skip(db, reason=None):
- return skip_if(db, reason)
- def only_on(dbs, reason=None):
- return only_if(
- OrPredicate(
- [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
- )
- )
- def exclude(db, op, spec, reason=None):
- return skip_if(SpecPredicate(db, op, spec), reason)
- def against(config, *queries):
- assert queries, "no queries sent!"
- return OrPredicate([Predicate.as_predicate(query) for query in queries])(
- config
- )
|