bloqade-circuit 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +21 -68
- bloqade/analysis/measure_id/__init__.py +2 -0
- bloqade/analysis/measure_id/analysis.py +45 -0
- bloqade/analysis/measure_id/impls.py +155 -0
- bloqade/analysis/measure_id/lattice.py +82 -0
- bloqade/cirq_utils/__init__.py +7 -0
- bloqade/cirq_utils/lineprog.py +295 -0
- bloqade/cirq_utils/parallelize.py +400 -0
- bloqade/pyqrack/squin/op.py +7 -2
- bloqade/pyqrack/squin/runtime.py +4 -2
- bloqade/qasm2/dialects/expr/stmts.py +2 -20
- bloqade/qasm2/parse/lowering.py +1 -0
- bloqade/qasm2/passes/parallel.py +18 -0
- bloqade/qasm2/passes/unroll_if.py +9 -2
- bloqade/qasm2/rewrite/__init__.py +1 -0
- bloqade/qasm2/rewrite/parallel_to_glob.py +82 -0
- bloqade/rewrite/__init__.py +0 -0
- bloqade/rewrite/passes/__init__.py +1 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +28 -0
- bloqade/rewrite/rules/__init__.py +1 -0
- bloqade/rewrite/rules/flatten_ilist.py +51 -0
- bloqade/rewrite/rules/inline_getitem_ilist.py +31 -0
- bloqade/{qasm2/rewrite → rewrite/rules}/split_ifs.py +15 -8
- bloqade/squin/__init__.py +2 -0
- bloqade/squin/_typeinfer.py +20 -0
- bloqade/squin/analysis/__init__.py +1 -0
- bloqade/squin/analysis/address_impl.py +71 -0
- bloqade/squin/analysis/nsites/impls.py +6 -1
- bloqade/squin/cirq/lowering.py +19 -6
- bloqade/squin/noise/stmts.py +1 -1
- bloqade/squin/op/__init__.py +1 -0
- bloqade/squin/op/_wrapper.py +4 -0
- bloqade/squin/op/stmts.py +20 -2
- bloqade/squin/qubit.py +8 -5
- bloqade/squin/rewrite/__init__.py +1 -0
- bloqade/squin/rewrite/canonicalize.py +60 -0
- bloqade/squin/rewrite/desugar.py +52 -5
- bloqade/squin/types.py +8 -0
- bloqade/squin/wire.py +91 -5
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +4 -0
- bloqade/stim/dialects/auxiliary/interp.py +0 -10
- bloqade/stim/dialects/auxiliary/stmts/annotate.py +1 -1
- bloqade/stim/dialects/noise/emit.py +1 -0
- bloqade/stim/dialects/noise/stmts.py +5 -0
- bloqade/stim/passes/__init__.py +1 -1
- bloqade/stim/passes/simplify_ifs.py +32 -0
- bloqade/stim/passes/squin_to_stim.py +109 -26
- bloqade/stim/rewrite/__init__.py +1 -0
- bloqade/stim/rewrite/ifs_to_stim.py +203 -0
- bloqade/stim/rewrite/qubit_to_stim.py +13 -6
- bloqade/stim/rewrite/squin_measure.py +68 -5
- bloqade/stim/rewrite/squin_noise.py +120 -0
- bloqade/stim/rewrite/util.py +40 -9
- bloqade/stim/rewrite/wire_to_stim.py +8 -3
- bloqade/stim/upstream/__init__.py +1 -0
- bloqade/stim/upstream/from_squin.py +10 -0
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/METADATA +4 -2
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/RECORD +61 -38
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
from dataclasses import field, dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.dialects import py, scf, func
|
|
5
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
6
|
+
|
|
7
|
+
from bloqade.squin import op, qubit
|
|
8
|
+
from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
|
|
9
|
+
from bloqade.squin.rewrite import AddressAttribute
|
|
10
|
+
from bloqade.stim.rewrite.util import (
|
|
11
|
+
SQUIN_STIM_CONTROL_GATE_MAPPING,
|
|
12
|
+
insert_qubit_idx_from_address,
|
|
13
|
+
)
|
|
14
|
+
from bloqade.stim.dialects.auxiliary import GetRecord
|
|
15
|
+
from bloqade.analysis.measure_id.lattice import (
|
|
16
|
+
MeasureId,
|
|
17
|
+
MeasureIdBool,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class IfElseSimplification:
|
|
23
|
+
|
|
24
|
+
# Might be better to just do a rewrite_Region?
|
|
25
|
+
def is_rewriteable(self, node: scf.IfElse) -> bool:
|
|
26
|
+
return not (
|
|
27
|
+
self.contains_ifelse(node)
|
|
28
|
+
or self.is_nested_ifelse(node)
|
|
29
|
+
or self.has_else_body(node)
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# A preliminary check to reject an IfElse from the "top down"
|
|
33
|
+
# use in conjunction with is_nested_ifelse
|
|
34
|
+
# to completely cover cases of nested IfElse statements
|
|
35
|
+
def contains_ifelse(self, stmt: scf.IfElse) -> bool:
|
|
36
|
+
"""Check if the IfElse statement contains another IfElse statement."""
|
|
37
|
+
for child in stmt.walk(include_self=False):
|
|
38
|
+
if isinstance(child, scf.IfElse):
|
|
39
|
+
return True
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
# because rewrite latches onto ANY scf.IfElse,
|
|
43
|
+
# you need a way to determine if you're touching an
|
|
44
|
+
# IfElse that's inside another IfElse
|
|
45
|
+
def is_nested_ifelse(self, stmt: scf.IfElse) -> bool:
|
|
46
|
+
"""Check if the IfElse statement is nested."""
|
|
47
|
+
if stmt.parent_stmt is not None:
|
|
48
|
+
if isinstance(stmt.parent_stmt, scf.IfElse) or isinstance(
|
|
49
|
+
stmt.parent_stmt.parent_stmt, scf.IfElse
|
|
50
|
+
):
|
|
51
|
+
return True
|
|
52
|
+
else:
|
|
53
|
+
return False
|
|
54
|
+
else:
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
def has_else_body(self, stmt: scf.IfElse) -> bool:
|
|
58
|
+
"""Check if the IfElse statement has an else body."""
|
|
59
|
+
if stmt.else_body.blocks and not (
|
|
60
|
+
len(stmt.else_body.blocks[0].stmts) == 1
|
|
61
|
+
and isinstance(else_term := stmt.else_body.blocks[0].last_stmt, scf.Yield)
|
|
62
|
+
and not else_term.values # empty yield
|
|
63
|
+
):
|
|
64
|
+
return True
|
|
65
|
+
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
DontLiftType = (
|
|
70
|
+
qubit.Apply,
|
|
71
|
+
qubit.Broadcast,
|
|
72
|
+
scf.Yield,
|
|
73
|
+
func.Return,
|
|
74
|
+
func.Invoke,
|
|
75
|
+
scf.IfElse,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class StimLiftThenBody(IfElseSimplification, LiftThenBody):
|
|
81
|
+
exclude_stmts: tuple[type[ir.Statement], ...] = field(default=DontLiftType)
|
|
82
|
+
|
|
83
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
84
|
+
|
|
85
|
+
if not isinstance(node, scf.IfElse):
|
|
86
|
+
return RewriteResult()
|
|
87
|
+
|
|
88
|
+
if not self.is_rewriteable(node):
|
|
89
|
+
return RewriteResult()
|
|
90
|
+
|
|
91
|
+
return super().rewrite_Statement(node)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Only run this after everything other than qubit.Apply/qubit.Broadcast has been
|
|
95
|
+
# lifted out!
|
|
96
|
+
class StimSplitIfStmts(IfElseSimplification, SplitIfStmts):
|
|
97
|
+
"""Splits the then body of an if-else statement into multiple if statements
|
|
98
|
+
|
|
99
|
+
Given an IfElse with multiple valid statements in the then-body:
|
|
100
|
+
|
|
101
|
+
if measure_result:
|
|
102
|
+
squin.qubit.apply(op.X, q0)
|
|
103
|
+
squin.qubit.apply(op.Y, q1)
|
|
104
|
+
|
|
105
|
+
this should be rewritten to:
|
|
106
|
+
|
|
107
|
+
if measure_result:
|
|
108
|
+
squin.qubit.apply(op.X, q0)
|
|
109
|
+
|
|
110
|
+
if measure_result:
|
|
111
|
+
squin.qubit.apply(op.Y, q1)
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
115
|
+
if not isinstance(node, scf.IfElse):
|
|
116
|
+
return RewriteResult()
|
|
117
|
+
|
|
118
|
+
if not self.is_rewriteable(node):
|
|
119
|
+
return RewriteResult()
|
|
120
|
+
|
|
121
|
+
return super().rewrite_Statement(node)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class IfToStim(IfElseSimplification, RewriteRule):
|
|
126
|
+
"""
|
|
127
|
+
Rewrite if statements to stim equivalent statements.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
measure_analysis: dict[ir.SSAValue, MeasureId]
|
|
131
|
+
measure_count: int
|
|
132
|
+
|
|
133
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
134
|
+
|
|
135
|
+
match node:
|
|
136
|
+
case scf.IfElse():
|
|
137
|
+
return self.rewrite_IfElse(node)
|
|
138
|
+
case _:
|
|
139
|
+
return RewriteResult()
|
|
140
|
+
|
|
141
|
+
def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
|
|
142
|
+
|
|
143
|
+
if not isinstance(self.measure_analysis[stmt.cond], MeasureIdBool):
|
|
144
|
+
return RewriteResult()
|
|
145
|
+
|
|
146
|
+
# check that there is only qubit.Apply in the then-body,
|
|
147
|
+
# if there's more than that, we can't do a valid rewrite.
|
|
148
|
+
# Can reuse logic from SplitIf
|
|
149
|
+
*stmts, _ = stmt.then_body.stmts()
|
|
150
|
+
if len(stmts) != 1 or not isinstance(stmts[0], (qubit.Apply, qubit.Broadcast)):
|
|
151
|
+
return RewriteResult()
|
|
152
|
+
|
|
153
|
+
apply_or_broadcast = stmts[0]
|
|
154
|
+
# Check that the gate being applied/broadcasted can be converted to a stim
|
|
155
|
+
# controlled gate.
|
|
156
|
+
ctrl_op_target_gate = apply_or_broadcast.operator.owner
|
|
157
|
+
assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
|
|
158
|
+
|
|
159
|
+
stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate))
|
|
160
|
+
if stim_gate is None:
|
|
161
|
+
return RewriteResult()
|
|
162
|
+
|
|
163
|
+
# get necessary measurement ID type from analysis
|
|
164
|
+
measure_id_bool = self.measure_analysis[stmt.cond]
|
|
165
|
+
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
166
|
+
|
|
167
|
+
# generate get record statement
|
|
168
|
+
measure_id_idx_stmt = py.Constant(
|
|
169
|
+
(measure_id_bool.idx - 1) - self.measure_count
|
|
170
|
+
)
|
|
171
|
+
get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
|
|
172
|
+
|
|
173
|
+
# get address attribute and generate qubit idx statements
|
|
174
|
+
address_attr = apply_or_broadcast.qubits.hints.get("address")
|
|
175
|
+
if address_attr is None:
|
|
176
|
+
return RewriteResult()
|
|
177
|
+
assert isinstance(address_attr, AddressAttribute)
|
|
178
|
+
|
|
179
|
+
# note: insert things before (literally above/outside) the If
|
|
180
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
181
|
+
address=address_attr, stmt_to_insert_before=stmt
|
|
182
|
+
)
|
|
183
|
+
if qubit_idx_ssas is None:
|
|
184
|
+
return RewriteResult()
|
|
185
|
+
|
|
186
|
+
# Assemble the stim statement
|
|
187
|
+
# let GetRecord's SSA be repeated per each get qubit
|
|
188
|
+
ctrl_records = tuple(get_record_stmt.result for _ in qubit_idx_ssas)
|
|
189
|
+
|
|
190
|
+
stim_stmt = stim_gate(
|
|
191
|
+
targets=tuple(qubit_idx_ssas),
|
|
192
|
+
controls=ctrl_records,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Insert the necessary SSA Values, then get rid of the scf.IfElse.
|
|
196
|
+
# The qubit indices have been successfully added,
|
|
197
|
+
# that just leaves the GetRecord statement and measurement ID index statement
|
|
198
|
+
|
|
199
|
+
measure_id_idx_stmt.insert_before(stmt)
|
|
200
|
+
get_record_stmt.insert_before(stmt)
|
|
201
|
+
stmt.replace_by(stim_stmt)
|
|
202
|
+
|
|
203
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -1,17 +1,21 @@
|
|
|
1
1
|
from kirin import ir
|
|
2
2
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
3
|
|
|
4
|
-
from bloqade.squin import op, qubit
|
|
4
|
+
from bloqade.squin import op, noise, qubit
|
|
5
5
|
from bloqade.squin.rewrite import AddressAttribute
|
|
6
6
|
from bloqade.stim.dialects import gate
|
|
7
7
|
from bloqade.stim.rewrite.util import (
|
|
8
|
-
|
|
8
|
+
SQUIN_STIM_OP_MAPPING,
|
|
9
9
|
rewrite_Control,
|
|
10
|
+
rewrite_QubitLoss,
|
|
10
11
|
insert_qubit_idx_from_address,
|
|
11
12
|
)
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class SquinQubitToStim(RewriteRule):
|
|
16
|
+
"""
|
|
17
|
+
NOTE this require address analysis result to be wrapped before using this rule.
|
|
18
|
+
"""
|
|
15
19
|
|
|
16
20
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
17
21
|
|
|
@@ -30,11 +34,17 @@ class SquinQubitToStim(RewriteRule):
|
|
|
30
34
|
|
|
31
35
|
# this is an SSAValue, need it to be the actual operator
|
|
32
36
|
applied_op = stmt.operator.owner
|
|
37
|
+
|
|
38
|
+
if isinstance(applied_op, noise.stmts.QubitLoss):
|
|
39
|
+
return rewrite_QubitLoss(stmt)
|
|
40
|
+
|
|
33
41
|
assert isinstance(applied_op, op.stmts.Operator)
|
|
34
42
|
|
|
35
43
|
if isinstance(applied_op, op.stmts.Control):
|
|
36
44
|
return rewrite_Control(stmt)
|
|
37
45
|
|
|
46
|
+
# need to handle Control through separate means
|
|
47
|
+
|
|
38
48
|
# check if its adjoint, assume its canonicalized so no nested adjoints.
|
|
39
49
|
is_conj = False
|
|
40
50
|
if isinstance(applied_op, op.stmts.Adjoint):
|
|
@@ -44,9 +54,7 @@ class SquinQubitToStim(RewriteRule):
|
|
|
44
54
|
is_conj = True
|
|
45
55
|
applied_op = applied_op.op.owner
|
|
46
56
|
|
|
47
|
-
|
|
48
|
-
# but we can handle X, Y, Z, H, and S here just fine
|
|
49
|
-
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
|
|
57
|
+
stim_1q_op = SQUIN_STIM_OP_MAPPING.get(type(applied_op))
|
|
50
58
|
if stim_1q_op is None:
|
|
51
59
|
return RewriteResult()
|
|
52
60
|
|
|
@@ -55,7 +63,6 @@ class SquinQubitToStim(RewriteRule):
|
|
|
55
63
|
if address_attr is None:
|
|
56
64
|
return RewriteResult()
|
|
57
65
|
|
|
58
|
-
# sometimes you can get a whole AddressReg...
|
|
59
66
|
assert isinstance(address_attr, AddressAttribute)
|
|
60
67
|
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
61
68
|
address=address_attr, stmt_to_insert_before=stmt
|
|
@@ -1,22 +1,59 @@
|
|
|
1
1
|
# create rewrite rule name SquinMeasureToStim using kirin
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
2
4
|
from kirin import ir
|
|
3
|
-
from kirin.dialects import py
|
|
5
|
+
from kirin.dialects import py, ilist
|
|
4
6
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
5
7
|
|
|
6
8
|
from bloqade.squin import wire, qubit
|
|
7
9
|
from bloqade.squin.rewrite import AddressAttribute
|
|
8
|
-
from bloqade.stim.dialects import collapse
|
|
10
|
+
from bloqade.stim.dialects import collapse, auxiliary
|
|
9
11
|
from bloqade.stim.rewrite.util import (
|
|
10
12
|
is_measure_result_used,
|
|
11
13
|
insert_qubit_idx_from_address,
|
|
12
14
|
)
|
|
15
|
+
from bloqade.analysis.measure_id.lattice import MeasureId, MeasureIdBool, MeasureIdTuple
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def replace_get_record(
|
|
19
|
+
node: ir.Statement, measure_id_bool: MeasureIdBool, meas_count: int
|
|
20
|
+
):
|
|
21
|
+
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
22
|
+
target_rec_idx = (measure_id_bool.idx - 1) - meas_count
|
|
23
|
+
idx_stmt = py.constant.Constant(target_rec_idx)
|
|
24
|
+
idx_stmt.insert_before(node)
|
|
25
|
+
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
|
|
26
|
+
node.replace_by(get_record_stmt)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def insert_get_record_list(
|
|
30
|
+
node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count: int
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Insert GetRecord statements before the given node
|
|
34
|
+
"""
|
|
35
|
+
get_record_ssas = []
|
|
36
|
+
for measure_id_bool in measure_id_tuple.data:
|
|
37
|
+
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
38
|
+
target_rec_idx = (measure_id_bool.idx - 1) - meas_count
|
|
39
|
+
idx_stmt = py.constant.Constant(target_rec_idx)
|
|
40
|
+
idx_stmt.insert_before(node)
|
|
41
|
+
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
|
|
42
|
+
get_record_stmt.insert_before(node)
|
|
43
|
+
get_record_ssas.append(get_record_stmt.result)
|
|
13
44
|
|
|
45
|
+
node.replace_by(ilist.New(values=get_record_ssas))
|
|
14
46
|
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
15
49
|
class SquinMeasureToStim(RewriteRule):
|
|
16
50
|
"""
|
|
17
51
|
Rewrite squin measure-related statements to stim statements.
|
|
18
52
|
"""
|
|
19
53
|
|
|
54
|
+
measure_id_result: dict[ir.SSAValue, MeasureId]
|
|
55
|
+
total_measure_count: int
|
|
56
|
+
|
|
20
57
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
21
58
|
|
|
22
59
|
match node:
|
|
@@ -28,20 +65,46 @@ class SquinMeasureToStim(RewriteRule):
|
|
|
28
65
|
def rewrite_Measure(
|
|
29
66
|
self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
|
|
30
67
|
) -> RewriteResult:
|
|
31
|
-
if is_measure_result_used(measure_stmt):
|
|
32
|
-
return RewriteResult()
|
|
33
68
|
|
|
34
69
|
qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
|
|
35
70
|
if qubit_idx_ssas is None:
|
|
36
71
|
return RewriteResult()
|
|
37
72
|
|
|
73
|
+
measure_id = self.measure_id_result[measure_stmt.result]
|
|
74
|
+
if not isinstance(measure_id, (MeasureIdBool, MeasureIdTuple)):
|
|
75
|
+
return RewriteResult()
|
|
76
|
+
|
|
38
77
|
prob_noise_stmt = py.constant.Constant(0.0)
|
|
39
78
|
stim_measure_stmt = collapse.MZ(
|
|
40
79
|
p=prob_noise_stmt.result,
|
|
41
80
|
targets=qubit_idx_ssas,
|
|
42
81
|
)
|
|
43
82
|
prob_noise_stmt.insert_before(measure_stmt)
|
|
44
|
-
|
|
83
|
+
stim_measure_stmt.insert_before(measure_stmt)
|
|
84
|
+
|
|
85
|
+
if not is_measure_result_used(measure_stmt):
|
|
86
|
+
measure_stmt.delete()
|
|
87
|
+
return RewriteResult(has_done_something=True)
|
|
88
|
+
|
|
89
|
+
# replace dataflow with new stmt!
|
|
90
|
+
measure_id = self.measure_id_result[measure_stmt.result]
|
|
91
|
+
if isinstance(measure_id, MeasureIdBool):
|
|
92
|
+
replace_get_record(
|
|
93
|
+
node=measure_stmt,
|
|
94
|
+
measure_id_bool=measure_id,
|
|
95
|
+
meas_count=self.total_measure_count,
|
|
96
|
+
)
|
|
97
|
+
elif isinstance(measure_id, MeasureIdTuple):
|
|
98
|
+
insert_get_record_list(
|
|
99
|
+
node=measure_stmt,
|
|
100
|
+
measure_id_tuple=measure_id,
|
|
101
|
+
meas_count=self.total_measure_count,
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
# already checked before, so this should not happen
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"Unexpected measure ID type: {type(measure_id)} for measure statement {measure_stmt}"
|
|
107
|
+
)
|
|
45
108
|
|
|
46
109
|
return RewriteResult(has_done_something=True)
|
|
47
110
|
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from typing import Dict, Tuple
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from kirin.ir import SSAValue, Statement
|
|
5
|
+
from kirin.analysis import const
|
|
6
|
+
from kirin.dialects import py
|
|
7
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
8
|
+
|
|
9
|
+
from bloqade.squin import wire, noise as squin_noise, qubit
|
|
10
|
+
from bloqade.stim.dialects import noise as stim_noise
|
|
11
|
+
from bloqade.stim.rewrite.util import (
|
|
12
|
+
create_wire_passthrough,
|
|
13
|
+
insert_qubit_idx_after_apply,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class SquinNoiseToStim(RewriteRule):
|
|
19
|
+
|
|
20
|
+
cp_results: Dict[SSAValue, const.Result]
|
|
21
|
+
|
|
22
|
+
def rewrite_Statement(self, node: Statement) -> RewriteResult:
|
|
23
|
+
match node:
|
|
24
|
+
case qubit.Apply() | qubit.Broadcast():
|
|
25
|
+
return self.rewrite_Apply_and_Broadcast(node)
|
|
26
|
+
case _:
|
|
27
|
+
return RewriteResult()
|
|
28
|
+
|
|
29
|
+
def rewrite_Apply_and_Broadcast(
|
|
30
|
+
self, stmt: qubit.Apply | qubit.Broadcast
|
|
31
|
+
) -> RewriteResult:
|
|
32
|
+
"""Rewrite Apply and Broadcast to their stim statements."""
|
|
33
|
+
|
|
34
|
+
# this is an SSAValue, need it to be the actual operator
|
|
35
|
+
applied_op = stmt.operator.owner
|
|
36
|
+
|
|
37
|
+
if isinstance(applied_op, squin_noise.stmts.NoiseChannel):
|
|
38
|
+
|
|
39
|
+
qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt)
|
|
40
|
+
if qubit_idx_ssas is None:
|
|
41
|
+
return RewriteResult()
|
|
42
|
+
|
|
43
|
+
stim_stmt = None
|
|
44
|
+
if isinstance(applied_op, squin_noise.stmts.SingleQubitPauliChannel):
|
|
45
|
+
stim_stmt = self.rewrite_SingleQubitPauliChannel(stmt, qubit_idx_ssas)
|
|
46
|
+
elif isinstance(applied_op, squin_noise.stmts.TwoQubitPauliChannel):
|
|
47
|
+
stim_stmt = self.rewrite_TwoQubitPauliChannel(stmt, qubit_idx_ssas)
|
|
48
|
+
|
|
49
|
+
if isinstance(stmt, (wire.Apply, wire.Broadcast)):
|
|
50
|
+
create_wire_passthrough(stmt)
|
|
51
|
+
|
|
52
|
+
if stim_stmt is not None:
|
|
53
|
+
stmt.replace_by(stim_stmt)
|
|
54
|
+
if len(stmt.operator.owner.result.uses) == 0:
|
|
55
|
+
stmt.operator.owner.delete()
|
|
56
|
+
|
|
57
|
+
return RewriteResult(has_done_something=True)
|
|
58
|
+
return RewriteResult()
|
|
59
|
+
|
|
60
|
+
def rewrite_SingleQubitPauliChannel(
|
|
61
|
+
self,
|
|
62
|
+
stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
|
|
63
|
+
qubit_idx_ssas: Tuple[SSAValue],
|
|
64
|
+
) -> Statement:
|
|
65
|
+
"""Rewrite squin.noise.SingleQubitPauliChannel to stim.PauliChannel1."""
|
|
66
|
+
|
|
67
|
+
squin_channel = stmt.operator.owner
|
|
68
|
+
assert isinstance(squin_channel, squin_noise.stmts.SingleQubitPauliChannel)
|
|
69
|
+
|
|
70
|
+
params = self.cp_results.get(squin_channel.params).data
|
|
71
|
+
new_stmts = [
|
|
72
|
+
p_x := py.Constant(params[0]),
|
|
73
|
+
p_y := py.Constant(params[1]),
|
|
74
|
+
p_z := py.Constant(params[2]),
|
|
75
|
+
]
|
|
76
|
+
for new_stmt in new_stmts:
|
|
77
|
+
new_stmt.insert_before(stmt)
|
|
78
|
+
|
|
79
|
+
stim_stmt = stim_noise.PauliChannel1(
|
|
80
|
+
targets=qubit_idx_ssas,
|
|
81
|
+
px=p_x.result,
|
|
82
|
+
py=p_y.result,
|
|
83
|
+
pz=p_z.result,
|
|
84
|
+
)
|
|
85
|
+
return stim_stmt
|
|
86
|
+
|
|
87
|
+
def rewrite_TwoQubitPauliChannel(
|
|
88
|
+
self,
|
|
89
|
+
stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
|
|
90
|
+
qubit_idx_ssas: Tuple[SSAValue],
|
|
91
|
+
) -> Statement:
|
|
92
|
+
"""Rewrite squin.noise.SingleQubitPauliChannel to stim.PauliChannel1."""
|
|
93
|
+
|
|
94
|
+
squin_channel = stmt.operator.owner
|
|
95
|
+
assert isinstance(squin_channel, squin_noise.stmts.TwoQubitPauliChannel)
|
|
96
|
+
|
|
97
|
+
params = self.cp_results.get(squin_channel.params).data
|
|
98
|
+
param_stmts = [py.Constant(p) for p in params]
|
|
99
|
+
for param_stmt in param_stmts:
|
|
100
|
+
param_stmt.insert_before(stmt)
|
|
101
|
+
|
|
102
|
+
stim_stmt = stim_noise.PauliChannel2(
|
|
103
|
+
targets=qubit_idx_ssas,
|
|
104
|
+
pix=param_stmts[0].result,
|
|
105
|
+
piy=param_stmts[1].result,
|
|
106
|
+
piz=param_stmts[2].result,
|
|
107
|
+
pxi=param_stmts[3].result,
|
|
108
|
+
pxx=param_stmts[4].result,
|
|
109
|
+
pxy=param_stmts[5].result,
|
|
110
|
+
pxz=param_stmts[6].result,
|
|
111
|
+
pyi=param_stmts[7].result,
|
|
112
|
+
pyx=param_stmts[8].result,
|
|
113
|
+
pyy=param_stmts[9].result,
|
|
114
|
+
pyz=param_stmts[10].result,
|
|
115
|
+
pzi=param_stmts[11].result,
|
|
116
|
+
pzx=param_stmts[12].result,
|
|
117
|
+
pzy=param_stmts[13].result,
|
|
118
|
+
pzz=param_stmts[14].result,
|
|
119
|
+
)
|
|
120
|
+
return stim_stmt
|
bloqade/stim/rewrite/util.py
CHANGED
|
@@ -2,12 +2,12 @@ from kirin import ir
|
|
|
2
2
|
from kirin.dialects import py
|
|
3
3
|
from kirin.rewrite.abc import RewriteResult
|
|
4
4
|
|
|
5
|
-
from bloqade.squin import op, wire, qubit
|
|
5
|
+
from bloqade.squin import op, wire, noise as squin_noise, qubit
|
|
6
6
|
from bloqade.squin.rewrite import AddressAttribute
|
|
7
|
-
from bloqade.stim.dialects import gate, collapse
|
|
7
|
+
from bloqade.stim.dialects import gate, noise as stim_noise, collapse
|
|
8
8
|
from bloqade.analysis.address import AddressReg, AddressWire, AddressQubit, AddressTuple
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
SQUIN_STIM_OP_MAPPING = {
|
|
11
11
|
op.stmts.X: gate.X,
|
|
12
12
|
op.stmts.Y: gate.Y,
|
|
13
13
|
op.stmts.Z: gate.Z,
|
|
@@ -17,6 +17,7 @@ SQUIN_STIM_GATE_MAPPING = {
|
|
|
17
17
|
op.stmts.SqrtY: gate.SqrtY,
|
|
18
18
|
op.stmts.Identity: gate.Identity,
|
|
19
19
|
op.stmts.Reset: collapse.RZ,
|
|
20
|
+
squin_noise.stmts.QubitLoss: stim_noise.QubitLoss,
|
|
20
21
|
}
|
|
21
22
|
|
|
22
23
|
# Squin allows creation of control gates where the gate can be any operator,
|
|
@@ -151,18 +152,48 @@ def rewrite_Control(
|
|
|
151
152
|
stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits)
|
|
152
153
|
|
|
153
154
|
if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)):
|
|
154
|
-
|
|
155
|
-
# to subsequent statements, remove dependency on the current statement
|
|
156
|
-
for input_wire, output_wire in zip(
|
|
157
|
-
stmt_with_ctrl.inputs, stmt_with_ctrl.results
|
|
158
|
-
):
|
|
159
|
-
output_wire.replace_by(input_wire)
|
|
155
|
+
create_wire_passthrough(stmt_with_ctrl)
|
|
160
156
|
|
|
161
157
|
stmt_with_ctrl.replace_by(stim_stmt)
|
|
162
158
|
|
|
163
159
|
return RewriteResult(has_done_something=True)
|
|
164
160
|
|
|
165
161
|
|
|
162
|
+
def rewrite_QubitLoss(
|
|
163
|
+
stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
|
|
164
|
+
) -> RewriteResult:
|
|
165
|
+
"""
|
|
166
|
+
Rewrite QubitLoss statements to Stim's TrivialError.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
squin_loss_op = stmt.operator.owner
|
|
170
|
+
assert isinstance(squin_loss_op, squin_noise.stmts.QubitLoss)
|
|
171
|
+
|
|
172
|
+
qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt)
|
|
173
|
+
if qubit_idx_ssas is None:
|
|
174
|
+
return RewriteResult()
|
|
175
|
+
|
|
176
|
+
stim_loss_stmt = stim_noise.QubitLoss(
|
|
177
|
+
targets=qubit_idx_ssas,
|
|
178
|
+
probs=(squin_loss_op.p,),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if isinstance(stmt, (wire.Apply, wire.Broadcast)):
|
|
182
|
+
create_wire_passthrough(stmt)
|
|
183
|
+
|
|
184
|
+
stmt.replace_by(stim_loss_stmt)
|
|
185
|
+
|
|
186
|
+
return RewriteResult(has_done_something=True)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def create_wire_passthrough(stmt: wire.Apply | wire.Broadcast) -> None:
|
|
190
|
+
|
|
191
|
+
for input_wire, output_wire in zip(stmt.inputs, stmt.results):
|
|
192
|
+
# have to "reroute" the input of these statements to directly plug in
|
|
193
|
+
# to subsequent statements, remove dependency on the current statement
|
|
194
|
+
output_wire.replace_by(input_wire)
|
|
195
|
+
|
|
196
|
+
|
|
166
197
|
def is_measure_result_used(
|
|
167
198
|
stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure,
|
|
168
199
|
) -> bool:
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from kirin import ir
|
|
2
2
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
3
|
|
|
4
|
-
from bloqade.squin import op, wire
|
|
4
|
+
from bloqade.squin import op, wire, noise
|
|
5
5
|
from bloqade.stim.rewrite.util import (
|
|
6
|
-
|
|
6
|
+
SQUIN_STIM_OP_MAPPING,
|
|
7
7
|
rewrite_Control,
|
|
8
|
+
rewrite_QubitLoss,
|
|
8
9
|
insert_qubit_idx_from_wire_ssa,
|
|
9
10
|
)
|
|
10
11
|
|
|
@@ -24,12 +25,16 @@ class SquinWireToStim(RewriteRule):
|
|
|
24
25
|
|
|
25
26
|
# this is an SSAValue, need it to be the actual operator
|
|
26
27
|
applied_op = stmt.operator.owner
|
|
28
|
+
|
|
29
|
+
if isinstance(applied_op, noise.stmts.QubitLoss):
|
|
30
|
+
return rewrite_QubitLoss(stmt)
|
|
31
|
+
|
|
27
32
|
assert isinstance(applied_op, op.stmts.Operator)
|
|
28
33
|
|
|
29
34
|
if isinstance(applied_op, op.stmts.Control):
|
|
30
35
|
return rewrite_Control(stmt)
|
|
31
36
|
|
|
32
|
-
stim_1q_op =
|
|
37
|
+
stim_1q_op = SQUIN_STIM_OP_MAPPING.get(type(applied_op))
|
|
33
38
|
if stim_1q_op is None:
|
|
34
39
|
return RewriteResult()
|
|
35
40
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .from_squin import squin_to_stim as squin_to_stim
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
|
|
3
|
+
from ..groups import main
|
|
4
|
+
from ..passes.squin_to_stim import SquinToStimPass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def squin_to_stim(mt: ir.Method) -> ir.Method:
|
|
8
|
+
new_mt = mt.similar()
|
|
9
|
+
SquinToStimPass(mt.dialects, no_raise=False)(new_mt)
|
|
10
|
+
return new_mt.similar(dialects=main)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bloqade-circuit
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.1
|
|
4
4
|
Summary: The software development toolkit for neutral atom arrays.
|
|
5
5
|
Author-email: Roger-luo <rluo@quera.com>, kaihsin <khwu@quera.com>, weinbe58 <pweinberg@quera.com>, johnzl-777 <jlong@quera.com>
|
|
6
6
|
License-File: LICENSE
|
|
@@ -16,6 +16,7 @@ Requires-Dist: scipy>=1.13.1
|
|
|
16
16
|
Provides-Extra: cirq
|
|
17
17
|
Requires-Dist: cirq-core>=1.4.1; extra == 'cirq'
|
|
18
18
|
Requires-Dist: cirq-core[contrib]>=1.4.1; extra == 'cirq'
|
|
19
|
+
Requires-Dist: qpsolvers[clarabel]>=4.7.0; extra == 'cirq'
|
|
19
20
|
Provides-Extra: pyqrack-cuda
|
|
20
21
|
Requires-Dist: pyqrack-cuda>=1.38.2; extra == 'pyqrack-cuda'
|
|
21
22
|
Provides-Extra: pyqrack-opencl
|
|
@@ -29,7 +30,8 @@ Requires-Dist: stim>=1.15.0; extra == 'stim'
|
|
|
29
30
|
Provides-Extra: vis
|
|
30
31
|
Requires-Dist: ffmpeg>=1.4; extra == 'vis'
|
|
31
32
|
Requires-Dist: matplotlib>=3.9.2; extra == 'vis'
|
|
32
|
-
Requires-Dist: pyqt5>=5.15.11; extra == 'vis'
|
|
33
|
+
Requires-Dist: pyqt5>=5.15.11; (sys_platform == 'darwin') and extra == 'vis'
|
|
34
|
+
Requires-Dist: pyqt5>=5.15.11; (sys_platform == 'linux') and extra == 'vis'
|
|
33
35
|
Requires-Dist: tqdm>=4.66.5; extra == 'vis'
|
|
34
36
|
Description-Content-Type: text/markdown
|
|
35
37
|
|