sinter 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sinter might be problematic. Click here for more details.

Files changed (62) hide show
  1. sinter/__init__.py +47 -0
  2. sinter/_collection/__init__.py +10 -0
  3. sinter/_collection/_collection.py +480 -0
  4. sinter/_collection/_collection_manager.py +581 -0
  5. sinter/_collection/_collection_manager_test.py +287 -0
  6. sinter/_collection/_collection_test.py +317 -0
  7. sinter/_collection/_collection_worker_loop.py +35 -0
  8. sinter/_collection/_collection_worker_state.py +259 -0
  9. sinter/_collection/_collection_worker_test.py +222 -0
  10. sinter/_collection/_mux_sampler.py +56 -0
  11. sinter/_collection/_printer.py +65 -0
  12. sinter/_collection/_sampler_ramp_throttled.py +66 -0
  13. sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
  14. sinter/_command/__init__.py +0 -0
  15. sinter/_command/_main.py +39 -0
  16. sinter/_command/_main_collect.py +350 -0
  17. sinter/_command/_main_collect_test.py +482 -0
  18. sinter/_command/_main_combine.py +84 -0
  19. sinter/_command/_main_combine_test.py +153 -0
  20. sinter/_command/_main_plot.py +817 -0
  21. sinter/_command/_main_plot_test.py +445 -0
  22. sinter/_command/_main_predict.py +75 -0
  23. sinter/_command/_main_predict_test.py +36 -0
  24. sinter/_data/__init__.py +20 -0
  25. sinter/_data/_anon_task_stats.py +89 -0
  26. sinter/_data/_anon_task_stats_test.py +35 -0
  27. sinter/_data/_collection_options.py +106 -0
  28. sinter/_data/_collection_options_test.py +24 -0
  29. sinter/_data/_csv_out.py +74 -0
  30. sinter/_data/_existing_data.py +173 -0
  31. sinter/_data/_existing_data_test.py +41 -0
  32. sinter/_data/_task.py +311 -0
  33. sinter/_data/_task_stats.py +244 -0
  34. sinter/_data/_task_stats_test.py +140 -0
  35. sinter/_data/_task_test.py +38 -0
  36. sinter/_decoding/__init__.py +16 -0
  37. sinter/_decoding/_decoding.py +419 -0
  38. sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
  39. sinter/_decoding/_decoding_decoder_class.py +161 -0
  40. sinter/_decoding/_decoding_fusion_blossom.py +193 -0
  41. sinter/_decoding/_decoding_mwpf.py +302 -0
  42. sinter/_decoding/_decoding_pymatching.py +81 -0
  43. sinter/_decoding/_decoding_test.py +480 -0
  44. sinter/_decoding/_decoding_vacuous.py +38 -0
  45. sinter/_decoding/_perfectionist_sampler.py +38 -0
  46. sinter/_decoding/_sampler.py +72 -0
  47. sinter/_decoding/_stim_then_decode_sampler.py +222 -0
  48. sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
  49. sinter/_plotting.py +619 -0
  50. sinter/_plotting_test.py +108 -0
  51. sinter/_predict.py +381 -0
  52. sinter/_predict_test.py +227 -0
  53. sinter/_probability_util.py +519 -0
  54. sinter/_probability_util_test.py +281 -0
  55. sinter-1.15.0.data/data/README.md +332 -0
  56. sinter-1.15.0.data/data/readme_example_plot.png +0 -0
  57. sinter-1.15.0.data/data/requirements.txt +4 -0
  58. sinter-1.15.0.dist-info/METADATA +354 -0
  59. sinter-1.15.0.dist-info/RECORD +62 -0
  60. sinter-1.15.0.dist-info/WHEEL +5 -0
  61. sinter-1.15.0.dist-info/entry_points.txt +2 -0
  62. sinter-1.15.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,161 @@
