bloqade-circuit 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (61) hide show
  1. bloqade/analysis/address/impls.py +21 -68
  2. bloqade/analysis/measure_id/__init__.py +2 -0
  3. bloqade/analysis/measure_id/analysis.py +45 -0
  4. bloqade/analysis/measure_id/impls.py +155 -0
  5. bloqade/analysis/measure_id/lattice.py +82 -0
  6. bloqade/cirq_utils/__init__.py +7 -0
  7. bloqade/cirq_utils/lineprog.py +295 -0
  8. bloqade/cirq_utils/parallelize.py +400 -0
  9. bloqade/pyqrack/squin/op.py +7 -2
  10. bloqade/pyqrack/squin/runtime.py +4 -2
  11. bloqade/qasm2/dialects/expr/stmts.py +2 -20
  12. bloqade/qasm2/parse/lowering.py +1 -0
  13. bloqade/qasm2/passes/parallel.py +18 -0
  14. bloqade/qasm2/passes/unroll_if.py +9 -2
  15. bloqade/qasm2/rewrite/__init__.py +1 -0
  16. bloqade/qasm2/rewrite/parallel_to_glob.py +82 -0
  17. bloqade/rewrite/__init__.py +0 -0
  18. bloqade/rewrite/passes/__init__.py +1 -0
  19. bloqade/rewrite/passes/canonicalize_ilist.py +28 -0
  20. bloqade/rewrite/rules/__init__.py +1 -0
  21. bloqade/rewrite/rules/flatten_ilist.py +51 -0
  22. bloqade/rewrite/rules/inline_getitem_ilist.py +31 -0
  23. bloqade/{qasm2/rewrite → rewrite/rules}/split_ifs.py +15 -8
  24. bloqade/squin/__init__.py +2 -0
  25. bloqade/squin/_typeinfer.py +20 -0
  26. bloqade/squin/analysis/__init__.py +1 -0
  27. bloqade/squin/analysis/address_impl.py +71 -0
  28. bloqade/squin/analysis/nsites/impls.py +6 -1
  29. bloqade/squin/cirq/lowering.py +19 -6
  30. bloqade/squin/noise/stmts.py +1 -1
  31. bloqade/squin/op/__init__.py +1 -0
  32. bloqade/squin/op/_wrapper.py +4 -0
  33. bloqade/squin/op/stmts.py +20 -2
  34. bloqade/squin/qubit.py +8 -5
  35. bloqade/squin/rewrite/__init__.py +1 -0
  36. bloqade/squin/rewrite/canonicalize.py +60 -0
  37. bloqade/squin/rewrite/desugar.py +52 -5
  38. bloqade/squin/types.py +8 -0
  39. bloqade/squin/wire.py +91 -5
  40. bloqade/stim/__init__.py +1 -0
  41. bloqade/stim/_wrappers.py +4 -0
  42. bloqade/stim/dialects/auxiliary/interp.py +0 -10
  43. bloqade/stim/dialects/auxiliary/stmts/annotate.py +1 -1
  44. bloqade/stim/dialects/noise/emit.py +1 -0
  45. bloqade/stim/dialects/noise/stmts.py +5 -0
  46. bloqade/stim/passes/__init__.py +1 -1
  47. bloqade/stim/passes/simplify_ifs.py +32 -0
  48. bloqade/stim/passes/squin_to_stim.py +109 -26
  49. bloqade/stim/rewrite/__init__.py +1 -0
  50. bloqade/stim/rewrite/ifs_to_stim.py +203 -0
  51. bloqade/stim/rewrite/qubit_to_stim.py +13 -6
  52. bloqade/stim/rewrite/squin_measure.py +68 -5
  53. bloqade/stim/rewrite/squin_noise.py +120 -0
  54. bloqade/stim/rewrite/util.py +40 -9
  55. bloqade/stim/rewrite/wire_to_stim.py +8 -3
  56. bloqade/stim/upstream/__init__.py +1 -0
  57. bloqade/stim/upstream/from_squin.py +10 -0
  58. {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/METADATA +4 -2
  59. {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/RECORD +61 -38
  60. {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/WHEEL +0 -0
  61. {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,60 @@
1
+ from typing import cast
2
+
3
+ from kirin import ir
4
+ from kirin.rewrite import abc
5
+ from kirin.dialects import cf
6
+
7
+ from .. import wire
8
+
9
+
10
+ class CanonicalizeWired(abc.RewriteRule):
11
+ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
12
+
13
+ if (
14
+ not isinstance(node, wire.Wired)
15
+ or len(node.qubits) != 0
16
+ or (parent_region := node.parent_region) is None
17
+ ):
18
+ return abc.RewriteResult()
19
+
20
+ parent_block = cast(ir.Block, node.parent_block)
21
+
22
+ # the body doesn't contain any quantum operations so we can safely inline the
23
+ # body into the parent block
24
+
25
+ # move all statements after `node` in the current block into another block
26
+ after_block = ir.Block()
27
+
28
+ stmt = node.next_stmt
29
+ while stmt is not None:
30
+ stmt.detach()
31
+ after_block.stmts.append(stmt)
32
+ stmt = node.next_stmt
33
+
34
+ # remap all results of the node to the arguments of the after_block
35
+ for result in node.results:
36
+ arg = after_block.args.append_from(result.type, result.name)
37
+ result.replace_by(arg)
38
+
39
+ parent_block_idx = parent_region._block_idx[parent_block]
40
+ # insert goto of parent block to the body block of the node.
41
+ parent_region.blocks.insert(parent_block_idx + 1, after_block)
42
+ # insert all blocks of the body of the node after the parent region
43
+ # making sure to convert any yield statements to jump statements to the after_block
44
+ parent_block.stmts.append(
45
+ cf.Branch(
46
+ arguments=(),
47
+ successor=node.body.blocks[0],
48
+ )
49
+ )
50
+ for block in reversed(node.body.blocks):
51
+ block.detach()
52
+ if isinstance((yield_stmt := block.last_stmt), wire.Yield):
53
+ yield_stmt.replace_by(
54
+ cf.Branch(yield_stmt.values, successor=after_block)
55
+ )
56
+
57
+ parent_region.blocks.insert(parent_block_idx + 1, block)
58
+
59
+ node.delete()
60
+ return abc.RewriteResult(has_done_something=True)
@@ -1,5 +1,5 @@
1
1
  from kirin import ir, types
2
- from kirin.dialects import ilist
2
+ from kirin.dialects import py, ilist
3
3
  from kirin.rewrite.abc import RewriteRule, RewriteResult
4
4
 
5
5
  from bloqade.squin.qubit import (
@@ -53,12 +53,59 @@ class ApplyDesugarRule(RewriteRule):
53
53
  op = node.operator
54
54
  qubits = node.qubits
55
55
 
56
- if len(qubits) == 1 and qubits[0].type.is_subseteq(ilist.IListType):
57
- # NOTE: already calling with just a single argument that is already an ilist
56
+ if len(qubits) > 1 and all(q.type.is_subseteq(QubitType) for q in qubits):
57
+ (qubits_ilist_stmt := ilist.New(qubits)).insert_before(node)
58
+ qubits_ilist = qubits_ilist_stmt.result
59
+
60
+ elif len(qubits) == 1 and qubits[0].type.is_subseteq(QubitType):
61
+ (qubits_ilist_stmt := ilist.New(qubits)).insert_before(node)
62
+ qubits_ilist = qubits_ilist_stmt.result
63
+
64
+ elif len(qubits) == 1 and qubits[0].type.is_subseteq(
65
+ ilist.IListType[QubitType, types.Any]
66
+ ):
58
67
  qubits_ilist = qubits[0]
59
- else:
60
- (qubits_ilist_stmt := ilist.New(values=qubits)).insert_before(node)
68
+
69
+ elif len(qubits) == 1:
70
+ # TODO: remove this elif clause once we're at kirin v0.18
71
+ # NOTE: this is a temporary workaround for kirin#408
72
+ # currently type inference fails here in for loops since the loop var
73
+ # is an IList for some reason
74
+
75
+ if not isinstance(qubits[0], ir.ResultValue):
76
+ return RewriteResult()
77
+
78
+ is_ilist = isinstance(qbit_stmt := qubits[0].stmt, ilist.New)
79
+ if is_ilist:
80
+ if len(qbit_stmt.values) != 1:
81
+ return RewriteResult()
82
+
83
+ if not isinstance(
84
+ qbit_getindex_result := qbit_stmt.values[0], ir.ResultValue
85
+ ):
86
+ return RewriteResult()
87
+
88
+ qbit_getindex = qbit_getindex_result.stmt
89
+ else:
90
+ qbit_getindex = qubits[0].stmt
91
+
92
+ if not isinstance(qbit_getindex, py.indexing.GetItem):
93
+ return RewriteResult()
94
+
95
+ if not qbit_getindex.obj.type.is_subseteq(
96
+ ilist.IListType[QubitType, types.Any]
97
+ ):
98
+ return RewriteResult()
99
+
100
+ if is_ilist:
101
+ values = qbit_stmt.values
102
+ else:
103
+ values = [qubits[0]]
104
+
105
+ (qubits_ilist_stmt := ilist.New(values=values)).insert_before(node)
61
106
  qubits_ilist = qubits_ilist_stmt.result
107
+ else:
108
+ return RewriteResult()
62
109
 
63
110
  stmt = Apply(operator=op, qubits=qubits_ilist)
64
111
  node.replace_by(stmt)
bloqade/squin/types.py ADDED
@@ -0,0 +1,8 @@
1
+ from kirin import types
2
+
3
+
4
+ class MeasurementResult:
5
+ pass
6
+
7
+
8
+ MeasurementResultType = types.PyClass(MeasurementResult)
bloqade/squin/wire.py CHANGED
@@ -6,12 +6,15 @@ circuits. Thus we do not define wrapping functions for the statements in this
6
6
  dialect.
7
7
  """
8
8
 
9
- from kirin import ir, types, lowering
9
+ from kirin import ir, types, lowering, exception
10
10
  from kirin.decl import info, statement
11
+ from kirin.dialects import func
11
12
  from kirin.lowering import wraps
13
+ from kirin.ir.attrs.types import TypeAttribute
12
14
 
13
15
  from bloqade.types import Qubit, QubitType
14
16
 
17
+ from .types import MeasurementResultType
15
18
  from .op.types import Op, OpType
16
19
 
17
20
  # from kirin.lowering import wraps
@@ -49,11 +52,87 @@ class Unwrap(ir.Statement):
49
52
  result: ir.ResultValue = info.result(WireType)
50
53
 
51
54
 
55
+ @statement(dialect=dialect)
56
+ class Wired(ir.Statement):
57
+ traits = frozenset()
58
+
59
+ qubits: tuple[ir.SSAValue, ...] = info.argument(QubitType)
60
+ memory_zone: str = info.attribute()
61
+ body: ir.Region = info.region(multi=True)
62
+
63
+ def __init__(
64
+ self,
65
+ body: ir.Region,
66
+ *qubits: ir.SSAValue,
67
+ memory_zone: str,
68
+ result_types: tuple[TypeAttribute, ...] | None = None,
69
+ ):
70
+ if result_types is None:
71
+ for block in body.blocks:
72
+ if isinstance(block.last_stmt, Yield):
73
+ result_types = tuple(arg.type for arg in block.last_stmt.values)
74
+ break
75
+
76
+ if result_types is None:
77
+ result_types = ()
78
+
79
+ super().__init__(
80
+ args=qubits,
81
+ args_slice={
82
+ "qubits": slice(0, None),
83
+ },
84
+ regions=[body],
85
+ attributes={
86
+ "memory_zone": ir.PyAttr(memory_zone)
87
+ }, # body of the wired statement
88
+ result_types=result_types,
89
+ )
90
+
91
+ def check(self):
92
+ entry_block = self.body.blocks[0]
93
+
94
+ if len(entry_block.args) != len(self.qubits):
95
+ raise exception.StaticCheckError(
96
+ f"Expected {len(self.qubits)} arguments, got {len(entry_block.args)}."
97
+ )
98
+ for arg in entry_block.args:
99
+ if not arg.type.is_subseteq(WireType):
100
+ raise exception.StaticCheckError(
101
+ f"Expected argument of type {WireType}, got {arg.type}."
102
+ )
103
+ for block in self.body.blocks:
104
+ last_stmt = block.last_stmt
105
+ if isinstance(last_stmt, func.Return):
106
+ raise exception.StaticCheckError(
107
+ "Return statements are not allowed in the body of a Wired statement."
108
+ )
109
+ elif isinstance(last_stmt, Yield) and len(last_stmt.values) != len(
110
+ self.results
111
+ ):
112
+ raise exception.StaticCheckError(
113
+ f"Expected {len(self.results)} return values, got {len(last_stmt.values)}."
114
+ )
115
+
116
+
117
+ @statement(dialect=dialect)
118
+ class Yield(ir.Statement):
119
+ traits = frozenset({})
120
+ values: tuple[ir.SSAValue, ...] = info.argument(WireType)
121
+
122
+ def __init__(self, *args: ir.SSAValue):
123
+ super().__init__(
124
+ args=args,
125
+ args_slice={
126
+ "values": slice(0, None),
127
+ },
128
+ )
129
+
130
+
52
131
  # In Quake, you put a wire in and get a wire out when you "apply" an operator
53
132
  # In this case though we just need to indicate that an operator is applied to list[wires]
54
133
  @statement(dialect=dialect)
55
134
  class Apply(ir.Statement): # apply(op, w1, w2, ...)
56
- traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
135
+ traits = frozenset({lowering.FromPythonCall()})
57
136
  operator: ir.SSAValue = info.argument(OpType)
58
137
  inputs: tuple[ir.SSAValue, ...] = info.argument(WireType)
59
138
 
@@ -88,6 +167,13 @@ class Broadcast(ir.Statement):
88
167
  ) # custom lowering required for wrapper to work here
89
168
 
90
169
 
170
+ @statement(dialect=dialect)
171
+ class RegionMeasure(ir.Statement):
172
+ traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
173
+ wire: ir.SSAValue = info.argument(WireType)
174
+ result: ir.ResultValue = info.result(MeasurementResultType)
175
+
176
+
91
177
  # NOTE: measurement cannot be pure because they will collapse the state
92
178
  # of the qubit. The state is a hidden state that is not visible to
93
179
  # the user in the wire dialect.
@@ -96,14 +182,14 @@ class Measure(ir.Statement):
96
182
  traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
97
183
  wire: ir.SSAValue = info.argument(WireType)
98
184
  qubit: ir.SSAValue = info.argument(QubitType)
99
- result: ir.ResultValue = info.result(types.Int)
185
+ result: ir.ResultValue = info.result(MeasurementResultType)
100
186
 
101
187
 
102
188
  @statement(dialect=dialect)
103
- class NonDestructiveMeasure(ir.Statement):
189
+ class LossResolvingMeasure(ir.Statement):
104
190
  traits = frozenset({lowering.FromPythonCall()})
105
191
  input_wire: ir.SSAValue = info.argument(WireType)
106
- result: ir.ResultValue = info.result(types.Int)
192
+ result: ir.ResultValue = info.result(MeasurementResultType)
107
193
  out_wire: ir.ResultValue = info.result(WireType)
108
194
 
109
195
 
bloqade/stim/__init__.py CHANGED
@@ -31,6 +31,7 @@ from ._wrappers import (
31
31
  z_error as z_error,
32
32
  detector as detector,
33
33
  identity as identity,
34
+ qubit_loss as qubit_loss,
34
35
  depolarize1 as depolarize1,
35
36
  depolarize2 as depolarize2,
36
37
  pauli_string as pauli_string,
bloqade/stim/_wrappers.py CHANGED
@@ -190,3 +190,7 @@ def y_error(p: float, targets: tuple[int, ...]) -> None: ...
190
190
 
191
191
  @wraps(noise.ZError)
192
192
  def z_error(p: float, targets: tuple[int, ...]) -> None: ...
193
+
194
+
195
+ @wraps(noise.QubitLoss)
196
+ def qubit_loss(probs: tuple[float, ...], targets: tuple[int, ...]) -> None: ...
@@ -1,7 +1,6 @@
1
1
  from kirin import interp
2
2
 
3
3
  from . import stmts
4
- from .types import RecordResult
5
4
  from ._dialect import dialect
6
5
 
7
6
 
@@ -28,12 +27,3 @@ class StimAuxMethods(interp.MethodTable):
28
27
  stmt: stmts.Neg,
29
28
  ):
30
29
  return (-frame.get(stmt.operand),)
31
-
32
- @interp.impl(stmts.GetRecord)
33
- def get_rec(
34
- self,
35
- interpreter: interp.Interpreter,
36
- frame: interp.Frame,
37
- stmt: stmts.GetRecord,
38
- ):
39
- return (RecordResult(value=frame.get(stmt.id)),)
@@ -10,7 +10,7 @@ PyNum = types.Union(types.Int, types.Float)
10
10
  @statement(dialect=dialect)
11
11
  class GetRecord(ir.Statement):
12
12
  name = "get_rec"
13
- traits = frozenset({lowering.FromPythonCall()})
13
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
14
14
  id: ir.SSAValue = info.argument(type=types.Int)
15
15
  result: ir.ResultValue = info.result(type=RecordType)
16
16
 
@@ -66,6 +66,7 @@ class EmitStimNoiseMethods(MethodTable):
66
66
  return ()
67
67
 
68
68
  @impl(stmts.TrivialError)
69
+ @impl(stmts.QubitLoss)
69
70
  def non_stim_error(
70
71
  self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.TrivialError
71
72
  ):
@@ -104,3 +104,8 @@ class TrivialCorrelatedError(NonStimCorrelatedError):
104
104
  @statement(dialect=dialect)
105
105
  class TrivialError(NonStimError):
106
106
  name = "TRIV_ERROR"
107
+
108
+
109
+ @statement(dialect=dialect)
110
+ class QubitLoss(NonStimError):
111
+ name = "loss"
@@ -1 +1 @@
1
- from .squin_to_stim import SquinToStim as SquinToStim
1
+ from .squin_to_stim import SquinToStimPass as SquinToStimPass
@@ -0,0 +1,32 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Pass
5
+ from kirin.rewrite import (
6
+ Walk,
7
+ Chain,
8
+ Fixpoint,
9
+ ConstantFold,
10
+ CommonSubexpressionElimination,
11
+ )
12
+
13
+ from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
14
+
15
+
16
+ @dataclass
17
+ class StimSimplifyIfs(Pass):
18
+
19
+ def unsafe_run(self, mt: ir.Method):
20
+
21
+ result = Chain(
22
+ Fixpoint(Walk(StimLiftThenBody())),
23
+ Walk(StimSplitIfStmts()),
24
+ ).rewrite(mt.code)
25
+
26
+ result = (
27
+ Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
28
+ .rewrite(mt.code)
29
+ .join(result)
30
+ )
31
+
32
+ return result
@@ -5,58 +5,153 @@ from kirin.rewrite import (
5
5
  Walk,
6
6
  Chain,
7
7
  Fixpoint,
8
+ CFGCompactify,
9
+ InlineGetItem,
10
+ InlineGetField,
8
11
  DeadCodeElimination,
9
12
  CommonSubexpressionElimination,
10
13
  )
14
+ from kirin.analysis import const
15
+ from kirin.dialects import scf, ilist
11
16
  from kirin.ir.method import Method
12
17
  from kirin.passes.abc import Pass
13
18
  from kirin.rewrite.abc import RewriteResult
19
+ from kirin.passes.inline import InlinePass
14
20
 
15
- from bloqade.stim.groups import main as stim_main_group
16
21
  from bloqade.stim.rewrite import (
17
22
  SquinWireToStim,
18
23
  PyConstantToStim,
24
+ SquinNoiseToStim,
19
25
  SquinQubitToStim,
20
26
  SquinMeasureToStim,
21
27
  SquinWireIdentityElimination,
22
28
  )
23
- from bloqade.squin.rewrite import RemoveDeadRegister
29
+ from bloqade.squin.rewrite import (
30
+ SquinU3ToClifford,
31
+ RemoveDeadRegister,
32
+ WrapAddressAnalysis,
33
+ )
34
+ from bloqade.rewrite.passes import CanonicalizeIList
35
+ from bloqade.analysis.address import AddressAnalysis
36
+ from bloqade.analysis.measure_id import MeasurementIDAnalysis
37
+
38
+ from .simplify_ifs import StimSimplifyIfs
39
+ from ..rewrite.ifs_to_stim import IfToStim
24
40
 
25
41
 
26
42
  @dataclass
27
- class SquinToStim(Pass):
43
+ class SquinToStimPass(Pass):
28
44
 
29
45
  def unsafe_run(self, mt: Method) -> RewriteResult:
30
- fold_pass = Fold(mt.dialects)
31
- # propagate constants
32
- rewrite_result = fold_pass(mt)
46
+
47
+ cp_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
48
+ cp_results = cp_frame.entries
33
49
 
34
50
  # Assume that address analysis and
35
51
  # wrapping has been done before this pass!
52
+ # inline aggressively:
53
+ rewrite_result = InlinePass(
54
+ dialects=mt.dialects, no_raise=self.no_raise
55
+ ).unsafe_run(mt)
56
+
57
+ rule = Chain(
58
+ InlineGetField(),
59
+ InlineGetItem(),
60
+ scf.unroll.ForLoop(),
61
+ scf.trim.UnusedYield(),
62
+ )
63
+ rewrite_result = Fixpoint(Walk(rule)).rewrite(mt.code).join(rewrite_result)
64
+ # fold_pass = Fold(mt.dialects, no_raise=self.no_raise)
65
+ # rewrite_result = fold_pass(mt)
66
+ rewrite_result = (
67
+ Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
68
+ )
69
+ rewrite_result = (
70
+ StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
71
+ .unsafe_run(mt)
72
+ .join(rewrite_result)
73
+ )
74
+
75
+ # run typeinfer again after unroll etc. because we now insert
76
+ # a lot of new nodes, which might have more precise types
77
+ # self.typeinfer.unsafe_run(mt)
78
+ rewrite_result = (
79
+ Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
80
+ .rewrite(mt.code)
81
+ .join(rewrite_result)
82
+ )
83
+ rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt)
84
+
85
+ rewrite_result = (
86
+ CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
87
+ .unsafe_run(mt)
88
+ .join(rewrite_result)
89
+ )
90
+
91
+ # after this the program should be in a state where it is analyzable
92
+ # -------------------------------------------------------------------
93
+
94
+ mia = MeasurementIDAnalysis(dialects=mt.dialects)
95
+ meas_analysis_frame, _ = mia.run_analysis(mt, no_raise=self.no_raise)
96
+
97
+ aa = AddressAnalysis(dialects=mt.dialects)
98
+ address_analysis_frame, _ = aa.run_analysis(mt, no_raise=self.no_raise)
99
+
100
+ # wrap the address analysis result
101
+ rewrite_result = (
102
+ Walk(WrapAddressAnalysis(address_analysis=address_analysis_frame.entries))
103
+ .rewrite(mt.code)
104
+ .join(rewrite_result)
105
+ )
106
+
107
+ # 2. rewrite
108
+ rewrite_result = (
109
+ Walk(
110
+ IfToStim(
111
+ measure_analysis=meas_analysis_frame.entries,
112
+ measure_count=mia.measure_count,
113
+ )
114
+ )
115
+ .rewrite(mt.code)
116
+ .join(rewrite_result)
117
+ )
118
+
119
+ # Rewrite the noise statements first.
120
+ rewrite_result = (
121
+ Walk(SquinNoiseToStim(cp_results=cp_results))
122
+ .rewrite(mt.code)
123
+ .join(rewrite_result)
124
+ )
36
125
 
37
126
  # Wrap Rewrite + SquinToStim can happen w/ standard walk
127
+ rewrite_result = Walk(SquinU3ToClifford()).rewrite(mt.code).join(rewrite_result)
128
+
38
129
  rewrite_result = (
39
130
  Walk(
40
131
  Chain(
41
132
  SquinQubitToStim(),
42
133
  SquinWireToStim(),
43
- SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later
134
+ SquinMeasureToStim(
135
+ measure_id_result=meas_analysis_frame.entries,
136
+ total_measure_count=mia.measure_count,
137
+ ), # reduce duplicated logic, can split out even more rules later
44
138
  SquinWireIdentityElimination(),
45
139
  )
46
140
  )
47
141
  .rewrite(mt.code)
48
142
  .join(rewrite_result)
49
143
  )
50
-
51
- # Convert all PyConsts to Stim Constants
52
144
  rewrite_result = (
53
- Walk(Chain(PyConstantToStim())).rewrite(mt.code).join(rewrite_result)
145
+ CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
146
+ .unsafe_run(mt)
147
+ .join(rewrite_result)
54
148
  )
55
149
 
56
- # remove any squin.qubit.new that's left around
57
- ## Not considered pure so DCE won't touch it but
58
- ## it isn't being used anymore considering everything is a
59
- ## stim dialect statement
150
+ # Convert all PyConsts to Stim Constants
151
+ rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
152
+
153
+ # clear up leftover stmts
154
+ # - remove any squin.qubit.new that's left around
60
155
  rewrite_result = (
61
156
  Fixpoint(
62
157
  Walk(
@@ -71,16 +166,4 @@ class SquinToStim(Pass):
71
166
  .join(rewrite_result)
72
167
  )
73
168
 
74
- # do program verification here,
75
- # ideally use built-in .verify() to catch any
76
- # incompatible statements as the full rewrite process should not
77
- # leave statements from any other dialects (other than the stim main group)
78
- mt_verification_clone = mt.similar(stim_main_group)
79
-
80
- # suggested by Kai, will work for now
81
- for stmt in mt_verification_clone.code.walk():
82
- assert (
83
- stmt.dialect in stim_main_group
84
- ), "Statements detected that are not part of the stim dialect, please verify the original code is valid for rewrite!"
85
-
86
169
  return rewrite_result
@@ -1,3 +1,4 @@
1
+ from .squin_noise import SquinNoiseToStim as SquinNoiseToStim
1
2
  from .wire_to_stim import SquinWireToStim as SquinWireToStim
2
3
  from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
3
4
  from .squin_measure import SquinMeasureToStim as SquinMeasureToStim