guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +118 -5
  5. guppylang_internals/cfg/cfg.py +3 -0
  6. guppylang_internals/checker/cfg_checker.py +6 -0
  7. guppylang_internals/checker/core.py +5 -2
  8. guppylang_internals/checker/errors/generic.py +32 -1
  9. guppylang_internals/checker/errors/type_errors.py +14 -0
  10. guppylang_internals/checker/errors/wasm.py +7 -4
  11. guppylang_internals/checker/expr_checker.py +58 -17
  12. guppylang_internals/checker/func_checker.py +18 -14
  13. guppylang_internals/checker/linearity_checker.py +67 -10
  14. guppylang_internals/checker/modifier_checker.py +120 -0
  15. guppylang_internals/checker/stmt_checker.py +48 -1
  16. guppylang_internals/checker/unitary_checker.py +132 -0
  17. guppylang_internals/compiler/cfg_compiler.py +7 -6
  18. guppylang_internals/compiler/core.py +93 -56
  19. guppylang_internals/compiler/expr_compiler.py +72 -168
  20. guppylang_internals/compiler/modifier_compiler.py +176 -0
  21. guppylang_internals/compiler/stmt_compiler.py +15 -8
  22. guppylang_internals/decorator.py +86 -7
  23. guppylang_internals/definition/custom.py +39 -1
  24. guppylang_internals/definition/declaration.py +9 -6
  25. guppylang_internals/definition/function.py +12 -2
  26. guppylang_internals/definition/parameter.py +8 -3
  27. guppylang_internals/definition/pytket_circuits.py +14 -41
  28. guppylang_internals/definition/struct.py +13 -7
  29. guppylang_internals/definition/ty.py +3 -3
  30. guppylang_internals/definition/wasm.py +42 -10
  31. guppylang_internals/engine.py +9 -3
  32. guppylang_internals/experimental.py +5 -0
  33. guppylang_internals/nodes.py +147 -24
  34. guppylang_internals/std/_internal/checker.py +13 -108
  35. guppylang_internals/std/_internal/compiler/array.py +95 -283
  36. guppylang_internals/std/_internal/compiler/list.py +1 -1
  37. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  38. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  39. guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
  40. guppylang_internals/std/_internal/debug.py +18 -9
  41. guppylang_internals/std/_internal/util.py +1 -1
  42. guppylang_internals/tracing/object.py +10 -0
  43. guppylang_internals/tracing/unpacking.py +19 -20
  44. guppylang_internals/tys/arg.py +18 -3
  45. guppylang_internals/tys/builtin.py +2 -5
  46. guppylang_internals/tys/const.py +33 -4
  47. guppylang_internals/tys/errors.py +23 -1
  48. guppylang_internals/tys/param.py +31 -16
  49. guppylang_internals/tys/parsing.py +11 -24
  50. guppylang_internals/tys/printing.py +2 -8
  51. guppylang_internals/tys/qubit.py +62 -0
  52. guppylang_internals/tys/subst.py +8 -26
  53. guppylang_internals/tys/ty.py +91 -85
  54. guppylang_internals/wasm_util.py +129 -0
  55. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
  56. guppylang_internals-0.26.0.dist-info/RECORD +104 -0
  57. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
  58. guppylang_internals-0.24.0.dist-info/RECORD +0 -98
  59. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
