guppylang-internals 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +101 -3
  5. guppylang_internals/checker/core.py +4 -0
  6. guppylang_internals/checker/errors/generic.py +32 -1
  7. guppylang_internals/checker/errors/type_errors.py +14 -0
  8. guppylang_internals/checker/expr_checker.py +46 -10
  9. guppylang_internals/checker/func_checker.py +1 -1
  10. guppylang_internals/checker/linearity_checker.py +65 -0
  11. guppylang_internals/checker/modifier_checker.py +116 -0
  12. guppylang_internals/checker/stmt_checker.py +48 -1
  13. guppylang_internals/compiler/core.py +90 -53
  14. guppylang_internals/compiler/expr_compiler.py +49 -114
  15. guppylang_internals/compiler/modifier_compiler.py +174 -0
  16. guppylang_internals/compiler/stmt_compiler.py +15 -8
  17. guppylang_internals/definition/custom.py +35 -1
  18. guppylang_internals/definition/declaration.py +3 -4
  19. guppylang_internals/definition/parameter.py +8 -3
  20. guppylang_internals/definition/pytket_circuits.py +13 -41
  21. guppylang_internals/definition/struct.py +7 -4
  22. guppylang_internals/definition/ty.py +3 -3
  23. guppylang_internals/experimental.py +5 -0
  24. guppylang_internals/nodes.py +124 -0
  25. guppylang_internals/std/_internal/compiler/array.py +94 -282
  26. guppylang_internals/std/_internal/compiler/tket_exts.py +9 -2
  27. guppylang_internals/tracing/unpacking.py +19 -20
  28. guppylang_internals/tys/arg.py +18 -3
  29. guppylang_internals/tys/builtin.py +2 -5
  30. guppylang_internals/tys/const.py +33 -4
  31. guppylang_internals/tys/param.py +31 -16
  32. guppylang_internals/tys/parsing.py +8 -21
  33. guppylang_internals/tys/qubit.py +27 -0
  34. guppylang_internals/tys/subst.py +8 -26
  35. guppylang_internals/tys/ty.py +31 -21
  36. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +3 -3
  37. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +39 -36
  38. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
  39. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
