guppylang-internals 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +101 -3
  5. guppylang_internals/checker/core.py +12 -0
  6. guppylang_internals/checker/errors/generic.py +32 -1
  7. guppylang_internals/checker/errors/type_errors.py +14 -0
  8. guppylang_internals/checker/expr_checker.py +55 -29
  9. guppylang_internals/checker/func_checker.py +171 -22
  10. guppylang_internals/checker/linearity_checker.py +65 -0
  11. guppylang_internals/checker/modifier_checker.py +116 -0
  12. guppylang_internals/checker/stmt_checker.py +49 -2
  13. guppylang_internals/compiler/core.py +90 -53
  14. guppylang_internals/compiler/expr_compiler.py +49 -114
  15. guppylang_internals/compiler/modifier_compiler.py +174 -0
  16. guppylang_internals/compiler/stmt_compiler.py +15 -8
  17. guppylang_internals/decorator.py +124 -58
  18. guppylang_internals/definition/const.py +2 -2
  19. guppylang_internals/definition/custom.py +36 -2
  20. guppylang_internals/definition/declaration.py +4 -5
  21. guppylang_internals/definition/extern.py +2 -2
  22. guppylang_internals/definition/function.py +1 -1
  23. guppylang_internals/definition/parameter.py +10 -5
  24. guppylang_internals/definition/pytket_circuits.py +14 -42
  25. guppylang_internals/definition/struct.py +17 -14
  26. guppylang_internals/definition/traced.py +1 -1
  27. guppylang_internals/definition/ty.py +9 -3
  28. guppylang_internals/definition/wasm.py +2 -2
  29. guppylang_internals/engine.py +13 -2
  30. guppylang_internals/experimental.py +5 -0
  31. guppylang_internals/nodes.py +124 -23
  32. guppylang_internals/std/_internal/compiler/array.py +94 -282
  33. guppylang_internals/std/_internal/compiler/tket_exts.py +12 -8
  34. guppylang_internals/std/_internal/compiler/wasm.py +37 -26
  35. guppylang_internals/tracing/function.py +13 -2
  36. guppylang_internals/tracing/unpacking.py +33 -28
  37. guppylang_internals/tys/arg.py +18 -3
  38. guppylang_internals/tys/builtin.py +32 -16
  39. guppylang_internals/tys/const.py +33 -4
  40. guppylang_internals/tys/errors.py +6 -0
  41. guppylang_internals/tys/param.py +31 -16
  42. guppylang_internals/tys/parsing.py +118 -145
  43. guppylang_internals/tys/qubit.py +27 -0
  44. guppylang_internals/tys/subst.py +8 -26
  45. guppylang_internals/tys/ty.py +31 -21
  46. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +4 -4
  47. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +49 -46
  48. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
  49. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
@@ -1,6 +1,5 @@
1
1
  import itertools
2
2
  from abc import ABC
3
- from collections import defaultdict
4
3
  from collections.abc import Callable, Iterator, Sequence
5
4
  from contextlib import contextmanager
6
5
  from dataclasses import dataclass, field
@@ -16,6 +15,7 @@ from hugr.hugr.base import OpVarCov
16
15
  from hugr.hugr.node_port import ToNode
17
16
  from hugr.std import PRELUDE
18
17
  from hugr.std.collections.array import EXTENSION as ARRAY_EXTENSION
18
+ from hugr.std.collections.borrow_array import EXTENSION as BORROW_ARRAY_EXTENSION
19
19
  from typing_extensions import assert_never
20
20
 
