bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bloqade/analysis/address/__init__.py +8 -4
- bloqade/analysis/address/analysis.py +123 -33
- bloqade/analysis/address/impls.py +293 -90
- bloqade/analysis/address/lattice.py +209 -24
- bloqade/analysis/fidelity/analysis.py +11 -23
- bloqade/analysis/measure_id/__init__.py +4 -1
- bloqade/analysis/measure_id/analysis.py +29 -20
- bloqade/analysis/measure_id/impls.py +72 -31
- bloqade/annotate/__init__.py +6 -0
- bloqade/annotate/_dialect.py +3 -0
- bloqade/annotate/_interface.py +22 -0
- bloqade/annotate/stmts.py +29 -0
- bloqade/annotate/types.py +13 -0
- bloqade/cirq_utils/__init__.py +4 -2
- bloqade/cirq_utils/emit/__init__.py +3 -0
- bloqade/cirq_utils/emit/base.py +246 -0
- bloqade/cirq_utils/emit/gate.py +104 -0
- bloqade/cirq_utils/emit/noise.py +90 -0
- bloqade/cirq_utils/emit/qubit.py +35 -0
- bloqade/cirq_utils/lowering.py +660 -0
- bloqade/cirq_utils/noise/__init__.py +0 -2
- bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
- bloqade/cirq_utils/noise/model.py +151 -191
- bloqade/cirq_utils/noise/transform.py +2 -2
- bloqade/cirq_utils/parallelize.py +9 -6
- bloqade/gemini/__init__.py +1 -0
- bloqade/gemini/analysis/__init__.py +3 -0
- bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
- bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
- bloqade/gemini/analysis/logical_validation/impls.py +101 -0
- bloqade/gemini/groups.py +67 -0
- bloqade/native/__init__.py +23 -0
- bloqade/native/_prelude.py +45 -0
- bloqade/native/dialects/__init__.py +0 -0
- bloqade/native/dialects/gate/__init__.py +2 -0
- bloqade/native/dialects/gate/_dialect.py +3 -0
- bloqade/native/dialects/gate/_interface.py +32 -0
- bloqade/native/dialects/gate/stmts.py +31 -0
- bloqade/native/stdlib/__init__.py +0 -0
- bloqade/native/stdlib/broadcast.py +246 -0
- bloqade/native/stdlib/simple.py +220 -0
- bloqade/native/upstream/__init__.py +4 -0
- bloqade/native/upstream/squin2native.py +79 -0
- bloqade/pyqrack/__init__.py +2 -2
- bloqade/pyqrack/base.py +7 -1
- bloqade/pyqrack/device.py +190 -4
- bloqade/pyqrack/native.py +49 -0
- bloqade/pyqrack/reg.py +6 -6
- bloqade/pyqrack/squin/gate/__init__.py +1 -0
- bloqade/pyqrack/squin/gate/gate.py +136 -0
- bloqade/pyqrack/squin/noise/native.py +120 -54
- bloqade/pyqrack/squin/qubit.py +39 -36
- bloqade/pyqrack/target.py +5 -4
- bloqade/pyqrack/task.py +114 -7
- bloqade/qasm2/_qasm_loading.py +3 -3
- bloqade/qasm2/dialects/core/address.py +21 -12
- bloqade/qasm2/dialects/expr/_emit.py +19 -8
- bloqade/qasm2/dialects/expr/stmts.py +7 -7
- bloqade/qasm2/dialects/noise/fidelity.py +4 -8
- bloqade/qasm2/dialects/noise/model.py +2 -1
- bloqade/qasm2/emit/base.py +16 -11
- bloqade/qasm2/emit/gate.py +11 -8
- bloqade/qasm2/emit/main.py +103 -3
- bloqade/qasm2/emit/target.py +9 -5
- bloqade/qasm2/groups.py +3 -2
- bloqade/qasm2/parse/lowering.py +0 -1
- bloqade/qasm2/passes/fold.py +14 -73
- bloqade/qasm2/passes/glob.py +2 -2
- bloqade/qasm2/passes/noise.py +1 -1
- bloqade/qasm2/passes/parallel.py +7 -5
- bloqade/qasm2/rewrite/__init__.py +0 -1
- bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
- bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
- bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
- bloqade/qasm2/rewrite/register.py +2 -2
- bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
- bloqade/qbraid/lowering.py +1 -0
- bloqade/qbraid/schema.py +2 -2
- bloqade/qubit/__init__.py +12 -0
- bloqade/qubit/_dialect.py +3 -0
- bloqade/qubit/_interface.py +49 -0
- bloqade/qubit/_prelude.py +45 -0
- bloqade/qubit/analysis/__init__.py +1 -0
- bloqade/qubit/analysis/address_impl.py +40 -0
- bloqade/qubit/stdlib/__init__.py +2 -0
- bloqade/qubit/stdlib/_new.py +34 -0
- bloqade/qubit/stdlib/broadcast.py +62 -0
- bloqade/qubit/stdlib/simple.py +59 -0
- bloqade/qubit/stmts.py +60 -0
- bloqade/rewrite/passes/__init__.py +6 -0
- bloqade/rewrite/passes/aggressive_unroll.py +103 -0
- bloqade/rewrite/passes/callgraph.py +116 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
- bloqade/rewrite/rules/split_ifs.py +18 -1
- bloqade/squin/__init__.py +47 -14
- bloqade/squin/analysis/__init__.py +0 -1
- bloqade/squin/analysis/schedule.py +10 -11
- bloqade/squin/gate/__init__.py +2 -0
- bloqade/squin/gate/_dialect.py +3 -0
- bloqade/squin/gate/_interface.py +98 -0
- bloqade/squin/gate/stmts.py +125 -0
- bloqade/squin/groups.py +5 -22
- bloqade/squin/noise/__init__.py +1 -10
- bloqade/squin/noise/_dialect.py +1 -1
- bloqade/squin/noise/_interface.py +45 -0
- bloqade/squin/noise/stmts.py +66 -28
- bloqade/squin/rewrite/U3_to_clifford.py +70 -51
- bloqade/squin/rewrite/__init__.py +0 -2
- bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
- bloqade/squin/rewrite/wrap_analysis.py +4 -35
- bloqade/squin/stdlib/__init__.py +0 -0
- bloqade/squin/stdlib/broadcast/__init__.py +34 -0
- bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
- bloqade/squin/stdlib/broadcast/gate.py +260 -0
- bloqade/squin/stdlib/broadcast/noise.py +144 -0
- bloqade/squin/stdlib/simple/__init__.py +33 -0
- bloqade/squin/stdlib/simple/gate.py +242 -0
- bloqade/squin/stdlib/simple/noise.py +126 -0
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +6 -0
- bloqade/stim/dialects/auxiliary/emit.py +19 -18
- bloqade/stim/dialects/collapse/emit_str.py +7 -8
- bloqade/stim/dialects/gate/emit.py +9 -10
- bloqade/stim/dialects/noise/emit.py +17 -13
- bloqade/stim/dialects/noise/stmts.py +5 -3
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/impls.py +16 -0
- bloqade/stim/emit/stim_str.py +48 -31
- bloqade/stim/groups.py +12 -2
- bloqade/stim/parse/lowering.py +14 -17
- bloqade/stim/passes/__init__.py +3 -1
- bloqade/stim/passes/flatten.py +26 -0
- bloqade/stim/passes/simplify_ifs.py +16 -2
- bloqade/stim/passes/squin_to_stim.py +18 -60
- bloqade/stim/rewrite/__init__.py +3 -4
- bloqade/stim/rewrite/get_record_util.py +24 -0
- bloqade/stim/rewrite/ifs_to_stim.py +29 -31
- bloqade/stim/rewrite/qubit_to_stim.py +90 -41
- bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
- bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
- bloqade/stim/rewrite/squin_measure.py +11 -79
- bloqade/stim/rewrite/squin_noise.py +134 -108
- bloqade/stim/rewrite/util.py +5 -192
- bloqade/test_utils.py +1 -1
- bloqade/types.py +10 -0
- bloqade/validation/__init__.py +2 -0
- bloqade/validation/analysis/__init__.py +5 -0
- bloqade/validation/analysis/analysis.py +41 -0
- bloqade/validation/analysis/lattice.py +58 -0
- bloqade/validation/kernel_validation.py +77 -0
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
- bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
- bloqade/pyqrack/squin/op.py +0 -166
- bloqade/pyqrack/squin/runtime.py +0 -535
- bloqade/pyqrack/squin/wire.py +0 -51
- bloqade/rewrite/rules/flatten_ilist.py +0 -51
- bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
- bloqade/squin/_typeinfer.py +0 -20
- bloqade/squin/analysis/address_impl.py +0 -71
- bloqade/squin/analysis/nsites/__init__.py +0 -9
- bloqade/squin/analysis/nsites/analysis.py +0 -50
- bloqade/squin/analysis/nsites/impls.py +0 -92
- bloqade/squin/analysis/nsites/lattice.py +0 -49
- bloqade/squin/cirq/__init__.py +0 -265
- bloqade/squin/cirq/emit/emit_circuit.py +0 -109
- bloqade/squin/cirq/emit/noise.py +0 -49
- bloqade/squin/cirq/emit/op.py +0 -125
- bloqade/squin/cirq/emit/qubit.py +0 -60
- bloqade/squin/cirq/emit/runtime.py +0 -242
- bloqade/squin/cirq/lowering.py +0 -440
- bloqade/squin/lowering.py +0 -54
- bloqade/squin/noise/_wrapper.py +0 -40
- bloqade/squin/noise/rewrite.py +0 -111
- bloqade/squin/op/__init__.py +0 -41
- bloqade/squin/op/_dialect.py +0 -3
- bloqade/squin/op/_wrapper.py +0 -121
- bloqade/squin/op/number.py +0 -5
- bloqade/squin/op/rewrite.py +0 -46
- bloqade/squin/op/stdlib.py +0 -62
- bloqade/squin/op/stmts.py +0 -276
- bloqade/squin/op/traits.py +0 -43
- bloqade/squin/op/types.py +0 -26
- bloqade/squin/qubit.py +0 -184
- bloqade/squin/rewrite/canonicalize.py +0 -60
- bloqade/squin/rewrite/desugar.py +0 -124
- bloqade/squin/types.py +0 -8
- bloqade/squin/wire.py +0 -201
- bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
- bloqade/stim/rewrite/wire_to_stim.py +0 -57
- bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
bloqade/stim/parse/lowering.py
CHANGED
|
@@ -98,6 +98,8 @@ def loads(
|
|
|
98
98
|
signature=func.Signature((), return_node.value.type),
|
|
99
99
|
body=body,
|
|
100
100
|
)
|
|
101
|
+
self_arg = ir.BlockArgument(body.blocks[0], 0) # Self argument
|
|
102
|
+
body.blocks[0]._args = (self_arg,)
|
|
101
103
|
return ir.Method(
|
|
102
104
|
mod=None,
|
|
103
105
|
py_func=None,
|
|
@@ -627,10 +629,13 @@ class Stim(lowering.LoweringABC[Node]):
|
|
|
627
629
|
# Parse tag
|
|
628
630
|
tag_parts = node.tag.split(";", maxsplit=1)[0].split(":", maxsplit=1)
|
|
629
631
|
nonstim_name = tag_parts[0]
|
|
630
|
-
nonce = 0
|
|
631
632
|
if len(tag_parts) == 2:
|
|
633
|
+
# This should be a correlated error of the form, e.g.,
|
|
634
|
+
# I_ERROR[correlated_loss:<identifier>](0.01) 0 1 2
|
|
635
|
+
# The identifier is a unique number that prevents stim from merging
|
|
636
|
+
# correlated errors. We discard the identifier, but verify it is an integer.
|
|
632
637
|
try:
|
|
633
|
-
|
|
638
|
+
_ = int(tag_parts[1])
|
|
634
639
|
except ValueError:
|
|
635
640
|
# String was not an integer
|
|
636
641
|
if self.error_unknown_nonstim:
|
|
@@ -643,22 +648,14 @@ class Stim(lowering.LoweringABC[Node]):
|
|
|
643
648
|
f"Unknown non-stim statement name: {nonstim_name!r} ({node!r})"
|
|
644
649
|
)
|
|
645
650
|
statement_cls = self.nonstim_noise_ops.get(nonstim_name)
|
|
651
|
+
stmt = None
|
|
646
652
|
if statement_cls is not None:
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
),
|
|
654
|
-
)
|
|
655
|
-
else:
|
|
656
|
-
stmt = statement_cls(
|
|
657
|
-
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
|
|
658
|
-
targets=self._get_multiple_qubit_or_rec_ssa(
|
|
659
|
-
state, node, node.targets_copy()
|
|
660
|
-
),
|
|
661
|
-
)
|
|
653
|
+
stmt = statement_cls(
|
|
654
|
+
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
|
|
655
|
+
targets=self._get_multiple_qubit_or_rec_ssa(
|
|
656
|
+
state, node, node.targets_copy()
|
|
657
|
+
),
|
|
658
|
+
)
|
|
662
659
|
return stmt
|
|
663
660
|
|
|
664
661
|
def visit_CircuitInstruction(
|
bloqade/stim/passes/__init__.py
CHANGED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Taken from Phillip Weinberg's bloqade-shuttle implementation
|
|
2
|
+
from dataclasses import field, dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.passes import Pass
|
|
6
|
+
from kirin.rewrite.abc import RewriteResult
|
|
7
|
+
|
|
8
|
+
from bloqade.rewrite.passes import AggressiveUnroll
|
|
9
|
+
from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class Flatten(Pass):
|
|
14
|
+
|
|
15
|
+
unroll: AggressiveUnroll = field(init=False)
|
|
16
|
+
simplify_if: StimSimplifyIfs = field(init=False)
|
|
17
|
+
|
|
18
|
+
def __post_init__(self):
|
|
19
|
+
self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise)
|
|
20
|
+
self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise)
|
|
21
|
+
|
|
22
|
+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
|
|
23
|
+
rewrite_result = RewriteResult()
|
|
24
|
+
rewrite_result = self.simplify_if(mt).join(rewrite_result)
|
|
25
|
+
rewrite_result = self.unroll(mt).join(rewrite_result)
|
|
26
|
+
return rewrite_result
|
|
@@ -7,8 +7,11 @@ from kirin.rewrite import (
|
|
|
7
7
|
Chain,
|
|
8
8
|
Fixpoint,
|
|
9
9
|
ConstantFold,
|
|
10
|
+
DeadCodeElimination,
|
|
10
11
|
CommonSubexpressionElimination,
|
|
11
12
|
)
|
|
13
|
+
from kirin.dialects.scf.trim import UnusedYield
|
|
14
|
+
from kirin.dialects.ilist.passes import ConstList2IList
|
|
12
15
|
|
|
13
16
|
from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
|
|
14
17
|
|
|
@@ -19,12 +22,23 @@ class StimSimplifyIfs(Pass):
|
|
|
19
22
|
def unsafe_run(self, mt: ir.Method):
|
|
20
23
|
|
|
21
24
|
result = Chain(
|
|
22
|
-
|
|
25
|
+
Walk(UnusedYield()),
|
|
26
|
+
Walk(StimLiftThenBody()),
|
|
27
|
+
# remove yields (if possible), then lift out as much stuff as possible
|
|
28
|
+
Walk(DeadCodeElimination()),
|
|
23
29
|
Walk(StimSplitIfStmts()),
|
|
24
30
|
).rewrite(mt.code)
|
|
25
31
|
|
|
32
|
+
# because nested python lists don't have their
|
|
33
|
+
# member lists converted to ILists, ConstantFold
|
|
34
|
+
# can add python lists that can't be hashed, causing
|
|
35
|
+
# issues with CSE. ConstList2IList remedies that problem here.
|
|
26
36
|
result = (
|
|
27
|
-
|
|
37
|
+
Chain(
|
|
38
|
+
Fixpoint(Walk(ConstantFold())),
|
|
39
|
+
Walk(ConstList2IList()),
|
|
40
|
+
Walk(CommonSubexpressionElimination()),
|
|
41
|
+
)
|
|
28
42
|
.rewrite(mt.code)
|
|
29
43
|
.join(result)
|
|
30
44
|
)
|
|
@@ -1,29 +1,21 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
|
|
3
|
-
from kirin.passes import Fold
|
|
4
3
|
from kirin.rewrite import (
|
|
5
4
|
Walk,
|
|
6
5
|
Chain,
|
|
7
6
|
Fixpoint,
|
|
8
|
-
CFGCompactify,
|
|
9
|
-
InlineGetItem,
|
|
10
|
-
InlineGetField,
|
|
11
7
|
DeadCodeElimination,
|
|
12
8
|
CommonSubexpressionElimination,
|
|
13
9
|
)
|
|
14
|
-
from kirin.dialects import scf, ilist
|
|
15
10
|
from kirin.ir.method import Method
|
|
16
11
|
from kirin.passes.abc import Pass
|
|
17
12
|
from kirin.rewrite.abc import RewriteResult
|
|
18
|
-
from kirin.passes.inline import InlinePass
|
|
19
13
|
|
|
20
14
|
from bloqade.stim.rewrite import (
|
|
21
|
-
SquinWireToStim,
|
|
22
15
|
PyConstantToStim,
|
|
23
16
|
SquinNoiseToStim,
|
|
24
17
|
SquinQubitToStim,
|
|
25
18
|
SquinMeasureToStim,
|
|
26
|
-
SquinWireIdentityElimination,
|
|
27
19
|
)
|
|
28
20
|
from bloqade.squin.rewrite import (
|
|
29
21
|
SquinU3ToClifford,
|
|
@@ -33,9 +25,9 @@ from bloqade.squin.rewrite import (
|
|
|
33
25
|
from bloqade.rewrite.passes import CanonicalizeIList
|
|
34
26
|
from bloqade.analysis.address import AddressAnalysis
|
|
35
27
|
from bloqade.analysis.measure_id import MeasurementIDAnalysis
|
|
28
|
+
from bloqade.stim.passes.flatten import Flatten
|
|
36
29
|
|
|
37
|
-
from
|
|
38
|
-
from ..rewrite.ifs_to_stim import IfToStim
|
|
30
|
+
from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim
|
|
39
31
|
|
|
40
32
|
|
|
41
33
|
@dataclass
|
|
@@ -44,52 +36,18 @@ class SquinToStimPass(Pass):
|
|
|
44
36
|
def unsafe_run(self, mt: Method) -> RewriteResult:
|
|
45
37
|
|
|
46
38
|
# inline aggressively:
|
|
47
|
-
rewrite_result =
|
|
48
|
-
|
|
49
|
-
).unsafe_run(mt)
|
|
50
|
-
|
|
51
|
-
rule = Chain(
|
|
52
|
-
InlineGetField(),
|
|
53
|
-
InlineGetItem(),
|
|
54
|
-
scf.unroll.ForLoop(),
|
|
55
|
-
scf.trim.UnusedYield(),
|
|
56
|
-
)
|
|
57
|
-
rewrite_result = Fixpoint(Walk(rule)).rewrite(mt.code).join(rewrite_result)
|
|
58
|
-
# fold_pass = Fold(mt.dialects, no_raise=self.no_raise)
|
|
59
|
-
# rewrite_result = fold_pass(mt)
|
|
60
|
-
rewrite_result = (
|
|
61
|
-
Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
|
|
62
|
-
)
|
|
63
|
-
rewrite_result = (
|
|
64
|
-
StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
|
|
65
|
-
.unsafe_run(mt)
|
|
66
|
-
.join(rewrite_result)
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
# run typeinfer again after unroll etc. because we now insert
|
|
70
|
-
# a lot of new nodes, which might have more precise types
|
|
71
|
-
# self.typeinfer.unsafe_run(mt)
|
|
72
|
-
rewrite_result = (
|
|
73
|
-
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
|
|
74
|
-
.rewrite(mt.code)
|
|
75
|
-
.join(rewrite_result)
|
|
76
|
-
)
|
|
77
|
-
rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt)
|
|
78
|
-
|
|
79
|
-
rewrite_result = (
|
|
80
|
-
CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
|
|
81
|
-
.unsafe_run(mt)
|
|
82
|
-
.join(rewrite_result)
|
|
39
|
+
rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint(
|
|
40
|
+
mt
|
|
83
41
|
)
|
|
84
42
|
|
|
85
43
|
# after this the program should be in a state where it is analyzable
|
|
86
44
|
# -------------------------------------------------------------------
|
|
87
45
|
|
|
88
46
|
mia = MeasurementIDAnalysis(dialects=mt.dialects)
|
|
89
|
-
meas_analysis_frame, _ = mia.
|
|
47
|
+
meas_analysis_frame, _ = mia.run(mt)
|
|
90
48
|
|
|
91
49
|
aa = AddressAnalysis(dialects=mt.dialects)
|
|
92
|
-
address_analysis_frame, _ = aa.
|
|
50
|
+
address_analysis_frame, _ = aa.run(mt)
|
|
93
51
|
|
|
94
52
|
# wrap the address analysis result
|
|
95
53
|
rewrite_result = (
|
|
@@ -99,12 +57,16 @@ class SquinToStimPass(Pass):
|
|
|
99
57
|
)
|
|
100
58
|
|
|
101
59
|
# 2. rewrite
|
|
60
|
+
## Invoke DCE afterwards to eliminate any GetItems
|
|
61
|
+
## that are no longer being used. This allows for
|
|
62
|
+
## SquinMeasureToStim to safely eliminate
|
|
63
|
+
## unused measure statements.
|
|
102
64
|
rewrite_result = (
|
|
103
|
-
|
|
104
|
-
IfToStim(
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
)
|
|
65
|
+
Chain(
|
|
66
|
+
Walk(IfToStim(measure_frame=meas_analysis_frame)),
|
|
67
|
+
Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)),
|
|
68
|
+
Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)),
|
|
69
|
+
Fixpoint(Walk(DeadCodeElimination())),
|
|
108
70
|
)
|
|
109
71
|
.rewrite(mt.code)
|
|
110
72
|
.join(rewrite_result)
|
|
@@ -120,17 +82,13 @@ class SquinToStimPass(Pass):
|
|
|
120
82
|
Walk(
|
|
121
83
|
Chain(
|
|
122
84
|
SquinQubitToStim(),
|
|
123
|
-
|
|
124
|
-
SquinMeasureToStim(
|
|
125
|
-
measure_id_result=meas_analysis_frame.entries,
|
|
126
|
-
total_measure_count=mia.measure_count,
|
|
127
|
-
), # reduce duplicated logic, can split out even more rules later
|
|
128
|
-
SquinWireIdentityElimination(),
|
|
85
|
+
SquinMeasureToStim(),
|
|
129
86
|
)
|
|
130
87
|
)
|
|
131
88
|
.rewrite(mt.code)
|
|
132
89
|
.join(rewrite_result)
|
|
133
90
|
)
|
|
91
|
+
|
|
134
92
|
rewrite_result = (
|
|
135
93
|
CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
|
|
136
94
|
.unsafe_run(mt)
|
|
@@ -141,7 +99,7 @@ class SquinToStimPass(Pass):
|
|
|
141
99
|
rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
|
|
142
100
|
|
|
143
101
|
# clear up leftover stmts
|
|
144
|
-
# - remove any squin.
|
|
102
|
+
# - remove any squin.qalloc that's left around
|
|
145
103
|
rewrite_result = (
|
|
146
104
|
Fixpoint(
|
|
147
105
|
Walk(
|
bloqade/stim/rewrite/__init__.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
|
+
from .ifs_to_stim import IfToStim as IfToStim
|
|
1
2
|
from .squin_noise import SquinNoiseToStim as SquinNoiseToStim
|
|
2
|
-
from .wire_to_stim import SquinWireToStim as SquinWireToStim
|
|
3
3
|
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
|
|
4
4
|
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
|
|
5
5
|
from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
|
|
6
|
-
from .
|
|
7
|
-
|
|
8
|
-
)
|
|
6
|
+
from .set_detector_to_stim import SetDetectorToStim as SetDetectorToStim
|
|
7
|
+
from .set_observable_to_stim import SetObservableToStim as SetObservableToStim
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.dialects import py
|
|
3
|
+
|
|
4
|
+
from bloqade.stim.dialects import auxiliary
|
|
5
|
+
from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def insert_get_records(
|
|
9
|
+
node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int
|
|
10
|
+
):
|
|
11
|
+
"""
|
|
12
|
+
Insert GetRecord statements before the given node
|
|
13
|
+
"""
|
|
14
|
+
get_record_ssas = []
|
|
15
|
+
for measure_id_bool in measure_id_tuple.data:
|
|
16
|
+
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
17
|
+
target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt
|
|
18
|
+
idx_stmt = py.constant.Constant(target_rec_idx)
|
|
19
|
+
idx_stmt.insert_before(node)
|
|
20
|
+
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
|
|
21
|
+
get_record_stmt.insert_before(node)
|
|
22
|
+
get_record_ssas.append(get_record_stmt.result)
|
|
23
|
+
|
|
24
|
+
return get_record_ssas
|
|
@@ -4,16 +4,16 @@ from kirin import ir
|
|
|
4
4
|
from kirin.dialects import py, scf, func
|
|
5
5
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
6
6
|
|
|
7
|
-
from bloqade.squin import
|
|
7
|
+
from bloqade.squin import gate
|
|
8
8
|
from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
|
|
9
9
|
from bloqade.squin.rewrite import AddressAttribute
|
|
10
10
|
from bloqade.stim.rewrite.util import (
|
|
11
|
-
SQUIN_STIM_CONTROL_GATE_MAPPING,
|
|
12
11
|
insert_qubit_idx_from_address,
|
|
13
12
|
)
|
|
13
|
+
from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ
|
|
14
|
+
from bloqade.analysis.measure_id import MeasureIDFrame
|
|
14
15
|
from bloqade.stim.dialects.auxiliary import GetRecord
|
|
15
16
|
from bloqade.analysis.measure_id.lattice import (
|
|
16
|
-
MeasureId,
|
|
17
17
|
MeasureIdBool,
|
|
18
18
|
)
|
|
19
19
|
|
|
@@ -58,8 +58,7 @@ class IfElseSimplification:
|
|
|
58
58
|
"""Check if the IfElse statement has an else body."""
|
|
59
59
|
if stmt.else_body.blocks and not (
|
|
60
60
|
len(stmt.else_body.blocks[0].stmts) == 1
|
|
61
|
-
and isinstance(
|
|
62
|
-
and not else_term.values # empty yield
|
|
61
|
+
and isinstance(stmt.else_body.blocks[0].last_stmt, scf.Yield)
|
|
63
62
|
):
|
|
64
63
|
return True
|
|
65
64
|
|
|
@@ -67,12 +66,13 @@ class IfElseSimplification:
|
|
|
67
66
|
|
|
68
67
|
|
|
69
68
|
DontLiftType = (
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
69
|
+
gate.stmts.SingleQubitGate,
|
|
70
|
+
gate.stmts.RotationGate,
|
|
71
|
+
gate.stmts.ControlledGate,
|
|
73
72
|
func.Return,
|
|
74
73
|
func.Invoke,
|
|
75
74
|
scf.IfElse,
|
|
75
|
+
scf.Yield,
|
|
76
76
|
)
|
|
77
77
|
|
|
78
78
|
|
|
@@ -99,16 +99,16 @@ class StimSplitIfStmts(IfElseSimplification, SplitIfStmts):
|
|
|
99
99
|
Given an IfElse with multiple valid statements in the then-body:
|
|
100
100
|
|
|
101
101
|
if measure_result:
|
|
102
|
-
squin.
|
|
103
|
-
squin.
|
|
102
|
+
squin.x(q0)
|
|
103
|
+
squin.y(q1)
|
|
104
104
|
|
|
105
105
|
this should be rewritten to:
|
|
106
106
|
|
|
107
107
|
if measure_result:
|
|
108
|
-
squin.
|
|
108
|
+
squin.x(q0)
|
|
109
109
|
|
|
110
110
|
if measure_result:
|
|
111
|
-
squin.
|
|
111
|
+
squin.y(q1)
|
|
112
112
|
"""
|
|
113
113
|
|
|
114
114
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
@@ -127,8 +127,7 @@ class IfToStim(IfElseSimplification, RewriteRule):
|
|
|
127
127
|
Rewrite if statements to stim equivalent statements.
|
|
128
128
|
"""
|
|
129
129
|
|
|
130
|
-
|
|
131
|
-
measure_count: int
|
|
130
|
+
measure_frame: MeasureIDFrame
|
|
132
131
|
|
|
133
132
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
134
133
|
|
|
@@ -140,38 +139,37 @@ class IfToStim(IfElseSimplification, RewriteRule):
|
|
|
140
139
|
|
|
141
140
|
def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
|
|
142
141
|
|
|
143
|
-
|
|
142
|
+
# Check the condition is a singular MeasurementIdBool
|
|
143
|
+
if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool):
|
|
144
144
|
return RewriteResult()
|
|
145
145
|
|
|
146
|
-
#
|
|
147
|
-
#
|
|
148
|
-
# Can reuse logic from SplitIf
|
|
146
|
+
# Reusing code from SplitIf,
|
|
147
|
+
# there should only be one statement in the body and it should be a pauli X, Y, or Z
|
|
149
148
|
*stmts, _ = stmt.then_body.stmts()
|
|
150
|
-
if len(stmts) != 1
|
|
149
|
+
if len(stmts) != 1:
|
|
151
150
|
return RewriteResult()
|
|
152
151
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
if stim_gate is None:
|
|
152
|
+
if isinstance(stmts[0], gate.stmts.X):
|
|
153
|
+
stim_gate = stim_CX
|
|
154
|
+
elif isinstance(stmts[0], gate.stmts.Y):
|
|
155
|
+
stim_gate = stim_CY
|
|
156
|
+
elif isinstance(stmts[0], gate.stmts.Z):
|
|
157
|
+
stim_gate = stim_CZ
|
|
158
|
+
else:
|
|
161
159
|
return RewriteResult()
|
|
162
160
|
|
|
163
161
|
# get necessary measurement ID type from analysis
|
|
164
|
-
measure_id_bool = self.
|
|
162
|
+
measure_id_bool = self.measure_frame.entries[stmt.cond]
|
|
165
163
|
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
166
164
|
|
|
167
165
|
# generate get record statement
|
|
168
166
|
measure_id_idx_stmt = py.Constant(
|
|
169
|
-
(measure_id_bool.idx - 1) - self.
|
|
167
|
+
(measure_id_bool.idx - 1) - self.measure_frame.num_measures_at_stmt[stmt]
|
|
170
168
|
)
|
|
171
169
|
get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
|
|
172
170
|
|
|
173
|
-
|
|
174
|
-
|
|
171
|
+
address_attr = stmts[0].qubits.hints.get("address")
|
|
172
|
+
|
|
175
173
|
if address_attr is None:
|
|
176
174
|
return RewriteResult()
|
|
177
175
|
assert isinstance(address_attr, AddressAttribute)
|
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
from kirin import ir
|
|
2
2
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
3
|
|
|
4
|
-
from bloqade
|
|
4
|
+
from bloqade import qubit
|
|
5
|
+
from bloqade.squin import gate
|
|
5
6
|
from bloqade.squin.rewrite import AddressAttribute
|
|
6
|
-
from bloqade.stim.dialects import gate
|
|
7
|
+
from bloqade.stim.dialects import gate as stim_gate, collapse as stim_collapse
|
|
7
8
|
from bloqade.stim.rewrite.util import (
|
|
8
|
-
SQUIN_STIM_OP_MAPPING,
|
|
9
|
-
rewrite_Control,
|
|
10
|
-
rewrite_QubitLoss,
|
|
11
9
|
insert_qubit_idx_from_address,
|
|
12
10
|
)
|
|
13
11
|
|
|
@@ -20,64 +18,115 @@ class SquinQubitToStim(RewriteRule):
|
|
|
20
18
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
21
19
|
|
|
22
20
|
match node:
|
|
23
|
-
|
|
24
|
-
|
|
21
|
+
# not supported by Stim
|
|
22
|
+
case gate.stmts.T() | gate.stmts.RotationGate():
|
|
23
|
+
return RewriteResult()
|
|
24
|
+
# If you've reached this point all gates have stim equivalents
|
|
25
|
+
case qubit.stmts.Reset():
|
|
26
|
+
return self.rewrite_Reset(node)
|
|
27
|
+
case gate.stmts.SingleQubitGate():
|
|
28
|
+
return self.rewrite_SingleQubitGate(node)
|
|
29
|
+
case gate.stmts.ControlledGate():
|
|
30
|
+
return self.rewrite_ControlledGate(node)
|
|
25
31
|
case _:
|
|
26
32
|
return RewriteResult()
|
|
27
33
|
|
|
28
|
-
def
|
|
29
|
-
self, stmt: qubit.Apply | qubit.Broadcast
|
|
30
|
-
) -> RewriteResult:
|
|
31
|
-
"""
|
|
32
|
-
Rewrite Apply and Broadcast nodes to their stim equivalent statements.
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
# this is an SSAValue, need it to be the actual operator
|
|
36
|
-
applied_op = stmt.operator.owner
|
|
34
|
+
def rewrite_Reset(self, stmt: qubit.stmts.Reset) -> RewriteResult:
|
|
37
35
|
|
|
38
|
-
|
|
39
|
-
return rewrite_QubitLoss(stmt)
|
|
36
|
+
qubit_addr_attr = stmt.qubits.hints.get("address", None)
|
|
40
37
|
|
|
41
|
-
|
|
38
|
+
if qubit_addr_attr is None:
|
|
39
|
+
return RewriteResult()
|
|
42
40
|
|
|
43
|
-
|
|
44
|
-
return rewrite_Control(stmt)
|
|
41
|
+
assert isinstance(qubit_addr_attr, AddressAttribute)
|
|
45
42
|
|
|
46
|
-
|
|
43
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
44
|
+
address=qubit_addr_attr, stmt_to_insert_before=stmt
|
|
45
|
+
)
|
|
47
46
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
if isinstance(applied_op, op.stmts.Adjoint):
|
|
51
|
-
if not applied_op.is_unitary:
|
|
52
|
-
return RewriteResult()
|
|
47
|
+
if qubit_idx_ssas is None:
|
|
48
|
+
return RewriteResult()
|
|
53
49
|
|
|
54
|
-
|
|
55
|
-
|
|
50
|
+
stim_stmt = stim_collapse.RZ(targets=tuple(qubit_idx_ssas))
|
|
51
|
+
stmt.replace_by(stim_stmt)
|
|
56
52
|
|
|
57
|
-
|
|
58
|
-
if stim_1q_op is None:
|
|
59
|
-
return RewriteResult()
|
|
53
|
+
return RewriteResult(has_done_something=True)
|
|
60
54
|
|
|
61
|
-
|
|
55
|
+
def rewrite_SingleQubitGate(
|
|
56
|
+
self, stmt: gate.stmts.SingleQubitGate
|
|
57
|
+
) -> RewriteResult:
|
|
58
|
+
"""
|
|
59
|
+
Rewrite single qubit gate nodes to their stim equivalent statements.
|
|
60
|
+
Address Analysis should have been run along with Wrap Analysis before this rewrite is applied.
|
|
61
|
+
"""
|
|
62
62
|
|
|
63
|
-
|
|
63
|
+
qubit_addr_attr = stmt.qubits.hints.get("address", None)
|
|
64
|
+
if qubit_addr_attr is None:
|
|
64
65
|
return RewriteResult()
|
|
65
66
|
|
|
66
|
-
assert isinstance(
|
|
67
|
+
assert isinstance(qubit_addr_attr, AddressAttribute)
|
|
68
|
+
|
|
67
69
|
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
68
|
-
address=
|
|
70
|
+
address=qubit_addr_attr, stmt_to_insert_before=stmt
|
|
69
71
|
)
|
|
70
72
|
|
|
71
73
|
if qubit_idx_ssas is None:
|
|
72
74
|
return RewriteResult()
|
|
73
75
|
|
|
74
|
-
if
|
|
75
|
-
|
|
76
|
+
# Get the name of the inputted stmt and see if there is an
|
|
77
|
+
# equivalently named statement in stim,
|
|
78
|
+
# then create an instance of that stim statement
|
|
79
|
+
stmt_name = type(stmt).__name__
|
|
80
|
+
stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None)
|
|
81
|
+
if stim_stmt_cls is None:
|
|
82
|
+
return RewriteResult()
|
|
83
|
+
|
|
84
|
+
if isinstance(stmt, gate.stmts.SingleQubitNonHermitianGate):
|
|
85
|
+
stim_stmt = stim_stmt_cls(
|
|
86
|
+
targets=tuple(qubit_idx_ssas), dagger=stmt.adjoint
|
|
87
|
+
)
|
|
76
88
|
else:
|
|
77
|
-
|
|
78
|
-
stmt.replace_by(
|
|
89
|
+
stim_stmt = stim_stmt_cls(targets=tuple(qubit_idx_ssas))
|
|
90
|
+
stmt.replace_by(stim_stmt)
|
|
79
91
|
|
|
80
92
|
return RewriteResult(has_done_something=True)
|
|
81
93
|
|
|
94
|
+
def rewrite_ControlledGate(self, stmt: gate.stmts.ControlledGate) -> RewriteResult:
|
|
95
|
+
"""
|
|
96
|
+
Rewrite controlled gate nodes to their stim equivalent statements.
|
|
97
|
+
Address Analysis should have been run along with Wrap Analysis before this rewrite is applied.
|
|
98
|
+
"""
|
|
82
99
|
|
|
83
|
-
|
|
100
|
+
controls_addr_attr = stmt.controls.hints.get("address", None)
|
|
101
|
+
targets_addr_attr = stmt.targets.hints.get("address", None)
|
|
102
|
+
|
|
103
|
+
if controls_addr_attr is None or targets_addr_attr is None:
|
|
104
|
+
return RewriteResult()
|
|
105
|
+
|
|
106
|
+
assert isinstance(controls_addr_attr, AddressAttribute)
|
|
107
|
+
assert isinstance(targets_addr_attr, AddressAttribute)
|
|
108
|
+
|
|
109
|
+
controls_idx_ssas = insert_qubit_idx_from_address(
|
|
110
|
+
address=controls_addr_attr, stmt_to_insert_before=stmt
|
|
111
|
+
)
|
|
112
|
+
targets_idx_ssas = insert_qubit_idx_from_address(
|
|
113
|
+
address=targets_addr_attr, stmt_to_insert_before=stmt
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if controls_idx_ssas is None or targets_idx_ssas is None:
|
|
117
|
+
return RewriteResult()
|
|
118
|
+
|
|
119
|
+
# Get the name of the inputted stmt and see if there is an
|
|
120
|
+
# equivalently named statement in stim,
|
|
121
|
+
# then create an instance of that stim statement
|
|
122
|
+
stmt_name = type(stmt).__name__
|
|
123
|
+
stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None)
|
|
124
|
+
if stim_stmt_cls is None:
|
|
125
|
+
return RewriteResult()
|
|
126
|
+
|
|
127
|
+
stim_stmt = stim_stmt_cls(
|
|
128
|
+
targets=tuple(targets_idx_ssas), controls=tuple(controls_idx_ssas)
|
|
129
|
+
)
|
|
130
|
+
stmt.replace_by(stim_stmt)
|
|
131
|
+
|
|
132
|
+
return RewriteResult(has_done_something=True)
|