bloqade-circuit 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +21 -68
- bloqade/analysis/measure_id/__init__.py +2 -0
- bloqade/analysis/measure_id/analysis.py +45 -0
- bloqade/analysis/measure_id/impls.py +155 -0
- bloqade/analysis/measure_id/lattice.py +82 -0
- bloqade/cirq_utils/__init__.py +7 -0
- bloqade/cirq_utils/lineprog.py +295 -0
- bloqade/cirq_utils/parallelize.py +400 -0
- bloqade/pyqrack/squin/op.py +7 -2
- bloqade/pyqrack/squin/runtime.py +4 -2
- bloqade/qasm2/dialects/expr/stmts.py +2 -20
- bloqade/qasm2/parse/lowering.py +1 -0
- bloqade/qasm2/passes/parallel.py +18 -0
- bloqade/qasm2/passes/unroll_if.py +9 -2
- bloqade/qasm2/rewrite/__init__.py +1 -0
- bloqade/qasm2/rewrite/parallel_to_glob.py +82 -0
- bloqade/rewrite/__init__.py +0 -0
- bloqade/rewrite/passes/__init__.py +1 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +28 -0
- bloqade/rewrite/rules/__init__.py +1 -0
- bloqade/rewrite/rules/flatten_ilist.py +51 -0
- bloqade/rewrite/rules/inline_getitem_ilist.py +31 -0
- bloqade/{qasm2/rewrite → rewrite/rules}/split_ifs.py +15 -8
- bloqade/squin/__init__.py +2 -0
- bloqade/squin/_typeinfer.py +20 -0
- bloqade/squin/analysis/__init__.py +1 -0
- bloqade/squin/analysis/address_impl.py +71 -0
- bloqade/squin/analysis/nsites/impls.py +6 -1
- bloqade/squin/cirq/lowering.py +19 -6
- bloqade/squin/noise/stmts.py +1 -1
- bloqade/squin/op/__init__.py +1 -0
- bloqade/squin/op/_wrapper.py +4 -0
- bloqade/squin/op/stmts.py +20 -2
- bloqade/squin/qubit.py +8 -5
- bloqade/squin/rewrite/__init__.py +1 -0
- bloqade/squin/rewrite/canonicalize.py +60 -0
- bloqade/squin/rewrite/desugar.py +52 -5
- bloqade/squin/types.py +8 -0
- bloqade/squin/wire.py +91 -5
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +4 -0
- bloqade/stim/dialects/auxiliary/interp.py +0 -10
- bloqade/stim/dialects/auxiliary/stmts/annotate.py +1 -1
- bloqade/stim/dialects/noise/emit.py +1 -0
- bloqade/stim/dialects/noise/stmts.py +5 -0
- bloqade/stim/passes/__init__.py +1 -1
- bloqade/stim/passes/simplify_ifs.py +32 -0
- bloqade/stim/passes/squin_to_stim.py +109 -26
- bloqade/stim/rewrite/__init__.py +1 -0
- bloqade/stim/rewrite/ifs_to_stim.py +203 -0
- bloqade/stim/rewrite/qubit_to_stim.py +13 -6
- bloqade/stim/rewrite/squin_measure.py +68 -5
- bloqade/stim/rewrite/squin_noise.py +120 -0
- bloqade/stim/rewrite/util.py +40 -9
- bloqade/stim/rewrite/wire_to_stim.py +8 -3
- bloqade/stim/upstream/__init__.py +1 -0
- bloqade/stim/upstream/from_squin.py +10 -0
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/METADATA +4 -2
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/RECORD +61 -38
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/licenses/LICENSE +0 -0
bloqade/pyqrack/squin/op.py
CHANGED
|
@@ -96,10 +96,15 @@ class PyQrackMethods(interp.MethodTable):
|
|
|
96
96
|
return (PhaseOpRuntime(theta, global_=global_),)
|
|
97
97
|
|
|
98
98
|
@interp.impl(op.stmts.Reset)
|
|
99
|
+
@interp.impl(op.stmts.ResetToOne)
|
|
99
100
|
def reset(
|
|
100
|
-
self,
|
|
101
|
+
self,
|
|
102
|
+
interp: PyQrackInterpreter,
|
|
103
|
+
frame: interp.Frame,
|
|
104
|
+
stmt: op.stmts.Reset | op.stmts.ResetToOne,
|
|
101
105
|
) -> tuple[OperatorRuntimeABC]:
|
|
102
|
-
|
|
106
|
+
target_state = isinstance(stmt, op.stmts.ResetToOne)
|
|
107
|
+
return (ResetRuntime(target_state=target_state),)
|
|
103
108
|
|
|
104
109
|
@interp.impl(op.stmts.X)
|
|
105
110
|
@interp.impl(op.stmts.Y)
|
bloqade/pyqrack/squin/runtime.py
CHANGED
|
@@ -43,7 +43,9 @@ class OperatorRuntimeABC:
|
|
|
43
43
|
|
|
44
44
|
@dataclass(frozen=True)
|
|
45
45
|
class ResetRuntime(OperatorRuntimeABC):
|
|
46
|
-
"""Reset the qubit to
|
|
46
|
+
"""Reset the qubit to the target state"""
|
|
47
|
+
|
|
48
|
+
target_state: bool
|
|
47
49
|
|
|
48
50
|
@property
|
|
49
51
|
def n_sites(self) -> int:
|
|
@@ -55,7 +57,7 @@ class ResetRuntime(OperatorRuntimeABC):
|
|
|
55
57
|
continue
|
|
56
58
|
|
|
57
59
|
res: bool = qubit.sim_reg.m(qubit.addr)
|
|
58
|
-
if res:
|
|
60
|
+
if res != self.target_state:
|
|
59
61
|
qubit.sim_reg.x(qubit.addr)
|
|
60
62
|
|
|
61
63
|
|
|
@@ -1,34 +1,16 @@
|
|
|
1
1
|
from kirin import ir, types, lowering
|
|
2
2
|
from kirin.decl import info, statement
|
|
3
|
+
from kirin.dialects import func
|
|
3
4
|
from kirin.print.printer import Printer
|
|
4
|
-
from kirin.dialects.func.attrs import Signature
|
|
5
5
|
|
|
6
6
|
from ._dialect import dialect
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class GateFuncOpCallableInterface(ir.CallableStmtInterface["GateFunction"]):
|
|
10
|
-
|
|
11
|
-
@classmethod
|
|
12
|
-
def get_callable_region(cls, stmt: "GateFunction") -> ir.Region:
|
|
13
|
-
return stmt.body
|
|
14
|
-
|
|
15
|
-
|
|
16
9
|
@statement(dialect=dialect)
|
|
17
|
-
class GateFunction(
|
|
10
|
+
class GateFunction(func.Function):
|
|
18
11
|
"""Special Function for qasm2 gate subroutine."""
|
|
19
12
|
|
|
20
13
|
name = "gate.func"
|
|
21
|
-
traits = frozenset(
|
|
22
|
-
{
|
|
23
|
-
ir.IsolatedFromAbove(),
|
|
24
|
-
ir.SymbolOpInterface(),
|
|
25
|
-
ir.HasSignature(),
|
|
26
|
-
GateFuncOpCallableInterface(),
|
|
27
|
-
}
|
|
28
|
-
)
|
|
29
|
-
sym_name: str = info.attribute()
|
|
30
|
-
signature: Signature = info.attribute()
|
|
31
|
-
body: ir.Region = info.region(multi=True)
|
|
32
14
|
|
|
33
15
|
def print_impl(self, printer: Printer) -> None:
|
|
34
16
|
with printer.rich(style="red"):
|
bloqade/qasm2/parse/lowering.py
CHANGED
bloqade/qasm2/passes/parallel.py
CHANGED
|
@@ -26,6 +26,7 @@ from bloqade.qasm2.rewrite import (
|
|
|
26
26
|
ParallelToUOpRule,
|
|
27
27
|
RaiseRegisterRule,
|
|
28
28
|
UOpToParallelRule,
|
|
29
|
+
ParallelToGlobalRule,
|
|
29
30
|
SimpleOptimalMergePolicy,
|
|
30
31
|
RydbergGateSetRewriteRule,
|
|
31
32
|
)
|
|
@@ -183,3 +184,20 @@ class UOpToParallel(Pass):
|
|
|
183
184
|
CommonSubexpressionElimination(),
|
|
184
185
|
)
|
|
185
186
|
return Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@dataclass
|
|
190
|
+
class ParallelToGlobal(Pass):
|
|
191
|
+
|
|
192
|
+
def generate_rule(self, mt: ir.Method) -> ParallelToGlobalRule:
|
|
193
|
+
address_analysis = address.AddressAnalysis(mt.dialects)
|
|
194
|
+
frame, _ = address_analysis.run_analysis(mt)
|
|
195
|
+
return ParallelToGlobalRule(frame.entries)
|
|
196
|
+
|
|
197
|
+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
|
|
198
|
+
rule = self.generate_rule(mt)
|
|
199
|
+
|
|
200
|
+
result = Walk(rule).rewrite(mt.code)
|
|
201
|
+
result = Walk(DeadCodeElimination()).rewrite(mt.code).join(result)
|
|
202
|
+
|
|
203
|
+
return result
|
|
@@ -7,15 +7,22 @@ from kirin.rewrite import (
|
|
|
7
7
|
ConstantFold,
|
|
8
8
|
CommonSubexpressionElimination,
|
|
9
9
|
)
|
|
10
|
+
from kirin.dialects import scf, func
|
|
10
11
|
|
|
11
|
-
from
|
|
12
|
+
from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
|
|
13
|
+
|
|
14
|
+
from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
|
|
15
|
+
from ..dialects.core.stmts import Reset, Measure
|
|
16
|
+
|
|
17
|
+
AllowedThenType = (SingleQubitGate, TwoQubitCtrlGate, Measure, Reset)
|
|
18
|
+
DontLiftType = AllowedThenType + (scf.Yield, func.Return, func.Invoke)
|
|
12
19
|
|
|
13
20
|
|
|
14
21
|
class UnrollIfs(Pass):
|
|
15
22
|
"""This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
|
|
16
23
|
|
|
17
24
|
def unsafe_run(self, mt: ir.Method):
|
|
18
|
-
result = Walk(LiftThenBody()).rewrite(mt.code)
|
|
25
|
+
result = Walk(LiftThenBody(exclude_stmts=DontLiftType)).rewrite(mt.code)
|
|
19
26
|
result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
|
|
20
27
|
result = (
|
|
21
28
|
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
|
|
@@ -11,5 +11,6 @@ from .uop_to_parallel import (
|
|
|
11
11
|
SimpleGreedyMergePolicy as SimpleGreedyMergePolicy,
|
|
12
12
|
SimpleOptimalMergePolicy as SimpleOptimalMergePolicy,
|
|
13
13
|
)
|
|
14
|
+
from .parallel_to_glob import ParallelToGlobalRule as ParallelToGlobalRule
|
|
14
15
|
from .noise.remove_noise import RemoveNoisePass as RemoveNoisePass
|
|
15
16
|
from .noise.heuristic_noise import NoiseRewriteRule as NoiseRewriteRule
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.rewrite import abc
|
|
6
|
+
from kirin.analysis import const
|
|
7
|
+
from kirin.dialects import ilist
|
|
8
|
+
|
|
9
|
+
from bloqade.analysis import address
|
|
10
|
+
|
|
11
|
+
from ..dialects import core, glob, parallel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ParallelToGlobalRule(abc.RewriteRule):
|
|
16
|
+
address_analysis: Dict[ir.SSAValue, address.Address]
|
|
17
|
+
|
|
18
|
+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
|
|
19
|
+
if not isinstance(node, parallel.UGate):
|
|
20
|
+
return abc.RewriteResult()
|
|
21
|
+
|
|
22
|
+
qargs = node.qargs
|
|
23
|
+
qarg_addresses = self.address_analysis.get(qargs, None)
|
|
24
|
+
|
|
25
|
+
if isinstance(qarg_addresses, address.AddressReg):
|
|
26
|
+
# NOTE: we only have an AddressReg if it's an entire register, definitely rewrite that
|
|
27
|
+
return self._rewrite_parallel_to_glob(node)
|
|
28
|
+
|
|
29
|
+
if not isinstance(qarg_addresses, address.AddressTuple):
|
|
30
|
+
return abc.RewriteResult()
|
|
31
|
+
|
|
32
|
+
idxs, qreg = self._find_qreg(qargs.owner, set())
|
|
33
|
+
|
|
34
|
+
if qreg is None:
|
|
35
|
+
# NOTE: no unique register found
|
|
36
|
+
return abc.RewriteResult()
|
|
37
|
+
|
|
38
|
+
if not isinstance(hint := qreg.n_qubits.hints.get("const"), const.Value):
|
|
39
|
+
# NOTE: non-constant number of qubits
|
|
40
|
+
return abc.RewriteResult()
|
|
41
|
+
|
|
42
|
+
n = hint.data
|
|
43
|
+
if len(idxs) != n:
|
|
44
|
+
# NOTE: not all qubits of the register are there
|
|
45
|
+
return abc.RewriteResult()
|
|
46
|
+
|
|
47
|
+
return self._rewrite_parallel_to_glob(node)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def _rewrite_parallel_to_glob(node: parallel.UGate) -> abc.RewriteResult:
|
|
51
|
+
theta, phi, lam = node.theta, node.phi, node.lam
|
|
52
|
+
global_u = glob.UGate(node.qargs, theta=theta, phi=phi, lam=lam)
|
|
53
|
+
node.replace_by(global_u)
|
|
54
|
+
return abc.RewriteResult(has_done_something=True)
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _find_qreg(
|
|
58
|
+
qargs_owner: ir.Statement | ir.Block, idxs: set
|
|
59
|
+
) -> tuple[set, core.stmts.QRegNew | None]:
|
|
60
|
+
|
|
61
|
+
if isinstance(qargs_owner, core.stmts.QRegGet):
|
|
62
|
+
idxs.add(qargs_owner.idx)
|
|
63
|
+
qreg = qargs_owner.reg.owner
|
|
64
|
+
if not isinstance(qreg, core.stmts.QRegNew):
|
|
65
|
+
# NOTE: this could potentially be casted
|
|
66
|
+
qreg = None
|
|
67
|
+
return idxs, qreg
|
|
68
|
+
|
|
69
|
+
if isinstance(qargs_owner, ilist.New):
|
|
70
|
+
vals = qargs_owner.values
|
|
71
|
+
if len(vals) == 0:
|
|
72
|
+
return idxs, None
|
|
73
|
+
|
|
74
|
+
idxs, first_qreg = ParallelToGlobalRule._find_qreg(vals[0].owner, idxs)
|
|
75
|
+
for val in vals[1:]:
|
|
76
|
+
idxs, qreg = ParallelToGlobalRule._find_qreg(val.owner, idxs)
|
|
77
|
+
if qreg != first_qreg:
|
|
78
|
+
return idxs, None
|
|
79
|
+
|
|
80
|
+
return idxs, first_qreg
|
|
81
|
+
|
|
82
|
+
return idxs, None
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.passes import Pass
|
|
5
|
+
from kirin.rewrite import (
|
|
6
|
+
Walk,
|
|
7
|
+
Chain,
|
|
8
|
+
Fixpoint,
|
|
9
|
+
)
|
|
10
|
+
from kirin.analysis import const
|
|
11
|
+
|
|
12
|
+
from ..rules.flatten_ilist import FlattenAddOpIList
|
|
13
|
+
from ..rules.inline_getitem_ilist import InlineGetItemFromIList
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class CanonicalizeIList(Pass):
|
|
18
|
+
|
|
19
|
+
def unsafe_run(self, mt: ir.Method):
|
|
20
|
+
|
|
21
|
+
cp_result_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
|
|
22
|
+
|
|
23
|
+
return Fixpoint(
|
|
24
|
+
Chain(
|
|
25
|
+
Walk(InlineGetItemFromIList(constprop_result=cp_result_frame.entries)),
|
|
26
|
+
Walk(FlattenAddOpIList()),
|
|
27
|
+
)
|
|
28
|
+
).rewrite(mt.code)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .split_ifs import LiftThenBody as LiftThenBody, SplitIfStmts as SplitIfStmts
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.dialects import py, ilist
|
|
5
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class FlattenAddOpIList(RewriteRule):
|
|
10
|
+
|
|
11
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
12
|
+
if not isinstance(node, py.binop.Add):
|
|
13
|
+
return RewriteResult()
|
|
14
|
+
|
|
15
|
+
# check if we are adding two ilist.New objects
|
|
16
|
+
new_data = ()
|
|
17
|
+
|
|
18
|
+
# lhs:
|
|
19
|
+
if not isinstance(node.lhs.owner, ilist.New):
|
|
20
|
+
if not (
|
|
21
|
+
isinstance(node.lhs.owner, py.Constant)
|
|
22
|
+
and isinstance(
|
|
23
|
+
const_ilist := node.lhs.owner.value.unwrap(), ilist.IList
|
|
24
|
+
)
|
|
25
|
+
and len(const_ilist.data) == 0
|
|
26
|
+
):
|
|
27
|
+
return RewriteResult()
|
|
28
|
+
|
|
29
|
+
else:
|
|
30
|
+
new_data += node.lhs.owner.values
|
|
31
|
+
|
|
32
|
+
# rhs:
|
|
33
|
+
if not isinstance(node.rhs.owner, ilist.New):
|
|
34
|
+
if not (
|
|
35
|
+
isinstance(node.rhs.owner, py.Constant)
|
|
36
|
+
and isinstance(
|
|
37
|
+
const_ilist := node.rhs.owner.value.unwrap(), ilist.IList
|
|
38
|
+
)
|
|
39
|
+
and len(const_ilist.data) == 0
|
|
40
|
+
):
|
|
41
|
+
return RewriteResult()
|
|
42
|
+
|
|
43
|
+
else:
|
|
44
|
+
new_data += node.rhs.owner.values
|
|
45
|
+
|
|
46
|
+
new_stmt = ilist.New(values=new_data)
|
|
47
|
+
node.replace_by(new_stmt)
|
|
48
|
+
|
|
49
|
+
return RewriteResult(
|
|
50
|
+
has_done_something=True,
|
|
51
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.analysis import const
|
|
5
|
+
from kirin.dialects import py, ilist
|
|
6
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class InlineGetItemFromIList(RewriteRule):
|
|
11
|
+
constprop_result: dict[ir.SSAValue, const.Result]
|
|
12
|
+
|
|
13
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
14
|
+
if not isinstance(node, py.indexing.GetItem):
|
|
15
|
+
return RewriteResult()
|
|
16
|
+
|
|
17
|
+
if not isinstance(node.obj.owner, ilist.New):
|
|
18
|
+
return RewriteResult()
|
|
19
|
+
|
|
20
|
+
if not isinstance(
|
|
21
|
+
index_value := self.constprop_result.get(node.index), const.Value
|
|
22
|
+
):
|
|
23
|
+
return RewriteResult()
|
|
24
|
+
|
|
25
|
+
elem_ssa = node.obj.owner.values[index_value.data]
|
|
26
|
+
|
|
27
|
+
node.result.replace_by(elem_ssa)
|
|
28
|
+
|
|
29
|
+
return RewriteResult(
|
|
30
|
+
has_done_something=True,
|
|
31
|
+
)
|
|
@@ -1,18 +1,23 @@
|
|
|
1
|
+
from dataclasses import field, dataclass
|
|
2
|
+
|
|
1
3
|
from kirin import ir
|
|
2
4
|
from kirin.dialects import scf, func
|
|
3
5
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
4
6
|
|
|
5
|
-
from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
|
|
6
|
-
from ..dialects.core.stmts import Reset, Measure
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
8
|
+
@dataclass
|
|
9
|
+
class LiftThenBody(RewriteRule):
|
|
10
|
+
"""
|
|
11
|
+
Lifts anything that's not in the `exclude_stmts` in the *then* body
|
|
12
|
+
|
|
10
13
|
|
|
11
|
-
|
|
14
|
+
Args:
|
|
15
|
+
exclude_stmts: A tuple of statement types that should not be lifted from the then body.
|
|
16
|
+
Defaults to an empty tuple, meaning all statements are lifted.
|
|
12
17
|
|
|
18
|
+
"""
|
|
13
19
|
|
|
14
|
-
|
|
15
|
-
"""Lifts anything that's not a UOP or a yield/return out of the then body"""
|
|
20
|
+
exclude_stmts: tuple[type[ir.Statement], ...] = field(default_factory=tuple)
|
|
16
21
|
|
|
17
22
|
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
18
23
|
if not isinstance(node, scf.IfElse):
|
|
@@ -20,7 +25,9 @@ class LiftThenBody(RewriteRule):
|
|
|
20
25
|
|
|
21
26
|
then_stmts = node.then_body.stmts()
|
|
22
27
|
|
|
23
|
-
lift_stmts = [
|
|
28
|
+
lift_stmts = [
|
|
29
|
+
stmt for stmt in then_stmts if not isinstance(stmt, self.exclude_stmts)
|
|
30
|
+
]
|
|
24
31
|
|
|
25
32
|
if len(lift_stmts) == 0:
|
|
26
33
|
return RewriteResult()
|
bloqade/squin/__init__.py
CHANGED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from kirin import types, interp
|
|
2
|
+
from kirin.analysis import TypeInference, const
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
|
|
5
|
+
from bloqade import squin
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@squin.qubit.dialect.register(key="typeinfer")
|
|
9
|
+
class TypeInfer(interp.MethodTable):
|
|
10
|
+
@interp.impl(squin.qubit.New)
|
|
11
|
+
def _call(self, interp: TypeInference, frame: interp.Frame, stmt: squin.qubit.New):
|
|
12
|
+
# based on Xiu-zhe (Roger) Luo's get_const_value function
|
|
13
|
+
|
|
14
|
+
if (hint := stmt.n_qubits.hints.get("const")) is None:
|
|
15
|
+
return (ilist.IListType[squin.qubit.QubitType, types.Any],)
|
|
16
|
+
|
|
17
|
+
if isinstance(hint, const.Value) and isinstance(hint.data, int):
|
|
18
|
+
return (ilist.IListType[squin.qubit.QubitType, types.Literal(hint.data)],)
|
|
19
|
+
|
|
20
|
+
return (ilist.IListType[squin.qubit.QubitType, types.Any],)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import address_impl as address_impl
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from kirin import interp
|
|
2
|
+
from kirin.analysis import ForwardFrame
|
|
3
|
+
|
|
4
|
+
from bloqade.analysis.address.lattice import (
|
|
5
|
+
Address,
|
|
6
|
+
AddressReg,
|
|
7
|
+
AddressWire,
|
|
8
|
+
AddressQubit,
|
|
9
|
+
)
|
|
10
|
+
from bloqade.analysis.address.analysis import AddressAnalysis
|
|
11
|
+
|
|
12
|
+
from .. import wire, qubit
|
|
13
|
+
|
|
14
|
+
# Address lattice elements we can work with:
|
|
15
|
+
## NotQubit (bottom), AnyAddress (top)
|
|
16
|
+
|
|
17
|
+
## AddressTuple -> data: tuple[Address, ...]
|
|
18
|
+
### Recursive type, could contain itself or other variants
|
|
19
|
+
### This pops up in cases where you can have an IList/Tuple
|
|
20
|
+
### That contains elements that could be other Address types
|
|
21
|
+
|
|
22
|
+
## AddressReg -> data: Sequence[int]
|
|
23
|
+
### specific to creation of a register of qubits
|
|
24
|
+
|
|
25
|
+
## AddressQubit -> data: int
|
|
26
|
+
### Base qubit address type
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@wire.dialect.register(key="qubit.address")
|
|
30
|
+
class SquinWireMethodTable(interp.MethodTable):
|
|
31
|
+
|
|
32
|
+
@interp.impl(wire.Unwrap)
|
|
33
|
+
def unwrap(
|
|
34
|
+
self,
|
|
35
|
+
interp_: AddressAnalysis,
|
|
36
|
+
frame: ForwardFrame[Address],
|
|
37
|
+
stmt: wire.Unwrap,
|
|
38
|
+
):
|
|
39
|
+
|
|
40
|
+
origin_qubit = frame.get(stmt.qubit)
|
|
41
|
+
|
|
42
|
+
if isinstance(origin_qubit, AddressQubit):
|
|
43
|
+
return (AddressWire(origin_qubit=origin_qubit),)
|
|
44
|
+
else:
|
|
45
|
+
return (Address.top(),)
|
|
46
|
+
|
|
47
|
+
@interp.impl(wire.Apply)
|
|
48
|
+
def apply(
|
|
49
|
+
self,
|
|
50
|
+
interp_: AddressAnalysis,
|
|
51
|
+
frame: ForwardFrame[Address],
|
|
52
|
+
stmt: wire.Apply,
|
|
53
|
+
):
|
|
54
|
+
return frame.get_values(stmt.inputs)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@qubit.dialect.register(key="qubit.address")
|
|
58
|
+
class SquinQubitMethodTable(interp.MethodTable):
|
|
59
|
+
|
|
60
|
+
# This can be treated like a QRegNew impl
|
|
61
|
+
@interp.impl(qubit.New)
|
|
62
|
+
def new(
|
|
63
|
+
self,
|
|
64
|
+
interp_: AddressAnalysis,
|
|
65
|
+
frame: ForwardFrame[Address],
|
|
66
|
+
stmt: qubit.New,
|
|
67
|
+
):
|
|
68
|
+
n_qubits = interp_.get_const_value(int, stmt.n_qubits)
|
|
69
|
+
addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
|
|
70
|
+
interp_.next_address += n_qubits
|
|
71
|
+
return (addr,)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from kirin import interp
|
|
2
|
-
from kirin.dialects import scf
|
|
2
|
+
from kirin.dialects import scf, func
|
|
3
3
|
from kirin.dialects.scf.typeinfer import TypeInfer as ScfTypeInfer
|
|
4
4
|
|
|
5
5
|
from bloqade.squin import op, wire
|
|
@@ -85,3 +85,8 @@ class SquinOp(interp.MethodTable):
|
|
|
85
85
|
@scf.dialect.register(key="op.nsites")
|
|
86
86
|
class ScfSquinOp(ScfTypeInfer):
|
|
87
87
|
pass
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@func.dialect.register(key="op.nsites")
|
|
91
|
+
class FuncSquinOp(func.typeinfer.TypeInfer):
|
|
92
|
+
pass
|
bloqade/squin/cirq/lowering.py
CHANGED
|
@@ -368,11 +368,24 @@ class Squin(lowering.LoweringABC[CirqNode]):
|
|
|
368
368
|
state: lowering.State[CirqNode],
|
|
369
369
|
node: cirq.GeneralizedAmplitudeDampingChannel,
|
|
370
370
|
):
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
# gamma = state.current_frame.push(py.Constant(node.gamma))
|
|
371
|
+
p = state.current_frame.push(py.Constant(node.p)).result
|
|
372
|
+
gamma = state.current_frame.push(py.Constant(node.gamma)).result
|
|
374
373
|
|
|
375
|
-
#
|
|
374
|
+
# NOTE: cirq has a weird convention here: if p == 1, we have AmplitudeDampingChannel,
|
|
375
|
+
# which basically means p is the probability of the environment being in the vacuum state
|
|
376
|
+
prob0 = state.current_frame.push(py.binop.Mult(p, gamma)).result
|
|
377
|
+
one_ = state.current_frame.push(py.Constant(1)).result
|
|
378
|
+
p_minus_1 = state.current_frame.push(py.binop.Sub(one_, p)).result
|
|
379
|
+
prob1 = state.current_frame.push(py.binop.Mult(p_minus_1, gamma)).result
|
|
376
380
|
|
|
377
|
-
|
|
378
|
-
|
|
381
|
+
r0 = state.current_frame.push(op.stmts.Reset()).result
|
|
382
|
+
r1 = state.current_frame.push(op.stmts.ResetToOne()).result
|
|
383
|
+
|
|
384
|
+
probs = state.current_frame.push(ilist.New(values=(prob0, prob1))).result
|
|
385
|
+
ops = state.current_frame.push(ilist.New(values=(r0, r1))).result
|
|
386
|
+
|
|
387
|
+
noise_channel = state.current_frame.push(
|
|
388
|
+
noise.stmts.StochasticUnitaryChannel(probabilities=probs, operators=ops)
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
return noise_channel
|
bloqade/squin/noise/stmts.py
CHANGED
bloqade/squin/op/__init__.py
CHANGED
bloqade/squin/op/_wrapper.py
CHANGED
bloqade/squin/op/stmts.py
CHANGED
|
@@ -98,6 +98,15 @@ class ConstantUnitary(ConstantOp):
|
|
|
98
98
|
|
|
99
99
|
@statement(dialect=dialect)
|
|
100
100
|
class U3(PrimitiveOp):
|
|
101
|
+
"""
|
|
102
|
+
The rotation operator U3(theta, phi, lam).
|
|
103
|
+
Note that we use the convention from the QASM2 specification, namely
|
|
104
|
+
|
|
105
|
+
$$
|
|
106
|
+
U_3(\theta, \phi, \lambda) = R_z(\phi) R_y(\theta) R_z(\lambda)
|
|
107
|
+
$$
|
|
108
|
+
"""
|
|
109
|
+
|
|
101
110
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
|
|
102
111
|
theta: ir.SSAValue = info.argument(types.Float)
|
|
103
112
|
phi: ir.SSAValue = info.argument(types.Float)
|
|
@@ -110,7 +119,7 @@ class PhaseOp(PrimitiveOp):
|
|
|
110
119
|
A phase operator.
|
|
111
120
|
|
|
112
121
|
$$
|
|
113
|
-
PhaseOp(theta) = e^{i \theta} I
|
|
122
|
+
PhaseOp(\theta) = e^{i \theta} I
|
|
114
123
|
$$
|
|
115
124
|
"""
|
|
116
125
|
|
|
@@ -124,7 +133,7 @@ class ShiftOp(PrimitiveOp):
|
|
|
124
133
|
A phase shift operator.
|
|
125
134
|
|
|
126
135
|
$$
|
|
127
|
-
Shift(theta) = \\begin{bmatrix} 1 & 0 \\\\ 0 & e^{i \\theta} \\end{bmatrix}
|
|
136
|
+
Shift(\theta) = \\begin{bmatrix} 1 & 0 \\\\ 0 & e^{i \\theta} \\end{bmatrix}
|
|
128
137
|
$$
|
|
129
138
|
"""
|
|
130
139
|
|
|
@@ -141,6 +150,15 @@ class Reset(PrimitiveOp):
|
|
|
141
150
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), FixedSites(1)})
|
|
142
151
|
|
|
143
152
|
|
|
153
|
+
@statement(dialect=dialect)
|
|
154
|
+
class ResetToOne(PrimitiveOp):
|
|
155
|
+
"""
|
|
156
|
+
Reset qubits to the one state. Mainly needed to accommodate cirq's GeneralizedAmplitudeDampingChannel
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), FixedSites(1)})
|
|
160
|
+
|
|
161
|
+
|
|
144
162
|
@statement
|
|
145
163
|
class CliffordOp(ConstantUnitary):
|
|
146
164
|
pass
|
bloqade/squin/qubit.py
CHANGED
|
@@ -17,6 +17,7 @@ from kirin.lowering import wraps
|
|
|
17
17
|
from bloqade.types import Qubit, QubitType
|
|
18
18
|
from bloqade.squin.op.types import Op, OpType
|
|
19
19
|
|
|
20
|
+
from .types import MeasurementResult, MeasurementResultType
|
|
20
21
|
from .lowering import ApplyAnyCallLowering
|
|
21
22
|
|
|
22
23
|
dialect = ir.Dialect("squin.qubit")
|
|
@@ -65,8 +66,8 @@ class MeasureQubit(ir.Statement):
|
|
|
65
66
|
name = "measure.qubit"
|
|
66
67
|
|
|
67
68
|
traits = frozenset({lowering.FromPythonCall()})
|
|
68
|
-
qubit: ir.SSAValue = info.argument(
|
|
69
|
-
result: ir.ResultValue = info.result(
|
|
69
|
+
qubit: ir.SSAValue = info.argument(QubitType)
|
|
70
|
+
result: ir.ResultValue = info.result(MeasurementResultType)
|
|
70
71
|
|
|
71
72
|
|
|
72
73
|
@statement(dialect=dialect)
|
|
@@ -75,7 +76,7 @@ class MeasureQubitList(ir.Statement):
|
|
|
75
76
|
|
|
76
77
|
traits = frozenset({lowering.FromPythonCall()})
|
|
77
78
|
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
|
|
78
|
-
result: ir.ResultValue = info.result(ilist.IListType[
|
|
79
|
+
result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType])
|
|
79
80
|
|
|
80
81
|
|
|
81
82
|
# NOTE: no dependent types in Python, so we have to mark it Any...
|
|
@@ -131,9 +132,11 @@ def apply(operator: Op, *qubits) -> None: ...
|
|
|
131
132
|
|
|
132
133
|
|
|
133
134
|
@overload
|
|
134
|
-
def measure(input: Qubit) ->
|
|
135
|
+
def measure(input: Qubit) -> MeasurementResult: ...
|
|
135
136
|
@overload
|
|
136
|
-
def measure(
|
|
137
|
+
def measure(
|
|
138
|
+
input: ilist.IList[Qubit, Any] | list[Qubit],
|
|
139
|
+
) -> ilist.IList[MeasurementResult, Any]: ...
|
|
137
140
|
|
|
138
141
|
|
|
139
142
|
@wraps(MeasureAny)
|