bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bloqade/analysis/address/__init__.py +8 -4
- bloqade/analysis/address/analysis.py +123 -33
- bloqade/analysis/address/impls.py +293 -90
- bloqade/analysis/address/lattice.py +209 -24
- bloqade/analysis/fidelity/analysis.py +11 -23
- bloqade/analysis/measure_id/analysis.py +18 -20
- bloqade/analysis/measure_id/impls.py +31 -29
- bloqade/annotate/__init__.py +6 -0
- bloqade/annotate/_dialect.py +3 -0
- bloqade/annotate/_interface.py +22 -0
- bloqade/annotate/stmts.py +29 -0
- bloqade/annotate/types.py +13 -0
- bloqade/cirq_utils/__init__.py +4 -2
- bloqade/cirq_utils/emit/__init__.py +3 -0
- bloqade/cirq_utils/emit/base.py +246 -0
- bloqade/cirq_utils/emit/gate.py +104 -0
- bloqade/cirq_utils/emit/noise.py +90 -0
- bloqade/cirq_utils/emit/qubit.py +35 -0
- bloqade/cirq_utils/lowering.py +660 -0
- bloqade/cirq_utils/noise/__init__.py +0 -2
- bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
- bloqade/cirq_utils/noise/model.py +151 -191
- bloqade/cirq_utils/noise/transform.py +2 -2
- bloqade/cirq_utils/parallelize.py +9 -6
- bloqade/gemini/__init__.py +1 -0
- bloqade/gemini/analysis/__init__.py +3 -0
- bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
- bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
- bloqade/gemini/analysis/logical_validation/impls.py +101 -0
- bloqade/gemini/groups.py +67 -0
- bloqade/native/__init__.py +23 -0
- bloqade/native/_prelude.py +45 -0
- bloqade/native/dialects/__init__.py +0 -0
- bloqade/native/dialects/gate/__init__.py +2 -0
- bloqade/native/dialects/gate/_dialect.py +3 -0
- bloqade/native/dialects/gate/_interface.py +32 -0
- bloqade/native/dialects/gate/stmts.py +31 -0
- bloqade/native/stdlib/__init__.py +0 -0
- bloqade/native/stdlib/broadcast.py +246 -0
- bloqade/native/stdlib/simple.py +220 -0
- bloqade/native/upstream/__init__.py +4 -0
- bloqade/native/upstream/squin2native.py +79 -0
- bloqade/pyqrack/__init__.py +2 -2
- bloqade/pyqrack/base.py +7 -1
- bloqade/pyqrack/device.py +192 -18
- bloqade/pyqrack/native.py +49 -0
- bloqade/pyqrack/reg.py +6 -6
- bloqade/pyqrack/squin/gate/__init__.py +1 -0
- bloqade/pyqrack/squin/gate/gate.py +136 -0
- bloqade/pyqrack/squin/noise/native.py +120 -54
- bloqade/pyqrack/squin/qubit.py +39 -36
- bloqade/pyqrack/target.py +5 -4
- bloqade/pyqrack/task.py +114 -7
- bloqade/qasm2/_qasm_loading.py +3 -3
- bloqade/qasm2/dialects/core/address.py +21 -12
- bloqade/qasm2/dialects/expr/_emit.py +19 -8
- bloqade/qasm2/dialects/expr/stmts.py +7 -7
- bloqade/qasm2/dialects/noise/fidelity.py +4 -8
- bloqade/qasm2/dialects/noise/model.py +2 -1
- bloqade/qasm2/emit/base.py +16 -11
- bloqade/qasm2/emit/gate.py +11 -8
- bloqade/qasm2/emit/main.py +103 -3
- bloqade/qasm2/emit/target.py +9 -5
- bloqade/qasm2/groups.py +3 -2
- bloqade/qasm2/parse/lowering.py +0 -1
- bloqade/qasm2/passes/fold.py +14 -73
- bloqade/qasm2/passes/glob.py +2 -2
- bloqade/qasm2/passes/noise.py +1 -1
- bloqade/qasm2/passes/parallel.py +7 -5
- bloqade/qasm2/rewrite/__init__.py +0 -1
- bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
- bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
- bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
- bloqade/qasm2/rewrite/register.py +2 -2
- bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
- bloqade/qbraid/lowering.py +1 -0
- bloqade/qbraid/schema.py +2 -2
- bloqade/qubit/__init__.py +12 -0
- bloqade/qubit/_dialect.py +3 -0
- bloqade/qubit/_interface.py +49 -0
- bloqade/qubit/_prelude.py +45 -0
- bloqade/qubit/analysis/__init__.py +1 -0
- bloqade/qubit/analysis/address_impl.py +40 -0
- bloqade/qubit/stdlib/__init__.py +2 -0
- bloqade/qubit/stdlib/_new.py +34 -0
- bloqade/qubit/stdlib/broadcast.py +62 -0
- bloqade/qubit/stdlib/simple.py +59 -0
- bloqade/qubit/stmts.py +60 -0
- bloqade/rewrite/passes/__init__.py +6 -0
- bloqade/rewrite/passes/aggressive_unroll.py +103 -0
- bloqade/rewrite/passes/callgraph.py +116 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
- bloqade/rewrite/rules/split_ifs.py +18 -1
- bloqade/squin/__init__.py +47 -14
- bloqade/squin/analysis/__init__.py +0 -1
- bloqade/squin/analysis/schedule.py +10 -11
- bloqade/squin/gate/__init__.py +2 -0
- bloqade/squin/gate/_dialect.py +3 -0
- bloqade/squin/gate/_interface.py +98 -0
- bloqade/squin/gate/stmts.py +125 -0
- bloqade/squin/groups.py +5 -22
- bloqade/squin/noise/__init__.py +1 -10
- bloqade/squin/noise/_dialect.py +1 -1
- bloqade/squin/noise/_interface.py +45 -0
- bloqade/squin/noise/stmts.py +66 -28
- bloqade/squin/rewrite/U3_to_clifford.py +70 -51
- bloqade/squin/rewrite/__init__.py +0 -2
- bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
- bloqade/squin/rewrite/wrap_analysis.py +4 -35
- bloqade/squin/stdlib/__init__.py +0 -0
- bloqade/squin/stdlib/broadcast/__init__.py +34 -0
- bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
- bloqade/squin/stdlib/broadcast/gate.py +260 -0
- bloqade/squin/stdlib/broadcast/noise.py +144 -0
- bloqade/squin/stdlib/simple/__init__.py +33 -0
- bloqade/squin/stdlib/simple/gate.py +242 -0
- bloqade/squin/stdlib/simple/noise.py +126 -0
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +6 -0
- bloqade/stim/dialects/auxiliary/emit.py +19 -18
- bloqade/stim/dialects/collapse/emit_str.py +7 -8
- bloqade/stim/dialects/gate/emit.py +9 -10
- bloqade/stim/dialects/noise/emit.py +17 -13
- bloqade/stim/dialects/noise/stmts.py +5 -3
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/impls.py +16 -0
- bloqade/stim/emit/stim_str.py +48 -31
- bloqade/stim/groups.py +12 -2
- bloqade/stim/parse/lowering.py +14 -17
- bloqade/stim/passes/__init__.py +0 -2
- bloqade/stim/passes/flatten.py +26 -0
- bloqade/stim/passes/simplify_ifs.py +6 -1
- bloqade/stim/passes/squin_to_stim.py +9 -84
- bloqade/stim/rewrite/__init__.py +2 -4
- bloqade/stim/rewrite/get_record_util.py +24 -0
- bloqade/stim/rewrite/ifs_to_stim.py +24 -25
- bloqade/stim/rewrite/qubit_to_stim.py +90 -41
- bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
- bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
- bloqade/stim/rewrite/squin_measure.py +9 -18
- bloqade/stim/rewrite/squin_noise.py +134 -108
- bloqade/stim/rewrite/util.py +5 -192
- bloqade/test_utils.py +1 -1
- bloqade/types.py +10 -0
- bloqade/validation/__init__.py +2 -0
- bloqade/validation/analysis/__init__.py +5 -0
- bloqade/validation/analysis/analysis.py +41 -0
- bloqade/validation/analysis/lattice.py +58 -0
- bloqade/validation/kernel_validation.py +77 -0
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
- bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
- bloqade/pyqrack/squin/op.py +0 -180
- bloqade/pyqrack/squin/runtime.py +0 -535
- bloqade/pyqrack/squin/wire.py +0 -51
- bloqade/rewrite/rules/flatten_ilist.py +0 -51
- bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
- bloqade/squin/_typeinfer.py +0 -20
- bloqade/squin/analysis/address_impl.py +0 -71
- bloqade/squin/analysis/nsites/__init__.py +0 -9
- bloqade/squin/analysis/nsites/analysis.py +0 -50
- bloqade/squin/analysis/nsites/impls.py +0 -92
- bloqade/squin/analysis/nsites/lattice.py +0 -49
- bloqade/squin/cirq/__init__.py +0 -280
- bloqade/squin/cirq/emit/emit_circuit.py +0 -109
- bloqade/squin/cirq/emit/noise.py +0 -49
- bloqade/squin/cirq/emit/op.py +0 -125
- bloqade/squin/cirq/emit/qubit.py +0 -60
- bloqade/squin/cirq/emit/runtime.py +0 -242
- bloqade/squin/cirq/lowering.py +0 -440
- bloqade/squin/lowering.py +0 -54
- bloqade/squin/noise/_wrapper.py +0 -40
- bloqade/squin/noise/rewrite.py +0 -111
- bloqade/squin/op/__init__.py +0 -41
- bloqade/squin/op/_dialect.py +0 -3
- bloqade/squin/op/_wrapper.py +0 -121
- bloqade/squin/op/number.py +0 -5
- bloqade/squin/op/rewrite.py +0 -46
- bloqade/squin/op/stdlib.py +0 -62
- bloqade/squin/op/stmts.py +0 -276
- bloqade/squin/op/traits.py +0 -43
- bloqade/squin/op/types.py +0 -26
- bloqade/squin/qubit.py +0 -184
- bloqade/squin/rewrite/canonicalize.py +0 -60
- bloqade/squin/rewrite/desugar.py +0 -124
- bloqade/squin/types.py +0 -8
- bloqade/squin/wire.py +0 -201
- bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
- bloqade/stim/rewrite/wire_to_stim.py +0 -57
- bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,16 +21,10 @@ class ParallelToUOpRule(abc.RewriteRule):
|
|
|
21
21
|
|
|
22
22
|
def get_qubit_ssa(self, ilist_ref: ir.SSAValue) -> Optional[List[ir.SSAValue]]:
|
|
23
23
|
addr = self.address_analysis.get(ilist_ref)
|
|
24
|
-
if not isinstance(addr, address.
|
|
24
|
+
if not isinstance(addr, address.AddressReg):
|
|
25
25
|
return None
|
|
26
26
|
|
|
27
|
-
ids =
|
|
28
|
-
for ele in addr.data:
|
|
29
|
-
if not isinstance(ele, address.AddressQubit):
|
|
30
|
-
return None
|
|
31
|
-
|
|
32
|
-
ids.append(ele.data)
|
|
33
|
-
|
|
27
|
+
ids = addr.data
|
|
34
28
|
return [self.id_map[ele] for ele in ids]
|
|
35
29
|
|
|
36
30
|
def rewrite_cz(self, node: ir.Statement):
|
|
@@ -2,7 +2,7 @@ from kirin import ir
|
|
|
2
2
|
from kirin.dialects import py
|
|
3
3
|
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
4
4
|
|
|
5
|
-
from bloqade.qasm2.dialects import core
|
|
5
|
+
from bloqade.qasm2.dialects import core, expr
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class RaiseRegisterRule(RewriteRule):
|
|
@@ -26,7 +26,7 @@ class RaiseRegisterRule(RewriteRule):
|
|
|
26
26
|
n_qubits_ref = node.n_qubits
|
|
27
27
|
|
|
28
28
|
n_qubits = n_qubits_ref.owner
|
|
29
|
-
if isinstance(n_qubits, py.Constant):
|
|
29
|
+
if isinstance(n_qubits, py.Constant | expr.ConstInt):
|
|
30
30
|
# case where the n_qubits comes from a constant
|
|
31
31
|
new_n_qubits = n_qubits.from_stmt(n_qubits)
|
|
32
32
|
new_n_qubits.insert_before(first_stmt)
|
|
@@ -8,7 +8,7 @@ from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
|
8
8
|
from kirin.analysis.const import lattice
|
|
9
9
|
|
|
10
10
|
from bloqade.analysis import address
|
|
11
|
-
from bloqade.qasm2.dialects import uop, core, parallel
|
|
11
|
+
from bloqade.qasm2.dialects import uop, core, expr, parallel
|
|
12
12
|
from bloqade.squin.analysis.schedule import StmtDag
|
|
13
13
|
|
|
14
14
|
|
|
@@ -66,7 +66,7 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
66
66
|
assert isinstance(hint1, lattice.Result) and isinstance(
|
|
67
67
|
hint2, lattice.Result
|
|
68
68
|
)
|
|
69
|
-
return hint1.
|
|
69
|
+
return hint1.is_structurally_equal(hint2)
|
|
70
70
|
else:
|
|
71
71
|
return False
|
|
72
72
|
|
|
@@ -194,6 +194,8 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
194
194
|
new_qubits.append(new_qubit.result)
|
|
195
195
|
case core.QRegGet(
|
|
196
196
|
reg=reg, idx=ir.ResultValue(stmt=py.Constant() as idx)
|
|
197
|
+
) | core.QRegGet(
|
|
198
|
+
reg=reg, idx=ir.ResultValue(stmt=expr.ConstInt() as idx)
|
|
197
199
|
):
|
|
198
200
|
(new_idx := idx.from_stmt(idx)).insert_before(node)
|
|
199
201
|
(
|
bloqade/qbraid/lowering.py
CHANGED
bloqade/qbraid/schema.py
CHANGED
|
@@ -238,13 +238,13 @@ class NoiseModel(BaseModel, Generic[ErrorModelType], extra="forbid"):
|
|
|
238
238
|
str: The decompiled circuit from hardware execution.
|
|
239
239
|
|
|
240
240
|
"""
|
|
241
|
-
from bloqade.noise import native
|
|
242
241
|
from bloqade.qasm2.emit import QASM2
|
|
243
242
|
from bloqade.qasm2.passes import glob, parallel
|
|
243
|
+
from bloqade.qasm2.rewrite.noise import remove_noise
|
|
244
244
|
|
|
245
245
|
mt = self.lower_noise_model("method")
|
|
246
246
|
|
|
247
|
-
|
|
247
|
+
remove_noise.RemoveNoisePass(mt.dialects)(mt)
|
|
248
248
|
parallel.ParallelToUOp(mt.dialects)(mt)
|
|
249
249
|
glob.GlobalToUOP(mt.dialects)(mt)
|
|
250
250
|
return QASM2(qelib1=True).emit_str(mt)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from bloqade.types import Qubit as Qubit, QubitType as QubitType
|
|
2
|
+
|
|
3
|
+
from . import stmts as stmts, analysis as analysis
|
|
4
|
+
from .stdlib import new as new, qalloc as qalloc, broadcast as broadcast
|
|
5
|
+
from ._dialect import dialect as dialect
|
|
6
|
+
from ._prelude import kernel as kernel
|
|
7
|
+
from .stdlib.simple import (
|
|
8
|
+
reset as reset,
|
|
9
|
+
measure as measure,
|
|
10
|
+
get_qubit_id as get_qubit_id,
|
|
11
|
+
get_measurement_id as get_measurement_id,
|
|
12
|
+
)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import Any, TypeVar
|
|
2
|
+
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
from kirin.lowering import wraps
|
|
5
|
+
|
|
6
|
+
from bloqade.types import Qubit, MeasurementResult
|
|
7
|
+
|
|
8
|
+
from .stmts import New, Reset, Measure, QubitId, MeasurementId
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@wraps(New)
|
|
12
|
+
def new() -> Qubit:
|
|
13
|
+
"""Create a new qubit.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
Qubit: A new qubit.
|
|
17
|
+
"""
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
N = TypeVar("N", bound=int)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@wraps(Measure)
|
|
25
|
+
def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]:
|
|
26
|
+
"""Measure a list of qubits.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
qubits (IList[Qubit, N]): The list of qubits to measure.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
IList[MeasurementResult, N]: The list containing the results of the measurements.
|
|
33
|
+
A MeasurementResult can represent both 0 and 1, but also atoms that are lost.
|
|
34
|
+
"""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@wraps(QubitId)
|
|
39
|
+
def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]: ...
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@wraps(MeasurementId)
|
|
43
|
+
def get_measurement_id(
|
|
44
|
+
measurements: ilist.IList[MeasurementResult, N],
|
|
45
|
+
) -> ilist.IList[int, N]: ...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@wraps(Reset)
|
|
49
|
+
def reset(qubits: ilist.IList[Qubit, Any]) -> None: ...
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.passes import Default
|
|
5
|
+
from kirin.prelude import structural_no_opt
|
|
6
|
+
from typing_extensions import Doc
|
|
7
|
+
|
|
8
|
+
from . import _dialect as qubit
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@ir.dialect_group(structural_no_opt.union([qubit]))
|
|
12
|
+
def kernel(self):
|
|
13
|
+
"""Compile to a qubit kernel"""
|
|
14
|
+
|
|
15
|
+
def run_pass(
|
|
16
|
+
mt,
|
|
17
|
+
*,
|
|
18
|
+
verify: Annotated[
|
|
19
|
+
bool, Doc("run `verify` before running passes, default is `True`")
|
|
20
|
+
] = True,
|
|
21
|
+
typeinfer: Annotated[
|
|
22
|
+
bool,
|
|
23
|
+
Doc(
|
|
24
|
+
"run type inference and apply the inferred type to IR, default `False`"
|
|
25
|
+
),
|
|
26
|
+
] = False,
|
|
27
|
+
fold: Annotated[bool, Doc("run folding passes, default is `True`")] = True,
|
|
28
|
+
aggressive: Annotated[
|
|
29
|
+
bool, Doc("run aggressive folding passes if `fold=True`")
|
|
30
|
+
] = False,
|
|
31
|
+
no_raise: Annotated[
|
|
32
|
+
bool, Doc("do not raise exception during analysis, default is `True`")
|
|
33
|
+
] = True,
|
|
34
|
+
) -> None:
|
|
35
|
+
default_pass = Default(
|
|
36
|
+
self,
|
|
37
|
+
verify=verify,
|
|
38
|
+
fold=fold,
|
|
39
|
+
aggressive=aggressive,
|
|
40
|
+
typeinfer=typeinfer,
|
|
41
|
+
no_raise=no_raise,
|
|
42
|
+
)
|
|
43
|
+
default_pass.fixpoint(mt)
|
|
44
|
+
|
|
45
|
+
return run_pass
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import address_impl as address_impl
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from kirin import interp
|
|
2
|
+
from kirin.analysis import ForwardFrame
|
|
3
|
+
|
|
4
|
+
from bloqade.analysis.address.lattice import (
|
|
5
|
+
Address,
|
|
6
|
+
AddressQubit,
|
|
7
|
+
)
|
|
8
|
+
from bloqade.analysis.address.analysis import AddressAnalysis
|
|
9
|
+
|
|
10
|
+
from .. import stmts
|
|
11
|
+
from .._dialect import dialect
|
|
12
|
+
|
|
13
|
+
# Address lattice elements we can work with:
|
|
14
|
+
## NotQubit (bottom), AnyAddress (top)
|
|
15
|
+
|
|
16
|
+
## AddressTuple -> data: tuple[Address, ...]
|
|
17
|
+
### Recursive type, could contain itself or other variants
|
|
18
|
+
### This pops up in cases where you can have an IList/Tuple
|
|
19
|
+
### That contains elements that could be other Address types
|
|
20
|
+
|
|
21
|
+
## AddressReg -> data: Sequence[int]
|
|
22
|
+
### specific to creation of a register of qubits
|
|
23
|
+
|
|
24
|
+
## AddressQubit -> data: int
|
|
25
|
+
### Base qubit address type
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dialect.register(key="qubit.address")
|
|
29
|
+
class SquinQubitMethodTable(interp.MethodTable):
|
|
30
|
+
|
|
31
|
+
@interp.impl(stmts.New)
|
|
32
|
+
def new_qubit(
|
|
33
|
+
self,
|
|
34
|
+
interp_: AddressAnalysis,
|
|
35
|
+
frame: ForwardFrame[Address],
|
|
36
|
+
stmt: stmts.New,
|
|
37
|
+
):
|
|
38
|
+
addr = AddressQubit(interp_.next_address)
|
|
39
|
+
interp_.next_address += 1
|
|
40
|
+
return (addr,)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
|
|
5
|
+
from .. import _interface as qubit
|
|
6
|
+
from .._prelude import kernel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@kernel(typeinfer=True)
|
|
10
|
+
def new() -> qubit.Qubit:
|
|
11
|
+
"""Allocate a single new qubit
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
(Qubit): The newly allocated qubit.
|
|
15
|
+
"""
|
|
16
|
+
return qubit.new()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# NOTE: this is a special case, that doesn't use the usual simple / broadcast semantics.
|
|
20
|
+
@kernel(typeinfer=True)
|
|
21
|
+
def qalloc(n_qubits: int) -> ilist.IList[qubit.Qubit, Any]:
|
|
22
|
+
"""Allocate a new list of qubits.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
n_qubits(int): The number of qubits to create.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
(ilist.IList[Qubit, n_qubits]) A list of qubits.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def _new(qid: int) -> qubit.Qubit:
|
|
32
|
+
return qubit.new()
|
|
33
|
+
|
|
34
|
+
return ilist.map(_new, ilist.range(n_qubits))
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from typing import Any, TypeVar
|
|
2
|
+
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
|
|
5
|
+
from bloqade.types import Qubit, MeasurementResult
|
|
6
|
+
|
|
7
|
+
from .. import _interface as _qubit
|
|
8
|
+
from .._prelude import kernel
|
|
9
|
+
|
|
10
|
+
N = TypeVar("N", bound=int)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@kernel
|
|
14
|
+
def reset(qubits: ilist.IList[Qubit, Any]) -> None:
|
|
15
|
+
"""
|
|
16
|
+
Reset a list of qubits to the zero state.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
qubits (IList[Qubit, Any]): The list of qubits to reset.
|
|
20
|
+
"""
|
|
21
|
+
_qubit.reset(qubits)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@kernel
|
|
25
|
+
def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]:
|
|
26
|
+
"""Measure a list of qubits.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
qubits (IList[Qubit, N]): The list of qubits to measure.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
IList[MeasurementResult, N]: The list containing the results of the measurements.
|
|
33
|
+
A MeasurementResult can represent both 0 and 1 as well as atom loss.
|
|
34
|
+
"""
|
|
35
|
+
return _qubit.measure(qubits)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@kernel
|
|
39
|
+
def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]:
|
|
40
|
+
"""Get the global, unique ID of each qubit in the list.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
qubits (IList[Qubit, N]): The list of qubits of which you want the ID.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
qubit_ids (IList[int, N]): The list of global, unique IDs of the qubits.
|
|
47
|
+
"""
|
|
48
|
+
return _qubit.get_qubit_id(qubits)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@kernel
|
|
52
|
+
def get_measurement_id(
|
|
53
|
+
measurements: ilist.IList[MeasurementResult, N],
|
|
54
|
+
) -> ilist.IList[int, N]:
|
|
55
|
+
"""Get the global, unique ID of each of the measurement results in the list.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
measurements (IList[MeasurementResult, N]): The previously taken measurement of which you want to know the ID.
|
|
59
|
+
Returns:
|
|
60
|
+
measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements.
|
|
61
|
+
"""
|
|
62
|
+
return _qubit.get_measurement_id(measurements)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from kirin.dialects import ilist
|
|
2
|
+
|
|
3
|
+
from bloqade.types import Qubit, MeasurementResult
|
|
4
|
+
|
|
5
|
+
from . import broadcast
|
|
6
|
+
from .._prelude import kernel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@kernel
|
|
10
|
+
def reset(qubit: Qubit) -> None:
|
|
11
|
+
"""
|
|
12
|
+
Reset a qubit to the zero state.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
qubit (Qubit): The list qubit to reset.
|
|
16
|
+
"""
|
|
17
|
+
return broadcast.reset(ilist.IList([qubit]))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@kernel
|
|
21
|
+
def measure(qubit: Qubit) -> MeasurementResult:
|
|
22
|
+
"""Measure a qubit.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
qubit (Qubit): The qubit to measure.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
MeasurementResult: The result of the measurement.
|
|
29
|
+
A MeasurementResult can represent both 0 and 1, but also atoms that are lost.
|
|
30
|
+
"""
|
|
31
|
+
measurement_results = broadcast.measure(ilist.IList([qubit]))
|
|
32
|
+
return measurement_results[0]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@kernel
|
|
36
|
+
def get_qubit_id(qubit: Qubit) -> int:
|
|
37
|
+
"""Get the global, unique ID of the qubit.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
qubit (Qubit): The qubit of which you want the ID.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
qubit_id (int): The global, unique ID of the qubit.
|
|
44
|
+
"""
|
|
45
|
+
ids = broadcast.get_qubit_id(ilist.IList([qubit]))
|
|
46
|
+
return ids[0]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@kernel
|
|
50
|
+
def get_measurement_id(measurement: MeasurementResult) -> int:
|
|
51
|
+
"""Get the global, unique ID of the measurement result.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
measurement (MeasurementResult): The previously taken measurement of which you want to know the ID.
|
|
55
|
+
Returns:
|
|
56
|
+
measurement_id (int): The global, unique ID of the measurement.
|
|
57
|
+
"""
|
|
58
|
+
ids = broadcast.get_measurement_id(ilist.IList([measurement]))
|
|
59
|
+
return ids[0]
|
bloqade/qubit/stmts.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from kirin import ir, types, interp, lowering
|
|
2
|
+
from kirin.decl import info, statement
|
|
3
|
+
from kirin.dialects import ilist
|
|
4
|
+
|
|
5
|
+
from bloqade.types import QubitType, MeasurementResultType
|
|
6
|
+
|
|
7
|
+
from ._dialect import dialect
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@statement(dialect=dialect)
|
|
11
|
+
class New(ir.Statement):
|
|
12
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
13
|
+
result: ir.ResultValue = info.result(QubitType)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
Len = types.TypeVar("Len", bound=types.Int)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@statement(dialect=dialect)
|
|
20
|
+
class Measure(ir.Statement):
|
|
21
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
22
|
+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
|
|
23
|
+
result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType, Len])
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@statement(dialect=dialect)
|
|
27
|
+
class QubitId(ir.Statement):
|
|
28
|
+
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
|
|
29
|
+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
|
|
30
|
+
result: ir.ResultValue = info.result(ilist.IListType[types.Int, Len])
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@statement(dialect=dialect)
|
|
34
|
+
class MeasurementId(ir.Statement):
|
|
35
|
+
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
|
|
36
|
+
measurements: ir.SSAValue = info.argument(
|
|
37
|
+
ilist.IListType[MeasurementResultType, Len]
|
|
38
|
+
)
|
|
39
|
+
result: ir.ResultValue = info.result(ilist.IListType[types.Int, Len])
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@statement(dialect=dialect)
|
|
43
|
+
class Reset(ir.Statement):
|
|
44
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
45
|
+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# TODO: investigate why this is needed to get type inference to be correct.
|
|
49
|
+
@dialect.register(key="typeinfer")
|
|
50
|
+
class __TypeInfer(interp.MethodTable):
|
|
51
|
+
@interp.impl(Measure)
|
|
52
|
+
def measure_list(self, _interp, frame: interp.AbstractFrame, stmt: Measure):
|
|
53
|
+
qubit_type = frame.get(stmt.qubits)
|
|
54
|
+
|
|
55
|
+
if isinstance(qubit_type, types.Generic):
|
|
56
|
+
len_type = qubit_type.vars[1]
|
|
57
|
+
else:
|
|
58
|
+
len_type = types.Any
|
|
59
|
+
|
|
60
|
+
return (ilist.IListType[MeasurementResultType, len_type],)
|
|
@@ -1 +1,7 @@
|
|
|
1
|
+
from .callgraph import (
|
|
2
|
+
CallGraphPass as CallGraphPass,
|
|
3
|
+
ReplaceMethods as ReplaceMethods,
|
|
4
|
+
UpdateDialectsOnCallGraph as UpdateDialectsOnCallGraph,
|
|
5
|
+
)
|
|
6
|
+
from .aggressive_unroll import AggressiveUnroll as AggressiveUnroll
|
|
1
7
|
from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
from dataclasses import field, dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.passes import Pass, HintConst, TypeInfer
|
|
6
|
+
from kirin.rewrite import (
|
|
7
|
+
Walk,
|
|
8
|
+
Chain,
|
|
9
|
+
Inline,
|
|
10
|
+
Fixpoint,
|
|
11
|
+
Call2Invoke,
|
|
12
|
+
ConstantFold,
|
|
13
|
+
CFGCompactify,
|
|
14
|
+
InlineGetItem,
|
|
15
|
+
InlineGetField,
|
|
16
|
+
DeadCodeElimination,
|
|
17
|
+
CommonSubexpressionElimination,
|
|
18
|
+
)
|
|
19
|
+
from kirin.dialects import scf, ilist
|
|
20
|
+
from kirin.ir.method import Method
|
|
21
|
+
from kirin.rewrite.abc import RewriteResult
|
|
22
|
+
from kirin.passes.aggressive import UnrollScf
|
|
23
|
+
|
|
24
|
+
from .canonicalize_ilist import CanonicalizeIList
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class Fold(Pass):
|
|
29
|
+
hint_const: HintConst = field(init=False)
|
|
30
|
+
|
|
31
|
+
def __post_init__(self):
|
|
32
|
+
self.hint_const = HintConst(self.dialects, no_raise=self.no_raise)
|
|
33
|
+
|
|
34
|
+
def unsafe_run(self, mt: Method) -> RewriteResult:
|
|
35
|
+
result = RewriteResult()
|
|
36
|
+
result = self.hint_const.unsafe_run(mt).join(result)
|
|
37
|
+
rule = Chain(
|
|
38
|
+
ConstantFold(),
|
|
39
|
+
Call2Invoke(),
|
|
40
|
+
InlineGetField(),
|
|
41
|
+
InlineGetItem(),
|
|
42
|
+
ilist.rewrite.InlineGetItem(),
|
|
43
|
+
ilist.rewrite.FlattenAdd(),
|
|
44
|
+
ilist.rewrite.HintLen(),
|
|
45
|
+
DeadCodeElimination(),
|
|
46
|
+
)
|
|
47
|
+
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
|
|
48
|
+
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class AggressiveUnroll(Pass):
|
|
54
|
+
"""A pass to unroll structured control flow"""
|
|
55
|
+
|
|
56
|
+
additional_inline_heuristic: Callable[[ir.Statement], bool] = lambda node: True
|
|
57
|
+
|
|
58
|
+
fold: Fold = field(init=False)
|
|
59
|
+
typeinfer: TypeInfer = field(init=False)
|
|
60
|
+
scf_unroll: UnrollScf = field(init=False)
|
|
61
|
+
canonicalize_ilist: CanonicalizeIList = field(init=False)
|
|
62
|
+
|
|
63
|
+
def __post_init__(self):
|
|
64
|
+
self.fold = Fold(self.dialects, no_raise=self.no_raise)
|
|
65
|
+
self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
|
|
66
|
+
self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise)
|
|
67
|
+
self.canonicalize_ilist = CanonicalizeIList(
|
|
68
|
+
self.dialects, no_raise=self.no_raise
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def unsafe_run(self, mt: Method) -> RewriteResult:
|
|
72
|
+
result = RewriteResult()
|
|
73
|
+
result = self.fold.unsafe_run(mt).join(result)
|
|
74
|
+
result = self.scf_unroll.unsafe_run(mt).join(result)
|
|
75
|
+
self.typeinfer.unsafe_run(
|
|
76
|
+
mt
|
|
77
|
+
) # Do not join the result of typeinfer or fixpoint will waste time
|
|
78
|
+
result = (
|
|
79
|
+
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
|
|
80
|
+
.rewrite(mt.code)
|
|
81
|
+
.join(result)
|
|
82
|
+
)
|
|
83
|
+
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
|
|
84
|
+
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
|
|
85
|
+
result = self.canonicalize_ilist.fixpoint(mt).join(result)
|
|
86
|
+
rule = Chain(
|
|
87
|
+
CommonSubexpressionElimination(),
|
|
88
|
+
DeadCodeElimination(),
|
|
89
|
+
)
|
|
90
|
+
result = Walk(rule).rewrite(mt.code).join(result)
|
|
91
|
+
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
def inline_heuristic(self, node: ir.Statement) -> bool:
|
|
95
|
+
"""The heuristic to decide whether to inline a function call or not.
|
|
96
|
+
inside loops and if-else, only inline simple functions, i.e.
|
|
97
|
+
functions with a single block
|
|
98
|
+
"""
|
|
99
|
+
return not isinstance(
|
|
100
|
+
node.parent_stmt, (scf.For, scf.IfElse)
|
|
101
|
+
) and self.additional_inline_heuristic(
|
|
102
|
+
node
|
|
103
|
+
) # always inline calls outside of loops and if-else
|