guppylang-internals 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +101 -3
- guppylang_internals/checker/core.py +4 -0
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/expr_checker.py +46 -10
- guppylang_internals/checker/func_checker.py +1 -1
- guppylang_internals/checker/linearity_checker.py +65 -0
- guppylang_internals/checker/modifier_checker.py +116 -0
- guppylang_internals/checker/stmt_checker.py +48 -1
- guppylang_internals/compiler/core.py +90 -53
- guppylang_internals/compiler/expr_compiler.py +49 -114
- guppylang_internals/compiler/modifier_compiler.py +174 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/definition/custom.py +35 -1
- guppylang_internals/definition/declaration.py +3 -4
- guppylang_internals/definition/parameter.py +8 -3
- guppylang_internals/definition/pytket_circuits.py +13 -41
- guppylang_internals/definition/struct.py +7 -4
- guppylang_internals/definition/ty.py +3 -3
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +124 -0
- guppylang_internals/std/_internal/compiler/array.py +94 -282
- guppylang_internals/std/_internal/compiler/tket_exts.py +9 -2
- guppylang_internals/tracing/unpacking.py +19 -20
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +2 -5
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +8 -21
- guppylang_internals/tys/qubit.py +27 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +31 -21
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +3 -3
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +39 -36
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""Hugr generation for modifiers."""
|
|
2
|
+
|
|
3
|
+
from hugr import Wire, ops
|
|
4
|
+
from hugr import tys as ht
|
|
5
|
+
|
|
6
|
+
from guppylang_internals.ast_util import get_type
|
|
7
|
+
from guppylang_internals.checker.modifier_checker import non_copyable_front_others_back
|
|
8
|
+
from guppylang_internals.compiler.cfg_compiler import compile_cfg
|
|
9
|
+
from guppylang_internals.compiler.core import CompilerContext, DFContainer
|
|
10
|
+
from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
11
|
+
from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
|
|
12
|
+
from guppylang_internals.std._internal.compiler.array import (
|
|
13
|
+
array_new,
|
|
14
|
+
array_to_std_array,
|
|
15
|
+
standard_array_type,
|
|
16
|
+
std_array_to_array,
|
|
17
|
+
unpack_array,
|
|
18
|
+
)
|
|
19
|
+
from guppylang_internals.std._internal.compiler.tket_exts import MODIFIER_EXTENSION
|
|
20
|
+
from guppylang_internals.tys.builtin import int_type, is_array_type
|
|
21
|
+
from guppylang_internals.tys.ty import InputFlags
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def compile_modified_block(
|
|
25
|
+
modified_block: CheckedModifiedBlock,
|
|
26
|
+
dfg: DFContainer,
|
|
27
|
+
ctx: CompilerContext,
|
|
28
|
+
expr_compiler: ExprCompiler,
|
|
29
|
+
) -> Wire:
|
|
30
|
+
DAGGER_OP_NAME = "DaggerModifier"
|
|
31
|
+
CONTROL_OP_NAME = "ControlModifier"
|
|
32
|
+
POWER_OP_NAME = "PowerModifier"
|
|
33
|
+
|
|
34
|
+
dagger_op_def = MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME)
|
|
35
|
+
control_op_def = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME)
|
|
36
|
+
power_op_def = MODIFIER_EXTENSION.get_op(POWER_OP_NAME)
|
|
37
|
+
|
|
38
|
+
body_ty = modified_block.ty
|
|
39
|
+
# TODO: Shouldn't this be `to_hugr_poly` since it can contain
|
|
40
|
+
# a variable with a generic type?
|
|
41
|
+
hugr_ty = body_ty.to_hugr(ctx)
|
|
42
|
+
in_out_ht = [
|
|
43
|
+
fn_inp.ty.to_hugr(ctx)
|
|
44
|
+
for fn_inp in body_ty.inputs
|
|
45
|
+
if InputFlags.Inout in fn_inp.flags and InputFlags.Comptime not in fn_inp.flags
|
|
46
|
+
]
|
|
47
|
+
other_in_ht = [
|
|
48
|
+
fn_inp.ty.to_hugr(ctx)
|
|
49
|
+
for fn_inp in body_ty.inputs
|
|
50
|
+
if InputFlags.Inout not in fn_inp.flags
|
|
51
|
+
and InputFlags.Comptime not in fn_inp.flags
|
|
52
|
+
]
|
|
53
|
+
in_out_arg = ht.ListArg([t.type_arg() for t in in_out_ht])
|
|
54
|
+
other_in_arg = ht.ListArg([t.type_arg() for t in other_in_ht])
|
|
55
|
+
|
|
56
|
+
func_builder = dfg.builder.module_root_builder().define_function(
|
|
57
|
+
str(modified_block), hugr_ty.input, hugr_ty.output
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# compile body
|
|
61
|
+
cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
|
|
62
|
+
func_builder.set_outputs(*cfg)
|
|
63
|
+
|
|
64
|
+
# LoadFunc
|
|
65
|
+
call = dfg.builder.load_function(func_builder, hugr_ty)
|
|
66
|
+
|
|
67
|
+
# Function inputs
|
|
68
|
+
captured = [v for v, _ in modified_block.captured.values()]
|
|
69
|
+
captured = non_copyable_front_others_back(captured)
|
|
70
|
+
args = [dfg[v] for v in captured]
|
|
71
|
+
|
|
72
|
+
# Apply modifiers
|
|
73
|
+
if modified_block.has_dagger():
|
|
74
|
+
dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty])
|
|
75
|
+
call = dfg.builder.add_op(
|
|
76
|
+
ops.ExtOp(
|
|
77
|
+
dagger_op_def,
|
|
78
|
+
dagger_ty,
|
|
79
|
+
[in_out_arg, other_in_arg],
|
|
80
|
+
),
|
|
81
|
+
call,
|
|
82
|
+
)
|
|
83
|
+
if modified_block.has_power():
|
|
84
|
+
power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty])
|
|
85
|
+
for power in modified_block.power:
|
|
86
|
+
num = expr_compiler.compile(power.iter, dfg)
|
|
87
|
+
call = dfg.builder.add_op(
|
|
88
|
+
ops.ExtOp(
|
|
89
|
+
power_op_def,
|
|
90
|
+
power_ty,
|
|
91
|
+
[in_out_arg, other_in_arg],
|
|
92
|
+
),
|
|
93
|
+
call,
|
|
94
|
+
num,
|
|
95
|
+
)
|
|
96
|
+
qubit_num_args = []
|
|
97
|
+
if modified_block.has_control():
|
|
98
|
+
for control in modified_block.control:
|
|
99
|
+
assert control.qubit_num is not None
|
|
100
|
+
qubit_num: ht.TypeArg
|
|
101
|
+
if isinstance(control.qubit_num, int):
|
|
102
|
+
qubit_num = ht.BoundedNatArg(control.qubit_num)
|
|
103
|
+
else:
|
|
104
|
+
qubit_num = control.qubit_num.to_arg().to_hugr(ctx)
|
|
105
|
+
qubit_num_args.append(qubit_num)
|
|
106
|
+
std_array = standard_array_type(ht.Qubit, qubit_num)
|
|
107
|
+
|
|
108
|
+
# control operator
|
|
109
|
+
input_fn_ty = hugr_ty
|
|
110
|
+
output_fn_ty = ht.FunctionType(
|
|
111
|
+
[std_array, *hugr_ty.input], [std_array, *hugr_ty.output]
|
|
112
|
+
)
|
|
113
|
+
op = ops.ExtOp(
|
|
114
|
+
control_op_def,
|
|
115
|
+
ht.FunctionType([input_fn_ty], [output_fn_ty]),
|
|
116
|
+
[qubit_num, in_out_arg, other_in_arg],
|
|
117
|
+
)
|
|
118
|
+
call = dfg.builder.add_op(op, call)
|
|
119
|
+
# update types
|
|
120
|
+
in_out_arg = ht.ListArg([std_array.type_arg(), *in_out_arg.elems])
|
|
121
|
+
hugr_ty = output_fn_ty
|
|
122
|
+
|
|
123
|
+
# Prepare control arguments
|
|
124
|
+
ctrl_args: list[Wire] = []
|
|
125
|
+
for i, control in enumerate(modified_block.control):
|
|
126
|
+
if is_array_type(get_type(control.ctrl[0])):
|
|
127
|
+
control_array = expr_compiler.compile(control.ctrl[0], dfg)
|
|
128
|
+
control_array = dfg.builder.add_op(
|
|
129
|
+
array_to_std_array(ht.Qubit, qubit_num_args[i]), control_array
|
|
130
|
+
)
|
|
131
|
+
ctrl_args.append(control_array)
|
|
132
|
+
else:
|
|
133
|
+
cs = [expr_compiler.compile(c, dfg) for c in control.ctrl]
|
|
134
|
+
control_array = dfg.builder.add_op(
|
|
135
|
+
array_new(ht.Qubit, len(control.ctrl)), *cs
|
|
136
|
+
)
|
|
137
|
+
control_array = dfg.builder.add_op(
|
|
138
|
+
array_to_std_array(ht.Qubit, qubit_num_args[i]), *control_array
|
|
139
|
+
)
|
|
140
|
+
ctrl_args.append(control_array)
|
|
141
|
+
|
|
142
|
+
# Call
|
|
143
|
+
call = dfg.builder.add_op(
|
|
144
|
+
ops.CallIndirect(),
|
|
145
|
+
call,
|
|
146
|
+
*ctrl_args,
|
|
147
|
+
*args,
|
|
148
|
+
)
|
|
149
|
+
outports = iter(call)
|
|
150
|
+
|
|
151
|
+
# Unpack controls
|
|
152
|
+
for i, control in enumerate(modified_block.control):
|
|
153
|
+
outport = next(outports)
|
|
154
|
+
if is_array_type(get_type(control.ctrl[0])):
|
|
155
|
+
control_array = dfg.builder.add_op(
|
|
156
|
+
std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
|
|
157
|
+
)
|
|
158
|
+
c = control.ctrl[0]
|
|
159
|
+
assert isinstance(c, PlaceNode)
|
|
160
|
+
dfg[c.place] = control_array
|
|
161
|
+
else:
|
|
162
|
+
control_array = dfg.builder.add_op(
|
|
163
|
+
std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
|
|
164
|
+
)
|
|
165
|
+
unpacked = unpack_array(dfg.builder, control_array)
|
|
166
|
+
for c, new_c in zip(control.ctrl, unpacked, strict=False):
|
|
167
|
+
assert isinstance(c, PlaceNode)
|
|
168
|
+
dfg[c.place] = new_c
|
|
169
|
+
|
|
170
|
+
for arg in captured:
|
|
171
|
+
if InputFlags.Inout in arg.flags:
|
|
172
|
+
dfg[arg] = next(outports)
|
|
173
|
+
|
|
174
|
+
return call
|
|
@@ -2,7 +2,6 @@ import ast
|
|
|
2
2
|
import functools
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
|
|
5
|
-
import hugr.tys as ht
|
|
6
5
|
from hugr import Wire, ops
|
|
7
6
|
from hugr.build.dfg import DfBase
|
|
8
7
|
|
|
@@ -18,6 +17,7 @@ from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
|
18
17
|
from guppylang_internals.error import InternalGuppyError
|
|
19
18
|
from guppylang_internals.nodes import (
|
|
20
19
|
ArrayUnpack,
|
|
20
|
+
CheckedModifiedBlock,
|
|
21
21
|
CheckedNestedFunctionDef,
|
|
22
22
|
IterableUnpack,
|
|
23
23
|
PlaceNode,
|
|
@@ -120,8 +120,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
120
120
|
ports[len(left) : -len(right)] if right else ports[len(left) :]
|
|
121
121
|
)
|
|
122
122
|
elt = get_element_type(array_ty).to_hugr(self.ctx)
|
|
123
|
-
|
|
124
|
-
|
|
123
|
+
array = self.builder.add_op(
|
|
124
|
+
array_new(elt, len(starred_ports)), *starred_ports
|
|
125
|
+
)
|
|
125
126
|
self._assign(starred, array)
|
|
126
127
|
|
|
127
128
|
@_assign.register
|
|
@@ -130,7 +131,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
130
131
|
# Given an assignment pattern `left, *starred, right`, pop from the left and
|
|
131
132
|
# right, leaving us with the starred array in the middle
|
|
132
133
|
length = lhs.length
|
|
133
|
-
|
|
134
|
+
elt_ty = lhs.elt_type.to_hugr(self.ctx)
|
|
134
135
|
|
|
135
136
|
def pop(
|
|
136
137
|
array: Wire, length: int, pats: list[ast.expr], from_left: bool
|
|
@@ -141,10 +142,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
141
142
|
elts = []
|
|
142
143
|
for i in range(num_pats):
|
|
143
144
|
res = self.builder.add_op(
|
|
144
|
-
array_pop(
|
|
145
|
+
array_pop(elt_ty, length - i, from_left), array
|
|
145
146
|
)
|
|
146
|
-
[
|
|
147
|
-
[elt] = build_unwrap(self.builder, elt_opt, err)
|
|
147
|
+
[elt, array] = build_unwrap(self.builder, res, err)
|
|
148
148
|
elts.append(elt)
|
|
149
149
|
# Assign elements to the given patterns
|
|
150
150
|
for pat, elt in zip(
|
|
@@ -164,7 +164,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
164
164
|
self._assign(lhs.pattern.starred, array)
|
|
165
165
|
else:
|
|
166
166
|
assert length == 0
|
|
167
|
-
self.builder.add_op(array_discard_empty(
|
|
167
|
+
self.builder.add_op(array_discard_empty(elt_ty), array)
|
|
168
168
|
|
|
169
169
|
@_assign.register
|
|
170
170
|
def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None:
|
|
@@ -221,3 +221,10 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
221
221
|
var = Variable(node.name, node.ty, node)
|
|
222
222
|
loaded_func = compile_local_func_def(node, self.dfg, self.ctx)
|
|
223
223
|
self.dfg[var] = loaded_func
|
|
224
|
+
|
|
225
|
+
def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
|
|
226
|
+
from guppylang_internals.compiler.modifier_compiler import (
|
|
227
|
+
compile_modified_block,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
compile_modified_block(node, self.dfg, self.ctx, self.expr_compiler)
|
|
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, ClassVar
|
|
|
7
7
|
from hugr import Wire, ops
|
|
8
8
|
from hugr import tys as ht
|
|
9
9
|
from hugr.build.dfg import DfBase
|
|
10
|
+
from hugr.std.collections.borrow_array import EXTENSION as BORROW_ARRAY_EXTENSION
|
|
10
11
|
|
|
11
12
|
from guppylang_internals.ast_util import (
|
|
12
13
|
AstNode,
|
|
@@ -23,6 +24,7 @@ from guppylang_internals.compiler.core import (
|
|
|
23
24
|
DFContainer,
|
|
24
25
|
GlobalConstId,
|
|
25
26
|
partially_monomorphize_args,
|
|
27
|
+
qualified_name,
|
|
26
28
|
)
|
|
27
29
|
from guppylang_internals.definition.common import ParsableDef
|
|
28
30
|
from guppylang_internals.definition.value import CallReturnWires, CompiledCallableDef
|
|
@@ -486,7 +488,39 @@ class NoopCompiler(CustomCallCompiler):
|
|
|
486
488
|
|
|
487
489
|
|
|
488
490
|
class CopyInoutCompiler(CustomInoutCallCompiler):
|
|
489
|
-
"""Call compiler for functions that
|
|
491
|
+
"""Call compiler for functions that borrow one argument to copy it."""
|
|
490
492
|
|
|
491
493
|
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
494
|
+
assert len(self.ty.input) == 1
|
|
495
|
+
inp_ty = self.ty.input[0]
|
|
496
|
+
if inp_ty.type_bound() == ht.TypeBound.Linear:
|
|
497
|
+
(arg,) = args
|
|
498
|
+
copies = self._handle_affine_type(inp_ty, arg)
|
|
499
|
+
return CallReturnWires(
|
|
500
|
+
regular_returns=[copies[0]], inout_returns=[copies[1]]
|
|
501
|
+
)
|
|
492
502
|
return CallReturnWires(regular_returns=args, inout_returns=args)
|
|
503
|
+
|
|
504
|
+
# Affine types in Guppy backed by a linear Hugr type need to be copied explicitly.
|
|
505
|
+
# TODO: Handle affine extension types more generally (borrow arrays are currently
|
|
506
|
+
# the only case).
|
|
507
|
+
def _handle_affine_type(self, ty: ht.Type, arg: Wire) -> list[Wire]:
|
|
508
|
+
match ty:
|
|
509
|
+
case ht.ExtType(type_def=type_def, args=type_args):
|
|
510
|
+
if qualified_name(type_def) == qualified_name(
|
|
511
|
+
BORROW_ARRAY_EXTENSION.get_type("borrow_array")
|
|
512
|
+
):
|
|
513
|
+
assert len(type_args) == 2
|
|
514
|
+
# Manually instantiate here to avoid circular import and use
|
|
515
|
+
# type args directly.
|
|
516
|
+
clone_op = BORROW_ARRAY_EXTENSION.get_op("clone").instantiate(
|
|
517
|
+
type_args,
|
|
518
|
+
ht.FunctionType(self.ty.input, self.ty.output),
|
|
519
|
+
)
|
|
520
|
+
return list(self.builder.add_op(clone_op, arg))
|
|
521
|
+
case _:
|
|
522
|
+
pass
|
|
523
|
+
raise InternalGuppyError(
|
|
524
|
+
f"Type `{ty}` needs an explicit handler in the `copy` compiler as "
|
|
525
|
+
"it is an affine Guppy type backed by a linear Hugr type."
|
|
526
|
+
)
|
|
@@ -13,7 +13,7 @@ from guppylang_internals.checker.func_checker import check_signature
|
|
|
13
13
|
from guppylang_internals.compiler.core import (
|
|
14
14
|
CompilerContext,
|
|
15
15
|
DFContainer,
|
|
16
|
-
|
|
16
|
+
require_monomorphization,
|
|
17
17
|
)
|
|
18
18
|
from guppylang_internals.definition.common import CompilableDef, ParsableDef
|
|
19
19
|
from guppylang_internals.definition.function import (
|
|
@@ -72,9 +72,8 @@ class RawFunctionDecl(ParsableDef):
|
|
|
72
72
|
if not has_empty_body(func_ast):
|
|
73
73
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
74
74
|
# Make sure we won't need monomorphization to compile this declaration
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
raise GuppyError(MonomorphizeError(func_ast, self.name, param))
|
|
75
|
+
if mono_params := require_monomorphization(ty.params):
|
|
76
|
+
raise GuppyError(MonomorphizeError(func_ast, self.name, mono_params.pop()))
|
|
78
77
|
return CheckedFunctionDecl(
|
|
79
78
|
self.id,
|
|
80
79
|
self.name,
|
|
@@ -38,14 +38,19 @@ class ParamDef(Definition):
|
|
|
38
38
|
class TypeVarDef(ParamDef, CompiledDef):
|
|
39
39
|
"""A type variable definition."""
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
copyable: bool
|
|
42
|
+
droppable: bool
|
|
43
43
|
|
|
44
44
|
description: str = field(default="type variable", init=False)
|
|
45
45
|
|
|
46
46
|
def to_param(self, idx: int) -> TypeParam:
|
|
47
47
|
"""Creates a parameter from this definition."""
|
|
48
|
-
return TypeParam(
|
|
48
|
+
return TypeParam(
|
|
49
|
+
idx,
|
|
50
|
+
self.name,
|
|
51
|
+
must_be_copyable=self.copyable,
|
|
52
|
+
must_be_droppable=self.droppable,
|
|
53
|
+
)
|
|
49
54
|
|
|
50
55
|
|
|
51
56
|
@dataclass(frozen=True)
|
|
@@ -46,7 +46,6 @@ from guppylang_internals.std._internal.compiler.array import (
|
|
|
46
46
|
array_new,
|
|
47
47
|
array_unpack,
|
|
48
48
|
)
|
|
49
|
-
from guppylang_internals.std._internal.compiler.prelude import build_unwrap
|
|
50
49
|
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
|
|
51
50
|
from guppylang_internals.tys.builtin import array_type, bool_type, float_type
|
|
52
51
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
@@ -195,17 +194,9 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
195
194
|
# them into separate wires.
|
|
196
195
|
for i, q_reg in enumerate(self.input_circuit.q_registers):
|
|
197
196
|
reg_wire = outer_func.inputs()[i]
|
|
198
|
-
|
|
199
|
-
array_unpack(ht.
|
|
197
|
+
elem_wires = outer_func.add_op(
|
|
198
|
+
array_unpack(ht.Qubit, q_reg.size), reg_wire
|
|
200
199
|
)
|
|
201
|
-
elem_wires = [
|
|
202
|
-
build_unwrap(
|
|
203
|
-
outer_func,
|
|
204
|
-
opt_elem,
|
|
205
|
-
"Internal error: unwrapping of array element failed",
|
|
206
|
-
)
|
|
207
|
-
for opt_elem in opt_elem_wires
|
|
208
|
-
]
|
|
209
200
|
input_list.extend(elem_wires)
|
|
210
201
|
|
|
211
202
|
else:
|
|
@@ -219,7 +210,8 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
219
210
|
]
|
|
220
211
|
|
|
221
212
|
# Symbolic parameters (if present) get passed after qubits and bools.
|
|
222
|
-
|
|
213
|
+
num_params = len(self.input_circuit.free_symbols())
|
|
214
|
+
has_params = num_params != 0
|
|
223
215
|
if has_params and "TKET1.input_parameters" not in hugr_func.metadata:
|
|
224
216
|
raise InternalGuppyError(
|
|
225
217
|
"Parameter metadata is missing from pytket circuit HUGR"
|
|
@@ -230,26 +222,17 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
230
222
|
if has_params:
|
|
231
223
|
lex_params: list[Wire] = list(outer_func.inputs()[offset:])
|
|
232
224
|
if self.use_arrays:
|
|
233
|
-
|
|
225
|
+
unpack_result = outer_func.add_op(
|
|
234
226
|
array_unpack(
|
|
235
|
-
ht.
|
|
236
|
-
q_reg.size,
|
|
227
|
+
ht.Tuple(float_type().to_hugr(ctx)), num_params
|
|
237
228
|
),
|
|
238
229
|
lex_params[0],
|
|
239
230
|
)
|
|
240
|
-
lex_params =
|
|
241
|
-
build_unwrap(
|
|
242
|
-
outer_func,
|
|
243
|
-
opt_param,
|
|
244
|
-
"Internal error: unwrapping of array element failed",
|
|
245
|
-
)
|
|
246
|
-
for opt_param in opt_param_wires
|
|
247
|
-
]
|
|
231
|
+
lex_params = list(unpack_result)
|
|
248
232
|
param_order = cast(
|
|
249
233
|
list[str], hugr_func.metadata["TKET1.input_parameters"]
|
|
250
234
|
)
|
|
251
235
|
lex_names = sorted(param_order)
|
|
252
|
-
assert len(lex_names) == len(lex_params)
|
|
253
236
|
name_to_param = dict(zip(lex_names, lex_params, strict=True))
|
|
254
237
|
angle_wires = [name_to_param[name] for name in param_order]
|
|
255
238
|
# Need to convert all angles to floats.
|
|
@@ -280,34 +263,23 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
280
263
|
]
|
|
281
264
|
|
|
282
265
|
if self.use_arrays:
|
|
283
|
-
|
|
284
|
-
def pack(elems: list[Wire], elem_ty: ht.Type, length: int) -> Wire:
|
|
285
|
-
elem_opts = [
|
|
286
|
-
outer_func.add_op(ops.Some(elem_ty), elem) for elem in elems
|
|
287
|
-
]
|
|
288
|
-
return outer_func.add_op(
|
|
289
|
-
array_new(ht.Option(elem_ty), length), *elem_opts
|
|
290
|
-
)
|
|
291
|
-
|
|
292
266
|
array_wires: list[Wire] = []
|
|
293
267
|
wire_idx = 0
|
|
294
268
|
# First pack bool results into an array.
|
|
295
269
|
for c_reg in self.input_circuit.c_registers:
|
|
296
270
|
array_wires.append(
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
c_reg.size,
|
|
271
|
+
outer_func.add_op(
|
|
272
|
+
array_new(OpaqueBool, c_reg.size),
|
|
273
|
+
*wires[wire_idx : wire_idx + c_reg.size],
|
|
301
274
|
)
|
|
302
275
|
)
|
|
303
276
|
wire_idx = wire_idx + c_reg.size
|
|
304
277
|
# Then the borrowed qubits also need to be put back into arrays.
|
|
305
278
|
for q_reg in self.input_circuit.q_registers:
|
|
306
279
|
array_wires.append(
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
q_reg.size,
|
|
280
|
+
outer_func.add_op(
|
|
281
|
+
array_new(ht.Qubit, q_reg.size),
|
|
282
|
+
*wires[wire_idx : wire_idx + q_reg.size],
|
|
311
283
|
)
|
|
312
284
|
)
|
|
313
285
|
wire_idx = wire_idx + q_reg.size
|
|
@@ -131,10 +131,13 @@ class RawStructDef(TypeDef, ParsableDef):
|
|
|
131
131
|
if cls_def.type_params:
|
|
132
132
|
first, last = cls_def.type_params[0], cls_def.type_params[-1]
|
|
133
133
|
params_span = Span(to_span(first).start, to_span(last).end)
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
134
|
+
param_vars_mapping: dict[str, Parameter] = {}
|
|
135
|
+
for idx, param_node in enumerate(cls_def.type_params):
|
|
136
|
+
param = parse_parameter(
|
|
137
|
+
param_node, idx, globals, param_vars_mapping
|
|
138
|
+
)
|
|
139
|
+
param_vars_mapping[param.name] = param
|
|
140
|
+
params.append(param)
|
|
138
141
|
|
|
139
142
|
# The only base we allow is `Generic[...]` to specify generic parameters with
|
|
140
143
|
# the legacy syntax
|
|
@@ -2,7 +2,7 @@ from abc import abstractmethod
|
|
|
2
2
|
from collections.abc import Callable, Sequence
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
|
|
5
|
-
from hugr import tys
|
|
5
|
+
from hugr import tys as ht
|
|
6
6
|
|
|
7
7
|
from guppylang_internals.ast_util import AstNode
|
|
8
8
|
from guppylang_internals.definition.common import CompiledDef, Definition
|
|
@@ -42,8 +42,8 @@ class OpaqueTypeDef(TypeDef, CompiledDef):
|
|
|
42
42
|
params: Sequence[Parameter]
|
|
43
43
|
never_copyable: bool
|
|
44
44
|
never_droppable: bool
|
|
45
|
-
to_hugr: Callable[[Sequence[Argument], ToHugrContext],
|
|
46
|
-
bound:
|
|
45
|
+
to_hugr: Callable[[Sequence[Argument], ToHugrContext], ht.Type]
|
|
46
|
+
bound: ht.TypeBound | None = None
|
|
47
47
|
|
|
48
48
|
def check_instantiate(
|
|
49
49
|
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
@@ -90,3 +90,8 @@ def check_lists_enabled(loc: AstNode | None = None) -> None:
|
|
|
90
90
|
def check_capturing_closures_enabled(loc: AstNode | None = None) -> None:
|
|
91
91
|
if not EXPERIMENTAL_FEATURES_ENABLED:
|
|
92
92
|
raise GuppyError(UnsupportedError(loc, "Capturing closures"))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def check_modifiers_enabled(loc: AstNode | None = None) -> None:
|
|
96
|
+
if not EXPERIMENTAL_FEATURES_ENABLED:
|
|
97
|
+
raise GuppyError(ExperimentalFeatureError(loc, "Modifiers"))
|
guppylang_internals/nodes.py
CHANGED
|
@@ -6,6 +6,7 @@ from enum import Enum
|
|
|
6
6
|
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from guppylang_internals.ast_util import AstNode
|
|
9
|
+
from guppylang_internals.span import Span, to_span
|
|
9
10
|
from guppylang_internals.tys.const import Const
|
|
10
11
|
from guppylang_internals.tys.subst import Inst
|
|
11
12
|
from guppylang_internals.tys.ty import FunctionType, StructType, TupleType, Type
|
|
@@ -422,3 +423,126 @@ class CheckedNestedFunctionDef(ast.FunctionDef):
|
|
|
422
423
|
self.cfg = cfg
|
|
423
424
|
self.ty = ty
|
|
424
425
|
self.captured = captured
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class Dagger(ast.expr):
|
|
429
|
+
"""The dagger modifier"""
|
|
430
|
+
|
|
431
|
+
def __init__(self, node: ast.expr) -> None:
|
|
432
|
+
super().__init__(**node.__dict__)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class Control(ast.Call):
|
|
436
|
+
"""The control modifier"""
|
|
437
|
+
|
|
438
|
+
ctrl: list[ast.expr]
|
|
439
|
+
qubit_num: int | Const | None
|
|
440
|
+
|
|
441
|
+
_fields = ("ctrl",)
|
|
442
|
+
|
|
443
|
+
def __init__(self, node: ast.Call, ctrl: list[ast.expr]) -> None:
|
|
444
|
+
super().__init__(**node.__dict__)
|
|
445
|
+
self.ctrl = ctrl
|
|
446
|
+
self.qubit_num = None
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class Power(ast.expr):
|
|
450
|
+
"""The power modifier"""
|
|
451
|
+
|
|
452
|
+
iter: ast.expr
|
|
453
|
+
|
|
454
|
+
_fields = ("iter",)
|
|
455
|
+
|
|
456
|
+
def __init__(self, node: ast.expr, iter: ast.expr) -> None:
|
|
457
|
+
super().__init__(**node.__dict__)
|
|
458
|
+
self.iter = iter
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
Modifier = Dagger | Control | Power
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
class ModifiedBlock(ast.With):
|
|
465
|
+
cfg: "CFG"
|
|
466
|
+
dagger: list[Dagger]
|
|
467
|
+
control: list[Control]
|
|
468
|
+
power: list[Power]
|
|
469
|
+
|
|
470
|
+
def __init__(self, cfg: "CFG", *args: Any, **kwargs: Any) -> None:
|
|
471
|
+
super().__init__(*args, **kwargs)
|
|
472
|
+
self.cfg = cfg
|
|
473
|
+
self.dagger = []
|
|
474
|
+
self.control = []
|
|
475
|
+
self.power = []
|
|
476
|
+
|
|
477
|
+
def is_dagger(self) -> bool:
|
|
478
|
+
return len(self.dagger) % 2 == 1
|
|
479
|
+
|
|
480
|
+
def is_control(self) -> bool:
|
|
481
|
+
return len(self.control) > 0
|
|
482
|
+
|
|
483
|
+
def is_power(self) -> bool:
|
|
484
|
+
return len(self.power) > 0
|
|
485
|
+
|
|
486
|
+
def span_ctxt_manager(self) -> Span:
|
|
487
|
+
return Span(
|
|
488
|
+
to_span(self.items[0].context_expr).start,
|
|
489
|
+
to_span(self.items[-1].context_expr).end,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
def push_modifier(self, modifier: Modifier) -> None:
|
|
493
|
+
"""Pushes a modifier kind onto the modifier."""
|
|
494
|
+
if isinstance(modifier, Dagger):
|
|
495
|
+
self.dagger.append(modifier)
|
|
496
|
+
elif isinstance(modifier, Control):
|
|
497
|
+
self.control.append(modifier)
|
|
498
|
+
elif isinstance(modifier, Power):
|
|
499
|
+
self.power.append(modifier)
|
|
500
|
+
else:
|
|
501
|
+
raise TypeError(f"Unknown modifier: {modifier}")
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
class CheckedModifiedBlock(ast.With):
|
|
505
|
+
def_id: "DefId"
|
|
506
|
+
cfg: "CheckedCFG[Place]"
|
|
507
|
+
dagger: list[Dagger]
|
|
508
|
+
control: list[Control]
|
|
509
|
+
power: list[Power]
|
|
510
|
+
|
|
511
|
+
#: The type of the body of With block.
|
|
512
|
+
ty: FunctionType
|
|
513
|
+
#: Mapping from names to variables captured in the body.
|
|
514
|
+
captured: Mapping[str, tuple["Variable", AstNode]]
|
|
515
|
+
|
|
516
|
+
def __init__(
|
|
517
|
+
self,
|
|
518
|
+
def_id: "DefId",
|
|
519
|
+
cfg: "CheckedCFG[Place]",
|
|
520
|
+
ty: FunctionType,
|
|
521
|
+
captured: Mapping[str, tuple["Variable", AstNode]],
|
|
522
|
+
dagger: list[Dagger],
|
|
523
|
+
control: list[Control],
|
|
524
|
+
power: list[Power],
|
|
525
|
+
*args: Any,
|
|
526
|
+
**kwargs: Any,
|
|
527
|
+
) -> None:
|
|
528
|
+
super().__init__(*args, **kwargs)
|
|
529
|
+
self.def_id = def_id
|
|
530
|
+
self.cfg = cfg
|
|
531
|
+
self.ty = ty
|
|
532
|
+
self.captured = captured
|
|
533
|
+
self.dagger = dagger
|
|
534
|
+
self.control = control
|
|
535
|
+
self.power = power
|
|
536
|
+
|
|
537
|
+
def __str__(self) -> str:
|
|
538
|
+
# generate a function name from the def_id
|
|
539
|
+
return f"__WithBlock__({self.def_id})"
|
|
540
|
+
|
|
541
|
+
def has_dagger(self) -> bool:
|
|
542
|
+
return len(self.dagger) % 2 == 1
|
|
543
|
+
|
|
544
|
+
def has_control(self) -> bool:
|
|
545
|
+
return any(len(c.ctrl) > 0 for c in self.control)
|
|
546
|
+
|
|
547
|
+
def has_power(self) -> bool:
|
|
548
|
+
return len(self.power) > 0
|