bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/analysis.py +18 -20
  7. bloqade/analysis/measure_id/impls.py +31 -29
  8. bloqade/annotate/__init__.py +6 -0
  9. bloqade/annotate/_dialect.py +3 -0
  10. bloqade/annotate/_interface.py +22 -0
  11. bloqade/annotate/stmts.py +29 -0
  12. bloqade/annotate/types.py +13 -0
  13. bloqade/cirq_utils/__init__.py +4 -2
  14. bloqade/cirq_utils/emit/__init__.py +3 -0
  15. bloqade/cirq_utils/emit/base.py +246 -0
  16. bloqade/cirq_utils/emit/gate.py +104 -0
  17. bloqade/cirq_utils/emit/noise.py +90 -0
  18. bloqade/cirq_utils/emit/qubit.py +35 -0
  19. bloqade/cirq_utils/lowering.py +660 -0
  20. bloqade/cirq_utils/noise/__init__.py +0 -2
  21. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  22. bloqade/cirq_utils/noise/model.py +151 -191
  23. bloqade/cirq_utils/noise/transform.py +2 -2
  24. bloqade/cirq_utils/parallelize.py +9 -6
  25. bloqade/gemini/__init__.py +1 -0
  26. bloqade/gemini/analysis/__init__.py +3 -0
  27. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  28. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  29. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  30. bloqade/gemini/groups.py +67 -0
  31. bloqade/native/__init__.py +23 -0
  32. bloqade/native/_prelude.py +45 -0
  33. bloqade/native/dialects/__init__.py +0 -0
  34. bloqade/native/dialects/gate/__init__.py +2 -0
  35. bloqade/native/dialects/gate/_dialect.py +3 -0
  36. bloqade/native/dialects/gate/_interface.py +32 -0
  37. bloqade/native/dialects/gate/stmts.py +31 -0
  38. bloqade/native/stdlib/__init__.py +0 -0
  39. bloqade/native/stdlib/broadcast.py +246 -0
  40. bloqade/native/stdlib/simple.py +220 -0
  41. bloqade/native/upstream/__init__.py +4 -0
  42. bloqade/native/upstream/squin2native.py +79 -0
  43. bloqade/pyqrack/__init__.py +2 -2
  44. bloqade/pyqrack/base.py +7 -1
  45. bloqade/pyqrack/device.py +192 -18
  46. bloqade/pyqrack/native.py +49 -0
  47. bloqade/pyqrack/reg.py +6 -6
  48. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  49. bloqade/pyqrack/squin/gate/gate.py +136 -0
  50. bloqade/pyqrack/squin/noise/native.py +120 -54
  51. bloqade/pyqrack/squin/qubit.py +39 -36
  52. bloqade/pyqrack/target.py +5 -4
  53. bloqade/pyqrack/task.py +114 -7
  54. bloqade/qasm2/_qasm_loading.py +3 -3
  55. bloqade/qasm2/dialects/core/address.py +21 -12
  56. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  57. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  58. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  59. bloqade/qasm2/dialects/noise/model.py +2 -1
  60. bloqade/qasm2/emit/base.py +16 -11
  61. bloqade/qasm2/emit/gate.py +11 -8
  62. bloqade/qasm2/emit/main.py +103 -3
  63. bloqade/qasm2/emit/target.py +9 -5
  64. bloqade/qasm2/groups.py +3 -2
  65. bloqade/qasm2/parse/lowering.py +0 -1
  66. bloqade/qasm2/passes/fold.py +14 -73
  67. bloqade/qasm2/passes/glob.py +2 -2
  68. bloqade/qasm2/passes/noise.py +1 -1
  69. bloqade/qasm2/passes/parallel.py +7 -5
  70. bloqade/qasm2/rewrite/__init__.py +0 -1
  71. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  72. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  73. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  74. bloqade/qasm2/rewrite/register.py +2 -2
  75. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  76. bloqade/qbraid/lowering.py +1 -0
  77. bloqade/qbraid/schema.py +2 -2
  78. bloqade/qubit/__init__.py +12 -0
  79. bloqade/qubit/_dialect.py +3 -0
  80. bloqade/qubit/_interface.py +49 -0
  81. bloqade/qubit/_prelude.py +45 -0
  82. bloqade/qubit/analysis/__init__.py +1 -0
  83. bloqade/qubit/analysis/address_impl.py +40 -0
  84. bloqade/qubit/stdlib/__init__.py +2 -0
  85. bloqade/qubit/stdlib/_new.py +34 -0
  86. bloqade/qubit/stdlib/broadcast.py +62 -0
  87. bloqade/qubit/stdlib/simple.py +59 -0
  88. bloqade/qubit/stmts.py +60 -0
  89. bloqade/rewrite/passes/__init__.py +6 -0
  90. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  91. bloqade/rewrite/passes/callgraph.py +116 -0
  92. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  93. bloqade/rewrite/rules/split_ifs.py +18 -1
  94. bloqade/squin/__init__.py +47 -14
  95. bloqade/squin/analysis/__init__.py +0 -1
  96. bloqade/squin/analysis/schedule.py +10 -11
  97. bloqade/squin/gate/__init__.py +2 -0
  98. bloqade/squin/gate/_dialect.py +3 -0
  99. bloqade/squin/gate/_interface.py +98 -0
  100. bloqade/squin/gate/stmts.py +125 -0
  101. bloqade/squin/groups.py +5 -22
  102. bloqade/squin/noise/__init__.py +1 -10
  103. bloqade/squin/noise/_dialect.py +1 -1
  104. bloqade/squin/noise/_interface.py +45 -0
  105. bloqade/squin/noise/stmts.py +66 -28
  106. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  107. bloqade/squin/rewrite/__init__.py +0 -2
  108. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  109. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  110. bloqade/squin/stdlib/__init__.py +0 -0
  111. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  112. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  113. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  114. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  115. bloqade/squin/stdlib/simple/__init__.py +33 -0
  116. bloqade/squin/stdlib/simple/gate.py +242 -0
  117. bloqade/squin/stdlib/simple/noise.py +126 -0
  118. bloqade/stim/__init__.py +1 -0
  119. bloqade/stim/_wrappers.py +6 -0
  120. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  121. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  122. bloqade/stim/dialects/gate/emit.py +9 -10
  123. bloqade/stim/dialects/noise/emit.py +17 -13
  124. bloqade/stim/dialects/noise/stmts.py +5 -3
  125. bloqade/stim/emit/__init__.py +1 -0
  126. bloqade/stim/emit/impls.py +16 -0
  127. bloqade/stim/emit/stim_str.py +48 -31
  128. bloqade/stim/groups.py +12 -2
  129. bloqade/stim/parse/lowering.py +14 -17
  130. bloqade/stim/passes/__init__.py +0 -2
  131. bloqade/stim/passes/flatten.py +26 -0
  132. bloqade/stim/passes/simplify_ifs.py +6 -1
  133. bloqade/stim/passes/squin_to_stim.py +9 -84
  134. bloqade/stim/rewrite/__init__.py +2 -4
  135. bloqade/stim/rewrite/get_record_util.py +24 -0
  136. bloqade/stim/rewrite/ifs_to_stim.py +24 -25
  137. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  138. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  139. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  140. bloqade/stim/rewrite/squin_measure.py +9 -18
  141. bloqade/stim/rewrite/squin_noise.py +134 -108
  142. bloqade/stim/rewrite/util.py +5 -192
  143. bloqade/test_utils.py +1 -1
  144. bloqade/types.py +10 -0
  145. bloqade/validation/__init__.py +2 -0
  146. bloqade/validation/analysis/__init__.py +5 -0
  147. bloqade/validation/analysis/analysis.py +41 -0
  148. bloqade/validation/analysis/lattice.py +58 -0
  149. bloqade/validation/kernel_validation.py +77 -0
  150. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  151. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  152. bloqade/pyqrack/squin/op.py +0 -180
  153. bloqade/pyqrack/squin/runtime.py +0 -535
  154. bloqade/pyqrack/squin/wire.py +0 -51
  155. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  156. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  157. bloqade/squin/_typeinfer.py +0 -20
  158. bloqade/squin/analysis/address_impl.py +0 -71
  159. bloqade/squin/analysis/nsites/__init__.py +0 -9
  160. bloqade/squin/analysis/nsites/analysis.py +0 -50
  161. bloqade/squin/analysis/nsites/impls.py +0 -92
  162. bloqade/squin/analysis/nsites/lattice.py +0 -49
  163. bloqade/squin/cirq/__init__.py +0 -280
  164. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  165. bloqade/squin/cirq/emit/noise.py +0 -49
  166. bloqade/squin/cirq/emit/op.py +0 -125
  167. bloqade/squin/cirq/emit/qubit.py +0 -60
  168. bloqade/squin/cirq/emit/runtime.py +0 -242
  169. bloqade/squin/cirq/lowering.py +0 -440
  170. bloqade/squin/lowering.py +0 -54
  171. bloqade/squin/noise/_wrapper.py +0 -40
  172. bloqade/squin/noise/rewrite.py +0 -111
  173. bloqade/squin/op/__init__.py +0 -41
  174. bloqade/squin/op/_dialect.py +0 -3
  175. bloqade/squin/op/_wrapper.py +0 -121
  176. bloqade/squin/op/number.py +0 -5
  177. bloqade/squin/op/rewrite.py +0 -46
  178. bloqade/squin/op/stdlib.py +0 -62
  179. bloqade/squin/op/stmts.py +0 -276
  180. bloqade/squin/op/traits.py +0 -43
  181. bloqade/squin/op/types.py +0 -26
  182. bloqade/squin/qubit.py +0 -184
  183. bloqade/squin/rewrite/canonicalize.py +0 -60
  184. bloqade/squin/rewrite/desugar.py +0 -124
  185. bloqade/squin/types.py +0 -8
  186. bloqade/squin/wire.py +0 -201
  187. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  188. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  189. bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
  190. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  191. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,18 @@
