guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +118 -5
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +5 -2
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +58 -17
- guppylang_internals/checker/func_checker.py +18 -14
- guppylang_internals/checker/linearity_checker.py +67 -10
- guppylang_internals/checker/modifier_checker.py +120 -0
- guppylang_internals/checker/stmt_checker.py +48 -1
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +93 -56
- guppylang_internals/compiler/expr_compiler.py +72 -168
- guppylang_internals/compiler/modifier_compiler.py +176 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +86 -7
- guppylang_internals/definition/custom.py +39 -1
- guppylang_internals/definition/declaration.py +9 -6
- guppylang_internals/definition/function.py +12 -2
- guppylang_internals/definition/parameter.py +8 -3
- guppylang_internals/definition/pytket_circuits.py +14 -41
- guppylang_internals/definition/struct.py +13 -7
- guppylang_internals/definition/ty.py +3 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/engine.py +9 -3
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +147 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +95 -283
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +10 -0
- guppylang_internals/tracing/unpacking.py +19 -20
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +2 -5
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +11 -24
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +62 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +91 -85
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
- guppylang_internals-0.26.0.dist-info/RECORD +104 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
- guppylang_internals-0.24.0.dist-info/RECORD +0 -98
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from abc import ABC
|
|
3
|
-
from collections import defaultdict
|
|
4
3
|
from collections.abc import Callable, Iterator, Sequence
|
|
5
4
|
from contextlib import contextmanager
|
|
6
5
|
from dataclasses import dataclass, field
|
|
@@ -16,6 +15,7 @@ from hugr.hugr.base import OpVarCov
|
|
|
16
15
|
from hugr.hugr.node_port import ToNode
|
|
17
16
|
from hugr.std import PRELUDE
|
|
18
17
|
from hugr.std.collections.array import EXTENSION as ARRAY_EXTENSION
|
|
18
|
+
from hugr.std.collections.borrow_array import EXTENSION as BORROW_ARRAY_EXTENSION
|
|
19
19
|
from typing_extensions import assert_never
|
|
20
20
|
|
|
21
21
|
from guppylang_internals.checker.core import (
|
|
@@ -43,8 +43,8 @@ from guppylang_internals.tys.arg import ConstArg, TypeArg
|
|
|
43
43
|
from guppylang_internals.tys.builtin import nat_type
|
|
44
44
|
from guppylang_internals.tys.common import ToHugrContext
|
|
45
45
|
from guppylang_internals.tys.const import BoundConstVar, ConstValue
|
|
46
|
-
from guppylang_internals.tys.param import ConstParam, Parameter
|
|
47
|
-
from guppylang_internals.tys.subst import Inst
|
|
46
|
+
from guppylang_internals.tys.param import ConstParam, Parameter
|
|
47
|
+
from guppylang_internals.tys.subst import Inst, Instantiator
|
|
48
48
|
from guppylang_internals.tys.ty import (
|
|
49
49
|
BoundTypeVar,
|
|
50
50
|
NumericType,
|
|
@@ -185,7 +185,7 @@ class CompilerContext(ToHugrContext):
|
|
|
185
185
|
# make the call to `ENGINE.get_checked` below fail. For now, let's just short-
|
|
186
186
|
# cut if the function doesn't take any generic params (as is the case for all
|
|
187
187
|
# nested functions).
|
|
188
|
-
# See https://github.com/
|
|
188
|
+
# See https://github.com/quantinuum/guppylang/issues/1032
|
|
189
189
|
if (def_id, ()) in self.compiled:
|
|
190
190
|
assert type_args == []
|
|
191
191
|
return self.compiled[def_id, ()], type_args
|
|
@@ -221,10 +221,10 @@ class CompilerContext(ToHugrContext):
|
|
|
221
221
|
match ENGINE.get_checked(defn.id):
|
|
222
222
|
case MonomorphizableDef(params=params) as defn:
|
|
223
223
|
# Entry point is not allowed to require monomorphization
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
224
|
+
if mono_params := require_monomorphization(params):
|
|
225
|
+
mono_param = mono_params.pop()
|
|
226
|
+
err = EntryMonomorphizeError(defn.defined_at, defn.name, mono_param)
|
|
227
|
+
raise GuppyError(err)
|
|
228
228
|
# Thus, the partial monomorphization for the entry point is always empty
|
|
229
229
|
entry_mono_args = tuple(None for _ in params)
|
|
230
230
|
entry_compiled = defn.monomorphize(self.module, entry_mono_args, self)
|
|
@@ -247,7 +247,7 @@ class CompilerContext(ToHugrContext):
|
|
|
247
247
|
|
|
248
248
|
# Insert explicit drops for affine types
|
|
249
249
|
# TODO: This is a quick workaround until we can properly insert these drops
|
|
250
|
-
#
|
|
250
|
+
# during linearity checking. See https://github.com/quantinuum/guppylang/issues/1082
|
|
251
251
|
insert_drops(self.module.hugr)
|
|
252
252
|
|
|
253
253
|
return entry_compiled
|
|
@@ -467,19 +467,27 @@ def is_return_var(x: str) -> bool:
|
|
|
467
467
|
return x.startswith("%ret")
|
|
468
468
|
|
|
469
469
|
|
|
470
|
-
def
|
|
471
|
-
"""
|
|
470
|
+
def require_monomorphization(params: Sequence[Parameter]) -> set[Parameter]:
|
|
471
|
+
"""Returns the subset of type parameters that must be monomorphized before compiling
|
|
472
|
+
to Hugr.
|
|
472
473
|
|
|
473
474
|
This is required for some Guppy language features that cannot be encoded in Hugr
|
|
474
|
-
yet. Currently, this
|
|
475
|
+
yet. Currently, this applies to:
|
|
476
|
+
* non-nat const parameters
|
|
477
|
+
* parameters that occur in the type of const parameters
|
|
475
478
|
"""
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
479
|
+
mono_params: set[Parameter] = set()
|
|
480
|
+
for param in params:
|
|
481
|
+
match param:
|
|
482
|
+
case ConstParam(ty=ty) if ty != NumericType(NumericType.Kind.Nat):
|
|
483
|
+
mono_params.add(param)
|
|
484
|
+
# If the constant type refers to any bound variables, then those will
|
|
485
|
+
# need to be monomorphized as well
|
|
486
|
+
for var in ty.bound_vars:
|
|
487
|
+
mono_params.add(params[var.idx])
|
|
488
|
+
case _:
|
|
489
|
+
pass
|
|
490
|
+
return mono_params
|
|
483
491
|
|
|
484
492
|
|
|
485
493
|
def partially_monomorphize_args(
|
|
@@ -493,27 +501,36 @@ def partially_monomorphize_args(
|
|
|
493
501
|
|
|
494
502
|
Also takes care of normalising bound variables w.r.t. the current monomorphization.
|
|
495
503
|
"""
|
|
496
|
-
|
|
497
|
-
|
|
504
|
+
# Normalise args w.r.t. the current outer monomorphisation
|
|
505
|
+
if ctx.current_mono_args is not None:
|
|
506
|
+
instantiator = Instantiator(ctx.current_mono_args, allow_partial=True)
|
|
507
|
+
args = [arg.transform(instantiator) for arg in args]
|
|
508
|
+
|
|
509
|
+
# Filter args depending on whether they need monomorphization or not. For this, we
|
|
510
|
+
# can't purely rely on the `require_monomorphization` function above since the
|
|
511
|
+
# instantiation also needs to be taken into account. For example, consider
|
|
512
|
+
# a function `def foo[T, x: T](...)`. Normally, both parameters would need to be
|
|
513
|
+
# monomorphized. However, if we instantiate `T := nat`, suddenly `x` no longer needs
|
|
514
|
+
# be monomorphized since we can use bounded nats in Hugr. Thus, we have to replicate
|
|
515
|
+
# the behaviour of `require_monomorphization` while taking `args` into account.
|
|
516
|
+
mono_args: list[Argument | None] = [None] * len(args)
|
|
498
517
|
for param, arg in zip(params, args, strict=True):
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
mono_args.append(None)
|
|
516
|
-
rem_args.append(arg)
|
|
518
|
+
match param, param.instantiate_bounds(args):
|
|
519
|
+
case ConstParam(ty=original_ty), ConstParam(ty=inst_ty):
|
|
520
|
+
# If the original const type is not `nat`, then all variables occurring
|
|
521
|
+
# in the type need to be monomorphised
|
|
522
|
+
if original_ty != nat_type():
|
|
523
|
+
for var in original_ty.bound_vars:
|
|
524
|
+
mono_args[var.idx] = args[var.idx]
|
|
525
|
+
# If the const type is still not `nat` after normalisation, then the
|
|
526
|
+
# const arg itself also needs to be monomorphized
|
|
527
|
+
if inst_ty != nat_type():
|
|
528
|
+
mono_args[param.idx] = arg
|
|
529
|
+
|
|
530
|
+
# Mono-arguments should not refer to any bound variables
|
|
531
|
+
assert all(mono_arg is None or not mono_arg.bound_vars for mono_arg in mono_args)
|
|
532
|
+
|
|
533
|
+
rem_args = [arg for i, arg in enumerate(args) if mono_args[i] is None]
|
|
517
534
|
return tuple(mono_args), rem_args
|
|
518
535
|
|
|
519
536
|
|
|
@@ -565,6 +582,9 @@ def may_have_side_effect(op: ops.Op) -> bool:
|
|
|
565
582
|
# precise answer
|
|
566
583
|
return True
|
|
567
584
|
case _:
|
|
585
|
+
# There is no need to handle TailLoop (in case of non-termination) since
|
|
586
|
+
# TailLoops are only generated for array comprehensions which must have
|
|
587
|
+
# statically-guaranteed (finite) size. TODO revisit this for lists.
|
|
568
588
|
return False
|
|
569
589
|
|
|
570
590
|
|
|
@@ -578,9 +598,7 @@ def track_hugr_side_effects() -> Iterator[None]:
|
|
|
578
598
|
# Remember original `Hugr.add_node` method that is monkey-patched below.
|
|
579
599
|
hugr_add_node = Hugr.add_node
|
|
580
600
|
# Last node with potential side effects for each dataflow parent
|
|
581
|
-
prev_node_with_side_effect:
|
|
582
|
-
lambda: None
|
|
583
|
-
)
|
|
601
|
+
prev_node_with_side_effect: dict[Node, tuple[Node, Hugr[Any]]] = {}
|
|
584
602
|
|
|
585
603
|
def hugr_add_node_with_order(
|
|
586
604
|
self: Hugr[OpVarCov],
|
|
@@ -601,29 +619,47 @@ def track_hugr_side_effects() -> Iterator[None]:
|
|
|
601
619
|
"""Performs the actual order-edge insertion, assuming that `node` has a side-
|
|
602
620
|
effect."""
|
|
603
621
|
parent = hugr[node].parent
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
622
|
+
assert parent is not None
|
|
623
|
+
|
|
624
|
+
if prev := prev_node_with_side_effect.get(parent):
|
|
625
|
+
prev_node = prev[0]
|
|
626
|
+
else:
|
|
627
|
+
# This is the first side-effectful op in this DFG. Recurse on the parent
|
|
628
|
+
# since the parent is also considered side-effectful now. We shouldn't walk
|
|
629
|
+
# up through function definitions (only the Module is above)
|
|
630
|
+
if not isinstance(hugr[parent].op, ops.FuncDefn):
|
|
631
|
+
handle_side_effect(parent, hugr)
|
|
632
|
+
# For DataflowBlocks and Cases, recurse to mark their containing CFG
|
|
633
|
+
# or Conditional as side-effectful as well, but there is nothing to do
|
|
634
|
+
# locally: we cannot add order edges, but Conditional/CFG semantics
|
|
635
|
+
# ensure execution if appropriate.
|
|
636
|
+
if isinstance(hugr[parent].op, ops.Conditional | ops.CFG):
|
|
637
|
+
return
|
|
638
|
+
prev_node = hugr.children(parent)[0]
|
|
639
|
+
assert isinstance(hugr[prev_node].op, ops.Input)
|
|
640
|
+
|
|
641
|
+
# Add edge, but avoid self-loops for containers when recursing up the hierarchy.
|
|
642
|
+
if prev_node != node:
|
|
643
|
+
hugr.add_order_link(prev_node, node)
|
|
644
|
+
prev_node_with_side_effect[parent] = (node, hugr)
|
|
615
645
|
|
|
616
646
|
# Monkey-patch the `add_node` method
|
|
617
647
|
Hugr.add_node = hugr_add_node_with_order # type: ignore[method-assign]
|
|
618
648
|
try:
|
|
619
649
|
yield
|
|
650
|
+
for parent, (last, hugr) in prev_node_with_side_effect.items():
|
|
651
|
+
# Connect the last side-effecting node to Output
|
|
652
|
+
outp = hugr.children(parent)[1]
|
|
653
|
+
assert isinstance(hugr[outp].op, ops.Output)
|
|
654
|
+
assert last != outp
|
|
655
|
+
hugr.add_order_link(last, outp)
|
|
620
656
|
finally:
|
|
621
657
|
Hugr.add_node = hugr_add_node # type: ignore[method-assign]
|
|
622
658
|
|
|
623
659
|
|
|
624
660
|
def qualified_name(type_def: he.TypeDef) -> str:
|
|
625
661
|
"""Returns the qualified name of a Hugr extension type.
|
|
626
|
-
TODO: Remove once upstreamed, see https://github.com/
|
|
662
|
+
TODO: Remove once upstreamed, see https://github.com/quantinuum/hugr/issues/2426
|
|
627
663
|
"""
|
|
628
664
|
if type_def._extension is not None:
|
|
629
665
|
return f"{type_def._extension.name}.{type_def.name}"
|
|
@@ -634,6 +670,7 @@ def qualified_name(type_def: he.TypeDef) -> str:
|
|
|
634
670
|
#: insertion of an explicit drop operation.
|
|
635
671
|
AFFINE_EXTENSION_TYS: list[str] = [
|
|
636
672
|
qualified_name(ARRAY_EXTENSION.get_type("array")),
|
|
673
|
+
qualified_name(BORROW_ARRAY_EXTENSION.get_type("borrow_array")),
|
|
637
674
|
]
|
|
638
675
|
|
|
639
676
|
|
|
@@ -674,14 +711,14 @@ def drop_op(ty: ht.Type) -> ops.ExtOp:
|
|
|
674
711
|
def insert_drops(hugr: Hugr[OpVarCov]) -> None:
|
|
675
712
|
"""Inserts explicit drop ops for unconnected ports into the Hugr.
|
|
676
713
|
TODO: This is a quick workaround until we can properly insert these drops during
|
|
677
|
-
linearity checking. See https://github.com/
|
|
714
|
+
linearity checking. See https://github.com/quantinuum/guppylang/issues/1082
|
|
678
715
|
"""
|
|
679
716
|
for node in hugr:
|
|
680
717
|
data = hugr[node]
|
|
681
718
|
# Iterating over `node.outputs()` doesn't work reliably since it sometimes
|
|
682
719
|
# raises an `IncompleteOp` exception. Instead, we query the number of out ports
|
|
683
720
|
# and look them up by index. However, this method is *also* broken when
|
|
684
|
-
#
|
|
721
|
+
# inspecting `FuncDefn` nodes due to https://github.com/quantinuum/hugr/issues/2438.
|
|
685
722
|
if isinstance(data.op, ops.FuncDefn):
|
|
686
723
|
continue
|
|
687
724
|
for i in range(hugr.num_out_ports(node)):
|
|
@@ -15,7 +15,6 @@ from hugr import val as hv
|
|
|
15
15
|
from hugr.build import function as hf
|
|
16
16
|
from hugr.build.cond_loop import Conditional
|
|
17
17
|
from hugr.build.dfg import DP, DfBase
|
|
18
|
-
from typing_extensions import assert_never
|
|
19
18
|
|
|
20
19
|
from guppylang_internals.ast_util import AstNode, AstVisitor, get_type
|
|
21
20
|
from guppylang_internals.cfg.builder import tmp_vars
|
|
@@ -23,7 +22,6 @@ from guppylang_internals.checker.core import Variable, contains_subscript
|
|
|
23
22
|
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
24
23
|
from guppylang_internals.compiler.core import (
|
|
25
24
|
DEBUG_EXTENSION,
|
|
26
|
-
RESULT_EXTENSION,
|
|
27
25
|
CompilerBase,
|
|
28
26
|
CompilerContext,
|
|
29
27
|
DFContainer,
|
|
@@ -53,7 +51,6 @@ from guppylang_internals.nodes import (
|
|
|
53
51
|
PanicExpr,
|
|
54
52
|
PartialApply,
|
|
55
53
|
PlaceNode,
|
|
56
|
-
ResultExpr,
|
|
57
54
|
StateResultExpr,
|
|
58
55
|
SubscriptAccessAndDrop,
|
|
59
56
|
TensorCall,
|
|
@@ -63,23 +60,24 @@ from guppylang_internals.nodes import (
|
|
|
63
60
|
from guppylang_internals.std._internal.compiler.arithmetic import (
|
|
64
61
|
UnsignedIntVal,
|
|
65
62
|
convert_ifromusize,
|
|
63
|
+
convert_itousize,
|
|
66
64
|
)
|
|
67
65
|
from guppylang_internals.std._internal.compiler.array import (
|
|
68
|
-
array_convert_from_std_array,
|
|
69
|
-
array_convert_to_std_array,
|
|
70
66
|
array_map,
|
|
71
67
|
array_new,
|
|
72
|
-
|
|
68
|
+
array_to_std_array,
|
|
69
|
+
barray_new_all_borrowed,
|
|
70
|
+
barray_return,
|
|
73
71
|
standard_array_type,
|
|
72
|
+
std_array_to_array,
|
|
74
73
|
unpack_array,
|
|
75
74
|
)
|
|
76
75
|
from guppylang_internals.std._internal.compiler.list import (
|
|
77
76
|
list_new,
|
|
78
77
|
)
|
|
79
78
|
from guppylang_internals.std._internal.compiler.prelude import (
|
|
80
|
-
build_error,
|
|
81
79
|
build_panic,
|
|
82
|
-
|
|
80
|
+
make_error,
|
|
83
81
|
panic,
|
|
84
82
|
)
|
|
85
83
|
from guppylang_internals.std._internal.compiler.tket_bool import (
|
|
@@ -94,10 +92,9 @@ from guppylang_internals.tys.builtin import (
|
|
|
94
92
|
bool_type,
|
|
95
93
|
get_element_type,
|
|
96
94
|
int_type,
|
|
97
|
-
is_bool_type,
|
|
98
95
|
is_frozenarray_type,
|
|
99
96
|
)
|
|
100
|
-
from guppylang_internals.tys.const import ConstValue
|
|
97
|
+
from guppylang_internals.tys.const import BoundConstVar, Const, ConstValue
|
|
101
98
|
from guppylang_internals.tys.subst import Inst
|
|
102
99
|
from guppylang_internals.tys.ty import (
|
|
103
100
|
BoundTypeVar,
|
|
@@ -279,11 +276,13 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
279
276
|
return defn.load(self.dfg, self.ctx, node)
|
|
280
277
|
|
|
281
278
|
def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
|
|
282
|
-
|
|
279
|
+
assert self.ctx.current_mono_args is not None
|
|
280
|
+
param = node.param.instantiate_bounds(self.ctx.current_mono_args)
|
|
281
|
+
match param.ty:
|
|
283
282
|
case NumericType(NumericType.Kind.Nat):
|
|
284
283
|
# Generic nat parameters are encoded using Hugr bounded nat parameters,
|
|
285
284
|
# so they are not monomorphized when compiling to Hugr
|
|
286
|
-
arg =
|
|
285
|
+
arg = param.to_bound().to_hugr(self.ctx)
|
|
287
286
|
load_nat = hugr.std.PRELUDE.get_op("load_nat").instantiate(
|
|
288
287
|
[arg], ht.FunctionType([], [ht.USize()])
|
|
289
288
|
)
|
|
@@ -291,7 +290,6 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
291
290
|
return self.builder.add_op(convert_ifromusize(), usize)
|
|
292
291
|
case ty:
|
|
293
292
|
# Look up monomorphization
|
|
294
|
-
assert self.ctx.current_mono_args is not None
|
|
295
293
|
match self.ctx.current_mono_args[node.param.idx]:
|
|
296
294
|
case ConstArg(const=ConstValue(value=v)):
|
|
297
295
|
val = python_value_to_hugr(v, ty, self.ctx)
|
|
@@ -531,67 +529,48 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
531
529
|
tuple_port = self.visit(node.value)
|
|
532
530
|
return self._unpack_tuple(tuple_port, node.tuple_ty.element_types)[node.index]
|
|
533
531
|
|
|
534
|
-
def
|
|
535
|
-
|
|
536
|
-
base_ty = node.base_ty.to_hugr(self.ctx)
|
|
537
|
-
extra_args: list[ht.TypeArg] = []
|
|
538
|
-
if isinstance(node.base_ty, NumericType):
|
|
539
|
-
match node.base_ty.kind:
|
|
540
|
-
case NumericType.Kind.Nat:
|
|
541
|
-
base_name = "uint"
|
|
542
|
-
extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
|
|
543
|
-
case NumericType.Kind.Int:
|
|
544
|
-
base_name = "int"
|
|
545
|
-
extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
|
|
546
|
-
case NumericType.Kind.Float:
|
|
547
|
-
base_name = "f64"
|
|
548
|
-
case kind:
|
|
549
|
-
assert_never(kind)
|
|
550
|
-
else:
|
|
551
|
-
# The only other valid base type is bool
|
|
552
|
-
assert is_bool_type(node.base_ty)
|
|
553
|
-
base_name = "bool"
|
|
554
|
-
if node.array_len is not None:
|
|
555
|
-
op_name = f"result_array_{base_name}"
|
|
556
|
-
size_arg = node.array_len.to_arg().to_hugr(self.ctx)
|
|
557
|
-
extra_args = [size_arg, *extra_args]
|
|
558
|
-
# Remove the option wrapping in the array
|
|
559
|
-
unwrap = array_unwrap_elem(self.ctx)
|
|
560
|
-
unwrap = self.builder.load_function(
|
|
561
|
-
unwrap,
|
|
562
|
-
instantiation=ht.FunctionType([ht.Option(base_ty)], [base_ty]),
|
|
563
|
-
type_args=[ht.TypeTypeArg(base_ty)],
|
|
564
|
-
)
|
|
565
|
-
map_op = array_map(ht.Option(base_ty), size_arg, base_ty)
|
|
566
|
-
value_wire = self.builder.add_op(map_op, value_wire, unwrap)
|
|
567
|
-
if is_bool_type(node.base_ty):
|
|
568
|
-
# We need to coerce a read on all the array elements if they are bools.
|
|
569
|
-
array_read = array_read_bool(self.ctx)
|
|
570
|
-
array_read = self.builder.load_function(array_read)
|
|
571
|
-
map_op = array_map(OpaqueBool, size_arg, ht.Bool)
|
|
572
|
-
value_wire = self.builder.add_op(map_op, value_wire, array_read)
|
|
573
|
-
base_ty = ht.Bool
|
|
574
|
-
# Turn `value_array` into regular linear `array`
|
|
575
|
-
value_wire = self.builder.add_op(
|
|
576
|
-
array_convert_to_std_array(base_ty, size_arg), value_wire
|
|
577
|
-
)
|
|
578
|
-
hugr_ty: ht.Type = hugr.std.collections.array.Array(base_ty, size_arg)
|
|
579
|
-
else:
|
|
580
|
-
if is_bool_type(node.base_ty):
|
|
581
|
-
base_ty = ht.Bool
|
|
582
|
-
value_wire = self.builder.add_op(read_bool(), value_wire)
|
|
583
|
-
op_name = f"result_{base_name}"
|
|
584
|
-
hugr_ty = base_ty
|
|
532
|
+
def _visit_result_tag(self, tag: Const, loc: ast.expr) -> str:
|
|
533
|
+
"""Helper method to resolve the tag string in `state_result` expressions.
|
|
585
534
|
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
535
|
+
Also takes care of checking that the tag fits into the maximum tag length.
|
|
536
|
+
Once we go ahead with https://github.com/quantinuum/guppylang/discussions/1299,
|
|
537
|
+
this can be moved into type checking.
|
|
538
|
+
"""
|
|
539
|
+
from guppylang_internals.std._internal.compiler.platform import (
|
|
540
|
+
TAG_MAX_LEN,
|
|
541
|
+
TooLongError,
|
|
542
|
+
)
|
|
589
543
|
|
|
590
|
-
|
|
591
|
-
|
|
544
|
+
is_generic: BoundConstVar | None = None
|
|
545
|
+
match tag:
|
|
546
|
+
case ConstValue(value=str(v)):
|
|
547
|
+
tag_value = v
|
|
548
|
+
case BoundConstVar(idx=idx) as var:
|
|
549
|
+
assert self.ctx.current_mono_args is not None
|
|
550
|
+
match self.ctx.current_mono_args[idx]:
|
|
551
|
+
case ConstArg(const=ConstValue(value=str(v))):
|
|
552
|
+
tag_value = v
|
|
553
|
+
is_generic = var
|
|
554
|
+
case _:
|
|
555
|
+
raise InternalGuppyError("Unexpected tag monomorphization")
|
|
556
|
+
case _:
|
|
557
|
+
raise InternalGuppyError("Unexpected tag value")
|
|
558
|
+
|
|
559
|
+
if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
|
|
560
|
+
err = TooLongError(loc)
|
|
561
|
+
err.add_sub_diagnostic(TooLongError.Hint(None))
|
|
562
|
+
if is_generic:
|
|
563
|
+
err.add_sub_diagnostic(
|
|
564
|
+
TooLongError.GenericHint(None, is_generic.display_name, tag_value)
|
|
565
|
+
)
|
|
566
|
+
raise GuppyError(err)
|
|
567
|
+
return tag_value
|
|
592
568
|
|
|
593
569
|
def visit_PanicExpr(self, node: PanicExpr) -> Wire:
|
|
594
|
-
|
|
570
|
+
signal = self.visit(node.signal)
|
|
571
|
+
signal_usize = self.builder.add_op(convert_itousize(), signal)
|
|
572
|
+
msg = self.visit(node.msg)
|
|
573
|
+
err = self.builder.add_op(make_error(), signal_usize, msg)
|
|
595
574
|
in_tys = [get_type(e).to_hugr(self.ctx) for e in node.values]
|
|
596
575
|
out_tys = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(node))]
|
|
597
576
|
args = [self.visit(e) for e in node.values]
|
|
@@ -616,12 +595,13 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
616
595
|
return self._pack_returns([], NoneType())
|
|
617
596
|
|
|
618
597
|
def visit_StateResultExpr(self, node: StateResultExpr) -> Wire:
|
|
598
|
+
tag_value = self._visit_result_tag(node.tag_value, node.tag_expr)
|
|
619
599
|
num_qubits_arg = (
|
|
620
600
|
node.array_len.to_arg().to_hugr(self.ctx)
|
|
621
601
|
if node.array_len
|
|
622
602
|
else ht.BoundedNatArg(len(node.args) - 1)
|
|
623
603
|
)
|
|
624
|
-
args = [ht.StringArg(
|
|
604
|
+
args = [ht.StringArg(tag_value), num_qubits_arg]
|
|
625
605
|
sig = ht.FunctionType(
|
|
626
606
|
[standard_array_type(ht.Qubit, num_qubits_arg)],
|
|
627
607
|
[standard_array_type(ht.Qubit, num_qubits_arg)],
|
|
@@ -635,20 +615,20 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
635
615
|
qubit_arr_in = self.builder.add_op(
|
|
636
616
|
array_new(ht.Qubit, len(node.args) - 1), *qubits_in
|
|
637
617
|
)
|
|
638
|
-
# Turn into standard array from
|
|
618
|
+
# Turn into standard array from borrow array.
|
|
639
619
|
qubit_arr_in = self.builder.add_op(
|
|
640
|
-
|
|
620
|
+
array_to_std_array(ht.Qubit, num_qubits_arg), qubit_arr_in
|
|
641
621
|
)
|
|
642
622
|
|
|
643
623
|
qubit_arr_out = self.builder.add_op(op, qubit_arr_in)
|
|
644
624
|
|
|
645
625
|
qubit_arr_out = self.builder.add_op(
|
|
646
|
-
|
|
626
|
+
std_array_to_array(ht.Qubit, num_qubits_arg), qubit_arr_out
|
|
647
627
|
)
|
|
648
628
|
qubits_out = unpack_array(self.builder, qubit_arr_out)
|
|
649
629
|
else:
|
|
650
|
-
# If the input is an array of qubits, we need to
|
|
651
|
-
#
|
|
630
|
+
# If the input is an array of qubits, we need to convert to a standard
|
|
631
|
+
# array.
|
|
652
632
|
qubits_in = [self.visit(node.args[1])]
|
|
653
633
|
qubits_out = [
|
|
654
634
|
apply_array_op_with_conversions(
|
|
@@ -680,17 +660,10 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
680
660
|
assert isinstance(array_ty, OpaqueType)
|
|
681
661
|
array_var = Variable(next(tmp_vars), array_ty, node)
|
|
682
662
|
count_var = Variable(next(tmp_vars), int_type(), node)
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
# Initialise array with `None`s
|
|
686
|
-
make_none = array_comprehension_init_func(self.ctx)
|
|
687
|
-
make_none = self.builder.load_function(
|
|
688
|
-
make_none,
|
|
689
|
-
instantiation=ht.FunctionType([], [hugr_elt_ty]),
|
|
690
|
-
type_args=[ht.TypeTypeArg(node.elt_ty.to_hugr(self.ctx))],
|
|
691
|
-
)
|
|
663
|
+
hugr_elt_ty = node.elt_ty.to_hugr(self.ctx)
|
|
664
|
+
# Initialise empty array.
|
|
692
665
|
self.dfg[array_var] = self.builder.add_op(
|
|
693
|
-
|
|
666
|
+
barray_new_all_borrowed(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx))
|
|
694
667
|
)
|
|
695
668
|
self.dfg[count_var] = self.builder.load(
|
|
696
669
|
hugr.std.int.IntVal(0, width=NumericType.INT_WIDTH)
|
|
@@ -698,8 +671,12 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
698
671
|
with self._build_generators([node.generator], [array_var, count_var]):
|
|
699
672
|
elt = self.visit(node.elt)
|
|
700
673
|
array, count = self.dfg[array_var], self.dfg[count_var]
|
|
701
|
-
|
|
702
|
-
|
|
674
|
+
idx = self.builder.add_op(convert_itousize(), count)
|
|
675
|
+
self.dfg[array_var] = self.builder.add_op(
|
|
676
|
+
barray_return(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx)),
|
|
677
|
+
array,
|
|
678
|
+
idx,
|
|
679
|
+
elt,
|
|
703
680
|
)
|
|
704
681
|
# Update `count += 1`
|
|
705
682
|
one = self.builder.load(hugr.std.int.IntVal(1, width=NumericType.INT_WIDTH))
|
|
@@ -835,10 +812,6 @@ def python_value_to_hugr(v: Any, exp_ty: Type, ctx: CompilerContext) -> hv.Value
|
|
|
835
812
|
return None
|
|
836
813
|
|
|
837
814
|
|
|
838
|
-
ARRAY_COMPREHENSION_INIT: Final[GlobalConstId] = GlobalConstId.fresh(
|
|
839
|
-
"array.__comprehension.init"
|
|
840
|
-
)
|
|
841
|
-
|
|
842
815
|
ARRAY_UNWRAP_ELEM: Final[GlobalConstId] = GlobalConstId.fresh("array.__unwrap_elem")
|
|
843
816
|
ARRAY_WRAP_ELEM: Final[GlobalConstId] = GlobalConstId.fresh("array.__wrap_elem")
|
|
844
817
|
|
|
@@ -848,54 +821,6 @@ ARRAY_MAKE_OPAQUE_BOOL: Final[GlobalConstId] = GlobalConstId.fresh(
|
|
|
848
821
|
)
|
|
849
822
|
|
|
850
823
|
|
|
851
|
-
def array_comprehension_init_func(ctx: CompilerContext) -> hf.Function:
|
|
852
|
-
"""Returns the Hugr function that is used to initialise arrays elements before a
|
|
853
|
-
comprehension.
|
|
854
|
-
|
|
855
|
-
Just returns the `None` variant of the optional element type.
|
|
856
|
-
|
|
857
|
-
See https://github.com/CQCL/guppylang/issues/629
|
|
858
|
-
"""
|
|
859
|
-
v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
|
|
860
|
-
sig = ht.PolyFuncType(
|
|
861
|
-
params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
|
|
862
|
-
body=ht.FunctionType([], [ht.Option(v)]),
|
|
863
|
-
)
|
|
864
|
-
func, already_defined = ctx.declare_global_func(ARRAY_COMPREHENSION_INIT, sig)
|
|
865
|
-
if not already_defined:
|
|
866
|
-
func.set_outputs(func.add_op(ops.Tag(0, ht.Option(v))))
|
|
867
|
-
return func
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
def array_unwrap_elem(ctx: CompilerContext) -> hf.Function:
|
|
871
|
-
"""Returns the Hugr function that is used to unwrap the elements in an option array
|
|
872
|
-
to turn it into a regular array."""
|
|
873
|
-
v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
|
|
874
|
-
sig = ht.PolyFuncType(
|
|
875
|
-
params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
|
|
876
|
-
body=ht.FunctionType([ht.Option(v)], [v]),
|
|
877
|
-
)
|
|
878
|
-
func, already_defined = ctx.declare_global_func(ARRAY_UNWRAP_ELEM, sig)
|
|
879
|
-
if not already_defined:
|
|
880
|
-
msg = "Linear array element has already been used"
|
|
881
|
-
func.set_outputs(build_unwrap(func, func.inputs()[0], msg))
|
|
882
|
-
return func
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
def array_wrap_elem(ctx: CompilerContext) -> hf.Function:
|
|
886
|
-
"""Returns the Hugr function that is used to wrap the elements in an regular array
|
|
887
|
-
to turn it into a option array."""
|
|
888
|
-
v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
|
|
889
|
-
sig = ht.PolyFuncType(
|
|
890
|
-
params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
|
|
891
|
-
body=ht.FunctionType([v], [ht.Option(v)]),
|
|
892
|
-
)
|
|
893
|
-
func, already_defined = ctx.declare_global_func(ARRAY_WRAP_ELEM, sig)
|
|
894
|
-
if not already_defined:
|
|
895
|
-
func.set_outputs(func.add_op(ops.Tag(1, ht.Option(v)), func.inputs()[0]))
|
|
896
|
-
return func
|
|
897
|
-
|
|
898
|
-
|
|
899
824
|
def array_read_bool(ctx: CompilerContext) -> hf.Function:
|
|
900
825
|
"""Returns the Hugr function that is used to unwrap the elements in an option array
|
|
901
826
|
to turn it into a regular array."""
|
|
@@ -944,35 +869,21 @@ def apply_array_op_with_conversions(
|
|
|
944
869
|
output array.
|
|
945
870
|
|
|
946
871
|
Transformations:
|
|
947
|
-
1.
|
|
948
|
-
|
|
949
|
-
2. Converts from / to value array to / from standard Hugr array.
|
|
872
|
+
1. (Optional) Converts from / to opaque bool to / from Hugr bool.
|
|
873
|
+
2. Converts from / to borrow array to / from standard Hugr array.
|
|
950
874
|
"""
|
|
951
|
-
unwrap = array_unwrap_elem(ctx)
|
|
952
|
-
unwrap = builder.load_function(
|
|
953
|
-
unwrap,
|
|
954
|
-
instantiation=ht.FunctionType([ht.Option(elem_ty)], [elem_ty]),
|
|
955
|
-
type_args=[ht.TypeTypeArg(elem_ty)],
|
|
956
|
-
)
|
|
957
|
-
map_op = array_map(ht.Option(elem_ty), size_arg, elem_ty)
|
|
958
|
-
unwrapped_array = builder.add_op(map_op, input_array, unwrap)
|
|
959
|
-
|
|
960
875
|
if convert_bool:
|
|
961
876
|
array_read = array_read_bool(ctx)
|
|
962
877
|
array_read = builder.load_function(array_read)
|
|
963
878
|
map_op = array_map(OpaqueBool, size_arg, ht.Bool)
|
|
964
|
-
|
|
879
|
+
input_array = builder.add_op(map_op, input_array, array_read)
|
|
965
880
|
elem_ty = ht.Bool
|
|
966
881
|
|
|
967
|
-
|
|
968
|
-
array_convert_to_std_array(elem_ty, size_arg), unwrapped_array
|
|
969
|
-
)
|
|
882
|
+
input_array = builder.add_op(array_to_std_array(elem_ty, size_arg), input_array)
|
|
970
883
|
|
|
971
|
-
result_array = builder.add_op(op,
|
|
884
|
+
result_array = builder.add_op(op, input_array)
|
|
972
885
|
|
|
973
|
-
result_array = builder.add_op(
|
|
974
|
-
array_convert_from_std_array(elem_ty, size_arg), result_array
|
|
975
|
-
)
|
|
886
|
+
result_array = builder.add_op(std_array_to_array(elem_ty, size_arg), result_array)
|
|
976
887
|
|
|
977
888
|
if convert_bool:
|
|
978
889
|
array_make_opaque = array_make_opaque_bool(ctx)
|
|
@@ -981,11 +892,4 @@ def apply_array_op_with_conversions(
|
|
|
981
892
|
result_array = builder.add_op(map_op, result_array, array_make_opaque)
|
|
982
893
|
elem_ty = OpaqueBool
|
|
983
894
|
|
|
984
|
-
|
|
985
|
-
wrap = builder.load_function(
|
|
986
|
-
wrap,
|
|
987
|
-
instantiation=ht.FunctionType([elem_ty], [ht.Option(elem_ty)]),
|
|
988
|
-
type_args=[ht.TypeTypeArg(elem_ty)],
|
|
989
|
-
)
|
|
990
|
-
map_op = array_map(elem_ty, size_arg, ht.Option(elem_ty))
|
|
991
|
-
return builder.add_op(map_op, result_array, wrap)
|
|
895
|
+
return result_array
|