guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,136 @@
1
+ """Native arithmetic operations from the HUGR std, and compilers for non native ones."""
2
+
3
+ from collections.abc import Sequence
4
+ from dataclasses import dataclass
5
+
6
+ import hugr.std.int
7
+ from hugr import model, ops, val
8
+ from hugr import tys as ht
9
+ from hugr.std.int import int_t
10
+
11
+ from guppylang_internals.std._internal.compiler.prelude import error_type
12
+ from guppylang_internals.tys.ty import NumericType
13
+
14
+ INT_T = int_t(NumericType.INT_WIDTH)
15
+
16
+
17
+ @dataclass
18
+ class UnsignedIntVal(val.ExtensionValue): # TODO: Upstream this to hugr-py?
19
+ """Custom value for an unsigned integer."""
20
+
21
+ v: int
22
+ width: int
23
+
24
+ def __post_init__(self) -> None:
25
+ assert self.v >= 0
26
+
27
+ def to_value(self) -> val.Extension:
28
+ payload = {"log_width": self.width, "value": self.v}
29
+ return val.Extension("ConstInt", typ=int_t(self.width), val=payload)
30
+
31
+ def __str__(self) -> str:
32
+ return f"{self.v}"
33
+
34
+ def to_model(self) -> model.Term:
35
+ return model.Apply(
36
+ "arithmetic.int.const", [model.Literal(self.width), model.Literal(self.v)]
37
+ )
38
+
39
+
40
+ # ------------------------------------------------------
41
+ # --------- std.arithmetic.int operations --------------
42
+ # ------------------------------------------------------
43
+
44
+
45
+ def _instantiate_int_op(
46
+ name: str,
47
+ int_width: int | Sequence[int],
48
+ inp: list[ht.Type],
49
+ out: list[ht.Type],
50
+ ) -> ops.ExtOp:
51
+ op_def = hugr.std.int.INT_OPS_EXTENSION.get_op(name)
52
+ int_width = [int_width] if isinstance(int_width, int) else int_width
53
+ return ops.ExtOp(
54
+ op_def,
55
+ ht.FunctionType(inp, out),
56
+ [ht.BoundedNatArg(w) for w in int_width],
57
+ )
58
+
59
+
60
+ def ieq(width: int) -> ops.ExtOp:
61
+ """Returns a `std.arithmetic.int.ieq` operation."""
62
+ return _instantiate_int_op("ieq", width, [int_t(width), int_t(width)], [ht.Bool])
63
+
64
+
65
+ def ine(width: int) -> ops.ExtOp:
66
+ """Returns a `std.arithmetic.int.ine` operation."""
67
+ return _instantiate_int_op("ine", width, [int_t(width), int_t(width)], [ht.Bool])
68
+
69
+
70
+ def iwiden_u(from_width: int, to_width: int) -> ops.ExtOp:
71
+ """Returns an unsigned `std.arithmetic.int.widen_u` operation."""
72
+ return _instantiate_int_op(
73
+ "iwiden_u", [from_width, to_width], [int_t(from_width)], [int_t(to_width)]
74
+ )
75
+
76
+
77
+ def iwiden_s(from_width: int, to_width: int) -> ops.ExtOp:
78
+ """Returns a signed `std.arithmetic.int.widen_s` operation."""
79
+ return _instantiate_int_op(
80
+ "iwiden_s", [from_width, to_width], [int_t(from_width)], [int_t(to_width)]
81
+ )
82
+
83
+
84
+ def inarrow_u(from_width: int, to_width: int) -> ops.ExtOp:
85
+ """Returns an unsigned `std.arithmetic.int.narrow_u` operation."""
86
+ return _instantiate_int_op(
87
+ "inarrow_u",
88
+ [from_width, to_width],
89
+ [int_t(from_width)],
90
+ [ht.Either([error_type()], [int_t(to_width)])],
91
+ )
92
+
93
+
94
+ def inarrow_s(from_width: int, to_width: int) -> ops.ExtOp:
95
+ """Returns a signed `std.arithmetic.int.narrow_s` operation."""
96
+ return _instantiate_int_op(
97
+ "inarrow_s",
98
+ [from_width, to_width],
99
+ [int_t(from_width)],
100
+ [ht.Either([error_type()], [int_t(to_width)])],
101
+ )
102
+
103
+
104
+ # ------------------------------------------------------
105
+ # --------- std.arithmetic.conversions ops -------------
106
+ # ------------------------------------------------------
107
+
108
+
109
+ def _instantiate_convert_op(
110
+ name: str,
111
+ inp: list[ht.Type],
112
+ out: list[ht.Type],
113
+ args: list[ht.TypeArg] | None = None,
114
+ ) -> ops.ExtOp:
115
+ op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op(name)
116
+ return ops.ExtOp(op_def, ht.FunctionType(inp, out), args or [])
117
+
118
+
119
+ def convert_ifromusize() -> ops.ExtOp:
120
+ """Returns a `std.arithmetic.conversions.ifromusize` operation."""
121
+ return _instantiate_convert_op("ifromusize", [ht.USize()], [INT_T])
122
+
123
+
124
+ def convert_itousize() -> ops.ExtOp:
125
+ """Returns a `std.arithmetic.conversions.itousize` operation."""
126
+ return _instantiate_convert_op("itousize", [INT_T], [ht.USize()])
127
+
128
+
129
+ def convert_ifrombool() -> ops.ExtOp:
130
+ """Returns a `std.arithmetic.conversions.ifrombool` operation."""
131
+ return _instantiate_convert_op("ifrombool", [ht.Bool], [int_t(0)])
132
+
133
+
134
+ def convert_itobool() -> ops.ExtOp:
135
+ """Returns a `std.arithmetic.conversions.itobool` operation."""
136
+ return _instantiate_convert_op("itobool", [int_t(0)], [ht.Bool])
@@ -0,0 +1,569 @@
1
+ """Compilers building array functions on top of hugr standard operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Final, TypeVar
6
+
7
+ import hugr
8
+ from hugr import Wire, ops
9
+ from hugr import tys as ht
10
+ from hugr.std.collections.value_array import EXTENSION
11
+
12
+ from guppylang_internals.compiler.core import (
13
+ GlobalConstId,
14
+ )
15
+ from guppylang_internals.definition.custom import CustomCallCompiler
16
+ from guppylang_internals.definition.value import CallReturnWires
17
+ from guppylang_internals.error import InternalGuppyError
18
+ from guppylang_internals.std._internal.compiler.arithmetic import convert_itousize
19
+ from guppylang_internals.std._internal.compiler.prelude import (
20
+ build_expect_none,
21
+ build_unwrap,
22
+ build_unwrap_left,
23
+ build_unwrap_right,
24
+ )
25
+ from guppylang_internals.tys.arg import ConstArg, TypeArg
26
+ from guppylang_internals.tys.builtin import int_type
27
+
28
+ if TYPE_CHECKING:
29
+ from hugr.build import function as hf
30
+ from hugr.build.dfg import DfBase
31
+
32
+
33
+ # ------------------------------------------------------
34
+ # --------------- std.array operations -----------------
35
+ # ------------------------------------------------------
36
+
37
+
38
+ def _instantiate_array_op(
39
+ name: str,
40
+ elem_ty: ht.Type,
41
+ length: ht.TypeArg,
42
+ inp: list[ht.Type],
43
+ out: list[ht.Type],
44
+ ) -> ops.ExtOp:
45
+ return EXTENSION.get_op(name).instantiate(
46
+ [length, ht.TypeTypeArg(elem_ty)], ht.FunctionType(inp, out)
47
+ )
48
+
49
+
50
+ def array_type(elem_ty: ht.Type, length: ht.TypeArg) -> ht.ExtType:
51
+ """Returns the hugr type of a fixed length array.
52
+
53
+ This is the copyable `value_array` type used by Guppy.
54
+ """
55
+ elem_arg = ht.TypeTypeArg(elem_ty)
56
+ return EXTENSION.types["value_array"].instantiate([length, elem_arg])
57
+
58
+
59
+ def standard_array_type(elem_ty: ht.Type, length: ht.TypeArg) -> ht.ExtType:
60
+ """Returns the hugr type of a linear fixed length array.
61
+
62
+ This is the standard `array` type targeted by Hugr.
63
+ """
64
+ elem_arg = ht.TypeTypeArg(elem_ty)
65
+ defn = hugr.std.collections.array.EXTENSION.types["array"]
66
+ return defn.instantiate([length, elem_arg])
67
+
68
+
69
+ def array_new(elem_ty: ht.Type, length: int) -> ops.ExtOp:
70
+ """Returns an operation that creates a new fixed length array."""
71
+ length_arg = ht.BoundedNatArg(length)
72
+ arr_ty = array_type(elem_ty, length_arg)
73
+ return _instantiate_array_op(
74
+ "new_array", elem_ty, length_arg, [elem_ty] * length, [arr_ty]
75
+ )
76
+
77
+
78
+ def array_unpack(elem_ty: ht.Type, length: int) -> ops.ExtOp:
79
+ """Returns an operation that unpacks a fixed length array."""
80
+ length_arg = ht.BoundedNatArg(length)
81
+ arr_ty = array_type(elem_ty, length_arg)
82
+ return _instantiate_array_op(
83
+ "unpack", elem_ty, length_arg, [arr_ty], [elem_ty] * length
84
+ )
85
+
86
+
87
+ def array_get(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
88
+ """Returns an array `get` operation."""
89
+ assert elem_ty.type_bound() == ht.TypeBound.Copyable
90
+ arr_ty = array_type(elem_ty, length)
91
+ return _instantiate_array_op(
92
+ "get", elem_ty, length, [arr_ty, ht.USize()], [ht.Option(elem_ty), arr_ty]
93
+ )
94
+
95
+
96
+ def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
97
+ """Returns an array `set` operation."""
98
+ arr_ty = array_type(elem_ty, length)
99
+ return _instantiate_array_op(
100
+ "set",
101
+ elem_ty,
102
+ length,
103
+ [arr_ty, ht.USize(), elem_ty],
104
+ [ht.Either([elem_ty, arr_ty], [elem_ty, arr_ty])],
105
+ )
106
+
107
+
108
+ def array_pop(elem_ty: ht.Type, length: int, from_left: bool) -> ops.ExtOp:
109
+ """Returns an operation that pops an element from the left of an array."""
110
+ assert length > 0
111
+ length_arg = ht.BoundedNatArg(length)
112
+ arr_ty = array_type(elem_ty, length_arg)
113
+ popped_arr_ty = array_type(elem_ty, ht.BoundedNatArg(length - 1))
114
+ op = "pop_left" if from_left else "pop_right"
115
+ return _instantiate_array_op(
116
+ op, elem_ty, length_arg, [arr_ty], [ht.Option(elem_ty, popped_arr_ty)]
117
+ )
118
+
119
+
120
+ def array_discard_empty(elem_ty: ht.Type) -> ops.ExtOp:
121
+ """Returns an operation that discards an array of length zero."""
122
+ arr_ty = array_type(elem_ty, ht.BoundedNatArg(0))
123
+ return EXTENSION.get_op("discard_empty").instantiate(
124
+ [ht.TypeTypeArg(elem_ty)], ht.FunctionType([arr_ty], [])
125
+ )
126
+
127
+
128
+ def array_scan(
129
+ elem_ty: ht.Type,
130
+ length: ht.TypeArg,
131
+ new_elem_ty: ht.Type,
132
+ accumulators: list[ht.Type],
133
+ ) -> ops.ExtOp:
134
+ """Returns an operation that maps and folds a function across an array."""
135
+ ty_args = [
136
+ length,
137
+ ht.TypeTypeArg(elem_ty),
138
+ ht.TypeTypeArg(new_elem_ty),
139
+ ht.ListArg([ht.TypeTypeArg(acc) for acc in accumulators]),
140
+ ]
141
+ ins = [
142
+ array_type(elem_ty, length),
143
+ ht.FunctionType([elem_ty, *accumulators], [new_elem_ty, *accumulators]),
144
+ *accumulators,
145
+ ]
146
+ outs = [array_type(new_elem_ty, length), *accumulators]
147
+ return EXTENSION.get_op("scan").instantiate(ty_args, ht.FunctionType(ins, outs))
148
+
149
+
150
+ def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp:
151
+ """Returns an operation that maps a function across an array."""
152
+ return array_scan(elem_ty, length, new_elem_ty, accumulators=[])
153
+
154
+
155
+ def array_repeat(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
156
+ """Returns an array `repeat` operation."""
157
+ return EXTENSION.get_op("repeat").instantiate(
158
+ [length, ht.TypeTypeArg(elem_ty)],
159
+ ht.FunctionType(
160
+ [ht.FunctionType([], [elem_ty])], [array_type(elem_ty, length)]
161
+ ),
162
+ )
163
+
164
+
165
+ def array_convert_to_std_array(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
166
+ """Returns an array operation to convert the `value_array` type used by Guppy into
167
+ the regular linear `array` in Hugr.
168
+ """
169
+ return EXTENSION.get_op("to_array").instantiate(
170
+ [length, ht.TypeTypeArg(elem_ty)],
171
+ ht.FunctionType(
172
+ [array_type(elem_ty, length)], [standard_array_type(elem_ty, length)]
173
+ ),
174
+ )
175
+
176
+
177
+ def array_convert_from_std_array(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
178
+ """Returns an array operation to convert the `array` type used by Hugr into the
179
+ `value_array` type used by Guppy.
180
+ """
181
+ return EXTENSION.get_op("from_array").instantiate(
182
+ [length, ht.TypeTypeArg(elem_ty)],
183
+ ht.FunctionType(
184
+ [standard_array_type(elem_ty, length)], [array_type(elem_ty, length)]
185
+ ),
186
+ )
187
+
188
+
189
+ # ------------------------------------------------------
190
+ # --------- Custom compilers for non-native ops --------
191
+ # ------------------------------------------------------
192
+
193
+
194
+ P = TypeVar("P", bound=ops.DfParentOp)
195
+
196
+
197
+ def unpack_array(builder: DfBase[P], array: Wire) -> list[Wire]:
198
+ """Unpacks a fixed length array into its elements."""
199
+ array_ty = builder.hugr.port_type(array.out_port())
200
+ assert isinstance(array_ty, ht.ExtType)
201
+ match array_ty.args:
202
+ case [ht.BoundedNatArg(length), ht.TypeTypeArg(elem_ty)]:
203
+ res = builder.add_op(array_unpack(elem_ty, length), array)
204
+ return [res[i] for i in range(length)]
205
+ case _:
206
+ raise InternalGuppyError("Invalid array type args")
207
+
208
+
209
+ class ArrayCompiler(CustomCallCompiler):
210
+ """Base class for custom array op compilers."""
211
+
212
+ @property
213
+ def elem_ty(self) -> ht.Type:
214
+ """The element type for the array op that is being compiled."""
215
+ match self.type_args:
216
+ case [TypeArg(ty=elem_ty), _]:
217
+ return elem_ty.to_hugr(self.ctx)
218
+ case _:
219
+ raise InternalGuppyError("Invalid array type args")
220
+
221
+ @property
222
+ def length(self) -> ht.TypeArg:
223
+ """The length for the array op that is being compiled."""
224
+ match self.type_args:
225
+ case [_, ConstArg(const)]: # Const includes both literals and variables
226
+ return const.to_arg().to_hugr(self.ctx)
227
+ case _:
228
+ raise InternalGuppyError("Invalid array type args")
229
+
230
+
231
+ class NewArrayCompiler(ArrayCompiler):
232
+ """Compiler for the `array.__new__` function."""
233
+
234
+ def build_classical_array(self, elems: list[Wire]) -> Wire:
235
+ """Lowers a call to `array.__new__` for classical arrays."""
236
+ # See https://github.com/CQCL/guppylang/issues/629
237
+ return self.build_linear_array(elems)
238
+
239
+ def build_linear_array(self, elems: list[Wire]) -> Wire:
240
+ """Lowers a call to `array.__new__` for linear arrays."""
241
+ elem_opts = [
242
+ self.builder.add_op(ops.Some(self.elem_ty), elem) for elem in elems
243
+ ]
244
+ return self.builder.add_op(
245
+ array_new(ht.Option(self.elem_ty), len(elems)), *elem_opts
246
+ )
247
+
248
+ def compile(self, args: list[Wire]) -> list[Wire]:
249
+ if self.elem_ty.type_bound() == ht.TypeBound.Linear:
250
+ return [self.build_linear_array(args)]
251
+ else:
252
+ return [self.build_classical_array(args)]
253
+
254
+
255
+ ARRAY_GETITEM_CLASSICAL: Final[GlobalConstId] = GlobalConstId.fresh(
256
+ "array.__getitem__.classical"
257
+ )
258
+ ARRAY_GETITEM_LINEAR: Final[GlobalConstId] = GlobalConstId.fresh(
259
+ "array.__getitem__.linear"
260
+ )
261
+ ARRAY_SETITEM_CLASSICAL: Final[GlobalConstId] = GlobalConstId.fresh(
262
+ "array.__setitem__.classical"
263
+ )
264
+ ARRAY_SETITEM_LINEAR: Final[GlobalConstId] = GlobalConstId.fresh(
265
+ "array.__setitem__.linear"
266
+ )
267
+ ARRAY_ITER_ASSERT_ALL_USED_HELPER: Final[GlobalConstId] = GlobalConstId.fresh(
268
+ "ArrayIter._assert_all_used.helper"
269
+ )
270
+
271
+
272
+ class ArrayGetitemCompiler(ArrayCompiler):
273
+ """Compiler for the `array.__getitem__` function."""
274
+
275
+ def _getitem_ty(self, bound: ht.TypeBound) -> ht.PolyFuncType:
276
+ """Constructs a polymorphic function type for `__getitem__`"""
277
+ # a(Option(T), N), int -> T, a(Option(T), N)
278
+ # Array element type parameter
279
+ elem_ty_param = ht.TypeTypeParam(bound)
280
+ # Array length parameter
281
+ length_param = ht.BoundedNatParam()
282
+ return ht.PolyFuncType(
283
+ params=[elem_ty_param, length_param],
284
+ body=ht.FunctionType(
285
+ input=[
286
+ array_type(
287
+ ht.Option(ht.Variable(0, bound)),
288
+ ht.VariableArg(1, length_param),
289
+ ),
290
+ int_type().to_hugr(self.ctx),
291
+ ],
292
+ output=[
293
+ ht.Variable(0, bound),
294
+ array_type(
295
+ ht.Option(ht.Variable(0, bound)),
296
+ ht.VariableArg(1, length_param),
297
+ ),
298
+ ],
299
+ ),
300
+ )
301
+
302
+ def _build_classical_getitem(self, func: hf.Function) -> None:
303
+ """Constructs a generic function for `__getitem__` for classical arrays."""
304
+ elem_ty = ht.Variable(0, ht.TypeBound.Copyable)
305
+ length = ht.VariableArg(1, ht.BoundedNatParam())
306
+
307
+ # See https://github.com/CQCL/guppylang/issues/629
308
+ elem_opt_ty = ht.Option(elem_ty)
309
+ none = func.add_op(ops.Tag(0, elem_opt_ty))
310
+ idx = func.add_op(convert_itousize(), func.inputs()[1])
311
+ # As copyable elements can be used multiple times, we need to swap the element
312
+ # back after initially swapping it out for `None` to get the value.
313
+ initial_result = func.add_op(
314
+ array_set(elem_opt_ty, length),
315
+ func.inputs()[0],
316
+ idx,
317
+ none,
318
+ )
319
+ elem_opt, arr = build_unwrap_right(
320
+ func, initial_result, "Array index out of bounds"
321
+ )
322
+ swapped_back = func.add_op(
323
+ array_set(elem_opt_ty, length),
324
+ arr,
325
+ idx,
326
+ elem_opt,
327
+ )
328
+ _, arr = build_unwrap_right(func, swapped_back, "Array index out of bounds")
329
+ elem = build_unwrap(func, elem_opt, "array.__getitem__: Internal error")
330
+
331
+ func.set_outputs(elem, arr)
332
+
333
+ def _build_linear_getitem(self, func: hf.Function) -> None:
334
+ """Constructs function to call `array.__getitem__` for linear arrays."""
335
+ elem_ty = ht.Variable(0, ht.TypeBound.Linear)
336
+ length = ht.VariableArg(1, ht.BoundedNatParam())
337
+
338
+ elem_opt_ty = ht.Option(elem_ty)
339
+ none = func.add_op(ops.Tag(0, elem_opt_ty))
340
+ idx = func.add_op(convert_itousize(), func.inputs()[1])
341
+ result = func.add_op(
342
+ array_set(elem_opt_ty, length),
343
+ func.inputs()[0],
344
+ idx,
345
+ none,
346
+ )
347
+ elem_opt, array = build_unwrap_right(func, result, "Array index out of bounds")
348
+ elem = build_unwrap(
349
+ func, elem_opt, "Linear array element has already been used"
350
+ )
351
+
352
+ func.set_outputs(elem, array)
353
+
354
+ def _build_call_getitem(
355
+ self,
356
+ func: hf.Function,
357
+ array: Wire,
358
+ idx: Wire,
359
+ ) -> CallReturnWires:
360
+ """Inserts a call to `array.__getitem__`."""
361
+ concrete_func_ty = ht.FunctionType(
362
+ input=[
363
+ array_type(ht.Option(self.elem_ty), self.length),
364
+ int_type().to_hugr(self.ctx),
365
+ ],
366
+ output=[self.elem_ty, array_type(ht.Option(self.elem_ty), self.length)],
367
+ )
368
+ type_args = [ht.TypeTypeArg(self.elem_ty), self.length]
369
+ func_call = self.builder.call(
370
+ func.parent_node,
371
+ array,
372
+ idx,
373
+ instantiation=concrete_func_ty,
374
+ type_args=type_args,
375
+ )
376
+ outputs = list(func_call.outputs())
377
+ return CallReturnWires(
378
+ regular_returns=[outputs[0]],
379
+ inout_returns=[outputs[1]],
380
+ )
381
+
382
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
383
+ [array, idx] = args
384
+ [elem_ty_arg, _] = self.type_args
385
+ assert isinstance(elem_ty_arg, TypeArg)
386
+ if not elem_ty_arg.ty.copyable:
387
+ func_ty = self._getitem_ty(ht.TypeBound.Linear)
388
+ func, already_exists = self.ctx.declare_global_func(
389
+ ARRAY_GETITEM_LINEAR, func_ty
390
+ )
391
+ if not already_exists:
392
+ self._build_linear_getitem(func)
393
+ else:
394
+ func_ty = self._getitem_ty(ht.TypeBound.Copyable)
395
+ func, already_exists = self.ctx.declare_global_func(
396
+ ARRAY_GETITEM_CLASSICAL, func_ty
397
+ )
398
+ if not already_exists:
399
+ self._build_classical_getitem(func)
400
+ return self._build_call_getitem(
401
+ func=func,
402
+ array=array,
403
+ idx=idx,
404
+ )
405
+
406
+ def compile(self, args: list[Wire]) -> list[Wire]:
407
+ raise InternalGuppyError("Call compile_with_inouts instead")
408
+
409
+
410
+ class ArraySetitemCompiler(ArrayCompiler):
411
+ """Compiler for the `array.__setitem__` function."""
412
+
413
+ def _setitem_ty(self, bound: ht.TypeBound) -> ht.PolyFuncType:
414
+ """Constructs a polymorphic function type for `__setitem__`"""
415
+ # a(Option(T), N), int, T -> a(Option(T), N)
416
+ elem_ty_param = ht.TypeTypeParam(bound)
417
+ length_param = ht.BoundedNatParam()
418
+ return ht.PolyFuncType(
419
+ params=[elem_ty_param, length_param],
420
+ body=ht.FunctionType(
421
+ input=[
422
+ array_type(
423
+ ht.Option(ht.Variable(0, bound)),
424
+ ht.VariableArg(1, length_param),
425
+ ),
426
+ int_type().to_hugr(self.ctx),
427
+ ht.Variable(0, bound),
428
+ ],
429
+ output=[
430
+ array_type(
431
+ ht.Option(ht.Variable(0, bound)),
432
+ ht.VariableArg(1, length_param),
433
+ ),
434
+ ],
435
+ ),
436
+ )
437
+
438
+ def _build_classical_setitem(self, func: hf.Function) -> None:
439
+ """Constructs a generic function for `__setitem__` for classical arrays."""
440
+ elem_ty = ht.Variable(0, ht.TypeBound.Copyable)
441
+ length = ht.VariableArg(1, ht.BoundedNatParam())
442
+
443
+ elem_opt_ty = ht.Option(elem_ty)
444
+ idx = func.add_op(convert_itousize(), func.inputs()[1])
445
+ elem_opt = func.add_op(ops.Some(elem_ty), func.inputs()[2])
446
+ result = func.add_op(
447
+ array_set(elem_opt_ty, length),
448
+ func.inputs()[0],
449
+ idx,
450
+ elem_opt,
451
+ )
452
+ _, array = build_unwrap_right(func, result, "Array index out of bounds")
453
+
454
+ func.set_outputs(array)
455
+
456
+ def _build_linear_setitem(self, func: hf.Function) -> None:
457
+ """Constructs function to call `array.__setitem__` for linear arrays."""
458
+ elem_ty = ht.Variable(0, ht.TypeBound.Linear)
459
+ length = ht.VariableArg(1, ht.BoundedNatParam())
460
+
461
+ elem_opt_ty = ht.Option(elem_ty)
462
+ elem = func.add_op(ops.Some(elem_ty), func.inputs()[2])
463
+ idx = func.add_op(convert_itousize(), func.inputs()[1])
464
+ result = func.add_op(
465
+ array_set(elem_opt_ty, length),
466
+ func.inputs()[0],
467
+ idx,
468
+ elem,
469
+ )
470
+ old_elem_opt, array = build_unwrap_right(
471
+ func, result, "Array index out of bounds"
472
+ )
473
+ build_unwrap_left(func, old_elem_opt, "Linear array element has not been used")
474
+
475
+ func.set_outputs(array)
476
+
477
+ def _build_call_setitem(
478
+ self,
479
+ func: hf.Function,
480
+ array: Wire,
481
+ idx: Wire,
482
+ elem: Wire,
483
+ ) -> CallReturnWires:
484
+ """Inserts a call to `array.__setitem__`."""
485
+ concrete_func_ty = ht.FunctionType(
486
+ input=[
487
+ array_type(ht.Option(self.elem_ty), self.length),
488
+ int_type().to_hugr(self.ctx),
489
+ self.elem_ty,
490
+ ],
491
+ output=[array_type(ht.Option(self.elem_ty), self.length)],
492
+ )
493
+ type_args = [ht.TypeTypeArg(self.elem_ty), self.length]
494
+ func_call = self.builder.call(
495
+ func.parent_node,
496
+ array,
497
+ idx,
498
+ elem,
499
+ instantiation=concrete_func_ty,
500
+ type_args=type_args,
501
+ )
502
+ return CallReturnWires(
503
+ regular_returns=[],
504
+ inout_returns=list(func_call.outputs()),
505
+ )
506
+
507
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
508
+ [array, idx, elem] = args
509
+ if self.elem_ty.type_bound() == ht.TypeBound.Linear:
510
+ func_ty = self._setitem_ty(ht.TypeBound.Linear)
511
+ func, already_exists = self.ctx.declare_global_func(
512
+ ARRAY_SETITEM_LINEAR, func_ty
513
+ )
514
+ if not already_exists:
515
+ self._build_linear_setitem(func)
516
+ else:
517
+ func_ty = self._setitem_ty(ht.TypeBound.Copyable)
518
+ func, already_exists = self.ctx.declare_global_func(
519
+ ARRAY_SETITEM_CLASSICAL, func_ty
520
+ )
521
+ if not already_exists:
522
+ self._build_classical_setitem(func)
523
+ return self._build_call_setitem(func=func, array=array, idx=idx, elem=elem)
524
+
525
+ def compile(self, args: list[Wire]) -> list[Wire]:
526
+ raise InternalGuppyError("Call compile_with_inouts instead")
527
+
528
+
529
+ class ArrayIterAsertAllUsedCompiler(ArrayCompiler):
530
+ """Compiler for the `ArrayIter._assert_all_used` method."""
531
+
532
+ def compile(self, args: list[Wire]) -> list[Wire]:
533
+ # For linear array iterators, map the array of optional elements to an
534
+ # `array[None, n]` that we can discard.
535
+ if self.elem_ty.type_bound() == ht.TypeBound.Linear:
536
+ elem_opt_ty = ht.Option(self.elem_ty)
537
+ unit_ty = ht.UnitSum(1)
538
+ # Instantiate `unwrap_none` function
539
+ func = self.builder.load_function(
540
+ self.define_unwrap_none_helper(),
541
+ type_args=[ht.TypeTypeArg(self.elem_ty)],
542
+ instantiation=ht.FunctionType([elem_opt_ty], [unit_ty]),
543
+ )
544
+ # Map it over the array so that the resulting array is no longer linear and
545
+ # can be discarded
546
+ [array_iter] = args
547
+ array, _ = self.builder.add_op(ops.UnpackTuple(), array_iter)
548
+ self.builder.add_op(
549
+ array_map(elem_opt_ty, self.length, unit_ty), array, func
550
+ )
551
+ return []
552
+
553
+ def define_unwrap_none_helper(self) -> hf.Function:
554
+ """Define an `unwrap_none` function that checks that the passed element is
555
+ indeed `None`."""
556
+ opt_ty = ht.Option(ht.Variable(0, ht.TypeBound.Linear))
557
+ unit_ty = ht.UnitSum(1)
558
+ func_ty = ht.PolyFuncType(
559
+ params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
560
+ body=ht.FunctionType([opt_ty], [unit_ty]),
561
+ )
562
+ func, already_defined = self.ctx.declare_global_func(
563
+ ARRAY_ITER_ASSERT_ALL_USED_HELPER, func_ty
564
+ )
565
+ if not already_defined:
566
+ err_msg = "ArrayIter._assert_all_used: array element has not been used"
567
+ build_expect_none(func, func.inputs()[0], err_msg)
568
+ func.set_outputs(func.add_op(ops.MakeTuple()))
569
+ return func