bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/analysis.py +18 -20
  7. bloqade/analysis/measure_id/impls.py +31 -29
  8. bloqade/annotate/__init__.py +6 -0
  9. bloqade/annotate/_dialect.py +3 -0
  10. bloqade/annotate/_interface.py +22 -0
  11. bloqade/annotate/stmts.py +29 -0
  12. bloqade/annotate/types.py +13 -0
  13. bloqade/cirq_utils/__init__.py +4 -2
  14. bloqade/cirq_utils/emit/__init__.py +3 -0
  15. bloqade/cirq_utils/emit/base.py +246 -0
  16. bloqade/cirq_utils/emit/gate.py +104 -0
  17. bloqade/cirq_utils/emit/noise.py +90 -0
  18. bloqade/cirq_utils/emit/qubit.py +35 -0
  19. bloqade/cirq_utils/lowering.py +660 -0
  20. bloqade/cirq_utils/noise/__init__.py +0 -2
  21. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  22. bloqade/cirq_utils/noise/model.py +151 -191
  23. bloqade/cirq_utils/noise/transform.py +2 -2
  24. bloqade/cirq_utils/parallelize.py +9 -6
  25. bloqade/gemini/__init__.py +1 -0
  26. bloqade/gemini/analysis/__init__.py +3 -0
  27. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  28. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  29. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  30. bloqade/gemini/groups.py +67 -0
  31. bloqade/native/__init__.py +23 -0
  32. bloqade/native/_prelude.py +45 -0
  33. bloqade/native/dialects/__init__.py +0 -0
  34. bloqade/native/dialects/gate/__init__.py +2 -0
  35. bloqade/native/dialects/gate/_dialect.py +3 -0
  36. bloqade/native/dialects/gate/_interface.py +32 -0
  37. bloqade/native/dialects/gate/stmts.py +31 -0
  38. bloqade/native/stdlib/__init__.py +0 -0
  39. bloqade/native/stdlib/broadcast.py +246 -0
  40. bloqade/native/stdlib/simple.py +220 -0
  41. bloqade/native/upstream/__init__.py +4 -0
  42. bloqade/native/upstream/squin2native.py +79 -0
  43. bloqade/pyqrack/__init__.py +2 -2
  44. bloqade/pyqrack/base.py +7 -1
  45. bloqade/pyqrack/device.py +192 -18
  46. bloqade/pyqrack/native.py +49 -0
  47. bloqade/pyqrack/reg.py +6 -6
  48. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  49. bloqade/pyqrack/squin/gate/gate.py +136 -0
  50. bloqade/pyqrack/squin/noise/native.py +120 -54
  51. bloqade/pyqrack/squin/qubit.py +39 -36
  52. bloqade/pyqrack/target.py +5 -4
  53. bloqade/pyqrack/task.py +114 -7
  54. bloqade/qasm2/_qasm_loading.py +3 -3
  55. bloqade/qasm2/dialects/core/address.py +21 -12
  56. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  57. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  58. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  59. bloqade/qasm2/dialects/noise/model.py +2 -1
  60. bloqade/qasm2/emit/base.py +16 -11
  61. bloqade/qasm2/emit/gate.py +11 -8
  62. bloqade/qasm2/emit/main.py +103 -3
  63. bloqade/qasm2/emit/target.py +9 -5
  64. bloqade/qasm2/groups.py +3 -2
  65. bloqade/qasm2/parse/lowering.py +0 -1
  66. bloqade/qasm2/passes/fold.py +14 -73
  67. bloqade/qasm2/passes/glob.py +2 -2
  68. bloqade/qasm2/passes/noise.py +1 -1
  69. bloqade/qasm2/passes/parallel.py +7 -5
  70. bloqade/qasm2/rewrite/__init__.py +0 -1
  71. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  72. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  73. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  74. bloqade/qasm2/rewrite/register.py +2 -2
  75. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  76. bloqade/qbraid/lowering.py +1 -0
  77. bloqade/qbraid/schema.py +2 -2
  78. bloqade/qubit/__init__.py +12 -0
  79. bloqade/qubit/_dialect.py +3 -0
  80. bloqade/qubit/_interface.py +49 -0
  81. bloqade/qubit/_prelude.py +45 -0
  82. bloqade/qubit/analysis/__init__.py +1 -0
  83. bloqade/qubit/analysis/address_impl.py +40 -0
  84. bloqade/qubit/stdlib/__init__.py +2 -0
  85. bloqade/qubit/stdlib/_new.py +34 -0
  86. bloqade/qubit/stdlib/broadcast.py +62 -0
  87. bloqade/qubit/stdlib/simple.py +59 -0
  88. bloqade/qubit/stmts.py +60 -0
  89. bloqade/rewrite/passes/__init__.py +6 -0
  90. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  91. bloqade/rewrite/passes/callgraph.py +116 -0
  92. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  93. bloqade/rewrite/rules/split_ifs.py +18 -1
  94. bloqade/squin/__init__.py +47 -14
  95. bloqade/squin/analysis/__init__.py +0 -1
  96. bloqade/squin/analysis/schedule.py +10 -11
  97. bloqade/squin/gate/__init__.py +2 -0
  98. bloqade/squin/gate/_dialect.py +3 -0
  99. bloqade/squin/gate/_interface.py +98 -0
  100. bloqade/squin/gate/stmts.py +125 -0
  101. bloqade/squin/groups.py +5 -22
  102. bloqade/squin/noise/__init__.py +1 -10
  103. bloqade/squin/noise/_dialect.py +1 -1
  104. bloqade/squin/noise/_interface.py +45 -0
  105. bloqade/squin/noise/stmts.py +66 -28
  106. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  107. bloqade/squin/rewrite/__init__.py +0 -2
  108. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  109. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  110. bloqade/squin/stdlib/__init__.py +0 -0
  111. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  112. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  113. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  114. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  115. bloqade/squin/stdlib/simple/__init__.py +33 -0
  116. bloqade/squin/stdlib/simple/gate.py +242 -0
  117. bloqade/squin/stdlib/simple/noise.py +126 -0
  118. bloqade/stim/__init__.py +1 -0
  119. bloqade/stim/_wrappers.py +6 -0
  120. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  121. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  122. bloqade/stim/dialects/gate/emit.py +9 -10
  123. bloqade/stim/dialects/noise/emit.py +17 -13
  124. bloqade/stim/dialects/noise/stmts.py +5 -3
  125. bloqade/stim/emit/__init__.py +1 -0
  126. bloqade/stim/emit/impls.py +16 -0
  127. bloqade/stim/emit/stim_str.py +48 -31
  128. bloqade/stim/groups.py +12 -2
  129. bloqade/stim/parse/lowering.py +14 -17
  130. bloqade/stim/passes/__init__.py +0 -2
  131. bloqade/stim/passes/flatten.py +26 -0
  132. bloqade/stim/passes/simplify_ifs.py +6 -1
  133. bloqade/stim/passes/squin_to_stim.py +9 -84
  134. bloqade/stim/rewrite/__init__.py +2 -4
  135. bloqade/stim/rewrite/get_record_util.py +24 -0
  136. bloqade/stim/rewrite/ifs_to_stim.py +24 -25
  137. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  138. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  139. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  140. bloqade/stim/rewrite/squin_measure.py +9 -18
  141. bloqade/stim/rewrite/squin_noise.py +134 -108
  142. bloqade/stim/rewrite/util.py +5 -192
  143. bloqade/test_utils.py +1 -1
  144. bloqade/types.py +10 -0
  145. bloqade/validation/__init__.py +2 -0
  146. bloqade/validation/analysis/__init__.py +5 -0
  147. bloqade/validation/analysis/analysis.py +41 -0
  148. bloqade/validation/analysis/lattice.py +58 -0
  149. bloqade/validation/kernel_validation.py +77 -0
  150. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  151. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  152. bloqade/pyqrack/squin/op.py +0 -180
  153. bloqade/pyqrack/squin/runtime.py +0 -535
  154. bloqade/pyqrack/squin/wire.py +0 -51
  155. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  156. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  157. bloqade/squin/_typeinfer.py +0 -20
  158. bloqade/squin/analysis/address_impl.py +0 -71
  159. bloqade/squin/analysis/nsites/__init__.py +0 -9
  160. bloqade/squin/analysis/nsites/analysis.py +0 -50
  161. bloqade/squin/analysis/nsites/impls.py +0 -92
  162. bloqade/squin/analysis/nsites/lattice.py +0 -49
  163. bloqade/squin/cirq/__init__.py +0 -280
  164. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  165. bloqade/squin/cirq/emit/noise.py +0 -49
  166. bloqade/squin/cirq/emit/op.py +0 -125
  167. bloqade/squin/cirq/emit/qubit.py +0 -60
  168. bloqade/squin/cirq/emit/runtime.py +0 -242
  169. bloqade/squin/cirq/lowering.py +0 -440
  170. bloqade/squin/lowering.py +0 -54
  171. bloqade/squin/noise/_wrapper.py +0 -40
  172. bloqade/squin/noise/rewrite.py +0 -111
  173. bloqade/squin/op/__init__.py +0 -41
  174. bloqade/squin/op/_dialect.py +0 -3
  175. bloqade/squin/op/_wrapper.py +0 -121
  176. bloqade/squin/op/number.py +0 -5
  177. bloqade/squin/op/rewrite.py +0 -46
  178. bloqade/squin/op/stdlib.py +0 -62
  179. bloqade/squin/op/stmts.py +0 -276
  180. bloqade/squin/op/traits.py +0 -43
  181. bloqade/squin/op/types.py +0 -26
  182. bloqade/squin/qubit.py +0 -184
  183. bloqade/squin/rewrite/canonicalize.py +0 -60
  184. bloqade/squin/rewrite/desugar.py +0 -124
  185. bloqade/squin/types.py +0 -8
  186. bloqade/squin/wire.py +0 -201
  187. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  188. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  189. bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
  190. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  191. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,126 @@
