bloqade-circuit 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,7 +16,7 @@ from .task import PyQrackSimulatorTask as PyQrackSimulatorTask
16
16
  # NOTE: The following import is for registering the method tables
17
17
  from .noise import native as native
18
18
  from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
19
- from .squin import op as op, qubit as qubit
19
+ from .squin import op as op, noise as noise, qubit as qubit
20
20
  from .device import (
21
21
  StackMemorySimulator as StackMemorySimulator,
22
22
  DynamicMemorySimulator as DynamicMemorySimulator,
@@ -0,0 +1 @@
1
+ from . import native as native
@@ -0,0 +1,72 @@
1
+ import random
2
+ import typing
3
+ from dataclasses import dataclass
4
+
5
+ from kirin import interp
6
+ from kirin.dialects import ilist
7
+
8
+ from bloqade.pyqrack import QubitState, PyQrackQubit, PyQrackInterpreter
9
+ from bloqade.squin.noise.stmts import QubitLoss, StochasticUnitaryChannel
10
+ from bloqade.squin.noise._dialect import dialect as squin_noise_dialect
11
+
12
+ from ..runtime import OperatorRuntimeABC
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class StochasticUnitaryChannelRuntime(OperatorRuntimeABC):
17
+ operators: ilist.IList[OperatorRuntimeABC, typing.Any]
18
+ probabilities: ilist.IList[float, typing.Any]
19
+
20
+ @property
21
+ def n_sites(self) -> int:
22
+ n = self.operators[0].n_sites
23
+ for op in self.operators[1:]:
24
+ assert (
25
+ op.n_sites == n
26
+ ), "Encountered a stochastic unitary channel with operators of different size!"
27
+ return n
28
+
29
+ def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
30
+ # NOTE: probabilities don't necessarily sum to 1; could be no noise event should occur
31
+ p_no_op = 1 - sum(self.probabilities)
32
+ if random.uniform(0.0, 1.0) < p_no_op:
33
+ return
34
+
35
+ selected_ops = random.choices(self.operators, weights=self.probabilities)
36
+ for op in selected_ops:
37
+ op.apply(*qubits, adjoint=adjoint)
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class QubitLossRuntime(OperatorRuntimeABC):
42
+ p: float
43
+
44
+ @property
45
+ def n_sites(self) -> int:
46
+ return 1
47
+
48
+ def apply(self, qubit: PyQrackQubit, adjoint: bool = False) -> None:
49
+ if random.uniform(0.0, 1.0) < self.p:
50
+ qubit.state = QubitState.Lost
51
+
52
+
53
+ @squin_noise_dialect.register(key="pyqrack")
54
+ class PyQrackMethods(interp.MethodTable):
55
+ @interp.impl(StochasticUnitaryChannel)
56
+ def stochastic_unitary_channel(
57
+ self,
58
+ interp: PyQrackInterpreter,
59
+ frame: interp.Frame,
60
+ stmt: StochasticUnitaryChannel,
61
+ ):
62
+ operators = frame.get(stmt.operators)
63
+ probabilities = frame.get(stmt.probabilities)
64
+
65
+ return (StochasticUnitaryChannelRuntime(operators, probabilities),)
66
+
67
+ @interp.impl(QubitLoss)
68
+ def qubit_loss(
69
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: QubitLoss
70
+ ):
71
+ p = frame.get(stmt.p)
72
+ return (QubitLossRuntime(p),)
bloqade/squin/__init__.py CHANGED
@@ -6,3 +6,12 @@ from . import (
6
6
  lowering as lowering,
7
7
  )
8
8
  from .groups import wired as wired, kernel as kernel
9
+
10
+ try:
11
+ # NOTE: make sure optional cirq dependency is installed
12
+ import cirq as cirq_package # noqa: F401
13
+ except ImportError:
14
+ pass
15
+ else:
16
+ from . import cirq as cirq
17
+ from .cirq import load_circuit as load_circuit
@@ -0,0 +1,89 @@
1
+ from typing import Any
2
+
3
+ import cirq
4
+ from kirin import ir, types
5
+ from kirin.dialects import func
6
+
7
+ from . import lowering as lowering
8
+ from .. import kernel
9
+ from .lowering import Squin
10
+
11
+
12
+ def load_circuit(
13
+ circuit: cirq.Circuit,
14
+ kernel_name: str = "main",
15
+ dialects: ir.DialectGroup = kernel,
16
+ globals: dict[str, Any] | None = None,
17
+ file: str | None = None,
18
+ lineno_offset: int = 0,
19
+ col_offset: int = 0,
20
+ compactify: bool = True,
21
+ ):
22
+ """Converts a cirq.Circuit object into a squin kernel.
23
+
24
+ Args:
25
+ circuit (cirq.Circuit): The circuit to load.
26
+
27
+ Keyword Args:
28
+ kernel_name (str): The name of the kernel to load. Defaults to "main".
29
+ dialects (ir.DialectGroup | None): The dialects to use. Defaults to `squin.kernel`.
30
+ globals (dict[str, Any] | None): The global variables to use. Defaults to None.
31
+ file (str | None): The file name for error reporting. Defaults to None.
32
+ lineno_offset (int): The line number offset for error reporting. Defaults to 0.
33
+ col_offset (int): The column number offset for error reporting. Defaults to 0.
34
+ compactify (bool): Whether to compactify the output. Defaults to True.
35
+
36
+ Example:
37
+
38
+ ```python
39
+ # from cirq's "hello qubit" example
40
+ import cirq
41
+ from bloqade import squin
42
+
43
+ # Pick a qubit.
44
+ qubit = cirq.GridQubit(0, 0)
45
+
46
+ # Create a circuit.
47
+ circuit = cirq.Circuit(
48
+ cirq.X(qubit)**0.5, # Square root of NOT.
49
+ cirq.measure(qubit, key='m') # Measurement.
50
+ )
51
+
52
+ # load the circuit as squin
53
+ main = squin.load_circuit(circuit)
54
+
55
+ # print the resulting IR
56
+ main.print()
57
+ ```
58
+ """
59
+
60
+ target = Squin(dialects=dialects, circuit=circuit)
61
+ body = target.run(
62
+ circuit,
63
+ source=str(circuit), # TODO: proper source string
64
+ file=file,
65
+ globals=globals,
66
+ lineno_offset=lineno_offset,
67
+ col_offset=col_offset,
68
+ compactify=compactify,
69
+ )
70
+
71
+ # NOTE: no return value
72
+ return_value = func.ConstantNone()
73
+ body.blocks[0].stmts.append(return_value)
74
+ body.blocks[0].stmts.append(func.Return(value_or_stmt=return_value))
75
+
76
+ code = func.Function(
77
+ sym_name=kernel_name,
78
+ signature=func.Signature((), types.NoneType),
79
+ body=body,
80
+ )
81
+
82
+ return ir.Method(
83
+ mod=None,
84
+ py_func=None,
85
+ sym_name=kernel_name,
86
+ arg_names=[],
87
+ dialects=dialects,
88
+ code=code,
89
+ )
@@ -0,0 +1,303 @@
1
+ import math
2
+ from typing import Any
3
+ from dataclasses import field, dataclass
4
+
5
+ import cirq
6
+ from kirin import ir, lowering
7
+ from kirin.rewrite import Walk, CFGCompactify
8
+ from kirin.dialects import py, ilist
9
+
10
+ from .. import op, noise, qubit
11
+
12
+ CirqNode = cirq.Circuit | cirq.Moment | cirq.Gate | cirq.Qid | cirq.Operation
13
+
14
+ DecomposeNode = (
15
+ cirq.SwapPowGate
16
+ | cirq.ISwapPowGate
17
+ | cirq.PhasedXPowGate
18
+ | cirq.PhasedXZGate
19
+ | cirq.CSwapGate
20
+ )
21
+
22
+
23
+ @dataclass
24
+ class Squin(lowering.LoweringABC[CirqNode]):
25
+ """Lower a cirq.Circuit object to a squin kernel"""
26
+
27
+ circuit: cirq.Circuit
28
+ qreg: qubit.New = field(init=False)
29
+ qreg_index: dict[cirq.Qid, int] = field(init=False, default_factory=dict)
30
+ next_qreg_index: int = field(init=False, default=0)
31
+
32
+ def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid):
33
+ index = self.qreg_index.get(qid)
34
+
35
+ if index is None:
36
+ index = self.next_qreg_index
37
+ self.qreg_index[qid] = index
38
+ self.next_qreg_index += 1
39
+
40
+ index_ssa = state.current_frame.push(py.Constant(index)).result
41
+ qbit_getitem = state.current_frame.push(py.GetItem(self.qreg.result, index_ssa))
42
+ return qbit_getitem.result
43
+
44
+ def lower_qubit_getindices(
45
+ self, state: lowering.State[CirqNode], qids: list[cirq.Qid]
46
+ ):
47
+ qbits_getitem = [self.lower_qubit_getindex(state, qid) for qid in qids]
48
+ qbits_stmt = ilist.New(values=qbits_getitem)
49
+ qbits_result = state.current_frame.get(qbits_stmt.name)
50
+
51
+ if qbits_result is not None:
52
+ return qbits_result
53
+
54
+ state.current_frame.push(qbits_stmt)
55
+ return qbits_stmt.result
56
+
57
+ def run(
58
+ self,
59
+ stmt: CirqNode,
60
+ *,
61
+ source: str | None = None,
62
+ globals: dict[str, Any] | None = None,
63
+ file: str | None = None,
64
+ lineno_offset: int = 0,
65
+ col_offset: int = 0,
66
+ compactify: bool = True,
67
+ ) -> ir.Region:
68
+
69
+ state = lowering.State(
70
+ self,
71
+ file=file,
72
+ lineno_offset=lineno_offset,
73
+ col_offset=col_offset,
74
+ )
75
+
76
+ with state.frame(
77
+ [stmt],
78
+ globals=globals,
79
+ finalize_next=False,
80
+ ) as frame:
81
+ # NOTE: create a global register of qubits first
82
+ # TODO: can there be a circuit without qubits?
83
+ n_qubits = cirq.num_qubits(self.circuit)
84
+ n = frame.push(py.Constant(n_qubits))
85
+ self.qreg = frame.push(qubit.New(n_qubits=n.result))
86
+
87
+ self.visit(state, stmt)
88
+
89
+ if compactify:
90
+ Walk(CFGCompactify()).rewrite(frame.curr_region)
91
+
92
+ region = frame.curr_region
93
+
94
+ return region
95
+
96
+ def visit(self, state: lowering.State[CirqNode], node: CirqNode) -> lowering.Result:
97
+ name = node.__class__.__name__
98
+ return getattr(self, f"visit_{name}", self.generic_visit)(state, node)
99
+
100
+ def generic_visit(self, state: lowering.State[CirqNode], node: CirqNode):
101
+ if isinstance(node, CirqNode):
102
+ raise lowering.BuildError(
103
+ f"Cannot lower {node.__class__.__name__} node: {node}"
104
+ )
105
+ raise lowering.BuildError(
106
+ f"Unexpected `{node.__class__.__name__}` node: {repr(node)} is not an AST node"
107
+ )
108
+
109
+ def lower_literal(self, state: lowering.State[CirqNode], value) -> ir.SSAValue:
110
+ raise lowering.BuildError("Literals not supported in cirq circuit")
111
+
112
+ def lower_global(
113
+ self, state: lowering.State[CirqNode], node: CirqNode
114
+ ) -> lowering.LoweringABC.Result:
115
+ raise lowering.BuildError("Literals not supported in cirq circuit")
116
+
117
+ def visit_Circuit(
118
+ self, state: lowering.State[CirqNode], node: cirq.Circuit
119
+ ) -> lowering.Result:
120
+ for moment in node:
121
+ state.lower(moment)
122
+
123
+ def visit_Moment(
124
+ self, state: lowering.State[CirqNode], node: cirq.Moment
125
+ ) -> lowering.Result:
126
+ for op_ in node.operations:
127
+ state.lower(op_)
128
+
129
+ def visit_GateOperation(
130
+ self, state: lowering.State[CirqNode], node: cirq.GateOperation
131
+ ):
132
+ if isinstance(node.gate, cirq.MeasurementGate):
133
+ # NOTE: special dispatch here, since measurement is a gate + a qubit in cirq,
134
+ # but a single statement in squin
135
+ return self.lower_measurement(state, node)
136
+
137
+ if isinstance(node.gate, DecomposeNode):
138
+ # NOTE: easier to decompose these, but for that we need the qubits too,
139
+ # so we need to do this within this method
140
+ for subnode in cirq.decompose_once(node):
141
+ state.lower(subnode)
142
+ return
143
+
144
+ op_ = state.lower(node.gate).expect_one()
145
+ qbits = self.lower_qubit_getindices(state, node.qubits)
146
+ return state.current_frame.push(qubit.Apply(operator=op_, qubits=qbits))
147
+
148
+ def lower_measurement(
149
+ self, state: lowering.State[CirqNode], node: cirq.GateOperation
150
+ ):
151
+ if len(node.qubits) == 1:
152
+ qbit = self.lower_qubit_getindex(state, node.qubits[0])
153
+ return state.current_frame.push(qubit.MeasureQubit(qbit))
154
+
155
+ qbits = self.lower_qubit_getindices(state, node.qubits)
156
+ return state.current_frame.push(qubit.MeasureQubitList(qbits))
157
+
158
+ def visit_SingleQubitPauliStringGateOperation(
159
+ self,
160
+ state: lowering.State[CirqNode],
161
+ node: cirq.SingleQubitPauliStringGateOperation,
162
+ ):
163
+
164
+ match node.pauli:
165
+ case cirq.X:
166
+ op_ = op.stmts.X()
167
+ case cirq.Y:
168
+ op_ = op.stmts.Y()
169
+ case cirq.Z:
170
+ op_ = op.stmts.Z()
171
+ case cirq.I:
172
+ op_ = op.stmts.Identity(sites=1)
173
+ case _:
174
+ raise lowering.BuildError(f"Unexpected Pauli operation {node.pauli}")
175
+
176
+ state.current_frame.push(op_)
177
+ qargs = self.lower_qubit_getindices(state, [node.qubit])
178
+ return state.current_frame.push(qubit.Apply(op_.result, qargs))
179
+
180
+ def visit_HPowGate(self, state: lowering.State[CirqNode], node: cirq.HPowGate):
181
+ if node.exponent == 1:
182
+ return state.current_frame.push(op.stmts.H())
183
+
184
+ return state.lower(node.in_su2())
185
+
186
+ def visit_XPowGate(self, state: lowering.State[CirqNode], node: cirq.XPowGate):
187
+ if node.exponent == 1:
188
+ return state.current_frame.push(op.stmts.X())
189
+
190
+ return self.visit(state, node.in_su2())
191
+
192
+ def visit_YPowGate(self, state: lowering.State[CirqNode], node: cirq.YPowGate):
193
+ if node.exponent == 1:
194
+ return state.current_frame.push(op.stmts.Y())
195
+
196
+ return self.visit(state, node.in_su2())
197
+
198
+ def visit_ZPowGate(self, state: lowering.State[CirqNode], node: cirq.ZPowGate):
199
+ if node.exponent == 0.5:
200
+ return state.current_frame.push(op.stmts.S())
201
+
202
+ if node.exponent == 0.25:
203
+ return state.current_frame.push(op.stmts.T())
204
+
205
+ if node.exponent == 1:
206
+ return state.current_frame.push(op.stmts.Z())
207
+
208
+ # NOTE: just for the Z gate, an arbitrary exponent is equivalent to the ShiftOp
209
+ t = node.exponent
210
+ theta = state.current_frame.push(py.Constant(math.pi * t))
211
+ return state.current_frame.push(op.stmts.ShiftOp(theta=theta.result))
212
+
213
+ def visit_Rx(self, state: lowering.State[CirqNode], node: cirq.Rx):
214
+ x = state.current_frame.push(op.stmts.X())
215
+ angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent))
216
+ return state.current_frame.push(op.stmts.Rot(axis=x.result, angle=angle.result))
217
+
218
+ def visit_Ry(self, state: lowering.State[CirqNode], node: cirq.Ry):
219
+ y = state.current_frame.push(op.stmts.Y())
220
+ angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent))
221
+ return state.current_frame.push(op.stmts.Rot(axis=y.result, angle=angle.result))
222
+
223
+ def visit_Rz(self, state: lowering.State[CirqNode], node: cirq.Rz):
224
+ z = state.current_frame.push(op.stmts.Z())
225
+ angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent))
226
+ return state.current_frame.push(op.stmts.Rot(axis=z.result, angle=angle.result))
227
+
228
+ def visit_CXPowGate(self, state: lowering.State[CirqNode], node: cirq.CXPowGate):
229
+ x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one()
230
+ return state.current_frame.push(op.stmts.Control(x, n_controls=1))
231
+
232
+ def visit_CZPowGate(self, state: lowering.State[CirqNode], node: cirq.CZPowGate):
233
+ z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one()
234
+ return state.current_frame.push(op.stmts.Control(z, n_controls=1))
235
+
236
+ def visit_ControlledOperation(
237
+ self, state: lowering.State[CirqNode], node: cirq.ControlledOperation
238
+ ):
239
+ return self.visit_GateOperation(state, node)
240
+
241
+ def visit_ControlledGate(
242
+ self, state: lowering.State[CirqNode], node: cirq.ControlledGate
243
+ ):
244
+ op_ = state.lower(node.sub_gate).expect_one()
245
+ n_controls = node.num_controls()
246
+ return state.current_frame.push(op.stmts.Control(op_, n_controls=n_controls))
247
+
248
+ def visit_XXPowGate(self, state: lowering.State[CirqNode], node: cirq.XXPowGate):
249
+ x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one()
250
+ return state.current_frame.push(op.stmts.Kron(x, x))
251
+
252
+ def visit_YYPowGate(self, state: lowering.State[CirqNode], node: cirq.YYPowGate):
253
+ y = state.lower(cirq.YPowGate(exponent=node.exponent)).expect_one()
254
+ return state.current_frame.push(op.stmts.Kron(y, y))
255
+
256
+ def visit_ZZPowGate(self, state: lowering.State[CirqNode], node: cirq.ZZPowGate):
257
+ z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one()
258
+ return state.current_frame.push(op.stmts.Kron(z, z))
259
+
260
+ def visit_CCXPowGate(self, state: lowering.State[CirqNode], node: cirq.CCXPowGate):
261
+ x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one()
262
+ return state.current_frame.push(op.stmts.Control(x, n_controls=2))
263
+
264
+ def visit_CCZPowGate(self, state: lowering.State[CirqNode], node: cirq.CCZPowGate):
265
+ z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one()
266
+ return state.current_frame.push(op.stmts.Control(z, n_controls=2))
267
+
268
+ def visit_BitFlipChannel(
269
+ self, state: lowering.State[CirqNode], node: cirq.BitFlipChannel
270
+ ):
271
+ x = state.current_frame.push(op.stmts.X())
272
+ p = state.current_frame.push(py.Constant(node.p))
273
+ return state.current_frame.push(
274
+ noise.stmts.PauliError(basis=x.result, p=p.result)
275
+ )
276
+
277
+ def visit_AmplitudeDampingChannel(
278
+ self, state: lowering.State[CirqNode], node: cirq.AmplitudeDampingChannel
279
+ ):
280
+ r = state.current_frame.push(op.stmts.Reset())
281
+ p = state.current_frame.push(py.Constant(node.gamma))
282
+
283
+ # TODO: do we need a dedicated noise stmt for this? Using PauliError
284
+ # with this basis feels like a hack
285
+ noise_channel = state.current_frame.push(
286
+ noise.stmts.PauliError(basis=r.result, p=p.result)
287
+ )
288
+
289
+ return noise_channel
290
+
291
+ def visit_GeneralizedAmplitudeDampingChannel(
292
+ self,
293
+ state: lowering.State[CirqNode],
294
+ node: cirq.GeneralizedAmplitudeDampingChannel,
295
+ ):
296
+ raise NotImplementedError("TODO: needs a new operator statement")
297
+ # p = state.current_frame.push(py.Constant(node.p))
298
+ # gamma = state.current_frame.push(py.Constant(node.gamma))
299
+
300
+ # p1 =
301
+
302
+ # x = state.current_frame.push(op.stmts.X())
303
+ # noise_channel1 = noise.stmts.PauliError(basis=x.result, p=)
bloqade/squin/groups.py CHANGED
@@ -3,12 +3,12 @@ from kirin.prelude import structural_no_opt
3
3
  from kirin.rewrite import Walk, Chain