@@ -0,0 +1,174 @@
1
+ """Hugr generation for modifiers."""
2
+
3
+ from hugr import Wire, ops
4
+ from hugr import tys as ht
5
+
6
+ from guppylang_internals.ast_util import get_type
7
+ from guppylang_internals.checker.modifier_checker import non_copyable_front_others_back
8
+ from guppylang_internals.compiler.cfg_compiler import compile_cfg
9
+ from guppylang_internals.compiler.core import CompilerContext, DFContainer
10
+ from guppylang_internals.compiler.expr_compiler import ExprCompiler
11
+ from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
12
+ from guppylang_internals.std._internal.compiler.array import (
13
+ array_new,
14
+ array_to_std_array,
15
+ standard_array_type,
16
+ std_array_to_array,
17
+ unpack_array,
18
+ )
19
+ from guppylang_internals.std._internal.compiler.tket_exts import MODIFIER_EXTENSION
20
+ from guppylang_internals.tys.builtin import int_type, is_array_type
21
+ from guppylang_internals.tys.ty import InputFlags
22
+
23
+
24
+ def compile_modified_block(
25
+ modified_block: CheckedModifiedBlock,
26
+ dfg: DFContainer,
27
+ ctx: CompilerContext,
28
+ expr_compiler: ExprCompiler,
29
+ ) -> Wire:
30
+ DAGGER_OP_NAME = "DaggerModifier"
31
+ CONTROL_OP_NAME = "ControlModifier"
32
+ POWER_OP_NAME = "PowerModifier"
33
+
34
+ dagger_op_def = MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME)
35
+ control_op_def = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME)
36
+ power_op_def = MODIFIER_EXTENSION.get_op(POWER_OP_NAME)
37
+
38
+ body_ty = modified_block.ty
39
+ # TODO: Shouldn't this be `to_hugr_poly` since it can contain
40
+ # a variable with a generic type?
41
+ hugr_ty = body_ty.to_hugr(ctx)
42
+ in_out_ht = [
43
+ fn_inp.ty.to_hugr(ctx)
44
+ for fn_inp in body_ty.inputs
45
+ if InputFlags.Inout in fn_inp.flags and InputFlags.Comptime not in fn_inp.flags
46
+ ]
47
+ other_in_ht = [
48
+ fn_inp.ty.to_hugr(ctx)
49
+ for fn_inp in body_ty.inputs
50
+ if InputFlags.Inout not in fn_inp.flags
51
+ and InputFlags.Comptime not in fn_inp.flags
52
+ ]
53
+ in_out_arg = ht.ListArg([t.type_arg() for t in in_out_ht])
54
+ other_in_arg = ht.ListArg([t.type_arg() for t in other_in_ht])
55
+
56
+ func_builder = dfg.builder.module_root_builder().define_function(
57
+ str(modified_block), hugr_ty.input, hugr_ty.output
58
+ )
59
+
60
+ # compile body
61
+ cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
62
+ func_builder.set_outputs(*cfg)
63
+
64
+ # LoadFunc
65
+ call = dfg.builder.load_function(func_builder, hugr_ty)
66
+
67
+ # Function inputs
68
+ captured = [v for v, _ in modified_block.captured.values()]
69
+ captured = non_copyable_front_others_back(captured)
70
+ args = [dfg[v] for v in captured]
71
+
72
+ # Apply modifiers
73
+ if modified_block.has_dagger():
74
+ dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty])
75
+ call = dfg.builder.add_op(
76
+ ops.ExtOp(
77
+ dagger_op_def,
78
+ dagger_ty,
79
+ [in_out_arg, other_in_arg],
80
+ ),
81
+ call,
82
+ )
83
+ if modified_block.has_power():
84
+ power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty])
85
+ for power in modified_block.power:
86
+ num = expr_compiler.compile(power.iter, dfg)
87
+ call = dfg.builder.add_op(
88
+ ops.ExtOp(
89
+ power_op_def,
90
+ power_ty,
91
+ [in_out_arg, other_in_arg],
92
+ ),
93
+ call,
94
+ num,
95
+ )
96
+ qubit_num_args = []
97
+ if modified_block.has_control():
98
+ for control in modified_block.control:
99
+ assert control.qubit_num is not None
100
+ qubit_num: ht.TypeArg
101
+ if isinstance(control.qubit_num, int):
102
+ qubit_num = ht.BoundedNatArg(control.qubit_num)
103
+ else:
104
+ qubit_num = control.qubit_num.to_arg().to_hugr(ctx)
105
+ qubit_num_args.append(qubit_num)
106
+ std_array = standard_array_type(ht.Qubit, qubit_num)
107
+
108
+ # control operator
109
+ input_fn_ty = hugr_ty
110
+ output_fn_ty = ht.FunctionType(
111
+ [std_array, *hugr_ty.input], [std_array, *hugr_ty.output]
112
+ )
113
+ op = ops.ExtOp(
114
+ control_op_def,
115
+ ht.FunctionType([input_fn_ty], [output_fn_ty]),
116
+ [qubit_num, in_out_arg, other_in_arg],
117
+ )
118
+ call = dfg.builder.add_op(op, call)
119
+ # update types
120
+ in_out_arg = ht.ListArg([std_array.type_arg(), *in_out_arg.elems])
121
+ hugr_ty = output_fn_ty
122
+
123
+ # Prepare control arguments
124
+ ctrl_args: list[Wire] = []
125
+ for i, control in enumerate(modified_block.control):
126
+ if is_array_type(get_type(control.ctrl[0])):
127
+ control_array = expr_compiler.compile(control.ctrl[0], dfg)
128
+ control_array = dfg.builder.add_op(
129
+ array_to_std_array(ht.Qubit, qubit_num_args[i]), control_array
130
+ )
131
+ ctrl_args.append(control_array)
132
+ else:
133
+ cs = [expr_compiler.compile(c, dfg) for c in control.ctrl]
134
+ control_array = dfg.builder.add_op(
135
+ array_new(ht.Qubit, len(control.ctrl)), *cs
136
+ )
137
+ control_array = dfg.builder.add_op(
138
+ array_to_std_array(ht.Qubit, qubit_num_args[i]), *control_array
139
+ )
140
+ ctrl_args.append(control_array)
141
+
142
+ # Call
143
+ call = dfg.builder.add_op(
144
+ ops.CallIndirect(),
145
+ call,
146
+ *ctrl_args,
147
+ *args,
148
+ )
149
+ outports = iter(call)
150
+
151
+ # Unpack controls
152
+ for i, control in enumerate(modified_block.control):
153
+ outport = next(outports)
154
+ if is_array_type(get_type(control.ctrl[0])):
155
+ control_array = dfg.builder.add_op(
156
+ std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
157
+ )
158
+ c = control.ctrl[0]
159
+ assert isinstance(c, PlaceNode)
160
+ dfg[c.place] = control_array
161
+ else:
162
+ control_array = dfg.builder.add_op(
163
+ std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
164
+ )
165
+ unpacked = unpack_array(dfg.builder, control_array)
166
+ for c, new_c in zip(control.ctrl, unpacked, strict=False):
167
+ assert isinstance(c, PlaceNode)
168
+ dfg[c.place] = new_c
169
+
170
+ for arg in captured:
171
+ if InputFlags.Inout in arg.flags:
172
+ dfg[arg] = next(outports)
173
+
174
+ return call
@@ -2,7 +2,6 @@ import ast
2
2
  import functools
