guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,1413 @@
1
+ """Type checking and synthesizing code for expressions.
2
+
3
+ Operates on expressions in a basic block after CFG construction. In particular, we
4
+ assume that expressions that involve control flow (i.e. short-circuiting and ternary
5
+ expressions) have been removed during CFG construction.
6
+
7
+ Furthermore, we assume that assignment expressions with the walrus operator := have
8
+ been turned into regular assignments and are no longer present. As a result, expressions
9
+ are assumed to be side effect free, in the sense that they do not modify the variables
10
+ available in the type checking context.
11
+
12
+ We may alter/desugar AST nodes during type checking. In particular, we turn `ast.Name`
13
+ nodes into either `LocalName` or `GlobalName` nodes and `ast.Call` nodes are turned into
14
+ `LocalCall` or `GlobalCall` nodes. Furthermore, all nodes in the resulting AST are
15
+ annotated with their type.
16
+
17
+ Expressions can be checked against a given type by the `ExprChecker`, raising a type
18
+ error if the expressions doesn't have the expected type. Checking is used for annotated
19
+ assignments, return values, and function arguments. Alternatively, the `ExprSynthesizer`
20
+ can be used to infer a type for an expression.
21
+ """
22
+
23
+ import ast
24
+ import sys
25
+ import traceback
26
+ from contextlib import suppress
27
+ from dataclasses import replace
28
+ from types import ModuleType
29
+ from typing import TYPE_CHECKING, Any, NoReturn, cast
30
+
31
+ from typing_extensions import assert_never
32
+
33
+ from guppylang_internals.ast_util import (
34
+ AstNode,
35
+ AstVisitor,
36
+ breaks_in_loop,
37
+ get_type_opt,
38
+ return_nodes_in_ast,
39
+ with_loc,
40
+ with_type,
41
+ )
42
+ from guppylang_internals.cfg.builder import is_tmp_var, tmp_vars
43
+ from guppylang_internals.checker.core import (
44
+ Context,
45
+ DummyEvalDict,
46
+ FieldAccess,
47
+ Globals,
48
+ Locals,
49
+ Place,
50
+ PythonObject,
51
+ SetitemCall,
52
+ SubscriptAccess,
53
+ TupleAccess,
54
+ Variable,
55
+ )
56
+ from guppylang_internals.checker.errors.comptime_errors import (
57
+ ComptimeExprEvalError,
58
+ ComptimeExprIncoherentListError,
59
+ ComptimeExprNotCPythonError,
60
+ ComptimeExprNotStaticError,
61
+ ComptimeUnknownError,
62
+ IllegalComptimeExpressionError,
63
+ )
64
+ from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
65
+ from guppylang_internals.checker.errors.linearity import NonDroppableForBreakError
66
+ from guppylang_internals.checker.errors.type_errors import (
67
+ AttributeNotFoundError,
68
+ BadProtocolError,
69
+ BinaryOperatorNotDefinedError,
70
+ ConstMismatchError,
71
+ IllegalConstant,
72
+ IntOverflowError,
73
+ ModuleMemberNotFoundError,
74
+ NonLinearInstantiateError,
75
+ NotCallableError,
76
+ TupleIndexOutOfBoundsError,
77
+ TypeApplyNotGenericError,
78
+ TypeInferenceError,
79
+ TypeMismatchError,
80
+ UnaryOperatorNotDefinedError,
81
+ WrongNumberOfArgsError,
82
+ )
83
+ from guppylang_internals.definition.common import Definition
84
+ from guppylang_internals.definition.ty import TypeDef
85
+ from guppylang_internals.definition.value import CallableDef, ValueDef
86
+ from guppylang_internals.error import (
87
+ GuppyError,
88
+ GuppyTypeError,
89
+ GuppyTypeInferenceError,
90
+ InternalGuppyError,
91
+ )
92
+ from guppylang_internals.experimental import (
93
+ check_function_tensors_enabled,
94
+ check_lists_enabled,
95
+ )
96
+ from guppylang_internals.nodes import (
97
+ ComptimeExpr,
98
+ DesugaredGenerator,
99
+ DesugaredGeneratorExpr,
100
+ DesugaredListComp,
101
+ FieldAccessAndDrop,
102
+ GenericParamValue,
103
+ GlobalName,
104
+ IterEnd,
105
+ IterHasNext,
106
+ IterNext,
107
+ LocalCall,
108
+ MakeIter,
109
+ PartialApply,
110
+ PlaceNode,
111
+ SubscriptAccessAndDrop,
112
+ TensorCall,
113
+ TupleAccessAndDrop,
114
+ TypeApply,
115
+ )
116
+ from guppylang_internals.span import Span, to_span
117
+ from guppylang_internals.tys.arg import TypeArg
118
+ from guppylang_internals.tys.builtin import (
119
+ bool_type,
120
+ float_type,
121
+ frozenarray_type,
122
+ get_element_type,
123
+ int_type,
124
+ is_bool_type,
125
+ is_frozenarray_type,
126
+ is_list_type,
127
+ is_sized_iter_type,
128
+ list_type,
129
+ nat_type,
130
+ option_type,
131
+ string_type,
132
+ )
133
+ from guppylang_internals.tys.const import Const, ConstValue
134
+ from guppylang_internals.tys.param import ConstParam, TypeParam
135
+ from guppylang_internals.tys.parsing import arg_from_ast
136
+ from guppylang_internals.tys.subst import Inst, Subst
137
+ from guppylang_internals.tys.ty import (
138
+ ExistentialTypeVar,
139
+ FuncInput,
140
+ FunctionType,
141
+ InputFlags,
142
+ NoneType,
143
+ NumericType,
144
+ OpaqueType,
145
+ StructType,
146
+ TupleType,
147
+ Type,
148
+ TypeBase,
149
+ function_tensor_signature,
150
+ parse_function_tensor,
151
+ unify,
152
+ )
153
+
154
+ if TYPE_CHECKING:
155
+ from guppylang_internals.diagnostic import SubDiagnostic
156
+
157
+ # Mapping from unary AST op to dunder method and display name
158
+ unary_table: dict[type[ast.unaryop], tuple[str, str]] = {
159
+ ast.UAdd: ("__pos__", "+"),
160
+ ast.USub: ("__neg__", "-"),
161
+ ast.Invert: ("__invert__", "~"),
162
+ } # fmt: skip
163
+
164
+ # Mapping from binary AST op to left dunder method, right dunder method and display name
165
+ AstOp = ast.operator | ast.cmpop
166
+ binary_table: dict[type[AstOp], tuple[str, str, str]] = {
167
+ ast.Add: ("__add__", "__radd__", "+"),
168
+ ast.Sub: ("__sub__", "__rsub__", "-"),
169
+ ast.Mult: ("__mul__", "__rmul__", "*"),
170
+ ast.Div: ("__truediv__", "__rtruediv__", "/"),
171
+ ast.FloorDiv: ("__floordiv__", "__rfloordiv__", "//"),
172
+ ast.Mod: ("__mod__", "__rmod__", "%"),
173
+ ast.Pow: ("__pow__", "__rpow__", "**"),
174
+ ast.LShift: ("__lshift__", "__rlshift__", "<<"),
175
+ ast.RShift: ("__rshift__", "__rrshift__", ">>"),
176
+ ast.BitOr: ("__or__", "__ror__", "|"),
177
+ ast.BitXor: ("__xor__", "__rxor__", "^"),
178
+ ast.BitAnd: ("__and__", "__rand__", "&"),
179
+ ast.MatMult: ("__matmul__", "__rmatmul__", "@"),
180
+ ast.Eq: ("__eq__", "__eq__", "=="),
181
+ ast.NotEq: ("__ne__", "__ne__", "!="),
182
+ ast.Lt: ("__lt__", "__gt__", "<"),
183
+ ast.LtE: ("__le__", "__ge__", "<="),
184
+ ast.Gt: ("__gt__", "__lt__", ">"),
185
+ ast.GtE: ("__ge__", "__le__", ">="),
186
+ } # fmt: skip
187
+
188
+
189
+ class ExprChecker(AstVisitor[tuple[ast.expr, Subst]]):
190
+ """Checks an expression against a type and produces a new type-annotated AST.
191
+
192
+ The type may contain free variables that the checker will try to solve. Note that
193
+ the checker will fail, if some free variables cannot be inferred.
194
+ """
195
+
196
+ ctx: Context
197
+
198
+ # Name for the kind of term we are currently checking against (used in errors).
199
+ # For example, "argument", "return value", or in general "expression".
200
+ _kind: str
201
+
202
+ def __init__(self, ctx: Context) -> None:
203
+ self.ctx = ctx
204
+ self._kind = "expression"
205
+
206
+ def _fail(
207
+ self,
208
+ expected: Type,
209
+ actual: ast.expr | Type,
210
+ loc: AstNode | None = None,
211
+ ) -> NoReturn:
212
+ """Raises a type error indicating that the type doesn't match."""
213
+ if not isinstance(actual, TypeBase):
214
+ loc = loc or actual
215
+ _, actual = self._synthesize(actual, allow_free_vars=True)
216
+ if loc is None:
217
+ raise InternalGuppyError("Failure location is required")
218
+ raise GuppyTypeError(TypeMismatchError(loc, expected, actual))
219
+
220
+ def check(
221
+ self, expr: ast.expr, ty: Type, kind: str = "expression"
222
+ ) -> tuple[ast.expr, Subst]:
223
+ """Checks an expression against a type.
224
+
225
+ The type may have free type variables which will try to be resolved. Returns
226
+ a new desugared expression with type annotations and a substitution with the
227
+ resolved type variables.
228
+ """
229
+ # If we already have a type for the expression, we just have to match it against
230
+ # the target
231
+ if actual := get_type_opt(expr):
232
+ expr, subst, inst = check_type_against(actual, ty, expr, self.ctx, kind)
233
+ if inst:
234
+ expr = with_loc(expr, TypeApply(value=expr, tys=inst))
235
+ return with_type(ty.substitute(subst), expr), subst
236
+
237
+ # When checking against a variable, we have to synthesize
238
+ if isinstance(ty, ExistentialTypeVar):
239
+ expr, syn_ty = self._synthesize(expr, allow_free_vars=False)
240
+ return with_type(syn_ty, expr), {ty: syn_ty}
241
+
242
+ # Otherwise, invoke the visitor
243
+ old_kind = self._kind
244
+ self._kind = kind or self._kind
245
+ expr, subst = self.visit(expr, ty)
246
+ self._kind = old_kind
247
+ return with_type(ty.substitute(subst), expr), subst
248
+
249
+ def _synthesize(
250
+ self, node: ast.expr, allow_free_vars: bool
251
+ ) -> tuple[ast.expr, Type]:
252
+ """Invokes the type synthesiser"""
253
+ return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars)
254
+
255
+ def visit_Constant(self, node: ast.Constant, ty: Type) -> tuple[ast.expr, Subst]:
256
+ act = python_value_to_guppy_type(node.value, node, self.ctx.globals, ty)
257
+ if act is None:
258
+ raise GuppyError(IllegalConstant(node, type(node.value)))
259
+ node, subst, inst = check_type_against(act, ty, node, self.ctx, self._kind)
260
+ assert inst == [], "Const values are not generic"
261
+ return node, subst
262
+
263
+ def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]:
264
+ if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts):
265
+ return self._fail(ty, node)
266
+ subst: Subst = {}
267
+ for i, el in enumerate(node.elts):
268
+ node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst))
269
+ subst |= s
270
+ return node, subst
271
+
272
+ def visit_List(self, node: ast.List, ty: Type) -> tuple[ast.expr, Subst]:
273
+ check_lists_enabled(node)
274
+ if not is_list_type(ty):
275
+ return self._fail(ty, node)
276
+ el_ty = get_element_type(ty)
277
+ subst: Subst = {}
278
+ for i, el in enumerate(node.elts):
279
+ node.elts[i], s = self.check(el, el_ty.substitute(subst))
280
+ subst |= s
281
+ return node, subst
282
+
283
+ def visit_DesugaredListComp(
284
+ self, node: DesugaredListComp, ty: Type
285
+ ) -> tuple[ast.expr, Subst]:
286
+ if not is_list_type(ty):
287
+ return self._fail(ty, node)
288
+ node.generators, node.elt, elt_ty = synthesize_comprehension(
289
+ node, node.generators, node.elt, self.ctx
290
+ )
291
+ subst = unify(get_element_type(ty), elt_ty, {})
292
+ if subst is None:
293
+ actual = list_type(elt_ty)
294
+ return self._fail(ty, actual, node)
295
+ return node, subst
296
+
297
+ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:
298
+ if len(node.keywords) > 0:
299
+ raise GuppyError(UnsupportedError(node.keywords[0], "Keyword arguments"))
300
+ node.func, func_ty = self._synthesize(node.func, allow_free_vars=False)
301
+
302
+ # First handle direct calls of user-defined functions and extension functions
303
+ if isinstance(node.func, GlobalName):
304
+ defn = self.ctx.globals[node.func.def_id]
305
+ if isinstance(defn, CallableDef):
306
+ return defn.check_call(node.args, ty, node, self.ctx)
307
+
308
+ # When calling a `PartialApply` node, we just move the args into this call
309
+ if isinstance(node.func, PartialApply):
310
+ node.args = [*node.func.args, *node.args]
311
+ node.func = node.func.func
312
+ return self.visit_Call(node, ty)
313
+
314
+ # Otherwise, it must be a function as a higher-order value - something
315
+ # whose type is either a FunctionType or a Tuple of FunctionTypes
316
+ if isinstance(func_ty, FunctionType):
317
+ args, return_ty, inst = check_call(func_ty, node.args, ty, node, self.ctx)
318
+ check_inst(func_ty, inst, node)
319
+ node.func = instantiate_poly(node.func, func_ty, inst)
320
+ return with_loc(node, LocalCall(func=node.func, args=args)), return_ty
321
+
322
+ if isinstance(func_ty, TupleType) and (
323
+ function_elements := parse_function_tensor(func_ty)
324
+ ):
325
+ check_function_tensors_enabled(node.func)
326
+ if any(f.parametrized for f in function_elements):
327
+ raise GuppyError(
328
+ UnsupportedError(node.func, "Polymorphic function tensors")
329
+ )
330
+
331
+ tensor_ty = function_tensor_signature(function_elements)
332
+
333
+ processed_args, subst, inst = check_call(
334
+ tensor_ty, node.args, ty, node, self.ctx
335
+ )
336
+ assert len(inst) == 0
337
+ return with_loc(
338
+ node,
339
+ TensorCall(func=node.func, args=processed_args, tensor_ty=tensor_ty),
340
+ ), subst
341
+
342
+ elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"):
343
+ return callee.check_call(node.args, ty, node, self.ctx)
344
+ else:
345
+ raise GuppyTypeError(NotCallableError(node.func, func_ty))
346
+
347
+ def visit_ComptimeExpr(
348
+ self, node: ComptimeExpr, ty: Type
349
+ ) -> tuple[ast.expr, Subst]:
350
+ python_val = eval_comptime_expr(node, self.ctx)
351
+ if act := python_value_to_guppy_type(
352
+ python_val, node.value, self.ctx.globals, ty
353
+ ):
354
+ subst = unify(ty, act, {})
355
+ if subst is None:
356
+ self._fail(ty, act, node)
357
+ act = act.substitute(subst)
358
+ subst = {x: s for x, s in subst.items() if x in ty.unsolved_vars}
359
+ return with_type(act, with_loc(node, ast.Constant(value=python_val))), subst
360
+
361
+ raise GuppyError(IllegalComptimeExpressionError(node.value, type(python_val)))
362
+
363
+ def generic_visit(self, node: ast.expr, ty: Type) -> tuple[ast.expr, Subst]:
364
+ # Try to synthesize and then check if we can unify it with the given type
365
+ node, synth = self._synthesize(node, allow_free_vars=False)
366
+ node, subst, inst = check_type_against(synth, ty, node, self.ctx, self._kind)
367
+
368
+ # Apply instantiation of quantified type variables
369
+ if inst:
370
+ node = with_loc(node, TypeApply(value=node, inst=inst))
371
+
372
+ return node, subst
373
+
374
+
375
+ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
376
+ ctx: Context
377
+
378
+ def __init__(self, ctx: Context) -> None:
379
+ self.ctx = ctx
380
+
381
+ def synthesize(
382
+ self, node: ast.expr, allow_free_vars: bool = False
383
+ ) -> tuple[ast.expr, Type]:
384
+ """Tries to synthesise a type for the given expression.
385
+
386
+ Also returns a new desugared expression with type annotations.
387
+ """
388
+ if ty := get_type_opt(node):
389
+ return node, ty
390
+ node, ty = self.visit(node)
391
+ if ty.unsolved_vars and not allow_free_vars:
392
+ raise GuppyError(TypeInferenceError(node, ty))
393
+ return with_type(ty, node), ty
394
+
395
+ def _check(
396
+ self, expr: ast.expr, ty: Type, kind: str = "expression"
397
+ ) -> tuple[ast.expr, Subst]:
398
+ """Checks an expression against a given type"""
399
+ return ExprChecker(self.ctx).check(expr, ty, kind)
400
+
401
+ def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, Type]:
402
+ ty = python_value_to_guppy_type(node.value, node, self.ctx.globals)
403
+ if ty is None:
404
+ raise GuppyError(IllegalConstant(node, type(node.value)))
405
+ return node, ty
406
+
407
+ def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]:
408
+ x = node.id
409
+ if x in self.ctx.locals:
410
+ var = self.ctx.locals[x]
411
+ return with_loc(node, PlaceNode(place=var)), var.ty
412
+ elif x in self.ctx.generic_params:
413
+ param = self.ctx.generic_params[x]
414
+ match param:
415
+ case ConstParam() as param:
416
+ ast_node = with_loc(node, GenericParamValue(id=x, param=param))
417
+ return ast_node, param.ty
418
+ case TypeParam() as param:
419
+ raise GuppyError(
420
+ ExpectedError(node, "a value", got=f"type `{param.name}`")
421
+ )
422
+ case _:
423
+ return assert_never(param)
424
+ elif x in self.ctx.globals:
425
+ match self.ctx.globals[x]:
426
+ case Definition() as defn:
427
+ return self._check_global(defn, x, node)
428
+ case PythonObject():
429
+ from guppylang_internals.checker.cfg_checker import (
430
+ VarNotDefinedError,
431
+ )
432
+
433
+ # TODO: Emit a hint that the variable could be accessed through a
434
+ # comptime expression
435
+ raise GuppyError(VarNotDefinedError(node, node.id))
436
+
437
+ raise InternalGuppyError(
438
+ f"Variable `{x}` is not defined in `TypeSynthesiser`. This should have "
439
+ "been caught by program analysis!"
440
+ )
441
+
442
+ def _check_global(
443
+ self, defn: Definition, name: str, node: ast.expr
444
+ ) -> tuple[ast.expr, Type]:
445
+ """Checks a global definition in an expression position."""
446
+ match defn:
447
+ case ValueDef() as defn:
448
+ return with_loc(node, GlobalName(id=name, def_id=defn.id)), defn.ty
449
+ # For types, we return their `__new__` constructor
450
+ case TypeDef() as defn if constr := self.ctx.globals.get_instance_func(
451
+ defn, "__new__"
452
+ ):
453
+ return with_loc(node, GlobalName(id=name, def_id=constr.id)), constr.ty
454
+ case defn:
455
+ raise GuppyError(
456
+ ExpectedError(node, "a value", got=f"{defn.description} `{name}`")
457
+ )
458
+
459
+ def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
460
+ from guppylang.defs import GuppyDefinition
461
+ from guppylang_internals.engine import ENGINE
462
+
463
+ # A `value.attr` attribute access. Unfortunately, the `attr` is just a string,
464
+ # not an AST node, so we have to compute its span by hand. This is fine since
465
+ # linebreaks are not allowed in the identifier following the `.`
466
+ span = to_span(node)
467
+ attr_span = Span(span.end.shift_left(len(node.attr)), span.end)
468
+ if module := self._is_python_module(node.value):
469
+ if node.attr in module.__dict__:
470
+ val = module.__dict__[node.attr]
471
+ if isinstance(val, GuppyDefinition):
472
+ defn = ENGINE.get_parsed(val.id)
473
+ qual_name = f"{module.__name__}.{defn.name}"
474
+ return self._check_global(defn, qual_name, node)
475
+ raise GuppyError(
476
+ ModuleMemberNotFoundError(attr_span, module.__name__, node.attr)
477
+ )
478
+ node.value, ty = self.synthesize(node.value)
479
+ if isinstance(ty, StructType) and node.attr in ty.field_dict:
480
+ field = ty.field_dict[node.attr]
481
+ expr: ast.expr
482
+ if isinstance(node.value, PlaceNode):
483
+ # Field access on a place is itself a place
484
+ expr = PlaceNode(place=FieldAccess(node.value.place, field, None))
485
+ else:
486
+ # If the struct is not in a place, then there is no way to address the
487
+ # other fields after this one has been projected (e.g. `f().a` makes
488
+ # you loose access to all fields besides `a`).
489
+ expr = FieldAccessAndDrop(value=node.value, struct_ty=ty, field=field)
490
+ return with_loc(node, expr), field.ty
491
+ elif func := self.ctx.globals.get_instance_func(ty, node.attr):
492
+ name = with_type(
493
+ func.ty, with_loc(node, GlobalName(id=func.name, def_id=func.id))
494
+ )
495
+ # Make a closure by partially applying the `self` argument
496
+ # TODO: Try to infer some type args based on `self`
497
+ result_ty = FunctionType(
498
+ func.ty.inputs[1:],
499
+ func.ty.output,
500
+ func.ty.input_names[1:] if func.ty.input_names else None,
501
+ func.ty.params,
502
+ )
503
+ return with_loc(node, PartialApply(func=name, args=[node.value])), result_ty
504
+ raise GuppyTypeError(AttributeNotFoundError(attr_span, ty, node.attr))
505
+
506
+ def _is_python_module(self, node: ast.expr) -> ModuleType | None:
507
+ """Checks whether an AST node corresponds to a Python module in scope."""
508
+ if isinstance(node, ast.Name):
509
+ x = node.id
510
+ globals = self.ctx.globals
511
+ if x in globals.f_locals or x in globals.f_globals:
512
+ val = (
513
+ globals.f_locals[x]
514
+ if x in globals.f_locals
515
+ else globals.f_globals[x]
516
+ )
517
+ if isinstance(val, ModuleType):
518
+ return val
519
+ return None
520
+
521
+ def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]:
522
+ elems = [self.synthesize(elem) for elem in node.elts]
523
+
524
+ node.elts = [n for n, _ in elems]
525
+ return node, TupleType([ty for _, ty in elems])
526
+
527
+ def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]:
528
+ check_lists_enabled(node)
529
+ if len(node.elts) == 0:
530
+ unsolved_ty = list_type(ExistentialTypeVar.fresh("T", True, True))
531
+ raise GuppyTypeInferenceError(TypeInferenceError(node, unsolved_ty))
532
+ node.elts[0], el_ty = self.synthesize(node.elts[0])
533
+ node.elts[1:] = [self._check(el, el_ty)[0] for el in node.elts[1:]]
534
+ return node, list_type(el_ty)
535
+
536
+ def visit_DesugaredListComp(self, node: DesugaredListComp) -> tuple[ast.expr, Type]:
537
+ node.generators, node.elt, elt_ty = synthesize_comprehension(
538
+ node, node.generators, node.elt, self.ctx
539
+ )
540
+ result_ty = list_type(elt_ty)
541
+ return node, result_ty
542
+
543
+ def visit_DesugaredGeneratorExpr(
544
+ self, node: DesugaredGeneratorExpr
545
+ ) -> tuple[ast.expr, Type]:
546
+ # This is a generator in an arbitrary expression position. We don't support
547
+ # generators as first-class value yet, so we always error out here. Special
548
+ # cases where generator are allowed need to explicitly check for them (e.g. see
549
+ # the handling of array comprehensions in the compiler for the `array` function)
550
+ raise GuppyError(UnsupportedError(node, "Generator expressions"))
551
+
552
+ def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]:
553
+ # We need to synthesise the argument type, so we can look up dunder methods
554
+ node.operand, op_ty = self.synthesize(node.operand)
555
+
556
+ # Special case for the `not` operation since it is not implemented via a dunder
557
+ # method or control-flow
558
+ if isinstance(node.op, ast.Not):
559
+ node.operand, bool_ty = to_bool(node.operand, op_ty, self.ctx)
560
+ return node, bool_ty
561
+
562
+ # Check all other unary expressions by calling out to instance dunder methods
563
+ op, display_name = unary_table[node.op.__class__]
564
+ func = self.ctx.globals.get_instance_func(op_ty, op)
565
+ if func is None:
566
+ raise GuppyTypeError(
567
+ UnaryOperatorNotDefinedError(node.operand, op_ty, display_name)
568
+ )
569
+ return func.synthesize_call([node.operand], node, self.ctx)
570
+
571
+ def _synthesize_binary(
572
+ self, left_expr: ast.expr, right_expr: ast.expr, op: AstOp, node: ast.expr
573
+ ) -> tuple[ast.expr, Type]:
574
+ """Helper method to compile binary operators by calling out to dunder methods.
575
+
576
+ For example, first try calling `__add__` on the left operand. If that fails, try
577
+ `__radd__` on the right operand.
578
+ """
579
+ if op.__class__ not in binary_table:
580
+ raise GuppyTypeError(UnsupportedError(node, "Operator", singular=True))
581
+ lop, rop, display_name = binary_table[op.__class__]
582
+ left_expr, left_ty = self.synthesize(left_expr)
583
+ right_expr, right_ty = self.synthesize(right_expr)
584
+
585
+ if func := self.ctx.globals.get_instance_func(left_ty, lop):
586
+ with suppress(GuppyError):
587
+ return func.synthesize_call([left_expr, right_expr], node, self.ctx)
588
+
589
+ if func := self.ctx.globals.get_instance_func(right_ty, rop):
590
+ with suppress(GuppyError):
591
+ return func.synthesize_call([right_expr, left_expr], node, self.ctx)
592
+
593
+ raise GuppyTypeError(
594
+ # TODO: Is there a way to get the span of the operator?
595
+ BinaryOperatorNotDefinedError(node, left_ty, right_ty, display_name)
596
+ )
597
+
598
+ def synthesize_instance_func(
599
+ self,
600
+ node: ast.expr,
601
+ args: list[ast.expr],
602
+ func_name: str,
603
+ description: str,
604
+ exp_sig: FunctionType | None = None,
605
+ give_reason: bool = False,
606
+ ) -> tuple[ast.expr, Type]:
607
+ """Helper method for expressions that are implemented via instance methods.
608
+
609
+ Raises a `GuppyTypeError` if the given instance method is not defined. The error
610
+ message can be customised by passing an `err` string and an optional error
611
+ reason can be printed.
612
+
613
+ Optionally, the signature of the instance function can also be checked against a
614
+ given expected signature.
615
+ """
616
+ node, ty = self.synthesize(node)
617
+ func = self.ctx.globals.get_instance_func(ty, func_name)
618
+ if func is None:
619
+ err = BadProtocolError(node, ty, description)
620
+ if give_reason and exp_sig is not None:
621
+ err.add_sub_diagnostic(
622
+ BadProtocolError.MethodMissing(None, func_name, exp_sig)
623
+ )
624
+ raise GuppyTypeError(err)
625
+ if exp_sig and unify(exp_sig, func.ty.unquantified()[0], {}) is None:
626
+ err = BadProtocolError(node, ty, description)
627
+ err.add_sub_diagnostic(
628
+ BadProtocolError.BadSignature(None, ty, func_name, exp_sig, func.ty)
629
+ )
630
+ raise GuppyError(err)
631
+ return func.synthesize_call([node, *args], node, self.ctx)
632
+
633
+ def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, Type]:
634
+ return self._synthesize_binary(node.left, node.right, node.op, node)
635
+
636
+ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]:
637
+ if len(node.comparators) != 1 or len(node.ops) != 1:
638
+ raise InternalGuppyError(
639
+ "BB contains chained comparison. Should have been removed during CFG "
640
+ "construction."
641
+ )
642
+ left_expr, [op], [right_expr] = node.left, node.ops, node.comparators
643
+ return self._synthesize_binary(left_expr, right_expr, op, node)
644
+
645
+ def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]:
646
+ node.value, ty = self.synthesize(node.value)
647
+ # Special case for subscripts on functions: Those are type applications
648
+ if isinstance(ty, FunctionType):
649
+ inst = check_type_apply(ty, node, self.ctx)
650
+ return instantiate_poly(node.value, ty, inst), ty.instantiate(inst)
651
+ item_expr, item_ty = self.synthesize(node.slice)
652
+ # Special case for tuples: Index needs to be known statically in order to infer
653
+ # element type of subscript
654
+ if isinstance(ty, TupleType):
655
+ match item_expr:
656
+ case ast.Constant(value=int(idx)):
657
+ if 0 <= idx < len(ty.element_types):
658
+ result_ty = ty.element_types[idx]
659
+ expr: ast.expr
660
+ if isinstance(node.value, PlaceNode):
661
+ tuple_place = TupleAccess(
662
+ node.value.place, result_ty, idx, None
663
+ )
664
+ expr = PlaceNode(place=tuple_place)
665
+ else:
666
+ expr = TupleAccessAndDrop(node.value, ty, idx)
667
+ return with_loc(node, expr), result_ty
668
+ else:
669
+ raise GuppyError(
670
+ TupleIndexOutOfBoundsError(
671
+ item_expr, idx, len(ty.element_types)
672
+ )
673
+ )
674
+ case _:
675
+ raise GuppyTypeError(ExpectedError(item_expr, "an integer literal"))
676
+ # Otherwise, it's a regular __getitem__ subscript
677
+ # Give the item a unique name so we can refer to it later in case we also want
678
+ # to compile a call to `__setitem__`
679
+ item = Variable(next(tmp_vars), item_ty, item_expr)
680
+ item_node = with_type(item_ty, with_loc(item_expr, PlaceNode(place=item)))
681
+ # Check a call to the `__getitem__` instance function
682
+ exp_sig = FunctionType(
683
+ [
684
+ FuncInput(ty, InputFlags.Inout),
685
+ FuncInput(
686
+ ExistentialTypeVar.fresh("Key", True, True), InputFlags.NoFlags
687
+ ),
688
+ ],
689
+ ExistentialTypeVar.fresh("Val", True, True),
690
+ )
691
+ getitem_expr, result_ty = self.synthesize_instance_func(
692
+ node.value, [item_node], "__getitem__", "subscriptable", exp_sig
693
+ )
694
+ # Subscripting a place is itself a place
695
+ if isinstance(node.value, PlaceNode):
696
+ place = SubscriptAccess(
697
+ node.value.place, item, result_ty, item_expr, getitem_expr
698
+ )
699
+ expr = PlaceNode(place=place)
700
+ else:
701
+ # If the subscript is not on a place, then there is no way to address the
702
+ # other indices after this one has been projected out (e.g. `f()[0]` makes
703
+ # you loose access to all elements besides 0).
704
+ expr = SubscriptAccessAndDrop(
705
+ item=item,
706
+ item_expr=item_expr,
707
+ getitem_expr=getitem_expr,
708
+ original_expr=node,
709
+ )
710
+ return with_loc(node, expr), result_ty
711
+
712
+ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
713
+ if len(node.keywords) > 0:
714
+ raise GuppyError(UnsupportedError(node.keywords[0], "Keyword arguments"))
715
+ node.func, ty = self.synthesize(node.func)
716
+
717
+ # First handle direct calls of user-defined functions and extension functions
718
+ if isinstance(node.func, GlobalName):
719
+ defn = self.ctx.globals[node.func.def_id]
720
+ if isinstance(defn, CallableDef):
721
+ return defn.synthesize_call(node.args, node, self.ctx)
722
+
723
+ # When calling a `PartialApply` node, we just move the args into this call
724
+ if isinstance(node.func, PartialApply):
725
+ node.args = [*node.func.args, *node.args]
726
+ node.func = node.func.func
727
+ return self.visit_Call(node)
728
+
729
+ # Otherwise, it must be a function as a higher-order value, or a tensor
730
+ if isinstance(ty, FunctionType):
731
+ args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx)
732
+ node.func = instantiate_poly(node.func, ty, inst)
733
+ return with_loc(node, LocalCall(func=node.func, args=args)), return_ty
734
+ elif isinstance(ty, TupleType) and (
735
+ function_elems := parse_function_tensor(ty)
736
+ ):
737
+ check_function_tensors_enabled(node.func)
738
+ if any(f.parametrized for f in function_elems):
739
+ raise GuppyError(
740
+ UnsupportedError(node.func, "Polymorphic function tensors")
741
+ )
742
+
743
+ tensor_ty = function_tensor_signature(function_elems)
744
+ args, return_ty, inst = synthesize_call(
745
+ tensor_ty, node.args, node, self.ctx
746
+ )
747
+ assert len(inst) == 0
748
+
749
+ return with_loc(
750
+ node, TensorCall(func=node.func, args=args, tensor_ty=tensor_ty)
751
+ ), return_ty
752
+
753
+ elif f := self.ctx.globals.get_instance_func(ty, "__call__"):
754
+ return f.synthesize_call(node.args, node, self.ctx)
755
+ else:
756
+ raise GuppyTypeError(NotCallableError(node.func, ty))
757
+
758
+ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:
759
+ node.value, ty = self.synthesize(node.value)
760
+ flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
761
+ exp_sig = FunctionType(
762
+ [FuncInput(ty, flags)], ExistentialTypeVar.fresh("Iter", True, True)
763
+ )
764
+ expr, ty = self.synthesize_instance_func(
765
+ node.value, [], "__iter__", "iterable", exp_sig, True
766
+ )
767
+ # Unwrap the size hint if present
768
+ if is_sized_iter_type(ty) and node.unwrap_size_hint:
769
+ expr, ty = self.synthesize_instance_func(expr, [], "unwrap_iter", "")
770
+
771
+ # If the iterator was created by a `for` loop, we can add some extra checks to
772
+ # produce nicer errors for linearity violations. Namely, `break` and `return`
773
+ # are not allowed when looping over a non-copyable iterator (`continue` is
774
+ # allowed)
775
+ if not ty.droppable and isinstance(node.origin_node, ast.For):
776
+ breaks = breaks_in_loop(node.origin_node) or return_nodes_in_ast(
777
+ node.origin_node
778
+ )
779
+ if breaks:
780
+ err = NonDroppableForBreakError(breaks[0])
781
+ err.add_sub_diagnostic(
782
+ NonDroppableForBreakError.NonDroppableIteratorType(node, ty)
783
+ )
784
+ raise GuppyTypeError(err)
785
+ return expr, ty
786
+
787
+ def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
788
+ node.value, ty = self.synthesize(node.value)
789
+ flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
790
+ exp_sig = FunctionType([FuncInput(ty, flags)], TupleType([bool_type(), ty]))
791
+ return self.synthesize_instance_func(
792
+ node.value, [], "__hasnext__", "an iterator", exp_sig, True
793
+ )
794
+
795
+ def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
796
+ node.value, ty = self.synthesize(node.value)
797
+ flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
798
+ exp_sig = FunctionType(
799
+ [FuncInput(ty, flags)],
800
+ option_type(TupleType([ExistentialTypeVar.fresh("T", True, True), ty])),
801
+ )
802
+ return self.synthesize_instance_func(
803
+ node.value, [], "__next__", "an iterator", exp_sig, True
804
+ )
805
+
806
+ def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
807
+ node.value, ty = self.synthesize(node.value)
808
+ flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
809
+ exp_sig = FunctionType([FuncInput(ty, flags)], NoneType())
810
+ return self.synthesize_instance_func(
811
+ node.value, [], "__end__", "an iterator", exp_sig, True
812
+ )
813
+
814
+ def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, Type]:
815
+ raise InternalGuppyError(
816
+ "BB contains `ListComp`. Should have been removed during CFG"
817
+ f"construction: `{ast.unparse(node)}`"
818
+ )
819
+
820
+ def visit_ComptimeExpr(self, node: ComptimeExpr) -> tuple[ast.expr, Type]:
821
+ python_val = eval_comptime_expr(node, self.ctx)
822
+ if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals):
823
+ return with_loc(node, ast.Constant(value=python_val)), ty
824
+
825
+ raise GuppyError(IllegalComptimeExpressionError(node.value, type(python_val)))
826
+
827
+ def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, Type]:
828
+ raise InternalGuppyError(
829
+ "BB contains `NamedExpr`. Should have been removed during CFG"
830
+ f"construction: `{ast.unparse(node)}`"
831
+ )
832
+
833
+ def visit_BoolOp(self, node: ast.BoolOp) -> tuple[ast.expr, Type]:
834
+ raise InternalGuppyError(
835
+ "BB contains `BoolOp`. Should have been removed during CFG construction: "
836
+ f"`{ast.unparse(node)}`"
837
+ )
838
+
839
+ def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, Type]:
840
+ raise InternalGuppyError(
841
+ "BB contains `IfExp`. Should have been removed during CFG construction: "
842
+ f"`{ast.unparse(node)}`"
843
+ )
844
+
845
+ def generic_visit(self, node: ast.expr) -> NoReturn:
846
+ """Called if no explicit visitor function exists for a node."""
847
+ raise GuppyError(UnsupportedError(node, "This expression", singular=True))
848
+
849
+
850
+ def check_type_against(
851
+ act: Type, exp: Type, node: ast.expr, ctx: Context, kind: str = "expression"
852
+ ) -> tuple[ast.expr, Subst, Inst]:
853
+ """Checks a type against another type.
854
+
855
+ Returns a substitution for the free variables the expected type and an instantiation
856
+ for the parameters in the actual type. Note that the expected type may not be
857
+ parametrised and the actual type may not contain free unification variables.
858
+ """
859
+ assert not isinstance(exp, FunctionType) or not exp.parametrized
860
+ assert not act.unsolved_vars
861
+
862
+ # The actual type may be parametrised. In that case, we have to find an
863
+ # instantiation to avoid higher-rank types.
864
+ subst: Subst | None
865
+ if isinstance(act, FunctionType) and act.parametrized:
866
+ unquantified, free_vars = act.unquantified()
867
+ subst = unify(exp, unquantified, {})
868
+ if subst is None:
869
+ raise GuppyTypeError(TypeMismatchError(node, exp, act, kind))
870
+ # Check that we have found a valid instantiation for all params
871
+ for i, v in enumerate(free_vars):
872
+ param = act.params[i].name
873
+ if v not in subst:
874
+ err = TypeMismatchError(node, exp, act, kind)
875
+ err.add_sub_diagnostic(TypeMismatchError.CantInferParam(None, param))
876
+ raise GuppyTypeInferenceError(err)
877
+ if subst[v].unsolved_vars:
878
+ err = TypeMismatchError(node, exp, act, kind)
879
+ err.add_sub_diagnostic(
880
+ TypeMismatchError.CantInstantiateFreeVars(None, param, subst[v])
881
+ )
882
+ raise GuppyTypeError(err)
883
+ inst = [subst[v].to_arg() for v in free_vars]
884
+ subst = {v: t for v, t in subst.items() if v in exp.unsolved_vars}
885
+
886
+ # Finally, check that the instantiation respects the linearity requirements
887
+ check_inst(act, inst, node)
888
+
889
+ return node, subst, inst
890
+
891
+ # Otherwise, we know that `act` has no unsolved type vars, so unification is trivial
892
+ assert not act.unsolved_vars
893
+ subst = unify(exp, act, {})
894
+ if subst is None:
895
+ # Maybe we can implicitly coerce `act` to `exp`
896
+ if coerced := try_coerce_to(act, exp, node, ctx):
897
+ return coerced, {}, []
898
+ raise GuppyTypeError(TypeMismatchError(node, exp, act, kind))
899
+ return node, subst, []
900
+
901
+
902
+ def try_coerce_to(
903
+ act: Type, exp: Type, node: ast.expr, ctx: Context
904
+ ) -> ast.expr | None:
905
+ """Tries to implicitly coerce an expression to a different type.
906
+
907
+ Returns the coerced expression or `None` if the type cannot be implicitly coerced.
908
+ """
909
+ # Currently, we only support implicit coercions of numeric types
910
+ if not isinstance(act, NumericType) or not isinstance(exp, NumericType):
911
+ return None
912
+ # Ordering on `NumericType.Kind` defines the coercion relation
913
+ if act.kind < exp.kind:
914
+ f = ctx.globals.get_instance_func(act, f"__{exp.kind.name.lower()}__")
915
+ assert f is not None
916
+ node, subst = f.check_call([node], exp, node, ctx)
917
+ assert len(subst) == 0, "Coercion methods are not generic"
918
+ return node
919
+ return None
920
+
921
+
922
+ def check_type_apply(ty: FunctionType, node: ast.Subscript, ctx: Context) -> Inst:
923
+ """Checks a `f[T1, T2, ...]` type application of a generic function."""
924
+ func = node.value
925
+ arg_exprs = (
926
+ node.slice.elts
927
+ if isinstance(node.slice, ast.Tuple) and len(node.slice.elts) > 0
928
+ else [node.slice]
929
+ )
930
+ globals = ctx.globals
931
+
932
+ if not ty.parametrized:
933
+ func_name = globals[func.def_id].name if isinstance(func, GlobalName) else None
934
+ raise GuppyError(TypeApplyNotGenericError(node, func_name))
935
+
936
+ exp, act = len(ty.params), len(arg_exprs)
937
+ assert exp > 0
938
+ assert act > 0
939
+ if exp != act:
940
+ if exp < act:
941
+ span = Span(to_span(arg_exprs[exp]).start, to_span(arg_exprs[-1]).end)
942
+ else:
943
+ span = Span(to_span(arg_exprs[-1]).end, to_span(node).end)
944
+ err = WrongNumberOfArgsError(span, exp, act, detailed=True, is_type_apply=True)
945
+ err.add_sub_diagnostic(WrongNumberOfArgsError.SignatureHint(None, ty))
946
+ raise GuppyError(err)
947
+
948
+ return [
949
+ param.check_arg(arg_from_ast(arg_expr, globals, ctx.generic_params), arg_expr)
950
+ for arg_expr, param in zip(arg_exprs, ty.params, strict=True)
951
+ ]
952
+
953
+
954
+ def check_num_args(
955
+ exp: int, act: int, node: AstNode, sig: FunctionType | None = None
956
+ ) -> None:
957
+ """Checks that the correct number of arguments have been passed to a function."""
958
+ if exp == act:
959
+ return
960
+ span, detailed = to_span(node), False
961
+ if isinstance(node, ast.Call):
962
+ # We can construct a nicer error span if we know it's a regular call
963
+ detailed = True
964
+ if exp < act:
965
+ span = Span(to_span(node.args[exp]).start, to_span(node.args[-1]).end)
966
+ elif act > 0:
967
+ span = Span(to_span(node.args[-1]).end, to_span(node).end)
968
+ else:
969
+ span = Span(to_span(node.func).end, to_span(node).end)
970
+ err = WrongNumberOfArgsError(span, exp, act, detailed)
971
+ if sig:
972
+ err.add_sub_diagnostic(WrongNumberOfArgsError.SignatureHint(None, sig))
973
+ raise GuppyTypeError(err)
974
+
975
+
976
+ def type_check_args(
977
+ inputs: list[ast.expr],
978
+ func_ty: FunctionType,
979
+ subst: Subst,
980
+ ctx: Context,
981
+ node: AstNode,
982
+ ) -> tuple[list[ast.expr], Subst]:
983
+ """Checks the arguments of a function call and infers free type variables.
984
+
985
+ We expect that parameters have been replaced with free unification variables.
986
+ Checks that all unification variables can be inferred.
987
+ """
988
+ assert not func_ty.parametrized
989
+ check_num_args(len(func_ty.inputs), len(inputs), node, func_ty)
990
+
991
+ new_args: list[ast.expr] = []
992
+ comptime_args = iter(func_ty.comptime_args)
993
+ for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
994
+ a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
995
+ if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode):
996
+ a.place = check_place_assignable(
997
+ a.place, ctx, a, "able to borrow subscripted elements"
998
+ )
999
+ if InputFlags.Comptime in func_inp.flags:
1000
+ comptime_arg = next(comptime_args)
1001
+ s = check_comptime_arg(a, comptime_arg.const, func_inp.ty, s)
1002
+ new_args.append(a)
1003
+ subst |= s
1004
+ assert next(comptime_args, None) is None
1005
+
1006
+ # If the argument check succeeded, this means that we must have found instantiations
1007
+ # for all unification variables occurring in the input types
1008
+ assert all(
1009
+ set.issubset(inp.ty.unsolved_vars, subst.keys()) for inp in func_ty.inputs
1010
+ )
1011
+
1012
+ # We also have to check that we found instantiations for all vars in the return type
1013
+ if not set.issubset(func_ty.output.unsolved_vars, subst.keys()):
1014
+ raise GuppyTypeInferenceError(
1015
+ TypeInferenceError(node, func_ty.output.substitute(subst))
1016
+ )
1017
+
1018
+ return new_args, subst
1019
+
1020
+
1021
+ def check_place_assignable(
1022
+ place: Place, ctx: Context, node: ast.expr, reason: str
1023
+ ) -> Place:
1024
+ """Performs additional checks for assignments to places, for example for borrowed
1025
+ place arguments after function returns.
1026
+
1027
+ In particular, we need to check that places involving `place[item]` subscripts
1028
+ implement the corresponding `__setitem__` method.
1029
+ """
1030
+ match place:
1031
+ case Variable():
1032
+ return place
1033
+ case FieldAccess(parent=parent):
1034
+ return replace(
1035
+ place, parent=check_place_assignable(parent, ctx, node, reason)
1036
+ )
1037
+ case SubscriptAccess(parent=parent, item=item, ty=ty):
1038
+ # Create temporary variable for the setitem value
1039
+ tmp_var = Variable(next(tmp_vars), item.ty, node)
1040
+ # Check a call to the `__setitem__` instance function
1041
+ exp_sig = FunctionType(
1042
+ [
1043
+ FuncInput(parent.ty, InputFlags.Inout),
1044
+ FuncInput(item.ty, InputFlags.NoFlags),
1045
+ FuncInput(ty, InputFlags.Owned),
1046
+ ],
1047
+ NoneType(),
1048
+ )
1049
+ setitem_args: list[ast.expr] = [
1050
+ with_type(parent.ty, with_loc(node, PlaceNode(parent))),
1051
+ with_type(item.ty, with_loc(node, PlaceNode(item))),
1052
+ with_type(ty, with_loc(node, PlaceNode(tmp_var))),
1053
+ ]
1054
+ setitem_call, _ = ExprSynthesizer(ctx).synthesize_instance_func(
1055
+ setitem_args[0],
1056
+ setitem_args[1:],
1057
+ "__setitem__",
1058
+ reason,
1059
+ exp_sig,
1060
+ True,
1061
+ )
1062
+ return replace(place, setitem_call=SetitemCall(setitem_call, tmp_var))
1063
+ case TupleAccess(parent=parent):
1064
+ return replace(
1065
+ place, parent=check_place_assignable(parent, ctx, node, reason)
1066
+ )
1067
+
1068
+
1069
+ def check_comptime_arg(
1070
+ arg: ast.expr, exp_const: Const, ty: Type, subst: Subst | None
1071
+ ) -> Subst:
1072
+ """Checks that an expression can be passes as a valid `@comptime` argument.
1073
+
1074
+ Also checks that the value matches the provided constant. Returns a substitution
1075
+ that solves any existential variables occurring in provided constant.
1076
+ """
1077
+ const: Const
1078
+ match arg:
1079
+ case ast.Constant(value=v):
1080
+ const = ConstValue(ty, v)
1081
+ case GenericParamValue(param=const_param):
1082
+ const = const_param.to_bound().const
1083
+ case arg:
1084
+ # Anything else is considered unknown at comptime, but we can give some
1085
+ # nicer error hints by inspecting in more detail
1086
+ err = ComptimeUnknownError(arg, "argument")
1087
+ s: SubDiagnostic
1088
+ match arg:
1089
+ case PlaceNode(place=place) if place.root.is_func_input:
1090
+ s = ComptimeUnknownError.InputHint(place.defined_at, place)
1091
+ case PlaceNode(place=place) if not is_tmp_var(place.root.name):
1092
+ s = ComptimeUnknownError.VariableHint(place.defined_at, place)
1093
+ case arg:
1094
+ s = ComptimeUnknownError.FallbackHint(arg)
1095
+ err.add_sub_diagnostic(s)
1096
+ err.add_sub_diagnostic(ComptimeUnknownError.Feedback(None))
1097
+ raise GuppyError(err)
1098
+ # Unify with expected constant to check and maybe infer some variables
1099
+ subst = unify(exp_const, const, subst)
1100
+ if subst is None:
1101
+ raise GuppyError(ConstMismatchError(arg, exp_const, const))
1102
+ return subst
1103
+
1104
+
1105
+ def synthesize_call(
1106
+ func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context
1107
+ ) -> tuple[list[ast.expr], Type, Inst]:
1108
+ """Synthesizes the return type of a function call.
1109
+
1110
+ Returns an annotated argument list, the synthesized return type, and an
1111
+ instantiation for the quantifiers in the function type.
1112
+ """
1113
+ assert not func_ty.unsolved_vars
1114
+ check_num_args(len(func_ty.inputs), len(args), node, func_ty)
1115
+
1116
+ # Replace quantified variables with free unification variables and try to infer an
1117
+ # instantiation by checking the arguments
1118
+ unquantified, free_vars = func_ty.unquantified()
1119
+ args, subst = type_check_args(args, unquantified, {}, ctx, node)
1120
+
1121
+ # Success implies that the substitution is closed
1122
+ assert all(not t.unsolved_vars for t in subst.values())
1123
+ inst = [subst[v].to_arg() for v in free_vars]
1124
+
1125
+ # Finally, check that the instantiation respects the linearity requirements
1126
+ check_inst(func_ty, inst, node)
1127
+
1128
+ return args, unquantified.output.substitute(subst), inst
1129
+
1130
+
1131
+ def check_call(
1132
+ func_ty: FunctionType,
1133
+ inputs: list[ast.expr],
1134
+ ty: Type,
1135
+ node: AstNode,
1136
+ ctx: Context,
1137
+ kind: str = "expression",
1138
+ ) -> tuple[list[ast.expr], Subst, Inst]:
1139
+ """Checks the return type of a function call against a given type.
1140
+
1141
+ Returns an annotated argument list, a substitution for the free variables in the
1142
+ expected type, and an instantiation for the quantifiers in the function type.
1143
+ """
1144
+ assert not func_ty.unsolved_vars
1145
+ check_num_args(len(func_ty.inputs), len(inputs), node, func_ty)
1146
+
1147
+ # When checking, we can use the information from the expected return type to infer
1148
+ # some type arguments. However, this pushes errors inwards. For example, given a
1149
+ # function `foo: forall T. T -> T`, the following type mismatch would be reported:
1150
+ #
1151
+ # x: int = foo(None)
1152
+ # ^^^^ Expected argument of type `int`, got `None`
1153
+ #
1154
+ # But the following error location would be more intuitive for users:
1155
+ #
1156
+ # x: int = foo(None)
1157
+ # ^^^^^^^^^ Expected expression of type `int`, got `None`
1158
+ #
1159
+ # In other words, if we can get away with synthesising the call without the extra
1160
+ # information from the expected type, we should do that to improve the error.
1161
+
1162
+ # TODO: The approach below can result in exponential runtime in the worst case.
1163
+ # However the bad case, e.g. `x: int = foo(foo(...foo(?)...))`, shouldn't be common
1164
+ # in practice. Can we do better than that?
1165
+
1166
+ # First, try to synthesize
1167
+ res: tuple[Type, Inst] | None = None
1168
+ try:
1169
+ inputs, synth, inst = synthesize_call(func_ty, inputs, node, ctx)
1170
+ res = synth, inst
1171
+ except GuppyTypeInferenceError:
1172
+ pass
1173
+ if res is not None:
1174
+ synth, inst = res
1175
+ subst = unify(ty, synth, {})
1176
+ if subst is None:
1177
+ raise GuppyTypeError(TypeMismatchError(node, ty, synth, kind))
1178
+ return inputs, subst, inst
1179
+
1180
+ # If synthesis fails, we try again, this time also using information from the
1181
+ # expected return type
1182
+ unquantified, free_vars = func_ty.unquantified()
1183
+ subst = unify(ty, unquantified.output, {})
1184
+ if subst is None:
1185
+ raise GuppyTypeError(TypeMismatchError(node, ty, unquantified.output, kind))
1186
+
1187
+ # Try to infer more by checking against the arguments
1188
+ inputs, subst = type_check_args(inputs, unquantified, subst, ctx, node)
1189
+
1190
+ # Also make sure we found an instantiation for all free vars in the type we're
1191
+ # checking against
1192
+ if not set.issubset(ty.unsolved_vars, subst.keys()):
1193
+ unsolved = (subst.keys() - ty.unsolved_vars).pop()
1194
+ err = TypeMismatchError(node, ty, func_ty.output.substitute(subst))
1195
+ err.add_sub_diagnostic(
1196
+ TypeMismatchError.CantInferParam(None, unsolved.display_name)
1197
+ )
1198
+ raise GuppyTypeInferenceError(err)
1199
+
1200
+ # Success implies that the substitution is closed
1201
+ assert all(not t.unsolved_vars for t in subst.values())
1202
+ inst = [subst[v].to_arg() for v in free_vars]
1203
+ subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars}
1204
+
1205
+ # Finally, check that the instantiation respects the linearity requirements
1206
+ check_inst(func_ty, inst, node)
1207
+
1208
+ return inputs, subst, inst
1209
+
1210
+
1211
+ def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None:
1212
+ """Checks if an instantiation is valid.
1213
+
1214
+ Makes sure that the linearity requirements are satisfied.
1215
+ """
1216
+ for param, arg in zip(func_ty.params, inst, strict=True):
1217
+ # Give a more informative error message for linearity issues
1218
+ if isinstance(param, TypeParam) and isinstance(arg, TypeArg):
1219
+ if param.must_be_copyable and not arg.ty.copyable:
1220
+ raise GuppyTypeError(
1221
+ NonLinearInstantiateError(node, param, func_ty, arg.ty)
1222
+ )
1223
+ if param.must_be_droppable and not arg.ty.droppable:
1224
+ raise GuppyTypeError(
1225
+ NonLinearInstantiateError(node, param, func_ty, arg.ty)
1226
+ )
1227
+ # For everything else, we fall back to the default checking implementation
1228
+ param.check_arg(arg, node)
1229
+
1230
+
1231
+ def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr:
1232
+ """Instantiates quantified type arguments in a function."""
1233
+ assert len(ty.params) == len(inst)
1234
+ if len(inst) > 0:
1235
+ node = with_loc(node, TypeApply(value=with_type(ty, node), inst=inst))
1236
+ return with_type(ty.instantiate(inst), node)
1237
+ return with_type(ty, node)
1238
+
1239
+
1240
+ def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type]:
1241
+ """Tries to turn a node into a bool"""
1242
+ if is_bool_type(node_ty):
1243
+ return node, node_ty
1244
+ synth = ExprSynthesizer(ctx)
1245
+ exp_sig = FunctionType([FuncInput(node_ty, InputFlags.Inout)], bool_type())
1246
+ return synth.synthesize_instance_func(node, [], "__bool__", "truthy", exp_sig, True)
1247
+
1248
+
1249
+ def synthesize_comprehension(
1250
+ node: AstNode, gens: list[DesugaredGenerator], elt: ast.expr, ctx: Context
1251
+ ) -> tuple[list[DesugaredGenerator], ast.expr, Type]:
1252
+ """Helper function to synthesise the element type of a list comprehension."""
1253
+ # If there are no more generators left, we can check the list element
1254
+ if not gens:
1255
+ elt, elt_ty = ExprSynthesizer(ctx).synthesize(elt)
1256
+ return gens, elt, elt_ty
1257
+
1258
+ # Check the first generator
1259
+ gen, *gens = gens
1260
+ gen, inner_ctx = check_generator(gen, ctx)
1261
+
1262
+ # Check remaining generators in inner context
1263
+ gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx)
1264
+
1265
+ return [gen, *gens], elt, elt_ty
1266
+
1267
+
1268
+ def check_generator(
1269
+ gen: DesugaredGenerator, ctx: Context
1270
+ ) -> tuple[DesugaredGenerator, Context]:
1271
+ """Helper function to check a single generator.
1272
+
1273
+ Returns the type annotated generator together with a new nested context in which the
1274
+ generator variables are bound.
1275
+ """
1276
+ from guppylang_internals.checker.stmt_checker import StmtChecker
1277
+
1278
+ # Check the iterator in the outer context
1279
+ gen.iter_assign = StmtChecker(ctx).visit_Assign(gen.iter_assign)
1280
+
1281
+ # The rest is checked in a new nested context to ensure that variables don't escape
1282
+ # their scope
1283
+ inner_locals: Locals[str, Variable] = Locals({}, parent_scope=ctx.locals)
1284
+ inner_ctx = Context(ctx.globals, inner_locals, ctx.generic_params)
1285
+ expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx)
1286
+ gen.iter, iter_ty = expr_sth.visit(gen.iter)
1287
+ gen.iter = with_type(iter_ty, gen.iter)
1288
+
1289
+ # The type returned by `next_call` is `Option[tuple[elt_ty, iter_ty]]`
1290
+ gen.next_call, option_ty = expr_sth.synthesize(gen.next_call)
1291
+ next_ty = get_element_type(option_ty)
1292
+ assert isinstance(next_ty, TupleType)
1293
+ [elt_ty, _] = next_ty.element_types
1294
+ gen.target = stmt_chk._check_assign(gen.target, gen.next_call, elt_ty)
1295
+
1296
+ # Check `if` guards
1297
+ for i in range(len(gen.ifs)):
1298
+ gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i])
1299
+ gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx)
1300
+
1301
+ return gen, inner_ctx
1302
+
1303
+
1304
+ def eval_comptime_expr(node: ComptimeExpr, ctx: Context) -> Any:
1305
+ """Evaluates a `comptime(...)` expression."""
1306
+ # The method we used for obtaining the Python variables in scope only works in
1307
+ # CPython (see `get_py_scope()`).
1308
+ if sys.implementation.name != "cpython":
1309
+ raise GuppyError(ComptimeExprNotCPythonError(node))
1310
+
1311
+ try:
1312
+ python_val = eval( # noqa: S307
1313
+ ast.unparse(node.value),
1314
+ None,
1315
+ DummyEvalDict(ctx, node.value),
1316
+ )
1317
+ except DummyEvalDict.GuppyVarUsedError as e:
1318
+ raise GuppyError(ComptimeExprNotStaticError(e.node or node, e.var)) from None
1319
+ except Exception as e:
1320
+ # Remove the top frame pointing to the `eval` call from the stack trace
1321
+ tb = e.__traceback__.tb_next if e.__traceback__ else None
1322
+ tb_formatted = "".join(traceback.format_exception(type(e), e, tb))
1323
+ raise GuppyError(ComptimeExprEvalError(node.value, tb_formatted)) from e
1324
+ return python_val
1325
+
1326
+
1327
+ def python_value_to_guppy_type(
1328
+ v: Any, node: ast.AST, globals: Globals, type_hint: Type | None = None
1329
+ ) -> Type | None:
1330
+ """Turns a primitive Python value into a Guppy type.
1331
+
1332
+ Accepts an optional `type_hint` for the expected expression type that is used to
1333
+ infer a more precise type (e.g. distinguishing between `int` and `nat`). Note that
1334
+ invalid hints are ignored, i.e. no user error are emitted.
1335
+
1336
+ Returns `None` if the Python value cannot be represented in Guppy.
1337
+ """
1338
+ match v:
1339
+ case bool():
1340
+ return bool_type()
1341
+ case str():
1342
+ return string_type()
1343
+ # Only resolve `int` to `nat` if the user specifically asked for it
1344
+ case int(n) if type_hint == nat_type() and n >= 0:
1345
+ _int_bounds_check(n, node, signed=False)
1346
+ return nat_type()
1347
+ # Otherwise, default to `int` for consistency with Python
1348
+ case int(n):
1349
+ _int_bounds_check(n, node, signed=True)
1350
+ return int_type()
1351
+ case float():
1352
+ return float_type()
1353
+ case tuple(elts):
1354
+ hints = (
1355
+ type_hint.element_types
1356
+ if isinstance(type_hint, TupleType)
1357
+ else len(elts) * [None]
1358
+ )
1359
+ tys = [
1360
+ python_value_to_guppy_type(elt, node, globals, hint)
1361
+ for elt, hint in zip(elts, hints, strict=False)
1362
+ ]
1363
+ if any(ty is None for ty in tys):
1364
+ return None
1365
+ return TupleType(cast(list[Type], tys))
1366
+ case list():
1367
+ return _python_list_to_guppy_type(v, node, globals, type_hint)
1368
+ case _:
1369
+ return None
1370
+
1371
+
1372
+ def _int_bounds_check(value: int, node: AstNode, signed: bool) -> None:
1373
+ bit_width = 1 << NumericType.INT_WIDTH
1374
+ if signed:
1375
+ max_v = (1 << (bit_width - 1)) - 1
1376
+ min_v = -(1 << (bit_width - 1))
1377
+ else:
1378
+ max_v = (1 << bit_width) - 1
1379
+ min_v = 0
1380
+ if value < min_v or value > max_v:
1381
+ err = IntOverflowError(node, signed, bit_width, value < min_v)
1382
+ raise GuppyTypeError(err)
1383
+
1384
+
1385
+ def _python_list_to_guppy_type(
1386
+ vs: list[Any], node: ast.AST, globals: Globals, type_hint: Type | None
1387
+ ) -> OpaqueType | None:
1388
+ """Turns a Python list into a Guppy type.
1389
+
1390
+ Returns `None` if the list contains different types or types that are not
1391
+ representable in Guppy.
1392
+ """
1393
+ if len(vs) == 0:
1394
+ return frozenarray_type(ExistentialTypeVar.fresh("T", True, True), 0)
1395
+
1396
+ # All the list elements must have a unifiable types
1397
+ v, *rest = vs
1398
+ elt_hint = (
1399
+ get_element_type(type_hint)
1400
+ if type_hint and is_frozenarray_type(type_hint)
1401
+ else None
1402
+ )
1403
+ el_ty = python_value_to_guppy_type(v, node, globals, elt_hint)
1404
+ if el_ty is None:
1405
+ return None
1406
+ for v in rest:
1407
+ ty = python_value_to_guppy_type(v, node, globals, elt_hint)
1408
+ if ty is None:
1409
+ return None
1410
+ if (subst := unify(ty, el_ty, {})) is None:
1411
+ raise GuppyError(ComptimeExprIncoherentListError(node))
1412
+ el_ty = el_ty.substitute(subst)
1413
+ return frozenarray_type(el_ty, len(vs))