bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/__init__.py +4 -1
  7. bloqade/analysis/measure_id/analysis.py +29 -20
  8. bloqade/analysis/measure_id/impls.py +72 -31
  9. bloqade/annotate/__init__.py +6 -0
  10. bloqade/annotate/_dialect.py +3 -0
  11. bloqade/annotate/_interface.py +22 -0
  12. bloqade/annotate/stmts.py +29 -0
  13. bloqade/annotate/types.py +13 -0
  14. bloqade/cirq_utils/__init__.py +4 -2
  15. bloqade/cirq_utils/emit/__init__.py +3 -0
  16. bloqade/cirq_utils/emit/base.py +246 -0
  17. bloqade/cirq_utils/emit/gate.py +104 -0
  18. bloqade/cirq_utils/emit/noise.py +90 -0
  19. bloqade/cirq_utils/emit/qubit.py +35 -0
  20. bloqade/cirq_utils/lowering.py +660 -0
  21. bloqade/cirq_utils/noise/__init__.py +0 -2
  22. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  23. bloqade/cirq_utils/noise/model.py +151 -191
  24. bloqade/cirq_utils/noise/transform.py +2 -2
  25. bloqade/cirq_utils/parallelize.py +9 -6
  26. bloqade/gemini/__init__.py +1 -0
  27. bloqade/gemini/analysis/__init__.py +3 -0
  28. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  29. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  30. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  31. bloqade/gemini/groups.py +67 -0
  32. bloqade/native/__init__.py +23 -0
  33. bloqade/native/_prelude.py +45 -0
  34. bloqade/native/dialects/__init__.py +0 -0
  35. bloqade/native/dialects/gate/__init__.py +2 -0
  36. bloqade/native/dialects/gate/_dialect.py +3 -0
  37. bloqade/native/dialects/gate/_interface.py +32 -0
  38. bloqade/native/dialects/gate/stmts.py +31 -0
  39. bloqade/native/stdlib/__init__.py +0 -0
  40. bloqade/native/stdlib/broadcast.py +246 -0
  41. bloqade/native/stdlib/simple.py +220 -0
  42. bloqade/native/upstream/__init__.py +4 -0
  43. bloqade/native/upstream/squin2native.py +79 -0
  44. bloqade/pyqrack/__init__.py +2 -2
  45. bloqade/pyqrack/base.py +7 -1
  46. bloqade/pyqrack/device.py +190 -4
  47. bloqade/pyqrack/native.py +49 -0
  48. bloqade/pyqrack/reg.py +6 -6
  49. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  50. bloqade/pyqrack/squin/gate/gate.py +136 -0
  51. bloqade/pyqrack/squin/noise/native.py +120 -54
  52. bloqade/pyqrack/squin/qubit.py +39 -36
  53. bloqade/pyqrack/target.py +5 -4
  54. bloqade/pyqrack/task.py +114 -7
  55. bloqade/qasm2/_qasm_loading.py +3 -3
  56. bloqade/qasm2/dialects/core/address.py +21 -12
  57. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  58. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  59. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  60. bloqade/qasm2/dialects/noise/model.py +2 -1
  61. bloqade/qasm2/emit/base.py +16 -11
  62. bloqade/qasm2/emit/gate.py +11 -8
  63. bloqade/qasm2/emit/main.py +103 -3
  64. bloqade/qasm2/emit/target.py +9 -5
  65. bloqade/qasm2/groups.py +3 -2
  66. bloqade/qasm2/parse/lowering.py +0 -1
  67. bloqade/qasm2/passes/fold.py +14 -73
  68. bloqade/qasm2/passes/glob.py +2 -2
  69. bloqade/qasm2/passes/noise.py +1 -1
  70. bloqade/qasm2/passes/parallel.py +7 -5
  71. bloqade/qasm2/rewrite/__init__.py +0 -1
  72. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  73. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  74. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  75. bloqade/qasm2/rewrite/register.py +2 -2
  76. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  77. bloqade/qbraid/lowering.py +1 -0
  78. bloqade/qbraid/schema.py +2 -2
  79. bloqade/qubit/__init__.py +12 -0
  80. bloqade/qubit/_dialect.py +3 -0
  81. bloqade/qubit/_interface.py +49 -0
  82. bloqade/qubit/_prelude.py +45 -0
  83. bloqade/qubit/analysis/__init__.py +1 -0
  84. bloqade/qubit/analysis/address_impl.py +40 -0
  85. bloqade/qubit/stdlib/__init__.py +2 -0
  86. bloqade/qubit/stdlib/_new.py +34 -0
  87. bloqade/qubit/stdlib/broadcast.py +62 -0
  88. bloqade/qubit/stdlib/simple.py +59 -0
  89. bloqade/qubit/stmts.py +60 -0
  90. bloqade/rewrite/passes/__init__.py +6 -0
  91. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  92. bloqade/rewrite/passes/callgraph.py +116 -0
  93. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  94. bloqade/rewrite/rules/split_ifs.py +18 -1
  95. bloqade/squin/__init__.py +47 -14
  96. bloqade/squin/analysis/__init__.py +0 -1
  97. bloqade/squin/analysis/schedule.py +10 -11
  98. bloqade/squin/gate/__init__.py +2 -0
  99. bloqade/squin/gate/_dialect.py +3 -0
  100. bloqade/squin/gate/_interface.py +98 -0
  101. bloqade/squin/gate/stmts.py +125 -0
  102. bloqade/squin/groups.py +5 -22
  103. bloqade/squin/noise/__init__.py +1 -10
  104. bloqade/squin/noise/_dialect.py +1 -1
  105. bloqade/squin/noise/_interface.py +45 -0
  106. bloqade/squin/noise/stmts.py +66 -28
  107. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  108. bloqade/squin/rewrite/__init__.py +0 -2
  109. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  110. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  111. bloqade/squin/stdlib/__init__.py +0 -0
  112. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  113. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  114. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  115. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  116. bloqade/squin/stdlib/simple/__init__.py +33 -0
  117. bloqade/squin/stdlib/simple/gate.py +242 -0
  118. bloqade/squin/stdlib/simple/noise.py +126 -0
  119. bloqade/stim/__init__.py +1 -0
  120. bloqade/stim/_wrappers.py +6 -0
  121. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  122. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  123. bloqade/stim/dialects/gate/emit.py +9 -10
  124. bloqade/stim/dialects/noise/emit.py +17 -13
  125. bloqade/stim/dialects/noise/stmts.py +5 -3
  126. bloqade/stim/emit/__init__.py +1 -0
  127. bloqade/stim/emit/impls.py +16 -0
  128. bloqade/stim/emit/stim_str.py +48 -31
  129. bloqade/stim/groups.py +12 -2
  130. bloqade/stim/parse/lowering.py +14 -17
  131. bloqade/stim/passes/__init__.py +3 -1
  132. bloqade/stim/passes/flatten.py +26 -0
  133. bloqade/stim/passes/simplify_ifs.py +16 -2
  134. bloqade/stim/passes/squin_to_stim.py +18 -60
  135. bloqade/stim/rewrite/__init__.py +3 -4
  136. bloqade/stim/rewrite/get_record_util.py +24 -0
  137. bloqade/stim/rewrite/ifs_to_stim.py +29 -31
  138. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  139. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  140. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  141. bloqade/stim/rewrite/squin_measure.py +11 -79
  142. bloqade/stim/rewrite/squin_noise.py +134 -108
  143. bloqade/stim/rewrite/util.py +5 -192
  144. bloqade/test_utils.py +1 -1
  145. bloqade/types.py +10 -0
  146. bloqade/validation/__init__.py +2 -0
  147. bloqade/validation/analysis/__init__.py +5 -0
  148. bloqade/validation/analysis/analysis.py +41 -0
  149. bloqade/validation/analysis/lattice.py +58 -0
  150. bloqade/validation/kernel_validation.py +77 -0
  151. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  152. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  153. bloqade/pyqrack/squin/op.py +0 -166
  154. bloqade/pyqrack/squin/runtime.py +0 -535
  155. bloqade/pyqrack/squin/wire.py +0 -51
  156. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  157. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  158. bloqade/squin/_typeinfer.py +0 -20
  159. bloqade/squin/analysis/address_impl.py +0 -71
  160. bloqade/squin/analysis/nsites/__init__.py +0 -9
  161. bloqade/squin/analysis/nsites/analysis.py +0 -50
  162. bloqade/squin/analysis/nsites/impls.py +0 -92
  163. bloqade/squin/analysis/nsites/lattice.py +0 -49
  164. bloqade/squin/cirq/__init__.py +0 -265
  165. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  166. bloqade/squin/cirq/emit/noise.py +0 -49
  167. bloqade/squin/cirq/emit/op.py +0 -125
  168. bloqade/squin/cirq/emit/qubit.py +0 -60
  169. bloqade/squin/cirq/emit/runtime.py +0 -242
  170. bloqade/squin/cirq/lowering.py +0 -440
  171. bloqade/squin/lowering.py +0 -54
  172. bloqade/squin/noise/_wrapper.py +0 -40
  173. bloqade/squin/noise/rewrite.py +0 -111
  174. bloqade/squin/op/__init__.py +0 -41
  175. bloqade/squin/op/_dialect.py +0 -3
  176. bloqade/squin/op/_wrapper.py +0 -121
  177. bloqade/squin/op/number.py +0 -5
  178. bloqade/squin/op/rewrite.py +0 -46
  179. bloqade/squin/op/stdlib.py +0 -62
  180. bloqade/squin/op/stmts.py +0 -276
  181. bloqade/squin/op/traits.py +0 -43
  182. bloqade/squin/op/types.py +0 -26
  183. bloqade/squin/qubit.py +0 -184
  184. bloqade/squin/rewrite/canonicalize.py +0 -60
  185. bloqade/squin/rewrite/desugar.py +0 -124
  186. bloqade/squin/types.py +0 -8
  187. bloqade/squin/wire.py +0 -201
  188. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  189. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  190. bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
  191. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  192. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -98,6 +98,8 @@ def loads(
98
98
  signature=func.Signature((), return_node.value.type),
99
99
  body=body,
100
100
  )
