bloqade-circuit 0.7.13__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/__init__.py +8 -4
- bloqade/analysis/address/analysis.py +119 -29
- bloqade/analysis/address/impls.py +290 -87
- bloqade/analysis/address/lattice.py +209 -24
- bloqade/analysis/fidelity/analysis.py +2 -2
- bloqade/analysis/measure_id/impls.py +3 -27
- bloqade/cirq_utils/__init__.py +3 -1
- bloqade/cirq_utils/emit/__init__.py +3 -0
- bloqade/cirq_utils/emit/base.py +243 -0
- bloqade/cirq_utils/emit/gate.py +104 -0
- bloqade/cirq_utils/emit/noise.py +90 -0
- bloqade/cirq_utils/emit/qubit.py +35 -0
- bloqade/cirq_utils/lowering.py +664 -0
- bloqade/native/__init__.py +0 -1
- bloqade/native/_prelude.py +3 -3
- bloqade/native/dialects/gate/__init__.py +2 -0
- bloqade/native/dialects/gate/_dialect.py +3 -0
- bloqade/native/dialects/{gates → gate}/_interface.py +5 -5
- bloqade/native/dialects/{gates → gate}/stmts.py +5 -5
- bloqade/native/stdlib/broadcast.py +19 -19
- bloqade/native/stdlib/simple.py +14 -13
- bloqade/native/upstream/__init__.py +5 -0
- bloqade/native/upstream/squin2native.py +136 -0
- bloqade/pyqrack/__init__.py +1 -2
- bloqade/pyqrack/device.py +6 -17
- bloqade/pyqrack/native.py +17 -17
- bloqade/pyqrack/reg.py +1 -6
- bloqade/pyqrack/squin/gate/__init__.py +1 -0
- bloqade/pyqrack/squin/gate/gate.py +136 -0
- bloqade/pyqrack/squin/noise/native.py +120 -54
- bloqade/pyqrack/squin/qubit.py +25 -41
- bloqade/pyqrack/target.py +2 -2
- bloqade/qasm2/dialects/core/address.py +21 -12
- bloqade/qasm2/dialects/noise/fidelity.py +2 -6
- bloqade/qasm2/dialects/noise/model.py +2 -1
- bloqade/qasm2/passes/parallel.py +3 -1
- bloqade/qasm2/rewrite/__init__.py +0 -1
- bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
- bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
- bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
- bloqade/qubit/__init__.py +12 -0
- bloqade/qubit/_dialect.py +3 -0
- bloqade/qubit/_interface.py +49 -0
- bloqade/qubit/_prelude.py +45 -0
- bloqade/qubit/analysis/__init__.py +1 -0
- bloqade/qubit/analysis/address_impl.py +40 -0
- bloqade/qubit/stdlib/__init__.py +2 -0
- bloqade/qubit/stdlib/_new.py +34 -0
- bloqade/qubit/stdlib/broadcast.py +62 -0
- bloqade/qubit/stdlib/simple.py +59 -0
- bloqade/qubit/stmts.py +60 -0
- bloqade/rewrite/passes/aggressive_unroll.py +2 -1
- bloqade/squin/__init__.py +44 -17
- bloqade/squin/analysis/__init__.py +0 -1
- bloqade/squin/analysis/schedule.py +2 -2
- bloqade/squin/gate/__init__.py +2 -0
- bloqade/squin/gate/_dialect.py +3 -0
- bloqade/squin/gate/_interface.py +98 -0
- bloqade/squin/gate/stmts.py +119 -0
- bloqade/squin/groups.py +4 -21
- bloqade/squin/noise/__init__.py +1 -9
- bloqade/squin/noise/_dialect.py +1 -1
- bloqade/squin/noise/_interface.py +45 -0
- bloqade/squin/noise/stmts.py +65 -29
- bloqade/squin/rewrite/U3_to_clifford.py +70 -51
- bloqade/squin/rewrite/__init__.py +0 -2
- bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
- bloqade/squin/rewrite/wrap_analysis.py +4 -35
- bloqade/squin/stdlib/broadcast/__init__.py +34 -0
- bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
- bloqade/squin/stdlib/broadcast/gate.py +260 -0
- bloqade/squin/stdlib/broadcast/noise.py +144 -0
- bloqade/squin/stdlib/simple/__init__.py +33 -0
- bloqade/squin/stdlib/simple/gate.py +242 -0
- bloqade/squin/stdlib/simple/noise.py +126 -0
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +6 -0
- bloqade/stim/dialects/noise/emit.py +6 -1
- bloqade/stim/dialects/noise/stmts.py +5 -3
- bloqade/stim/emit/stim_str.py +2 -0
- bloqade/stim/parse/lowering.py +12 -17
- bloqade/stim/passes/__init__.py +0 -1
- bloqade/stim/passes/flatten.py +26 -0
- bloqade/stim/passes/simplify_ifs.py +6 -1
- bloqade/stim/passes/squin_to_stim.py +4 -70
- bloqade/stim/rewrite/__init__.py +0 -4
- bloqade/stim/rewrite/ifs_to_stim.py +23 -29
- bloqade/stim/rewrite/qubit_to_stim.py +96 -51
- bloqade/stim/rewrite/squin_measure.py +9 -18
- bloqade/stim/rewrite/squin_noise.py +132 -108
- bloqade/stim/rewrite/util.py +5 -204
- bloqade/types.py +10 -0
- {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/METADATA +2 -2
- {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/RECORD +96 -100
- bloqade/native/dialects/gates/__init__.py +0 -3
- bloqade/native/dialects/gates/_dialect.py +0 -3
- bloqade/pyqrack/squin/op.py +0 -180
- bloqade/pyqrack/squin/runtime.py +0 -543
- bloqade/pyqrack/squin/wire.py +0 -51
- bloqade/squin/_typeinfer.py +0 -20
- bloqade/squin/analysis/address_impl.py +0 -71
- bloqade/squin/analysis/nsites/__init__.py +0 -9
- bloqade/squin/analysis/nsites/analysis.py +0 -50
- bloqade/squin/analysis/nsites/impls.py +0 -99
- bloqade/squin/analysis/nsites/lattice.py +0 -49
- bloqade/squin/cirq/__init__.py +0 -306
- bloqade/squin/cirq/emit/emit_circuit.py +0 -129
- bloqade/squin/cirq/emit/noise.py +0 -49
- bloqade/squin/cirq/emit/op.py +0 -176
- bloqade/squin/cirq/emit/qubit.py +0 -58
- bloqade/squin/cirq/emit/runtime.py +0 -242
- bloqade/squin/cirq/lowering.py +0 -439
- bloqade/squin/lowering.py +0 -80
- bloqade/squin/noise/_wrapper.py +0 -36
- bloqade/squin/noise/rewrite.py +0 -129
- bloqade/squin/op/__init__.py +0 -41
- bloqade/squin/op/_dialect.py +0 -3
- bloqade/squin/op/_wrapper.py +0 -121
- bloqade/squin/op/number.py +0 -5
- bloqade/squin/op/rewrite.py +0 -46
- bloqade/squin/op/stdlib.py +0 -62
- bloqade/squin/op/stmts.py +0 -300
- bloqade/squin/op/traits.py +0 -43
- bloqade/squin/op/types.py +0 -128
- bloqade/squin/parallel.py +0 -200
- bloqade/squin/qubit.py +0 -194
- bloqade/squin/rewrite/canonicalize.py +0 -60
- bloqade/squin/rewrite/desugar.py +0 -102
- bloqade/squin/stdlib/channel.py +0 -86
- bloqade/squin/stdlib/gate.py +0 -201
- bloqade/squin/types.py +0 -8
- bloqade/squin/wire.py +0 -201
- bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
- bloqade/stim/rewrite/wire_to_stim.py +0 -57
- {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from typing import Any, Literal, TypeVar
|
|
2
|
+
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
|
|
5
|
+
from bloqade.types import Qubit
|
|
6
|
+
|
|
7
|
+
from .. import broadcast
|
|
8
|
+
from ...groups import kernel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@kernel
|
|
12
|
+
def depolarize(p: float, qubit: Qubit) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Apply a depolarizing noise channel to a qubit with probability `p`.
|
|
15
|
+
|
|
16
|
+
This will randomly select one of the Pauli operators X, Y, Z
|
|
17
|
+
with a probability `p / 3` and apply it to the qubit. No operator is applied
|
|
18
|
+
with a probability of `1 - p`.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
p (float): The probability with which a Pauli operator is applied.
|
|
22
|
+
qubit (Qubit): The qubit to which the noise channel is applied.
|
|
23
|
+
"""
|
|
24
|
+
broadcast.depolarize(p, ilist.IList([qubit]))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
N = TypeVar("N", bound=int)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@kernel
|
|
31
|
+
def depolarize2(p: float, control: Qubit, target: Qubit) -> None:
|
|
32
|
+
"""
|
|
33
|
+
Symmetric two-qubit depolarization channel applied to a pair of qubits.
|
|
34
|
+
|
|
35
|
+
This will randomly select one of the pauli products
|
|
36
|
+
|
|
37
|
+
`{IX, IY, IZ, XI, XX, XY, XZ, YI, YX, YY, YZ, ZI, ZX, ZY, ZZ}`
|
|
38
|
+
|
|
39
|
+
each with a probability `p / 15`. No noise is applied with a probability of `1 - p`.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
p (float): The probability with which a Pauli product is applied.
|
|
43
|
+
control (Qubit): The control qubit.
|
|
44
|
+
target (Qubit): The target qubit.
|
|
45
|
+
"""
|
|
46
|
+
broadcast.depolarize2(p, ilist.IList([control]), ilist.IList([target]))
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@kernel
|
|
50
|
+
def single_qubit_pauli_channel(px: float, py: float, pz: float, qubit: Qubit) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Apply a Pauli error channel with weighted `px, py, pz`. No error is applied with a probability
|
|
53
|
+
`1 - (px + py + pz)`.
|
|
54
|
+
|
|
55
|
+
This randomly selects one of the three Pauli operators X, Y, Z, weighted with the given probabilities in that order.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
probabilities (IList[float, Literal[3]]): A list of 3 probabilities corresponding to the probabilities `(p_x, p_y, p_z)` in that order.
|
|
59
|
+
qubit (Qubit): The qubit to which the noise channel is applied.
|
|
60
|
+
"""
|
|
61
|
+
broadcast.single_qubit_pauli_channel(px, py, pz, ilist.IList([qubit]))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@kernel
|
|
65
|
+
def two_qubit_pauli_channel(
|
|
66
|
+
probabilities: ilist.IList[float, Literal[15]], control: Qubit, target: Qubit
|
|
67
|
+
) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Apply a Pauli product error with weighted `probabilities` to the pair of qubits.
|
|
70
|
+
|
|
71
|
+
No error is applied with the probability `1 - sum(probabilities)`.
|
|
72
|
+
|
|
73
|
+
This will randomly select one of the pauli products
|
|
74
|
+
|
|
75
|
+
`{IX, IY, IZ, XI, XX, XY, XZ, YI, YX, YY, YZ, ZI, ZX, ZY, ZZ}`
|
|
76
|
+
|
|
77
|
+
weighted with the corresponding list of probabilities.
|
|
78
|
+
|
|
79
|
+
**NOTE**: The order of the given probabilities must match the order of the list of Pauli products above!
|
|
80
|
+
"""
|
|
81
|
+
broadcast.two_qubit_pauli_channel(
|
|
82
|
+
probabilities, ilist.IList([control]), ilist.IList([target])
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@kernel
|
|
87
|
+
def qubit_loss(p: float, qubit: Qubit) -> None:
|
|
88
|
+
"""
|
|
89
|
+
Apply a qubit loss channel to the given qubit.
|
|
90
|
+
|
|
91
|
+
The qubit is lost with a probability `p`.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
p (float): Probability of the atom being lost.
|
|
95
|
+
qubit (Qubit): The qubit to which the noise channel is applied.
|
|
96
|
+
"""
|
|
97
|
+
broadcast.qubit_loss(p, ilist.IList([qubit]))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@kernel
|
|
101
|
+
def correlated_qubit_loss(p: float, qubits: ilist.IList[Qubit, Any]) -> None:
|
|
102
|
+
"""
|
|
103
|
+
Apply a correlated qubit loss channel to the given qubits.
|
|
104
|
+
|
|
105
|
+
All qubits are lost together with a probability `p`.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
p (float): Probability of the qubits being lost.
|
|
109
|
+
qubits (IList[Qubit, Any]): The list of qubits to which the correlated noise channel is applied.
|
|
110
|
+
"""
|
|
111
|
+
broadcast.correlated_qubit_loss(p, ilist.IList([qubits]))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# NOTE: actual stdlib that doesn't wrap statements starts here
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@kernel
|
|
118
|
+
def bit_flip(p: float, qubit: Qubit) -> None:
|
|
119
|
+
"""
|
|
120
|
+
Apply a bit flip error channel to the qubit with probability `p`.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
p (float): Probability of a bit flip error being applied.
|
|
124
|
+
qubit (Qubit): The qubit to which the noise channel is applied.
|
|
125
|
+
"""
|
|
126
|
+
single_qubit_pauli_channel(p, 0, 0, qubit)
|
bloqade/stim/__init__.py
CHANGED
bloqade/stim/_wrappers.py
CHANGED
|
@@ -194,3 +194,9 @@ def z_error(p: float, targets: tuple[int, ...]) -> None: ...
|
|
|
194
194
|
|
|
195
195
|
@wraps(noise.QubitLoss)
|
|
196
196
|
def qubit_loss(probs: tuple[float, ...], targets: tuple[int, ...]) -> None: ...
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@wraps(noise.CorrelatedQubitLoss)
|
|
200
|
+
def correlated_qubit_loss(
|
|
201
|
+
probs: tuple[float, ...], targets: tuple[int, ...]
|
|
202
|
+
) -> None: ...
|
|
@@ -81,6 +81,7 @@ class EmitStimNoiseMethods(MethodTable):
|
|
|
81
81
|
return ()
|
|
82
82
|
|
|
83
83
|
@impl(stmts.TrivialCorrelatedError)
|
|
84
|
+
@impl(stmts.CorrelatedQubitLoss)
|
|
84
85
|
def non_stim_corr_error(
|
|
85
86
|
self,
|
|
86
87
|
emit: EmitStimMain,
|
|
@@ -92,7 +93,11 @@ class EmitStimNoiseMethods(MethodTable):
|
|
|
92
93
|
prob: tuple[str, ...] = frame.get_values(stmt.probs)
|
|
93
94
|
prob_str: str = ", ".join(prob)
|
|
94
95
|
|
|
95
|
-
res =
|
|
96
|
+
res = (
|
|
97
|
+
f"I_ERROR[{stmt.name}:{emit.correlated_error_count}]({prob_str}) "
|
|
98
|
+
+ " ".join(targets)
|
|
99
|
+
)
|
|
100
|
+
emit.correlated_error_count += 1
|
|
96
101
|
emit.writeln(frame, res)
|
|
97
102
|
|
|
98
103
|
return ()
|
|
@@ -89,9 +89,6 @@ class NonStimError(ir.Statement):
|
|
|
89
89
|
class NonStimCorrelatedError(ir.Statement):
|
|
90
90
|
name = "NonStimCorrelatedError"
|
|
91
91
|
traits = frozenset({lowering.FromPythonCall()})
|
|
92
|
-
nonce: int = (
|
|
93
|
-
info.attribute()
|
|
94
|
-
) # Must be a unique value, otherwise stim might merge two correlated errors with equal probabilities
|
|
95
92
|
probs: tuple[ir.SSAValue, ...] = info.argument(types.Float)
|
|
96
93
|
targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
|
|
97
94
|
|
|
@@ -109,3 +106,8 @@ class TrivialError(NonStimError):
|
|
|
109
106
|
@statement(dialect=dialect)
|
|
110
107
|
class QubitLoss(NonStimError):
|
|
111
108
|
name = "loss"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@statement(dialect=dialect)
|
|
112
|
+
class CorrelatedQubitLoss(NonStimCorrelatedError):
|
|
113
|
+
name = "correlated_loss"
|
bloqade/stim/emit/stim_str.py
CHANGED
|
@@ -20,11 +20,13 @@ class EmitStimMain(EmitStr):
|
|
|
20
20
|
keys = ["emit.stim"]
|
|
21
21
|
dialects: ir.DialectGroup = field(default_factory=_default_dialect_group)
|
|
22
22
|
file: StringIO = field(default_factory=StringIO)
|
|
23
|
+
correlation_identifier_offset: int = 0
|
|
23
24
|
|
|
24
25
|
def initialize(self):
|
|
25
26
|
super().initialize()
|
|
26
27
|
self.file.truncate(0)
|
|
27
28
|
self.file.seek(0)
|
|
29
|
+
self.correlated_error_count = self.correlation_identifier_offset
|
|
28
30
|
return self
|
|
29
31
|
|
|
30
32
|
def eval_stmt_fallback(
|
bloqade/stim/parse/lowering.py
CHANGED
|
@@ -627,10 +627,13 @@ class Stim(lowering.LoweringABC[Node]):
|
|
|
627
627
|
# Parse tag
|
|
628
628
|
tag_parts = node.tag.split(";", maxsplit=1)[0].split(":", maxsplit=1)
|
|
629
629
|
nonstim_name = tag_parts[0]
|
|
630
|
-
nonce = 0
|
|
631
630
|
if len(tag_parts) == 2:
|
|
631
|
+
# This should be a correlated error of the form, e.g.,
|
|
632
|
+
# I_ERROR[correlated_loss:<identifier>](0.01) 0 1 2
|
|
633
|
+
# The identifier is a unique number that prevents stim from merging
|
|
634
|
+
# correlated errors. We discard the identifier, but verify it is an integer.
|
|
632
635
|
try:
|
|
633
|
-
|
|
636
|
+
_ = int(tag_parts[1])
|
|
634
637
|
except ValueError:
|
|
635
638
|
# String was not an integer
|
|
636
639
|
if self.error_unknown_nonstim:
|
|
@@ -643,22 +646,14 @@ class Stim(lowering.LoweringABC[Node]):
|
|
|
643
646
|
f"Unknown non-stim statement name: {nonstim_name!r} ({node!r})"
|
|
644
647
|
)
|
|
645
648
|
statement_cls = self.nonstim_noise_ops.get(nonstim_name)
|
|
649
|
+
stmt = None
|
|
646
650
|
if statement_cls is not None:
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
),
|
|
654
|
-
)
|
|
655
|
-
else:
|
|
656
|
-
stmt = statement_cls(
|
|
657
|
-
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
|
|
658
|
-
targets=self._get_multiple_qubit_or_rec_ssa(
|
|
659
|
-
state, node, node.targets_copy()
|
|
660
|
-
),
|
|
661
|
-
)
|
|
651
|
+
stmt = statement_cls(
|
|
652
|
+
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
|
|
653
|
+
targets=self._get_multiple_qubit_or_rec_ssa(
|
|
654
|
+
state, node, node.targets_copy()
|
|
655
|
+
),
|
|
656
|
+
)
|
|
662
657
|
return stmt
|
|
663
658
|
|
|
664
659
|
def visit_CircuitInstruction(
|
bloqade/stim/passes/__init__.py
CHANGED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Taken from Phillip Weinberg's bloqade-shuttle implementation
|
|
2
|
+
from dataclasses import field, dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.passes import Pass
|
|
6
|
+
from kirin.rewrite.abc import RewriteResult
|
|
7
|
+
|
|
8
|
+
from bloqade.qasm2.passes.fold import AggressiveUnroll
|
|
9
|
+
from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class Flatten(Pass):
|
|
14
|
+
|
|
15
|
+
unroll: AggressiveUnroll = field(init=False)
|
|
16
|
+
simplify_if: StimSimplifyIfs = field(init=False)
|
|
17
|
+
|
|
18
|
+
def __post_init__(self):
|
|
19
|
+
self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise)
|
|
20
|
+
self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise)
|
|
21
|
+
|
|
22
|
+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
|
|
23
|
+
rewrite_result = RewriteResult()
|
|
24
|
+
rewrite_result = self.simplify_if(mt).join(rewrite_result)
|
|
25
|
+
rewrite_result = self.unroll(mt).join(rewrite_result)
|
|
26
|
+
return rewrite_result
|
|
@@ -7,8 +7,10 @@ from kirin.rewrite import (
|
|
|
7
7
|
Chain,
|
|
8
8
|
Fixpoint,
|
|
9
9
|
ConstantFold,
|
|
10
|
+
DeadCodeElimination,
|
|
10
11
|
CommonSubexpressionElimination,
|
|
11
12
|
)
|
|
13
|
+
from kirin.dialects.scf.trim import UnusedYield
|
|
12
14
|
from kirin.dialects.ilist.passes import ConstList2IList
|
|
13
15
|
|
|
14
16
|
from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
|
|
@@ -20,7 +22,10 @@ class StimSimplifyIfs(Pass):
|
|
|
20
22
|
def unsafe_run(self, mt: ir.Method):
|
|
21
23
|
|
|
22
24
|
result = Chain(
|
|
23
|
-
|
|
25
|
+
Walk(UnusedYield()),
|
|
26
|
+
Walk(StimLiftThenBody()),
|
|
27
|
+
# remove yields (if possible), then lift out as much stuff as possible
|
|
28
|
+
Walk(DeadCodeElimination()),
|
|
24
29
|
Walk(StimSplitIfStmts()),
|
|
25
30
|
).rewrite(mt.code)
|
|
26
31
|
|
|
@@ -1,29 +1,21 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
|
|
3
|
-
from kirin.passes import Fold, TypeInfer
|
|
4
3
|
from kirin.rewrite import (
|
|
5
4
|
Walk,
|
|
6
5
|
Chain,
|
|
7
6
|
Fixpoint,
|
|
8
|
-
CFGCompactify,
|
|
9
7
|
DeadCodeElimination,
|
|
10
8
|
CommonSubexpressionElimination,
|
|
11
9
|
)
|
|
12
|
-
from kirin.dialects import ilist
|
|
13
10
|
from kirin.ir.method import Method
|
|
14
11
|
from kirin.passes.abc import Pass
|
|
15
12
|
from kirin.rewrite.abc import RewriteResult
|
|
16
|
-
from kirin.passes.inline import InlinePass
|
|
17
|
-
from kirin.rewrite.alias import InlineAlias
|
|
18
|
-
from kirin.passes.aggressive import UnrollScf
|
|
19
13
|
|
|
20
14
|
from bloqade.stim.rewrite import (
|
|
21
|
-
SquinWireToStim,
|
|
22
15
|
PyConstantToStim,
|
|
23
16
|
SquinNoiseToStim,
|
|
24
17
|
SquinQubitToStim,
|
|
25
18
|
SquinMeasureToStim,
|
|
26
|
-
SquinWireIdentityElimination,
|
|
27
19
|
)
|
|
28
20
|
from bloqade.squin.rewrite import (
|
|
29
21
|
SquinU3ToClifford,
|
|
@@ -33,9 +25,8 @@ from bloqade.squin.rewrite import (
|
|
|
33
25
|
from bloqade.rewrite.passes import CanonicalizeIList
|
|
34
26
|
from bloqade.analysis.address import AddressAnalysis
|
|
35
27
|
from bloqade.analysis.measure_id import MeasurementIDAnalysis
|
|
36
|
-
from bloqade.
|
|
28
|
+
from bloqade.stim.passes.flatten import Flatten
|
|
37
29
|
|
|
38
|
-
from .simplify_ifs import StimSimplifyIfs
|
|
39
30
|
from ..rewrite.ifs_to_stim import IfToStim
|
|
40
31
|
|
|
41
32
|
|
|
@@ -45,63 +36,8 @@ class SquinToStimPass(Pass):
|
|
|
45
36
|
def unsafe_run(self, mt: Method) -> RewriteResult:
|
|
46
37
|
|
|
47
38
|
# inline aggressively:
|
|
48
|
-
rewrite_result =
|
|
49
|
-
|
|
50
|
-
).unsafe_run(mt)
|
|
51
|
-
|
|
52
|
-
rewrite_result = Walk(ilist.rewrite.HintLen()).rewrite(mt.code)
|
|
53
|
-
rewrite_result = Fold(self.dialects).unsafe_run(mt).join(rewrite_result)
|
|
54
|
-
|
|
55
|
-
rewrite_result = (
|
|
56
|
-
UnrollScf(dialects=mt.dialects, no_raise=self.no_raise)
|
|
57
|
-
.fixpoint(mt)
|
|
58
|
-
.join(rewrite_result)
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
rewrite_result = (
|
|
62
|
-
Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
rewrite_result = Walk(InlineAlias()).rewrite(mt.code).join(rewrite_result)
|
|
66
|
-
|
|
67
|
-
rewrite_result = (
|
|
68
|
-
StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
|
|
69
|
-
.unsafe_run(mt)
|
|
70
|
-
.join(rewrite_result)
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
rewrite_result = (
|
|
74
|
-
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
|
|
75
|
-
.rewrite(mt.code)
|
|
76
|
-
.join(rewrite_result)
|
|
77
|
-
)
|
|
78
|
-
rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt)
|
|
79
|
-
|
|
80
|
-
rewrite_result = (
|
|
81
|
-
UnrollScf(mt.dialects, no_raise=self.no_raise)
|
|
82
|
-
.fixpoint(mt)
|
|
83
|
-
.join(rewrite_result)
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
rewrite_result = (
|
|
87
|
-
CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
|
|
88
|
-
.unsafe_run(mt)
|
|
89
|
-
.join(rewrite_result)
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
rewrite_result = TypeInfer(
|
|
93
|
-
dialects=mt.dialects, no_raise=self.no_raise
|
|
94
|
-
).unsafe_run(mt)
|
|
95
|
-
|
|
96
|
-
rewrite_result = (
|
|
97
|
-
Walk(
|
|
98
|
-
Chain(
|
|
99
|
-
ApplyDesugarRule(),
|
|
100
|
-
MeasureDesugarRule(),
|
|
101
|
-
)
|
|
102
|
-
)
|
|
103
|
-
.rewrite(mt.code)
|
|
104
|
-
.join(rewrite_result)
|
|
39
|
+
rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint(
|
|
40
|
+
mt
|
|
105
41
|
)
|
|
106
42
|
|
|
107
43
|
# after this the program should be in a state where it is analyzable
|
|
@@ -145,8 +81,6 @@ class SquinToStimPass(Pass):
|
|
|
145
81
|
Chain(
|
|
146
82
|
SquinQubitToStim(),
|
|
147
83
|
SquinMeasureToStim(),
|
|
148
|
-
SquinWireToStim(),
|
|
149
|
-
SquinWireIdentityElimination(),
|
|
150
84
|
)
|
|
151
85
|
)
|
|
152
86
|
.rewrite(mt.code)
|
|
@@ -163,7 +97,7 @@ class SquinToStimPass(Pass):
|
|
|
163
97
|
rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
|
|
164
98
|
|
|
165
99
|
# clear up leftover stmts
|
|
166
|
-
# - remove any squin.
|
|
100
|
+
# - remove any squin.qalloc that's left around
|
|
167
101
|
rewrite_result = (
|
|
168
102
|
Fixpoint(
|
|
169
103
|
Walk(
|
bloqade/stim/rewrite/__init__.py
CHANGED
|
@@ -1,9 +1,5 @@
|
|
|
1
1
|
from .ifs_to_stim import IfToStim as IfToStim
|
|
2
2
|
from .squin_noise import SquinNoiseToStim as SquinNoiseToStim
|
|
3
|
-
from .wire_to_stim import SquinWireToStim as SquinWireToStim
|
|
4
3
|
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
|
|
5
4
|
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
|
|
6
5
|
from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
|
|
7
|
-
from .wire_identity_elimination import (
|
|
8
|
-
SquinWireIdentityElimination as SquinWireIdentityElimination,
|
|
9
|
-
)
|
|
@@ -4,13 +4,13 @@ from kirin import ir
|
|
|
4
4
|
from kirin.dialects import py, scf, func
|
|
5
5
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
6
6
|
|
|
7
|
-
from bloqade.squin import
|
|
7
|
+
from bloqade.squin import gate
|
|
8
8
|
from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
|
|
9
9
|
from bloqade.squin.rewrite import AddressAttribute
|
|
10
10
|
from bloqade.stim.rewrite.util import (
|
|
11
|
-
SQUIN_STIM_CONTROL_GATE_MAPPING,
|
|
12
11
|
insert_qubit_idx_from_address,
|
|
13
12
|
)
|
|
13
|
+
from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ
|
|
14
14
|
from bloqade.analysis.measure_id import MeasureIDFrame
|
|
15
15
|
from bloqade.stim.dialects.auxiliary import GetRecord
|
|
16
16
|
from bloqade.analysis.measure_id.lattice import (
|
|
@@ -58,8 +58,7 @@ class IfElseSimplification:
|
|
|
58
58
|
"""Check if the IfElse statement has an else body."""
|
|
59
59
|
if stmt.else_body.blocks and not (
|
|
60
60
|
len(stmt.else_body.blocks[0].stmts) == 1
|
|
61
|
-
and isinstance(
|
|
62
|
-
and not else_term.values # empty yield
|
|
61
|
+
and isinstance(stmt.else_body.blocks[0].last_stmt, scf.Yield)
|
|
63
62
|
):
|
|
64
63
|
return True
|
|
65
64
|
|
|
@@ -67,12 +66,13 @@ class IfElseSimplification:
|
|
|
67
66
|
|
|
68
67
|
|
|
69
68
|
DontLiftType = (
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
69
|
+
gate.stmts.SingleQubitGate,
|
|
70
|
+
gate.stmts.RotationGate,
|
|
71
|
+
gate.stmts.ControlledGate,
|
|
73
72
|
func.Return,
|
|
74
73
|
func.Invoke,
|
|
75
74
|
scf.IfElse,
|
|
75
|
+
scf.Yield,
|
|
76
76
|
)
|
|
77
77
|
|
|
78
78
|
|
|
@@ -99,16 +99,16 @@ class StimSplitIfStmts(IfElseSimplification, SplitIfStmts):
|
|
|
99
99
|
Given an IfElse with multiple valid statements in the then-body:
|
|
100
100
|
|
|
101
101
|
if measure_result:
|
|
102
|
-
squin.
|
|
103
|
-
squin.
|
|
102
|
+
squin.x(q0)
|
|
103
|
+
squin.y(q1)
|
|
104
104
|
|
|
105
105
|
this should be rewritten to:
|
|
106
106
|
|
|
107
107
|
if measure_result:
|
|
108
|
-
squin.
|
|
108
|
+
squin.x(q0)
|
|
109
109
|
|
|
110
110
|
if measure_result:
|
|
111
|
-
squin.
|
|
111
|
+
squin.y(q1)
|
|
112
112
|
"""
|
|
113
113
|
|
|
114
114
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
@@ -139,24 +139,23 @@ class IfToStim(IfElseSimplification, RewriteRule):
|
|
|
139
139
|
|
|
140
140
|
def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
|
|
141
141
|
|
|
142
|
+
# Check the condition is a singular MeasurementIdBool
|
|
142
143
|
if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool):
|
|
143
144
|
return RewriteResult()
|
|
144
145
|
|
|
145
|
-
#
|
|
146
|
-
#
|
|
147
|
-
# Can reuse logic from SplitIf
|
|
146
|
+
# Reusing code from SplitIf,
|
|
147
|
+
# there should only be one statement in the body and it should be a pauli X, Y, or Z
|
|
148
148
|
*stmts, _ = stmt.then_body.stmts()
|
|
149
|
-
if len(stmts) != 1
|
|
149
|
+
if len(stmts) != 1:
|
|
150
150
|
return RewriteResult()
|
|
151
151
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
if stim_gate is None:
|
|
152
|
+
if isinstance(stmts[0], gate.stmts.X):
|
|
153
|
+
stim_gate = stim_CX
|
|
154
|
+
elif isinstance(stmts[0], gate.stmts.Y):
|
|
155
|
+
stim_gate = stim_CY
|
|
156
|
+
elif isinstance(stmts[0], gate.stmts.Z):
|
|
157
|
+
stim_gate = stim_CZ
|
|
158
|
+
else:
|
|
160
159
|
return RewriteResult()
|
|
161
160
|
|
|
162
161
|
# get necessary measurement ID type from analysis
|
|
@@ -169,12 +168,7 @@ class IfToStim(IfElseSimplification, RewriteRule):
|
|
|
169
168
|
)
|
|
170
169
|
get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
|
|
171
170
|
|
|
172
|
-
|
|
173
|
-
if len(apply_or_broadcast.qubits) != 1:
|
|
174
|
-
# NOTE: this is actually invalid since we are dealing with single-qubit operators here
|
|
175
|
-
return RewriteResult()
|
|
176
|
-
|
|
177
|
-
address_attr = apply_or_broadcast.qubits[0].hints.get("address")
|
|
171
|
+
address_attr = stmts[0].qubits.hints.get("address")
|
|
178
172
|
|
|
179
173
|
if address_attr is None:
|
|
180
174
|
return RewriteResult()
|