bloqade-circuit 0.7.12__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (136) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +119 -29
  3. bloqade/analysis/address/impls.py +290 -87
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +2 -2
  6. bloqade/analysis/measure_id/impls.py +3 -27
  7. bloqade/cirq_utils/__init__.py +3 -1
  8. bloqade/cirq_utils/emit/__init__.py +3 -0
  9. bloqade/cirq_utils/emit/base.py +243 -0
  10. bloqade/cirq_utils/emit/gate.py +104 -0
  11. bloqade/cirq_utils/emit/noise.py +90 -0
  12. bloqade/cirq_utils/emit/qubit.py +35 -0
  13. bloqade/cirq_utils/lowering.py +664 -0
  14. bloqade/native/__init__.py +0 -1
  15. bloqade/native/_prelude.py +3 -3
  16. bloqade/native/dialects/gate/__init__.py +2 -0
  17. bloqade/native/dialects/gate/_dialect.py +3 -0
  18. bloqade/native/dialects/{gates → gate}/_interface.py +5 -5
  19. bloqade/native/dialects/{gates → gate}/stmts.py +5 -5
  20. bloqade/native/stdlib/broadcast.py +19 -19
  21. bloqade/native/stdlib/simple.py +14 -13
  22. bloqade/native/upstream/__init__.py +5 -0
  23. bloqade/native/upstream/squin2native.py +136 -0
  24. bloqade/pyqrack/__init__.py +1 -2
  25. bloqade/pyqrack/device.py +6 -17
  26. bloqade/pyqrack/native.py +17 -17
  27. bloqade/pyqrack/reg.py +1 -6
  28. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  29. bloqade/pyqrack/squin/gate/gate.py +136 -0
  30. bloqade/pyqrack/squin/noise/native.py +120 -54
  31. bloqade/pyqrack/squin/qubit.py +25 -41
  32. bloqade/pyqrack/target.py +2 -2
  33. bloqade/qasm2/dialects/core/address.py +21 -12
  34. bloqade/qasm2/dialects/noise/fidelity.py +2 -6
  35. bloqade/qasm2/dialects/noise/model.py +2 -1
  36. bloqade/qasm2/passes/parallel.py +3 -1
  37. bloqade/qasm2/rewrite/__init__.py +0 -1
  38. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  39. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  40. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  41. bloqade/qubit/__init__.py +12 -0
  42. bloqade/qubit/_dialect.py +3 -0
  43. bloqade/qubit/_interface.py +49 -0
  44. bloqade/qubit/_prelude.py +45 -0
  45. bloqade/qubit/analysis/__init__.py +1 -0
  46. bloqade/qubit/analysis/address_impl.py +40 -0
  47. bloqade/qubit/stdlib/__init__.py +2 -0
  48. bloqade/qubit/stdlib/_new.py +34 -0
  49. bloqade/qubit/stdlib/broadcast.py +62 -0
  50. bloqade/qubit/stdlib/simple.py +59 -0
  51. bloqade/qubit/stmts.py +60 -0
  52. bloqade/rewrite/passes/aggressive_unroll.py +2 -1
  53. bloqade/squin/__init__.py +44 -17
  54. bloqade/squin/analysis/__init__.py +0 -1
  55. bloqade/squin/analysis/schedule.py +2 -2
  56. bloqade/squin/gate/__init__.py +2 -0
  57. bloqade/squin/gate/_dialect.py +3 -0
  58. bloqade/squin/gate/_interface.py +98 -0
  59. bloqade/squin/gate/stmts.py +119 -0
  60. bloqade/squin/groups.py +4 -21
  61. bloqade/squin/noise/__init__.py +1 -9
  62. bloqade/squin/noise/_dialect.py +1 -1
  63. bloqade/squin/noise/_interface.py +45 -0
  64. bloqade/squin/noise/stmts.py +65 -29
  65. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  66. bloqade/squin/rewrite/__init__.py +0 -2
  67. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  68. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  69. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  70. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  71. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  72. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  73. bloqade/squin/stdlib/simple/__init__.py +33 -0
  74. bloqade/squin/stdlib/simple/gate.py +242 -0
  75. bloqade/squin/stdlib/simple/noise.py +126 -0
  76. bloqade/stim/__init__.py +1 -0
  77. bloqade/stim/_wrappers.py +6 -0
  78. bloqade/stim/dialects/noise/emit.py +6 -1
  79. bloqade/stim/dialects/noise/stmts.py +5 -3
  80. bloqade/stim/emit/stim_str.py +2 -0
  81. bloqade/stim/parse/lowering.py +12 -17
  82. bloqade/stim/passes/__init__.py +0 -1
  83. bloqade/stim/passes/flatten.py +26 -0
  84. bloqade/stim/passes/simplify_ifs.py +6 -1
  85. bloqade/stim/passes/squin_to_stim.py +4 -70
  86. bloqade/stim/rewrite/__init__.py +0 -4
  87. bloqade/stim/rewrite/ifs_to_stim.py +23 -29
  88. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  89. bloqade/stim/rewrite/squin_measure.py +9 -18
  90. bloqade/stim/rewrite/squin_noise.py +132 -108
  91. bloqade/stim/rewrite/util.py +5 -204
  92. bloqade/types.py +10 -0
  93. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/METADATA +2 -2
  94. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/RECORD +96 -100
  95. bloqade/native/dialects/gates/__init__.py +0 -3
  96. bloqade/native/dialects/gates/_dialect.py +0 -3
  97. bloqade/pyqrack/squin/op.py +0 -180
  98. bloqade/pyqrack/squin/runtime.py +0 -543
  99. bloqade/pyqrack/squin/wire.py +0 -51
  100. bloqade/squin/_typeinfer.py +0 -20
  101. bloqade/squin/analysis/address_impl.py +0 -71
  102. bloqade/squin/analysis/nsites/__init__.py +0 -9
  103. bloqade/squin/analysis/nsites/analysis.py +0 -50
  104. bloqade/squin/analysis/nsites/impls.py +0 -99
  105. bloqade/squin/analysis/nsites/lattice.py +0 -49
  106. bloqade/squin/cirq/__init__.py +0 -306
  107. bloqade/squin/cirq/emit/emit_circuit.py +0 -129
  108. bloqade/squin/cirq/emit/noise.py +0 -49
  109. bloqade/squin/cirq/emit/op.py +0 -176
  110. bloqade/squin/cirq/emit/qubit.py +0 -58
  111. bloqade/squin/cirq/emit/runtime.py +0 -242
  112. bloqade/squin/cirq/lowering.py +0 -439
  113. bloqade/squin/lowering.py +0 -80
  114. bloqade/squin/noise/_wrapper.py +0 -36
  115. bloqade/squin/noise/rewrite.py +0 -129
  116. bloqade/squin/op/__init__.py +0 -41
  117. bloqade/squin/op/_dialect.py +0 -3
  118. bloqade/squin/op/_wrapper.py +0 -121
  119. bloqade/squin/op/number.py +0 -5
  120. bloqade/squin/op/rewrite.py +0 -46
  121. bloqade/squin/op/stdlib.py +0 -62
  122. bloqade/squin/op/stmts.py +0 -300
  123. bloqade/squin/op/traits.py +0 -43
  124. bloqade/squin/op/types.py +0 -128
  125. bloqade/squin/parallel.py +0 -200
  126. bloqade/squin/qubit.py +0 -194
  127. bloqade/squin/rewrite/canonicalize.py +0 -60
  128. bloqade/squin/rewrite/desugar.py +0 -102
  129. bloqade/squin/stdlib/channel.py +0 -86
  130. bloqade/squin/stdlib/gate.py +0 -201
  131. bloqade/squin/types.py +0 -8
  132. bloqade/squin/wire.py +0 -201
  133. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  134. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  135. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/WHEEL +0 -0
  136. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,126 @@
