guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +118 -5
  5. guppylang_internals/cfg/cfg.py +3 -0
  6. guppylang_internals/checker/cfg_checker.py +6 -0
  7. guppylang_internals/checker/core.py +5 -2
  8. guppylang_internals/checker/errors/generic.py +32 -1
  9. guppylang_internals/checker/errors/type_errors.py +14 -0
  10. guppylang_internals/checker/errors/wasm.py +7 -4
  11. guppylang_internals/checker/expr_checker.py +58 -17
  12. guppylang_internals/checker/func_checker.py +18 -14
  13. guppylang_internals/checker/linearity_checker.py +67 -10
  14. guppylang_internals/checker/modifier_checker.py +120 -0
  15. guppylang_internals/checker/stmt_checker.py +48 -1
  16. guppylang_internals/checker/unitary_checker.py +132 -0
  17. guppylang_internals/compiler/cfg_compiler.py +7 -6
  18. guppylang_internals/compiler/core.py +93 -56
  19. guppylang_internals/compiler/expr_compiler.py +72 -168
  20. guppylang_internals/compiler/modifier_compiler.py +176 -0
  21. guppylang_internals/compiler/stmt_compiler.py +15 -8
  22. guppylang_internals/decorator.py +86 -7
  23. guppylang_internals/definition/custom.py +39 -1
  24. guppylang_internals/definition/declaration.py +9 -6
  25. guppylang_internals/definition/function.py +12 -2
  26. guppylang_internals/definition/parameter.py +8 -3
  27. guppylang_internals/definition/pytket_circuits.py +14 -41
  28. guppylang_internals/definition/struct.py +13 -7
  29. guppylang_internals/definition/ty.py +3 -3
  30. guppylang_internals/definition/wasm.py +42 -10
  31. guppylang_internals/engine.py +9 -3
  32. guppylang_internals/experimental.py +5 -0
  33. guppylang_internals/nodes.py +147 -24
  34. guppylang_internals/std/_internal/checker.py +13 -108
  35. guppylang_internals/std/_internal/compiler/array.py +95 -283
  36. guppylang_internals/std/_internal/compiler/list.py +1 -1
  37. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  38. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  39. guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
  40. guppylang_internals/std/_internal/debug.py +18 -9
  41. guppylang_internals/std/_internal/util.py +1 -1
  42. guppylang_internals/tracing/object.py +10 -0
  43. guppylang_internals/tracing/unpacking.py +19 -20
  44. guppylang_internals/tys/arg.py +18 -3
  45. guppylang_internals/tys/builtin.py +2 -5
  46. guppylang_internals/tys/const.py +33 -4
  47. guppylang_internals/tys/errors.py +23 -1
  48. guppylang_internals/tys/param.py +31 -16
  49. guppylang_internals/tys/parsing.py +11 -24
  50. guppylang_internals/tys/printing.py +2 -8
  51. guppylang_internals/tys/qubit.py +62 -0
  52. guppylang_internals/tys/subst.py +8 -26
  53. guppylang_internals/tys/ty.py +91 -85
  54. guppylang_internals/wasm_util.py +129 -0
  55. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
  56. guppylang_internals-0.26.0.dist-info/RECORD +104 -0
  57. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
  58. guppylang_internals-0.24.0.dist-info/RECORD +0 -98
  59. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
@@ -46,7 +46,6 @@ from guppylang_internals.std._internal.compiler.array import (
46
46
  array_new,
47
47
  array_unpack,
48
48
  )
49
- from guppylang_internals.std._internal.compiler.prelude import build_unwrap
50
49
  from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
51
50
  from guppylang_internals.tys.builtin import array_type, bool_type, float_type
52
51
  from guppylang_internals.tys.subst import Inst, Subst
@@ -195,17 +194,9 @@ class ParsedPytketDef(CallableDef, CompilableDef):
195
194
  # them into separate wires.
196
195
  for i, q_reg in enumerate(self.input_circuit.q_registers):
