guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,876 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from enum import Enum, Flag, auto
|
|
5
|
+
from functools import cached_property, total_ordering
|
|
6
|
+
from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
|
|
7
|
+
|
|
8
|
+
import hugr.std.float
|
|
9
|
+
import hugr.std.int
|
|
10
|
+
from hugr import tys as ht
|
|
11
|
+
|
|
12
|
+
from guppylang_internals.error import InternalGuppyError
|
|
13
|
+
from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
|
|
14
|
+
from guppylang_internals.tys.common import (
|
|
15
|
+
QuantifiedToHugrContext,
|
|
16
|
+
ToHugr,
|
|
17
|
+
ToHugrContext,
|
|
18
|
+
Transformable,
|
|
19
|
+
Transformer,
|
|
20
|
+
Visitor,
|
|
21
|
+
)
|
|
22
|
+
from guppylang_internals.tys.const import Const, ConstValue, ExistentialConstVar
|
|
23
|
+
from guppylang_internals.tys.param import ConstParam, Parameter
|
|
24
|
+
from guppylang_internals.tys.var import BoundVar, ExistentialVar
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from guppylang_internals.definition.struct import CheckedStructDef, StructField
|
|
28
|
+
from guppylang_internals.definition.ty import OpaqueTypeDef
|
|
29
|
+
from guppylang_internals.tys.subst import Inst, PartialInst, Subst
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class TypeBase(ToHugr[ht.Type], Transformable["Type"], ABC):
|
|
34
|
+
"""Abstract base class for all Guppy types.
|
|
35
|
+
|
|
36
|
+
Note that all subclasses are expected to be immutable.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
@cached_property
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def copyable(self) -> bool:
|
|
42
|
+
"""Whether objects of this type can be implicitly copied."""
|
|
43
|
+
|
|
44
|
+
@cached_property
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def droppable(self) -> bool:
|
|
47
|
+
"""Whether objects of this type can be dropped."""
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def linear(self) -> bool:
|
|
51
|
+
"""Whether this type should be treated linearly."""
|
|
52
|
+
return not self.copyable and not self.droppable
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def affine(self) -> bool:
|
|
56
|
+
"""Whether this type should be treated in an affine way."""
|
|
57
|
+
return not self.copyable and self.droppable
|
|
58
|
+
|
|
59
|
+
@cached_property
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def hugr_bound(self) -> ht.TypeBound:
|
|
62
|
+
"""The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.
|
|
63
|
+
|
|
64
|
+
This needs to be specified explicitly, since opaque nonlinear types in a Hugr
|
|
65
|
+
extension could be either declared as copyable or equatable. If we don't get the
|
|
66
|
+
bound exactly right during serialisation, the Hugr validator will complain.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def cast(self) -> "Type":
|
|
71
|
+
"""Casts an implementor of `TypeBase` into a `Type`.
|
|
72
|
+
|
|
73
|
+
This enforces that all implementors of `TypeBase` can be embedded into the
|
|
74
|
+
`Type` union type.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
@cached_property
|
|
78
|
+
def unsolved_vars(self) -> set[ExistentialVar]:
|
|
79
|
+
"""The existential type variables contained in this type."""
|
|
80
|
+
return set()
|
|
81
|
+
|
|
82
|
+
def substitute(self, subst: "Subst") -> "Type":
|
|
83
|
+
"""Substitutes existential variables in this type."""
|
|
84
|
+
from guppylang_internals.tys.subst import Substituter
|
|
85
|
+
|
|
86
|
+
return self.transform(Substituter(subst))
|
|
87
|
+
|
|
88
|
+
def to_arg(self) -> TypeArg:
|
|
89
|
+
"""Wraps this constant into a type argument."""
|
|
90
|
+
return TypeArg(self.cast())
|
|
91
|
+
|
|
92
|
+
def __str__(self) -> str:
|
|
93
|
+
"""Returns a human-readable representation of the type."""
|
|
94
|
+
from guppylang_internals.tys.printing import TypePrinter
|
|
95
|
+
|
|
96
|
+
# We use a custom printer that takes care of inserting parentheses and choosing
|
|
97
|
+
# unique names
|
|
98
|
+
return TypePrinter().visit(cast(Type, self))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass(frozen=True)
|
|
102
|
+
class ParametrizedTypeBase(TypeBase, ABC):
|
|
103
|
+
"""Abstract base class for types that depend on parameters.
|
|
104
|
+
|
|
105
|
+
For example, `list`, `tuple`, etc. require arguments in order to be turned into a
|
|
106
|
+
proper type.
|
|
107
|
+
|
|
108
|
+
Note that all subclasses are expected to be immutable.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
args: Sequence[Argument]
|
|
112
|
+
|
|
113
|
+
def __post_init__(self) -> None:
|
|
114
|
+
# Make sure that we don't have nested generic functions
|
|
115
|
+
for arg in self.args:
|
|
116
|
+
match arg:
|
|
117
|
+
case TypeArg(ty=FunctionType(parametrized=True)):
|
|
118
|
+
raise InternalGuppyError(
|
|
119
|
+
"Tried to construct a higher-rank polymorphic type!"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
@abstractmethod
|
|
124
|
+
def intrinsically_copyable(self) -> bool:
|
|
125
|
+
"""Whether this type is copyable, independent of the arguments.
|
|
126
|
+
|
|
127
|
+
For example, a parametrized struct containing a qubit is never copyable, even if
|
|
128
|
+
all its arguments are.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
@cached_property
|
|
132
|
+
def copyable(self) -> bool:
|
|
133
|
+
"""Whether this type should be treated as copyable."""
|
|
134
|
+
# Either an argument isn't a type argument, or it must be copyable.
|
|
135
|
+
return self.intrinsically_copyable and all(
|
|
136
|
+
not isinstance(arg, TypeArg) or arg.ty.copyable for arg in self.args
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
@abstractmethod
|
|
141
|
+
def intrinsically_droppable(self) -> bool:
|
|
142
|
+
"""Whether this type is droppable, independent of the arguments.
|
|
143
|
+
|
|
144
|
+
For example, a parametrized struct containing a qubit is never droppable, even
|
|
145
|
+
if all its arguments are.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
@cached_property
|
|
149
|
+
def droppable(self) -> bool:
|
|
150
|
+
"""Whether this type should be treated as copyable."""
|
|
151
|
+
# Either an argument isn't a type argument, or it must be droppable.
|
|
152
|
+
return self.intrinsically_droppable and all(
|
|
153
|
+
not isinstance(arg, TypeArg) or arg.ty.droppable for arg in self.args
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
@cached_property
|
|
157
|
+
def unsolved_vars(self) -> set[ExistentialVar]:
|
|
158
|
+
"""The existential type variables contained in this type."""
|
|
159
|
+
return set().union(*(arg.unsolved_vars for arg in self.args))
|
|
160
|
+
|
|
161
|
+
@cached_property
|
|
162
|
+
def hugr_bound(self) -> ht.TypeBound:
|
|
163
|
+
"""The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
|
|
164
|
+
if self.linear:
|
|
165
|
+
return ht.TypeBound.Linear
|
|
166
|
+
return ht.TypeBound.join(
|
|
167
|
+
*(arg.ty.hugr_bound for arg in self.args if isinstance(arg, TypeArg))
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def visit(self, visitor: Visitor) -> None:
|
|
171
|
+
"""Accepts a visitor on this type."""
|
|
172
|
+
if not visitor.visit(self):
|
|
173
|
+
for arg in self.args:
|
|
174
|
+
visitor.visit(arg)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@dataclass(frozen=True)
|
|
178
|
+
class BoundTypeVar(TypeBase, BoundVar):
|
|
179
|
+
"""Bound type variable, referencing a parameter of kind `Type`.
|
|
180
|
+
|
|
181
|
+
For example, in the function type `forall T. list[T] -> T` we represent `T` as a
|
|
182
|
+
`BoundTypeVar(idx=0)`.
|
|
183
|
+
|
|
184
|
+
A bound type variables can be instantiated with a `TypeArg` argument.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
copyable: bool
|
|
188
|
+
droppable: bool
|
|
189
|
+
|
|
190
|
+
@cached_property
|
|
191
|
+
def hugr_bound(self) -> ht.TypeBound:
|
|
192
|
+
"""The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
|
|
193
|
+
if self.linear:
|
|
194
|
+
return ht.TypeBound.Linear
|
|
195
|
+
# We're conservative and don't require equatability for non-linear variables.
|
|
196
|
+
# This is fine since Guppy doesn't use the equatable feature anyways.
|
|
197
|
+
return ht.TypeBound.Copyable
|
|
198
|
+
|
|
199
|
+
def cast(self) -> "Type":
|
|
200
|
+
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
201
|
+
return self
|
|
202
|
+
|
|
203
|
+
def to_hugr(self, ctx: ToHugrContext) -> ht.Type:
|
|
204
|
+
"""Computes the Hugr representation of the type."""
|
|
205
|
+
return ctx.type_var_to_hugr(self)
|
|
206
|
+
|
|
207
|
+
def visit(self, visitor: Visitor) -> None:
|
|
208
|
+
"""Accepts a visitor on this type."""
|
|
209
|
+
visitor.visit(self)
|
|
210
|
+
|
|
211
|
+
def transform(self, transformer: Transformer) -> "Type":
|
|
212
|
+
"""Accepts a transformer on this type."""
|
|
213
|
+
return transformer.transform(self) or self
|
|
214
|
+
|
|
215
|
+
def __str__(self) -> str:
|
|
216
|
+
"""Returns a human-readable representation of the type."""
|
|
217
|
+
return self.display_name
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@dataclass(frozen=True)
|
|
221
|
+
class ExistentialTypeVar(ExistentialVar, TypeBase):
|
|
222
|
+
"""Existential type variable.
|
|
223
|
+
|
|
224
|
+
For example, the empty list literal `[]` is typed as `list[?T]` where `?T` stands
|
|
225
|
+
for an existential type variable.
|
|
226
|
+
|
|
227
|
+
During type checking we try to solve all existential type variables and substitute
|
|
228
|
+
them with concrete types.
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
copyable: bool
|
|
232
|
+
droppable: bool
|
|
233
|
+
|
|
234
|
+
@classmethod
|
|
235
|
+
def fresh(
|
|
236
|
+
cls, display_name: str, copyable: bool, droppable: bool
|
|
237
|
+
) -> "ExistentialTypeVar":
|
|
238
|
+
return ExistentialTypeVar(
|
|
239
|
+
display_name, next(cls._fresh_id), copyable, droppable
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
@cached_property
|
|
243
|
+
def unsolved_vars(self) -> set[ExistentialVar]:
|
|
244
|
+
"""The existential type variables contained in this type."""
|
|
245
|
+
return {self}
|
|
246
|
+
|
|
247
|
+
@cached_property
|
|
248
|
+
def hugr_bound(self) -> ht.TypeBound:
|
|
249
|
+
"""The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
|
|
250
|
+
raise InternalGuppyError(
|
|
251
|
+
"Tried to compute bound of unsolved existential type variable"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def cast(self) -> "Type":
|
|
255
|
+
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
256
|
+
return self
|
|
257
|
+
|
|
258
|
+
def to_hugr(self, ctx: ToHugrContext) -> ht.Type:
|
|
259
|
+
"""Computes the Hugr representation of the type."""
|
|
260
|
+
raise InternalGuppyError(
|
|
261
|
+
"Tried to convert unsolved existential type variable to Hugr"
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
def visit(self, visitor: Visitor) -> None:
|
|
265
|
+
"""Accepts a visitor on this type."""
|
|
266
|
+
visitor.visit(self)
|
|
267
|
+
|
|
268
|
+
def transform(self, transformer: Transformer) -> "Type":
|
|
269
|
+
"""Accepts a transformer on this type."""
|
|
270
|
+
return transformer.transform(self) or self
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@dataclass(frozen=True)
|
|
274
|
+
class NoneType(TypeBase):
|
|
275
|
+
"""Type of `None`."""
|
|
276
|
+
|
|
277
|
+
copyable: bool = field(default=True, init=True)
|
|
278
|
+
droppable: bool = field(default=True, init=True)
|
|
279
|
+
hugr_bound: ht.TypeBound = field(default=ht.TypeBound.Copyable, init=False)
|
|
280
|
+
|
|
281
|
+
# Flag to avoid turning the type into a row when calling `type_to_row()`. This is
|
|
282
|
+
# used to make sure that type vars instantiated to Nones are not broken up into
|
|
283
|
+
# empty rows when generating a Hugr
|
|
284
|
+
preserve: bool = field(default=False, compare=False)
|
|
285
|
+
|
|
286
|
+
def cast(self) -> "Type":
|
|
287
|
+
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
288
|
+
return self
|
|
289
|
+
|
|
290
|
+
def to_hugr(self, ctx: ToHugrContext) -> ht.Tuple:
|
|
291
|
+
"""Computes the Hugr representation of the type."""
|
|
292
|
+
return ht.Tuple()
|
|
293
|
+
|
|
294
|
+
def visit(self, visitor: Visitor) -> None:
|
|
295
|
+
"""Accepts a visitor on this type."""
|
|
296
|
+
visitor.visit(self)
|
|
297
|
+
|
|
298
|
+
def transform(self, transformer: Transformer) -> "Type":
|
|
299
|
+
"""Accepts a transformer on this type."""
|
|
300
|
+
return transformer.transform(self) or self
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@dataclass(frozen=True)
|
|
304
|
+
class NumericType(TypeBase):
|
|
305
|
+
"""Numeric types like `int` and `float`."""
|
|
306
|
+
|
|
307
|
+
kind: "Kind"
|
|
308
|
+
|
|
309
|
+
@total_ordering
|
|
310
|
+
class Kind(Enum):
|
|
311
|
+
"""The different kinds of numeric types."""
|
|
312
|
+
|
|
313
|
+
Nat = auto()
|
|
314
|
+
Int = auto()
|
|
315
|
+
Float = auto()
|
|
316
|
+
|
|
317
|
+
def __lt__(self, other: "NumericType.Kind") -> bool:
|
|
318
|
+
return self.value < other.value
|
|
319
|
+
|
|
320
|
+
INT_WIDTH: ClassVar[int] = 6
|
|
321
|
+
|
|
322
|
+
@property
|
|
323
|
+
def copyable(self) -> bool:
|
|
324
|
+
"""Whether objects of this type can be implicitly copied."""
|
|
325
|
+
return True
|
|
326
|
+
|
|
327
|
+
@property
|
|
328
|
+
def droppable(self) -> bool:
|
|
329
|
+
"""Whether objects of this type can be dropped."""
|
|
330
|
+
return True
|
|
331
|
+
|
|
332
|
+
def cast(self) -> "Type":
|
|
333
|
+
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
334
|
+
return self
|
|
335
|
+
|
|
336
|
+
def to_hugr(self, ctx: ToHugrContext) -> ht.ExtType:
|
|
337
|
+
"""Computes the Hugr representation of the type."""
|
|
338
|
+
match self.kind:
|
|
339
|
+
case NumericType.Kind.Nat | NumericType.Kind.Int:
|
|
340
|
+
return hugr.std.int.int_t(NumericType.INT_WIDTH)
|
|
341
|
+
case NumericType.Kind.Float:
|
|
342
|
+
return hugr.std.float.FLOAT_T
|
|
343
|
+
|
|
344
|
+
@property
|
|
345
|
+
def hugr_bound(self) -> ht.TypeBound:
|
|
346
|
+
"""The Hugr bound of this type, i.e. `Any` or `Copyable`"""
|
|
347
|
+
return ht.TypeBound.Copyable
|
|
348
|
+
|
|
349
|
+
def visit(self, visitor: Visitor) -> None:
|
|
350
|
+
"""Accepts a visitor on this type."""
|
|
351
|
+
visitor.visit(self)
|
|
352
|
+
|
|
353
|
+
def transform(self, transformer: Transformer) -> "Type":
|
|
354
|
+
"""Accepts a transformer on this type."""
|
|
355
|
+
return transformer.transform(self) or self
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class InputFlags(Flag):
|
|
359
|
+
"""Flags that can be set on inputs of function types.
|
|
360
|
+
|
|
361
|
+
In the future, we could add additional flags like `Frozen`.
|
|
362
|
+
"""
|
|
363
|
+
|
|
364
|
+
NoFlags = 0
|
|
365
|
+
Inout = auto()
|
|
366
|
+
Owned = auto()
|
|
367
|
+
Comptime = auto()
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@dataclass(frozen=True)
|
|
371
|
+
class FuncInput:
|
|
372
|
+
"""A single input of a function type."""
|
|
373
|
+
|
|
374
|
+
ty: "Type"
|
|
375
|
+
flags: InputFlags
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@dataclass(frozen=True, init=False)
|
|
379
|
+
class FunctionType(ParametrizedTypeBase):
|
|
380
|
+
"""Type of (potentially generic) functions."""
|
|
381
|
+
|
|
382
|
+
inputs: Sequence[FuncInput]
|
|
383
|
+
output: "Type"
|
|
384
|
+
params: Sequence[Parameter]
|
|
385
|
+
input_names: Sequence[str] | None
|
|
386
|
+
comptime_args: Sequence[ConstArg]
|
|
387
|
+
|
|
388
|
+
args: Sequence[Argument] = field(init=False)
|
|
389
|
+
copyable: bool = field(default=True, init=True)
|
|
390
|
+
droppable: bool = field(default=True, init=True)
|
|
391
|
+
intrinsically_copyable: bool = field(default=True, init=True)
|
|
392
|
+
intrinsically_droppable: bool = field(default=True, init=True)
|
|
393
|
+
hugr_bound: ht.TypeBound = field(default=ht.TypeBound.Copyable, init=False)
|
|
394
|
+
|
|
395
|
+
def __init__(
|
|
396
|
+
self,
|
|
397
|
+
inputs: Sequence[FuncInput],
|
|
398
|
+
output: "Type",
|
|
399
|
+
input_names: Sequence[str] | None = None,
|
|
400
|
+
params: Sequence[Parameter] | None = None,
|
|
401
|
+
comptime_args: Sequence[ConstArg] | None = None,
|
|
402
|
+
) -> None:
|
|
403
|
+
# We need a custom __init__ to set the args
|
|
404
|
+
args: list[Argument] = [TypeArg(inp.ty) for inp in inputs]
|
|
405
|
+
args.append(TypeArg(output))
|
|
406
|
+
|
|
407
|
+
# If no explicit comptime args are provided, assume that all of them are bound
|
|
408
|
+
params = params or []
|
|
409
|
+
if comptime_args is None:
|
|
410
|
+
comptime_args = [
|
|
411
|
+
param.to_bound()
|
|
412
|
+
for param in params
|
|
413
|
+
if isinstance(param, ConstParam) and param.from_comptime_arg
|
|
414
|
+
]
|
|
415
|
+
args += comptime_args
|
|
416
|
+
|
|
417
|
+
object.__setattr__(self, "args", args)
|
|
418
|
+
object.__setattr__(self, "comptime_args", comptime_args)
|
|
419
|
+
object.__setattr__(self, "inputs", inputs)
|
|
420
|
+
object.__setattr__(self, "output", output)
|
|
421
|
+
object.__setattr__(self, "input_names", input_names or [])
|
|
422
|
+
object.__setattr__(self, "params", params)
|
|
423
|
+
|
|
424
|
+
@property
|
|
425
|
+
def parametrized(self) -> bool:
|
|
426
|
+
"""Whether the function is parametrized."""
|
|
427
|
+
return len(self.params) > 0
|
|
428
|
+
|
|
429
|
+
def cast(self) -> "Type":
|
|
430
|
+
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
431
|
+
return self
|
|
432
|
+
|
|
433
|
+
def to_hugr(self, ctx: ToHugrContext) -> ht.FunctionType:
|
|
434
|
+
"""Computes the Hugr representation of the type."""
|
|
435
|
+
if self.parametrized:
|
|
436
|
+
raise InternalGuppyError(
|
|
437
|
+
"Tried to convert parametrised function type to Hugr. Use "
|
|
438
|
+
"`to_hugr_poly` instead"
|
|
439
|
+
)
|
|
440
|
+
return self._to_hugr_function_type(ctx)
|
|
441
|
+
|
|
442
|
+
def to_hugr_poly(self, ctx: ToHugrContext) -> ht.PolyFuncType:
|
|
443
|
+
"""Computes the Hugr `PolyFuncType` representation of the type."""
|
|
444
|
+
# Function body needs to be translated in a new context where the variables are
|
|
445
|
+
# bound to the quantifier.
|
|
446
|
+
inner_ctx = QuantifiedToHugrContext(self.params)
|
|
447
|
+
func_ty = self._to_hugr_function_type(inner_ctx)
|
|
448
|
+
return ht.PolyFuncType(
|
|
449
|
+
params=[p.to_hugr(ctx) for p in self.params], body=func_ty
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
def _to_hugr_function_type(self, ctx: ToHugrContext) -> ht.FunctionType:
|
|
453
|
+
"""Helper method to compute the Hugr `FunctionType` representation of the type.
|
|
454
|
+
|
|
455
|
+
The resulting `FunctionType` can then be embedded into a Hugr `Type` or a Hugr
|
|
456
|
+
`PolyFuncType`.
|
|
457
|
+
"""
|
|
458
|
+
ins = [
|
|
459
|
+
inp.ty.to_hugr(ctx)
|
|
460
|
+
for inp in self.inputs
|
|
461
|
+
# Comptime inputs are turned into generic args, so are not included here
|
|
462
|
+
if InputFlags.Comptime not in inp.flags
|
|
463
|
+
]
|
|
464
|
+
outs = [
|
|
465
|
+
*(t.to_hugr(ctx) for t in type_to_row(self.output)),
|
|
466
|
+
# We might have additional borrowed args that will be also outputted
|
|
467
|
+
*(
|
|
468
|
+
inp.ty.to_hugr(ctx)
|
|
469
|
+
for inp in self.inputs
|
|
470
|
+
if InputFlags.Inout in inp.flags
|
|
471
|
+
),
|
|
472
|
+
]
|
|
473
|
+
return ht.FunctionType(input=ins, output=outs)
|
|
474
|
+
|
|
475
|
+
def visit(self, visitor: Visitor) -> None:
|
|
476
|
+
"""Accepts a visitor on this type."""
|
|
477
|
+
if not visitor.visit(self):
|
|
478
|
+
for inp in self.inputs:
|
|
479
|
+
visitor.visit(inp)
|
|
480
|
+
visitor.visit(self.output)
|
|
481
|
+
for param in self.params:
|
|
482
|
+
visitor.visit(param)
|
|
483
|
+
|
|
484
|
+
def transform(self, transformer: Transformer) -> "Type":
|
|
485
|
+
"""Accepts a transformer on this type."""
|
|
486
|
+
return transformer.transform(self) or FunctionType(
|
|
487
|
+
[
|
|
488
|
+
FuncInput(inp.ty.transform(transformer), inp.flags)
|
|
489
|
+
for inp in self.inputs
|
|
490
|
+
],
|
|
491
|
+
self.output.transform(transformer),
|
|
492
|
+
self.input_names,
|
|
493
|
+
self.params,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
def instantiate_partial(self, args: "PartialInst") -> "FunctionType":
|
|
497
|
+
"""Instantiates a subset of the function parameters with concrete types."""
|
|
498
|
+
from guppylang_internals.tys.subst import Instantiator
|
|
499
|
+
|
|
500
|
+
assert len(args) == len(self.params)
|
|
501
|
+
|
|
502
|
+
full_inst: list[Argument] = []
|
|
503
|
+
remaining_params: list[Parameter] = []
|
|
504
|
+
for param, arg in zip(self.params, args, strict=True):
|
|
505
|
+
# If no instantiation for this param is provided, it should stay around.
|
|
506
|
+
# However, we have to down-shift the de Bruijn index.
|
|
507
|
+
if arg is None:
|
|
508
|
+
param = param.with_idx(len(remaining_params))
|
|
509
|
+
remaining_params.append(param)
|
|
510
|
+
arg = param.to_bound()
|
|
511
|
+
|
|
512
|
+
# Set the `preserve` flag for instantiated tuples and None
|
|
513
|
+
if isinstance(arg, TypeArg):
|
|
514
|
+
if isinstance(arg.ty, TupleType):
|
|
515
|
+
arg = TypeArg(TupleType(arg.ty.element_types, preserve=True))
|
|
516
|
+
elif isinstance(arg.ty, NoneType):
|
|
517
|
+
arg = TypeArg(NoneType(preserve=True))
|
|
518
|
+
full_inst.append(arg)
|
|
519
|
+
|
|
520
|
+
inst = Instantiator(full_inst)
|
|
521
|
+
return FunctionType(
|
|
522
|
+
[FuncInput(inp.ty.transform(inst), inp.flags) for inp in self.inputs],
|
|
523
|
+
self.output.transform(inst),
|
|
524
|
+
self.input_names,
|
|
525
|
+
remaining_params,
|
|
526
|
+
# Comptime type arguments also need to be instantiated
|
|
527
|
+
comptime_args=[
|
|
528
|
+
cast(ConstArg, arg.transform(inst)) for arg in self.comptime_args
|
|
529
|
+
],
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
def instantiate(self, args: "Inst") -> "FunctionType":
|
|
533
|
+
"""Instantiates all function parameters with concrete types."""
|
|
534
|
+
return self.instantiate_partial(args)
|
|
535
|
+
|
|
536
|
+
def unquantified(self) -> tuple["FunctionType", Sequence[ExistentialVar]]:
|
|
537
|
+
"""Instantiates all parameters with existential variables."""
|
|
538
|
+
exs = [param.to_existential() for param in self.params]
|
|
539
|
+
return self.instantiate([arg for arg, _ in exs]), [var for _, var in exs]
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
@dataclass(frozen=True, init=False)
|
|
543
|
+
class TupleType(ParametrizedTypeBase):
|
|
544
|
+
"""Type of tuples."""
|
|
545
|
+
|
|
546
|
+
element_types: Sequence["Type"]
|
|
547
|
+
|
|
548
|
+
# Flag to avoid turning the tuple into a row when calling `type_to_row()`. This is
|
|
549
|
+
# used to make sure that type vars instantiated to tuples are not broken up into
|
|
550
|
+
# rows when generating a Hugr
|
|
551
|
+
preserve: bool = field(default=False, compare=False)
|
|
552
|
+
|
|
553
|
+
def __init__(self, element_types: Sequence["Type"], preserve: bool = False) -> None:
|
|
554
|
+
# We need a custom __init__ to set the args
|
|
555
|
+
args = [TypeArg(ty) for ty in element_types]
|
|
556
|
+
object.__setattr__(self, "args", args)
|
|
557
|
+
object.__setattr__(self, "element_types", element_types)
|
|
558
|
+
object.__setattr__(self, "preserve", preserve)
|
|
559
|
+
|
|
560
|
+
@property
|
|
561
|
+
def intrinsically_copyable(self) -> bool:
|
|
562
|
+
"""Whether objects of this type can be implicitly copied."""
|
|
563
|
+
return True
|
|
564
|
+
|
|
565
|
+
@property
|
|
566
|
+
def intrinsically_droppable(self) -> bool:
|
|
567
|
+
"""Whether objects of this type can be dropped."""
|
|
568
|
+
return True
|
|
569
|
+
|
|
570
|
+
def cast(self) -> "Type":
|
|
571
|
+
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
572
|
+
return self
|
|
573
|
+
|
|
574
|
+
def to_hugr(self, ctx: ToHugrContext) -> ht.Tuple:
|
|
575
|
+
"""Computes the Hugr representation of the type."""
|
|
576
|
+
return ht.Tuple(*row_to_hugr(self.element_types, ctx))
|
|
577
|
+
|
|
578
|
+
def transform(self, transformer: Transformer) -> "Type":
|
|
579
|
+
"""Accepts a transformer on this type."""
|
|
580
|
+
return transformer.transform(self) or TupleType(
|
|
581
|
+
[ty.transform(transformer) for ty in self.element_types], self.preserve
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
@dataclass(frozen=True, init=False)
|
|
586
|
+
class SumType(ParametrizedTypeBase):
|
|
587
|
+
"""Type of sums.
|
|
588
|
+
|
|
589
|
+
Note that this type is only used internally when constructing the Hugr. Users cannot
|
|
590
|
+
write down this type.
|
|
591
|
+
"""
|
|
592
|
+
|
|
593
|
+
element_types: Sequence["Type"]
|
|
594
|
+
|
|
595
|
+
def __init__(self, element_types: Sequence["Type"]) -> None:
|
|
596
|
+
# We need a custom __init__ to set the args
|
|
597
|
+
args = [TypeArg(ty) for ty in element_types]
|
|
598
|
+
object.__setattr__(self, "args", args)
|
|
599
|
+
object.__setattr__(self, "element_types", element_types)
|
|
600
|
+
|
|
601
|
+
@property
|
|
602
|
+
def intrinsically_copyable(self) -> bool:
|
|
603
|
+
"""Whether objects of this type can be implicitly copied."""
|
|
604
|
+
return True
|
|
605
|
+
|
|
606
|
+
@property
|
|
607
|
+
def intrinsically_droppable(self) -> bool:
|
|
608
|
+
"""Whether objects of this type can be dropped."""
|
|
609
|
+
return True
|
|
610
|
+
|
|
611
|
+
def cast(self) -> "Type":
|
|
612
|
+
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
613
|
+
return self
|
|
614
|
+
|
|
615
|
+
def to_hugr(self, ctx: ToHugrContext) -> ht.Sum:
|
|
616
|
+
"""Computes the Hugr representation of the type."""
|
|
617
|
+
rows = [type_to_row(ty) for ty in self.element_types]
|
|
618
|
+
if all(len(row) == 0 for row in rows):
|
|
619
|
+
return ht.UnitSum(size=len(rows))
|
|
620
|
+
elif len(rows) == 1:
|
|
621
|
+
return ht.Tuple(*row_to_hugr(rows[0], ctx))
|
|
622
|
+
else:
|
|
623
|
+
return ht.Sum(variant_rows=rows_to_hugr(rows, ctx))
|
|
624
|
+
|
|
625
|
+
def transform(self, transformer: Transformer) -> "Type":
|
|
626
|
+
"""Accepts a transformer on this type."""
|
|
627
|
+
return transformer.transform(self) or SumType(
|
|
628
|
+
[ty.transform(transformer) for ty in self.element_types]
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
@dataclass(frozen=True)
|
|
633
|
+
class OpaqueType(ParametrizedTypeBase):
|
|
634
|
+
"""Type that is directly backed by a Hugr opaque type.
|
|
635
|
+
|
|
636
|
+
For example, many builtin types like `int`, `float`, `list` etc. are directly backed
|
|
637
|
+
by a Hugr extension.
|
|
638
|
+
"""
|
|
639
|
+
|
|
640
|
+
defn: "OpaqueTypeDef"
|
|
641
|
+
|
|
642
|
+
@property
|
|
643
|
+
def intrinsically_copyable(self) -> bool:
|
|
644
|
+
"""Whether objects of this type can be implicitly copied."""
|
|
645
|
+
return not self.defn.never_copyable
|
|
646
|
+
|
|
647
|
+
@property
|
|
648
|
+
def intrinsically_droppable(self) -> bool:
|
|
649
|
+
"""Whether objects of this type can be dropped."""
|
|
650
|
+
return not self.defn.never_droppable
|
|
651
|
+
|
|
652
|
+
@property
|
|
653
|
+
def hugr_bound(self) -> ht.TypeBound:
|
|
654
|
+
"""The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
|
|
655
|
+
if self.defn.bound is not None:
|
|
656
|
+
return self.defn.bound
|
|
657
|
+
return super().hugr_bound
|
|
658
|
+
|
|
659
|
+
def cast(self) -> "Type":
|
|
660
|
+
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
661
|
+
return self
|
|
662
|
+
|
|
663
|
+
def to_hugr(self, ctx: ToHugrContext) -> ht.Type:
|
|
664
|
+
"""Computes the Hugr representation of the type."""
|
|
665
|
+
return self.defn.to_hugr(self.args, ctx)
|
|
666
|
+
|
|
667
|
+
def transform(self, transformer: Transformer) -> "Type":
|
|
668
|
+
"""Accepts a transformer on this type."""
|
|
669
|
+
return transformer.transform(self) or OpaqueType(
|
|
670
|
+
[arg.transform(transformer) for arg in self.args], self.defn
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
@dataclass(frozen=True)
|
|
675
|
+
class StructType(ParametrizedTypeBase):
|
|
676
|
+
"""A struct type."""
|
|
677
|
+
|
|
678
|
+
defn: "CheckedStructDef"
|
|
679
|
+
|
|
680
|
+
@cached_property
|
|
681
|
+
def fields(self) -> list["StructField"]:
|
|
682
|
+
"""The fields of this struct type."""
|
|
683
|
+
from guppylang_internals.definition.struct import StructField
|
|
684
|
+
from guppylang_internals.tys.subst import Instantiator
|
|
685
|
+
|
|
686
|
+
inst = Instantiator(self.args)
|
|
687
|
+
return [StructField(f.name, f.ty.transform(inst)) for f in self.defn.fields]
|
|
688
|
+
|
|
689
|
+
@cached_property
|
|
690
|
+
def field_dict(self) -> "dict[str, StructField]":
|
|
691
|
+
"""Mapping from names to fields of this struct type."""
|
|
692
|
+
return {field.name: field for field in self.fields}
|
|
693
|
+
|
|
694
|
+
@cached_property
|
|
695
|
+
def intrinsically_copyable(self) -> bool:
|
|
696
|
+
"""Whether objects of this type can be implicitly copied."""
|
|
697
|
+
return all(f.ty.copyable for f in self.fields)
|
|
698
|
+
|
|
699
|
+
@cached_property
|
|
700
|
+
def intrinsically_droppable(self) -> bool:
|
|
701
|
+
"""Whether objects of this type can be dropped."""
|
|
702
|
+
return all(f.ty.droppable for f in self.fields)
|
|
703
|
+
|
|
704
|
+
def cast(self) -> "Type":
|
|
705
|
+
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
706
|
+
return self
|
|
707
|
+
|
|
708
|
+
def to_hugr(self, ctx: ToHugrContext) -> ht.Tuple:
|
|
709
|
+
"""Computes the Hugr representation of the type."""
|
|
710
|
+
return ht.Tuple(*(f.ty.to_hugr(ctx) for f in self.fields))
|
|
711
|
+
|
|
712
|
+
def transform(self, transformer: Transformer) -> "Type":
|
|
713
|
+
"""Accepts a transformer on this type."""
|
|
714
|
+
return transformer.transform(self) or StructType(
|
|
715
|
+
[arg.transform(transformer) for arg in self.args], self.defn
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
#: The type of parametrized Guppy types.
|
|
720
|
+
ParametrizedType: TypeAlias = (
|
|
721
|
+
FunctionType | TupleType | SumType | OpaqueType | StructType
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
#: The type of Guppy types.
|
|
725
|
+
#:
|
|
726
|
+
#: This is a type alias for a union of all Guppy types defined in this module. This
|
|
727
|
+
#: models an algebraic data type and enables exhaustiveness checking in pattern matches
|
|
728
|
+
#: etc.
|
|
729
|
+
#:
|
|
730
|
+
#: This might become obsolete in case the @sealed decorator is added:
|
|
731
|
+
#: * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types
|
|
732
|
+
#: * https://github.com/johnthagen/sealed-typing-pep
|
|
733
|
+
Type: TypeAlias = (
|
|
734
|
+
BoundTypeVar | ExistentialTypeVar | NumericType | NoneType | ParametrizedType
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
#: An immutable row of Guppy types.
|
|
738
|
+
TypeRow: TypeAlias = Sequence[Type]
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
def row_to_type(row: TypeRow) -> Type:
|
|
742
|
+
"""Turns a row of types into a single type by packing into a tuple."""
|
|
743
|
+
if len(row) == 0:
|
|
744
|
+
return NoneType()
|
|
745
|
+
elif len(row) == 1:
|
|
746
|
+
return row[0]
|
|
747
|
+
else:
|
|
748
|
+
return TupleType(row)
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def type_to_row(ty: Type) -> TypeRow:
|
|
752
|
+
"""Turns a type into a row of types by unpacking top-level tuples."""
|
|
753
|
+
if isinstance(ty, NoneType) and not ty.preserve:
|
|
754
|
+
return []
|
|
755
|
+
if isinstance(ty, TupleType) and not ty.preserve:
|
|
756
|
+
return ty.element_types
|
|
757
|
+
return [ty]
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def row_to_hugr(row: TypeRow, ctx: ToHugrContext) -> ht.TypeRow:
|
|
761
|
+
"""Computes the Hugr representation of a type row."""
|
|
762
|
+
return [ty.to_hugr(ctx) for ty in row]
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
def rows_to_hugr(rows: Sequence[TypeRow], ctx: ToHugrContext) -> list[ht.TypeRow]:
|
|
766
|
+
"""Computes the Hugr representation of a sequence of rows."""
|
|
767
|
+
return [row_to_hugr(row, ctx) for row in rows]
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def unify(s: Type | Const, t: Type | Const, subst: "Subst | None") -> "Subst | None":
|
|
771
|
+
"""Computes a most general unifier for two types or constants.
|
|
772
|
+
|
|
773
|
+
Return a substitutions `subst` such that `s[subst] == t[subst]` or `None` if this
|
|
774
|
+
not possible.
|
|
775
|
+
"""
|
|
776
|
+
# Make sure that s and t are either both constants or both types
|
|
777
|
+
assert isinstance(s, TypeBase) == isinstance(t, TypeBase)
|
|
778
|
+
if subst is None:
|
|
779
|
+
return None
|
|
780
|
+
match s, t:
|
|
781
|
+
case ExistentialVar(id=s_id), ExistentialVar(id=t_id) if s_id == t_id:
|
|
782
|
+
return subst
|
|
783
|
+
case ExistentialTypeVar() | ExistentialConstVar() as s_var, t:
|
|
784
|
+
return _unify_var(s_var, t, subst)
|
|
785
|
+
case s, ExistentialTypeVar() | ExistentialConstVar() as t_var:
|
|
786
|
+
return _unify_var(t_var, s, subst)
|
|
787
|
+
case BoundVar(idx=s_idx), BoundVar(idx=t_idx) if s_idx == t_idx:
|
|
788
|
+
return subst
|
|
789
|
+
case ConstValue(value=c_value), ConstValue(value=d_value) if c_value == d_value:
|
|
790
|
+
return subst
|
|
791
|
+
case NumericType(kind=s_kind), NumericType(kind=t_kind) if s_kind == t_kind:
|
|
792
|
+
return subst
|
|
793
|
+
case NoneType(), NoneType():
|
|
794
|
+
return subst
|
|
795
|
+
case FunctionType() as s, FunctionType() as t if s.params == t.params:
|
|
796
|
+
if len(s.inputs) != len(t.inputs):
|
|
797
|
+
return None
|
|
798
|
+
for a, b in zip(s.inputs, t.inputs, strict=True):
|
|
799
|
+
if a.ty.linear and b.ty.linear and a.flags != b.flags:
|
|
800
|
+
return None
|
|
801
|
+
return _unify_args(s, t, subst)
|
|
802
|
+
case TupleType() as s, TupleType() as t:
|
|
803
|
+
return _unify_args(s, t, subst)
|
|
804
|
+
case SumType() as s, SumType() as t:
|
|
805
|
+
return _unify_args(s, t, subst)
|
|
806
|
+
case OpaqueType() as s, OpaqueType() as t if s.defn == t.defn:
|
|
807
|
+
return _unify_args(s, t, subst)
|
|
808
|
+
case StructType() as s, StructType() as t if s.defn == t.defn:
|
|
809
|
+
return _unify_args(s, t, subst)
|
|
810
|
+
case _:
|
|
811
|
+
return None
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
def _unify_var(
|
|
815
|
+
var: ExistentialTypeVar | ExistentialConstVar, t: Type | Const, subst: "Subst"
|
|
816
|
+
) -> "Subst | None":
|
|
817
|
+
"""Helper function for unification of type or const variables."""
|
|
818
|
+
if var in subst:
|
|
819
|
+
return unify(subst[var], t, subst)
|
|
820
|
+
if isinstance(t, ExistentialTypeVar) and t in subst:
|
|
821
|
+
return unify(var, subst[t], subst)
|
|
822
|
+
if var in t.unsolved_vars:
|
|
823
|
+
return None
|
|
824
|
+
return {var: t, **subst}
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
def _unify_args(
|
|
828
|
+
s: ParametrizedType, t: ParametrizedType, subst: "Subst"
|
|
829
|
+
) -> "Subst | None":
|
|
830
|
+
"""Helper function for unification of type arguments of parametrised types."""
|
|
831
|
+
if len(s.args) != len(t.args):
|
|
832
|
+
return None
|
|
833
|
+
for sa, ta in zip(s.args, t.args, strict=True):
|
|
834
|
+
match sa, ta:
|
|
835
|
+
case TypeArg(ty=sa_ty), TypeArg(ty=ta_ty):
|
|
836
|
+
res = unify(sa_ty, ta_ty, subst)
|
|
837
|
+
if res is None:
|
|
838
|
+
return None
|
|
839
|
+
subst = res
|
|
840
|
+
case ConstArg(const=sa_const), ConstArg(const=ta_const):
|
|
841
|
+
res = unify(sa_const, ta_const, subst)
|
|
842
|
+
if res is None:
|
|
843
|
+
return None
|
|
844
|
+
subst = res
|
|
845
|
+
case _:
|
|
846
|
+
return None
|
|
847
|
+
return subst
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
### Helpers for working with tuples of functions
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
def parse_function_tensor(ty: TupleType) -> list[FunctionType] | None:
|
|
854
|
+
"""Parses a nested tuple of function types into a flat list of functions."""
|
|
855
|
+
result = []
|
|
856
|
+
for el in ty.element_types:
|
|
857
|
+
if isinstance(el, FunctionType):
|
|
858
|
+
result.append(el)
|
|
859
|
+
elif isinstance(el, TupleType):
|
|
860
|
+
funcs = parse_function_tensor(el)
|
|
861
|
+
if funcs:
|
|
862
|
+
result.extend(funcs)
|
|
863
|
+
else:
|
|
864
|
+
return None
|
|
865
|
+
return result
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
def function_tensor_signature(tys: list[FunctionType]) -> FunctionType:
|
|
869
|
+
"""Compute the combined function signature of a list of functions"""
|
|
870
|
+
inputs: list[FuncInput] = []
|
|
871
|
+
outputs: list[Type] = []
|
|
872
|
+
for fun_ty in tys:
|
|
873
|
+
assert not fun_ty.parametrized
|
|
874
|
+
inputs.extend(fun_ty.inputs)
|
|
875
|
+
outputs.extend(type_to_row(fun_ty.output))
|
|
876
|
+
return FunctionType(inputs, row_to_type(outputs))
|