guppylang-internals 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +101 -3
- guppylang_internals/checker/core.py +12 -0
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/expr_checker.py +55 -29
- guppylang_internals/checker/func_checker.py +171 -22
- guppylang_internals/checker/linearity_checker.py +65 -0
- guppylang_internals/checker/modifier_checker.py +116 -0
- guppylang_internals/checker/stmt_checker.py +49 -2
- guppylang_internals/compiler/core.py +90 -53
- guppylang_internals/compiler/expr_compiler.py +49 -114
- guppylang_internals/compiler/modifier_compiler.py +174 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +124 -58
- guppylang_internals/definition/const.py +2 -2
- guppylang_internals/definition/custom.py +36 -2
- guppylang_internals/definition/declaration.py +4 -5
- guppylang_internals/definition/extern.py +2 -2
- guppylang_internals/definition/function.py +1 -1
- guppylang_internals/definition/parameter.py +10 -5
- guppylang_internals/definition/pytket_circuits.py +14 -42
- guppylang_internals/definition/struct.py +17 -14
- guppylang_internals/definition/traced.py +1 -1
- guppylang_internals/definition/ty.py +9 -3
- guppylang_internals/definition/wasm.py +2 -2
- guppylang_internals/engine.py +13 -2
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +124 -23
- guppylang_internals/std/_internal/compiler/array.py +94 -282
- guppylang_internals/std/_internal/compiler/tket_exts.py +12 -8
- guppylang_internals/std/_internal/compiler/wasm.py +37 -26
- guppylang_internals/tracing/function.py +13 -2
- guppylang_internals/tracing/unpacking.py +33 -28
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +32 -16
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +6 -0
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +118 -145
- guppylang_internals/tys/qubit.py +27 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +31 -21
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +4 -4
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +49 -46
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""Hugr generation for modifiers."""
|
|
2
|
+
|
|
3
|
+
from hugr import Wire, ops
|
|
4
|
+
from hugr import tys as ht
|
|
5
|
+
|
|
6
|
+
from guppylang_internals.ast_util import get_type
|
|
7
|
+
from guppylang_internals.checker.modifier_checker import non_copyable_front_others_back
|
|
8
|
+
from guppylang_internals.compiler.cfg_compiler import compile_cfg
|
|
9
|
+
from guppylang_internals.compiler.core import CompilerContext, DFContainer
|
|
10
|
+
from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
11
|
+
from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
|
|
12
|
+
from guppylang_internals.std._internal.compiler.array import (
|
|
13
|
+
array_new,
|
|
14
|
+
array_to_std_array,
|
|
15
|
+
standard_array_type,
|
|
16
|
+
std_array_to_array,
|
|
17
|
+
unpack_array,
|
|
18
|
+
)
|
|
19
|
+
from guppylang_internals.std._internal.compiler.tket_exts import MODIFIER_EXTENSION
|
|
20
|
+
from guppylang_internals.tys.builtin import int_type, is_array_type
|
|
21
|
+
from guppylang_internals.tys.ty import InputFlags
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def compile_modified_block(
|
|
25
|
+
modified_block: CheckedModifiedBlock,
|
|
26
|
+
dfg: DFContainer,
|
|
27
|
+
ctx: CompilerContext,
|
|
28
|
+
expr_compiler: ExprCompiler,
|
|
29
|
+
) -> Wire:
|
|
30
|
+
DAGGER_OP_NAME = "DaggerModifier"
|
|
31
|
+
CONTROL_OP_NAME = "ControlModifier"
|
|
32
|
+
POWER_OP_NAME = "PowerModifier"
|
|
33
|
+
|
|
34
|
+
dagger_op_def = MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME)
|
|
35
|
+
control_op_def = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME)
|
|
36
|
+
power_op_def = MODIFIER_EXTENSION.get_op(POWER_OP_NAME)
|
|
37
|
+
|
|
38
|
+
body_ty = modified_block.ty
|
|
39
|
+
# TODO: Shouldn't this be `to_hugr_poly` since it can contain
|
|
40
|
+
# a variable with a generic type?
|
|
41
|
+
hugr_ty = body_ty.to_hugr(ctx)
|
|
42
|
+
in_out_ht = [
|
|
43
|
+
fn_inp.ty.to_hugr(ctx)
|
|
44
|
+
for fn_inp in body_ty.inputs
|
|
45
|
+
if InputFlags.Inout in fn_inp.flags and InputFlags.Comptime not in fn_inp.flags
|
|
46
|
+
]
|
|
47
|
+
other_in_ht = [
|
|
48
|
+
fn_inp.ty.to_hugr(ctx)
|
|
49
|
+
for fn_inp in body_ty.inputs
|
|
50
|
+
if InputFlags.Inout not in fn_inp.flags
|
|
51
|
+
and InputFlags.Comptime not in fn_inp.flags
|
|
52
|
+
]
|
|
53
|
+
in_out_arg = ht.ListArg([t.type_arg() for t in in_out_ht])
|
|
54
|
+
other_in_arg = ht.ListArg([t.type_arg() for t in other_in_ht])
|
|
55
|
+
|
|
56
|
+
func_builder = dfg.builder.module_root_builder().define_function(
|
|
57
|
+
str(modified_block), hugr_ty.input, hugr_ty.output
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# compile body
|
|
61
|
+
cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
|
|
62
|
+
func_builder.set_outputs(*cfg)
|
|
63
|
+
|
|
64
|
+
# LoadFunc
|
|
65
|
+
call = dfg.builder.load_function(func_builder, hugr_ty)
|
|
66
|
+
|
|
67
|
+
# Function inputs
|
|
68
|
+
captured = [v for v, _ in modified_block.captured.values()]
|
|
69
|
+
captured = non_copyable_front_others_back(captured)
|
|
70
|
+
args = [dfg[v] for v in captured]
|
|
71
|
+
|
|
72
|
+
# Apply modifiers
|
|
73
|
+
if modified_block.has_dagger():
|
|
74
|
+
dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty])
|
|
75
|
+
call = dfg.builder.add_op(
|
|
76
|
+
ops.ExtOp(
|
|
77
|
+
dagger_op_def,
|
|
78
|
+
dagger_ty,
|
|
79
|
+
[in_out_arg, other_in_arg],
|
|
80
|
+
),
|
|
81
|
+
call,
|
|
82
|
+
)
|
|
83
|
+
if modified_block.has_power():
|
|
84
|
+
power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty])
|
|
85
|
+
for power in modified_block.power:
|
|
86
|
+
num = expr_compiler.compile(power.iter, dfg)
|
|
87
|
+
call = dfg.builder.add_op(
|
|
88
|
+
ops.ExtOp(
|
|
89
|
+
power_op_def,
|
|
90
|
+
power_ty,
|
|
91
|
+
[in_out_arg, other_in_arg],
|
|
92
|
+
),
|
|
93
|
+
call,
|
|
94
|
+
num,
|
|
95
|
+
)
|
|
96
|
+
qubit_num_args = []
|
|
97
|
+
if modified_block.has_control():
|
|
98
|
+
for control in modified_block.control:
|
|
99
|
+
assert control.qubit_num is not None
|
|
100
|
+
qubit_num: ht.TypeArg
|
|
101
|
+
if isinstance(control.qubit_num, int):
|
|
102
|
+
qubit_num = ht.BoundedNatArg(control.qubit_num)
|
|
103
|
+
else:
|
|
104
|
+
qubit_num = control.qubit_num.to_arg().to_hugr(ctx)
|
|
105
|
+
qubit_num_args.append(qubit_num)
|
|
106
|
+
std_array = standard_array_type(ht.Qubit, qubit_num)
|
|
107
|
+
|
|
108
|
+
# control operator
|
|
109
|
+
input_fn_ty = hugr_ty
|
|
110
|
+
output_fn_ty = ht.FunctionType(
|
|
111
|
+
[std_array, *hugr_ty.input], [std_array, *hugr_ty.output]
|
|
112
|
+
)
|
|
113
|
+
op = ops.ExtOp(
|
|
114
|
+
control_op_def,
|
|
115
|
+
ht.FunctionType([input_fn_ty], [output_fn_ty]),
|
|
116
|
+
[qubit_num, in_out_arg, other_in_arg],
|
|
117
|
+
)
|
|
118
|
+
call = dfg.builder.add_op(op, call)
|
|
119
|
+
# update types
|
|
120
|
+
in_out_arg = ht.ListArg([std_array.type_arg(), *in_out_arg.elems])
|
|
121
|
+
hugr_ty = output_fn_ty
|
|
122
|
+
|
|
123
|
+
# Prepare control arguments
|
|
124
|
+
ctrl_args: list[Wire] = []
|
|
125
|
+
for i, control in enumerate(modified_block.control):
|
|
126
|
+
if is_array_type(get_type(control.ctrl[0])):
|
|
127
|
+
control_array = expr_compiler.compile(control.ctrl[0], dfg)
|
|
128
|
+
control_array = dfg.builder.add_op(
|
|
129
|
+
array_to_std_array(ht.Qubit, qubit_num_args[i]), control_array
|
|
130
|
+
)
|
|
131
|
+
ctrl_args.append(control_array)
|
|
132
|
+
else:
|
|
133
|
+
cs = [expr_compiler.compile(c, dfg) for c in control.ctrl]
|
|
134
|
+
control_array = dfg.builder.add_op(
|
|
135
|
+
array_new(ht.Qubit, len(control.ctrl)), *cs
|
|
136
|
+
)
|
|
137
|
+
control_array = dfg.builder.add_op(
|
|
138
|
+
array_to_std_array(ht.Qubit, qubit_num_args[i]), *control_array
|
|
139
|
+
)
|
|
140
|
+
ctrl_args.append(control_array)
|
|
141
|
+
|
|
142
|
+
# Call
|
|
143
|
+
call = dfg.builder.add_op(
|
|
144
|
+
ops.CallIndirect(),
|
|
145
|
+
call,
|
|
146
|
+
*ctrl_args,
|
|
147
|
+
*args,
|
|
148
|
+
)
|
|
149
|
+
outports = iter(call)
|
|
150
|
+
|
|
151
|
+
# Unpack controls
|
|
152
|
+
for i, control in enumerate(modified_block.control):
|
|
153
|
+
outport = next(outports)
|
|
154
|
+
if is_array_type(get_type(control.ctrl[0])):
|
|
155
|
+
control_array = dfg.builder.add_op(
|
|
156
|
+
std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
|
|
157
|
+
)
|
|
158
|
+
c = control.ctrl[0]
|
|
159
|
+
assert isinstance(c, PlaceNode)
|
|
160
|
+
dfg[c.place] = control_array
|
|
161
|
+
else:
|
|
162
|
+
control_array = dfg.builder.add_op(
|
|
163
|
+
std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
|
|
164
|
+
)
|
|
165
|
+
unpacked = unpack_array(dfg.builder, control_array)
|
|
166
|
+
for c, new_c in zip(control.ctrl, unpacked, strict=False):
|
|
167
|
+
assert isinstance(c, PlaceNode)
|
|
168
|
+
dfg[c.place] = new_c
|
|
169
|
+
|
|
170
|
+
for arg in captured:
|
|
171
|
+
if InputFlags.Inout in arg.flags:
|
|
172
|
+
dfg[arg] = next(outports)
|
|
173
|
+
|
|
174
|
+
return call
|
|
@@ -2,7 +2,6 @@ import ast
|
|
|
2
2
|
import functools
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
|
|
5
|
-
import hugr.tys as ht
|
|
6
5
|
from hugr import Wire, ops
|
|
7
6
|
from hugr.build.dfg import DfBase
|
|
8
7
|
|
|
@@ -18,6 +17,7 @@ from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
|
18
17
|
from guppylang_internals.error import InternalGuppyError
|
|
19
18
|
from guppylang_internals.nodes import (
|
|
20
19
|
ArrayUnpack,
|
|
20
|
+
CheckedModifiedBlock,
|
|
21
21
|
CheckedNestedFunctionDef,
|
|
22
22
|
IterableUnpack,
|
|
23
23
|
PlaceNode,
|
|
@@ -120,8 +120,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
120
120
|
ports[len(left) : -len(right)] if right else ports[len(left) :]
|
|
121
121
|
)
|
|
122
122
|
elt = get_element_type(array_ty).to_hugr(self.ctx)
|
|
123
|
-
|
|
124
|
-
|
|
123
|
+
array = self.builder.add_op(
|
|
124
|
+
array_new(elt, len(starred_ports)), *starred_ports
|
|
125
|
+
)
|
|
125
126
|
self._assign(starred, array)
|
|
126
127
|
|
|
127
128
|
@_assign.register
|
|
@@ -130,7 +131,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
130
131
|
# Given an assignment pattern `left, *starred, right`, pop from the left and
|
|
131
132
|
# right, leaving us with the starred array in the middle
|
|
132
133
|
length = lhs.length
|
|
133
|
-
|
|
134
|
+
elt_ty = lhs.elt_type.to_hugr(self.ctx)
|
|
134
135
|
|
|
135
136
|
def pop(
|
|
136
137
|
array: Wire, length: int, pats: list[ast.expr], from_left: bool
|
|
@@ -141,10 +142,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
141
142
|
elts = []
|
|
142
143
|
for i in range(num_pats):
|
|
143
144
|
res = self.builder.add_op(
|
|
144
|
-
array_pop(
|
|
145
|
+
array_pop(elt_ty, length - i, from_left), array
|
|
145
146
|
)
|
|
146
|
-
[
|
|
147
|
-
[elt] = build_unwrap(self.builder, elt_opt, err)
|
|
147
|
+
[elt, array] = build_unwrap(self.builder, res, err)
|
|
148
148
|
elts.append(elt)
|
|
149
149
|
# Assign elements to the given patterns
|
|
150
150
|
for pat, elt in zip(
|
|
@@ -164,7 +164,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
164
164
|
self._assign(lhs.pattern.starred, array)
|
|
165
165
|
else:
|
|
166
166
|
assert length == 0
|
|
167
|
-
self.builder.add_op(array_discard_empty(
|
|
167
|
+
self.builder.add_op(array_discard_empty(elt_ty), array)
|
|
168
168
|
|
|
169
169
|
@_assign.register
|
|
170
170
|
def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None:
|
|
@@ -221,3 +221,10 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
|
221
221
|
var = Variable(node.name, node.ty, node)
|
|
222
222
|
loaded_func = compile_local_func_def(node, self.dfg, self.ctx)
|
|
223
223
|
self.dfg[var] = loaded_func
|
|
224
|
+
|
|
225
|
+
def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
|
|
226
|
+
from guppylang_internals.compiler.modifier_compiler import (
|
|
227
|
+
compile_modified_block,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
compile_modified_block(node, self.dfg, self.ctx, self.expr_compiler)
|
guppylang_internals/decorator.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
from typing import TYPE_CHECKING, ParamSpec, TypeVar
|
|
4
|
+
from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
|
|
5
5
|
|
|
6
6
|
from hugr import ops
|
|
7
7
|
from hugr import tys as ht
|
|
@@ -42,6 +42,7 @@ from guppylang_internals.tys.ty import (
|
|
|
42
42
|
)
|
|
43
43
|
|
|
44
44
|
if TYPE_CHECKING:
|
|
45
|
+
import ast
|
|
45
46
|
import builtins
|
|
46
47
|
from collections.abc import Callable, Sequence
|
|
47
48
|
from types import FrameType
|
|
@@ -121,15 +122,19 @@ def hugr_op(
|
|
|
121
122
|
return custom_function(OpCompiler(op), checker, higher_order_value, name, signature)
|
|
122
123
|
|
|
123
124
|
|
|
124
|
-
def extend_type(defn: TypeDef) -> Callable[[type], type]:
|
|
125
|
-
"""Decorator to add new instance functions to a type.
|
|
125
|
+
def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]:
|
|
126
|
+
"""Decorator to add new instance functions to a type.
|
|
127
|
+
|
|
128
|
+
By default, returns a `GuppyDefinition` object referring to the type. Alternatively,
|
|
129
|
+
`return_class=True` can be set to return the decorated class unchanged.
|
|
130
|
+
"""
|
|
126
131
|
from guppylang.defs import GuppyDefinition
|
|
127
132
|
|
|
128
133
|
def dec(c: type) -> type:
|
|
129
134
|
for val in c.__dict__.values():
|
|
130
135
|
if isinstance(val, GuppyDefinition):
|
|
131
136
|
DEF_STORE.register_impl(defn.id, val.wrapped.name, val.id)
|
|
132
|
-
return c
|
|
137
|
+
return c if return_class else GuppyDefinition(defn) # type: ignore[return-value]
|
|
133
138
|
|
|
134
139
|
return dec
|
|
135
140
|
|
|
@@ -181,63 +186,124 @@ def custom_type(
|
|
|
181
186
|
|
|
182
187
|
|
|
183
188
|
def wasm_module(
|
|
184
|
-
filename: str,
|
|
189
|
+
filename: str,
|
|
185
190
|
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
DEF_STORE.register_def(wasm_module, get_calling_frame())
|
|
201
|
-
for val in cls.__dict__.values():
|
|
202
|
-
if isinstance(val, GuppyDefinition):
|
|
203
|
-
DEF_STORE.register_impl(wasm_module.id, val.wrapped.name, val.id)
|
|
204
|
-
# Add a constructor to the class
|
|
205
|
-
call_method = CustomFunctionDef(
|
|
206
|
-
DefId.fresh(),
|
|
207
|
-
"__new__",
|
|
208
|
-
None,
|
|
209
|
-
FunctionType(
|
|
210
|
-
[FuncInput(NumericType(NumericType.Kind.Nat), flags=InputFlags.Owned)],
|
|
211
|
-
wasm_module_ty,
|
|
212
|
-
),
|
|
213
|
-
DefaultCallChecker(),
|
|
214
|
-
WasmModuleInitCompiler(),
|
|
215
|
-
True,
|
|
216
|
-
GlobalConstId.fresh(f"{cls.__name__}.__new__"),
|
|
217
|
-
True,
|
|
218
|
-
)
|
|
219
|
-
discard = CustomFunctionDef(
|
|
220
|
-
DefId.fresh(),
|
|
221
|
-
"discard",
|
|
222
|
-
None,
|
|
223
|
-
FunctionType([FuncInput(wasm_module_ty, InputFlags.Owned)], NoneType()),
|
|
224
|
-
DefaultCallChecker(),
|
|
225
|
-
WasmModuleDiscardCompiler(),
|
|
226
|
-
False,
|
|
227
|
-
GlobalConstId.fresh(f"{cls.__name__}.__discard__"),
|
|
228
|
-
True,
|
|
229
|
-
)
|
|
230
|
-
DEF_STORE.register_def(call_method, get_calling_frame())
|
|
231
|
-
DEF_STORE.register_impl(wasm_module.id, "__new__", call_method.id)
|
|
232
|
-
DEF_STORE.register_def(discard, get_calling_frame())
|
|
233
|
-
DEF_STORE.register_impl(wasm_module.id, "discard", discard.id)
|
|
191
|
+
def type_def_wrapper(
|
|
192
|
+
id: DefId,
|
|
193
|
+
name: str,
|
|
194
|
+
defined_at: ast.AST | None,
|
|
195
|
+
wasm_file: str,
|
|
196
|
+
config: str | None,
|
|
197
|
+
) -> OpaqueTypeDef:
|
|
198
|
+
assert config is None
|
|
199
|
+
return WasmModuleTypeDef(id, name, defined_at, wasm_file)
|
|
200
|
+
|
|
201
|
+
f = ext_module_decorator(
|
|
202
|
+
type_def_wrapper, WasmModuleInitCompiler(), WasmModuleDiscardCompiler(), True
|
|
203
|
+
)
|
|
204
|
+
return f(filename, None)
|
|
234
205
|
|
|
235
|
-
return GuppyDefinition(wasm_module)
|
|
236
|
-
|
|
237
|
-
return dec
|
|
238
206
|
|
|
207
|
+
def ext_module_decorator(
|
|
208
|
+
type_def: Callable[[DefId, str, ast.AST | None, str, str | None], OpaqueTypeDef],
|
|
209
|
+
init_compiler: CustomInoutCallCompiler,
|
|
210
|
+
discard_compiler: CustomInoutCallCompiler,
|
|
211
|
+
init_arg: bool, # Whether the init function should take a nat argument
|
|
212
|
+
) -> Callable[[str, str | None], Callable[[builtins.type[T]], GuppyDefinition]]:
|
|
213
|
+
from guppylang.defs import GuppyDefinition
|
|
239
214
|
|
|
240
|
-
def
|
|
215
|
+
def fun(
|
|
216
|
+
filename: str, module: str | None
|
|
217
|
+
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
218
|
+
def dec(cls: builtins.type[T]) -> GuppyDefinition:
|
|
219
|
+
# N.B. Only one module per file and vice-versa
|
|
220
|
+
ext_module = type_def(
|
|
221
|
+
DefId.fresh(),
|
|
222
|
+
cls.__name__,
|
|
223
|
+
None,
|
|
224
|
+
filename,
|
|
225
|
+
module,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
ext_module_ty = ext_module.check_instantiate([], None)
|
|
229
|
+
|
|
230
|
+
DEF_STORE.register_def(ext_module, get_calling_frame())
|
|
231
|
+
for val in cls.__dict__.values():
|
|
232
|
+
if isinstance(val, GuppyDefinition):
|
|
233
|
+
DEF_STORE.register_impl(ext_module.id, val.wrapped.name, val.id)
|
|
234
|
+
# Add a constructor to the class
|
|
235
|
+
if init_arg:
|
|
236
|
+
init_fn_ty = FunctionType(
|
|
237
|
+
[
|
|
238
|
+
FuncInput(
|
|
239
|
+
NumericType(NumericType.Kind.Nat),
|
|
240
|
+
flags=InputFlags.Owned,
|
|
241
|
+
)
|
|
242
|
+
],
|
|
243
|
+
ext_module_ty,
|
|
244
|
+
)
|
|
245
|
+
else:
|
|
246
|
+
init_fn_ty = FunctionType([], ext_module_ty)
|
|
247
|
+
|
|
248
|
+
call_method = CustomFunctionDef(
|
|
249
|
+
DefId.fresh(),
|
|
250
|
+
"__new__",
|
|
251
|
+
None,
|
|
252
|
+
init_fn_ty,
|
|
253
|
+
DefaultCallChecker(),
|
|
254
|
+
init_compiler,
|
|
255
|
+
True,
|
|
256
|
+
GlobalConstId.fresh(f"{cls.__name__}.__new__"),
|
|
257
|
+
True,
|
|
258
|
+
)
|
|
259
|
+
discard = CustomFunctionDef(
|
|
260
|
+
DefId.fresh(),
|
|
261
|
+
"discard",
|
|
262
|
+
None,
|
|
263
|
+
FunctionType([FuncInput(ext_module_ty, InputFlags.Owned)], NoneType()),
|
|
264
|
+
DefaultCallChecker(),
|
|
265
|
+
discard_compiler,
|
|
266
|
+
False,
|
|
267
|
+
GlobalConstId.fresh(f"{cls.__name__}.__discard__"),
|
|
268
|
+
True,
|
|
269
|
+
)
|
|
270
|
+
DEF_STORE.register_def(call_method, get_calling_frame())
|
|
271
|
+
DEF_STORE.register_impl(ext_module.id, "__new__", call_method.id)
|
|
272
|
+
DEF_STORE.register_def(discard, get_calling_frame())
|
|
273
|
+
DEF_STORE.register_impl(ext_module.id, "discard", discard.id)
|
|
274
|
+
|
|
275
|
+
return GuppyDefinition(ext_module)
|
|
276
|
+
|
|
277
|
+
return dec
|
|
278
|
+
|
|
279
|
+
return fun
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@overload
|
|
283
|
+
def wasm(arg: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: ...
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@overload
|
|
287
|
+
def wasm(arg: int) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: ...
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def wasm(
|
|
291
|
+
arg: int | Callable[P, T],
|
|
292
|
+
) -> (
|
|
293
|
+
GuppyFunctionDefinition[P, T]
|
|
294
|
+
| Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]
|
|
295
|
+
):
|
|
296
|
+
if isinstance(arg, int):
|
|
297
|
+
|
|
298
|
+
def wrapper(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
|
|
299
|
+
return wasm_helper(arg, f)
|
|
300
|
+
|
|
301
|
+
return wrapper
|
|
302
|
+
else:
|
|
303
|
+
return wasm_helper(None, arg)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def wasm_helper(fn_id: int | None, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
|
|
241
307
|
from guppylang.defs import GuppyFunctionDefinition
|
|
242
308
|
|
|
243
309
|
func = RawWasmFunctionDef(
|
|
@@ -246,7 +312,7 @@ def wasm(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
|
|
|
246
312
|
None,
|
|
247
313
|
f,
|
|
248
314
|
WasmCallChecker(),
|
|
249
|
-
WasmModuleCallCompiler(f.__name__),
|
|
315
|
+
WasmModuleCallCompiler(f.__name__, fn_id),
|
|
250
316
|
True,
|
|
251
317
|
signature=None,
|
|
252
318
|
)
|
|
@@ -15,7 +15,7 @@ from guppylang_internals.definition.value import (
|
|
|
15
15
|
ValueDef,
|
|
16
16
|
)
|
|
17
17
|
from guppylang_internals.span import SourceMap
|
|
18
|
-
from guppylang_internals.tys.parsing import type_from_ast
|
|
18
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
@dataclass(frozen=True)
|
|
@@ -33,7 +33,7 @@ class RawConstDef(ParsableDef):
|
|
|
33
33
|
self.id,
|
|
34
34
|
self.name,
|
|
35
35
|
self.defined_at,
|
|
36
|
-
type_from_ast(self.type_ast, globals
|
|
36
|
+
type_from_ast(self.type_ast, TypeParsingCtx(globals)),
|
|
37
37
|
self.type_ast,
|
|
38
38
|
self.value,
|
|
39
39
|
)
|
|
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, ClassVar
|
|
|
7
7
|
from hugr import Wire, ops
|
|
8
8
|
from hugr import tys as ht
|
|
9
9
|
from hugr.build.dfg import DfBase
|
|
10
|
+
from hugr.std.collections.borrow_array import EXTENSION as BORROW_ARRAY_EXTENSION
|
|
10
11
|
|
|
11
12
|
from guppylang_internals.ast_util import (
|
|
12
13
|
AstNode,
|
|
@@ -23,6 +24,7 @@ from guppylang_internals.compiler.core import (
|
|
|
23
24
|
DFContainer,
|
|
24
25
|
GlobalConstId,
|
|
25
26
|
partially_monomorphize_args,
|
|
27
|
+
qualified_name,
|
|
26
28
|
)
|
|
27
29
|
from guppylang_internals.definition.common import ParsableDef
|
|
28
30
|
from guppylang_internals.definition.value import CallReturnWires, CompiledCallableDef
|
|
@@ -169,7 +171,7 @@ class RawCustomFunctionDef(ParsableDef):
|
|
|
169
171
|
raise GuppyError(NoSignatureError(node, self.name))
|
|
170
172
|
|
|
171
173
|
if requires_type_annotation:
|
|
172
|
-
return check_signature(node, globals)
|
|
174
|
+
return check_signature(node, globals, self.id)
|
|
173
175
|
else:
|
|
174
176
|
return None
|
|
175
177
|
|
|
@@ -486,7 +488,39 @@ class NoopCompiler(CustomCallCompiler):
|
|
|
486
488
|
|
|
487
489
|
|
|
488
490
|
class CopyInoutCompiler(CustomInoutCallCompiler):
|
|
489
|
-
"""Call compiler for functions that
|
|
491
|
+
"""Call compiler for functions that borrow one argument to copy it."""
|
|
490
492
|
|
|
491
493
|
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
494
|
+
assert len(self.ty.input) == 1
|
|
495
|
+
inp_ty = self.ty.input[0]
|
|
496
|
+
if inp_ty.type_bound() == ht.TypeBound.Linear:
|
|
497
|
+
(arg,) = args
|
|
498
|
+
copies = self._handle_affine_type(inp_ty, arg)
|
|
499
|
+
return CallReturnWires(
|
|
500
|
+
regular_returns=[copies[0]], inout_returns=[copies[1]]
|
|
501
|
+
)
|
|
492
502
|
return CallReturnWires(regular_returns=args, inout_returns=args)
|
|
503
|
+
|
|
504
|
+
# Affine types in Guppy backed by a linear Hugr type need to be copied explicitly.
|
|
505
|
+
# TODO: Handle affine extension types more generally (borrow arrays are currently
|
|
506
|
+
# the only case).
|
|
507
|
+
def _handle_affine_type(self, ty: ht.Type, arg: Wire) -> list[Wire]:
|
|
508
|
+
match ty:
|
|
509
|
+
case ht.ExtType(type_def=type_def, args=type_args):
|
|
510
|
+
if qualified_name(type_def) == qualified_name(
|
|
511
|
+
BORROW_ARRAY_EXTENSION.get_type("borrow_array")
|
|
512
|
+
):
|
|
513
|
+
assert len(type_args) == 2
|
|
514
|
+
# Manually instantiate here to avoid circular import and use
|
|
515
|
+
# type args directly.
|
|
516
|
+
clone_op = BORROW_ARRAY_EXTENSION.get_op("clone").instantiate(
|
|
517
|
+
type_args,
|
|
518
|
+
ht.FunctionType(self.ty.input, self.ty.output),
|
|
519
|
+
)
|
|
520
|
+
return list(self.builder.add_op(clone_op, arg))
|
|
521
|
+
case _:
|
|
522
|
+
pass
|
|
523
|
+
raise InternalGuppyError(
|
|
524
|
+
f"Type `{ty}` needs an explicit handler in the `copy` compiler as "
|
|
525
|
+
"it is an affine Guppy type backed by a linear Hugr type."
|
|
526
|
+
)
|
|
@@ -13,7 +13,7 @@ from guppylang_internals.checker.func_checker import check_signature
|
|
|
13
13
|
from guppylang_internals.compiler.core import (
|
|
14
14
|
CompilerContext,
|
|
15
15
|
DFContainer,
|
|
16
|
-
|
|
16
|
+
require_monomorphization,
|
|
17
17
|
)
|
|
18
18
|
from guppylang_internals.definition.common import CompilableDef, ParsableDef
|
|
19
19
|
from guppylang_internals.definition.function import (
|
|
@@ -68,13 +68,12 @@ class RawFunctionDecl(ParsableDef):
|
|
|
68
68
|
def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
|
|
69
69
|
"""Parses and checks the user-provided signature of the function."""
|
|
70
70
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
71
|
-
ty = check_signature(func_ast, globals)
|
|
71
|
+
ty = check_signature(func_ast, globals, self.id)
|
|
72
72
|
if not has_empty_body(func_ast):
|
|
73
73
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
74
74
|
# Make sure we won't need monomorphization to compile this declaration
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
raise GuppyError(MonomorphizeError(func_ast, self.name, param))
|
|
75
|
+
if mono_params := require_monomorphization(ty.params):
|
|
76
|
+
raise GuppyError(MonomorphizeError(func_ast, self.name, mono_params.pop()))
|
|
78
77
|
return CheckedFunctionDecl(
|
|
79
78
|
self.id,
|
|
80
79
|
self.name,
|
|
@@ -14,7 +14,7 @@ from guppylang_internals.definition.value import (
|
|
|
14
14
|
ValueDef,
|
|
15
15
|
)
|
|
16
16
|
from guppylang_internals.span import SourceMap
|
|
17
|
-
from guppylang_internals.tys.parsing import type_from_ast
|
|
17
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@dataclass(frozen=True)
|
|
@@ -33,7 +33,7 @@ class RawExternDef(ParsableDef):
|
|
|
33
33
|
self.id,
|
|
34
34
|
self.name,
|
|
35
35
|
self.defined_at,
|
|
36
|
-
type_from_ast(self.type_ast, globals
|
|
36
|
+
type_from_ast(self.type_ast, TypeParsingCtx(globals)),
|
|
37
37
|
self.symbol,
|
|
38
38
|
self.constant,
|
|
39
39
|
self.type_ast,
|
|
@@ -73,7 +73,7 @@ class RawFunctionDef(ParsableDef):
|
|
|
73
73
|
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
|
|
74
74
|
"""Parses and checks the user-provided signature of the function."""
|
|
75
75
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
76
|
-
ty = check_signature(func_ast, globals)
|
|
76
|
+
ty = check_signature(func_ast, globals, self.id)
|
|
77
77
|
return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
|
|
78
78
|
|
|
79
79
|
|
|
@@ -38,14 +38,19 @@ class ParamDef(Definition):
|
|
|
38
38
|
class TypeVarDef(ParamDef, CompiledDef):
|
|
39
39
|
"""A type variable definition."""
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
copyable: bool
|
|
42
|
+
droppable: bool
|
|
43
43
|
|
|
44
44
|
description: str = field(default="type variable", init=False)
|
|
45
45
|
|
|
46
46
|
def to_param(self, idx: int) -> TypeParam:
|
|
47
47
|
"""Creates a parameter from this definition."""
|
|
48
|
-
return TypeParam(
|
|
48
|
+
return TypeParam(
|
|
49
|
+
idx,
|
|
50
|
+
self.name,
|
|
51
|
+
must_be_copyable=self.copyable,
|
|
52
|
+
must_be_droppable=self.droppable,
|
|
53
|
+
)
|
|
49
54
|
|
|
50
55
|
|
|
51
56
|
@dataclass(frozen=True)
|
|
@@ -56,9 +61,9 @@ class RawConstVarDef(ParamDef, ParsableDef):
|
|
|
56
61
|
description: str = field(default="const variable", init=False)
|
|
57
62
|
|
|
58
63
|
def parse(self, globals: Globals, sources: SourceMap) -> "ConstVarDef":
|
|
59
|
-
from guppylang_internals.tys.parsing import type_from_ast
|
|
64
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
|
|
60
65
|
|
|
61
|
-
ty = type_from_ast(self.type_ast, globals
|
|
66
|
+
ty = type_from_ast(self.type_ast, TypeParsingCtx(globals))
|
|
62
67
|
if not ty.copyable or not ty.droppable:
|
|
63
68
|
raise GuppyError(LinearConstVarError(self.type_ast, self.name, ty))
|
|
64
69
|
return ConstVarDef(self.id, self.name, self.defined_at, ty)
|