guppylang-internals 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +101 -3
  5. guppylang_internals/checker/core.py +12 -0
  6. guppylang_internals/checker/errors/generic.py +32 -1
  7. guppylang_internals/checker/errors/type_errors.py +14 -0
  8. guppylang_internals/checker/expr_checker.py +55 -29
  9. guppylang_internals/checker/func_checker.py +171 -22
  10. guppylang_internals/checker/linearity_checker.py +65 -0
  11. guppylang_internals/checker/modifier_checker.py +116 -0
  12. guppylang_internals/checker/stmt_checker.py +49 -2
  13. guppylang_internals/compiler/core.py +90 -53
  14. guppylang_internals/compiler/expr_compiler.py +49 -114
  15. guppylang_internals/compiler/modifier_compiler.py +174 -0
  16. guppylang_internals/compiler/stmt_compiler.py +15 -8
  17. guppylang_internals/decorator.py +124 -58
  18. guppylang_internals/definition/const.py +2 -2
  19. guppylang_internals/definition/custom.py +36 -2
  20. guppylang_internals/definition/declaration.py +4 -5
  21. guppylang_internals/definition/extern.py +2 -2
  22. guppylang_internals/definition/function.py +1 -1
  23. guppylang_internals/definition/parameter.py +10 -5
  24. guppylang_internals/definition/pytket_circuits.py +14 -42
  25. guppylang_internals/definition/struct.py +17 -14
  26. guppylang_internals/definition/traced.py +1 -1
  27. guppylang_internals/definition/ty.py +9 -3
  28. guppylang_internals/definition/wasm.py +2 -2
  29. guppylang_internals/engine.py +13 -2
  30. guppylang_internals/experimental.py +5 -0
  31. guppylang_internals/nodes.py +124 -23
  32. guppylang_internals/std/_internal/compiler/array.py +94 -282
  33. guppylang_internals/std/_internal/compiler/tket_exts.py +12 -8
  34. guppylang_internals/std/_internal/compiler/wasm.py +37 -26
  35. guppylang_internals/tracing/function.py +13 -2
  36. guppylang_internals/tracing/unpacking.py +33 -28
  37. guppylang_internals/tys/arg.py +18 -3
  38. guppylang_internals/tys/builtin.py +32 -16
  39. guppylang_internals/tys/const.py +33 -4
  40. guppylang_internals/tys/errors.py +6 -0
  41. guppylang_internals/tys/param.py +31 -16
  42. guppylang_internals/tys/parsing.py +118 -145
  43. guppylang_internals/tys/qubit.py +27 -0
  44. guppylang_internals/tys/subst.py +8 -26
  45. guppylang_internals/tys/ty.py +31 -21
  46. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +4 -4
  47. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +49 -46
  48. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
  49. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
