sinter 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sinter might be problematic. Click here for more details.
- sinter/__init__.py +47 -0
- sinter/_collection/__init__.py +10 -0
- sinter/_collection/_collection.py +480 -0
- sinter/_collection/_collection_manager.py +581 -0
- sinter/_collection/_collection_manager_test.py +287 -0
- sinter/_collection/_collection_test.py +317 -0
- sinter/_collection/_collection_worker_loop.py +35 -0
- sinter/_collection/_collection_worker_state.py +259 -0
- sinter/_collection/_collection_worker_test.py +222 -0
- sinter/_collection/_mux_sampler.py +56 -0
- sinter/_collection/_printer.py +65 -0
- sinter/_collection/_sampler_ramp_throttled.py +66 -0
- sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
- sinter/_command/__init__.py +0 -0
- sinter/_command/_main.py +39 -0
- sinter/_command/_main_collect.py +350 -0
- sinter/_command/_main_collect_test.py +482 -0
- sinter/_command/_main_combine.py +84 -0
- sinter/_command/_main_combine_test.py +153 -0
- sinter/_command/_main_plot.py +817 -0
- sinter/_command/_main_plot_test.py +445 -0
- sinter/_command/_main_predict.py +75 -0
- sinter/_command/_main_predict_test.py +36 -0
- sinter/_data/__init__.py +20 -0
- sinter/_data/_anon_task_stats.py +89 -0
- sinter/_data/_anon_task_stats_test.py +35 -0
- sinter/_data/_collection_options.py +106 -0
- sinter/_data/_collection_options_test.py +24 -0
- sinter/_data/_csv_out.py +74 -0
- sinter/_data/_existing_data.py +173 -0
- sinter/_data/_existing_data_test.py +41 -0
- sinter/_data/_task.py +311 -0
- sinter/_data/_task_stats.py +244 -0
- sinter/_data/_task_stats_test.py +140 -0
- sinter/_data/_task_test.py +38 -0
- sinter/_decoding/__init__.py +16 -0
- sinter/_decoding/_decoding.py +419 -0
- sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
- sinter/_decoding/_decoding_decoder_class.py +161 -0
- sinter/_decoding/_decoding_fusion_blossom.py +193 -0
- sinter/_decoding/_decoding_mwpf.py +302 -0
- sinter/_decoding/_decoding_pymatching.py +81 -0
- sinter/_decoding/_decoding_test.py +480 -0
- sinter/_decoding/_decoding_vacuous.py +38 -0
- sinter/_decoding/_perfectionist_sampler.py +38 -0
- sinter/_decoding/_sampler.py +72 -0
- sinter/_decoding/_stim_then_decode_sampler.py +222 -0
- sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
- sinter/_plotting.py +619 -0
- sinter/_plotting_test.py +108 -0
- sinter/_predict.py +381 -0
- sinter/_predict_test.py +227 -0
- sinter/_probability_util.py +519 -0
- sinter/_probability_util_test.py +281 -0
- sinter-1.15.0.data/data/README.md +332 -0
- sinter-1.15.0.data/data/readme_example_plot.png +0 -0
- sinter-1.15.0.data/data/requirements.txt +4 -0
- sinter-1.15.0.dist-info/METADATA +354 -0
- sinter-1.15.0.dist-info/RECORD +62 -0
- sinter-1.15.0.dist-info/WHEEL +5 -0
- sinter-1.15.0.dist-info/entry_points.txt +2 -0
- sinter-1.15.0.dist-info/top_level.txt +1 -0
sinter/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
__version__ = '1.15.0'
|
|
2
|
+
|
|
3
|
+
from sinter._collection import (
|
|
4
|
+
collect,
|
|
5
|
+
iter_collect,
|
|
6
|
+
post_selection_mask_from_4th_coord,
|
|
7
|
+
Progress,
|
|
8
|
+
)
|
|
9
|
+
from sinter._data import (
|
|
10
|
+
AnonTaskStats,
|
|
11
|
+
CollectionOptions,
|
|
12
|
+
CSV_HEADER,
|
|
13
|
+
read_stats_from_csv_files,
|
|
14
|
+
stats_from_csv_files,
|
|
15
|
+
Task,
|
|
16
|
+
TaskStats,
|
|
17
|
+
)
|
|
18
|
+
from sinter._decoding import (
|
|
19
|
+
CompiledDecoder,
|
|
20
|
+
Decoder,
|
|
21
|
+
BUILT_IN_DECODERS,
|
|
22
|
+
BUILT_IN_SAMPLERS,
|
|
23
|
+
Sampler,
|
|
24
|
+
CompiledSampler,
|
|
25
|
+
)
|
|
26
|
+
from sinter._probability_util import (
|
|
27
|
+
comma_separated_key_values,
|
|
28
|
+
Fit,
|
|
29
|
+
fit_binomial,
|
|
30
|
+
fit_line_slope,
|
|
31
|
+
fit_line_y_at_x,
|
|
32
|
+
log_binomial,
|
|
33
|
+
log_factorial,
|
|
34
|
+
shot_error_rate_to_piece_error_rate,
|
|
35
|
+
)
|
|
36
|
+
from sinter._plotting import (
|
|
37
|
+
better_sorted_str_terms,
|
|
38
|
+
plot_discard_rate,
|
|
39
|
+
plot_error_rate,
|
|
40
|
+
group_by,
|
|
41
|
+
)
|
|
42
|
+
from sinter._predict import (
|
|
43
|
+
predict_discards_bit_packed,
|
|
44
|
+
predict_observables_bit_packed,
|
|
45
|
+
predict_on_disk,
|
|
46
|
+
predict_observables,
|
|
47
|
+
)
|
|
@@ -0,0 +1,480 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import dataclasses
|
|
3
|
+
import pathlib
|
|
4
|
+
from typing import Any, Callable, Iterator, Optional, Union, Iterable, List, TYPE_CHECKING, Tuple, Dict
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
import numpy as np
|
|
8
|
+
import stim
|
|
9
|
+
|
|
10
|
+
from sinter._data import CSV_HEADER, ExistingData, TaskStats, CollectionOptions, Task
|
|
11
|
+
from sinter._collection._collection_manager import CollectionManager
|
|
12
|
+
from sinter._collection._printer import ThrottledProgressPrinter
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import sinter
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclasses.dataclass(frozen=True)
|
|
19
|
+
class Progress:
|
|
20
|
+
"""Describes statistics and status messages from ongoing sampling.
|
|
21
|
+
|
|
22
|
+
This is the type yielded by `sinter.iter_collect`, and given to the
|
|
23
|
+
`progress_callback` argument of `sinter.collect`.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
new_stats: New sampled statistics collected since the last progress
|
|
27
|
+
update.
|
|
28
|
+
status_message: A free form human readable string describing the current
|
|
29
|
+
collection status, such as the number of tasks left and the
|
|
30
|
+
estimated time to completion for each task.
|
|
31
|
+
"""
|
|
32
|
+
new_stats: Tuple[TaskStats, ...]
|
|
33
|
+
status_message: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def iter_collect(*,
|
|
37
|
+
num_workers: int,
|
|
38
|
+
tasks: Union[Iterator['sinter.Task'],
|
|
39
|
+
Iterable['sinter.Task']],
|
|
40
|
+
hint_num_tasks: Optional[int] = None,
|
|
41
|
+
additional_existing_data: Union[None, dict[str, 'TaskStats'], Iterable['TaskStats']] = None,
|
|
42
|
+
max_shots: Optional[int] = None,
|
|
43
|
+
max_errors: Optional[int] = None,
|
|
44
|
+
decoders: Optional[Iterable[str]] = None,
|
|
45
|
+
max_batch_seconds: Optional[int] = None,
|
|
46
|
+
max_batch_size: Optional[int] = None,
|
|
47
|
+
start_batch_size: Optional[int] = None,
|
|
48
|
+
count_observable_error_combos: bool = False,
|
|
49
|
+
count_detection_events: bool = False,
|
|
50
|
+
custom_decoders: Optional[Dict[str, Union['sinter.Decoder', 'sinter.Sampler']]] = None,
|
|
51
|
+
custom_error_count_key: Optional[str] = None,
|
|
52
|
+
allowed_cpu_affinity_ids: Optional[Iterable[int]] = None,
|
|
53
|
+
) -> Iterator['sinter.Progress']:
|
|
54
|
+
"""Iterates error correction statistics collected from worker processes.
|
|
55
|
+
|
|
56
|
+
It is important to iterate until the sequence ends, or worker processes will
|
|
57
|
+
be left alive. The values yielded during iteration are progress updates from
|
|
58
|
+
the workers.
|
|
59
|
+
|
|
60
|
+
Note: if max_batch_size and max_batch_seconds are both not used (or
|
|
61
|
+
explicitly set to None), a default batch-size-limiting mechanism will be
|
|
62
|
+
chosen.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
num_workers: The number of worker processes to use.
|
|
66
|
+
tasks: Decoding problems to sample.
|
|
67
|
+
hint_num_tasks: If `tasks` is an iterator or a generator, its length
|
|
68
|
+
can be given here so that progress printouts can say how many cases
|
|
69
|
+
are left.
|
|
70
|
+
additional_existing_data: Defaults to None (no additional data).
|
|
71
|
+
Statistical data that has already been collected, in addition to
|
|
72
|
+
anything included in each task's `previous_stats` field.
|
|
73
|
+
decoders: Defaults to None (specified by each Task). The names of the
|
|
74
|
+
decoders to use on each Task. It must either be the case that each
|
|
75
|
+
Task specifies a decoder and this is set to None, or this is an
|
|
76
|
+
iterable and each Task has its decoder set to None.
|
|
77
|
+
max_shots: Defaults to None (unused). Stops the sampling process
|
|
78
|
+
after this many samples have been taken from the circuit.
|
|
79
|
+
max_errors: Defaults to None (unused). Stops the sampling process
|
|
80
|
+
after this many errors have been seen in samples taken from the
|
|
81
|
+
circuit. The actual number sampled errors may be larger due to
|
|
82
|
+
batching.
|
|
83
|
+
count_observable_error_combos: Defaults to False. When set to to True,
|
|
84
|
+
the returned stats will have a custom counts field with keys
|
|
85
|
+
like `obs_mistake_mask=E_E__` counting how many times specific
|
|
86
|
+
combinations of observables were mispredicted by the decoder.
|
|
87
|
+
count_detection_events: Defaults to False. When set to True, the
|
|
88
|
+
returned stats will have a custom counts field withs the
|
|
89
|
+
key `detection_events` counting the number of times a detector fired
|
|
90
|
+
and also `detectors_checked` counting the number of detectors that
|
|
91
|
+
were executed. The detection fraction is the ratio of these two
|
|
92
|
+
numbers.
|
|
93
|
+
start_batch_size: Defaults to None (collector's choice). The very
|
|
94
|
+
first shots taken from the circuit will use a batch of this
|
|
95
|
+
size, and no other batches will be taken in parallel. Once this
|
|
96
|
+
initial fact finding batch is done, batches can be taken in
|
|
97
|
+
parallel and the normal batch size limiting processes take over.
|
|
98
|
+
max_batch_size: Defaults to None (unused). Limits batches from
|
|
99
|
+
taking more than this many shots at once. For example, this can
|
|
100
|
+
be used to ensure memory usage stays below some limit.
|
|
101
|
+
max_batch_seconds: Defaults to None (unused). When set, the recorded
|
|
102
|
+
data from previous shots is used to estimate how much time is
|
|
103
|
+
taken per shot. This information is then used to predict the
|
|
104
|
+
biggest batch size that can finish in under the given number of
|
|
105
|
+
seconds. Limits each batch to be no larger than that.
|
|
106
|
+
custom_decoders: Custom decoders that can be used if requested by name.
|
|
107
|
+
If not specified, only decoders built into sinter, such as
|
|
108
|
+
'pymatching' and 'fusion_blossom', can be used.
|
|
109
|
+
custom_error_count_key: Makes `max_errors` apply to `stat.custom_counts[key]`
|
|
110
|
+
instead of `stat.errors`.
|
|
111
|
+
allowed_cpu_affinity_ids: Controls which CPUs the workers can be pinned to. The
|
|
112
|
+
set of allowed IDs should be at least as large as the number of workers, though
|
|
113
|
+
this is not strictly required. If not set, defaults to all CPUs being allowed.
|
|
114
|
+
|
|
115
|
+
Yields:
|
|
116
|
+
sinter.Progress instances recording incremental statistical data as it
|
|
117
|
+
is collected by workers.
|
|
118
|
+
|
|
119
|
+
Examples:
|
|
120
|
+
>>> import sinter
|
|
121
|
+
>>> import stim
|
|
122
|
+
>>> tasks = [
|
|
123
|
+
... sinter.Task(
|
|
124
|
+
... circuit=stim.Circuit.generated(
|
|
125
|
+
... 'repetition_code:memory',
|
|
126
|
+
... distance=5,
|
|
127
|
+
... rounds=5,
|
|
128
|
+
... before_round_data_depolarization=1e-3,
|
|
129
|
+
... ),
|
|
130
|
+
... json_metadata={'d': 5},
|
|
131
|
+
... ),
|
|
132
|
+
... sinter.Task(
|
|
133
|
+
... circuit=stim.Circuit.generated(
|
|
134
|
+
... 'repetition_code:memory',
|
|
135
|
+
... distance=7,
|
|
136
|
+
... rounds=5,
|
|
137
|
+
... before_round_data_depolarization=1e-3,
|
|
138
|
+
... ),
|
|
139
|
+
... json_metadata={'d': 7},
|
|
140
|
+
... ),
|
|
141
|
+
... ]
|
|
142
|
+
>>> iterator = sinter.iter_collect(
|
|
143
|
+
... tasks=tasks,
|
|
144
|
+
... decoders=['vacuous'],
|
|
145
|
+
... num_workers=2,
|
|
146
|
+
... max_shots=100,
|
|
147
|
+
... )
|
|
148
|
+
>>> total_shots = 0
|
|
149
|
+
>>> for progress in iterator:
|
|
150
|
+
... for stat in progress.new_stats:
|
|
151
|
+
... total_shots += stat.shots
|
|
152
|
+
>>> print(total_shots)
|
|
153
|
+
200
|
|
154
|
+
"""
|
|
155
|
+
existing_data: dict[str, TaskStats]
|
|
156
|
+
if isinstance(additional_existing_data, ExistingData):
|
|
157
|
+
existing_data = additional_existing_data.data
|
|
158
|
+
elif isinstance(additional_existing_data, dict):
|
|
159
|
+
existing_data = additional_existing_data
|
|
160
|
+
elif additional_existing_data is None:
|
|
161
|
+
existing_data = {}
|
|
162
|
+
else:
|
|
163
|
+
acc = ExistingData()
|
|
164
|
+
for stat in additional_existing_data:
|
|
165
|
+
acc.add_sample(stat)
|
|
166
|
+
existing_data = acc.data
|
|
167
|
+
|
|
168
|
+
if isinstance(decoders, str):
|
|
169
|
+
decoders = [decoders]
|
|
170
|
+
|
|
171
|
+
if hint_num_tasks is None:
|
|
172
|
+
try:
|
|
173
|
+
# noinspection PyTypeChecker
|
|
174
|
+
hint_num_tasks = len(tasks)
|
|
175
|
+
except TypeError:
|
|
176
|
+
pass
|
|
177
|
+
|
|
178
|
+
if decoders is not None:
|
|
179
|
+
old_tasks = tasks
|
|
180
|
+
tasks = (
|
|
181
|
+
Task(
|
|
182
|
+
circuit=task.circuit,
|
|
183
|
+
decoder=decoder,
|
|
184
|
+
detector_error_model=task.detector_error_model,
|
|
185
|
+
postselection_mask=task.postselection_mask,
|
|
186
|
+
postselected_observables_mask=task.postselected_observables_mask,
|
|
187
|
+
json_metadata=task.json_metadata,
|
|
188
|
+
collection_options=task.collection_options,
|
|
189
|
+
circuit_path=task.circuit_path,
|
|
190
|
+
)
|
|
191
|
+
for task in old_tasks
|
|
192
|
+
for decoder in (decoders if task.decoder is None else [task.decoder])
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
progress_log: list[Optional[TaskStats]] = []
|
|
196
|
+
def log_progress(e: Optional[TaskStats]):
|
|
197
|
+
progress_log.append(e)
|
|
198
|
+
with CollectionManager(
|
|
199
|
+
num_workers=num_workers,
|
|
200
|
+
tasks=tasks,
|
|
201
|
+
collection_options=CollectionOptions(
|
|
202
|
+
max_shots=max_shots,
|
|
203
|
+
max_errors=max_errors,
|
|
204
|
+
max_batch_seconds=max_batch_seconds,
|
|
205
|
+
start_batch_size=start_batch_size,
|
|
206
|
+
max_batch_size=max_batch_size,
|
|
207
|
+
),
|
|
208
|
+
existing_data=existing_data,
|
|
209
|
+
count_observable_error_combos=count_observable_error_combos,
|
|
210
|
+
count_detection_events=count_detection_events,
|
|
211
|
+
custom_error_count_key=custom_error_count_key,
|
|
212
|
+
custom_decoders=custom_decoders or {},
|
|
213
|
+
allowed_cpu_affinity_ids=allowed_cpu_affinity_ids,
|
|
214
|
+
worker_flush_period=max_batch_seconds or 120,
|
|
215
|
+
progress_callback=log_progress,
|
|
216
|
+
) as manager:
|
|
217
|
+
try:
|
|
218
|
+
yield Progress(
|
|
219
|
+
new_stats=(),
|
|
220
|
+
status_message=f"Starting {num_workers} workers..."
|
|
221
|
+
)
|
|
222
|
+
manager.start_workers()
|
|
223
|
+
manager.start_distributing_work()
|
|
224
|
+
|
|
225
|
+
while manager.task_states:
|
|
226
|
+
manager.process_message()
|
|
227
|
+
if progress_log:
|
|
228
|
+
vals = list(progress_log)
|
|
229
|
+
progress_log.clear()
|
|
230
|
+
for e in vals:
|
|
231
|
+
if e is not None:
|
|
232
|
+
yield Progress(
|
|
233
|
+
new_stats=(e,),
|
|
234
|
+
status_message=manager.status_message(),
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
except KeyboardInterrupt:
|
|
238
|
+
yield Progress(
|
|
239
|
+
new_stats=(),
|
|
240
|
+
status_message='KeyboardInterrupt',
|
|
241
|
+
)
|
|
242
|
+
raise
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def collect(*,
|
|
246
|
+
num_workers: int,
|
|
247
|
+
tasks: Union[Iterator['sinter.Task'], Iterable['sinter.Task']],
|
|
248
|
+
existing_data_filepaths: Iterable[Union[str, pathlib.Path]] = (),
|
|
249
|
+
save_resume_filepath: Union[None, str, pathlib.Path] = None,
|
|
250
|
+
progress_callback: Optional[Callable[['sinter.Progress'], None]] = None,
|
|
251
|
+
max_shots: Optional[int] = None,
|
|
252
|
+
max_errors: Optional[int] = None,
|
|
253
|
+
count_observable_error_combos: bool = False,
|
|
254
|
+
count_detection_events: bool = False,
|
|
255
|
+
decoders: Optional[Iterable[str]] = None,
|
|
256
|
+
max_batch_seconds: Optional[int] = None,
|
|
257
|
+
max_batch_size: Optional[int] = None,
|
|
258
|
+
start_batch_size: Optional[int] = None,
|
|
259
|
+
print_progress: bool = False,
|
|
260
|
+
hint_num_tasks: Optional[int] = None,
|
|
261
|
+
custom_decoders: Optional[Dict[str, Union['sinter.Decoder', 'sinter.Sampler']]] = None,
|
|
262
|
+
custom_error_count_key: Optional[str] = None,
|
|
263
|
+
allowed_cpu_affinity_ids: Optional[Iterable[int]] = None,
|
|
264
|
+
) -> List['sinter.TaskStats']:
|
|
265
|
+
"""Collects statistics from the given tasks, using multiprocessing.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
num_workers: The number of worker processes to use.
|
|
269
|
+
tasks: Decoding problems to sample.
|
|
270
|
+
save_resume_filepath: Defaults to None (unused). If set to a filepath,
|
|
271
|
+
results will be saved to that file while they are collected. If the
|
|
272
|
+
python interpreter is stopped or killed, calling this method again
|
|
273
|
+
with the same save_resume_filepath will load the previous results
|
|
274
|
+
from the file so it can resume where it left off.
|
|
275
|
+
|
|
276
|
+
The stats in this file will be counted in addition to each task's
|
|
277
|
+
previous_stats field (as opposed to overriding the field).
|
|
278
|
+
existing_data_filepaths: CSV data saved to these files will be loaded,
|
|
279
|
+
included in the returned results, and count towards things like
|
|
280
|
+
max_shots and max_errors.
|
|
281
|
+
progress_callback: Defaults to None (unused). If specified, then each
|
|
282
|
+
time new sample statistics are acquired from a worker this method
|
|
283
|
+
will be invoked with the new `sinter.TaskStats`.
|
|
284
|
+
hint_num_tasks: If `tasks` is an iterator or a generator, its length
|
|
285
|
+
can be given here so that progress printouts can say how many cases
|
|
286
|
+
are left.
|
|
287
|
+
decoders: Defaults to None (specified by each Task). The names of the
|
|
288
|
+
decoders to use on each Task. It must either be the case that each
|
|
289
|
+
Task specifies a decoder and this is set to None, or this is an
|
|
290
|
+
iterable and each Task has its decoder set to None.
|
|
291
|
+
count_observable_error_combos: Defaults to False. When set to to True,
|
|
292
|
+
the returned stats will have a custom counts field with keys
|
|
293
|
+
like `obs_mistake_mask=E_E__` counting how many times specific
|
|
294
|
+
combinations of observables were mispredicted by the decoder.
|
|
295
|
+
count_detection_events: Defaults to False. When set to True, the
|
|
296
|
+
returned stats will have a custom counts field withs the
|
|
297
|
+
key `detection_events` counting the number of times a detector fired
|
|
298
|
+
and also `detectors_checked` counting the number of detectors that
|
|
299
|
+
were executed. The detection fraction is the ratio of these two
|
|
300
|
+
numbers.
|
|
301
|
+
max_shots: Defaults to None (unused). Stops the sampling process
|
|
302
|
+
after this many samples have been taken from the circuit.
|
|
303
|
+
max_errors: Defaults to None (unused). Stops the sampling process
|
|
304
|
+
after this many errors have been seen in samples taken from the
|
|
305
|
+
circuit. The actual number sampled errors may be larger due to
|
|
306
|
+
batching.
|
|
307
|
+
start_batch_size: Defaults to None (collector's choice). The very
|
|
308
|
+
first shots taken from the circuit will use a batch of this
|
|
309
|
+
size, and no other batches will be taken in parallel. Once this
|
|
310
|
+
initial fact finding batch is done, batches can be taken in
|
|
311
|
+
parallel and the normal batch size limiting processes take over.
|
|
312
|
+
max_batch_size: Defaults to None (unused). Limits batches from
|
|
313
|
+
taking more than this many shots at once. For example, this can
|
|
314
|
+
be used to ensure memory usage stays below some limit.
|
|
315
|
+
print_progress: When True, progress is printed to stderr while
|
|
316
|
+
collection runs.
|
|
317
|
+
max_batch_seconds: Defaults to None (unused). When set, the recorded
|
|
318
|
+
data from previous shots is used to estimate how much time is
|
|
319
|
+
taken per shot. This information is then used to predict the
|
|
320
|
+
biggest batch size that can finish in under the given number of
|
|
321
|
+
seconds. Limits each batch to be no larger than that.
|
|
322
|
+
custom_decoders: Named child classes of `sinter.decoder`, that can be
|
|
323
|
+
used if requested by name by a task or by the decoders list.
|
|
324
|
+
If not specified, only decoders with support built into sinter, such
|
|
325
|
+
as 'pymatching' and 'fusion_blossom', can be used.
|
|
326
|
+
custom_error_count_key: Makes `max_errors` apply to `stat.custom_counts[key]`
|
|
327
|
+
instead of `stat.errors`.
|
|
328
|
+
allowed_cpu_affinity_ids: Controls which CPUs the workers can be pinned to. The
|
|
329
|
+
set of allowed IDs should be at least as large as the number of workers, though
|
|
330
|
+
this is not strictly required. If not set, defaults to all CPUs being allowed.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
A list of sample statistics, one from each problem. The list is not in
|
|
334
|
+
any specific order. This is the same data that would have been written
|
|
335
|
+
to a CSV file, but aggregated so that each problem has exactly one
|
|
336
|
+
sample statistic instead of potentially multiple.
|
|
337
|
+
|
|
338
|
+
Examples:
|
|
339
|
+
>>> import sinter
|
|
340
|
+
>>> import stim
|
|
341
|
+
>>> tasks = [
|
|
342
|
+
... sinter.Task(
|
|
343
|
+
... circuit=stim.Circuit.generated(
|
|
344
|
+
... 'repetition_code:memory',
|
|
345
|
+
... distance=5,
|
|
346
|
+
... rounds=5,
|
|
347
|
+
... before_round_data_depolarization=1e-3,
|
|
348
|
+
... ),
|
|
349
|
+
... json_metadata={'d': 5},
|
|
350
|
+
... ),
|
|
351
|
+
... sinter.Task(
|
|
352
|
+
... circuit=stim.Circuit.generated(
|
|
353
|
+
... 'repetition_code:memory',
|
|
354
|
+
... distance=7,
|
|
355
|
+
... rounds=5,
|
|
356
|
+
... before_round_data_depolarization=1e-3,
|
|
357
|
+
... ),
|
|
358
|
+
... json_metadata={'d': 7},
|
|
359
|
+
... ),
|
|
360
|
+
... ]
|
|
361
|
+
>>> stats = sinter.collect(
|
|
362
|
+
... tasks=tasks,
|
|
363
|
+
... decoders=['vacuous'],
|
|
364
|
+
... num_workers=2,
|
|
365
|
+
... max_shots=100,
|
|
366
|
+
... )
|
|
367
|
+
>>> for stat in sorted(stats, key=lambda e: e.json_metadata['d']):
|
|
368
|
+
... print(stat.json_metadata, stat.shots)
|
|
369
|
+
{'d': 5} 100
|
|
370
|
+
{'d': 7} 100
|
|
371
|
+
"""
|
|
372
|
+
# Load existing data.
|
|
373
|
+
additional_existing_data = ExistingData()
|
|
374
|
+
for existing in existing_data_filepaths:
|
|
375
|
+
additional_existing_data += ExistingData.from_file(existing)
|
|
376
|
+
|
|
377
|
+
if save_resume_filepath in existing_data_filepaths:
|
|
378
|
+
raise ValueError("save_resume_filepath in existing_data_filepaths")
|
|
379
|
+
|
|
380
|
+
progress_printer = ThrottledProgressPrinter(
|
|
381
|
+
outs=[],
|
|
382
|
+
print_progress=print_progress,
|
|
383
|
+
min_progress_delay=0.1,
|
|
384
|
+
)
|
|
385
|
+
with contextlib.ExitStack() as exit_stack:
|
|
386
|
+
# Open save/resume file.
|
|
387
|
+
if save_resume_filepath is not None:
|
|
388
|
+
save_resume_filepath = pathlib.Path(save_resume_filepath)
|
|
389
|
+
if save_resume_filepath.exists():
|
|
390
|
+
additional_existing_data += ExistingData.from_file(save_resume_filepath)
|
|
391
|
+
save_resume_file = exit_stack.enter_context(
|
|
392
|
+
open(save_resume_filepath, 'a')) # type: ignore
|
|
393
|
+
else:
|
|
394
|
+
save_resume_filepath.parent.mkdir(exist_ok=True)
|
|
395
|
+
save_resume_file = exit_stack.enter_context(
|
|
396
|
+
open(save_resume_filepath, 'w')) # type: ignore
|
|
397
|
+
print(CSV_HEADER, file=save_resume_file, flush=True)
|
|
398
|
+
else:
|
|
399
|
+
save_resume_file = None
|
|
400
|
+
|
|
401
|
+
# Collect data.
|
|
402
|
+
result = ExistingData()
|
|
403
|
+
result.data = dict(additional_existing_data.data)
|
|
404
|
+
for progress in iter_collect(
|
|
405
|
+
num_workers=num_workers,
|
|
406
|
+
max_shots=max_shots,
|
|
407
|
+
max_errors=max_errors,
|
|
408
|
+
max_batch_seconds=max_batch_seconds,
|
|
409
|
+
start_batch_size=start_batch_size,
|
|
410
|
+
max_batch_size=max_batch_size,
|
|
411
|
+
count_observable_error_combos=count_observable_error_combos,
|
|
412
|
+
count_detection_events=count_detection_events,
|
|
413
|
+
decoders=decoders,
|
|
414
|
+
tasks=tasks,
|
|
415
|
+
hint_num_tasks=hint_num_tasks,
|
|
416
|
+
additional_existing_data=additional_existing_data,
|
|
417
|
+
custom_decoders=custom_decoders,
|
|
418
|
+
custom_error_count_key=custom_error_count_key,
|
|
419
|
+
allowed_cpu_affinity_ids=allowed_cpu_affinity_ids,
|
|
420
|
+
):
|
|
421
|
+
for stats in progress.new_stats:
|
|
422
|
+
result.add_sample(stats)
|
|
423
|
+
if save_resume_file is not None:
|
|
424
|
+
print(stats.to_csv_line(), file=save_resume_file, flush=True)
|
|
425
|
+
if print_progress:
|
|
426
|
+
progress_printer.show_latest_progress(progress.status_message)
|
|
427
|
+
if progress_callback is not None:
|
|
428
|
+
progress_callback(progress)
|
|
429
|
+
if print_progress:
|
|
430
|
+
progress_printer.flush()
|
|
431
|
+
return list(result.data.values())
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def post_selection_mask_from_predicate(
|
|
435
|
+
circuit_or_dem: Union[stim.Circuit, stim.DetectorErrorModel],
|
|
436
|
+
*,
|
|
437
|
+
postselected_detectors_predicate: Callable[[int, Any, Tuple[float, ...]], bool],
|
|
438
|
+
metadata: Any,
|
|
439
|
+
) -> np.ndarray:
|
|
440
|
+
num_dets = circuit_or_dem.num_detectors
|
|
441
|
+
post_selection_mask = np.zeros(dtype=np.uint8, shape=math.ceil(num_dets / 8))
|
|
442
|
+
for k, coord in circuit_or_dem.get_detector_coordinates().items():
|
|
443
|
+
if postselected_detectors_predicate(k, metadata, coord):
|
|
444
|
+
post_selection_mask[k // 8] |= 1 << (k % 8)
|
|
445
|
+
return post_selection_mask
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def post_selection_mask_from_4th_coord(dem: Union[stim.Circuit, stim.DetectorErrorModel]) -> np.ndarray:
|
|
449
|
+
"""Returns a mask that postselects detector's with non-zero 4th coordinate.
|
|
450
|
+
|
|
451
|
+
This method is a leftover from before the existence of the command line
|
|
452
|
+
argument `--postselected_detectors_predicate`, when
|
|
453
|
+
`--postselect_detectors_with_non_zero_4th_coord` was the only way to do
|
|
454
|
+
post selection of detectors.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
dem: The detector error model to pull coordinate data from.
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
A bit packed numpy array where detectors with non-zero 4th coordinate
|
|
461
|
+
data have a True bit at their corresponding index.
|
|
462
|
+
|
|
463
|
+
Examples:
|
|
464
|
+
>>> import sinter
|
|
465
|
+
>>> import stim
|
|
466
|
+
>>> dem = stim.DetectorErrorModel('''
|
|
467
|
+
... detector(1, 2, 3) D0
|
|
468
|
+
... detector(1, 1, 1, 1) D1
|
|
469
|
+
... detector(1, 1, 1, 0) D2
|
|
470
|
+
... detector(1, 1, 1, 999) D80
|
|
471
|
+
... ''')
|
|
472
|
+
>>> sinter.post_selection_mask_from_4th_coord(dem)
|
|
473
|
+
array([2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], dtype=uint8)
|
|
474
|
+
"""
|
|
475
|
+
num_dets = dem.num_detectors
|
|
476
|
+
post_selection_mask = np.zeros(dtype=np.uint8, shape=math.ceil(num_dets / 8))
|
|
477
|
+
for k, coord in dem.get_detector_coordinates().items():
|
|
478
|
+
if len(coord) >= 4 and coord[3]:
|
|
479
|
+
post_selection_mask[k // 8] |= 1 << (k % 8)
|
|
480
|
+
return post_selection_mask
|