guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +118 -5
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +5 -2
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +58 -17
- guppylang_internals/checker/func_checker.py +18 -14
- guppylang_internals/checker/linearity_checker.py +67 -10
- guppylang_internals/checker/modifier_checker.py +120 -0
- guppylang_internals/checker/stmt_checker.py +48 -1
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +93 -56
- guppylang_internals/compiler/expr_compiler.py +72 -168
- guppylang_internals/compiler/modifier_compiler.py +176 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +86 -7
- guppylang_internals/definition/custom.py +39 -1
- guppylang_internals/definition/declaration.py +9 -6
- guppylang_internals/definition/function.py +12 -2
- guppylang_internals/definition/parameter.py +8 -3
- guppylang_internals/definition/pytket_circuits.py +14 -41
- guppylang_internals/definition/struct.py +13 -7
- guppylang_internals/definition/ty.py +3 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/engine.py +9 -3
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +147 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +95 -283
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +10 -0
- guppylang_internals/tracing/unpacking.py +19 -20
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +2 -5
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +11 -24
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +62 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +91 -85
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
- guppylang_internals-0.26.0.dist-info/RECORD +104 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
- guppylang_internals-0.24.0.dist-info/RECORD +0 -98
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
guppylang_internals/__init__.py
CHANGED
guppylang_internals/ast_util.py
CHANGED
|
@@ -106,6 +106,14 @@ def return_nodes_in_ast(node: Any) -> list[ast.Return]:
|
|
|
106
106
|
return cast(list[ast.Return], found)
|
|
107
107
|
|
|
108
108
|
|
|
109
|
+
def loop_in_ast(node: Any) -> list[ast.For | ast.While]:
|
|
110
|
+
"""Returns all `For` and `While` nodes occurring in an AST."""
|
|
111
|
+
found = find_nodes(
|
|
112
|
+
lambda n: isinstance(n, ast.For | ast.While), node, {ast.FunctionDef}
|
|
113
|
+
)
|
|
114
|
+
return cast(list[ast.For | ast.While], found)
|
|
115
|
+
|
|
116
|
+
|
|
109
117
|
def breaks_in_loop(node: Any) -> list[ast.Break]:
|
|
110
118
|
"""Returns all `Break` nodes occurring in a loop.
|
|
111
119
|
|
|
@@ -117,6 +125,19 @@ def breaks_in_loop(node: Any) -> list[ast.Break]:
|
|
|
117
125
|
return cast(list[ast.Break], found)
|
|
118
126
|
|
|
119
127
|
|
|
128
|
+
def loop_controls_in_loop(node: Any) -> list[ast.Break | ast.Continue]:
|
|
129
|
+
"""Returns all `Break` and `Continue` nodes occurring in a loop.
|
|
130
|
+
|
|
131
|
+
Note that breaks in nested loops are excluded.
|
|
132
|
+
"""
|
|
133
|
+
found = find_nodes(
|
|
134
|
+
lambda n: isinstance(n, ast.Break | ast.Continue),
|
|
135
|
+
node,
|
|
136
|
+
{ast.For, ast.While, ast.FunctionDef},
|
|
137
|
+
)
|
|
138
|
+
return cast(list[ast.Break | ast.Continue], found)
|
|
139
|
+
|
|
140
|
+
|
|
120
141
|
class ContextAdjuster(ast.NodeTransformer):
|
|
121
142
|
"""Updates the `ast.Context` indicating if expressions occur on the LHS or RHS."""
|
|
122
143
|
|
guppylang_internals/cfg/bb.py
CHANGED
|
@@ -13,6 +13,7 @@ from guppylang_internals.nodes import (
|
|
|
13
13
|
DesugaredGenerator,
|
|
14
14
|
DesugaredGeneratorExpr,
|
|
15
15
|
DesugaredListComp,
|
|
16
|
+
ModifiedBlock,
|
|
16
17
|
NestedFunctionDef,
|
|
17
18
|
)
|
|
18
19
|
|
|
@@ -44,6 +45,7 @@ BBStatement = (
|
|
|
44
45
|
| ast.Expr
|
|
45
46
|
| ast.Return
|
|
46
47
|
| NestedFunctionDef
|
|
48
|
+
| ModifiedBlock
|
|
47
49
|
)
|
|
48
50
|
|
|
49
51
|
|
|
@@ -219,3 +221,21 @@ class VariableVisitor(ast.NodeVisitor):
|
|
|
219
221
|
|
|
220
222
|
# The name of the function is now assigned
|
|
221
223
|
self.stats.assigned[node.name] = node
|
|
224
|
+
|
|
225
|
+
def visit_ModifiedBlock(self, node: ModifiedBlock) -> None:
|
|
226
|
+
for item in node.control:
|
|
227
|
+
self.visit(item)
|
|
228
|
+
for item in node.power:
|
|
229
|
+
self.visit(item)
|
|
230
|
+
|
|
231
|
+
# Similarly to nested functions
|
|
232
|
+
from guppylang_internals.cfg.analysis import LivenessAnalysis
|
|
233
|
+
|
|
234
|
+
stats = {bb: bb.compute_variable_stats() for bb in node.cfg.bbs}
|
|
235
|
+
live = LivenessAnalysis(stats).run(node.cfg.bbs)
|
|
236
|
+
assigned_before_in_bb = self.stats.assigned.keys()
|
|
237
|
+
self.stats.used |= {
|
|
238
|
+
x: using_bb.vars.used[x]
|
|
239
|
+
for x, using_bb in live[node.cfg.entry_bb].items()
|
|
240
|
+
if x not in assigned_before_in_bb
|
|
241
|
+
}
|
|
@@ -9,6 +9,8 @@ from guppylang_internals.ast_util import (
|
|
|
9
9
|
AstVisitor,
|
|
10
10
|
ContextAdjuster,
|
|
11
11
|
find_nodes,
|
|
12
|
+
loop_controls_in_loop,
|
|
13
|
+
return_nodes_in_ast,
|
|
12
14
|
set_location_from,
|
|
13
15
|
template_replace,
|
|
14
16
|
with_loc,
|
|
@@ -16,20 +18,35 @@ from guppylang_internals.ast_util import (
|
|
|
16
18
|
from guppylang_internals.cfg.bb import BB, BBStatement
|
|
17
19
|
from guppylang_internals.cfg.cfg import CFG
|
|
18
20
|
from guppylang_internals.checker.core import Globals
|
|
19
|
-
from guppylang_internals.checker.errors.generic import
|
|
21
|
+
from guppylang_internals.checker.errors.generic import (
|
|
22
|
+
ExpectedError,
|
|
23
|
+
UnexpectedInWithBlockError,
|
|
24
|
+
UnknownModifierError,
|
|
25
|
+
UnsupportedError,
|
|
26
|
+
)
|
|
27
|
+
from guppylang_internals.checker.errors.type_errors import WrongNumberOfArgsError
|
|
20
28
|
from guppylang_internals.diagnostic import Error
|
|
21
29
|
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
22
|
-
from guppylang_internals.experimental import
|
|
30
|
+
from guppylang_internals.experimental import (
|
|
31
|
+
check_lists_enabled,
|
|
32
|
+
check_modifiers_enabled,
|
|
33
|
+
)
|
|
23
34
|
from guppylang_internals.nodes import (
|
|
24
35
|
ComptimeExpr,
|
|
36
|
+
Control,
|
|
37
|
+
Dagger,
|
|
25
38
|
DesugaredGenerator,
|
|
26
39
|
DesugaredGeneratorExpr,
|
|
27
40
|
DesugaredListComp,
|
|
28
41
|
IterNext,
|
|
29
42
|
MakeIter,
|
|
43
|
+
ModifiedBlock,
|
|
44
|
+
Modifier,
|
|
30
45
|
NestedFunctionDef,
|
|
46
|
+
Power,
|
|
31
47
|
)
|
|
32
|
-
from guppylang_internals.
|
|
48
|
+
from guppylang_internals.span import Span, to_span
|
|
49
|
+
from guppylang_internals.tys.ty import NoneType, UnitaryFlags
|
|
33
50
|
|
|
34
51
|
# In order to build expressions, need an endless stream of unique temporary variables
|
|
35
52
|
# to store intermediate results
|
|
@@ -61,7 +78,13 @@ class CFGBuilder(AstVisitor[BB | None]):
|
|
|
61
78
|
cfg: CFG
|
|
62
79
|
globals: Globals
|
|
63
80
|
|
|
64
|
-
def build(
|
|
81
|
+
def build(
|
|
82
|
+
self,
|
|
83
|
+
nodes: list[ast.stmt],
|
|
84
|
+
returns_none: bool,
|
|
85
|
+
globals: Globals,
|
|
86
|
+
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
|
|
87
|
+
) -> CFG:
|
|
65
88
|
"""Builds a CFG from a list of ast nodes.
|
|
66
89
|
|
|
67
90
|
We also require the expected number of return ports for the whole CFG. This is
|
|
@@ -69,6 +92,7 @@ class CFGBuilder(AstVisitor[BB | None]):
|
|
|
69
92
|
variables.
|
|
70
93
|
"""
|
|
71
94
|
self.cfg = CFG()
|
|
95
|
+
self.cfg.unitary_flags = unitary_flags
|
|
72
96
|
self.globals = globals
|
|
73
97
|
|
|
74
98
|
final_bb = self.visit_stmts(
|
|
@@ -135,7 +159,10 @@ class CFGBuilder(AstVisitor[BB | None]):
|
|
|
135
159
|
Builds the expression and mutates `node.value` to point to the built expression.
|
|
136
160
|
Returns the BB in which the expression is available and adds the node to it.
|
|
137
161
|
"""
|
|
138
|
-
if
|
|
162
|
+
if (
|
|
163
|
+
not isinstance(node, NestedFunctionDef | ModifiedBlock)
|
|
164
|
+
and node.value is not None
|
|
165
|
+
):
|
|
139
166
|
node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
|
|
140
167
|
bb.statements.append(node)
|
|
141
168
|
return bb
|
|
@@ -253,6 +280,7 @@ class CFGBuilder(AstVisitor[BB | None]):
|
|
|
253
280
|
|
|
254
281
|
func_ty = check_signature(node, self.globals)
|
|
255
282
|
returns_none = isinstance(func_ty.output, NoneType)
|
|
283
|
+
# No UnitaryFlags are assigned to nested functions
|
|
256
284
|
cfg = CFGBuilder().build(node.body, returns_none, self.globals)
|
|
257
285
|
|
|
258
286
|
new_node = NestedFunctionDef(
|
|
@@ -265,6 +293,91 @@ class CFGBuilder(AstVisitor[BB | None]):
|
|
|
265
293
|
bb.statements.append(new_node)
|
|
266
294
|
return bb
|
|
267
295
|
|
|
296
|
+
def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None:
|
|
297
|
+
check_modifiers_enabled(node)
|
|
298
|
+
self._validate_modified_block(node)
|
|
299
|
+
|
|
300
|
+
cfg = CFGBuilder().build(node.body, True, self.globals)
|
|
301
|
+
new_node = ModifiedBlock(
|
|
302
|
+
cfg=cfg,
|
|
303
|
+
**dict(ast.iter_fields(node)),
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
for item in node.items:
|
|
307
|
+
item.context_expr, bb = ExprBuilder.build(item.context_expr, self.cfg, bb)
|
|
308
|
+
modifier = self._handle_withitem(item)
|
|
309
|
+
new_node.push_modifier(modifier)
|
|
310
|
+
|
|
311
|
+
# FIXME: Currently, the unitary flags is not set correctly if there are nested
|
|
312
|
+
# `with` blocks. This is because the outer block's unitary flags are not
|
|
313
|
+
# propagated to the outer block. The following line should calculate the sum
|
|
314
|
+
# of the unitary flags of the outer block and modifiers applied in this
|
|
315
|
+
# `with` block.
|
|
316
|
+
cfg.unitary_flags = new_node.flags()
|
|
317
|
+
|
|
318
|
+
set_location_from(new_node, node)
|
|
319
|
+
bb.statements.append(new_node)
|
|
320
|
+
return bb
|
|
321
|
+
|
|
322
|
+
def _handle_withitem(self, node: ast.withitem) -> Modifier:
|
|
323
|
+
# Check that `as` notation is not used
|
|
324
|
+
if node.optional_vars is not None:
|
|
325
|
+
span = Span(
|
|
326
|
+
to_span(node.context_expr).start, to_span(node.optional_vars).end
|
|
327
|
+
)
|
|
328
|
+
raise GuppyError(UnsupportedError(span, "`as` expression", singular=True))
|
|
329
|
+
|
|
330
|
+
e = node.context_expr
|
|
331
|
+
modifier: Modifier
|
|
332
|
+
match e:
|
|
333
|
+
case ast.Name(id="dagger"):
|
|
334
|
+
modifier = Dagger(e)
|
|
335
|
+
case ast.Call(func=ast.Name(id="dagger")):
|
|
336
|
+
if len(e.args) != 0:
|
|
337
|
+
span = Span(to_span(e.args[0]).start, to_span(e.args[-1]).end)
|
|
338
|
+
raise GuppyError(WrongNumberOfArgsError(span, 0, len(e.args)))
|
|
339
|
+
modifier = Dagger(e)
|
|
340
|
+
case ast.Call(func=ast.Name(id="control")):
|
|
341
|
+
if len(e.args) == 0:
|
|
342
|
+
span = Span(to_span(e.func).end, to_span(e).end)
|
|
343
|
+
raise GuppyError(WrongNumberOfArgsError(span, 1, len(e.args)))
|
|
344
|
+
modifier = Control(e, e.args)
|
|
345
|
+
case ast.Call(func=ast.Name(id="power")):
|
|
346
|
+
if len(e.args) == 0:
|
|
347
|
+
span = Span(to_span(e.func).end, to_span(e).end)
|
|
348
|
+
raise GuppyError(WrongNumberOfArgsError(span, 1, len(e.args)))
|
|
349
|
+
elif len(e.args) != 1:
|
|
350
|
+
span = Span(to_span(e.args[1]).start, to_span(e.args[-1]).end)
|
|
351
|
+
raise GuppyError(WrongNumberOfArgsError(span, 1, len(e.args)))
|
|
352
|
+
modifier = Power(e, e.args[0])
|
|
353
|
+
case _:
|
|
354
|
+
raise GuppyError(UnknownModifierError(e))
|
|
355
|
+
return modifier
|
|
356
|
+
|
|
357
|
+
def _validate_modified_block(self, node: ast.With) -> None:
|
|
358
|
+
# Check if the body contains a return statement.
|
|
359
|
+
return_in_body = return_nodes_in_ast(node)
|
|
360
|
+
if len(return_in_body) != 0:
|
|
361
|
+
err = UnexpectedInWithBlockError(return_in_body[0], "return", "Return")
|
|
362
|
+
span = Span(
|
|
363
|
+
to_span(node.items[0].context_expr).start,
|
|
364
|
+
to_span(node.items[-1].context_expr).end,
|
|
365
|
+
)
|
|
366
|
+
err.add_sub_diagnostic(UnexpectedInWithBlockError.Modifier(span))
|
|
367
|
+
raise GuppyError(err)
|
|
368
|
+
|
|
369
|
+
loop_controls_in_body = loop_controls_in_loop(node)
|
|
370
|
+
if len(loop_controls_in_body) != 0:
|
|
371
|
+
lc = loop_controls_in_body[0]
|
|
372
|
+
kind = lc.__class__.__name__
|
|
373
|
+
err = UnexpectedInWithBlockError(lc, "loop control", kind)
|
|
374
|
+
span = Span(
|
|
375
|
+
to_span(node.items[0].context_expr).start,
|
|
376
|
+
to_span(node.items[-1].context_expr).end,
|
|
377
|
+
)
|
|
378
|
+
err.add_sub_diagnostic(UnexpectedInWithBlockError.Modifier(span))
|
|
379
|
+
raise GuppyError(err)
|
|
380
|
+
|
|
268
381
|
def generic_visit(self, node: ast.AST, bb: BB, jumps: Jumps) -> BB | None:
|
|
269
382
|
# When adding support for new statements, we have to remember to use the
|
|
270
383
|
# ExprBuilder to transform all included expressions!
|
guppylang_internals/cfg/cfg.py
CHANGED
|
@@ -12,6 +12,7 @@ from guppylang_internals.cfg.analysis import (
|
|
|
12
12
|
)
|
|
13
13
|
from guppylang_internals.cfg.bb import BB, BBStatement, VariableStats
|
|
14
14
|
from guppylang_internals.nodes import InoutReturnSentinel
|
|
15
|
+
from guppylang_internals.tys.ty import UnitaryFlags
|
|
15
16
|
|
|
16
17
|
T = TypeVar("T", bound=BB)
|
|
17
18
|
|
|
@@ -29,6 +30,7 @@ class BaseCFG(Generic[T]):
|
|
|
29
30
|
|
|
30
31
|
#: Set of variables defined in this CFG
|
|
31
32
|
assigned_somewhere: set[str]
|
|
33
|
+
unitary_flags: UnitaryFlags
|
|
32
34
|
|
|
33
35
|
def __init__(
|
|
34
36
|
self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None
|
|
@@ -42,6 +44,7 @@ class BaseCFG(Generic[T]):
|
|
|
42
44
|
self.ass_before = {}
|
|
43
45
|
self.maybe_ass_before = {}
|
|
44
46
|
self.assigned_somewhere = set()
|
|
47
|
+
self.unitary_flags = UnitaryFlags.NoFlags
|
|
45
48
|
|
|
46
49
|
def ancestors(self, *bbs: T) -> Iterator[T]:
|
|
47
50
|
"""Returns an iterator over all ancestors of the given BBs in BFS order."""
|
|
@@ -149,11 +149,17 @@ def check_cfg(
|
|
|
149
149
|
checked_cfg.maybe_ass_before = {
|
|
150
150
|
compiled[bb]: cfg.maybe_ass_before[bb] for bb in required_bbs
|
|
151
151
|
}
|
|
152
|
+
checked_cfg.unitary_flags = cfg.unitary_flags
|
|
152
153
|
|
|
153
154
|
# Finally, run the linearity check
|
|
154
155
|
from guppylang_internals.checker.linearity_checker import check_cfg_linearity
|
|
155
156
|
|
|
156
157
|
linearity_checked_cfg = check_cfg_linearity(checked_cfg, func_name, globals)
|
|
158
|
+
|
|
159
|
+
from guppylang_internals.checker.unitary_checker import check_cfg_unitary
|
|
160
|
+
|
|
161
|
+
check_cfg_unitary(linearity_checked_cfg, cfg.unitary_flags)
|
|
162
|
+
|
|
157
163
|
return linearity_checked_cfg
|
|
158
164
|
|
|
159
165
|
|
|
@@ -47,7 +47,6 @@ from guppylang_internals.tys.ty import (
|
|
|
47
47
|
NumericType,
|
|
48
48
|
OpaqueType,
|
|
49
49
|
StructType,
|
|
50
|
-
SumType,
|
|
51
50
|
TupleType,
|
|
52
51
|
Type,
|
|
53
52
|
)
|
|
@@ -117,6 +116,10 @@ class Variable:
|
|
|
117
116
|
"""Returns a new `Variable` instance with an updated definition location."""
|
|
118
117
|
return replace(self, defined_at=node)
|
|
119
118
|
|
|
119
|
+
def add_flags(self, flags: InputFlags) -> "Variable":
|
|
120
|
+
"""Returns a new `Variable` instance with updated flags."""
|
|
121
|
+
return replace(self, flags=self.flags | flags)
|
|
122
|
+
|
|
120
123
|
|
|
121
124
|
@dataclass(frozen=True, kw_only=True)
|
|
122
125
|
class ComptimeVariable(Variable):
|
|
@@ -356,7 +359,7 @@ class Globals:
|
|
|
356
359
|
match ty:
|
|
357
360
|
case TypeDef() as type_defn:
|
|
358
361
|
pass
|
|
359
|
-
case BoundTypeVar() | ExistentialTypeVar()
|
|
362
|
+
case BoundTypeVar() | ExistentialTypeVar():
|
|
360
363
|
return None
|
|
361
364
|
case NumericType(kind):
|
|
362
365
|
match kind:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from typing import ClassVar
|
|
3
3
|
|
|
4
|
-
from guppylang_internals.diagnostic import Error
|
|
4
|
+
from guppylang_internals.diagnostic import Error, Note
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
@dataclass(frozen=True)
|
|
@@ -43,3 +43,34 @@ class ExpectedError(Error):
|
|
|
43
43
|
@property
|
|
44
44
|
def extra(self) -> str:
|
|
45
45
|
return f", got {self.got}" if self.got else ""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class UnknownModifierError(Error):
|
|
50
|
+
title: ClassVar[str] = "Unknown modifier"
|
|
51
|
+
span_label: ClassVar[str] = (
|
|
52
|
+
"Expected one of {{dagger, control(...), or power(...)}}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass(frozen=True)
|
|
57
|
+
class UnexpectedInWithBlockError(Error):
|
|
58
|
+
title: ClassVar[str] = "Unexpected {kind}"
|
|
59
|
+
span_label: ClassVar[str] = "{things} found in a `With` block"
|
|
60
|
+
kind: str
|
|
61
|
+
things: str
|
|
62
|
+
|
|
63
|
+
@dataclass(frozen=True)
|
|
64
|
+
class Modifier(Note):
|
|
65
|
+
span_label: ClassVar[str] = "modifier is used here"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass(frozen=True)
|
|
69
|
+
class InvalidUnderDagger(Error):
|
|
70
|
+
title: ClassVar[str] = "Invalid expression in dagger"
|
|
71
|
+
span_label: ClassVar[str] = "{things} found in a dagger context"
|
|
72
|
+
things: str
|
|
73
|
+
|
|
74
|
+
@dataclass(frozen=True)
|
|
75
|
+
class Dagger(Note):
|
|
76
|
+
span_label: ClassVar[str] = "dagger modifier is used here"
|
|
@@ -95,6 +95,20 @@ class TypeInferenceError(Error):
|
|
|
95
95
|
unsolved_ty: Type
|
|
96
96
|
|
|
97
97
|
|
|
98
|
+
@dataclass(frozen=True)
|
|
99
|
+
class ParameterInferenceError(Error):
|
|
100
|
+
title: ClassVar[str] = "Cannot infer generic parameter"
|
|
101
|
+
span_label: ClassVar[str] = (
|
|
102
|
+
"Cannot infer generic parameter `{param}` of this function"
|
|
103
|
+
)
|
|
104
|
+
param: str
|
|
105
|
+
|
|
106
|
+
@dataclass(frozen=True)
|
|
107
|
+
class SignatureHint(Note):
|
|
108
|
+
message: ClassVar[str] = "Function signature is `{sig}`"
|
|
109
|
+
sig: FunctionType
|
|
110
|
+
|
|
111
|
+
|
|
98
112
|
@dataclass(frozen=True)
|
|
99
113
|
class IllegalConstant(Error):
|
|
100
114
|
title: ClassVar[str] = "Unsupported constant"
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from typing import ClassVar
|
|
3
3
|
|
|
4
|
-
from guppylang_internals.diagnostic import Error
|
|
4
|
+
from guppylang_internals.diagnostic import Error, Note
|
|
5
5
|
from guppylang_internals.tys.ty import Type
|
|
6
6
|
|
|
7
7
|
|
|
@@ -13,10 +13,13 @@ class WasmError(Error):
|
|
|
13
13
|
@dataclass(frozen=True)
|
|
14
14
|
class FirstArgNotModule(WasmError):
|
|
15
15
|
span_label: ClassVar[str] = (
|
|
16
|
-
"First argument to WASM function should be a
|
|
17
|
-
" Found `{ty}` instead"
|
|
16
|
+
"First argument to WASM function should be a WASM module."
|
|
18
17
|
)
|
|
19
|
-
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class GotOtherType(Note):
|
|
21
|
+
span_label: ClassVar[str] = "Found `{ty}` instead."
|
|
22
|
+
ty: Type
|
|
20
23
|
|
|
21
24
|
|
|
22
25
|
@dataclass(frozen=True)
|
|
@@ -23,6 +23,7 @@ can be used to infer a type for an expression.
|
|
|
23
23
|
import ast
|
|
24
24
|
import sys
|
|
25
25
|
import traceback
|
|
26
|
+
from collections.abc import Sequence
|
|
26
27
|
from contextlib import suppress
|
|
27
28
|
from dataclasses import replace
|
|
28
29
|
from types import ModuleType
|
|
@@ -42,6 +43,7 @@ from guppylang_internals.ast_util import (
|
|
|
42
43
|
)
|
|
43
44
|
from guppylang_internals.cfg.builder import is_tmp_var, tmp_vars
|
|
44
45
|
from guppylang_internals.checker.core import (
|
|
46
|
+
ComptimeVariable,
|
|
45
47
|
Context,
|
|
46
48
|
DummyEvalDict,
|
|
47
49
|
FieldAccess,
|
|
@@ -74,6 +76,7 @@ from guppylang_internals.checker.errors.type_errors import (
|
|
|
74
76
|
ModuleMemberNotFoundError,
|
|
75
77
|
NonLinearInstantiateError,
|
|
76
78
|
NotCallableError,
|
|
79
|
+
ParameterInferenceError,
|
|
77
80
|
TupleIndexOutOfBoundsError,
|
|
78
81
|
TypeApplyNotGenericError,
|
|
79
82
|
TypeInferenceError,
|
|
@@ -130,7 +133,7 @@ from guppylang_internals.tys.builtin import (
|
|
|
130
133
|
string_type,
|
|
131
134
|
)
|
|
132
135
|
from guppylang_internals.tys.const import Const, ConstValue
|
|
133
|
-
from guppylang_internals.tys.param import ConstParam, TypeParam
|
|
136
|
+
from guppylang_internals.tys.param import ConstParam, TypeParam, check_all_args
|
|
134
137
|
from guppylang_internals.tys.parsing import arg_from_ast
|
|
135
138
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
136
139
|
from guppylang_internals.tys.ty import (
|
|
@@ -149,6 +152,7 @@ from guppylang_internals.tys.ty import (
|
|
|
149
152
|
parse_function_tensor,
|
|
150
153
|
unify,
|
|
151
154
|
)
|
|
155
|
+
from guppylang_internals.tys.var import ExistentialVar
|
|
152
156
|
|
|
153
157
|
if TYPE_CHECKING:
|
|
154
158
|
from guppylang_internals.diagnostic import SubDiagnostic
|
|
@@ -462,8 +466,15 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
|
462
466
|
# A `value.attr` attribute access. Unfortunately, the `attr` is just a string,
|
|
463
467
|
# not an AST node, so we have to compute its span by hand. This is fine since
|
|
464
468
|
# linebreaks are not allowed in the identifier following the `.`
|
|
469
|
+
# The only exception are attributes accesses that are generated during
|
|
470
|
+
# desugaring (for example for iterators in `for` loops). Since those just
|
|
471
|
+
# inherit the span of the sugared code, we could have line breaks there.
|
|
472
|
+
# See https://github.com/quantinuum/guppylang/issues/1301
|
|
465
473
|
span = to_span(node)
|
|
466
|
-
|
|
474
|
+
if span.start.line == span.end.line:
|
|
475
|
+
attr_span = Span(span.end.shift_left(len(node.attr)), span.end)
|
|
476
|
+
else:
|
|
477
|
+
attr_span = span
|
|
467
478
|
if module := self._is_python_module(node.value):
|
|
468
479
|
if node.attr in module.__dict__:
|
|
469
480
|
val = module.__dict__[node.attr]
|
|
@@ -493,12 +504,7 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
|
493
504
|
)
|
|
494
505
|
# Make a closure by partially applying the `self` argument
|
|
495
506
|
# TODO: Try to infer some type args based on `self`
|
|
496
|
-
result_ty = FunctionType(
|
|
497
|
-
func.ty.inputs[1:],
|
|
498
|
-
func.ty.output,
|
|
499
|
-
func.ty.input_names[1:] if func.ty.input_names else None,
|
|
500
|
-
func.ty.params,
|
|
501
|
-
)
|
|
507
|
+
result_ty = FunctionType(func.ty.inputs[1:], func.ty.output, func.ty.params)
|
|
502
508
|
return with_loc(node, PartialApply(func=name, args=[node.value])), result_ty
|
|
503
509
|
raise GuppyTypeError(AttributeNotFoundError(attr_span, ty, node.attr))
|
|
504
510
|
|
|
@@ -928,10 +934,9 @@ def check_type_apply(ty: FunctionType, node: ast.Subscript, ctx: Context) -> Ins
|
|
|
928
934
|
err.add_sub_diagnostic(WrongNumberOfArgsError.SignatureHint(None, ty))
|
|
929
935
|
raise GuppyError(err)
|
|
930
936
|
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
]
|
|
937
|
+
inst = [arg_from_ast(node, ctx.parsing_ctx) for node in arg_exprs]
|
|
938
|
+
check_all_args(ty.params, inst, "", node, arg_exprs)
|
|
939
|
+
return inst
|
|
935
940
|
|
|
936
941
|
|
|
937
942
|
def check_num_args(
|
|
@@ -975,15 +980,17 @@ def type_check_args(
|
|
|
975
980
|
comptime_args = iter(func_ty.comptime_args)
|
|
976
981
|
for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
|
|
977
982
|
a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
|
|
983
|
+
subst |= s
|
|
978
984
|
if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode):
|
|
979
985
|
a.place = check_place_assignable(
|
|
980
986
|
a.place, ctx, a, "able to borrow subscripted elements"
|
|
981
987
|
)
|
|
982
988
|
if InputFlags.Comptime in func_inp.flags:
|
|
983
989
|
comptime_arg = next(comptime_args)
|
|
984
|
-
|
|
990
|
+
const = comptime_arg.const.substitute(subst)
|
|
991
|
+
s = check_comptime_arg(a, const, func_inp.ty.substitute(subst), subst)
|
|
992
|
+
subst |= s
|
|
985
993
|
new_args.append(a)
|
|
986
|
-
subst |= s
|
|
987
994
|
assert next(comptime_args, None) is None
|
|
988
995
|
|
|
989
996
|
# If the argument check succeeded, this means that we must have found instantiations
|
|
@@ -1024,7 +1031,14 @@ def check_place_assignable(
|
|
|
1024
1031
|
exp_sig = FunctionType(
|
|
1025
1032
|
[
|
|
1026
1033
|
FuncInput(parent.ty, InputFlags.Inout),
|
|
1027
|
-
FuncInput(
|
|
1034
|
+
FuncInput(
|
|
1035
|
+
# Due to potential coercions that were applied during the
|
|
1036
|
+
# `__getitem__` call (e.g. coercing a nat index to int), we're
|
|
1037
|
+
# not allowed to rely on `item.ty` here.
|
|
1038
|
+
# See https://github.com/CQCL/guppylang/issues/1356
|
|
1039
|
+
ExistentialTypeVar.fresh("T", True, True),
|
|
1040
|
+
InputFlags.NoFlags,
|
|
1041
|
+
),
|
|
1028
1042
|
FuncInput(ty, InputFlags.Owned),
|
|
1029
1043
|
],
|
|
1030
1044
|
NoneType(),
|
|
@@ -1061,6 +1075,8 @@ def check_comptime_arg(
|
|
|
1061
1075
|
match arg:
|
|
1062
1076
|
case ast.Constant(value=v):
|
|
1063
1077
|
const = ConstValue(ty, v)
|
|
1078
|
+
case PlaceNode(place=ComptimeVariable(ty=ty, static_value=v)):
|
|
1079
|
+
const = ConstValue(ty, v)
|
|
1064
1080
|
case GenericParamValue(param=const_param):
|
|
1065
1081
|
const = const_param.to_bound().const
|
|
1066
1082
|
case arg:
|
|
@@ -1103,7 +1119,7 @@ def synthesize_call(
|
|
|
1103
1119
|
|
|
1104
1120
|
# Success implies that the substitution is closed
|
|
1105
1121
|
assert all(not t.unsolved_vars for t in subst.values())
|
|
1106
|
-
inst =
|
|
1122
|
+
inst = check_all_solved(subst, free_vars, func_ty, node)
|
|
1107
1123
|
|
|
1108
1124
|
# Finally, check that the instantiation respects the linearity requirements
|
|
1109
1125
|
check_inst(func_ty, inst, node)
|
|
@@ -1182,7 +1198,7 @@ def check_call(
|
|
|
1182
1198
|
|
|
1183
1199
|
# Success implies that the substitution is closed
|
|
1184
1200
|
assert all(not t.unsolved_vars for t in subst.values())
|
|
1185
|
-
inst =
|
|
1201
|
+
inst = check_all_solved(subst, free_vars, func_ty, node)
|
|
1186
1202
|
subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars}
|
|
1187
1203
|
|
|
1188
1204
|
# Finally, check that the instantiation respects the linearity requirements
|
|
@@ -1191,12 +1207,37 @@ def check_call(
|
|
|
1191
1207
|
return inputs, subst, inst
|
|
1192
1208
|
|
|
1193
1209
|
|
|
1210
|
+
def check_all_solved(
|
|
1211
|
+
subst: Subst,
|
|
1212
|
+
free_vars: Sequence[ExistentialVar],
|
|
1213
|
+
func_ty: FunctionType,
|
|
1214
|
+
loc: AstNode,
|
|
1215
|
+
) -> Inst:
|
|
1216
|
+
"""Checks that a substitution solves all parameters of a function.
|
|
1217
|
+
|
|
1218
|
+
Using 3.12 generic syntax, users can declare parameters that don't occur in the
|
|
1219
|
+
signature. Those will remain unsolved, even after unifying all function arguments,
|
|
1220
|
+
so we have to perform this extra check.
|
|
1221
|
+
|
|
1222
|
+
Returns an instantiation of all free variables, or emits a user error if some are
|
|
1223
|
+
not solved.
|
|
1224
|
+
"""
|
|
1225
|
+
for v in free_vars:
|
|
1226
|
+
if v not in subst:
|
|
1227
|
+
err = ParameterInferenceError(loc, v.display_name)
|
|
1228
|
+
err.add_sub_diagnostic(ParameterInferenceError.SignatureHint(None, func_ty))
|
|
1229
|
+
raise GuppyTypeInferenceError(err)
|
|
1230
|
+
return [subst[v].to_arg() for v in free_vars]
|
|
1231
|
+
|
|
1232
|
+
|
|
1194
1233
|
def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None:
|
|
1195
1234
|
"""Checks if an instantiation is valid.
|
|
1196
1235
|
|
|
1197
1236
|
Makes sure that the linearity requirements are satisfied.
|
|
1198
1237
|
"""
|
|
1199
1238
|
for param, arg in zip(func_ty.params, inst, strict=True):
|
|
1239
|
+
param = param.instantiate_bounds(inst)
|
|
1240
|
+
|
|
1200
1241
|
# Give a more informative error message for linearity issues
|
|
1201
1242
|
if isinstance(param, TypeParam) and isinstance(arg, TypeArg):
|
|
1202
1243
|
if param.must_be_copyable and not arg.ty.copyable:
|