101
+ self_arg = ir.BlockArgument(body.blocks[0], 0) # Self argument
102
+ body.blocks[0]._args = (self_arg,)
101
103
  return ir.Method(
102
104
  mod=None,
103
105
  py_func=None,
@@ -627,10 +629,13 @@ class Stim(lowering.LoweringABC[Node]):
627
629
  # Parse tag
628
630
  tag_parts = node.tag.split(";", maxsplit=1)[0].split(":", maxsplit=1)
629
631
  nonstim_name = tag_parts[0]
630
- nonce = 0
631
632
  if len(tag_parts) == 2:
633
+ # This should be a correlated error of the form, e.g.,
634
+ # I_ERROR[correlated_loss:<identifier>](0.01) 0 1 2
635
+ # The identifier is a unique number that prevents stim from merging
636
+ # correlated errors. We discard the identifier, but verify it is an integer.
632
637
  try:
633
- nonce = int(tag_parts[1])
638
+ _ = int(tag_parts[1])
634
639
  except ValueError:
635
640
  # String was not an integer
636
641
  if self.error_unknown_nonstim:
@@ -643,22 +648,14 @@ class Stim(lowering.LoweringABC[Node]):
643
648
  f"Unknown non-stim statement name: {nonstim_name!r} ({node!r})"
644
649
  )