@@ -0,0 +1,174 @@
1
+ """Hugr generation for modifiers."""
2
+
3
+ from hugr import Wire, ops
4
+ from hugr import tys as ht
5
+
6
+ from guppylang_internals.ast_util import get_type
7
+ from guppylang_internals.checker.modifier_checker import non_copyable_front_others_back
8
+ from guppylang_internals.compiler.cfg_compiler import compile_cfg
9
+ from guppylang_internals.compiler.core import CompilerContext, DFContainer
10
+ from guppylang_internals.compiler.expr_compiler import ExprCompiler
11
+ from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
12
+ from guppylang_internals.std._internal.compiler.array import (
13
+ array_new,
14
+ array_to_std_array,
15
+ standard_array_type,
16
+ std_array_to_array,
17
+ unpack_array,
18
+ )
19
+ from guppylang_internals.std._internal.compiler.tket_exts import MODIFIER_EXTENSION
20
+ from guppylang_internals.tys.builtin import int_type, is_array_type
21
+ from guppylang_internals.tys.ty import InputFlags
22
+
23
+
24
+ def compile_modified_block(
25
+ modified_block: CheckedModifiedBlock,
26
+ dfg: DFContainer,
27
+ ctx: CompilerContext,
28
+ expr_compiler: ExprCompiler,
29
+ ) -> Wire:
30
+ DAGGER_OP_NAME = "DaggerModifier"
31
+ CONTROL_OP_NAME = "ControlModifier"
32
+ POWER_OP_NAME = "PowerModifier"
33
+
34
+ dagger_op_def = MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME)
35
+ control_op_def = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME)
36
+ power_op_def = MODIFIER_EXTENSION.get_op(POWER_OP_NAME)
37
+
38
+ body_ty = modified_block.ty
39
+ # TODO: Shouldn't this be `to_hugr_poly` since it can contain
40
+ # a variable with a generic type?
41
+ hugr_ty = body_ty.to_hugr(ctx)
42
+ in_out_ht = [
43
+ fn_inp.ty.to_hugr(ctx)
44
+ for fn_inp in body_ty.inputs
45
+ if InputFlags.Inout in fn_inp.flags and InputFlags.Comptime not in fn_inp.flags
46
+ ]
47
+ other_in_ht = [
48
+ fn_inp.ty.to_hugr(ctx)
49
+ for fn_inp in body_ty.inputs
50
+ if InputFlags.Inout not in fn_inp.flags
51
+ and InputFlags.Comptime not in fn_inp.flags
52
+ ]
53
+ in_out_arg = ht.ListArg([t.type_arg() for t in in_out_ht])
54
+ other_in_arg = ht.ListArg([t.type_arg() for t in other_in_ht])
55
+
56
+ func_builder = dfg.builder.module_root_builder().define_function(
57
+ str(modified_block), hugr_ty.input, hugr_ty.output
58
+ )
59
+
60
+ # compile body
61
+ cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
62
+ func_builder.set_outputs(*cfg)
63
+
64
+ # LoadFunc
65
+ call = dfg.builder.load_function(func_builder, hugr_ty)
66
+
67
+ # Function inputs
68
+ captured = [v for v, _ in modified_block.captured.values()]
69
+ captured = non_copyable_front_others_back(captured)
70
+ args = [dfg[v] for v in captured]
71
+
72
+ # Apply modifiers
73
+ if modified_block.has_dagger():
74
+ dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty])
75
+ call = dfg.builder.add_op(
76
+ ops.ExtOp(
77
+ dagger_op_def,
78
+ dagger_ty,
79
+ [in_out_arg, other_in_arg],
80
+ ),
81
+ call,
82
+ )
83
+ if modified_block.has_power():
84
+ power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty])
85
+ for power in modified_block.power:
86
+ num = expr_compiler.compile(power.iter, dfg)
87
+ call = dfg.builder.add_op(
88
+ ops.ExtOp(
89
+ power_op_def,
90
+ power_ty,
91
+ [in_out_arg, other_in_arg],
92
+ ),
93
+ call,
94
+ num,
95
+ )
96
+ qubit_num_args = []
97
+ if modified_block.has_control():
98
+ for control in modified_block.control:
99
+ assert control.qubit_num is not None
100
+ qubit_num: ht.TypeArg
101
+ if isinstance(control.qubit_num, int):
102
+ qubit_num = ht.BoundedNatArg(control.qubit_num)
103
+ else:
104
+ qubit_num = control.qubit_num.to_arg().to_hugr(ctx)
105
+ qubit_num_args.append(qubit_num)
106
+ std_array = standard_array_type(ht.Qubit, qubit_num)
107
+
108
+ # control operator
109
+ input_fn_ty = hugr_ty
110
+ output_fn_ty = ht.FunctionType(
111
+ [std_array, *hugr_ty.input], [std_array, *hugr_ty.output]
112
+ )
113
+ op = ops.ExtOp(
114
+ control_op_def,
115
+ ht.FunctionType([input_fn_ty], [output_fn_ty]),
116
+ [qubit_num, in_out_arg, other_in_arg],
117
+ )
118
+ call = dfg.builder.add_op(op, call)
119
+ # update types
120
+ in_out_arg = ht.ListArg([std_array.type_arg(), *in_out_arg.elems])
121
+ hugr_ty = output_fn_ty
122
+
123
+ # Prepare control arguments
124
+ ctrl_args: list[Wire] = []
125
+ for i, control in enumerate(modified_block.control):
126
+ if is_array_type(get_type(control.ctrl[0])):
127
+ control_array = expr_compiler.compile(control.ctrl[0], dfg)
128
+ control_array = dfg.builder.add_op(
129
+ array_to_std_array(ht.Qubit, qubit_num_args[i]), control_array
130
+ )
131
+ ctrl_args.append(control_array)
132
+ else:
133
+ cs = [expr_compiler.compile(c, dfg) for c in control.ctrl]
134
+ control_array = dfg.builder.add_op(
135
+ array_new(ht.Qubit, len(control.ctrl)), *cs
136
+ )
137
+ control_array = dfg.builder.add_op(
138
+ array_to_std_array(ht.Qubit, qubit_num_args[i]), *control_array
139
+ )
140
+ ctrl_args.append(control_array)
141
+
142
+ # Call
143
+ call = dfg.builder.add_op(
144
+ ops.CallIndirect(),
145
+ call,
146
+ *ctrl_args,
147
+ *args,
148
+ )
149
+ outports = iter(call)
150
+
151
+ # Unpack controls
152
+ for i, control in enumerate(modified_block.control):
153
+ outport = next(outports)
154
+ if is_array_type(get_type(control.ctrl[0])):
155
+ control_array = dfg.builder.add_op(
156
+ std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
157
+ )
158
+ c = control.ctrl[0]
159
+ assert isinstance(c, PlaceNode)
160
+ dfg[c.place] = control_array
161
+ else:
162
+ control_array = dfg.builder.add_op(
163
+ std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
164
+ )
165
+ unpacked = unpack_array(dfg.builder, control_array)
166
+ for c, new_c in zip(control.ctrl, unpacked, strict=False):
167
+ assert isinstance(c, PlaceNode)
168
+ dfg[c.place] = new_c
169
+
170
+ for arg in captured:
171
+ if InputFlags.Inout in arg.flags:
172
+ dfg[arg] = next(outports)
173
+
174
+ return call
@@ -2,7 +2,6 @@ import ast
2
2
  import functools
