bloqade-circuit 0.2.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (101) hide show
  1. bloqade/analysis/address/impls.py +3 -2
  2. bloqade/pyqrack/device.py +1 -3
  3. bloqade/pyqrack/noise/native.py +8 -8
  4. bloqade/pyqrack/qasm2/core.py +4 -1
  5. bloqade/pyqrack/squin/op.py +7 -0
  6. bloqade/pyqrack/squin/qubit.py +5 -27
  7. bloqade/pyqrack/squin/runtime.py +18 -0
  8. bloqade/pyqrack/squin/wire.py +4 -22
  9. bloqade/pyqrack/task.py +13 -5
  10. bloqade/qasm2/__init__.py +1 -0
  11. bloqade/qasm2/_qasm_loading.py +151 -0
  12. bloqade/qasm2/dialects/core/__init__.py +9 -1
  13. bloqade/qasm2/dialects/expr/__init__.py +18 -1
  14. bloqade/{noise/native → qasm2/dialects/noise}/__init__.py +1 -7
  15. bloqade/qasm2/dialects/noise/_dialect.py +3 -0
  16. bloqade/{noise → qasm2/dialects/noise}/fidelity.py +4 -4
  17. bloqade/qasm2/dialects/noise/model.py +278 -0
  18. bloqade/{noise/native → qasm2/dialects/noise}/stmts.py +1 -1
  19. bloqade/qasm2/dialects/uop/__init__.py +39 -3
  20. bloqade/qasm2/dialects/uop/schedule.py +1 -1
  21. bloqade/qasm2/emit/impls/__init__.py +1 -0
  22. bloqade/qasm2/emit/impls/noise.py +89 -0
  23. bloqade/qasm2/emit/main.py +23 -4
  24. bloqade/qasm2/emit/target.py +19 -4
  25. bloqade/qasm2/noise.py +67 -0
  26. bloqade/qasm2/parse/__init__.py +7 -4
  27. bloqade/qasm2/parse/lowering.py +20 -130
  28. bloqade/qasm2/parse/qasm2.lark +1 -1
  29. bloqade/qasm2/passes/__init__.py +1 -0
  30. bloqade/qasm2/passes/fold.py +6 -0
  31. bloqade/qasm2/passes/glob.py +12 -8
  32. bloqade/qasm2/passes/noise.py +27 -16
  33. bloqade/qasm2/passes/parallel.py +9 -0
  34. bloqade/qasm2/passes/unroll_if.py +25 -0
  35. bloqade/qasm2/rewrite/__init__.py +3 -0
  36. bloqade/qasm2/rewrite/desugar.py +3 -2
  37. bloqade/qasm2/rewrite/native_gates.py +67 -4
  38. bloqade/qasm2/rewrite/noise/__init__.py +0 -0
  39. bloqade/qasm2/rewrite/{heuristic_noise.py → noise/heuristic_noise.py} +32 -62
  40. bloqade/{noise/native/rewrite.py → qasm2/rewrite/noise/remove_noise.py} +2 -2
  41. bloqade/qasm2/rewrite/split_ifs.py +66 -0
  42. bloqade/qbraid/lowering.py +8 -8
  43. bloqade/squin/__init__.py +7 -1
  44. bloqade/squin/analysis/nsites/__init__.py +1 -0
  45. bloqade/squin/analysis/nsites/impls.py +16 -1
  46. bloqade/squin/groups.py +4 -4
  47. bloqade/squin/lowering.py +27 -0
  48. bloqade/squin/noise/__init__.py +7 -26
  49. bloqade/squin/noise/_wrapper.py +25 -0
  50. bloqade/squin/op/__init__.py +34 -159
  51. bloqade/squin/op/_wrapper.py +105 -0
  52. bloqade/squin/op/stdlib.py +62 -0
  53. bloqade/squin/op/stmts.py +10 -0
  54. bloqade/squin/passes/__init__.py +1 -0
  55. bloqade/squin/passes/stim.py +68 -0
  56. bloqade/squin/qubit.py +32 -37
  57. bloqade/squin/rewrite/__init__.py +11 -0
  58. bloqade/squin/rewrite/desugar.py +65 -0
  59. bloqade/squin/rewrite/qubit_to_stim.py +61 -0
  60. bloqade/squin/rewrite/squin_measure.py +73 -0
  61. bloqade/squin/rewrite/stim_rewrite_util.py +153 -0
  62. bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
  63. bloqade/squin/rewrite/wire_to_stim.py +52 -0
  64. bloqade/squin/rewrite/wrap_analysis.py +72 -0
  65. bloqade/squin/wire.py +5 -22
  66. bloqade/stim/__init__.py +40 -5
  67. bloqade/stim/_wrappers.py +18 -12
  68. bloqade/stim/dialects/__init__.py +1 -5
  69. bloqade/stim/dialects/{aux → auxiliary}/__init__.py +13 -1
  70. bloqade/stim/dialects/{aux → auxiliary}/emit.py +18 -3
  71. bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +1 -0
  72. bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +8 -0
  73. bloqade/stim/dialects/collapse/__init__.py +13 -2
  74. bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +4 -2
  75. bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
  76. bloqade/stim/dialects/gate/__init__.py +16 -1
  77. bloqade/stim/dialects/gate/emit.py +10 -3
  78. bloqade/stim/dialects/gate/stmts/base.py +1 -1
  79. bloqade/stim/dialects/gate/stmts/pp.py +1 -1
  80. bloqade/stim/dialects/noise/emit.py +33 -2
  81. bloqade/stim/dialects/noise/stmts.py +29 -0
  82. bloqade/stim/emit/__init__.py +1 -1
  83. bloqade/stim/groups.py +4 -2
  84. bloqade/stim/parse/__init__.py +1 -0
  85. bloqade/stim/parse/lowering.py +686 -0
  86. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/METADATA +5 -3
  87. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/RECORD +95 -77
  88. bloqade/noise/__init__.py +0 -2
  89. bloqade/noise/native/_dialect.py +0 -3
  90. bloqade/noise/native/_wrappers.py +0 -34
  91. bloqade/noise/native/model.py +0 -346
  92. bloqade/qasm2/dialects/noise.py +0 -16
  93. bloqade/squin/rewrite/measure_desugar.py +0 -33
  94. /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
  95. /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
  96. /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
  97. /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
  98. /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
  99. /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
  100. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/WHEEL +0 -0
  101. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,105 @@
