guppylang-internals 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +101 -3
  5. guppylang_internals/checker/core.py +12 -0
  6. guppylang_internals/checker/errors/generic.py +32 -1
  7. guppylang_internals/checker/errors/type_errors.py +14 -0
  8. guppylang_internals/checker/expr_checker.py +55 -29
  9. guppylang_internals/checker/func_checker.py +171 -22
  10. guppylang_internals/checker/linearity_checker.py +65 -0
  11. guppylang_internals/checker/modifier_checker.py +116 -0
  12. guppylang_internals/checker/stmt_checker.py +49 -2
  13. guppylang_internals/compiler/core.py +90 -53
  14. guppylang_internals/compiler/expr_compiler.py +49 -114
  15. guppylang_internals/compiler/modifier_compiler.py +174 -0
  16. guppylang_internals/compiler/stmt_compiler.py +15 -8
  17. guppylang_internals/decorator.py +124 -58
  18. guppylang_internals/definition/const.py +2 -2
  19. guppylang_internals/definition/custom.py +36 -2
  20. guppylang_internals/definition/declaration.py +4 -5
  21. guppylang_internals/definition/extern.py +2 -2
  22. guppylang_internals/definition/function.py +1 -1
  23. guppylang_internals/definition/parameter.py +10 -5
  24. guppylang_internals/definition/pytket_circuits.py +14 -42
  25. guppylang_internals/definition/struct.py +17 -14
  26. guppylang_internals/definition/traced.py +1 -1
  27. guppylang_internals/definition/ty.py +9 -3
  28. guppylang_internals/definition/wasm.py +2 -2
  29. guppylang_internals/engine.py +13 -2
  30. guppylang_internals/experimental.py +5 -0
  31. guppylang_internals/nodes.py +124 -23
  32. guppylang_internals/std/_internal/compiler/array.py +94 -282
  33. guppylang_internals/std/_internal/compiler/tket_exts.py +12 -8
  34. guppylang_internals/std/_internal/compiler/wasm.py +37 -26
  35. guppylang_internals/tracing/function.py +13 -2
  36. guppylang_internals/tracing/unpacking.py +33 -28
  37. guppylang_internals/tys/arg.py +18 -3
  38. guppylang_internals/tys/builtin.py +32 -16
  39. guppylang_internals/tys/const.py +33 -4
  40. guppylang_internals/tys/errors.py +6 -0
  41. guppylang_internals/tys/param.py +31 -16
  42. guppylang_internals/tys/parsing.py +118 -145
  43. guppylang_internals/tys/qubit.py +27 -0
  44. guppylang_internals/tys/subst.py +8 -26
  45. guppylang_internals/tys/ty.py +31 -21
  46. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +4 -4
  47. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +49 -46
  48. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
  49. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
@@ -7,8 +7,8 @@ node straight from the Python AST. We build a CFG, check it, and return a
7
7
 
8
8
  import ast
9
9
  import sys
10
- from dataclasses import dataclass
11
- from typing import TYPE_CHECKING, ClassVar
10
+ from dataclasses import dataclass, replace
11
+ from typing import TYPE_CHECKING, ClassVar, cast
12
12
 
13
13
  from guppylang_internals.ast_util import return_nodes_in_ast, with_loc
14
14
  from guppylang_internals.cfg.bb import BB
@@ -17,13 +17,28 @@ from guppylang_internals.checker.cfg_checker import CheckedCFG, check_cfg
17
17
  from guppylang_internals.checker.core import Context, Globals, Place, Variable
18
18
  from guppylang_internals.checker.errors.generic import UnsupportedError
19
19
  from guppylang_internals.definition.common import DefId
20
+ from guppylang_internals.definition.ty import TypeDef
20
21
  from guppylang_internals.diagnostic import Error, Help, Note
21
22
  from guppylang_internals.engine import DEF_STORE, ENGINE
22
23
  from guppylang_internals.error import GuppyError
23
24
  from guppylang_internals.experimental import check_capturing_closures_enabled