1
+ from typing import Any, Literal, TypeVar
2
+
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import Qubit
6
+
7
+ from .. import broadcast
8
+ from ...groups import kernel
9
+
10
+
11
+ @kernel
12
+ def depolarize(p: float, qubit: Qubit) -> None:
13
+ """
14
+ Apply a depolarizing noise channel to a qubit with probability `p`.
15
+
16
+ This will randomly select one of the Pauli operators X, Y, Z
17
+ with a probability `p / 3` and apply it to the qubit. No operator is applied
18
+ with a probability of `1 - p`.
19
+
20
+ Args:
21
+ p (float): The probability with which a Pauli operator is applied.
22
+ qubit (Qubit): The qubit to which the noise channel is applied.
23
+ """
24
+ broadcast.depolarize(p, ilist.IList([qubit]))
25
+
26
+
27
+ N = TypeVar("N", bound=int)
28
+
29
+
30
+ @kernel
31
+ def depolarize2(p: float, control: Qubit, target: Qubit) -> None:
32
+ """
33
+ Symmetric two-qubit depolarization channel applied to a pair of qubits.
34
+
35
+ This will randomly select one of the pauli products
36
+
37
+ `{IX, IY, IZ, XI, XX, XY, XZ, YI, YX, YY, YZ, ZI, ZX, ZY, ZZ}`
38
+
39
+ each with a probability `p / 15`. No noise is applied with a probability of `1 - p`.
40
+
41
+ Args:
42
+ p (float): The probability with which a Pauli product is applied.
43
+ control (Qubit): The control qubit.
44
+ target (Qubit): The target qubit.
45
+ """
46
+ broadcast.depolarize2(p, ilist.IList([control]), ilist.IList([target]))
47
+
48
+
49
+ @kernel
50
+ def single_qubit_pauli_channel(px: float, py: float, pz: float, qubit: Qubit) -> None:
51
+ """
52
+ Apply a Pauli error channel with weighted `px, py, pz`. No error is applied with a probability
53
+ `1 - (px + py + pz)`.
54
+
55
+ This randomly selects one of the three Pauli operators X, Y, Z, weighted with the given probabilities in that order.
56
+
57
+ Args:
58
+ probabilities (IList[float, Literal[3]]): A list of 3 probabilities corresponding to the probabilities `(p_x, p_y, p_z)` in that order.
59
+ qubit (Qubit): The qubit to which the noise channel is applied.
60
+ """
61
+ broadcast.single_qubit_pauli_channel(px, py, pz, ilist.IList([qubit]))
62
+
63
+
64
+ @kernel
65
+ def two_qubit_pauli_channel(
66
+ probabilities: ilist.IList[float, Literal[15]], control: Qubit, target: Qubit
67
+ ) -> None:
68
+ """
69
+ Apply a Pauli product error with weighted `probabilities` to the pair of qubits.
70
+
71
+ No error is applied with the probability `1 - sum(probabilities)`.
72
+
73
+ This will randomly select one of the pauli products
74
+
75
+ `{IX, IY, IZ, XI, XX, XY, XZ, YI, YX, YY, YZ, ZI, ZX, ZY, ZZ}`
76
+
77
+ weighted with the corresponding list of probabilities.
78
+
79
+ **NOTE**: The order of the given probabilities must match the order of the list of Pauli products above!
80
+ """
81
+ broadcast.two_qubit_pauli_channel(
82
+ probabilities, ilist.IList([control]), ilist.IList([target])
83
+ )
84
+
85
+
86
+ @kernel
87
+ def qubit_loss(p: float, qubit: Qubit) -> None:
88
+ """
89
+ Apply a qubit loss channel to the given qubit.
90
+
91
+ The qubit is lost with a probability `p`.
92
+
93
+ Args:
94
+ p (float): Probability of the atom being lost.
95
+ qubit (Qubit): The qubit to which the noise channel is applied.
96
+ """
97
+ broadcast.qubit_loss(p, ilist.IList([qubit]))
98
+
99
+
100
+ @kernel
101
+ def correlated_qubit_loss(p: float, qubits: ilist.IList[Qubit, Any]) -> None:
102
+ """
103
+ Apply a correlated qubit loss channel to the given qubits.
104
+
105
+ All qubits are lost together with a probability `p`.
106
+
107
+ Args:
108
+ p (float): Probability of the qubits being lost.
109
+ qubits (IList[Qubit, Any]): The list of qubits to which the correlated noise channel is applied.
110
+ """
111
+ broadcast.correlated_qubit_loss(p, ilist.IList([qubits]))
112
+
113
+
114
+ # NOTE: actual stdlib that doesn't wrap statements starts here
115
+
116
+
117
+ @kernel
118
+ def bit_flip(p: float, qubit: Qubit) -> None:
119
+ """
120
+ Apply a bit flip error channel to the qubit with probability `p`.
121
+
122
+ Args:
123
+ p (float): Probability of a bit flip error being applied.
124
+ qubit (Qubit): The qubit to which the noise channel is applied.
125
+ """
126
+ single_qubit_pauli_channel(p, 0, 0, qubit)
bloqade/stim/__init__.py CHANGED
@@ -39,4 +39,5 @@ from ._wrappers import (
39
39
  pauli_channel1 as pauli_channel1,
40
40
  pauli_channel2 as pauli_channel2,
41
41
  observable_include as observable_include,
42
+ correlated_qubit_loss as correlated_qubit_loss,
42
43
  )
