guppylang-internals 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +101 -3
- guppylang_internals/checker/core.py +12 -0
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/expr_checker.py +55 -29
- guppylang_internals/checker/func_checker.py +171 -22
- guppylang_internals/checker/linearity_checker.py +65 -0
- guppylang_internals/checker/modifier_checker.py +116 -0
- guppylang_internals/checker/stmt_checker.py +49 -2
- guppylang_internals/compiler/core.py +90 -53
- guppylang_internals/compiler/expr_compiler.py +49 -114
- guppylang_internals/compiler/modifier_compiler.py +174 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +124 -58
- guppylang_internals/definition/const.py +2 -2
- guppylang_internals/definition/custom.py +36 -2
- guppylang_internals/definition/declaration.py +4 -5
- guppylang_internals/definition/extern.py +2 -2
- guppylang_internals/definition/function.py +1 -1
- guppylang_internals/definition/parameter.py +10 -5
- guppylang_internals/definition/pytket_circuits.py +14 -42
- guppylang_internals/definition/struct.py +17 -14
- guppylang_internals/definition/traced.py +1 -1
- guppylang_internals/definition/ty.py +9 -3
- guppylang_internals/definition/wasm.py +2 -2
- guppylang_internals/engine.py +13 -2
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +124 -23
- guppylang_internals/std/_internal/compiler/array.py +94 -282
- guppylang_internals/std/_internal/compiler/tket_exts.py +12 -8
- guppylang_internals/std/_internal/compiler/wasm.py +37 -26
- guppylang_internals/tracing/function.py +13 -2
- guppylang_internals/tracing/unpacking.py +33 -28
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +32 -16
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +6 -0
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +118 -145
- guppylang_internals/tys/qubit.py +27 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +31 -21
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +4 -4
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +49 -46
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -176,10 +176,21 @@ def trace_call(func: CallableDef, *args: Any) -> Any:
|
|
|
176
176
|
if len(func.ty.inputs) != 0:
|
|
177
177
|
for inp, arg, var in zip(func.ty.inputs, args, arg_vars, strict=True):
|
|
178
178
|
if InputFlags.Inout in inp.flags:
|
|
179
|
+
# Note that `inp.ty` could refer to bound variables in the function
|
|
180
|
+
# signature. Instead, make sure to use `var.ty` which will always be a
|
|
181
|
+
# concrete type and type checking has ensured that they unify.
|
|
182
|
+
ty = var.ty
|
|
179
183
|
inout_wire = state.dfg[var]
|
|
180
|
-
update_packed_value(
|
|
181
|
-
arg, GuppyObject(
|
|
184
|
+
success = update_packed_value(
|
|
185
|
+
arg, GuppyObject(ty, inout_wire), state.dfg.builder
|
|
182
186
|
)
|
|
187
|
+
if not success:
|
|
188
|
+
# This means the user has passed an object that we cannot update,
|
|
189
|
+
# e.g. calling `mem_swap(x, y)` where the inputs are plain Python
|
|
190
|
+
# objects
|
|
191
|
+
raise GuppyComptimeError(
|
|
192
|
+
f"Cannot borrow Python object of type `{ty}` at comptime"
|
|
193
|
+
)
|
|
183
194
|
|
|
184
195
|
ret_obj = GuppyObject(ret_ty, ret_wire)
|
|
185
196
|
return unpack_guppy_object(ret_obj, state.dfg.builder)
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Any, TypeVar
|
|
2
2
|
|
|
3
3
|
from hugr import ops
|
|
4
|
-
from hugr import tys as ht
|
|
5
4
|
from hugr.build.dfg import DfBase
|
|
6
5
|
|
|
7
6
|
from guppylang_internals.ast_util import AstNode
|
|
@@ -13,7 +12,6 @@ from guppylang_internals.compiler.core import CompilerContext
|
|
|
13
12
|
from guppylang_internals.compiler.expr_compiler import python_value_to_hugr
|
|
14
13
|
from guppylang_internals.error import GuppyComptimeError, GuppyError
|
|
15
14
|
from guppylang_internals.std._internal.compiler.array import array_new, unpack_array
|
|
16
|
-
from guppylang_internals.std._internal.compiler.prelude import build_unwrap
|
|
17
15
|
from guppylang_internals.tracing.frozenlist import frozenlist
|
|
18
16
|
from guppylang_internals.tracing.object import (
|
|
19
17
|
GuppyObject,
|
|
@@ -71,9 +69,7 @@ def unpack_guppy_object(
|
|
|
71
69
|
# them as Guppy objects here
|
|
72
70
|
return obj
|
|
73
71
|
elem_ty = get_element_type(ty)
|
|
74
|
-
|
|
75
|
-
err = "Non-copyable array element has already been used"
|
|
76
|
-
elems = [build_unwrap(builder, opt_elem, err) for opt_elem in opt_elems]
|
|
72
|
+
elems = unpack_array(builder, obj._use_wire(None))
|
|
77
73
|
obj_list = [
|
|
78
74
|
unpack_guppy_object(GuppyObject(elem_ty, wire), builder, frozen)
|
|
79
75
|
for wire in elems
|
|
@@ -128,11 +124,8 @@ def guppy_object_from_py(
|
|
|
128
124
|
f"Element at index {i + 1} does not match the type of "
|
|
129
125
|
f"previous elements. Expected `{elem_ty}`, got `{obj._ty}`."
|
|
130
126
|
)
|
|
131
|
-
hugr_elem_ty =
|
|
132
|
-
wires = [
|
|
133
|
-
builder.add_op(ops.Tag(1, hugr_elem_ty), obj._use_wire(None))
|
|
134
|
-
for obj in objs
|
|
135
|
-
]
|
|
127
|
+
hugr_elem_ty = elem_ty.to_hugr(ctx)
|
|
128
|
+
wires = [obj._use_wire(None) for obj in objs]
|
|
136
129
|
return GuppyObject(
|
|
137
130
|
array_type(elem_ty, len(vs)),
|
|
138
131
|
builder.add_op(array_new(hugr_elem_ty, len(vs)), *wires),
|
|
@@ -150,13 +143,15 @@ def guppy_object_from_py(
|
|
|
150
143
|
return GuppyObject(ty, builder.load(hugr_val))
|
|
151
144
|
|
|
152
145
|
|
|
153
|
-
def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) ->
|
|
146
|
+
def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> bool:
|
|
154
147
|
"""Given a Python value `v` and a `GuppyObject` `obj` that was constructed from `v`
|
|
155
|
-
using `guppy_object_from_py`,
|
|
156
|
-
`v` to the new wires specified by `obj`.
|
|
148
|
+
using `guppy_object_from_py`, tries to update the wires of any `GuppyObjects`
|
|
149
|
+
contained in `v` to the new wires specified by `obj`.
|
|
157
150
|
|
|
158
151
|
Also resets the used flag on any of those updated wires. This corresponds to making
|
|
159
152
|
the object available again since it now corresponds to a fresh wire.
|
|
153
|
+
|
|
154
|
+
Returns `True` if all wires could be updated, otherwise `False`.
|
|
160
155
|
"""
|
|
161
156
|
match v:
|
|
162
157
|
case GuppyObject() as v_obj:
|
|
@@ -170,25 +165,35 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None:
|
|
|
170
165
|
assert isinstance(obj._ty, NoneType)
|
|
171
166
|
case tuple(vs):
|
|
172
167
|
assert isinstance(obj._ty, TupleType)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
168
|
+
wire_iterator = builder.add_op(
|
|
169
|
+
ops.UnpackTuple(), obj._use_wire(None)
|
|
170
|
+
).outputs()
|
|
171
|
+
for v, ty, out_wire in zip(
|
|
172
|
+
vs, obj._ty.element_types, wire_iterator, strict=True
|
|
173
|
+
):
|
|
174
|
+
success = update_packed_value(v, GuppyObject(ty, out_wire), builder)
|
|
175
|
+
if not success:
|
|
176
|
+
return False
|
|
176
177
|
case GuppyStructObject(_ty=ty, _field_values=values):
|
|
177
178
|
assert obj._ty == ty
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
) in zip(ty.fields, wires, strict=True):
|
|
179
|
+
wire_iterator = builder.add_op(
|
|
180
|
+
ops.UnpackTuple(), obj._use_wire(None)
|
|
181
|
+
).outputs()
|
|
182
|
+
for field, out_wire in zip(ty.fields, wire_iterator, strict=True):
|
|
183
183
|
v = values[field.name]
|
|
184
|
-
update_packed_value(
|
|
184
|
+
success = update_packed_value(
|
|
185
|
+
v, GuppyObject(field.ty, out_wire), builder
|
|
186
|
+
)
|
|
187
|
+
if not success:
|
|
188
|
+
values[field.name] = obj
|
|
185
189
|
case list(vs) if len(vs) > 0:
|
|
186
190
|
assert is_array_type(obj._ty)
|
|
187
191
|
elem_ty = get_element_type(obj._ty)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
192
|
+
wires = unpack_array(builder, obj._use_wire(None))
|
|
193
|
+
for i, (v, wire) in enumerate(zip(vs, wires, strict=True)):
|
|
194
|
+
success = update_packed_value(v, GuppyObject(elem_ty, wire), builder)
|
|
195
|
+
if not success:
|
|
196
|
+
vs[i] = obj
|
|
193
197
|
case _:
|
|
194
|
-
|
|
198
|
+
return False
|
|
199
|
+
return True
|
guppylang_internals/tys/arg.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
3
|
from typing import TYPE_CHECKING, TypeAlias
|
|
4
4
|
|
|
5
5
|
from hugr import tys as ht
|
|
@@ -18,7 +18,7 @@ from guppylang_internals.tys.const import (
|
|
|
18
18
|
ConstValue,
|
|
19
19
|
ExistentialConstVar,
|
|
20
20
|
)
|
|
21
|
-
from guppylang_internals.tys.var import ExistentialVar
|
|
21
|
+
from guppylang_internals.tys.var import BoundVar, ExistentialVar
|
|
22
22
|
|
|
23
23
|
if TYPE_CHECKING:
|
|
24
24
|
from guppylang_internals.tys.ty import Type
|
|
@@ -45,19 +45,29 @@ class ArgumentBase(ToHugr[ht.TypeArg], Transformable["Argument"], ABC):
|
|
|
45
45
|
def unsolved_vars(self) -> set[ExistentialVar]:
|
|
46
46
|
"""The existential type variables contained in this argument."""
|
|
47
47
|
|
|
48
|
+
@property
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
51
|
+
"""The bound type variables contained in this argument."""
|
|
52
|
+
|
|
48
53
|
|
|
49
54
|
@dataclass(frozen=True)
|
|
50
55
|
class TypeArg(ArgumentBase):
|
|
51
56
|
"""Argument that can be instantiated for a `TypeParameter`."""
|
|
52
57
|
|
|
53
58
|
# The type to instantiate
|
|
54
|
-
ty: "Type"
|
|
59
|
+
ty: "Type" = field(hash=False) # Types are not hashable
|
|
55
60
|
|
|
56
61
|
@property
|
|
57
62
|
def unsolved_vars(self) -> set[ExistentialVar]:
|
|
58
63
|
"""The existential type variables contained in this argument."""
|
|
59
64
|
return self.ty.unsolved_vars
|
|
60
65
|
|
|
66
|
+
@property
|
|
67
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
68
|
+
"""The bound type variables contained in this type."""
|
|
69
|
+
return self.ty.bound_vars
|
|
70
|
+
|
|
61
71
|
def to_hugr(self, ctx: ToHugrContext) -> ht.TypeTypeArg:
|
|
62
72
|
"""Computes the Hugr representation of the argument."""
|
|
63
73
|
ty: ht.Type = self.ty.to_hugr(ctx)
|
|
@@ -84,6 +94,11 @@ class ConstArg(ArgumentBase):
|
|
|
84
94
|
"""The existential const variables contained in this argument."""
|
|
85
95
|
return self.const.unsolved_vars
|
|
86
96
|
|
|
97
|
+
@property
|
|
98
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
99
|
+
"""The bound type variables contained in this argument."""
|
|
100
|
+
return self.const.bound_vars
|
|
101
|
+
|
|
87
102
|
def to_hugr(self, ctx: ToHugrContext) -> ht.TypeArg:
|
|
88
103
|
"""Computes the Hugr representation of this argument."""
|
|
89
104
|
from guppylang_internals.tys.ty import NumericType
|
|
@@ -46,6 +46,27 @@ class CallableTypeDef(TypeDef, CompiledDef):
|
|
|
46
46
|
raise InternalGuppyError("Tried to `Callable` type via `check_instantiate`")
|
|
47
47
|
|
|
48
48
|
|
|
49
|
+
@dataclass(frozen=True)
|
|
50
|
+
class SelfTypeDef(TypeDef, CompiledDef):
|
|
51
|
+
"""Type definition associated with the `Self` type on methods.
|
|
52
|
+
|
|
53
|
+
During type parsing, we make sure that this type is replaced with the concrete type
|
|
54
|
+
the method is attached to. Thus, we should never have instances of this type around.
|
|
55
|
+
|
|
56
|
+
In other words, this definition is only a marker so that type parsing doesn't have
|
|
57
|
+
to rely on matching against the string "Self". By making `Self` a definition, we can
|
|
58
|
+
use the existing identifier tracking system and also handle users shadowing the
|
|
59
|
+
`Self` binder or assigning `Self` to some other name.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
name: Literal["Self"] = field(default="Self", init=False)
|
|
63
|
+
|
|
64
|
+
def check_instantiate(
|
|
65
|
+
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
66
|
+
) -> FunctionType:
|
|
67
|
+
raise InternalGuppyError("Tried to instantiate abstract `Self` type`")
|
|
68
|
+
|
|
69
|
+
|
|
49
70
|
@dataclass(frozen=True)
|
|
50
71
|
class _TupleTypeDef(TypeDef, CompiledDef):
|
|
51
72
|
"""Type definition associated with the builtin `tuple` type.
|
|
@@ -106,7 +127,6 @@ class _NumericTypeDef(TypeDef, CompiledDef):
|
|
|
106
127
|
|
|
107
128
|
class WasmModuleTypeDef(OpaqueTypeDef):
|
|
108
129
|
wasm_file: str
|
|
109
|
-
wasm_hash: int
|
|
110
130
|
|
|
111
131
|
def __init__(
|
|
112
132
|
self,
|
|
@@ -114,11 +134,9 @@ class WasmModuleTypeDef(OpaqueTypeDef):
|
|
|
114
134
|
name: str,
|
|
115
135
|
defined_at: ast.AST | None,
|
|
116
136
|
wasm_file: str,
|
|
117
|
-
wasm_hash: int,
|
|
118
137
|
) -> None:
|
|
119
138
|
super().__init__(id, name, defined_at, [], True, True, self.to_hugr)
|
|
120
139
|
self.wasm_file = wasm_file
|
|
121
|
-
self.wasm_hash = wasm_hash
|
|
122
140
|
|
|
123
141
|
def to_hugr(
|
|
124
142
|
self, args: Sequence[TypeArg | ConstArg], ctx: ToHugrContext
|
|
@@ -159,13 +177,10 @@ def _array_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
|
|
|
159
177
|
assert isinstance(ty_arg, TypeArg)
|
|
160
178
|
assert isinstance(len_arg, ConstArg)
|
|
161
179
|
|
|
162
|
-
|
|
163
|
-
# See `ArrayGetitemCompiler` for details.
|
|
164
|
-
# Same also for classical arrays, see https://github.com/CQCL/guppylang/issues/629
|
|
165
|
-
elem_ty = ht.Option(ty_arg.ty.to_hugr(ctx))
|
|
180
|
+
elem_ty = ty_arg.ty.to_hugr(ctx)
|
|
166
181
|
hugr_arg = len_arg.to_hugr(ctx)
|
|
167
182
|
|
|
168
|
-
return hugr.std.collections.
|
|
183
|
+
return hugr.std.collections.borrow_array.BorrowArray(elem_ty, hugr_arg)
|
|
169
184
|
|
|
170
185
|
|
|
171
186
|
def _frozenarray_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
|
|
@@ -189,9 +204,10 @@ def _option_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
|
|
|
189
204
|
return ht.Option(arg.ty.to_hugr(ctx))
|
|
190
205
|
|
|
191
206
|
|
|
192
|
-
callable_type_def = CallableTypeDef(DefId.fresh(), None)
|
|
193
|
-
|
|
194
|
-
|
|
207
|
+
callable_type_def = CallableTypeDef(DefId.fresh(), None, None)
|
|
208
|
+
self_type_def = SelfTypeDef(DefId.fresh(), None, [])
|
|
209
|
+
tuple_type_def = _TupleTypeDef(DefId.fresh(), None, None)
|
|
210
|
+
none_type_def = _NoneTypeDef(DefId.fresh(), None, [])
|
|
195
211
|
bool_type_def = OpaqueTypeDef(
|
|
196
212
|
id=DefId.fresh(),
|
|
197
213
|
name="bool",
|
|
@@ -202,13 +218,13 @@ bool_type_def = OpaqueTypeDef(
|
|
|
202
218
|
to_hugr=lambda args, ctx: OpaqueBool,
|
|
203
219
|
)
|
|
204
220
|
nat_type_def = _NumericTypeDef(
|
|
205
|
-
DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat)
|
|
221
|
+
DefId.fresh(), "nat", None, [], NumericType(NumericType.Kind.Nat)
|
|
206
222
|
)
|
|
207
223
|
int_type_def = _NumericTypeDef(
|
|
208
|
-
DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int)
|
|
224
|
+
DefId.fresh(), "int", None, [], NumericType(NumericType.Kind.Int)
|
|
209
225
|
)
|
|
210
226
|
float_type_def = _NumericTypeDef(
|
|
211
|
-
DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float)
|
|
227
|
+
DefId.fresh(), "float", None, [], NumericType(NumericType.Kind.Float)
|
|
212
228
|
)
|
|
213
229
|
string_type_def = OpaqueTypeDef(
|
|
214
230
|
id=DefId.fresh(),
|
|
@@ -345,9 +361,9 @@ def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]:
|
|
|
345
361
|
return isinstance(ty, OpaqueType) and ty.defn == sized_iter_type_def
|
|
346
362
|
|
|
347
363
|
|
|
348
|
-
def
|
|
364
|
+
def wasm_module_name(ty: Type) -> str | None:
|
|
349
365
|
if isinstance(ty, OpaqueType) and isinstance(ty.defn, WasmModuleTypeDef):
|
|
350
|
-
return ty.defn.wasm_file
|
|
366
|
+
return ty.defn.wasm_file
|
|
351
367
|
return None
|
|
352
368
|
|
|
353
369
|
|
guppylang_internals/tys/const.py
CHANGED
|
@@ -8,6 +8,7 @@ from guppylang_internals.tys.var import BoundVar, ExistentialVar
|
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
10
|
from guppylang_internals.tys.arg import ConstArg
|
|
11
|
+
from guppylang_internals.tys.subst import Subst
|
|
11
12
|
from guppylang_internals.tys.ty import Type
|
|
12
13
|
|
|
13
14
|
|
|
@@ -39,6 +40,11 @@ class ConstBase(Transformable["Const"], ABC):
|
|
|
39
40
|
"""The existential type variables contained in this constant."""
|
|
40
41
|
return set()
|
|
41
42
|
|
|
43
|
+
@property
|
|
44
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
45
|
+
"""The bound type variables contained in this constant."""
|
|
46
|
+
return self.ty.bound_vars
|
|
47
|
+
|
|
42
48
|
def __str__(self) -> str:
|
|
43
49
|
from guppylang_internals.tys.printing import TypePrinter
|
|
44
50
|
|
|
@@ -48,16 +54,18 @@ class ConstBase(Transformable["Const"], ABC):
|
|
|
48
54
|
"""Accepts a visitor on this constant."""
|
|
49
55
|
visitor.visit(self)
|
|
50
56
|
|
|
51
|
-
def transform(self, transformer: Transformer, /) -> "Const":
|
|
52
|
-
"""Accepts a transformer on this constant."""
|
|
53
|
-
return transformer.transform(self) or self.cast()
|
|
54
|
-
|
|
55
57
|
def to_arg(self) -> "ConstArg":
|
|
56
58
|
"""Wraps this constant into a type argument."""
|
|
57
59
|
from guppylang_internals.tys.arg import ConstArg
|
|
58
60
|
|
|
59
61
|
return ConstArg(self.cast())
|
|
60
62
|
|
|
63
|
+
def substitute(self, subst: "Subst") -> "Const":
|
|
64
|
+
"""Substitutes existential variables in this constant."""
|
|
65
|
+
from guppylang_internals.tys.subst import Substituter
|
|
66
|
+
|
|
67
|
+
return self.transform(Substituter(subst))
|
|
68
|
+
|
|
61
69
|
|
|
62
70
|
@dataclass(frozen=True)
|
|
63
71
|
class ConstValue(ConstBase):
|
|
@@ -74,6 +82,10 @@ class ConstValue(ConstBase):
|
|
|
74
82
|
"""Casts an implementor of `ConstBase` into a `Const`."""
|
|
75
83
|
return self
|
|
76
84
|
|
|
85
|
+
def transform(self, transformer: Transformer, /) -> "Const":
|
|
86
|
+
"""Accepts a transformer on this constant."""
|
|
87
|
+
return transformer.transform(self) or self
|
|
88
|
+
|
|
77
89
|
|
|
78
90
|
@dataclass(frozen=True)
|
|
79
91
|
class BoundConstVar(BoundVar, ConstBase):
|
|
@@ -84,10 +96,21 @@ class BoundConstVar(BoundVar, ConstBase):
|
|
|
84
96
|
`BoundConstVar(idx=0)`.
|
|
85
97
|
"""
|
|
86
98
|
|
|
99
|
+
@property
|
|
100
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
101
|
+
"""The bound type variables contained in this constant."""
|
|
102
|
+
return {self} | self.ty.bound_vars
|
|
103
|
+
|
|
87
104
|
def cast(self) -> "Const":
|
|
88
105
|
"""Casts an implementor of `ConstBase` into a `Const`."""
|
|
89
106
|
return self
|
|
90
107
|
|
|
108
|
+
def transform(self, transformer: Transformer, /) -> "Const":
|
|
109
|
+
"""Accepts a transformer on this constant."""
|
|
110
|
+
return transformer.transform(self) or BoundConstVar(
|
|
111
|
+
transformer.transform(self.ty) or self.ty, self.display_name, self.idx
|
|
112
|
+
)
|
|
113
|
+
|
|
91
114
|
|
|
92
115
|
@dataclass(frozen=True)
|
|
93
116
|
class ExistentialConstVar(ExistentialVar, ConstBase):
|
|
@@ -110,5 +133,11 @@ class ExistentialConstVar(ExistentialVar, ConstBase):
|
|
|
110
133
|
"""Casts an implementor of `ConstBase` into a `Const`."""
|
|
111
134
|
return self
|
|
112
135
|
|
|
136
|
+
def transform(self, transformer: Transformer, /) -> "Const":
|
|
137
|
+
"""Accepts a transformer on this constant."""
|
|
138
|
+
return transformer.transform(self) or ExistentialConstVar(
|
|
139
|
+
transformer.transform(self.ty) or self.ty, self.display_name, self.id
|
|
140
|
+
)
|
|
141
|
+
|
|
113
142
|
|
|
114
143
|
Const: TypeAlias = ConstValue | BoundConstVar | ExistentialConstVar
|
|
@@ -116,6 +116,12 @@ class InvalidCallableTypeError(Error):
|
|
|
116
116
|
self.add_sub_diagnostic(InvalidCallableTypeError.Explain(None))
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
@dataclass(frozen=True)
|
|
120
|
+
class SelfTyNotInMethodError(Error):
|
|
121
|
+
title: ClassVar[str] = "Invalid type"
|
|
122
|
+
span_label: ClassVar[str] = "`Self` type annotations are only allowed in methods"
|
|
123
|
+
|
|
124
|
+
|
|
119
125
|
@dataclass(frozen=True)
|
|
120
126
|
class NonLinearOwnedError(Error):
|
|
121
127
|
title: ClassVar[str] = "Invalid annotation"
|
guppylang_internals/tys/param.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from dataclasses import dataclass, field
|
|
3
|
+
from dataclasses import dataclass, field, replace
|
|
4
4
|
from typing import TYPE_CHECKING, TypeAlias
|
|
5
5
|
|
|
6
6
|
from hugr import tys as ht
|
|
@@ -17,9 +17,9 @@ from guppylang_internals.tys.errors import WrongNumberOfTypeArgsError
|
|
|
17
17
|
from guppylang_internals.tys.var import ExistentialVar
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
|
+
from guppylang_internals.tys.subst import PartialInst
|
|
20
21
|
from guppylang_internals.tys.ty import Type
|
|
21
22
|
|
|
22
|
-
|
|
23
23
|
# We define the `Parameter` type as a union of all `ParameterBase` subclasses defined
|
|
24
24
|
# below. This models an algebraic data type and enables exhaustiveness checking in
|
|
25
25
|
# pattern matches etc.
|
|
@@ -74,6 +74,10 @@ class ParameterBase(ToHugr[ht.TypeParam], ABC):
|
|
|
74
74
|
parameter.
|
|
75
75
|
"""
|
|
76
76
|
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def instantiate_bounds(self, inst: "PartialInst") -> Self:
|
|
79
|
+
"""Instantiates bound variables mentioned in parameter bounds"""
|
|
80
|
+
|
|
77
81
|
|
|
78
82
|
@dataclass(frozen=True)
|
|
79
83
|
class TypeParam(ParameterBase):
|
|
@@ -142,10 +146,17 @@ class TypeParam(ParameterBase):
|
|
|
142
146
|
BoundTypeVar(self.name, idx, self.must_be_copyable, self.must_be_droppable)
|
|
143
147
|
)
|
|
144
148
|
|
|
149
|
+
def instantiate_bounds(self, inst: "PartialInst") -> "TypeParam":
|
|
150
|
+
"""Instantiates bound variables mentioned in parameter bounds"""
|
|
151
|
+
# For now, type parameters don't have any bounds that could be instantiated
|
|
152
|
+
return self
|
|
153
|
+
|
|
145
154
|
def to_hugr(self, ctx: ToHugrContext) -> ht.TypeParam:
|
|
146
155
|
"""Computes the Hugr representation of the parameter."""
|
|
147
156
|
return ht.TypeTypeParam(
|
|
148
|
-
bound=ht.TypeBound.
|
|
157
|
+
bound=ht.TypeBound.Copyable
|
|
158
|
+
if self.must_be_copyable
|
|
159
|
+
else ht.TypeBound.Linear
|
|
149
160
|
)
|
|
150
161
|
|
|
151
162
|
def __str__(self) -> str:
|
|
@@ -157,18 +168,12 @@ class TypeParam(ParameterBase):
|
|
|
157
168
|
class ConstParam(ParameterBase):
|
|
158
169
|
"""A parameter of kind constant. Used to define fixed-size arrays etc."""
|
|
159
170
|
|
|
160
|
-
ty: "Type"
|
|
171
|
+
ty: "Type" = field(hash=False)
|
|
161
172
|
|
|
162
173
|
#: Marker to annotate if this parameter was implicitly generated by a `@comptime`
|
|
163
174
|
#: annotated argument in a function signature.
|
|
164
175
|
from_comptime_arg: bool = field(default=False, kw_only=True)
|
|
165
176
|
|
|
166
|
-
def __post_init__(self) -> None:
|
|
167
|
-
if self.ty.unsolved_vars:
|
|
168
|
-
raise InternalGuppyError(
|
|
169
|
-
"Attempted to create constant param with unsolved type"
|
|
170
|
-
)
|
|
171
|
-
|
|
172
177
|
def with_idx(self, idx: int) -> "ConstParam":
|
|
173
178
|
"""Returns a copy of the parameter with a new index."""
|
|
174
179
|
return ConstParam(idx, self.name, self.ty)
|
|
@@ -178,13 +183,16 @@ class ConstParam(ParameterBase):
|
|
|
178
183
|
|
|
179
184
|
Raises a user error if the argument is not valid.
|
|
180
185
|
"""
|
|
186
|
+
from guppylang_internals.tys.ty import unify
|
|
187
|
+
|
|
181
188
|
match arg:
|
|
182
189
|
case ConstArg(const):
|
|
183
|
-
|
|
190
|
+
subst = unify(const.ty, self.ty, {})
|
|
191
|
+
if subst is None:
|
|
184
192
|
raise GuppyTypeError(
|
|
185
193
|
TypeMismatchError(loc, self.ty, const.ty, kind="argument")
|
|
186
194
|
)
|
|
187
|
-
return
|
|
195
|
+
return ConstArg(replace(const, ty=const.ty.substitute(subst)))
|
|
188
196
|
case TypeArg(ty=ty):
|
|
189
197
|
err = ExpectedError(
|
|
190
198
|
loc, f"expression of type `{self.ty}`", got=f"type `{ty}`"
|
|
@@ -208,6 +216,13 @@ class ConstParam(ParameterBase):
|
|
|
208
216
|
idx = self.idx
|
|
209
217
|
return ConstArg(BoundConstVar(self.ty, self.name, idx))
|
|
210
218
|
|
|
219
|
+
def instantiate_bounds(self, inst: "PartialInst") -> "ConstParam":
|
|
220
|
+
"""Instantiates bound variables mentioned in parameter bounds"""
|
|
221
|
+
from guppylang_internals.tys.subst import Instantiator
|
|
222
|
+
|
|
223
|
+
instantiator = Instantiator(inst)
|
|
224
|
+
return replace(self, ty=self.ty.transform(instantiator))
|
|
225
|
+
|
|
211
226
|
def to_hugr(self, ctx: ToHugrContext) -> ht.TypeParam:
|
|
212
227
|
"""Computes the Hugr representation of the parameter."""
|
|
213
228
|
from guppylang_internals.tys.ty import NumericType
|
|
@@ -231,6 +246,7 @@ def check_all_args(
|
|
|
231
246
|
args: Sequence[Argument],
|
|
232
247
|
type_name: str,
|
|
233
248
|
loc: AstNode | None = None,
|
|
249
|
+
arg_locs: Sequence[AstNode] | None = None,
|
|
234
250
|
) -> None:
|
|
235
251
|
"""Checks a list of arguments against the given parameters.
|
|
236
252
|
|
|
@@ -245,7 +261,6 @@ def check_all_args(
|
|
|
245
261
|
raise GuppyError(WrongNumberOfTypeArgsError(loc, exp, act, type_name))
|
|
246
262
|
|
|
247
263
|
# Now check that the kinds match up
|
|
248
|
-
for param, arg in zip(params, args, strict=True):
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
param.check_arg(arg, loc)
|
|
264
|
+
for i, (param, arg) in enumerate(zip(params, args, strict=True)):
|
|
265
|
+
arg_loc = arg_locs[i] if arg_locs else loc
|
|
266
|
+
param.instantiate_bounds(args).check_arg(arg, arg_loc)
|