1
+ from kirin.lowering import wraps
2
+
3
+ from . import stmts, types
4
+
5
+
6
+ @wraps(stmts.Kron)
7
+ def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
8
+
9
+
10
+ @wraps(stmts.Mult)
11
+ def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
12
+
13
+
14
+ @wraps(stmts.Scale)
15
+ def scale(op: types.Op, factor: complex) -> types.Op: ...
16
+
17
+
18
+ @wraps(stmts.Adjoint)
19
+ def adjoint(op: types.Op) -> types.Op: ...
20
+
21
+
22
+ @wraps(stmts.Control)
23
+ def control(op: types.Op, *, n_controls: int) -> types.Op:
24
+ """
25
+ Create a controlled operator.
26
+
27
+ Note, that when considering atom loss, the operator will not be applied if
28
+ any of the controls has been lost.
29
+
30
+ Args:
31
+ operator: The operator to apply under the control.
32
+ n_controls: The number qubits to be used as control.
33
+
34
+ Returns:
35
+ Operator
36
+ """
37
+ ...
38
+
39
+
40
+ @wraps(stmts.Reset)
41
+ def reset() -> types.Op: ...
42
+
43
+
44
+ @wraps(stmts.Identity)
45
+ def identity(*, sites: int) -> types.Op: ...
46
+
47
+
48
+ @wraps(stmts.Rot)
49
+ def rot(axis: types.Op, angle: float) -> types.Op: ...
50
+
51
+
52
+ @wraps(stmts.ShiftOp)
53
+ def shift(theta: float) -> types.Op: ...
54
+
55
+
56
+ @wraps(stmts.PhaseOp)
57
+ def phase(theta: float) -> types.Op: ...
58
+
59
+
60
+ @wraps(stmts.X)
61
+ def x() -> types.Op: ...
62
+
63
+
64
+ @wraps(stmts.Y)
65
+ def y() -> types.Op: ...
66
+
67
+
68
+ @wraps(stmts.Z)
69
+ def z() -> types.Op: ...
70
+
71
+
72
+ @wraps(stmts.H)
73
+ def h() -> types.Op: ...
74
+
75
+
76
+ @wraps(stmts.S)
77
+ def s() -> types.Op: ...
78
+
79
+
80
+ @wraps(stmts.T)
81
+ def t() -> types.Op: ...
82
+
83
+
84
+ @wraps(stmts.P0)
85
+ def p0() -> types.Op: ...
86
+
87
+
88
+ @wraps(stmts.P1)
89
+ def p1() -> types.Op: ...
90
+
91
+
92
+ @wraps(stmts.Sn)
93
+ def spin_n() -> types.Op: ...
94
+
95
+
96
+ @wraps(stmts.Sp)
97
+ def spin_p() -> types.Op: ...
98
+
99
+
100
+ @wraps(stmts.U3)
101
+ def u(theta: float, phi: float, lam: float) -> types.Op: ...
102
+
103
+
104
+ @wraps(stmts.PauliString)
105
+ def pauli_string(*, string: str) -> types.Op: ...
@@ -0,0 +1,62 @@
1
+ from kirin import ir
2
+ from kirin.prelude import structural_no_opt
3
+
4
+ from . import types
5
+ from ._dialect import dialect
6
+ from ._wrapper import h, x, y, z, rot, phase, control
7
+
8
+
9
+ @ir.dialect_group(structural_no_opt.add(dialect))
10
+ def op(self):
11
+ def run_pass(method):
12
+ pass
13
+
14
+ return run_pass
15
+
16
+
17
+ @op
18
+ def rx(theta: float) -> types.Op:
19
+ """Rotation X gate."""
20
+ return rot(x(), theta)
21
+
22
+
23
+ @op
24
+ def ry(theta: float) -> types.Op:
25
+ """Rotation Y gate."""
26
+ return rot(y(), theta)
27
+
28
+
29
+ @op
30
+ def rz(theta: float) -> types.Op:
31
+ """Rotation Z gate."""
32
+ return rot(z(), theta)
33
+
34
+
35
+ @op
36
+ def cx() -> types.Op:
37
+ """Controlled X gate."""
38
+ return control(x(), n_controls=1)
39
+
40
+
41
+ @op
42
+ def cy() -> types.Op:
43
+ """Controlled Y gate."""
44
+ return control(y(), n_controls=1)
45
+
46
+
47
+ @op
48
+ def cz() -> types.Op:
49
+ """Control Z gate."""
50
+ return control(z(), n_controls=1)
51
+
52
+
53
+ @op
54
+ def ch() -> types.Op:
55
+ """Control H gate."""
56
+ return control(h(), n_controls=1)
57
+
58
+
59
+ @op
60
+ def cphase(theta: float) -> types.Op:
61
+ """Control Phase gate."""
62
+ return control(phase(theta), n_controls=1)
bloqade/squin/op/stmts.py CHANGED
@@ -142,6 +142,16 @@ class ShiftOp(PrimitiveOp):
142
142
  result: ir.ResultValue = info.result(OpType)
