bloqade-circuit 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (74) hide show
  1. bloqade/analysis/address/impls.py +5 -9
  2. bloqade/analysis/address/lattice.py +1 -1
  3. bloqade/analysis/fidelity/__init__.py +1 -0
  4. bloqade/analysis/fidelity/analysis.py +69 -0
  5. bloqade/device.py +130 -0
  6. bloqade/noise/__init__.py +2 -1
  7. bloqade/noise/fidelity.py +51 -0
  8. bloqade/noise/native/model.py +1 -2
  9. bloqade/noise/native/rewrite.py +5 -5
  10. bloqade/noise/native/stmts.py +40 -11
  11. bloqade/pyqrack/__init__.py +8 -2
  12. bloqade/pyqrack/base.py +24 -3
  13. bloqade/pyqrack/device.py +166 -0
  14. bloqade/pyqrack/noise/native.py +1 -2
  15. bloqade/pyqrack/qasm2/core.py +31 -15
  16. bloqade/pyqrack/qasm2/glob.py +28 -0
  17. bloqade/pyqrack/qasm2/uop.py +9 -1
  18. bloqade/pyqrack/reg.py +17 -49
  19. bloqade/pyqrack/squin/__init__.py +0 -0
  20. bloqade/pyqrack/squin/op.py +154 -0
  21. bloqade/pyqrack/squin/qubit.py +85 -0
  22. bloqade/pyqrack/squin/runtime.py +515 -0
  23. bloqade/pyqrack/squin/wire.py +69 -0
  24. bloqade/pyqrack/target.py +9 -2
  25. bloqade/pyqrack/task.py +30 -0
  26. bloqade/qasm2/_wrappers.py +11 -1
  27. bloqade/qasm2/dialects/core/stmts.py +15 -4
  28. bloqade/qasm2/dialects/expr/_emit.py +9 -8
  29. bloqade/qasm2/emit/base.py +4 -2
  30. bloqade/qasm2/emit/gate.py +0 -14
  31. bloqade/qasm2/emit/main.py +19 -15
  32. bloqade/qasm2/emit/target.py +2 -6
  33. bloqade/qasm2/glob.py +1 -1
  34. bloqade/qasm2/parse/lowering.py +124 -1
  35. bloqade/qasm2/passes/glob.py +3 -3
  36. bloqade/qasm2/passes/lift_qubits.py +26 -0
  37. bloqade/qasm2/passes/noise.py +6 -14
  38. bloqade/qasm2/passes/parallel.py +3 -3
  39. bloqade/qasm2/passes/py2qasm.py +1 -2
  40. bloqade/qasm2/passes/qasm2py.py +1 -2
  41. bloqade/qasm2/rewrite/desugar.py +6 -6
  42. bloqade/qasm2/rewrite/glob.py +9 -9
  43. bloqade/qasm2/rewrite/heuristic_noise.py +30 -38
  44. bloqade/qasm2/rewrite/insert_qubits.py +34 -0
  45. bloqade/qasm2/rewrite/native_gates.py +54 -55
  46. bloqade/qasm2/rewrite/parallel_to_uop.py +9 -9
  47. bloqade/qasm2/rewrite/uop_to_parallel.py +20 -22
  48. bloqade/qasm2/types.py +3 -6
  49. bloqade/qbraid/schema.py +10 -12
  50. bloqade/squin/__init__.py +1 -1
  51. bloqade/squin/analysis/nsites/analysis.py +4 -6
  52. bloqade/squin/analysis/nsites/impls.py +2 -6
  53. bloqade/squin/analysis/schedule.py +1 -1
  54. bloqade/squin/groups.py +15 -7
  55. bloqade/squin/noise/__init__.py +27 -0
  56. bloqade/squin/noise/_dialect.py +3 -0
  57. bloqade/squin/noise/stmts.py +59 -0
  58. bloqade/squin/op/__init__.py +35 -5
  59. bloqade/squin/op/number.py +5 -0
  60. bloqade/squin/op/rewrite.py +46 -0
  61. bloqade/squin/op/stmts.py +23 -2
  62. bloqade/squin/op/types.py +14 -0
  63. bloqade/squin/qubit.py +79 -11
  64. bloqade/squin/rewrite/__init__.py +0 -0
  65. bloqade/squin/rewrite/measure_desugar.py +33 -0
  66. bloqade/squin/wire.py +31 -2
  67. bloqade/stim/emit/stim.py +1 -1
  68. bloqade/task.py +94 -0
  69. bloqade/visual/animation/base.py +25 -15
  70. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/METADATA +8 -2
  71. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/RECORD +73 -52
  72. bloqade/squin/op/complex.py +0 -6
  73. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/WHEEL +0 -0
  74. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,17 @@
