bloqade-circuit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (153) hide show
  1. bloqade/analysis/__init__.py +0 -0
  2. bloqade/analysis/address/__init__.py +11 -0
  3. bloqade/analysis/address/analysis.py +60 -0
  4. bloqade/analysis/address/impls.py +228 -0
  5. bloqade/analysis/address/lattice.py +85 -0
  6. bloqade/noise/__init__.py +1 -0
  7. bloqade/noise/native/__init__.py +20 -0
  8. bloqade/noise/native/_dialect.py +3 -0
  9. bloqade/noise/native/_wrappers.py +34 -0
  10. bloqade/noise/native/model.py +347 -0
  11. bloqade/noise/native/rewrite.py +35 -0
  12. bloqade/noise/native/stmts.py +46 -0
  13. bloqade/pyqrack/__init__.py +18 -0
  14. bloqade/pyqrack/base.py +131 -0
  15. bloqade/pyqrack/noise/__init__.py +0 -0
  16. bloqade/pyqrack/noise/native.py +100 -0
  17. bloqade/pyqrack/qasm2/__init__.py +0 -0
  18. bloqade/pyqrack/qasm2/core.py +79 -0
  19. bloqade/pyqrack/qasm2/parallel.py +46 -0
  20. bloqade/pyqrack/qasm2/uop.py +247 -0
  21. bloqade/pyqrack/reg.py +109 -0
  22. bloqade/pyqrack/target.py +112 -0
  23. bloqade/qasm2/__init__.py +19 -0
  24. bloqade/qasm2/_wrappers.py +674 -0
  25. bloqade/qasm2/dialects/__init__.py +10 -0
  26. bloqade/qasm2/dialects/core/__init__.py +3 -0
  27. bloqade/qasm2/dialects/core/_dialect.py +3 -0
  28. bloqade/qasm2/dialects/core/_emit.py +68 -0
  29. bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
  30. bloqade/qasm2/dialects/core/address.py +38 -0
  31. bloqade/qasm2/dialects/core/stmts.py +94 -0
  32. bloqade/qasm2/dialects/expr/__init__.py +3 -0
  33. bloqade/qasm2/dialects/expr/_dialect.py +3 -0
  34. bloqade/qasm2/dialects/expr/_emit.py +103 -0
  35. bloqade/qasm2/dialects/expr/_from_python.py +86 -0
  36. bloqade/qasm2/dialects/expr/_interp.py +75 -0
  37. bloqade/qasm2/dialects/expr/stmts.py +262 -0
  38. bloqade/qasm2/dialects/glob.py +45 -0
  39. bloqade/qasm2/dialects/indexing.py +64 -0
  40. bloqade/qasm2/dialects/inline.py +76 -0
  41. bloqade/qasm2/dialects/noise.py +16 -0
  42. bloqade/qasm2/dialects/parallel.py +110 -0
  43. bloqade/qasm2/dialects/uop/__init__.py +4 -0
  44. bloqade/qasm2/dialects/uop/_dialect.py +3 -0
  45. bloqade/qasm2/dialects/uop/_emit.py +211 -0
  46. bloqade/qasm2/dialects/uop/schedule.py +89 -0
  47. bloqade/qasm2/dialects/uop/stmts.py +325 -0
  48. bloqade/qasm2/emit/__init__.py +1 -0
  49. bloqade/qasm2/emit/base.py +72 -0
  50. bloqade/qasm2/emit/gate.py +102 -0
  51. bloqade/qasm2/emit/main.py +106 -0
  52. bloqade/qasm2/emit/target.py +165 -0
  53. bloqade/qasm2/glob.py +24 -0
  54. bloqade/qasm2/groups.py +120 -0
  55. bloqade/qasm2/parallel.py +48 -0
  56. bloqade/qasm2/parse/__init__.py +37 -0
  57. bloqade/qasm2/parse/ast.py +235 -0
  58. bloqade/qasm2/parse/build.py +289 -0
  59. bloqade/qasm2/parse/lowering.py +553 -0
  60. bloqade/qasm2/parse/parser.py +5 -0
  61. bloqade/qasm2/parse/print.py +293 -0
  62. bloqade/qasm2/parse/qasm2.lark +75 -0
  63. bloqade/qasm2/parse/visitor.py +16 -0
  64. bloqade/qasm2/parse/visitor.pyi +39 -0
  65. bloqade/qasm2/passes/__init__.py +5 -0
  66. bloqade/qasm2/passes/fold.py +94 -0
  67. bloqade/qasm2/passes/glob.py +119 -0
  68. bloqade/qasm2/passes/noise.py +61 -0
  69. bloqade/qasm2/passes/parallel.py +176 -0
  70. bloqade/qasm2/passes/py2qasm.py +63 -0
  71. bloqade/qasm2/passes/qasm2py.py +61 -0
  72. bloqade/qasm2/rewrite/__init__.py +12 -0
  73. bloqade/qasm2/rewrite/desugar.py +28 -0
  74. bloqade/qasm2/rewrite/glob.py +103 -0
  75. bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
  76. bloqade/qasm2/rewrite/native_gates.py +447 -0
  77. bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
  78. bloqade/qasm2/rewrite/register.py +45 -0
  79. bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
  80. bloqade/qasm2/types.py +39 -0
  81. bloqade/qbraid/__init__.py +2 -0
  82. bloqade/qbraid/lowering.py +324 -0
  83. bloqade/qbraid/schema.py +252 -0
  84. bloqade/qbraid/simulation_result.py +99 -0
  85. bloqade/qbraid/target.py +86 -0
  86. bloqade/squin/__init__.py +2 -0
  87. bloqade/squin/analysis/__init__.py +0 -0
  88. bloqade/squin/analysis/nsites/__init__.py +8 -0
  89. bloqade/squin/analysis/nsites/analysis.py +52 -0
  90. bloqade/squin/analysis/nsites/impls.py +69 -0
  91. bloqade/squin/analysis/nsites/lattice.py +49 -0
  92. bloqade/squin/analysis/schedule.py +244 -0
  93. bloqade/squin/groups.py +38 -0
  94. bloqade/squin/op/__init__.py +132 -0
  95. bloqade/squin/op/_dialect.py +3 -0
  96. bloqade/squin/op/complex.py +6 -0
  97. bloqade/squin/op/stmts.py +220 -0
  98. bloqade/squin/op/traits.py +43 -0
  99. bloqade/squin/op/types.py +10 -0
  100. bloqade/squin/qubit.py +118 -0
  101. bloqade/squin/wire.py +103 -0
  102. bloqade/stim/__init__.py +6 -0
  103. bloqade/stim/_wrappers.py +186 -0
  104. bloqade/stim/dialects/__init__.py +5 -0
  105. bloqade/stim/dialects/aux/__init__.py +11 -0
  106. bloqade/stim/dialects/aux/_dialect.py +3 -0
  107. bloqade/stim/dialects/aux/emit.py +102 -0
  108. bloqade/stim/dialects/aux/interp.py +39 -0
  109. bloqade/stim/dialects/aux/lowering.py +40 -0
  110. bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
  111. bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
  112. bloqade/stim/dialects/aux/stmts/const.py +95 -0
  113. bloqade/stim/dialects/aux/types.py +19 -0
  114. bloqade/stim/dialects/collapse/__init__.py +3 -0
  115. bloqade/stim/dialects/collapse/_dialect.py +3 -0
  116. bloqade/stim/dialects/collapse/emit.py +68 -0
  117. bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
  118. bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
  119. bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
  120. bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
  121. bloqade/stim/dialects/gate/__init__.py +3 -0
  122. bloqade/stim/dialects/gate/_dialect.py +3 -0
  123. bloqade/stim/dialects/gate/emit.py +87 -0
  124. bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
  125. bloqade/stim/dialects/gate/stmts/base.py +31 -0
  126. bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
  127. bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
  128. bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
  129. bloqade/stim/dialects/gate/stmts/pp.py +15 -0
  130. bloqade/stim/dialects/noise/__init__.py +3 -0
  131. bloqade/stim/dialects/noise/_dialect.py +3 -0
  132. bloqade/stim/dialects/noise/emit.py +66 -0
  133. bloqade/stim/dialects/noise/stmts.py +77 -0
  134. bloqade/stim/emit/__init__.py +1 -0
  135. bloqade/stim/emit/stim.py +54 -0
  136. bloqade/stim/groups.py +26 -0
  137. bloqade/test_utils.py +35 -0
  138. bloqade/types.py +24 -0
  139. bloqade/visual/__init__.py +1 -0
  140. bloqade/visual/animation/__init__.py +0 -0
  141. bloqade/visual/animation/animate.py +267 -0
  142. bloqade/visual/animation/base.py +346 -0
  143. bloqade/visual/animation/gate_event.py +24 -0
  144. bloqade/visual/animation/runtime/__init__.py +0 -0
  145. bloqade/visual/animation/runtime/aod.py +36 -0
  146. bloqade/visual/animation/runtime/atoms.py +55 -0
  147. bloqade/visual/animation/runtime/ppoly.py +50 -0
  148. bloqade/visual/animation/runtime/qpustate.py +119 -0
  149. bloqade/visual/animation/runtime/utils.py +43 -0
  150. bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
  151. bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
  152. bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
  153. bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,247 @@
