guppylang-internals 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/cfg/builder.py +20 -2
  3. guppylang_internals/cfg/cfg.py +3 -0
  4. guppylang_internals/checker/cfg_checker.py +6 -0
  5. guppylang_internals/checker/core.py +1 -2
  6. guppylang_internals/checker/errors/linearity.py +6 -2
  7. guppylang_internals/checker/errors/wasm.py +7 -4
  8. guppylang_internals/checker/expr_checker.py +39 -19
  9. guppylang_internals/checker/func_checker.py +17 -13
  10. guppylang_internals/checker/linearity_checker.py +2 -10
  11. guppylang_internals/checker/modifier_checker.py +6 -2
  12. guppylang_internals/checker/unitary_checker.py +132 -0
  13. guppylang_internals/compiler/cfg_compiler.py +7 -6
  14. guppylang_internals/compiler/core.py +5 -5
  15. guppylang_internals/compiler/expr_compiler.py +72 -81
  16. guppylang_internals/compiler/modifier_compiler.py +5 -0
  17. guppylang_internals/decorator.py +88 -7
  18. guppylang_internals/definition/custom.py +4 -0
  19. guppylang_internals/definition/declaration.py +6 -2
  20. guppylang_internals/definition/function.py +26 -3
  21. guppylang_internals/definition/metadata.py +87 -0
  22. guppylang_internals/definition/overloaded.py +11 -2
  23. guppylang_internals/definition/pytket_circuits.py +7 -2
  24. guppylang_internals/definition/struct.py +6 -3
  25. guppylang_internals/definition/wasm.py +42 -10
  26. guppylang_internals/diagnostic.py +72 -15
  27. guppylang_internals/engine.py +10 -13
  28. guppylang_internals/nodes.py +55 -24
  29. guppylang_internals/std/_internal/checker.py +13 -108
  30. guppylang_internals/std/_internal/compiler/array.py +37 -2
  31. guppylang_internals/std/_internal/compiler/either.py +14 -2
  32. guppylang_internals/std/_internal/compiler/list.py +1 -1
  33. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  34. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  35. guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
  36. guppylang_internals/std/_internal/compiler/tket_exts.py +4 -5
  37. guppylang_internals/std/_internal/debug.py +18 -9
  38. guppylang_internals/std/_internal/util.py +1 -1
  39. guppylang_internals/tracing/object.py +14 -0
  40. guppylang_internals/tys/errors.py +23 -1
  41. guppylang_internals/tys/parsing.py +3 -3
  42. guppylang_internals/tys/printing.py +2 -8
  43. guppylang_internals/tys/qubit.py +37 -2
  44. guppylang_internals/tys/ty.py +60 -64
  45. guppylang_internals/wasm_util.py +129 -0
  46. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/METADATA +5 -4
  47. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/RECORD +49 -45
  48. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/WHEEL +1 -1
  49. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/licenses/LICENCE +0 -0
@@ -185,7 +185,7 @@ class CompilerContext(ToHugrContext):
185
185
  # make the call to `ENGINE.get_checked` below fail. For now, let's just short-
186
186
  # cut if the function doesn't take any generic params (as is the case for all
187
187
  # nested functions).
188
- # See https://github.com/CQCL/guppylang/issues/1032
188
+ # See https://github.com/quantinuum/guppylang/issues/1032
189
189
  if (def_id, ()) in self.compiled:
190
190
  assert type_args == []
191
191
  return self.compiled[def_id, ()], type_args
@@ -247,7 +247,7 @@ class CompilerContext(ToHugrContext):
247
247
 
248
248
  # Insert explicit drops for affine types
249
249
  # TODO: This is a quick workaround until we can properly insert these drops
250
- # during linearity checking. See https://github.com/CQCL/guppylang/issues/1082
250
+ # during linearity checking. See https://github.com/quantinuum/guppylang/issues/1082
251
251
  insert_drops(self.module.hugr)
252
252
 
253
253
  return entry_compiled
