bloqade-circuit 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (33) hide show
  1. bloqade/analysis/address/impls.py +21 -68
  2. bloqade/analysis/measure_id/__init__.py +2 -0
  3. bloqade/analysis/measure_id/analysis.py +45 -0
  4. bloqade/analysis/measure_id/impls.py +155 -0
  5. bloqade/analysis/measure_id/lattice.py +82 -0
  6. bloqade/qasm2/passes/unroll_if.py +9 -2
  7. bloqade/rewrite/__init__.py +0 -0
  8. bloqade/rewrite/passes/__init__.py +1 -0
  9. bloqade/rewrite/passes/canonicalize_ilist.py +28 -0
  10. bloqade/rewrite/rules/__init__.py +1 -0
  11. bloqade/rewrite/rules/flatten_ilist.py +51 -0
  12. bloqade/rewrite/rules/inline_getitem_ilist.py +31 -0
  13. bloqade/{qasm2/rewrite → rewrite/rules}/split_ifs.py +15 -8
  14. bloqade/squin/__init__.py +1 -0
  15. bloqade/squin/analysis/__init__.py +1 -0
  16. bloqade/squin/analysis/address_impl.py +71 -0
  17. bloqade/squin/cirq/lowering.py +2 -1
  18. bloqade/squin/noise/stmts.py +1 -1
  19. bloqade/stim/dialects/auxiliary/interp.py +0 -10
  20. bloqade/stim/dialects/auxiliary/stmts/annotate.py +1 -1
  21. bloqade/stim/passes/__init__.py +1 -1
  22. bloqade/stim/passes/simplify_ifs.py +32 -0
  23. bloqade/stim/passes/squin_to_stim.py +95 -27
  24. bloqade/stim/rewrite/ifs_to_stim.py +203 -0
  25. bloqade/stim/rewrite/qubit_to_stim.py +3 -0
  26. bloqade/stim/rewrite/squin_measure.py +68 -5
  27. bloqade/stim/rewrite/util.py +0 -4
  28. bloqade/stim/upstream/__init__.py +1 -0
  29. bloqade/stim/upstream/from_squin.py +10 -0
  30. {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/METADATA +1 -1
  31. {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/RECORD +33 -18
  32. {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/WHEEL +0 -0
  33. {bloqade_circuit-0.5.0.dist-info → bloqade_circuit-0.5.2.dist-info}/licenses/LICENSE +0 -0
@@ -6,13 +6,10 @@ from kirin import interp
6
6
  from kirin.analysis import ForwardFrame, const
7
7
  from kirin.dialects import cf, py, scf, func, ilist
8
8
 
9
- from bloqade import squin
10
-
11
9
  from .lattice import (
12
10
  Address,
13
11
  NotQubit,
14
12
  AddressReg,
15
- AddressWire,
16
13
  AddressQubit,
17
14
  AddressTuple,
18
15
  )
@@ -73,8 +70,19 @@ class PyList(interp.MethodTable):
73
70
  class PyIndexing(interp.MethodTable):
74
71
  @interp.impl(py.GetItem)
75
72
  def getitem(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.GetItem):
76
- # Integer index into the thing being indexed
77
- idx = interp.get_const_value(int, stmt.index)
73
+
74
+ # determine if the index is an int constant
75
+ # or a slice
76
+ hint = stmt.index.hints.get("const")
77
+ if hint is None:
78
+ return (NotQubit(),)
79
+
80
+ if isinstance(hint, const.Value):
81
+ idx = hint.data
82
+ elif isinstance(hint, slice):
83
+ idx = hint
84
+ else:
85
+ return (NotQubit(),)
78
86
 
79
87
  # The object being indexed into
80
88
  obj = frame.get(stmt.obj)
@@ -82,10 +90,15 @@ class PyIndexing(interp.MethodTable):
82
90
  # so we just extract that here
83
91
  if isinstance(obj, AddressTuple):
84
92
  return (obj.data[idx],)
85
- # an AddressReg is guaranteed to just have some sequence
86
- # of integers which is directly pluggable to AddressQubit
93
+ # If idx is an integer index into an AddressReg,
94
+ # then it's safe to assume a single qubit is being accessed.
95
+ # On the other hand, if it's a slice, we return
96
+ # a new AddressReg to preserve the new sequence.
87
97
  elif isinstance(obj, AddressReg):
88
- return (AddressQubit(obj.data[idx]),)
98
+ if isinstance(idx, slice):
99
+ return (AddressReg(data=obj.data[idx]),)
100
+ if isinstance(idx, int):
101
+ return (AddressQubit(obj.data[idx]),)
89
102
  else:
90
103
  return (NotQubit(),)
91
104
 
@@ -163,63 +176,3 @@ class Scf(scf.absint.Methods):
163
176
  return # if terminate is Return, there is no result
164
177
 
165
178
  return loop_vars
166
-
167
-
168
- # Address lattice elements we can work with:
169
- ## NotQubit (bottom), AnyAddress (top)
170
-
171
- ## AddressTuple -> data: tuple[Address, ...]
172
- ### Recursive type, could contain itself or other variants
173
- ### This pops up in cases where you can have an IList/Tuple
174
- ### That contains elements that could be other Address types
175
-
176
- ## AddressReg -> data: Sequence[int]
177
- ### specific to creation of a register of qubits
178
-
179
- ## AddressQubit -> data: int
180
- ### Base qubit address type
181
-
182
-
183
- @squin.wire.dialect.register(key="qubit.address")
184
- class SquinWireMethodTable(interp.MethodTable):
185
-
186
- @interp.impl(squin.wire.Unwrap)
187
- def unwrap(
188
- self,
189
- interp_: AddressAnalysis,
190
- frame: ForwardFrame[Address],
191
- stmt: squin.wire.Unwrap,
192
- ):
193
-
194
- origin_qubit = frame.get(stmt.qubit)
195
-
196
- if isinstance(origin_qubit, AddressQubit):
197
- return (AddressWire(origin_qubit=origin_qubit),)
198
- else:
199
- return (Address.top(),)
200
-
201
- @interp.impl(squin.wire.Apply)
202
- def apply(
203
- self,
204
- interp_: AddressAnalysis,
205
- frame: ForwardFrame[Address],
206
- stmt: squin.wire.Apply,
207
- ):
208
- return frame.get_values(stmt.inputs)
209
-
210
-
211
- @squin.qubit.dialect.register(key="qubit.address")
212
- class SquinQubitMethodTable(interp.MethodTable):
213
-
214
- # This can be treated like a QRegNew impl
215
- @interp.impl(squin.qubit.New)
216
- def new(
217
- self,
218
- interp_: AddressAnalysis,
219
- frame: ForwardFrame[Address],
220
- stmt: squin.qubit.New,
221
- ):
222
- n_qubits = interp_.get_const_value(int, stmt.n_qubits)
223
- addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
224
- interp_.next_address += n_qubits
225
- return (addr,)
@@ -0,0 +1,2 @@
1
+ from . import impls as impls
2
+ from .analysis import MeasurementIDAnalysis as MeasurementIDAnalysis
@@ -0,0 +1,45 @@
1
+ from typing import TypeVar
2
+
3
+ from kirin import ir, interp
4
+ from kirin.analysis import Forward, const
5
+ from kirin.analysis.forward import ForwardFrame
6
+
7
+ from .lattice import MeasureId, NotMeasureId
8
+
9
+
10
+ class MeasurementIDAnalysis(Forward[MeasureId]):
11
+
12
+ keys = ["measure_id"]
13
+ lattice = MeasureId
14
+ # for every kind of measurement encountered, increment this
15
+ # then use this to generate the negative values for target rec indices
16
+ measure_count = 0
17
+
18
+ # Still default to bottom,
19
+ # but let constants return the softer "NoMeasureId" type from impl
20
+ def eval_stmt_fallback(
21
+ self, frame: ForwardFrame[MeasureId], stmt: ir.Statement
22
+ ) -> tuple[MeasureId, ...]:
23
+ return tuple(NotMeasureId() for _ in stmt.results)
24
+
25
+ def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
26
+ # NOTE: we do not support dynamic calls here, thus no need to propagate method object
27
+ return self.run_callable(method.code, (self.lattice.bottom(),) + args)
28
+
29
+ T = TypeVar("T")
30
+
31
+ # Xiu-zhe (Roger) Luo came up with this in the address analysis,
32
+ # reused here for convenience
33
+ # TODO: Remove this function once upgrade to kirin 0.18 happens,
34
+ # method is built-in to interpreter then
35
+ def get_const_value(self, input_type: type[T], value: ir.SSAValue) -> T:
36
+ if isinstance(hint := value.hints.get("const"), const.Value):
37
+ data = hint.data
38
+ if isinstance(data, input_type):
39
+ return hint.data
40
+ raise interp.InterpreterError(
41
+ f"Expected constant value <type = {input_type}>, got {data}"
42
+ )
43
+ raise interp.InterpreterError(
44
+ f"Expected constant value <type = {input_type}>, got {value}"
45
+ )
@@ -0,0 +1,155 @@
1
+ from kirin import types as kirin_types, interp
2
+ from kirin.dialects import py, scf, func, ilist
3
+
4
+ from bloqade.squin import wire, qubit
5
+
6
+ from .lattice import (
7
+ AnyMeasureId,
8
+ NotMeasureId,
9
+ MeasureIdBool,
10
+ MeasureIdTuple,
11
+ InvalidMeasureId,
12
+ )
13
+ from .analysis import MeasurementIDAnalysis
14
+
15
+ ## Can't do wire right now because of
16
+ ## unresolved RFC on return type
17
+ # from bloqade.squin import wire
18
+
19
+
20
+ @qubit.dialect.register(key="measure_id")
21
+ class SquinQubit(interp.MethodTable):
22
+
23
+ @interp.impl(qubit.MeasureQubit)
24
+ def measure_qubit(
25
+ self,
26
+ interp: MeasurementIDAnalysis,
27
+ frame: interp.Frame,
28
+ stmt: qubit.MeasureQubit,
29
+ ):
30
+ interp.measure_count += 1
31
+ return (MeasureIdBool(interp.measure_count),)
32
+
33
+ @interp.impl(qubit.MeasureQubitList)
34
+ def measure_qubit_list(
35
+ self,
36
+ interp: MeasurementIDAnalysis,
37
+ frame: interp.Frame,
38
+ stmt: qubit.MeasureQubitList,
39
+ ):
40
+
41
+ # try to get the length of the list
42
+ ## "...safely assume the type inference will give you what you need"
43
+ qubits_type = stmt.qubits.type
44
+ # vars[0] is just the type of the elements in the ilist,
45
+ # vars[1] can contain a literal with length information
46
+ num_qubits = qubits_type.vars[1]
47
+ if not isinstance(num_qubits, kirin_types.Literal):
48
+ return (AnyMeasureId(),)
49
+
50
+ measure_id_bools = []
51
+ for _ in range(num_qubits.data):
52
+ interp.measure_count += 1
53
+ measure_id_bools.append(MeasureIdBool(interp.measure_count))
54
+
55
+ return (MeasureIdTuple(data=tuple(measure_id_bools)),)
56
+
57
+
58
+ @wire.dialect.register(key="measure_id")
59
+ class SquinWire(interp.MethodTable):
60
+
61
+ @interp.impl(wire.Measure)
62
+ def measure_qubit(
63
+ self,
64
+ interp: MeasurementIDAnalysis,
65
+ frame: interp.Frame,
66
+ stmt: wire.Measure,
67
+ ):
68
+ interp.measure_count += 1
69
+ return (MeasureIdBool(interp.measure_count),)
70
+
71
+
72
+ @ilist.dialect.register(key="measure_id")
73
+ class IList(interp.MethodTable):
74
+ @interp.impl(ilist.New)
75
+ # Because of the way GetItem works,
76
+ # A user could create an ilist of bools that
77
+ # ends up being a mixture of MeasureIdBool and NotMeasureId
78
+ def new_ilist(
79
+ self,
80
+ interp: MeasurementIDAnalysis,
81
+ frame: interp.Frame,
82
+ stmt: ilist.New,
83
+ ):
84
+
85
+ measure_ids_in_ilist = frame.get_values(stmt.values)
86
+ return (MeasureIdTuple(data=tuple(measure_ids_in_ilist)),)
87
+
88
+
89
+ @py.tuple.dialect.register(key="measure_id")
90
+ class PyTuple(interp.MethodTable):
91
+ @interp.impl(py.tuple.New)
92
+ def new_tuple(
93
+ self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.tuple.New
94
+ ):
95
+ measure_ids_in_tuple = frame.get_values(stmt.args)
96
+ return (MeasureIdTuple(data=tuple(measure_ids_in_tuple)),)
97
+
98
+
99
+ @py.indexing.dialect.register(key="measure_id")
100
+ class PyIndexing(interp.MethodTable):
101
+ @interp.impl(py.GetItem)
102
+ def getitem(
103
+ self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.GetItem
104
+ ):
105
+ idx = interp.get_const_value(int, stmt.index)
106
+ obj = frame.get(stmt.obj)
107
+ if isinstance(obj, MeasureIdTuple):
108
+ return (obj.data[idx],)
109
+ # just propagate these down the line
110
+ elif isinstance(obj, (AnyMeasureId, NotMeasureId)):
111
+ return (obj,)
112
+ else:
113
+ return (InvalidMeasureId(),)
114
+
115
+
116
+ @py.binop.dialect.register(key="measure_id")
117
+ class PyBinOp(interp.MethodTable):
118
+ @interp.impl(py.Add)
119
+ def add(self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.Add):
120
+ lhs = frame.get(stmt.lhs)
121
+ rhs = frame.get(stmt.rhs)
122
+
123
+ if isinstance(lhs, MeasureIdTuple) and isinstance(rhs, MeasureIdTuple):
124
+ return (MeasureIdTuple(data=lhs.data + rhs.data),)
125
+ else:
126
+ return (InvalidMeasureId(),)
127
+
128
+
129
+ @func.dialect.register(key="measure_id")
130
+ class Func(interp.MethodTable):
131
+ @interp.impl(func.Return)
132
+ def return_(self, _: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Return):
133
+ return interp.ReturnValue(frame.get(stmt.value))
134
+
135
+ # taken from Address Analysis implementation from Xiu-zhe (Roger) Luo
136
+ @interp.impl(
137
+ func.Invoke
138
+ ) # we know the callee already, func.Call would mean we don't know the callee @ compile time
139
+ def invoke(
140
+ self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke
141
+ ):
142
+ _, ret = interp_.run_method(
143
+ stmt.callee,
144
+ interp_.permute_values(
145
+ stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
146
+ ),
147
+ )
148
+ return (ret,)
149
+
150
+
151
+ # Just let analysis propagate through
152
+ # scf, particularly IfElse
153
+ @scf.dialect.register(key="measure_id")
154
+ class Scf(scf.absint.Methods):
155
+ pass
@@ -0,0 +1,82 @@
1
+ from typing import final
2
+ from dataclasses import dataclass
3
+
4
+ from kirin.lattice import (
5
+ SingletonMeta,
6
+ BoundedLattice,
7
+ SimpleJoinMixin,
8
+ SimpleMeetMixin,
9
+ )
10
+
11
+ # Taken directly from Kai-Hsin Wu's implementation
12
+ # with minor changes to names and addition of CanMeasureId type
13
+
14
+
15
+ @dataclass
16
+ class MeasureId(
17
+ SimpleJoinMixin["MeasureId"],
18
+ SimpleMeetMixin["MeasureId"],
19
+ BoundedLattice["MeasureId"],
20
+ ):
21
+
22
+ @classmethod
23
+ def bottom(cls) -> "MeasureId":
24
+ return InvalidMeasureId()
25
+
26
+ @classmethod
27
+ def top(cls) -> "MeasureId":
28
+ return AnyMeasureId()
29
+
30
+
31
+ # Can pop up if user constructs some list containing a mixture
32
+ # of bools from measure results and other places,
33
+ # in which case the whole list is invalid
34
+ @final
35
+ @dataclass
36
+ class InvalidMeasureId(MeasureId, metaclass=SingletonMeta):
37
+
38
+ def is_subseteq(self, other: MeasureId) -> bool:
39
+ return True
40
+
41
+
42
+ @final
43
+ @dataclass
44
+ class AnyMeasureId(MeasureId, metaclass=SingletonMeta):
45
+
46
+ def is_subseteq(self, other: MeasureId) -> bool:
47
+ return isinstance(other, AnyMeasureId)
48
+
49
+
50
+ @final
51
+ @dataclass
52
+ class NotMeasureId(MeasureId, metaclass=SingletonMeta):
53
+
54
+ def is_subseteq(self, other: MeasureId) -> bool:
55
+ return isinstance(other, NotMeasureId)
56
+
57
+
58
+ @final
59
+ @dataclass
60
+ class MeasureIdBool(MeasureId):
61
+ idx: int
62
+
63
+ def is_subseteq(self, other: MeasureId) -> bool:
64
+ if isinstance(other, MeasureIdBool):
65
+ return self.idx == other.idx
66
+ return False
67
+
68
+
69
+ # Might be nice to have some print override
70
+ # here so all the CanMeasureId's/other types are consolidated for
71
+ # readability
72
+
73
+
74
+ @final
75
+ @dataclass
76
+ class MeasureIdTuple(MeasureId):
77
+ data: tuple[MeasureId, ...]
78
+
79
+ def is_subseteq(self, other: MeasureId) -> bool:
80
+ if isinstance(other, MeasureIdTuple):
81
+ return all(a.is_subseteq(b) for a, b in zip(self.data, other.data))
82
+ return False
@@ -7,15 +7,22 @@ from kirin.rewrite import (
7
7
  ConstantFold,
8
8
  CommonSubexpressionElimination,
9
9
  )
10
+ from kirin.dialects import scf, func
10
11
 
11
- from ..rewrite.split_ifs import LiftThenBody, SplitIfStmts
12
+ from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
13
+
14
+ from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
15
+ from ..dialects.core.stmts import Reset, Measure
16
+
17
+ AllowedThenType = (SingleQubitGate, TwoQubitCtrlGate, Measure, Reset)
18
+ DontLiftType = AllowedThenType + (scf.Yield, func.Return, func.Invoke)
12
19
 
13
20
 
14
21
  class UnrollIfs(Pass):
15
22
  """This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
16
23
 
17
24
  def unsafe_run(self, mt: ir.Method):
18
- result = Walk(LiftThenBody()).rewrite(mt.code)
25
+ result = Walk(LiftThenBody(exclude_stmts=DontLiftType)).rewrite(mt.code)
19
26
  result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
20
27
  result = (
21
28
  Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
File without changes
@@ -0,0 +1 @@
1
+ from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
@@ -0,0 +1,28 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Pass
5
+ from kirin.rewrite import (
6
+ Walk,
7
+ Chain,
8
+ Fixpoint,
9
+ )
10
+ from kirin.analysis import const
11
+
12
+ from ..rules.flatten_ilist import FlattenAddOpIList
13
+ from ..rules.inline_getitem_ilist import InlineGetItemFromIList
14
+
15
+
16
+ @dataclass
17
+ class CanonicalizeIList(Pass):
18
+
19
+ def unsafe_run(self, mt: ir.Method):
20
+
21
+ cp_result_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
22
+
23
+ return Fixpoint(
24
+ Chain(
25
+ Walk(InlineGetItemFromIList(constprop_result=cp_result_frame.entries)),
26
+ Walk(FlattenAddOpIList()),
27
+ )
28
+ ).rewrite(mt.code)
@@ -0,0 +1 @@
1
+ from .split_ifs import LiftThenBody as LiftThenBody, SplitIfStmts as SplitIfStmts
@@ -0,0 +1,51 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.dialects import py, ilist
5
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
6
+
7
+
8
+ @dataclass
9
+ class FlattenAddOpIList(RewriteRule):
10
+
11
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
12
+ if not isinstance(node, py.binop.Add):
13
+ return RewriteResult()
14
+
15
+ # check if we are adding two ilist.New objects
16
+ new_data = ()
17
+
18
+ # lhs:
19
+ if not isinstance(node.lhs.owner, ilist.New):
20
+ if not (
21
+ isinstance(node.lhs.owner, py.Constant)
22
+ and isinstance(
23
+ const_ilist := node.lhs.owner.value.unwrap(), ilist.IList
24
+ )
25
+ and len(const_ilist.data) == 0
26
+ ):
27
+ return RewriteResult()
28
+
29
+ else:
30
+ new_data += node.lhs.owner.values
31
+
32
+ # rhs:
33
+ if not isinstance(node.rhs.owner, ilist.New):
34
+ if not (
35
+ isinstance(node.rhs.owner, py.Constant)
36
+ and isinstance(
37
+ const_ilist := node.rhs.owner.value.unwrap(), ilist.IList
38
+ )
39
+ and len(const_ilist.data) == 0
40
+ ):
41
+ return RewriteResult()
42
+
43
+ else:
44
+ new_data += node.rhs.owner.values
45
+
46
+ new_stmt = ilist.New(values=new_data)
47
+ node.replace_by(new_stmt)
48
+
49
+ return RewriteResult(
50
+ has_done_something=True,
51
+ )
@@ -0,0 +1,31 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.analysis import const
5
+ from kirin.dialects import py, ilist
6
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
7
+
8
+
9
+ @dataclass
10
+ class InlineGetItemFromIList(RewriteRule):
11
+ constprop_result: dict[ir.SSAValue, const.Result]
12
+
13
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14
+ if not isinstance(node, py.indexing.GetItem):
15
+ return RewriteResult()
16
+
17
+ if not isinstance(node.obj.owner, ilist.New):
18
+ return RewriteResult()
19
+
20
+ if not isinstance(
21
+ index_value := self.constprop_result.get(node.index), const.Value
22
+ ):
23
+ return RewriteResult()
24
+
25
+ elem_ssa = node.obj.owner.values[index_value.data]
26
+
27
+ node.result.replace_by(elem_ssa)
28
+
29
+ return RewriteResult(
30
+ has_done_something=True,
31
+ )
@@ -1,18 +1,23 @@
1
+ from dataclasses import field, dataclass
2
+
1
3
  from kirin import ir
2
4
  from kirin.dialects import scf, func
3
5
  from kirin.rewrite.abc import RewriteRule, RewriteResult
4
6
 
5
- from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
6
- from ..dialects.core.stmts import Reset, Measure
7
7
 
8
- # TODO: unify with PR #248
9
- AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
8
+ @dataclass
9
+ class LiftThenBody(RewriteRule):
10
+ """
11
+ Lifts anything that's not in the `exclude_stmts` in the *then* body
12
+
10
13
 
11
- DontLiftType = AllowedThenType | scf.Yield | func.Return | func.Invoke
14
+ Args:
15
+ exclude_stmts: A tuple of statement types that should not be lifted from the then body.
16
+ Defaults to an empty tuple, meaning all statements are lifted.
12
17
 
18
+ """
13
19
 
14
- class LiftThenBody(RewriteRule):
15
- """Lifts anything that's not a UOP or a yield/return out of the then body"""
20
+ exclude_stmts: tuple[type[ir.Statement], ...] = field(default_factory=tuple)
16
21
 
17
22
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18
23
  if not isinstance(node, scf.IfElse):
@@ -20,7 +25,9 @@ class LiftThenBody(RewriteRule):
20
25
 
21
26
  then_stmts = node.then_body.stmts()
22
27
 
23
- lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)]
28
+ lift_stmts = [
29
+ stmt for stmt in then_stmts if not isinstance(stmt, self.exclude_stmts)
30
+ ]
24
31
 
25
32
  if len(lift_stmts) == 0:
26
33
  return RewriteResult()
bloqade/squin/__init__.py CHANGED
@@ -3,6 +3,7 @@ from . import (
3
3
  wire as wire,
4
4
  noise as noise,
5
5
  qubit as qubit,
6
+ analysis as analysis,
6
7
  lowering as lowering,
7
8
  _typeinfer as _typeinfer,
8
9
  )
@@ -0,0 +1 @@
1
+ from . import address_impl as address_impl
@@ -0,0 +1,71 @@
1
+ from kirin import interp
2
+ from kirin.analysis import ForwardFrame
3
+
4
+ from bloqade.analysis.address.lattice import (
5
+ Address,
6
+ AddressReg,
7
+ AddressWire,
8
+ AddressQubit,
9
+ )
10
+ from bloqade.analysis.address.analysis import AddressAnalysis
11
+
12
+ from .. import wire, qubit
13
+
14
+ # Address lattice elements we can work with:
15
+ ## NotQubit (bottom), AnyAddress (top)
16
+
17
+ ## AddressTuple -> data: tuple[Address, ...]
18
+ ### Recursive type, could contain itself or other variants
19
+ ### This pops up in cases where you can have an IList/Tuple
20
+ ### That contains elements that could be other Address types
21
+
22
+ ## AddressReg -> data: Sequence[int]
23
+ ### specific to creation of a register of qubits
24
+
25
+ ## AddressQubit -> data: int
26
+ ### Base qubit address type
27
+
28
+
29
+ @wire.dialect.register(key="qubit.address")
30
+ class SquinWireMethodTable(interp.MethodTable):
31
+
32
+ @interp.impl(wire.Unwrap)
33
+ def unwrap(
34
+ self,
35
+ interp_: AddressAnalysis,
36
+ frame: ForwardFrame[Address],
37
+ stmt: wire.Unwrap,
38
+ ):
39
+
40
+ origin_qubit = frame.get(stmt.qubit)
41
+
42
+ if isinstance(origin_qubit, AddressQubit):
43
+ return (AddressWire(origin_qubit=origin_qubit),)
44
+ else:
45
+ return (Address.top(),)
46
+
47
+ @interp.impl(wire.Apply)
48
+ def apply(
49
+ self,
50
+ interp_: AddressAnalysis,
51
+ frame: ForwardFrame[Address],
52
+ stmt: wire.Apply,
53
+ ):
54
+ return frame.get_values(stmt.inputs)
55
+
56
+
57
+ @qubit.dialect.register(key="qubit.address")
58
+ class SquinQubitMethodTable(interp.MethodTable):
59
+
60
+ # This can be treated like a QRegNew impl
61
+ @interp.impl(qubit.New)
62
+ def new(
63
+ self,
64
+ interp_: AddressAnalysis,
65
+ frame: ForwardFrame[Address],
66
+ stmt: qubit.New,
67
+ ):
68
+ n_qubits = interp_.get_const_value(int, stmt.n_qubits)
69
+ addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
70
+ interp_.next_address += n_qubits
71
+ return (addr,)
@@ -281,7 +281,8 @@ class Squin(lowering.LoweringABC[CirqNode]):
281
281
  return state.current_frame.push(op.stmts.Z())
282
282
 
283
283
  # NOTE: just for the Z gate, an arbitrary exponent is equivalent to the ShiftOp
284
- t = node.exponent
284
+ # up to a minus sign!
285
+ t = -node.exponent
285
286
  theta = state.current_frame.push(py.Constant(math.pi * t))
286
287
  return state.current_frame.push(op.stmts.ShiftOp(theta=theta.result))
287
288