3
3
  from collections.abc import Sequence
4
4
 
5
- import hugr.tys as ht
6
5
  from hugr import Wire, ops
7
6
  from hugr.build.dfg import DfBase
8
7
 
@@ -18,6 +17,7 @@ from guppylang_internals.compiler.expr_compiler import ExprCompiler
18
17
  from guppylang_internals.error import InternalGuppyError
19
18
  from guppylang_internals.nodes import (
20
19
  ArrayUnpack,
20
+ CheckedModifiedBlock,
21
21
  CheckedNestedFunctionDef,
22
22
  IterableUnpack,
23
23
  PlaceNode,
@@ -120,8 +120,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
120
120
  ports[len(left) : -len(right)] if right else ports[len(left) :]
121
121
  )
122
122
  elt = get_element_type(array_ty).to_hugr(self.ctx)
123
- opts = [self.builder.add_op(ops.Some(elt), p) for p in starred_ports]
124
- array = self.builder.add_op(array_new(ht.Option(elt), len(opts)), *opts)
123
+ array = self.builder.add_op(
124
+ array_new(elt, len(starred_ports)), *starred_ports
125
+ )
125
126
  self._assign(starred, array)
126
127
 
127
128
  @_assign.register
@@ -130,7 +131,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
130
131
  # Given an assignment pattern `left, *starred, right`, pop from the left and
131
132
  # right, leaving us with the starred array in the middle
132
133
  length = lhs.length
133
- opt_elt_ty = ht.Option(lhs.elt_type.to_hugr(self.ctx))
134
+ elt_ty = lhs.elt_type.to_hugr(self.ctx)
134
135
 
