guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +118 -5
  5. guppylang_internals/cfg/cfg.py +3 -0
  6. guppylang_internals/checker/cfg_checker.py +6 -0
  7. guppylang_internals/checker/core.py +5 -2
  8. guppylang_internals/checker/errors/generic.py +32 -1
  9. guppylang_internals/checker/errors/type_errors.py +14 -0
  10. guppylang_internals/checker/errors/wasm.py +7 -4
  11. guppylang_internals/checker/expr_checker.py +58 -17
  12. guppylang_internals/checker/func_checker.py +18 -14
  13. guppylang_internals/checker/linearity_checker.py +67 -10
  14. guppylang_internals/checker/modifier_checker.py +120 -0
  15. guppylang_internals/checker/stmt_checker.py +48 -1
  16. guppylang_internals/checker/unitary_checker.py +132 -0
  17. guppylang_internals/compiler/cfg_compiler.py +7 -6
  18. guppylang_internals/compiler/core.py +93 -56
  19. guppylang_internals/compiler/expr_compiler.py +72 -168
  20. guppylang_internals/compiler/modifier_compiler.py +176 -0
  21. guppylang_internals/compiler/stmt_compiler.py +15 -8
  22. guppylang_internals/decorator.py +86 -7
  23. guppylang_internals/definition/custom.py +39 -1
  24. guppylang_internals/definition/declaration.py +9 -6
  25. guppylang_internals/definition/function.py +12 -2
  26. guppylang_internals/definition/parameter.py +8 -3
  27. guppylang_internals/definition/pytket_circuits.py +14 -41
  28. guppylang_internals/definition/struct.py +13 -7
  29. guppylang_internals/definition/ty.py +3 -3
  30. guppylang_internals/definition/wasm.py +42 -10
  31. guppylang_internals/engine.py +9 -3
  32. guppylang_internals/experimental.py +5 -0
  33. guppylang_internals/nodes.py +147 -24
  34. guppylang_internals/std/_internal/checker.py +13 -108
  35. guppylang_internals/std/_internal/compiler/array.py +95 -283
  36. guppylang_internals/std/_internal/compiler/list.py +1 -1
  37. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  38. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  39. guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
  40. guppylang_internals/std/_internal/debug.py +18 -9
  41. guppylang_internals/std/_internal/util.py +1 -1
  42. guppylang_internals/tracing/object.py +10 -0
  43. guppylang_internals/tracing/unpacking.py +19 -20
  44. guppylang_internals/tys/arg.py +18 -3
  45. guppylang_internals/tys/builtin.py +2 -5
  46. guppylang_internals/tys/const.py +33 -4
  47. guppylang_internals/tys/errors.py +23 -1
  48. guppylang_internals/tys/param.py +31 -16
  49. guppylang_internals/tys/parsing.py +11 -24
  50. guppylang_internals/tys/printing.py +2 -8
  51. guppylang_internals/tys/qubit.py +62 -0
  52. guppylang_internals/tys/subst.py +8 -26
  53. guppylang_internals/tys/ty.py +91 -85
  54. guppylang_internals/wasm_util.py +129 -0
  55. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
  56. guppylang_internals-0.26.0.dist-info/RECORD +104 -0
  57. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
  58. guppylang_internals-0.24.0.dist-info/RECORD +0 -98
  59. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
@@ -1,6 +1,6 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from collections.abc import Sequence
3
- from dataclasses import dataclass, field
3
+ from dataclasses import dataclass, field, replace
4
4
  from typing import TYPE_CHECKING, TypeAlias
5
5
 
6
6
  from hugr import tys as ht
@@ -17,9 +17,9 @@ from guppylang_internals.tys.errors import WrongNumberOfTypeArgsError
17
17
  from guppylang_internals.tys.var import ExistentialVar
18
18
 
19
19
  if TYPE_CHECKING:
20
+ from guppylang_internals.tys.subst import PartialInst
20
21
  from guppylang_internals.tys.ty import Type
21
22
 
22
-
23
23
  # We define the `Parameter` type as a union of all `ParameterBase` subclasses defined
24
24
  # below. This models an algebraic data type and enables exhaustiveness checking in
25
25
  # pattern matches etc.
@@ -74,6 +74,10 @@ class ParameterBase(ToHugr[ht.TypeParam], ABC):
74
74
  parameter.