645
650
  statement_cls = self.nonstim_noise_ops.get(nonstim_name)
651
+ stmt = None
646
652
  if statement_cls is not None:
647
- if issubclass(statement_cls, noise.NonStimCorrelatedError):
648
- stmt = statement_cls(
649
- nonce=nonce,
650
- probs=self._get_float_args_ssa(state, node.gate_args_copy()),
651
- targets=self._get_multiple_qubit_or_rec_ssa(
652
- state, node, node.targets_copy()
653
- ),
654
- )
655
- else:
656
- stmt = statement_cls(
657
- probs=self._get_float_args_ssa(state, node.gate_args_copy()),
658
- targets=self._get_multiple_qubit_or_rec_ssa(
659
- state, node, node.targets_copy()
660
- ),
661
- )
653
+ stmt = statement_cls(
654
+ probs=self._get_float_args_ssa(state, node.gate_args_copy()),
655
+ targets=self._get_multiple_qubit_or_rec_ssa(
656
+ state, node, node.targets_copy()
657
+ ),
658
+ )
662
659
  return stmt
663
660
 
664
661
  def visit_CircuitInstruction(
@@ -1 +1,3 @@
1
- from .squin_to_stim import SquinToStimPass as SquinToStimPass
1
+ from .squin_to_stim import (
2
+ SquinToStimPass as SquinToStimPass,
3
+ )
@@ -0,0 +1,26 @@
1
+ # Taken from Phillip Weinberg's bloqade-shuttle implementation
2
+ from dataclasses import field, dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.passes import Pass
6
+ from kirin.rewrite.abc import RewriteResult
7
+
8
+ from bloqade.rewrite.passes import AggressiveUnroll
9
+ from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs
10
+
11
+
12
+ @dataclass
13
+ class Flatten(Pass):
14
+
15
+ unroll: AggressiveUnroll = field(init=False)
16
+ simplify_if: StimSimplifyIfs = field(init=False)
17
+
18
+ def __post_init__(self):
19
+ self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise)
20
+ self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise)
21
+
22
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
23
+ rewrite_result = RewriteResult()
24
+ rewrite_result = self.simplify_if(mt).join(rewrite_result)
25
+ rewrite_result = self.unroll(mt).join(rewrite_result)
26
+ return rewrite_result
@@ -7,8 +7,11 @@ from kirin.rewrite import (
7
7
  Chain,
8
8
  Fixpoint,
9
9
  ConstantFold,
10
+ DeadCodeElimination,
10
11
  CommonSubexpressionElimination,
11
12
  )
