guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +118 -5
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +5 -2
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +58 -17
- guppylang_internals/checker/func_checker.py +18 -14
- guppylang_internals/checker/linearity_checker.py +67 -10
- guppylang_internals/checker/modifier_checker.py +120 -0
- guppylang_internals/checker/stmt_checker.py +48 -1
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +93 -56
- guppylang_internals/compiler/expr_compiler.py +72 -168
- guppylang_internals/compiler/modifier_compiler.py +176 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +86 -7
- guppylang_internals/definition/custom.py +39 -1
- guppylang_internals/definition/declaration.py +9 -6
- guppylang_internals/definition/function.py +12 -2
- guppylang_internals/definition/parameter.py +8 -3
- guppylang_internals/definition/pytket_circuits.py +14 -41
- guppylang_internals/definition/struct.py +13 -7
- guppylang_internals/definition/ty.py +3 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/engine.py +9 -3
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +147 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +95 -283
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +10 -0
- guppylang_internals/tracing/unpacking.py +19 -20
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +2 -5
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +11 -24
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +62 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +91 -85
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
- guppylang_internals-0.26.0.dist-info/RECORD +104 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
- guppylang_internals-0.24.0.dist-info/RECORD +0 -98
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -46,7 +46,6 @@ from guppylang_internals.std._internal.compiler.array import (
|
|
|
46
46
|
array_new,
|
|
47
47
|
array_unpack,
|
|
48
48
|
)
|
|
49
|
-
from guppylang_internals.std._internal.compiler.prelude import build_unwrap
|
|
50
49
|
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
|
|
51
50
|
from guppylang_internals.tys.builtin import array_type, bool_type, float_type
|
|
52
51
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
@@ -195,17 +194,9 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
195
194
|
# them into separate wires.
|
|
196
195
|
for i, q_reg in enumerate(self.input_circuit.q_registers):
|
|
197
196
|
reg_wire = outer_func.inputs()[i]
|
|
198
|
-
|
|
199
|
-
array_unpack(ht.
|
|
197
|
+
elem_wires = outer_func.add_op(
|
|
198
|
+
array_unpack(ht.Qubit, q_reg.size), reg_wire
|
|
200
199
|
)
|
|
201
|
-
elem_wires = [
|
|
202
|
-
build_unwrap(
|
|
203
|
-
outer_func,
|
|
204
|
-
opt_elem,
|
|
205
|
-
"Internal error: unwrapping of array element failed",
|
|
206
|
-
)
|
|
207
|
-
for opt_elem in opt_elem_wires
|
|
208
|
-
]
|
|
209
200
|
input_list.extend(elem_wires)
|
|
210
201
|
|
|
211
202
|
else:
|
|
@@ -219,7 +210,8 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
219
210
|
]
|
|
220
211
|
|
|
221
212
|
# Symbolic parameters (if present) get passed after qubits and bools.
|
|
222
|
-
|
|
213
|
+
num_params = len(self.input_circuit.free_symbols())
|
|
214
|
+
has_params = num_params != 0
|
|
223
215
|
if has_params and "TKET1.input_parameters" not in hugr_func.metadata:
|
|
224
216
|
raise InternalGuppyError(
|
|
225
217
|
"Parameter metadata is missing from pytket circuit HUGR"
|
|
@@ -230,26 +222,17 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
230
222
|
if has_params:
|
|
231
223
|
lex_params: list[Wire] = list(outer_func.inputs()[offset:])
|
|
232
224
|
if self.use_arrays:
|
|
233
|
-
|
|
225
|
+
unpack_result = outer_func.add_op(
|
|
234
226
|
array_unpack(
|
|
235
|
-
ht.
|
|
236
|
-
q_reg.size,
|
|
227
|
+
ht.Tuple(float_type().to_hugr(ctx)), num_params
|
|
237
228
|
),
|
|
238
229
|
lex_params[0],
|
|
239
230
|
)
|
|
240
|
-
lex_params =
|
|
241
|
-
build_unwrap(
|
|
242
|
-
outer_func,
|
|
243
|
-
opt_param,
|
|
244
|
-
"Internal error: unwrapping of array element failed",
|
|
245
|
-
)
|
|
246
|
-
for opt_param in opt_param_wires
|
|
247
|
-
]
|
|
231
|
+
lex_params = list(unpack_result)
|
|
248
232
|
param_order = cast(
|
|
249
233
|
list[str], hugr_func.metadata["TKET1.input_parameters"]
|
|
250
234
|
)
|
|
251
235
|
lex_names = sorted(param_order)
|
|
252
|
-
assert len(lex_names) == len(lex_params)
|
|
253
236
|
name_to_param = dict(zip(lex_names, lex_params, strict=True))
|
|
254
237
|
angle_wires = [name_to_param[name] for name in param_order]
|
|
255
238
|
# Need to convert all angles to floats.
|
|
@@ -280,34 +263,23 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
280
263
|
]
|
|
281
264
|
|
|
282
265
|
if self.use_arrays:
|
|
283
|
-
|
|
284
|
-
def pack(elems: list[Wire], elem_ty: ht.Type, length: int) -> Wire:
|
|
285
|
-
elem_opts = [
|
|
286
|
-
outer_func.add_op(ops.Some(elem_ty), elem) for elem in elems
|
|
287
|
-
]
|
|
288
|
-
return outer_func.add_op(
|
|
289
|
-
array_new(ht.Option(elem_ty), length), *elem_opts
|
|
290
|
-
)
|
|
291
|
-
|
|
292
266
|
array_wires: list[Wire] = []
|
|
293
267
|
wire_idx = 0
|
|
294
268
|
# First pack bool results into an array.
|
|
295
269
|
for c_reg in self.input_circuit.c_registers:
|
|
296
270
|
array_wires.append(
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
c_reg.size,
|
|
271
|
+
outer_func.add_op(
|
|
272
|
+
array_new(OpaqueBool, c_reg.size),
|
|
273
|
+
*wires[wire_idx : wire_idx + c_reg.size],
|
|
301
274
|
)
|
|
302
275
|
)
|
|
303
276
|
wire_idx = wire_idx + c_reg.size
|
|
304
277
|
# Then the borrowed qubits also need to be put back into arrays.
|
|
305
278
|
for q_reg in self.input_circuit.q_registers:
|
|
306
279
|
array_wires.append(
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
q_reg.size,
|
|
280
|
+
outer_func.add_op(
|
|
281
|
+
array_new(ht.Qubit, q_reg.size),
|
|
282
|
+
*wires[wire_idx : wire_idx + q_reg.size],
|
|
311
283
|
)
|
|
312
284
|
)
|
|
313
285
|
wire_idx = wire_idx + q_reg.size
|
|
@@ -398,6 +370,7 @@ def _signature_from_circuit(
|
|
|
398
370
|
use_arrays: bool = False,
|
|
399
371
|
) -> FunctionType:
|
|
400
372
|
"""Helper function for inferring a function signature from a pytket circuit."""
|
|
373
|
+
# May want to set proper unitary flags in the future.
|
|
401
374
|
try:
|
|
402
375
|
import pytket
|
|
403
376
|
|
|
@@ -131,10 +131,13 @@ class RawStructDef(TypeDef, ParsableDef):
|
|
|
131
131
|
if cls_def.type_params:
|
|
132
132
|
first, last = cls_def.type_params[0], cls_def.type_params[-1]
|
|
133
133
|
params_span = Span(to_span(first).start, to_span(last).end)
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
134
|
+
param_vars_mapping: dict[str, Parameter] = {}
|
|
135
|
+
for idx, param_node in enumerate(cls_def.type_params):
|
|
136
|
+
param = parse_parameter(
|
|
137
|
+
param_node, idx, globals, param_vars_mapping
|
|
138
|
+
)
|
|
139
|
+
param_vars_mapping[param.name] = param
|
|
140
|
+
params.append(param)
|
|
138
141
|
|
|
139
142
|
# The only base we allow is `Generic[...]` to specify generic parameters with
|
|
140
143
|
# the legacy syntax
|
|
@@ -270,13 +273,16 @@ class CheckedStructDef(TypeDef, CompiledDef):
|
|
|
270
273
|
|
|
271
274
|
constructor_sig = FunctionType(
|
|
272
275
|
inputs=[
|
|
273
|
-
FuncInput(
|
|
276
|
+
FuncInput(
|
|
277
|
+
f.ty,
|
|
278
|
+
InputFlags.Owned if f.ty.linear else InputFlags.NoFlags,
|
|
279
|
+
f.name,
|
|
280
|
+
)
|
|
274
281
|
for f in self.fields
|
|
275
282
|
],
|
|
276
283
|
output=StructType(
|
|
277
284
|
defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)]
|
|
278
285
|
),
|
|
279
|
-
input_names=[f.name for f in self.fields],
|
|
280
286
|
params=self.params,
|
|
281
287
|
)
|
|
282
288
|
constructor_def = CustomFunctionDef(
|
|
@@ -314,7 +320,7 @@ def parse_py_class(
|
|
|
314
320
|
raise GuppyError(UnknownSourceError(None, cls))
|
|
315
321
|
|
|
316
322
|
# We can't rely on `inspect.getsourcelines` since it doesn't work properly for
|
|
317
|
-
# classes prior to Python 3.13. See https://github.com/
|
|
323
|
+
# classes prior to Python 3.13. See https://github.com/quantinuum/guppylang/issues/1107.
|
|
318
324
|
# Instead, we reproduce the behaviour of Python >= 3.13 using the `__firstlineno__`
|
|
319
325
|
# attribute. See https://github.com/python/cpython/blob/3.13/Lib/inspect.py#L1052.
|
|
320
326
|
# In the decorator, we make sure that `__firstlineno__` is set, even if we're not
|
|
@@ -2,7 +2,7 @@ from abc import abstractmethod
|
|
|
2
2
|
from collections.abc import Callable, Sequence
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
|
|
5
|
-
from hugr import tys
|
|
5
|
+
from hugr import tys as ht
|
|
6
6
|
|
|
7
7
|
from guppylang_internals.ast_util import AstNode
|
|
8
8
|
from guppylang_internals.definition.common import CompiledDef, Definition
|
|
@@ -42,8 +42,8 @@ class OpaqueTypeDef(TypeDef, CompiledDef):
|
|
|
42
42
|
params: Sequence[Parameter]
|
|
43
43
|
never_copyable: bool
|
|
44
44
|
never_droppable: bool
|
|
45
|
-
to_hugr: Callable[[Sequence[Argument], ToHugrContext],
|
|
46
|
-
bound:
|
|
45
|
+
to_hugr: Callable[[Sequence[Argument], ToHugrContext], ht.Type]
|
|
46
|
+
bound: ht.TypeBound | None = None
|
|
47
47
|
|
|
48
48
|
def check_instantiate(
|
|
49
49
|
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
1
2
|
from typing import TYPE_CHECKING
|
|
2
3
|
|
|
3
4
|
from guppylang_internals.ast_util import AstNode
|
|
@@ -9,7 +10,8 @@ from guppylang_internals.definition.custom import (
|
|
|
9
10
|
CustomFunctionDef,
|
|
10
11
|
RawCustomFunctionDef,
|
|
11
12
|
)
|
|
12
|
-
from guppylang_internals.
|
|
13
|
+
from guppylang_internals.engine import DEF_STORE
|
|
14
|
+
from guppylang_internals.error import GuppyError, GuppyTypeError
|
|
13
15
|
from guppylang_internals.span import SourceMap
|
|
14
16
|
from guppylang_internals.tys.builtin import wasm_module_name
|
|
15
17
|
from guppylang_internals.tys.ty import (
|
|
@@ -21,24 +23,35 @@ from guppylang_internals.tys.ty import (
|
|
|
21
23
|
TupleType,
|
|
22
24
|
Type,
|
|
23
25
|
)
|
|
26
|
+
from guppylang_internals.wasm_util import WasmSigMismatchError
|
|
24
27
|
|
|
25
28
|
if TYPE_CHECKING:
|
|
26
29
|
from guppylang_internals.checker.core import Globals
|
|
27
30
|
|
|
28
31
|
|
|
32
|
+
@dataclass(frozen=True)
|
|
29
33
|
class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
30
|
-
|
|
34
|
+
# If a function is specified in the @wasm decorator by its index in the wasm
|
|
35
|
+
# file, record what the index was.
|
|
36
|
+
wasm_index: int | None = field(default=None)
|
|
37
|
+
|
|
38
|
+
def sanitise_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
|
|
31
39
|
# Place to highlight in error messages
|
|
32
|
-
match fun_ty.inputs
|
|
33
|
-
case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
|
|
40
|
+
match fun_ty.inputs:
|
|
41
|
+
case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if wasm_module_name(
|
|
34
42
|
ty
|
|
35
43
|
) is not None:
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
44
|
+
for inp in args:
|
|
45
|
+
if not self.is_type_wasmable(inp.ty):
|
|
46
|
+
raise GuppyError(UnWasmableType(loc, inp.ty))
|
|
47
|
+
case [FuncInput(ty=ty), *_]:
|
|
48
|
+
raise GuppyError(
|
|
49
|
+
FirstArgNotModule(loc).add_sub_diagnostic(
|
|
50
|
+
FirstArgNotModule.GotOtherType(loc, ty)
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
case []:
|
|
54
|
+
raise GuppyError(FirstArgNotModule(loc))
|
|
42
55
|
if not self.is_type_wasmable(fun_ty.output):
|
|
43
56
|
match fun_ty.output:
|
|
44
57
|
case NoneType():
|
|
@@ -46,6 +59,23 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
46
59
|
case _:
|
|
47
60
|
raise GuppyError(UnWasmableType(loc, fun_ty.output))
|
|
48
61
|
|
|
62
|
+
def validate_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
|
|
63
|
+
type_in_wasm: FunctionType = DEF_STORE.wasm_functions[self.id]
|
|
64
|
+
assert type_in_wasm is not None
|
|
65
|
+
# Drop the first arg because it should be "self"
|
|
66
|
+
expected_type = FunctionType(fun_ty.inputs[1:], fun_ty.output)
|
|
67
|
+
|
|
68
|
+
if expected_type != type_in_wasm:
|
|
69
|
+
raise GuppyTypeError(
|
|
70
|
+
WasmSigMismatchError(loc)
|
|
71
|
+
.add_sub_diagnostic(
|
|
72
|
+
WasmSigMismatchError.Declaration(None, declared=str(expected_type))
|
|
73
|
+
)
|
|
74
|
+
.add_sub_diagnostic(
|
|
75
|
+
WasmSigMismatchError.Actual(None, actual=str(type_in_wasm))
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
|
|
49
79
|
def is_type_wasmable(self, ty: Type) -> bool:
|
|
50
80
|
match ty:
|
|
51
81
|
case NumericType():
|
|
@@ -57,5 +87,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
57
87
|
|
|
58
88
|
def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
|
|
59
89
|
parsed = super().parse(globals, sources)
|
|
90
|
+
assert parsed.defined_at is not None
|
|
60
91
|
self.sanitise_type(parsed.defined_at, parsed.ty)
|
|
92
|
+
self.validate_type(parsed.defined_at, parsed.ty)
|
|
61
93
|
return parsed
|
guppylang_internals/engine.py
CHANGED
|
@@ -46,6 +46,7 @@ from guppylang_internals.tys.builtin import (
|
|
|
46
46
|
string_type_def,
|
|
47
47
|
tuple_type_def,
|
|
48
48
|
)
|
|
49
|
+
from guppylang_internals.tys.ty import FunctionType
|
|
49
50
|
|
|
50
51
|
if TYPE_CHECKING:
|
|
51
52
|
from guppylang_internals.compiler.core import MonoDefId
|
|
@@ -87,6 +88,7 @@ class DefinitionStore:
|
|
|
87
88
|
raw_defs: dict[DefId, RawDef]
|
|
88
89
|
impls: defaultdict[DefId, dict[str, DefId]]
|
|
89
90
|
impl_parents: dict[DefId, DefId]
|
|
91
|
+
wasm_functions: dict[DefId, FunctionType]
|
|
90
92
|
frames: dict[DefId, FrameType]
|
|
91
93
|
sources: SourceMap
|
|
92
94
|
|
|
@@ -96,6 +98,7 @@ class DefinitionStore:
|
|
|
96
98
|
self.impl_parents = {}
|
|
97
99
|
self.frames = {}
|
|
98
100
|
self.sources = SourceMap()
|
|
101
|
+
self.wasm_functions = {}
|
|
99
102
|
|
|
100
103
|
def register_def(self, defn: RawDef, frame: FrameType | None) -> None:
|
|
101
104
|
self.raw_defs[defn.id] = defn
|
|
@@ -123,6 +126,9 @@ class DefinitionStore:
|
|
|
123
126
|
assert frame is not None
|
|
124
127
|
self.frames[impl_id] = frame
|
|
125
128
|
|
|
129
|
+
def register_wasm_function(self, fn_id: DefId, sig: FunctionType) -> None:
|
|
130
|
+
self.wasm_functions[fn_id] = sig
|
|
131
|
+
|
|
126
132
|
|
|
127
133
|
DEF_STORE: DefinitionStore = DefinitionStore()
|
|
128
134
|
|
|
@@ -263,8 +269,8 @@ class CompilationEngine:
|
|
|
263
269
|
and isinstance(compiled_def, CompiledCallableDef)
|
|
264
270
|
and not isinstance(graph.hugr[compiled_def.hugr_node].op, ops.FuncDecl)
|
|
265
271
|
):
|
|
266
|
-
# if compiling a region set it as the HUGR entrypoint
|
|
267
|
-
#
|
|
272
|
+
# if compiling a region set it as the HUGR entrypoint can be
|
|
273
|
+
# loosened after https://github.com/quantinuum/hugr/issues/2501 is fixed
|
|
268
274
|
graph.hugr.entrypoint = compiled_def.hugr_node
|
|
269
275
|
|
|
270
276
|
# TODO: Currently the list of extensions is manually managed by the user.
|
|
@@ -278,7 +284,7 @@ class CompilationEngine:
|
|
|
278
284
|
guppylang_internals.compiler.hugr_extension.EXTENSION,
|
|
279
285
|
*self.additional_extensions,
|
|
280
286
|
]
|
|
281
|
-
# TODO replace with computed extensions after https://github.com/
|
|
287
|
+
# TODO replace with computed extensions after https://github.com/quantinuum/guppylang/issues/550
|
|
282
288
|
all_used_extensions = [
|
|
283
289
|
*extensions,
|
|
284
290
|
hugr.std.prelude.PRELUDE_EXTENSION,
|
|
@@ -90,3 +90,8 @@ def check_lists_enabled(loc: AstNode | None = None) -> None:
|
|
|
90
90
|
def check_capturing_closures_enabled(loc: AstNode | None = None) -> None:
|
|
91
91
|
if not EXPERIMENTAL_FEATURES_ENABLED:
|
|
92
92
|
raise GuppyError(UnsupportedError(loc, "Capturing closures"))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def check_modifiers_enabled(loc: AstNode | None = None) -> None:
|
|
96
|
+
if not EXPERIMENTAL_FEATURES_ENABLED:
|
|
97
|
+
raise GuppyError(ExperimentalFeatureError(loc, "Modifiers"))
|
guppylang_internals/nodes.py
CHANGED
|
@@ -6,9 +6,16 @@ from enum import Enum
|
|
|
6
6
|
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from guppylang_internals.ast_util import AstNode
|
|
9
|
+
from guppylang_internals.span import Span, to_span
|
|
9
10
|
from guppylang_internals.tys.const import Const
|
|
10
11
|
from guppylang_internals.tys.subst import Inst
|
|
11
|
-
from guppylang_internals.tys.ty import
|
|
12
|
+
from guppylang_internals.tys.ty import (
|
|
13
|
+
FunctionType,
|
|
14
|
+
StructType,
|
|
15
|
+
TupleType,
|
|
16
|
+
Type,
|
|
17
|
+
UnitaryFlags,
|
|
18
|
+
)
|
|
12
19
|
|
|
13
20
|
if TYPE_CHECKING:
|
|
14
21
|
from guppylang_internals.cfg.cfg import CFG
|
|
@@ -249,22 +256,6 @@ class ComptimeExpr(ast.expr):
|
|
|
249
256
|
_fields = ("value",)
|
|
250
257
|
|
|
251
258
|
|
|
252
|
-
class ResultExpr(ast.expr):
|
|
253
|
-
"""A `result(tag, value)` expression."""
|
|
254
|
-
|
|
255
|
-
value: ast.expr
|
|
256
|
-
base_ty: Type
|
|
257
|
-
#: Array length in case this is an array result, otherwise `None`
|
|
258
|
-
array_len: Const | None
|
|
259
|
-
tag: str
|
|
260
|
-
|
|
261
|
-
_fields = ("value", "base_ty", "array_len", "tag")
|
|
262
|
-
|
|
263
|
-
@property
|
|
264
|
-
def args(self) -> list[ast.expr]:
|
|
265
|
-
return [self.value]
|
|
266
|
-
|
|
267
|
-
|
|
268
259
|
class ExitKind(Enum):
|
|
269
260
|
ExitShot = 0 # Exit the current shot
|
|
270
261
|
Panic = 1 # Panic the program ending all shots
|
|
@@ -274,8 +265,8 @@ class PanicExpr(ast.expr):
|
|
|
274
265
|
"""A `panic(msg, *args)` or `exit(msg, *args)` expression ."""
|
|
275
266
|
|
|
276
267
|
kind: ExitKind
|
|
277
|
-
signal:
|
|
278
|
-
msg:
|
|
268
|
+
signal: ast.expr
|
|
269
|
+
msg: ast.expr
|
|
279
270
|
values: list[ast.expr]
|
|
280
271
|
|
|
281
272
|
_fields = ("kind", "signal", "msg", "values")
|
|
@@ -292,17 +283,16 @@ class BarrierExpr(ast.expr):
|
|
|
292
283
|
class StateResultExpr(ast.expr):
|
|
293
284
|
"""A `state_result(tag, *args)` expression."""
|
|
294
285
|
|
|
295
|
-
|
|
286
|
+
tag_value: Const
|
|
287
|
+
tag_expr: ast.expr
|
|
296
288
|
args: list[ast.expr]
|
|
297
289
|
func_ty: FunctionType
|
|
298
290
|
#: Array length in case this is an array result, otherwise `None`
|
|
299
291
|
array_len: Const | None
|
|
300
|
-
_fields = ("
|
|
292
|
+
_fields = ("tag_value", "tag_expr", "args", "func_ty", "has_array_input")
|
|
301
293
|
|
|
302
294
|
|
|
303
|
-
AnyCall =
|
|
304
|
-
LocalCall | GlobalCall | TensorCall | BarrierExpr | ResultExpr | StateResultExpr
|
|
305
|
-
)
|
|
295
|
+
AnyCall = LocalCall | GlobalCall | TensorCall | BarrierExpr | StateResultExpr
|
|
306
296
|
|
|
307
297
|
|
|
308
298
|
class InoutReturnSentinel(ast.expr):
|
|
@@ -422,3 +412,136 @@ class CheckedNestedFunctionDef(ast.FunctionDef):
|
|
|
422
412
|
self.cfg = cfg
|
|
423
413
|
self.ty = ty
|
|
424
414
|
self.captured = captured
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class Dagger(ast.expr):
|
|
418
|
+
"""The dagger modifier"""
|
|
419
|
+
|
|
420
|
+
def __init__(self, node: ast.expr) -> None:
|
|
421
|
+
super().__init__(**node.__dict__)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class Control(ast.Call):
|
|
425
|
+
"""The control modifier"""
|
|
426
|
+
|
|
427
|
+
ctrl: list[ast.expr]
|
|
428
|
+
qubit_num: int | Const | None
|
|
429
|
+
|
|
430
|
+
_fields = ("ctrl",)
|
|
431
|
+
|
|
432
|
+
def __init__(self, node: ast.Call, ctrl: list[ast.expr]) -> None:
|
|
433
|
+
super().__init__(**node.__dict__)
|
|
434
|
+
self.ctrl = ctrl
|
|
435
|
+
self.qubit_num = None
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
class Power(ast.expr):
|
|
439
|
+
"""The power modifier"""
|
|
440
|
+
|
|
441
|
+
iter: ast.expr
|
|
442
|
+
|
|
443
|
+
_fields = ("iter",)
|
|
444
|
+
|
|
445
|
+
def __init__(self, node: ast.expr, iter: ast.expr) -> None:
|
|
446
|
+
super().__init__(**node.__dict__)
|
|
447
|
+
self.iter = iter
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
Modifier = Dagger | Control | Power
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class ModifiedBlock(ast.With):
|
|
454
|
+
cfg: "CFG"
|
|
455
|
+
dagger: list[Dagger]
|
|
456
|
+
control: list[Control]
|
|
457
|
+
power: list[Power]
|
|
458
|
+
|
|
459
|
+
def __init__(self, cfg: "CFG", *args: Any, **kwargs: Any) -> None:
|
|
460
|
+
super().__init__(*args, **kwargs)
|
|
461
|
+
self.cfg = cfg
|
|
462
|
+
self.dagger = []
|
|
463
|
+
self.control = []
|
|
464
|
+
self.power = []
|
|
465
|
+
|
|
466
|
+
def is_dagger(self) -> bool:
|
|
467
|
+
return len(self.dagger) % 2 == 1
|
|
468
|
+
|
|
469
|
+
def is_control(self) -> bool:
|
|
470
|
+
return len(self.control) > 0
|
|
471
|
+
|
|
472
|
+
def is_power(self) -> bool:
|
|
473
|
+
return len(self.power) > 0
|
|
474
|
+
|
|
475
|
+
def span_ctxt_manager(self) -> Span:
|
|
476
|
+
return Span(
|
|
477
|
+
to_span(self.items[0].context_expr).start,
|
|
478
|
+
to_span(self.items[-1].context_expr).end,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
def push_modifier(self, modifier: Modifier) -> None:
|
|
482
|
+
"""Pushes a modifier kind onto the modifier."""
|
|
483
|
+
if isinstance(modifier, Dagger):
|
|
484
|
+
self.dagger.append(modifier)
|
|
485
|
+
elif isinstance(modifier, Control):
|
|
486
|
+
self.control.append(modifier)
|
|
487
|
+
elif isinstance(modifier, Power):
|
|
488
|
+
self.power.append(modifier)
|
|
489
|
+
else:
|
|
490
|
+
raise TypeError(f"Unknown modifier: {modifier}")
|
|
491
|
+
|
|
492
|
+
def flags(self) -> UnitaryFlags:
|
|
493
|
+
flags = UnitaryFlags.NoFlags
|
|
494
|
+
if self.is_dagger():
|
|
495
|
+
flags |= UnitaryFlags.Dagger
|
|
496
|
+
if self.is_control():
|
|
497
|
+
flags |= UnitaryFlags.Control
|
|
498
|
+
if self.is_power():
|
|
499
|
+
flags |= UnitaryFlags.Power
|
|
500
|
+
return flags
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
class CheckedModifiedBlock(ast.With):
|
|
504
|
+
def_id: "DefId"
|
|
505
|
+
cfg: "CheckedCFG[Place]"
|
|
506
|
+
dagger: list[Dagger]
|
|
507
|
+
control: list[Control]
|
|
508
|
+
power: list[Power]
|
|
509
|
+
|
|
510
|
+
#: The type of the body of With block.
|
|
511
|
+
ty: FunctionType
|
|
512
|
+
#: Mapping from names to variables captured in the body.
|
|
513
|
+
captured: Mapping[str, tuple["Variable", AstNode]]
|
|
514
|
+
|
|
515
|
+
def __init__(
|
|
516
|
+
self,
|
|
517
|
+
def_id: "DefId",
|
|
518
|
+
cfg: "CheckedCFG[Place]",
|
|
519
|
+
ty: FunctionType,
|
|
520
|
+
captured: Mapping[str, tuple["Variable", AstNode]],
|
|
521
|
+
dagger: list[Dagger],
|
|
522
|
+
control: list[Control],
|
|
523
|
+
power: list[Power],
|
|
524
|
+
*args: Any,
|
|
525
|
+
**kwargs: Any,
|
|
526
|
+
) -> None:
|
|
527
|
+
super().__init__(*args, **kwargs)
|
|
528
|
+
self.def_id = def_id
|
|
529
|
+
self.cfg = cfg
|
|
530
|
+
self.ty = ty
|
|
531
|
+
self.captured = captured
|
|
532
|
+
self.dagger = dagger
|
|
533
|
+
self.control = control
|
|
534
|
+
self.power = power
|
|
535
|
+
|
|
536
|
+
def __str__(self) -> str:
|
|
537
|
+
# generate a function name from the def_id
|
|
538
|
+
return f"__WithBlock__({self.def_id})"
|
|
539
|
+
|
|
540
|
+
def has_dagger(self) -> bool:
|
|
541
|
+
return len(self.dagger) % 2 == 1
|
|
542
|
+
|
|
543
|
+
def has_control(self) -> bool:
|
|
544
|
+
return any(len(c.ctrl) > 0 for c in self.control)
|
|
545
|
+
|
|
546
|
+
def has_power(self) -> bool:
|
|
547
|
+
return len(self.power) > 0
|