bloqade-circuit 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

@@ -1,2 +1,5 @@
1
1
  from . import impls as impls
2
- from .analysis import MeasurementIDAnalysis as MeasurementIDAnalysis
2
+ from .analysis import (
3
+ MeasureIDFrame as MeasureIDFrame,
4
+ MeasurementIDAnalysis as MeasurementIDAnalysis,
5
+ )
@@ -1,13 +1,19 @@
1
1
  from typing import TypeVar
2
+ from dataclasses import field, dataclass
2
3
 
3
4
  from kirin import ir, interp
4
- from kirin.analysis import Forward, const
5
+ from kirin.analysis import ForwardExtra, const
5
6
  from kirin.analysis.forward import ForwardFrame
6
7
 
7
8
  from .lattice import MeasureId, NotMeasureId
8
9
 
9
10
 
10
- class MeasurementIDAnalysis(Forward[MeasureId]):
11
+ @dataclass
12
+ class MeasureIDFrame(ForwardFrame[MeasureId]):
13
+ num_measures_at_stmt: dict[ir.Statement, int] = field(default_factory=dict)
14
+
15
+
16
+ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):
11
17
 
12
18
  keys = ["measure_id"]
13
19
  lattice = MeasureId
@@ -15,6 +21,11 @@ class MeasurementIDAnalysis(Forward[MeasureId]):
15
21
  # then use this to generate the negative values for target rec indices
16
22
  measure_count = 0
17
23
 
24
+ def initialize_frame(
25
+ self, code: ir.Statement, *, has_parent_access: bool = False
26
+ ) -> MeasureIDFrame:
27
+ return MeasureIDFrame(code, has_parent_access=has_parent_access)
28
+
18
29
  # Still default to bottom,
19
30
  # but let constants return the softer "NoMeasureId" type from impl
