guppylang-internals 0.22.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/cfg/cfg.py +8 -0
  3. guppylang_internals/checker/cfg_checker.py +26 -65
  4. guppylang_internals/checker/core.py +8 -0
  5. guppylang_internals/checker/expr_checker.py +11 -25
  6. guppylang_internals/checker/func_checker.py +170 -21
  7. guppylang_internals/checker/stmt_checker.py +1 -1
  8. guppylang_internals/decorator.py +124 -58
  9. guppylang_internals/definition/const.py +2 -2
  10. guppylang_internals/definition/custom.py +1 -1
  11. guppylang_internals/definition/declaration.py +1 -1
  12. guppylang_internals/definition/extern.py +2 -2
  13. guppylang_internals/definition/function.py +1 -1
  14. guppylang_internals/definition/parameter.py +2 -2
  15. guppylang_internals/definition/pytket_circuits.py +1 -1
  16. guppylang_internals/definition/struct.py +10 -10
  17. guppylang_internals/definition/traced.py +1 -1
  18. guppylang_internals/definition/ty.py +6 -0
  19. guppylang_internals/definition/wasm.py +2 -2
  20. guppylang_internals/engine.py +13 -2
  21. guppylang_internals/nodes.py +0 -23
  22. guppylang_internals/std/_internal/compiler/tket_exts.py +3 -6
  23. guppylang_internals/std/_internal/compiler/wasm.py +37 -26
  24. guppylang_internals/tracing/function.py +13 -2
  25. guppylang_internals/tracing/unpacking.py +18 -12
  26. guppylang_internals/tys/builtin.py +30 -11
  27. guppylang_internals/tys/errors.py +6 -0
  28. guppylang_internals/tys/parsing.py +111 -125
  29. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/METADATA +5 -5
  30. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/RECORD +32 -32
  31. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/WHEEL +0 -0
  32. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/licenses/LICENCE +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import TYPE_CHECKING, ParamSpec, TypeVar
4
+ from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
5
5
 
6
6
  from hugr import ops
7
7
  from hugr import tys as ht
@@ -42,6 +42,7 @@ from guppylang_internals.tys.ty import (
42
42
  )
43
43
 
44
44
  if TYPE_CHECKING:
45
+ import ast
45
46
  import builtins
46
47
  from collections.abc import Callable, Sequence
47
48
  from types import FrameType
@@ -121,15 +122,19 @@ def hugr_op(
121
122
  return custom_function(OpCompiler(op), checker, higher_order_value, name, signature)
122
123
 
123
124
 
124
- def extend_type(defn: TypeDef) -> Callable[[type], type]:
125
- """Decorator to add new instance functions to a type."""
125
+ def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]:
126
+ """Decorator to add new instance functions to a type.
127
+
128
+ By default, returns a `GuppyDefinition` object referring to the type. Alternatively,
129
+ `return_class=True` can be set to return the decorated class unchanged.
130
+ """
126
131
  from guppylang.defs import GuppyDefinition
127
132
 
128
133
  def dec(c: type) -> type:
129
134
  for val in c.__dict__.values():
130
135
  if isinstance(val, GuppyDefinition):
131
136
  DEF_STORE.register_impl(defn.id, val.wrapped.name, val.id)
132
- return c
137
+ return c if return_class else GuppyDefinition(defn) # type: ignore[return-value]
133
138
 
134
139
  return dec
135
140
 