bloqade/stim/_wrappers.py CHANGED
@@ -194,3 +194,9 @@ def z_error(p: float, targets: tuple[int, ...]) -> None: ...
194
194
 
195
195
  @wraps(noise.QubitLoss)
196
196
  def qubit_loss(probs: tuple[float, ...], targets: tuple[int, ...]) -> None: ...
197
+
198
+
199
+ @wraps(noise.CorrelatedQubitLoss)
200
+ def correlated_qubit_loss(
201
+ probs: tuple[float, ...], targets: tuple[int, ...]
202
+ ) -> None: ...
@@ -81,6 +81,7 @@ class EmitStimNoiseMethods(MethodTable):
81
81
  return ()
82
82
 
83
83
  @impl(stmts.TrivialCorrelatedError)
84
+ @impl(stmts.CorrelatedQubitLoss)
84
85
  def non_stim_corr_error(
85
86
  self,
86
87
  emit: EmitStimMain,
@@ -92,7 +93,11 @@ class EmitStimNoiseMethods(MethodTable):
92
93
  prob: tuple[str, ...] = frame.get_values(stmt.probs)
93
94
  prob_str: str = ", ".join(prob)
94
95
 
95
- res = f"I_ERROR[{stmt.name}:{stmt.nonce}]({prob_str}) " + " ".join(targets)
96
+ res = (
97
+ f"I_ERROR[{stmt.name}:{emit.correlated_error_count}]({prob_str}) "
98
+ + " ".join(targets)
99
+ )
100
+ emit.correlated_error_count += 1
96
101
  emit.writeln(frame, res)
97
102
 
98
103
  return ()
@@ -89,9 +89,6 @@ class NonStimError(ir.Statement):
89
89
  class NonStimCorrelatedError(ir.Statement):
90
90
  name = "NonStimCorrelatedError"
91
91
  traits = frozenset({lowering.FromPythonCall()})
92
- nonce: int = (
93
- info.attribute()
94
- ) # Must be a unique value, otherwise stim might merge two correlated errors with equal probabilities
95
92
  probs: tuple[ir.SSAValue, ...] = info.argument(types.Float)
96
93
  targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
97
94
 
@@ -109,3 +106,8 @@ class TrivialError(NonStimError):
109
106
  @statement(dialect=dialect)
110
107
  class QubitLoss(NonStimError):
111
108
  name = "loss"
109
+
110
+
111
+ @statement(dialect=dialect)
112
+ class CorrelatedQubitLoss(NonStimCorrelatedError):
113
+ name = "correlated_loss"
@@ -20,11 +20,13 @@ class EmitStimMain(EmitStr):
20
20
  keys = ["emit.stim"]
21
21
  dialects: ir.DialectGroup = field(default_factory=_default_dialect_group)
22
22
  file: StringIO = field(default_factory=StringIO)
23
+ correlation_identifier_offset: int = 0
23
24
 
24
25
  def initialize(self):
25
26
  super().initialize()
26
27
  self.file.truncate(0)
27
28
  self.file.seek(0)
29
+ self.correlated_error_count = self.correlation_identifier_offset
28
30
  return self
29
31
 
30
32
  def eval_stmt_fallback(
@@ -627,10 +627,13 @@ class Stim(lowering.LoweringABC[Node]):
627
627
  # Parse tag
628
628
  tag_parts = node.tag.split(";", maxsplit=1)[0].split(":", maxsplit=1)
629
629
  nonstim_name = tag_parts[0]
630
- nonce = 0
631
630
  if len(tag_parts) == 2:
631
+ # This should be a correlated error of the form, e.g.,
632
+ # I_ERROR[correlated_loss:<identifier>](0.01) 0 1 2
633
+ # The identifier is a unique number that prevents stim from merging
634
+ # correlated errors. We discard the identifier, but verify it is an integer.
632
635
  try:
633
- nonce = int(tag_parts[1])
636
+ _ = int(tag_parts[1])
634
637
  except ValueError:
635
638
  # String was not an integer
636
639
  if self.error_unknown_nonstim:
@@ -643,22 +646,14 @@ class Stim(lowering.LoweringABC[Node]):
643
646
  f"Unknown non-stim statement name: {nonstim_name!r} ({node!r})"
644
647
  )
645
648
  statement_cls = self.nonstim_noise_ops.get(nonstim_name)
649
+ stmt = None
646
650
  if statement_cls is not None:
647
- if issubclass(statement_cls, noise.NonStimCorrelatedError):
648
- stmt = statement_cls(
649
- nonce=nonce,
650
- probs=self._get_float_args_ssa(state, node.gate_args_copy()),
651
- targets=self._get_multiple_qubit_or_rec_ssa(
652
- state, node, node.targets_copy()
653
- ),
654
- )
655
- else:
656
- stmt = statement_cls(
657
- probs=self._get_float_args_ssa(state, node.gate_args_copy()),
658
- targets=self._get_multiple_qubit_or_rec_ssa(
659
- state, node, node.targets_copy()
660
- ),
661
- )
651
+ stmt = statement_cls(
652
+ probs=self._get_float_args_ssa(state, node.gate_args_copy()),
653
+ targets=self._get_multiple_qubit_or_rec_ssa(
654
+ state, node, node.targets_copy()
655
+ ),
656
+ )
662
657
  return stmt
663
658
 
664
659
  def visit_CircuitInstruction(
@@ -1,4 +1,3 @@
1
1
  from .squin_to_stim import (
2
2
  SquinToStimPass as SquinToStimPass,
3
- StimSimplifyIfs as StimSimplifyIfs,
4
3
  )
@@ -0,0 +1,26 @@
1
+ # Taken from Phillip Weinberg's bloqade-shuttle implementation
2
+ from dataclasses import field, dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.passes import Pass
6
+ from kirin.rewrite.abc import RewriteResult
7
+
8
+ from bloqade.qasm2.passes.fold import AggressiveUnroll
9
+ from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs
10
+
11
+
12
+ @dataclass
13
+ class Flatten(Pass):
14
+
15
+ unroll: AggressiveUnroll = field(init=False)
16
+ simplify_if: StimSimplifyIfs = field(init=False)
17
+
18
+ def __post_init__(self):
19
+ self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise)
20
+ self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise)
21
+
22
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
23
+ rewrite_result = RewriteResult()
24
+ rewrite_result = self.simplify_if(mt).join(rewrite_result)
25
+ rewrite_result = self.unroll(mt).join(rewrite_result)
26
+ return rewrite_result
@@ -7,8 +7,10 @@ from kirin.rewrite import (
7
7
  Chain,
8
8
  Fixpoint,
9
9
  ConstantFold,
10
+ DeadCodeElimination,
10
11
  CommonSubexpressionElimination,
11
12
  )
