bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/__init__.py +4 -1
  7. bloqade/analysis/measure_id/analysis.py +29 -20
  8. bloqade/analysis/measure_id/impls.py +72 -31
  9. bloqade/annotate/__init__.py +6 -0
  10. bloqade/annotate/_dialect.py +3 -0
  11. bloqade/annotate/_interface.py +22 -0
  12. bloqade/annotate/stmts.py +29 -0
  13. bloqade/annotate/types.py +13 -0
  14. bloqade/cirq_utils/__init__.py +4 -2
  15. bloqade/cirq_utils/emit/__init__.py +3 -0
  16. bloqade/cirq_utils/emit/base.py +246 -0
  17. bloqade/cirq_utils/emit/gate.py +104 -0
  18. bloqade/cirq_utils/emit/noise.py +90 -0
  19. bloqade/cirq_utils/emit/qubit.py +35 -0
  20. bloqade/cirq_utils/lowering.py +660 -0
  21. bloqade/cirq_utils/noise/__init__.py +0 -2
  22. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  23. bloqade/cirq_utils/noise/model.py +151 -191
  24. bloqade/cirq_utils/noise/transform.py +2 -2
  25. bloqade/cirq_utils/parallelize.py +9 -6
  26. bloqade/gemini/__init__.py +1 -0
  27. bloqade/gemini/analysis/__init__.py +3 -0
  28. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  29. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  30. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  31. bloqade/gemini/groups.py +67 -0
  32. bloqade/native/__init__.py +23 -0
  33. bloqade/native/_prelude.py +45 -0
  34. bloqade/native/dialects/__init__.py +0 -0
  35. bloqade/native/dialects/gate/__init__.py +2 -0
  36. bloqade/native/dialects/gate/_dialect.py +3 -0
  37. bloqade/native/dialects/gate/_interface.py +32 -0
  38. bloqade/native/dialects/gate/stmts.py +31 -0
  39. bloqade/native/stdlib/__init__.py +0 -0
  40. bloqade/native/stdlib/broadcast.py +246 -0
  41. bloqade/native/stdlib/simple.py +220 -0
  42. bloqade/native/upstream/__init__.py +4 -0
  43. bloqade/native/upstream/squin2native.py +79 -0
  44. bloqade/pyqrack/__init__.py +2 -2
  45. bloqade/pyqrack/base.py +7 -1
  46. bloqade/pyqrack/device.py +190 -4
  47. bloqade/pyqrack/native.py +49 -0
  48. bloqade/pyqrack/reg.py +6 -6
  49. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  50. bloqade/pyqrack/squin/gate/gate.py +136 -0
  51. bloqade/pyqrack/squin/noise/native.py +120 -54
  52. bloqade/pyqrack/squin/qubit.py +39 -36
  53. bloqade/pyqrack/target.py +5 -4
  54. bloqade/pyqrack/task.py +114 -7
  55. bloqade/qasm2/_qasm_loading.py +3 -3
  56. bloqade/qasm2/dialects/core/address.py +21 -12
  57. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  58. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  59. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  60. bloqade/qasm2/dialects/noise/model.py +2 -1
  61. bloqade/qasm2/emit/base.py +16 -11
  62. bloqade/qasm2/emit/gate.py +11 -8
  63. bloqade/qasm2/emit/main.py +103 -3
  64. bloqade/qasm2/emit/target.py +9 -5
  65. bloqade/qasm2/groups.py +3 -2
  66. bloqade/qasm2/parse/lowering.py +0 -1
  67. bloqade/qasm2/passes/fold.py +14 -73
  68. bloqade/qasm2/passes/glob.py +2 -2
  69. bloqade/qasm2/passes/noise.py +1 -1
  70. bloqade/qasm2/passes/parallel.py +7 -5
  71. bloqade/qasm2/rewrite/__init__.py +0 -1
  72. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  73. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  74. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  75. bloqade/qasm2/rewrite/register.py +2 -2
  76. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  77. bloqade/qbraid/lowering.py +1 -0
  78. bloqade/qbraid/schema.py +2 -2
  79. bloqade/qubit/__init__.py +12 -0
  80. bloqade/qubit/_dialect.py +3 -0
  81. bloqade/qubit/_interface.py +49 -0
  82. bloqade/qubit/_prelude.py +45 -0
  83. bloqade/qubit/analysis/__init__.py +1 -0
  84. bloqade/qubit/analysis/address_impl.py +40 -0
  85. bloqade/qubit/stdlib/__init__.py +2 -0
  86. bloqade/qubit/stdlib/_new.py +34 -0
  87. bloqade/qubit/stdlib/broadcast.py +62 -0
  88. bloqade/qubit/stdlib/simple.py +59 -0
  89. bloqade/qubit/stmts.py +60 -0
  90. bloqade/rewrite/passes/__init__.py +6 -0
  91. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  92. bloqade/rewrite/passes/callgraph.py +116 -0
  93. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  94. bloqade/rewrite/rules/split_ifs.py +18 -1
  95. bloqade/squin/__init__.py +47 -14
  96. bloqade/squin/analysis/__init__.py +0 -1
  97. bloqade/squin/analysis/schedule.py +10 -11
  98. bloqade/squin/gate/__init__.py +2 -0
  99. bloqade/squin/gate/_dialect.py +3 -0
  100. bloqade/squin/gate/_interface.py +98 -0
  101. bloqade/squin/gate/stmts.py +125 -0
  102. bloqade/squin/groups.py +5 -22
  103. bloqade/squin/noise/__init__.py +1 -10
  104. bloqade/squin/noise/_dialect.py +1 -1
  105. bloqade/squin/noise/_interface.py +45 -0
  106. bloqade/squin/noise/stmts.py +66 -28
  107. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  108. bloqade/squin/rewrite/__init__.py +0 -2
  109. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  110. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  111. bloqade/squin/stdlib/__init__.py +0 -0
  112. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  113. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  114. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  115. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  116. bloqade/squin/stdlib/simple/__init__.py +33 -0
  117. bloqade/squin/stdlib/simple/gate.py +242 -0
  118. bloqade/squin/stdlib/simple/noise.py +126 -0
  119. bloqade/stim/__init__.py +1 -0
  120. bloqade/stim/_wrappers.py +6 -0
  121. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  122. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  123. bloqade/stim/dialects/gate/emit.py +9 -10
  124. bloqade/stim/dialects/noise/emit.py +17 -13
  125. bloqade/stim/dialects/noise/stmts.py +5 -3
  126. bloqade/stim/emit/__init__.py +1 -0
  127. bloqade/stim/emit/impls.py +16 -0
  128. bloqade/stim/emit/stim_str.py +48 -31
  129. bloqade/stim/groups.py +12 -2
  130. bloqade/stim/parse/lowering.py +14 -17
  131. bloqade/stim/passes/__init__.py +3 -1
  132. bloqade/stim/passes/flatten.py +26 -0
  133. bloqade/stim/passes/simplify_ifs.py +16 -2
  134. bloqade/stim/passes/squin_to_stim.py +18 -60
  135. bloqade/stim/rewrite/__init__.py +3 -4
  136. bloqade/stim/rewrite/get_record_util.py +24 -0
  137. bloqade/stim/rewrite/ifs_to_stim.py +29 -31
  138. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  139. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  140. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  141. bloqade/stim/rewrite/squin_measure.py +11 -79
  142. bloqade/stim/rewrite/squin_noise.py +134 -108
  143. bloqade/stim/rewrite/util.py +5 -192
  144. bloqade/test_utils.py +1 -1
  145. bloqade/types.py +10 -0
  146. bloqade/validation/__init__.py +2 -0
  147. bloqade/validation/analysis/__init__.py +5 -0
  148. bloqade/validation/analysis/analysis.py +41 -0
  149. bloqade/validation/analysis/lattice.py +58 -0
  150. bloqade/validation/kernel_validation.py +77 -0
  151. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  152. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  153. bloqade/pyqrack/squin/op.py +0 -166
  154. bloqade/pyqrack/squin/runtime.py +0 -535
  155. bloqade/pyqrack/squin/wire.py +0 -51
  156. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  157. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  158. bloqade/squin/_typeinfer.py +0 -20
  159. bloqade/squin/analysis/address_impl.py +0 -71
  160. bloqade/squin/analysis/nsites/__init__.py +0 -9
  161. bloqade/squin/analysis/nsites/analysis.py +0 -50
  162. bloqade/squin/analysis/nsites/impls.py +0 -92
  163. bloqade/squin/analysis/nsites/lattice.py +0 -49
  164. bloqade/squin/cirq/__init__.py +0 -265
  165. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  166. bloqade/squin/cirq/emit/noise.py +0 -49
  167. bloqade/squin/cirq/emit/op.py +0 -125
  168. bloqade/squin/cirq/emit/qubit.py +0 -60
  169. bloqade/squin/cirq/emit/runtime.py +0 -242
  170. bloqade/squin/cirq/lowering.py +0 -440
  171. bloqade/squin/lowering.py +0 -54
  172. bloqade/squin/noise/_wrapper.py +0 -40
  173. bloqade/squin/noise/rewrite.py +0 -111
  174. bloqade/squin/op/__init__.py +0 -41
  175. bloqade/squin/op/_dialect.py +0 -3
  176. bloqade/squin/op/_wrapper.py +0 -121
  177. bloqade/squin/op/number.py +0 -5
  178. bloqade/squin/op/rewrite.py +0 -46
  179. bloqade/squin/op/stdlib.py +0 -62
  180. bloqade/squin/op/stmts.py +0 -276
  181. bloqade/squin/op/traits.py +0 -43
  182. bloqade/squin/op/types.py +0 -26
  183. bloqade/squin/qubit.py +0 -184
  184. bloqade/squin/rewrite/canonicalize.py +0 -60
  185. bloqade/squin/rewrite/desugar.py +0 -124
  186. bloqade/squin/types.py +0 -8
  187. bloqade/squin/wire.py +0 -201
  188. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  189. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  190. bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
  191. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  192. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,68 @@
