guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +118 -5
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +5 -2
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +58 -17
- guppylang_internals/checker/func_checker.py +18 -14
- guppylang_internals/checker/linearity_checker.py +67 -10
- guppylang_internals/checker/modifier_checker.py +120 -0
- guppylang_internals/checker/stmt_checker.py +48 -1
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +93 -56
- guppylang_internals/compiler/expr_compiler.py +72 -168
- guppylang_internals/compiler/modifier_compiler.py +176 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +86 -7
- guppylang_internals/definition/custom.py +39 -1
- guppylang_internals/definition/declaration.py +9 -6
- guppylang_internals/definition/function.py +12 -2
- guppylang_internals/definition/parameter.py +8 -3
- guppylang_internals/definition/pytket_circuits.py +14 -41
- guppylang_internals/definition/struct.py +13 -7
- guppylang_internals/definition/ty.py +3 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/engine.py +9 -3
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +147 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +95 -283
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +10 -0
- guppylang_internals/tracing/unpacking.py +19 -20
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +2 -5
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +11 -24
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +62 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +91 -85
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
- guppylang_internals-0.26.0.dist-info/RECORD +104 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
- guppylang_internals-0.24.0.dist-info/RECORD +0 -98
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""Hugr generation for modifiers."""
|
|
2
|
+
|
|
3
|
+
from hugr import Wire, ops
|
|
4
|
+
from hugr import tys as ht
|
|
5
|
+
|
|
6
|
+
from guppylang_internals.ast_util import get_type
|
|
7
|
+
from guppylang_internals.checker.modifier_checker import non_copyable_front_others_back
|
|
8
|
+
from guppylang_internals.compiler.cfg_compiler import compile_cfg
|
|
9
|
+
from guppylang_internals.compiler.core import CompilerContext, DFContainer
|
|
10
|
+
from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
11
|
+
from guppylang_internals.definition.function import add_unitarity_metadata
|
|
12
|
+
from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
|
|
13
|
+
from guppylang_internals.std._internal.compiler.array import (
|
|
14
|
+
array_new,
|
|
15
|
+
array_to_std_array,
|
|
16
|
+
standard_array_type,
|
|
17
|
+
std_array_to_array,
|
|
18
|
+
unpack_array,
|
|
19
|
+
)
|
|
20
|
+
from guppylang_internals.std._internal.compiler.tket_exts import MODIFIER_EXTENSION
|
|
21
|
+
from guppylang_internals.tys.builtin import int_type, is_array_type
|
|
22
|
+
from guppylang_internals.tys.ty import InputFlags
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def compile_modified_block(
|
|
26
|
+
modified_block: CheckedModifiedBlock,
|
|
27
|
+
dfg: DFContainer,
|
|
28
|
+
ctx: CompilerContext,
|
|
29
|
+
expr_compiler: ExprCompiler,
|
|
30
|
+
) -> Wire:
|
|
31
|
+
DAGGER_OP_NAME = "DaggerModifier"
|
|
32
|
+
CONTROL_OP_NAME = "ControlModifier"
|
|
33
|
+
POWER_OP_NAME = "PowerModifier"
|
|
34
|
+
|
|
35
|
+
dagger_op_def = MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME)
|
|
36
|
+
control_op_def = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME)
|
|
37
|
+
power_op_def = MODIFIER_EXTENSION.get_op(POWER_OP_NAME)
|
|
38
|
+
|
|
39
|
+
body_ty = modified_block.ty
|
|
40
|
+
# TODO: Shouldn't this be `to_hugr_poly` since it can contain
|
|
41
|
+
# a variable with a generic type?
|
|
42
|
+
hugr_ty = body_ty.to_hugr(ctx)
|
|
43
|
+
in_out_ht = [
|
|
44
|
+
fn_inp.ty.to_hugr(ctx)
|
|
45
|
+
for fn_inp in body_ty.inputs
|
|
46
|
+
if InputFlags.Inout in fn_inp.flags and InputFlags.Comptime not in fn_inp.flags
|
|
47
|
+
]
|
|
48
|
+
other_in_ht = [
|
|
49
|
+
fn_inp.ty.to_hugr(ctx)
|
|
50
|
+
for fn_inp in body_ty.inputs
|
|
51
|
+
if InputFlags.Inout not in fn_inp.flags
|
|
52
|
+
and InputFlags.Comptime not in fn_inp.flags
|
|
53
|
+
]
|
|
54
|
+
in_out_arg = ht.ListArg([t.type_arg() for t in in_out_ht])
|
|
55
|
+
other_in_arg = ht.ListArg([t.type_arg() for t in other_in_ht])
|
|
56
|
+
|
|
57
|
+
func_builder = dfg.builder.module_root_builder().define_function(
|
|
58
|
+
str(modified_block), hugr_ty.input, hugr_ty.output
|
|
59
|
+
)
|
|
60
|
+
add_unitarity_metadata(func_builder, modified_block.ty.unitary_flags)
|
|
61
|
+
|
|
62
|
+
# compile body
|
|
63
|
+
cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
|
|
64
|
+
func_builder.set_outputs(*cfg)
|
|
65
|
+
|
|
66
|
+
# LoadFunc
|
|
67
|
+
call = dfg.builder.load_function(func_builder, hugr_ty)
|
|
68
|
+
|
|
69
|
+
# Function inputs
|
|
70
|
+
captured = [v for v, _ in modified_block.captured.values()]
|
|
71
|
+
captured = non_copyable_front_others_back(captured)
|
|
72
|
+
args = [dfg[v] for v in captured]
|
|
73
|
+
|
|
74
|
+
# Apply modifiers
|
|
75
|
+
if modified_block.has_dagger():
|
|
76
|
+
dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty])
|
|
77
|
+
call = dfg.builder.add_op(
|
|
78
|
+
ops.ExtOp(
|
|
79
|
+
dagger_op_def,
|
|
80
|
+
dagger_ty,
|
|
81
|
+
[in_out_arg, other_in_arg],
|
|
82
|
+
),
|
|
83
|
+
call,
|
|
84
|
+
)
|
|
85
|
+
if modified_block.has_power():
|
|
86
|
+
power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty])
|
|
87
|
+
for power in modified_block.power:
|
|
88
|
+
num = expr_compiler.compile(power.iter, dfg)
|
|
89
|
+
call = dfg.builder.add_op(
|
|
90
|
+
ops.ExtOp(
|
|
91
|
+
power_op_def,
|
|
92
|
+
power_ty,
|
|
93
|
+
[in_out_arg, other_in_arg],
|
|
94
|
+
),
|
|
95
|
+
call,
|
|
96
|
+
num,
|
|
97
|
+
)
|
|
98
|
+
qubit_num_args = []
|
|
99
|
+
if modified_block.has_control():
|
|
100
|
+
for control in modified_block.control:
|
|
101
|
+
assert control.qubit_num is not None
|
|
102
|
+
qubit_num: ht.TypeArg
|
|
103
|
+
if isinstance(control.qubit_num, int):
|
|
104
|
+
qubit_num = ht.BoundedNatArg(control.qubit_num)
|
|
105
|
+
else:
|
|
106
|
+
qubit_num = control.qubit_num.to_arg().to_hugr(ctx)
|
|
107
|
+
qubit_num_args.append(qubit_num)
|
|
108
|
+
std_array = standard_array_type(ht.Qubit, qubit_num)
|
|
109
|
+
|
|
110
|
+
# control operator
|
|
111
|
+
input_fn_ty = hugr_ty
|
|
112
|
+
output_fn_ty = ht.FunctionType(
|
|
113
|
+
[std_array, *hugr_ty.input], [std_array, *hugr_ty.output]
|
|
114
|
+
)
|
|
115
|
+
op = ops.ExtOp(
|
|
116
|
+
control_op_def,
|
|
117
|
+
ht.FunctionType([input_fn_ty], [output_fn_ty]),
|
|
118
|
+
[qubit_num, in_out_arg, other_in_arg],
|
|
119
|
+
)
|
|
120
|
+
call = dfg.builder.add_op(op, call)
|
|
121
|
+
# update types
|
|
122
|
+
in_out_arg = ht.ListArg([std_array.type_arg(), *in_out_arg.elems])
|
|
123
|
+
hugr_ty = output_fn_ty
|
|
124
|
+
|
|
125
|
+
# Prepare control arguments
|
|
126
|
+
ctrl_args: list[Wire] = []
|
|
127
|
+
for i, control in enumerate(modified_block.control):
|
|
128
|
+
if is_array_type(get_type(control.ctrl[0])):
|
|
129
|
+
control_array = expr_compiler.compile(control.ctrl[0], dfg)
|
|
130
|
+
control_array = dfg.builder.add_op(
|
|
131
|
+
array_to_std_array(ht.Qubit, qubit_num_args[i]), control_array
|
|
132
|
+
)
|
|
133
|
+
ctrl_args.append(control_array)
|
|
134
|
+
else:
|
|
135
|
+
cs = [expr_compiler.compile(c, dfg) for c in control.ctrl]
|
|
136
|
+
control_array = dfg.builder.add_op(
|
|
137
|
+
array_new(ht.Qubit, len(control.ctrl)), *cs
|
|
138
|
+
)
|
|
139
|
+
control_array = dfg.builder.add_op(
|
|
140
|
+
array_to_std_array(ht.Qubit, qubit_num_args[i]), *control_array
|
|
141
|
+
)
|
|
142
|
+
ctrl_args.append(control_array)
|
|
143
|
+
|
|
144
|
+
# Call
|
|
145
|
+
call = dfg.builder.add_op(
|
|
146
|
+
ops.CallIndirect(),
|
|
147
|
+
call,
|
|
148
|
+
*ctrl_args,
|
|
149
|
+
*args,
|
|
150
|
+
)
|
|
151
|
+
outports = iter(call)
|
|
152
|
+
|
|
153
|
+
# Unpack controls
|
|
154
|
+
for i, control in enumerate(modified_block.control):
|
|
155
|
+
outport = next(outports)
|
|
156
|
+
if is_array_type(get_type(control.ctrl[0])):
|
|
157
|
+
control_array = dfg.builder.add_op(
|
|
158
|
+
std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
|
|
159
|
+
)
|
|
160
|
+
c = control.ctrl[0]
|
|
161
|
+
assert isinstance(c, PlaceNode)
|
|
162
|
+
dfg[c.place] = control_array
|
|
163
|
+
else:
|
|
164
|
+
control_array = dfg.builder.add_op(
|
|
165
|
+
std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
|
|
166
|
+
)
|
|
167
|
+
unpacked = unpack_array(dfg.builder, control_array)
|
|
168
|
+
for c, new_c in zip(control.ctrl, unpacked, strict=False):
|
|
169
|
+
assert isinstance(c, PlaceNode)
|
|
170
|
+
dfg[c.place] = new_c
|
|
171
|
+
|
|
172
|
+
for arg in captured:
|
|
173
|
+
if InputFlags.Inout in arg.flags:
|
|
174
|
+
dfg[arg] = next(outports)
|
|
175
|
+
|
|
176
|
+
return call
|
|
@@ -2,7 +2,6 @@ import ast
|
|
|
2
2
|
import functools
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
|
|
5
|
-
import hugr.tys as ht
|
|
6
5
|
from hugr import Wire, ops
|
|
7
6
|
from hugr.build.dfg import DfBase
|
|
8
7
|
|
|
@@ -18,6 +17,7 @@ from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
|
18
17
|
from guppylang_internals.error import InternalGuppyError
|
|
19
18
|
from guppylang_internals.nodes import (
|
|
20
19
|
ArrayUnpack,
|
|
20
|
+
CheckedModifiedBlock,
|
|
21
21
|
CheckedNestedFunctionDef,
|
|
22
22
|
IterableUnpack,
|
|
23
23
|
PlaceNode,
|
|
@@ -120,8 +120,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
120
120
|
ports[len(left) : -len(right)] if right else ports[len(left) :]
|
|
121
121
|
)
|
|
122
122
|
elt = get_element_type(array_ty).to_hugr(self.ctx)
|
|
123
|
-
|
|
124
|
-
|
|
123
|
+
array = self.builder.add_op(
|
|
124
|
+
array_new(elt, len(starred_ports)), *starred_ports
|
|
125
|
+
)
|
|
125
126
|
self._assign(starred, array)
|
|
126
127
|
|
|
127
128
|
@_assign.register
|
|
@@ -130,7 +131,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
130
131
|
# Given an assignment pattern `left, *starred, right`, pop from the left and
|
|
131
132
|
# right, leaving us with the starred array in the middle
|
|
132
133
|
length = lhs.length
|
|
133
|
-
|
|
134
|
+
elt_ty = lhs.elt_type.to_hugr(self.ctx)
|
|
134
135
|
|
|
135
136
|
def pop(
|
|
136
137
|
array: Wire, length: int, pats: list[ast.expr], from_left: bool
|
|
@@ -141,10 +142,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
141
142
|
elts = []
|
|
142
143
|
for i in range(num_pats):
|
|
143
144
|
res = self.builder.add_op(
|
|
144
|
-
array_pop(
|
|
145
|
+
array_pop(elt_ty, length - i, from_left), array
|
|
145
146
|
)
|
|
146
|
-
[
|
|
147
|
-
[elt] = build_unwrap(self.builder, elt_opt, err)
|
|
147
|
+
[elt, array] = build_unwrap(self.builder, res, err)
|
|
148
148
|
elts.append(elt)
|
|
149
149
|
# Assign elements to the given patterns
|
|
150
150
|
for pat, elt in zip(
|
|
@@ -164,7 +164,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
164
164
|
self._assign(lhs.pattern.starred, array)
|
|
165
165
|
else:
|
|
166
166
|
assert length == 0
|
|
167
|
-
self.builder.add_op(array_discard_empty(
|
|
167
|
+
self.builder.add_op(array_discard_empty(elt_ty), array)
|
|
168
168
|
|
|
169
169
|
@_assign.register
|
|
170
170
|
def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None:
|
|
@@ -221,3 +221,10 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
221
221
|
var = Variable(node.name, node.ty, node)
|
|
222
222
|
loaded_func = compile_local_func_def(node, self.dfg, self.ctx)
|
|
223
223
|
self.dfg[var] = loaded_func
|
|
224
|
+
|
|
225
|
+
def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
|
|
226
|
+
from guppylang_internals.compiler.modifier_compiler import (
|
|
227
|
+
compile_modified_block,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
compile_modified_block(node, self.dfg, self.ctx, self.expr_compiler)
|
guppylang_internals/decorator.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
+
import pathlib
|
|
4
5
|
from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
|
|
5
6
|
|
|
6
7
|
from hugr import ops
|
|
7
8
|
from hugr import tys as ht
|
|
8
9
|
|
|
10
|
+
from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
|
|
9
11
|
from guppylang_internals.compiler.core import (
|
|
10
12
|
CompilerContext,
|
|
11
13
|
GlobalConstId,
|
|
@@ -24,6 +26,7 @@ from guppylang_internals.definition.ty import OpaqueTypeDef, TypeDef
|
|
|
24
26
|
from guppylang_internals.definition.wasm import RawWasmFunctionDef
|
|
25
27
|
from guppylang_internals.dummy_decorator import _dummy_custom_decorator, sphinx_running
|
|
26
28
|
from guppylang_internals.engine import DEF_STORE
|
|
29
|
+
from guppylang_internals.error import GuppyError
|
|
27
30
|
from guppylang_internals.std._internal.checker import WasmCallChecker
|
|
28
31
|
from guppylang_internals.std._internal.compiler.wasm import (
|
|
29
32
|
WasmModuleCallCompiler,
|
|
@@ -39,6 +42,14 @@ from guppylang_internals.tys.ty import (
|
|
|
39
42
|
InputFlags,
|
|
40
43
|
NoneType,
|
|
41
44
|
NumericType,
|
|
45
|
+
UnitaryFlags,
|
|
46
|
+
)
|
|
47
|
+
from guppylang_internals.wasm_util import (
|
|
48
|
+
ConcreteWasmModule,
|
|
49
|
+
WasmFileNotFound,
|
|
50
|
+
WasmFunctionNotInFile,
|
|
51
|
+
WasmSignatureError,
|
|
52
|
+
decode_wasm_functions,
|
|
42
53
|
)
|
|
43
54
|
|
|
44
55
|
if TYPE_CHECKING:
|
|
@@ -47,7 +58,6 @@ if TYPE_CHECKING:
|
|
|
47
58
|
from collections.abc import Callable, Sequence
|
|
48
59
|
from types import FrameType
|
|
49
60
|
|
|
50
|
-
from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
|
|
51
61
|
from guppylang_internals.tys.arg import Argument
|
|
52
62
|
from guppylang_internals.tys.param import Parameter
|
|
53
63
|
from guppylang_internals.tys.subst import Inst
|
|
@@ -75,6 +85,7 @@ def custom_function(
|
|
|
75
85
|
higher_order_value: bool = True,
|
|
76
86
|
name: str = "",
|
|
77
87
|
signature: FunctionType | None = None,
|
|
88
|
+
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
|
|
78
89
|
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
|
|
79
90
|
"""Decorator to add custom typing or compilation behaviour to function decls.
|
|
80
91
|
|
|
@@ -86,6 +97,8 @@ def custom_function(
|
|
|
86
97
|
|
|
87
98
|
def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
|
|
88
99
|
call_checker = checker or DefaultCallChecker()
|
|
100
|
+
if signature is not None:
|
|
101
|
+
object.__setattr__(signature, "unitary_flags", unitary_flags)
|
|
89
102
|
func = RawCustomFunctionDef(
|
|
90
103
|
DefId.fresh(),
|
|
91
104
|
name or f.__name__,
|
|
@@ -95,6 +108,7 @@ def custom_function(
|
|
|
95
108
|
compiler or NotImplementedCallCompiler(),
|
|
96
109
|
higher_order_value,
|
|
97
110
|
signature,
|
|
111
|
+
unitary_flags,
|
|
98
112
|
)
|
|
99
113
|
DEF_STORE.register_def(func, get_calling_frame())
|
|
100
114
|
return GuppyFunctionDefinition(func)
|
|
@@ -108,6 +122,7 @@ def hugr_op(
|
|
|
108
122
|
higher_order_value: bool = True,
|
|
109
123
|
name: str = "",
|
|
110
124
|
signature: FunctionType | None = None,
|
|
125
|
+
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
|
|
111
126
|
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
|
|
112
127
|
"""Decorator to annotate function declarations as HUGR ops.
|
|
113
128
|
|
|
@@ -119,7 +134,14 @@ def hugr_op(
|
|
|
119
134
|
value.
|
|
120
135
|
name: The name of the function.
|
|
121
136
|
"""
|
|
122
|
-
return custom_function(
|
|
137
|
+
return custom_function(
|
|
138
|
+
OpCompiler(op),
|
|
139
|
+
checker,
|
|
140
|
+
higher_order_value,
|
|
141
|
+
name,
|
|
142
|
+
signature,
|
|
143
|
+
unitary_flags=unitary_flags,
|
|
144
|
+
)
|
|
123
145
|
|
|
124
146
|
|
|
125
147
|
def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]:
|
|
@@ -188,6 +210,12 @@ def custom_type(
|
|
|
188
210
|
def wasm_module(
|
|
189
211
|
filename: str,
|
|
190
212
|
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
213
|
+
wasm_file = pathlib.Path(filename)
|
|
214
|
+
if wasm_file.is_file():
|
|
215
|
+
wasm_sigs = decode_wasm_functions(filename)
|
|
216
|
+
else:
|
|
217
|
+
raise GuppyError(WasmFileNotFound(None, filename))
|
|
218
|
+
|
|
191
219
|
def type_def_wrapper(
|
|
192
220
|
id: DefId,
|
|
193
221
|
name: str,
|
|
@@ -198,10 +226,19 @@ def wasm_module(
|
|
|
198
226
|
assert config is None
|
|
199
227
|
return WasmModuleTypeDef(id, name, defined_at, wasm_file)
|
|
200
228
|
|
|
201
|
-
|
|
202
|
-
type_def_wrapper,
|
|
229
|
+
decorator = ext_module_decorator(
|
|
230
|
+
type_def_wrapper,
|
|
231
|
+
WasmModuleInitCompiler(),
|
|
232
|
+
WasmModuleDiscardCompiler(),
|
|
233
|
+
True,
|
|
234
|
+
wasm_sigs,
|
|
203
235
|
)
|
|
204
|
-
|
|
236
|
+
|
|
237
|
+
def inner_fun(ty: builtins.type[T]) -> GuppyDefinition:
|
|
238
|
+
decorator_inner = decorator(filename, None)
|
|
239
|
+
return decorator_inner(ty)
|
|
240
|
+
|
|
241
|
+
return inner_fun
|
|
205
242
|
|
|
206
243
|
|
|
207
244
|
def ext_module_decorator(
|
|
@@ -209,9 +246,9 @@ def ext_module_decorator(
|
|
|
209
246
|
init_compiler: CustomInoutCallCompiler,
|
|
210
247
|
discard_compiler: CustomInoutCallCompiler,
|
|
211
248
|
init_arg: bool, # Whether the init function should take a nat argument
|
|
249
|
+
wasm_sigs: ConcreteWasmModule
|
|
250
|
+
| None = None, # For @wasm_module, we must be passed a parsed wasm file
|
|
212
251
|
) -> Callable[[str, str | None], Callable[[builtins.type[T]], GuppyDefinition]]:
|
|
213
|
-
from guppylang.defs import GuppyDefinition
|
|
214
|
-
|
|
215
252
|
def fun(
|
|
216
253
|
filename: str, module: str | None
|
|
217
254
|
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
@@ -231,6 +268,47 @@ def ext_module_decorator(
|
|
|
231
268
|
for val in cls.__dict__.values():
|
|
232
269
|
if isinstance(val, GuppyDefinition):
|
|
233
270
|
DEF_STORE.register_impl(ext_module.id, val.wrapped.name, val.id)
|
|
271
|
+
wasm_def: RawWasmFunctionDef
|
|
272
|
+
if isinstance(val, GuppyFunctionDefinition) and isinstance(
|
|
273
|
+
val.wrapped, RawWasmFunctionDef
|
|
274
|
+
):
|
|
275
|
+
wasm_def = val.wrapped
|
|
276
|
+
else:
|
|
277
|
+
continue
|
|
278
|
+
# wasm_sigs should only have not been provided if we have
|
|
279
|
+
# defined @wasm functions in a class which didn't use the
|
|
280
|
+
# @wasm_module decorator.
|
|
281
|
+
assert wasm_sigs is not None
|
|
282
|
+
if wasm_def.wasm_index is not None:
|
|
283
|
+
name = wasm_sigs.functions[wasm_def.wasm_index]
|
|
284
|
+
assert name in wasm_sigs.function_sigs
|
|
285
|
+
wasm_sig_or_err = wasm_sigs.function_sigs[name]
|
|
286
|
+
else:
|
|
287
|
+
if wasm_def.name in wasm_sigs.function_sigs:
|
|
288
|
+
wasm_sig_or_err = wasm_sigs.function_sigs[wasm_def.name]
|
|
289
|
+
else:
|
|
290
|
+
raise GuppyError(
|
|
291
|
+
WasmFunctionNotInFile(
|
|
292
|
+
wasm_def.defined_at,
|
|
293
|
+
wasm_def.name,
|
|
294
|
+
).add_sub_diagnostic(
|
|
295
|
+
WasmFunctionNotInFile.WasmFileNote(
|
|
296
|
+
None,
|
|
297
|
+
wasm_sigs.filename,
|
|
298
|
+
)
|
|
299
|
+
)
|
|
300
|
+
)
|
|
301
|
+
if isinstance(wasm_sig_or_err, FunctionType):
|
|
302
|
+
DEF_STORE.register_wasm_function(wasm_def.id, wasm_sig_or_err)
|
|
303
|
+
elif isinstance(wasm_sig_or_err, str):
|
|
304
|
+
raise GuppyError(
|
|
305
|
+
WasmSignatureError(
|
|
306
|
+
None, wasm_def.name, filename
|
|
307
|
+
).add_sub_diagnostic(
|
|
308
|
+
WasmSignatureError.Message(None, wasm_sig_or_err)
|
|
309
|
+
)
|
|
310
|
+
)
|
|
311
|
+
|
|
234
312
|
# Add a constructor to the class
|
|
235
313
|
if init_arg:
|
|
236
314
|
init_fn_ty = FunctionType(
|
|
@@ -315,6 +393,7 @@ def wasm_helper(fn_id: int | None, f: Callable[P, T]) -> GuppyFunctionDefinition
|
|
|
315
393
|
WasmModuleCallCompiler(f.__name__, fn_id),
|
|
316
394
|
True,
|
|
317
395
|
signature=None,
|
|
396
|
+
wasm_index=fn_id,
|
|
318
397
|
)
|
|
319
398
|
DEF_STORE.register_def(func, get_calling_frame())
|
|
320
399
|
return GuppyFunctionDefinition(func)
|
|
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, ClassVar
|
|
|
7
7
|
from hugr import Wire, ops
|
|
8
8
|
from hugr import tys as ht
|
|
9
9
|
from hugr.build.dfg import DfBase
|
|
10
|
+
from hugr.std.collections.borrow_array import EXTENSION as BORROW_ARRAY_EXTENSION
|
|
10
11
|
|
|
11
12
|
from guppylang_internals.ast_util import (
|
|
12
13
|
AstNode,
|
|
@@ -23,6 +24,7 @@ from guppylang_internals.compiler.core import (
|
|
|
23
24
|
DFContainer,
|
|
24
25
|
GlobalConstId,
|
|
25
26
|
partially_monomorphize_args,
|
|
27
|
+
qualified_name,
|
|
26
28
|
)
|
|
27
29
|
from guppylang_internals.definition.common import ParsableDef
|
|
28
30
|
from guppylang_internals.definition.value import CallReturnWires, CompiledCallableDef
|
|
@@ -42,6 +44,7 @@ from guppylang_internals.tys.ty import (
|
|
|
42
44
|
InputFlags,
|
|
43
45
|
NoneType,
|
|
44
46
|
Type,
|
|
47
|
+
UnitaryFlags,
|
|
45
48
|
type_to_row,
|
|
46
49
|
)
|
|
47
50
|
|
|
@@ -112,6 +115,8 @@ class RawCustomFunctionDef(ParsableDef):
|
|
|
112
115
|
|
|
113
116
|
signature: FunctionType | None
|
|
114
117
|
|
|
118
|
+
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags)
|
|
119
|
+
|
|
115
120
|
description: str = field(default="function", init=False)
|
|
116
121
|
|
|
117
122
|
def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
|
|
@@ -134,6 +139,7 @@ class RawCustomFunctionDef(ParsableDef):
|
|
|
134
139
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
135
140
|
sig = self.signature or self._get_signature(func_ast, globals)
|
|
136
141
|
ty = sig or FunctionType([], NoneType())
|
|
142
|
+
ty = ty.with_unitary_flags(self.unitary_flags)
|
|
137
143
|
return CustomFunctionDef(
|
|
138
144
|
self.id,
|
|
139
145
|
self.name,
|
|
@@ -486,7 +492,39 @@ class NoopCompiler(CustomCallCompiler):
|
|
|
486
492
|
|
|
487
493
|
|
|
488
494
|
class CopyInoutCompiler(CustomInoutCallCompiler):
|
|
489
|
-
"""Call compiler for functions that
|
|
495
|
+
"""Call compiler for functions that borrow one argument to copy it."""
|
|
490
496
|
|
|
491
497
|
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
498
|
+
assert len(self.ty.input) == 1
|
|
499
|
+
inp_ty = self.ty.input[0]
|
|
500
|
+
if inp_ty.type_bound() == ht.TypeBound.Linear:
|
|
501
|
+
(arg,) = args
|
|
502
|
+
copies = self._handle_affine_type(inp_ty, arg)
|
|
503
|
+
return CallReturnWires(
|
|
504
|
+
regular_returns=[copies[0]], inout_returns=[copies[1]]
|
|
505
|
+
)
|
|
492
506
|
return CallReturnWires(regular_returns=args, inout_returns=args)
|
|
507
|
+
|
|
508
|
+
# Affine types in Guppy backed by a linear Hugr type need to be copied explicitly.
|
|
509
|
+
# TODO: Handle affine extension types more generally (borrow arrays are currently
|
|
510
|
+
# the only case).
|
|
511
|
+
def _handle_affine_type(self, ty: ht.Type, arg: Wire) -> list[Wire]:
|
|
512
|
+
match ty:
|
|
513
|
+
case ht.ExtType(type_def=type_def, args=type_args):
|
|
514
|
+
if qualified_name(type_def) == qualified_name(
|
|
515
|
+
BORROW_ARRAY_EXTENSION.get_type("borrow_array")
|
|
516
|
+
):
|
|
517
|
+
assert len(type_args) == 2
|
|
518
|
+
# Manually instantiate here to avoid circular import and use
|
|
519
|
+
# type args directly.
|
|
520
|
+
clone_op = BORROW_ARRAY_EXTENSION.get_op("clone").instantiate(
|
|
521
|
+
type_args,
|
|
522
|
+
ht.FunctionType(self.ty.input, self.ty.output),
|
|
523
|
+
)
|
|
524
|
+
return list(self.builder.add_op(clone_op, arg))
|
|
525
|
+
case _:
|
|
526
|
+
pass
|
|
527
|
+
raise InternalGuppyError(
|
|
528
|
+
f"Type `{ty}` needs an explicit handler in the `copy` compiler as "
|
|
529
|
+
"it is an affine Guppy type backed by a linear Hugr type."
|
|
530
|
+
)
|
|
@@ -13,7 +13,7 @@ from guppylang_internals.checker.func_checker import check_signature
|
|
|
13
13
|
from guppylang_internals.compiler.core import (
|
|
14
14
|
CompilerContext,
|
|
15
15
|
DFContainer,
|
|
16
|
-
|
|
16
|
+
require_monomorphization,
|
|
17
17
|
)
|
|
18
18
|
from guppylang_internals.definition.common import CompilableDef, ParsableDef
|
|
19
19
|
from guppylang_internals.definition.function import (
|
|
@@ -34,7 +34,7 @@ from guppylang_internals.nodes import GlobalCall
|
|
|
34
34
|
from guppylang_internals.span import SourceMap
|
|
35
35
|
from guppylang_internals.tys.param import Parameter
|
|
36
36
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
37
|
-
from guppylang_internals.tys.ty import Type
|
|
37
|
+
from guppylang_internals.tys.ty import Type, UnitaryFlags
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
@dataclass(frozen=True)
|
|
@@ -65,16 +65,19 @@ class RawFunctionDecl(ParsableDef):
|
|
|
65
65
|
python_func: PyFunc
|
|
66
66
|
description: str = field(default="function", init=False)
|
|
67
67
|
|
|
68
|
+
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
|
|
69
|
+
|
|
68
70
|
def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
|
|
69
71
|
"""Parses and checks the user-provided signature of the function."""
|
|
70
72
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
71
|
-
ty = check_signature(
|
|
73
|
+
ty = check_signature(
|
|
74
|
+
func_ast, globals, self.id, unitary_flags=self.unitary_flags
|
|
75
|
+
)
|
|
72
76
|
if not has_empty_body(func_ast):
|
|
73
77
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
74
78
|
# Make sure we won't need monomorphization to compile this declaration
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
raise GuppyError(MonomorphizeError(func_ast, self.name, param))
|
|
79
|
+
if mono_params := require_monomorphization(ty.params):
|
|
80
|
+
raise GuppyError(MonomorphizeError(func_ast, self.name, mono_params.pop()))
|
|
78
81
|
return CheckedFunctionDecl(
|
|
79
82
|
self.id,
|
|
80
83
|
self.name,
|
|
@@ -43,7 +43,7 @@ from guppylang_internals.error import GuppyError
|
|
|
43
43
|
from guppylang_internals.nodes import GlobalCall
|
|
44
44
|
from guppylang_internals.span import SourceMap
|
|
45
45
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
46
|
-
from guppylang_internals.tys.ty import FunctionType, Type, type_to_row
|
|
46
|
+
from guppylang_internals.tys.ty import FunctionType, Type, UnitaryFlags, type_to_row
|
|
47
47
|
|
|
48
48
|
if TYPE_CHECKING:
|
|
49
49
|
from guppylang_internals.tys.param import Parameter
|
|
@@ -70,10 +70,14 @@ class RawFunctionDef(ParsableDef):
|
|
|
70
70
|
|
|
71
71
|
description: str = field(default="function", init=False)
|
|
72
72
|
|
|
73
|
+
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
|
|
74
|
+
|
|
73
75
|
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
|
|
74
76
|
"""Parses and checks the user-provided signature of the function."""
|
|
75
77
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
76
|
-
ty = check_signature(
|
|
78
|
+
ty = check_signature(
|
|
79
|
+
func_ast, globals, self.id, unitary_flags=self.unitary_flags
|
|
80
|
+
)
|
|
77
81
|
return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
|
|
78
82
|
|
|
79
83
|
|
|
@@ -173,6 +177,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
|
|
|
173
177
|
func_def = module.module_root_builder().define_function(
|
|
174
178
|
self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
|
|
175
179
|
)
|
|
180
|
+
add_unitarity_metadata(func_def, self.ty.unitary_flags)
|
|
176
181
|
return CompiledFunctionDef(
|
|
177
182
|
self.id,
|
|
178
183
|
self.name,
|
|
@@ -300,3 +305,8 @@ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AS
|
|
|
300
305
|
else:
|
|
301
306
|
node = ast.parse(source).body[0]
|
|
302
307
|
return source, node, line_offset
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def add_unitarity_metadata(func: hf.Function, flags: UnitaryFlags) -> None:
|
|
311
|
+
"""Stores unitarity annotations in the metadate of a Hugr function definition."""
|
|
312
|
+
func.metadata["unitary"] = flags.value
|
|
@@ -38,14 +38,19 @@ class ParamDef(Definition):
|
|
|
38
38
|
class TypeVarDef(ParamDef, CompiledDef):
|
|
39
39
|
"""A type variable definition."""
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
copyable: bool
|
|
42
|
+
droppable: bool
|
|
43
43
|
|
|
44
44
|
description: str = field(default="type variable", init=False)
|
|
45
45
|
|
|
46
46
|
def to_param(self, idx: int) -> TypeParam:
|
|
47
47
|
"""Creates a parameter from this definition."""
|
|
48
|
-
return TypeParam(
|
|
48
|
+
return TypeParam(
|
|
49
|
+
idx,
|
|
50
|
+
self.name,
|
|
51
|
+
must_be_copyable=self.copyable,
|
|
52
|
+
must_be_droppable=self.droppable,
|
|
53
|
+
)
|
|
49
54
|
|
|
50
55
|
|
|
51
56
|
@dataclass(frozen=True)
|