bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/analysis.py +18 -20
  7. bloqade/analysis/measure_id/impls.py +31 -29
  8. bloqade/annotate/__init__.py +6 -0
  9. bloqade/annotate/_dialect.py +3 -0
  10. bloqade/annotate/_interface.py +22 -0
  11. bloqade/annotate/stmts.py +29 -0
  12. bloqade/annotate/types.py +13 -0
  13. bloqade/cirq_utils/__init__.py +4 -2
  14. bloqade/cirq_utils/emit/__init__.py +3 -0
  15. bloqade/cirq_utils/emit/base.py +246 -0
  16. bloqade/cirq_utils/emit/gate.py +104 -0
  17. bloqade/cirq_utils/emit/noise.py +90 -0
  18. bloqade/cirq_utils/emit/qubit.py +35 -0
  19. bloqade/cirq_utils/lowering.py +660 -0
  20. bloqade/cirq_utils/noise/__init__.py +0 -2
  21. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  22. bloqade/cirq_utils/noise/model.py +151 -191
  23. bloqade/cirq_utils/noise/transform.py +2 -2
  24. bloqade/cirq_utils/parallelize.py +9 -6
  25. bloqade/gemini/__init__.py +1 -0
  26. bloqade/gemini/analysis/__init__.py +3 -0
  27. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  28. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  29. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  30. bloqade/gemini/groups.py +67 -0
  31. bloqade/native/__init__.py +23 -0
  32. bloqade/native/_prelude.py +45 -0
  33. bloqade/native/dialects/__init__.py +0 -0
  34. bloqade/native/dialects/gate/__init__.py +2 -0
  35. bloqade/native/dialects/gate/_dialect.py +3 -0
  36. bloqade/native/dialects/gate/_interface.py +32 -0
  37. bloqade/native/dialects/gate/stmts.py +31 -0
  38. bloqade/native/stdlib/__init__.py +0 -0
  39. bloqade/native/stdlib/broadcast.py +246 -0
  40. bloqade/native/stdlib/simple.py +220 -0
  41. bloqade/native/upstream/__init__.py +4 -0
  42. bloqade/native/upstream/squin2native.py +79 -0
  43. bloqade/pyqrack/__init__.py +2 -2
  44. bloqade/pyqrack/base.py +7 -1
  45. bloqade/pyqrack/device.py +192 -18
  46. bloqade/pyqrack/native.py +49 -0
  47. bloqade/pyqrack/reg.py +6 -6
  48. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  49. bloqade/pyqrack/squin/gate/gate.py +136 -0
  50. bloqade/pyqrack/squin/noise/native.py +120 -54
  51. bloqade/pyqrack/squin/qubit.py +39 -36
  52. bloqade/pyqrack/target.py +5 -4
  53. bloqade/pyqrack/task.py +114 -7
  54. bloqade/qasm2/_qasm_loading.py +3 -3
  55. bloqade/qasm2/dialects/core/address.py +21 -12
  56. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  57. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  58. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  59. bloqade/qasm2/dialects/noise/model.py +2 -1
  60. bloqade/qasm2/emit/base.py +16 -11
  61. bloqade/qasm2/emit/gate.py +11 -8
  62. bloqade/qasm2/emit/main.py +103 -3
  63. bloqade/qasm2/emit/target.py +9 -5
  64. bloqade/qasm2/groups.py +3 -2
  65. bloqade/qasm2/parse/lowering.py +0 -1
  66. bloqade/qasm2/passes/fold.py +14 -73
  67. bloqade/qasm2/passes/glob.py +2 -2
  68. bloqade/qasm2/passes/noise.py +1 -1
  69. bloqade/qasm2/passes/parallel.py +7 -5
  70. bloqade/qasm2/rewrite/__init__.py +0 -1
  71. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  72. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  73. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  74. bloqade/qasm2/rewrite/register.py +2 -2
  75. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  76. bloqade/qbraid/lowering.py +1 -0
  77. bloqade/qbraid/schema.py +2 -2
  78. bloqade/qubit/__init__.py +12 -0
  79. bloqade/qubit/_dialect.py +3 -0
  80. bloqade/qubit/_interface.py +49 -0
  81. bloqade/qubit/_prelude.py +45 -0
  82. bloqade/qubit/analysis/__init__.py +1 -0
  83. bloqade/qubit/analysis/address_impl.py +40 -0
  84. bloqade/qubit/stdlib/__init__.py +2 -0
  85. bloqade/qubit/stdlib/_new.py +34 -0
  86. bloqade/qubit/stdlib/broadcast.py +62 -0
  87. bloqade/qubit/stdlib/simple.py +59 -0
  88. bloqade/qubit/stmts.py +60 -0
  89. bloqade/rewrite/passes/__init__.py +6 -0
  90. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  91. bloqade/rewrite/passes/callgraph.py +116 -0
  92. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  93. bloqade/rewrite/rules/split_ifs.py +18 -1
  94. bloqade/squin/__init__.py +47 -14
  95. bloqade/squin/analysis/__init__.py +0 -1
  96. bloqade/squin/analysis/schedule.py +10 -11
  97. bloqade/squin/gate/__init__.py +2 -0
  98. bloqade/squin/gate/_dialect.py +3 -0
  99. bloqade/squin/gate/_interface.py +98 -0
  100. bloqade/squin/gate/stmts.py +125 -0
  101. bloqade/squin/groups.py +5 -22
  102. bloqade/squin/noise/__init__.py +1 -10
  103. bloqade/squin/noise/_dialect.py +1 -1
  104. bloqade/squin/noise/_interface.py +45 -0
  105. bloqade/squin/noise/stmts.py +66 -28
  106. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  107. bloqade/squin/rewrite/__init__.py +0 -2
  108. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  109. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  110. bloqade/squin/stdlib/__init__.py +0 -0
  111. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  112. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  113. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  114. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  115. bloqade/squin/stdlib/simple/__init__.py +33 -0
  116. bloqade/squin/stdlib/simple/gate.py +242 -0
  117. bloqade/squin/stdlib/simple/noise.py +126 -0
  118. bloqade/stim/__init__.py +1 -0
  119. bloqade/stim/_wrappers.py +6 -0
  120. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  121. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  122. bloqade/stim/dialects/gate/emit.py +9 -10
  123. bloqade/stim/dialects/noise/emit.py +17 -13
  124. bloqade/stim/dialects/noise/stmts.py +5 -3
  125. bloqade/stim/emit/__init__.py +1 -0
  126. bloqade/stim/emit/impls.py +16 -0
  127. bloqade/stim/emit/stim_str.py +48 -31
  128. bloqade/stim/groups.py +12 -2
  129. bloqade/stim/parse/lowering.py +14 -17
  130. bloqade/stim/passes/__init__.py +0 -2
  131. bloqade/stim/passes/flatten.py +26 -0
  132. bloqade/stim/passes/simplify_ifs.py +6 -1
  133. bloqade/stim/passes/squin_to_stim.py +9 -84
  134. bloqade/stim/rewrite/__init__.py +2 -4
  135. bloqade/stim/rewrite/get_record_util.py +24 -0
  136. bloqade/stim/rewrite/ifs_to_stim.py +24 -25
  137. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  138. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  139. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  140. bloqade/stim/rewrite/squin_measure.py +9 -18
  141. bloqade/stim/rewrite/squin_noise.py +134 -108
  142. bloqade/stim/rewrite/util.py +5 -192
  143. bloqade/test_utils.py +1 -1
  144. bloqade/types.py +10 -0
  145. bloqade/validation/__init__.py +2 -0
  146. bloqade/validation/analysis/__init__.py +5 -0
  147. bloqade/validation/analysis/analysis.py +41 -0
  148. bloqade/validation/analysis/lattice.py +58 -0
  149. bloqade/validation/kernel_validation.py +77 -0
  150. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  151. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  152. bloqade/pyqrack/squin/op.py +0 -180
  153. bloqade/pyqrack/squin/runtime.py +0 -535
  154. bloqade/pyqrack/squin/wire.py +0 -51
  155. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  156. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  157. bloqade/squin/_typeinfer.py +0 -20
  158. bloqade/squin/analysis/address_impl.py +0 -71
  159. bloqade/squin/analysis/nsites/__init__.py +0 -9
  160. bloqade/squin/analysis/nsites/analysis.py +0 -50
  161. bloqade/squin/analysis/nsites/impls.py +0 -92
  162. bloqade/squin/analysis/nsites/lattice.py +0 -49
  163. bloqade/squin/cirq/__init__.py +0 -280
  164. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  165. bloqade/squin/cirq/emit/noise.py +0 -49
  166. bloqade/squin/cirq/emit/op.py +0 -125
  167. bloqade/squin/cirq/emit/qubit.py +0 -60
  168. bloqade/squin/cirq/emit/runtime.py +0 -242
  169. bloqade/squin/cirq/lowering.py +0 -440
  170. bloqade/squin/lowering.py +0 -54
  171. bloqade/squin/noise/_wrapper.py +0 -40
  172. bloqade/squin/noise/rewrite.py +0 -111
  173. bloqade/squin/op/__init__.py +0 -41
  174. bloqade/squin/op/_dialect.py +0 -3
  175. bloqade/squin/op/_wrapper.py +0 -121
  176. bloqade/squin/op/number.py +0 -5
  177. bloqade/squin/op/rewrite.py +0 -46
  178. bloqade/squin/op/stdlib.py +0 -62
  179. bloqade/squin/op/stmts.py +0 -276
  180. bloqade/squin/op/traits.py +0 -43
  181. bloqade/squin/op/types.py +0 -26
  182. bloqade/squin/qubit.py +0 -184
  183. bloqade/squin/rewrite/canonicalize.py +0 -60
  184. bloqade/squin/rewrite/desugar.py +0 -124
  185. bloqade/squin/types.py +0 -8
  186. bloqade/squin/wire.py +0 -201
  187. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  188. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  189. bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
  190. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  191. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,52 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+