1
+ import abc
2
+ import pathlib
3
+
4
+ import numpy as np
5
+ import stim
6
+
7
+
8
+ class CompiledDecoder(metaclass=abc.ABCMeta):
9
+ """Abstract class for decoders preconfigured to a specific decoding task.
10
+
11
+ This is the type returned by `sinter.Decoder.compile_decoder_for_dem`. The
12
+ idea is that, when many shots of the same decoding task are going to be
13
+ performed, it is valuable to pay the cost of configuring the decoder only
14
+ once instead of once per batch of shots. Custom decoders can optionally
15
+ implement that method, and return this type, to increase sampling
16
+ efficiency.
17
+ """
18
+
19
+ @abc.abstractmethod
20
+ def decode_shots_bit_packed(
21
+ self,
22
+ *,
23
+ bit_packed_detection_event_data: np.ndarray,
24
+ ) -> np.ndarray:
25
+ """Predicts observable flips from the given detection events.
26
+
27
+ All data taken and returned must be bit packed with bitorder='little'.
28
+
29
+ Args:
30
+ bit_packed_detection_event_data: Detection event data stored as a
31
+ bit packed numpy array. The numpy array will have the following
32
+ dtype/shape:
33
+
34
+ dtype: uint8
35
+ shape: (num_shots, ceil(dem.num_detectors / 8))
36
+
37
+ where `num_shots` is the number of shots to decoder and `dem` is
38
+ the detector error model this instance was compiled to decode.
39
+
40
+ It's guaranteed that the data will be laid out in memory so that
41
+ detection events within a shot are contiguous in memory (i.e.
42
+ that bit_packed_detection_event_data.strides[1] == 1).
43
+
44
+ Returns:
45
+ Bit packed observable flip data stored as a bit packed numpy array.
46
+ The numpy array must have the following dtype/shape:
47
+
48
+ dtype: uint8
49
+ shape: (num_shots, ceil(dem.num_observables / 8))
50
+
51
+ where `num_shots` is bit_packed_detection_event_data.shape[0] and
52
+ `dem` is the detector error model this instance was compiled to
53
+ decode.
54
+ """
55
+ pass
56
+
57
+
58
+ class Decoder:
59
+ """Abstract base class for custom decoders.
60
+
61
+ Custom decoders can be explained to sinter by inheriting from this class and
62
+ implementing its methods.
63
+
64
+ Decoder classes MUST be serializable (e.g. via pickling), so that they can
65
+ be given to worker processes when using python multiprocessing.
66
+
67
+ Child classes should implement `compile_decoder_for_dem`, but (for legacy
68
+ reasons) can alternatively implement `decode_via_files`. At least one of
69
+ the two methods must be implemented.
70
+ """
71
+
72
+ def compile_decoder_for_dem(
73
+ self,
74
+ *,
75
+ dem: stim.DetectorErrorModel,
76
+ ) -> CompiledDecoder:
77
+ """Creates a decoder preconfigured for the given detector error model.
78
+
79
+ This method is optional to implement. By default, it will raise a
80
+ NotImplementedError. When sampling, sinter will attempt to use this
81
+ method first and otherwise fallback to using `decode_via_files`.
82
+
83
+ The idea is that the preconfigured decoder amortizes the cost of
84
+ configuration over more calls. This makes smaller batch sizes efficient,
85
+ reducing the amount of memory used for storing each batch, improving
86
+ overall efficiency.
87
+
88
+ Args:
89
+ dem: A detector error model for the samples that will need to be
90
+ decoded. What to configure the decoder to decode.
91
+
92
+ Returns:
93
+ An instance of `sinter.CompiledDecoder` that can be used to invoke
94
+ the preconfigured decoder.
95
+
96
+ Raises:
97
+ NotImplementedError: This `sinter.Decoder` doesn't support compiling
98
+ for a dem.
99
+ """
100
+ raise NotImplementedError('compile_decoder_for_dem')
101
+
102
+ def decode_via_files(self,
103
+ *,
104
+ num_shots: int,
105
+ num_dets: int,
106
+ num_obs: int,
107
+ dem_path: pathlib.Path,
108
+ dets_b8_in_path: pathlib.Path,
109
+ obs_predictions_b8_out_path: pathlib.Path,
110
+ tmp_dir: pathlib.Path,
111
+ ) -> None:
112
+ """Performs decoding by reading/writing problems and answers from disk.
113
+
114
+ Args:
115
+ num_shots: The number of times the circuit was sampled. The number
116
+ of problems to be solved.
117
+ num_dets: The number of detectors in the circuit. The number of
118
+ detection event bits in each shot.
119
+ num_obs: The number of observables in the circuit. The number of
120
+ predicted bits in each shot.
121
+ dem_path: The file path where the detector error model should be
122
+ read from, e.g. using `stim.DetectorErrorModel.from_file`. The
123
+ error mechanisms specified by the detector error model should be
124
+ used to configure the decoder.
125
+ dets_b8_in_path: The file path that detection event data should be
126
+ read from. Note that the file may be a named pipe instead of a
127
+ fixed size object. The detection events will be in b8 format
128
+ (see
129
+ https://github.com/quantumlib/Stim/blob/main/doc/result_formats.md
130
+ ). The number of detection events per shot is available via the
131
+ `num_dets` argument or via the detector error model at
132
+ `dem_path`.
133
+ obs_predictions_b8_out_path: The file path that decoder predictions
134
+ must be written to. The predictions must be written in b8 format
135
+ (see
136
+ https://github.com/quantumlib/Stim/blob/main/doc/result_formats.md
137
+ ). The number of observables per shot is available via the
138
+ `num_obs` argument or via the detector error model at
139
+ `dem_path`.
140
+ tmp_dir: Any temporary files generated by the decoder during its
141
+ operation MUST be put into this directory. The reason for this
142
+ requirement is because sinter is allowed to kill the decoding
143
+ process without warning, without giving it time to clean up any
144
+ temporary objects. All cleanup should be done via sinter
145
+ deleting this directory after killing the decoder.
146
+ """
147
+ dem = stim.DetectorErrorModel.from_file(dem_path)
148
+
149
+ try:
150
+ compiled = self.compile_decoder_for_dem(dem=dem)
151
+ except NotImplementedError as ex:
152
+ raise NotImplementedError(f"{type(self).__qualname__} didn't implement `compile_decoder_for_dem` or `decode_via_files`.") from ex
153
+
154
+ num_det_bytes = -(-num_dets // 8)
155
+ num_obs_bytes = -(-num_obs // 8)
156
+ dets = np.fromfile(dets_b8_in_path, dtype=np.uint8, count=num_shots * num_det_bytes)
157
+ dets = dets.reshape(num_shots, num_det_bytes)
158
+ obs = compiled.decode_shots_bit_packed(bit_packed_detection_event_data=dets)
159
+ if obs.dtype != np.uint8 or obs.shape != (num_shots, num_obs_bytes):
160
+ raise ValueError(f"Got a numpy array with dtype={obs.dtype},shape={obs.shape} instead of dtype={np.uint8},shape={(num_shots, num_obs_bytes)} from {type(self).__qualname__}(...).compile_decoder_for_dem(...).decode_shots_bit_packed(...).")
161
+ obs.tofile(obs_predictions_b8_out_path)
@@ -0,0 +1,193 @@
1
+ import math
2
+ import pathlib
3
+ from typing import Callable, List, TYPE_CHECKING, Tuple
4
+
5
+ import numpy as np
6
+ import stim
7
+
8
+ from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder
9
+
10
+ if TYPE_CHECKING:
11
+ import fusion_blossom
12
+
13
+
14
+ class FusionBlossomCompiledDecoder(CompiledDecoder):
15
+ def __init__(self, solver: 'fusion_blossom.SolverSerial', fault_masks: 'np.ndarray', num_dets: int, num_obs: int):
16
+ self.solver = solver
17
+ self.fault_masks = fault_masks
18
+ self.num_dets = num_dets
19
+ self.num_obs = num_obs
20
+
21
+ def decode_shots_bit_packed(
22
+ self,
23
+ *,
24
+ bit_packed_detection_event_data: 'np.ndarray',
25
+ ) -> 'np.ndarray':
26
+ num_shots = bit_packed_detection_event_data.shape[0]
27
+ predictions = np.zeros(shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8)
28
+ import fusion_blossom
29
+ for shot in range(num_shots):
30
+ dets_sparse = np.flatnonzero(np.unpackbits(bit_packed_detection_event_data[shot], count=self.num_dets, bitorder='little'))
31
+ syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse)
32
+ self.solver.solve(syndrome)
33
+ prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]))
34
+ predictions[shot] = np.packbits(
35
+ np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8),
36
+ bitorder="little",
37
+ )
38
+ self.solver.clear()
39
+ return predictions
40
+
41
+
42
+ class FusionBlossomDecoder(Decoder):
43
+ """Use fusion blossom to predict observables from detection events."""
44
+
45
+ def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder:
46
+ try:
47
+ import fusion_blossom
48
+ except ImportError as ex:
49
+ raise ImportError(
50
+ "The decoder 'fusion_blossom' isn't installed\n"
51
+ "To fix this, install the python package 'fusion_blossom' into your environment.\n"
52
+ "For example, if you are using pip, run `pip install fusion_blossom`.\n"
53
+ ) from ex
54
+
55
+ solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(dem)
56
+ return FusionBlossomCompiledDecoder(solver, fault_masks, dem.num_detectors, dem.num_observables)
57
+
58
+
59
+ def decode_via_files(self,
60
+ *,
61
+ num_shots: int,
62
+ num_dets: int,
63
+ num_obs: int,
64
+ dem_path: pathlib.Path,
65
+ dets_b8_in_path: pathlib.Path,
66
+ obs_predictions_b8_out_path: pathlib.Path,
67
+ tmp_dir: pathlib.Path,
68
+ ) -> None:
69
+ try:
70
+ import fusion_blossom
71
+ except ImportError as ex:
72
+ raise ImportError(
73
+ "The decoder 'fusion_blossom' isn't installed\n"
74
+ "To fix this, install the python package 'fusion-blossom' into your environment.\n"
75
+ "For example, if you are using pip, run `pip install fusion-blossom~=0.1.4`.\n"
76
+ ) from ex
77
+
78
+ error_model = stim.DetectorErrorModel.from_file(dem_path)
79
+ solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(error_model)
80
+ num_det_bytes = math.ceil(num_dets / 8)
81
+ with open(dets_b8_in_path, 'rb') as dets_in_f:
82
+ with open(obs_predictions_b8_out_path, 'wb') as obs_out_f:
83
+ for _ in range(num_shots):
84
+ dets_bit_packed = np.fromfile(dets_in_f, dtype=np.uint8, count=num_det_bytes)
85
+ if dets_bit_packed.shape != (num_det_bytes,):
86
+ raise IOError('Missing dets data.')
87
+ dets_sparse = np.flatnonzero(np.unpackbits(dets_bit_packed, count=num_dets, bitorder='little'))
88
+ syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse)
89
+ solver.solve(syndrome)
90
+ prediction = int(np.bitwise_xor.reduce(fault_masks[solver.subgraph()]))
91
+ obs_out_f.write(prediction.to_bytes((num_obs + 7) // 8, byteorder='little'))
92
+ solver.clear()
93
+
94
+
95
+ def iter_flatten_model(model: stim.DetectorErrorModel,
96
+ handle_error: Callable[[float, List[int], List[int]], None],
97
+ handle_detector_coords: Callable[[int, np.ndarray], None]):
98
+ det_offset = 0
99
+ coords_offset = np.zeros(100, dtype=np.float64)
100
+
101
+ def _helper(m: stim.DetectorErrorModel, reps: int):
102
+ nonlocal det_offset
103
+ nonlocal coords_offset
104
+ for _ in range(reps):
105
+ for instruction in m:
106
+ if isinstance(instruction, stim.DemRepeatBlock):
107
+ _helper(instruction.body_copy(), instruction.repeat_count)
108
+ elif isinstance(instruction, stim.DemInstruction):
109
+ if instruction.type == "error":
110
+ dets: List[int] = []
111
+ frames: List[int] = []
112
+ t: stim.DemTarget
113
+ p = instruction.args_copy()[0]
114
+ for t in instruction.targets_copy():
115
+ if t.is_relative_detector_id():
116
+ dets.append(t.val + det_offset)
117
+ elif t.is_logical_observable_id():
118
+ frames.append(t.val)
119
+ elif t.is_separator():
120
+ # Treat each component of a decomposed error as an independent error.
121
+ # (Ideally we could configure some sort of correlated analysis; oh well.)
122
+ handle_error(p, dets, frames)
123
+ frames = []
124
+ dets = []
125
+ # Handle last component.
126
+ handle_error(p, dets, frames)
127
+ elif instruction.type == "shift_detectors":
128
+ det_offset += instruction.targets_copy()[0]
129
+ a = np.array(instruction.args_copy())
130
+ coords_offset[:len(a)] += a
131
+ elif instruction.type == "detector":
132
+ a = np.array(instruction.args_copy())
133
+ for t in instruction.targets_copy():
134
+ handle_detector_coords(t.val + det_offset, a + coords_offset[:len(a)])
135
+ elif instruction.type == "logical_observable":
136
+ pass
137
+ else:
138
+ raise NotImplementedError()
139
+ else:
140
+ raise NotImplementedError()
141
+ _helper(model, 1)
142
+
143
+
144
+ def detector_error_model_to_fusion_blossom_solver_and_fault_masks(model: stim.DetectorErrorModel) -> Tuple['fusion_blossom.SolverSerial', np.ndarray]:
145
+ """Convert a stim error model into a NetworkX graph."""
146
+
147
+ import fusion_blossom
148
+
149
+ def handle_error(p: float, dets: List[int], frame_changes: List[int]):
150
+ if p == 0:
151
+ return
152
+ if len(dets) == 0:
153
+ # No symptoms for this error.
154
+ # Code probably has distance 1.
155
+ # Accept it and keep going, though of course decoding will probably perform terribly.
156
+ return
157
+ if len(dets) == 1:
158
+ dets = [dets[0], num_detectors]
159
+ if len(dets) > 2:
160
+ raise NotImplementedError(
161
+ f"Error with more than 2 symptoms can't become an edge or boundary edge: {dets!r}.")
162
+ if p > 0.5:
163
+ # fusion_blossom doesn't support negative edge weights.
164
+ # approximate them as weight 0.
165
+ p = 0.5
166
+ weight = math.log((1 - p) / p)
167
+ mask = sum(1 << k for k in frame_changes)
168
+ edges.append((dets[0], dets[1], weight, mask))
169
+
170
+ def handle_detector_coords(detector: int, coords: np.ndarray):
171
+ pass
172
+
173
+ num_detectors = model.num_detectors
174
+ edges: List[Tuple[int, int, float, int]] = []
175
+ iter_flatten_model(
176
+ model,
177
+ handle_error=handle_error,
178
+ handle_detector_coords=handle_detector_coords,
179
+ )
180
+ max_weight = max(1e-4, max((w for _, _, w, _ in edges), default=1))
181
+ rescaled_edges = [
182
+ (a, b, round(w * 2**10 / max_weight) * 2)
183
+ for a, b, w, _ in edges
184
+ ]
185
+ fault_masks = np.array([e[3] for e in edges], dtype=np.uint64)
186
+
187
+ initializer = fusion_blossom.SolverInitializer(
188
+ num_detectors + 1, # Total number of nodes.
189
+ rescaled_edges, # Weighted edges.
190
+ [num_detectors], # Boundary node.
191
+ )
192
+
193
+ return fusion_blossom.SolverSerial(initializer), fault_masks
@@ -0,0 +1,302 @@
1
+ import math
2
+ import pathlib
3
+ from typing import Callable, List, TYPE_CHECKING, Tuple, Any, Optional
4
+
5
+ import numpy as np
6
+ import stim
7
+
8
+ from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder
9
+
10
+ if TYPE_CHECKING:
11
+ import mwpf
12
+
13
+
14
+ def mwpf_import_error() -> ImportError:
15
+ return ImportError(
16
+ "The decoder 'MWPF' isn't installed\n"
17
+ "To fix this, install the python package 'MWPF' into your environment.\n"
18
+ "For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n"
19
+ )
20
+
21
+
22
+ class MwpfCompiledDecoder(CompiledDecoder):
23
+ def __init__(
24
+ self,
25
+ solver: "mwpf.SolverSerialJointSingleHair",
26
+ fault_masks: "np.ndarray",
27
+ num_dets: int,
28
+ num_obs: int,
29
+ ):
30
+ self.solver = solver
31
+ self.fault_masks = fault_masks
32
+ self.num_dets = num_dets
33
+ self.num_obs = num_obs
34
+
35
+ def decode_shots_bit_packed(
36
+ self,
37
+ *,
38
+ bit_packed_detection_event_data: "np.ndarray",
39
+ ) -> "np.ndarray":
40
+ num_shots = bit_packed_detection_event_data.shape[0]
41
+ predictions = np.zeros(
42
+ shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8
43
+ )
44
+ import mwpf
45
+
46
+ for shot in range(num_shots):
47
+ dets_sparse = np.flatnonzero(
48
+ np.unpackbits(
49
+ bit_packed_detection_event_data[shot],
50
+ count=self.num_dets,
51
+ bitorder="little",
52
+ )
53
+ )
54
+ syndrome = mwpf.SyndromePattern(defect_vertices=dets_sparse)
55
+ if self.solver is None:
56
+ prediction = 0
57
+ else:
58
+ self.solver.solve(syndrome)
59
+ prediction = int(
60
+ np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()])
61
+ )
62
+ self.solver.clear()
63
+ predictions[shot] = np.packbits(
64
+ np.array(
65
+ list(np.binary_repr(prediction, width=self.num_obs))[::-1],
66
+ dtype=np.uint8,
67
+ ),
68
+ bitorder="little",
69
+ )
70
+ return predictions
71
+
72
+
73
+ class MwpfDecoder(Decoder):
74
+ """Use MWPF to predict observables from detection events."""
75
+
76
+ def __init__(
77
+ self,
78
+ decoder_cls: Any = None, # decoder class used to construct the MWPF decoder.
79
+ # in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins`
80
+ # but just provide different plugins for optimizing the primal and/or dual solutions.
81
+ # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
82
+ # grows the clusters until the first valid solution appears; some more optimized solvers uses
83
+ # one or more plugins to further optimize the solution, which requires longer decoding time.
84
+ cluster_node_limit: int = 50, # The maximum number of nodes in a cluster,
85
+ ):
86
+ self.decoder_cls = decoder_cls
87
+ self.cluster_node_limit = cluster_node_limit
88
+ super().__init__()
89
+
90
+ def compile_decoder_for_dem(
91
+ self,
92
+ *,
93
+ dem: "stim.DetectorErrorModel",
94
+ ) -> CompiledDecoder:
95
+ solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
96
+ dem,
97
+ decoder_cls=self.decoder_cls,
98
+ cluster_node_limit=self.cluster_node_limit,
99
+ )
100
+ return MwpfCompiledDecoder(
101
+ solver,
102
+ fault_masks,
103
+ dem.num_detectors,
104
+ dem.num_observables,
105
+ )
106
+
107
+ def decode_via_files(
108
+ self,
109
+ *,
110
+ num_shots: int,
111
+ num_dets: int,
112
+ num_obs: int,
113
+ dem_path: pathlib.Path,
114
+ dets_b8_in_path: pathlib.Path,
115
+ obs_predictions_b8_out_path: pathlib.Path,
116
+ tmp_dir: pathlib.Path,
117
+ ) -> None:
118
+ import mwpf
119
+
120
+ error_model = stim.DetectorErrorModel.from_file(dem_path)
121
+ solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
122
+ error_model,
123
+ decoder_cls=self.decoder_cls,
124
+ cluster_node_limit=self.cluster_node_limit,
125
+ )
126
+ num_det_bytes = math.ceil(num_dets / 8)
127
+ with open(dets_b8_in_path, "rb") as dets_in_f:
128
+ with open(obs_predictions_b8_out_path, "wb") as obs_out_f:
129
+ for _ in range(num_shots):
130
+ dets_bit_packed = np.fromfile(
131
+ dets_in_f, dtype=np.uint8, count=num_det_bytes
132
+ )
133
+ if dets_bit_packed.shape != (num_det_bytes,):
134
+ raise IOError("Missing dets data.")
135
+ dets_sparse = np.flatnonzero(
136
+ np.unpackbits(
137
+ dets_bit_packed, count=num_dets, bitorder="little"
138
+ )
139
+ )
140
+ syndrome = mwpf.SyndromePattern(defect_vertices=dets_sparse)
141
+ if solver is None:
142
+ prediction = 0
143
+ else:
144
+ solver.solve(syndrome)
145
+ prediction = int(
146
+ np.bitwise_xor.reduce(fault_masks[solver.subgraph()])
147
+ )
148
+ solver.clear()
149
+ obs_out_f.write(
150
+ prediction.to_bytes((num_obs + 7) // 8, byteorder="little")
151
+ )
152
+
153
+
154
+ class HyperUFDecoder(MwpfDecoder):
155
+ def __init__(self):
156
+ super().__init__(decoder_cls="SolverSerialUnionFind", cluster_node_limit=0)
157
+
158
+
159
+ def iter_flatten_model(
160
+ model: stim.DetectorErrorModel,
161
+ handle_error: Callable[[float, List[int], List[int]], None],
162
+ handle_detector_coords: Callable[[int, np.ndarray], None],
163
+ ):
164
+ det_offset = 0
165
+ coords_offset = np.zeros(100, dtype=np.float64)
166
+
167
+ def _helper(m: stim.DetectorErrorModel, reps: int):
168
+ nonlocal det_offset
169
+ nonlocal coords_offset
170
+ for _ in range(reps):
171
+ for instruction in m:
172
+ if isinstance(instruction, stim.DemRepeatBlock):
173
+ _helper(instruction.body_copy(), instruction.repeat_count)
174
+ elif isinstance(instruction, stim.DemInstruction):
175
+ if instruction.type == "error":
176
+ dets: set[int] = set()
177
+ frames: set[int] = set()
178
+ t: stim.DemTarget
179
+ p = instruction.args_copy()[0]
180
+ for t in instruction.targets_copy():
181
+ if t.is_relative_detector_id():
182
+ dets ^= {t.val + det_offset}
183
+ elif t.is_logical_observable_id():
184
+ frames ^= {t.val}
185
+ handle_error(p, list(dets), list(frames))
186
+ elif instruction.type == "shift_detectors":
187
+ det_offset += instruction.targets_copy()[0]
188
+ a = np.array(instruction.args_copy())
189
+ coords_offset[: len(a)] += a
190
+ elif instruction.type == "detector":
191
+ a = np.array(instruction.args_copy())
192
+ for t in instruction.targets_copy():
193
+ handle_detector_coords(
194
+ t.val + det_offset, a + coords_offset[: len(a)]
195
+ )
196
+ elif instruction.type == "logical_observable":
197
+ pass
198
+ else:
199
+ raise NotImplementedError()
200
+ else:
201
+ raise NotImplementedError()
202
+
203
+ _helper(model, 1)
204
+
205
+
206
+ def deduplicate_hyperedges(
207
+ hyperedges: List[Tuple[List[int], float, int]]
208
+ ) -> List[Tuple[List[int], float, int]]:
209
+ indices: dict[frozenset[int], Tuple[int, float]] = dict()
210
+ result: List[Tuple[List[int], float, int]] = []
211
+ for dets, weight, mask in hyperedges:
212
+ dets_set = frozenset(dets)
213
+ if dets_set in indices:
214
+ idx, min_weight = indices[dets_set]
215
+ p1 = 1 / (1 + math.exp(weight))
216
+ p2 = 1 / (1 + math.exp(result[idx][1]))
217
+ p = p1 * (1 - p2) + p2 * (1 - p1)
218
+ # choosing the mask from the most likely error
219
+ new_mask = result[idx][2]
220
+ if weight < min_weight:
221
+ indices[dets_set] = (idx, weight)
222
+ new_mask = mask
223
+ result[idx] = (dets, math.log((1 - p) / p), new_mask)
224
+ else:
225
+ indices[dets_set] = (len(result), weight)
226
+ result.append((dets, weight, mask))
227
+ return result
228
+
229
+
230
+ def detector_error_model_to_mwpf_solver_and_fault_masks(
231
+ model: stim.DetectorErrorModel,
232
+ decoder_cls: Any = None,
233
+ cluster_node_limit: int = 50,
234
+ ) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]:
235
+ """Convert a stim error model into a NetworkX graph."""
236
+
237
+ try:
238
+ import mwpf
239
+ except ImportError as ex:
240
+ raise mwpf_import_error() from ex
241
+
242
+ num_detectors = model.num_detectors
243
+ is_detector_connected = np.full(num_detectors, False, dtype=bool)
244
+ hyperedges: List[Tuple[List[int], float, int]] = []
245
+
246
+ def handle_error(p: float, dets: List[int], frame_changes: List[int]):
247
+ if p == 0:
248
+ return
249
+ if len(dets) == 0:
250
+ # No symptoms for this error.
251
+ # Code probably has distance 1.
252
+ # Accept it and keep going, though of course decoding will probably perform terribly.
253
+ return
254
+ if p > 0.5:
255
+ # mwpf doesn't support negative edge weights (yet, will be supported in the next version).
256
+ # approximate them as weight 0.
257
+ p = 0.5
258
+ weight = math.log((1 - p) / p)
259
+ mask = sum(1 << k for k in frame_changes)
260
+ is_detector_connected[dets] = True
261
+ hyperedges.append((dets, weight, mask))
262
+
263
+ def handle_detector_coords(detector: int, coords: np.ndarray):
264
+ pass
265
+
266
+ iter_flatten_model(
267
+ model,
268
+ handle_error=handle_error,
269
+ handle_detector_coords=handle_detector_coords,
270
+ )
271
+ # mwpf package panic on duplicate edges, thus we need to handle them here
272
+ hyperedges = deduplicate_hyperedges(hyperedges)
273
+
274
+ # fix the input by connecting an edge to all isolated vertices; will be supported in the next version
275
+ for idx in range(num_detectors):
276
+ if not is_detector_connected[idx]:
277
+ hyperedges.append(([idx], 0, 0))
278
+
279
+ max_weight = max(1e-4, max((w for _, w, _ in hyperedges), default=1))
280
+ rescaled_edges = [
281
+ mwpf.HyperEdge(v, round(w * 2**10 / max_weight) * 2) for v, w, _ in hyperedges
282
+ ]
283
+ fault_masks = np.array([e[2] for e in hyperedges], dtype=np.uint64)
284
+
285
+ initializer = mwpf.SolverInitializer(
286
+ num_detectors, # Total number of nodes.
287
+ rescaled_edges, # Weighted edges.
288
+ )
289
+
290
+ if decoder_cls is None:
291
+ # default to the solver with highest accuracy
292
+ decoder_cls = mwpf.SolverSerialJointSingleHair
293
+ elif isinstance(decoder_cls, str):
294
+ decoder_cls = getattr(mwpf, decoder_cls)
295
+ return (
296
+ (
297
+ decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})
298
+ if num_detectors > 0 and len(rescaled_edges) > 0
299
+ else None
300
+ ),
301
+ fault_masks,
302
+ )