bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/analysis.py +18 -20
  7. bloqade/analysis/measure_id/impls.py +31 -29
  8. bloqade/annotate/__init__.py +6 -0
  9. bloqade/annotate/_dialect.py +3 -0
  10. bloqade/annotate/_interface.py +22 -0
  11. bloqade/annotate/stmts.py +29 -0
  12. bloqade/annotate/types.py +13 -0
  13. bloqade/cirq_utils/__init__.py +4 -2
  14. bloqade/cirq_utils/emit/__init__.py +3 -0
  15. bloqade/cirq_utils/emit/base.py +246 -0
  16. bloqade/cirq_utils/emit/gate.py +104 -0
  17. bloqade/cirq_utils/emit/noise.py +90 -0
  18. bloqade/cirq_utils/emit/qubit.py +35 -0
  19. bloqade/cirq_utils/lowering.py +660 -0
  20. bloqade/cirq_utils/noise/__init__.py +0 -2
  21. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  22. bloqade/cirq_utils/noise/model.py +151 -191
  23. bloqade/cirq_utils/noise/transform.py +2 -2
  24. bloqade/cirq_utils/parallelize.py +9 -6
  25. bloqade/gemini/__init__.py +1 -0
  26. bloqade/gemini/analysis/__init__.py +3 -0
  27. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  28. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  29. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  30. bloqade/gemini/groups.py +67 -0
  31. bloqade/native/__init__.py +23 -0
  32. bloqade/native/_prelude.py +45 -0
  33. bloqade/native/dialects/__init__.py +0 -0
  34. bloqade/native/dialects/gate/__init__.py +2 -0
  35. bloqade/native/dialects/gate/_dialect.py +3 -0
  36. bloqade/native/dialects/gate/_interface.py +32 -0
  37. bloqade/native/dialects/gate/stmts.py +31 -0
  38. bloqade/native/stdlib/__init__.py +0 -0
  39. bloqade/native/stdlib/broadcast.py +246 -0
  40. bloqade/native/stdlib/simple.py +220 -0
  41. bloqade/native/upstream/__init__.py +4 -0
  42. bloqade/native/upstream/squin2native.py +79 -0
  43. bloqade/pyqrack/__init__.py +2 -2
  44. bloqade/pyqrack/base.py +7 -1
  45. bloqade/pyqrack/device.py +192 -18
  46. bloqade/pyqrack/native.py +49 -0
  47. bloqade/pyqrack/reg.py +6 -6
  48. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  49. bloqade/pyqrack/squin/gate/gate.py +136 -0
  50. bloqade/pyqrack/squin/noise/native.py +120 -54
  51. bloqade/pyqrack/squin/qubit.py +39 -36
  52. bloqade/pyqrack/target.py +5 -4
  53. bloqade/pyqrack/task.py +114 -7
  54. bloqade/qasm2/_qasm_loading.py +3 -3
  55. bloqade/qasm2/dialects/core/address.py +21 -12
  56. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  57. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  58. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  59. bloqade/qasm2/dialects/noise/model.py +2 -1
  60. bloqade/qasm2/emit/base.py +16 -11
  61. bloqade/qasm2/emit/gate.py +11 -8
  62. bloqade/qasm2/emit/main.py +103 -3
  63. bloqade/qasm2/emit/target.py +9 -5
  64. bloqade/qasm2/groups.py +3 -2
  65. bloqade/qasm2/parse/lowering.py +0 -1
  66. bloqade/qasm2/passes/fold.py +14 -73
  67. bloqade/qasm2/passes/glob.py +2 -2
  68. bloqade/qasm2/passes/noise.py +1 -1
  69. bloqade/qasm2/passes/parallel.py +7 -5
  70. bloqade/qasm2/rewrite/__init__.py +0 -1
  71. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  72. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  73. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  74. bloqade/qasm2/rewrite/register.py +2 -2
  75. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  76. bloqade/qbraid/lowering.py +1 -0
  77. bloqade/qbraid/schema.py +2 -2
  78. bloqade/qubit/__init__.py +12 -0
  79. bloqade/qubit/_dialect.py +3 -0
  80. bloqade/qubit/_interface.py +49 -0
  81. bloqade/qubit/_prelude.py +45 -0
  82. bloqade/qubit/analysis/__init__.py +1 -0
  83. bloqade/qubit/analysis/address_impl.py +40 -0
  84. bloqade/qubit/stdlib/__init__.py +2 -0
  85. bloqade/qubit/stdlib/_new.py +34 -0
  86. bloqade/qubit/stdlib/broadcast.py +62 -0
  87. bloqade/qubit/stdlib/simple.py +59 -0
  88. bloqade/qubit/stmts.py +60 -0
  89. bloqade/rewrite/passes/__init__.py +6 -0
  90. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  91. bloqade/rewrite/passes/callgraph.py +116 -0
  92. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  93. bloqade/rewrite/rules/split_ifs.py +18 -1
  94. bloqade/squin/__init__.py +47 -14
  95. bloqade/squin/analysis/__init__.py +0 -1
  96. bloqade/squin/analysis/schedule.py +10 -11
  97. bloqade/squin/gate/__init__.py +2 -0
  98. bloqade/squin/gate/_dialect.py +3 -0
  99. bloqade/squin/gate/_interface.py +98 -0
  100. bloqade/squin/gate/stmts.py +125 -0
  101. bloqade/squin/groups.py +5 -22
  102. bloqade/squin/noise/__init__.py +1 -10
  103. bloqade/squin/noise/_dialect.py +1 -1
  104. bloqade/squin/noise/_interface.py +45 -0
  105. bloqade/squin/noise/stmts.py +66 -28
  106. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  107. bloqade/squin/rewrite/__init__.py +0 -2
  108. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  109. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  110. bloqade/squin/stdlib/__init__.py +0 -0
  111. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  112. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  113. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  114. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  115. bloqade/squin/stdlib/simple/__init__.py +33 -0
  116. bloqade/squin/stdlib/simple/gate.py +242 -0
  117. bloqade/squin/stdlib/simple/noise.py +126 -0
  118. bloqade/stim/__init__.py +1 -0
  119. bloqade/stim/_wrappers.py +6 -0
  120. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  121. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  122. bloqade/stim/dialects/gate/emit.py +9 -10
  123. bloqade/stim/dialects/noise/emit.py +17 -13
  124. bloqade/stim/dialects/noise/stmts.py +5 -3
  125. bloqade/stim/emit/__init__.py +1 -0
  126. bloqade/stim/emit/impls.py +16 -0
  127. bloqade/stim/emit/stim_str.py +48 -31
  128. bloqade/stim/groups.py +12 -2
  129. bloqade/stim/parse/lowering.py +14 -17
  130. bloqade/stim/passes/__init__.py +0 -2
  131. bloqade/stim/passes/flatten.py +26 -0
  132. bloqade/stim/passes/simplify_ifs.py +6 -1
  133. bloqade/stim/passes/squin_to_stim.py +9 -84
  134. bloqade/stim/rewrite/__init__.py +2 -4
  135. bloqade/stim/rewrite/get_record_util.py +24 -0
  136. bloqade/stim/rewrite/ifs_to_stim.py +24 -25
  137. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  138. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  139. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  140. bloqade/stim/rewrite/squin_measure.py +9 -18
  141. bloqade/stim/rewrite/squin_noise.py +134 -108
  142. bloqade/stim/rewrite/util.py +5 -192
  143. bloqade/test_utils.py +1 -1
  144. bloqade/types.py +10 -0
  145. bloqade/validation/__init__.py +2 -0
  146. bloqade/validation/analysis/__init__.py +5 -0
  147. bloqade/validation/analysis/analysis.py +41 -0
  148. bloqade/validation/analysis/lattice.py +58 -0
  149. bloqade/validation/kernel_validation.py +77 -0
  150. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  151. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  152. bloqade/pyqrack/squin/op.py +0 -180
  153. bloqade/pyqrack/squin/runtime.py +0 -535
  154. bloqade/pyqrack/squin/wire.py +0 -51
  155. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  156. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  157. bloqade/squin/_typeinfer.py +0 -20
  158. bloqade/squin/analysis/address_impl.py +0 -71
  159. bloqade/squin/analysis/nsites/__init__.py +0 -9
  160. bloqade/squin/analysis/nsites/analysis.py +0 -50
  161. bloqade/squin/analysis/nsites/impls.py +0 -92
  162. bloqade/squin/analysis/nsites/lattice.py +0 -49
  163. bloqade/squin/cirq/__init__.py +0 -280
  164. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  165. bloqade/squin/cirq/emit/noise.py +0 -49
  166. bloqade/squin/cirq/emit/op.py +0 -125
  167. bloqade/squin/cirq/emit/qubit.py +0 -60
  168. bloqade/squin/cirq/emit/runtime.py +0 -242
  169. bloqade/squin/cirq/lowering.py +0 -440
  170. bloqade/squin/lowering.py +0 -54
  171. bloqade/squin/noise/_wrapper.py +0 -40
  172. bloqade/squin/noise/rewrite.py +0 -111
  173. bloqade/squin/op/__init__.py +0 -41
  174. bloqade/squin/op/_dialect.py +0 -3
  175. bloqade/squin/op/_wrapper.py +0 -121
  176. bloqade/squin/op/number.py +0 -5
  177. bloqade/squin/op/rewrite.py +0 -46
  178. bloqade/squin/op/stdlib.py +0 -62
  179. bloqade/squin/op/stmts.py +0 -276
  180. bloqade/squin/op/traits.py +0 -43
  181. bloqade/squin/op/types.py +0 -26
  182. bloqade/squin/qubit.py +0 -184
  183. bloqade/squin/rewrite/canonicalize.py +0 -60
  184. bloqade/squin/rewrite/desugar.py +0 -124
  185. bloqade/squin/types.py +0 -8
  186. bloqade/squin/wire.py +0 -201
  187. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  188. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  189. bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
  190. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  191. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -98,6 +98,8 @@ def loads(
98
98
  signature=func.Signature((), return_node.value.type),
99
99
  body=body,
100
100
  )