6
+ from bloqade.stim.dialects import auxiliary
7
+ from bloqade.annotate.stmts import SetObservable
8
+ from bloqade.analysis.measure_id import MeasureIDFrame
9
+ from bloqade.stim.dialects.auxiliary import ObservableInclude
10
+ from bloqade.analysis.measure_id.lattice import MeasureIdTuple
11
+
12
+ from ..rewrite.get_record_util import insert_get_records
13
+
14
+
15
+ @dataclass
16
+ class SetObservableToStim(RewriteRule):
17
+ """
18
+ Rewrite SetObservable to GetRecord and ObservableInclude in the stim dialect
19
+ """
20
+
21
+ measure_id_frame: MeasureIDFrame
22
+
23
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
24
+ match node:
25
+ case SetObservable():
26
+ return self.rewrite_SetObservable(node)
27
+ case _:
28
+ return RewriteResult()
29
+
30
+ def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult:
31
+
32
+ # set idx to 0 for now, but this
33
+ # should be something that a user can set on their own.
34
+ # SetObservable needs to accept an int.
35
+
36
+ idx_stmt = auxiliary.ConstInt(value=0)
37
+ idx_stmt.insert_before(node)
38
+
39
+ measure_ids = self.measure_id_frame.entries[node.measurements]
40
+ assert isinstance(measure_ids, MeasureIdTuple)
41
+
42
+ get_record_list = insert_get_records(
43
+ node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
44
+ )
45
+
46
+ observable_include_stmt = ObservableInclude(
47
+ idx=idx_stmt.result, targets=tuple(get_record_list)
48
+ )
49
+
50
+ node.replace_by(observable_include_stmt)
51
+
52
+ return RewriteResult(has_done_something=True)
@@ -5,11 +5,10 @@ from kirin import ir
5
5
  from kirin.dialects import py
