guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
"""Type checking code for statements.
|
|
2
|
+
|
|
3
|
+
Operates on statements in a basic block after CFG construction. In particular, we
|
|
4
|
+
assume that statements involving control flow (i.e. if, while, break, and return
|
|
5
|
+
statements) have been removed during CFG construction.
|
|
6
|
+
|
|
7
|
+
After checking, we return a desugared statement where all sub-expression have been type
|
|
8
|
+
annotated.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import ast
|
|
12
|
+
import functools
|
|
13
|
+
from collections.abc import Iterable, Sequence
|
|
14
|
+
from dataclasses import replace
|
|
15
|
+
from itertools import takewhile
|
|
16
|
+
from typing import TypeVar, cast
|
|
17
|
+
|
|
18
|
+
from guppylang_internals.ast_util import (
|
|
19
|
+
AstVisitor,
|
|
20
|
+
get_type,
|
|
21
|
+
with_loc,
|
|
22
|
+
with_type,
|
|
23
|
+
)
|
|
24
|
+
from guppylang_internals.cfg.bb import BB, BBStatement
|
|
25
|
+
from guppylang_internals.cfg.builder import (
|
|
26
|
+
desugar_comprehension,
|
|
27
|
+
make_var,
|
|
28
|
+
tmp_vars,
|
|
29
|
+
)
|
|
30
|
+
from guppylang_internals.checker.core import (
|
|
31
|
+
Context,
|
|
32
|
+
FieldAccess,
|
|
33
|
+
SubscriptAccess,
|
|
34
|
+
Variable,
|
|
35
|
+
)
|
|
36
|
+
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
37
|
+
from guppylang_internals.checker.errors.type_errors import (
|
|
38
|
+
AssignFieldTypeMismatchError,
|
|
39
|
+
AssignNonPlaceHelp,
|
|
40
|
+
AssignSubscriptTypeMismatchError,
|
|
41
|
+
AttributeNotFoundError,
|
|
42
|
+
MissingReturnValueError,
|
|
43
|
+
StarredTupleUnpackError,
|
|
44
|
+
TypeInferenceError,
|
|
45
|
+
UnpackableError,
|
|
46
|
+
WrongNumberOfUnpacksError,
|
|
47
|
+
)
|
|
48
|
+
from guppylang_internals.checker.expr_checker import (
|
|
49
|
+
ExprChecker,
|
|
50
|
+
ExprSynthesizer,
|
|
51
|
+
check_place_assignable,
|
|
52
|
+
synthesize_comprehension,
|
|
53
|
+
)
|
|
54
|
+
from guppylang_internals.error import GuppyError, GuppyTypeError, InternalGuppyError
|
|
55
|
+
from guppylang_internals.nodes import (
|
|
56
|
+
AnyUnpack,
|
|
57
|
+
DesugaredArrayComp,
|
|
58
|
+
IterableUnpack,
|
|
59
|
+
MakeIter,
|
|
60
|
+
NestedFunctionDef,
|
|
61
|
+
PlaceNode,
|
|
62
|
+
TupleUnpack,
|
|
63
|
+
UnpackPattern,
|
|
64
|
+
)
|
|
65
|
+
from guppylang_internals.span import Span, to_span
|
|
66
|
+
from guppylang_internals.tys.builtin import (
|
|
67
|
+
array_type,
|
|
68
|
+
get_element_type,
|
|
69
|
+
get_iter_size,
|
|
70
|
+
is_array_type,
|
|
71
|
+
is_sized_iter_type,
|
|
72
|
+
nat_type,
|
|
73
|
+
)
|
|
74
|
+
from guppylang_internals.tys.const import ConstValue
|
|
75
|
+
from guppylang_internals.tys.parsing import type_from_ast
|
|
76
|
+
from guppylang_internals.tys.subst import Subst
|
|
77
|
+
from guppylang_internals.tys.ty import (
|
|
78
|
+
ExistentialTypeVar,
|
|
79
|
+
FunctionType,
|
|
80
|
+
NoneType,
|
|
81
|
+
StructType,
|
|
82
|
+
TupleType,
|
|
83
|
+
Type,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class StmtChecker(AstVisitor[BBStatement]):
|
|
88
|
+
ctx: Context
|
|
89
|
+
bb: BB | None
|
|
90
|
+
return_ty: Type | None
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self, ctx: Context, bb: BB | None = None, return_ty: Type | None = None
|
|
94
|
+
) -> None:
|
|
95
|
+
assert not return_ty or not return_ty.unsolved_vars
|
|
96
|
+
self.ctx = ctx
|
|
97
|
+
self.bb = bb
|
|
98
|
+
self.return_ty = return_ty
|
|
99
|
+
|
|
100
|
+
def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]:
|
|
101
|
+
return [self.visit(s) for s in stmts]
|
|
102
|
+
|
|
103
|
+
def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, Type]:
|
|
104
|
+
return ExprSynthesizer(self.ctx).synthesize(node)
|
|
105
|
+
|
|
106
|
+
def _synth_instance_fun(
|
|
107
|
+
self,
|
|
108
|
+
node: ast.expr,
|
|
109
|
+
args: list[ast.expr],
|
|
110
|
+
func_name: str,
|
|
111
|
+
description: str,
|
|
112
|
+
exp_sig: FunctionType | None = None,
|
|
113
|
+
give_reason: bool = False,
|
|
114
|
+
) -> tuple[ast.expr, Type]:
|
|
115
|
+
return ExprSynthesizer(self.ctx).synthesize_instance_func(
|
|
116
|
+
node, args, func_name, description, exp_sig, give_reason
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def _check_expr(
|
|
120
|
+
self, node: ast.expr, ty: Type, kind: str = "expression"
|
|
121
|
+
) -> tuple[ast.expr, Subst]:
|
|
122
|
+
return ExprChecker(self.ctx).check(node, ty, kind)
|
|
123
|
+
|
|
124
|
+
@functools.singledispatchmethod
|
|
125
|
+
def _check_assign(self, lhs: ast.expr, rhs: ast.expr, rhs_ty: Type) -> ast.expr:
|
|
126
|
+
"""Helper function to check assignments with patterns."""
|
|
127
|
+
raise InternalGuppyError("Unexpected assignment pattern")
|
|
128
|
+
|
|
129
|
+
@_check_assign.register
|
|
130
|
+
def _check_variable_assign(
|
|
131
|
+
self, lhs: ast.Name, _rhs: ast.expr, rhs_ty: Type
|
|
132
|
+
) -> PlaceNode:
|
|
133
|
+
x = lhs.id
|
|
134
|
+
var = Variable(x, rhs_ty, lhs)
|
|
135
|
+
self.ctx.locals[x] = var
|
|
136
|
+
return with_loc(lhs, with_type(rhs_ty, PlaceNode(place=var)))
|
|
137
|
+
|
|
138
|
+
@_check_assign.register
|
|
139
|
+
def _check_field_assign(
|
|
140
|
+
self, lhs: ast.Attribute, _rhs: ast.expr, rhs_ty: Type
|
|
141
|
+
) -> PlaceNode:
|
|
142
|
+
# Unfortunately, the `attr` is just a string, not an AST node, so we
|
|
143
|
+
# have to compute its span by hand. This is fine since linebreaks are
|
|
144
|
+
# not allowed in the identifier following the `.`
|
|
145
|
+
span = to_span(lhs)
|
|
146
|
+
value, attr = lhs.value, lhs.attr
|
|
147
|
+
attr_span = Span(span.end.shift_left(len(attr)), span.end)
|
|
148
|
+
value, struct_ty = self._synth_expr(value)
|
|
149
|
+
if not isinstance(struct_ty, StructType) or attr not in struct_ty.field_dict:
|
|
150
|
+
raise GuppyTypeError(AttributeNotFoundError(attr_span, struct_ty, attr))
|
|
151
|
+
field = struct_ty.field_dict[attr]
|
|
152
|
+
# TODO: In the future, we could infer some type args here
|
|
153
|
+
if field.ty != rhs_ty:
|
|
154
|
+
# TODO: Get hold of a span for the RHS and use a regular `TypeMismatchError`
|
|
155
|
+
# instead (maybe with a custom hint).
|
|
156
|
+
raise GuppyTypeError(AssignFieldTypeMismatchError(attr_span, rhs_ty, field))
|
|
157
|
+
if not isinstance(value, PlaceNode):
|
|
158
|
+
# For now we complain if someone tries to assign to something that is not a
|
|
159
|
+
# place, e.g. `f().a = 4`. This would only make sense if there is another
|
|
160
|
+
# reference to the return value of `f`, otherwise the mutation cannot be
|
|
161
|
+
# observed. We can start supporting this once we have proper reference
|
|
162
|
+
# semantics.
|
|
163
|
+
err = UnsupportedError(value, "Assigning to this expression", singular=True)
|
|
164
|
+
err.add_sub_diagnostic(AssignNonPlaceHelp(None, field))
|
|
165
|
+
raise GuppyError(err)
|
|
166
|
+
if field.ty.copyable:
|
|
167
|
+
raise GuppyError(
|
|
168
|
+
UnsupportedError(
|
|
169
|
+
attr_span, "Mutation of classical fields", singular=True
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
place = FieldAccess(value.place, struct_ty.field_dict[attr], lhs)
|
|
173
|
+
place = check_place_assignable(place, self.ctx, lhs, "assignable")
|
|
174
|
+
return with_loc(lhs, with_type(rhs_ty, PlaceNode(place=place)))
|
|
175
|
+
|
|
176
|
+
@_check_assign.register
|
|
177
|
+
def _check_subscript_assign(
|
|
178
|
+
self, lhs: ast.Subscript, rhs: ast.expr, rhs_ty: Type
|
|
179
|
+
) -> PlaceNode:
|
|
180
|
+
# Check subscript is array subscript.
|
|
181
|
+
value, container_ty = self._synth_expr(lhs.value)
|
|
182
|
+
if not is_array_type(container_ty):
|
|
183
|
+
raise GuppyError(
|
|
184
|
+
UnsupportedError(lhs, "Subscript assignments to non-arrays")
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Check array element type matches type of RHS.
|
|
188
|
+
element_ty = get_element_type(container_ty)
|
|
189
|
+
if element_ty != rhs_ty:
|
|
190
|
+
raise GuppyTypeError(
|
|
191
|
+
AssignSubscriptTypeMismatchError(to_span(lhs), rhs_ty, element_ty)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# As with field assignment, only allow place assignments for now.
|
|
195
|
+
if not isinstance(value, PlaceNode):
|
|
196
|
+
raise GuppyError(
|
|
197
|
+
UnsupportedError(value, "Assigning to this expression", singular=True)
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Create a subscript place
|
|
201
|
+
item_expr, item_ty = self._synth_expr(lhs.slice)
|
|
202
|
+
item = Variable(next(tmp_vars), item_ty, item_expr)
|
|
203
|
+
place = SubscriptAccess(value.place, item, rhs_ty, item_expr)
|
|
204
|
+
|
|
205
|
+
# Calling `check_place_assignable` makes sure that `__setitem__` is implemented
|
|
206
|
+
place = check_place_assignable(place, self.ctx, lhs, "assignable")
|
|
207
|
+
return with_loc(lhs, with_type(rhs_ty, PlaceNode(place=place)))
|
|
208
|
+
|
|
209
|
+
@_check_assign.register
|
|
210
|
+
def _check_tuple_assign(
|
|
211
|
+
self, lhs: ast.Tuple, rhs: ast.expr, rhs_ty: Type
|
|
212
|
+
) -> AnyUnpack:
|
|
213
|
+
return self._check_unpack_assign(lhs, rhs, rhs_ty)
|
|
214
|
+
|
|
215
|
+
@_check_assign.register
|
|
216
|
+
def _check_list_assign(
|
|
217
|
+
self, lhs: ast.List, rhs: ast.expr, rhs_ty: Type
|
|
218
|
+
) -> AnyUnpack:
|
|
219
|
+
return self._check_unpack_assign(lhs, rhs, rhs_ty)
|
|
220
|
+
|
|
221
|
+
def _check_unpack_assign(
|
|
222
|
+
self, lhs: ast.Tuple | ast.List, rhs: ast.expr, rhs_ty: Type
|
|
223
|
+
) -> AnyUnpack:
|
|
224
|
+
"""Helper function to check unpacking assignments.
|
|
225
|
+
|
|
226
|
+
These are the ones where the LHS is either a tuple or a list.
|
|
227
|
+
"""
|
|
228
|
+
# Parse LHS into `left, *starred, right`
|
|
229
|
+
pattern = parse_unpack_pattern(lhs)
|
|
230
|
+
left, starred, right = pattern.left, pattern.starred, pattern.right
|
|
231
|
+
# Check that the RHS has an appropriate type to be unpacked
|
|
232
|
+
unpack, rhs_elts, rhs_tys = self._check_unpackable(rhs, rhs_ty, pattern)
|
|
233
|
+
|
|
234
|
+
# Check that the numbers match up on the LHS and RHS
|
|
235
|
+
num_lhs, num_rhs = len(right) + len(left), len(rhs_tys)
|
|
236
|
+
err = WrongNumberOfUnpacksError(
|
|
237
|
+
lhs, num_rhs, num_lhs, at_least=starred is not None
|
|
238
|
+
)
|
|
239
|
+
if num_lhs > num_rhs:
|
|
240
|
+
# Build span that covers the unexpected elts on the LHS
|
|
241
|
+
span = Span(to_span(lhs.elts[num_rhs]).start, to_span(lhs.elts[-1]).end)
|
|
242
|
+
raise GuppyTypeError(replace(err, span=span))
|
|
243
|
+
elif num_lhs < num_rhs and not starred:
|
|
244
|
+
raise GuppyTypeError(err)
|
|
245
|
+
|
|
246
|
+
# Recursively check any nested patterns on the left or right
|
|
247
|
+
le, rs = len(left), len(rhs_elts) - len(right) # left_end, right_start
|
|
248
|
+
unpack.pattern.left = [
|
|
249
|
+
self._check_assign(pat, elt, ty)
|
|
250
|
+
for pat, elt, ty in zip(left, rhs_elts[:le], rhs_tys[:le], strict=True)
|
|
251
|
+
]
|
|
252
|
+
unpack.pattern.right = [
|
|
253
|
+
self._check_assign(pat, elt, ty)
|
|
254
|
+
for pat, elt, ty in zip(right, rhs_elts[rs:], rhs_tys[rs:], strict=True)
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
# Starred assignments are collected into an array
|
|
258
|
+
if starred:
|
|
259
|
+
starred_tys = rhs_tys[le:rs]
|
|
260
|
+
assert all_equal(starred_tys)
|
|
261
|
+
if starred_tys:
|
|
262
|
+
starred_ty, *_ = starred_tys
|
|
263
|
+
# Starred part could be empty. If it's an iterable unpack, we're still fine
|
|
264
|
+
# since we know the yielded type
|
|
265
|
+
elif isinstance(unpack, IterableUnpack):
|
|
266
|
+
starred_ty = unpack.compr.elt_ty
|
|
267
|
+
# For tuple unpacks, there is no way to infer a type for the empty starred
|
|
268
|
+
# part
|
|
269
|
+
else:
|
|
270
|
+
unsolved = array_type(ExistentialTypeVar.fresh("T", True, True), 0)
|
|
271
|
+
raise GuppyError(TypeInferenceError(starred, unsolved))
|
|
272
|
+
array_ty = array_type(starred_ty, len(starred_tys))
|
|
273
|
+
unpack.pattern.starred = self._check_assign(starred, rhs_elts[0], array_ty)
|
|
274
|
+
|
|
275
|
+
return with_type(rhs_ty, with_loc(lhs, unpack))
|
|
276
|
+
|
|
277
|
+
def _check_unpackable(
|
|
278
|
+
self, expr: ast.expr, ty: Type, pattern: UnpackPattern
|
|
279
|
+
) -> tuple[AnyUnpack, list[ast.expr], Sequence[Type]]:
|
|
280
|
+
"""Checks that the given expression can be used in an unpacking assignment.
|
|
281
|
+
|
|
282
|
+
This is the case for expressions with tuple types or ones that are iterable with
|
|
283
|
+
a static size. Also checks that the expression is compatible with the given
|
|
284
|
+
unpacking pattern.
|
|
285
|
+
|
|
286
|
+
Returns an AST node capturing the unpacking operation together with expressions
|
|
287
|
+
and types for all unpacked items. Emits a user error if the given expression is
|
|
288
|
+
not unpackable.
|
|
289
|
+
"""
|
|
290
|
+
left, starred, right = pattern.left, pattern.starred, pattern.right
|
|
291
|
+
if isinstance(ty, TupleType):
|
|
292
|
+
# Starred assignment of tuples is only allowed if all starred elements have
|
|
293
|
+
# the same type
|
|
294
|
+
if starred:
|
|
295
|
+
starred_tys = (
|
|
296
|
+
ty.element_types[len(left) : -len(right)]
|
|
297
|
+
if right
|
|
298
|
+
else ty.element_types[len(left) :]
|
|
299
|
+
)
|
|
300
|
+
if not all_equal(starred_tys):
|
|
301
|
+
tuple_ty = TupleType(starred_tys)
|
|
302
|
+
raise GuppyError(StarredTupleUnpackError(starred, tuple_ty))
|
|
303
|
+
tys = ty.element_types
|
|
304
|
+
elts = expr.elts if isinstance(expr, ast.Tuple) else [expr] * len(tys)
|
|
305
|
+
return TupleUnpack(pattern), elts, tys
|
|
306
|
+
|
|
307
|
+
elif self.ctx.globals.get_instance_func(ty, "__iter__"):
|
|
308
|
+
size = check_iter_unpack_has_static_size(expr, self.ctx)
|
|
309
|
+
# Create a dummy variable and assign the expression to it. This helps us to
|
|
310
|
+
# wire it up correctly during Hugr generation.
|
|
311
|
+
var = self._check_assign(make_var(next(tmp_vars), expr), expr, ty)
|
|
312
|
+
assert isinstance(var, PlaceNode)
|
|
313
|
+
# We collect the whole RHS into an array. For this, we can reuse the
|
|
314
|
+
# existing array comprehension logic.
|
|
315
|
+
elt = make_var(next(tmp_vars), expr)
|
|
316
|
+
gen = ast.comprehension(target=elt, iter=var, ifs=[], is_async=False)
|
|
317
|
+
[gen], elt = desugar_comprehension([with_loc(expr, gen)], elt, expr)
|
|
318
|
+
# Type check the comprehension
|
|
319
|
+
[gen], elt, elt_ty = synthesize_comprehension(expr, [gen], elt, self.ctx)
|
|
320
|
+
compr = DesugaredArrayComp(
|
|
321
|
+
elt, gen, length=ConstValue(nat_type(), size), elt_ty=elt_ty
|
|
322
|
+
)
|
|
323
|
+
compr = with_type(array_type(elt_ty, size), compr)
|
|
324
|
+
return IterableUnpack(pattern, compr, var), size * [elt], size * [elt_ty]
|
|
325
|
+
|
|
326
|
+
# Otherwise, we can't unpack this expression
|
|
327
|
+
raise GuppyError(UnpackableError(expr, ty))
|
|
328
|
+
|
|
329
|
+
def visit_Assign(self, node: ast.Assign) -> ast.Assign:
|
|
330
|
+
if len(node.targets) > 1:
|
|
331
|
+
# This is the case for assignments like `a = b = 1`
|
|
332
|
+
raise GuppyError(UnsupportedError(node, "Multi assignments"))
|
|
333
|
+
|
|
334
|
+
[target] = node.targets
|
|
335
|
+
node.value, ty = self._synth_expr(node.value)
|
|
336
|
+
node.targets = [self._check_assign(target, node.value, ty)]
|
|
337
|
+
return node
|
|
338
|
+
|
|
339
|
+
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
|
|
340
|
+
if node.value is None:
|
|
341
|
+
raise GuppyError(UnsupportedError(node, "Variable declarations"))
|
|
342
|
+
ty = type_from_ast(node.annotation, self.ctx.globals, self.ctx.generic_params)
|
|
343
|
+
node.value, subst = self._check_expr(node.value, ty)
|
|
344
|
+
assert not ty.unsolved_vars # `ty` must be closed!
|
|
345
|
+
assert len(subst) == 0
|
|
346
|
+
target = self._check_assign(node.target, node.value, ty)
|
|
347
|
+
return with_loc(node, ast.Assign(targets=[target], value=node.value))
|
|
348
|
+
|
|
349
|
+
def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt:
|
|
350
|
+
bin_op = with_loc(
|
|
351
|
+
node, ast.BinOp(left=node.target, op=node.op, right=node.value)
|
|
352
|
+
)
|
|
353
|
+
assign = with_loc(node, ast.Assign(targets=[node.target], value=bin_op))
|
|
354
|
+
return self.visit_Assign(assign)
|
|
355
|
+
|
|
356
|
+
def visit_Expr(self, node: ast.Expr) -> ast.stmt:
|
|
357
|
+
# An expression statement where the return value is discarded
|
|
358
|
+
node.value, _ = self._synth_expr(node.value)
|
|
359
|
+
return node
|
|
360
|
+
|
|
361
|
+
def visit_Return(self, node: ast.Return) -> ast.stmt:
|
|
362
|
+
if not self.return_ty:
|
|
363
|
+
raise InternalGuppyError("return_ty required to check return stmt!")
|
|
364
|
+
|
|
365
|
+
if node.value is not None:
|
|
366
|
+
node.value, subst = self._check_expr(
|
|
367
|
+
node.value, self.return_ty, "return value"
|
|
368
|
+
)
|
|
369
|
+
assert len(subst) == 0 # `self.return_ty` is closed!
|
|
370
|
+
elif not isinstance(self.return_ty, NoneType):
|
|
371
|
+
raise GuppyTypeError(MissingReturnValueError(node, self.return_ty))
|
|
372
|
+
return node
|
|
373
|
+
|
|
374
|
+
def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt:
|
|
375
|
+
from guppylang_internals.checker.func_checker import check_nested_func_def
|
|
376
|
+
|
|
377
|
+
if not self.bb:
|
|
378
|
+
raise InternalGuppyError("BB required to check nested function def!")
|
|
379
|
+
|
|
380
|
+
func_def = check_nested_func_def(node, self.bb, self.ctx)
|
|
381
|
+
self.ctx.locals[func_def.name] = Variable(func_def.name, func_def.ty, func_def)
|
|
382
|
+
return func_def
|
|
383
|
+
|
|
384
|
+
def visit_If(self, node: ast.If) -> None:
|
|
385
|
+
raise InternalGuppyError("Control-flow statement should not be present here.")
|
|
386
|
+
|
|
387
|
+
def visit_While(self, node: ast.While) -> None:
|
|
388
|
+
raise InternalGuppyError("Control-flow statement should not be present here.")
|
|
389
|
+
|
|
390
|
+
def visit_Break(self, node: ast.Break) -> None:
|
|
391
|
+
raise InternalGuppyError("Control-flow statement should not be present here.")
|
|
392
|
+
|
|
393
|
+
def visit_Continue(self, node: ast.Continue) -> None:
|
|
394
|
+
raise InternalGuppyError("Control-flow statement should not be present here.")
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
T = TypeVar("T")
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def all_equal(xs: Iterable[T]) -> bool:
|
|
401
|
+
"""Checks if all elements yielded from an iterable are equal."""
|
|
402
|
+
it = iter(xs)
|
|
403
|
+
try:
|
|
404
|
+
first = next(it)
|
|
405
|
+
except StopIteration:
|
|
406
|
+
return True
|
|
407
|
+
return all(first == x for x in it)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def parse_unpack_pattern(lhs: ast.Tuple | ast.List) -> UnpackPattern:
|
|
411
|
+
"""Parses the LHS of an unpacking assignment like `a, *bs, c = ...` or
|
|
412
|
+
`[a, *bs, c] = ...`."""
|
|
413
|
+
# Split up LHS into `left, *starred, right` (the Python grammar ensures
|
|
414
|
+
# that there is at most one starred expression)
|
|
415
|
+
left = list(takewhile(lambda e: not isinstance(e, ast.Starred), lhs.elts))
|
|
416
|
+
starred = (
|
|
417
|
+
cast(ast.Starred, lhs.elts[len(left)]).value
|
|
418
|
+
if len(left) < len(lhs.elts)
|
|
419
|
+
else None
|
|
420
|
+
)
|
|
421
|
+
right = lhs.elts[len(left) + 1 :]
|
|
422
|
+
assert isinstance(starred, ast.Name | None), "Python grammar"
|
|
423
|
+
return UnpackPattern(left, starred, right)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def check_iter_unpack_has_static_size(expr: ast.expr, ctx: Context) -> int:
|
|
427
|
+
"""Helper function to check that an iterable expression is suitable to be unpacked
|
|
428
|
+
in an assignment.
|
|
429
|
+
|
|
430
|
+
This is the case if the iterator has a static, non-generic size.
|
|
431
|
+
|
|
432
|
+
Returns the size of the iterator or emits a user error if the iterable is not
|
|
433
|
+
suitable.
|
|
434
|
+
"""
|
|
435
|
+
expr_synth = ExprSynthesizer(ctx)
|
|
436
|
+
make_iter = with_loc(expr, MakeIter(expr, expr, unwrap_size_hint=False))
|
|
437
|
+
make_iter, iter_ty = expr_synth.visit_MakeIter(make_iter)
|
|
438
|
+
err = UnpackableError(expr, get_type(expr))
|
|
439
|
+
if not is_sized_iter_type(iter_ty):
|
|
440
|
+
err.add_sub_diagnostic(UnpackableError.NonStaticIter(None))
|
|
441
|
+
raise GuppyError(err)
|
|
442
|
+
match get_iter_size(iter_ty):
|
|
443
|
+
case ConstValue(value=int(size)):
|
|
444
|
+
return size
|
|
445
|
+
case generic_size:
|
|
446
|
+
err.add_sub_diagnostic(UnpackableError.GenericSize(None, generic_size))
|
|
447
|
+
raise GuppyError(err)
|
|
File without changes
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
from hugr import Wire, ops
|
|
6
|
+
from hugr import tys as ht
|
|
7
|
+
from hugr.build import cfg as hc
|
|
8
|
+
from hugr.build.dfg import DP, DfBase
|
|
9
|
+
from hugr.hugr.node_port import ToNode
|
|
10
|
+
|
|
11
|
+
from guppylang_internals.checker.cfg_checker import (
|
|
12
|
+
CheckedBB,
|
|
13
|
+
CheckedCFG,
|
|
14
|
+
Row,
|
|
15
|
+
Signature,
|
|
16
|
+
)
|
|
17
|
+
from guppylang_internals.checker.core import Place, Variable
|
|
18
|
+
from guppylang_internals.compiler.core import (
|
|
19
|
+
CompilerContext,
|
|
20
|
+
DFContainer,
|
|
21
|
+
is_return_var,
|
|
22
|
+
return_var,
|
|
23
|
+
)
|
|
24
|
+
from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
25
|
+
from guppylang_internals.compiler.stmt_compiler import StmtCompiler
|
|
26
|
+
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
|
|
27
|
+
from guppylang_internals.tys.ty import SumType, row_to_type, type_to_row
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def compile_cfg(
|
|
31
|
+
cfg: CheckedCFG[Place],
|
|
32
|
+
container: DfBase[DP],
|
|
33
|
+
inputs: Sequence[Wire],
|
|
34
|
+
ctx: CompilerContext,
|
|
35
|
+
) -> hc.Cfg:
|
|
36
|
+
"""Compiles a CFG to Hugr."""
|
|
37
|
+
# Patch the CFG with dummy return variables
|
|
38
|
+
# TODO: This mutates the CFG in-place which leads to problems when trying to lower
|
|
39
|
+
# the same function to Hugr twice. For now we just check that the return vars
|
|
40
|
+
# haven't already been inserted, but we should figure out a better way to handle
|
|
41
|
+
# this: https://github.com/CQCL/guppylang/issues/428
|
|
42
|
+
if all(
|
|
43
|
+
not is_return_var(v.name)
|
|
44
|
+
for v in cfg.exit_bb.sig.input_row
|
|
45
|
+
if isinstance(v, Variable)
|
|
46
|
+
):
|
|
47
|
+
insert_return_vars(cfg)
|
|
48
|
+
|
|
49
|
+
builder = container.add_cfg(*inputs)
|
|
50
|
+
|
|
51
|
+
# Explicitly annotate the output types since Hugr can't infer them if the exit is
|
|
52
|
+
# unreachable
|
|
53
|
+
out_tys = [place.ty.to_hugr(ctx) for place in cfg.exit_bb.sig.input_row]
|
|
54
|
+
# TODO: Use proper API for this once it's added in hugr-py:
|
|
55
|
+
# https://github.com/CQCL/hugr/issues/1816
|
|
56
|
+
builder._exit_op._cfg_outputs = out_tys
|
|
57
|
+
builder.parent_op._outputs = out_tys
|
|
58
|
+
builder.parent_node = builder.hugr._update_node_outs(
|
|
59
|
+
builder.parent_node, len(out_tys)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
blocks: dict[CheckedBB[Place], ToNode] = {}
|
|
63
|
+
for bb in cfg.bbs:
|
|
64
|
+
blocks[bb] = compile_bb(bb, builder, bb == cfg.entry_bb, ctx)
|
|
65
|
+
for bb in cfg.bbs:
|
|
66
|
+
for i, succ in enumerate(bb.successors):
|
|
67
|
+
builder.branch(blocks[bb][i], blocks[succ])
|
|
68
|
+
|
|
69
|
+
return builder
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def compile_bb(
|
|
73
|
+
bb: CheckedBB[Place],
|
|
74
|
+
builder: hc.Cfg,
|
|
75
|
+
is_entry: bool,
|
|
76
|
+
ctx: CompilerContext,
|
|
77
|
+
) -> ToNode:
|
|
78
|
+
"""Compiles a single basic block to Hugr.
|
|
79
|
+
|
|
80
|
+
If the basic block is the output block, returns `None`.
|
|
81
|
+
"""
|
|
82
|
+
# The exit BB is completely empty
|
|
83
|
+
if bb.is_exit:
|
|
84
|
+
assert len(bb.statements) == 0
|
|
85
|
+
return builder.exit
|
|
86
|
+
|
|
87
|
+
# Unreachable BBs (besides the exit) should have been removed by now
|
|
88
|
+
assert bb.reachable
|
|
89
|
+
|
|
90
|
+
# Otherwise, we use a regular `Block` node
|
|
91
|
+
block: hc.Block
|
|
92
|
+
inputs: Sequence[Place]
|
|
93
|
+
if is_entry:
|
|
94
|
+
inputs = bb.sig.input_row
|
|
95
|
+
block = builder.add_entry()
|
|
96
|
+
else:
|
|
97
|
+
inputs = sort_vars(bb.sig.input_row)
|
|
98
|
+
block = builder.add_block(*(v.ty.to_hugr(ctx) for v in inputs))
|
|
99
|
+
|
|
100
|
+
# Add input node and compile the statements
|
|
101
|
+
dfg = DFContainer(block, ctx)
|
|
102
|
+
for v, wire in zip(inputs, block.input_node, strict=True):
|
|
103
|
+
dfg[v] = wire
|
|
104
|
+
dfg = StmtCompiler(ctx).compile_stmts(bb.statements, dfg)
|
|
105
|
+
|
|
106
|
+
# If we branch, we also have to compile the branch predicate
|
|
107
|
+
if len(bb.successors) > 1:
|
|
108
|
+
assert bb.branch_pred is not None
|
|
109
|
+
branch_port = ExprCompiler(ctx).compile(bb.branch_pred, dfg)
|
|
110
|
+
# Convert the bool predicate into a sum for branching.
|
|
111
|
+
pred_ty = builder.hugr.port_type(branch_port.out_port())
|
|
112
|
+
assert pred_ty == OpaqueBool
|
|
113
|
+
branch_port = dfg.builder.add_op(read_bool(), branch_port)
|
|
114
|
+
branch_port = cast(Wire, branch_port)
|
|
115
|
+
else:
|
|
116
|
+
# Even if we don't branch, we still have to add a `Sum(())` predicates
|
|
117
|
+
branch_port = dfg.builder.add_op(ops.Tag(0, ht.UnitSum(1)))
|
|
118
|
+
|
|
119
|
+
# Finally, we have to add the block output.
|
|
120
|
+
outputs: Sequence[Place]
|
|
121
|
+
if len(bb.successors) == 1:
|
|
122
|
+
# The easy case is if we don't branch: We just output all variables that are
|
|
123
|
+
# specified by the signature
|
|
124
|
+
[outputs] = bb.sig.output_rows
|
|
125
|
+
else:
|
|
126
|
+
# CFG building ensures that branching BBs don't branch to the exit (exit jumps
|
|
127
|
+
# must always be unconditional)
|
|
128
|
+
assert not any(succ.is_exit for succ in bb.successors)
|
|
129
|
+
|
|
130
|
+
# If we branch and the branches use the same places, then we can use a
|
|
131
|
+
# regular output
|
|
132
|
+
first, *rest = bb.sig.output_rows
|
|
133
|
+
if all({p.id for p in first} == {p.id for p in r} for r in rest):
|
|
134
|
+
outputs = first
|
|
135
|
+
else:
|
|
136
|
+
# Otherwise, we have to output a TupleSum: We put all non-linear variables
|
|
137
|
+
# into the branch TupleSum and all linear variables in the normal output
|
|
138
|
+
# (since they are shared between all successors). This is in line with the
|
|
139
|
+
# ordering on variables which puts linear variables at the end.
|
|
140
|
+
# We don't need to worry about the order of return vars since this isn't
|
|
141
|
+
# a branch to an exit (see assert above).
|
|
142
|
+
branch_port = choose_vars_for_tuple_sum(
|
|
143
|
+
unit_sum=branch_port,
|
|
144
|
+
output_vars=[
|
|
145
|
+
[v for v in sort_vars(row) if v.ty.droppable]
|
|
146
|
+
for row in bb.sig.output_rows
|
|
147
|
+
],
|
|
148
|
+
dfg=dfg,
|
|
149
|
+
)
|
|
150
|
+
outputs = [v for v in first if not v.ty.droppable]
|
|
151
|
+
|
|
152
|
+
# If this is *not* a jump to the exit BB, we need to sort the outputs to make the
|
|
153
|
+
# signature consistent with what the next BB expects
|
|
154
|
+
if not any(succ.is_exit for succ in bb.successors):
|
|
155
|
+
outputs = sort_vars(outputs)
|
|
156
|
+
else:
|
|
157
|
+
# Exit variables are not allowed to be sorted since their order corresponds to
|
|
158
|
+
# the function outputs
|
|
159
|
+
assert len(bb.successors) == 1, "Exit jumps are always unconditional"
|
|
160
|
+
|
|
161
|
+
block.set_block_outputs(branch_port, *(dfg[v] for v in outputs))
|
|
162
|
+
return block
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def insert_return_vars(cfg: CheckedCFG[Place]) -> None:
|
|
166
|
+
"""Patches a CFG by annotating dummy return variables in the BB signatures.
|
|
167
|
+
|
|
168
|
+
The statement compiler turns `return` statements into assignments of dummy variables
|
|
169
|
+
`%ret0`, `%ret1`, etc. We update the exit BB signature to make sure they are
|
|
170
|
+
correctly outputted.
|
|
171
|
+
"""
|
|
172
|
+
return_vars = [
|
|
173
|
+
Variable(return_var(i), ty, None)
|
|
174
|
+
for i, ty in enumerate(type_to_row(cfg.output_ty))
|
|
175
|
+
]
|
|
176
|
+
# Prepend return variables to the exit signature
|
|
177
|
+
cfg.exit_bb.sig = Signature(
|
|
178
|
+
[*return_vars, *cfg.exit_bb.sig.input_row], cfg.exit_bb.sig.output_rows
|
|
179
|
+
)
|
|
180
|
+
# Also patch the predecessors
|
|
181
|
+
for pred in cfg.exit_bb.predecessors:
|
|
182
|
+
# The exit BB will be the only successor
|
|
183
|
+
assert len(pred.sig.output_rows) == 1
|
|
184
|
+
[out_row] = pred.sig.output_rows
|
|
185
|
+
pred.sig = Signature(pred.sig.input_row, [[*return_vars, *out_row]])
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def choose_vars_for_tuple_sum(
|
|
189
|
+
unit_sum: Wire, output_vars: list[Row[Place]], dfg: DFContainer
|
|
190
|
+
) -> Wire:
|
|
191
|
+
"""Selects an output based on a TupleSum.
|
|
192
|
+
|
|
193
|
+
Given `unit_sum: Sum(*(), *(), ...)` and output variable rows `#s1, #s2, ...`,
|
|
194
|
+
constructs a TupleSum value of type `Sum(#s1, #s2, ...)`.
|
|
195
|
+
"""
|
|
196
|
+
assert all(v.ty.droppable for var_row in output_vars for v in var_row)
|
|
197
|
+
tys = [[v.ty for v in var_row] for var_row in output_vars]
|
|
198
|
+
sum_type = SumType([row_to_type(row) for row in tys]).to_hugr(dfg.ctx)
|
|
199
|
+
|
|
200
|
+
# We pass all values into the conditional instead of relying on non-local edges.
|
|
201
|
+
# This is because we can't handle them in lower parts of the stack yet :/
|
|
202
|
+
# TODO: Reinstate use of non-local edges.
|
|
203
|
+
# See https://github.com/CQCL/guppylang/issues/963
|
|
204
|
+
all_vars = {v.id: dfg[v] for var_row in output_vars for v in var_row}
|
|
205
|
+
all_vars_wires = list(all_vars.values())
|
|
206
|
+
all_vars_idxs = {x: i for i, x in enumerate(all_vars.keys())}
|
|
207
|
+
|
|
208
|
+
with dfg.builder.add_conditional(unit_sum, *all_vars_wires) as conditional:
|
|
209
|
+
for i, var_row in enumerate(output_vars):
|
|
210
|
+
with conditional.add_case(i) as case:
|
|
211
|
+
case_inputs = case.inputs()
|
|
212
|
+
outputs = [case_inputs[all_vars_idxs[v.id]] for v in var_row]
|
|
213
|
+
tag = case.add_op(ops.Tag(i, sum_type), *outputs)
|
|
214
|
+
case.set_outputs(tag)
|
|
215
|
+
return conditional
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def compare_var(p1: Place, p2: Place) -> int:
|
|
219
|
+
"""Defines a `<` order on variables.
|
|
220
|
+
|
|
221
|
+
We use this to determine in which order variables are outputted from basic blocks.
|
|
222
|
+
We need to output linear variables at the end, so we do a lexicographic ordering of
|
|
223
|
+
linearity and name.
|
|
224
|
+
"""
|
|
225
|
+
return -1 if (not p1.ty.droppable, str(p1)) < (not p2.ty.droppable, str(p2)) else 1
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def sort_vars(row: Row[Place]) -> list[Place]:
|
|
229
|
+
"""Sorts a row of variables.
|
|
230
|
+
|
|
231
|
+
This determines the order in which they are outputted from a BB.
|
|
232
|
+
"""
|
|
233
|
+
return sorted(row, key=functools.cmp_to_key(compare_var))
|