1
+ from typing import Iterable
2
+ from dataclasses import dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.dialects.py import Constant
6
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
7
+
8
+ from bloqade.stim.dialects import auxiliary
9
+ from bloqade.annotate.stmts import SetDetector
10
+ from bloqade.analysis.measure_id import MeasureIDFrame
11
+ from bloqade.stim.dialects.auxiliary import Detector
12
+ from bloqade.analysis.measure_id.lattice import MeasureIdTuple
13
+
14
+ from ..rewrite.get_record_util import insert_get_records
15
+
16
+
17
+ @dataclass
18
+ class SetDetectorToStim(RewriteRule):
19
+ """
20
+ Rewrite SetDetector to GetRecord and Detector in the stim dialect
21
+ """
22
+
23
+ measure_id_frame: MeasureIDFrame
24
+
25
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
26
+ match node:
27
+ case SetDetector():
28
+ return self.rewrite_SetDetector(node)
29
+ case _:
30
+ return RewriteResult()
31
+
32
+ def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult:
33
+
34
+ # get coordinates and generate correct consts
35
+ coord_ssas = []
36
+ if not isinstance(node.coordinates.owner, Constant):
37
+ return RewriteResult()
38
+
39
+ coord_values = node.coordinates.owner.value.unwrap()
40
+
41
+ if not isinstance(coord_values, Iterable):
42
+ return RewriteResult()
43
+
44
+ if any(not isinstance(value, (int, float)) for value in coord_values):
45
+ return RewriteResult()
46
+
47
+ for coord_value in coord_values:
48
+ if isinstance(coord_value, float):
49
+ coord_stmt = auxiliary.ConstFloat(value=coord_value)
50
+ else: # int
51
+ coord_stmt = auxiliary.ConstInt(value=coord_value)
52
+ coord_ssas.append(coord_stmt.result)
53
+ coord_stmt.insert_before(node)
54
+
55
+ measure_ids = self.measure_id_frame.entries[node.measurements]
56
+ assert isinstance(measure_ids, MeasureIdTuple)
57
+
58
+ get_record_list = insert_get_records(
59
+ node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
60
+ )
61
+
62
+ detector_stmt = Detector(
63
+ coord=tuple(coord_ssas), targets=tuple(get_record_list)
64
+ )
65
+
66
+ node.replace_by(detector_stmt)
67
+
68
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,52 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+
6
+ from bloqade.stim.dialects import auxiliary
7
+ from bloqade.annotate.stmts import SetObservable
8
+ from bloqade.analysis.measure_id import MeasureIDFrame
9
+ from bloqade.stim.dialects.auxiliary import ObservableInclude
10
+ from bloqade.analysis.measure_id.lattice import MeasureIdTuple
11
+
12
+ from ..rewrite.get_record_util import insert_get_records
13
+
14
+
15
+ @dataclass
16
+ class SetObservableToStim(RewriteRule):
17
+ """
18
+ Rewrite SetObservable to GetRecord and ObservableInclude in the stim dialect
19
+ """
20
+
21
+ measure_id_frame: MeasureIDFrame
22
+
23
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
24
+ match node:
25
+ case SetObservable():
26
+ return self.rewrite_SetObservable(node)
27
+ case _:
28
+ return RewriteResult()
29
+
30
+ def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult:
31
+
32
+ # set idx to 0 for now, but this
33
+ # should be something that a user can set on their own.
34
+ # SetObservable needs to accept an int.
35
+
36
+ idx_stmt = auxiliary.ConstInt(value=0)
37
+ idx_stmt.insert_before(node)
38
+
39
+ measure_ids = self.measure_id_frame.entries[node.measurements]
40
+ assert isinstance(measure_ids, MeasureIdTuple)
41
+
42
+ get_record_list = insert_get_records(
43
+ node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
44
+ )
45
+
46
+ observable_include_stmt = ObservableInclude(
47
+ idx=idx_stmt.result, targets=tuple(get_record_list)
48
+ )
49
+
50
+ node.replace_by(observable_include_stmt)
51
+
52
+ return RewriteResult(has_done_something=True)
@@ -2,47 +2,15 @@
2
2
  from dataclasses import dataclass