6
6
  from kirin.rewrite.abc import RewriteRule, RewriteResult
7
7
 
8
- from bloqade.squin import wire, qubit
8
+ from bloqade import qubit
9
9
  from bloqade.squin.rewrite import AddressAttribute
10
10
  from bloqade.stim.dialects import collapse
11
11
  from bloqade.stim.rewrite.util import (
12
- is_measure_result_used,
13
12
  insert_qubit_idx_from_address,
14
13
  )
15
14
 
@@ -23,14 +22,12 @@ class SquinMeasureToStim(RewriteRule):
23
22
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
24
23
 
25
24
  match node:
26
- case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure():
25
+ case qubit.stmts.Measure():
27
26
  return self.rewrite_Measure(node)
28
27
  case _:
29
28
  return RewriteResult()
30
29
 
31
- def rewrite_Measure(
32
- self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
33
- ) -> RewriteResult:
30
+ def rewrite_Measure(self, measure_stmt: qubit.stmts.Measure) -> RewriteResult:
34
31
 
35
32
  qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
36
33
  if qubit_idx_ssas is None:
@@ -44,27 +41,21 @@ class SquinMeasureToStim(RewriteRule):
44
41
  prob_noise_stmt.insert_before(measure_stmt)
45
42
  stim_measure_stmt.insert_before(measure_stmt)
46
43
 
47
- if not is_measure_result_used(measure_stmt):
44
+ # if the measurement is not being used anywhere
45
+ # we can safely get rid of it. Measure cannot be DCE'd because
46
+ # it is not pure.
47
+ if not bool(measure_stmt.result.uses):
48
48
  measure_stmt.delete()
49
49
 
50
50
  return RewriteResult(has_done_something=True)
51
51
 
52
52
  def get_qubit_idx_ssas(
53
- self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
53
+ self, measure_stmt: qubit.stmts.Measure
54
54
  ) -> tuple[ir.SSAValue, ...] | None:
55
55
  """
56
56
  Extract the address attribute and insert qubit indices for the given measure statement.
57
57
  """
58
- match measure_stmt:
59
- case qubit.MeasureQubit():
60
- address_attr = measure_stmt.qubit.hints.get("address")
61
- case qubit.MeasureQubitList():
62
- address_attr = measure_stmt.qubits.hints.get("address")
63
- case wire.Measure():
64
- address_attr = measure_stmt.wire.hints.get("address")
65
- case _:
66
- return None
67
-
58
+ address_attr = measure_stmt.qubits.hints.get("address")
68
59
  if address_attr is None:
69
60
  return None
70
61
 
@@ -1,17 +1,17 @@
1
+ import itertools
1
2
  from typing import Tuple
2
3
  from dataclasses import dataclass
3
4
 
5
+ from kirin import types
4
6
  from kirin.ir import SSAValue, Statement
