bloqade-circuit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/__init__.py +0 -0
- bloqade/analysis/address/__init__.py +11 -0
- bloqade/analysis/address/analysis.py +60 -0
- bloqade/analysis/address/impls.py +228 -0
- bloqade/analysis/address/lattice.py +85 -0
- bloqade/noise/__init__.py +1 -0
- bloqade/noise/native/__init__.py +20 -0
- bloqade/noise/native/_dialect.py +3 -0
- bloqade/noise/native/_wrappers.py +34 -0
- bloqade/noise/native/model.py +347 -0
- bloqade/noise/native/rewrite.py +35 -0
- bloqade/noise/native/stmts.py +46 -0
- bloqade/pyqrack/__init__.py +18 -0
- bloqade/pyqrack/base.py +131 -0
- bloqade/pyqrack/noise/__init__.py +0 -0
- bloqade/pyqrack/noise/native.py +100 -0
- bloqade/pyqrack/qasm2/__init__.py +0 -0
- bloqade/pyqrack/qasm2/core.py +79 -0
- bloqade/pyqrack/qasm2/parallel.py +46 -0
- bloqade/pyqrack/qasm2/uop.py +247 -0
- bloqade/pyqrack/reg.py +109 -0
- bloqade/pyqrack/target.py +112 -0
- bloqade/qasm2/__init__.py +19 -0
- bloqade/qasm2/_wrappers.py +674 -0
- bloqade/qasm2/dialects/__init__.py +10 -0
- bloqade/qasm2/dialects/core/__init__.py +3 -0
- bloqade/qasm2/dialects/core/_dialect.py +3 -0
- bloqade/qasm2/dialects/core/_emit.py +68 -0
- bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
- bloqade/qasm2/dialects/core/address.py +38 -0
- bloqade/qasm2/dialects/core/stmts.py +94 -0
- bloqade/qasm2/dialects/expr/__init__.py +3 -0
- bloqade/qasm2/dialects/expr/_dialect.py +3 -0
- bloqade/qasm2/dialects/expr/_emit.py +103 -0
- bloqade/qasm2/dialects/expr/_from_python.py +86 -0
- bloqade/qasm2/dialects/expr/_interp.py +75 -0
- bloqade/qasm2/dialects/expr/stmts.py +262 -0
- bloqade/qasm2/dialects/glob.py +45 -0
- bloqade/qasm2/dialects/indexing.py +64 -0
- bloqade/qasm2/dialects/inline.py +76 -0
- bloqade/qasm2/dialects/noise.py +16 -0
- bloqade/qasm2/dialects/parallel.py +110 -0
- bloqade/qasm2/dialects/uop/__init__.py +4 -0
- bloqade/qasm2/dialects/uop/_dialect.py +3 -0
- bloqade/qasm2/dialects/uop/_emit.py +211 -0
- bloqade/qasm2/dialects/uop/schedule.py +89 -0
- bloqade/qasm2/dialects/uop/stmts.py +325 -0
- bloqade/qasm2/emit/__init__.py +1 -0
- bloqade/qasm2/emit/base.py +72 -0
- bloqade/qasm2/emit/gate.py +102 -0
- bloqade/qasm2/emit/main.py +106 -0
- bloqade/qasm2/emit/target.py +165 -0
- bloqade/qasm2/glob.py +24 -0
- bloqade/qasm2/groups.py +120 -0
- bloqade/qasm2/parallel.py +48 -0
- bloqade/qasm2/parse/__init__.py +37 -0
- bloqade/qasm2/parse/ast.py +235 -0
- bloqade/qasm2/parse/build.py +289 -0
- bloqade/qasm2/parse/lowering.py +553 -0
- bloqade/qasm2/parse/parser.py +5 -0
- bloqade/qasm2/parse/print.py +293 -0
- bloqade/qasm2/parse/qasm2.lark +75 -0
- bloqade/qasm2/parse/visitor.py +16 -0
- bloqade/qasm2/parse/visitor.pyi +39 -0
- bloqade/qasm2/passes/__init__.py +5 -0
- bloqade/qasm2/passes/fold.py +94 -0
- bloqade/qasm2/passes/glob.py +119 -0
- bloqade/qasm2/passes/noise.py +61 -0
- bloqade/qasm2/passes/parallel.py +176 -0
- bloqade/qasm2/passes/py2qasm.py +63 -0
- bloqade/qasm2/passes/qasm2py.py +61 -0
- bloqade/qasm2/rewrite/__init__.py +12 -0
- bloqade/qasm2/rewrite/desugar.py +28 -0
- bloqade/qasm2/rewrite/glob.py +103 -0
- bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
- bloqade/qasm2/rewrite/native_gates.py +447 -0
- bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
- bloqade/qasm2/rewrite/register.py +45 -0
- bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
- bloqade/qasm2/types.py +39 -0
- bloqade/qbraid/__init__.py +2 -0
- bloqade/qbraid/lowering.py +324 -0
- bloqade/qbraid/schema.py +252 -0
- bloqade/qbraid/simulation_result.py +99 -0
- bloqade/qbraid/target.py +86 -0
- bloqade/squin/__init__.py +2 -0
- bloqade/squin/analysis/__init__.py +0 -0
- bloqade/squin/analysis/nsites/__init__.py +8 -0
- bloqade/squin/analysis/nsites/analysis.py +52 -0
- bloqade/squin/analysis/nsites/impls.py +69 -0
- bloqade/squin/analysis/nsites/lattice.py +49 -0
- bloqade/squin/analysis/schedule.py +244 -0
- bloqade/squin/groups.py +38 -0
- bloqade/squin/op/__init__.py +132 -0
- bloqade/squin/op/_dialect.py +3 -0
- bloqade/squin/op/complex.py +6 -0
- bloqade/squin/op/stmts.py +220 -0
- bloqade/squin/op/traits.py +43 -0
- bloqade/squin/op/types.py +10 -0
- bloqade/squin/qubit.py +118 -0
- bloqade/squin/wire.py +103 -0
- bloqade/stim/__init__.py +6 -0
- bloqade/stim/_wrappers.py +186 -0
- bloqade/stim/dialects/__init__.py +5 -0
- bloqade/stim/dialects/aux/__init__.py +11 -0
- bloqade/stim/dialects/aux/_dialect.py +3 -0
- bloqade/stim/dialects/aux/emit.py +102 -0
- bloqade/stim/dialects/aux/interp.py +39 -0
- bloqade/stim/dialects/aux/lowering.py +40 -0
- bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
- bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
- bloqade/stim/dialects/aux/stmts/const.py +95 -0
- bloqade/stim/dialects/aux/types.py +19 -0
- bloqade/stim/dialects/collapse/__init__.py +3 -0
- bloqade/stim/dialects/collapse/_dialect.py +3 -0
- bloqade/stim/dialects/collapse/emit.py +68 -0
- bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
- bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
- bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
- bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
- bloqade/stim/dialects/gate/__init__.py +3 -0
- bloqade/stim/dialects/gate/_dialect.py +3 -0
- bloqade/stim/dialects/gate/emit.py +87 -0
- bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
- bloqade/stim/dialects/gate/stmts/base.py +31 -0
- bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
- bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
- bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
- bloqade/stim/dialects/gate/stmts/pp.py +15 -0
- bloqade/stim/dialects/noise/__init__.py +3 -0
- bloqade/stim/dialects/noise/_dialect.py +3 -0
- bloqade/stim/dialects/noise/emit.py +66 -0
- bloqade/stim/dialects/noise/stmts.py +77 -0
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/stim.py +54 -0
- bloqade/stim/groups.py +26 -0
- bloqade/test_utils.py +35 -0
- bloqade/types.py +24 -0
- bloqade/visual/__init__.py +1 -0
- bloqade/visual/animation/__init__.py +0 -0
- bloqade/visual/animation/animate.py +267 -0
- bloqade/visual/animation/base.py +346 -0
- bloqade/visual/animation/gate_event.py +24 -0
- bloqade/visual/animation/runtime/__init__.py +0 -0
- bloqade/visual/animation/runtime/aod.py +36 -0
- bloqade/visual/animation/runtime/atoms.py +55 -0
- bloqade/visual/animation/runtime/ppoly.py +50 -0
- bloqade/visual/animation/runtime/qpustate.py +119 -0
- bloqade/visual/animation/runtime/utils.py +43 -0
- bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
- bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
- bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
- bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from dataclasses import field, dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.passes import Pass
|
|
5
|
+
from kirin.rewrite import (
|
|
6
|
+
Walk,
|
|
7
|
+
Chain,
|
|
8
|
+
Fixpoint,
|
|
9
|
+
ConstantFold,
|
|
10
|
+
DeadCodeElimination,
|
|
11
|
+
CommonSubexpressionElimination,
|
|
12
|
+
)
|
|
13
|
+
from kirin.rewrite.result import RewriteResult
|
|
14
|
+
|
|
15
|
+
from bloqade.noise import native
|
|
16
|
+
from bloqade.analysis import address
|
|
17
|
+
from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class NoisePass(Pass):
|
|
22
|
+
"""Apply a noise model to a quantum circuit.
|
|
23
|
+
|
|
24
|
+
NOTE: This pass is not guaranteed to be supported long-term in bloqade. We will be
|
|
25
|
+
moving towards a more general approach to noise modeling in the future.
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
noise_model: native.MoveNoiseModelABC = field(
|
|
30
|
+
default_factory=native.TwoRowZoneModel
|
|
31
|
+
)
|
|
32
|
+
gate_noise_params: native.GateNoiseParams = field(
|
|
33
|
+
default_factory=native.GateNoiseParams
|
|
34
|
+
)
|
|
35
|
+
address_analysis: address.AddressAnalysis = field(init=False)
|
|
36
|
+
|
|
37
|
+
def __post_init__(self):
|
|
38
|
+
self.address_analysis = address.AddressAnalysis(self.dialects)
|
|
39
|
+
|
|
40
|
+
def unsafe_run(self, mt: ir.Method):
|
|
41
|
+
result = RewriteResult()
|
|
42
|
+
|
|
43
|
+
frame, res = self.address_analysis.run_analysis(mt, no_raise=False)
|
|
44
|
+
result = (
|
|
45
|
+
Walk(
|
|
46
|
+
NoiseRewriteRule(
|
|
47
|
+
address_analysis=frame.entries,
|
|
48
|
+
noise_model=self.noise_model,
|
|
49
|
+
gate_noise_params=self.gate_noise_params,
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
.rewrite(mt.code)
|
|
53
|
+
.join(result)
|
|
54
|
+
)
|
|
55
|
+
rule = Chain(
|
|
56
|
+
ConstantFold(),
|
|
57
|
+
DeadCodeElimination(),
|
|
58
|
+
CommonSubexpressionElimination(),
|
|
59
|
+
)
|
|
60
|
+
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
|
|
61
|
+
return result
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Passes for converting parallel gates into multiple single gates as well as
|
|
3
|
+
converting multiple single gates to parallel gates.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Type
|
|
7
|
+
from dataclasses import field, dataclass
|
|
8
|
+
|
|
9
|
+
from kirin import ir
|
|
10
|
+
from kirin.passes import Pass
|
|
11
|
+
from kirin.rewrite import (
|
|
12
|
+
Walk,
|
|
13
|
+
Chain,
|
|
14
|
+
Fixpoint,
|
|
15
|
+
WrapConst,
|
|
16
|
+
ConstantFold,
|
|
17
|
+
DeadCodeElimination,
|
|
18
|
+
CommonSubexpressionElimination,
|
|
19
|
+
result,
|
|
20
|
+
)
|
|
21
|
+
from kirin.analysis import const
|
|
22
|
+
|
|
23
|
+
from bloqade.analysis import address
|
|
24
|
+
from bloqade.qasm2.rewrite import (
|
|
25
|
+
MergePolicyABC,
|
|
26
|
+
ParallelToUOpRule,
|
|
27
|
+
RaiseRegisterRule,
|
|
28
|
+
UOpToParallelRule,
|
|
29
|
+
SimpleOptimalMergePolicy,
|
|
30
|
+
)
|
|
31
|
+
from bloqade.squin.analysis import schedule
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class ParallelToUOp(Pass):
|
|
36
|
+
"""Pass to convert parallel gates into single gates.
|
|
37
|
+
|
|
38
|
+
This pass rewrites any parallel gates from the `qasm2.parallel` dialect into multiple
|
|
39
|
+
single gates in the `qasm2.uop` dialect, bringing the program closer to
|
|
40
|
+
conforming to standard QASM2 syntax.
|
|
41
|
+
|
|
42
|
+
## Usage Examples
|
|
43
|
+
```
|
|
44
|
+
# Define kernel
|
|
45
|
+
@qasm2.extended
|
|
46
|
+
def main():
|
|
47
|
+
q = qasm2.qreg(4)
|
|
48
|
+
|
|
49
|
+
qasm2.parallel.cz(ctrls=[q[0], q[2]], qargs=[q[1], q[3]])
|
|
50
|
+
|
|
51
|
+
# Run rewrite
|
|
52
|
+
ParallelToUOp(main.dialects)(main)
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
The `qasm2.parallel.cz` statement has been rewritten to individual gates:
|
|
56
|
+
|
|
57
|
+
```
|
|
58
|
+
qasm2.uop.cz(ctrl=q[0], qarg=q[1])
|
|
59
|
+
qasm2.uop.cz(ctrl=q[2], qarg=q[3])
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def generate_rule(self, mt: ir.Method) -> ParallelToUOpRule:
|
|
65
|
+
frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
|
|
66
|
+
|
|
67
|
+
id_map = {}
|
|
68
|
+
|
|
69
|
+
# GOAL: Get the ssa value for the first reference of each qubit.
|
|
70
|
+
for ssa, addr in frame.entries.items():
|
|
71
|
+
if not isinstance(addr, address.AddressQubit):
|
|
72
|
+
# skip any stmts that are not qubits
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
# get qubit id from analysis result
|
|
76
|
+
qubit_id = addr.data
|
|
77
|
+
|
|
78
|
+
# check if id has already been found
|
|
79
|
+
# if so, skip this ssa value
|
|
80
|
+
if qubit_id in id_map:
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
id_map[qubit_id] = ssa
|
|
84
|
+
|
|
85
|
+
return ParallelToUOpRule(id_map=id_map, address_analysis=frame.entries)
|
|
86
|
+
|
|
87
|
+
def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
|
|
88
|
+
result = Walk(self.generate_rule(mt)).rewrite(mt.code)
|
|
89
|
+
rule = Chain(
|
|
90
|
+
ConstantFold(),
|
|
91
|
+
DeadCodeElimination(),
|
|
92
|
+
CommonSubexpressionElimination(),
|
|
93
|
+
)
|
|
94
|
+
return Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass
|
|
98
|
+
class UOpToParallel(Pass):
|
|
99
|
+
"""Pass to convert single gates into parallel gates.
|
|
100
|
+
|
|
101
|
+
This pass looks for single gates from the `qasm2.uop` dialect that can be combined
|
|
102
|
+
into parallel gates from the `qasm2.parallel` dialect and performs a rewrite to do so.
|
|
103
|
+
|
|
104
|
+
## Usage Examples
|
|
105
|
+
```
|
|
106
|
+
# Define kernel
|
|
107
|
+
@qasm2.main
|
|
108
|
+
def test():
|
|
109
|
+
q = qasm2.qreg(4)
|
|
110
|
+
|
|
111
|
+
theta = 0.1
|
|
112
|
+
phi = 0.2
|
|
113
|
+
lam = 0.3
|
|
114
|
+
|
|
115
|
+
qasm2.u(q[1], theta, phi, lam)
|
|
116
|
+
qasm2.u(q[3], theta, phi, lam)
|
|
117
|
+
qasm2.cx(q[1], q[3])
|
|
118
|
+
qasm2.u(q[2], theta, phi, lam)
|
|
119
|
+
qasm2.u(q[0], theta, phi, lam)
|
|
120
|
+
qasm2.cx(q[0], q[2])
|
|
121
|
+
|
|
122
|
+
# Run rewrite
|
|
123
|
+
UOpToParallel(main.dialects)(main)
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
The individual `qasm2.u` statements have now been combined
|
|
127
|
+
into a single `qasm2.parallel.u` statement.
|
|
128
|
+
|
|
129
|
+
```
|
|
130
|
+
qasm2.parallel.u(qargs = [q[0], q[1], q[2], q[3]], theta, phi, lam)
|
|
131
|
+
qasm2.uop.CX(q[1], q[3])
|
|
132
|
+
qasm2.uop.CX(q[0], q[2])
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
merge_policy_type: Type[MergePolicyABC] = SimpleOptimalMergePolicy
|
|
138
|
+
constprop: const.Propagate = field(init=False)
|
|
139
|
+
|
|
140
|
+
def __post_init__(self):
|
|
141
|
+
self.constprop = const.Propagate(self.dialects)
|
|
142
|
+
|
|
143
|
+
def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
|
|
144
|
+
result = Walk(RaiseRegisterRule()).rewrite(mt.code)
|
|
145
|
+
|
|
146
|
+
# do not run the parallelization because registers are not at the top
|
|
147
|
+
if not result.has_done_something:
|
|
148
|
+
return result
|
|
149
|
+
|
|
150
|
+
frame, _ = self.constprop.run_analysis(mt)
|
|
151
|
+
result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
|
|
152
|
+
|
|
153
|
+
frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
|
|
154
|
+
dags = schedule.DagScheduleAnalysis(
|
|
155
|
+
mt.dialects, address_analysis=frame.entries
|
|
156
|
+
).get_dags(mt)
|
|
157
|
+
|
|
158
|
+
result = (
|
|
159
|
+
Walk(
|
|
160
|
+
UOpToParallelRule(
|
|
161
|
+
{
|
|
162
|
+
block: self.merge_policy_type.from_analysis(dag, frame.entries)
|
|
163
|
+
for block, dag in dags.items()
|
|
164
|
+
}
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
.rewrite(mt.code)
|
|
168
|
+
.join(result)
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
rule = Chain(
|
|
172
|
+
ConstantFold(),
|
|
173
|
+
DeadCodeElimination(),
|
|
174
|
+
CommonSubexpressionElimination(),
|
|
175
|
+
)
|
|
176
|
+
return Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Rewrite py dialects into qasm dialects."""
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.passes import Pass
|
|
5
|
+
from kirin.rewrite import Walk, Fixpoint
|
|
6
|
+
from kirin.dialects import py, math
|
|
7
|
+
from kirin.rewrite.abc import RewriteRule
|
|
8
|
+
from kirin.rewrite.result import RewriteResult
|
|
9
|
+
|
|
10
|
+
from bloqade.qasm2.dialects import core, expr
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _Py2QASM(RewriteRule):
|
|
14
|
+
"""Rewrite py dialects into qasm dialects."""
|
|
15
|
+
|
|
16
|
+
UNARY_OPS = {
|
|
17
|
+
py.USub: expr.Neg,
|
|
18
|
+
math.sin: expr.Sin,
|
|
19
|
+
math.cos: expr.Cos,
|
|
20
|
+
math.tan: expr.Tan,
|
|
21
|
+
math.exp: expr.Exp,
|
|
22
|
+
math.sqrt: expr.Sqrt,
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
BINARY_OPS = {
|
|
26
|
+
py.Add: expr.Add,
|
|
27
|
+
py.Sub: expr.Sub,
|
|
28
|
+
py.Mult: expr.Mul,
|
|
29
|
+
py.Div: expr.Div,
|
|
30
|
+
py.Pow: expr.Pow,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
34
|
+
if isinstance(node, py.Constant):
|
|
35
|
+
value = node.value.unwrap()
|
|
36
|
+
if isinstance(value, int):
|
|
37
|
+
node.replace_by(expr.ConstInt(value=value))
|
|
38
|
+
return RewriteResult(has_done_something=True)
|
|
39
|
+
elif isinstance(value, float):
|
|
40
|
+
node.replace_by(expr.ConstFloat(value=value))
|
|
41
|
+
return RewriteResult(has_done_something=True)
|
|
42
|
+
elif isinstance(node, py.BinOp):
|
|
43
|
+
if (pystmt := self.BINARY_OPS.get(type(node))) is not None:
|
|
44
|
+
node.replace_by(pystmt(node.lhs, node.rhs))
|
|
45
|
+
return RewriteResult(has_done_something=True)
|
|
46
|
+
elif isinstance(node, py.UnaryOp):
|
|
47
|
+
if (pystmt := self.UNARY_OPS.get(type(node))) is not None:
|
|
48
|
+
node.replace_by(pystmt(node.value))
|
|
49
|
+
return RewriteResult(has_done_something=True)
|
|
50
|
+
elif isinstance(node, py.cmp.Eq):
|
|
51
|
+
node.replace_by(core.CRegEq(node.lhs, node.rhs))
|
|
52
|
+
return RewriteResult(has_done_something=True)
|
|
53
|
+
elif isinstance(node, py.assign.Alias):
|
|
54
|
+
node.result.replace_by(node.value)
|
|
55
|
+
node.delete()
|
|
56
|
+
return RewriteResult(has_done_something=True)
|
|
57
|
+
return RewriteResult()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Py2QASM(Pass):
|
|
61
|
+
|
|
62
|
+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
|
|
63
|
+
return Fixpoint(Walk(_Py2QASM())).rewrite(mt.code)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Rewrite qasm dialects into py dialects."""
|
|
2
|
+
|
|
3
|
+
import math as pymath
|
|
4
|
+
|
|
5
|
+
from kirin import ir
|
|
6
|
+
from kirin.passes import Pass
|
|
7
|
+
from kirin.rewrite import Walk, Fixpoint
|
|
8
|
+
from kirin.dialects import py, math
|
|
9
|
+
from kirin.rewrite.abc import RewriteRule
|
|
10
|
+
from kirin.rewrite.result import RewriteResult
|
|
11
|
+
|
|
12
|
+
from bloqade.qasm2.dialects import core, expr
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class _QASM2Py(RewriteRule):
|
|
16
|
+
"""Rewrite qasm dialects into py dialects."""
|
|
17
|
+
|
|
18
|
+
UNARY_OPS = {
|
|
19
|
+
expr.Neg: py.USub,
|
|
20
|
+
expr.Sin: math.stmts.sin,
|
|
21
|
+
expr.Cos: math.stmts.cos,
|
|
22
|
+
expr.Tan: math.stmts.tan,
|
|
23
|
+
expr.Exp: math.stmts.exp,
|
|
24
|
+
expr.Sqrt: math.stmts.sqrt,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
BINARY_OPS = {
|
|
28
|
+
expr.Add: py.Add,
|
|
29
|
+
expr.Sub: py.Sub,
|
|
30
|
+
expr.Mul: py.Mult,
|
|
31
|
+
expr.Div: py.Div,
|
|
32
|
+
expr.Pow: py.Pow,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
36
|
+
if isinstance(node, (expr.ConstInt, expr.ConstFloat)):
|
|
37
|
+
node.replace_by(py.Constant(value=node.value))
|
|
38
|
+
return RewriteResult(has_done_something=True)
|
|
39
|
+
elif isinstance(node, expr.Neg):
|
|
40
|
+
node.replace_by(self.UNARY_OPS[type(node)](value=node.value))
|
|
41
|
+
return RewriteResult(has_done_something=True)
|
|
42
|
+
elif isinstance(node, (expr.Sin, expr.Cos, expr.Tan, expr.Exp, expr.Sqrt)):
|
|
43
|
+
node.replace_by(self.UNARY_OPS[type(node)](x=node.value))
|
|
44
|
+
return RewriteResult(has_done_something=True)
|
|
45
|
+
elif isinstance(node, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Pow)):
|
|
46
|
+
node.replace_by(self.BINARY_OPS[type(node)](lhs=node.lhs, rhs=node.rhs))
|
|
47
|
+
return RewriteResult(has_done_something=True)
|
|
48
|
+
elif isinstance(node, core.CRegEq):
|
|
49
|
+
node.replace_by(py.cmp.Eq(node.lhs, node.rhs))
|
|
50
|
+
return RewriteResult(has_done_something=True)
|
|
51
|
+
elif isinstance(node, expr.ConstPI):
|
|
52
|
+
node.replace_by(py.Constant(value=pymath.pi))
|
|
53
|
+
return RewriteResult(has_done_something=True)
|
|
54
|
+
else:
|
|
55
|
+
return RewriteResult()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class QASM2Py(Pass):
|
|
59
|
+
|
|
60
|
+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
|
|
61
|
+
return Fixpoint(Walk(_QASM2Py())).rewrite(mt.code)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .glob import (
|
|
2
|
+
GlobalToUOpRule as GlobalToUOpRule,
|
|
3
|
+
GlobalToParallelRule as GlobalToParallelRule,
|
|
4
|
+
)
|
|
5
|
+
from .register import RaiseRegisterRule as RaiseRegisterRule
|
|
6
|
+
from .parallel_to_uop import ParallelToUOpRule as ParallelToUOpRule
|
|
7
|
+
from .uop_to_parallel import (
|
|
8
|
+
MergePolicyABC as MergePolicyABC,
|
|
9
|
+
UOpToParallelRule as UOpToParallelRule,
|
|
10
|
+
SimpleGreedyMergePolicy as SimpleGreedyMergePolicy,
|
|
11
|
+
SimpleOptimalMergePolicy as SimpleOptimalMergePolicy,
|
|
12
|
+
)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.passes import Pass
|
|
5
|
+
from kirin.rewrite import abc, walk, result
|
|
6
|
+
from kirin.dialects import py
|
|
7
|
+
|
|
8
|
+
from bloqade.qasm2.dialects import core
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class IndexingDesugarRule(abc.RewriteRule):
|
|
12
|
+
def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
|
|
13
|
+
if isinstance(node, py.indexing.GetItem):
|
|
14
|
+
if node.obj.type.is_subseteq(core.QRegType):
|
|
15
|
+
node.replace_by(core.QRegGet(reg=node.obj, idx=node.index))
|
|
16
|
+
return result.RewriteResult(has_done_something=True)
|
|
17
|
+
elif node.obj.type.is_subseteq(core.CRegType):
|
|
18
|
+
node.replace_by(core.CRegGet(reg=node.obj, idx=node.index))
|
|
19
|
+
return result.RewriteResult(has_done_something=True)
|
|
20
|
+
|
|
21
|
+
return result.RewriteResult()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class IndexingDesugarPass(Pass):
|
|
26
|
+
def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
|
|
27
|
+
|
|
28
|
+
return walk.Walk(IndexingDesugarRule()).rewrite(mt.code)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from typing import Dict, List
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.rewrite import abc, result
|
|
6
|
+
from kirin.dialects import py, ilist
|
|
7
|
+
|
|
8
|
+
from bloqade import qasm2
|
|
9
|
+
from bloqade.analysis import address
|
|
10
|
+
from bloqade.qasm2.dialects import glob
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class GlobalRewriteBase:
|
|
15
|
+
address_analysis: Dict[ir.SSAValue, address.Address]
|
|
16
|
+
|
|
17
|
+
def get_qubit_ssa(self, node: glob.UGate):
|
|
18
|
+
new_stmts: List[ir.Statement] = []
|
|
19
|
+
qubit_ssa: List[ir.SSAValue] = []
|
|
20
|
+
# can't rewrite if the registers are coming from a block argument
|
|
21
|
+
if not isinstance(node.registers, ir.ResultValue):
|
|
22
|
+
return new_stmts, None
|
|
23
|
+
|
|
24
|
+
if not isinstance(node.registers.owner, ilist.New):
|
|
25
|
+
return new_stmts, None
|
|
26
|
+
|
|
27
|
+
register_ssa_values = node.registers.owner.values
|
|
28
|
+
|
|
29
|
+
for register_ssa in register_ssa_values:
|
|
30
|
+
addr = self.address_analysis.get(register_ssa, address.Address.top())
|
|
31
|
+
if not isinstance(addr, address.AddressReg):
|
|
32
|
+
new_stmts.clear()
|
|
33
|
+
return new_stmts, None
|
|
34
|
+
|
|
35
|
+
for qubit in range(len(addr.data)):
|
|
36
|
+
new_stmts.append(idx_stmt := py.constant.Constant(value=qubit))
|
|
37
|
+
new_stmts.append(
|
|
38
|
+
qubit_stmt := qasm2.core.QRegGet(
|
|
39
|
+
reg=register_ssa, idx=idx_stmt.result
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
qubit_ssa.append(qubit_stmt.result)
|
|
43
|
+
|
|
44
|
+
return new_stmts, qubit_ssa
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class GlobalToParallelRule(abc.RewriteRule, GlobalRewriteBase):
|
|
49
|
+
|
|
50
|
+
def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
|
|
51
|
+
if type(node) in glob.dialect.stmts:
|
|
52
|
+
return getattr(self, f"rewrite_{node.name}")(node)
|
|
53
|
+
|
|
54
|
+
return result.RewriteResult()
|
|
55
|
+
|
|
56
|
+
def rewrite_ugate(self, node: glob.UGate):
|
|
57
|
+
|
|
58
|
+
new_stmts, qubit_ssa = self.get_qubit_ssa(node)
|
|
59
|
+
|
|
60
|
+
if qubit_ssa is None:
|
|
61
|
+
return result.RewriteResult()
|
|
62
|
+
|
|
63
|
+
new_stmts.append(qargs := ilist.New(values=qubit_ssa))
|
|
64
|
+
new_stmts.append(
|
|
65
|
+
qasm2.dialects.parallel.UGate(
|
|
66
|
+
qargs=qargs.result, theta=node.theta, phi=node.phi, lam=node.lam
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
for stmt in new_stmts:
|
|
71
|
+
stmt.insert_before(node)
|
|
72
|
+
|
|
73
|
+
node.delete()
|
|
74
|
+
|
|
75
|
+
return result.RewriteResult(has_done_something=True)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class GlobalToUOpRule(abc.RewriteRule, GlobalRewriteBase):
|
|
80
|
+
|
|
81
|
+
def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
|
|
82
|
+
if type(node) in glob.dialect.stmts:
|
|
83
|
+
return getattr(self, f"rewrite_{node.name}")(node)
|
|
84
|
+
|
|
85
|
+
return result.RewriteResult()
|
|
86
|
+
|
|
87
|
+
def rewrite_ugate(self, node: glob.UGate):
|
|
88
|
+
|
|
89
|
+
new_stmts, qubit_ssa = self.get_qubit_ssa(node)
|
|
90
|
+
|
|
91
|
+
if qubit_ssa is None:
|
|
92
|
+
return result.RewriteResult()
|
|
93
|
+
|
|
94
|
+
for qarg in qubit_ssa:
|
|
95
|
+
new_stmts.append(
|
|
96
|
+
qasm2.uop.UGate(qarg=qarg, theta=node.theta, phi=node.phi, lam=node.lam)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
for stmt in new_stmts:
|
|
100
|
+
stmt.insert_before(node)
|
|
101
|
+
|
|
102
|
+
node.delete()
|
|
103
|
+
return result.RewriteResult(has_done_something=True)
|