bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/__init__.py +4 -1
  7. bloqade/analysis/measure_id/analysis.py +29 -20
  8. bloqade/analysis/measure_id/impls.py +72 -31
  9. bloqade/annotate/__init__.py +6 -0
  10. bloqade/annotate/_dialect.py +3 -0
  11. bloqade/annotate/_interface.py +22 -0
  12. bloqade/annotate/stmts.py +29 -0
  13. bloqade/annotate/types.py +13 -0
  14. bloqade/cirq_utils/__init__.py +4 -2
  15. bloqade/cirq_utils/emit/__init__.py +3 -0
  16. bloqade/cirq_utils/emit/base.py +246 -0
  17. bloqade/cirq_utils/emit/gate.py +104 -0
  18. bloqade/cirq_utils/emit/noise.py +90 -0
  19. bloqade/cirq_utils/emit/qubit.py +35 -0
  20. bloqade/cirq_utils/lowering.py +660 -0
  21. bloqade/cirq_utils/noise/__init__.py +0 -2
  22. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  23. bloqade/cirq_utils/noise/model.py +151 -191
  24. bloqade/cirq_utils/noise/transform.py +2 -2
  25. bloqade/cirq_utils/parallelize.py +9 -6
  26. bloqade/gemini/__init__.py +1 -0
  27. bloqade/gemini/analysis/__init__.py +3 -0
  28. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  29. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  30. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  31. bloqade/gemini/groups.py +67 -0
  32. bloqade/native/__init__.py +23 -0
  33. bloqade/native/_prelude.py +45 -0
  34. bloqade/native/dialects/__init__.py +0 -0
  35. bloqade/native/dialects/gate/__init__.py +2 -0
  36. bloqade/native/dialects/gate/_dialect.py +3 -0
  37. bloqade/native/dialects/gate/_interface.py +32 -0
  38. bloqade/native/dialects/gate/stmts.py +31 -0
  39. bloqade/native/stdlib/__init__.py +0 -0
  40. bloqade/native/stdlib/broadcast.py +246 -0
  41. bloqade/native/stdlib/simple.py +220 -0
  42. bloqade/native/upstream/__init__.py +4 -0
  43. bloqade/native/upstream/squin2native.py +79 -0
  44. bloqade/pyqrack/__init__.py +2 -2
  45. bloqade/pyqrack/base.py +7 -1
  46. bloqade/pyqrack/device.py +190 -4
  47. bloqade/pyqrack/native.py +49 -0
  48. bloqade/pyqrack/reg.py +6 -6
  49. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  50. bloqade/pyqrack/squin/gate/gate.py +136 -0
  51. bloqade/pyqrack/squin/noise/native.py +120 -54
  52. bloqade/pyqrack/squin/qubit.py +39 -36
  53. bloqade/pyqrack/target.py +5 -4
  54. bloqade/pyqrack/task.py +114 -7
  55. bloqade/qasm2/_qasm_loading.py +3 -3
  56. bloqade/qasm2/dialects/core/address.py +21 -12
  57. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  58. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  59. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  60. bloqade/qasm2/dialects/noise/model.py +2 -1
  61. bloqade/qasm2/emit/base.py +16 -11
  62. bloqade/qasm2/emit/gate.py +11 -8
  63. bloqade/qasm2/emit/main.py +103 -3
  64. bloqade/qasm2/emit/target.py +9 -5
  65. bloqade/qasm2/groups.py +3 -2
  66. bloqade/qasm2/parse/lowering.py +0 -1
  67. bloqade/qasm2/passes/fold.py +14 -73
  68. bloqade/qasm2/passes/glob.py +2 -2
  69. bloqade/qasm2/passes/noise.py +1 -1
  70. bloqade/qasm2/passes/parallel.py +7 -5
  71. bloqade/qasm2/rewrite/__init__.py +0 -1
  72. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  73. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  74. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  75. bloqade/qasm2/rewrite/register.py +2 -2
  76. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  77. bloqade/qbraid/lowering.py +1 -0
  78. bloqade/qbraid/schema.py +2 -2
  79. bloqade/qubit/__init__.py +12 -0
  80. bloqade/qubit/_dialect.py +3 -0
  81. bloqade/qubit/_interface.py +49 -0
  82. bloqade/qubit/_prelude.py +45 -0
  83. bloqade/qubit/analysis/__init__.py +1 -0
  84. bloqade/qubit/analysis/address_impl.py +40 -0
  85. bloqade/qubit/stdlib/__init__.py +2 -0
  86. bloqade/qubit/stdlib/_new.py +34 -0
  87. bloqade/qubit/stdlib/broadcast.py +62 -0
  88. bloqade/qubit/stdlib/simple.py +59 -0
  89. bloqade/qubit/stmts.py +60 -0
  90. bloqade/rewrite/passes/__init__.py +6 -0
  91. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  92. bloqade/rewrite/passes/callgraph.py +116 -0
  93. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  94. bloqade/rewrite/rules/split_ifs.py +18 -1
  95. bloqade/squin/__init__.py +47 -14
  96. bloqade/squin/analysis/__init__.py +0 -1
  97. bloqade/squin/analysis/schedule.py +10 -11
  98. bloqade/squin/gate/__init__.py +2 -0
  99. bloqade/squin/gate/_dialect.py +3 -0
  100. bloqade/squin/gate/_interface.py +98 -0
  101. bloqade/squin/gate/stmts.py +125 -0
  102. bloqade/squin/groups.py +5 -22
  103. bloqade/squin/noise/__init__.py +1 -10
  104. bloqade/squin/noise/_dialect.py +1 -1
  105. bloqade/squin/noise/_interface.py +45 -0
  106. bloqade/squin/noise/stmts.py +66 -28
  107. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  108. bloqade/squin/rewrite/__init__.py +0 -2
  109. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  110. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  111. bloqade/squin/stdlib/__init__.py +0 -0
  112. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  113. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  114. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  115. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  116. bloqade/squin/stdlib/simple/__init__.py +33 -0
  117. bloqade/squin/stdlib/simple/gate.py +242 -0
  118. bloqade/squin/stdlib/simple/noise.py +126 -0
  119. bloqade/stim/__init__.py +1 -0
  120. bloqade/stim/_wrappers.py +6 -0
  121. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  122. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  123. bloqade/stim/dialects/gate/emit.py +9 -10
  124. bloqade/stim/dialects/noise/emit.py +17 -13
  125. bloqade/stim/dialects/noise/stmts.py +5 -3
  126. bloqade/stim/emit/__init__.py +1 -0
  127. bloqade/stim/emit/impls.py +16 -0
  128. bloqade/stim/emit/stim_str.py +48 -31
  129. bloqade/stim/groups.py +12 -2
  130. bloqade/stim/parse/lowering.py +14 -17
  131. bloqade/stim/passes/__init__.py +3 -1
  132. bloqade/stim/passes/flatten.py +26 -0
  133. bloqade/stim/passes/simplify_ifs.py +16 -2
  134. bloqade/stim/passes/squin_to_stim.py +18 -60
  135. bloqade/stim/rewrite/__init__.py +3 -4
  136. bloqade/stim/rewrite/get_record_util.py +24 -0
  137. bloqade/stim/rewrite/ifs_to_stim.py +29 -31
  138. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  139. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  140. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  141. bloqade/stim/rewrite/squin_measure.py +11 -79
  142. bloqade/stim/rewrite/squin_noise.py +134 -108
  143. bloqade/stim/rewrite/util.py +5 -192
  144. bloqade/test_utils.py +1 -1
  145. bloqade/types.py +10 -0
  146. bloqade/validation/__init__.py +2 -0
  147. bloqade/validation/analysis/__init__.py +5 -0
  148. bloqade/validation/analysis/analysis.py +41 -0
  149. bloqade/validation/analysis/lattice.py +58 -0
  150. bloqade/validation/kernel_validation.py +77 -0
  151. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  152. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  153. bloqade/pyqrack/squin/op.py +0 -166
  154. bloqade/pyqrack/squin/runtime.py +0 -535
  155. bloqade/pyqrack/squin/wire.py +0 -51
  156. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  157. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  158. bloqade/squin/_typeinfer.py +0 -20
  159. bloqade/squin/analysis/address_impl.py +0 -71
  160. bloqade/squin/analysis/nsites/__init__.py +0 -9
  161. bloqade/squin/analysis/nsites/analysis.py +0 -50
  162. bloqade/squin/analysis/nsites/impls.py +0 -92
  163. bloqade/squin/analysis/nsites/lattice.py +0 -49
  164. bloqade/squin/cirq/__init__.py +0 -265
  165. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  166. bloqade/squin/cirq/emit/noise.py +0 -49
  167. bloqade/squin/cirq/emit/op.py +0 -125
  168. bloqade/squin/cirq/emit/qubit.py +0 -60
  169. bloqade/squin/cirq/emit/runtime.py +0 -242
  170. bloqade/squin/cirq/lowering.py +0 -440
  171. bloqade/squin/lowering.py +0 -54
  172. bloqade/squin/noise/_wrapper.py +0 -40
  173. bloqade/squin/noise/rewrite.py +0 -111
  174. bloqade/squin/op/__init__.py +0 -41
  175. bloqade/squin/op/_dialect.py +0 -3
  176. bloqade/squin/op/_wrapper.py +0 -121
  177. bloqade/squin/op/number.py +0 -5
  178. bloqade/squin/op/rewrite.py +0 -46
  179. bloqade/squin/op/stdlib.py +0 -62
  180. bloqade/squin/op/stmts.py +0 -276
  181. bloqade/squin/op/traits.py +0 -43
  182. bloqade/squin/op/types.py +0 -26
  183. bloqade/squin/qubit.py +0 -184
  184. bloqade/squin/rewrite/canonicalize.py +0 -60
  185. bloqade/squin/rewrite/desugar.py +0 -124
  186. bloqade/squin/types.py +0 -8
  187. bloqade/squin/wire.py +0 -201
  188. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  189. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  190. bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
  191. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  192. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,72 +1,138 @@