3
3
 
4
4
  from kirin import ir
5
- from kirin.dialects import py, ilist
5
+ from kirin.dialects import py
6
6
  from kirin.rewrite.abc import RewriteRule, RewriteResult
7
7
 
8
- from bloqade.squin import wire, qubit
8
+ from bloqade import qubit
9
9
  from bloqade.squin.rewrite import AddressAttribute
10
- from bloqade.stim.dialects import collapse, auxiliary
10
+ from bloqade.stim.dialects import collapse
11
11
  from bloqade.stim.rewrite.util import (
12
- is_measure_result_used,
13
12
  insert_qubit_idx_from_address,
14
13
  )
15
- from bloqade.analysis.measure_id.lattice import MeasureId, MeasureIdBool, MeasureIdTuple
16
-
17
-
18
- def replace_get_record(
19
- node: ir.Statement, measure_id_bool: MeasureIdBool, meas_count: int
20
- ):
21
- assert isinstance(measure_id_bool, MeasureIdBool)
22
- target_rec_idx = (measure_id_bool.idx - 1) - meas_count
23
- idx_stmt = py.constant.Constant(target_rec_idx)
24
- idx_stmt.insert_before(node)
25
- get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
26
- node.replace_by(get_record_stmt)
27
-
28
-
29
- def insert_get_record_list(
30
- node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count: int
31
- ):
32
- """
33
- Insert GetRecord statements before the given node
34
- """
35
- get_record_ssas = []
36
- for measure_id_bool in measure_id_tuple.data:
37
- assert isinstance(measure_id_bool, MeasureIdBool)
38
- target_rec_idx = (measure_id_bool.idx - 1) - meas_count
39
- idx_stmt = py.constant.Constant(target_rec_idx)
40
- idx_stmt.insert_before(node)
41
- get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
42
- get_record_stmt.insert_before(node)
43
- get_record_ssas.append(get_record_stmt.result)
44
-
45
- node.replace_by(ilist.New(values=get_record_ssas))
46
14
 