1
+ from typing import Any, Literal, TypeVar
2
+
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import Qubit
6
+
7
+ from .. import broadcast
8
+ from ...groups import kernel
9
+
10
+
11
+ @kernel
12
+ def depolarize(p: float, qubit: Qubit) -> None:
13
+ """
14
+ Apply a depolarizing noise channel to a qubit with probability `p`.
15
+
16
+ This will randomly select one of the Pauli operators X, Y, Z
17
+ with a probability `p / 3` and apply it to the qubit. No operator is applied
18
+ with a probability of `1 - p`.
19
+
20
+ Args:
21
+ p (float): The probability with which a Pauli operator is applied.
22
+ qubit (Qubit): The qubit to which the noise channel is applied.
23
+ """
24
+ broadcast.depolarize(p, ilist.IList([qubit]))
25
+
26
+
27
+ N = TypeVar("N", bound=int)
28
+
29
+
30
+ @kernel
31
+ def depolarize2(p: float, control: Qubit, target: Qubit) -> None:
32
+ """
33
+ Symmetric two-qubit depolarization channel applied to a pair of qubits.
34
+
35
+ This will randomly select one of the pauli products
36
+
37
+ `{IX, IY, IZ, XI, XX, XY, XZ, YI, YX, YY, YZ, ZI, ZX, ZY, ZZ}`
38
+
39
+ each with a probability `p / 15`. No noise is applied with a probability of `1 - p`.
40
+
41
+ Args:
42
+ p (float): The probability with which a Pauli product is applied.
43
+ control (Qubit): The control qubit.
44
+ target (Qubit): The target qubit.
45
+ """
46
+ broadcast.depolarize2(p, ilist.IList([control]), ilist.IList([target]))
47
+
48
+
49
+ @kernel
50
+ def single_qubit_pauli_channel(px: float, py: float, pz: float, qubit: Qubit) -> None:
51
+ """
52
+ Apply a Pauli error channel with weighted `px, py, pz`. No error is applied with a probability
53
+ `1 - (px + py + pz)`.
54
+
55
+ This randomly selects one of the three Pauli operators X, Y, Z, weighted with the given probabilities in that order.
56
+
57
+ Args:
58
+ probabilities (IList[float, Literal[3]]): A list of 3 probabilities corresponding to the probabilities `(p_x, p_y, p_z)` in that order.
59
+ qubit (Qubit): The qubit to which the noise channel is applied.
60
+ """
61
+ broadcast.single_qubit_pauli_channel(px, py, pz, ilist.IList([qubit]))
62
+
63
+
64
+ @kernel
65
+ def two_qubit_pauli_channel(
66
+ probabilities: ilist.IList[float, Literal[15]], control: Qubit, target: Qubit
67
+ ) -> None:
68
+ """
69
+ Apply a Pauli product error with weighted `probabilities` to the pair of qubits.
70
+
71
+ No error is applied with the probability `1 - sum(probabilities)`.
72
+
73
+ This will randomly select one of the pauli products
74
+
75
+ `{IX, IY, IZ, XI, XX, XY, XZ, YI, YX, YY, YZ, ZI, ZX, ZY, ZZ}`
76
+
77
+ weighted with the corresponding list of probabilities.
78
+
79
+ **NOTE**: The order of the given probabilities must match the order of the list of Pauli products above!
80
+ """
81
+ broadcast.two_qubit_pauli_channel(
82
+ probabilities, ilist.IList([control]), ilist.IList([target])
83
+ )
84
+
85
+
86
+ @kernel
87
+ def qubit_loss(p: float, qubit: Qubit) -> None:
88
+ """
89
+ Apply a qubit loss channel to the given qubit.
90
+
91
+ The qubit is lost with a probability `p`.
92
+
93
+ Args:
94
+ p (float): Probability of the atom being lost.
95
+ qubit (Qubit): The qubit to which the noise channel is applied.
96
+ """
97
+ broadcast.qubit_loss(p, ilist.IList([qubit]))
98
+
99
+
100
+ @kernel
101
+ def correlated_qubit_loss(p: float, qubits: ilist.IList[Qubit, Any]) -> None:
102
+ """
103
+ Apply a correlated qubit loss channel to the given qubits.
104
+
105
+ All qubits are lost together with a probability `p`.
106
+
107
+ Args:
108
+ p (float): Probability of the qubits being lost.
109
+ qubits (IList[Qubit, Any]): The list of qubits to which the correlated noise channel is applied.
110
+ """
111
+ broadcast.correlated_qubit_loss(p, ilist.IList([qubits]))
112
+
113
+
114
+ # NOTE: actual stdlib that doesn't wrap statements starts here
115
+
116
+
117
+ @kernel
118
+ def bit_flip(p: float, qubit: Qubit) -> None:
119
+ """
120
+ Apply a bit flip error channel to the qubit with probability `p`.
121
+
122
+ Args:
123
+ p (float): Probability of a bit flip error being applied.
124
+ qubit (Qubit): The qubit to which the noise channel is applied.
125
+ """
126
+ single_qubit_pauli_channel(p, 0, 0, qubit)
bloqade/stim/__init__.py CHANGED
@@ -39,4 +39,5 @@ from ._wrappers import (
39
39
  pauli_channel1 as pauli_channel1,
40
40
  pauli_channel2 as pauli_channel2,
41
41
  observable_include as observable_include,
42
+ correlated_qubit_loss as correlated_qubit_loss,
42
43
  )
