bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bloqade/analysis/address/__init__.py +8 -4
- bloqade/analysis/address/analysis.py +123 -33
- bloqade/analysis/address/impls.py +293 -90
- bloqade/analysis/address/lattice.py +209 -24
- bloqade/analysis/fidelity/analysis.py +11 -23
- bloqade/analysis/measure_id/analysis.py +18 -20
- bloqade/analysis/measure_id/impls.py +31 -29
- bloqade/annotate/__init__.py +6 -0
- bloqade/annotate/_dialect.py +3 -0
- bloqade/annotate/_interface.py +22 -0
- bloqade/annotate/stmts.py +29 -0
- bloqade/annotate/types.py +13 -0
- bloqade/cirq_utils/__init__.py +4 -2
- bloqade/cirq_utils/emit/__init__.py +3 -0
- bloqade/cirq_utils/emit/base.py +246 -0
- bloqade/cirq_utils/emit/gate.py +104 -0
- bloqade/cirq_utils/emit/noise.py +90 -0
- bloqade/cirq_utils/emit/qubit.py +35 -0
- bloqade/cirq_utils/lowering.py +660 -0
- bloqade/cirq_utils/noise/__init__.py +0 -2
- bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
- bloqade/cirq_utils/noise/model.py +151 -191
- bloqade/cirq_utils/noise/transform.py +2 -2
- bloqade/cirq_utils/parallelize.py +9 -6
- bloqade/gemini/__init__.py +1 -0
- bloqade/gemini/analysis/__init__.py +3 -0
- bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
- bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
- bloqade/gemini/analysis/logical_validation/impls.py +101 -0
- bloqade/gemini/groups.py +67 -0
- bloqade/native/__init__.py +23 -0
- bloqade/native/_prelude.py +45 -0
- bloqade/native/dialects/__init__.py +0 -0
- bloqade/native/dialects/gate/__init__.py +2 -0
- bloqade/native/dialects/gate/_dialect.py +3 -0
- bloqade/native/dialects/gate/_interface.py +32 -0
- bloqade/native/dialects/gate/stmts.py +31 -0
- bloqade/native/stdlib/__init__.py +0 -0
- bloqade/native/stdlib/broadcast.py +246 -0
- bloqade/native/stdlib/simple.py +220 -0
- bloqade/native/upstream/__init__.py +4 -0
- bloqade/native/upstream/squin2native.py +79 -0
- bloqade/pyqrack/__init__.py +2 -2
- bloqade/pyqrack/base.py +7 -1
- bloqade/pyqrack/device.py +192 -18
- bloqade/pyqrack/native.py +49 -0
- bloqade/pyqrack/reg.py +6 -6
- bloqade/pyqrack/squin/gate/__init__.py +1 -0
- bloqade/pyqrack/squin/gate/gate.py +136 -0
- bloqade/pyqrack/squin/noise/native.py +120 -54
- bloqade/pyqrack/squin/qubit.py +39 -36
- bloqade/pyqrack/target.py +5 -4
- bloqade/pyqrack/task.py +114 -7
- bloqade/qasm2/_qasm_loading.py +3 -3
- bloqade/qasm2/dialects/core/address.py +21 -12
- bloqade/qasm2/dialects/expr/_emit.py +19 -8
- bloqade/qasm2/dialects/expr/stmts.py +7 -7
- bloqade/qasm2/dialects/noise/fidelity.py +4 -8
- bloqade/qasm2/dialects/noise/model.py +2 -1
- bloqade/qasm2/emit/base.py +16 -11
- bloqade/qasm2/emit/gate.py +11 -8
- bloqade/qasm2/emit/main.py +103 -3
- bloqade/qasm2/emit/target.py +9 -5
- bloqade/qasm2/groups.py +3 -2
- bloqade/qasm2/parse/lowering.py +0 -1
- bloqade/qasm2/passes/fold.py +14 -73
- bloqade/qasm2/passes/glob.py +2 -2
- bloqade/qasm2/passes/noise.py +1 -1
- bloqade/qasm2/passes/parallel.py +7 -5
- bloqade/qasm2/rewrite/__init__.py +0 -1
- bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
- bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
- bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
- bloqade/qasm2/rewrite/register.py +2 -2
- bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
- bloqade/qbraid/lowering.py +1 -0
- bloqade/qbraid/schema.py +2 -2
- bloqade/qubit/__init__.py +12 -0
- bloqade/qubit/_dialect.py +3 -0
- bloqade/qubit/_interface.py +49 -0
- bloqade/qubit/_prelude.py +45 -0
- bloqade/qubit/analysis/__init__.py +1 -0
- bloqade/qubit/analysis/address_impl.py +40 -0
- bloqade/qubit/stdlib/__init__.py +2 -0
- bloqade/qubit/stdlib/_new.py +34 -0
- bloqade/qubit/stdlib/broadcast.py +62 -0
- bloqade/qubit/stdlib/simple.py +59 -0
- bloqade/qubit/stmts.py +60 -0
- bloqade/rewrite/passes/__init__.py +6 -0
- bloqade/rewrite/passes/aggressive_unroll.py +103 -0
- bloqade/rewrite/passes/callgraph.py +116 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
- bloqade/rewrite/rules/split_ifs.py +18 -1
- bloqade/squin/__init__.py +47 -14
- bloqade/squin/analysis/__init__.py +0 -1
- bloqade/squin/analysis/schedule.py +10 -11
- bloqade/squin/gate/__init__.py +2 -0
- bloqade/squin/gate/_dialect.py +3 -0
- bloqade/squin/gate/_interface.py +98 -0
- bloqade/squin/gate/stmts.py +125 -0
- bloqade/squin/groups.py +5 -22
- bloqade/squin/noise/__init__.py +1 -10
- bloqade/squin/noise/_dialect.py +1 -1
- bloqade/squin/noise/_interface.py +45 -0
- bloqade/squin/noise/stmts.py +66 -28
- bloqade/squin/rewrite/U3_to_clifford.py +70 -51
- bloqade/squin/rewrite/__init__.py +0 -2
- bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
- bloqade/squin/rewrite/wrap_analysis.py +4 -35
- bloqade/squin/stdlib/__init__.py +0 -0
- bloqade/squin/stdlib/broadcast/__init__.py +34 -0
- bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
- bloqade/squin/stdlib/broadcast/gate.py +260 -0
- bloqade/squin/stdlib/broadcast/noise.py +144 -0
- bloqade/squin/stdlib/simple/__init__.py +33 -0
- bloqade/squin/stdlib/simple/gate.py +242 -0
- bloqade/squin/stdlib/simple/noise.py +126 -0
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +6 -0
- bloqade/stim/dialects/auxiliary/emit.py +19 -18
- bloqade/stim/dialects/collapse/emit_str.py +7 -8
- bloqade/stim/dialects/gate/emit.py +9 -10
- bloqade/stim/dialects/noise/emit.py +17 -13
- bloqade/stim/dialects/noise/stmts.py +5 -3
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/impls.py +16 -0
- bloqade/stim/emit/stim_str.py +48 -31
- bloqade/stim/groups.py +12 -2
- bloqade/stim/parse/lowering.py +14 -17
- bloqade/stim/passes/__init__.py +0 -2
- bloqade/stim/passes/flatten.py +26 -0
- bloqade/stim/passes/simplify_ifs.py +6 -1
- bloqade/stim/passes/squin_to_stim.py +9 -84
- bloqade/stim/rewrite/__init__.py +2 -4
- bloqade/stim/rewrite/get_record_util.py +24 -0
- bloqade/stim/rewrite/ifs_to_stim.py +24 -25
- bloqade/stim/rewrite/qubit_to_stim.py +90 -41
- bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
- bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
- bloqade/stim/rewrite/squin_measure.py +9 -18
- bloqade/stim/rewrite/squin_noise.py +134 -108
- bloqade/stim/rewrite/util.py +5 -192
- bloqade/test_utils.py +1 -1
- bloqade/types.py +10 -0
- bloqade/validation/__init__.py +2 -0
- bloqade/validation/analysis/__init__.py +5 -0
- bloqade/validation/analysis/analysis.py +41 -0
- bloqade/validation/analysis/lattice.py +58 -0
- bloqade/validation/kernel_validation.py +77 -0
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
- bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
- bloqade/pyqrack/squin/op.py +0 -180
- bloqade/pyqrack/squin/runtime.py +0 -535
- bloqade/pyqrack/squin/wire.py +0 -51
- bloqade/rewrite/rules/flatten_ilist.py +0 -51
- bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
- bloqade/squin/_typeinfer.py +0 -20
- bloqade/squin/analysis/address_impl.py +0 -71
- bloqade/squin/analysis/nsites/__init__.py +0 -9
- bloqade/squin/analysis/nsites/analysis.py +0 -50
- bloqade/squin/analysis/nsites/impls.py +0 -92
- bloqade/squin/analysis/nsites/lattice.py +0 -49
- bloqade/squin/cirq/__init__.py +0 -280
- bloqade/squin/cirq/emit/emit_circuit.py +0 -109
- bloqade/squin/cirq/emit/noise.py +0 -49
- bloqade/squin/cirq/emit/op.py +0 -125
- bloqade/squin/cirq/emit/qubit.py +0 -60
- bloqade/squin/cirq/emit/runtime.py +0 -242
- bloqade/squin/cirq/lowering.py +0 -440
- bloqade/squin/lowering.py +0 -54
- bloqade/squin/noise/_wrapper.py +0 -40
- bloqade/squin/noise/rewrite.py +0 -111
- bloqade/squin/op/__init__.py +0 -41
- bloqade/squin/op/_dialect.py +0 -3
- bloqade/squin/op/_wrapper.py +0 -121
- bloqade/squin/op/number.py +0 -5
- bloqade/squin/op/rewrite.py +0 -46
- bloqade/squin/op/stdlib.py +0 -62
- bloqade/squin/op/stmts.py +0 -276
- bloqade/squin/op/traits.py +0 -43
- bloqade/squin/op/types.py +0 -26
- bloqade/squin/qubit.py +0 -184
- bloqade/squin/rewrite/canonicalize.py +0 -60
- bloqade/squin/rewrite/desugar.py +0 -124
- bloqade/squin/types.py +0 -8
- bloqade/squin/wire.py +0 -201
- bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
- bloqade/stim/rewrite/wire_to_stim.py +0 -57
- bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
5
|
+
|
|
6
|
+
from bloqade.stim.dialects import auxiliary
|
|
7
|
+
from bloqade.annotate.stmts import SetObservable
|
|
8
|
+
from bloqade.analysis.measure_id import MeasureIDFrame
|
|
9
|
+
from bloqade.stim.dialects.auxiliary import ObservableInclude
|
|
10
|
+
from bloqade.analysis.measure_id.lattice import MeasureIdTuple
|
|
11
|
+
|
|
12
|
+
from ..rewrite.get_record_util import insert_get_records
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class SetObservableToStim(RewriteRule):
|
|
17
|
+
"""
|
|
18
|
+
Rewrite SetObservable to GetRecord and ObservableInclude in the stim dialect
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
measure_id_frame: MeasureIDFrame
|
|
22
|
+
|
|
23
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
24
|
+
match node:
|
|
25
|
+
case SetObservable():
|
|
26
|
+
return self.rewrite_SetObservable(node)
|
|
27
|
+
case _:
|
|
28
|
+
return RewriteResult()
|
|
29
|
+
|
|
30
|
+
def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult:
|
|
31
|
+
|
|
32
|
+
# set idx to 0 for now, but this
|
|
33
|
+
# should be something that a user can set on their own.
|
|
34
|
+
# SetObservable needs to accept an int.
|
|
35
|
+
|
|
36
|
+
idx_stmt = auxiliary.ConstInt(value=0)
|
|
37
|
+
idx_stmt.insert_before(node)
|
|
38
|
+
|
|
39
|
+
measure_ids = self.measure_id_frame.entries[node.measurements]
|
|
40
|
+
assert isinstance(measure_ids, MeasureIdTuple)
|
|
41
|
+
|
|
42
|
+
get_record_list = insert_get_records(
|
|
43
|
+
node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
observable_include_stmt = ObservableInclude(
|
|
47
|
+
idx=idx_stmt.result, targets=tuple(get_record_list)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
node.replace_by(observable_include_stmt)
|
|
51
|
+
|
|
52
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -5,11 +5,10 @@ from kirin import ir
|
|
|
5
5
|
from kirin.dialects import py
|
|
6
6
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
7
7
|
|
|
8
|
-
from bloqade
|
|
8
|
+
from bloqade import qubit
|
|
9
9
|
from bloqade.squin.rewrite import AddressAttribute
|
|
10
10
|
from bloqade.stim.dialects import collapse
|
|
11
11
|
from bloqade.stim.rewrite.util import (
|
|
12
|
-
is_measure_result_used,
|
|
13
12
|
insert_qubit_idx_from_address,
|
|
14
13
|
)
|
|
15
14
|
|
|
@@ -23,14 +22,12 @@ class SquinMeasureToStim(RewriteRule):
|
|
|
23
22
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
24
23
|
|
|
25
24
|
match node:
|
|
26
|
-
case qubit.
|
|
25
|
+
case qubit.stmts.Measure():
|
|
27
26
|
return self.rewrite_Measure(node)
|
|
28
27
|
case _:
|
|
29
28
|
return RewriteResult()
|
|
30
29
|
|
|
31
|
-
def rewrite_Measure(
|
|
32
|
-
self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
|
|
33
|
-
) -> RewriteResult:
|
|
30
|
+
def rewrite_Measure(self, measure_stmt: qubit.stmts.Measure) -> RewriteResult:
|
|
34
31
|
|
|
35
32
|
qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
|
|
36
33
|
if qubit_idx_ssas is None:
|
|
@@ -44,27 +41,21 @@ class SquinMeasureToStim(RewriteRule):
|
|
|
44
41
|
prob_noise_stmt.insert_before(measure_stmt)
|
|
45
42
|
stim_measure_stmt.insert_before(measure_stmt)
|
|
46
43
|
|
|
47
|
-
if not
|
|
44
|
+
# if the measurement is not being used anywhere
|
|
45
|
+
# we can safely get rid of it. Measure cannot be DCE'd because
|
|
46
|
+
# it is not pure.
|
|
47
|
+
if not bool(measure_stmt.result.uses):
|
|
48
48
|
measure_stmt.delete()
|
|
49
49
|
|
|
50
50
|
return RewriteResult(has_done_something=True)
|
|
51
51
|
|
|
52
52
|
def get_qubit_idx_ssas(
|
|
53
|
-
self, measure_stmt: qubit.
|
|
53
|
+
self, measure_stmt: qubit.stmts.Measure
|
|
54
54
|
) -> tuple[ir.SSAValue, ...] | None:
|
|
55
55
|
"""
|
|
56
56
|
Extract the address attribute and insert qubit indices for the given measure statement.
|
|
57
57
|
"""
|
|
58
|
-
|
|
59
|
-
case qubit.MeasureQubit():
|
|
60
|
-
address_attr = measure_stmt.qubit.hints.get("address")
|
|
61
|
-
case qubit.MeasureQubitList():
|
|
62
|
-
address_attr = measure_stmt.qubits.hints.get("address")
|
|
63
|
-
case wire.Measure():
|
|
64
|
-
address_attr = measure_stmt.wire.hints.get("address")
|
|
65
|
-
case _:
|
|
66
|
-
return None
|
|
67
|
-
|
|
58
|
+
address_attr = measure_stmt.qubits.hints.get("address")
|
|
68
59
|
if address_attr is None:
|
|
69
60
|
return None
|
|
70
61
|
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
+
import itertools
|
|
1
2
|
from typing import Tuple
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
|
|
5
|
+
from kirin import types
|
|
4
6
|
from kirin.ir import SSAValue, Statement
|
|
5
|
-
from kirin.dialects import py
|
|
7
|
+
from kirin.dialects import py
|
|
6
8
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
7
9
|
|
|
8
|
-
from bloqade.squin import
|
|
10
|
+
from bloqade.squin import noise as squin_noise
|
|
9
11
|
from bloqade.stim.dialects import noise as stim_noise
|
|
10
|
-
from bloqade.stim.rewrite.util import
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
insert_qubit_idx_after_apply,
|
|
14
|
-
)
|
|
12
|
+
from bloqade.stim.rewrite.util import insert_qubit_idx_from_address
|
|
13
|
+
from bloqade.analysis.address.lattice import AddressReg, PartialIList
|
|
14
|
+
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@dataclass
|
|
@@ -19,157 +19,183 @@ class SquinNoiseToStim(RewriteRule):
|
|
|
19
19
|
|
|
20
20
|
def rewrite_Statement(self, node: Statement) -> RewriteResult:
|
|
21
21
|
match node:
|
|
22
|
-
case
|
|
23
|
-
return self.
|
|
22
|
+
case squin_noise.stmts.NoiseChannel():
|
|
23
|
+
return self.rewrite_NoiseChannel(node)
|
|
24
24
|
case _:
|
|
25
25
|
return RewriteResult()
|
|
26
26
|
|
|
27
|
-
def
|
|
28
|
-
self, stmt:
|
|
27
|
+
def rewrite_NoiseChannel(
|
|
28
|
+
self, stmt: squin_noise.stmts.NoiseChannel
|
|
29
29
|
) -> RewriteResult:
|
|
30
|
-
"""Rewrite
|
|
30
|
+
"""Rewrite NoiseChannel statements to their stim equivalents."""
|
|
31
31
|
|
|
32
|
-
|
|
33
|
-
applied_op = stmt.operator.owner
|
|
32
|
+
rewrite_method = getattr(self, f"rewrite_{type(stmt).__name__}", None)
|
|
34
33
|
|
|
35
|
-
|
|
34
|
+
# No rewrite method exists and the rewrite should stop
|
|
35
|
+
if rewrite_method is None:
|
|
36
36
|
return RewriteResult()
|
|
37
|
+
if isinstance(stmt, squin_noise.stmts.CorrelatedQubitLoss):
|
|
38
|
+
# CorrelatedQubitLoss represents a broadcast operation, but Stim does not
|
|
39
|
+
# support broadcasting for multi-qubit noise channels.
|
|
40
|
+
# Therefore, we must expand the broadcast into individual stim statements.
|
|
41
|
+
qubit_address_attr = stmt.qubits.hints.get("address", None)
|
|
37
42
|
|
|
38
|
-
|
|
43
|
+
if not isinstance(qubit_address_attr, AddressAttribute):
|
|
44
|
+
return RewriteResult()
|
|
45
|
+
|
|
46
|
+
if not isinstance(address := qubit_address_attr.address, PartialIList):
|
|
47
|
+
return RewriteResult()
|
|
39
48
|
|
|
40
|
-
|
|
41
|
-
if qubit_idx_ssas is None:
|
|
49
|
+
if not types.is_tuple_of(data := address.data, AddressReg):
|
|
42
50
|
return RewriteResult()
|
|
43
51
|
|
|
44
|
-
|
|
45
|
-
stim_stmt = rewrite_method(stmt, qubit_idx_ssas)
|
|
52
|
+
for address_reg in data:
|
|
46
53
|
|
|
47
|
-
|
|
48
|
-
|
|
54
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
55
|
+
AddressAttribute(address_reg), stmt
|
|
56
|
+
)
|
|
49
57
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
58
|
+
stim_stmt = rewrite_method(stmt, qubit_idx_ssas)
|
|
59
|
+
stim_stmt.insert_before(stmt)
|
|
60
|
+
|
|
61
|
+
stmt.delete()
|
|
54
62
|
|
|
55
63
|
return RewriteResult(has_done_something=True)
|
|
56
|
-
return RewriteResult()
|
|
57
64
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
65
|
+
if isinstance(stmt, squin_noise.stmts.SingleQubitNoiseChannel):
|
|
66
|
+
qubit_address_attr = stmt.qubits.hints.get("address", None)
|
|
67
|
+
if qubit_address_attr is None:
|
|
68
|
+
return RewriteResult()
|
|
69
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(qubit_address_attr, stmt)
|
|
70
|
+
|
|
71
|
+
elif isinstance(stmt, squin_noise.stmts.TwoQubitNoiseChannel):
|
|
72
|
+
control_address_attr = stmt.controls.hints.get("address", None)
|
|
73
|
+
target_address_attr = stmt.targets.hints.get("address", None)
|
|
74
|
+
if control_address_attr is None or target_address_attr is None:
|
|
75
|
+
return RewriteResult()
|
|
76
|
+
control_qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
77
|
+
control_address_attr, stmt
|
|
78
|
+
)
|
|
79
|
+
target_qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
80
|
+
target_address_attr, stmt
|
|
81
|
+
)
|
|
82
|
+
if control_qubit_idx_ssas is None or target_qubit_idx_ssas is None:
|
|
83
|
+
return RewriteResult()
|
|
84
|
+
|
|
85
|
+
# For stim statements you want to interleave the control and target qubit indices:
|
|
86
|
+
# ex: CX controls = (0,1) targets = (2,3) in stim is: CX 0 2 1 3
|
|
87
|
+
qubit_idx_ssas = list(
|
|
88
|
+
itertools.chain.from_iterable(
|
|
89
|
+
zip(control_qubit_idx_ssas, target_qubit_idx_ssas)
|
|
90
|
+
)
|
|
91
|
+
)
|
|
77
92
|
else:
|
|
78
|
-
|
|
79
|
-
|
|
93
|
+
return RewriteResult()
|
|
94
|
+
|
|
95
|
+
# guaranteed that you have a valid stim_stmt to plug in
|
|
96
|
+
stim_stmt = rewrite_method(stmt, tuple(qubit_idx_ssas))
|
|
97
|
+
stmt.replace_by(stim_stmt)
|
|
98
|
+
|
|
99
|
+
return RewriteResult(has_done_something=True)
|
|
80
100
|
|
|
81
101
|
def rewrite_SingleQubitPauliChannel(
|
|
82
102
|
self,
|
|
83
|
-
stmt:
|
|
103
|
+
stmt: squin_noise.stmts.SingleQubitPauliChannel,
|
|
84
104
|
qubit_idx_ssas: Tuple[SSAValue],
|
|
85
105
|
) -> Statement:
|
|
86
106
|
"""Rewrite squin.noise.SingleQubitPauliChannel to stim.PauliChannel1."""
|
|
87
107
|
|
|
88
|
-
squin_channel = stmt.operator.owner
|
|
89
|
-
assert isinstance(squin_channel, squin_noise.stmts.SingleQubitPauliChannel)
|
|
90
|
-
|
|
91
|
-
params = get_const_value(ilist.IList, squin_channel.params)
|
|
92
|
-
new_stmts = [
|
|
93
|
-
p_x := py.Constant(params[0]),
|
|
94
|
-
p_y := py.Constant(params[1]),
|
|
95
|
-
p_z := py.Constant(params[2]),
|
|
96
|
-
]
|
|
97
|
-
for new_stmt in new_stmts:
|
|
98
|
-
new_stmt.insert_before(stmt)
|
|
99
|
-
|
|
100
108
|
stim_stmt = stim_noise.PauliChannel1(
|
|
101
109
|
targets=qubit_idx_ssas,
|
|
102
|
-
px=
|
|
103
|
-
py=
|
|
104
|
-
pz=
|
|
110
|
+
px=stmt.px,
|
|
111
|
+
py=stmt.py,
|
|
112
|
+
pz=stmt.pz,
|
|
105
113
|
)
|
|
106
114
|
return stim_stmt
|
|
107
115
|
|
|
108
|
-
def
|
|
116
|
+
def rewrite_QubitLoss(
|
|
109
117
|
self,
|
|
110
|
-
stmt:
|
|
118
|
+
stmt: squin_noise.stmts.QubitLoss,
|
|
111
119
|
qubit_idx_ssas: Tuple[SSAValue],
|
|
112
120
|
) -> Statement:
|
|
113
|
-
"""Rewrite squin.noise.
|
|
121
|
+
"""Rewrite squin.noise.QubitLoss to stim.TrivialError."""
|
|
114
122
|
|
|
115
|
-
|
|
116
|
-
|
|
123
|
+
stim_stmt = stim_noise.QubitLoss(
|
|
124
|
+
targets=qubit_idx_ssas,
|
|
125
|
+
probs=(stmt.p,),
|
|
126
|
+
)
|
|
117
127
|
|
|
118
|
-
|
|
119
|
-
param_stmts = [py.Constant(p) for p in params]
|
|
120
|
-
for param_stmt in param_stmts:
|
|
121
|
-
param_stmt.insert_before(stmt)
|
|
128
|
+
return stim_stmt
|
|
122
129
|
|
|
123
|
-
|
|
130
|
+
def rewrite_CorrelatedQubitLoss(
|
|
131
|
+
self,
|
|
132
|
+
stmt: squin_noise.stmts.CorrelatedQubitLoss,
|
|
133
|
+
qubit_idx_ssas: Tuple[SSAValue],
|
|
134
|
+
) -> Statement:
|
|
135
|
+
"""Rewrite squin.noise.CorrelatedQubitLoss to stim.CorrelatedQubitLoss."""
|
|
136
|
+
stim_stmt = stim_noise.CorrelatedQubitLoss(
|
|
124
137
|
targets=qubit_idx_ssas,
|
|
125
|
-
|
|
126
|
-
piy=param_stmts[1].result,
|
|
127
|
-
piz=param_stmts[2].result,
|
|
128
|
-
pxi=param_stmts[3].result,
|
|
129
|
-
pxx=param_stmts[4].result,
|
|
130
|
-
pxy=param_stmts[5].result,
|
|
131
|
-
pxz=param_stmts[6].result,
|
|
132
|
-
pyi=param_stmts[7].result,
|
|
133
|
-
pyx=param_stmts[8].result,
|
|
134
|
-
pyy=param_stmts[9].result,
|
|
135
|
-
pyz=param_stmts[10].result,
|
|
136
|
-
pzi=param_stmts[11].result,
|
|
137
|
-
pzx=param_stmts[12].result,
|
|
138
|
-
pzy=param_stmts[13].result,
|
|
139
|
-
pzz=param_stmts[14].result,
|
|
138
|
+
probs=(stmt.p,),
|
|
140
139
|
)
|
|
140
|
+
|
|
141
141
|
return stim_stmt
|
|
142
142
|
|
|
143
|
-
def
|
|
143
|
+
def rewrite_Depolarize(
|
|
144
144
|
self,
|
|
145
|
-
stmt:
|
|
145
|
+
stmt: squin_noise.stmts.Depolarize,
|
|
146
146
|
qubit_idx_ssas: Tuple[SSAValue],
|
|
147
147
|
) -> Statement:
|
|
148
|
-
"""Rewrite squin.noise.
|
|
149
|
-
|
|
150
|
-
squin_channel = stmt.operator.owner
|
|
151
|
-
assert isinstance(squin_channel, squin_noise.stmts.Depolarize2)
|
|
148
|
+
"""Rewrite squin.noise.Depolarize to stim.Depolarize1."""
|
|
152
149
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
150
|
+
stim_stmt = stim_noise.Depolarize1(
|
|
151
|
+
targets=qubit_idx_ssas,
|
|
152
|
+
p=stmt.p,
|
|
153
|
+
)
|
|
156
154
|
|
|
157
|
-
stim_stmt = stim_noise.Depolarize2(targets=qubit_idx_ssas, p=p_stmt.result)
|
|
158
155
|
return stim_stmt
|
|
159
156
|
|
|
160
|
-
def
|
|
157
|
+
def rewrite_TwoQubitPauliChannel(
|
|
161
158
|
self,
|
|
162
|
-
stmt:
|
|
159
|
+
stmt: squin_noise.stmts.TwoQubitPauliChannel,
|
|
163
160
|
qubit_idx_ssas: Tuple[SSAValue],
|
|
164
161
|
) -> Statement:
|
|
165
|
-
"""Rewrite squin.noise.
|
|
162
|
+
"""Rewrite squin.noise.TwoQubitPauliChannel to stim.PauliChannel2."""
|
|
163
|
+
|
|
164
|
+
params = stmt.probabilities
|
|
165
|
+
prob_ssas = []
|
|
166
|
+
for idx in range(15):
|
|
167
|
+
idx_stmt = py.Constant(value=idx)
|
|
168
|
+
idx_stmt.insert_before(stmt)
|
|
169
|
+
getitem_stmt = py.GetItem(obj=params, index=idx_stmt.result)
|
|
170
|
+
getitem_stmt.insert_before(stmt)
|
|
171
|
+
prob_ssas.append(getitem_stmt.result)
|
|
166
172
|
|
|
167
|
-
|
|
168
|
-
|
|
173
|
+
stim_stmt = stim_noise.PauliChannel2(
|
|
174
|
+
targets=qubit_idx_ssas,
|
|
175
|
+
pix=prob_ssas[0],
|
|
176
|
+
piy=prob_ssas[1],
|
|
177
|
+
piz=prob_ssas[2],
|
|
178
|
+
pxi=prob_ssas[3],
|
|
179
|
+
pxx=prob_ssas[4],
|
|
180
|
+
pxy=prob_ssas[5],
|
|
181
|
+
pxz=prob_ssas[6],
|
|
182
|
+
pyi=prob_ssas[7],
|
|
183
|
+
pyx=prob_ssas[8],
|
|
184
|
+
pyy=prob_ssas[9],
|
|
185
|
+
pyz=prob_ssas[10],
|
|
186
|
+
pzi=prob_ssas[11],
|
|
187
|
+
pzx=prob_ssas[12],
|
|
188
|
+
pzy=prob_ssas[13],
|
|
189
|
+
pzz=prob_ssas[14],
|
|
190
|
+
)
|
|
191
|
+
return stim_stmt
|
|
169
192
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
193
|
+
def rewrite_Depolarize2(
|
|
194
|
+
self,
|
|
195
|
+
stmt: squin_noise.stmts.Depolarize2,
|
|
196
|
+
qubit_idx_ssas: Tuple[SSAValue],
|
|
197
|
+
) -> Statement:
|
|
198
|
+
"""Rewrite squin.noise.Depolarize2 to stim.Depolarize2."""
|
|
173
199
|
|
|
174
|
-
stim_stmt = stim_noise.
|
|
200
|
+
stim_stmt = stim_noise.Depolarize2(targets=qubit_idx_ssas, p=stmt.p)
|
|
175
201
|
return stim_stmt
|
bloqade/stim/rewrite/util.py
CHANGED
|
@@ -1,35 +1,8 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from kirin import ir, interp
|
|
4
|
-
from kirin.analysis import const
|
|
1
|
+
from kirin import ir
|
|
5
2
|
from kirin.dialects import py
|
|
6
|
-
from kirin.rewrite.abc import RewriteResult
|
|
7
3
|
|
|
8
|
-
from bloqade.squin import op, wire, noise as squin_noise, qubit
|
|
9
4
|
from bloqade.squin.rewrite import AddressAttribute
|
|
10
|
-
from bloqade.
|
|
11
|
-
from bloqade.analysis.address import AddressReg, AddressWire, AddressQubit, AddressTuple
|
|
12
|
-
|
|
13
|
-
SQUIN_STIM_OP_MAPPING = {
|
|
14
|
-
op.stmts.X: gate.X,
|
|
15
|
-
op.stmts.Y: gate.Y,
|
|
16
|
-
op.stmts.Z: gate.Z,
|
|
17
|
-
op.stmts.H: gate.H,
|
|
18
|
-
op.stmts.S: gate.S,
|
|
19
|
-
op.stmts.SqrtX: gate.SqrtX,
|
|
20
|
-
op.stmts.SqrtY: gate.SqrtY,
|
|
21
|
-
op.stmts.Identity: gate.Identity,
|
|
22
|
-
op.stmts.Reset: collapse.RZ,
|
|
23
|
-
squin_noise.stmts.QubitLoss: stim_noise.QubitLoss,
|
|
24
|
-
}
|
|
25
|
-
|
|
26
|
-
# Squin allows creation of control gates where the gate can be any operator,
|
|
27
|
-
# but Stim only supports CX, CY, and CZ as control gates.
|
|
28
|
-
SQUIN_STIM_CONTROL_GATE_MAPPING = {
|
|
29
|
-
op.stmts.X: gate.CX,
|
|
30
|
-
op.stmts.Y: gate.CY,
|
|
31
|
-
op.stmts.Z: gate.CZ,
|
|
32
|
-
}
|
|
5
|
+
from bloqade.analysis.address import AddressReg, AddressQubit
|
|
33
6
|
|
|
34
7
|
|
|
35
8
|
def create_and_insert_qubit_idx_stmt(
|
|
@@ -46,177 +19,17 @@ def insert_qubit_idx_from_address(
|
|
|
46
19
|
"""
|
|
47
20
|
Extract qubit indices from an AddressAttribute and insert them into the SSA form.
|
|
48
21
|
"""
|
|
49
|
-
address_data = address.address
|
|
50
22
|
qubit_idx_ssas = []
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
for address_qubit in address_data.data:
|
|
54
|
-
if not isinstance(address_qubit, AddressQubit):
|
|
55
|
-
return
|
|
23
|
+
if isinstance(address_data := address.address, AddressReg):
|
|
24
|
+
for qubit_idx in address_data.qubits:
|
|
56
25
|
create_and_insert_qubit_idx_stmt(
|
|
57
|
-
|
|
58
|
-
)
|
|
59
|
-
elif isinstance(address_data, AddressReg):
|
|
60
|
-
for qubit_idx in address_data.data:
|
|
61
|
-
create_and_insert_qubit_idx_stmt(
|
|
62
|
-
qubit_idx, stmt_to_insert_before, qubit_idx_ssas
|
|
26
|
+
qubit_idx.data, stmt_to_insert_before, qubit_idx_ssas
|
|
63
27
|
)
|
|
64
28
|
elif isinstance(address_data, AddressQubit):
|
|
65
29
|
create_and_insert_qubit_idx_stmt(
|
|
66
30
|
address_data.data, stmt_to_insert_before, qubit_idx_ssas
|
|
67
31
|
)
|
|
68
|
-
elif isinstance(address_data, AddressWire):
|
|
69
|
-
address_qubit = address_data.origin_qubit
|
|
70
|
-
create_and_insert_qubit_idx_stmt(
|
|
71
|
-
address_qubit.data, stmt_to_insert_before, qubit_idx_ssas
|
|
72
|
-
)
|
|
73
32
|
else:
|
|
74
33
|
return
|
|
75
34
|
|
|
76
35
|
return tuple(qubit_idx_ssas)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def insert_qubit_idx_from_wire_ssa(
|
|
80
|
-
wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement
|
|
81
|
-
) -> tuple[ir.SSAValue, ...] | None:
|
|
82
|
-
"""
|
|
83
|
-
Extract qubit indices from wire SSA values and insert them into the SSA form.
|
|
84
|
-
"""
|
|
85
|
-
qubit_idx_ssas = []
|
|
86
|
-
for wire_ssa in wire_ssas:
|
|
87
|
-
address_attribute = wire_ssa.hints.get("address")
|
|
88
|
-
if address_attribute is None:
|
|
89
|
-
return
|
|
90
|
-
assert isinstance(address_attribute, AddressAttribute)
|
|
91
|
-
wire_address = address_attribute.address
|
|
92
|
-
assert isinstance(wire_address, AddressWire)
|
|
93
|
-
qubit_idx = wire_address.origin_qubit.data
|
|
94
|
-
qubit_idx_stmt = py.Constant(qubit_idx)
|
|
95
|
-
qubit_idx_ssas.append(qubit_idx_stmt.result)
|
|
96
|
-
qubit_idx_stmt.insert_before(stmt_to_insert_before)
|
|
97
|
-
|
|
98
|
-
return tuple(qubit_idx_ssas)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def insert_qubit_idx_after_apply(
|
|
102
|
-
stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast,
|
|
103
|
-
) -> tuple[ir.SSAValue, ...] | None:
|
|
104
|
-
"""
|
|
105
|
-
Extract qubit indices from Apply or Broadcast statements.
|
|
106
|
-
"""
|
|
107
|
-
if isinstance(stmt, (qubit.Apply, qubit.Broadcast)):
|
|
108
|
-
qubits = stmt.qubits
|
|
109
|
-
address_attribute = qubits.hints.get("address")
|
|
110
|
-
if address_attribute is None:
|
|
111
|
-
return
|
|
112
|
-
assert isinstance(address_attribute, AddressAttribute)
|
|
113
|
-
return insert_qubit_idx_from_address(
|
|
114
|
-
address=address_attribute, stmt_to_insert_before=stmt
|
|
115
|
-
)
|
|
116
|
-
elif isinstance(stmt, (wire.Apply, wire.Broadcast)):
|
|
117
|
-
wire_ssas = stmt.inputs
|
|
118
|
-
return insert_qubit_idx_from_wire_ssa(
|
|
119
|
-
wire_ssas=wire_ssas, stmt_to_insert_before=stmt
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def rewrite_Control(
|
|
124
|
-
stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast,
|
|
125
|
-
) -> RewriteResult:
|
|
126
|
-
"""
|
|
127
|
-
Handle control gates for Apply and Broadcast statements.
|
|
128
|
-
"""
|
|
129
|
-
ctrl_op = stmt_with_ctrl.operator.owner
|
|
130
|
-
assert isinstance(ctrl_op, op.stmts.Control)
|
|
131
|
-
|
|
132
|
-
ctrl_op_target_gate = ctrl_op.op.owner
|
|
133
|
-
assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
|
|
134
|
-
|
|
135
|
-
qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl)
|
|
136
|
-
if qubit_idx_ssas is None:
|
|
137
|
-
return RewriteResult()
|
|
138
|
-
|
|
139
|
-
# Separate control and target qubits
|
|
140
|
-
target_qubits = []
|
|
141
|
-
ctrl_qubits = []
|
|
142
|
-
for i in range(len(qubit_idx_ssas)):
|
|
143
|
-
if (i % 2) == 0:
|
|
144
|
-
ctrl_qubits.append(qubit_idx_ssas[i])
|
|
145
|
-
else:
|
|
146
|
-
target_qubits.append(qubit_idx_ssas[i])
|
|
147
|
-
|
|
148
|
-
target_qubits = tuple(target_qubits)
|
|
149
|
-
ctrl_qubits = tuple(ctrl_qubits)
|
|
150
|
-
|
|
151
|
-
stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate))
|
|
152
|
-
if stim_gate is None:
|
|
153
|
-
return RewriteResult()
|
|
154
|
-
|
|
155
|
-
stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits)
|
|
156
|
-
|
|
157
|
-
if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)):
|
|
158
|
-
create_wire_passthrough(stmt_with_ctrl)
|
|
159
|
-
|
|
160
|
-
stmt_with_ctrl.replace_by(stim_stmt)
|
|
161
|
-
|
|
162
|
-
return RewriteResult(has_done_something=True)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def rewrite_QubitLoss(
|
|
166
|
-
stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
|
|
167
|
-
) -> RewriteResult:
|
|
168
|
-
"""
|
|
169
|
-
Rewrite QubitLoss statements to Stim's TrivialError.
|
|
170
|
-
"""
|
|
171
|
-
|
|
172
|
-
squin_loss_op = stmt.operator.owner
|
|
173
|
-
assert isinstance(squin_loss_op, squin_noise.stmts.QubitLoss)
|
|
174
|
-
|
|
175
|
-
qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt)
|
|
176
|
-
if qubit_idx_ssas is None:
|
|
177
|
-
return RewriteResult()
|
|
178
|
-
|
|
179
|
-
stim_loss_stmt = stim_noise.QubitLoss(
|
|
180
|
-
targets=qubit_idx_ssas,
|
|
181
|
-
probs=(squin_loss_op.p,),
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
if isinstance(stmt, (wire.Apply, wire.Broadcast)):
|
|
185
|
-
create_wire_passthrough(stmt)
|
|
186
|
-
|
|
187
|
-
stmt.replace_by(stim_loss_stmt)
|
|
188
|
-
|
|
189
|
-
return RewriteResult(has_done_something=True)
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
def create_wire_passthrough(stmt: wire.Apply | wire.Broadcast) -> None:
|
|
193
|
-
|
|
194
|
-
for input_wire, output_wire in zip(stmt.inputs, stmt.results):
|
|
195
|
-
# have to "reroute" the input of these statements to directly plug in
|
|
196
|
-
# to subsequent statements, remove dependency on the current statement
|
|
197
|
-
output_wire.replace_by(input_wire)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
def is_measure_result_used(
|
|
201
|
-
stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure,
|
|
202
|
-
) -> bool:
|
|
203
|
-
"""
|
|
204
|
-
Check if the result of a measure statement is used in the program.
|
|
205
|
-
"""
|
|
206
|
-
return bool(stmt.result.uses)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
T = TypeVar("T")
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
def get_const_value(typ: type[T], value: ir.SSAValue) -> T:
|
|
213
|
-
if isinstance(hint := value.hints.get("const"), const.Value):
|
|
214
|
-
data = hint.data
|
|
215
|
-
if isinstance(data, typ):
|
|
216
|
-
return hint.data
|
|
217
|
-
raise interp.InterpreterError(
|
|
218
|
-
f"Expected constant value <type = {typ}>, got {data}"
|
|
219
|
-
)
|
|
220
|
-
raise interp.InterpreterError(
|
|
221
|
-
f"Expected constant value <type = {typ}>, got {value}"
|
|
222
|
-
)
|
bloqade/test_utils.py
CHANGED
|
@@ -25,7 +25,7 @@ def print_diff(node: pprint.Printable, expected_node: pprint.Printable):
|
|
|
25
25
|
|
|
26
26
|
def assert_nodes(node: ir.IRNode, expected_node: ir.IRNode):
|
|
27
27
|
try:
|
|
28
|
-
assert node.
|
|
28
|
+
assert node.is_structurally_equal(expected_node)
|
|
29
29
|
except AssertionError as e:
|
|
30
30
|
print_diff(node, expected_node)
|
|
31
31
|
raise e
|
bloqade/types.py
CHANGED
|
@@ -22,3 +22,13 @@ class Qubit(ABC):
|
|
|
22
22
|
|
|
23
23
|
QubitType = types.PyClass(Qubit)
|
|
24
24
|
"""Kirin type for a qubit."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MeasurementResult:
|
|
28
|
+
"""Runtime representation of the result of a measurement on a qubit."""
|
|
29
|
+
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
MeasurementResultType = types.PyClass(MeasurementResult)
|
|
34
|
+
"""Kirin type for a measurement result."""
|