guppylang-internals 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +37 -18
- guppylang_internals/cfg/analysis.py +6 -6
- guppylang_internals/cfg/builder.py +44 -12
- guppylang_internals/cfg/cfg.py +1 -1
- guppylang_internals/checker/core.py +1 -1
- guppylang_internals/checker/errors/comptime_errors.py +0 -12
- guppylang_internals/checker/errors/linearity.py +6 -2
- guppylang_internals/checker/expr_checker.py +53 -28
- guppylang_internals/checker/func_checker.py +4 -3
- guppylang_internals/checker/stmt_checker.py +1 -1
- guppylang_internals/compiler/cfg_compiler.py +1 -1
- guppylang_internals/compiler/core.py +17 -4
- guppylang_internals/compiler/expr_compiler.py +36 -14
- guppylang_internals/compiler/modifier_compiler.py +5 -2
- guppylang_internals/decorator.py +5 -3
- guppylang_internals/definition/common.py +1 -0
- guppylang_internals/definition/custom.py +2 -2
- guppylang_internals/definition/declaration.py +3 -3
- guppylang_internals/definition/function.py +28 -8
- guppylang_internals/definition/metadata.py +87 -0
- guppylang_internals/definition/overloaded.py +11 -2
- guppylang_internals/definition/pytket_circuits.py +50 -67
- guppylang_internals/definition/value.py +1 -1
- guppylang_internals/definition/wasm.py +3 -3
- guppylang_internals/diagnostic.py +89 -16
- guppylang_internals/engine.py +84 -40
- guppylang_internals/error.py +1 -1
- guppylang_internals/nodes.py +301 -3
- guppylang_internals/span.py +7 -3
- guppylang_internals/std/_internal/checker.py +104 -2
- guppylang_internals/std/_internal/compiler/array.py +36 -1
- guppylang_internals/std/_internal/compiler/either.py +14 -2
- guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
- guppylang_internals/std/_internal/compiler/tket_exts.py +1 -1
- guppylang_internals/std/_internal/debug.py +5 -3
- guppylang_internals/tracing/builtins_mock.py +2 -2
- guppylang_internals/tracing/object.py +6 -2
- guppylang_internals/tys/parsing.py +4 -1
- guppylang_internals/tys/qubit.py +6 -4
- guppylang_internals/tys/subst.py +2 -2
- guppylang_internals/tys/ty.py +2 -2
- guppylang_internals/wasm_util.py +2 -3
- {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/METADATA +5 -4
- {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/RECORD +47 -46
- {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/licenses/LICENCE +0 -0
guppylang_internals/__init__.py
CHANGED
guppylang_internals/ast_util.py
CHANGED
|
@@ -97,13 +97,13 @@ def find_nodes(
|
|
|
97
97
|
def name_nodes_in_ast(node: Any) -> list[ast.Name]:
|
|
98
98
|
"""Returns all `Name` nodes occurring in an AST."""
|
|
99
99
|
found = find_nodes(lambda n: isinstance(n, ast.Name), node)
|
|
100
|
-
return cast(list[ast.Name], found)
|
|
100
|
+
return cast("list[ast.Name]", found)
|
|
101
101
|
|
|
102
102
|
|
|
103
103
|
def return_nodes_in_ast(node: Any) -> list[ast.Return]:
|
|
104
104
|
"""Returns all `Return` nodes occurring in an AST."""
|
|
105
105
|
found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef})
|
|
106
|
-
return cast(list[ast.Return], found)
|
|
106
|
+
return cast("list[ast.Return]", found)
|
|
107
107
|
|
|
108
108
|
|
|
109
109
|
def loop_in_ast(node: Any) -> list[ast.For | ast.While]:
|
|
@@ -111,7 +111,7 @@ def loop_in_ast(node: Any) -> list[ast.For | ast.While]:
|
|
|
111
111
|
found = find_nodes(
|
|
112
112
|
lambda n: isinstance(n, ast.For | ast.While), node, {ast.FunctionDef}
|
|
113
113
|
)
|
|
114
|
-
return cast(list[ast.For | ast.While], found)
|
|
114
|
+
return cast("list[ast.For | ast.While]", found)
|
|
115
115
|
|
|
116
116
|
|
|
117
117
|
def breaks_in_loop(node: Any) -> list[ast.Break]:
|
|
@@ -122,7 +122,7 @@ def breaks_in_loop(node: Any) -> list[ast.Break]:
|
|
|
122
122
|
found = find_nodes(
|
|
123
123
|
lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef}
|
|
124
124
|
)
|
|
125
|
-
return cast(list[ast.Break], found)
|
|
125
|
+
return cast("list[ast.Break]", found)
|
|
126
126
|
|
|
127
127
|
|
|
128
128
|
def loop_controls_in_loop(node: Any) -> list[ast.Break | ast.Continue]:
|
|
@@ -135,7 +135,7 @@ def loop_controls_in_loop(node: Any) -> list[ast.Break | ast.Continue]:
|
|
|
135
135
|
node,
|
|
136
136
|
{ast.For, ast.While, ast.FunctionDef},
|
|
137
137
|
)
|
|
138
|
-
return cast(list[ast.Break | ast.Continue], found)
|
|
138
|
+
return cast("list[ast.Break | ast.Continue]", found)
|
|
139
139
|
|
|
140
140
|
|
|
141
141
|
class ContextAdjuster(ast.NodeTransformer):
|
|
@@ -147,7 +147,7 @@ class ContextAdjuster(ast.NodeTransformer):
|
|
|
147
147
|
self.ctx = ctx
|
|
148
148
|
|
|
149
149
|
def visit(self, node: ast.AST) -> ast.AST:
|
|
150
|
-
return cast(ast.AST, super().visit(node))
|
|
150
|
+
return cast("ast.AST", super().visit(node))
|
|
151
151
|
|
|
152
152
|
def visit_Name(self, node: ast.Name) -> ast.Name:
|
|
153
153
|
return with_loc(node, ast.Name(id=node.id, ctx=self.ctx))
|
|
@@ -156,29 +156,48 @@ class ContextAdjuster(ast.NodeTransformer):
|
|
|
156
156
|
self,
|
|
157
157
|
node: ast.Starred,
|
|
158
158
|
) -> ast.Starred:
|
|
159
|
-
return with_loc(
|
|
159
|
+
return with_loc(
|
|
160
|
+
node,
|
|
161
|
+
ast.Starred(value=self.visit(node.value), ctx=self.ctx), # type: ignore[arg-type]
|
|
162
|
+
)
|
|
160
163
|
|
|
161
164
|
def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple:
|
|
162
165
|
return with_loc(
|
|
163
|
-
node,
|
|
166
|
+
node,
|
|
167
|
+
ast.Tuple(
|
|
168
|
+
elts=[self.visit(elt) for elt in node.elts], # type: ignore[misc]
|
|
169
|
+
ctx=self.ctx,
|
|
170
|
+
),
|
|
164
171
|
)
|
|
165
172
|
|
|
166
173
|
def visit_List(self, node: ast.List) -> ast.List:
|
|
167
174
|
return with_loc(
|
|
168
|
-
node,
|
|
175
|
+
node,
|
|
176
|
+
ast.List(
|
|
177
|
+
elts=[self.visit(elt) for elt in node.elts], # type: ignore[misc]
|
|
178
|
+
ctx=self.ctx,
|
|
179
|
+
),
|
|
169
180
|
)
|
|
170
181
|
|
|
171
182
|
def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
|
|
172
183
|
# Don't adjust the slice!
|
|
173
184
|
return with_loc(
|
|
174
185
|
node,
|
|
175
|
-
ast.Subscript(
|
|
186
|
+
ast.Subscript(
|
|
187
|
+
value=self.visit(node.value), # type: ignore[arg-type]
|
|
188
|
+
slice=node.slice,
|
|
189
|
+
ctx=self.ctx,
|
|
190
|
+
),
|
|
176
191
|
)
|
|
177
192
|
|
|
178
193
|
def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
|
|
179
194
|
return with_loc(
|
|
180
195
|
node,
|
|
181
|
-
ast.Attribute(
|
|
196
|
+
ast.Attribute(
|
|
197
|
+
value=self.visit(node.value), # type: ignore[arg-type]
|
|
198
|
+
attr=node.attr,
|
|
199
|
+
ctx=self.ctx,
|
|
200
|
+
),
|
|
182
201
|
)
|
|
183
202
|
|
|
184
203
|
|
|
@@ -240,15 +259,15 @@ def template_replace(
|
|
|
240
259
|
|
|
241
260
|
def line_col(node: ast.AST) -> tuple[int, int]:
|
|
242
261
|
"""Returns the line and column of an ast node."""
|
|
243
|
-
return node.lineno, node.col_offset
|
|
262
|
+
return node.lineno, node.col_offset # type: ignore[attr-defined]
|
|
244
263
|
|
|
245
264
|
|
|
246
265
|
def set_location_from(node: ast.AST, loc: ast.AST) -> None:
|
|
247
266
|
"""Copy source location from one AST node to the other."""
|
|
248
|
-
node.lineno = loc.lineno
|
|
249
|
-
node.col_offset = loc.col_offset
|
|
250
|
-
node.end_lineno = loc.end_lineno
|
|
251
|
-
node.end_col_offset = loc.end_col_offset
|
|
267
|
+
node.lineno = loc.lineno # type: ignore[attr-defined]
|
|
268
|
+
node.col_offset = loc.col_offset # type: ignore[attr-defined]
|
|
269
|
+
node.end_lineno = loc.end_lineno # type: ignore[attr-defined]
|
|
270
|
+
node.end_col_offset = loc.end_col_offset # type: ignore[attr-defined]
|
|
252
271
|
|
|
253
272
|
source, file, line_offset = get_source(loc), get_file(loc), get_line_offset(loc)
|
|
254
273
|
assert source is not None
|
|
@@ -341,11 +360,11 @@ def with_type(ty: "Type", node: A) -> A:
|
|
|
341
360
|
|
|
342
361
|
def get_type_opt(node: AstNode) -> Optional["Type"]:
|
|
343
362
|
"""Tries to retrieve a type annotation from an AST node."""
|
|
344
|
-
from guppylang_internals.tys.ty import
|
|
363
|
+
from guppylang_internals.tys.ty import TypeBase
|
|
345
364
|
|
|
346
365
|
try:
|
|
347
366
|
ty = node.type # type: ignore[union-attr]
|
|
348
|
-
return cast(Type, ty) if isinstance(ty, TypeBase) else None
|
|
367
|
+
return cast("Type", ty) if isinstance(ty, TypeBase) else None
|
|
349
368
|
except AttributeError:
|
|
350
369
|
return None
|
|
351
370
|
|
|
@@ -11,7 +11,7 @@ T = TypeVar("T")
|
|
|
11
11
|
Result = dict[BB, T]
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class Analysis(Generic[T]
|
|
14
|
+
class Analysis(ABC, Generic[T]):
|
|
15
15
|
"""Abstract base class for a program analysis pass over the lattice `T`"""
|
|
16
16
|
|
|
17
17
|
def eq(self, t1: T, t2: T, /) -> bool:
|
|
@@ -39,7 +39,7 @@ class Analysis(Generic[T], ABC):
|
|
|
39
39
|
"""
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
class ForwardAnalysis(
|
|
42
|
+
class ForwardAnalysis(Analysis[T], ABC, Generic[T]):
|
|
43
43
|
"""Abstract base class for a program analysis pass running in forward direction."""
|
|
44
44
|
|
|
45
45
|
@abstractmethod
|
|
@@ -71,7 +71,7 @@ class ForwardAnalysis(Generic[T], Analysis[T], ABC):
|
|
|
71
71
|
return vals_before
|
|
72
72
|
|
|
73
73
|
|
|
74
|
-
class BackwardAnalysis(
|
|
74
|
+
class BackwardAnalysis(Analysis[T], ABC, Generic[T]):
|
|
75
75
|
"""Abstract base class for a program analysis pass running in backward direction."""
|
|
76
76
|
|
|
77
77
|
@abstractmethod
|
|
@@ -105,7 +105,7 @@ class BackwardAnalysis(Generic[T], Analysis[T], ABC):
|
|
|
105
105
|
LivenessDomain = dict[VId, BB]
|
|
106
106
|
|
|
107
107
|
|
|
108
|
-
class LivenessAnalysis(
|
|
108
|
+
class LivenessAnalysis(BackwardAnalysis[LivenessDomain[VId]], Generic[VId]):
|
|
109
109
|
"""Live variable analysis pass.
|
|
110
110
|
|
|
111
111
|
Computes the variables that are live before the execution of each BB. The analysis
|
|
@@ -143,7 +143,7 @@ class LivenessAnalysis(Generic[VId], BackwardAnalysis[LivenessDomain[VId]]):
|
|
|
143
143
|
|
|
144
144
|
def apply_bb(self, live_after: LivenessDomain[VId], bb: BB) -> LivenessDomain[VId]:
|
|
145
145
|
stats = self.stats[bb]
|
|
146
|
-
return
|
|
146
|
+
return dict.fromkeys(stats.used, bb) | {
|
|
147
147
|
x: b for x, b in live_after.items() if x not in stats.assigned
|
|
148
148
|
}
|
|
149
149
|
|
|
@@ -159,7 +159,7 @@ MaybeAssignmentDomain = set[VId]
|
|
|
159
159
|
AssignmentDomain = tuple[DefAssignmentDomain[VId], MaybeAssignmentDomain[VId]]
|
|
160
160
|
|
|
161
161
|
|
|
162
|
-
class AssignmentAnalysis(
|
|
162
|
+
class AssignmentAnalysis(ForwardAnalysis[AssignmentDomain[VId]], Generic[VId]):
|
|
163
163
|
"""Assigned variable analysis pass.
|
|
164
164
|
|
|
165
165
|
Computes the set of variables (i.e. `V`s) that are definitely assigned at the start
|
|
@@ -154,27 +154,52 @@ class CFGBuilder(AstVisitor[BB | None]):
|
|
|
154
154
|
return bb_opt
|
|
155
155
|
|
|
156
156
|
def _build_node_value(self, node: BBStatement, bb: BB) -> BB:
|
|
157
|
-
"""Utility method for building a
|
|
157
|
+
"""Utility method for building a nodes `value` expression, if available.
|
|
158
158
|
|
|
159
159
|
Builds the expression and mutates `node.value` to point to the built expression.
|
|
160
|
-
Returns the BB in which the expression is available
|
|
160
|
+
Returns the BB in which the expression is available.
|
|
161
161
|
"""
|
|
162
162
|
if (
|
|
163
163
|
not isinstance(node, NestedFunctionDef | ModifiedBlock)
|
|
164
164
|
and node.value is not None
|
|
165
165
|
):
|
|
166
166
|
node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
|
|
167
|
-
bb
|
|
167
|
+
return bb
|
|
168
|
+
|
|
169
|
+
def _build_node_targets(self, node: BBStatement, bb: BB) -> BB:
|
|
170
|
+
"""Utility method for building a nodes `target` or `targets` expressions,
|
|
171
|
+
depending on the node type.
|
|
172
|
+
|
|
173
|
+
Builds the expressions and mutates the elements of `node.targets` to point to
|
|
174
|
+
the built expressions. Returns the BB in which the expressions are available.
|
|
175
|
+
"""
|
|
176
|
+
if isinstance(node, ast.Assign):
|
|
177
|
+
for i, target in enumerate(node.targets):
|
|
178
|
+
node.targets[i], bb = ExprBuilder.build(target, self.cfg, bb)
|
|
179
|
+
elif isinstance(node, ast.AugAssign | ast.AnnAssign):
|
|
180
|
+
new_target, bb = ExprBuilder.build(node.target, self.cfg, bb)
|
|
181
|
+
if not isinstance(new_target, ast.Name | ast.Attribute | ast.Subscript):
|
|
182
|
+
raise InternalGuppyError("Unexpected type for built expression.")
|
|
183
|
+
node.target = new_target
|
|
168
184
|
return bb
|
|
169
185
|
|
|
170
186
|
def visit_Assign(self, node: ast.Assign, bb: BB, jumps: Jumps) -> BB | None:
|
|
171
|
-
|
|
187
|
+
bb = self._build_node_value(node, bb)
|
|
188
|
+
bb = self._build_node_targets(node, bb)
|
|
189
|
+
bb.statements.append(node)
|
|
190
|
+
return bb
|
|
172
191
|
|
|
173
192
|
def visit_AugAssign(self, node: ast.AugAssign, bb: BB, jumps: Jumps) -> BB | None:
|
|
174
|
-
|
|
193
|
+
bb = self._build_node_value(node, bb)
|
|
194
|
+
bb = self._build_node_targets(node, bb)
|
|
195
|
+
bb.statements.append(node)
|
|
196
|
+
return bb
|
|
175
197
|
|
|
176
198
|
def visit_AnnAssign(self, node: ast.AnnAssign, bb: BB, jumps: Jumps) -> BB | None:
|
|
177
|
-
|
|
199
|
+
bb = self._build_node_value(node, bb)
|
|
200
|
+
bb = self._build_node_targets(node, bb)
|
|
201
|
+
bb.statements.append(node)
|
|
202
|
+
return bb
|
|
178
203
|
|
|
179
204
|
def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> BB | None:
|
|
180
205
|
# This is an expression statement where the value is discarded
|
|
@@ -262,6 +287,7 @@ class CFGBuilder(AstVisitor[BB | None]):
|
|
|
262
287
|
|
|
263
288
|
def visit_Return(self, node: ast.Return, bb: BB, jumps: Jumps) -> BB | None:
|
|
264
289
|
bb = self._build_node_value(node, bb)
|
|
290
|
+
bb.statements.append(node)
|
|
265
291
|
self.cfg.link(bb, jumps.return_bb)
|
|
266
292
|
return None
|
|
267
293
|
|
|
@@ -572,7 +598,7 @@ class BranchBuilder(AstVisitor[None]):
|
|
|
572
598
|
comparators[:-1], node.ops, comparators[1:], strict=True
|
|
573
599
|
)
|
|
574
600
|
]
|
|
575
|
-
conj = ast.BoolOp(op=ast.And(), values=values)
|
|
601
|
+
conj = ast.BoolOp(op=ast.And(), values=values) # type: ignore[arg-type]
|
|
576
602
|
set_location_from(conj, node)
|
|
577
603
|
self.visit_BoolOp(conj, bb, true_bb, false_bb)
|
|
578
604
|
else:
|
|
@@ -668,6 +694,9 @@ def is_comptime_expression(node: ast.AST) -> ComptimeExpr | None:
|
|
|
668
694
|
|
|
669
695
|
Otherwise, returns `None`.
|
|
670
696
|
"""
|
|
697
|
+
if isinstance(node, ComptimeExpr):
|
|
698
|
+
return node
|
|
699
|
+
|
|
671
700
|
if (
|
|
672
701
|
isinstance(node, ast.Call)
|
|
673
702
|
and isinstance(node.func, ast.Name)
|
|
@@ -679,8 +708,8 @@ def is_comptime_expression(node: ast.AST) -> ComptimeExpr | None:
|
|
|
679
708
|
case [arg]:
|
|
680
709
|
pass
|
|
681
710
|
case args:
|
|
682
|
-
arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load))
|
|
683
|
-
return with_loc(node, ComptimeExpr(
|
|
711
|
+
arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load)) # type: ignore[arg-type]
|
|
712
|
+
return with_loc(node, ComptimeExpr(arg))
|
|
684
713
|
return None
|
|
685
714
|
|
|
686
715
|
|
|
@@ -701,7 +730,7 @@ def is_illegal_in_list_comp(node: ast.AST) -> bool:
|
|
|
701
730
|
|
|
702
731
|
def make_var(name: str, loc: ast.AST | None = None) -> ast.Name:
|
|
703
732
|
"""Creates an `ast.Name` node."""
|
|
704
|
-
node = ast.Name(id=name, ctx=ast.Load)
|
|
733
|
+
node = ast.Name(id=name, ctx=ast.Load) # type: ignore[arg-type]
|
|
705
734
|
if loc is not None:
|
|
706
735
|
set_location_from(node, loc)
|
|
707
736
|
return node
|
|
@@ -715,5 +744,8 @@ def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign:
|
|
|
715
744
|
if len(lhs) == 1:
|
|
716
745
|
target = lhs[0]
|
|
717
746
|
else:
|
|
718
|
-
target = with_loc(
|
|
719
|
-
|
|
747
|
+
target = with_loc(
|
|
748
|
+
value,
|
|
749
|
+
ast.Tuple(elts=lhs, ctx=ast.Store()), # type: ignore[arg-type]
|
|
750
|
+
)
|
|
751
|
+
return with_loc(value, ast.Assign(targets=[target], value=value)) # type: ignore[list-item]
|
guppylang_internals/cfg/cfg.py
CHANGED
|
@@ -118,7 +118,7 @@ class CFG(BaseCFG[BB]):
|
|
|
118
118
|
# initial value in the liveness analysis. This solves the edge case that
|
|
119
119
|
# borrowed variables should be considered live, even if the exit is actually
|
|
120
120
|
# unreachable (to avoid linearity violations later).
|
|
121
|
-
inout_live =
|
|
121
|
+
inout_live = dict.fromkeys(inout_vars, self.exit_bb)
|
|
122
122
|
self.live_before = LivenessAnalysis(
|
|
123
123
|
stats, initial=inout_live, include_unreachable=True
|
|
124
124
|
).run(self.bbs)
|
|
@@ -384,7 +384,7 @@ class Globals:
|
|
|
384
384
|
case _:
|
|
385
385
|
return assert_never(ty)
|
|
386
386
|
|
|
387
|
-
type_defn = cast(TypeDef, ENGINE.get_checked(type_defn.id))
|
|
387
|
+
type_defn = cast("TypeDef", ENGINE.get_checked(type_defn.id))
|
|
388
388
|
if type_defn.id in DEF_STORE.impls and name in DEF_STORE.impls[type_defn.id]:
|
|
389
389
|
def_id = DEF_STORE.impls[type_defn.id][name]
|
|
390
390
|
defn = ENGINE.get_parsed(def_id)
|
|
@@ -44,18 +44,6 @@ class ComptimeExprIncoherentListError(Error):
|
|
|
44
44
|
span_label: ClassVar[str] = "List contains elements with different types"
|
|
45
45
|
|
|
46
46
|
|
|
47
|
-
@dataclass(frozen=True)
|
|
48
|
-
class TketNotInstalled(Error):
|
|
49
|
-
title: ClassVar[str] = "Tket not installed"
|
|
50
|
-
span_label: ClassVar[str] = (
|
|
51
|
-
"Experimental pytket compatibility requires `tket` to be installed"
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
@dataclass(frozen=True)
|
|
55
|
-
class InstallInstruction(Help):
|
|
56
|
-
message: ClassVar[str] = "Install tket: `pip install tket`"
|
|
57
|
-
|
|
58
|
-
|
|
59
47
|
@dataclass(frozen=True)
|
|
60
48
|
class PytketSignatureMismatch(Error):
|
|
61
49
|
title: ClassVar[str] = "Signature mismatch"
|
|
@@ -33,7 +33,9 @@ class AlreadyUsedError(Error):
|
|
|
33
33
|
|
|
34
34
|
@dataclass(frozen=True)
|
|
35
35
|
class PrevUse(Note):
|
|
36
|
-
span_label: ClassVar[str] =
|
|
36
|
+
span_label: ClassVar[str] = (
|
|
37
|
+
"{place.describe} already {prev_kind.subjunctive} here"
|
|
38
|
+
)
|
|
37
39
|
prev_kind: UseKind
|
|
38
40
|
|
|
39
41
|
@dataclass(frozen=True)
|
|
@@ -55,7 +57,9 @@ class ComprAlreadyUsedError(Error):
|
|
|
55
57
|
|
|
56
58
|
@dataclass(frozen=True)
|
|
57
59
|
class PrevUse(Note):
|
|
58
|
-
span_label: ClassVar[str] =
|
|
60
|
+
span_label: ClassVar[str] = (
|
|
61
|
+
"{place.describe} already {prev_kind.subjunctive} here"
|
|
62
|
+
)
|
|
59
63
|
prev_kind: UseKind
|
|
60
64
|
|
|
61
65
|
|
|
@@ -21,6 +21,7 @@ can be used to infer a type for an expression.
|
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
23
|
import ast
|
|
24
|
+
import copy
|
|
24
25
|
import sys
|
|
25
26
|
import traceback
|
|
26
27
|
from collections.abc import Sequence
|
|
@@ -85,6 +86,7 @@ from guppylang_internals.checker.errors.type_errors import (
|
|
|
85
86
|
WrongNumberOfArgsError,
|
|
86
87
|
)
|
|
87
88
|
from guppylang_internals.definition.common import Definition
|
|
89
|
+
from guppylang_internals.definition.parameter import ParamDef
|
|
88
90
|
from guppylang_internals.definition.ty import TypeDef
|
|
89
91
|
from guppylang_internals.definition.value import CallableDef, ValueDef
|
|
90
92
|
from guppylang_internals.error import (
|
|
@@ -234,7 +236,7 @@ class ExprChecker(AstVisitor[tuple[ast.expr, Subst]]):
|
|
|
234
236
|
if actual := get_type_opt(expr):
|
|
235
237
|
expr, subst, inst = check_type_against(actual, ty, expr, self.ctx, kind)
|
|
236
238
|
if inst:
|
|
237
|
-
expr = with_loc(expr, TypeApply(
|
|
239
|
+
expr = with_loc(expr, TypeApply(expr, inst))
|
|
238
240
|
return with_type(ty.substitute(subst), expr), subst
|
|
239
241
|
|
|
240
242
|
# When checking against a variable, we have to synthesize
|
|
@@ -370,7 +372,7 @@ class ExprChecker(AstVisitor[tuple[ast.expr, Subst]]):
|
|
|
370
372
|
|
|
371
373
|
# Apply instantiation of quantified type variables
|
|
372
374
|
if inst:
|
|
373
|
-
node = with_loc(node, TypeApply(
|
|
375
|
+
node = with_loc(node, TypeApply(node, inst))
|
|
374
376
|
|
|
375
377
|
return node, subst
|
|
376
378
|
|
|
@@ -407,23 +409,27 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
|
407
409
|
raise GuppyError(IllegalConstant(node, type(node.value)))
|
|
408
410
|
return node, ty
|
|
409
411
|
|
|
412
|
+
def _check_generic_param(self, name: str, node: ast.expr) -> tuple[ast.expr, Type]:
|
|
413
|
+
"""Helper method to check a generic parameter (ConstParam or TypeParam)."""
|
|
414
|
+
param = self.ctx.generic_params[name]
|
|
415
|
+
match param:
|
|
416
|
+
case ConstParam() as param:
|
|
417
|
+
ast_node = with_loc(node, GenericParamValue(id=name, param=param))
|
|
418
|
+
return ast_node, param.ty
|
|
419
|
+
case TypeParam() as param:
|
|
420
|
+
raise GuppyError(
|
|
421
|
+
ExpectedError(node, "a value", got=f"type `{param.name}`")
|
|
422
|
+
)
|
|
423
|
+
case _:
|
|
424
|
+
return assert_never(param)
|
|
425
|
+
|
|
410
426
|
def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]:
|
|
411
427
|
x = node.id
|
|
412
428
|
if x in self.ctx.locals:
|
|
413
429
|
var = self.ctx.locals[x]
|
|
414
430
|
return with_loc(node, PlaceNode(place=var)), var.ty
|
|
415
431
|
elif x in self.ctx.generic_params:
|
|
416
|
-
|
|
417
|
-
match param:
|
|
418
|
-
case ConstParam() as param:
|
|
419
|
-
ast_node = with_loc(node, GenericParamValue(id=x, param=param))
|
|
420
|
-
return ast_node, param.ty
|
|
421
|
-
case TypeParam() as param:
|
|
422
|
-
raise GuppyError(
|
|
423
|
-
ExpectedError(node, "a value", got=f"type `{param.name}`")
|
|
424
|
-
)
|
|
425
|
-
case _:
|
|
426
|
-
return assert_never(param)
|
|
432
|
+
return self._check_generic_param(x, node)
|
|
427
433
|
elif x in self.ctx.globals:
|
|
428
434
|
match self.ctx.globals[x]:
|
|
429
435
|
case Definition() as defn:
|
|
@@ -454,6 +460,16 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
|
454
460
|
defn, "__new__"
|
|
455
461
|
):
|
|
456
462
|
return with_loc(node, GlobalName(id=name, def_id=constr.id)), constr.ty
|
|
463
|
+
# Handle parameter definitions (e.g., nat_var) that may be imported
|
|
464
|
+
case ParamDef():
|
|
465
|
+
# Check if this parameter is in our generic_params
|
|
466
|
+
# (e.g., used in type signature)
|
|
467
|
+
if name in self.ctx.generic_params:
|
|
468
|
+
return self._check_generic_param(name, node)
|
|
469
|
+
# If not in generic_params, it's being used outside its scope
|
|
470
|
+
raise GuppyError(
|
|
471
|
+
ExpectedError(node, "a value", got=f"{defn.description} `{name}`")
|
|
472
|
+
)
|
|
457
473
|
case defn:
|
|
458
474
|
raise GuppyError(
|
|
459
475
|
ExpectedError(node, "a value", got=f"{defn.description} `{name}`")
|
|
@@ -461,6 +477,7 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
|
461
477
|
|
|
462
478
|
def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
|
|
463
479
|
from guppylang.defs import GuppyDefinition
|
|
480
|
+
|
|
464
481
|
from guppylang_internals.engine import ENGINE
|
|
465
482
|
|
|
466
483
|
# A `value.attr` attribute access. Unfortunately, the `attr` is just a string,
|
|
@@ -993,11 +1010,13 @@ def type_check_args(
|
|
|
993
1010
|
new_args.append(a)
|
|
994
1011
|
assert next(comptime_args, None) is None
|
|
995
1012
|
|
|
996
|
-
#
|
|
997
|
-
#
|
|
998
|
-
|
|
999
|
-
set.issubset(inp.ty.unsolved_vars, subst.keys())
|
|
1000
|
-
|
|
1013
|
+
# Check whether we have found instantiations for all unification variables occurring
|
|
1014
|
+
# in the input types
|
|
1015
|
+
for inp in func_ty.inputs:
|
|
1016
|
+
if not set.issubset(inp.ty.unsolved_vars, subst.keys()):
|
|
1017
|
+
raise GuppyTypeInferenceError(
|
|
1018
|
+
TypeInferenceError(node, inp.ty.substitute(subst))
|
|
1019
|
+
)
|
|
1001
1020
|
|
|
1002
1021
|
# We also have to check that we found instantiations for all vars in the return type
|
|
1003
1022
|
if not set.issubset(func_ty.output.unsolved_vars, subst.keys()):
|
|
@@ -1162,19 +1181,25 @@ def check_call(
|
|
|
1162
1181
|
# However the bad case, e.g. `x: int = foo(foo(...foo(?)...))`, shouldn't be common
|
|
1163
1182
|
# in practice. Can we do better than that?
|
|
1164
1183
|
|
|
1165
|
-
#
|
|
1166
|
-
|
|
1184
|
+
# synthesize_call may modify args and node in place,
|
|
1185
|
+
# hence we deepcopy them before passing in the function
|
|
1186
|
+
node_copy = copy.deepcopy(node)
|
|
1187
|
+
inputs_copy = copy.deepcopy(inputs)
|
|
1188
|
+
|
|
1167
1189
|
try:
|
|
1168
1190
|
inputs, synth, inst = synthesize_call(func_ty, inputs, node, ctx)
|
|
1169
|
-
res = synth, inst
|
|
1170
|
-
except GuppyTypeInferenceError:
|
|
1171
|
-
pass
|
|
1172
|
-
if res is not None:
|
|
1173
|
-
synth, inst = res
|
|
1174
1191
|
subst = unify(ty, synth, {})
|
|
1175
1192
|
if subst is None:
|
|
1176
1193
|
raise GuppyTypeError(TypeMismatchError(node, ty, synth, kind))
|
|
1177
|
-
|
|
1194
|
+
else:
|
|
1195
|
+
return inputs, subst, inst
|
|
1196
|
+
except GuppyTypeInferenceError:
|
|
1197
|
+
pass
|
|
1198
|
+
|
|
1199
|
+
# Restore the state of these values from before they were potentially
|
|
1200
|
+
# modified by `synthesize_call`.
|
|
1201
|
+
inputs = inputs_copy
|
|
1202
|
+
node = node_copy
|
|
1178
1203
|
|
|
1179
1204
|
# If synthesis fails, we try again, this time also using information from the
|
|
1180
1205
|
# expected return type
|
|
@@ -1263,7 +1288,7 @@ def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr:
|
|
|
1263
1288
|
assert full_ty.params == ty.params
|
|
1264
1289
|
node.func = instantiate_poly(node.func, full_ty, inst)
|
|
1265
1290
|
else:
|
|
1266
|
-
node = with_loc(node, TypeApply(
|
|
1291
|
+
node = with_loc(node, TypeApply(with_type(ty, node), inst))
|
|
1267
1292
|
return with_type(ty.instantiate(inst), node)
|
|
1268
1293
|
return with_type(ty, node)
|
|
1269
1294
|
|
|
@@ -1389,7 +1414,7 @@ def python_value_to_guppy_type(
|
|
|
1389
1414
|
]
|
|
1390
1415
|
if any(ty is None for ty in tys):
|
|
1391
1416
|
return None
|
|
1392
|
-
return TupleType(cast(list[Type], tys))
|
|
1417
|
+
return TupleType(cast("list[Type]", tys))
|
|
1393
1418
|
case list():
|
|
1394
1419
|
return _python_list_to_guppy_type(v, node, globals, type_hint)
|
|
1395
1420
|
case None:
|
|
@@ -141,7 +141,7 @@ def check_global_func_def(
|
|
|
141
141
|
check_invalid_under_dagger(func_def, ty.unitary_flags)
|
|
142
142
|
cfg = CFGBuilder().build(func_def.body, returns_none, globals, ty.unitary_flags)
|
|
143
143
|
inputs = [
|
|
144
|
-
Variable(cast(str, inp.name), inp.ty, loc, inp.flags, is_func_input=True)
|
|
144
|
+
Variable(cast("str", inp.name), inp.ty, loc, inp.flags, is_func_input=True)
|
|
145
145
|
for inp, loc in zip(ty.inputs, args, strict=True)
|
|
146
146
|
# Comptime inputs are turned into generic args, so are not included here
|
|
147
147
|
if InputFlags.Comptime not in inp.flags
|
|
@@ -199,7 +199,7 @@ def check_nested_func_def(
|
|
|
199
199
|
|
|
200
200
|
# Construct inputs for checking the body CFG
|
|
201
201
|
inputs = [v for v, _ in captured.values()] + [
|
|
202
|
-
Variable(cast(str, inp.name), inp.ty, arg, inp.flags, is_func_input=True)
|
|
202
|
+
Variable(cast("str", inp.name), inp.ty, arg, inp.flags, is_func_input=True)
|
|
203
203
|
for arg, inp in zip(func_def.args.args, func_ty.inputs, strict=True)
|
|
204
204
|
# Comptime inputs are turned into generic args, so are not included here
|
|
205
205
|
if InputFlags.Comptime not in inp.flags
|
|
@@ -214,6 +214,7 @@ def check_nested_func_def(
|
|
|
214
214
|
if not captured:
|
|
215
215
|
# If there are no captured vars, we treat the function like a global name
|
|
216
216
|
from guppylang.defs import GuppyDefinition
|
|
217
|
+
|
|
217
218
|
from guppylang_internals.definition.function import ParsedFunctionDef
|
|
218
219
|
|
|
219
220
|
func = ParsedFunctionDef(def_id, func_def.name, func_def, func_ty, None)
|
|
@@ -288,7 +289,7 @@ def check_signature(
|
|
|
288
289
|
# Figure out if this is a method
|
|
289
290
|
self_defn: TypeDef | None = None
|
|
290
291
|
if def_id is not None and def_id in DEF_STORE.impl_parents:
|
|
291
|
-
self_defn = cast(TypeDef, ENGINE.get_checked(DEF_STORE.impl_parents[def_id]))
|
|
292
|
+
self_defn = cast("TypeDef", ENGINE.get_checked(DEF_STORE.impl_parents[def_id]))
|
|
292
293
|
assert isinstance(self_defn, TypeDef)
|
|
293
294
|
|
|
294
295
|
inputs = []
|
|
@@ -478,7 +478,7 @@ def parse_unpack_pattern(lhs: ast.Tuple | ast.List) -> UnpackPattern:
|
|
|
478
478
|
# that there is at most one starred expression)
|
|
479
479
|
left = list(takewhile(lambda e: not isinstance(e, ast.Starred), lhs.elts))
|
|
480
480
|
starred = (
|
|
481
|
-
cast(ast.Starred, lhs.elts[len(left)]).value
|
|
481
|
+
cast("ast.Starred", lhs.elts[len(left)]).value
|
|
482
482
|
if len(left) < len(lhs.elts)
|
|
483
483
|
else None
|
|
484
484
|
)
|
|
@@ -111,7 +111,7 @@ def compile_bb(
|
|
|
111
111
|
pred_ty = builder.hugr.port_type(branch_port.out_port())
|
|
112
112
|
assert pred_ty == OpaqueBool
|
|
113
113
|
branch_port = dfg.builder.add_op(read_bool(), branch_port)
|
|
114
|
-
branch_port = cast(Wire, branch_port)
|
|
114
|
+
branch_port = cast("Wire", branch_port)
|
|
115
115
|
else:
|
|
116
116
|
# Even if we don't branch, we still have to add a `Sum(())` predicates
|
|
117
117
|
branch_port = dfg.builder.add_op(ops.Tag(0, ht.UnitSum(1)))
|
|
@@ -31,12 +31,14 @@ from guppylang_internals.definition.common import (
|
|
|
31
31
|
CompilableDef,
|
|
32
32
|
CompiledDef,
|
|
33
33
|
DefId,
|
|
34
|
+
Definition,
|
|
34
35
|
MonomorphizableDef,
|
|
36
|
+
RawDef,
|
|
35
37
|
)
|
|
36
38
|
from guppylang_internals.definition.ty import TypeDef
|
|
37
39
|
from guppylang_internals.definition.value import CompiledCallableDef
|
|
38
40
|
from guppylang_internals.diagnostic import Error
|
|
39
|
-
from guppylang_internals.engine import ENGINE
|
|
41
|
+
from guppylang_internals.engine import DEF_STORE, ENGINE
|
|
40
42
|
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
41
43
|
from guppylang_internals.std._internal.compiler.tket_exts import GUPPY_EXTENSION
|
|
42
44
|
from guppylang_internals.tys.arg import ConstArg, TypeArg
|
|
@@ -199,7 +201,7 @@ class CompilerContext(ToHugrContext):
|
|
|
199
201
|
params, type_args, self
|
|
200
202
|
)
|
|
201
203
|
compile_outer = lambda: monomorphizable.monomorphize( # noqa: E731 (assign-lambda)
|
|
202
|
-
self.module, mono_args, self
|
|
204
|
+
self.module, mono_args, self, get_parent_type(monomorphizable)
|
|
203
205
|
)
|
|
204
206
|
case CompilableDef() as compilable:
|
|
205
207
|
compile_outer = lambda: compilable.compile_outer(self.module, self) # noqa: E731
|
|
@@ -227,7 +229,9 @@ class CompilerContext(ToHugrContext):
|
|
|
227
229
|
raise GuppyError(err)
|
|
228
230
|
# Thus, the partial monomorphization for the entry point is always empty
|
|
229
231
|
entry_mono_args = tuple(None for _ in params)
|
|
230
|
-
entry_compiled = defn.monomorphize(
|
|
232
|
+
entry_compiled = defn.monomorphize(
|
|
233
|
+
self.module, entry_mono_args, self, get_parent_type(defn)
|
|
234
|
+
)
|
|
231
235
|
case CompilableDef() as defn:
|
|
232
236
|
entry_compiled = defn.compile_outer(self.module, self)
|
|
233
237
|
case CompiledDef() as defn:
|
|
@@ -371,7 +375,7 @@ class DFContainer:
|
|
|
371
375
|
ctx: CompilerContext,
|
|
372
376
|
locals: CompiledLocals | None = None,
|
|
373
377
|
) -> None:
|
|
374
|
-
generic_builder = cast(DfBase[ops.DfParentOp], builder)
|
|
378
|
+
generic_builder = cast("DfBase[ops.DfParentOp]", builder)
|
|
375
379
|
if locals is None:
|
|
376
380
|
locals = {}
|
|
377
381
|
self.builder = generic_builder
|
|
@@ -467,6 +471,15 @@ def is_return_var(x: str) -> bool:
|
|
|
467
471
|
return x.startswith("%ret")
|
|
468
472
|
|
|
469
473
|
|
|
474
|
+
def get_parent_type(defn: Definition) -> "RawDef | None":
|
|
475
|
+
"""Returns the RawDef registered as the parent of `child` in the DEF_STORE,
|
|
476
|
+
or None if it has no parent."""
|
|
477
|
+
if parent_ty_id := DEF_STORE.impl_parents.get(defn.id):
|
|
478
|
+
return DEF_STORE.raw_defs[parent_ty_id]
|
|
479
|
+
else:
|
|
480
|
+
return None
|
|
481
|
+
|
|
482
|
+
|
|
470
483
|
def require_monomorphization(params: Sequence[Parameter]) -> set[Parameter]:
|
|
471
484
|
"""Returns the subset of type parameters that must be monomorphized before compiling
|
|
472
485
|
to Hugr.
|