47
15
 
48
16
  @dataclass
@@ -51,29 +19,20 @@ class SquinMeasureToStim(RewriteRule):
51
19
  Rewrite squin measure-related statements to stim statements.
52
20
  """
53
21
 
54
- measure_id_result: dict[ir.SSAValue, MeasureId]
55
- total_measure_count: int
56
-
57
22
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
58
23
 
59
24
  match node:
60
- case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure():
25
+ case qubit.stmts.Measure():
61
26
  return self.rewrite_Measure(node)
62
27
  case _:
63
28
  return RewriteResult()
64
29
 
65
- def rewrite_Measure(
66
- self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
67
- ) -> RewriteResult:
30
+ def rewrite_Measure(self, measure_stmt: qubit.stmts.Measure) -> RewriteResult:
68
31
 
69
32
  qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
70
33
  if qubit_idx_ssas is None:
71
34
  return RewriteResult()
72
35
 
73
- measure_id = self.measure_id_result[measure_stmt.result]
74
- if not isinstance(measure_id, (MeasureIdBool, MeasureIdTuple)):
75
- return RewriteResult()
76
-
77
36
  prob_noise_stmt = py.constant.Constant(0.0)
78
37
  stim_measure_stmt = collapse.MZ(
79
38
  p=prob_noise_stmt.result,
@@ -82,48 +41,21 @@ class SquinMeasureToStim(RewriteRule):
82
41
  prob_noise_stmt.insert_before(measure_stmt)
83
42
  stim_measure_stmt.insert_before(measure_stmt)
84
43
 
85
- if not is_measure_result_used(measure_stmt):
44
+ # if the measurement is not being used anywhere
45
+ # we can safely get rid of it. Measure cannot be DCE'd because
46
+ # it is not pure.
47
+ if not bool(measure_stmt.result.uses):
86
48
  measure_stmt.delete()
87
- return RewriteResult(has_done_something=True)
88
-
89
- # replace dataflow with new stmt!
90
- measure_id = self.measure_id_result[measure_stmt.result]
91
- if isinstance(measure_id, MeasureIdBool):
92
- replace_get_record(
93
- node=measure_stmt,
94
- measure_id_bool=measure_id,
95
- meas_count=self.total_measure_count,
96
- )
97
- elif isinstance(measure_id, MeasureIdTuple):
98
- insert_get_record_list(
99
- node=measure_stmt,
100
- measure_id_tuple=measure_id,
101
- meas_count=self.total_measure_count,
102
- )
103
- else:
104
- # already checked before, so this should not happen
105
- raise ValueError(
106
- f"Unexpected measure ID type: {type(measure_id)} for measure statement {measure_stmt}"
107
- )
108
49
 
109
50
  return RewriteResult(has_done_something=True)
110
51
 
111
52
  def get_qubit_idx_ssas(
112
- self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
53
+ self, measure_stmt: qubit.stmts.Measure
113
54
  ) -> tuple[ir.SSAValue, ...] | None:
114
55
  """
