guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,876 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Sequence
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum, Flag, auto
5
+ from functools import cached_property, total_ordering
6
+ from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
7
+
8
+ import hugr.std.float
9
+ import hugr.std.int
10
+ from hugr import tys as ht
11
+
12
+ from guppylang_internals.error import InternalGuppyError
13
+ from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
14
+ from guppylang_internals.tys.common import (
15
+ QuantifiedToHugrContext,
16
+ ToHugr,
17
+ ToHugrContext,
18
+ Transformable,
19
+ Transformer,
20
+ Visitor,
21
+ )
22
+ from guppylang_internals.tys.const import Const, ConstValue, ExistentialConstVar
23
+ from guppylang_internals.tys.param import ConstParam, Parameter
24
+ from guppylang_internals.tys.var import BoundVar, ExistentialVar
25
+
26
+ if TYPE_CHECKING:
27
+ from guppylang_internals.definition.struct import CheckedStructDef, StructField
28
+ from guppylang_internals.definition.ty import OpaqueTypeDef
29
+ from guppylang_internals.tys.subst import Inst, PartialInst, Subst
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class TypeBase(ToHugr[ht.Type], Transformable["Type"], ABC):
34
+ """Abstract base class for all Guppy types.
35
+
36
+ Note that all subclasses are expected to be immutable.
37
+ """
38
+
39
+ @cached_property
40
+ @abstractmethod
41
+ def copyable(self) -> bool:
42
+ """Whether objects of this type can be implicitly copied."""
43
+
44
+ @cached_property
45
+ @abstractmethod
46
+ def droppable(self) -> bool:
47
+ """Whether objects of this type can be dropped."""
48
+
49
+ @property
50
+ def linear(self) -> bool:
51
+ """Whether this type should be treated linearly."""
52
+ return not self.copyable and not self.droppable
53
+
54
+ @property
55
+ def affine(self) -> bool:
56
+ """Whether this type should be treated in an affine way."""
57
+ return not self.copyable and self.droppable
58
+
59
+ @cached_property
60
+ @abstractmethod
61
+ def hugr_bound(self) -> ht.TypeBound:
62
+ """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.
63
+
64
+ This needs to be specified explicitly, since opaque nonlinear types in a Hugr
65
+ extension could be either declared as copyable or equatable. If we don't get the
66
+ bound exactly right during serialisation, the Hugr validator will complain.
67
+ """
68
+
69
+ @abstractmethod
70
+ def cast(self) -> "Type":
71
+ """Casts an implementor of `TypeBase` into a `Type`.
72
+
73
+ This enforces that all implementors of `TypeBase` can be embedded into the
74
+ `Type` union type.
75
+ """
76
+
77
+ @cached_property
78
+ def unsolved_vars(self) -> set[ExistentialVar]:
79
+ """The existential type variables contained in this type."""
80
+ return set()
81
+
82
+ def substitute(self, subst: "Subst") -> "Type":
83
+ """Substitutes existential variables in this type."""
84
+ from guppylang_internals.tys.subst import Substituter
85
+
86
+ return self.transform(Substituter(subst))
87
+
88
+ def to_arg(self) -> TypeArg:
89
+ """Wraps this constant into a type argument."""
90
+ return TypeArg(self.cast())
91
+
92
+ def __str__(self) -> str:
93
+ """Returns a human-readable representation of the type."""
94
+ from guppylang_internals.tys.printing import TypePrinter
95
+
96
+ # We use a custom printer that takes care of inserting parentheses and choosing
97
+ # unique names
98
+ return TypePrinter().visit(cast(Type, self))
99
+
100
+
101
+ @dataclass(frozen=True)
102
+ class ParametrizedTypeBase(TypeBase, ABC):
103
+ """Abstract base class for types that depend on parameters.
104
+
105
+ For example, `list`, `tuple`, etc. require arguments in order to be turned into a
106
+ proper type.
107
+
108
+ Note that all subclasses are expected to be immutable.
109
+ """
110
+
111
+ args: Sequence[Argument]
112
+
113
+ def __post_init__(self) -> None:
114
+ # Make sure that we don't have nested generic functions
115
+ for arg in self.args:
116
+ match arg:
117
+ case TypeArg(ty=FunctionType(parametrized=True)):
118
+ raise InternalGuppyError(
119
+ "Tried to construct a higher-rank polymorphic type!"
120
+ )
121
+
122
+ @property
123
+ @abstractmethod
124
+ def intrinsically_copyable(self) -> bool:
125
+ """Whether this type is copyable, independent of the arguments.
126
+
127
+ For example, a parametrized struct containing a qubit is never copyable, even if
128
+ all its arguments are.
129
+ """
130
+
131
+ @cached_property
132
+ def copyable(self) -> bool:
133
+ """Whether this type should be treated as copyable."""
134
+ # Either an argument isn't a type argument, or it must be copyable.
135
+ return self.intrinsically_copyable and all(
136
+ not isinstance(arg, TypeArg) or arg.ty.copyable for arg in self.args
137
+ )
138
+
139
+ @property
140
+ @abstractmethod
141
+ def intrinsically_droppable(self) -> bool:
142
+ """Whether this type is droppable, independent of the arguments.
143
+
144
+ For example, a parametrized struct containing a qubit is never droppable, even
145
+ if all its arguments are.
146
+ """
147
+
148
+ @cached_property
149
+ def droppable(self) -> bool:
150
+ """Whether this type should be treated as copyable."""
151
+ # Either an argument isn't a type argument, or it must be droppable.
152
+ return self.intrinsically_droppable and all(
153
+ not isinstance(arg, TypeArg) or arg.ty.droppable for arg in self.args
154
+ )
155
+
156
+ @cached_property
157
+ def unsolved_vars(self) -> set[ExistentialVar]:
158
+ """The existential type variables contained in this type."""
159
+ return set().union(*(arg.unsolved_vars for arg in self.args))
160
+
161
+ @cached_property
162
+ def hugr_bound(self) -> ht.TypeBound:
163
+ """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
164
+ if self.linear:
165
+ return ht.TypeBound.Linear
166
+ return ht.TypeBound.join(
167
+ *(arg.ty.hugr_bound for arg in self.args if isinstance(arg, TypeArg))
168
+ )
169
+
170
+ def visit(self, visitor: Visitor) -> None:
171
+ """Accepts a visitor on this type."""
172
+ if not visitor.visit(self):
173
+ for arg in self.args:
174
+ visitor.visit(arg)
175
+
176
+
177
+ @dataclass(frozen=True)
178
+ class BoundTypeVar(TypeBase, BoundVar):
179
+ """Bound type variable, referencing a parameter of kind `Type`.
180
+
181
+ For example, in the function type `forall T. list[T] -> T` we represent `T` as a
182
+ `BoundTypeVar(idx=0)`.
183
+
184
+ A bound type variables can be instantiated with a `TypeArg` argument.
185
+ """
186
+
187
+ copyable: bool
188
+ droppable: bool
189
+
190
+ @cached_property
191
+ def hugr_bound(self) -> ht.TypeBound:
192
+ """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
193
+ if self.linear:
194
+ return ht.TypeBound.Linear
195
+ # We're conservative and don't require equatability for non-linear variables.
196
+ # This is fine since Guppy doesn't use the equatable feature anyways.
197
+ return ht.TypeBound.Copyable
198
+
199
+ def cast(self) -> "Type":
200
+ """Casts an implementor of `TypeBase` into a `Type`."""
201
+ return self
202
+
203
+ def to_hugr(self, ctx: ToHugrContext) -> ht.Type:
204
+ """Computes the Hugr representation of the type."""
205
+ return ctx.type_var_to_hugr(self)
206
+
207
+ def visit(self, visitor: Visitor) -> None:
208
+ """Accepts a visitor on this type."""
209
+ visitor.visit(self)
210
+
211
+ def transform(self, transformer: Transformer) -> "Type":
212
+ """Accepts a transformer on this type."""
213
+ return transformer.transform(self) or self
214
+
215
+ def __str__(self) -> str:
216
+ """Returns a human-readable representation of the type."""
217
+ return self.display_name
218
+
219
+
220
+ @dataclass(frozen=True)
221
+ class ExistentialTypeVar(ExistentialVar, TypeBase):
222
+ """Existential type variable.
223
+
224
+ For example, the empty list literal `[]` is typed as `list[?T]` where `?T` stands
225
+ for an existential type variable.
226
+
227
+ During type checking we try to solve all existential type variables and substitute
228
+ them with concrete types.
229
+ """
230
+
231
+ copyable: bool
232
+ droppable: bool
233
+
234
+ @classmethod
235
+ def fresh(
236
+ cls, display_name: str, copyable: bool, droppable: bool
237
+ ) -> "ExistentialTypeVar":
238
+ return ExistentialTypeVar(
239
+ display_name, next(cls._fresh_id), copyable, droppable
240
+ )
241
+
242
+ @cached_property
243
+ def unsolved_vars(self) -> set[ExistentialVar]:
244
+ """The existential type variables contained in this type."""
245
+ return {self}
246
+
247
+ @cached_property
248
+ def hugr_bound(self) -> ht.TypeBound:
249
+ """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
250
+ raise InternalGuppyError(
251
+ "Tried to compute bound of unsolved existential type variable"
252
+ )
253
+
254
+ def cast(self) -> "Type":
255
+ """Casts an implementor of `TypeBase` into a `Type`."""
256
+ return self
257
+
258
+ def to_hugr(self, ctx: ToHugrContext) -> ht.Type:
259
+ """Computes the Hugr representation of the type."""
260
+ raise InternalGuppyError(
261
+ "Tried to convert unsolved existential type variable to Hugr"
262
+ )
263
+
264
+ def visit(self, visitor: Visitor) -> None:
265
+ """Accepts a visitor on this type."""
266
+ visitor.visit(self)
267
+
268
+ def transform(self, transformer: Transformer) -> "Type":
269
+ """Accepts a transformer on this type."""
270
+ return transformer.transform(self) or self
271
+
272
+
273
+ @dataclass(frozen=True)
274
+ class NoneType(TypeBase):
275
+ """Type of `None`."""
276
+
277
+ copyable: bool = field(default=True, init=True)
278
+ droppable: bool = field(default=True, init=True)
279
+ hugr_bound: ht.TypeBound = field(default=ht.TypeBound.Copyable, init=False)
280
+
281
+ # Flag to avoid turning the type into a row when calling `type_to_row()`. This is
282
+ # used to make sure that type vars instantiated to Nones are not broken up into
283
+ # empty rows when generating a Hugr
284
+ preserve: bool = field(default=False, compare=False)
285
+
286
+ def cast(self) -> "Type":
287
+ """Casts an implementor of `TypeBase` into a `Type`."""
288
+ return self
289
+
290
+ def to_hugr(self, ctx: ToHugrContext) -> ht.Tuple:
291
+ """Computes the Hugr representation of the type."""
292
+ return ht.Tuple()
293
+
294
+ def visit(self, visitor: Visitor) -> None:
295
+ """Accepts a visitor on this type."""
296
+ visitor.visit(self)
297
+
298
+ def transform(self, transformer: Transformer) -> "Type":
299
+ """Accepts a transformer on this type."""
300
+ return transformer.transform(self) or self
301
+
302
+
303
+ @dataclass(frozen=True)
304
+ class NumericType(TypeBase):
305
+ """Numeric types like `int` and `float`."""
306
+
307
+ kind: "Kind"
308
+
309
+ @total_ordering
310
+ class Kind(Enum):
311
+ """The different kinds of numeric types."""
312
+
313
+ Nat = auto()
314
+ Int = auto()
315
+ Float = auto()
316
+
317
+ def __lt__(self, other: "NumericType.Kind") -> bool:
318
+ return self.value < other.value
319
+
320
+ INT_WIDTH: ClassVar[int] = 6
321
+
322
+ @property
323
+ def copyable(self) -> bool:
324
+ """Whether objects of this type can be implicitly copied."""
325
+ return True
326
+
327
+ @property
328
+ def droppable(self) -> bool:
329
+ """Whether objects of this type can be dropped."""
330
+ return True
331
+
332
+ def cast(self) -> "Type":
333
+ """Casts an implementor of `TypeBase` into a `Type`."""
334
+ return self
335
+
336
+ def to_hugr(self, ctx: ToHugrContext) -> ht.ExtType:
337
+ """Computes the Hugr representation of the type."""
338
+ match self.kind:
339
+ case NumericType.Kind.Nat | NumericType.Kind.Int:
340
+ return hugr.std.int.int_t(NumericType.INT_WIDTH)
341
+ case NumericType.Kind.Float:
342
+ return hugr.std.float.FLOAT_T
343
+
344
+ @property
345
+ def hugr_bound(self) -> ht.TypeBound:
346
+ """The Hugr bound of this type, i.e. `Any` or `Copyable`"""
347
+ return ht.TypeBound.Copyable
348
+
349
+ def visit(self, visitor: Visitor) -> None:
350
+ """Accepts a visitor on this type."""
351
+ visitor.visit(self)
352
+
353
+ def transform(self, transformer: Transformer) -> "Type":
354
+ """Accepts a transformer on this type."""
355
+ return transformer.transform(self) or self
356
+
357
+
358
+ class InputFlags(Flag):
359
+ """Flags that can be set on inputs of function types.
360
+
361
+ In the future, we could add additional flags like `Frozen`.
362
+ """
363
+
364
+ NoFlags = 0
365
+ Inout = auto()
366
+ Owned = auto()
367
+ Comptime = auto()
368
+
369
+
370
+ @dataclass(frozen=True)
371
+ class FuncInput:
372
+ """A single input of a function type."""
373
+
374
+ ty: "Type"
375
+ flags: InputFlags
376
+
377
+
378
+ @dataclass(frozen=True, init=False)
379
+ class FunctionType(ParametrizedTypeBase):
380
+ """Type of (potentially generic) functions."""
381
+
382
+ inputs: Sequence[FuncInput]
383
+ output: "Type"
384
+ params: Sequence[Parameter]
385
+ input_names: Sequence[str] | None
386
+ comptime_args: Sequence[ConstArg]
387
+
388
+ args: Sequence[Argument] = field(init=False)
389
+ copyable: bool = field(default=True, init=True)
390
+ droppable: bool = field(default=True, init=True)
391
+ intrinsically_copyable: bool = field(default=True, init=True)
392
+ intrinsically_droppable: bool = field(default=True, init=True)
393
+ hugr_bound: ht.TypeBound = field(default=ht.TypeBound.Copyable, init=False)
394
+
395
+ def __init__(
396
+ self,
397
+ inputs: Sequence[FuncInput],
398
+ output: "Type",
399
+ input_names: Sequence[str] | None = None,
400
+ params: Sequence[Parameter] | None = None,
401
+ comptime_args: Sequence[ConstArg] | None = None,
402
+ ) -> None:
403
+ # We need a custom __init__ to set the args
404
+ args: list[Argument] = [TypeArg(inp.ty) for inp in inputs]
405
+ args.append(TypeArg(output))
406
+
407
+ # If no explicit comptime args are provided, assume that all of them are bound
408
+ params = params or []
409
+ if comptime_args is None:
410
+ comptime_args = [
411
+ param.to_bound()
412
+ for param in params
413
+ if isinstance(param, ConstParam) and param.from_comptime_arg
414
+ ]
415
+ args += comptime_args
416
+
417
+ object.__setattr__(self, "args", args)
418
+ object.__setattr__(self, "comptime_args", comptime_args)
419
+ object.__setattr__(self, "inputs", inputs)
420
+ object.__setattr__(self, "output", output)
421
+ object.__setattr__(self, "input_names", input_names or [])
422
+ object.__setattr__(self, "params", params)
423
+
424
+ @property
425
+ def parametrized(self) -> bool:
426
+ """Whether the function is parametrized."""
427
+ return len(self.params) > 0
428
+
429
+ def cast(self) -> "Type":
430
+ """Casts an implementor of `TypeBase` into a `Type`."""
431
+ return self
432
+
433
+ def to_hugr(self, ctx: ToHugrContext) -> ht.FunctionType:
434
+ """Computes the Hugr representation of the type."""
435
+ if self.parametrized:
436
+ raise InternalGuppyError(
437
+ "Tried to convert parametrised function type to Hugr. Use "
438
+ "`to_hugr_poly` instead"
439
+ )
440
+ return self._to_hugr_function_type(ctx)
441
+
442
+ def to_hugr_poly(self, ctx: ToHugrContext) -> ht.PolyFuncType:
443
+ """Computes the Hugr `PolyFuncType` representation of the type."""
444
+ # Function body needs to be translated in a new context where the variables are
445
+ # bound to the quantifier.
446
+ inner_ctx = QuantifiedToHugrContext(self.params)
447
+ func_ty = self._to_hugr_function_type(inner_ctx)
448
+ return ht.PolyFuncType(
449
+ params=[p.to_hugr(ctx) for p in self.params], body=func_ty
450
+ )
451
+
452
+ def _to_hugr_function_type(self, ctx: ToHugrContext) -> ht.FunctionType:
453
+ """Helper method to compute the Hugr `FunctionType` representation of the type.
454
+
455
+ The resulting `FunctionType` can then be embedded into a Hugr `Type` or a Hugr
456
+ `PolyFuncType`.
457
+ """
458
+ ins = [
459
+ inp.ty.to_hugr(ctx)
460
+ for inp in self.inputs
461
+ # Comptime inputs are turned into generic args, so are not included here
462
+ if InputFlags.Comptime not in inp.flags
463
+ ]
464
+ outs = [
465
+ *(t.to_hugr(ctx) for t in type_to_row(self.output)),
466
+ # We might have additional borrowed args that will be also outputted
467
+ *(
468
+ inp.ty.to_hugr(ctx)
469
+ for inp in self.inputs
470
+ if InputFlags.Inout in inp.flags
471
+ ),
472
+ ]
473
+ return ht.FunctionType(input=ins, output=outs)
474
+
475
+ def visit(self, visitor: Visitor) -> None:
476
+ """Accepts a visitor on this type."""
477
+ if not visitor.visit(self):
478
+ for inp in self.inputs:
479
+ visitor.visit(inp)
480
+ visitor.visit(self.output)
481
+ for param in self.params:
482
+ visitor.visit(param)
483
+
484
+ def transform(self, transformer: Transformer) -> "Type":
485
+ """Accepts a transformer on this type."""
486
+ return transformer.transform(self) or FunctionType(
487
+ [
488
+ FuncInput(inp.ty.transform(transformer), inp.flags)
489
+ for inp in self.inputs
490
+ ],
491
+ self.output.transform(transformer),
492
+ self.input_names,
493
+ self.params,
494
+ )
495
+
496
+ def instantiate_partial(self, args: "PartialInst") -> "FunctionType":
497
+ """Instantiates a subset of the function parameters with concrete types."""
498
+ from guppylang_internals.tys.subst import Instantiator
499
+
500
+ assert len(args) == len(self.params)
501
+
502
+ full_inst: list[Argument] = []
503
+ remaining_params: list[Parameter] = []
504
+ for param, arg in zip(self.params, args, strict=True):
505
+ # If no instantiation for this param is provided, it should stay around.
506
+ # However, we have to down-shift the de Bruijn index.
507
+ if arg is None:
508
+ param = param.with_idx(len(remaining_params))
509
+ remaining_params.append(param)
510
+ arg = param.to_bound()
511
+
512
+ # Set the `preserve` flag for instantiated tuples and None
513
+ if isinstance(arg, TypeArg):
514
+ if isinstance(arg.ty, TupleType):
515
+ arg = TypeArg(TupleType(arg.ty.element_types, preserve=True))
516
+ elif isinstance(arg.ty, NoneType):
517
+ arg = TypeArg(NoneType(preserve=True))
518
+ full_inst.append(arg)
519
+
520
+ inst = Instantiator(full_inst)
521
+ return FunctionType(
522
+ [FuncInput(inp.ty.transform(inst), inp.flags) for inp in self.inputs],
523
+ self.output.transform(inst),
524
+ self.input_names,
525
+ remaining_params,
526
+ # Comptime type arguments also need to be instantiated
527
+ comptime_args=[
528
+ cast(ConstArg, arg.transform(inst)) for arg in self.comptime_args
529
+ ],
530
+ )
531
+
532
+ def instantiate(self, args: "Inst") -> "FunctionType":
533
+ """Instantiates all function parameters with concrete types."""
534
+ return self.instantiate_partial(args)
535
+
536
+ def unquantified(self) -> tuple["FunctionType", Sequence[ExistentialVar]]:
537
+ """Instantiates all parameters with existential variables."""
538
+ exs = [param.to_existential() for param in self.params]
539
+ return self.instantiate([arg for arg, _ in exs]), [var for _, var in exs]
540
+
541
+
542
+ @dataclass(frozen=True, init=False)
543
+ class TupleType(ParametrizedTypeBase):
544
+ """Type of tuples."""
545
+
546
+ element_types: Sequence["Type"]
547
+
548
+ # Flag to avoid turning the tuple into a row when calling `type_to_row()`. This is
549
+ # used to make sure that type vars instantiated to tuples are not broken up into
550
+ # rows when generating a Hugr
551
+ preserve: bool = field(default=False, compare=False)
552
+
553
+ def __init__(self, element_types: Sequence["Type"], preserve: bool = False) -> None:
554
+ # We need a custom __init__ to set the args
555
+ args = [TypeArg(ty) for ty in element_types]
556
+ object.__setattr__(self, "args", args)
557
+ object.__setattr__(self, "element_types", element_types)
558
+ object.__setattr__(self, "preserve", preserve)
559
+
560
+ @property
561
+ def intrinsically_copyable(self) -> bool:
562
+ """Whether objects of this type can be implicitly copied."""
563
+ return True
564
+
565
+ @property
566
+ def intrinsically_droppable(self) -> bool:
567
+ """Whether objects of this type can be dropped."""
568
+ return True
569
+
570
+ def cast(self) -> "Type":
571
+ """Casts an implementor of `TypeBase` into a `Type`."""
572
+ return self
573
+
574
+ def to_hugr(self, ctx: ToHugrContext) -> ht.Tuple:
575
+ """Computes the Hugr representation of the type."""
576
+ return ht.Tuple(*row_to_hugr(self.element_types, ctx))
577
+
578
+ def transform(self, transformer: Transformer) -> "Type":
579
+ """Accepts a transformer on this type."""
580
+ return transformer.transform(self) or TupleType(
581
+ [ty.transform(transformer) for ty in self.element_types], self.preserve
582
+ )
583
+
584
+
585
+ @dataclass(frozen=True, init=False)
586
+ class SumType(ParametrizedTypeBase):
587
+ """Type of sums.
588
+
589
+ Note that this type is only used internally when constructing the Hugr. Users cannot
590
+ write down this type.
591
+ """
592
+
593
+ element_types: Sequence["Type"]
594
+
595
+ def __init__(self, element_types: Sequence["Type"]) -> None:
596
+ # We need a custom __init__ to set the args
597
+ args = [TypeArg(ty) for ty in element_types]
598
+ object.__setattr__(self, "args", args)
599
+ object.__setattr__(self, "element_types", element_types)
600
+
601
+ @property
602
+ def intrinsically_copyable(self) -> bool:
603
+ """Whether objects of this type can be implicitly copied."""
604
+ return True
605
+
606
+ @property
607
+ def intrinsically_droppable(self) -> bool:
608
+ """Whether objects of this type can be dropped."""
609
+ return True
610
+
611
+ def cast(self) -> "Type":
612
+ """Casts an implementor of `TypeBase` into a `Type`."""
613
+ return self
614
+
615
+ def to_hugr(self, ctx: ToHugrContext) -> ht.Sum:
616
+ """Computes the Hugr representation of the type."""
617
+ rows = [type_to_row(ty) for ty in self.element_types]
618
+ if all(len(row) == 0 for row in rows):
619
+ return ht.UnitSum(size=len(rows))
620
+ elif len(rows) == 1:
621
+ return ht.Tuple(*row_to_hugr(rows[0], ctx))
622
+ else:
623
+ return ht.Sum(variant_rows=rows_to_hugr(rows, ctx))
624
+
625
+ def transform(self, transformer: Transformer) -> "Type":
626
+ """Accepts a transformer on this type."""
627
+ return transformer.transform(self) or SumType(
628
+ [ty.transform(transformer) for ty in self.element_types]
629
+ )
630
+
631
+
632
+ @dataclass(frozen=True)
633
+ class OpaqueType(ParametrizedTypeBase):
634
+ """Type that is directly backed by a Hugr opaque type.
635
+
636
+ For example, many builtin types like `int`, `float`, `list` etc. are directly backed
637
+ by a Hugr extension.
638
+ """
639
+
640
+ defn: "OpaqueTypeDef"
641
+
642
+ @property
643
+ def intrinsically_copyable(self) -> bool:
644
+ """Whether objects of this type can be implicitly copied."""
645
+ return not self.defn.never_copyable
646
+
647
+ @property
648
+ def intrinsically_droppable(self) -> bool:
649
+ """Whether objects of this type can be dropped."""
650
+ return not self.defn.never_droppable
651
+
652
+ @property
653
+ def hugr_bound(self) -> ht.TypeBound:
654
+ """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
655
+ if self.defn.bound is not None:
656
+ return self.defn.bound
657
+ return super().hugr_bound
658
+
659
+ def cast(self) -> "Type":
660
+ """Casts an implementor of `TypeBase` into a `Type`."""
661
+ return self
662
+
663
+ def to_hugr(self, ctx: ToHugrContext) -> ht.Type:
664
+ """Computes the Hugr representation of the type."""
665
+ return self.defn.to_hugr(self.args, ctx)
666
+
667
+ def transform(self, transformer: Transformer) -> "Type":
668
+ """Accepts a transformer on this type."""
669
+ return transformer.transform(self) or OpaqueType(
670
+ [arg.transform(transformer) for arg in self.args], self.defn
671
+ )
672
+
673
+
674
+ @dataclass(frozen=True)
675
+ class StructType(ParametrizedTypeBase):
676
+ """A struct type."""
677
+
678
+ defn: "CheckedStructDef"
679
+
680
+ @cached_property
681
+ def fields(self) -> list["StructField"]:
682
+ """The fields of this struct type."""
683
+ from guppylang_internals.definition.struct import StructField
684
+ from guppylang_internals.tys.subst import Instantiator
685
+
686
+ inst = Instantiator(self.args)
687
+ return [StructField(f.name, f.ty.transform(inst)) for f in self.defn.fields]
688
+
689
+ @cached_property
690
+ def field_dict(self) -> "dict[str, StructField]":
691
+ """Mapping from names to fields of this struct type."""
692
+ return {field.name: field for field in self.fields}
693
+
694
+ @cached_property
695
+ def intrinsically_copyable(self) -> bool:
696
+ """Whether objects of this type can be implicitly copied."""
697
+ return all(f.ty.copyable for f in self.fields)
698
+
699
+ @cached_property
700
+ def intrinsically_droppable(self) -> bool:
701
+ """Whether objects of this type can be dropped."""
702
+ return all(f.ty.droppable for f in self.fields)
703
+
704
+ def cast(self) -> "Type":
705
+ """Casts an implementor of `TypeBase` into a `Type`."""
706
+ return self
707
+
708
+ def to_hugr(self, ctx: ToHugrContext) -> ht.Tuple:
709
+ """Computes the Hugr representation of the type."""
710
+ return ht.Tuple(*(f.ty.to_hugr(ctx) for f in self.fields))
711
+
712
+ def transform(self, transformer: Transformer) -> "Type":
713
+ """Accepts a transformer on this type."""
714
+ return transformer.transform(self) or StructType(
715
+ [arg.transform(transformer) for arg in self.args], self.defn
716
+ )
717
+
718
+
719
+ #: The type of parametrized Guppy types.
720
+ ParametrizedType: TypeAlias = (
721
+ FunctionType | TupleType | SumType | OpaqueType | StructType
722
+ )
723
+
724
+ #: The type of Guppy types.
725
+ #:
726
+ #: This is a type alias for a union of all Guppy types defined in this module. This
727
+ #: models an algebraic data type and enables exhaustiveness checking in pattern matches
728
+ #: etc.
729
+ #:
730
+ #: This might become obsolete in case the @sealed decorator is added:
731
+ #: * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types
732
+ #: * https://github.com/johnthagen/sealed-typing-pep
733
+ Type: TypeAlias = (
734
+ BoundTypeVar | ExistentialTypeVar | NumericType | NoneType | ParametrizedType
735
+ )
736
+
737
+ #: An immutable row of Guppy types.
738
+ TypeRow: TypeAlias = Sequence[Type]
739
+
740
+
741
+ def row_to_type(row: TypeRow) -> Type:
742
+ """Turns a row of types into a single type by packing into a tuple."""
743
+ if len(row) == 0:
744
+ return NoneType()
745
+ elif len(row) == 1:
746
+ return row[0]
747
+ else:
748
+ return TupleType(row)
749
+
750
+
751
+ def type_to_row(ty: Type) -> TypeRow:
752
+ """Turns a type into a row of types by unpacking top-level tuples."""
753
+ if isinstance(ty, NoneType) and not ty.preserve:
754
+ return []
755
+ if isinstance(ty, TupleType) and not ty.preserve:
756
+ return ty.element_types
757
+ return [ty]
758
+
759
+
760
+ def row_to_hugr(row: TypeRow, ctx: ToHugrContext) -> ht.TypeRow:
761
+ """Computes the Hugr representation of a type row."""
762
+ return [ty.to_hugr(ctx) for ty in row]
763
+
764
+
765
+ def rows_to_hugr(rows: Sequence[TypeRow], ctx: ToHugrContext) -> list[ht.TypeRow]:
766
+ """Computes the Hugr representation of a sequence of rows."""
767
+ return [row_to_hugr(row, ctx) for row in rows]
768
+
769
+
770
+ def unify(s: Type | Const, t: Type | Const, subst: "Subst | None") -> "Subst | None":
771
+ """Computes a most general unifier for two types or constants.
772
+
773
+ Return a substitutions `subst` such that `s[subst] == t[subst]` or `None` if this
774
+ not possible.
775
+ """
776
+ # Make sure that s and t are either both constants or both types
777
+ assert isinstance(s, TypeBase) == isinstance(t, TypeBase)
778
+ if subst is None:
779
+ return None
780
+ match s, t:
781
+ case ExistentialVar(id=s_id), ExistentialVar(id=t_id) if s_id == t_id:
782
+ return subst
783
+ case ExistentialTypeVar() | ExistentialConstVar() as s_var, t:
784
+ return _unify_var(s_var, t, subst)
785
+ case s, ExistentialTypeVar() | ExistentialConstVar() as t_var:
786
+ return _unify_var(t_var, s, subst)
787
+ case BoundVar(idx=s_idx), BoundVar(idx=t_idx) if s_idx == t_idx:
788
+ return subst
789
+ case ConstValue(value=c_value), ConstValue(value=d_value) if c_value == d_value:
790
+ return subst
791
+ case NumericType(kind=s_kind), NumericType(kind=t_kind) if s_kind == t_kind:
792
+ return subst
793
+ case NoneType(), NoneType():
794
+ return subst
795
+ case FunctionType() as s, FunctionType() as t if s.params == t.params:
796
+ if len(s.inputs) != len(t.inputs):
797
+ return None
798
+ for a, b in zip(s.inputs, t.inputs, strict=True):
799
+ if a.ty.linear and b.ty.linear and a.flags != b.flags:
800
+ return None
801
+ return _unify_args(s, t, subst)
802
+ case TupleType() as s, TupleType() as t:
803
+ return _unify_args(s, t, subst)
804
+ case SumType() as s, SumType() as t:
805
+ return _unify_args(s, t, subst)
806
+ case OpaqueType() as s, OpaqueType() as t if s.defn == t.defn:
807
+ return _unify_args(s, t, subst)
808
+ case StructType() as s, StructType() as t if s.defn == t.defn:
809
+ return _unify_args(s, t, subst)
810
+ case _:
811
+ return None
812
+
813
+
814
+ def _unify_var(
815
+ var: ExistentialTypeVar | ExistentialConstVar, t: Type | Const, subst: "Subst"
816
+ ) -> "Subst | None":
817
+ """Helper function for unification of type or const variables."""
818
+ if var in subst:
819
+ return unify(subst[var], t, subst)
820
+ if isinstance(t, ExistentialTypeVar) and t in subst:
821
+ return unify(var, subst[t], subst)
822
+ if var in t.unsolved_vars:
823
+ return None
824
+ return {var: t, **subst}
825
+
826
+
827
+ def _unify_args(
828
+ s: ParametrizedType, t: ParametrizedType, subst: "Subst"
829
+ ) -> "Subst | None":
830
+ """Helper function for unification of type arguments of parametrised types."""
831
+ if len(s.args) != len(t.args):
832
+ return None
833
+ for sa, ta in zip(s.args, t.args, strict=True):
834
+ match sa, ta:
835
+ case TypeArg(ty=sa_ty), TypeArg(ty=ta_ty):
836
+ res = unify(sa_ty, ta_ty, subst)
837
+ if res is None:
838
+ return None
839
+ subst = res
840
+ case ConstArg(const=sa_const), ConstArg(const=ta_const):
841
+ res = unify(sa_const, ta_const, subst)
842
+ if res is None:
843
+ return None
844
+ subst = res
845
+ case _:
846
+ return None
847
+ return subst
848
+
849
+
850
+ ### Helpers for working with tuples of functions
851
+
852
+
853
+ def parse_function_tensor(ty: TupleType) -> list[FunctionType] | None:
854
+ """Parses a nested tuple of function types into a flat list of functions."""
855
+ result = []
856
+ for el in ty.element_types:
857
+ if isinstance(el, FunctionType):
858
+ result.append(el)
859
+ elif isinstance(el, TupleType):
860
+ funcs = parse_function_tensor(el)
861
+ if funcs:
862
+ result.extend(funcs)
863
+ else:
864
+ return None
865
+ return result
866
+
867
+
868
+ def function_tensor_signature(tys: list[FunctionType]) -> FunctionType:
869
+ """Compute the combined function signature of a list of functions"""
870
+ inputs: list[FuncInput] = []
871
+ outputs: list[Type] = []
872
+ for fun_ty in tys:
873
+ assert not fun_ty.parametrized
874
+ inputs.extend(fun_ty.inputs)
875
+ outputs.extend(type_to_row(fun_ty.output))
876
+ return FunctionType(inputs, row_to_type(outputs))