bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/__init__.py +4 -1
  7. bloqade/analysis/measure_id/analysis.py +29 -20
  8. bloqade/analysis/measure_id/impls.py +72 -31
  9. bloqade/annotate/__init__.py +6 -0
  10. bloqade/annotate/_dialect.py +3 -0
  11. bloqade/annotate/_interface.py +22 -0
  12. bloqade/annotate/stmts.py +29 -0
  13. bloqade/annotate/types.py +13 -0
  14. bloqade/cirq_utils/__init__.py +4 -2
  15. bloqade/cirq_utils/emit/__init__.py +3 -0
  16. bloqade/cirq_utils/emit/base.py +246 -0
  17. bloqade/cirq_utils/emit/gate.py +104 -0
  18. bloqade/cirq_utils/emit/noise.py +90 -0
  19. bloqade/cirq_utils/emit/qubit.py +35 -0
  20. bloqade/cirq_utils/lowering.py +660 -0
  21. bloqade/cirq_utils/noise/__init__.py +0 -2
  22. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  23. bloqade/cirq_utils/noise/model.py +151 -191
  24. bloqade/cirq_utils/noise/transform.py +2 -2
  25. bloqade/cirq_utils/parallelize.py +9 -6
  26. bloqade/gemini/__init__.py +1 -0
  27. bloqade/gemini/analysis/__init__.py +3 -0
  28. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  29. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  30. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  31. bloqade/gemini/groups.py +67 -0
  32. bloqade/native/__init__.py +23 -0
  33. bloqade/native/_prelude.py +45 -0
  34. bloqade/native/dialects/__init__.py +0 -0
  35. bloqade/native/dialects/gate/__init__.py +2 -0
  36. bloqade/native/dialects/gate/_dialect.py +3 -0
  37. bloqade/native/dialects/gate/_interface.py +32 -0
  38. bloqade/native/dialects/gate/stmts.py +31 -0
  39. bloqade/native/stdlib/__init__.py +0 -0
  40. bloqade/native/stdlib/broadcast.py +246 -0
  41. bloqade/native/stdlib/simple.py +220 -0
  42. bloqade/native/upstream/__init__.py +4 -0
  43. bloqade/native/upstream/squin2native.py +79 -0
  44. bloqade/pyqrack/__init__.py +2 -2
  45. bloqade/pyqrack/base.py +7 -1
  46. bloqade/pyqrack/device.py +190 -4
  47. bloqade/pyqrack/native.py +49 -0
  48. bloqade/pyqrack/reg.py +6 -6
  49. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  50. bloqade/pyqrack/squin/gate/gate.py +136 -0
  51. bloqade/pyqrack/squin/noise/native.py +120 -54
  52. bloqade/pyqrack/squin/qubit.py +39 -36
  53. bloqade/pyqrack/target.py +5 -4
  54. bloqade/pyqrack/task.py +114 -7
  55. bloqade/qasm2/_qasm_loading.py +3 -3
  56. bloqade/qasm2/dialects/core/address.py +21 -12
  57. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  58. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  59. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  60. bloqade/qasm2/dialects/noise/model.py +2 -1
  61. bloqade/qasm2/emit/base.py +16 -11
  62. bloqade/qasm2/emit/gate.py +11 -8
  63. bloqade/qasm2/emit/main.py +103 -3
  64. bloqade/qasm2/emit/target.py +9 -5
  65. bloqade/qasm2/groups.py +3 -2
  66. bloqade/qasm2/parse/lowering.py +0 -1
  67. bloqade/qasm2/passes/fold.py +14 -73
  68. bloqade/qasm2/passes/glob.py +2 -2
  69. bloqade/qasm2/passes/noise.py +1 -1
  70. bloqade/qasm2/passes/parallel.py +7 -5
  71. bloqade/qasm2/rewrite/__init__.py +0 -1
  72. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  73. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  74. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  75. bloqade/qasm2/rewrite/register.py +2 -2
  76. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  77. bloqade/qbraid/lowering.py +1 -0
  78. bloqade/qbraid/schema.py +2 -2
  79. bloqade/qubit/__init__.py +12 -0
  80. bloqade/qubit/_dialect.py +3 -0
  81. bloqade/qubit/_interface.py +49 -0
  82. bloqade/qubit/_prelude.py +45 -0
  83. bloqade/qubit/analysis/__init__.py +1 -0
  84. bloqade/qubit/analysis/address_impl.py +40 -0
  85. bloqade/qubit/stdlib/__init__.py +2 -0
  86. bloqade/qubit/stdlib/_new.py +34 -0
  87. bloqade/qubit/stdlib/broadcast.py +62 -0
  88. bloqade/qubit/stdlib/simple.py +59 -0
  89. bloqade/qubit/stmts.py +60 -0
  90. bloqade/rewrite/passes/__init__.py +6 -0
  91. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  92. bloqade/rewrite/passes/callgraph.py +116 -0
  93. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  94. bloqade/rewrite/rules/split_ifs.py +18 -1
  95. bloqade/squin/__init__.py +47 -14
  96. bloqade/squin/analysis/__init__.py +0 -1
  97. bloqade/squin/analysis/schedule.py +10 -11
  98. bloqade/squin/gate/__init__.py +2 -0
  99. bloqade/squin/gate/_dialect.py +3 -0
  100. bloqade/squin/gate/_interface.py +98 -0
  101. bloqade/squin/gate/stmts.py +125 -0
  102. bloqade/squin/groups.py +5 -22
  103. bloqade/squin/noise/__init__.py +1 -10
  104. bloqade/squin/noise/_dialect.py +1 -1
  105. bloqade/squin/noise/_interface.py +45 -0
  106. bloqade/squin/noise/stmts.py +66 -28
  107. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  108. bloqade/squin/rewrite/__init__.py +0 -2
  109. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  110. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  111. bloqade/squin/stdlib/__init__.py +0 -0
  112. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  113. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  114. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  115. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  116. bloqade/squin/stdlib/simple/__init__.py +33 -0
  117. bloqade/squin/stdlib/simple/gate.py +242 -0
  118. bloqade/squin/stdlib/simple/noise.py +126 -0
  119. bloqade/stim/__init__.py +1 -0
  120. bloqade/stim/_wrappers.py +6 -0
  121. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  122. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  123. bloqade/stim/dialects/gate/emit.py +9 -10
  124. bloqade/stim/dialects/noise/emit.py +17 -13
  125. bloqade/stim/dialects/noise/stmts.py +5 -3
  126. bloqade/stim/emit/__init__.py +1 -0
  127. bloqade/stim/emit/impls.py +16 -0
  128. bloqade/stim/emit/stim_str.py +48 -31
  129. bloqade/stim/groups.py +12 -2
  130. bloqade/stim/parse/lowering.py +14 -17
  131. bloqade/stim/passes/__init__.py +3 -1
  132. bloqade/stim/passes/flatten.py +26 -0
  133. bloqade/stim/passes/simplify_ifs.py +16 -2
  134. bloqade/stim/passes/squin_to_stim.py +18 -60
  135. bloqade/stim/rewrite/__init__.py +3 -4
  136. bloqade/stim/rewrite/get_record_util.py +24 -0
  137. bloqade/stim/rewrite/ifs_to_stim.py +29 -31
  138. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  139. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  140. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  141. bloqade/stim/rewrite/squin_measure.py +11 -79
  142. bloqade/stim/rewrite/squin_noise.py +134 -108
  143. bloqade/stim/rewrite/util.py +5 -192
  144. bloqade/test_utils.py +1 -1
  145. bloqade/types.py +10 -0
  146. bloqade/validation/__init__.py +2 -0
  147. bloqade/validation/analysis/__init__.py +5 -0
  148. bloqade/validation/analysis/analysis.py +41 -0
  149. bloqade/validation/analysis/lattice.py +58 -0
  150. bloqade/validation/kernel_validation.py +77 -0
  151. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  152. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  153. bloqade/pyqrack/squin/op.py +0 -166
  154. bloqade/pyqrack/squin/runtime.py +0 -535
  155. bloqade/pyqrack/squin/wire.py +0 -51
  156. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  157. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  158. bloqade/squin/_typeinfer.py +0 -20
  159. bloqade/squin/analysis/address_impl.py +0 -71
  160. bloqade/squin/analysis/nsites/__init__.py +0 -9
  161. bloqade/squin/analysis/nsites/analysis.py +0 -50
  162. bloqade/squin/analysis/nsites/impls.py +0 -92
  163. bloqade/squin/analysis/nsites/lattice.py +0 -49
  164. bloqade/squin/cirq/__init__.py +0 -265
  165. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  166. bloqade/squin/cirq/emit/noise.py +0 -49
  167. bloqade/squin/cirq/emit/op.py +0 -125
  168. bloqade/squin/cirq/emit/qubit.py +0 -60
  169. bloqade/squin/cirq/emit/runtime.py +0 -242
  170. bloqade/squin/cirq/lowering.py +0 -440
  171. bloqade/squin/lowering.py +0 -54
  172. bloqade/squin/noise/_wrapper.py +0 -40
  173. bloqade/squin/noise/rewrite.py +0 -111
  174. bloqade/squin/op/__init__.py +0 -41
  175. bloqade/squin/op/_dialect.py +0 -3
  176. bloqade/squin/op/_wrapper.py +0 -121
  177. bloqade/squin/op/number.py +0 -5
  178. bloqade/squin/op/rewrite.py +0 -46
  179. bloqade/squin/op/stdlib.py +0 -62
  180. bloqade/squin/op/stmts.py +0 -276
  181. bloqade/squin/op/traits.py +0 -43
  182. bloqade/squin/op/types.py +0 -26
  183. bloqade/squin/qubit.py +0 -184
  184. bloqade/squin/rewrite/canonicalize.py +0 -60
  185. bloqade/squin/rewrite/desugar.py +0 -124
  186. bloqade/squin/types.py +0 -8
  187. bloqade/squin/wire.py +0 -201
  188. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  189. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  190. bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
  191. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  192. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,18 @@