5
- from kirin.dialects import py, ilist
7
+ from kirin.dialects import py
6
8
  from kirin.rewrite.abc import RewriteRule, RewriteResult
7
9
 
8
- from bloqade.squin import op, wire, noise as squin_noise, qubit
10
+ from bloqade.squin import noise as squin_noise
9
11
  from bloqade.stim.dialects import noise as stim_noise
10
- from bloqade.stim.rewrite.util import (
11
- get_const_value,
12
- create_wire_passthrough,
13
- insert_qubit_idx_after_apply,
14
- )
12
+ from bloqade.stim.rewrite.util import insert_qubit_idx_from_address
13
+ from bloqade.analysis.address.lattice import AddressReg, PartialIList
14
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
15
15
 
16
16
 
17
17
  @dataclass
@@ -19,157 +19,183 @@ class SquinNoiseToStim(RewriteRule):
19
19
 
20
20
  def rewrite_Statement(self, node: Statement) -> RewriteResult:
21
21
  match node:
22
- case qubit.Apply() | qubit.Broadcast():
23
- return self.rewrite_Apply_and_Broadcast(node)
22
+ case squin_noise.stmts.NoiseChannel():
23
+ return self.rewrite_NoiseChannel(node)
24
24
  case _:
25
25
  return RewriteResult()
26
26
 
27
- def rewrite_Apply_and_Broadcast(
28
- self, stmt: qubit.Apply | qubit.Broadcast
27
+ def rewrite_NoiseChannel(
28
+ self, stmt: squin_noise.stmts.NoiseChannel
29
29
  ) -> RewriteResult:
30
- """Rewrite Apply and Broadcast to their stim statements."""
30
+ """Rewrite NoiseChannel statements to their stim equivalents."""
31
31
 
32
- # this is an SSAValue, need it to be the actual operator
33
- applied_op = stmt.operator.owner
32
+ rewrite_method = getattr(self, f"rewrite_{type(stmt).__name__}", None)
34
33
 
35
- if isinstance(applied_op, squin_noise.stmts.QubitLoss):
34
+ # No rewrite method exists and the rewrite should stop
35
+ if rewrite_method is None:
36
36
  return RewriteResult()
37
+ if isinstance(stmt, squin_noise.stmts.CorrelatedQubitLoss):
38
+ # CorrelatedQubitLoss represents a broadcast operation, but Stim does not
39
+ # support broadcasting for multi-qubit noise channels.
40
+ # Therefore, we must expand the broadcast into individual stim statements.
41
+ qubit_address_attr = stmt.qubits.hints.get("address", None)
37
42
 
38
- if isinstance(applied_op, squin_noise.stmts.NoiseChannel):
43
+ if not isinstance(qubit_address_attr, AddressAttribute):
44
+ return RewriteResult()
45
+
46
+ if not isinstance(address := qubit_address_attr.address, PartialIList):
47
+ return RewriteResult()
39
48
 
40
- qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt)
41
- if qubit_idx_ssas is None:
49
+ if not types.is_tuple_of(data := address.data, AddressReg):
42
50
  return RewriteResult()
43
51
 
44
- rewrite_method = getattr(self, f"rewrite_{type(applied_op).__name__}")
45
- stim_stmt = rewrite_method(stmt, qubit_idx_ssas)
52
+ for address_reg in data:
46
53
 
47
- if isinstance(stmt, (wire.Apply, wire.Broadcast)):
48
- create_wire_passthrough(stmt)
54
+ qubit_idx_ssas = insert_qubit_idx_from_address(
55
+ AddressAttribute(address_reg), stmt
56
+ )
49
57
 
50
- if stim_stmt is not None:
51
- stmt.replace_by(stim_stmt)
52
- if len(stmt.operator.owner.result.uses) == 0:
53
- stmt.operator.owner.delete()
58
+ stim_stmt = rewrite_method(stmt, qubit_idx_ssas)
59
+ stim_stmt.insert_before(stmt)
60
+
61
+ stmt.delete()
54
62
 
55
63
  return RewriteResult(has_done_something=True)
56
- return RewriteResult()
57
64
 
