bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bloqade/analysis/address/__init__.py +8 -4
- bloqade/analysis/address/analysis.py +123 -33
- bloqade/analysis/address/impls.py +293 -90
- bloqade/analysis/address/lattice.py +209 -24
- bloqade/analysis/fidelity/analysis.py +11 -23
- bloqade/analysis/measure_id/analysis.py +18 -20
- bloqade/analysis/measure_id/impls.py +31 -29
- bloqade/annotate/__init__.py +6 -0
- bloqade/annotate/_dialect.py +3 -0
- bloqade/annotate/_interface.py +22 -0
- bloqade/annotate/stmts.py +29 -0
- bloqade/annotate/types.py +13 -0
- bloqade/cirq_utils/__init__.py +4 -2
- bloqade/cirq_utils/emit/__init__.py +3 -0
- bloqade/cirq_utils/emit/base.py +246 -0
- bloqade/cirq_utils/emit/gate.py +104 -0
- bloqade/cirq_utils/emit/noise.py +90 -0
- bloqade/cirq_utils/emit/qubit.py +35 -0
- bloqade/cirq_utils/lowering.py +660 -0
- bloqade/cirq_utils/noise/__init__.py +0 -2
- bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
- bloqade/cirq_utils/noise/model.py +151 -191
- bloqade/cirq_utils/noise/transform.py +2 -2
- bloqade/cirq_utils/parallelize.py +9 -6
- bloqade/gemini/__init__.py +1 -0
- bloqade/gemini/analysis/__init__.py +3 -0
- bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
- bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
- bloqade/gemini/analysis/logical_validation/impls.py +101 -0
- bloqade/gemini/groups.py +67 -0
- bloqade/native/__init__.py +23 -0
- bloqade/native/_prelude.py +45 -0
- bloqade/native/dialects/__init__.py +0 -0
- bloqade/native/dialects/gate/__init__.py +2 -0
- bloqade/native/dialects/gate/_dialect.py +3 -0
- bloqade/native/dialects/gate/_interface.py +32 -0
- bloqade/native/dialects/gate/stmts.py +31 -0
- bloqade/native/stdlib/__init__.py +0 -0
- bloqade/native/stdlib/broadcast.py +246 -0
- bloqade/native/stdlib/simple.py +220 -0
- bloqade/native/upstream/__init__.py +4 -0
- bloqade/native/upstream/squin2native.py +79 -0
- bloqade/pyqrack/__init__.py +2 -2
- bloqade/pyqrack/base.py +7 -1
- bloqade/pyqrack/device.py +192 -18
- bloqade/pyqrack/native.py +49 -0
- bloqade/pyqrack/reg.py +6 -6
- bloqade/pyqrack/squin/gate/__init__.py +1 -0
- bloqade/pyqrack/squin/gate/gate.py +136 -0
- bloqade/pyqrack/squin/noise/native.py +120 -54
- bloqade/pyqrack/squin/qubit.py +39 -36
- bloqade/pyqrack/target.py +5 -4
- bloqade/pyqrack/task.py +114 -7
- bloqade/qasm2/_qasm_loading.py +3 -3
- bloqade/qasm2/dialects/core/address.py +21 -12
- bloqade/qasm2/dialects/expr/_emit.py +19 -8
- bloqade/qasm2/dialects/expr/stmts.py +7 -7
- bloqade/qasm2/dialects/noise/fidelity.py +4 -8
- bloqade/qasm2/dialects/noise/model.py +2 -1
- bloqade/qasm2/emit/base.py +16 -11
- bloqade/qasm2/emit/gate.py +11 -8
- bloqade/qasm2/emit/main.py +103 -3
- bloqade/qasm2/emit/target.py +9 -5
- bloqade/qasm2/groups.py +3 -2
- bloqade/qasm2/parse/lowering.py +0 -1
- bloqade/qasm2/passes/fold.py +14 -73
- bloqade/qasm2/passes/glob.py +2 -2
- bloqade/qasm2/passes/noise.py +1 -1
- bloqade/qasm2/passes/parallel.py +7 -5
- bloqade/qasm2/rewrite/__init__.py +0 -1
- bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
- bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
- bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
- bloqade/qasm2/rewrite/register.py +2 -2
- bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
- bloqade/qbraid/lowering.py +1 -0
- bloqade/qbraid/schema.py +2 -2
- bloqade/qubit/__init__.py +12 -0
- bloqade/qubit/_dialect.py +3 -0
- bloqade/qubit/_interface.py +49 -0
- bloqade/qubit/_prelude.py +45 -0
- bloqade/qubit/analysis/__init__.py +1 -0
- bloqade/qubit/analysis/address_impl.py +40 -0
- bloqade/qubit/stdlib/__init__.py +2 -0
- bloqade/qubit/stdlib/_new.py +34 -0
- bloqade/qubit/stdlib/broadcast.py +62 -0
- bloqade/qubit/stdlib/simple.py +59 -0
- bloqade/qubit/stmts.py +60 -0
- bloqade/rewrite/passes/__init__.py +6 -0
- bloqade/rewrite/passes/aggressive_unroll.py +103 -0
- bloqade/rewrite/passes/callgraph.py +116 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
- bloqade/rewrite/rules/split_ifs.py +18 -1
- bloqade/squin/__init__.py +47 -14
- bloqade/squin/analysis/__init__.py +0 -1
- bloqade/squin/analysis/schedule.py +10 -11
- bloqade/squin/gate/__init__.py +2 -0
- bloqade/squin/gate/_dialect.py +3 -0
- bloqade/squin/gate/_interface.py +98 -0
- bloqade/squin/gate/stmts.py +125 -0
- bloqade/squin/groups.py +5 -22
- bloqade/squin/noise/__init__.py +1 -10
- bloqade/squin/noise/_dialect.py +1 -1
- bloqade/squin/noise/_interface.py +45 -0
- bloqade/squin/noise/stmts.py +66 -28
- bloqade/squin/rewrite/U3_to_clifford.py +70 -51
- bloqade/squin/rewrite/__init__.py +0 -2
- bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
- bloqade/squin/rewrite/wrap_analysis.py +4 -35
- bloqade/squin/stdlib/__init__.py +0 -0
- bloqade/squin/stdlib/broadcast/__init__.py +34 -0
- bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
- bloqade/squin/stdlib/broadcast/gate.py +260 -0
- bloqade/squin/stdlib/broadcast/noise.py +144 -0
- bloqade/squin/stdlib/simple/__init__.py +33 -0
- bloqade/squin/stdlib/simple/gate.py +242 -0
- bloqade/squin/stdlib/simple/noise.py +126 -0
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +6 -0
- bloqade/stim/dialects/auxiliary/emit.py +19 -18
- bloqade/stim/dialects/collapse/emit_str.py +7 -8
- bloqade/stim/dialects/gate/emit.py +9 -10
- bloqade/stim/dialects/noise/emit.py +17 -13
- bloqade/stim/dialects/noise/stmts.py +5 -3
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/impls.py +16 -0
- bloqade/stim/emit/stim_str.py +48 -31
- bloqade/stim/groups.py +12 -2
- bloqade/stim/parse/lowering.py +14 -17
- bloqade/stim/passes/__init__.py +0 -2
- bloqade/stim/passes/flatten.py +26 -0
- bloqade/stim/passes/simplify_ifs.py +6 -1
- bloqade/stim/passes/squin_to_stim.py +9 -84
- bloqade/stim/rewrite/__init__.py +2 -4
- bloqade/stim/rewrite/get_record_util.py +24 -0
- bloqade/stim/rewrite/ifs_to_stim.py +24 -25
- bloqade/stim/rewrite/qubit_to_stim.py +90 -41
- bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
- bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
- bloqade/stim/rewrite/squin_measure.py +9 -18
- bloqade/stim/rewrite/squin_noise.py +134 -108
- bloqade/stim/rewrite/util.py +5 -192
- bloqade/test_utils.py +1 -1
- bloqade/types.py +10 -0
- bloqade/validation/__init__.py +2 -0
- bloqade/validation/analysis/__init__.py +5 -0
- bloqade/validation/analysis/analysis.py +41 -0
- bloqade/validation/analysis/lattice.py +58 -0
- bloqade/validation/kernel_validation.py +77 -0
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
- bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
- bloqade/pyqrack/squin/op.py +0 -180
- bloqade/pyqrack/squin/runtime.py +0 -535
- bloqade/pyqrack/squin/wire.py +0 -51
- bloqade/rewrite/rules/flatten_ilist.py +0 -51
- bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
- bloqade/squin/_typeinfer.py +0 -20
- bloqade/squin/analysis/address_impl.py +0 -71
- bloqade/squin/analysis/nsites/__init__.py +0 -9
- bloqade/squin/analysis/nsites/analysis.py +0 -50
- bloqade/squin/analysis/nsites/impls.py +0 -92
- bloqade/squin/analysis/nsites/lattice.py +0 -49
- bloqade/squin/cirq/__init__.py +0 -280
- bloqade/squin/cirq/emit/emit_circuit.py +0 -109
- bloqade/squin/cirq/emit/noise.py +0 -49
- bloqade/squin/cirq/emit/op.py +0 -125
- bloqade/squin/cirq/emit/qubit.py +0 -60
- bloqade/squin/cirq/emit/runtime.py +0 -242
- bloqade/squin/cirq/lowering.py +0 -440
- bloqade/squin/lowering.py +0 -54
- bloqade/squin/noise/_wrapper.py +0 -40
- bloqade/squin/noise/rewrite.py +0 -111
- bloqade/squin/op/__init__.py +0 -41
- bloqade/squin/op/_dialect.py +0 -3
- bloqade/squin/op/_wrapper.py +0 -121
- bloqade/squin/op/number.py +0 -5
- bloqade/squin/op/rewrite.py +0 -46
- bloqade/squin/op/stdlib.py +0 -62
- bloqade/squin/op/stmts.py +0 -276
- bloqade/squin/op/traits.py +0 -43
- bloqade/squin/op/types.py +0 -26
- bloqade/squin/qubit.py +0 -184
- bloqade/squin/rewrite/canonicalize.py +0 -60
- bloqade/squin/rewrite/desugar.py +0 -124
- bloqade/squin/types.py +0 -8
- bloqade/squin/wire.py +0 -201
- bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
- bloqade/stim/rewrite/wire_to_stim.py +0 -57
- bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
bloqade/stim/parse/lowering.py
CHANGED
|
@@ -98,6 +98,8 @@ def loads(
|
|
|
98
98
|
signature=func.Signature((), return_node.value.type),
|
|
99
99
|
body=body,
|
|
100
100
|
)
|
|
101
|
+
self_arg = ir.BlockArgument(body.blocks[0], 0) # Self argument
|
|
102
|
+
body.blocks[0]._args = (self_arg,)
|
|
101
103
|
return ir.Method(
|
|
102
104
|
mod=None,
|
|
103
105
|
py_func=None,
|
|
@@ -627,10 +629,13 @@ class Stim(lowering.LoweringABC[Node]):
|
|
|
627
629
|
# Parse tag
|
|
628
630
|
tag_parts = node.tag.split(";", maxsplit=1)[0].split(":", maxsplit=1)
|
|
629
631
|
nonstim_name = tag_parts[0]
|
|
630
|
-
nonce = 0
|
|
631
632
|
if len(tag_parts) == 2:
|
|
633
|
+
# This should be a correlated error of the form, e.g.,
|
|
634
|
+
# I_ERROR[correlated_loss:<identifier>](0.01) 0 1 2
|
|
635
|
+
# The identifier is a unique number that prevents stim from merging
|
|
636
|
+
# correlated errors. We discard the identifier, but verify it is an integer.
|
|
632
637
|
try:
|
|
633
|
-
|
|
638
|
+
_ = int(tag_parts[1])
|
|
634
639
|
except ValueError:
|
|
635
640
|
# String was not an integer
|
|
636
641
|
if self.error_unknown_nonstim:
|
|
@@ -643,22 +648,14 @@ class Stim(lowering.LoweringABC[Node]):
|
|
|
643
648
|
f"Unknown non-stim statement name: {nonstim_name!r} ({node!r})"
|
|
644
649
|
)
|
|
645
650
|
statement_cls = self.nonstim_noise_ops.get(nonstim_name)
|
|
651
|
+
stmt = None
|
|
646
652
|
if statement_cls is not None:
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
),
|
|
654
|
-
)
|
|
655
|
-
else:
|
|
656
|
-
stmt = statement_cls(
|
|
657
|
-
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
|
|
658
|
-
targets=self._get_multiple_qubit_or_rec_ssa(
|
|
659
|
-
state, node, node.targets_copy()
|
|
660
|
-
),
|
|
661
|
-
)
|
|
653
|
+
stmt = statement_cls(
|
|
654
|
+
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
|
|
655
|
+
targets=self._get_multiple_qubit_or_rec_ssa(
|
|
656
|
+
state, node, node.targets_copy()
|
|
657
|
+
),
|
|
658
|
+
)
|
|
662
659
|
return stmt
|
|
663
660
|
|
|
664
661
|
def visit_CircuitInstruction(
|
bloqade/stim/passes/__init__.py
CHANGED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Taken from Phillip Weinberg's bloqade-shuttle implementation
|
|
2
|
+
from dataclasses import field, dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.passes import Pass
|
|
6
|
+
from kirin.rewrite.abc import RewriteResult
|
|
7
|
+
|
|
8
|
+
from bloqade.rewrite.passes import AggressiveUnroll
|
|
9
|
+
from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class Flatten(Pass):
|
|
14
|
+
|
|
15
|
+
unroll: AggressiveUnroll = field(init=False)
|
|
16
|
+
simplify_if: StimSimplifyIfs = field(init=False)
|
|
17
|
+
|
|
18
|
+
def __post_init__(self):
|
|
19
|
+
self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise)
|
|
20
|
+
self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise)
|
|
21
|
+
|
|
22
|
+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
|
|
23
|
+
rewrite_result = RewriteResult()
|
|
24
|
+
rewrite_result = self.simplify_if(mt).join(rewrite_result)
|
|
25
|
+
rewrite_result = self.unroll(mt).join(rewrite_result)
|
|
26
|
+
return rewrite_result
|
|
@@ -7,8 +7,10 @@ from kirin.rewrite import (
|
|
|
7
7
|
Chain,
|
|
8
8
|
Fixpoint,
|
|
9
9
|
ConstantFold,
|
|
10
|
+
DeadCodeElimination,
|
|
10
11
|
CommonSubexpressionElimination,
|
|
11
12
|
)
|
|
13
|
+
from kirin.dialects.scf.trim import UnusedYield
|
|
12
14
|
from kirin.dialects.ilist.passes import ConstList2IList
|
|
13
15
|
|
|
14
16
|
from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
|
|
@@ -20,7 +22,10 @@ class StimSimplifyIfs(Pass):
|
|
|
20
22
|
def unsafe_run(self, mt: ir.Method):
|
|
21
23
|
|
|
22
24
|
result = Chain(
|
|
23
|
-
|
|
25
|
+
Walk(UnusedYield()),
|
|
26
|
+
Walk(StimLiftThenBody()),
|
|
27
|
+
# remove yields (if possible), then lift out as much stuff as possible
|
|
28
|
+
Walk(DeadCodeElimination()),
|
|
24
29
|
Walk(StimSplitIfStmts()),
|
|
25
30
|
).rewrite(mt.code)
|
|
26
31
|
|
|
@@ -1,30 +1,21 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
|
|
3
|
-
from kirin.passes import Fold, HintConst, TypeInfer
|
|
4
3
|
from kirin.rewrite import (
|
|
5
4
|
Walk,
|
|
6
5
|
Chain,
|
|
7
6
|
Fixpoint,
|
|
8
|
-
CFGCompactify,
|
|
9
|
-
InlineGetItem,
|
|
10
|
-
InlineGetField,
|
|
11
7
|
DeadCodeElimination,
|
|
12
8
|
CommonSubexpressionElimination,
|
|
13
9
|
)
|
|
14
|
-
from kirin.dialects import scf, ilist
|
|
15
10
|
from kirin.ir.method import Method
|
|
16
11
|
from kirin.passes.abc import Pass
|
|
17
12
|
from kirin.rewrite.abc import RewriteResult
|
|
18
|
-
from kirin.passes.inline import InlinePass
|
|
19
|
-
from kirin.rewrite.alias import InlineAlias
|
|
20
13
|
|
|
21
14
|
from bloqade.stim.rewrite import (
|
|
22
|
-
SquinWireToStim,
|
|
23
15
|
PyConstantToStim,
|
|
24
16
|
SquinNoiseToStim,
|
|
25
17
|
SquinQubitToStim,
|
|
26
18
|
SquinMeasureToStim,
|
|
27
|
-
SquinWireIdentityElimination,
|
|
28
19
|
)
|
|
29
20
|
from bloqade.squin.rewrite import (
|
|
30
21
|
SquinU3ToClifford,
|
|
@@ -34,41 +25,9 @@ from bloqade.squin.rewrite import (
|
|
|
34
25
|
from bloqade.rewrite.passes import CanonicalizeIList
|
|
35
26
|
from bloqade.analysis.address import AddressAnalysis
|
|
36
27
|
from bloqade.analysis.measure_id import MeasurementIDAnalysis
|
|
37
|
-
from bloqade.
|
|
28
|
+
from bloqade.stim.passes.flatten import Flatten
|
|
38
29
|
|
|
39
|
-
from
|
|
40
|
-
from ..rewrite.ifs_to_stim import IfToStim
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
@dataclass
|
|
44
|
-
class AggressiveForLoopUnroll(Pass):
|
|
45
|
-
"""
|
|
46
|
-
Aggressive unrolling of for loops, addresses cases where unroll
|
|
47
|
-
does not successfully handle nested loops because of a lack of constprop.
|
|
48
|
-
|
|
49
|
-
This should be invoked via fixpoint to let this be repeatedly applied until
|
|
50
|
-
no further rewrites are possible.
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
def unsafe_run(self, mt: Method) -> RewriteResult:
|
|
54
|
-
rule = Chain(
|
|
55
|
-
InlineGetField(),
|
|
56
|
-
InlineGetItem(),
|
|
57
|
-
scf.unroll.ForLoop(),
|
|
58
|
-
scf.trim.UnusedYield(),
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
# Intentionally only walk ONCE, let fixpoint happen with the WHOLE pass
|
|
62
|
-
# so that HintConst gets run right after, allowing subsequent unrolls to happen
|
|
63
|
-
rewrite_result = Walk(rule).rewrite(mt.code)
|
|
64
|
-
|
|
65
|
-
rewrite_result = (
|
|
66
|
-
HintConst(dialects=mt.dialects, no_raise=self.no_raise)
|
|
67
|
-
.unsafe_run(mt)
|
|
68
|
-
.join(rewrite_result)
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
return rewrite_result
|
|
30
|
+
from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim
|
|
72
31
|
|
|
73
32
|
|
|
74
33
|
@dataclass
|
|
@@ -77,52 +36,18 @@ class SquinToStimPass(Pass):
|
|
|
77
36
|
def unsafe_run(self, mt: Method) -> RewriteResult:
|
|
78
37
|
|
|
79
38
|
# inline aggressively:
|
|
80
|
-
rewrite_result =
|
|
81
|
-
|
|
82
|
-
).unsafe_run(mt)
|
|
83
|
-
|
|
84
|
-
rewrite_result = (
|
|
85
|
-
AggressiveForLoopUnroll(dialects=mt.dialects, no_raise=self.no_raise)
|
|
86
|
-
.fixpoint(mt)
|
|
87
|
-
.join(rewrite_result)
|
|
39
|
+
rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint(
|
|
40
|
+
mt
|
|
88
41
|
)
|
|
89
42
|
|
|
90
|
-
rewrite_result = (
|
|
91
|
-
Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
Walk(InlineAlias()).rewrite(mt.code).join(rewrite_result)
|
|
95
|
-
|
|
96
|
-
rewrite_result = (
|
|
97
|
-
StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
|
|
98
|
-
.unsafe_run(mt)
|
|
99
|
-
.join(rewrite_result)
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
rewrite_result = (
|
|
103
|
-
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
|
|
104
|
-
.rewrite(mt.code)
|
|
105
|
-
.join(rewrite_result)
|
|
106
|
-
)
|
|
107
|
-
rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt)
|
|
108
|
-
|
|
109
|
-
rewrite_result = (
|
|
110
|
-
CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
|
|
111
|
-
.unsafe_run(mt)
|
|
112
|
-
.join(rewrite_result)
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
TypeInfer(dialects=mt.dialects, no_raise=self.no_raise).unsafe_run(mt)
|
|
116
|
-
Walk(ApplyDesugarRule()).rewrite(mt.code)
|
|
117
|
-
|
|
118
43
|
# after this the program should be in a state where it is analyzable
|
|
119
44
|
# -------------------------------------------------------------------
|
|
120
45
|
|
|
121
46
|
mia = MeasurementIDAnalysis(dialects=mt.dialects)
|
|
122
|
-
meas_analysis_frame, _ = mia.
|
|
47
|
+
meas_analysis_frame, _ = mia.run(mt)
|
|
123
48
|
|
|
124
49
|
aa = AddressAnalysis(dialects=mt.dialects)
|
|
125
|
-
address_analysis_frame, _ = aa.
|
|
50
|
+
address_analysis_frame, _ = aa.run(mt)
|
|
126
51
|
|
|
127
52
|
# wrap the address analysis result
|
|
128
53
|
rewrite_result = (
|
|
@@ -139,6 +64,8 @@ class SquinToStimPass(Pass):
|
|
|
139
64
|
rewrite_result = (
|
|
140
65
|
Chain(
|
|
141
66
|
Walk(IfToStim(measure_frame=meas_analysis_frame)),
|
|
67
|
+
Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)),
|
|
68
|
+
Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)),
|
|
142
69
|
Fixpoint(Walk(DeadCodeElimination())),
|
|
143
70
|
)
|
|
144
71
|
.rewrite(mt.code)
|
|
@@ -156,8 +83,6 @@ class SquinToStimPass(Pass):
|
|
|
156
83
|
Chain(
|
|
157
84
|
SquinQubitToStim(),
|
|
158
85
|
SquinMeasureToStim(),
|
|
159
|
-
SquinWireToStim(),
|
|
160
|
-
SquinWireIdentityElimination(),
|
|
161
86
|
)
|
|
162
87
|
)
|
|
163
88
|
.rewrite(mt.code)
|
|
@@ -174,7 +99,7 @@ class SquinToStimPass(Pass):
|
|
|
174
99
|
rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
|
|
175
100
|
|
|
176
101
|
# clear up leftover stmts
|
|
177
|
-
# - remove any squin.
|
|
102
|
+
# - remove any squin.qalloc that's left around
|
|
178
103
|
rewrite_result = (
|
|
179
104
|
Fixpoint(
|
|
180
105
|
Walk(
|
bloqade/stim/rewrite/__init__.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
from .ifs_to_stim import IfToStim as IfToStim
|
|
2
2
|
from .squin_noise import SquinNoiseToStim as SquinNoiseToStim
|
|
3
|
-
from .wire_to_stim import SquinWireToStim as SquinWireToStim
|
|
4
3
|
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
|
|
5
4
|
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
|
|
6
5
|
from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
|
|
7
|
-
from .
|
|
8
|
-
|
|
9
|
-
)
|
|
6
|
+
from .set_detector_to_stim import SetDetectorToStim as SetDetectorToStim
|
|
7
|
+
from .set_observable_to_stim import SetObservableToStim as SetObservableToStim
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.dialects import py
|
|
3
|
+
|
|
4
|
+
from bloqade.stim.dialects import auxiliary
|
|
5
|
+
from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def insert_get_records(
|
|
9
|
+
node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int
|
|
10
|
+
):
|
|
11
|
+
"""
|
|
12
|
+
Insert GetRecord statements before the given node
|
|
13
|
+
"""
|
|
14
|
+
get_record_ssas = []
|
|
15
|
+
for measure_id_bool in measure_id_tuple.data:
|
|
16
|
+
assert isinstance(measure_id_bool, MeasureIdBool)
|
|
17
|
+
target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt
|
|
18
|
+
idx_stmt = py.constant.Constant(target_rec_idx)
|
|
19
|
+
idx_stmt.insert_before(node)
|
|
20
|
+
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
|
|
21
|
+
get_record_stmt.insert_before(node)
|
|
22
|
+
get_record_ssas.append(get_record_stmt.result)
|
|
23
|
+
|
|
24
|
+
return get_record_ssas
|
|
@@ -4,13 +4,13 @@ from kirin import ir
|
|
|
4
4
|
from kirin.dialects import py, scf, func
|
|
5
5
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
6
6
|
|
|
7
|
-
from bloqade.squin import
|
|
7
|
+
from bloqade.squin import gate
|
|
8
8
|
from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
|
|
9
9
|
from bloqade.squin.rewrite import AddressAttribute
|
|
10
10
|
from bloqade.stim.rewrite.util import (
|
|
11
|
-
SQUIN_STIM_CONTROL_GATE_MAPPING,
|
|
12
11
|
insert_qubit_idx_from_address,
|
|
13
12
|
)
|
|
13
|
+
from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ
|
|
14
14
|
from bloqade.analysis.measure_id import MeasureIDFrame
|
|
15
15
|
from bloqade.stim.dialects.auxiliary import GetRecord
|
|
16
16
|
from bloqade.analysis.measure_id.lattice import (
|
|
@@ -58,8 +58,7 @@ class IfElseSimplification:
|
|
|
58
58
|
"""Check if the IfElse statement has an else body."""
|
|
59
59
|
if stmt.else_body.blocks and not (
|
|
60
60
|
len(stmt.else_body.blocks[0].stmts) == 1
|
|
61
|
-
and isinstance(
|
|
62
|
-
and not else_term.values # empty yield
|
|
61
|
+
and isinstance(stmt.else_body.blocks[0].last_stmt, scf.Yield)
|
|
63
62
|
):
|
|
64
63
|
return True
|
|
65
64
|
|
|
@@ -67,12 +66,13 @@ class IfElseSimplification:
|
|
|
67
66
|
|
|
68
67
|
|
|
69
68
|
DontLiftType = (
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
69
|
+
gate.stmts.SingleQubitGate,
|
|
70
|
+
gate.stmts.RotationGate,
|
|
71
|
+
gate.stmts.ControlledGate,
|
|
73
72
|
func.Return,
|
|
74
73
|
func.Invoke,
|
|
75
74
|
scf.IfElse,
|
|
75
|
+
scf.Yield,
|
|
76
76
|
)
|
|
77
77
|
|
|
78
78
|
|
|
@@ -99,16 +99,16 @@ class StimSplitIfStmts(IfElseSimplification, SplitIfStmts):
|
|
|
99
99
|
Given an IfElse with multiple valid statements in the then-body:
|
|
100
100
|
|
|
101
101
|
if measure_result:
|
|
102
|
-
squin.
|
|
103
|
-
squin.
|
|
102
|
+
squin.x(q0)
|
|
103
|
+
squin.y(q1)
|
|
104
104
|
|
|
105
105
|
this should be rewritten to:
|
|
106
106
|
|
|
107
107
|
if measure_result:
|
|
108
|
-
squin.
|
|
108
|
+
squin.x(q0)
|
|
109
109
|
|
|
110
110
|
if measure_result:
|
|
111
|
-
squin.
|
|
111
|
+
squin.y(q1)
|
|
112
112
|
"""
|
|
113
113
|
|
|
114
114
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
@@ -139,24 +139,23 @@ class IfToStim(IfElseSimplification, RewriteRule):
|
|
|
139
139
|
|
|
140
140
|
def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
|
|
141
141
|
|
|
142
|
+
# Check the condition is a singular MeasurementIdBool
|
|
142
143
|
if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool):
|
|
143
144
|
return RewriteResult()
|
|
144
145
|
|
|
145
|
-
#
|
|
146
|
-
#
|
|
147
|
-
# Can reuse logic from SplitIf
|
|
146
|
+
# Reusing code from SplitIf,
|
|
147
|
+
# there should only be one statement in the body and it should be a pauli X, Y, or Z
|
|
148
148
|
*stmts, _ = stmt.then_body.stmts()
|
|
149
|
-
if len(stmts) != 1
|
|
149
|
+
if len(stmts) != 1:
|
|
150
150
|
return RewriteResult()
|
|
151
151
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
if stim_gate is None:
|
|
152
|
+
if isinstance(stmts[0], gate.stmts.X):
|
|
153
|
+
stim_gate = stim_CX
|
|
154
|
+
elif isinstance(stmts[0], gate.stmts.Y):
|
|
155
|
+
stim_gate = stim_CY
|
|
156
|
+
elif isinstance(stmts[0], gate.stmts.Z):
|
|
157
|
+
stim_gate = stim_CZ
|
|
158
|
+
else:
|
|
160
159
|
return RewriteResult()
|
|
161
160
|
|
|
162
161
|
# get necessary measurement ID type from analysis
|
|
@@ -169,8 +168,8 @@ class IfToStim(IfElseSimplification, RewriteRule):
|
|
|
169
168
|
)
|
|
170
169
|
get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
|
|
171
170
|
|
|
172
|
-
|
|
173
|
-
|
|
171
|
+
address_attr = stmts[0].qubits.hints.get("address")
|
|
172
|
+
|
|
174
173
|
if address_attr is None:
|
|
175
174
|
return RewriteResult()
|
|
176
175
|
assert isinstance(address_attr, AddressAttribute)
|
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
from kirin import ir
|
|
2
2
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
3
|
|
|
4
|
-
from bloqade
|
|
4
|
+
from bloqade import qubit
|
|
5
|
+
from bloqade.squin import gate
|
|
5
6
|
from bloqade.squin.rewrite import AddressAttribute
|
|
6
|
-
from bloqade.stim.dialects import gate
|
|
7
|
+
from bloqade.stim.dialects import gate as stim_gate, collapse as stim_collapse
|
|
7
8
|
from bloqade.stim.rewrite.util import (
|
|
8
|
-
SQUIN_STIM_OP_MAPPING,
|
|
9
|
-
rewrite_Control,
|
|
10
|
-
rewrite_QubitLoss,
|
|
11
9
|
insert_qubit_idx_from_address,
|
|
12
10
|
)
|
|
13
11
|
|
|
@@ -20,64 +18,115 @@ class SquinQubitToStim(RewriteRule):
|
|
|
20
18
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
21
19
|
|
|
22
20
|
match node:
|
|
23
|
-
|
|
24
|
-
|
|
21
|
+
# not supported by Stim
|
|
22
|
+
case gate.stmts.T() | gate.stmts.RotationGate():
|
|
23
|
+
return RewriteResult()
|
|
24
|
+
# If you've reached this point all gates have stim equivalents
|
|
25
|
+
case qubit.stmts.Reset():
|
|
26
|
+
return self.rewrite_Reset(node)
|
|
27
|
+
case gate.stmts.SingleQubitGate():
|
|
28
|
+
return self.rewrite_SingleQubitGate(node)
|
|
29
|
+
case gate.stmts.ControlledGate():
|
|
30
|
+
return self.rewrite_ControlledGate(node)
|
|
25
31
|
case _:
|
|
26
32
|
return RewriteResult()
|
|
27
33
|
|
|
28
|
-
def
|
|
29
|
-
self, stmt: qubit.Apply | qubit.Broadcast
|
|
30
|
-
) -> RewriteResult:
|
|
31
|
-
"""
|
|
32
|
-
Rewrite Apply and Broadcast nodes to their stim equivalent statements.
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
# this is an SSAValue, need it to be the actual operator
|
|
36
|
-
applied_op = stmt.operator.owner
|
|
34
|
+
def rewrite_Reset(self, stmt: qubit.stmts.Reset) -> RewriteResult:
|
|
37
35
|
|
|
38
|
-
|
|
39
|
-
return rewrite_QubitLoss(stmt)
|
|
36
|
+
qubit_addr_attr = stmt.qubits.hints.get("address", None)
|
|
40
37
|
|
|
41
|
-
|
|
38
|
+
if qubit_addr_attr is None:
|
|
39
|
+
return RewriteResult()
|
|
42
40
|
|
|
43
|
-
|
|
44
|
-
return rewrite_Control(stmt)
|
|
41
|
+
assert isinstance(qubit_addr_attr, AddressAttribute)
|
|
45
42
|
|
|
46
|
-
|
|
43
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
44
|
+
address=qubit_addr_attr, stmt_to_insert_before=stmt
|
|
45
|
+
)
|
|
47
46
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
if isinstance(applied_op, op.stmts.Adjoint):
|
|
51
|
-
if not applied_op.is_unitary:
|
|
52
|
-
return RewriteResult()
|
|
47
|
+
if qubit_idx_ssas is None:
|
|
48
|
+
return RewriteResult()
|
|
53
49
|
|
|
54
|
-
|
|
55
|
-
|
|
50
|
+
stim_stmt = stim_collapse.RZ(targets=tuple(qubit_idx_ssas))
|
|
51
|
+
stmt.replace_by(stim_stmt)
|
|
56
52
|
|
|
57
|
-
|
|
58
|
-
if stim_1q_op is None:
|
|
59
|
-
return RewriteResult()
|
|
53
|
+
return RewriteResult(has_done_something=True)
|
|
60
54
|
|
|
61
|
-
|
|
55
|
+
def rewrite_SingleQubitGate(
|
|
56
|
+
self, stmt: gate.stmts.SingleQubitGate
|
|
57
|
+
) -> RewriteResult:
|
|
58
|
+
"""
|
|
59
|
+
Rewrite single qubit gate nodes to their stim equivalent statements.
|
|
60
|
+
Address Analysis should have been run along with Wrap Analysis before this rewrite is applied.
|
|
61
|
+
"""
|
|
62
62
|
|
|
63
|
-
|
|
63
|
+
qubit_addr_attr = stmt.qubits.hints.get("address", None)
|
|
64
|
+
if qubit_addr_attr is None:
|
|
64
65
|
return RewriteResult()
|
|
65
66
|
|
|
66
|
-
assert isinstance(
|
|
67
|
+
assert isinstance(qubit_addr_attr, AddressAttribute)
|
|
68
|
+
|
|
67
69
|
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
68
|
-
address=
|
|
70
|
+
address=qubit_addr_attr, stmt_to_insert_before=stmt
|
|
69
71
|
)
|
|
70
72
|
|
|
71
73
|
if qubit_idx_ssas is None:
|
|
72
74
|
return RewriteResult()
|
|
73
75
|
|
|
74
|
-
if
|
|
75
|
-
|
|
76
|
+
# Get the name of the inputted stmt and see if there is an
|
|
77
|
+
# equivalently named statement in stim,
|
|
78
|
+
# then create an instance of that stim statement
|
|
79
|
+
stmt_name = type(stmt).__name__
|
|
80
|
+
stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None)
|
|
81
|
+
if stim_stmt_cls is None:
|
|
82
|
+
return RewriteResult()
|
|
83
|
+
|
|
84
|
+
if isinstance(stmt, gate.stmts.SingleQubitNonHermitianGate):
|
|
85
|
+
stim_stmt = stim_stmt_cls(
|
|
86
|
+
targets=tuple(qubit_idx_ssas), dagger=stmt.adjoint
|
|
87
|
+
)
|
|
76
88
|
else:
|
|
77
|
-
|
|
78
|
-
stmt.replace_by(
|
|
89
|
+
stim_stmt = stim_stmt_cls(targets=tuple(qubit_idx_ssas))
|
|
90
|
+
stmt.replace_by(stim_stmt)
|
|
79
91
|
|
|
80
92
|
return RewriteResult(has_done_something=True)
|
|
81
93
|
|
|
94
|
+
def rewrite_ControlledGate(self, stmt: gate.stmts.ControlledGate) -> RewriteResult:
|
|
95
|
+
"""
|
|
96
|
+
Rewrite controlled gate nodes to their stim equivalent statements.
|
|
97
|
+
Address Analysis should have been run along with Wrap Analysis before this rewrite is applied.
|
|
98
|
+
"""
|
|
82
99
|
|
|
83
|
-
|
|
100
|
+
controls_addr_attr = stmt.controls.hints.get("address", None)
|
|
101
|
+
targets_addr_attr = stmt.targets.hints.get("address", None)
|
|
102
|
+
|
|
103
|
+
if controls_addr_attr is None or targets_addr_attr is None:
|
|
104
|
+
return RewriteResult()
|
|
105
|
+
|
|
106
|
+
assert isinstance(controls_addr_attr, AddressAttribute)
|
|
107
|
+
assert isinstance(targets_addr_attr, AddressAttribute)
|
|
108
|
+
|
|
109
|
+
controls_idx_ssas = insert_qubit_idx_from_address(
|
|
110
|
+
address=controls_addr_attr, stmt_to_insert_before=stmt
|
|
111
|
+
)
|
|
112
|
+
targets_idx_ssas = insert_qubit_idx_from_address(
|
|
113
|
+
address=targets_addr_attr, stmt_to_insert_before=stmt
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if controls_idx_ssas is None or targets_idx_ssas is None:
|
|
117
|
+
return RewriteResult()
|
|
118
|
+
|
|
119
|
+
# Get the name of the inputted stmt and see if there is an
|
|
120
|
+
# equivalently named statement in stim,
|
|
121
|
+
# then create an instance of that stim statement
|
|
122
|
+
stmt_name = type(stmt).__name__
|
|
123
|
+
stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None)
|
|
124
|
+
if stim_stmt_cls is None:
|
|
125
|
+
return RewriteResult()
|
|
126
|
+
|
|
127
|
+
stim_stmt = stim_stmt_cls(
|
|
128
|
+
targets=tuple(targets_idx_ssas), controls=tuple(controls_idx_ssas)
|
|
129
|
+
)
|
|
130
|
+
stmt.replace_by(stim_stmt)
|
|
131
|
+
|
|
132
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.dialects.py import Constant
|
|
6
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
7
|
+
|
|
8
|
+
from bloqade.stim.dialects import auxiliary
|
|
9
|
+
from bloqade.annotate.stmts import SetDetector
|
|
10
|
+
from bloqade.analysis.measure_id import MeasureIDFrame
|
|
11
|
+
from bloqade.stim.dialects.auxiliary import Detector
|
|
12
|
+
from bloqade.analysis.measure_id.lattice import MeasureIdTuple
|
|
13
|
+
|
|
14
|
+
from ..rewrite.get_record_util import insert_get_records
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class SetDetectorToStim(RewriteRule):
|
|
19
|
+
"""
|
|
20
|
+
Rewrite SetDetector to GetRecord and Detector in the stim dialect
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
measure_id_frame: MeasureIDFrame
|
|
24
|
+
|
|
25
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
26
|
+
match node:
|
|
27
|
+
case SetDetector():
|
|
28
|
+
return self.rewrite_SetDetector(node)
|
|
29
|
+
case _:
|
|
30
|
+
return RewriteResult()
|
|
31
|
+
|
|
32
|
+
def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult:
|
|
33
|
+
|
|
34
|
+
# get coordinates and generate correct consts
|
|
35
|
+
coord_ssas = []
|
|
36
|
+
if not isinstance(node.coordinates.owner, Constant):
|
|
37
|
+
return RewriteResult()
|
|
38
|
+
|
|
39
|
+
coord_values = node.coordinates.owner.value.unwrap()
|
|
40
|
+
|
|
41
|
+
if not isinstance(coord_values, Iterable):
|
|
42
|
+
return RewriteResult()
|
|
43
|
+
|
|
44
|
+
if any(not isinstance(value, (int, float)) for value in coord_values):
|
|
45
|
+
return RewriteResult()
|
|
46
|
+
|
|
47
|
+
for coord_value in coord_values:
|
|
48
|
+
if isinstance(coord_value, float):
|
|
49
|
+
coord_stmt = auxiliary.ConstFloat(value=coord_value)
|
|
50
|
+
else: # int
|
|
51
|
+
coord_stmt = auxiliary.ConstInt(value=coord_value)
|
|
52
|
+
coord_ssas.append(coord_stmt.result)
|
|
53
|
+
coord_stmt.insert_before(node)
|
|
54
|
+
|
|
55
|
+
measure_ids = self.measure_id_frame.entries[node.measurements]
|
|
56
|
+
assert isinstance(measure_ids, MeasureIdTuple)
|
|
57
|
+
|
|
58
|
+
get_record_list = insert_get_records(
|
|
59
|
+
node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
detector_stmt = Detector(
|
|
63
|
+
coord=tuple(coord_ssas), targets=tuple(get_record_list)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
node.replace_by(detector_stmt)
|
|
67
|
+
|
|
68
|
+
return RewriteResult(has_done_something=True)
|