bloqade-circuit 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +14 -0
- bloqade/noise/fidelity.py +3 -3
- bloqade/noise/native/_dialect.py +1 -1
- bloqade/noise/native/_wrappers.py +35 -6
- bloqade/noise/native/stmts.py +1 -1
- bloqade/pyqrack/device.py +1 -3
- bloqade/pyqrack/qasm2/core.py +4 -1
- bloqade/pyqrack/squin/qubit.py +16 -9
- bloqade/pyqrack/squin/wire.py +22 -4
- bloqade/pyqrack/task.py +13 -5
- bloqade/qasm2/__init__.py +1 -0
- bloqade/qasm2/_qasm_loading.py +151 -0
- bloqade/qasm2/dialects/core/__init__.py +9 -1
- bloqade/qasm2/dialects/expr/__init__.py +18 -1
- bloqade/qasm2/dialects/noise.py +33 -1
- bloqade/qasm2/dialects/uop/__init__.py +39 -3
- bloqade/qasm2/dialects/uop/schedule.py +1 -1
- bloqade/qasm2/emit/impls/__init__.py +1 -0
- bloqade/qasm2/emit/impls/noise_native.py +89 -0
- bloqade/qasm2/emit/main.py +21 -0
- bloqade/qasm2/emit/target.py +20 -5
- bloqade/qasm2/groups.py +2 -0
- bloqade/qasm2/parse/__init__.py +7 -4
- bloqade/qasm2/parse/lowering.py +20 -130
- bloqade/qasm2/parse/qasm2.lark +1 -1
- bloqade/qasm2/passes/__init__.py +1 -0
- bloqade/qasm2/passes/fold.py +6 -0
- bloqade/qasm2/passes/noise.py +22 -2
- bloqade/qasm2/passes/parallel.py +9 -0
- bloqade/qasm2/passes/unroll_if.py +25 -0
- bloqade/qasm2/rewrite/__init__.py +1 -0
- bloqade/qasm2/rewrite/desugar.py +3 -2
- bloqade/qasm2/rewrite/heuristic_noise.py +1 -9
- bloqade/qasm2/rewrite/native_gates.py +67 -4
- bloqade/qasm2/rewrite/split_ifs.py +66 -0
- bloqade/squin/analysis/nsites/__init__.py +1 -0
- bloqade/squin/analysis/nsites/impls.py +25 -1
- bloqade/squin/noise/__init__.py +7 -26
- bloqade/squin/noise/_wrapper.py +25 -0
- bloqade/squin/op/__init__.py +33 -159
- bloqade/squin/op/_wrapper.py +101 -0
- bloqade/squin/op/stdlib.py +62 -0
- bloqade/squin/passes/__init__.py +1 -0
- bloqade/squin/passes/stim.py +68 -0
- bloqade/squin/rewrite/__init__.py +11 -0
- bloqade/squin/rewrite/qubit_to_stim.py +84 -0
- bloqade/squin/rewrite/squin_measure.py +98 -0
- bloqade/squin/rewrite/stim_rewrite_util.py +158 -0
- bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
- bloqade/squin/rewrite/wire_to_stim.py +73 -0
- bloqade/squin/rewrite/wrap_analysis.py +72 -0
- bloqade/squin/wire.py +1 -13
- bloqade/stim/__init__.py +39 -5
- bloqade/stim/_wrappers.py +14 -12
- bloqade/stim/dialects/__init__.py +1 -5
- bloqade/stim/dialects/{aux → auxiliary}/__init__.py +12 -1
- bloqade/stim/dialects/{aux → auxiliary}/emit.py +1 -1
- bloqade/stim/dialects/collapse/__init__.py +13 -2
- bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +1 -1
- bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
- bloqade/stim/dialects/gate/__init__.py +16 -1
- bloqade/stim/dialects/gate/emit.py +1 -1
- bloqade/stim/dialects/gate/stmts/base.py +1 -1
- bloqade/stim/dialects/gate/stmts/pp.py +1 -1
- bloqade/stim/dialects/noise/emit.py +1 -1
- bloqade/stim/emit/__init__.py +1 -1
- bloqade/stim/groups.py +4 -2
- {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/METADATA +3 -3
- {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/RECORD +79 -63
- /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
- /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
- {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.dialects import scf, func
|
|
3
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
4
|
+
|
|
5
|
+
from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
|
|
6
|
+
from ..dialects.core.stmts import Reset, Measure
|
|
7
|
+
|
|
8
|
+
# TODO: unify with PR #248
|
|
9
|
+
AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
|
|
10
|
+
|
|
11
|
+
DontLiftType = AllowedThenType | scf.Yield | func.Return | func.Invoke
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LiftThenBody(RewriteRule):
|
|
15
|
+
"""Lifts anything that's not a UOP or a yield/return out of the then body"""
|
|
16
|
+
|
|
17
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
18
|
+
if not isinstance(node, scf.IfElse):
|
|
19
|
+
return RewriteResult()
|
|
20
|
+
|
|
21
|
+
then_stmts = node.then_body.stmts()
|
|
22
|
+
|
|
23
|
+
lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)]
|
|
24
|
+
|
|
25
|
+
if len(lift_stmts) == 0:
|
|
26
|
+
return RewriteResult()
|
|
27
|
+
|
|
28
|
+
for stmt in lift_stmts:
|
|
29
|
+
stmt.detach()
|
|
30
|
+
stmt.insert_before(node)
|
|
31
|
+
|
|
32
|
+
return RewriteResult(has_done_something=True)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SplitIfStmts(RewriteRule):
|
|
36
|
+
"""Splits the then body of an if-else statement into multiple if statements"""
|
|
37
|
+
|
|
38
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
39
|
+
if not isinstance(node, scf.IfElse):
|
|
40
|
+
return RewriteResult()
|
|
41
|
+
|
|
42
|
+
*stmts, yield_or_return = node.then_body.stmts()
|
|
43
|
+
|
|
44
|
+
if len(stmts) == 1:
|
|
45
|
+
return RewriteResult()
|
|
46
|
+
|
|
47
|
+
is_yield = isinstance(yield_or_return, scf.Yield)
|
|
48
|
+
|
|
49
|
+
for stmt in stmts:
|
|
50
|
+
stmt.detach()
|
|
51
|
+
|
|
52
|
+
yield_or_return = scf.Yield() if is_yield else func.Return()
|
|
53
|
+
|
|
54
|
+
then_block = ir.Block((stmt, yield_or_return), argtypes=(node.cond.type,))
|
|
55
|
+
then_body = ir.Region(then_block)
|
|
56
|
+
else_body = node.else_body.clone()
|
|
57
|
+
else_body.detach()
|
|
58
|
+
new_if = scf.IfElse(
|
|
59
|
+
cond=node.cond, then_body=then_body, else_body=else_body
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
new_if.insert_before(node)
|
|
63
|
+
|
|
64
|
+
node.delete()
|
|
65
|
+
|
|
66
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from kirin import interp
|
|
2
2
|
|
|
3
|
-
from bloqade.squin import op
|
|
3
|
+
from bloqade.squin import op, wire
|
|
4
4
|
|
|
5
5
|
from .lattice import (
|
|
6
6
|
NoSites,
|
|
@@ -9,6 +9,30 @@ from .lattice import (
|
|
|
9
9
|
from .analysis import NSitesAnalysis
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
@wire.dialect.register(key="op.nsites")
|
|
13
|
+
class SquinWire(interp.MethodTable):
|
|
14
|
+
|
|
15
|
+
@interp.impl(wire.Apply)
|
|
16
|
+
@interp.impl(wire.Broadcast)
|
|
17
|
+
def apply(
|
|
18
|
+
self,
|
|
19
|
+
interp: NSitesAnalysis,
|
|
20
|
+
frame: interp.Frame,
|
|
21
|
+
stmt: wire.Apply | wire.Broadcast,
|
|
22
|
+
):
|
|
23
|
+
|
|
24
|
+
return tuple(frame.get(input) for input in stmt.inputs)
|
|
25
|
+
|
|
26
|
+
@interp.impl(wire.MeasureAndReset)
|
|
27
|
+
def measure_and_reset(
|
|
28
|
+
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.MeasureAndReset
|
|
29
|
+
):
|
|
30
|
+
|
|
31
|
+
# MeasureAndReset produces both a new wire
|
|
32
|
+
# and an integer which don't have any sites at all
|
|
33
|
+
return (NoSites(), NoSites())
|
|
34
|
+
|
|
35
|
+
|
|
12
36
|
@op.dialect.register(key="op.nsites")
|
|
13
37
|
class SquinOp(interp.MethodTable):
|
|
14
38
|
|
bloqade/squin/noise/__init__.py
CHANGED
|
@@ -1,27 +1,8 @@
|
|
|
1
|
-
# Put all the proper wrappers here
|
|
2
|
-
|
|
3
|
-
from kirin.lowering import wraps as _wraps
|
|
4
|
-
|
|
5
|
-
from bloqade.squin.op.types import Op
|
|
6
|
-
|
|
7
1
|
from . import stmts as stmts
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def pp_error(op: Op, p: float) -> Op: ...
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
@_wraps(stmts.Depolarize)
|
|
19
|
-
def depolarize(n_qubits: int, p: float) -> Op: ...
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
@_wraps(stmts.PauliChannel)
|
|
23
|
-
def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@_wraps(stmts.QubitLoss)
|
|
27
|
-
def qubit_loss(p: float) -> Op: ...
|
|
2
|
+
from ._dialect import dialect as dialect
|
|
3
|
+
from ._wrapper import (
|
|
4
|
+
pp_error as pp_error,
|
|
5
|
+
depolarize as depolarize,
|
|
6
|
+
qubit_loss as qubit_loss,
|
|
7
|
+
pauli_channel as pauli_channel,
|
|
8
|
+
)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from kirin.lowering import wraps
|
|
2
|
+
|
|
3
|
+
from bloqade.squin.op.types import Op
|
|
4
|
+
|
|
5
|
+
from . import stmts
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@wraps(stmts.PauliError)
|
|
9
|
+
def pauli_error(basis: Op, p: float) -> Op: ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@wraps(stmts.PPError)
|
|
13
|
+
def pp_error(op: Op, p: float) -> Op: ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@wraps(stmts.Depolarize)
|
|
17
|
+
def depolarize(n_qubits: int, p: float) -> Op: ...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@wraps(stmts.PauliChannel)
|
|
21
|
+
def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@wraps(stmts.QubitLoss)
|
|
25
|
+
def qubit_loss(p: float) -> Op: ...
|
bloqade/squin/op/__init__.py
CHANGED
|
@@ -1,162 +1,36 @@
|
|
|
1
|
-
from kirin import ir as _ir
|
|
2
|
-
from kirin.prelude import structural_no_opt as _structural_no_opt
|
|
3
|
-
from kirin.lowering import wraps as _wraps
|
|
4
|
-
|
|
5
1
|
from . import stmts as stmts, types as types, rewrite as rewrite
|
|
2
|
+
from .stdlib import (
|
|
3
|
+
ch as ch,
|
|
4
|
+
cx as cx,
|
|
5
|
+
cy as cy,
|
|
6
|
+
cz as cz,
|
|
7
|
+
rx as rx,
|
|
8
|
+
ry as ry,
|
|
9
|
+
rz as rz,
|
|
10
|
+
cphase as cphase,
|
|
11
|
+
)
|
|
6
12
|
from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
|
|
7
13
|
from ._dialect import dialect as dialect
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
Note, that when considering atom loss, the operator will not be applied if
|
|
32
|
-
any of the controls has been lost.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
operator: The operator to apply under the control.
|
|
36
|
-
n_controls: The number qubits to be used as control.
|
|
37
|
-
|
|
38
|
-
Returns:
|
|
39
|
-
Operator
|
|
40
|
-
"""
|
|
41
|
-
...
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@_wraps(stmts.Identity)
|
|
45
|
-
def identity(*, sites: int) -> types.Op: ...
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@_wraps(stmts.Rot)
|
|
49
|
-
def rot(axis: types.Op, angle: float) -> types.Op: ...
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@_wraps(stmts.ShiftOp)
|
|
53
|
-
def shift(theta: float) -> types.Op: ...
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@_wraps(stmts.PhaseOp)
|
|
57
|
-
def phase(theta: float) -> types.Op: ...
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
@_wraps(stmts.X)
|
|
61
|
-
def x() -> types.Op: ...
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
@_wraps(stmts.Y)
|
|
65
|
-
def y() -> types.Op: ...
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
@_wraps(stmts.Z)
|
|
69
|
-
def z() -> types.Op: ...
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@_wraps(stmts.H)
|
|
73
|
-
def h() -> types.Op: ...
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
@_wraps(stmts.S)
|
|
77
|
-
def s() -> types.Op: ...
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
@_wraps(stmts.T)
|
|
81
|
-
def t() -> types.Op: ...
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
@_wraps(stmts.P0)
|
|
85
|
-
def p0() -> types.Op: ...
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
@_wraps(stmts.P1)
|
|
89
|
-
def p1() -> types.Op: ...
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
@_wraps(stmts.Sn)
|
|
93
|
-
def spin_n() -> types.Op: ...
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
@_wraps(stmts.Sp)
|
|
97
|
-
def spin_p() -> types.Op: ...
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
@_wraps(stmts.U3)
|
|
101
|
-
def u(theta: float, phi: float, lam: float) -> types.Op: ...
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
@_wraps(stmts.PauliString)
|
|
105
|
-
def pauli_string(*, string: str) -> types.Op: ...
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
# stdlibs
|
|
109
|
-
@_ir.dialect_group(_structural_no_opt.add(dialect))
|
|
110
|
-
def op(self):
|
|
111
|
-
def run_pass(method):
|
|
112
|
-
pass
|
|
113
|
-
|
|
114
|
-
return run_pass
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
@op
|
|
118
|
-
def rx(theta: float) -> types.Op:
|
|
119
|
-
"""Rotation X gate."""
|
|
120
|
-
return rot(x(), theta)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
@op
|
|
124
|
-
def ry(theta: float) -> types.Op:
|
|
125
|
-
"""Rotation Y gate."""
|
|
126
|
-
return rot(y(), theta)
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
@op
|
|
130
|
-
def rz(theta: float) -> types.Op:
|
|
131
|
-
"""Rotation Z gate."""
|
|
132
|
-
return rot(z(), theta)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
@op
|
|
136
|
-
def cx() -> types.Op:
|
|
137
|
-
"""Controlled X gate."""
|
|
138
|
-
return control(x(), n_controls=1)
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
@op
|
|
142
|
-
def cy() -> types.Op:
|
|
143
|
-
"""Controlled Y gate."""
|
|
144
|
-
return control(y(), n_controls=1)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
@op
|
|
148
|
-
def cz() -> types.Op:
|
|
149
|
-
"""Control Z gate."""
|
|
150
|
-
return control(z(), n_controls=1)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
@op
|
|
154
|
-
def ch() -> types.Op:
|
|
155
|
-
"""Control H gate."""
|
|
156
|
-
return control(h(), n_controls=1)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
@op
|
|
160
|
-
def cphase(theta: float) -> types.Op:
|
|
161
|
-
"""Control Phase gate."""
|
|
162
|
-
return control(phase(theta), n_controls=1)
|
|
14
|
+
from ._wrapper import (
|
|
15
|
+
h as h,
|
|
16
|
+
s as s,
|
|
17
|
+
t as t,
|
|
18
|
+
u as u,
|
|
19
|
+
x as x,
|
|
20
|
+
y as y,
|
|
21
|
+
z as z,
|
|
22
|
+
p0 as p0,
|
|
23
|
+
p1 as p1,
|
|
24
|
+
rot as rot,
|
|
25
|
+
kron as kron,
|
|
26
|
+
mult as mult,
|
|
27
|
+
phase as phase,
|
|
28
|
+
scale as scale,
|
|
29
|
+
shift as shift,
|
|
30
|
+
spin_n as spin_n,
|
|
31
|
+
spin_p as spin_p,
|
|
32
|
+
adjoint as adjoint,
|
|
33
|
+
control as control,
|
|
34
|
+
identity as identity,
|
|
35
|
+
pauli_string as pauli_string,
|
|
36
|
+
)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from kirin.lowering import wraps
|
|
2
|
+
|
|
3
|
+
from . import stmts, types
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@wraps(stmts.Kron)
|
|
7
|
+
def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@wraps(stmts.Mult)
|
|
11
|
+
def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@wraps(stmts.Scale)
|
|
15
|
+
def scale(op: types.Op, factor: complex) -> types.Op: ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@wraps(stmts.Adjoint)
|
|
19
|
+
def adjoint(op: types.Op) -> types.Op: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@wraps(stmts.Control)
|
|
23
|
+
def control(op: types.Op, *, n_controls: int) -> types.Op:
|
|
24
|
+
"""
|
|
25
|
+
Create a controlled operator.
|
|
26
|
+
|
|
27
|
+
Note, that when considering atom loss, the operator will not be applied if
|
|
28
|
+
any of the controls has been lost.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
operator: The operator to apply under the control.
|
|
32
|
+
n_controls: The number qubits to be used as control.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Operator
|
|
36
|
+
"""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@wraps(stmts.Identity)
|
|
41
|
+
def identity(*, sites: int) -> types.Op: ...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@wraps(stmts.Rot)
|
|
45
|
+
def rot(axis: types.Op, angle: float) -> types.Op: ...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@wraps(stmts.ShiftOp)
|
|
49
|
+
def shift(theta: float) -> types.Op: ...
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@wraps(stmts.PhaseOp)
|
|
53
|
+
def phase(theta: float) -> types.Op: ...
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@wraps(stmts.X)
|
|
57
|
+
def x() -> types.Op: ...
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@wraps(stmts.Y)
|
|
61
|
+
def y() -> types.Op: ...
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@wraps(stmts.Z)
|
|
65
|
+
def z() -> types.Op: ...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@wraps(stmts.H)
|
|
69
|
+
def h() -> types.Op: ...
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@wraps(stmts.S)
|
|
73
|
+
def s() -> types.Op: ...
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@wraps(stmts.T)
|
|
77
|
+
def t() -> types.Op: ...
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@wraps(stmts.P0)
|
|
81
|
+
def p0() -> types.Op: ...
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@wraps(stmts.P1)
|
|
85
|
+
def p1() -> types.Op: ...
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@wraps(stmts.Sn)
|
|
89
|
+
def spin_n() -> types.Op: ...
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@wraps(stmts.Sp)
|
|
93
|
+
def spin_p() -> types.Op: ...
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@wraps(stmts.U3)
|
|
97
|
+
def u(theta: float, phi: float, lam: float) -> types.Op: ...
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@wraps(stmts.PauliString)
|
|
101
|
+
def pauli_string(*, string: str) -> types.Op: ...
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.prelude import structural_no_opt
|
|
3
|
+
|
|
4
|
+
from . import types
|
|
5
|
+
from ._dialect import dialect
|
|
6
|
+
from ._wrapper import h, x, y, z, rot, phase, control
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@ir.dialect_group(structural_no_opt.add(dialect))
|
|
10
|
+
def op(self):
|
|
11
|
+
def run_pass(method):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
return run_pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@op
|
|
18
|
+
def rx(theta: float) -> types.Op:
|
|
19
|
+
"""Rotation X gate."""
|
|
20
|
+
return rot(x(), theta)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@op
|
|
24
|
+
def ry(theta: float) -> types.Op:
|
|
25
|
+
"""Rotation Y gate."""
|
|
26
|
+
return rot(y(), theta)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@op
|
|
30
|
+
def rz(theta: float) -> types.Op:
|
|
31
|
+
"""Rotation Z gate."""
|
|
32
|
+
return rot(z(), theta)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@op
|
|
36
|
+
def cx() -> types.Op:
|
|
37
|
+
"""Controlled X gate."""
|
|
38
|
+
return control(x(), n_controls=1)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@op
|
|
42
|
+
def cy() -> types.Op:
|
|
43
|
+
"""Controlled Y gate."""
|
|
44
|
+
return control(y(), n_controls=1)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@op
|
|
48
|
+
def cz() -> types.Op:
|
|
49
|
+
"""Control Z gate."""
|
|
50
|
+
return control(z(), n_controls=1)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@op
|
|
54
|
+
def ch() -> types.Op:
|
|
55
|
+
"""Control H gate."""
|
|
56
|
+
return control(h(), n_controls=1)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@op
|
|
60
|
+
def cphase(theta: float) -> types.Op:
|
|
61
|
+
"""Control Phase gate."""
|
|
62
|
+
return control(phase(theta), n_controls=1)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .stim import SquinToStim as SquinToStim
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin.passes import Fold
|
|
4
|
+
from kirin.rewrite import (
|
|
5
|
+
Walk,
|
|
6
|
+
Chain,
|
|
7
|
+
Fixpoint,
|
|
8
|
+
DeadCodeElimination,
|
|
9
|
+
CommonSubexpressionElimination,
|
|
10
|
+
)
|
|
11
|
+
from kirin.ir.method import Method
|
|
12
|
+
from kirin.passes.abc import Pass
|
|
13
|
+
from kirin.rewrite.abc import RewriteResult
|
|
14
|
+
|
|
15
|
+
from bloqade.squin.rewrite import (
|
|
16
|
+
SquinWireToStim,
|
|
17
|
+
SquinQubitToStim,
|
|
18
|
+
WrapSquinAnalysis,
|
|
19
|
+
SquinMeasureToStim,
|
|
20
|
+
SquinWireIdentityElimination,
|
|
21
|
+
)
|
|
22
|
+
from bloqade.analysis.address import AddressAnalysis
|
|
23
|
+
from bloqade.squin.analysis.nsites import (
|
|
24
|
+
NSitesAnalysis,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class SquinToStim(Pass):
|
|
30
|
+
|
|
31
|
+
def unsafe_run(self, mt: Method) -> RewriteResult:
|
|
32
|
+
fold_pass = Fold(mt.dialects)
|
|
33
|
+
# propagate constants
|
|
34
|
+
rewrite_result = fold_pass(mt)
|
|
35
|
+
|
|
36
|
+
# Get necessary analysis results to plug into hints
|
|
37
|
+
address_analysis = AddressAnalysis(mt.dialects)
|
|
38
|
+
address_frame, _ = address_analysis.run_analysis(mt)
|
|
39
|
+
site_analysis = NSitesAnalysis(mt.dialects)
|
|
40
|
+
sites_frame, _ = site_analysis.run_analysis(mt)
|
|
41
|
+
|
|
42
|
+
# Wrap Rewrite + SquinToStim can happen w/ standard walk
|
|
43
|
+
rewrite_result = (
|
|
44
|
+
Walk(
|
|
45
|
+
Chain(
|
|
46
|
+
WrapSquinAnalysis(
|
|
47
|
+
address_analysis=address_frame.entries,
|
|
48
|
+
op_site_analysis=sites_frame.entries,
|
|
49
|
+
),
|
|
50
|
+
SquinQubitToStim(),
|
|
51
|
+
SquinWireToStim(),
|
|
52
|
+
SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later
|
|
53
|
+
SquinWireIdentityElimination(),
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
.rewrite(mt.code)
|
|
57
|
+
.join(rewrite_result)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
rewrite_result = (
|
|
61
|
+
Fixpoint(
|
|
62
|
+
Walk(Chain(DeadCodeElimination(), CommonSubexpressionElimination()))
|
|
63
|
+
)
|
|
64
|
+
.rewrite(mt.code)
|
|
65
|
+
.join(rewrite_result)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return rewrite_result
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .wire_to_stim import SquinWireToStim as SquinWireToStim
|
|
2
|
+
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
|
|
3
|
+
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
|
|
4
|
+
from .wrap_analysis import (
|
|
5
|
+
SitesAttribute as SitesAttribute,
|
|
6
|
+
AddressAttribute as AddressAttribute,
|
|
7
|
+
WrapSquinAnalysis as WrapSquinAnalysis,
|
|
8
|
+
)
|
|
9
|
+
from .wire_identity_elimination import (
|
|
10
|
+
SquinWireIdentityElimination as SquinWireIdentityElimination,
|
|
11
|
+
)
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
|
+
|
|
4
|
+
from bloqade import stim
|
|
5
|
+
from bloqade.squin import op, qubit
|
|
6
|
+
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
7
|
+
from bloqade.squin.rewrite.stim_rewrite_util import (
|
|
8
|
+
SQUIN_STIM_GATE_MAPPING,
|
|
9
|
+
rewrite_Control,
|
|
10
|
+
insert_qubit_idx_from_address,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SquinQubitToStim(RewriteRule):
|
|
15
|
+
|
|
16
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
17
|
+
|
|
18
|
+
match node:
|
|
19
|
+
case qubit.Apply() | qubit.Broadcast():
|
|
20
|
+
return self.rewrite_Apply_and_Broadcast(node)
|
|
21
|
+
case qubit.Reset():
|
|
22
|
+
return self.rewrite_Reset(node)
|
|
23
|
+
case _:
|
|
24
|
+
return RewriteResult()
|
|
25
|
+
|
|
26
|
+
def rewrite_Apply_and_Broadcast(
|
|
27
|
+
self, stmt: qubit.Apply | qubit.Broadcast
|
|
28
|
+
) -> RewriteResult:
|
|
29
|
+
"""
|
|
30
|
+
Rewrite Apply and Broadcast nodes to their stim equivalent statements.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
# this is an SSAValue, need it to be the actual operator
|
|
34
|
+
applied_op = stmt.operator.owner
|
|
35
|
+
assert isinstance(applied_op, op.stmts.Operator)
|
|
36
|
+
|
|
37
|
+
if isinstance(applied_op, op.stmts.Control):
|
|
38
|
+
return rewrite_Control(stmt)
|
|
39
|
+
|
|
40
|
+
# need to handle Control through separate means
|
|
41
|
+
# but we can handle X, Y, Z, H, and S here just fine
|
|
42
|
+
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
|
|
43
|
+
if stim_1q_op is None:
|
|
44
|
+
return RewriteResult()
|
|
45
|
+
|
|
46
|
+
address_attr = stmt.qubits.hints.get("address")
|
|
47
|
+
if address_attr is None:
|
|
48
|
+
return RewriteResult()
|
|
49
|
+
|
|
50
|
+
assert isinstance(address_attr, AddressAttribute)
|
|
51
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
52
|
+
address=address_attr, stmt_to_insert_before=stmt
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if qubit_idx_ssas is None:
|
|
56
|
+
return RewriteResult()
|
|
57
|
+
|
|
58
|
+
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
|
|
59
|
+
stmt.replace_by(stim_1q_stmt)
|
|
60
|
+
|
|
61
|
+
return RewriteResult(has_done_something=True)
|
|
62
|
+
|
|
63
|
+
def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult:
|
|
64
|
+
qubit_ilist_ssa = reset_stmt.qubits
|
|
65
|
+
# qubits are in an ilist which makes up an AddressTuple
|
|
66
|
+
address_attr = qubit_ilist_ssa.hints.get("address")
|
|
67
|
+
if address_attr is None:
|
|
68
|
+
return RewriteResult()
|
|
69
|
+
|
|
70
|
+
assert isinstance(address_attr, AddressAttribute)
|
|
71
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
72
|
+
address=address_attr, stmt_to_insert_before=reset_stmt
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if qubit_idx_ssas is None:
|
|
76
|
+
return RewriteResult()
|
|
77
|
+
|
|
78
|
+
stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
|
|
79
|
+
reset_stmt.replace_by(stim_rz_stmt)
|
|
80
|
+
|
|
81
|
+
return RewriteResult(has_done_something=True)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# put rewrites for measure statements in separate rule, then just have to dispatch
|