guppylang-internals 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +101 -3
- guppylang_internals/checker/core.py +12 -0
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/expr_checker.py +55 -29
- guppylang_internals/checker/func_checker.py +171 -22
- guppylang_internals/checker/linearity_checker.py +65 -0
- guppylang_internals/checker/modifier_checker.py +116 -0
- guppylang_internals/checker/stmt_checker.py +49 -2
- guppylang_internals/compiler/core.py +90 -53
- guppylang_internals/compiler/expr_compiler.py +49 -114
- guppylang_internals/compiler/modifier_compiler.py +174 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +124 -58
- guppylang_internals/definition/const.py +2 -2
- guppylang_internals/definition/custom.py +36 -2
- guppylang_internals/definition/declaration.py +4 -5
- guppylang_internals/definition/extern.py +2 -2
- guppylang_internals/definition/function.py +1 -1
- guppylang_internals/definition/parameter.py +10 -5
- guppylang_internals/definition/pytket_circuits.py +14 -42
- guppylang_internals/definition/struct.py +17 -14
- guppylang_internals/definition/traced.py +1 -1
- guppylang_internals/definition/ty.py +9 -3
- guppylang_internals/definition/wasm.py +2 -2
- guppylang_internals/engine.py +13 -2
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +124 -23
- guppylang_internals/std/_internal/compiler/array.py +94 -282
- guppylang_internals/std/_internal/compiler/tket_exts.py +12 -8
- guppylang_internals/std/_internal/compiler/wasm.py +37 -26
- guppylang_internals/tracing/function.py +13 -2
- guppylang_internals/tracing/unpacking.py +33 -28
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +32 -16
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +6 -0
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +118 -145
- guppylang_internals/tys/qubit.py +27 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +31 -21
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +4 -4
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +49 -46
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
guppylang_internals/__init__.py
CHANGED
guppylang_internals/ast_util.py
CHANGED
|
@@ -106,6 +106,14 @@ def return_nodes_in_ast(node: Any) -> list[ast.Return]:
|
|
|
106
106
|
return cast(list[ast.Return], found)
|
|
107
107
|
|
|
108
108
|
|
|
109
|
+
def loop_in_ast(node: Any) -> list[ast.For | ast.While]:
|
|
110
|
+
"""Returns all `For` and `While` nodes occurring in an AST."""
|
|
111
|
+
found = find_nodes(
|
|
112
|
+
lambda n: isinstance(n, ast.For | ast.While), node, {ast.FunctionDef}
|
|
113
|
+
)
|
|
114
|
+
return cast(list[ast.For | ast.While], found)
|
|
115
|
+
|
|
116
|
+
|
|
109
117
|
def breaks_in_loop(node: Any) -> list[ast.Break]:
|
|
110
118
|
"""Returns all `Break` nodes occurring in a loop.
|
|
111
119
|
|
|
@@ -117,6 +125,19 @@ def breaks_in_loop(node: Any) -> list[ast.Break]:
|
|
|
117
125
|
return cast(list[ast.Break], found)
|
|
118
126
|
|
|
119
127
|
|
|
128
|
+
def loop_controls_in_loop(node: Any) -> list[ast.Break | ast.Continue]:
|
|
129
|
+
"""Returns all `Break` and `Continue` nodes occurring in a loop.
|
|
130
|
+
|
|
131
|
+
Note that breaks in nested loops are excluded.
|
|
132
|
+
"""
|
|
133
|
+
found = find_nodes(
|
|
134
|
+
lambda n: isinstance(n, ast.Break | ast.Continue),
|
|
135
|
+
node,
|
|
136
|
+
{ast.For, ast.While, ast.FunctionDef},
|
|
137
|
+
)
|
|
138
|
+
return cast(list[ast.Break | ast.Continue], found)
|
|
139
|
+
|
|
140
|
+
|
|
120
141
|
class ContextAdjuster(ast.NodeTransformer):
|
|
121
142
|
"""Updates the `ast.Context` indicating if expressions occur on the LHS or RHS."""
|
|
122
143
|
|
guppylang_internals/cfg/bb.py
CHANGED
|
@@ -13,6 +13,7 @@ from guppylang_internals.nodes import (
|
|
|
13
13
|
DesugaredGenerator,
|
|
14
14
|
DesugaredGeneratorExpr,
|
|
15
15
|
DesugaredListComp,
|
|
16
|
+
ModifiedBlock,
|
|
16
17
|
NestedFunctionDef,
|
|
17
18
|
)
|
|
18
19
|
|
|
@@ -44,6 +45,7 @@ BBStatement = (
|
|
|
44
45
|
| ast.Expr
|
|
45
46
|
| ast.Return
|
|
46
47
|
| NestedFunctionDef
|
|
48
|
+
| ModifiedBlock
|
|
47
49
|
)
|
|
48
50
|
|
|
49
51
|
|
|
@@ -219,3 +221,21 @@ class VariableVisitor(ast.NodeVisitor):
|
|
|
219
221
|
|
|
220
222
|
# The name of the function is now assigned
|
|
221
223
|
self.stats.assigned[node.name] = node
|
|
224
|
+
|
|
225
|
+
def visit_ModifiedBlock(self, node: ModifiedBlock) -> None:
|
|
226
|
+
for item in node.control:
|
|
227
|
+
self.visit(item)
|
|
228
|
+
for item in node.power:
|
|
229
|
+
self.visit(item)
|
|
230
|
+
|
|
231
|
+
# Similarly to nested functions
|
|
232
|
+
from guppylang_internals.cfg.analysis import LivenessAnalysis
|
|
233
|
+
|
|
234
|
+
stats = {bb: bb.compute_variable_stats() for bb in node.cfg.bbs}
|
|
235
|
+
live = LivenessAnalysis(stats).run(node.cfg.bbs)
|
|
236
|
+
assigned_before_in_bb = self.stats.assigned.keys()
|
|
237
|
+
self.stats.used |= {
|
|
238
|
+
x: using_bb.vars.used[x]
|
|
239
|
+
for x, using_bb in live[node.cfg.entry_bb].items()
|
|
240
|
+
if x not in assigned_before_in_bb
|
|
241
|
+
}
|
|
@@ -9,6 +9,8 @@ from guppylang_internals.ast_util import (
|
|
|
9
9
|
AstVisitor,
|
|
10
10
|
ContextAdjuster,
|
|
11
11
|
find_nodes,
|
|
12
|
+
loop_controls_in_loop,
|
|
13
|
+
return_nodes_in_ast,
|
|
12
14
|
set_location_from,
|
|
13
15
|
template_replace,
|
|
14
16
|
with_loc,
|
|
@@ -16,19 +18,34 @@ from guppylang_internals.ast_util import (
|
|
|
16
18
|
from guppylang_internals.cfg.bb import BB, BBStatement
|
|
17
19
|
from guppylang_internals.cfg.cfg import CFG
|
|
18
20
|
from guppylang_internals.checker.core import Globals
|
|
19
|
-
from guppylang_internals.checker.errors.generic import
|
|
21
|
+
from guppylang_internals.checker.errors.generic import (
|
|
22
|
+
ExpectedError,
|
|
23
|
+
UnexpectedInWithBlockError,
|
|
24
|
+
UnknownModifierError,
|
|
25
|
+
UnsupportedError,
|
|
26
|
+
)
|
|
27
|
+
from guppylang_internals.checker.errors.type_errors import WrongNumberOfArgsError
|
|
20
28
|
from guppylang_internals.diagnostic import Error
|
|
21
29
|
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
22
|
-
from guppylang_internals.experimental import
|
|
30
|
+
from guppylang_internals.experimental import (
|
|
31
|
+
check_lists_enabled,
|
|
32
|
+
check_modifiers_enabled,
|
|
33
|
+
)
|
|
23
34
|
from guppylang_internals.nodes import (
|
|
24
35
|
ComptimeExpr,
|
|
36
|
+
Control,
|
|
37
|
+
Dagger,
|
|
25
38
|
DesugaredGenerator,
|
|
26
39
|
DesugaredGeneratorExpr,
|
|
27
40
|
DesugaredListComp,
|
|
28
41
|
IterNext,
|
|
29
42
|
MakeIter,
|
|
43
|
+
ModifiedBlock,
|
|
44
|
+
Modifier,
|
|
30
45
|
NestedFunctionDef,
|
|
46
|
+
Power,
|
|
31
47
|
)
|
|
48
|
+
from guppylang_internals.span import Span, to_span
|
|
32
49
|
from guppylang_internals.tys.ty import NoneType
|
|
33
50
|
|
|
34
51
|
# In order to build expressions, need an endless stream of unique temporary variables
|
|
@@ -135,7 +152,10 @@ class CFGBuilder(AstVisitor[BB | None]):
|
|
|
135
152
|
Builds the expression and mutates `node.value` to point to the built expression.
|
|
136
153
|
Returns the BB in which the expression is available and adds the node to it.
|
|
137
154
|
"""
|
|
138
|
-
if
|
|
155
|
+
if (
|
|
156
|
+
not isinstance(node, NestedFunctionDef | ModifiedBlock)
|
|
157
|
+
and node.value is not None
|
|
158
|
+
):
|
|
139
159
|
node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
|
|
140
160
|
bb.statements.append(node)
|
|
141
161
|
return bb
|
|
@@ -265,6 +285,84 @@ class CFGBuilder(AstVisitor[BB | None]):
|
|
|
265
285
|
bb.statements.append(new_node)
|
|
266
286
|
return bb
|
|
267
287
|
|
|
288
|
+
def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None:
|
|
289
|
+
check_modifiers_enabled(node)
|
|
290
|
+
self._validate_modified_block(node)
|
|
291
|
+
|
|
292
|
+
cfg = CFGBuilder().build(node.body, True, self.globals)
|
|
293
|
+
new_node = ModifiedBlock(
|
|
294
|
+
cfg=cfg,
|
|
295
|
+
**dict(ast.iter_fields(node)),
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
for item in node.items:
|
|
299
|
+
item.context_expr, bb = ExprBuilder.build(item.context_expr, self.cfg, bb)
|
|
300
|
+
modifier = self._handle_withitem(item)
|
|
301
|
+
new_node.push_modifier(modifier)
|
|
302
|
+
|
|
303
|
+
set_location_from(new_node, node)
|
|
304
|
+
bb.statements.append(new_node)
|
|
305
|
+
return bb
|
|
306
|
+
|
|
307
|
+
def _handle_withitem(self, node: ast.withitem) -> Modifier:
|
|
308
|
+
# Check that `as` notation is not used
|
|
309
|
+
if node.optional_vars is not None:
|
|
310
|
+
span = Span(
|
|
311
|
+
to_span(node.context_expr).start, to_span(node.optional_vars).end
|
|
312
|
+
)
|
|
313
|
+
raise GuppyError(UnsupportedError(span, "`as` expression", singular=True))
|
|
314
|
+
|
|
315
|
+
e = node.context_expr
|
|
316
|
+
modifier: Modifier
|
|
317
|
+
match e:
|
|
318
|
+
case ast.Name(id="dagger"):
|
|
319
|
+
modifier = Dagger(e)
|
|
320
|
+
case ast.Call(func=ast.Name(id="dagger")):
|
|
321
|
+
if len(e.args) != 0:
|
|
322
|
+
span = Span(to_span(e.args[0]).start, to_span(e.args[-1]).end)
|
|
323
|
+
raise GuppyError(WrongNumberOfArgsError(span, 0, len(e.args)))
|
|
324
|
+
modifier = Dagger(e)
|
|
325
|
+
case ast.Call(func=ast.Name(id="control")):
|
|
326
|
+
if len(e.args) == 0:
|
|
327
|
+
span = Span(to_span(e.func).end, to_span(e).end)
|
|
328
|
+
raise GuppyError(WrongNumberOfArgsError(span, 1, len(e.args)))
|
|
329
|
+
modifier = Control(e, e.args)
|
|
330
|
+
case ast.Call(func=ast.Name(id="power")):
|
|
331
|
+
if len(e.args) == 0:
|
|
332
|
+
span = Span(to_span(e.func).end, to_span(e).end)
|
|
333
|
+
raise GuppyError(WrongNumberOfArgsError(span, 1, len(e.args)))
|
|
334
|
+
elif len(e.args) != 1:
|
|
335
|
+
span = Span(to_span(e.args[1]).start, to_span(e.args[-1]).end)
|
|
336
|
+
raise GuppyError(WrongNumberOfArgsError(span, 1, len(e.args)))
|
|
337
|
+
modifier = Power(e, e.args[0])
|
|
338
|
+
case _:
|
|
339
|
+
raise GuppyError(UnknownModifierError(e))
|
|
340
|
+
return modifier
|
|
341
|
+
|
|
342
|
+
def _validate_modified_block(self, node: ast.With) -> None:
|
|
343
|
+
# Check if the body contains a return statement.
|
|
344
|
+
return_in_body = return_nodes_in_ast(node)
|
|
345
|
+
if len(return_in_body) != 0:
|
|
346
|
+
err = UnexpectedInWithBlockError(return_in_body[0], "return", "Return")
|
|
347
|
+
span = Span(
|
|
348
|
+
to_span(node.items[0].context_expr).start,
|
|
349
|
+
to_span(node.items[-1].context_expr).end,
|
|
350
|
+
)
|
|
351
|
+
err.add_sub_diagnostic(UnexpectedInWithBlockError.Modifier(span))
|
|
352
|
+
raise GuppyError(err)
|
|
353
|
+
|
|
354
|
+
loop_controls_in_body = loop_controls_in_loop(node)
|
|
355
|
+
if len(loop_controls_in_body) != 0:
|
|
356
|
+
lc = loop_controls_in_body[0]
|
|
357
|
+
kind = lc.__class__.__name__
|
|
358
|
+
err = UnexpectedInWithBlockError(lc, "loop control", kind)
|
|
359
|
+
span = Span(
|
|
360
|
+
to_span(node.items[0].context_expr).start,
|
|
361
|
+
to_span(node.items[-1].context_expr).end,
|
|
362
|
+
)
|
|
363
|
+
err.add_sub_diagnostic(UnexpectedInWithBlockError.Modifier(span))
|
|
364
|
+
raise GuppyError(err)
|
|
365
|
+
|
|
268
366
|
def generic_visit(self, node: ast.AST, bb: BB, jumps: Jumps) -> BB | None:
|
|
269
367
|
# When adding support for new statements, we have to remember to use the
|
|
270
368
|
# ExprBuilder to transform all included expressions!
|
|
@@ -54,6 +54,7 @@ from guppylang_internals.tys.ty import (
|
|
|
54
54
|
|
|
55
55
|
if TYPE_CHECKING:
|
|
56
56
|
from guppylang_internals.definition.struct import StructField
|
|
57
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx
|
|
57
58
|
|
|
58
59
|
|
|
59
60
|
#: A "place" is a description for a storage location of a local value that users
|
|
@@ -116,6 +117,10 @@ class Variable:
|
|
|
116
117
|
"""Returns a new `Variable` instance with an updated definition location."""
|
|
117
118
|
return replace(self, defined_at=node)
|
|
118
119
|
|
|
120
|
+
def add_flags(self, flags: InputFlags) -> "Variable":
|
|
121
|
+
"""Returns a new `Variable` instance with updated flags."""
|
|
122
|
+
return replace(self, flags=self.flags | flags)
|
|
123
|
+
|
|
119
124
|
|
|
120
125
|
@dataclass(frozen=True, kw_only=True)
|
|
121
126
|
class ComptimeVariable(Variable):
|
|
@@ -507,6 +512,13 @@ class Context(NamedTuple):
|
|
|
507
512
|
locals: Locals[str, Variable]
|
|
508
513
|
generic_params: dict[str, Parameter]
|
|
509
514
|
|
|
515
|
+
@property
|
|
516
|
+
def parsing_ctx(self) -> "TypeParsingCtx":
|
|
517
|
+
"""A type parsing context derived from this checking context."""
|
|
518
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx
|
|
519
|
+
|
|
520
|
+
return TypeParsingCtx(self.globals, self.generic_params)
|
|
521
|
+
|
|
510
522
|
|
|
511
523
|
class DummyEvalDict(dict[str, Any]):
|
|
512
524
|
"""A custom dict that can be passed to `eval` to give better error messages.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from typing import ClassVar
|
|
3
3
|
|
|
4
|
-
from guppylang_internals.diagnostic import Error
|
|
4
|
+
from guppylang_internals.diagnostic import Error, Note
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
@dataclass(frozen=True)
|
|
@@ -43,3 +43,34 @@ class ExpectedError(Error):
|
|
|
43
43
|
@property
|
|
44
44
|
def extra(self) -> str:
|
|
45
45
|
return f", got {self.got}" if self.got else ""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class UnknownModifierError(Error):
|
|
50
|
+
title: ClassVar[str] = "Unknown modifier"
|
|
51
|
+
span_label: ClassVar[str] = (
|
|
52
|
+
"Expected one of {{dagger, control(...), or power(...)}}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass(frozen=True)
|
|
57
|
+
class UnexpectedInWithBlockError(Error):
|
|
58
|
+
title: ClassVar[str] = "Unexpected {kind}"
|
|
59
|
+
span_label: ClassVar[str] = "{things} found in a `With` block"
|
|
60
|
+
kind: str
|
|
61
|
+
things: str
|
|
62
|
+
|
|
63
|
+
@dataclass(frozen=True)
|
|
64
|
+
class Modifier(Note):
|
|
65
|
+
span_label: ClassVar[str] = "modifier is used here"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass(frozen=True)
|
|
69
|
+
class InvalidUnderDagger(Error):
|
|
70
|
+
title: ClassVar[str] = "Invalid expression in dagger"
|
|
71
|
+
span_label: ClassVar[str] = "{things} found in a dagger context"
|
|
72
|
+
things: str
|
|
73
|
+
|
|
74
|
+
@dataclass(frozen=True)
|
|
75
|
+
class Dagger(Note):
|
|
76
|
+
span_label: ClassVar[str] = "dagger modifier is used here"
|
|
@@ -95,6 +95,20 @@ class TypeInferenceError(Error):
|
|
|
95
95
|
unsolved_ty: Type
|
|
96
96
|
|
|
97
97
|
|
|
98
|
+
@dataclass(frozen=True)
|
|
99
|
+
class ParameterInferenceError(Error):
|
|
100
|
+
title: ClassVar[str] = "Cannot infer generic parameter"
|
|
101
|
+
span_label: ClassVar[str] = (
|
|
102
|
+
"Cannot infer generic parameter `{param}` of this function"
|
|
103
|
+
)
|
|
104
|
+
param: str
|
|
105
|
+
|
|
106
|
+
@dataclass(frozen=True)
|
|
107
|
+
class SignatureHint(Note):
|
|
108
|
+
message: ClassVar[str] = "Function signature is `{sig}`"
|
|
109
|
+
sig: FunctionType
|
|
110
|
+
|
|
111
|
+
|
|
98
112
|
@dataclass(frozen=True)
|
|
99
113
|
class IllegalConstant(Error):
|
|
100
114
|
title: ClassVar[str] = "Unsupported constant"
|
|
@@ -23,6 +23,7 @@ can be used to infer a type for an expression.
|
|
|
23
23
|
import ast
|
|
24
24
|
import sys
|
|
25
25
|
import traceback
|
|
26
|
+
from collections.abc import Sequence
|
|
26
27
|
from contextlib import suppress
|
|
27
28
|
from dataclasses import replace
|
|
28
29
|
from types import ModuleType
|
|
@@ -34,6 +35,7 @@ from guppylang_internals.ast_util import (
|
|
|
34
35
|
AstNode,
|
|
35
36
|
AstVisitor,
|
|
36
37
|
breaks_in_loop,
|
|
38
|
+
get_type,
|
|
37
39
|
get_type_opt,
|
|
38
40
|
return_nodes_in_ast,
|
|
39
41
|
with_loc,
|
|
@@ -73,6 +75,7 @@ from guppylang_internals.checker.errors.type_errors import (
|
|
|
73
75
|
ModuleMemberNotFoundError,
|
|
74
76
|
NonLinearInstantiateError,
|
|
75
77
|
NotCallableError,
|
|
78
|
+
ParameterInferenceError,
|
|
76
79
|
TupleIndexOutOfBoundsError,
|
|
77
80
|
TypeApplyNotGenericError,
|
|
78
81
|
TypeInferenceError,
|
|
@@ -101,8 +104,6 @@ from guppylang_internals.nodes import (
|
|
|
101
104
|
FieldAccessAndDrop,
|
|
102
105
|
GenericParamValue,
|
|
103
106
|
GlobalName,
|
|
104
|
-
IterEnd,
|
|
105
|
-
IterHasNext,
|
|
106
107
|
IterNext,
|
|
107
108
|
LocalCall,
|
|
108
109
|
MakeIter,
|
|
@@ -131,7 +132,7 @@ from guppylang_internals.tys.builtin import (
|
|
|
131
132
|
string_type,
|
|
132
133
|
)
|
|
133
134
|
from guppylang_internals.tys.const import Const, ConstValue
|
|
134
|
-
from guppylang_internals.tys.param import ConstParam, TypeParam
|
|
135
|
+
from guppylang_internals.tys.param import ConstParam, TypeParam, check_all_args
|
|
135
136
|
from guppylang_internals.tys.parsing import arg_from_ast
|
|
136
137
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
137
138
|
from guppylang_internals.tys.ty import (
|
|
@@ -150,6 +151,7 @@ from guppylang_internals.tys.ty import (
|
|
|
150
151
|
parse_function_tensor,
|
|
151
152
|
unify,
|
|
152
153
|
)
|
|
154
|
+
from guppylang_internals.tys.var import ExistentialVar
|
|
153
155
|
|
|
154
156
|
if TYPE_CHECKING:
|
|
155
157
|
from guppylang_internals.diagnostic import SubDiagnostic
|
|
@@ -463,8 +465,15 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
|
463
465
|
# A `value.attr` attribute access. Unfortunately, the `attr` is just a string,
|
|
464
466
|
# not an AST node, so we have to compute its span by hand. This is fine since
|
|
465
467
|
# linebreaks are not allowed in the identifier following the `.`
|
|
468
|
+
# The only exception are attributes accesses that are generated during
|
|
469
|
+
# desugaring (for example for iterators in `for` loops). Since those just
|
|
470
|
+
# inherit the span of the sugared code, we could have line breaks there.
|
|
471
|
+
# See https://github.com/CQCL/guppylang/issues/1301
|
|
466
472
|
span = to_span(node)
|
|
467
|
-
|
|
473
|
+
if span.start.line == span.end.line:
|
|
474
|
+
attr_span = Span(span.end.shift_left(len(node.attr)), span.end)
|
|
475
|
+
else:
|
|
476
|
+
attr_span = span
|
|
468
477
|
if module := self._is_python_module(node.value):
|
|
469
478
|
if node.attr in module.__dict__:
|
|
470
479
|
val = module.__dict__[node.attr]
|
|
@@ -784,14 +793,6 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
|
784
793
|
raise GuppyTypeError(err)
|
|
785
794
|
return expr, ty
|
|
786
795
|
|
|
787
|
-
def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
|
|
788
|
-
node.value, ty = self.synthesize(node.value)
|
|
789
|
-
flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
|
|
790
|
-
exp_sig = FunctionType([FuncInput(ty, flags)], TupleType([bool_type(), ty]))
|
|
791
|
-
return self.synthesize_instance_func(
|
|
792
|
-
node.value, [], "__hasnext__", "an iterator", exp_sig, True
|
|
793
|
-
)
|
|
794
|
-
|
|
795
796
|
def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
|
|
796
797
|
node.value, ty = self.synthesize(node.value)
|
|
797
798
|
flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
|
|
@@ -803,14 +804,6 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
|
803
804
|
node.value, [], "__next__", "an iterator", exp_sig, True
|
|
804
805
|
)
|
|
805
806
|
|
|
806
|
-
def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
|
|
807
|
-
node.value, ty = self.synthesize(node.value)
|
|
808
|
-
flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
|
|
809
|
-
exp_sig = FunctionType([FuncInput(ty, flags)], NoneType())
|
|
810
|
-
return self.synthesize_instance_func(
|
|
811
|
-
node.value, [], "__end__", "an iterator", exp_sig, True
|
|
812
|
-
)
|
|
813
|
-
|
|
814
807
|
def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, Type]:
|
|
815
808
|
raise InternalGuppyError(
|
|
816
809
|
"BB contains `ListComp`. Should have been removed during CFG"
|
|
@@ -945,10 +938,9 @@ def check_type_apply(ty: FunctionType, node: ast.Subscript, ctx: Context) -> Ins
|
|
|
945
938
|
err.add_sub_diagnostic(WrongNumberOfArgsError.SignatureHint(None, ty))
|
|
946
939
|
raise GuppyError(err)
|
|
947
940
|
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
]
|
|
941
|
+
inst = [arg_from_ast(node, ctx.parsing_ctx) for node in arg_exprs]
|
|
942
|
+
check_all_args(ty.params, inst, "", node, arg_exprs)
|
|
943
|
+
return inst
|
|
952
944
|
|
|
953
945
|
|
|
954
946
|
def check_num_args(
|
|
@@ -992,15 +984,17 @@ def type_check_args(
|
|
|
992
984
|
comptime_args = iter(func_ty.comptime_args)
|
|
993
985
|
for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
|
|
994
986
|
a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
|
|
987
|
+
subst |= s
|
|
995
988
|
if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode):
|
|
996
989
|
a.place = check_place_assignable(
|
|
997
990
|
a.place, ctx, a, "able to borrow subscripted elements"
|
|
998
991
|
)
|
|
999
992
|
if InputFlags.Comptime in func_inp.flags:
|
|
1000
993
|
comptime_arg = next(comptime_args)
|
|
1001
|
-
|
|
994
|
+
const = comptime_arg.const.substitute(subst)
|
|
995
|
+
s = check_comptime_arg(a, const, func_inp.ty.substitute(subst), subst)
|
|
996
|
+
subst |= s
|
|
1002
997
|
new_args.append(a)
|
|
1003
|
-
subst |= s
|
|
1004
998
|
assert next(comptime_args, None) is None
|
|
1005
999
|
|
|
1006
1000
|
# If the argument check succeeded, this means that we must have found instantiations
|
|
@@ -1120,7 +1114,7 @@ def synthesize_call(
|
|
|
1120
1114
|
|
|
1121
1115
|
# Success implies that the substitution is closed
|
|
1122
1116
|
assert all(not t.unsolved_vars for t in subst.values())
|
|
1123
|
-
inst =
|
|
1117
|
+
inst = check_all_solved(subst, free_vars, func_ty, node)
|
|
1124
1118
|
|
|
1125
1119
|
# Finally, check that the instantiation respects the linearity requirements
|
|
1126
1120
|
check_inst(func_ty, inst, node)
|
|
@@ -1199,7 +1193,7 @@ def check_call(
|
|
|
1199
1193
|
|
|
1200
1194
|
# Success implies that the substitution is closed
|
|
1201
1195
|
assert all(not t.unsolved_vars for t in subst.values())
|
|
1202
|
-
inst =
|
|
1196
|
+
inst = check_all_solved(subst, free_vars, func_ty, node)
|
|
1203
1197
|
subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars}
|
|
1204
1198
|
|
|
1205
1199
|
# Finally, check that the instantiation respects the linearity requirements
|
|
@@ -1208,12 +1202,37 @@ def check_call(
|
|
|
1208
1202
|
return inputs, subst, inst
|
|
1209
1203
|
|
|
1210
1204
|
|
|
1205
|
+
def check_all_solved(
|
|
1206
|
+
subst: Subst,
|
|
1207
|
+
free_vars: Sequence[ExistentialVar],
|
|
1208
|
+
func_ty: FunctionType,
|
|
1209
|
+
loc: AstNode,
|
|
1210
|
+
) -> Inst:
|
|
1211
|
+
"""Checks that a substitution solves all parameters of a function.
|
|
1212
|
+
|
|
1213
|
+
Using 3.12 generic syntax, users can declare parameters that don't occur in the
|
|
1214
|
+
signature. Those will remain unsolved, even after unifying all function arguments,
|
|
1215
|
+
so we have to perform this extra check.
|
|
1216
|
+
|
|
1217
|
+
Returns an instantiation of all free variables, or emits a user error if some are
|
|
1218
|
+
not solved.
|
|
1219
|
+
"""
|
|
1220
|
+
for v in free_vars:
|
|
1221
|
+
if v not in subst:
|
|
1222
|
+
err = ParameterInferenceError(loc, v.display_name)
|
|
1223
|
+
err.add_sub_diagnostic(ParameterInferenceError.SignatureHint(None, func_ty))
|
|
1224
|
+
raise GuppyTypeInferenceError(err)
|
|
1225
|
+
return [subst[v].to_arg() for v in free_vars]
|
|
1226
|
+
|
|
1227
|
+
|
|
1211
1228
|
def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None:
|
|
1212
1229
|
"""Checks if an instantiation is valid.
|
|
1213
1230
|
|
|
1214
1231
|
Makes sure that the linearity requirements are satisfied.
|
|
1215
1232
|
"""
|
|
1216
1233
|
for param, arg in zip(func_ty.params, inst, strict=True):
|
|
1234
|
+
param = param.instantiate_bounds(inst)
|
|
1235
|
+
|
|
1217
1236
|
# Give a more informative error message for linearity issues
|
|
1218
1237
|
if isinstance(param, TypeParam) and isinstance(arg, TypeArg):
|
|
1219
1238
|
if param.must_be_copyable and not arg.ty.copyable:
|
|
@@ -1232,7 +1251,14 @@ def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr:
|
|
|
1232
1251
|
"""Instantiates quantified type arguments in a function."""
|
|
1233
1252
|
assert len(ty.params) == len(inst)
|
|
1234
1253
|
if len(inst) > 0:
|
|
1235
|
-
|
|
1254
|
+
# Partial applications need to be instantiated on the inside
|
|
1255
|
+
if isinstance(node, PartialApply):
|
|
1256
|
+
full_ty = get_type(node.func)
|
|
1257
|
+
assert isinstance(full_ty, FunctionType)
|
|
1258
|
+
assert full_ty.params == ty.params
|
|
1259
|
+
node.func = instantiate_poly(node.func, full_ty, inst)
|
|
1260
|
+
else:
|
|
1261
|
+
node = with_loc(node, TypeApply(value=with_type(ty, node), inst=inst))
|
|
1236
1262
|
return with_type(ty.instantiate(inst), node)
|
|
1237
1263
|
return with_type(ty, node)
|
|
1238
1264
|
|