guppylang-internals 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/cfg/builder.py +17 -2
  3. guppylang_internals/cfg/cfg.py +3 -0
  4. guppylang_internals/checker/cfg_checker.py +6 -0
  5. guppylang_internals/checker/core.py +1 -2
  6. guppylang_internals/checker/errors/wasm.py +7 -4
  7. guppylang_internals/checker/expr_checker.py +13 -8
  8. guppylang_internals/checker/func_checker.py +17 -13
  9. guppylang_internals/checker/linearity_checker.py +2 -10
  10. guppylang_internals/checker/modifier_checker.py +6 -2
  11. guppylang_internals/checker/unitary_checker.py +132 -0
  12. guppylang_internals/compiler/cfg_compiler.py +7 -6
  13. guppylang_internals/compiler/core.py +5 -5
  14. guppylang_internals/compiler/expr_compiler.py +42 -73
  15. guppylang_internals/compiler/modifier_compiler.py +2 -0
  16. guppylang_internals/decorator.py +86 -7
  17. guppylang_internals/definition/custom.py +4 -0
  18. guppylang_internals/definition/declaration.py +6 -2
  19. guppylang_internals/definition/function.py +12 -2
  20. guppylang_internals/definition/pytket_circuits.py +1 -0
  21. guppylang_internals/definition/struct.py +6 -3
  22. guppylang_internals/definition/wasm.py +42 -10
  23. guppylang_internals/engine.py +9 -3
  24. guppylang_internals/nodes.py +23 -24
  25. guppylang_internals/std/_internal/checker.py +13 -108
  26. guppylang_internals/std/_internal/compiler/array.py +1 -1
  27. guppylang_internals/std/_internal/compiler/list.py +1 -1
  28. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  29. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  30. guppylang_internals/std/_internal/compiler/tket_exts.py +3 -4
  31. guppylang_internals/std/_internal/debug.py +18 -9
  32. guppylang_internals/std/_internal/util.py +1 -1
  33. guppylang_internals/tracing/object.py +10 -0
  34. guppylang_internals/tys/errors.py +23 -1
  35. guppylang_internals/tys/parsing.py +3 -3
  36. guppylang_internals/tys/printing.py +2 -8
  37. guppylang_internals/tys/qubit.py +37 -2
  38. guppylang_internals/tys/ty.py +60 -64
  39. guppylang_internals/wasm_util.py +129 -0
  40. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +4 -3
  41. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/RECORD +43 -40
  42. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
  43. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
@@ -15,7 +15,6 @@ from hugr import val as hv
15
15
  from hugr.build import function as hf
16
16
  from hugr.build.cond_loop import Conditional
17
17
  from hugr.build.dfg import DP, DfBase
18
- from typing_extensions import assert_never
19
18
 
20
19
  from guppylang_internals.ast_util import AstNode, AstVisitor, get_type
21
20
  from guppylang_internals.cfg.builder import tmp_vars
@@ -23,7 +22,6 @@ from guppylang_internals.checker.core import Variable, contains_subscript
23
22
  from guppylang_internals.checker.errors.generic import UnsupportedError