4
4
  from kirin.dialects import ilist
5
5
 
6
- from . import op, wire, qubit
6
+ from . import op, wire, noise, qubit
7
7
  from .op.rewrite import PyMultToSquinMult
8
8
  from .rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule
9
9
 
10
10
 
11
- @ir.dialect_group(structural_no_opt.union([op, qubit]))
11
+ @ir.dialect_group(structural_no_opt.union([op, qubit, noise]))
12
12
  def kernel(self):
13
13
  fold_pass = passes.Fold(self)
14
14
  typeinfer_pass = passes.TypeInfer(self)
@@ -36,7 +36,7 @@ def kernel(self):
36
36
  return run_pass
37
37
 
38
38
 
39
- @ir.dialect_group(structural_no_opt.union([op, wire]))
39
+ @ir.dialect_group(structural_no_opt.union([op, wire, noise]))
40
40
  def wired(self):
41
41
  py_mult_to_mult_pass = PyMultToSquinMult(self)
42
42
 
@@ -4,5 +4,7 @@ from ._wrapper import (
4
4
  pp_error as pp_error,
5
5
  depolarize as depolarize,
6
6
  qubit_loss as qubit_loss,
7
- pauli_channel as pauli_channel,
7
+ pauli_error as pauli_error,
8
+ two_qubit_pauli_channel as two_qubit_pauli_channel,
9
+ single_qubit_pauli_channel as single_qubit_pauli_channel,
8
10
  )
