guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +118 -5
  5. guppylang_internals/cfg/cfg.py +3 -0
  6. guppylang_internals/checker/cfg_checker.py +6 -0
  7. guppylang_internals/checker/core.py +5 -2
  8. guppylang_internals/checker/errors/generic.py +32 -1
  9. guppylang_internals/checker/errors/type_errors.py +14 -0
  10. guppylang_internals/checker/errors/wasm.py +7 -4
  11. guppylang_internals/checker/expr_checker.py +58 -17
  12. guppylang_internals/checker/func_checker.py +18 -14
  13. guppylang_internals/checker/linearity_checker.py +67 -10
  14. guppylang_internals/checker/modifier_checker.py +120 -0
  15. guppylang_internals/checker/stmt_checker.py +48 -1
  16. guppylang_internals/checker/unitary_checker.py +132 -0
  17. guppylang_internals/compiler/cfg_compiler.py +7 -6
  18. guppylang_internals/compiler/core.py +93 -56
  19. guppylang_internals/compiler/expr_compiler.py +72 -168
  20. guppylang_internals/compiler/modifier_compiler.py +176 -0
  21. guppylang_internals/compiler/stmt_compiler.py +15 -8
  22. guppylang_internals/decorator.py +86 -7
  23. guppylang_internals/definition/custom.py +39 -1
  24. guppylang_internals/definition/declaration.py +9 -6
  25. guppylang_internals/definition/function.py +12 -2
  26. guppylang_internals/definition/parameter.py +8 -3
  27. guppylang_internals/definition/pytket_circuits.py +14 -41
  28. guppylang_internals/definition/struct.py +13 -7
  29. guppylang_internals/definition/ty.py +3 -3
  30. guppylang_internals/definition/wasm.py +42 -10
  31. guppylang_internals/engine.py +9 -3
  32. guppylang_internals/experimental.py +5 -0
  33. guppylang_internals/nodes.py +147 -24
  34. guppylang_internals/std/_internal/checker.py +13 -108
  35. guppylang_internals/std/_internal/compiler/array.py +95 -283
  36. guppylang_internals/std/_internal/compiler/list.py +1 -1
  37. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  38. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  39. guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
  40. guppylang_internals/std/_internal/debug.py +18 -9
  41. guppylang_internals/std/_internal/util.py +1 -1
  42. guppylang_internals/tracing/object.py +10 -0
  43. guppylang_internals/tracing/unpacking.py +19 -20
  44. guppylang_internals/tys/arg.py +18 -3
  45. guppylang_internals/tys/builtin.py +2 -5
  46. guppylang_internals/tys/const.py +33 -4
  47. guppylang_internals/tys/errors.py +23 -1
  48. guppylang_internals/tys/param.py +31 -16
  49. guppylang_internals/tys/parsing.py +11 -24
  50. guppylang_internals/tys/printing.py +2 -8
  51. guppylang_internals/tys/qubit.py +62 -0
  52. guppylang_internals/tys/subst.py +8 -26
  53. guppylang_internals/tys/ty.py +91 -85
  54. guppylang_internals/wasm_util.py +129 -0
  55. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
  56. guppylang_internals-0.26.0.dist-info/RECORD +104 -0
  57. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
  58. guppylang_internals-0.24.0.dist-info/RECORD +0 -98
  59. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
@@ -16,6 +16,7 @@ from guppylang_internals.cfg.builder import CFGBuilder
16
16
  from guppylang_internals.checker.cfg_checker import CheckedCFG, check_cfg
17
17
  from guppylang_internals.checker.core import Context, Globals, Place, Variable
18
18
  from guppylang_internals.checker.errors.generic import UnsupportedError
19
+ from guppylang_internals.checker.unitary_checker import check_invalid_under_dagger
19
20
  from guppylang_internals.definition.common import DefId
20
21
  from guppylang_internals.definition.ty import TypeDef
21
22
  from guppylang_internals.diagnostic import Error, Help, Note