24
23
  from guppylang_internals.compiler.core import (
25
24
  DEBUG_EXTENSION,
26
- RESULT_EXTENSION,
27
25
  CompilerBase,
28
26
  CompilerContext,
29
27
  DFContainer,
@@ -53,7 +51,6 @@ from guppylang_internals.nodes import (
53
51
  PanicExpr,
54
52
  PartialApply,
55
53
  PlaceNode,
56
- ResultExpr,
57
54
  StateResultExpr,
58
55
  SubscriptAccessAndDrop,
59
56
  TensorCall,
@@ -66,7 +63,6 @@ from guppylang_internals.std._internal.compiler.arithmetic import (
66
63
  convert_itousize,
67
64
  )
68
65
  from guppylang_internals.std._internal.compiler.array import (
69
- array_clone,
70
66
  array_map,
71
67
  array_new,
72
68
  array_to_std_array,
@@ -80,8 +76,8 @@ from guppylang_internals.std._internal.compiler.list import (
80
76
  list_new,
81
77
  )
82
78
  from guppylang_internals.std._internal.compiler.prelude import (
83
- build_error,
84
79
  build_panic,
80
+ make_error,
85
81
  panic,
86
82
  )
87
83
  from guppylang_internals.std._internal.compiler.tket_bool import (
@@ -93,14 +89,12 @@ from guppylang_internals.std._internal.compiler.tket_bool import (
93
89
  )
94
90
  from guppylang_internals.tys.arg import ConstArg
95
91
  from guppylang_internals.tys.builtin import (
96
- array_type,
97
92
  bool_type,
98
93
  get_element_type,
99
94
  int_type,
100
- is_bool_type,
101
95
  is_frozenarray_type,
102
96
  )
103
- from guppylang_internals.tys.const import ConstValue
97
+ from guppylang_internals.tys.const import BoundConstVar, Const, ConstValue
104
98
  from guppylang_internals.tys.subst import Inst
105
99
  from guppylang_internals.tys.ty import (
106
100
  BoundTypeVar,
@@ -535,74 +529,48 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
535
529
  tuple_port = self.visit(node.value)
536
530
  return self._unpack_tuple(tuple_port, node.tuple_ty.element_types)[node.index]
537
531
 
538
- def visit_ResultExpr(self, node: ResultExpr) -> Wire:
539
- value_wire = self.visit(node.value)
540
- base_ty = node.base_ty.to_hugr(self.ctx)
541
- extra_args: list[ht.TypeArg] = []
542
- if isinstance(node.base_ty, NumericType):
543
- match node.base_ty.kind:
544
- case NumericType.Kind.Nat:
545
- base_name = "uint"
546
- extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
547
- case NumericType.Kind.Int:
548
- base_name = "int"
549
- extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
550
- case NumericType.Kind.Float:
551
- base_name = "f64"
552
- case kind:
553
- assert_never(kind)
554
- else:
555
- # The only other valid base type is bool
556
- assert is_bool_type(node.base_ty)
557
- base_name = "bool"
558
- if node.array_len is not None:
559
- op_name = f"result_array_{base_name}"
560
- size_arg = node.array_len.to_arg().to_hugr(self.ctx)
561
- extra_args = [size_arg, *extra_args]
562
- # As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
563
- # that all elements in it are copyable) to avoid linearity violations when
564
- # both passing it to the result operation and returning it (as an inout
565
- # argument).
566
- value_wire, inout_wire = self.builder.add_op(
567
- array_clone(base_ty, size_arg), value_wire
568
- )
569
- func_ty = FunctionType(
570
- [
571
- FuncInput(
572
- array_type(node.base_ty, node.array_len), InputFlags.Inout
573
- ),
574
- ],
575
- NoneType(),
576
- )
577
- self._update_inout_ports(node.args, iter([inout_wire]), func_ty)
578
- if is_bool_type(node.base_ty):
579
- # We need to coerce a read on all the array elements if they are bools.
580
- array_read = array_read_bool(self.ctx)
581
- array_read = self.builder.load_function(array_read)
582
- map_op = array_map(OpaqueBool, size_arg, ht.Bool)
583
- value_wire = self.builder.add_op(map_op, value_wire, array_read)
584
- base_ty = ht.Bool
585
- # Turn `borrow_array` into regular `array`
586
- value_wire = self.builder.add_op(
587
- array_to_std_array(base_ty, size_arg), value_wire
588
- )
589
- hugr_ty: ht.Type = hugr.std.collections.array.Array(base_ty, size_arg)
590
- else:
591
- if is_bool_type(node.base_ty):
592
- base_ty = ht.Bool
593
- value_wire = self.builder.add_op(read_bool(), value_wire)
594
- op_name = f"result_{base_name}"
595
- hugr_ty = base_ty
532
+ def _visit_result_tag(self, tag: Const, loc: ast.expr) -> str:
533
+ """Helper method to resolve the tag string in `state_result` expressions.
596
534
 
597
- sig = ht.FunctionType(input=[hugr_ty], output=[])
598
- args = [ht.StringArg(node.tag), *extra_args]
599
- op = ops.ExtOp(RESULT_EXTENSION.get_op(op_name), signature=sig, args=args)
535
+ Also takes care of checking that the tag fits into the maximum tag length.
536
+ Once we go ahead with https://github.com/quantinuum/guppylang/discussions/1299,
537
+ this can be moved into type checking.
538
+ """
539
+ from guppylang_internals.std._internal.compiler.platform import (
540
+ TAG_MAX_LEN,
541
+ TooLongError,
542
+ )
600
543
 
601
- self.builder.add_op(op, value_wire)
602
- return self._pack_returns([], NoneType())
544
+ is_generic: BoundConstVar | None = None
545
+ match tag:
546
+ case ConstValue(value=str(v)):
547
+ tag_value = v
548
+ case BoundConstVar(idx=idx) as var:
549
+ assert self.ctx.current_mono_args is not None
550
+ match self.ctx.current_mono_args[idx]:
551
+ case ConstArg(const=ConstValue(value=str(v))):
552
+ tag_value = v
553
+ is_generic = var
554
+ case _:
555
+ raise InternalGuppyError("Unexpected tag monomorphization")
556
+ case _:
557
+ raise InternalGuppyError("Unexpected tag value")
558
+
559
+ if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
560
+ err = TooLongError(loc)
561
+ err.add_sub_diagnostic(TooLongError.Hint(None))
562
+ if is_generic:
563
+ err.add_sub_diagnostic(
564
+ TooLongError.GenericHint(None, is_generic.display_name, tag_value)
565
+ )
566
+ raise GuppyError(err)
567
+ return tag_value
603
568
 
604
569
  def visit_PanicExpr(self, node: PanicExpr) -> Wire:
605
- err = build_error(self.builder, node.signal, node.msg)
570
+ signal = self.visit(node.signal)
571
+ signal_usize = self.builder.add_op(convert_itousize(), signal)
572
+ msg = self.visit(node.msg)
573
+ err = self.builder.add_op(make_error(), signal_usize, msg)
606
574
  in_tys = [get_type(e).to_hugr(self.ctx) for e in node.values]
607
575
  out_tys = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(node))]
608
576
  args = [self.visit(e) for e in node.values]
@@ -627,12 +595,13 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
627
595
  return self._pack_returns([], NoneType())
628
596
 
629
597
  def visit_StateResultExpr(self, node: StateResultExpr) -> Wire:
598
+ tag_value = self._visit_result_tag(node.tag_value, node.tag_expr)
630
599
  num_qubits_arg = (
631
600
  node.array_len.to_arg().to_hugr(self.ctx)
632
601
  if node.array_len
633
602
  else ht.BoundedNatArg(len(node.args) - 1)
634
603
  )
635
- args = [ht.StringArg(node.tag), num_qubits_arg]
604
+ args = [ht.StringArg(tag_value), num_qubits_arg]
636
605
  sig = ht.FunctionType(
637
606
  [standard_array_type(ht.Qubit, num_qubits_arg)],
638
607
  [standard_array_type(ht.Qubit, num_qubits_arg)],
@@ -8,6 +8,7 @@ from guppylang_internals.checker.modifier_checker import non_copyable_front_othe
8
8
  from guppylang_internals.compiler.cfg_compiler import compile_cfg
9
9
  from guppylang_internals.compiler.core import CompilerContext, DFContainer
10
10
  from guppylang_internals.compiler.expr_compiler import ExprCompiler
11
+ from guppylang_internals.definition.function import add_unitarity_metadata
11
12
  from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
12
13
  from guppylang_internals.std._internal.compiler.array import (
13
14
  array_new,
@@ -56,6 +57,7 @@ def compile_modified_block(
56
57
  func_builder = dfg.builder.module_root_builder().define_function(
57
58
  str(modified_block), hugr_ty.input, hugr_ty.output
58
59
  )
60
+ add_unitarity_metadata(func_builder, modified_block.ty.unitary_flags)
59
61
 
60
62
  # compile body
61
63
  cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
@@ -1,11 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import pathlib
4
5
  from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
5
6
 
6
7
  from hugr import ops
7
8
  from hugr import tys as ht
8
9
 
10
+ from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
9
11
  from guppylang_internals.compiler.core import (
10
12
  CompilerContext,
11
13
  GlobalConstId,
@@ -24,6 +26,7 @@ from guppylang_internals.definition.ty import OpaqueTypeDef, TypeDef
24
26
  from guppylang_internals.definition.wasm import RawWasmFunctionDef
25
27
  from guppylang_internals.dummy_decorator import _dummy_custom_decorator, sphinx_running
26
28
  from guppylang_internals.engine import DEF_STORE
29
+ from guppylang_internals.error import GuppyError
27
30
  from guppylang_internals.std._internal.checker import WasmCallChecker
28
31
  from guppylang_internals.std._internal.compiler.wasm import (
29
32
  WasmModuleCallCompiler,
@@ -39,6 +42,14 @@ from guppylang_internals.tys.ty import (
39
42
  InputFlags,
40
43
  NoneType,
41
44
  NumericType,
45
+ UnitaryFlags,
46
+ )
47
+ from guppylang_internals.wasm_util import (
48
+ ConcreteWasmModule,
49
+ WasmFileNotFound,
50
+ WasmFunctionNotInFile,
51
+ WasmSignatureError,
52
+ decode_wasm_functions,
42
53
  )
43
54
 
44
55
  if TYPE_CHECKING:
@@ -47,7 +58,6 @@ if TYPE_CHECKING:
47
58
  from collections.abc import Callable, Sequence
48
59
  from types import FrameType
49
60
 
50
- from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
51
61
  from guppylang_internals.tys.arg import Argument
52
62
  from guppylang_internals.tys.param import Parameter
53
63
  from guppylang_internals.tys.subst import Inst
@@ -75,6 +85,7 @@ def custom_function(
75
85
  higher_order_value: bool = True,
76
86
  name: str = "",
77
87
  signature: FunctionType | None = None,
88
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
78
89
  ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
79
90
  """Decorator to add custom typing or compilation behaviour to function decls.
80
91
 
@@ -86,6 +97,8 @@ def custom_function(
86
97
 
87
98
  def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
88
99
  call_checker = checker or DefaultCallChecker()
100
+ if signature is not None:
101
+ object.__setattr__(signature, "unitary_flags", unitary_flags)
89
102
  func = RawCustomFunctionDef(
90
103
  DefId.fresh(),
91
104
  name or f.__name__,
@@ -95,6 +108,7 @@ def custom_function(
95
108
  compiler or NotImplementedCallCompiler(),
96
109
  higher_order_value,
97
110
  signature,
111
+ unitary_flags,
98
112
  )
99
113
  DEF_STORE.register_def(func, get_calling_frame())
100
114
  return GuppyFunctionDefinition(func)
@@ -108,6 +122,7 @@ def hugr_op(
108
122
  higher_order_value: bool = True,
109
123
  name: str = "",
110
124
  signature: FunctionType | None = None,
125
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
111
126
  ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
112
127
  """Decorator to annotate function declarations as HUGR ops.
113
128
 
@@ -119,7 +134,14 @@ def hugr_op(
119
134
  value.
120
135
  name: The name of the function.
121
136
  """
122
- return custom_function(OpCompiler(op), checker, higher_order_value, name, signature)
137
+ return custom_function(
138
+ OpCompiler(op),
139
+ checker,
140
+ higher_order_value,
141
+ name,
142
+ signature,
143
+ unitary_flags=unitary_flags,
144
+ )
123
145
 
124
146
 
125
147
  def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]:
@@ -188,6 +210,12 @@ def custom_type(
188
210
  def wasm_module(
189
211
  filename: str,
190
212
  ) -> Callable[[builtins.type[T]], GuppyDefinition]:
213
+ wasm_file = pathlib.Path(filename)
214
+ if wasm_file.is_file():
215
+ wasm_sigs = decode_wasm_functions(filename)
216
+ else:
217
+ raise GuppyError(WasmFileNotFound(None, filename))
218
+
191
219
  def type_def_wrapper(
192
220
  id: DefId,
193
221
  name: str,
@@ -198,10 +226,19 @@ def wasm_module(
198
226
  assert config is None
199
227
  return WasmModuleTypeDef(id, name, defined_at, wasm_file)
200
228
 
201
- f = ext_module_decorator(
202
- type_def_wrapper, WasmModuleInitCompiler(), WasmModuleDiscardCompiler(), True
229
+ decorator = ext_module_decorator(
230
+ type_def_wrapper,
231
+ WasmModuleInitCompiler(),
232
+ WasmModuleDiscardCompiler(),
233
+ True,
234
+ wasm_sigs,
203
235
  )
204
- return f(filename, None)
236
+
237
+ def inner_fun(ty: builtins.type[T]) -> GuppyDefinition:
238
+ decorator_inner = decorator(filename, None)
239
+ return decorator_inner(ty)
240
+
241
+ return inner_fun
205
242
 
206
243
 
207
244
  def ext_module_decorator(
@@ -209,9 +246,9 @@ def ext_module_decorator(
209
246
  init_compiler: CustomInoutCallCompiler,
210
247
  discard_compiler: CustomInoutCallCompiler,
211
248
  init_arg: bool, # Whether the init function should take a nat argument
249
+ wasm_sigs: ConcreteWasmModule
250
+ | None = None, # For @wasm_module, we must be passed a parsed wasm file
212
251
  ) -> Callable[[str, str | None], Callable[[builtins.type[T]], GuppyDefinition]]:
213
- from guppylang.defs import GuppyDefinition
214
-
215
252
  def fun(
216
253
  filename: str, module: str | None
217
254
  ) -> Callable[[builtins.type[T]], GuppyDefinition]:
@@ -231,6 +268,47 @@ def ext_module_decorator(
231
268
  for val in cls.__dict__.values():
232
269
  if isinstance(val, GuppyDefinition):
233
270
  DEF_STORE.register_impl(ext_module.id, val.wrapped.name, val.id)
271
+ wasm_def: RawWasmFunctionDef
272
+ if isinstance(val, GuppyFunctionDefinition) and isinstance(
273
+ val.wrapped, RawWasmFunctionDef
274
+ ):
275
+ wasm_def = val.wrapped
276
+ else:
277
+ continue
278
+ # wasm_sigs should only have not been provided if we have
279
+ # defined @wasm functions in a class which didn't use the
280
+ # @wasm_module decorator.
281
+ assert wasm_sigs is not None
282
+ if wasm_def.wasm_index is not None:
283
+ name = wasm_sigs.functions[wasm_def.wasm_index]
284
+ assert name in wasm_sigs.function_sigs
285
+ wasm_sig_or_err = wasm_sigs.function_sigs[name]
286
+ else:
287
+ if wasm_def.name in wasm_sigs.function_sigs:
288
+ wasm_sig_or_err = wasm_sigs.function_sigs[wasm_def.name]
289
+ else:
290
+ raise GuppyError(
291
+ WasmFunctionNotInFile(
292
+ wasm_def.defined_at,
293
+ wasm_def.name,
294
+ ).add_sub_diagnostic(
295
+ WasmFunctionNotInFile.WasmFileNote(
296
+ None,
297
+ wasm_sigs.filename,
298
+ )
299
+ )
300
+ )
301
+ if isinstance(wasm_sig_or_err, FunctionType):
302
+ DEF_STORE.register_wasm_function(wasm_def.id, wasm_sig_or_err)
303
+ elif isinstance(wasm_sig_or_err, str):
304
+ raise GuppyError(
305
+ WasmSignatureError(
306
+ None, wasm_def.name, filename
307
+ ).add_sub_diagnostic(
308
+ WasmSignatureError.Message(None, wasm_sig_or_err)
309
+ )
310
+ )
311
+
234
312
  # Add a constructor to the class
235
313
  if init_arg:
236
314
  init_fn_ty = FunctionType(
@@ -315,6 +393,7 @@ def wasm_helper(fn_id: int | None, f: Callable[P, T]) -> GuppyFunctionDefinition
315
393
  WasmModuleCallCompiler(f.__name__, fn_id),
316
394
  True,
317
395
  signature=None,
396
+ wasm_index=fn_id,
318
397
  )
319
398
  DEF_STORE.register_def(func, get_calling_frame())
320
399
  return GuppyFunctionDefinition(func)
@@ -44,6 +44,7 @@ from guppylang_internals.tys.ty import (
44
44
  InputFlags,
45
45
  NoneType,
46
46
  Type,
47
+ UnitaryFlags,
47
48
  type_to_row,
48
49
  )
49
50
 
@@ -114,6 +115,8 @@ class RawCustomFunctionDef(ParsableDef):
114
115
 
115
116
  signature: FunctionType | None
116
117
 
118
+ unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags)
119
+
117
120
  description: str = field(default="function", init=False)
118
121
 
119
122
  def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
@@ -136,6 +139,7 @@ class RawCustomFunctionDef(ParsableDef):
136
139
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
137
140
  sig = self.signature or self._get_signature(func_ast, globals)
138
141
  ty = sig or FunctionType([], NoneType())
142
+ ty = ty.with_unitary_flags(self.unitary_flags)
139
143
  return CustomFunctionDef(
140
144
  self.id,
141
145
  self.name,
@@ -34,7 +34,7 @@ from guppylang_internals.nodes import GlobalCall
34
34
  from guppylang_internals.span import SourceMap
35
35
  from guppylang_internals.tys.param import Parameter
36
36
  from guppylang_internals.tys.subst import Inst, Subst
37
- from guppylang_internals.tys.ty import Type
37
+ from guppylang_internals.tys.ty import Type, UnitaryFlags
38
38
 
39
39
 
40
40
  @dataclass(frozen=True)
@@ -65,10 +65,14 @@ class RawFunctionDecl(ParsableDef):
65
65
  python_func: PyFunc
66
66
  description: str = field(default="function", init=False)
67
67
 
68
+ unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
69
+
68
70
  def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
69
71
  """Parses and checks the user-provided signature of the function."""
70
72
  func_ast, docstring = parse_py_func(self.python_func, sources)
71
- ty = check_signature(func_ast, globals, self.id)
73
+ ty = check_signature(
74
+ func_ast, globals, self.id, unitary_flags=self.unitary_flags
75
+ )
72
76
  if not has_empty_body(func_ast):
73
77
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
74
78
  # Make sure we won't need monomorphization to compile this declaration
@@ -43,7 +43,7 @@ from guppylang_internals.error import GuppyError
43
43
  from guppylang_internals.nodes import GlobalCall
44
44
  from guppylang_internals.span import SourceMap
45
45
  from guppylang_internals.tys.subst import Inst, Subst
46
- from guppylang_internals.tys.ty import FunctionType, Type, type_to_row
46
+ from guppylang_internals.tys.ty import FunctionType, Type, UnitaryFlags, type_to_row
47
47
 
48
48
  if TYPE_CHECKING:
49
49
  from guppylang_internals.tys.param import Parameter
@@ -70,10 +70,14 @@ class RawFunctionDef(ParsableDef):
70
70
 
71
71
  description: str = field(default="function", init=False)
72
72
 
73
+ unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
74
+
73
75
  def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
74
76
  """Parses and checks the user-provided signature of the function."""
75
77
  func_ast, docstring = parse_py_func(self.python_func, sources)
76
- ty = check_signature(func_ast, globals, self.id)
78
+ ty = check_signature(
79
+ func_ast, globals, self.id, unitary_flags=self.unitary_flags
80
+ )
77
81
  return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
78
82
 
79
83
 
@@ -173,6 +177,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
173
177
  func_def = module.module_root_builder().define_function(
174
178
  self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
175
179
  )
180
+ add_unitarity_metadata(func_def, self.ty.unitary_flags)
176
181
  return CompiledFunctionDef(
177
182
  self.id,
178
183
  self.name,
@@ -300,3 +305,8 @@ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AS
300
305
  else:
301
306
  node = ast.parse(source).body[0]
302
307
  return source, node, line_offset
308
+
309
+
310
+ def add_unitarity_metadata(func: hf.Function, flags: UnitaryFlags) -> None:
311
+ """Stores unitarity annotations in the metadate of a Hugr function definition."""
312
+ func.metadata["unitary"] = flags.value
@@ -370,6 +370,7 @@ def _signature_from_circuit(
370
370
  use_arrays: bool = False,
371
371
  ) -> FunctionType:
372
372
  """Helper function for inferring a function signature from a pytket circuit."""
373
+ # May want to set proper unitary flags in the future.
373
374
  try:
374
375
  import pytket
375
376
 
@@ -273,13 +273,16 @@ class CheckedStructDef(TypeDef, CompiledDef):
273
273
 
274
274
  constructor_sig = FunctionType(
275
275
  inputs=[
276
- FuncInput(f.ty, InputFlags.Owned if f.ty.linear else InputFlags.NoFlags)
276
+ FuncInput(
277
+ f.ty,
278
+ InputFlags.Owned if f.ty.linear else InputFlags.NoFlags,
279
+ f.name,
280
+ )
277
281
  for f in self.fields
278
282
  ],
279
283
  output=StructType(
280
284
  defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)]
281
285
  ),
282
- input_names=[f.name for f in self.fields],
283
286
  params=self.params,
284
287
  )
285
288
  constructor_def = CustomFunctionDef(
@@ -317,7 +320,7 @@ def parse_py_class(
317
320
  raise GuppyError(UnknownSourceError(None, cls))
318
321
 
319
322
  # We can't rely on `inspect.getsourcelines` since it doesn't work properly for
320
- # classes prior to Python 3.13. See https://github.com/CQCL/guppylang/issues/1107.
323
+ # classes prior to Python 3.13. See https://github.com/quantinuum/guppylang/issues/1107.
321
324
  # Instead, we reproduce the behaviour of Python >= 3.13 using the `__firstlineno__`
322
325
  # attribute. See https://github.com/python/cpython/blob/3.13/Lib/inspect.py#L1052.
323
326
  # In the decorator, we make sure that `__firstlineno__` is set, even if we're not
@@ -1,3 +1,4 @@
1
+ from dataclasses import dataclass, field
1
2
  from typing import TYPE_CHECKING
2
3
 
3
4
  from guppylang_internals.ast_util import AstNode
@@ -9,7 +10,8 @@ from guppylang_internals.definition.custom import (
9
10
  CustomFunctionDef,
10
11
  RawCustomFunctionDef,
11
12
  )
12
- from guppylang_internals.error import GuppyError
13
+ from guppylang_internals.engine import DEF_STORE
14
+ from guppylang_internals.error import GuppyError, GuppyTypeError
13
15
  from guppylang_internals.span import SourceMap
14
16
  from guppylang_internals.tys.builtin import wasm_module_name
15
17
  from guppylang_internals.tys.ty import (
@@ -21,24 +23,35 @@ from guppylang_internals.tys.ty import (
21
23
  TupleType,
22
24
  Type,
23
25
  )
26
+ from guppylang_internals.wasm_util import WasmSigMismatchError
24
27
 
25
28
  if TYPE_CHECKING:
26
29
  from guppylang_internals.checker.core import Globals
27
30
 
28
31
 
32
+ @dataclass(frozen=True)
29
33
  class RawWasmFunctionDef(RawCustomFunctionDef):
30
- def sanitise_type(self, loc: AstNode | None, fun_ty: FunctionType) -> None:
34
+ # If a function is specified in the @wasm decorator by its index in the wasm
35
+ # file, record what the index was.
36
+ wasm_index: int | None = field(default=None)
37
+
38
+ def sanitise_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
31
39
  # Place to highlight in error messages
32
- match fun_ty.inputs[0]:
33
- case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
40
+ match fun_ty.inputs:
41
+ case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if wasm_module_name(
34
42
  ty
35
43
  ) is not None:
36
- pass
37
- case FuncInput(ty=ty):
38
- raise GuppyError(FirstArgNotModule(loc, ty))
39
- for inp in fun_ty.inputs[1:]:
40
- if not self.is_type_wasmable(inp.ty):
41
- raise GuppyError(UnWasmableType(loc, inp.ty))
44
+ for inp in args:
45
+ if not self.is_type_wasmable(inp.ty):
46
+ raise GuppyError(UnWasmableType(loc, inp.ty))
47
+ case [FuncInput(ty=ty), *_]:
48
+ raise GuppyError(
49
+ FirstArgNotModule(loc).add_sub_diagnostic(
50
+ FirstArgNotModule.GotOtherType(loc, ty)
51
+ )
52
+ )
53
+ case []:
54
+ raise GuppyError(FirstArgNotModule(loc))
42
55
  if not self.is_type_wasmable(fun_ty.output):
43
56
  match fun_ty.output:
44
57
  case NoneType():
@@ -46,6 +59,23 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
46
59
  case _:
47
60
  raise GuppyError(UnWasmableType(loc, fun_ty.output))
48
61
 
62
+ def validate_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
63
+ type_in_wasm: FunctionType = DEF_STORE.wasm_functions[self.id]
64
+ assert type_in_wasm is not None
65
+ # Drop the first arg because it should be "self"
66
+ expected_type = FunctionType(fun_ty.inputs[1:], fun_ty.output)
67
+
68
+ if expected_type != type_in_wasm:
69
+ raise GuppyTypeError(
70
+ WasmSigMismatchError(loc)
71
+ .add_sub_diagnostic(
72
+ WasmSigMismatchError.Declaration(None, declared=str(expected_type))
73
+ )
74
+ .add_sub_diagnostic(
75
+ WasmSigMismatchError.Actual(None, actual=str(type_in_wasm))
76
+ )
77
+ )
78
+
49
79
  def is_type_wasmable(self, ty: Type) -> bool:
50
80
  match ty:
51
81
  case NumericType():
@@ -57,5 +87,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
57
87
 
58
88
  def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
59
89
  parsed = super().parse(globals, sources)
90
+ assert parsed.defined_at is not None
60
91
  self.sanitise_type(parsed.defined_at, parsed.ty)
92
+ self.validate_type(parsed.defined_at, parsed.ty)
61
93
  return parsed
@@ -46,6 +46,7 @@ from guppylang_internals.tys.builtin import (
46
46
  string_type_def,
47
47
  tuple_type_def,
48
48
  )
49
+ from guppylang_internals.tys.ty import FunctionType
49
50
 
50
51
  if TYPE_CHECKING:
51
52
  from guppylang_internals.compiler.core import MonoDefId
@@ -87,6 +88,7 @@ class DefinitionStore:
87
88
  raw_defs: dict[DefId, RawDef]
88
89
  impls: defaultdict[DefId, dict[str, DefId]]
89
90
  impl_parents: dict[DefId, DefId]
91
+ wasm_functions: dict[DefId, FunctionType]
90
92
  frames: dict[DefId, FrameType]
91
93
  sources: SourceMap
92
94
 
@@ -96,6 +98,7 @@ class DefinitionStore:
96
98
  self.impl_parents = {}
97
99
  self.frames = {}
98
100
  self.sources = SourceMap()
101
+ self.wasm_functions = {}
99
102
 
100
103
  def register_def(self, defn: RawDef, frame: FrameType | None) -> None:
101
104
  self.raw_defs[defn.id] = defn
@@ -123,6 +126,9 @@ class DefinitionStore:
123
126
  assert frame is not None
124
127
  self.frames[impl_id] = frame
125
128
 
129
+ def register_wasm_function(self, fn_id: DefId, sig: FunctionType) -> None:
130
+ self.wasm_functions[fn_id] = sig
131
+
126
132
 
127
133
  DEF_STORE: DefinitionStore = DefinitionStore()
128
134
 
@@ -263,8 +269,8 @@ class CompilationEngine:
263
269
  and isinstance(compiled_def, CompiledCallableDef)
264
270
  and not isinstance(graph.hugr[compiled_def.hugr_node].op, ops.FuncDecl)
265
271
  ):
266
- # if compiling a region set it as the HUGR entrypoint
267
- # can be loosened after https://github.com/CQCL/hugr/issues/2501 is fixed
272
+ # if compiling a region set it as the HUGR entrypoint can be
273
+ # loosened after https://github.com/quantinuum/hugr/issues/2501 is fixed
268
274
  graph.hugr.entrypoint = compiled_def.hugr_node
269
275
 
270
276
  # TODO: Currently the list of extensions is manually managed by the user.
@@ -278,7 +284,7 @@ class CompilationEngine:
278
284
  guppylang_internals.compiler.hugr_extension.EXTENSION,
279
285
  *self.additional_extensions,
280
286
  ]
281
- # TODO replace with computed extensions after https://github.com/CQCL/guppylang/issues/550
287
+ # TODO replace with computed extensions after https://github.com/quantinuum/guppylang/issues/550
282
288
  all_used_extensions = [
283
289
  *extensions,
284
290
  hugr.std.prelude.PRELUDE_EXTENSION,