sinter 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sinter might be problematic. Click here for more details.

Files changed (62) hide show
  1. sinter/__init__.py +47 -0
  2. sinter/_collection/__init__.py +10 -0
  3. sinter/_collection/_collection.py +480 -0
  4. sinter/_collection/_collection_manager.py +581 -0
  5. sinter/_collection/_collection_manager_test.py +287 -0
  6. sinter/_collection/_collection_test.py +317 -0
  7. sinter/_collection/_collection_worker_loop.py +35 -0
  8. sinter/_collection/_collection_worker_state.py +259 -0
  9. sinter/_collection/_collection_worker_test.py +222 -0
  10. sinter/_collection/_mux_sampler.py +56 -0
  11. sinter/_collection/_printer.py +65 -0
  12. sinter/_collection/_sampler_ramp_throttled.py +66 -0
  13. sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
  14. sinter/_command/__init__.py +0 -0
  15. sinter/_command/_main.py +39 -0
  16. sinter/_command/_main_collect.py +350 -0
  17. sinter/_command/_main_collect_test.py +482 -0
  18. sinter/_command/_main_combine.py +84 -0
  19. sinter/_command/_main_combine_test.py +153 -0
  20. sinter/_command/_main_plot.py +817 -0
  21. sinter/_command/_main_plot_test.py +445 -0
  22. sinter/_command/_main_predict.py +75 -0
  23. sinter/_command/_main_predict_test.py +36 -0
  24. sinter/_data/__init__.py +20 -0
  25. sinter/_data/_anon_task_stats.py +89 -0
  26. sinter/_data/_anon_task_stats_test.py +35 -0
  27. sinter/_data/_collection_options.py +106 -0
  28. sinter/_data/_collection_options_test.py +24 -0
  29. sinter/_data/_csv_out.py +74 -0
  30. sinter/_data/_existing_data.py +173 -0
  31. sinter/_data/_existing_data_test.py +41 -0
  32. sinter/_data/_task.py +311 -0
  33. sinter/_data/_task_stats.py +244 -0
  34. sinter/_data/_task_stats_test.py +140 -0
  35. sinter/_data/_task_test.py +38 -0
  36. sinter/_decoding/__init__.py +16 -0
  37. sinter/_decoding/_decoding.py +419 -0
  38. sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
  39. sinter/_decoding/_decoding_decoder_class.py +161 -0
  40. sinter/_decoding/_decoding_fusion_blossom.py +193 -0
  41. sinter/_decoding/_decoding_mwpf.py +302 -0
  42. sinter/_decoding/_decoding_pymatching.py +81 -0
  43. sinter/_decoding/_decoding_test.py +480 -0
  44. sinter/_decoding/_decoding_vacuous.py +38 -0
  45. sinter/_decoding/_perfectionist_sampler.py +38 -0
  46. sinter/_decoding/_sampler.py +72 -0
  47. sinter/_decoding/_stim_then_decode_sampler.py +222 -0
  48. sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
  49. sinter/_plotting.py +619 -0
  50. sinter/_plotting_test.py +108 -0
  51. sinter/_predict.py +381 -0
  52. sinter/_predict_test.py +227 -0
  53. sinter/_probability_util.py +519 -0
  54. sinter/_probability_util_test.py +281 -0
  55. sinter-1.15.0.data/data/README.md +332 -0
  56. sinter-1.15.0.data/data/readme_example_plot.png +0 -0
  57. sinter-1.15.0.data/data/requirements.txt +4 -0
  58. sinter-1.15.0.dist-info/METADATA +354 -0
  59. sinter-1.15.0.dist-info/RECORD +62 -0
  60. sinter-1.15.0.dist-info/WHEEL +5 -0
  61. sinter-1.15.0.dist-info/entry_points.txt +2 -0
  62. sinter-1.15.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,140 @@