1
1
  from typing import Sequence, final
2
2
  from dataclasses import dataclass
3
3
 
4
+ from kirin import ir, types
4
5
  from kirin.lattice import (
5
6
  SingletonMeta,
6
7
  BoundedLattice,
7
8
  SimpleJoinMixin,
8
9
  SimpleMeetMixin,
9
10
  )
11
+ from kirin.analysis import const
12
+ from kirin.dialects import ilist
13
+ from kirin.ir.attrs.abc import LatticeAttributeMeta
14
+
15
+ from bloqade.types import QubitType
10
16
 
11
17
 
12
18
  @dataclass
@@ -18,54 +24,94 @@ class Address(
18
24
 
19
25
  @classmethod
20
26
  def bottom(cls) -> "Address":
21
- return NotQubit()
27
+ return Bottom()
22
28
 
23
29
  @classmethod
24
30
  def top(cls) -> "Address":
25
- return AnyAddress()
31
+ return Unknown()
32
+
33
+ @classmethod
34
+ def from_type(cls, typ: types.TypeAttribute):
35
+ if typ.is_subseteq(ilist.IListType[QubitType]):
36
+ return UnknownReg()
37
+ elif typ.is_subseteq(QubitType):
38
+ return UnknownQubit()
39
+ else:
40
+ return Unknown()
26
41
 
27
42
 
28
43
  @final
29
- @dataclass
30
- class NotQubit(Address, metaclass=SingletonMeta):
44
+ class Bottom(Address, metaclass=SingletonMeta):
45
+ """Error during interpretation"""
31
46
 
32
47
  def is_subseteq(self, other: Address) -> bool:
33
48
  return True
34
49
 
35
50
 
36
51
  @final
37
- @dataclass
38
- class AnyAddress(Address, metaclass=SingletonMeta):
52
+ class Unknown(Address, metaclass=SingletonMeta):
53
+ """Can't determine if it is an address or constant."""
39
54
 
40
55
  def is_subseteq(self, other: Address) -> bool:
41
- return isinstance(other, AnyAddress)
56
+ return isinstance(other, Unknown)
42
57
 
43
58
 
44
59
  @final
45
60
  @dataclass
46
- class AddressTuple(Address):
47
- data: tuple[Address, ...]
61
+ class ConstResult(Address):
62
+ """Stores a constant prop result in the lattice"""
63
+
64
+ result: const.Result
48
65
 
49
66
  def is_subseteq(self, other: Address) -> bool:
50
- if isinstance(other, AddressTuple):
51
- return all(a.is_subseteq(b) for a, b in zip(self.data, other.data))
52
- return False
67
+ return isinstance(other, ConstResult) and self.result.is_subseteq(other.result)
68
+
69
+
70
+ class QubitLike(Address):
71
+ def join(self, other: Address):
72
+ if isinstance(other, QubitLike):
73
+ return super().join(other)
74
+ return self.bottom()
75
+
76
+ def meet(self, other: Address):
77
+ if isinstance(other, QubitLike):
78
+ return super().meet(other)
79
+ return self.bottom()
53
80
 
54
81
 
55
82
  @final
56
- @dataclass
57
- class AddressReg(Address):
58
- data: Sequence[int]
83
+ class UnknownQubit(QubitLike, metaclass=SingletonMeta):
84
+ """A lattice element representing a single qubit with an unknown address."""
59
85
 
60
86
  def is_subseteq(self, other: Address) -> bool:
61
- if isinstance(other, AddressReg):
62
- return self.data == other.data
63
- return False
87
+ return isinstance(other, QubitLike)
88
+
89
+
90
+ class RegisterLike(Address):
91
+ def join(self, other: Address):
92
+ if isinstance(other, RegisterLike):
93
+ return super().join(other)
94
+ return self.bottom()
95
+
96
+ def meet(self, other: Address):
97
+ if isinstance(other, RegisterLike):
98
+ return super().meet(other)
99
+ return self.bottom()
100
+
101
+
102
+ @final
103
+ class UnknownReg(RegisterLike, metaclass=SingletonMeta):
104
+ """A lattice element representing a container of qubits with unknown indices."""
105
+
106
+ def is_subseteq(self, other: Address) -> bool:
107
+ return isinstance(other, RegisterLike)
64
108
 
65
109
 
66
110
  @final
67
111
  @dataclass
68
- class AddressQubit(Address):
112
+ class AddressQubit(QubitLike):
113
+ """A lattice element representing a single qubit with a known address."""
114
+
69
115
  data: int
70
116
 
71
117
  def is_subseteq(self, other: Address) -> bool:
@@ -76,10 +122,149 @@ class AddressQubit(Address):
76
122
 
77
123
  @final
78
124
  @dataclass
79
- class AddressWire(Address):
80
- origin_qubit: AddressQubit
125
+ class AddressReg(RegisterLike):
126
+ """A lattice element representing a container of qubits with known indices."""
127
+
128
+ data: Sequence[int]
81
129
 
82
130
  def is_subseteq(self, other: Address) -> bool:
83
- if isinstance(other, AddressWire):
84
- return self.origin_qubit == other.origin_qubit
85
- return False
131
+ return isinstance(other, AddressReg) and self.data == other.data
132
+
133
+ @property
134
+ def qubits(self) -> tuple[AddressQubit, ...]:
135
+ return tuple(AddressQubit(i) for i in self.data)
136
+
137
+
138
+ @final
139
+ @dataclass
140
+ class PartialLambda(Address):
141
+ """Represents a partially known lambda function"""
142
+
143
+ argnames: list[str]
144
+ code: ir.Statement
145
+ captured: tuple[Address, ...]
146
+
147
+ def join(self, other: Address) -> Address:
148
+ if other is other.bottom():
149
+ return self
150
+
151
+ if not isinstance(other, PartialLambda):
152
+ return self.top().join(other) # widen self
153
+
154
+ if self.code is not other.code:
155
+ return self.top() # lambda stmt is pure
156
+
157
+ if len(self.captured) != len(other.captured):
158
+ return self.bottom() # err
159
+
160
+ return PartialLambda(
161
+ self.argnames,
162
+ self.code,
163
+ tuple(x.join(y) for x, y in zip(self.captured, other.captured)),
164
+ )
165
+
166
+ def meet(self, other: Address) -> Address:
167
+ if not isinstance(other, PartialLambda):
168
+ return self.top().meet(other)
169
+
170
+ if self.code is not other.code:
171
+ return self.bottom()
172
+
173
+ if len(self.captured) != len(other.captured):
174
+ return self.top()
175
+
176
+ return PartialLambda(
177
+ self.argnames,
178
+ self.code,
179
+ tuple(x.meet(y) for x, y in zip(self.captured, other.captured)),
180
+ )
181
+
182
+ def is_subseteq(self, other: Address) -> bool:
183
+ return (
184
+ isinstance(other, PartialLambda)
185
+ and self.code is other.code
186
+ and self.argnames == other.argnames
187
+ and len(self.captured) == len(other.captured)
188
+ and all(
189
+ self_ele.is_subseteq(other_ele)
190
+ for self_ele, other_ele in zip(self.captured, other.captured)
191
+ )
192
+ )
193
+
194
+
195
+ @dataclass
196
+ class StaticContainer(Address):
197
+ """A lattice element representing the results of any static container, e. g. ilist or tuple."""
198
+
199
+ data: tuple[Address, ...]
200
+
201
+ @classmethod
202
+ def new(cls, data: tuple[Address, ...]):
203
+ return cls(data)
204
+
205
+ def join(self, other: "Address") -> "Address":
206
+ if isinstance(other, type(self)) and len(self.data) == len(other.data):
207
+ return self.new(tuple(x.join(y) for x, y in zip(self.data, other.data)))
208
+ return self.top()
209
+
210
+ def meet(self, other: "Address") -> "Address":
211
+ if isinstance(other, type(self)) and len(self.data) == len(other.data):
212
+ return self.new(tuple(x.meet(y) for x, y in zip(self.data, other.data)))
213
+ return self.bottom()
214
+
215
+ def is_subseteq(self, other: "Address") -> bool:
216
+ return (
217
+ isinstance(other, type(self))
218
+ and len(self.data) == len(other.data)
219
+ and all(x.is_subseteq(y) for x, y in zip(self.data, other.data))
220
+ )
221
+
222
+
223
+ class PartialIListMeta(LatticeAttributeMeta):
224
+ """This metaclass assures that PartialILists of ConstResults or AddressQubits are canonicalized
225
+ to a single ConstResult or AddressReg respectively.
226
+
227
+ because AddressReg is a specialization of PartialIList, being a container of pure qubit
228
+ addresses. For Operations that act in generic containers (e.g., ilist.ForEach),
229
+ AddressReg is treated as PartialIList but for other types of analysis it is often
230
+ useful to distinguish between a generic IList and a pure qubit address list.
231
+
232
+ Inside the method tables the `GetValuesMixin` implements a method that effectively
233
+ undoes this canonicalization.
234
+
235
+ """
236
+
237
+ def __call__(cls, data: tuple[Address, ...]):
238
+ # TODO: when constant prop has PartialIList, make sure to canonicalize here.
239
+ if types.is_tuple_of(data, ConstResult) and types.is_tuple_of(
240
+ all_constants := tuple(ele.result for ele in data), const.Value
241
+ ):
242
+ # all constants, create constant list
243
+ return ConstResult(
244
+ const.Value(ilist.IList([ele.data for ele in all_constants]))
245
+ )
246
+ elif types.is_tuple_of(data, AddressQubit):
247
+ # all qubits create qubit register
248
+ return AddressReg(tuple(ele.data for ele in data))
249
+ else:
250
+ return super().__call__(data)
251
+
252
+
253
+ @final
254
+ class PartialIList(StaticContainer, metaclass=PartialIListMeta):
255
+ """A lattice element representing a partially known ilist."""
256
+
257
+
258
+ class PartialTupleMeta(LatticeAttributeMeta):
259
+ """This metaclass assures that PartialTuples of ConstResults are canonicalized to a single ConstResult."""
260
+
261
+ def __call__(cls, data: tuple[Address, ...]):
262
+ if not types.is_tuple_of(data, ConstResult):
263
+ return super().__call__(data)
264
+
265
+ return ConstResult(const.PartialTuple(tuple(ele.result for ele in data)))
266
+
267
+
268
+ @final
269
+ class PartialTuple(StaticContainer, metaclass=PartialTupleMeta):
270
+ """A lattice element representing a partially known tuple."""
@@ -4,10 +4,9 @@ from dataclasses import field
4
4
  from kirin import ir
5
5
  from kirin.lattice import EmptyLattice
6
6
  from kirin.analysis import Forward
7
- from kirin.interp.value import Successor
8
7
  from kirin.analysis.forward import ForwardFrame
9
8
 
10
- from ..address import AddressAnalysis
9
+ from ..address import Address, AddressAnalysis
11
10
 
12
11
 
13
12
  class FidelityAnalysis(Forward):
@@ -48,16 +47,12 @@ class FidelityAnalysis(Forward):
48
47
  The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
49
48
  """
50
49
 
51
- _current_gate_fidelity: float = field(init=False)
52
-
53
50
  atom_survival_probability: list[float] = field(init=False)
54
51
  """
55
52
  The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
56
53
  """
57
54
 
58
- _current_atom_survival_probability: list[float] = field(init=False)
59
-
60
- addr_frame: ForwardFrame = field(init=False)
55
+ addr_frame: ForwardFrame[Address] = field(init=False)
61
56
 
62
57
  def initialize(self):
63
58
  super().initialize()
@@ -67,28 +62,21 @@ class FidelityAnalysis(Forward):
67
62
  ]
68
63
  return self
69
64
 
70
- def posthook_succ(self, frame: ForwardFrame, succ: Successor):
71
- self.gate_fidelity *= self._current_gate_fidelity
72
- for i, _current_survival in enumerate(self._current_atom_survival_probability):
73
- self.atom_survival_probability[i] *= _current_survival
74
-
75
- def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
65
+ def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
76
66
  # NOTE: default is to conserve fidelity, so do nothing here
77
67
  return
78
68
 
79
- def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):
80
- return self.run_callable(method.code, (self.lattice.bottom(),) + args)
69
+ def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]:
70
+ self._run_address_analysis(method)
71
+ return super().run(method, *args, **kwargs)
81
72
 
82
- def run_analysis(
83
- self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
84
- ) -> tuple[ForwardFrame, Any]:
85
- self._run_address_analysis(method, no_raise=no_raise)
86
- return super().run_analysis(method, args, no_raise=no_raise)
87
-
88
- def _run_address_analysis(self, method: ir.Method, no_raise: bool):
73
+ def _run_address_analysis(self, method: ir.Method):
89
74
  addr_analysis = AddressAnalysis(self.dialects)
90
- addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise)
75
+ addr_frame, _ = addr_analysis.run(method=method)
91
76
  self.addr_frame = addr_frame
92
77
 
93
78
  # NOTE: make sure we have as many probabilities as we have addresses
94
79
  self.atom_survival_probability = [1.0] * addr_analysis.qubit_count
80
+
81
+ def method_self(self, method: ir.Method) -> EmptyLattice:
82
+ return self.lattice.bottom()
@@ -1,7 +1,7 @@
1
1
  from typing import TypeVar
2
2
  from dataclasses import field, dataclass
3
3
 
4
- from kirin import ir, interp
4
+ from kirin import ir
5
5
  from kirin.analysis import ForwardExtra, const
6
6
  from kirin.analysis.forward import ForwardFrame
7
7
 
@@ -22,35 +22,33 @@ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):
22
22
  measure_count = 0
23
23
 
24
24
  def initialize_frame(
25
- self, code: ir.Statement, *, has_parent_access: bool = False
25
+ self, node: ir.Statement, *, has_parent_access: bool = False
26
26
  ) -> MeasureIDFrame:
27
- return MeasureIDFrame(code, has_parent_access=has_parent_access)
27
+ return MeasureIDFrame(node, has_parent_access=has_parent_access)
28
28
 
29
29
  # Still default to bottom,
30
30
  # but let constants return the softer "NoMeasureId" type from impl
31
- def eval_stmt_fallback(
32
- self, frame: ForwardFrame[MeasureId], stmt: ir.Statement
31
+ def eval_fallback(
32
+ self, frame: ForwardFrame[MeasureId], node: ir.Statement
33
33
  ) -> tuple[MeasureId, ...]:
34
- return tuple(NotMeasureId() for _ in stmt.results)
35
-
36
- def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
37
- # NOTE: we do not support dynamic calls here, thus no need to propagate method object
38
- return self.run_callable(method.code, (self.lattice.bottom(),) + args)
39
-
40
- T = TypeVar("T")
34
+ return tuple(NotMeasureId() for _ in node.results)
41
35
 
42
36
  # Xiu-zhe (Roger) Luo came up with this in the address analysis,
43
- # reused here for convenience
37
+ # reused here for convenience (now modified to be a bit more graceful)
44
38
  # TODO: Remove this function once upgrade to kirin 0.18 happens,
45
39
  # method is built-in to interpreter then
46
- def get_const_value(self, input_type: type[T], value: ir.SSAValue) -> T:
40
+
41
+ T = TypeVar("T")
42
+
43
+ def get_const_value(
44
+ self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue
45
+ ) -> type[T] | None:
47
46
  if isinstance(hint := value.hints.get("const"), const.Value):
48
47
  data = hint.data
49
48
  if isinstance(data, input_type):
50
49
  return hint.data
51
- raise interp.InterpreterError(
52
- f"Expected constant value <type = {input_type}>, got {data}"
53
- )
54
- raise interp.InterpreterError(
55
- f"Expected constant value <type = {input_type}>, got {value}"
56
- )
50
+
51
+ return None
52
+
53
+ def method_self(self, method: ir.Method) -> MeasureId:
54
+ return self.lattice.bottom()
@@ -2,7 +2,7 @@ from kirin import types as kirin_types, interp
2
2
  from kirin.analysis import const
3
3
  from kirin.dialects import py, scf, func, ilist
4
4
 
5
- from bloqade.squin import wire, qubit
5
+ from bloqade import qubit, annotate
6
6
 
7
7
  from .lattice import (
8
8
  AnyMeasureId,
@@ -21,22 +21,12 @@ from .analysis import MeasureIDFrame, MeasurementIDAnalysis
21
21
  @qubit.dialect.register(key="measure_id")
22
22
  class SquinQubit(interp.MethodTable):
23
23
 
24
- @interp.impl(qubit.MeasureQubit)
25
- def measure_qubit(
26
- self,
27
- interp: MeasurementIDAnalysis,
28
- frame: interp.Frame,
29
- stmt: qubit.MeasureQubit,
30
- ):
31
- interp.measure_count += 1
32
- return (MeasureIdBool(interp.measure_count),)
33
-
34
- @interp.impl(qubit.MeasureQubitList)
24
+ @interp.impl(qubit.stmts.Measure)
35
25
  def measure_qubit_list(
36
26
  self,
37
27
  interp: MeasurementIDAnalysis,
38
28
  frame: interp.Frame,
39
- stmt: qubit.MeasureQubitList,
29
+ stmt: qubit.stmts.Measure,
40
30
  ):
41
31
 
42
32
  # try to get the length of the list
@@ -56,18 +46,18 @@ class SquinQubit(interp.MethodTable):
56
46
  return (MeasureIdTuple(data=tuple(measure_id_bools)),)
57
47
 
58
48
 
59
- @wire.dialect.register(key="measure_id")
60
- class SquinWire(interp.MethodTable):
61
-
62
- @interp.impl(wire.Measure)
63
- def measure_qubit(
49
+ @annotate.dialect.register(key="measure_id")
50
+ class Annotate(interp.MethodTable):
51
+ @interp.impl(annotate.stmts.SetObservable)
52
+ @interp.impl(annotate.stmts.SetDetector)
53
+ def consumes_measurement_results(
64
54
  self,
65
55
  interp: MeasurementIDAnalysis,
66
- frame: interp.Frame,
67
- stmt: wire.Measure,
56
+ frame: MeasureIDFrame,
57
+ stmt: annotate.stmts.SetObservable | annotate.stmts.SetDetector,
68
58
  ):
69
- interp.measure_count += 1
70
- return (MeasureIdBool(interp.measure_count),)
59
+ frame.num_measures_at_stmt[stmt] = interp.measure_count
60
+ return (NotMeasureId(),)
71
61
 
72
62
 
73
63
  @ilist.dialect.register(key="measure_id")
@@ -103,10 +93,23 @@ class PyIndexing(interp.MethodTable):
103
93
  def getitem(
104
94
  self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.GetItem
105
95
  ):
106
- idx = interp.get_const_value(int, stmt.index)
96
+
97
+ idx_or_slice = interp.get_const_value((int, slice), stmt.index)
98
+ if idx_or_slice is None:
99
+ return (InvalidMeasureId(),)
100
+
101
+ # hint = stmt.index.hints.get("const")
102
+ # if hint is None or not isinstance(hint, const.Value):
103
+ # return (InvalidMeasureId(),)
104
+
107
105
  obj = frame.get(stmt.obj)
108
106
  if isinstance(obj, MeasureIdTuple):
109
- return (obj.data[idx],)
107
+ if isinstance(idx_or_slice, slice):
108
+ return (MeasureIdTuple(data=obj.data[idx_or_slice]),)
109
+ elif isinstance(idx_or_slice, int):
110
+ return (obj.data[idx_or_slice],)
111
+ else:
112
+ return (InvalidMeasureId(),)
110
113
  # just propagate these down the line
111
114
  elif isinstance(obj, (AnyMeasureId, NotMeasureId)):
112
115
  return (obj,)
@@ -149,11 +152,10 @@ class Func(interp.MethodTable):
149
152
  def invoke(
150
153
  self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke
151
154
  ):
152
- _, ret = interp_.run_method(
153
- stmt.callee,
154
- interp_.permute_values(
155
- stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
156
- ),
155
+ _, ret = interp_.call(
156
+ stmt.callee.code,
157
+ interp_.method_self(stmt.callee),
158
+ *frame.get_values(stmt.inputs),
157
159
  )
158
160
  return (ret,)
159
161
 
@@ -0,0 +1,6 @@
1
+ from . import stmts as stmts
2
+ from ._dialect import dialect as dialect
3
+ from ._interface import (
4
+ set_detector as set_detector,
5
+ set_observable as set_observable,
6
+ )
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("squin.annotate")
@@ -0,0 +1,22 @@
1
+ from typing import Any
2
+
3
+ from kirin.dialects import ilist
4
+ from kirin.lowering import wraps
5
+
6
+ from bloqade.types import MeasurementResult
7
+
8
+ from .stmts import SetDetector, SetObservable
9
+ from .types import Detector, Observable
10
+
11
+
12
+ @wraps(SetDetector)
13
+ def set_detector(
14
+ measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
15
+ coordinates: ilist.IList[float | int, Any] | list[float | int],
16
+ ) -> Detector: ...
17
+
18
+
19
+ @wraps(SetObservable)
20
+ def set_observable(
21
+ measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
22
+ ) -> Observable: ...
@@ -0,0 +1,29 @@
1
+ from kirin import ir, types as kirin_types, lowering
2
+ from kirin.decl import info, statement
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import MeasurementResultType
6
+ from bloqade.annotate.types import DetectorType, ObservableType
7
+
8
+ from ._dialect import dialect
9
+
10
+
11
+ @statement
12
+ class ConsumesMeasurementResults(ir.Statement):
13
+ traits = frozenset({lowering.FromPythonCall()})
14
+ measurements: ir.SSAValue = info.argument(
15
+ ilist.IListType[MeasurementResultType, kirin_types.Any]
16
+ )
17
+
18
+
19
+ @statement(dialect=dialect)
20
+ class SetDetector(ConsumesMeasurementResults):
21
+ coordinates: ir.SSAValue = info.argument(
22
+ type=ilist.IListType[kirin_types.Int | kirin_types.Float, kirin_types.Any]
23
+ )
24
+ result: ir.ResultValue = info.result(DetectorType)
25
+
26
+
27
+ @statement(dialect=dialect)
28
+ class SetObservable(ConsumesMeasurementResults):
29
+ result: ir.ResultValue = info.result(ObservableType)
@@ -0,0 +1,13 @@
1
+ from kirin import types
2
+
3
+
4
+ class Detector:
5
+ pass
6
+
7
+
8
+ class Observable:
9
+ pass
10
+
11
+
12
+ DetectorType = types.PyClass(Detector)
13
+ ObservableType = types.PyClass(Observable)
@@ -1,8 +1,10 @@
1
- from . import noise as noise
1
+ from . import emit as emit, noise as noise, lowering as lowering
2
+ from .emit import emit_circuit as emit_circuit
3
+ from .lowering import load_circuit as load_circuit
2
4
  from .parallelize import (
3
5
  transpile as transpile,
4
6
  parallelize as parallelize,
5
- no_similarity as no_similarity,
7
+ remove_tags as remove_tags,
6
8
  auto_similarity as auto_similarity,
7
9
  block_similarity as block_similarity,
8
10
  moment_similarity as moment_similarity,
@@ -0,0 +1,3 @@
1
+ # NOTE: just to register methods
2
+ from . import gate as gate, noise as noise, qubit as qubit
3
+ from .base import emit_circuit as emit_circuit