guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,606 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import copy
|
|
3
|
+
import itertools
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import ClassVar, NamedTuple
|
|
7
|
+
|
|
8
|
+
from guppylang_internals.ast_util import (
|
|
9
|
+
AstVisitor,
|
|
10
|
+
ContextAdjuster,
|
|
11
|
+
find_nodes,
|
|
12
|
+
set_location_from,
|
|
13
|
+
template_replace,
|
|
14
|
+
with_loc,
|
|
15
|
+
)
|
|
16
|
+
from guppylang_internals.cfg.bb import BB, BBStatement
|
|
17
|
+
from guppylang_internals.cfg.cfg import CFG
|
|
18
|
+
from guppylang_internals.checker.core import Globals
|
|
19
|
+
from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
|
|
20
|
+
from guppylang_internals.diagnostic import Error
|
|
21
|
+
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
22
|
+
from guppylang_internals.experimental import check_lists_enabled
|
|
23
|
+
from guppylang_internals.nodes import (
|
|
24
|
+
ComptimeExpr,
|
|
25
|
+
DesugaredGenerator,
|
|
26
|
+
DesugaredGeneratorExpr,
|
|
27
|
+
DesugaredListComp,
|
|
28
|
+
IterNext,
|
|
29
|
+
MakeIter,
|
|
30
|
+
NestedFunctionDef,
|
|
31
|
+
)
|
|
32
|
+
from guppylang_internals.tys.ty import NoneType
|
|
33
|
+
|
|
34
|
+
# In order to build expressions, need an endless stream of unique temporary variables
|
|
35
|
+
# to store intermediate results
|
|
36
|
+
tmp_vars: Iterator[str] = (f"%tmp{i}" for i in itertools.count())
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def is_tmp_var(x: str) -> bool:
|
|
40
|
+
"""Checks if a name corresponds to a temporary variable."""
|
|
41
|
+
return x.startswith("%tmp")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Jumps(NamedTuple):
|
|
45
|
+
"""Holds jump targets for return, continue, and break during CFG construction."""
|
|
46
|
+
|
|
47
|
+
return_bb: BB
|
|
48
|
+
continue_bb: BB | None
|
|
49
|
+
break_bb: BB | None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True)
|
|
53
|
+
class UnreachableError(Error):
|
|
54
|
+
title: ClassVar[str] = "Unreachable"
|
|
55
|
+
span_label: ClassVar[str] = "This code is not reachable"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class CFGBuilder(AstVisitor[BB | None]):
|
|
59
|
+
"""Constructs a CFG from ast nodes."""
|
|
60
|
+
|
|
61
|
+
cfg: CFG
|
|
62
|
+
globals: Globals
|
|
63
|
+
|
|
64
|
+
def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> CFG:
|
|
65
|
+
"""Builds a CFG from a list of ast nodes.
|
|
66
|
+
|
|
67
|
+
We also require the expected number of return ports for the whole CFG. This is
|
|
68
|
+
needed to translate return statements into assignments of dummy return
|
|
69
|
+
variables.
|
|
70
|
+
"""
|
|
71
|
+
self.cfg = CFG()
|
|
72
|
+
self.globals = globals
|
|
73
|
+
|
|
74
|
+
final_bb = self.visit_stmts(
|
|
75
|
+
nodes, self.cfg.entry_bb, Jumps(self.cfg.exit_bb, None, None)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Compute reachable BBs
|
|
79
|
+
self.cfg.update_reachable()
|
|
80
|
+
|
|
81
|
+
# If we're still in a basic block after compiling the whole body, we have to add
|
|
82
|
+
# an implicit void return
|
|
83
|
+
if final_bb is not None:
|
|
84
|
+
self.cfg.link(final_bb, self.cfg.exit_bb)
|
|
85
|
+
if final_bb.reachable:
|
|
86
|
+
self.cfg.exit_bb.reachable = True
|
|
87
|
+
if not returns_none:
|
|
88
|
+
raise GuppyError(ExpectedError(nodes[-1], "return statement"))
|
|
89
|
+
|
|
90
|
+
# Prune the CFG such that there are no jumps from unreachable code back into
|
|
91
|
+
# reachable code. Otherwise, unreachable code could lead to unnecessary type
|
|
92
|
+
# checking errors, e.g. if unreachable code changes the type of a variable.
|
|
93
|
+
for bb in self.cfg.bbs:
|
|
94
|
+
if not bb.reachable:
|
|
95
|
+
for succ in list(bb.successors):
|
|
96
|
+
if succ.reachable:
|
|
97
|
+
bb.successors.remove(succ)
|
|
98
|
+
succ.predecessors.remove(bb)
|
|
99
|
+
# Similarly, if a BB is reachable, then there is no need to hold on to dummy
|
|
100
|
+
# jumps into it. Dummy jumps are only needed to propagate type information
|
|
101
|
+
# into and between unreachable BBs
|
|
102
|
+
else:
|
|
103
|
+
for pred in bb.dummy_predecessors:
|
|
104
|
+
pred.dummy_successors.remove(bb)
|
|
105
|
+
bb.dummy_predecessors = []
|
|
106
|
+
|
|
107
|
+
return self.cfg
|
|
108
|
+
|
|
109
|
+
def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> BB | None:
|
|
110
|
+
prev_bb = bb
|
|
111
|
+
bb_opt: BB | None = bb
|
|
112
|
+
next_functional = False
|
|
113
|
+
for node in nodes:
|
|
114
|
+
# If the previous statement jumped, then all following statements are
|
|
115
|
+
# unreachable. Just create a new dummy BB and keep going so we can still
|
|
116
|
+
# check the unreachable code.
|
|
117
|
+
if bb_opt is None:
|
|
118
|
+
bb_opt = self.cfg.new_bb()
|
|
119
|
+
self.cfg.dummy_link(prev_bb, bb_opt)
|
|
120
|
+
if is_functional_annotation(node):
|
|
121
|
+
next_functional = True
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
if next_functional:
|
|
125
|
+
# TODO: This should be an assertion that the Hugr can be un-flattened
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
next_functional = False
|
|
128
|
+
else:
|
|
129
|
+
prev_bb, bb_opt = bb_opt, self.visit(node, bb_opt, jumps)
|
|
130
|
+
return bb_opt
|
|
131
|
+
|
|
132
|
+
def _build_node_value(self, node: BBStatement, bb: BB) -> BB:
|
|
133
|
+
"""Utility method for building a node containing a `value` expression.
|
|
134
|
+
|
|
135
|
+
Builds the expression and mutates `node.value` to point to the built expression.
|
|
136
|
+
Returns the BB in which the expression is available and adds the node to it.
|
|
137
|
+
"""
|
|
138
|
+
if not isinstance(node, NestedFunctionDef) and node.value is not None:
|
|
139
|
+
node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
|
|
140
|
+
bb.statements.append(node)
|
|
141
|
+
return bb
|
|
142
|
+
|
|
143
|
+
def visit_Assign(self, node: ast.Assign, bb: BB, jumps: Jumps) -> BB | None:
|
|
144
|
+
return self._build_node_value(node, bb)
|
|
145
|
+
|
|
146
|
+
def visit_AugAssign(self, node: ast.AugAssign, bb: BB, jumps: Jumps) -> BB | None:
|
|
147
|
+
return self._build_node_value(node, bb)
|
|
148
|
+
|
|
149
|
+
def visit_AnnAssign(self, node: ast.AnnAssign, bb: BB, jumps: Jumps) -> BB | None:
|
|
150
|
+
return self._build_node_value(node, bb)
|
|
151
|
+
|
|
152
|
+
def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> BB | None:
|
|
153
|
+
# This is an expression statement where the value is discarded
|
|
154
|
+
node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
|
|
155
|
+
# We don't add it to the BB if it's just a temporary variable. This will be the
|
|
156
|
+
# case if it's a branching expression, e.g. `42 if cond else False`. In that
|
|
157
|
+
# example the type mismatch is actually fine since the result is never used. To
|
|
158
|
+
# achieve this behaviour we must not add the temporary result variable to the BB
|
|
159
|
+
if not isinstance(node.value, ast.Name) or not is_tmp_var(node.value.id):
|
|
160
|
+
bb.statements.append(node)
|
|
161
|
+
return bb
|
|
162
|
+
|
|
163
|
+
def visit_If(self, node: ast.If, bb: BB, jumps: Jumps) -> BB | None:
|
|
164
|
+
then_bb, else_bb = self.cfg.new_bb(), self.cfg.new_bb()
|
|
165
|
+
BranchBuilder.add_branch(node.test, self.cfg, bb, then_bb, else_bb)
|
|
166
|
+
then_bb = self.visit_stmts(node.body, then_bb, jumps)
|
|
167
|
+
else_bb = self.visit_stmts(node.orelse, else_bb, jumps)
|
|
168
|
+
# We need to handle different cases depending on whether branches jump (i.e.
|
|
169
|
+
# return, continue, or break)
|
|
170
|
+
if then_bb is None:
|
|
171
|
+
# If branch jumps: We continue in the BB of the else branch
|
|
172
|
+
return else_bb
|
|
173
|
+
elif else_bb is None:
|
|
174
|
+
# Else branch jumps: We continue in the BB of the if branch
|
|
175
|
+
return then_bb
|
|
176
|
+
else:
|
|
177
|
+
# No branch jumps: We have to merge the control flow
|
|
178
|
+
return self.cfg.new_bb(then_bb, else_bb)
|
|
179
|
+
|
|
180
|
+
def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> BB | None:
|
|
181
|
+
head_bb = self.cfg.new_bb(bb)
|
|
182
|
+
body_bb, tail_bb = self.cfg.new_bb(), self.cfg.new_bb()
|
|
183
|
+
BranchBuilder.add_branch(node.test, self.cfg, head_bb, body_bb, tail_bb)
|
|
184
|
+
|
|
185
|
+
new_jumps = Jumps(
|
|
186
|
+
return_bb=jumps.return_bb, continue_bb=head_bb, break_bb=tail_bb
|
|
187
|
+
)
|
|
188
|
+
body_end_bb = self.visit_stmts(node.body, body_bb, new_jumps)
|
|
189
|
+
|
|
190
|
+
# Go back to the head (but only the body doesn't do its jumping)
|
|
191
|
+
if body_end_bb is not None:
|
|
192
|
+
self.cfg.link(body_end_bb, head_bb)
|
|
193
|
+
|
|
194
|
+
# Continue compilation in the tail. This should even happen if the body does
|
|
195
|
+
# its own jumps since the body is not guaranteed to execute
|
|
196
|
+
return tail_bb
|
|
197
|
+
|
|
198
|
+
def visit_For(self, node: ast.For, bb: BB, jumps: Jumps) -> BB | None:
|
|
199
|
+
template = """
|
|
200
|
+
it = make_iter
|
|
201
|
+
while True:
|
|
202
|
+
res = iter_next
|
|
203
|
+
if not res.is_some():
|
|
204
|
+
res.unwrap_nothing()
|
|
205
|
+
break
|
|
206
|
+
x, it = res.unwrap()
|
|
207
|
+
body
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
it = make_var(next(tmp_vars), node.iter)
|
|
211
|
+
res = make_var(next(tmp_vars), node.iter)
|
|
212
|
+
new_nodes = template_replace(
|
|
213
|
+
template,
|
|
214
|
+
node.iter,
|
|
215
|
+
it=it,
|
|
216
|
+
res=res,
|
|
217
|
+
x=node.target,
|
|
218
|
+
make_iter=with_loc(node.iter, MakeIter(value=node.iter, origin_node=node)),
|
|
219
|
+
iter_next=with_loc(node.iter, IterNext(value=it)),
|
|
220
|
+
body=node.body,
|
|
221
|
+
)
|
|
222
|
+
return self.visit_stmts(new_nodes, bb, jumps)
|
|
223
|
+
|
|
224
|
+
def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> BB | None:
|
|
225
|
+
if not jumps.continue_bb:
|
|
226
|
+
raise InternalGuppyError("Continue BB not defined")
|
|
227
|
+
self.cfg.link(bb, jumps.continue_bb)
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
def visit_Break(self, node: ast.Break, bb: BB, jumps: Jumps) -> BB | None:
|
|
231
|
+
if not jumps.break_bb:
|
|
232
|
+
raise InternalGuppyError("Break BB not defined")
|
|
233
|
+
self.cfg.link(bb, jumps.break_bb)
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
def visit_Return(self, node: ast.Return, bb: BB, jumps: Jumps) -> BB | None:
|
|
237
|
+
bb = self._build_node_value(node, bb)
|
|
238
|
+
self.cfg.link(bb, jumps.return_bb)
|
|
239
|
+
return None
|
|
240
|
+
|
|
241
|
+
def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> BB | None:
|
|
242
|
+
return bb
|
|
243
|
+
|
|
244
|
+
def visit_FunctionDef(
|
|
245
|
+
self, node: ast.FunctionDef, bb: BB, jumps: Jumps
|
|
246
|
+
) -> BB | None:
|
|
247
|
+
from guppylang_internals.checker.func_checker import (
|
|
248
|
+
check_signature,
|
|
249
|
+
parse_function_with_docstring,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
node, docstring = parse_function_with_docstring(node)
|
|
253
|
+
|
|
254
|
+
func_ty = check_signature(node, self.globals)
|
|
255
|
+
returns_none = isinstance(func_ty.output, NoneType)
|
|
256
|
+
cfg = CFGBuilder().build(node.body, returns_none, self.globals)
|
|
257
|
+
|
|
258
|
+
new_node = NestedFunctionDef(
|
|
259
|
+
cfg,
|
|
260
|
+
func_ty,
|
|
261
|
+
docstring=docstring,
|
|
262
|
+
**dict(ast.iter_fields(node)),
|
|
263
|
+
)
|
|
264
|
+
set_location_from(new_node, node)
|
|
265
|
+
bb.statements.append(new_node)
|
|
266
|
+
return bb
|
|
267
|
+
|
|
268
|
+
def generic_visit(self, node: ast.AST, bb: BB, jumps: Jumps) -> BB | None:
|
|
269
|
+
# When adding support for new statements, we have to remember to use the
|
|
270
|
+
# ExprBuilder to transform all included expressions!
|
|
271
|
+
raise GuppyError(UnsupportedError(node, "This statement", singular=True))
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class ExprBuilder(ast.NodeTransformer):
|
|
275
|
+
"""Builds an expression into a basic block."""
|
|
276
|
+
|
|
277
|
+
cfg: CFG
|
|
278
|
+
bb: BB
|
|
279
|
+
|
|
280
|
+
def __init__(self, cfg: CFG, start_bb: BB) -> None:
|
|
281
|
+
self.cfg = cfg
|
|
282
|
+
self.bb = start_bb
|
|
283
|
+
|
|
284
|
+
@staticmethod
|
|
285
|
+
def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]:
|
|
286
|
+
"""Builds an expression into a CFG.
|
|
287
|
+
|
|
288
|
+
The expression may be transformed and new basic blocks may be created (for
|
|
289
|
+
example for `... if ... else ...` expressions). Returns the new expression and
|
|
290
|
+
the final basic block in which the expression can be used."""
|
|
291
|
+
builder = ExprBuilder(cfg, bb)
|
|
292
|
+
return builder.visit(node), builder.bb
|
|
293
|
+
|
|
294
|
+
@classmethod
|
|
295
|
+
def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None:
|
|
296
|
+
"""Adds a temporary variable assignment to a basic block."""
|
|
297
|
+
lhs = make_var(tmp_name, value)
|
|
298
|
+
bb.statements.append(make_assign([lhs], value))
|
|
299
|
+
|
|
300
|
+
def visit_Name(self, node: ast.Name) -> ast.Name:
|
|
301
|
+
return node
|
|
302
|
+
|
|
303
|
+
def visit_NamedExpr(self, node: ast.NamedExpr) -> ast.Name:
|
|
304
|
+
# This is an assignment expression, e.g. `x := 42`. We turn it into an
|
|
305
|
+
# assignment statement and replace the expression with `x`.
|
|
306
|
+
if not isinstance(node.target, ast.Name):
|
|
307
|
+
raise InternalGuppyError(f"Unexpected assign target: {node.target}")
|
|
308
|
+
assign = ast.Assign(
|
|
309
|
+
targets=[copy.deepcopy(node.target)], value=self.visit(node.value)
|
|
310
|
+
)
|
|
311
|
+
set_location_from(assign, node)
|
|
312
|
+
self.bb.statements.append(assign)
|
|
313
|
+
return node.target
|
|
314
|
+
|
|
315
|
+
def visit_IfExp(self, node: ast.IfExp) -> ast.Name:
|
|
316
|
+
if_bb, else_bb = self.cfg.new_bb(), self.cfg.new_bb()
|
|
317
|
+
BranchBuilder.add_branch(node.test, self.cfg, self.bb, if_bb, else_bb)
|
|
318
|
+
|
|
319
|
+
if_expr, if_bb = self.build(node.body, self.cfg, if_bb)
|
|
320
|
+
else_expr, else_bb = self.build(node.orelse, self.cfg, else_bb)
|
|
321
|
+
|
|
322
|
+
# Assign the result to a temporary variable
|
|
323
|
+
tmp = next(tmp_vars)
|
|
324
|
+
self._tmp_assign(tmp, if_expr, if_bb)
|
|
325
|
+
self._tmp_assign(tmp, else_expr, else_bb)
|
|
326
|
+
|
|
327
|
+
# Merge the temporary variables in a new BB
|
|
328
|
+
merge_bb = self.cfg.new_bb(if_bb, else_bb)
|
|
329
|
+
self.bb = merge_bb
|
|
330
|
+
|
|
331
|
+
# The final value is stored in the temporary variable
|
|
332
|
+
return make_var(tmp, node)
|
|
333
|
+
|
|
334
|
+
def visit_ListComp(self, node: ast.ListComp) -> DesugaredListComp:
|
|
335
|
+
check_lists_enabled(node)
|
|
336
|
+
generators, elt = desugar_comprehension(node.generators, node.elt, node)
|
|
337
|
+
return with_loc(node, DesugaredListComp(elt=elt, generators=generators))
|
|
338
|
+
|
|
339
|
+
def visit_GeneratorExp(self, node: ast.GeneratorExp) -> DesugaredGeneratorExpr:
|
|
340
|
+
generators, elt = desugar_comprehension(node.generators, node.elt, node)
|
|
341
|
+
return with_loc(node, DesugaredGeneratorExpr(elt=elt, generators=generators))
|
|
342
|
+
|
|
343
|
+
def visit_Call(self, node: ast.Call) -> ast.AST:
|
|
344
|
+
return is_comptime_expression(node) or self.generic_visit(node)
|
|
345
|
+
|
|
346
|
+
def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST:
|
|
347
|
+
# Desugar negated numeric constants into constants
|
|
348
|
+
match node.op, node.operand:
|
|
349
|
+
case ast.USub(), ast.Constant(value=float(v) | int(v)) as const:
|
|
350
|
+
const.value = -v
|
|
351
|
+
return with_loc(node, const)
|
|
352
|
+
case _:
|
|
353
|
+
return self.generic_visit(node)
|
|
354
|
+
|
|
355
|
+
def generic_visit(self, node: ast.AST) -> ast.AST:
|
|
356
|
+
# Short-circuit expressions must be built using the `BranchBuilder`. However, we
|
|
357
|
+
# can turn them into regular expressions by assigning True/False to a temporary
|
|
358
|
+
# variable and merging the control-flow
|
|
359
|
+
if is_short_circuit_expr(node):
|
|
360
|
+
assert isinstance(node, ast.expr)
|
|
361
|
+
true_bb, false_bb = self.cfg.new_bb(), self.cfg.new_bb()
|
|
362
|
+
BranchBuilder.add_branch(node, self.cfg, self.bb, true_bb, false_bb)
|
|
363
|
+
true_const = ast.Constant(value=True)
|
|
364
|
+
false_const = ast.Constant(value=False)
|
|
365
|
+
set_location_from(true_const, node)
|
|
366
|
+
set_location_from(false_const, node)
|
|
367
|
+
tmp = next(tmp_vars)
|
|
368
|
+
self._tmp_assign(tmp, true_const, true_bb)
|
|
369
|
+
self._tmp_assign(tmp, false_const, false_bb)
|
|
370
|
+
merge_bb = self.cfg.new_bb(true_bb, false_bb)
|
|
371
|
+
self.bb = merge_bb
|
|
372
|
+
return make_var(tmp, node)
|
|
373
|
+
# For all other expressions, just recurse deeper with the node transformer
|
|
374
|
+
return super().generic_visit(node)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class BranchBuilder(AstVisitor[None]):
|
|
378
|
+
"""Builds an expression and does branching based on the value.
|
|
379
|
+
|
|
380
|
+
This builder should be used to handle all branching on boolean values since it
|
|
381
|
+
handles short-circuit evaluation etc.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
cfg: CFG
|
|
385
|
+
|
|
386
|
+
def __init__(self, cfg: CFG):
|
|
387
|
+
"""Creates a new `BranchBuilder`."""
|
|
388
|
+
self.cfg = cfg
|
|
389
|
+
|
|
390
|
+
@staticmethod
|
|
391
|
+
def add_branch(node: ast.expr, cfg: CFG, bb: BB, true_bb: BB, false_bb: BB) -> None:
|
|
392
|
+
"""Builds an expression and branches to `true_bb` or `false_bb`, depending on
|
|
393
|
+
the truth value of the expression."""
|
|
394
|
+
builder = BranchBuilder(cfg)
|
|
395
|
+
builder.visit(node, bb, true_bb, false_bb)
|
|
396
|
+
|
|
397
|
+
def visit_Constant(
|
|
398
|
+
self, node: ast.Constant, bb: BB, true_bb: BB, false_bb: BB
|
|
399
|
+
) -> None:
|
|
400
|
+
# Branching on `True` or `False` constant should be unconditional
|
|
401
|
+
if isinstance(node.value, bool):
|
|
402
|
+
self.cfg.link(bb, true_bb if node.value else false_bb)
|
|
403
|
+
self.cfg.dummy_link(bb, false_bb if node.value else true_bb)
|
|
404
|
+
else:
|
|
405
|
+
self.generic_visit(node, bb, true_bb, false_bb)
|
|
406
|
+
|
|
407
|
+
def visit_BoolOp(self, node: ast.BoolOp, bb: BB, true_bb: BB, false_bb: BB) -> None:
|
|
408
|
+
# Add short-circuit evaluation of boolean expression. If there are more than 2
|
|
409
|
+
# operators, we turn the flat operator list into a right-nested tree to allow
|
|
410
|
+
# for recursive processing.
|
|
411
|
+
assert len(node.values) > 1
|
|
412
|
+
if len(node.values) > 2:
|
|
413
|
+
r = ast.BoolOp(
|
|
414
|
+
op=node.op,
|
|
415
|
+
values=node.values[1:],
|
|
416
|
+
lineno=node.values[0].lineno,
|
|
417
|
+
col_offset=node.values[0].col_offset,
|
|
418
|
+
end_lineno=node.values[-1].end_lineno,
|
|
419
|
+
end_col_offset=node.values[-1].end_col_offset,
|
|
420
|
+
)
|
|
421
|
+
node.values = [node.values[0], r]
|
|
422
|
+
[left, right] = node.values
|
|
423
|
+
|
|
424
|
+
extra_bb = self.cfg.new_bb()
|
|
425
|
+
assert type(node.op) in [ast.And, ast.Or]
|
|
426
|
+
if isinstance(node.op, ast.And):
|
|
427
|
+
self.visit(left, bb, extra_bb, false_bb)
|
|
428
|
+
elif isinstance(node.op, ast.Or):
|
|
429
|
+
self.visit(left, bb, true_bb, extra_bb)
|
|
430
|
+
self.visit(right, extra_bb, true_bb, false_bb)
|
|
431
|
+
|
|
432
|
+
def visit_UnaryOp(
|
|
433
|
+
self, node: ast.UnaryOp, bb: BB, true_bb: BB, false_bb: BB
|
|
434
|
+
) -> None:
|
|
435
|
+
# For `not` operator, we can just switch `true_bb` and `false_bb`
|
|
436
|
+
if isinstance(node.op, ast.Not):
|
|
437
|
+
self.visit(node.operand, bb, false_bb, true_bb)
|
|
438
|
+
else:
|
|
439
|
+
self.generic_visit(node, bb, true_bb, false_bb)
|
|
440
|
+
|
|
441
|
+
def visit_Compare(
|
|
442
|
+
self, node: ast.Compare, bb: BB, true_bb: BB, false_bb: BB
|
|
443
|
+
) -> None:
|
|
444
|
+
# Support chained comparisons, e.g. `x <= 5 < y` by compiling to `x <= 5 and
|
|
445
|
+
# 5 < y`. This way we get short-circuit evaluation for free.
|
|
446
|
+
if len(node.comparators) > 1:
|
|
447
|
+
comparators = [node.left, *node.comparators]
|
|
448
|
+
values = [
|
|
449
|
+
ast.Compare(
|
|
450
|
+
left=left,
|
|
451
|
+
ops=[op],
|
|
452
|
+
comparators=[right],
|
|
453
|
+
lineno=left.lineno,
|
|
454
|
+
col_offset=left.col_offset,
|
|
455
|
+
end_lineno=right.end_lineno,
|
|
456
|
+
end_col_offset=right.end_col_offset,
|
|
457
|
+
)
|
|
458
|
+
for left, op, right in zip(
|
|
459
|
+
comparators[:-1], node.ops, comparators[1:], strict=True
|
|
460
|
+
)
|
|
461
|
+
]
|
|
462
|
+
conj = ast.BoolOp(op=ast.And(), values=values)
|
|
463
|
+
set_location_from(conj, node)
|
|
464
|
+
self.visit_BoolOp(conj, bb, true_bb, false_bb)
|
|
465
|
+
else:
|
|
466
|
+
self.generic_visit(node, bb, true_bb, false_bb)
|
|
467
|
+
|
|
468
|
+
def visit_IfExp(self, node: ast.IfExp, bb: BB, true_bb: BB, false_bb: BB) -> None:
|
|
469
|
+
then_bb, else_bb = self.cfg.new_bb(), self.cfg.new_bb()
|
|
470
|
+
self.visit(node.test, bb, then_bb, else_bb)
|
|
471
|
+
self.visit(node.body, then_bb, true_bb, false_bb)
|
|
472
|
+
self.visit(node.orelse, else_bb, true_bb, false_bb)
|
|
473
|
+
|
|
474
|
+
def generic_visit(self, node: ast.expr, bb: BB, true_bb: BB, false_bb: BB) -> None:
|
|
475
|
+
# We can always fall back to building the node as a regular expression and using
|
|
476
|
+
# the result as a branch predicate
|
|
477
|
+
pred, bb = ExprBuilder.build(node, self.cfg, bb)
|
|
478
|
+
bb.branch_pred = pred
|
|
479
|
+
self.cfg.link(bb, false_bb)
|
|
480
|
+
self.cfg.link(bb, true_bb)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def desugar_comprehension(
|
|
484
|
+
generators: list[ast.comprehension], elt: ast.expr, node: ast.AST
|
|
485
|
+
) -> tuple[list[DesugaredGenerator], ast.expr]:
|
|
486
|
+
"""Helper function to desugar a comprehension node."""
|
|
487
|
+
# Check for illegal expressions
|
|
488
|
+
illegals = find_nodes(is_illegal_in_list_comp, node)
|
|
489
|
+
if illegals:
|
|
490
|
+
err = UnsupportedError(
|
|
491
|
+
illegals[0],
|
|
492
|
+
"This expression",
|
|
493
|
+
singular=True,
|
|
494
|
+
unsupported_in="a list comprehension",
|
|
495
|
+
)
|
|
496
|
+
raise GuppyError(err)
|
|
497
|
+
|
|
498
|
+
# The check above ensures that the comprehension doesn't contain any control-flow
|
|
499
|
+
# expressions. Thus, we can use a dummy `ExprBuilder` to desugar the insides.
|
|
500
|
+
# TODO: Refactor so that desugaring is separate from control-flow building
|
|
501
|
+
dummy_cfg = CFG()
|
|
502
|
+
builder = ExprBuilder(dummy_cfg, dummy_cfg.entry_bb)
|
|
503
|
+
|
|
504
|
+
# Desugar into statements that create the iterator, check for a next element,
|
|
505
|
+
# get the next element, and finalise the iterator.
|
|
506
|
+
gens = []
|
|
507
|
+
for g in generators:
|
|
508
|
+
if g.is_async:
|
|
509
|
+
raise GuppyError(UnsupportedError(g, "Async generators"))
|
|
510
|
+
g.iter = builder.visit(g.iter)
|
|
511
|
+
it = make_var(next(tmp_vars), g.iter)
|
|
512
|
+
desugared = DesugaredGenerator(
|
|
513
|
+
iter=it,
|
|
514
|
+
iter_assign=make_assign(
|
|
515
|
+
[it], with_loc(it, MakeIter(value=g.iter, origin_node=node))
|
|
516
|
+
),
|
|
517
|
+
next_call=with_loc(it, IterNext(value=it)),
|
|
518
|
+
target=g.target,
|
|
519
|
+
ifs=g.ifs,
|
|
520
|
+
used_outer_places=[],
|
|
521
|
+
)
|
|
522
|
+
gens.append(desugared)
|
|
523
|
+
|
|
524
|
+
elt = builder.visit(elt)
|
|
525
|
+
return gens, elt
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def is_functional_annotation(stmt: ast.stmt) -> bool:
|
|
529
|
+
"""Returns `True` iff the given statement is the functional pseudo-decorator.
|
|
530
|
+
|
|
531
|
+
Pseudo-decorators are built using the matmul operator `@`, i.e. `_@functional`.
|
|
532
|
+
"""
|
|
533
|
+
if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.BinOp):
|
|
534
|
+
op = stmt.value
|
|
535
|
+
if (
|
|
536
|
+
isinstance(op.op, ast.MatMult)
|
|
537
|
+
and isinstance(op.left, ast.Name)
|
|
538
|
+
and isinstance(op.right, ast.Name)
|
|
539
|
+
):
|
|
540
|
+
return op.left.id == "_" and op.right.id == "functional"
|
|
541
|
+
return False
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
@dataclass(frozen=True)
|
|
545
|
+
class EmptyComptimeExprError(Error):
|
|
546
|
+
title: ClassVar[str] = "Invalid comptime expression"
|
|
547
|
+
span_label: ClassVar[str] = "Comptime expression requires an argument"
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def is_comptime_expression(node: ast.AST) -> ComptimeExpr | None:
|
|
551
|
+
"""Checks if the given node is a `comptime(...)` expression and turns it into
|
|
552
|
+
a `ComptimeExpr` AST node.
|
|
553
|
+
|
|
554
|
+
Also accepts the `py(...)` alias for `comptime` expressions.
|
|
555
|
+
|
|
556
|
+
Otherwise, returns `None`.
|
|
557
|
+
"""
|
|
558
|
+
if (
|
|
559
|
+
isinstance(node, ast.Call)
|
|
560
|
+
and isinstance(node.func, ast.Name)
|
|
561
|
+
and node.func.id in ("py", "comptime")
|
|
562
|
+
):
|
|
563
|
+
match node.args:
|
|
564
|
+
case []:
|
|
565
|
+
raise GuppyError(EmptyComptimeExprError(node))
|
|
566
|
+
case [arg]:
|
|
567
|
+
pass
|
|
568
|
+
case args:
|
|
569
|
+
arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load))
|
|
570
|
+
return with_loc(node, ComptimeExpr(value=arg))
|
|
571
|
+
return None
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def is_short_circuit_expr(node: ast.AST) -> bool:
|
|
575
|
+
"""Checks if an expression uses short-circuiting.
|
|
576
|
+
|
|
577
|
+
Those expressions *must* be compiled using the `BranchBuilder`.
|
|
578
|
+
"""
|
|
579
|
+
return isinstance(node, ast.BoolOp) or (
|
|
580
|
+
isinstance(node, ast.Compare) and len(node.comparators) > 1
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def is_illegal_in_list_comp(node: ast.AST) -> bool:
|
|
585
|
+
"""Checks if an expression is illegal to use in a list comprehension."""
|
|
586
|
+
return isinstance(node, ast.IfExp | ast.NamedExpr) or is_short_circuit_expr(node)
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def make_var(name: str, loc: ast.AST | None = None) -> ast.Name:
|
|
590
|
+
"""Creates an `ast.Name` node."""
|
|
591
|
+
node = ast.Name(id=name, ctx=ast.Load)
|
|
592
|
+
if loc is not None:
|
|
593
|
+
set_location_from(node, loc)
|
|
594
|
+
return node
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign:
|
|
598
|
+
"""Creates an `ast.Assign` node."""
|
|
599
|
+
assert len(lhs) > 0
|
|
600
|
+
adjuster = ContextAdjuster(ast.Store())
|
|
601
|
+
lhs = [adjuster.visit(expr) for expr in lhs]
|
|
602
|
+
if len(lhs) == 1:
|
|
603
|
+
target = lhs[0]
|
|
604
|
+
else:
|
|
605
|
+
target = with_loc(value, ast.Tuple(elts=lhs, ctx=ast.Store()))
|
|
606
|
+
return with_loc(value, ast.Assign(targets=[target], value=value))
|