@@ -0,0 +1,176 @@
1
+ """Hugr generation for modifiers."""
2
+
3
+ from hugr import Wire, ops
4
+ from hugr import tys as ht
5
+
6
+ from guppylang_internals.ast_util import get_type
7
+ from guppylang_internals.checker.modifier_checker import non_copyable_front_others_back
8
+ from guppylang_internals.compiler.cfg_compiler import compile_cfg
9
+ from guppylang_internals.compiler.core import CompilerContext, DFContainer
10
+ from guppylang_internals.compiler.expr_compiler import ExprCompiler
11
+ from guppylang_internals.definition.function import add_unitarity_metadata
12
+ from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
13
+ from guppylang_internals.std._internal.compiler.array import (
14
+ array_new,
15
+ array_to_std_array,
16
+ standard_array_type,
17
+ std_array_to_array,
18
+ unpack_array,
19
+ )
20
+ from guppylang_internals.std._internal.compiler.tket_exts import MODIFIER_EXTENSION
21
+ from guppylang_internals.tys.builtin import int_type, is_array_type
22
+ from guppylang_internals.tys.ty import InputFlags
23
+
24
+
25
+ def compile_modified_block(
26
+ modified_block: CheckedModifiedBlock,
27
+ dfg: DFContainer,
28
+ ctx: CompilerContext,
29
+ expr_compiler: ExprCompiler,
30
+ ) -> Wire:
31
+ DAGGER_OP_NAME = "DaggerModifier"
32
+ CONTROL_OP_NAME = "ControlModifier"
33
+ POWER_OP_NAME = "PowerModifier"
34
+
35
+ dagger_op_def = MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME)
36
+ control_op_def = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME)
37
+ power_op_def = MODIFIER_EXTENSION.get_op(POWER_OP_NAME)
38
+
39
+ body_ty = modified_block.ty
40
+ # TODO: Shouldn't this be `to_hugr_poly` since it can contain
41
+ # a variable with a generic type?
42
+ hugr_ty = body_ty.to_hugr(ctx)
43
+ in_out_ht = [
44
+ fn_inp.ty.to_hugr(ctx)
45
+ for fn_inp in body_ty.inputs
46
+ if InputFlags.Inout in fn_inp.flags and InputFlags.Comptime not in fn_inp.flags
47
+ ]
48
+ other_in_ht = [
49
+ fn_inp.ty.to_hugr(ctx)
50
+ for fn_inp in body_ty.inputs
51
+ if InputFlags.Inout not in fn_inp.flags
52
+ and InputFlags.Comptime not in fn_inp.flags
53
+ ]
54
+ in_out_arg = ht.ListArg([t.type_arg() for t in in_out_ht])
55
+ other_in_arg = ht.ListArg([t.type_arg() for t in other_in_ht])
56
+
57
+ func_builder = dfg.builder.module_root_builder().define_function(
58
+ str(modified_block), hugr_ty.input, hugr_ty.output
59
+ )
60
+ add_unitarity_metadata(func_builder, modified_block.ty.unitary_flags)
61
+
62
+ # compile body
63
+ cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
64
+ func_builder.set_outputs(*cfg)
65
+
66
+ # LoadFunc
67
+ call = dfg.builder.load_function(func_builder, hugr_ty)
68
+
69
+ # Function inputs
70
+ captured = [v for v, _ in modified_block.captured.values()]
71
+ captured = non_copyable_front_others_back(captured)
72
+ args = [dfg[v] for v in captured]
73
+
74
+ # Apply modifiers
75
+ if modified_block.has_dagger():
76
+ dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty])
77
+ call = dfg.builder.add_op(
78
+ ops.ExtOp(
79
+ dagger_op_def,
80
+ dagger_ty,
81
+ [in_out_arg, other_in_arg],
82
+ ),
83
+ call,
84
+ )
85
+ if modified_block.has_power():
86
+ power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty])
87
+ for power in modified_block.power:
88
+ num = expr_compiler.compile(power.iter, dfg)
89
+ call = dfg.builder.add_op(
90
+ ops.ExtOp(
91
+ power_op_def,
92
+ power_ty,
93
+ [in_out_arg, other_in_arg],
94
+ ),
95
+ call,
96
+ num,
97
+ )
98
+ qubit_num_args = []
99
+ if modified_block.has_control():
100
+ for control in modified_block.control:
101
+ assert control.qubit_num is not None
102
+ qubit_num: ht.TypeArg
103
+ if isinstance(control.qubit_num, int):
104
+ qubit_num = ht.BoundedNatArg(control.qubit_num)
105
+ else:
106
+ qubit_num = control.qubit_num.to_arg().to_hugr(ctx)
107
+ qubit_num_args.append(qubit_num)
108
+ std_array = standard_array_type(ht.Qubit, qubit_num)
109
+
110
+ # control operator
111
+ input_fn_ty = hugr_ty
112
+ output_fn_ty = ht.FunctionType(
113
+ [std_array, *hugr_ty.input], [std_array, *hugr_ty.output]
114
+ )
115
+ op = ops.ExtOp(
116
+ control_op_def,
117
+ ht.FunctionType([input_fn_ty], [output_fn_ty]),
118
+ [qubit_num, in_out_arg, other_in_arg],
119
+ )
120
+ call = dfg.builder.add_op(op, call)
121
+ # update types
122
+ in_out_arg = ht.ListArg([std_array.type_arg(), *in_out_arg.elems])
123
+ hugr_ty = output_fn_ty
124
+
125
+ # Prepare control arguments
126
+ ctrl_args: list[Wire] = []
127
+ for i, control in enumerate(modified_block.control):
128
+ if is_array_type(get_type(control.ctrl[0])):
129
+ control_array = expr_compiler.compile(control.ctrl[0], dfg)
130
+ control_array = dfg.builder.add_op(
131
+ array_to_std_array(ht.Qubit, qubit_num_args[i]), control_array
132
+ )
133
+ ctrl_args.append(control_array)
134
+ else:
135
+ cs = [expr_compiler.compile(c, dfg) for c in control.ctrl]
136
+ control_array = dfg.builder.add_op(
137
+ array_new(ht.Qubit, len(control.ctrl)), *cs
138
+ )
139
+ control_array = dfg.builder.add_op(
140
+ array_to_std_array(ht.Qubit, qubit_num_args[i]), *control_array
141
+ )
142
+ ctrl_args.append(control_array)
143
+
144
+ # Call
145
+ call = dfg.builder.add_op(
146
+ ops.CallIndirect(),
147
+ call,
148
+ *ctrl_args,
149
+ *args,
150
+ )
151
+ outports = iter(call)
152
+
153
+ # Unpack controls
154
+ for i, control in enumerate(modified_block.control):
155
+ outport = next(outports)
156
+ if is_array_type(get_type(control.ctrl[0])):
157
+ control_array = dfg.builder.add_op(
158
+ std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
159
+ )
160
+ c = control.ctrl[0]
161
+ assert isinstance(c, PlaceNode)
162
+ dfg[c.place] = control_array
163
+ else:
164
+ control_array = dfg.builder.add_op(
165
+ std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
166
+ )
167
+ unpacked = unpack_array(dfg.builder, control_array)
168
+ for c, new_c in zip(control.ctrl, unpacked, strict=False):
169
+ assert isinstance(c, PlaceNode)
170
+ dfg[c.place] = new_c
171
+
172
+ for arg in captured:
173
+ if InputFlags.Inout in arg.flags:
174
+ dfg[arg] = next(outports)
175
+
176
+ return call
@@ -2,7 +2,6 @@ import ast
2
2
  import functools
