bloqade-circuit 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (33) hide show
  1. bloqade/analysis/address/impls.py +21 -68
  2. bloqade/analysis/measure_id/__init__.py +2 -0
  3. bloqade/analysis/measure_id/analysis.py +45 -0
  4. bloqade/analysis/measure_id/impls.py +155 -0
  5. bloqade/analysis/measure_id/lattice.py +82 -0
  6. bloqade/qasm2/passes/unroll_if.py +9 -2
  7. bloqade/rewrite/__init__.py +0 -0
  8. bloqade/rewrite/passes/__init__.py +1 -0
  9. bloqade/rewrite/passes/canonicalize_ilist.py +28 -0
  10. bloqade/rewrite/rules/__init__.py +1 -0
  11. bloqade/rewrite/rules/flatten_ilist.py +51 -0
  12. bloqade/rewrite/rules/inline_getitem_ilist.py +31 -0
  13. bloqade/{qasm2/rewrite → rewrite/rules}/split_ifs.py +15 -8
  14. bloqade/squin/__init__.py +1 -0
  15. bloqade/squin/analysis/__init__.py +1 -0
  16. bloqade/squin/analysis/address_impl.py +71 -0
  17. bloqade/squin/cirq/lowering.py +2 -1
  18. bloqade/squin/noise/stmts.py +1 -1
  19. bloqade/stim/dialects/auxiliary/interp.py +0 -10
  20. bloqade/stim/dialects/auxiliary/stmts/annotate.py +1 -1
  21. bloqade/stim/passes/__init__.py +1 -1
  22. bloqade/stim/passes/simplify_ifs.py +32 -0
  23. bloqade/stim/passes/squin_to_stim.py +95 -27
  24. bloqade/stim/rewrite/ifs_to_stim.py +203 -0
  25. bloqade/stim/rewrite/qubit_to_stim.py +3 -0
  26. bloqade/stim/rewrite/squin_measure.py +68 -5
  27. bloqade/stim/rewrite/util.py +0 -4
  28. bloqade/stim/upstream/__init__.py +1 -0
  29. bloqade/stim/upstream/from_squin.py +10 -0
  30. {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/METADATA +1 -1
  31. {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/RECORD +33 -18
  32. {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/WHEEL +0 -0
  33. {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/licenses/LICENSE +0 -0
@@ -10,7 +10,7 @@ from ..op.types import NumOperators
10
10
 
11
11
  @statement
12
12
  class NoiseChannel(ir.Statement):
13
- traits = frozenset({lowering.FromPythonCall()})
13
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
14
14
  result: ir.ResultValue = info.result(OpType)
15
15
 
16
16
 
@@ -1,7 +1,6 @@
1
1
  from kirin import interp
2
2
 
3
3
  from . import stmts
4
- from .types import RecordResult
5
4
  from ._dialect import dialect
6
5
 
7
6
 
@@ -28,12 +27,3 @@ class StimAuxMethods(interp.MethodTable):
28
27
  stmt: stmts.Neg,
29
28
  ):
30
29
  return (-frame.get(stmt.operand),)
31
-
32
- @interp.impl(stmts.GetRecord)
33
- def get_rec(
34
- self,
35
- interpreter: interp.Interpreter,
36
- frame: interp.Frame,
37
- stmt: stmts.GetRecord,
38
- ):
39
- return (RecordResult(value=frame.get(stmt.id)),)
@@ -10,7 +10,7 @@ PyNum = types.Union(types.Int, types.Float)
10
10
  @statement(dialect=dialect)
11
11
  class GetRecord(ir.Statement):
12
12
  name = "get_rec"
13
- traits = frozenset({lowering.FromPythonCall()})
13
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
14
14
  id: ir.SSAValue = info.argument(type=types.Int)
15
15
  result: ir.ResultValue = info.result(type=RecordType)
16
16
 
@@ -1 +1 @@
1
- from .squin_to_stim import SquinToStim as SquinToStim
1
+ from .squin_to_stim import SquinToStimPass as SquinToStimPass
@@ -0,0 +1,32 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Pass
5
+ from kirin.rewrite import (
6
+ Walk,
7
+ Chain,
8
+ Fixpoint,
9
+ ConstantFold,
10
+ CommonSubexpressionElimination,
11
+ )
12
+
13
+ from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
14
+
15
+
16
+ @dataclass
17
+ class StimSimplifyIfs(Pass):
18
+
19
+ def unsafe_run(self, mt: ir.Method):
20
+
21
+ result = Chain(
22
+ Fixpoint(Walk(StimLiftThenBody())),
23
+ Walk(StimSplitIfStmts()),
24
+ ).rewrite(mt.code)
25
+
26
+ result = (
27
+ Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
28
+ .rewrite(mt.code)
29
+ .join(result)
30
+ )
31
+
32
+ return result
@@ -5,15 +5,19 @@ from kirin.rewrite import (
5
5
  Walk,
6
6
  Chain,
7
7
  Fixpoint,
8
+ CFGCompactify,
9
+ InlineGetItem,
10
+ InlineGetField,
8
11
  DeadCodeElimination,
9
12
  CommonSubexpressionElimination,
10
13
  )
11
14
  from kirin.analysis import const
15
+ from kirin.dialects import scf, ilist
12
16
  from kirin.ir.method import Method
13
17
  from kirin.passes.abc import Pass
14
18
  from kirin.rewrite.abc import RewriteResult
19
+ from kirin.passes.inline import InlinePass
15
20
 
16
- from bloqade.stim.groups import main as stim_main_group
17
21
  from bloqade.stim.rewrite import (
18
22
  SquinWireToStim,
19
23
  PyConstantToStim,
@@ -22,22 +26,95 @@ from bloqade.stim.rewrite import (
22
26
  SquinMeasureToStim,
23
27
  SquinWireIdentityElimination,
24
28
  )
25
- from bloqade.squin.rewrite import SquinU3ToClifford, RemoveDeadRegister
29
+ from bloqade.squin.rewrite import (
30
+ SquinU3ToClifford,
31
+ RemoveDeadRegister,
32
+ WrapAddressAnalysis,
33
+ )
34
+ from bloqade.rewrite.passes import CanonicalizeIList
35
+ from bloqade.analysis.address import AddressAnalysis
36
+ from bloqade.analysis.measure_id import MeasurementIDAnalysis
37
+
38
+ from .simplify_ifs import StimSimplifyIfs
39
+ from ..rewrite.ifs_to_stim import IfToStim
26
40
 
27
41
 
28
42
  @dataclass
29
- class SquinToStim(Pass):
43
+ class SquinToStimPass(Pass):
30
44
 
31
45
  def unsafe_run(self, mt: Method) -> RewriteResult:
32
- fold_pass = Fold(mt.dialects)
33
- # propagate constants
34
- rewrite_result = fold_pass(mt)
35
46
 
36
47
  cp_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
37
48
  cp_results = cp_frame.entries
38
49
 
39
50
  # Assume that address analysis and
40
51
  # wrapping has been done before this pass!
52
+ # inline aggressively:
53
+ rewrite_result = InlinePass(
54
+ dialects=mt.dialects, no_raise=self.no_raise
55
+ ).unsafe_run(mt)
56
+
57
+ rule = Chain(
58
+ InlineGetField(),
59
+ InlineGetItem(),
60
+ scf.unroll.ForLoop(),
61
+ scf.trim.UnusedYield(),
62
+ )
63
+ rewrite_result = Fixpoint(Walk(rule)).rewrite(mt.code).join(rewrite_result)
64
+ # fold_pass = Fold(mt.dialects, no_raise=self.no_raise)
65
+ # rewrite_result = fold_pass(mt)
66
+ rewrite_result = (
67
+ Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
68
+ )
69
+ rewrite_result = (
70
+ StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
71
+ .unsafe_run(mt)
72
+ .join(rewrite_result)
73
+ )
74
+
75
+ # run typeinfer again after unroll etc. because we now insert
76
+ # a lot of new nodes, which might have more precise types
77
+ # self.typeinfer.unsafe_run(mt)
78
+ rewrite_result = (
79
+ Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
80
+ .rewrite(mt.code)
81
+ .join(rewrite_result)
82
+ )
83
+ rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt)
84
+
85
+ rewrite_result = (
86
+ CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
87
+ .unsafe_run(mt)
88
+ .join(rewrite_result)
89
+ )
90
+
91
+ # after this the program should be in a state where it is analyzable
92
+ # -------------------------------------------------------------------
93
+
94
+ mia = MeasurementIDAnalysis(dialects=mt.dialects)
95
+ meas_analysis_frame, _ = mia.run_analysis(mt, no_raise=self.no_raise)
96
+
97
+ aa = AddressAnalysis(dialects=mt.dialects)
98
+ address_analysis_frame, _ = aa.run_analysis(mt, no_raise=self.no_raise)
99
+
100
+ # wrap the address analysis result
101
+ rewrite_result = (
102
+ Walk(WrapAddressAnalysis(address_analysis=address_analysis_frame.entries))
103
+ .rewrite(mt.code)
104
+ .join(rewrite_result)
105
+ )
106
+
107
+ # 2. rewrite
108
+ rewrite_result = (
109
+ Walk(
110
+ IfToStim(
111
+ measure_analysis=meas_analysis_frame.entries,
112
+ measure_count=mia.measure_count,
113
+ )
114
+ )
115
+ .rewrite(mt.code)
116
+ .join(rewrite_result)
117
+ )
41
118
 
42
119
  # Rewrite the noise statements first.
43
120
  rewrite_result = (
@@ -47,7 +124,6 @@ class SquinToStim(Pass):
47
124
  )
48
125
 
49
126
  # Wrap Rewrite + SquinToStim can happen w/ standard walk
50
-
51
127
  rewrite_result = Walk(SquinU3ToClifford()).rewrite(mt.code).join(rewrite_result)
52
128
 
53
129
  rewrite_result = (
@@ -55,23 +131,27 @@ class SquinToStim(Pass):
55
131
  Chain(
56
132
  SquinQubitToStim(),
57
133
  SquinWireToStim(),
58
- SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later
134
+ SquinMeasureToStim(
135
+ measure_id_result=meas_analysis_frame.entries,
136
+ total_measure_count=mia.measure_count,
137
+ ), # reduce duplicated logic, can split out even more rules later
59
138
  SquinWireIdentityElimination(),
60
139
  )
61
140
  )
62
141
  .rewrite(mt.code)
63
142
  .join(rewrite_result)
64
143
  )
65
-
66
- # Convert all PyConsts to Stim Constants
67
144
  rewrite_result = (
68
- Walk(Chain(PyConstantToStim())).rewrite(mt.code).join(rewrite_result)
145
+ CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
146
+ .unsafe_run(mt)
147
+ .join(rewrite_result)
69
148
  )
70
149
 
71
- # remove any squin.qubit.new that's left around
72
- ## Not considered pure so DCE won't touch it but
73
- ## it isn't being used anymore considering everything is a
74
- ## stim dialect statement
150
+ # Convert all PyConsts to Stim Constants
151
+ rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
152
+
153
+ # clear up leftover stmts
154
+ # - remove any squin.qubit.new that's left around
75
155
  rewrite_result = (
76
156
  Fixpoint(
77
157
  Walk(
@@ -86,16 +166,4 @@ class SquinToStim(Pass):
86
166
  .join(rewrite_result)
87
167
  )
88
168
 
89
- # do program verification here,
90
- # ideally use built-in .verify() to catch any
91
- # incompatible statements as the full rewrite process should not
92
- # leave statements from any other dialects (other than the stim main group)
93
- mt_verification_clone = mt.similar(stim_main_group)
94
-
95
- # suggested by Kai, will work for now
96
- for stmt in mt_verification_clone.code.walk():
97
- assert (
98
- stmt.dialect in stim_main_group
99
- ), "Statements detected that are not part of the stim dialect, please verify the original code is valid for rewrite!"
100
-
101
169
  return rewrite_result
@@ -0,0 +1,203 @@
1
+ from dataclasses import field, dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.dialects import py, scf, func
5
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
6
+
7
+ from bloqade.squin import op, qubit
8
+ from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
9
+ from bloqade.squin.rewrite import AddressAttribute
10
+ from bloqade.stim.rewrite.util import (
11
+ SQUIN_STIM_CONTROL_GATE_MAPPING,
12
+ insert_qubit_idx_from_address,
13
+ )
14
+ from bloqade.stim.dialects.auxiliary import GetRecord
15
+ from bloqade.analysis.measure_id.lattice import (
16
+ MeasureId,
17
+ MeasureIdBool,
18
+ )
19
+
20
+
21
+ @dataclass
22
+ class IfElseSimplification:
23
+
24
+ # Might be better to just do a rewrite_Region?
25
+ def is_rewriteable(self, node: scf.IfElse) -> bool:
26
+ return not (
27
+ self.contains_ifelse(node)
28
+ or self.is_nested_ifelse(node)
29
+ or self.has_else_body(node)
30
+ )
31
+
32
+ # A preliminary check to reject an IfElse from the "top down"
33
+ # use in conjunction with is_nested_ifelse
34
+ # to completely cover cases of nested IfElse statements
35
+ def contains_ifelse(self, stmt: scf.IfElse) -> bool:
36
+ """Check if the IfElse statement contains another IfElse statement."""
37
+ for child in stmt.walk(include_self=False):
38
+ if isinstance(child, scf.IfElse):
39
+ return True
40
+ return False
41
+
42
+ # because rewrite latches onto ANY scf.IfElse,
43
+ # you need a way to determine if you're touching an
44
+ # IfElse that's inside another IfElse
45
+ def is_nested_ifelse(self, stmt: scf.IfElse) -> bool:
46
+ """Check if the IfElse statement is nested."""
47
+ if stmt.parent_stmt is not None:
48
+ if isinstance(stmt.parent_stmt, scf.IfElse) or isinstance(
49
+ stmt.parent_stmt.parent_stmt, scf.IfElse
50
+ ):
51
+ return True
52
+ else:
53
+ return False
54
+ else:
55
+ return False
56
+
57
+ def has_else_body(self, stmt: scf.IfElse) -> bool:
58
+ """Check if the IfElse statement has an else body."""
59
+ if stmt.else_body.blocks and not (
60
+ len(stmt.else_body.blocks[0].stmts) == 1
61
+ and isinstance(else_term := stmt.else_body.blocks[0].last_stmt, scf.Yield)
62
+ and not else_term.values # empty yield
63
+ ):
64
+ return True
65
+
66
+ return False
67
+
68
+
69
+ DontLiftType = (
70
+ qubit.Apply,
71
+ qubit.Broadcast,
72
+ scf.Yield,
73
+ func.Return,
74
+ func.Invoke,
75
+ scf.IfElse,
76
+ )
77
+
78
+
79
+ @dataclass
80
+ class StimLiftThenBody(IfElseSimplification, LiftThenBody):
81
+ exclude_stmts: tuple[type[ir.Statement], ...] = field(default=DontLiftType)
82
+
83
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
84
+
85
+ if not isinstance(node, scf.IfElse):
86
+ return RewriteResult()
87
+
88
+ if not self.is_rewriteable(node):
89
+ return RewriteResult()
90
+
91
+ return super().rewrite_Statement(node)
92
+
93
+
94
+ # Only run this after everything other than qubit.Apply/qubit.Broadcast has been
95
+ # lifted out!
96
+ class StimSplitIfStmts(IfElseSimplification, SplitIfStmts):
97
+ """Splits the then body of an if-else statement into multiple if statements
98
+
99
+ Given an IfElse with multiple valid statements in the then-body:
100
+
101
+ if measure_result:
102
+ squin.qubit.apply(op.X, q0)
103
+ squin.qubit.apply(op.Y, q1)
104
+
105
+ this should be rewritten to:
106
+
107
+ if measure_result:
108
+ squin.qubit.apply(op.X, q0)
109
+
110
+ if measure_result:
111
+ squin.qubit.apply(op.Y, q1)
112
+ """
113
+
114
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
115
+ if not isinstance(node, scf.IfElse):
116
+ return RewriteResult()
117
+
118
+ if not self.is_rewriteable(node):
119
+ return RewriteResult()
120
+
121
+ return super().rewrite_Statement(node)
122
+
123
+
124
+ @dataclass
125
+ class IfToStim(IfElseSimplification, RewriteRule):
126
+ """
127
+ Rewrite if statements to stim equivalent statements.
128
+ """
129
+
130
+ measure_analysis: dict[ir.SSAValue, MeasureId]
131
+ measure_count: int
132
+
133
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
134
+
135
+ match node:
136
+ case scf.IfElse():
137
+ return self.rewrite_IfElse(node)
138
+ case _:
139
+ return RewriteResult()
140
+
141
+ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
142
+
143
+ if not isinstance(self.measure_analysis[stmt.cond], MeasureIdBool):
144
+ return RewriteResult()
145
+
146
+ # check that there is only qubit.Apply in the then-body,
147
+ # if there's more than that, we can't do a valid rewrite.
148
+ # Can reuse logic from SplitIf
149
+ *stmts, _ = stmt.then_body.stmts()
150
+ if len(stmts) != 1 or not isinstance(stmts[0], (qubit.Apply, qubit.Broadcast)):
151
+ return RewriteResult()
152
+
153
+ apply_or_broadcast = stmts[0]
154
+ # Check that the gate being applied/broadcasted can be converted to a stim
155
+ # controlled gate.
156
+ ctrl_op_target_gate = apply_or_broadcast.operator.owner
157
+ assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
158
+
159
+ stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate))
160
+ if stim_gate is None:
161
+ return RewriteResult()
162
+
163
+ # get necessary measurement ID type from analysis
164
+ measure_id_bool = self.measure_analysis[stmt.cond]
165
+ assert isinstance(measure_id_bool, MeasureIdBool)
166
+
167
+ # generate get record statement
168
+ measure_id_idx_stmt = py.Constant(
169
+ (measure_id_bool.idx - 1) - self.measure_count
170
+ )
171
+ get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
172
+
173
+ # get address attribute and generate qubit idx statements
174
+ address_attr = apply_or_broadcast.qubits.hints.get("address")
175
+ if address_attr is None:
176
+ return RewriteResult()
177
+ assert isinstance(address_attr, AddressAttribute)
178
+
179
+ # note: insert things before (literally above/outside) the If
180
+ qubit_idx_ssas = insert_qubit_idx_from_address(
181
+ address=address_attr, stmt_to_insert_before=stmt
182
+ )
183
+ if qubit_idx_ssas is None:
184
+ return RewriteResult()
185
+
186
+ # Assemble the stim statement
187
+ # let GetRecord's SSA be repeated per each get qubit
188
+ ctrl_records = tuple(get_record_stmt.result for _ in qubit_idx_ssas)
189
+
190
+ stim_stmt = stim_gate(
191
+ targets=tuple(qubit_idx_ssas),
192
+ controls=ctrl_records,
193
+ )
194
+
195
+ # Insert the necessary SSA Values, then get rid of the scf.IfElse.
196
+ # The qubit indices have been successfully added,
197
+ # that just leaves the GetRecord statement and measurement ID index statement
198
+
199
+ measure_id_idx_stmt.insert_before(stmt)
200
+ get_record_stmt.insert_before(stmt)
201
+ stmt.replace_by(stim_stmt)
202
+
203
+ return RewriteResult(has_done_something=True)
@@ -13,6 +13,9 @@ from bloqade.stim.rewrite.util import (
13
13
 
14
14
 
15
15
  class SquinQubitToStim(RewriteRule):
16
+ """
17
+ NOTE this require address analysis result to be wrapped before using this rule.
18
+ """
16
19
 
17
20
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18
21
 
@@ -1,22 +1,59 @@
1
1
  # create rewrite rule name SquinMeasureToStim using kirin
2
+ from dataclasses import dataclass
3
+
2
4
  from kirin import ir
3
- from kirin.dialects import py
5
+ from kirin.dialects import py, ilist
4
6
  from kirin.rewrite.abc import RewriteRule, RewriteResult
5
7
 
6
8
  from bloqade.squin import wire, qubit
7
9
  from bloqade.squin.rewrite import AddressAttribute
8
- from bloqade.stim.dialects import collapse
10
+ from bloqade.stim.dialects import collapse, auxiliary
9
11
  from bloqade.stim.rewrite.util import (
10
12
  is_measure_result_used,
11
13
  insert_qubit_idx_from_address,
12
14
  )
15
+ from bloqade.analysis.measure_id.lattice import MeasureId, MeasureIdBool, MeasureIdTuple
16
+
17
+
18
+ def replace_get_record(
19
+ node: ir.Statement, measure_id_bool: MeasureIdBool, meas_count: int
20
+ ):
21
+ assert isinstance(measure_id_bool, MeasureIdBool)
22
+ target_rec_idx = (measure_id_bool.idx - 1) - meas_count
23
+ idx_stmt = py.constant.Constant(target_rec_idx)
24
+ idx_stmt.insert_before(node)
25
+ get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
26
+ node.replace_by(get_record_stmt)
27
+
28
+
29
+ def insert_get_record_list(
30
+ node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count: int
31
+ ):
32
+ """
33
+ Insert GetRecord statements before the given node
34
+ """
35
+ get_record_ssas = []
36
+ for measure_id_bool in measure_id_tuple.data:
37
+ assert isinstance(measure_id_bool, MeasureIdBool)
38
+ target_rec_idx = (measure_id_bool.idx - 1) - meas_count
39
+ idx_stmt = py.constant.Constant(target_rec_idx)
40
+ idx_stmt.insert_before(node)
41
+ get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
42
+ get_record_stmt.insert_before(node)
43
+ get_record_ssas.append(get_record_stmt.result)
13
44
 
45
+ node.replace_by(ilist.New(values=get_record_ssas))
14
46
 
47
+
48
+ @dataclass
15
49
  class SquinMeasureToStim(RewriteRule):
16
50
  """
17
51
  Rewrite squin measure-related statements to stim statements.
18
52
  """
19
53
 
54
+ measure_id_result: dict[ir.SSAValue, MeasureId]
55
+ total_measure_count: int
56
+
20
57
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21
58
 
22
59
  match node:
@@ -28,20 +65,46 @@ class SquinMeasureToStim(RewriteRule):
28
65
  def rewrite_Measure(
29
66
  self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
30
67
  ) -> RewriteResult:
31
- if is_measure_result_used(measure_stmt):
32
- return RewriteResult()
33
68
 
34
69
  qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
35
70
  if qubit_idx_ssas is None:
36
71
  return RewriteResult()
37
72
 
73
+ measure_id = self.measure_id_result[measure_stmt.result]
74
+ if not isinstance(measure_id, (MeasureIdBool, MeasureIdTuple)):
75
+ return RewriteResult()
76
+
38
77
  prob_noise_stmt = py.constant.Constant(0.0)
39
78
  stim_measure_stmt = collapse.MZ(
40
79
  p=prob_noise_stmt.result,
41
80
  targets=qubit_idx_ssas,
42
81
  )
43
82
  prob_noise_stmt.insert_before(measure_stmt)
44
- measure_stmt.replace_by(stim_measure_stmt)
83
+ stim_measure_stmt.insert_before(measure_stmt)
84
+
85
+ if not is_measure_result_used(measure_stmt):
86
+ measure_stmt.delete()
87
+ return RewriteResult(has_done_something=True)
88
+
89
+ # replace dataflow with new stmt!
90
+ measure_id = self.measure_id_result[measure_stmt.result]
91
+ if isinstance(measure_id, MeasureIdBool):
92
+ replace_get_record(
93
+ node=measure_stmt,
94
+ measure_id_bool=measure_id,
95
+ meas_count=self.total_measure_count,
96
+ )
97
+ elif isinstance(measure_id, MeasureIdTuple):
98
+ insert_get_record_list(
99
+ node=measure_stmt,
100
+ measure_id_tuple=measure_id,
101
+ meas_count=self.total_measure_count,
102
+ )
103
+ else:
104
+ # already checked before, so this should not happen
105
+ raise ValueError(
106
+ f"Unexpected measure ID type: {type(measure_id)} for measure statement {measure_stmt}"
107
+ )
45
108
 
46
109
  return RewriteResult(has_done_something=True)
47
110
 
@@ -182,10 +182,6 @@ def rewrite_QubitLoss(
182
182
  create_wire_passthrough(stmt)
183
183
 
184
184
  stmt.replace_by(stim_loss_stmt)
185
- # NoiseChannels are not pure,
186
- # need to manually delete because
187
- # DCE won't touch them
188
- stmt.operator.owner.delete()
189
185
 
190
186
  return RewriteResult(has_done_something=True)
191
187
 
@@ -0,0 +1 @@
1
+ from .from_squin import squin_to_stim as squin_to_stim
@@ -0,0 +1,10 @@
1
+ from kirin import ir
2
+
3
+ from ..groups import main
4
+ from ..passes.squin_to_stim import SquinToStimPass
5
+
6
+
7
+ def squin_to_stim(mt: ir.Method) -> ir.Method:
8
+ new_mt = mt.similar()
9
+ SquinToStimPass(mt.dialects, no_raise=False)(new_mt)
10
+ return new_mt.similar(dialects=main)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bloqade-circuit
3
- Version: 0.5.0
3
+ Version: 0.5.2
4
4
  Summary: The software development toolkit for neutral atom arrays.
5
5
  Author-email: Roger-luo <rluo@quera.com>, kaihsin <khwu@quera.com>, weinbe58 <pweinberg@quera.com>, johnzl-777 <jlong@quera.com>
6
6
  License-File: LICENSE