bloqade-circuit 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +5 -9
- bloqade/analysis/address/lattice.py +1 -1
- bloqade/analysis/fidelity/__init__.py +1 -0
- bloqade/analysis/fidelity/analysis.py +69 -0
- bloqade/device.py +130 -0
- bloqade/noise/__init__.py +2 -1
- bloqade/noise/fidelity.py +51 -0
- bloqade/noise/native/model.py +1 -2
- bloqade/noise/native/rewrite.py +5 -5
- bloqade/noise/native/stmts.py +40 -11
- bloqade/pyqrack/__init__.py +8 -2
- bloqade/pyqrack/base.py +24 -3
- bloqade/pyqrack/device.py +166 -0
- bloqade/pyqrack/noise/native.py +1 -2
- bloqade/pyqrack/qasm2/core.py +31 -15
- bloqade/pyqrack/qasm2/glob.py +28 -0
- bloqade/pyqrack/qasm2/uop.py +9 -1
- bloqade/pyqrack/reg.py +17 -49
- bloqade/pyqrack/squin/__init__.py +0 -0
- bloqade/pyqrack/squin/op.py +154 -0
- bloqade/pyqrack/squin/qubit.py +85 -0
- bloqade/pyqrack/squin/runtime.py +515 -0
- bloqade/pyqrack/squin/wire.py +69 -0
- bloqade/pyqrack/target.py +9 -2
- bloqade/pyqrack/task.py +30 -0
- bloqade/qasm2/_wrappers.py +11 -1
- bloqade/qasm2/dialects/core/stmts.py +15 -4
- bloqade/qasm2/dialects/expr/_emit.py +9 -8
- bloqade/qasm2/emit/base.py +4 -2
- bloqade/qasm2/emit/gate.py +0 -14
- bloqade/qasm2/emit/main.py +19 -15
- bloqade/qasm2/emit/target.py +2 -6
- bloqade/qasm2/glob.py +1 -1
- bloqade/qasm2/parse/lowering.py +124 -1
- bloqade/qasm2/passes/glob.py +3 -3
- bloqade/qasm2/passes/lift_qubits.py +26 -0
- bloqade/qasm2/passes/noise.py +6 -14
- bloqade/qasm2/passes/parallel.py +3 -3
- bloqade/qasm2/passes/py2qasm.py +1 -2
- bloqade/qasm2/passes/qasm2py.py +1 -2
- bloqade/qasm2/rewrite/desugar.py +6 -6
- bloqade/qasm2/rewrite/glob.py +9 -9
- bloqade/qasm2/rewrite/heuristic_noise.py +30 -38
- bloqade/qasm2/rewrite/insert_qubits.py +34 -0
- bloqade/qasm2/rewrite/native_gates.py +54 -55
- bloqade/qasm2/rewrite/parallel_to_uop.py +9 -9
- bloqade/qasm2/rewrite/uop_to_parallel.py +20 -22
- bloqade/qasm2/types.py +3 -6
- bloqade/qbraid/schema.py +10 -12
- bloqade/squin/__init__.py +1 -1
- bloqade/squin/analysis/nsites/analysis.py +4 -6
- bloqade/squin/analysis/nsites/impls.py +2 -6
- bloqade/squin/analysis/schedule.py +1 -1
- bloqade/squin/groups.py +15 -7
- bloqade/squin/noise/__init__.py +27 -0
- bloqade/squin/noise/_dialect.py +3 -0
- bloqade/squin/noise/stmts.py +59 -0
- bloqade/squin/op/__init__.py +35 -5
- bloqade/squin/op/number.py +5 -0
- bloqade/squin/op/rewrite.py +46 -0
- bloqade/squin/op/stmts.py +23 -2
- bloqade/squin/op/types.py +14 -0
- bloqade/squin/qubit.py +79 -11
- bloqade/squin/rewrite/__init__.py +0 -0
- bloqade/squin/rewrite/measure_desugar.py +33 -0
- bloqade/squin/wire.py +31 -2
- bloqade/stim/emit/stim.py +1 -1
- bloqade/task.py +94 -0
- bloqade/visual/animation/base.py +25 -15
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/METADATA +8 -2
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/RECORD +73 -52
- bloqade/squin/op/complex.py +0 -6
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
from typing import Dict, List, Tuple
|
|
1
|
+
from typing import Dict, List, Tuple, cast
|
|
2
2
|
from dataclasses import field, dataclass
|
|
3
3
|
|
|
4
4
|
from kirin import ir
|
|
5
|
-
from kirin.rewrite import abc as
|
|
6
|
-
from kirin.dialects import
|
|
5
|
+
from kirin.rewrite import abc as rewrite_abc
|
|
6
|
+
from kirin.dialects import ilist
|
|
7
7
|
|
|
8
8
|
from bloqade.noise import native
|
|
9
9
|
from bloqade.analysis import address
|
|
10
|
-
from bloqade.qasm2.dialects import uop,
|
|
10
|
+
from bloqade.qasm2.dialects import uop, glob, parallel
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
@dataclass
|
|
14
|
-
class NoiseRewriteRule(
|
|
14
|
+
class NoiseRewriteRule(rewrite_abc.RewriteRule):
|
|
15
15
|
"""
|
|
16
16
|
NOTE: This pass is not guaranteed to be supported long-term in bloqade. We will be
|
|
17
17
|
moving towards a more general approach to noise modeling in the future.
|
|
@@ -24,12 +24,18 @@ class NoiseRewriteRule(result_abc.RewriteRule):
|
|
|
24
24
|
noise_model: native.MoveNoiseModelABC = field(
|
|
25
25
|
default_factory=native.TwoRowZoneModel
|
|
26
26
|
)
|
|
27
|
-
qubit_ssa_value: Dict[int, ir.SSAValue] = field(default_factory=dict, init=False)
|
|
28
27
|
|
|
29
|
-
def
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
28
|
+
def __post_init__(self):
|
|
29
|
+
self.qubit_ssa_value: Dict[int, ir.SSAValue] = {}
|
|
30
|
+
for ssa_value, addr in self.address_analysis.items():
|
|
31
|
+
if (
|
|
32
|
+
isinstance(addr, address.AddressQubit)
|
|
33
|
+
and ssa_value not in self.qubit_ssa_value
|
|
34
|
+
):
|
|
35
|
+
self.qubit_ssa_value[addr.data] = ssa_value
|
|
36
|
+
|
|
37
|
+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
|
|
38
|
+
if isinstance(node, uop.SingleQubitGate):
|
|
33
39
|
return self.rewrite_single_qubit_gate(node)
|
|
34
40
|
elif isinstance(node, uop.CZ):
|
|
35
41
|
return self.rewrite_cz_gate(node)
|
|
@@ -40,25 +46,7 @@ class NoiseRewriteRule(result_abc.RewriteRule):
|
|
|
40
46
|
elif isinstance(node, glob.UGate):
|
|
41
47
|
return self.rewrite_global_single_qubit_gate(node)
|
|
42
48
|
else:
|
|
43
|
-
return
|
|
44
|
-
|
|
45
|
-
def rewrite_qreg_new(self, node: core.QRegNew):
|
|
46
|
-
|
|
47
|
-
addr = self.address_analysis[node.result]
|
|
48
|
-
if not isinstance(addr, address.AddressReg):
|
|
49
|
-
return result.RewriteResult()
|
|
50
|
-
|
|
51
|
-
has_done_something = False
|
|
52
|
-
for idx_val, qid in enumerate(addr.data):
|
|
53
|
-
if qid not in self.qubit_ssa_value:
|
|
54
|
-
has_done_something = True
|
|
55
|
-
idx = py.constant.Constant(value=idx_val)
|
|
56
|
-
qubit = core.QRegGet(node.result, idx=idx.result)
|
|
57
|
-
self.qubit_ssa_value[qid] = qubit.result
|
|
58
|
-
qubit.insert_after(node)
|
|
59
|
-
idx.insert_after(node)
|
|
60
|
-
|
|
61
|
-
return result.RewriteResult(has_done_something=has_done_something)
|
|
49
|
+
return rewrite_abc.RewriteResult()
|
|
62
50
|
|
|
63
51
|
def insert_single_qubit_noise(
|
|
64
52
|
self,
|
|
@@ -71,7 +59,7 @@ class NoiseRewriteRule(result_abc.RewriteRule):
|
|
|
71
59
|
)
|
|
72
60
|
native.AtomLossChannel(qargs, prob=probs[3]).insert_before(node)
|
|
73
61
|
|
|
74
|
-
return
|
|
62
|
+
return rewrite_abc.RewriteResult(has_done_something=True)
|
|
75
63
|
|
|
76
64
|
def rewrite_single_qubit_gate(self, node: uop.SingleQubitGate):
|
|
77
65
|
probs = (
|
|
@@ -86,13 +74,13 @@ class NoiseRewriteRule(result_abc.RewriteRule):
|
|
|
86
74
|
def rewrite_global_single_qubit_gate(self, node: glob.UGate):
|
|
87
75
|
addrs = self.address_analysis[node.registers]
|
|
88
76
|
if not isinstance(addrs, address.AddressTuple):
|
|
89
|
-
return
|
|
77
|
+
return rewrite_abc.RewriteResult()
|
|
90
78
|
|
|
91
79
|
qargs = []
|
|
92
80
|
|
|
93
81
|
for addr in addrs.data:
|
|
94
82
|
if not isinstance(addr, address.AddressReg):
|
|
95
|
-
return
|
|
83
|
+
return rewrite_abc.RewriteResult()
|
|
96
84
|
|
|
97
85
|
for qid in addr.data:
|
|
98
86
|
qargs.append(self.qubit_ssa_value[qid])
|
|
@@ -109,10 +97,10 @@ class NoiseRewriteRule(result_abc.RewriteRule):
|
|
|
109
97
|
def rewrite_parallel_single_qubit_gate(self, node: parallel.RZ | parallel.UGate):
|
|
110
98
|
addrs = self.address_analysis[node.qargs]
|
|
111
99
|
if not isinstance(addrs, address.AddressTuple):
|
|
112
|
-
return
|
|
100
|
+
return rewrite_abc.RewriteResult()
|
|
113
101
|
|
|
114
102
|
if not all(isinstance(addr, address.AddressQubit) for addr in addrs.data):
|
|
115
|
-
return
|
|
103
|
+
return rewrite_abc.RewriteResult()
|
|
116
104
|
|
|
117
105
|
probs = (
|
|
118
106
|
self.gate_noise_params.local_px,
|
|
@@ -213,7 +201,7 @@ class NoiseRewriteRule(result_abc.RewriteRule):
|
|
|
213
201
|
new_node.insert_before(node)
|
|
214
202
|
has_done_something = True
|
|
215
203
|
|
|
216
|
-
return
|
|
204
|
+
return rewrite_abc.RewriteResult(has_done_something=has_done_something)
|
|
217
205
|
|
|
218
206
|
def rewrite_parallel_cz_gate(self, node: parallel.CZ):
|
|
219
207
|
ctrls = self.address_analysis[node.ctrls]
|
|
@@ -226,8 +214,12 @@ class NoiseRewriteRule(result_abc.RewriteRule):
|
|
|
226
214
|
and isinstance(qargs, address.AddressTuple)
|
|
227
215
|
and all(isinstance(addr, address.AddressQubit) for addr in qargs.data)
|
|
228
216
|
):
|
|
229
|
-
ctrl_qubits = list(
|
|
230
|
-
|
|
217
|
+
ctrl_qubits = list(
|
|
218
|
+
map(lambda addr: cast(address.AddressQubit, addr).data, ctrls.data)
|
|
219
|
+
)
|
|
220
|
+
qarg_qubits = list(
|
|
221
|
+
map(lambda addr: cast(address.AddressQubit, addr).data, qargs.data)
|
|
222
|
+
)
|
|
231
223
|
rest = sorted(
|
|
232
224
|
set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits)
|
|
233
225
|
)
|
|
@@ -244,4 +236,4 @@ class NoiseRewriteRule(result_abc.RewriteRule):
|
|
|
244
236
|
new_node.insert_before(node)
|
|
245
237
|
has_done_something = True
|
|
246
238
|
|
|
247
|
-
return
|
|
239
|
+
return rewrite_abc.RewriteResult(has_done_something=has_done_something)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.rewrite import abc as rewrite_abc
|
|
3
|
+
from kirin.dialects import py
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class InsertGetQubit(rewrite_abc.RewriteRule):
|
|
7
|
+
|
|
8
|
+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
|
|
9
|
+
from bloqade.qasm2 import core
|
|
10
|
+
|
|
11
|
+
if (
|
|
12
|
+
not isinstance(node, core.QRegNew)
|
|
13
|
+
or not isinstance(n_qubits_stmt := node.n_qubits.owner, py.Constant)
|
|
14
|
+
or not isinstance(n_qubits := n_qubits_stmt.value.unwrap(), int)
|
|
15
|
+
or (block := node.parent_block) is None
|
|
16
|
+
):
|
|
17
|
+
return rewrite_abc.RewriteResult()
|
|
18
|
+
|
|
19
|
+
n_qubits_stmt.detach()
|
|
20
|
+
node.detach()
|
|
21
|
+
if block.first_stmt is None:
|
|
22
|
+
block.stmts.append(n_qubits_stmt)
|
|
23
|
+
block.stmts.append(node)
|
|
24
|
+
else:
|
|
25
|
+
node.insert_before(block.first_stmt)
|
|
26
|
+
n_qubits_stmt.insert_before(block.first_stmt)
|
|
27
|
+
|
|
28
|
+
for idx_val in range(n_qubits):
|
|
29
|
+
idx = py.constant.Constant(value=idx_val)
|
|
30
|
+
qubit = core.QRegGet(node.result, idx=idx.result)
|
|
31
|
+
qubit.insert_after(node)
|
|
32
|
+
idx.insert_after(node)
|
|
33
|
+
|
|
34
|
+
return rewrite_abc.RewriteResult(has_done_something=True)
|
|
@@ -10,7 +10,7 @@ import cirq.contrib.qasm_import
|
|
|
10
10
|
import cirq.transformers.target_gatesets
|
|
11
11
|
import cirq.transformers.target_gatesets.compilation_target_gateset
|
|
12
12
|
from kirin import ir
|
|
13
|
-
from kirin.rewrite import abc
|
|
13
|
+
from kirin.rewrite import abc
|
|
14
14
|
from kirin.dialects import py
|
|
15
15
|
from cirq.circuits.qasm_output import QasmUGate
|
|
16
16
|
from cirq.transformers.target_gatesets.compilation_target_gateset import (
|
|
@@ -70,7 +70,7 @@ class CU(cirq.Gate):
|
|
|
70
70
|
return "*", "CU"
|
|
71
71
|
|
|
72
72
|
|
|
73
|
-
def around(val):
|
|
73
|
+
def around(val) -> float:
|
|
74
74
|
return float(np.around(val, 14))
|
|
75
75
|
|
|
76
76
|
|
|
@@ -78,7 +78,7 @@ def one_qubit_gate_to_u3_angles(op: cirq.Operation) -> tuple[float, float, float
|
|
|
78
78
|
lam, theta, phi = ( # Z angle, Y angle, then Z angle
|
|
79
79
|
cirq.deconstruct_single_qubit_matrix_into_angles(cirq.unitary(op))
|
|
80
80
|
)
|
|
81
|
-
return
|
|
81
|
+
return around(theta), around(phi), around(lam)
|
|
82
82
|
|
|
83
83
|
|
|
84
84
|
@dataclass
|
|
@@ -111,25 +111,25 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
111
111
|
else:
|
|
112
112
|
return py.constant.Constant(value=math.pi)
|
|
113
113
|
|
|
114
|
-
def rewrite_Statement(self, node: ir.Statement) ->
|
|
114
|
+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
|
|
115
115
|
# only deal with uop
|
|
116
116
|
if type(node) in uop.dialect.stmts:
|
|
117
117
|
return getattr(self, f"rewrite_{node.name}")(node)
|
|
118
118
|
|
|
119
|
-
return
|
|
119
|
+
return abc.RewriteResult()
|
|
120
120
|
|
|
121
|
-
def rewrite_barrier(self, node: uop.Barrier) ->
|
|
122
|
-
return
|
|
121
|
+
def rewrite_barrier(self, node: uop.Barrier) -> abc.RewriteResult:
|
|
122
|
+
return abc.RewriteResult()
|
|
123
123
|
|
|
124
|
-
def rewrite_cz(self, node: uop.CZ) ->
|
|
125
|
-
return
|
|
124
|
+
def rewrite_cz(self, node: uop.CZ) -> abc.RewriteResult:
|
|
125
|
+
return abc.RewriteResult()
|
|
126
126
|
|
|
127
|
-
def rewrite_CX(self, node: uop.CX) ->
|
|
127
|
+
def rewrite_CX(self, node: uop.CX) -> abc.RewriteResult:
|
|
128
128
|
return self._rewrite_2q_ctrl_gates(
|
|
129
129
|
cirq.CX(self.cached_qubits[0], self.cached_qubits[1]), node
|
|
130
130
|
)
|
|
131
131
|
|
|
132
|
-
def rewrite_cy(self, node: uop.CY) ->
|
|
132
|
+
def rewrite_cy(self, node: uop.CY) -> abc.RewriteResult:
|
|
133
133
|
return self._rewrite_2q_ctrl_gates(
|
|
134
134
|
cirq.ControlledGate(cirq.Y, 1)(
|
|
135
135
|
self.cached_qubits[0], self.cached_qubits[1]
|
|
@@ -137,92 +137,92 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
137
137
|
node,
|
|
138
138
|
)
|
|
139
139
|
|
|
140
|
-
def rewrite_U(self, node: uop.UGate) ->
|
|
141
|
-
return
|
|
140
|
+
def rewrite_U(self, node: uop.UGate) -> abc.RewriteResult:
|
|
141
|
+
return abc.RewriteResult()
|
|
142
142
|
|
|
143
|
-
def rewrite_id(self, node: uop.Id) ->
|
|
143
|
+
def rewrite_id(self, node: uop.Id) -> abc.RewriteResult:
|
|
144
144
|
node.delete() # just delete the identity gate
|
|
145
|
-
return
|
|
145
|
+
return abc.RewriteResult(has_done_something=True)
|
|
146
146
|
|
|
147
|
-
def rewrite_h(self, node: uop.H) ->
|
|
147
|
+
def rewrite_h(self, node: uop.H) -> abc.RewriteResult:
|
|
148
148
|
return self._rewrite_1q_gates(cirq.H(self.cached_qubits[0]), node)
|
|
149
149
|
|
|
150
|
-
def rewrite_x(self, node: uop.X) ->
|
|
150
|
+
def rewrite_x(self, node: uop.X) -> abc.RewriteResult:
|
|
151
151
|
return self._rewrite_1q_gates(cirq.X(self.cached_qubits[0]), node)
|
|
152
152
|
|
|
153
|
-
def rewrite_y(self, node: uop.Y) ->
|
|
153
|
+
def rewrite_y(self, node: uop.Y) -> abc.RewriteResult:
|
|
154
154
|
return self._rewrite_1q_gates(cirq.Y(self.cached_qubits[0]), node)
|
|
155
155
|
|
|
156
|
-
def rewrite_z(self, node: uop.Z) ->
|
|
156
|
+
def rewrite_z(self, node: uop.Z) -> abc.RewriteResult:
|
|
157
157
|
return self._rewrite_1q_gates(cirq.Z(self.cached_qubits[0]), node)
|
|
158
158
|
|
|
159
|
-
def rewrite_s(self, node: uop.S) ->
|
|
159
|
+
def rewrite_s(self, node: uop.S) -> abc.RewriteResult:
|
|
160
160
|
return self._rewrite_1q_gates(cirq.S(self.cached_qubits[0]), node)
|
|
161
161
|
|
|
162
|
-
def rewrite_sdg(self, node: uop.Sdag) ->
|
|
162
|
+
def rewrite_sdg(self, node: uop.Sdag) -> abc.RewriteResult:
|
|
163
163
|
return self._rewrite_1q_gates(cirq.S(self.cached_qubits[0]) ** -1, node)
|
|
164
164
|
|
|
165
|
-
def rewrite_t(self, node: uop.T) ->
|
|
165
|
+
def rewrite_t(self, node: uop.T) -> abc.RewriteResult:
|
|
166
166
|
return self._rewrite_1q_gates(cirq.T(self.cached_qubits[0]), node)
|
|
167
167
|
|
|
168
|
-
def rewrite_tdg(self, node: uop.Tdag) ->
|
|
168
|
+
def rewrite_tdg(self, node: uop.Tdag) -> abc.RewriteResult:
|
|
169
169
|
return self._rewrite_1q_gates(cirq.T(self.cached_qubits[0]) ** -1, node)
|
|
170
170
|
|
|
171
|
-
def rewrite_sx(self, node: uop.SX) ->
|
|
171
|
+
def rewrite_sx(self, node: uop.SX) -> abc.RewriteResult:
|
|
172
172
|
return self._rewrite_1q_gates(
|
|
173
173
|
cirq.XPowGate(exponent=0.5).on(self.cached_qubits[0]), node
|
|
174
174
|
)
|
|
175
175
|
|
|
176
|
-
def rewrite_sxdg(self, node: uop.SXdag) ->
|
|
176
|
+
def rewrite_sxdg(self, node: uop.SXdag) -> abc.RewriteResult:
|
|
177
177
|
return self._rewrite_1q_gates(
|
|
178
178
|
cirq.XPowGate(exponent=-0.5).on(self.cached_qubits[0]), node
|
|
179
179
|
)
|
|
180
180
|
|
|
181
|
-
def rewrite_u1(self, node: uop.U1) ->
|
|
181
|
+
def rewrite_u1(self, node: uop.U1) -> abc.RewriteResult:
|
|
182
182
|
theta = node.lam
|
|
183
183
|
(phi := self.const_float(value=0.0)).insert_before(node)
|
|
184
184
|
node.replace_by(
|
|
185
185
|
uop.UGate(qarg=node.qarg, theta=phi.result, phi=phi.result, lam=theta)
|
|
186
186
|
)
|
|
187
|
-
return
|
|
187
|
+
return abc.RewriteResult(has_done_something=True)
|
|
188
188
|
|
|
189
|
-
def rewrite_u2(self, node: uop.U2) ->
|
|
189
|
+
def rewrite_u2(self, node: uop.U2) -> abc.RewriteResult:
|
|
190
190
|
phi = node.phi
|
|
191
191
|
lam = node.lam
|
|
192
192
|
(theta := self.const_float(value=math.pi / 2)).insert_before(node)
|
|
193
193
|
node.replace_by(uop.UGate(qarg=node.qarg, theta=theta.result, phi=phi, lam=lam))
|
|
194
|
-
return
|
|
194
|
+
return abc.RewriteResult(has_done_something=True)
|
|
195
195
|
|
|
196
|
-
def rewrite_rx(self, node: uop.RX) ->
|
|
196
|
+
def rewrite_rx(self, node: uop.RX) -> abc.RewriteResult:
|
|
197
197
|
theta = node.theta
|
|
198
198
|
(phi := self.const_float(value=math.pi / 2)).insert_before(node)
|
|
199
199
|
(lam := self.const_float(value=-math.pi / 2)).insert_before(node)
|
|
200
200
|
node.replace_by(
|
|
201
201
|
uop.UGate(qarg=node.qarg, theta=theta, phi=phi.result, lam=lam.result)
|
|
202
202
|
)
|
|
203
|
-
return
|
|
203
|
+
return abc.RewriteResult(has_done_something=True)
|
|
204
204
|
|
|
205
|
-
def rewrite_ry(self, node: uop.RY) ->
|
|
205
|
+
def rewrite_ry(self, node: uop.RY) -> abc.RewriteResult:
|
|
206
206
|
theta = node.theta
|
|
207
207
|
(phi := self.const_float(value=0.0)).insert_before(node)
|
|
208
208
|
node.replace_by(
|
|
209
209
|
uop.UGate(qarg=node.qarg, theta=theta, phi=phi.result, lam=phi.result)
|
|
210
210
|
)
|
|
211
|
-
return
|
|
211
|
+
return abc.RewriteResult(has_done_something=True)
|
|
212
212
|
|
|
213
|
-
def rewrite_rz(self, node: uop.RZ) ->
|
|
213
|
+
def rewrite_rz(self, node: uop.RZ) -> abc.RewriteResult:
|
|
214
214
|
theta = node.theta
|
|
215
215
|
(phi := self.const_float(value=0.0)).insert_before(node)
|
|
216
216
|
node.replace_by(
|
|
217
217
|
uop.UGate(qarg=node.qarg, theta=phi.result, phi=phi.result, lam=theta)
|
|
218
218
|
)
|
|
219
|
-
return
|
|
219
|
+
return abc.RewriteResult(has_done_something=True)
|
|
220
220
|
|
|
221
|
-
def rewrite_crx(self, node: uop.CRX) ->
|
|
221
|
+
def rewrite_crx(self, node: uop.CRX) -> abc.RewriteResult:
|
|
222
222
|
lam = self._get_const_value(node.lam)
|
|
223
223
|
|
|
224
224
|
if lam is None:
|
|
225
|
-
return
|
|
225
|
+
return abc.RewriteResult()
|
|
226
226
|
|
|
227
227
|
return self._rewrite_2q_ctrl_gates(
|
|
228
228
|
cirq.ControlledGate(cirq.Rx(rads=lam), 1).on(
|
|
@@ -231,11 +231,11 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
231
231
|
node,
|
|
232
232
|
)
|
|
233
233
|
|
|
234
|
-
def rewrite_cry(self, node: uop.CRY) ->
|
|
234
|
+
def rewrite_cry(self, node: uop.CRY) -> abc.RewriteResult:
|
|
235
235
|
lam = self._get_const_value(node.lam)
|
|
236
236
|
|
|
237
237
|
if lam is None:
|
|
238
|
-
return
|
|
238
|
+
return abc.RewriteResult()
|
|
239
239
|
|
|
240
240
|
return self._rewrite_2q_ctrl_gates(
|
|
241
241
|
cirq.ControlledGate(cirq.Ry(rads=lam), 1).on(
|
|
@@ -244,11 +244,11 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
244
244
|
node,
|
|
245
245
|
)
|
|
246
246
|
|
|
247
|
-
def rewrite_crz(self, node: uop.CRZ) ->
|
|
247
|
+
def rewrite_crz(self, node: uop.CRZ) -> abc.RewriteResult:
|
|
248
248
|
lam = self._get_const_value(node.lam)
|
|
249
249
|
|
|
250
250
|
if lam is None:
|
|
251
|
-
return
|
|
251
|
+
return abc.RewriteResult()
|
|
252
252
|
|
|
253
253
|
return self._rewrite_2q_ctrl_gates(
|
|
254
254
|
cirq.ControlledGate(cirq.Rz(rads=lam), 1).on(
|
|
@@ -257,12 +257,12 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
257
257
|
node,
|
|
258
258
|
)
|
|
259
259
|
|
|
260
|
-
def rewrite_cu1(self, node: uop.CU1) ->
|
|
260
|
+
def rewrite_cu1(self, node: uop.CU1) -> abc.RewriteResult:
|
|
261
261
|
|
|
262
262
|
lam = self._get_const_value(node.lam)
|
|
263
263
|
|
|
264
264
|
if lam is None:
|
|
265
|
-
return
|
|
265
|
+
return abc.RewriteResult()
|
|
266
266
|
|
|
267
267
|
# cirq.ControlledGate(u3(0, 0, lambda))
|
|
268
268
|
return self._rewrite_2q_ctrl_gates(
|
|
@@ -273,14 +273,13 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
273
273
|
)
|
|
274
274
|
pass
|
|
275
275
|
|
|
276
|
-
def rewrite_cu3(self, node: uop.CU3) ->
|
|
276
|
+
def rewrite_cu3(self, node: uop.CU3) -> abc.RewriteResult:
|
|
277
277
|
|
|
278
278
|
theta = self._get_const_value(node.theta)
|
|
279
279
|
lam = self._get_const_value(node.lam)
|
|
280
280
|
phi = self._get_const_value(node.phi)
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
return result.RewriteResult()
|
|
281
|
+
if theta is None or lam is None or phi is None:
|
|
282
|
+
return abc.RewriteResult()
|
|
284
283
|
|
|
285
284
|
# cirq.ControlledGate(u3(theta, lambda phi))
|
|
286
285
|
return self._rewrite_2q_ctrl_gates(
|
|
@@ -290,7 +289,7 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
290
289
|
node,
|
|
291
290
|
)
|
|
292
291
|
|
|
293
|
-
def rewrite_cu(self, node: uop.CU) ->
|
|
292
|
+
def rewrite_cu(self, node: uop.CU) -> abc.RewriteResult:
|
|
294
293
|
|
|
295
294
|
gamma = self._get_const_value(node.gamma)
|
|
296
295
|
theta = self._get_const_value(node.theta)
|
|
@@ -304,12 +303,12 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
304
303
|
node,
|
|
305
304
|
)
|
|
306
305
|
|
|
307
|
-
def rewrite_rxx(self, node: uop.RXX) ->
|
|
306
|
+
def rewrite_rxx(self, node: uop.RXX) -> abc.RewriteResult:
|
|
308
307
|
|
|
309
308
|
theta = self._get_const_value(node.theta)
|
|
310
309
|
|
|
311
310
|
if theta is None:
|
|
312
|
-
return
|
|
311
|
+
return abc.RewriteResult()
|
|
313
312
|
|
|
314
313
|
# even though the XX gate is not controlled,
|
|
315
314
|
# the end U + CZ decomposition that happens internally means
|
|
@@ -320,11 +319,11 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
320
319
|
node,
|
|
321
320
|
)
|
|
322
321
|
|
|
323
|
-
def rewrite_rzz(self, node: uop.RZZ) ->
|
|
322
|
+
def rewrite_rzz(self, node: uop.RZZ) -> abc.RewriteResult:
|
|
324
323
|
theta = self._get_const_value(node.theta)
|
|
325
324
|
|
|
326
325
|
if theta is None:
|
|
327
|
-
return
|
|
326
|
+
return abc.RewriteResult()
|
|
328
327
|
|
|
329
328
|
return self._rewrite_2q_ctrl_gates(
|
|
330
329
|
cirq.ZZPowGate(exponent=theta / math.pi).on(
|
|
@@ -391,7 +390,7 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
391
390
|
|
|
392
391
|
def _rewrite_1q_gates(
|
|
393
392
|
self, cirq_gate: cirq.Operation, node: uop.SingleQubitGate
|
|
394
|
-
) ->
|
|
393
|
+
) -> abc.RewriteResult:
|
|
395
394
|
new_gate_stmts = self._generate_1q_gate_stmts(cirq_gate, node.qarg)
|
|
396
395
|
return self._rewrite_gate_stmts(new_gate_stmts, node)
|
|
397
396
|
|
|
@@ -427,7 +426,7 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
427
426
|
|
|
428
427
|
def _rewrite_2q_ctrl_gates(
|
|
429
428
|
self, cirq_gate: cirq.Operation, node: uop.TwoQubitCtrlGate
|
|
430
|
-
) ->
|
|
429
|
+
) -> abc.RewriteResult:
|
|
431
430
|
new_gate_stmts = self._generate_2q_ctrl_gate_stmts(
|
|
432
431
|
cirq_gate, [node.ctrl, node.qarg]
|
|
433
432
|
)
|
|
@@ -444,4 +443,4 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
|
|
|
444
443
|
stmt.insert_after(node)
|
|
445
444
|
node = stmt
|
|
446
445
|
|
|
447
|
-
return
|
|
446
|
+
return abc.RewriteResult(has_done_something=True)
|
|
@@ -2,7 +2,7 @@ from typing import Dict, List, Optional
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
|
|
4
4
|
from kirin import ir
|
|
5
|
-
from kirin.rewrite import abc
|
|
5
|
+
from kirin.rewrite import abc
|
|
6
6
|
|
|
7
7
|
from bloqade.analysis import address
|
|
8
8
|
from bloqade.qasm2.dialects import uop, parallel
|
|
@@ -13,11 +13,11 @@ class ParallelToUOpRule(abc.RewriteRule):
|
|
|
13
13
|
id_map: Dict[int, ir.SSAValue]
|
|
14
14
|
address_analysis: Dict[ir.SSAValue, address.Address]
|
|
15
15
|
|
|
16
|
-
def rewrite_Statement(self, node: ir.Statement) ->
|
|
16
|
+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
|
|
17
17
|
if type(node) in parallel.dialect.stmts:
|
|
18
18
|
return getattr(self, f"rewrite_{node.name}")(node)
|
|
19
19
|
|
|
20
|
-
return
|
|
20
|
+
return abc.RewriteResult()
|
|
21
21
|
|
|
22
22
|
def get_qubit_ssa(self, ilist_ref: ir.SSAValue) -> Optional[List[ir.SSAValue]]:
|
|
23
23
|
addr = self.address_analysis.get(ilist_ref)
|
|
@@ -40,7 +40,7 @@ class ParallelToUOpRule(abc.RewriteRule):
|
|
|
40
40
|
qargs = self.get_qubit_ssa(node.qargs)
|
|
41
41
|
|
|
42
42
|
if ctrls is None or qargs is None:
|
|
43
|
-
return
|
|
43
|
+
return abc.RewriteResult()
|
|
44
44
|
|
|
45
45
|
for ctrl, qarg in zip(ctrls, qargs):
|
|
46
46
|
new_node = uop.CZ(ctrl, qarg)
|
|
@@ -48,7 +48,7 @@ class ParallelToUOpRule(abc.RewriteRule):
|
|
|
48
48
|
|
|
49
49
|
node.delete()
|
|
50
50
|
|
|
51
|
-
return
|
|
51
|
+
return abc.RewriteResult(has_done_something=True)
|
|
52
52
|
|
|
53
53
|
def rewrite_u(self, node: ir.Statement):
|
|
54
54
|
assert isinstance(node, parallel.UGate)
|
|
@@ -56,7 +56,7 @@ class ParallelToUOpRule(abc.RewriteRule):
|
|
|
56
56
|
qargs = self.get_qubit_ssa(node.qargs)
|
|
57
57
|
|
|
58
58
|
if qargs is None:
|
|
59
|
-
return
|
|
59
|
+
return abc.RewriteResult()
|
|
60
60
|
|
|
61
61
|
for qarg in qargs:
|
|
62
62
|
new_node = uop.UGate(qarg, theta=node.theta, phi=node.phi, lam=node.lam)
|
|
@@ -64,7 +64,7 @@ class ParallelToUOpRule(abc.RewriteRule):
|
|
|
64
64
|
|
|
65
65
|
node.delete()
|
|
66
66
|
|
|
67
|
-
return
|
|
67
|
+
return abc.RewriteResult(has_done_something=True)
|
|
68
68
|
|
|
69
69
|
def rewrite_rz(self, node: ir.Statement):
|
|
70
70
|
assert isinstance(node, parallel.RZ)
|
|
@@ -72,7 +72,7 @@ class ParallelToUOpRule(abc.RewriteRule):
|
|
|
72
72
|
qargs = self.get_qubit_ssa(node.qargs)
|
|
73
73
|
|
|
74
74
|
if qargs is None:
|
|
75
|
-
return
|
|
75
|
+
return abc.RewriteResult()
|
|
76
76
|
|
|
77
77
|
for qarg in qargs:
|
|
78
78
|
new_node = uop.RZ(qarg, theta=node.theta)
|
|
@@ -80,4 +80,4 @@ class ParallelToUOpRule(abc.RewriteRule):
|
|
|
80
80
|
|
|
81
81
|
node.delete()
|
|
82
82
|
|
|
83
|
-
return
|
|
83
|
+
return abc.RewriteResult(has_done_something=True)
|