197
196
  reg_wire = outer_func.inputs()[i]
198
- opt_elem_wires = outer_func.add_op(
199
- array_unpack(ht.Option(ht.Qubit), q_reg.size), reg_wire
197
+ elem_wires = outer_func.add_op(
198
+ array_unpack(ht.Qubit, q_reg.size), reg_wire
200
199
  )
201
- elem_wires = [
202
- build_unwrap(
203
- outer_func,
204
- opt_elem,
205
- "Internal error: unwrapping of array element failed",
206
- )
207
- for opt_elem in opt_elem_wires
208
- ]
209
200
  input_list.extend(elem_wires)
210
201
 
211
202
  else:
@@ -219,7 +210,8 @@ class ParsedPytketDef(CallableDef, CompilableDef):
219
210
  ]
220
211
 
221
212
  # Symbolic parameters (if present) get passed after qubits and bools.
222
- has_params = len(self.input_circuit.free_symbols()) != 0
213
+ num_params = len(self.input_circuit.free_symbols())
214
+ has_params = num_params != 0
223
215
  if has_params and "TKET1.input_parameters" not in hugr_func.metadata:
224
216
  raise InternalGuppyError(
225
217
  "Parameter metadata is missing from pytket circuit HUGR"
@@ -230,26 +222,17 @@ class ParsedPytketDef(CallableDef, CompilableDef):
230
222
  if has_params:
231
223
  lex_params: list[Wire] = list(outer_func.inputs()[offset:])
232
224
  if self.use_arrays:
233
- opt_param_wires = outer_func.add_op(
225
+ unpack_result = outer_func.add_op(
234
226
  array_unpack(
235
- ht.Option(ht.Tuple(float_type().to_hugr(ctx))),
236
- q_reg.size,
227
+ ht.Tuple(float_type().to_hugr(ctx)), num_params
237
228
  ),
238
229
  lex_params[0],
239
230
  )
240
- lex_params = [
241
- build_unwrap(
242
- outer_func,
243
- opt_param,
244
- "Internal error: unwrapping of array element failed",
245
- )
246
- for opt_param in opt_param_wires
247
- ]
231
+ lex_params = list(unpack_result)
248
232
  param_order = cast(
249
233
  list[str], hugr_func.metadata["TKET1.input_parameters"]
250
234
  )
251
235
  lex_names = sorted(param_order)
252
- assert len(lex_names) == len(lex_params)
253
236
  name_to_param = dict(zip(lex_names, lex_params, strict=True))
254
237
  angle_wires = [name_to_param[name] for name in param_order]
255
238
  # Need to convert all angles to floats.
@@ -280,34 +263,23 @@ class ParsedPytketDef(CallableDef, CompilableDef):
280
263
  ]
281
264
 
282
265
  if self.use_arrays:
283
-
284
- def pack(elems: list[Wire], elem_ty: ht.Type, length: int) -> Wire:
285
- elem_opts = [
286
- outer_func.add_op(ops.Some(elem_ty), elem) for elem in elems
287
- ]
288
- return outer_func.add_op(
289
- array_new(ht.Option(elem_ty), length), *elem_opts
290
- )
291
-
292
266
  array_wires: list[Wire] = []
293
267
  wire_idx = 0
294
268
  # First pack bool results into an array.
295
269
  for c_reg in self.input_circuit.c_registers:
296
270
  array_wires.append(
297
- pack(
298
- wires[wire_idx : wire_idx + c_reg.size],
299
- OpaqueBool,
300
- c_reg.size,
271
+ outer_func.add_op(
272
+ array_new(OpaqueBool, c_reg.size),
273
+ *wires[wire_idx : wire_idx + c_reg.size],
301
274
  )
302
275
  )
303
276
  wire_idx = wire_idx + c_reg.size
304
277
  # Then the borrowed qubits also need to be put back into arrays.
305
278
  for q_reg in self.input_circuit.q_registers:
306
279
  array_wires.append(
307
- pack(
308
- wires[wire_idx : wire_idx + q_reg.size],
309
- ht.Qubit,
310
- q_reg.size,
280
+ outer_func.add_op(
281
+ array_new(ht.Qubit, q_reg.size),
282
+ *wires[wire_idx : wire_idx + q_reg.size],
311
283
  )
312
284
  )
313
285
  wire_idx = wire_idx + q_reg.size
@@ -398,6 +370,7 @@ def _signature_from_circuit(
398
370
  use_arrays: bool = False,
399
371
  ) -> FunctionType:
400
372
  """Helper function for inferring a function signature from a pytket circuit."""
373
+ # May want to set proper unitary flags in the future.
401
374
  try:
402
375
  import pytket
403
376
 
@@ -131,10 +131,13 @@ class RawStructDef(TypeDef, ParsableDef):
131
131
  if cls_def.type_params:
132
132
  first, last = cls_def.type_params[0], cls_def.type_params[-1]
133
133
  params_span = Span(to_span(first).start, to_span(last).end)
134
- params = [
135
- parse_parameter(node, idx, globals)
136
- for idx, node in enumerate(cls_def.type_params)
137
- ]
134
+ param_vars_mapping: dict[str, Parameter] = {}
135
+ for idx, param_node in enumerate(cls_def.type_params):
136
+ param = parse_parameter(
137
+ param_node, idx, globals, param_vars_mapping
138
+ )
139
+ param_vars_mapping[param.name] = param
140
+ params.append(param)
138
141
 
139
142
  # The only base we allow is `Generic[...]` to specify generic parameters with
140
143
  # the legacy syntax
@@ -270,13 +273,16 @@ class CheckedStructDef(TypeDef, CompiledDef):
270
273
 
271
274
  constructor_sig = FunctionType(
272
275
  inputs=[
273
- FuncInput(f.ty, InputFlags.Owned if f.ty.linear else InputFlags.NoFlags)
276
+ FuncInput(
277
+ f.ty,
278
+ InputFlags.Owned if f.ty.linear else InputFlags.NoFlags,
279
+ f.name,
280
+ )
274
281
  for f in self.fields
275
282
  ],
276
283
  output=StructType(
277
284
  defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)]
278
285
  ),
279
- input_names=[f.name for f in self.fields],
280
286
  params=self.params,
281
287
  )
282
288
  constructor_def = CustomFunctionDef(
@@ -314,7 +320,7 @@ def parse_py_class(
314
320
  raise GuppyError(UnknownSourceError(None, cls))
315
321
 
316
322
  # We can't rely on `inspect.getsourcelines` since it doesn't work properly for
317
- # classes prior to Python 3.13. See https://github.com/CQCL/guppylang/issues/1107.
323
+ # classes prior to Python 3.13. See https://github.com/quantinuum/guppylang/issues/1107.
318
324
  # Instead, we reproduce the behaviour of Python >= 3.13 using the `__firstlineno__`
319
325
  # attribute. See https://github.com/python/cpython/blob/3.13/Lib/inspect.py#L1052.
320
326
  # In the decorator, we make sure that `__firstlineno__` is set, even if we're not
@@ -2,7 +2,7 @@ from abc import abstractmethod
2
2
  from collections.abc import Callable, Sequence
3
3
  from dataclasses import dataclass, field
4
4
 
5
- from hugr import tys
5
+ from hugr import tys as ht
6
6
 
7
7
  from guppylang_internals.ast_util import AstNode
8
8
  from guppylang_internals.definition.common import CompiledDef, Definition
@@ -42,8 +42,8 @@ class OpaqueTypeDef(TypeDef, CompiledDef):
42
42
  params: Sequence[Parameter]
43
43
  never_copyable: bool
44
44
  never_droppable: bool
45
- to_hugr: Callable[[Sequence[Argument], ToHugrContext], tys.Type]
46
- bound: tys.TypeBound | None = None
45
+ to_hugr: Callable[[Sequence[Argument], ToHugrContext], ht.Type]
46
+ bound: ht.TypeBound | None = None
47
47
 
48
48
  def check_instantiate(
49
49
  self, args: Sequence[Argument], loc: AstNode | None = None
@@ -1,3 +1,4 @@
1
+ from dataclasses import dataclass, field
1
2
  from typing import TYPE_CHECKING
2
3
 
3
4
  from guppylang_internals.ast_util import AstNode
@@ -9,7 +10,8 @@ from guppylang_internals.definition.custom import (
9
10
  CustomFunctionDef,
10
11
  RawCustomFunctionDef,
11
12
  )
12
- from guppylang_internals.error import GuppyError
13
+ from guppylang_internals.engine import DEF_STORE
14
+ from guppylang_internals.error import GuppyError, GuppyTypeError
13
15
  from guppylang_internals.span import SourceMap
14
16
  from guppylang_internals.tys.builtin import wasm_module_name
15
17
  from guppylang_internals.tys.ty import (
@@ -21,24 +23,35 @@ from guppylang_internals.tys.ty import (
21
23
  TupleType,
22
24
  Type,
23
25
  )
26
+ from guppylang_internals.wasm_util import WasmSigMismatchError
24
27
 
25
28
  if TYPE_CHECKING:
26
29
  from guppylang_internals.checker.core import Globals
27
30
 
28
31
 
32
+ @dataclass(frozen=True)
29
33
  class RawWasmFunctionDef(RawCustomFunctionDef):
30
- def sanitise_type(self, loc: AstNode | None, fun_ty: FunctionType) -> None:
34
+ # If a function is specified in the @wasm decorator by its index in the wasm
35
+ # file, record what the index was.
36
+ wasm_index: int | None = field(default=None)
37
+
38
+ def sanitise_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
31
39
  # Place to highlight in error messages
32
- match fun_ty.inputs[0]:
33
- case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
40
+ match fun_ty.inputs:
41
+ case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if wasm_module_name(
34
42
  ty
35
43
  ) is not None:
36
- pass
37
- case FuncInput(ty=ty):
38
- raise GuppyError(FirstArgNotModule(loc, ty))
39
- for inp in fun_ty.inputs[1:]:
40
- if not self.is_type_wasmable(inp.ty):
41
- raise GuppyError(UnWasmableType(loc, inp.ty))
44
+ for inp in args:
45
+ if not self.is_type_wasmable(inp.ty):
46
+ raise GuppyError(UnWasmableType(loc, inp.ty))
47
+ case [FuncInput(ty=ty), *_]:
48
+ raise GuppyError(
49
+ FirstArgNotModule(loc).add_sub_diagnostic(
50
+ FirstArgNotModule.GotOtherType(loc, ty)
51
+ )
52
+ )
53
+ case []:
54
+ raise GuppyError(FirstArgNotModule(loc))
42
55
  if not self.is_type_wasmable(fun_ty.output):
43
56
  match fun_ty.output:
44
57
  case NoneType():
@@ -46,6 +59,23 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
46
59
  case _:
47
60
  raise GuppyError(UnWasmableType(loc, fun_ty.output))
48
61
 
62
+ def validate_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
63
+ type_in_wasm: FunctionType = DEF_STORE.wasm_functions[self.id]
64
+ assert type_in_wasm is not None
65
+ # Drop the first arg because it should be "self"
66
+ expected_type = FunctionType(fun_ty.inputs[1:], fun_ty.output)
67
+
68
+ if expected_type != type_in_wasm:
69
+ raise GuppyTypeError(
70
+ WasmSigMismatchError(loc)
71
+ .add_sub_diagnostic(
72
+ WasmSigMismatchError.Declaration(None, declared=str(expected_type))
73
+ )
74
+ .add_sub_diagnostic(
75
+ WasmSigMismatchError.Actual(None, actual=str(type_in_wasm))
76
+ )
77
+ )
78
+
49
79
  def is_type_wasmable(self, ty: Type) -> bool:
50
80
  match ty:
51
81
  case NumericType():
@@ -57,5 +87,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
57
87
 
58
88
  def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
59
89
  parsed = super().parse(globals, sources)
90
+ assert parsed.defined_at is not None
60
91
  self.sanitise_type(parsed.defined_at, parsed.ty)
92
+ self.validate_type(parsed.defined_at, parsed.ty)
61
93
  return parsed
@@ -46,6 +46,7 @@ from guppylang_internals.tys.builtin import (
46
46
  string_type_def,
47
47
  tuple_type_def,
48
48
  )
49
+ from guppylang_internals.tys.ty import FunctionType
49
50
 
50
51
  if TYPE_CHECKING:
51
52
  from guppylang_internals.compiler.core import MonoDefId
@@ -87,6 +88,7 @@ class DefinitionStore:
87
88
  raw_defs: dict[DefId, RawDef]
88
89
  impls: defaultdict[DefId, dict[str, DefId]]
89
90
  impl_parents: dict[DefId, DefId]
91
+ wasm_functions: dict[DefId, FunctionType]
90
92
  frames: dict[DefId, FrameType]
91
93
  sources: SourceMap
92
94
 
@@ -96,6 +98,7 @@ class DefinitionStore:
96
98
  self.impl_parents = {}
97
99
  self.frames = {}
98
100
  self.sources = SourceMap()
101
+ self.wasm_functions = {}
99
102
 
100
103
  def register_def(self, defn: RawDef, frame: FrameType | None) -> None:
101
104
  self.raw_defs[defn.id] = defn
@@ -123,6 +126,9 @@ class DefinitionStore:
123
126
  assert frame is not None
124
127
  self.frames[impl_id] = frame
125
128
 
129
+ def register_wasm_function(self, fn_id: DefId, sig: FunctionType) -> None:
130
+ self.wasm_functions[fn_id] = sig
131
+
126
132
 
127
133
  DEF_STORE: DefinitionStore = DefinitionStore()
128
134
 
@@ -263,8 +269,8 @@ class CompilationEngine:
263
269
  and isinstance(compiled_def, CompiledCallableDef)
264
270
  and not isinstance(graph.hugr[compiled_def.hugr_node].op, ops.FuncDecl)
265
271
  ):
266
- # if compiling a region set it as the HUGR entrypoint
267
- # can be loosened after https://github.com/CQCL/hugr/issues/2501 is fixed
272
+ # if compiling a region set it as the HUGR entrypoint can be
273
+ # loosened after https://github.com/quantinuum/hugr/issues/2501 is fixed
268
274
  graph.hugr.entrypoint = compiled_def.hugr_node
269
275
 
270
276
  # TODO: Currently the list of extensions is manually managed by the user.
@@ -278,7 +284,7 @@ class CompilationEngine:
278
284
  guppylang_internals.compiler.hugr_extension.EXTENSION,
279
285
  *self.additional_extensions,
280
286
  ]
281
- # TODO replace with computed extensions after https://github.com/CQCL/guppylang/issues/550
287
+ # TODO replace with computed extensions after https://github.com/quantinuum/guppylang/issues/550
282
288
  all_used_extensions = [
283
289
  *extensions,
284
290
  hugr.std.prelude.PRELUDE_EXTENSION,
@@ -90,3 +90,8 @@ def check_lists_enabled(loc: AstNode | None = None) -> None:
90
90
  def check_capturing_closures_enabled(loc: AstNode | None = None) -> None:
91
91
  if not EXPERIMENTAL_FEATURES_ENABLED:
92
92
  raise GuppyError(UnsupportedError(loc, "Capturing closures"))
93
+
94
+
95
+ def check_modifiers_enabled(loc: AstNode | None = None) -> None:
96
+ if not EXPERIMENTAL_FEATURES_ENABLED:
97
+ raise GuppyError(ExperimentalFeatureError(loc, "Modifiers"))
@@ -6,9 +6,16 @@ from enum import Enum
6
6
  from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from guppylang_internals.ast_util import AstNode
9
+ from guppylang_internals.span import Span, to_span
9
10
  from guppylang_internals.tys.const import Const
10
11
  from guppylang_internals.tys.subst import Inst
11
- from guppylang_internals.tys.ty import FunctionType, StructType, TupleType, Type
12
+ from guppylang_internals.tys.ty import (
13
+ FunctionType,
14
+ StructType,
15
+ TupleType,
16
+ Type,
17
+ UnitaryFlags,
18
+ )
12
19
 
13
20
  if TYPE_CHECKING:
14
21
  from guppylang_internals.cfg.cfg import CFG
@@ -249,22 +256,6 @@ class ComptimeExpr(ast.expr):
249
256
  _fields = ("value",)
250
257
 
251
258
 
252
- class ResultExpr(ast.expr):
253
- """A `result(tag, value)` expression."""
254
-
255
- value: ast.expr
256
- base_ty: Type
257
- #: Array length in case this is an array result, otherwise `None`
258
- array_len: Const | None
259
- tag: str
260
-
261
- _fields = ("value", "base_ty", "array_len", "tag")
262
-
263
- @property
264
- def args(self) -> list[ast.expr]:
265
- return [self.value]
266
-
267
-
268
259
  class ExitKind(Enum):
269
260
  ExitShot = 0 # Exit the current shot
270
261
  Panic = 1 # Panic the program ending all shots
@@ -274,8 +265,8 @@ class PanicExpr(ast.expr):
274
265
  """A `panic(msg, *args)` or `exit(msg, *args)` expression ."""
275
266
 
276
267
  kind: ExitKind
277
- signal: int
278
- msg: str
268
+ signal: ast.expr
269
+ msg: ast.expr
279
270
  values: list[ast.expr]
280
271
 
281
272
  _fields = ("kind", "signal", "msg", "values")
@@ -292,17 +283,16 @@ class BarrierExpr(ast.expr):
292
283
  class StateResultExpr(ast.expr):
293
284
  """A `state_result(tag, *args)` expression."""
294
285
 
295
- tag: str
286
+ tag_value: Const
287
+ tag_expr: ast.expr
296
288
  args: list[ast.expr]
297
289
  func_ty: FunctionType
298
290
  #: Array length in case this is an array result, otherwise `None`
299
291
  array_len: Const | None
300
- _fields = ("tag", "args", "func_ty", "has_array_input")
292
+ _fields = ("tag_value", "tag_expr", "args", "func_ty", "has_array_input")
301
293
 
302
294
 
303
- AnyCall = (
304
- LocalCall | GlobalCall | TensorCall | BarrierExpr | ResultExpr | StateResultExpr
305
- )
295
+ AnyCall = LocalCall | GlobalCall | TensorCall | BarrierExpr | StateResultExpr
306
296
 
307
297
 
308
298
  class InoutReturnSentinel(ast.expr):
@@ -422,3 +412,136 @@ class CheckedNestedFunctionDef(ast.FunctionDef):
422
412
  self.cfg = cfg
423
413
  self.ty = ty
424
414
  self.captured = captured
415
+
416
+
417
+ class Dagger(ast.expr):
418
+ """The dagger modifier"""
419
+
420
+ def __init__(self, node: ast.expr) -> None:
421
+ super().__init__(**node.__dict__)
422
+
423
+
424
+ class Control(ast.Call):
425
+ """The control modifier"""
426
+
427
+ ctrl: list[ast.expr]
428
+ qubit_num: int | Const | None
429
+
430
+ _fields = ("ctrl",)
431
+
432
+ def __init__(self, node: ast.Call, ctrl: list[ast.expr]) -> None:
433
+ super().__init__(**node.__dict__)
434
+ self.ctrl = ctrl
435
+ self.qubit_num = None
436
+
437
+
438
+ class Power(ast.expr):
439
+ """The power modifier"""
440
+
441
+ iter: ast.expr
442
+
443
+ _fields = ("iter",)
444
+
445
+ def __init__(self, node: ast.expr, iter: ast.expr) -> None:
446
+ super().__init__(**node.__dict__)
447
+ self.iter = iter
448
+
449
+
450
+ Modifier = Dagger | Control | Power
451
+
452
+
453
+ class ModifiedBlock(ast.With):
454
+ cfg: "CFG"
455
+ dagger: list[Dagger]
456
+ control: list[Control]
457
+ power: list[Power]
458
+
459
+ def __init__(self, cfg: "CFG", *args: Any, **kwargs: Any) -> None:
460
+ super().__init__(*args, **kwargs)
461
+ self.cfg = cfg
462
+ self.dagger = []
463
+ self.control = []
464
+ self.power = []
465
+
466
+ def is_dagger(self) -> bool:
467
+ return len(self.dagger) % 2 == 1
468
+
469
+ def is_control(self) -> bool:
470
+ return len(self.control) > 0
471
+
472
+ def is_power(self) -> bool:
473
+ return len(self.power) > 0
474
+
475
+ def span_ctxt_manager(self) -> Span:
476
+ return Span(
477
+ to_span(self.items[0].context_expr).start,
478
+ to_span(self.items[-1].context_expr).end,
479
+ )
480
+
481
+ def push_modifier(self, modifier: Modifier) -> None:
482
+ """Pushes a modifier kind onto the modifier."""
483
+ if isinstance(modifier, Dagger):
484
+ self.dagger.append(modifier)
485
+ elif isinstance(modifier, Control):
486
+ self.control.append(modifier)
487
+ elif isinstance(modifier, Power):
488
+ self.power.append(modifier)
489
+ else:
490
+ raise TypeError(f"Unknown modifier: {modifier}")
491
+
492
+ def flags(self) -> UnitaryFlags:
493
+ flags = UnitaryFlags.NoFlags
494
+ if self.is_dagger():
495
+ flags |= UnitaryFlags.Dagger
496
+ if self.is_control():
497
+ flags |= UnitaryFlags.Control
498
+ if self.is_power():
499
+ flags |= UnitaryFlags.Power
500
+ return flags
501
+
502
+
503
+ class CheckedModifiedBlock(ast.With):
504
+ def_id: "DefId"
505
+ cfg: "CheckedCFG[Place]"
506
+ dagger: list[Dagger]
507
+ control: list[Control]
508
+ power: list[Power]
509
+
510
+ #: The type of the body of With block.
511
+ ty: FunctionType
512
+ #: Mapping from names to variables captured in the body.
513
+ captured: Mapping[str, tuple["Variable", AstNode]]
514
+
515
+ def __init__(
516
+ self,
517
+ def_id: "DefId",
518
+ cfg: "CheckedCFG[Place]",
519
+ ty: FunctionType,
520
+ captured: Mapping[str, tuple["Variable", AstNode]],
521
+ dagger: list[Dagger],
522
+ control: list[Control],
523
+ power: list[Power],
524
+ *args: Any,
525
+ **kwargs: Any,
526
+ ) -> None:
527
+ super().__init__(*args, **kwargs)
528
+ self.def_id = def_id
529
+ self.cfg = cfg
530
+ self.ty = ty
531
+ self.captured = captured
532
+ self.dagger = dagger
533
+ self.control = control
534
+ self.power = power
535
+
536
+ def __str__(self) -> str:
537
+ # generate a function name from the def_id
538
+ return f"__WithBlock__({self.def_id})"
539
+
540
+ def has_dagger(self) -> bool:
541
+ return len(self.dagger) % 2 == 1
542
+
543
+ def has_control(self) -> bool:
544
+ return any(len(c.ctrl) > 0 for c in self.control)
545
+
546
+ def has_power(self) -> bool:
547
+ return len(self.power) > 0