idtracking.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import typing as t
  2. from . import nodes
  3. from .visitor import NodeVisitor
  4. VAR_LOAD_PARAMETER = "param"
  5. VAR_LOAD_RESOLVE = "resolve"
  6. VAR_LOAD_ALIAS = "alias"
  7. VAR_LOAD_UNDEFINED = "undefined"
  8. def find_symbols(
  9. nodes: t.Iterable[nodes.Node], parent_symbols: t.Optional["Symbols"] = None
  10. ) -> "Symbols":
  11. sym = Symbols(parent=parent_symbols)
  12. visitor = FrameSymbolVisitor(sym)
  13. for node in nodes:
  14. visitor.visit(node)
  15. return sym
  16. def symbols_for_node(
  17. node: nodes.Node, parent_symbols: t.Optional["Symbols"] = None
  18. ) -> "Symbols":
  19. sym = Symbols(parent=parent_symbols)
  20. sym.analyze_node(node)
  21. return sym
  22. class Symbols:
  23. def __init__(
  24. self, parent: t.Optional["Symbols"] = None, level: t.Optional[int] = None
  25. ) -> None:
  26. if level is None:
  27. if parent is None:
  28. level = 0
  29. else:
  30. level = parent.level + 1
  31. self.level: int = level
  32. self.parent = parent
  33. self.refs: t.Dict[str, str] = {}
  34. self.loads: t.Dict[str, t.Any] = {}
  35. self.stores: t.Set[str] = set()
  36. def analyze_node(self, node: nodes.Node, **kwargs: t.Any) -> None:
  37. visitor = RootVisitor(self)
  38. visitor.visit(node, **kwargs)
  39. def _define_ref(
  40. self, name: str, load: t.Optional[t.Tuple[str, t.Optional[str]]] = None
  41. ) -> str:
  42. ident = f"l_{self.level}_{name}"
  43. self.refs[name] = ident
  44. if load is not None:
  45. self.loads[ident] = load
  46. return ident
  47. def find_load(self, target: str) -> t.Optional[t.Any]:
  48. if target in self.loads:
  49. return self.loads[target]
  50. if self.parent is not None:
  51. return self.parent.find_load(target)
  52. return None
  53. def find_ref(self, name: str) -> t.Optional[str]:
  54. if name in self.refs:
  55. return self.refs[name]
  56. if self.parent is not None:
  57. return self.parent.find_ref(name)
  58. return None
  59. def ref(self, name: str) -> str:
  60. rv = self.find_ref(name)
  61. if rv is None:
  62. raise AssertionError(
  63. "Tried to resolve a name to a reference that was"
  64. f" unknown to the frame ({name!r})"
  65. )
  66. return rv
  67. def copy(self) -> "Symbols":
  68. rv = t.cast(Symbols, object.__new__(self.__class__))
  69. rv.__dict__.update(self.__dict__)
  70. rv.refs = self.refs.copy()
  71. rv.loads = self.loads.copy()
  72. rv.stores = self.stores.copy()
  73. return rv
  74. def store(self, name: str) -> None:
  75. self.stores.add(name)
  76. # If we have not see the name referenced yet, we need to figure
  77. # out what to set it to.
  78. if name not in self.refs:
  79. # If there is a parent scope we check if the name has a
  80. # reference there. If it does it means we might have to alias
  81. # to a variable there.
  82. if self.parent is not None:
  83. outer_ref = self.parent.find_ref(name)
  84. if outer_ref is not None:
  85. self._define_ref(name, load=(VAR_LOAD_ALIAS, outer_ref))
  86. return
  87. # Otherwise we can just set it to undefined.
  88. self._define_ref(name, load=(VAR_LOAD_UNDEFINED, None))
  89. def declare_parameter(self, name: str) -> str:
  90. self.stores.add(name)
  91. return self._define_ref(name, load=(VAR_LOAD_PARAMETER, None))
  92. def load(self, name: str) -> None:
  93. if self.find_ref(name) is None:
  94. self._define_ref(name, load=(VAR_LOAD_RESOLVE, name))
  95. def branch_update(self, branch_symbols: t.Sequence["Symbols"]) -> None:
  96. stores: t.Dict[str, int] = {}
  97. for branch in branch_symbols:
  98. for target in branch.stores:
  99. if target in self.stores:
  100. continue
  101. stores[target] = stores.get(target, 0) + 1
  102. for sym in branch_symbols:
  103. self.refs.update(sym.refs)
  104. self.loads.update(sym.loads)
  105. self.stores.update(sym.stores)
  106. for name, branch_count in stores.items():
  107. if branch_count == len(branch_symbols):
  108. continue
  109. target = self.find_ref(name) # type: ignore
  110. assert target is not None, "should not happen"
  111. if self.parent is not None:
  112. outer_target = self.parent.find_ref(name)
  113. if outer_target is not None:
  114. self.loads[target] = (VAR_LOAD_ALIAS, outer_target)
  115. continue
  116. self.loads[target] = (VAR_LOAD_RESOLVE, name)
  117. def dump_stores(self) -> t.Dict[str, str]:
  118. rv: t.Dict[str, str] = {}
  119. node: t.Optional["Symbols"] = self
  120. while node is not None:
  121. for name in sorted(node.stores):
  122. if name not in rv:
  123. rv[name] = self.find_ref(name) # type: ignore
  124. node = node.parent
  125. return rv
  126. def dump_param_targets(self) -> t.Set[str]:
  127. rv = set()
  128. node: t.Optional["Symbols"] = self
  129. while node is not None:
  130. for target, (instr, _) in self.loads.items():
  131. if instr == VAR_LOAD_PARAMETER:
  132. rv.add(target)
  133. node = node.parent
  134. return rv
  135. class RootVisitor(NodeVisitor):
  136. def __init__(self, symbols: "Symbols") -> None:
  137. self.sym_visitor = FrameSymbolVisitor(symbols)
  138. def _simple_visit(self, node: nodes.Node, **kwargs: t.Any) -> None:
  139. for child in node.iter_child_nodes():
  140. self.sym_visitor.visit(child)
  141. visit_Template = _simple_visit
  142. visit_Block = _simple_visit
  143. visit_Macro = _simple_visit
  144. visit_FilterBlock = _simple_visit
  145. visit_Scope = _simple_visit
  146. visit_If = _simple_visit
  147. visit_ScopedEvalContextModifier = _simple_visit
  148. def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None:
  149. for child in node.body:
  150. self.sym_visitor.visit(child)
  151. def visit_CallBlock(self, node: nodes.CallBlock, **kwargs: t.Any) -> None:
  152. for child in node.iter_child_nodes(exclude=("call",)):
  153. self.sym_visitor.visit(child)
  154. def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None:
  155. for child in node.body:
  156. self.sym_visitor.visit(child)
  157. def visit_For(
  158. self, node: nodes.For, for_branch: str = "body", **kwargs: t.Any
  159. ) -> None:
  160. if for_branch == "body":
  161. self.sym_visitor.visit(node.target, store_as_param=True)
  162. branch = node.body
  163. elif for_branch == "else":
  164. branch = node.else_
  165. elif for_branch == "test":
  166. self.sym_visitor.visit(node.target, store_as_param=True)
  167. if node.test is not None:
  168. self.sym_visitor.visit(node.test)
  169. return
  170. else:
  171. raise RuntimeError("Unknown for branch")
  172. if branch:
  173. for item in branch:
  174. self.sym_visitor.visit(item)
  175. def visit_With(self, node: nodes.With, **kwargs: t.Any) -> None:
  176. for target in node.targets:
  177. self.sym_visitor.visit(target)
  178. for child in node.body:
  179. self.sym_visitor.visit(child)
  180. def generic_visit(self, node: nodes.Node, *args: t.Any, **kwargs: t.Any) -> None:
  181. raise NotImplementedError(f"Cannot find symbols for {type(node).__name__!r}")
  182. class FrameSymbolVisitor(NodeVisitor):
  183. """A visitor for `Frame.inspect`."""
  184. def __init__(self, symbols: "Symbols") -> None:
  185. self.symbols = symbols
  186. def visit_Name(
  187. self, node: nodes.Name, store_as_param: bool = False, **kwargs: t.Any
  188. ) -> None:
  189. """All assignments to names go through this function."""
  190. if store_as_param or node.ctx == "param":
  191. self.symbols.declare_parameter(node.name)
  192. elif node.ctx == "store":
  193. self.symbols.store(node.name)
  194. elif node.ctx == "load":
  195. self.symbols.load(node.name)
  196. def visit_NSRef(self, node: nodes.NSRef, **kwargs: t.Any) -> None:
  197. self.symbols.load(node.name)
  198. def visit_If(self, node: nodes.If, **kwargs: t.Any) -> None:
  199. self.visit(node.test, **kwargs)
  200. original_symbols = self.symbols
  201. def inner_visit(nodes: t.Iterable[nodes.Node]) -> "Symbols":
  202. self.symbols = rv = original_symbols.copy()
  203. for subnode in nodes:
  204. self.visit(subnode, **kwargs)
  205. self.symbols = original_symbols
  206. return rv
  207. body_symbols = inner_visit(node.body)
  208. elif_symbols = inner_visit(node.elif_)
  209. else_symbols = inner_visit(node.else_ or ())
  210. self.symbols.branch_update([body_symbols, elif_symbols, else_symbols])
  211. def visit_Macro(self, node: nodes.Macro, **kwargs: t.Any) -> None:
  212. self.symbols.store(node.name)
  213. def visit_Import(self, node: nodes.Import, **kwargs: t.Any) -> None:
  214. self.generic_visit(node, **kwargs)
  215. self.symbols.store(node.target)
  216. def visit_FromImport(self, node: nodes.FromImport, **kwargs: t.Any) -> None:
  217. self.generic_visit(node, **kwargs)
  218. for name in node.names:
  219. if isinstance(name, tuple):
  220. self.symbols.store(name[1])
  221. else:
  222. self.symbols.store(name)
  223. def visit_Assign(self, node: nodes.Assign, **kwargs: t.Any) -> None:
  224. """Visit assignments in the correct order."""
  225. self.visit(node.node, **kwargs)
  226. self.visit(node.target, **kwargs)
  227. def visit_For(self, node: nodes.For, **kwargs: t.Any) -> None:
  228. """Visiting stops at for blocks. However the block sequence
  229. is visited as part of the outer scope.
  230. """
  231. self.visit(node.iter, **kwargs)
  232. def visit_CallBlock(self, node: nodes.CallBlock, **kwargs: t.Any) -> None:
  233. self.visit(node.call, **kwargs)
  234. def visit_FilterBlock(self, node: nodes.FilterBlock, **kwargs: t.Any) -> None:
  235. self.visit(node.filter, **kwargs)
  236. def visit_With(self, node: nodes.With, **kwargs: t.Any) -> None:
  237. for target in node.values:
  238. self.visit(target)
  239. def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None:
  240. """Stop visiting at block assigns."""
  241. self.visit(node.target, **kwargs)
  242. def visit_Scope(self, node: nodes.Scope, **kwargs: t.Any) -> None:
  243. """Stop visiting at scopes."""
  244. def visit_Block(self, node: nodes.Block, **kwargs: t.Any) -> None:
  245. """Stop visiting at blocks."""
  246. def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None:
  247. """Do not visit into overlay scopes."""