bloqade-circuit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/__init__.py +0 -0
- bloqade/analysis/address/__init__.py +11 -0
- bloqade/analysis/address/analysis.py +60 -0
- bloqade/analysis/address/impls.py +228 -0
- bloqade/analysis/address/lattice.py +85 -0
- bloqade/noise/__init__.py +1 -0
- bloqade/noise/native/__init__.py +20 -0
- bloqade/noise/native/_dialect.py +3 -0
- bloqade/noise/native/_wrappers.py +34 -0
- bloqade/noise/native/model.py +347 -0
- bloqade/noise/native/rewrite.py +35 -0
- bloqade/noise/native/stmts.py +46 -0
- bloqade/pyqrack/__init__.py +18 -0
- bloqade/pyqrack/base.py +131 -0
- bloqade/pyqrack/noise/__init__.py +0 -0
- bloqade/pyqrack/noise/native.py +100 -0
- bloqade/pyqrack/qasm2/__init__.py +0 -0
- bloqade/pyqrack/qasm2/core.py +79 -0
- bloqade/pyqrack/qasm2/parallel.py +46 -0
- bloqade/pyqrack/qasm2/uop.py +247 -0
- bloqade/pyqrack/reg.py +109 -0
- bloqade/pyqrack/target.py +112 -0
- bloqade/qasm2/__init__.py +19 -0
- bloqade/qasm2/_wrappers.py +674 -0
- bloqade/qasm2/dialects/__init__.py +10 -0
- bloqade/qasm2/dialects/core/__init__.py +3 -0
- bloqade/qasm2/dialects/core/_dialect.py +3 -0
- bloqade/qasm2/dialects/core/_emit.py +68 -0
- bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
- bloqade/qasm2/dialects/core/address.py +38 -0
- bloqade/qasm2/dialects/core/stmts.py +94 -0
- bloqade/qasm2/dialects/expr/__init__.py +3 -0
- bloqade/qasm2/dialects/expr/_dialect.py +3 -0
- bloqade/qasm2/dialects/expr/_emit.py +103 -0
- bloqade/qasm2/dialects/expr/_from_python.py +86 -0
- bloqade/qasm2/dialects/expr/_interp.py +75 -0
- bloqade/qasm2/dialects/expr/stmts.py +262 -0
- bloqade/qasm2/dialects/glob.py +45 -0
- bloqade/qasm2/dialects/indexing.py +64 -0
- bloqade/qasm2/dialects/inline.py +76 -0
- bloqade/qasm2/dialects/noise.py +16 -0
- bloqade/qasm2/dialects/parallel.py +110 -0
- bloqade/qasm2/dialects/uop/__init__.py +4 -0
- bloqade/qasm2/dialects/uop/_dialect.py +3 -0
- bloqade/qasm2/dialects/uop/_emit.py +211 -0
- bloqade/qasm2/dialects/uop/schedule.py +89 -0
- bloqade/qasm2/dialects/uop/stmts.py +325 -0
- bloqade/qasm2/emit/__init__.py +1 -0
- bloqade/qasm2/emit/base.py +72 -0
- bloqade/qasm2/emit/gate.py +102 -0
- bloqade/qasm2/emit/main.py +106 -0
- bloqade/qasm2/emit/target.py +165 -0
- bloqade/qasm2/glob.py +24 -0
- bloqade/qasm2/groups.py +120 -0
- bloqade/qasm2/parallel.py +48 -0
- bloqade/qasm2/parse/__init__.py +37 -0
- bloqade/qasm2/parse/ast.py +235 -0
- bloqade/qasm2/parse/build.py +289 -0
- bloqade/qasm2/parse/lowering.py +553 -0
- bloqade/qasm2/parse/parser.py +5 -0
- bloqade/qasm2/parse/print.py +293 -0
- bloqade/qasm2/parse/qasm2.lark +75 -0
- bloqade/qasm2/parse/visitor.py +16 -0
- bloqade/qasm2/parse/visitor.pyi +39 -0
- bloqade/qasm2/passes/__init__.py +5 -0
- bloqade/qasm2/passes/fold.py +94 -0
- bloqade/qasm2/passes/glob.py +119 -0
- bloqade/qasm2/passes/noise.py +61 -0
- bloqade/qasm2/passes/parallel.py +176 -0
- bloqade/qasm2/passes/py2qasm.py +63 -0
- bloqade/qasm2/passes/qasm2py.py +61 -0
- bloqade/qasm2/rewrite/__init__.py +12 -0
- bloqade/qasm2/rewrite/desugar.py +28 -0
- bloqade/qasm2/rewrite/glob.py +103 -0
- bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
- bloqade/qasm2/rewrite/native_gates.py +447 -0
- bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
- bloqade/qasm2/rewrite/register.py +45 -0
- bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
- bloqade/qasm2/types.py +39 -0
- bloqade/qbraid/__init__.py +2 -0
- bloqade/qbraid/lowering.py +324 -0
- bloqade/qbraid/schema.py +252 -0
- bloqade/qbraid/simulation_result.py +99 -0
- bloqade/qbraid/target.py +86 -0
- bloqade/squin/__init__.py +2 -0
- bloqade/squin/analysis/__init__.py +0 -0
- bloqade/squin/analysis/nsites/__init__.py +8 -0
- bloqade/squin/analysis/nsites/analysis.py +52 -0
- bloqade/squin/analysis/nsites/impls.py +69 -0
- bloqade/squin/analysis/nsites/lattice.py +49 -0
- bloqade/squin/analysis/schedule.py +244 -0
- bloqade/squin/groups.py +38 -0
- bloqade/squin/op/__init__.py +132 -0
- bloqade/squin/op/_dialect.py +3 -0
- bloqade/squin/op/complex.py +6 -0
- bloqade/squin/op/stmts.py +220 -0
- bloqade/squin/op/traits.py +43 -0
- bloqade/squin/op/types.py +10 -0
- bloqade/squin/qubit.py +118 -0
- bloqade/squin/wire.py +103 -0
- bloqade/stim/__init__.py +6 -0
- bloqade/stim/_wrappers.py +186 -0
- bloqade/stim/dialects/__init__.py +5 -0
- bloqade/stim/dialects/aux/__init__.py +11 -0
- bloqade/stim/dialects/aux/_dialect.py +3 -0
- bloqade/stim/dialects/aux/emit.py +102 -0
- bloqade/stim/dialects/aux/interp.py +39 -0
- bloqade/stim/dialects/aux/lowering.py +40 -0
- bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
- bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
- bloqade/stim/dialects/aux/stmts/const.py +95 -0
- bloqade/stim/dialects/aux/types.py +19 -0
- bloqade/stim/dialects/collapse/__init__.py +3 -0
- bloqade/stim/dialects/collapse/_dialect.py +3 -0
- bloqade/stim/dialects/collapse/emit.py +68 -0
- bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
- bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
- bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
- bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
- bloqade/stim/dialects/gate/__init__.py +3 -0
- bloqade/stim/dialects/gate/_dialect.py +3 -0
- bloqade/stim/dialects/gate/emit.py +87 -0
- bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
- bloqade/stim/dialects/gate/stmts/base.py +31 -0
- bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
- bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
- bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
- bloqade/stim/dialects/gate/stmts/pp.py +15 -0
- bloqade/stim/dialects/noise/__init__.py +3 -0
- bloqade/stim/dialects/noise/_dialect.py +3 -0
- bloqade/stim/dialects/noise/emit.py +66 -0
- bloqade/stim/dialects/noise/stmts.py +77 -0
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/stim.py +54 -0
- bloqade/stim/groups.py +26 -0
- bloqade/test_utils.py +35 -0
- bloqade/types.py +24 -0
- bloqade/visual/__init__.py +1 -0
- bloqade/visual/animation/__init__.py +0 -0
- bloqade/visual/animation/animate.py +267 -0
- bloqade/visual/animation/base.py +346 -0
- bloqade/visual/animation/gate_event.py +24 -0
- bloqade/visual/animation/runtime/__init__.py +0 -0
- bloqade/visual/animation/runtime/aod.py +36 -0
- bloqade/visual/animation/runtime/atoms.py +55 -0
- bloqade/visual/animation/runtime/ppoly.py +50 -0
- bloqade/visual/animation/runtime/qpustate.py +119 -0
- bloqade/visual/animation/runtime/utils.py +43 -0
- bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
- bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
- bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
- bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from typing import Dict, List, Optional
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.rewrite import abc, result
|
|
6
|
+
|
|
7
|
+
from bloqade.analysis import address
|
|
8
|
+
from bloqade.qasm2.dialects import uop, parallel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ParallelToUOpRule(abc.RewriteRule):
|
|
13
|
+
id_map: Dict[int, ir.SSAValue]
|
|
14
|
+
address_analysis: Dict[ir.SSAValue, address.Address]
|
|
15
|
+
|
|
16
|
+
def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
|
|
17
|
+
if type(node) in parallel.dialect.stmts:
|
|
18
|
+
return getattr(self, f"rewrite_{node.name}")(node)
|
|
19
|
+
|
|
20
|
+
return result.RewriteResult()
|
|
21
|
+
|
|
22
|
+
def get_qubit_ssa(self, ilist_ref: ir.SSAValue) -> Optional[List[ir.SSAValue]]:
|
|
23
|
+
addr = self.address_analysis.get(ilist_ref)
|
|
24
|
+
if not isinstance(addr, address.AddressTuple):
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
ids = []
|
|
28
|
+
for ele in addr.data:
|
|
29
|
+
if not isinstance(ele, address.AddressQubit):
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
ids.append(ele.data)
|
|
33
|
+
|
|
34
|
+
return [self.id_map[ele] for ele in ids]
|
|
35
|
+
|
|
36
|
+
def rewrite_cz(self, node: ir.Statement):
|
|
37
|
+
assert isinstance(node, parallel.CZ)
|
|
38
|
+
|
|
39
|
+
ctrls = self.get_qubit_ssa(node.ctrls)
|
|
40
|
+
qargs = self.get_qubit_ssa(node.qargs)
|
|
41
|
+
|
|
42
|
+
if ctrls is None or qargs is None:
|
|
43
|
+
return result.RewriteResult()
|
|
44
|
+
|
|
45
|
+
for ctrl, qarg in zip(ctrls, qargs):
|
|
46
|
+
new_node = uop.CZ(ctrl, qarg)
|
|
47
|
+
new_node.insert_before(node)
|
|
48
|
+
|
|
49
|
+
node.delete()
|
|
50
|
+
|
|
51
|
+
return result.RewriteResult(has_done_something=True)
|
|
52
|
+
|
|
53
|
+
def rewrite_u(self, node: ir.Statement):
|
|
54
|
+
assert isinstance(node, parallel.UGate)
|
|
55
|
+
|
|
56
|
+
qargs = self.get_qubit_ssa(node.qargs)
|
|
57
|
+
|
|
58
|
+
if qargs is None:
|
|
59
|
+
return result.RewriteResult()
|
|
60
|
+
|
|
61
|
+
for qarg in qargs:
|
|
62
|
+
new_node = uop.UGate(qarg, theta=node.theta, phi=node.phi, lam=node.lam)
|
|
63
|
+
new_node.insert_after(node)
|
|
64
|
+
|
|
65
|
+
node.delete()
|
|
66
|
+
|
|
67
|
+
return result.RewriteResult(has_done_something=True)
|
|
68
|
+
|
|
69
|
+
def rewrite_rz(self, node: ir.Statement):
|
|
70
|
+
assert isinstance(node, parallel.RZ)
|
|
71
|
+
|
|
72
|
+
qargs = self.get_qubit_ssa(node.qargs)
|
|
73
|
+
|
|
74
|
+
if qargs is None:
|
|
75
|
+
return result.RewriteResult()
|
|
76
|
+
|
|
77
|
+
for qarg in qargs:
|
|
78
|
+
new_node = uop.RZ(qarg, theta=node.theta)
|
|
79
|
+
new_node.insert_after(node)
|
|
80
|
+
|
|
81
|
+
node.delete()
|
|
82
|
+
|
|
83
|
+
return result.RewriteResult(has_done_something=True)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.dialects import py
|
|
3
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
4
|
+
|
|
5
|
+
from bloqade.qasm2.dialects import core
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RaiseRegisterRule(RewriteRule):
|
|
9
|
+
"""This rule puts all registers at the top of the block.
|
|
10
|
+
|
|
11
|
+
This is required for the UOpToParallel rules to work correctly
|
|
12
|
+
to handle cases where a register is defined in between two statements
|
|
13
|
+
that can be parallelized.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
18
|
+
if not isinstance(node, core.QRegNew):
|
|
19
|
+
return RewriteResult()
|
|
20
|
+
|
|
21
|
+
if node.parent_block is None or node.parent_block.first_stmt is None:
|
|
22
|
+
return RewriteResult()
|
|
23
|
+
|
|
24
|
+
first_stmt = node.parent_block.first_stmt
|
|
25
|
+
|
|
26
|
+
n_qubits_ref = node.n_qubits
|
|
27
|
+
|
|
28
|
+
n_qubits = n_qubits_ref.owner
|
|
29
|
+
if isinstance(n_qubits, py.Constant):
|
|
30
|
+
# case where the n_qubits comes from a constant
|
|
31
|
+
new_n_qubits = n_qubits.from_stmt(n_qubits)
|
|
32
|
+
new_n_qubits.insert_before(first_stmt)
|
|
33
|
+
new_n_qubits_ref = new_n_qubits.result
|
|
34
|
+
|
|
35
|
+
elif isinstance(n_qubits, ir.BlockArgument):
|
|
36
|
+
# case where the n_qubits comes from a block argument
|
|
37
|
+
new_n_qubits_ref = n_qubits
|
|
38
|
+
else:
|
|
39
|
+
return RewriteResult()
|
|
40
|
+
|
|
41
|
+
new_qreg_stmt = core.QRegNew(n_qubits=new_n_qubits_ref)
|
|
42
|
+
new_qreg_stmt.insert_before(first_stmt)
|
|
43
|
+
node.result.replace_by(new_qreg_stmt.result)
|
|
44
|
+
node.delete()
|
|
45
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Dict, List, Tuple, Iterable
|
|
3
|
+
from dataclasses import field, dataclass
|
|
4
|
+
|
|
5
|
+
from kirin import ir
|
|
6
|
+
from kirin.rewrite import abc as rewrite_abc
|
|
7
|
+
from kirin.dialects import py, ilist
|
|
8
|
+
from kirin.analysis.const import lattice
|
|
9
|
+
|
|
10
|
+
from bloqade.analysis import address
|
|
11
|
+
from bloqade.qasm2.dialects import uop, core, parallel
|
|
12
|
+
from bloqade.squin.analysis.schedule import StmtDag
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MergePolicyABC(abc.ABC):
|
|
16
|
+
@abc.abstractmethod
|
|
17
|
+
def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
@abc.abstractmethod
|
|
22
|
+
def can_merge(cls, stmt1: ir.Statement, stmt2: ir.Statement) -> bool:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
@abc.abstractmethod
|
|
27
|
+
def merge_gates(
|
|
28
|
+
cls, gate_stmts: Iterable[ir.Statement]
|
|
29
|
+
) -> List[List[ir.Statement]]:
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
def from_analysis(
|
|
35
|
+
cls, dag: StmtDag, address_analysis: Dict[ir.SSAValue, address.Address]
|
|
36
|
+
) -> "MergePolicyABC":
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class SimpleMergePolicy(MergePolicyABC):
|
|
42
|
+
"""General merge policy for merging gates based on their type and arguments.
|
|
43
|
+
|
|
44
|
+
Base class to implement a merge policy for CZ, U and RZ gates, To completed the policy implement the
|
|
45
|
+
`merge_gates` class method. This will take an iterable of statements and return a list
|
|
46
|
+
of groups of statements that can be merged together. There are two mix-in classes
|
|
47
|
+
that can be used to implement the `merge_gates` method. The `GreedyMixin` will merge
|
|
48
|
+
gates together greedily, while the `OptimalMixIn` will merge gates together optimally.
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
address_analysis: Dict[ir.SSAValue, address.Address]
|
|
53
|
+
"""Mapping from SSA values to their address analysis results. Needed for rewrites"""
|
|
54
|
+
merge_groups: List[List[ir.Statement]]
|
|
55
|
+
"""List of groups of statements that can be merged together"""
|
|
56
|
+
group_numbers: Dict[ir.Statement, int]
|
|
57
|
+
"""Mapping from statements to their group number"""
|
|
58
|
+
group_has_merged: Dict[int, bool] = field(default_factory=dict)
|
|
59
|
+
"""Mapping from group number to whether the group has been merged"""
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def same_id_checker(ssa1: ir.SSAValue, ssa2: ir.SSAValue):
|
|
63
|
+
if ssa1 is ssa2:
|
|
64
|
+
return True
|
|
65
|
+
elif (hint1 := ssa1.hints.get("const")) and (hint2 := ssa2.hints.get("const")):
|
|
66
|
+
assert isinstance(hint1, lattice.Result) and isinstance(
|
|
67
|
+
hint2, lattice.Result
|
|
68
|
+
)
|
|
69
|
+
return hint1.is_equal(hint2)
|
|
70
|
+
else:
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def check_equiv_args(
|
|
75
|
+
cls,
|
|
76
|
+
args1: Iterable[ir.SSAValue],
|
|
77
|
+
args2: Iterable[ir.SSAValue],
|
|
78
|
+
):
|
|
79
|
+
try:
|
|
80
|
+
return all(
|
|
81
|
+
cls.same_id_checker(ssa1, ssa2)
|
|
82
|
+
for ssa1, ssa2 in zip(args1, args2, strict=True)
|
|
83
|
+
)
|
|
84
|
+
except ValueError:
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def can_merge(cls, stmt1: ir.Statement, stmt2: ir.Statement) -> bool:
|
|
89
|
+
match stmt1, stmt2:
|
|
90
|
+
case (
|
|
91
|
+
(uop.UGate(), uop.UGate())
|
|
92
|
+
| (uop.RZ(), uop.RZ())
|
|
93
|
+
| (parallel.UGate(), parallel.UGate())
|
|
94
|
+
| (parallel.UGate(), uop.UGate())
|
|
95
|
+
| (uop.UGate(), parallel.UGate())
|
|
96
|
+
| (uop.UGate(), parallel.UGate())
|
|
97
|
+
| (uop.UGate(), parallel.UGate())
|
|
98
|
+
| (parallel.RZ(), parallel.RZ())
|
|
99
|
+
| (uop.RZ(), parallel.RZ())
|
|
100
|
+
| (parallel.RZ(), uop.RZ())
|
|
101
|
+
):
|
|
102
|
+
return cls.check_equiv_args(stmt1.args[1:], stmt2.args[1:])
|
|
103
|
+
case (
|
|
104
|
+
(parallel.CZ(), parallel.CZ())
|
|
105
|
+
| (parallel.CZ(), uop.CZ())
|
|
106
|
+
| (uop.CZ(), parallel.CZ())
|
|
107
|
+
| (uop.CZ(), uop.CZ())
|
|
108
|
+
| (uop.Barrier(), uop.Barrier())
|
|
109
|
+
):
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
case _:
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def from_analysis(
|
|
117
|
+
cls,
|
|
118
|
+
dag: StmtDag,
|
|
119
|
+
address_analysis: Dict[ir.SSAValue, address.Address],
|
|
120
|
+
):
|
|
121
|
+
|
|
122
|
+
merge_groups = []
|
|
123
|
+
group_numbers = {}
|
|
124
|
+
|
|
125
|
+
for group in dag.topological_groups():
|
|
126
|
+
gate_groups = cls.merge_gates(map(dag.stmts.__getitem__, group))
|
|
127
|
+
gate_groups_iter = (group for group in gate_groups if len(group) > 1)
|
|
128
|
+
|
|
129
|
+
for gate_group in gate_groups_iter:
|
|
130
|
+
group_number = len(merge_groups)
|
|
131
|
+
merge_groups.append(gate_group)
|
|
132
|
+
for stmt in gate_group:
|
|
133
|
+
group_numbers[stmt] = group_number
|
|
134
|
+
|
|
135
|
+
for group in merge_groups:
|
|
136
|
+
group.sort(key=lambda stmt: dag.stmt_index[stmt])
|
|
137
|
+
|
|
138
|
+
return cls(
|
|
139
|
+
address_analysis=address_analysis,
|
|
140
|
+
merge_groups=merge_groups,
|
|
141
|
+
group_numbers=group_numbers,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
|
|
145
|
+
|
|
146
|
+
if node not in self.group_numbers:
|
|
147
|
+
return rewrite_abc.RewriteResult()
|
|
148
|
+
|
|
149
|
+
group_number = self.group_numbers[node]
|
|
150
|
+
group = self.merge_groups[group_number]
|
|
151
|
+
if node is group[0]:
|
|
152
|
+
result = getattr(self, f"rewrite_group_{node.name}")(node, group)
|
|
153
|
+
|
|
154
|
+
self.group_has_merged[group_number] = result.has_done_something
|
|
155
|
+
return result
|
|
156
|
+
|
|
157
|
+
if self.group_has_merged[group_number]:
|
|
158
|
+
node.delete()
|
|
159
|
+
|
|
160
|
+
return rewrite_abc.RewriteResult(
|
|
161
|
+
has_done_something=self.group_has_merged[group_number]
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def move_and_collect_qubit_list(
|
|
165
|
+
self, qargs: List[ir.SSAValue], node: ir.Statement
|
|
166
|
+
) -> Tuple[ir.SSAValue, ...] | None:
|
|
167
|
+
|
|
168
|
+
qubits: List[ir.SSAValue] = []
|
|
169
|
+
# collect references to qubits
|
|
170
|
+
for qarg in qargs:
|
|
171
|
+
addr = self.address_analysis[qarg]
|
|
172
|
+
|
|
173
|
+
if isinstance(addr, address.AddressQubit):
|
|
174
|
+
qubits.append(qarg)
|
|
175
|
+
|
|
176
|
+
elif isinstance(addr, address.AddressTuple):
|
|
177
|
+
assert isinstance(qarg, ir.ResultValue)
|
|
178
|
+
assert isinstance(qarg.stmt, ilist.New)
|
|
179
|
+
qubits.extend(qarg.stmt.values)
|
|
180
|
+
else:
|
|
181
|
+
# give up if we cannot determine the address
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
new_qubits = []
|
|
185
|
+
|
|
186
|
+
# the registers must be moved to the top of the block
|
|
187
|
+
# before this pass can be applied
|
|
188
|
+
for qubit_ref in qubits:
|
|
189
|
+
qubit = qubit_ref.owner
|
|
190
|
+
match qubit:
|
|
191
|
+
case ir.BlockArgument(): # do not need to move the qubit
|
|
192
|
+
new_qubits.append(qubit)
|
|
193
|
+
case core.QRegGet(reg=reg, idx=ir.BlockArgument() as idx):
|
|
194
|
+
new_qubit = core.QRegGet(reg=reg, idx=idx)
|
|
195
|
+
new_qubit.insert_before(node)
|
|
196
|
+
new_qubits.append(new_qubit.result)
|
|
197
|
+
case core.QRegGet(
|
|
198
|
+
reg=reg, idx=ir.ResultValue(stmt=py.Constant() as idx)
|
|
199
|
+
):
|
|
200
|
+
(new_idx := idx.from_stmt(idx)).insert_before(node)
|
|
201
|
+
(
|
|
202
|
+
new_qubit := core.QRegGet(reg=reg, idx=new_idx.result)
|
|
203
|
+
).insert_before(node)
|
|
204
|
+
new_qubits.append(new_qubit.result)
|
|
205
|
+
case _:
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
return tuple(new_qubits)
|
|
209
|
+
|
|
210
|
+
def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]):
|
|
211
|
+
ctrls = []
|
|
212
|
+
qargs = []
|
|
213
|
+
|
|
214
|
+
for stmt in group:
|
|
215
|
+
if isinstance(stmt, uop.CZ):
|
|
216
|
+
ctrls.append(stmt.ctrl)
|
|
217
|
+
qargs.append(stmt.qarg)
|
|
218
|
+
elif isinstance(stmt, parallel.CZ):
|
|
219
|
+
ctrls.append(stmt.ctrls)
|
|
220
|
+
qargs.append(stmt.qargs)
|
|
221
|
+
else:
|
|
222
|
+
return rewrite_abc.RewriteResult(has_done_something=False)
|
|
223
|
+
|
|
224
|
+
ctrls_values = self.move_and_collect_qubit_list(ctrls, node)
|
|
225
|
+
qargs_values = self.move_and_collect_qubit_list(qargs, node)
|
|
226
|
+
|
|
227
|
+
if ctrls_values is None or qargs_values is None:
|
|
228
|
+
# give up if we cannot determine the address or cannot move the qubits
|
|
229
|
+
return rewrite_abc.RewriteResult(has_done_something=False)
|
|
230
|
+
|
|
231
|
+
new_ctrls = ilist.New(values=ctrls_values)
|
|
232
|
+
new_qargs = ilist.New(values=qargs_values)
|
|
233
|
+
new_gate = parallel.CZ(ctrls=new_ctrls.result, qargs=new_qargs.result)
|
|
234
|
+
|
|
235
|
+
new_ctrls.insert_before(node)
|
|
236
|
+
new_qargs.insert_before(node)
|
|
237
|
+
new_gate.insert_before(node)
|
|
238
|
+
|
|
239
|
+
node.delete()
|
|
240
|
+
|
|
241
|
+
return rewrite_abc.RewriteResult(has_done_something=True)
|
|
242
|
+
|
|
243
|
+
def rewrite_group_U(self, node: ir.Statement, group: List[ir.Statement]):
|
|
244
|
+
return self.rewrite_group_u(node, group)
|
|
245
|
+
|
|
246
|
+
def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]):
|
|
247
|
+
qargs = []
|
|
248
|
+
|
|
249
|
+
for stmt in group:
|
|
250
|
+
if isinstance(stmt, uop.UGate):
|
|
251
|
+
qargs.append(stmt.qarg)
|
|
252
|
+
elif isinstance(stmt, parallel.UGate):
|
|
253
|
+
qargs.append(stmt.qargs)
|
|
254
|
+
else:
|
|
255
|
+
return rewrite_abc.RewriteResult(has_done_something=False)
|
|
256
|
+
|
|
257
|
+
assert isinstance(node, (uop.UGate, parallel.UGate))
|
|
258
|
+
qargs_values = self.move_and_collect_qubit_list(qargs, node)
|
|
259
|
+
|
|
260
|
+
if qargs_values is None:
|
|
261
|
+
return rewrite_abc.RewriteResult(has_done_something=False)
|
|
262
|
+
|
|
263
|
+
new_qargs = ilist.New(values=qargs_values)
|
|
264
|
+
new_gate = parallel.UGate(
|
|
265
|
+
qargs=new_qargs.result,
|
|
266
|
+
theta=node.theta,
|
|
267
|
+
phi=node.phi,
|
|
268
|
+
lam=node.lam,
|
|
269
|
+
)
|
|
270
|
+
new_qargs.insert_before(node)
|
|
271
|
+
new_gate.insert_before(node)
|
|
272
|
+
node.delete()
|
|
273
|
+
|
|
274
|
+
return rewrite_abc.RewriteResult(has_done_something=True)
|
|
275
|
+
|
|
276
|
+
def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
|
|
277
|
+
qargs = []
|
|
278
|
+
|
|
279
|
+
for stmt in group:
|
|
280
|
+
if isinstance(stmt, uop.RZ):
|
|
281
|
+
qargs.append(stmt.qarg)
|
|
282
|
+
elif isinstance(stmt, parallel.RZ):
|
|
283
|
+
qargs.append(stmt.qargs)
|
|
284
|
+
else:
|
|
285
|
+
return rewrite_abc.RewriteResult(has_done_something=False)
|
|
286
|
+
|
|
287
|
+
assert isinstance(node, (uop.RZ, parallel.RZ))
|
|
288
|
+
|
|
289
|
+
qargs_values = self.move_and_collect_qubit_list(qargs, node)
|
|
290
|
+
|
|
291
|
+
if qargs_values is None:
|
|
292
|
+
return rewrite_abc.RewriteResult(has_done_something=False)
|
|
293
|
+
|
|
294
|
+
new_qargs = ilist.New(values=qargs_values)
|
|
295
|
+
new_gate = parallel.RZ(
|
|
296
|
+
qargs=new_qargs.result,
|
|
297
|
+
theta=node.theta,
|
|
298
|
+
)
|
|
299
|
+
new_qargs.insert_before(node)
|
|
300
|
+
new_gate.insert_before(node)
|
|
301
|
+
node.delete()
|
|
302
|
+
|
|
303
|
+
return rewrite_abc.RewriteResult(has_done_something=True)
|
|
304
|
+
|
|
305
|
+
def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]):
|
|
306
|
+
qargs = []
|
|
307
|
+
for stmt in group:
|
|
308
|
+
qargs.extend(stmt.qargs)
|
|
309
|
+
|
|
310
|
+
qargs_values = self.move_and_collect_qubit_list(qargs, node)
|
|
311
|
+
|
|
312
|
+
if qargs_values is None:
|
|
313
|
+
return rewrite_abc.RewriteResult(has_done_something=False)
|
|
314
|
+
|
|
315
|
+
new_node = uop.Barrier(qargs=qargs_values)
|
|
316
|
+
new_node.insert_before(node)
|
|
317
|
+
node.delete()
|
|
318
|
+
|
|
319
|
+
return rewrite_abc.RewriteResult(has_done_something=True)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class GreedyMixin(MergePolicyABC):
|
|
323
|
+
"""Merge policy that greedily merges gates together.
|
|
324
|
+
|
|
325
|
+
The `merge_gates` method will merge policy will try greedily merge gates together.
|
|
326
|
+
This policy has a worst case complexity of O(n) where n is the
|
|
327
|
+
number of gates in the input iterable.
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
@classmethod
|
|
331
|
+
def merge_gates(
|
|
332
|
+
cls, gate_stmts: Iterable[ir.Statement]
|
|
333
|
+
) -> List[List[ir.Statement]]:
|
|
334
|
+
|
|
335
|
+
iterable = iter(gate_stmts)
|
|
336
|
+
groups = [[next(iterable)]]
|
|
337
|
+
|
|
338
|
+
for stmt in gate_stmts:
|
|
339
|
+
if cls.can_merge(groups[-1][-1], stmt):
|
|
340
|
+
groups[-1].append(stmt)
|
|
341
|
+
else:
|
|
342
|
+
groups.append([stmt])
|
|
343
|
+
|
|
344
|
+
return groups
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class OptimalMixIn(MergePolicyABC):
|
|
348
|
+
"""Merge policy that merges gates together optimally.
|
|
349
|
+
|
|
350
|
+
The `merge_gates` method will merge policy will try to merge every gate into every
|
|
351
|
+
group of gates, terminating when it finds a group that can be merged with the current
|
|
352
|
+
gate. This policy has a worst case complexity of O(n^2) where n is the number of gates
|
|
353
|
+
in the input iterable.
|
|
354
|
+
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
@classmethod
|
|
358
|
+
def merge_gates(
|
|
359
|
+
cls, gate_stmts: Iterable[ir.Statement]
|
|
360
|
+
) -> List[List[ir.Statement]]:
|
|
361
|
+
|
|
362
|
+
groups = []
|
|
363
|
+
for stmt in gate_stmts:
|
|
364
|
+
found = False
|
|
365
|
+
for group in groups:
|
|
366
|
+
if cls.can_merge(group[-1], stmt):
|
|
367
|
+
group.append(stmt)
|
|
368
|
+
found = True
|
|
369
|
+
break
|
|
370
|
+
|
|
371
|
+
if not found:
|
|
372
|
+
groups.append([stmt])
|
|
373
|
+
|
|
374
|
+
return groups
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@dataclass
|
|
378
|
+
class SimpleGreedyMergePolicy(GreedyMixin, SimpleMergePolicy):
|
|
379
|
+
pass
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
@dataclass
|
|
383
|
+
class SimpleOptimalMergePolicy(OptimalMixIn, SimpleMergePolicy):
|
|
384
|
+
pass
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
@dataclass
|
|
388
|
+
class UOpToParallelRule(rewrite_abc.RewriteRule):
|
|
389
|
+
merge_rewriters: Dict[ir.Block | None, MergePolicyABC]
|
|
390
|
+
|
|
391
|
+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
|
|
392
|
+
merge_rewriter = self.merge_rewriters.get(
|
|
393
|
+
node.parent_block, lambda _: rewrite_abc.RewriteResult()
|
|
394
|
+
)
|
|
395
|
+
return merge_rewriter(node)
|
bloqade/qasm2/types.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from kirin import types
|
|
2
|
+
|
|
3
|
+
from bloqade.types import Qubit as Qubit, QubitType as QubitType
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Bit:
|
|
7
|
+
"""Runtime representation of a bit.
|
|
8
|
+
|
|
9
|
+
Note:
|
|
10
|
+
This is the base class of more specific bit types, such as
|
|
11
|
+
a reference to a piece of classical register in some quantum register
|
|
12
|
+
dialects.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class QReg:
|
|
19
|
+
"""Runtime representation of a quantum register."""
|
|
20
|
+
|
|
21
|
+
def __getitem__(self, index) -> Qubit:
|
|
22
|
+
raise NotImplementedError("cannot call __getitem__ outside of a kernel")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CReg:
|
|
26
|
+
"""Runtime representation of a classical register."""
|
|
27
|
+
|
|
28
|
+
def __getitem__(self, index) -> Bit:
|
|
29
|
+
raise NotImplementedError("cannot call __getitem__ outside of a kernel")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
BitType = types.PyClass(Bit)
|
|
33
|
+
"""Kirin type for a classical bit."""
|
|
34
|
+
|
|
35
|
+
QRegType = types.PyClass(QReg)
|
|
36
|
+
"""Kirin type for a quantum register."""
|
|
37
|
+
|
|
38
|
+
CRegType = types.PyClass(CReg)
|
|
39
|
+
"""Kirin type for a classical register."""
|