@@ -37,6 +38,7 @@ from guppylang_internals.tys.ty import (
37
38
  InputFlags,
38
39
  NoneType,
39
40
  Type,
41
+ UnitaryFlags,
40
42
  unify,
41
43
  )
42
44
 
@@ -134,12 +136,13 @@ def check_global_func_def(
134
136
  """Type checks a top-level function definition."""
135
137
  args = func_def.args.args
136
138
  returns_none = isinstance(ty.output, NoneType)
137
- assert ty.input_names is not None
139
+ assert all(inp.name is not None for inp in ty.inputs)
138
140
 
139
- cfg = CFGBuilder().build(func_def.body, returns_none, globals)
141
+ check_invalid_under_dagger(func_def, ty.unitary_flags)
142
+ cfg = CFGBuilder().build(func_def.body, returns_none, globals, ty.unitary_flags)
140
143
  inputs = [
141
- Variable(x, inp.ty, loc, inp.flags, is_func_input=True)
142
- for x, inp, loc in zip(ty.input_names, ty.inputs, args, strict=True)
144
+ Variable(cast(str, inp.name), inp.ty, loc, inp.flags, is_func_input=True)
145
+ for inp, loc in zip(ty.inputs, args, strict=True)
143
146
  # Comptime inputs are turned into generic args, so are not included here
144
147
  if InputFlags.Comptime not in inp.flags
145
148
  ]
@@ -150,7 +153,9 @@ def check_global_func_def(
150
153
 
151
154
 
152
155
  def check_nested_func_def(
153
- func_def: NestedFunctionDef, bb: BB, ctx: Context
156
+ func_def: NestedFunctionDef,
157
+ bb: BB,
158
+ ctx: Context,
154
159
  ) -> CheckedNestedFunctionDef:
155
160
  """Type checks a local (nested) function definition."""
156
161
  func_ty = check_signature(func_def, ctx.globals)
@@ -194,10 +199,8 @@ def check_nested_func_def(
194
199
 
195
200
  # Construct inputs for checking the body CFG
196
201
  inputs = [v for v, _ in captured.values()] + [
197
- Variable(x, inp.ty, func_def.args.args[i], inp.flags, is_func_input=True)
198
- for i, (x, inp) in enumerate(
199
- zip(func_ty.input_names, func_ty.inputs, strict=True)
200
- )
202
+ Variable(cast(str, inp.name), inp.ty, arg, inp.flags, is_func_input=True)
203
+ for arg, inp in zip(func_def.args.args, func_ty.inputs, strict=True)
201
204
  # Comptime inputs are turned into generic args, so are not included here
202
205
  if InputFlags.Comptime not in inp.flags
203
206
  ]
@@ -238,7 +241,10 @@ def check_nested_func_def(
238
241
 
239
242
 
240
243
  def check_signature(
241
- func_def: ast.FunctionDef, globals: Globals, def_id: DefId | None = None
244
+ func_def: ast.FunctionDef,
245
+ globals: Globals,
246
+ def_id: DefId | None = None,
247
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
242
248
  ) -> FunctionType:
243
249
  """Checks the signature of a function definition and returns the corresponding
244
250
  Guppy type.
@@ -276,7 +282,7 @@ def check_signature(
276
282
  param_var_mapping: dict[str, Parameter] = {}
277
283
  if sys.version_info >= (3, 12):
278
284
  for i, param_node in enumerate(func_def.type_params):
279
- param = parse_parameter(param_node, i, globals)
285
+ param = parse_parameter(param_node, i, globals, param_var_mapping)
280
286
  param_var_mapping[param.name] = param
281
287
 
282
288
  # Figure out if this is a method
@@ -286,7 +292,6 @@ def check_signature(
286
292
  assert isinstance(self_defn, TypeDef)
287
293
 
288
294
  inputs = []
289
- input_names = []
290
295
  ctx = TypeParsingCtx(globals, param_var_mapping, allow_free_vars=True)
291
296
  for i, inp in enumerate(func_def.args.args):
292
297
  # Special handling for `self` arguments. Note that `__new__` is excluded here
@@ -300,13 +305,12 @@ def check_signature(
300
305
  raise GuppyError(MissingArgAnnotationError(inp))
301
306
  input = parse_function_arg_annotation(ty_ast, inp.arg, ctx)
302
307
  inputs.append(input)
303
- input_names.append(inp.arg)
304
308
  output = type_from_ast(func_def.returns, ctx)
305
309
  return FunctionType(
306
310
  inputs,
307
311
  output,
308
- input_names,
309
312
  sorted(param_var_mapping.values(), key=lambda v: v.idx),
313
+ unitary_flags=unitary_flags,
310
314
  )
311
315
 
312
316
 
@@ -52,6 +52,7 @@ from guppylang_internals.error import GuppyError, GuppyTypeError
52
52
  from guppylang_internals.nodes import (
53
53
  AnyCall,
54
54
  BarrierExpr,
55
+ CheckedModifiedBlock,
55
56
  CheckedNestedFunctionDef,
56
57
  DesugaredArrayComp,
57
58
  DesugaredGenerator,
@@ -62,7 +63,6 @@ from guppylang_internals.nodes import (
62
63
  LocalCall,
63
64
  PartialApply,
64
65
  PlaceNode,
65
- ResultExpr,
66
66
  StateResultExpr,
67
67
  SubscriptAccessAndDrop,
68
68
  TensorCall,
@@ -73,7 +73,6 @@ from guppylang_internals.tys.ty import (
73
73
  FuncInput,
74
74
  FunctionType,
75
75
  InputFlags,
76
- NoneType,
77
76
  StructType,
78
77
  TupleType,
79
78
  Type,
@@ -450,13 +449,6 @@ class BBLinearityChecker(ast.NodeVisitor):
450
449
  self._visit_call_args(node.func_ty, node)
451
450
  self._reassign_inout_args(node.func_ty, node)
452
451
 
453
- def visit_ResultExpr(self, node: ResultExpr) -> None:
454
- ty = get_type(node.value)
455
- flag = InputFlags.Inout if not ty.copyable else InputFlags.NoFlags
456
- func_ty = FunctionType([FuncInput(ty, flag)], NoneType())
457
- self._visit_call_args(func_ty, node)
458
- self._reassign_inout_args(func_ty, node)
459
-
460
452
  def visit_StateResultExpr(self, node: StateResultExpr) -> None:
461
453
  self._visit_call_args(node.func_ty, node)
462
454
  self._reassign_inout_args(node.func_ty, node)
@@ -581,7 +573,7 @@ class BBLinearityChecker(ast.NodeVisitor):
581
573
  # can feed them through the loop. Note that we could also use non-local
582
574
  # edges, but we can't handle them in lower parts of the stack yet :/
583
575
  # TODO: Reinstate use of non-local edges.
584
- # See https://github.com/CQCL/guppylang/issues/963
576
+ # See https://github.com/quantinuum/guppylang/issues/963
585
577
  gen.used_outer_places = []
586
578
  for x, use in inner_scope.used_parent.items():
587
579
  place = inner_scope[x]
@@ -621,6 +613,70 @@ class BBLinearityChecker(ast.NodeVisitor):
621
613
  elif not place.ty.copyable:
622
614
  raise GuppyTypeError(ComprAlreadyUsedError(use.node, place, use.kind))
623
615
 
616
+ def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
617
+ # Linear usage of variables in a with statement
618
+ # ```
619
+ # with control(c1, c2, ...):
620
+ # body(q1, q2, ...) # captured variables
621
+ # ````
622
+ # is the same as to assume that this is a function call
623
+ # `WithCtrl(q1, q2, ..., c1, c2, ...)`
624
+ # where `WithCtrl` is a function that takes the control as mutable references.
625
+ # Therefore, we apply the same linearity rules as for function arguments.
626
+ # ```
627
+ # def WithCtrl(q1, q2, ..., c1, c2, ...):
628
+ # body(q1, q2, ...)
629
+ # ```
630
+
631
+ # check control
632
+ for ctrl in node.control:
633
+ for arg in ctrl.ctrl:
634
+ if isinstance(arg, PlaceNode):
635
+ self.visit_PlaceNode(arg, use_kind=UseKind.BORROW, is_call_arg=None)
636
+ else:
637
+ ty = get_type(arg)
638
+ unnamed_err = UnnamedExprNotUsedError(arg, ty)
639
+ unnamed_err.add_sub_diagnostic(UnnamedExprNotUsedError.Fix(None))
640
+ raise GuppyTypeError(unnamed_err)
641
+
642
+ # check power
643
+ for power in node.power:
644
+ if isinstance(power.iter, PlaceNode):
645
+ self.visit_PlaceNode(
646
+ power.iter, use_kind=UseKind.CONSUME, is_call_arg=None
647
+ )
648
+ else:
649
+ self.visit(power.iter)
650
+
651
+ # check captured variables
652
+ for var, use in node.captured.values():
653
+ for place in leaf_places(var):
654
+ use_kind = (
655
+ UseKind.BORROW if InputFlags.Inout in var.flags else UseKind.CONSUME
656
+ )
657
+
658
+ x = place.id
659
+ if (prev_use := self.scope.used(x)) and not place.ty.copyable:
660
+ used_err = AlreadyUsedError(use, place, use_kind)
661
+ used_err.add_sub_diagnostic(
662
+ AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind)
663
+ )
664
+ if has_explicit_copy(place.ty):
665
+ used_err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
666
+ raise GuppyError(used_err)
667
+ self.scope.use(x, node, use_kind)
668
+
669
+ # reassign controls
670
+ for ctrl in node.control:
671
+ for arg in ctrl.ctrl:
672
+ assert isinstance(arg, PlaceNode) # Checked above
673
+ self._reassign_single_inout_arg(arg.place, arg.place.defined_at or arg)
674
+
675
+ # reassign captured variables
676
+ for var, use in node.captured.values():
677
+ if InputFlags.Inout in var.flags:
678
+ self._reassign_single_inout_arg(var, var.defined_at or use)
679
+
624
680
 
625
681
  def leaf_places(place: Place) -> Iterator[Place]:
626
682
  """Returns all leaf descendant projections of a place."""
@@ -815,6 +871,7 @@ def check_cfg_linearity(
815
871
  result_cfg.maybe_ass_before = {
816
872
  checked[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs
817
873
  }
874
+ result_cfg.unitary_flags = cfg.unitary_flags
818
875
  for bb in cfg.bbs:
819
876
  checked[bb].predecessors = [checked[pred] for pred in bb.predecessors]
820
877
  checked[bb].successors = [checked[succ] for succ in bb.successors]
@@ -0,0 +1,120 @@
1
+ """Type checking code for modifiers."""
2
+
3
+ import ast
4
+
5
+ from guppylang_internals.ast_util import loop_in_ast, with_loc
6
+ from guppylang_internals.cfg.bb import BB
7
+ from guppylang_internals.checker.cfg_checker import check_cfg
8
+ from guppylang_internals.checker.core import Context, Variable
9
+ from guppylang_internals.checker.errors.generic import InvalidUnderDagger
10
+ from guppylang_internals.definition.common import DefId
11
+ from guppylang_internals.error import GuppyError
12
+ from guppylang_internals.nodes import CheckedModifiedBlock, ModifiedBlock
13
+ from guppylang_internals.tys.ty import (
14
+ FuncInput,
15
+ FunctionType,
16
+ InputFlags,
17
+ NoneType,
18
+ Type,
19
+ )
20
+
21
+
22
+ def check_modified_block(
23
+ modified_block: ModifiedBlock, bb: BB, ctx: Context
24
+ ) -> CheckedModifiedBlock:
25
+ """Type checks a modifier definition."""
26
+ cfg = modified_block.cfg
27
+
28
+ # Find captured variables
29
+ parent_cfg = bb.containing_cfg
30
+ def_ass_before = ctx.locals.keys()
31
+ maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb]
32
+
33
+ cfg.analyze(def_ass_before, maybe_ass_before, [])
34
+ captured = {
35
+ x: (_set_inout_if_non_copyable(ctx.locals[x]), using_bb.vars.used[x])
36
+ for x, using_bb in cfg.live_before[cfg.entry_bb].items()
37
+ if x in ctx.locals
38
+ }
39
+
40
+ # We do not allow any assignments if it is daggered.
41
+ if modified_block.is_dagger():
42
+ for stmt in modified_block.body:
43
+ loops = loop_in_ast(stmt)
44
+ if len(loops) != 0:
45
+ loop = next(iter(loops))
46
+ err = InvalidUnderDagger(loop, "Loop")
47
+ err.add_sub_diagnostic(
48
+ InvalidUnderDagger.Dagger(modified_block.span_ctxt_manager())
49
+ )
50
+ raise GuppyError(err)
51
+
52
+ for cfg_bb in cfg.bbs:
53
+ if cfg_bb.vars.assigned:
54
+ _, v = next(iter(cfg_bb.vars.assigned.items()))
55
+ err = InvalidUnderDagger(v, "Assignment")
56
+ err.add_sub_diagnostic(
57
+ InvalidUnderDagger.Dagger(modified_block.span_ctxt_manager())
58
+ )
59
+ raise GuppyError(err)
60
+
61
+ # The other checks are done in unitary checking.
62
+ # e.g. call to non-unitary function in a unitary modifier.
63
+
64
+ # Construct inputs for checking the body CFG
65
+ inputs = [v for v, _ in captured.values()]
66
+ inputs = non_copyable_front_others_back(inputs)
67
+ def_id = DefId.fresh()
68
+ globals = ctx.globals
69
+
70
+ # TODO: Ad hoc name for the new function
71
+ # This name could be printed in error messages, for example,
72
+ # when the linearity checker fails in the modifier body
73
+ checked_cfg = check_cfg(cfg, inputs, NoneType(), {}, "__modified__()", globals)
74
+ func_ty = check_modified_block_signature(modified_block, checked_cfg.input_tys)
75
+
76
+ checked_modifier = CheckedModifiedBlock(
77
+ def_id,
78
+ checked_cfg,
79
+ func_ty,
80
+ captured,
81
+ modified_block.dagger,
82
+ modified_block.control,
83
+ modified_block.power,
84
+ **dict(ast.iter_fields(modified_block)),
85
+ )
86
+ return with_loc(modified_block, checked_modifier)
87
+
88
+
89
+ def _set_inout_if_non_copyable(var: Variable) -> Variable:
90
+ """Set the `inout` flag if the variable is non-copyable."""
91
+ if not var.ty.copyable:
92
+ return var.add_flags(InputFlags.Inout)
93
+ else:
94
+ return var
95
+
96
+
97
+ def check_modified_block_signature(
98
+ modified_block: ModifiedBlock, input_tys: list[Type]
99
+ ) -> FunctionType:
100
+ """Check and create the signature of a function definition for a body
101
+ of a `With` block."""
102
+ unitary_flags = modified_block.flags()
103
+
104
+ func_ty = FunctionType(
105
+ [
106
+ FuncInput(t, InputFlags.Inout if not t.copyable else InputFlags.NoFlags)
107
+ for t in input_tys
108
+ ],
109
+ NoneType(),
110
+ unitary_flags=unitary_flags,
111
+ )
112
+ return func_ty
113
+
114
+
115
+ def non_copyable_front_others_back(v: list[Variable]) -> list[Variable]:
116
+ """Reorder variables so that linear ones come first, preserving the relative order
117
+ of linear and non-linear variables."""
118
+ linear_vars = [x for x in v if not x.ty.copyable]
119
+ non_linear_vars = [x for x in v if x.ty.copyable]
120
+ return linear_vars + non_linear_vars
@@ -42,7 +42,9 @@ from guppylang_internals.checker.errors.type_errors import (
42
42
  MissingReturnValueError,
43
43
  StarredTupleUnpackError,
44
44
  TypeInferenceError,
45
+ TypeMismatchError,
45
46
  UnpackableError,
47
+ WrongNumberOfArgsError,
46
48
  WrongNumberOfUnpacksError,
47
49
  )
48
50
  from guppylang_internals.checker.expr_checker import (
@@ -58,6 +60,7 @@ from guppylang_internals.nodes import (
58
60
  DesugaredArrayComp,
59
61
  IterableUnpack,
60
62
  MakeIter,
63
+ ModifiedBlock,
61
64
  NestedFunctionDef,
62
65
  PlaceNode,
63
66
  TupleUnpack,
@@ -73,13 +76,15 @@ from guppylang_internals.tys.builtin import (
73
76
  is_sized_iter_type,
74
77
  nat_type,
75
78
  )
76
- from guppylang_internals.tys.const import ConstValue
79
+ from guppylang_internals.tys.const import ConstValue, ExistentialConstVar
77
80
  from guppylang_internals.tys.parsing import type_from_ast
81
+ from guppylang_internals.tys.qubit import is_qubit_ty, qubit_ty
78
82
  from guppylang_internals.tys.subst import Subst
79
83
  from guppylang_internals.tys.ty import (
80
84
  ExistentialTypeVar,
81
85
  FunctionType,
82
86
  NoneType,
87
+ NumericType,
83
88
  StructType,
84
89
  TupleType,
85
90
  Type,
@@ -398,6 +403,48 @@ class StmtChecker(AstVisitor[BBStatement]):
398
403
  self.ctx.locals[func_def.name] = Variable(func_def.name, func_def.ty, func_def)
399
404
  return func_def
400
405
 
406
+ def visit_ModifiedBlock(self, node: ModifiedBlock) -> ast.stmt:
407
+ from guppylang_internals.checker.modifier_checker import check_modified_block
408
+
409
+ if not self.bb:
410
+ raise InternalGuppyError("BB required to check with block!")
411
+
412
+ # check the body of the modified block
413
+ modified_block = check_modified_block(node, self.bb, self.ctx)
414
+
415
+ # check the arguments of the control and power.
416
+ for control in modified_block.control:
417
+ ctrl = control.ctrl
418
+ # This case is handled during CFG construction.
419
+ assert len(ctrl) > 0
420
+ ctrl[0], ty = self._synth_expr(ctrl[0])
421
+
422
+ if is_array_type(ty):
423
+ if len(ctrl) > 1:
424
+ span = Span(to_span(control.func).end, to_span(control).end)
425
+ raise GuppyError(WrongNumberOfArgsError(span, 1, len(control.args)))
426
+ element_ty = get_element_type(ty)
427
+ if not is_qubit_ty(element_ty):
428
+ n = ExistentialConstVar.fresh(
429
+ "n", NumericType(NumericType.Kind.Nat)
430
+ )
431
+ dummy_array_ty = array_type(qubit_ty(), n)
432
+ raise GuppyTypeError(TypeMismatchError(ctrl[0], dummy_array_ty, ty))
433
+ control.qubit_num = get_array_length(ty)
434
+ else:
435
+ for i in range(len(ctrl)):
436
+ ctrl[i], subst = self._check_expr(ctrl[i], qubit_ty())
437
+ assert len(subst) == 0
438
+ control.qubit_num = len(ctrl)
439
+
440
+ for power in node.power:
441
+ power.iter, subst = self._check_expr(
442
+ power.iter, NumericType(NumericType.Kind.Nat)
443
+ )
444
+ assert len(subst) == 0
445
+
446
+ return modified_block
447
+
401
448
  def visit_If(self, node: ast.If) -> None:
402
449
  raise InternalGuppyError("Control-flow statement should not be present here.")
403
450
 
@@ -0,0 +1,132 @@
1
+ import ast
2
+
3
+ from guppylang_internals.ast_util import find_nodes, get_type, loop_in_ast
4
+ from guppylang_internals.checker.cfg_checker import CheckedBB, CheckedCFG
5
+ from guppylang_internals.checker.core import Place, contains_subscript
6
+ from guppylang_internals.checker.errors.generic import (
7
+ InvalidUnderDagger,
8
+ UnsupportedError,
9
+ )
10
+ from guppylang_internals.definition.value import CallableDef
11
+ from guppylang_internals.engine import ENGINE
12
+ from guppylang_internals.error import GuppyError, GuppyTypeError
13
+ from guppylang_internals.nodes import (
14
+ AnyCall,
15
+ BarrierExpr,
16
+ GlobalCall,
17
+ LocalCall,
18
+ PlaceNode,
19
+ StateResultExpr,
20
+ TensorCall,
21
+ )
22
+ from guppylang_internals.tys.errors import UnitaryCallError
23
+ from guppylang_internals.tys.qubit import contain_qubit_ty
24
+ from guppylang_internals.tys.ty import FunctionType, UnitaryFlags
25
+
26
+
27
+ def check_invalid_under_dagger(
28
+ fn_def: ast.FunctionDef, unitary_flags: UnitaryFlags
29
+ ) -> None:
30
+ """Check that there are no invalid constructs in a daggered CFG.
31
+ This checker checks the case the UnitaryFlags is given by
32
+ annotation (i.e., not inferred from `with dagger:`).
33
+ """
34
+ if UnitaryFlags.Dagger not in unitary_flags:
35
+ return
36
+
37
+ for stmt in fn_def.body:
38
+ loops = loop_in_ast(stmt)
39
+ if len(loops) != 0:
40
+ loop = next(iter(loops))
41
+ err = InvalidUnderDagger(loop, "Loop")
42
+ raise GuppyError(err)
43
+ # Note: sub-diagnostic for dagger context is not available here
44
+
45
+ found = find_nodes(
46
+ lambda n: isinstance(n, ast.Assign | ast.AnnAssign | ast.AugAssign),
47
+ stmt,
48
+ {ast.FunctionDef},
49
+ )
50
+ if len(found) != 0:
51
+ assign = next(iter(found))
52
+ err = InvalidUnderDagger(assign, "Assignment")
53
+ raise GuppyError(err)
54
+
55
+
56
+ class BBUnitaryChecker(ast.NodeVisitor):
57
+ """AST visitor that checks whether the modifiers (dagger, control, power)
58
+ are applicable."""
59
+
60
+ flags: UnitaryFlags
61
+
62
+ def check(self, bb: CheckedBB[Place], unitary_flags: UnitaryFlags) -> None:
63
+ self.flags = unitary_flags
64
+ for stmt in bb.statements:
65
+ self.visit(stmt)
66
+
67
+ def _check_classical_args(self, args: list[ast.expr]) -> bool:
68
+ for arg in args:
69
+ self.visit(arg)
70
+ if contain_qubit_ty(get_type(arg)):
71
+ return False
72
+ return True
73
+
74
+ def _check_call(self, node: AnyCall, ty: FunctionType) -> None:
75
+ classic = self._check_classical_args(node.args)
76
+ flag_ok = self.flags in ty.unitary_flags
77
+ if not classic and not flag_ok:
78
+ raise GuppyTypeError(
79
+ UnitaryCallError(node, self.flags & (~ty.unitary_flags))
80
+ )
81
+
82
+ def visit_GlobalCall(self, node: GlobalCall) -> None:
83
+ func = ENGINE.get_parsed(node.def_id)
84
+ assert isinstance(func, CallableDef)
85
+ self._check_call(node, func.ty)
86
+
87
+ def visit_LocalCall(self, node: LocalCall) -> None:
88
+ func = get_type(node.func)
89
+ assert isinstance(func, FunctionType)
90
+ self._check_call(node, func)
91
+
92
+ def visit_TensorCall(self, node: TensorCall) -> None:
93
+ self._check_call(node, node.tensor_ty)
94
+
95
+ def visit_BarrierExpr(self, node: BarrierExpr) -> None:
96
+ # Barrier is always allowed
97
+ pass
98
+
99
+ def visit_StateResultExpr(self, node: StateResultExpr) -> None:
100
+ # StateResult is always allowed
101
+ pass
102
+
103
+ def _check_assign(self, node: ast.Assign | ast.AnnAssign | ast.AugAssign) -> None:
104
+ if UnitaryFlags.Dagger in self.flags:
105
+ raise GuppyError(InvalidUnderDagger(node, "Assignment"))
106
+ if node.value is not None:
107
+ self.visit(node.value)
108
+
109
+ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
110
+ self._check_assign(node)
111
+
112
+ def visit_Assign(self, node: ast.Assign) -> None:
113
+ self._check_assign(node)
114
+
115
+ def visit_AugAssign(self, node: ast.AugAssign) -> None:
116
+ self._check_assign(node)
117
+
118
+ def visit_PlaceNode(self, node: PlaceNode) -> None:
119
+ if UnitaryFlags.Dagger in self.flags and contains_subscript(node.place):
120
+ raise GuppyError(
121
+ UnsupportedError(node, "index access", True, "dagger context")
122
+ )
123
+
124
+
125
+ def check_cfg_unitary(
126
+ cfg: CheckedCFG[Place],
127
+ unitary_flags: UnitaryFlags,
128
+ ) -> None:
129
+ """Checks that the given unitary flags are valid for a CFG."""
130
+ bb_checker = BBUnitaryChecker()
131
+ for bb in cfg.bbs:
132
+ bb_checker.check(bb, unitary_flags)
@@ -24,7 +24,7 @@ from guppylang_internals.compiler.core import (
24
24
  from guppylang_internals.compiler.expr_compiler import ExprCompiler
25
25
  from guppylang_internals.compiler.stmt_compiler import StmtCompiler
26
26
  from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
27
- from guppylang_internals.tys.ty import SumType, row_to_type, type_to_row
27
+ from guppylang_internals.tys.ty import type_to_row
28
28
 
29
29
 
30
30
  def compile_cfg(
@@ -38,7 +38,7 @@ def compile_cfg(
38
38
  # TODO: This mutates the CFG in-place which leads to problems when trying to lower
39
39
  # the same function to Hugr twice. For now we just check that the return vars
40
40
  # haven't already been inserted, but we should figure out a better way to handle
41
- # this: https://github.com/CQCL/guppylang/issues/428
41
+ # this: https://github.com/quantinuum/guppylang/issues/428
42
42
  if all(
43
43
  not is_return_var(v.name)
44
44
  for v in cfg.exit_bb.sig.input_row
@@ -52,7 +52,7 @@ def compile_cfg(
52
52
  # unreachable
53
53
  out_tys = [place.ty.to_hugr(ctx) for place in cfg.exit_bb.sig.input_row]
54
54
  # TODO: Use proper API for this once it's added in hugr-py:
55
- # https://github.com/CQCL/hugr/issues/1816
55
+ # https://github.com/quantinuum/hugr/issues/1816
56
56
  builder._exit_op._cfg_outputs = out_tys
57
57
  builder.parent_op._outputs = out_tys
58
58
  builder.parent_node = builder.hugr._update_node_outs(
@@ -194,13 +194,14 @@ def choose_vars_for_tuple_sum(
194
194
  constructs a TupleSum value of type `Sum(#s1, #s2, ...)`.
195
195
  """
196
196
  assert all(v.ty.droppable for var_row in output_vars for v in var_row)
197
- tys = [[v.ty for v in var_row] for var_row in output_vars]
198
- sum_type = SumType([row_to_type(row) for row in tys]).to_hugr(dfg.ctx)
197
+ sum_type = ht.Sum(
198
+ [[v.ty.to_hugr(dfg.ctx) for v in var_row] for var_row in output_vars]
199
+ )
199
200
 
200
201
  # We pass all values into the conditional instead of relying on non-local edges.
201
202
  # This is because we can't handle them in lower parts of the stack yet :/
202
203
  # TODO: Reinstate use of non-local edges.
203
- # See https://github.com/CQCL/guppylang/issues/963
204
+ # See https://github.com/quantinuum/guppylang/issues/963
204
205
  all_vars = {v.id: dfg[v] for var_row in output_vars for v in var_row}
205
206
  all_vars_wires = list(all_vars.values())
206
207
  all_vars_idxs = {x: i for i, x in enumerate(all_vars.keys())}