qector-decoder-v3 0.5.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1095 @@
1
+ """
2
+ QECTOR Decoder v3 — Source-available Rust/Python QEC decoders with reproducible, artifact-hashed benchmark evidence.
3
+ Rust core + PyO3 bindings. Zero-copy NumPy. GIL-free decode.
4
+ """
5
+
6
+ from .qector_decoder_v3 import (
7
+ UnionFindDecoder as _RustUnionFindDecoder,
8
+ FastUnionFindDecoder as _RustFastUnionFindDecoder,
9
+ BlossomDecoder as _RustBlossomDecoder,
10
+ SlidingWindowDecoder as _RustSlidingWindowDecoder,
11
+ StreamingDecoder as _RustStreamingDecoder,
12
+ BatchDecoder as _RustBatchDecoder,
13
+ CPUBatchDecoder as _RustCPUBatchDecoder,
14
+ OpenCLBatchDecoder as _RustOpenCLBatchDecoder,
15
+ BenchmarkSuite as _RustBenchmarkSuite,
16
+ LookupTableDecoder as _RustLookupTableDecoder,
17
+ BPOSDDecoder as _RustBPOSDDecoder,
18
+ NeuralPredecoder as _RustNeuralPredecoder,
19
+ DetectorGraph as _RustDetectorGraph,
20
+ GNNPredecoder as _RustGNNPredecoder,
21
+ GNNTrainer as _RustGNNTrainer,
22
+ SparseBlossomDecoder as _RustSparseBlossomDecoder,
23
+ HybridDecoder as _RustHybridDecoder,
24
+ py_check_to_edges,
25
+ py_generate_surface_code_checks,
26
+ py_generate_toy_code_checks,
27
+ py_generate_ring_code_checks,
28
+ py_generate_repetition_code_checks,
29
+ run_mcp_server,
30
+ )
31
+
32
+ try:
33
+ from .qector_decoder_v3 import (
34
+ CUDABatchDecoder as _RustCUDABatchDecoder,
35
+ cuda_is_available,
36
+ )
37
+ except ImportError:
38
+ _RustCUDABatchDecoder = None # type: ignore[assignment]
39
+
40
+ def cuda_is_available():
41
+ return False
42
+
43
+
44
+ try:
45
+ from .qector_decoder_v3 import opencl_is_available
46
+ except ImportError:
47
+
48
+ def opencl_is_available():
49
+ return False
50
+
51
+
52
+ try:
53
+ from .qector_decoder_v3 import run_grpc_server
54
+ except ImportError:
55
+ run_grpc_server = None
56
+
57
+ try:
58
+ from .qector_decoder_v3 import start_metrics_server
59
+ except ImportError:
60
+ start_metrics_server = None # type: ignore[assignment]
61
+
62
+ import numpy as np
63
+
64
+ try:
65
+ from .qector_decoder_v3 import __version__
66
+ except (ImportError, AttributeError):
67
+ __version__ = "0.5.0"
68
+
69
+
70
+ class UnionFindDecoder:
71
+ """Production-ready Union-Find quantum error correction decoder.
72
+
73
+ Rust core with PyO3 bindings. Zero-copy NumPy interop.
74
+ GIL is released during decode for true parallelism.
75
+ """
76
+
77
+ def __init__(self, check_to_qubits, n_qubits=None):
78
+ if not check_to_qubits:
79
+ raise ValueError("check_to_qubits must be non-empty")
80
+ # Convert Python list-of-lists to Vec<Vec<u32>>
81
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
82
+ nq = None if n_qubits is None else int(n_qubits)
83
+ self._inner = _RustUnionFindDecoder(c2q, nq)
84
+
85
+ def decode(self, syndrome):
86
+ if not isinstance(syndrome, np.ndarray):
87
+ syndrome = np.array(syndrome, dtype=np.uint8)
88
+ if syndrome.dtype != np.uint8:
89
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
90
+ return self._inner.decode(syndrome)
91
+
92
+ def batch_decode(self, syndromes):
93
+ if not isinstance(syndromes, np.ndarray):
94
+ syndromes = np.array(syndromes, dtype=np.uint8)
95
+ if syndromes.dtype != np.uint8:
96
+ syndromes = syndromes.astype(np.uint8)
97
+ if syndromes.ndim != 2:
98
+ raise ValueError(f"syndromes must be 2D, got shape {syndromes.shape}")
99
+ return self._inner.batch_decode(syndromes)
100
+
101
+ @property
102
+ def n_qubits(self):
103
+ return self._inner.n_qubits
104
+
105
+ @property
106
+ def n_checks(self):
107
+ return self._inner.n_checks
108
+
109
+
110
+ class FastUnionFindDecoder:
111
+ """SIMD-accelerated zero-allocation Union-Find decoder.
112
+
113
+ Uses pre-allocated reusable buffers, AVX2 runtime dispatch, and FFI.
114
+ Same API as UnionFindDecoder but with lower overhead.
115
+ """
116
+
117
+ def __init__(self, check_to_qubits, n_qubits=None):
118
+ if not check_to_qubits:
119
+ raise ValueError("check_to_qubits must be non-empty")
120
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
121
+ nq = None if n_qubits is None else int(n_qubits)
122
+ self._inner = _RustFastUnionFindDecoder(c2q, nq)
123
+
124
+ def decode(self, syndrome):
125
+ if not isinstance(syndrome, np.ndarray):
126
+ syndrome = np.array(syndrome, dtype=np.uint8)
127
+ if syndrome.dtype != np.uint8:
128
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
129
+ return self._inner.decode(syndrome)
130
+
131
+ def batch_decode(self, syndromes):
132
+ if not isinstance(syndromes, np.ndarray):
133
+ syndromes = np.array(syndromes, dtype=np.uint8)
134
+ if syndromes.dtype != np.uint8:
135
+ syndromes = syndromes.astype(np.uint8)
136
+ if syndromes.ndim != 2:
137
+ raise ValueError(f"syndromes must be 2D, got shape {syndromes.shape}")
138
+ return self._inner.batch_decode(syndromes)
139
+
140
+ @property
141
+ def n_qubits(self):
142
+ return self._inner.n_qubits
143
+
144
+ @property
145
+ def n_checks(self):
146
+ return self._inner.n_checks
147
+
148
+
149
+ class BlossomDecoder:
150
+ """Minimum-Weight Perfect Matching (MWPM) decoder via Edmonds' Blossom algorithm.
151
+
152
+ Supports weighted edges for higher decoding accuracy on realistic codes.
153
+ """
154
+
155
+ def __init__(self, check_to_qubits, n_qubits=None, edge_weights=None):
156
+ if not check_to_qubits:
157
+ raise ValueError("check_to_qubits must be non-empty")
158
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
159
+ nq = None if n_qubits is None else int(n_qubits)
160
+ self._inner = _RustBlossomDecoder(c2q, nq, edge_weights)
161
+
162
+ def decode(self, syndrome):
163
+ if not isinstance(syndrome, np.ndarray):
164
+ syndrome = np.array(syndrome, dtype=np.uint8)
165
+ if syndrome.dtype != np.uint8:
166
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
167
+ return self._inner.decode(syndrome)
168
+
169
+ def batch_decode(self, syndromes):
170
+ if not isinstance(syndromes, np.ndarray):
171
+ syndromes = np.array(syndromes, dtype=np.uint8)
172
+ if syndromes.dtype != np.uint8:
173
+ syndromes = syndromes.astype(np.uint8)
174
+ if syndromes.ndim != 2:
175
+ raise ValueError(f"syndromes must be 2D, got shape {syndromes.shape}")
176
+ return self._inner.batch_decode(syndromes)
177
+
178
+ @property
179
+ def n_qubits(self):
180
+ return self._inner.n_qubits
181
+
182
+ @property
183
+ def n_checks(self):
184
+ return self._inner.n_checks
185
+
186
+ @property
187
+ def edges(self):
188
+ return self._inner.edges
189
+
190
+
191
+ class SlidingWindowDecoder:
192
+ """Sliding-window decoder with exponential decay weighting.
193
+
194
+ Maintains a window of the last W rounds. Each round's syndrome is weighted
195
+ by ``decay_factor ** age`` so that more recent rounds contribute more.
196
+ The weighted cumulative syndrome is thresholded at 0.5 and decoded with
197
+ the standard Union-Find decoder.
198
+ """
199
+
200
+ def __init__(self, check_to_qubits, n_qubits=None, window_size=10, decay_factor=0.8):
201
+ if not check_to_qubits:
202
+ raise ValueError("check_to_qubits must be non-empty")
203
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
204
+ nq = None if n_qubits is None else int(n_qubits)
205
+ self._inner = _RustSlidingWindowDecoder(c2q, nq, window_size, decay_factor)
206
+
207
+ def update(self, round_syndrome):
208
+ if not isinstance(round_syndrome, np.ndarray):
209
+ round_syndrome = np.array(round_syndrome, dtype=np.uint8)
210
+ if round_syndrome.dtype != np.uint8:
211
+ raise TypeError(f"Syndrome must be dtype uint8, got {round_syndrome.dtype}")
212
+ return self._inner.update(round_syndrome)
213
+
214
+ def flush(self):
215
+ self._inner.flush()
216
+
217
+ def decode(self, syndrome):
218
+ if not isinstance(syndrome, np.ndarray):
219
+ syndrome = np.array(syndrome, dtype=np.uint8)
220
+ if syndrome.dtype != np.uint8:
221
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
222
+ return self._inner.decode(syndrome)
223
+
224
+ @property
225
+ def n_qubits(self):
226
+ return self._inner.n_qubits
227
+
228
+ @property
229
+ def n_checks(self):
230
+ return self._inner.n_checks
231
+
232
+ @property
233
+ def window_size(self):
234
+ return self._inner.window_size
235
+
236
+ @property
237
+ def decay_factor(self):
238
+ return self._inner.decay_factor
239
+
240
+ @property
241
+ def current_round(self):
242
+ return self._inner.current_round
243
+
244
+
245
+ class StreamingDecoder:
246
+ """Streaming decoder that accumulates syndromes over multiple rounds.
247
+
248
+ Rust core with circular history buffer and OR accumulation.
249
+ """
250
+
251
+ def __init__(self, check_to_qubits, n_qubits=None, history_size=10):
252
+ if not check_to_qubits:
253
+ raise ValueError("check_to_qubits must be non-empty")
254
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
255
+ nq = None if n_qubits is None else int(n_qubits)
256
+ self._inner = _RustStreamingDecoder(c2q, nq, history_size)
257
+
258
+ def update(self, round_syndrome):
259
+ if not isinstance(round_syndrome, np.ndarray):
260
+ round_syndrome = np.array(round_syndrome, dtype=np.uint8)
261
+ if round_syndrome.dtype != np.uint8:
262
+ raise TypeError(f"Syndrome must be dtype uint8, got {round_syndrome.dtype}")
263
+ return self._inner.update(round_syndrome)
264
+
265
+ def flush(self):
266
+ self._inner.flush()
267
+
268
+ def decode(self, syndrome):
269
+ if not isinstance(syndrome, np.ndarray):
270
+ syndrome = np.array(syndrome, dtype=np.uint8)
271
+ if syndrome.dtype != np.uint8:
272
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
273
+ return self._inner.decode(syndrome)
274
+
275
+ @property
276
+ def n_qubits(self):
277
+ return self._inner.n_qubits
278
+
279
+ @property
280
+ def n_checks(self):
281
+ return self._inner.n_checks
282
+
283
+
284
+ class BatchDecoder:
285
+ """Parallel batch decoder using Rayon (Rust data parallelism).
286
+
287
+ Distributes batch decoding across all CPU cores.
288
+ """
289
+
290
+ def __init__(self, check_to_qubits, n_qubits=None):
291
+ if not check_to_qubits:
292
+ raise ValueError("check_to_qubits must be non-empty")
293
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
294
+ nq = None if n_qubits is None else int(n_qubits)
295
+ self._inner = _RustBatchDecoder(c2q, nq)
296
+
297
+ def parallel_batch_decode(self, syndromes):
298
+ if not isinstance(syndromes, np.ndarray):
299
+ syndromes = np.array(syndromes, dtype=np.uint8)
300
+ if syndromes.dtype != np.uint8:
301
+ syndromes = syndromes.astype(np.uint8)
302
+ if syndromes.ndim != 2:
303
+ raise ValueError(f"syndromes must be 2D, got shape {syndromes.shape}")
304
+ return self._inner.parallel_batch_decode(syndromes)
305
+
306
+ def batch_decode(self, syndromes):
307
+ """Alias for ``parallel_batch_decode`` for API consistency with the
308
+ other batch decoders."""
309
+ return self.parallel_batch_decode(syndromes)
310
+
311
+ @property
312
+ def n_qubits(self):
313
+ return self._inner.n_qubits
314
+
315
+ @property
316
+ def n_checks(self):
317
+ return self._inner.n_checks
318
+
319
+
320
+ class CPUBatchDecoder:
321
+ """SIMD-friendly CPU batch decoder with pooled buffers and SoA transposition."""
322
+
323
+ def __init__(self, check_to_qubits, n_qubits=None):
324
+ if not check_to_qubits:
325
+ raise ValueError("check_to_qubits must be non-empty")
326
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
327
+ nq = None if n_qubits is None else int(n_qubits)
328
+ self._inner = _RustCPUBatchDecoder(c2q, nq)
329
+
330
+ def decode(self, syndrome):
331
+ if not isinstance(syndrome, np.ndarray):
332
+ syndrome = np.array(syndrome, dtype=np.uint8)
333
+ if syndrome.dtype != np.uint8:
334
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
335
+ return self._inner.decode(syndrome)
336
+
337
+ def batch_decode(self, syndromes):
338
+ if not isinstance(syndromes, np.ndarray):
339
+ syndromes = np.array(syndromes, dtype=np.uint8)
340
+ if syndromes.dtype != np.uint8:
341
+ syndromes = syndromes.astype(np.uint8)
342
+ if syndromes.ndim != 2:
343
+ raise ValueError(f"syndromes must be 2D, got shape {syndromes.shape}")
344
+ return self._inner.batch_decode_par(syndromes)
345
+
346
+ @property
347
+ def n_qubits(self):
348
+ return self._inner.n_qubits
349
+
350
+ @property
351
+ def n_checks(self):
352
+ return self._inner.n_checks
353
+
354
+
355
+ class OpenCLBatchDecoder:
356
+ """GPU-accelerated OpenCL batch decoder.
357
+
358
+ Uses NVIDIA/AMD/Intel GPU via OpenCL for parallel batch decoding.
359
+ Falls back to CPU UnionFind for small batches (< 8) or after repeated GPU failures.
360
+ Automatically recovers from degraded mode after periodic GPU health checks.
361
+ """
362
+
363
+ def __init__(self, check_to_qubits, n_qubits=None):
364
+ if not check_to_qubits:
365
+ raise ValueError("check_to_qubits must be non-empty")
366
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
367
+ nq = None if n_qubits is None else int(n_qubits)
368
+ self._inner = _RustOpenCLBatchDecoder(c2q, nq)
369
+
370
+ def batch_decode(self, syndromes):
371
+ if not isinstance(syndromes, np.ndarray):
372
+ syndromes = np.array(syndromes, dtype=np.uint8)
373
+ if syndromes.dtype != np.uint8:
374
+ syndromes = syndromes.astype(np.uint8)
375
+ if syndromes.ndim != 2:
376
+ raise ValueError(f"syndromes must be 2D, got shape {syndromes.shape}")
377
+ return self._inner.batch_decode(syndromes)
378
+
379
+ def reset(self):
380
+ """Reset all GPU counters and exit degraded mode.
381
+
382
+ Forces the decoder to retry GPU on the next call.
383
+ Useful after driver update, GPU maintenance, or manual intervention.
384
+ """
385
+ self._inner.reset()
386
+
387
+ @property
388
+ def n_qubits(self):
389
+ return self._inner.n_qubits
390
+
391
+ @property
392
+ def n_checks(self):
393
+ return self._inner.n_checks
394
+
395
+ @property
396
+ def device_name(self) -> str:
397
+ return str(self._inner.device_name)
398
+
399
+ @property
400
+ def consecutive_failures(self):
401
+ """Number of consecutive GPU failures since last success."""
402
+ return self._inner.consecutive_failures
403
+
404
+ @property
405
+ def total_failures(self):
406
+ """Total number of GPU failures since decoder creation."""
407
+ return self._inner.total_failures
408
+
409
+ @property
410
+ def is_degraded(self):
411
+ """True if decoder is in CPU-only mode after repeated GPU failures."""
412
+ return self._inner.is_degraded
413
+
414
+ @property
415
+ def gpu_recoveries(self):
416
+ """Number of times the GPU recovered after being in degraded mode."""
417
+ return self._inner.gpu_recoveries
418
+
419
+ @staticmethod
420
+ def is_available():
421
+ """Return True if an OpenCL GPU is available on this system."""
422
+ return _RustOpenCLBatchDecoder.is_available()
423
+
424
+
425
+ class CUDABatchDecoder:
426
+ """GPU-accelerated native CUDA batch decoder.
427
+
428
+ Uses a compiled CUDA kernel loaded through the CUDA Driver API. Falls back
429
+ to CPU UnionFind for tiny batches or after repeated CUDA failures.
430
+ """
431
+
432
+ def __init__(self, check_to_qubits, n_qubits=None):
433
+ if _RustCUDABatchDecoder is None:
434
+ raise RuntimeError("qector-decoder-v3 was built without the 'cuda' feature")
435
+ if not check_to_qubits:
436
+ raise ValueError("check_to_qubits must be non-empty")
437
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
438
+ nq = None if n_qubits is None else int(n_qubits)
439
+ self._inner = _RustCUDABatchDecoder(c2q, nq)
440
+
441
+ def batch_decode(self, syndromes):
442
+ if not isinstance(syndromes, np.ndarray):
443
+ syndromes = np.array(syndromes, dtype=np.uint8)
444
+ if syndromes.dtype != np.uint8:
445
+ syndromes = syndromes.astype(np.uint8)
446
+ if syndromes.ndim != 2:
447
+ raise ValueError(f"syndromes must be 2D, got shape {syndromes.shape}")
448
+ return self._inner.batch_decode(syndromes)
449
+
450
+ def reset(self):
451
+ self._inner.reset()
452
+
453
+ @property
454
+ def n_qubits(self):
455
+ return self._inner.n_qubits
456
+
457
+ @property
458
+ def n_checks(self):
459
+ return self._inner.n_checks
460
+
461
+ @property
462
+ def device_name(self) -> str:
463
+ return str(self._inner.device_name)
464
+
465
+ @property
466
+ def compute_capability(self):
467
+ return self._inner.compute_capability
468
+
469
+ @property
470
+ def consecutive_failures(self):
471
+ return self._inner.consecutive_failures
472
+
473
+ @property
474
+ def total_failures(self):
475
+ return self._inner.total_failures
476
+
477
+ @property
478
+ def is_degraded(self):
479
+ return self._inner.is_degraded
480
+
481
+ @property
482
+ def gpu_recoveries(self):
483
+ return self._inner.gpu_recoveries
484
+
485
+ @staticmethod
486
+ def is_available():
487
+ """Return True if a CUDA driver device is available in this build."""
488
+ if _RustCUDABatchDecoder is None:
489
+ return False
490
+ return _RustCUDABatchDecoder.is_available()
491
+
492
+
493
+ class SparseBlossomDecoder:
494
+ """Region-growing Sparse Blossom decoder with RadixHeap.
495
+
496
+ Supports dynamic weight overrides from GNN Pre-Decoder for enriched decoding.
497
+ """
498
+
499
+ def __init__(self, check_to_qubits, n_qubits=None):
500
+ if not check_to_qubits:
501
+ raise ValueError("check_to_qubits must be non-empty")
502
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
503
+ nq = None if n_qubits is None else int(n_qubits)
504
+ self._inner = _RustSparseBlossomDecoder(c2q, nq)
505
+
506
+ def decode(self, syndrome):
507
+ if not isinstance(syndrome, np.ndarray):
508
+ syndrome = np.array(syndrome, dtype=np.uint8)
509
+ if syndrome.dtype != np.uint8:
510
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
511
+ return self._inner.decode(syndrome)
512
+
513
+ def decode_with_weights(self, syndrome, weights):
514
+ """Decode with per-qubit dynamic weight overrides.
515
+
516
+ Args:
517
+ syndrome: np.ndarray of shape (n_checks,) with dtype uint8.
518
+ weights: List of (qubit_id, weight) tuples.
519
+
520
+ Returns:
521
+ np.ndarray of shape (n_qubits,) with correction.
522
+ """
523
+ if not isinstance(syndrome, np.ndarray):
524
+ syndrome = np.array(syndrome, dtype=np.uint8)
525
+ if syndrome.dtype != np.uint8:
526
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
527
+ if not isinstance(weights, list):
528
+ weights = list(weights)
529
+ return self._inner.decode_with_weights(syndrome, weights)
530
+
531
+ def batch_decode(self, syndromes):
532
+ if not isinstance(syndromes, np.ndarray):
533
+ syndromes = np.array(syndromes, dtype=np.uint8)
534
+ if syndromes.dtype != np.uint8:
535
+ syndromes = syndromes.astype(np.uint8)
536
+ if syndromes.ndim != 2:
537
+ raise ValueError(f"syndromes must be 2D, got shape {syndromes.shape}")
538
+ return self._inner.batch_decode(syndromes)
539
+
540
+ @property
541
+ def n_qubits(self):
542
+ return self._inner.n_qubits
543
+
544
+ @property
545
+ def n_checks(self):
546
+ return self._inner.n_checks
547
+
548
+
549
+ class BPOSDDecoder:
550
+ """Belief Propagation + Ordered Statistics Decoding.
551
+
552
+ Min-sum BP with OSD stage for improved LER on complex codes.
553
+ """
554
+
555
+ def __init__(self, check_to_qubits, n_qubits=None, error_rate=0.1):
556
+ if not check_to_qubits:
557
+ raise ValueError("check_to_qubits must be non-empty")
558
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
559
+ nq = None if n_qubits is None else int(n_qubits)
560
+ self._inner = _RustBPOSDDecoder(c2q, nq, error_rate)
561
+
562
+ def decode(self, syndrome):
563
+ if not isinstance(syndrome, np.ndarray):
564
+ syndrome = np.array(syndrome, dtype=np.uint8)
565
+ if syndrome.dtype != np.uint8:
566
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
567
+ return self._inner.decode(syndrome)
568
+
569
+ def bp_decode(self, syndrome, max_iterations=20):
570
+ """Run Belief Propagation and return log-likelihood ratios (LLRs) for each qubit.
571
+
572
+ Args:
573
+ syndrome: np.ndarray of shape (n_checks,) with dtype uint8.
574
+ max_iterations: Number of BP iterations (default: 20).
575
+
576
+ Returns:
577
+ np.ndarray of shape (n_qubits,) with LLR values.
578
+ Positive LLR -> more likely 0, Negative LLR -> more likely 1.
579
+ """
580
+ if not isinstance(syndrome, np.ndarray):
581
+ syndrome = np.array(syndrome, dtype=np.uint8)
582
+ if syndrome.dtype != np.uint8:
583
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
584
+ return self._inner.bp_decode(syndrome, max_iterations)
585
+
586
+ @property
587
+ def n_qubits(self):
588
+ return self._inner.n_qubits
589
+
590
+ @property
591
+ def n_checks(self):
592
+ return self._inner.n_checks
593
+
594
+
595
+ class NeuralPredecoder:
596
+ """Lightweight MLP pre-decoder with Xavier initialization and SGD training."""
597
+
598
+ def __init__(self, n_input, n_output, n_hidden1=None, n_hidden2=None):
599
+ self._inner = _RustNeuralPredecoder(n_input, n_output, n_hidden1, n_hidden2)
600
+
601
+ def train(self, syndromes, corrections, n_epochs, learning_rate=0.01):
602
+ if not isinstance(syndromes, np.ndarray):
603
+ syndromes = np.array(syndromes, dtype=np.uint8)
604
+ if not isinstance(corrections, np.ndarray):
605
+ corrections = np.array(corrections, dtype=np.uint8)
606
+ self._inner.train(syndromes, corrections, n_epochs, learning_rate)
607
+
608
+ def predict(self, syndrome):
609
+ if not isinstance(syndrome, np.ndarray):
610
+ syndrome = np.array(syndrome, dtype=np.uint8)
611
+ return self._inner.predict(syndrome)
612
+
613
+ def decode(self, syndrome):
614
+ if not isinstance(syndrome, np.ndarray):
615
+ syndrome = np.array(syndrome, dtype=np.uint8)
616
+ return self._inner.decode(syndrome)
617
+
618
+ @property
619
+ def n_input(self):
620
+ return self._inner.n_input
621
+
622
+ @property
623
+ def n_output(self):
624
+ return self._inner.n_output
625
+
626
+ @property
627
+ def n_hidden1(self):
628
+ return self._inner.n_hidden1
629
+
630
+ @property
631
+ def n_hidden2(self):
632
+ return self._inner.n_hidden2
633
+
634
+
635
+ class GNNPredecoder:
636
+ """Graph Neural Network Pre-Decoder for dynamic edge weight prediction.
637
+
638
+ MPNN 3 layers + MLP readout. Predicts adjusted edge weights for SparseBlossom.
639
+
640
+ **v2.0** : Full backpropagation through MPNN layers (P0). All layers are trainable.
641
+
642
+ Dimensions must match the DetectorGraph:
643
+ - node_feat_dim = 10 (NodeFeatures::DIM)
644
+ - edge_feat_dim = 8 (EdgeFeatures::DIM)
645
+ """
646
+
647
+ # Standard dimensions matching DetectorGraph
648
+ NODE_FEAT_DIM = 10
649
+ EDGE_FEAT_DIM = 8
650
+
651
+ def __init__(self, node_feat_dim=None, edge_feat_dim=None, hidden_size=16, n_layers=2):
652
+ """Create a GNNPredecoder.
653
+
654
+ If node_feat_dim and edge_feat_dim are not provided, uses the standard
655
+ dimensions matching DetectorGraph (10 and 8).
656
+ """
657
+ nfd = node_feat_dim if node_feat_dim is not None else self.NODE_FEAT_DIM
658
+ efd = edge_feat_dim if edge_feat_dim is not None else self.EDGE_FEAT_DIM
659
+ self._inner = _RustGNNPredecoder(nfd, efd, hidden_size, n_layers)
660
+
661
+ @classmethod
662
+ def new_standard(cls, hidden_size=16, n_layers=2):
663
+ """Create a GNNPredecoder with standard dimensions matching DetectorGraph."""
664
+ return cls(cls.NODE_FEAT_DIM, cls.EDGE_FEAT_DIM, hidden_size, n_layers)
665
+
666
+ @property
667
+ def learning_rate(self):
668
+ return self._inner.learning_rate
669
+
670
+ @learning_rate.setter
671
+ def learning_rate(self, lr):
672
+ self._inner.learning_rate = lr
673
+
674
+ @property
675
+ def l2_lambda(self):
676
+ return self._inner.l2_lambda
677
+
678
+ @l2_lambda.setter
679
+ def l2_lambda(self, val):
680
+ self._inner.l2_lambda = val
681
+
682
+ def forward(self, graph):
683
+ """Predict adjusted edge weights for a DetectorGraph."""
684
+ return self._inner.forward(graph._inner if isinstance(graph, DetectorGraph) else graph)
685
+
686
+ def train(self, graphs, targets, n_epochs):
687
+ """Train the GNN on a list of graphs and target edge weights."""
688
+ inner_graphs = [g._inner if isinstance(g, DetectorGraph) else g for g in graphs]
689
+ return self._inner.train(inner_graphs, targets, n_epochs)
690
+
691
+ def predict_with_node_probs(self, graph):
692
+ """Predict edge weights and node error probabilities."""
693
+ return self._inner.predict_with_node_probs(graph._inner if isinstance(graph, DetectorGraph) else graph)
694
+
695
+
696
+ class DetectorGraph:
697
+ """Detector graph used by the GNN and hybrid decoder paths."""
698
+
699
+ NODE_FEAT_DIM = 10
700
+ EDGE_FEAT_DIM = 8
701
+
702
+ def __init__(
703
+ self,
704
+ check_to_qubits,
705
+ syndrome,
706
+ check_positions=None,
707
+ check_types=None,
708
+ base_weights=None,
709
+ n_qubits=None,
710
+ ):
711
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
712
+ syn = [int(bit) for bit in syndrome]
713
+ self._inner = _RustDetectorGraph(
714
+ c2q,
715
+ syn,
716
+ check_positions,
717
+ check_types,
718
+ base_weights,
719
+ n_qubits,
720
+ )
721
+
722
+ def update_syndrome(self, syndrome):
723
+ self._inner.update_syndrome([int(bit) for bit in syndrome])
724
+
725
+ @property
726
+ def n_nodes(self):
727
+ return self._inner.n_nodes
728
+
729
+ @property
730
+ def n_edges(self):
731
+ return self._inner.n_edges
732
+
733
+ @property
734
+ def node_features(self):
735
+ return self._inner.node_features
736
+
737
+ @property
738
+ def edge_features(self):
739
+ return self._inner.edge_features
740
+
741
+ @property
742
+ def edge_qubit_id(self):
743
+ return self._inner.edge_qubit_id
744
+
745
+
746
+ class GNNTrainer:
747
+ """End-to-end GNN training pipeline with Blossom teacher model.
748
+
749
+ Generates random syndromes, computes optimal corrections via BlossomDecoder,
750
+ extracts target edge weights, and trains the GNN via SGD.
751
+ """
752
+
753
+ def __init__(self, check_to_qubits, n_qubits, error_rate=0.1):
754
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
755
+ self._inner = _RustGNNTrainer(c2q, n_qubits, error_rate)
756
+
757
+ def train(self, gnn, n_samples, n_epochs):
758
+ """Train a GNNPredecoder and return the final MSE loss."""
759
+ return self._inner.train(gnn._inner, n_samples, n_epochs)
760
+
761
+ def train_bp(self, gnn, n_samples, n_epochs, max_bp_iter=20):
762
+ """Train a GNNPredecoder with BP marginal targets and return the final MSE loss."""
763
+ return self._inner.train_bp(gnn._inner, n_samples, n_epochs, max_bp_iter)
764
+
765
+ def generate_dataset(self, n_samples):
766
+ """Generate a training dataset and return its size."""
767
+ return self._inner.generate_dataset(n_samples)
768
+
769
+
770
+ class HybridDecoder:
771
+ """GNN Pre-Decoder + SparseBlossom hybrid decoder.
772
+
773
+ Uses a lightweight MPNN to estimate dynamic edge weights, then passes
774
+ them to SparseBlossom for enriched region-growing decoding.
775
+ """
776
+
777
+ def __init__(
778
+ self,
779
+ check_to_qubits,
780
+ n_qubits=None,
781
+ check_positions=None,
782
+ check_types=None,
783
+ base_weights=None,
784
+ gnn_hidden_size=64,
785
+ gnn_n_layers=3,
786
+ ):
787
+ if not check_to_qubits:
788
+ raise ValueError("check_to_qubits must be non-empty")
789
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
790
+ nq = None if n_qubits is None else int(n_qubits)
791
+ self._inner = _RustHybridDecoder(
792
+ c2q, nq, check_positions, check_types, base_weights, gnn_hidden_size, gnn_n_layers
793
+ )
794
+
795
+ def decode_hybrid(self, syndrome):
796
+ if not isinstance(syndrome, np.ndarray):
797
+ syndrome = np.array(syndrome, dtype=np.uint8)
798
+ if syndrome.dtype != np.uint8:
799
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
800
+ return self._inner.decode_hybrid(syndrome)
801
+
802
+ def decode_heuristic(self, syndrome):
803
+ if not isinstance(syndrome, np.ndarray):
804
+ syndrome = np.array(syndrome, dtype=np.uint8)
805
+ if syndrome.dtype != np.uint8:
806
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
807
+ return self._inner.decode_heuristic(syndrome)
808
+
809
+ def decode_standard(self, syndrome):
810
+ if not isinstance(syndrome, np.ndarray):
811
+ syndrome = np.array(syndrome, dtype=np.uint8)
812
+ if syndrome.dtype != np.uint8:
813
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
814
+ return self._inner.decode_standard(syndrome)
815
+
816
+ def batch_decode_hybrid(self, syndromes):
817
+ """Batch decode multiple syndromes using the GNN-enhanced pipeline.
818
+
819
+ Args:
820
+ syndromes: np.ndarray of shape (batch, n_checks) or list of lists.
821
+
822
+ Returns:
823
+ np.ndarray of shape (batch, n_qubits) with corrections.
824
+ """
825
+ if not isinstance(syndromes, np.ndarray):
826
+ syndromes = np.array(syndromes, dtype=np.uint8)
827
+ if syndromes.dtype != np.uint8:
828
+ raise TypeError(f"Syndromes must be dtype uint8, got {syndromes.dtype}")
829
+ if syndromes.ndim != 2:
830
+ raise ValueError(f"Expected 2D array, got shape {syndromes.shape}")
831
+ return self._inner.batch_decode_hybrid(syndromes)
832
+
833
+ def batch_decode_standard(self, syndromes):
834
+ """Batch decode multiple syndromes using standard SparseBlossom.
835
+
836
+ Args:
837
+ syndromes: np.ndarray of shape (batch, n_checks) or list of lists.
838
+
839
+ Returns:
840
+ np.ndarray of shape (batch, n_qubits) with corrections.
841
+ """
842
+ if not isinstance(syndromes, np.ndarray):
843
+ syndromes = np.array(syndromes, dtype=np.uint8)
844
+ if syndromes.dtype != np.uint8:
845
+ raise TypeError(f"Syndromes must be dtype uint8, got {syndromes.dtype}")
846
+ if syndromes.ndim != 2:
847
+ raise ValueError(f"Expected 2D array, got shape {syndromes.shape}")
848
+ return self._inner.batch_decode_standard(syndromes)
849
+
850
+ def train(self, n_samples, n_epochs, error_rate=0.1):
851
+ """Train the internal GNN using a Blossom teacher model.
852
+
853
+ Generates random syndromes, computes optimal corrections via Blossom,
854
+ and trains the GNN via SGD to predict edge weights.
855
+
856
+ Args:
857
+ n_samples: Number of training examples to generate.
858
+ n_epochs: Number of training epochs.
859
+ error_rate: Physical error rate for syndrome generation.
860
+
861
+ Returns:
862
+ Final MSE loss after training.
863
+ """
864
+ return self._inner.train(n_samples, n_epochs, error_rate)
865
+
866
+ def train_bp(self, n_samples, n_epochs, error_rate=0.1, max_bp_iter=20):
867
+ """Train the internal GNN using BP marginal probability targets.
868
+
869
+ Generates random syndromes, computes marginal error probabilities via
870
+ Belief Propagation (min-sum), and trains the GNN to predict these
871
+ probabilities as edge weights.
872
+
873
+ Args:
874
+ n_samples: Number of training examples to generate.
875
+ n_epochs: Number of training epochs.
876
+ error_rate: Physical error rate for syndrome generation.
877
+ max_bp_iter: Number of BP iterations for marginal computation.
878
+
879
+ Returns:
880
+ Final MSE loss after training.
881
+ """
882
+ return self._inner.train_bp(n_samples, n_epochs, error_rate, max_bp_iter)
883
+
884
+ @property
885
+ def n_qubits(self):
886
+ return self._inner.n_qubits
887
+
888
+ @property
889
+ def n_checks(self):
890
+ return self._inner.n_checks
891
+
892
+
893
+ class LookupTableDecoder:
894
+ """Exact lookup-table decoder with UnionFind fallback.
895
+
896
+ Pre-computes all syndrome → correction mappings for small codes
897
+ (n_qubits ≤ 20, exhaustive; otherwise low-weight enumeration).
898
+ Decoding is O(1) for precomputed syndromes, fallback to UnionFind otherwise.
899
+ """
900
+
901
+ def __init__(self, check_to_qubits, n_qubits=None):
902
+ if not check_to_qubits:
903
+ raise ValueError("check_to_qubits must be non-empty")
904
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
905
+ nq = None if n_qubits is None else int(n_qubits)
906
+ self._inner = _RustLookupTableDecoder(c2q, nq)
907
+
908
+ def build_table(self, max_entries):
909
+ """Populate the lookup table by enumerating errors up to max_entries."""
910
+ self._inner.build_table(int(max_entries))
911
+
912
+ def decode(self, syndrome):
913
+ if not isinstance(syndrome, np.ndarray):
914
+ syndrome = np.array(syndrome, dtype=np.uint8)
915
+ if syndrome.dtype != np.uint8:
916
+ raise TypeError(f"Syndrome must be dtype uint8, got {syndrome.dtype}")
917
+ return self._inner.decode(syndrome)
918
+
919
+ def batch_decode(self, syndromes):
920
+ if not isinstance(syndromes, np.ndarray):
921
+ syndromes = np.array(syndromes, dtype=np.uint8)
922
+ if syndromes.dtype != np.uint8:
923
+ syndromes = syndromes.astype(np.uint8)
924
+ if syndromes.ndim != 2:
925
+ raise ValueError(f"syndromes must be 2D, got shape {syndromes.shape}")
926
+ return self._inner.batch_decode(syndromes)
927
+
928
+ @property
929
+ def n_qubits(self):
930
+ return self._inner.n_qubits
931
+
932
+ @property
933
+ def n_checks(self):
934
+ return self._inner.n_checks
935
+
936
+ @property
937
+ def table_size(self):
938
+ return self._inner.table_size
939
+
940
+
941
+ def check_to_edges(check_to_qubits):
942
+ """Convert check_to_qubits to edge list."""
943
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
944
+ return py_check_to_edges(c2q)
945
+
946
+
947
+ def generate_surface_code_checks(distance):
948
+ """Generate the compact periodic surface-code checks used by this API."""
949
+ return py_generate_surface_code_checks(int(distance))
950
+
951
+
952
+ def generate_toy_code_checks(distance):
953
+ """Generate a toy code (not a proper quantum code) with d*d qubits for backward compatibility.
954
+
955
+ Kept for reference. Prefer generate_surface_code_checks for real QEC tests.
956
+ """
957
+ return py_generate_toy_code_checks(int(distance))
958
+
959
+
960
+ def generate_ring_code_checks(distance):
961
+ """Generate a simple 1D ring code for testing."""
962
+ return py_generate_ring_code_checks(int(distance))
963
+
964
+
965
+ def generate_repetition_code_checks(distance):
966
+ """Generate a 1D repetition/chain code for testing."""
967
+ return py_generate_repetition_code_checks(int(distance))
968
+
969
+
970
+ class BenchmarkSuite:
971
+ """Production benchmark suite. Wraps Rust-native benchmarking."""
972
+
973
+ def __init__(self, check_to_qubits, n_qubits=None, n_samples=10000, seed=42):
974
+ if not check_to_qubits:
975
+ raise ValueError("check_to_qubits must be non-empty")
976
+ c2q = [[int(q) for q in check] for check in check_to_qubits]
977
+ nq = None if n_qubits is None else int(n_qubits)
978
+ self._inner = _RustBenchmarkSuite(c2q, nq, n_samples, seed)
979
+ self.n_samples = n_samples
980
+
981
+ def run(self):
982
+ import json
983
+
984
+ raw = self._inner.run()
985
+ return json.loads(raw)
986
+
987
+ def save(self, path, results):
988
+ import json
989
+ from pathlib import Path
990
+
991
+ Path(path).write_text(json.dumps(results, indent=2), encoding="utf-8")
992
+
993
+
994
+ __all__ = [
995
+ "UnionFindDecoder",
996
+ "FastUnionFindDecoder",
997
+ "BlossomDecoder",
998
+ "SlidingWindowDecoder",
999
+ "StreamingDecoder",
1000
+ "BatchDecoder",
1001
+ "CPUBatchDecoder",
1002
+ "OpenCLBatchDecoder",
1003
+ "CUDABatchDecoder",
1004
+ "SparseBlossomDecoder",
1005
+ "BPOSDDecoder",
1006
+ "NeuralPredecoder",
1007
+ "DetectorGraph",
1008
+ "GNNPredecoder",
1009
+ "GNNTrainer",
1010
+ "HybridDecoder",
1011
+ "LookupTableDecoder",
1012
+ "BenchmarkSuite",
1013
+ "check_to_edges",
1014
+ "generate_surface_code_checks",
1015
+ "generate_toy_code_checks",
1016
+ "generate_ring_code_checks",
1017
+ "generate_repetition_code_checks",
1018
+ "start_metrics_server",
1019
+ "run_mcp_server",
1020
+ "cuda_is_available",
1021
+ "opencl_is_available",
1022
+ "run_grpc_server",
1023
+ ]
1024
+
1025
+
1026
+ # Ecosystem / tooling layer (pure-Python, built on the compiled core)
1027
+ from . import (
1028
+ codes,
1029
+ dem,
1030
+ result,
1031
+ backend,
1032
+ pymatching_compat,
1033
+ benchmarking,
1034
+ belief_matching,
1035
+ bposd,
1036
+ predecoder,
1037
+ )
1038
+ from . import workbench
1039
+ from .backend import AutoDecoder, BackendConfig, Backend
1040
+ from .result import DecodeResult, decode_with_diagnostics
1041
+ from .belief_matching import BeliefMatching
1042
+ from .bposd import BpOsdDecoder
1043
+ from .predecoder import PredecodedDecoder
1044
+ from .workbench import Workbench
1045
+
1046
+ # sinter_compat imports `sinter` lazily; tolerate its absence.
1047
+ try:
1048
+ from . import sinter_compat
1049
+ except Exception: # pragma: no cover
1050
+ sinter_compat = None # type: ignore[assignment]
1051
+
1052
+ __all__ += [
1053
+ "codes",
1054
+ "dem",
1055
+ "result",
1056
+ "backend",
1057
+ "pymatching_compat",
1058
+ "benchmarking",
1059
+ "belief_matching",
1060
+ "bposd",
1061
+ "predecoder",
1062
+ "sinter_compat",
1063
+ "AutoDecoder",
1064
+ "BackendConfig",
1065
+ "Backend",
1066
+ "DecodeResult",
1067
+ "decode_with_diagnostics",
1068
+ "BeliefMatching",
1069
+ "BpOsdDecoder",
1070
+ "PredecodedDecoder",
1071
+ "workbench",
1072
+ "Workbench",
1073
+ ]
1074
+
1075
+ # Optional ecosystem integrations (tolerate missing third-party deps)
1076
+ try:
1077
+ from . import qiskit_plugin
1078
+ except Exception: # pragma: no cover
1079
+ qiskit_plugin = None # type: ignore[assignment]
1080
+
1081
+ try:
1082
+ from . import stim_compat
1083
+ except Exception: # pragma: no cover
1084
+ stim_compat = None # type: ignore[assignment]
1085
+
1086
+ try:
1087
+ from . import rest_api
1088
+ except Exception: # pragma: no cover
1089
+ rest_api = None # type: ignore[assignment]
1090
+
1091
+ __all__ += [
1092
+ "qiskit_plugin",
1093
+ "stim_compat",
1094
+ "rest_api",
1095
+ ]