bloqade-circuit 0.6.8__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/measure_id/analysis.py +10 -11
- bloqade/analysis/measure_id/impls.py +15 -2
- bloqade/cirq_utils/noise/__init__.py +0 -2
- bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
- bloqade/cirq_utils/noise/model.py +141 -188
- bloqade/cirq_utils/noise/transform.py +2 -2
- bloqade/pyqrack/squin/qubit.py +4 -2
- bloqade/pyqrack/squin/runtime.py +14 -6
- bloqade/qasm2/emit/target.py +5 -1
- bloqade/squin/cirq/emit/op.py +37 -5
- bloqade/squin/cirq/emit/qubit.py +4 -4
- bloqade/squin/cirq/emit/runtime.py +0 -15
- bloqade/squin/cirq/lowering.py +3 -9
- bloqade/squin/gate.py +7 -0
- bloqade/squin/lowering.py +26 -0
- bloqade/squin/noise/__init__.py +0 -1
- bloqade/squin/noise/_wrapper.py +2 -6
- bloqade/squin/noise/rewrite.py +0 -11
- bloqade/squin/noise/stmts.py +2 -14
- bloqade/squin/op/_wrapper.py +4 -4
- bloqade/squin/op/stmts.py +33 -9
- bloqade/squin/op/types.py +104 -2
- bloqade/squin/qubit.py +27 -40
- bloqade/squin/rewrite/desugar.py +44 -66
- bloqade/stim/passes/squin_to_stim.py +21 -4
- bloqade/stim/rewrite/ifs_to_stim.py +6 -1
- bloqade/stim/rewrite/qubit_to_stim.py +1 -1
- bloqade/stim/rewrite/squin_noise.py +9 -7
- bloqade/stim/rewrite/util.py +15 -3
- {bloqade_circuit-0.6.8.dist-info → bloqade_circuit-0.7.1.dist-info}/METADATA +2 -2
- {bloqade_circuit-0.6.8.dist-info → bloqade_circuit-0.7.1.dist-info}/RECORD +33 -33
- {bloqade_circuit-0.6.8.dist-info → bloqade_circuit-0.7.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.6.8.dist-info → bloqade_circuit-0.7.1.dist-info}/licenses/LICENSE +0 -0
bloqade/squin/cirq/emit/op.py
CHANGED
|
@@ -2,6 +2,7 @@ import math
|
|
|
2
2
|
|
|
3
3
|
import cirq
|
|
4
4
|
import numpy as np
|
|
5
|
+
from kirin.emit import EmitError
|
|
5
6
|
from kirin.interp import MethodTable, impl
|
|
6
7
|
|
|
7
8
|
from ... import op
|
|
@@ -9,11 +10,11 @@ from .runtime import (
|
|
|
9
10
|
SnRuntime,
|
|
10
11
|
SpRuntime,
|
|
11
12
|
U3Runtime,
|
|
12
|
-
RotRuntime,
|
|
13
13
|
KronRuntime,
|
|
14
14
|
MultRuntime,
|
|
15
15
|
ScaleRuntime,
|
|
16
16
|
AdjointRuntime,
|
|
17
|
+
BasicOpRuntime,
|
|
17
18
|
ControlRuntime,
|
|
18
19
|
UnitaryRuntime,
|
|
19
20
|
HermitianRuntime,
|
|
@@ -117,7 +118,7 @@ class EmitCirqOpMethods(MethodTable):
|
|
|
117
118
|
|
|
118
119
|
@impl(op.stmts.Reset)
|
|
119
120
|
def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Reset):
|
|
120
|
-
return (
|
|
121
|
+
return (BasicOpRuntime(cirq.ResetChannel()),)
|
|
121
122
|
|
|
122
123
|
@impl(op.stmts.PauliString)
|
|
123
124
|
def pauli_string(
|
|
@@ -127,11 +128,42 @@ class EmitCirqOpMethods(MethodTable):
|
|
|
127
128
|
|
|
128
129
|
@impl(op.stmts.Rot)
|
|
129
130
|
def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot):
|
|
130
|
-
|
|
131
|
+
axis: OperatorRuntimeABC = frame.get(stmt.axis)
|
|
132
|
+
|
|
133
|
+
if not isinstance(axis, HermitianRuntime):
|
|
134
|
+
raise EmitError(
|
|
135
|
+
f"Circuit emission only supported for Pauli operators! Got axis {axis}"
|
|
136
|
+
)
|
|
137
|
+
|
|
131
138
|
angle = frame.get(stmt.angle)
|
|
132
139
|
|
|
133
|
-
|
|
134
|
-
|
|
140
|
+
match axis.gate:
|
|
141
|
+
case cirq.X:
|
|
142
|
+
gate = cirq.Rx(rads=angle)
|
|
143
|
+
case cirq.Y:
|
|
144
|
+
gate = cirq.Ry(rads=angle)
|
|
145
|
+
case cirq.Z:
|
|
146
|
+
gate = cirq.Rz(rads=angle)
|
|
147
|
+
case _:
|
|
148
|
+
raise EmitError(
|
|
149
|
+
f"Circuit emission only supported for Pauli operators! Got axis {axis.gate}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return (HermitianRuntime(gate=gate),)
|
|
153
|
+
|
|
154
|
+
@impl(op.stmts.ResetToOne)
|
|
155
|
+
def reset_to_one(
|
|
156
|
+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ResetToOne
|
|
157
|
+
):
|
|
158
|
+
# NOTE: just apply a reset to 0 and flip in sequence (we re-use the multiplication runtime since it does exactly that)
|
|
159
|
+
gate1 = cirq.ResetChannel()
|
|
160
|
+
gate2 = cirq.X
|
|
161
|
+
|
|
162
|
+
rt1 = BasicOpRuntime(gate1)
|
|
163
|
+
rt2 = HermitianRuntime(gate2)
|
|
164
|
+
|
|
165
|
+
# NOTE: mind the order: rhs is applied first
|
|
166
|
+
return (MultRuntime(rt2, rt1),)
|
|
135
167
|
|
|
136
168
|
@impl(op.stmts.SqrtX)
|
|
137
169
|
def sqrt_x(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.SqrtX):
|
bloqade/squin/cirq/emit/qubit.py
CHANGED
|
@@ -25,7 +25,7 @@ class EmitCirqQubitMethods(MethodTable):
|
|
|
25
25
|
@impl(qubit.Apply)
|
|
26
26
|
def apply(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Apply):
|
|
27
27
|
op: OperatorRuntimeABC = frame.get(stmt.operator)
|
|
28
|
-
qbits = frame.get(stmt.qubits
|
|
28
|
+
qbits = [frame.get(qbit) for qbit in stmt.qubits]
|
|
29
29
|
operations = op.apply(qbits)
|
|
30
30
|
for operation in operations:
|
|
31
31
|
frame.circuit.append(operation)
|
|
@@ -34,11 +34,11 @@ class EmitCirqQubitMethods(MethodTable):
|
|
|
34
34
|
@impl(qubit.Broadcast)
|
|
35
35
|
def broadcast(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Broadcast):
|
|
36
36
|
op = frame.get(stmt.operator)
|
|
37
|
-
|
|
37
|
+
qbit_lists = [frame.get(qbit) for qbit in stmt.qubits]
|
|
38
38
|
|
|
39
39
|
cirq_ops = []
|
|
40
|
-
for
|
|
41
|
-
cirq_ops.extend(op.apply(
|
|
40
|
+
for qbits in zip(*qbit_lists):
|
|
41
|
+
cirq_ops.extend(op.apply(qbits))
|
|
42
42
|
|
|
43
43
|
frame.circuit.append(cirq.Moment(cirq_ops))
|
|
44
44
|
return ()
|
|
@@ -240,18 +240,3 @@ class PauliStringRuntime(OperatorRuntimeABC):
|
|
|
240
240
|
qbit: pauli_label for (qbit, pauli_label) in zip(qubits, self.string)
|
|
241
241
|
}
|
|
242
242
|
return [cirq.PauliString(pauli_mapping)]
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
@dataclass
|
|
246
|
-
class RotRuntime(OperatorRuntimeABC):
|
|
247
|
-
axis: str
|
|
248
|
-
angle: float
|
|
249
|
-
|
|
250
|
-
def num_qubits(self) -> int:
|
|
251
|
-
return 1
|
|
252
|
-
|
|
253
|
-
def unsafe_apply(
|
|
254
|
-
self, qubits: Sequence[cirq.Qid], adjoint: bool = False
|
|
255
|
-
) -> list[cirq.Operation]:
|
|
256
|
-
rot = getattr(cirq, "R" + self.axis.lower())(rads=self.angle)
|
|
257
|
-
return [rot(*qubits)]
|
bloqade/squin/cirq/lowering.py
CHANGED
|
@@ -44,14 +44,7 @@ class Squin(lowering.LoweringABC[CirqNode]):
|
|
|
44
44
|
self, state: lowering.State[CirqNode], qids: list[cirq.Qid]
|
|
45
45
|
):
|
|
46
46
|
qbits_getitem = [self.lower_qubit_getindex(state, qid) for qid in qids]
|
|
47
|
-
|
|
48
|
-
qbits_result = state.current_frame.get(qbits_stmt.name)
|
|
49
|
-
|
|
50
|
-
if qbits_result is not None:
|
|
51
|
-
return qbits_result
|
|
52
|
-
|
|
53
|
-
state.current_frame.push(qbits_stmt)
|
|
54
|
-
return qbits_stmt.result
|
|
47
|
+
return tuple(qbits_getitem)
|
|
55
48
|
|
|
56
49
|
def run(
|
|
57
50
|
self,
|
|
@@ -159,7 +152,8 @@ class Squin(lowering.LoweringABC[CirqNode]):
|
|
|
159
152
|
stmt = state.current_frame.push(qubit.MeasureQubit(qbit))
|
|
160
153
|
else:
|
|
161
154
|
qbits = self.lower_qubit_getindices(state, node.qubits)
|
|
162
|
-
|
|
155
|
+
qbits_list = state.current_frame.push(ilist.New(values=qbits))
|
|
156
|
+
stmt = state.current_frame.push(qubit.MeasureQubitList(qbits_list.result))
|
|
163
157
|
|
|
164
158
|
key = node.gate.key
|
|
165
159
|
if isinstance(key, cirq.MeasurementKey):
|
bloqade/squin/gate.py
CHANGED
|
@@ -137,6 +137,13 @@ def reset(qubit: Qubit) -> None:
|
|
|
137
137
|
_qubit.apply(op, qubit)
|
|
138
138
|
|
|
139
139
|
|
|
140
|
+
@kernel
|
|
141
|
+
def reset_to_one(qubit: Qubit) -> None:
|
|
142
|
+
"""Reset qubit to 1."""
|
|
143
|
+
op = _op.reset_to_one()
|
|
144
|
+
_qubit.apply(op, qubit)
|
|
145
|
+
|
|
146
|
+
|
|
140
147
|
@kernel
|
|
141
148
|
def cx(control: Qubit, target: Qubit) -> None:
|
|
142
149
|
"""Controlled x gate applied to control and target"""
|
bloqade/squin/lowering.py
CHANGED
|
@@ -52,3 +52,29 @@ class ApplyAnyCallLowering(lowering.FromPythonCall["qubit.ApplyAny"]):
|
|
|
52
52
|
return op, qubits.elts
|
|
53
53
|
|
|
54
54
|
return op, [qubits]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(frozen=True)
|
|
58
|
+
class BroadcastCallLowering(lowering.FromPythonCall["qubit.Broadcast"]):
|
|
59
|
+
"""
|
|
60
|
+
Custom lowering for broadcast vararg call.
|
|
61
|
+
|
|
62
|
+
NOTE: we can re-use this to lower Apply too once we remove the deprecated syntax
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def lower(
|
|
66
|
+
self, stmt: type["qubit.Broadcast"], state: lowering.State, node: ast.Call
|
|
67
|
+
):
|
|
68
|
+
if len(node.args) < 2:
|
|
69
|
+
raise lowering.BuildError(
|
|
70
|
+
"Broadcast requires at least one operator and one qubit list argument"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
op, *qubit_lists = node.args
|
|
74
|
+
|
|
75
|
+
op_lowered = state.lower(op).expect_one()
|
|
76
|
+
qubits_lists_lowered = [
|
|
77
|
+
state.lower(qubit_list).expect_one() for qubit_list in qubit_lists
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
return state.current_frame.push(stmt(op_lowered, tuple(qubits_lists_lowered)))
|
bloqade/squin/noise/__init__.py
CHANGED
bloqade/squin/noise/_wrapper.py
CHANGED
|
@@ -3,17 +3,13 @@ from typing import Literal
|
|
|
3
3
|
from kirin.dialects import ilist
|
|
4
4
|
from kirin.lowering import wraps
|
|
5
5
|
|
|
6
|
-
from bloqade.squin.op.types import Op
|
|
6
|
+
from bloqade.squin.op.types import Op, MultiQubitPauliOp
|
|
7
7
|
|
|
8
8
|
from . import stmts
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@wraps(stmts.PauliError)
|
|
12
|
-
def pauli_error(basis:
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@wraps(stmts.PPError)
|
|
16
|
-
def pp_error(op: Op, p: float) -> Op: ...
|
|
12
|
+
def pauli_error(basis: MultiQubitPauliOp, p: float) -> Op: ...
|
|
17
13
|
|
|
18
14
|
|
|
19
15
|
@wraps(stmts.Depolarize)
|
bloqade/squin/noise/rewrite.py
CHANGED
|
@@ -7,7 +7,6 @@ from kirin.dialects import py, ilist
|
|
|
7
7
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
8
8
|
|
|
9
9
|
from .stmts import (
|
|
10
|
-
PPError,
|
|
11
10
|
QubitLoss,
|
|
12
11
|
Depolarize,
|
|
13
12
|
PauliError,
|
|
@@ -86,16 +85,6 @@ class _RewriteNoiseStmts(RewriteRule):
|
|
|
86
85
|
(operator_list := ilist.New(values=operators)).insert_before(node)
|
|
87
86
|
return operator_list.result
|
|
88
87
|
|
|
89
|
-
def rewrite_p_p_error(self, node: PPError) -> RewriteResult:
|
|
90
|
-
(operators := ilist.New(values=(node.op,))).insert_before(node)
|
|
91
|
-
(ps := ilist.New(values=(node.p,))).insert_before(node)
|
|
92
|
-
stochastic_channel = StochasticUnitaryChannel(
|
|
93
|
-
operators=operators.result, probabilities=ps.result
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
node.replace_by(stochastic_channel)
|
|
97
|
-
return RewriteResult(has_done_something=True)
|
|
98
|
-
|
|
99
88
|
def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
|
|
100
89
|
paulis = (X(), Y(), Z())
|
|
101
90
|
operators: list[ir.SSAValue] = []
|
bloqade/squin/noise/stmts.py
CHANGED
|
@@ -2,10 +2,8 @@ from kirin import ir, types, lowering
|
|
|
2
2
|
from kirin.decl import info, statement
|
|
3
3
|
from kirin.dialects import ilist
|
|
4
4
|
|
|
5
|
-
from bloqade.squin.op.types import OpType
|
|
6
|
-
|
|
7
5
|
from ._dialect import dialect
|
|
8
|
-
from ..op.types import NumOperators
|
|
6
|
+
from ..op.types import OpType, NumOperators, MultiQubitPauliOpType
|
|
9
7
|
|
|
10
8
|
|
|
11
9
|
@statement
|
|
@@ -16,17 +14,7 @@ class NoiseChannel(ir.Statement):
|
|
|
16
14
|
|
|
17
15
|
@statement(dialect=dialect)
|
|
18
16
|
class PauliError(NoiseChannel):
|
|
19
|
-
basis: ir.SSAValue = info.argument(
|
|
20
|
-
p: ir.SSAValue = info.argument(types.Float)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@statement(dialect=dialect)
|
|
24
|
-
class PPError(NoiseChannel):
|
|
25
|
-
"""
|
|
26
|
-
Pauli Product Error
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
op: ir.SSAValue = info.argument(OpType)
|
|
17
|
+
basis: ir.SSAValue = info.argument(MultiQubitPauliOpType)
|
|
30
18
|
p: ir.SSAValue = info.argument(types.Float)
|
|
31
19
|
|
|
32
20
|
|
bloqade/squin/op/_wrapper.py
CHANGED
|
@@ -62,15 +62,15 @@ def phase(theta: float) -> types.Op: ...
|
|
|
62
62
|
|
|
63
63
|
|
|
64
64
|
@wraps(stmts.X)
|
|
65
|
-
def x() -> types.
|
|
65
|
+
def x() -> types.PauliOp: ...
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
@wraps(stmts.Y)
|
|
69
|
-
def y() -> types.
|
|
69
|
+
def y() -> types.PauliOp: ...
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
@wraps(stmts.Z)
|
|
73
|
-
def z() -> types.
|
|
73
|
+
def z() -> types.PauliOp: ...
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
@wraps(stmts.SqrtX)
|
|
@@ -118,4 +118,4 @@ def u(theta: float, phi: float, lam: float) -> types.Op: ...
|
|
|
118
118
|
|
|
119
119
|
|
|
120
120
|
@wraps(stmts.PauliString)
|
|
121
|
-
def pauli_string(*, string: str) -> types.
|
|
121
|
+
def pauli_string(*, string: str) -> types.PauliStringOp: ...
|
bloqade/squin/op/stmts.py
CHANGED
|
@@ -1,7 +1,19 @@
|
|
|
1
1
|
from kirin import ir, types, lowering
|
|
2
2
|
from kirin.decl import info, statement
|
|
3
3
|
|
|
4
|
-
from .types import
|
|
4
|
+
from .types import (
|
|
5
|
+
OpType,
|
|
6
|
+
ROpType,
|
|
7
|
+
XOpType,
|
|
8
|
+
YOpType,
|
|
9
|
+
ZOpType,
|
|
10
|
+
KronType,
|
|
11
|
+
MultType,
|
|
12
|
+
PauliOpType,
|
|
13
|
+
ControlOpType,
|
|
14
|
+
PauliStringType,
|
|
15
|
+
ControlledOpType,
|
|
16
|
+
)
|
|
5
17
|
from .number import NumberType
|
|
6
18
|
from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
|
|
7
19
|
from ._dialect import dialect
|
|
@@ -22,22 +34,28 @@ class CompositeOp(Operator):
|
|
|
22
34
|
pass
|
|
23
35
|
|
|
24
36
|
|
|
37
|
+
LhsType = types.TypeVar("Lhs", bound=OpType)
|
|
38
|
+
RhsType = types.TypeVar("Rhs", bound=OpType)
|
|
39
|
+
|
|
40
|
+
|
|
25
41
|
@statement
|
|
26
42
|
class BinaryOp(CompositeOp):
|
|
27
|
-
lhs: ir.SSAValue = info.argument(
|
|
28
|
-
rhs: ir.SSAValue = info.argument(
|
|
43
|
+
lhs: ir.SSAValue = info.argument(LhsType)
|
|
44
|
+
rhs: ir.SSAValue = info.argument(RhsType)
|
|
29
45
|
|
|
30
46
|
|
|
31
47
|
@statement(dialect=dialect)
|
|
32
48
|
class Kron(BinaryOp):
|
|
33
49
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
|
|
34
50
|
is_unitary: bool = info.attribute(default=False)
|
|
51
|
+
result: ir.ResultValue = info.result(KronType[LhsType, RhsType])
|
|
35
52
|
|
|
36
53
|
|
|
37
54
|
@statement(dialect=dialect)
|
|
38
55
|
class Mult(BinaryOp):
|
|
39
56
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
|
|
40
57
|
is_unitary: bool = info.attribute(default=False)
|
|
58
|
+
result: ir.ResultValue = info.result(MultType[LhsType, RhsType])
|
|
41
59
|
|
|
42
60
|
|
|
43
61
|
@statement(dialect=dialect)
|
|
@@ -59,15 +77,20 @@ class Scale(CompositeOp):
|
|
|
59
77
|
class Control(CompositeOp):
|
|
60
78
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
|
|
61
79
|
is_unitary: bool = info.attribute(default=False)
|
|
62
|
-
op: ir.SSAValue = info.argument(
|
|
80
|
+
op: ir.SSAValue = info.argument(ControlledOpType)
|
|
63
81
|
n_controls: int = info.attribute()
|
|
82
|
+
result: ir.ResultValue = info.result(ControlOpType[ControlledOpType])
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
RotationAxisType = types.TypeVar("RotationAxis", bound=OpType)
|
|
64
86
|
|
|
65
87
|
|
|
66
88
|
@statement(dialect=dialect)
|
|
67
89
|
class Rot(CompositeOp):
|
|
68
90
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary()})
|
|
69
|
-
axis: ir.SSAValue = info.argument(
|
|
91
|
+
axis: ir.SSAValue = info.argument(RotationAxisType)
|
|
70
92
|
angle: ir.SSAValue = info.argument(types.Float)
|
|
93
|
+
result: ir.ResultValue = info.result(ROpType[RotationAxisType])
|
|
71
94
|
|
|
72
95
|
|
|
73
96
|
@statement(dialect=dialect)
|
|
@@ -166,13 +189,14 @@ class CliffordOp(ConstantUnitary):
|
|
|
166
189
|
|
|
167
190
|
@statement
|
|
168
191
|
class PauliOp(CliffordOp):
|
|
169
|
-
|
|
192
|
+
result: ir.ResultValue = info.result(type=PauliOpType)
|
|
170
193
|
|
|
171
194
|
|
|
172
195
|
@statement(dialect=dialect)
|
|
173
196
|
class PauliString(ConstantUnitary):
|
|
174
197
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
|
|
175
198
|
string: str = info.attribute()
|
|
199
|
+
result: ir.ResultValue = info.result(type=PauliStringType)
|
|
176
200
|
|
|
177
201
|
def verify(self) -> None:
|
|
178
202
|
if not set("XYZ").issuperset(self.string):
|
|
@@ -183,17 +207,17 @@ class PauliString(ConstantUnitary):
|
|
|
183
207
|
|
|
184
208
|
@statement(dialect=dialect)
|
|
185
209
|
class X(PauliOp):
|
|
186
|
-
|
|
210
|
+
result: ir.ResultValue = info.result(XOpType)
|
|
187
211
|
|
|
188
212
|
|
|
189
213
|
@statement(dialect=dialect)
|
|
190
214
|
class Y(PauliOp):
|
|
191
|
-
|
|
215
|
+
result: ir.ResultValue = info.result(YOpType)
|
|
192
216
|
|
|
193
217
|
|
|
194
218
|
@statement(dialect=dialect)
|
|
195
219
|
class Z(PauliOp):
|
|
196
|
-
|
|
220
|
+
result: ir.ResultValue = info.result(ZOpType)
|
|
197
221
|
|
|
198
222
|
|
|
199
223
|
@statement(dialect=dialect)
|
bloqade/squin/op/types.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import overload
|
|
1
|
+
from typing import Generic, TypeVar, overload
|
|
2
2
|
|
|
3
3
|
from kirin import types
|
|
4
4
|
|
|
@@ -23,4 +23,106 @@ class Op:
|
|
|
23
23
|
|
|
24
24
|
OpType = types.PyClass(Op)
|
|
25
25
|
|
|
26
|
-
|
|
26
|
+
|
|
27
|
+
class CompositeOp(Op):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
CompositeOpType = types.PyClass(CompositeOp)
|
|
32
|
+
|
|
33
|
+
LhsType = TypeVar("LhsType", bound=Op)
|
|
34
|
+
RhsType = TypeVar("RhsType", bound=Op)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BinaryOp(Op, Generic[LhsType, RhsType]):
|
|
38
|
+
lhs: LhsType
|
|
39
|
+
rhs: RhsType
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
BinaryOpType = types.Generic(BinaryOp, OpType, OpType)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Mult(BinaryOp[LhsType, RhsType]):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
MultType = types.Generic(Mult, OpType, OpType)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Kron(BinaryOp[LhsType, RhsType]):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
KronType = types.Generic(Kron, OpType, OpType)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class MultiQubitPauliOp(Op):
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
MultiQubitPauliOpType = types.PyClass(MultiQubitPauliOp)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class PauliStringOp(MultiQubitPauliOp):
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
PauliStringType = types.PyClass(PauliStringOp)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class PauliOp(MultiQubitPauliOp):
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
PauliOpType = types.PyClass(PauliOp)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class XOp(PauliOp):
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
XOpType = types.PyClass(XOp)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class YOp(PauliOp):
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
YOpType = types.PyClass(YOp)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ZOp(PauliOp):
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
ZOpType = types.PyClass(ZOp)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
ControlledOp = TypeVar("ControlledOp", bound=Op)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ControlOp(CompositeOp, Generic[ControlledOp]):
|
|
105
|
+
op: ControlledOp
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
ControlledOpType = types.TypeVar("ControlledOp", bound=OpType)
|
|
109
|
+
ControlOpType = types.Generic(ControlOp, ControlledOpType)
|
|
110
|
+
CXOpType = ControlOpType[XOpType]
|
|
111
|
+
CYOpType = ControlOpType[YOpType]
|
|
112
|
+
CZOpType = ControlOpType[ZOpType]
|
|
113
|
+
|
|
114
|
+
RotationAxis = TypeVar("RotationAxis", bound=Op)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class ROp(CompositeOp, Generic[RotationAxis]):
|
|
118
|
+
axis: RotationAxis
|
|
119
|
+
angle: float
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
ROpType = types.Generic(ROp, OpType)
|
|
123
|
+
RxOpType = ROpType[XOpType]
|
|
124
|
+
RyOpType = ROpType[YOpType]
|
|
125
|
+
RzOpType = ROpType[ZOpType]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
NumOperators = types.TypeVar("NumOperators", bound=types.Int)
|
bloqade/squin/qubit.py
CHANGED
|
@@ -7,7 +7,7 @@ Depends on:
|
|
|
7
7
|
- `kirin.dialects.ilist`: provides the `ilist.IListType` type for lists of qubits.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from typing import Any, overload
|
|
10
|
+
from typing import Any, TypeVar, overload
|
|
11
11
|
|
|
12
12
|
from kirin import ir, types, lowering
|
|
13
13
|
from kirin.decl import info, statement
|
|
@@ -18,7 +18,7 @@ from bloqade.types import Qubit, QubitType
|
|
|
18
18
|
from bloqade.squin.op.types import Op, OpType
|
|
19
19
|
|
|
20
20
|
from .types import MeasurementResult, MeasurementResultType
|
|
21
|
-
from .lowering import ApplyAnyCallLowering
|
|
21
|
+
from .lowering import ApplyAnyCallLowering, BroadcastCallLowering
|
|
22
22
|
|
|
23
23
|
dialect = ir.Dialect("squin.qubit")
|
|
24
24
|
|
|
@@ -34,7 +34,7 @@ class New(ir.Statement):
|
|
|
34
34
|
class Apply(ir.Statement):
|
|
35
35
|
traits = frozenset({lowering.FromPythonCall()})
|
|
36
36
|
operator: ir.SSAValue = info.argument(OpType)
|
|
37
|
-
qubits: ir.SSAValue = info.argument(
|
|
37
|
+
qubits: tuple[ir.SSAValue, ...] = info.argument(QubitType)
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
@statement(dialect=dialect)
|
|
@@ -47,9 +47,9 @@ class ApplyAny(ir.Statement):
|
|
|
47
47
|
|
|
48
48
|
@statement(dialect=dialect)
|
|
49
49
|
class Broadcast(ir.Statement):
|
|
50
|
-
traits = frozenset({
|
|
50
|
+
traits = frozenset({BroadcastCallLowering()})
|
|
51
51
|
operator: ir.SSAValue = info.argument(OpType)
|
|
52
|
-
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
|
|
52
|
+
qubits: tuple[ir.SSAValue, ...] = info.argument(ilist.IListType[QubitType])
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
@statement(dialect=dialect)
|
|
@@ -93,26 +93,10 @@ def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
|
|
|
93
93
|
...
|
|
94
94
|
|
|
95
95
|
|
|
96
|
-
@
|
|
97
|
-
def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
|
|
98
|
-
"""Apply an operator to a list of qubits.
|
|
99
|
-
|
|
100
|
-
Note, that when considering atom loss, lost qubits will be skipped.
|
|
101
|
-
|
|
102
|
-
Args:
|
|
103
|
-
operator: The operator to apply.
|
|
104
|
-
qubits: The list of qubits to apply the operator to. The size of the list
|
|
105
|
-
must be inferable and match the number of qubits expected by the operator.
|
|
106
|
-
|
|
107
|
-
Returns:
|
|
108
|
-
None
|
|
109
|
-
"""
|
|
110
|
-
...
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
@overload
|
|
96
|
+
@wraps(ApplyAny)
|
|
114
97
|
def apply(operator: Op, *qubits: Qubit) -> None:
|
|
115
|
-
"""Apply
|
|
98
|
+
"""Apply an operator to qubits. The number of qubit arguments must match the
|
|
99
|
+
size of the operator.
|
|
116
100
|
|
|
117
101
|
Note, that when considering atom loss, lost qubits will be skipped.
|
|
118
102
|
|
|
@@ -127,10 +111,6 @@ def apply(operator: Op, *qubits: Qubit) -> None:
|
|
|
127
111
|
...
|
|
128
112
|
|
|
129
113
|
|
|
130
|
-
@wraps(ApplyAny)
|
|
131
|
-
def apply(operator: Op, *qubits) -> None: ...
|
|
132
|
-
|
|
133
|
-
|
|
134
114
|
@overload
|
|
135
115
|
def measure(input: Qubit) -> MeasurementResult: ...
|
|
136
116
|
@overload
|
|
@@ -154,23 +134,30 @@ def measure(input: Any) -> Any:
|
|
|
154
134
|
...
|
|
155
135
|
|
|
156
136
|
|
|
137
|
+
OpSize = TypeVar("OpSize")
|
|
138
|
+
|
|
139
|
+
|
|
157
140
|
@wraps(Broadcast)
|
|
158
|
-
def broadcast(operator: Op, qubits: ilist.IList[Qubit,
|
|
159
|
-
"""Broadcast and apply an operator to
|
|
160
|
-
|
|
141
|
+
def broadcast(operator: Op, *qubits: ilist.IList[Qubit, OpSize] | list[Qubit]) -> None:
|
|
142
|
+
"""Broadcast and apply an operator to lists of qubits. The number of qubit lists must
|
|
143
|
+
match the size of the operator and the lists must be of same length. The operator is
|
|
144
|
+
then applied to the list elements similar to what python's map function does.
|
|
161
145
|
|
|
162
|
-
|
|
163
|
-
For example
|
|
146
|
+
## Usage examples
|
|
164
147
|
|
|
165
|
-
```
|
|
166
|
-
|
|
167
|
-
```
|
|
148
|
+
```python
|
|
149
|
+
from bloqade import squin
|
|
168
150
|
|
|
169
|
-
|
|
151
|
+
@squin.kernel
|
|
152
|
+
def ghz():
|
|
153
|
+
controls = squin.qubit.new(4)
|
|
154
|
+
targets = squin.qubit.new(4)
|
|
170
155
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
156
|
+
h = squin.op.h()
|
|
157
|
+
squin.qubit.broadcast(h, controls)
|
|
158
|
+
|
|
159
|
+
cx = squin.op.cx()
|
|
160
|
+
squin.qubit.broadcast(cx, controls, targets)
|
|
174
161
|
```
|
|
175
162
|
|
|
176
163
|
Args:
|