bloqade/stim/_wrappers.py CHANGED
@@ -194,3 +194,9 @@ def z_error(p: float, targets: tuple[int, ...]) -> None: ...
194
194
 
195
195
  @wraps(noise.QubitLoss)
196
196
  def qubit_loss(probs: tuple[float, ...], targets: tuple[int, ...]) -> None: ...
197
+
198
+
199
+ @wraps(noise.CorrelatedQubitLoss)
200
+ def correlated_qubit_loss(
201
+ probs: tuple[float, ...], targets: tuple[int, ...]
202
+ ) -> None: ...
@@ -1,7 +1,6 @@
1
- from kirin.emit import EmitStrFrame
2
1
  from kirin.interp import MethodTable, impl
3
2
 
4
- from bloqade.stim.emit.stim_str import EmitStimMain
3
+ from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame
5
4
 
6
5
  from . import stmts
7
6
  from ._dialect import dialect
@@ -11,7 +10,7 @@ from ._dialect import dialect
11
10
  class EmitStimAuxMethods(MethodTable):
12
11
 
13
12
  @impl(stmts.ConstInt)
14
- def const_int(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstInt):
13
+ def const_int(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ConstInt):
15
14
 
16
15
  out: str = f"{stmt.value}"
17
16
 
@@ -19,7 +18,7 @@ class EmitStimAuxMethods(MethodTable):
19
18
 