13
+ from kirin.dialects.scf.trim import UnusedYield
14
+ from kirin.dialects.ilist.passes import ConstList2IList
12
15
 
13
16
  from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
14
17
 
@@ -19,12 +22,23 @@ class StimSimplifyIfs(Pass):
19
22
  def unsafe_run(self, mt: ir.Method):
20
23
 
21
24
  result = Chain(
22
- Fixpoint(Walk(StimLiftThenBody())),
25
+ Walk(UnusedYield()),
26
+ Walk(StimLiftThenBody()),
27
+ # remove yields (if possible), then lift out as much stuff as possible
28
+ Walk(DeadCodeElimination()),
23
29
  Walk(StimSplitIfStmts()),
24
30
  ).rewrite(mt.code)
25
31
 
32
+ # because nested python lists don't have their
33
+ # member lists converted to ILists, ConstantFold
34
+ # can add python lists that can't be hashed, causing
35
+ # issues with CSE. ConstList2IList remedies that problem here.
26
36
  result = (
27
- Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
37
+ Chain(
38
+ Fixpoint(Walk(ConstantFold())),
39
+ Walk(ConstList2IList()),
40
+ Walk(CommonSubexpressionElimination()),
41
+ )
28
42
  .rewrite(mt.code)
29
43
  .join(result)
30
44
  )
@@ -1,29 +1,21 @@
1
1
  from dataclasses import dataclass
2
2
 
3
- from kirin.passes import Fold
4
3
  from kirin.rewrite import (
5
4
  Walk,
6
5
  Chain,
7
6
  Fixpoint,
8
- CFGCompactify,
9
- InlineGetItem,
10
- InlineGetField,
11
7
  DeadCodeElimination,
12
8
  CommonSubexpressionElimination,
13
9
  )
14
- from kirin.dialects import scf, ilist
15
10
  from kirin.ir.method import Method
16
11
  from kirin.passes.abc import Pass
17
12
  from kirin.rewrite.abc import RewriteResult
18
- from kirin.passes.inline import InlinePass
19
13
 
20
14
  from bloqade.stim.rewrite import (
21
- SquinWireToStim,
22
15
  PyConstantToStim,
23
16
  SquinNoiseToStim,
24
17
  SquinQubitToStim,
25
18
  SquinMeasureToStim,
26
- SquinWireIdentityElimination,
27
19
  )