1
- from typing import Dict, List, Tuple
1
+ from typing import Dict, List, Tuple, cast
2
2
  from dataclasses import field, dataclass
3
3
 
4
4
  from kirin import ir
5
- from kirin.rewrite import abc as result_abc, result
6
- from kirin.dialects import py, ilist
5
+ from kirin.rewrite import abc as rewrite_abc
6
+ from kirin.dialects import ilist
7
7
 
8
8
  from bloqade.noise import native
9
9
  from bloqade.analysis import address
10
- from bloqade.qasm2.dialects import uop, core, glob, parallel
10
+ from bloqade.qasm2.dialects import uop, glob, parallel
11
11
 
12
12
 
13
13
  @dataclass
14
- class NoiseRewriteRule(result_abc.RewriteRule):
14
+ class NoiseRewriteRule(rewrite_abc.RewriteRule):
15
15
  """
16
16
  NOTE: This pass is not guaranteed to be supported long-term in bloqade. We will be
17
17
  moving towards a more general approach to noise modeling in the future.
@@ -24,12 +24,18 @@ class NoiseRewriteRule(result_abc.RewriteRule):
24
24
  noise_model: native.MoveNoiseModelABC = field(
25
25
  default_factory=native.TwoRowZoneModel
26
26
  )
27
- qubit_ssa_value: Dict[int, ir.SSAValue] = field(default_factory=dict, init=False)
28
27
 
29
- def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
30
- if isinstance(node, core.QRegNew):
31
- return self.rewrite_qreg_new(node)
32
- elif isinstance(node, uop.SingleQubitGate):
28
+ def __post_init__(self):
29
+ self.qubit_ssa_value: Dict[int, ir.SSAValue] = {}
30
+ for ssa_value, addr in self.address_analysis.items():
31
+ if (
32
+ isinstance(addr, address.AddressQubit)
33
+ and ssa_value not in self.qubit_ssa_value
34
+ ):
35
+ self.qubit_ssa_value[addr.data] = ssa_value
36
+
37
+ def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
38
+ if isinstance(node, uop.SingleQubitGate):
33
39
  return self.rewrite_single_qubit_gate(node)
34
40
  elif isinstance(node, uop.CZ):
35
41
  return self.rewrite_cz_gate(node)
@@ -40,25 +46,7 @@ class NoiseRewriteRule(result_abc.RewriteRule):
40
46
  elif isinstance(node, glob.UGate):
41
47
  return self.rewrite_global_single_qubit_gate(node)
42
48
  else:
43
- return result.RewriteResult()
44
-
45
- def rewrite_qreg_new(self, node: core.QRegNew):
46
-
47
- addr = self.address_analysis[node.result]
48
- if not isinstance(addr, address.AddressReg):
49
- return result.RewriteResult()
50
-
51
- has_done_something = False
52
- for idx_val, qid in enumerate(addr.data):
53
- if qid not in self.qubit_ssa_value:
54
- has_done_something = True
55
- idx = py.constant.Constant(value=idx_val)
56
- qubit = core.QRegGet(node.result, idx=idx.result)
57
- self.qubit_ssa_value[qid] = qubit.result
58
- qubit.insert_after(node)
59
- idx.insert_after(node)
60
-
61
- return result.RewriteResult(has_done_something=has_done_something)
49
+ return rewrite_abc.RewriteResult()
62
50
 
63
51
  def insert_single_qubit_noise(
64
52
  self,
@@ -71,7 +59,7 @@ class NoiseRewriteRule(result_abc.RewriteRule):
71
59
  )
72
60
  native.AtomLossChannel(qargs, prob=probs[3]).insert_before(node)
73
61
 
74
- return result.RewriteResult(has_done_something=True)
62
+ return rewrite_abc.RewriteResult(has_done_something=True)
75
63
 
76
64
  def rewrite_single_qubit_gate(self, node: uop.SingleQubitGate):
77
65
  probs = (
@@ -86,13 +74,13 @@ class NoiseRewriteRule(result_abc.RewriteRule):
86
74
  def rewrite_global_single_qubit_gate(self, node: glob.UGate):
87
75
  addrs = self.address_analysis[node.registers]
88
76
  if not isinstance(addrs, address.AddressTuple):
89
- return result.RewriteResult()
77
+ return rewrite_abc.RewriteResult()
90
78
 
91
79
  qargs = []
92
80
 
93
81
  for addr in addrs.data:
94
82
  if not isinstance(addr, address.AddressReg):
95
- return result.RewriteResult()
83
+ return rewrite_abc.RewriteResult()
96
84
 
97
85
  for qid in addr.data:
98
86
  qargs.append(self.qubit_ssa_value[qid])
@@ -109,10 +97,10 @@ class NoiseRewriteRule(result_abc.RewriteRule):
109
97
  def rewrite_parallel_single_qubit_gate(self, node: parallel.RZ | parallel.UGate):
110
98
  addrs = self.address_analysis[node.qargs]
111
99
  if not isinstance(addrs, address.AddressTuple):
112
- return result.RewriteResult()
100
+ return rewrite_abc.RewriteResult()
113
101
 
114
102
  if not all(isinstance(addr, address.AddressQubit) for addr in addrs.data):
115
- return result.RewriteResult()
103
+ return rewrite_abc.RewriteResult()
116
104
 
117
105
  probs = (
118
106
  self.gate_noise_params.local_px,
@@ -213,7 +201,7 @@ class NoiseRewriteRule(result_abc.RewriteRule):
213
201
  new_node.insert_before(node)
214
202
  has_done_something = True
215
203
 
216
- return result.RewriteResult(has_done_something=has_done_something)
204
+ return rewrite_abc.RewriteResult(has_done_something=has_done_something)
217
205
 
218
206
  def rewrite_parallel_cz_gate(self, node: parallel.CZ):
219
207
  ctrls = self.address_analysis[node.ctrls]
@@ -226,8 +214,12 @@ class NoiseRewriteRule(result_abc.RewriteRule):
226
214
  and isinstance(qargs, address.AddressTuple)
227
215
  and all(isinstance(addr, address.AddressQubit) for addr in qargs.data)
228
216
  ):
229
- ctrl_qubits = list(map(lambda addr: addr.data, ctrls.data))
230
- qarg_qubits = list(map(lambda addr: addr.data, qargs.data))
217
+ ctrl_qubits = list(
218
+ map(lambda addr: cast(address.AddressQubit, addr).data, ctrls.data)
219
+ )
220
+ qarg_qubits = list(
221
+ map(lambda addr: cast(address.AddressQubit, addr).data, qargs.data)
222
+ )
231
223
  rest = sorted(
232
224
  set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits)
233
225
  )
@@ -244,4 +236,4 @@ class NoiseRewriteRule(result_abc.RewriteRule):
244
236
  new_node.insert_before(node)
245
237
  has_done_something = True
246
238
 
247
- return result.RewriteResult(has_done_something=has_done_something)
239
+ return rewrite_abc.RewriteResult(has_done_something=has_done_something)
@@ -0,0 +1,34 @@
1
+ from kirin import ir
2
+ from kirin.rewrite import abc as rewrite_abc
3
+ from kirin.dialects import py
4
+
5
+
6
+ class InsertGetQubit(rewrite_abc.RewriteRule):
7
+
8
+ def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
9
+ from bloqade.qasm2 import core
10
+
11
+ if (
12
+ not isinstance(node, core.QRegNew)
13
+ or not isinstance(n_qubits_stmt := node.n_qubits.owner, py.Constant)
14
+ or not isinstance(n_qubits := n_qubits_stmt.value.unwrap(), int)
15
+ or (block := node.parent_block) is None
16
+ ):
17
+ return rewrite_abc.RewriteResult()
18
+
19
+ n_qubits_stmt.detach()
20
+ node.detach()
21
+ if block.first_stmt is None:
22
+ block.stmts.append(n_qubits_stmt)
23
+ block.stmts.append(node)
24
+ else:
25
+ node.insert_before(block.first_stmt)
26
+ n_qubits_stmt.insert_before(block.first_stmt)
27
+
28
+ for idx_val in range(n_qubits):
29
+ idx = py.constant.Constant(value=idx_val)
30
+ qubit = core.QRegGet(node.result, idx=idx.result)
31
+ qubit.insert_after(node)
32
+ idx.insert_after(node)
33
+
34
+ return rewrite_abc.RewriteResult(has_done_something=True)
@@ -10,7 +10,7 @@ import cirq.contrib.qasm_import
10
10
  import cirq.transformers.target_gatesets
11
11
  import cirq.transformers.target_gatesets.compilation_target_gateset
12
12
  from kirin import ir
13
- from kirin.rewrite import abc, result
13
+ from kirin.rewrite import abc
14
14
  from kirin.dialects import py
15
15
  from cirq.circuits.qasm_output import QasmUGate
16
16
  from cirq.transformers.target_gatesets.compilation_target_gateset import (
@@ -70,7 +70,7 @@ class CU(cirq.Gate):
70
70
  return "*", "CU"
71
71
 
72
72
 
73
- def around(val):
73
+ def around(val) -> float:
74
74
  return float(np.around(val, 14))
75
75
 
76
76
 
@@ -78,7 +78,7 @@ def one_qubit_gate_to_u3_angles(op: cirq.Operation) -> tuple[float, float, float
78
78
  lam, theta, phi = ( # Z angle, Y angle, then Z angle
79
79
  cirq.deconstruct_single_qubit_matrix_into_angles(cirq.unitary(op))
80
80
  )
81
- return tuple(map(around, (theta, phi, lam)))
81
+ return around(theta), around(phi), around(lam)
82
82
 
83
83
 
84
84
  @dataclass
@@ -111,25 +111,25 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
111
111
  else:
112
112
  return py.constant.Constant(value=math.pi)
113
113
 
114
- def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
114
+ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
115
115
  # only deal with uop
116
116
  if type(node) in uop.dialect.stmts:
117
117
  return getattr(self, f"rewrite_{node.name}")(node)
118
118
 
119
- return result.RewriteResult()
119
+ return abc.RewriteResult()
120
120
 
121
- def rewrite_barrier(self, node: uop.Barrier) -> result.RewriteResult:
122
- return result.RewriteResult()
121
+ def rewrite_barrier(self, node: uop.Barrier) -> abc.RewriteResult:
122
+ return abc.RewriteResult()
123
123
 
124
- def rewrite_cz(self, node: uop.CZ) -> result.RewriteResult:
125
- return result.RewriteResult()
124
+ def rewrite_cz(self, node: uop.CZ) -> abc.RewriteResult:
125
+ return abc.RewriteResult()
126
126
 
127
- def rewrite_CX(self, node: uop.CX) -> result.RewriteResult:
127
+ def rewrite_CX(self, node: uop.CX) -> abc.RewriteResult:
128
128
  return self._rewrite_2q_ctrl_gates(
129
129
  cirq.CX(self.cached_qubits[0], self.cached_qubits[1]), node
130
130
  )
131
131
 
132
- def rewrite_cy(self, node: uop.CY) -> result.RewriteResult:
132
+ def rewrite_cy(self, node: uop.CY) -> abc.RewriteResult:
133
133
  return self._rewrite_2q_ctrl_gates(
134
134
  cirq.ControlledGate(cirq.Y, 1)(
135
135
  self.cached_qubits[0], self.cached_qubits[1]
@@ -137,92 +137,92 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
137
137
  node,
138
138
  )
139
139
 
140
- def rewrite_U(self, node: uop.UGate) -> result.RewriteResult:
141
- return result.RewriteResult()
140
+ def rewrite_U(self, node: uop.UGate) -> abc.RewriteResult:
141
+ return abc.RewriteResult()
142
142
 
143
- def rewrite_id(self, node: uop.Id) -> result.RewriteResult:
143
+ def rewrite_id(self, node: uop.Id) -> abc.RewriteResult:
144
144
  node.delete() # just delete the identity gate
145
- return result.RewriteResult(has_done_something=True)
145
+ return abc.RewriteResult(has_done_something=True)
146
146
 
147
- def rewrite_h(self, node: uop.H) -> result.RewriteResult:
147
+ def rewrite_h(self, node: uop.H) -> abc.RewriteResult:
148
148
  return self._rewrite_1q_gates(cirq.H(self.cached_qubits[0]), node)
149
149
 
150
- def rewrite_x(self, node: uop.X) -> result.RewriteResult:
150
+ def rewrite_x(self, node: uop.X) -> abc.RewriteResult:
151
151
  return self._rewrite_1q_gates(cirq.X(self.cached_qubits[0]), node)
152
152
 
153
- def rewrite_y(self, node: uop.Y) -> result.RewriteResult:
153
+ def rewrite_y(self, node: uop.Y) -> abc.RewriteResult:
154
154
  return self._rewrite_1q_gates(cirq.Y(self.cached_qubits[0]), node)
155
155
 
156
- def rewrite_z(self, node: uop.Z) -> result.RewriteResult:
156
+ def rewrite_z(self, node: uop.Z) -> abc.RewriteResult:
157
157
  return self._rewrite_1q_gates(cirq.Z(self.cached_qubits[0]), node)
158
158
 
159
- def rewrite_s(self, node: uop.S) -> result.RewriteResult:
159
+ def rewrite_s(self, node: uop.S) -> abc.RewriteResult:
160
160
  return self._rewrite_1q_gates(cirq.S(self.cached_qubits[0]), node)
161
161
 
162
- def rewrite_sdg(self, node: uop.Sdag) -> result.RewriteResult:
162
+ def rewrite_sdg(self, node: uop.Sdag) -> abc.RewriteResult:
163
163
  return self._rewrite_1q_gates(cirq.S(self.cached_qubits[0]) ** -1, node)
164
164
 
165
- def rewrite_t(self, node: uop.T) -> result.RewriteResult:
165
+ def rewrite_t(self, node: uop.T) -> abc.RewriteResult:
166
166
  return self._rewrite_1q_gates(cirq.T(self.cached_qubits[0]), node)
167
167
 
168
- def rewrite_tdg(self, node: uop.Tdag) -> result.RewriteResult:
168
+ def rewrite_tdg(self, node: uop.Tdag) -> abc.RewriteResult:
169
169
  return self._rewrite_1q_gates(cirq.T(self.cached_qubits[0]) ** -1, node)
170
170
 
171
- def rewrite_sx(self, node: uop.SX) -> result.RewriteResult:
171
+ def rewrite_sx(self, node: uop.SX) -> abc.RewriteResult:
172
172
  return self._rewrite_1q_gates(
173
173
  cirq.XPowGate(exponent=0.5).on(self.cached_qubits[0]), node
174
174
  )
175
175
 
176
- def rewrite_sxdg(self, node: uop.SXdag) -> result.RewriteResult:
176
+ def rewrite_sxdg(self, node: uop.SXdag) -> abc.RewriteResult:
177
177
  return self._rewrite_1q_gates(
178
178
  cirq.XPowGate(exponent=-0.5).on(self.cached_qubits[0]), node
179
179
  )
180
180
 
181
- def rewrite_u1(self, node: uop.U1) -> result.RewriteResult:
181
+ def rewrite_u1(self, node: uop.U1) -> abc.RewriteResult:
182
182
  theta = node.lam
183
183
  (phi := self.const_float(value=0.0)).insert_before(node)
184
184
  node.replace_by(
185
185
  uop.UGate(qarg=node.qarg, theta=phi.result, phi=phi.result, lam=theta)
186
186
  )
187
- return result.RewriteResult(has_done_something=True)
187
+ return abc.RewriteResult(has_done_something=True)
188
188
 
189
- def rewrite_u2(self, node: uop.U2) -> result.RewriteResult:
189
+ def rewrite_u2(self, node: uop.U2) -> abc.RewriteResult:
190
190
  phi = node.phi
191
191
  lam = node.lam
192
192
  (theta := self.const_float(value=math.pi / 2)).insert_before(node)
193
193
  node.replace_by(uop.UGate(qarg=node.qarg, theta=theta.result, phi=phi, lam=lam))
194
- return result.RewriteResult(has_done_something=True)
194
+ return abc.RewriteResult(has_done_something=True)
195
195
 
196
- def rewrite_rx(self, node: uop.RX) -> result.RewriteResult:
196
+ def rewrite_rx(self, node: uop.RX) -> abc.RewriteResult:
197
197
  theta = node.theta
198
198
  (phi := self.const_float(value=math.pi / 2)).insert_before(node)
199
199
  (lam := self.const_float(value=-math.pi / 2)).insert_before(node)
200
200
  node.replace_by(
201
201
  uop.UGate(qarg=node.qarg, theta=theta, phi=phi.result, lam=lam.result)
202
202
  )
203
- return result.RewriteResult(has_done_something=True)
203
+ return abc.RewriteResult(has_done_something=True)
204
204
 
205
- def rewrite_ry(self, node: uop.RY) -> result.RewriteResult:
205
+ def rewrite_ry(self, node: uop.RY) -> abc.RewriteResult:
206
206
  theta = node.theta
207
207
  (phi := self.const_float(value=0.0)).insert_before(node)
208
208
  node.replace_by(
209
209
  uop.UGate(qarg=node.qarg, theta=theta, phi=phi.result, lam=phi.result)
210
210
  )
211
- return result.RewriteResult(has_done_something=True)
211
+ return abc.RewriteResult(has_done_something=True)
212
212
 
213
- def rewrite_rz(self, node: uop.RZ) -> result.RewriteResult:
213
+ def rewrite_rz(self, node: uop.RZ) -> abc.RewriteResult:
214
214
  theta = node.theta
215
215
  (phi := self.const_float(value=0.0)).insert_before(node)
216
216
  node.replace_by(
217
217
  uop.UGate(qarg=node.qarg, theta=phi.result, phi=phi.result, lam=theta)
218
218
  )
219
- return result.RewriteResult(has_done_something=True)
219
+ return abc.RewriteResult(has_done_something=True)
220
220
 
221
- def rewrite_crx(self, node: uop.CRX) -> result.RewriteResult:
221
+ def rewrite_crx(self, node: uop.CRX) -> abc.RewriteResult:
222
222
  lam = self._get_const_value(node.lam)
223
223
 
224
224
  if lam is None:
225
- return result.RewriteResult()
225
+ return abc.RewriteResult()
226
226
 
227
227
  return self._rewrite_2q_ctrl_gates(
228
228
  cirq.ControlledGate(cirq.Rx(rads=lam), 1).on(
@@ -231,11 +231,11 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
231
231
  node,
232
232
  )
233
233
 
234
- def rewrite_cry(self, node: uop.CRY) -> result.RewriteResult:
234
+ def rewrite_cry(self, node: uop.CRY) -> abc.RewriteResult:
235
235
  lam = self._get_const_value(node.lam)
236
236
 
237
237
  if lam is None:
238
- return result.RewriteResult()
238
+ return abc.RewriteResult()
239
239
 
240
240
  return self._rewrite_2q_ctrl_gates(
241
241
  cirq.ControlledGate(cirq.Ry(rads=lam), 1).on(
@@ -244,11 +244,11 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
244
244
  node,
245
245
  )
246
246
 
247
- def rewrite_crz(self, node: uop.CRZ) -> result.RewriteResult:
247
+ def rewrite_crz(self, node: uop.CRZ) -> abc.RewriteResult:
248
248
  lam = self._get_const_value(node.lam)
249
249
 
250
250
  if lam is None:
251
- return result.RewriteResult()
251
+ return abc.RewriteResult()
252
252
 
253
253
  return self._rewrite_2q_ctrl_gates(
254
254
  cirq.ControlledGate(cirq.Rz(rads=lam), 1).on(
@@ -257,12 +257,12 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
257
257
  node,
258
258
  )
259
259
 
260
- def rewrite_cu1(self, node: uop.CU1) -> result.RewriteResult:
260
+ def rewrite_cu1(self, node: uop.CU1) -> abc.RewriteResult:
261
261
 
262
262
  lam = self._get_const_value(node.lam)
263
263
 
264
264
  if lam is None:
265
- return result.RewriteResult()
265
+ return abc.RewriteResult()
266
266
 
267
267
  # cirq.ControlledGate(u3(0, 0, lambda))
268
268
  return self._rewrite_2q_ctrl_gates(
@@ -273,14 +273,13 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
273
273
  )
274
274
  pass
275
275
 
276
- def rewrite_cu3(self, node: uop.CU3) -> result.RewriteResult:
276
+ def rewrite_cu3(self, node: uop.CU3) -> abc.RewriteResult:
277
277
 
278
278
  theta = self._get_const_value(node.theta)
279
279
  lam = self._get_const_value(node.lam)
280
280
  phi = self._get_const_value(node.phi)
281
-
282
- if not all((theta, phi, lam)):
283
- return result.RewriteResult()
281
+ if theta is None or lam is None or phi is None:
282
+ return abc.RewriteResult()
284
283
 
285
284
  # cirq.ControlledGate(u3(theta, lambda phi))
286
285
  return self._rewrite_2q_ctrl_gates(
@@ -290,7 +289,7 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
290
289
  node,
291
290
  )
292
291
 
293
- def rewrite_cu(self, node: uop.CU) -> result.RewriteResult:
292
+ def rewrite_cu(self, node: uop.CU) -> abc.RewriteResult:
294
293
 
295
294
  gamma = self._get_const_value(node.gamma)
296
295
  theta = self._get_const_value(node.theta)
@@ -304,12 +303,12 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
304
303
  node,
305
304
  )
306
305
 
307
- def rewrite_rxx(self, node: uop.RXX) -> result.RewriteResult:
306
+ def rewrite_rxx(self, node: uop.RXX) -> abc.RewriteResult:
308
307
 
309
308
  theta = self._get_const_value(node.theta)
310
309
 
311
310
  if theta is None:
312
- return result.RewriteResult()
311
+ return abc.RewriteResult()
313
312
 
314
313
  # even though the XX gate is not controlled,
315
314
  # the end U + CZ decomposition that happens internally means
@@ -320,11 +319,11 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
320
319
  node,
321
320
  )
322
321
 
323
- def rewrite_rzz(self, node: uop.RZZ) -> result.RewriteResult:
322
+ def rewrite_rzz(self, node: uop.RZZ) -> abc.RewriteResult:
324
323
  theta = self._get_const_value(node.theta)
325
324
 
326
325
  if theta is None:
327
- return result.RewriteResult()
326
+ return abc.RewriteResult()
328
327
 
329
328
  return self._rewrite_2q_ctrl_gates(
330
329
  cirq.ZZPowGate(exponent=theta / math.pi).on(
@@ -391,7 +390,7 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
391
390
 
392
391
  def _rewrite_1q_gates(
393
392
  self, cirq_gate: cirq.Operation, node: uop.SingleQubitGate
394
- ) -> result.RewriteResult:
393
+ ) -> abc.RewriteResult:
395
394
  new_gate_stmts = self._generate_1q_gate_stmts(cirq_gate, node.qarg)
396
395
  return self._rewrite_gate_stmts(new_gate_stmts, node)
397
396
 
@@ -427,7 +426,7 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
427
426
 
428
427
  def _rewrite_2q_ctrl_gates(
429
428
  self, cirq_gate: cirq.Operation, node: uop.TwoQubitCtrlGate
430
- ) -> result.RewriteResult:
429
+ ) -> abc.RewriteResult:
431
430
  new_gate_stmts = self._generate_2q_ctrl_gate_stmts(
432
431
  cirq_gate, [node.ctrl, node.qarg]
433
432
  )
@@ -444,4 +443,4 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
444
443
  stmt.insert_after(node)
445
444
  node = stmt
446
445
 
447
- return result.RewriteResult(has_done_something=True)
446
+ return abc.RewriteResult(has_done_something=True)
@@ -2,7 +2,7 @@ from typing import Dict, List, Optional
2
2
  from dataclasses import dataclass
3
3
 
4
4
  from kirin import ir
5
- from kirin.rewrite import abc, result
5
+ from kirin.rewrite import abc
6
6
 
7
7
  from bloqade.analysis import address
8
8
  from bloqade.qasm2.dialects import uop, parallel
@@ -13,11 +13,11 @@ class ParallelToUOpRule(abc.RewriteRule):
13
13
  id_map: Dict[int, ir.SSAValue]
14
14
  address_analysis: Dict[ir.SSAValue, address.Address]
15
15
 
16
- def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
16
+ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
17
17
  if type(node) in parallel.dialect.stmts:
18
18
  return getattr(self, f"rewrite_{node.name}")(node)
19
19
 
20
- return result.RewriteResult()
20
+ return abc.RewriteResult()
21
21
 
22
22
  def get_qubit_ssa(self, ilist_ref: ir.SSAValue) -> Optional[List[ir.SSAValue]]:
23
23
  addr = self.address_analysis.get(ilist_ref)
@@ -40,7 +40,7 @@ class ParallelToUOpRule(abc.RewriteRule):
40
40
  qargs = self.get_qubit_ssa(node.qargs)
41
41
 
42
42
  if ctrls is None or qargs is None:
43
- return result.RewriteResult()
43
+ return abc.RewriteResult()
44
44
 
45
45
  for ctrl, qarg in zip(ctrls, qargs):
46
46
  new_node = uop.CZ(ctrl, qarg)
@@ -48,7 +48,7 @@ class ParallelToUOpRule(abc.RewriteRule):
48
48
 
49
49
  node.delete()
50
50
 
51
- return result.RewriteResult(has_done_something=True)
51
+ return abc.RewriteResult(has_done_something=True)
52
52
 
53
53
  def rewrite_u(self, node: ir.Statement):
54
54
  assert isinstance(node, parallel.UGate)
@@ -56,7 +56,7 @@ class ParallelToUOpRule(abc.RewriteRule):
56
56
  qargs = self.get_qubit_ssa(node.qargs)
57
57
 
58
58
  if qargs is None:
59
- return result.RewriteResult()
59
+ return abc.RewriteResult()
60
60
 
61
61
  for qarg in qargs:
62
62
  new_node = uop.UGate(qarg, theta=node.theta, phi=node.phi, lam=node.lam)
@@ -64,7 +64,7 @@ class ParallelToUOpRule(abc.RewriteRule):
64
64
 
65
65
  node.delete()
66
66
 
67
- return result.RewriteResult(has_done_something=True)
67
+ return abc.RewriteResult(has_done_something=True)
68
68
 
69
69
  def rewrite_rz(self, node: ir.Statement):
70
70
  assert isinstance(node, parallel.RZ)
@@ -72,7 +72,7 @@ class ParallelToUOpRule(abc.RewriteRule):
72
72
  qargs = self.get_qubit_ssa(node.qargs)
73
73
 
74
74
  if qargs is None:
75
- return result.RewriteResult()
75
+ return abc.RewriteResult()
76
76
 
77
77
  for qarg in qargs:
78
78
  new_node = uop.RZ(qarg, theta=node.theta)
@@ -80,4 +80,4 @@ class ParallelToUOpRule(abc.RewriteRule):
80
80
 
81
81
  node.delete()
82
82
 
83
- return result.RewriteResult(has_done_something=True)
83
+ return abc.RewriteResult(has_done_something=True)