bloqade-circuit 0.7.13__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (136) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +119 -29
  3. bloqade/analysis/address/impls.py +290 -87
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +2 -2
  6. bloqade/analysis/measure_id/impls.py +3 -27
  7. bloqade/cirq_utils/__init__.py +3 -1
  8. bloqade/cirq_utils/emit/__init__.py +3 -0
  9. bloqade/cirq_utils/emit/base.py +243 -0
  10. bloqade/cirq_utils/emit/gate.py +104 -0
  11. bloqade/cirq_utils/emit/noise.py +90 -0
  12. bloqade/cirq_utils/emit/qubit.py +35 -0
  13. bloqade/cirq_utils/lowering.py +664 -0
  14. bloqade/native/__init__.py +0 -1
  15. bloqade/native/_prelude.py +3 -3
  16. bloqade/native/dialects/gate/__init__.py +2 -0
  17. bloqade/native/dialects/gate/_dialect.py +3 -0
  18. bloqade/native/dialects/{gates → gate}/_interface.py +5 -5
  19. bloqade/native/dialects/{gates → gate}/stmts.py +5 -5
  20. bloqade/native/stdlib/broadcast.py +19 -19
  21. bloqade/native/stdlib/simple.py +14 -13
  22. bloqade/native/upstream/__init__.py +5 -0
  23. bloqade/native/upstream/squin2native.py +136 -0
  24. bloqade/pyqrack/__init__.py +1 -2
  25. bloqade/pyqrack/device.py +6 -17
  26. bloqade/pyqrack/native.py +17 -17
  27. bloqade/pyqrack/reg.py +1 -6
  28. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  29. bloqade/pyqrack/squin/gate/gate.py +136 -0
  30. bloqade/pyqrack/squin/noise/native.py +120 -54
  31. bloqade/pyqrack/squin/qubit.py +25 -41
  32. bloqade/pyqrack/target.py +2 -2
  33. bloqade/qasm2/dialects/core/address.py +21 -12
  34. bloqade/qasm2/dialects/noise/fidelity.py +2 -6
  35. bloqade/qasm2/dialects/noise/model.py +2 -1
  36. bloqade/qasm2/passes/parallel.py +3 -1
  37. bloqade/qasm2/rewrite/__init__.py +0 -1
  38. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  39. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  40. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  41. bloqade/qubit/__init__.py +12 -0
  42. bloqade/qubit/_dialect.py +3 -0
  43. bloqade/qubit/_interface.py +49 -0
  44. bloqade/qubit/_prelude.py +45 -0
  45. bloqade/qubit/analysis/__init__.py +1 -0
  46. bloqade/qubit/analysis/address_impl.py +40 -0
  47. bloqade/qubit/stdlib/__init__.py +2 -0
  48. bloqade/qubit/stdlib/_new.py +34 -0
  49. bloqade/qubit/stdlib/broadcast.py +62 -0
  50. bloqade/qubit/stdlib/simple.py +59 -0
  51. bloqade/qubit/stmts.py +60 -0
  52. bloqade/rewrite/passes/aggressive_unroll.py +2 -1
  53. bloqade/squin/__init__.py +44 -17
  54. bloqade/squin/analysis/__init__.py +0 -1
  55. bloqade/squin/analysis/schedule.py +2 -2
  56. bloqade/squin/gate/__init__.py +2 -0
  57. bloqade/squin/gate/_dialect.py +3 -0
  58. bloqade/squin/gate/_interface.py +98 -0
  59. bloqade/squin/gate/stmts.py +119 -0
  60. bloqade/squin/groups.py +4 -21
  61. bloqade/squin/noise/__init__.py +1 -9
  62. bloqade/squin/noise/_dialect.py +1 -1
  63. bloqade/squin/noise/_interface.py +45 -0
  64. bloqade/squin/noise/stmts.py +65 -29
  65. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  66. bloqade/squin/rewrite/__init__.py +0 -2
  67. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  68. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  69. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  70. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  71. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  72. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  73. bloqade/squin/stdlib/simple/__init__.py +33 -0
  74. bloqade/squin/stdlib/simple/gate.py +242 -0
  75. bloqade/squin/stdlib/simple/noise.py +126 -0
  76. bloqade/stim/__init__.py +1 -0
  77. bloqade/stim/_wrappers.py +6 -0
  78. bloqade/stim/dialects/noise/emit.py +6 -1
  79. bloqade/stim/dialects/noise/stmts.py +5 -3
  80. bloqade/stim/emit/stim_str.py +2 -0
  81. bloqade/stim/parse/lowering.py +12 -17
  82. bloqade/stim/passes/__init__.py +0 -1
  83. bloqade/stim/passes/flatten.py +26 -0
  84. bloqade/stim/passes/simplify_ifs.py +6 -1
  85. bloqade/stim/passes/squin_to_stim.py +4 -70
  86. bloqade/stim/rewrite/__init__.py +0 -4
  87. bloqade/stim/rewrite/ifs_to_stim.py +23 -29
  88. bloqade/stim/rewrite/qubit_to_stim.py +96 -51
  89. bloqade/stim/rewrite/squin_measure.py +9 -18
  90. bloqade/stim/rewrite/squin_noise.py +132 -108
  91. bloqade/stim/rewrite/util.py +5 -204
  92. bloqade/types.py +10 -0
  93. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/METADATA +2 -2
  94. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/RECORD +96 -100
  95. bloqade/native/dialects/gates/__init__.py +0 -3
  96. bloqade/native/dialects/gates/_dialect.py +0 -3
  97. bloqade/pyqrack/squin/op.py +0 -180
  98. bloqade/pyqrack/squin/runtime.py +0 -543
  99. bloqade/pyqrack/squin/wire.py +0 -51
  100. bloqade/squin/_typeinfer.py +0 -20
  101. bloqade/squin/analysis/address_impl.py +0 -71
  102. bloqade/squin/analysis/nsites/__init__.py +0 -9
  103. bloqade/squin/analysis/nsites/analysis.py +0 -50
  104. bloqade/squin/analysis/nsites/impls.py +0 -99
  105. bloqade/squin/analysis/nsites/lattice.py +0 -49
  106. bloqade/squin/cirq/__init__.py +0 -306
  107. bloqade/squin/cirq/emit/emit_circuit.py +0 -129
  108. bloqade/squin/cirq/emit/noise.py +0 -49
  109. bloqade/squin/cirq/emit/op.py +0 -176
  110. bloqade/squin/cirq/emit/qubit.py +0 -58
  111. bloqade/squin/cirq/emit/runtime.py +0 -242
  112. bloqade/squin/cirq/lowering.py +0 -439
  113. bloqade/squin/lowering.py +0 -80
  114. bloqade/squin/noise/_wrapper.py +0 -36
  115. bloqade/squin/noise/rewrite.py +0 -129
  116. bloqade/squin/op/__init__.py +0 -41
  117. bloqade/squin/op/_dialect.py +0 -3
  118. bloqade/squin/op/_wrapper.py +0 -121
  119. bloqade/squin/op/number.py +0 -5
  120. bloqade/squin/op/rewrite.py +0 -46
  121. bloqade/squin/op/stdlib.py +0 -62
  122. bloqade/squin/op/stmts.py +0 -300
  123. bloqade/squin/op/traits.py +0 -43
  124. bloqade/squin/op/types.py +0 -128
  125. bloqade/squin/parallel.py +0 -200
  126. bloqade/squin/qubit.py +0 -194
  127. bloqade/squin/rewrite/canonicalize.py +0 -60
  128. bloqade/squin/rewrite/desugar.py +0 -102
  129. bloqade/squin/stdlib/channel.py +0 -86
  130. bloqade/squin/stdlib/gate.py +0 -201
  131. bloqade/squin/types.py +0 -8
  132. bloqade/squin/wire.py +0 -201
  133. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  134. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  135. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/WHEEL +0 -0
  136. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,664 @@