20
19
  @impl(stmts.ConstFloat)
21
20
  def const_float(
22
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstFloat
21
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ConstFloat
23
22
  ):
24
23
 
25
24
  out: str = f"{stmt.value:.8f}"
@@ -28,26 +27,28 @@ class EmitStimAuxMethods(MethodTable):
28
27
 
29
28
  @impl(stmts.ConstBool)
30
29
  def const_bool(
31
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstBool
30
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ConstBool
32
31
  ):
33
32
  out: str = "!" if stmt.value else ""
34
33
 
35
34
  return (out,)
36
35
 
37
36
  @impl(stmts.ConstStr)
38
- def const_str(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstBool):
37
+ def const_str(
38
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ConstBool
39
+ ):
39
40
 
40
41
  return (stmt.value,)
41
42
 
42
43
  @impl(stmts.Neg)
43
- def neg(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Neg):
44
+ def neg(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.Neg):
44
45
 
45
46
  operand: str = frame.get(stmt.operand)
46
47
 
47
48
  return ("-" + operand,)
48
49
 
49
50
  @impl(stmts.GetRecord)
50
- def get_rec(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.GetRecord):
51
+ def get_rec(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.GetRecord):
51
52
 
52
53
  id: str = frame.get(stmt.id)
53
54
  out: str = f"rec[{id}]"
@@ -55,14 +56,14 @@ class EmitStimAuxMethods(MethodTable):
55
56
  return (out,)
56
57
 
57
58
  @impl(stmts.Tick)
58
- def tick(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Tick):
59
+ def tick(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.Tick):
59
60
 