1
- import random
2
- import typing
3
- from dataclasses import dataclass
4
-
5
1
  from kirin import interp
6
- from kirin.dialects import ilist
7
2
 
8
- from bloqade.pyqrack import QubitState, PyQrackQubit, PyQrackInterpreter
9
- from bloqade.squin.noise.stmts import QubitLoss, StochasticUnitaryChannel
3
+ from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter
4
+ from bloqade.squin.noise.stmts import (
5
+ QubitLoss,
6
+ Depolarize,
7
+ Depolarize2,
8
+ CorrelatedQubitLoss,
9
+ TwoQubitPauliChannel,
10
+ SingleQubitPauliChannel,
11
+ )
10
12
  from bloqade.squin.noise._dialect import dialect as squin_noise_dialect
11
13
 
12
- from ..runtime import OperatorRuntimeABC
13
-
14
-
15
- @dataclass(frozen=True)
16
- class StochasticUnitaryChannelRuntime(OperatorRuntimeABC):
17
- operators: ilist.IList[OperatorRuntimeABC, typing.Any]
18
- probabilities: ilist.IList[float, typing.Any]
19
-
20
- @property
21
- def n_sites(self) -> int:
22
- n = self.operators[0].n_sites
23
- for op in self.operators[1:]:
24
- assert (
25
- op.n_sites == n
26
- ), "Encountered a stochastic unitary channel with operators of different size!"
27
- return n
28
-
29
- def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
30
- # NOTE: probabilities don't necessarily sum to 1; could be no noise event should occur
31
- p_no_op = 1 - sum(self.probabilities)
32
- if random.uniform(0.0, 1.0) < p_no_op:
33
- return
34
-
35
- selected_ops = random.choices(self.operators, weights=self.probabilities)
36
- for op in selected_ops:
37
- op.apply(*qubits, adjoint=adjoint)
38
-
39
14
 