1
+ from typing import Any
2
+ from dataclasses import field, dataclass
3
+
4
+ import cirq
5
+ from kirin import ir, types, lowering
6
+ from kirin.rewrite import Walk, CFGCompactify
7
+ from kirin.dialects import py, scf, func, ilist
8
+
9
+ from bloqade import qubit
10
+ from bloqade.squin import gate, noise, kernel, qalloc
11
+
12
+
13
+ def load_circuit(
14
+ circuit: cirq.Circuit,
15
+ kernel_name: str = "main",
16
+ dialects: ir.DialectGroup = kernel,
17
+ register_as_argument: bool = False,
18
+ return_register: bool = False,
19
+ register_argument_name: str = "q",
20
+ globals: dict[str, Any] | None = None,
21
+ file: str | None = None,
22
+ lineno_offset: int = 0,
23
+ col_offset: int = 0,
24
+ compactify: bool = True,
25
+ ):
26
+ """Converts a cirq.Circuit object into a squin kernel.
27
+
28
+ Args:
29
+ circuit (cirq.Circuit): The circuit to load.
30
+
31
+ Keyword Args:
32
+ kernel_name (str): The name of the kernel to load. Defaults to "main".
33
+ dialects (ir.DialectGroup | None): The dialects to use. Defaults to `squin.kernel`.
34
+ register_as_argument (bool): Determine whether the resulting kernel function should accept
35
+ a single `ilist.IList[Qubit, Any]` argument that is a list of qubits used within the
36
+ function. This allows you to compose kernel functions generated from circuits.
37
+ Defaults to `False`.
38
+ return_register (bool): Determine whether the resulting kernel functionr returns a
39
+ single value of type `ilist.IList[Qubit, Any]` that is the list of qubits used
40
+ in the kernel function. Useful when you want to compose multiple kernel functions
41
+ generated from circuits. Defaults to `False`.
42
+ register_argument_name (str): The name of the argument that represents the qubit register.
43
+ Only used when `register_as_argument=True`. Defaults to "q".
44
+ globals (dict[str, Any] | None): The global variables to use. Defaults to None.
45
+ file (str | None): The file name for error reporting. Defaults to None.
46
+ lineno_offset (int): The line number offset for error reporting. Defaults to 0.
47
+ col_offset (int): The column number offset for error reporting. Defaults to 0.
48
+ compactify (bool): Whether to compactify the output. Defaults to True.
49
+
50
+ ## Usage Examples:
51
+
52
+ ```python
53
+ # from cirq's "hello qubit" example
54
+ import cirq
55
+ from bloqade.cirq_utils import load_circuit
56
+
57
+ # Pick a qubit.
58
+ qubit = cirq.GridQubit(0, 0)
59
+
60
+ # Create a circuit.
61
+ circuit = cirq.Circuit(
62
+ cirq.X(qubit)**0.5, # Square root of NOT.
63
+ cirq.measure(qubit, key='m') # Measurement.
64
+ )
65
+
66
+ # load the circuit as squin
67
+ main = load_circuit(circuit)
68
+
69
+ # print the resulting IR
70
+ main.print()
71
+ ```
72
+
73
+ You can also compose kernel functions generated from circuits by passing in
74
+ and / or returning the respective quantum registers:
75
+
76
+ ```python
77
+ import cirq
78
+ from bloqade.cirq_utils import load_circuit
79
+ from bloqade import squin
80
+
81
+ q = cirq.LineQubit.range(2)
82
+ circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q))
83
+
84
+ get_entangled_qubits = load_circuit(
85
+ circuit, return_register=True, kernel_name="get_entangled_qubits"
86
+ )
87
+ get_entangled_qubits.print()
88
+
89
+ entangle_qubits = load_circuit(
90
+ circuit, register_as_argument=True, kernel_name="entangle_qubits"
91
+ )
92
+
93
+ @squin.kernel
94
+ def main():
95
+ qreg = get_entangled_qubits()
96
+ qreg2 = squin.qalloc(1)
97
+ entangle_qubits([qreg[1], qreg2[0]])
98
+ return squin.qubit.measure(qreg2)
99
+ ```
100
+ """
101
+
102
+ target = Squin(dialects=dialects, circuit=circuit)
103
+ body = target.run(
104
+ circuit,
105
+ source=str(circuit), # TODO: proper source string
106
+ file=file,
107
+ globals=globals,
108
+ lineno_offset=lineno_offset,
109
+ col_offset=col_offset,
110
+ compactify=compactify,
111
+ register_as_argument=register_as_argument,
112
+ register_argument_name=register_argument_name,
113
+ )
114
+
115
+ if return_register:
116
+ return_value = target.qreg
117
+ else:
118
+ return_value = func.ConstantNone()
119
+ body.blocks[0].stmts.append(return_value)
120
+
121
+ return_node = func.Return(value_or_stmt=return_value)
122
+ body.blocks[0].stmts.append(return_node)
123
+
124
+ self_arg_name = kernel_name + "_self"
125
+ arg_names = [self_arg_name]
126
+ if register_as_argument:
127
+ args = (target.qreg.type,)
128
+ arg_names.append(register_argument_name)
129
+ else:
130
+ args = ()
131
+
132
+ # NOTE: add _self as argument; need to know signature before so do it after lowering
133
+ signature = func.Signature(args, return_node.value.type)
134
+ body.blocks[0].args.insert_from(
135
+ 0,
136
+ types.Generic(ir.Method, types.Tuple.where(signature.inputs), signature.output),
137
+ self_arg_name,
138
+ )
139
+
140
+ code = func.Function(
141
+ sym_name=kernel_name,
142
+ signature=signature,
143
+ body=body,
144
+ )
145
+
146
+ mt = ir.Method(
147
+ mod=None,
148
+ py_func=None,
149
+ sym_name=kernel_name,
150
+ arg_names=arg_names,
151
+ dialects=dialects,
152
+ code=code,
153
+ )
154
+
155
+ assert (run_pass := kernel.run_pass) is not None
156
+ run_pass(mt, typeinfer=True)
157
+
158
+ return mt
159
+
160
+
161
+ CirqNode = (
162
+ cirq.Circuit
163
+ | cirq.FrozenCircuit
164
+ | cirq.Moment
165
+ | cirq.Gate
166
+ | cirq.Qid
167
+ | cirq.Operation
168
+ )
169
+
170
+ DecomposeNode = (
171
+ cirq.SwapPowGate
172
+ | cirq.ISwapPowGate
173
+ | cirq.PhasedXPowGate
174
+ | cirq.PhasedXZGate
175
+ | cirq.CSwapGate
176
+ | cirq.XXPowGate
177
+ | cirq.YYPowGate
178
+ | cirq.CCXPowGate
179
+ | cirq.CCZPowGate
180
+ )
181
+
182
+
183
+ @dataclass
184
+ class Squin(lowering.LoweringABC[cirq.Circuit]):
185
+ """Lower a cirq.Circuit object to a squin kernel"""
186
+
187
+ circuit: cirq.Circuit
188
+ qreg: ir.SSAValue = field(init=False)
189
+ qreg_index: dict[cirq.Qid, int] = field(init=False, default_factory=dict)
190
+ next_qreg_index: int = field(init=False, default=0)
191
+
192
+ two_qubit_paulis = (
193
+ "IX",
194
+ "IY",
195
+ "IZ",
196
+ "XI",
197
+ "XX",
198
+ "XY",
199
+ "XZ",
200
+ "YI",
201
+ "YX",
202
+ "YY",
203
+ "YZ",
204
+ "ZI",
205
+ "ZX",
206
+ "ZY",
207
+ "ZZ",
208
+ )
209
+
210
+ def __post_init__(self):
211
+ # TODO: sort by cirq ordering
212
+ qbits = sorted(self.circuit.all_qubits())
213
+ self.qreg_index = {qid: idx for (idx, qid) in enumerate(qbits)}
214
+
215
+ def lower_qubit_getindex(self, state: lowering.State[cirq.Circuit], qid: cirq.Qid):
216
+ index = self.qreg_index[qid]
217
+ index_ssa = state.current_frame.push(py.Constant(index)).result
218
+ qbit_getitem = state.current_frame.push(py.GetItem(self.qreg, index_ssa))
219
+ return qbit_getitem.result
220
+
221
+ def lower_qubit_getindices(
222
+ self, state: lowering.State[cirq.Circuit], qids: tuple[cirq.Qid, ...]
223
+ ):
224
+ qbits_getitem = [self.lower_qubit_getindex(state, qid) for qid in qids]
225
+ qbits = state.current_frame.push(ilist.New(values=qbits_getitem))
226
+ return qbits.result
227
+
228
+ def run(
229
+ self,
230
+ stmt: cirq.Circuit,
231
+ *,
232
+ source: str | None = None,
233
+ globals: dict[str, Any] | None = None,
234
+ file: str | None = None,
235
+ lineno_offset: int = 0,
236
+ col_offset: int = 0,
237
+ compactify: bool = True,
238
+ register_as_argument: bool = False,
239
+ register_argument_name: str = "q",
240
+ ) -> ir.Region:
241
+
242
+ state = lowering.State(
243
+ self,
244
+ file=file,
245
+ lineno_offset=lineno_offset,
246
+ col_offset=col_offset,
247
+ )
248
+
249
+ with state.frame([stmt], globals=globals, finalize_next=False) as frame:
250
+
251
+ # NOTE: need a register of qubits before lowering statements
252
+ if register_as_argument:
253
+ # NOTE: register as argument to the kernel; we have freedom of choice for the name here
254
+ frame.curr_block.args.append_from(
255
+ ilist.IListType[qubit.QubitType, types.Any],
256
+ name=register_argument_name,
257
+ )
258
+ self.qreg = frame.curr_block.args[0]
259
+ else:
260
+ # NOTE: create a new register of appropriate size
261
+ n_qubits = len(self.qreg_index)
262
+ n = frame.push(py.Constant(n_qubits))
263
+ self.qreg = frame.push(
264
+ func.Invoke((n.result,), callee=qalloc, kwargs=())
265
+ ).result
266
+
267
+ self.visit(state, stmt)
268
+
269
+ if compactify:
270
+ Walk(CFGCompactify()).rewrite(frame.curr_region)
271
+
272
+ region = frame.curr_region
273
+
274
+ return region
275
+
276
+ def visit(
277
+ self, state: lowering.State[cirq.Circuit], node: CirqNode
278
+ ) -> lowering.Result:
279
+ name = node.__class__.__name__
280
+ return getattr(self, f"visit_{name}", self.generic_visit)(state, node)
281
+
282
+ def generic_visit(self, state: lowering.State[cirq.Circuit], node: CirqNode):
283
+ if isinstance(node, CirqNode):
284
+ raise lowering.BuildError(
285
+ f"Cannot lower {node.__class__.__name__} node: {node}"
286
+ )
287
+ raise lowering.BuildError(f"Cannot lower {node}")
288
+
289
+ # return self.visit_Operation(state, node)
290
+
291
+ def lower_literal(self, state: lowering.State[cirq.Circuit], value) -> ir.SSAValue:
292
+ raise lowering.BuildError("Literals not supported in cirq circuit")
293
+
294
+ def lower_global(
295
+ self, state: lowering.State[cirq.Circuit], node: CirqNode
296
+ ) -> lowering.LoweringABC.Result:
297
+ raise lowering.BuildError("Literals not supported in cirq circuit")
298
+
299
+ def visit_Circuit(
300
+ self,
301
+ state: lowering.State[cirq.Circuit],
302
+ node: cirq.Circuit | cirq.FrozenCircuit,
303
+ ) -> lowering.Result:
304
+ for moment in node:
305
+ self.visit_Moment(state, moment)
306
+
307
+ def visit_Moment(
308
+ self, state: lowering.State[cirq.Circuit], node: cirq.Moment
309
+ ) -> lowering.Result:
310
+ for op_ in node.operations:
311
+ self.visit(state, op_)
312
+
313
+ def visit_GateOperation(
314
+ self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
315
+ ):
316
+ if isinstance(node.gate, DecomposeNode):
317
+ # NOTE: easier to decompose these, but for that we need the qubits too,
318
+ # so we need to do this within this method
319
+ for subnode in cirq.decompose_once(node):
320
+ self.visit(state, subnode)
321
+ return
322
+
323
+ # NOTE: just forward to the appropriate method by getting the name
324
+ name = node.gate.__class__.__name__
325
+ return getattr(self, f"visit_{name}", self.generic_visit)(state, node)
326
+
327
+ def visit_TaggedOperation(
328
+ self, state: lowering.State[cirq.Circuit], node: cirq.TaggedOperation
329
+ ):
330
+ return self.visit(state, node.untagged)
331
+
332
+ def visit_ClassicallyControlledOperation(
333
+ self,
334
+ state: lowering.State[cirq.Circuit],
335
+ node: cirq.ClassicallyControlledOperation,
336
+ ):
337
+ conditions: list[ir.SSAValue] = []
338
+ for outcome in node.classical_controls:
339
+ key = outcome.key
340
+ if isinstance(key, cirq.MeasurementKey):
341
+ key = key.name
342
+ measurement_outcome = state.current_frame.defs[key]
343
+
344
+ if measurement_outcome.type.is_subseteq(ilist.IListType):
345
+ # NOTE: there is currently no convenient ilist.any method, so we need to use foldl
346
+ # with a simple function that just does an or
347
+
348
+ def bool_op_or(x: bool, y: bool) -> bool:
349
+ return x or y
350
+
351
+ f_code = state.current_frame.push(
352
+ lowering.Python(self.dialects).python_function(bool_op_or)
353
+ )
354
+ fn = ir.Method(
355
+ mod=None,
356
+ py_func=bool_op_or,
357
+ sym_name="bool_op_or",
358
+ arg_names=[],
359
+ dialects=self.dialects,
360
+ code=f_code,
361
+ )
362
+ f_const = state.current_frame.push(py.constant.Constant(fn))
363
+ init_val = state.current_frame.push(py.Constant(False)).result
364
+ condition = state.current_frame.push(
365
+ ilist.Foldl(f_const.result, measurement_outcome, init=init_val)
366
+ ).result
367
+ else:
368
+ condition = measurement_outcome
369
+
370
+ conditions.append(condition)
371
+
372
+ if len(conditions) == 1:
373
+ condition = conditions[0]
374
+ else:
375
+ condition = state.current_frame.push(
376
+ py.boolop.And(conditions[0], conditions[1])
377
+ ).result
378
+ for next_cond in conditions[2:]:
379
+ condition = state.current_frame.push(
380
+ py.boolop.And(condition, next_cond)
381
+ ).result
382
+
383
+ then_stmt = self.visit(state, node.without_classical_controls())
384
+
385
+ assert isinstance(
386
+ then_stmt, ir.Statement
387
+ ), f"Expected operation of classically controlled node {node} to be lowered to a statement, got type {type(then_stmt)}. \
388
+ Please report this issue!"
389
+
390
+ # NOTE: remove stmt from parent block
391
+ then_stmt.detach()
392
+ then_body = ir.Block((then_stmt,))
393
+ then_body.args.append_from(types.Bool, name="cond")
394
+ then_body.stmts.append(scf.Yield())
395
+
396
+ else_body = ir.Block(())
397
+ else_body.args.append_from(types.Bool, name="cond")
398
+ else_body.stmts.append(scf.Yield())
399
+
400
+ return state.current_frame.push(
401
+ scf.IfElse(condition, then_body=then_body, else_body=else_body)
402
+ )
403
+
404
+ def visit_MeasurementGate(
405
+ self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
406
+ ):
407
+ qubits = self.lower_qubit_getindices(state, node.qubits)
408
+ stmt = state.current_frame.push(qubit.stmts.Measure(qubits))
409
+
410
+ # NOTE: add for classically controlled lowering
411
+ key = node.gate.key
412
+ if isinstance(key, cirq.MeasurementKey):
413
+ key = key.name
414
+ state.current_frame.defs[key] = stmt.result
415
+
416
+ return stmt
417
+
418
+ def visit_SingleQubitPauliStringGateOperation(
419
+ self,
420
+ state: lowering.State[cirq.Circuit],
421
+ node: cirq.SingleQubitPauliStringGateOperation,
422
+ ):
423
+ if isinstance(node.pauli, cirq.IdentityGate):
424
+ # TODO: do we need an identity gate in gate?
425
+ return
426
+
427
+ qargs = self.lower_qubit_getindices(state, (node.qubit,))
428
+ match node.pauli:
429
+ case cirq.X:
430
+ gate_stmt = gate.stmts.X
431
+ case cirq.Y:
432
+ gate_stmt = gate.stmts.Y
433
+ case cirq.Z:
434
+ gate_stmt = gate.stmts.Z
435
+ case _:
436
+ raise lowering.BuildError(f"Unexpected Pauli operation {node.pauli}")
437
+
438
+ return state.current_frame.push(gate_stmt(qargs))
439
+
440
+ def visit_HPowGate(self, state: lowering.State[cirq.Circuit], node: cirq.HPowGate):
441
+ qargs = self.lower_qubit_getindices(state, node.qubits)
442
+
443
+ if node.gate.exponent % 2 == 1:
444
+ return state.current_frame.push(gate.stmts.H(qargs))
445
+
446
+ # NOTE: decompose into products of paulis for arbitrary exponents according to _decompose_ method
447
+ for subnode in cirq.decompose_once(node):
448
+ self.visit(state, subnode)
449
+
450
+ def visit_XPowGate(
451
+ self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
452
+ ):
453
+ qargs = self.lower_qubit_getindices(state, node.qubits)
454
+ if node.gate.exponent % 2 == 1:
455
+ return state.current_frame.push(gate.stmts.X(qargs))
456
+
457
+ angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent))
458
+ return state.current_frame.push(gate.stmts.Rx(angle.result, qargs))
459
+
460
+ def visit_YPowGate(
461
+ self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
462
+ ):
463
+ qargs = self.lower_qubit_getindices(state, node.qubits)
464
+ if node.gate.exponent % 2 == 1:
465
+ return state.current_frame.push(gate.stmts.Y(qargs))
466
+
467
+ angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent))
468
+ return state.current_frame.push(gate.stmts.Ry(angle.result, qargs))
469
+
470
+ def visit_ZPowGate(
471
+ self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
472
+ ):
473
+ qargs = self.lower_qubit_getindices(state, node.qubits)
474
+
475
+ if abs(node.gate.exponent) == 0.5:
476
+ adjoint = node.gate.exponent < 0
477
+ return state.current_frame.push(gate.stmts.S(adjoint=adjoint, qubits=qargs))
478
+
479
+ if abs(node.gate.exponent) == 0.25:
480
+ adjoint = node.gate.exponent < 0
481
+ return state.current_frame.push(gate.stmts.T(adjoint=adjoint, qubits=qargs))
482
+
483
+ if node.gate.exponent % 2 == 1:
484
+ return state.current_frame.push(gate.stmts.Z(qubits=qargs))
485
+
486
+ angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent))
487
+ return state.current_frame.push(gate.stmts.Rz(angle.result, qargs))
488
+
489
+ def visit_Rx(self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation):
490
+ qargs = self.lower_qubit_getindices(state, node.qubits)
491
+ angle = state.current_frame.push(py.Constant(value=0.5 * node.gate.exponent))
492
+ return state.current_frame.push(gate.stmts.Rx(angle.result, qargs))
493
+
494
+ def visit_Ry(self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation):
495
+ qargs = self.lower_qubit_getindices(state, node.qubits)
496
+ angle = state.current_frame.push(py.Constant(value=0.5 * node.gate.exponent))
497
+ return state.current_frame.push(gate.stmts.Ry(angle.result, qargs))
498
+
499
+ def visit_Rz(self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation):
500
+ qargs = self.lower_qubit_getindices(state, node.qubits)
501
+ angle = state.current_frame.push(py.Constant(value=0.5 * node.gate.exponent))
502
+ return state.current_frame.push(gate.stmts.Rz(angle.result, qargs))
503
+
504
+ def visit_CXPowGate(
505
+ self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
506
+ ):
507
+ if node.gate.exponent % 2 == 0:
508
+ return
509
+
510
+ if node.gate.exponent % 2 != 1:
511
+ raise lowering.BuildError("Exponents of CX gate are not supported!")
512
+
513
+ control, target = node.qubits
514
+ control_qarg = self.lower_qubit_getindices(state, (control,))
515
+ target_qarg = self.lower_qubit_getindices(state, (target,))
516
+ return state.current_frame.push(
517
+ gate.stmts.CX(controls=control_qarg, targets=target_qarg)
518
+ )
519
+
520
+ def visit_CZPowGate(
521
+ self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
522
+ ):
523
+ if node.gate.exponent % 2 == 0:
524
+ return
525
+
526
+ if node.gate.exponent % 2 != 1:
527
+ raise lowering.BuildError("Exponents of CZ gate are not supported!")
528
+
529
+ control, target = node.qubits
530
+ control_qarg = self.lower_qubit_getindices(state, (control,))
531
+ target_qarg = self.lower_qubit_getindices(state, (target,))
532
+ return state.current_frame.push(
533
+ gate.stmts.CZ(controls=control_qarg, targets=target_qarg)
534
+ )
535
+
536
+ def visit_ZZPowGate(
537
+ self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
538
+ ):
539
+ if node.gate.exponent % 2 == 0:
540
+ return
541
+
542
+ qubit1, qubit2 = node.qubits
543
+ qarg1 = self.lower_qubit_getindices(state, (qubit1,))
544
+ qarg2 = self.lower_qubit_getindices(state, (qubit2,))
545
+
546
+ if node.gate.exponent % 2 == 1:
547
+ state.current_frame.push(gate.stmts.X(qarg1))
548
+ state.current_frame.push(gate.stmts.X(qarg2))
549
+ return
550
+
551
+ # NOTE: arbitrary exponent, write as CX * Rz * CX (up to global phase)
552
+ state.current_frame.push(gate.stmts.CX(qarg1, qarg2))
553
+ angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent))
554
+ state.current_frame.push(gate.stmts.Rz(angle.result, qarg2))
555
+ state.current_frame.push(gate.stmts.CX(qarg1, qarg2))
556
+
557
+ def visit_ControlledOperation(
558
+ self, state: lowering.State[cirq.Circuit], node: cirq.ControlledOperation
559
+ ):
560
+ match node.gate.sub_gate:
561
+ case cirq.X:
562
+ stmt = gate.stmts.CX
563
+ case cirq.Y:
564
+ stmt = gate.stmts.CY
565
+ case cirq.Z:
566
+ stmt = gate.stmts.CZ
567
+ case _:
568
+ raise lowering.BuildError(
569
+ f"Cannot lowering controlled operation: {node}"
570
+ )
571
+
572
+ control, target = node.qubits
573
+ control_qarg = self.lower_qubit_getindices(state, (control,))
574
+ target_qarg = self.lower_qubit_getindices(state, (target,))
575
+ return state.current_frame.push(stmt(control_qarg, target_qarg))
576
+
577
+ def visit_FrozenCircuit(
578
+ self, state: lowering.State[cirq.Circuit], node: cirq.FrozenCircuit
579
+ ):
580
+ return self.visit_Circuit(state, node)
581
+
582
+ def visit_CircuitOperation(
583
+ self, state: lowering.State[cirq.Circuit], node: cirq.CircuitOperation
584
+ ):
585
+ reps = node.repetitions
586
+
587
+ if not isinstance(reps, int):
588
+ raise lowering.BuildError(
589
+ f"Cannot lower CircuitOperation with non-integer repetitions: {node}"
590
+ )
591
+
592
+ if reps > 1:
593
+ raise lowering.BuildError(
594
+ "Repetitions of circuit operatiosn not yet supported"
595
+ )
596
+
597
+ return self.visit(state, node.circuit)
598
+
599
+ def visit_BitFlipChannel(
600
+ self, state: lowering.State[cirq.Circuit], node: cirq.BitFlipChannel
601
+ ):
602
+ p = node.gate.p
603
+ p_x = state.current_frame.push(py.Constant(p)).result
604
+ p_y = p_z = state.current_frame.push(py.Constant(0)).result
605
+ qubits = self.lower_qubit_getindices(state, node.qubits)
606
+ return state.current_frame.push(
607
+ noise.stmts.SingleQubitPauliChannel(px=p_x, py=p_y, pz=p_z, qubits=qubits)
608
+ )
609
+
610
+ def visit_DepolarizingChannel(
611
+ self, state: lowering.State[cirq.Circuit], node: cirq.DepolarizingChannel
612
+ ):
613
+ p = state.current_frame.push(py.Constant(node.gate.p)).result
614
+ qubits = self.lower_qubit_getindices(state, node.qubits)
615
+ return state.current_frame.push(noise.stmts.Depolarize(p, qubits=qubits))
616
+
617
+ def visit_AsymmetricDepolarizingChannel(
618
+ self,
619
+ state: lowering.State[cirq.Circuit],
620
+ node: cirq.AsymmetricDepolarizingChannel,
621
+ ):
622
+ nqubits = node.gate.num_qubits()
623
+ if nqubits > 2:
624
+ raise lowering.BuildError(
625
+ "AsymmetricDepolarizingChannel applied to more than 2 qubits is not supported!"
626
+ )
627
+
628
+ if nqubits == 1:
629
+ qubits = self.lower_qubit_getindices(state, node.qubits)
630
+ p_x = state.current_frame.push(py.Constant(node.gate.p_x)).result
631
+ p_y = state.current_frame.push(py.Constant(node.gate.p_y)).result
632
+ p_z = state.current_frame.push(py.Constant(node.gate.p_z)).result
633
+ return state.current_frame.push(
634
+ noise.stmts.SingleQubitPauliChannel(p_x, p_y, p_z, qubits)
635
+ )
636
+
637
+ # NOTE: nqubits == 2
638
+ error_probs = node.gate.error_probabilities
639
+ probability_values = []
640
+ p0 = None
641
+ for key in self.two_qubit_paulis:
642
+ p = error_probs.get(key)
643
+
644
+ if p is None:
645
+ if p0 is None:
646
+ p0 = state.current_frame.push(py.Constant(0)).result
647
+ p_ssa = p0
648
+ else:
649
+ p_ssa = state.current_frame.push(py.Constant(p)).result
650
+ probability_values.append(p_ssa)
651
+
652
+ probabilities = state.current_frame.push(
653
+ ilist.New(values=probability_values)
654
+ ).result
655
+
656
+ control, target = node.qubits
657
+ control_qarg = self.lower_qubit_getindices(state, (control,))
658
+ target_qarg = self.lower_qubit_getindices(state, (target,))
659
+
660
+ return state.current_frame.push(
661
+ noise.stmts.TwoQubitPauliChannel(
662
+ probabilities, controls=control_qarg, targets=target_qarg
663
+ )
664
+ )
@@ -14,7 +14,6 @@ from .stdlib.simple import (
14
14
  ry as ry,
15
15
  rz as rz,
16
16
  u3 as u3,
17
- rot as rot,
18
17
  s_dag as s_dag,
19
18
  shift as shift,
20
19
  sqrt_x as sqrt_x,
@@ -5,12 +5,12 @@ from kirin.passes import Default
5
5
  from kirin.prelude import structural_no_opt
6
6
  from typing_extensions import Doc
7
7
 
8
- from bloqade.squin import qubit
8
+ from bloqade import qubit
9
9
 
10
- from .dialects import gates
10
+ from .dialects import gate
11
11
 
12
12
 
13
- @ir.dialect_group(structural_no_opt.union([gates, qubit]))
13
+ @ir.dialect_group(structural_no_opt.union([gate, qubit]))
14
14
  def kernel(self):
15
15
  """Compile a function to a native kernel."""
16
16
 
@@ -0,0 +1,2 @@
1
+ from . import stmts as stmts
2
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("native.gate")