guppylang-internals 0.22.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/cfg/cfg.py +8 -0
- guppylang_internals/checker/cfg_checker.py +26 -65
- guppylang_internals/checker/core.py +8 -0
- guppylang_internals/checker/expr_checker.py +11 -25
- guppylang_internals/checker/func_checker.py +170 -21
- guppylang_internals/checker/stmt_checker.py +1 -1
- guppylang_internals/decorator.py +124 -58
- guppylang_internals/definition/const.py +2 -2
- guppylang_internals/definition/custom.py +1 -1
- guppylang_internals/definition/declaration.py +1 -1
- guppylang_internals/definition/extern.py +2 -2
- guppylang_internals/definition/function.py +1 -1
- guppylang_internals/definition/parameter.py +2 -2
- guppylang_internals/definition/pytket_circuits.py +1 -1
- guppylang_internals/definition/struct.py +10 -10
- guppylang_internals/definition/traced.py +1 -1
- guppylang_internals/definition/ty.py +6 -0
- guppylang_internals/definition/wasm.py +2 -2
- guppylang_internals/engine.py +13 -2
- guppylang_internals/nodes.py +0 -23
- guppylang_internals/std/_internal/compiler/tket_exts.py +3 -6
- guppylang_internals/std/_internal/compiler/wasm.py +37 -26
- guppylang_internals/tracing/function.py +13 -2
- guppylang_internals/tracing/unpacking.py +18 -12
- guppylang_internals/tys/builtin.py +30 -11
- guppylang_internals/tys/errors.py +6 -0
- guppylang_internals/tys/parsing.py +111 -125
- {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/METADATA +5 -5
- {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/RECORD +32 -32
- {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/licenses/LICENCE +0 -0
guppylang_internals/__init__.py
CHANGED
guppylang_internals/cfg/cfg.py
CHANGED
|
@@ -27,6 +27,9 @@ class BaseCFG(Generic[T]):
|
|
|
27
27
|
ass_before: Result[DefAssignmentDomain[str]]
|
|
28
28
|
maybe_ass_before: Result[MaybeAssignmentDomain[str]]
|
|
29
29
|
|
|
30
|
+
#: Set of variables defined in this CFG
|
|
31
|
+
assigned_somewhere: set[str]
|
|
32
|
+
|
|
30
33
|
def __init__(
|
|
31
34
|
self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None
|
|
32
35
|
):
|
|
@@ -38,6 +41,7 @@ class BaseCFG(Generic[T]):
|
|
|
38
41
|
self.live_before = {}
|
|
39
42
|
self.ass_before = {}
|
|
40
43
|
self.maybe_ass_before = {}
|
|
44
|
+
self.assigned_somewhere = set()
|
|
41
45
|
|
|
42
46
|
def ancestors(self, *bbs: T) -> Iterator[T]:
|
|
43
47
|
"""Returns an iterator over all ancestors of the given BBs in BFS order."""
|
|
@@ -101,6 +105,10 @@ class CFG(BaseCFG[BB]):
|
|
|
101
105
|
inout_vars: list[str],
|
|
102
106
|
) -> dict[BB, VariableStats[str]]:
|
|
103
107
|
stats = {bb: bb.compute_variable_stats() for bb in self.bbs}
|
|
108
|
+
# Locals are variables that are assigned somewhere inside the function
|
|
109
|
+
self.assigned_somewhere = def_ass_before.union(
|
|
110
|
+
maybe_ass_before, (x for bb in self.bbs for x in stats[bb].assigned)
|
|
111
|
+
)
|
|
104
112
|
# Mark all borrowed variables as implicitly used in the exit BB
|
|
105
113
|
stats[self.exit_bb].used |= {x: InoutReturnSentinel(var=x) for x in inout_vars}
|
|
106
114
|
# This also means borrowed variables are always live, so we can use them as the
|
|
@@ -8,7 +8,7 @@ import ast
|
|
|
8
8
|
import collections
|
|
9
9
|
from collections.abc import Iterator, Sequence
|
|
10
10
|
from dataclasses import dataclass, field
|
|
11
|
-
from typing import ClassVar, Generic, TypeVar
|
|
11
|
+
from typing import ClassVar, Generic, TypeVar
|
|
12
12
|
|
|
13
13
|
from guppylang_internals.ast_util import line_col
|
|
14
14
|
from guppylang_internals.cfg.bb import BB
|
|
@@ -23,7 +23,6 @@ from guppylang_internals.checker.core import (
|
|
|
23
23
|
)
|
|
24
24
|
from guppylang_internals.checker.expr_checker import ExprSynthesizer, to_bool
|
|
25
25
|
from guppylang_internals.checker.stmt_checker import StmtChecker
|
|
26
|
-
from guppylang_internals.definition.value import ValueDef
|
|
27
26
|
from guppylang_internals.diagnostic import Error, Note
|
|
28
27
|
from guppylang_internals.error import GuppyError
|
|
29
28
|
from guppylang_internals.tys.param import Parameter
|
|
@@ -115,7 +114,7 @@ def check_cfg(
|
|
|
115
114
|
if bb in compiled:
|
|
116
115
|
# If the BB was already compiled, we just have to check that the signatures
|
|
117
116
|
# match.
|
|
118
|
-
check_rows_match(input_row, compiled[bb].sig.input_row, bb
|
|
117
|
+
check_rows_match(input_row, compiled[bb].sig.input_row, bb)
|
|
119
118
|
else:
|
|
120
119
|
# Otherwise, check the BB and enqueue its successors
|
|
121
120
|
checked_bb = check_bb(
|
|
@@ -195,21 +194,6 @@ class BranchTypeError(Error):
|
|
|
195
194
|
span_label: ClassVar[str] = "This is of type `{ty}`"
|
|
196
195
|
ty: Type
|
|
197
196
|
|
|
198
|
-
@dataclass(frozen=True)
|
|
199
|
-
class GlobalHint(Note):
|
|
200
|
-
message: ClassVar[str] = (
|
|
201
|
-
"{ident} may be shadowing a global {defn.description} definition of type "
|
|
202
|
-
"`{defn.ty}` on some branches"
|
|
203
|
-
)
|
|
204
|
-
defn: ValueDef
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
@dataclass(frozen=True)
|
|
208
|
-
class GlobalShadowError(Error):
|
|
209
|
-
title: ClassVar[str] = "Global variable conditionally shadowed"
|
|
210
|
-
span_label: ClassVar[str] = "{ident} may be shadowing a global variable"
|
|
211
|
-
ident: str
|
|
212
|
-
|
|
213
197
|
|
|
214
198
|
def check_bb(
|
|
215
199
|
bb: BB,
|
|
@@ -245,23 +229,27 @@ def check_bb(
|
|
|
245
229
|
|
|
246
230
|
for succ in bb.successors + bb.dummy_successors:
|
|
247
231
|
for x, use_bb in cfg.live_before[succ].items():
|
|
248
|
-
# Check that the variables requested by the successor are defined
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
232
|
+
# Check that the variables requested by the successor are defined. If `x` is
|
|
233
|
+
# a local variable, then we must be able to find it in the context.
|
|
234
|
+
# Following Python, locals are exactly those variables that are defined
|
|
235
|
+
# somewhere in the function body.
|
|
236
|
+
if x in cfg.assigned_somewhere:
|
|
237
|
+
if x not in ctx.locals:
|
|
238
|
+
# If the variable is defined on *some* paths, we can give a more
|
|
239
|
+
# informative error message
|
|
240
|
+
if x in cfg.maybe_ass_before[use_bb]:
|
|
241
|
+
err: Error = VarMaybeNotDefinedError(use_bb.vars.used[x], x)
|
|
242
|
+
if bad_branch := diagnose_maybe_undefined(use_bb, x, cfg):
|
|
243
|
+
branch_expr, truth_value = bad_branch
|
|
244
|
+
note = VarMaybeNotDefinedError.BadBranch(
|
|
245
|
+
branch_expr, x, truth_value
|
|
246
|
+
)
|
|
247
|
+
err.add_sub_diagnostic(note)
|
|
248
|
+
else:
|
|
249
|
+
err = VarNotDefinedError(use_bb.vars.used[x], x)
|
|
264
250
|
raise GuppyError(err)
|
|
251
|
+
# If x is not a local, then it must be a global or generic param
|
|
252
|
+
elif x not in ctx.globals and x not in ctx.generic_params:
|
|
265
253
|
raise GuppyError(VarNotDefinedError(use_bb.vars.used[x], x))
|
|
266
254
|
|
|
267
255
|
# Finally, we need to compute the signature of the basic block
|
|
@@ -287,9 +275,7 @@ def check_bb(
|
|
|
287
275
|
return checked_bb
|
|
288
276
|
|
|
289
277
|
|
|
290
|
-
def check_rows_match(
|
|
291
|
-
row1: Row[Variable], row2: Row[Variable], bb: BB, globals: Globals
|
|
292
|
-
) -> None:
|
|
278
|
+
def check_rows_match(row1: Row[Variable], row2: Row[Variable], bb: BB) -> None:
|
|
293
279
|
"""Checks that the types of two rows match up.
|
|
294
280
|
|
|
295
281
|
Otherwise, an error is thrown, alerting the user that a variable has different
|
|
@@ -299,10 +285,7 @@ def check_rows_match(
|
|
|
299
285
|
for x in map1.keys() | map2.keys():
|
|
300
286
|
# If block signature lengths don't match but no undefined error was thrown, some
|
|
301
287
|
# variables may be shadowing global variables.
|
|
302
|
-
v1 = map1
|
|
303
|
-
assert isinstance(v1, Variable | ValueDef)
|
|
304
|
-
v2 = map2.get(x) or cast(ValueDef, globals[x])
|
|
305
|
-
assert isinstance(v2, Variable | ValueDef)
|
|
288
|
+
v1, v2 = map1[x], map2[x]
|
|
306
289
|
if v1.ty != v2.ty:
|
|
307
290
|
# In the error message, we want to mention the variable that was first
|
|
308
291
|
# defined at the start.
|
|
@@ -320,31 +303,9 @@ def check_rows_match(
|
|
|
320
303
|
# We don't add a location to the type hint for the global variable,
|
|
321
304
|
# since it could lead to cross-file diagnostics (which are not
|
|
322
305
|
# supported) or refer to long function definitions.
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
if isinstance(v1, Variable)
|
|
326
|
-
else BranchTypeError.GlobalHint(None, v1)
|
|
327
|
-
)
|
|
328
|
-
sub2 = (
|
|
329
|
-
BranchTypeError.TypeHint(v2.defined_at, v2.ty)
|
|
330
|
-
if isinstance(v2, Variable)
|
|
331
|
-
else BranchTypeError.GlobalHint(None, v2)
|
|
332
|
-
)
|
|
333
|
-
err.add_sub_diagnostic(sub1)
|
|
334
|
-
err.add_sub_diagnostic(sub2)
|
|
306
|
+
err.add_sub_diagnostic(BranchTypeError.TypeHint(v1.defined_at, v1.ty))
|
|
307
|
+
err.add_sub_diagnostic(BranchTypeError.TypeHint(v2.defined_at, v2.ty))
|
|
335
308
|
raise GuppyError(err)
|
|
336
|
-
else:
|
|
337
|
-
# TODO: Remove once https://github.com/CQCL/guppylang/issues/827 is done.
|
|
338
|
-
# If either is a global variable, don't allow shadowing even if types match.
|
|
339
|
-
if not (isinstance(v1, Variable) and isinstance(v2, Variable)):
|
|
340
|
-
local_var = v1 if isinstance(v1, Variable) else v2
|
|
341
|
-
ident = (
|
|
342
|
-
"Expression"
|
|
343
|
-
if local_var.name.startswith("%")
|
|
344
|
-
else f"Variable `{local_var.name}`"
|
|
345
|
-
)
|
|
346
|
-
glob_err = GlobalShadowError(local_var.defined_at, ident)
|
|
347
|
-
raise GuppyError(glob_err)
|
|
348
309
|
|
|
349
310
|
|
|
350
311
|
def diagnose_maybe_undefined(
|
|
@@ -54,6 +54,7 @@ from guppylang_internals.tys.ty import (
|
|
|
54
54
|
|
|
55
55
|
if TYPE_CHECKING:
|
|
56
56
|
from guppylang_internals.definition.struct import StructField
|
|
57
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx
|
|
57
58
|
|
|
58
59
|
|
|
59
60
|
#: A "place" is a description for a storage location of a local value that users
|
|
@@ -507,6 +508,13 @@ class Context(NamedTuple):
|
|
|
507
508
|
locals: Locals[str, Variable]
|
|
508
509
|
generic_params: dict[str, Parameter]
|
|
509
510
|
|
|
511
|
+
@property
|
|
512
|
+
def parsing_ctx(self) -> "TypeParsingCtx":
|
|
513
|
+
"""A type parsing context derived from this checking context."""
|
|
514
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx
|
|
515
|
+
|
|
516
|
+
return TypeParsingCtx(self.globals, self.generic_params)
|
|
517
|
+
|
|
510
518
|
|
|
511
519
|
class DummyEvalDict(dict[str, Any]):
|
|
512
520
|
"""A custom dict that can be passed to `eval` to give better error messages.
|
|
@@ -34,6 +34,7 @@ from guppylang_internals.ast_util import (
|
|
|
34
34
|
AstNode,
|
|
35
35
|
AstVisitor,
|
|
36
36
|
breaks_in_loop,
|
|
37
|
+
get_type,
|
|
37
38
|
get_type_opt,
|
|
38
39
|
return_nodes_in_ast,
|
|
39
40
|
with_loc,
|
|
@@ -101,8 +102,6 @@ from guppylang_internals.nodes import (
|
|
|
101
102
|
FieldAccessAndDrop,
|
|
102
103
|
GenericParamValue,
|
|
103
104
|
GlobalName,
|
|
104
|
-
IterEnd,
|
|
105
|
-
IterHasNext,
|
|
106
105
|
IterNext,
|
|
107
106
|
LocalCall,
|
|
108
107
|
MakeIter,
|
|
@@ -784,14 +783,6 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
|
784
783
|
raise GuppyTypeError(err)
|
|
785
784
|
return expr, ty
|
|
786
785
|
|
|
787
|
-
def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
|
|
788
|
-
node.value, ty = self.synthesize(node.value)
|
|
789
|
-
flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
|
|
790
|
-
exp_sig = FunctionType([FuncInput(ty, flags)], TupleType([bool_type(), ty]))
|
|
791
|
-
return self.synthesize_instance_func(
|
|
792
|
-
node.value, [], "__hasnext__", "an iterator", exp_sig, True
|
|
793
|
-
)
|
|
794
|
-
|
|
795
786
|
def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
|
|
796
787
|
node.value, ty = self.synthesize(node.value)
|
|
797
788
|
flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
|
|
@@ -803,14 +794,6 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
|
803
794
|
node.value, [], "__next__", "an iterator", exp_sig, True
|
|
804
795
|
)
|
|
805
796
|
|
|
806
|
-
def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
|
|
807
|
-
node.value, ty = self.synthesize(node.value)
|
|
808
|
-
flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
|
|
809
|
-
exp_sig = FunctionType([FuncInput(ty, flags)], NoneType())
|
|
810
|
-
return self.synthesize_instance_func(
|
|
811
|
-
node.value, [], "__end__", "an iterator", exp_sig, True
|
|
812
|
-
)
|
|
813
|
-
|
|
814
797
|
def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, Type]:
|
|
815
798
|
raise InternalGuppyError(
|
|
816
799
|
"BB contains `ListComp`. Should have been removed during CFG"
|
|
@@ -946,7 +929,7 @@ def check_type_apply(ty: FunctionType, node: ast.Subscript, ctx: Context) -> Ins
|
|
|
946
929
|
raise GuppyError(err)
|
|
947
930
|
|
|
948
931
|
return [
|
|
949
|
-
param.check_arg(arg_from_ast(arg_expr,
|
|
932
|
+
param.check_arg(arg_from_ast(arg_expr, ctx.parsing_ctx), arg_expr)
|
|
950
933
|
for arg_expr, param in zip(arg_exprs, ty.params, strict=True)
|
|
951
934
|
]
|
|
952
935
|
|
|
@@ -1232,7 +1215,14 @@ def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr:
|
|
|
1232
1215
|
"""Instantiates quantified type arguments in a function."""
|
|
1233
1216
|
assert len(ty.params) == len(inst)
|
|
1234
1217
|
if len(inst) > 0:
|
|
1235
|
-
|
|
1218
|
+
# Partial applications need to be instantiated on the inside
|
|
1219
|
+
if isinstance(node, PartialApply):
|
|
1220
|
+
full_ty = get_type(node.func)
|
|
1221
|
+
assert isinstance(full_ty, FunctionType)
|
|
1222
|
+
assert full_ty.params == ty.params
|
|
1223
|
+
node.func = instantiate_poly(node.func, full_ty, inst)
|
|
1224
|
+
else:
|
|
1225
|
+
node = with_loc(node, TypeApply(value=with_type(ty, node), inst=inst))
|
|
1236
1226
|
return with_type(ty.instantiate(inst), node)
|
|
1237
1227
|
return with_type(ty, node)
|
|
1238
1228
|
|
|
@@ -1309,11 +1299,7 @@ def eval_comptime_expr(node: ComptimeExpr, ctx: Context) -> Any:
|
|
|
1309
1299
|
raise GuppyError(ComptimeExprNotCPythonError(node))
|
|
1310
1300
|
|
|
1311
1301
|
try:
|
|
1312
|
-
python_val = eval( # noqa: S307
|
|
1313
|
-
ast.unparse(node.value),
|
|
1314
|
-
None,
|
|
1315
|
-
DummyEvalDict(ctx, node.value),
|
|
1316
|
-
)
|
|
1302
|
+
python_val = eval(ast.unparse(node.value), DummyEvalDict(ctx, node.value)) # noqa: S307
|
|
1317
1303
|
except DummyEvalDict.GuppyVarUsedError as e:
|
|
1318
1304
|
raise GuppyError(ComptimeExprNotStaticError(e.node or node, e.var)) from None
|
|
1319
1305
|
except Exception as e:
|
|
@@ -7,8 +7,8 @@ node straight from the Python AST. We build a CFG, check it, and return a
|
|
|
7
7
|
|
|
8
8
|
import ast
|
|
9
9
|
import sys
|
|
10
|
-
from dataclasses import dataclass
|
|
11
|
-
from typing import TYPE_CHECKING, ClassVar
|
|
10
|
+
from dataclasses import dataclass, replace
|
|
11
|
+
from typing import TYPE_CHECKING, ClassVar, cast
|
|
12
12
|
|
|
13
13
|
from guppylang_internals.ast_util import return_nodes_in_ast, with_loc
|
|
14
14
|
from guppylang_internals.cfg.bb import BB
|
|
@@ -17,13 +17,28 @@ from guppylang_internals.checker.cfg_checker import CheckedCFG, check_cfg
|
|
|
17
17
|
from guppylang_internals.checker.core import Context, Globals, Place, Variable
|
|
18
18
|
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
19
19
|
from guppylang_internals.definition.common import DefId
|
|
20
|
+
from guppylang_internals.definition.ty import TypeDef
|
|
20
21
|
from guppylang_internals.diagnostic import Error, Help, Note
|
|
21
22
|
from guppylang_internals.engine import DEF_STORE, ENGINE
|
|
22
23
|
from guppylang_internals.error import GuppyError
|
|
23
24
|
from guppylang_internals.experimental import check_capturing_closures_enabled
|
|
24
25
|
from guppylang_internals.nodes import CheckedNestedFunctionDef, NestedFunctionDef
|
|
25
|
-
from guppylang_internals.tys.parsing import
|
|
26
|
-
|
|
26
|
+
from guppylang_internals.tys.parsing import (
|
|
27
|
+
TypeParsingCtx,
|
|
28
|
+
check_function_arg,
|
|
29
|
+
parse_function_arg_annotation,
|
|
30
|
+
type_from_ast,
|
|
31
|
+
type_with_flags_from_ast,
|
|
32
|
+
)
|
|
33
|
+
from guppylang_internals.tys.ty import (
|
|
34
|
+
ExistentialTypeVar,
|
|
35
|
+
FuncInput,
|
|
36
|
+
FunctionType,
|
|
37
|
+
InputFlags,
|
|
38
|
+
NoneType,
|
|
39
|
+
Type,
|
|
40
|
+
unify,
|
|
41
|
+
)
|
|
27
42
|
|
|
28
43
|
if sys.version_info >= (3, 12):
|
|
29
44
|
from guppylang_internals.tys.parsing import parse_parameter
|
|
@@ -53,6 +68,15 @@ class MissingArgAnnotationError(Error):
|
|
|
53
68
|
span_label: ClassVar[str] = "Argument requires a type annotation"
|
|
54
69
|
|
|
55
70
|
|
|
71
|
+
@dataclass(frozen=True)
|
|
72
|
+
class RecursiveSelfError(Error):
|
|
73
|
+
title: ClassVar[str] = "Recursive self annotation"
|
|
74
|
+
span_label: ClassVar[str] = (
|
|
75
|
+
"Type of `{self_arg}` cannot recursively refer to `Self`"
|
|
76
|
+
)
|
|
77
|
+
self_arg: str
|
|
78
|
+
|
|
79
|
+
|
|
56
80
|
@dataclass(frozen=True)
|
|
57
81
|
class MissingReturnAnnotationError(Error):
|
|
58
82
|
title: ClassVar[str] = "Missing type annotation"
|
|
@@ -67,6 +91,43 @@ class MissingReturnAnnotationError(Error):
|
|
|
67
91
|
func: str
|
|
68
92
|
|
|
69
93
|
|
|
94
|
+
@dataclass(frozen=True)
|
|
95
|
+
class InvalidSelfError(Error):
|
|
96
|
+
title: ClassVar[str] = "Invalid self annotation"
|
|
97
|
+
span_label: ClassVar[str] = "`{self_arg}` must be of type `{self_ty}`"
|
|
98
|
+
self_arg: str
|
|
99
|
+
self_ty: Type
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclass(frozen=True)
|
|
103
|
+
class SelfParamsShadowedError(Error):
|
|
104
|
+
title: ClassVar[str] = "Shadowed generic parameters"
|
|
105
|
+
span_label: ClassVar[str] = (
|
|
106
|
+
"Cannot infer type for `{self_arg}` since parameter `{param}` of "
|
|
107
|
+
"`{ty_defn.name}` is shadowed"
|
|
108
|
+
)
|
|
109
|
+
param: str
|
|
110
|
+
ty_defn: "TypeDef"
|
|
111
|
+
self_arg: str
|
|
112
|
+
|
|
113
|
+
@dataclass(frozen=True)
|
|
114
|
+
class ExplicitHelp(Help):
|
|
115
|
+
span_label: ClassVar[str] = (
|
|
116
|
+
"Consider specifying the type explicitly: `{suggestion}`"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def suggestion(self) -> str:
|
|
121
|
+
parent = self._parent
|
|
122
|
+
assert isinstance(parent, SelfParamsShadowedError)
|
|
123
|
+
params = (
|
|
124
|
+
f"[{', '.join(f'?{p.name}' for p in parent.ty_defn.params)}]"
|
|
125
|
+
if parent.ty_defn.params
|
|
126
|
+
else ""
|
|
127
|
+
)
|
|
128
|
+
return f'{parent.self_arg}: "{parent.ty_defn.name}{params}"'
|
|
129
|
+
|
|
130
|
+
|
|
70
131
|
def check_global_func_def(
|
|
71
132
|
func_def: ast.FunctionDef, ty: FunctionType, globals: Globals
|
|
72
133
|
) -> CheckedCFG[Place]:
|
|
@@ -176,9 +237,16 @@ def check_nested_func_def(
|
|
|
176
237
|
return with_loc(func_def, checked_def)
|
|
177
238
|
|
|
178
239
|
|
|
179
|
-
def check_signature(
|
|
240
|
+
def check_signature(
|
|
241
|
+
func_def: ast.FunctionDef, globals: Globals, def_id: DefId | None = None
|
|
242
|
+
) -> FunctionType:
|
|
180
243
|
"""Checks the signature of a function definition and returns the corresponding
|
|
181
|
-
Guppy type.
|
|
244
|
+
Guppy type.
|
|
245
|
+
|
|
246
|
+
If this is a method, then the `DefId` of the associated parent type should also be
|
|
247
|
+
passed. This will be used to check or infer the type annotation for the `self`
|
|
248
|
+
argument.
|
|
249
|
+
"""
|
|
182
250
|
if len(func_def.args.posonlyargs) != 0:
|
|
183
251
|
raise GuppyError(
|
|
184
252
|
UnsupportedError(func_def.args.posonlyargs[0], "Positional-only parameters")
|
|
@@ -211,23 +279,29 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
|
|
|
211
279
|
param = parse_parameter(param_node, i, globals)
|
|
212
280
|
param_var_mapping[param.name] = param
|
|
213
281
|
|
|
214
|
-
|
|
282
|
+
# Figure out if this is a method
|
|
283
|
+
self_defn: TypeDef | None = None
|
|
284
|
+
if def_id is not None and def_id in DEF_STORE.impl_parents:
|
|
285
|
+
self_defn = cast(TypeDef, ENGINE.get_checked(DEF_STORE.impl_parents[def_id]))
|
|
286
|
+
assert isinstance(self_defn, TypeDef)
|
|
287
|
+
|
|
288
|
+
inputs = []
|
|
215
289
|
input_names = []
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
290
|
+
ctx = TypeParsingCtx(globals, param_var_mapping, allow_free_vars=True)
|
|
291
|
+
for i, inp in enumerate(func_def.args.args):
|
|
292
|
+
# Special handling for `self` arguments. Note that `__new__` is excluded here
|
|
293
|
+
# since it's not a method so doesn't take `self`.
|
|
294
|
+
if self_defn and i == 0 and func_def.name != "__new__":
|
|
295
|
+
input = parse_self_arg(inp, self_defn, ctx)
|
|
296
|
+
ctx = replace(ctx, self_ty=input.ty)
|
|
297
|
+
else:
|
|
298
|
+
ty_ast = inp.annotation
|
|
299
|
+
if ty_ast is None:
|
|
300
|
+
raise GuppyError(MissingArgAnnotationError(inp))
|
|
301
|
+
input = parse_function_arg_annotation(ty_ast, inp.arg, ctx)
|
|
302
|
+
inputs.append(input)
|
|
221
303
|
input_names.append(inp.arg)
|
|
222
|
-
|
|
223
|
-
input_nodes,
|
|
224
|
-
func_def.returns,
|
|
225
|
-
input_names,
|
|
226
|
-
func_def,
|
|
227
|
-
globals,
|
|
228
|
-
param_var_mapping,
|
|
229
|
-
True,
|
|
230
|
-
)
|
|
304
|
+
output = type_from_ast(func_def.returns, ctx)
|
|
231
305
|
return FunctionType(
|
|
232
306
|
inputs,
|
|
233
307
|
output,
|
|
@@ -236,6 +310,81 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
|
|
|
236
310
|
)
|
|
237
311
|
|
|
238
312
|
|
|
313
|
+
def parse_self_arg(arg: ast.arg, self_defn: TypeDef, ctx: TypeParsingCtx) -> FuncInput:
|
|
314
|
+
"""Handles parsing of the `self` argument on methods.
|
|
315
|
+
|
|
316
|
+
This argument is special since its type annotation may be omitted. Furthermore, if a
|
|
317
|
+
type is provided then it must match the parent type.
|
|
318
|
+
"""
|
|
319
|
+
assert self_defn.params is not None
|
|
320
|
+
if arg.annotation is None:
|
|
321
|
+
return handle_implicit_self_arg(arg, self_defn, ctx)
|
|
322
|
+
|
|
323
|
+
# If the user has provided an annotation for `self`, then we go ahead and parse it.
|
|
324
|
+
# However, in the annotation the user is also allowed to use `Self`, so we have to
|
|
325
|
+
# specify a `self_ty` in the context.
|
|
326
|
+
self_ty_head = self_defn.check_instantiate(
|
|
327
|
+
[param.to_existential()[0] for param in self_defn.params]
|
|
328
|
+
)
|
|
329
|
+
self_ty_placeholder = ExistentialTypeVar.fresh(
|
|
330
|
+
"Self", copyable=self_ty_head.copyable, droppable=self_ty_head.droppable
|
|
331
|
+
)
|
|
332
|
+
assert ctx.self_ty is None
|
|
333
|
+
ctx = replace(ctx, self_ty=self_ty_placeholder)
|
|
334
|
+
user_ty, user_flags = type_with_flags_from_ast(arg.annotation, ctx)
|
|
335
|
+
|
|
336
|
+
# If the user just annotates `self: Self` then we can fall back to the case where
|
|
337
|
+
# no annotation is provided at all
|
|
338
|
+
if user_ty == self_ty_placeholder:
|
|
339
|
+
return handle_implicit_self_arg(arg, self_defn, ctx, user_flags)
|
|
340
|
+
|
|
341
|
+
# Annotations like `self: Foo[Self]` are not allowed (would be an infinite type)
|
|
342
|
+
if self_ty_placeholder in user_ty.unsolved_vars:
|
|
343
|
+
raise GuppyError(RecursiveSelfError(arg.annotation, arg.arg))
|
|
344
|
+
|
|
345
|
+
# Check that the annotation matches the parent type. We can do this by unifying with
|
|
346
|
+
# the expected self type where all params are instantiated with unification vars
|
|
347
|
+
subst = unify(user_ty, self_ty_head, {})
|
|
348
|
+
if subst is None:
|
|
349
|
+
raise GuppyError(InvalidSelfError(arg.annotation, arg.arg, self_ty_head))
|
|
350
|
+
|
|
351
|
+
return check_function_arg(user_ty, user_flags, arg, arg.arg, ctx)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def handle_implicit_self_arg(
|
|
355
|
+
arg: ast.arg,
|
|
356
|
+
self_defn: TypeDef,
|
|
357
|
+
ctx: TypeParsingCtx,
|
|
358
|
+
flags: InputFlags = InputFlags.NoFlags,
|
|
359
|
+
) -> FuncInput:
|
|
360
|
+
"""Handles the case where no annotation for `self` is provided.
|
|
361
|
+
|
|
362
|
+
Generates the most generic annotation that is possible by making the function as
|
|
363
|
+
generic as the parent type.
|
|
364
|
+
"""
|
|
365
|
+
# Check that the user hasn't shadowed some of the parent type parameters using a
|
|
366
|
+
# Python 3.12 style parameter declaration
|
|
367
|
+
assert self_defn.params is not None
|
|
368
|
+
shadowed_params = [
|
|
369
|
+
param for param in self_defn.params if param.name in ctx.param_var_mapping
|
|
370
|
+
]
|
|
371
|
+
if shadowed_params:
|
|
372
|
+
param = shadowed_params.pop()
|
|
373
|
+
err = SelfParamsShadowedError(arg, param.name, self_defn, arg.arg)
|
|
374
|
+
err.add_sub_diagnostic(SelfParamsShadowedError.ExplicitHelp(arg))
|
|
375
|
+
raise GuppyError(err)
|
|
376
|
+
|
|
377
|
+
# The generic params inherited from the parent type should appear first in the
|
|
378
|
+
# parameter list, so we have to shift the existing ones
|
|
379
|
+
for name, param in ctx.param_var_mapping.items():
|
|
380
|
+
ctx.param_var_mapping[name] = param.with_idx(param.idx + len(self_defn.params))
|
|
381
|
+
|
|
382
|
+
ctx.param_var_mapping.update({param.name: param for param in self_defn.params})
|
|
383
|
+
self_args = [param.to_bound() for param in self_defn.params]
|
|
384
|
+
self_ty = self_defn.check_instantiate(self_args, loc=arg)
|
|
385
|
+
return check_function_arg(self_ty, flags, arg, arg.arg, ctx)
|
|
386
|
+
|
|
387
|
+
|
|
239
388
|
def parse_function_with_docstring(
|
|
240
389
|
func_ast: ast.FunctionDef,
|
|
241
390
|
) -> tuple[ast.FunctionDef, str | None]:
|
|
@@ -356,7 +356,7 @@ class StmtChecker(AstVisitor[BBStatement]):
|
|
|
356
356
|
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
|
|
357
357
|
if node.value is None:
|
|
358
358
|
raise GuppyError(UnsupportedError(node, "Variable declarations"))
|
|
359
|
-
ty = type_from_ast(node.annotation, self.ctx.
|
|
359
|
+
ty = type_from_ast(node.annotation, self.ctx.parsing_ctx)
|
|
360
360
|
node.value, subst = self._check_expr(node.value, ty)
|
|
361
361
|
assert not ty.unsolved_vars # `ty` must be closed!
|
|
362
362
|
assert len(subst) == 0
|