guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,447 @@
1
+ """Type checking code for statements.
2
+
3
+ Operates on statements in a basic block after CFG construction. In particular, we
4
+ assume that statements involving control flow (i.e. if, while, break, and return
5
+ statements) have been removed during CFG construction.
6
+
7
+ After checking, we return a desugared statement where all sub-expression have been type
8
+ annotated.
9
+ """
10
+
11
+ import ast
12
+ import functools
13
+ from collections.abc import Iterable, Sequence
14
+ from dataclasses import replace
15
+ from itertools import takewhile
16
+ from typing import TypeVar, cast
17
+
18
+ from guppylang_internals.ast_util import (
19
+ AstVisitor,
20
+ get_type,
21
+ with_loc,
22
+ with_type,
23
+ )
24
+ from guppylang_internals.cfg.bb import BB, BBStatement
25
+ from guppylang_internals.cfg.builder import (
26
+ desugar_comprehension,
27
+ make_var,
28
+ tmp_vars,
29
+ )
30
+ from guppylang_internals.checker.core import (
31
+ Context,
32
+ FieldAccess,
33
+ SubscriptAccess,
34
+ Variable,
35
+ )
36
+ from guppylang_internals.checker.errors.generic import UnsupportedError
37
+ from guppylang_internals.checker.errors.type_errors import (
38
+ AssignFieldTypeMismatchError,
39
+ AssignNonPlaceHelp,
40
+ AssignSubscriptTypeMismatchError,
41
+ AttributeNotFoundError,
42
+ MissingReturnValueError,
43
+ StarredTupleUnpackError,
44
+ TypeInferenceError,
45
+ UnpackableError,
46
+ WrongNumberOfUnpacksError,
47
+ )
48
+ from guppylang_internals.checker.expr_checker import (
49
+ ExprChecker,
50
+ ExprSynthesizer,
51
+ check_place_assignable,
52
+ synthesize_comprehension,
53
+ )
54
+ from guppylang_internals.error import GuppyError, GuppyTypeError, InternalGuppyError
55
+ from guppylang_internals.nodes import (
56
+ AnyUnpack,
57
+ DesugaredArrayComp,
58
+ IterableUnpack,
59
+ MakeIter,
60
+ NestedFunctionDef,
61
+ PlaceNode,
62
+ TupleUnpack,
63
+ UnpackPattern,
64
+ )
65
+ from guppylang_internals.span import Span, to_span
66
+ from guppylang_internals.tys.builtin import (
67
+ array_type,
68
+ get_element_type,
69
+ get_iter_size,
70
+ is_array_type,
71
+ is_sized_iter_type,
72
+ nat_type,
73
+ )
74
+ from guppylang_internals.tys.const import ConstValue
75
+ from guppylang_internals.tys.parsing import type_from_ast
76
+ from guppylang_internals.tys.subst import Subst
77
+ from guppylang_internals.tys.ty import (
78
+ ExistentialTypeVar,
79
+ FunctionType,
80
+ NoneType,
81
+ StructType,
82
+ TupleType,
83
+ Type,
84
+ )
85
+
86
+
87
+ class StmtChecker(AstVisitor[BBStatement]):
88
+ ctx: Context
89
+ bb: BB | None
90
+ return_ty: Type | None
91
+
92
+ def __init__(
93
+ self, ctx: Context, bb: BB | None = None, return_ty: Type | None = None
94
+ ) -> None:
95
+ assert not return_ty or not return_ty.unsolved_vars
96
+ self.ctx = ctx
97
+ self.bb = bb
98
+ self.return_ty = return_ty
99
+
100
+ def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]:
101
+ return [self.visit(s) for s in stmts]
102
+
103
+ def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, Type]:
104
+ return ExprSynthesizer(self.ctx).synthesize(node)
105
+
106
+ def _synth_instance_fun(
107
+ self,
108
+ node: ast.expr,
109
+ args: list[ast.expr],
110
+ func_name: str,
111
+ description: str,
112
+ exp_sig: FunctionType | None = None,
113
+ give_reason: bool = False,
114
+ ) -> tuple[ast.expr, Type]:
115
+ return ExprSynthesizer(self.ctx).synthesize_instance_func(
116
+ node, args, func_name, description, exp_sig, give_reason
117
+ )
118
+
119
+ def _check_expr(
120
+ self, node: ast.expr, ty: Type, kind: str = "expression"
121
+ ) -> tuple[ast.expr, Subst]:
122
+ return ExprChecker(self.ctx).check(node, ty, kind)
123
+
124
+ @functools.singledispatchmethod
125
+ def _check_assign(self, lhs: ast.expr, rhs: ast.expr, rhs_ty: Type) -> ast.expr:
126
+ """Helper function to check assignments with patterns."""
127
+ raise InternalGuppyError("Unexpected assignment pattern")
128
+
129
+ @_check_assign.register
130
+ def _check_variable_assign(
131
+ self, lhs: ast.Name, _rhs: ast.expr, rhs_ty: Type
132
+ ) -> PlaceNode:
133
+ x = lhs.id
134
+ var = Variable(x, rhs_ty, lhs)
135
+ self.ctx.locals[x] = var
136
+ return with_loc(lhs, with_type(rhs_ty, PlaceNode(place=var)))
137
+
138
+ @_check_assign.register
139
+ def _check_field_assign(
140
+ self, lhs: ast.Attribute, _rhs: ast.expr, rhs_ty: Type
141
+ ) -> PlaceNode:
142
+ # Unfortunately, the `attr` is just a string, not an AST node, so we
143
+ # have to compute its span by hand. This is fine since linebreaks are
144
+ # not allowed in the identifier following the `.`
145
+ span = to_span(lhs)
146
+ value, attr = lhs.value, lhs.attr
147
+ attr_span = Span(span.end.shift_left(len(attr)), span.end)
148
+ value, struct_ty = self._synth_expr(value)
149
+ if not isinstance(struct_ty, StructType) or attr not in struct_ty.field_dict:
150
+ raise GuppyTypeError(AttributeNotFoundError(attr_span, struct_ty, attr))
151
+ field = struct_ty.field_dict[attr]
152
+ # TODO: In the future, we could infer some type args here
153
+ if field.ty != rhs_ty:
154
+ # TODO: Get hold of a span for the RHS and use a regular `TypeMismatchError`
155
+ # instead (maybe with a custom hint).
156
+ raise GuppyTypeError(AssignFieldTypeMismatchError(attr_span, rhs_ty, field))
157
+ if not isinstance(value, PlaceNode):
158
+ # For now we complain if someone tries to assign to something that is not a
159
+ # place, e.g. `f().a = 4`. This would only make sense if there is another
160
+ # reference to the return value of `f`, otherwise the mutation cannot be
161
+ # observed. We can start supporting this once we have proper reference
162
+ # semantics.
163
+ err = UnsupportedError(value, "Assigning to this expression", singular=True)
164
+ err.add_sub_diagnostic(AssignNonPlaceHelp(None, field))
165
+ raise GuppyError(err)
166
+ if field.ty.copyable:
167
+ raise GuppyError(
168
+ UnsupportedError(
169
+ attr_span, "Mutation of classical fields", singular=True
170
+ )
171
+ )
172
+ place = FieldAccess(value.place, struct_ty.field_dict[attr], lhs)
173
+ place = check_place_assignable(place, self.ctx, lhs, "assignable")
174
+ return with_loc(lhs, with_type(rhs_ty, PlaceNode(place=place)))
175
+
176
+ @_check_assign.register
177
+ def _check_subscript_assign(
178
+ self, lhs: ast.Subscript, rhs: ast.expr, rhs_ty: Type
179
+ ) -> PlaceNode:
180
+ # Check subscript is array subscript.
181
+ value, container_ty = self._synth_expr(lhs.value)
182
+ if not is_array_type(container_ty):
183
+ raise GuppyError(
184
+ UnsupportedError(lhs, "Subscript assignments to non-arrays")
185
+ )
186
+
187
+ # Check array element type matches type of RHS.
188
+ element_ty = get_element_type(container_ty)
189
+ if element_ty != rhs_ty:
190
+ raise GuppyTypeError(
191
+ AssignSubscriptTypeMismatchError(to_span(lhs), rhs_ty, element_ty)
192
+ )
193
+
194
+ # As with field assignment, only allow place assignments for now.
195
+ if not isinstance(value, PlaceNode):
196
+ raise GuppyError(
197
+ UnsupportedError(value, "Assigning to this expression", singular=True)
198
+ )
199
+
200
+ # Create a subscript place
201
+ item_expr, item_ty = self._synth_expr(lhs.slice)
202
+ item = Variable(next(tmp_vars), item_ty, item_expr)
203
+ place = SubscriptAccess(value.place, item, rhs_ty, item_expr)
204
+
205
+ # Calling `check_place_assignable` makes sure that `__setitem__` is implemented
206
+ place = check_place_assignable(place, self.ctx, lhs, "assignable")
207
+ return with_loc(lhs, with_type(rhs_ty, PlaceNode(place=place)))
208
+
209
+ @_check_assign.register
210
+ def _check_tuple_assign(
211
+ self, lhs: ast.Tuple, rhs: ast.expr, rhs_ty: Type
212
+ ) -> AnyUnpack:
213
+ return self._check_unpack_assign(lhs, rhs, rhs_ty)
214
+
215
+ @_check_assign.register
216
+ def _check_list_assign(
217
+ self, lhs: ast.List, rhs: ast.expr, rhs_ty: Type
218
+ ) -> AnyUnpack:
219
+ return self._check_unpack_assign(lhs, rhs, rhs_ty)
220
+
221
+ def _check_unpack_assign(
222
+ self, lhs: ast.Tuple | ast.List, rhs: ast.expr, rhs_ty: Type
223
+ ) -> AnyUnpack:
224
+ """Helper function to check unpacking assignments.
225
+
226
+ These are the ones where the LHS is either a tuple or a list.
227
+ """
228
+ # Parse LHS into `left, *starred, right`
229
+ pattern = parse_unpack_pattern(lhs)
230
+ left, starred, right = pattern.left, pattern.starred, pattern.right
231
+ # Check that the RHS has an appropriate type to be unpacked
232
+ unpack, rhs_elts, rhs_tys = self._check_unpackable(rhs, rhs_ty, pattern)
233
+
234
+ # Check that the numbers match up on the LHS and RHS
235
+ num_lhs, num_rhs = len(right) + len(left), len(rhs_tys)
236
+ err = WrongNumberOfUnpacksError(
237
+ lhs, num_rhs, num_lhs, at_least=starred is not None
238
+ )
239
+ if num_lhs > num_rhs:
240
+ # Build span that covers the unexpected elts on the LHS
241
+ span = Span(to_span(lhs.elts[num_rhs]).start, to_span(lhs.elts[-1]).end)
242
+ raise GuppyTypeError(replace(err, span=span))
243
+ elif num_lhs < num_rhs and not starred:
244
+ raise GuppyTypeError(err)
245
+
246
+ # Recursively check any nested patterns on the left or right
247
+ le, rs = len(left), len(rhs_elts) - len(right) # left_end, right_start
248
+ unpack.pattern.left = [
249
+ self._check_assign(pat, elt, ty)
250
+ for pat, elt, ty in zip(left, rhs_elts[:le], rhs_tys[:le], strict=True)
251
+ ]
252
+ unpack.pattern.right = [
253
+ self._check_assign(pat, elt, ty)
254
+ for pat, elt, ty in zip(right, rhs_elts[rs:], rhs_tys[rs:], strict=True)
255
+ ]
256
+
257
+ # Starred assignments are collected into an array
258
+ if starred:
259
+ starred_tys = rhs_tys[le:rs]
260
+ assert all_equal(starred_tys)
261
+ if starred_tys:
262
+ starred_ty, *_ = starred_tys
263
+ # Starred part could be empty. If it's an iterable unpack, we're still fine
264
+ # since we know the yielded type
265
+ elif isinstance(unpack, IterableUnpack):
266
+ starred_ty = unpack.compr.elt_ty
267
+ # For tuple unpacks, there is no way to infer a type for the empty starred
268
+ # part
269
+ else:
270
+ unsolved = array_type(ExistentialTypeVar.fresh("T", True, True), 0)
271
+ raise GuppyError(TypeInferenceError(starred, unsolved))
272
+ array_ty = array_type(starred_ty, len(starred_tys))
273
+ unpack.pattern.starred = self._check_assign(starred, rhs_elts[0], array_ty)
274
+
275
+ return with_type(rhs_ty, with_loc(lhs, unpack))
276
+
277
+ def _check_unpackable(
278
+ self, expr: ast.expr, ty: Type, pattern: UnpackPattern
279
+ ) -> tuple[AnyUnpack, list[ast.expr], Sequence[Type]]:
280
+ """Checks that the given expression can be used in an unpacking assignment.
281
+
282
+ This is the case for expressions with tuple types or ones that are iterable with
283
+ a static size. Also checks that the expression is compatible with the given
284
+ unpacking pattern.
285
+
286
+ Returns an AST node capturing the unpacking operation together with expressions
287
+ and types for all unpacked items. Emits a user error if the given expression is
288
+ not unpackable.
289
+ """
290
+ left, starred, right = pattern.left, pattern.starred, pattern.right
291
+ if isinstance(ty, TupleType):
292
+ # Starred assignment of tuples is only allowed if all starred elements have
293
+ # the same type
294
+ if starred:
295
+ starred_tys = (
296
+ ty.element_types[len(left) : -len(right)]
297
+ if right
298
+ else ty.element_types[len(left) :]
299
+ )
300
+ if not all_equal(starred_tys):
301
+ tuple_ty = TupleType(starred_tys)
302
+ raise GuppyError(StarredTupleUnpackError(starred, tuple_ty))
303
+ tys = ty.element_types
304
+ elts = expr.elts if isinstance(expr, ast.Tuple) else [expr] * len(tys)
305
+ return TupleUnpack(pattern), elts, tys
306
+
307
+ elif self.ctx.globals.get_instance_func(ty, "__iter__"):
308
+ size = check_iter_unpack_has_static_size(expr, self.ctx)
309
+ # Create a dummy variable and assign the expression to it. This helps us to
310
+ # wire it up correctly during Hugr generation.
311
+ var = self._check_assign(make_var(next(tmp_vars), expr), expr, ty)
312
+ assert isinstance(var, PlaceNode)
313
+ # We collect the whole RHS into an array. For this, we can reuse the
314
+ # existing array comprehension logic.
315
+ elt = make_var(next(tmp_vars), expr)
316
+ gen = ast.comprehension(target=elt, iter=var, ifs=[], is_async=False)
317
+ [gen], elt = desugar_comprehension([with_loc(expr, gen)], elt, expr)
318
+ # Type check the comprehension
319
+ [gen], elt, elt_ty = synthesize_comprehension(expr, [gen], elt, self.ctx)
320
+ compr = DesugaredArrayComp(
321
+ elt, gen, length=ConstValue(nat_type(), size), elt_ty=elt_ty
322
+ )
323
+ compr = with_type(array_type(elt_ty, size), compr)
324
+ return IterableUnpack(pattern, compr, var), size * [elt], size * [elt_ty]
325
+
326
+ # Otherwise, we can't unpack this expression
327
+ raise GuppyError(UnpackableError(expr, ty))
328
+
329
+ def visit_Assign(self, node: ast.Assign) -> ast.Assign:
330
+ if len(node.targets) > 1:
331
+ # This is the case for assignments like `a = b = 1`
332
+ raise GuppyError(UnsupportedError(node, "Multi assignments"))
333
+
334
+ [target] = node.targets
335
+ node.value, ty = self._synth_expr(node.value)
336
+ node.targets = [self._check_assign(target, node.value, ty)]
337
+ return node
338
+
339
+ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
340
+ if node.value is None:
341
+ raise GuppyError(UnsupportedError(node, "Variable declarations"))
342
+ ty = type_from_ast(node.annotation, self.ctx.globals, self.ctx.generic_params)
343
+ node.value, subst = self._check_expr(node.value, ty)
344
+ assert not ty.unsolved_vars # `ty` must be closed!
345
+ assert len(subst) == 0
346
+ target = self._check_assign(node.target, node.value, ty)
347
+ return with_loc(node, ast.Assign(targets=[target], value=node.value))
348
+
349
+ def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt:
350
+ bin_op = with_loc(
351
+ node, ast.BinOp(left=node.target, op=node.op, right=node.value)
352
+ )
353
+ assign = with_loc(node, ast.Assign(targets=[node.target], value=bin_op))
354
+ return self.visit_Assign(assign)
355
+
356
+ def visit_Expr(self, node: ast.Expr) -> ast.stmt:
357
+ # An expression statement where the return value is discarded
358
+ node.value, _ = self._synth_expr(node.value)
359
+ return node
360
+
361
+ def visit_Return(self, node: ast.Return) -> ast.stmt:
362
+ if not self.return_ty:
363
+ raise InternalGuppyError("return_ty required to check return stmt!")
364
+
365
+ if node.value is not None:
366
+ node.value, subst = self._check_expr(
367
+ node.value, self.return_ty, "return value"
368
+ )
369
+ assert len(subst) == 0 # `self.return_ty` is closed!
370
+ elif not isinstance(self.return_ty, NoneType):
371
+ raise GuppyTypeError(MissingReturnValueError(node, self.return_ty))
372
+ return node
373
+
374
+ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt:
375
+ from guppylang_internals.checker.func_checker import check_nested_func_def
376
+
377
+ if not self.bb:
378
+ raise InternalGuppyError("BB required to check nested function def!")
379
+
380
+ func_def = check_nested_func_def(node, self.bb, self.ctx)
381
+ self.ctx.locals[func_def.name] = Variable(func_def.name, func_def.ty, func_def)
382
+ return func_def
383
+
384
+ def visit_If(self, node: ast.If) -> None:
385
+ raise InternalGuppyError("Control-flow statement should not be present here.")
386
+
387
+ def visit_While(self, node: ast.While) -> None:
388
+ raise InternalGuppyError("Control-flow statement should not be present here.")
389
+
390
+ def visit_Break(self, node: ast.Break) -> None:
391
+ raise InternalGuppyError("Control-flow statement should not be present here.")
392
+
393
+ def visit_Continue(self, node: ast.Continue) -> None:
394
+ raise InternalGuppyError("Control-flow statement should not be present here.")
395
+
396
+
397
+ T = TypeVar("T")
398
+
399
+
400
+ def all_equal(xs: Iterable[T]) -> bool:
401
+ """Checks if all elements yielded from an iterable are equal."""
402
+ it = iter(xs)
403
+ try:
404
+ first = next(it)
405
+ except StopIteration:
406
+ return True
407
+ return all(first == x for x in it)
408
+
409
+
410
+ def parse_unpack_pattern(lhs: ast.Tuple | ast.List) -> UnpackPattern:
411
+ """Parses the LHS of an unpacking assignment like `a, *bs, c = ...` or
412
+ `[a, *bs, c] = ...`."""
413
+ # Split up LHS into `left, *starred, right` (the Python grammar ensures
414
+ # that there is at most one starred expression)
415
+ left = list(takewhile(lambda e: not isinstance(e, ast.Starred), lhs.elts))
416
+ starred = (
417
+ cast(ast.Starred, lhs.elts[len(left)]).value
418
+ if len(left) < len(lhs.elts)
419
+ else None
420
+ )
421
+ right = lhs.elts[len(left) + 1 :]
422
+ assert isinstance(starred, ast.Name | None), "Python grammar"
423
+ return UnpackPattern(left, starred, right)
424
+
425
+
426
+ def check_iter_unpack_has_static_size(expr: ast.expr, ctx: Context) -> int:
427
+ """Helper function to check that an iterable expression is suitable to be unpacked
428
+ in an assignment.
429
+
430
+ This is the case if the iterator has a static, non-generic size.
431
+
432
+ Returns the size of the iterator or emits a user error if the iterable is not
433
+ suitable.
434
+ """
435
+ expr_synth = ExprSynthesizer(ctx)
436
+ make_iter = with_loc(expr, MakeIter(expr, expr, unwrap_size_hint=False))
437
+ make_iter, iter_ty = expr_synth.visit_MakeIter(make_iter)
438
+ err = UnpackableError(expr, get_type(expr))
439
+ if not is_sized_iter_type(iter_ty):
440
+ err.add_sub_diagnostic(UnpackableError.NonStaticIter(None))
441
+ raise GuppyError(err)
442
+ match get_iter_size(iter_ty):
443
+ case ConstValue(value=int(size)):
444
+ return size
445
+ case generic_size:
446
+ err.add_sub_diagnostic(UnpackableError.GenericSize(None, generic_size))
447
+ raise GuppyError(err)
File without changes
@@ -0,0 +1,233 @@
1
+ import functools
2
+ from collections.abc import Sequence
3
+ from typing import cast
4
+
5
+ from hugr import Wire, ops
6
+ from hugr import tys as ht
7
+ from hugr.build import cfg as hc
8
+ from hugr.build.dfg import DP, DfBase
9
+ from hugr.hugr.node_port import ToNode
10
+
11
+ from guppylang_internals.checker.cfg_checker import (
12
+ CheckedBB,
13
+ CheckedCFG,
14
+ Row,
15
+ Signature,
16
+ )
17
+ from guppylang_internals.checker.core import Place, Variable
18
+ from guppylang_internals.compiler.core import (
19
+ CompilerContext,
20
+ DFContainer,
21
+ is_return_var,
22
+ return_var,
23
+ )
24
+ from guppylang_internals.compiler.expr_compiler import ExprCompiler
25
+ from guppylang_internals.compiler.stmt_compiler import StmtCompiler
26
+ from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
27
+ from guppylang_internals.tys.ty import SumType, row_to_type, type_to_row
28
+
29
+
30
+ def compile_cfg(
31
+ cfg: CheckedCFG[Place],
32
+ container: DfBase[DP],
33
+ inputs: Sequence[Wire],
34
+ ctx: CompilerContext,
35
+ ) -> hc.Cfg:
36
+ """Compiles a CFG to Hugr."""
37
+ # Patch the CFG with dummy return variables
38
+ # TODO: This mutates the CFG in-place which leads to problems when trying to lower
39
+ # the same function to Hugr twice. For now we just check that the return vars
40
+ # haven't already been inserted, but we should figure out a better way to handle
41
+ # this: https://github.com/CQCL/guppylang/issues/428
42
+ if all(
43
+ not is_return_var(v.name)
44
+ for v in cfg.exit_bb.sig.input_row
45
+ if isinstance(v, Variable)
46
+ ):
47
+ insert_return_vars(cfg)
48
+
49
+ builder = container.add_cfg(*inputs)
50
+
51
+ # Explicitly annotate the output types since Hugr can't infer them if the exit is
52
+ # unreachable
53
+ out_tys = [place.ty.to_hugr(ctx) for place in cfg.exit_bb.sig.input_row]
54
+ # TODO: Use proper API for this once it's added in hugr-py:
55
+ # https://github.com/CQCL/hugr/issues/1816
56
+ builder._exit_op._cfg_outputs = out_tys
57
+ builder.parent_op._outputs = out_tys
58
+ builder.parent_node = builder.hugr._update_node_outs(
59
+ builder.parent_node, len(out_tys)
60
+ )
61
+
62
+ blocks: dict[CheckedBB[Place], ToNode] = {}
63
+ for bb in cfg.bbs:
64
+ blocks[bb] = compile_bb(bb, builder, bb == cfg.entry_bb, ctx)
65
+ for bb in cfg.bbs:
66
+ for i, succ in enumerate(bb.successors):
67
+ builder.branch(blocks[bb][i], blocks[succ])
68
+
69
+ return builder
70
+
71
+
72
+ def compile_bb(
73
+ bb: CheckedBB[Place],
74
+ builder: hc.Cfg,
75
+ is_entry: bool,
76
+ ctx: CompilerContext,
77
+ ) -> ToNode:
78
+ """Compiles a single basic block to Hugr.
79
+
80
+ If the basic block is the output block, returns `None`.
81
+ """
82
+ # The exit BB is completely empty
83
+ if bb.is_exit:
84
+ assert len(bb.statements) == 0
85
+ return builder.exit
86
+
87
+ # Unreachable BBs (besides the exit) should have been removed by now
88
+ assert bb.reachable
89
+
90
+ # Otherwise, we use a regular `Block` node
91
+ block: hc.Block
92
+ inputs: Sequence[Place]
93
+ if is_entry:
94
+ inputs = bb.sig.input_row
95
+ block = builder.add_entry()
96
+ else:
97
+ inputs = sort_vars(bb.sig.input_row)
98
+ block = builder.add_block(*(v.ty.to_hugr(ctx) for v in inputs))
99
+
100
+ # Add input node and compile the statements
101
+ dfg = DFContainer(block, ctx)
102
+ for v, wire in zip(inputs, block.input_node, strict=True):
103
+ dfg[v] = wire
104
+ dfg = StmtCompiler(ctx).compile_stmts(bb.statements, dfg)
105
+
106
+ # If we branch, we also have to compile the branch predicate
107
+ if len(bb.successors) > 1:
108
+ assert bb.branch_pred is not None
109
+ branch_port = ExprCompiler(ctx).compile(bb.branch_pred, dfg)
110
+ # Convert the bool predicate into a sum for branching.
111
+ pred_ty = builder.hugr.port_type(branch_port.out_port())
112
+ assert pred_ty == OpaqueBool
113
+ branch_port = dfg.builder.add_op(read_bool(), branch_port)
114
+ branch_port = cast(Wire, branch_port)
115
+ else:
116
+ # Even if we don't branch, we still have to add a `Sum(())` predicates
117
+ branch_port = dfg.builder.add_op(ops.Tag(0, ht.UnitSum(1)))
118
+
119
+ # Finally, we have to add the block output.
120
+ outputs: Sequence[Place]
121
+ if len(bb.successors) == 1:
122
+ # The easy case is if we don't branch: We just output all variables that are
123
+ # specified by the signature
124
+ [outputs] = bb.sig.output_rows
125
+ else:
126
+ # CFG building ensures that branching BBs don't branch to the exit (exit jumps
127
+ # must always be unconditional)
128
+ assert not any(succ.is_exit for succ in bb.successors)
129
+
130
+ # If we branch and the branches use the same places, then we can use a
131
+ # regular output
132
+ first, *rest = bb.sig.output_rows
133
+ if all({p.id for p in first} == {p.id for p in r} for r in rest):
134
+ outputs = first
135
+ else:
136
+ # Otherwise, we have to output a TupleSum: We put all non-linear variables
137
+ # into the branch TupleSum and all linear variables in the normal output
138
+ # (since they are shared between all successors). This is in line with the
139
+ # ordering on variables which puts linear variables at the end.
140
+ # We don't need to worry about the order of return vars since this isn't
141
+ # a branch to an exit (see assert above).
142
+ branch_port = choose_vars_for_tuple_sum(
143
+ unit_sum=branch_port,
144
+ output_vars=[
145
+ [v for v in sort_vars(row) if v.ty.droppable]
146
+ for row in bb.sig.output_rows
147
+ ],
148
+ dfg=dfg,
149
+ )
150
+ outputs = [v for v in first if not v.ty.droppable]
151
+
152
+ # If this is *not* a jump to the exit BB, we need to sort the outputs to make the
153
+ # signature consistent with what the next BB expects
154
+ if not any(succ.is_exit for succ in bb.successors):
155
+ outputs = sort_vars(outputs)
156
+ else:
157
+ # Exit variables are not allowed to be sorted since their order corresponds to
158
+ # the function outputs
159
+ assert len(bb.successors) == 1, "Exit jumps are always unconditional"
160
+
161
+ block.set_block_outputs(branch_port, *(dfg[v] for v in outputs))
162
+ return block
163
+
164
+
165
+ def insert_return_vars(cfg: CheckedCFG[Place]) -> None:
166
+ """Patches a CFG by annotating dummy return variables in the BB signatures.
167
+
168
+ The statement compiler turns `return` statements into assignments of dummy variables
169
+ `%ret0`, `%ret1`, etc. We update the exit BB signature to make sure they are
170
+ correctly outputted.
171
+ """
172
+ return_vars = [
173
+ Variable(return_var(i), ty, None)
174
+ for i, ty in enumerate(type_to_row(cfg.output_ty))
175
+ ]
176
+ # Prepend return variables to the exit signature
177
+ cfg.exit_bb.sig = Signature(
178
+ [*return_vars, *cfg.exit_bb.sig.input_row], cfg.exit_bb.sig.output_rows
179
+ )
180
+ # Also patch the predecessors
181
+ for pred in cfg.exit_bb.predecessors:
182
+ # The exit BB will be the only successor
183
+ assert len(pred.sig.output_rows) == 1
184
+ [out_row] = pred.sig.output_rows
185
+ pred.sig = Signature(pred.sig.input_row, [[*return_vars, *out_row]])
186
+
187
+
188
+ def choose_vars_for_tuple_sum(
189
+ unit_sum: Wire, output_vars: list[Row[Place]], dfg: DFContainer
190
+ ) -> Wire:
191
+ """Selects an output based on a TupleSum.
192
+
193
+ Given `unit_sum: Sum(*(), *(), ...)` and output variable rows `#s1, #s2, ...`,
194
+ constructs a TupleSum value of type `Sum(#s1, #s2, ...)`.
195
+ """
196
+ assert all(v.ty.droppable for var_row in output_vars for v in var_row)
197
+ tys = [[v.ty for v in var_row] for var_row in output_vars]
198
+ sum_type = SumType([row_to_type(row) for row in tys]).to_hugr(dfg.ctx)
199
+
200
+ # We pass all values into the conditional instead of relying on non-local edges.
201
+ # This is because we can't handle them in lower parts of the stack yet :/
202
+ # TODO: Reinstate use of non-local edges.
203
+ # See https://github.com/CQCL/guppylang/issues/963
204
+ all_vars = {v.id: dfg[v] for var_row in output_vars for v in var_row}
205
+ all_vars_wires = list(all_vars.values())
206
+ all_vars_idxs = {x: i for i, x in enumerate(all_vars.keys())}
207
+
208
+ with dfg.builder.add_conditional(unit_sum, *all_vars_wires) as conditional:
209
+ for i, var_row in enumerate(output_vars):
210
+ with conditional.add_case(i) as case:
211
+ case_inputs = case.inputs()
212
+ outputs = [case_inputs[all_vars_idxs[v.id]] for v in var_row]
213
+ tag = case.add_op(ops.Tag(i, sum_type), *outputs)
214
+ case.set_outputs(tag)
215
+ return conditional
216
+
217
+
218
+ def compare_var(p1: Place, p2: Place) -> int:
219
+ """Defines a `<` order on variables.
220
+
221
+ We use this to determine in which order variables are outputted from basic blocks.
222
+ We need to output linear variables at the end, so we do a lexicographic ordering of
223
+ linearity and name.
224
+ """
225
+ return -1 if (not p1.ty.droppable, str(p1)) < (not p2.ty.droppable, str(p2)) else 1
226
+
227
+
228
+ def sort_vars(row: Row[Place]) -> list[Place]:
229
+ """Sorts a row of variables.
230
+
231
+ This determines the order in which they are outputted from a BB.
232
+ """
233
+ return sorted(row, key=functools.cmp_to_key(compare_var))