guppylang-internals 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/cfg/builder.py +17 -2
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +1 -2
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +13 -8
- guppylang_internals/checker/func_checker.py +17 -13
- guppylang_internals/checker/linearity_checker.py +2 -10
- guppylang_internals/checker/modifier_checker.py +6 -2
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +5 -5
- guppylang_internals/compiler/expr_compiler.py +42 -73
- guppylang_internals/compiler/modifier_compiler.py +2 -0
- guppylang_internals/decorator.py +86 -7
- guppylang_internals/definition/custom.py +4 -0
- guppylang_internals/definition/declaration.py +6 -2
- guppylang_internals/definition/function.py +12 -2
- guppylang_internals/definition/pytket_circuits.py +1 -0
- guppylang_internals/definition/struct.py +6 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/engine.py +9 -3
- guppylang_internals/nodes.py +23 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +1 -1
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_exts.py +3 -4
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +10 -0
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/parsing.py +3 -3
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +37 -2
- guppylang_internals/tys/ty.py +60 -64
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +4 -3
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/RECORD +43 -40
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -15,7 +15,6 @@ from hugr import val as hv
|
|
|
15
15
|
from hugr.build import function as hf
|
|
16
16
|
from hugr.build.cond_loop import Conditional
|
|
17
17
|
from hugr.build.dfg import DP, DfBase
|
|
18
|
-
from typing_extensions import assert_never
|
|
19
18
|
|
|
20
19
|
from guppylang_internals.ast_util import AstNode, AstVisitor, get_type
|
|
21
20
|
from guppylang_internals.cfg.builder import tmp_vars
|
|
@@ -23,7 +22,6 @@ from guppylang_internals.checker.core import Variable, contains_subscript
|
|
|
23
22
|
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
24
23
|
from guppylang_internals.compiler.core import (
|
|
25
24
|
DEBUG_EXTENSION,
|
|
26
|
-
RESULT_EXTENSION,
|
|
27
25
|
CompilerBase,
|
|
28
26
|
CompilerContext,
|
|
29
27
|
DFContainer,
|
|
@@ -53,7 +51,6 @@ from guppylang_internals.nodes import (
|
|
|
53
51
|
PanicExpr,
|
|
54
52
|
PartialApply,
|
|
55
53
|
PlaceNode,
|
|
56
|
-
ResultExpr,
|
|
57
54
|
StateResultExpr,
|
|
58
55
|
SubscriptAccessAndDrop,
|
|
59
56
|
TensorCall,
|
|
@@ -66,7 +63,6 @@ from guppylang_internals.std._internal.compiler.arithmetic import (
|
|
|
66
63
|
convert_itousize,
|
|
67
64
|
)
|
|
68
65
|
from guppylang_internals.std._internal.compiler.array import (
|
|
69
|
-
array_clone,
|
|
70
66
|
array_map,
|
|
71
67
|
array_new,
|
|
72
68
|
array_to_std_array,
|
|
@@ -80,8 +76,8 @@ from guppylang_internals.std._internal.compiler.list import (
|
|
|
80
76
|
list_new,
|
|
81
77
|
)
|
|
82
78
|
from guppylang_internals.std._internal.compiler.prelude import (
|
|
83
|
-
build_error,
|
|
84
79
|
build_panic,
|
|
80
|
+
make_error,
|
|
85
81
|
panic,
|
|
86
82
|
)
|
|
87
83
|
from guppylang_internals.std._internal.compiler.tket_bool import (
|
|
@@ -93,14 +89,12 @@ from guppylang_internals.std._internal.compiler.tket_bool import (
|
|
|
93
89
|
)
|
|
94
90
|
from guppylang_internals.tys.arg import ConstArg
|
|
95
91
|
from guppylang_internals.tys.builtin import (
|
|
96
|
-
array_type,
|
|
97
92
|
bool_type,
|
|
98
93
|
get_element_type,
|
|
99
94
|
int_type,
|
|
100
|
-
is_bool_type,
|
|
101
95
|
is_frozenarray_type,
|
|
102
96
|
)
|
|
103
|
-
from guppylang_internals.tys.const import ConstValue
|
|
97
|
+
from guppylang_internals.tys.const import BoundConstVar, Const, ConstValue
|
|
104
98
|
from guppylang_internals.tys.subst import Inst
|
|
105
99
|
from guppylang_internals.tys.ty import (
|
|
106
100
|
BoundTypeVar,
|
|
@@ -535,74 +529,48 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
535
529
|
tuple_port = self.visit(node.value)
|
|
536
530
|
return self._unpack_tuple(tuple_port, node.tuple_ty.element_types)[node.index]
|
|
537
531
|
|
|
538
|
-
def
|
|
539
|
-
|
|
540
|
-
base_ty = node.base_ty.to_hugr(self.ctx)
|
|
541
|
-
extra_args: list[ht.TypeArg] = []
|
|
542
|
-
if isinstance(node.base_ty, NumericType):
|
|
543
|
-
match node.base_ty.kind:
|
|
544
|
-
case NumericType.Kind.Nat:
|
|
545
|
-
base_name = "uint"
|
|
546
|
-
extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
|
|
547
|
-
case NumericType.Kind.Int:
|
|
548
|
-
base_name = "int"
|
|
549
|
-
extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
|
|
550
|
-
case NumericType.Kind.Float:
|
|
551
|
-
base_name = "f64"
|
|
552
|
-
case kind:
|
|
553
|
-
assert_never(kind)
|
|
554
|
-
else:
|
|
555
|
-
# The only other valid base type is bool
|
|
556
|
-
assert is_bool_type(node.base_ty)
|
|
557
|
-
base_name = "bool"
|
|
558
|
-
if node.array_len is not None:
|
|
559
|
-
op_name = f"result_array_{base_name}"
|
|
560
|
-
size_arg = node.array_len.to_arg().to_hugr(self.ctx)
|
|
561
|
-
extra_args = [size_arg, *extra_args]
|
|
562
|
-
# As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
|
|
563
|
-
# that all elements in it are copyable) to avoid linearity violations when
|
|
564
|
-
# both passing it to the result operation and returning it (as an inout
|
|
565
|
-
# argument).
|
|
566
|
-
value_wire, inout_wire = self.builder.add_op(
|
|
567
|
-
array_clone(base_ty, size_arg), value_wire
|
|
568
|
-
)
|
|
569
|
-
func_ty = FunctionType(
|
|
570
|
-
[
|
|
571
|
-
FuncInput(
|
|
572
|
-
array_type(node.base_ty, node.array_len), InputFlags.Inout
|
|
573
|
-
),
|
|
574
|
-
],
|
|
575
|
-
NoneType(),
|
|
576
|
-
)
|
|
577
|
-
self._update_inout_ports(node.args, iter([inout_wire]), func_ty)
|
|
578
|
-
if is_bool_type(node.base_ty):
|
|
579
|
-
# We need to coerce a read on all the array elements if they are bools.
|
|
580
|
-
array_read = array_read_bool(self.ctx)
|
|
581
|
-
array_read = self.builder.load_function(array_read)
|
|
582
|
-
map_op = array_map(OpaqueBool, size_arg, ht.Bool)
|
|
583
|
-
value_wire = self.builder.add_op(map_op, value_wire, array_read)
|
|
584
|
-
base_ty = ht.Bool
|
|
585
|
-
# Turn `borrow_array` into regular `array`
|
|
586
|
-
value_wire = self.builder.add_op(
|
|
587
|
-
array_to_std_array(base_ty, size_arg), value_wire
|
|
588
|
-
)
|
|
589
|
-
hugr_ty: ht.Type = hugr.std.collections.array.Array(base_ty, size_arg)
|
|
590
|
-
else:
|
|
591
|
-
if is_bool_type(node.base_ty):
|
|
592
|
-
base_ty = ht.Bool
|
|
593
|
-
value_wire = self.builder.add_op(read_bool(), value_wire)
|
|
594
|
-
op_name = f"result_{base_name}"
|
|
595
|
-
hugr_ty = base_ty
|
|
532
|
+
def _visit_result_tag(self, tag: Const, loc: ast.expr) -> str:
|
|
533
|
+
"""Helper method to resolve the tag string in `state_result` expressions.
|
|
596
534
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
535
|
+
Also takes care of checking that the tag fits into the maximum tag length.
|
|
536
|
+
Once we go ahead with https://github.com/quantinuum/guppylang/discussions/1299,
|
|
537
|
+
this can be moved into type checking.
|
|
538
|
+
"""
|
|
539
|
+
from guppylang_internals.std._internal.compiler.platform import (
|
|
540
|
+
TAG_MAX_LEN,
|
|
541
|
+
TooLongError,
|
|
542
|
+
)
|
|
600
543
|
|
|
601
|
-
|
|
602
|
-
|
|
544
|
+
is_generic: BoundConstVar | None = None
|
|
545
|
+
match tag:
|
|
546
|
+
case ConstValue(value=str(v)):
|
|
547
|
+
tag_value = v
|
|
548
|
+
case BoundConstVar(idx=idx) as var:
|
|
549
|
+
assert self.ctx.current_mono_args is not None
|
|
550
|
+
match self.ctx.current_mono_args[idx]:
|
|
551
|
+
case ConstArg(const=ConstValue(value=str(v))):
|
|
552
|
+
tag_value = v
|
|
553
|
+
is_generic = var
|
|
554
|
+
case _:
|
|
555
|
+
raise InternalGuppyError("Unexpected tag monomorphization")
|
|
556
|
+
case _:
|
|
557
|
+
raise InternalGuppyError("Unexpected tag value")
|
|
558
|
+
|
|
559
|
+
if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
|
|
560
|
+
err = TooLongError(loc)
|
|
561
|
+
err.add_sub_diagnostic(TooLongError.Hint(None))
|
|
562
|
+
if is_generic:
|
|
563
|
+
err.add_sub_diagnostic(
|
|
564
|
+
TooLongError.GenericHint(None, is_generic.display_name, tag_value)
|
|
565
|
+
)
|
|
566
|
+
raise GuppyError(err)
|
|
567
|
+
return tag_value
|
|
603
568
|
|
|
604
569
|
def visit_PanicExpr(self, node: PanicExpr) -> Wire:
|
|
605
|
-
|
|
570
|
+
signal = self.visit(node.signal)
|
|
571
|
+
signal_usize = self.builder.add_op(convert_itousize(), signal)
|
|
572
|
+
msg = self.visit(node.msg)
|
|
573
|
+
err = self.builder.add_op(make_error(), signal_usize, msg)
|
|
606
574
|
in_tys = [get_type(e).to_hugr(self.ctx) for e in node.values]
|
|
607
575
|
out_tys = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(node))]
|
|
608
576
|
args = [self.visit(e) for e in node.values]
|
|
@@ -627,12 +595,13 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
627
595
|
return self._pack_returns([], NoneType())
|
|
628
596
|
|
|
629
597
|
def visit_StateResultExpr(self, node: StateResultExpr) -> Wire:
|
|
598
|
+
tag_value = self._visit_result_tag(node.tag_value, node.tag_expr)
|
|
630
599
|
num_qubits_arg = (
|
|
631
600
|
node.array_len.to_arg().to_hugr(self.ctx)
|
|
632
601
|
if node.array_len
|
|
633
602
|
else ht.BoundedNatArg(len(node.args) - 1)
|
|
634
603
|
)
|
|
635
|
-
args = [ht.StringArg(
|
|
604
|
+
args = [ht.StringArg(tag_value), num_qubits_arg]
|
|
636
605
|
sig = ht.FunctionType(
|
|
637
606
|
[standard_array_type(ht.Qubit, num_qubits_arg)],
|
|
638
607
|
[standard_array_type(ht.Qubit, num_qubits_arg)],
|
|
@@ -8,6 +8,7 @@ from guppylang_internals.checker.modifier_checker import non_copyable_front_othe
|
|
|
8
8
|
from guppylang_internals.compiler.cfg_compiler import compile_cfg
|
|
9
9
|
from guppylang_internals.compiler.core import CompilerContext, DFContainer
|
|
10
10
|
from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
11
|
+
from guppylang_internals.definition.function import add_unitarity_metadata
|
|
11
12
|
from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
|
|
12
13
|
from guppylang_internals.std._internal.compiler.array import (
|
|
13
14
|
array_new,
|
|
@@ -56,6 +57,7 @@ def compile_modified_block(
|
|
|
56
57
|
func_builder = dfg.builder.module_root_builder().define_function(
|
|
57
58
|
str(modified_block), hugr_ty.input, hugr_ty.output
|
|
58
59
|
)
|
|
60
|
+
add_unitarity_metadata(func_builder, modified_block.ty.unitary_flags)
|
|
59
61
|
|
|
60
62
|
# compile body
|
|
61
63
|
cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
|
guppylang_internals/decorator.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
+
import pathlib
|
|
4
5
|
from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
|
|
5
6
|
|
|
6
7
|
from hugr import ops
|
|
7
8
|
from hugr import tys as ht
|
|
8
9
|
|
|
10
|
+
from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
|
|
9
11
|
from guppylang_internals.compiler.core import (
|
|
10
12
|
CompilerContext,
|
|
11
13
|
GlobalConstId,
|
|
@@ -24,6 +26,7 @@ from guppylang_internals.definition.ty import OpaqueTypeDef, TypeDef
|
|
|
24
26
|
from guppylang_internals.definition.wasm import RawWasmFunctionDef
|
|
25
27
|
from guppylang_internals.dummy_decorator import _dummy_custom_decorator, sphinx_running
|
|
26
28
|
from guppylang_internals.engine import DEF_STORE
|
|
29
|
+
from guppylang_internals.error import GuppyError
|
|
27
30
|
from guppylang_internals.std._internal.checker import WasmCallChecker
|
|
28
31
|
from guppylang_internals.std._internal.compiler.wasm import (
|
|
29
32
|
WasmModuleCallCompiler,
|
|
@@ -39,6 +42,14 @@ from guppylang_internals.tys.ty import (
|
|
|
39
42
|
InputFlags,
|
|
40
43
|
NoneType,
|
|
41
44
|
NumericType,
|
|
45
|
+
UnitaryFlags,
|
|
46
|
+
)
|
|
47
|
+
from guppylang_internals.wasm_util import (
|
|
48
|
+
ConcreteWasmModule,
|
|
49
|
+
WasmFileNotFound,
|
|
50
|
+
WasmFunctionNotInFile,
|
|
51
|
+
WasmSignatureError,
|
|
52
|
+
decode_wasm_functions,
|
|
42
53
|
)
|
|
43
54
|
|
|
44
55
|
if TYPE_CHECKING:
|
|
@@ -47,7 +58,6 @@ if TYPE_CHECKING:
|
|
|
47
58
|
from collections.abc import Callable, Sequence
|
|
48
59
|
from types import FrameType
|
|
49
60
|
|
|
50
|
-
from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
|
|
51
61
|
from guppylang_internals.tys.arg import Argument
|
|
52
62
|
from guppylang_internals.tys.param import Parameter
|
|
53
63
|
from guppylang_internals.tys.subst import Inst
|
|
@@ -75,6 +85,7 @@ def custom_function(
|
|
|
75
85
|
higher_order_value: bool = True,
|
|
76
86
|
name: str = "",
|
|
77
87
|
signature: FunctionType | None = None,
|
|
88
|
+
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
|
|
78
89
|
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
|
|
79
90
|
"""Decorator to add custom typing or compilation behaviour to function decls.
|
|
80
91
|
|
|
@@ -86,6 +97,8 @@ def custom_function(
|
|
|
86
97
|
|
|
87
98
|
def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
|
|
88
99
|
call_checker = checker or DefaultCallChecker()
|
|
100
|
+
if signature is not None:
|
|
101
|
+
object.__setattr__(signature, "unitary_flags", unitary_flags)
|
|
89
102
|
func = RawCustomFunctionDef(
|
|
90
103
|
DefId.fresh(),
|
|
91
104
|
name or f.__name__,
|
|
@@ -95,6 +108,7 @@ def custom_function(
|
|
|
95
108
|
compiler or NotImplementedCallCompiler(),
|
|
96
109
|
higher_order_value,
|
|
97
110
|
signature,
|
|
111
|
+
unitary_flags,
|
|
98
112
|
)
|
|
99
113
|
DEF_STORE.register_def(func, get_calling_frame())
|
|
100
114
|
return GuppyFunctionDefinition(func)
|
|
@@ -108,6 +122,7 @@ def hugr_op(
|
|
|
108
122
|
higher_order_value: bool = True,
|
|
109
123
|
name: str = "",
|
|
110
124
|
signature: FunctionType | None = None,
|
|
125
|
+
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
|
|
111
126
|
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
|
|
112
127
|
"""Decorator to annotate function declarations as HUGR ops.
|
|
113
128
|
|
|
@@ -119,7 +134,14 @@ def hugr_op(
|
|
|
119
134
|
value.
|
|
120
135
|
name: The name of the function.
|
|
121
136
|
"""
|
|
122
|
-
return custom_function(
|
|
137
|
+
return custom_function(
|
|
138
|
+
OpCompiler(op),
|
|
139
|
+
checker,
|
|
140
|
+
higher_order_value,
|
|
141
|
+
name,
|
|
142
|
+
signature,
|
|
143
|
+
unitary_flags=unitary_flags,
|
|
144
|
+
)
|
|
123
145
|
|
|
124
146
|
|
|
125
147
|
def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]:
|
|
@@ -188,6 +210,12 @@ def custom_type(
|
|
|
188
210
|
def wasm_module(
|
|
189
211
|
filename: str,
|
|
190
212
|
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
213
|
+
wasm_file = pathlib.Path(filename)
|
|
214
|
+
if wasm_file.is_file():
|
|
215
|
+
wasm_sigs = decode_wasm_functions(filename)
|
|
216
|
+
else:
|
|
217
|
+
raise GuppyError(WasmFileNotFound(None, filename))
|
|
218
|
+
|
|
191
219
|
def type_def_wrapper(
|
|
192
220
|
id: DefId,
|
|
193
221
|
name: str,
|
|
@@ -198,10 +226,19 @@ def wasm_module(
|
|
|
198
226
|
assert config is None
|
|
199
227
|
return WasmModuleTypeDef(id, name, defined_at, wasm_file)
|
|
200
228
|
|
|
201
|
-
|
|
202
|
-
type_def_wrapper,
|
|
229
|
+
decorator = ext_module_decorator(
|
|
230
|
+
type_def_wrapper,
|
|
231
|
+
WasmModuleInitCompiler(),
|
|
232
|
+
WasmModuleDiscardCompiler(),
|
|
233
|
+
True,
|
|
234
|
+
wasm_sigs,
|
|
203
235
|
)
|
|
204
|
-
|
|
236
|
+
|
|
237
|
+
def inner_fun(ty: builtins.type[T]) -> GuppyDefinition:
|
|
238
|
+
decorator_inner = decorator(filename, None)
|
|
239
|
+
return decorator_inner(ty)
|
|
240
|
+
|
|
241
|
+
return inner_fun
|
|
205
242
|
|
|
206
243
|
|
|
207
244
|
def ext_module_decorator(
|
|
@@ -209,9 +246,9 @@ def ext_module_decorator(
|
|
|
209
246
|
init_compiler: CustomInoutCallCompiler,
|
|
210
247
|
discard_compiler: CustomInoutCallCompiler,
|
|
211
248
|
init_arg: bool, # Whether the init function should take a nat argument
|
|
249
|
+
wasm_sigs: ConcreteWasmModule
|
|
250
|
+
| None = None, # For @wasm_module, we must be passed a parsed wasm file
|
|
212
251
|
) -> Callable[[str, str | None], Callable[[builtins.type[T]], GuppyDefinition]]:
|
|
213
|
-
from guppylang.defs import GuppyDefinition
|
|
214
|
-
|
|
215
252
|
def fun(
|
|
216
253
|
filename: str, module: str | None
|
|
217
254
|
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
@@ -231,6 +268,47 @@ def ext_module_decorator(
|
|
|
231
268
|
for val in cls.__dict__.values():
|
|
232
269
|
if isinstance(val, GuppyDefinition):
|
|
233
270
|
DEF_STORE.register_impl(ext_module.id, val.wrapped.name, val.id)
|
|
271
|
+
wasm_def: RawWasmFunctionDef
|
|
272
|
+
if isinstance(val, GuppyFunctionDefinition) and isinstance(
|
|
273
|
+
val.wrapped, RawWasmFunctionDef
|
|
274
|
+
):
|
|
275
|
+
wasm_def = val.wrapped
|
|
276
|
+
else:
|
|
277
|
+
continue
|
|
278
|
+
# wasm_sigs should only have not been provided if we have
|
|
279
|
+
# defined @wasm functions in a class which didn't use the
|
|
280
|
+
# @wasm_module decorator.
|
|
281
|
+
assert wasm_sigs is not None
|
|
282
|
+
if wasm_def.wasm_index is not None:
|
|
283
|
+
name = wasm_sigs.functions[wasm_def.wasm_index]
|
|
284
|
+
assert name in wasm_sigs.function_sigs
|
|
285
|
+
wasm_sig_or_err = wasm_sigs.function_sigs[name]
|
|
286
|
+
else:
|
|
287
|
+
if wasm_def.name in wasm_sigs.function_sigs:
|
|
288
|
+
wasm_sig_or_err = wasm_sigs.function_sigs[wasm_def.name]
|
|
289
|
+
else:
|
|
290
|
+
raise GuppyError(
|
|
291
|
+
WasmFunctionNotInFile(
|
|
292
|
+
wasm_def.defined_at,
|
|
293
|
+
wasm_def.name,
|
|
294
|
+
).add_sub_diagnostic(
|
|
295
|
+
WasmFunctionNotInFile.WasmFileNote(
|
|
296
|
+
None,
|
|
297
|
+
wasm_sigs.filename,
|
|
298
|
+
)
|
|
299
|
+
)
|
|
300
|
+
)
|
|
301
|
+
if isinstance(wasm_sig_or_err, FunctionType):
|
|
302
|
+
DEF_STORE.register_wasm_function(wasm_def.id, wasm_sig_or_err)
|
|
303
|
+
elif isinstance(wasm_sig_or_err, str):
|
|
304
|
+
raise GuppyError(
|
|
305
|
+
WasmSignatureError(
|
|
306
|
+
None, wasm_def.name, filename
|
|
307
|
+
).add_sub_diagnostic(
|
|
308
|
+
WasmSignatureError.Message(None, wasm_sig_or_err)
|
|
309
|
+
)
|
|
310
|
+
)
|
|
311
|
+
|
|
234
312
|
# Add a constructor to the class
|
|
235
313
|
if init_arg:
|
|
236
314
|
init_fn_ty = FunctionType(
|
|
@@ -315,6 +393,7 @@ def wasm_helper(fn_id: int | None, f: Callable[P, T]) -> GuppyFunctionDefinition
|
|
|
315
393
|
WasmModuleCallCompiler(f.__name__, fn_id),
|
|
316
394
|
True,
|
|
317
395
|
signature=None,
|
|
396
|
+
wasm_index=fn_id,
|
|
318
397
|
)
|
|
319
398
|
DEF_STORE.register_def(func, get_calling_frame())
|
|
320
399
|
return GuppyFunctionDefinition(func)
|
|
@@ -44,6 +44,7 @@ from guppylang_internals.tys.ty import (
|
|
|
44
44
|
InputFlags,
|
|
45
45
|
NoneType,
|
|
46
46
|
Type,
|
|
47
|
+
UnitaryFlags,
|
|
47
48
|
type_to_row,
|
|
48
49
|
)
|
|
49
50
|
|
|
@@ -114,6 +115,8 @@ class RawCustomFunctionDef(ParsableDef):
|
|
|
114
115
|
|
|
115
116
|
signature: FunctionType | None
|
|
116
117
|
|
|
118
|
+
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags)
|
|
119
|
+
|
|
117
120
|
description: str = field(default="function", init=False)
|
|
118
121
|
|
|
119
122
|
def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
|
|
@@ -136,6 +139,7 @@ class RawCustomFunctionDef(ParsableDef):
|
|
|
136
139
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
137
140
|
sig = self.signature or self._get_signature(func_ast, globals)
|
|
138
141
|
ty = sig or FunctionType([], NoneType())
|
|
142
|
+
ty = ty.with_unitary_flags(self.unitary_flags)
|
|
139
143
|
return CustomFunctionDef(
|
|
140
144
|
self.id,
|
|
141
145
|
self.name,
|
|
@@ -34,7 +34,7 @@ from guppylang_internals.nodes import GlobalCall
|
|
|
34
34
|
from guppylang_internals.span import SourceMap
|
|
35
35
|
from guppylang_internals.tys.param import Parameter
|
|
36
36
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
37
|
-
from guppylang_internals.tys.ty import Type
|
|
37
|
+
from guppylang_internals.tys.ty import Type, UnitaryFlags
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
@dataclass(frozen=True)
|
|
@@ -65,10 +65,14 @@ class RawFunctionDecl(ParsableDef):
|
|
|
65
65
|
python_func: PyFunc
|
|
66
66
|
description: str = field(default="function", init=False)
|
|
67
67
|
|
|
68
|
+
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
|
|
69
|
+
|
|
68
70
|
def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
|
|
69
71
|
"""Parses and checks the user-provided signature of the function."""
|
|
70
72
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
71
|
-
ty = check_signature(
|
|
73
|
+
ty = check_signature(
|
|
74
|
+
func_ast, globals, self.id, unitary_flags=self.unitary_flags
|
|
75
|
+
)
|
|
72
76
|
if not has_empty_body(func_ast):
|
|
73
77
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
74
78
|
# Make sure we won't need monomorphization to compile this declaration
|
|
@@ -43,7 +43,7 @@ from guppylang_internals.error import GuppyError
|
|
|
43
43
|
from guppylang_internals.nodes import GlobalCall
|
|
44
44
|
from guppylang_internals.span import SourceMap
|
|
45
45
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
46
|
-
from guppylang_internals.tys.ty import FunctionType, Type, type_to_row
|
|
46
|
+
from guppylang_internals.tys.ty import FunctionType, Type, UnitaryFlags, type_to_row
|
|
47
47
|
|
|
48
48
|
if TYPE_CHECKING:
|
|
49
49
|
from guppylang_internals.tys.param import Parameter
|
|
@@ -70,10 +70,14 @@ class RawFunctionDef(ParsableDef):
|
|
|
70
70
|
|
|
71
71
|
description: str = field(default="function", init=False)
|
|
72
72
|
|
|
73
|
+
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
|
|
74
|
+
|
|
73
75
|
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
|
|
74
76
|
"""Parses and checks the user-provided signature of the function."""
|
|
75
77
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
76
|
-
ty = check_signature(
|
|
78
|
+
ty = check_signature(
|
|
79
|
+
func_ast, globals, self.id, unitary_flags=self.unitary_flags
|
|
80
|
+
)
|
|
77
81
|
return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
|
|
78
82
|
|
|
79
83
|
|
|
@@ -173,6 +177,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
|
|
|
173
177
|
func_def = module.module_root_builder().define_function(
|
|
174
178
|
self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
|
|
175
179
|
)
|
|
180
|
+
add_unitarity_metadata(func_def, self.ty.unitary_flags)
|
|
176
181
|
return CompiledFunctionDef(
|
|
177
182
|
self.id,
|
|
178
183
|
self.name,
|
|
@@ -300,3 +305,8 @@ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AS
|
|
|
300
305
|
else:
|
|
301
306
|
node = ast.parse(source).body[0]
|
|
302
307
|
return source, node, line_offset
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def add_unitarity_metadata(func: hf.Function, flags: UnitaryFlags) -> None:
|
|
311
|
+
"""Stores unitarity annotations in the metadate of a Hugr function definition."""
|
|
312
|
+
func.metadata["unitary"] = flags.value
|
|
@@ -370,6 +370,7 @@ def _signature_from_circuit(
|
|
|
370
370
|
use_arrays: bool = False,
|
|
371
371
|
) -> FunctionType:
|
|
372
372
|
"""Helper function for inferring a function signature from a pytket circuit."""
|
|
373
|
+
# May want to set proper unitary flags in the future.
|
|
373
374
|
try:
|
|
374
375
|
import pytket
|
|
375
376
|
|
|
@@ -273,13 +273,16 @@ class CheckedStructDef(TypeDef, CompiledDef):
|
|
|
273
273
|
|
|
274
274
|
constructor_sig = FunctionType(
|
|
275
275
|
inputs=[
|
|
276
|
-
FuncInput(
|
|
276
|
+
FuncInput(
|
|
277
|
+
f.ty,
|
|
278
|
+
InputFlags.Owned if f.ty.linear else InputFlags.NoFlags,
|
|
279
|
+
f.name,
|
|
280
|
+
)
|
|
277
281
|
for f in self.fields
|
|
278
282
|
],
|
|
279
283
|
output=StructType(
|
|
280
284
|
defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)]
|
|
281
285
|
),
|
|
282
|
-
input_names=[f.name for f in self.fields],
|
|
283
286
|
params=self.params,
|
|
284
287
|
)
|
|
285
288
|
constructor_def = CustomFunctionDef(
|
|
@@ -317,7 +320,7 @@ def parse_py_class(
|
|
|
317
320
|
raise GuppyError(UnknownSourceError(None, cls))
|
|
318
321
|
|
|
319
322
|
# We can't rely on `inspect.getsourcelines` since it doesn't work properly for
|
|
320
|
-
# classes prior to Python 3.13. See https://github.com/
|
|
323
|
+
# classes prior to Python 3.13. See https://github.com/quantinuum/guppylang/issues/1107.
|
|
321
324
|
# Instead, we reproduce the behaviour of Python >= 3.13 using the `__firstlineno__`
|
|
322
325
|
# attribute. See https://github.com/python/cpython/blob/3.13/Lib/inspect.py#L1052.
|
|
323
326
|
# In the decorator, we make sure that `__firstlineno__` is set, even if we're not
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
1
2
|
from typing import TYPE_CHECKING
|
|
2
3
|
|
|
3
4
|
from guppylang_internals.ast_util import AstNode
|
|
@@ -9,7 +10,8 @@ from guppylang_internals.definition.custom import (
|
|
|
9
10
|
CustomFunctionDef,
|
|
10
11
|
RawCustomFunctionDef,
|
|
11
12
|
)
|
|
12
|
-
from guppylang_internals.
|
|
13
|
+
from guppylang_internals.engine import DEF_STORE
|
|
14
|
+
from guppylang_internals.error import GuppyError, GuppyTypeError
|
|
13
15
|
from guppylang_internals.span import SourceMap
|
|
14
16
|
from guppylang_internals.tys.builtin import wasm_module_name
|
|
15
17
|
from guppylang_internals.tys.ty import (
|
|
@@ -21,24 +23,35 @@ from guppylang_internals.tys.ty import (
|
|
|
21
23
|
TupleType,
|
|
22
24
|
Type,
|
|
23
25
|
)
|
|
26
|
+
from guppylang_internals.wasm_util import WasmSigMismatchError
|
|
24
27
|
|
|
25
28
|
if TYPE_CHECKING:
|
|
26
29
|
from guppylang_internals.checker.core import Globals
|
|
27
30
|
|
|
28
31
|
|
|
32
|
+
@dataclass(frozen=True)
|
|
29
33
|
class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
30
|
-
|
|
34
|
+
# If a function is specified in the @wasm decorator by its index in the wasm
|
|
35
|
+
# file, record what the index was.
|
|
36
|
+
wasm_index: int | None = field(default=None)
|
|
37
|
+
|
|
38
|
+
def sanitise_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
|
|
31
39
|
# Place to highlight in error messages
|
|
32
|
-
match fun_ty.inputs
|
|
33
|
-
case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
|
|
40
|
+
match fun_ty.inputs:
|
|
41
|
+
case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if wasm_module_name(
|
|
34
42
|
ty
|
|
35
43
|
) is not None:
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
44
|
+
for inp in args:
|
|
45
|
+
if not self.is_type_wasmable(inp.ty):
|
|
46
|
+
raise GuppyError(UnWasmableType(loc, inp.ty))
|
|
47
|
+
case [FuncInput(ty=ty), *_]:
|
|
48
|
+
raise GuppyError(
|
|
49
|
+
FirstArgNotModule(loc).add_sub_diagnostic(
|
|
50
|
+
FirstArgNotModule.GotOtherType(loc, ty)
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
case []:
|
|
54
|
+
raise GuppyError(FirstArgNotModule(loc))
|
|
42
55
|
if not self.is_type_wasmable(fun_ty.output):
|
|
43
56
|
match fun_ty.output:
|
|
44
57
|
case NoneType():
|
|
@@ -46,6 +59,23 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
46
59
|
case _:
|
|
47
60
|
raise GuppyError(UnWasmableType(loc, fun_ty.output))
|
|
48
61
|
|
|
62
|
+
def validate_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
|
|
63
|
+
type_in_wasm: FunctionType = DEF_STORE.wasm_functions[self.id]
|
|
64
|
+
assert type_in_wasm is not None
|
|
65
|
+
# Drop the first arg because it should be "self"
|
|
66
|
+
expected_type = FunctionType(fun_ty.inputs[1:], fun_ty.output)
|
|
67
|
+
|
|
68
|
+
if expected_type != type_in_wasm:
|
|
69
|
+
raise GuppyTypeError(
|
|
70
|
+
WasmSigMismatchError(loc)
|
|
71
|
+
.add_sub_diagnostic(
|
|
72
|
+
WasmSigMismatchError.Declaration(None, declared=str(expected_type))
|
|
73
|
+
)
|
|
74
|
+
.add_sub_diagnostic(
|
|
75
|
+
WasmSigMismatchError.Actual(None, actual=str(type_in_wasm))
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
|
|
49
79
|
def is_type_wasmable(self, ty: Type) -> bool:
|
|
50
80
|
match ty:
|
|
51
81
|
case NumericType():
|
|
@@ -57,5 +87,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
57
87
|
|
|
58
88
|
def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
|
|
59
89
|
parsed = super().parse(globals, sources)
|
|
90
|
+
assert parsed.defined_at is not None
|
|
60
91
|
self.sanitise_type(parsed.defined_at, parsed.ty)
|
|
92
|
+
self.validate_type(parsed.defined_at, parsed.ty)
|
|
61
93
|
return parsed
|
guppylang_internals/engine.py
CHANGED
|
@@ -46,6 +46,7 @@ from guppylang_internals.tys.builtin import (
|
|
|
46
46
|
string_type_def,
|
|
47
47
|
tuple_type_def,
|
|
48
48
|
)
|
|
49
|
+
from guppylang_internals.tys.ty import FunctionType
|
|
49
50
|
|
|
50
51
|
if TYPE_CHECKING:
|
|
51
52
|
from guppylang_internals.compiler.core import MonoDefId
|
|
@@ -87,6 +88,7 @@ class DefinitionStore:
|
|
|
87
88
|
raw_defs: dict[DefId, RawDef]
|
|
88
89
|
impls: defaultdict[DefId, dict[str, DefId]]
|
|
89
90
|
impl_parents: dict[DefId, DefId]
|
|
91
|
+
wasm_functions: dict[DefId, FunctionType]
|
|
90
92
|
frames: dict[DefId, FrameType]
|
|
91
93
|
sources: SourceMap
|
|
92
94
|
|
|
@@ -96,6 +98,7 @@ class DefinitionStore:
|
|
|
96
98
|
self.impl_parents = {}
|
|
97
99
|
self.frames = {}
|
|
98
100
|
self.sources = SourceMap()
|
|
101
|
+
self.wasm_functions = {}
|
|
99
102
|
|
|
100
103
|
def register_def(self, defn: RawDef, frame: FrameType | None) -> None:
|
|
101
104
|
self.raw_defs[defn.id] = defn
|
|
@@ -123,6 +126,9 @@ class DefinitionStore:
|
|
|
123
126
|
assert frame is not None
|
|
124
127
|
self.frames[impl_id] = frame
|
|
125
128
|
|
|
129
|
+
def register_wasm_function(self, fn_id: DefId, sig: FunctionType) -> None:
|
|
130
|
+
self.wasm_functions[fn_id] = sig
|
|
131
|
+
|
|
126
132
|
|
|
127
133
|
DEF_STORE: DefinitionStore = DefinitionStore()
|
|
128
134
|
|
|
@@ -263,8 +269,8 @@ class CompilationEngine:
|
|
|
263
269
|
and isinstance(compiled_def, CompiledCallableDef)
|
|
264
270
|
and not isinstance(graph.hugr[compiled_def.hugr_node].op, ops.FuncDecl)
|
|
265
271
|
):
|
|
266
|
-
# if compiling a region set it as the HUGR entrypoint
|
|
267
|
-
#
|
|
272
|
+
# if compiling a region set it as the HUGR entrypoint can be
|
|
273
|
+
# loosened after https://github.com/quantinuum/hugr/issues/2501 is fixed
|
|
268
274
|
graph.hugr.entrypoint = compiled_def.hugr_node
|
|
269
275
|
|
|
270
276
|
# TODO: Currently the list of extensions is manually managed by the user.
|
|
@@ -278,7 +284,7 @@ class CompilationEngine:
|
|
|
278
284
|
guppylang_internals.compiler.hugr_extension.EXTENSION,
|
|
279
285
|
*self.additional_extensions,
|
|
280
286
|
]
|
|
281
|
-
# TODO replace with computed extensions after https://github.com/
|
|
287
|
+
# TODO replace with computed extensions after https://github.com/quantinuum/guppylang/issues/550
|
|
282
288
|
all_used_extensions = [
|
|
283
289
|
*extensions,
|
|
284
290
|
hugr.std.prelude.PRELUDE_EXTENSION,
|