28
20
  from bloqade.squin.rewrite import (
29
21
  SquinU3ToClifford,
@@ -33,9 +25,9 @@ from bloqade.squin.rewrite import (
33
25
  from bloqade.rewrite.passes import CanonicalizeIList
34
26
  from bloqade.analysis.address import AddressAnalysis
35
27
  from bloqade.analysis.measure_id import MeasurementIDAnalysis
28
+ from bloqade.stim.passes.flatten import Flatten
36
29
 
37
- from .simplify_ifs import StimSimplifyIfs
38
- from ..rewrite.ifs_to_stim import IfToStim
30
+ from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim
39
31
 
40
32
 
41
33
  @dataclass
@@ -44,52 +36,18 @@ class SquinToStimPass(Pass):
44
36
  def unsafe_run(self, mt: Method) -> RewriteResult:
45
37
 
46
38
  # inline aggressively:
47
- rewrite_result = InlinePass(
48
- dialects=mt.dialects, no_raise=self.no_raise
49
- ).unsafe_run(mt)
50
-
51
- rule = Chain(
52
- InlineGetField(),
53
- InlineGetItem(),
54
- scf.unroll.ForLoop(),
55
- scf.trim.UnusedYield(),
56
- )
57
- rewrite_result = Fixpoint(Walk(rule)).rewrite(mt.code).join(rewrite_result)
58
- # fold_pass = Fold(mt.dialects, no_raise=self.no_raise)
59
- # rewrite_result = fold_pass(mt)
60
- rewrite_result = (
61
- Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
62
- )
63
- rewrite_result = (
64
- StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
65
- .unsafe_run(mt)
66
- .join(rewrite_result)
67
- )
68
-
69
- # run typeinfer again after unroll etc. because we now insert
70
- # a lot of new nodes, which might have more precise types
71
- # self.typeinfer.unsafe_run(mt)
72
- rewrite_result = (
73
- Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
74
- .rewrite(mt.code)
75
- .join(rewrite_result)
76
- )
77
- rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt)
78
-
79
- rewrite_result = (
80
- CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
81
- .unsafe_run(mt)
82
- .join(rewrite_result)
39
+ rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint(
40
+ mt
83
41
  )
84
42
 
85
43
  # after this the program should be in a state where it is analyzable
86
44
  # -------------------------------------------------------------------
87
45
 
88
46
  mia = MeasurementIDAnalysis(dialects=mt.dialects)
89
- meas_analysis_frame, _ = mia.run_analysis(mt, no_raise=self.no_raise)
47
+ meas_analysis_frame, _ = mia.run(mt)
90
48
 
91
49
  aa = AddressAnalysis(dialects=mt.dialects)
92
- address_analysis_frame, _ = aa.run_analysis(mt, no_raise=self.no_raise)
50
+ address_analysis_frame, _ = aa.run(mt)
93
51
 
94
52
  # wrap the address analysis result
95
53
  rewrite_result = (
@@ -99,12 +57,16 @@ class SquinToStimPass(Pass):
99
57
  )
100
58
 
101
59
  # 2. rewrite
60
+ ## Invoke DCE afterwards to eliminate any GetItems
61
+ ## that are no longer being used. This allows for
62
+ ## SquinMeasureToStim to safely eliminate
63
+ ## unused measure statements.
102
64
  rewrite_result = (
103
- Walk(
104
- IfToStim(
105
- measure_analysis=meas_analysis_frame.entries,
106
- measure_count=mia.measure_count,
107
- )
65
+ Chain(
66
+ Walk(IfToStim(measure_frame=meas_analysis_frame)),
67
+ Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)),
68
+ Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)),
69
+ Fixpoint(Walk(DeadCodeElimination())),
108
70
  )
109
71
  .rewrite(mt.code)
110
72
  .join(rewrite_result)
@@ -120,17 +82,13 @@ class SquinToStimPass(Pass):
120
82
  Walk(
121
83
  Chain(
122
84
  SquinQubitToStim(),
123
- SquinWireToStim(),
124
- SquinMeasureToStim(
125
- measure_id_result=meas_analysis_frame.entries,
126
- total_measure_count=mia.measure_count,
127
- ), # reduce duplicated logic, can split out even more rules later
128
- SquinWireIdentityElimination(),
85
+ SquinMeasureToStim(),
129
86
  )
130
87
  )
131
88
  .rewrite(mt.code)
132
89
  .join(rewrite_result)
133
90
  )
