bloqade-circuit 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +5 -9
- bloqade/analysis/address/lattice.py +1 -1
- bloqade/analysis/fidelity/__init__.py +1 -0
- bloqade/analysis/fidelity/analysis.py +69 -0
- bloqade/device.py +130 -0
- bloqade/noise/__init__.py +2 -1
- bloqade/noise/fidelity.py +51 -0
- bloqade/noise/native/model.py +1 -2
- bloqade/noise/native/rewrite.py +5 -5
- bloqade/noise/native/stmts.py +40 -11
- bloqade/pyqrack/__init__.py +8 -2
- bloqade/pyqrack/base.py +24 -3
- bloqade/pyqrack/device.py +166 -0
- bloqade/pyqrack/noise/native.py +1 -2
- bloqade/pyqrack/qasm2/core.py +31 -15
- bloqade/pyqrack/qasm2/glob.py +28 -0
- bloqade/pyqrack/qasm2/uop.py +9 -1
- bloqade/pyqrack/reg.py +17 -49
- bloqade/pyqrack/squin/__init__.py +0 -0
- bloqade/pyqrack/squin/op.py +154 -0
- bloqade/pyqrack/squin/qubit.py +85 -0
- bloqade/pyqrack/squin/runtime.py +515 -0
- bloqade/pyqrack/squin/wire.py +69 -0
- bloqade/pyqrack/target.py +9 -2
- bloqade/pyqrack/task.py +30 -0
- bloqade/qasm2/_wrappers.py +11 -1
- bloqade/qasm2/dialects/core/stmts.py +15 -4
- bloqade/qasm2/dialects/expr/_emit.py +9 -8
- bloqade/qasm2/emit/base.py +4 -2
- bloqade/qasm2/emit/gate.py +0 -14
- bloqade/qasm2/emit/main.py +19 -15
- bloqade/qasm2/emit/target.py +2 -6
- bloqade/qasm2/glob.py +1 -1
- bloqade/qasm2/parse/lowering.py +124 -1
- bloqade/qasm2/passes/glob.py +3 -3
- bloqade/qasm2/passes/lift_qubits.py +26 -0
- bloqade/qasm2/passes/noise.py +6 -14
- bloqade/qasm2/passes/parallel.py +3 -3
- bloqade/qasm2/passes/py2qasm.py +1 -2
- bloqade/qasm2/passes/qasm2py.py +1 -2
- bloqade/qasm2/rewrite/desugar.py +6 -6
- bloqade/qasm2/rewrite/glob.py +9 -9
- bloqade/qasm2/rewrite/heuristic_noise.py +30 -38
- bloqade/qasm2/rewrite/insert_qubits.py +34 -0
- bloqade/qasm2/rewrite/native_gates.py +54 -55
- bloqade/qasm2/rewrite/parallel_to_uop.py +9 -9
- bloqade/qasm2/rewrite/uop_to_parallel.py +20 -22
- bloqade/qasm2/types.py +3 -6
- bloqade/qbraid/schema.py +10 -12
- bloqade/squin/__init__.py +1 -1
- bloqade/squin/analysis/nsites/analysis.py +4 -6
- bloqade/squin/analysis/nsites/impls.py +2 -6
- bloqade/squin/analysis/schedule.py +1 -1
- bloqade/squin/groups.py +15 -7
- bloqade/squin/noise/__init__.py +27 -0
- bloqade/squin/noise/_dialect.py +3 -0
- bloqade/squin/noise/stmts.py +59 -0
- bloqade/squin/op/__init__.py +35 -5
- bloqade/squin/op/number.py +5 -0
- bloqade/squin/op/rewrite.py +46 -0
- bloqade/squin/op/stmts.py +23 -2
- bloqade/squin/op/types.py +14 -0
- bloqade/squin/qubit.py +79 -11
- bloqade/squin/rewrite/__init__.py +0 -0
- bloqade/squin/rewrite/measure_desugar.py +33 -0
- bloqade/squin/wire.py +31 -2
- bloqade/stim/emit/stim.py +1 -1
- bloqade/task.py +94 -0
- bloqade/visual/animation/base.py +25 -15
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/METADATA +8 -2
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/RECORD +73 -52
- bloqade/squin/op/complex.py +0 -6
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
3
|
from kirin import interp
|
|
4
|
-
from kirin.emit.exceptions import EmitError
|
|
5
4
|
|
|
6
5
|
from bloqade.qasm2.parse import ast
|
|
7
6
|
from bloqade.qasm2.types import QubitType
|
|
@@ -19,16 +18,18 @@ class EmitExpr(interp.MethodTable):
|
|
|
19
18
|
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.GateFunction
|
|
20
19
|
):
|
|
21
20
|
|
|
21
|
+
args: list[ast.Node] = []
|
|
22
22
|
cparams, qparams = [], []
|
|
23
|
-
for arg in stmt.body.blocks[0].args
|
|
24
|
-
name
|
|
25
|
-
|
|
26
|
-
|
|
23
|
+
for arg in stmt.body.blocks[0].args:
|
|
24
|
+
assert arg.name is not None
|
|
25
|
+
|
|
26
|
+
args.append(ast.Name(id=arg.name))
|
|
27
27
|
if arg.type.is_subseteq(QubitType):
|
|
28
|
-
qparams.append(name
|
|
28
|
+
qparams.append(arg.name)
|
|
29
29
|
else:
|
|
30
|
-
cparams.append(name
|
|
31
|
-
|
|
30
|
+
cparams.append(arg.name)
|
|
31
|
+
|
|
32
|
+
emit.run_ssacfg_region(frame, stmt.body, tuple(args))
|
|
32
33
|
emit.output = ast.Gate(
|
|
33
34
|
name=stmt.sym_name,
|
|
34
35
|
cparams=cparams,
|
bloqade/qasm2/emit/base.py
CHANGED
|
@@ -36,8 +36,10 @@ class EmitQASM2Base(
|
|
|
36
36
|
)
|
|
37
37
|
return self
|
|
38
38
|
|
|
39
|
-
def
|
|
40
|
-
|
|
39
|
+
def initialize_frame(
|
|
40
|
+
self, code: ir.Statement, *, has_parent_access: bool = False
|
|
41
|
+
) -> EmitQASM2Frame[StmtType]:
|
|
42
|
+
return EmitQASM2Frame(code, has_parent_access=has_parent_access)
|
|
41
43
|
|
|
42
44
|
def run_method(
|
|
43
45
|
self, method: ir.Method, args: tuple[ast.Node | None, ...]
|
bloqade/qasm2/emit/gate.py
CHANGED
|
@@ -86,17 +86,3 @@ class Func(interp.MethodTable):
|
|
|
86
86
|
@interp.impl(func.ConstantNone)
|
|
87
87
|
def ignore(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
|
|
88
88
|
return ()
|
|
89
|
-
|
|
90
|
-
@interp.impl(func.Function)
|
|
91
|
-
def emit_func(
|
|
92
|
-
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: func.Function
|
|
93
|
-
):
|
|
94
|
-
emit.run_ssacfg_region(frame, stmt.body)
|
|
95
|
-
cparams, qparams = [], []
|
|
96
|
-
for arg in stmt.args:
|
|
97
|
-
if arg.type.is_subseteq(QubitType):
|
|
98
|
-
qparams.append(frame.get(arg))
|
|
99
|
-
else:
|
|
100
|
-
cparams.append(frame.get(arg))
|
|
101
|
-
emit.output = ast.Gate(stmt.sym_name, cparams, qparams, frame.body)
|
|
102
|
-
return ()
|
bloqade/qasm2/emit/main.py
CHANGED
|
@@ -24,7 +24,7 @@ class Func(interp.MethodTable):
|
|
|
24
24
|
):
|
|
25
25
|
from bloqade.qasm2.dialects import glob, noise, parallel
|
|
26
26
|
|
|
27
|
-
emit.run_ssacfg_region(frame, stmt.body)
|
|
27
|
+
emit.run_ssacfg_region(frame, stmt.body, ())
|
|
28
28
|
if emit.dialects.data.intersection(
|
|
29
29
|
(parallel.dialect, glob.dialect, noise.dialect)
|
|
30
30
|
):
|
|
@@ -51,12 +51,14 @@ class Cf(interp.MethodTable):
|
|
|
51
51
|
self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: cf.ConditionalBranch
|
|
52
52
|
):
|
|
53
53
|
cond = emit.assert_node(ast.Cmp, frame.get(stmt.cond))
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
54
|
+
|
|
55
|
+
with emit.new_frame(stmt) as body_frame:
|
|
56
|
+
body_frame.entries.update(frame.entries)
|
|
57
|
+
body_frame.set_values(
|
|
58
|
+
stmt.then_successor.args, frame.get_values(stmt.then_arguments)
|
|
59
|
+
)
|
|
60
|
+
emit.emit_block(body_frame, stmt.then_successor)
|
|
61
|
+
|
|
60
62
|
frame.body.append(
|
|
61
63
|
ast.IfStmt(
|
|
62
64
|
cond,
|
|
@@ -91,15 +93,17 @@ class Scf(interp.MethodTable):
|
|
|
91
93
|
)
|
|
92
94
|
|
|
93
95
|
cond = emit.assert_node(ast.Cmp, frame.get(stmt.cond))
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
96
|
+
|
|
97
|
+
with emit.new_frame(stmt) as then_frame:
|
|
98
|
+
then_frame.entries.update(frame.entries)
|
|
99
|
+
emit.emit_block(then_frame, stmt.then_body.blocks[0])
|
|
100
|
+
frame.body.append(
|
|
101
|
+
ast.IfStmt(
|
|
102
|
+
cond,
|
|
103
|
+
body=then_frame.body, # type: ignore
|
|
104
|
+
)
|
|
101
105
|
)
|
|
102
|
-
|
|
106
|
+
|
|
103
107
|
term = stmt.then_body.blocks[0].last_stmt
|
|
104
108
|
if isinstance(term, scf.Yield):
|
|
105
109
|
return then_frame.get_values(term.values)
|
bloqade/qasm2/emit/target.py
CHANGED
|
@@ -101,9 +101,7 @@ class QASM2:
|
|
|
101
101
|
|
|
102
102
|
Py2QASM(entry.dialects)(entry)
|
|
103
103
|
target_main = EmitQASM2Main(self.main_target)
|
|
104
|
-
target_main.run(
|
|
105
|
-
entry, tuple(ast.Name(name) for name in entry.arg_names[1:])
|
|
106
|
-
).expect()
|
|
104
|
+
target_main.run(entry, ())
|
|
107
105
|
|
|
108
106
|
main_program = target_main.output
|
|
109
107
|
assert main_program is not None, f"failed to emit {entry.sym_name}"
|
|
@@ -133,9 +131,7 @@ class QASM2:
|
|
|
133
131
|
|
|
134
132
|
Py2QASM(fn.dialects)(fn)
|
|
135
133
|
|
|
136
|
-
target_gate.run(
|
|
137
|
-
fn, tuple(ast.Name(name) for name in fn.arg_names[1:])
|
|
138
|
-
).expect()
|
|
134
|
+
target_gate.run(fn, tuple(ast.Name(name) for name in fn.arg_names[1:]))
|
|
139
135
|
assert target_gate.output is not None, f"failed to emit {fn.sym_name}"
|
|
140
136
|
extra.append(target_gate.output)
|
|
141
137
|
|
bloqade/qasm2/glob.py
CHANGED
|
@@ -11,7 +11,7 @@ from .dialects import glob
|
|
|
11
11
|
|
|
12
12
|
@wraps(glob.UGate)
|
|
13
13
|
def u(
|
|
14
|
-
|
|
14
|
+
registers: ilist.IList[QReg, Any] | list, theta: float, phi: float, lam: float
|
|
15
15
|
) -> None:
|
|
16
16
|
"""Apply a U gate to all qubits in the input registers.
|
|
17
17
|
|
bloqade/qasm2/parse/lowering.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pathlib
|
|
1
3
|
from typing import Any
|
|
2
4
|
from dataclasses import field, dataclass
|
|
3
5
|
|
|
@@ -17,6 +19,119 @@ class QASM2(lowering.LoweringABC[ast.Node]):
|
|
|
17
19
|
hint_show_lineno: bool = field(default=True, kw_only=True)
|
|
18
20
|
stacktrace: bool = field(default=True, kw_only=True)
|
|
19
21
|
|
|
22
|
+
def loads(
|
|
23
|
+
self,
|
|
24
|
+
source: str,
|
|
25
|
+
kernel_name: str,
|
|
26
|
+
*,
|
|
27
|
+
returns: str | None = None,
|
|
28
|
+
globals: dict[str, Any] | None = None,
|
|
29
|
+
file: str | None = None,
|
|
30
|
+
lineno_offset: int = 0,
|
|
31
|
+
col_offset: int = 0,
|
|
32
|
+
compactify: bool = True,
|
|
33
|
+
) -> ir.Method:
|
|
34
|
+
from ..parse import loads
|
|
35
|
+
|
|
36
|
+
# TODO: add source info
|
|
37
|
+
stmt = loads(source)
|
|
38
|
+
|
|
39
|
+
state = lowering.State(
|
|
40
|
+
self,
|
|
41
|
+
file=file,
|
|
42
|
+
lineno_offset=lineno_offset,
|
|
43
|
+
col_offset=col_offset,
|
|
44
|
+
)
|
|
45
|
+
with state.frame(
|
|
46
|
+
[stmt],
|
|
47
|
+
globals=globals,
|
|
48
|
+
finalize_next=False,
|
|
49
|
+
) as frame:
|
|
50
|
+
try:
|
|
51
|
+
self.visit(state, stmt)
|
|
52
|
+
# append return statement with the return values
|
|
53
|
+
if returns is not None:
|
|
54
|
+
return_value = frame.get(returns)
|
|
55
|
+
if return_value is None:
|
|
56
|
+
raise lowering.BuildError(f"Cannot find return value {returns}")
|
|
57
|
+
else:
|
|
58
|
+
return_value = func.ConstantNone()
|
|
59
|
+
|
|
60
|
+
return_node = frame.push(func.Return(value_or_stmt=return_value))
|
|
61
|
+
|
|
62
|
+
except lowering.BuildError as e:
|
|
63
|
+
hint = state.error_hint(
|
|
64
|
+
e,
|
|
65
|
+
max_lines=self.max_lines,
|
|
66
|
+
indent=self.hint_indent,
|
|
67
|
+
show_lineno=self.hint_show_lineno,
|
|
68
|
+
)
|
|
69
|
+
if self.stacktrace:
|
|
70
|
+
raise Exception(
|
|
71
|
+
f"{e.args[0]}\n\n{hint}",
|
|
72
|
+
*e.args[1:],
|
|
73
|
+
) from e
|
|
74
|
+
else:
|
|
75
|
+
e.args = (hint,)
|
|
76
|
+
raise e
|
|
77
|
+
|
|
78
|
+
region = frame.curr_region
|
|
79
|
+
|
|
80
|
+
if compactify:
|
|
81
|
+
from kirin.rewrite import Walk, CFGCompactify
|
|
82
|
+
|
|
83
|
+
Walk(CFGCompactify()).rewrite(region)
|
|
84
|
+
|
|
85
|
+
code = func.Function(
|
|
86
|
+
sym_name=kernel_name,
|
|
87
|
+
signature=func.Signature((), return_node.value.type),
|
|
88
|
+
body=region,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return ir.Method(
|
|
92
|
+
mod=None,
|
|
93
|
+
py_func=None,
|
|
94
|
+
sym_name=kernel_name,
|
|
95
|
+
arg_names=[],
|
|
96
|
+
dialects=self.dialects,
|
|
97
|
+
code=code,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def loadfile(
|
|
101
|
+
self,
|
|
102
|
+
file: str | pathlib.Path,
|
|
103
|
+
*,
|
|
104
|
+
kernel_name: str | None = None,
|
|
105
|
+
returns: str | None = None,
|
|
106
|
+
globals: dict[str, Any] | None = None,
|
|
107
|
+
lineno_offset: int = 0,
|
|
108
|
+
col_offset: int = 0,
|
|
109
|
+
compactify: bool = True,
|
|
110
|
+
) -> ir.Method:
|
|
111
|
+
if isinstance(file, str):
|
|
112
|
+
file = pathlib.Path(*os.path.split(file))
|
|
113
|
+
|
|
114
|
+
if not file.is_file() or not file.name.endswith(".qasm"):
|
|
115
|
+
raise ValueError("File must be a .qasm file")
|
|
116
|
+
|
|
117
|
+
kernel_name = (
|
|
118
|
+
file.name.replace(".qasm", "") if kernel_name is None else kernel_name
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
with file.open("r") as f:
|
|
122
|
+
source = f.read()
|
|
123
|
+
|
|
124
|
+
return self.loads(
|
|
125
|
+
source,
|
|
126
|
+
kernel_name,
|
|
127
|
+
returns=returns,
|
|
128
|
+
globals=globals,
|
|
129
|
+
file=str(file),
|
|
130
|
+
lineno_offset=lineno_offset,
|
|
131
|
+
col_offset=col_offset,
|
|
132
|
+
compactify=compactify,
|
|
133
|
+
)
|
|
134
|
+
|
|
20
135
|
def run(
|
|
21
136
|
self,
|
|
22
137
|
stmt: ast.Node,
|
|
@@ -85,6 +200,10 @@ class QASM2(lowering.LoweringABC[ast.Node]):
|
|
|
85
200
|
stmt = expr.ConstInt(value=value)
|
|
86
201
|
elif isinstance(value, float):
|
|
87
202
|
stmt = expr.ConstFloat(value=value)
|
|
203
|
+
else:
|
|
204
|
+
raise lowering.BuildError(
|
|
205
|
+
f"Expected value of type float or int, got {type(value)}."
|
|
206
|
+
)
|
|
88
207
|
state.current_frame.push(stmt)
|
|
89
208
|
return stmt.result
|
|
90
209
|
|
|
@@ -99,6 +218,8 @@ class QASM2(lowering.LoweringABC[ast.Node]):
|
|
|
99
218
|
dialects = ["qasm2.core", "qasm2.uop", "qasm2.expr"]
|
|
100
219
|
elif isinstance(node.header, ast.Kirin):
|
|
101
220
|
dialects = node.header.dialects
|
|
221
|
+
else:
|
|
222
|
+
raise lowering.BuildError(f"Unexpected node header {node.header}")
|
|
102
223
|
|
|
103
224
|
for dialect in dialects:
|
|
104
225
|
if dialect not in allowed:
|
|
@@ -278,7 +399,7 @@ class QASM2(lowering.LoweringABC[ast.Node]):
|
|
|
278
399
|
def visit_UnaryOp(self, state: lowering.State[ast.Node], node: ast.UnaryOp):
|
|
279
400
|
if node.op == "-":
|
|
280
401
|
stmt = expr.Neg(value=state.lower(node.operand).expect_one())
|
|
281
|
-
return stmt.result
|
|
402
|
+
return state.current_frame.push(stmt).result
|
|
282
403
|
else:
|
|
283
404
|
return state.lower(node.operand).expect_one()
|
|
284
405
|
|
|
@@ -295,6 +416,8 @@ class QASM2(lowering.LoweringABC[ast.Node]):
|
|
|
295
416
|
stmt = core.QRegGet(reg, addr.result)
|
|
296
417
|
elif reg.type.is_subseteq(CRegType):
|
|
297
418
|
stmt = core.CRegGet(reg, addr.result)
|
|
419
|
+
else:
|
|
420
|
+
raise lowering.BuildError(f"Unexpected register type {reg.type}")
|
|
298
421
|
return state.current_frame.push(stmt).result
|
|
299
422
|
|
|
300
423
|
def visit_Call(self, state: lowering.State[ast.Node], node: ast.Call):
|
bloqade/qasm2/passes/glob.py
CHANGED
|
@@ -4,7 +4,7 @@ which converts global gates to single qubit gates.
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from kirin import ir
|
|
7
|
-
from kirin.rewrite import cse, dce, walk
|
|
7
|
+
from kirin.rewrite import abc, cse, dce, walk
|
|
8
8
|
from kirin.passes.abc import Pass
|
|
9
9
|
from kirin.passes.fold import Fold
|
|
10
10
|
from kirin.rewrite.fixpoint import Fixpoint
|
|
@@ -54,7 +54,7 @@ class GlobalToUOP(Pass):
|
|
|
54
54
|
frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
|
|
55
55
|
return GlobalToUOpRule(frame.entries)
|
|
56
56
|
|
|
57
|
-
def unsafe_run(self, mt: ir.Method) ->
|
|
57
|
+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
|
|
58
58
|
rewriter = walk.Walk(self.generate_rule(mt))
|
|
59
59
|
result = rewriter.rewrite(mt.code)
|
|
60
60
|
|
|
@@ -106,7 +106,7 @@ class GlobalToParallel(Pass):
|
|
|
106
106
|
frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
|
|
107
107
|
return GlobalToParallelRule(frame.entries)
|
|
108
108
|
|
|
109
|
-
def unsafe_run(self, mt: ir.Method) ->
|
|
109
|
+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
|
|
110
110
|
rewriter = walk.Walk(self.generate_rule(mt))
|
|
111
111
|
result = rewriter.rewrite(mt.code)
|
|
112
112
|
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.passes import Pass
|
|
3
|
+
from kirin.rewrite import (
|
|
4
|
+
Walk,
|
|
5
|
+
Chain,
|
|
6
|
+
Fixpoint,
|
|
7
|
+
ConstantFold,
|
|
8
|
+
CommonSubexpressionElimination,
|
|
9
|
+
)
|
|
10
|
+
from kirin.passes.hint_const import HintConst
|
|
11
|
+
|
|
12
|
+
from bloqade.qasm2.rewrite.insert_qubits import InsertGetQubit
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LiftQubits(Pass):
|
|
16
|
+
"""This pass lifts the creation of qubits to the block where the register is defined."""
|
|
17
|
+
|
|
18
|
+
def unsafe_run(self, mt: ir.Method):
|
|
19
|
+
result = Walk(InsertGetQubit()).rewrite(mt.code)
|
|
20
|
+
result = HintConst(self.dialects).unsafe_run(mt).join(result)
|
|
21
|
+
result = (
|
|
22
|
+
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
|
|
23
|
+
.rewrite(mt.code)
|
|
24
|
+
.join(result)
|
|
25
|
+
)
|
|
26
|
+
return result
|
bloqade/qasm2/passes/noise.py
CHANGED
|
@@ -4,16 +4,13 @@ from kirin import ir
|
|
|
4
4
|
from kirin.passes import Pass
|
|
5
5
|
from kirin.rewrite import (
|
|
6
6
|
Walk,
|
|
7
|
-
Chain,
|
|
8
7
|
Fixpoint,
|
|
9
|
-
ConstantFold,
|
|
10
8
|
DeadCodeElimination,
|
|
11
|
-
CommonSubexpressionElimination,
|
|
12
9
|
)
|
|
13
|
-
from kirin.rewrite.result import RewriteResult
|
|
14
10
|
|
|
15
11
|
from bloqade.noise import native
|
|
16
12
|
from bloqade.analysis import address
|
|
13
|
+
from bloqade.qasm2.passes.lift_qubits import LiftQubits
|
|
17
14
|
from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule
|
|
18
15
|
|
|
19
16
|
|
|
@@ -38,24 +35,19 @@ class NoisePass(Pass):
|
|
|
38
35
|
self.address_analysis = address.AddressAnalysis(self.dialects)
|
|
39
36
|
|
|
40
37
|
def unsafe_run(self, mt: ir.Method):
|
|
41
|
-
result =
|
|
42
|
-
|
|
43
|
-
frame, res = self.address_analysis.run_analysis(mt, no_raise=False)
|
|
38
|
+
result = LiftQubits(self.dialects).unsafe_run(mt)
|
|
39
|
+
frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
|
|
44
40
|
result = (
|
|
45
41
|
Walk(
|
|
46
42
|
NoiseRewriteRule(
|
|
47
43
|
address_analysis=frame.entries,
|
|
48
44
|
noise_model=self.noise_model,
|
|
49
45
|
gate_noise_params=self.gate_noise_params,
|
|
50
|
-
)
|
|
46
|
+
),
|
|
47
|
+
reverse=True,
|
|
51
48
|
)
|
|
52
49
|
.rewrite(mt.code)
|
|
53
50
|
.join(result)
|
|
54
51
|
)
|
|
55
|
-
|
|
56
|
-
ConstantFold(),
|
|
57
|
-
DeadCodeElimination(),
|
|
58
|
-
CommonSubexpressionElimination(),
|
|
59
|
-
)
|
|
60
|
-
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
|
|
52
|
+
result = Fixpoint(Walk(DeadCodeElimination())).rewrite(mt.code).join(result)
|
|
61
53
|
return result
|
bloqade/qasm2/passes/parallel.py
CHANGED
|
@@ -16,7 +16,7 @@ from kirin.rewrite import (
|
|
|
16
16
|
ConstantFold,
|
|
17
17
|
DeadCodeElimination,
|
|
18
18
|
CommonSubexpressionElimination,
|
|
19
|
-
|
|
19
|
+
abc,
|
|
20
20
|
)
|
|
21
21
|
from kirin.analysis import const
|
|
22
22
|
|
|
@@ -84,7 +84,7 @@ class ParallelToUOp(Pass):
|
|
|
84
84
|
|
|
85
85
|
return ParallelToUOpRule(id_map=id_map, address_analysis=frame.entries)
|
|
86
86
|
|
|
87
|
-
def unsafe_run(self, mt: ir.Method) ->
|
|
87
|
+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
|
|
88
88
|
result = Walk(self.generate_rule(mt)).rewrite(mt.code)
|
|
89
89
|
rule = Chain(
|
|
90
90
|
ConstantFold(),
|
|
@@ -140,7 +140,7 @@ class UOpToParallel(Pass):
|
|
|
140
140
|
def __post_init__(self):
|
|
141
141
|
self.constprop = const.Propagate(self.dialects)
|
|
142
142
|
|
|
143
|
-
def unsafe_run(self, mt: ir.Method) ->
|
|
143
|
+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
|
|
144
144
|
result = Walk(RaiseRegisterRule()).rewrite(mt.code)
|
|
145
145
|
|
|
146
146
|
# do not run the parallelization because registers are not at the top
|
bloqade/qasm2/passes/py2qasm.py
CHANGED
|
@@ -4,8 +4,7 @@ from kirin import ir
|
|
|
4
4
|
from kirin.passes import Pass
|
|
5
5
|
from kirin.rewrite import Walk, Fixpoint
|
|
6
6
|
from kirin.dialects import py, math
|
|
7
|
-
from kirin.rewrite.abc import RewriteRule
|
|
8
|
-
from kirin.rewrite.result import RewriteResult
|
|
7
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
9
8
|
|
|
10
9
|
from bloqade.qasm2.dialects import core, expr
|
|
11
10
|
|
bloqade/qasm2/passes/qasm2py.py
CHANGED
|
@@ -6,8 +6,7 @@ from kirin import ir
|
|
|
6
6
|
from kirin.passes import Pass
|
|
7
7
|
from kirin.rewrite import Walk, Fixpoint
|
|
8
8
|
from kirin.dialects import py, math
|
|
9
|
-
from kirin.rewrite.abc import RewriteRule
|
|
10
|
-
from kirin.rewrite.result import RewriteResult
|
|
9
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
11
10
|
|
|
12
11
|
from bloqade.qasm2.dialects import core, expr
|
|
13
12
|
|
bloqade/qasm2/rewrite/desugar.py
CHANGED
|
@@ -2,27 +2,27 @@ from dataclasses import dataclass
|
|
|
2
2
|
|
|
3
3
|
from kirin import ir
|
|
4
4
|
from kirin.passes import Pass
|
|
5
|
-
from kirin.rewrite import abc, walk
|
|
5
|
+
from kirin.rewrite import abc, walk
|
|
6
6
|
from kirin.dialects import py
|
|
7
7
|
|
|
8
8
|
from bloqade.qasm2.dialects import core
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class IndexingDesugarRule(abc.RewriteRule):
|
|
12
|
-
def rewrite_Statement(self, node: ir.Statement) ->
|
|
12
|
+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
|
|
13
13
|
if isinstance(node, py.indexing.GetItem):
|
|
14
14
|
if node.obj.type.is_subseteq(core.QRegType):
|
|
15
15
|
node.replace_by(core.QRegGet(reg=node.obj, idx=node.index))
|
|
16
|
-
return
|
|
16
|
+
return abc.RewriteResult(has_done_something=True)
|
|
17
17
|
elif node.obj.type.is_subseteq(core.CRegType):
|
|
18
18
|
node.replace_by(core.CRegGet(reg=node.obj, idx=node.index))
|
|
19
|
-
return
|
|
19
|
+
return abc.RewriteResult(has_done_something=True)
|
|
20
20
|
|
|
21
|
-
return
|
|
21
|
+
return abc.RewriteResult()
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
@dataclass
|
|
25
25
|
class IndexingDesugarPass(Pass):
|
|
26
|
-
def unsafe_run(self, mt: ir.Method) ->
|
|
26
|
+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
|
|
27
27
|
|
|
28
28
|
return walk.Walk(IndexingDesugarRule()).rewrite(mt.code)
|
bloqade/qasm2/rewrite/glob.py
CHANGED
|
@@ -2,7 +2,7 @@ from typing import Dict, List
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
|
|
4
4
|
from kirin import ir
|
|
5
|
-
from kirin.rewrite import abc
|
|
5
|
+
from kirin.rewrite import abc
|
|
6
6
|
from kirin.dialects import py, ilist
|
|
7
7
|
|
|
8
8
|
from bloqade import qasm2
|
|
@@ -47,18 +47,18 @@ class GlobalRewriteBase:
|
|
|
47
47
|
@dataclass
|
|
48
48
|
class GlobalToParallelRule(abc.RewriteRule, GlobalRewriteBase):
|
|
49
49
|
|
|
50
|
-
def rewrite_Statement(self, node: ir.Statement) ->
|
|
50
|
+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
|
|
51
51
|
if type(node) in glob.dialect.stmts:
|
|
52
52
|
return getattr(self, f"rewrite_{node.name}")(node)
|
|
53
53
|
|
|
54
|
-
return
|
|
54
|
+
return abc.RewriteResult()
|
|
55
55
|
|
|
56
56
|
def rewrite_ugate(self, node: glob.UGate):
|
|
57
57
|
|
|
58
58
|
new_stmts, qubit_ssa = self.get_qubit_ssa(node)
|
|
59
59
|
|
|
60
60
|
if qubit_ssa is None:
|
|
61
|
-
return
|
|
61
|
+
return abc.RewriteResult()
|
|
62
62
|
|
|
63
63
|
new_stmts.append(qargs := ilist.New(values=qubit_ssa))
|
|
64
64
|
new_stmts.append(
|
|
@@ -72,24 +72,24 @@ class GlobalToParallelRule(abc.RewriteRule, GlobalRewriteBase):
|
|
|
72
72
|
|
|
73
73
|
node.delete()
|
|
74
74
|
|
|
75
|
-
return
|
|
75
|
+
return abc.RewriteResult(has_done_something=True)
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
@dataclass
|
|
79
79
|
class GlobalToUOpRule(abc.RewriteRule, GlobalRewriteBase):
|
|
80
80
|
|
|
81
|
-
def rewrite_Statement(self, node: ir.Statement) ->
|
|
81
|
+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
|
|
82
82
|
if type(node) in glob.dialect.stmts:
|
|
83
83
|
return getattr(self, f"rewrite_{node.name}")(node)
|
|
84
84
|
|
|
85
|
-
return
|
|
85
|
+
return abc.RewriteResult()
|
|
86
86
|
|
|
87
87
|
def rewrite_ugate(self, node: glob.UGate):
|
|
88
88
|
|
|
89
89
|
new_stmts, qubit_ssa = self.get_qubit_ssa(node)
|
|
90
90
|
|
|
91
91
|
if qubit_ssa is None:
|
|
92
|
-
return
|
|
92
|
+
return abc.RewriteResult()
|
|
93
93
|
|
|
94
94
|
for qarg in qubit_ssa:
|
|
95
95
|
new_stmts.append(
|
|
@@ -100,4 +100,4 @@ class GlobalToUOpRule(abc.RewriteRule, GlobalRewriteBase):
|
|
|
100
100
|
stmt.insert_before(node)
|
|
101
101
|
|
|
102
102
|
node.delete()
|
|
103
|
-
return
|
|
103
|
+
return abc.RewriteResult(has_done_something=True)
|