guppylang-internals 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +101 -3
  5. guppylang_internals/checker/core.py +12 -0
  6. guppylang_internals/checker/errors/generic.py +32 -1
  7. guppylang_internals/checker/errors/type_errors.py +14 -0
  8. guppylang_internals/checker/expr_checker.py +55 -29
  9. guppylang_internals/checker/func_checker.py +171 -22
  10. guppylang_internals/checker/linearity_checker.py +65 -0
  11. guppylang_internals/checker/modifier_checker.py +116 -0
  12. guppylang_internals/checker/stmt_checker.py +49 -2
  13. guppylang_internals/compiler/core.py +90 -53
  14. guppylang_internals/compiler/expr_compiler.py +49 -114
  15. guppylang_internals/compiler/modifier_compiler.py +174 -0
  16. guppylang_internals/compiler/stmt_compiler.py +15 -8
  17. guppylang_internals/decorator.py +124 -58
  18. guppylang_internals/definition/const.py +2 -2
  19. guppylang_internals/definition/custom.py +36 -2
  20. guppylang_internals/definition/declaration.py +4 -5
  21. guppylang_internals/definition/extern.py +2 -2
  22. guppylang_internals/definition/function.py +1 -1
  23. guppylang_internals/definition/parameter.py +10 -5
  24. guppylang_internals/definition/pytket_circuits.py +14 -42
  25. guppylang_internals/definition/struct.py +17 -14
  26. guppylang_internals/definition/traced.py +1 -1
  27. guppylang_internals/definition/ty.py +9 -3
  28. guppylang_internals/definition/wasm.py +2 -2
  29. guppylang_internals/engine.py +13 -2
  30. guppylang_internals/experimental.py +5 -0
  31. guppylang_internals/nodes.py +124 -23
  32. guppylang_internals/std/_internal/compiler/array.py +94 -282
  33. guppylang_internals/std/_internal/compiler/tket_exts.py +12 -8
  34. guppylang_internals/std/_internal/compiler/wasm.py +37 -26
  35. guppylang_internals/tracing/function.py +13 -2
  36. guppylang_internals/tracing/unpacking.py +33 -28
  37. guppylang_internals/tys/arg.py +18 -3
  38. guppylang_internals/tys/builtin.py +32 -16
  39. guppylang_internals/tys/const.py +33 -4
  40. guppylang_internals/tys/errors.py +6 -0
  41. guppylang_internals/tys/param.py +31 -16
  42. guppylang_internals/tys/parsing.py +118 -145
  43. guppylang_internals/tys/qubit.py +27 -0
  44. guppylang_internals/tys/subst.py +8 -26
  45. guppylang_internals/tys/ty.py +31 -21
  46. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +4 -4
  47. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +49 -46
  48. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
  49. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
@@ -176,10 +176,21 @@ def trace_call(func: CallableDef, *args: Any) -> Any:
176
176
  if len(func.ty.inputs) != 0:
177
177
  for inp, arg, var in zip(func.ty.inputs, args, arg_vars, strict=True):
178
178
  if InputFlags.Inout in inp.flags:
179
+ # Note that `inp.ty` could refer to bound variables in the function
180
+ # signature. Instead, make sure to use `var.ty` which will always be a
181
+ # concrete type and type checking has ensured that they unify.
182
+ ty = var.ty
179
183
  inout_wire = state.dfg[var]