135
136
  def pop(
136
137
  array: Wire, length: int, pats: list[ast.expr], from_left: bool
@@ -141,10 +142,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
141
142
  elts = []
142
143
  for i in range(num_pats):
143
144
  res = self.builder.add_op(
144
- array_pop(opt_elt_ty, length - i, from_left), array
145
+ array_pop(elt_ty, length - i, from_left), array
145
146
  )
146
- [elt_opt, array] = build_unwrap(self.builder, res, err)
147
- [elt] = build_unwrap(self.builder, elt_opt, err)
147
+ [elt, array] = build_unwrap(self.builder, res, err)
148
148
  elts.append(elt)
149
149
  # Assign elements to the given patterns
150
150
  for pat, elt in zip(
@@ -164,7 +164,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
164
164
  self._assign(lhs.pattern.starred, array)
165
165
  else:
166
166
  assert length == 0
167
- self.builder.add_op(array_discard_empty(opt_elt_ty), array)
167
+ self.builder.add_op(array_discard_empty(elt_ty), array)
168
168
 
169
169
  @_assign.register
170
170
  def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None:
@@ -221,3 +221,10 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
221
221
  var = Variable(node.name, node.ty, node)
222
222
  loaded_func = compile_local_func_def(node, self.dfg, self.ctx)
223
223
  self.dfg[var] = loaded_func
224
+
225
+ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
226
+ from guppylang_internals.compiler.modifier_compiler import (
227
+ compile_modified_block,
228
+ )
229
+
230
+ compile_modified_block(node, self.dfg, self.ctx, self.expr_compiler)
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import TYPE_CHECKING, ParamSpec, TypeVar
4
+ from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
5
5
 
6
6
  from hugr import ops
7
7
  from hugr import tys as ht
@@ -42,6 +42,7 @@ from guppylang_internals.tys.ty import (
42
42
  )
43
43
 
44
44
  if TYPE_CHECKING:
45
+ import ast
45
46
  import builtins
46
47
  from collections.abc import Callable, Sequence
47
48
  from types import FrameType
@@ -121,15 +122,19 @@ def hugr_op(
121
122
  return custom_function(OpCompiler(op), checker, higher_order_value, name, signature)
122
123
 
123
124
 
124
- def extend_type(defn: TypeDef) -> Callable[[type], type]:
125
- """Decorator to add new instance functions to a type."""
125
+ def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]:
126
+ """Decorator to add new instance functions to a type.
127
+
128
+ By default, returns a `GuppyDefinition` object referring to the type. Alternatively,
129
+ `return_class=True` can be set to return the decorated class unchanged.
130
+ """
126
131
  from guppylang.defs import GuppyDefinition
127
132
 
128
133
  def dec(c: type) -> type:
129
134
  for val in c.__dict__.values():
130
135
  if isinstance(val, GuppyDefinition):
131
136
  DEF_STORE.register_impl(defn.id, val.wrapped.name, val.id)
132
- return c
137
+ return c if return_class else GuppyDefinition(defn) # type: ignore[return-value]
133
138
 
134
139
  return dec
135
140
 
@@ -181,63 +186,124 @@ def custom_type(
181
186
 
182
187
 
183
188
  def wasm_module(
184
- filename: str, filehash: int
189
+ filename: str,
185
190
  ) -> Callable[[builtins.type[T]], GuppyDefinition]:
186
- from guppylang.defs import GuppyDefinition
187
-
188
- def dec(cls: builtins.type[T]) -> GuppyDefinition:
189
- # N.B. Only one module per file and vice-versa
190
- wasm_module = WasmModuleTypeDef(
191
- DefId.fresh(),
192
- cls.__name__,
193
- None,
194
- filename,
195
- filehash,
196
- )
197
-
198
- wasm_module_ty = wasm_module.check_instantiate([], None)
199
-
200
- DEF_STORE.register_def(wasm_module, get_calling_frame())
201
- for val in cls.__dict__.values():
202
- if isinstance(val, GuppyDefinition):
203
- DEF_STORE.register_impl(wasm_module.id, val.wrapped.name, val.id)
204
- # Add a constructor to the class
205
- call_method = CustomFunctionDef(
206
- DefId.fresh(),
207
- "__new__",
208
- None,
209
- FunctionType(
210
- [FuncInput(NumericType(NumericType.Kind.Nat), flags=InputFlags.Owned)],
211
- wasm_module_ty,
212
- ),
213
- DefaultCallChecker(),
214
- WasmModuleInitCompiler(),
215
- True,
216
- GlobalConstId.fresh(f"{cls.__name__}.__new__"),
217
- True,
218
- )
219
- discard = CustomFunctionDef(
220
- DefId.fresh(),
221
- "discard",
222
- None,
223
- FunctionType([FuncInput(wasm_module_ty, InputFlags.Owned)], NoneType()),
224
- DefaultCallChecker(),
225
- WasmModuleDiscardCompiler(),
226
- False,
227
- GlobalConstId.fresh(f"{cls.__name__}.__discard__"),
228
- True,
229
- )
230
- DEF_STORE.register_def(call_method, get_calling_frame())
231
- DEF_STORE.register_impl(wasm_module.id, "__new__", call_method.id)
232
- DEF_STORE.register_def(discard, get_calling_frame())
233
- DEF_STORE.register_impl(wasm_module.id, "discard", discard.id)
191
+ def type_def_wrapper(
192
+ id: DefId,
193
+ name: str,
194
+ defined_at: ast.AST | None,
195
+ wasm_file: str,
196
+ config: str | None,
197
+ ) -> OpaqueTypeDef:
198
+ assert config is None
199
+ return WasmModuleTypeDef(id, name, defined_at, wasm_file)
200
+
201
+ f = ext_module_decorator(
202
+ type_def_wrapper, WasmModuleInitCompiler(), WasmModuleDiscardCompiler(), True
203
+ )
204
+ return f(filename, None)
234
205
 
235
- return GuppyDefinition(wasm_module)
236
-
237
- return dec
238
206
 
207
+ def ext_module_decorator(
208
+ type_def: Callable[[DefId, str, ast.AST | None, str, str | None], OpaqueTypeDef],
209
+ init_compiler: CustomInoutCallCompiler,
210
+ discard_compiler: CustomInoutCallCompiler,
211
+ init_arg: bool, # Whether the init function should take a nat argument
212
+ ) -> Callable[[str, str | None], Callable[[builtins.type[T]], GuppyDefinition]]:
213
+ from guppylang.defs import GuppyDefinition
239
214
 
240
- def wasm(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
215
+ def fun(
216
+ filename: str, module: str | None
217
+ ) -> Callable[[builtins.type[T]], GuppyDefinition]:
218
+ def dec(cls: builtins.type[T]) -> GuppyDefinition:
219
+ # N.B. Only one module per file and vice-versa
220
+ ext_module = type_def(
221
+ DefId.fresh(),
222
+ cls.__name__,
223
+ None,
224
+ filename,
225
+ module,
226
+ )
227
+
228
+ ext_module_ty = ext_module.check_instantiate([], None)
229
+
230
+ DEF_STORE.register_def(ext_module, get_calling_frame())
231
+ for val in cls.__dict__.values():
232
+ if isinstance(val, GuppyDefinition):
233
+ DEF_STORE.register_impl(ext_module.id, val.wrapped.name, val.id)
234
+ # Add a constructor to the class
235
+ if init_arg:
236
+ init_fn_ty = FunctionType(
237
+ [
238
+ FuncInput(
239
+ NumericType(NumericType.Kind.Nat),
240
+ flags=InputFlags.Owned,
241
+ )
242
+ ],
243
+ ext_module_ty,
244
+ )
245
+ else:
246
+ init_fn_ty = FunctionType([], ext_module_ty)
247
+
248
+ call_method = CustomFunctionDef(
249
+ DefId.fresh(),
250
+ "__new__",
251
+ None,
252
+ init_fn_ty,
253
+ DefaultCallChecker(),
254
+ init_compiler,
255
+ True,
256
+ GlobalConstId.fresh(f"{cls.__name__}.__new__"),
257
+ True,
258
+ )
259
+ discard = CustomFunctionDef(
260
+ DefId.fresh(),
261
+ "discard",
262
+ None,
263
+ FunctionType([FuncInput(ext_module_ty, InputFlags.Owned)], NoneType()),
264
+ DefaultCallChecker(),
265
+ discard_compiler,
266
+ False,
267
+ GlobalConstId.fresh(f"{cls.__name__}.__discard__"),
268
+ True,
269
+ )
270
+ DEF_STORE.register_def(call_method, get_calling_frame())
271
+ DEF_STORE.register_impl(ext_module.id, "__new__", call_method.id)
272
+ DEF_STORE.register_def(discard, get_calling_frame())
273
+ DEF_STORE.register_impl(ext_module.id, "discard", discard.id)
274
+
275
+ return GuppyDefinition(ext_module)
276
+
277
+ return dec
278
+
279
+ return fun
280
+
281
+
282
+ @overload
283
+ def wasm(arg: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: ...
284
+
285
+
286
+ @overload
287
+ def wasm(arg: int) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: ...
288
+
289
+
290
+ def wasm(
291
+ arg: int | Callable[P, T],
292
+ ) -> (
293
+ GuppyFunctionDefinition[P, T]
294
+ | Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]
295
+ ):
296
+ if isinstance(arg, int):
297
+
298
+ def wrapper(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
299
+ return wasm_helper(arg, f)
300
+
301
+ return wrapper
302
+ else:
303
+ return wasm_helper(None, arg)
304
+
305
+
306
+ def wasm_helper(fn_id: int | None, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
241
307
  from guppylang.defs import GuppyFunctionDefinition
242
308
 
243
309
  func = RawWasmFunctionDef(
@@ -246,7 +312,7 @@ def wasm(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
246
312
  None,
247
313
  f,
248
314
  WasmCallChecker(),
249
- WasmModuleCallCompiler(f.__name__),
315
+ WasmModuleCallCompiler(f.__name__, fn_id),
250
316
  True,
251
317
  signature=None,
252
318
  )
@@ -15,7 +15,7 @@ from guppylang_internals.definition.value import (
15
15
  ValueDef,
16
16
  )
17
17
  from guppylang_internals.span import SourceMap
18
- from guppylang_internals.tys.parsing import type_from_ast
18
+ from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
19
19
 
20
20
 
21
21
  @dataclass(frozen=True)
@@ -33,7 +33,7 @@ class RawConstDef(ParsableDef):
33
33
  self.id,
34
34
  self.name,
35
35
  self.defined_at,
36
- type_from_ast(self.type_ast, globals, {}),
36
+ type_from_ast(self.type_ast, TypeParsingCtx(globals)),
37
37
  self.type_ast,
38
38
  self.value,
39
39
  )
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, ClassVar
7
7
  from hugr import Wire, ops
8
8
  from hugr import tys as ht
9
9
  from hugr.build.dfg import DfBase
10
+ from hugr.std.collections.borrow_array import EXTENSION as BORROW_ARRAY_EXTENSION
10
11
 
11
12
  from guppylang_internals.ast_util import (
12
13
  AstNode,
@@ -23,6 +24,7 @@ from guppylang_internals.compiler.core import (
23
24
  DFContainer,
24
25
  GlobalConstId,
25
26
  partially_monomorphize_args,
27
+ qualified_name,
26
28
  )
27
29
  from guppylang_internals.definition.common import ParsableDef
28
30
  from guppylang_internals.definition.value import CallReturnWires, CompiledCallableDef
@@ -169,7 +171,7 @@ class RawCustomFunctionDef(ParsableDef):
169
171
  raise GuppyError(NoSignatureError(node, self.name))
170
172
 
171
173
  if requires_type_annotation:
172
- return check_signature(node, globals)
174
+ return check_signature(node, globals, self.id)
173
175
  else:
174
176
  return None
175
177
 
@@ -486,7 +488,39 @@ class NoopCompiler(CustomCallCompiler):
486
488
 
487
489
 
488
490
  class CopyInoutCompiler(CustomInoutCallCompiler):
489
- """Call compiler for functions that are noops but only want to borrow arguments."""
491
+ """Call compiler for functions that borrow one argument to copy it."""
490
492
 
491
493
  def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
494
+ assert len(self.ty.input) == 1
495
+ inp_ty = self.ty.input[0]
496
+ if inp_ty.type_bound() == ht.TypeBound.Linear:
497
+ (arg,) = args
498
+ copies = self._handle_affine_type(inp_ty, arg)
499
+ return CallReturnWires(
500
+ regular_returns=[copies[0]], inout_returns=[copies[1]]
501
+ )
492
502
  return CallReturnWires(regular_returns=args, inout_returns=args)
503
+
504
+ # Affine types in Guppy backed by a linear Hugr type need to be copied explicitly.
505
+ # TODO: Handle affine extension types more generally (borrow arrays are currently
506
+ # the only case).
507
+ def _handle_affine_type(self, ty: ht.Type, arg: Wire) -> list[Wire]:
508
+ match ty:
509
+ case ht.ExtType(type_def=type_def, args=type_args):
510
+ if qualified_name(type_def) == qualified_name(
511
+ BORROW_ARRAY_EXTENSION.get_type("borrow_array")
512
+ ):
513
+ assert len(type_args) == 2
514
+ # Manually instantiate here to avoid circular import and use
515
+ # type args directly.
516
+ clone_op = BORROW_ARRAY_EXTENSION.get_op("clone").instantiate(
517
+ type_args,
518
+ ht.FunctionType(self.ty.input, self.ty.output),
519
+ )
520
+ return list(self.builder.add_op(clone_op, arg))
521
+ case _:
522
+ pass
523
+ raise InternalGuppyError(
524
+ f"Type `{ty}` needs an explicit handler in the `copy` compiler as "
525
+ "it is an affine Guppy type backed by a linear Hugr type."
526
+ )
@@ -13,7 +13,7 @@ from guppylang_internals.checker.func_checker import check_signature
13
13
  from guppylang_internals.compiler.core import (
14
14
  CompilerContext,
15
15
  DFContainer,
16
- requires_monomorphization,
16
+ require_monomorphization,
17
17
  )
18
18
  from guppylang_internals.definition.common import CompilableDef, ParsableDef
19
19
  from guppylang_internals.definition.function import (
@@ -68,13 +68,12 @@ class RawFunctionDecl(ParsableDef):
68
68
  def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
69
69
  """Parses and checks the user-provided signature of the function."""
70
70
  func_ast, docstring = parse_py_func(self.python_func, sources)
71
- ty = check_signature(func_ast, globals)
71
+ ty = check_signature(func_ast, globals, self.id)
72
72
  if not has_empty_body(func_ast):
73
73
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
74
74
  # Make sure we won't need monomorphization to compile this declaration
75
- for param in ty.params:
76
- if requires_monomorphization(param):
77
- raise GuppyError(MonomorphizeError(func_ast, self.name, param))
75
+ if mono_params := require_monomorphization(ty.params):
76
+ raise GuppyError(MonomorphizeError(func_ast, self.name, mono_params.pop()))
78
77
  return CheckedFunctionDecl(
79
78
  self.id,
80
79
  self.name,
@@ -14,7 +14,7 @@ from guppylang_internals.definition.value import (
14
14
  ValueDef,
15
15
  )
16
16
  from guppylang_internals.span import SourceMap
17
- from guppylang_internals.tys.parsing import type_from_ast
17
+ from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
18
18
 
19
19
 
20
20
  @dataclass(frozen=True)
@@ -33,7 +33,7 @@ class RawExternDef(ParsableDef):
33
33
  self.id,
34
34
  self.name,
35
35
  self.defined_at,
36
- type_from_ast(self.type_ast, globals, {}),
36
+ type_from_ast(self.type_ast, TypeParsingCtx(globals)),
37
37
  self.symbol,
38
38
  self.constant,
39
39
  self.type_ast,
@@ -73,7 +73,7 @@ class RawFunctionDef(ParsableDef):
73
73
  def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
74
74
  """Parses and checks the user-provided signature of the function."""
75
75
  func_ast, docstring = parse_py_func(self.python_func, sources)
76
- ty = check_signature(func_ast, globals)
76
+ ty = check_signature(func_ast, globals, self.id)
77
77
  return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
78
78
 
79
79
 
@@ -38,14 +38,19 @@ class ParamDef(Definition):
38
38
  class TypeVarDef(ParamDef, CompiledDef):
39
39
  """A type variable definition."""
40
40
 
41
- must_be_copyable: bool
42
- must_be_droppable: bool
41
+ copyable: bool
42
+ droppable: bool
43
43
 
44
44
  description: str = field(default="type variable", init=False)
45
45
 
46
46
  def to_param(self, idx: int) -> TypeParam:
47
47
  """Creates a parameter from this definition."""
48
- return TypeParam(idx, self.name, self.must_be_copyable, self.must_be_droppable)
48
+ return TypeParam(
49
+ idx,
50
+ self.name,
51
+ must_be_copyable=self.copyable,
52
+ must_be_droppable=self.droppable,
53
+ )
49
54
 
50
55
 
51
56
  @dataclass(frozen=True)
@@ -56,9 +61,9 @@ class RawConstVarDef(ParamDef, ParsableDef):
56
61
  description: str = field(default="const variable", init=False)
57
62
 
58
63
  def parse(self, globals: Globals, sources: SourceMap) -> "ConstVarDef":
59
- from guppylang_internals.tys.parsing import type_from_ast
64
+ from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
60
65
 
61
- ty = type_from_ast(self.type_ast, globals, {})
66
+ ty = type_from_ast(self.type_ast, TypeParsingCtx(globals))
62
67
  if not ty.copyable or not ty.droppable:
63
68
  raise GuppyError(LinearConstVarError(self.type_ast, self.name, ty))
64
69
  return ConstVarDef(self.id, self.name, self.defined_at, ty)