75
75
  """
76
76
 
77
+ @abstractmethod
78
+ def instantiate_bounds(self, inst: "PartialInst") -> Self:
79
+ """Instantiates bound variables mentioned in parameter bounds"""
80
+
77
81
 
78
82
  @dataclass(frozen=True)
79
83
  class TypeParam(ParameterBase):
@@ -142,10 +146,17 @@ class TypeParam(ParameterBase):
142
146
  BoundTypeVar(self.name, idx, self.must_be_copyable, self.must_be_droppable)
143
147
  )
144
148
 
149
+ def instantiate_bounds(self, inst: "PartialInst") -> "TypeParam":
150
+ """Instantiates bound variables mentioned in parameter bounds"""
151
+ # For now, type parameters don't have any bounds that could be instantiated
152
+ return self
153
+
145
154
  def to_hugr(self, ctx: ToHugrContext) -> ht.TypeParam:
146
155
  """Computes the Hugr representation of the parameter."""
147
156
  return ht.TypeTypeParam(
148
- bound=ht.TypeBound.Linear if self.can_be_linear else ht.TypeBound.Copyable
157
+ bound=ht.TypeBound.Copyable
158
+ if self.must_be_copyable
159
+ else ht.TypeBound.Linear
149
160
  )
150
161
 
151
162
  def __str__(self) -> str:
@@ -157,18 +168,12 @@ class TypeParam(ParameterBase):
157
168
  class ConstParam(ParameterBase):
158
169
  """A parameter of kind constant. Used to define fixed-size arrays etc."""
159
170
 
160
- ty: "Type"
171
+ ty: "Type" = field(hash=False)
161
172
 
162
173
  #: Marker to annotate if this parameter was implicitly generated by a `@comptime`
163
174
  #: annotated argument in a function signature.
164
175
  from_comptime_arg: bool = field(default=False, kw_only=True)
165
176
 
166
- def __post_init__(self) -> None:
167
- if self.ty.unsolved_vars:
168
- raise InternalGuppyError(
169
- "Attempted to create constant param with unsolved type"
170
- )
171
-
172
177
  def with_idx(self, idx: int) -> "ConstParam":
173
178
  """Returns a copy of the parameter with a new index."""
174
179
  return ConstParam(idx, self.name, self.ty)
@@ -178,13 +183,16 @@ class ConstParam(ParameterBase):
178
183
 
179
184
  Raises a user error if the argument is not valid.
180
185
  """
186
+ from guppylang_internals.tys.ty import unify
187
+
181
188
  match arg:
182
189
  case ConstArg(const):
183
- if const.ty != self.ty:
190
+ subst = unify(const.ty, self.ty, {})
191
+ if subst is None:
184
192
  raise GuppyTypeError(
185
193
  TypeMismatchError(loc, self.ty, const.ty, kind="argument")
186
194
  )
187
- return arg
195
+ return ConstArg(replace(const, ty=const.ty.substitute(subst)))
188
196
  case TypeArg(ty=ty):
