guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""Compilers building array functions on top of hugr standard operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import TYPE_CHECKING, Final, TypeVar
|
|
8
|
+
|
|
9
|
+
import hugr.std.collections
|
|
10
|
+
import hugr.std.int
|
|
11
|
+
import hugr.std.prelude
|
|
12
|
+
from hugr import Node, Wire, ops
|
|
13
|
+
from hugr import tys as ht
|
|
14
|
+
from hugr import val as hv
|
|
15
|
+
|
|
16
|
+
from guppylang_internals.compiler.core import CompilerContext, GlobalConstId
|
|
17
|
+
from guppylang_internals.definition.custom import (
|
|
18
|
+
CustomCallCompiler,
|
|
19
|
+
CustomInoutCallCompiler,
|
|
20
|
+
)
|
|
21
|
+
from guppylang_internals.definition.value import CallReturnWires
|
|
22
|
+
from guppylang_internals.error import InternalGuppyError
|
|
23
|
+
from guppylang_internals.nodes import ExitKind
|
|
24
|
+
from guppylang_internals.tys.subst import Inst
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from collections.abc import Callable
|
|
28
|
+
|
|
29
|
+
from hugr.build import function as hf
|
|
30
|
+
from hugr.build.dfg import DfBase
|
|
31
|
+
|
|
32
|
+
from guppylang_internals.tys.common import ToHugrContext
|
|
33
|
+
from guppylang_internals.tys.subst import Inst
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# --------------------------------------------
|
|
37
|
+
# --------------- prelude --------------------
|
|
38
|
+
# --------------------------------------------
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def error_type() -> ht.ExtType:
|
|
42
|
+
"""Returns the hugr type of an error value."""
|
|
43
|
+
return hugr.std.PRELUDE.types["error"].instantiate([])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ErrorVal(hv.ExtensionValue):
|
|
48
|
+
"""Custom value for a floating point number."""
|
|
49
|
+
|
|
50
|
+
signal: int
|
|
51
|
+
message: str
|
|
52
|
+
|
|
53
|
+
def to_value(self) -> hv.Extension:
|
|
54
|
+
name = "ConstError"
|
|
55
|
+
payload = {"signal": self.signal, "message": self.message}
|
|
56
|
+
return hv.Extension(name, typ=error_type(), val=payload)
|
|
57
|
+
|
|
58
|
+
def __str__(self) -> str:
|
|
59
|
+
return f"Error({self.signal}): {self.message}"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def panic(
|
|
63
|
+
inputs: list[ht.Type], outputs: list[ht.Type], kind: ExitKind = ExitKind.Panic
|
|
64
|
+
) -> ops.ExtOp:
|
|
65
|
+
"""Returns an operation that panics."""
|
|
66
|
+
name = "panic" if kind == ExitKind.Panic else "exit"
|
|
67
|
+
op_def = hugr.std.PRELUDE.get_op(name)
|
|
68
|
+
args: list[ht.TypeArg] = [
|
|
69
|
+
ht.ListArg([ht.TypeTypeArg(ty) for ty in inputs]),
|
|
70
|
+
ht.ListArg([ht.TypeTypeArg(ty) for ty in outputs]),
|
|
71
|
+
]
|
|
72
|
+
sig = ht.FunctionType([error_type(), *inputs], outputs)
|
|
73
|
+
return ops.ExtOp(op_def, sig, args)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# ------------------------------------------------------
|
|
77
|
+
# --------- Custom compilers for non-native ops --------
|
|
78
|
+
# ------------------------------------------------------
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def build_panic(
|
|
82
|
+
builder: DfBase[P],
|
|
83
|
+
in_tys: ht.TypeRow,
|
|
84
|
+
out_tys: ht.TypeRow,
|
|
85
|
+
err: Wire,
|
|
86
|
+
*args: Wire,
|
|
87
|
+
) -> Node:
|
|
88
|
+
"""Builds a panic operation."""
|
|
89
|
+
op = panic(in_tys, out_tys, ExitKind.Panic)
|
|
90
|
+
return builder.add_op(op, err, *args)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def build_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
|
|
94
|
+
"""Constructs and loads a static error value."""
|
|
95
|
+
val = ErrorVal(signal, msg)
|
|
96
|
+
return builder.load(builder.add_const(val))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# TODO: Common up build_unwrap_right and build_unwrap_left below once
|
|
100
|
+
# https://github.com/CQCL/hugr/issues/1596 is fixed
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def build_unwrap_right(
|
|
104
|
+
builder: DfBase[P], either: Wire, error_msg: str, error_signal: int = 1
|
|
105
|
+
) -> Node:
|
|
106
|
+
"""Unwraps the right value from a `hugr.tys.Either` value, panicking with the given
|
|
107
|
+
message if the result is left.
|
|
108
|
+
"""
|
|
109
|
+
conditional = builder.add_conditional(either)
|
|
110
|
+
result_ty = builder.hugr.port_type(either.out_port())
|
|
111
|
+
assert isinstance(result_ty, ht.Sum)
|
|
112
|
+
[left_tys, right_tys] = result_ty.variant_rows
|
|
113
|
+
with conditional.add_case(0) as case:
|
|
114
|
+
error = build_error(case, error_signal, error_msg)
|
|
115
|
+
case.set_outputs(*build_panic(case, left_tys, right_tys, error, *case.inputs()))
|
|
116
|
+
with conditional.add_case(1) as case:
|
|
117
|
+
case.set_outputs(*case.inputs())
|
|
118
|
+
return conditional.to_node()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
P = TypeVar("P", bound=ops.DfParentOp)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def build_unwrap_left(
|
|
125
|
+
builder: DfBase[P], either: Wire, error_msg: str, error_signal: int = 1
|
|
126
|
+
) -> Node:
|
|
127
|
+
"""Unwraps the left value from a `hugr.tys.Either` value, panicking with the given
|
|
128
|
+
message if the result is right.
|
|
129
|
+
"""
|
|
130
|
+
conditional = builder.add_conditional(either)
|
|
131
|
+
result_ty = builder.hugr.port_type(either.out_port())
|
|
132
|
+
assert isinstance(result_ty, ht.Sum)
|
|
133
|
+
[left_tys, right_tys] = result_ty.variant_rows
|
|
134
|
+
with conditional.add_case(0) as case:
|
|
135
|
+
case.set_outputs(*case.inputs())
|
|
136
|
+
with conditional.add_case(1) as case:
|
|
137
|
+
error = build_error(case, error_signal, error_msg)
|
|
138
|
+
case.set_outputs(*build_panic(case, right_tys, left_tys, error, *case.inputs()))
|
|
139
|
+
return conditional.to_node()
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def build_unwrap(
|
|
143
|
+
builder: DfBase[P], option: Wire, error_msg: str, error_signal: int = 1
|
|
144
|
+
) -> Node:
|
|
145
|
+
"""Unwraps an `hugr.tys.Option` value, panicking with the given message if the
|
|
146
|
+
result is an error.
|
|
147
|
+
"""
|
|
148
|
+
return build_unwrap_right(builder, option, error_msg, error_signal)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def build_expect_none(
|
|
152
|
+
builder: DfBase[P], option: Wire, error_msg: str, error_signal: int = 1
|
|
153
|
+
) -> Node:
|
|
154
|
+
"""Checks that `hugr.tys.Option` value is `None`, otherwise panics with the given
|
|
155
|
+
message.
|
|
156
|
+
"""
|
|
157
|
+
return build_unwrap_left(builder, option, error_msg, error_signal)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class MemSwapCompiler(CustomCallCompiler):
|
|
161
|
+
"""Compiler for the `mem_swap` function."""
|
|
162
|
+
|
|
163
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
164
|
+
[x, y] = args
|
|
165
|
+
return CallReturnWires(regular_returns=[], inout_returns=[y, x])
|
|
166
|
+
|
|
167
|
+
def compile(self, args: list[Wire]) -> list[Wire]:
|
|
168
|
+
raise InternalGuppyError("Call compile_with_inouts instead")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
UNWRAP_RESULT: Final[GlobalConstId] = GlobalConstId.fresh("unwrap_result")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _build_unwrap_result(func: hf.Function, result_type_var: ht.Variable) -> None:
|
|
175
|
+
either = func.inputs()[0]
|
|
176
|
+
conditional = func.add_conditional(either)
|
|
177
|
+
with conditional.add_case(0) as case:
|
|
178
|
+
[error] = list(case.inputs())
|
|
179
|
+
case.set_outputs(
|
|
180
|
+
*build_panic(case, [error_type()], [result_type_var], error, *case.inputs())
|
|
181
|
+
)
|
|
182
|
+
with conditional.add_case(1) as case:
|
|
183
|
+
case.set_outputs(*case.inputs())
|
|
184
|
+
func.set_outputs(*conditional.outputs())
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def unwrap_result(
|
|
188
|
+
builder: DfBase[P],
|
|
189
|
+
ctx: CompilerContext,
|
|
190
|
+
either: Wire,
|
|
191
|
+
) -> Wire:
|
|
192
|
+
"""Builds or retrieves and then calls a function that unwraps an `hugr.tys.Either`
|
|
193
|
+
value, panicking if the result is an error.
|
|
194
|
+
"""
|
|
195
|
+
either_ty = builder.hugr.port_type(either.out_port())
|
|
196
|
+
assert isinstance(either_ty, ht.Either)
|
|
197
|
+
[error_tys, result_tys] = either_ty.variant_rows
|
|
198
|
+
# Construct the function signature for unwrapping a result of type T.
|
|
199
|
+
func_ty = ht.PolyFuncType(
|
|
200
|
+
params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
|
|
201
|
+
body=ht.FunctionType(
|
|
202
|
+
input=[ht.Either(error_tys, [ht.Variable(0, ht.TypeBound.Linear)])],
|
|
203
|
+
output=[ht.Variable(0, ht.TypeBound.Linear)],
|
|
204
|
+
),
|
|
205
|
+
)
|
|
206
|
+
# Build global unwrap result function if it doesn't already exist.
|
|
207
|
+
func, already_exists = ctx.declare_global_func(UNWRAP_RESULT, func_ty)
|
|
208
|
+
if not already_exists:
|
|
209
|
+
_build_unwrap_result(func, ht.Variable(0, ht.TypeBound.Linear))
|
|
210
|
+
# Call the global function.
|
|
211
|
+
concrete_ty = ht.FunctionType(
|
|
212
|
+
input=[ht.Either(error_tys, result_tys)], output=result_tys
|
|
213
|
+
)
|
|
214
|
+
type_args = [ht.TypeTypeArg(*result_tys)]
|
|
215
|
+
func_call = builder.call(
|
|
216
|
+
func.parent_node,
|
|
217
|
+
either,
|
|
218
|
+
instantiation=concrete_ty,
|
|
219
|
+
type_args=type_args,
|
|
220
|
+
)
|
|
221
|
+
[result] = list(func_call.outputs())
|
|
222
|
+
return result
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class UnwrapOpCompiler(CustomInoutCallCompiler):
|
|
226
|
+
"""Compiler for operations that require unwrapping a result which could potentially
|
|
227
|
+
cause a panic.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
op: A HUGR operation that outputs an Either<error, result> value.
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
op: Callable[[ht.FunctionType, Inst, ToHugrContext], ops.DataflowOp]
|
|
234
|
+
|
|
235
|
+
def __init__(
|
|
236
|
+
self, op: Callable[[ht.FunctionType, Inst, ToHugrContext], ops.DataflowOp]
|
|
237
|
+
):
|
|
238
|
+
self.op = op
|
|
239
|
+
|
|
240
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
241
|
+
assert len(self.ty.output) == 1
|
|
242
|
+
# To instantiate the op we need a function signature that wraps the output of
|
|
243
|
+
# the function that is being compiled into an either type.
|
|
244
|
+
opt_func_type = ht.FunctionType(
|
|
245
|
+
input=self.ty.input,
|
|
246
|
+
output=[ht.Either([error_type()], self.ty.output)],
|
|
247
|
+
)
|
|
248
|
+
op = self.op(opt_func_type, self.type_args, self.ctx)
|
|
249
|
+
either = self.builder.add_op(op, *args)
|
|
250
|
+
result = unwrap_result(self.builder, self.ctx, either)
|
|
251
|
+
return CallReturnWires(regular_returns=[result], inout_returns=[])
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class BarrierCompiler(CustomCallCompiler):
|
|
255
|
+
"""Compiler for the `barrier` function."""
|
|
256
|
+
|
|
257
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
258
|
+
tys = [t for arg in args if (t := self.builder.hugr.port_type(arg.out_port()))]
|
|
259
|
+
|
|
260
|
+
op = hugr.std.prelude.PRELUDE_EXTENSION.get_op("Barrier").instantiate(
|
|
261
|
+
[ht.ListArg([ht.TypeTypeArg(ty) for ty in tys])]
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
barrier_n = self.builder.add_op(op, *args)
|
|
265
|
+
|
|
266
|
+
return CallReturnWires(
|
|
267
|
+
regular_returns=[], inout_returns=[barrier_n[i] for i in range(len(tys))]
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
def compile(self, args: list[Wire]) -> list[Wire]:
|
|
271
|
+
raise InternalGuppyError("Call compile_with_inouts instead")
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from hugr import Wire
|
|
2
|
+
from hugr import tys as ht
|
|
3
|
+
from hugr.std.int import int_t
|
|
4
|
+
|
|
5
|
+
from guppylang_internals.definition.custom import CustomInoutCallCompiler
|
|
6
|
+
from guppylang_internals.definition.value import CallReturnWires
|
|
7
|
+
from guppylang_internals.std._internal.compiler.arithmetic import inarrow_s, iwiden_s
|
|
8
|
+
from guppylang_internals.std._internal.compiler.prelude import build_unwrap_right
|
|
9
|
+
from guppylang_internals.std._internal.compiler.quantum import (
|
|
10
|
+
RNGCONTEXT_T,
|
|
11
|
+
)
|
|
12
|
+
from guppylang_internals.std._internal.compiler.tket_exts import (
|
|
13
|
+
QSYSTEM_RANDOM_EXTENSION,
|
|
14
|
+
)
|
|
15
|
+
from guppylang_internals.std._internal.util import external_op
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RandomIntCompiler(CustomInoutCallCompiler):
|
|
19
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
20
|
+
[ctx] = args
|
|
21
|
+
[rnd, ctx] = self.builder.add_op(
|
|
22
|
+
external_op("RandomInt", [], ext=QSYSTEM_RANDOM_EXTENSION)(
|
|
23
|
+
ht.FunctionType([RNGCONTEXT_T], [int_t(5), RNGCONTEXT_T]), [], self.ctx
|
|
24
|
+
),
|
|
25
|
+
ctx,
|
|
26
|
+
)
|
|
27
|
+
[rnd] = self.builder.add_op(iwiden_s(5, 6), rnd)
|
|
28
|
+
return CallReturnWires(regular_returns=[rnd], inout_returns=[ctx])
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RandomIntBoundedCompiler(CustomInoutCallCompiler):
|
|
32
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
33
|
+
[ctx, bound] = args
|
|
34
|
+
bound_sum = self.builder.add_op(inarrow_s(6, 5), bound)
|
|
35
|
+
bound = build_unwrap_right(
|
|
36
|
+
self.builder, bound_sum, "bound must be a 32-bit integer"
|
|
37
|
+
)
|
|
38
|
+
[rnd, ctx] = self.builder.add_op(
|
|
39
|
+
external_op("RandomIntBounded", [], ext=QSYSTEM_RANDOM_EXTENSION)(
|
|
40
|
+
ht.FunctionType([RNGCONTEXT_T, int_t(5)], [int_t(5), RNGCONTEXT_T]),
|
|
41
|
+
[],
|
|
42
|
+
self.ctx,
|
|
43
|
+
),
|
|
44
|
+
ctx,
|
|
45
|
+
bound,
|
|
46
|
+
)
|
|
47
|
+
[rnd] = self.builder.add_op(iwiden_s(5, 6), rnd)
|
|
48
|
+
return CallReturnWires(regular_returns=[rnd], inout_returns=[ctx])
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Compilers building list functions on top of hugr standard operations, that involve
|
|
2
|
+
multiple nodes.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from hugr import Wire, ops
|
|
8
|
+
from hugr import ext as he
|
|
9
|
+
from hugr import tys as ht
|
|
10
|
+
from hugr.std.float import FLOAT_T
|
|
11
|
+
|
|
12
|
+
from guppylang_internals.definition.custom import CustomInoutCallCompiler
|
|
13
|
+
from guppylang_internals.definition.value import CallReturnWires
|
|
14
|
+
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
|
|
15
|
+
from guppylang_internals.std._internal.compiler.tket_exts import (
|
|
16
|
+
QSYSTEM_RANDOM_EXTENSION,
|
|
17
|
+
QUANTUM_EXTENSION,
|
|
18
|
+
ROTATION_EXTENSION,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# ----------------------------------------------
|
|
22
|
+
# --------- tket.* extensions -----------------
|
|
23
|
+
# ----------------------------------------------
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
RNGCONTEXT_T_DEF = QSYSTEM_RANDOM_EXTENSION.get_type("context")
|
|
27
|
+
RNGCONTEXT_T = ht.ExtType(RNGCONTEXT_T_DEF)
|
|
28
|
+
|
|
29
|
+
ROTATION_T_DEF = ROTATION_EXTENSION.get_type("rotation")
|
|
30
|
+
ROTATION_T = ht.ExtType(ROTATION_T_DEF)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def from_halfturns_unchecked() -> ops.ExtOp:
|
|
34
|
+
return ops.ExtOp(
|
|
35
|
+
ROTATION_EXTENSION.get_op("from_halfturns_unchecked"),
|
|
36
|
+
ht.FunctionType([FLOAT_T], [ROTATION_T]),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# ------------------------------------------------------
|
|
41
|
+
# --------- Custom compilers for non-native ops --------
|
|
42
|
+
# ------------------------------------------------------
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class InoutMeasureCompiler(CustomInoutCallCompiler):
|
|
46
|
+
"""Compiler for the measure functions with an inout qubit
|
|
47
|
+
such as the `project_z` function - requiring conversion to tket.bool."""
|
|
48
|
+
|
|
49
|
+
opname: str
|
|
50
|
+
ext: he.Extension
|
|
51
|
+
|
|
52
|
+
def __init__(self, opname: str | None = None, ext: he.Extension | None = None):
|
|
53
|
+
self.opname = opname or "Measure"
|
|
54
|
+
self.ext = ext or QUANTUM_EXTENSION
|
|
55
|
+
|
|
56
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
57
|
+
from guppylang_internals.std._internal.util import quantum_op
|
|
58
|
+
|
|
59
|
+
[q] = args
|
|
60
|
+
[q, bit] = self.builder.add_op(
|
|
61
|
+
quantum_op(self.opname, ext=self.ext)(
|
|
62
|
+
ht.FunctionType([ht.Qubit], [ht.Qubit, ht.Bool]), [], self.ctx
|
|
63
|
+
),
|
|
64
|
+
q,
|
|
65
|
+
)
|
|
66
|
+
bit = self.builder.add_op(make_opaque(), bit)
|
|
67
|
+
return CallReturnWires(regular_returns=[bit], inout_returns=[q])
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class InoutMeasureResetCompiler(CustomInoutCallCompiler):
|
|
71
|
+
"""Compiler for the measure functions with an inout qubit
|
|
72
|
+
such as the `project_z` function."""
|
|
73
|
+
|
|
74
|
+
opname: str
|
|
75
|
+
ext: he.Extension
|
|
76
|
+
|
|
77
|
+
def __init__(self, opname: str | None = None, ext: he.Extension | None = None):
|
|
78
|
+
self.opname = opname or "Measure"
|
|
79
|
+
self.ext = ext or QUANTUM_EXTENSION
|
|
80
|
+
|
|
81
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
82
|
+
from guppylang_internals.std._internal.util import quantum_op
|
|
83
|
+
|
|
84
|
+
[q] = args
|
|
85
|
+
[q, bit] = self.builder.add_op(
|
|
86
|
+
quantum_op(self.opname, ext=self.ext)(
|
|
87
|
+
ht.FunctionType([ht.Qubit], [ht.Qubit, OpaqueBool]), [], self.ctx
|
|
88
|
+
),
|
|
89
|
+
q,
|
|
90
|
+
)
|
|
91
|
+
return CallReturnWires(regular_returns=[bit], inout_returns=[q])
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class RotationCompiler(CustomInoutCallCompiler):
|
|
95
|
+
opname: str
|
|
96
|
+
|
|
97
|
+
def __init__(self, opname: str):
|
|
98
|
+
self.opname = opname
|
|
99
|
+
|
|
100
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
101
|
+
from guppylang_internals.std._internal.util import quantum_op
|
|
102
|
+
|
|
103
|
+
[*qs, angle] = args
|
|
104
|
+
[halfturns] = self.builder.add_op(ops.UnpackTuple([FLOAT_T]), angle)
|
|
105
|
+
[rotation] = self.builder.add_op(from_halfturns_unchecked(), halfturns)
|
|
106
|
+
|
|
107
|
+
qs = self.builder.add_op(
|
|
108
|
+
quantum_op(self.opname)(
|
|
109
|
+
ht.FunctionType(
|
|
110
|
+
[ht.Qubit for _ in qs] + [ROTATION_T], [ht.Qubit for _ in qs]
|
|
111
|
+
),
|
|
112
|
+
[],
|
|
113
|
+
self.ctx,
|
|
114
|
+
),
|
|
115
|
+
*qs,
|
|
116
|
+
rotation,
|
|
117
|
+
)
|
|
118
|
+
return CallReturnWires(regular_returns=[], inout_returns=list(qs))
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from hugr import ops
|
|
4
|
+
from hugr import tys as ht
|
|
5
|
+
from hugr import val as hv
|
|
6
|
+
|
|
7
|
+
from guppylang_internals.std._internal.compiler.tket_exts import BOOL_EXTENSION
|
|
8
|
+
|
|
9
|
+
BOOL_DEF = BOOL_EXTENSION.get_type("bool")
|
|
10
|
+
OpaqueBool = ht.ExtType(BOOL_DEF)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def read_bool() -> ops.ExtOp:
|
|
14
|
+
return ops.ExtOp(
|
|
15
|
+
BOOL_EXTENSION.get_op("read"),
|
|
16
|
+
ht.FunctionType([OpaqueBool], [ht.Bool]),
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def make_opaque() -> ops.ExtOp:
|
|
21
|
+
return ops.ExtOp(
|
|
22
|
+
BOOL_EXTENSION.get_op("make_opaque"),
|
|
23
|
+
ht.FunctionType([ht.Bool], [OpaqueBool]),
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def not_op() -> ops.ExtOp:
|
|
28
|
+
return ops.ExtOp(
|
|
29
|
+
BOOL_EXTENSION.get_op("not"),
|
|
30
|
+
ht.FunctionType([OpaqueBool], [OpaqueBool]),
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class OpaqueBoolVal(hv.ExtensionValue):
|
|
36
|
+
"""Custom value for a boolean."""
|
|
37
|
+
|
|
38
|
+
v: bool
|
|
39
|
+
|
|
40
|
+
def to_value(self) -> hv.Extension:
|
|
41
|
+
name = "ConstBool"
|
|
42
|
+
payload = self.v
|
|
43
|
+
return hv.Extension(
|
|
44
|
+
name,
|
|
45
|
+
typ=OpaqueBool,
|
|
46
|
+
val=payload,
|
|
47
|
+
extensions=[BOOL_EXTENSION.name],
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def __str__(self) -> str:
|
|
51
|
+
return f"{self.v}"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
OPAQUE_TRUE = OpaqueBoolVal(True)
|
|
55
|
+
OPAQUE_FALSE = OpaqueBoolVal(False)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from hugr import val
|
|
4
|
+
from tket_exts import (
|
|
5
|
+
debug,
|
|
6
|
+
futures,
|
|
7
|
+
opaque_bool,
|
|
8
|
+
qsystem,
|
|
9
|
+
qsystem_random,
|
|
10
|
+
qsystem_utils,
|
|
11
|
+
quantum,
|
|
12
|
+
result,
|
|
13
|
+
rotation,
|
|
14
|
+
wasm,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
BOOL_EXTENSION = opaque_bool()
|
|
18
|
+
DEBUG_EXTENSION = debug()
|
|
19
|
+
FUTURES_EXTENSION = futures()
|
|
20
|
+
QSYSTEM_EXTENSION = qsystem()
|
|
21
|
+
QSYSTEM_RANDOM_EXTENSION = qsystem_random()
|
|
22
|
+
QSYSTEM_UTILS_EXTENSION = qsystem_utils()
|
|
23
|
+
QUANTUM_EXTENSION = quantum()
|
|
24
|
+
RESULT_EXTENSION = result()
|
|
25
|
+
ROTATION_EXTENSION = rotation()
|
|
26
|
+
WASM_EXTENSION = wasm()
|
|
27
|
+
|
|
28
|
+
TKET_EXTENSIONS = [
|
|
29
|
+
BOOL_EXTENSION,
|
|
30
|
+
DEBUG_EXTENSION,
|
|
31
|
+
FUTURES_EXTENSION,
|
|
32
|
+
QSYSTEM_EXTENSION,
|
|
33
|
+
QSYSTEM_RANDOM_EXTENSION,
|
|
34
|
+
QSYSTEM_UTILS_EXTENSION,
|
|
35
|
+
QUANTUM_EXTENSION,
|
|
36
|
+
RESULT_EXTENSION,
|
|
37
|
+
ROTATION_EXTENSION,
|
|
38
|
+
WASM_EXTENSION,
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class ConstWasmModule(val.ExtensionValue):
|
|
44
|
+
"""Python wrapper for the tket ConstWasmModule type"""
|
|
45
|
+
|
|
46
|
+
wasm_file: str
|
|
47
|
+
wasm_hash: int
|
|
48
|
+
|
|
49
|
+
def to_value(self) -> val.Extension:
|
|
50
|
+
ty = WASM_EXTENSION.get_type("module").instantiate([])
|
|
51
|
+
|
|
52
|
+
name = "tket.wasm.ConstWasmModule"
|
|
53
|
+
payload = {"name": self.wasm_file, "hash": self.wasm_hash}
|
|
54
|
+
return val.Extension(name, typ=ty, val=payload, extensions=["tket.wasm"])
|
|
55
|
+
|
|
56
|
+
def __str__(self) -> str:
|
|
57
|
+
return (
|
|
58
|
+
f"ConstWasmModule(wasm_file={self.wasm_file}, wasm_hash={self.wasm_hash})"
|
|
59
|
+
)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from hugr import Wire, ops
|
|
2
|
+
from hugr import tys as ht
|
|
3
|
+
|
|
4
|
+
from guppylang_internals.definition.custom import CustomInoutCallCompiler
|
|
5
|
+
from guppylang_internals.definition.value import CallReturnWires
|
|
6
|
+
from guppylang_internals.error import InternalGuppyError
|
|
7
|
+
from guppylang_internals.nodes import GlobalCall
|
|
8
|
+
from guppylang_internals.std._internal.compiler.arithmetic import convert_itousize
|
|
9
|
+
from guppylang_internals.std._internal.compiler.prelude import build_unwrap
|
|
10
|
+
from guppylang_internals.std._internal.compiler.tket_exts import (
|
|
11
|
+
FUTURES_EXTENSION,
|
|
12
|
+
WASM_EXTENSION,
|
|
13
|
+
ConstWasmModule,
|
|
14
|
+
)
|
|
15
|
+
from guppylang_internals.tys.builtin import (
|
|
16
|
+
wasm_module_info,
|
|
17
|
+
)
|
|
18
|
+
from guppylang_internals.tys.ty import (
|
|
19
|
+
FunctionType,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class WasmModuleInitCompiler(CustomInoutCallCompiler):
|
|
24
|
+
"""Compiler for initialising WASM modules.
|
|
25
|
+
Calls tket's "get_context" and unwraps the `Option` result.
|
|
26
|
+
Returns a `tket.wasm.context` wire.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
30
|
+
# Make a ConstWasmModule as a CustomConst
|
|
31
|
+
assert len(args) == 1
|
|
32
|
+
ctx_arg = args[0]
|
|
33
|
+
ctx_wire = self.builder.add_op(convert_itousize(), ctx_arg)
|
|
34
|
+
|
|
35
|
+
ctx_ty = WASM_EXTENSION.get_type("context").instantiate([])
|
|
36
|
+
get_ctx_op = ops.ExtOp(
|
|
37
|
+
WASM_EXTENSION.get_op("get_context"),
|
|
38
|
+
ht.FunctionType([ht.USize()], [ht.Option(ctx_ty)]),
|
|
39
|
+
)
|
|
40
|
+
node = self.builder.add_op(get_ctx_op, ctx_wire)
|
|
41
|
+
opt_w: Wire = node[0]
|
|
42
|
+
err = "Failed to spawn WASM context"
|
|
43
|
+
out_node = build_unwrap(self.builder, opt_w, err)
|
|
44
|
+
return CallReturnWires(regular_returns=[out_node], inout_returns=[])
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class WasmModuleDiscardCompiler(CustomInoutCallCompiler):
|
|
48
|
+
"""Compiler for discarding WASM contexts."""
|
|
49
|
+
|
|
50
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
51
|
+
assert len(args) == 1
|
|
52
|
+
ctx = args[0]
|
|
53
|
+
op = WASM_EXTENSION.get_op("dispose_context").instantiate([])
|
|
54
|
+
self.builder.add_op(op, ctx)
|
|
55
|
+
return CallReturnWires(regular_returns=[], inout_returns=[])
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class WasmModuleCallCompiler(CustomInoutCallCompiler):
|
|
59
|
+
"""Compiler for WASM calls
|
|
60
|
+
When a wasm method is called in guppy, we turn it into 2 tket ops:
|
|
61
|
+
* lookup: wasm.module -> wasm.func
|
|
62
|
+
* call: wasm.context * wasm.func * inputs -> wasm.context * output
|
|
63
|
+
|
|
64
|
+
For the wasm.module that we use in lookup, a constant is created for each
|
|
65
|
+
call, using the wasm file information embedded in method's `self` argument.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
fn_name: str
|
|
69
|
+
|
|
70
|
+
def __init__(self, name: str) -> None:
|
|
71
|
+
self.fn_name = name
|
|
72
|
+
|
|
73
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
74
|
+
# The arguments should be:
|
|
75
|
+
# - a WASM context
|
|
76
|
+
# - any args meant for the WASM function
|
|
77
|
+
assert len(args) >= 1
|
|
78
|
+
module_ty = WASM_EXTENSION.get_type("module").instantiate([])
|
|
79
|
+
ctx_ty = WASM_EXTENSION.get_type("context").instantiate([])
|
|
80
|
+
|
|
81
|
+
fn_name_arg = ht.StringArg(self.fn_name)
|
|
82
|
+
# Function type without Inout context arg (for building)
|
|
83
|
+
assert isinstance(self.node, GlobalCall)
|
|
84
|
+
assert self.func is not None
|
|
85
|
+
wasm_sig = FunctionType(
|
|
86
|
+
self.func.ty.inputs[1:],
|
|
87
|
+
self.func.ty.output,
|
|
88
|
+
).to_hugr(self.ctx)
|
|
89
|
+
|
|
90
|
+
inputs_row_arg = ht.ListArg([ty.type_arg() for ty in wasm_sig.input])
|
|
91
|
+
output_row_arg = ht.ListArg([ty.type_arg() for ty in wasm_sig.output])
|
|
92
|
+
|
|
93
|
+
func_ty = WASM_EXTENSION.get_type("func").instantiate(
|
|
94
|
+
[inputs_row_arg, output_row_arg]
|
|
95
|
+
)
|
|
96
|
+
future_ty = FUTURES_EXTENSION.get_type("Future").instantiate(
|
|
97
|
+
[ht.Tuple(*wasm_sig.output).type_arg()]
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Get the WASM module information from the type
|
|
101
|
+
selfarg = self.func.ty.inputs[0].ty
|
|
102
|
+
if info := wasm_module_info(selfarg):
|
|
103
|
+
const_module = self.builder.add_const(ConstWasmModule(*info))
|
|
104
|
+
else:
|
|
105
|
+
raise InternalGuppyError(
|
|
106
|
+
"Expected cached signature to have WASM module as first arg"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
wasm_module = self.builder.load(const_module)
|
|
110
|
+
|
|
111
|
+
# Lookup the function we want
|
|
112
|
+
wasm_opdef = WASM_EXTENSION.get_op("lookup").instantiate(
|
|
113
|
+
[fn_name_arg, inputs_row_arg, output_row_arg],
|
|
114
|
+
ht.FunctionType([module_ty], [func_ty]),
|
|
115
|
+
)
|
|
116
|
+
wasm_func = self.builder.add_op(wasm_opdef, wasm_module)
|
|
117
|
+
|
|
118
|
+
# Call the function
|
|
119
|
+
call_op = WASM_EXTENSION.get_op("call").instantiate(
|
|
120
|
+
[inputs_row_arg, output_row_arg],
|
|
121
|
+
ht.FunctionType([ctx_ty, func_ty, *wasm_sig.input], [ctx_ty, future_ty]),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
ctx, future = self.builder.add_op(call_op, args[0], wasm_func, *args[1:])
|
|
125
|
+
|
|
126
|
+
read_opdef = FUTURES_EXTENSION.get_op("Read").instantiate(
|
|
127
|
+
[ht.Tuple(*wasm_sig.output).type_arg()],
|
|
128
|
+
ht.FunctionType([future_ty], [ht.Tuple(*wasm_sig.output)]),
|
|
129
|
+
)
|
|
130
|
+
result = self.builder.add_op(read_opdef, future)
|
|
131
|
+
ws: list[Wire] = list(result[:])
|
|
132
|
+
node = self.builder.add_op(ops.UnpackTuple(wasm_sig.output), *ws)
|
|
133
|
+
ws: list[Wire] = list(node[:])
|
|
134
|
+
|
|
135
|
+
return CallReturnWires(regular_returns=ws, inout_returns=[ctx])
|
|
File without changes
|