guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,606 @@
1
+ import ast
2
+ import copy
3
+ import itertools
4
+ from collections.abc import Iterator
5
+ from dataclasses import dataclass
6
+ from typing import ClassVar, NamedTuple
7
+
8
+ from guppylang_internals.ast_util import (
9
+ AstVisitor,
10
+ ContextAdjuster,
11
+ find_nodes,
12
+ set_location_from,
13
+ template_replace,
14
+ with_loc,
15
+ )
16
+ from guppylang_internals.cfg.bb import BB, BBStatement
17
+ from guppylang_internals.cfg.cfg import CFG
18
+ from guppylang_internals.checker.core import Globals
19
+ from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
20
+ from guppylang_internals.diagnostic import Error
21
+ from guppylang_internals.error import GuppyError, InternalGuppyError
22
+ from guppylang_internals.experimental import check_lists_enabled
23
+ from guppylang_internals.nodes import (
24
+ ComptimeExpr,
25
+ DesugaredGenerator,
26
+ DesugaredGeneratorExpr,
27
+ DesugaredListComp,
28
+ IterNext,
29
+ MakeIter,
30
+ NestedFunctionDef,
31
+ )
32
+ from guppylang_internals.tys.ty import NoneType
33
+
34
+ # In order to build expressions, need an endless stream of unique temporary variables
35
+ # to store intermediate results
36
+ tmp_vars: Iterator[str] = (f"%tmp{i}" for i in itertools.count())
37
+
38
+
39
+ def is_tmp_var(x: str) -> bool:
40
+ """Checks if a name corresponds to a temporary variable."""
41
+ return x.startswith("%tmp")
42
+
43
+
44
+ class Jumps(NamedTuple):
45
+ """Holds jump targets for return, continue, and break during CFG construction."""
46
+
47
+ return_bb: BB
48
+ continue_bb: BB | None
49
+ break_bb: BB | None
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class UnreachableError(Error):
54
+ title: ClassVar[str] = "Unreachable"
55
+ span_label: ClassVar[str] = "This code is not reachable"
56
+
57
+
58
+ class CFGBuilder(AstVisitor[BB | None]):
59
+ """Constructs a CFG from ast nodes."""
60
+
61
+ cfg: CFG
62
+ globals: Globals
63
+
64
+ def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> CFG:
65
+ """Builds a CFG from a list of ast nodes.
66
+
67
+ We also require the expected number of return ports for the whole CFG. This is
68
+ needed to translate return statements into assignments of dummy return
69
+ variables.
70
+ """
71
+ self.cfg = CFG()
72
+ self.globals = globals
73
+
74
+ final_bb = self.visit_stmts(
75
+ nodes, self.cfg.entry_bb, Jumps(self.cfg.exit_bb, None, None)
76
+ )
77
+
78
+ # Compute reachable BBs
79
+ self.cfg.update_reachable()
80
+
81
+ # If we're still in a basic block after compiling the whole body, we have to add
82
+ # an implicit void return
83
+ if final_bb is not None:
84
+ self.cfg.link(final_bb, self.cfg.exit_bb)
85
+ if final_bb.reachable:
86
+ self.cfg.exit_bb.reachable = True
87
+ if not returns_none:
88
+ raise GuppyError(ExpectedError(nodes[-1], "return statement"))
89
+
90
+ # Prune the CFG such that there are no jumps from unreachable code back into
91
+ # reachable code. Otherwise, unreachable code could lead to unnecessary type
92
+ # checking errors, e.g. if unreachable code changes the type of a variable.
93
+ for bb in self.cfg.bbs:
94
+ if not bb.reachable:
95
+ for succ in list(bb.successors):
96
+ if succ.reachable:
97
+ bb.successors.remove(succ)
98
+ succ.predecessors.remove(bb)
99
+ # Similarly, if a BB is reachable, then there is no need to hold on to dummy
100
+ # jumps into it. Dummy jumps are only needed to propagate type information
101
+ # into and between unreachable BBs
102
+ else:
103
+ for pred in bb.dummy_predecessors:
104
+ pred.dummy_successors.remove(bb)
105
+ bb.dummy_predecessors = []
106
+
107
+ return self.cfg
108
+
109
+ def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> BB | None:
110
+ prev_bb = bb
111
+ bb_opt: BB | None = bb
112
+ next_functional = False
113
+ for node in nodes:
114
+ # If the previous statement jumped, then all following statements are
115
+ # unreachable. Just create a new dummy BB and keep going so we can still
116
+ # check the unreachable code.
117
+ if bb_opt is None:
118
+ bb_opt = self.cfg.new_bb()
119
+ self.cfg.dummy_link(prev_bb, bb_opt)
120
+ if is_functional_annotation(node):
121
+ next_functional = True
122
+ continue
123
+
124
+ if next_functional:
125
+ # TODO: This should be an assertion that the Hugr can be un-flattened
126
+ raise NotImplementedError
127
+ next_functional = False
128
+ else:
129
+ prev_bb, bb_opt = bb_opt, self.visit(node, bb_opt, jumps)
130
+ return bb_opt
131
+
132
+ def _build_node_value(self, node: BBStatement, bb: BB) -> BB:
133
+ """Utility method for building a node containing a `value` expression.
134
+
135
+ Builds the expression and mutates `node.value` to point to the built expression.
136
+ Returns the BB in which the expression is available and adds the node to it.
137
+ """
138
+ if not isinstance(node, NestedFunctionDef) and node.value is not None:
139
+ node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
140
+ bb.statements.append(node)
141
+ return bb
142
+
143
+ def visit_Assign(self, node: ast.Assign, bb: BB, jumps: Jumps) -> BB | None:
144
+ return self._build_node_value(node, bb)
145
+
146
+ def visit_AugAssign(self, node: ast.AugAssign, bb: BB, jumps: Jumps) -> BB | None:
147
+ return self._build_node_value(node, bb)
148
+
149
+ def visit_AnnAssign(self, node: ast.AnnAssign, bb: BB, jumps: Jumps) -> BB | None:
150
+ return self._build_node_value(node, bb)
151
+
152
+ def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> BB | None:
153
+ # This is an expression statement where the value is discarded
154
+ node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
155
+ # We don't add it to the BB if it's just a temporary variable. This will be the
156
+ # case if it's a branching expression, e.g. `42 if cond else False`. In that
157
+ # example the type mismatch is actually fine since the result is never used. To
158
+ # achieve this behaviour we must not add the temporary result variable to the BB
159
+ if not isinstance(node.value, ast.Name) or not is_tmp_var(node.value.id):
160
+ bb.statements.append(node)
161
+ return bb
162
+
163
+ def visit_If(self, node: ast.If, bb: BB, jumps: Jumps) -> BB | None:
164
+ then_bb, else_bb = self.cfg.new_bb(), self.cfg.new_bb()
165
+ BranchBuilder.add_branch(node.test, self.cfg, bb, then_bb, else_bb)
166
+ then_bb = self.visit_stmts(node.body, then_bb, jumps)
167
+ else_bb = self.visit_stmts(node.orelse, else_bb, jumps)
168
+ # We need to handle different cases depending on whether branches jump (i.e.
169
+ # return, continue, or break)
170
+ if then_bb is None:
171
+ # If branch jumps: We continue in the BB of the else branch
172
+ return else_bb
173
+ elif else_bb is None:
174
+ # Else branch jumps: We continue in the BB of the if branch
175
+ return then_bb
176
+ else:
177
+ # No branch jumps: We have to merge the control flow
178
+ return self.cfg.new_bb(then_bb, else_bb)
179
+
180
+ def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> BB | None:
181
+ head_bb = self.cfg.new_bb(bb)
182
+ body_bb, tail_bb = self.cfg.new_bb(), self.cfg.new_bb()
183
+ BranchBuilder.add_branch(node.test, self.cfg, head_bb, body_bb, tail_bb)
184
+
185
+ new_jumps = Jumps(
186
+ return_bb=jumps.return_bb, continue_bb=head_bb, break_bb=tail_bb
187
+ )
188
+ body_end_bb = self.visit_stmts(node.body, body_bb, new_jumps)
189
+
190
+ # Go back to the head (but only the body doesn't do its jumping)
191
+ if body_end_bb is not None:
192
+ self.cfg.link(body_end_bb, head_bb)
193
+
194
+ # Continue compilation in the tail. This should even happen if the body does
195
+ # its own jumps since the body is not guaranteed to execute
196
+ return tail_bb
197
+
198
+ def visit_For(self, node: ast.For, bb: BB, jumps: Jumps) -> BB | None:
199
+ template = """
200
+ it = make_iter
201
+ while True:
202
+ res = iter_next
203
+ if not res.is_some():
204
+ res.unwrap_nothing()
205
+ break
206
+ x, it = res.unwrap()
207
+ body
208
+ """
209
+
210
+ it = make_var(next(tmp_vars), node.iter)
211
+ res = make_var(next(tmp_vars), node.iter)
212
+ new_nodes = template_replace(
213
+ template,
214
+ node.iter,
215
+ it=it,
216
+ res=res,
217
+ x=node.target,
218
+ make_iter=with_loc(node.iter, MakeIter(value=node.iter, origin_node=node)),
219
+ iter_next=with_loc(node.iter, IterNext(value=it)),
220
+ body=node.body,
221
+ )
222
+ return self.visit_stmts(new_nodes, bb, jumps)
223
+
224
+ def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> BB | None:
225
+ if not jumps.continue_bb:
226
+ raise InternalGuppyError("Continue BB not defined")
227
+ self.cfg.link(bb, jumps.continue_bb)
228
+ return None
229
+
230
+ def visit_Break(self, node: ast.Break, bb: BB, jumps: Jumps) -> BB | None:
231
+ if not jumps.break_bb:
232
+ raise InternalGuppyError("Break BB not defined")
233
+ self.cfg.link(bb, jumps.break_bb)
234
+ return None
235
+
236
+ def visit_Return(self, node: ast.Return, bb: BB, jumps: Jumps) -> BB | None:
237
+ bb = self._build_node_value(node, bb)
238
+ self.cfg.link(bb, jumps.return_bb)
239
+ return None
240
+
241
+ def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> BB | None:
242
+ return bb
243
+
244
+ def visit_FunctionDef(
245
+ self, node: ast.FunctionDef, bb: BB, jumps: Jumps
246
+ ) -> BB | None:
247
+ from guppylang_internals.checker.func_checker import (
248
+ check_signature,
249
+ parse_function_with_docstring,
250
+ )
251
+
252
+ node, docstring = parse_function_with_docstring(node)
253
+
254
+ func_ty = check_signature(node, self.globals)
255
+ returns_none = isinstance(func_ty.output, NoneType)
256
+ cfg = CFGBuilder().build(node.body, returns_none, self.globals)
257
+
258
+ new_node = NestedFunctionDef(
259
+ cfg,
260
+ func_ty,
261
+ docstring=docstring,
262
+ **dict(ast.iter_fields(node)),
263
+ )
264
+ set_location_from(new_node, node)
265
+ bb.statements.append(new_node)
266
+ return bb
267
+
268
+ def generic_visit(self, node: ast.AST, bb: BB, jumps: Jumps) -> BB | None:
269
+ # When adding support for new statements, we have to remember to use the
270
+ # ExprBuilder to transform all included expressions!
271
+ raise GuppyError(UnsupportedError(node, "This statement", singular=True))
272
+
273
+
274
+ class ExprBuilder(ast.NodeTransformer):
275
+ """Builds an expression into a basic block."""
276
+
277
+ cfg: CFG
278
+ bb: BB
279
+
280
+ def __init__(self, cfg: CFG, start_bb: BB) -> None:
281
+ self.cfg = cfg
282
+ self.bb = start_bb
283
+
284
+ @staticmethod
285
+ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]:
286
+ """Builds an expression into a CFG.
287
+
288
+ The expression may be transformed and new basic blocks may be created (for
289
+ example for `... if ... else ...` expressions). Returns the new expression and
290
+ the final basic block in which the expression can be used."""
291
+ builder = ExprBuilder(cfg, bb)
292
+ return builder.visit(node), builder.bb
293
+
294
+ @classmethod
295
+ def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None:
296
+ """Adds a temporary variable assignment to a basic block."""
297
+ lhs = make_var(tmp_name, value)
298
+ bb.statements.append(make_assign([lhs], value))
299
+
300
+ def visit_Name(self, node: ast.Name) -> ast.Name:
301
+ return node
302
+
303
+ def visit_NamedExpr(self, node: ast.NamedExpr) -> ast.Name:
304
+ # This is an assignment expression, e.g. `x := 42`. We turn it into an
305
+ # assignment statement and replace the expression with `x`.
306
+ if not isinstance(node.target, ast.Name):
307
+ raise InternalGuppyError(f"Unexpected assign target: {node.target}")
308
+ assign = ast.Assign(
309
+ targets=[copy.deepcopy(node.target)], value=self.visit(node.value)
310
+ )
311
+ set_location_from(assign, node)
312
+ self.bb.statements.append(assign)
313
+ return node.target
314
+
315
+ def visit_IfExp(self, node: ast.IfExp) -> ast.Name:
316
+ if_bb, else_bb = self.cfg.new_bb(), self.cfg.new_bb()
317
+ BranchBuilder.add_branch(node.test, self.cfg, self.bb, if_bb, else_bb)
318
+
319
+ if_expr, if_bb = self.build(node.body, self.cfg, if_bb)
320
+ else_expr, else_bb = self.build(node.orelse, self.cfg, else_bb)
321
+
322
+ # Assign the result to a temporary variable
323
+ tmp = next(tmp_vars)
324
+ self._tmp_assign(tmp, if_expr, if_bb)
325
+ self._tmp_assign(tmp, else_expr, else_bb)
326
+
327
+ # Merge the temporary variables in a new BB
328
+ merge_bb = self.cfg.new_bb(if_bb, else_bb)
329
+ self.bb = merge_bb
330
+
331
+ # The final value is stored in the temporary variable
332
+ return make_var(tmp, node)
333
+
334
+ def visit_ListComp(self, node: ast.ListComp) -> DesugaredListComp:
335
+ check_lists_enabled(node)
336
+ generators, elt = desugar_comprehension(node.generators, node.elt, node)
337
+ return with_loc(node, DesugaredListComp(elt=elt, generators=generators))
338
+
339
+ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> DesugaredGeneratorExpr:
340
+ generators, elt = desugar_comprehension(node.generators, node.elt, node)
341
+ return with_loc(node, DesugaredGeneratorExpr(elt=elt, generators=generators))
342
+
343
+ def visit_Call(self, node: ast.Call) -> ast.AST:
344
+ return is_comptime_expression(node) or self.generic_visit(node)
345
+
346
+ def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST:
347
+ # Desugar negated numeric constants into constants
348
+ match node.op, node.operand:
349
+ case ast.USub(), ast.Constant(value=float(v) | int(v)) as const:
350
+ const.value = -v
351
+ return with_loc(node, const)
352
+ case _:
353
+ return self.generic_visit(node)
354
+
355
+ def generic_visit(self, node: ast.AST) -> ast.AST:
356
+ # Short-circuit expressions must be built using the `BranchBuilder`. However, we
357
+ # can turn them into regular expressions by assigning True/False to a temporary
358
+ # variable and merging the control-flow
359
+ if is_short_circuit_expr(node):
360
+ assert isinstance(node, ast.expr)
361
+ true_bb, false_bb = self.cfg.new_bb(), self.cfg.new_bb()
362
+ BranchBuilder.add_branch(node, self.cfg, self.bb, true_bb, false_bb)
363
+ true_const = ast.Constant(value=True)
364
+ false_const = ast.Constant(value=False)
365
+ set_location_from(true_const, node)
366
+ set_location_from(false_const, node)
367
+ tmp = next(tmp_vars)
368
+ self._tmp_assign(tmp, true_const, true_bb)
369
+ self._tmp_assign(tmp, false_const, false_bb)
370
+ merge_bb = self.cfg.new_bb(true_bb, false_bb)
371
+ self.bb = merge_bb
372
+ return make_var(tmp, node)
373
+ # For all other expressions, just recurse deeper with the node transformer
374
+ return super().generic_visit(node)
375
+
376
+
377
+ class BranchBuilder(AstVisitor[None]):
378
+ """Builds an expression and does branching based on the value.
379
+
380
+ This builder should be used to handle all branching on boolean values since it
381
+ handles short-circuit evaluation etc.
382
+ """
383
+
384
+ cfg: CFG
385
+
386
+ def __init__(self, cfg: CFG):
387
+ """Creates a new `BranchBuilder`."""
388
+ self.cfg = cfg
389
+
390
+ @staticmethod
391
+ def add_branch(node: ast.expr, cfg: CFG, bb: BB, true_bb: BB, false_bb: BB) -> None:
392
+ """Builds an expression and branches to `true_bb` or `false_bb`, depending on
393
+ the truth value of the expression."""
394
+ builder = BranchBuilder(cfg)
395
+ builder.visit(node, bb, true_bb, false_bb)
396
+
397
+ def visit_Constant(
398
+ self, node: ast.Constant, bb: BB, true_bb: BB, false_bb: BB
399
+ ) -> None:
400
+ # Branching on `True` or `False` constant should be unconditional
401
+ if isinstance(node.value, bool):
402
+ self.cfg.link(bb, true_bb if node.value else false_bb)
403
+ self.cfg.dummy_link(bb, false_bb if node.value else true_bb)
404
+ else:
405
+ self.generic_visit(node, bb, true_bb, false_bb)
406
+
407
+ def visit_BoolOp(self, node: ast.BoolOp, bb: BB, true_bb: BB, false_bb: BB) -> None:
408
+ # Add short-circuit evaluation of boolean expression. If there are more than 2
409
+ # operators, we turn the flat operator list into a right-nested tree to allow
410
+ # for recursive processing.
411
+ assert len(node.values) > 1
412
+ if len(node.values) > 2:
413
+ r = ast.BoolOp(
414
+ op=node.op,
415
+ values=node.values[1:],
416
+ lineno=node.values[0].lineno,
417
+ col_offset=node.values[0].col_offset,
418
+ end_lineno=node.values[-1].end_lineno,
419
+ end_col_offset=node.values[-1].end_col_offset,
420
+ )
421
+ node.values = [node.values[0], r]
422
+ [left, right] = node.values
423
+
424
+ extra_bb = self.cfg.new_bb()
425
+ assert type(node.op) in [ast.And, ast.Or]
426
+ if isinstance(node.op, ast.And):
427
+ self.visit(left, bb, extra_bb, false_bb)
428
+ elif isinstance(node.op, ast.Or):
429
+ self.visit(left, bb, true_bb, extra_bb)
430
+ self.visit(right, extra_bb, true_bb, false_bb)
431
+
432
+ def visit_UnaryOp(
433
+ self, node: ast.UnaryOp, bb: BB, true_bb: BB, false_bb: BB
434
+ ) -> None:
435
+ # For `not` operator, we can just switch `true_bb` and `false_bb`
436
+ if isinstance(node.op, ast.Not):
437
+ self.visit(node.operand, bb, false_bb, true_bb)
438
+ else:
439
+ self.generic_visit(node, bb, true_bb, false_bb)
440
+
441
+ def visit_Compare(
442
+ self, node: ast.Compare, bb: BB, true_bb: BB, false_bb: BB
443
+ ) -> None:
444
+ # Support chained comparisons, e.g. `x <= 5 < y` by compiling to `x <= 5 and
445
+ # 5 < y`. This way we get short-circuit evaluation for free.
446
+ if len(node.comparators) > 1:
447
+ comparators = [node.left, *node.comparators]
448
+ values = [
449
+ ast.Compare(
450
+ left=left,
451
+ ops=[op],
452
+ comparators=[right],
453
+ lineno=left.lineno,
454
+ col_offset=left.col_offset,
455
+ end_lineno=right.end_lineno,
456
+ end_col_offset=right.end_col_offset,
457
+ )
458
+ for left, op, right in zip(
459
+ comparators[:-1], node.ops, comparators[1:], strict=True
460
+ )
461
+ ]
462
+ conj = ast.BoolOp(op=ast.And(), values=values)
463
+ set_location_from(conj, node)
464
+ self.visit_BoolOp(conj, bb, true_bb, false_bb)
465
+ else:
466
+ self.generic_visit(node, bb, true_bb, false_bb)
467
+
468
+ def visit_IfExp(self, node: ast.IfExp, bb: BB, true_bb: BB, false_bb: BB) -> None:
469
+ then_bb, else_bb = self.cfg.new_bb(), self.cfg.new_bb()
470
+ self.visit(node.test, bb, then_bb, else_bb)
471
+ self.visit(node.body, then_bb, true_bb, false_bb)
472
+ self.visit(node.orelse, else_bb, true_bb, false_bb)
473
+
474
+ def generic_visit(self, node: ast.expr, bb: BB, true_bb: BB, false_bb: BB) -> None:
475
+ # We can always fall back to building the node as a regular expression and using
476
+ # the result as a branch predicate
477
+ pred, bb = ExprBuilder.build(node, self.cfg, bb)
478
+ bb.branch_pred = pred
479
+ self.cfg.link(bb, false_bb)
480
+ self.cfg.link(bb, true_bb)
481
+
482
+
483
+ def desugar_comprehension(
484
+ generators: list[ast.comprehension], elt: ast.expr, node: ast.AST
485
+ ) -> tuple[list[DesugaredGenerator], ast.expr]:
486
+ """Helper function to desugar a comprehension node."""
487
+ # Check for illegal expressions
488
+ illegals = find_nodes(is_illegal_in_list_comp, node)
489
+ if illegals:
490
+ err = UnsupportedError(
491
+ illegals[0],
492
+ "This expression",
493
+ singular=True,
494
+ unsupported_in="a list comprehension",
495
+ )
496
+ raise GuppyError(err)
497
+
498
+ # The check above ensures that the comprehension doesn't contain any control-flow
499
+ # expressions. Thus, we can use a dummy `ExprBuilder` to desugar the insides.
500
+ # TODO: Refactor so that desugaring is separate from control-flow building
501
+ dummy_cfg = CFG()
502
+ builder = ExprBuilder(dummy_cfg, dummy_cfg.entry_bb)
503
+
504
+ # Desugar into statements that create the iterator, check for a next element,
505
+ # get the next element, and finalise the iterator.
506
+ gens = []
507
+ for g in generators:
508
+ if g.is_async:
509
+ raise GuppyError(UnsupportedError(g, "Async generators"))
510
+ g.iter = builder.visit(g.iter)
511
+ it = make_var(next(tmp_vars), g.iter)
512
+ desugared = DesugaredGenerator(
513
+ iter=it,
514
+ iter_assign=make_assign(
515
+ [it], with_loc(it, MakeIter(value=g.iter, origin_node=node))
516
+ ),
517
+ next_call=with_loc(it, IterNext(value=it)),
518
+ target=g.target,
519
+ ifs=g.ifs,
520
+ used_outer_places=[],
521
+ )
522
+ gens.append(desugared)
523
+
524
+ elt = builder.visit(elt)
525
+ return gens, elt
526
+
527
+
528
+ def is_functional_annotation(stmt: ast.stmt) -> bool:
529
+ """Returns `True` iff the given statement is the functional pseudo-decorator.
530
+
531
+ Pseudo-decorators are built using the matmul operator `@`, i.e. `_@functional`.
532
+ """
533
+ if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.BinOp):
534
+ op = stmt.value
535
+ if (
536
+ isinstance(op.op, ast.MatMult)
537
+ and isinstance(op.left, ast.Name)
538
+ and isinstance(op.right, ast.Name)
539
+ ):
540
+ return op.left.id == "_" and op.right.id == "functional"
541
+ return False
542
+
543
+
544
+ @dataclass(frozen=True)
545
+ class EmptyComptimeExprError(Error):
546
+ title: ClassVar[str] = "Invalid comptime expression"
547
+ span_label: ClassVar[str] = "Comptime expression requires an argument"
548
+
549
+
550
+ def is_comptime_expression(node: ast.AST) -> ComptimeExpr | None:
551
+ """Checks if the given node is a `comptime(...)` expression and turns it into
552
+ a `ComptimeExpr` AST node.
553
+
554
+ Also accepts the `py(...)` alias for `comptime` expressions.
555
+
556
+ Otherwise, returns `None`.
557
+ """
558
+ if (
559
+ isinstance(node, ast.Call)
560
+ and isinstance(node.func, ast.Name)
561
+ and node.func.id in ("py", "comptime")
562
+ ):
563
+ match node.args:
564
+ case []:
565
+ raise GuppyError(EmptyComptimeExprError(node))
566
+ case [arg]:
567
+ pass
568
+ case args:
569
+ arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load))
570
+ return with_loc(node, ComptimeExpr(value=arg))
571
+ return None
572
+
573
+
574
+ def is_short_circuit_expr(node: ast.AST) -> bool:
575
+ """Checks if an expression uses short-circuiting.
576
+
577
+ Those expressions *must* be compiled using the `BranchBuilder`.
578
+ """
579
+ return isinstance(node, ast.BoolOp) or (
580
+ isinstance(node, ast.Compare) and len(node.comparators) > 1
581
+ )
582
+
583
+
584
+ def is_illegal_in_list_comp(node: ast.AST) -> bool:
585
+ """Checks if an expression is illegal to use in a list comprehension."""
586
+ return isinstance(node, ast.IfExp | ast.NamedExpr) or is_short_circuit_expr(node)
587
+
588
+
589
+ def make_var(name: str, loc: ast.AST | None = None) -> ast.Name:
590
+ """Creates an `ast.Name` node."""
591
+ node = ast.Name(id=name, ctx=ast.Load)
592
+ if loc is not None:
593
+ set_location_from(node, loc)
594
+ return node
595
+
596
+
597
+ def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign:
598
+ """Creates an `ast.Assign` node."""
599
+ assert len(lhs) > 0
600
+ adjuster = ContextAdjuster(ast.Store())
601
+ lhs = [adjuster.visit(expr) for expr in lhs]
602
+ if len(lhs) == 1:
603
+ target = lhs[0]
604
+ else:
605
+ target = with_loc(value, ast.Tuple(elts=lhs, ctx=ast.Store()))
606
+ return with_loc(value, ast.Assign(targets=[target], value=value))