20
31
  def eval_stmt_fallback(
@@ -1,4 +1,5 @@
1
1
  from kirin import types as kirin_types, interp
2
+ from kirin.analysis import const
2
3
  from kirin.dialects import py, scf, func, ilist
3
4
 
4
5
  from bloqade.squin import wire, qubit
@@ -10,7 +11,7 @@ from .lattice import (
10
11
  MeasureIdTuple,
11
12
  InvalidMeasureId,
12
13
  )
13
- from .analysis import MeasurementIDAnalysis
14
+ from .analysis import MeasureIDFrame, MeasurementIDAnalysis
14
15
 
15
16
  ## Can't do wire right now because of
16
17
  ## unresolved RFC on return type
@@ -113,6 +114,15 @@ class PyIndexing(interp.MethodTable):
113
114
  return (InvalidMeasureId(),)
114
115
 
115
116
 
117
+ @py.assign.dialect.register(key="measure_id")
118
+ class PyAssign(interp.MethodTable):
119
+ @interp.impl(py.Alias)
120
+ def alias(
121
+ self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.assign.Alias
122
+ ):
123
+ return (frame.get(stmt.value),)
124
+
125
+
116
126
  @py.binop.dialect.register(key="measure_id")
117
127
  class PyBinOp(interp.MethodTable):
118
128
  @interp.impl(py.Add)
@@ -152,4 +162,33 @@ class Func(interp.MethodTable):
152
162
  # scf, particularly IfElse
153
163
  @scf.dialect.register(key="measure_id")
154
164
  class Scf(scf.absint.Methods):
155
- pass
165
+
166
+ @interp.impl(scf.IfElse)
167
+ def if_else(
168
+ self,
169
+ interp_: MeasurementIDAnalysis,
170
+ frame: MeasureIDFrame,
171
+ stmt: scf.IfElse,
172
+ ):
173
+
174
+ frame.num_measures_at_stmt[stmt] = interp_.measure_count
175
+
176
+ # rest of the code taken directly from scf.absint.Methods base implementation
177
+
178
+ if isinstance(hint := stmt.cond.hints.get("const"), const.Value):
179
+ if hint.data:
180
+ return self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
181
+ else:
182
+ return self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
183
+ then_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
184
+ else_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
185
+
186
+ match (then_results, else_results):
187
+ case (interp.ReturnValue(then_value), interp.ReturnValue(else_value)):
188
+ return interp.ReturnValue(then_value.join(else_value))
189
+ case (interp.ReturnValue(then_value), _):
190
+ return then_results
191
+ case (_, interp.ReturnValue(else_value)):
192
+ return else_results
193
+ case _:
194
+ return interp_.join_results(then_results, else_results)
bloqade/pyqrack/base.py CHANGED
@@ -48,7 +48,7 @@ def _default_pyqrack_args() -> PyQrackOptions:
48
48
  isSchmidtDecomposeMulti=True,
49
49
  isSchmidtDecompose=True,
50
50
  isStabilizerHybrid=False,
51
- isBinaryDecisionTree=True,
51
+ isBinaryDecisionTree=False,
52
52
  isPaged=True,
53
53
  isCpuGpuHybrid=True,
54
54
  isOpenCL=True,
bloqade/squin/__init__.py CHANGED
@@ -9,6 +9,10 @@ from . import (
9
9
  )
10
10
  from .groups import wired as wired, kernel as kernel
11
11
 
12
+ # NOTE: it's important to keep these imports here since they import squin.kernel
13
+ # we skip isort here
14
+ from . import gate as gate, parallel as parallel # isort: skip
15
+
12
16
  try:
13
17
  # NOTE: make sure optional cirq dependency is installed
14
18
  import cirq as cirq_package # noqa: F401
bloqade/squin/gate.py ADDED
@@ -0,0 +1,193 @@
1
+ from bloqade.types import Qubit
2
+
3
+ from . import op as _op, qubit as _qubit
4
+ from .groups import kernel
5
+
6
+
7
+ @kernel
8
+ def x(qubit: Qubit) -> None:
9
+ """x gate applied to qubit."""
10
+ op = _op.x()
11
+ _qubit.apply(op, qubit)
12
+
13
+
14
+ @kernel
15
+ def y(qubit: Qubit) -> None:
16
+ """y gate applied to qubit."""
17
+ op = _op.y()
18
+ _qubit.apply(op, qubit)
19
+
20
+
21
+ @kernel
22
+ def z(qubit: Qubit) -> None:
23
+ """z gate applied to qubit."""
24
+ op = _op.z()
25
+ _qubit.apply(op, qubit)
26
+
27
+
28
+ @kernel
29
+ def sqrt_x(qubit: Qubit) -> None:
30
+ """Square root x gate applied to qubit."""
31
+ op = _op.sqrt_x()
32
+ _qubit.apply(op, qubit)
33
+
34
+
35
+ @kernel
36
+ def sqrt_x_adj(qubit: Qubit) -> None:
37
+ """Adjoint sqrt_x gate applied to qubit."""
38
+ op = _op.sqrt_x()
39
+ _qubit.apply(_op.adjoint(op), qubit)
40
+
41
+
42
+ @kernel
43
+ def sqrt_y(qubit: Qubit) -> None:
44
+ """Square root y gate applied to qubit."""
45
+ op = _op.sqrt_y()
46
+ _qubit.apply(op, qubit)
47
+
48
+
49
+ @kernel
50
+ def sqrt_y_adj(qubit: Qubit) -> None:
51
+ """Adjoint sqrt_y gate applied to qubit."""
52
+ op = _op.sqrt_y()
53
+ _qubit.apply(_op.adjoint(op), qubit)
54
+
55
+
56
+ @kernel
57
+ def sqrt_z(qubit: Qubit) -> None:
58
+ """Square root z gate applied to qubit."""
59
+ op = _op.s()
60
+ _qubit.apply(op, qubit)
61
+
62
+
63
+ @kernel
64
+ def sqrt_z_adj(qubit: Qubit) -> None:
65
+ """Adjoint square root z gate applied to qubit."""
66
+ op = _op.s()
67
+ _qubit.apply(_op.adjoint(op), qubit)
68
+
69
+
70
+ @kernel
71
+ def h(qubit: Qubit) -> None:
72
+ """Hadamard gate applied to qubit."""
73
+ op = _op.h()
74
+ _qubit.apply(op, qubit)
75
+
76
+
77
+ @kernel
78
+ def s(qubit: Qubit) -> None:
79
+ """s gate applied to qubit."""
80
+ op = _op.s()
81
+ _qubit.apply(op, qubit)
82
+
83
+
84
+ @kernel
85
+ def s_adj(qubit: Qubit) -> None:
86
+ """Adjoint s gate applied to qubit."""
87
+ op = _op.s()
88
+ _qubit.apply(_op.adjoint(op), qubit)
89
+
90
+
91
+ @kernel
92
+ def t(qubit: Qubit) -> None:
93
+ """t gate applied to qubit."""
94
+ op = _op.t()
95
+ _qubit.apply(op, qubit)
96
+
97
+
98
+ @kernel
99
+ def t_adj(qubit: Qubit) -> None:
100
+ """Adjoint t gate applied to qubit."""
101
+ op = _op.t()
102
+ _qubit.apply(_op.adjoint(op), qubit)
103
+
104
+
105
+ @kernel
106
+ def p0(qubit: Qubit) -> None:
107
+ """Projector on 0 applied to qubit."""
108
+ op = _op.p0()
109
+ _qubit.apply(op, qubit)
110
+
111
+
112
+ @kernel
113
+ def p1(qubit: Qubit) -> None:
114
+ """Projector on 1 applied to qubit."""
115
+ op = _op.p1()
116
+ _qubit.apply(op, qubit)
117
+
118
+
119
+ @kernel
120
+ def spin_n(qubit: Qubit) -> None:
121
+ """Spin lowering gate applied to qubit."""
122
+ op = _op.spin_n()
123
+ _qubit.apply(op, qubit)
124
+
125
+
126
+ @kernel
127
+ def spin_p(qubit: Qubit) -> None:
128
+ """Spin raising gate applied to qubit."""
129
+ op = _op.spin_p()
130
+ _qubit.apply(op, qubit)
131
+
132
+
133
+ @kernel
134
+ def reset(qubit: Qubit) -> None:
135
+ """Reset qubit to 0."""
136
+ op = _op.reset()
137
+ _qubit.apply(op, qubit)
138
+
139
+
140
+ @kernel
141
+ def cx(control: Qubit, target: Qubit) -> None:
142
+ """Controlled x gate applied to control and target"""
143
+ op = _op.cx()
144
+ _qubit.apply(op, control, target)
145
+
146
+
147
+ @kernel
148
+ def cy(control: Qubit, target: Qubit) -> None:
149
+ """Controlled y gate applied to control and target"""
150
+ op = _op.cy()
151
+ _qubit.apply(op, control, target)
152
+
153
+
154
+ @kernel
155
+ def cz(control: Qubit, target: Qubit) -> None:
156
+ """Controlled z gate applied to control and target"""
157
+ op = _op.cz()
158
+ _qubit.apply(op, control, target)
159
+
160
+
161
+ @kernel
162
+ def ch(control: Qubit, target: Qubit) -> None:
163
+ """Controlled Hadamard gate applied to control and target"""
164
+ op = _op.ch()
165
+ _qubit.apply(op, control, target)
166
+
167
+
168
+ @kernel
169
+ def u(theta: float, phi: float, lam: float, qubit: Qubit) -> None:
170
+ """3D rotation gate applied to control and target"""
171
+ op = _op.u(theta, phi, lam)
172
+ _qubit.apply(op, qubit)
173
+
174
+
175
+ @kernel
176
+ def rx(theta: float, qubit: Qubit) -> None:
177
+ """Rotation X gate applied to qubit."""
178
+ op = _op.rot(_op.x(), theta)
179
+ _qubit.apply(op, qubit)
180
+
181
+
182
+ @kernel
183
+ def ry(theta: float, qubit: Qubit) -> None:
184
+ """Rotation Y gate applied to qubit."""
185
+ op = _op.rot(_op.y(), theta)
186
+ _qubit.apply(op, qubit)
187
+
188
+
189
+ @kernel
190
+ def rz(theta: float, qubit: Qubit) -> None:
191
+ """Rotation Z gate applied to qubit."""
192
+ op = _op.rot(_op.z(), theta)
193
+ _qubit.apply(op, qubit)
@@ -0,0 +1,200 @@
1
+ from typing import Any, TypeVar
2
+
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import Qubit
6
+
7
+ from . import op as _op, qubit as _qubit
8
+ from .groups import kernel
9
+
10
+
11
+ @kernel
12
+ def x(qubits: ilist.IList[Qubit, Any]) -> None:
13
+ """x gate applied to qubits in parallel."""
14
+ op = _op.x()
15
+ _qubit.broadcast(op, qubits)
16
+
17
+
18
+ @kernel
19
+ def y(qubits: ilist.IList[Qubit, Any]) -> None:
20
+ """y gate applied to qubits in parallel."""
21
+ op = _op.y()
22
+ _qubit.broadcast(op, qubits)
23
+
24
+
25
+ @kernel
26
+ def z(qubits: ilist.IList[Qubit, Any]) -> None:
27
+ """z gate applied to qubits in parallel."""
28
+ op = _op.z()
29
+ _qubit.broadcast(op, qubits)
30
+
31
+
32
+ @kernel
33
+ def sqrt_x(qubits: ilist.IList[Qubit, Any]) -> None:
34
+ """Square root x gate applied to qubits in parallel."""
35
+ op = _op.sqrt_x()
36
+ _qubit.broadcast(op, qubits)
37
+
38
+
39
+ @kernel
40
+ def sqrt_y(qubits: ilist.IList[Qubit, Any]) -> None:
41
+ """Square root y gate applied to qubits in parallel."""
42
+ op = _op.sqrt_y()
43
+ _qubit.broadcast(op, qubits)
44
+
45
+
46
+ @kernel
47
+ def sqrt_z(qubits: ilist.IList[Qubit, Any]) -> None:
48
+ """Square root gate applied to qubits in parallel."""
49
+ op = _op.s()
50
+ _qubit.broadcast(op, qubits)
51
+
52
+
53
+ @kernel
54
+ def h(qubits: ilist.IList[Qubit, Any]) -> None:
55
+ """Hadamard gate applied to qubits in parallel."""
56
+ op = _op.h()
57
+ _qubit.broadcast(op, qubits)
58
+
59
+
60
+ @kernel
61
+ def s(qubits: ilist.IList[Qubit, Any]) -> None:
62
+ """s gate applied to qubits in parallel."""
63
+ op = _op.s()
64
+ _qubit.broadcast(op, qubits)
65
+
66
+
67
+ @kernel
68
+ def t(qubits: ilist.IList[Qubit, Any]) -> None:
69
+ """t gate applied to qubits in parallel."""
70
+ op = _op.t()
71
+ _qubit.broadcast(op, qubits)
72
+
73
+
74
+ @kernel
75
+ def p0(qubits: ilist.IList[Qubit, Any]) -> None:
76
+ """Projector on 0 applied to qubits in parallel."""
77
+ op = _op.p0()
78
+ _qubit.broadcast(op, qubits)
79
+
80
+
81
+ @kernel
82
+ def p1(qubits: ilist.IList[Qubit, Any]) -> None:
83
+ """Projector on 1 applied to qubits in parallel."""
84
+ op = _op.p1()
85
+ _qubit.broadcast(op, qubits)
86
+
87
+
88
+ @kernel
89
+ def spin_n(qubits: ilist.IList[Qubit, Any]) -> None:
90
+ """Spin lowering gate applied to qubits in parallel."""
91
+ op = _op.spin_n()
92
+ _qubit.broadcast(op, qubits)
93
+
94
+
95
+ @kernel
96
+ def spin_p(qubits: ilist.IList[Qubit, Any]) -> None:
97
+ """Spin raising gate applied to qubits in parallel."""
98
+ op = _op.spin_p()
99
+ _qubit.broadcast(op, qubits)
100
+
101
+
102
+ @kernel
103
+ def reset(qubits: ilist.IList[Qubit, Any]) -> None:
104
+ """Reset qubit to 0."""
105
+ op = _op.reset()
106
+ _qubit.broadcast(op, qubits)
107
+
108
+
109
+ N = TypeVar("N")
110
+
111
+
112
+ @kernel
113
+ def cx(controls: ilist.IList[Qubit, N], targets: ilist.IList[Qubit, N]) -> None:
114
+ """Controlled x gate applied to controls and targets in parallel."""
115
+ op = _op.cx()
116
+ _qubit.broadcast(op, controls, targets)
117
+
118
+
119
+ @kernel
120
+ def cy(controls: ilist.IList[Qubit, N], targets: ilist.IList[Qubit, N]) -> None:
121
+ """Controlled y gate applied to controls and targets in parallel."""
122
+ op = _op.cy()
123
+ _qubit.broadcast(op, controls, targets)
124
+
125
+
126
+ @kernel
127
+ def cz(controls: ilist.IList[Qubit, N], targets: ilist.IList[Qubit, N]) -> None:
128
+ """Controlled z gate applied to controls and targets in parallel."""
129
+ op = _op.cz()
130
+ _qubit.broadcast(op, controls, targets)
131
+
132
+
133
+ @kernel
134
+ def ch(controls: ilist.IList[Qubit, N], targets: ilist.IList[Qubit, N]) -> None:
135
+ """Controlled Hadamard gate applied to controls and targets in parallel."""
136
+ op = _op.ch()
137
+ _qubit.broadcast(op, controls, targets)
138
+
139
+
140
+ @kernel
141
+ def u(theta: float, phi: float, lam: float, qubits: ilist.IList[Qubit, Any]) -> None:
142
+ """3D rotation gate applied to controls and targets in parallel."""
143
+ op = _op.u(theta, phi, lam)
144
+ _qubit.broadcast(op, qubits)
145
+
146
+
147
+ @kernel
148
+ def rx(theta: float, qubits: ilist.IList[Qubit, Any]) -> None:
149
+ """Rotation X gate applied to qubits in parallel."""
150
+ op = _op.rot(_op.x(), theta)
151
+ _qubit.broadcast(op, qubits)
152
+
153
+
154
+ @kernel
155
+ def ry(theta: float, qubits: ilist.IList[Qubit, Any]) -> None:
156
+ """Rotation Y gate applied to qubits in parallel."""
157
+ op = _op.rot(_op.y(), theta)
158
+ _qubit.broadcast(op, qubits)
159
+
160
+
161
+ @kernel
162
+ def rz(theta: float, qubits: ilist.IList[Qubit, Any]) -> None:
163
+ """Rotation Z gate applied to qubits in parallel."""
164
+ op = _op.rot(_op.z(), theta)
165
+ _qubit.broadcast(op, qubits)
166
+
167
+
168
+ @kernel
169
+ def sqrt_x_adj(qubits: ilist.IList[Qubit, Any]) -> None:
170
+ """Adjoint sqrt_x gate applied to qubits in parallel."""
171
+ op = _op.sqrt_x()
172
+ _qubit.broadcast(_op.adjoint(op), qubits)
173
+
174
+
175
+ @kernel
176
+ def sqrt_y_adj(qubits: ilist.IList[Qubit, Any]) -> None:
177
+ """Adjoint sqrt_y gate applied to qubits in parallel."""
178
+ op = _op.sqrt_y()
179
+ _qubit.broadcast(_op.adjoint(op), qubits)
180
+
181
+
182
+ @kernel
183
+ def sqrt_z_adj(qubits: ilist.IList[Qubit, Any]) -> None:
184
+ """Adjoint square root z gate applied to qubits in parallel."""
185
+ op = _op.s()
186
+ _qubit.broadcast(_op.adjoint(op), qubits)
187
+
188
+
189
+ @kernel
190
+ def s_adj(qubits: ilist.IList[Qubit, Any]) -> None:
191
+ """Adjoint s gate applied to qubits in parallel."""
192
+ op = _op.s()
193
+ _qubit.broadcast(_op.adjoint(op), qubits)
194
+
195
+
196
+ @kernel
197
+ def t_adj(qubits: ilist.IList[Qubit, Any]) -> None:
198
+ """Adjoint t gate applied to qubits in parallel."""
199
+ op = _op.t()
200
+ _qubit.broadcast(_op.adjoint(op), qubits)
@@ -1 +1,5 @@
1
- from .squin_to_stim import SquinToStimPass as SquinToStimPass
1
+ from .squin_to_stim import (
2
+ SquinToStimPass as SquinToStimPass,
3
+ StimSimplifyIfs as StimSimplifyIfs,
4
+ AggressiveForLoopUnroll as AggressiveForLoopUnroll,
5
+ )
@@ -9,6 +9,7 @@ from kirin.rewrite import (
9
9
  ConstantFold,
10
10
  CommonSubexpressionElimination,
11
11
  )
12
+ from kirin.dialects.ilist.passes import ConstList2IList
12
13
 
13
14
  from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
14
15
 
@@ -23,8 +24,16 @@ class StimSimplifyIfs(Pass):
23
24
  Walk(StimSplitIfStmts()),
24
25
  ).rewrite(mt.code)