3
3
  from collections.abc import Sequence
4
4
 
5
- import hugr.tys as ht
6
5
  from hugr import Wire, ops
7
6
  from hugr.build.dfg import DfBase
8
7
 
@@ -18,6 +17,7 @@ from guppylang_internals.compiler.expr_compiler import ExprCompiler
18
17
  from guppylang_internals.error import InternalGuppyError
19
18
  from guppylang_internals.nodes import (
20
19
  ArrayUnpack,
20
+ CheckedModifiedBlock,
21
21
  CheckedNestedFunctionDef,
22
22
  IterableUnpack,
23
23
  PlaceNode,
@@ -120,8 +120,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
120
120
  ports[len(left) : -len(right)] if right else ports[len(left) :]
121
121
  )
122
122
  elt = get_element_type(array_ty).to_hugr(self.ctx)
123
- opts = [self.builder.add_op(ops.Some(elt), p) for p in starred_ports]
124
- array = self.builder.add_op(array_new(ht.Option(elt), len(opts)), *opts)
123
+ array = self.builder.add_op(
124
+ array_new(elt, len(starred_ports)), *starred_ports
125
+ )
125
126
  self._assign(starred, array)
126
127
 
127
128
  @_assign.register
@@ -130,7 +131,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
130
131
  # Given an assignment pattern `left, *starred, right`, pop from the left and