1
+ import collections
2
+
3
+ import pytest
4
+
5
+ import sinter
6
+ from sinter._data._task_stats import _is_equal_json_values
7
+
8
+
9
+ def test_repr():
10
+ v = sinter.TaskStats(
11
+ strong_id='test',
12
+ json_metadata={'a': [1, 2, 3]},
13
+ decoder='pymatching',
14
+ shots=22,
15
+ errors=3,
16
+ discards=4,
17
+ seconds=5,
18
+ )
19
+ assert eval(repr(v), {"sinter": sinter}) == v
20
+
21
+
22
+ def test_to_csv_line():
23
+ v = sinter.TaskStats(
24
+ strong_id='test',
25
+ json_metadata={'a': [1, 2, 3]},
26
+ decoder='pymatching',
27
+ shots=22,
28
+ errors=3,
29
+ discards=4,
30
+ seconds=5,
31
+ )
32
+ assert v.to_csv_line() == str(v) == ' 22, 3, 4, 5,pymatching,test,"{""a"":[1,2,3]}",'
33
+
34
+
35
+ def test_to_anon_stats():
36
+ v = sinter.TaskStats(
37
+ strong_id='test',
38
+ json_metadata={'a': [1, 2, 3]},
39
+ decoder='pymatching',
40
+ shots=22,
41
+ errors=3,
42
+ discards=4,
43
+ seconds=5,
44
+ )
45
+ assert v.to_anon_stats() == sinter.AnonTaskStats(shots=22, errors=3, discards=4, seconds=5)
46
+
47
+
48
+ def test_add():
49
+ a = sinter.TaskStats(
50
+ decoder='pymatching',
51
+ json_metadata={'a': 2},
52
+ strong_id='abcdef',
53
+ shots=220,
54
+ errors=30,
55
+ discards=40,
56
+ seconds=50,
57
+ custom_counts=collections.Counter({'a': 10, 'b': 20}),
58
+ )
59
+ b = sinter.TaskStats(
60
+ decoder='pymatching',
61
+ json_metadata={'a': 2},
62
+ strong_id='abcdef',
63
+ shots=50,
64
+ errors=4,
65
+ discards=3,
66
+ seconds=2,
67
+ custom_counts=collections.Counter({'a': 1, 'c': 3}),
68
+ )
69
+ c = sinter.TaskStats(
70
+ decoder='pymatching',
71
+ json_metadata={'a': 2},
72
+ strong_id='abcdef',
73
+ shots=270,
74
+ errors=34,
75
+ discards=43,
76
+ seconds=52,
77
+ custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}),
78
+ )
79
+ assert a + b == c
80
+ with pytest.raises(ValueError):
81
+ a + sinter.TaskStats(
82
+ decoder='pymatching',
83
+ json_metadata={'a': 2},
84
+ strong_id='abcdefDIFFERENT',
85
+ shots=270,
86
+ errors=34,
87
+ discards=43,
88
+ seconds=52,
89
+ custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}),
90
+ )
91
+
92
+
93
+ def test_with_edits():
94
+ v = sinter.TaskStats(
95
+ decoder='pymatching',
96
+ json_metadata={'a': 2},
97
+ strong_id='abcdefDIFFERENT',
98
+ shots=270,
99
+ errors=34,
100
+ discards=43,
101
+ seconds=52,
102
+ custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}),
103
+ )
104
+ assert v.with_edits(json_metadata={'b': 3}) == sinter.TaskStats(
105
+ decoder='pymatching',
106
+ json_metadata={'b': 3},
107
+ strong_id='abcdefDIFFERENT',
108
+ shots=270,
109
+ errors=34,
110
+ discards=43,
111
+ seconds=52,
112
+ custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}),
113
+ )
114
+ assert v == sinter.TaskStats(strong_id='', json_metadata={}, decoder='').with_edits(
115
+ decoder='pymatching',
116
+ json_metadata={'a': 2},
117
+ strong_id='abcdefDIFFERENT',
118
+ shots=270,
119
+ errors=34,
120
+ discards=43,
121
+ seconds=52,
122
+ custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}),
123
+ )
124
+
125
+
126
+ def test_is_equal_json_values():
127
+ assert _is_equal_json_values([1, 2], (1, 2))
128
+ assert _is_equal_json_values([1, [3, (5, 6)]], (1, (3, [5, 6])))
129
+ assert not _is_equal_json_values([1, [3, (5, 6)]], (1, (3, [5, 7])))
130
+ assert not _is_equal_json_values([1, [3, (5, 6)]], (1, (3, [5])))
131
+ assert not _is_equal_json_values([1, 2], (1, 3))
132
+ assert not _is_equal_json_values([1, 2], {1, 2})
133
+ assert _is_equal_json_values({'x': [1, 2]}, {'x': (1, 2)})
134
+ assert _is_equal_json_values({'x': (1, 2)}, {'x': (1, 2)})
135
+ assert not _is_equal_json_values({'x': (1, 2)}, {'y': (1, 2)})
136
+ assert not _is_equal_json_values({'x': (1, 2)}, {'x': (1, 2), 'y': []})
137
+ assert not _is_equal_json_values({'x': (1, 2), 'y': []}, {'x': (1, 2)})
138
+ assert not _is_equal_json_values({'x': (1, 2)}, {'x': (1, 3)})
139
+ assert not _is_equal_json_values(1, 2)
140
+ assert _is_equal_json_values(1, 1)
@@ -0,0 +1,38 @@
1
+ import numpy as np
2
+
3
+ import stim
4
+
5
+ import sinter
6
+
7
+
8
+ def test_repr():
9
+ circuit = stim.Circuit("""
10
+ X_ERROR(0.1) 0 1 2
11
+ M 0 1 2
12
+ DETECTOR rec[-1] rec[-2]
13
+ DETECTOR rec[-2] rec[-3]
14
+ OBSERVABLE_INCLUDE(0) rec[-1]
15
+ """)
16
+ v = sinter.Task(circuit=circuit)
17
+ assert eval(repr(v), {"stim": stim, "sinter": sinter, "np": np}) == v
18
+
19
+ v = sinter.Task(circuit=circuit, detector_error_model=circuit.detector_error_model())
20
+ assert eval(repr(v), {"stim": stim, "sinter": sinter, "np": np}) == v
21
+
22
+ v = sinter.Task(circuit=circuit, postselection_mask=np.array([1], dtype=np.uint8))
23
+ assert eval(repr(v), {"stim": stim, "sinter": sinter, "np": np}) == v
24
+
25
+ v = sinter.Task(circuit=circuit, postselection_mask=np.array([2], dtype=np.uint8))
26
+ assert eval(repr(v), {"stim": stim, "sinter": sinter, "np": np}) == v
27
+
28
+ v = sinter.Task(circuit=circuit, postselected_observables_mask=np.array([1], dtype=np.uint8))
29
+ assert eval(repr(v), {"stim": stim, "sinter": sinter, "np": np}) == v
30
+
31
+ v = sinter.Task(circuit=circuit, collection_options=sinter.CollectionOptions(max_shots=10))
32
+ assert eval(repr(v), {"stim": stim, "sinter": sinter, "np": np}) == v
33
+
34
+ v = sinter.Task(circuit=circuit, json_metadata={'a': 5})
35
+ assert eval(repr(v), {"stim": stim, "sinter": sinter, "np": np}) == v
36
+
37
+ v = sinter.Task(circuit=circuit, decoder='pymatching')
38
+ assert eval(repr(v), {"stim": stim, "sinter": sinter, "np": np}) == v
@@ -0,0 +1,16 @@
1
+ from sinter._decoding._decoding import (
2
+ streaming_post_select,
3
+ sample_decode,
4
+ )
5
+ from sinter._decoding._decoding_decoder_class import (
6
+ CompiledDecoder,
7
+ Decoder,
8
+ )
9
+ from sinter._decoding._decoding_all_built_in_decoders import (
10
+ BUILT_IN_DECODERS,
11
+ BUILT_IN_SAMPLERS,
12
+ )
13
+ from sinter._decoding._sampler import (
14
+ Sampler,
15
+ CompiledSampler,
16
+ )
@@ -0,0 +1,419 @@
1
+ import collections
2
+ from typing import Iterable
3
+ from typing import Optional, Dict, Tuple, TYPE_CHECKING, Union
4
+
5
+ import contextlib
6
+ import pathlib
7
+ import tempfile
8
+ import math
9
+ import time
10
+
11
+ import numpy as np
12
+ import stim
13
+
14
+ from sinter._data import AnonTaskStats
15
+ from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_DECODERS
16
+ from sinter._decoding._decoding_decoder_class import CompiledDecoder, Decoder
17
+
18
+ if TYPE_CHECKING:
19
+ import sinter
20
+
21
+
22
+ def streaming_post_select(*,
23
+ num_dets: int,
24
+ num_obs: int,
25
+ dets_in_b8: pathlib.Path,
26
+ obs_in_b8: Optional[pathlib.Path],
27
+ dets_out_b8: pathlib.Path,
28
+ obs_out_b8: Optional[pathlib.Path],
29
+ discards_out_b8: Optional[pathlib.Path],
30
+ num_shots: int,
31
+ post_mask: np.ndarray) -> int:
32
+ if post_mask.shape != ((num_dets + 7) // 8,):
33
+ raise ValueError(f"post_mask.shape={post_mask.shape} != (math.ceil(num_detectors / 8),)")
34
+ if post_mask.dtype != np.uint8:
35
+ raise ValueError(f"post_mask.dtype={post_mask.dtype} != np.uint8")
36
+ assert (obs_in_b8 is None) == (obs_out_b8 is None)
37
+
38
+ num_det_bytes = math.ceil(num_dets / 8)
39
+ num_obs_bytes = math.ceil(num_obs / 8)
40
+ num_shots_left = num_shots
41
+ num_discards = 0
42
+
43
+ with contextlib.ExitStack() as ctx:
44
+ dets_in_f = ctx.enter_context(open(dets_in_b8, 'rb'))
45
+ dets_out_f = ctx.enter_context(open(dets_out_b8, 'wb'))
46
+ if obs_in_b8 is not None and obs_out_b8 is not None:
47
+ obs_in_f = ctx.enter_context(open(obs_in_b8, 'rb'))
48
+ obs_out_f = ctx.enter_context(open(obs_out_b8, 'wb'))
49
+ else:
50
+ obs_in_f = None
51
+ obs_out_f = None
52
+ if discards_out_b8 is not None:
53
+ discards_out_f = ctx.enter_context(open(discards_out_b8, 'wb'))
54
+ else:
55
+ discards_out_f = None
56
+
57
+ while num_shots_left:
58
+ batch_size = min(num_shots_left, math.ceil(10 ** 6 / max(1, num_dets)))
59
+
60
+ det_batch = np.fromfile(dets_in_f, dtype=np.uint8, count=num_det_bytes * batch_size)
61
+ det_batch.shape = (batch_size, num_det_bytes)
62
+ discarded = np.any(det_batch & post_mask, axis=1)
63
+ det_left = det_batch[~discarded, :]
64
+ det_left.tofile(dets_out_f)
65
+
66
+ if obs_in_f is not None and obs_out_f is not None:
67
+ obs_batch = np.fromfile(obs_in_f, dtype=np.uint8, count=num_obs_bytes * batch_size)
68
+ obs_batch.shape = (batch_size, num_obs_bytes)
69
+ obs_left = obs_batch[~discarded, :]
70
+ obs_left.tofile(obs_out_f)
71
+ if discards_out_f is not None:
72
+ discarded.tofile(discards_out_f)
73
+
74
+ num_discards += np.count_nonzero(discarded)
75
+ num_shots_left -= batch_size
76
+
77
+ return num_discards
78
+
79
+
80
+ def _streaming_count_mistakes(
81
+ *,
82
+ num_shots: int,
83
+ num_obs: int,
84
+ num_det: int,
85
+ postselected_observable_mask: Optional[np.ndarray] = None,
86
+ dets_in: pathlib.Path,
87
+ obs_in: pathlib.Path,
88
+ predictions_in: pathlib.Path,
89
+ count_detection_events: bool,
90
+ count_observable_error_combos: bool,
91
+ ) -> Tuple[int, int, collections.Counter]:
92
+
93
+ num_det_bytes = math.ceil(num_det / 8)
94
+ num_obs_bytes = math.ceil(num_obs / 8)
95
+ num_errors = 0
96
+ num_discards = 0
97
+ custom_counts = collections.Counter()
98
+ if count_detection_events:
99
+ with open(dets_in, 'rb') as dets_in_f:
100
+ num_shots_left = num_shots
101
+ while num_shots_left:
102
+ batch_size = min(num_shots_left, math.ceil(10**6 / max(num_obs, 1)))
103
+ det_data = np.fromfile(dets_in_f, dtype=np.uint8, count=num_det_bytes * batch_size)
104
+ for b in range(8):
105
+ custom_counts['detection_events'] += np.count_nonzero(det_data & (1 << b))
106
+ num_shots_left -= batch_size
107
+ custom_counts['detectors_checked'] += num_shots * num_det
108
+
109
+ with open(obs_in, 'rb') as obs_in_f:
110
+ with open(predictions_in, 'rb') as predictions_in_f:
111
+ num_shots_left = num_shots
112
+ while num_shots_left:
113
+ batch_size = min(num_shots_left, math.ceil(10**6 / max(num_obs, 1)))
114
+
115
+ obs_batch = np.fromfile(obs_in_f, dtype=np.uint8, count=num_obs_bytes * batch_size)
116
+ pred_batch = np.fromfile(predictions_in_f, dtype=np.uint8, count=num_obs_bytes * batch_size)
117
+ obs_batch.shape = (batch_size, num_obs_bytes)
118
+ pred_batch.shape = (batch_size, num_obs_bytes)
119
+
120
+ cmp_table = pred_batch ^ obs_batch
121
+ err_mask = np.any(cmp_table, axis=1)
122
+ if postselected_observable_mask is not None:
123
+ discard_mask = np.any(cmp_table & postselected_observable_mask, axis=1)
124
+ err_mask &= ~discard_mask
125
+ num_discards += np.count_nonzero(discard_mask)
126
+
127
+ if count_observable_error_combos:
128
+ for misprediction_arr in cmp_table[err_mask]:
129
+ err_key = "obs_mistake_mask=" + ''.join('_E'[b] for b in np.unpackbits(misprediction_arr, count=num_obs, bitorder='little'))
130
+ custom_counts[err_key] += 1
131
+
132
+ num_errors += np.count_nonzero(err_mask)
133
+ num_shots_left -= batch_size
134
+ return num_discards, num_errors, custom_counts
135
+
136
+
137
+ def sample_decode(*,
138
+ circuit_obj: Optional[stim.Circuit],
139
+ circuit_path: Union[None, str, pathlib.Path],
140
+ dem_obj: Optional[stim.DetectorErrorModel],
141
+ dem_path: Union[None, str, pathlib.Path],
142
+ post_mask: Optional[np.ndarray] = None,
143
+ postselected_observable_mask: Optional[np.ndarray] = None,
144
+ count_observable_error_combos: bool = False,
145
+ count_detection_events: bool = False,
146
+ num_shots: int,
147
+ decoder: str,
148
+ tmp_dir: Union[str, pathlib.Path, None] = None,
149
+ custom_decoders: Optional[Dict[str, 'sinter.Decoder']] = None,
150
+ __private__unstable__force_decode_on_disk: Optional[bool] = None,
151
+ ) -> AnonTaskStats:
152
+ """Samples how many times a decoder correctly predicts the logical frame.
153
+
154
+ Args:
155
+ circuit_obj: The noisy circuit to sample from and decode results for.
156
+ Must specify circuit_obj XOR circuit_path.
157
+ circuit_path: The file storing the circuit to sample from.
158
+ Must specify circuit_obj XOR circuit_path.
159
+ dem_obj: The error model to give to the decoder.
160
+ Must specify dem_obj XOR dem_path.
161
+ dem_path: The file storing the error model to give to the decoder.
162
+ Must specify dem_obj XOR dem_path.
163
+ post_mask: Postselection mask. Any samples that have a non-zero result
164
+ at a location where the mask has a 1 bit are discarded. If set to
165
+ None, no postselection is performed.
166
+ postselected_observable_mask: Bit packed mask indicating which observables to
167
+ postselect on. If the decoder incorrectly predicts any of these observables, the
168
+ shot is discarded instead of counted as an error.
169
+ count_observable_error_combos: Defaults to False. When set to to True,
170
+ the returned AnonTaskStats will have a custom counts field with keys
171
+ like `obs_mistake_mask=E_E__` counting how many times specific
172
+ combinations of observables were mispredicted by the decoder.
173
+ count_detection_events: Defaults to False. When set to True, the
174
+ returned AnonTaskStats will have a custom counts field withs the
175
+ key `detection_events` counting the number of times a detector fired
176
+ and also `detectors_checked` counting the number of detectors that
177
+ were executed. The detection fraction is the ratio of these two
178
+ numbers.
179
+ num_shots: The number of sample shots to take from the circuit.
180
+ decoder: The name of the decoder to use. Allowed values are:
181
+ "pymatching":
182
+ Use pymatching min-weight-perfect-match decoder.
183
+ "internal":
184
+ Use internal decoder with uncorrelated decoding.
185
+ "internal_correlated":
186
+ Use internal decoder with correlated decoding.
187
+ tmp_dir: An existing directory that is currently empty where temporary
188
+ files can be written as part of performing decoding. If set to
189
+ None, one is created using the tempfile package.
190
+ custom_decoders: Custom decoders that can be used if requested by name.
191
+ If not specified, only decoders built into sinter, such as
192
+ 'pymatching' and 'fusion_blossom', can be used.
193
+ """
194
+ if (circuit_obj is None) == (circuit_path is None):
195
+ raise ValueError('(circuit_obj is None) == (circuit_path is None)')
196
+ if (dem_obj is None) == (dem_path is None):
197
+ raise ValueError('(dem_obj is None) == (dem_path is None)')
198
+ if num_shots == 0:
199
+ return AnonTaskStats()
200
+
201
+ decoder_obj: Optional[Decoder] = None
202
+ if custom_decoders is not None:
203
+ decoder_obj = custom_decoders.get(decoder)
204
+ if decoder_obj is None:
205
+ decoder_obj = BUILT_IN_DECODERS.get(decoder)
206
+ if decoder_obj is None:
207
+ raise NotImplementedError(f"Unrecognized decoder: {decoder!r}")
208
+
209
+ dem: stim.DetectorErrorModel
210
+ if dem_obj is None:
211
+ dem = stim.DetectorErrorModel.from_file(dem_path)
212
+ else:
213
+ dem = dem_obj
214
+
215
+ circuit: stim.Circuit
216
+ if circuit_path is not None:
217
+ circuit = stim.Circuit.from_file(circuit_path)
218
+ else:
219
+ circuit = circuit_obj
220
+
221
+ start_time = time.monotonic()
222
+ try:
223
+ if __private__unstable__force_decode_on_disk:
224
+ raise NotImplementedError()
225
+ compiled_decoder = decoder_obj.compile_decoder_for_dem(dem=dem)
226
+ return _sample_decode_helper_using_memory(
227
+ circuit=circuit,
228
+ post_mask=post_mask,
229
+ postselected_observable_mask=postselected_observable_mask,
230
+ compiled_decoder=compiled_decoder,
231
+ total_num_shots=num_shots,
232
+ num_det=circuit.num_detectors,
233
+ mini_batch_size=1024,
234
+ start_time_monotonic=start_time,
235
+ num_obs=circuit.num_observables,
236
+ count_observable_error_combos=count_observable_error_combos,
237
+ count_detection_events=count_detection_events,
238
+ )
239
+ except NotImplementedError:
240
+ assert __private__unstable__force_decode_on_disk or __private__unstable__force_decode_on_disk is None
241
+ pass
242
+ return _sample_decode_helper_using_disk(
243
+ circuit=circuit,
244
+ dem=dem,
245
+ dem_path=dem_path,
246
+ post_mask=post_mask,
247
+ postselected_observable_mask=postselected_observable_mask,
248
+ num_shots=num_shots,
249
+ decoder_obj=decoder_obj,
250
+ tmp_dir=tmp_dir,
251
+ start_time_monotonic=start_time,
252
+ count_observable_error_combos=count_observable_error_combos,
253
+ count_detection_events=count_detection_events,
254
+ )
255
+
256
+
257
+ def _sample_decode_helper_using_memory(
258
+ *,
259
+ circuit: stim.Circuit,
260
+ post_mask: Optional[np.ndarray],
261
+ postselected_observable_mask: Optional[np.ndarray],
262
+ num_obs: int,
263
+ num_det: int,
264
+ total_num_shots: int,
265
+ mini_batch_size: int,
266
+ compiled_decoder: CompiledDecoder,
267
+ start_time_monotonic: float,
268
+ count_observable_error_combos: bool,
269
+ count_detection_events: bool,
270
+ ) -> AnonTaskStats:
271
+ sampler: stim.CompiledDetectorSampler = circuit.compile_detector_sampler()
272
+
273
+ out_num_discards = 0
274
+ out_num_errors = 0
275
+ shots_left = total_num_shots
276
+ custom_counts = collections.Counter()
277
+ while shots_left > 0:
278
+ cur_num_shots = min(shots_left, mini_batch_size)
279
+ dets_data, obs_data = sampler.sample(shots=cur_num_shots, separate_observables=True, bit_packed=True)
280
+
281
+ # Discard any shots that contain a postselected detection events.
282
+ if post_mask is not None:
283
+ discarded_flags = np.any(dets_data & post_mask, axis=1)
284
+ cur_num_discarded_shots = np.count_nonzero(discarded_flags)
285
+ if cur_num_discarded_shots:
286
+ out_num_discards += cur_num_discarded_shots
287
+ dets_data = dets_data[~discarded_flags, :]
288
+ obs_data = obs_data[~discarded_flags, :]
289
+
290
+ # Have the decoder predict which observables are flipped.
291
+ predict_data = compiled_decoder.decode_shots_bit_packed(bit_packed_detection_event_data=dets_data)
292
+
293
+ # Discard any shots where the decoder predicts a flipped postselected observable.
294
+ if postselected_observable_mask is not None:
295
+ discarded_flags = np.any(postselected_observable_mask & (predict_data ^ obs_data), axis=1)
296
+ cur_num_discarded_shots = np.count_nonzero(discarded_flags)
297
+ if cur_num_discarded_shots:
298
+ out_num_discards += cur_num_discarded_shots
299
+ obs_data = obs_data[~discarded_flags, :]
300
+ predict_data = predict_data[~discarded_flags, :]
301
+
302
+ # Count how many mistakes the decoder made on non-discarded shots.
303
+ mispredictions = obs_data ^ predict_data
304
+ err_mask = np.any(mispredictions, axis=1)
305
+ if count_detection_events:
306
+ for b in range(8):
307
+ custom_counts['detection_events'] += np.count_nonzero(dets_data & (1 << b))
308
+ if count_observable_error_combos:
309
+ for misprediction_arr in mispredictions[err_mask]:
310
+ err_key = "obs_mistake_mask=" + ''.join('_E'[b] for b in np.unpackbits(misprediction_arr, count=num_obs, bitorder='little'))
311
+ custom_counts[err_key] += 1
312
+ out_num_errors += np.count_nonzero(err_mask)
313
+ shots_left -= cur_num_shots
314
+
315
+ if count_detection_events:
316
+ custom_counts['detectors_checked'] += num_det * total_num_shots
317
+ return AnonTaskStats(
318
+ shots=total_num_shots,
319
+ errors=out_num_errors,
320
+ discards=out_num_discards,
321
+ seconds=time.monotonic() - start_time_monotonic,
322
+ custom_counts=custom_counts,
323
+ )
324
+
325
+
326
+ def _sample_decode_helper_using_disk(
327
+ *,
328
+ circuit: stim.Circuit,
329
+ dem: stim.DetectorErrorModel,
330
+ dem_path: Union[str, pathlib.Path],
331
+ post_mask: Optional[np.ndarray],
332
+ postselected_observable_mask: Optional[np.ndarray],
333
+ num_shots: int,
334
+ decoder_obj: Decoder,
335
+ tmp_dir: Union[str, pathlib.Path, None],
336
+ start_time_monotonic: float,
337
+ count_observable_error_combos: bool,
338
+ count_detection_events: bool,
339
+ ) -> AnonTaskStats:
340
+ with contextlib.ExitStack() as exit_stack:
341
+ if tmp_dir is None:
342
+ tmp_dir = exit_stack.enter_context(tempfile.TemporaryDirectory())
343
+ tmp_dir = pathlib.Path(tmp_dir)
344
+ if dem_path is None:
345
+ dem_path = tmp_dir / 'tmp.dem'
346
+ dem.to_file(dem_path)
347
+ dem_path = pathlib.Path(dem_path)
348
+
349
+ dets_all_path = tmp_dir / 'sinter_dets.all.b8'
350
+ obs_all_path = tmp_dir / 'sinter_obs.all.b8'
351
+ dets_kept_path = tmp_dir / 'sinter_dets.kept.b8'
352
+ obs_kept_path = tmp_dir / 'sinter_obs.kept.b8'
353
+ predictions_path = tmp_dir / 'sinter_predictions.b8'
354
+
355
+ num_dets = circuit.num_detectors
356
+ num_obs = circuit.num_observables
357
+
358
+ # Sample data using Stim.
359
+ sampler: stim.CompiledDetectorSampler = circuit.compile_detector_sampler()
360
+ sampler.sample_write(
361
+ num_shots,
362
+ filepath=str(dets_all_path),
363
+ obs_out_filepath=str(obs_all_path),
364
+ format='b8',
365
+ obs_out_format='b8',
366
+ )
367
+
368
+ # Postselect, then split into detection event data and observable data.
369
+ if post_mask is None:
370
+ num_det_discards = 0
371
+ dets_used_path = dets_all_path
372
+ obs_used_path = obs_all_path
373
+ else:
374
+ num_det_discards = streaming_post_select(
375
+ num_shots=num_shots,
376
+ num_dets=num_dets,
377
+ num_obs=num_obs,
378
+ dets_in_b8=dets_all_path,
379
+ dets_out_b8=dets_kept_path,
380
+ obs_in_b8=obs_all_path,
381
+ obs_out_b8=obs_kept_path,
382
+ post_mask=post_mask,
383
+ discards_out_b8=None,
384
+ )
385
+ dets_used_path = dets_kept_path
386
+ obs_used_path = obs_kept_path
387
+ num_kept_shots = num_shots - num_det_discards
388
+
389
+ # Perform syndrome decoding to predict observables from detection events.
390
+ decoder_obj.decode_via_files(
391
+ num_shots=num_kept_shots,
392
+ num_dets=num_dets,
393
+ num_obs=num_obs,
394
+ dem_path=dem_path,
395
+ dets_b8_in_path=dets_used_path,
396
+ obs_predictions_b8_out_path=predictions_path,
397
+ tmp_dir=tmp_dir,
398
+ )
399
+
400
+ # Count how many predictions matched the actual observable data.
401
+ num_obs_discards, num_errors, custom_counts = _streaming_count_mistakes(
402
+ num_shots=num_kept_shots,
403
+ num_obs=num_obs,
404
+ num_det=num_dets,
405
+ dets_in=dets_all_path,
406
+ obs_in=obs_used_path,
407
+ predictions_in=predictions_path,
408
+ postselected_observable_mask=postselected_observable_mask,
409
+ count_detection_events=count_detection_events,
410
+ count_observable_error_combos=count_observable_error_combos,
411
+ )
412
+
413
+ return AnonTaskStats(
414
+ shots=num_shots,
415
+ errors=num_errors,
416
+ discards=num_obs_discards + num_det_discards,
417
+ seconds=time.monotonic() - start_time_monotonic,
418
+ custom_counts=custom_counts,
419
+ )
@@ -0,0 +1,25 @@
1
+ from typing import Dict
2
+ from typing import Union
3
+
4
+ from sinter._decoding._decoding_decoder_class import Decoder
5
+ from sinter._decoding._decoding_fusion_blossom import FusionBlossomDecoder
6
+ from sinter._decoding._decoding_pymatching import PyMatchingDecoder
7
+ from sinter._decoding._decoding_vacuous import VacuousDecoder
8
+ from sinter._decoding._perfectionist_sampler import PerfectionistSampler
9
+ from sinter._decoding._sampler import Sampler
10
+ from sinter._decoding._decoding_mwpf import HyperUFDecoder, MwpfDecoder
11
+
12
+ BUILT_IN_DECODERS: Dict[str, Decoder] = {
13
+ 'vacuous': VacuousDecoder(),
14
+ 'pymatching': PyMatchingDecoder(),
15
+ 'fusion_blossom': FusionBlossomDecoder(),
16
+ # an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049)
17
+ 'hypergraph_union_find': HyperUFDecoder(),
18
+ # Minimum-Weight Parity Factor using similar primal-dual method the blossom algorithm (https://pypi.org/project/mwpf/)
19
+ 'mw_parity_factor': MwpfDecoder(),
20
+ }
21
+
22
+ BUILT_IN_SAMPLERS: Dict[str, Union[Decoder, Sampler]] = {
23
+ **BUILT_IN_DECODERS,
24
+ 'perfectionist': PerfectionistSampler(),
25
+ }