189
197
  err = ExpectedError(
190
198
  loc, f"expression of type `{self.ty}`", got=f"type `{ty}`"
@@ -208,6 +216,13 @@ class ConstParam(ParameterBase):
208
216
  idx = self.idx
209
217
  return ConstArg(BoundConstVar(self.ty, self.name, idx))
210
218
 
219
+ def instantiate_bounds(self, inst: "PartialInst") -> "ConstParam":
220
+ """Instantiates bound variables mentioned in parameter bounds"""
221
+ from guppylang_internals.tys.subst import Instantiator
222
+
223
+ instantiator = Instantiator(inst)
224
+ return replace(self, ty=self.ty.transform(instantiator))
225
+
211
226
  def to_hugr(self, ctx: ToHugrContext) -> ht.TypeParam:
212
227
  """Computes the Hugr representation of the parameter."""
213
228
  from guppylang_internals.tys.ty import NumericType
@@ -231,6 +246,7 @@ def check_all_args(
231
246
  args: Sequence[Argument],
232
247
  type_name: str,
233
248
  loc: AstNode | None = None,
249
+ arg_locs: Sequence[AstNode] | None = None,
234
250
  ) -> None:
235
251
  """Checks a list of arguments against the given parameters.
236
252
 
@@ -245,7 +261,6 @@ def check_all_args(
245
261
  raise GuppyError(WrongNumberOfTypeArgsError(loc, exp, act, type_name))
246
262
 
247
263
  # Now check that the kinds match up
248
- for param, arg in zip(params, args, strict=True):
249
- # TODO: The error location is bad. We want the location of `arg`, not of the
250
- # whole thing.
251
- param.check_arg(arg, loc)
264
+ for i, (param, arg) in enumerate(zip(params, args, strict=True)):
265
+ arg_loc = arg_locs[i] if arg_locs else loc
266
+ param.instantiate_bounds(args).check_arg(arg, arg_loc)
@@ -39,7 +39,6 @@ from guppylang_internals.tys.errors import (
39
39
  WrongNumberOfTypeArgsError,
40
40
  )
41
41
  from guppylang_internals.tys.param import ConstParam, Parameter, TypeParam
42
- from guppylang_internals.tys.subst import BoundVarFinder
43
42
  from guppylang_internals.tys.ty import (
44
43
  FuncInput,
45
44
  FunctionType,
@@ -108,7 +107,7 @@ def arg_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Argument:
108
107
  return ConstArg(ConstValue(bool_type(), v))
109
108
  # Integer literals are turned into nat args.
110
109
  # TODO: To support int args, we need proper inference logic here
111
- # See https://github.com/CQCL/guppylang/issues/1030
110
+ # See https://github.com/quantinuum/guppylang/issues/1030
112
111
  case int(v) if v >= 0:
113
112
  nat_ty = NumericType(NumericType.Kind.Nat)
114
113
  return ConstArg(ConstValue(nat_ty, v))
@@ -118,7 +117,7 @@ def arg_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Argument:
118
117
  # String literals are ignored for now since they could also be stringified
119
118
  # types.
120
119
  # TODO: To support string args, we need proper inference logic here
121
- # See https://github.com/CQCL/guppylang/issues/1030
120
+ # See https://github.com/quantinuum/guppylang/issues/1030
122
121
  case str(_):
123
122
  pass
124
123
 
@@ -290,12 +289,18 @@ def check_function_arg(
290
289
  ctx.param_var_mapping[name] = ConstParam(
291
290
  len(ctx.param_var_mapping), name, ty, from_comptime_arg=True
292
291
  )
293
- return FuncInput(ty, flags)
292
+ return FuncInput(ty, flags, name)
294
293
 
295
294
 
296
295
  if sys.version_info >= (3, 12):
297
296
 
298
- def parse_parameter(node: ast.type_param, idx: int, globals: Globals) -> Parameter:
297
+ def parse_parameter(
298
+ node: ast.type_param,
299
+ idx: int,
300
+ globals: Globals,
301
+ param_var_mapping: dict[str, Parameter],
302
+ allow_free_vars: bool = False,
303
+ ) -> Parameter:
299
304
  """Parses a `Variable: Bound` generic type parameter declaration."""
300
305
  if isinstance(node, ast.TypeVarTuple | ast.ParamSpec):
301
306
  raise GuppyError(UnsupportedError(node, "Variadic generic parameters"))
@@ -331,18 +336,10 @@ if sys.version_info >= (3, 12):
331
336
  # parameters, so we pass an empty dict as the `param_var_mapping`.
332
337
  # TODO: In the future we might want to allow stuff like
333
338
  # `def foo[T, XS: array[T, 42]]` and so on
334
- ctx = TypeParsingCtx(globals, param_var_mapping={})
339
+ ctx = TypeParsingCtx(globals, param_var_mapping, allow_free_vars)
335
340
  ty = type_from_ast(bound, ctx)
336
341
  if not ty.copyable or not ty.droppable:
337
342
  raise GuppyError(LinearConstParamError(bound, ty))
338
-
339
- # TODO: For now we can only do `nat` const args since they lower to
340
- # Hugr bounded nats. Extend to arbitrary types via monomorphization.
341
- # See https://github.com/CQCL/guppylang/issues/1008
342
- if ty != NumericType(NumericType.Kind.Nat):
343
- raise GuppyError(
344
- UnsupportedError(bound, f"`{ty}` generic parameters")
345
- )
346
343
  return ConstParam(idx, node.name, ty)
347
344
 
348
345
 
@@ -363,16 +360,6 @@ def type_with_flags_from_ast(
363
360
  flags |= InputFlags.Comptime
364
361
  if not ty.copyable or not ty.droppable:
365
362
  raise GuppyError(LinearComptimeError(node.right, ty))
366
- # For now, we don't allow comptime annotations on generic inputs
367
- # TODO: In the future we might want to allow stuff like
368
- # `def foo[T: (Copy, Discard](x: T @comptime)`.
369
- # Also see the todo in `parse_parameter`.
370
- var_finder = BoundVarFinder()
371
- ty.visit(var_finder)
372
- if var_finder.bound_vars:
373
- raise GuppyError(
374
- UnsupportedError(node.left, "Generic comptime arguments")
375
- )
376
363
  case _:
377
364
  raise GuppyError(InvalidFlagError(node.right))
378
365
  return ty, flags
@@ -11,7 +11,6 @@ from guppylang_internals.tys.ty import (
11
11
  NumericType,
12
12
  OpaqueType,
13
13
  StructType,
14
- SumType,
15
14
  TupleType,
16
15
  Type,
17
16
  )
@@ -122,11 +121,6 @@ class TypePrinter:
122
121
  args = ", ".join(self._visit(arg, True) for arg in ty.args)
123
122
  return f"({args})"
124
123
 
125
- @_visit.register
126
- def _visit_SumType(self, ty: SumType, inside_row: bool) -> str:
127
- args = ", ".join(self._visit(arg, True) for arg in ty.args)
128
- return f"Sum[{args}]"
129
-
130
124
  @_visit.register
131
125
  def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str:
132
126
  return "None"
@@ -168,7 +162,7 @@ def signature_to_str(name: str, sig: FunctionType) -> str:
168
162
  assert sig.input_names is not None
169
163
  s = f"def {name}("
170
164
  s += ", ".join(
171
- f"{name}: {inp.ty}{TypePrinter._print_flags(inp.flags)}"
172
- for name, inp in zip(sig.input_names, sig.inputs, strict=True)
165
+ f"{inp.name}: {inp.ty}{TypePrinter._print_flags(inp.flags)}"
166
+ for inp in sig.inputs
173
167
  )
174
168
  return s + ") -> " + str(sig.output)
@@ -0,0 +1,62 @@
1
+ import functools
2
+ from typing import Any, cast
3
+
4
+ from guppylang_internals.definition.ty import TypeDef
5
+ from guppylang_internals.tys.arg import TypeArg
6
+ from guppylang_internals.tys.common import Visitor
7
+ from guppylang_internals.tys.ty import OpaqueType, Type
8
+
9
+
10
+ @functools.cache
11
+ def qubit_ty() -> Type:
12
+ """Returns the qubit type. Beware that this function imports guppylang definitions,
13
+ so, if called before the definitions are registered,
14
+ it might result in circular imports.
15
+ """
16
+ from guppylang.defs import GuppyDefinition
17
+ from guppylang.std.quantum import qubit
18
+
19
+ assert isinstance(qubit, GuppyDefinition)
20
+ qubit_ty = cast(TypeDef, qubit.wrapped).check_instantiate([])
21
+ return qubit_ty
22
+
23
+
24
+ def is_qubit_ty(ty: Type) -> bool:
25
+ """Checks if the given type is the qubit type.
26
+ This function results in circular imports if called
27
+ before qubit types are registered.
28
+ """
29
+ return ty == qubit_ty()
30
+
31
+
32
+ class QubitFinder(Visitor):
33
+ """Type visitor that checks if a type contains the qubit type."""
34
+
35
+ class FoundFlag(Exception):
36
+ pass
37
+
38
+ @functools.singledispatchmethod
39
+ def visit(self, ty: Any) -> bool: # type: ignore[override]
40
+ return False
41
+
42
+ @visit.register
43
+ def _visit_OpaqueType(self, ty: OpaqueType) -> bool:
44
+ if is_qubit_ty(ty):
45
+ raise self.FoundFlag
46
+ return False
47
+
48
+ @visit.register
49
+ def _visit_TypeArg(self, arg: TypeArg) -> bool:
50
+ arg.ty.visit(self)
51
+ return True
52
+
53
+
54
+ def contain_qubit_ty(ty: Type) -> bool:
55
+ """Checks if the given type contains the qubit type."""
56
+ finder = QubitFinder()
57
+ try:
58
+ ty.visit(finder)
59
+ except QubitFinder.FoundFlag:
60
+ return True
61
+ else:
62
+ return False
@@ -4,7 +4,7 @@ from typing import Any
4
4
 
5
5
  from guppylang_internals.error import InternalGuppyError
6
6
  from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
7
- from guppylang_internals.tys.common import Transformer, Visitor
7
+ from guppylang_internals.tys.common import Transformer
8
8
  from guppylang_internals.tys.const import (
9
9
  BoundConstVar,
10
10
  Const,
@@ -18,7 +18,7 @@ from guppylang_internals.tys.ty import (
18
18
  Type,
19
19
  TypeBase,
20
20
  )
21
- from guppylang_internals.tys.var import BoundVar, ExistentialVar
21
+ from guppylang_internals.tys.var import ExistentialVar
22
22
 
23
23
  Subst = dict[ExistentialVar, Type | Const]
24
24
  Inst = Sequence[Argument]
@@ -51,7 +51,8 @@ class Substituter(Transformer):
51
51
  class Instantiator(Transformer):
52
52
  """Type transformer that instantiates bound variables."""
53
53
 
54
- def __init__(self, inst: Inst) -> None:
54
+ def __init__(self, inst: PartialInst, allow_partial: bool = False) -> None:
55
+ self.allow_partial = allow_partial
55
56
  self.inst = inst
56
57
 
57
58
  @functools.singledispatchmethod
@@ -63,6 +64,8 @@ class Instantiator(Transformer):
63
64
  # Instantiate if type for the index is available
64
65
  if ty.idx < len(self.inst):
65
66
  arg = self.inst[ty.idx]
67
+ if arg is None and self.allow_partial:
68
+ return None
66
69
  assert isinstance(arg, TypeArg)
67
70
  return arg.ty
68
71
 
@@ -76,6 +79,8 @@ class Instantiator(Transformer):
76
79
  # Instantiate if const value for the index is available
77
80
  if c.idx < len(self.inst):
78
81
  arg = self.inst[c.idx]
82
+ if arg is None and self.allow_partial:
83
+ return None
79
84
  assert isinstance(arg, ConstArg)
80
85
  return arg.const
81
86
 
@@ -87,26 +92,3 @@ class Instantiator(Transformer):
87
92
  if ty.parametrized:
88
93
  raise InternalGuppyError("Tried to instantiate under binder")
89
94
  return None
90
-
91
-
92
- class BoundVarFinder(Visitor):
93
- """Type visitor that looks for occurrences of bound variables."""
94
-
95
- bound_vars: set[BoundVar]
96
-
97
- def __init__(self) -> None:
98
- self.bound_vars = set()
99
-
100
- @functools.singledispatchmethod
101
- def visit(self, ty: Any) -> bool: # type: ignore[override]
102
- return False
103
-
104
- @visit.register
105
- def _transform_BoundTypeVar(self, ty: BoundTypeVar) -> bool:
106
- self.bound_vars.add(ty)
107
- return False
108
-
109
- @visit.register
110
- def _transform_BoundConstVar(self, c: BoundConstVar) -> bool:
111
- self.bound_vars.add(c)
112
- return False
@@ -1,6 +1,6 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from collections.abc import Sequence
3
- from dataclasses import dataclass, field
3
+ from dataclasses import dataclass, field, replace
4
4
  from enum import Enum, Flag, auto
5
5
  from functools import cached_property, total_ordering
6
6
  from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
@@ -57,14 +57,11 @@ class TypeBase(ToHugr[ht.Type], Transformable["Type"], ABC):
57
57
  return not self.copyable and self.droppable
58
58
 
59
59
  @cached_property
60
- @abstractmethod
61
60
  def hugr_bound(self) -> ht.TypeBound:
62
- """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.
63
-
64
- This needs to be specified explicitly, since opaque nonlinear types in a Hugr
65
- extension could be either declared as copyable or equatable. If we don't get the
66
- bound exactly right during serialisation, the Hugr validator will complain.
67
- """
61
+ """The Hugr bound of this type, i.e. `Any` or `Copyable`."""
62
+ if self.linear or self.affine:
63
+ return ht.TypeBound.Linear
64
+ return ht.TypeBound.Copyable
68
65
 
69
66
  @abstractmethod
70
67
  def cast(self) -> "Type":
@@ -79,6 +76,11 @@ class TypeBase(ToHugr[ht.Type], Transformable["Type"], ABC):
79
76
  """The existential type variables contained in this type."""
80
77
  return set()
81
78
 
79
+ @cached_property
80
+ def bound_vars(self) -> set[BoundVar]:
81
+ """The bound type variables contained in this type."""
82
+ return set()
83
+
82
84
  def substitute(self, subst: "Subst") -> "Type":
83
85
  """Substitutes existential variables in this type."""
84
86
  from guppylang_internals.tys.subst import Substituter
@@ -158,13 +160,17 @@ class ParametrizedTypeBase(TypeBase, ABC):
158
160
  """The existential type variables contained in this type."""
159
161
  return set().union(*(arg.unsolved_vars for arg in self.args))
160
162
 
163
+ @cached_property
164
+ def bound_vars(self) -> set[BoundVar]:
165
+ """The bound type variables contained in this type."""
166
+ return set().union(*(arg.bound_vars for arg in self.args))
167
+
161
168
  @cached_property
162
169
  def hugr_bound(self) -> ht.TypeBound:
163
- """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
164
- if self.linear:
165
- return ht.TypeBound.Linear
170
+ """The Hugr bound of this type, i.e. `Any` or `Copyable`."""
166
171
  return ht.TypeBound.join(
167
- *(arg.ty.hugr_bound for arg in self.args if isinstance(arg, TypeArg))
172
+ super().hugr_bound,
173
+ *(arg.ty.hugr_bound for arg in self.args if isinstance(arg, TypeArg)),
168
174
  )
169
175
 
170
176
  def visit(self, visitor: Visitor) -> None:
@@ -187,14 +193,10 @@ class BoundTypeVar(TypeBase, BoundVar):
187
193
  copyable: bool
188
194
  droppable: bool
189
195
 
190
- @cached_property
191
- def hugr_bound(self) -> ht.TypeBound:
192
- """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
193
- if self.linear:
194
- return ht.TypeBound.Linear
195
- # We're conservative and don't require equatability for non-linear variables.
196
- # This is fine since Guppy doesn't use the equatable feature anyways.
197
- return ht.TypeBound.Copyable
196
+ @property
197
+ def bound_vars(self) -> set[BoundVar]:
198
+ """The bound type variables contained in this type."""
199
+ return {self}
198
200
 
199
201
  def cast(self) -> "Type":
200
202
  """Casts an implementor of `TypeBase` into a `Type`."""
@@ -367,6 +369,21 @@ class InputFlags(Flag):
367
369
  Comptime = auto()
368
370
 
369
371
 
372
+ class UnitaryFlags(Flag):
373
+ """Flags that can be set on functions to indicate their unitary properties.
374
+
375
+ The flags indicate under which conditions a function can be used
376
+ in a unitary context.
377
+ """
378
+
379
+ NoFlags = 0
380
+ Control = auto()
381
+ Dagger = auto()
382
+ Power = auto()
383
+
384
+ Unitary = Control | Dagger | Power
385
+
386
+
370
387
  @dataclass(frozen=True)
371
388
  class FuncInput:
372
389
  """A single input of a function type."""
@@ -374,6 +391,10 @@ class FuncInput:
374
391
  ty: "Type"
375
392
  flags: InputFlags
376
393
 
394
+ #: Name of this input, or `None` if it is an unnamed argument (e.g. inside a
395
+ #: `Callable`). We use `compare=False` because names are not visible to the caller.
396
+ name: str | None = field(default=None, compare=False)
397
+
377
398
 
378
399
  @dataclass(frozen=True, init=False)
379
400
  class FunctionType(ParametrizedTypeBase):
@@ -382,7 +403,6 @@ class FunctionType(ParametrizedTypeBase):
382
403
  inputs: Sequence[FuncInput]
383
404
  output: "Type"
384
405
  params: Sequence[Parameter]
385
- input_names: Sequence[str] | None
386
406
  comptime_args: Sequence[ConstArg]
387
407
 
388
408
  args: Sequence[Argument] = field(init=False)
@@ -392,13 +412,15 @@ class FunctionType(ParametrizedTypeBase):
392
412
  intrinsically_droppable: bool = field(default=True, init=True)
393
413
  hugr_bound: ht.TypeBound = field(default=ht.TypeBound.Copyable, init=False)
394
414
 
415
+ unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, init=True)
416
+
395
417
  def __init__(
396
418
  self,
397
419
  inputs: Sequence[FuncInput],
398
420
  output: "Type",
399
- input_names: Sequence[str] | None = None,
400
421
  params: Sequence[Parameter] | None = None,
401
422
  comptime_args: Sequence[ConstArg] | None = None,
423
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
402
424
  ) -> None:
403
425
  # We need a custom __init__ to set the args
404
426
  args: list[Argument] = [TypeArg(inp.ty) for inp in inputs]
@@ -414,18 +436,43 @@ class FunctionType(ParametrizedTypeBase):
414
436
  ]
415
437
  args += comptime_args
416
438
 
439
+ # Either all inputs must have unique names, or none of them have names
440
+ names = {inp.name for inp in inputs if inp.name is not None}
441
+ if len(names) not in (0, len(inputs)):
442
+ raise InternalGuppyError(
443
+ "Tried to construct FunctionType with invalid input names"
444
+ )
445
+
417
446
  object.__setattr__(self, "args", args)
418
447
  object.__setattr__(self, "comptime_args", comptime_args)
419
448
  object.__setattr__(self, "inputs", inputs)
420
449
  object.__setattr__(self, "output", output)
421
- object.__setattr__(self, "input_names", input_names or [])
422
450
  object.__setattr__(self, "params", params)
451
+ object.__setattr__(self, "unitary_flags", unitary_flags)
423
452
 
424
453
  @property
425
454
  def parametrized(self) -> bool:
426
455
  """Whether the function is parametrized."""
427
456
  return len(self.params) > 0
428
457
 
458
+ @cached_property
459
+ def bound_vars(self) -> set[BoundVar]:
460
+ """The bound type variables contained in this type."""
461
+ if self.parametrized:
462
+ # Ensures that we don't look inside quantifiers
463
+ return set()
464
+ return super().bound_vars
465
+
466
+ @cached_property
467
+ def input_names(self) -> Sequence[str] | None:
468
+ """Names of all inputs or `None` if there are unnamed inputs."""
469
+ names: list[str] = []
470
+ for inp in self.inputs:
471
+ if inp.name is None:
472
+ return None
473
+ names.append(inp.name)
474
+ return names
475
+
429
476
  def cast(self) -> "Type":
430
477
  """Casts an implementor of `TypeBase` into a `Type`."""
431
478
  return self
@@ -484,12 +531,8 @@ class FunctionType(ParametrizedTypeBase):
484
531
  def transform(self, transformer: Transformer) -> "Type":
485
532
  """Accepts a transformer on this type."""
486
533
  return transformer.transform(self) or FunctionType(
487
- [
488
- FuncInput(inp.ty.transform(transformer), inp.flags)
489
- for inp in self.inputs
490
- ],
534
+ [replace(inp, ty=inp.ty.transform(transformer)) for inp in self.inputs],
491
535
  self.output.transform(transformer),
492
- self.input_names,
493
536
  self.params,
494
537
  )
495
538
 
@@ -506,7 +549,7 @@ class FunctionType(ParametrizedTypeBase):
506
549
  # However, we have to down-shift the de Bruijn index.
507
550
  if arg is None:
508
551
  param = param.with_idx(len(remaining_params))
509
- remaining_params.append(param)
552
+ remaining_params.append(param.instantiate_bounds(full_inst))
510
553
  arg = param.to_bound()
511
554
 
512
555
  # Set the `preserve` flag for instantiated tuples and None
@@ -519,9 +562,8 @@ class FunctionType(ParametrizedTypeBase):
519
562
 
520
563
  inst = Instantiator(full_inst)
521
564
  return FunctionType(
522
- [FuncInput(inp.ty.transform(inst), inp.flags) for inp in self.inputs],
565
+ [replace(inp, ty=inp.ty.transform(inst)) for inp in self.inputs],
523
566
  self.output.transform(inst),
524
- self.input_names,
525
567
  remaining_params,
526
568
  # Comptime type arguments also need to be instantiated
527
569
  comptime_args=[
@@ -538,6 +580,18 @@ class FunctionType(ParametrizedTypeBase):
538
580
  exs = [param.to_existential() for param in self.params]
539
581
  return self.instantiate([arg for arg, _ in exs]), [var for _, var in exs]
540
582
 
583
+ def with_unitary_flags(self, flags: UnitaryFlags) -> "FunctionType":
584
+ """Returns a copy of this function type with the specified unitary flags."""
585
+ # N.B. we can't use `dataclasses.replace` here since `FunctionType` has a custom
586
+ # constructor
587
+ return FunctionType(
588
+ self.inputs,
589
+ self.output,
590
+ self.params,
591
+ self.comptime_args,
592
+ flags,
593
+ )
594
+
541
595
 
542
596
  @dataclass(frozen=True, init=False)
543
597
  class TupleType(ParametrizedTypeBase):
@@ -582,53 +636,6 @@ class TupleType(ParametrizedTypeBase):
582
636
  )
583
637
 
584
638
 
585
- @dataclass(frozen=True, init=False)
586
- class SumType(ParametrizedTypeBase):
587
- """Type of sums.
588
-
589
- Note that this type is only used internally when constructing the Hugr. Users cannot
590
- write down this type.
591
- """
592
-
593
- element_types: Sequence["Type"]
594
-
595
- def __init__(self, element_types: Sequence["Type"]) -> None:
596
- # We need a custom __init__ to set the args
597
- args = [TypeArg(ty) for ty in element_types]
598
- object.__setattr__(self, "args", args)
599
- object.__setattr__(self, "element_types", element_types)
600
-
601
- @property
602
- def intrinsically_copyable(self) -> bool:
603
- """Whether objects of this type can be implicitly copied."""
604
- return True
605
-
606
- @property
607
- def intrinsically_droppable(self) -> bool:
608
- """Whether objects of this type can be dropped."""
609
- return True
610
-
611
- def cast(self) -> "Type":
612
- """Casts an implementor of `TypeBase` into a `Type`."""
613
- return self
614
-
615
- def to_hugr(self, ctx: ToHugrContext) -> ht.Sum:
616
- """Computes the Hugr representation of the type."""
617
- rows = [type_to_row(ty) for ty in self.element_types]
618
- if all(len(row) == 0 for row in rows):
619
- return ht.UnitSum(size=len(rows))
620
- elif len(rows) == 1:
621
- return ht.Tuple(*row_to_hugr(rows[0], ctx))
622
- else:
623
- return ht.Sum(variant_rows=rows_to_hugr(rows, ctx))
624
-
625
- def transform(self, transformer: Transformer) -> "Type":
626
- """Accepts a transformer on this type."""
627
- return transformer.transform(self) or SumType(
628
- [ty.transform(transformer) for ty in self.element_types]
629
- )
630
-
631
-
632
639
  @dataclass(frozen=True)
633
640
  class OpaqueType(ParametrizedTypeBase):
634
641
  """Type that is directly backed by a Hugr opaque type.
@@ -651,7 +658,7 @@ class OpaqueType(ParametrizedTypeBase):
651
658
 
652
659
  @property
653
660
  def hugr_bound(self) -> ht.TypeBound:
654
- """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
661
+ """The Hugr bound of this type, i.e. `Any` or `Copyable`."""
655
662
  if self.defn.bound is not None:
656
663
  return self.defn.bound
657
664
  return super().hugr_bound
@@ -717,9 +724,8 @@ class StructType(ParametrizedTypeBase):
717
724
 
718
725
 
719
726
  #: The type of parametrized Guppy types.
720
- ParametrizedType: TypeAlias = (
721
- FunctionType | TupleType | SumType | OpaqueType | StructType
722
- )
727
+ ParametrizedType: TypeAlias = FunctionType | TupleType | OpaqueType | StructType
728
+
723
729
 
724
730
  #: The type of Guppy types.
725
731
  #:
@@ -801,8 +807,6 @@ def unify(s: Type | Const, t: Type | Const, subst: "Subst | None") -> "Subst | N
801
807
  return _unify_args(s, t, subst)
802
808
  case TupleType() as s, TupleType() as t:
803
809
  return _unify_args(s, t, subst)
804
- case SumType() as s, SumType() as t:
805
- return _unify_args(s, t, subst)
806
810
  case OpaqueType() as s, OpaqueType() as t if s.defn == t.defn:
807
811
  return _unify_args(s, t, subst)
808
812
  case StructType() as s, StructType() as t if s.defn == t.defn:
@@ -871,6 +875,8 @@ def function_tensor_signature(tys: list[FunctionType]) -> FunctionType:
871
875
  outputs: list[Type] = []
872
876
  for fun_ty in tys:
873
877
  assert not fun_ty.parametrized
874
- inputs.extend(fun_ty.inputs)
878
+ # Forget the function input names since they might be non-unique across the
879
+ # tensored functions
880
+ inputs.extend([replace(inp, name=None) for inp in fun_ty.inputs])
875
881
  outputs.extend(type_to_row(fun_ty.output))
876
882
  return FunctionType(inputs, row_to_type(outputs))