21
21
  from guppylang_internals.checker.core import (
@@ -43,8 +43,8 @@ from guppylang_internals.tys.arg import ConstArg, TypeArg
43
43
  from guppylang_internals.tys.builtin import nat_type
44
44
  from guppylang_internals.tys.common import ToHugrContext
45
45
  from guppylang_internals.tys.const import BoundConstVar, ConstValue
46
- from guppylang_internals.tys.param import ConstParam, Parameter, TypeParam
47
- from guppylang_internals.tys.subst import Inst
46
+ from guppylang_internals.tys.param import ConstParam, Parameter
47
+ from guppylang_internals.tys.subst import Inst, Instantiator
48
48
  from guppylang_internals.tys.ty import (
49
49
  BoundTypeVar,
50
50
  NumericType,
@@ -221,10 +221,10 @@ class CompilerContext(ToHugrContext):
221
221
  match ENGINE.get_checked(defn.id):
222
222
  case MonomorphizableDef(params=params) as defn:
223
223
  # Entry point is not allowed to require monomorphization
224
- for param in params:
225
- if requires_monomorphization(param):
226
- err = EntryMonomorphizeError(defn.defined_at, defn.name, param)
227
- raise GuppyError(err)
224
+ if mono_params := require_monomorphization(params):
225
+ mono_param = mono_params.pop()
226
+ err = EntryMonomorphizeError(defn.defined_at, defn.name, mono_param)
227
+ raise GuppyError(err)
228
228
  # Thus, the partial monomorphization for the entry point is always empty
229
229
  entry_mono_args = tuple(None for _ in params)
230
230
  entry_compiled = defn.monomorphize(self.module, entry_mono_args, self)
@@ -247,7 +247,7 @@ class CompilerContext(ToHugrContext):
247
247
 
248
248
  # Insert explicit drops for affine types
249
249
  # TODO: This is a quick workaround until we can properly insert these drops
250
- # during linearity checking. See https://github.com/CQCL/guppylang/issues/1082
250
+ # during linearity checking. See https://github.com/CQCL/guppylang/issues/1082
251
251
  insert_drops(self.module.hugr)
252
252
 
253
253
  return entry_compiled
@@ -467,19 +467,27 @@ def is_return_var(x: str) -> bool:
467
467
  return x.startswith("%ret")
468
468
 
469
469
 
470
- def requires_monomorphization(param: Parameter) -> bool:
471
- """Checks if a type parameter must be monomorphized before compiling to Hugr.
470
+ def require_monomorphization(params: Sequence[Parameter]) -> set[Parameter]:
471
+ """Returns the subset of type parameters that must be monomorphized before compiling
472
+ to Hugr.
472
473
 
473
474
  This is required for some Guppy language features that cannot be encoded in Hugr
474
- yet. Currently, this only applies to non-nat const parameters.
475
+ yet. Currently, this applies to:
476
+ * non-nat const parameters
477
+ * parameters that occur in the type of const parameters
475
478
  """
476
- match param:
477
- case TypeParam():
478
- return False
479
- case ConstParam(ty=ty):
480
- return ty != NumericType(NumericType.Kind.Nat)
481
- case x:
482
- return assert_never(x)
479
+ mono_params: set[Parameter] = set()
480
+ for param in params:
481
+ match param:
482
+ case ConstParam(ty=ty) if ty != NumericType(NumericType.Kind.Nat):
483
+ mono_params.add(param)
484
+ # If the constant type refers to any bound variables, then those will
485
+ # need to be monomorphized as well
486
+ for var in ty.bound_vars:
487
+ mono_params.add(params[var.idx])
488
+ case _:
489
+ pass
490
+ return mono_params
483
491
 
484
492
 
485
493
  def partially_monomorphize_args(
@@ -493,27 +501,36 @@ def partially_monomorphize_args(
493
501
 
494
502
  Also takes care of normalising bound variables w.r.t. the current monomorphization.
495
503
  """
496
- mono_args: list[Argument | None] = []
497
- rem_args = []
504
+ # Normalise args w.r.t. the current outer monomorphisation
505
+ if ctx.current_mono_args is not None:
506
+ instantiator = Instantiator(ctx.current_mono_args, allow_partial=True)
507
+ args = [arg.transform(instantiator) for arg in args]
508
+
509
+ # Filter args depending on whether they need monomorphization or not. For this, we
510
+ # can't purely rely on the `require_monomorphization` function above since the
511
+ # instantiation also needs to be taken into account. For example, consider
512
+ # a function `def foo[T, x: T](...)`. Normally, both parameters would need to be
513
+ # monomorphized. However, if we instantiate `T := nat`, suddenly `x` no longer needs
514
+ # be monomorphized since we can use bounded nats in Hugr. Thus, we have to replicate
515
+ # the behaviour of `require_monomorphization` while taking `args` into account.
516
+ mono_args: list[Argument | None] = [None] * len(args)
498
517
  for param, arg in zip(params, args, strict=True):
499
- if requires_monomorphization(param):
500
- # The constant could still refer to a bound variable, so we need to
501
- # instantiate it w.r.t. to the current monomorphization
502
- match arg:
503
- case ConstArg(const=BoundConstVar(idx=idx)):
504
- assert ctx.current_mono_args is not None
505
- inst = ctx.current_mono_args[idx]
506
- assert inst is not None
507
- mono_args.append(inst)
508
- case TypeArg():
509
- # TODO: Once we also have type args that require monomorphization,
510
- # we'll need to downshift de Bruijn indices here as well
511
- raise NotImplementedError
512
- case arg:
513
- mono_args.append(arg)
514
- else:
515
- mono_args.append(None)
516
- rem_args.append(arg)
518
+ match param, param.instantiate_bounds(args):
519
+ case ConstParam(ty=original_ty), ConstParam(ty=inst_ty):
520
+ # If the original const type is not `nat`, then all variables occurring
521
+ # in the type need to be monomorphised
522
+ if original_ty != nat_type():
523
+ for var in original_ty.bound_vars:
524
+ mono_args[var.idx] = args[var.idx]
525
+ # If the const type is still not `nat` after normalisation, then the
526
+ # const arg itself also needs to be monomorphized
527
+ if inst_ty != nat_type():
528
+ mono_args[param.idx] = arg
529
+
530
+ # Mono-arguments should not refer to any bound variables
531
+ assert all(mono_arg is None or not mono_arg.bound_vars for mono_arg in mono_args)
532
+
533
+ rem_args = [arg for i, arg in enumerate(args) if mono_args[i] is None]
517
534
  return tuple(mono_args), rem_args
518
535
 
519
536
 
@@ -565,6 +582,9 @@ def may_have_side_effect(op: ops.Op) -> bool:
565
582
  # precise answer
566
583
  return True
567
584
  case _:
585
+ # There is no need to handle TailLoop (in case of non-termination) since
586
+ # TailLoops are only generated for array comprehensions which must have
587
+ # statically-guaranteed (finite) size. TODO revisit this for lists.
568
588
  return False
569
589
 
570
590
 
@@ -578,9 +598,7 @@ def track_hugr_side_effects() -> Iterator[None]:
578
598
  # Remember original `Hugr.add_node` method that is monkey-patched below.
579
599
  hugr_add_node = Hugr.add_node
580
600
  # Last node with potential side effects for each dataflow parent
581
- prev_node_with_side_effect: defaultdict[Node, Node | None] = defaultdict(
582
- lambda: None
583
- )
601
+ prev_node_with_side_effect: dict[Node, tuple[Node, Hugr[Any]]] = {}
584
602
 
585
603
  def hugr_add_node_with_order(
586
604
  self: Hugr[OpVarCov],
@@ -601,22 +619,40 @@ def track_hugr_side_effects() -> Iterator[None]:
601
619
  """Performs the actual order-edge insertion, assuming that `node` has a side-
602
620
  effect."""
603
621
  parent = hugr[node].parent
604
- if parent is not None:
605
- if prev := prev_node_with_side_effect[parent]:
606
- hugr.add_order_link(prev, node)
607
- else:
608
- # If this is the first side-effectful op in this DFG, make a recursive
609
- # call with the parent since the parent is also considered side-
610
- # effectful now. We shouldn't walk up through function definitions
611
- # or basic blocks though
612
- if not isinstance(hugr[parent].op, ops.FuncDefn | ops.DataflowBlock):
613
- handle_side_effect(parent, hugr)
614
- prev_node_with_side_effect[parent] = node
622
+ assert parent is not None
623
+
624
+ if prev := prev_node_with_side_effect.get(parent):
625
+ prev_node = prev[0]
626
+ else:
627
+ # This is the first side-effectful op in this DFG. Recurse on the parent
628
+ # since the parent is also considered side-effectful now. We shouldn't walk
629
+ # up through function definitions (only the Module is above)
630
+ if not isinstance(hugr[parent].op, ops.FuncDefn):
631
+ handle_side_effect(parent, hugr)
632
+ # For DataflowBlocks and Cases, recurse to mark their containing CFG
633
+ # or Conditional as side-effectful as well, but there is nothing to do
634
+ # locally: we cannot add order edges, but Conditional/CFG semantics
635
+ # ensure execution if appropriate.
636
+ if isinstance(hugr[parent].op, ops.Conditional | ops.CFG):
637
+ return
638
+ prev_node = hugr.children(parent)[0]
639
+ assert isinstance(hugr[prev_node].op, ops.Input)
640
+
641
+ # Add edge, but avoid self-loops for containers when recursing up the hierarchy.
642
+ if prev_node != node:
643
+ hugr.add_order_link(prev_node, node)
644
+ prev_node_with_side_effect[parent] = (node, hugr)
615
645
 
616
646
  # Monkey-patch the `add_node` method
617
647
  Hugr.add_node = hugr_add_node_with_order # type: ignore[method-assign]
618
648
  try:
619
649
  yield
650
+ for parent, (last, hugr) in prev_node_with_side_effect.items():
651
+ # Connect the last side-effecting node to Output
652
+ outp = hugr.children(parent)[1]
653
+ assert isinstance(hugr[outp].op, ops.Output)
654
+ assert last != outp
655
+ hugr.add_order_link(last, outp)
620
656
  finally:
621
657
  Hugr.add_node = hugr_add_node # type: ignore[method-assign]
622
658
 
@@ -634,6 +670,7 @@ def qualified_name(type_def: he.TypeDef) -> str:
634
670
  #: insertion of an explicit drop operation.
635
671
  AFFINE_EXTENSION_TYS: list[str] = [
636
672
  qualified_name(ARRAY_EXTENSION.get_type("array")),
673
+ qualified_name(BORROW_ARRAY_EXTENSION.get_type("borrow_array")),
637
674
  ]
638
675
 
639
676
 
@@ -681,7 +718,7 @@ def insert_drops(hugr: Hugr[OpVarCov]) -> None:
681
718
  # Iterating over `node.outputs()` doesn't work reliably since it sometimes
682
719
  # raises an `IncompleteOp` exception. Instead, we query the number of out ports
683
720
  # and look them up by index. However, this method is *also* broken when
684
- # isnpecting `FuncDefn` nodes due to https://github.com/CQCL/hugr/issues/2438.
721
+ # inspecting `FuncDefn` nodes due to https://github.com/CQCL/hugr/issues/2438.
685
722
  if isinstance(data.op, ops.FuncDefn):
686
723
  continue
687
724
  for i in range(hugr.num_out_ports(node)):
@@ -63,14 +63,17 @@ from guppylang_internals.nodes import (
63
63
  from guppylang_internals.std._internal.compiler.arithmetic import (
64
64
  UnsignedIntVal,
65
65
  convert_ifromusize,
66
+ convert_itousize,
66
67
  )
67
68
  from guppylang_internals.std._internal.compiler.array import (
68
- array_convert_from_std_array,
69
- array_convert_to_std_array,
69
+ array_clone,
70
70
  array_map,
71
71
  array_new,
72
- array_repeat,
72
+ array_to_std_array,
73
+ barray_new_all_borrowed,
74
+ barray_return,
73
75
  standard_array_type,
76
+ std_array_to_array,
74
77
  unpack_array,
75
78
  )
76
79
  from guppylang_internals.std._internal.compiler.list import (
@@ -79,7 +82,6 @@ from guppylang_internals.std._internal.compiler.list import (
79
82
  from guppylang_internals.std._internal.compiler.prelude import (
80
83
  build_error,
81
84
  build_panic,
82
- build_unwrap,
83
85
  panic,
84
86
  )
85
87
  from guppylang_internals.std._internal.compiler.tket_bool import (
@@ -91,6 +93,7 @@ from guppylang_internals.std._internal.compiler.tket_bool import (
91
93
  )
92
94
  from guppylang_internals.tys.arg import ConstArg
93
95
  from guppylang_internals.tys.builtin import (
96
+ array_type,
94
97
  bool_type,
95
98
  get_element_type,
96
99
  int_type,
@@ -279,11 +282,13 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
279
282
  return defn.load(self.dfg, self.ctx, node)
280
283
 
281
284
  def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
282
- match node.param.ty:
285
+ assert self.ctx.current_mono_args is not None
286
+ param = node.param.instantiate_bounds(self.ctx.current_mono_args)
287
+ match param.ty:
283
288
  case NumericType(NumericType.Kind.Nat):
284
289
  # Generic nat parameters are encoded using Hugr bounded nat parameters,
285
290
  # so they are not monomorphized when compiling to Hugr
286
- arg = node.param.to_bound().to_hugr(self.ctx)
291
+ arg = param.to_bound().to_hugr(self.ctx)
287
292
  load_nat = hugr.std.PRELUDE.get_op("load_nat").instantiate(
288
293
  [arg], ht.FunctionType([], [ht.USize()])
289
294
  )
@@ -291,7 +296,6 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
291
296
  return self.builder.add_op(convert_ifromusize(), usize)
292
297
  case ty:
293
298
  # Look up monomorphization
294
- assert self.ctx.current_mono_args is not None
295
299
  match self.ctx.current_mono_args[node.param.idx]:
296
300
  case ConstArg(const=ConstValue(value=v)):
297
301
  val = python_value_to_hugr(v, ty, self.ctx)
@@ -555,15 +559,22 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
555
559
  op_name = f"result_array_{base_name}"
556
560
  size_arg = node.array_len.to_arg().to_hugr(self.ctx)
557
561
  extra_args = [size_arg, *extra_args]
558
- # Remove the option wrapping in the array
559
- unwrap = array_unwrap_elem(self.ctx)
560
- unwrap = self.builder.load_function(
561
- unwrap,
562
- instantiation=ht.FunctionType([ht.Option(base_ty)], [base_ty]),
563
- type_args=[ht.TypeTypeArg(base_ty)],
562
+ # As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
563
+ # that all elements in it are copyable) to avoid linearity violations when
564
+ # both passing it to the result operation and returning it (as an inout
565
+ # argument).
566
+ value_wire, inout_wire = self.builder.add_op(
567
+ array_clone(base_ty, size_arg), value_wire
564
568
  )
565
- map_op = array_map(ht.Option(base_ty), size_arg, base_ty)
566
- value_wire = self.builder.add_op(map_op, value_wire, unwrap)
569
+ func_ty = FunctionType(
570
+ [
571
+ FuncInput(
572
+ array_type(node.base_ty, node.array_len), InputFlags.Inout
573
+ ),
574
+ ],
575
+ NoneType(),
576
+ )
577
+ self._update_inout_ports(node.args, iter([inout_wire]), func_ty)
567
578
  if is_bool_type(node.base_ty):
568
579
  # We need to coerce a read on all the array elements if they are bools.
569
580
  array_read = array_read_bool(self.ctx)
@@ -571,9 +582,9 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
571
582
  map_op = array_map(OpaqueBool, size_arg, ht.Bool)
572
583
  value_wire = self.builder.add_op(map_op, value_wire, array_read)
573
584
  base_ty = ht.Bool
574
- # Turn `value_array` into regular linear `array`
585
+ # Turn `borrow_array` into regular `array`
575
586
  value_wire = self.builder.add_op(
576
- array_convert_to_std_array(base_ty, size_arg), value_wire
587
+ array_to_std_array(base_ty, size_arg), value_wire
577
588
  )
578
589
  hugr_ty: ht.Type = hugr.std.collections.array.Array(base_ty, size_arg)
579
590
  else:
@@ -635,20 +646,20 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
635
646
  qubit_arr_in = self.builder.add_op(
636
647
  array_new(ht.Qubit, len(node.args) - 1), *qubits_in
637
648
  )
638
- # Turn into standard array from value array.
649
+ # Turn into standard array from borrow array.
639
650
  qubit_arr_in = self.builder.add_op(
640
- array_convert_to_std_array(ht.Qubit, num_qubits_arg), qubit_arr_in
651
+ array_to_std_array(ht.Qubit, num_qubits_arg), qubit_arr_in
641
652
  )
642
653
 
643
654
  qubit_arr_out = self.builder.add_op(op, qubit_arr_in)
644
655
 
645
656
  qubit_arr_out = self.builder.add_op(
646
- array_convert_from_std_array(ht.Qubit, num_qubits_arg), qubit_arr_out
657
+ std_array_to_array(ht.Qubit, num_qubits_arg), qubit_arr_out
647
658
  )
648
659
  qubits_out = unpack_array(self.builder, qubit_arr_out)
649
660
  else:
650
- # If the input is an array of qubits, we need to unwrap the elements first,
651
- # and then convert to a value array and back.
661
+ # If the input is an array of qubits, we need to convert to a standard
662
+ # array.
652
663
  qubits_in = [self.visit(node.args[1])]
653
664
  qubits_out = [
654
665
  apply_array_op_with_conversions(
@@ -680,17 +691,10 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
680
691
  assert isinstance(array_ty, OpaqueType)
681
692
  array_var = Variable(next(tmp_vars), array_ty, node)
682
693
  count_var = Variable(next(tmp_vars), int_type(), node)
683
- # See https://github.com/CQCL/guppylang/issues/629
684
- hugr_elt_ty = ht.Option(node.elt_ty.to_hugr(self.ctx))
685
- # Initialise array with `None`s
686
- make_none = array_comprehension_init_func(self.ctx)
687
- make_none = self.builder.load_function(
688
- make_none,
689
- instantiation=ht.FunctionType([], [hugr_elt_ty]),
690
- type_args=[ht.TypeTypeArg(node.elt_ty.to_hugr(self.ctx))],
691
- )
694
+ hugr_elt_ty = node.elt_ty.to_hugr(self.ctx)
695
+ # Initialise empty array.
692
696
  self.dfg[array_var] = self.builder.add_op(
693
- array_repeat(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx)), make_none
697
+ barray_new_all_borrowed(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx))
694
698
  )
695
699
  self.dfg[count_var] = self.builder.load(
696
700
  hugr.std.int.IntVal(0, width=NumericType.INT_WIDTH)
@@ -698,8 +702,12 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
698
702
  with self._build_generators([node.generator], [array_var, count_var]):
699
703
  elt = self.visit(node.elt)
700
704
  array, count = self.dfg[array_var], self.dfg[count_var]
701
- [], [self.dfg[array_var]] = self._build_method_call(
702
- array_ty, "__setitem__", node, [array, count, elt], array_ty.args
705
+ idx = self.builder.add_op(convert_itousize(), count)
706
+ self.dfg[array_var] = self.builder.add_op(
707
+ barray_return(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx)),
708
+ array,
709
+ idx,
710
+ elt,
703
711
  )
704
712
  # Update `count += 1`
705
713
  one = self.builder.load(hugr.std.int.IntVal(1, width=NumericType.INT_WIDTH))
@@ -835,10 +843,6 @@ def python_value_to_hugr(v: Any, exp_ty: Type, ctx: CompilerContext) -> hv.Value
835
843
  return None
836
844
 
837
845
 
838
- ARRAY_COMPREHENSION_INIT: Final[GlobalConstId] = GlobalConstId.fresh(
839
- "array.__comprehension.init"
840
- )
841
-
842
846
  ARRAY_UNWRAP_ELEM: Final[GlobalConstId] = GlobalConstId.fresh("array.__unwrap_elem")
843
847
  ARRAY_WRAP_ELEM: Final[GlobalConstId] = GlobalConstId.fresh("array.__wrap_elem")
844
848
 
@@ -848,54 +852,6 @@ ARRAY_MAKE_OPAQUE_BOOL: Final[GlobalConstId] = GlobalConstId.fresh(
848
852
  )
849
853
 
850
854
 
851
- def array_comprehension_init_func(ctx: CompilerContext) -> hf.Function:
852
- """Returns the Hugr function that is used to initialise arrays elements before a
853
- comprehension.
854
-
855
- Just returns the `None` variant of the optional element type.
856
-
857
- See https://github.com/CQCL/guppylang/issues/629
858
- """
859
- v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
860
- sig = ht.PolyFuncType(
861
- params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
862
- body=ht.FunctionType([], [ht.Option(v)]),
863
- )
864
- func, already_defined = ctx.declare_global_func(ARRAY_COMPREHENSION_INIT, sig)
865
- if not already_defined:
866
- func.set_outputs(func.add_op(ops.Tag(0, ht.Option(v))))
867
- return func
868
-
869
-
870
- def array_unwrap_elem(ctx: CompilerContext) -> hf.Function:
871
- """Returns the Hugr function that is used to unwrap the elements in an option array
872
- to turn it into a regular array."""
873
- v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
874
- sig = ht.PolyFuncType(
875
- params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
876
- body=ht.FunctionType([ht.Option(v)], [v]),
877
- )
878
- func, already_defined = ctx.declare_global_func(ARRAY_UNWRAP_ELEM, sig)
879
- if not already_defined:
880
- msg = "Linear array element has already been used"
881
- func.set_outputs(build_unwrap(func, func.inputs()[0], msg))
882
- return func
883
-
884
-
885
- def array_wrap_elem(ctx: CompilerContext) -> hf.Function:
886
- """Returns the Hugr function that is used to wrap the elements in an regular array
887
- to turn it into a option array."""
888
- v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
889
- sig = ht.PolyFuncType(
890
- params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
891
- body=ht.FunctionType([v], [ht.Option(v)]),
892
- )
893
- func, already_defined = ctx.declare_global_func(ARRAY_WRAP_ELEM, sig)
894
- if not already_defined:
895
- func.set_outputs(func.add_op(ops.Tag(1, ht.Option(v)), func.inputs()[0]))
896
- return func
897
-
898
-
899
855
  def array_read_bool(ctx: CompilerContext) -> hf.Function:
900
856
  """Returns the Hugr function that is used to unwrap the elements in an option array
901
857
  to turn it into a regular array."""
@@ -944,35 +900,21 @@ def apply_array_op_with_conversions(
944
900
  output array.
945
901
 
946
902
  Transformations:
947
- 1. Unwraps / wraps elements in options.
948
- 3. (Optional) Converts from / to opaque bool to / from Hugr bool.
949
- 2. Converts from / to value array to / from standard Hugr array.
903
+ 1. (Optional) Converts from / to opaque bool to / from Hugr bool.
904
+ 2. Converts from / to borrow array to / from standard Hugr array.
950
905
  """
951
- unwrap = array_unwrap_elem(ctx)
952
- unwrap = builder.load_function(
953
- unwrap,
954
- instantiation=ht.FunctionType([ht.Option(elem_ty)], [elem_ty]),
955
- type_args=[ht.TypeTypeArg(elem_ty)],
956
- )
957
- map_op = array_map(ht.Option(elem_ty), size_arg, elem_ty)
958
- unwrapped_array = builder.add_op(map_op, input_array, unwrap)
959
-
960
906
  if convert_bool:
961
907
  array_read = array_read_bool(ctx)
962
908
  array_read = builder.load_function(array_read)
963
909
  map_op = array_map(OpaqueBool, size_arg, ht.Bool)
964
- unwrapped_array = builder.add_op(map_op, unwrapped_array, array_read)
910
+ input_array = builder.add_op(map_op, input_array, array_read)
965
911
  elem_ty = ht.Bool
966
912
 
967
- unwrapped_array = builder.add_op(
968
- array_convert_to_std_array(elem_ty, size_arg), unwrapped_array
969
- )
913
+ input_array = builder.add_op(array_to_std_array(elem_ty, size_arg), input_array)
970
914
 
971
- result_array = builder.add_op(op, unwrapped_array)
915
+ result_array = builder.add_op(op, input_array)
972
916
 
973
- result_array = builder.add_op(
974
- array_convert_from_std_array(elem_ty, size_arg), result_array
975
- )
917
+ result_array = builder.add_op(std_array_to_array(elem_ty, size_arg), result_array)
976
918
 
977
919
  if convert_bool:
978
920
  array_make_opaque = array_make_opaque_bool(ctx)
@@ -981,11 +923,4 @@ def apply_array_op_with_conversions(
981
923
  result_array = builder.add_op(map_op, result_array, array_make_opaque)
982
924
  elem_ty = OpaqueBool
983
925
 
984
- wrap = array_wrap_elem(ctx)
985
- wrap = builder.load_function(
986
- wrap,
987
- instantiation=ht.FunctionType([elem_ty], [ht.Option(elem_ty)]),
988
- type_args=[ht.TypeTypeArg(elem_ty)],
989
- )
990
- map_op = array_map(elem_ty, size_arg, ht.Option(elem_ty))
991
- return builder.add_op(map_op, result_array, wrap)
926
+ return result_array