guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +118 -5
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +5 -2
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +58 -17
- guppylang_internals/checker/func_checker.py +18 -14
- guppylang_internals/checker/linearity_checker.py +67 -10
- guppylang_internals/checker/modifier_checker.py +120 -0
- guppylang_internals/checker/stmt_checker.py +48 -1
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +93 -56
- guppylang_internals/compiler/expr_compiler.py +72 -168
- guppylang_internals/compiler/modifier_compiler.py +176 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +86 -7
- guppylang_internals/definition/custom.py +39 -1
- guppylang_internals/definition/declaration.py +9 -6
- guppylang_internals/definition/function.py +12 -2
- guppylang_internals/definition/parameter.py +8 -3
- guppylang_internals/definition/pytket_circuits.py +14 -41
- guppylang_internals/definition/struct.py +13 -7
- guppylang_internals/definition/ty.py +3 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/engine.py +9 -3
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +147 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +95 -283
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +10 -0
- guppylang_internals/tracing/unpacking.py +19 -20
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +2 -5
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +11 -24
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +62 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +91 -85
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
- guppylang_internals-0.26.0.dist-info/RECORD +104 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
- guppylang_internals-0.24.0.dist-info/RECORD +0 -98
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -16,6 +16,7 @@ from guppylang_internals.cfg.builder import CFGBuilder
|
|
|
16
16
|
from guppylang_internals.checker.cfg_checker import CheckedCFG, check_cfg
|
|
17
17
|
from guppylang_internals.checker.core import Context, Globals, Place, Variable
|
|
18
18
|
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
19
|
+
from guppylang_internals.checker.unitary_checker import check_invalid_under_dagger
|
|
19
20
|
from guppylang_internals.definition.common import DefId
|
|
20
21
|
from guppylang_internals.definition.ty import TypeDef
|
|
21
22
|
from guppylang_internals.diagnostic import Error, Help, Note
|
|
@@ -37,6 +38,7 @@ from guppylang_internals.tys.ty import (
|
|
|
37
38
|
InputFlags,
|
|
38
39
|
NoneType,
|
|
39
40
|
Type,
|
|
41
|
+
UnitaryFlags,
|
|
40
42
|
unify,
|
|
41
43
|
)
|
|
42
44
|
|
|
@@ -134,12 +136,13 @@ def check_global_func_def(
|
|
|
134
136
|
"""Type checks a top-level function definition."""
|
|
135
137
|
args = func_def.args.args
|
|
136
138
|
returns_none = isinstance(ty.output, NoneType)
|
|
137
|
-
assert
|
|
139
|
+
assert all(inp.name is not None for inp in ty.inputs)
|
|
138
140
|
|
|
139
|
-
|
|
141
|
+
check_invalid_under_dagger(func_def, ty.unitary_flags)
|
|
142
|
+
cfg = CFGBuilder().build(func_def.body, returns_none, globals, ty.unitary_flags)
|
|
140
143
|
inputs = [
|
|
141
|
-
Variable(
|
|
142
|
-
for
|
|
144
|
+
Variable(cast(str, inp.name), inp.ty, loc, inp.flags, is_func_input=True)
|
|
145
|
+
for inp, loc in zip(ty.inputs, args, strict=True)
|
|
143
146
|
# Comptime inputs are turned into generic args, so are not included here
|
|
144
147
|
if InputFlags.Comptime not in inp.flags
|
|
145
148
|
]
|
|
@@ -150,7 +153,9 @@ def check_global_func_def(
|
|
|
150
153
|
|
|
151
154
|
|
|
152
155
|
def check_nested_func_def(
|
|
153
|
-
func_def: NestedFunctionDef,
|
|
156
|
+
func_def: NestedFunctionDef,
|
|
157
|
+
bb: BB,
|
|
158
|
+
ctx: Context,
|
|
154
159
|
) -> CheckedNestedFunctionDef:
|
|
155
160
|
"""Type checks a local (nested) function definition."""
|
|
156
161
|
func_ty = check_signature(func_def, ctx.globals)
|
|
@@ -194,10 +199,8 @@ def check_nested_func_def(
|
|
|
194
199
|
|
|
195
200
|
# Construct inputs for checking the body CFG
|
|
196
201
|
inputs = [v for v, _ in captured.values()] + [
|
|
197
|
-
Variable(
|
|
198
|
-
for
|
|
199
|
-
zip(func_ty.input_names, func_ty.inputs, strict=True)
|
|
200
|
-
)
|
|
202
|
+
Variable(cast(str, inp.name), inp.ty, arg, inp.flags, is_func_input=True)
|
|
203
|
+
for arg, inp in zip(func_def.args.args, func_ty.inputs, strict=True)
|
|
201
204
|
# Comptime inputs are turned into generic args, so are not included here
|
|
202
205
|
if InputFlags.Comptime not in inp.flags
|
|
203
206
|
]
|
|
@@ -238,7 +241,10 @@ def check_nested_func_def(
|
|
|
238
241
|
|
|
239
242
|
|
|
240
243
|
def check_signature(
|
|
241
|
-
func_def: ast.FunctionDef,
|
|
244
|
+
func_def: ast.FunctionDef,
|
|
245
|
+
globals: Globals,
|
|
246
|
+
def_id: DefId | None = None,
|
|
247
|
+
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
|
|
242
248
|
) -> FunctionType:
|
|
243
249
|
"""Checks the signature of a function definition and returns the corresponding
|
|
244
250
|
Guppy type.
|
|
@@ -276,7 +282,7 @@ def check_signature(
|
|
|
276
282
|
param_var_mapping: dict[str, Parameter] = {}
|
|
277
283
|
if sys.version_info >= (3, 12):
|
|
278
284
|
for i, param_node in enumerate(func_def.type_params):
|
|
279
|
-
param = parse_parameter(param_node, i, globals)
|
|
285
|
+
param = parse_parameter(param_node, i, globals, param_var_mapping)
|
|
280
286
|
param_var_mapping[param.name] = param
|
|
281
287
|
|
|
282
288
|
# Figure out if this is a method
|
|
@@ -286,7 +292,6 @@ def check_signature(
|
|
|
286
292
|
assert isinstance(self_defn, TypeDef)
|
|
287
293
|
|
|
288
294
|
inputs = []
|
|
289
|
-
input_names = []
|
|
290
295
|
ctx = TypeParsingCtx(globals, param_var_mapping, allow_free_vars=True)
|
|
291
296
|
for i, inp in enumerate(func_def.args.args):
|
|
292
297
|
# Special handling for `self` arguments. Note that `__new__` is excluded here
|
|
@@ -300,13 +305,12 @@ def check_signature(
|
|
|
300
305
|
raise GuppyError(MissingArgAnnotationError(inp))
|
|
301
306
|
input = parse_function_arg_annotation(ty_ast, inp.arg, ctx)
|
|
302
307
|
inputs.append(input)
|
|
303
|
-
input_names.append(inp.arg)
|
|
304
308
|
output = type_from_ast(func_def.returns, ctx)
|
|
305
309
|
return FunctionType(
|
|
306
310
|
inputs,
|
|
307
311
|
output,
|
|
308
|
-
input_names,
|
|
309
312
|
sorted(param_var_mapping.values(), key=lambda v: v.idx),
|
|
313
|
+
unitary_flags=unitary_flags,
|
|
310
314
|
)
|
|
311
315
|
|
|
312
316
|
|
|
@@ -52,6 +52,7 @@ from guppylang_internals.error import GuppyError, GuppyTypeError
|
|
|
52
52
|
from guppylang_internals.nodes import (
|
|
53
53
|
AnyCall,
|
|
54
54
|
BarrierExpr,
|
|
55
|
+
CheckedModifiedBlock,
|
|
55
56
|
CheckedNestedFunctionDef,
|
|
56
57
|
DesugaredArrayComp,
|
|
57
58
|
DesugaredGenerator,
|
|
@@ -62,7 +63,6 @@ from guppylang_internals.nodes import (
|
|
|
62
63
|
LocalCall,
|
|
63
64
|
PartialApply,
|
|
64
65
|
PlaceNode,
|
|
65
|
-
ResultExpr,
|
|
66
66
|
StateResultExpr,
|
|
67
67
|
SubscriptAccessAndDrop,
|
|
68
68
|
TensorCall,
|
|
@@ -73,7 +73,6 @@ from guppylang_internals.tys.ty import (
|
|
|
73
73
|
FuncInput,
|
|
74
74
|
FunctionType,
|
|
75
75
|
InputFlags,
|
|
76
|
-
NoneType,
|
|
77
76
|
StructType,
|
|
78
77
|
TupleType,
|
|
79
78
|
Type,
|
|
@@ -450,13 +449,6 @@ class BBLinearityChecker(ast.NodeVisitor):
|
|
|
450
449
|
self._visit_call_args(node.func_ty, node)
|
|
451
450
|
self._reassign_inout_args(node.func_ty, node)
|
|
452
451
|
|
|
453
|
-
def visit_ResultExpr(self, node: ResultExpr) -> None:
|
|
454
|
-
ty = get_type(node.value)
|
|
455
|
-
flag = InputFlags.Inout if not ty.copyable else InputFlags.NoFlags
|
|
456
|
-
func_ty = FunctionType([FuncInput(ty, flag)], NoneType())
|
|
457
|
-
self._visit_call_args(func_ty, node)
|
|
458
|
-
self._reassign_inout_args(func_ty, node)
|
|
459
|
-
|
|
460
452
|
def visit_StateResultExpr(self, node: StateResultExpr) -> None:
|
|
461
453
|
self._visit_call_args(node.func_ty, node)
|
|
462
454
|
self._reassign_inout_args(node.func_ty, node)
|
|
@@ -581,7 +573,7 @@ class BBLinearityChecker(ast.NodeVisitor):
|
|
|
581
573
|
# can feed them through the loop. Note that we could also use non-local
|
|
582
574
|
# edges, but we can't handle them in lower parts of the stack yet :/
|
|
583
575
|
# TODO: Reinstate use of non-local edges.
|
|
584
|
-
# See https://github.com/
|
|
576
|
+
# See https://github.com/quantinuum/guppylang/issues/963
|
|
585
577
|
gen.used_outer_places = []
|
|
586
578
|
for x, use in inner_scope.used_parent.items():
|
|
587
579
|
place = inner_scope[x]
|
|
@@ -621,6 +613,70 @@ class BBLinearityChecker(ast.NodeVisitor):
|
|
|
621
613
|
elif not place.ty.copyable:
|
|
622
614
|
raise GuppyTypeError(ComprAlreadyUsedError(use.node, place, use.kind))
|
|
623
615
|
|
|
616
|
+
def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
|
|
617
|
+
# Linear usage of variables in a with statement
|
|
618
|
+
# ```
|
|
619
|
+
# with control(c1, c2, ...):
|
|
620
|
+
# body(q1, q2, ...) # captured variables
|
|
621
|
+
# ````
|
|
622
|
+
# is the same as to assume that this is a function call
|
|
623
|
+
# `WithCtrl(q1, q2, ..., c1, c2, ...)`
|
|
624
|
+
# where `WithCtrl` is a function that takes the control as mutable references.
|
|
625
|
+
# Therefore, we apply the same linearity rules as for function arguments.
|
|
626
|
+
# ```
|
|
627
|
+
# def WithCtrl(q1, q2, ..., c1, c2, ...):
|
|
628
|
+
# body(q1, q2, ...)
|
|
629
|
+
# ```
|
|
630
|
+
|
|
631
|
+
# check control
|
|
632
|
+
for ctrl in node.control:
|
|
633
|
+
for arg in ctrl.ctrl:
|
|
634
|
+
if isinstance(arg, PlaceNode):
|
|
635
|
+
self.visit_PlaceNode(arg, use_kind=UseKind.BORROW, is_call_arg=None)
|
|
636
|
+
else:
|
|
637
|
+
ty = get_type(arg)
|
|
638
|
+
unnamed_err = UnnamedExprNotUsedError(arg, ty)
|
|
639
|
+
unnamed_err.add_sub_diagnostic(UnnamedExprNotUsedError.Fix(None))
|
|
640
|
+
raise GuppyTypeError(unnamed_err)
|
|
641
|
+
|
|
642
|
+
# check power
|
|
643
|
+
for power in node.power:
|
|
644
|
+
if isinstance(power.iter, PlaceNode):
|
|
645
|
+
self.visit_PlaceNode(
|
|
646
|
+
power.iter, use_kind=UseKind.CONSUME, is_call_arg=None
|
|
647
|
+
)
|
|
648
|
+
else:
|
|
649
|
+
self.visit(power.iter)
|
|
650
|
+
|
|
651
|
+
# check captured variables
|
|
652
|
+
for var, use in node.captured.values():
|
|
653
|
+
for place in leaf_places(var):
|
|
654
|
+
use_kind = (
|
|
655
|
+
UseKind.BORROW if InputFlags.Inout in var.flags else UseKind.CONSUME
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
x = place.id
|
|
659
|
+
if (prev_use := self.scope.used(x)) and not place.ty.copyable:
|
|
660
|
+
used_err = AlreadyUsedError(use, place, use_kind)
|
|
661
|
+
used_err.add_sub_diagnostic(
|
|
662
|
+
AlreadyUsedError.PrevUse(prev_use.node, prev_use.kind)
|
|
663
|
+
)
|
|
664
|
+
if has_explicit_copy(place.ty):
|
|
665
|
+
used_err.add_sub_diagnostic(AlreadyUsedError.MakeCopy(None))
|
|
666
|
+
raise GuppyError(used_err)
|
|
667
|
+
self.scope.use(x, node, use_kind)
|
|
668
|
+
|
|
669
|
+
# reassign controls
|
|
670
|
+
for ctrl in node.control:
|
|
671
|
+
for arg in ctrl.ctrl:
|
|
672
|
+
assert isinstance(arg, PlaceNode) # Checked above
|
|
673
|
+
self._reassign_single_inout_arg(arg.place, arg.place.defined_at or arg)
|
|
674
|
+
|
|
675
|
+
# reassign captured variables
|
|
676
|
+
for var, use in node.captured.values():
|
|
677
|
+
if InputFlags.Inout in var.flags:
|
|
678
|
+
self._reassign_single_inout_arg(var, var.defined_at or use)
|
|
679
|
+
|
|
624
680
|
|
|
625
681
|
def leaf_places(place: Place) -> Iterator[Place]:
|
|
626
682
|
"""Returns all leaf descendant projections of a place."""
|
|
@@ -815,6 +871,7 @@ def check_cfg_linearity(
|
|
|
815
871
|
result_cfg.maybe_ass_before = {
|
|
816
872
|
checked[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs
|
|
817
873
|
}
|
|
874
|
+
result_cfg.unitary_flags = cfg.unitary_flags
|
|
818
875
|
for bb in cfg.bbs:
|
|
819
876
|
checked[bb].predecessors = [checked[pred] for pred in bb.predecessors]
|
|
820
877
|
checked[bb].successors = [checked[succ] for succ in bb.successors]
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Type checking code for modifiers."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
|
|
5
|
+
from guppylang_internals.ast_util import loop_in_ast, with_loc
|
|
6
|
+
from guppylang_internals.cfg.bb import BB
|
|
7
|
+
from guppylang_internals.checker.cfg_checker import check_cfg
|
|
8
|
+
from guppylang_internals.checker.core import Context, Variable
|
|
9
|
+
from guppylang_internals.checker.errors.generic import InvalidUnderDagger
|
|
10
|
+
from guppylang_internals.definition.common import DefId
|
|
11
|
+
from guppylang_internals.error import GuppyError
|
|
12
|
+
from guppylang_internals.nodes import CheckedModifiedBlock, ModifiedBlock
|
|
13
|
+
from guppylang_internals.tys.ty import (
|
|
14
|
+
FuncInput,
|
|
15
|
+
FunctionType,
|
|
16
|
+
InputFlags,
|
|
17
|
+
NoneType,
|
|
18
|
+
Type,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def check_modified_block(
|
|
23
|
+
modified_block: ModifiedBlock, bb: BB, ctx: Context
|
|
24
|
+
) -> CheckedModifiedBlock:
|
|
25
|
+
"""Type checks a modifier definition."""
|
|
26
|
+
cfg = modified_block.cfg
|
|
27
|
+
|
|
28
|
+
# Find captured variables
|
|
29
|
+
parent_cfg = bb.containing_cfg
|
|
30
|
+
def_ass_before = ctx.locals.keys()
|
|
31
|
+
maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb]
|
|
32
|
+
|
|
33
|
+
cfg.analyze(def_ass_before, maybe_ass_before, [])
|
|
34
|
+
captured = {
|
|
35
|
+
x: (_set_inout_if_non_copyable(ctx.locals[x]), using_bb.vars.used[x])
|
|
36
|
+
for x, using_bb in cfg.live_before[cfg.entry_bb].items()
|
|
37
|
+
if x in ctx.locals
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
# We do not allow any assignments if it is daggered.
|
|
41
|
+
if modified_block.is_dagger():
|
|
42
|
+
for stmt in modified_block.body:
|
|
43
|
+
loops = loop_in_ast(stmt)
|
|
44
|
+
if len(loops) != 0:
|
|
45
|
+
loop = next(iter(loops))
|
|
46
|
+
err = InvalidUnderDagger(loop, "Loop")
|
|
47
|
+
err.add_sub_diagnostic(
|
|
48
|
+
InvalidUnderDagger.Dagger(modified_block.span_ctxt_manager())
|
|
49
|
+
)
|
|
50
|
+
raise GuppyError(err)
|
|
51
|
+
|
|
52
|
+
for cfg_bb in cfg.bbs:
|
|
53
|
+
if cfg_bb.vars.assigned:
|
|
54
|
+
_, v = next(iter(cfg_bb.vars.assigned.items()))
|
|
55
|
+
err = InvalidUnderDagger(v, "Assignment")
|
|
56
|
+
err.add_sub_diagnostic(
|
|
57
|
+
InvalidUnderDagger.Dagger(modified_block.span_ctxt_manager())
|
|
58
|
+
)
|
|
59
|
+
raise GuppyError(err)
|
|
60
|
+
|
|
61
|
+
# The other checks are done in unitary checking.
|
|
62
|
+
# e.g. call to non-unitary function in a unitary modifier.
|
|
63
|
+
|
|
64
|
+
# Construct inputs for checking the body CFG
|
|
65
|
+
inputs = [v for v, _ in captured.values()]
|
|
66
|
+
inputs = non_copyable_front_others_back(inputs)
|
|
67
|
+
def_id = DefId.fresh()
|
|
68
|
+
globals = ctx.globals
|
|
69
|
+
|
|
70
|
+
# TODO: Ad hoc name for the new function
|
|
71
|
+
# This name could be printed in error messages, for example,
|
|
72
|
+
# when the linearity checker fails in the modifier body
|
|
73
|
+
checked_cfg = check_cfg(cfg, inputs, NoneType(), {}, "__modified__()", globals)
|
|
74
|
+
func_ty = check_modified_block_signature(modified_block, checked_cfg.input_tys)
|
|
75
|
+
|
|
76
|
+
checked_modifier = CheckedModifiedBlock(
|
|
77
|
+
def_id,
|
|
78
|
+
checked_cfg,
|
|
79
|
+
func_ty,
|
|
80
|
+
captured,
|
|
81
|
+
modified_block.dagger,
|
|
82
|
+
modified_block.control,
|
|
83
|
+
modified_block.power,
|
|
84
|
+
**dict(ast.iter_fields(modified_block)),
|
|
85
|
+
)
|
|
86
|
+
return with_loc(modified_block, checked_modifier)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _set_inout_if_non_copyable(var: Variable) -> Variable:
|
|
90
|
+
"""Set the `inout` flag if the variable is non-copyable."""
|
|
91
|
+
if not var.ty.copyable:
|
|
92
|
+
return var.add_flags(InputFlags.Inout)
|
|
93
|
+
else:
|
|
94
|
+
return var
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def check_modified_block_signature(
|
|
98
|
+
modified_block: ModifiedBlock, input_tys: list[Type]
|
|
99
|
+
) -> FunctionType:
|
|
100
|
+
"""Check and create the signature of a function definition for a body
|
|
101
|
+
of a `With` block."""
|
|
102
|
+
unitary_flags = modified_block.flags()
|
|
103
|
+
|
|
104
|
+
func_ty = FunctionType(
|
|
105
|
+
[
|
|
106
|
+
FuncInput(t, InputFlags.Inout if not t.copyable else InputFlags.NoFlags)
|
|
107
|
+
for t in input_tys
|
|
108
|
+
],
|
|
109
|
+
NoneType(),
|
|
110
|
+
unitary_flags=unitary_flags,
|
|
111
|
+
)
|
|
112
|
+
return func_ty
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def non_copyable_front_others_back(v: list[Variable]) -> list[Variable]:
|
|
116
|
+
"""Reorder variables so that linear ones come first, preserving the relative order
|
|
117
|
+
of linear and non-linear variables."""
|
|
118
|
+
linear_vars = [x for x in v if not x.ty.copyable]
|
|
119
|
+
non_linear_vars = [x for x in v if x.ty.copyable]
|
|
120
|
+
return linear_vars + non_linear_vars
|
|
@@ -42,7 +42,9 @@ from guppylang_internals.checker.errors.type_errors import (
|
|
|
42
42
|
MissingReturnValueError,
|
|
43
43
|
StarredTupleUnpackError,
|
|
44
44
|
TypeInferenceError,
|
|
45
|
+
TypeMismatchError,
|
|
45
46
|
UnpackableError,
|
|
47
|
+
WrongNumberOfArgsError,
|
|
46
48
|
WrongNumberOfUnpacksError,
|
|
47
49
|
)
|
|
48
50
|
from guppylang_internals.checker.expr_checker import (
|
|
@@ -58,6 +60,7 @@ from guppylang_internals.nodes import (
|
|
|
58
60
|
DesugaredArrayComp,
|
|
59
61
|
IterableUnpack,
|
|
60
62
|
MakeIter,
|
|
63
|
+
ModifiedBlock,
|
|
61
64
|
NestedFunctionDef,
|
|
62
65
|
PlaceNode,
|
|
63
66
|
TupleUnpack,
|
|
@@ -73,13 +76,15 @@ from guppylang_internals.tys.builtin import (
|
|
|
73
76
|
is_sized_iter_type,
|
|
74
77
|
nat_type,
|
|
75
78
|
)
|
|
76
|
-
from guppylang_internals.tys.const import ConstValue
|
|
79
|
+
from guppylang_internals.tys.const import ConstValue, ExistentialConstVar
|
|
77
80
|
from guppylang_internals.tys.parsing import type_from_ast
|
|
81
|
+
from guppylang_internals.tys.qubit import is_qubit_ty, qubit_ty
|
|
78
82
|
from guppylang_internals.tys.subst import Subst
|
|
79
83
|
from guppylang_internals.tys.ty import (
|
|
80
84
|
ExistentialTypeVar,
|
|
81
85
|
FunctionType,
|
|
82
86
|
NoneType,
|
|
87
|
+
NumericType,
|
|
83
88
|
StructType,
|
|
84
89
|
TupleType,
|
|
85
90
|
Type,
|
|
@@ -398,6 +403,48 @@ class StmtChecker(AstVisitor[BBStatement]):
|
|
|
398
403
|
self.ctx.locals[func_def.name] = Variable(func_def.name, func_def.ty, func_def)
|
|
399
404
|
return func_def
|
|
400
405
|
|
|
406
|
+
def visit_ModifiedBlock(self, node: ModifiedBlock) -> ast.stmt:
|
|
407
|
+
from guppylang_internals.checker.modifier_checker import check_modified_block
|
|
408
|
+
|
|
409
|
+
if not self.bb:
|
|
410
|
+
raise InternalGuppyError("BB required to check with block!")
|
|
411
|
+
|
|
412
|
+
# check the body of the modified block
|
|
413
|
+
modified_block = check_modified_block(node, self.bb, self.ctx)
|
|
414
|
+
|
|
415
|
+
# check the arguments of the control and power.
|
|
416
|
+
for control in modified_block.control:
|
|
417
|
+
ctrl = control.ctrl
|
|
418
|
+
# This case is handled during CFG construction.
|
|
419
|
+
assert len(ctrl) > 0
|
|
420
|
+
ctrl[0], ty = self._synth_expr(ctrl[0])
|
|
421
|
+
|
|
422
|
+
if is_array_type(ty):
|
|
423
|
+
if len(ctrl) > 1:
|
|
424
|
+
span = Span(to_span(control.func).end, to_span(control).end)
|
|
425
|
+
raise GuppyError(WrongNumberOfArgsError(span, 1, len(control.args)))
|
|
426
|
+
element_ty = get_element_type(ty)
|
|
427
|
+
if not is_qubit_ty(element_ty):
|
|
428
|
+
n = ExistentialConstVar.fresh(
|
|
429
|
+
"n", NumericType(NumericType.Kind.Nat)
|
|
430
|
+
)
|
|
431
|
+
dummy_array_ty = array_type(qubit_ty(), n)
|
|
432
|
+
raise GuppyTypeError(TypeMismatchError(ctrl[0], dummy_array_ty, ty))
|
|
433
|
+
control.qubit_num = get_array_length(ty)
|
|
434
|
+
else:
|
|
435
|
+
for i in range(len(ctrl)):
|
|
436
|
+
ctrl[i], subst = self._check_expr(ctrl[i], qubit_ty())
|
|
437
|
+
assert len(subst) == 0
|
|
438
|
+
control.qubit_num = len(ctrl)
|
|
439
|
+
|
|
440
|
+
for power in node.power:
|
|
441
|
+
power.iter, subst = self._check_expr(
|
|
442
|
+
power.iter, NumericType(NumericType.Kind.Nat)
|
|
443
|
+
)
|
|
444
|
+
assert len(subst) == 0
|
|
445
|
+
|
|
446
|
+
return modified_block
|
|
447
|
+
|
|
401
448
|
def visit_If(self, node: ast.If) -> None:
|
|
402
449
|
raise InternalGuppyError("Control-flow statement should not be present here.")
|
|
403
450
|
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
|
|
3
|
+
from guppylang_internals.ast_util import find_nodes, get_type, loop_in_ast
|
|
4
|
+
from guppylang_internals.checker.cfg_checker import CheckedBB, CheckedCFG
|
|
5
|
+
from guppylang_internals.checker.core import Place, contains_subscript
|
|
6
|
+
from guppylang_internals.checker.errors.generic import (
|
|
7
|
+
InvalidUnderDagger,
|
|
8
|
+
UnsupportedError,
|
|
9
|
+
)
|
|
10
|
+
from guppylang_internals.definition.value import CallableDef
|
|
11
|
+
from guppylang_internals.engine import ENGINE
|
|
12
|
+
from guppylang_internals.error import GuppyError, GuppyTypeError
|
|
13
|
+
from guppylang_internals.nodes import (
|
|
14
|
+
AnyCall,
|
|
15
|
+
BarrierExpr,
|
|
16
|
+
GlobalCall,
|
|
17
|
+
LocalCall,
|
|
18
|
+
PlaceNode,
|
|
19
|
+
StateResultExpr,
|
|
20
|
+
TensorCall,
|
|
21
|
+
)
|
|
22
|
+
from guppylang_internals.tys.errors import UnitaryCallError
|
|
23
|
+
from guppylang_internals.tys.qubit import contain_qubit_ty
|
|
24
|
+
from guppylang_internals.tys.ty import FunctionType, UnitaryFlags
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def check_invalid_under_dagger(
|
|
28
|
+
fn_def: ast.FunctionDef, unitary_flags: UnitaryFlags
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Check that there are no invalid constructs in a daggered CFG.
|
|
31
|
+
This checker checks the case the UnitaryFlags is given by
|
|
32
|
+
annotation (i.e., not inferred from `with dagger:`).
|
|
33
|
+
"""
|
|
34
|
+
if UnitaryFlags.Dagger not in unitary_flags:
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
for stmt in fn_def.body:
|
|
38
|
+
loops = loop_in_ast(stmt)
|
|
39
|
+
if len(loops) != 0:
|
|
40
|
+
loop = next(iter(loops))
|
|
41
|
+
err = InvalidUnderDagger(loop, "Loop")
|
|
42
|
+
raise GuppyError(err)
|
|
43
|
+
# Note: sub-diagnostic for dagger context is not available here
|
|
44
|
+
|
|
45
|
+
found = find_nodes(
|
|
46
|
+
lambda n: isinstance(n, ast.Assign | ast.AnnAssign | ast.AugAssign),
|
|
47
|
+
stmt,
|
|
48
|
+
{ast.FunctionDef},
|
|
49
|
+
)
|
|
50
|
+
if len(found) != 0:
|
|
51
|
+
assign = next(iter(found))
|
|
52
|
+
err = InvalidUnderDagger(assign, "Assignment")
|
|
53
|
+
raise GuppyError(err)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class BBUnitaryChecker(ast.NodeVisitor):
|
|
57
|
+
"""AST visitor that checks whether the modifiers (dagger, control, power)
|
|
58
|
+
are applicable."""
|
|
59
|
+
|
|
60
|
+
flags: UnitaryFlags
|
|
61
|
+
|
|
62
|
+
def check(self, bb: CheckedBB[Place], unitary_flags: UnitaryFlags) -> None:
|
|
63
|
+
self.flags = unitary_flags
|
|
64
|
+
for stmt in bb.statements:
|
|
65
|
+
self.visit(stmt)
|
|
66
|
+
|
|
67
|
+
def _check_classical_args(self, args: list[ast.expr]) -> bool:
|
|
68
|
+
for arg in args:
|
|
69
|
+
self.visit(arg)
|
|
70
|
+
if contain_qubit_ty(get_type(arg)):
|
|
71
|
+
return False
|
|
72
|
+
return True
|
|
73
|
+
|
|
74
|
+
def _check_call(self, node: AnyCall, ty: FunctionType) -> None:
|
|
75
|
+
classic = self._check_classical_args(node.args)
|
|
76
|
+
flag_ok = self.flags in ty.unitary_flags
|
|
77
|
+
if not classic and not flag_ok:
|
|
78
|
+
raise GuppyTypeError(
|
|
79
|
+
UnitaryCallError(node, self.flags & (~ty.unitary_flags))
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def visit_GlobalCall(self, node: GlobalCall) -> None:
|
|
83
|
+
func = ENGINE.get_parsed(node.def_id)
|
|
84
|
+
assert isinstance(func, CallableDef)
|
|
85
|
+
self._check_call(node, func.ty)
|
|
86
|
+
|
|
87
|
+
def visit_LocalCall(self, node: LocalCall) -> None:
|
|
88
|
+
func = get_type(node.func)
|
|
89
|
+
assert isinstance(func, FunctionType)
|
|
90
|
+
self._check_call(node, func)
|
|
91
|
+
|
|
92
|
+
def visit_TensorCall(self, node: TensorCall) -> None:
|
|
93
|
+
self._check_call(node, node.tensor_ty)
|
|
94
|
+
|
|
95
|
+
def visit_BarrierExpr(self, node: BarrierExpr) -> None:
|
|
96
|
+
# Barrier is always allowed
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
def visit_StateResultExpr(self, node: StateResultExpr) -> None:
|
|
100
|
+
# StateResult is always allowed
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
def _check_assign(self, node: ast.Assign | ast.AnnAssign | ast.AugAssign) -> None:
|
|
104
|
+
if UnitaryFlags.Dagger in self.flags:
|
|
105
|
+
raise GuppyError(InvalidUnderDagger(node, "Assignment"))
|
|
106
|
+
if node.value is not None:
|
|
107
|
+
self.visit(node.value)
|
|
108
|
+
|
|
109
|
+
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
|
|
110
|
+
self._check_assign(node)
|
|
111
|
+
|
|
112
|
+
def visit_Assign(self, node: ast.Assign) -> None:
|
|
113
|
+
self._check_assign(node)
|
|
114
|
+
|
|
115
|
+
def visit_AugAssign(self, node: ast.AugAssign) -> None:
|
|
116
|
+
self._check_assign(node)
|
|
117
|
+
|
|
118
|
+
def visit_PlaceNode(self, node: PlaceNode) -> None:
|
|
119
|
+
if UnitaryFlags.Dagger in self.flags and contains_subscript(node.place):
|
|
120
|
+
raise GuppyError(
|
|
121
|
+
UnsupportedError(node, "index access", True, "dagger context")
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def check_cfg_unitary(
|
|
126
|
+
cfg: CheckedCFG[Place],
|
|
127
|
+
unitary_flags: UnitaryFlags,
|
|
128
|
+
) -> None:
|
|
129
|
+
"""Checks that the given unitary flags are valid for a CFG."""
|
|
130
|
+
bb_checker = BBUnitaryChecker()
|
|
131
|
+
for bb in cfg.bbs:
|
|
132
|
+
bb_checker.check(bb, unitary_flags)
|
|
@@ -24,7 +24,7 @@ from guppylang_internals.compiler.core import (
|
|
|
24
24
|
from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
25
25
|
from guppylang_internals.compiler.stmt_compiler import StmtCompiler
|
|
26
26
|
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
|
|
27
|
-
from guppylang_internals.tys.ty import
|
|
27
|
+
from guppylang_internals.tys.ty import type_to_row
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
def compile_cfg(
|
|
@@ -38,7 +38,7 @@ def compile_cfg(
|
|
|
38
38
|
# TODO: This mutates the CFG in-place which leads to problems when trying to lower
|
|
39
39
|
# the same function to Hugr twice. For now we just check that the return vars
|
|
40
40
|
# haven't already been inserted, but we should figure out a better way to handle
|
|
41
|
-
# this: https://github.com/
|
|
41
|
+
# this: https://github.com/quantinuum/guppylang/issues/428
|
|
42
42
|
if all(
|
|
43
43
|
not is_return_var(v.name)
|
|
44
44
|
for v in cfg.exit_bb.sig.input_row
|
|
@@ -52,7 +52,7 @@ def compile_cfg(
|
|
|
52
52
|
# unreachable
|
|
53
53
|
out_tys = [place.ty.to_hugr(ctx) for place in cfg.exit_bb.sig.input_row]
|
|
54
54
|
# TODO: Use proper API for this once it's added in hugr-py:
|
|
55
|
-
# https://github.com/
|
|
55
|
+
# https://github.com/quantinuum/hugr/issues/1816
|
|
56
56
|
builder._exit_op._cfg_outputs = out_tys
|
|
57
57
|
builder.parent_op._outputs = out_tys
|
|
58
58
|
builder.parent_node = builder.hugr._update_node_outs(
|
|
@@ -194,13 +194,14 @@ def choose_vars_for_tuple_sum(
|
|
|
194
194
|
constructs a TupleSum value of type `Sum(#s1, #s2, ...)`.
|
|
195
195
|
"""
|
|
196
196
|
assert all(v.ty.droppable for var_row in output_vars for v in var_row)
|
|
197
|
-
|
|
198
|
-
|
|
197
|
+
sum_type = ht.Sum(
|
|
198
|
+
[[v.ty.to_hugr(dfg.ctx) for v in var_row] for var_row in output_vars]
|
|
199
|
+
)
|
|
199
200
|
|
|
200
201
|
# We pass all values into the conditional instead of relying on non-local edges.
|
|
201
202
|
# This is because we can't handle them in lower parts of the stack yet :/
|
|
202
203
|
# TODO: Reinstate use of non-local edges.
|
|
203
|
-
# See https://github.com/
|
|
204
|
+
# See https://github.com/quantinuum/guppylang/issues/963
|
|
204
205
|
all_vars = {v.id: dfg[v] for var_row in output_vars for v in var_row}
|
|
205
206
|
all_vars_wires = list(all_vars.values())
|
|
206
207
|
all_vars_idxs = {x: i for i, x in enumerate(all_vars.keys())}
|