_assertionold.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. import py
  2. import sys, inspect
  3. from compiler import parse, ast, pycodegen
  4. from py._code.assertion import BuiltinAssertionError, _format_explanation
  5. import types
  6. passthroughex = py.builtin._sysex
  7. class Failure:
  8. def __init__(self, node):
  9. self.exc, self.value, self.tb = sys.exc_info()
  10. self.node = node
  11. class View(object):
  12. """View base class.
  13. If C is a subclass of View, then C(x) creates a proxy object around
  14. the object x. The actual class of the proxy is not C in general,
  15. but a *subclass* of C determined by the rules below. To avoid confusion
  16. we call view class the class of the proxy (a subclass of C, so of View)
  17. and object class the class of x.
  18. Attributes and methods not found in the proxy are automatically read on x.
  19. Other operations like setting attributes are performed on the proxy, as
  20. determined by its view class. The object x is available from the proxy
  21. as its __obj__ attribute.
  22. The view class selection is determined by the __view__ tuples and the
  23. optional __viewkey__ method. By default, the selected view class is the
  24. most specific subclass of C whose __view__ mentions the class of x.
  25. If no such subclass is found, the search proceeds with the parent
  26. object classes. For example, C(True) will first look for a subclass
  27. of C with __view__ = (..., bool, ...) and only if it doesn't find any
  28. look for one with __view__ = (..., int, ...), and then ..., object,...
  29. If everything fails the class C itself is considered to be the default.
  30. Alternatively, the view class selection can be driven by another aspect
  31. of the object x, instead of the class of x, by overriding __viewkey__.
  32. See last example at the end of this module.
  33. """
  34. _viewcache = {}
  35. __view__ = ()
  36. def __new__(rootclass, obj, *args, **kwds):
  37. self = object.__new__(rootclass)
  38. self.__obj__ = obj
  39. self.__rootclass__ = rootclass
  40. key = self.__viewkey__()
  41. try:
  42. self.__class__ = self._viewcache[key]
  43. except KeyError:
  44. self.__class__ = self._selectsubclass(key)
  45. return self
  46. def __getattr__(self, attr):
  47. # attributes not found in the normal hierarchy rooted on View
  48. # are looked up in the object's real class
  49. return getattr(self.__obj__, attr)
  50. def __viewkey__(self):
  51. return self.__obj__.__class__
  52. def __matchkey__(self, key, subclasses):
  53. if inspect.isclass(key):
  54. keys = inspect.getmro(key)
  55. else:
  56. keys = [key]
  57. for key in keys:
  58. result = [C for C in subclasses if key in C.__view__]
  59. if result:
  60. return result
  61. return []
  62. def _selectsubclass(self, key):
  63. subclasses = list(enumsubclasses(self.__rootclass__))
  64. for C in subclasses:
  65. if not isinstance(C.__view__, tuple):
  66. C.__view__ = (C.__view__,)
  67. choices = self.__matchkey__(key, subclasses)
  68. if not choices:
  69. return self.__rootclass__
  70. elif len(choices) == 1:
  71. return choices[0]
  72. else:
  73. # combine the multiple choices
  74. return type('?', tuple(choices), {})
  75. def __repr__(self):
  76. return '%s(%r)' % (self.__rootclass__.__name__, self.__obj__)
  77. def enumsubclasses(cls):
  78. for subcls in cls.__subclasses__():
  79. for subsubclass in enumsubclasses(subcls):
  80. yield subsubclass
  81. yield cls
  82. class Interpretable(View):
  83. """A parse tree node with a few extra methods."""
  84. explanation = None
  85. def is_builtin(self, frame):
  86. return False
  87. def eval(self, frame):
  88. # fall-back for unknown expression nodes
  89. try:
  90. expr = ast.Expression(self.__obj__)
  91. expr.filename = '<eval>'
  92. self.__obj__.filename = '<eval>'
  93. co = pycodegen.ExpressionCodeGenerator(expr).getCode()
  94. result = frame.eval(co)
  95. except passthroughex:
  96. raise
  97. except:
  98. raise Failure(self)
  99. self.result = result
  100. self.explanation = self.explanation or frame.repr(self.result)
  101. def run(self, frame):
  102. # fall-back for unknown statement nodes
  103. try:
  104. expr = ast.Module(None, ast.Stmt([self.__obj__]))
  105. expr.filename = '<run>'
  106. co = pycodegen.ModuleCodeGenerator(expr).getCode()
  107. frame.exec_(co)
  108. except passthroughex:
  109. raise
  110. except:
  111. raise Failure(self)
  112. def nice_explanation(self):
  113. return _format_explanation(self.explanation)
  114. class Name(Interpretable):
  115. __view__ = ast.Name
  116. def is_local(self, frame):
  117. source = '%r in locals() is not globals()' % self.name
  118. try:
  119. return frame.is_true(frame.eval(source))
  120. except passthroughex:
  121. raise
  122. except:
  123. return False
  124. def is_global(self, frame):
  125. source = '%r in globals()' % self.name
  126. try:
  127. return frame.is_true(frame.eval(source))
  128. except passthroughex:
  129. raise
  130. except:
  131. return False
  132. def is_builtin(self, frame):
  133. source = '%r not in locals() and %r not in globals()' % (
  134. self.name, self.name)
  135. try:
  136. return frame.is_true(frame.eval(source))
  137. except passthroughex:
  138. raise
  139. except:
  140. return False
  141. def eval(self, frame):
  142. super(Name, self).eval(frame)
  143. if not self.is_local(frame):
  144. self.explanation = self.name
  145. class Compare(Interpretable):
  146. __view__ = ast.Compare
  147. def eval(self, frame):
  148. expr = Interpretable(self.expr)
  149. expr.eval(frame)
  150. for operation, expr2 in self.ops:
  151. if hasattr(self, 'result'):
  152. # shortcutting in chained expressions
  153. if not frame.is_true(self.result):
  154. break
  155. expr2 = Interpretable(expr2)
  156. expr2.eval(frame)
  157. self.explanation = "%s %s %s" % (
  158. expr.explanation, operation, expr2.explanation)
  159. source = "__exprinfo_left %s __exprinfo_right" % operation
  160. try:
  161. self.result = frame.eval(source,
  162. __exprinfo_left=expr.result,
  163. __exprinfo_right=expr2.result)
  164. except passthroughex:
  165. raise
  166. except:
  167. raise Failure(self)
  168. expr = expr2
  169. class And(Interpretable):
  170. __view__ = ast.And
  171. def eval(self, frame):
  172. explanations = []
  173. for expr in self.nodes:
  174. expr = Interpretable(expr)
  175. expr.eval(frame)
  176. explanations.append(expr.explanation)
  177. self.result = expr.result
  178. if not frame.is_true(expr.result):
  179. break
  180. self.explanation = '(' + ' and '.join(explanations) + ')'
  181. class Or(Interpretable):
  182. __view__ = ast.Or
  183. def eval(self, frame):
  184. explanations = []
  185. for expr in self.nodes:
  186. expr = Interpretable(expr)
  187. expr.eval(frame)
  188. explanations.append(expr.explanation)
  189. self.result = expr.result
  190. if frame.is_true(expr.result):
  191. break
  192. self.explanation = '(' + ' or '.join(explanations) + ')'
  193. # == Unary operations ==
  194. keepalive = []
  195. for astclass, astpattern in {
  196. ast.Not : 'not __exprinfo_expr',
  197. ast.Invert : '(~__exprinfo_expr)',
  198. }.items():
  199. class UnaryArith(Interpretable):
  200. __view__ = astclass
  201. def eval(self, frame, astpattern=astpattern):
  202. expr = Interpretable(self.expr)
  203. expr.eval(frame)
  204. self.explanation = astpattern.replace('__exprinfo_expr',
  205. expr.explanation)
  206. try:
  207. self.result = frame.eval(astpattern,
  208. __exprinfo_expr=expr.result)
  209. except passthroughex:
  210. raise
  211. except:
  212. raise Failure(self)
  213. keepalive.append(UnaryArith)
  214. # == Binary operations ==
  215. for astclass, astpattern in {
  216. ast.Add : '(__exprinfo_left + __exprinfo_right)',
  217. ast.Sub : '(__exprinfo_left - __exprinfo_right)',
  218. ast.Mul : '(__exprinfo_left * __exprinfo_right)',
  219. ast.Div : '(__exprinfo_left / __exprinfo_right)',
  220. ast.Mod : '(__exprinfo_left % __exprinfo_right)',
  221. ast.Power : '(__exprinfo_left ** __exprinfo_right)',
  222. }.items():
  223. class BinaryArith(Interpretable):
  224. __view__ = astclass
  225. def eval(self, frame, astpattern=astpattern):
  226. left = Interpretable(self.left)
  227. left.eval(frame)
  228. right = Interpretable(self.right)
  229. right.eval(frame)
  230. self.explanation = (astpattern
  231. .replace('__exprinfo_left', left .explanation)
  232. .replace('__exprinfo_right', right.explanation))
  233. try:
  234. self.result = frame.eval(astpattern,
  235. __exprinfo_left=left.result,
  236. __exprinfo_right=right.result)
  237. except passthroughex:
  238. raise
  239. except:
  240. raise Failure(self)
  241. keepalive.append(BinaryArith)
  242. class CallFunc(Interpretable):
  243. __view__ = ast.CallFunc
  244. def is_bool(self, frame):
  245. source = 'isinstance(__exprinfo_value, bool)'
  246. try:
  247. return frame.is_true(frame.eval(source,
  248. __exprinfo_value=self.result))
  249. except passthroughex:
  250. raise
  251. except:
  252. return False
  253. def eval(self, frame):
  254. node = Interpretable(self.node)
  255. node.eval(frame)
  256. explanations = []
  257. vars = {'__exprinfo_fn': node.result}
  258. source = '__exprinfo_fn('
  259. for a in self.args:
  260. if isinstance(a, ast.Keyword):
  261. keyword = a.name
  262. a = a.expr
  263. else:
  264. keyword = None
  265. a = Interpretable(a)
  266. a.eval(frame)
  267. argname = '__exprinfo_%d' % len(vars)
  268. vars[argname] = a.result
  269. if keyword is None:
  270. source += argname + ','
  271. explanations.append(a.explanation)
  272. else:
  273. source += '%s=%s,' % (keyword, argname)
  274. explanations.append('%s=%s' % (keyword, a.explanation))
  275. if self.star_args:
  276. star_args = Interpretable(self.star_args)
  277. star_args.eval(frame)
  278. argname = '__exprinfo_star'
  279. vars[argname] = star_args.result
  280. source += '*' + argname + ','
  281. explanations.append('*' + star_args.explanation)
  282. if self.dstar_args:
  283. dstar_args = Interpretable(self.dstar_args)
  284. dstar_args.eval(frame)
  285. argname = '__exprinfo_kwds'
  286. vars[argname] = dstar_args.result
  287. source += '**' + argname + ','
  288. explanations.append('**' + dstar_args.explanation)
  289. self.explanation = "%s(%s)" % (
  290. node.explanation, ', '.join(explanations))
  291. if source.endswith(','):
  292. source = source[:-1]
  293. source += ')'
  294. try:
  295. self.result = frame.eval(source, **vars)
  296. except passthroughex:
  297. raise
  298. except:
  299. raise Failure(self)
  300. if not node.is_builtin(frame) or not self.is_bool(frame):
  301. r = frame.repr(self.result)
  302. self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation)
  303. class Getattr(Interpretable):
  304. __view__ = ast.Getattr
  305. def eval(self, frame):
  306. expr = Interpretable(self.expr)
  307. expr.eval(frame)
  308. source = '__exprinfo_expr.%s' % self.attrname
  309. try:
  310. self.result = frame.eval(source, __exprinfo_expr=expr.result)
  311. except passthroughex:
  312. raise
  313. except:
  314. raise Failure(self)
  315. self.explanation = '%s.%s' % (expr.explanation, self.attrname)
  316. # if the attribute comes from the instance, its value is interesting
  317. source = ('hasattr(__exprinfo_expr, "__dict__") and '
  318. '%r in __exprinfo_expr.__dict__' % self.attrname)
  319. try:
  320. from_instance = frame.is_true(
  321. frame.eval(source, __exprinfo_expr=expr.result))
  322. except passthroughex:
  323. raise
  324. except:
  325. from_instance = True
  326. if from_instance:
  327. r = frame.repr(self.result)
  328. self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation)
  329. # == Re-interpretation of full statements ==
  330. class Assert(Interpretable):
  331. __view__ = ast.Assert
  332. def run(self, frame):
  333. test = Interpretable(self.test)
  334. test.eval(frame)
  335. # simplify 'assert False where False = ...'
  336. if (test.explanation.startswith('False\n{False = ') and
  337. test.explanation.endswith('\n}')):
  338. test.explanation = test.explanation[15:-2]
  339. # print the result as 'assert <explanation>'
  340. self.result = test.result
  341. self.explanation = 'assert ' + test.explanation
  342. if not frame.is_true(test.result):
  343. try:
  344. raise BuiltinAssertionError
  345. except passthroughex:
  346. raise
  347. except:
  348. raise Failure(self)
  349. class Assign(Interpretable):
  350. __view__ = ast.Assign
  351. def run(self, frame):
  352. expr = Interpretable(self.expr)
  353. expr.eval(frame)
  354. self.result = expr.result
  355. self.explanation = '... = ' + expr.explanation
  356. # fall-back-run the rest of the assignment
  357. ass = ast.Assign(self.nodes, ast.Name('__exprinfo_expr'))
  358. mod = ast.Module(None, ast.Stmt([ass]))
  359. mod.filename = '<run>'
  360. co = pycodegen.ModuleCodeGenerator(mod).getCode()
  361. try:
  362. frame.exec_(co, __exprinfo_expr=expr.result)
  363. except passthroughex:
  364. raise
  365. except:
  366. raise Failure(self)
  367. class Discard(Interpretable):
  368. __view__ = ast.Discard
  369. def run(self, frame):
  370. expr = Interpretable(self.expr)
  371. expr.eval(frame)
  372. self.result = expr.result
  373. self.explanation = expr.explanation
  374. class Stmt(Interpretable):
  375. __view__ = ast.Stmt
  376. def run(self, frame):
  377. for stmt in self.nodes:
  378. stmt = Interpretable(stmt)
  379. stmt.run(frame)
  380. def report_failure(e):
  381. explanation = e.node.nice_explanation()
  382. if explanation:
  383. explanation = ", in: " + explanation
  384. else:
  385. explanation = ""
  386. sys.stdout.write("%s: %s%s\n" % (e.exc.__name__, e.value, explanation))
  387. def check(s, frame=None):
  388. if frame is None:
  389. frame = sys._getframe(1)
  390. frame = py.code.Frame(frame)
  391. expr = parse(s, 'eval')
  392. assert isinstance(expr, ast.Expression)
  393. node = Interpretable(expr.node)
  394. try:
  395. node.eval(frame)
  396. except passthroughex:
  397. raise
  398. except Failure:
  399. e = sys.exc_info()[1]
  400. report_failure(e)
  401. else:
  402. if not frame.is_true(node.result):
  403. sys.stderr.write("assertion failed: %s\n" % node.nice_explanation())
  404. ###########################################################
  405. # API / Entry points
  406. # #########################################################
  407. def interpret(source, frame, should_fail=False):
  408. module = Interpretable(parse(source, 'exec').node)
  409. #print "got module", module
  410. if isinstance(frame, types.FrameType):
  411. frame = py.code.Frame(frame)
  412. try:
  413. module.run(frame)
  414. except Failure:
  415. e = sys.exc_info()[1]
  416. return getfailure(e)
  417. except passthroughex:
  418. raise
  419. except:
  420. import traceback
  421. traceback.print_exc()
  422. if should_fail:
  423. return ("(assertion failed, but when it was re-run for "
  424. "printing intermediate values, it did not fail. Suggestions: "
  425. "compute assert expression before the assert or use --nomagic)")
  426. else:
  427. return None
  428. def getmsg(excinfo):
  429. if isinstance(excinfo, tuple):
  430. excinfo = py.code.ExceptionInfo(excinfo)
  431. #frame, line = gettbline(tb)
  432. #frame = py.code.Frame(frame)
  433. #return interpret(line, frame)
  434. tb = excinfo.traceback[-1]
  435. source = str(tb.statement).strip()
  436. x = interpret(source, tb.frame, should_fail=True)
  437. if not isinstance(x, str):
  438. raise TypeError("interpret returned non-string %r" % (x,))
  439. return x
  440. def getfailure(e):
  441. explanation = e.node.nice_explanation()
  442. if str(e.value):
  443. lines = explanation.split('\n')
  444. lines[0] += " << %s" % (e.value,)
  445. explanation = '\n'.join(lines)
  446. text = "%s: %s" % (e.exc.__name__, explanation)
  447. if text.startswith('AssertionError: assert '):
  448. text = text[16:]
  449. return text
  450. def run(s, frame=None):
  451. if frame is None:
  452. frame = sys._getframe(1)
  453. frame = py.code.Frame(frame)
  454. module = Interpretable(parse(s, 'exec').node)
  455. try:
  456. module.run(frame)
  457. except Failure:
  458. e = sys.exc_info()[1]
  459. report_failure(e)
  460. if __name__ == '__main__':
  461. # example:
  462. def f():
  463. return 5
  464. def g():
  465. return 3
  466. def h(x):
  467. return 'never'
  468. check("f() * g() == 5")
  469. check("not f()")
  470. check("not (f() and g() or 0)")
  471. check("f() == g()")
  472. i = 4
  473. check("i == f()")
  474. check("len(f()) == 0")
  475. check("isinstance(2+3+4, float)")
  476. run("x = i")
  477. check("x == 5")
  478. run("assert not f(), 'oops'")
  479. run("a, b, c = 1, 2")
  480. run("a, b, c = f()")
  481. check("max([f(),g()]) == 4")
  482. check("'hello'[g()] == 'h'")
  483. run("'guk%d' % h(f())")