@@ -14,11 +14,15 @@ def pp_error(op: Op, p: float) -> Op: ...
14
14
 
15
15
 
16
16
  @wraps(stmts.Depolarize)
17
- def depolarize(n_qubits: int, p: float) -> Op: ...
17
+ def depolarize(p: float) -> Op: ...
18
18
 
19
19
 
20
- @wraps(stmts.PauliChannel)
21
- def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
20
+ @wraps(stmts.SingleQubitPauliChannel)
21
+ def single_qubit_pauli_channel(params: tuple[float, float, float]) -> Op: ...
22
+
23
+
24
+ @wraps(stmts.TwoQubitPauliChannel)
25
+ def two_qubit_pauli_channel(params: tuple[float, ...]) -> Op: ...
22
26
 
23
27
 
24
28
  @wraps(stmts.QubitLoss)
@@ -0,0 +1,111 @@
1
+ import itertools
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Pass
5
+ from kirin.rewrite import Walk
6
+ from kirin.dialects import ilist
7
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
8
+
9
+ from .stmts import (
10
+ PPError,
11
+ QubitLoss,
12
+ Depolarize,
13
+ PauliError,
14
+ NoiseChannel,
15
+ TwoQubitPauliChannel,
16
+ SingleQubitPauliChannel,
17
+ StochasticUnitaryChannel,
18
+ )
19
+ from ..op.stmts import X, Y, Z, Kron, Identity
20
+
21
+
22
+ class _RewriteNoiseStmts(RewriteRule):
23
+ """Rewrites squin noise statements to StochasticUnitaryChannel"""
24
+
25
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
26
+ if not isinstance(node, NoiseChannel) or isinstance(node, QubitLoss):
27
+ return RewriteResult()
28
+
29
+ return getattr(self, "rewrite_" + node.name)(node)
30
+
31
+ def rewrite_pauli_error(self, node: PauliError) -> RewriteResult:
32
+ (operators := ilist.New(values=(node.basis,))).insert_before(node)
33
+ (ps := ilist.New(values=(node.p,))).insert_before(node)
34
+ stochastic_channel = StochasticUnitaryChannel(
35
+ operators=operators.result, probabilities=ps.result
36
+ )
37
+
38
+ node.replace_by(stochastic_channel)
39
+ return RewriteResult(has_done_something=True)
40
+
41
+ def rewrite_single_qubit_pauli_channel(
42
+ self, node: SingleQubitPauliChannel
43
+ ) -> RewriteResult:
44
+ paulis = (X(), Y(), Z())
45
+ paulis_ssa: list[ir.SSAValue] = []
46
+ for op in paulis:
47
+ op.insert_before(node)
48
+ paulis_ssa.append(op.result)
49
+
50
+ (pauli_ops := ilist.New(values=paulis_ssa)).insert_before(node)
51
+
52
+ stochastic_unitary = StochasticUnitaryChannel(
53
+ operators=pauli_ops.result, probabilities=node.params
54
+ )
55
+ node.replace_by(stochastic_unitary)
56
+ return RewriteResult(has_done_something=True)
57
+
58
+ def rewrite_two_qubit_pauli_channel(
59
+ self, node: TwoQubitPauliChannel
60
+ ) -> RewriteResult:
61
+ paulis = (X(), Y(), Z(), Identity(sites=1))
62
+ for op in paulis:
63
+ op.insert_before(node)
64
+
65
+ # NOTE: collect list so we can skip the last entry, which will be two identities
66
+ combinations = list(itertools.product(paulis, repeat=2))[:-1]
67
+ operators: list[ir.SSAValue] = []
68
+ for pauli_1, pauli_2 in combinations:
69
+ op = Kron(pauli_1.result, pauli_2.result)
70
+ op.insert_before(node)
71
+ operators.append(op.result)
72
+
73
+ (operator_list := ilist.New(values=operators)).insert_before(node)
74
+ stochastic_unitary = StochasticUnitaryChannel(
75
+ operators=operator_list.result, probabilities=node.params
76
+ )
77
+
78
+ node.replace_by(stochastic_unitary)
79
+ return RewriteResult(has_done_something=True)
80
+
81
+ def rewrite_p_p_error(self, node: PPError) -> RewriteResult:
82
+ (operators := ilist.New(values=(node.op,))).insert_before(node)
83
+ (ps := ilist.New(values=(node.p,))).insert_before(node)
84
+ stochastic_channel = StochasticUnitaryChannel(
85
+ operators=operators.result, probabilities=ps.result
86
+ )
87
+
88
+ node.replace_by(stochastic_channel)
89
+ return RewriteResult(has_done_something=True)
90
+
91
+ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
92
+ paulis = (X(), Y(), Z())
93
+ operators: list[ir.SSAValue] = []
94
+ for op in paulis:
95
+ op.insert_before(node)
96
+ operators.append(op.result)
97
+
98
+ (operator_list := ilist.New(values=operators)).insert_before(node)
99
+ (ps := ilist.New(values=[node.p for _ in range(3)])).insert_before(node)
100
+
101
+ stochastic_unitary = StochasticUnitaryChannel(
102
+ operators=operator_list.result, probabilities=ps.result
103
+ )
104
+ node.replace_by(stochastic_unitary)
105
+
106
+ return RewriteResult(has_done_something=True)
107
+
108
+
109
+ class RewriteNoiseStmts(Pass):
110
+ def unsafe_run(self, mt: ir.Method):
111
+ return Walk(_RewriteNoiseStmts()).rewrite(mt.code)
@@ -1,21 +1,23 @@
1
- from kirin import ir, types
1
+ from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
+ from kirin.dialects import ilist
3
4
 