58
- def rewrite_PauliError(
59
- self,
60
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
61
- qubit_idx_ssas: Tuple[SSAValue],
62
- ) -> Statement:
63
- """Rewrite squin.noise.PauliError to XError, YError, ZError."""
64
- squin_channel = stmt.operator.owner
65
- assert isinstance(squin_channel, squin_noise.stmts.PauliError)
66
- basis = squin_channel.basis.owner
67
- assert isinstance(basis, op.stmts.PauliOp)
68
- p = get_const_value(float, squin_channel.p)
69
-
70
- p_stmt = py.Constant(p)
71
- p_stmt.insert_before(stmt)
72
-
73
- if isinstance(basis, op.stmts.X):
74
- stim_stmt = stim_noise.XError(targets=qubit_idx_ssas, p=p_stmt.result)
75
- elif isinstance(basis, op.stmts.Y):
76
- stim_stmt = stim_noise.YError(targets=qubit_idx_ssas, p=p_stmt.result)
65
+ if isinstance(stmt, squin_noise.stmts.SingleQubitNoiseChannel):
66
+ qubit_address_attr = stmt.qubits.hints.get("address", None)
67
+ if qubit_address_attr is None:
68
+ return RewriteResult()
69
+ qubit_idx_ssas = insert_qubit_idx_from_address(qubit_address_attr, stmt)
70
+
71
+ elif isinstance(stmt, squin_noise.stmts.TwoQubitNoiseChannel):
72
+ control_address_attr = stmt.controls.hints.get("address", None)
73
+ target_address_attr = stmt.targets.hints.get("address", None)
74
+ if control_address_attr is None or target_address_attr is None:
75
+ return RewriteResult()
76
+ control_qubit_idx_ssas = insert_qubit_idx_from_address(
77
+ control_address_attr, stmt
78
+ )
79
+ target_qubit_idx_ssas = insert_qubit_idx_from_address(
80
+ target_address_attr, stmt
81
+ )
82
+ if control_qubit_idx_ssas is None or target_qubit_idx_ssas is None:
83
+ return RewriteResult()
84
+
85
+ # For stim statements you want to interleave the control and target qubit indices:
86
+ # ex: CX controls = (0,1) targets = (2,3) in stim is: CX 0 2 1 3
87
+ qubit_idx_ssas = list(
88
+ itertools.chain.from_iterable(
89
+ zip(control_qubit_idx_ssas, target_qubit_idx_ssas)
90
+ )
91
+ )
77
92
  else:
78
- stim_stmt = stim_noise.ZError(targets=qubit_idx_ssas, p=p_stmt.result)
79
- return stim_stmt
93
+ return RewriteResult()
94
+
95
+ # guaranteed that you have a valid stim_stmt to plug in
96
+ stim_stmt = rewrite_method(stmt, tuple(qubit_idx_ssas))
97
+ stmt.replace_by(stim_stmt)
98
+
99
+ return RewriteResult(has_done_something=True)
80
100
 
81
101
  def rewrite_SingleQubitPauliChannel(
82
102
  self,
83
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
103
+ stmt: squin_noise.stmts.SingleQubitPauliChannel,
84
104
  qubit_idx_ssas: Tuple[SSAValue],
85
105
  ) -> Statement:
86
106
  """Rewrite squin.noise.SingleQubitPauliChannel to stim.PauliChannel1."""
87
107
 
88
- squin_channel = stmt.operator.owner
89
- assert isinstance(squin_channel, squin_noise.stmts.SingleQubitPauliChannel)
90
-
91
- params = get_const_value(ilist.IList, squin_channel.params)
92
- new_stmts = [
93
- p_x := py.Constant(params[0]),
94
- p_y := py.Constant(params[1]),
95
- p_z := py.Constant(params[2]),
96
- ]
97
- for new_stmt in new_stmts:
98
- new_stmt.insert_before(stmt)
99
-
100
108
  stim_stmt = stim_noise.PauliChannel1(
101
109
  targets=qubit_idx_ssas,
102
- px=p_x.result,
103
- py=p_y.result,
104
- pz=p_z.result,
110
+ px=stmt.px,
111
+ py=stmt.py,
112
+ pz=stmt.pz,
105
113
  )
106
114
  return stim_stmt
107
115
 