1
+ from typing import Dict, List, Tuple
2
+ from dataclasses import field, dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.rewrite import abc as result_abc, result
6
+ from kirin.dialects import py, ilist
7
+
8
+ from bloqade.noise import native
9
+ from bloqade.analysis import address
10
+ from bloqade.qasm2.dialects import uop, core, glob, parallel
11
+
12
+
13
+ @dataclass
14
+ class NoiseRewriteRule(result_abc.RewriteRule):
15
+ """
16
+ NOTE: This pass is not guaranteed to be supported long-term in bloqade. We will be
17
+ moving towards a more general approach to noise modeling in the future.
18
+ """
19
+
20
+ address_analysis: Dict[ir.SSAValue, address.Address]
21
+ gate_noise_params: native.GateNoiseParams = field(
22
+ default_factory=native.GateNoiseParams
23
+ )
24
+ noise_model: native.MoveNoiseModelABC = field(
25
+ default_factory=native.TwoRowZoneModel
26
+ )
27
+ qubit_ssa_value: Dict[int, ir.SSAValue] = field(default_factory=dict, init=False)
28
+
29
+ def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
30
+ if isinstance(node, core.QRegNew):
31
+ return self.rewrite_qreg_new(node)
32
+ elif isinstance(node, uop.SingleQubitGate):
33
+ return self.rewrite_single_qubit_gate(node)
34
+ elif isinstance(node, uop.CZ):
35
+ return self.rewrite_cz_gate(node)
36
+ elif isinstance(node, (parallel.UGate, parallel.RZ)):
37
+ return self.rewrite_parallel_single_qubit_gate(node)
38
+ elif isinstance(node, parallel.CZ):
39
+ return self.rewrite_parallel_cz_gate(node)
40
+ elif isinstance(node, glob.UGate):
41
+ return self.rewrite_global_single_qubit_gate(node)
42
+ else:
43
+ return result.RewriteResult()
44
+
45
+ def rewrite_qreg_new(self, node: core.QRegNew):
46
+
47
+ addr = self.address_analysis[node.result]
48
+ if not isinstance(addr, address.AddressReg):
49
+ return result.RewriteResult()
50
+
51
+ has_done_something = False
52
+ for idx_val, qid in enumerate(addr.data):
53
+ if qid not in self.qubit_ssa_value:
54
+ has_done_something = True
55
+ idx = py.constant.Constant(value=idx_val)
56
+ qubit = core.QRegGet(node.result, idx=idx.result)
57
+ self.qubit_ssa_value[qid] = qubit.result
58
+ qubit.insert_after(node)
59
+ idx.insert_after(node)
60
+
61
+ return result.RewriteResult(has_done_something=has_done_something)
62
+
63
+ def insert_single_qubit_noise(
64
+ self,
65
+ node: ir.Statement,
66
+ qargs: ir.SSAValue,
67
+ probs: Tuple[float, float, float, float],
68
+ ):
69
+ native.PauliChannel(qargs, px=probs[0], py=probs[1], pz=probs[2]).insert_before(
70
+ node
71
+ )
72
+ native.AtomLossChannel(qargs, prob=probs[3]).insert_before(node)
73
+
74
+ return result.RewriteResult(has_done_something=True)
75
+
76
+ def rewrite_single_qubit_gate(self, node: uop.SingleQubitGate):
77
+ probs = (
78
+ self.gate_noise_params.local_px,
79
+ self.gate_noise_params.local_py,
80
+ self.gate_noise_params.local_pz,
81
+ self.gate_noise_params.local_loss_prob,
82
+ )
83
+ (qargs := ilist.New(values=(node.qarg,))).insert_before(node)
84
+ return self.insert_single_qubit_noise(node, qargs.result, probs)
85
+
86
+ def rewrite_global_single_qubit_gate(self, node: glob.UGate):
87
+ addrs = self.address_analysis[node.registers]
88
+ if not isinstance(addrs, address.AddressTuple):
89
+ return result.RewriteResult()
90
+
91
+ qargs = []
92
+
93
+ for addr in addrs.data:
94
+ if not isinstance(addr, address.AddressReg):
95
+ return result.RewriteResult()
96
+
97
+ for qid in addr.data:
98
+ qargs.append(self.qubit_ssa_value[qid])
99
+
100
+ probs = (
101
+ self.gate_noise_params.global_px,
102
+ self.gate_noise_params.global_py,
103
+ self.gate_noise_params.global_pz,
104
+ self.gate_noise_params.global_loss_prob,
105
+ )
106
+ (qargs := ilist.New(values=tuple(qargs))).insert_before(node)
107
+ return self.insert_single_qubit_noise(node, qargs.result, probs)
108
+
109
+ def rewrite_parallel_single_qubit_gate(self, node: parallel.RZ | parallel.UGate):
110
+ addrs = self.address_analysis[node.qargs]
111
+ if not isinstance(addrs, address.AddressTuple):
112
+ return result.RewriteResult()
113
+
114
+ if not all(isinstance(addr, address.AddressQubit) for addr in addrs.data):
115
+ return result.RewriteResult()
116
+
117
+ probs = (
118
+ self.gate_noise_params.local_px,
119
+ self.gate_noise_params.local_py,
120
+ self.gate_noise_params.local_pz,
121
+ self.gate_noise_params.local_loss_prob,
122
+ )
123
+ assert isinstance(node.qargs, ir.ResultValue)
124
+ assert isinstance(node.qargs.stmt, ilist.New)
125
+ return self.insert_single_qubit_noise(node, node.qargs, probs)
126
+
127
+ def move_noise_stmts(
128
+ self,
129
+ errors: Dict[Tuple[float, float, float, float], List[int]],
130
+ ) -> list[ir.Statement]:
131
+
132
+ nodes = []
133
+
134
+ for probs, qubits in errors.items():
135
+ if len(qubits) == 0:
136
+ continue
137
+
138
+ nodes.append(
139
+ qargs := ilist.New(tuple(self.qubit_ssa_value[q] for q in qubits))
140
+ )
141
+ nodes.append(native.AtomLossChannel(qargs.result, prob=probs[3]))
142
+ nodes.append(
143
+ native.PauliChannel(qargs.result, px=probs[0], py=probs[1], pz=probs[2])
144
+ )
145
+
146
+ return nodes
147
+
148
+ def cz_gate_noise(
149
+ self,
150
+ ctrls: ir.SSAValue,
151
+ qargs: ir.SSAValue,
152
+ ) -> list[ir.Statement]:
153
+ return [
154
+ native.CZPauliChannel(
155
+ ctrls,
156
+ qargs,
157
+ px_ctrl=self.gate_noise_params.cz_paired_gate_px,
158
+ py_ctrl=self.gate_noise_params.cz_paired_gate_py,
159
+ pz_ctrl=self.gate_noise_params.cz_paired_gate_pz,
160
+ px_qarg=self.gate_noise_params.cz_paired_gate_px,
161
+ py_qarg=self.gate_noise_params.cz_paired_gate_py,
162
+ pz_qarg=self.gate_noise_params.cz_paired_gate_pz,
163
+ paired=True,
164
+ ),
165
+ native.CZPauliChannel(
166
+ ctrls,
167
+ qargs,
168
+ px_ctrl=self.gate_noise_params.cz_unpaired_gate_px,
169
+ py_ctrl=self.gate_noise_params.cz_unpaired_gate_py,
170
+ pz_ctrl=self.gate_noise_params.cz_unpaired_gate_pz,
171
+ px_qarg=self.gate_noise_params.cz_unpaired_gate_px,
172
+ py_qarg=self.gate_noise_params.cz_unpaired_gate_py,
173
+ pz_qarg=self.gate_noise_params.cz_unpaired_gate_pz,
174
+ paired=False,
175
+ ),
176
+ native.AtomLossChannel(
177
+ ctrls, prob=self.gate_noise_params.cz_gate_loss_prob
178
+ ),
179
+ native.AtomLossChannel(
180
+ qargs, prob=self.gate_noise_params.cz_gate_loss_prob
181
+ ),
182
+ ]
183
+
184
+ def rewrite_cz_gate(self, node: uop.CZ):
185
+
186
+ has_done_something = False
187
+
188
+ qarg_addr = self.address_analysis[node.qarg]
189
+ ctrl_addr = self.address_analysis[node.ctrl]
190
+
191
+ (ctrls := ilist.New([node.ctrl])).insert_before(node)
192
+ (qargs := ilist.New([node.qarg])).insert_before(node)
193
+
194
+ if isinstance(qarg_addr, address.AddressQubit) and isinstance(
195
+ ctrl_addr, address.AddressQubit
196
+ ):
197
+ other_qubits = sorted(
198
+ set(self.qubit_ssa_value.keys()) - {ctrl_addr.data, qarg_addr.data}
199
+ )
200
+ errors = self.noise_model.parallel_cz_errors(
201
+ [ctrl_addr.data], [qarg_addr.data], other_qubits
202
+ )
203
+
204
+ move_noise_nodes = self.move_noise_stmts(errors)
205
+
206
+ for new_node in move_noise_nodes:
207
+ new_node.insert_before(node)
208
+ has_done_something = True
209
+
210
+ gate_noise_nodes = self.cz_gate_noise(ctrls.result, qargs.result)
211
+
212
+ for new_node in gate_noise_nodes:
213
+ new_node.insert_before(node)
214
+ has_done_something = True
215
+
216
+ return result.RewriteResult(has_done_something=has_done_something)
217
+
218
+ def rewrite_parallel_cz_gate(self, node: parallel.CZ):
219
+ ctrls = self.address_analysis[node.ctrls]
220
+ qargs = self.address_analysis[node.qargs]
221
+
222
+ has_done_something = False
223
+ if (
224
+ isinstance(ctrls, address.AddressTuple)
225
+ and all(isinstance(addr, address.AddressQubit) for addr in ctrls.data)
226
+ and isinstance(qargs, address.AddressTuple)
227
+ and all(isinstance(addr, address.AddressQubit) for addr in qargs.data)
228
+ ):
229
+ ctrl_qubits = list(map(lambda addr: addr.data, ctrls.data))
230
+ qarg_qubits = list(map(lambda addr: addr.data, qargs.data))
231
+ rest = sorted(
232
+ set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits)
233
+ )
234
+ errors = self.noise_model.parallel_cz_errors(ctrl_qubits, qarg_qubits, rest)
235
+ move_noise_nodes = self.move_noise_stmts(errors)
236
+
237
+ for new_node in move_noise_nodes:
238
+ new_node.insert_before(node)
239
+ has_done_something = True
240
+
241
+ gate_noise_nodes = self.cz_gate_noise(node.ctrls, node.qargs)
242
+
243
+ for new_node in gate_noise_nodes:
244
+ new_node.insert_before(node)
245
+ has_done_something = True
246
+
247
+ return result.RewriteResult(has_done_something=has_done_something)
@@ -0,0 +1,447 @@
1
+ import math
2
+ from typing import List, Optional
3
+ from functools import cached_property
4
+ from dataclasses import field, dataclass
5
+
6
+ import cirq
7
+ import numpy as np
8
+ import cirq.transformers
9
+ import cirq.contrib.qasm_import
10
+ import cirq.transformers.target_gatesets
11
+ import cirq.transformers.target_gatesets.compilation_target_gateset
12
+ from kirin import ir
13
+ from kirin.rewrite import abc, result
14
+ from kirin.dialects import py
15
+ from cirq.circuits.qasm_output import QasmUGate
16
+ from cirq.transformers.target_gatesets.compilation_target_gateset import (
17
+ CompilationTargetGateset,
18
+ )
19
+
20
+ from bloqade.qasm2.dialects import uop, expr
21
+
22
+
23
+ # rydeberg gate sets
24
+ class RydbergTargetGateset(cirq.CZTargetGateset):
25
+ def __init__(self, *, cnz_max_size: int = 2, atol: float = 1e-8):
26
+ additional = [cirq.Z.controlled(cn) for cn in range(2, cnz_max_size)]
27
+ super().__init__(atol=atol, additional_gates=additional)
28
+ self.cnz_max_size = cnz_max_size
29
+
30
+ @property
31
+ def num_qubits(self) -> int:
32
+ return max(2, self.cnz_max_size)
33
+
34
+
35
+ # decompose the CU by defining a custom Cirq Gate with the qelib1 definition
36
+ # Need to be careful about the fact that U(theta, phi, lambda) in standard QASM2
37
+ # and its variants
38
+ class CU(cirq.Gate):
39
+ def __init__(self, theta, phi, lam, gamma):
40
+ super(CU, self)
41
+ self.theta = theta
42
+ self.phi = phi
43
+ self.lam = lam
44
+ self.gamma = gamma
45
+
46
+ def _num_qubits_(self):
47
+ return 2
48
+
49
+ def _decompose_(self, qubits):
50
+ ctrl, target = qubits
51
+ # taken from qelib1 definition
52
+ # p(gamma) c;
53
+ yield QasmUGate(0, 0, self.gamma / math.pi)(ctrl)
54
+ # p((lambda+phi/2)) c;
55
+ yield QasmUGate(0, 0, ((self.lam + self.phi) / 2) / math.pi)(ctrl)
56
+ # p((lambda-phi/2)) t;
57
+ yield QasmUGate(0, 0, ((self.lam - self.phi) / 2) / math.pi)(target)
58
+ # cx c,t
59
+ yield cirq.CX(ctrl, target)
60
+ # u(-theta/2, 0, -(phi+lambda/2)) t;
61
+ yield QasmUGate(
62
+ (-self.theta / 2) / math.pi, 0, (-(self.phi + self.lam) / 2) / math.pi
63
+ )(target)
64
+ # cx c,t
65
+ yield cirq.CX(ctrl, target)
66
+ # u(theta/2, phi, 0) t;
67
+ yield QasmUGate((self.theta / 2) / math.pi, self.phi / math.pi, 0)(target)
68
+
69
+ def _circuit_diagram_info_(self, args):
70
+ return "*", "CU"
71
+
72
+
73
+ def around(val):
74
+ return float(np.around(val, 14))
75
+
76
+
77
+ def one_qubit_gate_to_u3_angles(op: cirq.Operation) -> tuple[float, float, float]:
78
+ lam, theta, phi = ( # Z angle, Y angle, then Z angle
79
+ cirq.deconstruct_single_qubit_matrix_into_angles(cirq.unitary(op))
80
+ )
81
+ return tuple(map(around, (theta, phi, lam)))
82
+
83
+
84
+ @dataclass
85
+ class RydbergGateSetRewriteRule(abc.RewriteRule):
86
+ # NOTE
87
+ # 1. this can only rewrite qasm2.main and qasm2.gate!
88
+ dialect_group: ir.DialectGroup
89
+ gateset: CompilationTargetGateset = field(default_factory=RydbergTargetGateset)
90
+
91
+ @cached_property
92
+ def cached_qubits(self) -> tuple[cirq.LineQubit, ...]:
93
+
94
+ # qasm2 stmts only have up to 3 qubits gates, so we cached only 3.
95
+ return tuple(cirq.LineQubit(i) for i in range(3))
96
+
97
+ @cached_property
98
+ def const_float_type(self):
99
+ if expr.dialect in self.dialect_group.data:
100
+ return expr.ConstFloat
101
+ else:
102
+ return py.constant.Constant
103
+
104
+ def const_float(self, value: float):
105
+ return self.const_float_type(value=value)
106
+
107
+ @cached_property
108
+ def const_pi(self):
109
+ if expr in self.dialect_group.data:
110
+ return expr.ConstPI()
111
+ else:
112
+ return py.constant.Constant(value=math.pi)
113
+
114
+ def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
115
+ # only deal with uop
116
+ if type(node) in uop.dialect.stmts:
117
+ return getattr(self, f"rewrite_{node.name}")(node)
118
+
119
+ return result.RewriteResult()
120
+
121
+ def rewrite_barrier(self, node: uop.Barrier) -> result.RewriteResult:
122
+ return result.RewriteResult()
123
+
124
+ def rewrite_cz(self, node: uop.CZ) -> result.RewriteResult:
125
+ return result.RewriteResult()
126
+
127
+ def rewrite_CX(self, node: uop.CX) -> result.RewriteResult:
128
+ return self._rewrite_2q_ctrl_gates(
129
+ cirq.CX(self.cached_qubits[0], self.cached_qubits[1]), node
130
+ )
131
+
132
+ def rewrite_cy(self, node: uop.CY) -> result.RewriteResult:
133
+ return self._rewrite_2q_ctrl_gates(
134
+ cirq.ControlledGate(cirq.Y, 1)(
135
+ self.cached_qubits[0], self.cached_qubits[1]
136
+ ),
137
+ node,
138
+ )
139
+
140
+ def rewrite_U(self, node: uop.UGate) -> result.RewriteResult:
141
+ return result.RewriteResult()
142
+
143
+ def rewrite_id(self, node: uop.Id) -> result.RewriteResult:
144
+ node.delete() # just delete the identity gate
145
+ return result.RewriteResult(has_done_something=True)
146
+
147
+ def rewrite_h(self, node: uop.H) -> result.RewriteResult:
148
+ return self._rewrite_1q_gates(cirq.H(self.cached_qubits[0]), node)
149
+
150
+ def rewrite_x(self, node: uop.X) -> result.RewriteResult:
151
+ return self._rewrite_1q_gates(cirq.X(self.cached_qubits[0]), node)
152
+
153
+ def rewrite_y(self, node: uop.Y) -> result.RewriteResult:
154
+ return self._rewrite_1q_gates(cirq.Y(self.cached_qubits[0]), node)
155
+
156
+ def rewrite_z(self, node: uop.Z) -> result.RewriteResult:
157
+ return self._rewrite_1q_gates(cirq.Z(self.cached_qubits[0]), node)
158
+
159
+ def rewrite_s(self, node: uop.S) -> result.RewriteResult:
160
+ return self._rewrite_1q_gates(cirq.S(self.cached_qubits[0]), node)
161
+
162
+ def rewrite_sdg(self, node: uop.Sdag) -> result.RewriteResult:
163
+ return self._rewrite_1q_gates(cirq.S(self.cached_qubits[0]) ** -1, node)
164
+
165
+ def rewrite_t(self, node: uop.T) -> result.RewriteResult:
166
+ return self._rewrite_1q_gates(cirq.T(self.cached_qubits[0]), node)
167
+
168
+ def rewrite_tdg(self, node: uop.Tdag) -> result.RewriteResult:
169
+ return self._rewrite_1q_gates(cirq.T(self.cached_qubits[0]) ** -1, node)
170
+
171
+ def rewrite_sx(self, node: uop.SX) -> result.RewriteResult:
172
+ return self._rewrite_1q_gates(
173
+ cirq.XPowGate(exponent=0.5).on(self.cached_qubits[0]), node
174
+ )
175
+
176
+ def rewrite_sxdg(self, node: uop.SXdag) -> result.RewriteResult:
177
+ return self._rewrite_1q_gates(
178
+ cirq.XPowGate(exponent=-0.5).on(self.cached_qubits[0]), node
179
+ )
180
+
181
+ def rewrite_u1(self, node: uop.U1) -> result.RewriteResult:
182
+ theta = node.lam
183
+ (phi := self.const_float(value=0.0)).insert_before(node)
184
+ node.replace_by(
185
+ uop.UGate(qarg=node.qarg, theta=phi.result, phi=phi.result, lam=theta)
186
+ )
187
+ return result.RewriteResult(has_done_something=True)
188
+
189
+ def rewrite_u2(self, node: uop.U2) -> result.RewriteResult:
190
+ phi = node.phi
191
+ lam = node.lam
192
+ (theta := self.const_float(value=math.pi / 2)).insert_before(node)
193
+ node.replace_by(uop.UGate(qarg=node.qarg, theta=theta.result, phi=phi, lam=lam))
194
+ return result.RewriteResult(has_done_something=True)
195
+
196
+ def rewrite_rx(self, node: uop.RX) -> result.RewriteResult:
197
+ theta = node.theta
198
+ (phi := self.const_float(value=math.pi / 2)).insert_before(node)
199
+ (lam := self.const_float(value=-math.pi / 2)).insert_before(node)
200
+ node.replace_by(
201
+ uop.UGate(qarg=node.qarg, theta=theta, phi=phi.result, lam=lam.result)
202
+ )
203
+ return result.RewriteResult(has_done_something=True)
204
+
205
+ def rewrite_ry(self, node: uop.RY) -> result.RewriteResult:
206
+ theta = node.theta
207
+ (phi := self.const_float(value=0.0)).insert_before(node)
208
+ node.replace_by(
209
+ uop.UGate(qarg=node.qarg, theta=theta, phi=phi.result, lam=phi.result)
210
+ )
211
+ return result.RewriteResult(has_done_something=True)
212
+
213
+ def rewrite_rz(self, node: uop.RZ) -> result.RewriteResult:
214
+ theta = node.theta
215
+ (phi := self.const_float(value=0.0)).insert_before(node)
216
+ node.replace_by(
217
+ uop.UGate(qarg=node.qarg, theta=phi.result, phi=phi.result, lam=theta)
218
+ )
219
+ return result.RewriteResult(has_done_something=True)
220
+
221
+ def rewrite_crx(self, node: uop.CRX) -> result.RewriteResult:
222
+ lam = self._get_const_value(node.lam)
223
+
224
+ if lam is None:
225
+ return result.RewriteResult()
226
+
227
+ return self._rewrite_2q_ctrl_gates(
228
+ cirq.ControlledGate(cirq.Rx(rads=lam), 1).on(
229
+ self.cached_qubits[0], self.cached_qubits[1]
230
+ ),
231
+ node,
232
+ )
233
+
234
+ def rewrite_cry(self, node: uop.CRY) -> result.RewriteResult:
235
+ lam = self._get_const_value(node.lam)
236
+
237
+ if lam is None:
238
+ return result.RewriteResult()
239
+
240
+ return self._rewrite_2q_ctrl_gates(
241
+ cirq.ControlledGate(cirq.Ry(rads=lam), 1).on(
242
+ self.cached_qubits[0], self.cached_qubits[1]
243
+ ),
244
+ node,
245
+ )
246
+
247
+ def rewrite_crz(self, node: uop.CRZ) -> result.RewriteResult:
248
+ lam = self._get_const_value(node.lam)
249
+
250
+ if lam is None:
251
+ return result.RewriteResult()
252
+
253
+ return self._rewrite_2q_ctrl_gates(
254
+ cirq.ControlledGate(cirq.Rz(rads=lam), 1).on(
255
+ self.cached_qubits[0], self.cached_qubits[1]
256
+ ),
257
+ node,
258
+ )
259
+
260
+ def rewrite_cu1(self, node: uop.CU1) -> result.RewriteResult:
261
+
262
+ lam = self._get_const_value(node.lam)
263
+
264
+ if lam is None:
265
+ return result.RewriteResult()
266
+
267
+ # cirq.ControlledGate(u3(0, 0, lambda))
268
+ return self._rewrite_2q_ctrl_gates(
269
+ cirq.ControlledGate(QasmUGate(0, 0, lam / math.pi)).on(
270
+ self.cached_qubits[0], self.cached_qubits[1]
271
+ ),
272
+ node,
273
+ )
274
+ pass
275
+
276
+ def rewrite_cu3(self, node: uop.CU3) -> result.RewriteResult:
277
+
278
+ theta = self._get_const_value(node.theta)
279
+ lam = self._get_const_value(node.lam)
280
+ phi = self._get_const_value(node.phi)
281
+
282
+ if not all((theta, phi, lam)):
283
+ return result.RewriteResult()
284
+
285
+ # cirq.ControlledGate(u3(theta, lambda phi))
286
+ return self._rewrite_2q_ctrl_gates(
287
+ cirq.ControlledGate(
288
+ QasmUGate(theta / math.pi, phi / math.pi, lam / math.pi)
289
+ ).on(self.cached_qubits[0], self.cached_qubits[1]),
290
+ node,
291
+ )
292
+
293
+ def rewrite_cu(self, node: uop.CU) -> result.RewriteResult:
294
+
295
+ gamma = self._get_const_value(node.gamma)
296
+ theta = self._get_const_value(node.theta)
297
+ lam = self._get_const_value(node.lam)
298
+ phi = self._get_const_value(node.phi)
299
+
300
+ # need to create custom 2q gate, then feed that into rewrite_2q
301
+
302
+ return self._rewrite_2q_ctrl_gates(
303
+ CU(theta, phi, lam, gamma).on(self.cached_qubits[0], self.cached_qubits[1]),
304
+ node,
305
+ )
306
+
307
+ def rewrite_rxx(self, node: uop.RXX) -> result.RewriteResult:
308
+
309
+ theta = self._get_const_value(node.theta)
310
+
311
+ if theta is None:
312
+ return result.RewriteResult()
313
+
314
+ # even though the XX gate is not controlled,
315
+ # the end U + CZ decomposition that happens internally means
316
+ return self._rewrite_2q_ctrl_gates(
317
+ cirq.XXPowGate(exponent=theta / math.pi).on(
318
+ self.cached_qubits[0], self.cached_qubits[1]
319
+ ),
320
+ node,
321
+ )
322
+
323
+ def rewrite_rzz(self, node: uop.RZZ) -> result.RewriteResult:
324
+ theta = self._get_const_value(node.theta)
325
+
326
+ if theta is None:
327
+ return result.RewriteResult()
328
+
329
+ return self._rewrite_2q_ctrl_gates(
330
+ cirq.ZZPowGate(exponent=theta / math.pi).on(
331
+ self.cached_qubits[0], self.cached_qubits[1]
332
+ ),
333
+ node,
334
+ )
335
+
336
+ """
337
+ return self._rewrite_2q_ctrl_gates(
338
+ cirq.ZZPowGate(exponent = theta).on(self.cached_qubits[0], self.cached_qubits[1])
339
+ ,node
340
+ )
341
+ """
342
+
343
+ def rewrite_swap(self, node: uop.Swap):
344
+ return self._rewrite_2q_ctrl_gates(
345
+ cirq.SWAP(self.cached_qubits[0], self.cached_qubits[1]), node
346
+ )
347
+
348
+ def _get_const_value(self, ssa: ir.SSAValue) -> Optional[float | int]:
349
+ if not isinstance(ssa, ir.ResultValue):
350
+ return None
351
+
352
+ match ssa.owner:
353
+ case expr.ConstFloat(value=value):
354
+ return value
355
+ case expr.ConstInt(value=value):
356
+ return value
357
+ case py.constant.Constant(value=float() as value) | py.constant.Constant(
358
+ value=int() as value
359
+ ):
360
+ return value
361
+ case expr.ConstPI():
362
+ return math.pi
363
+ case _:
364
+ return None
365
+
366
+ def _generate_1q_gate_stmts(self, cirq_gate: cirq.Operation, qarg: ir.SSAValue):
367
+ target_gates = self.gateset.decompose_to_target_gateset(cirq_gate, 0)
368
+
369
+ if isinstance(target_gates, cirq.GateOperation):
370
+ target_gates = [target_gates]
371
+
372
+ new_stmts = []
373
+ for new_gate in target_gates:
374
+ theta, phi, lam = one_qubit_gate_to_u3_angles(new_gate)
375
+ theta_stmt = self.const_float(value=theta)
376
+ phi_stmt = self.const_float(value=phi)
377
+ lam_stmt = self.const_float(value=lam)
378
+
379
+ new_stmts.append(theta_stmt)
380
+ new_stmts.append(phi_stmt)
381
+ new_stmts.append(lam_stmt)
382
+ new_stmts.append(
383
+ uop.UGate(
384
+ qarg=qarg,
385
+ theta=theta_stmt.result,
386
+ phi=phi_stmt.result,
387
+ lam=lam_stmt.result,
388
+ )
389
+ )
390
+ return new_stmts
391
+
392
+ def _rewrite_1q_gates(
393
+ self, cirq_gate: cirq.Operation, node: uop.SingleQubitGate
394
+ ) -> result.RewriteResult:
395
+ new_gate_stmts = self._generate_1q_gate_stmts(cirq_gate, node.qarg)
396
+ return self._rewrite_gate_stmts(new_gate_stmts, node)
397
+
398
+ def _generate_2q_ctrl_gate_stmts(
399
+ self, cirq_gate: cirq.Operation, qubits_ssa: List[ir.SSAValue]
400
+ ) -> list[ir.Statement]:
401
+ target_gates = self.gateset.decompose_to_target_gateset(cirq_gate, 0)
402
+ new_stmts = []
403
+ for new_gate in target_gates:
404
+ if len(new_gate.qubits) == 1:
405
+ # 1q
406
+ phi0, phi1, phi2 = one_qubit_gate_to_u3_angles(new_gate)
407
+ phi0_stmt = self.const_float(value=phi0)
408
+ phi1_stmt = self.const_float(value=phi1)
409
+ phi2_stmt = self.const_float(value=phi2)
410
+
411
+ new_stmts.append(phi0_stmt)
412
+ new_stmts.append(phi1_stmt)
413
+ new_stmts.append(phi2_stmt)
414
+ new_stmts.append(
415
+ uop.UGate(
416
+ qarg=qubits_ssa[new_gate.qubits[0].x],
417
+ theta=phi0_stmt.result,
418
+ phi=phi1_stmt.result,
419
+ lam=phi2_stmt.result,
420
+ )
421
+ )
422
+ else:
423
+ # 2q
424
+ new_stmts.append(uop.CZ(ctrl=qubits_ssa[0], qarg=qubits_ssa[1]))
425
+
426
+ return new_stmts
427
+
428
+ def _rewrite_2q_ctrl_gates(
429
+ self, cirq_gate: cirq.Operation, node: uop.TwoQubitCtrlGate
430
+ ) -> result.RewriteResult:
431
+ new_gate_stmts = self._generate_2q_ctrl_gate_stmts(
432
+ cirq_gate, [node.ctrl, node.qarg]
433
+ )
434
+ return self._rewrite_gate_stmts(new_gate_stmts, node)
435
+
436
+ def _rewrite_gate_stmts(
437
+ self, new_gate_stmts: list[ir.Statement], node: ir.Statement
438
+ ):
439
+
440
+ node.replace_by(new_gate_stmts[0])
441
+ node = new_gate_stmts[0]
442
+
443
+ for stmt in new_gate_stmts[1:]:
444
+ stmt.insert_after(node)
445
+ node = stmt
446
+
447
+ return result.RewriteResult(has_done_something=True)