guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +118 -5
  5. guppylang_internals/cfg/cfg.py +3 -0
  6. guppylang_internals/checker/cfg_checker.py +6 -0
  7. guppylang_internals/checker/core.py +5 -2
  8. guppylang_internals/checker/errors/generic.py +32 -1
  9. guppylang_internals/checker/errors/type_errors.py +14 -0
  10. guppylang_internals/checker/errors/wasm.py +7 -4
  11. guppylang_internals/checker/expr_checker.py +58 -17
  12. guppylang_internals/checker/func_checker.py +18 -14
  13. guppylang_internals/checker/linearity_checker.py +67 -10
  14. guppylang_internals/checker/modifier_checker.py +120 -0
  15. guppylang_internals/checker/stmt_checker.py +48 -1
  16. guppylang_internals/checker/unitary_checker.py +132 -0
  17. guppylang_internals/compiler/cfg_compiler.py +7 -6
  18. guppylang_internals/compiler/core.py +93 -56
  19. guppylang_internals/compiler/expr_compiler.py +72 -168
  20. guppylang_internals/compiler/modifier_compiler.py +176 -0
  21. guppylang_internals/compiler/stmt_compiler.py +15 -8
  22. guppylang_internals/decorator.py +86 -7
  23. guppylang_internals/definition/custom.py +39 -1
  24. guppylang_internals/definition/declaration.py +9 -6
  25. guppylang_internals/definition/function.py +12 -2
  26. guppylang_internals/definition/parameter.py +8 -3
  27. guppylang_internals/definition/pytket_circuits.py +14 -41
  28. guppylang_internals/definition/struct.py +13 -7
  29. guppylang_internals/definition/ty.py +3 -3
  30. guppylang_internals/definition/wasm.py +42 -10
  31. guppylang_internals/engine.py +9 -3
  32. guppylang_internals/experimental.py +5 -0
  33. guppylang_internals/nodes.py +147 -24
  34. guppylang_internals/std/_internal/checker.py +13 -108
  35. guppylang_internals/std/_internal/compiler/array.py +95 -283
  36. guppylang_internals/std/_internal/compiler/list.py +1 -1
  37. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  38. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  39. guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
  40. guppylang_internals/std/_internal/debug.py +18 -9
  41. guppylang_internals/std/_internal/util.py +1 -1
  42. guppylang_internals/tracing/object.py +10 -0
  43. guppylang_internals/tracing/unpacking.py +19 -20
  44. guppylang_internals/tys/arg.py +18 -3
  45. guppylang_internals/tys/builtin.py +2 -5
  46. guppylang_internals/tys/const.py +33 -4
  47. guppylang_internals/tys/errors.py +23 -1
  48. guppylang_internals/tys/param.py +31 -16
  49. guppylang_internals/tys/parsing.py +11 -24
  50. guppylang_internals/tys/printing.py +2 -8
  51. guppylang_internals/tys/qubit.py +62 -0
  52. guppylang_internals/tys/subst.py +8 -26
  53. guppylang_internals/tys/ty.py +91 -85
  54. guppylang_internals/wasm_util.py +129 -0
  55. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
  56. guppylang_internals-0.26.0.dist-info/RECORD +104 -0
  57. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
  58. guppylang_internals-0.24.0.dist-info/RECORD +0 -98
  59. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
@@ -328,7 +328,7 @@ def _list_new_classical(
328
328
  builder: DfBase[ops.DfParentOp], elem_type: ht.Type, args: list[Wire]
329
329
  ) -> Wire:
330
330
  # This may be simplified in the future with a `new` or `with_capacity` list op
331
- # See https://github.com/CQCL/hugr/issues/1508
331
+ # See https://github.com/quantinuum/hugr/issues/1508
332
332
  lst = builder.load(ListVal([], elem_ty=elem_type))
333
333
  push_op = list_push(elem_type)
334
334
  for elem in args:
@@ -0,0 +1,153 @@
1
+ from dataclasses import dataclass
2
+ from typing import ClassVar
3
+
4
+ import hugr
5
+ from hugr import Wire, ops, tys
6
+
7
+ from guppylang_internals.ast_util import AstNode
8
+ from guppylang_internals.compiler.core import CompilerContext
9
+ from guppylang_internals.compiler.expr_compiler import array_read_bool
10
+ from guppylang_internals.definition.custom import (
11
+ CustomCallCompiler,
12
+ CustomInoutCallCompiler,
13
+ )
14
+ from guppylang_internals.definition.value import CallReturnWires
15
+ from guppylang_internals.diagnostic import Error, Note
16
+ from guppylang_internals.error import GuppyError, InternalGuppyError
17
+ from guppylang_internals.std._internal.compiler.array import (
18
+ array_clone,
19
+ array_map,
20
+ array_to_std_array,
21
+ )
22
+ from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
23
+ from guppylang_internals.std._internal.compiler.tket_exts import RESULT_EXTENSION
24
+ from guppylang_internals.tys.arg import Argument, ConstArg
25
+ from guppylang_internals.tys.builtin import get_element_type, is_bool_type
26
+ from guppylang_internals.tys.const import BoundConstVar, ConstValue
27
+ from guppylang_internals.tys.ty import NumericType
28
+
29
+ #: Maximum length of a tag in the `result` function.
30
+ TAG_MAX_LEN = 200
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class TooLongError(Error):
35
+ title: ClassVar[str] = "Tag too long"
36
+ span_label: ClassVar[str] = "Result tag is too long"
37
+
38
+ @dataclass(frozen=True)
39
+ class Hint(Note):
40
+ message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
41
+
42
+ @dataclass(frozen=True)
43
+ class GenericHint(Note):
44
+ message: ClassVar[str] = "Parameter `{param}` was instantiated to `{value}`"
45
+ param: str
46
+ value: str
47
+
48
+
49
+ class ResultCompiler(CustomCallCompiler):
50
+ """Custom compiler for overloads of the `result` function.
51
+
52
+ See `ArrayResultCompiler` for the compiler that handles results involving arrays.
53
+ """
54
+
55
+ def __init__(self, op_name: str, with_int_width: bool = False):
56
+ self.op_name = op_name
57
+ self.with_int_width = with_int_width
58
+
59
+ def compile(self, args: list[Wire]) -> list[Wire]:
60
+ assert self.func is not None
61
+ [value] = args
62
+ ty = self.func.ty.inputs[1].ty
63
+ hugr_ty = ty.to_hugr(self.ctx)
64
+ args = [tag_to_hugr(self.type_args[0], self.ctx, self.node)]
65
+ if self.with_int_width:
66
+ args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
67
+ # Bool results need an extra conversion into regular hugr bools
68
+ if is_bool_type(ty):
69
+ value = self.builder.add_op(read_bool(), value)
70
+ hugr_ty = tys.Bool
71
+ op = RESULT_EXTENSION.get_op(self.op_name)
72
+ sig = tys.FunctionType(input=[hugr_ty], output=[])
73
+ self.builder.add_op(op.instantiate(args, sig), value)
74
+ return []
75
+
76
+
77
+ class ArrayResultCompiler(CustomInoutCallCompiler):
78
+ """Custom compiler for overloads of the `result` function accepting arrays.
79
+
80
+ See `ResultCompiler` for the compiler that handles basic results.
81
+ """
82
+
83
+ def __init__(self, op_name: str, with_int_width: bool = False):
84
+ self.op_name = op_name
85
+ self.with_int_width = with_int_width
86
+
87
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
88
+ assert self.func is not None
89
+ array_ty = self.func.ty.inputs[1].ty
90
+ elem_ty = get_element_type(array_ty)
91
+ [tag_arg, size_arg] = self.type_args
92
+ [arr] = args
93
+
94
+ # As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
95
+ # that all elements in it are copyable) to avoid linearity violations when
96
+ # both passing it to the result operation and returning it (as an inout
97
+ # argument).
98
+ hugr_elem_ty = elem_ty.to_hugr(self.ctx)
99
+ hugr_size = size_arg.to_hugr(self.ctx)
100
+ arr, out_arr = self.builder.add_op(array_clone(hugr_elem_ty, hugr_size), arr)
101
+ # For bool arrays, we furthermore need to coerce a read on all the array
102
+ # elements
103
+ if is_bool_type(elem_ty):
104
+ array_read = array_read_bool(self.ctx)
105
+ array_read = self.builder.load_function(array_read)
106
+ map_op = array_map(OpaqueBool, hugr_size, tys.Bool)
107
+ arr = self.builder.add_op(map_op, arr, array_read).out(0)
108
+ hugr_elem_ty = tys.Bool
109
+ # Turn `borrow_array` into regular `array`
110
+ arr = self.builder.add_op(array_to_std_array(hugr_elem_ty, hugr_size), arr).out(
111
+ 0
112
+ )
113
+
114
+ hugr_ty = hugr.std.collections.array.Array(hugr_elem_ty, hugr_size)
115
+ sig = tys.FunctionType(input=[hugr_ty], output=[])
116
+ args = [tag_to_hugr(tag_arg, self.ctx, self.node), hugr_size]
117
+ if self.with_int_width:
118
+ args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
119
+ op = ops.ExtOp(RESULT_EXTENSION.get_op(self.op_name), signature=sig, args=args)
120
+ self.builder.add_op(op, arr)
121
+ return CallReturnWires([], [out_arr])
122
+
123
+
124
+ def tag_to_hugr(tag_arg: Argument, ctx: CompilerContext, loc: AstNode) -> tys.TypeArg:
125
+ """Helper function to convert the Guppy tag comptime argument into a Hugr type arg.
126
+
127
+ Takes care of reading the tag value from the current monomorphization and checks
128
+ that the tag fits into `TAG_MAX_LEN`.
129
+ """
130
+ is_generic: BoundConstVar | None = None
131
+ match tag_arg:
132
+ case ConstArg(const=ConstValue(value=str(value))):
133
+ tag = value
134
+ case ConstArg(const=BoundConstVar(idx=idx) as var):
135
+ is_generic = var
136
+ assert ctx.current_mono_args is not None
137
+ match ctx.current_mono_args[idx]:
138
+ case ConstArg(const=ConstValue(value=str(value))):
139
+ tag = value
140
+ case _:
141
+ raise InternalGuppyError("Invalid tag monomorphization")
142
+ case _:
143
+ raise InternalGuppyError("Invalid tag argument")
144
+
145
+ if len(tag.encode("utf-8")) > TAG_MAX_LEN:
146
+ err = TooLongError(loc)
147
+ err.add_sub_diagnostic(TooLongError.Hint(None))
148
+ if is_generic:
149
+ err.add_sub_diagnostic(
150
+ TooLongError.GenericHint(None, is_generic.display_name, tag)
151
+ )
152
+ raise GuppyError(err)
153
+ return tys.StringArg(tag)
@@ -73,6 +73,14 @@ def panic(
73
73
  return ops.ExtOp(op_def, sig, args)
74
74
 
75
75
 
76
+ def make_error() -> ops.ExtOp:
77
+ """Returns an operation that makes an error."""
78
+ op_def = hugr.std.PRELUDE.get_op("MakeError")
79
+ args: list[ht.TypeArg] = []
80
+ sig = ht.FunctionType([ht.USize(), hugr.std.prelude.STRING_T], [error_type()])
81
+ return ops.ExtOp(op_def, sig, args)
82
+
83
+
76
84
  # ------------------------------------------------------
77
85
  # --------- Custom compilers for non-native ops --------
78
86
  # ------------------------------------------------------
@@ -90,14 +98,14 @@ def build_panic(
90
98
  return builder.add_op(op, err, *args)
91
99
 
92
100
 
93
- def build_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
101
+ def build_static_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
94
102
  """Constructs and loads a static error value."""
95
103
  val = ErrorVal(signal, msg)
96
104
  return builder.load(builder.add_const(val))
97
105
 
98
106
 
99
107
  # TODO: Common up build_unwrap_right and build_unwrap_left below once
100
- # https://github.com/CQCL/hugr/issues/1596 is fixed
108
+ # https://github.com/quantinuum/hugr/issues/1596 is fixed
101
109
 
102
110
 
103
111
  def build_unwrap_right(
@@ -111,7 +119,7 @@ def build_unwrap_right(
111
119
  assert isinstance(result_ty, ht.Sum)
112
120
  [left_tys, right_tys] = result_ty.variant_rows
113
121
  with conditional.add_case(0) as case:
114
- error = build_error(case, error_signal, error_msg)
122
+ error = build_static_error(case, error_signal, error_msg)
115
123
  case.set_outputs(*build_panic(case, left_tys, right_tys, error, *case.inputs()))
116
124
  with conditional.add_case(1) as case:
117
125
  case.set_outputs(*case.inputs())
@@ -134,7 +142,7 @@ def build_unwrap_left(
134
142
  with conditional.add_case(0) as case:
135
143
  case.set_outputs(*case.inputs())
136
144
  with conditional.add_case(1) as case:
137
- error = build_error(case, error_signal, error_msg)
145
+ error = build_static_error(case, error_signal, error_msg)
138
146
  case.set_outputs(*build_panic(case, right_tys, left_tys, error, *case.inputs()))
139
147
  return conditional.to_node()
140
148
 
@@ -1,11 +1,13 @@
1
1
  from dataclasses import dataclass
2
2
 
3
+ import tket_exts
3
4
  from hugr import val
4
5
  from tket_exts import (
5
6
  debug,
6
7
  futures,
8
+ global_phase,
7
9
  guppy,
8
- opaque_bool,
10
+ modifier,
9
11
  qsystem,
10
12
  qsystem_random,
11
13
  qsystem_utils,
@@ -15,10 +17,12 @@ from tket_exts import (
15
17
  wasm,
16
18
  )
17
19
 
18
- BOOL_EXTENSION = opaque_bool()
20
+ BOOL_EXTENSION = tket_exts.bool()
19
21
  DEBUG_EXTENSION = debug()
20
22
  FUTURES_EXTENSION = futures()
23
+ GLOBAL_PHASE_EXTENSION = global_phase()
21
24
  GUPPY_EXTENSION = guppy()
25
+ MODIFIER_EXTENSION = modifier()
22
26
  QSYSTEM_EXTENSION = qsystem()
23
27
  QSYSTEM_RANDOM_EXTENSION = qsystem_random()
24
28
  QSYSTEM_UTILS_EXTENSION = qsystem_utils()
@@ -31,7 +35,9 @@ TKET_EXTENSIONS = [
31
35
  BOOL_EXTENSION,
32
36
  DEBUG_EXTENSION,
33
37
  FUTURES_EXTENSION,
38
+ GLOBAL_PHASE_EXTENSION,
34
39
  GUPPY_EXTENSION,
40
+ MODIFIER_EXTENSION,
35
41
  QSYSTEM_EXTENSION,
36
42
  QSYSTEM_RANDOM_EXTENSION,
37
43
  QSYSTEM_UTILS_EXTENSION,
@@ -3,6 +3,7 @@ from dataclasses import dataclass
3
3
  from typing import ClassVar, cast
4
4
 
5
5
  from guppylang_internals.ast_util import with_loc
6
+ from guppylang_internals.checker.core import ComptimeVariable
6
7
  from guppylang_internals.checker.errors.generic import ExpectedError
7
8
  from guppylang_internals.checker.errors.type_errors import WrongNumberOfArgsError
8
9
  from guppylang_internals.checker.expr_checker import (
@@ -14,14 +15,14 @@ from guppylang_internals.definition.custom import CustomCallChecker
14
15
  from guppylang_internals.definition.ty import TypeDef
15
16
  from guppylang_internals.diagnostic import Error
16
17
  from guppylang_internals.error import GuppyTypeError
17
- from guppylang_internals.nodes import StateResultExpr
18
- from guppylang_internals.std._internal.checker import TAG_MAX_LEN, TooLongError
18
+ from guppylang_internals.nodes import GenericParamValue, PlaceNode, StateResultExpr
19
19
  from guppylang_internals.tys.builtin import (
20
20
  get_array_length,
21
21
  get_element_type,
22
22
  is_array_type,
23
23
  string_type,
24
24
  )
25
+ from guppylang_internals.tys.const import Const, ConstValue
25
26
  from guppylang_internals.tys.ty import (
26
27
  FuncInput,
27
28
  FunctionType,
@@ -43,12 +44,16 @@ class StateResultChecker(CustomCallChecker):
43
44
 
44
45
  def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
45
46
  tag, _ = ExprChecker(self.ctx).check(args[0], string_type())
46
- if not isinstance(tag, ast.Constant) or not isinstance(tag.value, str):
47
- raise GuppyTypeError(ExpectedError(tag, "a string literal"))
48
- if len(tag.value.encode("utf-8")) > TAG_MAX_LEN:
49
- err: Error = TooLongError(tag)
50
- err.add_sub_diagnostic(TooLongError.Hint(None))
51
- raise GuppyTypeError(err)
47
+ tag_value: Const
48
+ match tag:
49
+ case ast.Constant(value=str(v)):
50
+ tag_value = ConstValue(string_type(), v)
51
+ case PlaceNode(place=ComptimeVariable(static_value=str(v))):
52
+ tag_value = ConstValue(string_type(), v)
53
+ case GenericParamValue() as param_value:
54
+ tag_value = param_value.param.to_bound().const
55
+ case _:
56
+ raise GuppyTypeError(ExpectedError(tag, "a string literal"))
52
57
  syn_args: list[ast.expr] = [tag]
53
58
 
54
59
  if len(args) < 2:
@@ -90,6 +95,10 @@ class StateResultChecker(CustomCallChecker):
90
95
  args, ret_ty, inst = synthesize_call(func_ty, syn_args, self.node, self.ctx)
91
96
  assert len(inst) == 0, "func_ty is not generic"
92
97
  node = StateResultExpr(
93
- tag=tag.value, args=args, func_ty=func_ty, array_len=array_len
98
+ tag_value=tag_value,
99
+ tag_expr=tag,
100
+ args=args,
101
+ func_ty=func_ty,
102
+ array_len=array_len,
94
103
  )
95
104
  return with_loc(self.node, node), ret_ty
@@ -129,7 +129,7 @@ def int_op(
129
129
  # Ideally we'd be able to derive the arguments from the input/output types,
130
130
  # but the amount of variables does not correlate with the signature for the
131
131
  # integer ops in hugr :/
132
- # https://github.com/CQCL/hugr/blob/bfa13e59468feb0fc746677ea3b3a4341b2ed42e/hugr-core/src/std_extensions/arithmetic/int_ops.rs#L116
132
+ # https://github.com/quantinuum/hugr/blob/bfa13e59468feb0fc746677ea3b3a4341b2ed42e/hugr-core/src/std_extensions/arithmetic/int_ops.rs#L116
133
133
  #
134
134
  # For now, we just instantiate every type argument to a 64-bit integer.
135
135
  args: list[ht.TypeArg] = [int_arg() for _ in range(n_vars)]
@@ -539,6 +539,16 @@ class TracingDefMixin(DunderMixin):
539
539
 
540
540
  def to_guppy_object(self) -> GuppyObject:
541
541
  state = get_tracing_state()
542
+ defn = ENGINE.get_checked(self.id)
543
+ # TODO: For generic functions, we need to know an instantiation for their type
544
+ # parameters. Maybe we should pass them to `to_guppy_object`? Either way, this
545
+ # will require some more plumbing of type inference information through the
546
+ # comptime logic. For now, let's just bail on generic functions.
547
+ # See https://github.com/quantinuum/guppylang/issues/1336
548
+ if isinstance(defn, CallableDef) and defn.ty.parametrized:
549
+ raise GuppyComptimeError(
550
+ f"Cannot infer type parameters of generic function `{defn.name}`"
551
+ )
542
552
  defn, [] = state.ctx.build_compiled_def(self.id, type_args=[])
543
553
  if isinstance(defn, CompiledValueDef):
544
554
  wire = defn.load(state.dfg, state.ctx, state.node)
@@ -1,7 +1,6 @@
1
1
  from typing import Any, TypeVar
2
2
 
3
3
  from hugr import ops
4
- from hugr import tys as ht
5
4
  from hugr.build.dfg import DfBase
6
5
 
7
6
  from guppylang_internals.ast_util import AstNode
@@ -13,7 +12,6 @@ from guppylang_internals.compiler.core import CompilerContext
13
12
  from guppylang_internals.compiler.expr_compiler import python_value_to_hugr
14
13
  from guppylang_internals.error import GuppyComptimeError, GuppyError
15
14
  from guppylang_internals.std._internal.compiler.array import array_new, unpack_array
16
- from guppylang_internals.std._internal.compiler.prelude import build_unwrap
17
15
  from guppylang_internals.tracing.frozenlist import frozenlist
18
16
  from guppylang_internals.tracing.object import (
19
17
  GuppyObject,
@@ -71,9 +69,7 @@ def unpack_guppy_object(
71
69
  # them as Guppy objects here
72
70
  return obj
73
71
  elem_ty = get_element_type(ty)
74
- opt_elems = unpack_array(builder, obj._use_wire(None))
75
- err = "Non-copyable array element has already been used"
76
- elems = [build_unwrap(builder, opt_elem, err) for opt_elem in opt_elems]
72
+ elems = unpack_array(builder, obj._use_wire(None))
77
73
  obj_list = [
78
74
  unpack_guppy_object(GuppyObject(elem_ty, wire), builder, frozen)
79
75
  for wire in elems
@@ -128,11 +124,8 @@ def guppy_object_from_py(
128
124
  f"Element at index {i + 1} does not match the type of "
129
125
  f"previous elements. Expected `{elem_ty}`, got `{obj._ty}`."
130
126
  )
131
- hugr_elem_ty = ht.Option(elem_ty.to_hugr(ctx))
132
- wires = [
133
- builder.add_op(ops.Tag(1, hugr_elem_ty), obj._use_wire(None))
134
- for obj in objs
135
- ]
127
+ hugr_elem_ty = elem_ty.to_hugr(ctx)
128
+ wires = [obj._use_wire(None) for obj in objs]
136
129
  return GuppyObject(
137
130
  array_type(elem_ty, len(vs)),
138
131
  builder.add_op(array_new(hugr_elem_ty, len(vs)), *wires),
@@ -172,26 +165,32 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> bool:
172
165
  assert isinstance(obj._ty, NoneType)
173
166
  case tuple(vs):
174
167
  assert isinstance(obj._ty, TupleType)
175
- wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs()
176
- for v, ty, wire in zip(vs, obj._ty.element_types, wires, strict=True):
177
- success = update_packed_value(v, GuppyObject(ty, wire), builder)
168
+ wire_iterator = builder.add_op(
169
+ ops.UnpackTuple(), obj._use_wire(None)
170
+ ).outputs()
171
+ for v, ty, out_wire in zip(
172
+ vs, obj._ty.element_types, wire_iterator, strict=True
173
+ ):
174
+ success = update_packed_value(v, GuppyObject(ty, out_wire), builder)
178
175
  if not success:
179
176
  return False
180
177
  case GuppyStructObject(_ty=ty, _field_values=values):
181
178
  assert obj._ty == ty
182
- wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs()
183
- for field, wire in zip(ty.fields, wires, strict=True):
179
+ wire_iterator = builder.add_op(
180
+ ops.UnpackTuple(), obj._use_wire(None)
181
+ ).outputs()
182
+ for field, out_wire in zip(ty.fields, wire_iterator, strict=True):
184
183
  v = values[field.name]
185
- success = update_packed_value(v, GuppyObject(field.ty, wire), builder)
184
+ success = update_packed_value(
185
+ v, GuppyObject(field.ty, out_wire), builder
186
+ )
186
187
  if not success:
187
188
  values[field.name] = obj
188
189
  case list(vs) if len(vs) > 0:
189
190
  assert is_array_type(obj._ty)
190
191
  elem_ty = get_element_type(obj._ty)
191
- opt_wires = unpack_array(builder, obj._use_wire(None))
192
- err = "Non-droppable array element has already been used"
193
- for i, (v, opt_wire) in enumerate(zip(vs, opt_wires, strict=True)):
194
- (wire,) = build_unwrap(builder, opt_wire, err).outputs()
192
+ wires = unpack_array(builder, obj._use_wire(None))
193
+ for i, (v, wire) in enumerate(zip(vs, wires, strict=True)):
195
194
  success = update_packed_value(v, GuppyObject(elem_ty, wire), builder)
196
195
  if not success:
197
196
  vs[i] = obj
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
3
3
  from typing import TYPE_CHECKING, TypeAlias
4
4
 
5
5
  from hugr import tys as ht
@@ -18,7 +18,7 @@ from guppylang_internals.tys.const import (
18
18
  ConstValue,
19
19
  ExistentialConstVar,
20
20
  )
21
- from guppylang_internals.tys.var import ExistentialVar
21
+ from guppylang_internals.tys.var import BoundVar, ExistentialVar
22
22
 
23
23
  if TYPE_CHECKING:
24
24
  from guppylang_internals.tys.ty import Type
@@ -45,19 +45,29 @@ class ArgumentBase(ToHugr[ht.TypeArg], Transformable["Argument"], ABC):
45
45
  def unsolved_vars(self) -> set[ExistentialVar]:
46
46
  """The existential type variables contained in this argument."""
47
47
 
48
+ @property
49
+ @abstractmethod
50
+ def bound_vars(self) -> set[BoundVar]:
51
+ """The bound type variables contained in this argument."""
52
+
48
53
 
49
54
  @dataclass(frozen=True)
50
55
  class TypeArg(ArgumentBase):
51
56
  """Argument that can be instantiated for a `TypeParameter`."""
52
57
 
53
58
  # The type to instantiate
54
- ty: "Type"
59
+ ty: "Type" = field(hash=False) # Types are not hashable
55
60
 
56
61
  @property
57
62
  def unsolved_vars(self) -> set[ExistentialVar]:
58
63
  """The existential type variables contained in this argument."""
59
64
  return self.ty.unsolved_vars
60
65
 
66
+ @property
67
+ def bound_vars(self) -> set[BoundVar]:
68
+ """The bound type variables contained in this type."""
69
+ return self.ty.bound_vars
70
+
61
71
  def to_hugr(self, ctx: ToHugrContext) -> ht.TypeTypeArg:
62
72
  """Computes the Hugr representation of the argument."""
63
73
  ty: ht.Type = self.ty.to_hugr(ctx)
@@ -84,6 +94,11 @@ class ConstArg(ArgumentBase):
84
94
  """The existential const variables contained in this argument."""
85
95
  return self.const.unsolved_vars
86
96
 
97
+ @property
98
+ def bound_vars(self) -> set[BoundVar]:
99
+ """The bound type variables contained in this argument."""
100
+ return self.const.bound_vars
101
+
87
102
  def to_hugr(self, ctx: ToHugrContext) -> ht.TypeArg:
88
103
  """Computes the Hugr representation of this argument."""
89
104
  from guppylang_internals.tys.ty import NumericType
@@ -177,13 +177,10 @@ def _array_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
177
177
  assert isinstance(ty_arg, TypeArg)
178
178
  assert isinstance(len_arg, ConstArg)
179
179
 
180
- # Linear elements are turned into an optional to enable unsafe indexing.
181
- # See `ArrayGetitemCompiler` for details.
182
- # Same also for classical arrays, see https://github.com/CQCL/guppylang/issues/629
183
- elem_ty = ht.Option(ty_arg.ty.to_hugr(ctx))
180
+ elem_ty = ty_arg.ty.to_hugr(ctx)
184
181
  hugr_arg = len_arg.to_hugr(ctx)
185
182
 
186
- return hugr.std.collections.value_array.ValueArray(elem_ty, hugr_arg)
183
+ return hugr.std.collections.borrow_array.BorrowArray(elem_ty, hugr_arg)
187
184
 
188
185
 
189
186
  def _frozenarray_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
@@ -8,6 +8,7 @@ from guppylang_internals.tys.var import BoundVar, ExistentialVar
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from guppylang_internals.tys.arg import ConstArg
11
+ from guppylang_internals.tys.subst import Subst
11
12
  from guppylang_internals.tys.ty import Type
12
13
 
13
14
 
@@ -39,6 +40,11 @@ class ConstBase(Transformable["Const"], ABC):
39
40
  """The existential type variables contained in this constant."""
40
41
  return set()
41
42
 
43
+ @property
44
+ def bound_vars(self) -> set[BoundVar]:
45
+ """The bound type variables contained in this constant."""
46
+ return self.ty.bound_vars
47
+
42
48
  def __str__(self) -> str:
43
49
  from guppylang_internals.tys.printing import TypePrinter
44
50
 
@@ -48,16 +54,18 @@ class ConstBase(Transformable["Const"], ABC):
48
54
  """Accepts a visitor on this constant."""
49
55
  visitor.visit(self)
50
56
 
51
- def transform(self, transformer: Transformer, /) -> "Const":
52
- """Accepts a transformer on this constant."""
53
- return transformer.transform(self) or self.cast()
54
-
55
57
  def to_arg(self) -> "ConstArg":
56
58
  """Wraps this constant into a type argument."""
57
59
  from guppylang_internals.tys.arg import ConstArg
58
60
 
59
61
  return ConstArg(self.cast())
60
62
 
63
+ def substitute(self, subst: "Subst") -> "Const":
64
+ """Substitutes existential variables in this constant."""
65
+ from guppylang_internals.tys.subst import Substituter
66
+
67
+ return self.transform(Substituter(subst))
68
+
61
69
 
62
70
  @dataclass(frozen=True)
63
71
  class ConstValue(ConstBase):
@@ -74,6 +82,10 @@ class ConstValue(ConstBase):
74
82
  """Casts an implementor of `ConstBase` into a `Const`."""
75
83
  return self
76
84
 
85
+ def transform(self, transformer: Transformer, /) -> "Const":
86
+ """Accepts a transformer on this constant."""
87
+ return transformer.transform(self) or self
88
+
77
89
 
78
90
  @dataclass(frozen=True)
79
91
  class BoundConstVar(BoundVar, ConstBase):
@@ -84,10 +96,21 @@ class BoundConstVar(BoundVar, ConstBase):
84
96
  `BoundConstVar(idx=0)`.
85
97
  """
86
98
 
99
+ @property
100
+ def bound_vars(self) -> set[BoundVar]:
101
+ """The bound type variables contained in this constant."""
102
+ return {self} | self.ty.bound_vars
103
+
87
104
  def cast(self) -> "Const":
88
105
  """Casts an implementor of `ConstBase` into a `Const`."""
89
106
  return self
90
107
 
108
+ def transform(self, transformer: Transformer, /) -> "Const":
109
+ """Accepts a transformer on this constant."""
110
+ return transformer.transform(self) or BoundConstVar(
111
+ transformer.transform(self.ty) or self.ty, self.display_name, self.idx
112
+ )
113
+
91
114
 
92
115
  @dataclass(frozen=True)
93
116
  class ExistentialConstVar(ExistentialVar, ConstBase):
@@ -110,5 +133,11 @@ class ExistentialConstVar(ExistentialVar, ConstBase):
110
133
  """Casts an implementor of `ConstBase` into a `Const`."""
111
134
  return self
112
135
 
136
+ def transform(self, transformer: Transformer, /) -> "Const":
137
+ """Accepts a transformer on this constant."""
138
+ return transformer.transform(self) or ExistentialConstVar(
139
+ transformer.transform(self.ty) or self.ty, self.display_name, self.id
140
+ )
141
+
113
142
 
114
143
  Const: TypeAlias = ConstValue | BoundConstVar | ExistentialConstVar
@@ -5,7 +5,7 @@ from guppylang_internals.diagnostic import Error, Help, Note
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from guppylang_internals.definition.parameter import ParamDef
8
- from guppylang_internals.tys.ty import Type
8
+ from guppylang_internals.tys.ty import Type, UnitaryFlags
9
9
 
10
10
 
11
11
  @dataclass(frozen=True)
@@ -182,3 +182,25 @@ class InvalidFlagError(Error):
182
182
  class FlagNotAllowedError(Error):
183
183
  title: ClassVar[str] = "Invalid annotation"
184
184
  span_label: ClassVar[str] = "`@` type annotations are not allowed in this position"
185
+
186
+
187
+ @dataclass(frozen=True)
188
+ class UnitaryCallError(Error):
189
+ title: ClassVar[str] = "Unitary constraint violation"
190
+ span_label: ClassVar[str] = (
191
+ "This function cannot be called in a {render_flags} context"
192
+ )
193
+ flags: "UnitaryFlags"
194
+
195
+ @property
196
+ def render_flags(self) -> str:
197
+ from guppylang_internals.tys.ty import UnitaryFlags
198
+
199
+ if self.flags == UnitaryFlags.Dagger:
200
+ return "dagger"
201
+ elif self.flags == UnitaryFlags.Control:
202
+ return "control"
203
+ elif self.flags == UnitaryFlags.Power:
204
+ return "power"
205
+ else:
206
+ return "unitary"