108
- def rewrite_TwoQubitPauliChannel(
116
+ def rewrite_QubitLoss(
109
117
  self,
110
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
118
+ stmt: squin_noise.stmts.QubitLoss,
111
119
  qubit_idx_ssas: Tuple[SSAValue],
112
120
  ) -> Statement:
113
- """Rewrite squin.noise.SingleQubitPauliChannel to stim.PauliChannel1."""
121
+ """Rewrite squin.noise.QubitLoss to stim.TrivialError."""
114
122
 
115
- squin_channel = stmt.operator.owner
116
- assert isinstance(squin_channel, squin_noise.stmts.TwoQubitPauliChannel)
123
+ stim_stmt = stim_noise.QubitLoss(
124
+ targets=qubit_idx_ssas,
125
+ probs=(stmt.p,),
126
+ )
117
127
 
118
- params = get_const_value(ilist.IList, squin_channel.params)
119
- param_stmts = [py.Constant(p) for p in params]
120
- for param_stmt in param_stmts:
121
- param_stmt.insert_before(stmt)
128
+ return stim_stmt
122
129
 
123
- stim_stmt = stim_noise.PauliChannel2(
130
+ def rewrite_CorrelatedQubitLoss(
131
+ self,
132
+ stmt: squin_noise.stmts.CorrelatedQubitLoss,
133
+ qubit_idx_ssas: Tuple[SSAValue],
134
+ ) -> Statement:
135
+ """Rewrite squin.noise.CorrelatedQubitLoss to stim.CorrelatedQubitLoss."""
136
+ stim_stmt = stim_noise.CorrelatedQubitLoss(
124
137
  targets=qubit_idx_ssas,
125
- pix=param_stmts[0].result,
126
- piy=param_stmts[1].result,
127
- piz=param_stmts[2].result,
128
- pxi=param_stmts[3].result,
129
- pxx=param_stmts[4].result,
130
- pxy=param_stmts[5].result,
131
- pxz=param_stmts[6].result,
132
- pyi=param_stmts[7].result,
133
- pyx=param_stmts[8].result,
134
- pyy=param_stmts[9].result,
135
- pyz=param_stmts[10].result,
136
- pzi=param_stmts[11].result,
137
- pzx=param_stmts[12].result,
138
- pzy=param_stmts[13].result,
139
- pzz=param_stmts[14].result,
138
+ probs=(stmt.p,),
140
139
  )
140
+
141
141
  return stim_stmt
142
142
 
143
- def rewrite_Depolarize2(
143
+ def rewrite_Depolarize(
144
144
  self,
145
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
145
+ stmt: squin_noise.stmts.Depolarize,
146
146
  qubit_idx_ssas: Tuple[SSAValue],
147
147
  ) -> Statement:
148
- """Rewrite squin.noise.Depolarize2 to stim.Depolarize2."""
149
-
150
- squin_channel = stmt.operator.owner
151
- assert isinstance(squin_channel, squin_noise.stmts.Depolarize2)
148
+ """Rewrite squin.noise.Depolarize to stim.Depolarize1."""
152
149
 
153
- p = get_const_value(float, squin_channel.p)
154
- p_stmt = py.Constant(p)
155
- p_stmt.insert_before(stmt)
150
+ stim_stmt = stim_noise.Depolarize1(
151
+ targets=qubit_idx_ssas,
152
+ p=stmt.p,
153
+ )
156
154
 
157
- stim_stmt = stim_noise.Depolarize2(targets=qubit_idx_ssas, p=p_stmt.result)
158
155
  return stim_stmt
159
156
 
160
- def rewrite_Depolarize(
157
+ def rewrite_TwoQubitPauliChannel(
161
158
  self,
162
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
159
+ stmt: squin_noise.stmts.TwoQubitPauliChannel,
163
160
  qubit_idx_ssas: Tuple[SSAValue],
164
161
  ) -> Statement:
165
- """Rewrite squin.noise.Depolarize to stim.Depolarize1."""
162
+ """Rewrite squin.noise.TwoQubitPauliChannel to stim.PauliChannel2."""
163
+
164
+ params = stmt.probabilities
165
+ prob_ssas = []
166
+ for idx in range(15):
167
+ idx_stmt = py.Constant(value=idx)
168
+ idx_stmt.insert_before(stmt)
169
+ getitem_stmt = py.GetItem(obj=params, index=idx_stmt.result)
170
+ getitem_stmt.insert_before(stmt)
171
+ prob_ssas.append(getitem_stmt.result)
166
172
 
167
- squin_channel = stmt.operator.owner
168
- assert isinstance(squin_channel, squin_noise.stmts.Depolarize)
173
+ stim_stmt = stim_noise.PauliChannel2(
174
+ targets=qubit_idx_ssas,
175
+ pix=prob_ssas[0],
176
+ piy=prob_ssas[1],
177
+ piz=prob_ssas[2],
178
+ pxi=prob_ssas[3],
179
+ pxx=prob_ssas[4],
180
+ pxy=prob_ssas[5],
181
+ pxz=prob_ssas[6],
182
+ pyi=prob_ssas[7],
183
+ pyx=prob_ssas[8],
184
+ pyy=prob_ssas[9],
185
+ pyz=prob_ssas[10],
186
+ pzi=prob_ssas[11],
187
+ pzx=prob_ssas[12],
188
+ pzy=prob_ssas[13],
189
+ pzz=prob_ssas[14],
190
+ )
191
+ return stim_stmt
169
192
 
170
- p = get_const_value(float, squin_channel.p)
171
- p_stmt = py.Constant(p)
172
- p_stmt.insert_before(stmt)
193
+ def rewrite_Depolarize2(
194
+ self,
195
+ stmt: squin_noise.stmts.Depolarize2,
196
+ qubit_idx_ssas: Tuple[SSAValue],
197
+ ) -> Statement:
198
+ """Rewrite squin.noise.Depolarize2 to stim.Depolarize2."""
173
199
 
174
- stim_stmt = stim_noise.Depolarize1(targets=qubit_idx_ssas, p=p_stmt.result)
200
+ stim_stmt = stim_noise.Depolarize2(targets=qubit_idx_ssas, p=stmt.p)
175
201
  return stim_stmt
@@ -1,35 +1,8 @@
1
- from typing import TypeVar
2
-
3
- from kirin import ir, interp
4
- from kirin.analysis import const
1
+ from kirin import ir
5
2
  from kirin.dialects import py
6
- from kirin.rewrite.abc import RewriteResult
7
3
 
8
- from bloqade.squin import op, wire, noise as squin_noise, qubit
9
4
  from bloqade.squin.rewrite import AddressAttribute
10
- from bloqade.stim.dialects import gate, noise as stim_noise, collapse
11
- from bloqade.analysis.address import AddressReg, AddressWire, AddressQubit, AddressTuple
12
-
13
- SQUIN_STIM_OP_MAPPING = {
14
- op.stmts.X: gate.X,
15
- op.stmts.Y: gate.Y,
16
- op.stmts.Z: gate.Z,
17
- op.stmts.H: gate.H,
18
- op.stmts.S: gate.S,
19
- op.stmts.SqrtX: gate.SqrtX,
20
- op.stmts.SqrtY: gate.SqrtY,
21
- op.stmts.Identity: gate.Identity,
22
- op.stmts.Reset: collapse.RZ,
23
- squin_noise.stmts.QubitLoss: stim_noise.QubitLoss,
24
- }
25
-
26
- # Squin allows creation of control gates where the gate can be any operator,
27
- # but Stim only supports CX, CY, and CZ as control gates.
28
- SQUIN_STIM_CONTROL_GATE_MAPPING = {
29
- op.stmts.X: gate.CX,
30
- op.stmts.Y: gate.CY,
31
- op.stmts.Z: gate.CZ,
32
- }
5
+ from bloqade.analysis.address import AddressReg, AddressQubit
33
6
 
34
7
 
35
8
  def create_and_insert_qubit_idx_stmt(
@@ -46,177 +19,17 @@ def insert_qubit_idx_from_address(
46
19
  """
47
20
  Extract qubit indices from an AddressAttribute and insert them into the SSA form.
48
21
  """
49
- address_data = address.address
50
22
  qubit_idx_ssas = []
51
-
52
- if isinstance(address_data, AddressTuple):
53
- for address_qubit in address_data.data:
54
- if not isinstance(address_qubit, AddressQubit):
55
- return
23
+ if isinstance(address_data := address.address, AddressReg):
24
+ for qubit_idx in address_data.qubits:
56
25
  create_and_insert_qubit_idx_stmt(
57
- address_qubit.data, stmt_to_insert_before, qubit_idx_ssas
58
- )
59
- elif isinstance(address_data, AddressReg):
60
- for qubit_idx in address_data.data:
61
- create_and_insert_qubit_idx_stmt(
62
- qubit_idx, stmt_to_insert_before, qubit_idx_ssas
26
+ qubit_idx.data, stmt_to_insert_before, qubit_idx_ssas
63
27
  )
64
28
  elif isinstance(address_data, AddressQubit):
65
29
  create_and_insert_qubit_idx_stmt(
66
30
  address_data.data, stmt_to_insert_before, qubit_idx_ssas
67
31
  )
68
- elif isinstance(address_data, AddressWire):
69
- address_qubit = address_data.origin_qubit
70
- create_and_insert_qubit_idx_stmt(
71
- address_qubit.data, stmt_to_insert_before, qubit_idx_ssas
72
- )
73
32
  else:
74
33
  return
75
34
 
76
35
  return tuple(qubit_idx_ssas)
77
-
78
-
79
- def insert_qubit_idx_from_wire_ssa(
80
- wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement
81
- ) -> tuple[ir.SSAValue, ...] | None:
82
- """
83
- Extract qubit indices from wire SSA values and insert them into the SSA form.
84
- """
85
- qubit_idx_ssas = []
86
- for wire_ssa in wire_ssas:
87
- address_attribute = wire_ssa.hints.get("address")
88
- if address_attribute is None:
89
- return
90
- assert isinstance(address_attribute, AddressAttribute)
91
- wire_address = address_attribute.address
92
- assert isinstance(wire_address, AddressWire)
93
- qubit_idx = wire_address.origin_qubit.data
94
- qubit_idx_stmt = py.Constant(qubit_idx)
95
- qubit_idx_ssas.append(qubit_idx_stmt.result)
96
- qubit_idx_stmt.insert_before(stmt_to_insert_before)
97
-
98
- return tuple(qubit_idx_ssas)
99
-
100
-
101
- def insert_qubit_idx_after_apply(
102
- stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast,
103
- ) -> tuple[ir.SSAValue, ...] | None:
104
- """
105
- Extract qubit indices from Apply or Broadcast statements.
106
- """
107
- if isinstance(stmt, (qubit.Apply, qubit.Broadcast)):
108
- qubits = stmt.qubits
109
- address_attribute = qubits.hints.get("address")
110
- if address_attribute is None:
111
- return
112
- assert isinstance(address_attribute, AddressAttribute)
113
- return insert_qubit_idx_from_address(
114
- address=address_attribute, stmt_to_insert_before=stmt
115
- )
116
- elif isinstance(stmt, (wire.Apply, wire.Broadcast)):
117
- wire_ssas = stmt.inputs
118
- return insert_qubit_idx_from_wire_ssa(
119
- wire_ssas=wire_ssas, stmt_to_insert_before=stmt
120
- )
121
-
122
-
123
- def rewrite_Control(
124
- stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast,
125
- ) -> RewriteResult:
126
- """
127
- Handle control gates for Apply and Broadcast statements.
128
- """
129
- ctrl_op = stmt_with_ctrl.operator.owner
130
- assert isinstance(ctrl_op, op.stmts.Control)
131
-
132
- ctrl_op_target_gate = ctrl_op.op.owner
133
- assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
134
-
135
- qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl)
136
- if qubit_idx_ssas is None:
137
- return RewriteResult()
138
-
139
- # Separate control and target qubits
140
- target_qubits = []
141
- ctrl_qubits = []
142
- for i in range(len(qubit_idx_ssas)):
143
- if (i % 2) == 0:
144
- ctrl_qubits.append(qubit_idx_ssas[i])
145
- else:
146
- target_qubits.append(qubit_idx_ssas[i])
147
-
148
- target_qubits = tuple(target_qubits)
149
- ctrl_qubits = tuple(ctrl_qubits)
150
-
151
- stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate))
152
- if stim_gate is None:
153
- return RewriteResult()
154
-
155
- stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits)
156
-
157
- if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)):
158
- create_wire_passthrough(stmt_with_ctrl)
159
-
160
- stmt_with_ctrl.replace_by(stim_stmt)
161
-
162
- return RewriteResult(has_done_something=True)
163
-
164
-
165
- def rewrite_QubitLoss(
166
- stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
167
- ) -> RewriteResult:
168
- """
169
- Rewrite QubitLoss statements to Stim's TrivialError.
170
- """
171
-
172
- squin_loss_op = stmt.operator.owner
173
- assert isinstance(squin_loss_op, squin_noise.stmts.QubitLoss)
174
-
175
- qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt)
176
- if qubit_idx_ssas is None:
177
- return RewriteResult()
178
-
179
- stim_loss_stmt = stim_noise.QubitLoss(
180
- targets=qubit_idx_ssas,
181
- probs=(squin_loss_op.p,),
182
- )
183
-
184
- if isinstance(stmt, (wire.Apply, wire.Broadcast)):
185
- create_wire_passthrough(stmt)
186
-
187
- stmt.replace_by(stim_loss_stmt)
188
-
189
- return RewriteResult(has_done_something=True)
190
-
191
-
192
- def create_wire_passthrough(stmt: wire.Apply | wire.Broadcast) -> None:
193
-
194
- for input_wire, output_wire in zip(stmt.inputs, stmt.results):
195
- # have to "reroute" the input of these statements to directly plug in
196
- # to subsequent statements, remove dependency on the current statement
197
- output_wire.replace_by(input_wire)
198
-
199
-
200
- def is_measure_result_used(
201
- stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure,
202
- ) -> bool:
203
- """
204
- Check if the result of a measure statement is used in the program.
205
- """
206
- return bool(stmt.result.uses)
207
-
208
-
209
- T = TypeVar("T")
210
-
211
-
212
- def get_const_value(typ: type[T], value: ir.SSAValue) -> T:
213
- if isinstance(hint := value.hints.get("const"), const.Value):
214
- data = hint.data
215
- if isinstance(data, typ):
216
- return hint.data
217
- raise interp.InterpreterError(
218
- f"Expected constant value <type = {typ}>, got {data}"
219
- )
220
- raise interp.InterpreterError(
221
- f"Expected constant value <type = {typ}>, got {value}"
222
- )
bloqade/test_utils.py CHANGED
@@ -25,7 +25,7 @@ def print_diff(node: pprint.Printable, expected_node: pprint.Printable):
25
25
 
26
26
  def assert_nodes(node: ir.IRNode, expected_node: ir.IRNode):
27
27
  try:
28
- assert node.is_equal(expected_node)
28
+ assert node.is_structurally_equal(expected_node)
29
29
  except AssertionError as e:
30
30
  print_diff(node, expected_node)
31
31
  raise e
bloqade/types.py CHANGED
@@ -22,3 +22,13 @@ class Qubit(ABC):
22
22
 
23
23
  QubitType = types.PyClass(Qubit)
24
24
  """Kirin type for a qubit."""
25
+
26
+
27
+ class MeasurementResult:
28
+ """Runtime representation of the result of a measurement on a qubit."""
29
+
30
+ pass
31
+
32
+
33
+ MeasurementResultType = types.PyClass(MeasurementResult)
34
+ """Kirin type for a measurement result."""
@@ -0,0 +1,2 @@
1
+ from . import analysis as analysis
2
+ from .kernel_validation import KernelValidation as KernelValidation
@@ -0,0 +1,5 @@
1
+ from . import lattice as lattice
2
+ from .analysis import (
3
+ ValidationFrame as ValidationFrame,
4
+ ValidationAnalysis as ValidationAnalysis,
5
+ )