bloqade-circuit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/__init__.py +0 -0
- bloqade/analysis/address/__init__.py +11 -0
- bloqade/analysis/address/analysis.py +60 -0
- bloqade/analysis/address/impls.py +228 -0
- bloqade/analysis/address/lattice.py +85 -0
- bloqade/noise/__init__.py +1 -0
- bloqade/noise/native/__init__.py +20 -0
- bloqade/noise/native/_dialect.py +3 -0
- bloqade/noise/native/_wrappers.py +34 -0
- bloqade/noise/native/model.py +347 -0
- bloqade/noise/native/rewrite.py +35 -0
- bloqade/noise/native/stmts.py +46 -0
- bloqade/pyqrack/__init__.py +18 -0
- bloqade/pyqrack/base.py +131 -0
- bloqade/pyqrack/noise/__init__.py +0 -0
- bloqade/pyqrack/noise/native.py +100 -0
- bloqade/pyqrack/qasm2/__init__.py +0 -0
- bloqade/pyqrack/qasm2/core.py +79 -0
- bloqade/pyqrack/qasm2/parallel.py +46 -0
- bloqade/pyqrack/qasm2/uop.py +247 -0
- bloqade/pyqrack/reg.py +109 -0
- bloqade/pyqrack/target.py +112 -0
- bloqade/qasm2/__init__.py +19 -0
- bloqade/qasm2/_wrappers.py +674 -0
- bloqade/qasm2/dialects/__init__.py +10 -0
- bloqade/qasm2/dialects/core/__init__.py +3 -0
- bloqade/qasm2/dialects/core/_dialect.py +3 -0
- bloqade/qasm2/dialects/core/_emit.py +68 -0
- bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
- bloqade/qasm2/dialects/core/address.py +38 -0
- bloqade/qasm2/dialects/core/stmts.py +94 -0
- bloqade/qasm2/dialects/expr/__init__.py +3 -0
- bloqade/qasm2/dialects/expr/_dialect.py +3 -0
- bloqade/qasm2/dialects/expr/_emit.py +103 -0
- bloqade/qasm2/dialects/expr/_from_python.py +86 -0
- bloqade/qasm2/dialects/expr/_interp.py +75 -0
- bloqade/qasm2/dialects/expr/stmts.py +262 -0
- bloqade/qasm2/dialects/glob.py +45 -0
- bloqade/qasm2/dialects/indexing.py +64 -0
- bloqade/qasm2/dialects/inline.py +76 -0
- bloqade/qasm2/dialects/noise.py +16 -0
- bloqade/qasm2/dialects/parallel.py +110 -0
- bloqade/qasm2/dialects/uop/__init__.py +4 -0
- bloqade/qasm2/dialects/uop/_dialect.py +3 -0
- bloqade/qasm2/dialects/uop/_emit.py +211 -0
- bloqade/qasm2/dialects/uop/schedule.py +89 -0
- bloqade/qasm2/dialects/uop/stmts.py +325 -0
- bloqade/qasm2/emit/__init__.py +1 -0
- bloqade/qasm2/emit/base.py +72 -0
- bloqade/qasm2/emit/gate.py +102 -0
- bloqade/qasm2/emit/main.py +106 -0
- bloqade/qasm2/emit/target.py +165 -0
- bloqade/qasm2/glob.py +24 -0
- bloqade/qasm2/groups.py +120 -0
- bloqade/qasm2/parallel.py +48 -0
- bloqade/qasm2/parse/__init__.py +37 -0
- bloqade/qasm2/parse/ast.py +235 -0
- bloqade/qasm2/parse/build.py +289 -0
- bloqade/qasm2/parse/lowering.py +553 -0
- bloqade/qasm2/parse/parser.py +5 -0
- bloqade/qasm2/parse/print.py +293 -0
- bloqade/qasm2/parse/qasm2.lark +75 -0
- bloqade/qasm2/parse/visitor.py +16 -0
- bloqade/qasm2/parse/visitor.pyi +39 -0
- bloqade/qasm2/passes/__init__.py +5 -0
- bloqade/qasm2/passes/fold.py +94 -0
- bloqade/qasm2/passes/glob.py +119 -0
- bloqade/qasm2/passes/noise.py +61 -0
- bloqade/qasm2/passes/parallel.py +176 -0
- bloqade/qasm2/passes/py2qasm.py +63 -0
- bloqade/qasm2/passes/qasm2py.py +61 -0
- bloqade/qasm2/rewrite/__init__.py +12 -0
- bloqade/qasm2/rewrite/desugar.py +28 -0
- bloqade/qasm2/rewrite/glob.py +103 -0
- bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
- bloqade/qasm2/rewrite/native_gates.py +447 -0
- bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
- bloqade/qasm2/rewrite/register.py +45 -0
- bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
- bloqade/qasm2/types.py +39 -0
- bloqade/qbraid/__init__.py +2 -0
- bloqade/qbraid/lowering.py +324 -0
- bloqade/qbraid/schema.py +252 -0
- bloqade/qbraid/simulation_result.py +99 -0
- bloqade/qbraid/target.py +86 -0
- bloqade/squin/__init__.py +2 -0
- bloqade/squin/analysis/__init__.py +0 -0
- bloqade/squin/analysis/nsites/__init__.py +8 -0
- bloqade/squin/analysis/nsites/analysis.py +52 -0
- bloqade/squin/analysis/nsites/impls.py +69 -0
- bloqade/squin/analysis/nsites/lattice.py +49 -0
- bloqade/squin/analysis/schedule.py +244 -0
- bloqade/squin/groups.py +38 -0
- bloqade/squin/op/__init__.py +132 -0
- bloqade/squin/op/_dialect.py +3 -0
- bloqade/squin/op/complex.py +6 -0
- bloqade/squin/op/stmts.py +220 -0
- bloqade/squin/op/traits.py +43 -0
- bloqade/squin/op/types.py +10 -0
- bloqade/squin/qubit.py +118 -0
- bloqade/squin/wire.py +103 -0
- bloqade/stim/__init__.py +6 -0
- bloqade/stim/_wrappers.py +186 -0
- bloqade/stim/dialects/__init__.py +5 -0
- bloqade/stim/dialects/aux/__init__.py +11 -0
- bloqade/stim/dialects/aux/_dialect.py +3 -0
- bloqade/stim/dialects/aux/emit.py +102 -0
- bloqade/stim/dialects/aux/interp.py +39 -0
- bloqade/stim/dialects/aux/lowering.py +40 -0
- bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
- bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
- bloqade/stim/dialects/aux/stmts/const.py +95 -0
- bloqade/stim/dialects/aux/types.py +19 -0
- bloqade/stim/dialects/collapse/__init__.py +3 -0
- bloqade/stim/dialects/collapse/_dialect.py +3 -0
- bloqade/stim/dialects/collapse/emit.py +68 -0
- bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
- bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
- bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
- bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
- bloqade/stim/dialects/gate/__init__.py +3 -0
- bloqade/stim/dialects/gate/_dialect.py +3 -0
- bloqade/stim/dialects/gate/emit.py +87 -0
- bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
- bloqade/stim/dialects/gate/stmts/base.py +31 -0
- bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
- bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
- bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
- bloqade/stim/dialects/gate/stmts/pp.py +15 -0
- bloqade/stim/dialects/noise/__init__.py +3 -0
- bloqade/stim/dialects/noise/_dialect.py +3 -0
- bloqade/stim/dialects/noise/emit.py +66 -0
- bloqade/stim/dialects/noise/stmts.py +77 -0
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/stim.py +54 -0
- bloqade/stim/groups.py +26 -0
- bloqade/test_utils.py +35 -0
- bloqade/types.py +24 -0
- bloqade/visual/__init__.py +1 -0
- bloqade/visual/animation/__init__.py +0 -0
- bloqade/visual/animation/animate.py +267 -0
- bloqade/visual/animation/base.py +346 -0
- bloqade/visual/animation/gate_event.py +24 -0
- bloqade/visual/animation/runtime/__init__.py +0 -0
- bloqade/visual/animation/runtime/aod.py +36 -0
- bloqade/visual/animation/runtime/atoms.py +55 -0
- bloqade/visual/animation/runtime/ppoly.py +50 -0
- bloqade/visual/animation/runtime/qpustate.py +119 -0
- bloqade/visual/animation/runtime/utils.py +43 -0
- bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
- bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
- bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
- bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
bloqade/qbraid/target.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Union, Optional
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from qbraid import QbraidProvider
|
|
7
|
+
from qbraid.runtime import QbraidJob
|
|
8
|
+
|
|
9
|
+
from bloqade.qasm2.emit import QASM2
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class qBraid:
|
|
13
|
+
"""qBraid target for Bloqade kernels.
|
|
14
|
+
|
|
15
|
+
qBraid target that accepts a Bloqade kernel and submits the kernel to the QuEra simulator hosted on qBraid. A `QbraidJob` is obtainable
|
|
16
|
+
that then lets you query the status of the submitted program on the simulator as well
|
|
17
|
+
as obtain results.
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
*,
|
|
24
|
+
allow_parallel: bool = False,
|
|
25
|
+
allow_global: bool = False,
|
|
26
|
+
provider: "QbraidProvider", # inject externally for easier mocking
|
|
27
|
+
qelib1: bool = True,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Initialize the qBraid target.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
allow_parallel (bool):
|
|
33
|
+
Allow parallel gate in the resulting QASM2 AST. Defaults to `False`.
|
|
34
|
+
In the case its False, and the input kernel uses parallel gates, they will get rewrite into uop gates.
|
|
35
|
+
|
|
36
|
+
allow_global (bool):
|
|
37
|
+
Allow global gate in the resulting QASM2 AST. Defaults to `False`.
|
|
38
|
+
In the case its False, and the input kernel uses global gates, they will get rewrite into parallel gates.
|
|
39
|
+
If both `allow_parallel` and `allow_global` are False, the input kernel will be rewritten to use uop gates.
|
|
40
|
+
|
|
41
|
+
provider (QbraidProvider):
|
|
42
|
+
Qbraid-provided object to allow submission of the kernel to the QuEra simulator.
|
|
43
|
+
qelib1 (bool):
|
|
44
|
+
Include the `include "qelib1.inc"` line in the resulting QASM2 AST that's
|
|
45
|
+
submitted to qBraid. Defaults to `True`.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
self.qelib1 = qelib1
|
|
49
|
+
self.provider = provider
|
|
50
|
+
self.allow_parallel = allow_parallel
|
|
51
|
+
self.allow_global = allow_global
|
|
52
|
+
|
|
53
|
+
def emit(
|
|
54
|
+
self,
|
|
55
|
+
method: ir.Method,
|
|
56
|
+
shots: Optional[int] = None,
|
|
57
|
+
tags: Optional[dict[str, str]] = None,
|
|
58
|
+
) -> Union["QbraidJob", list["QbraidJob"]]:
|
|
59
|
+
"""Submit the Bloqade kernel to the QuEra simulator on qBraid.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
method (ir.Method):
|
|
63
|
+
The kernel to submit to qBraid.
|
|
64
|
+
shots: (Optional[int]):
|
|
65
|
+
Number of times to run the kernel. Defaults to None.
|
|
66
|
+
tags: (Optional[dict[str,str]]):
|
|
67
|
+
A dictionary of tags to associate with the Job.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Union[QbraidJob, list[QbraidJob]]:
|
|
71
|
+
An object you can query for the status of your submission as well as
|
|
72
|
+
obtain simulator results from.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
# Convert method to QASM2 string
|
|
76
|
+
qasm2_emitter = QASM2(
|
|
77
|
+
allow_parallel=self.allow_parallel,
|
|
78
|
+
allow_global=self.allow_global,
|
|
79
|
+
qelib1=self.qelib1,
|
|
80
|
+
)
|
|
81
|
+
qasm2_prog = qasm2_emitter.emit_str(method)
|
|
82
|
+
|
|
83
|
+
# Submit the QASM2 string to the qBraid simulator
|
|
84
|
+
quera_qasm_simulator = self.provider.get_device("quera_qasm_simulator")
|
|
85
|
+
|
|
86
|
+
return quera_qasm_simulator.run(qasm2_prog, shots=shots, tags=tags)
|
|
File without changes
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# from typing import cast
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.analysis import Forward
|
|
5
|
+
from kirin.analysis.forward import ForwardFrame
|
|
6
|
+
|
|
7
|
+
from bloqade.squin.op.types import OpType
|
|
8
|
+
from bloqade.squin.op.traits import HasSites, FixedSites
|
|
9
|
+
|
|
10
|
+
from .lattice import Sites, NoSites, NumberSites
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NSitesAnalysis(Forward[Sites]):
|
|
14
|
+
|
|
15
|
+
keys = ["op.nsites"]
|
|
16
|
+
lattice = Sites
|
|
17
|
+
|
|
18
|
+
# Take a page from const prop in Kirin,
|
|
19
|
+
# I can get the data I want from the SizedTrait
|
|
20
|
+
# and go from there
|
|
21
|
+
|
|
22
|
+
## This gets called before the registry look up
|
|
23
|
+
def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
|
|
24
|
+
method = self.lookup_registry(frame, stmt)
|
|
25
|
+
if method is not None:
|
|
26
|
+
return method(self, frame, stmt)
|
|
27
|
+
elif stmt.has_trait(HasSites):
|
|
28
|
+
has_sites_trait = stmt.get_trait(HasSites)
|
|
29
|
+
sites = has_sites_trait.get_sites(stmt)
|
|
30
|
+
return (NumberSites(sites=sites),)
|
|
31
|
+
elif stmt.has_trait(FixedSites):
|
|
32
|
+
sites_trait = stmt.get_trait(FixedSites)
|
|
33
|
+
return (NumberSites(sites=sites_trait.data),)
|
|
34
|
+
else:
|
|
35
|
+
return (NoSites(),)
|
|
36
|
+
|
|
37
|
+
# For when no implementation is found for the statement
|
|
38
|
+
def eval_stmt_fallback(
|
|
39
|
+
self, frame: ForwardFrame[Sites], stmt: ir.Statement
|
|
40
|
+
) -> tuple[Sites, ...]: # some form of Shape will go back into the frame
|
|
41
|
+
return tuple(
|
|
42
|
+
(
|
|
43
|
+
self.lattice.top()
|
|
44
|
+
if result.type.is_subseteq(OpType)
|
|
45
|
+
else self.lattice.bottom()
|
|
46
|
+
)
|
|
47
|
+
for result in stmt.results
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def run_method(self, method: ir.Method, args: tuple[Sites, ...]):
|
|
51
|
+
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
|
|
52
|
+
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
3
|
+
from kirin import ir, interp
|
|
4
|
+
|
|
5
|
+
from bloqade.squin import op
|
|
6
|
+
|
|
7
|
+
from .lattice import (
|
|
8
|
+
NoSites,
|
|
9
|
+
NumberSites,
|
|
10
|
+
)
|
|
11
|
+
from .analysis import NSitesAnalysis
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@op.dialect.register(key="op.nsites")
|
|
15
|
+
class SquinOp(interp.MethodTable):
|
|
16
|
+
|
|
17
|
+
@interp.impl(op.stmts.Kron)
|
|
18
|
+
def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron):
|
|
19
|
+
lhs = frame.get(stmt.lhs)
|
|
20
|
+
rhs = frame.get(stmt.rhs)
|
|
21
|
+
if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites):
|
|
22
|
+
new_n_sites = lhs.sites + rhs.sites
|
|
23
|
+
return (NumberSites(sites=new_n_sites),)
|
|
24
|
+
else:
|
|
25
|
+
return (NoSites(),)
|
|
26
|
+
|
|
27
|
+
@interp.impl(op.stmts.Mult)
|
|
28
|
+
def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult):
|
|
29
|
+
lhs = frame.get(stmt.lhs)
|
|
30
|
+
rhs = frame.get(stmt.rhs)
|
|
31
|
+
|
|
32
|
+
if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites):
|
|
33
|
+
lhs_sites = lhs.sites
|
|
34
|
+
rhs_sites = rhs.sites
|
|
35
|
+
# I originally considered throwing an exception here
|
|
36
|
+
# but Xiu-zhe (Roger) Luo has pointed out it would be
|
|
37
|
+
# a much better UX to add a type element that
|
|
38
|
+
# could explicitly indicate the error. The downside
|
|
39
|
+
# is you'll have some added complexity in the type lattice.
|
|
40
|
+
if lhs_sites != rhs_sites:
|
|
41
|
+
return (NoSites(),)
|
|
42
|
+
else:
|
|
43
|
+
return (NumberSites(sites=lhs_sites + rhs_sites),)
|
|
44
|
+
else:
|
|
45
|
+
return (NoSites(),)
|
|
46
|
+
|
|
47
|
+
@interp.impl(op.stmts.Control)
|
|
48
|
+
def control(
|
|
49
|
+
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Control
|
|
50
|
+
):
|
|
51
|
+
op_sites = frame.get(stmt.op)
|
|
52
|
+
|
|
53
|
+
if isinstance(op_sites, NumberSites):
|
|
54
|
+
n_sites = op_sites.sites
|
|
55
|
+
n_controls_attr = stmt.get_attr_or_prop("n_controls")
|
|
56
|
+
n_controls = cast(ir.PyAttr[int], n_controls_attr).data
|
|
57
|
+
return (NumberSites(sites=n_sites + n_controls),)
|
|
58
|
+
else:
|
|
59
|
+
return (NoSites(),)
|
|
60
|
+
|
|
61
|
+
@interp.impl(op.stmts.Rot)
|
|
62
|
+
def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot):
|
|
63
|
+
op_sites = frame.get(stmt.axis)
|
|
64
|
+
return (op_sites,)
|
|
65
|
+
|
|
66
|
+
@interp.impl(op.stmts.Scale)
|
|
67
|
+
def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
|
|
68
|
+
op_sites = frame.get(stmt.op)
|
|
69
|
+
return (op_sites,)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import final
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from kirin.lattice import (
|
|
5
|
+
SingletonMeta,
|
|
6
|
+
BoundedLattice,
|
|
7
|
+
SimpleJoinMixin,
|
|
8
|
+
SimpleMeetMixin,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class Sites(
|
|
14
|
+
SimpleJoinMixin["Sites"], SimpleMeetMixin["Sites"], BoundedLattice["Sites"]
|
|
15
|
+
):
|
|
16
|
+
@classmethod
|
|
17
|
+
def bottom(cls) -> "Sites":
|
|
18
|
+
return NoSites()
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def top(cls) -> "Sites":
|
|
22
|
+
return AnySites()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@final
|
|
26
|
+
@dataclass
|
|
27
|
+
class NoSites(Sites, metaclass=SingletonMeta):
|
|
28
|
+
|
|
29
|
+
def is_subseteq(self, other: Sites) -> bool:
|
|
30
|
+
return True
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@final
|
|
34
|
+
@dataclass
|
|
35
|
+
class AnySites(Sites, metaclass=SingletonMeta):
|
|
36
|
+
|
|
37
|
+
def is_subseteq(self, other: Sites) -> bool:
|
|
38
|
+
return isinstance(other, Sites)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@final
|
|
42
|
+
@dataclass
|
|
43
|
+
class NumberSites(Sites):
|
|
44
|
+
sites: int
|
|
45
|
+
|
|
46
|
+
def is_subseteq(self, other: Sites) -> bool:
|
|
47
|
+
if isinstance(other, NumberSites):
|
|
48
|
+
return self.sites == other.sites
|
|
49
|
+
return False
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from typing import Any, Set, Dict, Iterable, Optional, final
|
|
2
|
+
from itertools import chain
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from dataclasses import field, dataclass
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
|
|
7
|
+
from kirin import ir, graph, interp, idtable
|
|
8
|
+
from kirin.lattice import (
|
|
9
|
+
SingletonMeta,
|
|
10
|
+
BoundedLattice,
|
|
11
|
+
SimpleJoinMixin,
|
|
12
|
+
SimpleMeetMixin,
|
|
13
|
+
)
|
|
14
|
+
from kirin.analysis import Forward, ForwardFrame
|
|
15
|
+
from kirin.dialects import func
|
|
16
|
+
|
|
17
|
+
from bloqade.analysis import address
|
|
18
|
+
from bloqade.qasm2.parse.print import Printer
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class GateSchedule(
|
|
23
|
+
SimpleJoinMixin["GateSchedule"],
|
|
24
|
+
SimpleMeetMixin["GateSchedule"],
|
|
25
|
+
BoundedLattice["GateSchedule"],
|
|
26
|
+
):
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def bottom(cls) -> "GateSchedule":
|
|
30
|
+
return NotQubit()
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def top(cls) -> "GateSchedule":
|
|
34
|
+
return Qubit()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@final
|
|
38
|
+
@dataclass
|
|
39
|
+
class NotQubit(GateSchedule, metaclass=SingletonMeta):
|
|
40
|
+
|
|
41
|
+
def is_subseteq(self, other: GateSchedule) -> bool:
|
|
42
|
+
return True
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@final
|
|
46
|
+
@dataclass
|
|
47
|
+
class Qubit(GateSchedule, metaclass=SingletonMeta):
|
|
48
|
+
|
|
49
|
+
def is_subseteq(self, other: GateSchedule) -> bool:
|
|
50
|
+
return isinstance(other, Qubit)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Treat global gates as terminators for this analysis, e.g. split block in half.
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass(slots=True)
|
|
57
|
+
class StmtDag(graph.Graph[ir.Statement]):
|
|
58
|
+
id_table: idtable.IdTable[ir.Statement] = field(
|
|
59
|
+
default_factory=lambda: idtable.IdTable()
|
|
60
|
+
)
|
|
61
|
+
stmts: Dict[str, ir.Statement] = field(default_factory=OrderedDict)
|
|
62
|
+
out_edges: Dict[str, Set[str]] = field(default_factory=OrderedDict)
|
|
63
|
+
inc_edges: Dict[str, Set[str]] = field(default_factory=OrderedDict)
|
|
64
|
+
stmt_index: Dict[ir.Statement, int] = field(default_factory=OrderedDict)
|
|
65
|
+
|
|
66
|
+
def update_index(self, node: ir.Statement):
|
|
67
|
+
if node not in self.stmt_index:
|
|
68
|
+
self.stmt_index[node] = len(self.stmt_index)
|
|
69
|
+
|
|
70
|
+
def add_node(self, node: ir.Statement):
|
|
71
|
+
node_id = self.id_table[node]
|
|
72
|
+
self.stmts[node_id] = node
|
|
73
|
+
self.update_index(node)
|
|
74
|
+
self.out_edges.setdefault(node_id, set())
|
|
75
|
+
self.inc_edges.setdefault(node_id, set())
|
|
76
|
+
return node_id
|
|
77
|
+
|
|
78
|
+
def add_edge(self, src: ir.Statement, dst: ir.Statement):
|
|
79
|
+
src_id = self.add_node(src)
|
|
80
|
+
dst_id = self.add_node(dst)
|
|
81
|
+
|
|
82
|
+
self.out_edges[src_id].add(dst_id)
|
|
83
|
+
self.inc_edges[dst_id].add(src_id)
|
|
84
|
+
|
|
85
|
+
def get_parents(self, node: ir.Statement) -> Iterable[ir.Statement]:
|
|
86
|
+
return (
|
|
87
|
+
self.stmts[node_id]
|
|
88
|
+
for node_id in self.inc_edges.get(self.id_table[node], set())
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def get_children(self, node: ir.Statement) -> Iterable[ir.Statement]:
|
|
92
|
+
return (
|
|
93
|
+
self.stmts[node_id]
|
|
94
|
+
for node_id in self.out_edges.get(self.id_table[node], set())
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def get_neighbors(self, node: ir.Statement) -> Iterable[ir.Statement]:
|
|
98
|
+
return chain(self.get_parents(node), self.get_children(node))
|
|
99
|
+
|
|
100
|
+
def get_nodes(self) -> Iterable[ir.Statement]:
|
|
101
|
+
return self.stmts.values()
|
|
102
|
+
|
|
103
|
+
def get_edges(self) -> Iterable[tuple[ir.Statement, ir.Statement]]:
|
|
104
|
+
return (
|
|
105
|
+
(self.stmts[src], self.stmts[dst])
|
|
106
|
+
for src, dsts in self.out_edges.items()
|
|
107
|
+
for dst in dsts
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def print(
|
|
111
|
+
self,
|
|
112
|
+
printer: Optional["Printer"] = None,
|
|
113
|
+
analysis: dict["ir.SSAValue", Any] | None = None,
|
|
114
|
+
) -> None:
|
|
115
|
+
raise NotImplementedError
|
|
116
|
+
|
|
117
|
+
def topological_groups(self):
|
|
118
|
+
"""Split the dag into topological groups where each group
|
|
119
|
+
contains nodes that have no dependencies on each other, but
|
|
120
|
+
have dependencies on nodes in one or more previous groups.
|
|
121
|
+
|
|
122
|
+
Yields:
|
|
123
|
+
List[str]: A list of node ids in a topological group
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
ValueError: If a cyclic dependency is detected
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
The idea is to yield all nodes with no dependencies, then remove
|
|
131
|
+
those nodes from the graph repeating until no nodes are left
|
|
132
|
+
or we reach some upper limit. Worse case is a linear dag,
|
|
133
|
+
so we can use len(dag.stmts) as the upper limit
|
|
134
|
+
|
|
135
|
+
If we reach the limit and there are still nodes left, then we
|
|
136
|
+
have a cyclic dependency.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
inc_edges = {k: set(v) for k, v in self.inc_edges.items()}
|
|
140
|
+
|
|
141
|
+
check_next = inc_edges.keys()
|
|
142
|
+
|
|
143
|
+
for _ in range(len(self.stmts)):
|
|
144
|
+
if len(inc_edges) == 0:
|
|
145
|
+
break
|
|
146
|
+
|
|
147
|
+
group = [node_id for node_id in check_next if len(inc_edges[node_id]) == 0]
|
|
148
|
+
yield group
|
|
149
|
+
|
|
150
|
+
check_next = set()
|
|
151
|
+
for n in group:
|
|
152
|
+
inc_edges.pop(n)
|
|
153
|
+
for m in self.out_edges[n]:
|
|
154
|
+
check_next.add(m)
|
|
155
|
+
inc_edges[m].remove(n)
|
|
156
|
+
|
|
157
|
+
if inc_edges:
|
|
158
|
+
raise ValueError("Cyclic dependency detected")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@dataclass
|
|
162
|
+
class DagScheduleAnalysis(Forward[GateSchedule]):
|
|
163
|
+
keys = ["qasm2.schedule.dag"]
|
|
164
|
+
lattice = GateSchedule
|
|
165
|
+
|
|
166
|
+
address_analysis: Dict[ir.SSAValue, address.Address]
|
|
167
|
+
use_def: Dict[int, ir.Statement] = field(init=False)
|
|
168
|
+
stmt_dag: StmtDag = field(init=False)
|
|
169
|
+
stmt_dags: Dict[ir.Block, StmtDag] = field(init=False)
|
|
170
|
+
|
|
171
|
+
def initialize(self):
|
|
172
|
+
self.use_def = {}
|
|
173
|
+
self.stmt_dag = StmtDag()
|
|
174
|
+
self.stmt_dags = {}
|
|
175
|
+
return super().initialize()
|
|
176
|
+
|
|
177
|
+
def push_current_dag(self, block: ir.Block):
|
|
178
|
+
# run when hitting terminator statements
|
|
179
|
+
assert block not in self.stmt_dags, "Block already in stmt_dags"
|
|
180
|
+
|
|
181
|
+
for node in self.use_def.values():
|
|
182
|
+
self.stmt_dag.add_node(node)
|
|
183
|
+
|
|
184
|
+
self.stmt_dags[block] = self.stmt_dag
|
|
185
|
+
self.stmt_dag = StmtDag()
|
|
186
|
+
self.use_def = {}
|
|
187
|
+
|
|
188
|
+
def run_method(self, method: ir.Method, args: tuple[GateSchedule, ...]):
|
|
189
|
+
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
|
|
190
|
+
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
|
|
191
|
+
|
|
192
|
+
def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
|
|
193
|
+
if stmt.has_trait(ir.IsTerminator):
|
|
194
|
+
assert (
|
|
195
|
+
stmt.parent_block is not None
|
|
196
|
+
), "Terminator statement has no parent block"
|
|
197
|
+
self.push_current_dag(stmt.parent_block)
|
|
198
|
+
|
|
199
|
+
return tuple(self.lattice.top() for _ in stmt.results)
|
|
200
|
+
|
|
201
|
+
def _update_dag(self, stmt: ir.Statement, addr: address.Address):
|
|
202
|
+
if isinstance(addr, address.AddressQubit):
|
|
203
|
+
old_stmt = self.use_def.get(addr.data, None)
|
|
204
|
+
if old_stmt is not None:
|
|
205
|
+
self.stmt_dag.add_edge(old_stmt, stmt)
|
|
206
|
+
self.use_def[addr.data] = stmt
|
|
207
|
+
elif isinstance(addr, address.AddressReg):
|
|
208
|
+
for idx in addr.data:
|
|
209
|
+
old_stmt = self.use_def.get(idx, None)
|
|
210
|
+
if old_stmt is not None:
|
|
211
|
+
self.stmt_dag.add_edge(old_stmt, stmt)
|
|
212
|
+
self.use_def[idx] = stmt
|
|
213
|
+
elif isinstance(addr, address.AddressTuple):
|
|
214
|
+
for sub_addr in addr.data:
|
|
215
|
+
self._update_dag(stmt, sub_addr)
|
|
216
|
+
|
|
217
|
+
def update_dag(self, stmt: ir.Statement, args: Sequence[ir.SSAValue]):
|
|
218
|
+
self.stmt_dag.add_node(stmt)
|
|
219
|
+
|
|
220
|
+
for arg in args:
|
|
221
|
+
self._update_dag(
|
|
222
|
+
stmt, self.address_analysis.get(arg, address.Address.bottom())
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
def get_dags(self, mt: ir.Method, args=None, kwargs=None):
|
|
226
|
+
if args is None:
|
|
227
|
+
args = tuple(self.lattice.top() for _ in mt.args)
|
|
228
|
+
|
|
229
|
+
self.run(mt, args, kwargs).expect()
|
|
230
|
+
return self.stmt_dags
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@func.dialect.register(key="qasm2.schedule.dag")
|
|
234
|
+
class FuncImpl(interp.MethodTable):
|
|
235
|
+
@interp.impl(func.Invoke)
|
|
236
|
+
@interp.impl(func.Call)
|
|
237
|
+
def invoke(
|
|
238
|
+
self,
|
|
239
|
+
interp: DagScheduleAnalysis,
|
|
240
|
+
frame: ForwardFrame,
|
|
241
|
+
stmt: func.Invoke | func.Call,
|
|
242
|
+
):
|
|
243
|
+
interp.update_dag(stmt, stmt.inputs)
|
|
244
|
+
return tuple(interp.lattice.top() for _ in stmt.results)
|
bloqade/squin/groups.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from kirin import ir, passes
|
|
2
|
+
from kirin.prelude import structural_no_opt
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
|
|
5
|
+
from bloqade.qasm2.rewrite.desugar import IndexingDesugarPass
|
|
6
|
+
|
|
7
|
+
from . import op, wire, qubit
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@ir.dialect_group(structural_no_opt.union([op, qubit]))
|
|
11
|
+
def kernel(self):
|
|
12
|
+
fold_pass = passes.Fold(self)
|
|
13
|
+
typeinfer_pass = passes.TypeInfer(self)
|
|
14
|
+
ilist_desugar_pass = ilist.IListDesugar(self)
|
|
15
|
+
indexing_desugar_pass = IndexingDesugarPass(self)
|
|
16
|
+
|
|
17
|
+
def run_pass(method, *, fold=True, typeinfer=True):
|
|
18
|
+
method.verify()
|
|
19
|
+
if fold:
|
|
20
|
+
fold_pass.fixpoint(method)
|
|
21
|
+
|
|
22
|
+
if typeinfer:
|
|
23
|
+
typeinfer_pass(method)
|
|
24
|
+
ilist_desugar_pass(method)
|
|
25
|
+
indexing_desugar_pass(method)
|
|
26
|
+
if typeinfer:
|
|
27
|
+
typeinfer_pass(method) # fix types after desugaring
|
|
28
|
+
method.code.typecheck()
|
|
29
|
+
|
|
30
|
+
return run_pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@ir.dialect_group(structural_no_opt.union([op, wire]))
|
|
34
|
+
def wired(self):
|
|
35
|
+
def run_pass(method):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
return run_pass
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from kirin import ir as _ir
|
|
2
|
+
from kirin.prelude import structural_no_opt as _structural_no_opt
|
|
3
|
+
from kirin.lowering import wraps as _wraps
|
|
4
|
+
|
|
5
|
+
from . import stmts as stmts, types as types
|
|
6
|
+
from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
|
|
7
|
+
from ._dialect import dialect as dialect
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@_wraps(stmts.Kron)
|
|
11
|
+
def kron(lhs: types.Op, rhs: types.Op, *, is_unitary: bool = False) -> types.Op: ...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@_wraps(stmts.Adjoint)
|
|
15
|
+
def adjoint(op: types.Op, *, is_unitary: bool = False) -> types.Op: ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@_wraps(stmts.Control)
|
|
19
|
+
def control(op: types.Op, *, n_controls: int, is_unitary: bool = False) -> types.Op: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@_wraps(stmts.Identity)
|
|
23
|
+
def identity(*, size: int) -> types.Op: ...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@_wraps(stmts.Rot)
|
|
27
|
+
def rot(axis: types.Op, angle: float) -> types.Op: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@_wraps(stmts.ShiftOp)
|
|
31
|
+
def shift(theta: float) -> types.Op: ...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@_wraps(stmts.PhaseOp)
|
|
35
|
+
def phase(theta: float) -> types.Op: ...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@_wraps(stmts.X)
|
|
39
|
+
def x() -> types.Op: ...
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@_wraps(stmts.Y)
|
|
43
|
+
def y() -> types.Op: ...
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@_wraps(stmts.Z)
|
|
47
|
+
def z() -> types.Op: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@_wraps(stmts.H)
|
|
51
|
+
def h() -> types.Op: ...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@_wraps(stmts.S)
|
|
55
|
+
def s() -> types.Op: ...
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@_wraps(stmts.T)
|
|
59
|
+
def t() -> types.Op: ...
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@_wraps(stmts.P0)
|
|
63
|
+
def p0() -> types.Op: ...
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@_wraps(stmts.P1)
|
|
67
|
+
def p1() -> types.Op: ...
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@_wraps(stmts.Sn)
|
|
71
|
+
def spin_n() -> types.Op: ...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@_wraps(stmts.Sp)
|
|
75
|
+
def spin_p() -> types.Op: ...
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# stdlibs
|
|
79
|
+
@_ir.dialect_group(_structural_no_opt.add(dialect))
|
|
80
|
+
def op(self):
|
|
81
|
+
def run_pass(method):
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
return run_pass
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@op
|
|
88
|
+
def rx(theta: float) -> types.Op:
|
|
89
|
+
"""Rotation X gate."""
|
|
90
|
+
return rot(x(), theta)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@op
|
|
94
|
+
def ry(theta: float) -> types.Op:
|
|
95
|
+
"""Rotation Y gate."""
|
|
96
|
+
return rot(y(), theta)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@op
|
|
100
|
+
def rz(theta: float) -> types.Op:
|
|
101
|
+
"""Rotation Z gate."""
|
|
102
|
+
return rot(z(), theta)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@op
|
|
106
|
+
def cx() -> types.Op:
|
|
107
|
+
"""Controlled X gate."""
|
|
108
|
+
return control(x(), n_controls=1)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@op
|
|
112
|
+
def cy() -> types.Op:
|
|
113
|
+
"""Controlled Y gate."""
|
|
114
|
+
return control(y(), n_controls=1)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@op
|
|
118
|
+
def cz() -> types.Op:
|
|
119
|
+
"""Control Z gate."""
|
|
120
|
+
return control(z(), n_controls=1)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@op
|
|
124
|
+
def ch() -> types.Op:
|
|
125
|
+
"""Control H gate."""
|
|
126
|
+
return control(h(), n_controls=1)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@op
|
|
130
|
+
def cphase(theta: float) -> types.Op:
|
|
131
|
+
"""Control Phase gate."""
|
|
132
|
+
return control(phase(theta), n_controls=1)
|