@@ -181,63 +186,124 @@ def custom_type(
181
186
 
182
187
 
183
188
  def wasm_module(
184
- filename: str, filehash: int
189
+ filename: str,
185
190
  ) -> Callable[[builtins.type[T]], GuppyDefinition]:
186
- from guppylang.defs import GuppyDefinition
187
-
188
- def dec(cls: builtins.type[T]) -> GuppyDefinition:
189
- # N.B. Only one module per file and vice-versa
190
- wasm_module = WasmModuleTypeDef(
191
- DefId.fresh(),
192
- cls.__name__,
193
- None,
194
- filename,
195
- filehash,
196
- )
197
-
198
- wasm_module_ty = wasm_module.check_instantiate([], None)
199
-
200
- DEF_STORE.register_def(wasm_module, get_calling_frame())
201
- for val in cls.__dict__.values():
202
- if isinstance(val, GuppyDefinition):
203
- DEF_STORE.register_impl(wasm_module.id, val.wrapped.name, val.id)
204
- # Add a constructor to the class
205
- call_method = CustomFunctionDef(
206
- DefId.fresh(),
207
- "__new__",
208
- None,
209
- FunctionType(
210
- [FuncInput(NumericType(NumericType.Kind.Nat), flags=InputFlags.Owned)],
211
- wasm_module_ty,
212
- ),
213
- DefaultCallChecker(),
214
- WasmModuleInitCompiler(),
215
- True,
216
- GlobalConstId.fresh(f"{cls.__name__}.__new__"),
217
- True,
218
- )
219
- discard = CustomFunctionDef(
220
- DefId.fresh(),
221
- "discard",
222
- None,
223
- FunctionType([FuncInput(wasm_module_ty, InputFlags.Owned)], NoneType()),
224
- DefaultCallChecker(),
225
- WasmModuleDiscardCompiler(),
226
- False,
227
- GlobalConstId.fresh(f"{cls.__name__}.__discard__"),
228
- True,
229
- )
230
- DEF_STORE.register_def(call_method, get_calling_frame())
231
- DEF_STORE.register_impl(wasm_module.id, "__new__", call_method.id)
232
- DEF_STORE.register_def(discard, get_calling_frame())
233
- DEF_STORE.register_impl(wasm_module.id, "discard", discard.id)
191
+ def type_def_wrapper(
192
+ id: DefId,
193
+ name: str,
194
+ defined_at: ast.AST | None,
195
+ wasm_file: str,
196
+ config: str | None,
197
+ ) -> OpaqueTypeDef:
198
+ assert config is None
199
+ return WasmModuleTypeDef(id, name, defined_at, wasm_file)
200
+
201
+ f = ext_module_decorator(
202
+ type_def_wrapper, WasmModuleInitCompiler(), WasmModuleDiscardCompiler(), True
203
+ )
204
+ return f(filename, None)
234
205
 
235
- return GuppyDefinition(wasm_module)
236
-
237
- return dec
238
206
 
207
+ def ext_module_decorator(
208
+ type_def: Callable[[DefId, str, ast.AST | None, str, str | None], OpaqueTypeDef],
209
+ init_compiler: CustomInoutCallCompiler,
210
+ discard_compiler: CustomInoutCallCompiler,
211
+ init_arg: bool, # Whether the init function should take a nat argument
212
+ ) -> Callable[[str, str | None], Callable[[builtins.type[T]], GuppyDefinition]]:
213
+ from guppylang.defs import GuppyDefinition
239
214
 
240
- def wasm(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
215
+ def fun(
216
+ filename: str, module: str | None
217
+ ) -> Callable[[builtins.type[T]], GuppyDefinition]:
218
+ def dec(cls: builtins.type[T]) -> GuppyDefinition:
219
+ # N.B. Only one module per file and vice-versa
220
+ ext_module = type_def(
221
+ DefId.fresh(),
222
+ cls.__name__,
223
+ None,
224
+ filename,
225
+ module,
226
+ )
227
+
228
+ ext_module_ty = ext_module.check_instantiate([], None)
229
+
230
+ DEF_STORE.register_def(ext_module, get_calling_frame())
231
+ for val in cls.__dict__.values():
232
+ if isinstance(val, GuppyDefinition):
233
+ DEF_STORE.register_impl(ext_module.id, val.wrapped.name, val.id)
234
+ # Add a constructor to the class
235
+ if init_arg:
236
+ init_fn_ty = FunctionType(
237
+ [
238
+ FuncInput(
239
+ NumericType(NumericType.Kind.Nat),
240
+ flags=InputFlags.Owned,
241
+ )
242
+ ],
243
+ ext_module_ty,
244
+ )
245
+ else:
246
+ init_fn_ty = FunctionType([], ext_module_ty)
247
+
248
+ call_method = CustomFunctionDef(
249
+ DefId.fresh(),
250
+ "__new__",
251
+ None,
252
+ init_fn_ty,
253
+ DefaultCallChecker(),
254
+ init_compiler,
255
+ True,
256
+ GlobalConstId.fresh(f"{cls.__name__}.__new__"),
257
+ True,
258
+ )
259
+ discard = CustomFunctionDef(
260
+ DefId.fresh(),
261
+ "discard",
262
+ None,
263
+ FunctionType([FuncInput(ext_module_ty, InputFlags.Owned)], NoneType()),
264
+ DefaultCallChecker(),
265
+ discard_compiler,
266
+ False,
267
+ GlobalConstId.fresh(f"{cls.__name__}.__discard__"),
268
+ True,
269
+ )
270
+ DEF_STORE.register_def(call_method, get_calling_frame())
271
+ DEF_STORE.register_impl(ext_module.id, "__new__", call_method.id)
272
+ DEF_STORE.register_def(discard, get_calling_frame())
273
+ DEF_STORE.register_impl(ext_module.id, "discard", discard.id)
274
+
275
+ return GuppyDefinition(ext_module)
276
+
277
+ return dec
278
+
279
+ return fun
280
+
281
+
282
+ @overload
283
+ def wasm(arg: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: ...
284
+
285
+
286
+ @overload
287
+ def wasm(arg: int) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: ...
288
+
289
+
290
+ def wasm(
291
+ arg: int | Callable[P, T],
292
+ ) -> (
293
+ GuppyFunctionDefinition[P, T]
294
+ | Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]
295
+ ):
296
+ if isinstance(arg, int):
297
+
298
+ def wrapper(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
299
+ return wasm_helper(arg, f)
300
+
301
+ return wrapper
302
+ else:
303
+ return wasm_helper(None, arg)
304
+
305
+
306
+ def wasm_helper(fn_id: int | None, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
241
307
  from guppylang.defs import GuppyFunctionDefinition
242
308
 
243
309
  func = RawWasmFunctionDef(
@@ -246,7 +312,7 @@ def wasm(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
246
312
  None,
247
313
  f,
248
314
  WasmCallChecker(),
249
- WasmModuleCallCompiler(f.__name__),
315
+ WasmModuleCallCompiler(f.__name__, fn_id),
250
316
  True,
251
317
  signature=None,
252
318
  )
@@ -15,7 +15,7 @@ from guppylang_internals.definition.value import (
15
15
  ValueDef,
16
16
  )
17
17
  from guppylang_internals.span import SourceMap
18
- from guppylang_internals.tys.parsing import type_from_ast
18
+ from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
19
19
 
20
20
 
21
21
  @dataclass(frozen=True)
@@ -33,7 +33,7 @@ class RawConstDef(ParsableDef):
33
33
  self.id,
34
34
  self.name,
35
35
  self.defined_at,
36
- type_from_ast(self.type_ast, globals, {}),
36
+ type_from_ast(self.type_ast, TypeParsingCtx(globals)),
37
37
  self.type_ast,
38
38
  self.value,
39
39
  )
@@ -169,7 +169,7 @@ class RawCustomFunctionDef(ParsableDef):
169
169
  raise GuppyError(NoSignatureError(node, self.name))
170
170
 
171
171
  if requires_type_annotation:
172
- return check_signature(node, globals)
172
+ return check_signature(node, globals, self.id)
173
173
  else:
174
174
  return None
175
175
 
@@ -68,7 +68,7 @@ class RawFunctionDecl(ParsableDef):
68
68
  def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
69
69
  """Parses and checks the user-provided signature of the function."""
70
70
  func_ast, docstring = parse_py_func(self.python_func, sources)
71
- ty = check_signature(func_ast, globals)
71
+ ty = check_signature(func_ast, globals, self.id)
72
72
  if not has_empty_body(func_ast):
73
73
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
74
74
  # Make sure we won't need monomorphization to compile this declaration
@@ -14,7 +14,7 @@ from guppylang_internals.definition.value import (
14
14
  ValueDef,
15
15
  )
16
16
  from guppylang_internals.span import SourceMap
17
- from guppylang_internals.tys.parsing import type_from_ast
17
+ from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
18
18
 
19
19
 
20
20
  @dataclass(frozen=True)
@@ -33,7 +33,7 @@ class RawExternDef(ParsableDef):
33
33
  self.id,
34
34
  self.name,
35
35
  self.defined_at,
36
- type_from_ast(self.type_ast, globals, {}),
36
+ type_from_ast(self.type_ast, TypeParsingCtx(globals)),
37
37
  self.symbol,
38
38
  self.constant,
39
39
  self.type_ast,
@@ -73,7 +73,7 @@ class RawFunctionDef(ParsableDef):
73
73
  def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
74
74
  """Parses and checks the user-provided signature of the function."""
75
75
  func_ast, docstring = parse_py_func(self.python_func, sources)
76
- ty = check_signature(func_ast, globals)
76
+ ty = check_signature(func_ast, globals, self.id)
77
77
  return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
78
78
 
79
79
 
@@ -56,9 +56,9 @@ class RawConstVarDef(ParamDef, ParsableDef):
56
56
  description: str = field(default="const variable", init=False)
57
57
 
58
58
  def parse(self, globals: Globals, sources: SourceMap) -> "ConstVarDef":
59
- from guppylang_internals.tys.parsing import type_from_ast
59
+ from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
60
60
 
61
- ty = type_from_ast(self.type_ast, globals, {})
61
+ ty = type_from_ast(self.type_ast, TypeParsingCtx(globals))
62
62
  if not ty.copyable or not ty.droppable:
63
63
  raise GuppyError(LinearConstVarError(self.type_ast, self.name, ty))
64
64
  return ConstVarDef(self.id, self.name, self.defined_at, ty)
@@ -85,7 +85,7 @@ class RawPytketDef(ParsableDef):
85
85
  if not has_empty_body(func_ast):
86
86
  # Function stub should have empty body.
87
87
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
88
- stub_signature = check_signature(func_ast, globals)
88
+ stub_signature = check_signature(func_ast, globals, self.id)
89
89
 
90
90
  # Compare signatures.
91
91
  circuit_signature = _signature_from_circuit(self.input_circuit, self.defined_at)
@@ -3,7 +3,7 @@ import inspect
3
3
  import linecache
4
4
  import sys
5
5
  from collections.abc import Sequence
6
- from dataclasses import dataclass
6
+ from dataclasses import dataclass, field
7
7
  from types import FrameType
8
8
  from typing import ClassVar
9
9
 
@@ -39,7 +39,7 @@ from guppylang_internals.ipython_inspect import is_running_ipython
39
39
  from guppylang_internals.span import SourceMap, Span, to_span
40
40
  from guppylang_internals.tys.arg import Argument
41
41
  from guppylang_internals.tys.param import Parameter, check_all_args
42
- from guppylang_internals.tys.parsing import type_from_ast
42
+ from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
43
43
  from guppylang_internals.tys.ty import (
44
44
  FuncInput,
45
45
  FunctionType,
@@ -115,6 +115,7 @@ class RawStructDef(TypeDef, ParsableDef):
115
115
  """A raw struct type definition that has not been parsed yet."""
116
116
 
117
117
  python_class: type
118
+ params: None = field(default=None, init=False) # Params not known yet
118
119
 
119
120
  def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef":
120
121
  """Parses the raw class object into an AST and checks that it is well-formed."""
@@ -211,15 +212,16 @@ class ParsedStructDef(TypeDef, CheckableDef):
211
212
 
212
213
  def check(self, globals: Globals) -> "CheckedStructDef":
213
214
  """Checks that all struct fields have valid types."""
215
+ param_var_mapping = {p.name: p for p in self.params}
216
+ ctx = TypeParsingCtx(globals, param_var_mapping)
217
+
214
218
  # Before checking the fields, make sure that this definition is not recursive,
215
219
  # otherwise the code below would not terminate.
216
220
  # TODO: This is not ideal (see todo in `check_instantiate`)
217
- param_var_mapping = {p.name: p for p in self.params}
218
- check_not_recursive(self, globals, param_var_mapping)
221
+ check_not_recursive(self, ctx)
219
222
 
220
223
  fields = [
221
- StructField(f.name, type_from_ast(f.type_ast, globals, param_var_mapping))
222
- for f in self.fields
224
+ StructField(f.name, type_from_ast(f.type_ast, ctx)) for f in self.fields
223
225
  ]
224
226
  return CheckedStructDef(
225
227
  self.id, self.name, self.defined_at, self.params, fields
@@ -370,9 +372,7 @@ def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Paramet
370
372
  return params
371
373
 
372
374
 
373
- def check_not_recursive(
374
- defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter]
375
- ) -> None:
375
+ def check_not_recursive(defn: ParsedStructDef, ctx: TypeParsingCtx) -> None:
376
376
  """Throws a user error if the given struct definition is recursive."""
377
377
 
378
378
  # TODO: The implementation below hijacks the type parsing logic to detect recursive
@@ -388,5 +388,5 @@ def check_not_recursive(
388
388
  original = defn.check_instantiate
389
389
  object.__setattr__(defn, "check_instantiate", dummy_check_instantiate)
390
390
  for fld in defn.fields:
391
- type_from_ast(fld.type_ast, globals, param_var_mapping)
391
+ type_from_ast(fld.type_ast, ctx)
392
392
  object.__setattr__(defn, "check_instantiate", original)
@@ -48,7 +48,7 @@ class RawTracedFunctionDef(ParsableDef):
48
48
  def parse(self, globals: Globals, sources: SourceMap) -> "TracedFunctionDef":
49
49
  """Parses and checks the user-provided signature of the function."""
50
50
  func_ast, _docstring = parse_py_func(self.python_func, sources)
51
- ty = check_signature(func_ast, globals)
51
+ ty = check_signature(func_ast, globals, self.id)
52
52
  if ty.parametrized:
53
53
  raise GuppyError(UnsupportedError(func_ast, "Generic comptime functions"))
54
54
  return TracedFunctionDef(self.id, self.name, func_ast, ty, self.python_func)
@@ -18,6 +18,12 @@ class TypeDef(Definition):
18
18
 
19
19
  description: str = field(default="type", init=False)
20
20
 
21
+ #: Generic parameters of the type. This may be `None` for special types that are
22
+ #: more polymorphic than the regular type system allows (for example `tuple` and
23
+ #: `Callable`), or if this is a raw definition whose parameters are not determined
24
+ #: yet (for example a `RawStructDef`).
25
+ params: Sequence[Parameter] | None
26
+
21
27
  @abstractmethod
22
28
  def check_instantiate(
23
29
  self, args: Sequence[Argument], loc: AstNode | None = None
@@ -11,7 +11,7 @@ from guppylang_internals.definition.custom import (
11
11
  )
12
12
  from guppylang_internals.error import GuppyError
13
13
  from guppylang_internals.span import SourceMap
14
- from guppylang_internals.tys.builtin import wasm_module_info
14
+ from guppylang_internals.tys.builtin import wasm_module_name
15
15
  from guppylang_internals.tys.ty import (
16
16
  FuncInput,
17
17
  FunctionType,
@@ -30,7 +30,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
30
30
  def sanitise_type(self, loc: AstNode | None, fun_ty: FunctionType) -> None:
31
31
  # Place to highlight in error messages
32
32
  match fun_ty.inputs[0]:
33
- case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_info(
33
+ case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
34
34
  ty
35
35
  ) is not None:
36
36
  pass
@@ -41,6 +41,7 @@ from guppylang_internals.tys.builtin import (
41
41
  nat_type_def,
42
42
  none_type_def,
43
43
  option_type_def,
44
+ self_type_def,
44
45
  sized_iter_type_def,
45
46
  string_type_def,
46
47
  tuple_type_def,
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
51
52
 
52
53
  BUILTIN_DEFS_LIST: list[RawDef] = [
53
54
  callable_type_def,
55
+ self_type_def,
54
56
  tuple_type_def,
55
57
  none_type_def,
56
58
  bool_type_def,
@@ -84,12 +86,14 @@ class DefinitionStore:
84
86
 
85
87
  raw_defs: dict[DefId, RawDef]
86
88
  impls: defaultdict[DefId, dict[str, DefId]]
89
+ impl_parents: dict[DefId, DefId]
87
90
  frames: dict[DefId, FrameType]
88
91
  sources: SourceMap
89
92
 
90
93
  def __init__(self) -> None:
91
94
  self.raw_defs = {defn.id: defn for defn in BUILTIN_DEFS_LIST}
92
95
  self.impls = defaultdict(dict)
96
+ self.impl_parents = {}
93
97
  self.frames = {}
94
98
  self.sources = SourceMap()
95
99
 
@@ -99,7 +103,9 @@ class DefinitionStore:
99
103
  self.frames[defn.id] = frame
100
104
 
101
105
  def register_impl(self, ty_id: DefId, name: str, impl_id: DefId) -> None:
106
+ assert impl_id not in self.impl_parents, "Already an impl"
102
107
  self.impls[ty_id][name] = impl_id
108
+ self.impl_parents[impl_id] = ty_id
103
109
  # Update the frame of the definition to the frame of the defining class
104
110
  if impl_id in self.frames:
105
111
  frame = self.frames[impl_id].f_back
@@ -138,18 +144,23 @@ class CompilationEngine:
138
144
  types_to_check_worklist: dict[DefId, ParsedDef]
139
145
  to_check_worklist: dict[DefId, ParsedDef]
140
146
 
147
+ def __init__(self) -> None:
148
+ """Resets the compilation cache."""
149
+ self.reset()
150
+ self.additional_extensions = []
151
+
141
152
  def reset(self) -> None:
142
153
  """Resets the compilation cache."""
143
154
  self.parsed = {}
144
155
  self.checked = {}
145
156
  self.compiled = {}
146
- self.additional_extensions = []
147
157
  self.to_check_worklist = {}
148
158
  self.types_to_check_worklist = {}
149
159
 
150
160
  @pretty_errors
151
161
  def register_extension(self, extension: Extension) -> None:
152
- self.additional_extensions.append(extension)
162
+ if extension not in self.additional_extensions:
163
+ self.additional_extensions.append(extension)
153
164
 
154
165
  @pretty_errors
155
166
  def get_parsed(self, id: DefId) -> ParsedDef:
@@ -166,17 +166,6 @@ class MakeIter(ast.expr):
166
166
  self.unwrap_size_hint = unwrap_size_hint
167
167
 
168
168
 
169
- class IterHasNext(ast.expr):
170
- """Checks if an iterator has a next element using the `__hasnext__` magic method.
171
-
172
- This node is inserted in `for` loops and list comprehensions.
173
- """
174
-
175
- value: ast.expr
176
-
177
- _fields = ("value",)
178
-
179
-
180
169
  class IterNext(ast.expr):
181
170
  """Obtains the next element of an iterator using the `__next__` magic method.
182
171
 
@@ -188,18 +177,6 @@ class IterNext(ast.expr):
188
177
  _fields = ("value",)
189
178
 
190
179
 
191
- class IterEnd(ast.expr):
192
- """Finalises an iterator using the `__end__` magic method.
193
-
194
- This node is inserted in `for` loops and list comprehensions. It is needed to
195
- consume linear iterators once they are finished.
196
- """
197
-
198
- value: ast.expr
199
-
200
- _fields = ("value",)
201
-
202
-
203
180
  class DesugaredGenerator(ast.expr):
204
181
  """A single desugared generator in a list comprehension.
205
182
 
@@ -47,16 +47,13 @@ class ConstWasmModule(val.ExtensionValue):
47
47
  """Python wrapper for the tket ConstWasmModule type"""
48
48
 
49
49
  wasm_file: str
50
- wasm_hash: int
51
50
 
52
51
  def to_value(self) -> val.Extension:
53
52
  ty = WASM_EXTENSION.get_type("module").instantiate([])
54
53
 
55
- name = "tket.wasm.ConstWasmModule"
56
- payload = {"name": self.wasm_file, "hash": self.wasm_hash}
54
+ name = "ConstWasmModule"
55
+ payload = {"module_filename": self.wasm_file}
57
56
  return val.Extension(name, typ=ty, val=payload, extensions=["tket.wasm"])
58
57
 
59
58
  def __str__(self) -> str:
60
- return (
61
- f"ConstWasmModule(wasm_file={self.wasm_file}, wasm_hash={self.wasm_hash})"
62
- )
59
+ return f"tket.wasm.module(module_filename={self.wasm_file})"
@@ -8,12 +8,11 @@ from guppylang_internals.nodes import GlobalCall
8
8
  from guppylang_internals.std._internal.compiler.arithmetic import convert_itousize
9
9
  from guppylang_internals.std._internal.compiler.prelude import build_unwrap
10
10
  from guppylang_internals.std._internal.compiler.tket_exts import (
11
- FUTURES_EXTENSION,
12
11
  WASM_EXTENSION,
13
12
  ConstWasmModule,
14
13
  )
15
14
  from guppylang_internals.tys.builtin import (
16
- wasm_module_info,
15
+ wasm_module_name,
17
16
  )
18
17
  from guppylang_internals.tys.ty import (
19
18
  FunctionType,
@@ -57,18 +56,20 @@ class WasmModuleDiscardCompiler(CustomInoutCallCompiler):
57
56
 
58
57
  class WasmModuleCallCompiler(CustomInoutCallCompiler):
59
58
  """Compiler for WASM calls
60
- When a wasm method is called in guppy, we turn it into 2 tket ops:
59
+ When a wasm method is called in guppy, we turn it into 3 tket ops:
61
60
  * lookup: wasm.module -> wasm.func
62
- * call: wasm.context * wasm.func * inputs -> wasm.context * output
63
-
61
+ * call: wasm.context * wasm.func * inputs -> wasm.result
62
+ * read_result: wasm.result -> wasm.context * outputs
64
63
  For the wasm.module that we use in lookup, a constant is created for each
65
64
  call, using the wasm file information embedded in method's `self` argument.
66
65
  """
67
66
 
68
67
  fn_name: str
68
+ fn_id: int | None
69
69
 
70
- def __init__(self, name: str) -> None:
70
+ def __init__(self, name: str, id_: int | None) -> None:
71
71
  self.fn_name = name
72
+ self.fn_id = id_
72
73
 
73
74
  def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
74
75
  # The arguments should be:
@@ -93,14 +94,13 @@ class WasmModuleCallCompiler(CustomInoutCallCompiler):
93
94
  func_ty = WASM_EXTENSION.get_type("func").instantiate(
94
95
  [inputs_row_arg, output_row_arg]
95
96
  )
96
- future_ty = FUTURES_EXTENSION.get_type("Future").instantiate(
97
- [ht.Tuple(*wasm_sig.output).type_arg()]
98
- )
97
+ result_ty = WASM_EXTENSION.get_type("result").instantiate([output_row_arg])
99
98
 
100
99
  # Get the WASM module information from the type
101
100
  selfarg = self.func.ty.inputs[0].ty
102
- if info := wasm_module_info(selfarg):
103
- const_module = self.builder.add_const(ConstWasmModule(*info))
101
+ info = wasm_module_name(selfarg)
102
+ if info is not None:
103
+ const_module = self.builder.add_const(ConstWasmModule(info))
104
104
  else:
105
105
  raise InternalGuppyError(
106
106
  "Expected cached signature to have WASM module as first arg"
@@ -109,27 +109,38 @@ class WasmModuleCallCompiler(CustomInoutCallCompiler):
109
109
  wasm_module = self.builder.load(const_module)
110
110
 
111
111
  # Lookup the function we want
112
- wasm_opdef = WASM_EXTENSION.get_op("lookup").instantiate(
113
- [fn_name_arg, inputs_row_arg, output_row_arg],
114
- ht.FunctionType([module_ty], [func_ty]),
115
- )
112
+ if self.fn_id is None:
113
+ fn_name_arg = ht.StringArg(self.fn_name)
114
+ wasm_opdef = WASM_EXTENSION.get_op("lookup_by_name").instantiate(
115
+ [fn_name_arg, inputs_row_arg, output_row_arg],
116
+ ht.FunctionType([module_ty], [func_ty]),
117
+ )
118
+ else:
119
+ fn_id_arg = ht.BoundedNatArg(self.fn_id)
120
+ wasm_opdef = WASM_EXTENSION.get_op("lookup_by_id").instantiate(
121
+ [fn_id_arg, inputs_row_arg, output_row_arg],
122
+ ht.FunctionType([module_ty], [func_ty]),
123
+ )
124
+
116
125
  wasm_func = self.builder.add_op(wasm_opdef, wasm_module)
117
126
 
118
127
  # Call the function
119
128
  call_op = WASM_EXTENSION.get_op("call").instantiate(
120
129
  [inputs_row_arg, output_row_arg],
121
- ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [ctx_ty, future_ty]),
130
+ ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [result_ty]),
122
131
  )
123
132
 
124
- ctx, future = self.builder.add_op(call_op, args[0], wasm_func, *args[1:])
133
+ result = self.builder.add_op(call_op, args[0], wasm_func, *args[1:])
125
134
 
126
- read_opdef = FUTURES_EXTENSION.get_op("Read").instantiate(
127
- [ht.Tuple(*wasm_sig.output).type_arg()],
128
- ht.FunctionType([future_ty], [ht.Tuple(*wasm_sig.output)]),
135
+ read_opdef = WASM_EXTENSION.get_op("read_result").instantiate(
136
+ [output_row_arg],
137
+ ht.FunctionType([result_ty], [ctx_ty, *wasm_sig.output]),
129
138
  )
130
- result = self.builder.add_op(read_opdef, future)
131
- ws: list[Wire] = list(result[:])
132
- node = self.builder.add_op(ops.UnpackTuple(wasm_sig.output), *ws)
133
- ws: list[Wire] = list(node[:])
134
-
135
- return CallReturnWires(regular_returns=ws, inout_returns=[ctx])
139
+ data = self.builder.add_op(read_opdef, result)
140
+ match list(data[:]):
141
+ case [ctx]:
142
+ return CallReturnWires(regular_returns=[], inout_returns=[ctx])
143
+ case [ctx, *values]:
144
+ return CallReturnWires(regular_returns=[*values], inout_returns=[ctx])
145
+ case _:
146
+ raise AssertionError("impossible")