bloqade-circuit 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +21 -68
- bloqade/analysis/measure_id/__init__.py +2 -0
- bloqade/analysis/measure_id/analysis.py +45 -0
- bloqade/analysis/measure_id/impls.py +155 -0
- bloqade/analysis/measure_id/lattice.py +82 -0
- bloqade/cirq_utils/__init__.py +7 -0
- bloqade/cirq_utils/lineprog.py +295 -0
- bloqade/cirq_utils/parallelize.py +400 -0
- bloqade/pyqrack/squin/op.py +7 -2
- bloqade/pyqrack/squin/runtime.py +4 -2
- bloqade/qasm2/dialects/expr/stmts.py +2 -20
- bloqade/qasm2/parse/lowering.py +1 -0
- bloqade/qasm2/passes/parallel.py +18 -0
- bloqade/qasm2/passes/unroll_if.py +9 -2
- bloqade/qasm2/rewrite/__init__.py +1 -0
- bloqade/qasm2/rewrite/parallel_to_glob.py +82 -0
- bloqade/rewrite/__init__.py +0 -0
- bloqade/rewrite/passes/__init__.py +1 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +28 -0
- bloqade/rewrite/rules/__init__.py +1 -0
- bloqade/rewrite/rules/flatten_ilist.py +51 -0
- bloqade/rewrite/rules/inline_getitem_ilist.py +31 -0
- bloqade/{qasm2/rewrite → rewrite/rules}/split_ifs.py +15 -8
- bloqade/squin/__init__.py +2 -0
- bloqade/squin/_typeinfer.py +20 -0
- bloqade/squin/analysis/__init__.py +1 -0
- bloqade/squin/analysis/address_impl.py +71 -0
- bloqade/squin/analysis/nsites/impls.py +6 -1
- bloqade/squin/cirq/lowering.py +19 -6
- bloqade/squin/noise/stmts.py +1 -1
- bloqade/squin/op/__init__.py +1 -0
- bloqade/squin/op/_wrapper.py +4 -0
- bloqade/squin/op/stmts.py +20 -2
- bloqade/squin/qubit.py +8 -5
- bloqade/squin/rewrite/__init__.py +1 -0
- bloqade/squin/rewrite/canonicalize.py +60 -0
- bloqade/squin/rewrite/desugar.py +52 -5
- bloqade/squin/types.py +8 -0
- bloqade/squin/wire.py +91 -5
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +4 -0
- bloqade/stim/dialects/auxiliary/interp.py +0 -10
- bloqade/stim/dialects/auxiliary/stmts/annotate.py +1 -1
- bloqade/stim/dialects/noise/emit.py +1 -0
- bloqade/stim/dialects/noise/stmts.py +5 -0
- bloqade/stim/passes/__init__.py +1 -1
- bloqade/stim/passes/simplify_ifs.py +32 -0
- bloqade/stim/passes/squin_to_stim.py +109 -26
- bloqade/stim/rewrite/__init__.py +1 -0
- bloqade/stim/rewrite/ifs_to_stim.py +203 -0
- bloqade/stim/rewrite/qubit_to_stim.py +13 -6
- bloqade/stim/rewrite/squin_measure.py +68 -5
- bloqade/stim/rewrite/squin_noise.py +120 -0
- bloqade/stim/rewrite/util.py +40 -9
- bloqade/stim/rewrite/wire_to_stim.py +8 -3
- bloqade/stim/upstream/__init__.py +1 -0
- bloqade/stim/upstream/from_squin.py +10 -0
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/METADATA +4 -2
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/RECORD +61 -38
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.rewrite import abc
|
|
5
|
+
from kirin.dialects import cf
|
|
6
|
+
|
|
7
|
+
from .. import wire
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CanonicalizeWired(abc.RewriteRule):
|
|
11
|
+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
|
|
12
|
+
|
|
13
|
+
if (
|
|
14
|
+
not isinstance(node, wire.Wired)
|
|
15
|
+
or len(node.qubits) != 0
|
|
16
|
+
or (parent_region := node.parent_region) is None
|
|
17
|
+
):
|
|
18
|
+
return abc.RewriteResult()
|
|
19
|
+
|
|
20
|
+
parent_block = cast(ir.Block, node.parent_block)
|
|
21
|
+
|
|
22
|
+
# the body doesn't contain any quantum operations so we can safely inline the
|
|
23
|
+
# body into the parent block
|
|
24
|
+
|
|
25
|
+
# move all statements after `node` in the current block into another block
|
|
26
|
+
after_block = ir.Block()
|
|
27
|
+
|
|
28
|
+
stmt = node.next_stmt
|
|
29
|
+
while stmt is not None:
|
|
30
|
+
stmt.detach()
|
|
31
|
+
after_block.stmts.append(stmt)
|
|
32
|
+
stmt = node.next_stmt
|
|
33
|
+
|
|
34
|
+
# remap all results of the node to the arguments of the after_block
|
|
35
|
+
for result in node.results:
|
|
36
|
+
arg = after_block.args.append_from(result.type, result.name)
|
|
37
|
+
result.replace_by(arg)
|
|
38
|
+
|
|
39
|
+
parent_block_idx = parent_region._block_idx[parent_block]
|
|
40
|
+
# insert goto of parent block to the body block of the node.
|
|
41
|
+
parent_region.blocks.insert(parent_block_idx + 1, after_block)
|
|
42
|
+
# insert all blocks of the body of the node after the parent region
|
|
43
|
+
# making sure to convert any yield statements to jump statements to the after_block
|
|
44
|
+
parent_block.stmts.append(
|
|
45
|
+
cf.Branch(
|
|
46
|
+
arguments=(),
|
|
47
|
+
successor=node.body.blocks[0],
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
for block in reversed(node.body.blocks):
|
|
51
|
+
block.detach()
|
|
52
|
+
if isinstance((yield_stmt := block.last_stmt), wire.Yield):
|
|
53
|
+
yield_stmt.replace_by(
|
|
54
|
+
cf.Branch(yield_stmt.values, successor=after_block)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
parent_region.blocks.insert(parent_block_idx + 1, block)
|
|
58
|
+
|
|
59
|
+
node.delete()
|
|
60
|
+
return abc.RewriteResult(has_done_something=True)
|
bloqade/squin/rewrite/desugar.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from kirin import ir, types
|
|
2
|
-
from kirin.dialects import ilist
|
|
2
|
+
from kirin.dialects import py, ilist
|
|
3
3
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
4
4
|
|
|
5
5
|
from bloqade.squin.qubit import (
|
|
@@ -53,12 +53,59 @@ class ApplyDesugarRule(RewriteRule):
|
|
|
53
53
|
op = node.operator
|
|
54
54
|
qubits = node.qubits
|
|
55
55
|
|
|
56
|
-
if len(qubits)
|
|
57
|
-
|
|
56
|
+
if len(qubits) > 1 and all(q.type.is_subseteq(QubitType) for q in qubits):
|
|
57
|
+
(qubits_ilist_stmt := ilist.New(qubits)).insert_before(node)
|
|
58
|
+
qubits_ilist = qubits_ilist_stmt.result
|
|
59
|
+
|
|
60
|
+
elif len(qubits) == 1 and qubits[0].type.is_subseteq(QubitType):
|
|
61
|
+
(qubits_ilist_stmt := ilist.New(qubits)).insert_before(node)
|
|
62
|
+
qubits_ilist = qubits_ilist_stmt.result
|
|
63
|
+
|
|
64
|
+
elif len(qubits) == 1 and qubits[0].type.is_subseteq(
|
|
65
|
+
ilist.IListType[QubitType, types.Any]
|
|
66
|
+
):
|
|
58
67
|
qubits_ilist = qubits[0]
|
|
59
|
-
|
|
60
|
-
|
|
68
|
+
|
|
69
|
+
elif len(qubits) == 1:
|
|
70
|
+
# TODO: remove this elif clause once we're at kirin v0.18
|
|
71
|
+
# NOTE: this is a temporary workaround for kirin#408
|
|
72
|
+
# currently type inference fails here in for loops since the loop var
|
|
73
|
+
# is an IList for some reason
|
|
74
|
+
|
|
75
|
+
if not isinstance(qubits[0], ir.ResultValue):
|
|
76
|
+
return RewriteResult()
|
|
77
|
+
|
|
78
|
+
is_ilist = isinstance(qbit_stmt := qubits[0].stmt, ilist.New)
|
|
79
|
+
if is_ilist:
|
|
80
|
+
if len(qbit_stmt.values) != 1:
|
|
81
|
+
return RewriteResult()
|
|
82
|
+
|
|
83
|
+
if not isinstance(
|
|
84
|
+
qbit_getindex_result := qbit_stmt.values[0], ir.ResultValue
|
|
85
|
+
):
|
|
86
|
+
return RewriteResult()
|
|
87
|
+
|
|
88
|
+
qbit_getindex = qbit_getindex_result.stmt
|
|
89
|
+
else:
|
|
90
|
+
qbit_getindex = qubits[0].stmt
|
|
91
|
+
|
|
92
|
+
if not isinstance(qbit_getindex, py.indexing.GetItem):
|
|
93
|
+
return RewriteResult()
|
|
94
|
+
|
|
95
|
+
if not qbit_getindex.obj.type.is_subseteq(
|
|
96
|
+
ilist.IListType[QubitType, types.Any]
|
|
97
|
+
):
|
|
98
|
+
return RewriteResult()
|
|
99
|
+
|
|
100
|
+
if is_ilist:
|
|
101
|
+
values = qbit_stmt.values
|
|
102
|
+
else:
|
|
103
|
+
values = [qubits[0]]
|
|
104
|
+
|
|
105
|
+
(qubits_ilist_stmt := ilist.New(values=values)).insert_before(node)
|
|
61
106
|
qubits_ilist = qubits_ilist_stmt.result
|
|
107
|
+
else:
|
|
108
|
+
return RewriteResult()
|
|
62
109
|
|
|
63
110
|
stmt = Apply(operator=op, qubits=qubits_ilist)
|
|
64
111
|
node.replace_by(stmt)
|
bloqade/squin/types.py
ADDED
bloqade/squin/wire.py
CHANGED
|
@@ -6,12 +6,15 @@ circuits. Thus we do not define wrapping functions for the statements in this
|
|
|
6
6
|
dialect.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
from kirin import ir, types, lowering
|
|
9
|
+
from kirin import ir, types, lowering, exception
|
|
10
10
|
from kirin.decl import info, statement
|
|
11
|
+
from kirin.dialects import func
|
|
11
12
|
from kirin.lowering import wraps
|
|
13
|
+
from kirin.ir.attrs.types import TypeAttribute
|
|
12
14
|
|
|
13
15
|
from bloqade.types import Qubit, QubitType
|
|
14
16
|
|
|
17
|
+
from .types import MeasurementResultType
|
|
15
18
|
from .op.types import Op, OpType
|
|
16
19
|
|
|
17
20
|
# from kirin.lowering import wraps
|
|
@@ -49,11 +52,87 @@ class Unwrap(ir.Statement):
|
|
|
49
52
|
result: ir.ResultValue = info.result(WireType)
|
|
50
53
|
|
|
51
54
|
|
|
55
|
+
@statement(dialect=dialect)
|
|
56
|
+
class Wired(ir.Statement):
|
|
57
|
+
traits = frozenset()
|
|
58
|
+
|
|
59
|
+
qubits: tuple[ir.SSAValue, ...] = info.argument(QubitType)
|
|
60
|
+
memory_zone: str = info.attribute()
|
|
61
|
+
body: ir.Region = info.region(multi=True)
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
body: ir.Region,
|
|
66
|
+
*qubits: ir.SSAValue,
|
|
67
|
+
memory_zone: str,
|
|
68
|
+
result_types: tuple[TypeAttribute, ...] | None = None,
|
|
69
|
+
):
|
|
70
|
+
if result_types is None:
|
|
71
|
+
for block in body.blocks:
|
|
72
|
+
if isinstance(block.last_stmt, Yield):
|
|
73
|
+
result_types = tuple(arg.type for arg in block.last_stmt.values)
|
|
74
|
+
break
|
|
75
|
+
|
|
76
|
+
if result_types is None:
|
|
77
|
+
result_types = ()
|
|
78
|
+
|
|
79
|
+
super().__init__(
|
|
80
|
+
args=qubits,
|
|
81
|
+
args_slice={
|
|
82
|
+
"qubits": slice(0, None),
|
|
83
|
+
},
|
|
84
|
+
regions=[body],
|
|
85
|
+
attributes={
|
|
86
|
+
"memory_zone": ir.PyAttr(memory_zone)
|
|
87
|
+
}, # body of the wired statement
|
|
88
|
+
result_types=result_types,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def check(self):
|
|
92
|
+
entry_block = self.body.blocks[0]
|
|
93
|
+
|
|
94
|
+
if len(entry_block.args) != len(self.qubits):
|
|
95
|
+
raise exception.StaticCheckError(
|
|
96
|
+
f"Expected {len(self.qubits)} arguments, got {len(entry_block.args)}."
|
|
97
|
+
)
|
|
98
|
+
for arg in entry_block.args:
|
|
99
|
+
if not arg.type.is_subseteq(WireType):
|
|
100
|
+
raise exception.StaticCheckError(
|
|
101
|
+
f"Expected argument of type {WireType}, got {arg.type}."
|
|
102
|
+
)
|
|
103
|
+
for block in self.body.blocks:
|
|
104
|
+
last_stmt = block.last_stmt
|
|
105
|
+
if isinstance(last_stmt, func.Return):
|
|
106
|
+
raise exception.StaticCheckError(
|
|
107
|
+
"Return statements are not allowed in the body of a Wired statement."
|
|
108
|
+
)
|
|
109
|
+
elif isinstance(last_stmt, Yield) and len(last_stmt.values) != len(
|
|
110
|
+
self.results
|
|
111
|
+
):
|
|
112
|
+
raise exception.StaticCheckError(
|
|
113
|
+
f"Expected {len(self.results)} return values, got {len(last_stmt.values)}."
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@statement(dialect=dialect)
|
|
118
|
+
class Yield(ir.Statement):
|
|
119
|
+
traits = frozenset({})
|
|
120
|
+
values: tuple[ir.SSAValue, ...] = info.argument(WireType)
|
|
121
|
+
|
|
122
|
+
def __init__(self, *args: ir.SSAValue):
|
|
123
|
+
super().__init__(
|
|
124
|
+
args=args,
|
|
125
|
+
args_slice={
|
|
126
|
+
"values": slice(0, None),
|
|
127
|
+
},
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
52
131
|
# In Quake, you put a wire in and get a wire out when you "apply" an operator
|
|
53
132
|
# In this case though we just need to indicate that an operator is applied to list[wires]
|
|
54
133
|
@statement(dialect=dialect)
|
|
55
134
|
class Apply(ir.Statement): # apply(op, w1, w2, ...)
|
|
56
|
-
traits = frozenset({lowering.FromPythonCall()
|
|
135
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
57
136
|
operator: ir.SSAValue = info.argument(OpType)
|
|
58
137
|
inputs: tuple[ir.SSAValue, ...] = info.argument(WireType)
|
|
59
138
|
|
|
@@ -88,6 +167,13 @@ class Broadcast(ir.Statement):
|
|
|
88
167
|
) # custom lowering required for wrapper to work here
|
|
89
168
|
|
|
90
169
|
|
|
170
|
+
@statement(dialect=dialect)
|
|
171
|
+
class RegionMeasure(ir.Statement):
|
|
172
|
+
traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
|
|
173
|
+
wire: ir.SSAValue = info.argument(WireType)
|
|
174
|
+
result: ir.ResultValue = info.result(MeasurementResultType)
|
|
175
|
+
|
|
176
|
+
|
|
91
177
|
# NOTE: measurement cannot be pure because they will collapse the state
|
|
92
178
|
# of the qubit. The state is a hidden state that is not visible to
|
|
93
179
|
# the user in the wire dialect.
|
|
@@ -96,14 +182,14 @@ class Measure(ir.Statement):
|
|
|
96
182
|
traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
|
|
97
183
|
wire: ir.SSAValue = info.argument(WireType)
|
|
98
184
|
qubit: ir.SSAValue = info.argument(QubitType)
|
|
99
|
-
result: ir.ResultValue = info.result(
|
|
185
|
+
result: ir.ResultValue = info.result(MeasurementResultType)
|
|
100
186
|
|
|
101
187
|
|
|
102
188
|
@statement(dialect=dialect)
|
|
103
|
-
class
|
|
189
|
+
class LossResolvingMeasure(ir.Statement):
|
|
104
190
|
traits = frozenset({lowering.FromPythonCall()})
|
|
105
191
|
input_wire: ir.SSAValue = info.argument(WireType)
|
|
106
|
-
result: ir.ResultValue = info.result(
|
|
192
|
+
result: ir.ResultValue = info.result(MeasurementResultType)
|
|
107
193
|
out_wire: ir.ResultValue = info.result(WireType)
|
|
108
194
|
|
|
109
195
|
|
bloqade/stim/__init__.py
CHANGED
bloqade/stim/_wrappers.py
CHANGED
|
@@ -190,3 +190,7 @@ def y_error(p: float, targets: tuple[int, ...]) -> None: ...
|
|
|
190
190
|
|
|
191
191
|
@wraps(noise.ZError)
|
|
192
192
|
def z_error(p: float, targets: tuple[int, ...]) -> None: ...
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@wraps(noise.QubitLoss)
|
|
196
|
+
def qubit_loss(probs: tuple[float, ...], targets: tuple[int, ...]) -> None: ...
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from kirin import interp
|
|
2
2
|
|
|
3
3
|
from . import stmts
|
|
4
|
-
from .types import RecordResult
|
|
5
4
|
from ._dialect import dialect
|
|
6
5
|
|
|
7
6
|
|
|
@@ -28,12 +27,3 @@ class StimAuxMethods(interp.MethodTable):
|
|
|
28
27
|
stmt: stmts.Neg,
|
|
29
28
|
):
|
|
30
29
|
return (-frame.get(stmt.operand),)
|
|
31
|
-
|
|
32
|
-
@interp.impl(stmts.GetRecord)
|
|
33
|
-
def get_rec(
|
|
34
|
-
self,
|
|
35
|
-
interpreter: interp.Interpreter,
|
|
36
|
-
frame: interp.Frame,
|
|
37
|
-
stmt: stmts.GetRecord,
|
|
38
|
-
):
|
|
39
|
-
return (RecordResult(value=frame.get(stmt.id)),)
|
|
@@ -10,7 +10,7 @@ PyNum = types.Union(types.Int, types.Float)
|
|
|
10
10
|
@statement(dialect=dialect)
|
|
11
11
|
class GetRecord(ir.Statement):
|
|
12
12
|
name = "get_rec"
|
|
13
|
-
traits = frozenset({lowering.FromPythonCall()})
|
|
13
|
+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
|
|
14
14
|
id: ir.SSAValue = info.argument(type=types.Int)
|
|
15
15
|
result: ir.ResultValue = info.result(type=RecordType)
|
|
16
16
|
|
bloqade/stim/passes/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from .squin_to_stim import
|
|
1
|
+
from .squin_to_stim import SquinToStimPass as SquinToStimPass
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.passes import Pass
|
|
5
|
+
from kirin.rewrite import (
|
|
6
|
+
Walk,
|
|
7
|
+
Chain,
|
|
8
|
+
Fixpoint,
|
|
9
|
+
ConstantFold,
|
|
10
|
+
CommonSubexpressionElimination,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class StimSimplifyIfs(Pass):
|
|
18
|
+
|
|
19
|
+
def unsafe_run(self, mt: ir.Method):
|
|
20
|
+
|
|
21
|
+
result = Chain(
|
|
22
|
+
Fixpoint(Walk(StimLiftThenBody())),
|
|
23
|
+
Walk(StimSplitIfStmts()),
|
|
24
|
+
).rewrite(mt.code)
|
|
25
|
+
|
|
26
|
+
result = (
|
|
27
|
+
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
|
|
28
|
+
.rewrite(mt.code)
|
|
29
|
+
.join(result)
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
return result
|
|
@@ -5,58 +5,153 @@ from kirin.rewrite import (
|
|
|
5
5
|
Walk,
|
|
6
6
|
Chain,
|
|
7
7
|
Fixpoint,
|
|
8
|
+
CFGCompactify,
|
|
9
|
+
InlineGetItem,
|
|
10
|
+
InlineGetField,
|
|
8
11
|
DeadCodeElimination,
|
|
9
12
|
CommonSubexpressionElimination,
|
|
10
13
|
)
|
|
14
|
+
from kirin.analysis import const
|
|
15
|
+
from kirin.dialects import scf, ilist
|
|
11
16
|
from kirin.ir.method import Method
|
|
12
17
|
from kirin.passes.abc import Pass
|
|
13
18
|
from kirin.rewrite.abc import RewriteResult
|
|
19
|
+
from kirin.passes.inline import InlinePass
|
|
14
20
|
|
|
15
|
-
from bloqade.stim.groups import main as stim_main_group
|
|
16
21
|
from bloqade.stim.rewrite import (
|
|
17
22
|
SquinWireToStim,
|
|
18
23
|
PyConstantToStim,
|
|
24
|
+
SquinNoiseToStim,
|
|
19
25
|
SquinQubitToStim,
|
|
20
26
|
SquinMeasureToStim,
|
|
21
27
|
SquinWireIdentityElimination,
|
|
22
28
|
)
|
|
23
|
-
from bloqade.squin.rewrite import
|
|
29
|
+
from bloqade.squin.rewrite import (
|
|
30
|
+
SquinU3ToClifford,
|
|
31
|
+
RemoveDeadRegister,
|
|
32
|
+
WrapAddressAnalysis,
|
|
33
|
+
)
|
|
34
|
+
from bloqade.rewrite.passes import CanonicalizeIList
|
|
35
|
+
from bloqade.analysis.address import AddressAnalysis
|
|
36
|
+
from bloqade.analysis.measure_id import MeasurementIDAnalysis
|
|
37
|
+
|
|
38
|
+
from .simplify_ifs import StimSimplifyIfs
|
|
39
|
+
from ..rewrite.ifs_to_stim import IfToStim
|
|
24
40
|
|
|
25
41
|
|
|
26
42
|
@dataclass
|
|
27
|
-
class
|
|
43
|
+
class SquinToStimPass(Pass):
|
|
28
44
|
|
|
29
45
|
def unsafe_run(self, mt: Method) -> RewriteResult:
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
46
|
+
|
|
47
|
+
cp_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
|
|
48
|
+
cp_results = cp_frame.entries
|
|
33
49
|
|
|
34
50
|
# Assume that address analysis and
|
|
35
51
|
# wrapping has been done before this pass!
|
|
52
|
+
# inline aggressively:
|
|
53
|
+
rewrite_result = InlinePass(
|
|
54
|
+
dialects=mt.dialects, no_raise=self.no_raise
|
|
55
|
+
).unsafe_run(mt)
|
|
56
|
+
|
|
57
|
+
rule = Chain(
|
|
58
|
+
InlineGetField(),
|
|
59
|
+
InlineGetItem(),
|
|
60
|
+
scf.unroll.ForLoop(),
|
|
61
|
+
scf.trim.UnusedYield(),
|
|
62
|
+
)
|
|
63
|
+
rewrite_result = Fixpoint(Walk(rule)).rewrite(mt.code).join(rewrite_result)
|
|
64
|
+
# fold_pass = Fold(mt.dialects, no_raise=self.no_raise)
|
|
65
|
+
# rewrite_result = fold_pass(mt)
|
|
66
|
+
rewrite_result = (
|
|
67
|
+
Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
|
|
68
|
+
)
|
|
69
|
+
rewrite_result = (
|
|
70
|
+
StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
|
|
71
|
+
.unsafe_run(mt)
|
|
72
|
+
.join(rewrite_result)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# run typeinfer again after unroll etc. because we now insert
|
|
76
|
+
# a lot of new nodes, which might have more precise types
|
|
77
|
+
# self.typeinfer.unsafe_run(mt)
|
|
78
|
+
rewrite_result = (
|
|
79
|
+
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
|
|
80
|
+
.rewrite(mt.code)
|
|
81
|
+
.join(rewrite_result)
|
|
82
|
+
)
|
|
83
|
+
rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt)
|
|
84
|
+
|
|
85
|
+
rewrite_result = (
|
|
86
|
+
CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
|
|
87
|
+
.unsafe_run(mt)
|
|
88
|
+
.join(rewrite_result)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# after this the program should be in a state where it is analyzable
|
|
92
|
+
# -------------------------------------------------------------------
|
|
93
|
+
|
|
94
|
+
mia = MeasurementIDAnalysis(dialects=mt.dialects)
|
|
95
|
+
meas_analysis_frame, _ = mia.run_analysis(mt, no_raise=self.no_raise)
|
|
96
|
+
|
|
97
|
+
aa = AddressAnalysis(dialects=mt.dialects)
|
|
98
|
+
address_analysis_frame, _ = aa.run_analysis(mt, no_raise=self.no_raise)
|
|
99
|
+
|
|
100
|
+
# wrap the address analysis result
|
|
101
|
+
rewrite_result = (
|
|
102
|
+
Walk(WrapAddressAnalysis(address_analysis=address_analysis_frame.entries))
|
|
103
|
+
.rewrite(mt.code)
|
|
104
|
+
.join(rewrite_result)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# 2. rewrite
|
|
108
|
+
rewrite_result = (
|
|
109
|
+
Walk(
|
|
110
|
+
IfToStim(
|
|
111
|
+
measure_analysis=meas_analysis_frame.entries,
|
|
112
|
+
measure_count=mia.measure_count,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
.rewrite(mt.code)
|
|
116
|
+
.join(rewrite_result)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Rewrite the noise statements first.
|
|
120
|
+
rewrite_result = (
|
|
121
|
+
Walk(SquinNoiseToStim(cp_results=cp_results))
|
|
122
|
+
.rewrite(mt.code)
|
|
123
|
+
.join(rewrite_result)
|
|
124
|
+
)
|
|
36
125
|
|
|
37
126
|
# Wrap Rewrite + SquinToStim can happen w/ standard walk
|
|
127
|
+
rewrite_result = Walk(SquinU3ToClifford()).rewrite(mt.code).join(rewrite_result)
|
|
128
|
+
|
|
38
129
|
rewrite_result = (
|
|
39
130
|
Walk(
|
|
40
131
|
Chain(
|
|
41
132
|
SquinQubitToStim(),
|
|
42
133
|
SquinWireToStim(),
|
|
43
|
-
SquinMeasureToStim(
|
|
134
|
+
SquinMeasureToStim(
|
|
135
|
+
measure_id_result=meas_analysis_frame.entries,
|
|
136
|
+
total_measure_count=mia.measure_count,
|
|
137
|
+
), # reduce duplicated logic, can split out even more rules later
|
|
44
138
|
SquinWireIdentityElimination(),
|
|
45
139
|
)
|
|
46
140
|
)
|
|
47
141
|
.rewrite(mt.code)
|
|
48
142
|
.join(rewrite_result)
|
|
49
143
|
)
|
|
50
|
-
|
|
51
|
-
# Convert all PyConsts to Stim Constants
|
|
52
144
|
rewrite_result = (
|
|
53
|
-
|
|
145
|
+
CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
|
|
146
|
+
.unsafe_run(mt)
|
|
147
|
+
.join(rewrite_result)
|
|
54
148
|
)
|
|
55
149
|
|
|
56
|
-
#
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
150
|
+
# Convert all PyConsts to Stim Constants
|
|
151
|
+
rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
|
|
152
|
+
|
|
153
|
+
# clear up leftover stmts
|
|
154
|
+
# - remove any squin.qubit.new that's left around
|
|
60
155
|
rewrite_result = (
|
|
61
156
|
Fixpoint(
|
|
62
157
|
Walk(
|
|
@@ -71,16 +166,4 @@ class SquinToStim(Pass):
|
|
|
71
166
|
.join(rewrite_result)
|
|
72
167
|
)
|
|
73
168
|
|
|
74
|
-
# do program verification here,
|
|
75
|
-
# ideally use built-in .verify() to catch any
|
|
76
|
-
# incompatible statements as the full rewrite process should not
|
|
77
|
-
# leave statements from any other dialects (other than the stim main group)
|
|
78
|
-
mt_verification_clone = mt.similar(stim_main_group)
|
|
79
|
-
|
|
80
|
-
# suggested by Kai, will work for now
|
|
81
|
-
for stmt in mt_verification_clone.code.walk():
|
|
82
|
-
assert (
|
|
83
|
-
stmt.dialect in stim_main_group
|
|
84
|
-
), "Statements detected that are not part of the stim dialect, please verify the original code is valid for rewrite!"
|
|
85
|
-
|
|
86
169
|
return rewrite_result
|
bloqade/stim/rewrite/__init__.py
CHANGED