143
143
 
144
144
 
145
+ @statement(dialect=dialect)
146
+ class Reset(PrimitiveOp):
147
+ """
148
+ Reset operator for qubits or wires.
149
+ """
150
+
151
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall(), FixedSites(1)})
152
+ result: ir.ResultValue = info.result(OpType)
153
+
154
+
145
155
  @statement
146
156
  class PauliOp(ConstantUnitary):
147
157
  pass
@@ -0,0 +1 @@
1
+ from .stim import SquinToStim as SquinToStim
@@ -0,0 +1,68 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin.passes import Fold
4
+ from kirin.rewrite import (
5
+ Walk,
6
+ Chain,
7
+ Fixpoint,
8
+ DeadCodeElimination,
9
+ CommonSubexpressionElimination,
10
+ )
11
+ from kirin.ir.method import Method
12
+ from kirin.passes.abc import Pass
13
+ from kirin.rewrite.abc import RewriteResult
14
+
15
+ from bloqade.squin.rewrite import (
16
+ SquinWireToStim,
17
+ SquinQubitToStim,
18
+ WrapSquinAnalysis,
19
+ SquinMeasureToStim,
20
+ SquinWireIdentityElimination,
21
+ )
22
+ from bloqade.analysis.address import AddressAnalysis
23
+ from bloqade.squin.analysis.nsites import (
24
+ NSitesAnalysis,
25
+ )
26
+
27
+
28
+ @dataclass
29
+ class SquinToStim(Pass):
30
+
31
+ def unsafe_run(self, mt: Method) -> RewriteResult:
32
+ fold_pass = Fold(mt.dialects)
33
+ # propagate constants
34
+ rewrite_result = fold_pass(mt)
35
+
36
+ # Get necessary analysis results to plug into hints
37
+ address_analysis = AddressAnalysis(mt.dialects)
38
+ address_frame, _ = address_analysis.run_analysis(mt)
39
+ site_analysis = NSitesAnalysis(mt.dialects)
40
+ sites_frame, _ = site_analysis.run_analysis(mt)
41
+
42
+ # Wrap Rewrite + SquinToStim can happen w/ standard walk
43
+ rewrite_result = (
44
+ Walk(
45
+ Chain(
46
+ WrapSquinAnalysis(
47
+ address_analysis=address_frame.entries,
48
+ op_site_analysis=sites_frame.entries,
49
+ ),
50
+ SquinQubitToStim(),
51
+ SquinWireToStim(),
52
+ SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later
53
+ SquinWireIdentityElimination(),
54
+ )
55
+ )
56
+ .rewrite(mt.code)
57
+ .join(rewrite_result)
58
+ )
59
+
60
+ rewrite_result = (
61
+ Fixpoint(
62
+ Walk(Chain(DeadCodeElimination(), CommonSubexpressionElimination()))
63
+ )
64
+ .rewrite(mt.code)
65
+ .join(rewrite_result)
66
+ )
67
+
68
+ return rewrite_result
bloqade/squin/qubit.py CHANGED
@@ -17,6 +17,8 @@ from kirin.lowering import wraps
17
17
  from bloqade.types import Qubit, QubitType
