guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,1413 @@
|
|
|
1
|
+
"""Type checking and synthesizing code for expressions.
|
|
2
|
+
|
|
3
|
+
Operates on expressions in a basic block after CFG construction. In particular, we
|
|
4
|
+
assume that expressions that involve control flow (i.e. short-circuiting and ternary
|
|
5
|
+
expressions) have been removed during CFG construction.
|
|
6
|
+
|
|
7
|
+
Furthermore, we assume that assignment expressions with the walrus operator := have
|
|
8
|
+
been turned into regular assignments and are no longer present. As a result, expressions
|
|
9
|
+
are assumed to be side effect free, in the sense that they do not modify the variables
|
|
10
|
+
available in the type checking context.
|
|
11
|
+
|
|
12
|
+
We may alter/desugar AST nodes during type checking. In particular, we turn `ast.Name`
|
|
13
|
+
nodes into either `LocalName` or `GlobalName` nodes and `ast.Call` nodes are turned into
|
|
14
|
+
`LocalCall` or `GlobalCall` nodes. Furthermore, all nodes in the resulting AST are
|
|
15
|
+
annotated with their type.
|
|
16
|
+
|
|
17
|
+
Expressions can be checked against a given type by the `ExprChecker`, raising a type
|
|
18
|
+
error if the expressions doesn't have the expected type. Checking is used for annotated
|
|
19
|
+
assignments, return values, and function arguments. Alternatively, the `ExprSynthesizer`
|
|
20
|
+
can be used to infer a type for an expression.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import ast
|
|
24
|
+
import sys
|
|
25
|
+
import traceback
|
|
26
|
+
from contextlib import suppress
|
|
27
|
+
from dataclasses import replace
|
|
28
|
+
from types import ModuleType
|
|
29
|
+
from typing import TYPE_CHECKING, Any, NoReturn, cast
|
|
30
|
+
|
|
31
|
+
from typing_extensions import assert_never
|
|
32
|
+
|
|
33
|
+
from guppylang_internals.ast_util import (
|
|
34
|
+
AstNode,
|
|
35
|
+
AstVisitor,
|
|
36
|
+
breaks_in_loop,
|
|
37
|
+
get_type_opt,
|
|
38
|
+
return_nodes_in_ast,
|
|
39
|
+
with_loc,
|
|
40
|
+
with_type,
|
|
41
|
+
)
|
|
42
|
+
from guppylang_internals.cfg.builder import is_tmp_var, tmp_vars
|
|
43
|
+
from guppylang_internals.checker.core import (
|
|
44
|
+
Context,
|
|
45
|
+
DummyEvalDict,
|
|
46
|
+
FieldAccess,
|
|
47
|
+
Globals,
|
|
48
|
+
Locals,
|
|
49
|
+
Place,
|
|
50
|
+
PythonObject,
|
|
51
|
+
SetitemCall,
|
|
52
|
+
SubscriptAccess,
|
|
53
|
+
TupleAccess,
|
|
54
|
+
Variable,
|
|
55
|
+
)
|
|
56
|
+
from guppylang_internals.checker.errors.comptime_errors import (
|
|
57
|
+
ComptimeExprEvalError,
|
|
58
|
+
ComptimeExprIncoherentListError,
|
|
59
|
+
ComptimeExprNotCPythonError,
|
|
60
|
+
ComptimeExprNotStaticError,
|
|
61
|
+
ComptimeUnknownError,
|
|
62
|
+
IllegalComptimeExpressionError,
|
|
63
|
+
)
|
|
64
|
+
from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
|
|
65
|
+
from guppylang_internals.checker.errors.linearity import NonDroppableForBreakError
|
|
66
|
+
from guppylang_internals.checker.errors.type_errors import (
|
|
67
|
+
AttributeNotFoundError,
|
|
68
|
+
BadProtocolError,
|
|
69
|
+
BinaryOperatorNotDefinedError,
|
|
70
|
+
ConstMismatchError,
|
|
71
|
+
IllegalConstant,
|
|
72
|
+
IntOverflowError,
|
|
73
|
+
ModuleMemberNotFoundError,
|
|
74
|
+
NonLinearInstantiateError,
|
|
75
|
+
NotCallableError,
|
|
76
|
+
TupleIndexOutOfBoundsError,
|
|
77
|
+
TypeApplyNotGenericError,
|
|
78
|
+
TypeInferenceError,
|
|
79
|
+
TypeMismatchError,
|
|
80
|
+
UnaryOperatorNotDefinedError,
|
|
81
|
+
WrongNumberOfArgsError,
|
|
82
|
+
)
|
|
83
|
+
from guppylang_internals.definition.common import Definition
|
|
84
|
+
from guppylang_internals.definition.ty import TypeDef
|
|
85
|
+
from guppylang_internals.definition.value import CallableDef, ValueDef
|
|
86
|
+
from guppylang_internals.error import (
|
|
87
|
+
GuppyError,
|
|
88
|
+
GuppyTypeError,
|
|
89
|
+
GuppyTypeInferenceError,
|
|
90
|
+
InternalGuppyError,
|
|
91
|
+
)
|
|
92
|
+
from guppylang_internals.experimental import (
|
|
93
|
+
check_function_tensors_enabled,
|
|
94
|
+
check_lists_enabled,
|
|
95
|
+
)
|
|
96
|
+
from guppylang_internals.nodes import (
|
|
97
|
+
ComptimeExpr,
|
|
98
|
+
DesugaredGenerator,
|
|
99
|
+
DesugaredGeneratorExpr,
|
|
100
|
+
DesugaredListComp,
|
|
101
|
+
FieldAccessAndDrop,
|
|
102
|
+
GenericParamValue,
|
|
103
|
+
GlobalName,
|
|
104
|
+
IterEnd,
|
|
105
|
+
IterHasNext,
|
|
106
|
+
IterNext,
|
|
107
|
+
LocalCall,
|
|
108
|
+
MakeIter,
|
|
109
|
+
PartialApply,
|
|
110
|
+
PlaceNode,
|
|
111
|
+
SubscriptAccessAndDrop,
|
|
112
|
+
TensorCall,
|
|
113
|
+
TupleAccessAndDrop,
|
|
114
|
+
TypeApply,
|
|
115
|
+
)
|
|
116
|
+
from guppylang_internals.span import Span, to_span
|
|
117
|
+
from guppylang_internals.tys.arg import TypeArg
|
|
118
|
+
from guppylang_internals.tys.builtin import (
|
|
119
|
+
bool_type,
|
|
120
|
+
float_type,
|
|
121
|
+
frozenarray_type,
|
|
122
|
+
get_element_type,
|
|
123
|
+
int_type,
|
|
124
|
+
is_bool_type,
|
|
125
|
+
is_frozenarray_type,
|
|
126
|
+
is_list_type,
|
|
127
|
+
is_sized_iter_type,
|
|
128
|
+
list_type,
|
|
129
|
+
nat_type,
|
|
130
|
+
option_type,
|
|
131
|
+
string_type,
|
|
132
|
+
)
|
|
133
|
+
from guppylang_internals.tys.const import Const, ConstValue
|
|
134
|
+
from guppylang_internals.tys.param import ConstParam, TypeParam
|
|
135
|
+
from guppylang_internals.tys.parsing import arg_from_ast
|
|
136
|
+
from guppylang_internals.tys.subst import Inst, Subst
|
|
137
|
+
from guppylang_internals.tys.ty import (
|
|
138
|
+
ExistentialTypeVar,
|
|
139
|
+
FuncInput,
|
|
140
|
+
FunctionType,
|
|
141
|
+
InputFlags,
|
|
142
|
+
NoneType,
|
|
143
|
+
NumericType,
|
|
144
|
+
OpaqueType,
|
|
145
|
+
StructType,
|
|
146
|
+
TupleType,
|
|
147
|
+
Type,
|
|
148
|
+
TypeBase,
|
|
149
|
+
function_tensor_signature,
|
|
150
|
+
parse_function_tensor,
|
|
151
|
+
unify,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
if TYPE_CHECKING:
|
|
155
|
+
from guppylang_internals.diagnostic import SubDiagnostic
|
|
156
|
+
|
|
157
|
+
# Mapping from unary AST op to dunder method and display name
|
|
158
|
+
unary_table: dict[type[ast.unaryop], tuple[str, str]] = {
|
|
159
|
+
ast.UAdd: ("__pos__", "+"),
|
|
160
|
+
ast.USub: ("__neg__", "-"),
|
|
161
|
+
ast.Invert: ("__invert__", "~"),
|
|
162
|
+
} # fmt: skip
|
|
163
|
+
|
|
164
|
+
# Mapping from binary AST op to left dunder method, right dunder method and display name
|
|
165
|
+
AstOp = ast.operator | ast.cmpop
|
|
166
|
+
binary_table: dict[type[AstOp], tuple[str, str, str]] = {
|
|
167
|
+
ast.Add: ("__add__", "__radd__", "+"),
|
|
168
|
+
ast.Sub: ("__sub__", "__rsub__", "-"),
|
|
169
|
+
ast.Mult: ("__mul__", "__rmul__", "*"),
|
|
170
|
+
ast.Div: ("__truediv__", "__rtruediv__", "/"),
|
|
171
|
+
ast.FloorDiv: ("__floordiv__", "__rfloordiv__", "//"),
|
|
172
|
+
ast.Mod: ("__mod__", "__rmod__", "%"),
|
|
173
|
+
ast.Pow: ("__pow__", "__rpow__", "**"),
|
|
174
|
+
ast.LShift: ("__lshift__", "__rlshift__", "<<"),
|
|
175
|
+
ast.RShift: ("__rshift__", "__rrshift__", ">>"),
|
|
176
|
+
ast.BitOr: ("__or__", "__ror__", "|"),
|
|
177
|
+
ast.BitXor: ("__xor__", "__rxor__", "^"),
|
|
178
|
+
ast.BitAnd: ("__and__", "__rand__", "&"),
|
|
179
|
+
ast.MatMult: ("__matmul__", "__rmatmul__", "@"),
|
|
180
|
+
ast.Eq: ("__eq__", "__eq__", "=="),
|
|
181
|
+
ast.NotEq: ("__ne__", "__ne__", "!="),
|
|
182
|
+
ast.Lt: ("__lt__", "__gt__", "<"),
|
|
183
|
+
ast.LtE: ("__le__", "__ge__", "<="),
|
|
184
|
+
ast.Gt: ("__gt__", "__lt__", ">"),
|
|
185
|
+
ast.GtE: ("__ge__", "__le__", ">="),
|
|
186
|
+
} # fmt: skip
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class ExprChecker(AstVisitor[tuple[ast.expr, Subst]]):
|
|
190
|
+
"""Checks an expression against a type and produces a new type-annotated AST.
|
|
191
|
+
|
|
192
|
+
The type may contain free variables that the checker will try to solve. Note that
|
|
193
|
+
the checker will fail, if some free variables cannot be inferred.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
ctx: Context
|
|
197
|
+
|
|
198
|
+
# Name for the kind of term we are currently checking against (used in errors).
|
|
199
|
+
# For example, "argument", "return value", or in general "expression".
|
|
200
|
+
_kind: str
|
|
201
|
+
|
|
202
|
+
def __init__(self, ctx: Context) -> None:
|
|
203
|
+
self.ctx = ctx
|
|
204
|
+
self._kind = "expression"
|
|
205
|
+
|
|
206
|
+
def _fail(
|
|
207
|
+
self,
|
|
208
|
+
expected: Type,
|
|
209
|
+
actual: ast.expr | Type,
|
|
210
|
+
loc: AstNode | None = None,
|
|
211
|
+
) -> NoReturn:
|
|
212
|
+
"""Raises a type error indicating that the type doesn't match."""
|
|
213
|
+
if not isinstance(actual, TypeBase):
|
|
214
|
+
loc = loc or actual
|
|
215
|
+
_, actual = self._synthesize(actual, allow_free_vars=True)
|
|
216
|
+
if loc is None:
|
|
217
|
+
raise InternalGuppyError("Failure location is required")
|
|
218
|
+
raise GuppyTypeError(TypeMismatchError(loc, expected, actual))
|
|
219
|
+
|
|
220
|
+
def check(
|
|
221
|
+
self, expr: ast.expr, ty: Type, kind: str = "expression"
|
|
222
|
+
) -> tuple[ast.expr, Subst]:
|
|
223
|
+
"""Checks an expression against a type.
|
|
224
|
+
|
|
225
|
+
The type may have free type variables which will try to be resolved. Returns
|
|
226
|
+
a new desugared expression with type annotations and a substitution with the
|
|
227
|
+
resolved type variables.
|
|
228
|
+
"""
|
|
229
|
+
# If we already have a type for the expression, we just have to match it against
|
|
230
|
+
# the target
|
|
231
|
+
if actual := get_type_opt(expr):
|
|
232
|
+
expr, subst, inst = check_type_against(actual, ty, expr, self.ctx, kind)
|
|
233
|
+
if inst:
|
|
234
|
+
expr = with_loc(expr, TypeApply(value=expr, tys=inst))
|
|
235
|
+
return with_type(ty.substitute(subst), expr), subst
|
|
236
|
+
|
|
237
|
+
# When checking against a variable, we have to synthesize
|
|
238
|
+
if isinstance(ty, ExistentialTypeVar):
|
|
239
|
+
expr, syn_ty = self._synthesize(expr, allow_free_vars=False)
|
|
240
|
+
return with_type(syn_ty, expr), {ty: syn_ty}
|
|
241
|
+
|
|
242
|
+
# Otherwise, invoke the visitor
|
|
243
|
+
old_kind = self._kind
|
|
244
|
+
self._kind = kind or self._kind
|
|
245
|
+
expr, subst = self.visit(expr, ty)
|
|
246
|
+
self._kind = old_kind
|
|
247
|
+
return with_type(ty.substitute(subst), expr), subst
|
|
248
|
+
|
|
249
|
+
def _synthesize(
|
|
250
|
+
self, node: ast.expr, allow_free_vars: bool
|
|
251
|
+
) -> tuple[ast.expr, Type]:
|
|
252
|
+
"""Invokes the type synthesiser"""
|
|
253
|
+
return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars)
|
|
254
|
+
|
|
255
|
+
def visit_Constant(self, node: ast.Constant, ty: Type) -> tuple[ast.expr, Subst]:
|
|
256
|
+
act = python_value_to_guppy_type(node.value, node, self.ctx.globals, ty)
|
|
257
|
+
if act is None:
|
|
258
|
+
raise GuppyError(IllegalConstant(node, type(node.value)))
|
|
259
|
+
node, subst, inst = check_type_against(act, ty, node, self.ctx, self._kind)
|
|
260
|
+
assert inst == [], "Const values are not generic"
|
|
261
|
+
return node, subst
|
|
262
|
+
|
|
263
|
+
def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]:
|
|
264
|
+
if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts):
|
|
265
|
+
return self._fail(ty, node)
|
|
266
|
+
subst: Subst = {}
|
|
267
|
+
for i, el in enumerate(node.elts):
|
|
268
|
+
node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst))
|
|
269
|
+
subst |= s
|
|
270
|
+
return node, subst
|
|
271
|
+
|
|
272
|
+
def visit_List(self, node: ast.List, ty: Type) -> tuple[ast.expr, Subst]:
|
|
273
|
+
check_lists_enabled(node)
|
|
274
|
+
if not is_list_type(ty):
|
|
275
|
+
return self._fail(ty, node)
|
|
276
|
+
el_ty = get_element_type(ty)
|
|
277
|
+
subst: Subst = {}
|
|
278
|
+
for i, el in enumerate(node.elts):
|
|
279
|
+
node.elts[i], s = self.check(el, el_ty.substitute(subst))
|
|
280
|
+
subst |= s
|
|
281
|
+
return node, subst
|
|
282
|
+
|
|
283
|
+
def visit_DesugaredListComp(
|
|
284
|
+
self, node: DesugaredListComp, ty: Type
|
|
285
|
+
) -> tuple[ast.expr, Subst]:
|
|
286
|
+
if not is_list_type(ty):
|
|
287
|
+
return self._fail(ty, node)
|
|
288
|
+
node.generators, node.elt, elt_ty = synthesize_comprehension(
|
|
289
|
+
node, node.generators, node.elt, self.ctx
|
|
290
|
+
)
|
|
291
|
+
subst = unify(get_element_type(ty), elt_ty, {})
|
|
292
|
+
if subst is None:
|
|
293
|
+
actual = list_type(elt_ty)
|
|
294
|
+
return self._fail(ty, actual, node)
|
|
295
|
+
return node, subst
|
|
296
|
+
|
|
297
|
+
def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:
|
|
298
|
+
if len(node.keywords) > 0:
|
|
299
|
+
raise GuppyError(UnsupportedError(node.keywords[0], "Keyword arguments"))
|
|
300
|
+
node.func, func_ty = self._synthesize(node.func, allow_free_vars=False)
|
|
301
|
+
|
|
302
|
+
# First handle direct calls of user-defined functions and extension functions
|
|
303
|
+
if isinstance(node.func, GlobalName):
|
|
304
|
+
defn = self.ctx.globals[node.func.def_id]
|
|
305
|
+
if isinstance(defn, CallableDef):
|
|
306
|
+
return defn.check_call(node.args, ty, node, self.ctx)
|
|
307
|
+
|
|
308
|
+
# When calling a `PartialApply` node, we just move the args into this call
|
|
309
|
+
if isinstance(node.func, PartialApply):
|
|
310
|
+
node.args = [*node.func.args, *node.args]
|
|
311
|
+
node.func = node.func.func
|
|
312
|
+
return self.visit_Call(node, ty)
|
|
313
|
+
|
|
314
|
+
# Otherwise, it must be a function as a higher-order value - something
|
|
315
|
+
# whose type is either a FunctionType or a Tuple of FunctionTypes
|
|
316
|
+
if isinstance(func_ty, FunctionType):
|
|
317
|
+
args, return_ty, inst = check_call(func_ty, node.args, ty, node, self.ctx)
|
|
318
|
+
check_inst(func_ty, inst, node)
|
|
319
|
+
node.func = instantiate_poly(node.func, func_ty, inst)
|
|
320
|
+
return with_loc(node, LocalCall(func=node.func, args=args)), return_ty
|
|
321
|
+
|
|
322
|
+
if isinstance(func_ty, TupleType) and (
|
|
323
|
+
function_elements := parse_function_tensor(func_ty)
|
|
324
|
+
):
|
|
325
|
+
check_function_tensors_enabled(node.func)
|
|
326
|
+
if any(f.parametrized for f in function_elements):
|
|
327
|
+
raise GuppyError(
|
|
328
|
+
UnsupportedError(node.func, "Polymorphic function tensors")
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
tensor_ty = function_tensor_signature(function_elements)
|
|
332
|
+
|
|
333
|
+
processed_args, subst, inst = check_call(
|
|
334
|
+
tensor_ty, node.args, ty, node, self.ctx
|
|
335
|
+
)
|
|
336
|
+
assert len(inst) == 0
|
|
337
|
+
return with_loc(
|
|
338
|
+
node,
|
|
339
|
+
TensorCall(func=node.func, args=processed_args, tensor_ty=tensor_ty),
|
|
340
|
+
), subst
|
|
341
|
+
|
|
342
|
+
elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"):
|
|
343
|
+
return callee.check_call(node.args, ty, node, self.ctx)
|
|
344
|
+
else:
|
|
345
|
+
raise GuppyTypeError(NotCallableError(node.func, func_ty))
|
|
346
|
+
|
|
347
|
+
def visit_ComptimeExpr(
|
|
348
|
+
self, node: ComptimeExpr, ty: Type
|
|
349
|
+
) -> tuple[ast.expr, Subst]:
|
|
350
|
+
python_val = eval_comptime_expr(node, self.ctx)
|
|
351
|
+
if act := python_value_to_guppy_type(
|
|
352
|
+
python_val, node.value, self.ctx.globals, ty
|
|
353
|
+
):
|
|
354
|
+
subst = unify(ty, act, {})
|
|
355
|
+
if subst is None:
|
|
356
|
+
self._fail(ty, act, node)
|
|
357
|
+
act = act.substitute(subst)
|
|
358
|
+
subst = {x: s for x, s in subst.items() if x in ty.unsolved_vars}
|
|
359
|
+
return with_type(act, with_loc(node, ast.Constant(value=python_val))), subst
|
|
360
|
+
|
|
361
|
+
raise GuppyError(IllegalComptimeExpressionError(node.value, type(python_val)))
|
|
362
|
+
|
|
363
|
+
def generic_visit(self, node: ast.expr, ty: Type) -> tuple[ast.expr, Subst]:
|
|
364
|
+
# Try to synthesize and then check if we can unify it with the given type
|
|
365
|
+
node, synth = self._synthesize(node, allow_free_vars=False)
|
|
366
|
+
node, subst, inst = check_type_against(synth, ty, node, self.ctx, self._kind)
|
|
367
|
+
|
|
368
|
+
# Apply instantiation of quantified type variables
|
|
369
|
+
if inst:
|
|
370
|
+
node = with_loc(node, TypeApply(value=node, inst=inst))
|
|
371
|
+
|
|
372
|
+
return node, subst
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
|
|
376
|
+
ctx: Context
|
|
377
|
+
|
|
378
|
+
def __init__(self, ctx: Context) -> None:
|
|
379
|
+
self.ctx = ctx
|
|
380
|
+
|
|
381
|
+
def synthesize(
|
|
382
|
+
self, node: ast.expr, allow_free_vars: bool = False
|
|
383
|
+
) -> tuple[ast.expr, Type]:
|
|
384
|
+
"""Tries to synthesise a type for the given expression.
|
|
385
|
+
|
|
386
|
+
Also returns a new desugared expression with type annotations.
|
|
387
|
+
"""
|
|
388
|
+
if ty := get_type_opt(node):
|
|
389
|
+
return node, ty
|
|
390
|
+
node, ty = self.visit(node)
|
|
391
|
+
if ty.unsolved_vars and not allow_free_vars:
|
|
392
|
+
raise GuppyError(TypeInferenceError(node, ty))
|
|
393
|
+
return with_type(ty, node), ty
|
|
394
|
+
|
|
395
|
+
def _check(
|
|
396
|
+
self, expr: ast.expr, ty: Type, kind: str = "expression"
|
|
397
|
+
) -> tuple[ast.expr, Subst]:
|
|
398
|
+
"""Checks an expression against a given type"""
|
|
399
|
+
return ExprChecker(self.ctx).check(expr, ty, kind)
|
|
400
|
+
|
|
401
|
+
def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, Type]:
|
|
402
|
+
ty = python_value_to_guppy_type(node.value, node, self.ctx.globals)
|
|
403
|
+
if ty is None:
|
|
404
|
+
raise GuppyError(IllegalConstant(node, type(node.value)))
|
|
405
|
+
return node, ty
|
|
406
|
+
|
|
407
|
+
def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]:
|
|
408
|
+
x = node.id
|
|
409
|
+
if x in self.ctx.locals:
|
|
410
|
+
var = self.ctx.locals[x]
|
|
411
|
+
return with_loc(node, PlaceNode(place=var)), var.ty
|
|
412
|
+
elif x in self.ctx.generic_params:
|
|
413
|
+
param = self.ctx.generic_params[x]
|
|
414
|
+
match param:
|
|
415
|
+
case ConstParam() as param:
|
|
416
|
+
ast_node = with_loc(node, GenericParamValue(id=x, param=param))
|
|
417
|
+
return ast_node, param.ty
|
|
418
|
+
case TypeParam() as param:
|
|
419
|
+
raise GuppyError(
|
|
420
|
+
ExpectedError(node, "a value", got=f"type `{param.name}`")
|
|
421
|
+
)
|
|
422
|
+
case _:
|
|
423
|
+
return assert_never(param)
|
|
424
|
+
elif x in self.ctx.globals:
|
|
425
|
+
match self.ctx.globals[x]:
|
|
426
|
+
case Definition() as defn:
|
|
427
|
+
return self._check_global(defn, x, node)
|
|
428
|
+
case PythonObject():
|
|
429
|
+
from guppylang_internals.checker.cfg_checker import (
|
|
430
|
+
VarNotDefinedError,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# TODO: Emit a hint that the variable could be accessed through a
|
|
434
|
+
# comptime expression
|
|
435
|
+
raise GuppyError(VarNotDefinedError(node, node.id))
|
|
436
|
+
|
|
437
|
+
raise InternalGuppyError(
|
|
438
|
+
f"Variable `{x}` is not defined in `TypeSynthesiser`. This should have "
|
|
439
|
+
"been caught by program analysis!"
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
def _check_global(
|
|
443
|
+
self, defn: Definition, name: str, node: ast.expr
|
|
444
|
+
) -> tuple[ast.expr, Type]:
|
|
445
|
+
"""Checks a global definition in an expression position."""
|
|
446
|
+
match defn:
|
|
447
|
+
case ValueDef() as defn:
|
|
448
|
+
return with_loc(node, GlobalName(id=name, def_id=defn.id)), defn.ty
|
|
449
|
+
# For types, we return their `__new__` constructor
|
|
450
|
+
case TypeDef() as defn if constr := self.ctx.globals.get_instance_func(
|
|
451
|
+
defn, "__new__"
|
|
452
|
+
):
|
|
453
|
+
return with_loc(node, GlobalName(id=name, def_id=constr.id)), constr.ty
|
|
454
|
+
case defn:
|
|
455
|
+
raise GuppyError(
|
|
456
|
+
ExpectedError(node, "a value", got=f"{defn.description} `{name}`")
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
|
|
460
|
+
from guppylang.defs import GuppyDefinition
|
|
461
|
+
from guppylang_internals.engine import ENGINE
|
|
462
|
+
|
|
463
|
+
# A `value.attr` attribute access. Unfortunately, the `attr` is just a string,
|
|
464
|
+
# not an AST node, so we have to compute its span by hand. This is fine since
|
|
465
|
+
# linebreaks are not allowed in the identifier following the `.`
|
|
466
|
+
span = to_span(node)
|
|
467
|
+
attr_span = Span(span.end.shift_left(len(node.attr)), span.end)
|
|
468
|
+
if module := self._is_python_module(node.value):
|
|
469
|
+
if node.attr in module.__dict__:
|
|
470
|
+
val = module.__dict__[node.attr]
|
|
471
|
+
if isinstance(val, GuppyDefinition):
|
|
472
|
+
defn = ENGINE.get_parsed(val.id)
|
|
473
|
+
qual_name = f"{module.__name__}.{defn.name}"
|
|
474
|
+
return self._check_global(defn, qual_name, node)
|
|
475
|
+
raise GuppyError(
|
|
476
|
+
ModuleMemberNotFoundError(attr_span, module.__name__, node.attr)
|
|
477
|
+
)
|
|
478
|
+
node.value, ty = self.synthesize(node.value)
|
|
479
|
+
if isinstance(ty, StructType) and node.attr in ty.field_dict:
|
|
480
|
+
field = ty.field_dict[node.attr]
|
|
481
|
+
expr: ast.expr
|
|
482
|
+
if isinstance(node.value, PlaceNode):
|
|
483
|
+
# Field access on a place is itself a place
|
|
484
|
+
expr = PlaceNode(place=FieldAccess(node.value.place, field, None))
|
|
485
|
+
else:
|
|
486
|
+
# If the struct is not in a place, then there is no way to address the
|
|
487
|
+
# other fields after this one has been projected (e.g. `f().a` makes
|
|
488
|
+
# you loose access to all fields besides `a`).
|
|
489
|
+
expr = FieldAccessAndDrop(value=node.value, struct_ty=ty, field=field)
|
|
490
|
+
return with_loc(node, expr), field.ty
|
|
491
|
+
elif func := self.ctx.globals.get_instance_func(ty, node.attr):
|
|
492
|
+
name = with_type(
|
|
493
|
+
func.ty, with_loc(node, GlobalName(id=func.name, def_id=func.id))
|
|
494
|
+
)
|
|
495
|
+
# Make a closure by partially applying the `self` argument
|
|
496
|
+
# TODO: Try to infer some type args based on `self`
|
|
497
|
+
result_ty = FunctionType(
|
|
498
|
+
func.ty.inputs[1:],
|
|
499
|
+
func.ty.output,
|
|
500
|
+
func.ty.input_names[1:] if func.ty.input_names else None,
|
|
501
|
+
func.ty.params,
|
|
502
|
+
)
|
|
503
|
+
return with_loc(node, PartialApply(func=name, args=[node.value])), result_ty
|
|
504
|
+
raise GuppyTypeError(AttributeNotFoundError(attr_span, ty, node.attr))
|
|
505
|
+
|
|
506
|
+
def _is_python_module(self, node: ast.expr) -> ModuleType | None:
|
|
507
|
+
"""Checks whether an AST node corresponds to a Python module in scope."""
|
|
508
|
+
if isinstance(node, ast.Name):
|
|
509
|
+
x = node.id
|
|
510
|
+
globals = self.ctx.globals
|
|
511
|
+
if x in globals.f_locals or x in globals.f_globals:
|
|
512
|
+
val = (
|
|
513
|
+
globals.f_locals[x]
|
|
514
|
+
if x in globals.f_locals
|
|
515
|
+
else globals.f_globals[x]
|
|
516
|
+
)
|
|
517
|
+
if isinstance(val, ModuleType):
|
|
518
|
+
return val
|
|
519
|
+
return None
|
|
520
|
+
|
|
521
|
+
def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]:
|
|
522
|
+
elems = [self.synthesize(elem) for elem in node.elts]
|
|
523
|
+
|
|
524
|
+
node.elts = [n for n, _ in elems]
|
|
525
|
+
return node, TupleType([ty for _, ty in elems])
|
|
526
|
+
|
|
527
|
+
def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]:
|
|
528
|
+
check_lists_enabled(node)
|
|
529
|
+
if len(node.elts) == 0:
|
|
530
|
+
unsolved_ty = list_type(ExistentialTypeVar.fresh("T", True, True))
|
|
531
|
+
raise GuppyTypeInferenceError(TypeInferenceError(node, unsolved_ty))
|
|
532
|
+
node.elts[0], el_ty = self.synthesize(node.elts[0])
|
|
533
|
+
node.elts[1:] = [self._check(el, el_ty)[0] for el in node.elts[1:]]
|
|
534
|
+
return node, list_type(el_ty)
|
|
535
|
+
|
|
536
|
+
def visit_DesugaredListComp(self, node: DesugaredListComp) -> tuple[ast.expr, Type]:
|
|
537
|
+
node.generators, node.elt, elt_ty = synthesize_comprehension(
|
|
538
|
+
node, node.generators, node.elt, self.ctx
|
|
539
|
+
)
|
|
540
|
+
result_ty = list_type(elt_ty)
|
|
541
|
+
return node, result_ty
|
|
542
|
+
|
|
543
|
+
def visit_DesugaredGeneratorExpr(
|
|
544
|
+
self, node: DesugaredGeneratorExpr
|
|
545
|
+
) -> tuple[ast.expr, Type]:
|
|
546
|
+
# This is a generator in an arbitrary expression position. We don't support
|
|
547
|
+
# generators as first-class value yet, so we always error out here. Special
|
|
548
|
+
# cases where generator are allowed need to explicitly check for them (e.g. see
|
|
549
|
+
# the handling of array comprehensions in the compiler for the `array` function)
|
|
550
|
+
raise GuppyError(UnsupportedError(node, "Generator expressions"))
|
|
551
|
+
|
|
552
|
+
def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]:
|
|
553
|
+
# We need to synthesise the argument type, so we can look up dunder methods
|
|
554
|
+
node.operand, op_ty = self.synthesize(node.operand)
|
|
555
|
+
|
|
556
|
+
# Special case for the `not` operation since it is not implemented via a dunder
|
|
557
|
+
# method or control-flow
|
|
558
|
+
if isinstance(node.op, ast.Not):
|
|
559
|
+
node.operand, bool_ty = to_bool(node.operand, op_ty, self.ctx)
|
|
560
|
+
return node, bool_ty
|
|
561
|
+
|
|
562
|
+
# Check all other unary expressions by calling out to instance dunder methods
|
|
563
|
+
op, display_name = unary_table[node.op.__class__]
|
|
564
|
+
func = self.ctx.globals.get_instance_func(op_ty, op)
|
|
565
|
+
if func is None:
|
|
566
|
+
raise GuppyTypeError(
|
|
567
|
+
UnaryOperatorNotDefinedError(node.operand, op_ty, display_name)
|
|
568
|
+
)
|
|
569
|
+
return func.synthesize_call([node.operand], node, self.ctx)
|
|
570
|
+
|
|
571
|
+
def _synthesize_binary(
|
|
572
|
+
self, left_expr: ast.expr, right_expr: ast.expr, op: AstOp, node: ast.expr
|
|
573
|
+
) -> tuple[ast.expr, Type]:
|
|
574
|
+
"""Helper method to compile binary operators by calling out to dunder methods.
|
|
575
|
+
|
|
576
|
+
For example, first try calling `__add__` on the left operand. If that fails, try
|
|
577
|
+
`__radd__` on the right operand.
|
|
578
|
+
"""
|
|
579
|
+
if op.__class__ not in binary_table:
|
|
580
|
+
raise GuppyTypeError(UnsupportedError(node, "Operator", singular=True))
|
|
581
|
+
lop, rop, display_name = binary_table[op.__class__]
|
|
582
|
+
left_expr, left_ty = self.synthesize(left_expr)
|
|
583
|
+
right_expr, right_ty = self.synthesize(right_expr)
|
|
584
|
+
|
|
585
|
+
if func := self.ctx.globals.get_instance_func(left_ty, lop):
|
|
586
|
+
with suppress(GuppyError):
|
|
587
|
+
return func.synthesize_call([left_expr, right_expr], node, self.ctx)
|
|
588
|
+
|
|
589
|
+
if func := self.ctx.globals.get_instance_func(right_ty, rop):
|
|
590
|
+
with suppress(GuppyError):
|
|
591
|
+
return func.synthesize_call([right_expr, left_expr], node, self.ctx)
|
|
592
|
+
|
|
593
|
+
raise GuppyTypeError(
|
|
594
|
+
# TODO: Is there a way to get the span of the operator?
|
|
595
|
+
BinaryOperatorNotDefinedError(node, left_ty, right_ty, display_name)
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
def synthesize_instance_func(
|
|
599
|
+
self,
|
|
600
|
+
node: ast.expr,
|
|
601
|
+
args: list[ast.expr],
|
|
602
|
+
func_name: str,
|
|
603
|
+
description: str,
|
|
604
|
+
exp_sig: FunctionType | None = None,
|
|
605
|
+
give_reason: bool = False,
|
|
606
|
+
) -> tuple[ast.expr, Type]:
|
|
607
|
+
"""Helper method for expressions that are implemented via instance methods.
|
|
608
|
+
|
|
609
|
+
Raises a `GuppyTypeError` if the given instance method is not defined. The error
|
|
610
|
+
message can be customised by passing an `err` string and an optional error
|
|
611
|
+
reason can be printed.
|
|
612
|
+
|
|
613
|
+
Optionally, the signature of the instance function can also be checked against a
|
|
614
|
+
given expected signature.
|
|
615
|
+
"""
|
|
616
|
+
node, ty = self.synthesize(node)
|
|
617
|
+
func = self.ctx.globals.get_instance_func(ty, func_name)
|
|
618
|
+
if func is None:
|
|
619
|
+
err = BadProtocolError(node, ty, description)
|
|
620
|
+
if give_reason and exp_sig is not None:
|
|
621
|
+
err.add_sub_diagnostic(
|
|
622
|
+
BadProtocolError.MethodMissing(None, func_name, exp_sig)
|
|
623
|
+
)
|
|
624
|
+
raise GuppyTypeError(err)
|
|
625
|
+
if exp_sig and unify(exp_sig, func.ty.unquantified()[0], {}) is None:
|
|
626
|
+
err = BadProtocolError(node, ty, description)
|
|
627
|
+
err.add_sub_diagnostic(
|
|
628
|
+
BadProtocolError.BadSignature(None, ty, func_name, exp_sig, func.ty)
|
|
629
|
+
)
|
|
630
|
+
raise GuppyError(err)
|
|
631
|
+
return func.synthesize_call([node, *args], node, self.ctx)
|
|
632
|
+
|
|
633
|
+
def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, Type]:
|
|
634
|
+
return self._synthesize_binary(node.left, node.right, node.op, node)
|
|
635
|
+
|
|
636
|
+
def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]:
|
|
637
|
+
if len(node.comparators) != 1 or len(node.ops) != 1:
|
|
638
|
+
raise InternalGuppyError(
|
|
639
|
+
"BB contains chained comparison. Should have been removed during CFG "
|
|
640
|
+
"construction."
|
|
641
|
+
)
|
|
642
|
+
left_expr, [op], [right_expr] = node.left, node.ops, node.comparators
|
|
643
|
+
return self._synthesize_binary(left_expr, right_expr, op, node)
|
|
644
|
+
|
|
645
|
+
def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]:
|
|
646
|
+
node.value, ty = self.synthesize(node.value)
|
|
647
|
+
# Special case for subscripts on functions: Those are type applications
|
|
648
|
+
if isinstance(ty, FunctionType):
|
|
649
|
+
inst = check_type_apply(ty, node, self.ctx)
|
|
650
|
+
return instantiate_poly(node.value, ty, inst), ty.instantiate(inst)
|
|
651
|
+
item_expr, item_ty = self.synthesize(node.slice)
|
|
652
|
+
# Special case for tuples: Index needs to be known statically in order to infer
|
|
653
|
+
# element type of subscript
|
|
654
|
+
if isinstance(ty, TupleType):
|
|
655
|
+
match item_expr:
|
|
656
|
+
case ast.Constant(value=int(idx)):
|
|
657
|
+
if 0 <= idx < len(ty.element_types):
|
|
658
|
+
result_ty = ty.element_types[idx]
|
|
659
|
+
expr: ast.expr
|
|
660
|
+
if isinstance(node.value, PlaceNode):
|
|
661
|
+
tuple_place = TupleAccess(
|
|
662
|
+
node.value.place, result_ty, idx, None
|
|
663
|
+
)
|
|
664
|
+
expr = PlaceNode(place=tuple_place)
|
|
665
|
+
else:
|
|
666
|
+
expr = TupleAccessAndDrop(node.value, ty, idx)
|
|
667
|
+
return with_loc(node, expr), result_ty
|
|
668
|
+
else:
|
|
669
|
+
raise GuppyError(
|
|
670
|
+
TupleIndexOutOfBoundsError(
|
|
671
|
+
item_expr, idx, len(ty.element_types)
|
|
672
|
+
)
|
|
673
|
+
)
|
|
674
|
+
case _:
|
|
675
|
+
raise GuppyTypeError(ExpectedError(item_expr, "an integer literal"))
|
|
676
|
+
# Otherwise, it's a regular __getitem__ subscript
|
|
677
|
+
# Give the item a unique name so we can refer to it later in case we also want
|
|
678
|
+
# to compile a call to `__setitem__`
|
|
679
|
+
item = Variable(next(tmp_vars), item_ty, item_expr)
|
|
680
|
+
item_node = with_type(item_ty, with_loc(item_expr, PlaceNode(place=item)))
|
|
681
|
+
# Check a call to the `__getitem__` instance function
|
|
682
|
+
exp_sig = FunctionType(
|
|
683
|
+
[
|
|
684
|
+
FuncInput(ty, InputFlags.Inout),
|
|
685
|
+
FuncInput(
|
|
686
|
+
ExistentialTypeVar.fresh("Key", True, True), InputFlags.NoFlags
|
|
687
|
+
),
|
|
688
|
+
],
|
|
689
|
+
ExistentialTypeVar.fresh("Val", True, True),
|
|
690
|
+
)
|
|
691
|
+
getitem_expr, result_ty = self.synthesize_instance_func(
|
|
692
|
+
node.value, [item_node], "__getitem__", "subscriptable", exp_sig
|
|
693
|
+
)
|
|
694
|
+
# Subscripting a place is itself a place
|
|
695
|
+
if isinstance(node.value, PlaceNode):
|
|
696
|
+
place = SubscriptAccess(
|
|
697
|
+
node.value.place, item, result_ty, item_expr, getitem_expr
|
|
698
|
+
)
|
|
699
|
+
expr = PlaceNode(place=place)
|
|
700
|
+
else:
|
|
701
|
+
# If the subscript is not on a place, then there is no way to address the
|
|
702
|
+
# other indices after this one has been projected out (e.g. `f()[0]` makes
|
|
703
|
+
# you loose access to all elements besides 0).
|
|
704
|
+
expr = SubscriptAccessAndDrop(
|
|
705
|
+
item=item,
|
|
706
|
+
item_expr=item_expr,
|
|
707
|
+
getitem_expr=getitem_expr,
|
|
708
|
+
original_expr=node,
|
|
709
|
+
)
|
|
710
|
+
return with_loc(node, expr), result_ty
|
|
711
|
+
|
|
712
|
+
def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
|
|
713
|
+
if len(node.keywords) > 0:
|
|
714
|
+
raise GuppyError(UnsupportedError(node.keywords[0], "Keyword arguments"))
|
|
715
|
+
node.func, ty = self.synthesize(node.func)
|
|
716
|
+
|
|
717
|
+
# First handle direct calls of user-defined functions and extension functions
|
|
718
|
+
if isinstance(node.func, GlobalName):
|
|
719
|
+
defn = self.ctx.globals[node.func.def_id]
|
|
720
|
+
if isinstance(defn, CallableDef):
|
|
721
|
+
return defn.synthesize_call(node.args, node, self.ctx)
|
|
722
|
+
|
|
723
|
+
# When calling a `PartialApply` node, we just move the args into this call
|
|
724
|
+
if isinstance(node.func, PartialApply):
|
|
725
|
+
node.args = [*node.func.args, *node.args]
|
|
726
|
+
node.func = node.func.func
|
|
727
|
+
return self.visit_Call(node)
|
|
728
|
+
|
|
729
|
+
# Otherwise, it must be a function as a higher-order value, or a tensor
|
|
730
|
+
if isinstance(ty, FunctionType):
|
|
731
|
+
args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx)
|
|
732
|
+
node.func = instantiate_poly(node.func, ty, inst)
|
|
733
|
+
return with_loc(node, LocalCall(func=node.func, args=args)), return_ty
|
|
734
|
+
elif isinstance(ty, TupleType) and (
|
|
735
|
+
function_elems := parse_function_tensor(ty)
|
|
736
|
+
):
|
|
737
|
+
check_function_tensors_enabled(node.func)
|
|
738
|
+
if any(f.parametrized for f in function_elems):
|
|
739
|
+
raise GuppyError(
|
|
740
|
+
UnsupportedError(node.func, "Polymorphic function tensors")
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
tensor_ty = function_tensor_signature(function_elems)
|
|
744
|
+
args, return_ty, inst = synthesize_call(
|
|
745
|
+
tensor_ty, node.args, node, self.ctx
|
|
746
|
+
)
|
|
747
|
+
assert len(inst) == 0
|
|
748
|
+
|
|
749
|
+
return with_loc(
|
|
750
|
+
node, TensorCall(func=node.func, args=args, tensor_ty=tensor_ty)
|
|
751
|
+
), return_ty
|
|
752
|
+
|
|
753
|
+
elif f := self.ctx.globals.get_instance_func(ty, "__call__"):
|
|
754
|
+
return f.synthesize_call(node.args, node, self.ctx)
|
|
755
|
+
else:
|
|
756
|
+
raise GuppyTypeError(NotCallableError(node.func, ty))
|
|
757
|
+
|
|
758
|
+
def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:
|
|
759
|
+
node.value, ty = self.synthesize(node.value)
|
|
760
|
+
flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
|
|
761
|
+
exp_sig = FunctionType(
|
|
762
|
+
[FuncInput(ty, flags)], ExistentialTypeVar.fresh("Iter", True, True)
|
|
763
|
+
)
|
|
764
|
+
expr, ty = self.synthesize_instance_func(
|
|
765
|
+
node.value, [], "__iter__", "iterable", exp_sig, True
|
|
766
|
+
)
|
|
767
|
+
# Unwrap the size hint if present
|
|
768
|
+
if is_sized_iter_type(ty) and node.unwrap_size_hint:
|
|
769
|
+
expr, ty = self.synthesize_instance_func(expr, [], "unwrap_iter", "")
|
|
770
|
+
|
|
771
|
+
# If the iterator was created by a `for` loop, we can add some extra checks to
|
|
772
|
+
# produce nicer errors for linearity violations. Namely, `break` and `return`
|
|
773
|
+
# are not allowed when looping over a non-copyable iterator (`continue` is
|
|
774
|
+
# allowed)
|
|
775
|
+
if not ty.droppable and isinstance(node.origin_node, ast.For):
|
|
776
|
+
breaks = breaks_in_loop(node.origin_node) or return_nodes_in_ast(
|
|
777
|
+
node.origin_node
|
|
778
|
+
)
|
|
779
|
+
if breaks:
|
|
780
|
+
err = NonDroppableForBreakError(breaks[0])
|
|
781
|
+
err.add_sub_diagnostic(
|
|
782
|
+
NonDroppableForBreakError.NonDroppableIteratorType(node, ty)
|
|
783
|
+
)
|
|
784
|
+
raise GuppyTypeError(err)
|
|
785
|
+
return expr, ty
|
|
786
|
+
|
|
787
|
+
def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
|
|
788
|
+
node.value, ty = self.synthesize(node.value)
|
|
789
|
+
flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
|
|
790
|
+
exp_sig = FunctionType([FuncInput(ty, flags)], TupleType([bool_type(), ty]))
|
|
791
|
+
return self.synthesize_instance_func(
|
|
792
|
+
node.value, [], "__hasnext__", "an iterator", exp_sig, True
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
|
|
796
|
+
node.value, ty = self.synthesize(node.value)
|
|
797
|
+
flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
|
|
798
|
+
exp_sig = FunctionType(
|
|
799
|
+
[FuncInput(ty, flags)],
|
|
800
|
+
option_type(TupleType([ExistentialTypeVar.fresh("T", True, True), ty])),
|
|
801
|
+
)
|
|
802
|
+
return self.synthesize_instance_func(
|
|
803
|
+
node.value, [], "__next__", "an iterator", exp_sig, True
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
|
|
807
|
+
node.value, ty = self.synthesize(node.value)
|
|
808
|
+
flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
|
|
809
|
+
exp_sig = FunctionType([FuncInput(ty, flags)], NoneType())
|
|
810
|
+
return self.synthesize_instance_func(
|
|
811
|
+
node.value, [], "__end__", "an iterator", exp_sig, True
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, Type]:
|
|
815
|
+
raise InternalGuppyError(
|
|
816
|
+
"BB contains `ListComp`. Should have been removed during CFG"
|
|
817
|
+
f"construction: `{ast.unparse(node)}`"
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
def visit_ComptimeExpr(self, node: ComptimeExpr) -> tuple[ast.expr, Type]:
|
|
821
|
+
python_val = eval_comptime_expr(node, self.ctx)
|
|
822
|
+
if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals):
|
|
823
|
+
return with_loc(node, ast.Constant(value=python_val)), ty
|
|
824
|
+
|
|
825
|
+
raise GuppyError(IllegalComptimeExpressionError(node.value, type(python_val)))
|
|
826
|
+
|
|
827
|
+
def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, Type]:
|
|
828
|
+
raise InternalGuppyError(
|
|
829
|
+
"BB contains `NamedExpr`. Should have been removed during CFG"
|
|
830
|
+
f"construction: `{ast.unparse(node)}`"
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
def visit_BoolOp(self, node: ast.BoolOp) -> tuple[ast.expr, Type]:
|
|
834
|
+
raise InternalGuppyError(
|
|
835
|
+
"BB contains `BoolOp`. Should have been removed during CFG construction: "
|
|
836
|
+
f"`{ast.unparse(node)}`"
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, Type]:
|
|
840
|
+
raise InternalGuppyError(
|
|
841
|
+
"BB contains `IfExp`. Should have been removed during CFG construction: "
|
|
842
|
+
f"`{ast.unparse(node)}`"
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
def generic_visit(self, node: ast.expr) -> NoReturn:
|
|
846
|
+
"""Called if no explicit visitor function exists for a node."""
|
|
847
|
+
raise GuppyError(UnsupportedError(node, "This expression", singular=True))
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
def check_type_against(
|
|
851
|
+
act: Type, exp: Type, node: ast.expr, ctx: Context, kind: str = "expression"
|
|
852
|
+
) -> tuple[ast.expr, Subst, Inst]:
|
|
853
|
+
"""Checks a type against another type.
|
|
854
|
+
|
|
855
|
+
Returns a substitution for the free variables the expected type and an instantiation
|
|
856
|
+
for the parameters in the actual type. Note that the expected type may not be
|
|
857
|
+
parametrised and the actual type may not contain free unification variables.
|
|
858
|
+
"""
|
|
859
|
+
assert not isinstance(exp, FunctionType) or not exp.parametrized
|
|
860
|
+
assert not act.unsolved_vars
|
|
861
|
+
|
|
862
|
+
# The actual type may be parametrised. In that case, we have to find an
|
|
863
|
+
# instantiation to avoid higher-rank types.
|
|
864
|
+
subst: Subst | None
|
|
865
|
+
if isinstance(act, FunctionType) and act.parametrized:
|
|
866
|
+
unquantified, free_vars = act.unquantified()
|
|
867
|
+
subst = unify(exp, unquantified, {})
|
|
868
|
+
if subst is None:
|
|
869
|
+
raise GuppyTypeError(TypeMismatchError(node, exp, act, kind))
|
|
870
|
+
# Check that we have found a valid instantiation for all params
|
|
871
|
+
for i, v in enumerate(free_vars):
|
|
872
|
+
param = act.params[i].name
|
|
873
|
+
if v not in subst:
|
|
874
|
+
err = TypeMismatchError(node, exp, act, kind)
|
|
875
|
+
err.add_sub_diagnostic(TypeMismatchError.CantInferParam(None, param))
|
|
876
|
+
raise GuppyTypeInferenceError(err)
|
|
877
|
+
if subst[v].unsolved_vars:
|
|
878
|
+
err = TypeMismatchError(node, exp, act, kind)
|
|
879
|
+
err.add_sub_diagnostic(
|
|
880
|
+
TypeMismatchError.CantInstantiateFreeVars(None, param, subst[v])
|
|
881
|
+
)
|
|
882
|
+
raise GuppyTypeError(err)
|
|
883
|
+
inst = [subst[v].to_arg() for v in free_vars]
|
|
884
|
+
subst = {v: t for v, t in subst.items() if v in exp.unsolved_vars}
|
|
885
|
+
|
|
886
|
+
# Finally, check that the instantiation respects the linearity requirements
|
|
887
|
+
check_inst(act, inst, node)
|
|
888
|
+
|
|
889
|
+
return node, subst, inst
|
|
890
|
+
|
|
891
|
+
# Otherwise, we know that `act` has no unsolved type vars, so unification is trivial
|
|
892
|
+
assert not act.unsolved_vars
|
|
893
|
+
subst = unify(exp, act, {})
|
|
894
|
+
if subst is None:
|
|
895
|
+
# Maybe we can implicitly coerce `act` to `exp`
|
|
896
|
+
if coerced := try_coerce_to(act, exp, node, ctx):
|
|
897
|
+
return coerced, {}, []
|
|
898
|
+
raise GuppyTypeError(TypeMismatchError(node, exp, act, kind))
|
|
899
|
+
return node, subst, []
|
|
900
|
+
|
|
901
|
+
|
|
902
|
+
def try_coerce_to(
|
|
903
|
+
act: Type, exp: Type, node: ast.expr, ctx: Context
|
|
904
|
+
) -> ast.expr | None:
|
|
905
|
+
"""Tries to implicitly coerce an expression to a different type.
|
|
906
|
+
|
|
907
|
+
Returns the coerced expression or `None` if the type cannot be implicitly coerced.
|
|
908
|
+
"""
|
|
909
|
+
# Currently, we only support implicit coercions of numeric types
|
|
910
|
+
if not isinstance(act, NumericType) or not isinstance(exp, NumericType):
|
|
911
|
+
return None
|
|
912
|
+
# Ordering on `NumericType.Kind` defines the coercion relation
|
|
913
|
+
if act.kind < exp.kind:
|
|
914
|
+
f = ctx.globals.get_instance_func(act, f"__{exp.kind.name.lower()}__")
|
|
915
|
+
assert f is not None
|
|
916
|
+
node, subst = f.check_call([node], exp, node, ctx)
|
|
917
|
+
assert len(subst) == 0, "Coercion methods are not generic"
|
|
918
|
+
return node
|
|
919
|
+
return None
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
def check_type_apply(ty: FunctionType, node: ast.Subscript, ctx: Context) -> Inst:
|
|
923
|
+
"""Checks a `f[T1, T2, ...]` type application of a generic function."""
|
|
924
|
+
func = node.value
|
|
925
|
+
arg_exprs = (
|
|
926
|
+
node.slice.elts
|
|
927
|
+
if isinstance(node.slice, ast.Tuple) and len(node.slice.elts) > 0
|
|
928
|
+
else [node.slice]
|
|
929
|
+
)
|
|
930
|
+
globals = ctx.globals
|
|
931
|
+
|
|
932
|
+
if not ty.parametrized:
|
|
933
|
+
func_name = globals[func.def_id].name if isinstance(func, GlobalName) else None
|
|
934
|
+
raise GuppyError(TypeApplyNotGenericError(node, func_name))
|
|
935
|
+
|
|
936
|
+
exp, act = len(ty.params), len(arg_exprs)
|
|
937
|
+
assert exp > 0
|
|
938
|
+
assert act > 0
|
|
939
|
+
if exp != act:
|
|
940
|
+
if exp < act:
|
|
941
|
+
span = Span(to_span(arg_exprs[exp]).start, to_span(arg_exprs[-1]).end)
|
|
942
|
+
else:
|
|
943
|
+
span = Span(to_span(arg_exprs[-1]).end, to_span(node).end)
|
|
944
|
+
err = WrongNumberOfArgsError(span, exp, act, detailed=True, is_type_apply=True)
|
|
945
|
+
err.add_sub_diagnostic(WrongNumberOfArgsError.SignatureHint(None, ty))
|
|
946
|
+
raise GuppyError(err)
|
|
947
|
+
|
|
948
|
+
return [
|
|
949
|
+
param.check_arg(arg_from_ast(arg_expr, globals, ctx.generic_params), arg_expr)
|
|
950
|
+
for arg_expr, param in zip(arg_exprs, ty.params, strict=True)
|
|
951
|
+
]
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
def check_num_args(
|
|
955
|
+
exp: int, act: int, node: AstNode, sig: FunctionType | None = None
|
|
956
|
+
) -> None:
|
|
957
|
+
"""Checks that the correct number of arguments have been passed to a function."""
|
|
958
|
+
if exp == act:
|
|
959
|
+
return
|
|
960
|
+
span, detailed = to_span(node), False
|
|
961
|
+
if isinstance(node, ast.Call):
|
|
962
|
+
# We can construct a nicer error span if we know it's a regular call
|
|
963
|
+
detailed = True
|
|
964
|
+
if exp < act:
|
|
965
|
+
span = Span(to_span(node.args[exp]).start, to_span(node.args[-1]).end)
|
|
966
|
+
elif act > 0:
|
|
967
|
+
span = Span(to_span(node.args[-1]).end, to_span(node).end)
|
|
968
|
+
else:
|
|
969
|
+
span = Span(to_span(node.func).end, to_span(node).end)
|
|
970
|
+
err = WrongNumberOfArgsError(span, exp, act, detailed)
|
|
971
|
+
if sig:
|
|
972
|
+
err.add_sub_diagnostic(WrongNumberOfArgsError.SignatureHint(None, sig))
|
|
973
|
+
raise GuppyTypeError(err)
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
def type_check_args(
|
|
977
|
+
inputs: list[ast.expr],
|
|
978
|
+
func_ty: FunctionType,
|
|
979
|
+
subst: Subst,
|
|
980
|
+
ctx: Context,
|
|
981
|
+
node: AstNode,
|
|
982
|
+
) -> tuple[list[ast.expr], Subst]:
|
|
983
|
+
"""Checks the arguments of a function call and infers free type variables.
|
|
984
|
+
|
|
985
|
+
We expect that parameters have been replaced with free unification variables.
|
|
986
|
+
Checks that all unification variables can be inferred.
|
|
987
|
+
"""
|
|
988
|
+
assert not func_ty.parametrized
|
|
989
|
+
check_num_args(len(func_ty.inputs), len(inputs), node, func_ty)
|
|
990
|
+
|
|
991
|
+
new_args: list[ast.expr] = []
|
|
992
|
+
comptime_args = iter(func_ty.comptime_args)
|
|
993
|
+
for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
|
|
994
|
+
a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
|
|
995
|
+
if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode):
|
|
996
|
+
a.place = check_place_assignable(
|
|
997
|
+
a.place, ctx, a, "able to borrow subscripted elements"
|
|
998
|
+
)
|
|
999
|
+
if InputFlags.Comptime in func_inp.flags:
|
|
1000
|
+
comptime_arg = next(comptime_args)
|
|
1001
|
+
s = check_comptime_arg(a, comptime_arg.const, func_inp.ty, s)
|
|
1002
|
+
new_args.append(a)
|
|
1003
|
+
subst |= s
|
|
1004
|
+
assert next(comptime_args, None) is None
|
|
1005
|
+
|
|
1006
|
+
# If the argument check succeeded, this means that we must have found instantiations
|
|
1007
|
+
# for all unification variables occurring in the input types
|
|
1008
|
+
assert all(
|
|
1009
|
+
set.issubset(inp.ty.unsolved_vars, subst.keys()) for inp in func_ty.inputs
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
# We also have to check that we found instantiations for all vars in the return type
|
|
1013
|
+
if not set.issubset(func_ty.output.unsolved_vars, subst.keys()):
|
|
1014
|
+
raise GuppyTypeInferenceError(
|
|
1015
|
+
TypeInferenceError(node, func_ty.output.substitute(subst))
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
return new_args, subst
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
def check_place_assignable(
|
|
1022
|
+
place: Place, ctx: Context, node: ast.expr, reason: str
|
|
1023
|
+
) -> Place:
|
|
1024
|
+
"""Performs additional checks for assignments to places, for example for borrowed
|
|
1025
|
+
place arguments after function returns.
|
|
1026
|
+
|
|
1027
|
+
In particular, we need to check that places involving `place[item]` subscripts
|
|
1028
|
+
implement the corresponding `__setitem__` method.
|
|
1029
|
+
"""
|
|
1030
|
+
match place:
|
|
1031
|
+
case Variable():
|
|
1032
|
+
return place
|
|
1033
|
+
case FieldAccess(parent=parent):
|
|
1034
|
+
return replace(
|
|
1035
|
+
place, parent=check_place_assignable(parent, ctx, node, reason)
|
|
1036
|
+
)
|
|
1037
|
+
case SubscriptAccess(parent=parent, item=item, ty=ty):
|
|
1038
|
+
# Create temporary variable for the setitem value
|
|
1039
|
+
tmp_var = Variable(next(tmp_vars), item.ty, node)
|
|
1040
|
+
# Check a call to the `__setitem__` instance function
|
|
1041
|
+
exp_sig = FunctionType(
|
|
1042
|
+
[
|
|
1043
|
+
FuncInput(parent.ty, InputFlags.Inout),
|
|
1044
|
+
FuncInput(item.ty, InputFlags.NoFlags),
|
|
1045
|
+
FuncInput(ty, InputFlags.Owned),
|
|
1046
|
+
],
|
|
1047
|
+
NoneType(),
|
|
1048
|
+
)
|
|
1049
|
+
setitem_args: list[ast.expr] = [
|
|
1050
|
+
with_type(parent.ty, with_loc(node, PlaceNode(parent))),
|
|
1051
|
+
with_type(item.ty, with_loc(node, PlaceNode(item))),
|
|
1052
|
+
with_type(ty, with_loc(node, PlaceNode(tmp_var))),
|
|
1053
|
+
]
|
|
1054
|
+
setitem_call, _ = ExprSynthesizer(ctx).synthesize_instance_func(
|
|
1055
|
+
setitem_args[0],
|
|
1056
|
+
setitem_args[1:],
|
|
1057
|
+
"__setitem__",
|
|
1058
|
+
reason,
|
|
1059
|
+
exp_sig,
|
|
1060
|
+
True,
|
|
1061
|
+
)
|
|
1062
|
+
return replace(place, setitem_call=SetitemCall(setitem_call, tmp_var))
|
|
1063
|
+
case TupleAccess(parent=parent):
|
|
1064
|
+
return replace(
|
|
1065
|
+
place, parent=check_place_assignable(parent, ctx, node, reason)
|
|
1066
|
+
)
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
def check_comptime_arg(
|
|
1070
|
+
arg: ast.expr, exp_const: Const, ty: Type, subst: Subst | None
|
|
1071
|
+
) -> Subst:
|
|
1072
|
+
"""Checks that an expression can be passes as a valid `@comptime` argument.
|
|
1073
|
+
|
|
1074
|
+
Also checks that the value matches the provided constant. Returns a substitution
|
|
1075
|
+
that solves any existential variables occurring in provided constant.
|
|
1076
|
+
"""
|
|
1077
|
+
const: Const
|
|
1078
|
+
match arg:
|
|
1079
|
+
case ast.Constant(value=v):
|
|
1080
|
+
const = ConstValue(ty, v)
|
|
1081
|
+
case GenericParamValue(param=const_param):
|
|
1082
|
+
const = const_param.to_bound().const
|
|
1083
|
+
case arg:
|
|
1084
|
+
# Anything else is considered unknown at comptime, but we can give some
|
|
1085
|
+
# nicer error hints by inspecting in more detail
|
|
1086
|
+
err = ComptimeUnknownError(arg, "argument")
|
|
1087
|
+
s: SubDiagnostic
|
|
1088
|
+
match arg:
|
|
1089
|
+
case PlaceNode(place=place) if place.root.is_func_input:
|
|
1090
|
+
s = ComptimeUnknownError.InputHint(place.defined_at, place)
|
|
1091
|
+
case PlaceNode(place=place) if not is_tmp_var(place.root.name):
|
|
1092
|
+
s = ComptimeUnknownError.VariableHint(place.defined_at, place)
|
|
1093
|
+
case arg:
|
|
1094
|
+
s = ComptimeUnknownError.FallbackHint(arg)
|
|
1095
|
+
err.add_sub_diagnostic(s)
|
|
1096
|
+
err.add_sub_diagnostic(ComptimeUnknownError.Feedback(None))
|
|
1097
|
+
raise GuppyError(err)
|
|
1098
|
+
# Unify with expected constant to check and maybe infer some variables
|
|
1099
|
+
subst = unify(exp_const, const, subst)
|
|
1100
|
+
if subst is None:
|
|
1101
|
+
raise GuppyError(ConstMismatchError(arg, exp_const, const))
|
|
1102
|
+
return subst
|
|
1103
|
+
|
|
1104
|
+
|
|
1105
|
+
def synthesize_call(
|
|
1106
|
+
func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context
|
|
1107
|
+
) -> tuple[list[ast.expr], Type, Inst]:
|
|
1108
|
+
"""Synthesizes the return type of a function call.
|
|
1109
|
+
|
|
1110
|
+
Returns an annotated argument list, the synthesized return type, and an
|
|
1111
|
+
instantiation for the quantifiers in the function type.
|
|
1112
|
+
"""
|
|
1113
|
+
assert not func_ty.unsolved_vars
|
|
1114
|
+
check_num_args(len(func_ty.inputs), len(args), node, func_ty)
|
|
1115
|
+
|
|
1116
|
+
# Replace quantified variables with free unification variables and try to infer an
|
|
1117
|
+
# instantiation by checking the arguments
|
|
1118
|
+
unquantified, free_vars = func_ty.unquantified()
|
|
1119
|
+
args, subst = type_check_args(args, unquantified, {}, ctx, node)
|
|
1120
|
+
|
|
1121
|
+
# Success implies that the substitution is closed
|
|
1122
|
+
assert all(not t.unsolved_vars for t in subst.values())
|
|
1123
|
+
inst = [subst[v].to_arg() for v in free_vars]
|
|
1124
|
+
|
|
1125
|
+
# Finally, check that the instantiation respects the linearity requirements
|
|
1126
|
+
check_inst(func_ty, inst, node)
|
|
1127
|
+
|
|
1128
|
+
return args, unquantified.output.substitute(subst), inst
|
|
1129
|
+
|
|
1130
|
+
|
|
1131
|
+
def check_call(
|
|
1132
|
+
func_ty: FunctionType,
|
|
1133
|
+
inputs: list[ast.expr],
|
|
1134
|
+
ty: Type,
|
|
1135
|
+
node: AstNode,
|
|
1136
|
+
ctx: Context,
|
|
1137
|
+
kind: str = "expression",
|
|
1138
|
+
) -> tuple[list[ast.expr], Subst, Inst]:
|
|
1139
|
+
"""Checks the return type of a function call against a given type.
|
|
1140
|
+
|
|
1141
|
+
Returns an annotated argument list, a substitution for the free variables in the
|
|
1142
|
+
expected type, and an instantiation for the quantifiers in the function type.
|
|
1143
|
+
"""
|
|
1144
|
+
assert not func_ty.unsolved_vars
|
|
1145
|
+
check_num_args(len(func_ty.inputs), len(inputs), node, func_ty)
|
|
1146
|
+
|
|
1147
|
+
# When checking, we can use the information from the expected return type to infer
|
|
1148
|
+
# some type arguments. However, this pushes errors inwards. For example, given a
|
|
1149
|
+
# function `foo: forall T. T -> T`, the following type mismatch would be reported:
|
|
1150
|
+
#
|
|
1151
|
+
# x: int = foo(None)
|
|
1152
|
+
# ^^^^ Expected argument of type `int`, got `None`
|
|
1153
|
+
#
|
|
1154
|
+
# But the following error location would be more intuitive for users:
|
|
1155
|
+
#
|
|
1156
|
+
# x: int = foo(None)
|
|
1157
|
+
# ^^^^^^^^^ Expected expression of type `int`, got `None`
|
|
1158
|
+
#
|
|
1159
|
+
# In other words, if we can get away with synthesising the call without the extra
|
|
1160
|
+
# information from the expected type, we should do that to improve the error.
|
|
1161
|
+
|
|
1162
|
+
# TODO: The approach below can result in exponential runtime in the worst case.
|
|
1163
|
+
# However the bad case, e.g. `x: int = foo(foo(...foo(?)...))`, shouldn't be common
|
|
1164
|
+
# in practice. Can we do better than that?
|
|
1165
|
+
|
|
1166
|
+
# First, try to synthesize
|
|
1167
|
+
res: tuple[Type, Inst] | None = None
|
|
1168
|
+
try:
|
|
1169
|
+
inputs, synth, inst = synthesize_call(func_ty, inputs, node, ctx)
|
|
1170
|
+
res = synth, inst
|
|
1171
|
+
except GuppyTypeInferenceError:
|
|
1172
|
+
pass
|
|
1173
|
+
if res is not None:
|
|
1174
|
+
synth, inst = res
|
|
1175
|
+
subst = unify(ty, synth, {})
|
|
1176
|
+
if subst is None:
|
|
1177
|
+
raise GuppyTypeError(TypeMismatchError(node, ty, synth, kind))
|
|
1178
|
+
return inputs, subst, inst
|
|
1179
|
+
|
|
1180
|
+
# If synthesis fails, we try again, this time also using information from the
|
|
1181
|
+
# expected return type
|
|
1182
|
+
unquantified, free_vars = func_ty.unquantified()
|
|
1183
|
+
subst = unify(ty, unquantified.output, {})
|
|
1184
|
+
if subst is None:
|
|
1185
|
+
raise GuppyTypeError(TypeMismatchError(node, ty, unquantified.output, kind))
|
|
1186
|
+
|
|
1187
|
+
# Try to infer more by checking against the arguments
|
|
1188
|
+
inputs, subst = type_check_args(inputs, unquantified, subst, ctx, node)
|
|
1189
|
+
|
|
1190
|
+
# Also make sure we found an instantiation for all free vars in the type we're
|
|
1191
|
+
# checking against
|
|
1192
|
+
if not set.issubset(ty.unsolved_vars, subst.keys()):
|
|
1193
|
+
unsolved = (subst.keys() - ty.unsolved_vars).pop()
|
|
1194
|
+
err = TypeMismatchError(node, ty, func_ty.output.substitute(subst))
|
|
1195
|
+
err.add_sub_diagnostic(
|
|
1196
|
+
TypeMismatchError.CantInferParam(None, unsolved.display_name)
|
|
1197
|
+
)
|
|
1198
|
+
raise GuppyTypeInferenceError(err)
|
|
1199
|
+
|
|
1200
|
+
# Success implies that the substitution is closed
|
|
1201
|
+
assert all(not t.unsolved_vars for t in subst.values())
|
|
1202
|
+
inst = [subst[v].to_arg() for v in free_vars]
|
|
1203
|
+
subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars}
|
|
1204
|
+
|
|
1205
|
+
# Finally, check that the instantiation respects the linearity requirements
|
|
1206
|
+
check_inst(func_ty, inst, node)
|
|
1207
|
+
|
|
1208
|
+
return inputs, subst, inst
|
|
1209
|
+
|
|
1210
|
+
|
|
1211
|
+
def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None:
|
|
1212
|
+
"""Checks if an instantiation is valid.
|
|
1213
|
+
|
|
1214
|
+
Makes sure that the linearity requirements are satisfied.
|
|
1215
|
+
"""
|
|
1216
|
+
for param, arg in zip(func_ty.params, inst, strict=True):
|
|
1217
|
+
# Give a more informative error message for linearity issues
|
|
1218
|
+
if isinstance(param, TypeParam) and isinstance(arg, TypeArg):
|
|
1219
|
+
if param.must_be_copyable and not arg.ty.copyable:
|
|
1220
|
+
raise GuppyTypeError(
|
|
1221
|
+
NonLinearInstantiateError(node, param, func_ty, arg.ty)
|
|
1222
|
+
)
|
|
1223
|
+
if param.must_be_droppable and not arg.ty.droppable:
|
|
1224
|
+
raise GuppyTypeError(
|
|
1225
|
+
NonLinearInstantiateError(node, param, func_ty, arg.ty)
|
|
1226
|
+
)
|
|
1227
|
+
# For everything else, we fall back to the default checking implementation
|
|
1228
|
+
param.check_arg(arg, node)
|
|
1229
|
+
|
|
1230
|
+
|
|
1231
|
+
def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr:
|
|
1232
|
+
"""Instantiates quantified type arguments in a function."""
|
|
1233
|
+
assert len(ty.params) == len(inst)
|
|
1234
|
+
if len(inst) > 0:
|
|
1235
|
+
node = with_loc(node, TypeApply(value=with_type(ty, node), inst=inst))
|
|
1236
|
+
return with_type(ty.instantiate(inst), node)
|
|
1237
|
+
return with_type(ty, node)
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type]:
|
|
1241
|
+
"""Tries to turn a node into a bool"""
|
|
1242
|
+
if is_bool_type(node_ty):
|
|
1243
|
+
return node, node_ty
|
|
1244
|
+
synth = ExprSynthesizer(ctx)
|
|
1245
|
+
exp_sig = FunctionType([FuncInput(node_ty, InputFlags.Inout)], bool_type())
|
|
1246
|
+
return synth.synthesize_instance_func(node, [], "__bool__", "truthy", exp_sig, True)
|
|
1247
|
+
|
|
1248
|
+
|
|
1249
|
+
def synthesize_comprehension(
|
|
1250
|
+
node: AstNode, gens: list[DesugaredGenerator], elt: ast.expr, ctx: Context
|
|
1251
|
+
) -> tuple[list[DesugaredGenerator], ast.expr, Type]:
|
|
1252
|
+
"""Helper function to synthesise the element type of a list comprehension."""
|
|
1253
|
+
# If there are no more generators left, we can check the list element
|
|
1254
|
+
if not gens:
|
|
1255
|
+
elt, elt_ty = ExprSynthesizer(ctx).synthesize(elt)
|
|
1256
|
+
return gens, elt, elt_ty
|
|
1257
|
+
|
|
1258
|
+
# Check the first generator
|
|
1259
|
+
gen, *gens = gens
|
|
1260
|
+
gen, inner_ctx = check_generator(gen, ctx)
|
|
1261
|
+
|
|
1262
|
+
# Check remaining generators in inner context
|
|
1263
|
+
gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx)
|
|
1264
|
+
|
|
1265
|
+
return [gen, *gens], elt, elt_ty
|
|
1266
|
+
|
|
1267
|
+
|
|
1268
|
+
def check_generator(
|
|
1269
|
+
gen: DesugaredGenerator, ctx: Context
|
|
1270
|
+
) -> tuple[DesugaredGenerator, Context]:
|
|
1271
|
+
"""Helper function to check a single generator.
|
|
1272
|
+
|
|
1273
|
+
Returns the type annotated generator together with a new nested context in which the
|
|
1274
|
+
generator variables are bound.
|
|
1275
|
+
"""
|
|
1276
|
+
from guppylang_internals.checker.stmt_checker import StmtChecker
|
|
1277
|
+
|
|
1278
|
+
# Check the iterator in the outer context
|
|
1279
|
+
gen.iter_assign = StmtChecker(ctx).visit_Assign(gen.iter_assign)
|
|
1280
|
+
|
|
1281
|
+
# The rest is checked in a new nested context to ensure that variables don't escape
|
|
1282
|
+
# their scope
|
|
1283
|
+
inner_locals: Locals[str, Variable] = Locals({}, parent_scope=ctx.locals)
|
|
1284
|
+
inner_ctx = Context(ctx.globals, inner_locals, ctx.generic_params)
|
|
1285
|
+
expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx)
|
|
1286
|
+
gen.iter, iter_ty = expr_sth.visit(gen.iter)
|
|
1287
|
+
gen.iter = with_type(iter_ty, gen.iter)
|
|
1288
|
+
|
|
1289
|
+
# The type returned by `next_call` is `Option[tuple[elt_ty, iter_ty]]`
|
|
1290
|
+
gen.next_call, option_ty = expr_sth.synthesize(gen.next_call)
|
|
1291
|
+
next_ty = get_element_type(option_ty)
|
|
1292
|
+
assert isinstance(next_ty, TupleType)
|
|
1293
|
+
[elt_ty, _] = next_ty.element_types
|
|
1294
|
+
gen.target = stmt_chk._check_assign(gen.target, gen.next_call, elt_ty)
|
|
1295
|
+
|
|
1296
|
+
# Check `if` guards
|
|
1297
|
+
for i in range(len(gen.ifs)):
|
|
1298
|
+
gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i])
|
|
1299
|
+
gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx)
|
|
1300
|
+
|
|
1301
|
+
return gen, inner_ctx
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
def eval_comptime_expr(node: ComptimeExpr, ctx: Context) -> Any:
|
|
1305
|
+
"""Evaluates a `comptime(...)` expression."""
|
|
1306
|
+
# The method we used for obtaining the Python variables in scope only works in
|
|
1307
|
+
# CPython (see `get_py_scope()`).
|
|
1308
|
+
if sys.implementation.name != "cpython":
|
|
1309
|
+
raise GuppyError(ComptimeExprNotCPythonError(node))
|
|
1310
|
+
|
|
1311
|
+
try:
|
|
1312
|
+
python_val = eval( # noqa: S307
|
|
1313
|
+
ast.unparse(node.value),
|
|
1314
|
+
None,
|
|
1315
|
+
DummyEvalDict(ctx, node.value),
|
|
1316
|
+
)
|
|
1317
|
+
except DummyEvalDict.GuppyVarUsedError as e:
|
|
1318
|
+
raise GuppyError(ComptimeExprNotStaticError(e.node or node, e.var)) from None
|
|
1319
|
+
except Exception as e:
|
|
1320
|
+
# Remove the top frame pointing to the `eval` call from the stack trace
|
|
1321
|
+
tb = e.__traceback__.tb_next if e.__traceback__ else None
|
|
1322
|
+
tb_formatted = "".join(traceback.format_exception(type(e), e, tb))
|
|
1323
|
+
raise GuppyError(ComptimeExprEvalError(node.value, tb_formatted)) from e
|
|
1324
|
+
return python_val
|
|
1325
|
+
|
|
1326
|
+
|
|
1327
|
+
def python_value_to_guppy_type(
|
|
1328
|
+
v: Any, node: ast.AST, globals: Globals, type_hint: Type | None = None
|
|
1329
|
+
) -> Type | None:
|
|
1330
|
+
"""Turns a primitive Python value into a Guppy type.
|
|
1331
|
+
|
|
1332
|
+
Accepts an optional `type_hint` for the expected expression type that is used to
|
|
1333
|
+
infer a more precise type (e.g. distinguishing between `int` and `nat`). Note that
|
|
1334
|
+
invalid hints are ignored, i.e. no user error are emitted.
|
|
1335
|
+
|
|
1336
|
+
Returns `None` if the Python value cannot be represented in Guppy.
|
|
1337
|
+
"""
|
|
1338
|
+
match v:
|
|
1339
|
+
case bool():
|
|
1340
|
+
return bool_type()
|
|
1341
|
+
case str():
|
|
1342
|
+
return string_type()
|
|
1343
|
+
# Only resolve `int` to `nat` if the user specifically asked for it
|
|
1344
|
+
case int(n) if type_hint == nat_type() and n >= 0:
|
|
1345
|
+
_int_bounds_check(n, node, signed=False)
|
|
1346
|
+
return nat_type()
|
|
1347
|
+
# Otherwise, default to `int` for consistency with Python
|
|
1348
|
+
case int(n):
|
|
1349
|
+
_int_bounds_check(n, node, signed=True)
|
|
1350
|
+
return int_type()
|
|
1351
|
+
case float():
|
|
1352
|
+
return float_type()
|
|
1353
|
+
case tuple(elts):
|
|
1354
|
+
hints = (
|
|
1355
|
+
type_hint.element_types
|
|
1356
|
+
if isinstance(type_hint, TupleType)
|
|
1357
|
+
else len(elts) * [None]
|
|
1358
|
+
)
|
|
1359
|
+
tys = [
|
|
1360
|
+
python_value_to_guppy_type(elt, node, globals, hint)
|
|
1361
|
+
for elt, hint in zip(elts, hints, strict=False)
|
|
1362
|
+
]
|
|
1363
|
+
if any(ty is None for ty in tys):
|
|
1364
|
+
return None
|
|
1365
|
+
return TupleType(cast(list[Type], tys))
|
|
1366
|
+
case list():
|
|
1367
|
+
return _python_list_to_guppy_type(v, node, globals, type_hint)
|
|
1368
|
+
case _:
|
|
1369
|
+
return None
|
|
1370
|
+
|
|
1371
|
+
|
|
1372
|
+
def _int_bounds_check(value: int, node: AstNode, signed: bool) -> None:
|
|
1373
|
+
bit_width = 1 << NumericType.INT_WIDTH
|
|
1374
|
+
if signed:
|
|
1375
|
+
max_v = (1 << (bit_width - 1)) - 1
|
|
1376
|
+
min_v = -(1 << (bit_width - 1))
|
|
1377
|
+
else:
|
|
1378
|
+
max_v = (1 << bit_width) - 1
|
|
1379
|
+
min_v = 0
|
|
1380
|
+
if value < min_v or value > max_v:
|
|
1381
|
+
err = IntOverflowError(node, signed, bit_width, value < min_v)
|
|
1382
|
+
raise GuppyTypeError(err)
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def _python_list_to_guppy_type(
|
|
1386
|
+
vs: list[Any], node: ast.AST, globals: Globals, type_hint: Type | None
|
|
1387
|
+
) -> OpaqueType | None:
|
|
1388
|
+
"""Turns a Python list into a Guppy type.
|
|
1389
|
+
|
|
1390
|
+
Returns `None` if the list contains different types or types that are not
|
|
1391
|
+
representable in Guppy.
|
|
1392
|
+
"""
|
|
1393
|
+
if len(vs) == 0:
|
|
1394
|
+
return frozenarray_type(ExistentialTypeVar.fresh("T", True, True), 0)
|
|
1395
|
+
|
|
1396
|
+
# All the list elements must have a unifiable types
|
|
1397
|
+
v, *rest = vs
|
|
1398
|
+
elt_hint = (
|
|
1399
|
+
get_element_type(type_hint)
|
|
1400
|
+
if type_hint and is_frozenarray_type(type_hint)
|
|
1401
|
+
else None
|
|
1402
|
+
)
|
|
1403
|
+
el_ty = python_value_to_guppy_type(v, node, globals, elt_hint)
|
|
1404
|
+
if el_ty is None:
|
|
1405
|
+
return None
|
|
1406
|
+
for v in rest:
|
|
1407
|
+
ty = python_value_to_guppy_type(v, node, globals, elt_hint)
|
|
1408
|
+
if ty is None:
|
|
1409
|
+
return None
|
|
1410
|
+
if (subst := unify(ty, el_ty, {})) is None:
|
|
1411
|
+
raise GuppyError(ComptimeExprIncoherentListError(node))
|
|
1412
|
+
el_ty = el_ty.substitute(subst)
|
|
1413
|
+
return frozenarray_type(el_ty, len(vs))
|