bloqade-circuit 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +14 -0
- bloqade/analysis/fidelity/analysis.py +27 -2
- bloqade/noise/fidelity.py +3 -3
- bloqade/noise/native/_dialect.py +1 -1
- bloqade/noise/native/_wrappers.py +35 -6
- bloqade/noise/native/stmts.py +1 -1
- bloqade/pyqrack/device.py +109 -21
- bloqade/pyqrack/qasm2/core.py +4 -1
- bloqade/pyqrack/squin/qubit.py +16 -9
- bloqade/pyqrack/squin/wire.py +22 -4
- bloqade/pyqrack/task.py +13 -5
- bloqade/qasm2/__init__.py +1 -0
- bloqade/qasm2/_qasm_loading.py +151 -0
- bloqade/qasm2/dialects/core/__init__.py +9 -1
- bloqade/qasm2/dialects/expr/__init__.py +18 -1
- bloqade/qasm2/dialects/noise.py +33 -1
- bloqade/qasm2/dialects/uop/__init__.py +39 -3
- bloqade/qasm2/dialects/uop/schedule.py +1 -1
- bloqade/qasm2/emit/impls/__init__.py +1 -0
- bloqade/qasm2/emit/impls/noise_native.py +89 -0
- bloqade/qasm2/emit/main.py +21 -0
- bloqade/qasm2/emit/target.py +20 -5
- bloqade/qasm2/groups.py +2 -0
- bloqade/qasm2/parse/__init__.py +7 -4
- bloqade/qasm2/parse/lowering.py +20 -130
- bloqade/qasm2/parse/qasm2.lark +1 -1
- bloqade/qasm2/passes/__init__.py +1 -0
- bloqade/qasm2/passes/fold.py +6 -0
- bloqade/qasm2/passes/noise.py +50 -2
- bloqade/qasm2/passes/parallel.py +9 -0
- bloqade/qasm2/passes/unroll_if.py +25 -0
- bloqade/qasm2/rewrite/__init__.py +1 -0
- bloqade/qasm2/rewrite/desugar.py +3 -2
- bloqade/qasm2/rewrite/heuristic_noise.py +1 -9
- bloqade/qasm2/rewrite/native_gates.py +67 -4
- bloqade/qasm2/rewrite/split_ifs.py +66 -0
- bloqade/squin/analysis/nsites/__init__.py +1 -0
- bloqade/squin/analysis/nsites/impls.py +25 -1
- bloqade/squin/noise/__init__.py +7 -26
- bloqade/squin/noise/_wrapper.py +25 -0
- bloqade/squin/op/__init__.py +33 -159
- bloqade/squin/op/_wrapper.py +101 -0
- bloqade/squin/op/stdlib.py +62 -0
- bloqade/squin/passes/__init__.py +1 -0
- bloqade/squin/passes/stim.py +68 -0
- bloqade/squin/rewrite/__init__.py +11 -0
- bloqade/squin/rewrite/qubit_to_stim.py +84 -0
- bloqade/squin/rewrite/squin_measure.py +98 -0
- bloqade/squin/rewrite/stim_rewrite_util.py +158 -0
- bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
- bloqade/squin/rewrite/wire_to_stim.py +73 -0
- bloqade/squin/rewrite/wrap_analysis.py +72 -0
- bloqade/squin/wire.py +1 -13
- bloqade/stim/__init__.py +39 -5
- bloqade/stim/_wrappers.py +14 -12
- bloqade/stim/dialects/__init__.py +1 -5
- bloqade/stim/dialects/{aux → auxiliary}/__init__.py +12 -1
- bloqade/stim/dialects/{aux → auxiliary}/emit.py +1 -1
- bloqade/stim/dialects/collapse/__init__.py +13 -2
- bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +1 -1
- bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
- bloqade/stim/dialects/gate/__init__.py +16 -1
- bloqade/stim/dialects/gate/emit.py +1 -1
- bloqade/stim/dialects/gate/stmts/base.py +1 -1
- bloqade/stim/dialects/gate/stmts/pp.py +1 -1
- bloqade/stim/dialects/noise/emit.py +1 -1
- bloqade/stim/emit/__init__.py +1 -1
- bloqade/stim/groups.py +4 -2
- {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/METADATA +3 -3
- {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/RECORD +80 -64
- /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
- /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
- {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.prelude import structural_no_opt
|
|
3
|
+
|
|
4
|
+
from . import types
|
|
5
|
+
from ._dialect import dialect
|
|
6
|
+
from ._wrapper import h, x, y, z, rot, phase, control
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@ir.dialect_group(structural_no_opt.add(dialect))
|
|
10
|
+
def op(self):
|
|
11
|
+
def run_pass(method):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
return run_pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@op
|
|
18
|
+
def rx(theta: float) -> types.Op:
|
|
19
|
+
"""Rotation X gate."""
|
|
20
|
+
return rot(x(), theta)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@op
|
|
24
|
+
def ry(theta: float) -> types.Op:
|
|
25
|
+
"""Rotation Y gate."""
|
|
26
|
+
return rot(y(), theta)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@op
|
|
30
|
+
def rz(theta: float) -> types.Op:
|
|
31
|
+
"""Rotation Z gate."""
|
|
32
|
+
return rot(z(), theta)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@op
|
|
36
|
+
def cx() -> types.Op:
|
|
37
|
+
"""Controlled X gate."""
|
|
38
|
+
return control(x(), n_controls=1)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@op
|
|
42
|
+
def cy() -> types.Op:
|
|
43
|
+
"""Controlled Y gate."""
|
|
44
|
+
return control(y(), n_controls=1)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@op
|
|
48
|
+
def cz() -> types.Op:
|
|
49
|
+
"""Control Z gate."""
|
|
50
|
+
return control(z(), n_controls=1)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@op
|
|
54
|
+
def ch() -> types.Op:
|
|
55
|
+
"""Control H gate."""
|
|
56
|
+
return control(h(), n_controls=1)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@op
|
|
60
|
+
def cphase(theta: float) -> types.Op:
|
|
61
|
+
"""Control Phase gate."""
|
|
62
|
+
return control(phase(theta), n_controls=1)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .stim import SquinToStim as SquinToStim
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin.passes import Fold
|
|
4
|
+
from kirin.rewrite import (
|
|
5
|
+
Walk,
|
|
6
|
+
Chain,
|
|
7
|
+
Fixpoint,
|
|
8
|
+
DeadCodeElimination,
|
|
9
|
+
CommonSubexpressionElimination,
|
|
10
|
+
)
|
|
11
|
+
from kirin.ir.method import Method
|
|
12
|
+
from kirin.passes.abc import Pass
|
|
13
|
+
from kirin.rewrite.abc import RewriteResult
|
|
14
|
+
|
|
15
|
+
from bloqade.squin.rewrite import (
|
|
16
|
+
SquinWireToStim,
|
|
17
|
+
SquinQubitToStim,
|
|
18
|
+
WrapSquinAnalysis,
|
|
19
|
+
SquinMeasureToStim,
|
|
20
|
+
SquinWireIdentityElimination,
|
|
21
|
+
)
|
|
22
|
+
from bloqade.analysis.address import AddressAnalysis
|
|
23
|
+
from bloqade.squin.analysis.nsites import (
|
|
24
|
+
NSitesAnalysis,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class SquinToStim(Pass):
|
|
30
|
+
|
|
31
|
+
def unsafe_run(self, mt: Method) -> RewriteResult:
|
|
32
|
+
fold_pass = Fold(mt.dialects)
|
|
33
|
+
# propagate constants
|
|
34
|
+
rewrite_result = fold_pass(mt)
|
|
35
|
+
|
|
36
|
+
# Get necessary analysis results to plug into hints
|
|
37
|
+
address_analysis = AddressAnalysis(mt.dialects)
|
|
38
|
+
address_frame, _ = address_analysis.run_analysis(mt)
|
|
39
|
+
site_analysis = NSitesAnalysis(mt.dialects)
|
|
40
|
+
sites_frame, _ = site_analysis.run_analysis(mt)
|
|
41
|
+
|
|
42
|
+
# Wrap Rewrite + SquinToStim can happen w/ standard walk
|
|
43
|
+
rewrite_result = (
|
|
44
|
+
Walk(
|
|
45
|
+
Chain(
|
|
46
|
+
WrapSquinAnalysis(
|
|
47
|
+
address_analysis=address_frame.entries,
|
|
48
|
+
op_site_analysis=sites_frame.entries,
|
|
49
|
+
),
|
|
50
|
+
SquinQubitToStim(),
|
|
51
|
+
SquinWireToStim(),
|
|
52
|
+
SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later
|
|
53
|
+
SquinWireIdentityElimination(),
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
.rewrite(mt.code)
|
|
57
|
+
.join(rewrite_result)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
rewrite_result = (
|
|
61
|
+
Fixpoint(
|
|
62
|
+
Walk(Chain(DeadCodeElimination(), CommonSubexpressionElimination()))
|
|
63
|
+
)
|
|
64
|
+
.rewrite(mt.code)
|
|
65
|
+
.join(rewrite_result)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return rewrite_result
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .wire_to_stim import SquinWireToStim as SquinWireToStim
|
|
2
|
+
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
|
|
3
|
+
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
|
|
4
|
+
from .wrap_analysis import (
|
|
5
|
+
SitesAttribute as SitesAttribute,
|
|
6
|
+
AddressAttribute as AddressAttribute,
|
|
7
|
+
WrapSquinAnalysis as WrapSquinAnalysis,
|
|
8
|
+
)
|
|
9
|
+
from .wire_identity_elimination import (
|
|
10
|
+
SquinWireIdentityElimination as SquinWireIdentityElimination,
|
|
11
|
+
)
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
|
+
|
|
4
|
+
from bloqade import stim
|
|
5
|
+
from bloqade.squin import op, qubit
|
|
6
|
+
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
7
|
+
from bloqade.squin.rewrite.stim_rewrite_util import (
|
|
8
|
+
SQUIN_STIM_GATE_MAPPING,
|
|
9
|
+
rewrite_Control,
|
|
10
|
+
insert_qubit_idx_from_address,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SquinQubitToStim(RewriteRule):
|
|
15
|
+
|
|
16
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
17
|
+
|
|
18
|
+
match node:
|
|
19
|
+
case qubit.Apply() | qubit.Broadcast():
|
|
20
|
+
return self.rewrite_Apply_and_Broadcast(node)
|
|
21
|
+
case qubit.Reset():
|
|
22
|
+
return self.rewrite_Reset(node)
|
|
23
|
+
case _:
|
|
24
|
+
return RewriteResult()
|
|
25
|
+
|
|
26
|
+
def rewrite_Apply_and_Broadcast(
|
|
27
|
+
self, stmt: qubit.Apply | qubit.Broadcast
|
|
28
|
+
) -> RewriteResult:
|
|
29
|
+
"""
|
|
30
|
+
Rewrite Apply and Broadcast nodes to their stim equivalent statements.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
# this is an SSAValue, need it to be the actual operator
|
|
34
|
+
applied_op = stmt.operator.owner
|
|
35
|
+
assert isinstance(applied_op, op.stmts.Operator)
|
|
36
|
+
|
|
37
|
+
if isinstance(applied_op, op.stmts.Control):
|
|
38
|
+
return rewrite_Control(stmt)
|
|
39
|
+
|
|
40
|
+
# need to handle Control through separate means
|
|
41
|
+
# but we can handle X, Y, Z, H, and S here just fine
|
|
42
|
+
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
|
|
43
|
+
if stim_1q_op is None:
|
|
44
|
+
return RewriteResult()
|
|
45
|
+
|
|
46
|
+
address_attr = stmt.qubits.hints.get("address")
|
|
47
|
+
if address_attr is None:
|
|
48
|
+
return RewriteResult()
|
|
49
|
+
|
|
50
|
+
assert isinstance(address_attr, AddressAttribute)
|
|
51
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
52
|
+
address=address_attr, stmt_to_insert_before=stmt
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if qubit_idx_ssas is None:
|
|
56
|
+
return RewriteResult()
|
|
57
|
+
|
|
58
|
+
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
|
|
59
|
+
stmt.replace_by(stim_1q_stmt)
|
|
60
|
+
|
|
61
|
+
return RewriteResult(has_done_something=True)
|
|
62
|
+
|
|
63
|
+
def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult:
|
|
64
|
+
qubit_ilist_ssa = reset_stmt.qubits
|
|
65
|
+
# qubits are in an ilist which makes up an AddressTuple
|
|
66
|
+
address_attr = qubit_ilist_ssa.hints.get("address")
|
|
67
|
+
if address_attr is None:
|
|
68
|
+
return RewriteResult()
|
|
69
|
+
|
|
70
|
+
assert isinstance(address_attr, AddressAttribute)
|
|
71
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
72
|
+
address=address_attr, stmt_to_insert_before=reset_stmt
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if qubit_idx_ssas is None:
|
|
76
|
+
return RewriteResult()
|
|
77
|
+
|
|
78
|
+
stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
|
|
79
|
+
reset_stmt.replace_by(stim_rz_stmt)
|
|
80
|
+
|
|
81
|
+
return RewriteResult(has_done_something=True)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# put rewrites for measure statements in separate rule, then just have to dispatch
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# create rewrite rule name SquinMeasureToStim using kirin
|
|
2
|
+
from kirin import ir
|
|
3
|
+
from kirin.dialects import py
|
|
4
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
5
|
+
|
|
6
|
+
from bloqade import stim
|
|
7
|
+
from bloqade.squin import wire, qubit
|
|
8
|
+
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
9
|
+
from bloqade.squin.rewrite.stim_rewrite_util import (
|
|
10
|
+
is_measure_result_used,
|
|
11
|
+
insert_qubit_idx_from_address,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SquinMeasureToStim(RewriteRule):
|
|
16
|
+
"""
|
|
17
|
+
Rewrite squin measure-related statements to stim statements.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
21
|
+
|
|
22
|
+
match node:
|
|
23
|
+
case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure():
|
|
24
|
+
return self.rewrite_Measure(node)
|
|
25
|
+
case qubit.MeasureAndReset() | wire.MeasureAndReset():
|
|
26
|
+
return self.rewrite_MeasureAndReset(node)
|
|
27
|
+
case _:
|
|
28
|
+
return RewriteResult()
|
|
29
|
+
|
|
30
|
+
def rewrite_Measure(
|
|
31
|
+
self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
|
|
32
|
+
) -> RewriteResult:
|
|
33
|
+
if is_measure_result_used(measure_stmt):
|
|
34
|
+
return RewriteResult()
|
|
35
|
+
|
|
36
|
+
qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
|
|
37
|
+
if qubit_idx_ssas is None:
|
|
38
|
+
return RewriteResult()
|
|
39
|
+
|
|
40
|
+
prob_noise_stmt = py.constant.Constant(0.0)
|
|
41
|
+
stim_measure_stmt = stim.collapse.MZ(
|
|
42
|
+
p=prob_noise_stmt.result,
|
|
43
|
+
targets=qubit_idx_ssas,
|
|
44
|
+
)
|
|
45
|
+
prob_noise_stmt.insert_before(measure_stmt)
|
|
46
|
+
measure_stmt.replace_by(stim_measure_stmt)
|
|
47
|
+
|
|
48
|
+
return RewriteResult(has_done_something=True)
|
|
49
|
+
|
|
50
|
+
def rewrite_MeasureAndReset(
|
|
51
|
+
self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset
|
|
52
|
+
) -> RewriteResult:
|
|
53
|
+
if not is_measure_result_used(meas_and_reset_stmt):
|
|
54
|
+
return RewriteResult()
|
|
55
|
+
|
|
56
|
+
qubit_idx_ssas = self.get_qubit_idx_ssas(meas_and_reset_stmt)
|
|
57
|
+
|
|
58
|
+
if qubit_idx_ssas is None:
|
|
59
|
+
return RewriteResult()
|
|
60
|
+
|
|
61
|
+
error_p_stmt = py.Constant(0.0)
|
|
62
|
+
stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result)
|
|
63
|
+
stim_rz_stmt = stim.collapse.RZ(
|
|
64
|
+
targets=qubit_idx_ssas,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
error_p_stmt.insert_before(meas_and_reset_stmt)
|
|
68
|
+
stim_mz_stmt.insert_before(meas_and_reset_stmt)
|
|
69
|
+
meas_and_reset_stmt.replace_by(stim_rz_stmt)
|
|
70
|
+
|
|
71
|
+
return RewriteResult(has_done_something=True)
|
|
72
|
+
|
|
73
|
+
def get_qubit_idx_ssas(
|
|
74
|
+
self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
|
|
75
|
+
) -> tuple[ir.SSAValue, ...] | None:
|
|
76
|
+
"""
|
|
77
|
+
Extract the address attribute and insert qubit indices for the given measure statement.
|
|
78
|
+
"""
|
|
79
|
+
match measure_stmt:
|
|
80
|
+
case qubit.MeasureQubit():
|
|
81
|
+
address_attr = measure_stmt.qubit.hints.get("address")
|
|
82
|
+
case qubit.MeasureQubitList():
|
|
83
|
+
address_attr = measure_stmt.qubits.hints.get("address")
|
|
84
|
+
case wire.Measure():
|
|
85
|
+
address_attr = measure_stmt.wire.hints.get("address")
|
|
86
|
+
case _:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
if address_attr is None:
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
assert isinstance(address_attr, AddressAttribute)
|
|
93
|
+
|
|
94
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
95
|
+
address=address_attr, stmt_to_insert_before=measure_stmt
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return qubit_idx_ssas
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.dialects import py
|
|
3
|
+
from kirin.rewrite.abc import RewriteResult
|
|
4
|
+
|
|
5
|
+
from bloqade.squin import op, wire, qubit
|
|
6
|
+
from bloqade.stim.dialects import gate
|
|
7
|
+
from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple
|
|
8
|
+
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
9
|
+
|
|
10
|
+
SQUIN_STIM_GATE_MAPPING = {
|
|
11
|
+
op.stmts.X: gate.X,
|
|
12
|
+
op.stmts.Y: gate.Y,
|
|
13
|
+
op.stmts.Z: gate.Z,
|
|
14
|
+
op.stmts.H: gate.H,
|
|
15
|
+
op.stmts.S: gate.S,
|
|
16
|
+
op.stmts.Identity: gate.Identity,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def insert_qubit_idx_from_address(
|
|
21
|
+
address: AddressAttribute, stmt_to_insert_before: ir.Statement
|
|
22
|
+
) -> tuple[ir.SSAValue, ...] | None:
|
|
23
|
+
"""
|
|
24
|
+
Extract qubit indices from an AddressAttribute and insert them into the SSA form.
|
|
25
|
+
"""
|
|
26
|
+
address_data = address.address
|
|
27
|
+
qubit_idx_ssas = []
|
|
28
|
+
|
|
29
|
+
if isinstance(address_data, AddressTuple):
|
|
30
|
+
for address_qubit in address_data.data:
|
|
31
|
+
if not isinstance(address_qubit, AddressQubit):
|
|
32
|
+
return
|
|
33
|
+
qubit_idx = address_qubit.data
|
|
34
|
+
qubit_idx_stmt = py.Constant(qubit_idx)
|
|
35
|
+
qubit_idx_stmt.insert_before(stmt_to_insert_before)
|
|
36
|
+
qubit_idx_ssas.append(qubit_idx_stmt.result)
|
|
37
|
+
elif isinstance(address_data, AddressWire):
|
|
38
|
+
address_qubit = address_data.origin_qubit
|
|
39
|
+
qubit_idx = address_qubit.data
|
|
40
|
+
qubit_idx_stmt = py.Constant(qubit_idx)
|
|
41
|
+
qubit_idx_stmt.insert_before(stmt_to_insert_before)
|
|
42
|
+
qubit_idx_ssas.append(qubit_idx_stmt.result)
|
|
43
|
+
else:
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
return tuple(qubit_idx_ssas)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def insert_qubit_idx_from_wire_ssa(
|
|
50
|
+
wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement
|
|
51
|
+
) -> tuple[ir.SSAValue, ...] | None:
|
|
52
|
+
"""
|
|
53
|
+
Extract qubit indices from wire SSA values and insert them into the SSA form.
|
|
54
|
+
"""
|
|
55
|
+
qubit_idx_ssas = []
|
|
56
|
+
for wire_ssa in wire_ssas:
|
|
57
|
+
address_attribute = wire_ssa.hints.get("address")
|
|
58
|
+
if address_attribute is None:
|
|
59
|
+
return
|
|
60
|
+
assert isinstance(address_attribute, AddressAttribute)
|
|
61
|
+
wire_address = address_attribute.address
|
|
62
|
+
assert isinstance(wire_address, AddressWire)
|
|
63
|
+
qubit_idx = wire_address.origin_qubit.data
|
|
64
|
+
qubit_idx_stmt = py.Constant(qubit_idx)
|
|
65
|
+
qubit_idx_ssas.append(qubit_idx_stmt.result)
|
|
66
|
+
qubit_idx_stmt.insert_before(stmt_to_insert_before)
|
|
67
|
+
|
|
68
|
+
return tuple(qubit_idx_ssas)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def insert_qubit_idx_after_apply(
|
|
72
|
+
stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast,
|
|
73
|
+
) -> tuple[ir.SSAValue, ...] | None:
|
|
74
|
+
"""
|
|
75
|
+
Extract qubit indices from Apply or Broadcast statements.
|
|
76
|
+
"""
|
|
77
|
+
if isinstance(stmt, (qubit.Apply, qubit.Broadcast)):
|
|
78
|
+
qubits = stmt.qubits
|
|
79
|
+
address_attribute = qubits.hints.get("address")
|
|
80
|
+
if address_attribute is None:
|
|
81
|
+
return
|
|
82
|
+
assert isinstance(address_attribute, AddressAttribute)
|
|
83
|
+
return insert_qubit_idx_from_address(
|
|
84
|
+
address=address_attribute, stmt_to_insert_before=stmt
|
|
85
|
+
)
|
|
86
|
+
elif isinstance(stmt, (wire.Apply, wire.Broadcast)):
|
|
87
|
+
wire_ssas = stmt.inputs
|
|
88
|
+
return insert_qubit_idx_from_wire_ssa(
|
|
89
|
+
wire_ssas=wire_ssas, stmt_to_insert_before=stmt
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def rewrite_Control(
|
|
94
|
+
stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast,
|
|
95
|
+
) -> RewriteResult:
|
|
96
|
+
"""
|
|
97
|
+
Handle control gates for Apply and Broadcast statements.
|
|
98
|
+
"""
|
|
99
|
+
ctrl_op = stmt_with_ctrl.operator.owner
|
|
100
|
+
assert isinstance(ctrl_op, op.stmts.Control)
|
|
101
|
+
|
|
102
|
+
ctrl_op_target_gate = ctrl_op.op.owner
|
|
103
|
+
assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
|
|
104
|
+
|
|
105
|
+
qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl)
|
|
106
|
+
if qubit_idx_ssas is None:
|
|
107
|
+
return RewriteResult()
|
|
108
|
+
|
|
109
|
+
# Separate control and target qubits
|
|
110
|
+
target_qubits = []
|
|
111
|
+
ctrl_qubits = []
|
|
112
|
+
for i in range(len(qubit_idx_ssas)):
|
|
113
|
+
if (i % 2) == 0:
|
|
114
|
+
ctrl_qubits.append(qubit_idx_ssas[i])
|
|
115
|
+
else:
|
|
116
|
+
target_qubits.append(qubit_idx_ssas[i])
|
|
117
|
+
|
|
118
|
+
target_qubits = tuple(target_qubits)
|
|
119
|
+
ctrl_qubits = tuple(ctrl_qubits)
|
|
120
|
+
|
|
121
|
+
supported_gate_mapping = {
|
|
122
|
+
op.stmts.X: gate.CX,
|
|
123
|
+
op.stmts.Y: gate.CY,
|
|
124
|
+
op.stmts.Z: gate.CZ,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
stim_gate = supported_gate_mapping.get(type(ctrl_op_target_gate))
|
|
128
|
+
if stim_gate is None:
|
|
129
|
+
return RewriteResult()
|
|
130
|
+
|
|
131
|
+
stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits)
|
|
132
|
+
|
|
133
|
+
if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)):
|
|
134
|
+
# have to "reroute" the input of these statements to directly plug in
|
|
135
|
+
# to subsequent statements, remove dependency on the current statement
|
|
136
|
+
for input_wire, output_wire in zip(
|
|
137
|
+
stmt_with_ctrl.inputs, stmt_with_ctrl.results
|
|
138
|
+
):
|
|
139
|
+
output_wire.replace_by(input_wire)
|
|
140
|
+
|
|
141
|
+
stmt_with_ctrl.replace_by(stim_stmt)
|
|
142
|
+
|
|
143
|
+
return RewriteResult(has_done_something=True)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def is_measure_result_used(
|
|
147
|
+
stmt: (
|
|
148
|
+
qubit.MeasureAndReset
|
|
149
|
+
| qubit.MeasureQubit
|
|
150
|
+
| qubit.MeasureQubitList
|
|
151
|
+
| wire.MeasureAndReset
|
|
152
|
+
| wire.Measure
|
|
153
|
+
),
|
|
154
|
+
) -> bool:
|
|
155
|
+
"""
|
|
156
|
+
Check if the result of a measure statement is used in the program.
|
|
157
|
+
"""
|
|
158
|
+
return bool(stmt.result.uses)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
|
+
|
|
4
|
+
from bloqade.squin import wire
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SquinWireIdentityElimination(RewriteRule):
|
|
8
|
+
|
|
9
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
10
|
+
"""
|
|
11
|
+
Handle the case where an unwrap feeds a wire directly into a wrap,
|
|
12
|
+
equivalent to nothing happening/identity operation
|
|
13
|
+
|
|
14
|
+
w = unwrap(qubit)
|
|
15
|
+
wrap(qubit, w)
|
|
16
|
+
"""
|
|
17
|
+
if isinstance(node, wire.Wrap):
|
|
18
|
+
wire_origin_stmt = node.wire.owner
|
|
19
|
+
if isinstance(wire_origin_stmt, wire.Unwrap):
|
|
20
|
+
node.delete() # get rid of wrap
|
|
21
|
+
wire_origin_stmt.delete() # get rid of the unwrap
|
|
22
|
+
return RewriteResult(has_done_something=True)
|
|
23
|
+
|
|
24
|
+
return RewriteResult()
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
|
+
|
|
4
|
+
from bloqade import stim
|
|
5
|
+
from bloqade.squin import op, wire
|
|
6
|
+
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
7
|
+
from bloqade.squin.rewrite.stim_rewrite_util import (
|
|
8
|
+
SQUIN_STIM_GATE_MAPPING,
|
|
9
|
+
rewrite_Control,
|
|
10
|
+
insert_qubit_idx_from_address,
|
|
11
|
+
insert_qubit_idx_from_wire_ssa,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SquinWireToStim(RewriteRule):
|
|
16
|
+
|
|
17
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
18
|
+
match node:
|
|
19
|
+
case wire.Apply() | wire.Broadcast():
|
|
20
|
+
return self.rewrite_Apply_and_Broadcast(node)
|
|
21
|
+
case wire.Reset():
|
|
22
|
+
return self.rewrite_Reset(node)
|
|
23
|
+
case _:
|
|
24
|
+
return RewriteResult()
|
|
25
|
+
|
|
26
|
+
def rewrite_Apply_and_Broadcast(
|
|
27
|
+
self, stmt: wire.Apply | wire.Broadcast
|
|
28
|
+
) -> RewriteResult:
|
|
29
|
+
|
|
30
|
+
# this is an SSAValue, need it to be the actual operator
|
|
31
|
+
applied_op = stmt.operator.owner
|
|
32
|
+
assert isinstance(applied_op, op.stmts.Operator)
|
|
33
|
+
|
|
34
|
+
if isinstance(applied_op, op.stmts.Control):
|
|
35
|
+
return rewrite_Control(stmt)
|
|
36
|
+
|
|
37
|
+
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
|
|
38
|
+
if stim_1q_op is None:
|
|
39
|
+
return RewriteResult()
|
|
40
|
+
|
|
41
|
+
qubit_idx_ssas = insert_qubit_idx_from_wire_ssa(
|
|
42
|
+
wire_ssas=stmt.inputs, stmt_to_insert_before=stmt
|
|
43
|
+
)
|
|
44
|
+
if qubit_idx_ssas is None:
|
|
45
|
+
return RewriteResult()
|
|
46
|
+
|
|
47
|
+
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
|
|
48
|
+
|
|
49
|
+
# Get the wires from the inputs of Apply or Broadcast,
|
|
50
|
+
# then put those as the result of the current stmt
|
|
51
|
+
# before replacing it entirely
|
|
52
|
+
for input_wire, output_wire in zip(stmt.inputs, stmt.results):
|
|
53
|
+
output_wire.replace_by(input_wire)
|
|
54
|
+
|
|
55
|
+
stmt.replace_by(stim_1q_stmt)
|
|
56
|
+
|
|
57
|
+
return RewriteResult(has_done_something=True)
|
|
58
|
+
|
|
59
|
+
def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult:
|
|
60
|
+
address_attr = reset_stmt.wire.hints.get("address")
|
|
61
|
+
if address_attr is None:
|
|
62
|
+
return RewriteResult()
|
|
63
|
+
assert isinstance(address_attr, AddressAttribute)
|
|
64
|
+
qubit_idx_ssas = insert_qubit_idx_from_address(
|
|
65
|
+
address=address_attr, stmt_to_insert_before=reset_stmt
|
|
66
|
+
)
|
|
67
|
+
if qubit_idx_ssas is None:
|
|
68
|
+
return RewriteResult()
|
|
69
|
+
|
|
70
|
+
stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
|
|
71
|
+
reset_stmt.replace_by(stim_rz_stmt)
|
|
72
|
+
|
|
73
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
5
|
+
from kirin.print.printer import Printer
|
|
6
|
+
|
|
7
|
+
from bloqade.squin import op, wire
|
|
8
|
+
from bloqade.analysis.address import Address
|
|
9
|
+
from bloqade.squin.analysis.nsites import Sites
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@wire.dialect.register
|
|
13
|
+
@dataclass
|
|
14
|
+
class AddressAttribute(ir.Attribute):
|
|
15
|
+
|
|
16
|
+
name = "Address"
|
|
17
|
+
address: Address
|
|
18
|
+
|
|
19
|
+
def __hash__(self) -> int:
|
|
20
|
+
return hash(self.address)
|
|
21
|
+
|
|
22
|
+
def print_impl(self, printer: Printer) -> None:
|
|
23
|
+
# Can return to implementing this later
|
|
24
|
+
printer.print(self.address)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@op.dialect.register
|
|
28
|
+
@dataclass
|
|
29
|
+
class SitesAttribute(ir.Attribute):
|
|
30
|
+
|
|
31
|
+
name = "Sites"
|
|
32
|
+
sites: Sites
|
|
33
|
+
|
|
34
|
+
def __hash__(self) -> int:
|
|
35
|
+
return hash(self.sites)
|
|
36
|
+
|
|
37
|
+
def print_impl(self, printer: Printer) -> None:
|
|
38
|
+
# Can return to implementing this later
|
|
39
|
+
printer.print(self.sites)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class WrapSquinAnalysis(RewriteRule):
|
|
44
|
+
|
|
45
|
+
address_analysis: dict[ir.SSAValue, Address]
|
|
46
|
+
op_site_analysis: dict[ir.SSAValue, Sites]
|
|
47
|
+
|
|
48
|
+
def wrap(self, value: ir.SSAValue) -> bool:
|
|
49
|
+
address_analysis_result = self.address_analysis[value]
|
|
50
|
+
op_site_analysis_result = self.op_site_analysis[value]
|
|
51
|
+
|
|
52
|
+
if value.hints.get("address") and value.hints.get("sites"):
|
|
53
|
+
return False
|
|
54
|
+
else:
|
|
55
|
+
value.hints["address"] = AddressAttribute(address_analysis_result)
|
|
56
|
+
value.hints["sites"] = SitesAttribute(op_site_analysis_result)
|
|
57
|
+
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
|
|
61
|
+
has_done_something = False
|
|
62
|
+
for arg in node.args:
|
|
63
|
+
if self.wrap(arg):
|
|
64
|
+
has_done_something = True
|
|
65
|
+
return RewriteResult(has_done_something=has_done_something)
|
|
66
|
+
|
|
67
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
68
|
+
has_done_something = False
|
|
69
|
+
for result in node.results:
|
|
70
|
+
if self.wrap(result):
|
|
71
|
+
has_done_something = True
|
|
72
|
+
return RewriteResult(has_done_something=has_done_something)
|
bloqade/squin/wire.py
CHANGED
|
@@ -6,7 +6,7 @@ circuits. Thus we do not define wrapping functions for the statements in this
|
|
|
6
6
|
dialect.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
from kirin import ir, types,
|
|
9
|
+
from kirin import ir, types, lowering
|
|
10
10
|
from kirin.decl import info, statement
|
|
11
11
|
from kirin.lowering import wraps
|
|
12
12
|
|
|
@@ -112,18 +112,6 @@ class Reset(ir.Statement):
|
|
|
112
112
|
wire: ir.SSAValue = info.argument(WireType)
|
|
113
113
|
|
|
114
114
|
|
|
115
|
-
# Issue where constant propagation can't handle
|
|
116
|
-
# multiple return values from Apply properly
|
|
117
|
-
@dialect.register(key="constprop")
|
|
118
|
-
class ConstPropWire(interp.MethodTable):
|
|
119
|
-
|
|
120
|
-
@interp.impl(Apply)
|
|
121
|
-
@interp.impl(Broadcast)
|
|
122
|
-
def apply(self, interp, frame, stmt: Apply):
|
|
123
|
-
|
|
124
|
-
return frame.get_values(stmt.inputs)
|
|
125
|
-
|
|
126
|
-
|
|
127
115
|
@wraps(Unwrap)
|
|
128
116
|
def unwrap(qubit: Qubit) -> Wire: ...
|
|
129
117
|
|