guppylang-internals 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/cfg/builder.py +20 -2
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +1 -2
- guppylang_internals/checker/errors/linearity.py +6 -2
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +39 -19
- guppylang_internals/checker/func_checker.py +17 -13
- guppylang_internals/checker/linearity_checker.py +2 -10
- guppylang_internals/checker/modifier_checker.py +6 -2
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +5 -5
- guppylang_internals/compiler/expr_compiler.py +72 -81
- guppylang_internals/compiler/modifier_compiler.py +5 -0
- guppylang_internals/decorator.py +88 -7
- guppylang_internals/definition/custom.py +4 -0
- guppylang_internals/definition/declaration.py +6 -2
- guppylang_internals/definition/function.py +26 -3
- guppylang_internals/definition/metadata.py +87 -0
- guppylang_internals/definition/overloaded.py +11 -2
- guppylang_internals/definition/pytket_circuits.py +7 -2
- guppylang_internals/definition/struct.py +6 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/diagnostic.py +72 -15
- guppylang_internals/engine.py +10 -13
- guppylang_internals/nodes.py +55 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +37 -2
- guppylang_internals/std/_internal/compiler/either.py +14 -2
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
- guppylang_internals/std/_internal/compiler/tket_exts.py +4 -5
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +14 -0
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/parsing.py +3 -3
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +37 -2
- guppylang_internals/tys/ty.py +60 -64
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/METADATA +5 -4
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/RECORD +49 -45
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/WHEEL +1 -1
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -185,7 +185,7 @@ class CompilerContext(ToHugrContext):
|
|
|
185
185
|
# make the call to `ENGINE.get_checked` below fail. For now, let's just short-
|
|
186
186
|
# cut if the function doesn't take any generic params (as is the case for all
|
|
187
187
|
# nested functions).
|
|
188
|
-
# See https://github.com/
|
|
188
|
+
# See https://github.com/quantinuum/guppylang/issues/1032
|
|
189
189
|
if (def_id, ()) in self.compiled:
|
|
190
190
|
assert type_args == []
|
|
191
191
|
return self.compiled[def_id, ()], type_args
|
|
@@ -247,7 +247,7 @@ class CompilerContext(ToHugrContext):
|
|
|
247
247
|
|
|
248
248
|
# Insert explicit drops for affine types
|
|
249
249
|
# TODO: This is a quick workaround until we can properly insert these drops
|
|
250
|
-
# during linearity checking. See https://github.com/
|
|
250
|
+
# during linearity checking. See https://github.com/quantinuum/guppylang/issues/1082
|
|
251
251
|
insert_drops(self.module.hugr)
|
|
252
252
|
|
|
253
253
|
return entry_compiled
|
|
@@ -659,7 +659,7 @@ def track_hugr_side_effects() -> Iterator[None]:
|
|
|
659
659
|
|
|
660
660
|
def qualified_name(type_def: he.TypeDef) -> str:
|
|
661
661
|
"""Returns the qualified name of a Hugr extension type.
|
|
662
|
-
TODO: Remove once upstreamed, see https://github.com/
|
|
662
|
+
TODO: Remove once upstreamed, see https://github.com/quantinuum/hugr/issues/2426
|
|
663
663
|
"""
|
|
664
664
|
if type_def._extension is not None:
|
|
665
665
|
return f"{type_def._extension.name}.{type_def.name}"
|
|
@@ -711,14 +711,14 @@ def drop_op(ty: ht.Type) -> ops.ExtOp:
|
|
|
711
711
|
def insert_drops(hugr: Hugr[OpVarCov]) -> None:
|
|
712
712
|
"""Inserts explicit drop ops for unconnected ports into the Hugr.
|
|
713
713
|
TODO: This is a quick workaround until we can properly insert these drops during
|
|
714
|
-
linearity checking. See https://github.com/
|
|
714
|
+
linearity checking. See https://github.com/quantinuum/guppylang/issues/1082
|
|
715
715
|
"""
|
|
716
716
|
for node in hugr:
|
|
717
717
|
data = hugr[node]
|
|
718
718
|
# Iterating over `node.outputs()` doesn't work reliably since it sometimes
|
|
719
719
|
# raises an `IncompleteOp` exception. Instead, we query the number of out ports
|
|
720
720
|
# and look them up by index. However, this method is *also* broken when
|
|
721
|
-
# inspecting `FuncDefn` nodes due to https://github.com/
|
|
721
|
+
# inspecting `FuncDefn` nodes due to https://github.com/quantinuum/hugr/issues/2438.
|
|
722
722
|
if isinstance(data.op, ops.FuncDefn):
|
|
723
723
|
continue
|
|
724
724
|
for i in range(hugr.num_out_ports(node)):
|
|
@@ -15,7 +15,6 @@ from hugr import val as hv
|
|
|
15
15
|
from hugr.build import function as hf
|
|
16
16
|
from hugr.build.cond_loop import Conditional
|
|
17
17
|
from hugr.build.dfg import DP, DfBase
|
|
18
|
-
from typing_extensions import assert_never
|
|
19
18
|
|
|
20
19
|
from guppylang_internals.ast_util import AstNode, AstVisitor, get_type
|
|
21
20
|
from guppylang_internals.cfg.builder import tmp_vars
|
|
@@ -23,7 +22,6 @@ from guppylang_internals.checker.core import Variable, contains_subscript
|
|
|
23
22
|
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
24
23
|
from guppylang_internals.compiler.core import (
|
|
25
24
|
DEBUG_EXTENSION,
|
|
26
|
-
RESULT_EXTENSION,
|
|
27
25
|
CompilerBase,
|
|
28
26
|
CompilerContext,
|
|
29
27
|
DFContainer,
|
|
@@ -53,7 +51,6 @@ from guppylang_internals.nodes import (
|
|
|
53
51
|
PanicExpr,
|
|
54
52
|
PartialApply,
|
|
55
53
|
PlaceNode,
|
|
56
|
-
ResultExpr,
|
|
57
54
|
StateResultExpr,
|
|
58
55
|
SubscriptAccessAndDrop,
|
|
59
56
|
TensorCall,
|
|
@@ -66,7 +63,6 @@ from guppylang_internals.std._internal.compiler.arithmetic import (
|
|
|
66
63
|
convert_itousize,
|
|
67
64
|
)
|
|
68
65
|
from guppylang_internals.std._internal.compiler.array import (
|
|
69
|
-
array_clone,
|
|
70
66
|
array_map,
|
|
71
67
|
array_new,
|
|
72
68
|
array_to_std_array,
|
|
@@ -80,8 +76,8 @@ from guppylang_internals.std._internal.compiler.list import (
|
|
|
80
76
|
list_new,
|
|
81
77
|
)
|
|
82
78
|
from guppylang_internals.std._internal.compiler.prelude import (
|
|
83
|
-
build_error,
|
|
84
79
|
build_panic,
|
|
80
|
+
make_error,
|
|
85
81
|
panic,
|
|
86
82
|
)
|
|
87
83
|
from guppylang_internals.std._internal.compiler.tket_bool import (
|
|
@@ -93,14 +89,12 @@ from guppylang_internals.std._internal.compiler.tket_bool import (
|
|
|
93
89
|
)
|
|
94
90
|
from guppylang_internals.tys.arg import ConstArg
|
|
95
91
|
from guppylang_internals.tys.builtin import (
|
|
96
|
-
array_type,
|
|
97
92
|
bool_type,
|
|
98
93
|
get_element_type,
|
|
99
94
|
int_type,
|
|
100
|
-
is_bool_type,
|
|
101
95
|
is_frozenarray_type,
|
|
102
96
|
)
|
|
103
|
-
from guppylang_internals.tys.const import ConstValue
|
|
97
|
+
from guppylang_internals.tys.const import BoundConstVar, Const, ConstValue
|
|
104
98
|
from guppylang_internals.tys.subst import Inst
|
|
105
99
|
from guppylang_internals.tys.ty import (
|
|
106
100
|
BoundTypeVar,
|
|
@@ -331,14 +325,7 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
331
325
|
|
|
332
326
|
def _pack_returns(self, returns: Sequence[Wire], return_ty: Type) -> Wire:
|
|
333
327
|
"""Groups function return values into a tuple"""
|
|
334
|
-
|
|
335
|
-
types = type_to_row(return_ty)
|
|
336
|
-
assert len(returns) == len(types)
|
|
337
|
-
return self._pack_tuple(returns, types)
|
|
338
|
-
assert (
|
|
339
|
-
len(returns) == 1
|
|
340
|
-
), f"Expected a single return value. Got {returns}. return type {return_ty}"
|
|
341
|
-
return returns[0]
|
|
328
|
+
return pack_returns(returns, return_ty, self.builder, self.ctx)
|
|
342
329
|
|
|
343
330
|
def _update_inout_ports(
|
|
344
331
|
self,
|
|
@@ -535,74 +522,48 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
535
522
|
tuple_port = self.visit(node.value)
|
|
536
523
|
return self._unpack_tuple(tuple_port, node.tuple_ty.element_types)[node.index]
|
|
537
524
|
|
|
538
|
-
def
|
|
539
|
-
|
|
540
|
-
base_ty = node.base_ty.to_hugr(self.ctx)
|
|
541
|
-
extra_args: list[ht.TypeArg] = []
|
|
542
|
-
if isinstance(node.base_ty, NumericType):
|
|
543
|
-
match node.base_ty.kind:
|
|
544
|
-
case NumericType.Kind.Nat:
|
|
545
|
-
base_name = "uint"
|
|
546
|
-
extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
|
|
547
|
-
case NumericType.Kind.Int:
|
|
548
|
-
base_name = "int"
|
|
549
|
-
extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
|
|
550
|
-
case NumericType.Kind.Float:
|
|
551
|
-
base_name = "f64"
|
|
552
|
-
case kind:
|
|
553
|
-
assert_never(kind)
|
|
554
|
-
else:
|
|
555
|
-
# The only other valid base type is bool
|
|
556
|
-
assert is_bool_type(node.base_ty)
|
|
557
|
-
base_name = "bool"
|
|
558
|
-
if node.array_len is not None:
|
|
559
|
-
op_name = f"result_array_{base_name}"
|
|
560
|
-
size_arg = node.array_len.to_arg().to_hugr(self.ctx)
|
|
561
|
-
extra_args = [size_arg, *extra_args]
|
|
562
|
-
# As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
|
|
563
|
-
# that all elements in it are copyable) to avoid linearity violations when
|
|
564
|
-
# both passing it to the result operation and returning it (as an inout
|
|
565
|
-
# argument).
|
|
566
|
-
value_wire, inout_wire = self.builder.add_op(
|
|
567
|
-
array_clone(base_ty, size_arg), value_wire
|
|
568
|
-
)
|
|
569
|
-
func_ty = FunctionType(
|
|
570
|
-
[
|
|
571
|
-
FuncInput(
|
|
572
|
-
array_type(node.base_ty, node.array_len), InputFlags.Inout
|
|
573
|
-
),
|
|
574
|
-
],
|
|
575
|
-
NoneType(),
|
|
576
|
-
)
|
|
577
|
-
self._update_inout_ports(node.args, iter([inout_wire]), func_ty)
|
|
578
|
-
if is_bool_type(node.base_ty):
|
|
579
|
-
# We need to coerce a read on all the array elements if they are bools.
|
|
580
|
-
array_read = array_read_bool(self.ctx)
|
|
581
|
-
array_read = self.builder.load_function(array_read)
|
|
582
|
-
map_op = array_map(OpaqueBool, size_arg, ht.Bool)
|
|
583
|
-
value_wire = self.builder.add_op(map_op, value_wire, array_read)
|
|
584
|
-
base_ty = ht.Bool
|
|
585
|
-
# Turn `borrow_array` into regular `array`
|
|
586
|
-
value_wire = self.builder.add_op(
|
|
587
|
-
array_to_std_array(base_ty, size_arg), value_wire
|
|
588
|
-
)
|
|
589
|
-
hugr_ty: ht.Type = hugr.std.collections.array.Array(base_ty, size_arg)
|
|
590
|
-
else:
|
|
591
|
-
if is_bool_type(node.base_ty):
|
|
592
|
-
base_ty = ht.Bool
|
|
593
|
-
value_wire = self.builder.add_op(read_bool(), value_wire)
|
|
594
|
-
op_name = f"result_{base_name}"
|
|
595
|
-
hugr_ty = base_ty
|
|
525
|
+
def _visit_result_tag(self, tag: Const, loc: ast.expr) -> str:
|
|
526
|
+
"""Helper method to resolve the tag string in `state_result` expressions.
|
|
596
527
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
528
|
+
Also takes care of checking that the tag fits into the maximum tag length.
|
|
529
|
+
Once we go ahead with https://github.com/quantinuum/guppylang/discussions/1299,
|
|
530
|
+
this can be moved into type checking.
|
|
531
|
+
"""
|
|
532
|
+
from guppylang_internals.std._internal.compiler.platform import (
|
|
533
|
+
TAG_MAX_LEN,
|
|
534
|
+
TooLongError,
|
|
535
|
+
)
|
|
600
536
|
|
|
601
|
-
|
|
602
|
-
|
|
537
|
+
is_generic: BoundConstVar | None = None
|
|
538
|
+
match tag:
|
|
539
|
+
case ConstValue(value=str(v)):
|
|
540
|
+
tag_value = v
|
|
541
|
+
case BoundConstVar(idx=idx) as var:
|
|
542
|
+
assert self.ctx.current_mono_args is not None
|
|
543
|
+
match self.ctx.current_mono_args[idx]:
|
|
544
|
+
case ConstArg(const=ConstValue(value=str(v))):
|
|
545
|
+
tag_value = v
|
|
546
|
+
is_generic = var
|
|
547
|
+
case _:
|
|
548
|
+
raise InternalGuppyError("Unexpected tag monomorphization")
|
|
549
|
+
case _:
|
|
550
|
+
raise InternalGuppyError("Unexpected tag value")
|
|
551
|
+
|
|
552
|
+
if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
|
|
553
|
+
err = TooLongError(loc)
|
|
554
|
+
err.add_sub_diagnostic(TooLongError.Hint(None))
|
|
555
|
+
if is_generic:
|
|
556
|
+
err.add_sub_diagnostic(
|
|
557
|
+
TooLongError.GenericHint(None, is_generic.display_name, tag_value)
|
|
558
|
+
)
|
|
559
|
+
raise GuppyError(err)
|
|
560
|
+
return tag_value
|
|
603
561
|
|
|
604
562
|
def visit_PanicExpr(self, node: PanicExpr) -> Wire:
|
|
605
|
-
|
|
563
|
+
signal = self.visit(node.signal)
|
|
564
|
+
signal_usize = self.builder.add_op(convert_itousize(), signal)
|
|
565
|
+
msg = self.visit(node.msg)
|
|
566
|
+
err = self.builder.add_op(make_error(), signal_usize, msg)
|
|
606
567
|
in_tys = [get_type(e).to_hugr(self.ctx) for e in node.values]
|
|
607
568
|
out_tys = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(node))]
|
|
608
569
|
args = [self.visit(e) for e in node.values]
|
|
@@ -627,12 +588,13 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
627
588
|
return self._pack_returns([], NoneType())
|
|
628
589
|
|
|
629
590
|
def visit_StateResultExpr(self, node: StateResultExpr) -> Wire:
|
|
591
|
+
tag_value = self._visit_result_tag(node.tag_value, node.tag_expr)
|
|
630
592
|
num_qubits_arg = (
|
|
631
593
|
node.array_len.to_arg().to_hugr(self.ctx)
|
|
632
594
|
if node.array_len
|
|
633
595
|
else ht.BoundedNatArg(len(node.args) - 1)
|
|
634
596
|
)
|
|
635
|
-
args = [ht.StringArg(
|
|
597
|
+
args = [ht.StringArg(tag_value), num_qubits_arg]
|
|
636
598
|
sig = ht.FunctionType(
|
|
637
599
|
[standard_array_type(ht.Qubit, num_qubits_arg)],
|
|
638
600
|
[standard_array_type(ht.Qubit, num_qubits_arg)],
|
|
@@ -791,6 +753,35 @@ def expr_to_row(expr: ast.expr) -> list[ast.expr]:
|
|
|
791
753
|
return expr.elts if isinstance(expr, ast.Tuple) else [expr]
|
|
792
754
|
|
|
793
755
|
|
|
756
|
+
def pack_returns(
|
|
757
|
+
returns: Sequence[Wire],
|
|
758
|
+
return_ty: Type,
|
|
759
|
+
builder: DfBase[ops.DfParentOp],
|
|
760
|
+
ctx: CompilerContext,
|
|
761
|
+
) -> Wire:
|
|
762
|
+
"""Groups function return values into a tuple"""
|
|
763
|
+
if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
|
|
764
|
+
types = type_to_row(return_ty)
|
|
765
|
+
assert len(returns) == len(types)
|
|
766
|
+
hugr_tys = [t.to_hugr(ctx) for t in types]
|
|
767
|
+
return builder.add_op(ops.MakeTuple(hugr_tys), *returns)
|
|
768
|
+
assert (
|
|
769
|
+
len(returns) == 1
|
|
770
|
+
), f"Expected a single return value. Got {returns}. return type {return_ty}"
|
|
771
|
+
return returns[0]
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
def unpack_wire(
|
|
775
|
+
wire: Wire, return_ty: Type, builder: DfBase[ops.DfParentOp], ctx: CompilerContext
|
|
776
|
+
) -> list[Wire]:
|
|
777
|
+
"""The inverse of `pack_returns`"""
|
|
778
|
+
if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
|
|
779
|
+
types = type_to_row(return_ty)
|
|
780
|
+
hugr_tys = [t.to_hugr(ctx) for t in types]
|
|
781
|
+
return list(builder.add_op(ops.UnpackTuple(hugr_tys), wire).outputs())
|
|
782
|
+
return [wire]
|
|
783
|
+
|
|
784
|
+
|
|
794
785
|
def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
|
|
795
786
|
"""Checks if instantiating a polymorphic makes it return a row."""
|
|
796
787
|
if isinstance(func_ty.output, BoundTypeVar):
|
|
@@ -8,6 +8,7 @@ from guppylang_internals.checker.modifier_checker import non_copyable_front_othe
|
|
|
8
8
|
from guppylang_internals.compiler.cfg_compiler import compile_cfg
|
|
9
9
|
from guppylang_internals.compiler.core import CompilerContext, DFContainer
|
|
10
10
|
from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
11
|
+
from guppylang_internals.definition.metadata import add_metadata
|
|
11
12
|
from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
|
|
12
13
|
from guppylang_internals.std._internal.compiler.array import (
|
|
13
14
|
array_new,
|
|
@@ -56,6 +57,10 @@ def compile_modified_block(
|
|
|
56
57
|
func_builder = dfg.builder.module_root_builder().define_function(
|
|
57
58
|
str(modified_block), hugr_ty.input, hugr_ty.output
|
|
58
59
|
)
|
|
60
|
+
add_metadata(
|
|
61
|
+
func_builder,
|
|
62
|
+
additional_metadata={"unitary": modified_block.ty.unitary_flags.value},
|
|
63
|
+
)
|
|
59
64
|
|
|
60
65
|
# compile body
|
|
61
66
|
cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
|
guppylang_internals/decorator.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
+
import pathlib
|
|
4
5
|
from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
|
|
5
6
|
|
|
6
7
|
from hugr import ops
|
|
7
8
|
from hugr import tys as ht
|
|
8
9
|
|
|
10
|
+
from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
|
|
9
11
|
from guppylang_internals.compiler.core import (
|
|
10
12
|
CompilerContext,
|
|
11
13
|
GlobalConstId,
|
|
@@ -24,6 +26,7 @@ from guppylang_internals.definition.ty import OpaqueTypeDef, TypeDef
|
|
|
24
26
|
from guppylang_internals.definition.wasm import RawWasmFunctionDef
|
|
25
27
|
from guppylang_internals.dummy_decorator import _dummy_custom_decorator, sphinx_running
|
|
26
28
|
from guppylang_internals.engine import DEF_STORE
|
|
29
|
+
from guppylang_internals.error import GuppyError, pretty_errors
|
|
27
30
|
from guppylang_internals.std._internal.checker import WasmCallChecker
|
|
28
31
|
from guppylang_internals.std._internal.compiler.wasm import (
|
|
29
32
|
WasmModuleCallCompiler,
|
|
@@ -39,6 +42,14 @@ from guppylang_internals.tys.ty import (
|
|
|
39
42
|
InputFlags,
|
|
40
43
|
NoneType,
|
|
41
44
|
NumericType,
|
|
45
|
+
UnitaryFlags,
|
|
46
|
+
)
|
|
47
|
+
from guppylang_internals.wasm_util import (
|
|
48
|
+
ConcreteWasmModule,
|
|
49
|
+
WasmFileNotFound,
|
|
50
|
+
WasmFunctionNotInFile,
|
|
51
|
+
WasmSignatureError,
|
|
52
|
+
decode_wasm_functions,
|
|
42
53
|
)
|
|
43
54
|
|
|
44
55
|
if TYPE_CHECKING:
|
|
@@ -47,7 +58,6 @@ if TYPE_CHECKING:
|
|
|
47
58
|
from collections.abc import Callable, Sequence
|
|
48
59
|
from types import FrameType
|
|
49
60
|
|
|
50
|
-
from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
|
|
51
61
|
from guppylang_internals.tys.arg import Argument
|
|
52
62
|
from guppylang_internals.tys.param import Parameter
|
|
53
63
|
from guppylang_internals.tys.subst import Inst
|
|
@@ -75,6 +85,7 @@ def custom_function(
|
|
|
75
85
|
higher_order_value: bool = True,
|
|
76
86
|
name: str = "",
|
|
77
87
|
signature: FunctionType | None = None,
|
|
88
|
+
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
|
|
78
89
|
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
|
|
79
90
|
"""Decorator to add custom typing or compilation behaviour to function decls.
|
|
80
91
|
|
|
@@ -86,6 +97,8 @@ def custom_function(
|
|
|
86
97
|
|
|
87
98
|
def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
|
|
88
99
|
call_checker = checker or DefaultCallChecker()
|
|
100
|
+
if signature is not None:
|
|
101
|
+
object.__setattr__(signature, "unitary_flags", unitary_flags)
|
|
89
102
|
func = RawCustomFunctionDef(
|
|
90
103
|
DefId.fresh(),
|
|
91
104
|
name or f.__name__,
|
|
@@ -95,6 +108,7 @@ def custom_function(
|
|
|
95
108
|
compiler or NotImplementedCallCompiler(),
|
|
96
109
|
higher_order_value,
|
|
97
110
|
signature,
|
|
111
|
+
unitary_flags,
|
|
98
112
|
)
|
|
99
113
|
DEF_STORE.register_def(func, get_calling_frame())
|
|
100
114
|
return GuppyFunctionDefinition(func)
|
|
@@ -108,6 +122,7 @@ def hugr_op(
|
|
|
108
122
|
higher_order_value: bool = True,
|
|
109
123
|
name: str = "",
|
|
110
124
|
signature: FunctionType | None = None,
|
|
125
|
+
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
|
|
111
126
|
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
|
|
112
127
|
"""Decorator to annotate function declarations as HUGR ops.
|
|
113
128
|
|
|
@@ -119,7 +134,14 @@ def hugr_op(
|
|
|
119
134
|
value.
|
|
120
135
|
name: The name of the function.
|
|
121
136
|
"""
|
|
122
|
-
return custom_function(
|
|
137
|
+
return custom_function(
|
|
138
|
+
OpCompiler(op),
|
|
139
|
+
checker,
|
|
140
|
+
higher_order_value,
|
|
141
|
+
name,
|
|
142
|
+
signature,
|
|
143
|
+
unitary_flags=unitary_flags,
|
|
144
|
+
)
|
|
123
145
|
|
|
124
146
|
|
|
125
147
|
def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]:
|
|
@@ -185,9 +207,16 @@ def custom_type(
|
|
|
185
207
|
return dec
|
|
186
208
|
|
|
187
209
|
|
|
210
|
+
@pretty_errors
|
|
188
211
|
def wasm_module(
|
|
189
212
|
filename: str,
|
|
190
213
|
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
214
|
+
wasm_file = pathlib.Path(filename)
|
|
215
|
+
if wasm_file.is_file():
|
|
216
|
+
wasm_sigs = decode_wasm_functions(filename)
|
|
217
|
+
else:
|
|
218
|
+
raise GuppyError(WasmFileNotFound(None, filename))
|
|
219
|
+
|
|
191
220
|
def type_def_wrapper(
|
|
192
221
|
id: DefId,
|
|
193
222
|
name: str,
|
|
@@ -198,10 +227,19 @@ def wasm_module(
|
|
|
198
227
|
assert config is None
|
|
199
228
|
return WasmModuleTypeDef(id, name, defined_at, wasm_file)
|
|
200
229
|
|
|
201
|
-
|
|
202
|
-
type_def_wrapper,
|
|
230
|
+
decorator = ext_module_decorator(
|
|
231
|
+
type_def_wrapper,
|
|
232
|
+
WasmModuleInitCompiler(),
|
|
233
|
+
WasmModuleDiscardCompiler(),
|
|
234
|
+
True,
|
|
235
|
+
wasm_sigs,
|
|
203
236
|
)
|
|
204
|
-
|
|
237
|
+
|
|
238
|
+
def inner_fun(ty: builtins.type[T]) -> GuppyDefinition:
|
|
239
|
+
decorator_inner = decorator(filename, None)
|
|
240
|
+
return decorator_inner(ty)
|
|
241
|
+
|
|
242
|
+
return inner_fun
|
|
205
243
|
|
|
206
244
|
|
|
207
245
|
def ext_module_decorator(
|
|
@@ -209,12 +247,13 @@ def ext_module_decorator(
|
|
|
209
247
|
init_compiler: CustomInoutCallCompiler,
|
|
210
248
|
discard_compiler: CustomInoutCallCompiler,
|
|
211
249
|
init_arg: bool, # Whether the init function should take a nat argument
|
|
250
|
+
wasm_sigs: ConcreteWasmModule
|
|
251
|
+
| None = None, # For @wasm_module, we must be passed a parsed wasm file
|
|
212
252
|
) -> Callable[[str, str | None], Callable[[builtins.type[T]], GuppyDefinition]]:
|
|
213
|
-
from guppylang.defs import GuppyDefinition
|
|
214
|
-
|
|
215
253
|
def fun(
|
|
216
254
|
filename: str, module: str | None
|
|
217
255
|
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
256
|
+
@pretty_errors
|
|
218
257
|
def dec(cls: builtins.type[T]) -> GuppyDefinition:
|
|
219
258
|
# N.B. Only one module per file and vice-versa
|
|
220
259
|
ext_module = type_def(
|
|
@@ -231,6 +270,47 @@ def ext_module_decorator(
|
|
|
231
270
|
for val in cls.__dict__.values():
|
|
232
271
|
if isinstance(val, GuppyDefinition):
|
|
233
272
|
DEF_STORE.register_impl(ext_module.id, val.wrapped.name, val.id)
|
|
273
|
+
wasm_def: RawWasmFunctionDef
|
|
274
|
+
if isinstance(val, GuppyFunctionDefinition) and isinstance(
|
|
275
|
+
val.wrapped, RawWasmFunctionDef
|
|
276
|
+
):
|
|
277
|
+
wasm_def = val.wrapped
|
|
278
|
+
else:
|
|
279
|
+
continue
|
|
280
|
+
# wasm_sigs should only have not been provided if we have
|
|
281
|
+
# defined @wasm functions in a class which didn't use the
|
|
282
|
+
# @wasm_module decorator.
|
|
283
|
+
assert wasm_sigs is not None
|
|
284
|
+
if wasm_def.wasm_index is not None:
|
|
285
|
+
name = wasm_sigs.functions[wasm_def.wasm_index]
|
|
286
|
+
assert name in wasm_sigs.function_sigs
|
|
287
|
+
wasm_sig_or_err = wasm_sigs.function_sigs[name]
|
|
288
|
+
else:
|
|
289
|
+
if wasm_def.name in wasm_sigs.function_sigs:
|
|
290
|
+
wasm_sig_or_err = wasm_sigs.function_sigs[wasm_def.name]
|
|
291
|
+
else:
|
|
292
|
+
raise GuppyError(
|
|
293
|
+
WasmFunctionNotInFile(
|
|
294
|
+
wasm_def.defined_at,
|
|
295
|
+
wasm_def.name,
|
|
296
|
+
).add_sub_diagnostic(
|
|
297
|
+
WasmFunctionNotInFile.WasmFileNote(
|
|
298
|
+
None,
|
|
299
|
+
wasm_sigs.filename,
|
|
300
|
+
)
|
|
301
|
+
)
|
|
302
|
+
)
|
|
303
|
+
if isinstance(wasm_sig_or_err, FunctionType):
|
|
304
|
+
DEF_STORE.register_wasm_function(wasm_def.id, wasm_sig_or_err)
|
|
305
|
+
elif isinstance(wasm_sig_or_err, str):
|
|
306
|
+
raise GuppyError(
|
|
307
|
+
WasmSignatureError(
|
|
308
|
+
None, wasm_def.name, filename
|
|
309
|
+
).add_sub_diagnostic(
|
|
310
|
+
WasmSignatureError.Message(None, wasm_sig_or_err)
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
|
|
234
314
|
# Add a constructor to the class
|
|
235
315
|
if init_arg:
|
|
236
316
|
init_fn_ty = FunctionType(
|
|
@@ -315,6 +395,7 @@ def wasm_helper(fn_id: int | None, f: Callable[P, T]) -> GuppyFunctionDefinition
|
|
|
315
395
|
WasmModuleCallCompiler(f.__name__, fn_id),
|
|
316
396
|
True,
|
|
317
397
|
signature=None,
|
|
398
|
+
wasm_index=fn_id,
|
|
318
399
|
)
|
|
319
400
|
DEF_STORE.register_def(func, get_calling_frame())
|
|
320
401
|
return GuppyFunctionDefinition(func)
|
|
@@ -44,6 +44,7 @@ from guppylang_internals.tys.ty import (
|
|
|
44
44
|
InputFlags,
|
|
45
45
|
NoneType,
|
|
46
46
|
Type,
|
|
47
|
+
UnitaryFlags,
|
|
47
48
|
type_to_row,
|
|
48
49
|
)
|
|
49
50
|
|
|
@@ -114,6 +115,8 @@ class RawCustomFunctionDef(ParsableDef):
|
|
|
114
115
|
|
|
115
116
|
signature: FunctionType | None
|
|
116
117
|
|
|
118
|
+
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags)
|
|
119
|
+
|
|
117
120
|
description: str = field(default="function", init=False)
|
|
118
121
|
|
|
119
122
|
def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
|
|
@@ -136,6 +139,7 @@ class RawCustomFunctionDef(ParsableDef):
|
|
|
136
139
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
137
140
|
sig = self.signature or self._get_signature(func_ast, globals)
|
|
138
141
|
ty = sig or FunctionType([], NoneType())
|
|
142
|
+
ty = ty.with_unitary_flags(self.unitary_flags)
|
|
139
143
|
return CustomFunctionDef(
|
|
140
144
|
self.id,
|
|
141
145
|
self.name,
|
|
@@ -34,7 +34,7 @@ from guppylang_internals.nodes import GlobalCall
|
|
|
34
34
|
from guppylang_internals.span import SourceMap
|
|
35
35
|
from guppylang_internals.tys.param import Parameter
|
|
36
36
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
37
|
-
from guppylang_internals.tys.ty import Type
|
|
37
|
+
from guppylang_internals.tys.ty import Type, UnitaryFlags
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
@dataclass(frozen=True)
|
|
@@ -65,10 +65,14 @@ class RawFunctionDecl(ParsableDef):
|
|
|
65
65
|
python_func: PyFunc
|
|
66
66
|
description: str = field(default="function", init=False)
|
|
67
67
|
|
|
68
|
+
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
|
|
69
|
+
|
|
68
70
|
def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
|
|
69
71
|
"""Parses and checks the user-provided signature of the function."""
|
|
70
72
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
71
|
-
ty = check_signature(
|
|
73
|
+
ty = check_signature(
|
|
74
|
+
func_ast, globals, self.id, unitary_flags=self.unitary_flags
|
|
75
|
+
)
|
|
72
76
|
if not has_empty_body(func_ast):
|
|
73
77
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
74
78
|
# Make sure we won't need monomorphization to compile this declaration
|
|
@@ -33,6 +33,7 @@ from guppylang_internals.definition.common import (
|
|
|
33
33
|
ParsableDef,
|
|
34
34
|
UnknownSourceError,
|
|
35
35
|
)
|
|
36
|
+
from guppylang_internals.definition.metadata import GuppyMetadata, add_metadata
|
|
36
37
|
from guppylang_internals.definition.value import (
|
|
37
38
|
CallableDef,
|
|
38
39
|
CallReturnWires,
|
|
@@ -43,7 +44,7 @@ from guppylang_internals.error import GuppyError
|
|
|
43
44
|
from guppylang_internals.nodes import GlobalCall
|
|
44
45
|
from guppylang_internals.span import SourceMap
|
|
45
46
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
46
|
-
from guppylang_internals.tys.ty import FunctionType, Type, type_to_row
|
|
47
|
+
from guppylang_internals.tys.ty import FunctionType, Type, UnitaryFlags, type_to_row
|
|
47
48
|
|
|
48
49
|
if TYPE_CHECKING:
|
|
49
50
|
from guppylang_internals.tys.param import Parameter
|
|
@@ -70,11 +71,24 @@ class RawFunctionDef(ParsableDef):
|
|
|
70
71
|
|
|
71
72
|
description: str = field(default="function", init=False)
|
|
72
73
|
|
|
74
|
+
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
|
|
75
|
+
|
|
76
|
+
metadata: GuppyMetadata | None = field(default=None, kw_only=True)
|
|
77
|
+
|
|
73
78
|
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
|
|
74
79
|
"""Parses and checks the user-provided signature of the function."""
|
|
75
80
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
76
|
-
ty = check_signature(
|
|
77
|
-
|
|
81
|
+
ty = check_signature(
|
|
82
|
+
func_ast, globals, self.id, unitary_flags=self.unitary_flags
|
|
83
|
+
)
|
|
84
|
+
return ParsedFunctionDef(
|
|
85
|
+
self.id,
|
|
86
|
+
self.name,
|
|
87
|
+
func_ast,
|
|
88
|
+
ty,
|
|
89
|
+
docstring,
|
|
90
|
+
metadata=self.metadata,
|
|
91
|
+
)
|
|
78
92
|
|
|
79
93
|
|
|
80
94
|
@dataclass(frozen=True)
|
|
@@ -99,6 +113,8 @@ class ParsedFunctionDef(CheckableDef, CallableDef):
|
|
|
99
113
|
|
|
100
114
|
description: str = field(default="function", init=False)
|
|
101
115
|
|
|
116
|
+
metadata: GuppyMetadata | None = field(default=None, kw_only=True)
|
|
117
|
+
|
|
102
118
|
def check(self, globals: Globals) -> "CheckedFunctionDef":
|
|
103
119
|
"""Type checks the body of the function."""
|
|
104
120
|
# Add python variable scope to the globals
|
|
@@ -110,6 +126,7 @@ class ParsedFunctionDef(CheckableDef, CallableDef):
|
|
|
110
126
|
self.ty,
|
|
111
127
|
self.docstring,
|
|
112
128
|
cfg,
|
|
129
|
+
metadata=self.metadata,
|
|
113
130
|
)
|
|
114
131
|
|
|
115
132
|
def check_call(
|
|
@@ -173,6 +190,11 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
|
|
|
173
190
|
func_def = module.module_root_builder().define_function(
|
|
174
191
|
self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
|
|
175
192
|
)
|
|
193
|
+
add_metadata(
|
|
194
|
+
func_def,
|
|
195
|
+
self.metadata,
|
|
196
|
+
additional_metadata={"unitary": self.ty.unitary_flags.value},
|
|
197
|
+
)
|
|
176
198
|
return CompiledFunctionDef(
|
|
177
199
|
self.id,
|
|
178
200
|
self.name,
|
|
@@ -182,6 +204,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
|
|
|
182
204
|
self.docstring,
|
|
183
205
|
self.cfg,
|
|
184
206
|
func_def,
|
|
207
|
+
metadata=self.metadata,
|
|
185
208
|
)
|
|
186
209
|
|
|
187
210
|
|