91
+
134
92
  rewrite_result = (
135
93
  CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
136
94
  .unsafe_run(mt)
@@ -141,7 +99,7 @@ class SquinToStimPass(Pass):
141
99
  rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
142
100
 
143
101
  # clear up leftover stmts
144
- # - remove any squin.qubit.new that's left around
102
+ # - remove any squin.qalloc that's left around
145
103
  rewrite_result = (
146
104
  Fixpoint(
147
105
  Walk(
@@ -1,8 +1,7 @@
1
+ from .ifs_to_stim import IfToStim as IfToStim
1
2
  from .squin_noise import SquinNoiseToStim as SquinNoiseToStim
2
- from .wire_to_stim import SquinWireToStim as SquinWireToStim
3
3
  from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
4
4
  from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
5
5
  from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
6
- from .wire_identity_elimination import (
7
- SquinWireIdentityElimination as SquinWireIdentityElimination,
8
- )
6
+ from .set_detector_to_stim import SetDetectorToStim as SetDetectorToStim
7
+ from .set_observable_to_stim import SetObservableToStim as SetObservableToStim
@@ -0,0 +1,24 @@
1
+ from kirin import ir
2
+ from kirin.dialects import py
3
+
4
+ from bloqade.stim.dialects import auxiliary
5
+ from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple
6
+
7
+
8
+ def insert_get_records(
9
+ node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int
10
+ ):
11
+ """
12
+ Insert GetRecord statements before the given node
13
+ """
14
+ get_record_ssas = []
15
+ for measure_id_bool in measure_id_tuple.data:
16
+ assert isinstance(measure_id_bool, MeasureIdBool)
17
+ target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt
18
+ idx_stmt = py.constant.Constant(target_rec_idx)
19
+ idx_stmt.insert_before(node)
20
+ get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
21
+ get_record_stmt.insert_before(node)
22
+ get_record_ssas.append(get_record_stmt.result)
23
+
24
+ return get_record_ssas
@@ -4,16 +4,16 @@ from kirin import ir
4
4
  from kirin.dialects import py, scf, func
5
5
  from kirin.rewrite.abc import RewriteRule, RewriteResult
6
6
 
7
- from bloqade.squin import op, qubit
7
+ from bloqade.squin import gate
8
8
  from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
9
9
  from bloqade.squin.rewrite import AddressAttribute
10
10
  from bloqade.stim.rewrite.util import (
11
- SQUIN_STIM_CONTROL_GATE_MAPPING,
12
11
  insert_qubit_idx_from_address,
13
12
  )
13
+ from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ
14
+ from bloqade.analysis.measure_id import MeasureIDFrame
14
15
  from bloqade.stim.dialects.auxiliary import GetRecord
15
16
  from bloqade.analysis.measure_id.lattice import (
16
- MeasureId,
17
17
  MeasureIdBool,
18
18
  )
19
19
 
@@ -58,8 +58,7 @@ class IfElseSimplification:
58
58
  """Check if the IfElse statement has an else body."""
59
59
  if stmt.else_body.blocks and not (
60
60
  len(stmt.else_body.blocks[0].stmts) == 1
61
- and isinstance(else_term := stmt.else_body.blocks[0].last_stmt, scf.Yield)
62
- and not else_term.values # empty yield
61
+ and isinstance(stmt.else_body.blocks[0].last_stmt, scf.Yield)
63
62
  ):
64
63
  return True
65
64
 
@@ -67,12 +66,13 @@ class IfElseSimplification:
67
66
 
68
67
 
69
68
  DontLiftType = (
70
- qubit.Apply,
71
- qubit.Broadcast,
72
- scf.Yield,
69
+ gate.stmts.SingleQubitGate,
70
+ gate.stmts.RotationGate,
71
+ gate.stmts.ControlledGate,
73
72
  func.Return,
74
73
  func.Invoke,
75
74
  scf.IfElse,
75
+ scf.Yield,
76
76
  )
77
77
 
78
78
 
@@ -99,16 +99,16 @@ class StimSplitIfStmts(IfElseSimplification, SplitIfStmts):
99
99
  Given an IfElse with multiple valid statements in the then-body:
100
100
 
101
101
  if measure_result:
102
- squin.qubit.apply(op.X, q0)
103
- squin.qubit.apply(op.Y, q1)
102
+ squin.x(q0)
103
+ squin.y(q1)
104
104
 
105
105
  this should be rewritten to:
106
106
 
107
107
  if measure_result:
108
- squin.qubit.apply(op.X, q0)
108
+ squin.x(q0)
109
109
 
110
110
  if measure_result:
111
- squin.qubit.apply(op.Y, q1)
111
+ squin.y(q1)
112
112
  """
113
113
 
114
114
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
@@ -127,8 +127,7 @@ class IfToStim(IfElseSimplification, RewriteRule):
127
127
  Rewrite if statements to stim equivalent statements.
128
128
  """
129
129
 
130
- measure_analysis: dict[ir.SSAValue, MeasureId]
131
- measure_count: int
130
+ measure_frame: MeasureIDFrame
132
131
 
133
132
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
134
133
 
@@ -140,38 +139,37 @@ class IfToStim(IfElseSimplification, RewriteRule):
140
139
 
141
140
  def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
142
141
 
143
- if not isinstance(self.measure_analysis[stmt.cond], MeasureIdBool):
142
+ # Check the condition is a singular MeasurementIdBool
143
+ if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool):
144
144
  return RewriteResult()
145
145
 
146
- # check that there is only qubit.Apply in the then-body,
147
- # if there's more than that, we can't do a valid rewrite.
148
- # Can reuse logic from SplitIf
146
+ # Reusing code from SplitIf,
147
+ # there should only be one statement in the body and it should be a pauli X, Y, or Z
149
148
  *stmts, _ = stmt.then_body.stmts()
150
- if len(stmts) != 1 or not isinstance(stmts[0], (qubit.Apply, qubit.Broadcast)):
149
+ if len(stmts) != 1:
151
150
  return RewriteResult()
152
151
 
153
- apply_or_broadcast = stmts[0]
154
- # Check that the gate being applied/broadcasted can be converted to a stim
155
- # controlled gate.
156
- ctrl_op_target_gate = apply_or_broadcast.operator.owner
157
- assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
158
-
159
- stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate))
160
- if stim_gate is None:
152
+ if isinstance(stmts[0], gate.stmts.X):
153
+ stim_gate = stim_CX
154
+ elif isinstance(stmts[0], gate.stmts.Y):
155
+ stim_gate = stim_CY
156
+ elif isinstance(stmts[0], gate.stmts.Z):
157
+ stim_gate = stim_CZ
158
+ else:
161
159
  return RewriteResult()
162
160
 
163
161
  # get necessary measurement ID type from analysis
164
- measure_id_bool = self.measure_analysis[stmt.cond]
162
+ measure_id_bool = self.measure_frame.entries[stmt.cond]
165
163
  assert isinstance(measure_id_bool, MeasureIdBool)
166
164
 
167
165
  # generate get record statement
168
166
  measure_id_idx_stmt = py.Constant(
169
- (measure_id_bool.idx - 1) - self.measure_count
167
+ (measure_id_bool.idx - 1) - self.measure_frame.num_measures_at_stmt[stmt]
170
168
  )
171
169
  get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
172
170
 
173
- # get address attribute and generate qubit idx statements
174
- address_attr = apply_or_broadcast.qubits.hints.get("address")
171
+ address_attr = stmts[0].qubits.hints.get("address")
172
+
175
173
  if address_attr is None:
176
174
  return RewriteResult()
177
175
  assert isinstance(address_attr, AddressAttribute)
@@ -1,13 +1,11 @@
1
1
  from kirin import ir
2
2
  from kirin.rewrite.abc import RewriteRule, RewriteResult
3
3
 
4
- from bloqade.squin import op, noise, qubit
4
+ from bloqade import qubit
5
+ from bloqade.squin import gate
5
6
  from bloqade.squin.rewrite import AddressAttribute
6
- from bloqade.stim.dialects import gate
7
+ from bloqade.stim.dialects import gate as stim_gate, collapse as stim_collapse
7
8
  from bloqade.stim.rewrite.util import (
8
- SQUIN_STIM_OP_MAPPING,
9
- rewrite_Control,
10
- rewrite_QubitLoss,
11
9
  insert_qubit_idx_from_address,
12
10
  )
13
11
 
@@ -20,64 +18,115 @@ class SquinQubitToStim(RewriteRule):
20
18
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21
19
 
22
20
  match node:
23
- case qubit.Apply() | qubit.Broadcast():
24
- return self.rewrite_Apply_and_Broadcast(node)
21
+ # not supported by Stim
22
+ case gate.stmts.T() | gate.stmts.RotationGate():
23
+ return RewriteResult()
24
+ # If you've reached this point all gates have stim equivalents
25
+ case qubit.stmts.Reset():
26
+ return self.rewrite_Reset(node)
27
+ case gate.stmts.SingleQubitGate():
28
+ return self.rewrite_SingleQubitGate(node)
29
+ case gate.stmts.ControlledGate():
30
+ return self.rewrite_ControlledGate(node)
25
31
  case _:
26
32
  return RewriteResult()
27
33
 
28
- def rewrite_Apply_and_Broadcast(
29
- self, stmt: qubit.Apply | qubit.Broadcast
30
- ) -> RewriteResult:
31
- """
32
- Rewrite Apply and Broadcast nodes to their stim equivalent statements.
33
- """
34
-
35
- # this is an SSAValue, need it to be the actual operator
36
- applied_op = stmt.operator.owner
34
+ def rewrite_Reset(self, stmt: qubit.stmts.Reset) -> RewriteResult:
37
35
 
38
- if isinstance(applied_op, noise.stmts.QubitLoss):
39
- return rewrite_QubitLoss(stmt)
36
+ qubit_addr_attr = stmt.qubits.hints.get("address", None)
40
37
 
41
- assert isinstance(applied_op, op.stmts.Operator)
38
+ if qubit_addr_attr is None:
39
+ return RewriteResult()
42
40
 
43
- if isinstance(applied_op, op.stmts.Control):
44
- return rewrite_Control(stmt)
41
+ assert isinstance(qubit_addr_attr, AddressAttribute)
45
42
 
46
- # need to handle Control through separate means
43
+ qubit_idx_ssas = insert_qubit_idx_from_address(
44
+ address=qubit_addr_attr, stmt_to_insert_before=stmt
45
+ )
47
46
 
48
- # check if its adjoint, assume its canonicalized so no nested adjoints.
49
- is_conj = False
50
- if isinstance(applied_op, op.stmts.Adjoint):
51
- if not applied_op.is_unitary:
52
- return RewriteResult()
47
+ if qubit_idx_ssas is None:
48
+ return RewriteResult()
53
49
 
54
- is_conj = True
55
- applied_op = applied_op.op.owner
50
+ stim_stmt = stim_collapse.RZ(targets=tuple(qubit_idx_ssas))
51
+ stmt.replace_by(stim_stmt)
56
52
 
57
- stim_1q_op = SQUIN_STIM_OP_MAPPING.get(type(applied_op))
58
- if stim_1q_op is None:
59
- return RewriteResult()
53
+ return RewriteResult(has_done_something=True)
60
54
 
61
- address_attr = stmt.qubits.hints.get("address")
55
+ def rewrite_SingleQubitGate(
56
+ self, stmt: gate.stmts.SingleQubitGate
57
+ ) -> RewriteResult:
58
+ """
59
+ Rewrite single qubit gate nodes to their stim equivalent statements.
60
+ Address Analysis should have been run along with Wrap Analysis before this rewrite is applied.
61
+ """
62
62
 
63
- if address_attr is None:
63
+ qubit_addr_attr = stmt.qubits.hints.get("address", None)
64
+ if qubit_addr_attr is None:
64
65
  return RewriteResult()
65
66
 
66
- assert isinstance(address_attr, AddressAttribute)
67
+ assert isinstance(qubit_addr_attr, AddressAttribute)
68
+
67
69
  qubit_idx_ssas = insert_qubit_idx_from_address(
68
- address=address_attr, stmt_to_insert_before=stmt
70
+ address=qubit_addr_attr, stmt_to_insert_before=stmt
69
71
  )
70
72
 
71
73
  if qubit_idx_ssas is None:
72
74
  return RewriteResult()
73
75
 
74
- if isinstance(stim_1q_op, gate.stmts.Gate):
75
- stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas), dagger=is_conj)
76
+ # Get the name of the inputted stmt and see if there is an
77
+ # equivalently named statement in stim,
78
+ # then create an instance of that stim statement
79
+ stmt_name = type(stmt).__name__
80
+ stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None)
81
+ if stim_stmt_cls is None:
82
+ return RewriteResult()
83
+
84
+ if isinstance(stmt, gate.stmts.SingleQubitNonHermitianGate):
85
+ stim_stmt = stim_stmt_cls(
86
+ targets=tuple(qubit_idx_ssas), dagger=stmt.adjoint
87
+ )
76
88
  else:
