bloqade-circuit 0.6.7__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/measure_id/analysis.py +10 -11
- bloqade/analysis/measure_id/impls.py +15 -2
- bloqade/cirq_utils/noise/__init__.py +0 -2
- bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
- bloqade/cirq_utils/noise/model.py +141 -188
- bloqade/cirq_utils/noise/transform.py +2 -2
- bloqade/pyqrack/device.py +96 -1
- bloqade/pyqrack/squin/qubit.py +4 -2
- bloqade/pyqrack/squin/runtime.py +14 -6
- bloqade/pyqrack/task.py +15 -0
- bloqade/rewrite/rules/split_ifs.py +18 -1
- bloqade/squin/cirq/emit/op.py +52 -1
- bloqade/squin/cirq/emit/qubit.py +4 -4
- bloqade/squin/cirq/lowering.py +3 -9
- bloqade/squin/gate.py +7 -0
- bloqade/squin/lowering.py +26 -0
- bloqade/squin/noise/__init__.py +0 -1
- bloqade/squin/noise/_wrapper.py +2 -6
- bloqade/squin/noise/rewrite.py +0 -11
- bloqade/squin/noise/stmts.py +2 -14
- bloqade/squin/op/_wrapper.py +4 -4
- bloqade/squin/op/stmts.py +33 -9
- bloqade/squin/op/types.py +104 -2
- bloqade/squin/qubit.py +27 -40
- bloqade/squin/rewrite/desugar.py +44 -66
- bloqade/stim/passes/squin_to_stim.py +21 -4
- bloqade/stim/rewrite/ifs_to_stim.py +6 -1
- bloqade/stim/rewrite/qubit_to_stim.py +1 -1
- bloqade/stim/rewrite/squin_noise.py +9 -7
- bloqade/stim/rewrite/util.py +15 -3
- {bloqade_circuit-0.6.7.dist-info → bloqade_circuit-0.7.0.dist-info}/METADATA +1 -1
- {bloqade_circuit-0.6.7.dist-info → bloqade_circuit-0.7.0.dist-info}/RECORD +34 -34
- {bloqade_circuit-0.6.7.dist-info → bloqade_circuit-0.7.0.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.6.7.dist-info → bloqade_circuit-0.7.0.dist-info}/licenses/LICENSE +0 -0
bloqade/pyqrack/squin/qubit.py
CHANGED
|
@@ -25,7 +25,7 @@ class PyQrackMethods(interp.MethodTable):
|
|
|
25
25
|
|
|
26
26
|
@interp.impl(qubit.Apply)
|
|
27
27
|
def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Apply):
|
|
28
|
-
qubits:
|
|
28
|
+
qubits: list[PyQrackQubit] = [frame.get(qbit) for qbit in stmt.qubits]
|
|
29
29
|
operator: OperatorRuntimeABC = frame.get(stmt.operator)
|
|
30
30
|
operator.apply(*qubits)
|
|
31
31
|
|
|
@@ -34,7 +34,9 @@ class PyQrackMethods(interp.MethodTable):
|
|
|
34
34
|
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Broadcast
|
|
35
35
|
):
|
|
36
36
|
operator: OperatorRuntimeABC = frame.get(stmt.operator)
|
|
37
|
-
qubits: ilist.IList[PyQrackQubit, Any] =
|
|
37
|
+
qubits: list[ilist.IList[PyQrackQubit, Any]] = [
|
|
38
|
+
frame.get(qbit) for qbit in stmt.qubits
|
|
39
|
+
]
|
|
38
40
|
operator.broadcast_apply(qubits)
|
|
39
41
|
|
|
40
42
|
def _measure_qubit(self, qbit: PyQrackQubit, interp: PyQrackInterpreter):
|
bloqade/pyqrack/squin/runtime.py
CHANGED
|
@@ -28,17 +28,25 @@ class OperatorRuntimeABC:
|
|
|
28
28
|
) -> None:
|
|
29
29
|
raise RuntimeError(f"Can't apply controlled version of {self}")
|
|
30
30
|
|
|
31
|
-
def broadcast_apply(
|
|
31
|
+
def broadcast_apply(
|
|
32
|
+
self, qubit_lists: list[ilist.IList[PyQrackQubit, Any]], **kwargs
|
|
33
|
+
) -> None:
|
|
32
34
|
n = self.n_sites
|
|
33
35
|
|
|
34
|
-
if
|
|
36
|
+
if n != len(qubit_lists):
|
|
35
37
|
raise RuntimeError(
|
|
36
|
-
f"Cannot
|
|
38
|
+
f"Cannot apply operator of size {n} to {len(qubit_lists)} qubits!"
|
|
37
39
|
)
|
|
38
40
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
41
|
+
m = len(qubit_lists[0])
|
|
42
|
+
for qubit_list in qubit_lists:
|
|
43
|
+
if m != len(qubit_list):
|
|
44
|
+
raise RuntimeError(
|
|
45
|
+
"Cannot broadcast operator on qubit lists of varying length!"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
for qubits in zip(*qubit_lists):
|
|
49
|
+
self.apply(*qubits, **kwargs)
|
|
42
50
|
|
|
43
51
|
|
|
44
52
|
@dataclass(frozen=True)
|
bloqade/pyqrack/task.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import TypeVar, ParamSpec, cast
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
|
|
4
4
|
from bloqade.task import AbstractSimulatorTask
|
|
5
|
+
from bloqade.pyqrack.reg import QubitState, PyQrackQubit
|
|
5
6
|
from bloqade.pyqrack.base import (
|
|
6
7
|
MemoryABC,
|
|
7
8
|
PyQrackInterpreter,
|
|
@@ -36,3 +37,17 @@ class PyQrackSimulatorTask(AbstractSimulatorTask[Param, RetType, MemoryType]):
|
|
|
36
37
|
"""Returns the state vector of the simulator."""
|
|
37
38
|
self.run()
|
|
38
39
|
return self.state.sim_reg.out_ket()
|
|
40
|
+
|
|
41
|
+
def qubits(self) -> list[PyQrackQubit]:
|
|
42
|
+
"""Returns the qubits in the simulator."""
|
|
43
|
+
try:
|
|
44
|
+
N = self.state.sim_reg.num_qubits()
|
|
45
|
+
return [
|
|
46
|
+
PyQrackQubit(
|
|
47
|
+
addr=i, sim_reg=self.state.sim_reg, state=QubitState.Active
|
|
48
|
+
)
|
|
49
|
+
for i in range(N)
|
|
50
|
+
]
|
|
51
|
+
except AttributeError:
|
|
52
|
+
Warning("Task has not been run, there are no qubits!")
|
|
53
|
+
return []
|
|
@@ -46,9 +46,13 @@ class SplitIfStmts(RewriteRule):
|
|
|
46
46
|
if not isinstance(node, scf.IfElse):
|
|
47
47
|
return RewriteResult()
|
|
48
48
|
|
|
49
|
+
# NOTE: only empty else bodies are allowed in valid QASM2
|
|
50
|
+
if not self._has_empty_else(node):
|
|
51
|
+
return RewriteResult()
|
|
52
|
+
|
|
49
53
|
*stmts, yield_or_return = node.then_body.stmts()
|
|
50
54
|
|
|
51
|
-
if len(stmts)
|
|
55
|
+
if len(stmts) <= 1:
|
|
52
56
|
return RewriteResult()
|
|
53
57
|
|
|
54
58
|
is_yield = isinstance(yield_or_return, scf.Yield)
|
|
@@ -71,3 +75,16 @@ class SplitIfStmts(RewriteRule):
|
|
|
71
75
|
node.delete()
|
|
72
76
|
|
|
73
77
|
return RewriteResult(has_done_something=True)
|
|
78
|
+
|
|
79
|
+
def _has_empty_else(self, node: scf.IfElse) -> bool:
|
|
80
|
+
else_stmts = list(node.else_body.stmts())
|
|
81
|
+
if len(else_stmts) > 1:
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
if len(else_stmts) == 0:
|
|
85
|
+
return True
|
|
86
|
+
|
|
87
|
+
if not isinstance(else_stmts[0], scf.Yield):
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
return len(else_stmts[0].values) == 0
|
bloqade/squin/cirq/emit/op.py
CHANGED
|
@@ -2,6 +2,7 @@ import math
|
|
|
2
2
|
|
|
3
3
|
import cirq
|
|
4
4
|
import numpy as np
|
|
5
|
+
from kirin.emit import EmitError
|
|
5
6
|
from kirin.interp import MethodTable, impl
|
|
6
7
|
|
|
7
8
|
from ... import op
|
|
@@ -13,6 +14,7 @@ from .runtime import (
|
|
|
13
14
|
MultRuntime,
|
|
14
15
|
ScaleRuntime,
|
|
15
16
|
AdjointRuntime,
|
|
17
|
+
BasicOpRuntime,
|
|
16
18
|
ControlRuntime,
|
|
17
19
|
UnitaryRuntime,
|
|
18
20
|
HermitianRuntime,
|
|
@@ -116,10 +118,59 @@ class EmitCirqOpMethods(MethodTable):
|
|
|
116
118
|
|
|
117
119
|
@impl(op.stmts.Reset)
|
|
118
120
|
def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Reset):
|
|
119
|
-
return (
|
|
121
|
+
return (BasicOpRuntime(cirq.ResetChannel()),)
|
|
120
122
|
|
|
121
123
|
@impl(op.stmts.PauliString)
|
|
122
124
|
def pauli_string(
|
|
123
125
|
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliString
|
|
124
126
|
):
|
|
125
127
|
return (PauliStringRuntime(stmt.string),)
|
|
128
|
+
|
|
129
|
+
@impl(op.stmts.Rot)
|
|
130
|
+
def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot):
|
|
131
|
+
axis: OperatorRuntimeABC = frame.get(stmt.axis)
|
|
132
|
+
|
|
133
|
+
if not isinstance(axis, HermitianRuntime):
|
|
134
|
+
raise EmitError(
|
|
135
|
+
f"Circuit emission only supported for Pauli operators! Got axis {axis}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
angle = frame.get(stmt.angle)
|
|
139
|
+
|
|
140
|
+
match axis.gate:
|
|
141
|
+
case cirq.X:
|
|
142
|
+
gate = cirq.Rx(rads=angle)
|
|
143
|
+
case cirq.Y:
|
|
144
|
+
gate = cirq.Ry(rads=angle)
|
|
145
|
+
case cirq.Z:
|
|
146
|
+
gate = cirq.Rz(rads=angle)
|
|
147
|
+
case _:
|
|
148
|
+
raise EmitError(
|
|
149
|
+
f"Circuit emission only supported for Pauli operators! Got axis {axis.gate}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return (HermitianRuntime(gate=gate),)
|
|
153
|
+
|
|
154
|
+
@impl(op.stmts.ResetToOne)
|
|
155
|
+
def reset_to_one(
|
|
156
|
+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ResetToOne
|
|
157
|
+
):
|
|
158
|
+
# NOTE: just apply a reset to 0 and flip in sequence (we re-use the multiplication runtime since it does exactly that)
|
|
159
|
+
gate1 = cirq.ResetChannel()
|
|
160
|
+
gate2 = cirq.X
|
|
161
|
+
|
|
162
|
+
rt1 = BasicOpRuntime(gate1)
|
|
163
|
+
rt2 = HermitianRuntime(gate2)
|
|
164
|
+
|
|
165
|
+
# NOTE: mind the order: rhs is applied first
|
|
166
|
+
return (MultRuntime(rt2, rt1),)
|
|
167
|
+
|
|
168
|
+
@impl(op.stmts.SqrtX)
|
|
169
|
+
def sqrt_x(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.SqrtX):
|
|
170
|
+
cirq_op = cirq.XPowGate(exponent=0.5)
|
|
171
|
+
return (UnitaryRuntime(cirq_op),)
|
|
172
|
+
|
|
173
|
+
@impl(op.stmts.SqrtY)
|
|
174
|
+
def sqrt_y(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.SqrtY):
|
|
175
|
+
cirq_op = cirq.YPowGate(exponent=0.5)
|
|
176
|
+
return (UnitaryRuntime(cirq_op),)
|
bloqade/squin/cirq/emit/qubit.py
CHANGED
|
@@ -25,7 +25,7 @@ class EmitCirqQubitMethods(MethodTable):
|
|
|
25
25
|
@impl(qubit.Apply)
|
|
26
26
|
def apply(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Apply):
|
|
27
27
|
op: OperatorRuntimeABC = frame.get(stmt.operator)
|
|
28
|
-
qbits = frame.get(stmt.qubits
|
|
28
|
+
qbits = [frame.get(qbit) for qbit in stmt.qubits]
|
|
29
29
|
operations = op.apply(qbits)
|
|
30
30
|
for operation in operations:
|
|
31
31
|
frame.circuit.append(operation)
|
|
@@ -34,11 +34,11 @@ class EmitCirqQubitMethods(MethodTable):
|
|
|
34
34
|
@impl(qubit.Broadcast)
|
|
35
35
|
def broadcast(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Broadcast):
|
|
36
36
|
op = frame.get(stmt.operator)
|
|
37
|
-
|
|
37
|
+
qbit_lists = [frame.get(qbit) for qbit in stmt.qubits]
|
|
38
38
|
|
|
39
39
|
cirq_ops = []
|
|
40
|
-
for
|
|
41
|
-
cirq_ops.extend(op.apply(
|
|
40
|
+
for qbits in zip(*qbit_lists):
|
|
41
|
+
cirq_ops.extend(op.apply(qbits))
|
|
42
42
|
|
|
43
43
|
frame.circuit.append(cirq.Moment(cirq_ops))
|
|
44
44
|
return ()
|
bloqade/squin/cirq/lowering.py
CHANGED
|
@@ -44,14 +44,7 @@ class Squin(lowering.LoweringABC[CirqNode]):
|
|
|
44
44
|
self, state: lowering.State[CirqNode], qids: list[cirq.Qid]
|
|
45
45
|
):
|
|
46
46
|
qbits_getitem = [self.lower_qubit_getindex(state, qid) for qid in qids]
|
|
47
|
-
|
|
48
|
-
qbits_result = state.current_frame.get(qbits_stmt.name)
|
|
49
|
-
|
|
50
|
-
if qbits_result is not None:
|
|
51
|
-
return qbits_result
|
|
52
|
-
|
|
53
|
-
state.current_frame.push(qbits_stmt)
|
|
54
|
-
return qbits_stmt.result
|
|
47
|
+
return tuple(qbits_getitem)
|
|
55
48
|
|
|
56
49
|
def run(
|
|
57
50
|
self,
|
|
@@ -159,7 +152,8 @@ class Squin(lowering.LoweringABC[CirqNode]):
|
|
|
159
152
|
stmt = state.current_frame.push(qubit.MeasureQubit(qbit))
|
|
160
153
|
else:
|
|
161
154
|
qbits = self.lower_qubit_getindices(state, node.qubits)
|
|
162
|
-
|
|
155
|
+
qbits_list = state.current_frame.push(ilist.New(values=qbits))
|
|
156
|
+
stmt = state.current_frame.push(qubit.MeasureQubitList(qbits_list.result))
|
|
163
157
|
|
|
164
158
|
key = node.gate.key
|
|
165
159
|
if isinstance(key, cirq.MeasurementKey):
|
bloqade/squin/gate.py
CHANGED
|
@@ -137,6 +137,13 @@ def reset(qubit: Qubit) -> None:
|
|
|
137
137
|
_qubit.apply(op, qubit)
|
|
138
138
|
|
|
139
139
|
|
|
140
|
+
@kernel
|
|
141
|
+
def reset_to_one(qubit: Qubit) -> None:
|
|
142
|
+
"""Reset qubit to 1."""
|
|
143
|
+
op = _op.reset_to_one()
|
|
144
|
+
_qubit.apply(op, qubit)
|
|
145
|
+
|
|
146
|
+
|
|
140
147
|
@kernel
|
|
141
148
|
def cx(control: Qubit, target: Qubit) -> None:
|
|
142
149
|
"""Controlled x gate applied to control and target"""
|
bloqade/squin/lowering.py
CHANGED
|
@@ -52,3 +52,29 @@ class ApplyAnyCallLowering(lowering.FromPythonCall["qubit.ApplyAny"]):
|
|
|
52
52
|
return op, qubits.elts
|
|
53
53
|
|
|
54
54
|
return op, [qubits]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(frozen=True)
|
|
58
|
+
class BroadcastCallLowering(lowering.FromPythonCall["qubit.Broadcast"]):
|
|
59
|
+
"""
|
|
60
|
+
Custom lowering for broadcast vararg call.
|
|
61
|
+
|
|
62
|
+
NOTE: we can re-use this to lower Apply too once we remove the deprecated syntax
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def lower(
|
|
66
|
+
self, stmt: type["qubit.Broadcast"], state: lowering.State, node: ast.Call
|
|
67
|
+
):
|
|
68
|
+
if len(node.args) < 2:
|
|
69
|
+
raise lowering.BuildError(
|
|
70
|
+
"Broadcast requires at least one operator and one qubit list argument"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
op, *qubit_lists = node.args
|
|
74
|
+
|
|
75
|
+
op_lowered = state.lower(op).expect_one()
|
|
76
|
+
qubits_lists_lowered = [
|
|
77
|
+
state.lower(qubit_list).expect_one() for qubit_list in qubit_lists
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
return state.current_frame.push(stmt(op_lowered, tuple(qubits_lists_lowered)))
|
bloqade/squin/noise/__init__.py
CHANGED
bloqade/squin/noise/_wrapper.py
CHANGED
|
@@ -3,17 +3,13 @@ from typing import Literal
|
|
|
3
3
|
from kirin.dialects import ilist
|
|
4
4
|
from kirin.lowering import wraps
|
|
5
5
|
|
|
6
|
-
from bloqade.squin.op.types import Op
|
|
6
|
+
from bloqade.squin.op.types import Op, MultiQubitPauliOp
|
|
7
7
|
|
|
8
8
|
from . import stmts
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@wraps(stmts.PauliError)
|
|
12
|
-
def pauli_error(basis:
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@wraps(stmts.PPError)
|
|
16
|
-
def pp_error(op: Op, p: float) -> Op: ...
|
|
12
|
+
def pauli_error(basis: MultiQubitPauliOp, p: float) -> Op: ...
|
|
17
13
|
|
|
18
14
|
|
|
19
15
|
@wraps(stmts.Depolarize)
|
bloqade/squin/noise/rewrite.py
CHANGED
|
@@ -7,7 +7,6 @@ from kirin.dialects import py, ilist
|
|
|
7
7
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
8
8
|
|
|
9
9
|
from .stmts import (
|
|
10
|
-
PPError,
|
|
11
10
|
QubitLoss,
|
|
12
11
|
Depolarize,
|
|
13
12
|
PauliError,
|
|
@@ -86,16 +85,6 @@ class _RewriteNoiseStmts(RewriteRule):
|
|
|
86
85
|
(operator_list := ilist.New(values=operators)).insert_before(node)
|
|
87
86
|
return operator_list.result
|
|
88
87
|
|
|
89
|
-
def rewrite_p_p_error(self, node: PPError) -> RewriteResult:
|
|
90
|
-
(operators := ilist.New(values=(node.op,))).insert_before(node)
|
|
91
|
-
(ps := ilist.New(values=(node.p,))).insert_before(node)
|
|
92
|
-
stochastic_channel = StochasticUnitaryChannel(
|
|
93
|
-
operators=operators.result, probabilities=ps.result
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
node.replace_by(stochastic_channel)
|
|
97
|
-
return RewriteResult(has_done_something=True)
|
|
98
|
-
|
|
99
88
|
def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
|
|
100
89
|
paulis = (X(), Y(), Z())
|
|
101
90
|
operators: list[ir.SSAValue] = []
|
bloqade/squin/noise/stmts.py
CHANGED
|
@@ -2,10 +2,8 @@ from kirin import ir, types, lowering
|
|
|
2
2
|
from kirin.decl import info, statement
|
|
3
3
|
from kirin.dialects import ilist
|
|
4
4
|
|
|
5
|
-
from bloqade.squin.op.types import OpType
|
|
6
|
-
|
|
7
5
|
from ._dialect import dialect
|
|
8
|
-
from ..op.types import NumOperators
|
|
6
|
+
from ..op.types import OpType, NumOperators, MultiQubitPauliOpType
|
|
9
7
|
|
|
10
8
|
|
|
11
9
|
@statement
|
|
@@ -16,17 +14,7 @@ class NoiseChannel(ir.Statement):
|
|
|
16
14
|
|
|
17
15
|
@statement(dialect=dialect)
|
|
18
16
|
class PauliError(NoiseChannel):
|
|
19
|
-
basis: ir.SSAValue = info.argument(
|
|
20
|
-
p: ir.SSAValue = info.argument(types.Float)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@statement(dialect=dialect)
|
|
24
|
-
class PPError(NoiseChannel):
|
|
25
|
-
"""
|
|
26
|
-
Pauli Product Error
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
op: ir.SSAValue = info.argument(OpType)
|
|
17
|
+
basis: ir.SSAValue = info.argument(MultiQubitPauliOpType)
|
|
30
18
|
p: ir.SSAValue = info.argument(types.Float)
|
|
31
19
|
|
|
32
20
|
|
bloqade/squin/op/_wrapper.py
CHANGED
|
@@ -62,15 +62,15 @@ def phase(theta: float) -> types.Op: ...
|
|
|
62
62
|
|
|
63
63
|
|
|
64
64
|
@wraps(stmts.X)
|
|
65
|
-
def x() -> types.
|
|
65
|
+
def x() -> types.PauliOp: ...
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
@wraps(stmts.Y)
|
|
69
|
-
def y() -> types.
|
|
69
|
+
def y() -> types.PauliOp: ...
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
@wraps(stmts.Z)
|
|
73
|
-
def z() -> types.
|
|
73
|
+
def z() -> types.PauliOp: ...
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
@wraps(stmts.SqrtX)
|
|
@@ -118,4 +118,4 @@ def u(theta: float, phi: float, lam: float) -> types.Op: ...
|
|
|
118
118
|
|
|
119
119
|
|
|
120
120
|
@wraps(stmts.PauliString)
|
|
121
|
-
def pauli_string(*, string: str) -> types.
|
|
121
|
+
def pauli_string(*, string: str) -> types.PauliStringOp: ...
|
bloqade/squin/op/stmts.py
CHANGED
|
@@ -1,7 +1,19 @@
|
|
|
1
1
|
from kirin import ir, types, lowering
|
|
2
2
|
from kirin.decl import info, statement
|
|
3
3
|
|
|
4
|
-
from .types import
|
|
4
|
+
from .types import (
|
|
5
|
+
OpType,
|
|
6
|
+
ROpType,
|
|
7
|
+
XOpType,
|
|
8
|
+
YOpType,
|
|
9
|
+
ZOpType,
|
|
10
|
+
KronType,
|
|
11
|
+
MultType,
|
|
12
|
+
PauliOpType,
|
|
13
|
+
ControlOpType,
|
|
14
|
+
PauliStringType,
|
|
15
|
+
ControlledOpType,
|
|
16
|
+
)
|
|
5
17
|
from .number import NumberType
|
|
6
18
|
from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
|
|
7
19
|
from ._dialect import dialect
|
|
@@ -22,22 +34,28 @@ class CompositeOp(Operator):
|
|
|
22
34
|
pass
|
|
23
35
|
|
|
24
36
|
|
|
37
|
+
LhsType = types.TypeVar("Lhs", bound=OpType)
|
|
38
|
+
RhsType = types.TypeVar("Rhs", bound=OpType)
|
|
39
|
+
|
|
40
|
+
|
|
25
41
|
@statement
|
|
26
42
|
class BinaryOp(CompositeOp):
|
|
27
|
-
lhs: ir.SSAValue = info.argument(
|
|
28
|
-
rhs: ir.SSAValue = info.argument(
|
|
43
|
+
lhs: ir.SSAValue = info.argument(LhsType)
|
|
44
|
+
rhs: ir.SSAValue = info.argument(RhsType)
|
|
29
45
|
|
|
30
46
|
|
|
31
47
|
@statement(dialect=dialect)
|
|
32
48
|
class Kron(BinaryOp):
|
|
33
49
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
|
|
34
50
|
is_unitary: bool = info.attribute(default=False)
|
|
51
|
+
result: ir.ResultValue = info.result(KronType[LhsType, RhsType])
|
|
35
52
|
|
|
36
53
|
|
|
37
54
|
@statement(dialect=dialect)
|
|
38
55
|
class Mult(BinaryOp):
|
|
39
56
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
|
|
40
57
|
is_unitary: bool = info.attribute(default=False)
|
|
58
|
+
result: ir.ResultValue = info.result(MultType[LhsType, RhsType])
|
|
41
59
|
|
|
42
60
|
|
|
43
61
|
@statement(dialect=dialect)
|
|
@@ -59,15 +77,20 @@ class Scale(CompositeOp):
|
|
|
59
77
|
class Control(CompositeOp):
|
|
60
78
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
|
|
61
79
|
is_unitary: bool = info.attribute(default=False)
|
|
62
|
-
op: ir.SSAValue = info.argument(
|
|
80
|
+
op: ir.SSAValue = info.argument(ControlledOpType)
|
|
63
81
|
n_controls: int = info.attribute()
|
|
82
|
+
result: ir.ResultValue = info.result(ControlOpType[ControlledOpType])
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
RotationAxisType = types.TypeVar("RotationAxis", bound=OpType)
|
|
64
86
|
|
|
65
87
|
|
|
66
88
|
@statement(dialect=dialect)
|
|
67
89
|
class Rot(CompositeOp):
|
|
68
90
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary()})
|
|
69
|
-
axis: ir.SSAValue = info.argument(
|
|
91
|
+
axis: ir.SSAValue = info.argument(RotationAxisType)
|
|
70
92
|
angle: ir.SSAValue = info.argument(types.Float)
|
|
93
|
+
result: ir.ResultValue = info.result(ROpType[RotationAxisType])
|
|
71
94
|
|
|
72
95
|
|
|
73
96
|
@statement(dialect=dialect)
|
|
@@ -166,13 +189,14 @@ class CliffordOp(ConstantUnitary):
|
|
|
166
189
|
|
|
167
190
|
@statement
|
|
168
191
|
class PauliOp(CliffordOp):
|
|
169
|
-
|
|
192
|
+
result: ir.ResultValue = info.result(type=PauliOpType)
|
|
170
193
|
|
|
171
194
|
|
|
172
195
|
@statement(dialect=dialect)
|
|
173
196
|
class PauliString(ConstantUnitary):
|
|
174
197
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
|
|
175
198
|
string: str = info.attribute()
|
|
199
|
+
result: ir.ResultValue = info.result(type=PauliStringType)
|
|
176
200
|
|
|
177
201
|
def verify(self) -> None:
|
|
178
202
|
if not set("XYZ").issuperset(self.string):
|
|
@@ -183,17 +207,17 @@ class PauliString(ConstantUnitary):
|
|
|
183
207
|
|
|
184
208
|
@statement(dialect=dialect)
|
|
185
209
|
class X(PauliOp):
|
|
186
|
-
|
|
210
|
+
result: ir.ResultValue = info.result(XOpType)
|
|
187
211
|
|
|
188
212
|
|
|
189
213
|
@statement(dialect=dialect)
|
|
190
214
|
class Y(PauliOp):
|
|
191
|
-
|
|
215
|
+
result: ir.ResultValue = info.result(YOpType)
|
|
192
216
|
|
|
193
217
|
|
|
194
218
|
@statement(dialect=dialect)
|
|
195
219
|
class Z(PauliOp):
|
|
196
|
-
|
|
220
|
+
result: ir.ResultValue = info.result(ZOpType)
|
|
197
221
|
|
|
198
222
|
|
|
199
223
|
@statement(dialect=dialect)
|
bloqade/squin/op/types.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import overload
|
|
1
|
+
from typing import Generic, TypeVar, overload
|
|
2
2
|
|
|
3
3
|
from kirin import types
|
|
4
4
|
|
|
@@ -23,4 +23,106 @@ class Op:
|
|
|
23
23
|
|
|
24
24
|
OpType = types.PyClass(Op)
|
|
25
25
|
|
|
26
|
-
|
|
26
|
+
|
|
27
|
+
class CompositeOp(Op):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
CompositeOpType = types.PyClass(CompositeOp)
|
|
32
|
+
|
|
33
|
+
LhsType = TypeVar("LhsType", bound=Op)
|
|
34
|
+
RhsType = TypeVar("RhsType", bound=Op)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BinaryOp(Op, Generic[LhsType, RhsType]):
|
|
38
|
+
lhs: LhsType
|
|
39
|
+
rhs: RhsType
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
BinaryOpType = types.Generic(BinaryOp, OpType, OpType)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Mult(BinaryOp[LhsType, RhsType]):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
MultType = types.Generic(Mult, OpType, OpType)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Kron(BinaryOp[LhsType, RhsType]):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
KronType = types.Generic(Kron, OpType, OpType)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class MultiQubitPauliOp(Op):
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
MultiQubitPauliOpType = types.PyClass(MultiQubitPauliOp)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class PauliStringOp(MultiQubitPauliOp):
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
PauliStringType = types.PyClass(PauliStringOp)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class PauliOp(MultiQubitPauliOp):
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
PauliOpType = types.PyClass(PauliOp)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class XOp(PauliOp):
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
XOpType = types.PyClass(XOp)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class YOp(PauliOp):
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
YOpType = types.PyClass(YOp)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ZOp(PauliOp):
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
ZOpType = types.PyClass(ZOp)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
ControlledOp = TypeVar("ControlledOp", bound=Op)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ControlOp(CompositeOp, Generic[ControlledOp]):
|
|
105
|
+
op: ControlledOp
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
ControlledOpType = types.TypeVar("ControlledOp", bound=OpType)
|
|
109
|
+
ControlOpType = types.Generic(ControlOp, ControlledOpType)
|
|
110
|
+
CXOpType = ControlOpType[XOpType]
|
|
111
|
+
CYOpType = ControlOpType[YOpType]
|
|
112
|
+
CZOpType = ControlOpType[ZOpType]
|
|
113
|
+
|
|
114
|
+
RotationAxis = TypeVar("RotationAxis", bound=Op)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class ROp(CompositeOp, Generic[RotationAxis]):
|
|
118
|
+
axis: RotationAxis
|
|
119
|
+
angle: float
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
ROpType = types.Generic(ROp, OpType)
|
|
123
|
+
RxOpType = ROpType[XOpType]
|
|
124
|
+
RyOpType = ROpType[YOpType]
|
|
125
|
+
RzOpType = ROpType[ZOpType]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
NumOperators = types.TypeVar("NumOperators", bound=types.Int)
|