guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,97 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from hugr import Wire
4
+ from hugr import tys as ht
5
+ from hugr.build.function import Function
6
+
7
+ from guppylang_internals.compiler.cfg_compiler import compile_cfg
8
+ from guppylang_internals.compiler.core import CompilerContext, DFContainer
9
+ from guppylang_internals.compiler.hugr_extension import PartialOp
10
+ from guppylang_internals.nodes import CheckedNestedFunctionDef
11
+
12
+ if TYPE_CHECKING:
13
+ from guppylang_internals.definition.function import CheckedFunctionDef
14
+
15
+
16
+ def compile_global_func_def(
17
+ func: "CheckedFunctionDef",
18
+ builder: Function,
19
+ ctx: CompilerContext,
20
+ ) -> None:
21
+ """Compiles a top-level function definition to Hugr."""
22
+ cfg = compile_cfg(func.cfg, builder, builder.inputs(), ctx)
23
+ builder.set_outputs(*cfg)
24
+
25
+
26
+ def compile_local_func_def(
27
+ func: CheckedNestedFunctionDef,
28
+ dfg: DFContainer,
29
+ ctx: CompilerContext,
30
+ ) -> Wire:
31
+ """Compiles a local (nested) function definition to Hugr and loads it into a value.
32
+
33
+ Returns the wire output of the `LoadFunc` operation.
34
+ """
35
+ assert func.ty.input_names is not None
36
+
37
+ # Pick an order for the captured variables
38
+ captured = list(func.captured.values())
39
+ captured_types = [v.ty.to_hugr(ctx) for v, _ in captured]
40
+
41
+ # Whether the function calls itself recursively.
42
+ recursive = func.name in func.cfg.live_before[func.cfg.entry_bb]
43
+
44
+ # Prepend captured variables to the function arguments
45
+ func_ty = func.ty.to_hugr(ctx)
46
+ closure_ty = ht.FunctionType([*captured_types, *func_ty.input], func_ty.output)
47
+ func_builder = dfg.builder.module_root_builder().define_function(
48
+ func.name, closure_ty.input, closure_ty.output
49
+ )
50
+
51
+ # Nested functions are not generic, so no need to worry about monomorphization
52
+ mono_args = ()
53
+
54
+ # If we have captured variables and the body contains a recursive occurrence of
55
+ # the function itself, then we provide the partially applied function as a local
56
+ # variable
57
+ call_args: list[Wire] = list(func_builder.inputs())
58
+ if len(captured) > 0 and recursive:
59
+ loaded = func_builder.load_function(func_builder, closure_ty)
60
+ partial = func_builder.add_op(
61
+ PartialOp.from_closure(closure_ty, captured_types),
62
+ loaded,
63
+ *func_builder.input_node[: len(captured)],
64
+ )
65
+
66
+ call_args.append(partial)
67
+ func.cfg.input_tys.append(func.ty)
68
+
69
+ # Compile the CFG
70
+ cfg = compile_cfg(func.cfg, func_builder, call_args, ctx)
71
+ func_builder.set_outputs(*cfg)
72
+ else:
73
+ # Otherwise, we treat the function like a normal global variable
74
+ from guppylang_internals.definition.function import CompiledFunctionDef
75
+
76
+ ctx.compiled[func.def_id, mono_args] = CompiledFunctionDef(
77
+ func.def_id,
78
+ func.name,
79
+ func,
80
+ mono_args,
81
+ func.ty,
82
+ None,
83
+ func.cfg,
84
+ func_builder,
85
+ )
86
+ ctx.worklist[func.def_id, mono_args] = None # will compile the CFG later
87
+
88
+ # Finally, load the function into the local data-flow graph
89
+ loaded = dfg.builder.load_function(func_builder, closure_ty)
90
+ if len(captured) > 0:
91
+ loaded = dfg.builder.add_op(
92
+ PartialOp.from_closure(closure_ty, captured_types),
93
+ loaded,
94
+ *(dfg[v] for v, _ in captured),
95
+ )
96
+
97
+ return loaded
@@ -0,0 +1,224 @@
1
+ """A hugr extension with guppy-specific operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING
7
+
8
+ import hugr.ext as he
9
+ import hugr.tys as ht
10
+ from hugr import ops
11
+
12
+ if TYPE_CHECKING:
13
+ from collections.abc import Iterator, Sequence
14
+
15
+ EXTENSION: he.Extension = he.Extension("guppylang", he.Version(0, 1, 0))
16
+
17
+
18
+ PARTIAL_OP_DEF: he.OpDef = EXTENSION.add_op_def(
19
+ he.OpDef(
20
+ "partial",
21
+ signature=he.OpDefSig(
22
+ poly_func=ht.PolyFuncType(
23
+ params=[
24
+ # Captured input types
25
+ ht.ListParam(ht.TypeTypeParam(ht.TypeBound.Linear)),
26
+ # Non-captured input types
27
+ ht.ListParam(ht.TypeTypeParam(ht.TypeBound.Linear)),
28
+ # Output types
29
+ ht.ListParam(ht.TypeTypeParam(ht.TypeBound.Linear)),
30
+ ],
31
+ body=ht.FunctionType(
32
+ input=[
33
+ ht.FunctionType(
34
+ input=[
35
+ ht.RowVariable(0, ht.TypeBound.Linear),
36
+ ht.RowVariable(1, ht.TypeBound.Linear),
37
+ ],
38
+ output=[ht.RowVariable(2, ht.TypeBound.Linear)],
39
+ ),
40
+ ht.RowVariable(0, ht.TypeBound.Linear),
41
+ ],
42
+ output=[
43
+ ht.FunctionType(
44
+ input=[ht.RowVariable(1, ht.TypeBound.Linear)],
45
+ output=[ht.RowVariable(2, ht.TypeBound.Linear)],
46
+ ),
47
+ ],
48
+ ),
49
+ )
50
+ ),
51
+ description="A partial application of a function."
52
+ " Given arguments [*a],[*b],[*c], represents an operation with type"
53
+ " `(*c, *a -> *b), *c -> (*a -> *b)`",
54
+ )
55
+ )
56
+
57
+
58
+ @dataclass
59
+ class PartialOp(ops.AsExtOp):
60
+ """An operation that partially evaluates a function.
61
+
62
+ args:
63
+ captured_inputs: A list of input types `c_0, ..., c_k` to partially apply.
64
+ other_inputs: A list of input types `a_0, ..., a_n` not partially applied.
65
+ outputs: The output types `b_0, ..., b_m` of the partially applied function.
66
+
67
+ returns:
68
+ An operation with type
69
+ ` (c_0, ..., c_k, a_0, ..., a_n -> b_0, ..., b_m ), c_0, ..., c_k`
70
+ `-> (a_0, ..., a_n -> b_0, ..., b_m)`
71
+ """
72
+
73
+ captured_inputs: list[ht.Type]
74
+ other_inputs: list[ht.Type]
75
+ outputs: list[ht.Type]
76
+
77
+ @classmethod
78
+ def from_closure(
79
+ cls, closure_ty: ht.FunctionType, captured_tys: Sequence[ht.Type]
80
+ ) -> PartialOp:
81
+ """An operation that partially evaluates a function.
82
+
83
+ args:
84
+ closure_ty: A function `(c_0, ..., c_k, a_0, ..., a_n) -> b_0, ..., b_m`
85
+ captured_tys: A list `c_0, ..., c_k` of types captured by the function
86
+
87
+ returns:
88
+ An operation with type
89
+ ` (c_0, ..., c_k, a_0, ..., a_n -> b_0, ..., b_m ), c_0, ..., c_k`
90
+ `-> (a_0, ..., a_n -> b_0, ..., b_m)`
91
+ """
92
+ assert len(closure_ty.input) >= len(captured_tys)
93
+ assert captured_tys == closure_ty.input[: len(captured_tys)]
94
+
95
+ other_inputs = closure_ty.input[len(captured_tys) :]
96
+ return cls(
97
+ captured_inputs=list(captured_tys),
98
+ other_inputs=list(other_inputs),
99
+ outputs=list(closure_ty.output),
100
+ )
101
+
102
+ def op_def(self) -> he.OpDef:
103
+ return PARTIAL_OP_DEF
104
+
105
+ def type_args(self) -> list[ht.TypeArg]:
106
+ captured_args: list[ht.TypeArg] = [
107
+ ht.TypeTypeArg(ty) for ty in self.captured_inputs
108
+ ]
109
+ other_args: list[ht.TypeArg] = [ht.TypeTypeArg(ty) for ty in self.other_inputs]
110
+ output_args: list[ht.TypeArg] = [ht.TypeTypeArg(ty) for ty in self.outputs]
111
+ return [
112
+ ht.ListArg(captured_args),
113
+ ht.ListArg(other_args),
114
+ ht.ListArg(output_args),
115
+ ]
116
+
117
+ def cached_signature(self) -> ht.FunctionType | None:
118
+ closure_ty = ht.FunctionType(
119
+ [*self.captured_inputs, *self.other_inputs],
120
+ self.outputs,
121
+ )
122
+ partial_fn_ty = ht.FunctionType(self.other_inputs, closure_ty.output)
123
+ return ht.FunctionType([closure_ty, *self.captured_inputs], [partial_fn_ty])
124
+
125
+ @classmethod
126
+ def from_ext(cls, custom: ops.ExtOp) -> PartialOp:
127
+ match custom:
128
+ case ops.ExtOp(
129
+ _op_def=op_def, args=[captured_args, other_args, output_args]
130
+ ):
131
+ if op_def.qualified_name() == PARTIAL_OP_DEF.qualified_name():
132
+ return cls(
133
+ captured_inputs=[*_arg_seq_to_types(captured_args)],
134
+ other_inputs=[*_arg_seq_to_types(other_args)],
135
+ outputs=[*_arg_seq_to_types(output_args)],
136
+ )
137
+ msg = f"Invalid custom op: {custom}"
138
+ raise ops.AsExtOp.InvalidExtOp(msg)
139
+
140
+ @property
141
+ def num_out(self) -> int:
142
+ return 1
143
+
144
+
145
+ UNSUPPORTED_OP_DEF: he.OpDef = EXTENSION.add_op_def(
146
+ he.OpDef(
147
+ "unsupported",
148
+ signature=he.OpDefSig(
149
+ poly_func=ht.PolyFuncType(
150
+ params=[
151
+ # Name of the operation
152
+ ht.StringParam(),
153
+ # Input types
154
+ ht.ListParam(ht.TypeTypeParam(ht.TypeBound.Linear)),
155
+ # Output types
156
+ ht.ListParam(ht.TypeTypeParam(ht.TypeBound.Linear)),
157
+ ],
158
+ body=ht.FunctionType(
159
+ input=[ht.RowVariable(1, ht.TypeBound.Linear)],
160
+ output=[ht.RowVariable(2, ht.TypeBound.Linear)],
161
+ ),
162
+ )
163
+ ),
164
+ description="An unsupported operation stub emitted by Guppy.",
165
+ )
166
+ )
167
+
168
+
169
+ @dataclass
170
+ class UnsupportedOp(ops.AsExtOp):
171
+ """An unsupported operation stub emitted by Guppy.
172
+
173
+ args:
174
+ op_name: The name of the unsupported operation.
175
+ inputs: The input types of the operation.
176
+ outputs: The output types of the operation.
177
+ """
178
+
179
+ op_name: str
180
+ inputs: list[ht.Type]
181
+ outputs: list[ht.Type]
182
+
183
+ def op_def(self) -> he.OpDef:
184
+ return UNSUPPORTED_OP_DEF
185
+
186
+ def type_args(self) -> list[ht.TypeArg]:
187
+ op_name = ht.StringArg(self.op_name)
188
+ input_args = ht.ListArg([ht.TypeTypeArg(ty) for ty in self.inputs])
189
+ output_args = ht.ListArg([ht.TypeTypeArg(ty) for ty in self.outputs])
190
+ return [op_name, input_args, output_args]
191
+
192
+ def cached_signature(self) -> ht.FunctionType | None:
193
+ return ht.FunctionType(self.inputs, self.outputs)
194
+
195
+ @classmethod
196
+ def from_ext(cls, custom: ops.ExtOp) -> UnsupportedOp:
197
+ match custom:
198
+ case ops.ExtOp(_op_def=op_def, args=args):
199
+ if op_def.qualified_name() == UNSUPPORTED_OP_DEF.qualified_name():
200
+ [op_name, input_args, output_args] = args
201
+ assert isinstance(op_name, ht.StringArg), (
202
+ "The first argument to a guppylang.unsupported op "
203
+ "must be the operation name"
204
+ )
205
+ op_name = op_name.value
206
+ return cls(
207
+ op_name=op_name,
208
+ inputs=[*_arg_seq_to_types(input_args)],
209
+ outputs=[*_arg_seq_to_types(output_args)],
210
+ )
211
+ msg = f"Invalid custom op: {custom}"
212
+ raise ops.AsExtOp.InvalidExtOp(msg)
213
+
214
+ @property
215
+ def num_out(self) -> int:
216
+ return len(self.outputs)
217
+
218
+
219
+ def _arg_seq_to_types(args: ht.TypeArg) -> Iterator[ht.Type]:
220
+ """Converts a ListArg of type arguments into a sequence of types."""
221
+ assert isinstance(args, ht.ListArg)
222
+ for arg in args.elems:
223
+ assert isinstance(arg, ht.TypeTypeArg)
224
+ yield arg.ty
File without changes
@@ -0,0 +1,212 @@
1
+ import ast
2
+ import functools
3
+ from collections.abc import Sequence
4
+
5
+ import hugr.tys as ht
6
+ from hugr import Wire, ops
7
+ from hugr.build.dfg import DfBase
8
+
9
+ from guppylang_internals.ast_util import AstVisitor, get_type
10
+ from guppylang_internals.checker.core import Variable, contains_subscript
11
+ from guppylang_internals.compiler.core import (
12
+ CompilerBase,
13
+ CompilerContext,
14
+ DFContainer,
15
+ return_var,
16
+ )
17
+ from guppylang_internals.compiler.expr_compiler import ExprCompiler
18
+ from guppylang_internals.error import InternalGuppyError
19
+ from guppylang_internals.nodes import (
20
+ CheckedNestedFunctionDef,
21
+ IterableUnpack,
22
+ PlaceNode,
23
+ TupleUnpack,
24
+ )
25
+ from guppylang_internals.std._internal.compiler.array import (
26
+ array_discard_empty,
27
+ array_new,
28
+ array_pop,
29
+ )
30
+ from guppylang_internals.std._internal.compiler.prelude import build_unwrap
31
+ from guppylang_internals.tys.builtin import get_element_type
32
+ from guppylang_internals.tys.const import ConstValue
33
+ from guppylang_internals.tys.ty import TupleType, Type, type_to_row
34
+
35
+
36
+ class StmtCompiler(CompilerBase, AstVisitor[None]):
37
+ """A compiler for Guppy statements to Hugr"""
38
+
39
+ expr_compiler: ExprCompiler
40
+
41
+ dfg: DFContainer
42
+
43
+ def __init__(self, ctx: CompilerContext):
44
+ super().__init__(ctx)
45
+ self.expr_compiler = ExprCompiler(ctx)
46
+
47
+ def compile_stmts(
48
+ self,
49
+ stmts: Sequence[ast.stmt],
50
+ dfg: DFContainer,
51
+ ) -> DFContainer:
52
+ """Compiles a list of basic statements into a dataflow node.
53
+
54
+ Note that the `dfg` is mutated in-place. After compilation, the DFG will also
55
+ contain all variables that are assigned in the given list of statements.
56
+ """
57
+ self.dfg = dfg
58
+ for s in stmts:
59
+ self.visit(s)
60
+ return self.dfg
61
+
62
+ @property
63
+ def builder(self) -> DfBase[ops.DfParentOp]:
64
+ """The Hugr dataflow graph builder."""
65
+ return self.dfg.builder
66
+
67
+ @functools.singledispatchmethod
68
+ def _assign(self, lhs: ast.expr, port: Wire) -> None:
69
+ """Updates the local DFG with assignments."""
70
+ raise InternalGuppyError("Invalid assign pattern in compiler")
71
+
72
+ @_assign.register
73
+ def _assign_place(self, lhs: PlaceNode, port: Wire) -> None:
74
+ if subscript := contains_subscript(lhs.place):
75
+ assert subscript.setitem_call is not None
76
+ if subscript.item not in self.dfg:
77
+ self.dfg[subscript.item] = self.expr_compiler.compile(
78
+ subscript.item_expr, self.dfg
79
+ )
80
+ # If the subscript is nested inside the place, e.g. `xs[i].y = ...`, we
81
+ # first need to lookup `tmp = xs[i]`, assign `tmp.y = ...`, and then finally
82
+ # set `xs[i] = tmp`
83
+ if subscript != lhs.place:
84
+ assert subscript.getitem_call is not None
85
+ # Instead of `tmp` just use `xs[i]` as a "name", the dfg tracker doesn't
86
+ # care about this
87
+ self.dfg[subscript] = self.expr_compiler.compile(
88
+ subscript.getitem_call, self.dfg
89
+ )
90
+ # Assign to the name `xs[i].y`
91
+ self.dfg[lhs.place] = port
92
+ # Look up `xs[i]` again since it was mutated by the assignment above, then
93
+ # compile a call to `__setitem__` to actually mutate
94
+ self.dfg[subscript.setitem_call.value_var] = self.dfg[subscript]
95
+ self.expr_compiler.visit(subscript.setitem_call.call)
96
+ else:
97
+ self.dfg[lhs.place] = port
98
+
99
+ @_assign.register
100
+ def _assign_tuple(self, lhs: TupleUnpack, port: Wire) -> None:
101
+ """Handles assignment where the RHS is a tuple that should be unpacked."""
102
+ # Unpack the RHS tuple
103
+ left, starred, right = lhs.pattern.left, lhs.pattern.starred, lhs.pattern.right
104
+ types = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(lhs))]
105
+ unpack = self.builder.add_op(ops.UnpackTuple(types), port)
106
+ ports = list(unpack)
107
+
108
+ # Assign left and right
109
+ for pat, wire in zip(left, ports[: len(left)], strict=True):
110
+ self._assign(pat, wire)
111
+ if right:
112
+ for pat, wire in zip(right, ports[-len(right) :], strict=True):
113
+ self._assign(pat, wire)
114
+
115
+ # Starred assignments are collected into an array
116
+ if starred:
117
+ array_ty = get_type(starred)
118
+ starred_ports = (
119
+ ports[len(left) : -len(right)] if right else ports[len(left) :]
120
+ )
121
+ elt = get_element_type(array_ty).to_hugr(self.ctx)
122
+ opts = [self.builder.add_op(ops.Some(elt), p) for p in starred_ports]
123
+ array = self.builder.add_op(array_new(ht.Option(elt), len(opts)), *opts)
124
+ self._assign(starred, array)
125
+
126
+ @_assign.register
127
+ def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None:
128
+ """Handles assignment where the RHS is an iterable that should be unpacked."""
129
+ # Given an assignment pattern `left, *starred, right`, collect the RHS into an
130
+ # array and pop from the left and right, leaving us with the starred array in
131
+ # the middle
132
+ assert isinstance(lhs.compr.length, ConstValue)
133
+ length = lhs.compr.length.value
134
+ assert isinstance(length, int)
135
+ opt_elt_ty = ht.Option(lhs.compr.elt_ty.to_hugr(self.ctx))
136
+
137
+ def pop(
138
+ array: Wire, length: int, pats: list[ast.expr], from_left: bool
139
+ ) -> tuple[Wire, int]:
140
+ err = "Internal error: unpacking of iterable failed"
141
+ num_pats = len(pats)
142
+ # Pop the number of requested elements from the array
143
+ elts = []
144
+ for i in range(num_pats):
145
+ res = self.builder.add_op(
146
+ array_pop(opt_elt_ty, length - i, from_left), array
147
+ )
148
+ [elt_opt, array] = build_unwrap(self.builder, res, err)
149
+ [elt] = build_unwrap(self.builder, elt_opt, err)
150
+ elts.append(elt)
151
+ # Assign elements to the given patterns
152
+ for pat, elt in zip(
153
+ pats,
154
+ # Assignments are evaluated from left to right, so we need to assign in
155
+ # reverse order if we popped from the right
156
+ elts if from_left else reversed(elts),
157
+ strict=True,
158
+ ):
159
+ self._assign(pat, elt)
160
+ return array, length - num_pats
161
+
162
+ self.dfg[lhs.rhs_var.place] = port
163
+ array = self.expr_compiler.visit_DesugaredArrayComp(lhs.compr)
164
+ array, length = pop(array, length, lhs.pattern.left, True)
165
+ array, length = pop(array, length, lhs.pattern.right, False)
166
+ if lhs.pattern.starred:
167
+ self._assign(lhs.pattern.starred, array)
168
+ else:
169
+ assert length == 0
170
+ self.builder.add_op(array_discard_empty(opt_elt_ty), array)
171
+
172
+ def visit_Assign(self, node: ast.Assign) -> None:
173
+ [target] = node.targets
174
+ port = self.expr_compiler.compile(node.value, self.dfg)
175
+ self._assign(target, port)
176
+
177
+ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
178
+ assert node.value is not None
179
+ port = self.expr_compiler.compile(node.value, self.dfg)
180
+ self._assign(node.target, port)
181
+
182
+ def visit_AugAssign(self, node: ast.AugAssign) -> None:
183
+ raise InternalGuppyError("Node should have been removed during type checking.")
184
+
185
+ def visit_Expr(self, node: ast.Expr) -> None:
186
+ self.expr_compiler.compile_row(node.value, self.dfg)
187
+
188
+ def visit_Return(self, node: ast.Return) -> None:
189
+ # We turn returns into assignments of dummy variables, i.e. the statement
190
+ # `return e0, e1, e2` is turned into `%ret0 = e0; %ret1 = e1; %ret2 = e2`.
191
+ if node.value is not None:
192
+ return_ty = get_type(node.value)
193
+ port = self.expr_compiler.compile(node.value, self.dfg)
194
+
195
+ row: list[tuple[Wire, Type]]
196
+ if isinstance(return_ty, TupleType):
197
+ types = [e.to_hugr(self.ctx) for e in return_ty.element_types]
198
+ unpack = self.builder.add_op(ops.UnpackTuple(types), port)
199
+ row = list(zip(unpack, return_ty.element_types, strict=True))
200
+ else:
201
+ row = [(port, return_ty)]
202
+
203
+ for i, (wire, ty) in enumerate(row):
204
+ var = Variable(return_var(i), ty, node.value)
205
+ self.dfg[var] = wire
206
+
207
+ def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None:
208
+ from guppylang_internals.compiler.func_compiler import compile_local_func_def
209
+
210
+ var = Variable(node.name, node.ty, node)
211
+ loaded_func = compile_local_func_def(node, self.dfg, self.ctx)
212
+ self.dfg[var] = loaded_func