60
- emit.writeln(frame, "TICK")
61
+ frame.write_line("TICK")
61
62
 
62
63
  return ()
63
64
 
64
65
  @impl(stmts.Detector)
65
- def detector(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Detector):
66
+ def detector(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.Detector):
66
67
 
67
68
  coords: tuple[str, ...] = frame.get_values(stmt.coord)
68
69
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
@@ -70,27 +71,27 @@ class EmitStimAuxMethods(MethodTable):
70
71
  coord_str: str = ", ".join(coords)
71
72
  target_str: str = " ".join(targets)
72
73
  if len(coords):
73
- emit.writeln(frame, f"DETECTOR({coord_str}) {target_str}")
74
+ frame.write_line(f"DETECTOR({coord_str}) {target_str}")
74
75
  else:
75
- emit.writeln(frame, f"DETECTOR {target_str}")
76
+ frame.write_line(f"DETECTOR {target_str}")
76
77
  return ()
77
78
 
78
79
  @impl(stmts.ObservableInclude)
79
80
  def obs_include(
80
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ObservableInclude
81
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ObservableInclude
81
82
  ):
82
83
 
83
84
  idx: str = frame.get(stmt.idx)
84
85
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
85
86
 
86
87
  target_str: str = " ".join(targets)
87
- emit.writeln(frame, f"OBSERVABLE_INCLUDE({idx}) {target_str}")
88
+ frame.write_line(f"OBSERVABLE_INCLUDE({idx}) {target_str}")
88
89
 
89
90
  return ()
90
91
 
91
92
  @impl(stmts.NewPauliString)
92
93
  def new_paulistr(
93
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.NewPauliString
94
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.NewPauliString
94
95
  ):
95
96
 
96
97
  string: tuple[str, ...] = frame.get_values(stmt.string)
@@ -105,13 +106,13 @@ class EmitStimAuxMethods(MethodTable):
105
106
 
106
107
  @impl(stmts.QubitCoordinates)
107
108
  def qubit_coordinates(
108
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.QubitCoordinates
109
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.QubitCoordinates
109
110
  ):
110
111
 
111
112
  coords: tuple[str, ...] = frame.get_values(stmt.coord)
112
113
  target: str = frame.get(stmt.target)
113
114
 
114
115
  coord_str: str = ", ".join(coords)
115
- emit.writeln(frame, f"QUBIT_COORDS({coord_str}) {target}")
116
+ frame.write_line(f"QUBIT_COORDS({coord_str}) {target}")
116
117
 
117
118
  return ()
@@ -1,7 +1,6 @@
1
- from kirin.emit import EmitStrFrame
2
1
  from kirin.interp import MethodTable, impl
3
2
 
4
- from bloqade.stim.emit.stim_str import EmitStimMain
3
+ from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame
5
4
 
6
5
  from . import stmts
7
6
  from ._dialect import dialect
@@ -27,13 +26,13 @@ class EmitStimCollapseMethods(MethodTable):
27
26
  @impl(stmts.MXX)
28
27
  @impl(stmts.MYY)
29
28
  @impl(stmts.MZZ)
30
- def get_measure(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: Measurement):
29
+ def get_measure(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: Measurement):
31
30
 
32
31
  probability: str = frame.get(stmt.p)
33
32
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
34
33
 
35
34
  out = f"{self.meas_map[stmt.name]}({probability}) " + " ".join(targets)
36
- emit.writeln(frame, out)
35
+ frame.write_line(out)
37
36
 
38
37
  return ()
39
38
 
@@ -46,18 +45,18 @@ class EmitStimCollapseMethods(MethodTable):
46
45
  @impl(stmts.RX)
47
46
  @impl(stmts.RY)
48
47
  @impl(stmts.RZ)
49
- def get_reset(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: Reset):
48
+ def get_reset(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: Reset):
50
49
 
51
50
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
52
51
 
53
52
  out = f"{self.reset_map[stmt.name]} " + " ".join(targets)
54
- emit.writeln(frame, out)
53
+ frame.write_line(out)
55
54
 
56
55
  return ()
57
56
 
58
57
  @impl(stmts.PPMeasurement)
59
58
  def pp_measure(
60
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PPMeasurement
59
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.PPMeasurement
61
60
  ):
62
61
  probability: str = frame.get(stmt.p)
63
62
  targets: tuple[str, ...] = tuple(
@@ -65,6 +64,6 @@ class EmitStimCollapseMethods(MethodTable):
65
64
  )
66
65
 
67
66
  out = f"MPP({probability}) " + " ".join(targets)
68
- emit.writeln(frame, out)
67
+ frame.write_line(out)
69
68
 
70
69
  return ()