24
25
  from guppylang_internals.nodes import CheckedNestedFunctionDef, NestedFunctionDef
25
- from guppylang_internals.tys.parsing import parse_function_io_types
26
- from guppylang_internals.tys.ty import FunctionType, InputFlags, NoneType
26
+ from guppylang_internals.tys.parsing import (
27
+ TypeParsingCtx,
28
+ check_function_arg,
29
+ parse_function_arg_annotation,
30
+ type_from_ast,
31
+ type_with_flags_from_ast,
32
+ )
33
+ from guppylang_internals.tys.ty import (
34
+ ExistentialTypeVar,
35
+ FuncInput,
36
+ FunctionType,
37
+ InputFlags,
38
+ NoneType,
39
+ Type,
40
+ unify,
41
+ )
27
42
 
28
43
  if sys.version_info >= (3, 12):
29
44
  from guppylang_internals.tys.parsing import parse_parameter
@@ -53,6 +68,15 @@ class MissingArgAnnotationError(Error):
53
68
  span_label: ClassVar[str] = "Argument requires a type annotation"
54
69
 
55
70
 
71
+ @dataclass(frozen=True)
72
+ class RecursiveSelfError(Error):
73
+ title: ClassVar[str] = "Recursive self annotation"
74
+ span_label: ClassVar[str] = (
75
+ "Type of `{self_arg}` cannot recursively refer to `Self`"
76
+ )
77
+ self_arg: str
78
+
79
+
56
80
  @dataclass(frozen=True)
57
81
  class MissingReturnAnnotationError(Error):
58
82
  title: ClassVar[str] = "Missing type annotation"
@@ -67,6 +91,43 @@ class MissingReturnAnnotationError(Error):
67
91
  func: str
68
92
 
69
93
 
94
+ @dataclass(frozen=True)
95
+ class InvalidSelfError(Error):
96
+ title: ClassVar[str] = "Invalid self annotation"
97
+ span_label: ClassVar[str] = "`{self_arg}` must be of type `{self_ty}`"
98
+ self_arg: str
99
+ self_ty: Type
100
+
101
+
102
+ @dataclass(frozen=True)
103
+ class SelfParamsShadowedError(Error):
104
+ title: ClassVar[str] = "Shadowed generic parameters"
105
+ span_label: ClassVar[str] = (
106
+ "Cannot infer type for `{self_arg}` since parameter `{param}` of "
107
+ "`{ty_defn.name}` is shadowed"
108
+ )
109
+ param: str
110
+ ty_defn: "TypeDef"
111
+ self_arg: str
112
+
113
+ @dataclass(frozen=True)
114
+ class ExplicitHelp(Help):
115
+ span_label: ClassVar[str] = (
116
+ "Consider specifying the type explicitly: `{suggestion}`"
117
+ )
118
+
119
+ @property
120
+ def suggestion(self) -> str:
121
+ parent = self._parent
122
+ assert isinstance(parent, SelfParamsShadowedError)
123
+ params = (
124
+ f"[{', '.join(f'?{p.name}' for p in parent.ty_defn.params)}]"
125
+ if parent.ty_defn.params
126
+ else ""
127
+ )
128
+ return f'{parent.self_arg}: "{parent.ty_defn.name}{params}"'
129
+
130
+
70
131
  def check_global_func_def(
71
132
  func_def: ast.FunctionDef, ty: FunctionType, globals: Globals
72
133
  ) -> CheckedCFG[Place]:
