bloqade-circuit 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bloqade/analysis/address/impls.py +3 -16
- bloqade/pyqrack/__init__.py +1 -1
- bloqade/pyqrack/noise/native.py +8 -8
- bloqade/pyqrack/squin/noise/__init__.py +1 -0
- bloqade/pyqrack/squin/noise/native.py +72 -0
- bloqade/pyqrack/squin/op.py +7 -0
- bloqade/pyqrack/squin/qubit.py +0 -29
- bloqade/pyqrack/squin/runtime.py +18 -0
- bloqade/pyqrack/squin/wire.py +0 -36
- bloqade/{noise/native → qasm2/dialects/noise}/__init__.py +1 -7
- bloqade/qasm2/dialects/noise/_dialect.py +3 -0
- bloqade/{noise → qasm2/dialects/noise}/fidelity.py +2 -2
- bloqade/qasm2/dialects/noise/model.py +278 -0
- bloqade/qasm2/emit/impls/__init__.py +1 -1
- bloqade/qasm2/emit/impls/{noise_native.py → noise.py} +11 -11
- bloqade/qasm2/emit/main.py +2 -4
- bloqade/qasm2/emit/target.py +3 -3
- bloqade/qasm2/groups.py +0 -2
- bloqade/{noise/native/_wrappers.py → qasm2/noise.py} +9 -5
- bloqade/qasm2/passes/glob.py +12 -8
- bloqade/qasm2/passes/noise.py +5 -14
- bloqade/qasm2/rewrite/__init__.py +2 -0
- bloqade/qasm2/rewrite/noise/__init__.py +0 -0
- bloqade/qasm2/rewrite/{heuristic_noise.py → noise/heuristic_noise.py} +31 -53
- bloqade/{noise/native/rewrite.py → qasm2/rewrite/noise/remove_noise.py} +2 -2
- bloqade/qbraid/lowering.py +8 -8
- bloqade/squin/__init__.py +16 -1
- bloqade/squin/analysis/nsites/impls.py +0 -9
- bloqade/squin/cirq/__init__.py +89 -0
- bloqade/squin/cirq/lowering.py +303 -0
- bloqade/squin/groups.py +7 -7
- bloqade/squin/lowering.py +27 -0
- bloqade/squin/noise/__init__.py +3 -1
- bloqade/squin/noise/_wrapper.py +7 -3
- bloqade/squin/noise/rewrite.py +111 -0
- bloqade/squin/noise/stmts.py +21 -16
- bloqade/squin/op/__init__.py +1 -0
- bloqade/squin/op/_wrapper.py +4 -0
- bloqade/squin/op/stmts.py +10 -11
- bloqade/squin/op/types.py +2 -0
- bloqade/squin/qubit.py +32 -37
- bloqade/squin/rewrite/desugar.py +65 -0
- bloqade/squin/rewrite/qubit_to_stim.py +0 -23
- bloqade/squin/rewrite/squin_measure.py +2 -27
- bloqade/squin/rewrite/stim_rewrite_util.py +3 -8
- bloqade/squin/rewrite/wire_to_stim.py +0 -21
- bloqade/squin/wire.py +4 -9
- bloqade/stim/__init__.py +2 -1
- bloqade/stim/_wrappers.py +4 -0
- bloqade/stim/dialects/auxiliary/__init__.py +1 -0
- bloqade/stim/dialects/auxiliary/emit.py +17 -2
- bloqade/stim/dialects/auxiliary/stmts/__init__.py +1 -0
- bloqade/stim/dialects/auxiliary/stmts/annotate.py +8 -0
- bloqade/stim/dialects/collapse/emit_str.py +3 -1
- bloqade/stim/dialects/gate/emit.py +9 -2
- bloqade/stim/dialects/noise/emit.py +32 -1
- bloqade/stim/dialects/noise/stmts.py +29 -0
- bloqade/stim/parse/__init__.py +1 -0
- bloqade/stim/parse/lowering.py +686 -0
- {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.1.dist-info}/METADATA +3 -1
- {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.1.dist-info}/RECORD +64 -57
- bloqade/noise/__init__.py +0 -2
- bloqade/noise/native/_dialect.py +0 -3
- bloqade/noise/native/model.py +0 -346
- bloqade/qasm2/dialects/noise.py +0 -48
- bloqade/squin/rewrite/measure_desugar.py +0 -33
- /bloqade/{noise/native → qasm2/dialects/noise}/stmts.py +0 -0
- {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.1.dist-info}/licenses/LICENSE +0 -0
bloqade/squin/op/stmts.py
CHANGED
|
@@ -9,7 +9,7 @@ from ._dialect import dialect
|
|
|
9
9
|
|
|
10
10
|
@statement
|
|
11
11
|
class Operator(ir.Statement):
|
|
12
|
-
|
|
12
|
+
result: ir.ResultValue = info.result(OpType)
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
@statement
|
|
@@ -26,7 +26,6 @@ class CompositeOp(Operator):
|
|
|
26
26
|
class BinaryOp(CompositeOp):
|
|
27
27
|
lhs: ir.SSAValue = info.argument(OpType)
|
|
28
28
|
rhs: ir.SSAValue = info.argument(OpType)
|
|
29
|
-
result: ir.ResultValue = info.result(OpType)
|
|
30
29
|
|
|
31
30
|
|
|
32
31
|
@statement(dialect=dialect)
|
|
@@ -46,7 +45,6 @@ class Adjoint(CompositeOp):
|
|
|
46
45
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
|
|
47
46
|
is_unitary: bool = info.attribute(default=False)
|
|
48
47
|
op: ir.SSAValue = info.argument(OpType)
|
|
49
|
-
result: ir.ResultValue = info.result(OpType)
|
|
50
48
|
|
|
51
49
|
|
|
52
50
|
@statement(dialect=dialect)
|
|
@@ -55,7 +53,6 @@ class Scale(CompositeOp):
|
|
|
55
53
|
is_unitary: bool = info.attribute(default=False)
|
|
56
54
|
op: ir.SSAValue = info.argument(OpType)
|
|
57
55
|
factor: ir.SSAValue = info.argument(NumberType)
|
|
58
|
-
result: ir.ResultValue = info.result(OpType)
|
|
59
56
|
|
|
60
57
|
|
|
61
58
|
@statement(dialect=dialect)
|
|
@@ -64,7 +61,6 @@ class Control(CompositeOp):
|
|
|
64
61
|
is_unitary: bool = info.attribute(default=False)
|
|
65
62
|
op: ir.SSAValue = info.argument(OpType)
|
|
66
63
|
n_controls: int = info.attribute()
|
|
67
|
-
result: ir.ResultValue = info.result(OpType)
|
|
68
64
|
|
|
69
65
|
|
|
70
66
|
@statement(dialect=dialect)
|
|
@@ -72,14 +68,12 @@ class Rot(CompositeOp):
|
|
|
72
68
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary()})
|
|
73
69
|
axis: ir.SSAValue = info.argument(OpType)
|
|
74
70
|
angle: ir.SSAValue = info.argument(types.Float)
|
|
75
|
-
result: ir.ResultValue = info.result(OpType)
|
|
76
71
|
|
|
77
72
|
|
|
78
73
|
@statement(dialect=dialect)
|
|
79
74
|
class Identity(CompositeOp):
|
|
80
75
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
|
|
81
76
|
sites: int = info.attribute()
|
|
82
|
-
result: ir.ResultValue = info.result(OpType)
|
|
83
77
|
|
|
84
78
|
|
|
85
79
|
@statement
|
|
@@ -87,7 +81,6 @@ class ConstantOp(PrimitiveOp):
|
|
|
87
81
|
traits = frozenset(
|
|
88
82
|
{ir.Pure(), lowering.FromPythonCall(), ir.ConstantLike(), FixedSites(1)}
|
|
89
83
|
)
|
|
90
|
-
result: ir.ResultValue = info.result(OpType)
|
|
91
84
|
|
|
92
85
|
|
|
93
86
|
@statement
|
|
@@ -109,7 +102,6 @@ class U3(PrimitiveOp):
|
|
|
109
102
|
theta: ir.SSAValue = info.argument(types.Float)
|
|
110
103
|
phi: ir.SSAValue = info.argument(types.Float)
|
|
111
104
|
lam: ir.SSAValue = info.argument(types.Float)
|
|
112
|
-
result: ir.ResultValue = info.result(OpType)
|
|
113
105
|
|
|
114
106
|
|
|
115
107
|
@statement(dialect=dialect)
|
|
@@ -124,7 +116,6 @@ class PhaseOp(PrimitiveOp):
|
|
|
124
116
|
|
|
125
117
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
|
|
126
118
|
theta: ir.SSAValue = info.argument(types.Float)
|
|
127
|
-
result: ir.ResultValue = info.result(OpType)
|
|
128
119
|
|
|
129
120
|
|
|
130
121
|
@statement(dialect=dialect)
|
|
@@ -139,7 +130,15 @@ class ShiftOp(PrimitiveOp):
|
|
|
139
130
|
|
|
140
131
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
|
|
141
132
|
theta: ir.SSAValue = info.argument(types.Float)
|
|
142
|
-
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@statement(dialect=dialect)
|
|
136
|
+
class Reset(PrimitiveOp):
|
|
137
|
+
"""
|
|
138
|
+
Reset operator for qubits or wires.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), FixedSites(1)})
|
|
143
142
|
|
|
144
143
|
|
|
145
144
|
@statement
|
bloqade/squin/op/types.py
CHANGED
bloqade/squin/qubit.py
CHANGED
|
@@ -17,6 +17,8 @@ from kirin.lowering import wraps
|
|
|
17
17
|
from bloqade.types import Qubit, QubitType
|
|
18
18
|
from bloqade.squin.op.types import Op, OpType
|
|
19
19
|
|
|
20
|
+
from .lowering import ApplyAnyCallLowering
|
|
21
|
+
|
|
20
22
|
dialect = ir.Dialect("squin.qubit")
|
|
21
23
|
|
|
22
24
|
|
|
@@ -34,6 +36,14 @@ class Apply(ir.Statement):
|
|
|
34
36
|
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
|
|
35
37
|
|
|
36
38
|
|
|
39
|
+
@statement(dialect=dialect)
|
|
40
|
+
class ApplyAny(ir.Statement):
|
|
41
|
+
# NOTE: custom lowering to deal with vararg calls
|
|
42
|
+
traits = frozenset({ApplyAnyCallLowering()})
|
|
43
|
+
operator: ir.SSAValue = info.argument(OpType)
|
|
44
|
+
qubits: tuple[ir.SSAValue, ...] = info.argument()
|
|
45
|
+
|
|
46
|
+
|
|
37
47
|
@statement(dialect=dialect)
|
|
38
48
|
class Broadcast(ir.Statement):
|
|
39
49
|
traits = frozenset({lowering.FromPythonCall()})
|
|
@@ -68,19 +78,6 @@ class MeasureQubitList(ir.Statement):
|
|
|
68
78
|
result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
|
|
69
79
|
|
|
70
80
|
|
|
71
|
-
@statement(dialect=dialect)
|
|
72
|
-
class MeasureAndReset(ir.Statement):
|
|
73
|
-
traits = frozenset({lowering.FromPythonCall()})
|
|
74
|
-
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
|
|
75
|
-
result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
@statement(dialect=dialect)
|
|
79
|
-
class Reset(ir.Statement):
|
|
80
|
-
traits = frozenset({lowering.FromPythonCall()})
|
|
81
|
-
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
|
|
82
|
-
|
|
83
|
-
|
|
84
81
|
# NOTE: no dependent types in Python, so we have to mark it Any...
|
|
85
82
|
@wraps(New)
|
|
86
83
|
def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
|
|
@@ -95,7 +92,7 @@ def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
|
|
|
95
92
|
...
|
|
96
93
|
|
|
97
94
|
|
|
98
|
-
@
|
|
95
|
+
@overload
|
|
99
96
|
def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
|
|
100
97
|
"""Apply an operator to a list of qubits.
|
|
101
98
|
|
|
@@ -112,6 +109,27 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
|
|
|
112
109
|
...
|
|
113
110
|
|
|
114
111
|
|
|
112
|
+
@overload
|
|
113
|
+
def apply(operator: Op, *qubits: Qubit) -> None:
|
|
114
|
+
"""Apply and operator to any number of qubits.
|
|
115
|
+
|
|
116
|
+
Note, that when considering atom loss, lost qubits will be skipped.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
operator: The operator to apply.
|
|
120
|
+
*qubits: The qubits to apply the operator to. The number of qubits must
|
|
121
|
+
match the size of the operator.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
None
|
|
125
|
+
"""
|
|
126
|
+
...
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@wraps(ApplyAny)
|
|
130
|
+
def apply(operator: Op, *qubits) -> None: ...
|
|
131
|
+
|
|
132
|
+
|
|
115
133
|
@overload
|
|
116
134
|
def measure(input: Qubit) -> bool: ...
|
|
117
135
|
@overload
|
|
@@ -161,26 +179,3 @@ def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> No
|
|
|
161
179
|
None
|
|
162
180
|
"""
|
|
163
181
|
...
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
@wraps(MeasureAndReset)
|
|
167
|
-
def measure_and_reset(qubits: ilist.IList[Qubit, Any]) -> ilist.IList[bool, Any]:
|
|
168
|
-
"""Measure the qubits in the list and reset them."
|
|
169
|
-
|
|
170
|
-
Args:
|
|
171
|
-
qubits: The list of qubits to measure and reset.
|
|
172
|
-
|
|
173
|
-
Returns:
|
|
174
|
-
list[bool]: The result of the measurement.
|
|
175
|
-
"""
|
|
176
|
-
...
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@wraps(Reset)
|
|
180
|
-
def reset(qubits: ilist.IList[Qubit, Any]) -> None:
|
|
181
|
-
"""Reset the qubits in the list."
|
|
182
|
-
|
|
183
|
-
Args:
|
|
184
|
-
qubits: The list of qubits to reset.
|
|
185
|
-
"""
|
|
186
|
-
...
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from kirin import ir, types
|
|
2
|
+
from kirin.dialects import ilist
|
|
3
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
4
|
+
|
|
5
|
+
from bloqade.squin.qubit import (
|
|
6
|
+
Apply,
|
|
7
|
+
ApplyAny,
|
|
8
|
+
QubitType,
|
|
9
|
+
MeasureAny,
|
|
10
|
+
MeasureQubit,
|
|
11
|
+
MeasureQubitList,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MeasureDesugarRule(RewriteRule):
|
|
16
|
+
"""
|
|
17
|
+
Desugar measure operations in the circuit.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
21
|
+
|
|
22
|
+
if not isinstance(node, MeasureAny):
|
|
23
|
+
return RewriteResult()
|
|
24
|
+
|
|
25
|
+
if node.input.type.is_subseteq(QubitType):
|
|
26
|
+
node.replace_by(
|
|
27
|
+
MeasureQubit(
|
|
28
|
+
qubit=node.input,
|
|
29
|
+
)
|
|
30
|
+
)
|
|
31
|
+
return RewriteResult(has_done_something=True)
|
|
32
|
+
elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]):
|
|
33
|
+
node.replace_by(
|
|
34
|
+
MeasureQubitList(
|
|
35
|
+
qubits=node.input,
|
|
36
|
+
)
|
|
37
|
+
)
|
|
38
|
+
return RewriteResult(has_done_something=True)
|
|
39
|
+
|
|
40
|
+
return RewriteResult()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ApplyDesugarRule(RewriteRule):
|
|
44
|
+
"""
|
|
45
|
+
Desugar apply operators in the kernel.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
49
|
+
|
|
50
|
+
if not isinstance(node, ApplyAny):
|
|
51
|
+
return RewriteResult()
|
|
52
|
+
|
|
53
|
+
op = node.operator
|
|
54
|
+
qubits = node.qubits
|
|
55
|
+
|
|
56
|
+
if len(qubits) == 1 and qubits[0].type.is_subseteq(ilist.IListType):
|
|
57
|
+
# NOTE: already calling with just a single argument that is already an ilist
|
|
58
|
+
qubits_ilist = qubits[0]
|
|
59
|
+
else:
|
|
60
|
+
(qubits_ilist_stmt := ilist.New(values=qubits)).insert_before(node)
|
|
61
|
+
qubits_ilist = qubits_ilist_stmt.result
|
|
62
|
+
|
|
63
|
+
stmt = Apply(operator=op, qubits=qubits_ilist)
|
|
64
|
+
node.replace_by(stmt)
|
|
65
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from kirin import ir
|
|
2
2
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
3
|
|
|
4
|
-
from bloqade import stim
|
|
5
4
|
from bloqade.squin import op, qubit
|
|
6
5
|
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
7
6
|
from bloqade.squin.rewrite.stim_rewrite_util import (
|
|
@@ -18,8 +17,6 @@ class SquinQubitToStim(RewriteRule):
|
|
|
18
17
|
match node:
|
|
19
18
|
case qubit.Apply() | qubit.Broadcast():
|
|
20
19
|
return self.rewrite_Apply_and_Broadcast(node)
|
|
21
|
-
case qubit.Reset():
|
|
22
|
-
return self.rewrite_Reset(node)
|
|
23
20
|
case _:
|
|
24
21
|
return RewriteResult()
|
|
25
22
|
|
|
@@ -60,25 +57,5 @@ class SquinQubitToStim(RewriteRule):
|
|
|
60
57
|
|
|
61
58
|
return RewriteResult(has_done_something=True)
|
|
62
59
|
|
|
63
|
-
def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult:
|
|
64
|
-
qubit_ilist_ssa = reset_stmt.qubits
|
|
65
|
-
# qubits are in an ilist which makes up an AddressTuple
|
|
66
|
-
address_attr = qubit_ilist_ssa.hints.get("address")
|
|
67
|
-
if address_attr is None:
|
|
68
|
-
return RewriteResult()
|
|
69
|
-
|
|
70
|
-
assert isinstance(address_attr, AddressAttribute)
|
|
71
|
-
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
72
|
-
address=address_attr, stmt_to_insert_before=reset_stmt
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
if qubit_idx_ssas is None:
|
|
76
|
-
return RewriteResult()
|
|
77
|
-
|
|
78
|
-
stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
|
|
79
|
-
reset_stmt.replace_by(stim_rz_stmt)
|
|
80
|
-
|
|
81
|
-
return RewriteResult(has_done_something=True)
|
|
82
|
-
|
|
83
60
|
|
|
84
61
|
# put rewrites for measure statements in separate rule, then just have to dispatch
|
|
@@ -3,8 +3,8 @@ from kirin import ir
|
|
|
3
3
|
from kirin.dialects import py
|
|
4
4
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
5
5
|
|
|
6
|
-
from bloqade import stim
|
|
7
6
|
from bloqade.squin import wire, qubit
|
|
7
|
+
from bloqade.stim.dialects import collapse
|
|
8
8
|
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
9
9
|
from bloqade.squin.rewrite.stim_rewrite_util import (
|
|
10
10
|
is_measure_result_used,
|
|
@@ -22,8 +22,6 @@ class SquinMeasureToStim(RewriteRule):
|
|
|
22
22
|
match node:
|
|
23
23
|
case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure():
|
|
24
24
|
return self.rewrite_Measure(node)
|
|
25
|
-
case qubit.MeasureAndReset() | wire.MeasureAndReset():
|
|
26
|
-
return self.rewrite_MeasureAndReset(node)
|
|
27
25
|
case _:
|
|
28
26
|
return RewriteResult()
|
|
29
27
|
|
|
@@ -38,7 +36,7 @@ class SquinMeasureToStim(RewriteRule):
|
|
|
38
36
|
return RewriteResult()
|
|
39
37
|
|
|
40
38
|
prob_noise_stmt = py.constant.Constant(0.0)
|
|
41
|
-
stim_measure_stmt =
|
|
39
|
+
stim_measure_stmt = collapse.MZ(
|
|
42
40
|
p=prob_noise_stmt.result,
|
|
43
41
|
targets=qubit_idx_ssas,
|
|
44
42
|
)
|
|
@@ -47,29 +45,6 @@ class SquinMeasureToStim(RewriteRule):
|
|
|
47
45
|
|
|
48
46
|
return RewriteResult(has_done_something=True)
|
|
49
47
|
|
|
50
|
-
def rewrite_MeasureAndReset(
|
|
51
|
-
self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset
|
|
52
|
-
) -> RewriteResult:
|
|
53
|
-
if not is_measure_result_used(meas_and_reset_stmt):
|
|
54
|
-
return RewriteResult()
|
|
55
|
-
|
|
56
|
-
qubit_idx_ssas = self.get_qubit_idx_ssas(meas_and_reset_stmt)
|
|
57
|
-
|
|
58
|
-
if qubit_idx_ssas is None:
|
|
59
|
-
return RewriteResult()
|
|
60
|
-
|
|
61
|
-
error_p_stmt = py.Constant(0.0)
|
|
62
|
-
stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result)
|
|
63
|
-
stim_rz_stmt = stim.collapse.RZ(
|
|
64
|
-
targets=qubit_idx_ssas,
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
error_p_stmt.insert_before(meas_and_reset_stmt)
|
|
68
|
-
stim_mz_stmt.insert_before(meas_and_reset_stmt)
|
|
69
|
-
meas_and_reset_stmt.replace_by(stim_rz_stmt)
|
|
70
|
-
|
|
71
|
-
return RewriteResult(has_done_something=True)
|
|
72
|
-
|
|
73
48
|
def get_qubit_idx_ssas(
|
|
74
49
|
self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
|
|
75
50
|
) -> tuple[ir.SSAValue, ...] | None:
|
|
@@ -3,7 +3,7 @@ from kirin.dialects import py
|
|
|
3
3
|
from kirin.rewrite.abc import RewriteResult
|
|
4
4
|
|
|
5
5
|
from bloqade.squin import op, wire, qubit
|
|
6
|
-
from bloqade.stim.dialects import gate
|
|
6
|
+
from bloqade.stim.dialects import gate, collapse
|
|
7
7
|
from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple
|
|
8
8
|
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
9
9
|
|
|
@@ -14,6 +14,7 @@ SQUIN_STIM_GATE_MAPPING = {
|
|
|
14
14
|
op.stmts.H: gate.H,
|
|
15
15
|
op.stmts.S: gate.S,
|
|
16
16
|
op.stmts.Identity: gate.Identity,
|
|
17
|
+
op.stmts.Reset: collapse.RZ,
|
|
17
18
|
}
|
|
18
19
|
|
|
19
20
|
|
|
@@ -144,13 +145,7 @@ def rewrite_Control(
|
|
|
144
145
|
|
|
145
146
|
|
|
146
147
|
def is_measure_result_used(
|
|
147
|
-
stmt:
|
|
148
|
-
qubit.MeasureAndReset
|
|
149
|
-
| qubit.MeasureQubit
|
|
150
|
-
| qubit.MeasureQubitList
|
|
151
|
-
| wire.MeasureAndReset
|
|
152
|
-
| wire.Measure
|
|
153
|
-
),
|
|
148
|
+
stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure,
|
|
154
149
|
) -> bool:
|
|
155
150
|
"""
|
|
156
151
|
Check if the result of a measure statement is used in the program.
|
|
@@ -1,13 +1,10 @@
|
|
|
1
1
|
from kirin import ir
|
|
2
2
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
3
|
|
|
4
|
-
from bloqade import stim
|
|
5
4
|
from bloqade.squin import op, wire
|
|
6
|
-
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
7
5
|
from bloqade.squin.rewrite.stim_rewrite_util import (
|
|
8
6
|
SQUIN_STIM_GATE_MAPPING,
|
|
9
7
|
rewrite_Control,
|
|
10
|
-
insert_qubit_idx_from_address,
|
|
11
8
|
insert_qubit_idx_from_wire_ssa,
|
|
12
9
|
)
|
|
13
10
|
|
|
@@ -18,8 +15,6 @@ class SquinWireToStim(RewriteRule):
|
|
|
18
15
|
match node:
|
|
19
16
|
case wire.Apply() | wire.Broadcast():
|
|
20
17
|
return self.rewrite_Apply_and_Broadcast(node)
|
|
21
|
-
case wire.Reset():
|
|
22
|
-
return self.rewrite_Reset(node)
|
|
23
18
|
case _:
|
|
24
19
|
return RewriteResult()
|
|
25
20
|
|
|
@@ -55,19 +50,3 @@ class SquinWireToStim(RewriteRule):
|
|
|
55
50
|
stmt.replace_by(stim_1q_stmt)
|
|
56
51
|
|
|
57
52
|
return RewriteResult(has_done_something=True)
|
|
58
|
-
|
|
59
|
-
def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult:
|
|
60
|
-
address_attr = reset_stmt.wire.hints.get("address")
|
|
61
|
-
if address_attr is None:
|
|
62
|
-
return RewriteResult()
|
|
63
|
-
assert isinstance(address_attr, AddressAttribute)
|
|
64
|
-
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
65
|
-
address=address_attr, stmt_to_insert_before=reset_stmt
|
|
66
|
-
)
|
|
67
|
-
if qubit_idx_ssas is None:
|
|
68
|
-
return RewriteResult()
|
|
69
|
-
|
|
70
|
-
stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
|
|
71
|
-
reset_stmt.replace_by(stim_rz_stmt)
|
|
72
|
-
|
|
73
|
-
return RewriteResult(has_done_something=True)
|
bloqade/squin/wire.py
CHANGED
|
@@ -95,23 +95,18 @@ class Broadcast(ir.Statement):
|
|
|
95
95
|
class Measure(ir.Statement):
|
|
96
96
|
traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
|
|
97
97
|
wire: ir.SSAValue = info.argument(WireType)
|
|
98
|
+
qubit: ir.SSAValue = info.argument(QubitType)
|
|
98
99
|
result: ir.ResultValue = info.result(types.Int)
|
|
99
100
|
|
|
100
101
|
|
|
101
102
|
@statement(dialect=dialect)
|
|
102
|
-
class
|
|
103
|
-
traits = frozenset({lowering.FromPythonCall()
|
|
104
|
-
|
|
103
|
+
class NonDestructiveMeasure(ir.Statement):
|
|
104
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
105
|
+
input_wire: ir.SSAValue = info.argument(WireType)
|
|
105
106
|
result: ir.ResultValue = info.result(types.Int)
|
|
106
107
|
out_wire: ir.ResultValue = info.result(WireType)
|
|
107
108
|
|
|
108
109
|
|
|
109
|
-
@statement(dialect=dialect)
|
|
110
|
-
class Reset(ir.Statement):
|
|
111
|
-
traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
|
|
112
|
-
wire: ir.SSAValue = info.argument(WireType)
|
|
113
|
-
|
|
114
|
-
|
|
115
110
|
@wraps(Unwrap)
|
|
116
111
|
def unwrap(qubit: Qubit) -> Wire: ...
|
|
117
112
|
|
bloqade/stim/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from . import emit as emit, dialects as dialects
|
|
1
|
+
from . import emit as emit, parse as parse, dialects as dialects
|
|
2
2
|
from .groups import main as main
|
|
3
3
|
from ._wrappers import (
|
|
4
4
|
h as h,
|
|
@@ -34,6 +34,7 @@ from ._wrappers import (
|
|
|
34
34
|
depolarize1 as depolarize1,
|
|
35
35
|
depolarize2 as depolarize2,
|
|
36
36
|
pauli_string as pauli_string,
|
|
37
|
+
qubit_coords as qubit_coords,
|
|
37
38
|
pauli_channel1 as pauli_channel1,
|
|
38
39
|
pauli_channel2 as pauli_channel2,
|
|
39
40
|
observable_include as observable_include,
|
bloqade/stim/_wrappers.py
CHANGED
|
@@ -99,6 +99,10 @@ def pauli_string(
|
|
|
99
99
|
) -> auxiliary.PauliString: ...
|
|
100
100
|
|
|
101
101
|
|
|
102
|
+
@wraps(auxiliary.QubitCoordinates)
|
|
103
|
+
def qubit_coords(coord: tuple[Union[int, float], ...], target: int) -> None: ...
|
|
104
|
+
|
|
105
|
+
|
|
102
106
|
# dialect:: collapse
|
|
103
107
|
@wraps(collapse.MZ)
|
|
104
108
|
def mz(p: float, targets: tuple[int, ...]) -> None: ...
|
|
@@ -69,8 +69,10 @@ class EmitStimAuxMethods(MethodTable):
|
|
|
69
69
|
|
|
70
70
|
coord_str: str = ", ".join(coords)
|
|
71
71
|
target_str: str = " ".join(targets)
|
|
72
|
-
|
|
73
|
-
|
|
72
|
+
if len(coords):
|
|
73
|
+
emit.writeln(frame, f"DETECTOR({coord_str}) {target_str}")
|
|
74
|
+
else:
|
|
75
|
+
emit.writeln(frame, f"DETECTOR {target_str}")
|
|
74
76
|
return ()
|
|
75
77
|
|
|
76
78
|
@impl(stmts.ObservableInclude)
|
|
@@ -100,3 +102,16 @@ class EmitStimAuxMethods(MethodTable):
|
|
|
100
102
|
)
|
|
101
103
|
|
|
102
104
|
return (out,)
|
|
105
|
+
|
|
106
|
+
@impl(stmts.QubitCoordinates)
|
|
107
|
+
def qubit_coordinates(
|
|
108
|
+
self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.QubitCoordinates
|
|
109
|
+
):
|
|
110
|
+
|
|
111
|
+
coords: tuple[str, ...] = frame.get_values(stmt.coord)
|
|
112
|
+
target: str = frame.get(stmt.target)
|
|
113
|
+
|
|
114
|
+
coord_str: str = ", ".join(coords)
|
|
115
|
+
emit.writeln(frame, f"QUBIT_COORDS({coord_str}) {target}")
|
|
116
|
+
|
|
117
|
+
return ()
|
|
@@ -45,3 +45,11 @@ class NewPauliString(ir.Statement):
|
|
|
45
45
|
flipped: tuple[ir.SSAValue, ...] = info.argument(types.Bool)
|
|
46
46
|
targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
|
|
47
47
|
result: ir.ResultValue = info.result(type=PauliStringType)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@statement(dialect=dialect)
|
|
51
|
+
class QubitCoordinates(ir.Statement):
|
|
52
|
+
name = "qubit_coordinates"
|
|
53
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
54
|
+
coord: tuple[ir.SSAValue, ...] = info.argument(PyNum)
|
|
55
|
+
target: ir.SSAValue = info.argument(types.Int)
|
|
@@ -60,7 +60,9 @@ class EmitStimCollapseMethods(MethodTable):
|
|
|
60
60
|
self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PPMeasurement
|
|
61
61
|
):
|
|
62
62
|
probability: str = frame.get(stmt.p)
|
|
63
|
-
targets: tuple[str, ...] =
|
|
63
|
+
targets: tuple[str, ...] = tuple(
|
|
64
|
+
targ.upper() for targ in frame.get_values(stmt.targets)
|
|
65
|
+
)
|
|
64
66
|
|
|
65
67
|
out = f"MPP({probability}) " + " ".join(targets)
|
|
66
68
|
emit.writeln(frame, out)
|
|
@@ -12,6 +12,7 @@ from .stmts.base import SingleQubitGate, ControlledTwoQubitGate
|
|
|
12
12
|
class EmitStimGateMethods(MethodTable):
|
|
13
13
|
|
|
14
14
|
gate_1q_map: dict[str, tuple[str, str]] = {
|
|
15
|
+
stmts.Identity.name: ("I", "I"),
|
|
15
16
|
stmts.X.name: ("X", "X"),
|
|
16
17
|
stmts.Y.name: ("Y", "Y"),
|
|
17
18
|
stmts.Z.name: ("Z", "Z"),
|
|
@@ -22,6 +23,7 @@ class EmitStimGateMethods(MethodTable):
|
|
|
22
23
|
stmts.SqrtZ.name: ("SQRT_Z", "SQRT_Z_DAG"),
|
|
23
24
|
}
|
|
24
25
|
|
|
26
|
+
@impl(stmts.Identity)
|
|
25
27
|
@impl(stmts.X)
|
|
26
28
|
@impl(stmts.Y)
|
|
27
29
|
@impl(stmts.Z)
|
|
@@ -80,8 +82,13 @@ class EmitStimGateMethods(MethodTable):
|
|
|
80
82
|
@impl(stmts.SPP)
|
|
81
83
|
def spp(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.SPP):
|
|
82
84
|
|
|
83
|
-
targets: tuple[str, ...] =
|
|
84
|
-
|
|
85
|
+
targets: tuple[str, ...] = tuple(
|
|
86
|
+
targ.upper() for targ in frame.get_values(stmt.targets)
|
|
87
|
+
)
|
|
88
|
+
if stmt.dagger:
|
|
89
|
+
res = "SPP_DAG " + " ".join(targets)
|
|
90
|
+
else:
|
|
91
|
+
res = "SPP " + " ".join(targets)
|
|
85
92
|
emit.writeln(frame, res)
|
|
86
93
|
|
|
87
94
|
return ()
|
|
@@ -44,7 +44,7 @@ class EmitStimNoiseMethods(MethodTable):
|
|
|
44
44
|
px: str = frame.get(stmt.px)
|
|
45
45
|
py: str = frame.get(stmt.py)
|
|
46
46
|
pz: str = frame.get(stmt.pz)
|
|
47
|
-
res = f"PAULI_CHANNEL_1({px},{py},{pz}) " + " ".join(targets)
|
|
47
|
+
res = f"PAULI_CHANNEL_1({px}, {py}, {pz}) " + " ".join(targets)
|
|
48
48
|
emit.writeln(frame, res)
|
|
49
49
|
|
|
50
50
|
return ()
|
|
@@ -64,3 +64,34 @@ class EmitStimNoiseMethods(MethodTable):
|
|
|
64
64
|
emit.writeln(frame, res)
|
|
65
65
|
|
|
66
66
|
return ()
|
|
67
|
+
|
|
68
|
+
@impl(stmts.TrivialError)
|
|
69
|
+
def non_stim_error(
|
|
70
|
+
self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.TrivialError
|
|
71
|
+
):
|
|
72
|
+
|
|
73
|
+
targets: tuple[str, ...] = frame.get_values(stmt.targets)
|
|
74
|
+
prob: tuple[str, ...] = frame.get_values(stmt.probs)
|
|
75
|
+
prob_str: str = ", ".join(prob)
|
|
76
|
+
|
|
77
|
+
res = f"I_ERROR[{stmt.name}]({prob_str}) " + " ".join(targets)
|
|
78
|
+
emit.writeln(frame, res)
|
|
79
|
+
|
|
80
|
+
return ()
|
|
81
|
+
|
|
82
|
+
@impl(stmts.TrivialCorrelatedError)
|
|
83
|
+
def non_stim_corr_error(
|
|
84
|
+
self,
|
|
85
|
+
emit: EmitStimMain,
|
|
86
|
+
frame: EmitStrFrame,
|
|
87
|
+
stmt: stmts.TrivialCorrelatedError,
|
|
88
|
+
):
|
|
89
|
+
|
|
90
|
+
targets: tuple[str, ...] = frame.get_values(stmt.targets)
|
|
91
|
+
prob: tuple[str, ...] = frame.get_values(stmt.probs)
|
|
92
|
+
prob_str: str = ", ".join(prob)
|
|
93
|
+
|
|
94
|
+
res = f"I_ERROR[{stmt.name}:{stmt.nonce}]({prob_str}) " + " ".join(targets)
|
|
95
|
+
emit.writeln(frame, res)
|
|
96
|
+
|
|
97
|
+
return ()
|
|
@@ -75,3 +75,32 @@ class ZError(ir.Statement):
|
|
|
75
75
|
traits = frozenset({lowering.FromPythonCall()})
|
|
76
76
|
p: ir.SSAValue = info.argument(types.Float)
|
|
77
77
|
targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@statement
|
|
81
|
+
class NonStimError(ir.Statement):
|
|
82
|
+
name = "NonStimError"
|
|
83
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
84
|
+
probs: tuple[ir.SSAValue, ...] = info.argument(types.Float)
|
|
85
|
+
targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@statement
|
|
89
|
+
class NonStimCorrelatedError(ir.Statement):
|
|
90
|
+
name = "NonStimCorrelatedError"
|
|
91
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
92
|
+
nonce: int = (
|
|
93
|
+
info.attribute()
|
|
94
|
+
) # Must be a unique value, otherwise stim might merge two correlated errors with equal probabilities
|
|
95
|
+
probs: tuple[ir.SSAValue, ...] = info.argument(types.Float)
|
|
96
|
+
targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@statement(dialect=dialect)
|
|
100
|
+
class TrivialCorrelatedError(NonStimCorrelatedError):
|
|
101
|
+
name = "TRIV_CORR_ERROR"
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@statement(dialect=dialect)
|
|
105
|
+
class TrivialError(NonStimError):
|
|
106
|
+
name = "TRIV_ERROR"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .lowering import loads as loads, loadfile as loadfile
|