4
5
  from bloqade.squin.op.types import OpType
5
6
 
6
7
  from ._dialect import dialect
8
+ from ..op.types import NumOperators
7
9
 
8
10
 
9
11
  @statement
10
12
  class NoiseChannel(ir.Statement):
11
- pass
13
+ traits = frozenset({lowering.FromPythonCall()})
14
+ result: ir.ResultValue = info.result(OpType)
12
15
 
13
16
 
14
17
  @statement(dialect=dialect)
15
18
  class PauliError(NoiseChannel):
16
19
  basis: ir.SSAValue = info.argument(OpType)
17
20
  p: ir.SSAValue = info.argument(types.Float)
18
- result: ir.ResultValue = info.result(OpType)
19
21
 
20
22
 
21
23
  @statement(dialect=dialect)
@@ -26,34 +28,37 @@ class PPError(NoiseChannel):
26
28
 
27
29
  op: ir.SSAValue = info.argument(OpType)
28
30
  p: ir.SSAValue = info.argument(types.Float)
29
- result: ir.ResultValue = info.result(OpType)
30
31
 
31
32
 
32
33
  @statement(dialect=dialect)
33
34
  class Depolarize(NoiseChannel):
34
35
  """
35
- Apply n-qubit depolaize error to qubits
36
- NOTE For Stim, this can only accept 1 or 2 qubits
36
+ Apply depolarize error to qubit
37
37
  """
