guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import sys
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from types import ModuleType
|
|
5
|
+
|
|
6
|
+
from guppylang_internals.ast_util import (
|
|
7
|
+
AstNode,
|
|
8
|
+
set_location_from,
|
|
9
|
+
shift_loc,
|
|
10
|
+
)
|
|
11
|
+
from guppylang_internals.cfg.builder import is_comptime_expression
|
|
12
|
+
from guppylang_internals.checker.core import Context, Globals, Locals, PythonObject
|
|
13
|
+
from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
|
|
14
|
+
from guppylang_internals.definition.common import Definition
|
|
15
|
+
from guppylang_internals.definition.parameter import ParamDef
|
|
16
|
+
from guppylang_internals.definition.ty import TypeDef
|
|
17
|
+
from guppylang_internals.engine import ENGINE
|
|
18
|
+
from guppylang_internals.error import GuppyError
|
|
19
|
+
from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
|
|
20
|
+
from guppylang_internals.tys.builtin import CallableTypeDef, bool_type
|
|
21
|
+
from guppylang_internals.tys.const import ConstValue
|
|
22
|
+
from guppylang_internals.tys.errors import (
|
|
23
|
+
CallableComptimeError,
|
|
24
|
+
ComptimeArgShadowError,
|
|
25
|
+
FlagNotAllowedError,
|
|
26
|
+
FreeTypeVarError,
|
|
27
|
+
HigherKindedTypeVarError,
|
|
28
|
+
IllegalComptimeTypeArgError,
|
|
29
|
+
InvalidCallableTypeError,
|
|
30
|
+
InvalidFlagError,
|
|
31
|
+
InvalidTypeArgError,
|
|
32
|
+
InvalidTypeError,
|
|
33
|
+
LinearComptimeError,
|
|
34
|
+
LinearConstParamError,
|
|
35
|
+
ModuleMemberNotFoundError,
|
|
36
|
+
NonLinearOwnedError,
|
|
37
|
+
)
|
|
38
|
+
from guppylang_internals.tys.param import ConstParam, Parameter, TypeParam
|
|
39
|
+
from guppylang_internals.tys.subst import BoundVarFinder
|
|
40
|
+
from guppylang_internals.tys.ty import (
|
|
41
|
+
FuncInput,
|
|
42
|
+
FunctionType,
|
|
43
|
+
InputFlags,
|
|
44
|
+
NoneType,
|
|
45
|
+
NumericType,
|
|
46
|
+
TupleType,
|
|
47
|
+
Type,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def arg_from_ast(
|
|
52
|
+
node: AstNode,
|
|
53
|
+
globals: Globals,
|
|
54
|
+
param_var_mapping: dict[str, Parameter],
|
|
55
|
+
allow_free_vars: bool = False,
|
|
56
|
+
) -> Argument:
|
|
57
|
+
"""Turns an AST expression into an argument."""
|
|
58
|
+
from guppylang_internals.checker.cfg_checker import VarNotDefinedError
|
|
59
|
+
|
|
60
|
+
# A single (possibly qualified) identifier
|
|
61
|
+
if defn := _try_parse_defn(node, globals):
|
|
62
|
+
return _arg_from_instantiated_defn(
|
|
63
|
+
defn, [], globals, node, param_var_mapping, allow_free_vars
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# An identifier referring to a quantified variable
|
|
67
|
+
if isinstance(node, ast.Name):
|
|
68
|
+
if node.id in param_var_mapping:
|
|
69
|
+
return param_var_mapping[node.id].to_bound()
|
|
70
|
+
raise GuppyError(VarNotDefinedError(node, node.id))
|
|
71
|
+
|
|
72
|
+
# A parametrised type, e.g. `list[??]`
|
|
73
|
+
if isinstance(node, ast.Subscript) and (
|
|
74
|
+
defn := _try_parse_defn(node.value, globals)
|
|
75
|
+
):
|
|
76
|
+
arg_nodes = (
|
|
77
|
+
node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice]
|
|
78
|
+
)
|
|
79
|
+
return _arg_from_instantiated_defn(
|
|
80
|
+
defn, arg_nodes, globals, node, param_var_mapping, allow_free_vars
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# We allow tuple types to be written as `(int, bool)`
|
|
84
|
+
if isinstance(node, ast.Tuple):
|
|
85
|
+
ty = TupleType(
|
|
86
|
+
[
|
|
87
|
+
type_from_ast(el, globals, param_var_mapping, allow_free_vars)
|
|
88
|
+
for el in node.elts
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
return TypeArg(ty)
|
|
92
|
+
|
|
93
|
+
# Literals
|
|
94
|
+
if isinstance(node, ast.Constant):
|
|
95
|
+
match node.value:
|
|
96
|
+
# `None` is represented as a `ast.Constant` node with value `None`
|
|
97
|
+
case None:
|
|
98
|
+
return TypeArg(NoneType())
|
|
99
|
+
case bool(v):
|
|
100
|
+
return ConstArg(ConstValue(bool_type(), v))
|
|
101
|
+
# Integer literals are turned into nat args.
|
|
102
|
+
# TODO: To support int args, we need proper inference logic here
|
|
103
|
+
# See https://github.com/CQCL/guppylang/issues/1030
|
|
104
|
+
case int(v) if v >= 0:
|
|
105
|
+
nat_ty = NumericType(NumericType.Kind.Nat)
|
|
106
|
+
return ConstArg(ConstValue(nat_ty, v))
|
|
107
|
+
case float(v):
|
|
108
|
+
float_ty = NumericType(NumericType.Kind.Float)
|
|
109
|
+
return ConstArg(ConstValue(float_ty, v))
|
|
110
|
+
# String literals are ignored for now since they could also be stringified
|
|
111
|
+
# types.
|
|
112
|
+
# TODO: To support string args, we need proper inference logic here
|
|
113
|
+
# See https://github.com/CQCL/guppylang/issues/1030
|
|
114
|
+
case str(_):
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
# Py-expressions can also be used to specify static numbers
|
|
118
|
+
if comptime_expr := is_comptime_expression(node):
|
|
119
|
+
from guppylang_internals.checker.expr_checker import eval_comptime_expr
|
|
120
|
+
|
|
121
|
+
v = eval_comptime_expr(comptime_expr, Context(globals, Locals({}), {}))
|
|
122
|
+
if isinstance(v, int):
|
|
123
|
+
nat_ty = NumericType(NumericType.Kind.Nat)
|
|
124
|
+
return ConstArg(ConstValue(nat_ty, v))
|
|
125
|
+
else:
|
|
126
|
+
raise GuppyError(IllegalComptimeTypeArgError(node, v))
|
|
127
|
+
|
|
128
|
+
# Finally, we also support delayed annotations in strings
|
|
129
|
+
if isinstance(node, ast.Constant) and isinstance(node.value, str):
|
|
130
|
+
node = _parse_delayed_annotation(node.value, node)
|
|
131
|
+
return arg_from_ast(node, globals, param_var_mapping, allow_free_vars)
|
|
132
|
+
|
|
133
|
+
raise GuppyError(InvalidTypeArgError(node))
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _try_parse_defn(node: AstNode, globals: Globals) -> Definition | None:
|
|
137
|
+
"""Tries to parse a (possibly qualified) name into a global definition."""
|
|
138
|
+
from guppylang.defs import GuppyDefinition
|
|
139
|
+
from guppylang_internals.checker.cfg_checker import VarNotDefinedError
|
|
140
|
+
|
|
141
|
+
match node:
|
|
142
|
+
case ast.Name(id=x):
|
|
143
|
+
if x not in globals:
|
|
144
|
+
return None
|
|
145
|
+
defn = globals[x]
|
|
146
|
+
if isinstance(defn, PythonObject):
|
|
147
|
+
return None
|
|
148
|
+
return defn
|
|
149
|
+
case ast.Attribute(value=ast.Name(id=module_name) as value, attr=x):
|
|
150
|
+
if module_name not in globals:
|
|
151
|
+
raise GuppyError(VarNotDefinedError(value, module_name))
|
|
152
|
+
match globals[module_name]:
|
|
153
|
+
case PythonObject(ModuleType() as module):
|
|
154
|
+
if x in module.__dict__:
|
|
155
|
+
val = module.__dict__[x]
|
|
156
|
+
if isinstance(val, GuppyDefinition):
|
|
157
|
+
return ENGINE.get_parsed(val.id)
|
|
158
|
+
raise GuppyError(
|
|
159
|
+
ModuleMemberNotFoundError(node, module.__name__, x)
|
|
160
|
+
)
|
|
161
|
+
case _:
|
|
162
|
+
raise GuppyError(ExpectedError(value, "a module"))
|
|
163
|
+
case _:
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _arg_from_instantiated_defn(
|
|
168
|
+
defn: Definition,
|
|
169
|
+
arg_nodes: list[ast.expr],
|
|
170
|
+
globals: Globals,
|
|
171
|
+
node: AstNode,
|
|
172
|
+
param_var_mapping: dict[str, Parameter],
|
|
173
|
+
allow_free_vars: bool = False,
|
|
174
|
+
) -> Argument:
|
|
175
|
+
"""Parses a globals definition with type args into an argument."""
|
|
176
|
+
match defn:
|
|
177
|
+
# Special case for the `Callable` type
|
|
178
|
+
case CallableTypeDef():
|
|
179
|
+
return TypeArg(
|
|
180
|
+
_parse_callable_type(
|
|
181
|
+
arg_nodes, node, globals, param_var_mapping, allow_free_vars
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
# Either a defined type (e.g. `int`, `bool`, ...)
|
|
185
|
+
case TypeDef() as defn:
|
|
186
|
+
args = [
|
|
187
|
+
arg_from_ast(arg_node, globals, param_var_mapping, allow_free_vars)
|
|
188
|
+
for arg_node in arg_nodes
|
|
189
|
+
]
|
|
190
|
+
ty = defn.check_instantiate(args, node)
|
|
191
|
+
return TypeArg(ty)
|
|
192
|
+
# Or a parameter (e.g. `T`, `n`, ...)
|
|
193
|
+
case ParamDef() as defn:
|
|
194
|
+
# We don't allow parametrised variables like `T[int]`
|
|
195
|
+
if arg_nodes:
|
|
196
|
+
raise GuppyError(HigherKindedTypeVarError(node, defn))
|
|
197
|
+
if defn.name not in param_var_mapping:
|
|
198
|
+
if allow_free_vars:
|
|
199
|
+
param_var_mapping[defn.name] = defn.to_param(len(param_var_mapping))
|
|
200
|
+
else:
|
|
201
|
+
raise GuppyError(FreeTypeVarError(node, defn))
|
|
202
|
+
return param_var_mapping[defn.name].to_bound()
|
|
203
|
+
case defn:
|
|
204
|
+
err = ExpectedError(node, "a type", got=f"{defn.description} `{defn.name}`")
|
|
205
|
+
raise GuppyError(err)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _parse_delayed_annotation(ast_str: str, node: ast.Constant) -> ast.expr:
|
|
209
|
+
"""Parses a delayed type annotation in a string."""
|
|
210
|
+
try:
|
|
211
|
+
[stmt] = ast.parse(ast_str).body
|
|
212
|
+
if not isinstance(stmt, ast.Expr):
|
|
213
|
+
raise GuppyError(InvalidTypeError(node))
|
|
214
|
+
set_location_from(stmt, loc=node)
|
|
215
|
+
shift_loc(
|
|
216
|
+
stmt,
|
|
217
|
+
delta_lineno=node.lineno - 1, # -1 since lines start at 1
|
|
218
|
+
delta_col_offset=node.col_offset + 1, # +1 to remove the `"`
|
|
219
|
+
)
|
|
220
|
+
except (SyntaxError, ValueError):
|
|
221
|
+
raise GuppyError(InvalidTypeError(node)) from None
|
|
222
|
+
else:
|
|
223
|
+
return stmt.value
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _parse_callable_type(
|
|
227
|
+
args: list[ast.expr],
|
|
228
|
+
loc: AstNode,
|
|
229
|
+
globals: Globals,
|
|
230
|
+
param_var_mapping: dict[str, Parameter],
|
|
231
|
+
allow_free_vars: bool = False,
|
|
232
|
+
) -> FunctionType:
|
|
233
|
+
"""Helper function to parse a `Callable[[<arguments>], <return type>]` type."""
|
|
234
|
+
err = InvalidCallableTypeError(loc)
|
|
235
|
+
if len(args) != 2:
|
|
236
|
+
raise GuppyError(err)
|
|
237
|
+
[inputs, output] = args
|
|
238
|
+
if not isinstance(inputs, ast.List):
|
|
239
|
+
raise GuppyError(err)
|
|
240
|
+
inouts, output = parse_function_io_types(
|
|
241
|
+
inputs.elts, output, None, loc, globals, param_var_mapping, allow_free_vars
|
|
242
|
+
)
|
|
243
|
+
return FunctionType(inouts, output)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def parse_function_io_types(
|
|
247
|
+
input_nodes: list[ast.expr],
|
|
248
|
+
output_node: ast.expr,
|
|
249
|
+
input_names: list[str] | None,
|
|
250
|
+
loc: AstNode,
|
|
251
|
+
globals: Globals,
|
|
252
|
+
param_var_mapping: dict[str, Parameter],
|
|
253
|
+
allow_free_vars: bool = False,
|
|
254
|
+
) -> tuple[list[FuncInput], Type]:
|
|
255
|
+
"""Parses the inputs and output types of a function type.
|
|
256
|
+
|
|
257
|
+
This function takes care of parsing annotations and any related checks.
|
|
258
|
+
|
|
259
|
+
Returns the parsed input and output types.
|
|
260
|
+
"""
|
|
261
|
+
inputs = []
|
|
262
|
+
for i, inp in enumerate(input_nodes):
|
|
263
|
+
ty, flags = type_with_flags_from_ast(
|
|
264
|
+
inp, globals, param_var_mapping, allow_free_vars
|
|
265
|
+
)
|
|
266
|
+
if InputFlags.Owned in flags and ty.copyable:
|
|
267
|
+
raise GuppyError(NonLinearOwnedError(loc, ty))
|
|
268
|
+
if not ty.copyable and InputFlags.Owned not in flags:
|
|
269
|
+
flags |= InputFlags.Inout
|
|
270
|
+
if InputFlags.Comptime in flags:
|
|
271
|
+
if input_names is None:
|
|
272
|
+
raise GuppyError(CallableComptimeError(inp))
|
|
273
|
+
name = input_names[i]
|
|
274
|
+
|
|
275
|
+
# Make sure we're not shadowing a type variable with the same name that was
|
|
276
|
+
# already used on the left. E.g
|
|
277
|
+
#
|
|
278
|
+
# n = guppy.type_var("n")
|
|
279
|
+
# def foo(xs: array[int, n], n: nat @comptime)
|
|
280
|
+
#
|
|
281
|
+
# TODO: In principle we could lift this restriction by tracking multiple
|
|
282
|
+
# params referring to the same name in `param_var_mapping`, but not sure if
|
|
283
|
+
# this would be worth it...
|
|
284
|
+
if name in param_var_mapping:
|
|
285
|
+
raise GuppyError(ComptimeArgShadowError(inp, name))
|
|
286
|
+
param_var_mapping[name] = ConstParam(
|
|
287
|
+
len(param_var_mapping), name, ty, from_comptime_arg=True
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
inputs.append(FuncInput(ty, flags))
|
|
291
|
+
output = type_from_ast(output_node, globals, param_var_mapping, allow_free_vars)
|
|
292
|
+
return inputs, output
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
if sys.version_info >= (3, 12):
|
|
296
|
+
|
|
297
|
+
def parse_parameter(node: ast.type_param, idx: int, globals: Globals) -> Parameter:
|
|
298
|
+
"""Parses a `Variable: Bound` generic type parameter declaration."""
|
|
299
|
+
if isinstance(node, ast.TypeVarTuple | ast.ParamSpec):
|
|
300
|
+
raise GuppyError(UnsupportedError(node, "Variadic generic parameters"))
|
|
301
|
+
assert isinstance(node, ast.TypeVar)
|
|
302
|
+
|
|
303
|
+
match node.bound:
|
|
304
|
+
# No bound means it's a linear type parameter
|
|
305
|
+
case None:
|
|
306
|
+
return TypeParam(
|
|
307
|
+
idx, node.name, must_be_copyable=False, must_be_droppable=False
|
|
308
|
+
)
|
|
309
|
+
# Special `Copy` or `Drop` bounds for types
|
|
310
|
+
case ast.Name(id="Copy"):
|
|
311
|
+
return TypeParam(
|
|
312
|
+
idx, node.name, must_be_copyable=True, must_be_droppable=False
|
|
313
|
+
)
|
|
314
|
+
case ast.Name(id="Drop"):
|
|
315
|
+
return TypeParam(
|
|
316
|
+
idx, node.name, must_be_copyable=False, must_be_droppable=True
|
|
317
|
+
)
|
|
318
|
+
# Copy and drop is annotated as `T: (Copy, Drop)`
|
|
319
|
+
# TODO: Should we also allow `T: Copy + Drop`? Mypy would complain about it
|
|
320
|
+
case ast.Tuple(elts=[ast.Name(id=id1), ast.Name(id=id2)]) if {id1, id2} == {
|
|
321
|
+
"Copy",
|
|
322
|
+
"Drop",
|
|
323
|
+
}:
|
|
324
|
+
return TypeParam(
|
|
325
|
+
idx, node.name, must_be_copyable=True, must_be_droppable=True
|
|
326
|
+
)
|
|
327
|
+
# Otherwise, it must be a const parameter
|
|
328
|
+
case bound:
|
|
329
|
+
# For now, we don't allow the types of const params to refer to previous
|
|
330
|
+
# parameters, so we pass an empty dict as the `param_var_mapping`.
|
|
331
|
+
# TODO: In the future we might want to allow stuff like
|
|
332
|
+
# `def foo[T, XS: array[T, 42]]` and so on
|
|
333
|
+
ty = type_from_ast(bound, globals, {}, allow_free_vars=False)
|
|
334
|
+
if not ty.copyable or not ty.droppable:
|
|
335
|
+
raise GuppyError(LinearConstParamError(bound, ty))
|
|
336
|
+
|
|
337
|
+
# TODO: For now we can only do `nat` const args since they lower to
|
|
338
|
+
# Hugr bounded nats. Extend to arbitrary types via monomorphization.
|
|
339
|
+
# See https://github.com/CQCL/guppylang/issues/1008
|
|
340
|
+
if ty != NumericType(NumericType.Kind.Nat):
|
|
341
|
+
raise GuppyError(
|
|
342
|
+
UnsupportedError(bound, f"`{ty}` generic parameters")
|
|
343
|
+
)
|
|
344
|
+
return ConstParam(idx, node.name, ty)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
_type_param = TypeParam(0, "T", False, False)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def type_with_flags_from_ast(
|
|
351
|
+
node: AstNode,
|
|
352
|
+
globals: Globals,
|
|
353
|
+
param_var_mapping: dict[str, Parameter],
|
|
354
|
+
allow_free_vars: bool = False,
|
|
355
|
+
) -> tuple[Type, InputFlags]:
|
|
356
|
+
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
|
|
357
|
+
ty, flags = type_with_flags_from_ast(
|
|
358
|
+
node.left, globals, param_var_mapping, allow_free_vars
|
|
359
|
+
)
|
|
360
|
+
match node.right:
|
|
361
|
+
case ast.Name(id="owned"):
|
|
362
|
+
if ty.copyable:
|
|
363
|
+
raise GuppyError(NonLinearOwnedError(node.right, ty))
|
|
364
|
+
flags |= InputFlags.Owned
|
|
365
|
+
case ast.Name(id="comptime"):
|
|
366
|
+
flags |= InputFlags.Comptime
|
|
367
|
+
if not ty.copyable or not ty.droppable:
|
|
368
|
+
raise GuppyError(LinearComptimeError(node.right, ty))
|
|
369
|
+
# For now, we don't allow comptime annotations on generic inputs
|
|
370
|
+
# TODO: In the future we might want to allow stuff like
|
|
371
|
+
# `def foo[T: (Copy, Discard](x: T @comptime)`.
|
|
372
|
+
# Also see the todo in `parse_parameter`.
|
|
373
|
+
var_finder = BoundVarFinder()
|
|
374
|
+
ty.visit(var_finder)
|
|
375
|
+
if var_finder.bound_vars:
|
|
376
|
+
raise GuppyError(
|
|
377
|
+
UnsupportedError(node.left, "Generic comptime arguments")
|
|
378
|
+
)
|
|
379
|
+
case _:
|
|
380
|
+
raise GuppyError(InvalidFlagError(node.right))
|
|
381
|
+
return ty, flags
|
|
382
|
+
# We also need to handle the case that this could be a delayed string annotation
|
|
383
|
+
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
|
|
384
|
+
node = _parse_delayed_annotation(node.value, node)
|
|
385
|
+
return type_with_flags_from_ast(
|
|
386
|
+
node, globals, param_var_mapping, allow_free_vars
|
|
387
|
+
)
|
|
388
|
+
else:
|
|
389
|
+
# Parse an argument and check that it's valid for a `TypeParam`
|
|
390
|
+
arg = arg_from_ast(node, globals, param_var_mapping, allow_free_vars)
|
|
391
|
+
tyarg = _type_param.check_arg(arg, node)
|
|
392
|
+
return tyarg.ty, InputFlags.NoFlags
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def type_from_ast(
|
|
396
|
+
node: AstNode,
|
|
397
|
+
globals: Globals,
|
|
398
|
+
param_var_mapping: dict[str, Parameter],
|
|
399
|
+
allow_free_vars: bool = False,
|
|
400
|
+
) -> Type:
|
|
401
|
+
"""Turns an AST expression into a Guppy type."""
|
|
402
|
+
ty, flags = type_with_flags_from_ast(
|
|
403
|
+
node, globals, param_var_mapping, allow_free_vars
|
|
404
|
+
)
|
|
405
|
+
if flags != InputFlags.NoFlags:
|
|
406
|
+
assert InputFlags.Inout not in flags # Users shouldn't be able to set this
|
|
407
|
+
raise GuppyError(FlagNotAllowedError(node))
|
|
408
|
+
return ty
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def type_row_from_ast(
|
|
412
|
+
node: ast.expr, globals: "Globals", allow_free_vars: bool = False
|
|
413
|
+
) -> Sequence[Type]:
|
|
414
|
+
"""Turns an AST expression into a Guppy type row.
|
|
415
|
+
|
|
416
|
+
This is needed to interpret the return type annotation of functions.
|
|
417
|
+
"""
|
|
418
|
+
# The return type `-> None` is represented in the ast as `ast.Constant(value=None)`
|
|
419
|
+
if isinstance(node, ast.Constant) and node.value is None:
|
|
420
|
+
return []
|
|
421
|
+
ty = type_from_ast(node, globals, {}, allow_free_vars)
|
|
422
|
+
if isinstance(ty, TupleType):
|
|
423
|
+
return ty.element_types
|
|
424
|
+
else:
|
|
425
|
+
return [ty]
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from functools import singledispatchmethod
|
|
2
|
+
|
|
3
|
+
from guppylang_internals.error import InternalGuppyError
|
|
4
|
+
from guppylang_internals.tys.arg import ConstArg, TypeArg
|
|
5
|
+
from guppylang_internals.tys.const import Const, ConstValue
|
|
6
|
+
from guppylang_internals.tys.param import ConstParam, TypeParam
|
|
7
|
+
from guppylang_internals.tys.ty import (
|
|
8
|
+
FunctionType,
|
|
9
|
+
InputFlags,
|
|
10
|
+
NoneType,
|
|
11
|
+
NumericType,
|
|
12
|
+
OpaqueType,
|
|
13
|
+
StructType,
|
|
14
|
+
SumType,
|
|
15
|
+
TupleType,
|
|
16
|
+
Type,
|
|
17
|
+
)
|
|
18
|
+
from guppylang_internals.tys.var import BoundVar, ExistentialVar, UniqueId
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TypePrinter:
|
|
22
|
+
"""Visitor that pretty prints types.
|
|
23
|
+
|
|
24
|
+
Takes care of inserting minimal parentheses and renaming variables to make them
|
|
25
|
+
unique.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# Store how often each user-picked display name is used to stand for different
|
|
29
|
+
# variables
|
|
30
|
+
used: dict[str, int]
|
|
31
|
+
|
|
32
|
+
# Already chosen names for bound and existential variables
|
|
33
|
+
bound_names: list[str]
|
|
34
|
+
existential_names: dict[UniqueId, str]
|
|
35
|
+
|
|
36
|
+
# Count how often the user has picked the same name to stand for different variables
|
|
37
|
+
counter: dict[str, int]
|
|
38
|
+
|
|
39
|
+
def __init__(self) -> None:
|
|
40
|
+
self.used = {}
|
|
41
|
+
self.bound_names = []
|
|
42
|
+
self.existential_names = {}
|
|
43
|
+
self.counter = {}
|
|
44
|
+
|
|
45
|
+
def _fresh_name(self, display_name: str) -> str:
|
|
46
|
+
if display_name not in self.counter:
|
|
47
|
+
self.counter[display_name] = 1
|
|
48
|
+
return display_name
|
|
49
|
+
|
|
50
|
+
# If the display name `T` has already been used, we start adding indices: `T`,
|
|
51
|
+
# `T'1`, `T'2`, ...
|
|
52
|
+
indexed = f"{display_name}'{self.counter[display_name]}"
|
|
53
|
+
self.counter[display_name] += 1
|
|
54
|
+
return indexed
|
|
55
|
+
|
|
56
|
+
def visit(self, ty: Type | Const) -> str:
|
|
57
|
+
return self._visit(ty, False)
|
|
58
|
+
|
|
59
|
+
@singledispatchmethod
|
|
60
|
+
def _visit(self, ty: Type, inside_row: bool) -> str:
|
|
61
|
+
raise InternalGuppyError(f"Tried to pretty-print unknown type: {ty!r}")
|
|
62
|
+
|
|
63
|
+
@_visit.register
|
|
64
|
+
def _visit_BoundVar(self, var: BoundVar, inside_row: bool) -> str:
|
|
65
|
+
if var.idx < len(self.bound_names):
|
|
66
|
+
return self.bound_names[var.idx]
|
|
67
|
+
return var.display_name
|
|
68
|
+
|
|
69
|
+
@_visit.register
|
|
70
|
+
def _visit_ExistentialVar(self, var: ExistentialVar, inside_row: bool) -> str:
|
|
71
|
+
if var.id not in self.existential_names:
|
|
72
|
+
self.existential_names[var.id] = self._fresh_name(var.display_name)
|
|
73
|
+
return f"?{self.existential_names[var.id]}"
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def _print_flags(flags: InputFlags) -> str:
|
|
77
|
+
s = ""
|
|
78
|
+
if InputFlags.Owned in flags:
|
|
79
|
+
s += " @owned"
|
|
80
|
+
if InputFlags.Comptime in flags:
|
|
81
|
+
s += " @comptime"
|
|
82
|
+
return s
|
|
83
|
+
|
|
84
|
+
@_visit.register
|
|
85
|
+
def _visit_FunctionType(self, ty: FunctionType, inside_row: bool) -> str:
|
|
86
|
+
if ty.parametrized:
|
|
87
|
+
for p in ty.params:
|
|
88
|
+
self.bound_names.append(self._fresh_name(p.name))
|
|
89
|
+
inputs = ", ".join(
|
|
90
|
+
[
|
|
91
|
+
self._visit(inp.ty, True) + self._print_flags(inp.flags)
|
|
92
|
+
for inp in ty.inputs
|
|
93
|
+
]
|
|
94
|
+
)
|
|
95
|
+
if len(ty.inputs) != 1:
|
|
96
|
+
inputs = f"({inputs})"
|
|
97
|
+
output = self._visit(ty.output, True)
|
|
98
|
+
if ty.parametrized:
|
|
99
|
+
params = [
|
|
100
|
+
self._visit(param, False)
|
|
101
|
+
for param in ty.params
|
|
102
|
+
# Don't print out implicit parameters generated for comptime arguments
|
|
103
|
+
if not isinstance(param, ConstParam) or not param.from_comptime_arg
|
|
104
|
+
]
|
|
105
|
+
quantified = ", ".join(params)
|
|
106
|
+
del self.bound_names[: -len(ty.params)]
|
|
107
|
+
return _wrap(f"forall {quantified}. {inputs} -> {output}", inside_row)
|
|
108
|
+
return _wrap(f"{inputs} -> {output}", inside_row)
|
|
109
|
+
|
|
110
|
+
@_visit.register(OpaqueType)
|
|
111
|
+
@_visit.register(StructType)
|
|
112
|
+
def _visit_OpaqueType_StructType(
|
|
113
|
+
self, ty: OpaqueType | StructType, inside_row: bool
|
|
114
|
+
) -> str:
|
|
115
|
+
if ty.args:
|
|
116
|
+
args = ", ".join(self._visit(arg, True) for arg in ty.args)
|
|
117
|
+
return f"{ty.defn.name}[{args}]"
|
|
118
|
+
return ty.defn.name
|
|
119
|
+
|
|
120
|
+
@_visit.register
|
|
121
|
+
def _visit_TupleType(self, ty: TupleType, inside_row: bool) -> str:
|
|
122
|
+
args = ", ".join(self._visit(arg, True) for arg in ty.args)
|
|
123
|
+
return f"({args})"
|
|
124
|
+
|
|
125
|
+
@_visit.register
|
|
126
|
+
def _visit_SumType(self, ty: SumType, inside_row: bool) -> str:
|
|
127
|
+
args = ", ".join(self._visit(arg, True) for arg in ty.args)
|
|
128
|
+
return f"Sum[{args}]"
|
|
129
|
+
|
|
130
|
+
@_visit.register
|
|
131
|
+
def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str:
|
|
132
|
+
return "None"
|
|
133
|
+
|
|
134
|
+
@_visit.register
|
|
135
|
+
def _visit_NumericType(self, ty: NumericType, inside_row: bool) -> str:
|
|
136
|
+
return ty.kind.name.lower()
|
|
137
|
+
|
|
138
|
+
@_visit.register
|
|
139
|
+
def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str:
|
|
140
|
+
# TODO: Print linearity?
|
|
141
|
+
return self.bound_names[param.idx]
|
|
142
|
+
|
|
143
|
+
@_visit.register
|
|
144
|
+
def _visit_ConstParam(self, param: ConstParam, inside_row: bool) -> str:
|
|
145
|
+
kind = self._visit(param.ty, True)
|
|
146
|
+
name = self.bound_names[param.idx]
|
|
147
|
+
return f"{name}: {kind}"
|
|
148
|
+
|
|
149
|
+
@_visit.register
|
|
150
|
+
def _visit_TypeArg(self, arg: TypeArg, inside_row: bool) -> str:
|
|
151
|
+
return self._visit(arg.ty, inside_row)
|
|
152
|
+
|
|
153
|
+
@_visit.register
|
|
154
|
+
def _visit_ConstArg(self, arg: ConstArg, inside_row: bool) -> str:
|
|
155
|
+
return self._visit(arg.const, inside_row)
|
|
156
|
+
|
|
157
|
+
@_visit.register
|
|
158
|
+
def _visit_ConstValue(self, c: ConstValue, inside_row: bool) -> str:
|
|
159
|
+
return str(c.value)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _wrap(s: str, inside_row: bool) -> str:
|
|
163
|
+
return f"({s})" if inside_row else s
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def signature_to_str(name: str, sig: FunctionType) -> str:
|
|
167
|
+
"""Displays a function signature in Python syntax including the function name."""
|
|
168
|
+
assert sig.input_names is not None
|
|
169
|
+
s = f"def {name}("
|
|
170
|
+
s += ", ".join(
|
|
171
|
+
f"{name}: {inp.ty}{TypePrinter._print_flags(inp.flags)}"
|
|
172
|
+
for name, inp in zip(sig.input_names, sig.inputs, strict=True)
|
|
173
|
+
)
|
|
174
|
+
return s + ") -> " + str(sig.output)
|