guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,821 @@
1
+ """Linearity checking
2
+
3
+ Linearity checking across control-flow is done by the `CFGChecker`.
4
+ """
5
+
6
+ import ast
7
+ from collections.abc import Generator, Iterator
8
+ from contextlib import contextmanager
9
+ from enum import Enum, auto
10
+ from typing import TYPE_CHECKING, NamedTuple, TypeGuard
11
+
12
+ from guppylang_internals.ast_util import AstNode, find_nodes, get_type
13
+ from guppylang_internals.cfg.analysis import LivenessAnalysis, LivenessDomain
14
+ from guppylang_internals.cfg.bb import BB, VariableStats
15
+ from guppylang_internals.checker.cfg_checker import (
16
+ CheckedBB,
17
+ CheckedCFG,
18
+ Row,
19
+ Signature,
20
+ )
21
+ from guppylang_internals.checker.core import (
22
+ FieldAccess,
23
+ Globals,
24
+ Locals,
25
+ Place,
26
+ PlaceId,
27
+ SubscriptAccess,
28
+ TupleAccess,
29
+ Variable,
30
+ contains_subscript,
31
+ )
32
+ from guppylang_internals.checker.errors.linearity import (
33
+ AlreadyUsedError,
34
+ BorrowShadowedError,
35
+ BorrowSubPlaceUsedError,
36
+ ComprAlreadyUsedError,
37
+ DropAfterCallError,
38
+ MoveOutOfSubscriptError,
39
+ NonCopyableCaptureError,
40
+ NonCopyablePartialApplyError,
41
+ NotOwnedError,
42
+ PlaceNotUsedError,
43
+ UnnamedExprNotUsedError,
44
+ UnnamedFieldNotUsedError,
45
+ UnnamedSubscriptNotUsedError,
46
+ UnnamedTupleNotUsedError,
47
+ )
48
+ from guppylang_internals.definition.custom import CustomFunctionDef
49
+ from guppylang_internals.definition.value import CallableDef
50
+ from guppylang_internals.engine import DEF_STORE, ENGINE
51
+ from guppylang_internals.error import GuppyError, GuppyTypeError
52
+ from guppylang_internals.nodes import (
53
+ AnyCall,
54
+ BarrierExpr,
55
+ CheckedNestedFunctionDef,
56
+ DesugaredArrayComp,
57
+ DesugaredGenerator,
58
+ DesugaredListComp,
59
+ FieldAccessAndDrop,
60
+ GlobalCall,
61
+ InoutReturnSentinel,
62
+ LocalCall,
63
+ PartialApply,
64
+ PlaceNode,
65
+ ResultExpr,
66
+ StateResultExpr,
67
+ SubscriptAccessAndDrop,
68
+ TensorCall,
69
+ TupleAccessAndDrop,
70
+ )
71
+ from guppylang_internals.tys.builtin import get_element_type, is_array_type
72
+ from guppylang_internals.tys.ty import (
73
+ FuncInput,
74
+ FunctionType,
75
+ InputFlags,
76
+ NoneType,
77
+ StructType,
78
+ TupleType,
79
+ Type,
80
+ )
81
+
82
+ if TYPE_CHECKING:
83
+ from guppylang_internals.diagnostic import Error
84
+
85
+
86
+ class UseKind(Enum):
87
+ """The different ways places can be used."""
88
+
89
+ #: A classical value is copied
90
+ COPY = auto()
91
+
92
+ #: A value is borrowed when passing it to a function
93
+ BORROW = auto()
94
+
95
+ #: Ownership of an owned value is transferred by passing it to a function
96
+ CONSUME = auto()
97
+
98
+ #: Ownership of an owned value is transferred by returning it
99
+ RETURN = auto()
100
+
101
+ #: An owned value is renamed or stored in a tuple/list
102
+ MOVE = auto()
103
+
104
+ @property
105
+ def indicative(self) -> str:
106
+ """Describes a use in an indicative mood.
107
+
108
+ For example: "You cannot *consume* this qubit."
109
+ """
110
+ return self.name.lower()
111
+
112
+ @property
113
+ def subjunctive(self) -> str:
114
+ """Describes a use in a subjunctive mood.
115
+
116
+ For example: "This qubit cannot be *consumed*"
117
+ """
118
+ match self:
119
+ case UseKind.COPY:
120
+ return "copied"
121
+ case UseKind.BORROW:
122
+ return "borrowed"
123
+ case UseKind.CONSUME:
124
+ return "consumed"
125
+ case UseKind.RETURN:
126
+ return "returned"
127
+ case UseKind.MOVE:
128
+ return "moved"
129
+
130
+
131
+ class Use(NamedTuple):
132
+ """Records data associated with a use of a place."""
133
+
134
+ #: The AST node corresponding to the use
135
+ node: AstNode
136
+
137
+ #: The kind of use, i.e. is the value consumed, borrowed, returned, ...?
138
+ kind: UseKind
139
+
140
+
141
+ class Scope(Locals[PlaceId, Place]):
142
+ """Scoped collection of assigned places indexed by their id.
143
+
144
+ Keeps track of which places have already been used.
145
+ """
146
+
147
+ parent_scope: "Scope | None"
148
+ used_local: dict[PlaceId, Use]
149
+ used_parent: dict[PlaceId, Use]
150
+
151
+ def __init__(self, parent: "Scope | None" = None):
152
+ self.used_local = {}
153
+ self.used_parent = {}
154
+ super().__init__({}, parent)
155
+
156
+ def used(self, x: PlaceId) -> Use | None:
157
+ """Checks whether a place has already been used."""
158
+ if x in self.vars:
159
+ return self.used_local.get(x, None)
160
+ assert self.parent_scope is not None
161
+ return self.parent_scope.used(x)
162
+
163
+ def use(self, x: PlaceId, node: AstNode, kind: UseKind) -> None:
164
+ """Records a use of a place.
165
+
166
+ Works for places in the current scope as well as places in any parent scope.
167
+ """
168
+ if x in self.vars:
169
+ self.used_local[x] = Use(node, kind)
170
+ else:
171
+ assert self.parent_scope is not None
172
+ assert x in self.parent_scope
173
+ self.used_parent[x] = Use(node, kind)
174
+ self.parent_scope.use(x, node, kind)
175
+
176
+ def assign(self, place: Place) -> None:
177
+ """Records an assignment of a place."""
178
+ assert place.defined_at is not None
179
+ x = place.id
180
+ self.vars[x] = place
181
+ if x in self.used_local:
182
+ self.used_local.pop(x)
183
+
184
+ def stats(self) -> VariableStats[PlaceId]:
185
+ assigned = {}
186
+ for x, place in self.vars.items():
187
+ assert place.defined_at is not None
188
+ assigned[x] = place.defined_at
189
+ used = {x: use.node for x, use in self.used_parent.items()}
190
+ return VariableStats(assigned, used)
191
+
192
+
193
+ class BBLinearityChecker(ast.NodeVisitor):
194
+ """AST visitor that checks linearity for a single basic block."""
195
+
196
+ scope: Scope
197
+ stats: VariableStats[PlaceId]
198
+ func_name: str
199
+ func_inputs: dict[PlaceId, Variable]
200
+ globals: Globals
201
+
202
+ def check(
203
+ self,
204
+ bb: "CheckedBB[Variable]",
205
+ is_entry: bool,
206
+ func_name: str,
207
+ func_inputs: dict[PlaceId, Variable],
208
+ globals: Globals,
209
+ ) -> Scope:
210
+ # Manufacture a scope that holds all places that are live at the start
211
+ # of this BB
212
+ input_scope = Scope()
213
+ for var in bb.sig.input_row:
214
+ for place in leaf_places(var):
215
+ input_scope.assign(place)
216
+ self.func_name = func_name
217
+ self.func_inputs = func_inputs
218
+ self.globals = globals
219
+
220
+ # Open up a new nested scope to check the BB contents. This way we can track
221
+ # when we use variables from the outside vs ones assigned in this BB. The only
222
+ # exception is the entry BB since function arguments should be treated as part
223
+ # of the entry BB
224
+ self.scope = input_scope if is_entry else Scope(input_scope)
225
+
226
+ for stmt in bb.statements:
227
+ self.visit(stmt)
228
+ if bb.branch_pred:
229
+ self.visit(bb.branch_pred)
230
+ return self.scope
231
+
232
+ @contextmanager
233
+ def new_scope(self) -> Generator[Scope, None, None]:
234
+ scope, new_scope = self.scope, Scope(self.scope)
235
+ self.scope = new_scope
236
+ yield new_scope
237
+ self.scope = scope
238
+
239
+ def visit_PlaceNode(
240
+ self,
241
+ node: PlaceNode,
242
+ /,
243
+ use_kind: UseKind = UseKind.MOVE,
244
+ is_call_arg: AnyCall | None = None,
245
+ ) -> None:
246
+ # Usage of borrowed variables is generally forbidden. The only exception is
247
+ # letting them be reborrowed by another function call. In that case, our
248
+ # `_visit_call_args` helper will set `use_kind=UseKind.BORROW`.
249
+ is_inout_arg = use_kind == UseKind.BORROW
250
+ if is_inout_var(node.place) and not is_inout_arg:
251
+ err: Error = NotOwnedError(
252
+ node,
253
+ node.place,
254
+ use_kind,
255
+ is_call_arg is not None,
256
+ self._call_name(is_call_arg),
257
+ self.func_name,
258
+ )
259
+ arg_span = self.func_inputs[node.place.root.id].defined_at
260
+ err.add_sub_diagnostic(NotOwnedError.MakeOwned(arg_span))
261
+ # If the argument is a classical array, we can also suggest copying it.
262
+ if has_explicit_copy(node.place.ty):
263
+ err.add_sub_diagnostic(NotOwnedError.MakeCopy(node))
264
+ raise GuppyError(err)
265
+ # Places involving subscripts are handled differently since we ignore everything
266
+ # after the subscript for the purposes of linearity checking.
267
+ if subscript := contains_subscript(node.place):
268
+ # We have to check the item type to determine if we can move out of the
269
+ # subscript.
270
+ if not is_inout_arg and not subscript.ty.copyable:
271
+ err = MoveOutOfSubscriptError(node, use_kind, subscript.parent)
272
+ err.add_sub_diagnostic(MoveOutOfSubscriptError.Explanation(None))
273
+ raise GuppyError(err)
274
+ self.visit(subscript.item_expr)
275
+ self.scope.assign(subscript.item)
276
+ # Visiting the `__getitem__(place.parent, place.item)` call ensures that we
277
+ # linearity-check the parent and element.
278
+ assert subscript.getitem_call is not None
279
+ self.visit(subscript.getitem_call)
280
+ # For all other places, we record uses of all leaves
281
+ else:
282
+ for place in leaf_places(node.place):
283
+ x = place.id
284
+ if (prev_use := self.scope.used(x)) and not place.ty.copyable:
285
+ err = AlreadyUsedError(node, place, use_kind)
286
+ err.add_sub_diagnostic(
287
+ AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind)
288
+ )
289
+ if has_explicit_copy(place.ty):
290
+ err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
291
+ raise GuppyError(err)
292
+ self.scope.use(x, node, use_kind)
293
+
294
+ def visit_Assign(self, node: ast.Assign) -> None:
295
+ self.visit(node.value)
296
+ self._check_assign_targets(node.targets)
297
+
298
+ # Check that borrowed vars are not being shadowed. This would also be caught by
299
+ # the dataflow analysis later, however we can give nicer error messages here.
300
+ [target] = node.targets
301
+ for tgt in find_nodes(lambda n: isinstance(n, PlaceNode), target):
302
+ assert isinstance(tgt, PlaceNode)
303
+ if tgt.place.id in self.func_inputs:
304
+ entry_place = self.func_inputs[tgt.place.id]
305
+ if is_inout_var(entry_place):
306
+ err = BorrowShadowedError(tgt.place.defined_at, entry_place)
307
+ err.add_sub_diagnostic(BorrowShadowedError.Rename(None))
308
+ raise GuppyError(err)
309
+
310
+ def visit_Return(self, node: ast.Return) -> None:
311
+ # Intercept returns of places, so we can set the appropriate `use_kind` to get
312
+ # nicer error messages
313
+ if isinstance(node.value, PlaceNode):
314
+ self.visit_PlaceNode(node.value, use_kind=UseKind.RETURN)
315
+ elif isinstance(node.value, ast.Tuple):
316
+ for elt in node.value.elts:
317
+ if isinstance(elt, PlaceNode):
318
+ self.visit_PlaceNode(elt, use_kind=UseKind.RETURN)
319
+ else:
320
+ self.visit(elt)
321
+ elif node.value:
322
+ self.visit(node.value)
323
+
324
+ def _visit_call_args(self, func_ty: FunctionType, call: AnyCall) -> None:
325
+ """Helper function to check the arguments of a function call.
326
+
327
+ Populates the `use_kind` kwarg of `visit_PlaceNode` in case some of the
328
+ arguments are places.
329
+ """
330
+ for inp, arg in zip(func_ty.inputs, call.args, strict=True):
331
+ if isinstance(arg, PlaceNode):
332
+ use_kind = (
333
+ UseKind.BORROW if InputFlags.Inout in inp.flags else UseKind.CONSUME
334
+ )
335
+ self.visit_PlaceNode(arg, use_kind=use_kind, is_call_arg=call)
336
+ else:
337
+ self.visit(arg)
338
+
339
+ def _reassign_inout_args(self, func_ty: FunctionType, call: AnyCall) -> None:
340
+ """Helper function to reassign the borrowed arguments after a function call."""
341
+ for inp, arg in zip(func_ty.inputs, call.args, strict=True):
342
+ if InputFlags.Inout in inp.flags:
343
+ match arg:
344
+ case PlaceNode(place=place):
345
+ self._reassign_single_inout_arg(place, place.defined_at or arg)
346
+ case arg if not inp.ty.droppable:
347
+ err = DropAfterCallError(arg, inp.ty, self._call_name(call))
348
+ err.add_sub_diagnostic(DropAfterCallError.Assign(None))
349
+ raise GuppyError(err)
350
+
351
+ def _reassign_single_inout_arg(self, place: Place, node: AstNode) -> None:
352
+ """Helper function to reassign a single borrowed argument after a function
353
+ call."""
354
+ # Places involving subscripts are given back by visiting the `__setitem__` call
355
+ if subscript := contains_subscript(place):
356
+ assert subscript.setitem_call is not None
357
+ for leaf in leaf_places(subscript.setitem_call.value_var):
358
+ self.scope.assign(leaf)
359
+ self.visit(subscript.setitem_call.call)
360
+ self._reassign_single_inout_arg(subscript.parent, node)
361
+ else:
362
+ for leaf in leaf_places(place):
363
+ assert not isinstance(leaf, SubscriptAccess)
364
+ leaf = leaf.replace_defined_at(node)
365
+ self.scope.assign(leaf)
366
+
367
+ def _call_name(self, node: AnyCall | None) -> str | None:
368
+ """Tries to extract the name of a called function from a call AST node."""
369
+ if isinstance(node, LocalCall):
370
+ return node.func.id if isinstance(node.func, ast.Name) else None
371
+ elif isinstance(node, GlobalCall):
372
+ return DEF_STORE.raw_defs[node.def_id].name
373
+ return None
374
+
375
+ def visit_GlobalCall(self, node: GlobalCall) -> None:
376
+ func = ENGINE.get_parsed(node.def_id)
377
+ assert isinstance(func, CallableDef)
378
+ if isinstance(func, CustomFunctionDef) and not func.has_signature:
379
+ func_ty = FunctionType(
380
+ [FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args],
381
+ get_type(node),
382
+ )
383
+ else:
384
+ func_ty = func.ty.instantiate(node.type_args)
385
+ self._visit_call_args(func_ty, node)
386
+ self._reassign_inout_args(func_ty, node)
387
+
388
+ def visit_LocalCall(self, node: LocalCall) -> None:
389
+ func_ty = get_type(node.func)
390
+ assert isinstance(func_ty, FunctionType)
391
+ self.visit(node.func)
392
+ self._visit_call_args(func_ty, node)
393
+ self._reassign_inout_args(func_ty, node)
394
+
395
+ def visit_TensorCall(self, node: TensorCall) -> None:
396
+ for arg in node.args:
397
+ self.visit(arg)
398
+ self._reassign_inout_args(node.tensor_ty, node)
399
+
400
+ def visit_PartialApply(self, node: PartialApply) -> None:
401
+ self.visit(node.func)
402
+ for arg in node.args:
403
+ ty = get_type(arg)
404
+ if not ty.copyable:
405
+ err = NonCopyablePartialApplyError(node)
406
+ err.add_sub_diagnostic(NonCopyablePartialApplyError.Captured(arg, ty))
407
+ raise GuppyError(err)
408
+ self.visit(arg)
409
+
410
+ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None:
411
+ # A field access on a value that is not a place. This means the value can no
412
+ # longer be accessed after the field has been projected out. Thus, this is only
413
+ # legal if there are no remaining linear fields on the value
414
+ self.visit(node.value)
415
+ for field in node.struct_ty.fields:
416
+ if field.name != node.field.name and not field.ty.droppable:
417
+ err = UnnamedFieldNotUsedError(node.value, field, node.struct_ty)
418
+ err.add_sub_diagnostic(UnnamedFieldNotUsedError.Fix(None, node.field))
419
+ raise GuppyError(err)
420
+
421
+ def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> None:
422
+ # A subscript access on a value that is not a place. This means the value can no
423
+ # longer be accessed after the item has been projected out. Thus, this is only
424
+ # legal if the items in the container are not linear
425
+ elem_ty = get_type(node.getitem_expr)
426
+ if not elem_ty.droppable:
427
+ value = node.original_expr.value
428
+ err = UnnamedSubscriptNotUsedError(value, get_type(value))
429
+ err.add_sub_diagnostic(
430
+ UnnamedSubscriptNotUsedError.SubscriptHint(node.item_expr)
431
+ )
432
+ err.add_sub_diagnostic(UnnamedSubscriptNotUsedError.Fix(None))
433
+ raise GuppyTypeError(err)
434
+ self.visit(node.item_expr)
435
+ self.scope.assign(node.item)
436
+ self.visit(node.getitem_expr)
437
+
438
+ def visit_TupleAccessAndDrop(self, node: TupleAccessAndDrop) -> None:
439
+ # A tuple access on a value that is not a place. This means the value can no
440
+ # longer be accessed after the element has been projected out. Thus, this is
441
+ # only legal if there are no remaining linear elements in the tuple
442
+ self.visit(node.value)
443
+ for idx, elem_ty in enumerate(node.tuple_ty.element_types):
444
+ if idx != node.index and not elem_ty.droppable:
445
+ err = UnnamedTupleNotUsedError(node.value, node.tuple_ty)
446
+ err.add_sub_diagnostic(UnnamedTupleNotUsedError.Fix(None))
447
+ raise GuppyError(err)
448
+
449
+ def visit_BarrierExpr(self, node: BarrierExpr) -> None:
450
+ self._visit_call_args(node.func_ty, node)
451
+ self._reassign_inout_args(node.func_ty, node)
452
+
453
+ def visit_ResultExpr(self, node: ResultExpr) -> None:
454
+ ty = get_type(node.value)
455
+ flag = InputFlags.Inout if not ty.copyable else InputFlags.NoFlags
456
+ func_ty = FunctionType([FuncInput(ty, flag)], NoneType())
457
+ self._visit_call_args(func_ty, node)
458
+ self._reassign_inout_args(func_ty, node)
459
+
460
+ def visit_StateResultExpr(self, node: StateResultExpr) -> None:
461
+ self._visit_call_args(node.func_ty, node)
462
+ self._reassign_inout_args(node.func_ty, node)
463
+
464
+ def visit_Expr(self, node: ast.Expr) -> None:
465
+ # An expression statement where the return value is discarded
466
+ self.visit(node.value)
467
+ ty = get_type(node.value)
468
+ if not ty.droppable:
469
+ err = UnnamedExprNotUsedError(node, ty)
470
+ err.add_sub_diagnostic(UnnamedExprNotUsedError.Fix(None))
471
+ raise GuppyTypeError(err)
472
+
473
+ def visit_DesugaredListComp(self, node: DesugaredListComp) -> None:
474
+ self._check_comprehension(node.generators, node.elt)
475
+
476
+ def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> None:
477
+ self._check_comprehension([node.generator], node.elt)
478
+
479
+ def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None:
480
+ # Linearity of the nested function has already been checked. We just need to
481
+ # verify that no linear variables are captured
482
+ # TODO: In the future, we could support capturing of non-linear subplaces
483
+ for var, use in node.captured.values():
484
+ if not var.ty.copyable:
485
+ err = NonCopyableCaptureError(use, var)
486
+ err.add_sub_diagnostic(
487
+ NonCopyableCaptureError.DefinedHere(var.defined_at)
488
+ )
489
+ raise GuppyError(err)
490
+ for place in leaf_places(var):
491
+ self.scope.use(place.id, use, UseKind.COPY)
492
+ self.scope.assign(Variable(node.name, node.ty, node))
493
+
494
+ def _check_assign_targets(self, targets: list[ast.expr]) -> None:
495
+ """Helper function to check assignments."""
496
+ # We're not allowed to override an unused linear place
497
+ [target] = targets
498
+ for tgt in find_nodes(lambda n: isinstance(n, PlaceNode), target):
499
+ assert isinstance(tgt, PlaceNode)
500
+ # Special error message for shadowing of borrowed vars
501
+ x = tgt.place.id
502
+ if x in self.scope.vars and is_inout_var(self.scope[x]):
503
+ err: Error = BorrowShadowedError(tgt, tgt.place)
504
+ err.add_sub_diagnostic(BorrowShadowedError.Rename(None))
505
+ raise GuppyError(err)
506
+ # Subscript assignments also require checking the `__setitem__` call
507
+ if subscript := contains_subscript(tgt.place):
508
+ assert subscript.setitem_call is not None
509
+ self.visit(subscript.item_expr)
510
+ self.scope.assign(subscript.item)
511
+ self.scope.assign(subscript.setitem_call.value_var)
512
+ self.visit(subscript.setitem_call.call)
513
+ else:
514
+ for tgt_place in leaf_places(tgt.place):
515
+ x = tgt_place.id
516
+ # Only check for overrides of places locally defined in this BB.
517
+ # Global checks are handled by dataflow analysis.
518
+ if x in self.scope.vars and x not in self.scope.used_local:
519
+ place = self.scope[x]
520
+ if not place.ty.droppable:
521
+ err = PlaceNotUsedError(place.defined_at, place)
522
+ err.add_sub_diagnostic(PlaceNotUsedError.Fix(None))
523
+ raise GuppyError(err)
524
+ self.scope.assign(tgt_place)
525
+
526
+ def _check_comprehension(
527
+ self, gens: list[DesugaredGenerator], elt: ast.expr
528
+ ) -> None:
529
+ """Helper function to recursively check list comprehensions."""
530
+ if not gens:
531
+ self.visit(elt)
532
+ return
533
+
534
+ # Check the iterator expression in the current scope
535
+ gen, *gens = gens
536
+ self.visit(gen.iter_assign.value)
537
+ assert isinstance(gen.iter, PlaceNode)
538
+
539
+ # The rest is checked in a new nested scope so we can track which variables
540
+ # are introduced and used inside the loop
541
+ with self.new_scope() as inner_scope:
542
+ # In particular, assign the iterator variable in the new scope
543
+ self._check_assign_targets(gen.iter_assign.targets)
544
+ self.visit(gen.next_call)
545
+ self._check_assign_targets([gen.target])
546
+ self._check_assign_targets(gen.iter_assign.targets)
547
+
548
+ # `if` guards are generally not allowed when we're iterating over linear
549
+ # variables. The only exception is if all linear variables are already
550
+ # consumed by the first guard
551
+ if gen.ifs:
552
+ first_if, *other_ifs = gen.ifs
553
+ # Check if there are linear iteration variables that have not been used
554
+ # by the first guard
555
+ self.visit(first_if)
556
+ for place in self.scope.vars.values():
557
+ # The only exception is the iterator variable since we make sure
558
+ # that it is carried through each iteration during Hugr generation
559
+ if place == gen.iter.place:
560
+ continue
561
+ for leaf in leaf_places(place):
562
+ x = leaf.id
563
+ # Also ignore borrowed variables
564
+ if x in inner_scope.used_parent and (
565
+ inner_scope.used_parent[x].kind == UseKind.BORROW
566
+ ):
567
+ continue
568
+ if not self.scope.used(x) and not place.ty.droppable:
569
+ err = PlaceNotUsedError(place.defined_at, place)
570
+ err.add_sub_diagnostic(
571
+ PlaceNotUsedError.Branch(first_if, False)
572
+ )
573
+ raise GuppyTypeError(err)
574
+ for expr in other_ifs:
575
+ self.visit(expr)
576
+
577
+ # Recursively check the remaining generators
578
+ self._check_comprehension(gens, elt)
579
+
580
+ # Look for any variables that are used from the outer scope. This is so we
581
+ # can feed them through the loop. Note that we could also use non-local
582
+ # edges, but we can't handle them in lower parts of the stack yet :/
583
+ # TODO: Reinstate use of non-local edges.
584
+ # See https://github.com/CQCL/guppylang/issues/963
585
+ gen.used_outer_places = []
586
+ for x, use in inner_scope.used_parent.items():
587
+ place = inner_scope[x]
588
+ gen.used_outer_places.append(place)
589
+ if use.kind == UseKind.BORROW:
590
+ # Since `x` was borrowed, we know that is now also assigned in the
591
+ # inner scope since it gets reassigned in the local scope after the
592
+ # borrow expires.
593
+ # Also mark this place as implicitly used so we don't complain about
594
+ # it later.
595
+ for leaf in leaf_places(place):
596
+ inner_scope.use(
597
+ leaf.id, InoutReturnSentinel(leaf), UseKind.RETURN
598
+ )
599
+
600
+ # Mark the iterator as used since it's carried into the next iteration
601
+ for leaf in leaf_places(gen.iter.place):
602
+ self.scope.use(leaf.id, gen.iter, UseKind.CONSUME)
603
+
604
+ # We have to make sure that all linear variables that were introduced in the
605
+ # inner scope have been used
606
+ for place in inner_scope.vars.values():
607
+ for leaf in leaf_places(place):
608
+ x = leaf.id
609
+ if not leaf.ty.droppable and not inner_scope.used(x):
610
+ raise GuppyTypeError(PlaceNotUsedError(leaf.defined_at, leaf))
611
+
612
+ # On the other hand, we have to ensure that no linear places from the
613
+ # outer scope have been used inside the comprehension (they would be used
614
+ # multiple times since the comprehension body is executed repeatedly)
615
+ for x, use in inner_scope.used_parent.items():
616
+ place = inner_scope[x]
617
+ # The only exception are values that are only borrowed from the outer
618
+ # scope. These can be safely reassigned.
619
+ if use.kind == UseKind.BORROW:
620
+ self._reassign_single_inout_arg(place, use.node)
621
+ elif not place.ty.copyable:
622
+ raise GuppyTypeError(ComprAlreadyUsedError(use.node, place, use.kind))
623
+
624
+
625
+ def leaf_places(place: Place) -> Iterator[Place]:
626
+ """Returns all leaf descendant projections of a place."""
627
+ stack = [place]
628
+ while stack:
629
+ place = stack.pop()
630
+ if isinstance(place.ty, StructType):
631
+ stack += [
632
+ FieldAccess(place, field, place.defined_at) for field in place.ty.fields
633
+ ]
634
+ elif isinstance(place.ty, TupleType):
635
+ stack += [
636
+ TupleAccess(place, elem_ty, idx, None)
637
+ for idx, elem_ty in enumerate(place.ty.element_types)
638
+ ]
639
+ else:
640
+ yield place
641
+
642
+
643
+ def is_inout_var(place: Place) -> TypeGuard[Variable]:
644
+ """Checks whether a place is a borrowed variable."""
645
+ return isinstance(place, Variable) and InputFlags.Inout in place.flags
646
+
647
+
648
+ def has_explicit_copy(ty: Type) -> bool:
649
+ """Checks whether a type has an explicit copy function.
650
+
651
+ Currently, this is only the case for arrays with copyable elements."""
652
+ if not is_array_type(ty):
653
+ return False
654
+ return get_element_type(ty).copyable
655
+
656
+
657
+ def check_cfg_linearity(
658
+ cfg: "CheckedCFG[Variable]", func_name: str, globals: Globals
659
+ ) -> "CheckedCFG[Place]":
660
+ """Checks whether a CFG satisfies the linearity requirements.
661
+
662
+ Raises a user-error if linearity violations are found.
663
+
664
+ Returns a new CFG with refined basic block signatures in terms of *places* rather
665
+ than just variables.
666
+ """
667
+ bb_checker = BBLinearityChecker()
668
+ func_inputs: dict[PlaceId, Variable] = {v.id: v for v in cfg.entry_bb.sig.input_row}
669
+ scopes: dict[BB, Scope] = {
670
+ bb: bb_checker.check(
671
+ bb,
672
+ is_entry=bb == cfg.entry_bb,
673
+ func_name=func_name,
674
+ func_inputs=func_inputs,
675
+ globals=globals,
676
+ )
677
+ for bb in cfg.bbs
678
+ }
679
+
680
+ # Mark the borrowed variables as implicitly used in the exit BB
681
+ exit_scope = scopes[cfg.exit_bb]
682
+ for var in cfg.entry_bb.sig.input_row:
683
+ if InputFlags.Inout in var.flags:
684
+ for leaf in leaf_places(var):
685
+ exit_scope.use(leaf.id, InoutReturnSentinel(var=var), UseKind.RETURN)
686
+
687
+ # Edge case: If the exit is unreachable, then the function will never terminate, so
688
+ # there is no need to give the borrowed values back to the caller. To ensure that
689
+ # the generated Hugr is still valid, we have to thread the borrowed arguments
690
+ # through the non-terminating loop. We achieve this by considering borrowed
691
+ # variables as live in every BB, even if the actual use in the exit is unreachable.
692
+ # This is done by including borrowed vars in the initial value for the liveness
693
+ # analysis below. The analogous thing was also done in the previous `CFG.analyze`
694
+ # pass.
695
+ live_default: LivenessDomain[PlaceId] = (
696
+ {
697
+ leaf.id: cfg.exit_bb
698
+ for var in cfg.entry_bb.sig.input_row
699
+ if InputFlags.Inout in var.flags
700
+ for leaf in leaf_places(var)
701
+ }
702
+ if not cfg.exit_bb.reachable
703
+ else {}
704
+ )
705
+
706
+ # Run liveness analysis with this initial value
707
+ stats = {bb: scope.stats() for bb, scope in scopes.items()}
708
+ live_before = LivenessAnalysis(
709
+ stats, initial=live_default, include_unreachable=False
710
+ ).run(cfg.bbs)
711
+
712
+ # Construct a CFG that tracks places instead of just variables
713
+ result_cfg: CheckedCFG[Place] = CheckedCFG(cfg.input_tys, cfg.output_ty)
714
+ checked: dict[BB, CheckedBB[Place]] = {}
715
+
716
+ for bb, scope in scopes.items():
717
+ live_before_bb = live_before[bb]
718
+
719
+ # We have to check that used not copyable variables are not being outputted
720
+ for succ in bb.successors:
721
+ live = live_before[succ]
722
+ for x, use_bb in live.items():
723
+ use_scope = scopes[use_bb]
724
+ place = use_scope[x]
725
+ if not place.ty.copyable and (prev_use := scope.used(x)):
726
+ use = use_scope.used_parent[x]
727
+ # Special case if this is a use arising from the implicit returning
728
+ # of a borrowed argument
729
+ if isinstance(use.node, InoutReturnSentinel):
730
+ assert isinstance(use.node.var, Variable)
731
+ assert InputFlags.Inout in use.node.var.flags
732
+ err: Error = BorrowSubPlaceUsedError(
733
+ use.node.var.defined_at, use.node.var, place
734
+ )
735
+ err.add_sub_diagnostic(
736
+ BorrowSubPlaceUsedError.PrevUse(
737
+ prev_use.node, prev_use.kind
738
+ )
739
+ )
740
+ err.add_sub_diagnostic(BorrowSubPlaceUsedError.Fix(None))
741
+ raise GuppyError(err)
742
+ err = AlreadyUsedError(use.node, place, use.kind)
743
+ err.add_sub_diagnostic(
744
+ AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind)
745
+ )
746
+ if has_explicit_copy(place.ty):
747
+ err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
748
+ raise GuppyError(err)
749
+
750
+ # On the other hand, unused variables that are not droppable *must* be outputted
751
+ for place in scope.values():
752
+ for leaf in leaf_places(place):
753
+ x = leaf.id
754
+ # Some values are just in scope because the type checker determined
755
+ # them as live in the first (less precises) dataflow analysis. It
756
+ # might be the case that x is actually not live when considering
757
+ # the second, more fine-grained, analysis based on places.
758
+ if x not in live_before_bb and x not in scope.vars:
759
+ continue
760
+ used_later = all(x in live_before[succ] for succ in bb.successors)
761
+ if not leaf.ty.droppable and not scope.used(x) and not used_later:
762
+ err = PlaceNotUsedError(scope[x].defined_at, leaf)
763
+ # If there are some paths that lead to a consumption, we can give
764
+ # a nicer error message by highlighting the branch that leads to
765
+ # the leak
766
+ if any(x in live_before[succ] for succ in bb.successors):
767
+ assert bb.branch_pred is not None
768
+ [left_succ, _] = bb.successors
769
+ err.add_sub_diagnostic(
770
+ PlaceNotUsedError.Branch(
771
+ bb.branch_pred, x in live_before[left_succ]
772
+ )
773
+ )
774
+ err.add_sub_diagnostic(PlaceNotUsedError.Fix(None))
775
+ raise GuppyError(err)
776
+
777
+ def live_places_row(
778
+ bb: BB, original_row: Row[Variable], pred_scope: Scope | None
779
+ ) -> Row[Place]:
780
+ """Construct a row of all places that are live at the start of a given BB.
781
+
782
+ The only exception are input and exit BBs whose signature should not be
783
+ split up into places but instead keep the original variable signature.
784
+ """
785
+ if bb in (cfg.entry_bb, cfg.exit_bb):
786
+ return original_row
787
+ assert pred_scope is not None
788
+ return [pred_scope[x] for x in live_before[bb]]
789
+
790
+ assert isinstance(bb, CheckedBB)
791
+ sig = Signature(
792
+ input_row=live_places_row(bb, bb.sig.input_row, scope.parent_scope),
793
+ output_rows=[
794
+ live_places_row(succ, output_row, scope)
795
+ for succ, output_row in zip(
796
+ bb.successors, bb.sig.output_rows, strict=True
797
+ )
798
+ ],
799
+ )
800
+ checked[bb] = CheckedBB(
801
+ bb.idx,
802
+ result_cfg,
803
+ bb.statements,
804
+ branch_pred=bb.branch_pred,
805
+ reachable=bb.reachable,
806
+ sig=sig,
807
+ )
808
+
809
+ # Fill in missing fields of the result CFG
810
+ result_cfg.bbs = list(checked.values())
811
+ result_cfg.entry_bb = checked[cfg.entry_bb]
812
+ result_cfg.exit_bb = checked[cfg.exit_bb]
813
+ result_cfg.live_before = {checked[bb]: cfg.live_before[bb] for bb in cfg.bbs}
814
+ result_cfg.ass_before = {checked[bb]: cfg.ass_before[bb] for bb in cfg.bbs}
815
+ result_cfg.maybe_ass_before = {
816
+ checked[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs
817
+ }
818
+ for bb in cfg.bbs:
819
+ checked[bb].predecessors = [checked[pred] for pred in bb.predecessors]
820
+ checked[bb].successors = [checked[succ] for succ in bb.successors]
821
+ return result_cfg