@@ -659,7 +659,7 @@ def track_hugr_side_effects() -> Iterator[None]:
659
659
 
660
660
  def qualified_name(type_def: he.TypeDef) -> str:
661
661
  """Returns the qualified name of a Hugr extension type.
662
- TODO: Remove once upstreamed, see https://github.com/CQCL/hugr/issues/2426
662
+ TODO: Remove once upstreamed, see https://github.com/quantinuum/hugr/issues/2426
663
663
  """
664
664
  if type_def._extension is not None:
665
665
  return f"{type_def._extension.name}.{type_def.name}"
@@ -711,14 +711,14 @@ def drop_op(ty: ht.Type) -> ops.ExtOp:
711
711
  def insert_drops(hugr: Hugr[OpVarCov]) -> None:
712
712
  """Inserts explicit drop ops for unconnected ports into the Hugr.
713
713
  TODO: This is a quick workaround until we can properly insert these drops during
714
- linearity checking. See https://github.com/CQCL/guppylang/issues/1082
714
+ linearity checking. See https://github.com/quantinuum/guppylang/issues/1082
715
715
  """
716
716
  for node in hugr:
717
717
  data = hugr[node]
718
718
  # Iterating over `node.outputs()` doesn't work reliably since it sometimes
719
719
  # raises an `IncompleteOp` exception. Instead, we query the number of out ports
720
720
  # and look them up by index. However, this method is *also* broken when
721
- # inspecting `FuncDefn` nodes due to https://github.com/CQCL/hugr/issues/2438.
721
+ # inspecting `FuncDefn` nodes due to https://github.com/quantinuum/hugr/issues/2438.
722
722
  if isinstance(data.op, ops.FuncDefn):
723
723
  continue
724
724
  for i in range(hugr.num_out_ports(node)):
@@ -15,7 +15,6 @@ from hugr import val as hv
15
15
  from hugr.build import function as hf
16
16
  from hugr.build.cond_loop import Conditional
17
17
  from hugr.build.dfg import DP, DfBase
18
- from typing_extensions import assert_never
19
18
 
20
19
  from guppylang_internals.ast_util import AstNode, AstVisitor, get_type
21
20
  from guppylang_internals.cfg.builder import tmp_vars
@@ -23,7 +22,6 @@ from guppylang_internals.checker.core import Variable, contains_subscript
23
22
  from guppylang_internals.checker.errors.generic import UnsupportedError
