guppylang-internals 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/checker/core.py +8 -0
  3. guppylang_internals/checker/expr_checker.py +10 -20
  4. guppylang_internals/checker/func_checker.py +170 -21
  5. guppylang_internals/checker/stmt_checker.py +1 -1
  6. guppylang_internals/decorator.py +124 -58
  7. guppylang_internals/definition/const.py +2 -2
  8. guppylang_internals/definition/custom.py +1 -1
  9. guppylang_internals/definition/declaration.py +1 -1
  10. guppylang_internals/definition/extern.py +2 -2
  11. guppylang_internals/definition/function.py +1 -1
  12. guppylang_internals/definition/parameter.py +2 -2
  13. guppylang_internals/definition/pytket_circuits.py +1 -1
  14. guppylang_internals/definition/struct.py +10 -10
  15. guppylang_internals/definition/traced.py +1 -1
  16. guppylang_internals/definition/ty.py +6 -0
  17. guppylang_internals/definition/wasm.py +2 -2
  18. guppylang_internals/engine.py +13 -2
  19. guppylang_internals/nodes.py +0 -23
  20. guppylang_internals/std/_internal/compiler/tket_exts.py +3 -6
  21. guppylang_internals/std/_internal/compiler/wasm.py +37 -26
  22. guppylang_internals/tracing/function.py +13 -2
  23. guppylang_internals/tracing/unpacking.py +18 -12
  24. guppylang_internals/tys/builtin.py +30 -11
  25. guppylang_internals/tys/errors.py +6 -0
  26. guppylang_internals/tys/parsing.py +111 -125
  27. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.24.0.dist-info}/METADATA +3 -3
  28. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.24.0.dist-info}/RECORD +30 -30
  29. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.24.0.dist-info}/WHEEL +0 -0
  30. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.24.0.dist-info}/licenses/LICENCE +0 -0
@@ -3,7 +3,7 @@ import inspect
3
3
  import linecache
4
4
  import sys
5
5
  from collections.abc import Sequence
6
- from dataclasses import dataclass
6
+ from dataclasses import dataclass, field
7
7
  from types import FrameType
8
8
  from typing import ClassVar
9
9
 
@@ -39,7 +39,7 @@ from guppylang_internals.ipython_inspect import is_running_ipython
39
39
  from guppylang_internals.span import SourceMap, Span, to_span
40
40
  from guppylang_internals.tys.arg import Argument
41
41
  from guppylang_internals.tys.param import Parameter, check_all_args
