sinter 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sinter might be problematic. Click here for more details.
- sinter/__init__.py +47 -0
- sinter/_collection/__init__.py +10 -0
- sinter/_collection/_collection.py +480 -0
- sinter/_collection/_collection_manager.py +581 -0
- sinter/_collection/_collection_manager_test.py +287 -0
- sinter/_collection/_collection_test.py +317 -0
- sinter/_collection/_collection_worker_loop.py +35 -0
- sinter/_collection/_collection_worker_state.py +259 -0
- sinter/_collection/_collection_worker_test.py +222 -0
- sinter/_collection/_mux_sampler.py +56 -0
- sinter/_collection/_printer.py +65 -0
- sinter/_collection/_sampler_ramp_throttled.py +66 -0
- sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
- sinter/_command/__init__.py +0 -0
- sinter/_command/_main.py +39 -0
- sinter/_command/_main_collect.py +350 -0
- sinter/_command/_main_collect_test.py +482 -0
- sinter/_command/_main_combine.py +84 -0
- sinter/_command/_main_combine_test.py +153 -0
- sinter/_command/_main_plot.py +817 -0
- sinter/_command/_main_plot_test.py +445 -0
- sinter/_command/_main_predict.py +75 -0
- sinter/_command/_main_predict_test.py +36 -0
- sinter/_data/__init__.py +20 -0
- sinter/_data/_anon_task_stats.py +89 -0
- sinter/_data/_anon_task_stats_test.py +35 -0
- sinter/_data/_collection_options.py +106 -0
- sinter/_data/_collection_options_test.py +24 -0
- sinter/_data/_csv_out.py +74 -0
- sinter/_data/_existing_data.py +173 -0
- sinter/_data/_existing_data_test.py +41 -0
- sinter/_data/_task.py +311 -0
- sinter/_data/_task_stats.py +244 -0
- sinter/_data/_task_stats_test.py +140 -0
- sinter/_data/_task_test.py +38 -0
- sinter/_decoding/__init__.py +16 -0
- sinter/_decoding/_decoding.py +419 -0
- sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
- sinter/_decoding/_decoding_decoder_class.py +161 -0
- sinter/_decoding/_decoding_fusion_blossom.py +193 -0
- sinter/_decoding/_decoding_mwpf.py +302 -0
- sinter/_decoding/_decoding_pymatching.py +81 -0
- sinter/_decoding/_decoding_test.py +480 -0
- sinter/_decoding/_decoding_vacuous.py +38 -0
- sinter/_decoding/_perfectionist_sampler.py +38 -0
- sinter/_decoding/_sampler.py +72 -0
- sinter/_decoding/_stim_then_decode_sampler.py +222 -0
- sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
- sinter/_plotting.py +619 -0
- sinter/_plotting_test.py +108 -0
- sinter/_predict.py +381 -0
- sinter/_predict_test.py +227 -0
- sinter/_probability_util.py +519 -0
- sinter/_probability_util_test.py +281 -0
- sinter-1.15.0.data/data/README.md +332 -0
- sinter-1.15.0.data/data/readme_example_plot.png +0 -0
- sinter-1.15.0.data/data/requirements.txt +4 -0
- sinter-1.15.0.dist-info/METADATA +354 -0
- sinter-1.15.0.dist-info/RECORD +62 -0
- sinter-1.15.0.dist-info/WHEEL +5 -0
- sinter-1.15.0.dist-info/entry_points.txt +2 -0
- sinter-1.15.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import pathlib
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import stim
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CompiledDecoder(metaclass=abc.ABCMeta):
|
|
9
|
+
"""Abstract class for decoders preconfigured to a specific decoding task.
|
|
10
|
+
|
|
11
|
+
This is the type returned by `sinter.Decoder.compile_decoder_for_dem`. The
|
|
12
|
+
idea is that, when many shots of the same decoding task are going to be
|
|
13
|
+
performed, it is valuable to pay the cost of configuring the decoder only
|
|
14
|
+
once instead of once per batch of shots. Custom decoders can optionally
|
|
15
|
+
implement that method, and return this type, to increase sampling
|
|
16
|
+
efficiency.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
def decode_shots_bit_packed(
|
|
21
|
+
self,
|
|
22
|
+
*,
|
|
23
|
+
bit_packed_detection_event_data: np.ndarray,
|
|
24
|
+
) -> np.ndarray:
|
|
25
|
+
"""Predicts observable flips from the given detection events.
|
|
26
|
+
|
|
27
|
+
All data taken and returned must be bit packed with bitorder='little'.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
bit_packed_detection_event_data: Detection event data stored as a
|
|
31
|
+
bit packed numpy array. The numpy array will have the following
|
|
32
|
+
dtype/shape:
|
|
33
|
+
|
|
34
|
+
dtype: uint8
|
|
35
|
+
shape: (num_shots, ceil(dem.num_detectors / 8))
|
|
36
|
+
|
|
37
|
+
where `num_shots` is the number of shots to decoder and `dem` is
|
|
38
|
+
the detector error model this instance was compiled to decode.
|
|
39
|
+
|
|
40
|
+
It's guaranteed that the data will be laid out in memory so that
|
|
41
|
+
detection events within a shot are contiguous in memory (i.e.
|
|
42
|
+
that bit_packed_detection_event_data.strides[1] == 1).
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Bit packed observable flip data stored as a bit packed numpy array.
|
|
46
|
+
The numpy array must have the following dtype/shape:
|
|
47
|
+
|
|
48
|
+
dtype: uint8
|
|
49
|
+
shape: (num_shots, ceil(dem.num_observables / 8))
|
|
50
|
+
|
|
51
|
+
where `num_shots` is bit_packed_detection_event_data.shape[0] and
|
|
52
|
+
`dem` is the detector error model this instance was compiled to
|
|
53
|
+
decode.
|
|
54
|
+
"""
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Decoder:
|
|
59
|
+
"""Abstract base class for custom decoders.
|
|
60
|
+
|
|
61
|
+
Custom decoders can be explained to sinter by inheriting from this class and
|
|
62
|
+
implementing its methods.
|
|
63
|
+
|
|
64
|
+
Decoder classes MUST be serializable (e.g. via pickling), so that they can
|
|
65
|
+
be given to worker processes when using python multiprocessing.
|
|
66
|
+
|
|
67
|
+
Child classes should implement `compile_decoder_for_dem`, but (for legacy
|
|
68
|
+
reasons) can alternatively implement `decode_via_files`. At least one of
|
|
69
|
+
the two methods must be implemented.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def compile_decoder_for_dem(
|
|
73
|
+
self,
|
|
74
|
+
*,
|
|
75
|
+
dem: stim.DetectorErrorModel,
|
|
76
|
+
) -> CompiledDecoder:
|
|
77
|
+
"""Creates a decoder preconfigured for the given detector error model.
|
|
78
|
+
|
|
79
|
+
This method is optional to implement. By default, it will raise a
|
|
80
|
+
NotImplementedError. When sampling, sinter will attempt to use this
|
|
81
|
+
method first and otherwise fallback to using `decode_via_files`.
|
|
82
|
+
|
|
83
|
+
The idea is that the preconfigured decoder amortizes the cost of
|
|
84
|
+
configuration over more calls. This makes smaller batch sizes efficient,
|
|
85
|
+
reducing the amount of memory used for storing each batch, improving
|
|
86
|
+
overall efficiency.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
dem: A detector error model for the samples that will need to be
|
|
90
|
+
decoded. What to configure the decoder to decode.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
An instance of `sinter.CompiledDecoder` that can be used to invoke
|
|
94
|
+
the preconfigured decoder.
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
NotImplementedError: This `sinter.Decoder` doesn't support compiling
|
|
98
|
+
for a dem.
|
|
99
|
+
"""
|
|
100
|
+
raise NotImplementedError('compile_decoder_for_dem')
|
|
101
|
+
|
|
102
|
+
def decode_via_files(self,
|
|
103
|
+
*,
|
|
104
|
+
num_shots: int,
|
|
105
|
+
num_dets: int,
|
|
106
|
+
num_obs: int,
|
|
107
|
+
dem_path: pathlib.Path,
|
|
108
|
+
dets_b8_in_path: pathlib.Path,
|
|
109
|
+
obs_predictions_b8_out_path: pathlib.Path,
|
|
110
|
+
tmp_dir: pathlib.Path,
|
|
111
|
+
) -> None:
|
|
112
|
+
"""Performs decoding by reading/writing problems and answers from disk.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
num_shots: The number of times the circuit was sampled. The number
|
|
116
|
+
of problems to be solved.
|
|
117
|
+
num_dets: The number of detectors in the circuit. The number of
|
|
118
|
+
detection event bits in each shot.
|
|
119
|
+
num_obs: The number of observables in the circuit. The number of
|
|
120
|
+
predicted bits in each shot.
|
|
121
|
+
dem_path: The file path where the detector error model should be
|
|
122
|
+
read from, e.g. using `stim.DetectorErrorModel.from_file`. The
|
|
123
|
+
error mechanisms specified by the detector error model should be
|
|
124
|
+
used to configure the decoder.
|
|
125
|
+
dets_b8_in_path: The file path that detection event data should be
|
|
126
|
+
read from. Note that the file may be a named pipe instead of a
|
|
127
|
+
fixed size object. The detection events will be in b8 format
|
|
128
|
+
(see
|
|
129
|
+
https://github.com/quantumlib/Stim/blob/main/doc/result_formats.md
|
|
130
|
+
). The number of detection events per shot is available via the
|
|
131
|
+
`num_dets` argument or via the detector error model at
|
|
132
|
+
`dem_path`.
|
|
133
|
+
obs_predictions_b8_out_path: The file path that decoder predictions
|
|
134
|
+
must be written to. The predictions must be written in b8 format
|
|
135
|
+
(see
|
|
136
|
+
https://github.com/quantumlib/Stim/blob/main/doc/result_formats.md
|
|
137
|
+
). The number of observables per shot is available via the
|
|
138
|
+
`num_obs` argument or via the detector error model at
|
|
139
|
+
`dem_path`.
|
|
140
|
+
tmp_dir: Any temporary files generated by the decoder during its
|
|
141
|
+
operation MUST be put into this directory. The reason for this
|
|
142
|
+
requirement is because sinter is allowed to kill the decoding
|
|
143
|
+
process without warning, without giving it time to clean up any
|
|
144
|
+
temporary objects. All cleanup should be done via sinter
|
|
145
|
+
deleting this directory after killing the decoder.
|
|
146
|
+
"""
|
|
147
|
+
dem = stim.DetectorErrorModel.from_file(dem_path)
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
compiled = self.compile_decoder_for_dem(dem=dem)
|
|
151
|
+
except NotImplementedError as ex:
|
|
152
|
+
raise NotImplementedError(f"{type(self).__qualname__} didn't implement `compile_decoder_for_dem` or `decode_via_files`.") from ex
|
|
153
|
+
|
|
154
|
+
num_det_bytes = -(-num_dets // 8)
|
|
155
|
+
num_obs_bytes = -(-num_obs // 8)
|
|
156
|
+
dets = np.fromfile(dets_b8_in_path, dtype=np.uint8, count=num_shots * num_det_bytes)
|
|
157
|
+
dets = dets.reshape(num_shots, num_det_bytes)
|
|
158
|
+
obs = compiled.decode_shots_bit_packed(bit_packed_detection_event_data=dets)
|
|
159
|
+
if obs.dtype != np.uint8 or obs.shape != (num_shots, num_obs_bytes):
|
|
160
|
+
raise ValueError(f"Got a numpy array with dtype={obs.dtype},shape={obs.shape} instead of dtype={np.uint8},shape={(num_shots, num_obs_bytes)} from {type(self).__qualname__}(...).compile_decoder_for_dem(...).decode_shots_bit_packed(...).")
|
|
161
|
+
obs.tofile(obs_predictions_b8_out_path)
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import pathlib
|
|
3
|
+
from typing import Callable, List, TYPE_CHECKING, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import stim
|
|
7
|
+
|
|
8
|
+
from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import fusion_blossom
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FusionBlossomCompiledDecoder(CompiledDecoder):
|
|
15
|
+
def __init__(self, solver: 'fusion_blossom.SolverSerial', fault_masks: 'np.ndarray', num_dets: int, num_obs: int):
|
|
16
|
+
self.solver = solver
|
|
17
|
+
self.fault_masks = fault_masks
|
|
18
|
+
self.num_dets = num_dets
|
|
19
|
+
self.num_obs = num_obs
|
|
20
|
+
|
|
21
|
+
def decode_shots_bit_packed(
|
|
22
|
+
self,
|
|
23
|
+
*,
|
|
24
|
+
bit_packed_detection_event_data: 'np.ndarray',
|
|
25
|
+
) -> 'np.ndarray':
|
|
26
|
+
num_shots = bit_packed_detection_event_data.shape[0]
|
|
27
|
+
predictions = np.zeros(shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8)
|
|
28
|
+
import fusion_blossom
|
|
29
|
+
for shot in range(num_shots):
|
|
30
|
+
dets_sparse = np.flatnonzero(np.unpackbits(bit_packed_detection_event_data[shot], count=self.num_dets, bitorder='little'))
|
|
31
|
+
syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse)
|
|
32
|
+
self.solver.solve(syndrome)
|
|
33
|
+
prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]))
|
|
34
|
+
predictions[shot] = np.packbits(
|
|
35
|
+
np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8),
|
|
36
|
+
bitorder="little",
|
|
37
|
+
)
|
|
38
|
+
self.solver.clear()
|
|
39
|
+
return predictions
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class FusionBlossomDecoder(Decoder):
|
|
43
|
+
"""Use fusion blossom to predict observables from detection events."""
|
|
44
|
+
|
|
45
|
+
def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder:
|
|
46
|
+
try:
|
|
47
|
+
import fusion_blossom
|
|
48
|
+
except ImportError as ex:
|
|
49
|
+
raise ImportError(
|
|
50
|
+
"The decoder 'fusion_blossom' isn't installed\n"
|
|
51
|
+
"To fix this, install the python package 'fusion_blossom' into your environment.\n"
|
|
52
|
+
"For example, if you are using pip, run `pip install fusion_blossom`.\n"
|
|
53
|
+
) from ex
|
|
54
|
+
|
|
55
|
+
solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(dem)
|
|
56
|
+
return FusionBlossomCompiledDecoder(solver, fault_masks, dem.num_detectors, dem.num_observables)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def decode_via_files(self,
|
|
60
|
+
*,
|
|
61
|
+
num_shots: int,
|
|
62
|
+
num_dets: int,
|
|
63
|
+
num_obs: int,
|
|
64
|
+
dem_path: pathlib.Path,
|
|
65
|
+
dets_b8_in_path: pathlib.Path,
|
|
66
|
+
obs_predictions_b8_out_path: pathlib.Path,
|
|
67
|
+
tmp_dir: pathlib.Path,
|
|
68
|
+
) -> None:
|
|
69
|
+
try:
|
|
70
|
+
import fusion_blossom
|
|
71
|
+
except ImportError as ex:
|
|
72
|
+
raise ImportError(
|
|
73
|
+
"The decoder 'fusion_blossom' isn't installed\n"
|
|
74
|
+
"To fix this, install the python package 'fusion-blossom' into your environment.\n"
|
|
75
|
+
"For example, if you are using pip, run `pip install fusion-blossom~=0.1.4`.\n"
|
|
76
|
+
) from ex
|
|
77
|
+
|
|
78
|
+
error_model = stim.DetectorErrorModel.from_file(dem_path)
|
|
79
|
+
solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(error_model)
|
|
80
|
+
num_det_bytes = math.ceil(num_dets / 8)
|
|
81
|
+
with open(dets_b8_in_path, 'rb') as dets_in_f:
|
|
82
|
+
with open(obs_predictions_b8_out_path, 'wb') as obs_out_f:
|
|
83
|
+
for _ in range(num_shots):
|
|
84
|
+
dets_bit_packed = np.fromfile(dets_in_f, dtype=np.uint8, count=num_det_bytes)
|
|
85
|
+
if dets_bit_packed.shape != (num_det_bytes,):
|
|
86
|
+
raise IOError('Missing dets data.')
|
|
87
|
+
dets_sparse = np.flatnonzero(np.unpackbits(dets_bit_packed, count=num_dets, bitorder='little'))
|
|
88
|
+
syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse)
|
|
89
|
+
solver.solve(syndrome)
|
|
90
|
+
prediction = int(np.bitwise_xor.reduce(fault_masks[solver.subgraph()]))
|
|
91
|
+
obs_out_f.write(prediction.to_bytes((num_obs + 7) // 8, byteorder='little'))
|
|
92
|
+
solver.clear()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def iter_flatten_model(model: stim.DetectorErrorModel,
|
|
96
|
+
handle_error: Callable[[float, List[int], List[int]], None],
|
|
97
|
+
handle_detector_coords: Callable[[int, np.ndarray], None]):
|
|
98
|
+
det_offset = 0
|
|
99
|
+
coords_offset = np.zeros(100, dtype=np.float64)
|
|
100
|
+
|
|
101
|
+
def _helper(m: stim.DetectorErrorModel, reps: int):
|
|
102
|
+
nonlocal det_offset
|
|
103
|
+
nonlocal coords_offset
|
|
104
|
+
for _ in range(reps):
|
|
105
|
+
for instruction in m:
|
|
106
|
+
if isinstance(instruction, stim.DemRepeatBlock):
|
|
107
|
+
_helper(instruction.body_copy(), instruction.repeat_count)
|
|
108
|
+
elif isinstance(instruction, stim.DemInstruction):
|
|
109
|
+
if instruction.type == "error":
|
|
110
|
+
dets: List[int] = []
|
|
111
|
+
frames: List[int] = []
|
|
112
|
+
t: stim.DemTarget
|
|
113
|
+
p = instruction.args_copy()[0]
|
|
114
|
+
for t in instruction.targets_copy():
|
|
115
|
+
if t.is_relative_detector_id():
|
|
116
|
+
dets.append(t.val + det_offset)
|
|
117
|
+
elif t.is_logical_observable_id():
|
|
118
|
+
frames.append(t.val)
|
|
119
|
+
elif t.is_separator():
|
|
120
|
+
# Treat each component of a decomposed error as an independent error.
|
|
121
|
+
# (Ideally we could configure some sort of correlated analysis; oh well.)
|
|
122
|
+
handle_error(p, dets, frames)
|
|
123
|
+
frames = []
|
|
124
|
+
dets = []
|
|
125
|
+
# Handle last component.
|
|
126
|
+
handle_error(p, dets, frames)
|
|
127
|
+
elif instruction.type == "shift_detectors":
|
|
128
|
+
det_offset += instruction.targets_copy()[0]
|
|
129
|
+
a = np.array(instruction.args_copy())
|
|
130
|
+
coords_offset[:len(a)] += a
|
|
131
|
+
elif instruction.type == "detector":
|
|
132
|
+
a = np.array(instruction.args_copy())
|
|
133
|
+
for t in instruction.targets_copy():
|
|
134
|
+
handle_detector_coords(t.val + det_offset, a + coords_offset[:len(a)])
|
|
135
|
+
elif instruction.type == "logical_observable":
|
|
136
|
+
pass
|
|
137
|
+
else:
|
|
138
|
+
raise NotImplementedError()
|
|
139
|
+
else:
|
|
140
|
+
raise NotImplementedError()
|
|
141
|
+
_helper(model, 1)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def detector_error_model_to_fusion_blossom_solver_and_fault_masks(model: stim.DetectorErrorModel) -> Tuple['fusion_blossom.SolverSerial', np.ndarray]:
|
|
145
|
+
"""Convert a stim error model into a NetworkX graph."""
|
|
146
|
+
|
|
147
|
+
import fusion_blossom
|
|
148
|
+
|
|
149
|
+
def handle_error(p: float, dets: List[int], frame_changes: List[int]):
|
|
150
|
+
if p == 0:
|
|
151
|
+
return
|
|
152
|
+
if len(dets) == 0:
|
|
153
|
+
# No symptoms for this error.
|
|
154
|
+
# Code probably has distance 1.
|
|
155
|
+
# Accept it and keep going, though of course decoding will probably perform terribly.
|
|
156
|
+
return
|
|
157
|
+
if len(dets) == 1:
|
|
158
|
+
dets = [dets[0], num_detectors]
|
|
159
|
+
if len(dets) > 2:
|
|
160
|
+
raise NotImplementedError(
|
|
161
|
+
f"Error with more than 2 symptoms can't become an edge or boundary edge: {dets!r}.")
|
|
162
|
+
if p > 0.5:
|
|
163
|
+
# fusion_blossom doesn't support negative edge weights.
|
|
164
|
+
# approximate them as weight 0.
|
|
165
|
+
p = 0.5
|
|
166
|
+
weight = math.log((1 - p) / p)
|
|
167
|
+
mask = sum(1 << k for k in frame_changes)
|
|
168
|
+
edges.append((dets[0], dets[1], weight, mask))
|
|
169
|
+
|
|
170
|
+
def handle_detector_coords(detector: int, coords: np.ndarray):
|
|
171
|
+
pass
|
|
172
|
+
|
|
173
|
+
num_detectors = model.num_detectors
|
|
174
|
+
edges: List[Tuple[int, int, float, int]] = []
|
|
175
|
+
iter_flatten_model(
|
|
176
|
+
model,
|
|
177
|
+
handle_error=handle_error,
|
|
178
|
+
handle_detector_coords=handle_detector_coords,
|
|
179
|
+
)
|
|
180
|
+
max_weight = max(1e-4, max((w for _, _, w, _ in edges), default=1))
|
|
181
|
+
rescaled_edges = [
|
|
182
|
+
(a, b, round(w * 2**10 / max_weight) * 2)
|
|
183
|
+
for a, b, w, _ in edges
|
|
184
|
+
]
|
|
185
|
+
fault_masks = np.array([e[3] for e in edges], dtype=np.uint64)
|
|
186
|
+
|
|
187
|
+
initializer = fusion_blossom.SolverInitializer(
|
|
188
|
+
num_detectors + 1, # Total number of nodes.
|
|
189
|
+
rescaled_edges, # Weighted edges.
|
|
190
|
+
[num_detectors], # Boundary node.
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
return fusion_blossom.SolverSerial(initializer), fault_masks
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import pathlib
|
|
3
|
+
from typing import Callable, List, TYPE_CHECKING, Tuple, Any, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import stim
|
|
7
|
+
|
|
8
|
+
from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import mwpf
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def mwpf_import_error() -> ImportError:
|
|
15
|
+
return ImportError(
|
|
16
|
+
"The decoder 'MWPF' isn't installed\n"
|
|
17
|
+
"To fix this, install the python package 'MWPF' into your environment.\n"
|
|
18
|
+
"For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MwpfCompiledDecoder(CompiledDecoder):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
solver: "mwpf.SolverSerialJointSingleHair",
|
|
26
|
+
fault_masks: "np.ndarray",
|
|
27
|
+
num_dets: int,
|
|
28
|
+
num_obs: int,
|
|
29
|
+
):
|
|
30
|
+
self.solver = solver
|
|
31
|
+
self.fault_masks = fault_masks
|
|
32
|
+
self.num_dets = num_dets
|
|
33
|
+
self.num_obs = num_obs
|
|
34
|
+
|
|
35
|
+
def decode_shots_bit_packed(
|
|
36
|
+
self,
|
|
37
|
+
*,
|
|
38
|
+
bit_packed_detection_event_data: "np.ndarray",
|
|
39
|
+
) -> "np.ndarray":
|
|
40
|
+
num_shots = bit_packed_detection_event_data.shape[0]
|
|
41
|
+
predictions = np.zeros(
|
|
42
|
+
shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8
|
|
43
|
+
)
|
|
44
|
+
import mwpf
|
|
45
|
+
|
|
46
|
+
for shot in range(num_shots):
|
|
47
|
+
dets_sparse = np.flatnonzero(
|
|
48
|
+
np.unpackbits(
|
|
49
|
+
bit_packed_detection_event_data[shot],
|
|
50
|
+
count=self.num_dets,
|
|
51
|
+
bitorder="little",
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
syndrome = mwpf.SyndromePattern(defect_vertices=dets_sparse)
|
|
55
|
+
if self.solver is None:
|
|
56
|
+
prediction = 0
|
|
57
|
+
else:
|
|
58
|
+
self.solver.solve(syndrome)
|
|
59
|
+
prediction = int(
|
|
60
|
+
np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()])
|
|
61
|
+
)
|
|
62
|
+
self.solver.clear()
|
|
63
|
+
predictions[shot] = np.packbits(
|
|
64
|
+
np.array(
|
|
65
|
+
list(np.binary_repr(prediction, width=self.num_obs))[::-1],
|
|
66
|
+
dtype=np.uint8,
|
|
67
|
+
),
|
|
68
|
+
bitorder="little",
|
|
69
|
+
)
|
|
70
|
+
return predictions
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class MwpfDecoder(Decoder):
|
|
74
|
+
"""Use MWPF to predict observables from detection events."""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
decoder_cls: Any = None, # decoder class used to construct the MWPF decoder.
|
|
79
|
+
# in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins`
|
|
80
|
+
# but just provide different plugins for optimizing the primal and/or dual solutions.
|
|
81
|
+
# For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
|
|
82
|
+
# grows the clusters until the first valid solution appears; some more optimized solvers uses
|
|
83
|
+
# one or more plugins to further optimize the solution, which requires longer decoding time.
|
|
84
|
+
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster,
|
|
85
|
+
):
|
|
86
|
+
self.decoder_cls = decoder_cls
|
|
87
|
+
self.cluster_node_limit = cluster_node_limit
|
|
88
|
+
super().__init__()
|
|
89
|
+
|
|
90
|
+
def compile_decoder_for_dem(
|
|
91
|
+
self,
|
|
92
|
+
*,
|
|
93
|
+
dem: "stim.DetectorErrorModel",
|
|
94
|
+
) -> CompiledDecoder:
|
|
95
|
+
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
|
|
96
|
+
dem,
|
|
97
|
+
decoder_cls=self.decoder_cls,
|
|
98
|
+
cluster_node_limit=self.cluster_node_limit,
|
|
99
|
+
)
|
|
100
|
+
return MwpfCompiledDecoder(
|
|
101
|
+
solver,
|
|
102
|
+
fault_masks,
|
|
103
|
+
dem.num_detectors,
|
|
104
|
+
dem.num_observables,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def decode_via_files(
|
|
108
|
+
self,
|
|
109
|
+
*,
|
|
110
|
+
num_shots: int,
|
|
111
|
+
num_dets: int,
|
|
112
|
+
num_obs: int,
|
|
113
|
+
dem_path: pathlib.Path,
|
|
114
|
+
dets_b8_in_path: pathlib.Path,
|
|
115
|
+
obs_predictions_b8_out_path: pathlib.Path,
|
|
116
|
+
tmp_dir: pathlib.Path,
|
|
117
|
+
) -> None:
|
|
118
|
+
import mwpf
|
|
119
|
+
|
|
120
|
+
error_model = stim.DetectorErrorModel.from_file(dem_path)
|
|
121
|
+
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
|
|
122
|
+
error_model,
|
|
123
|
+
decoder_cls=self.decoder_cls,
|
|
124
|
+
cluster_node_limit=self.cluster_node_limit,
|
|
125
|
+
)
|
|
126
|
+
num_det_bytes = math.ceil(num_dets / 8)
|
|
127
|
+
with open(dets_b8_in_path, "rb") as dets_in_f:
|
|
128
|
+
with open(obs_predictions_b8_out_path, "wb") as obs_out_f:
|
|
129
|
+
for _ in range(num_shots):
|
|
130
|
+
dets_bit_packed = np.fromfile(
|
|
131
|
+
dets_in_f, dtype=np.uint8, count=num_det_bytes
|
|
132
|
+
)
|
|
133
|
+
if dets_bit_packed.shape != (num_det_bytes,):
|
|
134
|
+
raise IOError("Missing dets data.")
|
|
135
|
+
dets_sparse = np.flatnonzero(
|
|
136
|
+
np.unpackbits(
|
|
137
|
+
dets_bit_packed, count=num_dets, bitorder="little"
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
syndrome = mwpf.SyndromePattern(defect_vertices=dets_sparse)
|
|
141
|
+
if solver is None:
|
|
142
|
+
prediction = 0
|
|
143
|
+
else:
|
|
144
|
+
solver.solve(syndrome)
|
|
145
|
+
prediction = int(
|
|
146
|
+
np.bitwise_xor.reduce(fault_masks[solver.subgraph()])
|
|
147
|
+
)
|
|
148
|
+
solver.clear()
|
|
149
|
+
obs_out_f.write(
|
|
150
|
+
prediction.to_bytes((num_obs + 7) // 8, byteorder="little")
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class HyperUFDecoder(MwpfDecoder):
|
|
155
|
+
def __init__(self):
|
|
156
|
+
super().__init__(decoder_cls="SolverSerialUnionFind", cluster_node_limit=0)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def iter_flatten_model(
|
|
160
|
+
model: stim.DetectorErrorModel,
|
|
161
|
+
handle_error: Callable[[float, List[int], List[int]], None],
|
|
162
|
+
handle_detector_coords: Callable[[int, np.ndarray], None],
|
|
163
|
+
):
|
|
164
|
+
det_offset = 0
|
|
165
|
+
coords_offset = np.zeros(100, dtype=np.float64)
|
|
166
|
+
|
|
167
|
+
def _helper(m: stim.DetectorErrorModel, reps: int):
|
|
168
|
+
nonlocal det_offset
|
|
169
|
+
nonlocal coords_offset
|
|
170
|
+
for _ in range(reps):
|
|
171
|
+
for instruction in m:
|
|
172
|
+
if isinstance(instruction, stim.DemRepeatBlock):
|
|
173
|
+
_helper(instruction.body_copy(), instruction.repeat_count)
|
|
174
|
+
elif isinstance(instruction, stim.DemInstruction):
|
|
175
|
+
if instruction.type == "error":
|
|
176
|
+
dets: set[int] = set()
|
|
177
|
+
frames: set[int] = set()
|
|
178
|
+
t: stim.DemTarget
|
|
179
|
+
p = instruction.args_copy()[0]
|
|
180
|
+
for t in instruction.targets_copy():
|
|
181
|
+
if t.is_relative_detector_id():
|
|
182
|
+
dets ^= {t.val + det_offset}
|
|
183
|
+
elif t.is_logical_observable_id():
|
|
184
|
+
frames ^= {t.val}
|
|
185
|
+
handle_error(p, list(dets), list(frames))
|
|
186
|
+
elif instruction.type == "shift_detectors":
|
|
187
|
+
det_offset += instruction.targets_copy()[0]
|
|
188
|
+
a = np.array(instruction.args_copy())
|
|
189
|
+
coords_offset[: len(a)] += a
|
|
190
|
+
elif instruction.type == "detector":
|
|
191
|
+
a = np.array(instruction.args_copy())
|
|
192
|
+
for t in instruction.targets_copy():
|
|
193
|
+
handle_detector_coords(
|
|
194
|
+
t.val + det_offset, a + coords_offset[: len(a)]
|
|
195
|
+
)
|
|
196
|
+
elif instruction.type == "logical_observable":
|
|
197
|
+
pass
|
|
198
|
+
else:
|
|
199
|
+
raise NotImplementedError()
|
|
200
|
+
else:
|
|
201
|
+
raise NotImplementedError()
|
|
202
|
+
|
|
203
|
+
_helper(model, 1)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def deduplicate_hyperedges(
|
|
207
|
+
hyperedges: List[Tuple[List[int], float, int]]
|
|
208
|
+
) -> List[Tuple[List[int], float, int]]:
|
|
209
|
+
indices: dict[frozenset[int], Tuple[int, float]] = dict()
|
|
210
|
+
result: List[Tuple[List[int], float, int]] = []
|
|
211
|
+
for dets, weight, mask in hyperedges:
|
|
212
|
+
dets_set = frozenset(dets)
|
|
213
|
+
if dets_set in indices:
|
|
214
|
+
idx, min_weight = indices[dets_set]
|
|
215
|
+
p1 = 1 / (1 + math.exp(weight))
|
|
216
|
+
p2 = 1 / (1 + math.exp(result[idx][1]))
|
|
217
|
+
p = p1 * (1 - p2) + p2 * (1 - p1)
|
|
218
|
+
# choosing the mask from the most likely error
|
|
219
|
+
new_mask = result[idx][2]
|
|
220
|
+
if weight < min_weight:
|
|
221
|
+
indices[dets_set] = (idx, weight)
|
|
222
|
+
new_mask = mask
|
|
223
|
+
result[idx] = (dets, math.log((1 - p) / p), new_mask)
|
|
224
|
+
else:
|
|
225
|
+
indices[dets_set] = (len(result), weight)
|
|
226
|
+
result.append((dets, weight, mask))
|
|
227
|
+
return result
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def detector_error_model_to_mwpf_solver_and_fault_masks(
|
|
231
|
+
model: stim.DetectorErrorModel,
|
|
232
|
+
decoder_cls: Any = None,
|
|
233
|
+
cluster_node_limit: int = 50,
|
|
234
|
+
) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]:
|
|
235
|
+
"""Convert a stim error model into a NetworkX graph."""
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
import mwpf
|
|
239
|
+
except ImportError as ex:
|
|
240
|
+
raise mwpf_import_error() from ex
|
|
241
|
+
|
|
242
|
+
num_detectors = model.num_detectors
|
|
243
|
+
is_detector_connected = np.full(num_detectors, False, dtype=bool)
|
|
244
|
+
hyperedges: List[Tuple[List[int], float, int]] = []
|
|
245
|
+
|
|
246
|
+
def handle_error(p: float, dets: List[int], frame_changes: List[int]):
|
|
247
|
+
if p == 0:
|
|
248
|
+
return
|
|
249
|
+
if len(dets) == 0:
|
|
250
|
+
# No symptoms for this error.
|
|
251
|
+
# Code probably has distance 1.
|
|
252
|
+
# Accept it and keep going, though of course decoding will probably perform terribly.
|
|
253
|
+
return
|
|
254
|
+
if p > 0.5:
|
|
255
|
+
# mwpf doesn't support negative edge weights (yet, will be supported in the next version).
|
|
256
|
+
# approximate them as weight 0.
|
|
257
|
+
p = 0.5
|
|
258
|
+
weight = math.log((1 - p) / p)
|
|
259
|
+
mask = sum(1 << k for k in frame_changes)
|
|
260
|
+
is_detector_connected[dets] = True
|
|
261
|
+
hyperedges.append((dets, weight, mask))
|
|
262
|
+
|
|
263
|
+
def handle_detector_coords(detector: int, coords: np.ndarray):
|
|
264
|
+
pass
|
|
265
|
+
|
|
266
|
+
iter_flatten_model(
|
|
267
|
+
model,
|
|
268
|
+
handle_error=handle_error,
|
|
269
|
+
handle_detector_coords=handle_detector_coords,
|
|
270
|
+
)
|
|
271
|
+
# mwpf package panic on duplicate edges, thus we need to handle them here
|
|
272
|
+
hyperedges = deduplicate_hyperedges(hyperedges)
|
|
273
|
+
|
|
274
|
+
# fix the input by connecting an edge to all isolated vertices; will be supported in the next version
|
|
275
|
+
for idx in range(num_detectors):
|
|
276
|
+
if not is_detector_connected[idx]:
|
|
277
|
+
hyperedges.append(([idx], 0, 0))
|
|
278
|
+
|
|
279
|
+
max_weight = max(1e-4, max((w for _, w, _ in hyperedges), default=1))
|
|
280
|
+
rescaled_edges = [
|
|
281
|
+
mwpf.HyperEdge(v, round(w * 2**10 / max_weight) * 2) for v, w, _ in hyperedges
|
|
282
|
+
]
|
|
283
|
+
fault_masks = np.array([e[2] for e in hyperedges], dtype=np.uint64)
|
|
284
|
+
|
|
285
|
+
initializer = mwpf.SolverInitializer(
|
|
286
|
+
num_detectors, # Total number of nodes.
|
|
287
|
+
rescaled_edges, # Weighted edges.
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
if decoder_cls is None:
|
|
291
|
+
# default to the solver with highest accuracy
|
|
292
|
+
decoder_cls = mwpf.SolverSerialJointSingleHair
|
|
293
|
+
elif isinstance(decoder_cls, str):
|
|
294
|
+
decoder_cls = getattr(mwpf, decoder_cls)
|
|
295
|
+
return (
|
|
296
|
+
(
|
|
297
|
+
decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})
|
|
298
|
+
if num_detectors > 0 and len(rescaled_edges) > 0
|
|
299
|
+
else None
|
|
300
|
+
),
|
|
301
|
+
fault_masks,
|
|
302
|
+
)
|