guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +118 -5
  5. guppylang_internals/cfg/cfg.py +3 -0
  6. guppylang_internals/checker/cfg_checker.py +6 -0
  7. guppylang_internals/checker/core.py +5 -2
  8. guppylang_internals/checker/errors/generic.py +32 -1
  9. guppylang_internals/checker/errors/type_errors.py +14 -0
  10. guppylang_internals/checker/errors/wasm.py +7 -4
  11. guppylang_internals/checker/expr_checker.py +58 -17
  12. guppylang_internals/checker/func_checker.py +18 -14
  13. guppylang_internals/checker/linearity_checker.py +67 -10
  14. guppylang_internals/checker/modifier_checker.py +120 -0
  15. guppylang_internals/checker/stmt_checker.py +48 -1
  16. guppylang_internals/checker/unitary_checker.py +132 -0
  17. guppylang_internals/compiler/cfg_compiler.py +7 -6
  18. guppylang_internals/compiler/core.py +93 -56
  19. guppylang_internals/compiler/expr_compiler.py +72 -168
  20. guppylang_internals/compiler/modifier_compiler.py +176 -0
  21. guppylang_internals/compiler/stmt_compiler.py +15 -8
  22. guppylang_internals/decorator.py +86 -7
  23. guppylang_internals/definition/custom.py +39 -1
  24. guppylang_internals/definition/declaration.py +9 -6
  25. guppylang_internals/definition/function.py +12 -2
  26. guppylang_internals/definition/parameter.py +8 -3
  27. guppylang_internals/definition/pytket_circuits.py +14 -41
  28. guppylang_internals/definition/struct.py +13 -7
  29. guppylang_internals/definition/ty.py +3 -3
  30. guppylang_internals/definition/wasm.py +42 -10
  31. guppylang_internals/engine.py +9 -3
  32. guppylang_internals/experimental.py +5 -0
  33. guppylang_internals/nodes.py +147 -24
  34. guppylang_internals/std/_internal/checker.py +13 -108
  35. guppylang_internals/std/_internal/compiler/array.py +95 -283
  36. guppylang_internals/std/_internal/compiler/list.py +1 -1
  37. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  38. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  39. guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
  40. guppylang_internals/std/_internal/debug.py +18 -9
  41. guppylang_internals/std/_internal/util.py +1 -1
  42. guppylang_internals/tracing/object.py +10 -0
  43. guppylang_internals/tracing/unpacking.py +19 -20
  44. guppylang_internals/tys/arg.py +18 -3
  45. guppylang_internals/tys/builtin.py +2 -5
  46. guppylang_internals/tys/const.py +33 -4
  47. guppylang_internals/tys/errors.py +23 -1
  48. guppylang_internals/tys/param.py +31 -16
  49. guppylang_internals/tys/parsing.py +11 -24
  50. guppylang_internals/tys/printing.py +2 -8
  51. guppylang_internals/tys/qubit.py +62 -0
  52. guppylang_internals/tys/subst.py +8 -26
  53. guppylang_internals/tys/ty.py +91 -85
  54. guppylang_internals/wasm_util.py +129 -0
  55. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
  56. guppylang_internals-0.26.0.dist-info/RECORD +104 -0
  57. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
  58. guppylang_internals-0.24.0.dist-info/RECORD +0 -98
  59. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
@@ -4,9 +4,9 @@ from typing import ClassVar
4
4
 
5
5
  from typing_extensions import assert_never
6
6
 
