bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bloqade/analysis/address/__init__.py +8 -4
- bloqade/analysis/address/analysis.py +123 -33
- bloqade/analysis/address/impls.py +293 -90
- bloqade/analysis/address/lattice.py +209 -24
- bloqade/analysis/fidelity/analysis.py +11 -23
- bloqade/analysis/measure_id/__init__.py +4 -1
- bloqade/analysis/measure_id/analysis.py +29 -20
- bloqade/analysis/measure_id/impls.py +72 -31
- bloqade/annotate/__init__.py +6 -0
- bloqade/annotate/_dialect.py +3 -0
- bloqade/annotate/_interface.py +22 -0
- bloqade/annotate/stmts.py +29 -0
- bloqade/annotate/types.py +13 -0
- bloqade/cirq_utils/__init__.py +4 -2
- bloqade/cirq_utils/emit/__init__.py +3 -0
- bloqade/cirq_utils/emit/base.py +246 -0
- bloqade/cirq_utils/emit/gate.py +104 -0
- bloqade/cirq_utils/emit/noise.py +90 -0
- bloqade/cirq_utils/emit/qubit.py +35 -0
- bloqade/cirq_utils/lowering.py +660 -0
- bloqade/cirq_utils/noise/__init__.py +0 -2
- bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
- bloqade/cirq_utils/noise/model.py +151 -191
- bloqade/cirq_utils/noise/transform.py +2 -2
- bloqade/cirq_utils/parallelize.py +9 -6
- bloqade/gemini/__init__.py +1 -0
- bloqade/gemini/analysis/__init__.py +3 -0
- bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
- bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
- bloqade/gemini/analysis/logical_validation/impls.py +101 -0
- bloqade/gemini/groups.py +67 -0
- bloqade/native/__init__.py +23 -0
- bloqade/native/_prelude.py +45 -0
- bloqade/native/dialects/__init__.py +0 -0
- bloqade/native/dialects/gate/__init__.py +2 -0
- bloqade/native/dialects/gate/_dialect.py +3 -0
- bloqade/native/dialects/gate/_interface.py +32 -0
- bloqade/native/dialects/gate/stmts.py +31 -0
- bloqade/native/stdlib/__init__.py +0 -0
- bloqade/native/stdlib/broadcast.py +246 -0
- bloqade/native/stdlib/simple.py +220 -0
- bloqade/native/upstream/__init__.py +4 -0
- bloqade/native/upstream/squin2native.py +79 -0
- bloqade/pyqrack/__init__.py +2 -2
- bloqade/pyqrack/base.py +7 -1
- bloqade/pyqrack/device.py +190 -4
- bloqade/pyqrack/native.py +49 -0
- bloqade/pyqrack/reg.py +6 -6
- bloqade/pyqrack/squin/gate/__init__.py +1 -0
- bloqade/pyqrack/squin/gate/gate.py +136 -0
- bloqade/pyqrack/squin/noise/native.py +120 -54
- bloqade/pyqrack/squin/qubit.py +39 -36
- bloqade/pyqrack/target.py +5 -4
- bloqade/pyqrack/task.py +114 -7
- bloqade/qasm2/_qasm_loading.py +3 -3
- bloqade/qasm2/dialects/core/address.py +21 -12
- bloqade/qasm2/dialects/expr/_emit.py +19 -8
- bloqade/qasm2/dialects/expr/stmts.py +7 -7
- bloqade/qasm2/dialects/noise/fidelity.py +4 -8
- bloqade/qasm2/dialects/noise/model.py +2 -1
- bloqade/qasm2/emit/base.py +16 -11
- bloqade/qasm2/emit/gate.py +11 -8
- bloqade/qasm2/emit/main.py +103 -3
- bloqade/qasm2/emit/target.py +9 -5
- bloqade/qasm2/groups.py +3 -2
- bloqade/qasm2/parse/lowering.py +0 -1
- bloqade/qasm2/passes/fold.py +14 -73
- bloqade/qasm2/passes/glob.py +2 -2
- bloqade/qasm2/passes/noise.py +1 -1
- bloqade/qasm2/passes/parallel.py +7 -5
- bloqade/qasm2/rewrite/__init__.py +0 -1
- bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
- bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
- bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
- bloqade/qasm2/rewrite/register.py +2 -2
- bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
- bloqade/qbraid/lowering.py +1 -0
- bloqade/qbraid/schema.py +2 -2
- bloqade/qubit/__init__.py +12 -0
- bloqade/qubit/_dialect.py +3 -0
- bloqade/qubit/_interface.py +49 -0
- bloqade/qubit/_prelude.py +45 -0
- bloqade/qubit/analysis/__init__.py +1 -0
- bloqade/qubit/analysis/address_impl.py +40 -0
- bloqade/qubit/stdlib/__init__.py +2 -0
- bloqade/qubit/stdlib/_new.py +34 -0
- bloqade/qubit/stdlib/broadcast.py +62 -0
- bloqade/qubit/stdlib/simple.py +59 -0
- bloqade/qubit/stmts.py +60 -0
- bloqade/rewrite/passes/__init__.py +6 -0
- bloqade/rewrite/passes/aggressive_unroll.py +103 -0
- bloqade/rewrite/passes/callgraph.py +116 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
- bloqade/rewrite/rules/split_ifs.py +18 -1
- bloqade/squin/__init__.py +47 -14
- bloqade/squin/analysis/__init__.py +0 -1
- bloqade/squin/analysis/schedule.py +10 -11
- bloqade/squin/gate/__init__.py +2 -0
- bloqade/squin/gate/_dialect.py +3 -0
- bloqade/squin/gate/_interface.py +98 -0
- bloqade/squin/gate/stmts.py +125 -0
- bloqade/squin/groups.py +5 -22
- bloqade/squin/noise/__init__.py +1 -10
- bloqade/squin/noise/_dialect.py +1 -1
- bloqade/squin/noise/_interface.py +45 -0
- bloqade/squin/noise/stmts.py +66 -28
- bloqade/squin/rewrite/U3_to_clifford.py +70 -51
- bloqade/squin/rewrite/__init__.py +0 -2
- bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
- bloqade/squin/rewrite/wrap_analysis.py +4 -35
- bloqade/squin/stdlib/__init__.py +0 -0
- bloqade/squin/stdlib/broadcast/__init__.py +34 -0
- bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
- bloqade/squin/stdlib/broadcast/gate.py +260 -0
- bloqade/squin/stdlib/broadcast/noise.py +144 -0
- bloqade/squin/stdlib/simple/__init__.py +33 -0
- bloqade/squin/stdlib/simple/gate.py +242 -0
- bloqade/squin/stdlib/simple/noise.py +126 -0
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +6 -0
- bloqade/stim/dialects/auxiliary/emit.py +19 -18
- bloqade/stim/dialects/collapse/emit_str.py +7 -8
- bloqade/stim/dialects/gate/emit.py +9 -10
- bloqade/stim/dialects/noise/emit.py +17 -13
- bloqade/stim/dialects/noise/stmts.py +5 -3
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/impls.py +16 -0
- bloqade/stim/emit/stim_str.py +48 -31
- bloqade/stim/groups.py +12 -2
- bloqade/stim/parse/lowering.py +14 -17
- bloqade/stim/passes/__init__.py +3 -1
- bloqade/stim/passes/flatten.py +26 -0
- bloqade/stim/passes/simplify_ifs.py +16 -2
- bloqade/stim/passes/squin_to_stim.py +18 -60
- bloqade/stim/rewrite/__init__.py +3 -4
- bloqade/stim/rewrite/get_record_util.py +24 -0
- bloqade/stim/rewrite/ifs_to_stim.py +29 -31
- bloqade/stim/rewrite/qubit_to_stim.py +90 -41
- bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
- bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
- bloqade/stim/rewrite/squin_measure.py +11 -79
- bloqade/stim/rewrite/squin_noise.py +134 -108
- bloqade/stim/rewrite/util.py +5 -192
- bloqade/test_utils.py +1 -1
- bloqade/types.py +10 -0
- bloqade/validation/__init__.py +2 -0
- bloqade/validation/analysis/__init__.py +5 -0
- bloqade/validation/analysis/analysis.py +41 -0
- bloqade/validation/analysis/lattice.py +58 -0
- bloqade/validation/kernel_validation.py +77 -0
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
- bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
- bloqade/pyqrack/squin/op.py +0 -166
- bloqade/pyqrack/squin/runtime.py +0 -535
- bloqade/pyqrack/squin/wire.py +0 -51
- bloqade/rewrite/rules/flatten_ilist.py +0 -51
- bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
- bloqade/squin/_typeinfer.py +0 -20
- bloqade/squin/analysis/address_impl.py +0 -71
- bloqade/squin/analysis/nsites/__init__.py +0 -9
- bloqade/squin/analysis/nsites/analysis.py +0 -50
- bloqade/squin/analysis/nsites/impls.py +0 -92
- bloqade/squin/analysis/nsites/lattice.py +0 -49
- bloqade/squin/cirq/__init__.py +0 -265
- bloqade/squin/cirq/emit/emit_circuit.py +0 -109
- bloqade/squin/cirq/emit/noise.py +0 -49
- bloqade/squin/cirq/emit/op.py +0 -125
- bloqade/squin/cirq/emit/qubit.py +0 -60
- bloqade/squin/cirq/emit/runtime.py +0 -242
- bloqade/squin/cirq/lowering.py +0 -440
- bloqade/squin/lowering.py +0 -54
- bloqade/squin/noise/_wrapper.py +0 -40
- bloqade/squin/noise/rewrite.py +0 -111
- bloqade/squin/op/__init__.py +0 -41
- bloqade/squin/op/_dialect.py +0 -3
- bloqade/squin/op/_wrapper.py +0 -121
- bloqade/squin/op/number.py +0 -5
- bloqade/squin/op/rewrite.py +0 -46
- bloqade/squin/op/stdlib.py +0 -62
- bloqade/squin/op/stmts.py +0 -276
- bloqade/squin/op/traits.py +0 -43
- bloqade/squin/op/types.py +0 -26
- bloqade/squin/qubit.py +0 -184
- bloqade/squin/rewrite/canonicalize.py +0 -60
- bloqade/squin/rewrite/desugar.py +0 -124
- bloqade/squin/types.py +0 -8
- bloqade/squin/wire.py +0 -201
- bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
- bloqade/stim/rewrite/wire_to_stim.py +0 -57
- bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.dialects.py import Constant
|
|
6
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
7
|
+
|
|
8
|
+
from bloqade.stim.dialects import auxiliary
|
|
9
|
+
from bloqade.annotate.stmts import SetDetector
|
|
10
|
+
from bloqade.analysis.measure_id import MeasureIDFrame
|
|
11
|
+
from bloqade.stim.dialects.auxiliary import Detector
|
|
12
|
+
from bloqade.analysis.measure_id.lattice import MeasureIdTuple
|
|
13
|
+
|
|
14
|
+
from ..rewrite.get_record_util import insert_get_records
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class SetDetectorToStim(RewriteRule):
|
|
19
|
+
"""
|
|
20
|
+
Rewrite SetDetector to GetRecord and Detector in the stim dialect
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
measure_id_frame: MeasureIDFrame
|
|
24
|
+
|
|
25
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
26
|
+
match node:
|
|
27
|
+
case SetDetector():
|
|
28
|
+
return self.rewrite_SetDetector(node)
|
|
29
|
+
case _:
|
|
30
|
+
return RewriteResult()
|
|
31
|
+
|
|
32
|
+
def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult:
|
|
33
|
+
|
|
34
|
+
# get coordinates and generate correct consts
|
|
35
|
+
coord_ssas = []
|
|
36
|
+
if not isinstance(node.coordinates.owner, Constant):
|
|
37
|
+
return RewriteResult()
|
|
38
|
+
|
|
39
|
+
coord_values = node.coordinates.owner.value.unwrap()
|
|
40
|
+
|
|
41
|
+
if not isinstance(coord_values, Iterable):
|
|
42
|
+
return RewriteResult()
|
|
43
|
+
|
|
44
|
+
if any(not isinstance(value, (int, float)) for value in coord_values):
|
|
45
|
+
return RewriteResult()
|
|
46
|
+
|
|
47
|
+
for coord_value in coord_values:
|
|
48
|
+
if isinstance(coord_value, float):
|
|
49
|
+
coord_stmt = auxiliary.ConstFloat(value=coord_value)
|
|
50
|
+
else: # int
|
|
51
|
+
coord_stmt = auxiliary.ConstInt(value=coord_value)
|
|
52
|
+
coord_ssas.append(coord_stmt.result)
|
|
53
|
+
coord_stmt.insert_before(node)
|
|
54
|
+
|
|
55
|
+
measure_ids = self.measure_id_frame.entries[node.measurements]
|
|
56
|
+
assert isinstance(measure_ids, MeasureIdTuple)
|
|
57
|
+
|
|
58
|
+
get_record_list = insert_get_records(
|
|
59
|
+
node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
detector_stmt = Detector(
|
|
63
|
+
coord=tuple(coord_ssas), targets=tuple(get_record_list)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
node.replace_by(detector_stmt)
|
|
67
|
+
|
|
68
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
5
|
+
|
|
6
|
+
from bloqade.stim.dialects import auxiliary
|
|
7
|
+
from bloqade.annotate.stmts import SetObservable
|
|
8
|
+
from bloqade.analysis.measure_id import MeasureIDFrame
|
|
9
|
+
from bloqade.stim.dialects.auxiliary import ObservableInclude
|
|
10
|
+
from bloqade.analysis.measure_id.lattice import MeasureIdTuple
|
|
11
|
+
|
|
12
|
+
from ..rewrite.get_record_util import insert_get_records
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class SetObservableToStim(RewriteRule):
|
|
17
|
+
"""
|
|
18
|
+
Rewrite SetObservable to GetRecord and ObservableInclude in the stim dialect
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
measure_id_frame: MeasureIDFrame
|
|
22
|
+
|
|
23
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
24
|
+
match node:
|
|
25
|
+
case SetObservable():
|
|
26
|
+
return self.rewrite_SetObservable(node)
|
|
27
|
+
case _:
|
|
28
|
+
return RewriteResult()
|
|
29
|
+
|
|
30
|
+
def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult:
|
|
31
|
+
|
|
32
|
+
# set idx to 0 for now, but this
|
|
33
|
+
# should be something that a user can set on their own.
|
|
34
|
+
# SetObservable needs to accept an int.
|
|
35
|
+
|
|
36
|
+
idx_stmt = auxiliary.ConstInt(value=0)
|
|
37
|
+
idx_stmt.insert_before(node)
|
|
38
|
+
|
|
39
|
+
measure_ids = self.measure_id_frame.entries[node.measurements]
|
|
40
|
+
assert isinstance(measure_ids, MeasureIdTuple)
|
|
41
|
+
|
|
42
|
+
get_record_list = insert_get_records(
|
|
43
|
+
node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
observable_include_stmt = ObservableInclude(
|
|
47
|
+
idx=idx_stmt.result, targets=tuple(get_record_list)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
node.replace_by(observable_include_stmt)
|
|
51
|
+
|
|
52
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -2,47 +2,15 @@
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
|
|
4
4
|
from kirin import ir
|
|
5
|
-
from kirin.dialects import py
|
|
5
|
+
from kirin.dialects import py
|
|
6
6
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
7
7
|
|
|
8
|
-
from bloqade
|
|
8
|
+
from bloqade import qubit
|
|
9
9
|
from bloqade.squin.rewrite import AddressAttribute
|
|
10
|
-
from bloqade.stim.dialects import collapse
|
|
10
|
+
from bloqade.stim.dialects import collapse
|
|
11
11
|
from bloqade.stim.rewrite.util import (
|
|
12
|
-
is_measure_result_used,
|
|
13
12
|
insert_qubit_idx_from_address,
|
|
14
13
|
)
|
|
15
|
-
from bloqade.analysis.measure_id.lattice import MeasureId, MeasureIdBool, MeasureIdTuple
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def replace_get_record(
|
|
19
|
-
node: ir.Statement, measure_id_bool: MeasureIdBool, meas_count: int
|
|
20
|
-
):
|
|
21
|
-
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
22
|
-
target_rec_idx = (measure_id_bool.idx - 1) - meas_count
|
|
23
|
-
idx_stmt = py.constant.Constant(target_rec_idx)
|
|
24
|
-
idx_stmt.insert_before(node)
|
|
25
|
-
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
|
|
26
|
-
node.replace_by(get_record_stmt)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def insert_get_record_list(
|
|
30
|
-
node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count: int
|
|
31
|
-
):
|
|
32
|
-
"""
|
|
33
|
-
Insert GetRecord statements before the given node
|
|
34
|
-
"""
|
|
35
|
-
get_record_ssas = []
|
|
36
|
-
for measure_id_bool in measure_id_tuple.data:
|
|
37
|
-
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
38
|
-
target_rec_idx = (measure_id_bool.idx - 1) - meas_count
|
|
39
|
-
idx_stmt = py.constant.Constant(target_rec_idx)
|
|
40
|
-
idx_stmt.insert_before(node)
|
|
41
|
-
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
|
|
42
|
-
get_record_stmt.insert_before(node)
|
|
43
|
-
get_record_ssas.append(get_record_stmt.result)
|
|
44
|
-
|
|
45
|
-
node.replace_by(ilist.New(values=get_record_ssas))
|
|
46
14
|
|
|
47
15
|
|
|
48
16
|
@dataclass
|
|
@@ -51,29 +19,20 @@ class SquinMeasureToStim(RewriteRule):
|
|
|
51
19
|
Rewrite squin measure-related statements to stim statements.
|
|
52
20
|
"""
|
|
53
21
|
|
|
54
|
-
measure_id_result: dict[ir.SSAValue, MeasureId]
|
|
55
|
-
total_measure_count: int
|
|
56
|
-
|
|
57
22
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
58
23
|
|
|
59
24
|
match node:
|
|
60
|
-
case qubit.
|
|
25
|
+
case qubit.stmts.Measure():
|
|
61
26
|
return self.rewrite_Measure(node)
|
|
62
27
|
case _:
|
|
63
28
|
return RewriteResult()
|
|
64
29
|
|
|
65
|
-
def rewrite_Measure(
|
|
66
|
-
self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
|
|
67
|
-
) -> RewriteResult:
|
|
30
|
+
def rewrite_Measure(self, measure_stmt: qubit.stmts.Measure) -> RewriteResult:
|
|
68
31
|
|
|
69
32
|
qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
|
|
70
33
|
if qubit_idx_ssas is None:
|
|
71
34
|
return RewriteResult()
|
|
72
35
|
|
|
73
|
-
measure_id = self.measure_id_result[measure_stmt.result]
|
|
74
|
-
if not isinstance(measure_id, (MeasureIdBool, MeasureIdTuple)):
|
|
75
|
-
return RewriteResult()
|
|
76
|
-
|
|
77
36
|
prob_noise_stmt = py.constant.Constant(0.0)
|
|
78
37
|
stim_measure_stmt = collapse.MZ(
|
|
79
38
|
p=prob_noise_stmt.result,
|
|
@@ -82,48 +41,21 @@ class SquinMeasureToStim(RewriteRule):
|
|
|
82
41
|
prob_noise_stmt.insert_before(measure_stmt)
|
|
83
42
|
stim_measure_stmt.insert_before(measure_stmt)
|
|
84
43
|
|
|
85
|
-
if not
|
|
44
|
+
# if the measurement is not being used anywhere
|
|
45
|
+
# we can safely get rid of it. Measure cannot be DCE'd because
|
|
46
|
+
# it is not pure.
|
|
47
|
+
if not bool(measure_stmt.result.uses):
|
|
86
48
|
measure_stmt.delete()
|
|
87
|
-
return RewriteResult(has_done_something=True)
|
|
88
|
-
|
|
89
|
-
# replace dataflow with new stmt!
|
|
90
|
-
measure_id = self.measure_id_result[measure_stmt.result]
|
|
91
|
-
if isinstance(measure_id, MeasureIdBool):
|
|
92
|
-
replace_get_record(
|
|
93
|
-
node=measure_stmt,
|
|
94
|
-
measure_id_bool=measure_id,
|
|
95
|
-
meas_count=self.total_measure_count,
|
|
96
|
-
)
|
|
97
|
-
elif isinstance(measure_id, MeasureIdTuple):
|
|
98
|
-
insert_get_record_list(
|
|
99
|
-
node=measure_stmt,
|
|
100
|
-
measure_id_tuple=measure_id,
|
|
101
|
-
meas_count=self.total_measure_count,
|
|
102
|
-
)
|
|
103
|
-
else:
|
|
104
|
-
# already checked before, so this should not happen
|
|
105
|
-
raise ValueError(
|
|
106
|
-
f"Unexpected measure ID type: {type(measure_id)} for measure statement {measure_stmt}"
|
|
107
|
-
)
|
|
108
49
|
|
|
109
50
|
return RewriteResult(has_done_something=True)
|
|
110
51
|
|
|
111
52
|
def get_qubit_idx_ssas(
|
|
112
|
-
self, measure_stmt: qubit.
|
|
53
|
+
self, measure_stmt: qubit.stmts.Measure
|
|
113
54
|
) -> tuple[ir.SSAValue, ...] | None:
|
|
114
55
|
"""
|
|
115
56
|
Extract the address attribute and insert qubit indices for the given measure statement.
|
|
116
57
|
"""
|
|
117
|
-
|
|
118
|
-
case qubit.MeasureQubit():
|
|
119
|
-
address_attr = measure_stmt.qubit.hints.get("address")
|
|
120
|
-
case qubit.MeasureQubitList():
|
|
121
|
-
address_attr = measure_stmt.qubits.hints.get("address")
|
|
122
|
-
case wire.Measure():
|
|
123
|
-
address_attr = measure_stmt.wire.hints.get("address")
|
|
124
|
-
case _:
|
|
125
|
-
return None
|
|
126
|
-
|
|
58
|
+
address_attr = measure_stmt.qubits.hints.get("address")
|
|
127
59
|
if address_attr is None:
|
|
128
60
|
return None
|
|
129
61
|
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
+
import itertools
|
|
1
2
|
from typing import Tuple
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
|
|
5
|
+
from kirin import types
|
|
4
6
|
from kirin.ir import SSAValue, Statement
|
|
5
|
-
from kirin.dialects import py
|
|
7
|
+
from kirin.dialects import py
|
|
6
8
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
7
9
|
|
|
8
|
-
from bloqade.squin import
|
|
10
|
+
from bloqade.squin import noise as squin_noise
|
|
9
11
|
from bloqade.stim.dialects import noise as stim_noise
|
|
10
|
-
from bloqade.stim.rewrite.util import
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
insert_qubit_idx_after_apply,
|
|
14
|
-
)
|
|
12
|
+
from bloqade.stim.rewrite.util import insert_qubit_idx_from_address
|
|
13
|
+
from bloqade.analysis.address.lattice import AddressReg, PartialIList
|
|
14
|
+
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@dataclass
|
|
@@ -19,157 +19,183 @@ class SquinNoiseToStim(RewriteRule):
|
|
|
19
19
|
|
|
20
20
|
def rewrite_Statement(self, node: Statement) -> RewriteResult:
|
|
21
21
|
match node:
|
|
22
|
-
case
|
|
23
|
-
return self.
|
|
22
|
+
case squin_noise.stmts.NoiseChannel():
|
|
23
|
+
return self.rewrite_NoiseChannel(node)
|
|
24
24
|
case _:
|
|
25
25
|
return RewriteResult()
|
|
26
26
|
|
|
27
|
-
def
|
|
28
|
-
self, stmt:
|
|
27
|
+
def rewrite_NoiseChannel(
|
|
28
|
+
self, stmt: squin_noise.stmts.NoiseChannel
|
|
29
29
|
) -> RewriteResult:
|
|
30
|
-
"""Rewrite
|
|
30
|
+
"""Rewrite NoiseChannel statements to their stim equivalents."""
|
|
31
31
|
|
|
32
|
-
|
|
33
|
-
applied_op = stmt.operator.owner
|
|
32
|
+
rewrite_method = getattr(self, f"rewrite_{type(stmt).__name__}", None)
|
|
34
33
|
|
|
35
|
-
|
|
34
|
+
# No rewrite method exists and the rewrite should stop
|
|
35
|
+
if rewrite_method is None:
|
|
36
36
|
return RewriteResult()
|
|
37
|
+
if isinstance(stmt, squin_noise.stmts.CorrelatedQubitLoss):
|
|
38
|
+
# CorrelatedQubitLoss represents a broadcast operation, but Stim does not
|
|
39
|
+
# support broadcasting for multi-qubit noise channels.
|
|
40
|
+
# Therefore, we must expand the broadcast into individual stim statements.
|
|
41
|
+
qubit_address_attr = stmt.qubits.hints.get("address", None)
|
|
37
42
|
|
|
38
|
-
|
|
43
|
+
if not isinstance(qubit_address_attr, AddressAttribute):
|
|
44
|
+
return RewriteResult()
|
|
45
|
+
|
|
46
|
+
if not isinstance(address := qubit_address_attr.address, PartialIList):
|
|
47
|
+
return RewriteResult()
|
|
39
48
|
|
|
40
|
-
|
|
41
|
-
if qubit_idx_ssas is None:
|
|
49
|
+
if not types.is_tuple_of(data := address.data, AddressReg):
|
|
42
50
|
return RewriteResult()
|
|
43
51
|
|
|
44
|
-
|
|
45
|
-
stim_stmt = rewrite_method(stmt, qubit_idx_ssas)
|
|
52
|
+
for address_reg in data:
|
|
46
53
|
|
|
47
|
-
|
|
48
|
-
|
|
54
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
55
|
+
AddressAttribute(address_reg), stmt
|
|
56
|
+
)
|
|
49
57
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
58
|
+
stim_stmt = rewrite_method(stmt, qubit_idx_ssas)
|
|
59
|
+
stim_stmt.insert_before(stmt)
|
|
60
|
+
|
|
61
|
+
stmt.delete()
|
|
54
62
|
|
|
55
63
|
return RewriteResult(has_done_something=True)
|
|
56
|
-
return RewriteResult()
|
|
57
64
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
65
|
+
if isinstance(stmt, squin_noise.stmts.SingleQubitNoiseChannel):
|
|
66
|
+
qubit_address_attr = stmt.qubits.hints.get("address", None)
|
|
67
|
+
if qubit_address_attr is None:
|
|
68
|
+
return RewriteResult()
|
|
69
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(qubit_address_attr, stmt)
|
|
70
|
+
|
|
71
|
+
elif isinstance(stmt, squin_noise.stmts.TwoQubitNoiseChannel):
|
|
72
|
+
control_address_attr = stmt.controls.hints.get("address", None)
|
|
73
|
+
target_address_attr = stmt.targets.hints.get("address", None)
|
|
74
|
+
if control_address_attr is None or target_address_attr is None:
|
|
75
|
+
return RewriteResult()
|
|
76
|
+
control_qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
77
|
+
control_address_attr, stmt
|
|
78
|
+
)
|
|
79
|
+
target_qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
80
|
+
target_address_attr, stmt
|
|
81
|
+
)
|
|
82
|
+
if control_qubit_idx_ssas is None or target_qubit_idx_ssas is None:
|
|
83
|
+
return RewriteResult()
|
|
84
|
+
|
|
85
|
+
# For stim statements you want to interleave the control and target qubit indices:
|
|
86
|
+
# ex: CX controls = (0,1) targets = (2,3) in stim is: CX 0 2 1 3
|
|
87
|
+
qubit_idx_ssas = list(
|
|
88
|
+
itertools.chain.from_iterable(
|
|
89
|
+
zip(control_qubit_idx_ssas, target_qubit_idx_ssas)
|
|
90
|
+
)
|
|
91
|
+
)
|
|
77
92
|
else:
|
|
78
|
-
|
|
79
|
-
|
|
93
|
+
return RewriteResult()
|
|
94
|
+
|
|
95
|
+
# guaranteed that you have a valid stim_stmt to plug in
|
|
96
|
+
stim_stmt = rewrite_method(stmt, tuple(qubit_idx_ssas))
|
|
97
|
+
stmt.replace_by(stim_stmt)
|
|
98
|
+
|
|
99
|
+
return RewriteResult(has_done_something=True)
|
|
80
100
|
|
|
81
101
|
def rewrite_SingleQubitPauliChannel(
|
|
82
102
|
self,
|
|
83
|
-
stmt:
|
|
103
|
+
stmt: squin_noise.stmts.SingleQubitPauliChannel,
|
|
84
104
|
qubit_idx_ssas: Tuple[SSAValue],
|
|
85
105
|
) -> Statement:
|
|
86
106
|
"""Rewrite squin.noise.SingleQubitPauliChannel to stim.PauliChannel1."""
|
|
87
107
|
|
|
88
|
-
squin_channel = stmt.operator.owner
|
|
89
|
-
assert isinstance(squin_channel, squin_noise.stmts.SingleQubitPauliChannel)
|
|
90
|
-
|
|
91
|
-
params = get_const_value(ilist.IList, squin_channel.params)
|
|
92
|
-
new_stmts = [
|
|
93
|
-
p_x := py.Constant(params[0]),
|
|
94
|
-
p_y := py.Constant(params[1]),
|
|
95
|
-
p_z := py.Constant(params[2]),
|
|
96
|
-
]
|
|
97
|
-
for new_stmt in new_stmts:
|
|
98
|
-
new_stmt.insert_before(stmt)
|
|
99
|
-
|
|
100
108
|
stim_stmt = stim_noise.PauliChannel1(
|
|
101
109
|
targets=qubit_idx_ssas,
|
|
102
|
-
px=
|
|
103
|
-
py=
|
|
104
|
-
pz=
|
|
110
|
+
px=stmt.px,
|
|
111
|
+
py=stmt.py,
|
|
112
|
+
pz=stmt.pz,
|
|
105
113
|
)
|
|
106
114
|
return stim_stmt
|
|
107
115
|
|
|
108
|
-
def
|
|
116
|
+
def rewrite_QubitLoss(
|
|
109
117
|
self,
|
|
110
|
-
stmt:
|
|
118
|
+
stmt: squin_noise.stmts.QubitLoss,
|
|
111
119
|
qubit_idx_ssas: Tuple[SSAValue],
|
|
112
120
|
) -> Statement:
|
|
113
|
-
"""Rewrite squin.noise.
|
|
121
|
+
"""Rewrite squin.noise.QubitLoss to stim.TrivialError."""
|
|
114
122
|
|
|
115
|
-
|
|
116
|
-
|
|
123
|
+
stim_stmt = stim_noise.QubitLoss(
|
|
124
|
+
targets=qubit_idx_ssas,
|
|
125
|
+
probs=(stmt.p,),
|
|
126
|
+
)
|
|
117
127
|
|
|
118
|
-
|
|
119
|
-
param_stmts = [py.Constant(p) for p in params]
|
|
120
|
-
for param_stmt in param_stmts:
|
|
121
|
-
param_stmt.insert_before(stmt)
|
|
128
|
+
return stim_stmt
|
|
122
129
|
|
|
123
|
-
|
|
130
|
+
def rewrite_CorrelatedQubitLoss(
|
|
131
|
+
self,
|
|
132
|
+
stmt: squin_noise.stmts.CorrelatedQubitLoss,
|
|
133
|
+
qubit_idx_ssas: Tuple[SSAValue],
|
|
134
|
+
) -> Statement:
|
|
135
|
+
"""Rewrite squin.noise.CorrelatedQubitLoss to stim.CorrelatedQubitLoss."""
|
|
136
|
+
stim_stmt = stim_noise.CorrelatedQubitLoss(
|
|
124
137
|
targets=qubit_idx_ssas,
|
|
125
|
-
|
|
126
|
-
piy=param_stmts[1].result,
|
|
127
|
-
piz=param_stmts[2].result,
|
|
128
|
-
pxi=param_stmts[3].result,
|
|
129
|
-
pxx=param_stmts[4].result,
|
|
130
|
-
pxy=param_stmts[5].result,
|
|
131
|
-
pxz=param_stmts[6].result,
|
|
132
|
-
pyi=param_stmts[7].result,
|
|
133
|
-
pyx=param_stmts[8].result,
|
|
134
|
-
pyy=param_stmts[9].result,
|
|
135
|
-
pyz=param_stmts[10].result,
|
|
136
|
-
pzi=param_stmts[11].result,
|
|
137
|
-
pzx=param_stmts[12].result,
|
|
138
|
-
pzy=param_stmts[13].result,
|
|
139
|
-
pzz=param_stmts[14].result,
|
|
138
|
+
probs=(stmt.p,),
|
|
140
139
|
)
|
|
140
|
+
|
|
141
141
|
return stim_stmt
|
|
142
142
|
|
|
143
|
-
def
|
|
143
|
+
def rewrite_Depolarize(
|
|
144
144
|
self,
|
|
145
|
-
stmt:
|
|
145
|
+
stmt: squin_noise.stmts.Depolarize,
|
|
146
146
|
qubit_idx_ssas: Tuple[SSAValue],
|
|
147
147
|
) -> Statement:
|
|
148
|
-
"""Rewrite squin.noise.
|
|
149
|
-
|
|
150
|
-
squin_channel = stmt.operator.owner
|
|
151
|
-
assert isinstance(squin_channel, squin_noise.stmts.Depolarize2)
|
|
148
|
+
"""Rewrite squin.noise.Depolarize to stim.Depolarize1."""
|
|
152
149
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
150
|
+
stim_stmt = stim_noise.Depolarize1(
|
|
151
|
+
targets=qubit_idx_ssas,
|
|
152
|
+
p=stmt.p,
|
|
153
|
+
)
|
|
156
154
|
|
|
157
|
-
stim_stmt = stim_noise.Depolarize2(targets=qubit_idx_ssas, p=p_stmt.result)
|
|
158
155
|
return stim_stmt
|
|
159
156
|
|
|
160
|
-
def
|
|
157
|
+
def rewrite_TwoQubitPauliChannel(
|
|
161
158
|
self,
|
|
162
|
-
stmt:
|
|
159
|
+
stmt: squin_noise.stmts.TwoQubitPauliChannel,
|
|
163
160
|
qubit_idx_ssas: Tuple[SSAValue],
|
|
164
161
|
) -> Statement:
|
|
165
|
-
"""Rewrite squin.noise.
|
|
162
|
+
"""Rewrite squin.noise.TwoQubitPauliChannel to stim.PauliChannel2."""
|
|
163
|
+
|
|
164
|
+
params = stmt.probabilities
|
|
165
|
+
prob_ssas = []
|
|
166
|
+
for idx in range(15):
|
|
167
|
+
idx_stmt = py.Constant(value=idx)
|
|
168
|
+
idx_stmt.insert_before(stmt)
|
|
169
|
+
getitem_stmt = py.GetItem(obj=params, index=idx_stmt.result)
|
|
170
|
+
getitem_stmt.insert_before(stmt)
|
|
171
|
+
prob_ssas.append(getitem_stmt.result)
|
|
166
172
|
|
|
167
|
-
|
|
168
|
-
|
|
173
|
+
stim_stmt = stim_noise.PauliChannel2(
|
|
174
|
+
targets=qubit_idx_ssas,
|
|
175
|
+
pix=prob_ssas[0],
|
|
176
|
+
piy=prob_ssas[1],
|
|
177
|
+
piz=prob_ssas[2],
|
|
178
|
+
pxi=prob_ssas[3],
|
|
179
|
+
pxx=prob_ssas[4],
|
|
180
|
+
pxy=prob_ssas[5],
|
|
181
|
+
pxz=prob_ssas[6],
|
|
182
|
+
pyi=prob_ssas[7],
|
|
183
|
+
pyx=prob_ssas[8],
|
|
184
|
+
pyy=prob_ssas[9],
|
|
185
|
+
pyz=prob_ssas[10],
|
|
186
|
+
pzi=prob_ssas[11],
|
|
187
|
+
pzx=prob_ssas[12],
|
|
188
|
+
pzy=prob_ssas[13],
|
|
189
|
+
pzz=prob_ssas[14],
|
|
190
|
+
)
|
|
191
|
+
return stim_stmt
|
|
169
192
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
193
|
+
def rewrite_Depolarize2(
|
|
194
|
+
self,
|
|
195
|
+
stmt: squin_noise.stmts.Depolarize2,
|
|
196
|
+
qubit_idx_ssas: Tuple[SSAValue],
|
|
197
|
+
) -> Statement:
|
|
198
|
+
"""Rewrite squin.noise.Depolarize2 to stim.Depolarize2."""
|
|
173
199
|
|
|
174
|
-
stim_stmt = stim_noise.
|
|
200
|
+
stim_stmt = stim_noise.Depolarize2(targets=qubit_idx_ssas, p=stmt.p)
|
|
175
201
|
return stim_stmt
|