101
+ self_arg = ir.BlockArgument(body.blocks[0], 0) # Self argument
102
+ body.blocks[0]._args = (self_arg,)
101
103
  return ir.Method(
102
104
  mod=None,
103
105
  py_func=None,
@@ -627,10 +629,13 @@ class Stim(lowering.LoweringABC[Node]):
627
629
  # Parse tag
628
630
  tag_parts = node.tag.split(";", maxsplit=1)[0].split(":", maxsplit=1)
629
631
  nonstim_name = tag_parts[0]
630
- nonce = 0
631
632
  if len(tag_parts) == 2:
633
+ # This should be a correlated error of the form, e.g.,
634
+ # I_ERROR[correlated_loss:<identifier>](0.01) 0 1 2
635
+ # The identifier is a unique number that prevents stim from merging
636
+ # correlated errors. We discard the identifier, but verify it is an integer.
632
637
  try:
633
- nonce = int(tag_parts[1])
638
+ _ = int(tag_parts[1])
634
639
  except ValueError:
635
640
  # String was not an integer
636
641
  if self.error_unknown_nonstim:
@@ -643,22 +648,14 @@ class Stim(lowering.LoweringABC[Node]):
643
648
  f"Unknown non-stim statement name: {nonstim_name!r} ({node!r})"
644
649
  )