131
132
  # right, leaving us with the starred array in the middle
132
133
  length = lhs.length
133
- opt_elt_ty = ht.Option(lhs.elt_type.to_hugr(self.ctx))
134
+ elt_ty = lhs.elt_type.to_hugr(self.ctx)
134
135
 
135
136
  def pop(
136
137
  array: Wire, length: int, pats: list[ast.expr], from_left: bool
@@ -141,10 +142,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
141
142
  elts = []
142
143
  for i in range(num_pats):
143
144
  res = self.builder.add_op(
144
- array_pop(opt_elt_ty, length - i, from_left), array
145
+ array_pop(elt_ty, length - i, from_left), array
145
146
  )
146
- [elt_opt, array] = build_unwrap(self.builder, res, err)
147
- [elt] = build_unwrap(self.builder, elt_opt, err)
147
+ [elt, array] = build_unwrap(self.builder, res, err)
148
148
  elts.append(elt)
149
149
  # Assign elements to the given patterns
150
150
  for pat, elt in zip(
@@ -164,7 +164,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
164
164
  self._assign(lhs.pattern.starred, array)
165
165
  else:
166
166
  assert length == 0
167
- self.builder.add_op(array_discard_empty(opt_elt_ty), array)
167
+ self.builder.add_op(array_discard_empty(elt_ty), array)
168
168
 
169
169
  @_assign.register
170
170
  def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None:
@@ -221,3 +221,10 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
221
221
  var = Variable(node.name, node.ty, node)
222
222
  loaded_func = compile_local_func_def(node, self.dfg, self.ctx)
223
223
  self.dfg[var] = loaded_func
224
+
225
+ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
226
+ from guppylang_internals.compiler.modifier_compiler import (
227
+ compile_modified_block,
228
+ )
229
+
230
+ compile_modified_block(node, self.dfg, self.ctx, self.expr_compiler)
@@ -1,11 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import pathlib
4
5
  from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
5
6
 
6
7
  from hugr import ops
7
8
  from hugr import tys as ht
8
9
 
10
+ from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
9
11
  from guppylang_internals.compiler.core import (
10
12
  CompilerContext,
11
13
  GlobalConstId,
@@ -24,6 +26,7 @@ from guppylang_internals.definition.ty import OpaqueTypeDef, TypeDef
24
26
  from guppylang_internals.definition.wasm import RawWasmFunctionDef
25
27
  from guppylang_internals.dummy_decorator import _dummy_custom_decorator, sphinx_running
26
28
  from guppylang_internals.engine import DEF_STORE
29
+ from guppylang_internals.error import GuppyError
27
30
  from guppylang_internals.std._internal.checker import WasmCallChecker
28
31
  from guppylang_internals.std._internal.compiler.wasm import (
29
32
  WasmModuleCallCompiler,
@@ -39,6 +42,14 @@ from guppylang_internals.tys.ty import (
39
42
  InputFlags,
40
43
  NoneType,
41
44
  NumericType,
45
+ UnitaryFlags,
46
+ )
47
+ from guppylang_internals.wasm_util import (
48
+ ConcreteWasmModule,
49
+ WasmFileNotFound,
50
+ WasmFunctionNotInFile,
51
+ WasmSignatureError,
52
+ decode_wasm_functions,
42
53
  )
43
54
 
44
55
  if TYPE_CHECKING:
@@ -47,7 +58,6 @@ if TYPE_CHECKING:
47
58
  from collections.abc import Callable, Sequence
48
59
  from types import FrameType
49
60
 
50
- from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
51
61
  from guppylang_internals.tys.arg import Argument
52
62
  from guppylang_internals.tys.param import Parameter
53
63
  from guppylang_internals.tys.subst import Inst
@@ -75,6 +85,7 @@ def custom_function(
75
85
  higher_order_value: bool = True,
76
86
  name: str = "",
77
87
  signature: FunctionType | None = None,
88
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
78
89
  ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
79
90
  """Decorator to add custom typing or compilation behaviour to function decls.
80
91
 
@@ -86,6 +97,8 @@ def custom_function(
86
97
 
87
98
  def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
88
99
  call_checker = checker or DefaultCallChecker()
100
+ if signature is not None:
101
+ object.__setattr__(signature, "unitary_flags", unitary_flags)
89
102
  func = RawCustomFunctionDef(
90
103
  DefId.fresh(),
91
104
  name or f.__name__,
@@ -95,6 +108,7 @@ def custom_function(
95
108
  compiler or NotImplementedCallCompiler(),
96
109
  higher_order_value,
97
110
  signature,
111
+ unitary_flags,
98
112
  )
99
113
  DEF_STORE.register_def(func, get_calling_frame())
100
114
  return GuppyFunctionDefinition(func)
@@ -108,6 +122,7 @@ def hugr_op(
108
122
  higher_order_value: bool = True,
109
123
  name: str = "",
110
124
  signature: FunctionType | None = None,
125
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
111
126
  ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
112
127
  """Decorator to annotate function declarations as HUGR ops.
113
128
 
@@ -119,7 +134,14 @@ def hugr_op(
119
134
  value.
120
135
  name: The name of the function.
121
136
  """
122
- return custom_function(OpCompiler(op), checker, higher_order_value, name, signature)
137
+ return custom_function(
138
+ OpCompiler(op),
139
+ checker,
140
+ higher_order_value,
141
+ name,
142
+ signature,
143
+ unitary_flags=unitary_flags,
144
+ )
123
145
 
124
146
 
125
147
  def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]:
@@ -188,6 +210,12 @@ def custom_type(
188
210
  def wasm_module(
189
211
  filename: str,
190
212
  ) -> Callable[[builtins.type[T]], GuppyDefinition]:
213
+ wasm_file = pathlib.Path(filename)
214
+ if wasm_file.is_file():
215
+ wasm_sigs = decode_wasm_functions(filename)
216
+ else:
217
+ raise GuppyError(WasmFileNotFound(None, filename))
218
+
191
219
  def type_def_wrapper(
192
220
  id: DefId,
193
221
  name: str,
@@ -198,10 +226,19 @@ def wasm_module(
198
226
  assert config is None
199
227
  return WasmModuleTypeDef(id, name, defined_at, wasm_file)
200
228
 
201
- f = ext_module_decorator(
202
- type_def_wrapper, WasmModuleInitCompiler(), WasmModuleDiscardCompiler(), True
229
+ decorator = ext_module_decorator(
230
+ type_def_wrapper,
231
+ WasmModuleInitCompiler(),
232
+ WasmModuleDiscardCompiler(),
233
+ True,
234
+ wasm_sigs,
203
235
  )
204
- return f(filename, None)
236
+
237
+ def inner_fun(ty: builtins.type[T]) -> GuppyDefinition:
238
+ decorator_inner = decorator(filename, None)
239
+ return decorator_inner(ty)
240
+
241
+ return inner_fun
205
242
 
206
243
 
207
244
  def ext_module_decorator(
@@ -209,9 +246,9 @@ def ext_module_decorator(
209
246
  init_compiler: CustomInoutCallCompiler,
210
247
  discard_compiler: CustomInoutCallCompiler,
211
248
  init_arg: bool, # Whether the init function should take a nat argument
249
+ wasm_sigs: ConcreteWasmModule
250
+ | None = None, # For @wasm_module, we must be passed a parsed wasm file
212
251
  ) -> Callable[[str, str | None], Callable[[builtins.type[T]], GuppyDefinition]]:
213
- from guppylang.defs import GuppyDefinition
214
-
215
252
  def fun(
216
253
  filename: str, module: str | None
217
254
  ) -> Callable[[builtins.type[T]], GuppyDefinition]:
@@ -231,6 +268,47 @@ def ext_module_decorator(
231
268
  for val in cls.__dict__.values():
232
269
  if isinstance(val, GuppyDefinition):
233
270
  DEF_STORE.register_impl(ext_module.id, val.wrapped.name, val.id)
271
+ wasm_def: RawWasmFunctionDef
272
+ if isinstance(val, GuppyFunctionDefinition) and isinstance(
273
+ val.wrapped, RawWasmFunctionDef
274
+ ):
275
+ wasm_def = val.wrapped
276
+ else:
277
+ continue
278
+ # wasm_sigs should only have not been provided if we have
279
+ # defined @wasm functions in a class which didn't use the
280
+ # @wasm_module decorator.
281
+ assert wasm_sigs is not None
282
+ if wasm_def.wasm_index is not None:
283
+ name = wasm_sigs.functions[wasm_def.wasm_index]
284
+ assert name in wasm_sigs.function_sigs
285
+ wasm_sig_or_err = wasm_sigs.function_sigs[name]
286
+ else:
287
+ if wasm_def.name in wasm_sigs.function_sigs:
288
+ wasm_sig_or_err = wasm_sigs.function_sigs[wasm_def.name]
289
+ else:
290
+ raise GuppyError(
291
+ WasmFunctionNotInFile(
292
+ wasm_def.defined_at,
293
+ wasm_def.name,
294
+ ).add_sub_diagnostic(
295
+ WasmFunctionNotInFile.WasmFileNote(
296
+ None,
297
+ wasm_sigs.filename,
298
+ )
299
+ )
300
+ )
301
+ if isinstance(wasm_sig_or_err, FunctionType):
302
+ DEF_STORE.register_wasm_function(wasm_def.id, wasm_sig_or_err)
303
+ elif isinstance(wasm_sig_or_err, str):
304
+ raise GuppyError(
305
+ WasmSignatureError(
306
+ None, wasm_def.name, filename
307
+ ).add_sub_diagnostic(
308
+ WasmSignatureError.Message(None, wasm_sig_or_err)
309
+ )
310
+ )
311
+
234
312
  # Add a constructor to the class
235
313
  if init_arg:
236
314
  init_fn_ty = FunctionType(
@@ -315,6 +393,7 @@ def wasm_helper(fn_id: int | None, f: Callable[P, T]) -> GuppyFunctionDefinition
315
393
  WasmModuleCallCompiler(f.__name__, fn_id),
316
394
  True,
317
395
  signature=None,
396
+ wasm_index=fn_id,
318
397
  )
319
398
  DEF_STORE.register_def(func, get_calling_frame())
320
399
  return GuppyFunctionDefinition(func)
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, ClassVar
7
7
  from hugr import Wire, ops
8
8
  from hugr import tys as ht
9
9
  from hugr.build.dfg import DfBase
10
+ from hugr.std.collections.borrow_array import EXTENSION as BORROW_ARRAY_EXTENSION
10
11
 
11
12
  from guppylang_internals.ast_util import (
12
13
  AstNode,
@@ -23,6 +24,7 @@ from guppylang_internals.compiler.core import (
23
24
  DFContainer,
24
25
  GlobalConstId,
25
26
  partially_monomorphize_args,
27
+ qualified_name,
26
28
  )
27
29
  from guppylang_internals.definition.common import ParsableDef
28
30
  from guppylang_internals.definition.value import CallReturnWires, CompiledCallableDef
@@ -42,6 +44,7 @@ from guppylang_internals.tys.ty import (
42
44
  InputFlags,
43
45
  NoneType,
44
46
  Type,
47
+ UnitaryFlags,
45
48
  type_to_row,
46
49
  )
47
50
 
@@ -112,6 +115,8 @@ class RawCustomFunctionDef(ParsableDef):
112
115
 
113
116
  signature: FunctionType | None
114
117
 
118
+ unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags)
119
+
115
120
  description: str = field(default="function", init=False)
116
121
 
117
122
  def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
@@ -134,6 +139,7 @@ class RawCustomFunctionDef(ParsableDef):
134
139
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
135
140
  sig = self.signature or self._get_signature(func_ast, globals)
136
141
  ty = sig or FunctionType([], NoneType())
142
+ ty = ty.with_unitary_flags(self.unitary_flags)
137
143
  return CustomFunctionDef(
138
144
  self.id,
139
145
  self.name,
@@ -486,7 +492,39 @@ class NoopCompiler(CustomCallCompiler):
486
492
 
487
493
 
488
494
  class CopyInoutCompiler(CustomInoutCallCompiler):
489
- """Call compiler for functions that are noops but only want to borrow arguments."""
495
+ """Call compiler for functions that borrow one argument to copy it."""
490
496
 
491
497
  def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
498
+ assert len(self.ty.input) == 1
499
+ inp_ty = self.ty.input[0]
500
+ if inp_ty.type_bound() == ht.TypeBound.Linear:
501
+ (arg,) = args
502
+ copies = self._handle_affine_type(inp_ty, arg)
503
+ return CallReturnWires(
504
+ regular_returns=[copies[0]], inout_returns=[copies[1]]
505
+ )
492
506
  return CallReturnWires(regular_returns=args, inout_returns=args)
507
+
508
+ # Affine types in Guppy backed by a linear Hugr type need to be copied explicitly.
509
+ # TODO: Handle affine extension types more generally (borrow arrays are currently
510
+ # the only case).
511
+ def _handle_affine_type(self, ty: ht.Type, arg: Wire) -> list[Wire]:
512
+ match ty:
513
+ case ht.ExtType(type_def=type_def, args=type_args):
514
+ if qualified_name(type_def) == qualified_name(
515
+ BORROW_ARRAY_EXTENSION.get_type("borrow_array")
516
+ ):
517
+ assert len(type_args) == 2
518
+ # Manually instantiate here to avoid circular import and use
519
+ # type args directly.
520
+ clone_op = BORROW_ARRAY_EXTENSION.get_op("clone").instantiate(
521
+ type_args,
522
+ ht.FunctionType(self.ty.input, self.ty.output),
523
+ )
524
+ return list(self.builder.add_op(clone_op, arg))
525
+ case _:
526
+ pass
527
+ raise InternalGuppyError(
528
+ f"Type `{ty}` needs an explicit handler in the `copy` compiler as "
529
+ "it is an affine Guppy type backed by a linear Hugr type."
530
+ )
@@ -13,7 +13,7 @@ from guppylang_internals.checker.func_checker import check_signature
13
13
  from guppylang_internals.compiler.core import (
14
14
  CompilerContext,
15
15
  DFContainer,
16
- requires_monomorphization,
16
+ require_monomorphization,
17
17
  )
18
18
  from guppylang_internals.definition.common import CompilableDef, ParsableDef
19
19
  from guppylang_internals.definition.function import (
@@ -34,7 +34,7 @@ from guppylang_internals.nodes import GlobalCall
34
34
  from guppylang_internals.span import SourceMap
35
35
  from guppylang_internals.tys.param import Parameter
36
36
  from guppylang_internals.tys.subst import Inst, Subst
37
- from guppylang_internals.tys.ty import Type
37
+ from guppylang_internals.tys.ty import Type, UnitaryFlags
38
38
 
39
39
 
40
40
  @dataclass(frozen=True)
@@ -65,16 +65,19 @@ class RawFunctionDecl(ParsableDef):
65
65
  python_func: PyFunc
66
66
  description: str = field(default="function", init=False)
67
67
 
68
+ unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
69
+
68
70
  def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
69
71
  """Parses and checks the user-provided signature of the function."""
70
72
  func_ast, docstring = parse_py_func(self.python_func, sources)
71
- ty = check_signature(func_ast, globals, self.id)
73
+ ty = check_signature(
74
+ func_ast, globals, self.id, unitary_flags=self.unitary_flags
75
+ )
72
76
  if not has_empty_body(func_ast):
73
77
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
74
78
  # Make sure we won't need monomorphization to compile this declaration
75
- for param in ty.params:
76
- if requires_monomorphization(param):
77
- raise GuppyError(MonomorphizeError(func_ast, self.name, param))
79
+ if mono_params := require_monomorphization(ty.params):
80
+ raise GuppyError(MonomorphizeError(func_ast, self.name, mono_params.pop()))
78
81
  return CheckedFunctionDecl(
79
82
  self.id,
80
83
  self.name,
@@ -43,7 +43,7 @@ from guppylang_internals.error import GuppyError
43
43
  from guppylang_internals.nodes import GlobalCall
44
44
  from guppylang_internals.span import SourceMap
45
45
  from guppylang_internals.tys.subst import Inst, Subst
46
- from guppylang_internals.tys.ty import FunctionType, Type, type_to_row
46
+ from guppylang_internals.tys.ty import FunctionType, Type, UnitaryFlags, type_to_row
47
47
 
48
48
  if TYPE_CHECKING:
49
49
  from guppylang_internals.tys.param import Parameter
@@ -70,10 +70,14 @@ class RawFunctionDef(ParsableDef):
70
70
 
71
71
  description: str = field(default="function", init=False)
72
72
 
73
+ unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
74
+
73
75
  def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
74
76
  """Parses and checks the user-provided signature of the function."""
75
77
  func_ast, docstring = parse_py_func(self.python_func, sources)
76
- ty = check_signature(func_ast, globals, self.id)
78
+ ty = check_signature(
79
+ func_ast, globals, self.id, unitary_flags=self.unitary_flags
80
+ )
77
81
  return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
78
82
 
79
83
 
@@ -173,6 +177,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
173
177
  func_def = module.module_root_builder().define_function(
174
178
  self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
175
179
  )
180
+ add_unitarity_metadata(func_def, self.ty.unitary_flags)
176
181
  return CompiledFunctionDef(
177
182
  self.id,
178
183
  self.name,
@@ -300,3 +305,8 @@ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AS
300
305
  else:
301
306
  node = ast.parse(source).body[0]
302
307
  return source, node, line_offset
308
+
309
+
310
+ def add_unitarity_metadata(func: hf.Function, flags: UnitaryFlags) -> None:
311
+ """Stores unitarity annotations in the metadate of a Hugr function definition."""
312
+ func.metadata["unitary"] = flags.value
@@ -38,14 +38,19 @@ class ParamDef(Definition):
38
38
  class TypeVarDef(ParamDef, CompiledDef):
39
39
  """A type variable definition."""
40
40
 
41
- must_be_copyable: bool
42
- must_be_droppable: bool
41
+ copyable: bool
42
+ droppable: bool
43
43
 
44
44
  description: str = field(default="type variable", init=False)
45
45
 
46
46
  def to_param(self, idx: int) -> TypeParam:
47
47
  """Creates a parameter from this definition."""
48
- return TypeParam(idx, self.name, self.must_be_copyable, self.must_be_droppable)
48
+ return TypeParam(
49
+ idx,
50
+ self.name,
51
+ must_be_copyable=self.copyable,
52
+ must_be_droppable=self.droppable,
53
+ )
49
54
 
50
55
 
51
56
  @dataclass(frozen=True)