77
- stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
78
- stmt.replace_by(stim_1q_stmt)
89
+ stim_stmt = stim_stmt_cls(targets=tuple(qubit_idx_ssas))
90
+ stmt.replace_by(stim_stmt)
79
91
 
80
92
  return RewriteResult(has_done_something=True)
81
93
 
94
+ def rewrite_ControlledGate(self, stmt: gate.stmts.ControlledGate) -> RewriteResult:
95
+ """
96
+ Rewrite controlled gate nodes to their stim equivalent statements.
97
+ Address Analysis should have been run along with Wrap Analysis before this rewrite is applied.
98
+ """
82
99
 
83
- # put rewrites for measure statements in separate rule, then just have to dispatch
100
+ controls_addr_attr = stmt.controls.hints.get("address", None)
101
+ targets_addr_attr = stmt.targets.hints.get("address", None)
102
+
103
+ if controls_addr_attr is None or targets_addr_attr is None:
104
+ return RewriteResult()
105
+
106
+ assert isinstance(controls_addr_attr, AddressAttribute)
107
+ assert isinstance(targets_addr_attr, AddressAttribute)
108
+
109
+ controls_idx_ssas = insert_qubit_idx_from_address(
110
+ address=controls_addr_attr, stmt_to_insert_before=stmt
111
+ )
112
+ targets_idx_ssas = insert_qubit_idx_from_address(
113
+ address=targets_addr_attr, stmt_to_insert_before=stmt
114
+ )
115
+
116
+ if controls_idx_ssas is None or targets_idx_ssas is None:
117
+ return RewriteResult()
118
+
119
+ # Get the name of the inputted stmt and see if there is an
120
+ # equivalently named statement in stim,
121
+ # then create an instance of that stim statement
122
+ stmt_name = type(stmt).__name__
123
+ stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None)
124
+ if stim_stmt_cls is None:
125
+ return RewriteResult()
126
+
127
+ stim_stmt = stim_stmt_cls(
128
+ targets=tuple(targets_idx_ssas), controls=tuple(controls_idx_ssas)
129
+ )
130
+ stmt.replace_by(stim_stmt)
131
+
132
+ return RewriteResult(has_done_something=True)