guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,271 @@
1
+ """Compilers building array functions on top of hugr standard operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Final, TypeVar
8
+
9
+ import hugr.std.collections
10
+ import hugr.std.int
11
+ import hugr.std.prelude
12
+ from hugr import Node, Wire, ops
13
+ from hugr import tys as ht
14
+ from hugr import val as hv
15
+
16
+ from guppylang_internals.compiler.core import CompilerContext, GlobalConstId
17
+ from guppylang_internals.definition.custom import (
18
+ CustomCallCompiler,
19
+ CustomInoutCallCompiler,
20
+ )
21
+ from guppylang_internals.definition.value import CallReturnWires
22
+ from guppylang_internals.error import InternalGuppyError
23
+ from guppylang_internals.nodes import ExitKind
24
+ from guppylang_internals.tys.subst import Inst
25
+
26
+ if TYPE_CHECKING:
27
+ from collections.abc import Callable
28
+
29
+ from hugr.build import function as hf
30
+ from hugr.build.dfg import DfBase
31
+
32
+ from guppylang_internals.tys.common import ToHugrContext
33
+ from guppylang_internals.tys.subst import Inst
34
+
35
+
36
+ # --------------------------------------------
37
+ # --------------- prelude --------------------
38
+ # --------------------------------------------
39
+
40
+
41
+ def error_type() -> ht.ExtType:
42
+ """Returns the hugr type of an error value."""
43
+ return hugr.std.PRELUDE.types["error"].instantiate([])
44
+
45
+
46
+ @dataclass
47
+ class ErrorVal(hv.ExtensionValue):
48
+ """Custom value for a floating point number."""
49
+
50
+ signal: int
51
+ message: str
52
+
53
+ def to_value(self) -> hv.Extension:
54
+ name = "ConstError"
55
+ payload = {"signal": self.signal, "message": self.message}
56
+ return hv.Extension(name, typ=error_type(), val=payload)
57
+
58
+ def __str__(self) -> str:
59
+ return f"Error({self.signal}): {self.message}"
60
+
61
+
62
+ def panic(
63
+ inputs: list[ht.Type], outputs: list[ht.Type], kind: ExitKind = ExitKind.Panic
64
+ ) -> ops.ExtOp:
65
+ """Returns an operation that panics."""
66
+ name = "panic" if kind == ExitKind.Panic else "exit"
67
+ op_def = hugr.std.PRELUDE.get_op(name)
68
+ args: list[ht.TypeArg] = [
69
+ ht.ListArg([ht.TypeTypeArg(ty) for ty in inputs]),
70
+ ht.ListArg([ht.TypeTypeArg(ty) for ty in outputs]),
71
+ ]
72
+ sig = ht.FunctionType([error_type(), *inputs], outputs)
73
+ return ops.ExtOp(op_def, sig, args)
74
+
75
+
76
+ # ------------------------------------------------------
77
+ # --------- Custom compilers for non-native ops --------
78
+ # ------------------------------------------------------
79
+
80
+
81
+ def build_panic(
82
+ builder: DfBase[P],
83
+ in_tys: ht.TypeRow,
84
+ out_tys: ht.TypeRow,
85
+ err: Wire,
86
+ *args: Wire,
87
+ ) -> Node:
88
+ """Builds a panic operation."""
89
+ op = panic(in_tys, out_tys, ExitKind.Panic)
90
+ return builder.add_op(op, err, *args)
91
+
92
+
93
+ def build_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
94
+ """Constructs and loads a static error value."""
95
+ val = ErrorVal(signal, msg)
96
+ return builder.load(builder.add_const(val))
97
+
98
+
99
+ # TODO: Common up build_unwrap_right and build_unwrap_left below once
100
+ # https://github.com/CQCL/hugr/issues/1596 is fixed
101
+
102
+
103
+ def build_unwrap_right(
104
+ builder: DfBase[P], either: Wire, error_msg: str, error_signal: int = 1
105
+ ) -> Node:
106
+ """Unwraps the right value from a `hugr.tys.Either` value, panicking with the given
107
+ message if the result is left.
108
+ """
109
+ conditional = builder.add_conditional(either)
110
+ result_ty = builder.hugr.port_type(either.out_port())
111
+ assert isinstance(result_ty, ht.Sum)
112
+ [left_tys, right_tys] = result_ty.variant_rows
113
+ with conditional.add_case(0) as case:
114
+ error = build_error(case, error_signal, error_msg)
115
+ case.set_outputs(*build_panic(case, left_tys, right_tys, error, *case.inputs()))
116
+ with conditional.add_case(1) as case:
117
+ case.set_outputs(*case.inputs())
118
+ return conditional.to_node()
119
+
120
+
121
+ P = TypeVar("P", bound=ops.DfParentOp)
122
+
123
+
124
+ def build_unwrap_left(
125
+ builder: DfBase[P], either: Wire, error_msg: str, error_signal: int = 1
126
+ ) -> Node:
127
+ """Unwraps the left value from a `hugr.tys.Either` value, panicking with the given
128
+ message if the result is right.
129
+ """
130
+ conditional = builder.add_conditional(either)
131
+ result_ty = builder.hugr.port_type(either.out_port())
132
+ assert isinstance(result_ty, ht.Sum)
133
+ [left_tys, right_tys] = result_ty.variant_rows
134
+ with conditional.add_case(0) as case:
135
+ case.set_outputs(*case.inputs())
136
+ with conditional.add_case(1) as case:
137
+ error = build_error(case, error_signal, error_msg)
138
+ case.set_outputs(*build_panic(case, right_tys, left_tys, error, *case.inputs()))
139
+ return conditional.to_node()
140
+
141
+
142
+ def build_unwrap(
143
+ builder: DfBase[P], option: Wire, error_msg: str, error_signal: int = 1
144
+ ) -> Node:
145
+ """Unwraps an `hugr.tys.Option` value, panicking with the given message if the
146
+ result is an error.
147
+ """
148
+ return build_unwrap_right(builder, option, error_msg, error_signal)
149
+
150
+
151
+ def build_expect_none(
152
+ builder: DfBase[P], option: Wire, error_msg: str, error_signal: int = 1
153
+ ) -> Node:
154
+ """Checks that `hugr.tys.Option` value is `None`, otherwise panics with the given
155
+ message.
156
+ """
157
+ return build_unwrap_left(builder, option, error_msg, error_signal)
158
+
159
+
160
+ class MemSwapCompiler(CustomCallCompiler):
161
+ """Compiler for the `mem_swap` function."""
162
+
163
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
164
+ [x, y] = args
165
+ return CallReturnWires(regular_returns=[], inout_returns=[y, x])
166
+
167
+ def compile(self, args: list[Wire]) -> list[Wire]:
168
+ raise InternalGuppyError("Call compile_with_inouts instead")
169
+
170
+
171
+ UNWRAP_RESULT: Final[GlobalConstId] = GlobalConstId.fresh("unwrap_result")
172
+
173
+
174
+ def _build_unwrap_result(func: hf.Function, result_type_var: ht.Variable) -> None:
175
+ either = func.inputs()[0]
176
+ conditional = func.add_conditional(either)
177
+ with conditional.add_case(0) as case:
178
+ [error] = list(case.inputs())
179
+ case.set_outputs(
180
+ *build_panic(case, [error_type()], [result_type_var], error, *case.inputs())
181
+ )
182
+ with conditional.add_case(1) as case:
183
+ case.set_outputs(*case.inputs())
184
+ func.set_outputs(*conditional.outputs())
185
+
186
+
187
+ def unwrap_result(
188
+ builder: DfBase[P],
189
+ ctx: CompilerContext,
190
+ either: Wire,
191
+ ) -> Wire:
192
+ """Builds or retrieves and then calls a function that unwraps an `hugr.tys.Either`
193
+ value, panicking if the result is an error.
194
+ """
195
+ either_ty = builder.hugr.port_type(either.out_port())
196
+ assert isinstance(either_ty, ht.Either)
197
+ [error_tys, result_tys] = either_ty.variant_rows
198
+ # Construct the function signature for unwrapping a result of type T.
199
+ func_ty = ht.PolyFuncType(
200
+ params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
201
+ body=ht.FunctionType(
202
+ input=[ht.Either(error_tys, [ht.Variable(0, ht.TypeBound.Linear)])],
203
+ output=[ht.Variable(0, ht.TypeBound.Linear)],
204
+ ),
205
+ )
206
+ # Build global unwrap result function if it doesn't already exist.
207
+ func, already_exists = ctx.declare_global_func(UNWRAP_RESULT, func_ty)
208
+ if not already_exists:
209
+ _build_unwrap_result(func, ht.Variable(0, ht.TypeBound.Linear))
210
+ # Call the global function.
211
+ concrete_ty = ht.FunctionType(
212
+ input=[ht.Either(error_tys, result_tys)], output=result_tys
213
+ )
214
+ type_args = [ht.TypeTypeArg(*result_tys)]
215
+ func_call = builder.call(
216
+ func.parent_node,
217
+ either,
218
+ instantiation=concrete_ty,
219
+ type_args=type_args,
220
+ )
221
+ [result] = list(func_call.outputs())
222
+ return result
223
+
224
+
225
+ class UnwrapOpCompiler(CustomInoutCallCompiler):
226
+ """Compiler for operations that require unwrapping a result which could potentially
227
+ cause a panic.
228
+
229
+ Args:
230
+ op: A HUGR operation that outputs an Either<error, result> value.
231
+ """
232
+
233
+ op: Callable[[ht.FunctionType, Inst, ToHugrContext], ops.DataflowOp]
234
+
235
+ def __init__(
236
+ self, op: Callable[[ht.FunctionType, Inst, ToHugrContext], ops.DataflowOp]
237
+ ):
238
+ self.op = op
239
+
240
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
241
+ assert len(self.ty.output) == 1
242
+ # To instantiate the op we need a function signature that wraps the output of
243
+ # the function that is being compiled into an either type.
244
+ opt_func_type = ht.FunctionType(
245
+ input=self.ty.input,
246
+ output=[ht.Either([error_type()], self.ty.output)],
247
+ )
248
+ op = self.op(opt_func_type, self.type_args, self.ctx)
249
+ either = self.builder.add_op(op, *args)
250
+ result = unwrap_result(self.builder, self.ctx, either)
251
+ return CallReturnWires(regular_returns=[result], inout_returns=[])
252
+
253
+
254
+ class BarrierCompiler(CustomCallCompiler):
255
+ """Compiler for the `barrier` function."""
256
+
257
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
258
+ tys = [t for arg in args if (t := self.builder.hugr.port_type(arg.out_port()))]
259
+
260
+ op = hugr.std.prelude.PRELUDE_EXTENSION.get_op("Barrier").instantiate(
261
+ [ht.ListArg([ht.TypeTypeArg(ty) for ty in tys])]
262
+ )
263
+
264
+ barrier_n = self.builder.add_op(op, *args)
265
+
266
+ return CallReturnWires(
267
+ regular_returns=[], inout_returns=[barrier_n[i] for i in range(len(tys))]
268
+ )
269
+
270
+ def compile(self, args: list[Wire]) -> list[Wire]:
271
+ raise InternalGuppyError("Call compile_with_inouts instead")
@@ -0,0 +1,48 @@
1
+ from hugr import Wire
2
+ from hugr import tys as ht
3
+ from hugr.std.int import int_t
4
+
5
+ from guppylang_internals.definition.custom import CustomInoutCallCompiler
6
+ from guppylang_internals.definition.value import CallReturnWires
7
+ from guppylang_internals.std._internal.compiler.arithmetic import inarrow_s, iwiden_s
8
+ from guppylang_internals.std._internal.compiler.prelude import build_unwrap_right
9
+ from guppylang_internals.std._internal.compiler.quantum import (
10
+ RNGCONTEXT_T,
11
+ )
12
+ from guppylang_internals.std._internal.compiler.tket_exts import (
13
+ QSYSTEM_RANDOM_EXTENSION,
14
+ )
15
+ from guppylang_internals.std._internal.util import external_op
16
+
17
+
18
+ class RandomIntCompiler(CustomInoutCallCompiler):
19
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
20
+ [ctx] = args
21
+ [rnd, ctx] = self.builder.add_op(
22
+ external_op("RandomInt", [], ext=QSYSTEM_RANDOM_EXTENSION)(
23
+ ht.FunctionType([RNGCONTEXT_T], [int_t(5), RNGCONTEXT_T]), [], self.ctx
24
+ ),
25
+ ctx,
26
+ )
27
+ [rnd] = self.builder.add_op(iwiden_s(5, 6), rnd)
28
+ return CallReturnWires(regular_returns=[rnd], inout_returns=[ctx])
29
+
30
+
31
+ class RandomIntBoundedCompiler(CustomInoutCallCompiler):
32
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
33
+ [ctx, bound] = args
34
+ bound_sum = self.builder.add_op(inarrow_s(6, 5), bound)
35
+ bound = build_unwrap_right(
36
+ self.builder, bound_sum, "bound must be a 32-bit integer"
37
+ )
38
+ [rnd, ctx] = self.builder.add_op(
39
+ external_op("RandomIntBounded", [], ext=QSYSTEM_RANDOM_EXTENSION)(
40
+ ht.FunctionType([RNGCONTEXT_T, int_t(5)], [int_t(5), RNGCONTEXT_T]),
41
+ [],
42
+ self.ctx,
43
+ ),
44
+ ctx,
45
+ bound,
46
+ )
47
+ [rnd] = self.builder.add_op(iwiden_s(5, 6), rnd)
48
+ return CallReturnWires(regular_returns=[rnd], inout_returns=[ctx])
@@ -0,0 +1,118 @@
1
+ """Compilers building list functions on top of hugr standard operations, that involve
2
+ multiple nodes.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from hugr import Wire, ops
8
+ from hugr import ext as he
9
+ from hugr import tys as ht
10
+ from hugr.std.float import FLOAT_T
11
+
12
+ from guppylang_internals.definition.custom import CustomInoutCallCompiler
13
+ from guppylang_internals.definition.value import CallReturnWires
14
+ from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
15
+ from guppylang_internals.std._internal.compiler.tket_exts import (
16
+ QSYSTEM_RANDOM_EXTENSION,
17
+ QUANTUM_EXTENSION,
18
+ ROTATION_EXTENSION,
19
+ )
20
+
21
+ # ----------------------------------------------
22
+ # --------- tket.* extensions -----------------
23
+ # ----------------------------------------------
24
+
25
+
26
+ RNGCONTEXT_T_DEF = QSYSTEM_RANDOM_EXTENSION.get_type("context")
27
+ RNGCONTEXT_T = ht.ExtType(RNGCONTEXT_T_DEF)
28
+
29
+ ROTATION_T_DEF = ROTATION_EXTENSION.get_type("rotation")
30
+ ROTATION_T = ht.ExtType(ROTATION_T_DEF)
31
+
32
+
33
+ def from_halfturns_unchecked() -> ops.ExtOp:
34
+ return ops.ExtOp(
35
+ ROTATION_EXTENSION.get_op("from_halfturns_unchecked"),
36
+ ht.FunctionType([FLOAT_T], [ROTATION_T]),
37
+ )
38
+
39
+
40
+ # ------------------------------------------------------
41
+ # --------- Custom compilers for non-native ops --------
42
+ # ------------------------------------------------------
43
+
44
+
45
+ class InoutMeasureCompiler(CustomInoutCallCompiler):
46
+ """Compiler for the measure functions with an inout qubit
47
+ such as the `project_z` function - requiring conversion to tket.bool."""
48
+
49
+ opname: str
50
+ ext: he.Extension
51
+
52
+ def __init__(self, opname: str | None = None, ext: he.Extension | None = None):
53
+ self.opname = opname or "Measure"
54
+ self.ext = ext or QUANTUM_EXTENSION
55
+
56
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
57
+ from guppylang_internals.std._internal.util import quantum_op
58
+
59
+ [q] = args
60
+ [q, bit] = self.builder.add_op(
61
+ quantum_op(self.opname, ext=self.ext)(
62
+ ht.FunctionType([ht.Qubit], [ht.Qubit, ht.Bool]), [], self.ctx
63
+ ),
64
+ q,
65
+ )
66
+ bit = self.builder.add_op(make_opaque(), bit)
67
+ return CallReturnWires(regular_returns=[bit], inout_returns=[q])
68
+
69
+
70
+ class InoutMeasureResetCompiler(CustomInoutCallCompiler):
71
+ """Compiler for the measure functions with an inout qubit
72
+ such as the `project_z` function."""
73
+
74
+ opname: str
75
+ ext: he.Extension
76
+
77
+ def __init__(self, opname: str | None = None, ext: he.Extension | None = None):
78
+ self.opname = opname or "Measure"
79
+ self.ext = ext or QUANTUM_EXTENSION
80
+
81
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
82
+ from guppylang_internals.std._internal.util import quantum_op
83
+
84
+ [q] = args
85
+ [q, bit] = self.builder.add_op(
86
+ quantum_op(self.opname, ext=self.ext)(
87
+ ht.FunctionType([ht.Qubit], [ht.Qubit, OpaqueBool]), [], self.ctx
88
+ ),
89
+ q,
90
+ )
91
+ return CallReturnWires(regular_returns=[bit], inout_returns=[q])
92
+
93
+
94
+ class RotationCompiler(CustomInoutCallCompiler):
95
+ opname: str
96
+
97
+ def __init__(self, opname: str):
98
+ self.opname = opname
99
+
100
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
101
+ from guppylang_internals.std._internal.util import quantum_op
102
+
103
+ [*qs, angle] = args
104
+ [halfturns] = self.builder.add_op(ops.UnpackTuple([FLOAT_T]), angle)
105
+ [rotation] = self.builder.add_op(from_halfturns_unchecked(), halfturns)
106
+
107
+ qs = self.builder.add_op(
108
+ quantum_op(self.opname)(
109
+ ht.FunctionType(
110
+ [ht.Qubit for _ in qs] + [ROTATION_T], [ht.Qubit for _ in qs]
111
+ ),
112
+ [],
113
+ self.ctx,
114
+ ),
115
+ *qs,
116
+ rotation,
117
+ )
118
+ return CallReturnWires(regular_returns=[], inout_returns=list(qs))
@@ -0,0 +1,55 @@
1
+ from dataclasses import dataclass
2
+
3
+ from hugr import ops
4
+ from hugr import tys as ht
5
+ from hugr import val as hv
6
+
7
+ from guppylang_internals.std._internal.compiler.tket_exts import BOOL_EXTENSION
8
+
9
+ BOOL_DEF = BOOL_EXTENSION.get_type("bool")
10
+ OpaqueBool = ht.ExtType(BOOL_DEF)
11
+
12
+
13
+ def read_bool() -> ops.ExtOp:
14
+ return ops.ExtOp(
15
+ BOOL_EXTENSION.get_op("read"),
16
+ ht.FunctionType([OpaqueBool], [ht.Bool]),
17
+ )
18
+
19
+
20
+ def make_opaque() -> ops.ExtOp:
21
+ return ops.ExtOp(
22
+ BOOL_EXTENSION.get_op("make_opaque"),
23
+ ht.FunctionType([ht.Bool], [OpaqueBool]),
24
+ )
25
+
26
+
27
+ def not_op() -> ops.ExtOp:
28
+ return ops.ExtOp(
29
+ BOOL_EXTENSION.get_op("not"),
30
+ ht.FunctionType([OpaqueBool], [OpaqueBool]),
31
+ )
32
+
33
+
34
+ @dataclass
35
+ class OpaqueBoolVal(hv.ExtensionValue):
36
+ """Custom value for a boolean."""
37
+
38
+ v: bool
39
+
40
+ def to_value(self) -> hv.Extension:
41
+ name = "ConstBool"
42
+ payload = self.v
43
+ return hv.Extension(
44
+ name,
45
+ typ=OpaqueBool,
46
+ val=payload,
47
+ extensions=[BOOL_EXTENSION.name],
48
+ )
49
+
50
+ def __str__(self) -> str:
51
+ return f"{self.v}"
52
+
53
+
54
+ OPAQUE_TRUE = OpaqueBoolVal(True)
55
+ OPAQUE_FALSE = OpaqueBoolVal(False)
@@ -0,0 +1,59 @@
1
+ from dataclasses import dataclass
2
+
3
+ from hugr import val
4
+ from tket_exts import (
5
+ debug,
6
+ futures,
7
+ opaque_bool,
8
+ qsystem,
9
+ qsystem_random,
10
+ qsystem_utils,
11
+ quantum,
12
+ result,
13
+ rotation,
14
+ wasm,
15
+ )
16
+
17
+ BOOL_EXTENSION = opaque_bool()
18
+ DEBUG_EXTENSION = debug()
19
+ FUTURES_EXTENSION = futures()
20
+ QSYSTEM_EXTENSION = qsystem()
21
+ QSYSTEM_RANDOM_EXTENSION = qsystem_random()
22
+ QSYSTEM_UTILS_EXTENSION = qsystem_utils()
23
+ QUANTUM_EXTENSION = quantum()
24
+ RESULT_EXTENSION = result()
25
+ ROTATION_EXTENSION = rotation()
26
+ WASM_EXTENSION = wasm()
27
+
28
+ TKET_EXTENSIONS = [
29
+ BOOL_EXTENSION,
30
+ DEBUG_EXTENSION,
31
+ FUTURES_EXTENSION,
32
+ QSYSTEM_EXTENSION,
33
+ QSYSTEM_RANDOM_EXTENSION,
34
+ QSYSTEM_UTILS_EXTENSION,
35
+ QUANTUM_EXTENSION,
36
+ RESULT_EXTENSION,
37
+ ROTATION_EXTENSION,
38
+ WASM_EXTENSION,
39
+ ]
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class ConstWasmModule(val.ExtensionValue):
44
+ """Python wrapper for the tket ConstWasmModule type"""
45
+
46
+ wasm_file: str
47
+ wasm_hash: int
48
+
49
+ def to_value(self) -> val.Extension:
50
+ ty = WASM_EXTENSION.get_type("module").instantiate([])
51
+
52
+ name = "tket.wasm.ConstWasmModule"
53
+ payload = {"name": self.wasm_file, "hash": self.wasm_hash}
54
+ return val.Extension(name, typ=ty, val=payload, extensions=["tket.wasm"])
55
+
56
+ def __str__(self) -> str:
57
+ return (
58
+ f"ConstWasmModule(wasm_file={self.wasm_file}, wasm_hash={self.wasm_hash})"
59
+ )
@@ -0,0 +1,135 @@
1
+ from hugr import Wire, ops
2
+ from hugr import tys as ht
3
+
4
+ from guppylang_internals.definition.custom import CustomInoutCallCompiler
5
+ from guppylang_internals.definition.value import CallReturnWires
6
+ from guppylang_internals.error import InternalGuppyError
7
+ from guppylang_internals.nodes import GlobalCall
8
+ from guppylang_internals.std._internal.compiler.arithmetic import convert_itousize
9
+ from guppylang_internals.std._internal.compiler.prelude import build_unwrap
10
+ from guppylang_internals.std._internal.compiler.tket_exts import (
11
+ FUTURES_EXTENSION,
12
+ WASM_EXTENSION,
13
+ ConstWasmModule,
14
+ )
15
+ from guppylang_internals.tys.builtin import (
16
+ wasm_module_info,
17
+ )
18
+ from guppylang_internals.tys.ty import (
19
+ FunctionType,
20
+ )
21
+
22
+
23
+ class WasmModuleInitCompiler(CustomInoutCallCompiler):
24
+ """Compiler for initialising WASM modules.
25
+ Calls tket's "get_context" and unwraps the `Option` result.
26
+ Returns a `tket.wasm.context` wire.
27
+ """
28
+
29
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
30
+ # Make a ConstWasmModule as a CustomConst
31
+ assert len(args) == 1
32
+ ctx_arg = args[0]
33
+ ctx_wire = self.builder.add_op(convert_itousize(), ctx_arg)
34
+
35
+ ctx_ty = WASM_EXTENSION.get_type("context").instantiate([])
36
+ get_ctx_op = ops.ExtOp(
37
+ WASM_EXTENSION.get_op("get_context"),
38
+ ht.FunctionType([ht.USize()], [ht.Option(ctx_ty)]),
39
+ )
40
+ node = self.builder.add_op(get_ctx_op, ctx_wire)
41
+ opt_w: Wire = node[0]
42
+ err = "Failed to spawn WASM context"
43
+ out_node = build_unwrap(self.builder, opt_w, err)
44
+ return CallReturnWires(regular_returns=[out_node], inout_returns=[])
45
+
46
+
47
+ class WasmModuleDiscardCompiler(CustomInoutCallCompiler):
48
+ """Compiler for discarding WASM contexts."""
49
+
50
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
51
+ assert len(args) == 1
52
+ ctx = args[0]
53
+ op = WASM_EXTENSION.get_op("dispose_context").instantiate([])
54
+ self.builder.add_op(op, ctx)
55
+ return CallReturnWires(regular_returns=[], inout_returns=[])
56
+
57
+
58
+ class WasmModuleCallCompiler(CustomInoutCallCompiler):
59
+ """Compiler for WASM calls
60
+ When a wasm method is called in guppy, we turn it into 2 tket ops:
61
+ * lookup: wasm.module -> wasm.func
62
+ * call: wasm.context * wasm.func * inputs -> wasm.context * output
63
+
64
+ For the wasm.module that we use in lookup, a constant is created for each
65
+ call, using the wasm file information embedded in method's `self` argument.
66
+ """
67
+
68
+ fn_name: str
69
+
70
+ def __init__(self, name: str) -> None:
71
+ self.fn_name = name
72
+
73
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
74
+ # The arguments should be:
75
+ # - a WASM context
76
+ # - any args meant for the WASM function
77
+ assert len(args) >= 1
78
+ module_ty = WASM_EXTENSION.get_type("module").instantiate([])
79
+ ctx_ty = WASM_EXTENSION.get_type("context").instantiate([])
80
+
81
+ fn_name_arg = ht.StringArg(self.fn_name)
82
+ # Function type without Inout context arg (for building)
83
+ assert isinstance(self.node, GlobalCall)
84
+ assert self.func is not None
85
+ wasm_sig = FunctionType(
86
+ self.func.ty.inputs[1:],
87
+ self.func.ty.output,
88
+ ).to_hugr(self.ctx)
89
+
90
+ inputs_row_arg = ht.ListArg([ty.type_arg() for ty in wasm_sig.input])
91
+ output_row_arg = ht.ListArg([ty.type_arg() for ty in wasm_sig.output])
92
+
93
+ func_ty = WASM_EXTENSION.get_type("func").instantiate(
94
+ [inputs_row_arg, output_row_arg]
95
+ )
96
+ future_ty = FUTURES_EXTENSION.get_type("Future").instantiate(
97
+ [ht.Tuple(*wasm_sig.output).type_arg()]
98
+ )
99
+
100
+ # Get the WASM module information from the type
101
+ selfarg = self.func.ty.inputs[0].ty
102
+ if info := wasm_module_info(selfarg):
103
+ const_module = self.builder.add_const(ConstWasmModule(*info))
104
+ else:
105
+ raise InternalGuppyError(
106
+ "Expected cached signature to have WASM module as first arg"
107
+ )
108
+
109
+ wasm_module = self.builder.load(const_module)
110
+
111
+ # Lookup the function we want
112
+ wasm_opdef = WASM_EXTENSION.get_op("lookup").instantiate(
113
+ [fn_name_arg, inputs_row_arg, output_row_arg],
114
+ ht.FunctionType([module_ty], [func_ty]),
115
+ )
116
+ wasm_func = self.builder.add_op(wasm_opdef, wasm_module)
117
+
118
+ # Call the function
119
+ call_op = WASM_EXTENSION.get_op("call").instantiate(
120
+ [inputs_row_arg, output_row_arg],
121
+ ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [ctx_ty, future_ty]),
122
+ )
123
+
124
+ ctx, future = self.builder.add_op(call_op, args[0], wasm_func, *args[1:])
125
+
126
+ read_opdef = FUTURES_EXTENSION.get_op("Read").instantiate(
127
+ [ht.Tuple(*wasm_sig.output).type_arg()],
128
+ ht.FunctionType([future_ty], [ht.Tuple(*wasm_sig.output)]),
129
+ )
130
+ result = self.builder.add_op(read_opdef, future)
131
+ ws: list[Wire] = list(result[:])
132
+ node = self.builder.add_op(ops.UnpackTuple(wasm_sig.output), *ws)
133
+ ws: list[Wire] = list(node[:])
134
+
135
+ return CallReturnWires(regular_returns=ws, inout_returns=[ctx])
File without changes