7
- from guppylang_internals.ast_util import get_type, with_loc
8
- from guppylang_internals.checker.core import ComptimeVariable, Context
9
- from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
7
+ from guppylang_internals.ast_util import get_type, with_loc, with_type
8
+ from guppylang_internals.checker.core import Context
9
+ from guppylang_internals.checker.errors.generic import UnsupportedError
10
10
  from guppylang_internals.checker.errors.type_errors import (
11
11
  ArrayComprUnknownSizeError,
12
12
  TypeMismatchError,
@@ -33,8 +33,6 @@ from guppylang_internals.nodes import (
33
33
  GlobalCall,
34
34
  MakeIter,
35
35
  PanicExpr,
36
- PlaceNode,
37
- ResultExpr,
38
36
  )
39
37
  from guppylang_internals.tys.arg import ConstArg, TypeArg
40
38
  from guppylang_internals.tys.builtin import (
@@ -45,7 +43,6 @@ from guppylang_internals.tys.builtin import (
45
43
  get_iter_size,
46
44
  int_type,
47
45
  is_array_type,
48
- is_bool_type,
49
46
  is_sized_iter_type,
50
47
  nat_type,
51
48
  sized_iter_type,
@@ -58,7 +55,6 @@ from guppylang_internals.tys.ty import (
58
55
  FunctionType,
59
56
  InputFlags,
60
57
  NoneType,
61
- NumericType,
62
58
  Type,
63
59
  unify,
64
60
  )
@@ -302,80 +298,6 @@ class NewArrayChecker(CustomCallChecker):
302
298
  return with_loc(compr, array_compr), array_type(elt_ty, size)
303
299
 
304
300
 
305
- #: Maximum length of a tag in the `result` function.
306
- TAG_MAX_LEN = 200
307
-
308
-
309
- @dataclass(frozen=True)
310
- class TooLongError(Error):
311
- title: ClassVar[str] = "Tag too long"
312
- span_label: ClassVar[str] = "Result tag is too long"
313
-
314
- @dataclass(frozen=True)
315
- class Hint(Note):
316
- message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
317
-
318
-
319
- class ResultChecker(CustomCallChecker):
320
- """Call checker for the `result` function."""
321
-
322
- @dataclass(frozen=True)
323
- class InvalidError(Error):
324
- title: ClassVar[str] = "Invalid Result"
325
- span_label: ClassVar[str] = "Expression of type `{ty}` is not a valid result."
326
- ty: Type
327
-
328
- @dataclass(frozen=True)
329
- class Explanation(Note):
330
- message: ClassVar[str] = (
331
- "Only numeric values or arrays thereof are allowed as results"
332
- )
333
-
334
- def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
335
- check_num_args(2, len(args), self.node)
336
- [tag, value] = args
337
- tag, _ = ExprChecker(self.ctx).check(tag, string_type())
338
- match tag:
339
- case ast.Constant(value=str(v)):
340
- tag_value = v
341
- case PlaceNode(place=ComptimeVariable(static_value=str(v))):
342
- tag_value = v
343
- case _:
344
- raise GuppyTypeError(ExpectedError(tag, "a string literal"))
345
- if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
346
- err: Error = TooLongError(tag)
347
- err.add_sub_diagnostic(TooLongError.Hint(None))
348
- raise GuppyTypeError(err)
349
- value, ty = ExprSynthesizer(self.ctx).synthesize(value)
350
- # We only allow numeric values or vectors of numeric values
351
- err = ResultChecker.InvalidError(value, ty)
352
- err.add_sub_diagnostic(ResultChecker.InvalidError.Explanation(None))
353
- if self._is_numeric_or_bool_type(ty):
354
- base_ty = ty
355
- array_len: Const | None = None
356
- elif is_array_type(ty):
357
- [ty_arg, len_arg] = ty.args
358
- assert isinstance(ty_arg, TypeArg)
359
- assert isinstance(len_arg, ConstArg)
360
- if not self._is_numeric_or_bool_type(ty_arg.ty):
361
- raise GuppyError(err)
362
- base_ty = ty_arg.ty
363
- array_len = len_arg.const
364
- else:
365
- raise GuppyError(err)
366
- node = ResultExpr(value, base_ty, array_len, tag_value)
367
- return with_loc(self.node, node), NoneType()
368
-
369
- def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
370
- expr, res_ty = self.synthesize(args)
371
- expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
372
- return expr, subst
373
-
374
- @staticmethod
375
- def _is_numeric_or_bool_type(ty: Type) -> bool:
376
- return isinstance(ty, NumericType) or is_bool_type(ty)
377
-
378
-
379
301
  class PanicChecker(CustomCallChecker):
380
302
  """Call checker for the `panic` function."""
381
303
 
@@ -395,18 +317,16 @@ class PanicChecker(CustomCallChecker):
395
317
  err.add_sub_diagnostic(PanicChecker.NoMessageError.Suggestion(None))
396
318
  raise GuppyTypeError(err)
397
319
  case [msg, *rest]:
320
+ # Check type of message and synthesize types for additional values.
398
321
  msg, _ = ExprChecker(self.ctx).check(msg, string_type())
399
- match msg:
400
- case ast.Constant(value=str(v)):
401
- msg_value = v
402
- case PlaceNode(place=ComptimeVariable(static_value=str(v))):
403
- msg_value = v
404
- case _:
405
- raise GuppyTypeError(ExpectedError(msg, "a string literal"))
406
322
  vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
407
323
  # TODO variable signals once default arguments are available
324
+ # TODO this will also allow us to remove this manual AST node hack
325
+ signal_expr = with_type(
326
+ int_type(), with_loc(self.node, ast.Constant(value=1))
327
+ )
408
328
  node = PanicExpr(
409
- kind=ExitKind.Panic, msg=msg_value, values=vals, signal=1
329
+ kind=ExitKind.Panic, msg=msg, values=vals, signal=signal_expr
410
330
  )
411
331
  return with_loc(self.node, node), NoneType()
412
332
  case args:
@@ -454,31 +374,16 @@ class ExitChecker(CustomCallChecker):
454
374
  )
455
375
  raise GuppyTypeError(signal_err)
456
376
  case [msg, signal, *rest]:
377
+ # Check types for message and signal and synthesize types for additional
378
+ # values.
457
379
  msg, _ = ExprChecker(self.ctx).check(msg, string_type())
458
- match msg:
459
- case ast.Constant(value=str(v)):
460
- msg_value = v
461
- case PlaceNode(place=ComptimeVariable(static_value=str(v))):
462
- msg_value = v
463
- case _:
464
- raise GuppyTypeError(ExpectedError(msg, "a string literal"))
465
- # TODO allow variable signals after https://github.com/CQCL/hugr/issues/1863
466
380
  signal, _ = ExprChecker(self.ctx).check(signal, int_type())
467
- match signal:
468
- case ast.Constant(value=int(s)):
469
- signal_value = s
470
- case PlaceNode(place=ComptimeVariable(static_value=int(s))):
471
- signal_value = s
472
- case _:
473
- raise GuppyTypeError(
474
- ExpectedError(signal, "an integer literal")
475
- )
476
381
  vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
477
382
  node = PanicExpr(
478
383
  kind=ExitKind.ExitShot,
479
- msg=msg_value,
384
+ msg=msg,
480
385
  values=vals,
481
- signal=signal_value,
386
+ signal=signal,
482
387
  )
483
388
  return with_loc(self.node, node), NoneType()
484
389
  case args:
@@ -2,31 +2,23 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Final, TypeVar
5
+ from typing import TYPE_CHECKING, TypeVar
6
6
 
7
7
  import hugr
8
8
  from hugr import Wire, ops
9
9
  from hugr import tys as ht
10
- from hugr.std.collections.value_array import EXTENSION
10
+ from hugr.std.collections.borrow_array import EXTENSION
11
11
 
12
- from guppylang_internals.compiler.core import (
13
- GlobalConstId,
14
- )
15
12
  from guppylang_internals.definition.custom import CustomCallCompiler
16
13
  from guppylang_internals.definition.value import CallReturnWires
17
14
  from guppylang_internals.error import InternalGuppyError
18
15
  from guppylang_internals.std._internal.compiler.arithmetic import convert_itousize
19
16
  from guppylang_internals.std._internal.compiler.prelude import (
20
- build_expect_none,
21
- build_unwrap,
22
- build_unwrap_left,
23
17
  build_unwrap_right,
24
18
  )
25
19
  from guppylang_internals.tys.arg import ConstArg, TypeArg
26
- from guppylang_internals.tys.builtin import int_type
27
20
 
28
21
  if TYPE_CHECKING:
29
- from hugr.build import function as hf
30
22
  from hugr.build.dfg import DfBase
31
23
 
32
24
 
@@ -50,10 +42,10 @@ def _instantiate_array_op(
50
42
  def array_type(elem_ty: ht.Type, length: ht.TypeArg) -> ht.ExtType:
51
43
  """Returns the hugr type of a fixed length array.
52
44
 
53
- This is the copyable `value_array` type used by Guppy.
45
+ This is the linear `borrow_array` type used by Guppy.
54
46
  """
55
47
  elem_arg = ht.TypeTypeArg(elem_ty)
56
- return EXTENSION.types["value_array"].instantiate([length, elem_arg])
48
+ return EXTENSION.types["borrow_array"].instantiate([length, elem_arg])
57
49
 
58
50
 
59
51
  def standard_array_type(elem_ty: ht.Type, length: ht.TypeArg) -> ht.ExtType:
@@ -162,9 +154,9 @@ def array_repeat(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
162
154
  )
163
155
 
164
156
 
165
- def array_convert_to_std_array(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
166
- """Returns an array operation to convert the `value_array` type used by Guppy into
167
- the regular linear `array` in Hugr.
157
+ def array_to_std_array(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
158
+ """Returns an array operation to convert a value of the `borrow_array` type
159
+ used by Guppy into a standard `array`.
168
160
  """
169
161
  return EXTENSION.get_op("to_array").instantiate(
170
162
  [length, ht.TypeTypeArg(elem_ty)],
@@ -174,9 +166,9 @@ def array_convert_to_std_array(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtO
174
166
  )
175
167
 
176
168
 
177
- def array_convert_from_std_array(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
178
- """Returns an array operation to convert the `array` type used by Hugr into the
179
- `value_array` type used by Guppy.
169
+ def std_array_to_array(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
170
+ """Returns an array operation to convert the standard `array` type into the
171
+ `borrow_array` type used by Guppy.
180
172
  """
181
173
  return EXTENSION.get_op("from_array").instantiate(
182
174
  [length, ht.TypeTypeArg(elem_ty)],
@@ -186,6 +178,42 @@ def array_convert_from_std_array(elem_ty: ht.Type, length: ht.TypeArg) -> ops.Ex
186
178
  )
187
179
 
188
180
 
181
+ def barray_borrow(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
182
+ """Returns an array `borrow` operation."""
183
+ arr_ty = array_type(elem_ty, length)
184
+ return _instantiate_array_op(
185
+ "borrow", elem_ty, length, [arr_ty, ht.USize()], [arr_ty, elem_ty]
186
+ )
187
+
188
+
189
+ def barray_return(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
190
+ """Returns an array `return` operation."""
191
+ arr_ty = array_type(elem_ty, length)
192
+ return _instantiate_array_op(
193
+ "return", elem_ty, length, [arr_ty, ht.USize(), elem_ty], [arr_ty]
194
+ )
195
+
196
+
197
+ def barray_discard_all_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
198
+ """Returns an array `discard_all_borrowed` operation."""
199
+ arr_ty = array_type(elem_ty, length)
200
+ return _instantiate_array_op("discard_all_borrowed", elem_ty, length, [arr_ty], [])
201
+
202
+
203
+ def barray_new_all_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
204
+ """Returns an array `new_all_borrowed` operation."""
205
+ arr_ty = array_type(elem_ty, length)
206
+ return _instantiate_array_op("new_all_borrowed", elem_ty, length, [], [arr_ty])
207
+
208
+
209
+ def array_clone(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
210
+ """Returns an array `clone` operation for arrays none of whose elements are
211
+ borrowed."""
212
+ assert elem_ty.type_bound() == ht.TypeBound.Copyable
213
+ arr_ty = array_type(elem_ty, length)
214
+ return _instantiate_array_op("clone", elem_ty, length, [arr_ty], [arr_ty, arr_ty])
215
+
216
+
189
217
  # ------------------------------------------------------
190
218
  # --------- Custom compilers for non-native ops --------
191
219
  # ------------------------------------------------------
@@ -233,17 +261,12 @@ class NewArrayCompiler(ArrayCompiler):
233
261
 
234
262
  def build_classical_array(self, elems: list[Wire]) -> Wire:
235
263
  """Lowers a call to `array.__new__` for classical arrays."""
236
- # See https://github.com/CQCL/guppylang/issues/629
264
+ # See https://github.com/quantinuum/guppylang/issues/629
237
265
  return self.build_linear_array(elems)
238
266
 
239
267
  def build_linear_array(self, elems: list[Wire]) -> Wire:
240
268
  """Lowers a call to `array.__new__` for linear arrays."""
241
- elem_opts = [
242
- self.builder.add_op(ops.Some(self.elem_ty), elem) for elem in elems
243
- ]
244
- return self.builder.add_op(
245
- array_new(ht.Option(self.elem_ty), len(elems)), *elem_opts
246
- )
269
+ return self.builder.add_op(array_new(self.elem_ty, len(elems)), *elems)
247
270
 
248
271
  def compile(self, args: list[Wire]) -> list[Wire]:
249
272
  if self.elem_ty.type_bound() == ht.TypeBound.Linear:
@@ -252,131 +275,35 @@ class NewArrayCompiler(ArrayCompiler):
252
275
  return [self.build_classical_array(args)]
253
276
 
254
277
 
255
- ARRAY_GETITEM_CLASSICAL: Final[GlobalConstId] = GlobalConstId.fresh(
256
- "array.__getitem__.classical"
257
- )
258
- ARRAY_GETITEM_LINEAR: Final[GlobalConstId] = GlobalConstId.fresh(
259
- "array.__getitem__.linear"
260
- )
261
- ARRAY_SETITEM_CLASSICAL: Final[GlobalConstId] = GlobalConstId.fresh(
262
- "array.__setitem__.classical"
263
- )
264
- ARRAY_SETITEM_LINEAR: Final[GlobalConstId] = GlobalConstId.fresh(
265
- "array.__setitem__.linear"
266
- )
267
- ARRAY_ITER_ASSERT_ALL_USED_HELPER: Final[GlobalConstId] = GlobalConstId.fresh(
268
- "ArrayIter._assert_all_used.helper"
269
- )
270
-
271
-
272
278
  class ArrayGetitemCompiler(ArrayCompiler):
273
279
  """Compiler for the `array.__getitem__` function."""
274
280
 
275
- def _getitem_ty(self, bound: ht.TypeBound) -> ht.PolyFuncType:
276
- """Constructs a polymorphic function type for `__getitem__`"""
277
- # a(Option(T), N), int -> T, a(Option(T), N)
278
- # Array element type parameter
279
- elem_ty_param = ht.TypeTypeParam(bound)
280
- # Array length parameter
281
- length_param = ht.BoundedNatParam()
282
- return ht.PolyFuncType(
283
- params=[elem_ty_param, length_param],
284
- body=ht.FunctionType(
285
- input=[
286
- array_type(
287
- ht.Option(ht.Variable(0, bound)),
288
- ht.VariableArg(1, length_param),
289
- ),
290
- int_type().to_hugr(self.ctx),
291
- ],
292
- output=[
293
- ht.Variable(0, bound),
294
- array_type(
295
- ht.Option(ht.Variable(0, bound)),
296
- ht.VariableArg(1, length_param),
297
- ),
298
- ],
299
- ),
300
- )
281
+ def _build_classical_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
282
+ """Constructs `__getitem__` for classical arrays."""
283
+ idx = self.builder.add_op(convert_itousize(), idx)
301
284
 
302
- def _build_classical_getitem(self, func: hf.Function) -> None:
303
- """Constructs a generic function for `__getitem__` for classical arrays."""
304
- elem_ty = ht.Variable(0, ht.TypeBound.Copyable)
305
- length = ht.VariableArg(1, ht.BoundedNatParam())
306
-
307
- # See https://github.com/CQCL/guppylang/issues/629
308
- elem_opt_ty = ht.Option(elem_ty)
309
- none = func.add_op(ops.Tag(0, elem_opt_ty))
310
- idx = func.add_op(convert_itousize(), func.inputs()[1])
311
- # As copyable elements can be used multiple times, we need to swap the element
312
- # back after initially swapping it out for `None` to get the value.
313
- initial_result = func.add_op(
314
- array_set(elem_opt_ty, length),
315
- func.inputs()[0],
316
- idx,
317
- none,
318
- )
319
- elem_opt, arr = build_unwrap_right(
320
- func, initial_result, "Array index out of bounds"
321
- )
322
- swapped_back = func.add_op(
323
- array_set(elem_opt_ty, length),
324
- arr,
325
- idx,
326
- elem_opt,
327
- )
328
- _, arr = build_unwrap_right(func, swapped_back, "Array index out of bounds")
329
- elem = build_unwrap(func, elem_opt, "array.__getitem__: Internal error")
330
-
331
- func.set_outputs(elem, arr)
332
-
333
- def _build_linear_getitem(self, func: hf.Function) -> None:
334
- """Constructs function to call `array.__getitem__` for linear arrays."""
335
- elem_ty = ht.Variable(0, ht.TypeBound.Linear)
336
- length = ht.VariableArg(1, ht.BoundedNatParam())
337
-
338
- elem_opt_ty = ht.Option(elem_ty)
339
- none = func.add_op(ops.Tag(0, elem_opt_ty))
340
- idx = func.add_op(convert_itousize(), func.inputs()[1])
341
- result = func.add_op(
342
- array_set(elem_opt_ty, length),
343
- func.inputs()[0],
285
+ opt_elem, arr = self.builder.add_op(
286
+ array_get(self.elem_ty, self.length),
287
+ array,
344
288
  idx,
345
- none,
346
289
  )
347
- elem_opt, array = build_unwrap_right(func, result, "Array index out of bounds")
348
- elem = build_unwrap(
349
- func, elem_opt, "Linear array element has already been used"
290
+ elem = build_unwrap_right(self.builder, opt_elem, "Array index out of bounds")
291
+ return CallReturnWires(
292
+ regular_returns=[elem],
293
+ inout_returns=[arr],
350
294
  )
351
295
 
352
- func.set_outputs(elem, array)
353
-
354
- def _build_call_getitem(
355
- self,
356
- func: hf.Function,
357
- array: Wire,
358
- idx: Wire,
359
- ) -> CallReturnWires:
360
- """Inserts a call to `array.__getitem__`."""
361
- concrete_func_ty = ht.FunctionType(
362
- input=[
363
- array_type(ht.Option(self.elem_ty), self.length),
364
- int_type().to_hugr(self.ctx),
365
- ],
366
- output=[self.elem_ty, array_type(ht.Option(self.elem_ty), self.length)],
367
- )
368
- type_args = [ht.TypeTypeArg(self.elem_ty), self.length]
369
- func_call = self.builder.call(
370
- func.parent_node,
296
+ def _build_linear_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
297
+ """Constructs `array.__getitem__` for linear arrays."""
298
+ idx = self.builder.add_op(convert_itousize(), idx)
299
+ arr, elem = self.builder.add_op(
300
+ barray_borrow(self.elem_ty, self.length),
371
301
  array,
372
302
  idx,
373
- instantiation=concrete_func_ty,
374
- type_args=type_args,
375
303
  )
376
- outputs = list(func_call.outputs())
377
304
  return CallReturnWires(
378
- regular_returns=[outputs[0]],
379
- inout_returns=[outputs[1]],
305
+ regular_returns=[elem],
306
+ inout_returns=[arr],
380
307
  )
381
308
 
382
309
  def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
@@ -384,24 +311,9 @@ class ArrayGetitemCompiler(ArrayCompiler):
384
311
  [elem_ty_arg, _] = self.type_args
385
312
  assert isinstance(elem_ty_arg, TypeArg)
386
313
  if not elem_ty_arg.ty.copyable:
387
- func_ty = self._getitem_ty(ht.TypeBound.Linear)
388
- func, already_exists = self.ctx.declare_global_func(
389
- ARRAY_GETITEM_LINEAR, func_ty
390
- )
391
- if not already_exists:
392
- self._build_linear_getitem(func)
314
+ return self._build_linear_getitem(array, idx)
393
315
  else:
394
- func_ty = self._getitem_ty(ht.TypeBound.Copyable)
395
- func, already_exists = self.ctx.declare_global_func(
396
- ARRAY_GETITEM_CLASSICAL, func_ty
397
- )
398
- if not already_exists:
399
- self._build_classical_getitem(func)
400
- return self._build_call_getitem(
401
- func=func,
402
- array=array,
403
- idx=idx,
404
- )
316
+ return self._build_classical_getitem(array, idx)
405
317
 
406
318
  def compile(self, args: list[Wire]) -> list[Wire]:
407
319
  raise InternalGuppyError("Call compile_with_inouts instead")
@@ -410,160 +322,60 @@ class ArrayGetitemCompiler(ArrayCompiler):
410
322
  class ArraySetitemCompiler(ArrayCompiler):
411
323
  """Compiler for the `array.__setitem__` function."""
412
324
 
413
- def _setitem_ty(self, bound: ht.TypeBound) -> ht.PolyFuncType:
414
- """Constructs a polymorphic function type for `__setitem__`"""
415
- # a(Option(T), N), int, T -> a(Option(T), N)
416
- elem_ty_param = ht.TypeTypeParam(bound)
417
- length_param = ht.BoundedNatParam()
418
- return ht.PolyFuncType(
419
- params=[elem_ty_param, length_param],
420
- body=ht.FunctionType(
421
- input=[
422
- array_type(
423
- ht.Option(ht.Variable(0, bound)),
424
- ht.VariableArg(1, length_param),
425
- ),
426
- int_type().to_hugr(self.ctx),
427
- ht.Variable(0, bound),
428
- ],
429
- output=[
430
- array_type(
431
- ht.Option(ht.Variable(0, bound)),
432
- ht.VariableArg(1, length_param),
433
- ),
434
- ],
435
- ),
436
- )
437
-
438
- def _build_classical_setitem(self, func: hf.Function) -> None:
439
- """Constructs a generic function for `__setitem__` for classical arrays."""
440
- elem_ty = ht.Variable(0, ht.TypeBound.Copyable)
441
- length = ht.VariableArg(1, ht.BoundedNatParam())
442
-
443
- elem_opt_ty = ht.Option(elem_ty)
444
- idx = func.add_op(convert_itousize(), func.inputs()[1])
445
- elem_opt = func.add_op(ops.Some(elem_ty), func.inputs()[2])
446
- result = func.add_op(
447
- array_set(elem_opt_ty, length),
448
- func.inputs()[0],
449
- idx,
450
- elem_opt,
451
- )
452
- _, array = build_unwrap_right(func, result, "Array index out of bounds")
453
-
454
- func.set_outputs(array)
455
-
456
- def _build_linear_setitem(self, func: hf.Function) -> None:
457
- """Constructs function to call `array.__setitem__` for linear arrays."""
458
- elem_ty = ht.Variable(0, ht.TypeBound.Linear)
459
- length = ht.VariableArg(1, ht.BoundedNatParam())
460
-
461
- elem_opt_ty = ht.Option(elem_ty)
462
- elem = func.add_op(ops.Some(elem_ty), func.inputs()[2])
463
- idx = func.add_op(convert_itousize(), func.inputs()[1])
464
- result = func.add_op(
465
- array_set(elem_opt_ty, length),
466
- func.inputs()[0],
325
+ def _build_classical_setitem(
326
+ self, array: Wire, idx: Wire, elem: Wire
327
+ ) -> CallReturnWires:
328
+ """Constructs `__setitem__` for classical arrays."""
329
+ idx = self.builder.add_op(convert_itousize(), idx)
330
+ result = self.builder.add_op(
331
+ array_set(self.elem_ty, self.length),
332
+ array,
467
333
  idx,
468
334
  elem,
469
335
  )
470
- old_elem_opt, array = build_unwrap_right(
471
- func, result, "Array index out of bounds"
472
- )
473
- build_unwrap_left(func, old_elem_opt, "Linear array element has not been used")
336
+ _, arr = build_unwrap_right(self.builder, result, "Array index out of bounds")
474
337
 
475
- func.set_outputs(array)
338
+ return CallReturnWires(
339
+ regular_returns=[],
340
+ inout_returns=[arr],
341
+ )
476
342
 
477
- def _build_call_setitem(
478
- self,
479
- func: hf.Function,
480
- array: Wire,
481
- idx: Wire,
482
- elem: Wire,
343
+ def _build_linear_setitem(
344
+ self, array: Wire, idx: Wire, elem: Wire
483
345
  ) -> CallReturnWires:
484
- """Inserts a call to `array.__setitem__`."""
485
- concrete_func_ty = ht.FunctionType(
486
- input=[
487
- array_type(ht.Option(self.elem_ty), self.length),
488
- int_type().to_hugr(self.ctx),
489
- self.elem_ty,
490
- ],
491
- output=[array_type(ht.Option(self.elem_ty), self.length)],
492
- )
493
- type_args = [ht.TypeTypeArg(self.elem_ty), self.length]
494
- func_call = self.builder.call(
495
- func.parent_node,
346
+ """Constructs `array.__setitem__` for linear arrays."""
347
+ idx = self.builder.add_op(convert_itousize(), idx)
348
+ arr = self.builder.add_op(
349
+ barray_return(self.elem_ty, self.length),
496
350
  array,
497
351
  idx,
498
352
  elem,
499
- instantiation=concrete_func_ty,
500
- type_args=type_args,
501
353
  )
354
+
502
355
  return CallReturnWires(
503
356
  regular_returns=[],
504
- inout_returns=list(func_call.outputs()),
357
+ inout_returns=[arr],
505
358
  )
506
359
 
507
360
  def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
508
361
  [array, idx, elem] = args
509
362
  if self.elem_ty.type_bound() == ht.TypeBound.Linear:
510
- func_ty = self._setitem_ty(ht.TypeBound.Linear)
511
- func, already_exists = self.ctx.declare_global_func(
512
- ARRAY_SETITEM_LINEAR, func_ty
513
- )
514
- if not already_exists:
515
- self._build_linear_setitem(func)
363
+ return self._build_linear_setitem(array, idx, elem)
516
364
  else:
517
- func_ty = self._setitem_ty(ht.TypeBound.Copyable)
518
- func, already_exists = self.ctx.declare_global_func(
519
- ARRAY_SETITEM_CLASSICAL, func_ty
520
- )
521
- if not already_exists:
522
- self._build_classical_setitem(func)
523
- return self._build_call_setitem(func=func, array=array, idx=idx, elem=elem)
365
+ return self._build_classical_setitem(array, idx, elem)
524
366
 
525
367
  def compile(self, args: list[Wire]) -> list[Wire]:
526
368
  raise InternalGuppyError("Call compile_with_inouts instead")
527
369
 
528
370
 
529
- class ArrayIterAsertAllUsedCompiler(ArrayCompiler):
530
- """Compiler for the `ArrayIter._assert_all_used` method."""
371
+ class ArrayDiscardAllUsedCompiler(ArrayCompiler):
372
+ """Compiler for the `_array_discard_all_used` method."""
531
373
 
532
374
  def compile(self, args: list[Wire]) -> list[Wire]:
533
- # For linear array iterators, map the array of optional elements to an
534
- # `array[None, n]` that we can discard.
535
375
  if self.elem_ty.type_bound() == ht.TypeBound.Linear:
536
- elem_opt_ty = ht.Option(self.elem_ty)
537
- unit_ty = ht.UnitSum(1)
538
- # Instantiate `unwrap_none` function
539
- func = self.builder.load_function(
540
- self.define_unwrap_none_helper(),
541
- type_args=[ht.TypeTypeArg(self.elem_ty)],
542
- instantiation=ht.FunctionType([elem_opt_ty], [unit_ty]),
543
- )
544
- # Map it over the array so that the resulting array is no longer linear and
545
- # can be discarded
546
- [array_iter] = args
547
- array, _ = self.builder.add_op(ops.UnpackTuple(), array_iter)
376
+ [arr] = args
548
377
  self.builder.add_op(
549
- array_map(elem_opt_ty, self.length, unit_ty), array, func
378
+ barray_discard_all_borrowed(self.elem_ty, self.length),
379
+ arr,
550
380
  )
551
381
  return []
552
-
553
- def define_unwrap_none_helper(self) -> hf.Function:
554
- """Define an `unwrap_none` function that checks that the passed element is
555
- indeed `None`."""
556
- opt_ty = ht.Option(ht.Variable(0, ht.TypeBound.Linear))
557
- unit_ty = ht.UnitSum(1)
558
- func_ty = ht.PolyFuncType(
559
- params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
560
- body=ht.FunctionType([opt_ty], [unit_ty]),
561
- )
562
- func, already_defined = self.ctx.declare_global_func(
563
- ARRAY_ITER_ASSERT_ALL_USED_HELPER, func_ty
564
- )
565
- if not already_defined:
566
- err_msg = "ArrayIter._assert_all_used: array element has not been used"
567
- build_expect_none(func, func.inputs()[0], err_msg)
568
- func.set_outputs(func.add_op(ops.MakeTuple()))
569
- return func