180
- update_packed_value(
181
- arg, GuppyObject(inp.ty, inout_wire), state.dfg.builder
184
+ success = update_packed_value(
185
+ arg, GuppyObject(ty, inout_wire), state.dfg.builder
182
186
  )
187
+ if not success:
188
+ # This means the user has passed an object that we cannot update,
189
+ # e.g. calling `mem_swap(x, y)` where the inputs are plain Python
190
+ # objects
191
+ raise GuppyComptimeError(
192
+ f"Cannot borrow Python object of type `{ty}` at comptime"
193
+ )
183
194
 
184
195
  ret_obj = GuppyObject(ret_ty, ret_wire)
185
196
  return unpack_guppy_object(ret_obj, state.dfg.builder)
@@ -1,7 +1,6 @@
1
1
  from typing import Any, TypeVar
2
2
 
3
3
  from hugr import ops
4
- from hugr import tys as ht
5
4
  from hugr.build.dfg import DfBase
6
5
 
7
6
  from guppylang_internals.ast_util import AstNode
@@ -13,7 +12,6 @@ from guppylang_internals.compiler.core import CompilerContext
13
12
  from guppylang_internals.compiler.expr_compiler import python_value_to_hugr
14
13
  from guppylang_internals.error import GuppyComptimeError, GuppyError
15
14
  from guppylang_internals.std._internal.compiler.array import array_new, unpack_array
16
- from guppylang_internals.std._internal.compiler.prelude import build_unwrap
17
15
  from guppylang_internals.tracing.frozenlist import frozenlist
18
16
  from guppylang_internals.tracing.object import (
19
17
  GuppyObject,
@@ -71,9 +69,7 @@ def unpack_guppy_object(
71
69
  # them as Guppy objects here
72
70
  return obj
73
71
  elem_ty = get_element_type(ty)
74
- opt_elems = unpack_array(builder, obj._use_wire(None))
75
- err = "Non-copyable array element has already been used"
76
- elems = [build_unwrap(builder, opt_elem, err) for opt_elem in opt_elems]
72
+ elems = unpack_array(builder, obj._use_wire(None))
77
73
  obj_list = [
78
74
  unpack_guppy_object(GuppyObject(elem_ty, wire), builder, frozen)
79
75
  for wire in elems
@@ -128,11 +124,8 @@ def guppy_object_from_py(
128
124
  f"Element at index {i + 1} does not match the type of "
129
125
  f"previous elements. Expected `{elem_ty}`, got `{obj._ty}`."
130
126
  )
131
- hugr_elem_ty = ht.Option(elem_ty.to_hugr(ctx))
132
- wires = [
133
- builder.add_op(ops.Tag(1, hugr_elem_ty), obj._use_wire(None))
134
- for obj in objs
135
- ]
127
+ hugr_elem_ty = elem_ty.to_hugr(ctx)
128
+ wires = [obj._use_wire(None) for obj in objs]
136
129
  return GuppyObject(
137
130
  array_type(elem_ty, len(vs)),
138
131
  builder.add_op(array_new(hugr_elem_ty, len(vs)), *wires),
@@ -150,13 +143,15 @@ def guppy_object_from_py(
150
143
  return GuppyObject(ty, builder.load(hugr_val))
151
144
 
152
145
 
153
- def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None:
146
+ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> bool:
154
147
  """Given a Python value `v` and a `GuppyObject` `obj` that was constructed from `v`
155
- using `guppy_object_from_py`, updates the wires of any `GuppyObjects` contained in
156
- `v` to the new wires specified by `obj`.
148
+ using `guppy_object_from_py`, tries to update the wires of any `GuppyObjects`
149
+ contained in `v` to the new wires specified by `obj`.
157
150
 
158
151
  Also resets the used flag on any of those updated wires. This corresponds to making
159
152
  the object available again since it now corresponds to a fresh wire.
153
+
154
+ Returns `True` if all wires could be updated, otherwise `False`.
160
155
  """
161
156
  match v:
162
157
  case GuppyObject() as v_obj:
@@ -170,25 +165,35 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None:
170
165
  assert isinstance(obj._ty, NoneType)
171
166
  case tuple(vs):
172
167
  assert isinstance(obj._ty, TupleType)
173
- wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs()
174
- for v, ty, wire in zip(vs, obj._ty.element_types, wires, strict=True):
175
- update_packed_value(v, GuppyObject(ty, wire), builder)
168
+ wire_iterator = builder.add_op(
169
+ ops.UnpackTuple(), obj._use_wire(None)
170
+ ).outputs()
171
+ for v, ty, out_wire in zip(
172
+ vs, obj._ty.element_types, wire_iterator, strict=True
173
+ ):
174
+ success = update_packed_value(v, GuppyObject(ty, out_wire), builder)
175
+ if not success:
176
+ return False
176
177
  case GuppyStructObject(_ty=ty, _field_values=values):
177
178
  assert obj._ty == ty
178
- wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs()
179
- for (
180
- field,
181
- wire,
182
- ) in zip(ty.fields, wires, strict=True):
179
+ wire_iterator = builder.add_op(
180
+ ops.UnpackTuple(), obj._use_wire(None)
181
+ ).outputs()
182
+ for field, out_wire in zip(ty.fields, wire_iterator, strict=True):
183
183
  v = values[field.name]
184
- update_packed_value(v, GuppyObject(field.ty, wire), builder)
184
+ success = update_packed_value(
185
+ v, GuppyObject(field.ty, out_wire), builder
186
+ )
187
+ if not success:
188
+ values[field.name] = obj
185
189
  case list(vs) if len(vs) > 0:
186
190
  assert is_array_type(obj._ty)
187
191
  elem_ty = get_element_type(obj._ty)
188
- opt_wires = unpack_array(builder, obj._use_wire(None))
189
- err = "Non-droppable array element has already been used"
190
- for v, opt_wire in zip(vs, opt_wires, strict=True):
191
- (wire,) = build_unwrap(builder, opt_wire, err).outputs()
192
- update_packed_value(v, GuppyObject(elem_ty, wire), builder)
192
+ wires = unpack_array(builder, obj._use_wire(None))
193
+ for i, (v, wire) in enumerate(zip(vs, wires, strict=True)):
194
+ success = update_packed_value(v, GuppyObject(elem_ty, wire), builder)
195
+ if not success:
196
+ vs[i] = obj
193
197
  case _:
194
- pass
198
+ return False
199
+ return True
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
3
3
  from typing import TYPE_CHECKING, TypeAlias
4
4
 
5
5
  from hugr import tys as ht
@@ -18,7 +18,7 @@ from guppylang_internals.tys.const import (
18
18
  ConstValue,
19
19
  ExistentialConstVar,
20
20
  )
21
- from guppylang_internals.tys.var import ExistentialVar
21
+ from guppylang_internals.tys.var import BoundVar, ExistentialVar
22
22
 
23
23
  if TYPE_CHECKING:
24
24
  from guppylang_internals.tys.ty import Type
@@ -45,19 +45,29 @@ class ArgumentBase(ToHugr[ht.TypeArg], Transformable["Argument"], ABC):
45
45
  def unsolved_vars(self) -> set[ExistentialVar]:
46
46
  """The existential type variables contained in this argument."""
47
47
 
48
+ @property
49
+ @abstractmethod
50
+ def bound_vars(self) -> set[BoundVar]:
51
+ """The bound type variables contained in this argument."""
52
+
48
53
 
49
54
  @dataclass(frozen=True)
50
55
  class TypeArg(ArgumentBase):
51
56
  """Argument that can be instantiated for a `TypeParameter`."""
52
57
 
53
58
  # The type to instantiate
54
- ty: "Type"
59
+ ty: "Type" = field(hash=False) # Types are not hashable
55
60
 
56
61
  @property
57
62
  def unsolved_vars(self) -> set[ExistentialVar]:
58
63
  """The existential type variables contained in this argument."""
59
64
  return self.ty.unsolved_vars
60
65
 
66
+ @property
67
+ def bound_vars(self) -> set[BoundVar]:
68
+ """The bound type variables contained in this type."""
69
+ return self.ty.bound_vars
70
+
61
71
  def to_hugr(self, ctx: ToHugrContext) -> ht.TypeTypeArg:
62
72
  """Computes the Hugr representation of the argument."""
63
73
  ty: ht.Type = self.ty.to_hugr(ctx)
@@ -84,6 +94,11 @@ class ConstArg(ArgumentBase):
84
94
  """The existential const variables contained in this argument."""
85
95
  return self.const.unsolved_vars
86
96
 
97
+ @property
98
+ def bound_vars(self) -> set[BoundVar]:
99
+ """The bound type variables contained in this argument."""
100
+ return self.const.bound_vars
101
+
87
102
  def to_hugr(self, ctx: ToHugrContext) -> ht.TypeArg:
88
103
  """Computes the Hugr representation of this argument."""
89
104
  from guppylang_internals.tys.ty import NumericType
@@ -46,6 +46,27 @@ class CallableTypeDef(TypeDef, CompiledDef):
46
46
  raise InternalGuppyError("Tried to `Callable` type via `check_instantiate`")
47
47
 
48
48
 
49
+ @dataclass(frozen=True)
50
+ class SelfTypeDef(TypeDef, CompiledDef):
51
+ """Type definition associated with the `Self` type on methods.
52
+
53
+ During type parsing, we make sure that this type is replaced with the concrete type
54
+ the method is attached to. Thus, we should never have instances of this type around.
55
+
56
+ In other words, this definition is only a marker so that type parsing doesn't have
57
+ to rely on matching against the string "Self". By making `Self` a definition, we can
58
+ use the existing identifier tracking system and also handle users shadowing the
59
+ `Self` binder or assigning `Self` to some other name.
60
+ """
61
+
62
+ name: Literal["Self"] = field(default="Self", init=False)
63
+
64
+ def check_instantiate(
65
+ self, args: Sequence[Argument], loc: AstNode | None = None
66
+ ) -> FunctionType:
67
+ raise InternalGuppyError("Tried to instantiate abstract `Self` type`")
68
+
69
+
49
70
  @dataclass(frozen=True)
50
71
  class _TupleTypeDef(TypeDef, CompiledDef):
51
72
  """Type definition associated with the builtin `tuple` type.
@@ -106,7 +127,6 @@ class _NumericTypeDef(TypeDef, CompiledDef):
106
127
 
107
128
  class WasmModuleTypeDef(OpaqueTypeDef):
108
129
  wasm_file: str
109
- wasm_hash: int
110
130
 
111
131
  def __init__(
112
132
  self,
@@ -114,11 +134,9 @@ class WasmModuleTypeDef(OpaqueTypeDef):
114
134
  name: str,
115
135
  defined_at: ast.AST | None,
116
136
  wasm_file: str,
117
- wasm_hash: int,
118
137
  ) -> None:
119
138
  super().__init__(id, name, defined_at, [], True, True, self.to_hugr)
120
139
  self.wasm_file = wasm_file
121
- self.wasm_hash = wasm_hash
122
140
 
123
141
  def to_hugr(
124
142
  self, args: Sequence[TypeArg | ConstArg], ctx: ToHugrContext
@@ -159,13 +177,10 @@ def _array_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
159
177
  assert isinstance(ty_arg, TypeArg)
160
178
  assert isinstance(len_arg, ConstArg)
161
179
 
162
- # Linear elements are turned into an optional to enable unsafe indexing.
163
- # See `ArrayGetitemCompiler` for details.
164
- # Same also for classical arrays, see https://github.com/CQCL/guppylang/issues/629
165
- elem_ty = ht.Option(ty_arg.ty.to_hugr(ctx))
180
+ elem_ty = ty_arg.ty.to_hugr(ctx)
166
181
  hugr_arg = len_arg.to_hugr(ctx)
167
182
 
168
- return hugr.std.collections.value_array.ValueArray(elem_ty, hugr_arg)
183
+ return hugr.std.collections.borrow_array.BorrowArray(elem_ty, hugr_arg)
169
184
 
170
185
 
171
186
  def _frozenarray_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
@@ -189,9 +204,10 @@ def _option_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
189
204
  return ht.Option(arg.ty.to_hugr(ctx))
190
205
 
191
206
 
192
- callable_type_def = CallableTypeDef(DefId.fresh(), None)
193
- tuple_type_def = _TupleTypeDef(DefId.fresh(), None)
194
- none_type_def = _NoneTypeDef(DefId.fresh(), None)
207
+ callable_type_def = CallableTypeDef(DefId.fresh(), None, None)
208
+ self_type_def = SelfTypeDef(DefId.fresh(), None, [])
209
+ tuple_type_def = _TupleTypeDef(DefId.fresh(), None, None)
210
+ none_type_def = _NoneTypeDef(DefId.fresh(), None, [])
195
211
  bool_type_def = OpaqueTypeDef(
196
212
  id=DefId.fresh(),
197
213
  name="bool",
@@ -202,13 +218,13 @@ bool_type_def = OpaqueTypeDef(
202
218
  to_hugr=lambda args, ctx: OpaqueBool,
203
219
  )
204
220
  nat_type_def = _NumericTypeDef(
205
- DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat)
221
+ DefId.fresh(), "nat", None, [], NumericType(NumericType.Kind.Nat)
206
222
  )
207
223
  int_type_def = _NumericTypeDef(
208
- DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int)
224
+ DefId.fresh(), "int", None, [], NumericType(NumericType.Kind.Int)
209
225
  )
210
226
  float_type_def = _NumericTypeDef(
211
- DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float)
227
+ DefId.fresh(), "float", None, [], NumericType(NumericType.Kind.Float)
212
228
  )
213
229
  string_type_def = OpaqueTypeDef(
214
230
  id=DefId.fresh(),
@@ -345,9 +361,9 @@ def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]:
345
361
  return isinstance(ty, OpaqueType) and ty.defn == sized_iter_type_def
346
362
 
347
363
 
348
- def wasm_module_info(ty: Type) -> tuple[str, int] | None:
364
+ def wasm_module_name(ty: Type) -> str | None:
349
365
  if isinstance(ty, OpaqueType) and isinstance(ty.defn, WasmModuleTypeDef):
350
- return ty.defn.wasm_file, ty.defn.wasm_hash
366
+ return ty.defn.wasm_file
351
367
  return None
352
368
 
353
369
 
@@ -8,6 +8,7 @@ from guppylang_internals.tys.var import BoundVar, ExistentialVar
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from guppylang_internals.tys.arg import ConstArg
11
+ from guppylang_internals.tys.subst import Subst
11
12
  from guppylang_internals.tys.ty import Type
12
13
 
13
14
 
@@ -39,6 +40,11 @@ class ConstBase(Transformable["Const"], ABC):
39
40
  """The existential type variables contained in this constant."""
40
41
  return set()
41
42
 
43
+ @property
44
+ def bound_vars(self) -> set[BoundVar]:
45
+ """The bound type variables contained in this constant."""
46
+ return self.ty.bound_vars
47
+
42
48
  def __str__(self) -> str:
43
49
  from guppylang_internals.tys.printing import TypePrinter
44
50
 
@@ -48,16 +54,18 @@ class ConstBase(Transformable["Const"], ABC):
48
54
  """Accepts a visitor on this constant."""
49
55
  visitor.visit(self)
50
56
 
51
- def transform(self, transformer: Transformer, /) -> "Const":
52
- """Accepts a transformer on this constant."""
53
- return transformer.transform(self) or self.cast()
54
-
55
57
  def to_arg(self) -> "ConstArg":
56
58
  """Wraps this constant into a type argument."""
57
59
  from guppylang_internals.tys.arg import ConstArg
58
60
 
59
61
  return ConstArg(self.cast())
60
62
 
63
+ def substitute(self, subst: "Subst") -> "Const":
64
+ """Substitutes existential variables in this constant."""
65
+ from guppylang_internals.tys.subst import Substituter
66
+
67
+ return self.transform(Substituter(subst))
68
+
61
69
 
62
70
  @dataclass(frozen=True)
63
71
  class ConstValue(ConstBase):
@@ -74,6 +82,10 @@ class ConstValue(ConstBase):
74
82
  """Casts an implementor of `ConstBase` into a `Const`."""
75
83
  return self
76
84
 
85
+ def transform(self, transformer: Transformer, /) -> "Const":
86
+ """Accepts a transformer on this constant."""
87
+ return transformer.transform(self) or self
88
+
77
89
 
78
90
  @dataclass(frozen=True)
79
91
  class BoundConstVar(BoundVar, ConstBase):
@@ -84,10 +96,21 @@ class BoundConstVar(BoundVar, ConstBase):
84
96
  `BoundConstVar(idx=0)`.
85
97
  """
86
98
 
99
+ @property
100
+ def bound_vars(self) -> set[BoundVar]:
101
+ """The bound type variables contained in this constant."""
102
+ return {self} | self.ty.bound_vars
103
+
87
104
  def cast(self) -> "Const":
88
105
  """Casts an implementor of `ConstBase` into a `Const`."""
89
106
  return self
90
107
 
108
+ def transform(self, transformer: Transformer, /) -> "Const":
109
+ """Accepts a transformer on this constant."""
110
+ return transformer.transform(self) or BoundConstVar(
111
+ transformer.transform(self.ty) or self.ty, self.display_name, self.idx
112
+ )
113
+
91
114
 
92
115
  @dataclass(frozen=True)
93
116
  class ExistentialConstVar(ExistentialVar, ConstBase):
@@ -110,5 +133,11 @@ class ExistentialConstVar(ExistentialVar, ConstBase):
110
133
  """Casts an implementor of `ConstBase` into a `Const`."""
111
134
  return self
112
135
 
136
+ def transform(self, transformer: Transformer, /) -> "Const":
137
+ """Accepts a transformer on this constant."""
138
+ return transformer.transform(self) or ExistentialConstVar(
139
+ transformer.transform(self.ty) or self.ty, self.display_name, self.id
140
+ )
141
+
113
142
 
114
143
  Const: TypeAlias = ConstValue | BoundConstVar | ExistentialConstVar
@@ -116,6 +116,12 @@ class InvalidCallableTypeError(Error):
116
116
  self.add_sub_diagnostic(InvalidCallableTypeError.Explain(None))
117
117
 
118
118
 
119
+ @dataclass(frozen=True)
120
+ class SelfTyNotInMethodError(Error):
121
+ title: ClassVar[str] = "Invalid type"
122
+ span_label: ClassVar[str] = "`Self` type annotations are only allowed in methods"
123
+
124
+
119
125
  @dataclass(frozen=True)
120
126
  class NonLinearOwnedError(Error):
121
127
  title: ClassVar[str] = "Invalid annotation"
@@ -1,6 +1,6 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from collections.abc import Sequence
3
- from dataclasses import dataclass, field
3
+ from dataclasses import dataclass, field, replace
4
4
  from typing import TYPE_CHECKING, TypeAlias
5
5
 
6
6
  from hugr import tys as ht
@@ -17,9 +17,9 @@ from guppylang_internals.tys.errors import WrongNumberOfTypeArgsError
17
17
  from guppylang_internals.tys.var import ExistentialVar
18
18
 
19
19
  if TYPE_CHECKING:
20
+ from guppylang_internals.tys.subst import PartialInst
20
21
  from guppylang_internals.tys.ty import Type
21
22
 
22
-
23
23
  # We define the `Parameter` type as a union of all `ParameterBase` subclasses defined
24
24
  # below. This models an algebraic data type and enables exhaustiveness checking in
25
25
  # pattern matches etc.
@@ -74,6 +74,10 @@ class ParameterBase(ToHugr[ht.TypeParam], ABC):
74
74
  parameter.
75
75
  """
76
76
 
77
+ @abstractmethod
78
+ def instantiate_bounds(self, inst: "PartialInst") -> Self:
79
+ """Instantiates bound variables mentioned in parameter bounds"""
80
+
77
81
 
78
82
  @dataclass(frozen=True)
79
83
  class TypeParam(ParameterBase):
@@ -142,10 +146,17 @@ class TypeParam(ParameterBase):
142
146
  BoundTypeVar(self.name, idx, self.must_be_copyable, self.must_be_droppable)
143
147
  )
144
148
 
149
+ def instantiate_bounds(self, inst: "PartialInst") -> "TypeParam":
150
+ """Instantiates bound variables mentioned in parameter bounds"""
151
+ # For now, type parameters don't have any bounds that could be instantiated
152
+ return self
153
+
145
154
  def to_hugr(self, ctx: ToHugrContext) -> ht.TypeParam:
146
155
  """Computes the Hugr representation of the parameter."""
147
156
  return ht.TypeTypeParam(
148
- bound=ht.TypeBound.Linear if self.can_be_linear else ht.TypeBound.Copyable
157
+ bound=ht.TypeBound.Copyable
158
+ if self.must_be_copyable
159
+ else ht.TypeBound.Linear
149
160
  )
150
161
 
151
162
  def __str__(self) -> str:
@@ -157,18 +168,12 @@ class TypeParam(ParameterBase):
157
168
  class ConstParam(ParameterBase):
158
169
  """A parameter of kind constant. Used to define fixed-size arrays etc."""
159
170
 
160
- ty: "Type"
171
+ ty: "Type" = field(hash=False)
161
172
 
162
173
  #: Marker to annotate if this parameter was implicitly generated by a `@comptime`
163
174
  #: annotated argument in a function signature.
164
175
  from_comptime_arg: bool = field(default=False, kw_only=True)
165
176
 
166
- def __post_init__(self) -> None:
167
- if self.ty.unsolved_vars:
168
- raise InternalGuppyError(
169
- "Attempted to create constant param with unsolved type"
170
- )
171
-
172
177
  def with_idx(self, idx: int) -> "ConstParam":
173
178
  """Returns a copy of the parameter with a new index."""
174
179
  return ConstParam(idx, self.name, self.ty)
@@ -178,13 +183,16 @@ class ConstParam(ParameterBase):
178
183
 
179
184
  Raises a user error if the argument is not valid.
180
185
  """
186
+ from guppylang_internals.tys.ty import unify
187
+
181
188
  match arg:
182
189
  case ConstArg(const):
183
- if const.ty != self.ty:
190
+ subst = unify(const.ty, self.ty, {})
191
+ if subst is None:
184
192
  raise GuppyTypeError(
185
193
  TypeMismatchError(loc, self.ty, const.ty, kind="argument")
186
194
  )
187
- return arg
195
+ return ConstArg(replace(const, ty=const.ty.substitute(subst)))
188
196
  case TypeArg(ty=ty):
189
197
  err = ExpectedError(
190
198
  loc, f"expression of type `{self.ty}`", got=f"type `{ty}`"
@@ -208,6 +216,13 @@ class ConstParam(ParameterBase):
208
216
  idx = self.idx
209
217
  return ConstArg(BoundConstVar(self.ty, self.name, idx))
210
218
 
219
+ def instantiate_bounds(self, inst: "PartialInst") -> "ConstParam":
220
+ """Instantiates bound variables mentioned in parameter bounds"""
221
+ from guppylang_internals.tys.subst import Instantiator
222
+
223
+ instantiator = Instantiator(inst)
224
+ return replace(self, ty=self.ty.transform(instantiator))
225
+
211
226
  def to_hugr(self, ctx: ToHugrContext) -> ht.TypeParam:
212
227
  """Computes the Hugr representation of the parameter."""
213
228
  from guppylang_internals.tys.ty import NumericType
@@ -231,6 +246,7 @@ def check_all_args(
231
246
  args: Sequence[Argument],
232
247
  type_name: str,
233
248
  loc: AstNode | None = None,
249
+ arg_locs: Sequence[AstNode] | None = None,
234
250
  ) -> None:
235
251
  """Checks a list of arguments against the given parameters.
236
252
 
@@ -245,7 +261,6 @@ def check_all_args(
245
261
  raise GuppyError(WrongNumberOfTypeArgsError(loc, exp, act, type_name))
246
262
 
247
263
  # Now check that the kinds match up
248
- for param, arg in zip(params, args, strict=True):
249
- # TODO: The error location is bad. We want the location of `arg`, not of the
250
- # whole thing.
251
- param.check_arg(arg, loc)
264
+ for i, (param, arg) in enumerate(zip(params, args, strict=True)):
265
+ arg_loc = arg_locs[i] if arg_locs else loc
266
+ param.instantiate_bounds(args).check_arg(arg, arg_loc)