115
56
  Extract the address attribute and insert qubit indices for the given measure statement.
116
57
  """
117
- match measure_stmt:
118
- case qubit.MeasureQubit():
119
- address_attr = measure_stmt.qubit.hints.get("address")
120
- case qubit.MeasureQubitList():
121
- address_attr = measure_stmt.qubits.hints.get("address")
122
- case wire.Measure():
123
- address_attr = measure_stmt.wire.hints.get("address")
124
- case _:
125
- return None
126
-
58
+ address_attr = measure_stmt.qubits.hints.get("address")
127
59
  if address_attr is None:
128
60
  return None
129
61
 
@@ -1,17 +1,17 @@
1
+ import itertools
1
2
  from typing import Tuple
2
3
  from dataclasses import dataclass
3
4
 
5
+ from kirin import types
4
6
  from kirin.ir import SSAValue, Statement
5
- from kirin.dialects import py, ilist
7
+ from kirin.dialects import py
6
8
  from kirin.rewrite.abc import RewriteRule, RewriteResult
7
9
 
8
- from bloqade.squin import op, wire, noise as squin_noise, qubit
10
+ from bloqade.squin import noise as squin_noise
9
11
  from bloqade.stim.dialects import noise as stim_noise
10
- from bloqade.stim.rewrite.util import (
11
- get_const_value,
12
- create_wire_passthrough,
13
- insert_qubit_idx_after_apply,
14
- )
12
+ from bloqade.stim.rewrite.util import insert_qubit_idx_from_address
13
+ from bloqade.analysis.address.lattice import AddressReg, PartialIList
14
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
15
15
 
16
16
 
17
17
  @dataclass
@@ -19,157 +19,183 @@ class SquinNoiseToStim(RewriteRule):
19
19
 
20
20
  def rewrite_Statement(self, node: Statement) -> RewriteResult:
21
21
  match node:
22
- case qubit.Apply() | qubit.Broadcast():
23
- return self.rewrite_Apply_and_Broadcast(node)
22
+ case squin_noise.stmts.NoiseChannel():
23
+ return self.rewrite_NoiseChannel(node)
24
24
  case _:
25
25
  return RewriteResult()
26
26
 
27
- def rewrite_Apply_and_Broadcast(
28
- self, stmt: qubit.Apply | qubit.Broadcast
27
+ def rewrite_NoiseChannel(
28
+ self, stmt: squin_noise.stmts.NoiseChannel
29
29
  ) -> RewriteResult:
30
- """Rewrite Apply and Broadcast to their stim statements."""
30
+ """Rewrite NoiseChannel statements to their stim equivalents."""
31
31
 
32
- # this is an SSAValue, need it to be the actual operator
33
- applied_op = stmt.operator.owner
32
+ rewrite_method = getattr(self, f"rewrite_{type(stmt).__name__}", None)
34
33
 
35
- if isinstance(applied_op, squin_noise.stmts.QubitLoss):
34
+ # No rewrite method exists and the rewrite should stop
35
+ if rewrite_method is None:
36
36
  return RewriteResult()
37
+ if isinstance(stmt, squin_noise.stmts.CorrelatedQubitLoss):
38
+ # CorrelatedQubitLoss represents a broadcast operation, but Stim does not
39
+ # support broadcasting for multi-qubit noise channels.
40
+ # Therefore, we must expand the broadcast into individual stim statements.
41
+ qubit_address_attr = stmt.qubits.hints.get("address", None)
37
42
 
38
- if isinstance(applied_op, squin_noise.stmts.NoiseChannel):
43
+ if not isinstance(qubit_address_attr, AddressAttribute):
44
+ return RewriteResult()
45
+
46
+ if not isinstance(address := qubit_address_attr.address, PartialIList):
47
+ return RewriteResult()
39
48
 
40
- qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt)
41
- if qubit_idx_ssas is None:
49
+ if not types.is_tuple_of(data := address.data, AddressReg):
42
50
  return RewriteResult()
43
51
 
44
- rewrite_method = getattr(self, f"rewrite_{type(applied_op).__name__}")
45
- stim_stmt = rewrite_method(stmt, qubit_idx_ssas)
52
+ for address_reg in data:
46
53
 
47
- if isinstance(stmt, (wire.Apply, wire.Broadcast)):
48
- create_wire_passthrough(stmt)
54
+ qubit_idx_ssas = insert_qubit_idx_from_address(
55
+ AddressAttribute(address_reg), stmt
56
+ )
49
57
 
50
- if stim_stmt is not None:
51
- stmt.replace_by(stim_stmt)
52
- if len(stmt.operator.owner.result.uses) == 0:
53
- stmt.operator.owner.delete()
58
+ stim_stmt = rewrite_method(stmt, qubit_idx_ssas)
59
+ stim_stmt.insert_before(stmt)
60
+
61
+ stmt.delete()
54
62
 
55
63
  return RewriteResult(has_done_something=True)
56
- return RewriteResult()
57
64
 
58
- def rewrite_PauliError(
59
- self,
60
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
61
- qubit_idx_ssas: Tuple[SSAValue],
62
- ) -> Statement:
63
- """Rewrite squin.noise.PauliError to XError, YError, ZError."""
64
- squin_channel = stmt.operator.owner
65
- assert isinstance(squin_channel, squin_noise.stmts.PauliError)
66
- basis = squin_channel.basis.owner
67
- assert isinstance(basis, op.stmts.PauliOp)
68
- p = get_const_value(float, squin_channel.p)
69
-
70
- p_stmt = py.Constant(p)
71
- p_stmt.insert_before(stmt)
72
-
73
- if isinstance(basis, op.stmts.X):
74
- stim_stmt = stim_noise.XError(targets=qubit_idx_ssas, p=p_stmt.result)
75
- elif isinstance(basis, op.stmts.Y):
76
- stim_stmt = stim_noise.YError(targets=qubit_idx_ssas, p=p_stmt.result)
65
+ if isinstance(stmt, squin_noise.stmts.SingleQubitNoiseChannel):
66
+ qubit_address_attr = stmt.qubits.hints.get("address", None)
67
+ if qubit_address_attr is None:
68
+ return RewriteResult()
69
+ qubit_idx_ssas = insert_qubit_idx_from_address(qubit_address_attr, stmt)
70
+
71
+ elif isinstance(stmt, squin_noise.stmts.TwoQubitNoiseChannel):
72
+ control_address_attr = stmt.controls.hints.get("address", None)
73
+ target_address_attr = stmt.targets.hints.get("address", None)
74
+ if control_address_attr is None or target_address_attr is None:
75
+ return RewriteResult()
76
+ control_qubit_idx_ssas = insert_qubit_idx_from_address(
77
+ control_address_attr, stmt
78
+ )
79
+ target_qubit_idx_ssas = insert_qubit_idx_from_address(
80
+ target_address_attr, stmt
81
+ )
82
+ if control_qubit_idx_ssas is None or target_qubit_idx_ssas is None:
83
+ return RewriteResult()
84
+
85
+ # For stim statements you want to interleave the control and target qubit indices:
86
+ # ex: CX controls = (0,1) targets = (2,3) in stim is: CX 0 2 1 3
87
+ qubit_idx_ssas = list(
88
+ itertools.chain.from_iterable(
89
+ zip(control_qubit_idx_ssas, target_qubit_idx_ssas)
90
+ )
91
+ )
77
92
  else:
78
- stim_stmt = stim_noise.ZError(targets=qubit_idx_ssas, p=p_stmt.result)
79
- return stim_stmt
93
+ return RewriteResult()
94
+
95
+ # guaranteed that you have a valid stim_stmt to plug in
96
+ stim_stmt = rewrite_method(stmt, tuple(qubit_idx_ssas))
97
+ stmt.replace_by(stim_stmt)
98
+
99
+ return RewriteResult(has_done_something=True)
80
100
 
81
101
  def rewrite_SingleQubitPauliChannel(
82
102
  self,
83
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
103
+ stmt: squin_noise.stmts.SingleQubitPauliChannel,
84
104
  qubit_idx_ssas: Tuple[SSAValue],
85
105
  ) -> Statement:
86
106
  """Rewrite squin.noise.SingleQubitPauliChannel to stim.PauliChannel1."""
87
107
 
88
- squin_channel = stmt.operator.owner
89
- assert isinstance(squin_channel, squin_noise.stmts.SingleQubitPauliChannel)
90
-
91
- params = get_const_value(ilist.IList, squin_channel.params)
92
- new_stmts = [
93
- p_x := py.Constant(params[0]),
94
- p_y := py.Constant(params[1]),
95
- p_z := py.Constant(params[2]),
96
- ]
97
- for new_stmt in new_stmts:
98
- new_stmt.insert_before(stmt)
99
-
100
108
  stim_stmt = stim_noise.PauliChannel1(
101
109
  targets=qubit_idx_ssas,
102
- px=p_x.result,
103
- py=p_y.result,
104
- pz=p_z.result,
110
+ px=stmt.px,
111
+ py=stmt.py,
112
+ pz=stmt.pz,
105
113
  )
106
114
  return stim_stmt
107
115
 
108
- def rewrite_TwoQubitPauliChannel(
116
+ def rewrite_QubitLoss(
109
117
  self,
110
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
118
+ stmt: squin_noise.stmts.QubitLoss,
111
119
  qubit_idx_ssas: Tuple[SSAValue],
112
120
  ) -> Statement:
113
- """Rewrite squin.noise.SingleQubitPauliChannel to stim.PauliChannel1."""
121
+ """Rewrite squin.noise.QubitLoss to stim.TrivialError."""
114
122
 
115
- squin_channel = stmt.operator.owner
116
- assert isinstance(squin_channel, squin_noise.stmts.TwoQubitPauliChannel)
123
+ stim_stmt = stim_noise.QubitLoss(
124
+ targets=qubit_idx_ssas,
125
+ probs=(stmt.p,),
126
+ )
117
127
 
118
- params = get_const_value(ilist.IList, squin_channel.params)
119
- param_stmts = [py.Constant(p) for p in params]
120
- for param_stmt in param_stmts:
121
- param_stmt.insert_before(stmt)
128
+ return stim_stmt
122
129
 
123
- stim_stmt = stim_noise.PauliChannel2(
130
+ def rewrite_CorrelatedQubitLoss(
131
+ self,
132
+ stmt: squin_noise.stmts.CorrelatedQubitLoss,
133
+ qubit_idx_ssas: Tuple[SSAValue],
134
+ ) -> Statement:
135
+ """Rewrite squin.noise.CorrelatedQubitLoss to stim.CorrelatedQubitLoss."""
136
+ stim_stmt = stim_noise.CorrelatedQubitLoss(
124
137
  targets=qubit_idx_ssas,
125
- pix=param_stmts[0].result,
126
- piy=param_stmts[1].result,
127
- piz=param_stmts[2].result,
128
- pxi=param_stmts[3].result,
129
- pxx=param_stmts[4].result,
130
- pxy=param_stmts[5].result,
131
- pxz=param_stmts[6].result,
132
- pyi=param_stmts[7].result,
133
- pyx=param_stmts[8].result,
134
- pyy=param_stmts[9].result,
135
- pyz=param_stmts[10].result,
136
- pzi=param_stmts[11].result,
137
- pzx=param_stmts[12].result,
138
- pzy=param_stmts[13].result,
139
- pzz=param_stmts[14].result,
138
+ probs=(stmt.p,),
140
139
  )
140
+
141
141
  return stim_stmt
142
142
 
143
- def rewrite_Depolarize2(
143
+ def rewrite_Depolarize(
144
144
  self,
145
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
145
+ stmt: squin_noise.stmts.Depolarize,
146
146
  qubit_idx_ssas: Tuple[SSAValue],
147
147
  ) -> Statement:
148
- """Rewrite squin.noise.Depolarize2 to stim.Depolarize2."""
149
-
150
- squin_channel = stmt.operator.owner
151
- assert isinstance(squin_channel, squin_noise.stmts.Depolarize2)
148
+ """Rewrite squin.noise.Depolarize to stim.Depolarize1."""
152
149
 
153
- p = get_const_value(float, squin_channel.p)
154
- p_stmt = py.Constant(p)
155
- p_stmt.insert_before(stmt)
150
+ stim_stmt = stim_noise.Depolarize1(
151
+ targets=qubit_idx_ssas,
152
+ p=stmt.p,
153
+ )
156
154
 
157
- stim_stmt = stim_noise.Depolarize2(targets=qubit_idx_ssas, p=p_stmt.result)
158
155
  return stim_stmt
159
156
 
160
- def rewrite_Depolarize(
157
+ def rewrite_TwoQubitPauliChannel(
161
158
  self,
162
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
159
+ stmt: squin_noise.stmts.TwoQubitPauliChannel,
163
160
  qubit_idx_ssas: Tuple[SSAValue],
164
161
  ) -> Statement:
165
- """Rewrite squin.noise.Depolarize to stim.Depolarize1."""
162
+ """Rewrite squin.noise.TwoQubitPauliChannel to stim.PauliChannel2."""
163
+
164
+ params = stmt.probabilities
165
+ prob_ssas = []
166
+ for idx in range(15):
167
+ idx_stmt = py.Constant(value=idx)
168
+ idx_stmt.insert_before(stmt)
169
+ getitem_stmt = py.GetItem(obj=params, index=idx_stmt.result)
170
+ getitem_stmt.insert_before(stmt)
171
+ prob_ssas.append(getitem_stmt.result)
166
172
 
167
- squin_channel = stmt.operator.owner
168
- assert isinstance(squin_channel, squin_noise.stmts.Depolarize)
173
+ stim_stmt = stim_noise.PauliChannel2(
174
+ targets=qubit_idx_ssas,
175
+ pix=prob_ssas[0],
176
+ piy=prob_ssas[1],
177
+ piz=prob_ssas[2],
178
+ pxi=prob_ssas[3],
179
+ pxx=prob_ssas[4],
180
+ pxy=prob_ssas[5],
181
+ pxz=prob_ssas[6],
182
+ pyi=prob_ssas[7],
183
+ pyx=prob_ssas[8],
184
+ pyy=prob_ssas[9],
185
+ pyz=prob_ssas[10],
186
+ pzi=prob_ssas[11],
187
+ pzx=prob_ssas[12],
188
+ pzy=prob_ssas[13],
189
+ pzz=prob_ssas[14],
190
+ )
191
+ return stim_stmt
169
192
 
170
- p = get_const_value(float, squin_channel.p)
171
- p_stmt = py.Constant(p)
172
- p_stmt.insert_before(stmt)
193
+ def rewrite_Depolarize2(
194
+ self,
195
+ stmt: squin_noise.stmts.Depolarize2,
196
+ qubit_idx_ssas: Tuple[SSAValue],
197
+ ) -> Statement:
198
+ """Rewrite squin.noise.Depolarize2 to stim.Depolarize2."""
173
199
 
174
- stim_stmt = stim_noise.Depolarize1(targets=qubit_idx_ssas, p=p_stmt.result)
200
+ stim_stmt = stim_noise.Depolarize2(targets=qubit_idx_ssas, p=stmt.p)
175
201
  return stim_stmt