guppylang-internals 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/cfg/builder.py +17 -2
  3. guppylang_internals/cfg/cfg.py +3 -0
  4. guppylang_internals/checker/cfg_checker.py +6 -0
  5. guppylang_internals/checker/core.py +1 -2
  6. guppylang_internals/checker/errors/wasm.py +7 -4
  7. guppylang_internals/checker/expr_checker.py +13 -8
  8. guppylang_internals/checker/func_checker.py +17 -13
  9. guppylang_internals/checker/linearity_checker.py +2 -10
  10. guppylang_internals/checker/modifier_checker.py +6 -2
  11. guppylang_internals/checker/unitary_checker.py +132 -0
  12. guppylang_internals/compiler/cfg_compiler.py +7 -6
  13. guppylang_internals/compiler/core.py +5 -5
  14. guppylang_internals/compiler/expr_compiler.py +42 -73
  15. guppylang_internals/compiler/modifier_compiler.py +2 -0
  16. guppylang_internals/decorator.py +86 -7
  17. guppylang_internals/definition/custom.py +4 -0
  18. guppylang_internals/definition/declaration.py +6 -2
  19. guppylang_internals/definition/function.py +12 -2
  20. guppylang_internals/definition/pytket_circuits.py +1 -0
  21. guppylang_internals/definition/struct.py +6 -3
  22. guppylang_internals/definition/wasm.py +42 -10
  23. guppylang_internals/engine.py +9 -3
  24. guppylang_internals/nodes.py +23 -24
  25. guppylang_internals/std/_internal/checker.py +13 -108
  26. guppylang_internals/std/_internal/compiler/array.py +1 -1
  27. guppylang_internals/std/_internal/compiler/list.py +1 -1
  28. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  29. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  30. guppylang_internals/std/_internal/compiler/tket_exts.py +3 -4
  31. guppylang_internals/std/_internal/debug.py +18 -9
  32. guppylang_internals/std/_internal/util.py +1 -1
  33. guppylang_internals/tracing/object.py +10 -0
  34. guppylang_internals/tys/errors.py +23 -1
  35. guppylang_internals/tys/parsing.py +3 -3
  36. guppylang_internals/tys/printing.py +2 -8
  37. guppylang_internals/tys/qubit.py +37 -2
  38. guppylang_internals/tys/ty.py +60 -64
  39. guppylang_internals/wasm_util.py +129 -0
  40. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +4 -3
  41. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/RECORD +43 -40
  42. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
  43. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
@@ -1,3 +1,3 @@
1
1
  # This is updated by our release-please workflow, triggered by this
2
2
  # annotation: x-release-please-version
3
- __version__ = "0.25.0"
3
+ __version__ = "0.26.0"
@@ -46,7 +46,7 @@ from guppylang_internals.nodes import (
46
46
  Power,
47
47
  )
48
48
  from guppylang_internals.span import Span, to_span
49
- from guppylang_internals.tys.ty import NoneType
49
+ from guppylang_internals.tys.ty import NoneType, UnitaryFlags
50
50
 
51
51
  # In order to build expressions, need an endless stream of unique temporary variables
52
52
  # to store intermediate results
@@ -78,7 +78,13 @@ class CFGBuilder(AstVisitor[BB | None]):
78
78
  cfg: CFG
79
79
  globals: Globals
80
80
 
81
- def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> CFG:
81
+ def build(
82
+ self,
83
+ nodes: list[ast.stmt],
84
+ returns_none: bool,
85
+ globals: Globals,
86
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
87
+ ) -> CFG:
82
88
  """Builds a CFG from a list of ast nodes.
83
89
 
84
90
  We also require the expected number of return ports for the whole CFG. This is
@@ -86,6 +92,7 @@ class CFGBuilder(AstVisitor[BB | None]):
86
92
  variables.
87
93
  """
88
94
  self.cfg = CFG()
95
+ self.cfg.unitary_flags = unitary_flags
89
96
  self.globals = globals
90
97
 
