bloqade-circuit 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +21 -68
- bloqade/analysis/measure_id/__init__.py +2 -0
- bloqade/analysis/measure_id/analysis.py +45 -0
- bloqade/analysis/measure_id/impls.py +155 -0
- bloqade/analysis/measure_id/lattice.py +82 -0
- bloqade/qasm2/passes/unroll_if.py +9 -2
- bloqade/rewrite/__init__.py +0 -0
- bloqade/rewrite/passes/__init__.py +1 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +28 -0
- bloqade/rewrite/rules/__init__.py +1 -0
- bloqade/rewrite/rules/flatten_ilist.py +51 -0
- bloqade/rewrite/rules/inline_getitem_ilist.py +31 -0
- bloqade/{qasm2/rewrite → rewrite/rules}/split_ifs.py +15 -8
- bloqade/squin/__init__.py +1 -0
- bloqade/squin/analysis/__init__.py +1 -0
- bloqade/squin/analysis/address_impl.py +71 -0
- bloqade/squin/cirq/lowering.py +2 -1
- bloqade/squin/noise/stmts.py +1 -1
- bloqade/stim/dialects/auxiliary/interp.py +0 -10
- bloqade/stim/dialects/auxiliary/stmts/annotate.py +1 -1
- bloqade/stim/passes/__init__.py +1 -1
- bloqade/stim/passes/simplify_ifs.py +32 -0
- bloqade/stim/passes/squin_to_stim.py +95 -27
- bloqade/stim/rewrite/ifs_to_stim.py +203 -0
- bloqade/stim/rewrite/qubit_to_stim.py +3 -0
- bloqade/stim/rewrite/squin_measure.py +68 -5
- bloqade/stim/rewrite/util.py +0 -4
- bloqade/stim/upstream/__init__.py +1 -0
- bloqade/stim/upstream/from_squin.py +10 -0
- {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/METADATA +1 -1
- {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/RECORD +33 -18
- {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,13 +6,10 @@ from kirin import interp
|
|
|
6
6
|
from kirin.analysis import ForwardFrame, const
|
|
7
7
|
from kirin.dialects import cf, py, scf, func, ilist
|
|
8
8
|
|
|
9
|
-
from bloqade import squin
|
|
10
|
-
|
|
11
9
|
from .lattice import (
|
|
12
10
|
Address,
|
|
13
11
|
NotQubit,
|
|
14
12
|
AddressReg,
|
|
15
|
-
AddressWire,
|
|
16
13
|
AddressQubit,
|
|
17
14
|
AddressTuple,
|
|
18
15
|
)
|
|
@@ -73,8 +70,19 @@ class PyList(interp.MethodTable):
|
|
|
73
70
|
class PyIndexing(interp.MethodTable):
|
|
74
71
|
@interp.impl(py.GetItem)
|
|
75
72
|
def getitem(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.GetItem):
|
|
76
|
-
|
|
77
|
-
|
|
73
|
+
|
|
74
|
+
# determine if the index is an int constant
|
|
75
|
+
# or a slice
|
|
76
|
+
hint = stmt.index.hints.get("const")
|
|
77
|
+
if hint is None:
|
|
78
|
+
return (NotQubit(),)
|
|
79
|
+
|
|
80
|
+
if isinstance(hint, const.Value):
|
|
81
|
+
idx = hint.data
|
|
82
|
+
elif isinstance(hint, slice):
|
|
83
|
+
idx = hint
|
|
84
|
+
else:
|
|
85
|
+
return (NotQubit(),)
|
|
78
86
|
|
|
79
87
|
# The object being indexed into
|
|
80
88
|
obj = frame.get(stmt.obj)
|
|
@@ -82,10 +90,15 @@ class PyIndexing(interp.MethodTable):
|
|
|
82
90
|
# so we just extract that here
|
|
83
91
|
if isinstance(obj, AddressTuple):
|
|
84
92
|
return (obj.data[idx],)
|
|
85
|
-
#
|
|
86
|
-
#
|
|
93
|
+
# If idx is an integer index into an AddressReg,
|
|
94
|
+
# then it's safe to assume a single qubit is being accessed.
|
|
95
|
+
# On the other hand, if it's a slice, we return
|
|
96
|
+
# a new AddressReg to preserve the new sequence.
|
|
87
97
|
elif isinstance(obj, AddressReg):
|
|
88
|
-
|
|
98
|
+
if isinstance(idx, slice):
|
|
99
|
+
return (AddressReg(data=obj.data[idx]),)
|
|
100
|
+
if isinstance(idx, int):
|
|
101
|
+
return (AddressQubit(obj.data[idx]),)
|
|
89
102
|
else:
|
|
90
103
|
return (NotQubit(),)
|
|
91
104
|
|
|
@@ -163,63 +176,3 @@ class Scf(scf.absint.Methods):
|
|
|
163
176
|
return # if terminate is Return, there is no result
|
|
164
177
|
|
|
165
178
|
return loop_vars
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
# Address lattice elements we can work with:
|
|
169
|
-
## NotQubit (bottom), AnyAddress (top)
|
|
170
|
-
|
|
171
|
-
## AddressTuple -> data: tuple[Address, ...]
|
|
172
|
-
### Recursive type, could contain itself or other variants
|
|
173
|
-
### This pops up in cases where you can have an IList/Tuple
|
|
174
|
-
### That contains elements that could be other Address types
|
|
175
|
-
|
|
176
|
-
## AddressReg -> data: Sequence[int]
|
|
177
|
-
### specific to creation of a register of qubits
|
|
178
|
-
|
|
179
|
-
## AddressQubit -> data: int
|
|
180
|
-
### Base qubit address type
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
@squin.wire.dialect.register(key="qubit.address")
|
|
184
|
-
class SquinWireMethodTable(interp.MethodTable):
|
|
185
|
-
|
|
186
|
-
@interp.impl(squin.wire.Unwrap)
|
|
187
|
-
def unwrap(
|
|
188
|
-
self,
|
|
189
|
-
interp_: AddressAnalysis,
|
|
190
|
-
frame: ForwardFrame[Address],
|
|
191
|
-
stmt: squin.wire.Unwrap,
|
|
192
|
-
):
|
|
193
|
-
|
|
194
|
-
origin_qubit = frame.get(stmt.qubit)
|
|
195
|
-
|
|
196
|
-
if isinstance(origin_qubit, AddressQubit):
|
|
197
|
-
return (AddressWire(origin_qubit=origin_qubit),)
|
|
198
|
-
else:
|
|
199
|
-
return (Address.top(),)
|
|
200
|
-
|
|
201
|
-
@interp.impl(squin.wire.Apply)
|
|
202
|
-
def apply(
|
|
203
|
-
self,
|
|
204
|
-
interp_: AddressAnalysis,
|
|
205
|
-
frame: ForwardFrame[Address],
|
|
206
|
-
stmt: squin.wire.Apply,
|
|
207
|
-
):
|
|
208
|
-
return frame.get_values(stmt.inputs)
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
@squin.qubit.dialect.register(key="qubit.address")
|
|
212
|
-
class SquinQubitMethodTable(interp.MethodTable):
|
|
213
|
-
|
|
214
|
-
# This can be treated like a QRegNew impl
|
|
215
|
-
@interp.impl(squin.qubit.New)
|
|
216
|
-
def new(
|
|
217
|
-
self,
|
|
218
|
-
interp_: AddressAnalysis,
|
|
219
|
-
frame: ForwardFrame[Address],
|
|
220
|
-
stmt: squin.qubit.New,
|
|
221
|
-
):
|
|
222
|
-
n_qubits = interp_.get_const_value(int, stmt.n_qubits)
|
|
223
|
-
addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
|
|
224
|
-
interp_.next_address += n_qubits
|
|
225
|
-
return (addr,)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from typing import TypeVar
|
|
2
|
+
|
|
3
|
+
from kirin import ir, interp
|
|
4
|
+
from kirin.analysis import Forward, const
|
|
5
|
+
from kirin.analysis.forward import ForwardFrame
|
|
6
|
+
|
|
7
|
+
from .lattice import MeasureId, NotMeasureId
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MeasurementIDAnalysis(Forward[MeasureId]):
|
|
11
|
+
|
|
12
|
+
keys = ["measure_id"]
|
|
13
|
+
lattice = MeasureId
|
|
14
|
+
# for every kind of measurement encountered, increment this
|
|
15
|
+
# then use this to generate the negative values for target rec indices
|
|
16
|
+
measure_count = 0
|
|
17
|
+
|
|
18
|
+
# Still default to bottom,
|
|
19
|
+
# but let constants return the softer "NoMeasureId" type from impl
|
|
20
|
+
def eval_stmt_fallback(
|
|
21
|
+
self, frame: ForwardFrame[MeasureId], stmt: ir.Statement
|
|
22
|
+
) -> tuple[MeasureId, ...]:
|
|
23
|
+
return tuple(NotMeasureId() for _ in stmt.results)
|
|
24
|
+
|
|
25
|
+
def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
|
|
26
|
+
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
|
|
27
|
+
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
|
|
28
|
+
|
|
29
|
+
T = TypeVar("T")
|
|
30
|
+
|
|
31
|
+
# Xiu-zhe (Roger) Luo came up with this in the address analysis,
|
|
32
|
+
# reused here for convenience
|
|
33
|
+
# TODO: Remove this function once upgrade to kirin 0.18 happens,
|
|
34
|
+
# method is built-in to interpreter then
|
|
35
|
+
def get_const_value(self, input_type: type[T], value: ir.SSAValue) -> T:
|
|
36
|
+
if isinstance(hint := value.hints.get("const"), const.Value):
|
|
37
|
+
data = hint.data
|
|
38
|
+
if isinstance(data, input_type):
|
|
39
|
+
return hint.data
|
|
40
|
+
raise interp.InterpreterError(
|
|
41
|
+
f"Expected constant value <type = {input_type}>, got {data}"
|
|
42
|
+
)
|
|
43
|
+
raise interp.InterpreterError(
|
|
44
|
+
f"Expected constant value <type = {input_type}>, got {value}"
|
|
45
|
+
)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from kirin import types as kirin_types, interp
|
|
2
|
+
from kirin.dialects import py, scf, func, ilist
|
|
3
|
+
|
|
4
|
+
from bloqade.squin import wire, qubit
|
|
5
|
+
|
|
6
|
+
from .lattice import (
|
|
7
|
+
AnyMeasureId,
|
|
8
|
+
NotMeasureId,
|
|
9
|
+
MeasureIdBool,
|
|
10
|
+
MeasureIdTuple,
|
|
11
|
+
InvalidMeasureId,
|
|
12
|
+
)
|
|
13
|
+
from .analysis import MeasurementIDAnalysis
|
|
14
|
+
|
|
15
|
+
## Can't do wire right now because of
|
|
16
|
+
## unresolved RFC on return type
|
|
17
|
+
# from bloqade.squin import wire
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@qubit.dialect.register(key="measure_id")
|
|
21
|
+
class SquinQubit(interp.MethodTable):
|
|
22
|
+
|
|
23
|
+
@interp.impl(qubit.MeasureQubit)
|
|
24
|
+
def measure_qubit(
|
|
25
|
+
self,
|
|
26
|
+
interp: MeasurementIDAnalysis,
|
|
27
|
+
frame: interp.Frame,
|
|
28
|
+
stmt: qubit.MeasureQubit,
|
|
29
|
+
):
|
|
30
|
+
interp.measure_count += 1
|
|
31
|
+
return (MeasureIdBool(interp.measure_count),)
|
|
32
|
+
|
|
33
|
+
@interp.impl(qubit.MeasureQubitList)
|
|
34
|
+
def measure_qubit_list(
|
|
35
|
+
self,
|
|
36
|
+
interp: MeasurementIDAnalysis,
|
|
37
|
+
frame: interp.Frame,
|
|
38
|
+
stmt: qubit.MeasureQubitList,
|
|
39
|
+
):
|
|
40
|
+
|
|
41
|
+
# try to get the length of the list
|
|
42
|
+
## "...safely assume the type inference will give you what you need"
|
|
43
|
+
qubits_type = stmt.qubits.type
|
|
44
|
+
# vars[0] is just the type of the elements in the ilist,
|
|
45
|
+
# vars[1] can contain a literal with length information
|
|
46
|
+
num_qubits = qubits_type.vars[1]
|
|
47
|
+
if not isinstance(num_qubits, kirin_types.Literal):
|
|
48
|
+
return (AnyMeasureId(),)
|
|
49
|
+
|
|
50
|
+
measure_id_bools = []
|
|
51
|
+
for _ in range(num_qubits.data):
|
|
52
|
+
interp.measure_count += 1
|
|
53
|
+
measure_id_bools.append(MeasureIdBool(interp.measure_count))
|
|
54
|
+
|
|
55
|
+
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@wire.dialect.register(key="measure_id")
|
|
59
|
+
class SquinWire(interp.MethodTable):
|
|
60
|
+
|
|
61
|
+
@interp.impl(wire.Measure)
|
|
62
|
+
def measure_qubit(
|
|
63
|
+
self,
|
|
64
|
+
interp: MeasurementIDAnalysis,
|
|
65
|
+
frame: interp.Frame,
|
|
66
|
+
stmt: wire.Measure,
|
|
67
|
+
):
|
|
68
|
+
interp.measure_count += 1
|
|
69
|
+
return (MeasureIdBool(interp.measure_count),)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@ilist.dialect.register(key="measure_id")
|
|
73
|
+
class IList(interp.MethodTable):
|
|
74
|
+
@interp.impl(ilist.New)
|
|
75
|
+
# Because of the way GetItem works,
|
|
76
|
+
# A user could create an ilist of bools that
|
|
77
|
+
# ends up being a mixture of MeasureIdBool and NotMeasureId
|
|
78
|
+
def new_ilist(
|
|
79
|
+
self,
|
|
80
|
+
interp: MeasurementIDAnalysis,
|
|
81
|
+
frame: interp.Frame,
|
|
82
|
+
stmt: ilist.New,
|
|
83
|
+
):
|
|
84
|
+
|
|
85
|
+
measure_ids_in_ilist = frame.get_values(stmt.values)
|
|
86
|
+
return (MeasureIdTuple(data=tuple(measure_ids_in_ilist)),)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@py.tuple.dialect.register(key="measure_id")
|
|
90
|
+
class PyTuple(interp.MethodTable):
|
|
91
|
+
@interp.impl(py.tuple.New)
|
|
92
|
+
def new_tuple(
|
|
93
|
+
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.tuple.New
|
|
94
|
+
):
|
|
95
|
+
measure_ids_in_tuple = frame.get_values(stmt.args)
|
|
96
|
+
return (MeasureIdTuple(data=tuple(measure_ids_in_tuple)),)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@py.indexing.dialect.register(key="measure_id")
|
|
100
|
+
class PyIndexing(interp.MethodTable):
|
|
101
|
+
@interp.impl(py.GetItem)
|
|
102
|
+
def getitem(
|
|
103
|
+
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.GetItem
|
|
104
|
+
):
|
|
105
|
+
idx = interp.get_const_value(int, stmt.index)
|
|
106
|
+
obj = frame.get(stmt.obj)
|
|
107
|
+
if isinstance(obj, MeasureIdTuple):
|
|
108
|
+
return (obj.data[idx],)
|
|
109
|
+
# just propagate these down the line
|
|
110
|
+
elif isinstance(obj, (AnyMeasureId, NotMeasureId)):
|
|
111
|
+
return (obj,)
|
|
112
|
+
else:
|
|
113
|
+
return (InvalidMeasureId(),)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@py.binop.dialect.register(key="measure_id")
|
|
117
|
+
class PyBinOp(interp.MethodTable):
|
|
118
|
+
@interp.impl(py.Add)
|
|
119
|
+
def add(self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.Add):
|
|
120
|
+
lhs = frame.get(stmt.lhs)
|
|
121
|
+
rhs = frame.get(stmt.rhs)
|
|
122
|
+
|
|
123
|
+
if isinstance(lhs, MeasureIdTuple) and isinstance(rhs, MeasureIdTuple):
|
|
124
|
+
return (MeasureIdTuple(data=lhs.data + rhs.data),)
|
|
125
|
+
else:
|
|
126
|
+
return (InvalidMeasureId(),)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@func.dialect.register(key="measure_id")
|
|
130
|
+
class Func(interp.MethodTable):
|
|
131
|
+
@interp.impl(func.Return)
|
|
132
|
+
def return_(self, _: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Return):
|
|
133
|
+
return interp.ReturnValue(frame.get(stmt.value))
|
|
134
|
+
|
|
135
|
+
# taken from Address Analysis implementation from Xiu-zhe (Roger) Luo
|
|
136
|
+
@interp.impl(
|
|
137
|
+
func.Invoke
|
|
138
|
+
) # we know the callee already, func.Call would mean we don't know the callee @ compile time
|
|
139
|
+
def invoke(
|
|
140
|
+
self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke
|
|
141
|
+
):
|
|
142
|
+
_, ret = interp_.run_method(
|
|
143
|
+
stmt.callee,
|
|
144
|
+
interp_.permute_values(
|
|
145
|
+
stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
|
|
146
|
+
),
|
|
147
|
+
)
|
|
148
|
+
return (ret,)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# Just let analysis propagate through
|
|
152
|
+
# scf, particularly IfElse
|
|
153
|
+
@scf.dialect.register(key="measure_id")
|
|
154
|
+
class Scf(scf.absint.Methods):
|
|
155
|
+
pass
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from typing import final
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from kirin.lattice import (
|
|
5
|
+
SingletonMeta,
|
|
6
|
+
BoundedLattice,
|
|
7
|
+
SimpleJoinMixin,
|
|
8
|
+
SimpleMeetMixin,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
# Taken directly from Kai-Hsin Wu's implementation
|
|
12
|
+
# with minor changes to names and addition of CanMeasureId type
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class MeasureId(
|
|
17
|
+
SimpleJoinMixin["MeasureId"],
|
|
18
|
+
SimpleMeetMixin["MeasureId"],
|
|
19
|
+
BoundedLattice["MeasureId"],
|
|
20
|
+
):
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def bottom(cls) -> "MeasureId":
|
|
24
|
+
return InvalidMeasureId()
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def top(cls) -> "MeasureId":
|
|
28
|
+
return AnyMeasureId()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Can pop up if user constructs some list containing a mixture
|
|
32
|
+
# of bools from measure results and other places,
|
|
33
|
+
# in which case the whole list is invalid
|
|
34
|
+
@final
|
|
35
|
+
@dataclass
|
|
36
|
+
class InvalidMeasureId(MeasureId, metaclass=SingletonMeta):
|
|
37
|
+
|
|
38
|
+
def is_subseteq(self, other: MeasureId) -> bool:
|
|
39
|
+
return True
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@final
|
|
43
|
+
@dataclass
|
|
44
|
+
class AnyMeasureId(MeasureId, metaclass=SingletonMeta):
|
|
45
|
+
|
|
46
|
+
def is_subseteq(self, other: MeasureId) -> bool:
|
|
47
|
+
return isinstance(other, AnyMeasureId)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@final
|
|
51
|
+
@dataclass
|
|
52
|
+
class NotMeasureId(MeasureId, metaclass=SingletonMeta):
|
|
53
|
+
|
|
54
|
+
def is_subseteq(self, other: MeasureId) -> bool:
|
|
55
|
+
return isinstance(other, NotMeasureId)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@final
|
|
59
|
+
@dataclass
|
|
60
|
+
class MeasureIdBool(MeasureId):
|
|
61
|
+
idx: int
|
|
62
|
+
|
|
63
|
+
def is_subseteq(self, other: MeasureId) -> bool:
|
|
64
|
+
if isinstance(other, MeasureIdBool):
|
|
65
|
+
return self.idx == other.idx
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Might be nice to have some print override
|
|
70
|
+
# here so all the CanMeasureId's/other types are consolidated for
|
|
71
|
+
# readability
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@final
|
|
75
|
+
@dataclass
|
|
76
|
+
class MeasureIdTuple(MeasureId):
|
|
77
|
+
data: tuple[MeasureId, ...]
|
|
78
|
+
|
|
79
|
+
def is_subseteq(self, other: MeasureId) -> bool:
|
|
80
|
+
if isinstance(other, MeasureIdTuple):
|
|
81
|
+
return all(a.is_subseteq(b) for a, b in zip(self.data, other.data))
|
|
82
|
+
return False
|
|
@@ -7,15 +7,22 @@ from kirin.rewrite import (
|
|
|
7
7
|
ConstantFold,
|
|
8
8
|
CommonSubexpressionElimination,
|
|
9
9
|
)
|
|
10
|
+
from kirin.dialects import scf, func
|
|
10
11
|
|
|
11
|
-
from
|
|
12
|
+
from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
|
|
13
|
+
|
|
14
|
+
from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
|
|
15
|
+
from ..dialects.core.stmts import Reset, Measure
|
|
16
|
+
|
|
17
|
+
AllowedThenType = (SingleQubitGate, TwoQubitCtrlGate, Measure, Reset)
|
|
18
|
+
DontLiftType = AllowedThenType + (scf.Yield, func.Return, func.Invoke)
|
|
12
19
|
|
|
13
20
|
|
|
14
21
|
class UnrollIfs(Pass):
|
|
15
22
|
"""This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
|
|
16
23
|
|
|
17
24
|
def unsafe_run(self, mt: ir.Method):
|
|
18
|
-
result = Walk(LiftThenBody()).rewrite(mt.code)
|
|
25
|
+
result = Walk(LiftThenBody(exclude_stmts=DontLiftType)).rewrite(mt.code)
|
|
19
26
|
result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
|
|
20
27
|
result = (
|
|
21
28
|
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.passes import Pass
|
|
5
|
+
from kirin.rewrite import (
|
|
6
|
+
Walk,
|
|
7
|
+
Chain,
|
|
8
|
+
Fixpoint,
|
|
9
|
+
)
|
|
10
|
+
from kirin.analysis import const
|
|
11
|
+
|
|
12
|
+
from ..rules.flatten_ilist import FlattenAddOpIList
|
|
13
|
+
from ..rules.inline_getitem_ilist import InlineGetItemFromIList
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class CanonicalizeIList(Pass):
|
|
18
|
+
|
|
19
|
+
def unsafe_run(self, mt: ir.Method):
|
|
20
|
+
|
|
21
|
+
cp_result_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
|
|
22
|
+
|
|
23
|
+
return Fixpoint(
|
|
24
|
+
Chain(
|
|
25
|
+
Walk(InlineGetItemFromIList(constprop_result=cp_result_frame.entries)),
|
|
26
|
+
Walk(FlattenAddOpIList()),
|
|
27
|
+
)
|
|
28
|
+
).rewrite(mt.code)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .split_ifs import LiftThenBody as LiftThenBody, SplitIfStmts as SplitIfStmts
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.dialects import py, ilist
|
|
5
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class FlattenAddOpIList(RewriteRule):
|
|
10
|
+
|
|
11
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
12
|
+
if not isinstance(node, py.binop.Add):
|
|
13
|
+
return RewriteResult()
|
|
14
|
+
|
|
15
|
+
# check if we are adding two ilist.New objects
|
|
16
|
+
new_data = ()
|
|
17
|
+
|
|
18
|
+
# lhs:
|
|
19
|
+
if not isinstance(node.lhs.owner, ilist.New):
|
|
20
|
+
if not (
|
|
21
|
+
isinstance(node.lhs.owner, py.Constant)
|
|
22
|
+
and isinstance(
|
|
23
|
+
const_ilist := node.lhs.owner.value.unwrap(), ilist.IList
|
|
24
|
+
)
|
|
25
|
+
and len(const_ilist.data) == 0
|
|
26
|
+
):
|
|
27
|
+
return RewriteResult()
|
|
28
|
+
|
|
29
|
+
else:
|
|
30
|
+
new_data += node.lhs.owner.values
|
|
31
|
+
|
|
32
|
+
# rhs:
|
|
33
|
+
if not isinstance(node.rhs.owner, ilist.New):
|
|
34
|
+
if not (
|
|
35
|
+
isinstance(node.rhs.owner, py.Constant)
|
|
36
|
+
and isinstance(
|
|
37
|
+
const_ilist := node.rhs.owner.value.unwrap(), ilist.IList
|
|
38
|
+
)
|
|
39
|
+
and len(const_ilist.data) == 0
|
|
40
|
+
):
|
|
41
|
+
return RewriteResult()
|
|
42
|
+
|
|
43
|
+
else:
|
|
44
|
+
new_data += node.rhs.owner.values
|
|
45
|
+
|
|
46
|
+
new_stmt = ilist.New(values=new_data)
|
|
47
|
+
node.replace_by(new_stmt)
|
|
48
|
+
|
|
49
|
+
return RewriteResult(
|
|
50
|
+
has_done_something=True,
|
|
51
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.analysis import const
|
|
5
|
+
from kirin.dialects import py, ilist
|
|
6
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class InlineGetItemFromIList(RewriteRule):
|
|
11
|
+
constprop_result: dict[ir.SSAValue, const.Result]
|
|
12
|
+
|
|
13
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
14
|
+
if not isinstance(node, py.indexing.GetItem):
|
|
15
|
+
return RewriteResult()
|
|
16
|
+
|
|
17
|
+
if not isinstance(node.obj.owner, ilist.New):
|
|
18
|
+
return RewriteResult()
|
|
19
|
+
|
|
20
|
+
if not isinstance(
|
|
21
|
+
index_value := self.constprop_result.get(node.index), const.Value
|
|
22
|
+
):
|
|
23
|
+
return RewriteResult()
|
|
24
|
+
|
|
25
|
+
elem_ssa = node.obj.owner.values[index_value.data]
|
|
26
|
+
|
|
27
|
+
node.result.replace_by(elem_ssa)
|
|
28
|
+
|
|
29
|
+
return RewriteResult(
|
|
30
|
+
has_done_something=True,
|
|
31
|
+
)
|
|
@@ -1,18 +1,23 @@
|
|
|
1
|
+
from dataclasses import field, dataclass
|
|
2
|
+
|
|
1
3
|
from kirin import ir
|
|
2
4
|
from kirin.dialects import scf, func
|
|
3
5
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
4
6
|
|
|
5
|
-
from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
|
|
6
|
-
from ..dialects.core.stmts import Reset, Measure
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
8
|
+
@dataclass
|
|
9
|
+
class LiftThenBody(RewriteRule):
|
|
10
|
+
"""
|
|
11
|
+
Lifts anything that's not in the `exclude_stmts` in the *then* body
|
|
12
|
+
|
|
10
13
|
|
|
11
|
-
|
|
14
|
+
Args:
|
|
15
|
+
exclude_stmts: A tuple of statement types that should not be lifted from the then body.
|
|
16
|
+
Defaults to an empty tuple, meaning all statements are lifted.
|
|
12
17
|
|
|
18
|
+
"""
|
|
13
19
|
|
|
14
|
-
|
|
15
|
-
"""Lifts anything that's not a UOP or a yield/return out of the then body"""
|
|
20
|
+
exclude_stmts: tuple[type[ir.Statement], ...] = field(default_factory=tuple)
|
|
16
21
|
|
|
17
22
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
18
23
|
if not isinstance(node, scf.IfElse):
|
|
@@ -20,7 +25,9 @@ class LiftThenBody(RewriteRule):
|
|
|
20
25
|
|
|
21
26
|
then_stmts = node.then_body.stmts()
|
|
22
27
|
|
|
23
|
-
lift_stmts = [
|
|
28
|
+
lift_stmts = [
|
|
29
|
+
stmt for stmt in then_stmts if not isinstance(stmt, self.exclude_stmts)
|
|
30
|
+
]
|
|
24
31
|
|
|
25
32
|
if len(lift_stmts) == 0:
|
|
26
33
|
return RewriteResult()
|
bloqade/squin/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import address_impl as address_impl
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from kirin import interp
|
|
2
|
+
from kirin.analysis import ForwardFrame
|
|
3
|
+
|
|
4
|
+
from bloqade.analysis.address.lattice import (
|
|
5
|
+
Address,
|
|
6
|
+
AddressReg,
|
|
7
|
+
AddressWire,
|
|
8
|
+
AddressQubit,
|
|
9
|
+
)
|
|
10
|
+
from bloqade.analysis.address.analysis import AddressAnalysis
|
|
11
|
+
|
|
12
|
+
from .. import wire, qubit
|
|
13
|
+
|
|
14
|
+
# Address lattice elements we can work with:
|
|
15
|
+
## NotQubit (bottom), AnyAddress (top)
|
|
16
|
+
|
|
17
|
+
## AddressTuple -> data: tuple[Address, ...]
|
|
18
|
+
### Recursive type, could contain itself or other variants
|
|
19
|
+
### This pops up in cases where you can have an IList/Tuple
|
|
20
|
+
### That contains elements that could be other Address types
|
|
21
|
+
|
|
22
|
+
## AddressReg -> data: Sequence[int]
|
|
23
|
+
### specific to creation of a register of qubits
|
|
24
|
+
|
|
25
|
+
## AddressQubit -> data: int
|
|
26
|
+
### Base qubit address type
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@wire.dialect.register(key="qubit.address")
|
|
30
|
+
class SquinWireMethodTable(interp.MethodTable):
|
|
31
|
+
|
|
32
|
+
@interp.impl(wire.Unwrap)
|
|
33
|
+
def unwrap(
|
|
34
|
+
self,
|
|
35
|
+
interp_: AddressAnalysis,
|
|
36
|
+
frame: ForwardFrame[Address],
|
|
37
|
+
stmt: wire.Unwrap,
|
|
38
|
+
):
|
|
39
|
+
|
|
40
|
+
origin_qubit = frame.get(stmt.qubit)
|
|
41
|
+
|
|
42
|
+
if isinstance(origin_qubit, AddressQubit):
|
|
43
|
+
return (AddressWire(origin_qubit=origin_qubit),)
|
|
44
|
+
else:
|
|
45
|
+
return (Address.top(),)
|
|
46
|
+
|
|
47
|
+
@interp.impl(wire.Apply)
|
|
48
|
+
def apply(
|
|
49
|
+
self,
|
|
50
|
+
interp_: AddressAnalysis,
|
|
51
|
+
frame: ForwardFrame[Address],
|
|
52
|
+
stmt: wire.Apply,
|
|
53
|
+
):
|
|
54
|
+
return frame.get_values(stmt.inputs)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@qubit.dialect.register(key="qubit.address")
|
|
58
|
+
class SquinQubitMethodTable(interp.MethodTable):
|
|
59
|
+
|
|
60
|
+
# This can be treated like a QRegNew impl
|
|
61
|
+
@interp.impl(qubit.New)
|
|
62
|
+
def new(
|
|
63
|
+
self,
|
|
64
|
+
interp_: AddressAnalysis,
|
|
65
|
+
frame: ForwardFrame[Address],
|
|
66
|
+
stmt: qubit.New,
|
|
67
|
+
):
|
|
68
|
+
n_qubits = interp_.get_const_value(int, stmt.n_qubits)
|
|
69
|
+
addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
|
|
70
|
+
interp_.next_address += n_qubits
|
|
71
|
+
return (addr,)
|
bloqade/squin/cirq/lowering.py
CHANGED
|
@@ -281,7 +281,8 @@ class Squin(lowering.LoweringABC[CirqNode]):
|
|
|
281
281
|
return state.current_frame.push(op.stmts.Z())
|
|
282
282
|
|
|
283
283
|
# NOTE: just for the Z gate, an arbitrary exponent is equivalent to the ShiftOp
|
|
284
|
-
|
|
284
|
+
# up to a minus sign!
|
|
285
|
+
t = -node.exponent
|
|
285
286
|
theta = state.current_frame.push(py.Constant(math.pi * t))
|
|
286
287
|
return state.current_frame.push(op.stmts.ShiftOp(theta=theta.result))
|
|
287
288
|
|