18
18
  from bloqade.squin.op.types import Op, OpType
19
19
 
20
+ from .lowering import ApplyAnyCallLowering
21
+
20
22
  dialect = ir.Dialect("squin.qubit")
21
23
 
22
24
 
@@ -34,6 +36,14 @@ class Apply(ir.Statement):
34
36
  qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
35
37
 
36
38
 
39
+ @statement(dialect=dialect)
40
+ class ApplyAny(ir.Statement):
41
+ # NOTE: custom lowering to deal with vararg calls
42
+ traits = frozenset({ApplyAnyCallLowering()})
43
+ operator: ir.SSAValue = info.argument(OpType)
44
+ qubits: tuple[ir.SSAValue, ...] = info.argument()
45
+
46
+
37
47
  @statement(dialect=dialect)
38
48
  class Broadcast(ir.Statement):
39
49
  traits = frozenset({lowering.FromPythonCall()})
@@ -68,19 +78,6 @@ class MeasureQubitList(ir.Statement):
68
78
  result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
69
79
 
70
80
 
71
- @statement(dialect=dialect)
72
- class MeasureAndReset(ir.Statement):
73
- traits = frozenset({lowering.FromPythonCall()})
74
- qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
75
- result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
76
-
77
-
78
- @statement(dialect=dialect)
79
- class Reset(ir.Statement):
80
- traits = frozenset({lowering.FromPythonCall()})
81
- qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
82
-
83
-
84
81
  # NOTE: no dependent types in Python, so we have to mark it Any...
85
82
  @wraps(New)
86
83
  def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
@@ -95,7 +92,7 @@ def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
95
92
  ...
96
93
 
97
94
 
98
- @wraps(Apply)
95
+ @overload
99
96
  def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
100
97
  """Apply an operator to a list of qubits.
101
98
 
@@ -112,6 +109,27 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
112
109
  ...
113
110
 
114
111
 
112
+ @overload
113
+ def apply(operator: Op, *qubits: Qubit) -> None:
114
+ """Apply and operator to any number of qubits.
115
+
116
+ Note, that when considering atom loss, lost qubits will be skipped.
117
+
118
+ Args:
119
+ operator: The operator to apply.
120
+ *qubits: The qubits to apply the operator to. The number of qubits must
121
+ match the size of the operator.
122
+
123
+ Returns:
124
+ None
125
+ """
126
+ ...
127
+
128
+
129
+ @wraps(ApplyAny)
130
+ def apply(operator: Op, *qubits) -> None: ...
131
+
132
+
115
133
  @overload
116
134
  def measure(input: Qubit) -> bool: ...
117
135
  @overload
@@ -161,26 +179,3 @@ def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> No
161
179
  None
162
180
  """
