guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +118 -5
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +5 -2
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +58 -17
- guppylang_internals/checker/func_checker.py +18 -14
- guppylang_internals/checker/linearity_checker.py +67 -10
- guppylang_internals/checker/modifier_checker.py +120 -0
- guppylang_internals/checker/stmt_checker.py +48 -1
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +93 -56
- guppylang_internals/compiler/expr_compiler.py +72 -168
- guppylang_internals/compiler/modifier_compiler.py +176 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +86 -7
- guppylang_internals/definition/custom.py +39 -1
- guppylang_internals/definition/declaration.py +9 -6
- guppylang_internals/definition/function.py +12 -2
- guppylang_internals/definition/parameter.py +8 -3
- guppylang_internals/definition/pytket_circuits.py +14 -41
- guppylang_internals/definition/struct.py +13 -7
- guppylang_internals/definition/ty.py +3 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/engine.py +9 -3
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +147 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +95 -283
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +10 -0
- guppylang_internals/tracing/unpacking.py +19 -20
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +2 -5
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +11 -24
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +62 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +91 -85
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
- guppylang_internals-0.26.0.dist-info/RECORD +104 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
- guppylang_internals-0.24.0.dist-info/RECORD +0 -98
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
guppylang_internals/tys/param.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from dataclasses import dataclass, field
|
|
3
|
+
from dataclasses import dataclass, field, replace
|
|
4
4
|
from typing import TYPE_CHECKING, TypeAlias
|
|
5
5
|
|
|
6
6
|
from hugr import tys as ht
|
|
@@ -17,9 +17,9 @@ from guppylang_internals.tys.errors import WrongNumberOfTypeArgsError
|
|
|
17
17
|
from guppylang_internals.tys.var import ExistentialVar
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
|
+
from guppylang_internals.tys.subst import PartialInst
|
|
20
21
|
from guppylang_internals.tys.ty import Type
|
|
21
22
|
|
|
22
|
-
|
|
23
23
|
# We define the `Parameter` type as a union of all `ParameterBase` subclasses defined
|
|
24
24
|
# below. This models an algebraic data type and enables exhaustiveness checking in
|
|
25
25
|
# pattern matches etc.
|
|
@@ -74,6 +74,10 @@ class ParameterBase(ToHugr[ht.TypeParam], ABC):
|
|
|
74
74
|
parameter.
|
|
75
75
|
"""
|
|
76
76
|
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def instantiate_bounds(self, inst: "PartialInst") -> Self:
|
|
79
|
+
"""Instantiates bound variables mentioned in parameter bounds"""
|
|
80
|
+
|
|
77
81
|
|
|
78
82
|
@dataclass(frozen=True)
|
|
79
83
|
class TypeParam(ParameterBase):
|
|
@@ -142,10 +146,17 @@ class TypeParam(ParameterBase):
|
|
|
142
146
|
BoundTypeVar(self.name, idx, self.must_be_copyable, self.must_be_droppable)
|
|
143
147
|
)
|
|
144
148
|
|
|
149
|
+
def instantiate_bounds(self, inst: "PartialInst") -> "TypeParam":
|
|
150
|
+
"""Instantiates bound variables mentioned in parameter bounds"""
|
|
151
|
+
# For now, type parameters don't have any bounds that could be instantiated
|
|
152
|
+
return self
|
|
153
|
+
|
|
145
154
|
def to_hugr(self, ctx: ToHugrContext) -> ht.TypeParam:
|
|
146
155
|
"""Computes the Hugr representation of the parameter."""
|
|
147
156
|
return ht.TypeTypeParam(
|
|
148
|
-
bound=ht.TypeBound.
|
|
157
|
+
bound=ht.TypeBound.Copyable
|
|
158
|
+
if self.must_be_copyable
|
|
159
|
+
else ht.TypeBound.Linear
|
|
149
160
|
)
|
|
150
161
|
|
|
151
162
|
def __str__(self) -> str:
|
|
@@ -157,18 +168,12 @@ class TypeParam(ParameterBase):
|
|
|
157
168
|
class ConstParam(ParameterBase):
|
|
158
169
|
"""A parameter of kind constant. Used to define fixed-size arrays etc."""
|
|
159
170
|
|
|
160
|
-
ty: "Type"
|
|
171
|
+
ty: "Type" = field(hash=False)
|
|
161
172
|
|
|
162
173
|
#: Marker to annotate if this parameter was implicitly generated by a `@comptime`
|
|
163
174
|
#: annotated argument in a function signature.
|
|
164
175
|
from_comptime_arg: bool = field(default=False, kw_only=True)
|
|
165
176
|
|
|
166
|
-
def __post_init__(self) -> None:
|
|
167
|
-
if self.ty.unsolved_vars:
|
|
168
|
-
raise InternalGuppyError(
|
|
169
|
-
"Attempted to create constant param with unsolved type"
|
|
170
|
-
)
|
|
171
|
-
|
|
172
177
|
def with_idx(self, idx: int) -> "ConstParam":
|
|
173
178
|
"""Returns a copy of the parameter with a new index."""
|
|
174
179
|
return ConstParam(idx, self.name, self.ty)
|
|
@@ -178,13 +183,16 @@ class ConstParam(ParameterBase):
|
|
|
178
183
|
|
|
179
184
|
Raises a user error if the argument is not valid.
|
|
180
185
|
"""
|
|
186
|
+
from guppylang_internals.tys.ty import unify
|
|
187
|
+
|
|
181
188
|
match arg:
|
|
182
189
|
case ConstArg(const):
|
|
183
|
-
|
|
190
|
+
subst = unify(const.ty, self.ty, {})
|
|
191
|
+
if subst is None:
|
|
184
192
|
raise GuppyTypeError(
|
|
185
193
|
TypeMismatchError(loc, self.ty, const.ty, kind="argument")
|
|
186
194
|
)
|
|
187
|
-
return
|
|
195
|
+
return ConstArg(replace(const, ty=const.ty.substitute(subst)))
|
|
188
196
|
case TypeArg(ty=ty):
|
|
189
197
|
err = ExpectedError(
|
|
190
198
|
loc, f"expression of type `{self.ty}`", got=f"type `{ty}`"
|
|
@@ -208,6 +216,13 @@ class ConstParam(ParameterBase):
|
|
|
208
216
|
idx = self.idx
|
|
209
217
|
return ConstArg(BoundConstVar(self.ty, self.name, idx))
|
|
210
218
|
|
|
219
|
+
def instantiate_bounds(self, inst: "PartialInst") -> "ConstParam":
|
|
220
|
+
"""Instantiates bound variables mentioned in parameter bounds"""
|
|
221
|
+
from guppylang_internals.tys.subst import Instantiator
|
|
222
|
+
|
|
223
|
+
instantiator = Instantiator(inst)
|
|
224
|
+
return replace(self, ty=self.ty.transform(instantiator))
|
|
225
|
+
|
|
211
226
|
def to_hugr(self, ctx: ToHugrContext) -> ht.TypeParam:
|
|
212
227
|
"""Computes the Hugr representation of the parameter."""
|
|
213
228
|
from guppylang_internals.tys.ty import NumericType
|
|
@@ -231,6 +246,7 @@ def check_all_args(
|
|
|
231
246
|
args: Sequence[Argument],
|
|
232
247
|
type_name: str,
|
|
233
248
|
loc: AstNode | None = None,
|
|
249
|
+
arg_locs: Sequence[AstNode] | None = None,
|
|
234
250
|
) -> None:
|
|
235
251
|
"""Checks a list of arguments against the given parameters.
|
|
236
252
|
|
|
@@ -245,7 +261,6 @@ def check_all_args(
|
|
|
245
261
|
raise GuppyError(WrongNumberOfTypeArgsError(loc, exp, act, type_name))
|
|
246
262
|
|
|
247
263
|
# Now check that the kinds match up
|
|
248
|
-
for param, arg in zip(params, args, strict=True):
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
param.check_arg(arg, loc)
|
|
264
|
+
for i, (param, arg) in enumerate(zip(params, args, strict=True)):
|
|
265
|
+
arg_loc = arg_locs[i] if arg_locs else loc
|
|
266
|
+
param.instantiate_bounds(args).check_arg(arg, arg_loc)
|
|
@@ -39,7 +39,6 @@ from guppylang_internals.tys.errors import (
|
|
|
39
39
|
WrongNumberOfTypeArgsError,
|
|
40
40
|
)
|
|
41
41
|
from guppylang_internals.tys.param import ConstParam, Parameter, TypeParam
|
|
42
|
-
from guppylang_internals.tys.subst import BoundVarFinder
|
|
43
42
|
from guppylang_internals.tys.ty import (
|
|
44
43
|
FuncInput,
|
|
45
44
|
FunctionType,
|
|
@@ -108,7 +107,7 @@ def arg_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Argument:
|
|
|
108
107
|
return ConstArg(ConstValue(bool_type(), v))
|
|
109
108
|
# Integer literals are turned into nat args.
|
|
110
109
|
# TODO: To support int args, we need proper inference logic here
|
|
111
|
-
# See https://github.com/
|
|
110
|
+
# See https://github.com/quantinuum/guppylang/issues/1030
|
|
112
111
|
case int(v) if v >= 0:
|
|
113
112
|
nat_ty = NumericType(NumericType.Kind.Nat)
|
|
114
113
|
return ConstArg(ConstValue(nat_ty, v))
|
|
@@ -118,7 +117,7 @@ def arg_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Argument:
|
|
|
118
117
|
# String literals are ignored for now since they could also be stringified
|
|
119
118
|
# types.
|
|
120
119
|
# TODO: To support string args, we need proper inference logic here
|
|
121
|
-
# See https://github.com/
|
|
120
|
+
# See https://github.com/quantinuum/guppylang/issues/1030
|
|
122
121
|
case str(_):
|
|
123
122
|
pass
|
|
124
123
|
|
|
@@ -290,12 +289,18 @@ def check_function_arg(
|
|
|
290
289
|
ctx.param_var_mapping[name] = ConstParam(
|
|
291
290
|
len(ctx.param_var_mapping), name, ty, from_comptime_arg=True
|
|
292
291
|
)
|
|
293
|
-
return FuncInput(ty, flags)
|
|
292
|
+
return FuncInput(ty, flags, name)
|
|
294
293
|
|
|
295
294
|
|
|
296
295
|
if sys.version_info >= (3, 12):
|
|
297
296
|
|
|
298
|
-
def parse_parameter(
|
|
297
|
+
def parse_parameter(
|
|
298
|
+
node: ast.type_param,
|
|
299
|
+
idx: int,
|
|
300
|
+
globals: Globals,
|
|
301
|
+
param_var_mapping: dict[str, Parameter],
|
|
302
|
+
allow_free_vars: bool = False,
|
|
303
|
+
) -> Parameter:
|
|
299
304
|
"""Parses a `Variable: Bound` generic type parameter declaration."""
|
|
300
305
|
if isinstance(node, ast.TypeVarTuple | ast.ParamSpec):
|
|
301
306
|
raise GuppyError(UnsupportedError(node, "Variadic generic parameters"))
|
|
@@ -331,18 +336,10 @@ if sys.version_info >= (3, 12):
|
|
|
331
336
|
# parameters, so we pass an empty dict as the `param_var_mapping`.
|
|
332
337
|
# TODO: In the future we might want to allow stuff like
|
|
333
338
|
# `def foo[T, XS: array[T, 42]]` and so on
|
|
334
|
-
ctx = TypeParsingCtx(globals, param_var_mapping
|
|
339
|
+
ctx = TypeParsingCtx(globals, param_var_mapping, allow_free_vars)
|
|
335
340
|
ty = type_from_ast(bound, ctx)
|
|
336
341
|
if not ty.copyable or not ty.droppable:
|
|
337
342
|
raise GuppyError(LinearConstParamError(bound, ty))
|
|
338
|
-
|
|
339
|
-
# TODO: For now we can only do `nat` const args since they lower to
|
|
340
|
-
# Hugr bounded nats. Extend to arbitrary types via monomorphization.
|
|
341
|
-
# See https://github.com/CQCL/guppylang/issues/1008
|
|
342
|
-
if ty != NumericType(NumericType.Kind.Nat):
|
|
343
|
-
raise GuppyError(
|
|
344
|
-
UnsupportedError(bound, f"`{ty}` generic parameters")
|
|
345
|
-
)
|
|
346
343
|
return ConstParam(idx, node.name, ty)
|
|
347
344
|
|
|
348
345
|
|
|
@@ -363,16 +360,6 @@ def type_with_flags_from_ast(
|
|
|
363
360
|
flags |= InputFlags.Comptime
|
|
364
361
|
if not ty.copyable or not ty.droppable:
|
|
365
362
|
raise GuppyError(LinearComptimeError(node.right, ty))
|
|
366
|
-
# For now, we don't allow comptime annotations on generic inputs
|
|
367
|
-
# TODO: In the future we might want to allow stuff like
|
|
368
|
-
# `def foo[T: (Copy, Discard](x: T @comptime)`.
|
|
369
|
-
# Also see the todo in `parse_parameter`.
|
|
370
|
-
var_finder = BoundVarFinder()
|
|
371
|
-
ty.visit(var_finder)
|
|
372
|
-
if var_finder.bound_vars:
|
|
373
|
-
raise GuppyError(
|
|
374
|
-
UnsupportedError(node.left, "Generic comptime arguments")
|
|
375
|
-
)
|
|
376
363
|
case _:
|
|
377
364
|
raise GuppyError(InvalidFlagError(node.right))
|
|
378
365
|
return ty, flags
|
|
@@ -11,7 +11,6 @@ from guppylang_internals.tys.ty import (
|
|
|
11
11
|
NumericType,
|
|
12
12
|
OpaqueType,
|
|
13
13
|
StructType,
|
|
14
|
-
SumType,
|
|
15
14
|
TupleType,
|
|
16
15
|
Type,
|
|
17
16
|
)
|
|
@@ -122,11 +121,6 @@ class TypePrinter:
|
|
|
122
121
|
args = ", ".join(self._visit(arg, True) for arg in ty.args)
|
|
123
122
|
return f"({args})"
|
|
124
123
|
|
|
125
|
-
@_visit.register
|
|
126
|
-
def _visit_SumType(self, ty: SumType, inside_row: bool) -> str:
|
|
127
|
-
args = ", ".join(self._visit(arg, True) for arg in ty.args)
|
|
128
|
-
return f"Sum[{args}]"
|
|
129
|
-
|
|
130
124
|
@_visit.register
|
|
131
125
|
def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str:
|
|
132
126
|
return "None"
|
|
@@ -168,7 +162,7 @@ def signature_to_str(name: str, sig: FunctionType) -> str:
|
|
|
168
162
|
assert sig.input_names is not None
|
|
169
163
|
s = f"def {name}("
|
|
170
164
|
s += ", ".join(
|
|
171
|
-
f"{name}: {inp.ty}{TypePrinter._print_flags(inp.flags)}"
|
|
172
|
-
for
|
|
165
|
+
f"{inp.name}: {inp.ty}{TypePrinter._print_flags(inp.flags)}"
|
|
166
|
+
for inp in sig.inputs
|
|
173
167
|
)
|
|
174
168
|
return s + ") -> " + str(sig.output)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from typing import Any, cast
|
|
3
|
+
|
|
4
|
+
from guppylang_internals.definition.ty import TypeDef
|
|
5
|
+
from guppylang_internals.tys.arg import TypeArg
|
|
6
|
+
from guppylang_internals.tys.common import Visitor
|
|
7
|
+
from guppylang_internals.tys.ty import OpaqueType, Type
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@functools.cache
|
|
11
|
+
def qubit_ty() -> Type:
|
|
12
|
+
"""Returns the qubit type. Beware that this function imports guppylang definitions,
|
|
13
|
+
so, if called before the definitions are registered,
|
|
14
|
+
it might result in circular imports.
|
|
15
|
+
"""
|
|
16
|
+
from guppylang.defs import GuppyDefinition
|
|
17
|
+
from guppylang.std.quantum import qubit
|
|
18
|
+
|
|
19
|
+
assert isinstance(qubit, GuppyDefinition)
|
|
20
|
+
qubit_ty = cast(TypeDef, qubit.wrapped).check_instantiate([])
|
|
21
|
+
return qubit_ty
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def is_qubit_ty(ty: Type) -> bool:
|
|
25
|
+
"""Checks if the given type is the qubit type.
|
|
26
|
+
This function results in circular imports if called
|
|
27
|
+
before qubit types are registered.
|
|
28
|
+
"""
|
|
29
|
+
return ty == qubit_ty()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class QubitFinder(Visitor):
|
|
33
|
+
"""Type visitor that checks if a type contains the qubit type."""
|
|
34
|
+
|
|
35
|
+
class FoundFlag(Exception):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@functools.singledispatchmethod
|
|
39
|
+
def visit(self, ty: Any) -> bool: # type: ignore[override]
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
@visit.register
|
|
43
|
+
def _visit_OpaqueType(self, ty: OpaqueType) -> bool:
|
|
44
|
+
if is_qubit_ty(ty):
|
|
45
|
+
raise self.FoundFlag
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
@visit.register
|
|
49
|
+
def _visit_TypeArg(self, arg: TypeArg) -> bool:
|
|
50
|
+
arg.ty.visit(self)
|
|
51
|
+
return True
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def contain_qubit_ty(ty: Type) -> bool:
|
|
55
|
+
"""Checks if the given type contains the qubit type."""
|
|
56
|
+
finder = QubitFinder()
|
|
57
|
+
try:
|
|
58
|
+
ty.visit(finder)
|
|
59
|
+
except QubitFinder.FoundFlag:
|
|
60
|
+
return True
|
|
61
|
+
else:
|
|
62
|
+
return False
|
guppylang_internals/tys/subst.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
from guppylang_internals.error import InternalGuppyError
|
|
6
6
|
from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
|
|
7
|
-
from guppylang_internals.tys.common import Transformer
|
|
7
|
+
from guppylang_internals.tys.common import Transformer
|
|
8
8
|
from guppylang_internals.tys.const import (
|
|
9
9
|
BoundConstVar,
|
|
10
10
|
Const,
|
|
@@ -18,7 +18,7 @@ from guppylang_internals.tys.ty import (
|
|
|
18
18
|
Type,
|
|
19
19
|
TypeBase,
|
|
20
20
|
)
|
|
21
|
-
from guppylang_internals.tys.var import
|
|
21
|
+
from guppylang_internals.tys.var import ExistentialVar
|
|
22
22
|
|
|
23
23
|
Subst = dict[ExistentialVar, Type | Const]
|
|
24
24
|
Inst = Sequence[Argument]
|
|
@@ -51,7 +51,8 @@ class Substituter(Transformer):
|
|
|
51
51
|
class Instantiator(Transformer):
|
|
52
52
|
"""Type transformer that instantiates bound variables."""
|
|
53
53
|
|
|
54
|
-
def __init__(self, inst:
|
|
54
|
+
def __init__(self, inst: PartialInst, allow_partial: bool = False) -> None:
|
|
55
|
+
self.allow_partial = allow_partial
|
|
55
56
|
self.inst = inst
|
|
56
57
|
|
|
57
58
|
@functools.singledispatchmethod
|
|
@@ -63,6 +64,8 @@ class Instantiator(Transformer):
|
|
|
63
64
|
# Instantiate if type for the index is available
|
|
64
65
|
if ty.idx < len(self.inst):
|
|
65
66
|
arg = self.inst[ty.idx]
|
|
67
|
+
if arg is None and self.allow_partial:
|
|
68
|
+
return None
|
|
66
69
|
assert isinstance(arg, TypeArg)
|
|
67
70
|
return arg.ty
|
|
68
71
|
|
|
@@ -76,6 +79,8 @@ class Instantiator(Transformer):
|
|
|
76
79
|
# Instantiate if const value for the index is available
|
|
77
80
|
if c.idx < len(self.inst):
|
|
78
81
|
arg = self.inst[c.idx]
|
|
82
|
+
if arg is None and self.allow_partial:
|
|
83
|
+
return None
|
|
79
84
|
assert isinstance(arg, ConstArg)
|
|
80
85
|
return arg.const
|
|
81
86
|
|
|
@@ -87,26 +92,3 @@ class Instantiator(Transformer):
|
|
|
87
92
|
if ty.parametrized:
|
|
88
93
|
raise InternalGuppyError("Tried to instantiate under binder")
|
|
89
94
|
return None
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
class BoundVarFinder(Visitor):
|
|
93
|
-
"""Type visitor that looks for occurrences of bound variables."""
|
|
94
|
-
|
|
95
|
-
bound_vars: set[BoundVar]
|
|
96
|
-
|
|
97
|
-
def __init__(self) -> None:
|
|
98
|
-
self.bound_vars = set()
|
|
99
|
-
|
|
100
|
-
@functools.singledispatchmethod
|
|
101
|
-
def visit(self, ty: Any) -> bool: # type: ignore[override]
|
|
102
|
-
return False
|
|
103
|
-
|
|
104
|
-
@visit.register
|
|
105
|
-
def _transform_BoundTypeVar(self, ty: BoundTypeVar) -> bool:
|
|
106
|
-
self.bound_vars.add(ty)
|
|
107
|
-
return False
|
|
108
|
-
|
|
109
|
-
@visit.register
|
|
110
|
-
def _transform_BoundConstVar(self, c: BoundConstVar) -> bool:
|
|
111
|
-
self.bound_vars.add(c)
|
|
112
|
-
return False
|
guppylang_internals/tys/ty.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from dataclasses import dataclass, field
|
|
3
|
+
from dataclasses import dataclass, field, replace
|
|
4
4
|
from enum import Enum, Flag, auto
|
|
5
5
|
from functools import cached_property, total_ordering
|
|
6
6
|
from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
|
|
@@ -57,14 +57,11 @@ class TypeBase(ToHugr[ht.Type], Transformable["Type"], ABC):
|
|
|
57
57
|
return not self.copyable and self.droppable
|
|
58
58
|
|
|
59
59
|
@cached_property
|
|
60
|
-
@abstractmethod
|
|
61
60
|
def hugr_bound(self) -> ht.TypeBound:
|
|
62
|
-
"""The Hugr bound of this type, i.e. `Any
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
bound exactly right during serialisation, the Hugr validator will complain.
|
|
67
|
-
"""
|
|
61
|
+
"""The Hugr bound of this type, i.e. `Any` or `Copyable`."""
|
|
62
|
+
if self.linear or self.affine:
|
|
63
|
+
return ht.TypeBound.Linear
|
|
64
|
+
return ht.TypeBound.Copyable
|
|
68
65
|
|
|
69
66
|
@abstractmethod
|
|
70
67
|
def cast(self) -> "Type":
|
|
@@ -79,6 +76,11 @@ class TypeBase(ToHugr[ht.Type], Transformable["Type"], ABC):
|
|
|
79
76
|
"""The existential type variables contained in this type."""
|
|
80
77
|
return set()
|
|
81
78
|
|
|
79
|
+
@cached_property
|
|
80
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
81
|
+
"""The bound type variables contained in this type."""
|
|
82
|
+
return set()
|
|
83
|
+
|
|
82
84
|
def substitute(self, subst: "Subst") -> "Type":
|
|
83
85
|
"""Substitutes existential variables in this type."""
|
|
84
86
|
from guppylang_internals.tys.subst import Substituter
|
|
@@ -158,13 +160,17 @@ class ParametrizedTypeBase(TypeBase, ABC):
|
|
|
158
160
|
"""The existential type variables contained in this type."""
|
|
159
161
|
return set().union(*(arg.unsolved_vars for arg in self.args))
|
|
160
162
|
|
|
163
|
+
@cached_property
|
|
164
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
165
|
+
"""The bound type variables contained in this type."""
|
|
166
|
+
return set().union(*(arg.bound_vars for arg in self.args))
|
|
167
|
+
|
|
161
168
|
@cached_property
|
|
162
169
|
def hugr_bound(self) -> ht.TypeBound:
|
|
163
|
-
"""The Hugr bound of this type, i.e. `Any
|
|
164
|
-
if self.linear:
|
|
165
|
-
return ht.TypeBound.Linear
|
|
170
|
+
"""The Hugr bound of this type, i.e. `Any` or `Copyable`."""
|
|
166
171
|
return ht.TypeBound.join(
|
|
167
|
-
|
|
172
|
+
super().hugr_bound,
|
|
173
|
+
*(arg.ty.hugr_bound for arg in self.args if isinstance(arg, TypeArg)),
|
|
168
174
|
)
|
|
169
175
|
|
|
170
176
|
def visit(self, visitor: Visitor) -> None:
|
|
@@ -187,14 +193,10 @@ class BoundTypeVar(TypeBase, BoundVar):
|
|
|
187
193
|
copyable: bool
|
|
188
194
|
droppable: bool
|
|
189
195
|
|
|
190
|
-
@
|
|
191
|
-
def
|
|
192
|
-
"""The
|
|
193
|
-
|
|
194
|
-
return ht.TypeBound.Linear
|
|
195
|
-
# We're conservative and don't require equatability for non-linear variables.
|
|
196
|
-
# This is fine since Guppy doesn't use the equatable feature anyways.
|
|
197
|
-
return ht.TypeBound.Copyable
|
|
196
|
+
@property
|
|
197
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
198
|
+
"""The bound type variables contained in this type."""
|
|
199
|
+
return {self}
|
|
198
200
|
|
|
199
201
|
def cast(self) -> "Type":
|
|
200
202
|
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
@@ -367,6 +369,21 @@ class InputFlags(Flag):
|
|
|
367
369
|
Comptime = auto()
|
|
368
370
|
|
|
369
371
|
|
|
372
|
+
class UnitaryFlags(Flag):
|
|
373
|
+
"""Flags that can be set on functions to indicate their unitary properties.
|
|
374
|
+
|
|
375
|
+
The flags indicate under which conditions a function can be used
|
|
376
|
+
in a unitary context.
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
NoFlags = 0
|
|
380
|
+
Control = auto()
|
|
381
|
+
Dagger = auto()
|
|
382
|
+
Power = auto()
|
|
383
|
+
|
|
384
|
+
Unitary = Control | Dagger | Power
|
|
385
|
+
|
|
386
|
+
|
|
370
387
|
@dataclass(frozen=True)
|
|
371
388
|
class FuncInput:
|
|
372
389
|
"""A single input of a function type."""
|
|
@@ -374,6 +391,10 @@ class FuncInput:
|
|
|
374
391
|
ty: "Type"
|
|
375
392
|
flags: InputFlags
|
|
376
393
|
|
|
394
|
+
#: Name of this input, or `None` if it is an unnamed argument (e.g. inside a
|
|
395
|
+
#: `Callable`). We use `compare=False` because names are not visible to the caller.
|
|
396
|
+
name: str | None = field(default=None, compare=False)
|
|
397
|
+
|
|
377
398
|
|
|
378
399
|
@dataclass(frozen=True, init=False)
|
|
379
400
|
class FunctionType(ParametrizedTypeBase):
|
|
@@ -382,7 +403,6 @@ class FunctionType(ParametrizedTypeBase):
|
|
|
382
403
|
inputs: Sequence[FuncInput]
|
|
383
404
|
output: "Type"
|
|
384
405
|
params: Sequence[Parameter]
|
|
385
|
-
input_names: Sequence[str] | None
|
|
386
406
|
comptime_args: Sequence[ConstArg]
|
|
387
407
|
|
|
388
408
|
args: Sequence[Argument] = field(init=False)
|
|
@@ -392,13 +412,15 @@ class FunctionType(ParametrizedTypeBase):
|
|
|
392
412
|
intrinsically_droppable: bool = field(default=True, init=True)
|
|
393
413
|
hugr_bound: ht.TypeBound = field(default=ht.TypeBound.Copyable, init=False)
|
|
394
414
|
|
|
415
|
+
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, init=True)
|
|
416
|
+
|
|
395
417
|
def __init__(
|
|
396
418
|
self,
|
|
397
419
|
inputs: Sequence[FuncInput],
|
|
398
420
|
output: "Type",
|
|
399
|
-
input_names: Sequence[str] | None = None,
|
|
400
421
|
params: Sequence[Parameter] | None = None,
|
|
401
422
|
comptime_args: Sequence[ConstArg] | None = None,
|
|
423
|
+
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
|
|
402
424
|
) -> None:
|
|
403
425
|
# We need a custom __init__ to set the args
|
|
404
426
|
args: list[Argument] = [TypeArg(inp.ty) for inp in inputs]
|
|
@@ -414,18 +436,43 @@ class FunctionType(ParametrizedTypeBase):
|
|
|
414
436
|
]
|
|
415
437
|
args += comptime_args
|
|
416
438
|
|
|
439
|
+
# Either all inputs must have unique names, or none of them have names
|
|
440
|
+
names = {inp.name for inp in inputs if inp.name is not None}
|
|
441
|
+
if len(names) not in (0, len(inputs)):
|
|
442
|
+
raise InternalGuppyError(
|
|
443
|
+
"Tried to construct FunctionType with invalid input names"
|
|
444
|
+
)
|
|
445
|
+
|
|
417
446
|
object.__setattr__(self, "args", args)
|
|
418
447
|
object.__setattr__(self, "comptime_args", comptime_args)
|
|
419
448
|
object.__setattr__(self, "inputs", inputs)
|
|
420
449
|
object.__setattr__(self, "output", output)
|
|
421
|
-
object.__setattr__(self, "input_names", input_names or [])
|
|
422
450
|
object.__setattr__(self, "params", params)
|
|
451
|
+
object.__setattr__(self, "unitary_flags", unitary_flags)
|
|
423
452
|
|
|
424
453
|
@property
|
|
425
454
|
def parametrized(self) -> bool:
|
|
426
455
|
"""Whether the function is parametrized."""
|
|
427
456
|
return len(self.params) > 0
|
|
428
457
|
|
|
458
|
+
@cached_property
|
|
459
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
460
|
+
"""The bound type variables contained in this type."""
|
|
461
|
+
if self.parametrized:
|
|
462
|
+
# Ensures that we don't look inside quantifiers
|
|
463
|
+
return set()
|
|
464
|
+
return super().bound_vars
|
|
465
|
+
|
|
466
|
+
@cached_property
|
|
467
|
+
def input_names(self) -> Sequence[str] | None:
|
|
468
|
+
"""Names of all inputs or `None` if there are unnamed inputs."""
|
|
469
|
+
names: list[str] = []
|
|
470
|
+
for inp in self.inputs:
|
|
471
|
+
if inp.name is None:
|
|
472
|
+
return None
|
|
473
|
+
names.append(inp.name)
|
|
474
|
+
return names
|
|
475
|
+
|
|
429
476
|
def cast(self) -> "Type":
|
|
430
477
|
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
431
478
|
return self
|
|
@@ -484,12 +531,8 @@ class FunctionType(ParametrizedTypeBase):
|
|
|
484
531
|
def transform(self, transformer: Transformer) -> "Type":
|
|
485
532
|
"""Accepts a transformer on this type."""
|
|
486
533
|
return transformer.transform(self) or FunctionType(
|
|
487
|
-
[
|
|
488
|
-
FuncInput(inp.ty.transform(transformer), inp.flags)
|
|
489
|
-
for inp in self.inputs
|
|
490
|
-
],
|
|
534
|
+
[replace(inp, ty=inp.ty.transform(transformer)) for inp in self.inputs],
|
|
491
535
|
self.output.transform(transformer),
|
|
492
|
-
self.input_names,
|
|
493
536
|
self.params,
|
|
494
537
|
)
|
|
495
538
|
|
|
@@ -506,7 +549,7 @@ class FunctionType(ParametrizedTypeBase):
|
|
|
506
549
|
# However, we have to down-shift the de Bruijn index.
|
|
507
550
|
if arg is None:
|
|
508
551
|
param = param.with_idx(len(remaining_params))
|
|
509
|
-
remaining_params.append(param)
|
|
552
|
+
remaining_params.append(param.instantiate_bounds(full_inst))
|
|
510
553
|
arg = param.to_bound()
|
|
511
554
|
|
|
512
555
|
# Set the `preserve` flag for instantiated tuples and None
|
|
@@ -519,9 +562,8 @@ class FunctionType(ParametrizedTypeBase):
|
|
|
519
562
|
|
|
520
563
|
inst = Instantiator(full_inst)
|
|
521
564
|
return FunctionType(
|
|
522
|
-
[
|
|
565
|
+
[replace(inp, ty=inp.ty.transform(inst)) for inp in self.inputs],
|
|
523
566
|
self.output.transform(inst),
|
|
524
|
-
self.input_names,
|
|
525
567
|
remaining_params,
|
|
526
568
|
# Comptime type arguments also need to be instantiated
|
|
527
569
|
comptime_args=[
|
|
@@ -538,6 +580,18 @@ class FunctionType(ParametrizedTypeBase):
|
|
|
538
580
|
exs = [param.to_existential() for param in self.params]
|
|
539
581
|
return self.instantiate([arg for arg, _ in exs]), [var for _, var in exs]
|
|
540
582
|
|
|
583
|
+
def with_unitary_flags(self, flags: UnitaryFlags) -> "FunctionType":
|
|
584
|
+
"""Returns a copy of this function type with the specified unitary flags."""
|
|
585
|
+
# N.B. we can't use `dataclasses.replace` here since `FunctionType` has a custom
|
|
586
|
+
# constructor
|
|
587
|
+
return FunctionType(
|
|
588
|
+
self.inputs,
|
|
589
|
+
self.output,
|
|
590
|
+
self.params,
|
|
591
|
+
self.comptime_args,
|
|
592
|
+
flags,
|
|
593
|
+
)
|
|
594
|
+
|
|
541
595
|
|
|
542
596
|
@dataclass(frozen=True, init=False)
|
|
543
597
|
class TupleType(ParametrizedTypeBase):
|
|
@@ -582,53 +636,6 @@ class TupleType(ParametrizedTypeBase):
|
|
|
582
636
|
)
|
|
583
637
|
|
|
584
638
|
|
|
585
|
-
@dataclass(frozen=True, init=False)
|
|
586
|
-
class SumType(ParametrizedTypeBase):
|
|
587
|
-
"""Type of sums.
|
|
588
|
-
|
|
589
|
-
Note that this type is only used internally when constructing the Hugr. Users cannot
|
|
590
|
-
write down this type.
|
|
591
|
-
"""
|
|
592
|
-
|
|
593
|
-
element_types: Sequence["Type"]
|
|
594
|
-
|
|
595
|
-
def __init__(self, element_types: Sequence["Type"]) -> None:
|
|
596
|
-
# We need a custom __init__ to set the args
|
|
597
|
-
args = [TypeArg(ty) for ty in element_types]
|
|
598
|
-
object.__setattr__(self, "args", args)
|
|
599
|
-
object.__setattr__(self, "element_types", element_types)
|
|
600
|
-
|
|
601
|
-
@property
|
|
602
|
-
def intrinsically_copyable(self) -> bool:
|
|
603
|
-
"""Whether objects of this type can be implicitly copied."""
|
|
604
|
-
return True
|
|
605
|
-
|
|
606
|
-
@property
|
|
607
|
-
def intrinsically_droppable(self) -> bool:
|
|
608
|
-
"""Whether objects of this type can be dropped."""
|
|
609
|
-
return True
|
|
610
|
-
|
|
611
|
-
def cast(self) -> "Type":
|
|
612
|
-
"""Casts an implementor of `TypeBase` into a `Type`."""
|
|
613
|
-
return self
|
|
614
|
-
|
|
615
|
-
def to_hugr(self, ctx: ToHugrContext) -> ht.Sum:
|
|
616
|
-
"""Computes the Hugr representation of the type."""
|
|
617
|
-
rows = [type_to_row(ty) for ty in self.element_types]
|
|
618
|
-
if all(len(row) == 0 for row in rows):
|
|
619
|
-
return ht.UnitSum(size=len(rows))
|
|
620
|
-
elif len(rows) == 1:
|
|
621
|
-
return ht.Tuple(*row_to_hugr(rows[0], ctx))
|
|
622
|
-
else:
|
|
623
|
-
return ht.Sum(variant_rows=rows_to_hugr(rows, ctx))
|
|
624
|
-
|
|
625
|
-
def transform(self, transformer: Transformer) -> "Type":
|
|
626
|
-
"""Accepts a transformer on this type."""
|
|
627
|
-
return transformer.transform(self) or SumType(
|
|
628
|
-
[ty.transform(transformer) for ty in self.element_types]
|
|
629
|
-
)
|
|
630
|
-
|
|
631
|
-
|
|
632
639
|
@dataclass(frozen=True)
|
|
633
640
|
class OpaqueType(ParametrizedTypeBase):
|
|
634
641
|
"""Type that is directly backed by a Hugr opaque type.
|
|
@@ -651,7 +658,7 @@ class OpaqueType(ParametrizedTypeBase):
|
|
|
651
658
|
|
|
652
659
|
@property
|
|
653
660
|
def hugr_bound(self) -> ht.TypeBound:
|
|
654
|
-
"""The Hugr bound of this type, i.e. `Any
|
|
661
|
+
"""The Hugr bound of this type, i.e. `Any` or `Copyable`."""
|
|
655
662
|
if self.defn.bound is not None:
|
|
656
663
|
return self.defn.bound
|
|
657
664
|
return super().hugr_bound
|
|
@@ -717,9 +724,8 @@ class StructType(ParametrizedTypeBase):
|
|
|
717
724
|
|
|
718
725
|
|
|
719
726
|
#: The type of parametrized Guppy types.
|
|
720
|
-
ParametrizedType: TypeAlias =
|
|
721
|
-
|
|
722
|
-
)
|
|
727
|
+
ParametrizedType: TypeAlias = FunctionType | TupleType | OpaqueType | StructType
|
|
728
|
+
|
|
723
729
|
|
|
724
730
|
#: The type of Guppy types.
|
|
725
731
|
#:
|
|
@@ -801,8 +807,6 @@ def unify(s: Type | Const, t: Type | Const, subst: "Subst | None") -> "Subst | N
|
|
|
801
807
|
return _unify_args(s, t, subst)
|
|
802
808
|
case TupleType() as s, TupleType() as t:
|
|
803
809
|
return _unify_args(s, t, subst)
|
|
804
|
-
case SumType() as s, SumType() as t:
|
|
805
|
-
return _unify_args(s, t, subst)
|
|
806
810
|
case OpaqueType() as s, OpaqueType() as t if s.defn == t.defn:
|
|
807
811
|
return _unify_args(s, t, subst)
|
|
808
812
|
case StructType() as s, StructType() as t if s.defn == t.defn:
|
|
@@ -871,6 +875,8 @@ def function_tensor_signature(tys: list[FunctionType]) -> FunctionType:
|
|
|
871
875
|
outputs: list[Type] = []
|
|
872
876
|
for fun_ty in tys:
|
|
873
877
|
assert not fun_ty.parametrized
|
|
874
|
-
|
|
878
|
+
# Forget the function input names since they might be non-unique across the
|
|
879
|
+
# tensored functions
|
|
880
|
+
inputs.extend([replace(inp, name=None) for inp in fun_ty.inputs])
|
|
875
881
|
outputs.extend(type_to_row(fun_ty.output))
|
|
876
882
|
return FunctionType(inputs, row_to_type(outputs))
|