38
38
 
39
- n_qubits: int = info.attribute(types.Int)
40
39
  p: ir.SSAValue = info.argument(types.Float)
41
- result: ir.ResultValue = info.result(OpType)
42
40
 
43
41
 
44
42
  @statement(dialect=dialect)
45
- class PauliChannel(NoiseChannel):
46
- # NOTE:
47
- # 1-qubit 3 params px, py, pz
48
- # 2-qubit 15 params pix, piy, piz, pxi, pxx, pxy, pxz, pyi, pyx ..., pzz
49
- # TODO add validation for params (maybe during lowering via custom lower?)
50
- n_qubits: int = info.attribute()
51
- params: ir.SSAValue = info.argument(types.Tuple[types.Vararg(types.Float)])
52
- result: ir.ResultValue = info.result(OpType)
43
+ class SingleQubitPauliChannel(NoiseChannel):
44
+ params: ir.SSAValue = info.argument(ilist.IListType[types.Float, types.Literal(3)])
45
+
46
+
47
+ @statement(dialect=dialect)
48
+ class TwoQubitPauliChannel(NoiseChannel):
49
+ params: ir.SSAValue = info.argument(ilist.IListType[types.Float, types.Literal(15)])
53
50
 
54
51
 
55
52
  @statement(dialect=dialect)
