bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/analysis.py +18 -20
  7. bloqade/analysis/measure_id/impls.py +31 -29
  8. bloqade/annotate/__init__.py +6 -0
  9. bloqade/annotate/_dialect.py +3 -0
  10. bloqade/annotate/_interface.py +22 -0
  11. bloqade/annotate/stmts.py +29 -0
  12. bloqade/annotate/types.py +13 -0
  13. bloqade/cirq_utils/__init__.py +4 -2
  14. bloqade/cirq_utils/emit/__init__.py +3 -0
  15. bloqade/cirq_utils/emit/base.py +246 -0
  16. bloqade/cirq_utils/emit/gate.py +104 -0
  17. bloqade/cirq_utils/emit/noise.py +90 -0
  18. bloqade/cirq_utils/emit/qubit.py +35 -0
  19. bloqade/cirq_utils/lowering.py +660 -0
  20. bloqade/cirq_utils/noise/__init__.py +0 -2
  21. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  22. bloqade/cirq_utils/noise/model.py +151 -191
  23. bloqade/cirq_utils/noise/transform.py +2 -2
  24. bloqade/cirq_utils/parallelize.py +9 -6
  25. bloqade/gemini/__init__.py +1 -0
  26. bloqade/gemini/analysis/__init__.py +3 -0
  27. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  28. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  29. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  30. bloqade/gemini/groups.py +67 -0
  31. bloqade/native/__init__.py +23 -0
  32. bloqade/native/_prelude.py +45 -0
  33. bloqade/native/dialects/__init__.py +0 -0
  34. bloqade/native/dialects/gate/__init__.py +2 -0
  35. bloqade/native/dialects/gate/_dialect.py +3 -0
  36. bloqade/native/dialects/gate/_interface.py +32 -0
  37. bloqade/native/dialects/gate/stmts.py +31 -0
  38. bloqade/native/stdlib/__init__.py +0 -0
  39. bloqade/native/stdlib/broadcast.py +246 -0
  40. bloqade/native/stdlib/simple.py +220 -0
  41. bloqade/native/upstream/__init__.py +4 -0
  42. bloqade/native/upstream/squin2native.py +79 -0
  43. bloqade/pyqrack/__init__.py +2 -2
  44. bloqade/pyqrack/base.py +7 -1
  45. bloqade/pyqrack/device.py +192 -18
  46. bloqade/pyqrack/native.py +49 -0
  47. bloqade/pyqrack/reg.py +6 -6
  48. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  49. bloqade/pyqrack/squin/gate/gate.py +136 -0
  50. bloqade/pyqrack/squin/noise/native.py +120 -54
  51. bloqade/pyqrack/squin/qubit.py +39 -36
  52. bloqade/pyqrack/target.py +5 -4
  53. bloqade/pyqrack/task.py +114 -7
  54. bloqade/qasm2/_qasm_loading.py +3 -3
  55. bloqade/qasm2/dialects/core/address.py +21 -12
  56. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  57. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  58. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  59. bloqade/qasm2/dialects/noise/model.py +2 -1
  60. bloqade/qasm2/emit/base.py +16 -11
  61. bloqade/qasm2/emit/gate.py +11 -8
  62. bloqade/qasm2/emit/main.py +103 -3
  63. bloqade/qasm2/emit/target.py +9 -5
  64. bloqade/qasm2/groups.py +3 -2
  65. bloqade/qasm2/parse/lowering.py +0 -1
  66. bloqade/qasm2/passes/fold.py +14 -73
  67. bloqade/qasm2/passes/glob.py +2 -2
  68. bloqade/qasm2/passes/noise.py +1 -1
  69. bloqade/qasm2/passes/parallel.py +7 -5
  70. bloqade/qasm2/rewrite/__init__.py +0 -1
  71. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  72. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  73. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  74. bloqade/qasm2/rewrite/register.py +2 -2
  75. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  76. bloqade/qbraid/lowering.py +1 -0
  77. bloqade/qbraid/schema.py +2 -2
  78. bloqade/qubit/__init__.py +12 -0
  79. bloqade/qubit/_dialect.py +3 -0
  80. bloqade/qubit/_interface.py +49 -0
  81. bloqade/qubit/_prelude.py +45 -0
  82. bloqade/qubit/analysis/__init__.py +1 -0
  83. bloqade/qubit/analysis/address_impl.py +40 -0
  84. bloqade/qubit/stdlib/__init__.py +2 -0
  85. bloqade/qubit/stdlib/_new.py +34 -0
  86. bloqade/qubit/stdlib/broadcast.py +62 -0
  87. bloqade/qubit/stdlib/simple.py +59 -0
  88. bloqade/qubit/stmts.py +60 -0
  89. bloqade/rewrite/passes/__init__.py +6 -0
  90. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  91. bloqade/rewrite/passes/callgraph.py +116 -0
  92. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  93. bloqade/rewrite/rules/split_ifs.py +18 -1
  94. bloqade/squin/__init__.py +47 -14
  95. bloqade/squin/analysis/__init__.py +0 -1
  96. bloqade/squin/analysis/schedule.py +10 -11
  97. bloqade/squin/gate/__init__.py +2 -0
  98. bloqade/squin/gate/_dialect.py +3 -0
  99. bloqade/squin/gate/_interface.py +98 -0
  100. bloqade/squin/gate/stmts.py +125 -0
  101. bloqade/squin/groups.py +5 -22
  102. bloqade/squin/noise/__init__.py +1 -10
  103. bloqade/squin/noise/_dialect.py +1 -1
  104. bloqade/squin/noise/_interface.py +45 -0
  105. bloqade/squin/noise/stmts.py +66 -28
  106. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  107. bloqade/squin/rewrite/__init__.py +0 -2
  108. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  109. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  110. bloqade/squin/stdlib/__init__.py +0 -0
  111. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  112. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  113. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  114. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  115. bloqade/squin/stdlib/simple/__init__.py +33 -0
  116. bloqade/squin/stdlib/simple/gate.py +242 -0
  117. bloqade/squin/stdlib/simple/noise.py +126 -0
  118. bloqade/stim/__init__.py +1 -0
  119. bloqade/stim/_wrappers.py +6 -0
  120. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  121. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  122. bloqade/stim/dialects/gate/emit.py +9 -10
  123. bloqade/stim/dialects/noise/emit.py +17 -13
  124. bloqade/stim/dialects/noise/stmts.py +5 -3
  125. bloqade/stim/emit/__init__.py +1 -0
  126. bloqade/stim/emit/impls.py +16 -0
  127. bloqade/stim/emit/stim_str.py +48 -31
  128. bloqade/stim/groups.py +12 -2
  129. bloqade/stim/parse/lowering.py +14 -17
  130. bloqade/stim/passes/__init__.py +0 -2
  131. bloqade/stim/passes/flatten.py +26 -0
  132. bloqade/stim/passes/simplify_ifs.py +6 -1
  133. bloqade/stim/passes/squin_to_stim.py +9 -84
  134. bloqade/stim/rewrite/__init__.py +2 -4
  135. bloqade/stim/rewrite/get_record_util.py +24 -0
  136. bloqade/stim/rewrite/ifs_to_stim.py +24 -25
  137. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  138. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  139. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  140. bloqade/stim/rewrite/squin_measure.py +9 -18
  141. bloqade/stim/rewrite/squin_noise.py +134 -108
  142. bloqade/stim/rewrite/util.py +5 -192
  143. bloqade/test_utils.py +1 -1
  144. bloqade/types.py +10 -0
  145. bloqade/validation/__init__.py +2 -0
  146. bloqade/validation/analysis/__init__.py +5 -0
  147. bloqade/validation/analysis/analysis.py +41 -0
  148. bloqade/validation/analysis/lattice.py +58 -0
  149. bloqade/validation/kernel_validation.py +77 -0
  150. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  151. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  152. bloqade/pyqrack/squin/op.py +0 -180
  153. bloqade/pyqrack/squin/runtime.py +0 -535
  154. bloqade/pyqrack/squin/wire.py +0 -51
  155. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  156. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  157. bloqade/squin/_typeinfer.py +0 -20
  158. bloqade/squin/analysis/address_impl.py +0 -71
  159. bloqade/squin/analysis/nsites/__init__.py +0 -9
  160. bloqade/squin/analysis/nsites/analysis.py +0 -50
  161. bloqade/squin/analysis/nsites/impls.py +0 -92
  162. bloqade/squin/analysis/nsites/lattice.py +0 -49
  163. bloqade/squin/cirq/__init__.py +0 -280
  164. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  165. bloqade/squin/cirq/emit/noise.py +0 -49
  166. bloqade/squin/cirq/emit/op.py +0 -125
  167. bloqade/squin/cirq/emit/qubit.py +0 -60
  168. bloqade/squin/cirq/emit/runtime.py +0 -242
  169. bloqade/squin/cirq/lowering.py +0 -440
  170. bloqade/squin/lowering.py +0 -54
  171. bloqade/squin/noise/_wrapper.py +0 -40
  172. bloqade/squin/noise/rewrite.py +0 -111
  173. bloqade/squin/op/__init__.py +0 -41
  174. bloqade/squin/op/_dialect.py +0 -3
  175. bloqade/squin/op/_wrapper.py +0 -121
  176. bloqade/squin/op/number.py +0 -5
  177. bloqade/squin/op/rewrite.py +0 -46
  178. bloqade/squin/op/stdlib.py +0 -62
  179. bloqade/squin/op/stmts.py +0 -276
  180. bloqade/squin/op/traits.py +0 -43
  181. bloqade/squin/op/types.py +0 -26
  182. bloqade/squin/qubit.py +0 -184
  183. bloqade/squin/rewrite/canonicalize.py +0 -60
  184. bloqade/squin/rewrite/desugar.py +0 -124
  185. bloqade/squin/types.py +0 -8
  186. bloqade/squin/wire.py +0 -201
  187. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  188. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  189. bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
  190. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  191. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,440 +0,0 @@