645
650
  statement_cls = self.nonstim_noise_ops.get(nonstim_name)
651
+ stmt = None
646
652
  if statement_cls is not None:
647
- if issubclass(statement_cls, noise.NonStimCorrelatedError):
648
- stmt = statement_cls(
649
- nonce=nonce,
650
- probs=self._get_float_args_ssa(state, node.gate_args_copy()),
651
- targets=self._get_multiple_qubit_or_rec_ssa(
652
- state, node, node.targets_copy()
653
- ),
654
- )
655
- else:
656
- stmt = statement_cls(
657
- probs=self._get_float_args_ssa(state, node.gate_args_copy()),
658
- targets=self._get_multiple_qubit_or_rec_ssa(
659
- state, node, node.targets_copy()
660
- ),
661
- )
653
+ stmt = statement_cls(
654
+ probs=self._get_float_args_ssa(state, node.gate_args_copy()),
655
+ targets=self._get_multiple_qubit_or_rec_ssa(
656
+ state, node, node.targets_copy()
657
+ ),
658
+ )
662
659
  return stmt
663
660
 
664
661
  def visit_CircuitInstruction(
@@ -1,5 +1,3 @@
1
1
  from .squin_to_stim import (
2
2
  SquinToStimPass as SquinToStimPass,
3
- StimSimplifyIfs as StimSimplifyIfs,
4
- AggressiveForLoopUnroll as AggressiveForLoopUnroll,
5
3
  )
@@ -0,0 +1,26 @@
1
+ # Taken from Phillip Weinberg's bloqade-shuttle implementation
2
+ from dataclasses import field, dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.passes import Pass
6
+ from kirin.rewrite.abc import RewriteResult
7
+
8
+ from bloqade.rewrite.passes import AggressiveUnroll
9
+ from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs
10
+
11
+
12
+ @dataclass
13
+ class Flatten(Pass):
14
+
15
+ unroll: AggressiveUnroll = field(init=False)
16
+ simplify_if: StimSimplifyIfs = field(init=False)
17
+
18
+ def __post_init__(self):
19
+ self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise)
20
+ self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise)
21
+
22
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
23
+ rewrite_result = RewriteResult()
24
+ rewrite_result = self.simplify_if(mt).join(rewrite_result)
25
+ rewrite_result = self.unroll(mt).join(rewrite_result)
26
+ return rewrite_result
@@ -7,8 +7,10 @@ from kirin.rewrite import (
7
7
  Chain,
8
8
  Fixpoint,
9
9
  ConstantFold,
10
+ DeadCodeElimination,
10
11
  CommonSubexpressionElimination,
11
12
  )
