guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,989 @@
1
+ import ast
2
+ from collections.abc import Iterator, Sequence
3
+ from contextlib import AbstractContextManager, ExitStack, contextmanager
4
+ from typing import Any, Final, TypeGuard, TypeVar
5
+
6
+ import hugr
7
+ import hugr.std.collections.array
8
+ import hugr.std.float
9
+ import hugr.std.int
10
+ import hugr.std.logic
11
+ import hugr.std.prelude
12
+ from hugr import Wire, ops
13
+ from hugr import tys as ht
14
+ from hugr import val as hv
15
+ from hugr.build import function as hf
16
+ from hugr.build.cond_loop import Conditional
17
+ from hugr.build.dfg import DP, DfBase
18
+ from typing_extensions import assert_never
19
+
20
+ from guppylang_internals.ast_util import AstNode, AstVisitor, get_type
21
+ from guppylang_internals.cfg.builder import tmp_vars
22
+ from guppylang_internals.checker.core import Variable, contains_subscript
23
+ from guppylang_internals.checker.errors.generic import UnsupportedError
24
+ from guppylang_internals.compiler.core import (
25
+ DEBUG_EXTENSION,
26
+ RESULT_EXTENSION,
27
+ CompilerBase,
28
+ CompilerContext,
29
+ DFContainer,
30
+ GlobalConstId,
31
+ )
32
+ from guppylang_internals.compiler.hugr_extension import PartialOp
33
+ from guppylang_internals.definition.custom import CustomFunctionDef
34
+ from guppylang_internals.definition.value import (
35
+ CallableDef,
36
+ CallReturnWires,
37
+ CompiledCallableDef,
38
+ CompiledValueDef,
39
+ )
40
+ from guppylang_internals.engine import ENGINE
41
+ from guppylang_internals.error import GuppyError, InternalGuppyError
42
+ from guppylang_internals.nodes import (
43
+ BarrierExpr,
44
+ DesugaredArrayComp,
45
+ DesugaredGenerator,
46
+ DesugaredListComp,
47
+ ExitKind,
48
+ FieldAccessAndDrop,
49
+ GenericParamValue,
50
+ GlobalCall,
51
+ GlobalName,
52
+ LocalCall,
53
+ PanicExpr,
54
+ PartialApply,
55
+ PlaceNode,
56
+ ResultExpr,
57
+ StateResultExpr,
58
+ SubscriptAccessAndDrop,
59
+ TensorCall,
60
+ TupleAccessAndDrop,
61
+ TypeApply,
62
+ )
63
+ from guppylang_internals.std._internal.compiler.arithmetic import (
64
+ UnsignedIntVal,
65
+ convert_ifromusize,
66
+ )
67
+ from guppylang_internals.std._internal.compiler.array import (
68
+ array_convert_from_std_array,
69
+ array_convert_to_std_array,
70
+ array_map,
71
+ array_new,
72
+ array_repeat,
73
+ standard_array_type,
74
+ unpack_array,
75
+ )
76
+ from guppylang_internals.std._internal.compiler.list import (
77
+ list_new,
78
+ )
79
+ from guppylang_internals.std._internal.compiler.prelude import (
80
+ build_error,
81
+ build_panic,
82
+ build_unwrap,
83
+ panic,
84
+ )
85
+ from guppylang_internals.std._internal.compiler.tket_bool import (
86
+ OpaqueBool,
87
+ OpaqueBoolVal,
88
+ make_opaque,
89
+ not_op,
90
+ read_bool,
91
+ )
92
+ from guppylang_internals.tys.arg import ConstArg
93
+ from guppylang_internals.tys.builtin import (
94
+ bool_type,
95
+ get_element_type,
96
+ int_type,
97
+ is_bool_type,
98
+ is_frozenarray_type,
99
+ )
100
+ from guppylang_internals.tys.const import ConstValue
101
+ from guppylang_internals.tys.subst import Inst
102
+ from guppylang_internals.tys.ty import (
103
+ BoundTypeVar,
104
+ FuncInput,
105
+ FunctionType,
106
+ InputFlags,
107
+ NoneType,
108
+ NumericType,
109
+ OpaqueType,
110
+ TupleType,
111
+ Type,
112
+ type_to_row,
113
+ )
114
+
115
+
116
+ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
117
+ """A compiler from guppylang expressions to Hugr."""
118
+
119
+ dfg: DFContainer
120
+
121
+ def compile(self, expr: ast.expr, dfg: DFContainer) -> Wire:
122
+ """Compiles an expression and returns a single wire holding the output value."""
123
+ self.dfg = dfg
124
+ return self.visit(expr)
125
+
126
+ def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[Wire]:
127
+ """Compiles a row expression and returns a list of wires, one for each value in
128
+ the row.
129
+
130
+ On Python-level, we treat tuples like rows on top-level. However, nested tuples
131
+ are treated like regular Guppy tuples.
132
+ """
133
+ return [self.compile(e, dfg) for e in expr_to_row(expr)]
134
+
135
+ @property
136
+ def builder(self) -> DfBase[ops.DfParentOp]:
137
+ """The current Hugr dataflow graph builder."""
138
+ return self.dfg.builder
139
+
140
+ @contextmanager
141
+ def _new_dfcontainer(
142
+ self, inputs: list[PlaceNode], builder: DfBase[DP]
143
+ ) -> Iterator[None]:
144
+ """Context manager to build a graph inside a new `DFContainer`.
145
+
146
+ Automatically updates `self.dfg` and makes the inputs available.
147
+ """
148
+ old = self.dfg
149
+ # Check that the input names are unique
150
+ assert len({inp.place.id for inp in inputs}) == len(
151
+ inputs
152
+ ), "Inputs are not unique"
153
+ self.dfg = DFContainer(builder, self.ctx, self.dfg.locals.copy())
154
+ hugr_input = builder.input_node
155
+ for input_node, wire in zip(inputs, hugr_input, strict=True):
156
+ self.dfg[input_node.place] = wire
157
+
158
+ yield
159
+
160
+ self.dfg = old
161
+
162
+ @contextmanager
163
+ def _new_loop(
164
+ self,
165
+ just_inputs_vars: list[PlaceNode],
166
+ loop_vars: list[PlaceNode],
167
+ break_predicate: PlaceNode,
168
+ ) -> Iterator[None]:
169
+ """Context manager to build a graph inside a new `TailLoop` node.
170
+
171
+ Automatically adds the `Output` node to the loop body once the context manager
172
+ exits.
173
+ """
174
+ just_inputs = [self.visit(name) for name in just_inputs_vars]
175
+ loop_inputs = [self.visit(name) for name in loop_vars]
176
+ loop = self.builder.add_tail_loop(just_inputs, loop_inputs)
177
+ with self._new_dfcontainer(just_inputs_vars + loop_vars, loop):
178
+ yield
179
+ # Output the branch predicate and the inputs for the next iteration. Note
180
+ # that we have to do fresh calls to `self.visit` here since we're in a new
181
+ # context
182
+ do_break = self.visit(break_predicate)
183
+ loop.set_loop_outputs(do_break, *(self.visit(name) for name in loop_vars))
184
+ # Update the DFG with the outputs from the loop
185
+ for node, wire in zip(loop_vars, loop, strict=True):
186
+ self.dfg[node.place] = wire
187
+
188
+ @contextmanager
189
+ def _new_case(
190
+ self,
191
+ inputs: list[PlaceNode],
192
+ outputs: list[PlaceNode],
193
+ conditional: Conditional,
194
+ case_id: int,
195
+ ) -> Iterator[None]:
196
+ """Context manager to build a graph inside a new `Case` node.
197
+
198
+ Automatically adds the `Output` node once the context manager exits.
199
+ """
200
+ # TODO: `Case` is `_DfgBase`, but not `Dfg`?
201
+ case = conditional.add_case(case_id)
202
+ with self._new_dfcontainer(inputs, case):
203
+ yield
204
+ case.set_outputs(*(self.visit(name) for name in outputs))
205
+
206
+ def _if_else(
207
+ self,
208
+ cond: ast.expr,
209
+ inputs: list[PlaceNode],
210
+ outputs: list[PlaceNode],
211
+ only_true_inputs: list[PlaceNode] | None = None,
212
+ only_false_inputs: list[PlaceNode] | None = None,
213
+ ) -> tuple[AbstractContextManager[None], AbstractContextManager[None]]:
214
+ """Builds a `Conditional`, returning context managers to build the `True` and
215
+ `False` branch.
216
+ """
217
+ cond_wire = self.visit(cond)
218
+ cond_ty = self.builder.hugr.port_type(cond_wire.out_port())
219
+ if cond_ty == OpaqueBool:
220
+ cond_wire = self.builder.add_op(read_bool(), cond_wire)
221
+ conditional = self.builder.add_conditional(
222
+ cond_wire, *(self.visit(inp) for inp in inputs)
223
+ )
224
+ only_true_inputs_ = only_true_inputs or []
225
+ only_false_inputs_ = only_false_inputs or []
226
+
227
+ @contextmanager
228
+ def true_case() -> Iterator[None]:
229
+ with self._new_case(only_true_inputs_ + inputs, outputs, conditional, 1):
230
+ yield
231
+
232
+ @contextmanager
233
+ def false_case() -> Iterator[None]:
234
+ with self._new_case(only_false_inputs_ + inputs, outputs, conditional, 0):
235
+ yield
236
+ # Update the DFG with the outputs from the Conditional node
237
+ for node, wire in zip(outputs, conditional, strict=True):
238
+ self.dfg[node.place] = wire
239
+
240
+ return true_case(), false_case()
241
+
242
+ @contextmanager
243
+ def _if_true(self, cond: ast.expr, inputs: list[PlaceNode]) -> Iterator[None]:
244
+ """Context manager to build a graph inside the `true` case of a `Conditional`
245
+
246
+ In the `false` case, the inputs are outputted as is.
247
+ """
248
+ true_case, false_case = self._if_else(cond, inputs, inputs)
249
+ with false_case:
250
+ # If the condition is false, output the inputs as is
251
+ pass
252
+ with true_case:
253
+ # If the condition is true, we enter the `with` block
254
+ yield
255
+
256
+ def visit_Constant(self, node: ast.Constant) -> Wire:
257
+ if value := python_value_to_hugr(node.value, get_type(node), self.ctx):
258
+ return self.builder.load(value)
259
+ raise InternalGuppyError("Unsupported constant expression in compiler")
260
+
261
+ def visit_PlaceNode(self, node: PlaceNode) -> Wire:
262
+ if subscript := contains_subscript(node.place):
263
+ if subscript.item not in self.dfg:
264
+ self.dfg[subscript.item] = self.visit(subscript.item_expr)
265
+ self.dfg[subscript] = self.visit(subscript.getitem_call)
266
+ return self.dfg[node.place]
267
+
268
+ def visit_GlobalName(self, node: GlobalName) -> Wire:
269
+ defn = ENGINE.get_checked(node.def_id)
270
+ if isinstance(defn, CallableDef) and defn.ty.parametrized:
271
+ # TODO: This should be caught during checking
272
+ err = UnsupportedError(
273
+ node, "Polymorphic functions as dynamic higher-order values"
274
+ )
275
+ raise GuppyError(err)
276
+
277
+ defn, [] = self.ctx.build_compiled_def(node.def_id, type_args=[])
278
+ assert isinstance(defn, CompiledValueDef)
279
+ return defn.load(self.dfg, self.ctx, node)
280
+
281
+ def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
282
+ match node.param.ty:
283
+ case NumericType(NumericType.Kind.Nat):
284
+ # Generic nat parameters are encoded using Hugr bounded nat parameters,
285
+ # so they are not monomorphized when compiling to Hugr
286
+ arg = node.param.to_bound().to_hugr(self.ctx)
287
+ load_nat = hugr.std.PRELUDE.get_op("load_nat").instantiate(
288
+ [arg], ht.FunctionType([], [ht.USize()])
289
+ )
290
+ usize = self.builder.add_op(load_nat)
291
+ return self.builder.add_op(convert_ifromusize(), usize)
292
+ case ty:
293
+ # Look up monomorphization
294
+ assert self.ctx.current_mono_args is not None
295
+ match self.ctx.current_mono_args[node.param.idx]:
296
+ case ConstArg(const=ConstValue(value=v)):
297
+ val = python_value_to_hugr(v, ty, self.ctx)
298
+ assert val is not None
299
+ return self.builder.load(val)
300
+ case _:
301
+ raise InternalGuppyError("Monomorphized const is not a value")
302
+
303
+ def visit_Name(self, node: ast.Name) -> Wire:
304
+ raise InternalGuppyError("Node should have been removed during type checking.")
305
+
306
+ def visit_Tuple(self, node: ast.Tuple) -> Wire:
307
+ elems = [self.visit(e) for e in node.elts]
308
+ types = [get_type(e) for e in node.elts]
309
+ return self._pack_tuple(elems, types)
310
+
311
+ def visit_List(self, node: ast.List) -> Wire:
312
+ # Note that this is a list literal (i.e. `[e1, e2, ...]`), not a comprehension
313
+ inputs = [self.visit(e) for e in node.elts]
314
+ list_ty = get_type(node)
315
+ elem_ty = get_element_type(list_ty)
316
+ return list_new(self.builder, elem_ty.to_hugr(self.ctx), inputs)
317
+
318
+ def _unpack_tuple(self, wire: Wire, types: Sequence[Type]) -> Sequence[Wire]:
319
+ """Add a tuple unpack operation to the graph"""
320
+ types = [t.to_hugr(self.ctx) for t in types]
321
+ return list(self.builder.add_op(ops.UnpackTuple(types), wire))
322
+
323
+ def _pack_tuple(self, wires: Sequence[Wire], types: Sequence[Type]) -> Wire:
324
+ """Add a tuple pack operation to the graph"""
325
+ types = [t.to_hugr(self.ctx) for t in types]
326
+ return self.builder.add_op(ops.MakeTuple(types), *wires)
327
+
328
+ def _pack_returns(self, returns: Sequence[Wire], return_ty: Type) -> Wire:
329
+ """Groups function return values into a tuple"""
330
+ if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
331
+ types = type_to_row(return_ty)
332
+ assert len(returns) == len(types)
333
+ return self._pack_tuple(returns, types)
334
+ assert (
335
+ len(returns) == 1
336
+ ), f"Expected a single return value. Got {returns}. return type {return_ty}"
337
+ return returns[0]
338
+
339
+ def _update_inout_ports(
340
+ self,
341
+ args: list[ast.expr],
342
+ inout_ports: Iterator[Wire],
343
+ func_ty: FunctionType,
344
+ ) -> None:
345
+ """Helper method that updates the ports for borrowed arguments after a call."""
346
+ for inp, arg in zip(func_ty.inputs, args, strict=True):
347
+ if InputFlags.Inout in inp.flags:
348
+ # Linearity checker ensures that borrowed arguments that are not places
349
+ # can be safely dropped after the call returns
350
+ if not isinstance(arg, PlaceNode):
351
+ next(inout_ports)
352
+ continue
353
+ self.dfg[arg.place] = next(inout_ports)
354
+ # Places involving subscripts need to generate code for the appropriate
355
+ # `__setitem__` call. Nested subscripts are handled automatically since
356
+ # `arg.place.parent` occurs as an arg of this call, so will also
357
+ # be recursively reassigned.
358
+ if subscript := contains_subscript(arg.place):
359
+ assert subscript.setitem_call is not None
360
+ # Need to assign __setitem__ value before compiling call.
361
+ # Note that the assignment to `self.dfg[arg.place]` also updated
362
+ # `self.dfg[subscript]` so that it now contains the value we want
363
+ # to write back into the subscript.
364
+ self.dfg[subscript.setitem_call.value_var] = self.dfg[subscript]
365
+ self.visit(subscript.setitem_call.call)
366
+ assert next(inout_ports, None) is None, "Too many inout return ports"
367
+
368
+ def visit_LocalCall(self, node: LocalCall) -> Wire:
369
+ func = self.visit(node.func)
370
+ func_ty = get_type(node.func)
371
+ assert isinstance(func_ty, FunctionType)
372
+ num_returns = len(type_to_row(func_ty.output))
373
+
374
+ args = self._compile_call_args(node.args, func_ty)
375
+ call = self.builder.add_op(
376
+ ops.CallIndirect(func_ty.to_hugr(self.ctx)), func, *args
377
+ )
378
+ regular_returns = list(call[:num_returns])
379
+ inout_returns = call[num_returns:]
380
+ self._update_inout_ports(node.args, inout_returns, func_ty)
381
+ return self._pack_returns(regular_returns, func_ty.output)
382
+
383
+ def visit_TensorCall(self, node: TensorCall) -> Wire:
384
+ functions: Wire = self.visit(node.func)
385
+ function_types = get_type(node.func)
386
+ assert isinstance(function_types, TupleType)
387
+
388
+ rets: list[Wire] = []
389
+ remaining_args = node.args
390
+ for func, func_ty in zip(
391
+ self._unpack_tuple(functions, function_types.element_types),
392
+ function_types.element_types,
393
+ strict=True,
394
+ ):
395
+ outs, remaining_args = self._compile_tensor_with_leftovers(
396
+ func, func_ty, remaining_args
397
+ )
398
+ rets.extend(outs)
399
+ assert (
400
+ remaining_args == []
401
+ ), "Not all function arguments were consumed after a tensor call"
402
+ return self._pack_returns(rets, node.tensor_ty.output)
403
+
404
+ def _compile_tensor_with_leftovers(
405
+ self, func: Wire, func_ty: Type, args: list[ast.expr]
406
+ ) -> tuple[
407
+ list[Wire], # Compiled outputs
408
+ list[ast.expr], # Leftover args
409
+ ]:
410
+ """Compiles a function call, consuming as many arguments as needed, and
411
+ returning the unused ones.
412
+ """
413
+ if isinstance(func_ty, TupleType):
414
+ remaining_args = args
415
+ all_outs = []
416
+ for elem, ty in zip(
417
+ self._unpack_tuple(func, func_ty.element_types),
418
+ func_ty.element_types,
419
+ strict=True,
420
+ ):
421
+ outs, remaining_args = self._compile_tensor_with_leftovers(
422
+ elem, ty, remaining_args
423
+ )
424
+ all_outs.extend(outs)
425
+ return all_outs, remaining_args
426
+
427
+ elif isinstance(func_ty, FunctionType):
428
+ input_len = len(func_ty.inputs)
429
+ num_returns = len(type_to_row(func_ty.output))
430
+ consumed_args, other_args = args[0:input_len], args[input_len:]
431
+ consumed_wires = self._compile_call_args(consumed_args, func_ty)
432
+ call = self.builder.add_op(
433
+ ops.CallIndirect(func_ty.to_hugr(self.ctx)), func, *consumed_wires
434
+ )
435
+ regular_returns: list[Wire] = list(call[:num_returns])
436
+ inout_returns = call[num_returns:]
437
+ self._update_inout_ports(consumed_args, inout_returns, func_ty)
438
+ return regular_returns, other_args
439
+ else:
440
+ raise InternalGuppyError("Tensor element wasn't function or tuple")
441
+
442
+ def visit_GlobalCall(self, node: GlobalCall) -> Wire:
443
+ func, rem_args = self.ctx.build_compiled_def(node.def_id, node.type_args)
444
+ assert isinstance(func, CompiledCallableDef)
445
+
446
+ if isinstance(func, CustomFunctionDef) and not func.has_signature:
447
+ func_ty = FunctionType(
448
+ [FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args],
449
+ get_type(node),
450
+ )
451
+ else:
452
+ func_ty = func.ty.instantiate(rem_args)
453
+
454
+ args = self._compile_call_args(node.args, func_ty)
455
+ rets = func.compile_call(args, rem_args, self.dfg, self.ctx, node)
456
+ self._update_inout_ports(node.args, iter(rets.inout_returns), func_ty)
457
+ return self._pack_returns(rets.regular_returns, func_ty.output)
458
+
459
+ def _compile_call_args(
460
+ self, args: list[ast.expr], func_ty: FunctionType
461
+ ) -> list[Wire]:
462
+ """Helper function to compile arguments for function calls.
463
+
464
+ Takes care of filtering out comptime arguments that are provided via generic
465
+ args or monomorphization instead of wires.
466
+ """
467
+ return [
468
+ self.visit(arg)
469
+ for arg, inp in zip(args, func_ty.inputs, strict=True)
470
+ # Don't compile comptime args since we already include them as type args
471
+ if InputFlags.Comptime not in inp.flags
472
+ ]
473
+
474
+ def visit_Call(self, node: ast.Call) -> Wire:
475
+ raise InternalGuppyError("Node should have been removed during type checking.")
476
+
477
+ def visit_PartialApply(self, node: PartialApply) -> Wire:
478
+ func_ty = get_type(node.func)
479
+ assert isinstance(func_ty, FunctionType)
480
+ op = PartialOp.from_closure(
481
+ func_ty.to_hugr(self.ctx),
482
+ [get_type(arg).to_hugr(self.ctx) for arg in node.args],
483
+ )
484
+ return self.builder.add_op(
485
+ op, self.visit(node.func), *(self.visit(arg) for arg in node.args)
486
+ )
487
+
488
+ def visit_TypeApply(self, node: TypeApply) -> Wire:
489
+ # For now, we can only TypeApply global FunctionDefs/Decls.
490
+ if not isinstance(node.value, GlobalName):
491
+ raise InternalGuppyError("Dynamic TypeApply not supported yet!")
492
+ defn, rem_args = self.ctx.build_compiled_def(node.value.def_id, node.inst)
493
+ assert isinstance(defn, CompiledCallableDef)
494
+
495
+ # We have to be very careful here: If we instantiate `foo: forall T. T -> T`
496
+ # with a tuple type `tuple[A, B]`, we get the type `tuple[A, B] -> tuple[A, B]`.
497
+ # Normally, this would be represented in Hugr as a function with two output
498
+ # ports types A and B. However, when TypeApplying `foo`, we actually get a
499
+ # function with a single output port typed `tuple[A, B]`.
500
+ # TODO: We would need to do manual monomorphisation in that case to obtain a
501
+ # function that returns two ports as expected
502
+ if instantiation_needs_unpacking(defn.ty, node.inst):
503
+ err = UnsupportedError(
504
+ node, "Generic function instantiations returning rows"
505
+ )
506
+ raise GuppyError(err)
507
+
508
+ return defn.load_with_args(rem_args, self.dfg, self.ctx, node)
509
+
510
+ def visit_UnaryOp(self, node: ast.UnaryOp) -> Wire:
511
+ # The only case that is not desugared by the type checker is the `not` operation
512
+ # since it is not implemented via a dunder method
513
+ if isinstance(node.op, ast.Not):
514
+ arg = self.visit(node.operand)
515
+ return self.builder.add_op(not_op(), arg)
516
+
517
+ raise InternalGuppyError("Node should have been removed during type checking.")
518
+
519
+ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> Wire:
520
+ struct_port = self.visit(node.value)
521
+ field_idx = node.struct_ty.fields.index(node.field)
522
+ return self._unpack_tuple(struct_port, [f.ty for f in node.struct_ty.fields])[
523
+ field_idx
524
+ ]
525
+
526
+ def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> Wire:
527
+ self.dfg[node.item] = self.visit(node.item_expr)
528
+ return self.visit(node.getitem_expr)
529
+
530
+ def visit_TupleAccessAndDrop(self, node: TupleAccessAndDrop) -> Wire:
531
+ tuple_port = self.visit(node.value)
532
+ return self._unpack_tuple(tuple_port, node.tuple_ty.element_types)[node.index]
533
+
534
+ def visit_ResultExpr(self, node: ResultExpr) -> Wire:
535
+ value_wire = self.visit(node.value)
536
+ base_ty = node.base_ty.to_hugr(self.ctx)
537
+ extra_args: list[ht.TypeArg] = []
538
+ if isinstance(node.base_ty, NumericType):
539
+ match node.base_ty.kind:
540
+ case NumericType.Kind.Nat:
541
+ base_name = "uint"
542
+ extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
543
+ case NumericType.Kind.Int:
544
+ base_name = "int"
545
+ extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
546
+ case NumericType.Kind.Float:
547
+ base_name = "f64"
548
+ case kind:
549
+ assert_never(kind)
550
+ else:
551
+ # The only other valid base type is bool
552
+ assert is_bool_type(node.base_ty)
553
+ base_name = "bool"
554
+ if node.array_len is not None:
555
+ op_name = f"result_array_{base_name}"
556
+ size_arg = node.array_len.to_arg().to_hugr(self.ctx)
557
+ extra_args = [size_arg, *extra_args]
558
+ # Remove the option wrapping in the array
559
+ unwrap = array_unwrap_elem(self.ctx)
560
+ unwrap = self.builder.load_function(
561
+ unwrap,
562
+ instantiation=ht.FunctionType([ht.Option(base_ty)], [base_ty]),
563
+ type_args=[ht.TypeTypeArg(base_ty)],
564
+ )
565
+ map_op = array_map(ht.Option(base_ty), size_arg, base_ty)
566
+ value_wire = self.builder.add_op(map_op, value_wire, unwrap)
567
+ if is_bool_type(node.base_ty):
568
+ # We need to coerce a read on all the array elements if they are bools.
569
+ array_read = array_read_bool(self.ctx)
570
+ array_read = self.builder.load_function(array_read)
571
+ map_op = array_map(OpaqueBool, size_arg, ht.Bool)
572
+ value_wire = self.builder.add_op(map_op, value_wire, array_read)
573
+ base_ty = ht.Bool
574
+ # Turn `value_array` into regular linear `array`
575
+ value_wire = self.builder.add_op(
576
+ array_convert_to_std_array(base_ty, size_arg), value_wire
577
+ )
578
+ hugr_ty: ht.Type = hugr.std.collections.array.Array(base_ty, size_arg)
579
+ else:
580
+ if is_bool_type(node.base_ty):
581
+ base_ty = ht.Bool
582
+ value_wire = self.builder.add_op(read_bool(), value_wire)
583
+ op_name = f"result_{base_name}"
584
+ hugr_ty = base_ty
585
+
586
+ sig = ht.FunctionType(input=[hugr_ty], output=[])
587
+ args = [ht.StringArg(node.tag), *extra_args]
588
+ op = ops.ExtOp(RESULT_EXTENSION.get_op(op_name), signature=sig, args=args)
589
+
590
+ self.builder.add_op(op, value_wire)
591
+ return self._pack_returns([], NoneType())
592
+
593
+ def visit_PanicExpr(self, node: PanicExpr) -> Wire:
594
+ err = build_error(self.builder, node.signal, node.msg)
595
+ in_tys = [get_type(e).to_hugr(self.ctx) for e in node.values]
596
+ out_tys = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(node))]
597
+ args = [self.visit(e) for e in node.values]
598
+ match node.kind:
599
+ case ExitKind.Panic:
600
+ h_node = build_panic(self.builder, in_tys, out_tys, err, *args)
601
+ case ExitKind.ExitShot:
602
+ op = panic(in_tys, out_tys, ExitKind.ExitShot)
603
+ h_node = self.builder.add_op(op, err, *args)
604
+ return self._pack_returns(list(h_node.outputs()), get_type(node))
605
+
606
+ def visit_BarrierExpr(self, node: BarrierExpr) -> Wire:
607
+ hugr_tys = [get_type(e).to_hugr(self.ctx) for e in node.args]
608
+ op = hugr.std.prelude.PRELUDE_EXTENSION.get_op("Barrier").instantiate(
609
+ [ht.ListArg([ht.TypeTypeArg(ty) for ty in hugr_tys])],
610
+ ht.FunctionType.endo(hugr_tys),
611
+ )
612
+
613
+ barrier_n = self.builder.add_op(op, *(self.visit(e) for e in node.args))
614
+
615
+ self._update_inout_ports(node.args, iter(barrier_n), node.func_ty)
616
+ return self._pack_returns([], NoneType())
617
+
618
+ def visit_StateResultExpr(self, node: StateResultExpr) -> Wire:
619
+ num_qubits_arg = (
620
+ node.array_len.to_arg().to_hugr(self.ctx)
621
+ if node.array_len
622
+ else ht.BoundedNatArg(len(node.args) - 1)
623
+ )
624
+ args = [ht.StringArg(node.tag), num_qubits_arg]
625
+ sig = ht.FunctionType(
626
+ [standard_array_type(ht.Qubit, num_qubits_arg)],
627
+ [standard_array_type(ht.Qubit, num_qubits_arg)],
628
+ )
629
+
630
+ op = ops.ExtOp(DEBUG_EXTENSION.get_op("StateResult"), signature=sig, args=args)
631
+
632
+ if not node.array_len:
633
+ # If the input is a sequence of qubits, we pack them into an array.
634
+ qubits_in = [self.visit(e) for e in node.args[1:]]
635
+ qubit_arr_in = self.builder.add_op(
636
+ array_new(ht.Qubit, len(node.args) - 1), *qubits_in
637
+ )
638
+ # Turn into standard array from value array.
639
+ qubit_arr_in = self.builder.add_op(
640
+ array_convert_to_std_array(ht.Qubit, num_qubits_arg), qubit_arr_in
641
+ )
642
+
643
+ qubit_arr_out = self.builder.add_op(op, qubit_arr_in)
644
+
645
+ qubit_arr_out = self.builder.add_op(
646
+ array_convert_from_std_array(ht.Qubit, num_qubits_arg), qubit_arr_out
647
+ )
648
+ qubits_out = unpack_array(self.builder, qubit_arr_out)
649
+ else:
650
+ # If the input is an array of qubits, we need to unwrap the elements first,
651
+ # and then convert to a value array and back.
652
+ qubits_in = [self.visit(node.args[1])]
653
+ qubits_out = [
654
+ apply_array_op_with_conversions(
655
+ self.ctx, self.builder, op, ht.Qubit, num_qubits_arg, qubits_in[0]
656
+ )
657
+ ]
658
+
659
+ self._update_inout_ports(node.args, iter(qubits_out), node.func_ty)
660
+ return self._pack_returns([], NoneType())
661
+
662
+ def visit_DesugaredListComp(self, node: DesugaredListComp) -> Wire:
663
+ # Make up a name for the list under construction and bind it to an empty list
664
+ list_ty = get_type(node)
665
+ assert isinstance(list_ty, OpaqueType)
666
+ elem_ty = get_element_type(list_ty)
667
+ list_place = Variable(next(tmp_vars), list_ty, node)
668
+ self.dfg[list_place] = list_new(self.builder, elem_ty.to_hugr(self.ctx), [])
669
+ with self._build_generators(node.generators, [list_place]):
670
+ elt_port = self.visit(node.elt)
671
+ list_port = self.dfg[list_place]
672
+ [], [self.dfg[list_place]] = self._build_method_call(
673
+ list_ty, "append", node, [list_port, elt_port], list_ty.args
674
+ )
675
+ return self.dfg[list_place]
676
+
677
+ def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> Wire:
678
+ # Allocate an uninitialised array of the desired size and a counter variable
679
+ array_ty = get_type(node)
680
+ assert isinstance(array_ty, OpaqueType)
681
+ array_var = Variable(next(tmp_vars), array_ty, node)
682
+ count_var = Variable(next(tmp_vars), int_type(), node)
683
+ # See https://github.com/CQCL/guppylang/issues/629
684
+ hugr_elt_ty = ht.Option(node.elt_ty.to_hugr(self.ctx))
685
+ # Initialise array with `None`s
686
+ make_none = array_comprehension_init_func(self.ctx)
687
+ make_none = self.builder.load_function(
688
+ make_none,
689
+ instantiation=ht.FunctionType([], [hugr_elt_ty]),
690
+ type_args=[ht.TypeTypeArg(node.elt_ty.to_hugr(self.ctx))],
691
+ )
692
+ self.dfg[array_var] = self.builder.add_op(
693
+ array_repeat(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx)), make_none
694
+ )
695
+ self.dfg[count_var] = self.builder.load(
696
+ hugr.std.int.IntVal(0, width=NumericType.INT_WIDTH)
697
+ )
698
+ with self._build_generators([node.generator], [array_var, count_var]):
699
+ elt = self.visit(node.elt)
700
+ array, count = self.dfg[array_var], self.dfg[count_var]
701
+ [], [self.dfg[array_var]] = self._build_method_call(
702
+ array_ty, "__setitem__", node, [array, count, elt], array_ty.args
703
+ )
704
+ # Update `count += 1`
705
+ one = self.builder.load(hugr.std.int.IntVal(1, width=NumericType.INT_WIDTH))
706
+ [self.dfg[count_var]], [] = self._build_method_call(
707
+ int_type(), "__add__", node, [count, one], []
708
+ )
709
+ return self.dfg[array_var]
710
+
711
+ def _build_method_call(
712
+ self, ty: Type, method: str, node: AstNode, args: list[Wire], type_args: Inst
713
+ ) -> CallReturnWires:
714
+ func_and_targs = self.ctx.build_compiled_instance_func(ty, method, type_args)
715
+ assert func_and_targs is not None
716
+ func, rem_args = func_and_targs
717
+ return func.compile_call(args, rem_args, self.dfg, self.ctx, node)
718
+
719
+ @contextmanager
720
+ def _build_generators(
721
+ self, gens: list[DesugaredGenerator], loop_vars: list[Variable]
722
+ ) -> Iterator[None]:
723
+ """Context manager to build and enter the `TailLoop`s for a list of generators.
724
+
725
+ The provided `loop_vars` will be threaded through and will be available inside
726
+ the loops.
727
+ """
728
+ from guppylang_internals.compiler.stmt_compiler import StmtCompiler
729
+
730
+ compiler = StmtCompiler(self.ctx)
731
+ with ExitStack() as stack:
732
+ for gen in gens:
733
+ # Build the generator
734
+ compiler.compile_stmts([gen.iter_assign], self.dfg)
735
+ assert isinstance(gen.iter, PlaceNode)
736
+ iter_ty = get_type(gen.iter)
737
+ inputs = [PlaceNode(place=var) for var in loop_vars]
738
+ inputs += [PlaceNode(place=place) for place in gen.used_outer_places]
739
+ # Enter a new tail loop. Note that the iterator is a `just_input`, so
740
+ # will not be outputted by the loop
741
+ break_pred = PlaceNode(Variable(next(tmp_vars), bool_type(), gen.iter))
742
+ stack.enter_context(self._new_loop([gen.iter], inputs, break_pred))
743
+ # Enter a conditional checking if we have a next element
744
+ next_ty = TupleType([get_type(gen.target), iter_ty])
745
+ next_var = PlaceNode(Variable(next(tmp_vars), next_ty, gen.iter))
746
+ hasnext_case, stop_case = self._if_else(
747
+ gen.next_call,
748
+ inputs,
749
+ only_true_inputs=[next_var],
750
+ outputs=[break_pred, *inputs],
751
+ )
752
+ # In the "no" case, we set the break predicate to true
753
+ break_pred_hugr_ty = ht.Either([iter_ty.to_hugr(self.ctx)], [])
754
+ with stop_case:
755
+ self.dfg[break_pred.place] = self.dfg.builder.add_op(
756
+ ops.Tag(1, break_pred_hugr_ty)
757
+ )
758
+ # Otherwise, we continue, set the break predicate to false, and insert
759
+ # the iterator for the next loop iteration
760
+ stack.enter_context(hasnext_case)
761
+ next_wire = self.dfg[next_var.place]
762
+ elt, it = self.dfg.builder.add_op(ops.UnpackTuple(), next_wire)
763
+ compiler.dfg = self.dfg
764
+ compiler._assign(gen.target, elt)
765
+ self.dfg[break_pred.place] = self.dfg.builder.add_op(
766
+ ops.Tag(0, break_pred_hugr_ty), it
767
+ )
768
+ # Enter nested conditionals for each if guard on the generator
769
+ for if_expr in gen.ifs:
770
+ stack.enter_context(self._if_true(if_expr, [break_pred, *inputs]))
771
+ # Yield control to the caller to build inside the loop
772
+ yield
773
+
774
+ def visit_BinOp(self, node: ast.BinOp) -> Wire:
775
+ raise InternalGuppyError("Node should have been removed during type checking.")
776
+
777
+ def visit_Compare(self, node: ast.Compare) -> Wire:
778
+ raise InternalGuppyError("Node should have been removed during type checking.")
779
+
780
+
781
+ def expr_to_row(expr: ast.expr) -> list[ast.expr]:
782
+ """Turns an expression into a row expressions by unpacking top-level tuples."""
783
+ return expr.elts if isinstance(expr, ast.Tuple) else [expr]
784
+
785
+
786
+ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
787
+ """Checks if instantiating a polymorphic makes it return a row."""
788
+ if isinstance(func_ty.output, BoundTypeVar):
789
+ return_ty = inst[func_ty.output.idx]
790
+ return isinstance(return_ty, TupleType | NoneType)
791
+ return False
792
+
793
+
794
+ def python_value_to_hugr(v: Any, exp_ty: Type, ctx: CompilerContext) -> hv.Value | None:
795
+ """Turns a Python value into a Hugr value.
796
+
797
+ Returns None if the Python value cannot be represented in Guppy.
798
+ """
799
+ match v:
800
+ case bool():
801
+ return OpaqueBoolVal(v)
802
+ case str():
803
+ return hugr.std.prelude.StringVal(v)
804
+ case int():
805
+ assert isinstance(exp_ty, NumericType)
806
+ match exp_ty.kind:
807
+ case NumericType.Kind.Nat:
808
+ return UnsignedIntVal(v, width=NumericType.INT_WIDTH)
809
+ case NumericType.Kind.Int:
810
+ return hugr.std.int.IntVal(v, width=NumericType.INT_WIDTH)
811
+ case _:
812
+ raise InternalGuppyError("Unexpected numeric type")
813
+ case float():
814
+ return hugr.std.float.FloatVal(v)
815
+ case tuple(elts):
816
+ assert isinstance(exp_ty, TupleType)
817
+ vs = [
818
+ python_value_to_hugr(elt, ty, ctx)
819
+ for elt, ty in zip(elts, exp_ty.element_types, strict=True)
820
+ ]
821
+ if doesnt_contain_none(vs):
822
+ return hv.Tuple(*vs)
823
+ case list(elts):
824
+ assert is_frozenarray_type(exp_ty)
825
+ elem_ty = get_element_type(exp_ty)
826
+ vs = [python_value_to_hugr(elt, elem_ty, ctx) for elt in elts]
827
+ if doesnt_contain_none(vs):
828
+ return hugr.std.collections.static_array.StaticArrayVal(
829
+ vs, elem_ty.to_hugr(ctx), name=f"static_pyarray.{next(tmp_vars)}"
830
+ )
831
+ case _:
832
+ return None
833
+ return None
834
+
835
+
836
+ ARRAY_COMPREHENSION_INIT: Final[GlobalConstId] = GlobalConstId.fresh(
837
+ "array.__comprehension.init"
838
+ )
839
+
840
+ ARRAY_UNWRAP_ELEM: Final[GlobalConstId] = GlobalConstId.fresh("array.__unwrap_elem")
841
+ ARRAY_WRAP_ELEM: Final[GlobalConstId] = GlobalConstId.fresh("array.__wrap_elem")
842
+
843
+ ARRAY_READ_BOOL: Final[GlobalConstId] = GlobalConstId.fresh("array.__read_bool")
844
+ ARRAY_MAKE_OPAQUE_BOOL: Final[GlobalConstId] = GlobalConstId.fresh(
845
+ "array.__make_opaque_bool"
846
+ )
847
+
848
+
849
+ def array_comprehension_init_func(ctx: CompilerContext) -> hf.Function:
850
+ """Returns the Hugr function that is used to initialise arrays elements before a
851
+ comprehension.
852
+
853
+ Just returns the `None` variant of the optional element type.
854
+
855
+ See https://github.com/CQCL/guppylang/issues/629
856
+ """
857
+ v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
858
+ sig = ht.PolyFuncType(
859
+ params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
860
+ body=ht.FunctionType([], [ht.Option(v)]),
861
+ )
862
+ func, already_defined = ctx.declare_global_func(ARRAY_COMPREHENSION_INIT, sig)
863
+ if not already_defined:
864
+ func.set_outputs(func.add_op(ops.Tag(0, ht.Option(v))))
865
+ return func
866
+
867
+
868
+ def array_unwrap_elem(ctx: CompilerContext) -> hf.Function:
869
+ """Returns the Hugr function that is used to unwrap the elements in an option array
870
+ to turn it into a regular array."""
871
+ v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
872
+ sig = ht.PolyFuncType(
873
+ params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
874
+ body=ht.FunctionType([ht.Option(v)], [v]),
875
+ )
876
+ func, already_defined = ctx.declare_global_func(ARRAY_UNWRAP_ELEM, sig)
877
+ if not already_defined:
878
+ msg = "Linear array element has already been used"
879
+ func.set_outputs(build_unwrap(func, func.inputs()[0], msg))
880
+ return func
881
+
882
+
883
+ def array_wrap_elem(ctx: CompilerContext) -> hf.Function:
884
+ """Returns the Hugr function that is used to wrap the elements in an regular array
885
+ to turn it into a option array."""
886
+ v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
887
+ sig = ht.PolyFuncType(
888
+ params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
889
+ body=ht.FunctionType([v], [ht.Option(v)]),
890
+ )
891
+ func, already_defined = ctx.declare_global_func(ARRAY_WRAP_ELEM, sig)
892
+ if not already_defined:
893
+ func.set_outputs(func.add_op(ops.Tag(1, ht.Option(v)), func.inputs()[0]))
894
+ return func
895
+
896
+
897
+ def array_read_bool(ctx: CompilerContext) -> hf.Function:
898
+ """Returns the Hugr function that is used to unwrap the elements in an option array
899
+ to turn it into a regular array."""
900
+ sig = ht.PolyFuncType(
901
+ params=[],
902
+ body=ht.FunctionType([OpaqueBool], [ht.Bool]),
903
+ )
904
+ func, already_defined = ctx.declare_global_func(ARRAY_READ_BOOL, sig)
905
+ if not already_defined:
906
+ func.set_outputs(func.add_op(read_bool(), func.inputs()[0]))
907
+ return func
908
+
909
+
910
+ def array_make_opaque_bool(ctx: CompilerContext) -> hf.Function:
911
+ """Returns the Hugr function that is used to unwrap the elements in an option array
912
+ to turn it into a regular array."""
913
+ sig = ht.PolyFuncType(
914
+ params=[],
915
+ body=ht.FunctionType([ht.Bool], [OpaqueBool]),
916
+ )
917
+ func, already_defined = ctx.declare_global_func(ARRAY_MAKE_OPAQUE_BOOL, sig)
918
+ if not already_defined:
919
+ func.set_outputs(func.add_op(make_opaque(), func.inputs()[0]))
920
+ return func
921
+
922
+
923
+ T = TypeVar("T")
924
+
925
+
926
+ def doesnt_contain_none(xs: list[T | None]) -> TypeGuard[list[T]]:
927
+ """Checks if a list contains `None`."""
928
+ return all(x is not None for x in xs)
929
+
930
+
931
+ def apply_array_op_with_conversions(
932
+ ctx: CompilerContext,
933
+ builder: DfBase[ops.DfParentOp],
934
+ op: ops.DataflowOp,
935
+ elem_ty: ht.Type,
936
+ size_arg: ht.TypeArg,
937
+ input_array: Wire,
938
+ convert_bool: bool = False,
939
+ ) -> Wire:
940
+ """Applies common transformations to a Guppy array input before it can be passed to
941
+ a Hugr op operating on a standard Hugr array, and then reverses them again on the
942
+ output array.
943
+
944
+ Transformations:
945
+ 1. Unwraps / wraps elements in options.
946
+ 3. (Optional) Converts from / to opaque bool to / from Hugr bool.
947
+ 2. Converts from / to value array to / from standard Hugr array.
948
+ """
949
+ unwrap = array_unwrap_elem(ctx)
950
+ unwrap = builder.load_function(
951
+ unwrap,
952
+ instantiation=ht.FunctionType([ht.Option(elem_ty)], [elem_ty]),
953
+ type_args=[ht.TypeTypeArg(elem_ty)],
954
+ )
955
+ map_op = array_map(ht.Option(elem_ty), size_arg, elem_ty)
956
+ unwrapped_array = builder.add_op(map_op, input_array, unwrap)
957
+
958
+ if convert_bool:
959
+ array_read = array_read_bool(ctx)
960
+ array_read = builder.load_function(array_read)
961
+ map_op = array_map(OpaqueBool, size_arg, ht.Bool)
962
+ unwrapped_array = builder.add_op(map_op, unwrapped_array, array_read)
963
+ elem_ty = ht.Bool
964
+
965
+ unwrapped_array = builder.add_op(
966
+ array_convert_to_std_array(elem_ty, size_arg), unwrapped_array
967
+ )
968
+
969
+ result_array = builder.add_op(op, unwrapped_array)
970
+
971
+ result_array = builder.add_op(
972
+ array_convert_from_std_array(elem_ty, size_arg), result_array
973
+ )
974
+
975
+ if convert_bool:
976
+ array_make_opaque = array_make_opaque_bool(ctx)
977
+ array_make_opaque = builder.load_function(array_make_opaque)
978
+ map_op = array_map(ht.Bool, size_arg, OpaqueBool)
979
+ result_array = builder.add_op(map_op, result_array, array_make_opaque)
980
+ elem_ty = OpaqueBool
981
+
982
+ wrap = array_wrap_elem(ctx)
983
+ wrap = builder.load_function(
984
+ wrap,
985
+ instantiation=ht.FunctionType([elem_ty], [ht.Option(elem_ty)]),
986
+ type_args=[ht.TypeTypeArg(elem_ty)],
987
+ )
988
+ map_op = array_map(elem_ty, size_arg, ht.Option(elem_ty))
989
+ return builder.add_op(map_op, result_array, wrap)