3
3
  from collections.abc import Sequence
4
4
 
5
- import hugr.tys as ht
6
5
  from hugr import Wire, ops
7
6
  from hugr.build.dfg import DfBase
8
7
 
@@ -18,6 +17,7 @@ from guppylang_internals.compiler.expr_compiler import ExprCompiler
18
17
  from guppylang_internals.error import InternalGuppyError
19
18
  from guppylang_internals.nodes import (
20
19
  ArrayUnpack,
20
+ CheckedModifiedBlock,
21
21
  CheckedNestedFunctionDef,
22
22
  IterableUnpack,
23
23
  PlaceNode,
@@ -120,8 +120,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
120
120
  ports[len(left) : -len(right)] if right else ports[len(left) :]
121
121
  )
122
122
  elt = get_element_type(array_ty).to_hugr(self.ctx)
123
- opts = [self.builder.add_op(ops.Some(elt), p) for p in starred_ports]
124
- array = self.builder.add_op(array_new(ht.Option(elt), len(opts)), *opts)
123
+ array = self.builder.add_op(
124
+ array_new(elt, len(starred_ports)), *starred_ports
125
+ )
125
126
  self._assign(starred, array)
126
127
 
127
128
  @_assign.register
@@ -130,7 +131,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
130
131
  # Given an assignment pattern `left, *starred, right`, pop from the left and
131
132
  # right, leaving us with the starred array in the middle
132
133
  length = lhs.length
133
- opt_elt_ty = ht.Option(lhs.elt_type.to_hugr(self.ctx))
134
+ elt_ty = lhs.elt_type.to_hugr(self.ctx)
134
135
 