1
- import math
2
- from typing import Any
3
- from dataclasses import field, dataclass
4
-
5
- import cirq
6
- from kirin import ir, types, lowering
7
- from kirin.rewrite import Walk, CFGCompactify
8
- from kirin.dialects import py, scf, ilist
9
-
10
- from .. import op, noise, qubit
11
-
12
- CirqNode = cirq.Circuit | cirq.Moment | cirq.Gate | cirq.Qid | cirq.Operation
13
-
14
- DecomposeNode = (
15
- cirq.SwapPowGate
16
- | cirq.ISwapPowGate
17
- | cirq.PhasedXPowGate
18
- | cirq.PhasedXZGate
19
- | cirq.CSwapGate
20
- )
21
-
22
-
23
- @dataclass
24
- class Squin(lowering.LoweringABC[CirqNode]):
25
- """Lower a cirq.Circuit object to a squin kernel"""
26
-
27
- circuit: cirq.Circuit
28
- qreg: ir.SSAValue = field(init=False)
29
- qreg_index: dict[cirq.Qid, int] = field(init=False, default_factory=dict)
30
- next_qreg_index: int = field(init=False, default=0)
31
-
32
- def __post_init__(self):
33
- # TODO: sort by cirq ordering
34
- qbits = sorted(self.circuit.all_qubits())
35
- self.qreg_index = {qid: idx for (idx, qid) in enumerate(qbits)}
36
-
37
- def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid):
38
- index = self.qreg_index[qid]
39
- index_ssa = state.current_frame.push(py.Constant(index)).result
40
- qbit_getitem = state.current_frame.push(py.GetItem(self.qreg, index_ssa))
41
- return qbit_getitem.result
42
-
43
- def lower_qubit_getindices(
44
- self, state: lowering.State[CirqNode], qids: list[cirq.Qid]
45
- ):
46
- qbits_getitem = [self.lower_qubit_getindex(state, qid) for qid in qids]
47
- qbits_stmt = ilist.New(values=qbits_getitem, elem_type=qubit.QubitType)
48
- qbits_result = state.current_frame.get(qbits_stmt.name)
49
-
50
- if qbits_result is not None:
51
- return qbits_result
52
-
53
- state.current_frame.push(qbits_stmt)
54
- return qbits_stmt.result
55
-
56
- def run(
57
- self,
58
- stmt: CirqNode,
59
- *,
60
- source: str | None = None,
61
- globals: dict[str, Any] | None = None,
62
- file: str | None = None,
63
- lineno_offset: int = 0,
64
- col_offset: int = 0,
65
- compactify: bool = True,
66
- register_as_argument: bool = False,
67
- register_argument_name: str = "q",
68
- ) -> ir.Region:
69
-
70
- state = lowering.State(
71
- self,
72
- file=file,
73
- lineno_offset=lineno_offset,
74
- col_offset=col_offset,
75
- )
76
-
77
- with state.frame([stmt], globals=globals, finalize_next=False) as frame:
78
-
79
- # NOTE: need a register of qubits before lowering statements
80
- if register_as_argument:
81
- # NOTE: register as argument to the kernel; we have freedom of choice for the name here
82
- frame.curr_block.args.append_from(
83
- ilist.IListType[qubit.QubitType, types.Any],
84
- name=register_argument_name,
85
- )
86
- self.qreg = frame.curr_block.args[0]
87
- else:
88
- # NOTE: create a new register of appropriate size
89
- n_qubits = len(self.qreg_index)
90
- n = frame.push(py.Constant(n_qubits))
91
- self.qreg = frame.push(qubit.New(n_qubits=n.result)).result
92
-
93
- self.visit(state, stmt)
94
-
95
- if compactify:
96
- Walk(CFGCompactify()).rewrite(frame.curr_region)
97
-
98
- region = frame.curr_region
99
-
100
- return region
101
-
102
- def visit(self, state: lowering.State[CirqNode], node: CirqNode) -> lowering.Result:
103
- name = node.__class__.__name__
104
- return getattr(self, f"visit_{name}", self.generic_visit)(state, node)
105
-
106
- def generic_visit(self, state: lowering.State[CirqNode], node: CirqNode):
107
- if isinstance(node, CirqNode):
108
- raise lowering.BuildError(
109
- f"Cannot lower {node.__class__.__name__} node: {node}"
110
- )
111
- raise lowering.BuildError(
112
- f"Unexpected `{node.__class__.__name__}` node: {repr(node)} is not an AST node"
113
- )
114
-
115
- def lower_literal(self, state: lowering.State[CirqNode], value) -> ir.SSAValue:
116
- raise lowering.BuildError("Literals not supported in cirq circuit")
117
-
118
- def lower_global(
119
- self, state: lowering.State[CirqNode], node: CirqNode
120
- ) -> lowering.LoweringABC.Result:
121
- raise lowering.BuildError("Literals not supported in cirq circuit")
122
-
123
- def visit_Circuit(
124
- self, state: lowering.State[CirqNode], node: cirq.Circuit
125
- ) -> lowering.Result:
126
- for moment in node:
127
- state.lower(moment)
128
-
129
- def visit_Moment(
130
- self, state: lowering.State[CirqNode], node: cirq.Moment
131
- ) -> lowering.Result:
132
- for op_ in node.operations:
133
- state.lower(op_)
134
-
135
- def visit_GateOperation(
136
- self, state: lowering.State[CirqNode], node: cirq.GateOperation
137
- ):
138
- if isinstance(node.gate, cirq.MeasurementGate):
139
- # NOTE: special dispatch here, since measurement is a gate + a qubit in cirq,
140
- # but a single statement in squin
141
- return self.lower_measurement(state, node)
142
-
143
- if isinstance(node.gate, DecomposeNode):
144
- # NOTE: easier to decompose these, but for that we need the qubits too,
145
- # so we need to do this within this method
146
- for subnode in cirq.decompose_once(node):
147
- state.lower(subnode)
148
- return
149
-
150
- op_ = state.lower(node.gate).expect_one()
151
- qbits = self.lower_qubit_getindices(state, node.qubits)
152
- return state.current_frame.push(qubit.Apply(operator=op_, qubits=qbits))
153
-
154
- def lower_measurement(
155
- self, state: lowering.State[CirqNode], node: cirq.GateOperation
156
- ):
157
- if len(node.qubits) == 1:
158
- qbit = self.lower_qubit_getindex(state, node.qubits[0])
159
- stmt = state.current_frame.push(qubit.MeasureQubit(qbit))
160
- else:
161
- qbits = self.lower_qubit_getindices(state, node.qubits)
162
- stmt = state.current_frame.push(qubit.MeasureQubitList(qbits))
163
-
164
- key = node.gate.key
165
- if isinstance(key, cirq.MeasurementKey):
166
- key = key.name
167
-
168
- state.current_frame.defs[key] = stmt.result
169
- return stmt
170
-
171
- def visit_ClassicallyControlledOperation(
172
- self, state: lowering.State[CirqNode], node: cirq.ClassicallyControlledOperation
173
- ):
174
- conditions: list[ir.SSAValue] = []
175
- for outcome in node.classical_controls:
176
- key = outcome.key
177
- if isinstance(key, cirq.MeasurementKey):
178
- key = key.name
179
- measurement_outcome = state.current_frame.defs[key]
180
-
181
- if measurement_outcome.type.is_subseteq(ilist.IListType):
182
- # NOTE: there is currently no convenient ilist.any method, so we need to use foldl
183
- # with a simple function that just does an or
184
-
185
- def bool_op_or(x: bool, y: bool) -> bool:
186
- return x or y
187
-
188
- f_code = state.current_frame.push(
189
- lowering.Python(self.dialects).python_function(bool_op_or)
190
- )
191
- fn = ir.Method(
192
- mod=None,
193
- py_func=bool_op_or,
194
- sym_name="bool_op_or",
195
- arg_names=[],
196
- dialects=self.dialects,
197
- code=f_code,
198
- )
199
- f_const = state.current_frame.push(py.constant.Constant(fn))
200
- init_val = state.current_frame.push(py.Constant(False)).result
201
- condition = state.current_frame.push(
202
- ilist.Foldl(f_const.result, measurement_outcome, init=init_val)
203
- ).result
204
- else:
205
- condition = measurement_outcome
206
-
207
- conditions.append(condition)
208
-
209
- if len(conditions) == 1:
210
- condition = conditions[0]
211
- else:
212
- condition = state.current_frame.push(
213
- py.boolop.And(conditions[0], conditions[1])
214
- ).result
215
- for next_cond in conditions[2:]:
216
- condition = state.current_frame.push(
217
- py.boolop.And(condition, next_cond)
218
- ).result
219
-
220
- then_stmt = self.visit(state, node.without_classical_controls())
221
-
222
- assert isinstance(
223
- then_stmt, ir.Statement
224
- ), f"Expected operation of classically controlled node {node} to be lowered to a statement, got type {type(then_stmt)}. \
225
- Please report this issue!"
226
-
227
- # NOTE: remove stmt from parent block
228
- then_stmt.detach()
229
- then_body = ir.Block((then_stmt,))
230
-
231
- return state.current_frame.push(scf.IfElse(condition, then_body=then_body))
232
-
233
- def visit_SingleQubitPauliStringGateOperation(
234
- self,
235
- state: lowering.State[CirqNode],
236
- node: cirq.SingleQubitPauliStringGateOperation,
237
- ):
238
-
239
- match node.pauli:
240
- case cirq.X:
241
- op_ = op.stmts.X()
242
- case cirq.Y:
243
- op_ = op.stmts.Y()
244
- case cirq.Z:
245
- op_ = op.stmts.Z()
246
- case cirq.I:
247
- op_ = op.stmts.Identity(sites=1)
248
- case _:
249
- raise lowering.BuildError(f"Unexpected Pauli operation {node.pauli}")
250
-
251
- state.current_frame.push(op_)
252
- qargs = self.lower_qubit_getindices(state, [node.qubit])
253
- return state.current_frame.push(qubit.Apply(op_.result, qargs))
254
-
255
- def visit_HPowGate(self, state: lowering.State[CirqNode], node: cirq.HPowGate):
256
- if abs(node.exponent) == 1:
257
- return state.current_frame.push(op.stmts.H())
258
-
259
- # NOTE: decompose into products of paulis for arbitrary exponents according to _decompose_ method
260
- # can't use decompose directly since that method requires qubits to be passed in for some reason
261
- y_rhs = state.lower(cirq.YPowGate(exponent=0.25)).expect_one()
262
- x = state.lower(
263
- cirq.XPowGate(exponent=node.exponent, global_shift=node.global_shift)
264
- ).expect_one()
265
- y_lhs = state.lower(cirq.YPowGate(exponent=-0.25)).expect_one()
266
-
267
- # NOTE: reversed order since we're creating a mult stmt
268
- m_lhs = state.current_frame.push(op.stmts.Mult(y_lhs, x))
269
- return state.current_frame.push(op.stmts.Mult(m_lhs.result, y_rhs))
270
-
271
- def visit_XPowGate(self, state: lowering.State[CirqNode], node: cirq.XPowGate):
272
- if abs(node.exponent == 1):
273
- return state.current_frame.push(op.stmts.X())
274
-
275
- return self.visit(state, node.in_su2())
276
-
277
- def visit_YPowGate(self, state: lowering.State[CirqNode], node: cirq.YPowGate):
278
- if abs(node.exponent == 1):
279
- return state.current_frame.push(op.stmts.Y())
280
-
281
- return self.visit(state, node.in_su2())
282
-
283
- def visit_ZPowGate(self, state: lowering.State[CirqNode], node: cirq.ZPowGate):
284
- if node.exponent == 0.5:
285
- return state.current_frame.push(op.stmts.S())
286
-
287
- if node.exponent == 0.25:
288
- return state.current_frame.push(op.stmts.T())
289
-
290
- if abs(node.exponent == 1):
291
- return state.current_frame.push(op.stmts.Z())
292
-
293
- # NOTE: just for the Z gate, an arbitrary exponent is equivalent to the ShiftOp
294
- # up to a minus sign!
295
- t = -node.exponent
296
- theta = state.current_frame.push(py.Constant(math.pi * t))
297
- return state.current_frame.push(op.stmts.ShiftOp(theta=theta.result))
298
-
299
- def visit_Rx(self, state: lowering.State[CirqNode], node: cirq.Rx):
300
- x = state.current_frame.push(op.stmts.X())
301
- angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent))
302
- return state.current_frame.push(op.stmts.Rot(axis=x.result, angle=angle.result))
303
-
304
- def visit_Ry(self, state: lowering.State[CirqNode], node: cirq.Ry):
305
- y = state.current_frame.push(op.stmts.Y())
306
- angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent))
307
- return state.current_frame.push(op.stmts.Rot(axis=y.result, angle=angle.result))
308
-
309
- def visit_Rz(self, state: lowering.State[CirqNode], node: cirq.Rz):
310
- z = state.current_frame.push(op.stmts.Z())
311
- angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent))
312
- return state.current_frame.push(op.stmts.Rot(axis=z.result, angle=angle.result))
313
-
314
- def visit_CXPowGate(self, state: lowering.State[CirqNode], node: cirq.CXPowGate):
315
- x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one()
316
- return state.current_frame.push(op.stmts.Control(x, n_controls=1))
317
-
318
- def visit_CZPowGate(self, state: lowering.State[CirqNode], node: cirq.CZPowGate):
319
- z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one()
320
- return state.current_frame.push(op.stmts.Control(z, n_controls=1))
321
-
322
- def visit_ControlledOperation(
323
- self, state: lowering.State[CirqNode], node: cirq.ControlledOperation
324
- ):
325
- return self.visit_GateOperation(state, node)
326
-
327
- def visit_ControlledGate(
328
- self, state: lowering.State[CirqNode], node: cirq.ControlledGate
329
- ):
330
- op_ = state.lower(node.sub_gate).expect_one()
331
- n_controls = node.num_controls()
332
- return state.current_frame.push(op.stmts.Control(op_, n_controls=n_controls))
333
-
334
- def visit_XXPowGate(self, state: lowering.State[CirqNode], node: cirq.XXPowGate):
335
- x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one()
336
- return state.current_frame.push(op.stmts.Kron(x, x))
337
-
338
- def visit_YYPowGate(self, state: lowering.State[CirqNode], node: cirq.YYPowGate):
339
- y = state.lower(cirq.YPowGate(exponent=node.exponent)).expect_one()
340
- return state.current_frame.push(op.stmts.Kron(y, y))
341
-
342
- def visit_ZZPowGate(self, state: lowering.State[CirqNode], node: cirq.ZZPowGate):
343
- z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one()
344
- return state.current_frame.push(op.stmts.Kron(z, z))
345
-
346
- def visit_CCXPowGate(self, state: lowering.State[CirqNode], node: cirq.CCXPowGate):
347
- x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one()
348
- return state.current_frame.push(op.stmts.Control(x, n_controls=2))
349
-
350
- def visit_CCZPowGate(self, state: lowering.State[CirqNode], node: cirq.CCZPowGate):
351
- z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one()
352
- return state.current_frame.push(op.stmts.Control(z, n_controls=2))
353
-
354
- def visit_BitFlipChannel(
355
- self, state: lowering.State[CirqNode], node: cirq.BitFlipChannel
356
- ):
357
- x = state.current_frame.push(op.stmts.X())
358
- p = state.current_frame.push(py.Constant(node.p))
359
- return state.current_frame.push(
360
- noise.stmts.PauliError(basis=x.result, p=p.result)
361
- )
362
-
363
- def visit_AmplitudeDampingChannel(
364
- self, state: lowering.State[CirqNode], node: cirq.AmplitudeDampingChannel
365
- ):
366
- r = state.current_frame.push(op.stmts.Reset())
367
- p = state.current_frame.push(py.Constant(node.gamma))
368
-
369
- # TODO: do we need a dedicated noise stmt for this? Using PauliError
370
- # with this basis feels like a hack
371
- noise_channel = state.current_frame.push(
372
- noise.stmts.PauliError(basis=r.result, p=p.result)
373
- )
374
-
375
- return noise_channel
376
-
377
- def visit_GeneralizedAmplitudeDampingChannel(
378
- self,
379
- state: lowering.State[CirqNode],
380
- node: cirq.GeneralizedAmplitudeDampingChannel,
381
- ):
382
- p = state.current_frame.push(py.Constant(node.p)).result
383
- gamma = state.current_frame.push(py.Constant(node.gamma)).result
384
-
385
- # NOTE: cirq has a weird convention here: if p == 1, we have AmplitudeDampingChannel,
386
- # which basically means p is the probability of the environment being in the vacuum state
387
- prob0 = state.current_frame.push(py.binop.Mult(p, gamma)).result
388
- one_ = state.current_frame.push(py.Constant(1)).result
389
- p_minus_1 = state.current_frame.push(py.binop.Sub(one_, p)).result
390
- prob1 = state.current_frame.push(py.binop.Mult(p_minus_1, gamma)).result
391
-
392
- r0 = state.current_frame.push(op.stmts.Reset()).result
393
- r1 = state.current_frame.push(op.stmts.ResetToOne()).result
394
-
395
- probs = state.current_frame.push(ilist.New(values=(prob0, prob1))).result
396
- ops = state.current_frame.push(ilist.New(values=(r0, r1))).result
397
-
398
- noise_channel = state.current_frame.push(
399
- noise.stmts.StochasticUnitaryChannel(probabilities=probs, operators=ops)
400
- )
401
-
402
- return noise_channel
403
-
404
- def visit_DepolarizingChannel(
405
- self, state: lowering.State[CirqNode], node: cirq.DepolarizingChannel
406
- ):
407
- p = state.current_frame.push(py.Constant(node.p)).result
408
- return state.current_frame.push(noise.stmts.Depolarize(p))
409
-
410
- def visit_AsymmetricDepolarizingChannel(
411
- self, state: lowering.State[CirqNode], node: cirq.AsymmetricDepolarizingChannel
412
- ):
413
- nqubits = node.num_qubits()
414
- if nqubits > 2:
415
- raise lowering.BuildError(
416
- "AsymmetricDepolarizingChannel applied to more than 2 qubits is not supported!"
417
- )
418
-
419
- if nqubits == 1:
420
- p_x = state.current_frame.push(py.Constant(node.p_x)).result
421
- p_y = state.current_frame.push(py.Constant(node.p_y)).result
422
- p_z = state.current_frame.push(py.Constant(node.p_z)).result
423
- params = state.current_frame.push(ilist.New(values=(p_x, p_y, p_z))).result
424
- return state.current_frame.push(noise.stmts.SingleQubitPauliChannel(params))
425
-
426
- # NOTE: nqubits == 2
427
- error_probs = node.error_probabilities
428
- paulis = ("I", "X", "Y", "Z")
429
- values = []
430
- for p1 in paulis:
431
- for p2 in paulis:
432
- if p1 == p2 == "I":
433
- continue
434
-
435
- p = error_probs.get(p1 + p2, 0.0)
436
- p_ssa = state.current_frame.push(py.Constant(p)).result
437
- values.append(p_ssa)
438
-
439
- params = state.current_frame.push(ilist.New(values=values)).result
440
- return state.current_frame.push(noise.stmts.TwoQubitPauliChannel(params))
bloqade/squin/lowering.py DELETED
@@ -1,54 +0,0 @@
1
- import ast
2
- from dataclasses import dataclass
3
-
4
- from kirin import lowering
5
-
6
- from . import qubit
7
-
8
-
9
- @dataclass(frozen=True)
10
- class ApplyAnyCallLowering(lowering.FromPythonCall["qubit.ApplyAny"]):
11
- """
12
- Custom lowering for ApplyAny that collects vararg qubits into a single tuple argument
13
- """
14
-
15
- def lower(
16
- self, stmt: type["qubit.ApplyAny"], state: lowering.State, node: ast.Call
17
- ):
18
- if len(node.args) + len(node.keywords) < 2:
19
- raise lowering.BuildError(
20
- "Apply requires at least one operator and one qubit as arguments!"
21
- )
22
-
23
- op, qubits = self.unpack_arguments(node)
24
-
25
- op_ssa = state.lower(op).expect_one()
26
- qubits_lowered = [state.lower(qbit).expect_one() for qbit in qubits]
27
-
28
- s = stmt(op_ssa, tuple(qubits_lowered))
29
- return state.current_frame.push(s)
30
-
31
- def unpack_arguments(self, node: ast.Call) -> tuple[ast.expr, list[ast.expr]]:
32
- if len(node.keywords) == 0:
33
- op, *qubits = node.args
34
- return op, qubits
35
-
36
- kwargs = {kw.arg: kw.value for kw in node.keywords}
37
- if len(kwargs) > 2 or "qubits" not in kwargs:
38
- raise lowering.BuildError(f"Got unsupported keyword argument {kwargs}")
39
-
40
- qubits = kwargs["qubits"]
41
- if len(kwargs) == 1:
42
- if len(node.args) != 1:
43
- raise lowering.BuildError("Missing operator argument")
44
- op = node.args[0]
45
- else:
46
- try:
47
- op = kwargs["operator"]
48
- except KeyError:
49
- raise lowering.BuildError(f"Got unsupported keyword argument {kwargs}")
50
-
51
- if isinstance(qubits, ast.List):
52
- return op, qubits.elts
53
-
54
- return op, [qubits]
@@ -1,40 +0,0 @@
1
- from typing import Literal
2
-
3
- from kirin.dialects import ilist
4
- from kirin.lowering import wraps
5
-
6
- from bloqade.squin.op.types import Op
7
-
8
- from . import stmts
9
-
10
-
11
- @wraps(stmts.PauliError)
12
- def pauli_error(basis: Op, p: float) -> Op: ...
13
-
14
-
15
- @wraps(stmts.PPError)
16
- def pp_error(op: Op, p: float) -> Op: ...
17
-
18
-
19
- @wraps(stmts.Depolarize)
20
- def depolarize(p: float) -> Op: ...
21
-
22
-
23
- @wraps(stmts.Depolarize2)
24
- def depolarize2(p: float) -> Op: ...
25
-
26
-
27
- @wraps(stmts.SingleQubitPauliChannel)
28
- def single_qubit_pauli_channel(
29
- params: ilist.IList[float, Literal[3]] | list[float] | tuple[float, float, float],
30
- ) -> Op: ...
31
-
32
-
33
- @wraps(stmts.TwoQubitPauliChannel)
34
- def two_qubit_pauli_channel(
35
- params: ilist.IList[float, Literal[15]] | list[float] | tuple[float, ...],
36
- ) -> Op: ...
37
-
38
-
39
- @wraps(stmts.QubitLoss)
40
- def qubit_loss(p: float) -> Op: ...
@@ -1,111 +0,0 @@
1
- import itertools
2
-
3
- from kirin import ir
4
- from kirin.passes import Pass
5
- from kirin.rewrite import Walk
6
- from kirin.dialects import ilist
7
- from kirin.rewrite.abc import RewriteRule, RewriteResult
8
-
9
- from .stmts import (
10
- PPError,
11
- QubitLoss,
12
- Depolarize,
13
- PauliError,
14
- NoiseChannel,
15
- TwoQubitPauliChannel,
16
- SingleQubitPauliChannel,
17
- StochasticUnitaryChannel,
18
- )
19
- from ..op.stmts import X, Y, Z, Kron, Identity
20
-
21
-
22
- class _RewriteNoiseStmts(RewriteRule):
23
- """Rewrites squin noise statements to StochasticUnitaryChannel"""
24
-
25
- def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
26
- if not isinstance(node, NoiseChannel) or isinstance(node, QubitLoss):
27
- return RewriteResult()
28
-
29
- return getattr(self, "rewrite_" + node.name)(node)
30
-
31
- def rewrite_pauli_error(self, node: PauliError) -> RewriteResult:
32
- (operators := ilist.New(values=(node.basis,))).insert_before(node)
33
- (ps := ilist.New(values=(node.p,))).insert_before(node)
34
- stochastic_channel = StochasticUnitaryChannel(
35
- operators=operators.result, probabilities=ps.result
36
- )
37
-
38
- node.replace_by(stochastic_channel)
39
- return RewriteResult(has_done_something=True)
40
-
41
- def rewrite_single_qubit_pauli_channel(
42
- self, node: SingleQubitPauliChannel
43
- ) -> RewriteResult:
44
- paulis = (X(), Y(), Z())
45
- paulis_ssa: list[ir.SSAValue] = []
46
- for op in paulis:
47
- op.insert_before(node)
48
- paulis_ssa.append(op.result)
49
-
50
- (pauli_ops := ilist.New(values=paulis_ssa)).insert_before(node)
51
-
52
- stochastic_unitary = StochasticUnitaryChannel(
53
- operators=pauli_ops.result, probabilities=node.params
54
- )
55
- node.replace_by(stochastic_unitary)
56
- return RewriteResult(has_done_something=True)
57
-
58
- def rewrite_two_qubit_pauli_channel(
59
- self, node: TwoQubitPauliChannel
60
- ) -> RewriteResult:
61
- paulis = (Identity(sites=1), X(), Y(), Z())
62
- for op in paulis:
63
- op.insert_before(node)
64
-
65
- # NOTE: collect list so we can skip the first entry, which will be two identities
66
- combinations = list(itertools.product(paulis, repeat=2))[1:]
67
- operators: list[ir.SSAValue] = []
68
- for pauli_1, pauli_2 in combinations:
69
- op = Kron(pauli_1.result, pauli_2.result)
70
- op.insert_before(node)
71
- operators.append(op.result)
72
-
73
- (operator_list := ilist.New(values=operators)).insert_before(node)
74
- stochastic_unitary = StochasticUnitaryChannel(
75
- operators=operator_list.result, probabilities=node.params
76
- )
77
-
78
- node.replace_by(stochastic_unitary)
79
- return RewriteResult(has_done_something=True)
80
-
81
- def rewrite_p_p_error(self, node: PPError) -> RewriteResult:
82
- (operators := ilist.New(values=(node.op,))).insert_before(node)
83
- (ps := ilist.New(values=(node.p,))).insert_before(node)
84
- stochastic_channel = StochasticUnitaryChannel(
85
- operators=operators.result, probabilities=ps.result
86
- )
87
-
88
- node.replace_by(stochastic_channel)
89
- return RewriteResult(has_done_something=True)
90
-
91
- def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
92
- paulis = (X(), Y(), Z())
93
- operators: list[ir.SSAValue] = []
94
- for op in paulis:
95
- op.insert_before(node)
96
- operators.append(op.result)
97
-
98
- (operator_list := ilist.New(values=operators)).insert_before(node)
99
- (ps := ilist.New(values=[node.p for _ in range(3)])).insert_before(node)
100
-
101
- stochastic_unitary = StochasticUnitaryChannel(
102
- operators=operator_list.result, probabilities=ps.result
103
- )
104
- node.replace_by(stochastic_unitary)
105
-
106
- return RewriteResult(has_done_something=True)
107
-
108
-
109
- class RewriteNoiseStmts(Pass):
110
- def unsafe_run(self, mt: ir.Method):
111
- return Walk(_RewriteNoiseStmts()).rewrite(mt.code)
@@ -1,41 +0,0 @@
1
- from . import stmts as stmts, types as types, rewrite as rewrite
2
- from .stdlib import (
3
- ch as ch,
4
- cx as cx,
5
- cy as cy,
6
- cz as cz,
7
- rx as rx,
8
- ry as ry,
9
- rz as rz,
10
- cphase as cphase,
11
- )
12
- from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
13
- from ._dialect import dialect as dialect
14
- from ._wrapper import (
15
- h as h,
16
- s as s,
17
- t as t,
18
- u as u,
19
- x as x,
20
- y as y,
21
- z as z,
22
- p0 as p0,
23
- p1 as p1,
24
- rot as rot,
25
- kron as kron,
26
- mult as mult,
27
- phase as phase,
28
- reset as reset,
29
- scale as scale,
30
- shift as shift,
31
- spin_n as spin_n,
32
- spin_p as spin_p,
33
- sqrt_x as sqrt_x,
34
- sqrt_y as sqrt_y,
35
- sqrt_z as sqrt_z,
36
- adjoint as adjoint,
37
- control as control,
38
- identity as identity,
39
- pauli_string as pauli_string,
40
- reset_to_one as reset_to_one,
41
- )
@@ -1,3 +0,0 @@
1
- from kirin import ir
2
-
3
- dialect = ir.Dialect("squin.op")