bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bloqade/analysis/address/__init__.py +8 -4
- bloqade/analysis/address/analysis.py +123 -33
- bloqade/analysis/address/impls.py +293 -90
- bloqade/analysis/address/lattice.py +209 -24
- bloqade/analysis/fidelity/analysis.py +11 -23
- bloqade/analysis/measure_id/__init__.py +4 -1
- bloqade/analysis/measure_id/analysis.py +29 -20
- bloqade/analysis/measure_id/impls.py +72 -31
- bloqade/annotate/__init__.py +6 -0
- bloqade/annotate/_dialect.py +3 -0
- bloqade/annotate/_interface.py +22 -0
- bloqade/annotate/stmts.py +29 -0
- bloqade/annotate/types.py +13 -0
- bloqade/cirq_utils/__init__.py +4 -2
- bloqade/cirq_utils/emit/__init__.py +3 -0
- bloqade/cirq_utils/emit/base.py +246 -0
- bloqade/cirq_utils/emit/gate.py +104 -0
- bloqade/cirq_utils/emit/noise.py +90 -0
- bloqade/cirq_utils/emit/qubit.py +35 -0
- bloqade/cirq_utils/lowering.py +660 -0
- bloqade/cirq_utils/noise/__init__.py +0 -2
- bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
- bloqade/cirq_utils/noise/model.py +151 -191
- bloqade/cirq_utils/noise/transform.py +2 -2
- bloqade/cirq_utils/parallelize.py +9 -6
- bloqade/gemini/__init__.py +1 -0
- bloqade/gemini/analysis/__init__.py +3 -0
- bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
- bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
- bloqade/gemini/analysis/logical_validation/impls.py +101 -0
- bloqade/gemini/groups.py +67 -0
- bloqade/native/__init__.py +23 -0
- bloqade/native/_prelude.py +45 -0
- bloqade/native/dialects/__init__.py +0 -0
- bloqade/native/dialects/gate/__init__.py +2 -0
- bloqade/native/dialects/gate/_dialect.py +3 -0
- bloqade/native/dialects/gate/_interface.py +32 -0
- bloqade/native/dialects/gate/stmts.py +31 -0
- bloqade/native/stdlib/__init__.py +0 -0
- bloqade/native/stdlib/broadcast.py +246 -0
- bloqade/native/stdlib/simple.py +220 -0
- bloqade/native/upstream/__init__.py +4 -0
- bloqade/native/upstream/squin2native.py +79 -0
- bloqade/pyqrack/__init__.py +2 -2
- bloqade/pyqrack/base.py +7 -1
- bloqade/pyqrack/device.py +190 -4
- bloqade/pyqrack/native.py +49 -0
- bloqade/pyqrack/reg.py +6 -6
- bloqade/pyqrack/squin/gate/__init__.py +1 -0
- bloqade/pyqrack/squin/gate/gate.py +136 -0
- bloqade/pyqrack/squin/noise/native.py +120 -54
- bloqade/pyqrack/squin/qubit.py +39 -36
- bloqade/pyqrack/target.py +5 -4
- bloqade/pyqrack/task.py +114 -7
- bloqade/qasm2/_qasm_loading.py +3 -3
- bloqade/qasm2/dialects/core/address.py +21 -12
- bloqade/qasm2/dialects/expr/_emit.py +19 -8
- bloqade/qasm2/dialects/expr/stmts.py +7 -7
- bloqade/qasm2/dialects/noise/fidelity.py +4 -8
- bloqade/qasm2/dialects/noise/model.py +2 -1
- bloqade/qasm2/emit/base.py +16 -11
- bloqade/qasm2/emit/gate.py +11 -8
- bloqade/qasm2/emit/main.py +103 -3
- bloqade/qasm2/emit/target.py +9 -5
- bloqade/qasm2/groups.py +3 -2
- bloqade/qasm2/parse/lowering.py +0 -1
- bloqade/qasm2/passes/fold.py +14 -73
- bloqade/qasm2/passes/glob.py +2 -2
- bloqade/qasm2/passes/noise.py +1 -1
- bloqade/qasm2/passes/parallel.py +7 -5
- bloqade/qasm2/rewrite/__init__.py +0 -1
- bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
- bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
- bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
- bloqade/qasm2/rewrite/register.py +2 -2
- bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
- bloqade/qbraid/lowering.py +1 -0
- bloqade/qbraid/schema.py +2 -2
- bloqade/qubit/__init__.py +12 -0
- bloqade/qubit/_dialect.py +3 -0
- bloqade/qubit/_interface.py +49 -0
- bloqade/qubit/_prelude.py +45 -0
- bloqade/qubit/analysis/__init__.py +1 -0
- bloqade/qubit/analysis/address_impl.py +40 -0
- bloqade/qubit/stdlib/__init__.py +2 -0
- bloqade/qubit/stdlib/_new.py +34 -0
- bloqade/qubit/stdlib/broadcast.py +62 -0
- bloqade/qubit/stdlib/simple.py +59 -0
- bloqade/qubit/stmts.py +60 -0
- bloqade/rewrite/passes/__init__.py +6 -0
- bloqade/rewrite/passes/aggressive_unroll.py +103 -0
- bloqade/rewrite/passes/callgraph.py +116 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
- bloqade/rewrite/rules/split_ifs.py +18 -1
- bloqade/squin/__init__.py +47 -14
- bloqade/squin/analysis/__init__.py +0 -1
- bloqade/squin/analysis/schedule.py +10 -11
- bloqade/squin/gate/__init__.py +2 -0
- bloqade/squin/gate/_dialect.py +3 -0
- bloqade/squin/gate/_interface.py +98 -0
- bloqade/squin/gate/stmts.py +125 -0
- bloqade/squin/groups.py +5 -22
- bloqade/squin/noise/__init__.py +1 -10
- bloqade/squin/noise/_dialect.py +1 -1
- bloqade/squin/noise/_interface.py +45 -0
- bloqade/squin/noise/stmts.py +66 -28
- bloqade/squin/rewrite/U3_to_clifford.py +70 -51
- bloqade/squin/rewrite/__init__.py +0 -2
- bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
- bloqade/squin/rewrite/wrap_analysis.py +4 -35
- bloqade/squin/stdlib/__init__.py +0 -0
- bloqade/squin/stdlib/broadcast/__init__.py +34 -0
- bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
- bloqade/squin/stdlib/broadcast/gate.py +260 -0
- bloqade/squin/stdlib/broadcast/noise.py +144 -0
- bloqade/squin/stdlib/simple/__init__.py +33 -0
- bloqade/squin/stdlib/simple/gate.py +242 -0
- bloqade/squin/stdlib/simple/noise.py +126 -0
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +6 -0
- bloqade/stim/dialects/auxiliary/emit.py +19 -18
- bloqade/stim/dialects/collapse/emit_str.py +7 -8
- bloqade/stim/dialects/gate/emit.py +9 -10
- bloqade/stim/dialects/noise/emit.py +17 -13
- bloqade/stim/dialects/noise/stmts.py +5 -3
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/impls.py +16 -0
- bloqade/stim/emit/stim_str.py +48 -31
- bloqade/stim/groups.py +12 -2
- bloqade/stim/parse/lowering.py +14 -17
- bloqade/stim/passes/__init__.py +3 -1
- bloqade/stim/passes/flatten.py +26 -0
- bloqade/stim/passes/simplify_ifs.py +16 -2
- bloqade/stim/passes/squin_to_stim.py +18 -60
- bloqade/stim/rewrite/__init__.py +3 -4
- bloqade/stim/rewrite/get_record_util.py +24 -0
- bloqade/stim/rewrite/ifs_to_stim.py +29 -31
- bloqade/stim/rewrite/qubit_to_stim.py +90 -41
- bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
- bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
- bloqade/stim/rewrite/squin_measure.py +11 -79
- bloqade/stim/rewrite/squin_noise.py +134 -108
- bloqade/stim/rewrite/util.py +5 -192
- bloqade/test_utils.py +1 -1
- bloqade/types.py +10 -0
- bloqade/validation/__init__.py +2 -0
- bloqade/validation/analysis/__init__.py +5 -0
- bloqade/validation/analysis/analysis.py +41 -0
- bloqade/validation/analysis/lattice.py +58 -0
- bloqade/validation/kernel_validation.py +77 -0
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
- bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
- bloqade/pyqrack/squin/op.py +0 -166
- bloqade/pyqrack/squin/runtime.py +0 -535
- bloqade/pyqrack/squin/wire.py +0 -51
- bloqade/rewrite/rules/flatten_ilist.py +0 -51
- bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
- bloqade/squin/_typeinfer.py +0 -20
- bloqade/squin/analysis/address_impl.py +0 -71
- bloqade/squin/analysis/nsites/__init__.py +0 -9
- bloqade/squin/analysis/nsites/analysis.py +0 -50
- bloqade/squin/analysis/nsites/impls.py +0 -92
- bloqade/squin/analysis/nsites/lattice.py +0 -49
- bloqade/squin/cirq/__init__.py +0 -265
- bloqade/squin/cirq/emit/emit_circuit.py +0 -109
- bloqade/squin/cirq/emit/noise.py +0 -49
- bloqade/squin/cirq/emit/op.py +0 -125
- bloqade/squin/cirq/emit/qubit.py +0 -60
- bloqade/squin/cirq/emit/runtime.py +0 -242
- bloqade/squin/cirq/lowering.py +0 -440
- bloqade/squin/lowering.py +0 -54
- bloqade/squin/noise/_wrapper.py +0 -40
- bloqade/squin/noise/rewrite.py +0 -111
- bloqade/squin/op/__init__.py +0 -41
- bloqade/squin/op/_dialect.py +0 -3
- bloqade/squin/op/_wrapper.py +0 -121
- bloqade/squin/op/number.py +0 -5
- bloqade/squin/op/rewrite.py +0 -46
- bloqade/squin/op/stdlib.py +0 -62
- bloqade/squin/op/stmts.py +0 -276
- bloqade/squin/op/traits.py +0 -43
- bloqade/squin/op/types.py +0 -26
- bloqade/squin/qubit.py +0 -184
- bloqade/squin/rewrite/canonicalize.py +0 -60
- bloqade/squin/rewrite/desugar.py +0 -124
- bloqade/squin/types.py +0 -8
- bloqade/squin/wire.py +0 -201
- bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
- bloqade/stim/rewrite/wire_to_stim.py +0 -57
- bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,12 +1,18 @@
|
|
|
1
1
|
from typing import Sequence, final
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
|
|
4
|
+
from kirin import ir, types
|
|
4
5
|
from kirin.lattice import (
|
|
5
6
|
SingletonMeta,
|
|
6
7
|
BoundedLattice,
|
|
7
8
|
SimpleJoinMixin,
|
|
8
9
|
SimpleMeetMixin,
|
|
9
10
|
)
|
|
11
|
+
from kirin.analysis import const
|
|
12
|
+
from kirin.dialects import ilist
|
|
13
|
+
from kirin.ir.attrs.abc import LatticeAttributeMeta
|
|
14
|
+
|
|
15
|
+
from bloqade.types import QubitType
|
|
10
16
|
|
|
11
17
|
|
|
12
18
|
@dataclass
|
|
@@ -18,54 +24,94 @@ class Address(
|
|
|
18
24
|
|
|
19
25
|
@classmethod
|
|
20
26
|
def bottom(cls) -> "Address":
|
|
21
|
-
return
|
|
27
|
+
return Bottom()
|
|
22
28
|
|
|
23
29
|
@classmethod
|
|
24
30
|
def top(cls) -> "Address":
|
|
25
|
-
return
|
|
31
|
+
return Unknown()
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def from_type(cls, typ: types.TypeAttribute):
|
|
35
|
+
if typ.is_subseteq(ilist.IListType[QubitType]):
|
|
36
|
+
return UnknownReg()
|
|
37
|
+
elif typ.is_subseteq(QubitType):
|
|
38
|
+
return UnknownQubit()
|
|
39
|
+
else:
|
|
40
|
+
return Unknown()
|
|
26
41
|
|
|
27
42
|
|
|
28
43
|
@final
|
|
29
|
-
|
|
30
|
-
|
|
44
|
+
class Bottom(Address, metaclass=SingletonMeta):
|
|
45
|
+
"""Error during interpretation"""
|
|
31
46
|
|
|
32
47
|
def is_subseteq(self, other: Address) -> bool:
|
|
33
48
|
return True
|
|
34
49
|
|
|
35
50
|
|
|
36
51
|
@final
|
|
37
|
-
|
|
38
|
-
|
|
52
|
+
class Unknown(Address, metaclass=SingletonMeta):
|
|
53
|
+
"""Can't determine if it is an address or constant."""
|
|
39
54
|
|
|
40
55
|
def is_subseteq(self, other: Address) -> bool:
|
|
41
|
-
return isinstance(other,
|
|
56
|
+
return isinstance(other, Unknown)
|
|
42
57
|
|
|
43
58
|
|
|
44
59
|
@final
|
|
45
60
|
@dataclass
|
|
46
|
-
class
|
|
47
|
-
|
|
61
|
+
class ConstResult(Address):
|
|
62
|
+
"""Stores a constant prop result in the lattice"""
|
|
63
|
+
|
|
64
|
+
result: const.Result
|
|
48
65
|
|
|
49
66
|
def is_subseteq(self, other: Address) -> bool:
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
67
|
+
return isinstance(other, ConstResult) and self.result.is_subseteq(other.result)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class QubitLike(Address):
|
|
71
|
+
def join(self, other: Address):
|
|
72
|
+
if isinstance(other, QubitLike):
|
|
73
|
+
return super().join(other)
|
|
74
|
+
return self.bottom()
|
|
75
|
+
|
|
76
|
+
def meet(self, other: Address):
|
|
77
|
+
if isinstance(other, QubitLike):
|
|
78
|
+
return super().meet(other)
|
|
79
|
+
return self.bottom()
|
|
53
80
|
|
|
54
81
|
|
|
55
82
|
@final
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
data: Sequence[int]
|
|
83
|
+
class UnknownQubit(QubitLike, metaclass=SingletonMeta):
|
|
84
|
+
"""A lattice element representing a single qubit with an unknown address."""
|
|
59
85
|
|
|
60
86
|
def is_subseteq(self, other: Address) -> bool:
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
87
|
+
return isinstance(other, QubitLike)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class RegisterLike(Address):
|
|
91
|
+
def join(self, other: Address):
|
|
92
|
+
if isinstance(other, RegisterLike):
|
|
93
|
+
return super().join(other)
|
|
94
|
+
return self.bottom()
|
|
95
|
+
|
|
96
|
+
def meet(self, other: Address):
|
|
97
|
+
if isinstance(other, RegisterLike):
|
|
98
|
+
return super().meet(other)
|
|
99
|
+
return self.bottom()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@final
|
|
103
|
+
class UnknownReg(RegisterLike, metaclass=SingletonMeta):
|
|
104
|
+
"""A lattice element representing a container of qubits with unknown indices."""
|
|
105
|
+
|
|
106
|
+
def is_subseteq(self, other: Address) -> bool:
|
|
107
|
+
return isinstance(other, RegisterLike)
|
|
64
108
|
|
|
65
109
|
|
|
66
110
|
@final
|
|
67
111
|
@dataclass
|
|
68
|
-
class AddressQubit(
|
|
112
|
+
class AddressQubit(QubitLike):
|
|
113
|
+
"""A lattice element representing a single qubit with a known address."""
|
|
114
|
+
|
|
69
115
|
data: int
|
|
70
116
|
|
|
71
117
|
def is_subseteq(self, other: Address) -> bool:
|
|
@@ -76,10 +122,149 @@ class AddressQubit(Address):
|
|
|
76
122
|
|
|
77
123
|
@final
|
|
78
124
|
@dataclass
|
|
79
|
-
class
|
|
80
|
-
|
|
125
|
+
class AddressReg(RegisterLike):
|
|
126
|
+
"""A lattice element representing a container of qubits with known indices."""
|
|
127
|
+
|
|
128
|
+
data: Sequence[int]
|
|
81
129
|
|
|
82
130
|
def is_subseteq(self, other: Address) -> bool:
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
131
|
+
return isinstance(other, AddressReg) and self.data == other.data
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def qubits(self) -> tuple[AddressQubit, ...]:
|
|
135
|
+
return tuple(AddressQubit(i) for i in self.data)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@final
|
|
139
|
+
@dataclass
|
|
140
|
+
class PartialLambda(Address):
|
|
141
|
+
"""Represents a partially known lambda function"""
|
|
142
|
+
|
|
143
|
+
argnames: list[str]
|
|
144
|
+
code: ir.Statement
|
|
145
|
+
captured: tuple[Address, ...]
|
|
146
|
+
|
|
147
|
+
def join(self, other: Address) -> Address:
|
|
148
|
+
if other is other.bottom():
|
|
149
|
+
return self
|
|
150
|
+
|
|
151
|
+
if not isinstance(other, PartialLambda):
|
|
152
|
+
return self.top().join(other) # widen self
|
|
153
|
+
|
|
154
|
+
if self.code is not other.code:
|
|
155
|
+
return self.top() # lambda stmt is pure
|
|
156
|
+
|
|
157
|
+
if len(self.captured) != len(other.captured):
|
|
158
|
+
return self.bottom() # err
|
|
159
|
+
|
|
160
|
+
return PartialLambda(
|
|
161
|
+
self.argnames,
|
|
162
|
+
self.code,
|
|
163
|
+
tuple(x.join(y) for x, y in zip(self.captured, other.captured)),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def meet(self, other: Address) -> Address:
|
|
167
|
+
if not isinstance(other, PartialLambda):
|
|
168
|
+
return self.top().meet(other)
|
|
169
|
+
|
|
170
|
+
if self.code is not other.code:
|
|
171
|
+
return self.bottom()
|
|
172
|
+
|
|
173
|
+
if len(self.captured) != len(other.captured):
|
|
174
|
+
return self.top()
|
|
175
|
+
|
|
176
|
+
return PartialLambda(
|
|
177
|
+
self.argnames,
|
|
178
|
+
self.code,
|
|
179
|
+
tuple(x.meet(y) for x, y in zip(self.captured, other.captured)),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def is_subseteq(self, other: Address) -> bool:
|
|
183
|
+
return (
|
|
184
|
+
isinstance(other, PartialLambda)
|
|
185
|
+
and self.code is other.code
|
|
186
|
+
and self.argnames == other.argnames
|
|
187
|
+
and len(self.captured) == len(other.captured)
|
|
188
|
+
and all(
|
|
189
|
+
self_ele.is_subseteq(other_ele)
|
|
190
|
+
for self_ele, other_ele in zip(self.captured, other.captured)
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@dataclass
|
|
196
|
+
class StaticContainer(Address):
|
|
197
|
+
"""A lattice element representing the results of any static container, e. g. ilist or tuple."""
|
|
198
|
+
|
|
199
|
+
data: tuple[Address, ...]
|
|
200
|
+
|
|
201
|
+
@classmethod
|
|
202
|
+
def new(cls, data: tuple[Address, ...]):
|
|
203
|
+
return cls(data)
|
|
204
|
+
|
|
205
|
+
def join(self, other: "Address") -> "Address":
|
|
206
|
+
if isinstance(other, type(self)) and len(self.data) == len(other.data):
|
|
207
|
+
return self.new(tuple(x.join(y) for x, y in zip(self.data, other.data)))
|
|
208
|
+
return self.top()
|
|
209
|
+
|
|
210
|
+
def meet(self, other: "Address") -> "Address":
|
|
211
|
+
if isinstance(other, type(self)) and len(self.data) == len(other.data):
|
|
212
|
+
return self.new(tuple(x.meet(y) for x, y in zip(self.data, other.data)))
|
|
213
|
+
return self.bottom()
|
|
214
|
+
|
|
215
|
+
def is_subseteq(self, other: "Address") -> bool:
|
|
216
|
+
return (
|
|
217
|
+
isinstance(other, type(self))
|
|
218
|
+
and len(self.data) == len(other.data)
|
|
219
|
+
and all(x.is_subseteq(y) for x, y in zip(self.data, other.data))
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class PartialIListMeta(LatticeAttributeMeta):
|
|
224
|
+
"""This metaclass assures that PartialILists of ConstResults or AddressQubits are canonicalized
|
|
225
|
+
to a single ConstResult or AddressReg respectively.
|
|
226
|
+
|
|
227
|
+
because AddressReg is a specialization of PartialIList, being a container of pure qubit
|
|
228
|
+
addresses. For Operations that act in generic containers (e.g., ilist.ForEach),
|
|
229
|
+
AddressReg is treated as PartialIList but for other types of analysis it is often
|
|
230
|
+
useful to distinguish between a generic IList and a pure qubit address list.
|
|
231
|
+
|
|
232
|
+
Inside the method tables the `GetValuesMixin` implements a method that effectively
|
|
233
|
+
undoes this canonicalization.
|
|
234
|
+
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
def __call__(cls, data: tuple[Address, ...]):
|
|
238
|
+
# TODO: when constant prop has PartialIList, make sure to canonicalize here.
|
|
239
|
+
if types.is_tuple_of(data, ConstResult) and types.is_tuple_of(
|
|
240
|
+
all_constants := tuple(ele.result for ele in data), const.Value
|
|
241
|
+
):
|
|
242
|
+
# all constants, create constant list
|
|
243
|
+
return ConstResult(
|
|
244
|
+
const.Value(ilist.IList([ele.data for ele in all_constants]))
|
|
245
|
+
)
|
|
246
|
+
elif types.is_tuple_of(data, AddressQubit):
|
|
247
|
+
# all qubits create qubit register
|
|
248
|
+
return AddressReg(tuple(ele.data for ele in data))
|
|
249
|
+
else:
|
|
250
|
+
return super().__call__(data)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@final
|
|
254
|
+
class PartialIList(StaticContainer, metaclass=PartialIListMeta):
|
|
255
|
+
"""A lattice element representing a partially known ilist."""
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class PartialTupleMeta(LatticeAttributeMeta):
|
|
259
|
+
"""This metaclass assures that PartialTuples of ConstResults are canonicalized to a single ConstResult."""
|
|
260
|
+
|
|
261
|
+
def __call__(cls, data: tuple[Address, ...]):
|
|
262
|
+
if not types.is_tuple_of(data, ConstResult):
|
|
263
|
+
return super().__call__(data)
|
|
264
|
+
|
|
265
|
+
return ConstResult(const.PartialTuple(tuple(ele.result for ele in data)))
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@final
|
|
269
|
+
class PartialTuple(StaticContainer, metaclass=PartialTupleMeta):
|
|
270
|
+
"""A lattice element representing a partially known tuple."""
|
|
@@ -4,10 +4,9 @@ from dataclasses import field
|
|
|
4
4
|
from kirin import ir
|
|
5
5
|
from kirin.lattice import EmptyLattice
|
|
6
6
|
from kirin.analysis import Forward
|
|
7
|
-
from kirin.interp.value import Successor
|
|
8
7
|
from kirin.analysis.forward import ForwardFrame
|
|
9
8
|
|
|
10
|
-
from ..address import AddressAnalysis
|
|
9
|
+
from ..address import Address, AddressAnalysis
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
class FidelityAnalysis(Forward):
|
|
@@ -48,16 +47,12 @@ class FidelityAnalysis(Forward):
|
|
|
48
47
|
The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
|
|
49
48
|
"""
|
|
50
49
|
|
|
51
|
-
_current_gate_fidelity: float = field(init=False)
|
|
52
|
-
|
|
53
50
|
atom_survival_probability: list[float] = field(init=False)
|
|
54
51
|
"""
|
|
55
52
|
The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
|
|
56
53
|
"""
|
|
57
54
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
addr_frame: ForwardFrame = field(init=False)
|
|
55
|
+
addr_frame: ForwardFrame[Address] = field(init=False)
|
|
61
56
|
|
|
62
57
|
def initialize(self):
|
|
63
58
|
super().initialize()
|
|
@@ -67,28 +62,21 @@ class FidelityAnalysis(Forward):
|
|
|
67
62
|
]
|
|
68
63
|
return self
|
|
69
64
|
|
|
70
|
-
def
|
|
71
|
-
self.gate_fidelity *= self._current_gate_fidelity
|
|
72
|
-
for i, _current_survival in enumerate(self._current_atom_survival_probability):
|
|
73
|
-
self.atom_survival_probability[i] *= _current_survival
|
|
74
|
-
|
|
75
|
-
def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
|
|
65
|
+
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
|
|
76
66
|
# NOTE: default is to conserve fidelity, so do nothing here
|
|
77
67
|
return
|
|
78
68
|
|
|
79
|
-
def
|
|
80
|
-
|
|
69
|
+
def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]:
|
|
70
|
+
self._run_address_analysis(method)
|
|
71
|
+
return super().run(method, *args, **kwargs)
|
|
81
72
|
|
|
82
|
-
def
|
|
83
|
-
self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
|
|
84
|
-
) -> tuple[ForwardFrame, Any]:
|
|
85
|
-
self._run_address_analysis(method, no_raise=no_raise)
|
|
86
|
-
return super().run_analysis(method, args, no_raise=no_raise)
|
|
87
|
-
|
|
88
|
-
def _run_address_analysis(self, method: ir.Method, no_raise: bool):
|
|
73
|
+
def _run_address_analysis(self, method: ir.Method):
|
|
89
74
|
addr_analysis = AddressAnalysis(self.dialects)
|
|
90
|
-
addr_frame, _ = addr_analysis.
|
|
75
|
+
addr_frame, _ = addr_analysis.run(method=method)
|
|
91
76
|
self.addr_frame = addr_frame
|
|
92
77
|
|
|
93
78
|
# NOTE: make sure we have as many probabilities as we have addresses
|
|
94
79
|
self.atom_survival_probability = [1.0] * addr_analysis.qubit_count
|
|
80
|
+
|
|
81
|
+
def method_self(self, method: ir.Method) -> EmptyLattice:
|
|
82
|
+
return self.lattice.bottom()
|
|
@@ -1,13 +1,19 @@
|
|
|
1
1
|
from typing import TypeVar
|
|
2
|
+
from dataclasses import field, dataclass
|
|
2
3
|
|
|
3
|
-
from kirin import ir
|
|
4
|
-
from kirin.analysis import
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.analysis import ForwardExtra, const
|
|
5
6
|
from kirin.analysis.forward import ForwardFrame
|
|
6
7
|
|
|
7
8
|
from .lattice import MeasureId, NotMeasureId
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
|
|
11
|
+
@dataclass
|
|
12
|
+
class MeasureIDFrame(ForwardFrame[MeasureId]):
|
|
13
|
+
num_measures_at_stmt: dict[ir.Statement, int] = field(default_factory=dict)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):
|
|
11
17
|
|
|
12
18
|
keys = ["measure_id"]
|
|
13
19
|
lattice = MeasureId
|
|
@@ -15,31 +21,34 @@ class MeasurementIDAnalysis(Forward[MeasureId]):
|
|
|
15
21
|
# then use this to generate the negative values for target rec indices
|
|
16
22
|
measure_count = 0
|
|
17
23
|
|
|
24
|
+
def initialize_frame(
|
|
25
|
+
self, node: ir.Statement, *, has_parent_access: bool = False
|
|
26
|
+
) -> MeasureIDFrame:
|
|
27
|
+
return MeasureIDFrame(node, has_parent_access=has_parent_access)
|
|
28
|
+
|
|
18
29
|
# Still default to bottom,
|
|
19
30
|
# but let constants return the softer "NoMeasureId" type from impl
|
|
20
|
-
def
|
|
21
|
-
self, frame: ForwardFrame[MeasureId],
|
|
31
|
+
def eval_fallback(
|
|
32
|
+
self, frame: ForwardFrame[MeasureId], node: ir.Statement
|
|
22
33
|
) -> tuple[MeasureId, ...]:
|
|
23
|
-
return tuple(NotMeasureId() for _ in
|
|
24
|
-
|
|
25
|
-
def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
|
|
26
|
-
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
|
|
27
|
-
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
|
|
28
|
-
|
|
29
|
-
T = TypeVar("T")
|
|
34
|
+
return tuple(NotMeasureId() for _ in node.results)
|
|
30
35
|
|
|
31
36
|
# Xiu-zhe (Roger) Luo came up with this in the address analysis,
|
|
32
|
-
# reused here for convenience
|
|
37
|
+
# reused here for convenience (now modified to be a bit more graceful)
|
|
33
38
|
# TODO: Remove this function once upgrade to kirin 0.18 happens,
|
|
34
39
|
# method is built-in to interpreter then
|
|
35
|
-
|
|
40
|
+
|
|
41
|
+
T = TypeVar("T")
|
|
42
|
+
|
|
43
|
+
def get_const_value(
|
|
44
|
+
self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue
|
|
45
|
+
) -> type[T] | None:
|
|
36
46
|
if isinstance(hint := value.hints.get("const"), const.Value):
|
|
37
47
|
data = hint.data
|
|
38
48
|
if isinstance(data, input_type):
|
|
39
49
|
return hint.data
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
)
|
|
50
|
+
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
def method_self(self, method: ir.Method) -> MeasureId:
|
|
54
|
+
return self.lattice.bottom()
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from kirin import types as kirin_types, interp
|
|
2
|
+
from kirin.analysis import const
|
|
2
3
|
from kirin.dialects import py, scf, func, ilist
|
|
3
4
|
|
|
4
|
-
from bloqade
|
|
5
|
+
from bloqade import qubit, annotate
|
|
5
6
|
|
|
6
7
|
from .lattice import (
|
|
7
8
|
AnyMeasureId,
|
|
@@ -10,7 +11,7 @@ from .lattice import (
|
|
|
10
11
|
MeasureIdTuple,
|
|
11
12
|
InvalidMeasureId,
|
|
12
13
|
)
|
|
13
|
-
from .analysis import MeasurementIDAnalysis
|
|
14
|
+
from .analysis import MeasureIDFrame, MeasurementIDAnalysis
|
|
14
15
|
|
|
15
16
|
## Can't do wire right now because of
|
|
16
17
|
## unresolved RFC on return type
|
|
@@ -20,22 +21,12 @@ from .analysis import MeasurementIDAnalysis
|
|
|
20
21
|
@qubit.dialect.register(key="measure_id")
|
|
21
22
|
class SquinQubit(interp.MethodTable):
|
|
22
23
|
|
|
23
|
-
@interp.impl(qubit.
|
|
24
|
-
def measure_qubit(
|
|
25
|
-
self,
|
|
26
|
-
interp: MeasurementIDAnalysis,
|
|
27
|
-
frame: interp.Frame,
|
|
28
|
-
stmt: qubit.MeasureQubit,
|
|
29
|
-
):
|
|
30
|
-
interp.measure_count += 1
|
|
31
|
-
return (MeasureIdBool(interp.measure_count),)
|
|
32
|
-
|
|
33
|
-
@interp.impl(qubit.MeasureQubitList)
|
|
24
|
+
@interp.impl(qubit.stmts.Measure)
|
|
34
25
|
def measure_qubit_list(
|
|
35
26
|
self,
|
|
36
27
|
interp: MeasurementIDAnalysis,
|
|
37
28
|
frame: interp.Frame,
|
|
38
|
-
stmt: qubit.
|
|
29
|
+
stmt: qubit.stmts.Measure,
|
|
39
30
|
):
|
|
40
31
|
|
|
41
32
|
# try to get the length of the list
|
|
@@ -55,18 +46,18 @@ class SquinQubit(interp.MethodTable):
|
|
|
55
46
|
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
|
|
56
47
|
|
|
57
48
|
|
|
58
|
-
@
|
|
59
|
-
class
|
|
60
|
-
|
|
61
|
-
@interp.impl(
|
|
62
|
-
def
|
|
49
|
+
@annotate.dialect.register(key="measure_id")
|
|
50
|
+
class Annotate(interp.MethodTable):
|
|
51
|
+
@interp.impl(annotate.stmts.SetObservable)
|
|
52
|
+
@interp.impl(annotate.stmts.SetDetector)
|
|
53
|
+
def consumes_measurement_results(
|
|
63
54
|
self,
|
|
64
55
|
interp: MeasurementIDAnalysis,
|
|
65
|
-
frame:
|
|
66
|
-
stmt:
|
|
56
|
+
frame: MeasureIDFrame,
|
|
57
|
+
stmt: annotate.stmts.SetObservable | annotate.stmts.SetDetector,
|
|
67
58
|
):
|
|
68
|
-
interp.measure_count
|
|
69
|
-
return (
|
|
59
|
+
frame.num_measures_at_stmt[stmt] = interp.measure_count
|
|
60
|
+
return (NotMeasureId(),)
|
|
70
61
|
|
|
71
62
|
|
|
72
63
|
@ilist.dialect.register(key="measure_id")
|
|
@@ -102,10 +93,23 @@ class PyIndexing(interp.MethodTable):
|
|
|
102
93
|
def getitem(
|
|
103
94
|
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.GetItem
|
|
104
95
|
):
|
|
105
|
-
|
|
96
|
+
|
|
97
|
+
idx_or_slice = interp.get_const_value((int, slice), stmt.index)
|
|
98
|
+
if idx_or_slice is None:
|
|
99
|
+
return (InvalidMeasureId(),)
|
|
100
|
+
|
|
101
|
+
# hint = stmt.index.hints.get("const")
|
|
102
|
+
# if hint is None or not isinstance(hint, const.Value):
|
|
103
|
+
# return (InvalidMeasureId(),)
|
|
104
|
+
|
|
106
105
|
obj = frame.get(stmt.obj)
|
|
107
106
|
if isinstance(obj, MeasureIdTuple):
|
|
108
|
-
|
|
107
|
+
if isinstance(idx_or_slice, slice):
|
|
108
|
+
return (MeasureIdTuple(data=obj.data[idx_or_slice]),)
|
|
109
|
+
elif isinstance(idx_or_slice, int):
|
|
110
|
+
return (obj.data[idx_or_slice],)
|
|
111
|
+
else:
|
|
112
|
+
return (InvalidMeasureId(),)
|
|
109
113
|
# just propagate these down the line
|
|
110
114
|
elif isinstance(obj, (AnyMeasureId, NotMeasureId)):
|
|
111
115
|
return (obj,)
|
|
@@ -113,6 +117,15 @@ class PyIndexing(interp.MethodTable):
|
|
|
113
117
|
return (InvalidMeasureId(),)
|
|
114
118
|
|
|
115
119
|
|
|
120
|
+
@py.assign.dialect.register(key="measure_id")
|
|
121
|
+
class PyAssign(interp.MethodTable):
|
|
122
|
+
@interp.impl(py.Alias)
|
|
123
|
+
def alias(
|
|
124
|
+
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.assign.Alias
|
|
125
|
+
):
|
|
126
|
+
return (frame.get(stmt.value),)
|
|
127
|
+
|
|
128
|
+
|
|
116
129
|
@py.binop.dialect.register(key="measure_id")
|
|
117
130
|
class PyBinOp(interp.MethodTable):
|
|
118
131
|
@interp.impl(py.Add)
|
|
@@ -139,11 +152,10 @@ class Func(interp.MethodTable):
|
|
|
139
152
|
def invoke(
|
|
140
153
|
self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke
|
|
141
154
|
):
|
|
142
|
-
_, ret = interp_.
|
|
143
|
-
stmt.callee,
|
|
144
|
-
interp_.
|
|
145
|
-
|
|
146
|
-
),
|
|
155
|
+
_, ret = interp_.call(
|
|
156
|
+
stmt.callee.code,
|
|
157
|
+
interp_.method_self(stmt.callee),
|
|
158
|
+
*frame.get_values(stmt.inputs),
|
|
147
159
|
)
|
|
148
160
|
return (ret,)
|
|
149
161
|
|
|
@@ -152,4 +164,33 @@ class Func(interp.MethodTable):
|
|
|
152
164
|
# scf, particularly IfElse
|
|
153
165
|
@scf.dialect.register(key="measure_id")
|
|
154
166
|
class Scf(scf.absint.Methods):
|
|
155
|
-
|
|
167
|
+
|
|
168
|
+
@interp.impl(scf.IfElse)
|
|
169
|
+
def if_else(
|
|
170
|
+
self,
|
|
171
|
+
interp_: MeasurementIDAnalysis,
|
|
172
|
+
frame: MeasureIDFrame,
|
|
173
|
+
stmt: scf.IfElse,
|
|
174
|
+
):
|
|
175
|
+
|
|
176
|
+
frame.num_measures_at_stmt[stmt] = interp_.measure_count
|
|
177
|
+
|
|
178
|
+
# rest of the code taken directly from scf.absint.Methods base implementation
|
|
179
|
+
|
|
180
|
+
if isinstance(hint := stmt.cond.hints.get("const"), const.Value):
|
|
181
|
+
if hint.data:
|
|
182
|
+
return self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
|
|
183
|
+
else:
|
|
184
|
+
return self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
|
|
185
|
+
then_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
|
|
186
|
+
else_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
|
|
187
|
+
|
|
188
|
+
match (then_results, else_results):
|
|
189
|
+
case (interp.ReturnValue(then_value), interp.ReturnValue(else_value)):
|
|
190
|
+
return interp.ReturnValue(then_value.join(else_value))
|
|
191
|
+
case (interp.ReturnValue(then_value), _):
|
|
192
|
+
return then_results
|
|
193
|
+
case (_, interp.ReturnValue(else_value)):
|
|
194
|
+
return else_results
|
|
195
|
+
case _:
|
|
196
|
+
return interp_.join_results(then_results, else_results)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
from kirin.lowering import wraps
|
|
5
|
+
|
|
6
|
+
from bloqade.types import MeasurementResult
|
|
7
|
+
|
|
8
|
+
from .stmts import SetDetector, SetObservable
|
|
9
|
+
from .types import Detector, Observable
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@wraps(SetDetector)
|
|
13
|
+
def set_detector(
|
|
14
|
+
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
|
|
15
|
+
coordinates: ilist.IList[float | int, Any] | list[float | int],
|
|
16
|
+
) -> Detector: ...
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@wraps(SetObservable)
|
|
20
|
+
def set_observable(
|
|
21
|
+
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
|
|
22
|
+
) -> Observable: ...
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from kirin import ir, types as kirin_types, lowering
|
|
2
|
+
from kirin.decl import info, statement
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
|
|
5
|
+
from bloqade.types import MeasurementResultType
|
|
6
|
+
from bloqade.annotate.types import DetectorType, ObservableType
|
|
7
|
+
|
|
8
|
+
from ._dialect import dialect
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@statement
|
|
12
|
+
class ConsumesMeasurementResults(ir.Statement):
|
|
13
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
14
|
+
measurements: ir.SSAValue = info.argument(
|
|
15
|
+
ilist.IListType[MeasurementResultType, kirin_types.Any]
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@statement(dialect=dialect)
|
|
20
|
+
class SetDetector(ConsumesMeasurementResults):
|
|
21
|
+
coordinates: ir.SSAValue = info.argument(
|
|
22
|
+
type=ilist.IListType[kirin_types.Int | kirin_types.Float, kirin_types.Any]
|
|
23
|
+
)
|
|
24
|
+
result: ir.ResultValue = info.result(DetectorType)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@statement(dialect=dialect)
|
|
28
|
+
class SetObservable(ConsumesMeasurementResults):
|
|
29
|
+
result: ir.ResultValue = info.result(ObservableType)
|