guppylang-internals 0.22.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/cfg/cfg.py +8 -0
- guppylang_internals/checker/cfg_checker.py +26 -65
- guppylang_internals/checker/core.py +8 -0
- guppylang_internals/checker/expr_checker.py +11 -25
- guppylang_internals/checker/func_checker.py +170 -21
- guppylang_internals/checker/stmt_checker.py +1 -1
- guppylang_internals/decorator.py +124 -58
- guppylang_internals/definition/const.py +2 -2
- guppylang_internals/definition/custom.py +1 -1
- guppylang_internals/definition/declaration.py +1 -1
- guppylang_internals/definition/extern.py +2 -2
- guppylang_internals/definition/function.py +1 -1
- guppylang_internals/definition/parameter.py +2 -2
- guppylang_internals/definition/pytket_circuits.py +1 -1
- guppylang_internals/definition/struct.py +10 -10
- guppylang_internals/definition/traced.py +1 -1
- guppylang_internals/definition/ty.py +6 -0
- guppylang_internals/definition/wasm.py +2 -2
- guppylang_internals/engine.py +13 -2
- guppylang_internals/nodes.py +0 -23
- guppylang_internals/std/_internal/compiler/tket_exts.py +3 -6
- guppylang_internals/std/_internal/compiler/wasm.py +37 -26
- guppylang_internals/tracing/function.py +13 -2
- guppylang_internals/tracing/unpacking.py +18 -12
- guppylang_internals/tys/builtin.py +30 -11
- guppylang_internals/tys/errors.py +6 -0
- guppylang_internals/tys/parsing.py +111 -125
- {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/METADATA +5 -5
- {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/RECORD +32 -32
- {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/licenses/LICENCE +0 -0
guppylang_internals/decorator.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
from typing import TYPE_CHECKING, ParamSpec, TypeVar
|
|
4
|
+
from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
|
|
5
5
|
|
|
6
6
|
from hugr import ops
|
|
7
7
|
from hugr import tys as ht
|
|
@@ -42,6 +42,7 @@ from guppylang_internals.tys.ty import (
|
|
|
42
42
|
)
|
|
43
43
|
|
|
44
44
|
if TYPE_CHECKING:
|
|
45
|
+
import ast
|
|
45
46
|
import builtins
|
|
46
47
|
from collections.abc import Callable, Sequence
|
|
47
48
|
from types import FrameType
|
|
@@ -121,15 +122,19 @@ def hugr_op(
|
|
|
121
122
|
return custom_function(OpCompiler(op), checker, higher_order_value, name, signature)
|
|
122
123
|
|
|
123
124
|
|
|
124
|
-
def extend_type(defn: TypeDef) -> Callable[[type], type]:
|
|
125
|
-
"""Decorator to add new instance functions to a type.
|
|
125
|
+
def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]:
|
|
126
|
+
"""Decorator to add new instance functions to a type.
|
|
127
|
+
|
|
128
|
+
By default, returns a `GuppyDefinition` object referring to the type. Alternatively,
|
|
129
|
+
`return_class=True` can be set to return the decorated class unchanged.
|
|
130
|
+
"""
|
|
126
131
|
from guppylang.defs import GuppyDefinition
|
|
127
132
|
|
|
128
133
|
def dec(c: type) -> type:
|
|
129
134
|
for val in c.__dict__.values():
|
|
130
135
|
if isinstance(val, GuppyDefinition):
|
|
131
136
|
DEF_STORE.register_impl(defn.id, val.wrapped.name, val.id)
|
|
132
|
-
return c
|
|
137
|
+
return c if return_class else GuppyDefinition(defn) # type: ignore[return-value]
|
|
133
138
|
|
|
134
139
|
return dec
|
|
135
140
|
|
|
@@ -181,63 +186,124 @@ def custom_type(
|
|
|
181
186
|
|
|
182
187
|
|
|
183
188
|
def wasm_module(
|
|
184
|
-
filename: str,
|
|
189
|
+
filename: str,
|
|
185
190
|
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
DEF_STORE.register_def(wasm_module, get_calling_frame())
|
|
201
|
-
for val in cls.__dict__.values():
|
|
202
|
-
if isinstance(val, GuppyDefinition):
|
|
203
|
-
DEF_STORE.register_impl(wasm_module.id, val.wrapped.name, val.id)
|
|
204
|
-
# Add a constructor to the class
|
|
205
|
-
call_method = CustomFunctionDef(
|
|
206
|
-
DefId.fresh(),
|
|
207
|
-
"__new__",
|
|
208
|
-
None,
|
|
209
|
-
FunctionType(
|
|
210
|
-
[FuncInput(NumericType(NumericType.Kind.Nat), flags=InputFlags.Owned)],
|
|
211
|
-
wasm_module_ty,
|
|
212
|
-
),
|
|
213
|
-
DefaultCallChecker(),
|
|
214
|
-
WasmModuleInitCompiler(),
|
|
215
|
-
True,
|
|
216
|
-
GlobalConstId.fresh(f"{cls.__name__}.__new__"),
|
|
217
|
-
True,
|
|
218
|
-
)
|
|
219
|
-
discard = CustomFunctionDef(
|
|
220
|
-
DefId.fresh(),
|
|
221
|
-
"discard",
|
|
222
|
-
None,
|
|
223
|
-
FunctionType([FuncInput(wasm_module_ty, InputFlags.Owned)], NoneType()),
|
|
224
|
-
DefaultCallChecker(),
|
|
225
|
-
WasmModuleDiscardCompiler(),
|
|
226
|
-
False,
|
|
227
|
-
GlobalConstId.fresh(f"{cls.__name__}.__discard__"),
|
|
228
|
-
True,
|
|
229
|
-
)
|
|
230
|
-
DEF_STORE.register_def(call_method, get_calling_frame())
|
|
231
|
-
DEF_STORE.register_impl(wasm_module.id, "__new__", call_method.id)
|
|
232
|
-
DEF_STORE.register_def(discard, get_calling_frame())
|
|
233
|
-
DEF_STORE.register_impl(wasm_module.id, "discard", discard.id)
|
|
191
|
+
def type_def_wrapper(
|
|
192
|
+
id: DefId,
|
|
193
|
+
name: str,
|
|
194
|
+
defined_at: ast.AST | None,
|
|
195
|
+
wasm_file: str,
|
|
196
|
+
config: str | None,
|
|
197
|
+
) -> OpaqueTypeDef:
|
|
198
|
+
assert config is None
|
|
199
|
+
return WasmModuleTypeDef(id, name, defined_at, wasm_file)
|
|
200
|
+
|
|
201
|
+
f = ext_module_decorator(
|
|
202
|
+
type_def_wrapper, WasmModuleInitCompiler(), WasmModuleDiscardCompiler(), True
|
|
203
|
+
)
|
|
204
|
+
return f(filename, None)
|
|
234
205
|
|
|
235
|
-
return GuppyDefinition(wasm_module)
|
|
236
|
-
|
|
237
|
-
return dec
|
|
238
206
|
|
|
207
|
+
def ext_module_decorator(
|
|
208
|
+
type_def: Callable[[DefId, str, ast.AST | None, str, str | None], OpaqueTypeDef],
|
|
209
|
+
init_compiler: CustomInoutCallCompiler,
|
|
210
|
+
discard_compiler: CustomInoutCallCompiler,
|
|
211
|
+
init_arg: bool, # Whether the init function should take a nat argument
|
|
212
|
+
) -> Callable[[str, str | None], Callable[[builtins.type[T]], GuppyDefinition]]:
|
|
213
|
+
from guppylang.defs import GuppyDefinition
|
|
239
214
|
|
|
240
|
-
def
|
|
215
|
+
def fun(
|
|
216
|
+
filename: str, module: str | None
|
|
217
|
+
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
218
|
+
def dec(cls: builtins.type[T]) -> GuppyDefinition:
|
|
219
|
+
# N.B. Only one module per file and vice-versa
|
|
220
|
+
ext_module = type_def(
|
|
221
|
+
DefId.fresh(),
|
|
222
|
+
cls.__name__,
|
|
223
|
+
None,
|
|
224
|
+
filename,
|
|
225
|
+
module,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
ext_module_ty = ext_module.check_instantiate([], None)
|
|
229
|
+
|
|
230
|
+
DEF_STORE.register_def(ext_module, get_calling_frame())
|
|
231
|
+
for val in cls.__dict__.values():
|
|
232
|
+
if isinstance(val, GuppyDefinition):
|
|
233
|
+
DEF_STORE.register_impl(ext_module.id, val.wrapped.name, val.id)
|
|
234
|
+
# Add a constructor to the class
|
|
235
|
+
if init_arg:
|
|
236
|
+
init_fn_ty = FunctionType(
|
|
237
|
+
[
|
|
238
|
+
FuncInput(
|
|
239
|
+
NumericType(NumericType.Kind.Nat),
|
|
240
|
+
flags=InputFlags.Owned,
|
|
241
|
+
)
|
|
242
|
+
],
|
|
243
|
+
ext_module_ty,
|
|
244
|
+
)
|
|
245
|
+
else:
|
|
246
|
+
init_fn_ty = FunctionType([], ext_module_ty)
|
|
247
|
+
|
|
248
|
+
call_method = CustomFunctionDef(
|
|
249
|
+
DefId.fresh(),
|
|
250
|
+
"__new__",
|
|
251
|
+
None,
|
|
252
|
+
init_fn_ty,
|
|
253
|
+
DefaultCallChecker(),
|
|
254
|
+
init_compiler,
|
|
255
|
+
True,
|
|
256
|
+
GlobalConstId.fresh(f"{cls.__name__}.__new__"),
|
|
257
|
+
True,
|
|
258
|
+
)
|
|
259
|
+
discard = CustomFunctionDef(
|
|
260
|
+
DefId.fresh(),
|
|
261
|
+
"discard",
|
|
262
|
+
None,
|
|
263
|
+
FunctionType([FuncInput(ext_module_ty, InputFlags.Owned)], NoneType()),
|
|
264
|
+
DefaultCallChecker(),
|
|
265
|
+
discard_compiler,
|
|
266
|
+
False,
|
|
267
|
+
GlobalConstId.fresh(f"{cls.__name__}.__discard__"),
|
|
268
|
+
True,
|
|
269
|
+
)
|
|
270
|
+
DEF_STORE.register_def(call_method, get_calling_frame())
|
|
271
|
+
DEF_STORE.register_impl(ext_module.id, "__new__", call_method.id)
|
|
272
|
+
DEF_STORE.register_def(discard, get_calling_frame())
|
|
273
|
+
DEF_STORE.register_impl(ext_module.id, "discard", discard.id)
|
|
274
|
+
|
|
275
|
+
return GuppyDefinition(ext_module)
|
|
276
|
+
|
|
277
|
+
return dec
|
|
278
|
+
|
|
279
|
+
return fun
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@overload
|
|
283
|
+
def wasm(arg: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: ...
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@overload
|
|
287
|
+
def wasm(arg: int) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: ...
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def wasm(
|
|
291
|
+
arg: int | Callable[P, T],
|
|
292
|
+
) -> (
|
|
293
|
+
GuppyFunctionDefinition[P, T]
|
|
294
|
+
| Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]
|
|
295
|
+
):
|
|
296
|
+
if isinstance(arg, int):
|
|
297
|
+
|
|
298
|
+
def wrapper(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
|
|
299
|
+
return wasm_helper(arg, f)
|
|
300
|
+
|
|
301
|
+
return wrapper
|
|
302
|
+
else:
|
|
303
|
+
return wasm_helper(None, arg)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def wasm_helper(fn_id: int | None, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
|
|
241
307
|
from guppylang.defs import GuppyFunctionDefinition
|
|
242
308
|
|
|
243
309
|
func = RawWasmFunctionDef(
|
|
@@ -246,7 +312,7 @@ def wasm(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
|
|
|
246
312
|
None,
|
|
247
313
|
f,
|
|
248
314
|
WasmCallChecker(),
|
|
249
|
-
WasmModuleCallCompiler(f.__name__),
|
|
315
|
+
WasmModuleCallCompiler(f.__name__, fn_id),
|
|
250
316
|
True,
|
|
251
317
|
signature=None,
|
|
252
318
|
)
|
|
@@ -15,7 +15,7 @@ from guppylang_internals.definition.value import (
|
|
|
15
15
|
ValueDef,
|
|
16
16
|
)
|
|
17
17
|
from guppylang_internals.span import SourceMap
|
|
18
|
-
from guppylang_internals.tys.parsing import type_from_ast
|
|
18
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
@dataclass(frozen=True)
|
|
@@ -33,7 +33,7 @@ class RawConstDef(ParsableDef):
|
|
|
33
33
|
self.id,
|
|
34
34
|
self.name,
|
|
35
35
|
self.defined_at,
|
|
36
|
-
type_from_ast(self.type_ast, globals
|
|
36
|
+
type_from_ast(self.type_ast, TypeParsingCtx(globals)),
|
|
37
37
|
self.type_ast,
|
|
38
38
|
self.value,
|
|
39
39
|
)
|
|
@@ -169,7 +169,7 @@ class RawCustomFunctionDef(ParsableDef):
|
|
|
169
169
|
raise GuppyError(NoSignatureError(node, self.name))
|
|
170
170
|
|
|
171
171
|
if requires_type_annotation:
|
|
172
|
-
return check_signature(node, globals)
|
|
172
|
+
return check_signature(node, globals, self.id)
|
|
173
173
|
else:
|
|
174
174
|
return None
|
|
175
175
|
|
|
@@ -68,7 +68,7 @@ class RawFunctionDecl(ParsableDef):
|
|
|
68
68
|
def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
|
|
69
69
|
"""Parses and checks the user-provided signature of the function."""
|
|
70
70
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
71
|
-
ty = check_signature(func_ast, globals)
|
|
71
|
+
ty = check_signature(func_ast, globals, self.id)
|
|
72
72
|
if not has_empty_body(func_ast):
|
|
73
73
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
74
74
|
# Make sure we won't need monomorphization to compile this declaration
|
|
@@ -14,7 +14,7 @@ from guppylang_internals.definition.value import (
|
|
|
14
14
|
ValueDef,
|
|
15
15
|
)
|
|
16
16
|
from guppylang_internals.span import SourceMap
|
|
17
|
-
from guppylang_internals.tys.parsing import type_from_ast
|
|
17
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@dataclass(frozen=True)
|
|
@@ -33,7 +33,7 @@ class RawExternDef(ParsableDef):
|
|
|
33
33
|
self.id,
|
|
34
34
|
self.name,
|
|
35
35
|
self.defined_at,
|
|
36
|
-
type_from_ast(self.type_ast, globals
|
|
36
|
+
type_from_ast(self.type_ast, TypeParsingCtx(globals)),
|
|
37
37
|
self.symbol,
|
|
38
38
|
self.constant,
|
|
39
39
|
self.type_ast,
|
|
@@ -73,7 +73,7 @@ class RawFunctionDef(ParsableDef):
|
|
|
73
73
|
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
|
|
74
74
|
"""Parses and checks the user-provided signature of the function."""
|
|
75
75
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
76
|
-
ty = check_signature(func_ast, globals)
|
|
76
|
+
ty = check_signature(func_ast, globals, self.id)
|
|
77
77
|
return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
|
|
78
78
|
|
|
79
79
|
|
|
@@ -56,9 +56,9 @@ class RawConstVarDef(ParamDef, ParsableDef):
|
|
|
56
56
|
description: str = field(default="const variable", init=False)
|
|
57
57
|
|
|
58
58
|
def parse(self, globals: Globals, sources: SourceMap) -> "ConstVarDef":
|
|
59
|
-
from guppylang_internals.tys.parsing import type_from_ast
|
|
59
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
|
|
60
60
|
|
|
61
|
-
ty = type_from_ast(self.type_ast, globals
|
|
61
|
+
ty = type_from_ast(self.type_ast, TypeParsingCtx(globals))
|
|
62
62
|
if not ty.copyable or not ty.droppable:
|
|
63
63
|
raise GuppyError(LinearConstVarError(self.type_ast, self.name, ty))
|
|
64
64
|
return ConstVarDef(self.id, self.name, self.defined_at, ty)
|
|
@@ -85,7 +85,7 @@ class RawPytketDef(ParsableDef):
|
|
|
85
85
|
if not has_empty_body(func_ast):
|
|
86
86
|
# Function stub should have empty body.
|
|
87
87
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
88
|
-
stub_signature = check_signature(func_ast, globals)
|
|
88
|
+
stub_signature = check_signature(func_ast, globals, self.id)
|
|
89
89
|
|
|
90
90
|
# Compare signatures.
|
|
91
91
|
circuit_signature = _signature_from_circuit(self.input_circuit, self.defined_at)
|
|
@@ -3,7 +3,7 @@ import inspect
|
|
|
3
3
|
import linecache
|
|
4
4
|
import sys
|
|
5
5
|
from collections.abc import Sequence
|
|
6
|
-
from dataclasses import dataclass
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
7
|
from types import FrameType
|
|
8
8
|
from typing import ClassVar
|
|
9
9
|
|
|
@@ -39,7 +39,7 @@ from guppylang_internals.ipython_inspect import is_running_ipython
|
|
|
39
39
|
from guppylang_internals.span import SourceMap, Span, to_span
|
|
40
40
|
from guppylang_internals.tys.arg import Argument
|
|
41
41
|
from guppylang_internals.tys.param import Parameter, check_all_args
|
|
42
|
-
from guppylang_internals.tys.parsing import type_from_ast
|
|
42
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
|
|
43
43
|
from guppylang_internals.tys.ty import (
|
|
44
44
|
FuncInput,
|
|
45
45
|
FunctionType,
|
|
@@ -115,6 +115,7 @@ class RawStructDef(TypeDef, ParsableDef):
|
|
|
115
115
|
"""A raw struct type definition that has not been parsed yet."""
|
|
116
116
|
|
|
117
117
|
python_class: type
|
|
118
|
+
params: None = field(default=None, init=False) # Params not known yet
|
|
118
119
|
|
|
119
120
|
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef":
|
|
120
121
|
"""Parses the raw class object into an AST and checks that it is well-formed."""
|
|
@@ -211,15 +212,16 @@ class ParsedStructDef(TypeDef, CheckableDef):
|
|
|
211
212
|
|
|
212
213
|
def check(self, globals: Globals) -> "CheckedStructDef":
|
|
213
214
|
"""Checks that all struct fields have valid types."""
|
|
215
|
+
param_var_mapping = {p.name: p for p in self.params}
|
|
216
|
+
ctx = TypeParsingCtx(globals, param_var_mapping)
|
|
217
|
+
|
|
214
218
|
# Before checking the fields, make sure that this definition is not recursive,
|
|
215
219
|
# otherwise the code below would not terminate.
|
|
216
220
|
# TODO: This is not ideal (see todo in `check_instantiate`)
|
|
217
|
-
|
|
218
|
-
check_not_recursive(self, globals, param_var_mapping)
|
|
221
|
+
check_not_recursive(self, ctx)
|
|
219
222
|
|
|
220
223
|
fields = [
|
|
221
|
-
StructField(f.name, type_from_ast(f.type_ast,
|
|
222
|
-
for f in self.fields
|
|
224
|
+
StructField(f.name, type_from_ast(f.type_ast, ctx)) for f in self.fields
|
|
223
225
|
]
|
|
224
226
|
return CheckedStructDef(
|
|
225
227
|
self.id, self.name, self.defined_at, self.params, fields
|
|
@@ -370,9 +372,7 @@ def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Paramet
|
|
|
370
372
|
return params
|
|
371
373
|
|
|
372
374
|
|
|
373
|
-
def check_not_recursive(
|
|
374
|
-
defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter]
|
|
375
|
-
) -> None:
|
|
375
|
+
def check_not_recursive(defn: ParsedStructDef, ctx: TypeParsingCtx) -> None:
|
|
376
376
|
"""Throws a user error if the given struct definition is recursive."""
|
|
377
377
|
|
|
378
378
|
# TODO: The implementation below hijacks the type parsing logic to detect recursive
|
|
@@ -388,5 +388,5 @@ def check_not_recursive(
|
|
|
388
388
|
original = defn.check_instantiate
|
|
389
389
|
object.__setattr__(defn, "check_instantiate", dummy_check_instantiate)
|
|
390
390
|
for fld in defn.fields:
|
|
391
|
-
type_from_ast(fld.type_ast,
|
|
391
|
+
type_from_ast(fld.type_ast, ctx)
|
|
392
392
|
object.__setattr__(defn, "check_instantiate", original)
|
|
@@ -48,7 +48,7 @@ class RawTracedFunctionDef(ParsableDef):
|
|
|
48
48
|
def parse(self, globals: Globals, sources: SourceMap) -> "TracedFunctionDef":
|
|
49
49
|
"""Parses and checks the user-provided signature of the function."""
|
|
50
50
|
func_ast, _docstring = parse_py_func(self.python_func, sources)
|
|
51
|
-
ty = check_signature(func_ast, globals)
|
|
51
|
+
ty = check_signature(func_ast, globals, self.id)
|
|
52
52
|
if ty.parametrized:
|
|
53
53
|
raise GuppyError(UnsupportedError(func_ast, "Generic comptime functions"))
|
|
54
54
|
return TracedFunctionDef(self.id, self.name, func_ast, ty, self.python_func)
|
|
@@ -18,6 +18,12 @@ class TypeDef(Definition):
|
|
|
18
18
|
|
|
19
19
|
description: str = field(default="type", init=False)
|
|
20
20
|
|
|
21
|
+
#: Generic parameters of the type. This may be `None` for special types that are
|
|
22
|
+
#: more polymorphic than the regular type system allows (for example `tuple` and
|
|
23
|
+
#: `Callable`), or if this is a raw definition whose parameters are not determined
|
|
24
|
+
#: yet (for example a `RawStructDef`).
|
|
25
|
+
params: Sequence[Parameter] | None
|
|
26
|
+
|
|
21
27
|
@abstractmethod
|
|
22
28
|
def check_instantiate(
|
|
23
29
|
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
@@ -11,7 +11,7 @@ from guppylang_internals.definition.custom import (
|
|
|
11
11
|
)
|
|
12
12
|
from guppylang_internals.error import GuppyError
|
|
13
13
|
from guppylang_internals.span import SourceMap
|
|
14
|
-
from guppylang_internals.tys.builtin import
|
|
14
|
+
from guppylang_internals.tys.builtin import wasm_module_name
|
|
15
15
|
from guppylang_internals.tys.ty import (
|
|
16
16
|
FuncInput,
|
|
17
17
|
FunctionType,
|
|
@@ -30,7 +30,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
30
30
|
def sanitise_type(self, loc: AstNode | None, fun_ty: FunctionType) -> None:
|
|
31
31
|
# Place to highlight in error messages
|
|
32
32
|
match fun_ty.inputs[0]:
|
|
33
|
-
case FuncInput(ty=ty, flags=InputFlags.Inout) if
|
|
33
|
+
case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
|
|
34
34
|
ty
|
|
35
35
|
) is not None:
|
|
36
36
|
pass
|
guppylang_internals/engine.py
CHANGED
|
@@ -41,6 +41,7 @@ from guppylang_internals.tys.builtin import (
|
|
|
41
41
|
nat_type_def,
|
|
42
42
|
none_type_def,
|
|
43
43
|
option_type_def,
|
|
44
|
+
self_type_def,
|
|
44
45
|
sized_iter_type_def,
|
|
45
46
|
string_type_def,
|
|
46
47
|
tuple_type_def,
|
|
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|
|
51
52
|
|
|
52
53
|
BUILTIN_DEFS_LIST: list[RawDef] = [
|
|
53
54
|
callable_type_def,
|
|
55
|
+
self_type_def,
|
|
54
56
|
tuple_type_def,
|
|
55
57
|
none_type_def,
|
|
56
58
|
bool_type_def,
|
|
@@ -84,12 +86,14 @@ class DefinitionStore:
|
|
|
84
86
|
|
|
85
87
|
raw_defs: dict[DefId, RawDef]
|
|
86
88
|
impls: defaultdict[DefId, dict[str, DefId]]
|
|
89
|
+
impl_parents: dict[DefId, DefId]
|
|
87
90
|
frames: dict[DefId, FrameType]
|
|
88
91
|
sources: SourceMap
|
|
89
92
|
|
|
90
93
|
def __init__(self) -> None:
|
|
91
94
|
self.raw_defs = {defn.id: defn for defn in BUILTIN_DEFS_LIST}
|
|
92
95
|
self.impls = defaultdict(dict)
|
|
96
|
+
self.impl_parents = {}
|
|
93
97
|
self.frames = {}
|
|
94
98
|
self.sources = SourceMap()
|
|
95
99
|
|
|
@@ -99,7 +103,9 @@ class DefinitionStore:
|
|
|
99
103
|
self.frames[defn.id] = frame
|
|
100
104
|
|
|
101
105
|
def register_impl(self, ty_id: DefId, name: str, impl_id: DefId) -> None:
|
|
106
|
+
assert impl_id not in self.impl_parents, "Already an impl"
|
|
102
107
|
self.impls[ty_id][name] = impl_id
|
|
108
|
+
self.impl_parents[impl_id] = ty_id
|
|
103
109
|
# Update the frame of the definition to the frame of the defining class
|
|
104
110
|
if impl_id in self.frames:
|
|
105
111
|
frame = self.frames[impl_id].f_back
|
|
@@ -138,18 +144,23 @@ class CompilationEngine:
|
|
|
138
144
|
types_to_check_worklist: dict[DefId, ParsedDef]
|
|
139
145
|
to_check_worklist: dict[DefId, ParsedDef]
|
|
140
146
|
|
|
147
|
+
def __init__(self) -> None:
|
|
148
|
+
"""Resets the compilation cache."""
|
|
149
|
+
self.reset()
|
|
150
|
+
self.additional_extensions = []
|
|
151
|
+
|
|
141
152
|
def reset(self) -> None:
|
|
142
153
|
"""Resets the compilation cache."""
|
|
143
154
|
self.parsed = {}
|
|
144
155
|
self.checked = {}
|
|
145
156
|
self.compiled = {}
|
|
146
|
-
self.additional_extensions = []
|
|
147
157
|
self.to_check_worklist = {}
|
|
148
158
|
self.types_to_check_worklist = {}
|
|
149
159
|
|
|
150
160
|
@pretty_errors
|
|
151
161
|
def register_extension(self, extension: Extension) -> None:
|
|
152
|
-
self.additional_extensions
|
|
162
|
+
if extension not in self.additional_extensions:
|
|
163
|
+
self.additional_extensions.append(extension)
|
|
153
164
|
|
|
154
165
|
@pretty_errors
|
|
155
166
|
def get_parsed(self, id: DefId) -> ParsedDef:
|
guppylang_internals/nodes.py
CHANGED
|
@@ -166,17 +166,6 @@ class MakeIter(ast.expr):
|
|
|
166
166
|
self.unwrap_size_hint = unwrap_size_hint
|
|
167
167
|
|
|
168
168
|
|
|
169
|
-
class IterHasNext(ast.expr):
|
|
170
|
-
"""Checks if an iterator has a next element using the `__hasnext__` magic method.
|
|
171
|
-
|
|
172
|
-
This node is inserted in `for` loops and list comprehensions.
|
|
173
|
-
"""
|
|
174
|
-
|
|
175
|
-
value: ast.expr
|
|
176
|
-
|
|
177
|
-
_fields = ("value",)
|
|
178
|
-
|
|
179
|
-
|
|
180
169
|
class IterNext(ast.expr):
|
|
181
170
|
"""Obtains the next element of an iterator using the `__next__` magic method.
|
|
182
171
|
|
|
@@ -188,18 +177,6 @@ class IterNext(ast.expr):
|
|
|
188
177
|
_fields = ("value",)
|
|
189
178
|
|
|
190
179
|
|
|
191
|
-
class IterEnd(ast.expr):
|
|
192
|
-
"""Finalises an iterator using the `__end__` magic method.
|
|
193
|
-
|
|
194
|
-
This node is inserted in `for` loops and list comprehensions. It is needed to
|
|
195
|
-
consume linear iterators once they are finished.
|
|
196
|
-
"""
|
|
197
|
-
|
|
198
|
-
value: ast.expr
|
|
199
|
-
|
|
200
|
-
_fields = ("value",)
|
|
201
|
-
|
|
202
|
-
|
|
203
180
|
class DesugaredGenerator(ast.expr):
|
|
204
181
|
"""A single desugared generator in a list comprehension.
|
|
205
182
|
|
|
@@ -47,16 +47,13 @@ class ConstWasmModule(val.ExtensionValue):
|
|
|
47
47
|
"""Python wrapper for the tket ConstWasmModule type"""
|
|
48
48
|
|
|
49
49
|
wasm_file: str
|
|
50
|
-
wasm_hash: int
|
|
51
50
|
|
|
52
51
|
def to_value(self) -> val.Extension:
|
|
53
52
|
ty = WASM_EXTENSION.get_type("module").instantiate([])
|
|
54
53
|
|
|
55
|
-
name = "
|
|
56
|
-
payload = {"
|
|
54
|
+
name = "ConstWasmModule"
|
|
55
|
+
payload = {"module_filename": self.wasm_file}
|
|
57
56
|
return val.Extension(name, typ=ty, val=payload, extensions=["tket.wasm"])
|
|
58
57
|
|
|
59
58
|
def __str__(self) -> str:
|
|
60
|
-
return (
|
|
61
|
-
f"ConstWasmModule(wasm_file={self.wasm_file}, wasm_hash={self.wasm_hash})"
|
|
62
|
-
)
|
|
59
|
+
return f"tket.wasm.module(module_filename={self.wasm_file})"
|
|
@@ -8,12 +8,11 @@ from guppylang_internals.nodes import GlobalCall
|
|
|
8
8
|
from guppylang_internals.std._internal.compiler.arithmetic import convert_itousize
|
|
9
9
|
from guppylang_internals.std._internal.compiler.prelude import build_unwrap
|
|
10
10
|
from guppylang_internals.std._internal.compiler.tket_exts import (
|
|
11
|
-
FUTURES_EXTENSION,
|
|
12
11
|
WASM_EXTENSION,
|
|
13
12
|
ConstWasmModule,
|
|
14
13
|
)
|
|
15
14
|
from guppylang_internals.tys.builtin import (
|
|
16
|
-
|
|
15
|
+
wasm_module_name,
|
|
17
16
|
)
|
|
18
17
|
from guppylang_internals.tys.ty import (
|
|
19
18
|
FunctionType,
|
|
@@ -57,18 +56,20 @@ class WasmModuleDiscardCompiler(CustomInoutCallCompiler):
|
|
|
57
56
|
|
|
58
57
|
class WasmModuleCallCompiler(CustomInoutCallCompiler):
|
|
59
58
|
"""Compiler for WASM calls
|
|
60
|
-
When a wasm method is called in guppy, we turn it into
|
|
59
|
+
When a wasm method is called in guppy, we turn it into 3 tket ops:
|
|
61
60
|
* lookup: wasm.module -> wasm.func
|
|
62
|
-
* call: wasm.context * wasm.func * inputs -> wasm.
|
|
63
|
-
|
|
61
|
+
* call: wasm.context * wasm.func * inputs -> wasm.result
|
|
62
|
+
* read_result: wasm.result -> wasm.context * outputs
|
|
64
63
|
For the wasm.module that we use in lookup, a constant is created for each
|
|
65
64
|
call, using the wasm file information embedded in method's `self` argument.
|
|
66
65
|
"""
|
|
67
66
|
|
|
68
67
|
fn_name: str
|
|
68
|
+
fn_id: int | None
|
|
69
69
|
|
|
70
|
-
def __init__(self, name: str) -> None:
|
|
70
|
+
def __init__(self, name: str, id_: int | None) -> None:
|
|
71
71
|
self.fn_name = name
|
|
72
|
+
self.fn_id = id_
|
|
72
73
|
|
|
73
74
|
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
74
75
|
# The arguments should be:
|
|
@@ -93,14 +94,13 @@ class WasmModuleCallCompiler(CustomInoutCallCompiler):
|
|
|
93
94
|
func_ty = WASM_EXTENSION.get_type("func").instantiate(
|
|
94
95
|
[inputs_row_arg, output_row_arg]
|
|
95
96
|
)
|
|
96
|
-
|
|
97
|
-
[ht.Tuple(*wasm_sig.output).type_arg()]
|
|
98
|
-
)
|
|
97
|
+
result_ty = WASM_EXTENSION.get_type("result").instantiate([output_row_arg])
|
|
99
98
|
|
|
100
99
|
# Get the WASM module information from the type
|
|
101
100
|
selfarg = self.func.ty.inputs[0].ty
|
|
102
|
-
|
|
103
|
-
|
|
101
|
+
info = wasm_module_name(selfarg)
|
|
102
|
+
if info is not None:
|
|
103
|
+
const_module = self.builder.add_const(ConstWasmModule(info))
|
|
104
104
|
else:
|
|
105
105
|
raise InternalGuppyError(
|
|
106
106
|
"Expected cached signature to have WASM module as first arg"
|
|
@@ -109,27 +109,38 @@ class WasmModuleCallCompiler(CustomInoutCallCompiler):
|
|
|
109
109
|
wasm_module = self.builder.load(const_module)
|
|
110
110
|
|
|
111
111
|
# Lookup the function we want
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
112
|
+
if self.fn_id is None:
|
|
113
|
+
fn_name_arg = ht.StringArg(self.fn_name)
|
|
114
|
+
wasm_opdef = WASM_EXTENSION.get_op("lookup_by_name").instantiate(
|
|
115
|
+
[fn_name_arg, inputs_row_arg, output_row_arg],
|
|
116
|
+
ht.FunctionType([module_ty], [func_ty]),
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
fn_id_arg = ht.BoundedNatArg(self.fn_id)
|
|
120
|
+
wasm_opdef = WASM_EXTENSION.get_op("lookup_by_id").instantiate(
|
|
121
|
+
[fn_id_arg, inputs_row_arg, output_row_arg],
|
|
122
|
+
ht.FunctionType([module_ty], [func_ty]),
|
|
123
|
+
)
|
|
124
|
+
|
|
116
125
|
wasm_func = self.builder.add_op(wasm_opdef, wasm_module)
|
|
117
126
|
|
|
118
127
|
# Call the function
|
|
119
128
|
call_op = WASM_EXTENSION.get_op("call").instantiate(
|
|
120
129
|
[inputs_row_arg, output_row_arg],
|
|
121
|
-
ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [
|
|
130
|
+
ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [result_ty]),
|
|
122
131
|
)
|
|
123
132
|
|
|
124
|
-
|
|
133
|
+
result = self.builder.add_op(call_op, args[0], wasm_func, *args[1:])
|
|
125
134
|
|
|
126
|
-
read_opdef =
|
|
127
|
-
[
|
|
128
|
-
ht.FunctionType([
|
|
135
|
+
read_opdef = WASM_EXTENSION.get_op("read_result").instantiate(
|
|
136
|
+
[output_row_arg],
|
|
137
|
+
ht.FunctionType([result_ty], [ctx_ty, *wasm_sig.output]),
|
|
129
138
|
)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
139
|
+
data = self.builder.add_op(read_opdef, result)
|
|
140
|
+
match list(data[:]):
|
|
141
|
+
case [ctx]:
|
|
142
|
+
return CallReturnWires(regular_returns=[], inout_returns=[ctx])
|
|
143
|
+
case [ctx, *values]:
|
|
144
|
+
return CallReturnWires(regular_returns=[*values], inout_returns=[ctx])
|
|
145
|
+
case _:
|
|
146
|
+
raise AssertionError("impossible")
|