13
+ from kirin.dialects.scf.trim import UnusedYield
12
14
  from kirin.dialects.ilist.passes import ConstList2IList
13
15
 
14
16
  from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
@@ -20,7 +22,10 @@ class StimSimplifyIfs(Pass):
20
22
  def unsafe_run(self, mt: ir.Method):
21
23
 
22
24
  result = Chain(
23
- Fixpoint(Walk(StimLiftThenBody())),
25
+ Walk(UnusedYield()),
26
+ Walk(StimLiftThenBody()),
27
+ # remove yields (if possible), then lift out as much stuff as possible
28
+ Walk(DeadCodeElimination()),
24
29
  Walk(StimSplitIfStmts()),
25
30
  ).rewrite(mt.code)
26
31
 
@@ -1,30 +1,21 @@
1
1
  from dataclasses import dataclass
2
2
 
3
- from kirin.passes import Fold, HintConst, TypeInfer
4
3
  from kirin.rewrite import (
5
4
  Walk,
6
5
  Chain,
7
6
  Fixpoint,
8
- CFGCompactify,
9
- InlineGetItem,
10
- InlineGetField,
11
7
  DeadCodeElimination,
12
8
  CommonSubexpressionElimination,
13
9
  )
14
- from kirin.dialects import scf, ilist
15
10
  from kirin.ir.method import Method
16
11
  from kirin.passes.abc import Pass
17
12
  from kirin.rewrite.abc import RewriteResult
18
- from kirin.passes.inline import InlinePass
19
- from kirin.rewrite.alias import InlineAlias
20
13
 
21
14
  from bloqade.stim.rewrite import (
22
- SquinWireToStim,
23
15
  PyConstantToStim,
24
16
  SquinNoiseToStim,
25
17
  SquinQubitToStim,
26
18
  SquinMeasureToStim,
27
- SquinWireIdentityElimination,
28
19
  )
29
20
  from bloqade.squin.rewrite import (
30
21
  SquinU3ToClifford,
@@ -34,41 +25,9 @@ from bloqade.squin.rewrite import (
34
25
  from bloqade.rewrite.passes import CanonicalizeIList
35
26
  from bloqade.analysis.address import AddressAnalysis
36
27
  from bloqade.analysis.measure_id import MeasurementIDAnalysis
37
- from bloqade.squin.rewrite.desugar import ApplyDesugarRule
28
+ from bloqade.stim.passes.flatten import Flatten
38
29
 
39
- from .simplify_ifs import StimSimplifyIfs
40
- from ..rewrite.ifs_to_stim import IfToStim
41
-
42
-
43
- @dataclass
44
- class AggressiveForLoopUnroll(Pass):
45
- """
46
- Aggressive unrolling of for loops, addresses cases where unroll
47
- does not successfully handle nested loops because of a lack of constprop.
48
-
49
- This should be invoked via fixpoint to let this be repeatedly applied until
50
- no further rewrites are possible.
51
- """
52
-
53
- def unsafe_run(self, mt: Method) -> RewriteResult:
54
- rule = Chain(
55
- InlineGetField(),
56
- InlineGetItem(),
57
- scf.unroll.ForLoop(),
58
- scf.trim.UnusedYield(),
59
- )
60
-
61
- # Intentionally only walk ONCE, let fixpoint happen with the WHOLE pass
62
- # so that HintConst gets run right after, allowing subsequent unrolls to happen
63
- rewrite_result = Walk(rule).rewrite(mt.code)
64
-
65
- rewrite_result = (
66
- HintConst(dialects=mt.dialects, no_raise=self.no_raise)
67
- .unsafe_run(mt)
68
- .join(rewrite_result)
69
- )
70
-
71
- return rewrite_result
30
+ from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim
72
31
 
73
32
 
74
33
  @dataclass
@@ -77,52 +36,18 @@ class SquinToStimPass(Pass):
77
36
  def unsafe_run(self, mt: Method) -> RewriteResult:
78
37
 
79
38
  # inline aggressively:
80
- rewrite_result = InlinePass(
81
- dialects=mt.dialects, no_raise=self.no_raise
82
- ).unsafe_run(mt)
83
-
84
- rewrite_result = (
85
- AggressiveForLoopUnroll(dialects=mt.dialects, no_raise=self.no_raise)
86
- .fixpoint(mt)
87
- .join(rewrite_result)
39
+ rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint(
40
+ mt
88
41
  )
89
42
 
90
- rewrite_result = (
91
- Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
92
- )
93
-
94
- Walk(InlineAlias()).rewrite(mt.code).join(rewrite_result)
95
-
96
- rewrite_result = (
97
- StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
98
- .unsafe_run(mt)
99
- .join(rewrite_result)
100
- )
101
-
102
- rewrite_result = (
103
- Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
104
- .rewrite(mt.code)
105
- .join(rewrite_result)
106
- )
107
- rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt)
108
-
109
- rewrite_result = (
110
- CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
111
- .unsafe_run(mt)
112
- .join(rewrite_result)
113
- )
114
-
115
- TypeInfer(dialects=mt.dialects, no_raise=self.no_raise).unsafe_run(mt)
116
- Walk(ApplyDesugarRule()).rewrite(mt.code)
117
-
118
43
  # after this the program should be in a state where it is analyzable
119
44
  # -------------------------------------------------------------------
120
45
 
121
46
  mia = MeasurementIDAnalysis(dialects=mt.dialects)
122
- meas_analysis_frame, _ = mia.run_analysis(mt, no_raise=self.no_raise)
47
+ meas_analysis_frame, _ = mia.run(mt)
123
48
 
124
49
  aa = AddressAnalysis(dialects=mt.dialects)
125
- address_analysis_frame, _ = aa.run_analysis(mt, no_raise=self.no_raise)
50
+ address_analysis_frame, _ = aa.run(mt)
126
51
 
127
52
  # wrap the address analysis result
128
53
  rewrite_result = (
@@ -139,6 +64,8 @@ class SquinToStimPass(Pass):
139
64
  rewrite_result = (
140
65
  Chain(
141
66
  Walk(IfToStim(measure_frame=meas_analysis_frame)),
67
+ Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)),
68
+ Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)),
142
69
  Fixpoint(Walk(DeadCodeElimination())),
143
70
  )