1
1
  from typing import Sequence, final
2
2
  from dataclasses import dataclass
3
3
 
4
+ from kirin import ir, types
4
5
  from kirin.lattice import (
5
6
  SingletonMeta,
6
7
  BoundedLattice,
7
8
  SimpleJoinMixin,
8
9
  SimpleMeetMixin,
9
10
  )
11
+ from kirin.analysis import const
12
+ from kirin.dialects import ilist
13
+ from kirin.ir.attrs.abc import LatticeAttributeMeta
14
+
15
+ from bloqade.types import QubitType
10
16
 
11
17
 
12
18
  @dataclass
@@ -18,54 +24,94 @@ class Address(
18
24
 
19
25
  @classmethod
20
26
  def bottom(cls) -> "Address":
21
- return NotQubit()
27
+ return Bottom()
22
28
 
23
29
  @classmethod
24
30
  def top(cls) -> "Address":
25
- return AnyAddress()
31
+ return Unknown()
32
+
33
+ @classmethod
34
+ def from_type(cls, typ: types.TypeAttribute):
35
+ if typ.is_subseteq(ilist.IListType[QubitType]):
36
+ return UnknownReg()
37
+ elif typ.is_subseteq(QubitType):
38
+ return UnknownQubit()
39
+ else:
40
+ return Unknown()
26
41
 
27
42
 
28
43
  @final
29
- @dataclass
30
- class NotQubit(Address, metaclass=SingletonMeta):
44
+ class Bottom(Address, metaclass=SingletonMeta):
45
+ """Error during interpretation"""
31
46
 
32
47
  def is_subseteq(self, other: Address) -> bool:
33
48
  return True
34
49
 
35
50
 
36
51
  @final
37
- @dataclass
38
- class AnyAddress(Address, metaclass=SingletonMeta):
52
+ class Unknown(Address, metaclass=SingletonMeta):
53
+ """Can't determine if it is an address or constant."""
39
54
 
40
55
  def is_subseteq(self, other: Address) -> bool:
41
- return isinstance(other, AnyAddress)
56
+ return isinstance(other, Unknown)
42
57
 
43
58
 
44
59
  @final
45
60
  @dataclass
46
- class AddressTuple(Address):
47
- data: tuple[Address, ...]
61
+ class ConstResult(Address):
62
+ """Stores a constant prop result in the lattice"""
63
+
64
+ result: const.Result
48
65
 
49
66
  def is_subseteq(self, other: Address) -> bool:
50
- if isinstance(other, AddressTuple):
51
- return all(a.is_subseteq(b) for a, b in zip(self.data, other.data))
52
- return False
67
+ return isinstance(other, ConstResult) and self.result.is_subseteq(other.result)
68
+
69
+
70
+ class QubitLike(Address):
71
+ def join(self, other: Address):
72
+ if isinstance(other, QubitLike):
73
+ return super().join(other)
74
+ return self.bottom()
75
+
76
+ def meet(self, other: Address):
77
+ if isinstance(other, QubitLike):
78
+ return super().meet(other)
79
+ return self.bottom()
53
80
 
54
81
 
55
82
  @final
56
- @dataclass
57
- class AddressReg(Address):
58
- data: Sequence[int]
83
+ class UnknownQubit(QubitLike, metaclass=SingletonMeta):
84
+ """A lattice element representing a single qubit with an unknown address."""
59
85
 
60
86
  def is_subseteq(self, other: Address) -> bool:
61
- if isinstance(other, AddressReg):
62
- return self.data == other.data
63
- return False
87
+ return isinstance(other, QubitLike)
88
+
89
+
90
+ class RegisterLike(Address):
91
+ def join(self, other: Address):
92
+ if isinstance(other, RegisterLike):
93
+ return super().join(other)
94
+ return self.bottom()
95
+
96
+ def meet(self, other: Address):
97
+ if isinstance(other, RegisterLike):
98
+ return super().meet(other)
99
+ return self.bottom()
100
+
101
+
102
+ @final
103
+ class UnknownReg(RegisterLike, metaclass=SingletonMeta):
104
+ """A lattice element representing a container of qubits with unknown indices."""
105
+
106
+ def is_subseteq(self, other: Address) -> bool:
107
+ return isinstance(other, RegisterLike)
64
108
 
65
109
 
66
110
  @final
67
111
  @dataclass
68
- class AddressQubit(Address):
112
+ class AddressQubit(QubitLike):
113
+ """A lattice element representing a single qubit with a known address."""
114
+
69
115
  data: int
70
116
 
71
117
  def is_subseteq(self, other: Address) -> bool:
@@ -76,10 +122,149 @@ class AddressQubit(Address):
76
122
 
77
123
  @final
78
124
  @dataclass
79
- class AddressWire(Address):
80
- origin_qubit: AddressQubit
125
+ class AddressReg(RegisterLike):
126
+ """A lattice element representing a container of qubits with known indices."""
127
+
128
+ data: Sequence[int]
81
129
 
82
130
  def is_subseteq(self, other: Address) -> bool:
83
- if isinstance(other, AddressWire):
84
- return self.origin_qubit == other.origin_qubit
85
- return False
131
+ return isinstance(other, AddressReg) and self.data == other.data
132
+
133
+ @property
134
+ def qubits(self) -> tuple[AddressQubit, ...]:
135
+ return tuple(AddressQubit(i) for i in self.data)
136
+
137
+
138
+ @final
139
+ @dataclass
140
+ class PartialLambda(Address):
141
+ """Represents a partially known lambda function"""
142
+
143
+ argnames: list[str]
144
+ code: ir.Statement
145
+ captured: tuple[Address, ...]
146
+
147
+ def join(self, other: Address) -> Address:
148
+ if other is other.bottom():
149
+ return self
150
+
151
+ if not isinstance(other, PartialLambda):
152
+ return self.top().join(other) # widen self
153
+
154
+ if self.code is not other.code:
155
+ return self.top() # lambda stmt is pure
156
+
157
+ if len(self.captured) != len(other.captured):
158
+ return self.bottom() # err
159
+
160
+ return PartialLambda(
161
+ self.argnames,
162
+ self.code,
163
+ tuple(x.join(y) for x, y in zip(self.captured, other.captured)),
164
+ )
165
+
166
+ def meet(self, other: Address) -> Address:
167
+ if not isinstance(other, PartialLambda):
168
+ return self.top().meet(other)
169
+
170
+ if self.code is not other.code:
171
+ return self.bottom()
172
+
173
+ if len(self.captured) != len(other.captured):
174
+ return self.top()
175
+
176
+ return PartialLambda(
177
+ self.argnames,
178
+ self.code,
179
+ tuple(x.meet(y) for x, y in zip(self.captured, other.captured)),
180
+ )
181
+
182
+ def is_subseteq(self, other: Address) -> bool:
183
+ return (
184
+ isinstance(other, PartialLambda)
185
+ and self.code is other.code
186
+ and self.argnames == other.argnames
187
+ and len(self.captured) == len(other.captured)
188
+ and all(
189
+ self_ele.is_subseteq(other_ele)
190
+ for self_ele, other_ele in zip(self.captured, other.captured)
191
+ )
192
+ )
193
+
194
+
195
+ @dataclass
196
+ class StaticContainer(Address):
197
+ """A lattice element representing the results of any static container, e. g. ilist or tuple."""
198
+
199
+ data: tuple[Address, ...]
200
+
201
+ @classmethod
202
+ def new(cls, data: tuple[Address, ...]):
203
+ return cls(data)
204
+
205
+ def join(self, other: "Address") -> "Address":
206
+ if isinstance(other, type(self)) and len(self.data) == len(other.data):
207
+ return self.new(tuple(x.join(y) for x, y in zip(self.data, other.data)))
208
+ return self.top()
209
+
210
+ def meet(self, other: "Address") -> "Address":
211
+ if isinstance(other, type(self)) and len(self.data) == len(other.data):
212
+ return self.new(tuple(x.meet(y) for x, y in zip(self.data, other.data)))
213
+ return self.bottom()
214
+
215
+ def is_subseteq(self, other: "Address") -> bool:
216
+ return (
217
+ isinstance(other, type(self))
218
+ and len(self.data) == len(other.data)
219
+ and all(x.is_subseteq(y) for x, y in zip(self.data, other.data))
220
+ )
221
+
222
+
223
+ class PartialIListMeta(LatticeAttributeMeta):
224
+ """This metaclass assures that PartialILists of ConstResults or AddressQubits are canonicalized
225
+ to a single ConstResult or AddressReg respectively.
226
+
227
+ because AddressReg is a specialization of PartialIList, being a container of pure qubit
228
+ addresses. For Operations that act in generic containers (e.g., ilist.ForEach),
229
+ AddressReg is treated as PartialIList but for other types of analysis it is often
230
+ useful to distinguish between a generic IList and a pure qubit address list.
231
+
232
+ Inside the method tables the `GetValuesMixin` implements a method that effectively
233
+ undoes this canonicalization.
234
+
235
+ """
236
+
237
+ def __call__(cls, data: tuple[Address, ...]):
238
+ # TODO: when constant prop has PartialIList, make sure to canonicalize here.
239
+ if types.is_tuple_of(data, ConstResult) and types.is_tuple_of(
240
+ all_constants := tuple(ele.result for ele in data), const.Value
241
+ ):
242
+ # all constants, create constant list
243
+ return ConstResult(
244
+ const.Value(ilist.IList([ele.data for ele in all_constants]))
245
+ )
246
+ elif types.is_tuple_of(data, AddressQubit):
247
+ # all qubits create qubit register
248
+ return AddressReg(tuple(ele.data for ele in data))
249
+ else:
250
+ return super().__call__(data)
251
+
252
+
253
+ @final
254
+ class PartialIList(StaticContainer, metaclass=PartialIListMeta):
255
+ """A lattice element representing a partially known ilist."""
256
+
257
+
258
+ class PartialTupleMeta(LatticeAttributeMeta):
259
+ """This metaclass assures that PartialTuples of ConstResults are canonicalized to a single ConstResult."""
260
+
261
+ def __call__(cls, data: tuple[Address, ...]):
262
+ if not types.is_tuple_of(data, ConstResult):
263
+ return super().__call__(data)
264
+
265
+ return ConstResult(const.PartialTuple(tuple(ele.result for ele in data)))
266
+
267
+
268
+ @final
269
+ class PartialTuple(StaticContainer, metaclass=PartialTupleMeta):
270
+ """A lattice element representing a partially known tuple."""
@@ -4,10 +4,9 @@ from dataclasses import field
4
4
  from kirin import ir
5
5
  from kirin.lattice import EmptyLattice
6
6
  from kirin.analysis import Forward
7
- from kirin.interp.value import Successor
8
7
  from kirin.analysis.forward import ForwardFrame
9
8
 
10
- from ..address import AddressAnalysis
9
+ from ..address import Address, AddressAnalysis
11
10
 
12
11
 
13
12
  class FidelityAnalysis(Forward):
@@ -48,16 +47,12 @@ class FidelityAnalysis(Forward):
48
47
  The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
49
48
  """
50
49
 
51
- _current_gate_fidelity: float = field(init=False)
52
-
53
50
  atom_survival_probability: list[float] = field(init=False)
54
51
  """
55
52
  The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
56
53
  """
57
54
 
58
- _current_atom_survival_probability: list[float] = field(init=False)
59
-
60
- addr_frame: ForwardFrame = field(init=False)
55
+ addr_frame: ForwardFrame[Address] = field(init=False)
61
56
 
62
57
  def initialize(self):
63
58
  super().initialize()
@@ -67,28 +62,21 @@ class FidelityAnalysis(Forward):
67
62
  ]
68
63
  return self
69
64
 
70
- def posthook_succ(self, frame: ForwardFrame, succ: Successor):
71
- self.gate_fidelity *= self._current_gate_fidelity
72
- for i, _current_survival in enumerate(self._current_atom_survival_probability):
73
- self.atom_survival_probability[i] *= _current_survival
74
-
75
- def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
65
+ def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
76
66
  # NOTE: default is to conserve fidelity, so do nothing here
77
67
  return
78
68
 
79
- def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):
80
- return self.run_callable(method.code, (self.lattice.bottom(),) + args)
69
+ def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]:
70
+ self._run_address_analysis(method)
71
+ return super().run(method, *args, **kwargs)
81
72
 
82
- def run_analysis(
83
- self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
84
- ) -> tuple[ForwardFrame, Any]:
85
- self._run_address_analysis(method, no_raise=no_raise)
86
- return super().run_analysis(method, args, no_raise=no_raise)
87
-
88
- def _run_address_analysis(self, method: ir.Method, no_raise: bool):
73
+ def _run_address_analysis(self, method: ir.Method):
89
74
  addr_analysis = AddressAnalysis(self.dialects)
90
- addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise)
75
+ addr_frame, _ = addr_analysis.run(method=method)
91
76
  self.addr_frame = addr_frame
92
77
 
93
78
  # NOTE: make sure we have as many probabilities as we have addresses
94
79
  self.atom_survival_probability = [1.0] * addr_analysis.qubit_count
80
+
81
+ def method_self(self, method: ir.Method) -> EmptyLattice:
82
+ return self.lattice.bottom()
@@ -1,2 +1,5 @@
1
1
  from . import impls as impls
2
- from .analysis import MeasurementIDAnalysis as MeasurementIDAnalysis
2
+ from .analysis import (
3
+ MeasureIDFrame as MeasureIDFrame,
4
+ MeasurementIDAnalysis as MeasurementIDAnalysis,
5
+ )
@@ -1,13 +1,19 @@
1
1
  from typing import TypeVar
2
+ from dataclasses import field, dataclass
2
3
 
3
- from kirin import ir, interp
4
- from kirin.analysis import Forward, const
4
+ from kirin import ir
5
+ from kirin.analysis import ForwardExtra, const
5
6
  from kirin.analysis.forward import ForwardFrame
6
7
 
7
8
  from .lattice import MeasureId, NotMeasureId
8
9
 
9
10
 
10
- class MeasurementIDAnalysis(Forward[MeasureId]):
11
+ @dataclass
12
+ class MeasureIDFrame(ForwardFrame[MeasureId]):
13
+ num_measures_at_stmt: dict[ir.Statement, int] = field(default_factory=dict)
14
+
15
+
16
+ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):
11
17
 
12
18
  keys = ["measure_id"]
13
19
  lattice = MeasureId
@@ -15,31 +21,34 @@ class MeasurementIDAnalysis(Forward[MeasureId]):
15
21
  # then use this to generate the negative values for target rec indices
16
22
  measure_count = 0
17
23
 
24
+ def initialize_frame(
25
+ self, node: ir.Statement, *, has_parent_access: bool = False
26
+ ) -> MeasureIDFrame:
27
+ return MeasureIDFrame(node, has_parent_access=has_parent_access)
28
+
18
29
  # Still default to bottom,
19
30
  # but let constants return the softer "NoMeasureId" type from impl
20
- def eval_stmt_fallback(
21
- self, frame: ForwardFrame[MeasureId], stmt: ir.Statement
31
+ def eval_fallback(
32
+ self, frame: ForwardFrame[MeasureId], node: ir.Statement
22
33
  ) -> tuple[MeasureId, ...]:
23
- return tuple(NotMeasureId() for _ in stmt.results)
24
-
25
- def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
26
- # NOTE: we do not support dynamic calls here, thus no need to propagate method object
27
- return self.run_callable(method.code, (self.lattice.bottom(),) + args)
28
-
29
- T = TypeVar("T")
34
+ return tuple(NotMeasureId() for _ in node.results)
30
35
 
31
36
  # Xiu-zhe (Roger) Luo came up with this in the address analysis,
32
- # reused here for convenience
37
+ # reused here for convenience (now modified to be a bit more graceful)
33
38
  # TODO: Remove this function once upgrade to kirin 0.18 happens,
34
39
  # method is built-in to interpreter then
35
- def get_const_value(self, input_type: type[T], value: ir.SSAValue) -> T:
40
+
41
+ T = TypeVar("T")
42
+
43
+ def get_const_value(
44
+ self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue
45
+ ) -> type[T] | None:
36
46
  if isinstance(hint := value.hints.get("const"), const.Value):
37
47
  data = hint.data
38
48
  if isinstance(data, input_type):
39
49
  return hint.data
40
- raise interp.InterpreterError(
41
- f"Expected constant value <type = {input_type}>, got {data}"
42
- )
43
- raise interp.InterpreterError(
44
- f"Expected constant value <type = {input_type}>, got {value}"
45
- )
50
+
51
+ return None
52
+
53
+ def method_self(self, method: ir.Method) -> MeasureId:
54
+ return self.lattice.bottom()
@@ -1,7 +1,8 @@
1
1
  from kirin import types as kirin_types, interp
2
+ from kirin.analysis import const
2
3
  from kirin.dialects import py, scf, func, ilist
3
4
 
4
- from bloqade.squin import wire, qubit
5
+ from bloqade import qubit, annotate
5
6
 
6
7
  from .lattice import (
7
8
  AnyMeasureId,
@@ -10,7 +11,7 @@ from .lattice import (
10
11
  MeasureIdTuple,
11
12
  InvalidMeasureId,
12
13
  )
13
- from .analysis import MeasurementIDAnalysis
14
+ from .analysis import MeasureIDFrame, MeasurementIDAnalysis
14
15
 
15
16
  ## Can't do wire right now because of
16
17
  ## unresolved RFC on return type
@@ -20,22 +21,12 @@ from .analysis import MeasurementIDAnalysis
20
21
  @qubit.dialect.register(key="measure_id")
21
22
  class SquinQubit(interp.MethodTable):
22
23
 
23
- @interp.impl(qubit.MeasureQubit)
24
- def measure_qubit(
25
- self,
26
- interp: MeasurementIDAnalysis,
27
- frame: interp.Frame,
28
- stmt: qubit.MeasureQubit,
29
- ):
30
- interp.measure_count += 1
31
- return (MeasureIdBool(interp.measure_count),)
32
-
33
- @interp.impl(qubit.MeasureQubitList)
24
+ @interp.impl(qubit.stmts.Measure)
34
25
  def measure_qubit_list(
35
26
  self,
36
27
  interp: MeasurementIDAnalysis,
37
28
  frame: interp.Frame,
38
- stmt: qubit.MeasureQubitList,
29
+ stmt: qubit.stmts.Measure,
39
30
  ):
40
31
 
41
32
  # try to get the length of the list
@@ -55,18 +46,18 @@ class SquinQubit(interp.MethodTable):
55
46
  return (MeasureIdTuple(data=tuple(measure_id_bools)),)
56
47
 
57
48
 
58
- @wire.dialect.register(key="measure_id")
59
- class SquinWire(interp.MethodTable):
60
-
61
- @interp.impl(wire.Measure)
62
- def measure_qubit(
49
+ @annotate.dialect.register(key="measure_id")
50
+ class Annotate(interp.MethodTable):
51
+ @interp.impl(annotate.stmts.SetObservable)
52
+ @interp.impl(annotate.stmts.SetDetector)
53
+ def consumes_measurement_results(
63
54
  self,
64
55
  interp: MeasurementIDAnalysis,
65
- frame: interp.Frame,
66
- stmt: wire.Measure,
56
+ frame: MeasureIDFrame,
57
+ stmt: annotate.stmts.SetObservable | annotate.stmts.SetDetector,
67
58
  ):
68
- interp.measure_count += 1
69
- return (MeasureIdBool(interp.measure_count),)
59
+ frame.num_measures_at_stmt[stmt] = interp.measure_count
60
+ return (NotMeasureId(),)
70
61
 
71
62
 
72
63
  @ilist.dialect.register(key="measure_id")
@@ -102,10 +93,23 @@ class PyIndexing(interp.MethodTable):
102
93
  def getitem(
103
94
  self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.GetItem
104
95
  ):
105
- idx = interp.get_const_value(int, stmt.index)
96
+
97
+ idx_or_slice = interp.get_const_value((int, slice), stmt.index)
98
+ if idx_or_slice is None:
99
+ return (InvalidMeasureId(),)
100
+
101
+ # hint = stmt.index.hints.get("const")
102
+ # if hint is None or not isinstance(hint, const.Value):
103
+ # return (InvalidMeasureId(),)
104
+
106
105
  obj = frame.get(stmt.obj)
107
106
  if isinstance(obj, MeasureIdTuple):
108
- return (obj.data[idx],)
107
+ if isinstance(idx_or_slice, slice):
108
+ return (MeasureIdTuple(data=obj.data[idx_or_slice]),)
109
+ elif isinstance(idx_or_slice, int):
110
+ return (obj.data[idx_or_slice],)
111
+ else:
112
+ return (InvalidMeasureId(),)
109
113
  # just propagate these down the line
110
114
  elif isinstance(obj, (AnyMeasureId, NotMeasureId)):
111
115
  return (obj,)
@@ -113,6 +117,15 @@ class PyIndexing(interp.MethodTable):
113
117
  return (InvalidMeasureId(),)
114
118
 
115
119
 
120
+ @py.assign.dialect.register(key="measure_id")
121
+ class PyAssign(interp.MethodTable):
122
+ @interp.impl(py.Alias)
123
+ def alias(
124
+ self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.assign.Alias
125
+ ):
126
+ return (frame.get(stmt.value),)
127
+
128
+
116
129
  @py.binop.dialect.register(key="measure_id")
117
130
  class PyBinOp(interp.MethodTable):
118
131
  @interp.impl(py.Add)
@@ -139,11 +152,10 @@ class Func(interp.MethodTable):
139
152
  def invoke(
140
153
  self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke
141
154
  ):
142
- _, ret = interp_.run_method(
143
- stmt.callee,
144
- interp_.permute_values(
145
- stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
146
- ),
155
+ _, ret = interp_.call(
156
+ stmt.callee.code,
157
+ interp_.method_self(stmt.callee),
158
+ *frame.get_values(stmt.inputs),
147
159
  )
148
160
  return (ret,)
149
161
 
@@ -152,4 +164,33 @@ class Func(interp.MethodTable):
152
164
  # scf, particularly IfElse
153
165
  @scf.dialect.register(key="measure_id")
154
166
  class Scf(scf.absint.Methods):
155
- pass
167
+
168
+ @interp.impl(scf.IfElse)
169
+ def if_else(
170
+ self,
171
+ interp_: MeasurementIDAnalysis,
172
+ frame: MeasureIDFrame,
173
+ stmt: scf.IfElse,
174
+ ):
175
+
176
+ frame.num_measures_at_stmt[stmt] = interp_.measure_count
177
+
178
+ # rest of the code taken directly from scf.absint.Methods base implementation
179
+
180
+ if isinstance(hint := stmt.cond.hints.get("const"), const.Value):
181
+ if hint.data:
182
+ return self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
183
+ else:
184
+ return self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
185
+ then_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
186
+ else_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
187
+
188
+ match (then_results, else_results):
189
+ case (interp.ReturnValue(then_value), interp.ReturnValue(else_value)):
190
+ return interp.ReturnValue(then_value.join(else_value))
191
+ case (interp.ReturnValue(then_value), _):
192
+ return then_results
193
+ case (_, interp.ReturnValue(else_value)):
194
+ return else_results
195
+ case _:
196
+ return interp_.join_results(then_results, else_results)
@@ -0,0 +1,6 @@
1
+ from . import stmts as stmts
2
+ from ._dialect import dialect as dialect
3
+ from ._interface import (
4
+ set_detector as set_detector,
5
+ set_observable as set_observable,
6
+ )
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("squin.annotate")
@@ -0,0 +1,22 @@
1
+ from typing import Any
2
+
3
+ from kirin.dialects import ilist
4
+ from kirin.lowering import wraps
5
+
6
+ from bloqade.types import MeasurementResult
7
+
8
+ from .stmts import SetDetector, SetObservable
9
+ from .types import Detector, Observable
10
+
11
+
12
+ @wraps(SetDetector)
13
+ def set_detector(
14
+ measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
15
+ coordinates: ilist.IList[float | int, Any] | list[float | int],
16
+ ) -> Detector: ...
17
+
18
+
19
+ @wraps(SetObservable)
20
+ def set_observable(
21
+ measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
22
+ ) -> Observable: ...
@@ -0,0 +1,29 @@
1
+ from kirin import ir, types as kirin_types, lowering
2
+ from kirin.decl import info, statement
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import MeasurementResultType
6
+ from bloqade.annotate.types import DetectorType, ObservableType
7
+
8
+ from ._dialect import dialect
9
+
10
+
11
+ @statement
12
+ class ConsumesMeasurementResults(ir.Statement):
13
+ traits = frozenset({lowering.FromPythonCall()})
14
+ measurements: ir.SSAValue = info.argument(
15
+ ilist.IListType[MeasurementResultType, kirin_types.Any]
16
+ )
17
+
18
+
19
+ @statement(dialect=dialect)
20
+ class SetDetector(ConsumesMeasurementResults):
21
+ coordinates: ir.SSAValue = info.argument(
22
+ type=ilist.IListType[kirin_types.Int | kirin_types.Float, kirin_types.Any]
23
+ )
24
+ result: ir.ResultValue = info.result(DetectorType)
25
+
26
+
27
+ @statement(dialect=dialect)
28
+ class SetObservable(ConsumesMeasurementResults):
29
+ result: ir.ResultValue = info.result(ObservableType)
@@ -0,0 +1,13 @@
1
+ from kirin import types
2
+
3
+
4
+ class Detector:
5
+ pass
6
+
7
+
8
+ class Observable:
9
+ pass
10
+
11
+
12
+ DetectorType = types.PyClass(Detector)
13
+ ObservableType = types.PyClass(Observable)