@@ -176,9 +237,16 @@ def check_nested_func_def(
176
237
  return with_loc(func_def, checked_def)
177
238
 
178
239
 
179
- def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType:
240
+ def check_signature(
241
+ func_def: ast.FunctionDef, globals: Globals, def_id: DefId | None = None
242
+ ) -> FunctionType:
180
243
  """Checks the signature of a function definition and returns the corresponding
181
- Guppy type."""
244
+ Guppy type.
245
+
246
+ If this is a method, then the `DefId` of the associated parent type should also be
247
+ passed. This will be used to check or infer the type annotation for the `self`
248
+ argument.
249
+ """
182
250
  if len(func_def.args.posonlyargs) != 0:
183
251
  raise GuppyError(
184
252
  UnsupportedError(func_def.args.posonlyargs[0], "Positional-only parameters")
@@ -208,26 +276,32 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
208
276
  param_var_mapping: dict[str, Parameter] = {}
209
277
  if sys.version_info >= (3, 12):
210
278
  for i, param_node in enumerate(func_def.type_params):
211
- param = parse_parameter(param_node, i, globals)
279
+ param = parse_parameter(param_node, i, globals, param_var_mapping)
212
280
  param_var_mapping[param.name] = param
213
281
 
214
- input_nodes = []
282
+ # Figure out if this is a method
283
+ self_defn: TypeDef | None = None
284
+ if def_id is not None and def_id in DEF_STORE.impl_parents:
285
+ self_defn = cast(TypeDef, ENGINE.get_checked(DEF_STORE.impl_parents[def_id]))
286
+ assert isinstance(self_defn, TypeDef)
287
+
288
+ inputs = []
215
289
  input_names = []
216
- for inp in func_def.args.args:
217
- ty_ast = inp.annotation
218
- if ty_ast is None:
219
- raise GuppyError(MissingArgAnnotationError(inp))
220
- input_nodes.append(ty_ast)
290
+ ctx = TypeParsingCtx(globals, param_var_mapping, allow_free_vars=True)
291
+ for i, inp in enumerate(func_def.args.args):
292
+ # Special handling for `self` arguments. Note that `__new__` is excluded here
293
+ # since it's not a method so doesn't take `self`.
294
+ if self_defn and i == 0 and func_def.name != "__new__":
295
+ input = parse_self_arg(inp, self_defn, ctx)
296
+ ctx = replace(ctx, self_ty=input.ty)
297
+ else:
298
+ ty_ast = inp.annotation
299
+ if ty_ast is None:
300
+ raise GuppyError(MissingArgAnnotationError(inp))
301
+ input = parse_function_arg_annotation(ty_ast, inp.arg, ctx)
302
+ inputs.append(input)
221
303
  input_names.append(inp.arg)
222
- inputs, output = parse_function_io_types(
223
- input_nodes,
224
- func_def.returns,
225
- input_names,
226
- func_def,
227
- globals,
228
- param_var_mapping,
229
- True,
230
- )
304
+ output = type_from_ast(func_def.returns, ctx)
231
305
  return FunctionType(
232
306
  inputs,
233
307
  output,
@@ -236,6 +310,81 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
236
310
  )
237
311
 
238
312
 
313
+ def parse_self_arg(arg: ast.arg, self_defn: TypeDef, ctx: TypeParsingCtx) -> FuncInput:
314
+ """Handles parsing of the `self` argument on methods.
315
+
316
+ This argument is special since its type annotation may be omitted. Furthermore, if a
317
+ type is provided then it must match the parent type.
318
+ """
319
+ assert self_defn.params is not None
320
+ if arg.annotation is None:
321
+ return handle_implicit_self_arg(arg, self_defn, ctx)
322
+
323
+ # If the user has provided an annotation for `self`, then we go ahead and parse it.
324
+ # However, in the annotation the user is also allowed to use `Self`, so we have to
325
+ # specify a `self_ty` in the context.
326
+ self_ty_head = self_defn.check_instantiate(
327
+ [param.to_existential()[0] for param in self_defn.params]
328
+ )
329
+ self_ty_placeholder = ExistentialTypeVar.fresh(
330
+ "Self", copyable=self_ty_head.copyable, droppable=self_ty_head.droppable
331
+ )
332
+ assert ctx.self_ty is None
333
+ ctx = replace(ctx, self_ty=self_ty_placeholder)
334
+ user_ty, user_flags = type_with_flags_from_ast(arg.annotation, ctx)
335
+
336
+ # If the user just annotates `self: Self` then we can fall back to the case where
337
+ # no annotation is provided at all
338
+ if user_ty == self_ty_placeholder:
339
+ return handle_implicit_self_arg(arg, self_defn, ctx, user_flags)
340
+
341
+ # Annotations like `self: Foo[Self]` are not allowed (would be an infinite type)
342
+ if self_ty_placeholder in user_ty.unsolved_vars:
343
+ raise GuppyError(RecursiveSelfError(arg.annotation, arg.arg))
344
+
345
+ # Check that the annotation matches the parent type. We can do this by unifying with
346
+ # the expected self type where all params are instantiated with unification vars
347
+ subst = unify(user_ty, self_ty_head, {})
348
+ if subst is None:
349
+ raise GuppyError(InvalidSelfError(arg.annotation, arg.arg, self_ty_head))
350
+
351
+ return check_function_arg(user_ty, user_flags, arg, arg.arg, ctx)
352
+
353
+
354
+ def handle_implicit_self_arg(
355
+ arg: ast.arg,
356
+ self_defn: TypeDef,
357
+ ctx: TypeParsingCtx,
358
+ flags: InputFlags = InputFlags.NoFlags,
359
+ ) -> FuncInput:
360
+ """Handles the case where no annotation for `self` is provided.
361
+
362
+ Generates the most generic annotation that is possible by making the function as
363
+ generic as the parent type.
364
+ """
365
+ # Check that the user hasn't shadowed some of the parent type parameters using a
366
+ # Python 3.12 style parameter declaration
367
+ assert self_defn.params is not None
368
+ shadowed_params = [
369
+ param for param in self_defn.params if param.name in ctx.param_var_mapping
370
+ ]
371
+ if shadowed_params:
372
+ param = shadowed_params.pop()
373
+ err = SelfParamsShadowedError(arg, param.name, self_defn, arg.arg)
374
+ err.add_sub_diagnostic(SelfParamsShadowedError.ExplicitHelp(arg))
375
+ raise GuppyError(err)
376
+
377
+ # The generic params inherited from the parent type should appear first in the
378
+ # parameter list, so we have to shift the existing ones
379
+ for name, param in ctx.param_var_mapping.items():
380
+ ctx.param_var_mapping[name] = param.with_idx(param.idx + len(self_defn.params))
381
+
382
+ ctx.param_var_mapping.update({param.name: param for param in self_defn.params})
383
+ self_args = [param.to_bound() for param in self_defn.params]
384
+ self_ty = self_defn.check_instantiate(self_args, loc=arg)
385
+ return check_function_arg(self_ty, flags, arg, arg.arg, ctx)
386
+
387
+
239
388
  def parse_function_with_docstring(
240
389
  func_ast: ast.FunctionDef,
241
390
  ) -> tuple[ast.FunctionDef, str | None]:
@@ -52,6 +52,7 @@ from guppylang_internals.error import GuppyError, GuppyTypeError
52
52
  from guppylang_internals.nodes import (
53
53
  AnyCall,
54
54
  BarrierExpr,
55
+ CheckedModifiedBlock,
55
56
  CheckedNestedFunctionDef,
56
57
  DesugaredArrayComp,
57
58
  DesugaredGenerator,
@@ -621,6 +622,70 @@ class BBLinearityChecker(ast.NodeVisitor):
621
622
  elif not place.ty.copyable:
622
623
  raise GuppyTypeError(ComprAlreadyUsedError(use.node, place, use.kind))
623
624
 
625
+ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
626
+ # Linear usage of variables in a with statement
627
+ # ```
628
+ # with control(c1, c2, ...):
629
+ # body(q1, q2, ...) # captured variables
630
+ # ````
631
+ # is the same as to assume that this is a function call
632
+ # `WithCtrl(q1, q2, ..., c1, c2, ...)`
633
+ # where `WithCtrl` is a function that takes the control as mutable references.
634
+ # Therefore, we apply the same linearity rules as for function arguments.
635
+ # ```
636
+ # def WithCtrl(q1, q2, ..., c1, c2, ...):
637
+ # body(q1, q2, ...)
638
+ # ```
639
+
640
+ # check control
641
+ for ctrl in node.control:
642
+ for arg in ctrl.ctrl:
643
+ if isinstance(arg, PlaceNode):
644
+ self.visit_PlaceNode(arg, use_kind=UseKind.BORROW, is_call_arg=None)
645
+ else:
646
+ ty = get_type(arg)
647
+ unnamed_err = UnnamedExprNotUsedError(arg, ty)
648
+ unnamed_err.add_sub_diagnostic(UnnamedExprNotUsedError.Fix(None))
649
+ raise GuppyTypeError(unnamed_err)
650
+
651
+ # check power
652
+ for power in node.power:
653
+ if isinstance(power.iter, PlaceNode):
654
+ self.visit_PlaceNode(
655
+ power.iter, use_kind=UseKind.CONSUME, is_call_arg=None
656
+ )
657
+ else:
658
+ self.visit(power.iter)
659
+
660
+ # check captured variables
661
+ for var, use in node.captured.values():
662
+ for place in leaf_places(var):
663
+ use_kind = (
664
+ UseKind.BORROW if InputFlags.Inout in var.flags else UseKind.CONSUME
665
+ )
666
+
667
+ x = place.id
668
+ if (prev_use := self.scope.used(x)) and not place.ty.copyable:
669
+ used_err = AlreadyUsedError(use, place, use_kind)
670
+ used_err.add_sub_diagnostic(
671
+ AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind)
672
+ )
673
+ if has_explicit_copy(place.ty):
674
+ used_err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
675
+ raise GuppyError(used_err)
676
+ self.scope.use(x, node, use_kind)
677
+
678
+ # reassign controls
679
+ for ctrl in node.control:
680
+ for arg in ctrl.ctrl:
681
+ assert isinstance(arg, PlaceNode) # Checked above
682
+ self._reassign_single_inout_arg(arg.place, arg.place.defined_at or arg)
683
+
684
+ # reassign captured variables
685
+ for var, use in node.captured.values():
686
+ if InputFlags.Inout in var.flags:
687
+ self._reassign_single_inout_arg(var, var.defined_at or use)
688
+
624
689
 
625
690
  def leaf_places(place: Place) -> Iterator[Place]:
626
691
  """Returns all leaf descendant projections of a place."""
@@ -0,0 +1,116 @@
1
+ """Type checking code for modifiers."""
2
+
3
+ import ast
4
+
5
+ from guppylang_internals.ast_util import loop_in_ast, with_loc
6
+ from guppylang_internals.cfg.bb import BB
7
+ from guppylang_internals.checker.cfg_checker import check_cfg
8
+ from guppylang_internals.checker.core import Context, Variable
9
+ from guppylang_internals.checker.errors.generic import InvalidUnderDagger
10
+ from guppylang_internals.definition.common import DefId
11
+ from guppylang_internals.error import GuppyError
12
+ from guppylang_internals.nodes import CheckedModifiedBlock, ModifiedBlock
13
+ from guppylang_internals.tys.ty import (
14
+ FuncInput,
15
+ FunctionType,
16
+ InputFlags,
17
+ NoneType,
18
+ Type,
19
+ )
20
+
21
+
22
+ def check_modified_block(
23
+ modified_block: ModifiedBlock, bb: BB, ctx: Context
24
+ ) -> CheckedModifiedBlock:
25
+ """Type checks a modifier definition."""
26
+ cfg = modified_block.cfg
27
+
28
+ # Find captured variables
29
+ parent_cfg = bb.containing_cfg
30
+ def_ass_before = ctx.locals.keys()
31
+ maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb]
32
+
33
+ cfg.analyze(def_ass_before, maybe_ass_before, [])
34
+ captured = {
35
+ x: (_set_inout_if_non_copyable(ctx.locals[x]), using_bb.vars.used[x])
36
+ for x, using_bb in cfg.live_before[cfg.entry_bb].items()
37
+ if x in ctx.locals
38
+ }
39
+
40
+ # We do not allow any assignments if it is daggered.
41
+ if modified_block.is_dagger():
42
+ for stmt in modified_block.body:
43
+ loops = loop_in_ast(stmt)
44
+ if len(loops) != 0:
45
+ loop = next(iter(loops))
46
+ err = InvalidUnderDagger(loop, "Loop")
47
+ err.add_sub_diagnostic(
48
+ InvalidUnderDagger.Dagger(modified_block.span_ctxt_manager())
49
+ )
50
+ raise GuppyError(err)
51
+
52
+ for cfg_bb in cfg.bbs:
53
+ if cfg_bb.vars.assigned:
54
+ _, v = next(iter(cfg_bb.vars.assigned.items()))
55
+ err = InvalidUnderDagger(v, "Assignment")
56
+ err.add_sub_diagnostic(
57
+ InvalidUnderDagger.Dagger(modified_block.span_ctxt_manager())
58
+ )
59
+ raise GuppyError(err)
60
+
61
+ # The other checks are done in unitary checking.
62
+ # e.g. call to non-unitary function in a unitary modifier.
63
+
64
+ # Construct inputs for checking the body CFG
65
+ inputs = [v for v, _ in captured.values()]
66
+ inputs = non_copyable_front_others_back(inputs)
67
+ def_id = DefId.fresh()
68
+ globals = ctx.globals
69
+
70
+ # TODO: Ad hoc name for the new function
71
+ # This name could be printed in error messages, for example,
72
+ # when the linearity checker fails in the modifier body
73
+ checked_cfg = check_cfg(cfg, inputs, NoneType(), {}, "__modified__()", globals)
74
+ func_ty = check_modified_block_signature(checked_cfg.input_tys)
75
+
76
+ checked_modifier = CheckedModifiedBlock(
77
+ def_id,
78
+ checked_cfg,
79
+ func_ty,
80
+ captured,
81
+ modified_block.dagger,
82
+ modified_block.control,
83
+ modified_block.power,
84
+ **dict(ast.iter_fields(modified_block)),
85
+ )
86
+ return with_loc(modified_block, checked_modifier)
87
+
88
+
89
+ def _set_inout_if_non_copyable(var: Variable) -> Variable:
90
+ """Set the `inout` flag if the variable is non-copyable."""
91
+ if not var.ty.copyable:
92
+ return var.add_flags(InputFlags.Inout)
93
+ else:
94
+ return var
95
+
96
+
97
+ def check_modified_block_signature(input_tys: list[Type]) -> FunctionType:
98
+ """Check and create the signature of a function definition for a body
99
+ of a `With` block."""
100
+
101
+ func_ty = FunctionType(
102
+ [
103
+ FuncInput(t, InputFlags.Inout if not t.copyable else InputFlags.NoFlags)
104
+ for t in input_tys
105
+ ],
106
+ NoneType(),
107
+ )
108
+ return func_ty
109
+
110
+
111
+ def non_copyable_front_others_back(v: list[Variable]) -> list[Variable]:
112
+ """Reorder variables so that linear ones come first, preserving the relative order
113
+ of linear and non-linear variables."""
114
+ linear_vars = [x for x in v if not x.ty.copyable]
115
+ non_linear_vars = [x for x in v if x.ty.copyable]
116
+ return linear_vars + non_linear_vars
@@ -42,7 +42,9 @@ from guppylang_internals.checker.errors.type_errors import (
42
42
  MissingReturnValueError,
43
43
  StarredTupleUnpackError,
44
44
  TypeInferenceError,
45
+ TypeMismatchError,
45
46
  UnpackableError,
47
+ WrongNumberOfArgsError,
46
48
  WrongNumberOfUnpacksError,
47
49
  )
48
50
  from guppylang_internals.checker.expr_checker import (
@@ -58,6 +60,7 @@ from guppylang_internals.nodes import (
58
60
  DesugaredArrayComp,
59
61
  IterableUnpack,
60
62
  MakeIter,
63
+ ModifiedBlock,
61
64
  NestedFunctionDef,
62
65
  PlaceNode,
63
66
  TupleUnpack,
@@ -73,13 +76,15 @@ from guppylang_internals.tys.builtin import (
73
76
  is_sized_iter_type,
74
77
  nat_type,
75
78
  )
76
- from guppylang_internals.tys.const import ConstValue
79
+ from guppylang_internals.tys.const import ConstValue, ExistentialConstVar
77
80
  from guppylang_internals.tys.parsing import type_from_ast
81
+ from guppylang_internals.tys.qubit import is_qubit_ty, qubit_ty
78
82
  from guppylang_internals.tys.subst import Subst
79
83
  from guppylang_internals.tys.ty import (
80
84
  ExistentialTypeVar,
81
85
  FunctionType,
82
86
  NoneType,
87
+ NumericType,
83
88
  StructType,
84
89
  TupleType,
85
90
  Type,
@@ -356,7 +361,7 @@ class StmtChecker(AstVisitor[BBStatement]):
356
361
  def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
357
362
  if node.value is None:
358
363
  raise GuppyError(UnsupportedError(node, "Variable declarations"))
359
- ty = type_from_ast(node.annotation, self.ctx.globals, self.ctx.generic_params)
364
+ ty = type_from_ast(node.annotation, self.ctx.parsing_ctx)
360
365
  node.value, subst = self._check_expr(node.value, ty)
361
366
  assert not ty.unsolved_vars # `ty` must be closed!
362
367
  assert len(subst) == 0
@@ -398,6 +403,48 @@ class StmtChecker(AstVisitor[BBStatement]):
398
403
  self.ctx.locals[func_def.name] = Variable(func_def.name, func_def.ty, func_def)
399
404
  return func_def
400
405
 
406
+ def visit_ModifiedBlock(self, node: ModifiedBlock) -> ast.stmt:
407
+ from guppylang_internals.checker.modifier_checker import check_modified_block
408
+
409
+ if not self.bb:
410
+ raise InternalGuppyError("BB required to check with block!")
411
+
412
+ # check the body of the modified block
413
+ modified_block = check_modified_block(node, self.bb, self.ctx)
414
+
415
+ # check the arguments of the control and power.
416
+ for control in modified_block.control:
417
+ ctrl = control.ctrl
418
+ # This case is handled during CFG construction.
419
+ assert len(ctrl) > 0
420
+ ctrl[0], ty = self._synth_expr(ctrl[0])
421
+
422
+ if is_array_type(ty):
423
+ if len(ctrl) > 1:
424
+ span = Span(to_span(control.func).end, to_span(control).end)
425
+ raise GuppyError(WrongNumberOfArgsError(span, 1, len(control.args)))
426
+ element_ty = get_element_type(ty)
427
+ if not is_qubit_ty(element_ty):
428
+ n = ExistentialConstVar.fresh(
429
+ "n", NumericType(NumericType.Kind.Nat)
430
+ )
431
+ dummy_array_ty = array_type(qubit_ty(), n)
432
+ raise GuppyTypeError(TypeMismatchError(ctrl[0], dummy_array_ty, ty))
433
+ control.qubit_num = get_array_length(ty)
434
+ else:
435
+ for i in range(len(ctrl)):
436
+ ctrl[i], subst = self._check_expr(ctrl[i], qubit_ty())
437
+ assert len(subst) == 0
438
+ control.qubit_num = len(ctrl)
439
+
440
+ for power in node.power:
441
+ power.iter, subst = self._check_expr(
442
+ power.iter, NumericType(NumericType.Kind.Nat)
443
+ )
444
+ assert len(subst) == 0
445
+
446
+ return modified_block
447
+
401
448
  def visit_If(self, node: ast.If) -> None:
402
449
  raise InternalGuppyError("Control-flow statement should not be present here.")
403
450