tridec 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tridec/__init__.py +42 -0
- tridec/adapters/__init__.py +258 -0
- tridec/api.py +337 -0
- tridec/backends/__init__.py +11 -0
- tridec/backends/bp_numpy.py +221 -0
- tridec/backends/bp_torch.py +358 -0
- tridec/backends/bp_triton.py +480 -0
- tridec/backends/relay_triton.py +549 -0
- tridec/dem.py +49 -0
- tridec/tanner.py +48 -0
- tridec/validation/__init__.py +40 -0
- tridec/validation/analysis.py +239 -0
- tridec/validation/harness.py +231 -0
- tridec/validation/stats.py +63 -0
- tridec-0.1.0a1.dist-info/METADATA +121 -0
- tridec-0.1.0a1.dist-info/RECORD +19 -0
- tridec-0.1.0a1.dist-info/WHEEL +5 -0
- tridec-0.1.0a1.dist-info/licenses/LICENSE +202 -0
- tridec-0.1.0a1.dist-info/top_level.txt +1 -0
tridec/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""tridec: vendor-portable GPU decoders for quantum LDPC codes.
|
|
2
|
+
|
|
3
|
+
Triton min-sum BP and Relay-BP decoders that consume any stim
|
|
4
|
+
DetectorErrorModel or raw parity-check matrices, with CPU reference
|
|
5
|
+
implementations, validated against the standard CPU references (ldpc,
|
|
6
|
+
relay-bp), running on NVIDIA (CUDA) and AMD (ROCm) GPUs.
|
|
7
|
+
|
|
8
|
+
Quickstart::
|
|
9
|
+
|
|
10
|
+
import stim, tridec
|
|
11
|
+
|
|
12
|
+
circuit = stim.Circuit.from_file("memory.stim")
|
|
13
|
+
dem = circuit.detector_error_model(decompose_errors=False)
|
|
14
|
+
decoder = tridec.from_dem(dem, backend="auto")
|
|
15
|
+
|
|
16
|
+
dets, obs = circuit.compile_detector_sampler(seed=0).sample(
|
|
17
|
+
10_000, separate_observables=True)
|
|
18
|
+
pred = decoder.decode_batch(dets) # (shots, n_obs) bool
|
|
19
|
+
ler = (pred != obs).any(axis=1).mean()
|
|
20
|
+
"""
|
|
21
|
+
from .api import (
|
|
22
|
+
BpDecoder,
|
|
23
|
+
RelayBpDecoder,
|
|
24
|
+
available_backends,
|
|
25
|
+
from_dem,
|
|
26
|
+
from_matrices,
|
|
27
|
+
resolve_backend,
|
|
28
|
+
)
|
|
29
|
+
from .dem import extract
|
|
30
|
+
|
|
31
|
+
__version__ = "0.1.0.dev0"
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"BpDecoder",
|
|
35
|
+
"RelayBpDecoder",
|
|
36
|
+
"available_backends",
|
|
37
|
+
"extract",
|
|
38
|
+
"from_dem",
|
|
39
|
+
"from_matrices",
|
|
40
|
+
"resolve_backend",
|
|
41
|
+
"__version__",
|
|
42
|
+
]
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
"""Optional CPU reference-decoder adapters on a SHARED DEM (import-guarded).
|
|
2
|
+
|
|
3
|
+
These wrap the standard CPU reference implementations — the `ldpc` package's
|
|
4
|
+
BP / BP-OSD / BP-LSD and IBM's `relay-bp` Rust decoder — behind the same
|
|
5
|
+
``decode_batch(dets) -> predicted_observables`` surface as the native
|
|
6
|
+
backends, so a matched harness (``tridec.validation.run_matched``) can
|
|
7
|
+
decode the SAME shots with every decoder (apples-to-apples LER). They are the
|
|
8
|
+
validation targets the GPU kernels are held against.
|
|
9
|
+
|
|
10
|
+
Install with the ``decoders`` extra: ``pip install tridec[decoders]``.
|
|
11
|
+
The module imports without either package; each factory raises (or the
|
|
12
|
+
``*_available()`` probes return False) when its dependency is missing.
|
|
13
|
+
|
|
14
|
+
Interface (every adapter):
|
|
15
|
+
* ``.name`` -- str identifier (e.g. ``"BPOSD-10"``),
|
|
16
|
+
* ``.config`` -- dict of pinned hyperparameters (provenance),
|
|
17
|
+
* ``.dem`` -- the shared ``stim.DetectorErrorModel`` it was built from,
|
|
18
|
+
* ``.tie_break`` -- declared deterministic tie-break (gate G2),
|
|
19
|
+
* ``.decode_batch(dets: bool[shots, n_det]) -> bool[shots, n_obs]``.
|
|
20
|
+
|
|
21
|
+
For an ldpc decoder, each shot's detector syndrome is decoded to an error
|
|
22
|
+
estimate ``e_hat`` (length n_err); predicted observables = ``(Lo @ e_hat) % 2``.
|
|
23
|
+
ldpc 2.4.x exposes only single-shot ``decoder.decode(syndrome)`` (no batched
|
|
24
|
+
entry point), so ldpc adapters loop over shots.
|
|
25
|
+
"""
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
from ..dem import extract
|
|
29
|
+
|
|
30
|
+
# Pinned min-sum BP hyperparameters shared across the BP-family adapters
|
|
31
|
+
# (the provenance constants the validation grid committed to).
|
|
32
|
+
_BP_MAX_ITER = 30
|
|
33
|
+
_BP_MS_SCALING = 0.625 # standard normalized-min-sum scaling factor
|
|
34
|
+
_BP_METHOD = "minimum_sum" # min-sum BP (the kernel target)
|
|
35
|
+
_BP_SCHEDULE = "parallel"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def ldpc_available():
|
|
39
|
+
"""True iff the `ldpc` package is importable."""
|
|
40
|
+
try:
|
|
41
|
+
import ldpc # noqa: F401
|
|
42
|
+
except Exception:
|
|
43
|
+
return False
|
|
44
|
+
return True
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def relay_bp_available():
|
|
48
|
+
"""True iff relay-bp[stim] is importable (import-guarded membership)."""
|
|
49
|
+
try:
|
|
50
|
+
import relay_bp # noqa: F401
|
|
51
|
+
from relay_bp.stim import CheckMatrices # noqa: F401
|
|
52
|
+
except Exception:
|
|
53
|
+
return False
|
|
54
|
+
return True
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class _LdpcAdapter:
|
|
58
|
+
"""Base for ldpc-family adapters: build H/Lo/priors from the shared DEM,
|
|
59
|
+
decode each shot's syndrome to an error estimate, map to observables."""
|
|
60
|
+
|
|
61
|
+
def __init__(self, dem, name, config, decoder, tie_break):
|
|
62
|
+
self.dem = dem
|
|
63
|
+
self.name = name
|
|
64
|
+
self.config = dict(config)
|
|
65
|
+
# Declared deterministic tie-break (gate G2). No silent default: the
|
|
66
|
+
# matched harness asserts this is in APPROVED_TIE_BREAKS.
|
|
67
|
+
self.tie_break = tie_break
|
|
68
|
+
self._decoder = decoder
|
|
69
|
+
ex = extract(dem)
|
|
70
|
+
# Lo: (n_obs x n_err) GF2 map from error mechanisms to observables.
|
|
71
|
+
self._Lo = ex["Lo"].toarray().astype(np.uint8)
|
|
72
|
+
self._n_obs = ex["n_obs"]
|
|
73
|
+
self._n_err = ex["n_err"]
|
|
74
|
+
self._n_det = ex["n_det"]
|
|
75
|
+
|
|
76
|
+
def decode_batch(self, dets):
|
|
77
|
+
dets = np.asarray(dets, dtype=bool)
|
|
78
|
+
shots = dets.shape[0]
|
|
79
|
+
out = np.zeros((shots, self._n_obs), dtype=bool)
|
|
80
|
+
syn_u8 = dets.astype(np.uint8)
|
|
81
|
+
Lo = self._Lo
|
|
82
|
+
for i in range(shots):
|
|
83
|
+
e_hat = self._decoder.decode(syn_u8[i])
|
|
84
|
+
# predicted observables = (Lo @ e_hat) % 2
|
|
85
|
+
pred = (Lo @ np.asarray(e_hat, dtype=np.uint8)) & 1
|
|
86
|
+
out[i] = pred.astype(bool)
|
|
87
|
+
return out
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _priors(dem):
|
|
91
|
+
"""Per-mechanism priors from the shared DEM, clipped for ldpc stability."""
|
|
92
|
+
pri = extract(dem)["priors"]
|
|
93
|
+
return list(np.clip(pri, 1e-6, 1 - 1e-6))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def make_bp(dem):
|
|
97
|
+
"""Pure min-sum BP (no post-processing): ldpc.BpDecoder reference."""
|
|
98
|
+
from ldpc import BpDecoder
|
|
99
|
+
|
|
100
|
+
H = extract(dem)["H"]
|
|
101
|
+
cfg = dict(decoder="BpDecoder", bp_method=_BP_METHOD,
|
|
102
|
+
ms_scaling_factor=_BP_MS_SCALING, max_iter=_BP_MAX_ITER,
|
|
103
|
+
schedule=_BP_SCHEDULE)
|
|
104
|
+
dec = BpDecoder(H, error_channel=_priors(dem), max_iter=_BP_MAX_ITER,
|
|
105
|
+
bp_method=_BP_METHOD, ms_scaling_factor=_BP_MS_SCALING,
|
|
106
|
+
schedule=_BP_SCHEDULE)
|
|
107
|
+
return _LdpcAdapter(dem, "BP", cfg, dec, "min_sum_parallel_hard_decision")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def make_bposd0(dem):
|
|
111
|
+
"""BP-OSD order-0 (osd_0): cheapest OSD post-processing."""
|
|
112
|
+
from ldpc import BpOsdDecoder
|
|
113
|
+
|
|
114
|
+
H = extract(dem)["H"]
|
|
115
|
+
cfg = dict(decoder="BpOsdDecoder", bp_method=_BP_METHOD,
|
|
116
|
+
ms_scaling_factor=_BP_MS_SCALING, max_iter=_BP_MAX_ITER,
|
|
117
|
+
schedule=_BP_SCHEDULE, osd_method="osd_0", osd_order=0)
|
|
118
|
+
dec = BpOsdDecoder(H, error_channel=_priors(dem), max_iter=_BP_MAX_ITER,
|
|
119
|
+
bp_method=_BP_METHOD, ms_scaling_factor=_BP_MS_SCALING,
|
|
120
|
+
schedule=_BP_SCHEDULE, osd_method="osd_0", osd_order=0)
|
|
121
|
+
return _LdpcAdapter(dem, "BPOSD-0", cfg, dec, "osd0_reliability_order")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def make_bposd10(dem):
|
|
125
|
+
"""BP-OSD order-10 combination-sweep (osd_cs): the strong classical bar."""
|
|
126
|
+
from ldpc import BpOsdDecoder
|
|
127
|
+
|
|
128
|
+
H = extract(dem)["H"]
|
|
129
|
+
cfg = dict(decoder="BpOsdDecoder", bp_method=_BP_METHOD,
|
|
130
|
+
ms_scaling_factor=_BP_MS_SCALING, max_iter=_BP_MAX_ITER,
|
|
131
|
+
schedule=_BP_SCHEDULE, osd_method="osd_cs", osd_order=10)
|
|
132
|
+
dec = BpOsdDecoder(H, error_channel=_priors(dem), max_iter=_BP_MAX_ITER,
|
|
133
|
+
bp_method=_BP_METHOD, ms_scaling_factor=_BP_MS_SCALING,
|
|
134
|
+
schedule=_BP_SCHEDULE, osd_method="osd_cs", osd_order=10)
|
|
135
|
+
return _LdpcAdapter(dem, "BPOSD-10", cfg, dec, "osd_cs_order10")
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def make_bplsd(dem):
|
|
139
|
+
"""BP + Localised-Statistics Decoder (lsd_cs, order 10)."""
|
|
140
|
+
from ldpc import BpLsdDecoder
|
|
141
|
+
|
|
142
|
+
H = extract(dem)["H"]
|
|
143
|
+
lsd_order = 10
|
|
144
|
+
cfg = dict(decoder="BpLsdDecoder", bp_method=_BP_METHOD,
|
|
145
|
+
ms_scaling_factor=_BP_MS_SCALING, max_iter=_BP_MAX_ITER,
|
|
146
|
+
schedule=_BP_SCHEDULE, lsd_method="lsd_cs", lsd_order=lsd_order)
|
|
147
|
+
dec = BpLsdDecoder(H, error_channel=_priors(dem), max_iter=_BP_MAX_ITER,
|
|
148
|
+
bp_method=_BP_METHOD, ms_scaling_factor=_BP_MS_SCALING,
|
|
149
|
+
schedule=_BP_SCHEDULE, lsd_method="lsd_cs",
|
|
150
|
+
lsd_order=lsd_order)
|
|
151
|
+
return _LdpcAdapter(dem, "BPLSD", cfg, dec, "lsd_cs_order10")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# --------------------------------------------------------------------------- #
|
|
155
|
+
# Relay-BP (relay-bp[stim] >= 0.2.2) — IBM's Rust reference decoder. #
|
|
156
|
+
# --------------------------------------------------------------------------- #
|
|
157
|
+
# Construct-from-DEM:
|
|
158
|
+
# from relay_bp.stim import CheckMatrices
|
|
159
|
+
# cm = CheckMatrices.from_dem(dem) # -> .check_matrix (ndet x E csc),
|
|
160
|
+
# # .observables_matrix (nobs x E csc),
|
|
161
|
+
# # .error_priors (E,)
|
|
162
|
+
# dec = relay_bp.RelayDecoderF64(cm.check_matrix, error_priors=cm.error_priors,
|
|
163
|
+
# gamma0=, pre_iter=, num_sets=, set_max_iter=, gamma_dist_interval=,
|
|
164
|
+
# stop_nconv=, stopping_criterion='nconv') # disjoint-relay ensemble
|
|
165
|
+
# runner = relay_bp.ObservableDecoderRunner(dec, cm.observables_matrix,
|
|
166
|
+
# include_decode_result=False)
|
|
167
|
+
# Decode:
|
|
168
|
+
# runner.decode_observables_batch(syndromes uint8 [shots, n_det])
|
|
169
|
+
# -> predicted observables uint8 [shots, n_obs]
|
|
170
|
+
# This is the path relay_bp.stim.SinterDecoder_RelayBP uses internally, minus
|
|
171
|
+
# sinter's bit-packing — the runner is driven directly for a clean decode_batch.
|
|
172
|
+
_RELAY_BP_DEFAULTS = dict(
|
|
173
|
+
gamma0=0.1,
|
|
174
|
+
pre_iter=80,
|
|
175
|
+
num_sets=60,
|
|
176
|
+
set_max_iter=60,
|
|
177
|
+
gamma_dist_interval=(-0.24, 0.66),
|
|
178
|
+
stop_nconv=5,
|
|
179
|
+
stopping_criterion="nconv",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class RelayBPAdapter:
|
|
184
|
+
"""Relay-BP adapter (in-process). Builds the relay-BP decoder from the SAME
|
|
185
|
+
shared DEM via ``relay_bp.stim.CheckMatrices.from_dem`` and decodes a batch
|
|
186
|
+
of syndromes straight to observables. G1 holds trivially: ``.dem is dem``."""
|
|
187
|
+
|
|
188
|
+
def __init__(self, dem, **params):
|
|
189
|
+
import importlib.metadata as _md
|
|
190
|
+
|
|
191
|
+
import relay_bp
|
|
192
|
+
from relay_bp.stim import CheckMatrices
|
|
193
|
+
|
|
194
|
+
self.dem = dem
|
|
195
|
+
self.name = "RelayBP"
|
|
196
|
+
try:
|
|
197
|
+
ver = _md.version("relay-bp")
|
|
198
|
+
except Exception: # pragma: no cover - metadata present once installed
|
|
199
|
+
ver = "unknown"
|
|
200
|
+
cfg = dict(_RELAY_BP_DEFAULTS)
|
|
201
|
+
cfg.update(params)
|
|
202
|
+
self.config = dict(decoder="RelayBP", relay_bp_version=ver, **cfg)
|
|
203
|
+
# Deterministic relay schedule (fixed gamma distribution + nconv stop).
|
|
204
|
+
self.tie_break = "relay_bp_nconv_disjoint_ensemble"
|
|
205
|
+
|
|
206
|
+
cm = CheckMatrices.from_dem(dem)
|
|
207
|
+
self._n_obs = cm.observables_matrix.shape[0]
|
|
208
|
+
decoder = relay_bp.RelayDecoderF64(
|
|
209
|
+
cm.check_matrix,
|
|
210
|
+
error_priors=cm.error_priors,
|
|
211
|
+
**cfg,
|
|
212
|
+
)
|
|
213
|
+
self._runner = relay_bp.ObservableDecoderRunner(
|
|
214
|
+
decoder, cm.observables_matrix, include_decode_result=False)
|
|
215
|
+
|
|
216
|
+
def decode_batch(self, dets):
|
|
217
|
+
dets = np.asarray(dets, dtype=bool)
|
|
218
|
+
pred = np.asarray(
|
|
219
|
+
self._runner.decode_observables_batch(dets.astype(np.uint8)))
|
|
220
|
+
pred = (pred % 2).astype(bool)
|
|
221
|
+
if pred.ndim == 1:
|
|
222
|
+
pred = pred.reshape(-1, 1)
|
|
223
|
+
return pred
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def make_relay_bp(dem, **params):
|
|
227
|
+
return RelayBPAdapter(dem, **params)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# Registry: name -> factory(dem).
|
|
231
|
+
_FACTORIES = {
|
|
232
|
+
"BPOSD-0": make_bposd0,
|
|
233
|
+
"BPOSD-10": make_bposd10,
|
|
234
|
+
"BPLSD": make_bplsd,
|
|
235
|
+
"BP": make_bp,
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
DEFAULT_DECODERS = ("BPOSD-0", "BPOSD-10", "BPLSD", "BP")
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def build_decoders(dem, which=DEFAULT_DECODERS, include_relay=False):
|
|
242
|
+
"""Construct all requested adapters from ONE shared DEM object.
|
|
243
|
+
|
|
244
|
+
Every returned adapter has ``.dem is dem`` (provenance for the matched
|
|
245
|
+
harness). ``which`` selects/orders the ldpc-family adapters by registry
|
|
246
|
+
name. Relay-BP is OPT-IN via ``include_relay=True`` and is added ONLY when
|
|
247
|
+
its package is available (import-guarded), so the core set always builds.
|
|
248
|
+
"""
|
|
249
|
+
decoders = []
|
|
250
|
+
for name in which:
|
|
251
|
+
if name not in _FACTORIES:
|
|
252
|
+
raise KeyError(f"unknown decoder {name!r}; known: {sorted(_FACTORIES)}")
|
|
253
|
+
decoders.append(_FACTORIES[name](dem))
|
|
254
|
+
|
|
255
|
+
if include_relay and relay_bp_available():
|
|
256
|
+
decoders.append(make_relay_bp(dem))
|
|
257
|
+
|
|
258
|
+
return decoders
|
tridec/api.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""Public decoder API: ``from_dem`` / ``from_matrices`` + backend dispatch.
|
|
2
|
+
|
|
3
|
+
Backends
|
|
4
|
+
--------
|
|
5
|
+
* ``"numpy"`` — pure-numpy normalized-min-sum BP (always available; the
|
|
6
|
+
CPU reference the GPU paths are validated against).
|
|
7
|
+
* ``"torch"`` — batched torch edge-list BP; bit-identical to numpy at fp64
|
|
8
|
+
for one iteration; runs on CPU and CUDA/ROCm devices.
|
|
9
|
+
* ``"triton"`` — the Triton kernels (min-sum BP and Relay-BP); requires
|
|
10
|
+
triton + a CUDA or ROCm GPU. fp32 messages on the BP path
|
|
11
|
+
(>=99.5% hard-decision agreement vs the fp64 references,
|
|
12
|
+
LER-validated on H200 and MI300X — see bench/receipts/).
|
|
13
|
+
* ``"auto"`` — triton if importable AND a GPU is visible, else torch if
|
|
14
|
+
importable, else numpy.
|
|
15
|
+
|
|
16
|
+
Algorithms per backend (honest availability matrix):
|
|
17
|
+
|
|
18
|
+
=========== ======= ======= ========
|
|
19
|
+
algorithm numpy torch triton
|
|
20
|
+
=========== ======= ======= ========
|
|
21
|
+
bp (min-sum) yes yes yes
|
|
22
|
+
relay no no yes
|
|
23
|
+
=========== ======= ======= ========
|
|
24
|
+
|
|
25
|
+
Relay-BP has no in-package CPU implementation; its CPU reference is IBM's
|
|
26
|
+
``relay-bp`` Rust decoder, available through ``tridec.adapters`` (the
|
|
27
|
+
``decoders`` extra) and used as the validation oracle for the Triton path.
|
|
28
|
+
"""
|
|
29
|
+
import numpy as np
|
|
30
|
+
import scipy.sparse as sp
|
|
31
|
+
|
|
32
|
+
from .dem import extract
|
|
33
|
+
|
|
34
|
+
_BACKENDS = ("auto", "numpy", "torch", "triton")
|
|
35
|
+
|
|
36
|
+
# Validated defaults (the configuration the carried receipts were measured at).
|
|
37
|
+
_BP_DEFAULTS = dict(max_iter=30, ms_scaling_factor=0.625)
|
|
38
|
+
_RELAY_DEFAULTS = dict(
|
|
39
|
+
gamma0=0.1,
|
|
40
|
+
pre_iter=80,
|
|
41
|
+
num_sets=60,
|
|
42
|
+
set_max_iter=60,
|
|
43
|
+
gamma_dist_interval=(-0.24, 0.66),
|
|
44
|
+
stop_nconv=5,
|
|
45
|
+
stopping_criterion="nconv",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# --------------------------------------------------------------------------- #
|
|
50
|
+
# Backend availability / resolution. #
|
|
51
|
+
# --------------------------------------------------------------------------- #
|
|
52
|
+
def _torch_available():
|
|
53
|
+
try:
|
|
54
|
+
import torch # noqa: F401
|
|
55
|
+
return True
|
|
56
|
+
except Exception:
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _triton_gpu_available():
|
|
61
|
+
try:
|
|
62
|
+
import triton # noqa: F401
|
|
63
|
+
import torch
|
|
64
|
+
return bool(torch.cuda.is_available())
|
|
65
|
+
except Exception:
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def available_backends():
|
|
70
|
+
"""The backends usable in THIS environment, best first."""
|
|
71
|
+
out = []
|
|
72
|
+
if _triton_gpu_available():
|
|
73
|
+
out.append("triton")
|
|
74
|
+
if _torch_available():
|
|
75
|
+
out.append("torch")
|
|
76
|
+
out.append("numpy")
|
|
77
|
+
return out
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def resolve_backend(backend="auto"):
|
|
81
|
+
"""Resolve a backend request to a concrete backend name.
|
|
82
|
+
|
|
83
|
+
``"auto"`` -> triton if importable AND a GPU (CUDA or ROCm) is visible,
|
|
84
|
+
else torch if importable, else numpy. Explicitly requesting an unavailable
|
|
85
|
+
backend raises RuntimeError with the reason.
|
|
86
|
+
"""
|
|
87
|
+
if backend not in _BACKENDS:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"unknown backend {backend!r}; expected one of {_BACKENDS}")
|
|
90
|
+
if backend == "auto":
|
|
91
|
+
if _triton_gpu_available():
|
|
92
|
+
return "triton"
|
|
93
|
+
if _torch_available():
|
|
94
|
+
return "torch"
|
|
95
|
+
return "numpy"
|
|
96
|
+
if backend == "torch" and not _torch_available():
|
|
97
|
+
raise RuntimeError(
|
|
98
|
+
"torch backend requested but torch is not importable; "
|
|
99
|
+
"install with the [torch] extra: pip install tridec[torch]")
|
|
100
|
+
if backend == "triton" and not _triton_gpu_available():
|
|
101
|
+
raise RuntimeError(
|
|
102
|
+
"triton backend requested but triton + a CUDA/ROCm GPU are not "
|
|
103
|
+
"available (triton importable: requires the [gpu] extra; GPU "
|
|
104
|
+
"visible: torch.cuda.is_available() must be True)")
|
|
105
|
+
return backend
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _default_device(backend, device):
|
|
109
|
+
if device is not None:
|
|
110
|
+
return device
|
|
111
|
+
if backend == "triton":
|
|
112
|
+
return "cuda"
|
|
113
|
+
if backend == "torch":
|
|
114
|
+
try:
|
|
115
|
+
import torch
|
|
116
|
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
117
|
+
except Exception: # pragma: no cover
|
|
118
|
+
return "cpu"
|
|
119
|
+
return "cpu"
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _dense_uint8(M):
|
|
123
|
+
if M is None:
|
|
124
|
+
return None
|
|
125
|
+
if sp.issparse(M):
|
|
126
|
+
M = M.toarray()
|
|
127
|
+
return (np.asarray(M, dtype=np.uint8) % 2)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# --------------------------------------------------------------------------- #
|
|
131
|
+
# Decoders. #
|
|
132
|
+
# --------------------------------------------------------------------------- #
|
|
133
|
+
class BpDecoder:
|
|
134
|
+
"""Normalized min-sum BP over the numpy / torch / triton backends.
|
|
135
|
+
|
|
136
|
+
Construct via ``tridec.from_dem`` / ``tridec.from_matrices``
|
|
137
|
+
(or directly). ``decode_batch(dets)`` returns predicted observables
|
|
138
|
+
(bool[shots, n_obs]) when an observable map is available (always the case
|
|
139
|
+
via ``from_dem``), else hard error estimates (uint8[shots, n_bits]).
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
algorithm = "bp"
|
|
143
|
+
|
|
144
|
+
def __init__(self, H, priors, observables=None, backend="auto", device=None,
|
|
145
|
+
max_iter=30, ms_scaling_factor=0.625, block_s=256, dem=None):
|
|
146
|
+
self.backend = resolve_backend(backend)
|
|
147
|
+
self.device = _default_device(self.backend, device)
|
|
148
|
+
self.dem = dem
|
|
149
|
+
self.max_iter = int(max_iter)
|
|
150
|
+
self.ms_scaling_factor = float(ms_scaling_factor)
|
|
151
|
+
|
|
152
|
+
if self.backend == "numpy":
|
|
153
|
+
from .backends.bp_numpy import BpBaseline
|
|
154
|
+
self._impl = BpBaseline(H, priors, max_iter=max_iter,
|
|
155
|
+
ms_scaling_factor=ms_scaling_factor)
|
|
156
|
+
elif self.backend == "torch":
|
|
157
|
+
from .backends.bp_torch import BpGpu
|
|
158
|
+
self._impl = BpGpu(H, priors, max_iter=max_iter,
|
|
159
|
+
ms_scaling_factor=ms_scaling_factor)
|
|
160
|
+
else: # triton
|
|
161
|
+
from .backends.bp_triton import BpTriton
|
|
162
|
+
self._impl = BpTriton(H, priors, max_iter=max_iter,
|
|
163
|
+
ms_scaling_factor=ms_scaling_factor,
|
|
164
|
+
block_s=block_s)
|
|
165
|
+
|
|
166
|
+
Lo = _dense_uint8(observables)
|
|
167
|
+
self._Lo = Lo
|
|
168
|
+
if Lo is not None:
|
|
169
|
+
# Attach the observable map to the backend impl so its validated
|
|
170
|
+
# decode_batch path (e_hat -> (Lo @ e_hat) % 2) applies unchanged.
|
|
171
|
+
self._impl._Lo = Lo
|
|
172
|
+
self._impl._n_obs = int(Lo.shape[0])
|
|
173
|
+
self.n_obs = None if Lo is None else int(Lo.shape[0])
|
|
174
|
+
self.n_bits = self._impl.n_bits
|
|
175
|
+
self.n_checks = self._impl.n_checks
|
|
176
|
+
|
|
177
|
+
self.name = f"portable-bp[{self.backend}]"
|
|
178
|
+
self.tie_break = "min_sum_parallel_hard_decision"
|
|
179
|
+
self.config = dict(
|
|
180
|
+
decoder="tridec.BpDecoder", backend=self.backend,
|
|
181
|
+
bp_method="minimum_sum", ms_scaling_factor=self.ms_scaling_factor,
|
|
182
|
+
max_iter=self.max_iter, schedule="parallel")
|
|
183
|
+
|
|
184
|
+
@classmethod
|
|
185
|
+
def from_dem(cls, dem, backend="auto", device=None, **opts):
|
|
186
|
+
kw = dict(_BP_DEFAULTS)
|
|
187
|
+
kw.update(opts)
|
|
188
|
+
ex = extract(dem)
|
|
189
|
+
obj = cls(ex["H"], ex["priors"], observables=ex["Lo"], backend=backend,
|
|
190
|
+
device=device, dem=dem, **kw)
|
|
191
|
+
return obj
|
|
192
|
+
|
|
193
|
+
# -- decode surfaces ---------------------------------------------------- #
|
|
194
|
+
def decode_batch(self, detection_events):
|
|
195
|
+
"""Decode a batch of detector-event vectors.
|
|
196
|
+
|
|
197
|
+
Returns predicted observables (bool[shots, n_obs]) when an observable
|
|
198
|
+
map is present, else hard error estimates (uint8[shots, n_bits]).
|
|
199
|
+
"""
|
|
200
|
+
dets = np.asarray(detection_events)
|
|
201
|
+
if dets.ndim == 1:
|
|
202
|
+
dets = dets[None, :]
|
|
203
|
+
if self._Lo is not None:
|
|
204
|
+
if self.backend == "numpy":
|
|
205
|
+
return self._impl.decode_batch(dets.astype(bool))
|
|
206
|
+
return self._impl.decode_batch(dets.astype(bool), device=self.device)
|
|
207
|
+
# No observable map: return hard error estimates.
|
|
208
|
+
syn = dets.astype(np.uint8)
|
|
209
|
+
if self.backend == "numpy":
|
|
210
|
+
out = np.zeros((syn.shape[0], self.n_bits), dtype=np.uint8)
|
|
211
|
+
for i in range(syn.shape[0]):
|
|
212
|
+
out[i] = self._impl.decode(syn[i])
|
|
213
|
+
return out
|
|
214
|
+
post = self._impl.run_iterations_batch(syn, n_iter=self.max_iter,
|
|
215
|
+
device=self.device)
|
|
216
|
+
return (post < 0.0).astype(np.uint8)
|
|
217
|
+
|
|
218
|
+
def decode(self, detection_events):
|
|
219
|
+
"""Single-shot convenience: 1-D in, 1-D out."""
|
|
220
|
+
out = self.decode_batch(np.asarray(detection_events)[None, :])
|
|
221
|
+
return out[0]
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class RelayBpDecoder:
|
|
225
|
+
"""Relay-BP (disordered-memory min-sum relay ensemble), Triton backend only.
|
|
226
|
+
|
|
227
|
+
Defaults match the ``relay_bp`` Rust oracle configuration the kernels were
|
|
228
|
+
LER-validated against (gamma0=0.1, pre_iter=80, num_sets=60,
|
|
229
|
+
set_max_iter=60, gamma_dist_interval=(-0.24, 0.66), stop_nconv=5).
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
algorithm = "relay"
|
|
233
|
+
|
|
234
|
+
def __init__(self, H, priors, observables=None, backend="auto", device=None,
|
|
235
|
+
block_s=256, dtype="float64", dem=None, **relay_params):
|
|
236
|
+
resolved = resolve_backend(backend)
|
|
237
|
+
if resolved != "triton":
|
|
238
|
+
raise NotImplementedError(
|
|
239
|
+
f"Relay-BP is implemented on the triton backend only (resolved "
|
|
240
|
+
f"backend: {resolved!r}). There is no in-package CPU Relay-BP; "
|
|
241
|
+
f"for a CPU reference use the relay-bp adapter "
|
|
242
|
+
f"(tridec.adapters.make_relay_bp, [decoders] extra).")
|
|
243
|
+
self.backend = "triton"
|
|
244
|
+
self.device = _default_device("triton", device)
|
|
245
|
+
self.dem = dem
|
|
246
|
+
|
|
247
|
+
cfg = dict(_RELAY_DEFAULTS)
|
|
248
|
+
cfg.update(relay_params)
|
|
249
|
+
from .backends.relay_triton import RelayBpTriton
|
|
250
|
+
self._impl = RelayBpTriton(H, priors, block_s=block_s, dtype=dtype,
|
|
251
|
+
**cfg)
|
|
252
|
+
|
|
253
|
+
Lo = _dense_uint8(observables)
|
|
254
|
+
self._Lo = Lo
|
|
255
|
+
if Lo is not None:
|
|
256
|
+
self._impl._Lo = Lo
|
|
257
|
+
self._impl._n_obs = int(Lo.shape[0])
|
|
258
|
+
self.n_obs = None if Lo is None else int(Lo.shape[0])
|
|
259
|
+
self.n_bits = self._impl.n_bits
|
|
260
|
+
self.n_checks = self._impl.n_checks
|
|
261
|
+
|
|
262
|
+
self.name = "portable-relay-bp[triton]"
|
|
263
|
+
self.tie_break = "relay_bp_nconv_disjoint_ensemble"
|
|
264
|
+
self.config = dict(
|
|
265
|
+
decoder="tridec.RelayBpDecoder", backend="triton",
|
|
266
|
+
dtype=dtype, **cfg)
|
|
267
|
+
|
|
268
|
+
@classmethod
|
|
269
|
+
def from_dem(cls, dem, backend="auto", device=None, **opts):
|
|
270
|
+
ex = extract(dem)
|
|
271
|
+
return cls(ex["H"], ex["priors"], observables=ex["Lo"], backend=backend,
|
|
272
|
+
device=device, dem=dem, **opts)
|
|
273
|
+
|
|
274
|
+
def decode_batch(self, detection_events):
|
|
275
|
+
dets = np.asarray(detection_events)
|
|
276
|
+
if dets.ndim == 1:
|
|
277
|
+
dets = dets[None, :]
|
|
278
|
+
if self._Lo is not None:
|
|
279
|
+
return self._impl.decode_batch(dets.astype(bool), device=self.device)
|
|
280
|
+
# No observable map: return the lowest-weight valid error estimate.
|
|
281
|
+
import torch
|
|
282
|
+
dev = torch.device(self.device)
|
|
283
|
+
syn_t = torch.as_tensor(dets.astype(bool), device=dev)
|
|
284
|
+
best_eh = self._impl._relay_posteriors(syn_t, dev) # (N, S)
|
|
285
|
+
return best_eh.t().cpu().numpy().astype(np.uint8)
|
|
286
|
+
|
|
287
|
+
def decode(self, detection_events):
|
|
288
|
+
out = self.decode_batch(np.asarray(detection_events)[None, :])
|
|
289
|
+
return out[0]
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
# --------------------------------------------------------------------------- #
|
|
293
|
+
# Factories. #
|
|
294
|
+
# --------------------------------------------------------------------------- #
|
|
295
|
+
def from_dem(dem, backend="auto", algorithm="bp", device=None, **opts):
|
|
296
|
+
"""Build a decoder from a ``stim.DetectorErrorModel``.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
dem: the DEM (build with ``decompose_errors=False`` — the decoders
|
|
300
|
+
consume the raw hyperedge mechanism set).
|
|
301
|
+
backend: "auto" | "numpy" | "torch" | "triton" (see module docstring).
|
|
302
|
+
algorithm: "bp" (min-sum BP, all backends) or "relay" (Relay-BP,
|
|
303
|
+
triton backend only).
|
|
304
|
+
device: optional torch device string for the torch/triton backends.
|
|
305
|
+
**opts: decoder hyperparameters (e.g. max_iter, ms_scaling_factor for
|
|
306
|
+
bp; gamma0, pre_iter, num_sets, ... for relay).
|
|
307
|
+
|
|
308
|
+
Returns a decoder with ``decode_batch(detection_events) ->
|
|
309
|
+
predicted_observables`` and a single-shot ``decode``.
|
|
310
|
+
"""
|
|
311
|
+
if algorithm == "bp":
|
|
312
|
+
return BpDecoder.from_dem(dem, backend=backend, device=device, **opts)
|
|
313
|
+
if algorithm == "relay":
|
|
314
|
+
return RelayBpDecoder.from_dem(dem, backend=backend, device=device,
|
|
315
|
+
**opts)
|
|
316
|
+
raise ValueError(f"unknown algorithm {algorithm!r}; expected 'bp' or 'relay'")
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def from_matrices(H, priors, observables=None, backend="auto", algorithm="bp",
|
|
320
|
+
device=None, **opts):
|
|
321
|
+
"""Build a decoder from a raw GF2 parity-check matrix + per-bit priors.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
H: (n_checks x n_bits) GF2 check matrix (dense or scipy sparse).
|
|
325
|
+
priors: per-bit error probabilities, length n_bits.
|
|
326
|
+
observables: optional (n_obs x n_bits) GF2 observable map. With it,
|
|
327
|
+
``decode_batch`` returns predicted observables; without it, hard
|
|
328
|
+
error estimates.
|
|
329
|
+
backend, algorithm, device, **opts: as in ``from_dem``.
|
|
330
|
+
"""
|
|
331
|
+
if algorithm == "bp":
|
|
332
|
+
return BpDecoder(H, priors, observables=observables, backend=backend,
|
|
333
|
+
device=device, **opts)
|
|
334
|
+
if algorithm == "relay":
|
|
335
|
+
return RelayBpDecoder(H, priors, observables=observables,
|
|
336
|
+
backend=backend, device=device, **opts)
|
|
337
|
+
raise ValueError(f"unknown algorithm {algorithm!r}; expected 'bp' or 'relay'")
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Decoder backends.
|
|
2
|
+
|
|
3
|
+
``bp_numpy`` is always importable (numpy/scipy only). ``bp_torch`` requires
|
|
4
|
+
torch; ``bp_triton`` / ``relay_triton`` additionally require triton and a GPU
|
|
5
|
+
to RUN (they import without one — the kernels compile only where triton
|
|
6
|
+
exists). The API layer (``tridec.api``) imports the optional backends
|
|
7
|
+
lazily, so a missing extra never breaks the core package.
|
|
8
|
+
"""
|
|
9
|
+
from .bp_numpy import BpBaseline
|
|
10
|
+
|
|
11
|
+
__all__ = ["BpBaseline"]
|