bloqade-circuit 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (79) hide show
  1. bloqade/analysis/address/impls.py +14 -0
  2. bloqade/noise/fidelity.py +3 -3
  3. bloqade/noise/native/_dialect.py +1 -1
  4. bloqade/noise/native/_wrappers.py +35 -6
  5. bloqade/noise/native/stmts.py +1 -1
  6. bloqade/pyqrack/device.py +1 -3
  7. bloqade/pyqrack/qasm2/core.py +4 -1
  8. bloqade/pyqrack/squin/qubit.py +16 -9
  9. bloqade/pyqrack/squin/wire.py +22 -4
  10. bloqade/pyqrack/task.py +13 -5
  11. bloqade/qasm2/__init__.py +1 -0
  12. bloqade/qasm2/_qasm_loading.py +151 -0
  13. bloqade/qasm2/dialects/core/__init__.py +9 -1
  14. bloqade/qasm2/dialects/expr/__init__.py +18 -1
  15. bloqade/qasm2/dialects/noise.py +33 -1
  16. bloqade/qasm2/dialects/uop/__init__.py +39 -3
  17. bloqade/qasm2/dialects/uop/schedule.py +1 -1
  18. bloqade/qasm2/emit/impls/__init__.py +1 -0
  19. bloqade/qasm2/emit/impls/noise_native.py +89 -0
  20. bloqade/qasm2/emit/main.py +21 -0
  21. bloqade/qasm2/emit/target.py +20 -5
  22. bloqade/qasm2/groups.py +2 -0
  23. bloqade/qasm2/parse/__init__.py +7 -4
  24. bloqade/qasm2/parse/lowering.py +20 -130
  25. bloqade/qasm2/parse/qasm2.lark +1 -1
  26. bloqade/qasm2/passes/__init__.py +1 -0
  27. bloqade/qasm2/passes/fold.py +6 -0
  28. bloqade/qasm2/passes/noise.py +22 -2
  29. bloqade/qasm2/passes/parallel.py +9 -0
  30. bloqade/qasm2/passes/unroll_if.py +25 -0
  31. bloqade/qasm2/rewrite/__init__.py +1 -0
  32. bloqade/qasm2/rewrite/desugar.py +3 -2
  33. bloqade/qasm2/rewrite/heuristic_noise.py +1 -9
  34. bloqade/qasm2/rewrite/native_gates.py +67 -4
  35. bloqade/qasm2/rewrite/split_ifs.py +66 -0
  36. bloqade/squin/analysis/nsites/__init__.py +1 -0
  37. bloqade/squin/analysis/nsites/impls.py +25 -1
  38. bloqade/squin/noise/__init__.py +7 -26
  39. bloqade/squin/noise/_wrapper.py +25 -0
  40. bloqade/squin/op/__init__.py +33 -159
  41. bloqade/squin/op/_wrapper.py +101 -0
  42. bloqade/squin/op/stdlib.py +62 -0
  43. bloqade/squin/passes/__init__.py +1 -0
  44. bloqade/squin/passes/stim.py +68 -0
  45. bloqade/squin/rewrite/__init__.py +11 -0
  46. bloqade/squin/rewrite/qubit_to_stim.py +84 -0
  47. bloqade/squin/rewrite/squin_measure.py +98 -0
  48. bloqade/squin/rewrite/stim_rewrite_util.py +158 -0
  49. bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
  50. bloqade/squin/rewrite/wire_to_stim.py +73 -0
  51. bloqade/squin/rewrite/wrap_analysis.py +72 -0
  52. bloqade/squin/wire.py +1 -13
  53. bloqade/stim/__init__.py +39 -5
  54. bloqade/stim/_wrappers.py +14 -12
  55. bloqade/stim/dialects/__init__.py +1 -5
  56. bloqade/stim/dialects/{aux → auxiliary}/__init__.py +12 -1
  57. bloqade/stim/dialects/{aux → auxiliary}/emit.py +1 -1
  58. bloqade/stim/dialects/collapse/__init__.py +13 -2
  59. bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +1 -1
  60. bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
  61. bloqade/stim/dialects/gate/__init__.py +16 -1
  62. bloqade/stim/dialects/gate/emit.py +1 -1
  63. bloqade/stim/dialects/gate/stmts/base.py +1 -1
  64. bloqade/stim/dialects/gate/stmts/pp.py +1 -1
  65. bloqade/stim/dialects/noise/emit.py +1 -1
  66. bloqade/stim/emit/__init__.py +1 -1
  67. bloqade/stim/groups.py +4 -2
  68. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/METADATA +3 -3
  69. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/RECORD +79 -63
  70. /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
  71. /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
  72. /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
  73. /bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +0 -0
  74. /bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +0 -0
  75. /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
  76. /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
  77. /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
  78. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/WHEEL +0 -0
  79. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,66 @@