144
71
  .rewrite(mt.code)
@@ -156,8 +83,6 @@ class SquinToStimPass(Pass):
156
83
  Chain(
157
84
  SquinQubitToStim(),
158
85
  SquinMeasureToStim(),
159
- SquinWireToStim(),
160
- SquinWireIdentityElimination(),
161
86
  )
162
87
  )
163
88
  .rewrite(mt.code)
@@ -174,7 +99,7 @@ class SquinToStimPass(Pass):
174
99
  rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
175
100
 
176
101
  # clear up leftover stmts
177
- # - remove any squin.qubit.new that's left around
102
+ # - remove any squin.qalloc that's left around
178
103
  rewrite_result = (
179
104
  Fixpoint(
180
105
  Walk(
@@ -1,9 +1,7 @@
1
1
  from .ifs_to_stim import IfToStim as IfToStim
2
2
  from .squin_noise import SquinNoiseToStim as SquinNoiseToStim
3
- from .wire_to_stim import SquinWireToStim as SquinWireToStim
4
3
  from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
5
4
  from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
6
5
  from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
7
- from .wire_identity_elimination import (
8
- SquinWireIdentityElimination as SquinWireIdentityElimination,
9
- )
6
+ from .set_detector_to_stim import SetDetectorToStim as SetDetectorToStim
7
+ from .set_observable_to_stim import SetObservableToStim as SetObservableToStim
@@ -0,0 +1,24 @@
1
+ from kirin import ir
2
+ from kirin.dialects import py
3
+
4
+ from bloqade.stim.dialects import auxiliary
5
+ from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple
6
+
7
+
8
+ def insert_get_records(
9
+ node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int
10
+ ):
11
+ """
12
+ Insert GetRecord statements before the given node
13
+ """
14
+ get_record_ssas = []
15
+ for measure_id_bool in measure_id_tuple.data:
16
+ assert isinstance(measure_id_bool, MeasureIdBool)
17
+ target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt
18
+ idx_stmt = py.constant.Constant(target_rec_idx)
19
+ idx_stmt.insert_before(node)
20
+ get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
21
+ get_record_stmt.insert_before(node)
22
+ get_record_ssas.append(get_record_stmt.result)
23
+
24
+ return get_record_ssas
@@ -4,13 +4,13 @@ from kirin import ir
4
4
  from kirin.dialects import py, scf, func
5
5
  from kirin.rewrite.abc import RewriteRule, RewriteResult
6
6
 
7
- from bloqade.squin import op, qubit
7
+ from bloqade.squin import gate
8
8
  from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
9
9
  from bloqade.squin.rewrite import AddressAttribute
10
10
  from bloqade.stim.rewrite.util import (
11
- SQUIN_STIM_CONTROL_GATE_MAPPING,
12
11
  insert_qubit_idx_from_address,
13
12
  )
13
+ from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ
14
14
  from bloqade.analysis.measure_id import MeasureIDFrame
15
15
  from bloqade.stim.dialects.auxiliary import GetRecord
16
16
  from bloqade.analysis.measure_id.lattice import (
@@ -58,8 +58,7 @@ class IfElseSimplification:
58
58
  """Check if the IfElse statement has an else body."""
59
59
  if stmt.else_body.blocks and not (
60
60
  len(stmt.else_body.blocks[0].stmts) == 1
61
- and isinstance(else_term := stmt.else_body.blocks[0].last_stmt, scf.Yield)
62
- and not else_term.values # empty yield
61
+ and isinstance(stmt.else_body.blocks[0].last_stmt, scf.Yield)
63
62
  ):
64
63
  return True
65
64
 
@@ -67,12 +66,13 @@ class IfElseSimplification:
67
66
 
68
67
 
69
68
  DontLiftType = (
70
- qubit.Apply,
71
- qubit.Broadcast,
72
- scf.Yield,
69
+ gate.stmts.SingleQubitGate,
70
+ gate.stmts.RotationGate,
71
+ gate.stmts.ControlledGate,
73
72
  func.Return,
74
73
  func.Invoke,
75
74
  scf.IfElse,
75
+ scf.Yield,
76
76
  )
77
77
 
78
78
 
@@ -99,16 +99,16 @@ class StimSplitIfStmts(IfElseSimplification, SplitIfStmts):
99
99
  Given an IfElse with multiple valid statements in the then-body:
100
100
 
101
101
  if measure_result:
102
- squin.qubit.apply(op.X, q0)
103
- squin.qubit.apply(op.Y, q1)
102
+ squin.x(q0)
103
+ squin.y(q1)
104
104
 
105
105
  this should be rewritten to:
106
106
 
107
107
  if measure_result:
108
- squin.qubit.apply(op.X, q0)
108
+ squin.x(q0)
109
109
 
110
110
  if measure_result:
111
- squin.qubit.apply(op.Y, q1)
111
+ squin.y(q1)
112
112
  """
113
113
 
114
114
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
@@ -139,24 +139,23 @@ class IfToStim(IfElseSimplification, RewriteRule):
139
139
 
140
140
  def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
141
141
 
142
+ # Check the condition is a singular MeasurementIdBool
142
143
  if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool):
143
144
  return RewriteResult()
144
145
 
145
- # check that there is only qubit.Apply in the then-body,
146
- # if there's more than that, we can't do a valid rewrite.
147
- # Can reuse logic from SplitIf
146
+ # Reusing code from SplitIf,
147
+ # there should only be one statement in the body and it should be a pauli X, Y, or Z
148
148
  *stmts, _ = stmt.then_body.stmts()
149
- if len(stmts) != 1 or not isinstance(stmts[0], (qubit.Apply, qubit.Broadcast)):
149
+ if len(stmts) != 1:
150
150
  return RewriteResult()
151
151
 
152
- apply_or_broadcast = stmts[0]
153
- # Check that the gate being applied/broadcasted can be converted to a stim
154
- # controlled gate.
155
- ctrl_op_target_gate = apply_or_broadcast.operator.owner
156
- assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
157
-
158
- stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate))
159
- if stim_gate is None:
152
+ if isinstance(stmts[0], gate.stmts.X):
153
+ stim_gate = stim_CX
154
+ elif isinstance(stmts[0], gate.stmts.Y):
155
+ stim_gate = stim_CY
156
+ elif isinstance(stmts[0], gate.stmts.Z):
157
+ stim_gate = stim_CZ
158
+ else:
160
159
  return RewriteResult()
161
160
 
162
161
  # get necessary measurement ID type from analysis
@@ -169,8 +168,8 @@ class IfToStim(IfElseSimplification, RewriteRule):
169
168
  )
170
169
  get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
171
170
 
172
- # get address attribute and generate qubit idx statements
173
- address_attr = apply_or_broadcast.qubits.hints.get("address")
171
+ address_attr = stmts[0].qubits.hints.get("address")
172
+
174
173
  if address_attr is None:
175
174
  return RewriteResult()
176
175
  assert isinstance(address_attr, AddressAttribute)
@@ -1,13 +1,11 @@
1
1
  from kirin import ir
2
2
  from kirin.rewrite.abc import RewriteRule, RewriteResult
3
3
 
4
- from bloqade.squin import op, noise, qubit
4
+ from bloqade import qubit
5
+ from bloqade.squin import gate
5
6
  from bloqade.squin.rewrite import AddressAttribute
6
- from bloqade.stim.dialects import gate
7
+ from bloqade.stim.dialects import gate as stim_gate, collapse as stim_collapse
7
8
  from bloqade.stim.rewrite.util import (
8
- SQUIN_STIM_OP_MAPPING,
9
- rewrite_Control,
10
- rewrite_QubitLoss,
11
9
  insert_qubit_idx_from_address,
12
10
  )
13
11
 
@@ -20,64 +18,115 @@ class SquinQubitToStim(RewriteRule):
20
18
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21
19
 
22
20
  match node:
23
- case qubit.Apply() | qubit.Broadcast():
24
- return self.rewrite_Apply_and_Broadcast(node)
21
+ # not supported by Stim
22
+ case gate.stmts.T() | gate.stmts.RotationGate():
23
+ return RewriteResult()
24
+ # If you've reached this point all gates have stim equivalents
25
+ case qubit.stmts.Reset():
26
+ return self.rewrite_Reset(node)
27
+ case gate.stmts.SingleQubitGate():
28
+ return self.rewrite_SingleQubitGate(node)
29
+ case gate.stmts.ControlledGate():
30
+ return self.rewrite_ControlledGate(node)
25
31
  case _:
26
32
  return RewriteResult()
27
33
 
28
- def rewrite_Apply_and_Broadcast(
29
- self, stmt: qubit.Apply | qubit.Broadcast
30
- ) -> RewriteResult:
31
- """
32
- Rewrite Apply and Broadcast nodes to their stim equivalent statements.
33
- """
34
-
35
- # this is an SSAValue, need it to be the actual operator
36
- applied_op = stmt.operator.owner
34
+ def rewrite_Reset(self, stmt: qubit.stmts.Reset) -> RewriteResult:
37
35
 
38
- if isinstance(applied_op, noise.stmts.QubitLoss):
39
- return rewrite_QubitLoss(stmt)
36
+ qubit_addr_attr = stmt.qubits.hints.get("address", None)
40
37
 
41
- assert isinstance(applied_op, op.stmts.Operator)
38
+ if qubit_addr_attr is None:
39
+ return RewriteResult()
42
40
 
43
- if isinstance(applied_op, op.stmts.Control):
44
- return rewrite_Control(stmt)
41
+ assert isinstance(qubit_addr_attr, AddressAttribute)
45
42
 
46
- # need to handle Control through separate means
43
+ qubit_idx_ssas = insert_qubit_idx_from_address(
44
+ address=qubit_addr_attr, stmt_to_insert_before=stmt
45
+ )
47
46
 
48
- # check if its adjoint, assume its canonicalized so no nested adjoints.
49
- is_conj = False
50
- if isinstance(applied_op, op.stmts.Adjoint):
51
- if not applied_op.is_unitary:
52
- return RewriteResult()
47
+ if qubit_idx_ssas is None:
48
+ return RewriteResult()
53
49
 
54
- is_conj = True
55
- applied_op = applied_op.op.owner
50
+ stim_stmt = stim_collapse.RZ(targets=tuple(qubit_idx_ssas))
51
+ stmt.replace_by(stim_stmt)
56
52
 
57
- stim_1q_op = SQUIN_STIM_OP_MAPPING.get(type(applied_op))
58
- if stim_1q_op is None:
59
- return RewriteResult()
53
+ return RewriteResult(has_done_something=True)
60
54
 
61
- address_attr = stmt.qubits.hints.get("address")
55
+ def rewrite_SingleQubitGate(
56
+ self, stmt: gate.stmts.SingleQubitGate
57
+ ) -> RewriteResult:
58
+ """
59
+ Rewrite single qubit gate nodes to their stim equivalent statements.
60
+ Address Analysis should have been run along with Wrap Analysis before this rewrite is applied.
61
+ """
62
62
 
63
- if address_attr is None:
63
+ qubit_addr_attr = stmt.qubits.hints.get("address", None)
64
+ if qubit_addr_attr is None:
64
65
  return RewriteResult()
65
66
 
66
- assert isinstance(address_attr, AddressAttribute)
67
+ assert isinstance(qubit_addr_attr, AddressAttribute)
68
+
67
69
  qubit_idx_ssas = insert_qubit_idx_from_address(
68
- address=address_attr, stmt_to_insert_before=stmt
70
+ address=qubit_addr_attr, stmt_to_insert_before=stmt
69
71
  )
70
72
 
71
73
  if qubit_idx_ssas is None:
72
74
  return RewriteResult()
73
75
 
74
- if isinstance(stim_1q_op, gate.stmts.Gate):
75
- stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas), dagger=is_conj)
76
+ # Get the name of the inputted stmt and see if there is an
77
+ # equivalently named statement in stim,
78
+ # then create an instance of that stim statement
79
+ stmt_name = type(stmt).__name__
80
+ stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None)
81
+ if stim_stmt_cls is None:
82
+ return RewriteResult()
83
+
84
+ if isinstance(stmt, gate.stmts.SingleQubitNonHermitianGate):
85
+ stim_stmt = stim_stmt_cls(
86
+ targets=tuple(qubit_idx_ssas), dagger=stmt.adjoint
87
+ )
76
88
  else:
77
- stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
78
- stmt.replace_by(stim_1q_stmt)
89
+ stim_stmt = stim_stmt_cls(targets=tuple(qubit_idx_ssas))
90
+ stmt.replace_by(stim_stmt)
79
91
 
80
92
  return RewriteResult(has_done_something=True)
81
93
 
94
+ def rewrite_ControlledGate(self, stmt: gate.stmts.ControlledGate) -> RewriteResult:
95
+ """
96
+ Rewrite controlled gate nodes to their stim equivalent statements.
97
+ Address Analysis should have been run along with Wrap Analysis before this rewrite is applied.
98
+ """
82
99
 
83
- # put rewrites for measure statements in separate rule, then just have to dispatch
100
+ controls_addr_attr = stmt.controls.hints.get("address", None)
101
+ targets_addr_attr = stmt.targets.hints.get("address", None)
102
+
103
+ if controls_addr_attr is None or targets_addr_attr is None:
104
+ return RewriteResult()
105
+
106
+ assert isinstance(controls_addr_attr, AddressAttribute)
107
+ assert isinstance(targets_addr_attr, AddressAttribute)
108
+
109
+ controls_idx_ssas = insert_qubit_idx_from_address(
110
+ address=controls_addr_attr, stmt_to_insert_before=stmt
111
+ )
112
+ targets_idx_ssas = insert_qubit_idx_from_address(
113
+ address=targets_addr_attr, stmt_to_insert_before=stmt
114
+ )
115
+
116
+ if controls_idx_ssas is None or targets_idx_ssas is None:
117
+ return RewriteResult()
118
+
119
+ # Get the name of the inputted stmt and see if there is an
120
+ # equivalently named statement in stim,
121
+ # then create an instance of that stim statement
122
+ stmt_name = type(stmt).__name__
123
+ stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None)
124
+ if stim_stmt_cls is None:
125
+ return RewriteResult()
126
+
127
+ stim_stmt = stim_stmt_cls(
128
+ targets=tuple(targets_idx_ssas), controls=tuple(controls_idx_ssas)
129
+ )
130
+ stmt.replace_by(stim_stmt)
131
+
132
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,68 @@
1
+ from typing import Iterable
2
+ from dataclasses import dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.dialects.py import Constant
6
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
7
+
8
+ from bloqade.stim.dialects import auxiliary
9
+ from bloqade.annotate.stmts import SetDetector
10
+ from bloqade.analysis.measure_id import MeasureIDFrame
11
+ from bloqade.stim.dialects.auxiliary import Detector
12
+ from bloqade.analysis.measure_id.lattice import MeasureIdTuple
13
+
14
+ from ..rewrite.get_record_util import insert_get_records
15
+
16
+
17
+ @dataclass
18
+ class SetDetectorToStim(RewriteRule):
19
+ """
20
+ Rewrite SetDetector to GetRecord and Detector in the stim dialect
21
+ """
22
+
23
+ measure_id_frame: MeasureIDFrame
24
+
25
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
26
+ match node:
27
+ case SetDetector():
28
+ return self.rewrite_SetDetector(node)
29
+ case _:
30
+ return RewriteResult()
31
+
32
+ def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult:
33
+
34
+ # get coordinates and generate correct consts
35
+ coord_ssas = []
36
+ if not isinstance(node.coordinates.owner, Constant):
37
+ return RewriteResult()
38
+
39
+ coord_values = node.coordinates.owner.value.unwrap()
40
+
41
+ if not isinstance(coord_values, Iterable):
42
+ return RewriteResult()
43
+
44
+ if any(not isinstance(value, (int, float)) for value in coord_values):
45
+ return RewriteResult()
46
+
47
+ for coord_value in coord_values:
48
+ if isinstance(coord_value, float):
49
+ coord_stmt = auxiliary.ConstFloat(value=coord_value)
50
+ else: # int
51
+ coord_stmt = auxiliary.ConstInt(value=coord_value)
52
+ coord_ssas.append(coord_stmt.result)
53
+ coord_stmt.insert_before(node)
54
+
55
+ measure_ids = self.measure_id_frame.entries[node.measurements]
56
+ assert isinstance(measure_ids, MeasureIdTuple)
57
+
58
+ get_record_list = insert_get_records(
59
+ node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
60
+ )
61
+
62
+ detector_stmt = Detector(
63
+ coord=tuple(coord_ssas), targets=tuple(get_record_list)
64
+ )
65
+
66
+ node.replace_by(detector_stmt)
67
+
68
+ return RewriteResult(has_done_something=True)