56
53
  class QubitLoss(NoiseChannel):
57
54
  # NOTE: qubit loss error (not supported by Stim)
58
55
  p: ir.SSAValue = info.argument(types.Float)
56
+
57
+
58
+ @statement(dialect=dialect)
59
+ class StochasticUnitaryChannel(ir.Statement):
60
+ operators: ir.SSAValue = info.argument(ilist.IListType[OpType, NumOperators])
61
+ probabilities: ir.SSAValue = info.argument(
62
+ ilist.IListType[types.Float, NumOperators]
63
+ )
59
64
  result: ir.ResultValue = info.result(OpType)
bloqade/squin/op/stmts.py CHANGED
@@ -9,7 +9,7 @@ from ._dialect import dialect
9
9
 
10
10
  @statement
11
11
  class Operator(ir.Statement):
12
- pass
12
+ result: ir.ResultValue = info.result(OpType)
13
13
 
14
14
 
15
15
  @statement
@@ -26,7 +26,6 @@ class CompositeOp(Operator):
26
26
  class BinaryOp(CompositeOp):
27
27
  lhs: ir.SSAValue = info.argument(OpType)
28
28
  rhs: ir.SSAValue = info.argument(OpType)
29
- result: ir.ResultValue = info.result(OpType)
30
29
 
31
30
 
32
31
  @statement(dialect=dialect)
@@ -46,7 +45,6 @@ class Adjoint(CompositeOp):
46
45
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
47
46
  is_unitary: bool = info.attribute(default=False)
48
47
  op: ir.SSAValue = info.argument(OpType)
49
- result: ir.ResultValue = info.result(OpType)
50
48
 
51
49
 
52
50
  @statement(dialect=dialect)
@@ -55,7 +53,6 @@ class Scale(CompositeOp):
55
53
  is_unitary: bool = info.attribute(default=False)
56
54
  op: ir.SSAValue = info.argument(OpType)
57
55
  factor: ir.SSAValue = info.argument(NumberType)
58
- result: ir.ResultValue = info.result(OpType)
59
56
 
60
57
 
61
58
  @statement(dialect=dialect)
@@ -64,7 +61,6 @@ class Control(CompositeOp):
64
61
  is_unitary: bool = info.attribute(default=False)
65
62
  op: ir.SSAValue = info.argument(OpType)
66
63
  n_controls: int = info.attribute()
67
- result: ir.ResultValue = info.result(OpType)
68
64
 
69
65
 
70
66
  @statement(dialect=dialect)
@@ -72,14 +68,12 @@ class Rot(CompositeOp):
72
68
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary()})
73
69
  axis: ir.SSAValue = info.argument(OpType)
74
70
  angle: ir.SSAValue = info.argument(types.Float)
75
- result: ir.ResultValue = info.result(OpType)
76
71
 
77
72
 
78
73
  @statement(dialect=dialect)
79
74
  class Identity(CompositeOp):
80
75
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
81
76
  sites: int = info.attribute()
82
- result: ir.ResultValue = info.result(OpType)
83
77
 
84
78
 
85
79
  @statement
@@ -87,7 +81,6 @@ class ConstantOp(PrimitiveOp):
87
81
  traits = frozenset(
88
82
  {ir.Pure(), lowering.FromPythonCall(), ir.ConstantLike(), FixedSites(1)}
89
83
  )
90
- result: ir.ResultValue = info.result(OpType)
91
84
 
92
85
 
93
86
  @statement
@@ -109,7 +102,6 @@ class U3(PrimitiveOp):
109
102
  theta: ir.SSAValue = info.argument(types.Float)
110
103
  phi: ir.SSAValue = info.argument(types.Float)
111
104
  lam: ir.SSAValue = info.argument(types.Float)
112
- result: ir.ResultValue = info.result(OpType)
113
105
 
114
106
 
115
107
  @statement(dialect=dialect)
@@ -124,7 +116,6 @@ class PhaseOp(PrimitiveOp):
124
116
 
125
117
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
126
118
  theta: ir.SSAValue = info.argument(types.Float)
127
- result: ir.ResultValue = info.result(OpType)
128
119
 
129
120
 
130
121
  @statement(dialect=dialect)
@@ -139,7 +130,6 @@ class ShiftOp(PrimitiveOp):
139
130
 
140
131
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
141
132
  theta: ir.SSAValue = info.argument(types.Float)
142
- result: ir.ResultValue = info.result(OpType)
143
133
 
144
134
 
145
135
  @statement(dialect=dialect)