40
- @dataclass(frozen=True)
41
- class QubitLossRuntime(OperatorRuntimeABC):
42
- p: float
43
-
44
- @property
45
- def n_sites(self) -> int:
46
- return 1
15
+ @squin_noise_dialect.register(key="pyqrack")
16
+ class PyQrackMethods(interp.MethodTable):
47
17
 
48
- def apply(self, qubit: PyQrackQubit, adjoint: bool = False) -> None:
49
- if random.uniform(0.0, 1.0) < self.p:
50
- qubit.state = QubitState.Lost
18
+ single_pauli_choices = ("i", "x", "y", "z")
19
+ two_pauli_choices = (
20
+ "ii",
21
+ "ix",
22
+ "iy",
23
+ "iz",
24
+ "xi",
25
+ "xx",
26
+ "xy",
27
+ "xz",
28
+ "yi",
29
+ "yx",
30
+ "yy",
31
+ "yz",
32
+ "zi",
33
+ "zx",
34
+ "zy",
35
+ "zz",
36
+ )
37
+
38
+ @interp.impl(Depolarize)
39
+ def depolarize(
40
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: Depolarize
41
+ ):
42
+ p = frame.get(stmt.p)
43
+ ps = [p / 3.0] * 3
44
+ qubits = frame.get(stmt.qubits)
45
+ self.apply_single_qubit_pauli_error(interp, ps, qubits)
51
46
 
47
+ @interp.impl(Depolarize2)
48
+ def depolarize2(
49
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: Depolarize2
50
+ ):
51
+ p = frame.get(stmt.p)
52
+ ps = [p / 15.0] * 15
53
+ controls = frame.get(stmt.controls)
54
+ targets = frame.get(stmt.targets)
55
+ self.apply_two_qubit_pauli_error(interp, ps, controls, targets)
52
56
 
