guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,3 @@
1
+ # This is updated by our release-please workflow, triggered by this
2
+ # annotation: x-release-please-version
3
+ __version__ = "0.21.0"
@@ -0,0 +1,350 @@
1
+ import ast
2
+ import textwrap
3
+ from collections.abc import Callable, Mapping, Sequence
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast
6
+
7
+ if TYPE_CHECKING:
8
+ from guppylang_internals.tys.ty import Type
9
+
10
+ AstNode = (
11
+ ast.AST
12
+ | ast.operator
13
+ | ast.expr
14
+ | ast.arg
15
+ | ast.stmt
16
+ | ast.Name
17
+ | ast.keyword
18
+ | ast.FunctionDef
19
+ )
20
+
21
+ T = TypeVar("T", covariant=True)
22
+
23
+
24
+ class AstVisitor(Generic[T]):
25
+ """
26
+ Note: This class is based on the implementation of `ast.NodeVisitor` but
27
+ allows extra arguments to be passed to the `visit` functions.
28
+
29
+ Original documentation:
30
+
31
+ A node visitor base class that walks the abstract syntax tree and calls a
32
+ visitor function for every node found. This function may return a value
33
+ which is forwarded by the `visit` method.
34
+
35
+ This class is meant to be subclassed, with the subclass adding visitor
36
+ methods.
37
+
38
+ Per default the visitor functions for the nodes are ``'visit_'`` +
39
+ class name of the node. So a `TryFinally` node visit function would
40
+ be `visit_TryFinally`. This behavior can be changed by overriding
41
+ the `visit` method. If no visitor function exists for a node
42
+ (return value `None`) the `generic_visit` visitor is used instead.
43
+
44
+ Don't use the `NodeVisitor` if you want to apply changes to nodes during
45
+ traversing. For this a special visitor exists (`NodeTransformer`) that
46
+ allows modifications.
47
+ """
48
+
49
+ def visit(self, node: Any, *args: Any, **kwargs: Any) -> T:
50
+ """Visit a node."""
51
+ method = "visit_" + node.__class__.__name__
52
+ visitor = getattr(self, method, self.generic_visit)
53
+ return visitor(node, *args, **kwargs)
54
+
55
+ def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> T:
56
+ """Called if no explicit visitor function exists for a node."""
57
+ raise NotImplementedError(f"visit_{node.__class__.__name__} is not implemented")
58
+
59
+
60
+ class AstSearcher(ast.NodeVisitor):
61
+ """Visitor that searches for occurrences of specific nodes in an AST."""
62
+
63
+ matcher: Callable[[ast.AST], bool]
64
+ dont_recurse_into: set[type[ast.AST]]
65
+ found: list[ast.AST]
66
+ is_first_node: bool
67
+
68
+ def __init__(
69
+ self,
70
+ matcher: Callable[[ast.AST], bool],
71
+ dont_recurse_into: set[type[ast.AST]] | None = None,
72
+ ) -> None:
73
+ self.matcher = matcher
74
+ self.dont_recurse_into = dont_recurse_into or set()
75
+ self.found = []
76
+ self.is_first_node = True
77
+
78
+ def generic_visit(self, node: ast.AST) -> None:
79
+ if self.matcher(node):
80
+ self.found.append(node)
81
+ if self.is_first_node or type(node) not in self.dont_recurse_into:
82
+ self.is_first_node = False
83
+ super().generic_visit(node)
84
+
85
+
86
+ def find_nodes(
87
+ matcher: Callable[[ast.AST], bool],
88
+ node: ast.AST,
89
+ dont_recurse_into: set[type[ast.AST]] | None = None,
90
+ ) -> list[ast.AST]:
91
+ """Returns all nodes in the AST that satisfy the matcher."""
92
+ v = AstSearcher(matcher, dont_recurse_into)
93
+ v.visit(node)
94
+ return v.found
95
+
96
+
97
+ def name_nodes_in_ast(node: Any) -> list[ast.Name]:
98
+ """Returns all `Name` nodes occurring in an AST."""
99
+ found = find_nodes(lambda n: isinstance(n, ast.Name), node)
100
+ return cast(list[ast.Name], found)
101
+
102
+
103
+ def return_nodes_in_ast(node: Any) -> list[ast.Return]:
104
+ """Returns all `Return` nodes occurring in an AST."""
105
+ found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef})
106
+ return cast(list[ast.Return], found)
107
+
108
+
109
+ def breaks_in_loop(node: Any) -> list[ast.Break]:
110
+ """Returns all `Break` nodes occurring in a loop.
111
+
112
+ Note that breaks in nested loops are excluded.
113
+ """
114
+ found = find_nodes(
115
+ lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef}
116
+ )
117
+ return cast(list[ast.Break], found)
118
+
119
+
120
+ class ContextAdjuster(ast.NodeTransformer):
121
+ """Updates the `ast.Context` indicating if expressions occur on the LHS or RHS."""
122
+
123
+ ctx: ast.expr_context
124
+
125
+ def __init__(self, ctx: ast.expr_context) -> None:
126
+ self.ctx = ctx
127
+
128
+ def visit(self, node: ast.AST) -> ast.AST:
129
+ return cast(ast.AST, super().visit(node))
130
+
131
+ def visit_Name(self, node: ast.Name) -> ast.Name:
132
+ return with_loc(node, ast.Name(id=node.id, ctx=self.ctx))
133
+
134
+ def visit_Starred(self, node: ast.Starred) -> ast.Starred:
135
+ return with_loc(node, ast.Starred(value=self.visit(node.value), ctx=self.ctx))
136
+
137
+ def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple:
138
+ return with_loc(
139
+ node, ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
140
+ )
141
+
142
+ def visit_List(self, node: ast.List) -> ast.List:
143
+ return with_loc(
144
+ node, ast.List(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
145
+ )
146
+
147
+ def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
148
+ # Don't adjust the slice!
149
+ return with_loc(
150
+ node,
151
+ ast.Subscript(value=self.visit(node.value), slice=node.slice, ctx=self.ctx),
152
+ )
153
+
154
+ def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
155
+ return with_loc(
156
+ node,
157
+ ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx),
158
+ )
159
+
160
+
161
+ @dataclass(frozen=True, eq=False)
162
+ class TemplateReplacer(ast.NodeTransformer):
163
+ """Replaces nodes in a template."""
164
+
165
+ replacements: Mapping[str, ast.AST | Sequence[ast.AST]]
166
+ default_loc: ast.AST
167
+
168
+ def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]:
169
+ if x not in self.replacements:
170
+ msg = f"No replacement for `{x}` is given"
171
+ raise ValueError(msg)
172
+ return self.replacements[x]
173
+
174
+ def visit_Name(self, node: ast.Name) -> ast.AST:
175
+ repl = self._get_replacement(node.id)
176
+ if not isinstance(repl, ast.expr):
177
+ msg = f"Replacement for `{node.id}` must be an expression"
178
+ raise TypeError(msg)
179
+
180
+ # Update the context
181
+ adjuster = ContextAdjuster(node.ctx)
182
+ return with_loc(repl, adjuster.visit(repl))
183
+
184
+ def visit_Expr(self, node: ast.Expr) -> ast.AST | Sequence[ast.AST]:
185
+ if isinstance(node.value, ast.Name):
186
+ repl = self._get_replacement(node.value.id)
187
+ repls = [repl] if not isinstance(repl, Sequence) else repl
188
+ # Wrap expressions to turn them into statements
189
+ return [
190
+ with_loc(r, ast.Expr(value=r)) if isinstance(r, ast.expr) else r
191
+ for r in repls
192
+ ]
193
+ return self.generic_visit(node)
194
+
195
+ def generic_visit(self, node: ast.AST) -> ast.AST:
196
+ # Insert the default location
197
+ node = super().generic_visit(node)
198
+ return with_loc(self.default_loc, node)
199
+
200
+
201
+ def template_replace(
202
+ template: str, default_loc: ast.AST, **kwargs: ast.AST | Sequence[ast.AST]
203
+ ) -> list[ast.stmt]:
204
+ """Turns a template into a proper AST by substituting all placeholders."""
205
+ nodes = ast.parse(textwrap.dedent(template)).body
206
+ replacer = TemplateReplacer(kwargs, default_loc)
207
+ new_nodes = []
208
+ for n in nodes:
209
+ new = replacer.visit(n)
210
+ if isinstance(new, list):
211
+ new_nodes.extend(new)
212
+ else:
213
+ new_nodes.append(new)
214
+ return new_nodes
215
+
216
+
217
+ def line_col(node: ast.AST) -> tuple[int, int]:
218
+ """Returns the line and column of an ast node."""
219
+ return node.lineno, node.col_offset
220
+
221
+
222
+ def set_location_from(node: ast.AST, loc: ast.AST) -> None:
223
+ """Copy source location from one AST node to the other."""
224
+ node.lineno = loc.lineno
225
+ node.col_offset = loc.col_offset
226
+ node.end_lineno = loc.end_lineno
227
+ node.end_col_offset = loc.end_col_offset
228
+
229
+ source, file, line_offset = get_source(loc), get_file(loc), get_line_offset(loc)
230
+ assert source is not None
231
+ assert file is not None
232
+ assert line_offset is not None
233
+ annotate_location(node, source, file, line_offset)
234
+
235
+
236
+ def annotate_location(
237
+ node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True
238
+ ) -> None:
239
+ node.line_offset = line_offset # type: ignore[attr-defined]
240
+ node.file = file # type: ignore[attr-defined]
241
+ node.source = source # type: ignore[attr-defined]
242
+
243
+ if recurse:
244
+ for _field, value in ast.iter_fields(node):
245
+ if isinstance(value, list):
246
+ for item in value:
247
+ if isinstance(item, ast.AST):
248
+ annotate_location(item, source, file, line_offset, recurse)
249
+ elif isinstance(value, ast.AST):
250
+ annotate_location(value, source, file, line_offset, recurse)
251
+
252
+
253
+ def shift_loc(node: ast.AST, delta_lineno: int, delta_col_offset: int) -> None:
254
+ """Shifts all line and column number in the AST node by the given amount."""
255
+ if hasattr(node, "lineno"):
256
+ node.lineno += delta_lineno
257
+ if hasattr(node, "end_lineno") and node.end_lineno is not None:
258
+ node.end_lineno += delta_lineno
259
+ if hasattr(node, "col_offset"):
260
+ node.col_offset += delta_col_offset
261
+ if hasattr(node, "end_col_offset") and node.end_col_offset is not None:
262
+ node.end_col_offset += delta_col_offset
263
+ for _, value in ast.iter_fields(node):
264
+ if isinstance(value, list):
265
+ for item in value:
266
+ if isinstance(item, ast.AST):
267
+ shift_loc(item, delta_lineno, delta_col_offset)
268
+ elif isinstance(value, ast.AST):
269
+ shift_loc(value, delta_lineno, delta_col_offset)
270
+
271
+
272
+ def get_file(node: AstNode) -> str | None:
273
+ """Tries to retrieve a file annotation from an AST node."""
274
+ try:
275
+ file = node.file # type: ignore[union-attr]
276
+ return file if isinstance(file, str) else None
277
+ except AttributeError:
278
+ return None
279
+
280
+
281
+ def get_source(node: AstNode) -> str | None:
282
+ """Tries to retrieve a source annotation from an AST node."""
283
+ try:
284
+ source = node.source # type: ignore[union-attr]
285
+ return source if isinstance(source, str) else None
286
+ except AttributeError:
287
+ return None
288
+
289
+
290
+ def get_line_offset(node: AstNode) -> int | None:
291
+ """Tries to retrieve a line offset annotation from an AST node."""
292
+ try:
293
+ line_offset = node.line_offset # type: ignore[union-attr]
294
+ return line_offset if isinstance(line_offset, int) else None
295
+ except AttributeError:
296
+ return None
297
+
298
+
299
+ A = TypeVar("A", bound=ast.AST)
300
+
301
+
302
+ def with_loc(loc: ast.AST, node: A) -> A:
303
+ """Copy source location from one AST node to the other."""
304
+ set_location_from(node, loc)
305
+ return node
306
+
307
+
308
+ def with_type(ty: "Type", node: A) -> A:
309
+ """Annotates an AST node with a type."""
310
+ node.type = ty # type: ignore[attr-defined]
311
+ return node
312
+
313
+
314
+ def get_type_opt(node: AstNode) -> Optional["Type"]:
315
+ """Tries to retrieve a type annotation from an AST node."""
316
+ from guppylang_internals.tys.ty import Type, TypeBase
317
+
318
+ try:
319
+ ty = node.type # type: ignore[union-attr]
320
+ return cast(Type, ty) if isinstance(ty, TypeBase) else None
321
+ except AttributeError:
322
+ return None
323
+
324
+
325
+ def get_type(node: AstNode) -> "Type":
326
+ """Retrieve a type annotation from an AST node.
327
+
328
+ Fails if the node is not annotated.
329
+ """
330
+ ty = get_type_opt(node)
331
+ assert ty is not None
332
+ return ty
333
+
334
+
335
+ def has_empty_body(func_ast: ast.FunctionDef) -> bool:
336
+ """Returns `True` if the body of a function definition is empty.
337
+
338
+ This is the case if the body only contains a single `pass` statement or an ellipsis
339
+ `...` expression.
340
+ """
341
+ if len(func_ast.body) == 0:
342
+ return True
343
+ if len(func_ast.body) > 1:
344
+ return False
345
+ [n] = func_ast.body
346
+ return (
347
+ isinstance(n, ast.Expr)
348
+ and isinstance(n.value, ast.Constant)
349
+ and n.value.value == Ellipsis
350
+ )
File without changes
@@ -0,0 +1,230 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Iterable
3
+ from typing import Generic, TypeVar
4
+
5
+ from guppylang_internals.cfg.bb import BB, VariableStats, VId
6
+
7
+ # Type variable for the lattice domain
8
+ T = TypeVar("T")
9
+
10
+ # Analysis result is a mapping from basic blocks to lattice values
11
+ Result = dict[BB, T]
12
+
13
+
14
+ class Analysis(Generic[T], ABC):
15
+ """Abstract base class for a program analysis pass over the lattice `T`"""
16
+
17
+ def eq(self, t1: T, t2: T, /) -> bool:
18
+ """Equality on lattice values"""
19
+ return t1 == t2
20
+
21
+ @abstractmethod
22
+ def include_unreachable(self) -> bool:
23
+ """Whether unreachable BBs and jumps should be taken into account for the
24
+ analysis."""
25
+
26
+ @abstractmethod
27
+ def initial(self) -> T:
28
+ """Initial lattice value"""
29
+
30
+ @abstractmethod
31
+ def join(self, *ts: T) -> T:
32
+ """Lattice join operation"""
33
+
34
+ @abstractmethod
35
+ def run(self, bbs: Iterable[BB]) -> Result[T]:
36
+ """Runs the analysis pass.
37
+
38
+ Returns a mapping from basic blocks to lattice values at the start of each BB.
39
+ """
40
+
41
+
42
+ class ForwardAnalysis(Generic[T], Analysis[T], ABC):
43
+ """Abstract base class for a program analysis pass running in forward direction."""
44
+
45
+ @abstractmethod
46
+ def apply_bb(self, val_before: T, bb: BB, /) -> T:
47
+ """Transformation a basic block applies to a lattice value"""
48
+
49
+ def run(self, bbs: Iterable[BB]) -> Result[T]:
50
+ """Runs the analysis pass.
51
+
52
+ Returns a mapping from basic blocks to lattice values at the start of each BB.
53
+ """
54
+ if not self.include_unreachable():
55
+ bbs = [bb for bb in bbs if bb.reachable]
56
+ vals_before = {bb: self.initial() for bb in bbs} # return value
57
+ vals_after = {bb: self.apply_bb(vals_before[bb], bb) for bb in bbs} # cache
58
+ queue = set(bbs)
59
+ while len(queue) > 0:
60
+ bb = queue.pop()
61
+ preds = (
62
+ bb.predecessors + bb.dummy_predecessors
63
+ if self.include_unreachable()
64
+ else bb.predecessors
65
+ )
66
+ vals_before[bb] = self.join(*(vals_after[pred] for pred in preds))
67
+ val_after = self.apply_bb(vals_before[bb], bb)
68
+ if not self.eq(val_after, vals_after[bb]):
69
+ vals_after[bb] = val_after
70
+ queue.update(bb.successors)
71
+ return vals_before
72
+
73
+
74
+ class BackwardAnalysis(Generic[T], Analysis[T], ABC):
75
+ """Abstract base class for a program analysis pass running in backward direction."""
76
+
77
+ @abstractmethod
78
+ def apply_bb(self, val_after: T, bb: BB, /) -> T:
79
+ """Transformation a basic block applies to a lattice value"""
80
+
81
+ def run(self, bbs: Iterable[BB]) -> Result[T]:
82
+ """Runs the analysis pass.
83
+
84
+ Returns a mapping from basic blocks to lattice values at the start of each BB.
85
+ """
86
+ vals_before = {bb: self.initial() for bb in bbs}
87
+ queue = set(bbs)
88
+ while len(queue) > 0:
89
+ bb = queue.pop()
90
+ succs = (
91
+ bb.successors + bb.dummy_successors
92
+ if self.include_unreachable()
93
+ else bb.successors
94
+ )
95
+ val_after = self.join(*(vals_before[succ] for succ in succs))
96
+ val_before = self.apply_bb(val_after, bb)
97
+ if not self.eq(vals_before[bb], val_before):
98
+ vals_before[bb] = val_before
99
+ queue.update(bb.predecessors)
100
+ return vals_before
101
+
102
+
103
+ # For live variable analysis, we also store a BB in which a use occurs as evidence of
104
+ # liveness.
105
+ LivenessDomain = dict[VId, BB]
106
+
107
+
108
+ class LivenessAnalysis(Generic[VId], BackwardAnalysis[LivenessDomain[VId]]):
109
+ """Live variable analysis pass.
110
+
111
+ Computes the variables that are live before the execution of each BB. The analysis
112
+ runs over the lattice of mappings from variable names to BBs containing a use.
113
+ """
114
+
115
+ stats: dict[BB, VariableStats[VId]]
116
+
117
+ def __init__(
118
+ self,
119
+ stats: dict[BB, VariableStats[VId]],
120
+ initial: LivenessDomain[VId] | None = None,
121
+ include_unreachable: bool = False,
122
+ ) -> None:
123
+ self.stats = stats
124
+ self._initial = initial or {}
125
+ self._include_unreachable = include_unreachable
126
+
127
+ def eq(self, live1: LivenessDomain[VId], live2: LivenessDomain[VId]) -> bool:
128
+ # Only check that both contain the same variables. We don't care about the BB
129
+ # in which the use occurs, we just need any one, to report to the user.
130
+ return live1.keys() == live2.keys()
131
+
132
+ def initial(self) -> LivenessDomain[VId]:
133
+ return self._initial
134
+
135
+ def include_unreachable(self) -> bool:
136
+ return self._include_unreachable
137
+
138
+ def join(self, *ts: LivenessDomain[VId]) -> LivenessDomain[VId]:
139
+ res: LivenessDomain[VId] = {}
140
+ for t in ts:
141
+ res |= t
142
+ return res
143
+
144
+ def apply_bb(self, live_after: LivenessDomain[VId], bb: BB) -> LivenessDomain[VId]:
145
+ stats = self.stats[bb]
146
+ return {x: bb for x in stats.used} | {
147
+ x: b for x, b in live_after.items() if x not in stats.assigned
148
+ }
149
+
150
+
151
+ # Set of variables that are definitely assigned at the start of a BB
152
+ DefAssignmentDomain = set[VId]
153
+
154
+ # Set of variables that are assigned on (at least) some paths to a BB. Definitely
155
+ # assigned variables are a subset of this
156
+ MaybeAssignmentDomain = set[VId]
157
+
158
+ # For assignment analysis, we do definite- and maybe-assignment in one pass
159
+ AssignmentDomain = tuple[DefAssignmentDomain[VId], MaybeAssignmentDomain[VId]]
160
+
161
+
162
+ class AssignmentAnalysis(Generic[VId], ForwardAnalysis[AssignmentDomain[VId]]):
163
+ """Assigned variable analysis pass.
164
+
165
+ Computes the set of variables (i.e. `V`s) that are definitely assigned at the start
166
+ of a BB. Additionally, we compute the set of variables that are assigned on (at
167
+ least) some paths to a BB (the definitely assigned variables are a subset of this).
168
+ """
169
+
170
+ stats: dict[BB, VariableStats[VId]]
171
+ all_vars: set[VId]
172
+ ass_before_entry: set[VId]
173
+ maybe_ass_before_entry: set[VId]
174
+
175
+ def __init__(
176
+ self,
177
+ stats: dict[BB, VariableStats[VId]],
178
+ ass_before_entry: set[VId],
179
+ maybe_ass_before_entry: set[VId],
180
+ include_unreachable: bool = False,
181
+ ) -> None:
182
+ """Constructs an `AssignmentAnalysis` pass for a CFG.
183
+
184
+ Also takes a set variables that are definitely assigned before the entry of the
185
+ CFG (for example function arguments).
186
+ """
187
+ assert ass_before_entry.issubset(maybe_ass_before_entry)
188
+ self.stats = stats
189
+ self.ass_before_entry = ass_before_entry
190
+ self.maybe_ass_before_entry = maybe_ass_before_entry
191
+ self.all_vars = (
192
+ set.union(*(set(stat.assigned.keys()) for stat in stats.values()))
193
+ | ass_before_entry
194
+ )
195
+ self._include_unreachable = include_unreachable
196
+
197
+ def initial(self) -> AssignmentDomain[VId]:
198
+ # Note that definite assignment must start with `all_vars` instead of only
199
+ # `ass_before_entry` since we want to compute the *greatest* fixpoint.
200
+ return self.all_vars, self.maybe_ass_before_entry
201
+
202
+ def include_unreachable(self) -> bool:
203
+ return self._include_unreachable
204
+
205
+ def join(self, *ts: AssignmentDomain[VId]) -> AssignmentDomain[VId]:
206
+ # We always include the variables that are definitely assigned before the entry,
207
+ # even if the join is empty
208
+ if len(ts) == 0:
209
+ return self.ass_before_entry, self.ass_before_entry
210
+
211
+ def_ass = set.intersection(*(def_ass for def_ass, _ in ts))
212
+ maybe_ass = set.union(*(maybe_ass for _, maybe_ass in ts))
213
+ return def_ass, maybe_ass
214
+
215
+ def apply_bb(
216
+ self, val_before: AssignmentDomain[VId], bb: BB
217
+ ) -> AssignmentDomain[VId]:
218
+ stats = self.stats[bb]
219
+ def_ass_before, maybe_ass_before = val_before
220
+ return (
221
+ def_ass_before | stats.assigned.keys(),
222
+ maybe_ass_before | stats.assigned.keys(),
223
+ )
224
+
225
+ def run_unpacked(
226
+ self, bbs: Iterable[BB]
227
+ ) -> tuple[Result[DefAssignmentDomain[VId]], Result[MaybeAssignmentDomain[VId]]]:
228
+ """Runs the analysis and unpacks the definite- and maybe-assignment results."""
229
+ res = self.run(bbs)
230
+ return {bb: res[bb][0] for bb in res}, {bb: res[bb][1] for bb in res}