24
23
  from guppylang_internals.compiler.core import (
25
24
  DEBUG_EXTENSION,
26
- RESULT_EXTENSION,
27
25
  CompilerBase,
28
26
  CompilerContext,
29
27
  DFContainer,
@@ -53,7 +51,6 @@ from guppylang_internals.nodes import (
53
51
  PanicExpr,
54
52
  PartialApply,
55
53
  PlaceNode,
56
- ResultExpr,
57
54
  StateResultExpr,
58
55
  SubscriptAccessAndDrop,
59
56
  TensorCall,
@@ -66,7 +63,6 @@ from guppylang_internals.std._internal.compiler.arithmetic import (
66
63
  convert_itousize,
67
64
  )
68
65
  from guppylang_internals.std._internal.compiler.array import (
69
- array_clone,
70
66
  array_map,
71
67
  array_new,
72
68
  array_to_std_array,
@@ -80,8 +76,8 @@ from guppylang_internals.std._internal.compiler.list import (
80
76
  list_new,
81
77
  )
82
78
  from guppylang_internals.std._internal.compiler.prelude import (
83
- build_error,
84
79
  build_panic,
80
+ make_error,
85
81
  panic,
86
82
  )
87
83
  from guppylang_internals.std._internal.compiler.tket_bool import (
@@ -93,14 +89,12 @@ from guppylang_internals.std._internal.compiler.tket_bool import (
93
89
  )
94
90
  from guppylang_internals.tys.arg import ConstArg
95
91
  from guppylang_internals.tys.builtin import (
96
- array_type,
97
92
  bool_type,
98
93
  get_element_type,
99
94
  int_type,
100
- is_bool_type,
101
95
  is_frozenarray_type,
102
96
  )
103
- from guppylang_internals.tys.const import ConstValue
97
+ from guppylang_internals.tys.const import BoundConstVar, Const, ConstValue
104
98
  from guppylang_internals.tys.subst import Inst
105
99
  from guppylang_internals.tys.ty import (
106
100
  BoundTypeVar,
@@ -331,14 +325,7 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
331
325
 
332
326
  def _pack_returns(self, returns: Sequence[Wire], return_ty: Type) -> Wire:
333
327
  """Groups function return values into a tuple"""
334
- if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
335
- types = type_to_row(return_ty)
336
- assert len(returns) == len(types)
337
- return self._pack_tuple(returns, types)
338
- assert (
339
- len(returns) == 1
340
- ), f"Expected a single return value. Got {returns}. return type {return_ty}"
341
- return returns[0]
328
+ return pack_returns(returns, return_ty, self.builder, self.ctx)
342
329
 
343
330
  def _update_inout_ports(
344
331
  self,
@@ -535,74 +522,48 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
535
522
  tuple_port = self.visit(node.value)
536
523
  return self._unpack_tuple(tuple_port, node.tuple_ty.element_types)[node.index]
537
524
 
538
- def visit_ResultExpr(self, node: ResultExpr) -> Wire:
539
- value_wire = self.visit(node.value)
540
- base_ty = node.base_ty.to_hugr(self.ctx)
541
- extra_args: list[ht.TypeArg] = []
542
- if isinstance(node.base_ty, NumericType):
543
- match node.base_ty.kind:
544
- case NumericType.Kind.Nat:
545
- base_name = "uint"
546
- extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
547
- case NumericType.Kind.Int:
548
- base_name = "int"
549
- extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
550
- case NumericType.Kind.Float:
551
- base_name = "f64"
552
- case kind:
553
- assert_never(kind)
554
- else:
555
- # The only other valid base type is bool
556
- assert is_bool_type(node.base_ty)
557
- base_name = "bool"
558
- if node.array_len is not None:
559
- op_name = f"result_array_{base_name}"
560
- size_arg = node.array_len.to_arg().to_hugr(self.ctx)
561
- extra_args = [size_arg, *extra_args]
562
- # As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
563
- # that all elements in it are copyable) to avoid linearity violations when
564
- # both passing it to the result operation and returning it (as an inout
565
- # argument).
566
- value_wire, inout_wire = self.builder.add_op(
567
- array_clone(base_ty, size_arg), value_wire
568
- )
569
- func_ty = FunctionType(
570
- [
571
- FuncInput(
572
- array_type(node.base_ty, node.array_len), InputFlags.Inout
573
- ),
574
- ],
575
- NoneType(),
576
- )
577
- self._update_inout_ports(node.args, iter([inout_wire]), func_ty)
578
- if is_bool_type(node.base_ty):
579
- # We need to coerce a read on all the array elements if they are bools.
580
- array_read = array_read_bool(self.ctx)
581
- array_read = self.builder.load_function(array_read)
582
- map_op = array_map(OpaqueBool, size_arg, ht.Bool)
583
- value_wire = self.builder.add_op(map_op, value_wire, array_read)
584
- base_ty = ht.Bool
585
- # Turn `borrow_array` into regular `array`
586
- value_wire = self.builder.add_op(
587
- array_to_std_array(base_ty, size_arg), value_wire
588
- )
589
- hugr_ty: ht.Type = hugr.std.collections.array.Array(base_ty, size_arg)
590
- else:
591
- if is_bool_type(node.base_ty):
592
- base_ty = ht.Bool
593
- value_wire = self.builder.add_op(read_bool(), value_wire)
594
- op_name = f"result_{base_name}"
595
- hugr_ty = base_ty
525
+ def _visit_result_tag(self, tag: Const, loc: ast.expr) -> str:
526
+ """Helper method to resolve the tag string in `state_result` expressions.
596
527
 
597
- sig = ht.FunctionType(input=[hugr_ty], output=[])
598
- args = [ht.StringArg(node.tag), *extra_args]
599
- op = ops.ExtOp(RESULT_EXTENSION.get_op(op_name), signature=sig, args=args)
528
+ Also takes care of checking that the tag fits into the maximum tag length.
529
+ Once we go ahead with https://github.com/quantinuum/guppylang/discussions/1299,
530
+ this can be moved into type checking.
531
+ """
532
+ from guppylang_internals.std._internal.compiler.platform import (
533
+ TAG_MAX_LEN,
534
+ TooLongError,
535
+ )
600
536
 
601
- self.builder.add_op(op, value_wire)
602
- return self._pack_returns([], NoneType())
537
+ is_generic: BoundConstVar | None = None
538
+ match tag:
539
+ case ConstValue(value=str(v)):
540
+ tag_value = v
541
+ case BoundConstVar(idx=idx) as var:
542
+ assert self.ctx.current_mono_args is not None
543
+ match self.ctx.current_mono_args[idx]:
544
+ case ConstArg(const=ConstValue(value=str(v))):
545
+ tag_value = v
546
+ is_generic = var
547
+ case _:
548
+ raise InternalGuppyError("Unexpected tag monomorphization")
549
+ case _:
550
+ raise InternalGuppyError("Unexpected tag value")
551
+
552
+ if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
553
+ err = TooLongError(loc)
554
+ err.add_sub_diagnostic(TooLongError.Hint(None))
555
+ if is_generic:
556
+ err.add_sub_diagnostic(
557
+ TooLongError.GenericHint(None, is_generic.display_name, tag_value)
558
+ )
559
+ raise GuppyError(err)
560
+ return tag_value
603
561
 
604
562
  def visit_PanicExpr(self, node: PanicExpr) -> Wire:
605
- err = build_error(self.builder, node.signal, node.msg)
563
+ signal = self.visit(node.signal)
564
+ signal_usize = self.builder.add_op(convert_itousize(), signal)
565
+ msg = self.visit(node.msg)
566
+ err = self.builder.add_op(make_error(), signal_usize, msg)
606
567
  in_tys = [get_type(e).to_hugr(self.ctx) for e in node.values]
607
568
  out_tys = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(node))]
608
569
  args = [self.visit(e) for e in node.values]
@@ -627,12 +588,13 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
627
588
  return self._pack_returns([], NoneType())
628
589
 
629
590
  def visit_StateResultExpr(self, node: StateResultExpr) -> Wire:
591
+ tag_value = self._visit_result_tag(node.tag_value, node.tag_expr)
630
592
  num_qubits_arg = (
631
593
  node.array_len.to_arg().to_hugr(self.ctx)
632
594
  if node.array_len
633
595
  else ht.BoundedNatArg(len(node.args) - 1)
634
596
  )
635
- args = [ht.StringArg(node.tag), num_qubits_arg]
597
+ args = [ht.StringArg(tag_value), num_qubits_arg]
636
598
  sig = ht.FunctionType(
637
599
  [standard_array_type(ht.Qubit, num_qubits_arg)],
638
600
  [standard_array_type(ht.Qubit, num_qubits_arg)],
@@ -791,6 +753,35 @@ def expr_to_row(expr: ast.expr) -> list[ast.expr]:
791
753
  return expr.elts if isinstance(expr, ast.Tuple) else [expr]
792
754
 
793
755
 
756
+ def pack_returns(
757
+ returns: Sequence[Wire],
758
+ return_ty: Type,
759
+ builder: DfBase[ops.DfParentOp],
760
+ ctx: CompilerContext,
761
+ ) -> Wire:
762
+ """Groups function return values into a tuple"""
763
+ if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
764
+ types = type_to_row(return_ty)
765
+ assert len(returns) == len(types)
766
+ hugr_tys = [t.to_hugr(ctx) for t in types]
767
+ return builder.add_op(ops.MakeTuple(hugr_tys), *returns)
768
+ assert (
769
+ len(returns) == 1
770
+ ), f"Expected a single return value. Got {returns}. return type {return_ty}"
771
+ return returns[0]
772
+
773
+
774
+ def unpack_wire(
775
+ wire: Wire, return_ty: Type, builder: DfBase[ops.DfParentOp], ctx: CompilerContext
776
+ ) -> list[Wire]:
777
+ """The inverse of `pack_returns`"""
778
+ if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
779
+ types = type_to_row(return_ty)
780
+ hugr_tys = [t.to_hugr(ctx) for t in types]
781
+ return list(builder.add_op(ops.UnpackTuple(hugr_tys), wire).outputs())
782
+ return [wire]
783
+
784
+
794
785
  def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
795
786
  """Checks if instantiating a polymorphic makes it return a row."""
796
787
  if isinstance(func_ty.output, BoundTypeVar):
@@ -8,6 +8,7 @@ from guppylang_internals.checker.modifier_checker import non_copyable_front_othe
8
8
  from guppylang_internals.compiler.cfg_compiler import compile_cfg
9
9
  from guppylang_internals.compiler.core import CompilerContext, DFContainer
10
10
  from guppylang_internals.compiler.expr_compiler import ExprCompiler
11
+ from guppylang_internals.definition.metadata import add_metadata
11
12
  from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
12
13
  from guppylang_internals.std._internal.compiler.array import (
13
14
  array_new,
@@ -56,6 +57,10 @@ def compile_modified_block(
56
57
  func_builder = dfg.builder.module_root_builder().define_function(
57
58
  str(modified_block), hugr_ty.input, hugr_ty.output
58
59
  )
60
+ add_metadata(
61
+ func_builder,
62
+ additional_metadata={"unitary": modified_block.ty.unitary_flags.value},
63
+ )
59
64
 
60
65
  # compile body
61
66
  cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
@@ -1,11 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import pathlib
4
5
  from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
5
6
 
6
7
  from hugr import ops
7
8
  from hugr import tys as ht
8
9
 
10
+ from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
9
11
  from guppylang_internals.compiler.core import (
10
12
  CompilerContext,
11
13
  GlobalConstId,
@@ -24,6 +26,7 @@ from guppylang_internals.definition.ty import OpaqueTypeDef, TypeDef
24
26
  from guppylang_internals.definition.wasm import RawWasmFunctionDef
25
27
  from guppylang_internals.dummy_decorator import _dummy_custom_decorator, sphinx_running
26
28
  from guppylang_internals.engine import DEF_STORE
29
+ from guppylang_internals.error import GuppyError, pretty_errors
27
30
  from guppylang_internals.std._internal.checker import WasmCallChecker
28
31
  from guppylang_internals.std._internal.compiler.wasm import (
29
32
  WasmModuleCallCompiler,
@@ -39,6 +42,14 @@ from guppylang_internals.tys.ty import (
39
42
  InputFlags,
40
43
  NoneType,
41
44
  NumericType,
45
+ UnitaryFlags,
46
+ )
47
+ from guppylang_internals.wasm_util import (
48
+ ConcreteWasmModule,
49
+ WasmFileNotFound,
50
+ WasmFunctionNotInFile,
51
+ WasmSignatureError,
52
+ decode_wasm_functions,
42
53
  )
43
54
 
44
55
  if TYPE_CHECKING:
@@ -47,7 +58,6 @@ if TYPE_CHECKING:
47
58
  from collections.abc import Callable, Sequence
48
59
  from types import FrameType
49
60
 
50
- from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
51
61
  from guppylang_internals.tys.arg import Argument
52
62
  from guppylang_internals.tys.param import Parameter
53
63
  from guppylang_internals.tys.subst import Inst
@@ -75,6 +85,7 @@ def custom_function(
75
85
  higher_order_value: bool = True,
76
86
  name: str = "",
77
87
  signature: FunctionType | None = None,
88
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
78
89
  ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
79
90
  """Decorator to add custom typing or compilation behaviour to function decls.
80
91
 
@@ -86,6 +97,8 @@ def custom_function(
86
97
 
87
98
  def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
88
99
  call_checker = checker or DefaultCallChecker()
100
+ if signature is not None:
101
+ object.__setattr__(signature, "unitary_flags", unitary_flags)
89
102
  func = RawCustomFunctionDef(
90
103
  DefId.fresh(),
91
104
  name or f.__name__,
@@ -95,6 +108,7 @@ def custom_function(
95
108
  compiler or NotImplementedCallCompiler(),
96
109
  higher_order_value,
97
110
  signature,
111
+ unitary_flags,
98
112
  )
99
113
  DEF_STORE.register_def(func, get_calling_frame())
100
114
  return GuppyFunctionDefinition(func)
@@ -108,6 +122,7 @@ def hugr_op(
108
122
  higher_order_value: bool = True,
109
123
  name: str = "",
110
124
  signature: FunctionType | None = None,
125
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
111
126
  ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
112
127
  """Decorator to annotate function declarations as HUGR ops.
113
128
 
@@ -119,7 +134,14 @@ def hugr_op(
119
134
  value.
120
135
  name: The name of the function.
121
136
  """
122
- return custom_function(OpCompiler(op), checker, higher_order_value, name, signature)
137
+ return custom_function(
138
+ OpCompiler(op),
139
+ checker,
140
+ higher_order_value,
141
+ name,
142
+ signature,
143
+ unitary_flags=unitary_flags,
144
+ )
123
145
 
124
146
 
125
147
  def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]:
@@ -185,9 +207,16 @@ def custom_type(
185
207
  return dec
186
208
 
187
209
 
210
+ @pretty_errors
188
211
  def wasm_module(
189
212
  filename: str,
190
213
  ) -> Callable[[builtins.type[T]], GuppyDefinition]:
214
+ wasm_file = pathlib.Path(filename)
215
+ if wasm_file.is_file():
216
+ wasm_sigs = decode_wasm_functions(filename)
217
+ else:
218
+ raise GuppyError(WasmFileNotFound(None, filename))
219
+
191
220
  def type_def_wrapper(
192
221
  id: DefId,
193
222
  name: str,
@@ -198,10 +227,19 @@ def wasm_module(
198
227
  assert config is None
199
228
  return WasmModuleTypeDef(id, name, defined_at, wasm_file)
200
229
 
201
- f = ext_module_decorator(
202
- type_def_wrapper, WasmModuleInitCompiler(), WasmModuleDiscardCompiler(), True
230
+ decorator = ext_module_decorator(
231
+ type_def_wrapper,
232
+ WasmModuleInitCompiler(),
233
+ WasmModuleDiscardCompiler(),
234
+ True,
235
+ wasm_sigs,
203
236
  )
204
- return f(filename, None)
237
+
238
+ def inner_fun(ty: builtins.type[T]) -> GuppyDefinition:
239
+ decorator_inner = decorator(filename, None)
240
+ return decorator_inner(ty)
241
+
242
+ return inner_fun
205
243
 
206
244
 
207
245
  def ext_module_decorator(
@@ -209,12 +247,13 @@ def ext_module_decorator(
209
247
  init_compiler: CustomInoutCallCompiler,
210
248
  discard_compiler: CustomInoutCallCompiler,
211
249
  init_arg: bool, # Whether the init function should take a nat argument
250
+ wasm_sigs: ConcreteWasmModule
251
+ | None = None, # For @wasm_module, we must be passed a parsed wasm file
212
252
  ) -> Callable[[str, str | None], Callable[[builtins.type[T]], GuppyDefinition]]:
213
- from guppylang.defs import GuppyDefinition
214
-
215
253
  def fun(
216
254
  filename: str, module: str | None
217
255
  ) -> Callable[[builtins.type[T]], GuppyDefinition]:
256
+ @pretty_errors
218
257
  def dec(cls: builtins.type[T]) -> GuppyDefinition:
219
258
  # N.B. Only one module per file and vice-versa
220
259
  ext_module = type_def(
@@ -231,6 +270,47 @@ def ext_module_decorator(
231
270
  for val in cls.__dict__.values():
232
271
  if isinstance(val, GuppyDefinition):
233
272
  DEF_STORE.register_impl(ext_module.id, val.wrapped.name, val.id)
273
+ wasm_def: RawWasmFunctionDef
274
+ if isinstance(val, GuppyFunctionDefinition) and isinstance(
275
+ val.wrapped, RawWasmFunctionDef
276
+ ):
277
+ wasm_def = val.wrapped
278
+ else:
279
+ continue
280
+ # wasm_sigs should only have not been provided if we have
281
+ # defined @wasm functions in a class which didn't use the
282
+ # @wasm_module decorator.
283
+ assert wasm_sigs is not None
284
+ if wasm_def.wasm_index is not None:
285
+ name = wasm_sigs.functions[wasm_def.wasm_index]
286
+ assert name in wasm_sigs.function_sigs
287
+ wasm_sig_or_err = wasm_sigs.function_sigs[name]
288
+ else:
289
+ if wasm_def.name in wasm_sigs.function_sigs:
290
+ wasm_sig_or_err = wasm_sigs.function_sigs[wasm_def.name]
291
+ else:
292
+ raise GuppyError(
293
+ WasmFunctionNotInFile(
294
+ wasm_def.defined_at,
295
+ wasm_def.name,
296
+ ).add_sub_diagnostic(
297
+ WasmFunctionNotInFile.WasmFileNote(
298
+ None,
299
+ wasm_sigs.filename,
300
+ )
301
+ )
302
+ )
303
+ if isinstance(wasm_sig_or_err, FunctionType):
304
+ DEF_STORE.register_wasm_function(wasm_def.id, wasm_sig_or_err)
305
+ elif isinstance(wasm_sig_or_err, str):
306
+ raise GuppyError(
307
+ WasmSignatureError(
308
+ None, wasm_def.name, filename
309
+ ).add_sub_diagnostic(
310
+ WasmSignatureError.Message(None, wasm_sig_or_err)
311
+ )
312
+ )
313
+
234
314
  # Add a constructor to the class
235
315
  if init_arg:
236
316
  init_fn_ty = FunctionType(
@@ -315,6 +395,7 @@ def wasm_helper(fn_id: int | None, f: Callable[P, T]) -> GuppyFunctionDefinition
315
395
  WasmModuleCallCompiler(f.__name__, fn_id),
316
396
  True,
317
397
  signature=None,
398
+ wasm_index=fn_id,
318
399
  )
319
400
  DEF_STORE.register_def(func, get_calling_frame())
320
401
  return GuppyFunctionDefinition(func)
@@ -44,6 +44,7 @@ from guppylang_internals.tys.ty import (
44
44
  InputFlags,
45
45
  NoneType,
46
46
  Type,
47
+ UnitaryFlags,
47
48
  type_to_row,
48
49
  )
49
50
 
@@ -114,6 +115,8 @@ class RawCustomFunctionDef(ParsableDef):
114
115
 
115
116
  signature: FunctionType | None
116
117
 
118
+ unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags)
119
+
117
120
  description: str = field(default="function", init=False)
118
121
 
119
122
  def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
@@ -136,6 +139,7 @@ class RawCustomFunctionDef(ParsableDef):
136
139
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
137
140
  sig = self.signature or self._get_signature(func_ast, globals)
138
141
  ty = sig or FunctionType([], NoneType())
142
+ ty = ty.with_unitary_flags(self.unitary_flags)
139
143
  return CustomFunctionDef(
140
144
  self.id,
141
145
  self.name,
@@ -34,7 +34,7 @@ from guppylang_internals.nodes import GlobalCall
34
34
  from guppylang_internals.span import SourceMap
35
35
  from guppylang_internals.tys.param import Parameter
36
36
  from guppylang_internals.tys.subst import Inst, Subst
37
- from guppylang_internals.tys.ty import Type
37
+ from guppylang_internals.tys.ty import Type, UnitaryFlags
38
38
 
39
39
 
40
40
  @dataclass(frozen=True)
@@ -65,10 +65,14 @@ class RawFunctionDecl(ParsableDef):
65
65
  python_func: PyFunc
66
66
  description: str = field(default="function", init=False)
67
67
 
68
+ unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
69
+
68
70
  def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
69
71
  """Parses and checks the user-provided signature of the function."""
70
72
  func_ast, docstring = parse_py_func(self.python_func, sources)
71
- ty = check_signature(func_ast, globals, self.id)
73
+ ty = check_signature(
74
+ func_ast, globals, self.id, unitary_flags=self.unitary_flags
75
+ )
72
76
  if not has_empty_body(func_ast):
73
77
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
74
78
  # Make sure we won't need monomorphization to compile this declaration
@@ -33,6 +33,7 @@ from guppylang_internals.definition.common import (
33
33
  ParsableDef,
34
34
  UnknownSourceError,
35
35
  )
36
+ from guppylang_internals.definition.metadata import GuppyMetadata, add_metadata
36
37
  from guppylang_internals.definition.value import (
37
38
  CallableDef,
38
39
  CallReturnWires,
@@ -43,7 +44,7 @@ from guppylang_internals.error import GuppyError
43
44
  from guppylang_internals.nodes import GlobalCall
44
45
  from guppylang_internals.span import SourceMap
45
46
  from guppylang_internals.tys.subst import Inst, Subst
46
- from guppylang_internals.tys.ty import FunctionType, Type, type_to_row
47
+ from guppylang_internals.tys.ty import FunctionType, Type, UnitaryFlags, type_to_row
47
48
 
48
49
  if TYPE_CHECKING:
49
50
  from guppylang_internals.tys.param import Parameter
@@ -70,11 +71,24 @@ class RawFunctionDef(ParsableDef):
70
71
 
71
72
  description: str = field(default="function", init=False)
72
73
 
74
+ unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
75
+
76
+ metadata: GuppyMetadata | None = field(default=None, kw_only=True)
77
+
73
78
  def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
74
79
  """Parses and checks the user-provided signature of the function."""
75
80
  func_ast, docstring = parse_py_func(self.python_func, sources)
76
- ty = check_signature(func_ast, globals, self.id)
77
- return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
81
+ ty = check_signature(
82
+ func_ast, globals, self.id, unitary_flags=self.unitary_flags
83
+ )
84
+ return ParsedFunctionDef(
85
+ self.id,
86
+ self.name,
87
+ func_ast,
88
+ ty,
89
+ docstring,
90
+ metadata=self.metadata,
91
+ )
78
92
 
79
93
 
80
94
  @dataclass(frozen=True)
@@ -99,6 +113,8 @@ class ParsedFunctionDef(CheckableDef, CallableDef):
99
113
 
100
114
  description: str = field(default="function", init=False)
101
115
 
116
+ metadata: GuppyMetadata | None = field(default=None, kw_only=True)
117
+
102
118
  def check(self, globals: Globals) -> "CheckedFunctionDef":
103
119
  """Type checks the body of the function."""
104
120
  # Add python variable scope to the globals
@@ -110,6 +126,7 @@ class ParsedFunctionDef(CheckableDef, CallableDef):
110
126
  self.ty,
111
127
  self.docstring,
112
128
  cfg,
129
+ metadata=self.metadata,
113
130
  )
114
131
 
115
132
  def check_call(
@@ -173,6 +190,11 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
173
190
  func_def = module.module_root_builder().define_function(
174
191
  self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
175
192
  )
193
+ add_metadata(
194
+ func_def,
195
+ self.metadata,
196
+ additional_metadata={"unitary": self.ty.unitary_flags.value},
197
+ )
176
198
  return CompiledFunctionDef(
177
199
  self.id,
178
200
  self.name,
@@ -182,6 +204,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
182
204
  self.docstring,
183
205
  self.cfg,
184
206
  func_def,
207
+ metadata=self.metadata,
185
208
  )
186
209
 
187
210