13
+ from kirin.dialects.scf.trim import UnusedYield
12
14
  from kirin.dialects.ilist.passes import ConstList2IList
13
15
 
14
16
  from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
@@ -20,7 +22,10 @@ class StimSimplifyIfs(Pass):
20
22
  def unsafe_run(self, mt: ir.Method):
21
23
 
22
24
  result = Chain(
23
- Fixpoint(Walk(StimLiftThenBody())),
25
+ Walk(UnusedYield()),
26
+ Walk(StimLiftThenBody()),
27
+ # remove yields (if possible), then lift out as much stuff as possible
28
+ Walk(DeadCodeElimination()),
24
29
  Walk(StimSplitIfStmts()),
25
30
  ).rewrite(mt.code)
26
31
 
@@ -1,29 +1,21 @@
1
1
  from dataclasses import dataclass
2
2
 
3
- from kirin.passes import Fold, TypeInfer
4
3
  from kirin.rewrite import (
5
4
  Walk,
6
5
  Chain,
7
6
  Fixpoint,
8
- CFGCompactify,
9
7
  DeadCodeElimination,
10
8
  CommonSubexpressionElimination,
11
9
  )
12
- from kirin.dialects import ilist
13
10
  from kirin.ir.method import Method
14
11
  from kirin.passes.abc import Pass