163
181
  ...
164
-
165
-
166
- @wraps(MeasureAndReset)
167
- def measure_and_reset(qubits: ilist.IList[Qubit, Any]) -> ilist.IList[bool, Any]:
168
- """Measure the qubits in the list and reset them."
169
-
170
- Args:
171
- qubits: The list of qubits to measure and reset.
172
-
173
- Returns:
174
- list[bool]: The result of the measurement.
175
- """
176
- ...
177
-
178
-
179
- @wraps(Reset)
180
- def reset(qubits: ilist.IList[Qubit, Any]) -> None:
181
- """Reset the qubits in the list."
182
-
183
- Args:
184
- qubits: The list of qubits to reset.
185
- """
186
- ...
@@ -0,0 +1,11 @@
1
+ from .wire_to_stim import SquinWireToStim as SquinWireToStim
2
+ from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
3
+ from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
4
+ from .wrap_analysis import (
5
+ SitesAttribute as SitesAttribute,
6
+ AddressAttribute as AddressAttribute,
7
+ WrapSquinAnalysis as WrapSquinAnalysis,
8
+ )
9
+ from .wire_identity_elimination import (
10
+ SquinWireIdentityElimination as SquinWireIdentityElimination,
11
+ )
@@ -0,0 +1,65 @@
1
+ from kirin import ir, types
2
+ from kirin.dialects import ilist
3
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
4
+
5
+ from bloqade.squin.qubit import (
6
+ Apply,
7
+ ApplyAny,
8
+ QubitType,
9
+ MeasureAny,
10
+ MeasureQubit,
11
+ MeasureQubitList,
12
+ )
13
+
14
+
15
+ class MeasureDesugarRule(RewriteRule):
16
+ """
17
+ Desugar measure operations in the circuit.
18
+ """
19
+
20
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21
+
22
+ if not isinstance(node, MeasureAny):
23
+ return RewriteResult()
24
+
25
+ if node.input.type.is_subseteq(QubitType):
26
+ node.replace_by(
27
+ MeasureQubit(
28
+ qubit=node.input,
29
+ )
30
+ )
31
+ return RewriteResult(has_done_something=True)
32
+ elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]):
33
+ node.replace_by(
34
+ MeasureQubitList(
35
+ qubits=node.input,
36
+ )
37
+ )
38
+ return RewriteResult(has_done_something=True)
39
+
40
+ return RewriteResult()
41
+
42
+
43
+ class ApplyDesugarRule(RewriteRule):
44
+ """
45
+ Desugar apply operators in the kernel.
46
+ """
47
+
48
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
49
+
50
+ if not isinstance(node, ApplyAny):
51
+ return RewriteResult()
52
+
53
+ op = node.operator
54
+ qubits = node.qubits
55
+
56
+ if len(qubits) == 1 and qubits[0].type.is_subseteq(ilist.IListType):
57
+ # NOTE: already calling with just a single argument that is already an ilist
58
+ qubits_ilist = qubits[0]
59
+ else:
60
+ (qubits_ilist_stmt := ilist.New(values=qubits)).insert_before(node)
61
+ qubits_ilist = qubits_ilist_stmt.result
62
+
63
+ stmt = Apply(operator=op, qubits=qubits_ilist)
64
+ node.replace_by(stmt)
65
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,61 @@
1
+ from kirin import ir
2
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
3
+
4
+ from bloqade.squin import op, qubit
5
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
6
+ from bloqade.squin.rewrite.stim_rewrite_util import (
7
+ SQUIN_STIM_GATE_MAPPING,
8
+ rewrite_Control,
9
+ insert_qubit_idx_from_address,
10
+ )
11
+
12
+
13
+ class SquinQubitToStim(RewriteRule):
14
+
15
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
16
+
17
+ match node:
18
+ case qubit.Apply() | qubit.Broadcast():
19
+ return self.rewrite_Apply_and_Broadcast(node)
20
+ case _:
21
+ return RewriteResult()
22
+
23
+ def rewrite_Apply_and_Broadcast(
24
+ self, stmt: qubit.Apply | qubit.Broadcast
25
+ ) -> RewriteResult:
26
+ """
27
+ Rewrite Apply and Broadcast nodes to their stim equivalent statements.
28
+ """
29
+
30
+ # this is an SSAValue, need it to be the actual operator
31
+ applied_op = stmt.operator.owner
32
+ assert isinstance(applied_op, op.stmts.Operator)
33
+
34
+ if isinstance(applied_op, op.stmts.Control):
35
+ return rewrite_Control(stmt)
36
+
37
+ # need to handle Control through separate means
38
+ # but we can handle X, Y, Z, H, and S here just fine
39
+ stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
40
+ if stim_1q_op is None:
41
+ return RewriteResult()
42
+
43
+ address_attr = stmt.qubits.hints.get("address")
44
+ if address_attr is None:
45
+ return RewriteResult()
46
+
47
+ assert isinstance(address_attr, AddressAttribute)
48
+ qubit_idx_ssas = insert_qubit_idx_from_address(
49
+ address=address_attr, stmt_to_insert_before=stmt
50
+ )
51
+
52
+ if qubit_idx_ssas is None:
53
+ return RewriteResult()
54
+
55
+ stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
56
+ stmt.replace_by(stim_1q_stmt)
57
+
58
+ return RewriteResult(has_done_something=True)
59
+
60
+
61
+ # put rewrites for measure statements in separate rule, then just have to dispatch
@@ -0,0 +1,73 @@
1
+ # create rewrite rule name SquinMeasureToStim using kirin
2
+ from kirin import ir
3
+ from kirin.dialects import py
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+
6
+ from bloqade.squin import wire, qubit
7
+ from bloqade.stim.dialects import collapse
8
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
9
+ from bloqade.squin.rewrite.stim_rewrite_util import (
10
+ is_measure_result_used,
11
+ insert_qubit_idx_from_address,
12
+ )
13
+
14
+
15
+ class SquinMeasureToStim(RewriteRule):
16
+ """
17
+ Rewrite squin measure-related statements to stim statements.
18
+ """
19
+
20
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21
+
22
+ match node:
23
+ case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure():
24
+ return self.rewrite_Measure(node)
25
+ case _:
26
+ return RewriteResult()
27
+
28
+ def rewrite_Measure(
29
+ self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
30
+ ) -> RewriteResult:
31
+ if is_measure_result_used(measure_stmt):
32
+ return RewriteResult()
33
+
34
+ qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
35
+ if qubit_idx_ssas is None:
36
+ return RewriteResult()
37
+
38
+ prob_noise_stmt = py.constant.Constant(0.0)
39
+ stim_measure_stmt = collapse.MZ(
40
+ p=prob_noise_stmt.result,
41
+ targets=qubit_idx_ssas,
42
+ )
43
+ prob_noise_stmt.insert_before(measure_stmt)
44
+ measure_stmt.replace_by(stim_measure_stmt)
45
+
46
+ return RewriteResult(has_done_something=True)
47
+
48
+ def get_qubit_idx_ssas(
49
+ self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
50
+ ) -> tuple[ir.SSAValue, ...] | None:
51
+ """
52
+ Extract the address attribute and insert qubit indices for the given measure statement.
53
+ """
54
+ match measure_stmt:
55
+ case qubit.MeasureQubit():
56
+ address_attr = measure_stmt.qubit.hints.get("address")
57
+ case qubit.MeasureQubitList():
58
+ address_attr = measure_stmt.qubits.hints.get("address")
59
+ case wire.Measure():
60
+ address_attr = measure_stmt.wire.hints.get("address")
61
+ case _:
62
+ return None
63
+
64
+ if address_attr is None:
65
+ return None
66
+
67
+ assert isinstance(address_attr, AddressAttribute)
68
+
69
+ qubit_idx_ssas = insert_qubit_idx_from_address(
70
+ address=address_attr, stmt_to_insert_before=measure_stmt
71
+ )
72
+
73
+ return qubit_idx_ssas