1
+ from kirin import ir
2
+ from kirin.dialects import scf, func
3
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
4
+
5
+ from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
6
+ from ..dialects.core.stmts import Reset, Measure
7
+
8
+ # TODO: unify with PR #248
9
+ AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
10
+
11
+ DontLiftType = AllowedThenType | scf.Yield | func.Return | func.Invoke
12
+
13
+
14
+ class LiftThenBody(RewriteRule):
15
+ """Lifts anything that's not a UOP or a yield/return out of the then body"""
16
+
17
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18
+ if not isinstance(node, scf.IfElse):
19
+ return RewriteResult()
20
+
21
+ then_stmts = node.then_body.stmts()
22
+
23
+ lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)]
24
+
25
+ if len(lift_stmts) == 0:
26
+ return RewriteResult()
27
+
28
+ for stmt in lift_stmts:
29
+ stmt.detach()
30
+ stmt.insert_before(node)
31
+
32
+ return RewriteResult(has_done_something=True)
33
+
34
+
35
+ class SplitIfStmts(RewriteRule):
36
+ """Splits the then body of an if-else statement into multiple if statements"""
37
+
38
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
39
+ if not isinstance(node, scf.IfElse):
40
+ return RewriteResult()
41
+
42
+ *stmts, yield_or_return = node.then_body.stmts()
43
+
44
+ if len(stmts) == 1:
45
+ return RewriteResult()
46
+
47
+ is_yield = isinstance(yield_or_return, scf.Yield)
48
+
49
+ for stmt in stmts:
50
+ stmt.detach()
51
+
52
+ yield_or_return = scf.Yield() if is_yield else func.Return()
53
+
54
+ then_block = ir.Block((stmt, yield_or_return), argtypes=(node.cond.type,))
55
+ then_body = ir.Region(then_block)
56
+ else_body = node.else_body.clone()
57
+ else_body.detach()
58
+ new_if = scf.IfElse(
59
+ cond=node.cond, then_body=then_body, else_body=else_body
60
+ )
61
+
62
+ new_if.insert_before(node)
63
+
64
+ node.delete()
65
+
66
+ return RewriteResult(has_done_something=True)
@@ -1,6 +1,7 @@
1
1
  # Need this for impl registration to work properly!
2
2
  from . import impls as impls
