bloqade-circuit 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +21 -68
- bloqade/analysis/measure_id/__init__.py +2 -0
- bloqade/analysis/measure_id/analysis.py +45 -0
- bloqade/analysis/measure_id/impls.py +155 -0
- bloqade/analysis/measure_id/lattice.py +82 -0
- bloqade/qasm2/passes/unroll_if.py +9 -2
- bloqade/rewrite/__init__.py +0 -0
- bloqade/rewrite/passes/__init__.py +1 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +28 -0
- bloqade/rewrite/rules/__init__.py +1 -0
- bloqade/rewrite/rules/flatten_ilist.py +51 -0
- bloqade/rewrite/rules/inline_getitem_ilist.py +31 -0
- bloqade/{qasm2/rewrite → rewrite/rules}/split_ifs.py +15 -8
- bloqade/squin/__init__.py +1 -0
- bloqade/squin/analysis/__init__.py +1 -0
- bloqade/squin/analysis/address_impl.py +71 -0
- bloqade/squin/cirq/lowering.py +2 -1
- bloqade/squin/noise/stmts.py +1 -1
- bloqade/stim/dialects/auxiliary/interp.py +0 -10
- bloqade/stim/dialects/auxiliary/stmts/annotate.py +1 -1
- bloqade/stim/passes/__init__.py +1 -1
- bloqade/stim/passes/simplify_ifs.py +32 -0
- bloqade/stim/passes/squin_to_stim.py +95 -27
- bloqade/stim/rewrite/ifs_to_stim.py +203 -0
- bloqade/stim/rewrite/qubit_to_stim.py +3 -0
- bloqade/stim/rewrite/squin_measure.py +68 -5
- bloqade/stim/rewrite/util.py +0 -4
- bloqade/stim/upstream/__init__.py +1 -0
- bloqade/stim/upstream/from_squin.py +10 -0
- {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/METADATA +1 -1
- {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/RECORD +33 -18
- {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/licenses/LICENSE +0 -0
bloqade/squin/noise/stmts.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from kirin import interp
|
|
2
2
|
|
|
3
3
|
from . import stmts
|
|
4
|
-
from .types import RecordResult
|
|
5
4
|
from ._dialect import dialect
|
|
6
5
|
|
|
7
6
|
|
|
@@ -28,12 +27,3 @@ class StimAuxMethods(interp.MethodTable):
|
|
|
28
27
|
stmt: stmts.Neg,
|
|
29
28
|
):
|
|
30
29
|
return (-frame.get(stmt.operand),)
|
|
31
|
-
|
|
32
|
-
@interp.impl(stmts.GetRecord)
|
|
33
|
-
def get_rec(
|
|
34
|
-
self,
|
|
35
|
-
interpreter: interp.Interpreter,
|
|
36
|
-
frame: interp.Frame,
|
|
37
|
-
stmt: stmts.GetRecord,
|
|
38
|
-
):
|
|
39
|
-
return (RecordResult(value=frame.get(stmt.id)),)
|
|
@@ -10,7 +10,7 @@ PyNum = types.Union(types.Int, types.Float)
|
|
|
10
10
|
@statement(dialect=dialect)
|
|
11
11
|
class GetRecord(ir.Statement):
|
|
12
12
|
name = "get_rec"
|
|
13
|
-
traits = frozenset({lowering.FromPythonCall()})
|
|
13
|
+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
|
|
14
14
|
id: ir.SSAValue = info.argument(type=types.Int)
|
|
15
15
|
result: ir.ResultValue = info.result(type=RecordType)
|
|
16
16
|
|
bloqade/stim/passes/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from .squin_to_stim import
|
|
1
|
+
from .squin_to_stim import SquinToStimPass as SquinToStimPass
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.passes import Pass
|
|
5
|
+
from kirin.rewrite import (
|
|
6
|
+
Walk,
|
|
7
|
+
Chain,
|
|
8
|
+
Fixpoint,
|
|
9
|
+
ConstantFold,
|
|
10
|
+
CommonSubexpressionElimination,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class StimSimplifyIfs(Pass):
|
|
18
|
+
|
|
19
|
+
def unsafe_run(self, mt: ir.Method):
|
|
20
|
+
|
|
21
|
+
result = Chain(
|
|
22
|
+
Fixpoint(Walk(StimLiftThenBody())),
|
|
23
|
+
Walk(StimSplitIfStmts()),
|
|
24
|
+
).rewrite(mt.code)
|
|
25
|
+
|
|
26
|
+
result = (
|
|
27
|
+
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
|
|
28
|
+
.rewrite(mt.code)
|
|
29
|
+
.join(result)
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
return result
|
|
@@ -5,15 +5,19 @@ from kirin.rewrite import (
|
|
|
5
5
|
Walk,
|
|
6
6
|
Chain,
|
|
7
7
|
Fixpoint,
|
|
8
|
+
CFGCompactify,
|
|
9
|
+
InlineGetItem,
|
|
10
|
+
InlineGetField,
|
|
8
11
|
DeadCodeElimination,
|
|
9
12
|
CommonSubexpressionElimination,
|
|
10
13
|
)
|
|
11
14
|
from kirin.analysis import const
|
|
15
|
+
from kirin.dialects import scf, ilist
|
|
12
16
|
from kirin.ir.method import Method
|
|
13
17
|
from kirin.passes.abc import Pass
|
|
14
18
|
from kirin.rewrite.abc import RewriteResult
|
|
19
|
+
from kirin.passes.inline import InlinePass
|
|
15
20
|
|
|
16
|
-
from bloqade.stim.groups import main as stim_main_group
|
|
17
21
|
from bloqade.stim.rewrite import (
|
|
18
22
|
SquinWireToStim,
|
|
19
23
|
PyConstantToStim,
|
|
@@ -22,22 +26,95 @@ from bloqade.stim.rewrite import (
|
|
|
22
26
|
SquinMeasureToStim,
|
|
23
27
|
SquinWireIdentityElimination,
|
|
24
28
|
)
|
|
25
|
-
from bloqade.squin.rewrite import
|
|
29
|
+
from bloqade.squin.rewrite import (
|
|
30
|
+
SquinU3ToClifford,
|
|
31
|
+
RemoveDeadRegister,
|
|
32
|
+
WrapAddressAnalysis,
|
|
33
|
+
)
|
|
34
|
+
from bloqade.rewrite.passes import CanonicalizeIList
|
|
35
|
+
from bloqade.analysis.address import AddressAnalysis
|
|
36
|
+
from bloqade.analysis.measure_id import MeasurementIDAnalysis
|
|
37
|
+
|
|
38
|
+
from .simplify_ifs import StimSimplifyIfs
|
|
39
|
+
from ..rewrite.ifs_to_stim import IfToStim
|
|
26
40
|
|
|
27
41
|
|
|
28
42
|
@dataclass
|
|
29
|
-
class
|
|
43
|
+
class SquinToStimPass(Pass):
|
|
30
44
|
|
|
31
45
|
def unsafe_run(self, mt: Method) -> RewriteResult:
|
|
32
|
-
fold_pass = Fold(mt.dialects)
|
|
33
|
-
# propagate constants
|
|
34
|
-
rewrite_result = fold_pass(mt)
|
|
35
46
|
|
|
36
47
|
cp_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
|
|
37
48
|
cp_results = cp_frame.entries
|
|
38
49
|
|
|
39
50
|
# Assume that address analysis and
|
|
40
51
|
# wrapping has been done before this pass!
|
|
52
|
+
# inline aggressively:
|
|
53
|
+
rewrite_result = InlinePass(
|
|
54
|
+
dialects=mt.dialects, no_raise=self.no_raise
|
|
55
|
+
).unsafe_run(mt)
|
|
56
|
+
|
|
57
|
+
rule = Chain(
|
|
58
|
+
InlineGetField(),
|
|
59
|
+
InlineGetItem(),
|
|
60
|
+
scf.unroll.ForLoop(),
|
|
61
|
+
scf.trim.UnusedYield(),
|
|
62
|
+
)
|
|
63
|
+
rewrite_result = Fixpoint(Walk(rule)).rewrite(mt.code).join(rewrite_result)
|
|
64
|
+
# fold_pass = Fold(mt.dialects, no_raise=self.no_raise)
|
|
65
|
+
# rewrite_result = fold_pass(mt)
|
|
66
|
+
rewrite_result = (
|
|
67
|
+
Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
|
|
68
|
+
)
|
|
69
|
+
rewrite_result = (
|
|
70
|
+
StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
|
|
71
|
+
.unsafe_run(mt)
|
|
72
|
+
.join(rewrite_result)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# run typeinfer again after unroll etc. because we now insert
|
|
76
|
+
# a lot of new nodes, which might have more precise types
|
|
77
|
+
# self.typeinfer.unsafe_run(mt)
|
|
78
|
+
rewrite_result = (
|
|
79
|
+
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
|
|
80
|
+
.rewrite(mt.code)
|
|
81
|
+
.join(rewrite_result)
|
|
82
|
+
)
|
|
83
|
+
rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt)
|
|
84
|
+
|
|
85
|
+
rewrite_result = (
|
|
86
|
+
CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
|
|
87
|
+
.unsafe_run(mt)
|
|
88
|
+
.join(rewrite_result)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# after this the program should be in a state where it is analyzable
|
|
92
|
+
# -------------------------------------------------------------------
|
|
93
|
+
|
|
94
|
+
mia = MeasurementIDAnalysis(dialects=mt.dialects)
|
|
95
|
+
meas_analysis_frame, _ = mia.run_analysis(mt, no_raise=self.no_raise)
|
|
96
|
+
|
|
97
|
+
aa = AddressAnalysis(dialects=mt.dialects)
|
|
98
|
+
address_analysis_frame, _ = aa.run_analysis(mt, no_raise=self.no_raise)
|
|
99
|
+
|
|
100
|
+
# wrap the address analysis result
|
|
101
|
+
rewrite_result = (
|
|
102
|
+
Walk(WrapAddressAnalysis(address_analysis=address_analysis_frame.entries))
|
|
103
|
+
.rewrite(mt.code)
|
|
104
|
+
.join(rewrite_result)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# 2. rewrite
|
|
108
|
+
rewrite_result = (
|
|
109
|
+
Walk(
|
|
110
|
+
IfToStim(
|
|
111
|
+
measure_analysis=meas_analysis_frame.entries,
|
|
112
|
+
measure_count=mia.measure_count,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
.rewrite(mt.code)
|
|
116
|
+
.join(rewrite_result)
|
|
117
|
+
)
|
|
41
118
|
|
|
42
119
|
# Rewrite the noise statements first.
|
|
43
120
|
rewrite_result = (
|
|
@@ -47,7 +124,6 @@ class SquinToStim(Pass):
|
|
|
47
124
|
)
|
|
48
125
|
|
|
49
126
|
# Wrap Rewrite + SquinToStim can happen w/ standard walk
|
|
50
|
-
|
|
51
127
|
rewrite_result = Walk(SquinU3ToClifford()).rewrite(mt.code).join(rewrite_result)
|
|
52
128
|
|
|
53
129
|
rewrite_result = (
|
|
@@ -55,23 +131,27 @@ class SquinToStim(Pass):
|
|
|
55
131
|
Chain(
|
|
56
132
|
SquinQubitToStim(),
|
|
57
133
|
SquinWireToStim(),
|
|
58
|
-
SquinMeasureToStim(
|
|
134
|
+
SquinMeasureToStim(
|
|
135
|
+
measure_id_result=meas_analysis_frame.entries,
|
|
136
|
+
total_measure_count=mia.measure_count,
|
|
137
|
+
), # reduce duplicated logic, can split out even more rules later
|
|
59
138
|
SquinWireIdentityElimination(),
|
|
60
139
|
)
|
|
61
140
|
)
|
|
62
141
|
.rewrite(mt.code)
|
|
63
142
|
.join(rewrite_result)
|
|
64
143
|
)
|
|
65
|
-
|
|
66
|
-
# Convert all PyConsts to Stim Constants
|
|
67
144
|
rewrite_result = (
|
|
68
|
-
|
|
145
|
+
CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
|
|
146
|
+
.unsafe_run(mt)
|
|
147
|
+
.join(rewrite_result)
|
|
69
148
|
)
|
|
70
149
|
|
|
71
|
-
#
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
150
|
+
# Convert all PyConsts to Stim Constants
|
|
151
|
+
rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
|
|
152
|
+
|
|
153
|
+
# clear up leftover stmts
|
|
154
|
+
# - remove any squin.qubit.new that's left around
|
|
75
155
|
rewrite_result = (
|
|
76
156
|
Fixpoint(
|
|
77
157
|
Walk(
|
|
@@ -86,16 +166,4 @@ class SquinToStim(Pass):
|
|
|
86
166
|
.join(rewrite_result)
|
|
87
167
|
)
|
|
88
168
|
|
|
89
|
-
# do program verification here,
|
|
90
|
-
# ideally use built-in .verify() to catch any
|
|
91
|
-
# incompatible statements as the full rewrite process should not
|
|
92
|
-
# leave statements from any other dialects (other than the stim main group)
|
|
93
|
-
mt_verification_clone = mt.similar(stim_main_group)
|
|
94
|
-
|
|
95
|
-
# suggested by Kai, will work for now
|
|
96
|
-
for stmt in mt_verification_clone.code.walk():
|
|
97
|
-
assert (
|
|
98
|
-
stmt.dialect in stim_main_group
|
|
99
|
-
), "Statements detected that are not part of the stim dialect, please verify the original code is valid for rewrite!"
|
|
100
|
-
|
|
101
169
|
return rewrite_result
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
from dataclasses import field, dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.dialects import py, scf, func
|
|
5
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
6
|
+
|
|
7
|
+
from bloqade.squin import op, qubit
|
|
8
|
+
from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
|
|
9
|
+
from bloqade.squin.rewrite import AddressAttribute
|
|
10
|
+
from bloqade.stim.rewrite.util import (
|
|
11
|
+
SQUIN_STIM_CONTROL_GATE_MAPPING,
|
|
12
|
+
insert_qubit_idx_from_address,
|
|
13
|
+
)
|
|
14
|
+
from bloqade.stim.dialects.auxiliary import GetRecord
|
|
15
|
+
from bloqade.analysis.measure_id.lattice import (
|
|
16
|
+
MeasureId,
|
|
17
|
+
MeasureIdBool,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class IfElseSimplification:
|
|
23
|
+
|
|
24
|
+
# Might be better to just do a rewrite_Region?
|
|
25
|
+
def is_rewriteable(self, node: scf.IfElse) -> bool:
|
|
26
|
+
return not (
|
|
27
|
+
self.contains_ifelse(node)
|
|
28
|
+
or self.is_nested_ifelse(node)
|
|
29
|
+
or self.has_else_body(node)
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# A preliminary check to reject an IfElse from the "top down"
|
|
33
|
+
# use in conjunction with is_nested_ifelse
|
|
34
|
+
# to completely cover cases of nested IfElse statements
|
|
35
|
+
def contains_ifelse(self, stmt: scf.IfElse) -> bool:
|
|
36
|
+
"""Check if the IfElse statement contains another IfElse statement."""
|
|
37
|
+
for child in stmt.walk(include_self=False):
|
|
38
|
+
if isinstance(child, scf.IfElse):
|
|
39
|
+
return True
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
# because rewrite latches onto ANY scf.IfElse,
|
|
43
|
+
# you need a way to determine if you're touching an
|
|
44
|
+
# IfElse that's inside another IfElse
|
|
45
|
+
def is_nested_ifelse(self, stmt: scf.IfElse) -> bool:
|
|
46
|
+
"""Check if the IfElse statement is nested."""
|
|
47
|
+
if stmt.parent_stmt is not None:
|
|
48
|
+
if isinstance(stmt.parent_stmt, scf.IfElse) or isinstance(
|
|
49
|
+
stmt.parent_stmt.parent_stmt, scf.IfElse
|
|
50
|
+
):
|
|
51
|
+
return True
|
|
52
|
+
else:
|
|
53
|
+
return False
|
|
54
|
+
else:
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
def has_else_body(self, stmt: scf.IfElse) -> bool:
|
|
58
|
+
"""Check if the IfElse statement has an else body."""
|
|
59
|
+
if stmt.else_body.blocks and not (
|
|
60
|
+
len(stmt.else_body.blocks[0].stmts) == 1
|
|
61
|
+
and isinstance(else_term := stmt.else_body.blocks[0].last_stmt, scf.Yield)
|
|
62
|
+
and not else_term.values # empty yield
|
|
63
|
+
):
|
|
64
|
+
return True
|
|
65
|
+
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
DontLiftType = (
|
|
70
|
+
qubit.Apply,
|
|
71
|
+
qubit.Broadcast,
|
|
72
|
+
scf.Yield,
|
|
73
|
+
func.Return,
|
|
74
|
+
func.Invoke,
|
|
75
|
+
scf.IfElse,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class StimLiftThenBody(IfElseSimplification, LiftThenBody):
|
|
81
|
+
exclude_stmts: tuple[type[ir.Statement], ...] = field(default=DontLiftType)
|
|
82
|
+
|
|
83
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
84
|
+
|
|
85
|
+
if not isinstance(node, scf.IfElse):
|
|
86
|
+
return RewriteResult()
|
|
87
|
+
|
|
88
|
+
if not self.is_rewriteable(node):
|
|
89
|
+
return RewriteResult()
|
|
90
|
+
|
|
91
|
+
return super().rewrite_Statement(node)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Only run this after everything other than qubit.Apply/qubit.Broadcast has been
|
|
95
|
+
# lifted out!
|
|
96
|
+
class StimSplitIfStmts(IfElseSimplification, SplitIfStmts):
|
|
97
|
+
"""Splits the then body of an if-else statement into multiple if statements
|
|
98
|
+
|
|
99
|
+
Given an IfElse with multiple valid statements in the then-body:
|
|
100
|
+
|
|
101
|
+
if measure_result:
|
|
102
|
+
squin.qubit.apply(op.X, q0)
|
|
103
|
+
squin.qubit.apply(op.Y, q1)
|
|
104
|
+
|
|
105
|
+
this should be rewritten to:
|
|
106
|
+
|
|
107
|
+
if measure_result:
|
|
108
|
+
squin.qubit.apply(op.X, q0)
|
|
109
|
+
|
|
110
|
+
if measure_result:
|
|
111
|
+
squin.qubit.apply(op.Y, q1)
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
115
|
+
if not isinstance(node, scf.IfElse):
|
|
116
|
+
return RewriteResult()
|
|
117
|
+
|
|
118
|
+
if not self.is_rewriteable(node):
|
|
119
|
+
return RewriteResult()
|
|
120
|
+
|
|
121
|
+
return super().rewrite_Statement(node)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class IfToStim(IfElseSimplification, RewriteRule):
|
|
126
|
+
"""
|
|
127
|
+
Rewrite if statements to stim equivalent statements.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
measure_analysis: dict[ir.SSAValue, MeasureId]
|
|
131
|
+
measure_count: int
|
|
132
|
+
|
|
133
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
134
|
+
|
|
135
|
+
match node:
|
|
136
|
+
case scf.IfElse():
|
|
137
|
+
return self.rewrite_IfElse(node)
|
|
138
|
+
case _:
|
|
139
|
+
return RewriteResult()
|
|
140
|
+
|
|
141
|
+
def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
|
|
142
|
+
|
|
143
|
+
if not isinstance(self.measure_analysis[stmt.cond], MeasureIdBool):
|
|
144
|
+
return RewriteResult()
|
|
145
|
+
|
|
146
|
+
# check that there is only qubit.Apply in the then-body,
|
|
147
|
+
# if there's more than that, we can't do a valid rewrite.
|
|
148
|
+
# Can reuse logic from SplitIf
|
|
149
|
+
*stmts, _ = stmt.then_body.stmts()
|
|
150
|
+
if len(stmts) != 1 or not isinstance(stmts[0], (qubit.Apply, qubit.Broadcast)):
|
|
151
|
+
return RewriteResult()
|
|
152
|
+
|
|
153
|
+
apply_or_broadcast = stmts[0]
|
|
154
|
+
# Check that the gate being applied/broadcasted can be converted to a stim
|
|
155
|
+
# controlled gate.
|
|
156
|
+
ctrl_op_target_gate = apply_or_broadcast.operator.owner
|
|
157
|
+
assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
|
|
158
|
+
|
|
159
|
+
stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate))
|
|
160
|
+
if stim_gate is None:
|
|
161
|
+
return RewriteResult()
|
|
162
|
+
|
|
163
|
+
# get necessary measurement ID type from analysis
|
|
164
|
+
measure_id_bool = self.measure_analysis[stmt.cond]
|
|
165
|
+
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
166
|
+
|
|
167
|
+
# generate get record statement
|
|
168
|
+
measure_id_idx_stmt = py.Constant(
|
|
169
|
+
(measure_id_bool.idx - 1) - self.measure_count
|
|
170
|
+
)
|
|
171
|
+
get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
|
|
172
|
+
|
|
173
|
+
# get address attribute and generate qubit idx statements
|
|
174
|
+
address_attr = apply_or_broadcast.qubits.hints.get("address")
|
|
175
|
+
if address_attr is None:
|
|
176
|
+
return RewriteResult()
|
|
177
|
+
assert isinstance(address_attr, AddressAttribute)
|
|
178
|
+
|
|
179
|
+
# note: insert things before (literally above/outside) the If
|
|
180
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
181
|
+
address=address_attr, stmt_to_insert_before=stmt
|
|
182
|
+
)
|
|
183
|
+
if qubit_idx_ssas is None:
|
|
184
|
+
return RewriteResult()
|
|
185
|
+
|
|
186
|
+
# Assemble the stim statement
|
|
187
|
+
# let GetRecord's SSA be repeated per each get qubit
|
|
188
|
+
ctrl_records = tuple(get_record_stmt.result for _ in qubit_idx_ssas)
|
|
189
|
+
|
|
190
|
+
stim_stmt = stim_gate(
|
|
191
|
+
targets=tuple(qubit_idx_ssas),
|
|
192
|
+
controls=ctrl_records,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Insert the necessary SSA Values, then get rid of the scf.IfElse.
|
|
196
|
+
# The qubit indices have been successfully added,
|
|
197
|
+
# that just leaves the GetRecord statement and measurement ID index statement
|
|
198
|
+
|
|
199
|
+
measure_id_idx_stmt.insert_before(stmt)
|
|
200
|
+
get_record_stmt.insert_before(stmt)
|
|
201
|
+
stmt.replace_by(stim_stmt)
|
|
202
|
+
|
|
203
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -1,22 +1,59 @@
|
|
|
1
1
|
# create rewrite rule name SquinMeasureToStim using kirin
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
2
4
|
from kirin import ir
|
|
3
|
-
from kirin.dialects import py
|
|
5
|
+
from kirin.dialects import py, ilist
|
|
4
6
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
5
7
|
|
|
6
8
|
from bloqade.squin import wire, qubit
|
|
7
9
|
from bloqade.squin.rewrite import AddressAttribute
|
|
8
|
-
from bloqade.stim.dialects import collapse
|
|
10
|
+
from bloqade.stim.dialects import collapse, auxiliary
|
|
9
11
|
from bloqade.stim.rewrite.util import (
|
|
10
12
|
is_measure_result_used,
|
|
11
13
|
insert_qubit_idx_from_address,
|
|
12
14
|
)
|
|
15
|
+
from bloqade.analysis.measure_id.lattice import MeasureId, MeasureIdBool, MeasureIdTuple
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def replace_get_record(
|
|
19
|
+
node: ir.Statement, measure_id_bool: MeasureIdBool, meas_count: int
|
|
20
|
+
):
|
|
21
|
+
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
22
|
+
target_rec_idx = (measure_id_bool.idx - 1) - meas_count
|
|
23
|
+
idx_stmt = py.constant.Constant(target_rec_idx)
|
|
24
|
+
idx_stmt.insert_before(node)
|
|
25
|
+
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
|
|
26
|
+
node.replace_by(get_record_stmt)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def insert_get_record_list(
|
|
30
|
+
node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count: int
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Insert GetRecord statements before the given node
|
|
34
|
+
"""
|
|
35
|
+
get_record_ssas = []
|
|
36
|
+
for measure_id_bool in measure_id_tuple.data:
|
|
37
|
+
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
38
|
+
target_rec_idx = (measure_id_bool.idx - 1) - meas_count
|
|
39
|
+
idx_stmt = py.constant.Constant(target_rec_idx)
|
|
40
|
+
idx_stmt.insert_before(node)
|
|
41
|
+
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
|
|
42
|
+
get_record_stmt.insert_before(node)
|
|
43
|
+
get_record_ssas.append(get_record_stmt.result)
|
|
13
44
|
|
|
45
|
+
node.replace_by(ilist.New(values=get_record_ssas))
|
|
14
46
|
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
15
49
|
class SquinMeasureToStim(RewriteRule):
|
|
16
50
|
"""
|
|
17
51
|
Rewrite squin measure-related statements to stim statements.
|
|
18
52
|
"""
|
|
19
53
|
|
|
54
|
+
measure_id_result: dict[ir.SSAValue, MeasureId]
|
|
55
|
+
total_measure_count: int
|
|
56
|
+
|
|
20
57
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
21
58
|
|
|
22
59
|
match node:
|
|
@@ -28,20 +65,46 @@ class SquinMeasureToStim(RewriteRule):
|
|
|
28
65
|
def rewrite_Measure(
|
|
29
66
|
self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
|
|
30
67
|
) -> RewriteResult:
|
|
31
|
-
if is_measure_result_used(measure_stmt):
|
|
32
|
-
return RewriteResult()
|
|
33
68
|
|
|
34
69
|
qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
|
|
35
70
|
if qubit_idx_ssas is None:
|
|
36
71
|
return RewriteResult()
|
|
37
72
|
|
|
73
|
+
measure_id = self.measure_id_result[measure_stmt.result]
|
|
74
|
+
if not isinstance(measure_id, (MeasureIdBool, MeasureIdTuple)):
|
|
75
|
+
return RewriteResult()
|
|
76
|
+
|
|
38
77
|
prob_noise_stmt = py.constant.Constant(0.0)
|
|
39
78
|
stim_measure_stmt = collapse.MZ(
|
|
40
79
|
p=prob_noise_stmt.result,
|
|
41
80
|
targets=qubit_idx_ssas,
|
|
42
81
|
)
|
|
43
82
|
prob_noise_stmt.insert_before(measure_stmt)
|
|
44
|
-
|
|
83
|
+
stim_measure_stmt.insert_before(measure_stmt)
|
|
84
|
+
|
|
85
|
+
if not is_measure_result_used(measure_stmt):
|
|
86
|
+
measure_stmt.delete()
|
|
87
|
+
return RewriteResult(has_done_something=True)
|
|
88
|
+
|
|
89
|
+
# replace dataflow with new stmt!
|
|
90
|
+
measure_id = self.measure_id_result[measure_stmt.result]
|
|
91
|
+
if isinstance(measure_id, MeasureIdBool):
|
|
92
|
+
replace_get_record(
|
|
93
|
+
node=measure_stmt,
|
|
94
|
+
measure_id_bool=measure_id,
|
|
95
|
+
meas_count=self.total_measure_count,
|
|
96
|
+
)
|
|
97
|
+
elif isinstance(measure_id, MeasureIdTuple):
|
|
98
|
+
insert_get_record_list(
|
|
99
|
+
node=measure_stmt,
|
|
100
|
+
measure_id_tuple=measure_id,
|
|
101
|
+
meas_count=self.total_measure_count,
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
# already checked before, so this should not happen
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"Unexpected measure ID type: {type(measure_id)} for measure statement {measure_stmt}"
|
|
107
|
+
)
|
|
45
108
|
|
|
46
109
|
return RewriteResult(has_done_something=True)
|
|
47
110
|
|
bloqade/stim/rewrite/util.py
CHANGED
|
@@ -182,10 +182,6 @@ def rewrite_QubitLoss(
|
|
|
182
182
|
create_wire_passthrough(stmt)
|
|
183
183
|
|
|
184
184
|
stmt.replace_by(stim_loss_stmt)
|
|
185
|
-
# NoiseChannels are not pure,
|
|
186
|
-
# need to manually delete because
|
|
187
|
-
# DCE won't touch them
|
|
188
|
-
stmt.operator.owner.delete()
|
|
189
185
|
|
|
190
186
|
return RewriteResult(has_done_something=True)
|
|
191
187
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .from_squin import squin_to_stim as squin_to_stim
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
|
|
3
|
+
from ..groups import main
|
|
4
|
+
from ..passes.squin_to_stim import SquinToStimPass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def squin_to_stim(mt: ir.Method) -> ir.Method:
|
|
8
|
+
new_mt = mt.similar()
|
|
9
|
+
SquinToStimPass(mt.dialects, no_raise=False)(new_mt)
|
|
10
|
+
return new_mt.similar(dialects=main)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bloqade-circuit
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.2
|
|
4
4
|
Summary: The software development toolkit for neutral atom arrays.
|
|
5
5
|
Author-email: Roger-luo <rluo@quera.com>, kaihsin <khwu@quera.com>, weinbe58 <pweinberg@quera.com>, johnzl-777 <jlong@quera.com>
|
|
6
6
|
License-File: LICENSE
|