guppylang-internals 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/checker/core.py +8 -0
- guppylang_internals/checker/expr_checker.py +10 -20
- guppylang_internals/checker/func_checker.py +170 -21
- guppylang_internals/checker/stmt_checker.py +1 -1
- guppylang_internals/decorator.py +124 -58
- guppylang_internals/definition/const.py +2 -2
- guppylang_internals/definition/custom.py +1 -1
- guppylang_internals/definition/declaration.py +1 -1
- guppylang_internals/definition/extern.py +2 -2
- guppylang_internals/definition/function.py +1 -1
- guppylang_internals/definition/parameter.py +2 -2
- guppylang_internals/definition/pytket_circuits.py +1 -1
- guppylang_internals/definition/struct.py +10 -10
- guppylang_internals/definition/traced.py +1 -1
- guppylang_internals/definition/ty.py +6 -0
- guppylang_internals/definition/wasm.py +2 -2
- guppylang_internals/engine.py +13 -2
- guppylang_internals/nodes.py +0 -23
- guppylang_internals/std/_internal/compiler/tket_exts.py +3 -6
- guppylang_internals/std/_internal/compiler/wasm.py +37 -26
- guppylang_internals/tracing/function.py +13 -2
- guppylang_internals/tracing/unpacking.py +18 -12
- guppylang_internals/tys/builtin.py +30 -11
- guppylang_internals/tys/errors.py +6 -0
- guppylang_internals/tys/parsing.py +111 -125
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.24.0.dist-info}/METADATA +3 -3
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.24.0.dist-info}/RECORD +30 -30
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.24.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.24.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -3,7 +3,7 @@ import inspect
|
|
|
3
3
|
import linecache
|
|
4
4
|
import sys
|
|
5
5
|
from collections.abc import Sequence
|
|
6
|
-
from dataclasses import dataclass
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
7
|
from types import FrameType
|
|
8
8
|
from typing import ClassVar
|
|
9
9
|
|
|
@@ -39,7 +39,7 @@ from guppylang_internals.ipython_inspect import is_running_ipython
|
|
|
39
39
|
from guppylang_internals.span import SourceMap, Span, to_span
|
|
40
40
|
from guppylang_internals.tys.arg import Argument
|
|
41
41
|
from guppylang_internals.tys.param import Parameter, check_all_args
|
|
42
|
-
from guppylang_internals.tys.parsing import type_from_ast
|
|
42
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
|
|
43
43
|
from guppylang_internals.tys.ty import (
|
|
44
44
|
FuncInput,
|
|
45
45
|
FunctionType,
|
|
@@ -115,6 +115,7 @@ class RawStructDef(TypeDef, ParsableDef):
|
|
|
115
115
|
"""A raw struct type definition that has not been parsed yet."""
|
|
116
116
|
|
|
117
117
|
python_class: type
|
|
118
|
+
params: None = field(default=None, init=False) # Params not known yet
|
|
118
119
|
|
|
119
120
|
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef":
|
|
120
121
|
"""Parses the raw class object into an AST and checks that it is well-formed."""
|
|
@@ -211,15 +212,16 @@ class ParsedStructDef(TypeDef, CheckableDef):
|
|
|
211
212
|
|
|
212
213
|
def check(self, globals: Globals) -> "CheckedStructDef":
|
|
213
214
|
"""Checks that all struct fields have valid types."""
|
|
215
|
+
param_var_mapping = {p.name: p for p in self.params}
|
|
216
|
+
ctx = TypeParsingCtx(globals, param_var_mapping)
|
|
217
|
+
|
|
214
218
|
# Before checking the fields, make sure that this definition is not recursive,
|
|
215
219
|
# otherwise the code below would not terminate.
|
|
216
220
|
# TODO: This is not ideal (see todo in `check_instantiate`)
|
|
217
|
-
|
|
218
|
-
check_not_recursive(self, globals, param_var_mapping)
|
|
221
|
+
check_not_recursive(self, ctx)
|
|
219
222
|
|
|
220
223
|
fields = [
|
|
221
|
-
StructField(f.name, type_from_ast(f.type_ast,
|
|
222
|
-
for f in self.fields
|
|
224
|
+
StructField(f.name, type_from_ast(f.type_ast, ctx)) for f in self.fields
|
|
223
225
|
]
|
|
224
226
|
return CheckedStructDef(
|
|
225
227
|
self.id, self.name, self.defined_at, self.params, fields
|
|
@@ -370,9 +372,7 @@ def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Paramet
|
|
|
370
372
|
return params
|
|
371
373
|
|
|
372
374
|
|
|
373
|
-
def check_not_recursive(
|
|
374
|
-
defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter]
|
|
375
|
-
) -> None:
|
|
375
|
+
def check_not_recursive(defn: ParsedStructDef, ctx: TypeParsingCtx) -> None:
|
|
376
376
|
"""Throws a user error if the given struct definition is recursive."""
|
|
377
377
|
|
|
378
378
|
# TODO: The implementation below hijacks the type parsing logic to detect recursive
|
|
@@ -388,5 +388,5 @@ def check_not_recursive(
|
|
|
388
388
|
original = defn.check_instantiate
|
|
389
389
|
object.__setattr__(defn, "check_instantiate", dummy_check_instantiate)
|
|
390
390
|
for fld in defn.fields:
|
|
391
|
-
type_from_ast(fld.type_ast,
|
|
391
|
+
type_from_ast(fld.type_ast, ctx)
|
|
392
392
|
object.__setattr__(defn, "check_instantiate", original)
|
|
@@ -48,7 +48,7 @@ class RawTracedFunctionDef(ParsableDef):
|
|
|
48
48
|
def parse(self, globals: Globals, sources: SourceMap) -> "TracedFunctionDef":
|
|
49
49
|
"""Parses and checks the user-provided signature of the function."""
|
|
50
50
|
func_ast, _docstring = parse_py_func(self.python_func, sources)
|
|
51
|
-
ty = check_signature(func_ast, globals)
|
|
51
|
+
ty = check_signature(func_ast, globals, self.id)
|
|
52
52
|
if ty.parametrized:
|
|
53
53
|
raise GuppyError(UnsupportedError(func_ast, "Generic comptime functions"))
|
|
54
54
|
return TracedFunctionDef(self.id, self.name, func_ast, ty, self.python_func)
|
|
@@ -18,6 +18,12 @@ class TypeDef(Definition):
|
|
|
18
18
|
|
|
19
19
|
description: str = field(default="type", init=False)
|
|
20
20
|
|
|
21
|
+
#: Generic parameters of the type. This may be `None` for special types that are
|
|
22
|
+
#: more polymorphic than the regular type system allows (for example `tuple` and
|
|
23
|
+
#: `Callable`), or if this is a raw definition whose parameters are not determined
|
|
24
|
+
#: yet (for example a `RawStructDef`).
|
|
25
|
+
params: Sequence[Parameter] | None
|
|
26
|
+
|
|
21
27
|
@abstractmethod
|
|
22
28
|
def check_instantiate(
|
|
23
29
|
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
@@ -11,7 +11,7 @@ from guppylang_internals.definition.custom import (
|
|
|
11
11
|
)
|
|
12
12
|
from guppylang_internals.error import GuppyError
|
|
13
13
|
from guppylang_internals.span import SourceMap
|
|
14
|
-
from guppylang_internals.tys.builtin import
|
|
14
|
+
from guppylang_internals.tys.builtin import wasm_module_name
|
|
15
15
|
from guppylang_internals.tys.ty import (
|
|
16
16
|
FuncInput,
|
|
17
17
|
FunctionType,
|
|
@@ -30,7 +30,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
30
30
|
def sanitise_type(self, loc: AstNode | None, fun_ty: FunctionType) -> None:
|
|
31
31
|
# Place to highlight in error messages
|
|
32
32
|
match fun_ty.inputs[0]:
|
|
33
|
-
case FuncInput(ty=ty, flags=InputFlags.Inout) if
|
|
33
|
+
case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
|
|
34
34
|
ty
|
|
35
35
|
) is not None:
|
|
36
36
|
pass
|
guppylang_internals/engine.py
CHANGED
|
@@ -41,6 +41,7 @@ from guppylang_internals.tys.builtin import (
|
|
|
41
41
|
nat_type_def,
|
|
42
42
|
none_type_def,
|
|
43
43
|
option_type_def,
|
|
44
|
+
self_type_def,
|
|
44
45
|
sized_iter_type_def,
|
|
45
46
|
string_type_def,
|
|
46
47
|
tuple_type_def,
|
|
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|
|
51
52
|
|
|
52
53
|
BUILTIN_DEFS_LIST: list[RawDef] = [
|
|
53
54
|
callable_type_def,
|
|
55
|
+
self_type_def,
|
|
54
56
|
tuple_type_def,
|
|
55
57
|
none_type_def,
|
|
56
58
|
bool_type_def,
|
|
@@ -84,12 +86,14 @@ class DefinitionStore:
|
|
|
84
86
|
|
|
85
87
|
raw_defs: dict[DefId, RawDef]
|
|
86
88
|
impls: defaultdict[DefId, dict[str, DefId]]
|
|
89
|
+
impl_parents: dict[DefId, DefId]
|
|
87
90
|
frames: dict[DefId, FrameType]
|
|
88
91
|
sources: SourceMap
|
|
89
92
|
|
|
90
93
|
def __init__(self) -> None:
|
|
91
94
|
self.raw_defs = {defn.id: defn for defn in BUILTIN_DEFS_LIST}
|
|
92
95
|
self.impls = defaultdict(dict)
|
|
96
|
+
self.impl_parents = {}
|
|
93
97
|
self.frames = {}
|
|
94
98
|
self.sources = SourceMap()
|
|
95
99
|
|
|
@@ -99,7 +103,9 @@ class DefinitionStore:
|
|
|
99
103
|
self.frames[defn.id] = frame
|
|
100
104
|
|
|
101
105
|
def register_impl(self, ty_id: DefId, name: str, impl_id: DefId) -> None:
|
|
106
|
+
assert impl_id not in self.impl_parents, "Already an impl"
|
|
102
107
|
self.impls[ty_id][name] = impl_id
|
|
108
|
+
self.impl_parents[impl_id] = ty_id
|
|
103
109
|
# Update the frame of the definition to the frame of the defining class
|
|
104
110
|
if impl_id in self.frames:
|
|
105
111
|
frame = self.frames[impl_id].f_back
|
|
@@ -138,18 +144,23 @@ class CompilationEngine:
|
|
|
138
144
|
types_to_check_worklist: dict[DefId, ParsedDef]
|
|
139
145
|
to_check_worklist: dict[DefId, ParsedDef]
|
|
140
146
|
|
|
147
|
+
def __init__(self) -> None:
|
|
148
|
+
"""Resets the compilation cache."""
|
|
149
|
+
self.reset()
|
|
150
|
+
self.additional_extensions = []
|
|
151
|
+
|
|
141
152
|
def reset(self) -> None:
|
|
142
153
|
"""Resets the compilation cache."""
|
|
143
154
|
self.parsed = {}
|
|
144
155
|
self.checked = {}
|
|
145
156
|
self.compiled = {}
|
|
146
|
-
self.additional_extensions = []
|
|
147
157
|
self.to_check_worklist = {}
|
|
148
158
|
self.types_to_check_worklist = {}
|
|
149
159
|
|
|
150
160
|
@pretty_errors
|
|
151
161
|
def register_extension(self, extension: Extension) -> None:
|
|
152
|
-
self.additional_extensions
|
|
162
|
+
if extension not in self.additional_extensions:
|
|
163
|
+
self.additional_extensions.append(extension)
|
|
153
164
|
|
|
154
165
|
@pretty_errors
|
|
155
166
|
def get_parsed(self, id: DefId) -> ParsedDef:
|
guppylang_internals/nodes.py
CHANGED
|
@@ -166,17 +166,6 @@ class MakeIter(ast.expr):
|
|
|
166
166
|
self.unwrap_size_hint = unwrap_size_hint
|
|
167
167
|
|
|
168
168
|
|
|
169
|
-
class IterHasNext(ast.expr):
|
|
170
|
-
"""Checks if an iterator has a next element using the `__hasnext__` magic method.
|
|
171
|
-
|
|
172
|
-
This node is inserted in `for` loops and list comprehensions.
|
|
173
|
-
"""
|
|
174
|
-
|
|
175
|
-
value: ast.expr
|
|
176
|
-
|
|
177
|
-
_fields = ("value",)
|
|
178
|
-
|
|
179
|
-
|
|
180
169
|
class IterNext(ast.expr):
|
|
181
170
|
"""Obtains the next element of an iterator using the `__next__` magic method.
|
|
182
171
|
|
|
@@ -188,18 +177,6 @@ class IterNext(ast.expr):
|
|
|
188
177
|
_fields = ("value",)
|
|
189
178
|
|
|
190
179
|
|
|
191
|
-
class IterEnd(ast.expr):
|
|
192
|
-
"""Finalises an iterator using the `__end__` magic method.
|
|
193
|
-
|
|
194
|
-
This node is inserted in `for` loops and list comprehensions. It is needed to
|
|
195
|
-
consume linear iterators once they are finished.
|
|
196
|
-
"""
|
|
197
|
-
|
|
198
|
-
value: ast.expr
|
|
199
|
-
|
|
200
|
-
_fields = ("value",)
|
|
201
|
-
|
|
202
|
-
|
|
203
180
|
class DesugaredGenerator(ast.expr):
|
|
204
181
|
"""A single desugared generator in a list comprehension.
|
|
205
182
|
|
|
@@ -47,16 +47,13 @@ class ConstWasmModule(val.ExtensionValue):
|
|
|
47
47
|
"""Python wrapper for the tket ConstWasmModule type"""
|
|
48
48
|
|
|
49
49
|
wasm_file: str
|
|
50
|
-
wasm_hash: int
|
|
51
50
|
|
|
52
51
|
def to_value(self) -> val.Extension:
|
|
53
52
|
ty = WASM_EXTENSION.get_type("module").instantiate([])
|
|
54
53
|
|
|
55
|
-
name = "
|
|
56
|
-
payload = {"
|
|
54
|
+
name = "ConstWasmModule"
|
|
55
|
+
payload = {"module_filename": self.wasm_file}
|
|
57
56
|
return val.Extension(name, typ=ty, val=payload, extensions=["tket.wasm"])
|
|
58
57
|
|
|
59
58
|
def __str__(self) -> str:
|
|
60
|
-
return (
|
|
61
|
-
f"ConstWasmModule(wasm_file={self.wasm_file}, wasm_hash={self.wasm_hash})"
|
|
62
|
-
)
|
|
59
|
+
return f"tket.wasm.module(module_filename={self.wasm_file})"
|
|
@@ -8,12 +8,11 @@ from guppylang_internals.nodes import GlobalCall
|
|
|
8
8
|
from guppylang_internals.std._internal.compiler.arithmetic import convert_itousize
|
|
9
9
|
from guppylang_internals.std._internal.compiler.prelude import build_unwrap
|
|
10
10
|
from guppylang_internals.std._internal.compiler.tket_exts import (
|
|
11
|
-
FUTURES_EXTENSION,
|
|
12
11
|
WASM_EXTENSION,
|
|
13
12
|
ConstWasmModule,
|
|
14
13
|
)
|
|
15
14
|
from guppylang_internals.tys.builtin import (
|
|
16
|
-
|
|
15
|
+
wasm_module_name,
|
|
17
16
|
)
|
|
18
17
|
from guppylang_internals.tys.ty import (
|
|
19
18
|
FunctionType,
|
|
@@ -57,18 +56,20 @@ class WasmModuleDiscardCompiler(CustomInoutCallCompiler):
|
|
|
57
56
|
|
|
58
57
|
class WasmModuleCallCompiler(CustomInoutCallCompiler):
|
|
59
58
|
"""Compiler for WASM calls
|
|
60
|
-
When a wasm method is called in guppy, we turn it into
|
|
59
|
+
When a wasm method is called in guppy, we turn it into 3 tket ops:
|
|
61
60
|
* lookup: wasm.module -> wasm.func
|
|
62
|
-
* call: wasm.context * wasm.func * inputs -> wasm.
|
|
63
|
-
|
|
61
|
+
* call: wasm.context * wasm.func * inputs -> wasm.result
|
|
62
|
+
* read_result: wasm.result -> wasm.context * outputs
|
|
64
63
|
For the wasm.module that we use in lookup, a constant is created for each
|
|
65
64
|
call, using the wasm file information embedded in method's `self` argument.
|
|
66
65
|
"""
|
|
67
66
|
|
|
68
67
|
fn_name: str
|
|
68
|
+
fn_id: int | None
|
|
69
69
|
|
|
70
|
-
def __init__(self, name: str) -> None:
|
|
70
|
+
def __init__(self, name: str, id_: int | None) -> None:
|
|
71
71
|
self.fn_name = name
|
|
72
|
+
self.fn_id = id_
|
|
72
73
|
|
|
73
74
|
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
74
75
|
# The arguments should be:
|
|
@@ -93,14 +94,13 @@ class WasmModuleCallCompiler(CustomInoutCallCompiler):
|
|
|
93
94
|
func_ty = WASM_EXTENSION.get_type("func").instantiate(
|
|
94
95
|
[inputs_row_arg, output_row_arg]
|
|
95
96
|
)
|
|
96
|
-
|
|
97
|
-
[ht.Tuple(*wasm_sig.output).type_arg()]
|
|
98
|
-
)
|
|
97
|
+
result_ty = WASM_EXTENSION.get_type("result").instantiate([output_row_arg])
|
|
99
98
|
|
|
100
99
|
# Get the WASM module information from the type
|
|
101
100
|
selfarg = self.func.ty.inputs[0].ty
|
|
102
|
-
|
|
103
|
-
|
|
101
|
+
info = wasm_module_name(selfarg)
|
|
102
|
+
if info is not None:
|
|
103
|
+
const_module = self.builder.add_const(ConstWasmModule(info))
|
|
104
104
|
else:
|
|
105
105
|
raise InternalGuppyError(
|
|
106
106
|
"Expected cached signature to have WASM module as first arg"
|
|
@@ -109,27 +109,38 @@ class WasmModuleCallCompiler(CustomInoutCallCompiler):
|
|
|
109
109
|
wasm_module = self.builder.load(const_module)
|
|
110
110
|
|
|
111
111
|
# Lookup the function we want
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
112
|
+
if self.fn_id is None:
|
|
113
|
+
fn_name_arg = ht.StringArg(self.fn_name)
|
|
114
|
+
wasm_opdef = WASM_EXTENSION.get_op("lookup_by_name").instantiate(
|
|
115
|
+
[fn_name_arg, inputs_row_arg, output_row_arg],
|
|
116
|
+
ht.FunctionType([module_ty], [func_ty]),
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
fn_id_arg = ht.BoundedNatArg(self.fn_id)
|
|
120
|
+
wasm_opdef = WASM_EXTENSION.get_op("lookup_by_id").instantiate(
|
|
121
|
+
[fn_id_arg, inputs_row_arg, output_row_arg],
|
|
122
|
+
ht.FunctionType([module_ty], [func_ty]),
|
|
123
|
+
)
|
|
124
|
+
|
|
116
125
|
wasm_func = self.builder.add_op(wasm_opdef, wasm_module)
|
|
117
126
|
|
|
118
127
|
# Call the function
|
|
119
128
|
call_op = WASM_EXTENSION.get_op("call").instantiate(
|
|
120
129
|
[inputs_row_arg, output_row_arg],
|
|
121
|
-
ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [
|
|
130
|
+
ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [result_ty]),
|
|
122
131
|
)
|
|
123
132
|
|
|
124
|
-
|
|
133
|
+
result = self.builder.add_op(call_op, args[0], wasm_func, *args[1:])
|
|
125
134
|
|
|
126
|
-
read_opdef =
|
|
127
|
-
[
|
|
128
|
-
ht.FunctionType([
|
|
135
|
+
read_opdef = WASM_EXTENSION.get_op("read_result").instantiate(
|
|
136
|
+
[output_row_arg],
|
|
137
|
+
ht.FunctionType([result_ty], [ctx_ty, *wasm_sig.output]),
|
|
129
138
|
)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
139
|
+
data = self.builder.add_op(read_opdef, result)
|
|
140
|
+
match list(data[:]):
|
|
141
|
+
case [ctx]:
|
|
142
|
+
return CallReturnWires(regular_returns=[], inout_returns=[ctx])
|
|
143
|
+
case [ctx, *values]:
|
|
144
|
+
return CallReturnWires(regular_returns=[*values], inout_returns=[ctx])
|
|
145
|
+
case _:
|
|
146
|
+
raise AssertionError("impossible")
|
|
@@ -176,10 +176,21 @@ def trace_call(func: CallableDef, *args: Any) -> Any:
|
|
|
176
176
|
if len(func.ty.inputs) != 0:
|
|
177
177
|
for inp, arg, var in zip(func.ty.inputs, args, arg_vars, strict=True):
|
|
178
178
|
if InputFlags.Inout in inp.flags:
|
|
179
|
+
# Note that `inp.ty` could refer to bound variables in the function
|
|
180
|
+
# signature. Instead, make sure to use `var.ty` which will always be a
|
|
181
|
+
# concrete type and type checking has ensured that they unify.
|
|
182
|
+
ty = var.ty
|
|
179
183
|
inout_wire = state.dfg[var]
|
|
180
|
-
update_packed_value(
|
|
181
|
-
arg, GuppyObject(
|
|
184
|
+
success = update_packed_value(
|
|
185
|
+
arg, GuppyObject(ty, inout_wire), state.dfg.builder
|
|
182
186
|
)
|
|
187
|
+
if not success:
|
|
188
|
+
# This means the user has passed an object that we cannot update,
|
|
189
|
+
# e.g. calling `mem_swap(x, y)` where the inputs are plain Python
|
|
190
|
+
# objects
|
|
191
|
+
raise GuppyComptimeError(
|
|
192
|
+
f"Cannot borrow Python object of type `{ty}` at comptime"
|
|
193
|
+
)
|
|
183
194
|
|
|
184
195
|
ret_obj = GuppyObject(ret_ty, ret_wire)
|
|
185
196
|
return unpack_guppy_object(ret_obj, state.dfg.builder)
|
|
@@ -150,13 +150,15 @@ def guppy_object_from_py(
|
|
|
150
150
|
return GuppyObject(ty, builder.load(hugr_val))
|
|
151
151
|
|
|
152
152
|
|
|
153
|
-
def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) ->
|
|
153
|
+
def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> bool:
|
|
154
154
|
"""Given a Python value `v` and a `GuppyObject` `obj` that was constructed from `v`
|
|
155
|
-
using `guppy_object_from_py`,
|
|
156
|
-
`v` to the new wires specified by `obj`.
|
|
155
|
+
using `guppy_object_from_py`, tries to update the wires of any `GuppyObjects`
|
|
156
|
+
contained in `v` to the new wires specified by `obj`.
|
|
157
157
|
|
|
158
158
|
Also resets the used flag on any of those updated wires. This corresponds to making
|
|
159
159
|
the object available again since it now corresponds to a fresh wire.
|
|
160
|
+
|
|
161
|
+
Returns `True` if all wires could be updated, otherwise `False`.
|
|
160
162
|
"""
|
|
161
163
|
match v:
|
|
162
164
|
case GuppyObject() as v_obj:
|
|
@@ -172,23 +174,27 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None:
|
|
|
172
174
|
assert isinstance(obj._ty, TupleType)
|
|
173
175
|
wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs()
|
|
174
176
|
for v, ty, wire in zip(vs, obj._ty.element_types, wires, strict=True):
|
|
175
|
-
update_packed_value(v, GuppyObject(ty, wire), builder)
|
|
177
|
+
success = update_packed_value(v, GuppyObject(ty, wire), builder)
|
|
178
|
+
if not success:
|
|
179
|
+
return False
|
|
176
180
|
case GuppyStructObject(_ty=ty, _field_values=values):
|
|
177
181
|
assert obj._ty == ty
|
|
178
182
|
wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs()
|
|
179
|
-
for (
|
|
180
|
-
field,
|
|
181
|
-
wire,
|
|
182
|
-
) in zip(ty.fields, wires, strict=True):
|
|
183
|
+
for field, wire in zip(ty.fields, wires, strict=True):
|
|
183
184
|
v = values[field.name]
|
|
184
|
-
update_packed_value(v, GuppyObject(field.ty, wire), builder)
|
|
185
|
+
success = update_packed_value(v, GuppyObject(field.ty, wire), builder)
|
|
186
|
+
if not success:
|
|
187
|
+
values[field.name] = obj
|
|
185
188
|
case list(vs) if len(vs) > 0:
|
|
186
189
|
assert is_array_type(obj._ty)
|
|
187
190
|
elem_ty = get_element_type(obj._ty)
|
|
188
191
|
opt_wires = unpack_array(builder, obj._use_wire(None))
|
|
189
192
|
err = "Non-droppable array element has already been used"
|
|
190
|
-
for v, opt_wire in zip(vs, opt_wires, strict=True):
|
|
193
|
+
for i, (v, opt_wire) in enumerate(zip(vs, opt_wires, strict=True)):
|
|
191
194
|
(wire,) = build_unwrap(builder, opt_wire, err).outputs()
|
|
192
|
-
update_packed_value(v, GuppyObject(elem_ty, wire), builder)
|
|
195
|
+
success = update_packed_value(v, GuppyObject(elem_ty, wire), builder)
|
|
196
|
+
if not success:
|
|
197
|
+
vs[i] = obj
|
|
193
198
|
case _:
|
|
194
|
-
|
|
199
|
+
return False
|
|
200
|
+
return True
|
|
@@ -46,6 +46,27 @@ class CallableTypeDef(TypeDef, CompiledDef):
|
|
|
46
46
|
raise InternalGuppyError("Tried to `Callable` type via `check_instantiate`")
|
|
47
47
|
|
|
48
48
|
|
|
49
|
+
@dataclass(frozen=True)
|
|
50
|
+
class SelfTypeDef(TypeDef, CompiledDef):
|
|
51
|
+
"""Type definition associated with the `Self` type on methods.
|
|
52
|
+
|
|
53
|
+
During type parsing, we make sure that this type is replaced with the concrete type
|
|
54
|
+
the method is attached to. Thus, we should never have instances of this type around.
|
|
55
|
+
|
|
56
|
+
In other words, this definition is only a marker so that type parsing doesn't have
|
|
57
|
+
to rely on matching against the string "Self". By making `Self` a definition, we can
|
|
58
|
+
use the existing identifier tracking system and also handle users shadowing the
|
|
59
|
+
`Self` binder or assigning `Self` to some other name.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
name: Literal["Self"] = field(default="Self", init=False)
|
|
63
|
+
|
|
64
|
+
def check_instantiate(
|
|
65
|
+
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
66
|
+
) -> FunctionType:
|
|
67
|
+
raise InternalGuppyError("Tried to instantiate abstract `Self` type`")
|
|
68
|
+
|
|
69
|
+
|
|
49
70
|
@dataclass(frozen=True)
|
|
50
71
|
class _TupleTypeDef(TypeDef, CompiledDef):
|
|
51
72
|
"""Type definition associated with the builtin `tuple` type.
|
|
@@ -106,7 +127,6 @@ class _NumericTypeDef(TypeDef, CompiledDef):
|
|
|
106
127
|
|
|
107
128
|
class WasmModuleTypeDef(OpaqueTypeDef):
|
|
108
129
|
wasm_file: str
|
|
109
|
-
wasm_hash: int
|
|
110
130
|
|
|
111
131
|
def __init__(
|
|
112
132
|
self,
|
|
@@ -114,11 +134,9 @@ class WasmModuleTypeDef(OpaqueTypeDef):
|
|
|
114
134
|
name: str,
|
|
115
135
|
defined_at: ast.AST | None,
|
|
116
136
|
wasm_file: str,
|
|
117
|
-
wasm_hash: int,
|
|
118
137
|
) -> None:
|
|
119
138
|
super().__init__(id, name, defined_at, [], True, True, self.to_hugr)
|
|
120
139
|
self.wasm_file = wasm_file
|
|
121
|
-
self.wasm_hash = wasm_hash
|
|
122
140
|
|
|
123
141
|
def to_hugr(
|
|
124
142
|
self, args: Sequence[TypeArg | ConstArg], ctx: ToHugrContext
|
|
@@ -189,9 +207,10 @@ def _option_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
|
|
|
189
207
|
return ht.Option(arg.ty.to_hugr(ctx))
|
|
190
208
|
|
|
191
209
|
|
|
192
|
-
callable_type_def = CallableTypeDef(DefId.fresh(), None)
|
|
193
|
-
|
|
194
|
-
|
|
210
|
+
callable_type_def = CallableTypeDef(DefId.fresh(), None, None)
|
|
211
|
+
self_type_def = SelfTypeDef(DefId.fresh(), None, [])
|
|
212
|
+
tuple_type_def = _TupleTypeDef(DefId.fresh(), None, None)
|
|
213
|
+
none_type_def = _NoneTypeDef(DefId.fresh(), None, [])
|
|
195
214
|
bool_type_def = OpaqueTypeDef(
|
|
196
215
|
id=DefId.fresh(),
|
|
197
216
|
name="bool",
|
|
@@ -202,13 +221,13 @@ bool_type_def = OpaqueTypeDef(
|
|
|
202
221
|
to_hugr=lambda args, ctx: OpaqueBool,
|
|
203
222
|
)
|
|
204
223
|
nat_type_def = _NumericTypeDef(
|
|
205
|
-
DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat)
|
|
224
|
+
DefId.fresh(), "nat", None, [], NumericType(NumericType.Kind.Nat)
|
|
206
225
|
)
|
|
207
226
|
int_type_def = _NumericTypeDef(
|
|
208
|
-
DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int)
|
|
227
|
+
DefId.fresh(), "int", None, [], NumericType(NumericType.Kind.Int)
|
|
209
228
|
)
|
|
210
229
|
float_type_def = _NumericTypeDef(
|
|
211
|
-
DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float)
|
|
230
|
+
DefId.fresh(), "float", None, [], NumericType(NumericType.Kind.Float)
|
|
212
231
|
)
|
|
213
232
|
string_type_def = OpaqueTypeDef(
|
|
214
233
|
id=DefId.fresh(),
|
|
@@ -345,9 +364,9 @@ def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]:
|
|
|
345
364
|
return isinstance(ty, OpaqueType) and ty.defn == sized_iter_type_def
|
|
346
365
|
|
|
347
366
|
|
|
348
|
-
def
|
|
367
|
+
def wasm_module_name(ty: Type) -> str | None:
|
|
349
368
|
if isinstance(ty, OpaqueType) and isinstance(ty.defn, WasmModuleTypeDef):
|
|
350
|
-
return ty.defn.wasm_file
|
|
369
|
+
return ty.defn.wasm_file
|
|
351
370
|
return None
|
|
352
371
|
|
|
353
372
|
|
|
@@ -116,6 +116,12 @@ class InvalidCallableTypeError(Error):
|
|
|
116
116
|
self.add_sub_diagnostic(InvalidCallableTypeError.Explain(None))
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
@dataclass(frozen=True)
|
|
120
|
+
class SelfTyNotInMethodError(Error):
|
|
121
|
+
title: ClassVar[str] = "Invalid type"
|
|
122
|
+
span_label: ClassVar[str] = "`Self` type annotations are only allowed in methods"
|
|
123
|
+
|
|
124
|
+
|
|
119
125
|
@dataclass(frozen=True)
|
|
120
126
|
class NonLinearOwnedError(Error):
|
|
121
127
|
title: ClassVar[str] = "Invalid annotation"
|