3
3
  from .lattice import (
4
+ Sites as Sites,
4
5
  NoSites as NoSites,
5
6
  AnySites as AnySites,
6
7
  NumberSites as NumberSites,
@@ -1,6 +1,6 @@
1
1
  from kirin import interp
2
2
 
3
- from bloqade.squin import op
3
+ from bloqade.squin import op, wire
4
4
 
5
5
  from .lattice import (
6
6
  NoSites,
@@ -9,6 +9,30 @@ from .lattice import (
9
9
  from .analysis import NSitesAnalysis
10
10
 
11
11
 
12
+ @wire.dialect.register(key="op.nsites")
13
+ class SquinWire(interp.MethodTable):
14
+
15
+ @interp.impl(wire.Apply)
16
+ @interp.impl(wire.Broadcast)
17
+ def apply(
18
+ self,
19
+ interp: NSitesAnalysis,
20
+ frame: interp.Frame,
21
+ stmt: wire.Apply | wire.Broadcast,
22
+ ):
23
+
24
+ return tuple(frame.get(input) for input in stmt.inputs)
25
+
26
+ @interp.impl(wire.MeasureAndReset)
27
+ def measure_and_reset(
28
+ self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.MeasureAndReset
29
+ ):
30
+
31
+ # MeasureAndReset produces both a new wire
32
+ # and an integer which don't have any sites at all
33
+ return (NoSites(), NoSites())
34
+
35
+
12
36
  @op.dialect.register(key="op.nsites")
13
37
  class SquinOp(interp.MethodTable):
14
38
 
@@ -1,27 +1,8 @@
1
- # Put all the proper wrappers here
2
-
3
- from kirin.lowering import wraps as _wraps
4
-
5
- from bloqade.squin.op.types import Op
6
-
7
1
  from . import stmts as stmts
8
-
9
-
10
- @_wraps(stmts.PauliError)
11
- def pauli_error(basis: Op, p: float) -> Op: ...
12
-
13
-
14
- @_wraps(stmts.PPError)
15
- def pp_error(op: Op, p: float) -> Op: ...
16
-
17
-
18
- @_wraps(stmts.Depolarize)
19
- def depolarize(n_qubits: int, p: float) -> Op: ...
20
-
21
-
22
- @_wraps(stmts.PauliChannel)
23
- def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
24
-
25
-
26
- @_wraps(stmts.QubitLoss)
27
- def qubit_loss(p: float) -> Op: ...
2
+ from ._dialect import dialect as dialect
3
+ from ._wrapper import (
4
+ pp_error as pp_error,
5
+ depolarize as depolarize,
6
+ qubit_loss as qubit_loss,
7
+ pauli_channel as pauli_channel,
8
+ )
@@ -0,0 +1,25 @@
1
+ from kirin.lowering import wraps
2
+
3
+ from bloqade.squin.op.types import Op
4
+
5
+ from . import stmts
6
+
7
+
8
+ @wraps(stmts.PauliError)
9
+ def pauli_error(basis: Op, p: float) -> Op: ...
10
+
11
+
12
+ @wraps(stmts.PPError)
13
+ def pp_error(op: Op, p: float) -> Op: ...
14
+
15
+
16
+ @wraps(stmts.Depolarize)
17
+ def depolarize(n_qubits: int, p: float) -> Op: ...
18
+
19
+
20
+ @wraps(stmts.PauliChannel)
21
+ def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
22
+
23
+
24
+ @wraps(stmts.QubitLoss)
25
+ def qubit_loss(p: float) -> Op: ...
@@ -1,162 +1,36 @@
1
- from kirin import ir as _ir
2
- from kirin.prelude import structural_no_opt as _structural_no_opt
3
- from kirin.lowering import wraps as _wraps
4
-
5
1
  from . import stmts as stmts, types as types, rewrite as rewrite
2
+ from .stdlib import (
3
+ ch as ch,
4
+ cx as cx,
5
+ cy as cy,
6
+ cz as cz,
7
+ rx as rx,
8
+ ry as ry,
9
+ rz as rz,
10
+ cphase as cphase,
11
+ )
6
12
  from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
7
13
  from ._dialect import dialect as dialect
8
-
9
-
10
- @_wraps(stmts.Kron)
11
- def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
12
-
13
-
14
- @_wraps(stmts.Mult)
15
- def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
16
-
17
-
18
- @_wraps(stmts.Scale)
19
- def scale(op: types.Op, factor: complex) -> types.Op: ...
20
-
21
-
22
- @_wraps(stmts.Adjoint)
23
- def adjoint(op: types.Op) -> types.Op: ...
24
-
25
-
26
- @_wraps(stmts.Control)
27
- def control(op: types.Op, *, n_controls: int) -> types.Op:
28
- """
29
- Create a controlled operator.
30
-
31
- Note, that when considering atom loss, the operator will not be applied if
32
- any of the controls has been lost.
33
-
34
- Args:
35
- operator: The operator to apply under the control.
36
- n_controls: The number qubits to be used as control.
37
-
38
- Returns:
39
- Operator
40
- """
41
- ...
42
-
43
-
44
- @_wraps(stmts.Identity)
45
- def identity(*, sites: int) -> types.Op: ...
46
-
47
-
48
- @_wraps(stmts.Rot)
49
- def rot(axis: types.Op, angle: float) -> types.Op: ...
50
-
51
-
52
- @_wraps(stmts.ShiftOp)
53
- def shift(theta: float) -> types.Op: ...
54
-
55
-
56
- @_wraps(stmts.PhaseOp)
57
- def phase(theta: float) -> types.Op: ...
58
-
59
-
60
- @_wraps(stmts.X)
61
- def x() -> types.Op: ...
62
-
63
-
64
- @_wraps(stmts.Y)
65
- def y() -> types.Op: ...
66
-
67
-
68
- @_wraps(stmts.Z)
69
- def z() -> types.Op: ...
70
-
71
-
72
- @_wraps(stmts.H)
73
- def h() -> types.Op: ...
74
-
75
-
76
- @_wraps(stmts.S)
77
- def s() -> types.Op: ...
78
-
79
-
80
- @_wraps(stmts.T)
81
- def t() -> types.Op: ...
82
-
83
-
84
- @_wraps(stmts.P0)
85
- def p0() -> types.Op: ...
86
-
87
-
88
- @_wraps(stmts.P1)
89
- def p1() -> types.Op: ...
90
-
91
-
92
- @_wraps(stmts.Sn)
93
- def spin_n() -> types.Op: ...
94
-
95
-
96
- @_wraps(stmts.Sp)
97
- def spin_p() -> types.Op: ...
98
-
99
-
100
- @_wraps(stmts.U3)
101
- def u(theta: float, phi: float, lam: float) -> types.Op: ...
102
-
103
-
104
- @_wraps(stmts.PauliString)
105
- def pauli_string(*, string: str) -> types.Op: ...
106
-
107
-
108
- # stdlibs
109
- @_ir.dialect_group(_structural_no_opt.add(dialect))
110
- def op(self):
111
- def run_pass(method):
112
- pass
113
-
114
- return run_pass
115
-
116
-
117
- @op
118
- def rx(theta: float) -> types.Op:
119
- """Rotation X gate."""
120
- return rot(x(), theta)
121
-
122
-
123
- @op
124
- def ry(theta: float) -> types.Op:
125
- """Rotation Y gate."""
126
- return rot(y(), theta)
127
-
128
-
129
- @op
130
- def rz(theta: float) -> types.Op:
131
- """Rotation Z gate."""
132
- return rot(z(), theta)
133
-
134
-
135
- @op
136
- def cx() -> types.Op:
137
- """Controlled X gate."""
138
- return control(x(), n_controls=1)
139
-
140
-
141
- @op
142
- def cy() -> types.Op:
143
- """Controlled Y gate."""
144
- return control(y(), n_controls=1)
145
-
146
-
147
- @op
148
- def cz() -> types.Op:
149
- """Control Z gate."""
150
- return control(z(), n_controls=1)
151
-
152
-
153
- @op
154
- def ch() -> types.Op:
155
- """Control H gate."""
156
- return control(h(), n_controls=1)
157
-
158
-
159
- @op
160
- def cphase(theta: float) -> types.Op:
161
- """Control Phase gate."""
162
- return control(phase(theta), n_controls=1)
14
+ from ._wrapper import (
15
+ h as h,
16
+ s as s,
17
+ t as t,
18
+ u as u,
19
+ x as x,
20
+ y as y,
21
+ z as z,
22
+ p0 as p0,
23
+ p1 as p1,
24
+ rot as rot,
25
+ kron as kron,
26
+ mult as mult,
27
+ phase as phase,
28
+ scale as scale,
29
+ shift as shift,
30
+ spin_n as spin_n,
31
+ spin_p as spin_p,
32
+ adjoint as adjoint,
33
+ control as control,
34
+ identity as identity,
35
+ pauli_string as pauli_string,
36
+ )
@@ -0,0 +1,101 @@
1
+ from kirin.lowering import wraps
2
+
3
+ from . import stmts, types
4
+
5
+
6
+ @wraps(stmts.Kron)
7
+ def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
8
+
9
+
10
+ @wraps(stmts.Mult)
11
+ def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
12
+
13
+
14
+ @wraps(stmts.Scale)
15
+ def scale(op: types.Op, factor: complex) -> types.Op: ...
16
+
17
+
18
+ @wraps(stmts.Adjoint)
19
+ def adjoint(op: types.Op) -> types.Op: ...
20
+
21
+
22
+ @wraps(stmts.Control)
23
+ def control(op: types.Op, *, n_controls: int) -> types.Op:
24
+ """
25
+ Create a controlled operator.
26
+
27
+ Note, that when considering atom loss, the operator will not be applied if
28
+ any of the controls has been lost.
29
+
30
+ Args:
31
+ operator: The operator to apply under the control.
32
+ n_controls: The number qubits to be used as control.
33
+
34
+ Returns:
35
+ Operator
36
+ """
37
+ ...
38
+
39
+
40
+ @wraps(stmts.Identity)
41
+ def identity(*, sites: int) -> types.Op: ...
42
+
43
+
44
+ @wraps(stmts.Rot)
45
+ def rot(axis: types.Op, angle: float) -> types.Op: ...
46
+
47
+
48
+ @wraps(stmts.ShiftOp)
49
+ def shift(theta: float) -> types.Op: ...
50
+
51
+
52
+ @wraps(stmts.PhaseOp)
53
+ def phase(theta: float) -> types.Op: ...
54
+
55
+
56
+ @wraps(stmts.X)
57
+ def x() -> types.Op: ...
58
+
59
+
60
+ @wraps(stmts.Y)
61
+ def y() -> types.Op: ...
62
+
63
+
64
+ @wraps(stmts.Z)
65
+ def z() -> types.Op: ...
66
+
67
+
68
+ @wraps(stmts.H)
69
+ def h() -> types.Op: ...
70
+
71
+
72
+ @wraps(stmts.S)
73
+ def s() -> types.Op: ...
74
+
75
+
76
+ @wraps(stmts.T)
77
+ def t() -> types.Op: ...
78
+
79
+
80
+ @wraps(stmts.P0)
81
+ def p0() -> types.Op: ...
82
+
83
+
84
+ @wraps(stmts.P1)
85
+ def p1() -> types.Op: ...
86
+
87
+
88
+ @wraps(stmts.Sn)
89
+ def spin_n() -> types.Op: ...
90
+
91
+
92
+ @wraps(stmts.Sp)
93
+ def spin_p() -> types.Op: ...
94
+
95
+
96
+ @wraps(stmts.U3)
97
+ def u(theta: float, phi: float, lam: float) -> types.Op: ...
98
+
99
+
100
+ @wraps(stmts.PauliString)
101
+ def pauli_string(*, string: str) -> types.Op: ...
@@ -0,0 +1,62 @@
1
+ from kirin import ir
2
+ from kirin.prelude import structural_no_opt
3
+
4
+ from . import types
5
+ from ._dialect import dialect
6
+ from ._wrapper import h, x, y, z, rot, phase, control
7
+
8
+
9
+ @ir.dialect_group(structural_no_opt.add(dialect))
10
+ def op(self):
11
+ def run_pass(method):
12
+ pass
13
+
14
+ return run_pass
15
+
16
+
17
+ @op
18
+ def rx(theta: float) -> types.Op:
19
+ """Rotation X gate."""
20
+ return rot(x(), theta)
21
+
22
+
23
+ @op
24
+ def ry(theta: float) -> types.Op:
25
+ """Rotation Y gate."""
26
+ return rot(y(), theta)
27
+
28
+
29
+ @op
30
+ def rz(theta: float) -> types.Op:
31
+ """Rotation Z gate."""
32
+ return rot(z(), theta)
33
+
34
+
35
+ @op
36
+ def cx() -> types.Op:
37
+ """Controlled X gate."""
38
+ return control(x(), n_controls=1)
39
+
40
+
41
+ @op
42
+ def cy() -> types.Op:
43
+ """Controlled Y gate."""
44
+ return control(y(), n_controls=1)
45
+
46
+
47
+ @op
48
+ def cz() -> types.Op:
49
+ """Control Z gate."""
50
+ return control(z(), n_controls=1)
51
+
52
+
53
+ @op
54
+ def ch() -> types.Op:
55
+ """Control H gate."""
56
+ return control(h(), n_controls=1)
57
+
58
+
59
+ @op
60
+ def cphase(theta: float) -> types.Op:
61
+ """Control Phase gate."""
62
+ return control(phase(theta), n_controls=1)
@@ -0,0 +1 @@
1
+ from .stim import SquinToStim as SquinToStim
@@ -0,0 +1,68 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin.passes import Fold
4
+ from kirin.rewrite import (
5
+ Walk,
6
+ Chain,
7
+ Fixpoint,
8
+ DeadCodeElimination,
9
+ CommonSubexpressionElimination,
10
+ )
11
+ from kirin.ir.method import Method
12
+ from kirin.passes.abc import Pass
13
+ from kirin.rewrite.abc import RewriteResult
14
+
15
+ from bloqade.squin.rewrite import (
16
+ SquinWireToStim,
17
+ SquinQubitToStim,
18
+ WrapSquinAnalysis,
19
+ SquinMeasureToStim,
20
+ SquinWireIdentityElimination,
21
+ )
22
+ from bloqade.analysis.address import AddressAnalysis
23
+ from bloqade.squin.analysis.nsites import (
24
+ NSitesAnalysis,
25
+ )
26
+
27
+
28
+ @dataclass
29
+ class SquinToStim(Pass):
30
+
31
+ def unsafe_run(self, mt: Method) -> RewriteResult:
32
+ fold_pass = Fold(mt.dialects)
33
+ # propagate constants
34
+ rewrite_result = fold_pass(mt)
35
+
36
+ # Get necessary analysis results to plug into hints
37
+ address_analysis = AddressAnalysis(mt.dialects)
38
+ address_frame, _ = address_analysis.run_analysis(mt)
39
+ site_analysis = NSitesAnalysis(mt.dialects)
40
+ sites_frame, _ = site_analysis.run_analysis(mt)
41
+
42
+ # Wrap Rewrite + SquinToStim can happen w/ standard walk
43
+ rewrite_result = (
44
+ Walk(
45
+ Chain(
46
+ WrapSquinAnalysis(
47
+ address_analysis=address_frame.entries,
48
+ op_site_analysis=sites_frame.entries,
49
+ ),
50
+ SquinQubitToStim(),
51
+ SquinWireToStim(),
52
+ SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later
53
+ SquinWireIdentityElimination(),
54
+ )
55
+ )
56
+ .rewrite(mt.code)
57
+ .join(rewrite_result)
58
+ )
59
+
60
+ rewrite_result = (
61
+ Fixpoint(
62
+ Walk(Chain(DeadCodeElimination(), CommonSubexpressionElimination()))
63
+ )
64
+ .rewrite(mt.code)
65
+ .join(rewrite_result)
66
+ )
67
+
68
+ return rewrite_result
@@ -0,0 +1,11 @@
1
+ from .wire_to_stim import SquinWireToStim as SquinWireToStim
2
+ from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
3
+ from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
4
+ from .wrap_analysis import (
5
+ SitesAttribute as SitesAttribute,
6
+ AddressAttribute as AddressAttribute,
7
+ WrapSquinAnalysis as WrapSquinAnalysis,
8
+ )
9
+ from .wire_identity_elimination import (
10
+ SquinWireIdentityElimination as SquinWireIdentityElimination,
11
+ )
@@ -0,0 +1,84 @@
1
+ from kirin import ir
2
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
3
+
4
+ from bloqade import stim
5
+ from bloqade.squin import op, qubit
6
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
7
+ from bloqade.squin.rewrite.stim_rewrite_util import (
8
+ SQUIN_STIM_GATE_MAPPING,
9
+ rewrite_Control,
10
+ insert_qubit_idx_from_address,
11
+ )
12
+
13
+
14
+ class SquinQubitToStim(RewriteRule):
15
+
16
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
17
+
18
+ match node:
19
+ case qubit.Apply() | qubit.Broadcast():
20
+ return self.rewrite_Apply_and_Broadcast(node)
21
+ case qubit.Reset():
22
+ return self.rewrite_Reset(node)
23
+ case _:
24
+ return RewriteResult()
25
+
26
+ def rewrite_Apply_and_Broadcast(
27
+ self, stmt: qubit.Apply | qubit.Broadcast
28
+ ) -> RewriteResult:
29
+ """
30
+ Rewrite Apply and Broadcast nodes to their stim equivalent statements.
31
+ """
32
+
33
+ # this is an SSAValue, need it to be the actual operator
34
+ applied_op = stmt.operator.owner
35
+ assert isinstance(applied_op, op.stmts.Operator)
36
+
37
+ if isinstance(applied_op, op.stmts.Control):
38
+ return rewrite_Control(stmt)
39
+
40
+ # need to handle Control through separate means
41
+ # but we can handle X, Y, Z, H, and S here just fine
42
+ stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
43
+ if stim_1q_op is None:
44
+ return RewriteResult()
45
+
46
+ address_attr = stmt.qubits.hints.get("address")
47
+ if address_attr is None:
48
+ return RewriteResult()
49
+
50
+ assert isinstance(address_attr, AddressAttribute)
51
+ qubit_idx_ssas = insert_qubit_idx_from_address(
52
+ address=address_attr, stmt_to_insert_before=stmt
53
+ )
54
+
55
+ if qubit_idx_ssas is None:
56
+ return RewriteResult()
57
+
58
+ stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
59
+ stmt.replace_by(stim_1q_stmt)
60
+
61
+ return RewriteResult(has_done_something=True)
62
+
63
+ def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult:
64
+ qubit_ilist_ssa = reset_stmt.qubits
65
+ # qubits are in an ilist which makes up an AddressTuple
66
+ address_attr = qubit_ilist_ssa.hints.get("address")
67
+ if address_attr is None:
68
+ return RewriteResult()
69
+
70
+ assert isinstance(address_attr, AddressAttribute)
71
+ qubit_idx_ssas = insert_qubit_idx_from_address(
72
+ address=address_attr, stmt_to_insert_before=reset_stmt
73
+ )
74
+
75
+ if qubit_idx_ssas is None:
76
+ return RewriteResult()
77
+
78
+ stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
79
+ reset_stmt.replace_by(stim_rz_stmt)
80
+
81
+ return RewriteResult(has_done_something=True)
82
+
83
+
84
+ # put rewrites for measure statements in separate rule, then just have to dispatch