@@ -149,7 +139,6 @@ class Reset(PrimitiveOp):
149
139
  """
150
140
 
151
141
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), FixedSites(1)})
152
- result: ir.ResultValue = info.result(OpType)
153
142
 
154
143
 
155
144
  @statement
bloqade/squin/op/types.py CHANGED
@@ -22,3 +22,5 @@ class Op:
22
22
 
23
23
 
24
24
  OpType = types.PyClass(Op)
25
+
26
+ NumOperators = types.TypeVar("NumOperators")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bloqade-circuit
3
- Version: 0.4.0
3
+ Version: 0.4.1
4
4
  Summary: The software development toolkit for neutral atom arrays.
5
5
  Author-email: Roger-luo <rluo@quera.com>, kaihsin <khwu@quera.com>, weinbe58 <pweinberg@quera.com>, johnzl-777 <jlong@quera.com>
6
6
  License-File: LICENSE
@@ -9,7 +9,7 @@ bloqade/analysis/address/impls.py,sha256=c3FMF2LLkV_fwWaolIRTV7ueXC9qwAR6PUW9kDQ
9
9
  bloqade/analysis/address/lattice.py,sha256=dUq999feqPoBYkqEXe1hjHOn4TP_bkvKip8fyWQ-2-8,1755
10
10
  bloqade/analysis/fidelity/__init__.py,sha256=iJkhoHvCMU9bKxQqgxIWKQWvpqNFRgNBI5DK8-4RAB8,59
11
11
  bloqade/analysis/fidelity/analysis.py,sha256=G6JEYc8eeWJ9mwsbUAIzXuU2nrnTU4te41c04xE71gM,3218
12
- bloqade/pyqrack/__init__.py,sha256=_EQJ4blQxqS1_Z6OTjmgSquFo1R_9rEjECRg4fH97hQ,794
12
+ bloqade/pyqrack/__init__.py,sha256=lonTS-luJkTVujCCtgdZRC12V7FQdoFcozAI-byXwN0,810
13
13
  bloqade/pyqrack/base.py,sha256=9z61PaaAFqCBBwkgsDZSr-qr9IQ5OJ_JUvltmJ7Bgls,4407
14
14
  bloqade/pyqrack/device.py,sha256=cOdyT1k0b73QbOsEIu5KqxHW2OAWP86fi_3XGPhaWGA,7134
15
15
  bloqade/pyqrack/reg.py,sha256=uTL07CT1R0xUsInLmwU9YuuNdV6lV0lCs1zhdUz1qIs,1660
@@ -27,6 +27,8 @@ bloqade/pyqrack/squin/op.py,sha256=CY2qTYqFPwEceWt-iSiax6i_Bs2i9hQhzGR8ni_PmmA,4
27
27
  bloqade/pyqrack/squin/qubit.py,sha256=svQMbsLxv3yjiFMSRc4C7QGllzjtmlSWOsMY1mjTI8Q,2223
28
28
  bloqade/pyqrack/squin/runtime.py,sha256=806S23bJKbh2xBuZ0yPAgjs7ZxDHB-zqQtNeA--x9fw,15349
29
29
  bloqade/pyqrack/squin/wire.py,sha256=rqlAeU-r_EHOwJMqHrEAxpZ_rKsvUpwGG7MP4BW75Nw,1658
30
+ bloqade/pyqrack/squin/noise/__init__.py,sha256=uXgRQPOrHNRp3k2ff2HD8mheUEaqxZPKEnwV-s4BiV4,31
31
+ bloqade/pyqrack/squin/noise/native.py,sha256=KF4VGzU5Ps92DeLcIDIMsxQQtQ97z_3KUHqBPPkZFaM,2286
30
32
  bloqade/qasm2/__init__.py,sha256=W9dR4Qnvigc7e7Ay7puSJHAIuiQk8vWqY-W64SMu5oU,515
31
33
  bloqade/qasm2/_qasm_loading.py,sha256=1EFTt1YDkL8fsoSgSuqD1QcKO4EMFIGuBTX9HCnb6S0,4724
32
34
  bloqade/qasm2/_wrappers.py,sha256=4x3fldC4sV2K_XZ0FPZOorQKAbs_7pualListXtak4A,11148
@@ -104,8 +106,8 @@ bloqade/qbraid/lowering.py,sha256=84RsPONWeQ_2beyOWBNoJbGFKuaMhsWtRiPivTCZ-Q0,11
104
106
  bloqade/qbraid/schema.py,sha256=dTPexUFOiBNBnFv0GEbGh6jpIbMIFHk4hFXmXbeihxA,7854
105
107
  bloqade/qbraid/simulation_result.py,sha256=zdCJcAdbQkEDzFFuC2q3gqOFTOLAXHk4wh8RRDB6cgc,3956
106
108
  bloqade/qbraid/target.py,sha256=LcFHHyLe74yBmrHI9251xHgLN_nUz35lN8RPNwrT6mI,3149
107
- bloqade/squin/__init__.py,sha256=8W5Z26BFrR3XRP6yb6IWDA5tZiXMhNstevx5kP3pP5k,169
108
- bloqade/squin/groups.py,sha256=kFS9WtDC0NDRb4XUfNP8HZdrS961uk39XAAKlufJykA,1289
109
+ bloqade/squin/__init__.py,sha256=0NKHHhSFtlITiAMShgKPkh6tvQbfPzqgBi4I2Wm_LCU,398
110
+ bloqade/squin/groups.py,sha256=RXGJnNZUSXF_f5ljjhZ9At8UhaijayoxFoWvxEsUOWc,1310
109
111
  bloqade/squin/lowering.py,sha256=w-GyOKYZHHKCGA2slcgWNS97Q_znQU65PeYxEIkvChM,816
110
112
  bloqade/squin/qubit.py,sha256=psIZPtbQHsiToCXcT4wuuaZzEifaPVvk_SWrsgPbNwg,5067
111
113
  bloqade/squin/wire.py,sha256=kRmpC7P6qIOZOsuFJHJQyeNsEWcU2Of4ZgcdsMESnKA,3746
@@ -115,19 +117,22 @@ bloqade/squin/analysis/nsites/__init__.py,sha256=RlQg7ivczXCXG5lMeL3ipYKj2oJKC4T
115
117
  bloqade/squin/analysis/nsites/analysis.py,sha256=rIe1RU1MZRItcE2aB8DYahLrv73HfD3IHCX3E_EGQ1c,1773
116
118
  bloqade/squin/analysis/nsites/impls.py,sha256=OaKuAoZ0EAorStYDZxzgc6Dk42kuj19MLkqHWG1MEQM,2592
117
119
  bloqade/squin/analysis/nsites/lattice.py,sha256=ruh0808SHtj3ecuT-C3AZTsLY2j3DRhtezGiTZvcuVs,942
118
- bloqade/squin/noise/__init__.py,sha256=K7wHkzUxWbLF-XQPCAlXY7izz-bS7LJXt2G3GjlkEQc,218
120
+ bloqade/squin/cirq/__init__.py,sha256=aGRwIWqbbKaW-1Pm697uo-wQsM9ZBf0h5Qlo3Y6HJ-Y,2453
121
+ bloqade/squin/cirq/lowering.py,sha256=Crq4oq0O_ucZCjSXV6DNEMvZCo5HCl9Fy0mpYfYOdok,11689
122
+ bloqade/squin/noise/__init__.py,sha256=HQl3FE0SZAGEX3qdveapCaMX391lgLvWeWnoE6Z2pYw,332
119
123
  bloqade/squin/noise/_dialect.py,sha256=2IR98J-lXm5Y3srP9g-FD4JC-qTq2seureM6mKKq1xg,63
120
- bloqade/squin/noise/_wrapper.py,sha256=fHytpR6kINTPnaTfmMa_GRn4geBQkbEGVFRE60V0Zko,474
121
- bloqade/squin/noise/stmts.py,sha256=0XvBdZFkjMABUkaj5Fry8oQ8QmWzoZggX1aecvWBnVk,1607
124
+ bloqade/squin/noise/_wrapper.py,sha256=0jD5va_go9jEW5rC6bZSWU30kjCha2-axFogPON3-V0,580
125
+ bloqade/squin/noise/rewrite.py,sha256=SxIHgMDqYJXepiZDyukHWpe5yaFDSTG-yJ4JONNVr0o,3917
126
+ bloqade/squin/noise/stmts.py,sha256=rktxkIdjdPUYek0MYh9uh83otkl-7UoADCoWHWf57J8,1678
122
127
  bloqade/squin/op/__init__.py,sha256=5OBgT4E44Cy0DNF3yRbXGXkiB8VAJtr48x8hDQEquH4,741
123
128
  bloqade/squin/op/_dialect.py,sha256=66G1IYqmsqUEaCTyUqn2shSHmGYduiTU8GfDXcoMvw4,55
124
129
  bloqade/squin/op/_wrapper.py,sha256=5pqbKGeNWoYQIPa1xkKBr8z5waxwmAm3AV4efGSRT_s,1714
125
130
  bloqade/squin/op/number.py,sha256=yujWUqLrOAr8i8OBDsiS5M882wV7t08u345NgNA6TUc,95
126
131
  bloqade/squin/op/rewrite.py,sha256=Itxz_hTAPNLyLYeLS0PCVk143J1Z558UR7N9-urbnoU,1327
127
132
  bloqade/squin/op/stdlib.py,sha256=4UFK3wKImpums2v5a9OFKuVvz2TLYbYwidg3JYYEi2o,1073
128
- bloqade/squin/op/stmts.py,sha256=0QMB3YC_YLr9CfNQiNapSRRT7V5uU8_DVLy-u8Oq9qw,5761
133
+ bloqade/squin/op/stmts.py,sha256=uBQeCWc1JjMTIVcssWkLWi6MyTLtQMoqTWoKgLsYgWQ,5262
129
134
  bloqade/squin/op/traits.py,sha256=jjsnzWtPtmQK7K3H_D2fvc8XiW1Y3EMBcgeyPax2sjc,1065
130
- bloqade/squin/op/types.py,sha256=2am1KC9FHBl__5q_87hTXrSdx1KMkcEUBeZLSrBZQEw,617
135
+ bloqade/squin/op/types.py,sha256=ozUT0Bv9NuUxPjB2vAeqJ9cpdvUaBfP9trB5mybYxgc,663
131
136
  bloqade/squin/passes/__init__.py,sha256=Bhog-wZBtToNJXfhlYa6S7tE6OoppyRibjMl5JBfY58,45
132
137
  bloqade/squin/passes/stim.py,sha256=VWv3hhfizCWz5sIwwdFt3flWHLzG428evLGIcX8E36Y,1992
133
138
  bloqade/squin/rewrite/__init__.py,sha256=0-9m1cbvFRgjZpQ700NEjW1uKvwZPPbrmUwylhgOjUw,457
@@ -186,7 +191,7 @@ bloqade/visual/animation/runtime/atoms.py,sha256=EmjxhujLiHHPS_HtH_B-7TiqeHgvW5u
186
191
  bloqade/visual/animation/runtime/ppoly.py,sha256=JB9IP53N1w6adBJEue6J5Nmj818Id9JvrlgrmiQTU1I,1385
187
192
  bloqade/visual/animation/runtime/qpustate.py,sha256=rlmxQeJSvaohXrTpXQL5y-NJcpvfW33xPaYM1slv7cc,4270
188
193
  bloqade/visual/animation/runtime/utils.py,sha256=ju9IzOWX-vKwfpqUjlUKu3Ssr_UFPFFq-tzH_Nqyo_c,1212
189
- bloqade_circuit-0.4.0.dist-info/METADATA,sha256=3CjbxNGvQkBKgscX6HpdRqxkGyiAmDLgkgsAYovRvZc,3683
190
- bloqade_circuit-0.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
191
- bloqade_circuit-0.4.0.dist-info/licenses/LICENSE,sha256=S5GIJwR6QCixPA9wryYb44ZEek0Nz4rt_zLUqP05UbU,13160
192
- bloqade_circuit-0.4.0.dist-info/RECORD,,
194
+ bloqade_circuit-0.4.1.dist-info/METADATA,sha256=53teyQGCbiQ0-uxPqwBrNAKdeLcpe7h3Aj7fgVq26HU,3683
195
+ bloqade_circuit-0.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
196
+ bloqade_circuit-0.4.1.dist-info/licenses/LICENSE,sha256=S5GIJwR6QCixPA9wryYb44ZEek0Nz4rt_zLUqP05UbU,13160
197
+ bloqade_circuit-0.4.1.dist-info/RECORD,,