91
98
  final_bb = self.visit_stmts(
@@ -273,6 +280,7 @@ class CFGBuilder(AstVisitor[BB | None]):
273
280
 
274
281
  func_ty = check_signature(node, self.globals)
275
282
  returns_none = isinstance(func_ty.output, NoneType)
283
+ # No UnitaryFlags are assigned to nested functions
276
284
  cfg = CFGBuilder().build(node.body, returns_none, self.globals)
277
285
 
278
286
  new_node = NestedFunctionDef(
@@ -300,6 +308,13 @@ class CFGBuilder(AstVisitor[BB | None]):
300
308
  modifier = self._handle_withitem(item)
301
309
  new_node.push_modifier(modifier)
302
310
 
311
+ # FIXME: Currently, the unitary flags is not set correctly if there are nested
312
+ # `with` blocks. This is because the outer block's unitary flags are not
313
+ # propagated to the outer block. The following line should calculate the sum
314
+ # of the unitary flags of the outer block and modifiers applied in this
315
+ # `with` block.
316
+ cfg.unitary_flags = new_node.flags()
317
+
303
318
  set_location_from(new_node, node)
304
319
  bb.statements.append(new_node)
305
320
  return bb
@@ -12,6 +12,7 @@ from guppylang_internals.cfg.analysis import (
12
12
  )
13
13
  from guppylang_internals.cfg.bb import BB, BBStatement, VariableStats
14
14
  from guppylang_internals.nodes import InoutReturnSentinel
15
+ from guppylang_internals.tys.ty import UnitaryFlags
15
16
 
16
17
  T = TypeVar("T", bound=BB)
17
18
 
@@ -29,6 +30,7 @@ class BaseCFG(Generic[T]):
29
30
 
30
31
  #: Set of variables defined in this CFG
31
32
  assigned_somewhere: set[str]
33
+ unitary_flags: UnitaryFlags
32
34
 
33
35
  def __init__(
34
36
  self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None
@@ -42,6 +44,7 @@ class BaseCFG(Generic[T]):
42
44
  self.ass_before = {}
43
45
  self.maybe_ass_before = {}
44
46
  self.assigned_somewhere = set()
47
+ self.unitary_flags = UnitaryFlags.NoFlags
45
48
 
46
49
  def ancestors(self, *bbs: T) -> Iterator[T]:
47
50
  """Returns an iterator over all ancestors of the given BBs in BFS order."""
@@ -149,11 +149,17 @@ def check_cfg(
149
149
  checked_cfg.maybe_ass_before = {
150
150
  compiled[bb]: cfg.maybe_ass_before[bb] for bb in required_bbs
151
151
  }
152
+ checked_cfg.unitary_flags = cfg.unitary_flags
152
153
 
153
154
  # Finally, run the linearity check
154
155
  from guppylang_internals.checker.linearity_checker import check_cfg_linearity
155
156
 
156
157
  linearity_checked_cfg = check_cfg_linearity(checked_cfg, func_name, globals)
158
+
159
+ from guppylang_internals.checker.unitary_checker import check_cfg_unitary
160
+
161
+ check_cfg_unitary(linearity_checked_cfg, cfg.unitary_flags)
162
+
157
163
  return linearity_checked_cfg
158
164
 
159
165
 
@@ -47,7 +47,6 @@ from guppylang_internals.tys.ty import (
47
47
  NumericType,
48
48
  OpaqueType,
49
49
  StructType,
50
- SumType,
51
50
  TupleType,
52
51
  Type,
53
52
  )
@@ -360,7 +359,7 @@ class Globals:
360
359
  match ty:
361
360
  case TypeDef() as type_defn:
362
361
  pass
363
- case BoundTypeVar() | ExistentialTypeVar() | SumType():
362
+ case BoundTypeVar() | ExistentialTypeVar():
364
363
  return None
365
364
  case NumericType(kind):
366
365
  match kind:
@@ -1,7 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import ClassVar
3
3
 
4
- from guppylang_internals.diagnostic import Error
4
+ from guppylang_internals.diagnostic import Error, Note
5
5
  from guppylang_internals.tys.ty import Type
6
6
 
7
7
 
@@ -13,10 +13,13 @@ class WasmError(Error):
13
13
  @dataclass(frozen=True)
14
14
  class FirstArgNotModule(WasmError):
15
15
  span_label: ClassVar[str] = (
16
- "First argument to WASM function should be a reference to a WASM module."
17
- " Found `{ty}` instead"
16
+ "First argument to WASM function should be a WASM module."
18
17
  )
19
- ty: Type
18
+
19
+ @dataclass(frozen=True)
20
+ class GotOtherType(Note):
21
+ span_label: ClassVar[str] = "Found `{ty}` instead."
22
+ ty: Type
20
23
 
21
24
 
22
25
  @dataclass(frozen=True)
@@ -43,6 +43,7 @@ from guppylang_internals.ast_util import (
43
43
  )
44
44
  from guppylang_internals.cfg.builder import is_tmp_var, tmp_vars
45
45
  from guppylang_internals.checker.core import (
46
+ ComptimeVariable,
46
47
  Context,
47
48
  DummyEvalDict,
48
49
  FieldAccess,
@@ -468,7 +469,7 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
468
469
  # The only exception are attributes accesses that are generated during
469
470
  # desugaring (for example for iterators in `for` loops). Since those just
470
471
  # inherit the span of the sugared code, we could have line breaks there.
471
- # See https://github.com/CQCL/guppylang/issues/1301
472
+ # See https://github.com/quantinuum/guppylang/issues/1301
472
473
  span = to_span(node)
473
474
  if span.start.line == span.end.line:
474
475
  attr_span = Span(span.end.shift_left(len(node.attr)), span.end)
@@ -503,12 +504,7 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
503
504
  )
504
505
  # Make a closure by partially applying the `self` argument
505
506
  # TODO: Try to infer some type args based on `self`
506
- result_ty = FunctionType(
507
- func.ty.inputs[1:],
508
- func.ty.output,
509
- func.ty.input_names[1:] if func.ty.input_names else None,
510
- func.ty.params,
511
- )
507
+ result_ty = FunctionType(func.ty.inputs[1:], func.ty.output, func.ty.params)
512
508
  return with_loc(node, PartialApply(func=name, args=[node.value])), result_ty
513
509
  raise GuppyTypeError(AttributeNotFoundError(attr_span, ty, node.attr))
514
510
 
@@ -1035,7 +1031,14 @@ def check_place_assignable(
1035
1031
  exp_sig = FunctionType(
1036
1032
  [
1037
1033
  FuncInput(parent.ty, InputFlags.Inout),
1038
- FuncInput(item.ty, InputFlags.NoFlags),
1034
+ FuncInput(
1035
+ # Due to potential coercions that were applied during the
1036
+ # `__getitem__` call (e.g. coercing a nat index to int), we're
1037
+ # not allowed to rely on `item.ty` here.
1038
+ # See https://github.com/CQCL/guppylang/issues/1356
1039
+ ExistentialTypeVar.fresh("T", True, True),
1040
+ InputFlags.NoFlags,
1041
+ ),
1039
1042
  FuncInput(ty, InputFlags.Owned),
1040
1043
  ],
1041
1044
  NoneType(),
@@ -1072,6 +1075,8 @@ def check_comptime_arg(
1072
1075
  match arg:
1073
1076
  case ast.Constant(value=v):
1074
1077
  const = ConstValue(ty, v)
1078
+ case PlaceNode(place=ComptimeVariable(ty=ty, static_value=v)):
1079
+ const = ConstValue(ty, v)
1075
1080
  case GenericParamValue(param=const_param):
1076
1081
  const = const_param.to_bound().const
1077
1082
  case arg:
@@ -16,6 +16,7 @@ from guppylang_internals.cfg.builder import CFGBuilder
16
16
  from guppylang_internals.checker.cfg_checker import CheckedCFG, check_cfg
17
17
  from guppylang_internals.checker.core import Context, Globals, Place, Variable
18
18
  from guppylang_internals.checker.errors.generic import UnsupportedError
19
+ from guppylang_internals.checker.unitary_checker import check_invalid_under_dagger
19
20
  from guppylang_internals.definition.common import DefId
20
21
  from guppylang_internals.definition.ty import TypeDef
21
22
  from guppylang_internals.diagnostic import Error, Help, Note
@@ -37,6 +38,7 @@ from guppylang_internals.tys.ty import (
37
38
  InputFlags,
38
39
  NoneType,
39
40
  Type,
41
+ UnitaryFlags,
40
42
  unify,
41
43
  )
42
44
 
@@ -134,12 +136,13 @@ def check_global_func_def(
134
136
  """Type checks a top-level function definition."""
135
137
  args = func_def.args.args
136
138
  returns_none = isinstance(ty.output, NoneType)
137
- assert ty.input_names is not None
139
+ assert all(inp.name is not None for inp in ty.inputs)
138
140
 
139
- cfg = CFGBuilder().build(func_def.body, returns_none, globals)
141
+ check_invalid_under_dagger(func_def, ty.unitary_flags)
142
+ cfg = CFGBuilder().build(func_def.body, returns_none, globals, ty.unitary_flags)
140
143
  inputs = [
141
- Variable(x, inp.ty, loc, inp.flags, is_func_input=True)
142
- for x, inp, loc in zip(ty.input_names, ty.inputs, args, strict=True)
144
+ Variable(cast(str, inp.name), inp.ty, loc, inp.flags, is_func_input=True)
145
+ for inp, loc in zip(ty.inputs, args, strict=True)
143
146
  # Comptime inputs are turned into generic args, so are not included here
144
147
  if InputFlags.Comptime not in inp.flags
145
148
  ]
@@ -150,7 +153,9 @@ def check_global_func_def(
150
153
 
151
154
 
152
155
  def check_nested_func_def(
153
- func_def: NestedFunctionDef, bb: BB, ctx: Context
156
+ func_def: NestedFunctionDef,
157
+ bb: BB,
158
+ ctx: Context,
154
159
  ) -> CheckedNestedFunctionDef:
155
160
  """Type checks a local (nested) function definition."""
156
161
  func_ty = check_signature(func_def, ctx.globals)
@@ -194,10 +199,8 @@ def check_nested_func_def(
194
199
 
195
200
  # Construct inputs for checking the body CFG
196
201
  inputs = [v for v, _ in captured.values()] + [
197
- Variable(x, inp.ty, func_def.args.args[i], inp.flags, is_func_input=True)
198
- for i, (x, inp) in enumerate(
199
- zip(func_ty.input_names, func_ty.inputs, strict=True)
200
- )
202
+ Variable(cast(str, inp.name), inp.ty, arg, inp.flags, is_func_input=True)
203
+ for arg, inp in zip(func_def.args.args, func_ty.inputs, strict=True)
201
204
  # Comptime inputs are turned into generic args, so are not included here
202
205
  if InputFlags.Comptime not in inp.flags
203
206
  ]
@@ -238,7 +241,10 @@ def check_nested_func_def(
238
241
 
239
242
 
240
243
  def check_signature(
241
- func_def: ast.FunctionDef, globals: Globals, def_id: DefId | None = None
244
+ func_def: ast.FunctionDef,
245
+ globals: Globals,
246
+ def_id: DefId | None = None,
247
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
242
248
  ) -> FunctionType:
243
249
  """Checks the signature of a function definition and returns the corresponding
244
250
  Guppy type.
@@ -286,7 +292,6 @@ def check_signature(
286
292
  assert isinstance(self_defn, TypeDef)
287
293
 
288
294
  inputs = []
289
- input_names = []
290
295
  ctx = TypeParsingCtx(globals, param_var_mapping, allow_free_vars=True)
291
296
  for i, inp in enumerate(func_def.args.args):
292
297
  # Special handling for `self` arguments. Note that `__new__` is excluded here
@@ -300,13 +305,12 @@ def check_signature(
300
305
  raise GuppyError(MissingArgAnnotationError(inp))
301
306
  input = parse_function_arg_annotation(ty_ast, inp.arg, ctx)
302
307
  inputs.append(input)
303
- input_names.append(inp.arg)
304
308
  output = type_from_ast(func_def.returns, ctx)
305
309
  return FunctionType(
306
310
  inputs,
307
311
  output,
308
- input_names,
309
312
  sorted(param_var_mapping.values(), key=lambda v: v.idx),
313
+ unitary_flags=unitary_flags,
310
314
  )
311
315
 
312
316
 
@@ -63,7 +63,6 @@ from guppylang_internals.nodes import (
63
63
  LocalCall,
64
64
  PartialApply,
65
65
  PlaceNode,
66
- ResultExpr,
67
66
  StateResultExpr,
68
67
  SubscriptAccessAndDrop,
69
68
  TensorCall,
@@ -74,7 +73,6 @@ from guppylang_internals.tys.ty import (
74
73
  FuncInput,
75
74
  FunctionType,
76
75
  InputFlags,
77
- NoneType,
78
76
  StructType,
79
77
  TupleType,
80
78
  Type,
@@ -451,13 +449,6 @@ class BBLinearityChecker(ast.NodeVisitor):
451
449
  self._visit_call_args(node.func_ty, node)
452
450
  self._reassign_inout_args(node.func_ty, node)
453
451
 
454
- def visit_ResultExpr(self, node: ResultExpr) -> None:
455
- ty = get_type(node.value)
456
- flag = InputFlags.Inout if not ty.copyable else InputFlags.NoFlags
457
- func_ty = FunctionType([FuncInput(ty, flag)], NoneType())
458
- self._visit_call_args(func_ty, node)
459
- self._reassign_inout_args(func_ty, node)
460
-
461
452
  def visit_StateResultExpr(self, node: StateResultExpr) -> None:
462
453
  self._visit_call_args(node.func_ty, node)
463
454
  self._reassign_inout_args(node.func_ty, node)
@@ -582,7 +573,7 @@ class BBLinearityChecker(ast.NodeVisitor):
582
573
  # can feed them through the loop. Note that we could also use non-local
583
574
  # edges, but we can't handle them in lower parts of the stack yet :/
584
575
  # TODO: Reinstate use of non-local edges.
585
- # See https://github.com/CQCL/guppylang/issues/963
576
+ # See https://github.com/quantinuum/guppylang/issues/963
586
577
  gen.used_outer_places = []
587
578
  for x, use in inner_scope.used_parent.items():
588
579
  place = inner_scope[x]
@@ -880,6 +871,7 @@ def check_cfg_linearity(
880
871
  result_cfg.maybe_ass_before = {
881
872
  checked[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs
882
873
  }
874
+ result_cfg.unitary_flags = cfg.unitary_flags
883
875
  for bb in cfg.bbs:
884
876
  checked[bb].predecessors = [checked[pred] for pred in bb.predecessors]
885
877
  checked[bb].successors = [checked[succ] for succ in bb.successors]
@@ -71,7 +71,7 @@ def check_modified_block(
71
71
  # This name could be printed in error messages, for example,
72
72
  # when the linearity checker fails in the modifier body
73
73
  checked_cfg = check_cfg(cfg, inputs, NoneType(), {}, "__modified__()", globals)
74
- func_ty = check_modified_block_signature(checked_cfg.input_tys)
74
+ func_ty = check_modified_block_signature(modified_block, checked_cfg.input_tys)
75
75
 
76
76
  checked_modifier = CheckedModifiedBlock(
77
77
  def_id,
@@ -94,9 +94,12 @@ def _set_inout_if_non_copyable(var: Variable) -> Variable:
94
94
  return var
95
95
 
96
96
 
97
- def check_modified_block_signature(input_tys: list[Type]) -> FunctionType:
97
+ def check_modified_block_signature(
98
+ modified_block: ModifiedBlock, input_tys: list[Type]
99
+ ) -> FunctionType:
98
100
  """Check and create the signature of a function definition for a body
99
101
  of a `With` block."""
102
+ unitary_flags = modified_block.flags()
100
103
 
101
104
  func_ty = FunctionType(
102
105
  [
@@ -104,6 +107,7 @@ def check_modified_block_signature(input_tys: list[Type]) -> FunctionType:
104
107
  for t in input_tys
105
108
  ],
106
109
  NoneType(),
110
+ unitary_flags=unitary_flags,
107
111
  )
108
112
  return func_ty
109
113
 
@@ -0,0 +1,132 @@
1
+ import ast
2
+
3
+ from guppylang_internals.ast_util import find_nodes, get_type, loop_in_ast
4
+ from guppylang_internals.checker.cfg_checker import CheckedBB, CheckedCFG
5
+ from guppylang_internals.checker.core import Place, contains_subscript
6
+ from guppylang_internals.checker.errors.generic import (
7
+ InvalidUnderDagger,
8
+ UnsupportedError,
9
+ )
10
+ from guppylang_internals.definition.value import CallableDef
11
+ from guppylang_internals.engine import ENGINE
12
+ from guppylang_internals.error import GuppyError, GuppyTypeError
13
+ from guppylang_internals.nodes import (
14
+ AnyCall,
15
+ BarrierExpr,
16
+ GlobalCall,
17
+ LocalCall,
18
+ PlaceNode,
19
+ StateResultExpr,
20
+ TensorCall,
21
+ )
22
+ from guppylang_internals.tys.errors import UnitaryCallError
23
+ from guppylang_internals.tys.qubit import contain_qubit_ty
24
+ from guppylang_internals.tys.ty import FunctionType, UnitaryFlags
25
+
26
+
27
+ def check_invalid_under_dagger(
28
+ fn_def: ast.FunctionDef, unitary_flags: UnitaryFlags
29
+ ) -> None:
30
+ """Check that there are no invalid constructs in a daggered CFG.
31
+ This checker checks the case the UnitaryFlags is given by
32
+ annotation (i.e., not inferred from `with dagger:`).
33
+ """
34
+ if UnitaryFlags.Dagger not in unitary_flags:
35
+ return
36
+
37
+ for stmt in fn_def.body:
38
+ loops = loop_in_ast(stmt)
39
+ if len(loops) != 0:
40
+ loop = next(iter(loops))
41
+ err = InvalidUnderDagger(loop, "Loop")
42
+ raise GuppyError(err)
43
+ # Note: sub-diagnostic for dagger context is not available here
44
+
45
+ found = find_nodes(
46
+ lambda n: isinstance(n, ast.Assign | ast.AnnAssign | ast.AugAssign),
47
+ stmt,
48
+ {ast.FunctionDef},
49
+ )
50
+ if len(found) != 0:
51
+ assign = next(iter(found))
52
+ err = InvalidUnderDagger(assign, "Assignment")
53
+ raise GuppyError(err)
54
+
55
+
56
+ class BBUnitaryChecker(ast.NodeVisitor):
57
+ """AST visitor that checks whether the modifiers (dagger, control, power)
58
+ are applicable."""
59
+
60
+ flags: UnitaryFlags
61
+
62
+ def check(self, bb: CheckedBB[Place], unitary_flags: UnitaryFlags) -> None:
63
+ self.flags = unitary_flags
64
+ for stmt in bb.statements:
65
+ self.visit(stmt)
66
+
67
+ def _check_classical_args(self, args: list[ast.expr]) -> bool:
68
+ for arg in args:
69
+ self.visit(arg)
70
+ if contain_qubit_ty(get_type(arg)):
71
+ return False
72
+ return True
73
+
74
+ def _check_call(self, node: AnyCall, ty: FunctionType) -> None:
75
+ classic = self._check_classical_args(node.args)
76
+ flag_ok = self.flags in ty.unitary_flags
77
+ if not classic and not flag_ok:
78
+ raise GuppyTypeError(
79
+ UnitaryCallError(node, self.flags & (~ty.unitary_flags))
80
+ )
81
+
82
+ def visit_GlobalCall(self, node: GlobalCall) -> None:
83
+ func = ENGINE.get_parsed(node.def_id)
84
+ assert isinstance(func, CallableDef)
85
+ self._check_call(node, func.ty)
86
+
87
+ def visit_LocalCall(self, node: LocalCall) -> None:
88
+ func = get_type(node.func)
89
+ assert isinstance(func, FunctionType)
90
+ self._check_call(node, func)
91
+
92
+ def visit_TensorCall(self, node: TensorCall) -> None:
93
+ self._check_call(node, node.tensor_ty)
94
+
95
+ def visit_BarrierExpr(self, node: BarrierExpr) -> None:
96
+ # Barrier is always allowed
97
+ pass
98
+
99
+ def visit_StateResultExpr(self, node: StateResultExpr) -> None:
100
+ # StateResult is always allowed
101
+ pass
102
+
103
+ def _check_assign(self, node: ast.Assign | ast.AnnAssign | ast.AugAssign) -> None:
104
+ if UnitaryFlags.Dagger in self.flags:
105
+ raise GuppyError(InvalidUnderDagger(node, "Assignment"))
106
+ if node.value is not None:
107
+ self.visit(node.value)
108
+
109
+ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
110
+ self._check_assign(node)
111
+
112
+ def visit_Assign(self, node: ast.Assign) -> None:
113
+ self._check_assign(node)
114
+
115
+ def visit_AugAssign(self, node: ast.AugAssign) -> None:
116
+ self._check_assign(node)
117
+
118
+ def visit_PlaceNode(self, node: PlaceNode) -> None:
119
+ if UnitaryFlags.Dagger in self.flags and contains_subscript(node.place):
120
+ raise GuppyError(
121
+ UnsupportedError(node, "index access", True, "dagger context")
122
+ )
123
+
124
+
125
+ def check_cfg_unitary(
126
+ cfg: CheckedCFG[Place],
127
+ unitary_flags: UnitaryFlags,
128
+ ) -> None:
129
+ """Checks that the given unitary flags are valid for a CFG."""
130
+ bb_checker = BBUnitaryChecker()
131
+ for bb in cfg.bbs:
132
+ bb_checker.check(bb, unitary_flags)
@@ -24,7 +24,7 @@ from guppylang_internals.compiler.core import (
24
24
  from guppylang_internals.compiler.expr_compiler import ExprCompiler
25
25
  from guppylang_internals.compiler.stmt_compiler import StmtCompiler
26
26
  from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
27
- from guppylang_internals.tys.ty import SumType, row_to_type, type_to_row
27
+ from guppylang_internals.tys.ty import type_to_row
28
28
 
29
29
 
30
30
  def compile_cfg(
@@ -38,7 +38,7 @@ def compile_cfg(
38
38
  # TODO: This mutates the CFG in-place which leads to problems when trying to lower
39
39
  # the same function to Hugr twice. For now we just check that the return vars
40
40
  # haven't already been inserted, but we should figure out a better way to handle
41
- # this: https://github.com/CQCL/guppylang/issues/428
41
+ # this: https://github.com/quantinuum/guppylang/issues/428
42
42
  if all(
43
43
  not is_return_var(v.name)
44
44
  for v in cfg.exit_bb.sig.input_row
@@ -52,7 +52,7 @@ def compile_cfg(
52
52
  # unreachable
53
53
  out_tys = [place.ty.to_hugr(ctx) for place in cfg.exit_bb.sig.input_row]
54
54
  # TODO: Use proper API for this once it's added in hugr-py:
55
- # https://github.com/CQCL/hugr/issues/1816
55
+ # https://github.com/quantinuum/hugr/issues/1816
56
56
  builder._exit_op._cfg_outputs = out_tys
57
57
  builder.parent_op._outputs = out_tys
58
58
  builder.parent_node = builder.hugr._update_node_outs(
@@ -194,13 +194,14 @@ def choose_vars_for_tuple_sum(
194
194
  constructs a TupleSum value of type `Sum(#s1, #s2, ...)`.
195
195
  """
196
196
  assert all(v.ty.droppable for var_row in output_vars for v in var_row)
197
- tys = [[v.ty for v in var_row] for var_row in output_vars]
198
- sum_type = SumType([row_to_type(row) for row in tys]).to_hugr(dfg.ctx)
197
+ sum_type = ht.Sum(
198
+ [[v.ty.to_hugr(dfg.ctx) for v in var_row] for var_row in output_vars]
199
+ )
199
200
 
200
201
  # We pass all values into the conditional instead of relying on non-local edges.
201
202
  # This is because we can't handle them in lower parts of the stack yet :/
202
203
  # TODO: Reinstate use of non-local edges.
203
- # See https://github.com/CQCL/guppylang/issues/963
204
+ # See https://github.com/quantinuum/guppylang/issues/963
204
205
  all_vars = {v.id: dfg[v] for var_row in output_vars for v in var_row}
205
206
  all_vars_wires = list(all_vars.values())
206
207
  all_vars_idxs = {x: i for i, x in enumerate(all_vars.keys())}
@@ -185,7 +185,7 @@ class CompilerContext(ToHugrContext):
185
185
  # make the call to `ENGINE.get_checked` below fail. For now, let's just short-
186
186
  # cut if the function doesn't take any generic params (as is the case for all
187
187
  # nested functions).
188
- # See https://github.com/CQCL/guppylang/issues/1032
188
+ # See https://github.com/quantinuum/guppylang/issues/1032
189
189
  if (def_id, ()) in self.compiled:
190
190
  assert type_args == []
191
191
  return self.compiled[def_id, ()], type_args
@@ -247,7 +247,7 @@ class CompilerContext(ToHugrContext):
247
247
 
248
248
  # Insert explicit drops for affine types
249
249
  # TODO: This is a quick workaround until we can properly insert these drops
250
- # during linearity checking. See https://github.com/CQCL/guppylang/issues/1082
250
+ # during linearity checking. See https://github.com/quantinuum/guppylang/issues/1082
251
251
  insert_drops(self.module.hugr)
252
252
 
253
253
  return entry_compiled
@@ -659,7 +659,7 @@ def track_hugr_side_effects() -> Iterator[None]:
659
659
 
660
660
  def qualified_name(type_def: he.TypeDef) -> str:
661
661
  """Returns the qualified name of a Hugr extension type.
662
- TODO: Remove once upstreamed, see https://github.com/CQCL/hugr/issues/2426
662
+ TODO: Remove once upstreamed, see https://github.com/quantinuum/hugr/issues/2426
663
663
  """
664
664
  if type_def._extension is not None:
665
665
  return f"{type_def._extension.name}.{type_def.name}"
@@ -711,14 +711,14 @@ def drop_op(ty: ht.Type) -> ops.ExtOp:
711
711
  def insert_drops(hugr: Hugr[OpVarCov]) -> None:
712
712
  """Inserts explicit drop ops for unconnected ports into the Hugr.
713
713
  TODO: This is a quick workaround until we can properly insert these drops during
714
- linearity checking. See https://github.com/CQCL/guppylang/issues/1082
714
+ linearity checking. See https://github.com/quantinuum/guppylang/issues/1082
715
715
  """
716
716
  for node in hugr:
717
717
  data = hugr[node]
718
718
  # Iterating over `node.outputs()` doesn't work reliably since it sometimes
719
719
  # raises an `IncompleteOp` exception. Instead, we query the number of out ports
720
720
  # and look them up by index. However, this method is *also* broken when
721
- # inspecting `FuncDefn` nodes due to https://github.com/CQCL/hugr/issues/2438.
721
+ # inspecting `FuncDefn` nodes due to https://github.com/quantinuum/hugr/issues/2438.
722
722
  if isinstance(data.op, ops.FuncDefn):
723
723
  continue
724
724
  for i in range(hugr.num_out_ports(node)):