25
26
 
27
+ # because nested python lists don't have their
28
+ # member lists converted to ILists, ConstantFold
29
+ # can add python lists that can't be hashed, causing
30
+ # issues with CSE. ConstList2IList remedies that problem here.
26
31
  result = (
27
- Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
32
+ Chain(
33
+ Fixpoint(Walk(ConstantFold())),
34
+ Walk(ConstList2IList()),
35
+ Walk(CommonSubexpressionElimination()),
36
+ )
28
37
  .rewrite(mt.code)
29
38
  .join(result)
30
39
  )
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
 
3
- from kirin.passes import Fold
3
+ from kirin.passes import Fold, HintConst, TypeInfer
4
4
  from kirin.rewrite import (
5
5
  Walk,
6
6
  Chain,
@@ -16,6 +16,7 @@ from kirin.ir.method import Method
16
16
  from kirin.passes.abc import Pass
17
17
  from kirin.rewrite.abc import RewriteResult
18
18
  from kirin.passes.inline import InlinePass
19
+ from kirin.rewrite.alias import InlineAlias
19
20
 
20
21
  from bloqade.stim.rewrite import (
21
22
  SquinWireToStim,
@@ -33,11 +34,43 @@ from bloqade.squin.rewrite import (
33
34
  from bloqade.rewrite.passes import CanonicalizeIList
34
35
  from bloqade.analysis.address import AddressAnalysis
35
36
  from bloqade.analysis.measure_id import MeasurementIDAnalysis
37
+ from bloqade.squin.rewrite.desugar import ApplyDesugarRule
36
38
 
37
39
  from .simplify_ifs import StimSimplifyIfs
38
40
  from ..rewrite.ifs_to_stim import IfToStim
39
41
 
40
42
 
43
+ @dataclass
44
+ class AggressiveForLoopUnroll(Pass):
45
+ """
46
+ Aggressive unrolling of for loops, addresses cases where unroll
47
+ does not successfully handle nested loops because of a lack of constprop.
48
+
49
+ This should be invoked via fixpoint to let this be repeatedly applied until
50
+ no further rewrites are possible.
51
+ """
52
+
53
+ def unsafe_run(self, mt: Method) -> RewriteResult:
54
+ rule = Chain(
55
+ InlineGetField(),
56
+ InlineGetItem(),
57
+ scf.unroll.ForLoop(),
58
+ scf.trim.UnusedYield(),
59
+ )
60
+
61
+ # Intentionally only walk ONCE, let fixpoint happen with the WHOLE pass
62
+ # so that HintConst gets run right after, allowing subsequent unrolls to happen
63
+ rewrite_result = Walk(rule).rewrite(mt.code)
64
+
65
+ rewrite_result = (
66
+ HintConst(dialects=mt.dialects, no_raise=self.no_raise)
67
+ .unsafe_run(mt)
68
+ .join(rewrite_result)
69
+ )
70
+
71
+ return rewrite_result
72
+
73
+
41
74
  @dataclass
42
75
  class SquinToStimPass(Pass):
43
76
 
@@ -48,27 +81,24 @@ class SquinToStimPass(Pass):
48
81
  dialects=mt.dialects, no_raise=self.no_raise
49
82
  ).unsafe_run(mt)
50
83
 
51
- rule = Chain(
52
- InlineGetField(),
53
- InlineGetItem(),
54
- scf.unroll.ForLoop(),
55
- scf.trim.UnusedYield(),
84
+ rewrite_result = (
85
+ AggressiveForLoopUnroll(dialects=mt.dialects, no_raise=self.no_raise)
86
+ .fixpoint(mt)
87
+ .join(rewrite_result)
56
88
  )
57
- rewrite_result = Fixpoint(Walk(rule)).rewrite(mt.code).join(rewrite_result)
58
- # fold_pass = Fold(mt.dialects, no_raise=self.no_raise)
59
- # rewrite_result = fold_pass(mt)
89
+
60
90
  rewrite_result = (
61
91
  Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
62
92
  )
93
+
94
+ Walk(InlineAlias()).rewrite(mt.code).join(rewrite_result)
95
+
63
96
  rewrite_result = (
64
97
  StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
65
98
  .unsafe_run(mt)
66
99
  .join(rewrite_result)
67
100
  )
68
101
 
69
- # run typeinfer again after unroll etc. because we now insert
70
- # a lot of new nodes, which might have more precise types
71
- # self.typeinfer.unsafe_run(mt)
72
102
  rewrite_result = (
73
103
  Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
74
104
  .rewrite(mt.code)
@@ -82,6 +112,9 @@ class SquinToStimPass(Pass):
82
112
  .join(rewrite_result)
83
113
  )
84
114
 
115
+ TypeInfer(dialects=mt.dialects, no_raise=self.no_raise).unsafe_run(mt)
116
+ Walk(ApplyDesugarRule()).rewrite(mt.code)
117
+
85
118
  # after this the program should be in a state where it is analyzable
86
119
  # -------------------------------------------------------------------
87
120
 
@@ -99,12 +132,14 @@ class SquinToStimPass(Pass):
99
132
  )
100
133
 
101
134
  # 2. rewrite
135
+ ## Invoke DCE afterwards to eliminate any GetItems
136
+ ## that are no longer being used. This allows for
137
+ ## SquinMeasureToStim to safely eliminate
138
+ ## unused measure statements.
102
139
  rewrite_result = (
103
- Walk(
104
- IfToStim(
105
- measure_analysis=meas_analysis_frame.entries,
106
- measure_count=mia.measure_count,
107
- )
140
+ Chain(
141
+ Walk(IfToStim(measure_frame=meas_analysis_frame)),
142
+ Fixpoint(Walk(DeadCodeElimination())),
108
143
  )
109
144
  .rewrite(mt.code)
110
145
  .join(rewrite_result)
@@ -120,17 +155,15 @@ class SquinToStimPass(Pass):
120
155
  Walk(
121
156
  Chain(
122
157
  SquinQubitToStim(),
158
+ SquinMeasureToStim(),
123
159
  SquinWireToStim(),
124
- SquinMeasureToStim(
125
- measure_id_result=meas_analysis_frame.entries,
126
- total_measure_count=mia.measure_count,
127
- ), # reduce duplicated logic, can split out even more rules later
128
160
  SquinWireIdentityElimination(),
129
161
  )
130
162
  )
131
163
  .rewrite(mt.code)
132
164
  .join(rewrite_result)
133
165
  )
166
+
134
167
  rewrite_result = (
135
168
  CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
136
169
  .unsafe_run(mt)
@@ -1,3 +1,4 @@
1
+ from .ifs_to_stim import IfToStim as IfToStim
1
2
  from .squin_noise import SquinNoiseToStim as SquinNoiseToStim
2
3
  from .wire_to_stim import SquinWireToStim as SquinWireToStim
3
4
  from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
@@ -11,9 +11,9 @@ from bloqade.stim.rewrite.util import (
11
11
  SQUIN_STIM_CONTROL_GATE_MAPPING,
12
12
  insert_qubit_idx_from_address,
13
13
  )
14
+ from bloqade.analysis.measure_id import MeasureIDFrame
14
15
  from bloqade.stim.dialects.auxiliary import GetRecord
15
16
  from bloqade.analysis.measure_id.lattice import (
16
- MeasureId,
17
17
  MeasureIdBool,
18
18
  )
19
19
 
@@ -127,8 +127,7 @@ class IfToStim(IfElseSimplification, RewriteRule):
127
127
  Rewrite if statements to stim equivalent statements.
128
128
  """
129
129
 
130
- measure_analysis: dict[ir.SSAValue, MeasureId]
131
- measure_count: int
130
+ measure_frame: MeasureIDFrame
132
131
 
133
132
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
134
133
 
@@ -140,7 +139,7 @@ class IfToStim(IfElseSimplification, RewriteRule):
140
139
 
141
140
  def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
142
141
 
143
- if not isinstance(self.measure_analysis[stmt.cond], MeasureIdBool):
142
+ if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool):
144
143
  return RewriteResult()
145
144
 
146
145
  # check that there is only qubit.Apply in the then-body,
@@ -161,12 +160,12 @@ class IfToStim(IfElseSimplification, RewriteRule):
161
160
  return RewriteResult()
162
161
 
163
162
  # get necessary measurement ID type from analysis
164
- measure_id_bool = self.measure_analysis[stmt.cond]
163
+ measure_id_bool = self.measure_frame.entries[stmt.cond]
165
164
  assert isinstance(measure_id_bool, MeasureIdBool)
166
165
 
167
166
  # generate get record statement
168
167
  measure_id_idx_stmt = py.Constant(
169
- (measure_id_bool.idx - 1) - self.measure_count
168
+ (measure_id_bool.idx - 1) - self.measure_frame.num_measures_at_stmt[stmt]
170
169
  )
171
170
  get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
172
171
 
@@ -2,47 +2,16 @@
2
2
  from dataclasses import dataclass
3
3
 
4
4
  from kirin import ir
5
- from kirin.dialects import py, ilist
5
+ from kirin.dialects import py
6
6
  from kirin.rewrite.abc import RewriteRule, RewriteResult
7
7
 
8
8
  from bloqade.squin import wire, qubit
9
9
  from bloqade.squin.rewrite import AddressAttribute
10
- from bloqade.stim.dialects import collapse, auxiliary
10
+ from bloqade.stim.dialects import collapse
11
11
  from bloqade.stim.rewrite.util import (
12
12
  is_measure_result_used,
13
13
  insert_qubit_idx_from_address,
14
14
  )
15
- from bloqade.analysis.measure_id.lattice import MeasureId, MeasureIdBool, MeasureIdTuple
16
-
17
-
18
- def replace_get_record(
19
- node: ir.Statement, measure_id_bool: MeasureIdBool, meas_count: int
20
- ):
21
- assert isinstance(measure_id_bool, MeasureIdBool)
22
- target_rec_idx = (measure_id_bool.idx - 1) - meas_count
23
- idx_stmt = py.constant.Constant(target_rec_idx)
24
- idx_stmt.insert_before(node)
25
- get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
26
- node.replace_by(get_record_stmt)
27
-
28
-
29
- def insert_get_record_list(
30
- node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count: int
31
- ):
32
- """
33
- Insert GetRecord statements before the given node
34
- """
35
- get_record_ssas = []
36
- for measure_id_bool in measure_id_tuple.data:
37
- assert isinstance(measure_id_bool, MeasureIdBool)
38
- target_rec_idx = (measure_id_bool.idx - 1) - meas_count
39
- idx_stmt = py.constant.Constant(target_rec_idx)
40
- idx_stmt.insert_before(node)
41
- get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
42
- get_record_stmt.insert_before(node)
43
- get_record_ssas.append(get_record_stmt.result)
44
-
45
- node.replace_by(ilist.New(values=get_record_ssas))
46
15
 
47
16
 
48
17
  @dataclass
@@ -51,9 +20,6 @@ class SquinMeasureToStim(RewriteRule):
51
20
  Rewrite squin measure-related statements to stim statements.
52
21
  """
53
22
 
54
- measure_id_result: dict[ir.SSAValue, MeasureId]
55
- total_measure_count: int
56
-
57
23
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
58
24
 
59
25
  match node:
@@ -70,10 +36,6 @@ class SquinMeasureToStim(RewriteRule):
70
36
  if qubit_idx_ssas is None:
71
37
  return RewriteResult()
72
38
 
73
- measure_id = self.measure_id_result[measure_stmt.result]
74
- if not isinstance(measure_id, (MeasureIdBool, MeasureIdTuple)):
75
- return RewriteResult()
76
-
77
39
  prob_noise_stmt = py.constant.Constant(0.0)
78
40
  stim_measure_stmt = collapse.MZ(
79
41
  p=prob_noise_stmt.result,
@@ -84,27 +46,6 @@ class SquinMeasureToStim(RewriteRule):
84
46
 
85
47
  if not is_measure_result_used(measure_stmt):
86
48
  measure_stmt.delete()
87
- return RewriteResult(has_done_something=True)
88
-
89
- # replace dataflow with new stmt!
90
- measure_id = self.measure_id_result[measure_stmt.result]
91
- if isinstance(measure_id, MeasureIdBool):
92
- replace_get_record(
93
- node=measure_stmt,
94
- measure_id_bool=measure_id,
95
- meas_count=self.total_measure_count,
96
- )
97
- elif isinstance(measure_id, MeasureIdTuple):
98
- insert_get_record_list(
99
- node=measure_stmt,
100
- measure_id_tuple=measure_id,
101
- meas_count=self.total_measure_count,
102
- )
103
- else:
104
- # already checked before, so this should not happen
105
- raise ValueError(
106
- f"Unexpected measure ID type: {type(measure_id)} for measure statement {measure_stmt}"
107
- )
108
49
 
109
50
  return RewriteResult(has_done_something=True)
110
51
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bloqade-circuit
3
- Version: 0.6.3
3
+ Version: 0.6.5
4
4
  Summary: The software development toolkit for neutral atom arrays.
5
5
  Author-email: Roger-luo <rluo@quera.com>, kaihsin <khwu@quera.com>, weinbe58 <pweinberg@quera.com>, johnzl-777 <jlong@quera.com>
6
6
  License-File: LICENSE
@@ -9,9 +9,9 @@ bloqade/analysis/address/impls.py,sha256=cWdq1CNV6HVDECtIgl9ogUa8nvFZm_Sy5Wcfk1f
9
9
  bloqade/analysis/address/lattice.py,sha256=dUq999feqPoBYkqEXe1hjHOn4TP_bkvKip8fyWQ-2-8,1755
10
10
  bloqade/analysis/fidelity/__init__.py,sha256=iJkhoHvCMU9bKxQqgxIWKQWvpqNFRgNBI5DK8-4RAB8,59
11
11
  bloqade/analysis/fidelity/analysis.py,sha256=G6JEYc8eeWJ9mwsbUAIzXuU2nrnTU4te41c04xE71gM,3218
12
- bloqade/analysis/measure_id/__init__.py,sha256=J9I58iIyt4IjB36xK6_q9PV1B-40VJ1Gu9hQ432A6gM,98
13
- bloqade/analysis/measure_id/analysis.py,sha256=93S3a_Wu7Bt1j217l1hZMY2tLlR00aUYN27o0IhZsqU,1748
14
- bloqade/analysis/measure_id/impls.py,sha256=ItS9bLMs8UdAkElYgrqD48Xlh9ZyoJq5xDPWcXjWH1c,4823
12
+ bloqade/analysis/measure_id/__init__.py,sha256=r_R_br1e3H7ZzwkeQw4TnDAP4M_bUaRlRb7ZRdowvNI,145
13
+ bloqade/analysis/measure_id/analysis.py,sha256=F22mLWeOLH_QHm25_4DZLLse4ljsiZr7UNThaIy4pzA,2149
14
+ bloqade/analysis/measure_id/impls.py,sha256=ojFSQrrqj-jdcBklaLa3HHojHvw59Z7Dl2iNdfSuP1w,6376
15
15
  bloqade/analysis/measure_id/lattice.py,sha256=WPrn0R79umCH909BFWsUJ64qx9n_3KYimIW5UaXNuGU,1891
16
16
  bloqade/cirq_utils/__init__.py,sha256=1DRDCF3PpgJCOr0z7iULdrn3dqm7GLpRGs9AlqE7XA8,280
17
17
  bloqade/cirq_utils/lineprog.py,sha256=JosrhfeOHI9FycUT_sYFj8TBzLpo97TL8zK-Ap2U4eQ,11021
@@ -22,7 +22,7 @@ bloqade/cirq_utils/noise/conflict_graph.py,sha256=ZUwPWTknrb6SgtZUVPeICn3YA-nUeW
22
22
  bloqade/cirq_utils/noise/model.py,sha256=06Y_BLChOA-PhhAJcWLSgLVAAJoNjOrAujL1YCwcXA0,20590
23
23
  bloqade/cirq_utils/noise/transform.py,sha256=tvDt4WMLM8dKPME51y0_peSZk2-jKmjq0urOxm0lWuQ,2309
24
24
  bloqade/pyqrack/__init__.py,sha256=lonTS-luJkTVujCCtgdZRC12V7FQdoFcozAI-byXwN0,810
25
- bloqade/pyqrack/base.py,sha256=9z61PaaAFqCBBwkgsDZSr-qr9IQ5OJ_JUvltmJ7Bgls,4407
25
+ bloqade/pyqrack/base.py,sha256=g0GRlEgyJ_P8z-lR8RK2CAuRTj6KPfglKX0iwrgg4DM,4408
26
26
  bloqade/pyqrack/device.py,sha256=40vduanEgA26GAW3buHoRpyqPA0xUt2tONY3w5JeH5s,7524
27
27
  bloqade/pyqrack/reg.py,sha256=uTL07CT1R0xUsInLmwU9YuuNdV6lV0lCs1zhdUz1qIs,1660
28
28
  bloqade/pyqrack/target.py,sha256=c78VtLWAiDNp_0sXwvVzhaEoeFsr1fUVsupxWuo6p3s,3661
@@ -125,10 +125,12 @@ bloqade/rewrite/rules/__init__.py,sha256=3e1Z5T3INqNtP6OU0Vivu_SdMOR_2KDixeA0Yjo
125
125
  bloqade/rewrite/rules/flatten_ilist.py,sha256=QoIxMaBXSlatpWzi5s_MAPnV3bV3GeoWc31RBw0WQ3s,1465
126
126
  bloqade/rewrite/rules/inline_getitem_ilist.py,sha256=uIXQRCsr3_GPMciDT4ghI-ezhQmkDcGcC6pguABPUVw,875
127
127
  bloqade/rewrite/rules/split_ifs.py,sha256=Nm4lpEUHZcnCeewIld0tt7UuGO69LiBGl7Uybuwissw,2119
128
- bloqade/squin/__init__.py,sha256=MH7i5gR9DhTjLMI6vsP_NT7_yoaEowYiQwsYhrrUEX0,454
128
+ bloqade/squin/__init__.py,sha256=b7ZD69ql9GriIPxN6JhWxANJVzuIJh_r-gC24wsW1mM,621
129
129
  bloqade/squin/_typeinfer.py,sha256=bilWfC6whTMwewFCqDgB6vDHZsgXPr3azNOYqqnvtB4,780
130
+ bloqade/squin/gate.py,sha256=tCnjfrSVsXHX7VxkEulZ2SQS5ydtmON8QlcGibM6c2I,4028
130
131
  bloqade/squin/groups.py,sha256=RXGJnNZUSXF_f5ljjhZ9At8UhaijayoxFoWvxEsUOWc,1310
131
132
  bloqade/squin/lowering.py,sha256=SR6q-IfV8WHPKT97M7UFu5KoRgAojfDno8Bft1mUSKM,1736
133
+ bloqade/squin/parallel.py,sha256=X6Ps9kQIgnFMlZO14y2ntdxvivqbIP28PAWF8KmxByM,5172
132
134
  bloqade/squin/qubit.py,sha256=LgNJsm6qCyP7_O-lZg3YT8IiqzF5W5ff1VwQ79nXN4c,5148
133
135
  bloqade/squin/types.py,sha256=T3lkqid4HEWuAK_wRns_p-K5DbLDwlldoyZtVay7A3o,119
134
136
  bloqade/squin/wire.py,sha256=GZhF0EHCu7OU70zTV_N83yann-eQnYG_lM2u0QYFoAs,6596
@@ -203,14 +205,14 @@ bloqade/stim/emit/__init__.py,sha256=N2dPQY7OyqPwHAStDeOgYg2yfxqxMOz-N7pD5Z4JwlI
203
205
  bloqade/stim/emit/stim_str.py,sha256=JyEBoIhLQASogZcUWHI9tMD4JoXYrEqUr2qaZ30gZdc,1491
204
206
  bloqade/stim/parse/__init__.py,sha256=l2DjReB2KkgrDjP_4nP6RnoziiOewoSeZfTno1sVYTw,59
205
207
  bloqade/stim/parse/lowering.py,sha256=L-IcR_exlxsTVv4SQ0bhzIF4_L82P-GEdK6qRd6B86Y,23723
206
- bloqade/stim/passes/__init__.py,sha256=SO4OJLaRxq9Lt2AaUxNqiAz-eJnfWsJ6TijUAX64DM4,62
207
- bloqade/stim/passes/simplify_ifs.py,sha256=45lWKmc6ybt0KjEuE297IouPVJ7NwX_j9q4GNUsIkEc,690
208
- bloqade/stim/passes/squin_to_stim.py,sha256=pZvSd1rmOnLRfZONctQAb6ORjcoMd1Qjpbhyf-7m-MQ,5125
209
- bloqade/stim/rewrite/__init__.py,sha256=SHWryh7rZHXOlIz8BMNpj-w7-8VQCRMLt6PfzYFbBfw,434
210
- bloqade/stim/rewrite/ifs_to_stim.py,sha256=LXkmmTIgi8-Exrz1EGjG8QpdRn6EV3HZM-bZeAaf-cQ,6718
208
+ bloqade/stim/passes/__init__.py,sha256=aysjOZyn0IrJQCQBEqiz8pwZ5u5t2s9TmEzA9Y9KG9w,167
209
+ bloqade/stim/passes/simplify_ifs.py,sha256=zicqggWu_yzfrf2a7uUCt-ZenbYSEnFsyGxDfKw72qQ,1084
210
+ bloqade/stim/passes/squin_to_stim.py,sha256=6TpnpvA6JA3dQyH6mxNpGn9-__sPlSmyUBIdThn9xVg,6018
211
+ bloqade/stim/rewrite/__init__.py,sha256=zL5G73JEsXkehN7gCtUgGnmC2BJ3vKihOd1ohVwM68E,480
212
+ bloqade/stim/rewrite/ifs_to_stim.py,sha256=jmCb6AwNqGXWuLR8BkAYIVvAlOptQzVPHWv52_v-Vsw,6755
211
213
  bloqade/stim/rewrite/py_constant_to_stim.py,sha256=PV8bHvn759-d_0JW4akaGSORW_oxigrlUBhAC51PJAU,1354
212
214
  bloqade/stim/rewrite/qubit_to_stim.py,sha256=oiKmi8BlBwXJq-8kGhN1nXgyxJ2UIt_9uouNkU1J8vs,2624
213
- bloqade/stim/rewrite/squin_measure.py,sha256=zPH2q_ciV2D615GK9l9LWGYmjv3dOju18jKMYERIL7c,4817
215
+ bloqade/stim/rewrite/squin_measure.py,sha256=1zuILosGACN7rPYA87MYVwv0M4pPTala1YTe9owbhkw,2519
214
216
  bloqade/stim/rewrite/squin_noise.py,sha256=NafmAiByT4Y5895fZM4Od8arKjsJuW6F5wvRpAFFo70,6240
215
217
  bloqade/stim/rewrite/util.py,sha256=xnLDiEj45CBoG3mpG-ywE1Jjh1k_OVP4iI1A75VR6sw,7257
216
218
  bloqade/stim/rewrite/wire_identity_elimination.py,sha256=Cscu8yaSslPuW04HvbXx4HJ3JzdUZNUMyFqcvuc4sxY,795
@@ -228,7 +230,7 @@ bloqade/visual/animation/runtime/atoms.py,sha256=EmjxhujLiHHPS_HtH_B-7TiqeHgvW5u
228
230
  bloqade/visual/animation/runtime/ppoly.py,sha256=JB9IP53N1w6adBJEue6J5Nmj818Id9JvrlgrmiQTU1I,1385
229
231
  bloqade/visual/animation/runtime/qpustate.py,sha256=rlmxQeJSvaohXrTpXQL5y-NJcpvfW33xPaYM1slv7cc,4270
230
232
  bloqade/visual/animation/runtime/utils.py,sha256=ju9IzOWX-vKwfpqUjlUKu3Ssr_UFPFFq-tzH_Nqyo_c,1212
231
- bloqade_circuit-0.6.3.dist-info/METADATA,sha256=K2Me5w0S2bHxND1_JeYHn5CBC7HjLOdzfpacoATbqWo,3849
232
- bloqade_circuit-0.6.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
233
- bloqade_circuit-0.6.3.dist-info/licenses/LICENSE,sha256=S5GIJwR6QCixPA9wryYb44ZEek0Nz4rt_zLUqP05UbU,13160
234
- bloqade_circuit-0.6.3.dist-info/RECORD,,
233
+ bloqade_circuit-0.6.5.dist-info/METADATA,sha256=8Zb_1qZrO5V9v3JR32KxFQk9malYU1JPPJpkf1U9RPc,3849
234
+ bloqade_circuit-0.6.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
235
+ bloqade_circuit-0.6.5.dist-info/licenses/LICENSE,sha256=S5GIJwR6QCixPA9wryYb44ZEek0Nz4rt_zLUqP05UbU,13160
236
+ bloqade_circuit-0.6.5.dist-info/RECORD,,