53
- @squin_noise_dialect.register(key="pyqrack")
54
- class PyQrackMethods(interp.MethodTable):
55
- @interp.impl(StochasticUnitaryChannel)
56
- def stochastic_unitary_channel(
57
+ @interp.impl(SingleQubitPauliChannel)
58
+ def single_qubit_pauli_channel(
57
59
  self,
58
60
  interp: PyQrackInterpreter,
59
61
  frame: interp.Frame,
60
- stmt: StochasticUnitaryChannel,
62
+ stmt: SingleQubitPauliChannel,
61
63
  ):
62
- operators = frame.get(stmt.operators)
63
- probabilities = frame.get(stmt.probabilities)
64
-
65
- return (StochasticUnitaryChannelRuntime(operators, probabilities),)
64
+ px = frame.get(stmt.px)
65
+ py = frame.get(stmt.py)
66
+ pz = frame.get(stmt.pz)
67
+ qubits = frame.get(stmt.qubits)
68
+ self.apply_single_qubit_pauli_error(interp, [px, py, pz], qubits)
69
+
70
+ @interp.impl(TwoQubitPauliChannel)
71
+ def two_qubit_pauli_channel(
72
+ self,
73
+ interp: PyQrackInterpreter,
74
+ frame: interp.Frame,
75
+ stmt: TwoQubitPauliChannel,
76
+ ):
77
+ ps = frame.get(stmt.probabilities)
78
+ controls = frame.get(stmt.controls)
79
+ targets = frame.get(stmt.targets)
80
+ self.apply_two_qubit_pauli_error(interp, ps, controls, targets)
66
81
 
67
82
  @interp.impl(QubitLoss)
68
83
  def qubit_loss(
69
84
  self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: QubitLoss
70
85
  ):
71
86
  p = frame.get(stmt.p)
72
- return (QubitLossRuntime(p),)
87
+ qubits: list[PyQrackQubit] = frame.get(stmt.qubits)
88
+ for qbit in qubits:
89
+ if interp.rng_state.uniform(0.0, 1.0) <= p:
90
+ qbit.drop()
91
+
92
+ @interp.impl(CorrelatedQubitLoss)
93
+ def correlated_qubit_loss(
94
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: CorrelatedQubitLoss
95
+ ):
96
+ p = frame.get(stmt.p)
97
+ qubits: list[list[PyQrackQubit]] = frame.get(stmt.qubits)
98
+ for qubit_group in qubits:
99
+ if interp.rng_state.uniform(0.0, 1.0) <= p:
100
+ for qbit in qubit_group:
101
+ qbit.drop()
102
+
103
+ def apply_single_qubit_pauli_error(
104
+ self,
105
+ interp: PyQrackInterpreter,
106
+ ps: list[float],
107
+ qubits: list[PyQrackQubit],
108
+ ):
109
+ pi = 1 - sum(ps)
110
+ probs = [pi] + ps
111
+
112
+ assert all(0 <= x <= 1 for x in probs), "Invalid Pauli error probabilities"
113
+
114
+ for qbit in qubits:
115
+ which = interp.rng_state.choice(self.single_pauli_choices, p=probs)
116
+ self.apply_pauli_error(which, qbit)
117
+
118
+ def apply_two_qubit_pauli_error(
119
+ self,
120
+ interp: PyQrackInterpreter,
121
+ ps: list[float],
122
+ controls: list[PyQrackQubit],
123
+ targets: list[PyQrackQubit],
124
+ ):
125
+ pii = 1 - sum(ps)
126
+ probs = [pii] + ps
127
+ assert all(0 <= x <= 1 for x in probs), "Invalid Pauli error probabilities"
128
+
129
+ for control, target in zip(controls, targets):
130
+ which = interp.rng_state.choice(self.two_pauli_choices, p=probs)
131
+ self.apply_pauli_error(which[0], control)
132
+ self.apply_pauli_error(which[1], target)
133
+
134
+ def apply_pauli_error(self, which: str, qbit: PyQrackQubit):
135
+ if not qbit.is_active() or which == "i":
136
+ return
137
+
138
+ getattr(qbit.sim_reg, which)(qbit.addr)
@@ -3,61 +3,64 @@ from typing import Any
3
3
  from kirin import interp
4
4
  from kirin.dialects import ilist
5
5
 
6
- from bloqade.squin import qubit
7
- from bloqade.pyqrack.reg import QubitState, PyQrackQubit
6
+ from bloqade.qubit import stmts as qubit
7
+ from bloqade.pyqrack.reg import QubitState, Measurement, PyQrackQubit
8
8
  from bloqade.pyqrack.base import PyQrackInterpreter
9
9
 
10
- from .runtime import OperatorRuntimeABC
11
-
12
10
 
13
11
  @qubit.dialect.register(key="pyqrack")
14
12
  class PyQrackMethods(interp.MethodTable):
15
13
  @interp.impl(qubit.New)
16
- def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New):
17
- n_qubits: int = frame.get(stmt.n_qubits)
18
- qreg = ilist.IList(
19
- [
20
- PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active)
21
- for i in interp.memory.allocate(n_qubits=n_qubits)
22
- ]
23
- )
24
- return (qreg,)
25
-
26
- @interp.impl(qubit.Apply)
27
- def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Apply):
28
- qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
29
- operator: OperatorRuntimeABC = frame.get(stmt.operator)
30
- operator.apply(*qubits)
31
-
32
- @interp.impl(qubit.Broadcast)
33
- def broadcast(
34
- self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Broadcast
14
+ def new_qubit(
15
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New
35
16
  ):
36
- operator: OperatorRuntimeABC = frame.get(stmt.operator)
37
- qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
38
- operator.broadcast_apply(qubits)
17
+ (addr,) = interp.memory.allocate(1)
18
+ qb = PyQrackQubit(addr, interp.memory.sim_reg, QubitState.Active)
19
+ return (qb,)
39
20
 
40
21
  def _measure_qubit(self, qbit: PyQrackQubit, interp: PyQrackInterpreter):
41
22
  if qbit.is_active():
42
- return bool(qbit.sim_reg.m(qbit.addr))
23
+ m = Measurement(bool(qbit.sim_reg.m(qbit.addr)))
43
24
  else:
44
- return interp.loss_m_result
25
+ m = Measurement(interp.loss_m_result)
26
+
27
+ interp.set_global_measurement_id(m)
28
+ return m
45
29
 
46
- @interp.impl(qubit.MeasureQubitList)
30
+ @interp.impl(qubit.Measure)
47
31
  def measure_qubit_list(
48
32
  self,
49
33
  interp: PyQrackInterpreter,
50
34
  frame: interp.Frame,
51
- stmt: qubit.MeasureQubitList,
35
+ stmt: qubit.Measure,
52
36
  ):
53
37
  qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
54
38
  result = ilist.IList([self._measure_qubit(qbit, interp) for qbit in qubits])
55
39
  return (result,)
56
40
 
57
- @interp.impl(qubit.MeasureQubit)
58
- def measure_qubit(
59
- self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasureQubit
41
+ @interp.impl(qubit.QubitId)
42
+ def qubit_id(
43
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.QubitId
60
44
  ):
61
- qbit: PyQrackQubit = frame.get(stmt.qubit)
62
- result = self._measure_qubit(qbit, interp)
63
- return (result,)
45
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
46
+ ids = ilist.IList([qbit.addr for qbit in qubits])
47
+ return (ids,)
48
+
49
+ @interp.impl(qubit.MeasurementId)
50
+ def measurement_id(
51
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasurementId
52
+ ):
53
+ measurements: ilist.IList[Measurement, Any] = frame.get(stmt.measurements)
54
+ ids = ilist.IList([measurement.measurement_id for measurement in measurements])
55
+ return (ids,)
56
+
57
+ @interp.impl(qubit.Reset)
58
+ def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Reset):
59
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
60
+ for qbit in qubits:
61
+ if not qbit.is_active():
62
+ continue
63
+
64
+ m = qbit.sim_reg.m(qbit.addr)
65
+ if m == Measurement.One:
66
+ qbit.sim_reg.x(qbit.addr)
bloqade/pyqrack/target.py CHANGED
@@ -12,7 +12,7 @@ from bloqade.pyqrack.base import (
12
12
  PyQrackInterpreter,
13
13
  _default_pyqrack_args,
14
14
  )
15
- from bloqade.analysis.address import AnyAddress, AddressAnalysis
15
+ from bloqade.analysis.address import UnknownQubit, AddressAnalysis
16
16
 
17
17
  Params = ParamSpec("Params")
18
18
  RetType = TypeVar("RetType")
@@ -51,9 +51,9 @@ class PyQrack:
51
51
  return PyQrackInterpreter(mt.dialects, memory=DynamicMemory(options))
52
52
  else:
53
53
  address_analysis = AddressAnalysis(mt.dialects)
54
- frame, _ = address_analysis.run_analysis(mt)
54
+ frame, _ = address_analysis.run(mt)
55
55
  if self.min_qubits == 0 and any(
56
- isinstance(a, AnyAddress) for a in frame.entries.values()
56
+ isinstance(a, UnknownQubit) for a in frame.entries.values()
57
57
  ):
58
58
  raise ValueError(
59
59
  "All addresses must be resolved. Or set min_qubits to a positive integer."
@@ -87,7 +87,8 @@ class PyQrack:
87
87
  """
88
88
  fold = Fold(mt.dialects)
89
89
  fold(mt)
90
- return self._get_interp(mt).run(mt, args, kwargs)
90
+ _, ret = self._get_interp(mt).run(mt, *args, **kwargs)
91
+ return ret
91
92
 
92
93
  def multi_run(
93
94
  self,
bloqade/pyqrack/task.py CHANGED
@@ -1,7 +1,12 @@
1
1
  from typing import TypeVar, ParamSpec, cast
2
+ from collections import Counter
2
3
  from dataclasses import dataclass
3
4
 
5
+ import numpy as np
6
+ from kirin.dialects.ilist import IList
7
+
4
8
  from bloqade.task import AbstractSimulatorTask
9
+ from bloqade.pyqrack.reg import QubitState, PyQrackQubit
5
10
  from bloqade.pyqrack.base import (
6
11
  MemoryABC,
7
12
  PyQrackInterpreter,
@@ -19,14 +24,12 @@ class PyQrackSimulatorTask(AbstractSimulatorTask[Param, RetType, MemoryType]):
19
24
  pyqrack_interp: PyQrackInterpreter[MemoryType]
20
25
 
21
26
  def run(self) -> RetType:
22
- return cast(
23
- RetType,
24
- self.pyqrack_interp.run(
25
- self.kernel,
26
- args=self.args,
27
- kwargs=self.kwargs,
28
- ),
27
+ _, ret = self.pyqrack_interp.run(
28
+ self.kernel,
29
+ *self.args,
30
+ **self.kwargs,
29
31
  )
32
+ return cast(RetType, ret)
30
33
 
31
34
  @property
32
35
  def state(self) -> MemoryType:
@@ -36,3 +39,107 @@ class PyQrackSimulatorTask(AbstractSimulatorTask[Param, RetType, MemoryType]):
36
39
  """Returns the state vector of the simulator."""
37
40
  self.run()
38
41
  return self.state.sim_reg.out_ket()
42
+
43
+ def qubits(self) -> list[PyQrackQubit]:
44
+ """Returns the qubits in the simulator."""
45
+ try:
46
+ N = self.state.sim_reg.num_qubits()
47
+ return [
48
+ PyQrackQubit(
49
+ addr=i, sim_reg=self.state.sim_reg, state=QubitState.Active
50
+ )
51
+ for i in range(N)
52
+ ]
53
+ except AttributeError:
54
+ Warning("Task has not been run, there are no qubits!")
55
+ return []
56
+
57
+ def batch_run(self, shots: int = 1) -> dict[RetType, float]:
58
+ """
59
+ Repeatedly run the task to collect statistics on the shot outcomes.
60
+ The average is done over [shots] repetitions and thus is frequentist
61
+ and converges to exact only in the shots -> infinity limit.
62
+
63
+ Args:
64
+ shots (int):
65
+ the number of repetitions of the task
66
+ Returns:
67
+ dict[RetType, float]:
68
+ a dictionary mapping outcomes to their probabilities,
69
+ as estimated from counting the shot outcomes. RetType must be hashable.
70
+ """
71
+
72
+ results: list[RetType] = [self.run() for _ in range(shots)]
73
+
74
+ # Convert IList to tuple so that it is hashable by Counter
75
+ def convert(data):
76
+ if isinstance(data, (list, IList)):
77
+ return tuple(convert(item) for item in data)
78
+ return data
79
+
80
+ results = convert(results)
81
+
82
+ data = {
83
+ key: value / len(results) for key, value in Counter(results).items()
84
+ } # Normalize to probabilities
85
+ return data
86
+
87
+ def batch_state(
88
+ self, shots: int = 1, qubit_map: None = None
89
+ ) -> "QuantumState": # noqa: F821
90
+ """
91
+ Repeatedly run the task to extract the averaged quantum state.
92
+ The average is done over [shots] repetitions and thus is frequentist
93
+ and converges to exact only in the shots -> infinity limit.
94
+
95
+ Args:
96
+ shots (int):
97
+ the number of repetitions of the task
98
+ qubit_map (callable | None):
99
+ an optional callable that takes the output of self.run() and extract
100
+ the [returned] qubits to be used for the quantum state.
101
+ If None, all qubits in the simulator are used, in the order set by the simulator.
102
+ If callable, qubit_map must have the signature
103
+ > qubit_map(output:RetType) -> list[PyQrackQubit]
104
+ and the averaged state is
105
+ > quantum_state(qubit_map(self.run())).
106
+ If qubit_map is not None, self.run() must return qubit(s).
107
+ Two common patterns here are:
108
+ > qubit_map = lambda qubits: qubits
109
+ for the case where self.run() returns a list of qubits, or
110
+ > qubit_map = lambda qubit: [qubits]
111
+ for the case where self.run() returns a single qubit.
112
+ Returns:
113
+ QuantumState:
114
+ the averaged quantum state as a density matrix,
115
+ represented in its eigenbasis.
116
+ """
117
+ # Import here to avoid circular dependencies.
118
+ from bloqade.pyqrack.device import QuantumState, PyQrackSimulatorBase
119
+
120
+ states: list[QuantumState] = []
121
+ for _ in range(shots):
122
+ res = self.run()
123
+ if callable(qubit_map):
124
+ qbs = qubit_map(res)
125
+ else:
126
+ qbs = self.qubits()
127
+ states.append(PyQrackSimulatorBase.quantum_state(qbs))
128
+
129
+ state = QuantumState(
130
+ eigenvectors=np.concatenate(
131
+ [state.eigenvectors for state in states], axis=1
132
+ ),
133
+ eigenvalues=np.concatenate([state.eigenvalues for state in states], axis=0)
134
+ / len(states),
135
+ )
136
+
137
+ # Canonicalize the state by orthoganalizing the basis vectors.
138
+ tol = 1e-7
139
+ s, v, d = np.linalg.svd(
140
+ state.eigenvectors * np.sqrt(state.eigenvalues), full_matrices=False
141
+ )
142
+ mask = v > tol
143
+ v = v[mask] ** 2
144
+ s = s[:, mask]
145
+ return QuantumState(eigenvalues=v, eigenvectors=s)
@@ -4,6 +4,7 @@ import pathlib
4
4
  from typing import Any
5
5
 
6
6
  from kirin import ir, lowering
7
+ from kirin.types import MethodType
7
8
  from kirin.dialects import func
8
9
 
9
10
  from . import parse
@@ -82,11 +83,10 @@ def loads(
82
83
  body=body,
83
84
  )
84
85
 
86
+ body.blocks[0].args.append_from(MethodType, kernel_name + "_self")
87
+
85
88
  mt = ir.Method(
86
- mod=None,
87
- py_func=None,
88
89
  sym_name=kernel_name,
89
- arg_names=[],
90
90
  dialects=qasm2_lowering.dialects,
91
91
  code=code,
92
92
  )
@@ -1,12 +1,14 @@
1
1
  from kirin import interp
2
+ from kirin.analysis import const
2
3
 
3
4
  from bloqade.analysis.address import (
4
5
  Address,
5
- NotQubit,
6
6
  AddressReg,
7
- AddressQubit,
7
+ ConstResult,
8
+ UnknownQubit,
8
9
  AddressAnalysis,
9
10
  )
11
+ from bloqade.analysis.address.lattice import UnknownReg
10
12
 
11
13
  from .stmts import QRegGet, QRegNew
12
14
  from ._dialect import dialect
@@ -22,17 +24,24 @@ class AddressMethodTable(interp.MethodTable):
22
24
  frame: interp.Frame[Address],
23
25
  stmt: QRegNew,
24
26
  ):
25
- n_qubits = interp.get_const_value(int, stmt.n_qubits)
26
- addr = AddressReg(range(interp.next_address, interp.next_address + n_qubits))
27
- interp.next_address += n_qubits
28
- return (addr,)
27
+ n_qubits = frame.get(stmt.n_qubits)
28
+ match n_qubits:
29
+ case ConstResult(const.Value(int() as n)):
30
+ addr = AddressReg(range(interp.next_address, interp.next_address + n))
31
+ interp.next_address += n
32
+ return (addr,)
33
+ case _:
34
+ return (UnknownReg(),)
29
35
 
30
36
  @interp.impl(QRegGet)
31
37
  def get(self, interp: AddressAnalysis, frame: interp.Frame[Address], stmt: QRegGet):
32
38
  addr = frame.get(stmt.reg)
33
- pos = interp.get_const_value(int, stmt.idx)
34
- if isinstance(addr, AddressReg):
35
- global_idx = addr.data[pos]
36
- return (AddressQubit(global_idx),)
37
- else: # this is not reachable
38
- return (NotQubit(),)
39
+ idx = frame.get(stmt.idx)
40
+
41
+ typ, values = interp.unpack_iterable(addr)
42
+ idx_value = interp.get_const_value(idx, int)
43
+
44
+ if typ is not None and idx_value is not None:
45
+ return (values[idx_value],)
46
+
47
+ return (UnknownQubit(),)
@@ -20,7 +20,10 @@ class EmitExpr(interp.MethodTable):
20
20
 
21
21
  args: list[ast.Node] = []
22
22
  cparams, qparams = [], []
23
- for arg in stmt.body.blocks[0].args:
23
+ entry_args = stmt.body.blocks[0].args
24
+ user_args = entry_args[1:] if len(entry_args) > 0 else []
25
+
26
+ for arg in user_args:
24
27
  assert arg.name is not None
25
28
 
26
29
  args.append(ast.Name(id=arg.name))
@@ -29,14 +32,22 @@ class EmitExpr(interp.MethodTable):
29
32
  else:
30
33
  cparams.append(arg.name)
31
34
 
32
- emit.run_ssacfg_region(frame, stmt.body, tuple(args))
33
- emit.output = ast.Gate(
34
- name=stmt.sym_name,
35
- cparams=cparams,
36
- qparams=qparams,
37
- body=frame.body,
35
+ frame.worklist.append(interp.Successor(stmt.body.blocks[0], *args))
36
+ if len(entry_args) > 0:
37
+ frame.set(entry_args[0], ast.Name(stmt.sym_name or "gate"))
38
+
39
+ while (succ := frame.worklist.pop()) is not None:
40
+ frame.set_values(succ.block.args[1:], succ.block_args)
41
+ block_header = emit.emit_block(frame, succ.block)
42
+ frame.block_ref[succ.block] = block_header
43
+ return (
44
+ ast.Gate(
45
+ name=stmt.sym_name,
46
+ cparams=cparams,
47
+ qparams=qparams,
48
+ body=frame.body,
49
+ ),
38
50
  )
39
- return ()
40
51
 
41
52
  @interp.impl(stmts.ConstInt)
42
53
  @interp.impl(stmts.ConstFloat)
@@ -87,7 +87,7 @@ class ConstPI(ir.Statement):
87
87
 
88
88
 
89
89
  # QASM 2.0 arithmetic operations
90
- PyNum = types.Union(types.Int, types.Float)
90
+ PyNum = types.TypeVar("PyNum", bound=types.Union(types.Int, types.Float))
91
91
 
92
92
 
93
93
  @statement(dialect=dialect)
@@ -110,7 +110,7 @@ class Sin(ir.Statement):
110
110
  traits = frozenset({lowering.FromPythonCall()})
111
111
  value: ir.SSAValue = info.argument(PyNum)
112
112
  """value (Union[int, float]): The number to take the sine of."""
113
- result: ir.ResultValue = info.result(PyNum)
113
+ result: ir.ResultValue = info.result(types.Float)
114
114
  """result (float): The sine of the number."""
115
115
 
116
116
 
@@ -122,7 +122,7 @@ class Cos(ir.Statement):
122
122
  traits = frozenset({lowering.FromPythonCall()})
123
123
  value: ir.SSAValue = info.argument(PyNum)
124
124
  """value (Union[int, float]): The number to take the cosine of."""
125
- result: ir.ResultValue = info.result(PyNum)
125
+ result: ir.ResultValue = info.result(types.Float)
126
126
  """result (float): The cosine of the number."""
127
127
 
128
128
 
@@ -134,7 +134,7 @@ class Tan(ir.Statement):
134
134
  traits = frozenset({lowering.FromPythonCall()})
135
135
  value: ir.SSAValue = info.argument(PyNum)
136
136
  """value (Union[int, float]): The number to take the tangent of."""
137
- result: ir.ResultValue = info.result(PyNum)
137
+ result: ir.ResultValue = info.result(types.Float)
138
138
  """result (float): The tangent of the number."""
139
139
 
140
140
 
@@ -146,7 +146,7 @@ class Exp(ir.Statement):
146
146
  traits = frozenset({lowering.FromPythonCall()})
147
147
  value: ir.SSAValue = info.argument(PyNum)
148
148
  """value (Union[int, float]): The number to take the exponential of."""
149
- result: ir.ResultValue = info.result(PyNum)
149
+ result: ir.ResultValue = info.result(types.Float)
150
150
  """result (float): The exponential of the number."""
151
151
 
152
152
 
@@ -158,7 +158,7 @@ class Log(ir.Statement):
158
158
  traits = frozenset({lowering.FromPythonCall()})
159
159
  value: ir.SSAValue = info.argument(PyNum)
160
160
  """value (Union[int, float]): The number to take the natural log of."""
161
- result: ir.ResultValue = info.result(PyNum)
161
+ result: ir.ResultValue = info.result(types.Float)
162
162
  """result (float): The natural log of the number."""
163
163
 
164
164
 
@@ -170,7 +170,7 @@ class Sqrt(ir.Statement):
170
170
  traits = frozenset({lowering.FromPythonCall()})
171
171
  value: ir.SSAValue = info.argument(PyNum)
172
172
  """value (Union[int, float]): The number to take the square root of."""
173
- result: ir.ResultValue = info.result(PyNum)
173
+ result: ir.ResultValue = info.result(types.Float)
174
174
  """result (float): The square root of the number."""
175
175
 
176
176
 
@@ -1,7 +1,6 @@
1
1
  from kirin import interp
2
2
  from kirin.lattice import EmptyLattice
3
3
 
4
- from bloqade.analysis.address import AddressQubit, AddressTuple
5
4
  from bloqade.analysis.fidelity import FidelityAnalysis
6
5
 
7
6
  from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
@@ -32,7 +31,7 @@ class FidelityMethodTable(interp.MethodTable):
32
31
  # NOTE: fidelity is just the inverse probability of any noise to occur
33
32
  fid = (1 - p) * (1 - p_ctrl)
34
33
 
35
- interp._current_gate_fidelity *= fid
34
+ interp.gate_fidelity *= fid
36
35
 
37
36
  @interp.impl(AtomLossChannel)
38
37
  def atom_loss(
@@ -42,10 +41,7 @@ class FidelityMethodTable(interp.MethodTable):
42
41
  stmt: AtomLossChannel,
43
42
  ):
44
43
  # NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple
45
- addresses: AddressTuple = interp.addr_frame.get(stmt.qargs)
46
-
44
+ addresses = interp.addr_frame.get(stmt.qargs)
47
45
  # NOTE: get the corresponding index and reduce survival probability accordingly
48
- for qbit_address in addresses.data:
49
- assert isinstance(qbit_address, AddressQubit)
50
- index = qbit_address.data
51
- interp._current_atom_survival_probability[index] *= 1 - stmt.prob
46
+ for index in addresses.data:
47
+ interp.atom_survival_probability[index] *= 1 - stmt.prob