tridec 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tridec/__init__.py ADDED
@@ -0,0 +1,42 @@
1
+ """tridec: vendor-portable GPU decoders for quantum LDPC codes.
2
+
3
+ Triton min-sum BP and Relay-BP decoders that consume any stim
4
+ DetectorErrorModel or raw parity-check matrices, with CPU reference
5
+ implementations, validated against the standard CPU references (ldpc,
6
+ relay-bp), running on NVIDIA (CUDA) and AMD (ROCm) GPUs.
7
+
8
+ Quickstart::
9
+
10
+ import stim, tridec
11
+
12
+ circuit = stim.Circuit.from_file("memory.stim")
13
+ dem = circuit.detector_error_model(decompose_errors=False)
14
+ decoder = tridec.from_dem(dem, backend="auto")
15
+
16
+ dets, obs = circuit.compile_detector_sampler(seed=0).sample(
17
+ 10_000, separate_observables=True)
18
+ pred = decoder.decode_batch(dets) # (shots, n_obs) bool
19
+ ler = (pred != obs).any(axis=1).mean()
20
+ """
21
+ from .api import (
22
+ BpDecoder,
23
+ RelayBpDecoder,
24
+ available_backends,
25
+ from_dem,
26
+ from_matrices,
27
+ resolve_backend,
28
+ )
29
+ from .dem import extract
30
+
31
+ __version__ = "0.1.0.dev0"
32
+
33
+ __all__ = [
34
+ "BpDecoder",
35
+ "RelayBpDecoder",
36
+ "available_backends",
37
+ "extract",
38
+ "from_dem",
39
+ "from_matrices",
40
+ "resolve_backend",
41
+ "__version__",
42
+ ]
@@ -0,0 +1,258 @@
1
+ """Optional CPU reference-decoder adapters on a SHARED DEM (import-guarded).
2
+
3
+ These wrap the standard CPU reference implementations — the `ldpc` package's
4
+ BP / BP-OSD / BP-LSD and IBM's `relay-bp` Rust decoder — behind the same
5
+ ``decode_batch(dets) -> predicted_observables`` surface as the native
6
+ backends, so a matched harness (``tridec.validation.run_matched``) can
7
+ decode the SAME shots with every decoder (apples-to-apples LER). They are the
8
+ validation targets the GPU kernels are held against.
9
+
10
+ Install with the ``decoders`` extra: ``pip install tridec[decoders]``.
11
+ The module imports without either package; each factory raises (or the
12
+ ``*_available()`` probes return False) when its dependency is missing.
13
+
14
+ Interface (every adapter):
15
+ * ``.name`` -- str identifier (e.g. ``"BPOSD-10"``),
16
+ * ``.config`` -- dict of pinned hyperparameters (provenance),
17
+ * ``.dem`` -- the shared ``stim.DetectorErrorModel`` it was built from,
18
+ * ``.tie_break`` -- declared deterministic tie-break (gate G2),
19
+ * ``.decode_batch(dets: bool[shots, n_det]) -> bool[shots, n_obs]``.
20
+
21
+ For an ldpc decoder, each shot's detector syndrome is decoded to an error
22
+ estimate ``e_hat`` (length n_err); predicted observables = ``(Lo @ e_hat) % 2``.
23
+ ldpc 2.4.x exposes only single-shot ``decoder.decode(syndrome)`` (no batched
24
+ entry point), so ldpc adapters loop over shots.
25
+ """
26
+ import numpy as np
27
+
28
+ from ..dem import extract
29
+
30
+ # Pinned min-sum BP hyperparameters shared across the BP-family adapters
31
+ # (the provenance constants the validation grid committed to).
32
+ _BP_MAX_ITER = 30
33
+ _BP_MS_SCALING = 0.625 # standard normalized-min-sum scaling factor
34
+ _BP_METHOD = "minimum_sum" # min-sum BP (the kernel target)
35
+ _BP_SCHEDULE = "parallel"
36
+
37
+
38
+ def ldpc_available():
39
+ """True iff the `ldpc` package is importable."""
40
+ try:
41
+ import ldpc # noqa: F401
42
+ except Exception:
43
+ return False
44
+ return True
45
+
46
+
47
+ def relay_bp_available():
48
+ """True iff relay-bp[stim] is importable (import-guarded membership)."""
49
+ try:
50
+ import relay_bp # noqa: F401
51
+ from relay_bp.stim import CheckMatrices # noqa: F401
52
+ except Exception:
53
+ return False
54
+ return True
55
+
56
+
57
+ class _LdpcAdapter:
58
+ """Base for ldpc-family adapters: build H/Lo/priors from the shared DEM,
59
+ decode each shot's syndrome to an error estimate, map to observables."""
60
+
61
+ def __init__(self, dem, name, config, decoder, tie_break):
62
+ self.dem = dem
63
+ self.name = name
64
+ self.config = dict(config)
65
+ # Declared deterministic tie-break (gate G2). No silent default: the
66
+ # matched harness asserts this is in APPROVED_TIE_BREAKS.
67
+ self.tie_break = tie_break
68
+ self._decoder = decoder
69
+ ex = extract(dem)
70
+ # Lo: (n_obs x n_err) GF2 map from error mechanisms to observables.
71
+ self._Lo = ex["Lo"].toarray().astype(np.uint8)
72
+ self._n_obs = ex["n_obs"]
73
+ self._n_err = ex["n_err"]
74
+ self._n_det = ex["n_det"]
75
+
76
+ def decode_batch(self, dets):
77
+ dets = np.asarray(dets, dtype=bool)
78
+ shots = dets.shape[0]
79
+ out = np.zeros((shots, self._n_obs), dtype=bool)
80
+ syn_u8 = dets.astype(np.uint8)
81
+ Lo = self._Lo
82
+ for i in range(shots):
83
+ e_hat = self._decoder.decode(syn_u8[i])
84
+ # predicted observables = (Lo @ e_hat) % 2
85
+ pred = (Lo @ np.asarray(e_hat, dtype=np.uint8)) & 1
86
+ out[i] = pred.astype(bool)
87
+ return out
88
+
89
+
90
+ def _priors(dem):
91
+ """Per-mechanism priors from the shared DEM, clipped for ldpc stability."""
92
+ pri = extract(dem)["priors"]
93
+ return list(np.clip(pri, 1e-6, 1 - 1e-6))
94
+
95
+
96
+ def make_bp(dem):
97
+ """Pure min-sum BP (no post-processing): ldpc.BpDecoder reference."""
98
+ from ldpc import BpDecoder
99
+
100
+ H = extract(dem)["H"]
101
+ cfg = dict(decoder="BpDecoder", bp_method=_BP_METHOD,
102
+ ms_scaling_factor=_BP_MS_SCALING, max_iter=_BP_MAX_ITER,
103
+ schedule=_BP_SCHEDULE)
104
+ dec = BpDecoder(H, error_channel=_priors(dem), max_iter=_BP_MAX_ITER,
105
+ bp_method=_BP_METHOD, ms_scaling_factor=_BP_MS_SCALING,
106
+ schedule=_BP_SCHEDULE)
107
+ return _LdpcAdapter(dem, "BP", cfg, dec, "min_sum_parallel_hard_decision")
108
+
109
+
110
+ def make_bposd0(dem):
111
+ """BP-OSD order-0 (osd_0): cheapest OSD post-processing."""
112
+ from ldpc import BpOsdDecoder
113
+
114
+ H = extract(dem)["H"]
115
+ cfg = dict(decoder="BpOsdDecoder", bp_method=_BP_METHOD,
116
+ ms_scaling_factor=_BP_MS_SCALING, max_iter=_BP_MAX_ITER,
117
+ schedule=_BP_SCHEDULE, osd_method="osd_0", osd_order=0)
118
+ dec = BpOsdDecoder(H, error_channel=_priors(dem), max_iter=_BP_MAX_ITER,
119
+ bp_method=_BP_METHOD, ms_scaling_factor=_BP_MS_SCALING,
120
+ schedule=_BP_SCHEDULE, osd_method="osd_0", osd_order=0)
121
+ return _LdpcAdapter(dem, "BPOSD-0", cfg, dec, "osd0_reliability_order")
122
+
123
+
124
+ def make_bposd10(dem):
125
+ """BP-OSD order-10 combination-sweep (osd_cs): the strong classical bar."""
126
+ from ldpc import BpOsdDecoder
127
+
128
+ H = extract(dem)["H"]
129
+ cfg = dict(decoder="BpOsdDecoder", bp_method=_BP_METHOD,
130
+ ms_scaling_factor=_BP_MS_SCALING, max_iter=_BP_MAX_ITER,
131
+ schedule=_BP_SCHEDULE, osd_method="osd_cs", osd_order=10)
132
+ dec = BpOsdDecoder(H, error_channel=_priors(dem), max_iter=_BP_MAX_ITER,
133
+ bp_method=_BP_METHOD, ms_scaling_factor=_BP_MS_SCALING,
134
+ schedule=_BP_SCHEDULE, osd_method="osd_cs", osd_order=10)
135
+ return _LdpcAdapter(dem, "BPOSD-10", cfg, dec, "osd_cs_order10")
136
+
137
+
138
+ def make_bplsd(dem):
139
+ """BP + Localised-Statistics Decoder (lsd_cs, order 10)."""
140
+ from ldpc import BpLsdDecoder
141
+
142
+ H = extract(dem)["H"]
143
+ lsd_order = 10
144
+ cfg = dict(decoder="BpLsdDecoder", bp_method=_BP_METHOD,
145
+ ms_scaling_factor=_BP_MS_SCALING, max_iter=_BP_MAX_ITER,
146
+ schedule=_BP_SCHEDULE, lsd_method="lsd_cs", lsd_order=lsd_order)
147
+ dec = BpLsdDecoder(H, error_channel=_priors(dem), max_iter=_BP_MAX_ITER,
148
+ bp_method=_BP_METHOD, ms_scaling_factor=_BP_MS_SCALING,
149
+ schedule=_BP_SCHEDULE, lsd_method="lsd_cs",
150
+ lsd_order=lsd_order)
151
+ return _LdpcAdapter(dem, "BPLSD", cfg, dec, "lsd_cs_order10")
152
+
153
+
154
+ # --------------------------------------------------------------------------- #
155
+ # Relay-BP (relay-bp[stim] >= 0.2.2) — IBM's Rust reference decoder. #
156
+ # --------------------------------------------------------------------------- #
157
+ # Construct-from-DEM:
158
+ # from relay_bp.stim import CheckMatrices
159
+ # cm = CheckMatrices.from_dem(dem) # -> .check_matrix (ndet x E csc),
160
+ # # .observables_matrix (nobs x E csc),
161
+ # # .error_priors (E,)
162
+ # dec = relay_bp.RelayDecoderF64(cm.check_matrix, error_priors=cm.error_priors,
163
+ # gamma0=, pre_iter=, num_sets=, set_max_iter=, gamma_dist_interval=,
164
+ # stop_nconv=, stopping_criterion='nconv') # disjoint-relay ensemble
165
+ # runner = relay_bp.ObservableDecoderRunner(dec, cm.observables_matrix,
166
+ # include_decode_result=False)
167
+ # Decode:
168
+ # runner.decode_observables_batch(syndromes uint8 [shots, n_det])
169
+ # -> predicted observables uint8 [shots, n_obs]
170
+ # This is the path relay_bp.stim.SinterDecoder_RelayBP uses internally, minus
171
+ # sinter's bit-packing — the runner is driven directly for a clean decode_batch.
172
+ _RELAY_BP_DEFAULTS = dict(
173
+ gamma0=0.1,
174
+ pre_iter=80,
175
+ num_sets=60,
176
+ set_max_iter=60,
177
+ gamma_dist_interval=(-0.24, 0.66),
178
+ stop_nconv=5,
179
+ stopping_criterion="nconv",
180
+ )
181
+
182
+
183
+ class RelayBPAdapter:
184
+ """Relay-BP adapter (in-process). Builds the relay-BP decoder from the SAME
185
+ shared DEM via ``relay_bp.stim.CheckMatrices.from_dem`` and decodes a batch
186
+ of syndromes straight to observables. G1 holds trivially: ``.dem is dem``."""
187
+
188
+ def __init__(self, dem, **params):
189
+ import importlib.metadata as _md
190
+
191
+ import relay_bp
192
+ from relay_bp.stim import CheckMatrices
193
+
194
+ self.dem = dem
195
+ self.name = "RelayBP"
196
+ try:
197
+ ver = _md.version("relay-bp")
198
+ except Exception: # pragma: no cover - metadata present once installed
199
+ ver = "unknown"
200
+ cfg = dict(_RELAY_BP_DEFAULTS)
201
+ cfg.update(params)
202
+ self.config = dict(decoder="RelayBP", relay_bp_version=ver, **cfg)
203
+ # Deterministic relay schedule (fixed gamma distribution + nconv stop).
204
+ self.tie_break = "relay_bp_nconv_disjoint_ensemble"
205
+
206
+ cm = CheckMatrices.from_dem(dem)
207
+ self._n_obs = cm.observables_matrix.shape[0]
208
+ decoder = relay_bp.RelayDecoderF64(
209
+ cm.check_matrix,
210
+ error_priors=cm.error_priors,
211
+ **cfg,
212
+ )
213
+ self._runner = relay_bp.ObservableDecoderRunner(
214
+ decoder, cm.observables_matrix, include_decode_result=False)
215
+
216
+ def decode_batch(self, dets):
217
+ dets = np.asarray(dets, dtype=bool)
218
+ pred = np.asarray(
219
+ self._runner.decode_observables_batch(dets.astype(np.uint8)))
220
+ pred = (pred % 2).astype(bool)
221
+ if pred.ndim == 1:
222
+ pred = pred.reshape(-1, 1)
223
+ return pred
224
+
225
+
226
+ def make_relay_bp(dem, **params):
227
+ return RelayBPAdapter(dem, **params)
228
+
229
+
230
+ # Registry: name -> factory(dem).
231
+ _FACTORIES = {
232
+ "BPOSD-0": make_bposd0,
233
+ "BPOSD-10": make_bposd10,
234
+ "BPLSD": make_bplsd,
235
+ "BP": make_bp,
236
+ }
237
+
238
+ DEFAULT_DECODERS = ("BPOSD-0", "BPOSD-10", "BPLSD", "BP")
239
+
240
+
241
+ def build_decoders(dem, which=DEFAULT_DECODERS, include_relay=False):
242
+ """Construct all requested adapters from ONE shared DEM object.
243
+
244
+ Every returned adapter has ``.dem is dem`` (provenance for the matched
245
+ harness). ``which`` selects/orders the ldpc-family adapters by registry
246
+ name. Relay-BP is OPT-IN via ``include_relay=True`` and is added ONLY when
247
+ its package is available (import-guarded), so the core set always builds.
248
+ """
249
+ decoders = []
250
+ for name in which:
251
+ if name not in _FACTORIES:
252
+ raise KeyError(f"unknown decoder {name!r}; known: {sorted(_FACTORIES)}")
253
+ decoders.append(_FACTORIES[name](dem))
254
+
255
+ if include_relay and relay_bp_available():
256
+ decoders.append(make_relay_bp(dem))
257
+
258
+ return decoders
tridec/api.py ADDED
@@ -0,0 +1,337 @@
1
+ """Public decoder API: ``from_dem`` / ``from_matrices`` + backend dispatch.
2
+
3
+ Backends
4
+ --------
5
+ * ``"numpy"`` — pure-numpy normalized-min-sum BP (always available; the
6
+ CPU reference the GPU paths are validated against).
7
+ * ``"torch"`` — batched torch edge-list BP; bit-identical to numpy at fp64
8
+ for one iteration; runs on CPU and CUDA/ROCm devices.
9
+ * ``"triton"`` — the Triton kernels (min-sum BP and Relay-BP); requires
10
+ triton + a CUDA or ROCm GPU. fp32 messages on the BP path
11
+ (>=99.5% hard-decision agreement vs the fp64 references,
12
+ LER-validated on H200 and MI300X — see bench/receipts/).
13
+ * ``"auto"`` — triton if importable AND a GPU is visible, else torch if
14
+ importable, else numpy.
15
+
16
+ Algorithms per backend (honest availability matrix):
17
+
18
+ =========== ======= ======= ========
19
+ algorithm numpy torch triton
20
+ =========== ======= ======= ========
21
+ bp (min-sum) yes yes yes
22
+ relay no no yes
23
+ =========== ======= ======= ========
24
+
25
+ Relay-BP has no in-package CPU implementation; its CPU reference is IBM's
26
+ ``relay-bp`` Rust decoder, available through ``tridec.adapters`` (the
27
+ ``decoders`` extra) and used as the validation oracle for the Triton path.
28
+ """
29
+ import numpy as np
30
+ import scipy.sparse as sp
31
+
32
+ from .dem import extract
33
+
34
+ _BACKENDS = ("auto", "numpy", "torch", "triton")
35
+
36
+ # Validated defaults (the configuration the carried receipts were measured at).
37
+ _BP_DEFAULTS = dict(max_iter=30, ms_scaling_factor=0.625)
38
+ _RELAY_DEFAULTS = dict(
39
+ gamma0=0.1,
40
+ pre_iter=80,
41
+ num_sets=60,
42
+ set_max_iter=60,
43
+ gamma_dist_interval=(-0.24, 0.66),
44
+ stop_nconv=5,
45
+ stopping_criterion="nconv",
46
+ )
47
+
48
+
49
+ # --------------------------------------------------------------------------- #
50
+ # Backend availability / resolution. #
51
+ # --------------------------------------------------------------------------- #
52
+ def _torch_available():
53
+ try:
54
+ import torch # noqa: F401
55
+ return True
56
+ except Exception:
57
+ return False
58
+
59
+
60
+ def _triton_gpu_available():
61
+ try:
62
+ import triton # noqa: F401
63
+ import torch
64
+ return bool(torch.cuda.is_available())
65
+ except Exception:
66
+ return False
67
+
68
+
69
+ def available_backends():
70
+ """The backends usable in THIS environment, best first."""
71
+ out = []
72
+ if _triton_gpu_available():
73
+ out.append("triton")
74
+ if _torch_available():
75
+ out.append("torch")
76
+ out.append("numpy")
77
+ return out
78
+
79
+
80
+ def resolve_backend(backend="auto"):
81
+ """Resolve a backend request to a concrete backend name.
82
+
83
+ ``"auto"`` -> triton if importable AND a GPU (CUDA or ROCm) is visible,
84
+ else torch if importable, else numpy. Explicitly requesting an unavailable
85
+ backend raises RuntimeError with the reason.
86
+ """
87
+ if backend not in _BACKENDS:
88
+ raise ValueError(
89
+ f"unknown backend {backend!r}; expected one of {_BACKENDS}")
90
+ if backend == "auto":
91
+ if _triton_gpu_available():
92
+ return "triton"
93
+ if _torch_available():
94
+ return "torch"
95
+ return "numpy"
96
+ if backend == "torch" and not _torch_available():
97
+ raise RuntimeError(
98
+ "torch backend requested but torch is not importable; "
99
+ "install with the [torch] extra: pip install tridec[torch]")
100
+ if backend == "triton" and not _triton_gpu_available():
101
+ raise RuntimeError(
102
+ "triton backend requested but triton + a CUDA/ROCm GPU are not "
103
+ "available (triton importable: requires the [gpu] extra; GPU "
104
+ "visible: torch.cuda.is_available() must be True)")
105
+ return backend
106
+
107
+
108
+ def _default_device(backend, device):
109
+ if device is not None:
110
+ return device
111
+ if backend == "triton":
112
+ return "cuda"
113
+ if backend == "torch":
114
+ try:
115
+ import torch
116
+ return "cuda" if torch.cuda.is_available() else "cpu"
117
+ except Exception: # pragma: no cover
118
+ return "cpu"
119
+ return "cpu"
120
+
121
+
122
+ def _dense_uint8(M):
123
+ if M is None:
124
+ return None
125
+ if sp.issparse(M):
126
+ M = M.toarray()
127
+ return (np.asarray(M, dtype=np.uint8) % 2)
128
+
129
+
130
+ # --------------------------------------------------------------------------- #
131
+ # Decoders. #
132
+ # --------------------------------------------------------------------------- #
133
+ class BpDecoder:
134
+ """Normalized min-sum BP over the numpy / torch / triton backends.
135
+
136
+ Construct via ``tridec.from_dem`` / ``tridec.from_matrices``
137
+ (or directly). ``decode_batch(dets)`` returns predicted observables
138
+ (bool[shots, n_obs]) when an observable map is available (always the case
139
+ via ``from_dem``), else hard error estimates (uint8[shots, n_bits]).
140
+ """
141
+
142
+ algorithm = "bp"
143
+
144
+ def __init__(self, H, priors, observables=None, backend="auto", device=None,
145
+ max_iter=30, ms_scaling_factor=0.625, block_s=256, dem=None):
146
+ self.backend = resolve_backend(backend)
147
+ self.device = _default_device(self.backend, device)
148
+ self.dem = dem
149
+ self.max_iter = int(max_iter)
150
+ self.ms_scaling_factor = float(ms_scaling_factor)
151
+
152
+ if self.backend == "numpy":
153
+ from .backends.bp_numpy import BpBaseline
154
+ self._impl = BpBaseline(H, priors, max_iter=max_iter,
155
+ ms_scaling_factor=ms_scaling_factor)
156
+ elif self.backend == "torch":
157
+ from .backends.bp_torch import BpGpu
158
+ self._impl = BpGpu(H, priors, max_iter=max_iter,
159
+ ms_scaling_factor=ms_scaling_factor)
160
+ else: # triton
161
+ from .backends.bp_triton import BpTriton
162
+ self._impl = BpTriton(H, priors, max_iter=max_iter,
163
+ ms_scaling_factor=ms_scaling_factor,
164
+ block_s=block_s)
165
+
166
+ Lo = _dense_uint8(observables)
167
+ self._Lo = Lo
168
+ if Lo is not None:
169
+ # Attach the observable map to the backend impl so its validated
170
+ # decode_batch path (e_hat -> (Lo @ e_hat) % 2) applies unchanged.
171
+ self._impl._Lo = Lo
172
+ self._impl._n_obs = int(Lo.shape[0])
173
+ self.n_obs = None if Lo is None else int(Lo.shape[0])
174
+ self.n_bits = self._impl.n_bits
175
+ self.n_checks = self._impl.n_checks
176
+
177
+ self.name = f"portable-bp[{self.backend}]"
178
+ self.tie_break = "min_sum_parallel_hard_decision"
179
+ self.config = dict(
180
+ decoder="tridec.BpDecoder", backend=self.backend,
181
+ bp_method="minimum_sum", ms_scaling_factor=self.ms_scaling_factor,
182
+ max_iter=self.max_iter, schedule="parallel")
183
+
184
+ @classmethod
185
+ def from_dem(cls, dem, backend="auto", device=None, **opts):
186
+ kw = dict(_BP_DEFAULTS)
187
+ kw.update(opts)
188
+ ex = extract(dem)
189
+ obj = cls(ex["H"], ex["priors"], observables=ex["Lo"], backend=backend,
190
+ device=device, dem=dem, **kw)
191
+ return obj
192
+
193
+ # -- decode surfaces ---------------------------------------------------- #
194
+ def decode_batch(self, detection_events):
195
+ """Decode a batch of detector-event vectors.
196
+
197
+ Returns predicted observables (bool[shots, n_obs]) when an observable
198
+ map is present, else hard error estimates (uint8[shots, n_bits]).
199
+ """
200
+ dets = np.asarray(detection_events)
201
+ if dets.ndim == 1:
202
+ dets = dets[None, :]
203
+ if self._Lo is not None:
204
+ if self.backend == "numpy":
205
+ return self._impl.decode_batch(dets.astype(bool))
206
+ return self._impl.decode_batch(dets.astype(bool), device=self.device)
207
+ # No observable map: return hard error estimates.
208
+ syn = dets.astype(np.uint8)
209
+ if self.backend == "numpy":
210
+ out = np.zeros((syn.shape[0], self.n_bits), dtype=np.uint8)
211
+ for i in range(syn.shape[0]):
212
+ out[i] = self._impl.decode(syn[i])
213
+ return out
214
+ post = self._impl.run_iterations_batch(syn, n_iter=self.max_iter,
215
+ device=self.device)
216
+ return (post < 0.0).astype(np.uint8)
217
+
218
+ def decode(self, detection_events):
219
+ """Single-shot convenience: 1-D in, 1-D out."""
220
+ out = self.decode_batch(np.asarray(detection_events)[None, :])
221
+ return out[0]
222
+
223
+
224
+ class RelayBpDecoder:
225
+ """Relay-BP (disordered-memory min-sum relay ensemble), Triton backend only.
226
+
227
+ Defaults match the ``relay_bp`` Rust oracle configuration the kernels were
228
+ LER-validated against (gamma0=0.1, pre_iter=80, num_sets=60,
229
+ set_max_iter=60, gamma_dist_interval=(-0.24, 0.66), stop_nconv=5).
230
+ """
231
+
232
+ algorithm = "relay"
233
+
234
+ def __init__(self, H, priors, observables=None, backend="auto", device=None,
235
+ block_s=256, dtype="float64", dem=None, **relay_params):
236
+ resolved = resolve_backend(backend)
237
+ if resolved != "triton":
238
+ raise NotImplementedError(
239
+ f"Relay-BP is implemented on the triton backend only (resolved "
240
+ f"backend: {resolved!r}). There is no in-package CPU Relay-BP; "
241
+ f"for a CPU reference use the relay-bp adapter "
242
+ f"(tridec.adapters.make_relay_bp, [decoders] extra).")
243
+ self.backend = "triton"
244
+ self.device = _default_device("triton", device)
245
+ self.dem = dem
246
+
247
+ cfg = dict(_RELAY_DEFAULTS)
248
+ cfg.update(relay_params)
249
+ from .backends.relay_triton import RelayBpTriton
250
+ self._impl = RelayBpTriton(H, priors, block_s=block_s, dtype=dtype,
251
+ **cfg)
252
+
253
+ Lo = _dense_uint8(observables)
254
+ self._Lo = Lo
255
+ if Lo is not None:
256
+ self._impl._Lo = Lo
257
+ self._impl._n_obs = int(Lo.shape[0])
258
+ self.n_obs = None if Lo is None else int(Lo.shape[0])
259
+ self.n_bits = self._impl.n_bits
260
+ self.n_checks = self._impl.n_checks
261
+
262
+ self.name = "portable-relay-bp[triton]"
263
+ self.tie_break = "relay_bp_nconv_disjoint_ensemble"
264
+ self.config = dict(
265
+ decoder="tridec.RelayBpDecoder", backend="triton",
266
+ dtype=dtype, **cfg)
267
+
268
+ @classmethod
269
+ def from_dem(cls, dem, backend="auto", device=None, **opts):
270
+ ex = extract(dem)
271
+ return cls(ex["H"], ex["priors"], observables=ex["Lo"], backend=backend,
272
+ device=device, dem=dem, **opts)
273
+
274
+ def decode_batch(self, detection_events):
275
+ dets = np.asarray(detection_events)
276
+ if dets.ndim == 1:
277
+ dets = dets[None, :]
278
+ if self._Lo is not None:
279
+ return self._impl.decode_batch(dets.astype(bool), device=self.device)
280
+ # No observable map: return the lowest-weight valid error estimate.
281
+ import torch
282
+ dev = torch.device(self.device)
283
+ syn_t = torch.as_tensor(dets.astype(bool), device=dev)
284
+ best_eh = self._impl._relay_posteriors(syn_t, dev) # (N, S)
285
+ return best_eh.t().cpu().numpy().astype(np.uint8)
286
+
287
+ def decode(self, detection_events):
288
+ out = self.decode_batch(np.asarray(detection_events)[None, :])
289
+ return out[0]
290
+
291
+
292
+ # --------------------------------------------------------------------------- #
293
+ # Factories. #
294
+ # --------------------------------------------------------------------------- #
295
+ def from_dem(dem, backend="auto", algorithm="bp", device=None, **opts):
296
+ """Build a decoder from a ``stim.DetectorErrorModel``.
297
+
298
+ Args:
299
+ dem: the DEM (build with ``decompose_errors=False`` — the decoders
300
+ consume the raw hyperedge mechanism set).
301
+ backend: "auto" | "numpy" | "torch" | "triton" (see module docstring).
302
+ algorithm: "bp" (min-sum BP, all backends) or "relay" (Relay-BP,
303
+ triton backend only).
304
+ device: optional torch device string for the torch/triton backends.
305
+ **opts: decoder hyperparameters (e.g. max_iter, ms_scaling_factor for
306
+ bp; gamma0, pre_iter, num_sets, ... for relay).
307
+
308
+ Returns a decoder with ``decode_batch(detection_events) ->
309
+ predicted_observables`` and a single-shot ``decode``.
310
+ """
311
+ if algorithm == "bp":
312
+ return BpDecoder.from_dem(dem, backend=backend, device=device, **opts)
313
+ if algorithm == "relay":
314
+ return RelayBpDecoder.from_dem(dem, backend=backend, device=device,
315
+ **opts)
316
+ raise ValueError(f"unknown algorithm {algorithm!r}; expected 'bp' or 'relay'")
317
+
318
+
319
+ def from_matrices(H, priors, observables=None, backend="auto", algorithm="bp",
320
+ device=None, **opts):
321
+ """Build a decoder from a raw GF2 parity-check matrix + per-bit priors.
322
+
323
+ Args:
324
+ H: (n_checks x n_bits) GF2 check matrix (dense or scipy sparse).
325
+ priors: per-bit error probabilities, length n_bits.
326
+ observables: optional (n_obs x n_bits) GF2 observable map. With it,
327
+ ``decode_batch`` returns predicted observables; without it, hard
328
+ error estimates.
329
+ backend, algorithm, device, **opts: as in ``from_dem``.
330
+ """
331
+ if algorithm == "bp":
332
+ return BpDecoder(H, priors, observables=observables, backend=backend,
333
+ device=device, **opts)
334
+ if algorithm == "relay":
335
+ return RelayBpDecoder(H, priors, observables=observables,
336
+ backend=backend, device=device, **opts)
337
+ raise ValueError(f"unknown algorithm {algorithm!r}; expected 'bp' or 'relay'")
@@ -0,0 +1,11 @@
1
+ """Decoder backends.
2
+
3
+ ``bp_numpy`` is always importable (numpy/scipy only). ``bp_torch`` requires
4
+ torch; ``bp_triton`` / ``relay_triton`` additionally require triton and a GPU
5
+ to RUN (they import without one — the kernels compile only where triton
6
+ exists). The API layer (``tridec.api``) imports the optional backends
7
+ lazily, so a missing extra never breaks the core package.
8
+ """
9
+ from .bp_numpy import BpBaseline
10
+
11
+ __all__ = ["BpBaseline"]