guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,425 @@
1
+ import ast
2
+ import sys
3
+ from collections.abc import Sequence
4
+ from types import ModuleType
5
+
6
+ from guppylang_internals.ast_util import (
7
+ AstNode,
8
+ set_location_from,
9
+ shift_loc,
10
+ )
11
+ from guppylang_internals.cfg.builder import is_comptime_expression
12
+ from guppylang_internals.checker.core import Context, Globals, Locals, PythonObject
13
+ from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
14
+ from guppylang_internals.definition.common import Definition
15
+ from guppylang_internals.definition.parameter import ParamDef
16
+ from guppylang_internals.definition.ty import TypeDef
17
+ from guppylang_internals.engine import ENGINE
18
+ from guppylang_internals.error import GuppyError
19
+ from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
20
+ from guppylang_internals.tys.builtin import CallableTypeDef, bool_type
21
+ from guppylang_internals.tys.const import ConstValue
22
+ from guppylang_internals.tys.errors import (
23
+ CallableComptimeError,
24
+ ComptimeArgShadowError,
25
+ FlagNotAllowedError,
26
+ FreeTypeVarError,
27
+ HigherKindedTypeVarError,
28
+ IllegalComptimeTypeArgError,
29
+ InvalidCallableTypeError,
30
+ InvalidFlagError,
31
+ InvalidTypeArgError,
32
+ InvalidTypeError,
33
+ LinearComptimeError,
34
+ LinearConstParamError,
35
+ ModuleMemberNotFoundError,
36
+ NonLinearOwnedError,
37
+ )
38
+ from guppylang_internals.tys.param import ConstParam, Parameter, TypeParam
39
+ from guppylang_internals.tys.subst import BoundVarFinder
40
+ from guppylang_internals.tys.ty import (
41
+ FuncInput,
42
+ FunctionType,
43
+ InputFlags,
44
+ NoneType,
45
+ NumericType,
46
+ TupleType,
47
+ Type,
48
+ )
49
+
50
+
51
+ def arg_from_ast(
52
+ node: AstNode,
53
+ globals: Globals,
54
+ param_var_mapping: dict[str, Parameter],
55
+ allow_free_vars: bool = False,
56
+ ) -> Argument:
57
+ """Turns an AST expression into an argument."""
58
+ from guppylang_internals.checker.cfg_checker import VarNotDefinedError
59
+
60
+ # A single (possibly qualified) identifier
61
+ if defn := _try_parse_defn(node, globals):
62
+ return _arg_from_instantiated_defn(
63
+ defn, [], globals, node, param_var_mapping, allow_free_vars
64
+ )
65
+
66
+ # An identifier referring to a quantified variable
67
+ if isinstance(node, ast.Name):
68
+ if node.id in param_var_mapping:
69
+ return param_var_mapping[node.id].to_bound()
70
+ raise GuppyError(VarNotDefinedError(node, node.id))
71
+
72
+ # A parametrised type, e.g. `list[??]`
73
+ if isinstance(node, ast.Subscript) and (
74
+ defn := _try_parse_defn(node.value, globals)
75
+ ):
76
+ arg_nodes = (
77
+ node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice]
78
+ )
79
+ return _arg_from_instantiated_defn(
80
+ defn, arg_nodes, globals, node, param_var_mapping, allow_free_vars
81
+ )
82
+
83
+ # We allow tuple types to be written as `(int, bool)`
84
+ if isinstance(node, ast.Tuple):
85
+ ty = TupleType(
86
+ [
87
+ type_from_ast(el, globals, param_var_mapping, allow_free_vars)
88
+ for el in node.elts
89
+ ]
90
+ )
91
+ return TypeArg(ty)
92
+
93
+ # Literals
94
+ if isinstance(node, ast.Constant):
95
+ match node.value:
96
+ # `None` is represented as a `ast.Constant` node with value `None`
97
+ case None:
98
+ return TypeArg(NoneType())
99
+ case bool(v):
100
+ return ConstArg(ConstValue(bool_type(), v))
101
+ # Integer literals are turned into nat args.
102
+ # TODO: To support int args, we need proper inference logic here
103
+ # See https://github.com/CQCL/guppylang/issues/1030
104
+ case int(v) if v >= 0:
105
+ nat_ty = NumericType(NumericType.Kind.Nat)
106
+ return ConstArg(ConstValue(nat_ty, v))
107
+ case float(v):
108
+ float_ty = NumericType(NumericType.Kind.Float)
109
+ return ConstArg(ConstValue(float_ty, v))
110
+ # String literals are ignored for now since they could also be stringified
111
+ # types.
112
+ # TODO: To support string args, we need proper inference logic here
113
+ # See https://github.com/CQCL/guppylang/issues/1030
114
+ case str(_):
115
+ pass
116
+
117
+ # Py-expressions can also be used to specify static numbers
118
+ if comptime_expr := is_comptime_expression(node):
119
+ from guppylang_internals.checker.expr_checker import eval_comptime_expr
120
+
121
+ v = eval_comptime_expr(comptime_expr, Context(globals, Locals({}), {}))
122
+ if isinstance(v, int):
123
+ nat_ty = NumericType(NumericType.Kind.Nat)
124
+ return ConstArg(ConstValue(nat_ty, v))
125
+ else:
126
+ raise GuppyError(IllegalComptimeTypeArgError(node, v))
127
+
128
+ # Finally, we also support delayed annotations in strings
129
+ if isinstance(node, ast.Constant) and isinstance(node.value, str):
130
+ node = _parse_delayed_annotation(node.value, node)
131
+ return arg_from_ast(node, globals, param_var_mapping, allow_free_vars)
132
+
133
+ raise GuppyError(InvalidTypeArgError(node))
134
+
135
+
136
+ def _try_parse_defn(node: AstNode, globals: Globals) -> Definition | None:
137
+ """Tries to parse a (possibly qualified) name into a global definition."""
138
+ from guppylang.defs import GuppyDefinition
139
+ from guppylang_internals.checker.cfg_checker import VarNotDefinedError
140
+
141
+ match node:
142
+ case ast.Name(id=x):
143
+ if x not in globals:
144
+ return None
145
+ defn = globals[x]
146
+ if isinstance(defn, PythonObject):
147
+ return None
148
+ return defn
149
+ case ast.Attribute(value=ast.Name(id=module_name) as value, attr=x):
150
+ if module_name not in globals:
151
+ raise GuppyError(VarNotDefinedError(value, module_name))
152
+ match globals[module_name]:
153
+ case PythonObject(ModuleType() as module):
154
+ if x in module.__dict__:
155
+ val = module.__dict__[x]
156
+ if isinstance(val, GuppyDefinition):
157
+ return ENGINE.get_parsed(val.id)
158
+ raise GuppyError(
159
+ ModuleMemberNotFoundError(node, module.__name__, x)
160
+ )
161
+ case _:
162
+ raise GuppyError(ExpectedError(value, "a module"))
163
+ case _:
164
+ return None
165
+
166
+
167
+ def _arg_from_instantiated_defn(
168
+ defn: Definition,
169
+ arg_nodes: list[ast.expr],
170
+ globals: Globals,
171
+ node: AstNode,
172
+ param_var_mapping: dict[str, Parameter],
173
+ allow_free_vars: bool = False,
174
+ ) -> Argument:
175
+ """Parses a globals definition with type args into an argument."""
176
+ match defn:
177
+ # Special case for the `Callable` type
178
+ case CallableTypeDef():
179
+ return TypeArg(
180
+ _parse_callable_type(
181
+ arg_nodes, node, globals, param_var_mapping, allow_free_vars
182
+ )
183
+ )
184
+ # Either a defined type (e.g. `int`, `bool`, ...)
185
+ case TypeDef() as defn:
186
+ args = [
187
+ arg_from_ast(arg_node, globals, param_var_mapping, allow_free_vars)
188
+ for arg_node in arg_nodes
189
+ ]
190
+ ty = defn.check_instantiate(args, node)
191
+ return TypeArg(ty)
192
+ # Or a parameter (e.g. `T`, `n`, ...)
193
+ case ParamDef() as defn:
194
+ # We don't allow parametrised variables like `T[int]`
195
+ if arg_nodes:
196
+ raise GuppyError(HigherKindedTypeVarError(node, defn))
197
+ if defn.name not in param_var_mapping:
198
+ if allow_free_vars:
199
+ param_var_mapping[defn.name] = defn.to_param(len(param_var_mapping))
200
+ else:
201
+ raise GuppyError(FreeTypeVarError(node, defn))
202
+ return param_var_mapping[defn.name].to_bound()
203
+ case defn:
204
+ err = ExpectedError(node, "a type", got=f"{defn.description} `{defn.name}`")
205
+ raise GuppyError(err)
206
+
207
+
208
+ def _parse_delayed_annotation(ast_str: str, node: ast.Constant) -> ast.expr:
209
+ """Parses a delayed type annotation in a string."""
210
+ try:
211
+ [stmt] = ast.parse(ast_str).body
212
+ if not isinstance(stmt, ast.Expr):
213
+ raise GuppyError(InvalidTypeError(node))
214
+ set_location_from(stmt, loc=node)
215
+ shift_loc(
216
+ stmt,
217
+ delta_lineno=node.lineno - 1, # -1 since lines start at 1
218
+ delta_col_offset=node.col_offset + 1, # +1 to remove the `"`
219
+ )
220
+ except (SyntaxError, ValueError):
221
+ raise GuppyError(InvalidTypeError(node)) from None
222
+ else:
223
+ return stmt.value
224
+
225
+
226
+ def _parse_callable_type(
227
+ args: list[ast.expr],
228
+ loc: AstNode,
229
+ globals: Globals,
230
+ param_var_mapping: dict[str, Parameter],
231
+ allow_free_vars: bool = False,
232
+ ) -> FunctionType:
233
+ """Helper function to parse a `Callable[[<arguments>], <return type>]` type."""
234
+ err = InvalidCallableTypeError(loc)
235
+ if len(args) != 2:
236
+ raise GuppyError(err)
237
+ [inputs, output] = args
238
+ if not isinstance(inputs, ast.List):
239
+ raise GuppyError(err)
240
+ inouts, output = parse_function_io_types(
241
+ inputs.elts, output, None, loc, globals, param_var_mapping, allow_free_vars
242
+ )
243
+ return FunctionType(inouts, output)
244
+
245
+
246
+ def parse_function_io_types(
247
+ input_nodes: list[ast.expr],
248
+ output_node: ast.expr,
249
+ input_names: list[str] | None,
250
+ loc: AstNode,
251
+ globals: Globals,
252
+ param_var_mapping: dict[str, Parameter],
253
+ allow_free_vars: bool = False,
254
+ ) -> tuple[list[FuncInput], Type]:
255
+ """Parses the inputs and output types of a function type.
256
+
257
+ This function takes care of parsing annotations and any related checks.
258
+
259
+ Returns the parsed input and output types.
260
+ """
261
+ inputs = []
262
+ for i, inp in enumerate(input_nodes):
263
+ ty, flags = type_with_flags_from_ast(
264
+ inp, globals, param_var_mapping, allow_free_vars
265
+ )
266
+ if InputFlags.Owned in flags and ty.copyable:
267
+ raise GuppyError(NonLinearOwnedError(loc, ty))
268
+ if not ty.copyable and InputFlags.Owned not in flags:
269
+ flags |= InputFlags.Inout
270
+ if InputFlags.Comptime in flags:
271
+ if input_names is None:
272
+ raise GuppyError(CallableComptimeError(inp))
273
+ name = input_names[i]
274
+
275
+ # Make sure we're not shadowing a type variable with the same name that was
276
+ # already used on the left. E.g
277
+ #
278
+ # n = guppy.type_var("n")
279
+ # def foo(xs: array[int, n], n: nat @comptime)
280
+ #
281
+ # TODO: In principle we could lift this restriction by tracking multiple
282
+ # params referring to the same name in `param_var_mapping`, but not sure if
283
+ # this would be worth it...
284
+ if name in param_var_mapping:
285
+ raise GuppyError(ComptimeArgShadowError(inp, name))
286
+ param_var_mapping[name] = ConstParam(
287
+ len(param_var_mapping), name, ty, from_comptime_arg=True
288
+ )
289
+
290
+ inputs.append(FuncInput(ty, flags))
291
+ output = type_from_ast(output_node, globals, param_var_mapping, allow_free_vars)
292
+ return inputs, output
293
+
294
+
295
+ if sys.version_info >= (3, 12):
296
+
297
+ def parse_parameter(node: ast.type_param, idx: int, globals: Globals) -> Parameter:
298
+ """Parses a `Variable: Bound` generic type parameter declaration."""
299
+ if isinstance(node, ast.TypeVarTuple | ast.ParamSpec):
300
+ raise GuppyError(UnsupportedError(node, "Variadic generic parameters"))
301
+ assert isinstance(node, ast.TypeVar)
302
+
303
+ match node.bound:
304
+ # No bound means it's a linear type parameter
305
+ case None:
306
+ return TypeParam(
307
+ idx, node.name, must_be_copyable=False, must_be_droppable=False
308
+ )
309
+ # Special `Copy` or `Drop` bounds for types
310
+ case ast.Name(id="Copy"):
311
+ return TypeParam(
312
+ idx, node.name, must_be_copyable=True, must_be_droppable=False
313
+ )
314
+ case ast.Name(id="Drop"):
315
+ return TypeParam(
316
+ idx, node.name, must_be_copyable=False, must_be_droppable=True
317
+ )
318
+ # Copy and drop is annotated as `T: (Copy, Drop)`
319
+ # TODO: Should we also allow `T: Copy + Drop`? Mypy would complain about it
320
+ case ast.Tuple(elts=[ast.Name(id=id1), ast.Name(id=id2)]) if {id1, id2} == {
321
+ "Copy",
322
+ "Drop",
323
+ }:
324
+ return TypeParam(
325
+ idx, node.name, must_be_copyable=True, must_be_droppable=True
326
+ )
327
+ # Otherwise, it must be a const parameter
328
+ case bound:
329
+ # For now, we don't allow the types of const params to refer to previous
330
+ # parameters, so we pass an empty dict as the `param_var_mapping`.
331
+ # TODO: In the future we might want to allow stuff like
332
+ # `def foo[T, XS: array[T, 42]]` and so on
333
+ ty = type_from_ast(bound, globals, {}, allow_free_vars=False)
334
+ if not ty.copyable or not ty.droppable:
335
+ raise GuppyError(LinearConstParamError(bound, ty))
336
+
337
+ # TODO: For now we can only do `nat` const args since they lower to
338
+ # Hugr bounded nats. Extend to arbitrary types via monomorphization.
339
+ # See https://github.com/CQCL/guppylang/issues/1008
340
+ if ty != NumericType(NumericType.Kind.Nat):
341
+ raise GuppyError(
342
+ UnsupportedError(bound, f"`{ty}` generic parameters")
343
+ )
344
+ return ConstParam(idx, node.name, ty)
345
+
346
+
347
+ _type_param = TypeParam(0, "T", False, False)
348
+
349
+
350
+ def type_with_flags_from_ast(
351
+ node: AstNode,
352
+ globals: Globals,
353
+ param_var_mapping: dict[str, Parameter],
354
+ allow_free_vars: bool = False,
355
+ ) -> tuple[Type, InputFlags]:
356
+ if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
357
+ ty, flags = type_with_flags_from_ast(
358
+ node.left, globals, param_var_mapping, allow_free_vars
359
+ )
360
+ match node.right:
361
+ case ast.Name(id="owned"):
362
+ if ty.copyable:
363
+ raise GuppyError(NonLinearOwnedError(node.right, ty))
364
+ flags |= InputFlags.Owned
365
+ case ast.Name(id="comptime"):
366
+ flags |= InputFlags.Comptime
367
+ if not ty.copyable or not ty.droppable:
368
+ raise GuppyError(LinearComptimeError(node.right, ty))
369
+ # For now, we don't allow comptime annotations on generic inputs
370
+ # TODO: In the future we might want to allow stuff like
371
+ # `def foo[T: (Copy, Discard](x: T @comptime)`.
372
+ # Also see the todo in `parse_parameter`.
373
+ var_finder = BoundVarFinder()
374
+ ty.visit(var_finder)
375
+ if var_finder.bound_vars:
376
+ raise GuppyError(
377
+ UnsupportedError(node.left, "Generic comptime arguments")
378
+ )
379
+ case _:
380
+ raise GuppyError(InvalidFlagError(node.right))
381
+ return ty, flags
382
+ # We also need to handle the case that this could be a delayed string annotation
383
+ elif isinstance(node, ast.Constant) and isinstance(node.value, str):
384
+ node = _parse_delayed_annotation(node.value, node)
385
+ return type_with_flags_from_ast(
386
+ node, globals, param_var_mapping, allow_free_vars
387
+ )
388
+ else:
389
+ # Parse an argument and check that it's valid for a `TypeParam`
390
+ arg = arg_from_ast(node, globals, param_var_mapping, allow_free_vars)
391
+ tyarg = _type_param.check_arg(arg, node)
392
+ return tyarg.ty, InputFlags.NoFlags
393
+
394
+
395
+ def type_from_ast(
396
+ node: AstNode,
397
+ globals: Globals,
398
+ param_var_mapping: dict[str, Parameter],
399
+ allow_free_vars: bool = False,
400
+ ) -> Type:
401
+ """Turns an AST expression into a Guppy type."""
402
+ ty, flags = type_with_flags_from_ast(
403
+ node, globals, param_var_mapping, allow_free_vars
404
+ )
405
+ if flags != InputFlags.NoFlags:
406
+ assert InputFlags.Inout not in flags # Users shouldn't be able to set this
407
+ raise GuppyError(FlagNotAllowedError(node))
408
+ return ty
409
+
410
+
411
+ def type_row_from_ast(
412
+ node: ast.expr, globals: "Globals", allow_free_vars: bool = False
413
+ ) -> Sequence[Type]:
414
+ """Turns an AST expression into a Guppy type row.
415
+
416
+ This is needed to interpret the return type annotation of functions.
417
+ """
418
+ # The return type `-> None` is represented in the ast as `ast.Constant(value=None)`
419
+ if isinstance(node, ast.Constant) and node.value is None:
420
+ return []
421
+ ty = type_from_ast(node, globals, {}, allow_free_vars)
422
+ if isinstance(ty, TupleType):
423
+ return ty.element_types
424
+ else:
425
+ return [ty]
@@ -0,0 +1,174 @@
1
+ from functools import singledispatchmethod
2
+
3
+ from guppylang_internals.error import InternalGuppyError
4
+ from guppylang_internals.tys.arg import ConstArg, TypeArg
5
+ from guppylang_internals.tys.const import Const, ConstValue
6
+ from guppylang_internals.tys.param import ConstParam, TypeParam
7
+ from guppylang_internals.tys.ty import (
8
+ FunctionType,
9
+ InputFlags,
10
+ NoneType,
11
+ NumericType,
12
+ OpaqueType,
13
+ StructType,
14
+ SumType,
15
+ TupleType,
16
+ Type,
17
+ )
18
+ from guppylang_internals.tys.var import BoundVar, ExistentialVar, UniqueId
19
+
20
+
21
+ class TypePrinter:
22
+ """Visitor that pretty prints types.
23
+
24
+ Takes care of inserting minimal parentheses and renaming variables to make them
25
+ unique.
26
+ """
27
+
28
+ # Store how often each user-picked display name is used to stand for different
29
+ # variables
30
+ used: dict[str, int]
31
+
32
+ # Already chosen names for bound and existential variables
33
+ bound_names: list[str]
34
+ existential_names: dict[UniqueId, str]
35
+
36
+ # Count how often the user has picked the same name to stand for different variables
37
+ counter: dict[str, int]
38
+
39
+ def __init__(self) -> None:
40
+ self.used = {}
41
+ self.bound_names = []
42
+ self.existential_names = {}
43
+ self.counter = {}
44
+
45
+ def _fresh_name(self, display_name: str) -> str:
46
+ if display_name not in self.counter:
47
+ self.counter[display_name] = 1
48
+ return display_name
49
+
50
+ # If the display name `T` has already been used, we start adding indices: `T`,
51
+ # `T'1`, `T'2`, ...
52
+ indexed = f"{display_name}'{self.counter[display_name]}"
53
+ self.counter[display_name] += 1
54
+ return indexed
55
+
56
+ def visit(self, ty: Type | Const) -> str:
57
+ return self._visit(ty, False)
58
+
59
+ @singledispatchmethod
60
+ def _visit(self, ty: Type, inside_row: bool) -> str:
61
+ raise InternalGuppyError(f"Tried to pretty-print unknown type: {ty!r}")
62
+
63
+ @_visit.register
64
+ def _visit_BoundVar(self, var: BoundVar, inside_row: bool) -> str:
65
+ if var.idx < len(self.bound_names):
66
+ return self.bound_names[var.idx]
67
+ return var.display_name
68
+
69
+ @_visit.register
70
+ def _visit_ExistentialVar(self, var: ExistentialVar, inside_row: bool) -> str:
71
+ if var.id not in self.existential_names:
72
+ self.existential_names[var.id] = self._fresh_name(var.display_name)
73
+ return f"?{self.existential_names[var.id]}"
74
+
75
+ @staticmethod
76
+ def _print_flags(flags: InputFlags) -> str:
77
+ s = ""
78
+ if InputFlags.Owned in flags:
79
+ s += " @owned"
80
+ if InputFlags.Comptime in flags:
81
+ s += " @comptime"
82
+ return s
83
+
84
+ @_visit.register
85
+ def _visit_FunctionType(self, ty: FunctionType, inside_row: bool) -> str:
86
+ if ty.parametrized:
87
+ for p in ty.params:
88
+ self.bound_names.append(self._fresh_name(p.name))
89
+ inputs = ", ".join(
90
+ [
91
+ self._visit(inp.ty, True) + self._print_flags(inp.flags)
92
+ for inp in ty.inputs
93
+ ]
94
+ )
95
+ if len(ty.inputs) != 1:
96
+ inputs = f"({inputs})"
97
+ output = self._visit(ty.output, True)
98
+ if ty.parametrized:
99
+ params = [
100
+ self._visit(param, False)
101
+ for param in ty.params
102
+ # Don't print out implicit parameters generated for comptime arguments
103
+ if not isinstance(param, ConstParam) or not param.from_comptime_arg
104
+ ]
105
+ quantified = ", ".join(params)
106
+ del self.bound_names[: -len(ty.params)]
107
+ return _wrap(f"forall {quantified}. {inputs} -> {output}", inside_row)
108
+ return _wrap(f"{inputs} -> {output}", inside_row)
109
+
110
+ @_visit.register(OpaqueType)
111
+ @_visit.register(StructType)
112
+ def _visit_OpaqueType_StructType(
113
+ self, ty: OpaqueType | StructType, inside_row: bool
114
+ ) -> str:
115
+ if ty.args:
116
+ args = ", ".join(self._visit(arg, True) for arg in ty.args)
117
+ return f"{ty.defn.name}[{args}]"
118
+ return ty.defn.name
119
+
120
+ @_visit.register
121
+ def _visit_TupleType(self, ty: TupleType, inside_row: bool) -> str:
122
+ args = ", ".join(self._visit(arg, True) for arg in ty.args)
123
+ return f"({args})"
124
+
125
+ @_visit.register
126
+ def _visit_SumType(self, ty: SumType, inside_row: bool) -> str:
127
+ args = ", ".join(self._visit(arg, True) for arg in ty.args)
128
+ return f"Sum[{args}]"
129
+
130
+ @_visit.register
131
+ def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str:
132
+ return "None"
133
+
134
+ @_visit.register
135
+ def _visit_NumericType(self, ty: NumericType, inside_row: bool) -> str:
136
+ return ty.kind.name.lower()
137
+
138
+ @_visit.register
139
+ def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str:
140
+ # TODO: Print linearity?
141
+ return self.bound_names[param.idx]
142
+
143
+ @_visit.register
144
+ def _visit_ConstParam(self, param: ConstParam, inside_row: bool) -> str:
145
+ kind = self._visit(param.ty, True)
146
+ name = self.bound_names[param.idx]
147
+ return f"{name}: {kind}"
148
+
149
+ @_visit.register
150
+ def _visit_TypeArg(self, arg: TypeArg, inside_row: bool) -> str:
151
+ return self._visit(arg.ty, inside_row)
152
+
153
+ @_visit.register
154
+ def _visit_ConstArg(self, arg: ConstArg, inside_row: bool) -> str:
155
+ return self._visit(arg.const, inside_row)
156
+
157
+ @_visit.register
158
+ def _visit_ConstValue(self, c: ConstValue, inside_row: bool) -> str:
159
+ return str(c.value)
160
+
161
+
162
+ def _wrap(s: str, inside_row: bool) -> str:
163
+ return f"({s})" if inside_row else s
164
+
165
+
166
+ def signature_to_str(name: str, sig: FunctionType) -> str:
167
+ """Displays a function signature in Python syntax including the function name."""
168
+ assert sig.input_names is not None
169
+ s = f"def {name}("
170
+ s += ", ".join(
171
+ f"{name}: {inp.ty}{TypePrinter._print_flags(inp.flags)}"
172
+ for name, inp in zip(sig.input_names, sig.inputs, strict=True)
173
+ )
174
+ return s + ") -> " + str(sig.output)