42
- from guppylang_internals.tys.parsing import type_from_ast
42
+ from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
43
43
  from guppylang_internals.tys.ty import (
44
44
  FuncInput,
45
45
  FunctionType,
@@ -115,6 +115,7 @@ class RawStructDef(TypeDef, ParsableDef):
115
115
  """A raw struct type definition that has not been parsed yet."""
116
116
 
117
117
  python_class: type
118
+ params: None = field(default=None, init=False) # Params not known yet
118
119
 
119
120
  def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef":
120
121
  """Parses the raw class object into an AST and checks that it is well-formed."""
@@ -211,15 +212,16 @@ class ParsedStructDef(TypeDef, CheckableDef):
211
212
 
212
213
  def check(self, globals: Globals) -> "CheckedStructDef":
213
214
  """Checks that all struct fields have valid types."""
215
+ param_var_mapping = {p.name: p for p in self.params}
216
+ ctx = TypeParsingCtx(globals, param_var_mapping)
217
+
214
218
  # Before checking the fields, make sure that this definition is not recursive,
215
219
  # otherwise the code below would not terminate.
216
220
  # TODO: This is not ideal (see todo in `check_instantiate`)
217
- param_var_mapping = {p.name: p for p in self.params}
218
- check_not_recursive(self, globals, param_var_mapping)
221
+ check_not_recursive(self, ctx)
219
222
 
220
223
  fields = [
221
- StructField(f.name, type_from_ast(f.type_ast, globals, param_var_mapping))
222
- for f in self.fields
224
+ StructField(f.name, type_from_ast(f.type_ast, ctx)) for f in self.fields
223
225
  ]
224
226
  return CheckedStructDef(
225
227
  self.id, self.name, self.defined_at, self.params, fields
@@ -370,9 +372,7 @@ def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Paramet
370
372
  return params
371
373
 
372
374
 
373
- def check_not_recursive(
374
- defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter]
375
- ) -> None:
375
+ def check_not_recursive(defn: ParsedStructDef, ctx: TypeParsingCtx) -> None:
376
376
  """Throws a user error if the given struct definition is recursive."""
377
377
 
378
378
  # TODO: The implementation below hijacks the type parsing logic to detect recursive
@@ -388,5 +388,5 @@ def check_not_recursive(
388
388
  original = defn.check_instantiate
389
389
  object.__setattr__(defn, "check_instantiate", dummy_check_instantiate)
390
390
  for fld in defn.fields:
391
- type_from_ast(fld.type_ast, globals, param_var_mapping)
391
+ type_from_ast(fld.type_ast, ctx)
392
392
  object.__setattr__(defn, "check_instantiate", original)
@@ -48,7 +48,7 @@ class RawTracedFunctionDef(ParsableDef):
48
48
  def parse(self, globals: Globals, sources: SourceMap) -> "TracedFunctionDef":
49
49
  """Parses and checks the user-provided signature of the function."""
50
50
  func_ast, _docstring = parse_py_func(self.python_func, sources)
51
- ty = check_signature(func_ast, globals)
51
+ ty = check_signature(func_ast, globals, self.id)
52
52
  if ty.parametrized:
53
53
  raise GuppyError(UnsupportedError(func_ast, "Generic comptime functions"))
54
54
  return TracedFunctionDef(self.id, self.name, func_ast, ty, self.python_func)
@@ -18,6 +18,12 @@ class TypeDef(Definition):
18
18
 
19
19
  description: str = field(default="type", init=False)
20
20
 
21
+ #: Generic parameters of the type. This may be `None` for special types that are
22
+ #: more polymorphic than the regular type system allows (for example `tuple` and
23
+ #: `Callable`), or if this is a raw definition whose parameters are not determined
24
+ #: yet (for example a `RawStructDef`).
25
+ params: Sequence[Parameter] | None
26
+
21
27
  @abstractmethod
22
28
  def check_instantiate(
23
29
  self, args: Sequence[Argument], loc: AstNode | None = None
@@ -11,7 +11,7 @@ from guppylang_internals.definition.custom import (
11
11
  )
12
12
  from guppylang_internals.error import GuppyError
13
13
  from guppylang_internals.span import SourceMap
14
- from guppylang_internals.tys.builtin import wasm_module_info
14
+ from guppylang_internals.tys.builtin import wasm_module_name
15
15
  from guppylang_internals.tys.ty import (
16
16
  FuncInput,
17
17
  FunctionType,
@@ -30,7 +30,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
30
30
  def sanitise_type(self, loc: AstNode | None, fun_ty: FunctionType) -> None:
31
31
  # Place to highlight in error messages
32
32
  match fun_ty.inputs[0]:
33
- case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_info(
33
+ case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
34
34
  ty
35
35
  ) is not None:
36
36
  pass
@@ -41,6 +41,7 @@ from guppylang_internals.tys.builtin import (
41
41
  nat_type_def,
42
42
  none_type_def,
43
43
  option_type_def,
44
+ self_type_def,
44
45
  sized_iter_type_def,
45
46
  string_type_def,
46
47
  tuple_type_def,
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
51
52
 
52
53
  BUILTIN_DEFS_LIST: list[RawDef] = [
53
54
  callable_type_def,
55
+ self_type_def,
54
56
  tuple_type_def,
55
57
  none_type_def,
56
58
  bool_type_def,
@@ -84,12 +86,14 @@ class DefinitionStore:
84
86
 
85
87
  raw_defs: dict[DefId, RawDef]
86
88
  impls: defaultdict[DefId, dict[str, DefId]]
89
+ impl_parents: dict[DefId, DefId]
87
90
  frames: dict[DefId, FrameType]
88
91
  sources: SourceMap
89
92
 
90
93
  def __init__(self) -> None:
91
94
  self.raw_defs = {defn.id: defn for defn in BUILTIN_DEFS_LIST}
92
95
  self.impls = defaultdict(dict)
96
+ self.impl_parents = {}
93
97
  self.frames = {}
94
98
  self.sources = SourceMap()
95
99
 
@@ -99,7 +103,9 @@ class DefinitionStore:
99
103
  self.frames[defn.id] = frame
100
104
 
101
105
  def register_impl(self, ty_id: DefId, name: str, impl_id: DefId) -> None:
106
+ assert impl_id not in self.impl_parents, "Already an impl"
102
107
  self.impls[ty_id][name] = impl_id
108
+ self.impl_parents[impl_id] = ty_id
103
109
  # Update the frame of the definition to the frame of the defining class
104
110
  if impl_id in self.frames:
105
111
  frame = self.frames[impl_id].f_back
@@ -138,18 +144,23 @@ class CompilationEngine:
138
144
  types_to_check_worklist: dict[DefId, ParsedDef]
139
145
  to_check_worklist: dict[DefId, ParsedDef]
140
146
 
147
+ def __init__(self) -> None:
148
+ """Resets the compilation cache."""
149
+ self.reset()
150
+ self.additional_extensions = []
151
+
141
152
  def reset(self) -> None:
142
153
  """Resets the compilation cache."""
143
154
  self.parsed = {}
144
155
  self.checked = {}
145
156
  self.compiled = {}
146
- self.additional_extensions = []
147
157
  self.to_check_worklist = {}
148
158
  self.types_to_check_worklist = {}
149
159
 
150
160
  @pretty_errors
151
161
  def register_extension(self, extension: Extension) -> None:
152
- self.additional_extensions.append(extension)
162
+ if extension not in self.additional_extensions:
163
+ self.additional_extensions.append(extension)
153
164
 
154
165
  @pretty_errors
155
166
  def get_parsed(self, id: DefId) -> ParsedDef:
@@ -166,17 +166,6 @@ class MakeIter(ast.expr):
166
166
  self.unwrap_size_hint = unwrap_size_hint
167
167
 
168
168
 
169
- class IterHasNext(ast.expr):
170
- """Checks if an iterator has a next element using the `__hasnext__` magic method.
171
-
172
- This node is inserted in `for` loops and list comprehensions.
173
- """
174
-
175
- value: ast.expr
176
-
177
- _fields = ("value",)
178
-
179
-
180
169
  class IterNext(ast.expr):
181
170
  """Obtains the next element of an iterator using the `__next__` magic method.
182
171
 
@@ -188,18 +177,6 @@ class IterNext(ast.expr):
188
177
  _fields = ("value",)
189
178
 
190
179
 
191
- class IterEnd(ast.expr):
192
- """Finalises an iterator using the `__end__` magic method.
193
-
194
- This node is inserted in `for` loops and list comprehensions. It is needed to
195
- consume linear iterators once they are finished.
196
- """
197
-
198
- value: ast.expr
199
-
200
- _fields = ("value",)
201
-
202
-
203
180
  class DesugaredGenerator(ast.expr):
204
181
  """A single desugared generator in a list comprehension.
205
182
 
@@ -47,16 +47,13 @@ class ConstWasmModule(val.ExtensionValue):
47
47
  """Python wrapper for the tket ConstWasmModule type"""
48
48
 
49
49
  wasm_file: str
50
- wasm_hash: int
51
50
 
52
51
  def to_value(self) -> val.Extension:
53
52
  ty = WASM_EXTENSION.get_type("module").instantiate([])
54
53
 
55
- name = "tket.wasm.ConstWasmModule"
56
- payload = {"name": self.wasm_file, "hash": self.wasm_hash}
54
+ name = "ConstWasmModule"
55
+ payload = {"module_filename": self.wasm_file}
57
56
  return val.Extension(name, typ=ty, val=payload, extensions=["tket.wasm"])
58
57
 
59
58
  def __str__(self) -> str:
60
- return (
61
- f"ConstWasmModule(wasm_file={self.wasm_file}, wasm_hash={self.wasm_hash})"
62
- )
59
+ return f"tket.wasm.module(module_filename={self.wasm_file})"
@@ -8,12 +8,11 @@ from guppylang_internals.nodes import GlobalCall
8
8
  from guppylang_internals.std._internal.compiler.arithmetic import convert_itousize
9
9
  from guppylang_internals.std._internal.compiler.prelude import build_unwrap
10
10
  from guppylang_internals.std._internal.compiler.tket_exts import (
11
- FUTURES_EXTENSION,
12
11
  WASM_EXTENSION,
13
12
  ConstWasmModule,
14
13
  )
15
14
  from guppylang_internals.tys.builtin import (
16
- wasm_module_info,
15
+ wasm_module_name,
17
16
  )
18
17
  from guppylang_internals.tys.ty import (
19
18
  FunctionType,
@@ -57,18 +56,20 @@ class WasmModuleDiscardCompiler(CustomInoutCallCompiler):
57
56
 
58
57
  class WasmModuleCallCompiler(CustomInoutCallCompiler):
59
58
  """Compiler for WASM calls
60
- When a wasm method is called in guppy, we turn it into 2 tket ops:
59
+ When a wasm method is called in guppy, we turn it into 3 tket ops:
61
60
  * lookup: wasm.module -> wasm.func
62
- * call: wasm.context * wasm.func * inputs -> wasm.context * output
63
-
61
+ * call: wasm.context * wasm.func * inputs -> wasm.result
62
+ * read_result: wasm.result -> wasm.context * outputs
64
63
  For the wasm.module that we use in lookup, a constant is created for each
65
64
  call, using the wasm file information embedded in method's `self` argument.
66
65
  """
67
66
 
68
67
  fn_name: str
68
+ fn_id: int | None
69
69
 
70
- def __init__(self, name: str) -> None:
70
+ def __init__(self, name: str, id_: int | None) -> None:
71
71
  self.fn_name = name
72
+ self.fn_id = id_
72
73
 
73
74
  def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
74
75
  # The arguments should be:
@@ -93,14 +94,13 @@ class WasmModuleCallCompiler(CustomInoutCallCompiler):
93
94
  func_ty = WASM_EXTENSION.get_type("func").instantiate(
94
95
  [inputs_row_arg, output_row_arg]
95
96
  )
96
- future_ty = FUTURES_EXTENSION.get_type("Future").instantiate(
97
- [ht.Tuple(*wasm_sig.output).type_arg()]
98
- )
97
+ result_ty = WASM_EXTENSION.get_type("result").instantiate([output_row_arg])
99
98
 
100
99
  # Get the WASM module information from the type
101
100
  selfarg = self.func.ty.inputs[0].ty
102
- if info := wasm_module_info(selfarg):
103
- const_module = self.builder.add_const(ConstWasmModule(*info))
101
+ info = wasm_module_name(selfarg)
102
+ if info is not None:
103
+ const_module = self.builder.add_const(ConstWasmModule(info))
104
104
  else:
105
105
  raise InternalGuppyError(
106
106
  "Expected cached signature to have WASM module as first arg"
@@ -109,27 +109,38 @@ class WasmModuleCallCompiler(CustomInoutCallCompiler):
109
109
  wasm_module = self.builder.load(const_module)
110
110
 
111
111
  # Lookup the function we want
112
- wasm_opdef = WASM_EXTENSION.get_op("lookup").instantiate(
113
- [fn_name_arg, inputs_row_arg, output_row_arg],
114
- ht.FunctionType([module_ty], [func_ty]),
115
- )
112
+ if self.fn_id is None:
113
+ fn_name_arg = ht.StringArg(self.fn_name)
114
+ wasm_opdef = WASM_EXTENSION.get_op("lookup_by_name").instantiate(
115
+ [fn_name_arg, inputs_row_arg, output_row_arg],
116
+ ht.FunctionType([module_ty], [func_ty]),
117
+ )
118
+ else:
119
+ fn_id_arg = ht.BoundedNatArg(self.fn_id)
120
+ wasm_opdef = WASM_EXTENSION.get_op("lookup_by_id").instantiate(
121
+ [fn_id_arg, inputs_row_arg, output_row_arg],
122
+ ht.FunctionType([module_ty], [func_ty]),
123
+ )
124
+
116
125
  wasm_func = self.builder.add_op(wasm_opdef, wasm_module)
117
126
 
118
127
  # Call the function
119
128
  call_op = WASM_EXTENSION.get_op("call").instantiate(
120
129
  [inputs_row_arg, output_row_arg],
121
- ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [ctx_ty, future_ty]),
130
+ ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [result_ty]),
122
131
  )
123
132
 
124
- ctx, future = self.builder.add_op(call_op, args[0], wasm_func, *args[1:])
133
+ result = self.builder.add_op(call_op, args[0], wasm_func, *args[1:])
125
134
 
126
- read_opdef = FUTURES_EXTENSION.get_op("Read").instantiate(
127
- [ht.Tuple(*wasm_sig.output).type_arg()],
128
- ht.FunctionType([future_ty], [ht.Tuple(*wasm_sig.output)]),
135
+ read_opdef = WASM_EXTENSION.get_op("read_result").instantiate(
136
+ [output_row_arg],
137
+ ht.FunctionType([result_ty], [ctx_ty, *wasm_sig.output]),
129
138
  )
130
- result = self.builder.add_op(read_opdef, future)
131
- ws: list[Wire] = list(result[:])
132
- node = self.builder.add_op(ops.UnpackTuple(wasm_sig.output), *ws)
133
- ws: list[Wire] = list(node[:])
134
-
135
- return CallReturnWires(regular_returns=ws, inout_returns=[ctx])
139
+ data = self.builder.add_op(read_opdef, result)
140
+ match list(data[:]):
141
+ case [ctx]:
142
+ return CallReturnWires(regular_returns=[], inout_returns=[ctx])
143
+ case [ctx, *values]:
144
+ return CallReturnWires(regular_returns=[*values], inout_returns=[ctx])
145
+ case _:
146
+ raise AssertionError("impossible")
@@ -176,10 +176,21 @@ def trace_call(func: CallableDef, *args: Any) -> Any:
176
176
  if len(func.ty.inputs) != 0:
177
177
  for inp, arg, var in zip(func.ty.inputs, args, arg_vars, strict=True):
178
178
  if InputFlags.Inout in inp.flags:
179
+ # Note that `inp.ty` could refer to bound variables in the function
180
+ # signature. Instead, make sure to use `var.ty` which will always be a
181
+ # concrete type and type checking has ensured that they unify.
182
+ ty = var.ty
179
183
  inout_wire = state.dfg[var]
180
- update_packed_value(
181
- arg, GuppyObject(inp.ty, inout_wire), state.dfg.builder
184
+ success = update_packed_value(
185
+ arg, GuppyObject(ty, inout_wire), state.dfg.builder
182
186
  )
187
+ if not success:
188
+ # This means the user has passed an object that we cannot update,
189
+ # e.g. calling `mem_swap(x, y)` where the inputs are plain Python
190
+ # objects
191
+ raise GuppyComptimeError(
192
+ f"Cannot borrow Python object of type `{ty}` at comptime"
193
+ )
183
194
 
184
195
  ret_obj = GuppyObject(ret_ty, ret_wire)
185
196
  return unpack_guppy_object(ret_obj, state.dfg.builder)
@@ -150,13 +150,15 @@ def guppy_object_from_py(
150
150
  return GuppyObject(ty, builder.load(hugr_val))
151
151
 
152
152
 
153
- def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None:
153
+ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> bool:
154
154
  """Given a Python value `v` and a `GuppyObject` `obj` that was constructed from `v`
155
- using `guppy_object_from_py`, updates the wires of any `GuppyObjects` contained in
156
- `v` to the new wires specified by `obj`.
155
+ using `guppy_object_from_py`, tries to update the wires of any `GuppyObjects`
156
+ contained in `v` to the new wires specified by `obj`.
157
157
 
158
158
  Also resets the used flag on any of those updated wires. This corresponds to making
159
159
  the object available again since it now corresponds to a fresh wire.
160
+
161
+ Returns `True` if all wires could be updated, otherwise `False`.
160
162
  """
161
163
  match v:
162
164
  case GuppyObject() as v_obj:
@@ -172,23 +174,27 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None:
172
174
  assert isinstance(obj._ty, TupleType)
173
175
  wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs()
174
176
  for v, ty, wire in zip(vs, obj._ty.element_types, wires, strict=True):
175
- update_packed_value(v, GuppyObject(ty, wire), builder)
177
+ success = update_packed_value(v, GuppyObject(ty, wire), builder)
178
+ if not success:
179
+ return False
176
180
  case GuppyStructObject(_ty=ty, _field_values=values):
177
181
  assert obj._ty == ty
178
182
  wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs()
179
- for (
180
- field,
181
- wire,
182
- ) in zip(ty.fields, wires, strict=True):
183
+ for field, wire in zip(ty.fields, wires, strict=True):
183
184
  v = values[field.name]
184
- update_packed_value(v, GuppyObject(field.ty, wire), builder)
185
+ success = update_packed_value(v, GuppyObject(field.ty, wire), builder)
186
+ if not success:
187
+ values[field.name] = obj
185
188
  case list(vs) if len(vs) > 0:
186
189
  assert is_array_type(obj._ty)
187
190
  elem_ty = get_element_type(obj._ty)
188
191
  opt_wires = unpack_array(builder, obj._use_wire(None))
189
192
  err = "Non-droppable array element has already been used"
190
- for v, opt_wire in zip(vs, opt_wires, strict=True):
193
+ for i, (v, opt_wire) in enumerate(zip(vs, opt_wires, strict=True)):
191
194
  (wire,) = build_unwrap(builder, opt_wire, err).outputs()
192
- update_packed_value(v, GuppyObject(elem_ty, wire), builder)
195
+ success = update_packed_value(v, GuppyObject(elem_ty, wire), builder)
196
+ if not success:
197
+ vs[i] = obj
193
198
  case _:
194
- pass
199
+ return False
200
+ return True
@@ -46,6 +46,27 @@ class CallableTypeDef(TypeDef, CompiledDef):
46
46
  raise InternalGuppyError("Tried to `Callable` type via `check_instantiate`")
47
47
 
48
48
 
49
+ @dataclass(frozen=True)
50
+ class SelfTypeDef(TypeDef, CompiledDef):
51
+ """Type definition associated with the `Self` type on methods.
52
+
53
+ During type parsing, we make sure that this type is replaced with the concrete type
54
+ the method is attached to. Thus, we should never have instances of this type around.
55
+
56
+ In other words, this definition is only a marker so that type parsing doesn't have
57
+ to rely on matching against the string "Self". By making `Self` a definition, we can
58
+ use the existing identifier tracking system and also handle users shadowing the
59
+ `Self` binder or assigning `Self` to some other name.
60
+ """
61
+
62
+ name: Literal["Self"] = field(default="Self", init=False)
63
+
64
+ def check_instantiate(
65
+ self, args: Sequence[Argument], loc: AstNode | None = None
66
+ ) -> FunctionType:
67
+ raise InternalGuppyError("Tried to instantiate abstract `Self` type`")
68
+
69
+
49
70
  @dataclass(frozen=True)
50
71
  class _TupleTypeDef(TypeDef, CompiledDef):
51
72
  """Type definition associated with the builtin `tuple` type.
@@ -106,7 +127,6 @@ class _NumericTypeDef(TypeDef, CompiledDef):
106
127
 
107
128
  class WasmModuleTypeDef(OpaqueTypeDef):
108
129
  wasm_file: str
109
- wasm_hash: int
110
130
 
111
131
  def __init__(
112
132
  self,
@@ -114,11 +134,9 @@ class WasmModuleTypeDef(OpaqueTypeDef):
114
134
  name: str,
115
135
  defined_at: ast.AST | None,
116
136
  wasm_file: str,
117
- wasm_hash: int,
118
137
  ) -> None:
119
138
  super().__init__(id, name, defined_at, [], True, True, self.to_hugr)
120
139
  self.wasm_file = wasm_file
121
- self.wasm_hash = wasm_hash
122
140
 
123
141
  def to_hugr(
124
142
  self, args: Sequence[TypeArg | ConstArg], ctx: ToHugrContext
@@ -189,9 +207,10 @@ def _option_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
189
207
  return ht.Option(arg.ty.to_hugr(ctx))
190
208
 
191
209
 
192
- callable_type_def = CallableTypeDef(DefId.fresh(), None)
193
- tuple_type_def = _TupleTypeDef(DefId.fresh(), None)
194
- none_type_def = _NoneTypeDef(DefId.fresh(), None)
210
+ callable_type_def = CallableTypeDef(DefId.fresh(), None, None)
211
+ self_type_def = SelfTypeDef(DefId.fresh(), None, [])
212
+ tuple_type_def = _TupleTypeDef(DefId.fresh(), None, None)
213
+ none_type_def = _NoneTypeDef(DefId.fresh(), None, [])
195
214
  bool_type_def = OpaqueTypeDef(
196
215
  id=DefId.fresh(),
197
216
  name="bool",
@@ -202,13 +221,13 @@ bool_type_def = OpaqueTypeDef(
202
221
  to_hugr=lambda args, ctx: OpaqueBool,
203
222
  )
204
223
  nat_type_def = _NumericTypeDef(
205
- DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat)
224
+ DefId.fresh(), "nat", None, [], NumericType(NumericType.Kind.Nat)
206
225
  )
207
226
  int_type_def = _NumericTypeDef(
208
- DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int)
227
+ DefId.fresh(), "int", None, [], NumericType(NumericType.Kind.Int)
209
228
  )
210
229
  float_type_def = _NumericTypeDef(
211
- DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float)
230
+ DefId.fresh(), "float", None, [], NumericType(NumericType.Kind.Float)
212
231
  )
213
232
  string_type_def = OpaqueTypeDef(
214
233
  id=DefId.fresh(),
@@ -345,9 +364,9 @@ def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]:
345
364
  return isinstance(ty, OpaqueType) and ty.defn == sized_iter_type_def
346
365
 
347
366
 
348
- def wasm_module_info(ty: Type) -> tuple[str, int] | None:
367
+ def wasm_module_name(ty: Type) -> str | None:
349
368
  if isinstance(ty, OpaqueType) and isinstance(ty.defn, WasmModuleTypeDef):
350
- return ty.defn.wasm_file, ty.defn.wasm_hash
369
+ return ty.defn.wasm_file
351
370
  return None
352
371
 
353
372
 
@@ -116,6 +116,12 @@ class InvalidCallableTypeError(Error):
116
116
  self.add_sub_diagnostic(InvalidCallableTypeError.Explain(None))
117
117
 
118
118
 
119
+ @dataclass(frozen=True)
120
+ class SelfTyNotInMethodError(Error):
121
+ title: ClassVar[str] = "Invalid type"
122
+ span_label: ClassVar[str] = "`Self` type annotations are only allowed in methods"
123
+
124
+
119
125
  @dataclass(frozen=True)
120
126
  class NonLinearOwnedError(Error):
121
127
  title: ClassVar[str] = "Invalid annotation"