guppylang-internals 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +101 -3
- guppylang_internals/checker/core.py +12 -0
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/expr_checker.py +55 -29
- guppylang_internals/checker/func_checker.py +171 -22
- guppylang_internals/checker/linearity_checker.py +65 -0
- guppylang_internals/checker/modifier_checker.py +116 -0
- guppylang_internals/checker/stmt_checker.py +49 -2
- guppylang_internals/compiler/core.py +90 -53
- guppylang_internals/compiler/expr_compiler.py +49 -114
- guppylang_internals/compiler/modifier_compiler.py +174 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +124 -58
- guppylang_internals/definition/const.py +2 -2
- guppylang_internals/definition/custom.py +36 -2
- guppylang_internals/definition/declaration.py +4 -5
- guppylang_internals/definition/extern.py +2 -2
- guppylang_internals/definition/function.py +1 -1
- guppylang_internals/definition/parameter.py +10 -5
- guppylang_internals/definition/pytket_circuits.py +14 -42
- guppylang_internals/definition/struct.py +17 -14
- guppylang_internals/definition/traced.py +1 -1
- guppylang_internals/definition/ty.py +9 -3
- guppylang_internals/definition/wasm.py +2 -2
- guppylang_internals/engine.py +13 -2
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +124 -23
- guppylang_internals/std/_internal/compiler/array.py +94 -282
- guppylang_internals/std/_internal/compiler/tket_exts.py +12 -8
- guppylang_internals/std/_internal/compiler/wasm.py +37 -26
- guppylang_internals/tracing/function.py +13 -2
- guppylang_internals/tracing/unpacking.py +33 -28
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +32 -16
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +6 -0
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +118 -145
- guppylang_internals/tys/qubit.py +27 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +31 -21
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +4 -4
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +49 -46
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -46,7 +46,6 @@ from guppylang_internals.std._internal.compiler.array import (
|
|
|
46
46
|
array_new,
|
|
47
47
|
array_unpack,
|
|
48
48
|
)
|
|
49
|
-
from guppylang_internals.std._internal.compiler.prelude import build_unwrap
|
|
50
49
|
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
|
|
51
50
|
from guppylang_internals.tys.builtin import array_type, bool_type, float_type
|
|
52
51
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
@@ -85,7 +84,7 @@ class RawPytketDef(ParsableDef):
|
|
|
85
84
|
if not has_empty_body(func_ast):
|
|
86
85
|
# Function stub should have empty body.
|
|
87
86
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
88
|
-
stub_signature = check_signature(func_ast, globals)
|
|
87
|
+
stub_signature = check_signature(func_ast, globals, self.id)
|
|
89
88
|
|
|
90
89
|
# Compare signatures.
|
|
91
90
|
circuit_signature = _signature_from_circuit(self.input_circuit, self.defined_at)
|
|
@@ -195,17 +194,9 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
195
194
|
# them into separate wires.
|
|
196
195
|
for i, q_reg in enumerate(self.input_circuit.q_registers):
|
|
197
196
|
reg_wire = outer_func.inputs()[i]
|
|
198
|
-
|
|
199
|
-
array_unpack(ht.
|
|
197
|
+
elem_wires = outer_func.add_op(
|
|
198
|
+
array_unpack(ht.Qubit, q_reg.size), reg_wire
|
|
200
199
|
)
|
|
201
|
-
elem_wires = [
|
|
202
|
-
build_unwrap(
|
|
203
|
-
outer_func,
|
|
204
|
-
opt_elem,
|
|
205
|
-
"Internal error: unwrapping of array element failed",
|
|
206
|
-
)
|
|
207
|
-
for opt_elem in opt_elem_wires
|
|
208
|
-
]
|
|
209
200
|
input_list.extend(elem_wires)
|
|
210
201
|
|
|
211
202
|
else:
|
|
@@ -219,7 +210,8 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
219
210
|
]
|
|
220
211
|
|
|
221
212
|
# Symbolic parameters (if present) get passed after qubits and bools.
|
|
222
|
-
|
|
213
|
+
num_params = len(self.input_circuit.free_symbols())
|
|
214
|
+
has_params = num_params != 0
|
|
223
215
|
if has_params and "TKET1.input_parameters" not in hugr_func.metadata:
|
|
224
216
|
raise InternalGuppyError(
|
|
225
217
|
"Parameter metadata is missing from pytket circuit HUGR"
|
|
@@ -230,26 +222,17 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
230
222
|
if has_params:
|
|
231
223
|
lex_params: list[Wire] = list(outer_func.inputs()[offset:])
|
|
232
224
|
if self.use_arrays:
|
|
233
|
-
|
|
225
|
+
unpack_result = outer_func.add_op(
|
|
234
226
|
array_unpack(
|
|
235
|
-
ht.
|
|
236
|
-
q_reg.size,
|
|
227
|
+
ht.Tuple(float_type().to_hugr(ctx)), num_params
|
|
237
228
|
),
|
|
238
229
|
lex_params[0],
|
|
239
230
|
)
|
|
240
|
-
lex_params =
|
|
241
|
-
build_unwrap(
|
|
242
|
-
outer_func,
|
|
243
|
-
opt_param,
|
|
244
|
-
"Internal error: unwrapping of array element failed",
|
|
245
|
-
)
|
|
246
|
-
for opt_param in opt_param_wires
|
|
247
|
-
]
|
|
231
|
+
lex_params = list(unpack_result)
|
|
248
232
|
param_order = cast(
|
|
249
233
|
list[str], hugr_func.metadata["TKET1.input_parameters"]
|
|
250
234
|
)
|
|
251
235
|
lex_names = sorted(param_order)
|
|
252
|
-
assert len(lex_names) == len(lex_params)
|
|
253
236
|
name_to_param = dict(zip(lex_names, lex_params, strict=True))
|
|
254
237
|
angle_wires = [name_to_param[name] for name in param_order]
|
|
255
238
|
# Need to convert all angles to floats.
|
|
@@ -280,34 +263,23 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
280
263
|
]
|
|
281
264
|
|
|
282
265
|
if self.use_arrays:
|
|
283
|
-
|
|
284
|
-
def pack(elems: list[Wire], elem_ty: ht.Type, length: int) -> Wire:
|
|
285
|
-
elem_opts = [
|
|
286
|
-
outer_func.add_op(ops.Some(elem_ty), elem) for elem in elems
|
|
287
|
-
]
|
|
288
|
-
return outer_func.add_op(
|
|
289
|
-
array_new(ht.Option(elem_ty), length), *elem_opts
|
|
290
|
-
)
|
|
291
|
-
|
|
292
266
|
array_wires: list[Wire] = []
|
|
293
267
|
wire_idx = 0
|
|
294
268
|
# First pack bool results into an array.
|
|
295
269
|
for c_reg in self.input_circuit.c_registers:
|
|
296
270
|
array_wires.append(
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
c_reg.size,
|
|
271
|
+
outer_func.add_op(
|
|
272
|
+
array_new(OpaqueBool, c_reg.size),
|
|
273
|
+
*wires[wire_idx : wire_idx + c_reg.size],
|
|
301
274
|
)
|
|
302
275
|
)
|
|
303
276
|
wire_idx = wire_idx + c_reg.size
|
|
304
277
|
# Then the borrowed qubits also need to be put back into arrays.
|
|
305
278
|
for q_reg in self.input_circuit.q_registers:
|
|
306
279
|
array_wires.append(
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
q_reg.size,
|
|
280
|
+
outer_func.add_op(
|
|
281
|
+
array_new(ht.Qubit, q_reg.size),
|
|
282
|
+
*wires[wire_idx : wire_idx + q_reg.size],
|
|
311
283
|
)
|
|
312
284
|
)
|
|
313
285
|
wire_idx = wire_idx + q_reg.size
|
|
@@ -3,7 +3,7 @@ import inspect
|
|
|
3
3
|
import linecache
|
|
4
4
|
import sys
|
|
5
5
|
from collections.abc import Sequence
|
|
6
|
-
from dataclasses import dataclass
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
7
|
from types import FrameType
|
|
8
8
|
from typing import ClassVar
|
|
9
9
|
|
|
@@ -39,7 +39,7 @@ from guppylang_internals.ipython_inspect import is_running_ipython
|
|
|
39
39
|
from guppylang_internals.span import SourceMap, Span, to_span
|
|
40
40
|
from guppylang_internals.tys.arg import Argument
|
|
41
41
|
from guppylang_internals.tys.param import Parameter, check_all_args
|
|
42
|
-
from guppylang_internals.tys.parsing import type_from_ast
|
|
42
|
+
from guppylang_internals.tys.parsing import TypeParsingCtx, type_from_ast
|
|
43
43
|
from guppylang_internals.tys.ty import (
|
|
44
44
|
FuncInput,
|
|
45
45
|
FunctionType,
|
|
@@ -115,6 +115,7 @@ class RawStructDef(TypeDef, ParsableDef):
|
|
|
115
115
|
"""A raw struct type definition that has not been parsed yet."""
|
|
116
116
|
|
|
117
117
|
python_class: type
|
|
118
|
+
params: None = field(default=None, init=False) # Params not known yet
|
|
118
119
|
|
|
119
120
|
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef":
|
|
120
121
|
"""Parses the raw class object into an AST and checks that it is well-formed."""
|
|
@@ -130,10 +131,13 @@ class RawStructDef(TypeDef, ParsableDef):
|
|
|
130
131
|
if cls_def.type_params:
|
|
131
132
|
first, last = cls_def.type_params[0], cls_def.type_params[-1]
|
|
132
133
|
params_span = Span(to_span(first).start, to_span(last).end)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
134
|
+
param_vars_mapping: dict[str, Parameter] = {}
|
|
135
|
+
for idx, param_node in enumerate(cls_def.type_params):
|
|
136
|
+
param = parse_parameter(
|
|
137
|
+
param_node, idx, globals, param_vars_mapping
|
|
138
|
+
)
|
|
139
|
+
param_vars_mapping[param.name] = param
|
|
140
|
+
params.append(param)
|
|
137
141
|
|
|
138
142
|
# The only base we allow is `Generic[...]` to specify generic parameters with
|
|
139
143
|
# the legacy syntax
|
|
@@ -211,15 +215,16 @@ class ParsedStructDef(TypeDef, CheckableDef):
|
|
|
211
215
|
|
|
212
216
|
def check(self, globals: Globals) -> "CheckedStructDef":
|
|
213
217
|
"""Checks that all struct fields have valid types."""
|
|
218
|
+
param_var_mapping = {p.name: p for p in self.params}
|
|
219
|
+
ctx = TypeParsingCtx(globals, param_var_mapping)
|
|
220
|
+
|
|
214
221
|
# Before checking the fields, make sure that this definition is not recursive,
|
|
215
222
|
# otherwise the code below would not terminate.
|
|
216
223
|
# TODO: This is not ideal (see todo in `check_instantiate`)
|
|
217
|
-
|
|
218
|
-
check_not_recursive(self, globals, param_var_mapping)
|
|
224
|
+
check_not_recursive(self, ctx)
|
|
219
225
|
|
|
220
226
|
fields = [
|
|
221
|
-
StructField(f.name, type_from_ast(f.type_ast,
|
|
222
|
-
for f in self.fields
|
|
227
|
+
StructField(f.name, type_from_ast(f.type_ast, ctx)) for f in self.fields
|
|
223
228
|
]
|
|
224
229
|
return CheckedStructDef(
|
|
225
230
|
self.id, self.name, self.defined_at, self.params, fields
|
|
@@ -370,9 +375,7 @@ def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Paramet
|
|
|
370
375
|
return params
|
|
371
376
|
|
|
372
377
|
|
|
373
|
-
def check_not_recursive(
|
|
374
|
-
defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter]
|
|
375
|
-
) -> None:
|
|
378
|
+
def check_not_recursive(defn: ParsedStructDef, ctx: TypeParsingCtx) -> None:
|
|
376
379
|
"""Throws a user error if the given struct definition is recursive."""
|
|
377
380
|
|
|
378
381
|
# TODO: The implementation below hijacks the type parsing logic to detect recursive
|
|
@@ -388,5 +391,5 @@ def check_not_recursive(
|
|
|
388
391
|
original = defn.check_instantiate
|
|
389
392
|
object.__setattr__(defn, "check_instantiate", dummy_check_instantiate)
|
|
390
393
|
for fld in defn.fields:
|
|
391
|
-
type_from_ast(fld.type_ast,
|
|
394
|
+
type_from_ast(fld.type_ast, ctx)
|
|
392
395
|
object.__setattr__(defn, "check_instantiate", original)
|
|
@@ -48,7 +48,7 @@ class RawTracedFunctionDef(ParsableDef):
|
|
|
48
48
|
def parse(self, globals: Globals, sources: SourceMap) -> "TracedFunctionDef":
|
|
49
49
|
"""Parses and checks the user-provided signature of the function."""
|
|
50
50
|
func_ast, _docstring = parse_py_func(self.python_func, sources)
|
|
51
|
-
ty = check_signature(func_ast, globals)
|
|
51
|
+
ty = check_signature(func_ast, globals, self.id)
|
|
52
52
|
if ty.parametrized:
|
|
53
53
|
raise GuppyError(UnsupportedError(func_ast, "Generic comptime functions"))
|
|
54
54
|
return TracedFunctionDef(self.id, self.name, func_ast, ty, self.python_func)
|
|
@@ -2,7 +2,7 @@ from abc import abstractmethod
|
|
|
2
2
|
from collections.abc import Callable, Sequence
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
|
|
5
|
-
from hugr import tys
|
|
5
|
+
from hugr import tys as ht
|
|
6
6
|
|
|
7
7
|
from guppylang_internals.ast_util import AstNode
|
|
8
8
|
from guppylang_internals.definition.common import CompiledDef, Definition
|
|
@@ -18,6 +18,12 @@ class TypeDef(Definition):
|
|
|
18
18
|
|
|
19
19
|
description: str = field(default="type", init=False)
|
|
20
20
|
|
|
21
|
+
#: Generic parameters of the type. This may be `None` for special types that are
|
|
22
|
+
#: more polymorphic than the regular type system allows (for example `tuple` and
|
|
23
|
+
#: `Callable`), or if this is a raw definition whose parameters are not determined
|
|
24
|
+
#: yet (for example a `RawStructDef`).
|
|
25
|
+
params: Sequence[Parameter] | None
|
|
26
|
+
|
|
21
27
|
@abstractmethod
|
|
22
28
|
def check_instantiate(
|
|
23
29
|
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
@@ -36,8 +42,8 @@ class OpaqueTypeDef(TypeDef, CompiledDef):
|
|
|
36
42
|
params: Sequence[Parameter]
|
|
37
43
|
never_copyable: bool
|
|
38
44
|
never_droppable: bool
|
|
39
|
-
to_hugr: Callable[[Sequence[Argument], ToHugrContext],
|
|
40
|
-
bound:
|
|
45
|
+
to_hugr: Callable[[Sequence[Argument], ToHugrContext], ht.Type]
|
|
46
|
+
bound: ht.TypeBound | None = None
|
|
41
47
|
|
|
42
48
|
def check_instantiate(
|
|
43
49
|
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
@@ -11,7 +11,7 @@ from guppylang_internals.definition.custom import (
|
|
|
11
11
|
)
|
|
12
12
|
from guppylang_internals.error import GuppyError
|
|
13
13
|
from guppylang_internals.span import SourceMap
|
|
14
|
-
from guppylang_internals.tys.builtin import
|
|
14
|
+
from guppylang_internals.tys.builtin import wasm_module_name
|
|
15
15
|
from guppylang_internals.tys.ty import (
|
|
16
16
|
FuncInput,
|
|
17
17
|
FunctionType,
|
|
@@ -30,7 +30,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
30
30
|
def sanitise_type(self, loc: AstNode | None, fun_ty: FunctionType) -> None:
|
|
31
31
|
# Place to highlight in error messages
|
|
32
32
|
match fun_ty.inputs[0]:
|
|
33
|
-
case FuncInput(ty=ty, flags=InputFlags.Inout) if
|
|
33
|
+
case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
|
|
34
34
|
ty
|
|
35
35
|
) is not None:
|
|
36
36
|
pass
|
guppylang_internals/engine.py
CHANGED
|
@@ -41,6 +41,7 @@ from guppylang_internals.tys.builtin import (
|
|
|
41
41
|
nat_type_def,
|
|
42
42
|
none_type_def,
|
|
43
43
|
option_type_def,
|
|
44
|
+
self_type_def,
|
|
44
45
|
sized_iter_type_def,
|
|
45
46
|
string_type_def,
|
|
46
47
|
tuple_type_def,
|
|
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|
|
51
52
|
|
|
52
53
|
BUILTIN_DEFS_LIST: list[RawDef] = [
|
|
53
54
|
callable_type_def,
|
|
55
|
+
self_type_def,
|
|
54
56
|
tuple_type_def,
|
|
55
57
|
none_type_def,
|
|
56
58
|
bool_type_def,
|
|
@@ -84,12 +86,14 @@ class DefinitionStore:
|
|
|
84
86
|
|
|
85
87
|
raw_defs: dict[DefId, RawDef]
|
|
86
88
|
impls: defaultdict[DefId, dict[str, DefId]]
|
|
89
|
+
impl_parents: dict[DefId, DefId]
|
|
87
90
|
frames: dict[DefId, FrameType]
|
|
88
91
|
sources: SourceMap
|
|
89
92
|
|
|
90
93
|
def __init__(self) -> None:
|
|
91
94
|
self.raw_defs = {defn.id: defn for defn in BUILTIN_DEFS_LIST}
|
|
92
95
|
self.impls = defaultdict(dict)
|
|
96
|
+
self.impl_parents = {}
|
|
93
97
|
self.frames = {}
|
|
94
98
|
self.sources = SourceMap()
|
|
95
99
|
|
|
@@ -99,7 +103,9 @@ class DefinitionStore:
|
|
|
99
103
|
self.frames[defn.id] = frame
|
|
100
104
|
|
|
101
105
|
def register_impl(self, ty_id: DefId, name: str, impl_id: DefId) -> None:
|
|
106
|
+
assert impl_id not in self.impl_parents, "Already an impl"
|
|
102
107
|
self.impls[ty_id][name] = impl_id
|
|
108
|
+
self.impl_parents[impl_id] = ty_id
|
|
103
109
|
# Update the frame of the definition to the frame of the defining class
|
|
104
110
|
if impl_id in self.frames:
|
|
105
111
|
frame = self.frames[impl_id].f_back
|
|
@@ -138,18 +144,23 @@ class CompilationEngine:
|
|
|
138
144
|
types_to_check_worklist: dict[DefId, ParsedDef]
|
|
139
145
|
to_check_worklist: dict[DefId, ParsedDef]
|
|
140
146
|
|
|
147
|
+
def __init__(self) -> None:
|
|
148
|
+
"""Resets the compilation cache."""
|
|
149
|
+
self.reset()
|
|
150
|
+
self.additional_extensions = []
|
|
151
|
+
|
|
141
152
|
def reset(self) -> None:
|
|
142
153
|
"""Resets the compilation cache."""
|
|
143
154
|
self.parsed = {}
|
|
144
155
|
self.checked = {}
|
|
145
156
|
self.compiled = {}
|
|
146
|
-
self.additional_extensions = []
|
|
147
157
|
self.to_check_worklist = {}
|
|
148
158
|
self.types_to_check_worklist = {}
|
|
149
159
|
|
|
150
160
|
@pretty_errors
|
|
151
161
|
def register_extension(self, extension: Extension) -> None:
|
|
152
|
-
self.additional_extensions
|
|
162
|
+
if extension not in self.additional_extensions:
|
|
163
|
+
self.additional_extensions.append(extension)
|
|
153
164
|
|
|
154
165
|
@pretty_errors
|
|
155
166
|
def get_parsed(self, id: DefId) -> ParsedDef:
|
|
@@ -90,3 +90,8 @@ def check_lists_enabled(loc: AstNode | None = None) -> None:
|
|
|
90
90
|
def check_capturing_closures_enabled(loc: AstNode | None = None) -> None:
|
|
91
91
|
if not EXPERIMENTAL_FEATURES_ENABLED:
|
|
92
92
|
raise GuppyError(UnsupportedError(loc, "Capturing closures"))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def check_modifiers_enabled(loc: AstNode | None = None) -> None:
|
|
96
|
+
if not EXPERIMENTAL_FEATURES_ENABLED:
|
|
97
|
+
raise GuppyError(ExperimentalFeatureError(loc, "Modifiers"))
|
guppylang_internals/nodes.py
CHANGED
|
@@ -6,6 +6,7 @@ from enum import Enum
|
|
|
6
6
|
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from guppylang_internals.ast_util import AstNode
|
|
9
|
+
from guppylang_internals.span import Span, to_span
|
|
9
10
|
from guppylang_internals.tys.const import Const
|
|
10
11
|
from guppylang_internals.tys.subst import Inst
|
|
11
12
|
from guppylang_internals.tys.ty import FunctionType, StructType, TupleType, Type
|
|
@@ -166,17 +167,6 @@ class MakeIter(ast.expr):
|
|
|
166
167
|
self.unwrap_size_hint = unwrap_size_hint
|
|
167
168
|
|
|
168
169
|
|
|
169
|
-
class IterHasNext(ast.expr):
|
|
170
|
-
"""Checks if an iterator has a next element using the `__hasnext__` magic method.
|
|
171
|
-
|
|
172
|
-
This node is inserted in `for` loops and list comprehensions.
|
|
173
|
-
"""
|
|
174
|
-
|
|
175
|
-
value: ast.expr
|
|
176
|
-
|
|
177
|
-
_fields = ("value",)
|
|
178
|
-
|
|
179
|
-
|
|
180
170
|
class IterNext(ast.expr):
|
|
181
171
|
"""Obtains the next element of an iterator using the `__next__` magic method.
|
|
182
172
|
|
|
@@ -188,18 +178,6 @@ class IterNext(ast.expr):
|
|
|
188
178
|
_fields = ("value",)
|
|
189
179
|
|
|
190
180
|
|
|
191
|
-
class IterEnd(ast.expr):
|
|
192
|
-
"""Finalises an iterator using the `__end__` magic method.
|
|
193
|
-
|
|
194
|
-
This node is inserted in `for` loops and list comprehensions. It is needed to
|
|
195
|
-
consume linear iterators once they are finished.
|
|
196
|
-
"""
|
|
197
|
-
|
|
198
|
-
value: ast.expr
|
|
199
|
-
|
|
200
|
-
_fields = ("value",)
|
|
201
|
-
|
|
202
|
-
|
|
203
181
|
class DesugaredGenerator(ast.expr):
|
|
204
182
|
"""A single desugared generator in a list comprehension.
|
|
205
183
|
|
|
@@ -445,3 +423,126 @@ class CheckedNestedFunctionDef(ast.FunctionDef):
|
|
|
445
423
|
self.cfg = cfg
|
|
446
424
|
self.ty = ty
|
|
447
425
|
self.captured = captured
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class Dagger(ast.expr):
|
|
429
|
+
"""The dagger modifier"""
|
|
430
|
+
|
|
431
|
+
def __init__(self, node: ast.expr) -> None:
|
|
432
|
+
super().__init__(**node.__dict__)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class Control(ast.Call):
|
|
436
|
+
"""The control modifier"""
|
|
437
|
+
|
|
438
|
+
ctrl: list[ast.expr]
|
|
439
|
+
qubit_num: int | Const | None
|
|
440
|
+
|
|
441
|
+
_fields = ("ctrl",)
|
|
442
|
+
|
|
443
|
+
def __init__(self, node: ast.Call, ctrl: list[ast.expr]) -> None:
|
|
444
|
+
super().__init__(**node.__dict__)
|
|
445
|
+
self.ctrl = ctrl
|
|
446
|
+
self.qubit_num = None
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class Power(ast.expr):
|
|
450
|
+
"""The power modifier"""
|
|
451
|
+
|
|
452
|
+
iter: ast.expr
|
|
453
|
+
|
|
454
|
+
_fields = ("iter",)
|
|
455
|
+
|
|
456
|
+
def __init__(self, node: ast.expr, iter: ast.expr) -> None:
|
|
457
|
+
super().__init__(**node.__dict__)
|
|
458
|
+
self.iter = iter
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
Modifier = Dagger | Control | Power
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
class ModifiedBlock(ast.With):
|
|
465
|
+
cfg: "CFG"
|
|
466
|
+
dagger: list[Dagger]
|
|
467
|
+
control: list[Control]
|
|
468
|
+
power: list[Power]
|
|
469
|
+
|
|
470
|
+
def __init__(self, cfg: "CFG", *args: Any, **kwargs: Any) -> None:
|
|
471
|
+
super().__init__(*args, **kwargs)
|
|
472
|
+
self.cfg = cfg
|
|
473
|
+
self.dagger = []
|
|
474
|
+
self.control = []
|
|
475
|
+
self.power = []
|
|
476
|
+
|
|
477
|
+
def is_dagger(self) -> bool:
|
|
478
|
+
return len(self.dagger) % 2 == 1
|
|
479
|
+
|
|
480
|
+
def is_control(self) -> bool:
|
|
481
|
+
return len(self.control) > 0
|
|
482
|
+
|
|
483
|
+
def is_power(self) -> bool:
|
|
484
|
+
return len(self.power) > 0
|
|
485
|
+
|
|
486
|
+
def span_ctxt_manager(self) -> Span:
|
|
487
|
+
return Span(
|
|
488
|
+
to_span(self.items[0].context_expr).start,
|
|
489
|
+
to_span(self.items[-1].context_expr).end,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
def push_modifier(self, modifier: Modifier) -> None:
|
|
493
|
+
"""Pushes a modifier kind onto the modifier."""
|
|
494
|
+
if isinstance(modifier, Dagger):
|
|
495
|
+
self.dagger.append(modifier)
|
|
496
|
+
elif isinstance(modifier, Control):
|
|
497
|
+
self.control.append(modifier)
|
|
498
|
+
elif isinstance(modifier, Power):
|
|
499
|
+
self.power.append(modifier)
|
|
500
|
+
else:
|
|
501
|
+
raise TypeError(f"Unknown modifier: {modifier}")
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
class CheckedModifiedBlock(ast.With):
|
|
505
|
+
def_id: "DefId"
|
|
506
|
+
cfg: "CheckedCFG[Place]"
|
|
507
|
+
dagger: list[Dagger]
|
|
508
|
+
control: list[Control]
|
|
509
|
+
power: list[Power]
|
|
510
|
+
|
|
511
|
+
#: The type of the body of With block.
|
|
512
|
+
ty: FunctionType
|
|
513
|
+
#: Mapping from names to variables captured in the body.
|
|
514
|
+
captured: Mapping[str, tuple["Variable", AstNode]]
|
|
515
|
+
|
|
516
|
+
def __init__(
|
|
517
|
+
self,
|
|
518
|
+
def_id: "DefId",
|
|
519
|
+
cfg: "CheckedCFG[Place]",
|
|
520
|
+
ty: FunctionType,
|
|
521
|
+
captured: Mapping[str, tuple["Variable", AstNode]],
|
|
522
|
+
dagger: list[Dagger],
|
|
523
|
+
control: list[Control],
|
|
524
|
+
power: list[Power],
|
|
525
|
+
*args: Any,
|
|
526
|
+
**kwargs: Any,
|
|
527
|
+
) -> None:
|
|
528
|
+
super().__init__(*args, **kwargs)
|
|
529
|
+
self.def_id = def_id
|
|
530
|
+
self.cfg = cfg
|
|
531
|
+
self.ty = ty
|
|
532
|
+
self.captured = captured
|
|
533
|
+
self.dagger = dagger
|
|
534
|
+
self.control = control
|
|
535
|
+
self.power = power
|
|
536
|
+
|
|
537
|
+
def __str__(self) -> str:
|
|
538
|
+
# generate a function name from the def_id
|
|
539
|
+
return f"__WithBlock__({self.def_id})"
|
|
540
|
+
|
|
541
|
+
def has_dagger(self) -> bool:
|
|
542
|
+
return len(self.dagger) % 2 == 1
|
|
543
|
+
|
|
544
|
+
def has_control(self) -> bool:
|
|
545
|
+
return any(len(c.ctrl) > 0 for c in self.control)
|
|
546
|
+
|
|
547
|
+
def has_power(self) -> bool:
|
|
548
|
+
return len(self.power) > 0
|