135
136
  def pop(
136
137
  array: Wire, length: int, pats: list[ast.expr], from_left: bool
@@ -141,10 +142,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
141
142
  elts = []
142
143
  for i in range(num_pats):
143
144
  res = self.builder.add_op(
144
- array_pop(opt_elt_ty, length - i, from_left), array
145
+ array_pop(elt_ty, length - i, from_left), array
145
146
  )
146
- [elt_opt, array] = build_unwrap(self.builder, res, err)
147
- [elt] = build_unwrap(self.builder, elt_opt, err)
147
+ [elt, array] = build_unwrap(self.builder, res, err)
148
148
  elts.append(elt)
149
149
  # Assign elements to the given patterns
150
150
  for pat, elt in zip(
@@ -164,7 +164,7 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
164
164
  self._assign(lhs.pattern.starred, array)
165
165
  else:
166
166
  assert length == 0
167
- self.builder.add_op(array_discard_empty(opt_elt_ty), array)
167
+ self.builder.add_op(array_discard_empty(elt_ty), array)
168
168
 
169
169
  @_assign.register
170
170
  def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None:
@@ -221,3 +221,10 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):
221
221
  var = Variable(node.name, node.ty, node)
222
222
  loaded_func = compile_local_func_def(node, self.dfg, self.ctx)
223
223
  self.dfg[var] = loaded_func
224
+
225
+ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
226
+ from guppylang_internals.compiler.modifier_compiler import (
227
+ compile_modified_block,
228
+ )
229
+
230
+ compile_modified_block(node, self.dfg, self.ctx, self.expr_compiler)
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, ClassVar
7
7
  from hugr import Wire, ops
8
8
  from hugr import tys as ht
9
9
  from hugr.build.dfg import DfBase
10
+ from hugr.std.collections.borrow_array import EXTENSION as BORROW_ARRAY_EXTENSION
10
11
 
11
12
  from guppylang_internals.ast_util import (
12
13
  AstNode,
@@ -23,6 +24,7 @@ from guppylang_internals.compiler.core import (
23
24
  DFContainer,
24
25
  GlobalConstId,
25
26
  partially_monomorphize_args,
27
+ qualified_name,
26
28
  )
27
29
  from guppylang_internals.definition.common import ParsableDef
28
30
  from guppylang_internals.definition.value import CallReturnWires, CompiledCallableDef
@@ -486,7 +488,39 @@ class NoopCompiler(CustomCallCompiler):
486
488
 
487
489
 
488
490
  class CopyInoutCompiler(CustomInoutCallCompiler):
489
- """Call compiler for functions that are noops but only want to borrow arguments."""
491
+ """Call compiler for functions that borrow one argument to copy it."""
490
492
 
491
493
  def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
494
+ assert len(self.ty.input) == 1
495
+ inp_ty = self.ty.input[0]
496
+ if inp_ty.type_bound() == ht.TypeBound.Linear:
497
+ (arg,) = args
498
+ copies = self._handle_affine_type(inp_ty, arg)
499
+ return CallReturnWires(
500
+ regular_returns=[copies[0]], inout_returns=[copies[1]]
501
+ )
492
502
  return CallReturnWires(regular_returns=args, inout_returns=args)
503
+
504
+ # Affine types in Guppy backed by a linear Hugr type need to be copied explicitly.
505
+ # TODO: Handle affine extension types more generally (borrow arrays are currently
506
+ # the only case).
507
+ def _handle_affine_type(self, ty: ht.Type, arg: Wire) -> list[Wire]:
508
+ match ty:
509
+ case ht.ExtType(type_def=type_def, args=type_args):
510
+ if qualified_name(type_def) == qualified_name(
511
+ BORROW_ARRAY_EXTENSION.get_type("borrow_array")
512
+ ):
513
+ assert len(type_args) == 2
514
+ # Manually instantiate here to avoid circular import and use
515
+ # type args directly.
516
+ clone_op = BORROW_ARRAY_EXTENSION.get_op("clone").instantiate(
517
+ type_args,
518
+ ht.FunctionType(self.ty.input, self.ty.output),
519
+ )
520
+ return list(self.builder.add_op(clone_op, arg))
521
+ case _:
522
+ pass
523
+ raise InternalGuppyError(
524
+ f"Type `{ty}` needs an explicit handler in the `copy` compiler as "
525
+ "it is an affine Guppy type backed by a linear Hugr type."
526
+ )
@@ -13,7 +13,7 @@ from guppylang_internals.checker.func_checker import check_signature
13
13
  from guppylang_internals.compiler.core import (
14
14
  CompilerContext,
15
15
  DFContainer,
16
- requires_monomorphization,
16
+ require_monomorphization,
17
17
  )
18
18
  from guppylang_internals.definition.common import CompilableDef, ParsableDef
19
19
  from guppylang_internals.definition.function import (
@@ -72,9 +72,8 @@ class RawFunctionDecl(ParsableDef):
72
72
  if not has_empty_body(func_ast):
73
73
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
74
74
  # Make sure we won't need monomorphization to compile this declaration
75
- for param in ty.params:
76
- if requires_monomorphization(param):
77
- raise GuppyError(MonomorphizeError(func_ast, self.name, param))
75
+ if mono_params := require_monomorphization(ty.params):
76
+ raise GuppyError(MonomorphizeError(func_ast, self.name, mono_params.pop()))
78
77
  return CheckedFunctionDecl(
79
78
  self.id,
80
79
  self.name,
@@ -38,14 +38,19 @@ class ParamDef(Definition):
38
38
  class TypeVarDef(ParamDef, CompiledDef):
39
39
  """A type variable definition."""
40
40
 
41
- must_be_copyable: bool
42
- must_be_droppable: bool
41
+ copyable: bool
42
+ droppable: bool
43
43
 
44
44
  description: str = field(default="type variable", init=False)
45
45
 
46
46
  def to_param(self, idx: int) -> TypeParam:
47
47
  """Creates a parameter from this definition."""
48
- return TypeParam(idx, self.name, self.must_be_copyable, self.must_be_droppable)
48
+ return TypeParam(
49
+ idx,
50
+ self.name,
51
+ must_be_copyable=self.copyable,
52
+ must_be_droppable=self.droppable,
53
+ )
49
54
 
50
55
 
51
56
  @dataclass(frozen=True)
@@ -46,7 +46,6 @@ from guppylang_internals.std._internal.compiler.array import (
46
46
  array_new,
47
47
  array_unpack,
48
48
  )
49
- from guppylang_internals.std._internal.compiler.prelude import build_unwrap
50
49
  from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
51
50
  from guppylang_internals.tys.builtin import array_type, bool_type, float_type
52
51
  from guppylang_internals.tys.subst import Inst, Subst
@@ -195,17 +194,9 @@ class ParsedPytketDef(CallableDef, CompilableDef):
195
194
  # them into separate wires.
196
195
  for i, q_reg in enumerate(self.input_circuit.q_registers):
197
196
  reg_wire = outer_func.inputs()[i]
198
- opt_elem_wires = outer_func.add_op(
199
- array_unpack(ht.Option(ht.Qubit), q_reg.size), reg_wire
197
+ elem_wires = outer_func.add_op(
198
+ array_unpack(ht.Qubit, q_reg.size), reg_wire
200
199
  )
201
- elem_wires = [
202
- build_unwrap(
203
- outer_func,
204
- opt_elem,
205
- "Internal error: unwrapping of array element failed",
206
- )
207
- for opt_elem in opt_elem_wires
208
- ]
209
200
  input_list.extend(elem_wires)
210
201
 
211
202
  else:
@@ -219,7 +210,8 @@ class ParsedPytketDef(CallableDef, CompilableDef):
219
210
  ]
220
211
 
221
212
  # Symbolic parameters (if present) get passed after qubits and bools.
222
- has_params = len(self.input_circuit.free_symbols()) != 0
213
+ num_params = len(self.input_circuit.free_symbols())
214
+ has_params = num_params != 0
223
215
  if has_params and "TKET1.input_parameters" not in hugr_func.metadata:
224
216
  raise InternalGuppyError(
225
217
  "Parameter metadata is missing from pytket circuit HUGR"
@@ -230,26 +222,17 @@ class ParsedPytketDef(CallableDef, CompilableDef):
230
222
  if has_params:
231
223
  lex_params: list[Wire] = list(outer_func.inputs()[offset:])
232
224
  if self.use_arrays:
233
- opt_param_wires = outer_func.add_op(
225
+ unpack_result = outer_func.add_op(
234
226
  array_unpack(
235
- ht.Option(ht.Tuple(float_type().to_hugr(ctx))),
236
- q_reg.size,
227
+ ht.Tuple(float_type().to_hugr(ctx)), num_params
237
228
  ),
238
229
  lex_params[0],
239
230
  )
240
- lex_params = [
241
- build_unwrap(
242
- outer_func,
243
- opt_param,
244
- "Internal error: unwrapping of array element failed",
245
- )
246
- for opt_param in opt_param_wires
247
- ]
231
+ lex_params = list(unpack_result)
248
232
  param_order = cast(
249
233
  list[str], hugr_func.metadata["TKET1.input_parameters"]
250
234
  )
251
235
  lex_names = sorted(param_order)
252
- assert len(lex_names) == len(lex_params)
253
236
  name_to_param = dict(zip(lex_names, lex_params, strict=True))
254
237
  angle_wires = [name_to_param[name] for name in param_order]
255
238
  # Need to convert all angles to floats.
@@ -280,34 +263,23 @@ class ParsedPytketDef(CallableDef, CompilableDef):
280
263
  ]
281
264
 
282
265
  if self.use_arrays:
283
-
284
- def pack(elems: list[Wire], elem_ty: ht.Type, length: int) -> Wire:
285
- elem_opts = [
286
- outer_func.add_op(ops.Some(elem_ty), elem) for elem in elems
287
- ]
288
- return outer_func.add_op(
289
- array_new(ht.Option(elem_ty), length), *elem_opts
290
- )
291
-
292
266
  array_wires: list[Wire] = []
293
267
  wire_idx = 0
294
268
  # First pack bool results into an array.
295
269
  for c_reg in self.input_circuit.c_registers:
296
270
  array_wires.append(
297
- pack(
298
- wires[wire_idx : wire_idx + c_reg.size],
299
- OpaqueBool,
300
- c_reg.size,
271
+ outer_func.add_op(
272
+ array_new(OpaqueBool, c_reg.size),
273
+ *wires[wire_idx : wire_idx + c_reg.size],
301
274
  )
302
275
  )
303
276
  wire_idx = wire_idx + c_reg.size
304
277
  # Then the borrowed qubits also need to be put back into arrays.
305
278
  for q_reg in self.input_circuit.q_registers:
306
279
  array_wires.append(
307
- pack(
308
- wires[wire_idx : wire_idx + q_reg.size],
309
- ht.Qubit,
310
- q_reg.size,
280
+ outer_func.add_op(
281
+ array_new(ht.Qubit, q_reg.size),
282
+ *wires[wire_idx : wire_idx + q_reg.size],
311
283
  )
312
284
  )
313
285
  wire_idx = wire_idx + q_reg.size
@@ -131,10 +131,13 @@ class RawStructDef(TypeDef, ParsableDef):
131
131
  if cls_def.type_params:
132
132
  first, last = cls_def.type_params[0], cls_def.type_params[-1]
133
133
  params_span = Span(to_span(first).start, to_span(last).end)
134
- params = [
135
- parse_parameter(node, idx, globals)
136
- for idx, node in enumerate(cls_def.type_params)
137
- ]
134
+ param_vars_mapping: dict[str, Parameter] = {}
135
+ for idx, param_node in enumerate(cls_def.type_params):
136
+ param = parse_parameter(
137
+ param_node, idx, globals, param_vars_mapping
138
+ )
139
+ param_vars_mapping[param.name] = param
140
+ params.append(param)
138
141
 
139
142
  # The only base we allow is `Generic[...]` to specify generic parameters with
140
143
  # the legacy syntax
@@ -2,7 +2,7 @@ from abc import abstractmethod
2
2
  from collections.abc import Callable, Sequence
3
3
  from dataclasses import dataclass, field
4
4
 
5
- from hugr import tys
5
+ from hugr import tys as ht
6
6
 
7
7
  from guppylang_internals.ast_util import AstNode
8
8
  from guppylang_internals.definition.common import CompiledDef, Definition
@@ -42,8 +42,8 @@ class OpaqueTypeDef(TypeDef, CompiledDef):
42
42
  params: Sequence[Parameter]
43
43
  never_copyable: bool
44
44
  never_droppable: bool
45
- to_hugr: Callable[[Sequence[Argument], ToHugrContext], tys.Type]
46
- bound: tys.TypeBound | None = None
45
+ to_hugr: Callable[[Sequence[Argument], ToHugrContext], ht.Type]
46
+ bound: ht.TypeBound | None = None
47
47
 
48
48
  def check_instantiate(
49
49
  self, args: Sequence[Argument], loc: AstNode | None = None
@@ -90,3 +90,8 @@ def check_lists_enabled(loc: AstNode | None = None) -> None:
90
90
  def check_capturing_closures_enabled(loc: AstNode | None = None) -> None:
91
91
  if not EXPERIMENTAL_FEATURES_ENABLED:
92
92
  raise GuppyError(UnsupportedError(loc, "Capturing closures"))
93
+
94
+
95
+ def check_modifiers_enabled(loc: AstNode | None = None) -> None:
96
+ if not EXPERIMENTAL_FEATURES_ENABLED:
97
+ raise GuppyError(ExperimentalFeatureError(loc, "Modifiers"))
@@ -6,6 +6,7 @@ from enum import Enum
6
6
  from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from guppylang_internals.ast_util import AstNode
9
+ from guppylang_internals.span import Span, to_span
9
10
  from guppylang_internals.tys.const import Const
10
11
  from guppylang_internals.tys.subst import Inst
11
12
  from guppylang_internals.tys.ty import FunctionType, StructType, TupleType, Type
@@ -422,3 +423,126 @@ class CheckedNestedFunctionDef(ast.FunctionDef):
422
423
  self.cfg = cfg
423
424
  self.ty = ty
424
425
  self.captured = captured
426
+
427
+
428
+ class Dagger(ast.expr):
429
+ """The dagger modifier"""
430
+
431
+ def __init__(self, node: ast.expr) -> None:
432
+ super().__init__(**node.__dict__)
433
+
434
+
435
+ class Control(ast.Call):
436
+ """The control modifier"""
437
+
438
+ ctrl: list[ast.expr]
439
+ qubit_num: int | Const | None
440
+
441
+ _fields = ("ctrl",)
442
+
443
+ def __init__(self, node: ast.Call, ctrl: list[ast.expr]) -> None:
444
+ super().__init__(**node.__dict__)
445
+ self.ctrl = ctrl
446
+ self.qubit_num = None
447
+
448
+
449
+ class Power(ast.expr):
450
+ """The power modifier"""
451
+
452
+ iter: ast.expr
453
+
454
+ _fields = ("iter",)
455
+
456
+ def __init__(self, node: ast.expr, iter: ast.expr) -> None:
457
+ super().__init__(**node.__dict__)
458
+ self.iter = iter
459
+
460
+
461
+ Modifier = Dagger | Control | Power
462
+
463
+
464
+ class ModifiedBlock(ast.With):
465
+ cfg: "CFG"
466
+ dagger: list[Dagger]
467
+ control: list[Control]
468
+ power: list[Power]
469
+
470
+ def __init__(self, cfg: "CFG", *args: Any, **kwargs: Any) -> None:
471
+ super().__init__(*args, **kwargs)
472
+ self.cfg = cfg
473
+ self.dagger = []
474
+ self.control = []
475
+ self.power = []
476
+
477
+ def is_dagger(self) -> bool:
478
+ return len(self.dagger) % 2 == 1
479
+
480
+ def is_control(self) -> bool:
481
+ return len(self.control) > 0
482
+
483
+ def is_power(self) -> bool:
484
+ return len(self.power) > 0
485
+
486
+ def span_ctxt_manager(self) -> Span:
487
+ return Span(
488
+ to_span(self.items[0].context_expr).start,
489
+ to_span(self.items[-1].context_expr).end,
490
+ )
491
+
492
+ def push_modifier(self, modifier: Modifier) -> None:
493
+ """Pushes a modifier kind onto the modifier."""
494
+ if isinstance(modifier, Dagger):
495
+ self.dagger.append(modifier)
496
+ elif isinstance(modifier, Control):
497
+ self.control.append(modifier)
498
+ elif isinstance(modifier, Power):
499
+ self.power.append(modifier)
500
+ else:
501
+ raise TypeError(f"Unknown modifier: {modifier}")
502
+
503
+
504
+ class CheckedModifiedBlock(ast.With):
505
+ def_id: "DefId"
506
+ cfg: "CheckedCFG[Place]"
507
+ dagger: list[Dagger]
508
+ control: list[Control]
509
+ power: list[Power]
510
+
511
+ #: The type of the body of With block.
512
+ ty: FunctionType
513
+ #: Mapping from names to variables captured in the body.
514
+ captured: Mapping[str, tuple["Variable", AstNode]]
515
+
516
+ def __init__(
517
+ self,
518
+ def_id: "DefId",
519
+ cfg: "CheckedCFG[Place]",
520
+ ty: FunctionType,
521
+ captured: Mapping[str, tuple["Variable", AstNode]],
522
+ dagger: list[Dagger],
523
+ control: list[Control],
524
+ power: list[Power],
525
+ *args: Any,
526
+ **kwargs: Any,
527
+ ) -> None:
528
+ super().__init__(*args, **kwargs)
529
+ self.def_id = def_id
530
+ self.cfg = cfg
531
+ self.ty = ty
532
+ self.captured = captured
533
+ self.dagger = dagger
534
+ self.control = control
535
+ self.power = power
536
+
537
+ def __str__(self) -> str:
538
+ # generate a function name from the def_id
539
+ return f"__WithBlock__({self.def_id})"
540
+
541
+ def has_dagger(self) -> bool:
542
+ return len(self.dagger) % 2 == 1
543
+
544
+ def has_control(self) -> bool:
545
+ return any(len(c.ctrl) > 0 for c in self.control)
546
+
547
+ def has_power(self) -> bool:
548
+ return len(self.power) > 0