bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bloqade/analysis/address/__init__.py +8 -4
- bloqade/analysis/address/analysis.py +123 -33
- bloqade/analysis/address/impls.py +293 -90
- bloqade/analysis/address/lattice.py +209 -24
- bloqade/analysis/fidelity/analysis.py +11 -23
- bloqade/analysis/measure_id/analysis.py +18 -20
- bloqade/analysis/measure_id/impls.py +31 -29
- bloqade/annotate/__init__.py +6 -0
- bloqade/annotate/_dialect.py +3 -0
- bloqade/annotate/_interface.py +22 -0
- bloqade/annotate/stmts.py +29 -0
- bloqade/annotate/types.py +13 -0
- bloqade/cirq_utils/__init__.py +4 -2
- bloqade/cirq_utils/emit/__init__.py +3 -0
- bloqade/cirq_utils/emit/base.py +246 -0
- bloqade/cirq_utils/emit/gate.py +104 -0
- bloqade/cirq_utils/emit/noise.py +90 -0
- bloqade/cirq_utils/emit/qubit.py +35 -0
- bloqade/cirq_utils/lowering.py +660 -0
- bloqade/cirq_utils/noise/__init__.py +0 -2
- bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
- bloqade/cirq_utils/noise/model.py +151 -191
- bloqade/cirq_utils/noise/transform.py +2 -2
- bloqade/cirq_utils/parallelize.py +9 -6
- bloqade/gemini/__init__.py +1 -0
- bloqade/gemini/analysis/__init__.py +3 -0
- bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
- bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
- bloqade/gemini/analysis/logical_validation/impls.py +101 -0
- bloqade/gemini/groups.py +67 -0
- bloqade/native/__init__.py +23 -0
- bloqade/native/_prelude.py +45 -0
- bloqade/native/dialects/__init__.py +0 -0
- bloqade/native/dialects/gate/__init__.py +2 -0
- bloqade/native/dialects/gate/_dialect.py +3 -0
- bloqade/native/dialects/gate/_interface.py +32 -0
- bloqade/native/dialects/gate/stmts.py +31 -0
- bloqade/native/stdlib/__init__.py +0 -0
- bloqade/native/stdlib/broadcast.py +246 -0
- bloqade/native/stdlib/simple.py +220 -0
- bloqade/native/upstream/__init__.py +4 -0
- bloqade/native/upstream/squin2native.py +79 -0
- bloqade/pyqrack/__init__.py +2 -2
- bloqade/pyqrack/base.py +7 -1
- bloqade/pyqrack/device.py +192 -18
- bloqade/pyqrack/native.py +49 -0
- bloqade/pyqrack/reg.py +6 -6
- bloqade/pyqrack/squin/gate/__init__.py +1 -0
- bloqade/pyqrack/squin/gate/gate.py +136 -0
- bloqade/pyqrack/squin/noise/native.py +120 -54
- bloqade/pyqrack/squin/qubit.py +39 -36
- bloqade/pyqrack/target.py +5 -4
- bloqade/pyqrack/task.py +114 -7
- bloqade/qasm2/_qasm_loading.py +3 -3
- bloqade/qasm2/dialects/core/address.py +21 -12
- bloqade/qasm2/dialects/expr/_emit.py +19 -8
- bloqade/qasm2/dialects/expr/stmts.py +7 -7
- bloqade/qasm2/dialects/noise/fidelity.py +4 -8
- bloqade/qasm2/dialects/noise/model.py +2 -1
- bloqade/qasm2/emit/base.py +16 -11
- bloqade/qasm2/emit/gate.py +11 -8
- bloqade/qasm2/emit/main.py +103 -3
- bloqade/qasm2/emit/target.py +9 -5
- bloqade/qasm2/groups.py +3 -2
- bloqade/qasm2/parse/lowering.py +0 -1
- bloqade/qasm2/passes/fold.py +14 -73
- bloqade/qasm2/passes/glob.py +2 -2
- bloqade/qasm2/passes/noise.py +1 -1
- bloqade/qasm2/passes/parallel.py +7 -5
- bloqade/qasm2/rewrite/__init__.py +0 -1
- bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
- bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
- bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
- bloqade/qasm2/rewrite/register.py +2 -2
- bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
- bloqade/qbraid/lowering.py +1 -0
- bloqade/qbraid/schema.py +2 -2
- bloqade/qubit/__init__.py +12 -0
- bloqade/qubit/_dialect.py +3 -0
- bloqade/qubit/_interface.py +49 -0
- bloqade/qubit/_prelude.py +45 -0
- bloqade/qubit/analysis/__init__.py +1 -0
- bloqade/qubit/analysis/address_impl.py +40 -0
- bloqade/qubit/stdlib/__init__.py +2 -0
- bloqade/qubit/stdlib/_new.py +34 -0
- bloqade/qubit/stdlib/broadcast.py +62 -0
- bloqade/qubit/stdlib/simple.py +59 -0
- bloqade/qubit/stmts.py +60 -0
- bloqade/rewrite/passes/__init__.py +6 -0
- bloqade/rewrite/passes/aggressive_unroll.py +103 -0
- bloqade/rewrite/passes/callgraph.py +116 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
- bloqade/rewrite/rules/split_ifs.py +18 -1
- bloqade/squin/__init__.py +47 -14
- bloqade/squin/analysis/__init__.py +0 -1
- bloqade/squin/analysis/schedule.py +10 -11
- bloqade/squin/gate/__init__.py +2 -0
- bloqade/squin/gate/_dialect.py +3 -0
- bloqade/squin/gate/_interface.py +98 -0
- bloqade/squin/gate/stmts.py +125 -0
- bloqade/squin/groups.py +5 -22
- bloqade/squin/noise/__init__.py +1 -10
- bloqade/squin/noise/_dialect.py +1 -1
- bloqade/squin/noise/_interface.py +45 -0
- bloqade/squin/noise/stmts.py +66 -28
- bloqade/squin/rewrite/U3_to_clifford.py +70 -51
- bloqade/squin/rewrite/__init__.py +0 -2
- bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
- bloqade/squin/rewrite/wrap_analysis.py +4 -35
- bloqade/squin/stdlib/__init__.py +0 -0
- bloqade/squin/stdlib/broadcast/__init__.py +34 -0
- bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
- bloqade/squin/stdlib/broadcast/gate.py +260 -0
- bloqade/squin/stdlib/broadcast/noise.py +144 -0
- bloqade/squin/stdlib/simple/__init__.py +33 -0
- bloqade/squin/stdlib/simple/gate.py +242 -0
- bloqade/squin/stdlib/simple/noise.py +126 -0
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +6 -0
- bloqade/stim/dialects/auxiliary/emit.py +19 -18
- bloqade/stim/dialects/collapse/emit_str.py +7 -8
- bloqade/stim/dialects/gate/emit.py +9 -10
- bloqade/stim/dialects/noise/emit.py +17 -13
- bloqade/stim/dialects/noise/stmts.py +5 -3
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/impls.py +16 -0
- bloqade/stim/emit/stim_str.py +48 -31
- bloqade/stim/groups.py +12 -2
- bloqade/stim/parse/lowering.py +14 -17
- bloqade/stim/passes/__init__.py +0 -2
- bloqade/stim/passes/flatten.py +26 -0
- bloqade/stim/passes/simplify_ifs.py +6 -1
- bloqade/stim/passes/squin_to_stim.py +9 -84
- bloqade/stim/rewrite/__init__.py +2 -4
- bloqade/stim/rewrite/get_record_util.py +24 -0
- bloqade/stim/rewrite/ifs_to_stim.py +24 -25
- bloqade/stim/rewrite/qubit_to_stim.py +90 -41
- bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
- bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
- bloqade/stim/rewrite/squin_measure.py +9 -18
- bloqade/stim/rewrite/squin_noise.py +134 -108
- bloqade/stim/rewrite/util.py +5 -192
- bloqade/test_utils.py +1 -1
- bloqade/types.py +10 -0
- bloqade/validation/__init__.py +2 -0
- bloqade/validation/analysis/__init__.py +5 -0
- bloqade/validation/analysis/analysis.py +41 -0
- bloqade/validation/analysis/lattice.py +58 -0
- bloqade/validation/kernel_validation.py +77 -0
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
- bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
- bloqade/pyqrack/squin/op.py +0 -180
- bloqade/pyqrack/squin/runtime.py +0 -535
- bloqade/pyqrack/squin/wire.py +0 -51
- bloqade/rewrite/rules/flatten_ilist.py +0 -51
- bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
- bloqade/squin/_typeinfer.py +0 -20
- bloqade/squin/analysis/address_impl.py +0 -71
- bloqade/squin/analysis/nsites/__init__.py +0 -9
- bloqade/squin/analysis/nsites/analysis.py +0 -50
- bloqade/squin/analysis/nsites/impls.py +0 -92
- bloqade/squin/analysis/nsites/lattice.py +0 -49
- bloqade/squin/cirq/__init__.py +0 -280
- bloqade/squin/cirq/emit/emit_circuit.py +0 -109
- bloqade/squin/cirq/emit/noise.py +0 -49
- bloqade/squin/cirq/emit/op.py +0 -125
- bloqade/squin/cirq/emit/qubit.py +0 -60
- bloqade/squin/cirq/emit/runtime.py +0 -242
- bloqade/squin/cirq/lowering.py +0 -440
- bloqade/squin/lowering.py +0 -54
- bloqade/squin/noise/_wrapper.py +0 -40
- bloqade/squin/noise/rewrite.py +0 -111
- bloqade/squin/op/__init__.py +0 -41
- bloqade/squin/op/_dialect.py +0 -3
- bloqade/squin/op/_wrapper.py +0 -121
- bloqade/squin/op/number.py +0 -5
- bloqade/squin/op/rewrite.py +0 -46
- bloqade/squin/op/stdlib.py +0 -62
- bloqade/squin/op/stmts.py +0 -276
- bloqade/squin/op/traits.py +0 -43
- bloqade/squin/op/types.py +0 -26
- bloqade/squin/qubit.py +0 -184
- bloqade/squin/rewrite/canonicalize.py +0 -60
- bloqade/squin/rewrite/desugar.py +0 -124
- bloqade/squin/types.py +0 -8
- bloqade/squin/wire.py +0 -201
- bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
- bloqade/stim/rewrite/wire_to_stim.py +0 -57
- bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from dataclasses import field, dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir, passes, rewrite
|
|
4
|
+
from kirin.analysis import CallGraph
|
|
5
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
6
|
+
from kirin.dialects.func.stmts import Invoke
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class ReplaceMethods(RewriteRule):
|
|
11
|
+
new_symbols: dict[ir.Method, ir.Method]
|
|
12
|
+
|
|
13
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
14
|
+
if (
|
|
15
|
+
not isinstance(node, Invoke)
|
|
16
|
+
or (new_callee := self.new_symbols.get(node.callee)) is None
|
|
17
|
+
):
|
|
18
|
+
return RewriteResult()
|
|
19
|
+
|
|
20
|
+
node.replace_by(
|
|
21
|
+
Invoke(
|
|
22
|
+
inputs=node.inputs,
|
|
23
|
+
callee=new_callee,
|
|
24
|
+
purity=node.purity,
|
|
25
|
+
)
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
return RewriteResult(has_done_something=True)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class UpdateDialectsOnCallGraph(passes.Pass):
|
|
33
|
+
"""Update All dialects on the call graph to a new set of dialects given to this pass.
|
|
34
|
+
|
|
35
|
+
Usage:
|
|
36
|
+
pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects)
|
|
37
|
+
pass_(some_method)
|
|
38
|
+
|
|
39
|
+
Note: This pass does not update the dialects of the input method, but copies
|
|
40
|
+
all other methods invoked within it before updating their dialects.
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
fold_pass: passes.Fold = field(init=False)
|
|
45
|
+
|
|
46
|
+
def __post_init__(self):
|
|
47
|
+
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
|
|
48
|
+
|
|
49
|
+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
|
|
50
|
+
mt_map = {}
|
|
51
|
+
|
|
52
|
+
cg = CallGraph(mt)
|
|
53
|
+
|
|
54
|
+
all_methods = set(sum(map(tuple, cg.defs.values()), ()))
|
|
55
|
+
for original_mt in all_methods:
|
|
56
|
+
if original_mt is mt:
|
|
57
|
+
new_mt = original_mt
|
|
58
|
+
else:
|
|
59
|
+
new_mt = original_mt.similar(self.dialects)
|
|
60
|
+
mt_map[original_mt] = new_mt
|
|
61
|
+
|
|
62
|
+
result = RewriteResult()
|
|
63
|
+
|
|
64
|
+
for _, new_mt in mt_map.items():
|
|
65
|
+
result = (
|
|
66
|
+
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result)
|
|
67
|
+
)
|
|
68
|
+
self.fold_pass(new_mt)
|
|
69
|
+
|
|
70
|
+
return result
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class CallGraphPass(passes.Pass):
|
|
75
|
+
"""Copy all functions in the call graph and apply a rule to each of them.
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
Usage:
|
|
79
|
+
rule = Walk(SomeRewriteRule())
|
|
80
|
+
pass_ = CallGraphPass(rule=rule, dialects=...)
|
|
81
|
+
pass_(some_method)
|
|
82
|
+
|
|
83
|
+
Note: This pass modifies the input method in place, but copies
|
|
84
|
+
all methods invoked within it before applying the rule to them.
|
|
85
|
+
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
rule: RewriteRule
|
|
89
|
+
"""The rule to apply to each function in the call graph."""
|
|
90
|
+
|
|
91
|
+
fold_pass: passes.Fold = field(init=False)
|
|
92
|
+
|
|
93
|
+
def __post_init__(self):
|
|
94
|
+
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
|
|
95
|
+
|
|
96
|
+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
|
|
97
|
+
result = RewriteResult()
|
|
98
|
+
mt_map = {}
|
|
99
|
+
|
|
100
|
+
cg = CallGraph(mt)
|
|
101
|
+
|
|
102
|
+
all_methods = set(cg.edges.keys())
|
|
103
|
+
for original_mt in all_methods:
|
|
104
|
+
if original_mt is mt:
|
|
105
|
+
new_mt = original_mt
|
|
106
|
+
else:
|
|
107
|
+
new_mt = original_mt.similar()
|
|
108
|
+
result = self.rule.rewrite(new_mt.code).join(result)
|
|
109
|
+
mt_map[original_mt] = new_mt
|
|
110
|
+
|
|
111
|
+
if result.has_done_something:
|
|
112
|
+
for _, new_mt in mt_map.items():
|
|
113
|
+
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code)
|
|
114
|
+
self.fold_pass(new_mt)
|
|
115
|
+
|
|
116
|
+
return result
|
|
@@ -1,28 +1,34 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
1
|
+
from dataclasses import field, dataclass
|
|
2
2
|
|
|
3
|
-
from kirin import ir
|
|
4
|
-
from kirin.passes import Pass
|
|
3
|
+
from kirin import ir, passes
|
|
5
4
|
from kirin.rewrite import (
|
|
6
5
|
Walk,
|
|
7
6
|
Chain,
|
|
8
7
|
Fixpoint,
|
|
8
|
+
DeadCodeElimination,
|
|
9
9
|
)
|
|
10
|
-
from kirin.
|
|
11
|
-
|
|
12
|
-
from ..rules.flatten_ilist import FlattenAddOpIList
|
|
13
|
-
from ..rules.inline_getitem_ilist import InlineGetItemFromIList
|
|
10
|
+
from kirin.dialects.ilist import rewrite
|
|
14
11
|
|
|
15
12
|
|
|
16
13
|
@dataclass
|
|
17
|
-
class CanonicalizeIList(Pass):
|
|
14
|
+
class CanonicalizeIList(passes.Pass):
|
|
18
15
|
|
|
19
|
-
|
|
16
|
+
fold_pass: passes.Fold = field(init=False)
|
|
20
17
|
|
|
21
|
-
|
|
18
|
+
def __post_init__(self):
|
|
19
|
+
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
|
|
22
20
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
def unsafe_run(self, mt: ir.Method):
|
|
22
|
+
result = Fixpoint(
|
|
23
|
+
Walk(
|
|
24
|
+
Chain(
|
|
25
|
+
rewrite.InlineGetItem(),
|
|
26
|
+
rewrite.FlattenAdd(),
|
|
27
|
+
rewrite.HintLen(),
|
|
28
|
+
DeadCodeElimination(),
|
|
29
|
+
)
|
|
27
30
|
)
|
|
28
31
|
).rewrite(mt.code)
|
|
32
|
+
|
|
33
|
+
result = self.fold_pass(mt).join(result)
|
|
34
|
+
return result
|
|
@@ -46,9 +46,13 @@ class SplitIfStmts(RewriteRule):
|
|
|
46
46
|
if not isinstance(node, scf.IfElse):
|
|
47
47
|
return RewriteResult()
|
|
48
48
|
|
|
49
|
+
# NOTE: only empty else bodies are allowed in valid QASM2
|
|
50
|
+
if not self._has_empty_else(node):
|
|
51
|
+
return RewriteResult()
|
|
52
|
+
|
|
49
53
|
*stmts, yield_or_return = node.then_body.stmts()
|
|
50
54
|
|
|
51
|
-
if len(stmts)
|
|
55
|
+
if len(stmts) <= 1:
|
|
52
56
|
return RewriteResult()
|
|
53
57
|
|
|
54
58
|
is_yield = isinstance(yield_or_return, scf.Yield)
|
|
@@ -71,3 +75,16 @@ class SplitIfStmts(RewriteRule):
|
|
|
71
75
|
node.delete()
|
|
72
76
|
|
|
73
77
|
return RewriteResult(has_done_something=True)
|
|
78
|
+
|
|
79
|
+
def _has_empty_else(self, node: scf.IfElse) -> bool:
|
|
80
|
+
else_stmts = list(node.else_body.stmts())
|
|
81
|
+
if len(else_stmts) > 1:
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
if len(else_stmts) == 0:
|
|
85
|
+
return True
|
|
86
|
+
|
|
87
|
+
if not isinstance(else_stmts[0], scf.Yield):
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
return len(else_stmts[0].values) == 0
|
bloqade/squin/__init__.py
CHANGED
|
@@ -1,19 +1,52 @@
|
|
|
1
1
|
from . import (
|
|
2
|
-
|
|
3
|
-
wire as wire,
|
|
2
|
+
gate as gate,
|
|
4
3
|
noise as noise,
|
|
5
|
-
qubit as qubit,
|
|
6
4
|
analysis as analysis,
|
|
7
|
-
lowering as lowering,
|
|
8
|
-
_typeinfer as _typeinfer,
|
|
9
5
|
)
|
|
10
|
-
from
|
|
6
|
+
from .. import qubit as qubit, annotate as annotate
|
|
7
|
+
from ..qubit import (
|
|
8
|
+
reset as reset,
|
|
9
|
+
qalloc as qalloc,
|
|
10
|
+
measure as measure,
|
|
11
|
+
get_qubit_id as get_qubit_id,
|
|
12
|
+
get_measurement_id as get_measurement_id,
|
|
13
|
+
)
|
|
14
|
+
from .groups import kernel as kernel
|
|
15
|
+
from ..annotate import set_detector as set_detector, set_observable as set_observable
|
|
16
|
+
from .stdlib.simple import (
|
|
17
|
+
h as h,
|
|
18
|
+
s as s,
|
|
19
|
+
t as t,
|
|
20
|
+
x as x,
|
|
21
|
+
y as y,
|
|
22
|
+
z as z,
|
|
23
|
+
cx as cx,
|
|
24
|
+
cy as cy,
|
|
25
|
+
cz as cz,
|
|
26
|
+
rx as rx,
|
|
27
|
+
ry as ry,
|
|
28
|
+
rz as rz,
|
|
29
|
+
u3 as u3,
|
|
30
|
+
s_adj as s_adj,
|
|
31
|
+
shift as shift,
|
|
32
|
+
t_adj as t_adj,
|
|
33
|
+
sqrt_x as sqrt_x,
|
|
34
|
+
sqrt_y as sqrt_y,
|
|
35
|
+
sqrt_z as sqrt_z,
|
|
36
|
+
bit_flip as bit_flip,
|
|
37
|
+
depolarize as depolarize,
|
|
38
|
+
qubit_loss as qubit_loss,
|
|
39
|
+
sqrt_x_adj as sqrt_x_adj,
|
|
40
|
+
sqrt_y_adj as sqrt_y_adj,
|
|
41
|
+
sqrt_z_adj as sqrt_z_adj,
|
|
42
|
+
depolarize2 as depolarize2,
|
|
43
|
+
correlated_qubit_loss as correlated_qubit_loss,
|
|
44
|
+
two_qubit_pauli_channel as two_qubit_pauli_channel,
|
|
45
|
+
single_qubit_pauli_channel as single_qubit_pauli_channel,
|
|
46
|
+
)
|
|
11
47
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
else:
|
|
18
|
-
from . import cirq as cirq
|
|
19
|
-
from .cirq import load_circuit as load_circuit
|
|
48
|
+
# NOTE: it's important to keep these imports here since they import squin.kernel
|
|
49
|
+
# we skip isort here
|
|
50
|
+
from .stdlib import ( # isort: skip
|
|
51
|
+
broadcast as broadcast,
|
|
52
|
+
)
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from . import address_impl as address_impl
|
|
@@ -185,18 +185,17 @@ class DagScheduleAnalysis(Forward[GateSchedule]):
|
|
|
185
185
|
self.stmt_dag = StmtDag()
|
|
186
186
|
self.use_def = {}
|
|
187
187
|
|
|
188
|
-
def
|
|
189
|
-
|
|
190
|
-
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
|
|
188
|
+
def method_self(self, method: ir.Method) -> GateSchedule:
|
|
189
|
+
return self.lattice.bottom()
|
|
191
190
|
|
|
192
|
-
def
|
|
193
|
-
if
|
|
191
|
+
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
|
|
192
|
+
if node.has_trait(ir.IsTerminator):
|
|
194
193
|
assert (
|
|
195
|
-
|
|
194
|
+
node.parent_block is not None
|
|
196
195
|
), "Terminator statement has no parent block"
|
|
197
|
-
self.push_current_dag(
|
|
196
|
+
self.push_current_dag(node.parent_block)
|
|
198
197
|
|
|
199
|
-
return tuple(self.lattice.top() for _ in
|
|
198
|
+
return tuple(self.lattice.top() for _ in node.results)
|
|
200
199
|
|
|
201
200
|
def _update_dag(self, stmt: ir.Statement, addr: address.Address):
|
|
202
201
|
if isinstance(addr, address.AddressQubit):
|
|
@@ -210,8 +209,8 @@ class DagScheduleAnalysis(Forward[GateSchedule]):
|
|
|
210
209
|
if old_stmt is not None:
|
|
211
210
|
self.stmt_dag.add_edge(old_stmt, stmt)
|
|
212
211
|
self.use_def[idx] = stmt
|
|
213
|
-
elif isinstance(addr, address.
|
|
214
|
-
for sub_addr in addr.
|
|
212
|
+
elif isinstance(addr, address.AddressReg):
|
|
213
|
+
for sub_addr in addr.qubits:
|
|
215
214
|
self._update_dag(stmt, sub_addr)
|
|
216
215
|
|
|
217
216
|
def update_dag(self, stmt: ir.Statement, args: Sequence[ir.SSAValue]):
|
|
@@ -226,7 +225,7 @@ class DagScheduleAnalysis(Forward[GateSchedule]):
|
|
|
226
225
|
if args is None:
|
|
227
226
|
args = tuple(self.lattice.top() for _ in mt.args)
|
|
228
227
|
|
|
229
|
-
self.run(mt
|
|
228
|
+
self.run(mt)
|
|
230
229
|
return self.stmt_dags
|
|
231
230
|
|
|
232
231
|
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from typing import Any, TypeVar
|
|
2
|
+
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
from kirin.lowering import wraps
|
|
5
|
+
|
|
6
|
+
from bloqade.types import Qubit
|
|
7
|
+
|
|
8
|
+
from .stmts import (
|
|
9
|
+
CX,
|
|
10
|
+
CY,
|
|
11
|
+
CZ,
|
|
12
|
+
U3,
|
|
13
|
+
H,
|
|
14
|
+
S,
|
|
15
|
+
T,
|
|
16
|
+
X,
|
|
17
|
+
Y,
|
|
18
|
+
Z,
|
|
19
|
+
Rx,
|
|
20
|
+
Ry,
|
|
21
|
+
Rz,
|
|
22
|
+
SqrtX,
|
|
23
|
+
SqrtY,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@wraps(X)
|
|
28
|
+
def x(qubits: ilist.IList[Qubit, Any]) -> None: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@wraps(Y)
|
|
32
|
+
def y(qubits: ilist.IList[Qubit, Any]) -> None: ...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@wraps(Z)
|
|
36
|
+
def z(qubits: ilist.IList[Qubit, Any]) -> None: ...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@wraps(H)
|
|
40
|
+
def h(qubits: ilist.IList[Qubit, Any]) -> None: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@wraps(T)
|
|
44
|
+
def t(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@wraps(S)
|
|
48
|
+
def s(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@wraps(SqrtX)
|
|
52
|
+
def sqrt_x(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@wraps(SqrtY)
|
|
56
|
+
def sqrt_y(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@wraps(Rx)
|
|
60
|
+
def rx(angle: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@wraps(Ry)
|
|
64
|
+
def ry(angle: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@wraps(Rz)
|
|
68
|
+
def rz(angle: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
Len = TypeVar("Len", bound=int)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@wraps(CX)
|
|
75
|
+
def cx(
|
|
76
|
+
controls: ilist.IList[Qubit, Len],
|
|
77
|
+
targets: ilist.IList[Qubit, Len],
|
|
78
|
+
) -> None: ...
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@wraps(CY)
|
|
82
|
+
def cy(
|
|
83
|
+
controls: ilist.IList[Qubit, Len],
|
|
84
|
+
targets: ilist.IList[Qubit, Len],
|
|
85
|
+
) -> None: ...
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@wraps(CZ)
|
|
89
|
+
def cz(
|
|
90
|
+
controls: ilist.IList[Qubit, Len],
|
|
91
|
+
targets: ilist.IList[Qubit, Len],
|
|
92
|
+
) -> None: ...
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@wraps(U3)
|
|
96
|
+
def u3(
|
|
97
|
+
theta: float, phi: float, lam: float, qubits: ilist.IList[Qubit, Any]
|
|
98
|
+
) -> None: ...
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from kirin import ir, types, lowering
|
|
2
|
+
from kirin.decl import info, statement
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
|
|
5
|
+
from bloqade.types import QubitType
|
|
6
|
+
|
|
7
|
+
from ._dialect import dialect
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@statement
|
|
11
|
+
class Gate(ir.Statement):
|
|
12
|
+
# NOTE: just for easier isinstance checks elsewhere, all gates inherit from this class
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@statement
|
|
17
|
+
class SingleQubitGate(Gate):
|
|
18
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
19
|
+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@statement(dialect=dialect)
|
|
23
|
+
class X(SingleQubitGate):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@statement(dialect=dialect)
|
|
28
|
+
class Y(SingleQubitGate):
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@statement(dialect=dialect)
|
|
33
|
+
class Z(SingleQubitGate):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@statement(dialect=dialect)
|
|
38
|
+
class H(SingleQubitGate):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@statement
|
|
43
|
+
class SingleQubitNonHermitianGate(SingleQubitGate):
|
|
44
|
+
adjoint: bool = info.attribute(default=False)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@statement(dialect=dialect)
|
|
48
|
+
class T(SingleQubitNonHermitianGate):
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@statement(dialect=dialect)
|
|
53
|
+
class S(SingleQubitNonHermitianGate):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@statement(dialect=dialect)
|
|
58
|
+
class SqrtX(SingleQubitNonHermitianGate):
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@statement(dialect=dialect)
|
|
63
|
+
class SqrtY(SingleQubitNonHermitianGate):
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@statement
|
|
68
|
+
class RotationGate(Gate):
|
|
69
|
+
# NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
|
|
70
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
71
|
+
angle: ir.SSAValue = info.argument(types.Float)
|
|
72
|
+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@statement(dialect=dialect)
|
|
76
|
+
class Rx(RotationGate):
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@statement(dialect=dialect)
|
|
81
|
+
class Ry(RotationGate):
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@statement(dialect=dialect)
|
|
86
|
+
class Rz(RotationGate):
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
N = types.TypeVar("N", bound=types.Int)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@statement
|
|
94
|
+
class ControlledGate(Gate):
|
|
95
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
96
|
+
controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
|
|
97
|
+
targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@statement(dialect=dialect)
|
|
101
|
+
class CX(ControlledGate):
|
|
102
|
+
name = "cx"
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@statement(dialect=dialect)
|
|
107
|
+
class CY(ControlledGate):
|
|
108
|
+
name = "cy"
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@statement(dialect=dialect)
|
|
113
|
+
class CZ(ControlledGate):
|
|
114
|
+
name = "cz"
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@statement(dialect=dialect)
|
|
119
|
+
class U3(Gate):
|
|
120
|
+
# NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
|
|
121
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
122
|
+
theta: ir.SSAValue = info.argument(types.Float)
|
|
123
|
+
phi: ir.SSAValue = info.argument(types.Float)
|
|
124
|
+
lam: ir.SSAValue = info.argument(types.Float)
|
|
125
|
+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
|
bloqade/squin/groups.py
CHANGED
|
@@ -1,31 +1,24 @@
|
|
|
1
1
|
from kirin import ir, passes
|
|
2
2
|
from kirin.prelude import structural_no_opt
|
|
3
|
-
from kirin.
|
|
4
|
-
from kirin.dialects import ilist
|
|
3
|
+
from kirin.dialects import debug, ilist
|
|
5
4
|
|
|
6
|
-
from . import
|
|
7
|
-
from
|
|
8
|
-
from .rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule
|
|
5
|
+
from . import gate, noise
|
|
6
|
+
from .. import qubit, annotate
|
|
9
7
|
|
|
10
8
|
|
|
11
|
-
@ir.dialect_group(structural_no_opt.union([
|
|
9
|
+
@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug, annotate]))
|
|
12
10
|
def kernel(self):
|
|
13
11
|
fold_pass = passes.Fold(self)
|
|
14
12
|
typeinfer_pass = passes.TypeInfer(self)
|
|
15
13
|
ilist_desugar_pass = ilist.IListDesugar(self)
|
|
16
|
-
desugar_pass = Walk(Chain(MeasureDesugarRule(), ApplyDesugarRule()))
|
|
17
|
-
py_mult_to_mult_pass = PyMultToSquinMult(self)
|
|
18
14
|
|
|
19
15
|
def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
|
|
20
16
|
method.verify()
|
|
21
17
|
if fold:
|
|
22
18
|
fold_pass.fixpoint(method)
|
|
23
19
|
|
|
24
|
-
py_mult_to_mult_pass(method)
|
|
25
|
-
|
|
26
20
|
if typeinfer:
|
|
27
|
-
typeinfer_pass(method)
|
|
28
|
-
desugar_pass.rewrite(method.code)
|
|
21
|
+
typeinfer_pass(method) # infer types before desugaring
|
|
29
22
|
|
|
30
23
|
ilist_desugar_pass(method)
|
|
31
24
|
|
|
@@ -34,13 +27,3 @@ def kernel(self):
|
|
|
34
27
|
method.verify_type()
|
|
35
28
|
|
|
36
29
|
return run_pass
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@ir.dialect_group(structural_no_opt.union([op, wire, noise]))
|
|
40
|
-
def wired(self):
|
|
41
|
-
py_mult_to_mult_pass = PyMultToSquinMult(self)
|
|
42
|
-
|
|
43
|
-
def run_pass(method):
|
|
44
|
-
py_mult_to_mult_pass(method)
|
|
45
|
-
|
|
46
|
-
return run_pass
|
bloqade/squin/noise/__init__.py
CHANGED
|
@@ -1,11 +1,2 @@
|
|
|
1
|
-
from . import stmts as stmts
|
|
1
|
+
from . import stmts as stmts, _interface as _interface
|
|
2
2
|
from ._dialect import dialect as dialect
|
|
3
|
-
from ._wrapper import (
|
|
4
|
-
pp_error as pp_error,
|
|
5
|
-
depolarize as depolarize,
|
|
6
|
-
qubit_loss as qubit_loss,
|
|
7
|
-
depolarize2 as depolarize2,
|
|
8
|
-
pauli_error as pauli_error,
|
|
9
|
-
two_qubit_pauli_channel as two_qubit_pauli_channel,
|
|
10
|
-
single_qubit_pauli_channel as single_qubit_pauli_channel,
|
|
11
|
-
)
|
bloqade/squin/noise/_dialect.py
CHANGED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from typing import Any, Literal, TypeVar
|
|
2
|
+
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
from kirin.lowering import wraps
|
|
5
|
+
|
|
6
|
+
from bloqade.types import Qubit
|
|
7
|
+
|
|
8
|
+
from . import stmts
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@wraps(stmts.Depolarize)
|
|
12
|
+
def depolarize(p: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
N = TypeVar("N", bound=int)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@wraps(stmts.Depolarize2)
|
|
19
|
+
def depolarize2(
|
|
20
|
+
p: float, controls: ilist.IList[Qubit, N], targets: ilist.IList[Qubit, N]
|
|
21
|
+
) -> None: ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@wraps(stmts.SingleQubitPauliChannel)
|
|
25
|
+
def single_qubit_pauli_channel(
|
|
26
|
+
px: float, py: float, pz: float, qubits: ilist.IList[Qubit, Any]
|
|
27
|
+
) -> None: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@wraps(stmts.TwoQubitPauliChannel)
|
|
31
|
+
def two_qubit_pauli_channel(
|
|
32
|
+
probabilities: ilist.IList[float, Literal[15]],
|
|
33
|
+
controls: ilist.IList[Qubit, N],
|
|
34
|
+
targets: ilist.IList[Qubit, N],
|
|
35
|
+
) -> None: ...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@wraps(stmts.QubitLoss)
|
|
39
|
+
def qubit_loss(p: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@wraps(stmts.CorrelatedQubitLoss)
|
|
43
|
+
def correlated_qubit_loss(
|
|
44
|
+
p: float, qubits: ilist.IList[ilist.IList[Qubit, N], Any]
|
|
45
|
+
) -> None: ...
|