@@ -1,7 +1,6 @@
1
- from kirin.emit import EmitStrFrame
2
1
  from kirin.interp import MethodTable, impl
3
2
 
4
- from bloqade.stim.emit.stim_str import EmitStimMain
3
+ from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame
5
4
 
6
5
  from . import stmts
7
6
  from ._dialect import dialect
@@ -33,11 +32,11 @@ class EmitStimGateMethods(MethodTable):
33
32
  @impl(stmts.SqrtY)
34
33
  @impl(stmts.SqrtZ)
35
34
  def single_qubit_gate(
36
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: SingleQubitGate
35
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: SingleQubitGate
37
36
  ):
38
37
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
39
38
  res = f"{self.gate_1q_map[stmt.name][int(stmt.dagger)]} " + " ".join(targets)
40
- emit.writeln(frame, res)
39
+ frame.write_line(res)
41
40
 
42
41
  return ()
43
42
 
@@ -47,13 +46,13 @@ class EmitStimGateMethods(MethodTable):
47
46
 
48
47
  @impl(stmts.Swap)
49
48
  def two_qubit_gate(
50
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: ControlledTwoQubitGate
49
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: ControlledTwoQubitGate
51
50
  ):
52
51
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
53
52
  res = f"{self.gate_ctrl_2q_map[stmt.name][int(stmt.dagger)]} " + " ".join(
54
53
  targets
55
54
  )
56
- emit.writeln(frame, res)
55
+ frame.write_line(res)
57
56
 
58
57
  return ()
59
58
 
@@ -68,19 +67,19 @@ class EmitStimGateMethods(MethodTable):
68
67
  @impl(stmts.CY)
69
68
  @impl(stmts.CZ)
70
69
  def ctrl_two_qubit_gate(
71
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: ControlledTwoQubitGate
70
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: ControlledTwoQubitGate
72
71
  ):
73
72
  controls: tuple[str, ...] = frame.get_values(stmt.controls)
74
73
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
75
74
  res = f"{self.gate_ctrl_2q_map[stmt.name][int(stmt.dagger)]} " + " ".join(
76
75
  f"{ctrl} {tgt}" for ctrl, tgt in zip(controls, targets)
77
76
  )
78
- emit.writeln(frame, res)
77
+ frame.write_line(res)
79
78
 
80
79
  return ()
81
80
 
82
81
  @impl(stmts.SPP)
83
- def spp(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.SPP):
82
+ def spp(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.SPP):
84
83
 
