guppylang-internals 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +101 -3
- guppylang_internals/checker/core.py +12 -0
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/expr_checker.py +55 -29
- guppylang_internals/checker/func_checker.py +171 -22
- guppylang_internals/checker/linearity_checker.py +65 -0
- guppylang_internals/checker/modifier_checker.py +116 -0
- guppylang_internals/checker/stmt_checker.py +49 -2
- guppylang_internals/compiler/core.py +90 -53
- guppylang_internals/compiler/expr_compiler.py +49 -114
- guppylang_internals/compiler/modifier_compiler.py +174 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +124 -58
- guppylang_internals/definition/const.py +2 -2
- guppylang_internals/definition/custom.py +36 -2
- guppylang_internals/definition/declaration.py +4 -5
- guppylang_internals/definition/extern.py +2 -2
- guppylang_internals/definition/function.py +1 -1
- guppylang_internals/definition/parameter.py +10 -5
- guppylang_internals/definition/pytket_circuits.py +14 -42
- guppylang_internals/definition/struct.py +17 -14
- guppylang_internals/definition/traced.py +1 -1
- guppylang_internals/definition/ty.py +9 -3
- guppylang_internals/definition/wasm.py +2 -2
- guppylang_internals/engine.py +13 -2
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +124 -23
- guppylang_internals/std/_internal/compiler/array.py +94 -282
- guppylang_internals/std/_internal/compiler/tket_exts.py +12 -8
- guppylang_internals/std/_internal/compiler/wasm.py +37 -26
- guppylang_internals/tracing/function.py +13 -2
- guppylang_internals/tracing/unpacking.py +33 -28
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +32 -16
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +6 -0
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +118 -145
- guppylang_internals/tys/qubit.py +27 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +31 -21
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +4 -4
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +49 -46
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -7,8 +7,8 @@ node straight from the Python AST. We build a CFG, check it, and return a
|
|
|
7
7
|
|
|
8
8
|
import ast
|
|
9
9
|
import sys
|
|
10
|
-
from dataclasses import dataclass
|
|
11
|
-
from typing import TYPE_CHECKING, ClassVar
|
|
10
|
+
from dataclasses import dataclass, replace
|
|
11
|
+
from typing import TYPE_CHECKING, ClassVar, cast
|
|
12
12
|
|
|
13
13
|
from guppylang_internals.ast_util import return_nodes_in_ast, with_loc
|
|
14
14
|
from guppylang_internals.cfg.bb import BB
|
|
@@ -17,13 +17,28 @@ from guppylang_internals.checker.cfg_checker import CheckedCFG, check_cfg
|
|
|
17
17
|
from guppylang_internals.checker.core import Context, Globals, Place, Variable
|
|
18
18
|
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
19
19
|
from guppylang_internals.definition.common import DefId
|
|
20
|
+
from guppylang_internals.definition.ty import TypeDef
|
|
20
21
|
from guppylang_internals.diagnostic import Error, Help, Note
|
|
21
22
|
from guppylang_internals.engine import DEF_STORE, ENGINE
|
|
22
23
|
from guppylang_internals.error import GuppyError
|
|
23
24
|
from guppylang_internals.experimental import check_capturing_closures_enabled
|
|
24
25
|
from guppylang_internals.nodes import CheckedNestedFunctionDef, NestedFunctionDef
|
|
25
|
-
from guppylang_internals.tys.parsing import
|
|
26
|
-
|
|
26
|
+
from guppylang_internals.tys.parsing import (
|
|
27
|
+
TypeParsingCtx,
|
|
28
|
+
check_function_arg,
|
|
29
|
+
parse_function_arg_annotation,
|
|
30
|
+
type_from_ast,
|
|
31
|
+
type_with_flags_from_ast,
|
|
32
|
+
)
|
|
33
|
+
from guppylang_internals.tys.ty import (
|
|
34
|
+
ExistentialTypeVar,
|
|
35
|
+
FuncInput,
|
|
36
|
+
FunctionType,
|
|
37
|
+
InputFlags,
|
|
38
|
+
NoneType,
|
|
39
|
+
Type,
|
|
40
|
+
unify,
|
|
41
|
+
)
|
|
27
42
|
|
|
28
43
|
if sys.version_info >= (3, 12):
|
|
29
44
|
from guppylang_internals.tys.parsing import parse_parameter
|
|
@@ -53,6 +68,15 @@ class MissingArgAnnotationError(Error):
|
|
|
53
68
|
span_label: ClassVar[str] = "Argument requires a type annotation"
|
|
54
69
|
|
|
55
70
|
|
|
71
|
+
@dataclass(frozen=True)
|
|
72
|
+
class RecursiveSelfError(Error):
|
|
73
|
+
title: ClassVar[str] = "Recursive self annotation"
|
|
74
|
+
span_label: ClassVar[str] = (
|
|
75
|
+
"Type of `{self_arg}` cannot recursively refer to `Self`"
|
|
76
|
+
)
|
|
77
|
+
self_arg: str
|
|
78
|
+
|
|
79
|
+
|
|
56
80
|
@dataclass(frozen=True)
|
|
57
81
|
class MissingReturnAnnotationError(Error):
|
|
58
82
|
title: ClassVar[str] = "Missing type annotation"
|
|
@@ -67,6 +91,43 @@ class MissingReturnAnnotationError(Error):
|
|
|
67
91
|
func: str
|
|
68
92
|
|
|
69
93
|
|
|
94
|
+
@dataclass(frozen=True)
|
|
95
|
+
class InvalidSelfError(Error):
|
|
96
|
+
title: ClassVar[str] = "Invalid self annotation"
|
|
97
|
+
span_label: ClassVar[str] = "`{self_arg}` must be of type `{self_ty}`"
|
|
98
|
+
self_arg: str
|
|
99
|
+
self_ty: Type
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclass(frozen=True)
|
|
103
|
+
class SelfParamsShadowedError(Error):
|
|
104
|
+
title: ClassVar[str] = "Shadowed generic parameters"
|
|
105
|
+
span_label: ClassVar[str] = (
|
|
106
|
+
"Cannot infer type for `{self_arg}` since parameter `{param}` of "
|
|
107
|
+
"`{ty_defn.name}` is shadowed"
|
|
108
|
+
)
|
|
109
|
+
param: str
|
|
110
|
+
ty_defn: "TypeDef"
|
|
111
|
+
self_arg: str
|
|
112
|
+
|
|
113
|
+
@dataclass(frozen=True)
|
|
114
|
+
class ExplicitHelp(Help):
|
|
115
|
+
span_label: ClassVar[str] = (
|
|
116
|
+
"Consider specifying the type explicitly: `{suggestion}`"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def suggestion(self) -> str:
|
|
121
|
+
parent = self._parent
|
|
122
|
+
assert isinstance(parent, SelfParamsShadowedError)
|
|
123
|
+
params = (
|
|
124
|
+
f"[{', '.join(f'?{p.name}' for p in parent.ty_defn.params)}]"
|
|
125
|
+
if parent.ty_defn.params
|
|
126
|
+
else ""
|
|
127
|
+
)
|
|
128
|
+
return f'{parent.self_arg}: "{parent.ty_defn.name}{params}"'
|
|
129
|
+
|
|
130
|
+
|
|
70
131
|
def check_global_func_def(
|
|
71
132
|
func_def: ast.FunctionDef, ty: FunctionType, globals: Globals
|
|
72
133
|
) -> CheckedCFG[Place]:
|
|
@@ -176,9 +237,16 @@ def check_nested_func_def(
|
|
|
176
237
|
return with_loc(func_def, checked_def)
|
|
177
238
|
|
|
178
239
|
|
|
179
|
-
def check_signature(
|
|
240
|
+
def check_signature(
|
|
241
|
+
func_def: ast.FunctionDef, globals: Globals, def_id: DefId | None = None
|
|
242
|
+
) -> FunctionType:
|
|
180
243
|
"""Checks the signature of a function definition and returns the corresponding
|
|
181
|
-
Guppy type.
|
|
244
|
+
Guppy type.
|
|
245
|
+
|
|
246
|
+
If this is a method, then the `DefId` of the associated parent type should also be
|
|
247
|
+
passed. This will be used to check or infer the type annotation for the `self`
|
|
248
|
+
argument.
|
|
249
|
+
"""
|
|
182
250
|
if len(func_def.args.posonlyargs) != 0:
|
|
183
251
|
raise GuppyError(
|
|
184
252
|
UnsupportedError(func_def.args.posonlyargs[0], "Positional-only parameters")
|
|
@@ -208,26 +276,32 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
|
|
|
208
276
|
param_var_mapping: dict[str, Parameter] = {}
|
|
209
277
|
if sys.version_info >= (3, 12):
|
|
210
278
|
for i, param_node in enumerate(func_def.type_params):
|
|
211
|
-
param = parse_parameter(param_node, i, globals)
|
|
279
|
+
param = parse_parameter(param_node, i, globals, param_var_mapping)
|
|
212
280
|
param_var_mapping[param.name] = param
|
|
213
281
|
|
|
214
|
-
|
|
282
|
+
# Figure out if this is a method
|
|
283
|
+
self_defn: TypeDef | None = None
|
|
284
|
+
if def_id is not None and def_id in DEF_STORE.impl_parents:
|
|
285
|
+
self_defn = cast(TypeDef, ENGINE.get_checked(DEF_STORE.impl_parents[def_id]))
|
|
286
|
+
assert isinstance(self_defn, TypeDef)
|
|
287
|
+
|
|
288
|
+
inputs = []
|
|
215
289
|
input_names = []
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
290
|
+
ctx = TypeParsingCtx(globals, param_var_mapping, allow_free_vars=True)
|
|
291
|
+
for i, inp in enumerate(func_def.args.args):
|
|
292
|
+
# Special handling for `self` arguments. Note that `__new__` is excluded here
|
|
293
|
+
# since it's not a method so doesn't take `self`.
|
|
294
|
+
if self_defn and i == 0 and func_def.name != "__new__":
|
|
295
|
+
input = parse_self_arg(inp, self_defn, ctx)
|
|
296
|
+
ctx = replace(ctx, self_ty=input.ty)
|
|
297
|
+
else:
|
|
298
|
+
ty_ast = inp.annotation
|
|
299
|
+
if ty_ast is None:
|
|
300
|
+
raise GuppyError(MissingArgAnnotationError(inp))
|
|
301
|
+
input = parse_function_arg_annotation(ty_ast, inp.arg, ctx)
|
|
302
|
+
inputs.append(input)
|
|
221
303
|
input_names.append(inp.arg)
|
|
222
|
-
|
|
223
|
-
input_nodes,
|
|
224
|
-
func_def.returns,
|
|
225
|
-
input_names,
|
|
226
|
-
func_def,
|
|
227
|
-
globals,
|
|
228
|
-
param_var_mapping,
|
|
229
|
-
True,
|
|
230
|
-
)
|
|
304
|
+
output = type_from_ast(func_def.returns, ctx)
|
|
231
305
|
return FunctionType(
|
|
232
306
|
inputs,
|
|
233
307
|
output,
|
|
@@ -236,6 +310,81 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
|
|
|
236
310
|
)
|
|
237
311
|
|
|
238
312
|
|
|
313
|
+
def parse_self_arg(arg: ast.arg, self_defn: TypeDef, ctx: TypeParsingCtx) -> FuncInput:
|
|
314
|
+
"""Handles parsing of the `self` argument on methods.
|
|
315
|
+
|
|
316
|
+
This argument is special since its type annotation may be omitted. Furthermore, if a
|
|
317
|
+
type is provided then it must match the parent type.
|
|
318
|
+
"""
|
|
319
|
+
assert self_defn.params is not None
|
|
320
|
+
if arg.annotation is None:
|
|
321
|
+
return handle_implicit_self_arg(arg, self_defn, ctx)
|
|
322
|
+
|
|
323
|
+
# If the user has provided an annotation for `self`, then we go ahead and parse it.
|
|
324
|
+
# However, in the annotation the user is also allowed to use `Self`, so we have to
|
|
325
|
+
# specify a `self_ty` in the context.
|
|
326
|
+
self_ty_head = self_defn.check_instantiate(
|
|
327
|
+
[param.to_existential()[0] for param in self_defn.params]
|
|
328
|
+
)
|
|
329
|
+
self_ty_placeholder = ExistentialTypeVar.fresh(
|
|
330
|
+
"Self", copyable=self_ty_head.copyable, droppable=self_ty_head.droppable
|
|
331
|
+
)
|
|
332
|
+
assert ctx.self_ty is None
|
|
333
|
+
ctx = replace(ctx, self_ty=self_ty_placeholder)
|
|
334
|
+
user_ty, user_flags = type_with_flags_from_ast(arg.annotation, ctx)
|
|
335
|
+
|
|
336
|
+
# If the user just annotates `self: Self` then we can fall back to the case where
|
|
337
|
+
# no annotation is provided at all
|
|
338
|
+
if user_ty == self_ty_placeholder:
|
|
339
|
+
return handle_implicit_self_arg(arg, self_defn, ctx, user_flags)
|
|
340
|
+
|
|
341
|
+
# Annotations like `self: Foo[Self]` are not allowed (would be an infinite type)
|
|
342
|
+
if self_ty_placeholder in user_ty.unsolved_vars:
|
|
343
|
+
raise GuppyError(RecursiveSelfError(arg.annotation, arg.arg))
|
|
344
|
+
|
|
345
|
+
# Check that the annotation matches the parent type. We can do this by unifying with
|
|
346
|
+
# the expected self type where all params are instantiated with unification vars
|
|
347
|
+
subst = unify(user_ty, self_ty_head, {})
|
|
348
|
+
if subst is None:
|
|
349
|
+
raise GuppyError(InvalidSelfError(arg.annotation, arg.arg, self_ty_head))
|
|
350
|
+
|
|
351
|
+
return check_function_arg(user_ty, user_flags, arg, arg.arg, ctx)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def handle_implicit_self_arg(
|
|
355
|
+
arg: ast.arg,
|
|
356
|
+
self_defn: TypeDef,
|
|
357
|
+
ctx: TypeParsingCtx,
|
|
358
|
+
flags: InputFlags = InputFlags.NoFlags,
|
|
359
|
+
) -> FuncInput:
|
|
360
|
+
"""Handles the case where no annotation for `self` is provided.
|
|
361
|
+
|
|
362
|
+
Generates the most generic annotation that is possible by making the function as
|
|
363
|
+
generic as the parent type.
|
|
364
|
+
"""
|
|
365
|
+
# Check that the user hasn't shadowed some of the parent type parameters using a
|
|
366
|
+
# Python 3.12 style parameter declaration
|
|
367
|
+
assert self_defn.params is not None
|
|
368
|
+
shadowed_params = [
|
|
369
|
+
param for param in self_defn.params if param.name in ctx.param_var_mapping
|
|
370
|
+
]
|
|
371
|
+
if shadowed_params:
|
|
372
|
+
param = shadowed_params.pop()
|
|
373
|
+
err = SelfParamsShadowedError(arg, param.name, self_defn, arg.arg)
|
|
374
|
+
err.add_sub_diagnostic(SelfParamsShadowedError.ExplicitHelp(arg))
|
|
375
|
+
raise GuppyError(err)
|
|
376
|
+
|
|
377
|
+
# The generic params inherited from the parent type should appear first in the
|
|
378
|
+
# parameter list, so we have to shift the existing ones
|
|
379
|
+
for name, param in ctx.param_var_mapping.items():
|
|
380
|
+
ctx.param_var_mapping[name] = param.with_idx(param.idx + len(self_defn.params))
|
|
381
|
+
|
|
382
|
+
ctx.param_var_mapping.update({param.name: param for param in self_defn.params})
|
|
383
|
+
self_args = [param.to_bound() for param in self_defn.params]
|
|
384
|
+
self_ty = self_defn.check_instantiate(self_args, loc=arg)
|
|
385
|
+
return check_function_arg(self_ty, flags, arg, arg.arg, ctx)
|
|
386
|
+
|
|
387
|
+
|
|
239
388
|
def parse_function_with_docstring(
|
|
240
389
|
func_ast: ast.FunctionDef,
|
|
241
390
|
) -> tuple[ast.FunctionDef, str | None]:
|
|
@@ -52,6 +52,7 @@ from guppylang_internals.error import GuppyError, GuppyTypeError
|
|
|
52
52
|
from guppylang_internals.nodes import (
|
|
53
53
|
AnyCall,
|
|
54
54
|
BarrierExpr,
|
|
55
|
+
CheckedModifiedBlock,
|
|
55
56
|
CheckedNestedFunctionDef,
|
|
56
57
|
DesugaredArrayComp,
|
|
57
58
|
DesugaredGenerator,
|
|
@@ -621,6 +622,70 @@ class BBLinearityChecker(ast.NodeVisitor):
|
|
|
621
622
|
elif not place.ty.copyable:
|
|
622
623
|
raise GuppyTypeError(ComprAlreadyUsedError(use.node, place, use.kind))
|
|
623
624
|
|
|
625
|
+
def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
|
|
626
|
+
# Linear usage of variables in a with statement
|
|
627
|
+
# ```
|
|
628
|
+
# with control(c1, c2, ...):
|
|
629
|
+
# body(q1, q2, ...) # captured variables
|
|
630
|
+
# ````
|
|
631
|
+
# is the same as to assume that this is a function call
|
|
632
|
+
# `WithCtrl(q1, q2, ..., c1, c2, ...)`
|
|
633
|
+
# where `WithCtrl` is a function that takes the control as mutable references.
|
|
634
|
+
# Therefore, we apply the same linearity rules as for function arguments.
|
|
635
|
+
# ```
|
|
636
|
+
# def WithCtrl(q1, q2, ..., c1, c2, ...):
|
|
637
|
+
# body(q1, q2, ...)
|
|
638
|
+
# ```
|
|
639
|
+
|
|
640
|
+
# check control
|
|
641
|
+
for ctrl in node.control:
|
|
642
|
+
for arg in ctrl.ctrl:
|
|
643
|
+
if isinstance(arg, PlaceNode):
|
|
644
|
+
self.visit_PlaceNode(arg, use_kind=UseKind.BORROW, is_call_arg=None)
|
|
645
|
+
else:
|
|
646
|
+
ty = get_type(arg)
|
|
647
|
+
unnamed_err = UnnamedExprNotUsedError(arg, ty)
|
|
648
|
+
unnamed_err.add_sub_diagnostic(UnnamedExprNotUsedError.Fix(None))
|
|
649
|
+
raise GuppyTypeError(unnamed_err)
|
|
650
|
+
|
|
651
|
+
# check power
|
|
652
|
+
for power in node.power:
|
|
653
|
+
if isinstance(power.iter, PlaceNode):
|
|
654
|
+
self.visit_PlaceNode(
|
|
655
|
+
power.iter, use_kind=UseKind.CONSUME, is_call_arg=None
|
|
656
|
+
)
|
|
657
|
+
else:
|
|
658
|
+
self.visit(power.iter)
|
|
659
|
+
|
|
660
|
+
# check captured variables
|
|
661
|
+
for var, use in node.captured.values():
|
|
662
|
+
for place in leaf_places(var):
|
|
663
|
+
use_kind = (
|
|
664
|
+
UseKind.BORROW if InputFlags.Inout in var.flags else UseKind.CONSUME
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
x = place.id
|
|
668
|
+
if (prev_use := self.scope.used(x)) and not place.ty.copyable:
|
|
669
|
+
used_err = AlreadyUsedError(use, place, use_kind)
|
|
670
|
+
used_err.add_sub_diagnostic(
|
|
671
|
+
AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind)
|
|
672
|
+
)
|
|
673
|
+
if has_explicit_copy(place.ty):
|
|
674
|
+
used_err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
|
|
675
|
+
raise GuppyError(used_err)
|
|
676
|
+
self.scope.use(x, node, use_kind)
|
|
677
|
+
|
|
678
|
+
# reassign controls
|
|
679
|
+
for ctrl in node.control:
|
|
680
|
+
for arg in ctrl.ctrl:
|
|
681
|
+
assert isinstance(arg, PlaceNode) # Checked above
|
|
682
|
+
self._reassign_single_inout_arg(arg.place, arg.place.defined_at or arg)
|
|
683
|
+
|
|
684
|
+
# reassign captured variables
|
|
685
|
+
for var, use in node.captured.values():
|
|
686
|
+
if InputFlags.Inout in var.flags:
|
|
687
|
+
self._reassign_single_inout_arg(var, var.defined_at or use)
|
|
688
|
+
|
|
624
689
|
|
|
625
690
|
def leaf_places(place: Place) -> Iterator[Place]:
|
|
626
691
|
"""Returns all leaf descendant projections of a place."""
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Type checking code for modifiers."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
|
|
5
|
+
from guppylang_internals.ast_util import loop_in_ast, with_loc
|
|
6
|
+
from guppylang_internals.cfg.bb import BB
|
|
7
|
+
from guppylang_internals.checker.cfg_checker import check_cfg
|
|
8
|
+
from guppylang_internals.checker.core import Context, Variable
|
|
9
|
+
from guppylang_internals.checker.errors.generic import InvalidUnderDagger
|
|
10
|
+
from guppylang_internals.definition.common import DefId
|
|
11
|
+
from guppylang_internals.error import GuppyError
|
|
12
|
+
from guppylang_internals.nodes import CheckedModifiedBlock, ModifiedBlock
|
|
13
|
+
from guppylang_internals.tys.ty import (
|
|
14
|
+
FuncInput,
|
|
15
|
+
FunctionType,
|
|
16
|
+
InputFlags,
|
|
17
|
+
NoneType,
|
|
18
|
+
Type,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def check_modified_block(
|
|
23
|
+
modified_block: ModifiedBlock, bb: BB, ctx: Context
|
|
24
|
+
) -> CheckedModifiedBlock:
|
|
25
|
+
"""Type checks a modifier definition."""
|
|
26
|
+
cfg = modified_block.cfg
|
|
27
|
+
|
|
28
|
+
# Find captured variables
|
|
29
|
+
parent_cfg = bb.containing_cfg
|
|
30
|
+
def_ass_before = ctx.locals.keys()
|
|
31
|
+
maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb]
|
|
32
|
+
|
|
33
|
+
cfg.analyze(def_ass_before, maybe_ass_before, [])
|
|
34
|
+
captured = {
|
|
35
|
+
x: (_set_inout_if_non_copyable(ctx.locals[x]), using_bb.vars.used[x])
|
|
36
|
+
for x, using_bb in cfg.live_before[cfg.entry_bb].items()
|
|
37
|
+
if x in ctx.locals
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
# We do not allow any assignments if it is daggered.
|
|
41
|
+
if modified_block.is_dagger():
|
|
42
|
+
for stmt in modified_block.body:
|
|
43
|
+
loops = loop_in_ast(stmt)
|
|
44
|
+
if len(loops) != 0:
|
|
45
|
+
loop = next(iter(loops))
|
|
46
|
+
err = InvalidUnderDagger(loop, "Loop")
|
|
47
|
+
err.add_sub_diagnostic(
|
|
48
|
+
InvalidUnderDagger.Dagger(modified_block.span_ctxt_manager())
|
|
49
|
+
)
|
|
50
|
+
raise GuppyError(err)
|
|
51
|
+
|
|
52
|
+
for cfg_bb in cfg.bbs:
|
|
53
|
+
if cfg_bb.vars.assigned:
|
|
54
|
+
_, v = next(iter(cfg_bb.vars.assigned.items()))
|
|
55
|
+
err = InvalidUnderDagger(v, "Assignment")
|
|
56
|
+
err.add_sub_diagnostic(
|
|
57
|
+
InvalidUnderDagger.Dagger(modified_block.span_ctxt_manager())
|
|
58
|
+
)
|
|
59
|
+
raise GuppyError(err)
|
|
60
|
+
|
|
61
|
+
# The other checks are done in unitary checking.
|
|
62
|
+
# e.g. call to non-unitary function in a unitary modifier.
|
|
63
|
+
|
|
64
|
+
# Construct inputs for checking the body CFG
|
|
65
|
+
inputs = [v for v, _ in captured.values()]
|
|
66
|
+
inputs = non_copyable_front_others_back(inputs)
|
|
67
|
+
def_id = DefId.fresh()
|
|
68
|
+
globals = ctx.globals
|
|
69
|
+
|
|
70
|
+
# TODO: Ad hoc name for the new function
|
|
71
|
+
# This name could be printed in error messages, for example,
|
|
72
|
+
# when the linearity checker fails in the modifier body
|
|
73
|
+
checked_cfg = check_cfg(cfg, inputs, NoneType(), {}, "__modified__()", globals)
|
|
74
|
+
func_ty = check_modified_block_signature(checked_cfg.input_tys)
|
|
75
|
+
|
|
76
|
+
checked_modifier = CheckedModifiedBlock(
|
|
77
|
+
def_id,
|
|
78
|
+
checked_cfg,
|
|
79
|
+
func_ty,
|
|
80
|
+
captured,
|
|
81
|
+
modified_block.dagger,
|
|
82
|
+
modified_block.control,
|
|
83
|
+
modified_block.power,
|
|
84
|
+
**dict(ast.iter_fields(modified_block)),
|
|
85
|
+
)
|
|
86
|
+
return with_loc(modified_block, checked_modifier)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _set_inout_if_non_copyable(var: Variable) -> Variable:
|
|
90
|
+
"""Set the `inout` flag if the variable is non-copyable."""
|
|
91
|
+
if not var.ty.copyable:
|
|
92
|
+
return var.add_flags(InputFlags.Inout)
|
|
93
|
+
else:
|
|
94
|
+
return var
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def check_modified_block_signature(input_tys: list[Type]) -> FunctionType:
|
|
98
|
+
"""Check and create the signature of a function definition for a body
|
|
99
|
+
of a `With` block."""
|
|
100
|
+
|
|
101
|
+
func_ty = FunctionType(
|
|
102
|
+
[
|
|
103
|
+
FuncInput(t, InputFlags.Inout if not t.copyable else InputFlags.NoFlags)
|
|
104
|
+
for t in input_tys
|
|
105
|
+
],
|
|
106
|
+
NoneType(),
|
|
107
|
+
)
|
|
108
|
+
return func_ty
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def non_copyable_front_others_back(v: list[Variable]) -> list[Variable]:
|
|
112
|
+
"""Reorder variables so that linear ones come first, preserving the relative order
|
|
113
|
+
of linear and non-linear variables."""
|
|
114
|
+
linear_vars = [x for x in v if not x.ty.copyable]
|
|
115
|
+
non_linear_vars = [x for x in v if x.ty.copyable]
|
|
116
|
+
return linear_vars + non_linear_vars
|
|
@@ -42,7 +42,9 @@ from guppylang_internals.checker.errors.type_errors import (
|
|
|
42
42
|
MissingReturnValueError,
|
|
43
43
|
StarredTupleUnpackError,
|
|
44
44
|
TypeInferenceError,
|
|
45
|
+
TypeMismatchError,
|
|
45
46
|
UnpackableError,
|
|
47
|
+
WrongNumberOfArgsError,
|
|
46
48
|
WrongNumberOfUnpacksError,
|
|
47
49
|
)
|
|
48
50
|
from guppylang_internals.checker.expr_checker import (
|
|
@@ -58,6 +60,7 @@ from guppylang_internals.nodes import (
|
|
|
58
60
|
DesugaredArrayComp,
|
|
59
61
|
IterableUnpack,
|
|
60
62
|
MakeIter,
|
|
63
|
+
ModifiedBlock,
|
|
61
64
|
NestedFunctionDef,
|
|
62
65
|
PlaceNode,
|
|
63
66
|
TupleUnpack,
|
|
@@ -73,13 +76,15 @@ from guppylang_internals.tys.builtin import (
|
|
|
73
76
|
is_sized_iter_type,
|
|
74
77
|
nat_type,
|
|
75
78
|
)
|
|
76
|
-
from guppylang_internals.tys.const import ConstValue
|
|
79
|
+
from guppylang_internals.tys.const import ConstValue, ExistentialConstVar
|
|
77
80
|
from guppylang_internals.tys.parsing import type_from_ast
|
|
81
|
+
from guppylang_internals.tys.qubit import is_qubit_ty, qubit_ty
|
|
78
82
|
from guppylang_internals.tys.subst import Subst
|
|
79
83
|
from guppylang_internals.tys.ty import (
|
|
80
84
|
ExistentialTypeVar,
|
|
81
85
|
FunctionType,
|
|
82
86
|
NoneType,
|
|
87
|
+
NumericType,
|
|
83
88
|
StructType,
|
|
84
89
|
TupleType,
|
|
85
90
|
Type,
|
|
@@ -356,7 +361,7 @@ class StmtChecker(AstVisitor[BBStatement]):
|
|
|
356
361
|
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
|
|
357
362
|
if node.value is None:
|
|
358
363
|
raise GuppyError(UnsupportedError(node, "Variable declarations"))
|
|
359
|
-
ty = type_from_ast(node.annotation, self.ctx.
|
|
364
|
+
ty = type_from_ast(node.annotation, self.ctx.parsing_ctx)
|
|
360
365
|
node.value, subst = self._check_expr(node.value, ty)
|
|
361
366
|
assert not ty.unsolved_vars # `ty` must be closed!
|
|
362
367
|
assert len(subst) == 0
|
|
@@ -398,6 +403,48 @@ class StmtChecker(AstVisitor[BBStatement]):
|
|
|
398
403
|
self.ctx.locals[func_def.name] = Variable(func_def.name, func_def.ty, func_def)
|
|
399
404
|
return func_def
|
|
400
405
|
|
|
406
|
+
def visit_ModifiedBlock(self, node: ModifiedBlock) -> ast.stmt:
|
|
407
|
+
from guppylang_internals.checker.modifier_checker import check_modified_block
|
|
408
|
+
|
|
409
|
+
if not self.bb:
|
|
410
|
+
raise InternalGuppyError("BB required to check with block!")
|
|
411
|
+
|
|
412
|
+
# check the body of the modified block
|
|
413
|
+
modified_block = check_modified_block(node, self.bb, self.ctx)
|
|
414
|
+
|
|
415
|
+
# check the arguments of the control and power.
|
|
416
|
+
for control in modified_block.control:
|
|
417
|
+
ctrl = control.ctrl
|
|
418
|
+
# This case is handled during CFG construction.
|
|
419
|
+
assert len(ctrl) > 0
|
|
420
|
+
ctrl[0], ty = self._synth_expr(ctrl[0])
|
|
421
|
+
|
|
422
|
+
if is_array_type(ty):
|
|
423
|
+
if len(ctrl) > 1:
|
|
424
|
+
span = Span(to_span(control.func).end, to_span(control).end)
|
|
425
|
+
raise GuppyError(WrongNumberOfArgsError(span, 1, len(control.args)))
|
|
426
|
+
element_ty = get_element_type(ty)
|
|
427
|
+
if not is_qubit_ty(element_ty):
|
|
428
|
+
n = ExistentialConstVar.fresh(
|
|
429
|
+
"n", NumericType(NumericType.Kind.Nat)
|
|
430
|
+
)
|
|
431
|
+
dummy_array_ty = array_type(qubit_ty(), n)
|
|
432
|
+
raise GuppyTypeError(TypeMismatchError(ctrl[0], dummy_array_ty, ty))
|
|
433
|
+
control.qubit_num = get_array_length(ty)
|
|
434
|
+
else:
|
|
435
|
+
for i in range(len(ctrl)):
|
|
436
|
+
ctrl[i], subst = self._check_expr(ctrl[i], qubit_ty())
|
|
437
|
+
assert len(subst) == 0
|
|
438
|
+
control.qubit_num = len(ctrl)
|
|
439
|
+
|
|
440
|
+
for power in node.power:
|
|
441
|
+
power.iter, subst = self._check_expr(
|
|
442
|
+
power.iter, NumericType(NumericType.Kind.Nat)
|
|
443
|
+
)
|
|
444
|
+
assert len(subst) == 0
|
|
445
|
+
|
|
446
|
+
return modified_block
|
|
447
|
+
|
|
401
448
|
def visit_If(self, node: ast.If) -> None:
|
|
402
449
|
raise InternalGuppyError("Control-flow statement should not be present here.")
|
|
403
450
|
|