guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import textwrap
|
|
3
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from guppylang_internals.tys.ty import Type
|
|
9
|
+
|
|
10
|
+
AstNode = (
|
|
11
|
+
ast.AST
|
|
12
|
+
| ast.operator
|
|
13
|
+
| ast.expr
|
|
14
|
+
| ast.arg
|
|
15
|
+
| ast.stmt
|
|
16
|
+
| ast.Name
|
|
17
|
+
| ast.keyword
|
|
18
|
+
| ast.FunctionDef
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
T = TypeVar("T", covariant=True)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AstVisitor(Generic[T]):
|
|
25
|
+
"""
|
|
26
|
+
Note: This class is based on the implementation of `ast.NodeVisitor` but
|
|
27
|
+
allows extra arguments to be passed to the `visit` functions.
|
|
28
|
+
|
|
29
|
+
Original documentation:
|
|
30
|
+
|
|
31
|
+
A node visitor base class that walks the abstract syntax tree and calls a
|
|
32
|
+
visitor function for every node found. This function may return a value
|
|
33
|
+
which is forwarded by the `visit` method.
|
|
34
|
+
|
|
35
|
+
This class is meant to be subclassed, with the subclass adding visitor
|
|
36
|
+
methods.
|
|
37
|
+
|
|
38
|
+
Per default the visitor functions for the nodes are ``'visit_'`` +
|
|
39
|
+
class name of the node. So a `TryFinally` node visit function would
|
|
40
|
+
be `visit_TryFinally`. This behavior can be changed by overriding
|
|
41
|
+
the `visit` method. If no visitor function exists for a node
|
|
42
|
+
(return value `None`) the `generic_visit` visitor is used instead.
|
|
43
|
+
|
|
44
|
+
Don't use the `NodeVisitor` if you want to apply changes to nodes during
|
|
45
|
+
traversing. For this a special visitor exists (`NodeTransformer`) that
|
|
46
|
+
allows modifications.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def visit(self, node: Any, *args: Any, **kwargs: Any) -> T:
|
|
50
|
+
"""Visit a node."""
|
|
51
|
+
method = "visit_" + node.__class__.__name__
|
|
52
|
+
visitor = getattr(self, method, self.generic_visit)
|
|
53
|
+
return visitor(node, *args, **kwargs)
|
|
54
|
+
|
|
55
|
+
def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> T:
|
|
56
|
+
"""Called if no explicit visitor function exists for a node."""
|
|
57
|
+
raise NotImplementedError(f"visit_{node.__class__.__name__} is not implemented")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class AstSearcher(ast.NodeVisitor):
|
|
61
|
+
"""Visitor that searches for occurrences of specific nodes in an AST."""
|
|
62
|
+
|
|
63
|
+
matcher: Callable[[ast.AST], bool]
|
|
64
|
+
dont_recurse_into: set[type[ast.AST]]
|
|
65
|
+
found: list[ast.AST]
|
|
66
|
+
is_first_node: bool
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
matcher: Callable[[ast.AST], bool],
|
|
71
|
+
dont_recurse_into: set[type[ast.AST]] | None = None,
|
|
72
|
+
) -> None:
|
|
73
|
+
self.matcher = matcher
|
|
74
|
+
self.dont_recurse_into = dont_recurse_into or set()
|
|
75
|
+
self.found = []
|
|
76
|
+
self.is_first_node = True
|
|
77
|
+
|
|
78
|
+
def generic_visit(self, node: ast.AST) -> None:
|
|
79
|
+
if self.matcher(node):
|
|
80
|
+
self.found.append(node)
|
|
81
|
+
if self.is_first_node or type(node) not in self.dont_recurse_into:
|
|
82
|
+
self.is_first_node = False
|
|
83
|
+
super().generic_visit(node)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def find_nodes(
|
|
87
|
+
matcher: Callable[[ast.AST], bool],
|
|
88
|
+
node: ast.AST,
|
|
89
|
+
dont_recurse_into: set[type[ast.AST]] | None = None,
|
|
90
|
+
) -> list[ast.AST]:
|
|
91
|
+
"""Returns all nodes in the AST that satisfy the matcher."""
|
|
92
|
+
v = AstSearcher(matcher, dont_recurse_into)
|
|
93
|
+
v.visit(node)
|
|
94
|
+
return v.found
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def name_nodes_in_ast(node: Any) -> list[ast.Name]:
|
|
98
|
+
"""Returns all `Name` nodes occurring in an AST."""
|
|
99
|
+
found = find_nodes(lambda n: isinstance(n, ast.Name), node)
|
|
100
|
+
return cast(list[ast.Name], found)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def return_nodes_in_ast(node: Any) -> list[ast.Return]:
|
|
104
|
+
"""Returns all `Return` nodes occurring in an AST."""
|
|
105
|
+
found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef})
|
|
106
|
+
return cast(list[ast.Return], found)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def breaks_in_loop(node: Any) -> list[ast.Break]:
|
|
110
|
+
"""Returns all `Break` nodes occurring in a loop.
|
|
111
|
+
|
|
112
|
+
Note that breaks in nested loops are excluded.
|
|
113
|
+
"""
|
|
114
|
+
found = find_nodes(
|
|
115
|
+
lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef}
|
|
116
|
+
)
|
|
117
|
+
return cast(list[ast.Break], found)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class ContextAdjuster(ast.NodeTransformer):
|
|
121
|
+
"""Updates the `ast.Context` indicating if expressions occur on the LHS or RHS."""
|
|
122
|
+
|
|
123
|
+
ctx: ast.expr_context
|
|
124
|
+
|
|
125
|
+
def __init__(self, ctx: ast.expr_context) -> None:
|
|
126
|
+
self.ctx = ctx
|
|
127
|
+
|
|
128
|
+
def visit(self, node: ast.AST) -> ast.AST:
|
|
129
|
+
return cast(ast.AST, super().visit(node))
|
|
130
|
+
|
|
131
|
+
def visit_Name(self, node: ast.Name) -> ast.Name:
|
|
132
|
+
return with_loc(node, ast.Name(id=node.id, ctx=self.ctx))
|
|
133
|
+
|
|
134
|
+
def visit_Starred(self, node: ast.Starred) -> ast.Starred:
|
|
135
|
+
return with_loc(node, ast.Starred(value=self.visit(node.value), ctx=self.ctx))
|
|
136
|
+
|
|
137
|
+
def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple:
|
|
138
|
+
return with_loc(
|
|
139
|
+
node, ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def visit_List(self, node: ast.List) -> ast.List:
|
|
143
|
+
return with_loc(
|
|
144
|
+
node, ast.List(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
|
|
148
|
+
# Don't adjust the slice!
|
|
149
|
+
return with_loc(
|
|
150
|
+
node,
|
|
151
|
+
ast.Subscript(value=self.visit(node.value), slice=node.slice, ctx=self.ctx),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
|
|
155
|
+
return with_loc(
|
|
156
|
+
node,
|
|
157
|
+
ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx),
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@dataclass(frozen=True, eq=False)
|
|
162
|
+
class TemplateReplacer(ast.NodeTransformer):
|
|
163
|
+
"""Replaces nodes in a template."""
|
|
164
|
+
|
|
165
|
+
replacements: Mapping[str, ast.AST | Sequence[ast.AST]]
|
|
166
|
+
default_loc: ast.AST
|
|
167
|
+
|
|
168
|
+
def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]:
|
|
169
|
+
if x not in self.replacements:
|
|
170
|
+
msg = f"No replacement for `{x}` is given"
|
|
171
|
+
raise ValueError(msg)
|
|
172
|
+
return self.replacements[x]
|
|
173
|
+
|
|
174
|
+
def visit_Name(self, node: ast.Name) -> ast.AST:
|
|
175
|
+
repl = self._get_replacement(node.id)
|
|
176
|
+
if not isinstance(repl, ast.expr):
|
|
177
|
+
msg = f"Replacement for `{node.id}` must be an expression"
|
|
178
|
+
raise TypeError(msg)
|
|
179
|
+
|
|
180
|
+
# Update the context
|
|
181
|
+
adjuster = ContextAdjuster(node.ctx)
|
|
182
|
+
return with_loc(repl, adjuster.visit(repl))
|
|
183
|
+
|
|
184
|
+
def visit_Expr(self, node: ast.Expr) -> ast.AST | Sequence[ast.AST]:
|
|
185
|
+
if isinstance(node.value, ast.Name):
|
|
186
|
+
repl = self._get_replacement(node.value.id)
|
|
187
|
+
repls = [repl] if not isinstance(repl, Sequence) else repl
|
|
188
|
+
# Wrap expressions to turn them into statements
|
|
189
|
+
return [
|
|
190
|
+
with_loc(r, ast.Expr(value=r)) if isinstance(r, ast.expr) else r
|
|
191
|
+
for r in repls
|
|
192
|
+
]
|
|
193
|
+
return self.generic_visit(node)
|
|
194
|
+
|
|
195
|
+
def generic_visit(self, node: ast.AST) -> ast.AST:
|
|
196
|
+
# Insert the default location
|
|
197
|
+
node = super().generic_visit(node)
|
|
198
|
+
return with_loc(self.default_loc, node)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def template_replace(
|
|
202
|
+
template: str, default_loc: ast.AST, **kwargs: ast.AST | Sequence[ast.AST]
|
|
203
|
+
) -> list[ast.stmt]:
|
|
204
|
+
"""Turns a template into a proper AST by substituting all placeholders."""
|
|
205
|
+
nodes = ast.parse(textwrap.dedent(template)).body
|
|
206
|
+
replacer = TemplateReplacer(kwargs, default_loc)
|
|
207
|
+
new_nodes = []
|
|
208
|
+
for n in nodes:
|
|
209
|
+
new = replacer.visit(n)
|
|
210
|
+
if isinstance(new, list):
|
|
211
|
+
new_nodes.extend(new)
|
|
212
|
+
else:
|
|
213
|
+
new_nodes.append(new)
|
|
214
|
+
return new_nodes
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def line_col(node: ast.AST) -> tuple[int, int]:
|
|
218
|
+
"""Returns the line and column of an ast node."""
|
|
219
|
+
return node.lineno, node.col_offset
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def set_location_from(node: ast.AST, loc: ast.AST) -> None:
|
|
223
|
+
"""Copy source location from one AST node to the other."""
|
|
224
|
+
node.lineno = loc.lineno
|
|
225
|
+
node.col_offset = loc.col_offset
|
|
226
|
+
node.end_lineno = loc.end_lineno
|
|
227
|
+
node.end_col_offset = loc.end_col_offset
|
|
228
|
+
|
|
229
|
+
source, file, line_offset = get_source(loc), get_file(loc), get_line_offset(loc)
|
|
230
|
+
assert source is not None
|
|
231
|
+
assert file is not None
|
|
232
|
+
assert line_offset is not None
|
|
233
|
+
annotate_location(node, source, file, line_offset)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def annotate_location(
|
|
237
|
+
node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True
|
|
238
|
+
) -> None:
|
|
239
|
+
node.line_offset = line_offset # type: ignore[attr-defined]
|
|
240
|
+
node.file = file # type: ignore[attr-defined]
|
|
241
|
+
node.source = source # type: ignore[attr-defined]
|
|
242
|
+
|
|
243
|
+
if recurse:
|
|
244
|
+
for _field, value in ast.iter_fields(node):
|
|
245
|
+
if isinstance(value, list):
|
|
246
|
+
for item in value:
|
|
247
|
+
if isinstance(item, ast.AST):
|
|
248
|
+
annotate_location(item, source, file, line_offset, recurse)
|
|
249
|
+
elif isinstance(value, ast.AST):
|
|
250
|
+
annotate_location(value, source, file, line_offset, recurse)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def shift_loc(node: ast.AST, delta_lineno: int, delta_col_offset: int) -> None:
|
|
254
|
+
"""Shifts all line and column number in the AST node by the given amount."""
|
|
255
|
+
if hasattr(node, "lineno"):
|
|
256
|
+
node.lineno += delta_lineno
|
|
257
|
+
if hasattr(node, "end_lineno") and node.end_lineno is not None:
|
|
258
|
+
node.end_lineno += delta_lineno
|
|
259
|
+
if hasattr(node, "col_offset"):
|
|
260
|
+
node.col_offset += delta_col_offset
|
|
261
|
+
if hasattr(node, "end_col_offset") and node.end_col_offset is not None:
|
|
262
|
+
node.end_col_offset += delta_col_offset
|
|
263
|
+
for _, value in ast.iter_fields(node):
|
|
264
|
+
if isinstance(value, list):
|
|
265
|
+
for item in value:
|
|
266
|
+
if isinstance(item, ast.AST):
|
|
267
|
+
shift_loc(item, delta_lineno, delta_col_offset)
|
|
268
|
+
elif isinstance(value, ast.AST):
|
|
269
|
+
shift_loc(value, delta_lineno, delta_col_offset)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def get_file(node: AstNode) -> str | None:
|
|
273
|
+
"""Tries to retrieve a file annotation from an AST node."""
|
|
274
|
+
try:
|
|
275
|
+
file = node.file # type: ignore[union-attr]
|
|
276
|
+
return file if isinstance(file, str) else None
|
|
277
|
+
except AttributeError:
|
|
278
|
+
return None
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def get_source(node: AstNode) -> str | None:
|
|
282
|
+
"""Tries to retrieve a source annotation from an AST node."""
|
|
283
|
+
try:
|
|
284
|
+
source = node.source # type: ignore[union-attr]
|
|
285
|
+
return source if isinstance(source, str) else None
|
|
286
|
+
except AttributeError:
|
|
287
|
+
return None
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def get_line_offset(node: AstNode) -> int | None:
|
|
291
|
+
"""Tries to retrieve a line offset annotation from an AST node."""
|
|
292
|
+
try:
|
|
293
|
+
line_offset = node.line_offset # type: ignore[union-attr]
|
|
294
|
+
return line_offset if isinstance(line_offset, int) else None
|
|
295
|
+
except AttributeError:
|
|
296
|
+
return None
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
A = TypeVar("A", bound=ast.AST)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def with_loc(loc: ast.AST, node: A) -> A:
|
|
303
|
+
"""Copy source location from one AST node to the other."""
|
|
304
|
+
set_location_from(node, loc)
|
|
305
|
+
return node
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def with_type(ty: "Type", node: A) -> A:
|
|
309
|
+
"""Annotates an AST node with a type."""
|
|
310
|
+
node.type = ty # type: ignore[attr-defined]
|
|
311
|
+
return node
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def get_type_opt(node: AstNode) -> Optional["Type"]:
|
|
315
|
+
"""Tries to retrieve a type annotation from an AST node."""
|
|
316
|
+
from guppylang_internals.tys.ty import Type, TypeBase
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
ty = node.type # type: ignore[union-attr]
|
|
320
|
+
return cast(Type, ty) if isinstance(ty, TypeBase) else None
|
|
321
|
+
except AttributeError:
|
|
322
|
+
return None
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def get_type(node: AstNode) -> "Type":
|
|
326
|
+
"""Retrieve a type annotation from an AST node.
|
|
327
|
+
|
|
328
|
+
Fails if the node is not annotated.
|
|
329
|
+
"""
|
|
330
|
+
ty = get_type_opt(node)
|
|
331
|
+
assert ty is not None
|
|
332
|
+
return ty
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def has_empty_body(func_ast: ast.FunctionDef) -> bool:
|
|
336
|
+
"""Returns `True` if the body of a function definition is empty.
|
|
337
|
+
|
|
338
|
+
This is the case if the body only contains a single `pass` statement or an ellipsis
|
|
339
|
+
`...` expression.
|
|
340
|
+
"""
|
|
341
|
+
if len(func_ast.body) == 0:
|
|
342
|
+
return True
|
|
343
|
+
if len(func_ast.body) > 1:
|
|
344
|
+
return False
|
|
345
|
+
[n] = func_ast.body
|
|
346
|
+
return (
|
|
347
|
+
isinstance(n, ast.Expr)
|
|
348
|
+
and isinstance(n.value, ast.Constant)
|
|
349
|
+
and n.value.value == Ellipsis
|
|
350
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from typing import Generic, TypeVar
|
|
4
|
+
|
|
5
|
+
from guppylang_internals.cfg.bb import BB, VariableStats, VId
|
|
6
|
+
|
|
7
|
+
# Type variable for the lattice domain
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
# Analysis result is a mapping from basic blocks to lattice values
|
|
11
|
+
Result = dict[BB, T]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Analysis(Generic[T], ABC):
|
|
15
|
+
"""Abstract base class for a program analysis pass over the lattice `T`"""
|
|
16
|
+
|
|
17
|
+
def eq(self, t1: T, t2: T, /) -> bool:
|
|
18
|
+
"""Equality on lattice values"""
|
|
19
|
+
return t1 == t2
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def include_unreachable(self) -> bool:
|
|
23
|
+
"""Whether unreachable BBs and jumps should be taken into account for the
|
|
24
|
+
analysis."""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def initial(self) -> T:
|
|
28
|
+
"""Initial lattice value"""
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def join(self, *ts: T) -> T:
|
|
32
|
+
"""Lattice join operation"""
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def run(self, bbs: Iterable[BB]) -> Result[T]:
|
|
36
|
+
"""Runs the analysis pass.
|
|
37
|
+
|
|
38
|
+
Returns a mapping from basic blocks to lattice values at the start of each BB.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ForwardAnalysis(Generic[T], Analysis[T], ABC):
|
|
43
|
+
"""Abstract base class for a program analysis pass running in forward direction."""
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def apply_bb(self, val_before: T, bb: BB, /) -> T:
|
|
47
|
+
"""Transformation a basic block applies to a lattice value"""
|
|
48
|
+
|
|
49
|
+
def run(self, bbs: Iterable[BB]) -> Result[T]:
|
|
50
|
+
"""Runs the analysis pass.
|
|
51
|
+
|
|
52
|
+
Returns a mapping from basic blocks to lattice values at the start of each BB.
|
|
53
|
+
"""
|
|
54
|
+
if not self.include_unreachable():
|
|
55
|
+
bbs = [bb for bb in bbs if bb.reachable]
|
|
56
|
+
vals_before = {bb: self.initial() for bb in bbs} # return value
|
|
57
|
+
vals_after = {bb: self.apply_bb(vals_before[bb], bb) for bb in bbs} # cache
|
|
58
|
+
queue = set(bbs)
|
|
59
|
+
while len(queue) > 0:
|
|
60
|
+
bb = queue.pop()
|
|
61
|
+
preds = (
|
|
62
|
+
bb.predecessors + bb.dummy_predecessors
|
|
63
|
+
if self.include_unreachable()
|
|
64
|
+
else bb.predecessors
|
|
65
|
+
)
|
|
66
|
+
vals_before[bb] = self.join(*(vals_after[pred] for pred in preds))
|
|
67
|
+
val_after = self.apply_bb(vals_before[bb], bb)
|
|
68
|
+
if not self.eq(val_after, vals_after[bb]):
|
|
69
|
+
vals_after[bb] = val_after
|
|
70
|
+
queue.update(bb.successors)
|
|
71
|
+
return vals_before
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class BackwardAnalysis(Generic[T], Analysis[T], ABC):
|
|
75
|
+
"""Abstract base class for a program analysis pass running in backward direction."""
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def apply_bb(self, val_after: T, bb: BB, /) -> T:
|
|
79
|
+
"""Transformation a basic block applies to a lattice value"""
|
|
80
|
+
|
|
81
|
+
def run(self, bbs: Iterable[BB]) -> Result[T]:
|
|
82
|
+
"""Runs the analysis pass.
|
|
83
|
+
|
|
84
|
+
Returns a mapping from basic blocks to lattice values at the start of each BB.
|
|
85
|
+
"""
|
|
86
|
+
vals_before = {bb: self.initial() for bb in bbs}
|
|
87
|
+
queue = set(bbs)
|
|
88
|
+
while len(queue) > 0:
|
|
89
|
+
bb = queue.pop()
|
|
90
|
+
succs = (
|
|
91
|
+
bb.successors + bb.dummy_successors
|
|
92
|
+
if self.include_unreachable()
|
|
93
|
+
else bb.successors
|
|
94
|
+
)
|
|
95
|
+
val_after = self.join(*(vals_before[succ] for succ in succs))
|
|
96
|
+
val_before = self.apply_bb(val_after, bb)
|
|
97
|
+
if not self.eq(vals_before[bb], val_before):
|
|
98
|
+
vals_before[bb] = val_before
|
|
99
|
+
queue.update(bb.predecessors)
|
|
100
|
+
return vals_before
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# For live variable analysis, we also store a BB in which a use occurs as evidence of
|
|
104
|
+
# liveness.
|
|
105
|
+
LivenessDomain = dict[VId, BB]
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LivenessAnalysis(Generic[VId], BackwardAnalysis[LivenessDomain[VId]]):
|
|
109
|
+
"""Live variable analysis pass.
|
|
110
|
+
|
|
111
|
+
Computes the variables that are live before the execution of each BB. The analysis
|
|
112
|
+
runs over the lattice of mappings from variable names to BBs containing a use.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
stats: dict[BB, VariableStats[VId]]
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
stats: dict[BB, VariableStats[VId]],
|
|
120
|
+
initial: LivenessDomain[VId] | None = None,
|
|
121
|
+
include_unreachable: bool = False,
|
|
122
|
+
) -> None:
|
|
123
|
+
self.stats = stats
|
|
124
|
+
self._initial = initial or {}
|
|
125
|
+
self._include_unreachable = include_unreachable
|
|
126
|
+
|
|
127
|
+
def eq(self, live1: LivenessDomain[VId], live2: LivenessDomain[VId]) -> bool:
|
|
128
|
+
# Only check that both contain the same variables. We don't care about the BB
|
|
129
|
+
# in which the use occurs, we just need any one, to report to the user.
|
|
130
|
+
return live1.keys() == live2.keys()
|
|
131
|
+
|
|
132
|
+
def initial(self) -> LivenessDomain[VId]:
|
|
133
|
+
return self._initial
|
|
134
|
+
|
|
135
|
+
def include_unreachable(self) -> bool:
|
|
136
|
+
return self._include_unreachable
|
|
137
|
+
|
|
138
|
+
def join(self, *ts: LivenessDomain[VId]) -> LivenessDomain[VId]:
|
|
139
|
+
res: LivenessDomain[VId] = {}
|
|
140
|
+
for t in ts:
|
|
141
|
+
res |= t
|
|
142
|
+
return res
|
|
143
|
+
|
|
144
|
+
def apply_bb(self, live_after: LivenessDomain[VId], bb: BB) -> LivenessDomain[VId]:
|
|
145
|
+
stats = self.stats[bb]
|
|
146
|
+
return {x: bb for x in stats.used} | {
|
|
147
|
+
x: b for x, b in live_after.items() if x not in stats.assigned
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# Set of variables that are definitely assigned at the start of a BB
|
|
152
|
+
DefAssignmentDomain = set[VId]
|
|
153
|
+
|
|
154
|
+
# Set of variables that are assigned on (at least) some paths to a BB. Definitely
|
|
155
|
+
# assigned variables are a subset of this
|
|
156
|
+
MaybeAssignmentDomain = set[VId]
|
|
157
|
+
|
|
158
|
+
# For assignment analysis, we do definite- and maybe-assignment in one pass
|
|
159
|
+
AssignmentDomain = tuple[DefAssignmentDomain[VId], MaybeAssignmentDomain[VId]]
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class AssignmentAnalysis(Generic[VId], ForwardAnalysis[AssignmentDomain[VId]]):
|
|
163
|
+
"""Assigned variable analysis pass.
|
|
164
|
+
|
|
165
|
+
Computes the set of variables (i.e. `V`s) that are definitely assigned at the start
|
|
166
|
+
of a BB. Additionally, we compute the set of variables that are assigned on (at
|
|
167
|
+
least) some paths to a BB (the definitely assigned variables are a subset of this).
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
stats: dict[BB, VariableStats[VId]]
|
|
171
|
+
all_vars: set[VId]
|
|
172
|
+
ass_before_entry: set[VId]
|
|
173
|
+
maybe_ass_before_entry: set[VId]
|
|
174
|
+
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
stats: dict[BB, VariableStats[VId]],
|
|
178
|
+
ass_before_entry: set[VId],
|
|
179
|
+
maybe_ass_before_entry: set[VId],
|
|
180
|
+
include_unreachable: bool = False,
|
|
181
|
+
) -> None:
|
|
182
|
+
"""Constructs an `AssignmentAnalysis` pass for a CFG.
|
|
183
|
+
|
|
184
|
+
Also takes a set variables that are definitely assigned before the entry of the
|
|
185
|
+
CFG (for example function arguments).
|
|
186
|
+
"""
|
|
187
|
+
assert ass_before_entry.issubset(maybe_ass_before_entry)
|
|
188
|
+
self.stats = stats
|
|
189
|
+
self.ass_before_entry = ass_before_entry
|
|
190
|
+
self.maybe_ass_before_entry = maybe_ass_before_entry
|
|
191
|
+
self.all_vars = (
|
|
192
|
+
set.union(*(set(stat.assigned.keys()) for stat in stats.values()))
|
|
193
|
+
| ass_before_entry
|
|
194
|
+
)
|
|
195
|
+
self._include_unreachable = include_unreachable
|
|
196
|
+
|
|
197
|
+
def initial(self) -> AssignmentDomain[VId]:
|
|
198
|
+
# Note that definite assignment must start with `all_vars` instead of only
|
|
199
|
+
# `ass_before_entry` since we want to compute the *greatest* fixpoint.
|
|
200
|
+
return self.all_vars, self.maybe_ass_before_entry
|
|
201
|
+
|
|
202
|
+
def include_unreachable(self) -> bool:
|
|
203
|
+
return self._include_unreachable
|
|
204
|
+
|
|
205
|
+
def join(self, *ts: AssignmentDomain[VId]) -> AssignmentDomain[VId]:
|
|
206
|
+
# We always include the variables that are definitely assigned before the entry,
|
|
207
|
+
# even if the join is empty
|
|
208
|
+
if len(ts) == 0:
|
|
209
|
+
return self.ass_before_entry, self.ass_before_entry
|
|
210
|
+
|
|
211
|
+
def_ass = set.intersection(*(def_ass for def_ass, _ in ts))
|
|
212
|
+
maybe_ass = set.union(*(maybe_ass for _, maybe_ass in ts))
|
|
213
|
+
return def_ass, maybe_ass
|
|
214
|
+
|
|
215
|
+
def apply_bb(
|
|
216
|
+
self, val_before: AssignmentDomain[VId], bb: BB
|
|
217
|
+
) -> AssignmentDomain[VId]:
|
|
218
|
+
stats = self.stats[bb]
|
|
219
|
+
def_ass_before, maybe_ass_before = val_before
|
|
220
|
+
return (
|
|
221
|
+
def_ass_before | stats.assigned.keys(),
|
|
222
|
+
maybe_ass_before | stats.assigned.keys(),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
def run_unpacked(
|
|
226
|
+
self, bbs: Iterable[BB]
|
|
227
|
+
) -> tuple[Result[DefAssignmentDomain[VId]], Result[MaybeAssignmentDomain[VId]]]:
|
|
228
|
+
"""Runs the analysis and unpacks the definite- and maybe-assignment results."""
|
|
229
|
+
res = self.run(bbs)
|
|
230
|
+
return {bb: res[bb][0] for bb in res}, {bb: res[bb][1] for bb in res}
|