15
12
  from kirin.rewrite.abc import RewriteResult
16
- from kirin.passes.inline import InlinePass
17
- from kirin.rewrite.alias import InlineAlias
18
- from kirin.passes.aggressive import UnrollScf
19
13
 
20
14
  from bloqade.stim.rewrite import (
21
- SquinWireToStim,
22
15
  PyConstantToStim,
23
16
  SquinNoiseToStim,
24
17
  SquinQubitToStim,
25
18
  SquinMeasureToStim,
26
- SquinWireIdentityElimination,
27
19
  )
28
20
  from bloqade.squin.rewrite import (
29
21
  SquinU3ToClifford,
@@ -33,9 +25,8 @@ from bloqade.squin.rewrite import (
33
25
  from bloqade.rewrite.passes import CanonicalizeIList
34
26
  from bloqade.analysis.address import AddressAnalysis
35
27
  from bloqade.analysis.measure_id import MeasurementIDAnalysis
36
- from bloqade.squin.rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule
28
+ from bloqade.stim.passes.flatten import Flatten
37
29
 
38
- from .simplify_ifs import StimSimplifyIfs
39
30
  from ..rewrite.ifs_to_stim import IfToStim
40
31
 
41
32
 
@@ -45,63 +36,8 @@ class SquinToStimPass(Pass):
45
36
  def unsafe_run(self, mt: Method) -> RewriteResult:
46
37
 
47
38
  # inline aggressively:
48
- rewrite_result = InlinePass(
49
- dialects=mt.dialects, no_raise=self.no_raise
50
- ).unsafe_run(mt)
51
-
52
- rewrite_result = Walk(ilist.rewrite.HintLen()).rewrite(mt.code)
53
- rewrite_result = Fold(self.dialects).unsafe_run(mt).join(rewrite_result)
54
-
55
- rewrite_result = (
56
- UnrollScf(dialects=mt.dialects, no_raise=self.no_raise)
57
- .fixpoint(mt)
58
- .join(rewrite_result)
59
- )
60
-
61
- rewrite_result = (
62
- Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
63
- )
64
-
65
- rewrite_result = Walk(InlineAlias()).rewrite(mt.code).join(rewrite_result)
66
-
67
- rewrite_result = (
68
- StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
69
- .unsafe_run(mt)
70
- .join(rewrite_result)
71
- )
72
-
73
- rewrite_result = (
74
- Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
75
- .rewrite(mt.code)
76
- .join(rewrite_result)
77
- )
78
- rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt)
79
-
80
- rewrite_result = (
81
- UnrollScf(mt.dialects, no_raise=self.no_raise)
82
- .fixpoint(mt)
83
- .join(rewrite_result)
84
- )
85
-
86
- rewrite_result = (
87
- CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
88
- .unsafe_run(mt)
89
- .join(rewrite_result)
90
- )
91
-
92
- rewrite_result = TypeInfer(
93
- dialects=mt.dialects, no_raise=self.no_raise
94
- ).unsafe_run(mt)
95
-
96
- rewrite_result = (
97
- Walk(
98
- Chain(
99
- ApplyDesugarRule(),
100
- MeasureDesugarRule(),
101
- )
102
- )
103
- .rewrite(mt.code)
104
- .join(rewrite_result)
39
+ rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint(
40
+ mt
105
41
  )
106
42
 
107
43
  # after this the program should be in a state where it is analyzable
@@ -145,8 +81,6 @@ class SquinToStimPass(Pass):
145
81
  Chain(
146
82
  SquinQubitToStim(),
147
83
  SquinMeasureToStim(),
148
- SquinWireToStim(),
149
- SquinWireIdentityElimination(),
150
84
  )
151
85
  )
152
86
  .rewrite(mt.code)
@@ -163,7 +97,7 @@ class SquinToStimPass(Pass):
163
97
  rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
164
98
 
165
99
  # clear up leftover stmts
166
- # - remove any squin.qubit.new that's left around
100
+ # - remove any squin.qalloc that's left around
167
101
  rewrite_result = (
168
102
  Fixpoint(
169
103
  Walk(
@@ -1,9 +1,5 @@
1
1
  from .ifs_to_stim import IfToStim as IfToStim
2
2
  from .squin_noise import SquinNoiseToStim as SquinNoiseToStim
3
- from .wire_to_stim import SquinWireToStim as SquinWireToStim
4
3
  from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
5
4
  from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
6
5
  from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
7
- from .wire_identity_elimination import (
8
- SquinWireIdentityElimination as SquinWireIdentityElimination,
9
- )
@@ -4,13 +4,13 @@ from kirin import ir
4
4
  from kirin.dialects import py, scf, func
5
5
  from kirin.rewrite.abc import RewriteRule, RewriteResult
6
6
 
7
- from bloqade.squin import op, qubit
7
+ from bloqade.squin import gate
8
8
  from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
9
9
  from bloqade.squin.rewrite import AddressAttribute
10
10
  from bloqade.stim.rewrite.util import (
11
- SQUIN_STIM_CONTROL_GATE_MAPPING,
12
11
  insert_qubit_idx_from_address,
13
12
  )
13
+ from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ
14
14
  from bloqade.analysis.measure_id import MeasureIDFrame
15
15
  from bloqade.stim.dialects.auxiliary import GetRecord
16
16
  from bloqade.analysis.measure_id.lattice import (
@@ -58,8 +58,7 @@ class IfElseSimplification:
58
58
  """Check if the IfElse statement has an else body."""
59
59
  if stmt.else_body.blocks and not (
60
60
  len(stmt.else_body.blocks[0].stmts) == 1
61
- and isinstance(else_term := stmt.else_body.blocks[0].last_stmt, scf.Yield)
62
- and not else_term.values # empty yield
61
+ and isinstance(stmt.else_body.blocks[0].last_stmt, scf.Yield)
63
62
  ):
64
63
  return True
65
64
 
@@ -67,12 +66,13 @@ class IfElseSimplification:
67
66
 
68
67
 
69
68
  DontLiftType = (
70
- qubit.Apply,
71
- qubit.Broadcast,
72
- scf.Yield,
69
+ gate.stmts.SingleQubitGate,
70
+ gate.stmts.RotationGate,
71
+ gate.stmts.ControlledGate,
73
72
  func.Return,
74
73
  func.Invoke,
75
74
  scf.IfElse,
75
+ scf.Yield,
76
76
  )
77
77
 
78
78
 
@@ -99,16 +99,16 @@ class StimSplitIfStmts(IfElseSimplification, SplitIfStmts):
99
99
  Given an IfElse with multiple valid statements in the then-body:
100
100
 
101
101
  if measure_result:
102
- squin.qubit.apply(op.X, q0)
103
- squin.qubit.apply(op.Y, q1)
102
+ squin.x(q0)
103
+ squin.y(q1)
104
104
 
105
105
  this should be rewritten to:
106
106
 
107
107
  if measure_result:
108
- squin.qubit.apply(op.X, q0)
108
+ squin.x(q0)
109
109
 
110
110
  if measure_result:
111
- squin.qubit.apply(op.Y, q1)
111
+ squin.y(q1)
112
112
  """
113
113
 
114
114
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
@@ -139,24 +139,23 @@ class IfToStim(IfElseSimplification, RewriteRule):
139
139
 
140
140
  def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
141
141
 
142
+ # Check the condition is a singular MeasurementIdBool
142
143
  if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool):
143
144
  return RewriteResult()
144
145
 
145
- # check that there is only qubit.Apply in the then-body,
146
- # if there's more than that, we can't do a valid rewrite.
147
- # Can reuse logic from SplitIf
146
+ # Reusing code from SplitIf,
147
+ # there should only be one statement in the body and it should be a pauli X, Y, or Z
148
148
  *stmts, _ = stmt.then_body.stmts()
149
- if len(stmts) != 1 or not isinstance(stmts[0], (qubit.Apply, qubit.Broadcast)):
149
+ if len(stmts) != 1:
150
150
  return RewriteResult()
151
151
 
152
- apply_or_broadcast = stmts[0]
153
- # Check that the gate being applied/broadcasted can be converted to a stim
154
- # controlled gate.
155
- ctrl_op_target_gate = apply_or_broadcast.operator.owner
156
- assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
157
-
158
- stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate))
159
- if stim_gate is None:
152
+ if isinstance(stmts[0], gate.stmts.X):
153
+ stim_gate = stim_CX
154
+ elif isinstance(stmts[0], gate.stmts.Y):
155
+ stim_gate = stim_CY
156
+ elif isinstance(stmts[0], gate.stmts.Z):
157
+ stim_gate = stim_CZ
158
+ else:
160
159
  return RewriteResult()
161
160
 
162
161
  # get necessary measurement ID type from analysis
@@ -169,12 +168,7 @@ class IfToStim(IfElseSimplification, RewriteRule):
169
168
  )
170
169
  get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
171
170
 
172
- # get address attribute and generate qubit idx statements
173
- if len(apply_or_broadcast.qubits) != 1:
174
- # NOTE: this is actually invalid since we are dealing with single-qubit operators here
175
- return RewriteResult()
176
-
177
- address_attr = apply_or_broadcast.qubits[0].hints.get("address")
171
+ address_attr = stmts[0].qubits.hints.get("address")
178
172
 
179
173
  if address_attr is None:
180
174
  return RewriteResult()