85
84
  targets: tuple[str, ...] = tuple(
86
85
  targ.upper() for targ in frame.get_values(stmt.targets)
@@ -89,6 +88,6 @@ class EmitStimGateMethods(MethodTable):
89
88
  res = "SPP_DAG " + " ".join(targets)
90
89
  else:
91
90
  res = "SPP " + " ".join(targets)
92
- emit.writeln(frame, res)
91
+ frame.write_line(res)
93
92
 
94
93
  return ()
@@ -1,7 +1,6 @@
1
- from kirin.emit import EmitStrFrame
2
1
  from kirin.interp import MethodTable, impl
3
2
 
4
- from bloqade.stim.emit.stim_str import EmitStimMain
3
+ from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame
5
4
 
6
5
  from . import stmts
7
6
  from ._dialect import dialect
@@ -24,20 +23,20 @@ class EmitStimNoiseMethods(MethodTable):
24
23
  @impl(stmts.Depolarize1)
25
24
  @impl(stmts.Depolarize2)
26
25
  def single_p_error(
27
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Depolarize1
26
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.Depolarize1
28
27
  ):
29
28
 
30
29
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
31
30
  p: str = frame.get(stmt.p)
32
31
  name = self.single_p_error_map[stmt.name]
33
32
  res = f"{name}({p}) " + " ".join(targets)
34
- emit.writeln(frame, res)
33
+ frame.write_line(res)
35
34
 
36
35
  return ()
37
36
 
38
37
  @impl(stmts.PauliChannel1)
39
38
  def pauli_channel1(
40
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PauliChannel1
39
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.PauliChannel1
41
40
  ):
42
41
 
43
42
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
@@ -45,13 +44,13 @@ class EmitStimNoiseMethods(MethodTable):
45
44
  py: str = frame.get(stmt.py)
46
45
  pz: str = frame.get(stmt.pz)
47
46
  res = f"PAULI_CHANNEL_1({px}, {py}, {pz}) " + " ".join(targets)
48
- emit.writeln(frame, res)
47
+ frame.write_line(res)
49
48
 
50
49
  return ()
51
50
 
52
51
  @impl(stmts.PauliChannel2)
53
52
  def pauli_channel2(
54
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PauliChannel2
53
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.PauliChannel2
55
54
  ):
56
55
 
57
56
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
@@ -61,14 +60,14 @@ class EmitStimNoiseMethods(MethodTable):
61
60
  prob_str: str = ", ".join(prob)
62
61
 
63
62
  res = f"PAULI_CHANNEL_2({prob_str}) " + " ".join(targets)
64
- emit.writeln(frame, res)
63
+ frame.write_line(res)
65
64
 
66
65
  return ()
67
66
 
68
67
  @impl(stmts.TrivialError)
69
68
  @impl(stmts.QubitLoss)
70
69
  def non_stim_error(
71
- self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.TrivialError
70
+ self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.TrivialError
72
71
  ):
73
72
 
74
73
  targets: tuple[str, ...] = frame.get_values(stmt.targets)
@@ -76,15 +75,16 @@ class EmitStimNoiseMethods(MethodTable):
76
75
  prob_str: str = ", ".join(prob)
77
76
 
78
77
  res = f"I_ERROR[{stmt.name}]({prob_str}) " + " ".join(targets)
79
- emit.writeln(frame, res)
78
+ frame.write_line(res)
80
79
 
81
80
  return ()
82
81
 
83
82
  @impl(stmts.TrivialCorrelatedError)
83
+ @impl(stmts.CorrelatedQubitLoss)
84
84
  def non_stim_corr_error(
85
85
  self,
86
86
  emit: EmitStimMain,
87
- frame: EmitStrFrame,
87
+ frame: EmitStimFrame,
88
88
  stmt: stmts.TrivialCorrelatedError,
89
89
  ):
90
90
 
@@ -92,7 +92,11 @@ class EmitStimNoiseMethods(MethodTable):
92
92
  prob: tuple[str, ...] = frame.get_values(stmt.probs)
93
93
  prob_str: str = ", ".join(prob)
94
94
 
95
- res = f"I_ERROR[{stmt.name}:{stmt.nonce}]({prob_str}) " + " ".join(targets)
96
- emit.writeln(frame, res)
95
+ res = (
96
+ f"I_ERROR[{stmt.name}:{emit.correlated_error_count}]({prob_str}) "
97
+ + " ".join(targets)
98
+ )
99
+ emit.correlated_error_count += 1
100
+ frame.write_line(res)
97
101
 
98
102
  return ()
@@ -89,9 +89,6 @@ class NonStimError(ir.Statement):
89
89
  class NonStimCorrelatedError(ir.Statement):
90
90
  name = "NonStimCorrelatedError"
91
91
  traits = frozenset({lowering.FromPythonCall()})
92
- nonce: int = (
93
- info.attribute()
94
- ) # Must be a unique value, otherwise stim might merge two correlated errors with equal probabilities
95
92
  probs: tuple[ir.SSAValue, ...] = info.argument(types.Float)
96
93
  targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
97
94
 
@@ -109,3 +106,8 @@ class TrivialError(NonStimError):
109
106
  @statement(dialect=dialect)
110
107
  class QubitLoss(NonStimError):
111
108
  name = "loss"
109
+
110
+
111
+ @statement(dialect=dialect)
112
+ class CorrelatedQubitLoss(NonStimCorrelatedError):
113
+ name = "correlated_loss"
@@ -1 +1,2 @@
1
+ from . import impls as impls
1
2
  from .stim_str import FuncEmit as FuncEmit, EmitStimMain as EmitStimMain
@@ -0,0 +1,16 @@
1
+ from kirin.interp import MethodTable, impl
2
+ from kirin.dialects.debug import Info, dialect
3
+
4
+ from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame
5
+
6
+
7
+ @dialect.register(key="emit.stim")
8
+ class EmitStimDebugMethods(MethodTable):
9
+
10
+ @impl(Info)
11
+ def info(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: Info):
12
+
13
+ msg: str = frame.get(stmt.msg)
14
+ frame.write_line(f"# {msg}")
15
+
16
+ return ()
@@ -1,54 +1,71 @@
1
- from io import StringIO
2
- from typing import IO, TypeVar
3
- from dataclasses import field, dataclass
1
+ import sys
2
+ from typing import IO, Generic, TypeVar, cast
3
+ from dataclasses import dataclass
4
4
 
5
5
  from kirin import ir, interp
6
- from kirin.emit import EmitStr, EmitStrFrame
7
6
  from kirin.dialects import func
7
+ from kirin.emit.abc import EmitABC, EmitFrame
8
8
 
9
9
  IO_t = TypeVar("IO_t", bound=IO)
10
10
 
11
11
 
12
- def _default_dialect_group() -> ir.DialectGroup:
13
- from ..groups import main
12
+ @dataclass
13
+ class EmitStimFrame(EmitFrame[str], Generic[IO_t]):
14
+ io: IO_t = cast(IO_t, sys.stdout)
15
+
16
+ def write(self, value: str) -> None:
17
+ self.io.write(value)
14
18
 
15
- return main
19
+ def write_line(self, value: str) -> None:
20
+ self.write(" " * self._indent + value + "\n")
16
21
 
17
22
 
18
23
  @dataclass
19
- class EmitStimMain(EmitStr):
20
- keys = ["emit.stim"]
21
- dialects: ir.DialectGroup = field(default_factory=_default_dialect_group)
22
- file: StringIO = field(default_factory=StringIO)
24
+ class EmitStimMain(EmitABC[EmitStimFrame, str], Generic[IO_t]):
25
+ io: IO_t = cast(IO_t, sys.stdout)
26
+ keys = ("emit.stim",)
27
+ void = ""
28
+ correlation_identifier_offset: int = 0
23
29
 
24
- def initialize(self):
30
+ def initialize(self) -> "EmitStimMain":
25
31
  super().initialize()
26
- self.file.truncate(0)
27
- self.file.seek(0)
32
+ self.correlated_error_count = self.correlation_identifier_offset
28
33
  return self
29
34
 
30
- def eval_stmt_fallback(
31
- self, frame: EmitStrFrame, stmt: ir.Statement
32
- ) -> tuple[str, ...]:
33
- return (stmt.name,)
35
+ def initialize_frame(
36
+ self, node: ir.Statement, *, has_parent_access: bool = False
37
+ ) -> EmitStimFrame:
38
+ return EmitStimFrame(node, self.io, has_parent_access=has_parent_access)
39
+
40
+ def frame_call(
41
+ self, frame: EmitStimFrame, node: ir.Statement, *args: str, **kwargs: str
42
+ ) -> str:
43
+ return f"{args[0]}({', '.join(args[1:])})"
34
44
 
35
- def emit_block(self, frame: EmitStrFrame, block: ir.Block) -> str | None:
36
- for stmt in block.stmts:
37
- result = self.eval_stmt(frame, stmt)
38
- if isinstance(result, tuple):
39
- frame.set_values(stmt.results, result)
40
- return None
45
+ def get_attribute(self, frame: EmitStimFrame, node: ir.Attribute) -> str:
46
+ method = self.registry.get(interp.Signature(type(node)))
47
+ if method is None:
48
+ raise ValueError(f"Method not found for node: {node}")
49
+ return method(self, frame, node)
41
50
 
42
- def get_output(self) -> str:
43
- self.file.seek(0)
44
- return self.file.read()
51
+ def reset(self):
52
+ self.io.truncate(0)
53
+ self.io.seek(0)
54
+
55
+ def eval_fallback(self, frame: EmitStimFrame, node: ir.Statement) -> tuple:
56
+ return tuple("" for _ in range(len(node.results)))
45
57
 
46
58
 
47
59
  @func.dialect.register(key="emit.stim")
48
60
  class FuncEmit(interp.MethodTable):
49
-
50
61
  @interp.impl(func.Function)
51
- def emit_func(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: func.Function):
52
- _ = emit.run_ssacfg_region(frame, stmt.body, ())
53
- # emit.output = "\n".join(frame.body)
62
+ def emit_func(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: func.Function):
63
+ for block in stmt.body.blocks:
64
+ frame.current_block = block
65
+ for stmt_ in block.stmts:
66
+ frame.current_stmt = stmt_
67
+ res = emit.frame_eval(frame, stmt_)
68
+ if isinstance(res, tuple):
69
+ frame.set_values(stmt_.results, res)
70
+
54
71
  return ()
bloqade/stim/groups.py CHANGED
@@ -1,12 +1,22 @@
1
1
  from kirin import ir
2
2
  from kirin.passes import Fold, TypeInfer
3
- from kirin.dialects import func, lowering
3
+ from kirin.dialects import func, debug, ssacfg, lowering
4
4
 
5
5
  from .dialects import gate, noise, collapse, auxiliary
6
6
 
7
7
 
8
8
  @ir.dialect_group(
9
- [noise, gate, auxiliary, collapse, func, lowering.func, lowering.call]
9
+ [
10
+ noise,
11
+ gate,
12
+ auxiliary,
13
+ collapse,
14
+ func,
15
+ lowering.func,
16
+ lowering.call,
17
+ debug,
18
+ ssacfg,
19
+ ]
10
20
  )
11
21
  def main(self):
12
22
  typeinfer_pass = TypeInfer(self)