sinter 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sinter might be problematic. Click here for more details.

Files changed (62) hide show
  1. sinter/__init__.py +47 -0
  2. sinter/_collection/__init__.py +10 -0
  3. sinter/_collection/_collection.py +480 -0
  4. sinter/_collection/_collection_manager.py +581 -0
  5. sinter/_collection/_collection_manager_test.py +287 -0
  6. sinter/_collection/_collection_test.py +317 -0
  7. sinter/_collection/_collection_worker_loop.py +35 -0
  8. sinter/_collection/_collection_worker_state.py +259 -0
  9. sinter/_collection/_collection_worker_test.py +222 -0
  10. sinter/_collection/_mux_sampler.py +56 -0
  11. sinter/_collection/_printer.py +65 -0
  12. sinter/_collection/_sampler_ramp_throttled.py +66 -0
  13. sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
  14. sinter/_command/__init__.py +0 -0
  15. sinter/_command/_main.py +39 -0
  16. sinter/_command/_main_collect.py +350 -0
  17. sinter/_command/_main_collect_test.py +482 -0
  18. sinter/_command/_main_combine.py +84 -0
  19. sinter/_command/_main_combine_test.py +153 -0
  20. sinter/_command/_main_plot.py +817 -0
  21. sinter/_command/_main_plot_test.py +445 -0
  22. sinter/_command/_main_predict.py +75 -0
  23. sinter/_command/_main_predict_test.py +36 -0
  24. sinter/_data/__init__.py +20 -0
  25. sinter/_data/_anon_task_stats.py +89 -0
  26. sinter/_data/_anon_task_stats_test.py +35 -0
  27. sinter/_data/_collection_options.py +106 -0
  28. sinter/_data/_collection_options_test.py +24 -0
  29. sinter/_data/_csv_out.py +74 -0
  30. sinter/_data/_existing_data.py +173 -0
  31. sinter/_data/_existing_data_test.py +41 -0
  32. sinter/_data/_task.py +311 -0
  33. sinter/_data/_task_stats.py +244 -0
  34. sinter/_data/_task_stats_test.py +140 -0
  35. sinter/_data/_task_test.py +38 -0
  36. sinter/_decoding/__init__.py +16 -0
  37. sinter/_decoding/_decoding.py +419 -0
  38. sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
  39. sinter/_decoding/_decoding_decoder_class.py +161 -0
  40. sinter/_decoding/_decoding_fusion_blossom.py +193 -0
  41. sinter/_decoding/_decoding_mwpf.py +302 -0
  42. sinter/_decoding/_decoding_pymatching.py +81 -0
  43. sinter/_decoding/_decoding_test.py +480 -0
  44. sinter/_decoding/_decoding_vacuous.py +38 -0
  45. sinter/_decoding/_perfectionist_sampler.py +38 -0
  46. sinter/_decoding/_sampler.py +72 -0
  47. sinter/_decoding/_stim_then_decode_sampler.py +222 -0
  48. sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
  49. sinter/_plotting.py +619 -0
  50. sinter/_plotting_test.py +108 -0
  51. sinter/_predict.py +381 -0
  52. sinter/_predict_test.py +227 -0
  53. sinter/_probability_util.py +519 -0
  54. sinter/_probability_util_test.py +281 -0
  55. sinter-1.15.0.data/data/README.md +332 -0
  56. sinter-1.15.0.data/data/readme_example_plot.png +0 -0
  57. sinter-1.15.0.data/data/requirements.txt +4 -0
  58. sinter-1.15.0.dist-info/METADATA +354 -0
  59. sinter-1.15.0.dist-info/RECORD +62 -0
  60. sinter-1.15.0.dist-info/WHEEL +5 -0
  61. sinter-1.15.0.dist-info/entry_points.txt +2 -0
  62. sinter-1.15.0.dist-info/top_level.txt +1 -0
sinter/__init__.py ADDED
@@ -0,0 +1,47 @@
1
+ __version__ = '1.15.0'
2
+
3
+ from sinter._collection import (
4
+ collect,
5
+ iter_collect,
6
+ post_selection_mask_from_4th_coord,
7
+ Progress,
8
+ )
9
+ from sinter._data import (
10
+ AnonTaskStats,
11
+ CollectionOptions,
12
+ CSV_HEADER,
13
+ read_stats_from_csv_files,
14
+ stats_from_csv_files,
15
+ Task,
16
+ TaskStats,
17
+ )
18
+ from sinter._decoding import (
19
+ CompiledDecoder,
20
+ Decoder,
21
+ BUILT_IN_DECODERS,
22
+ BUILT_IN_SAMPLERS,
23
+ Sampler,
24
+ CompiledSampler,
25
+ )
26
+ from sinter._probability_util import (
27
+ comma_separated_key_values,
28
+ Fit,
29
+ fit_binomial,
30
+ fit_line_slope,
31
+ fit_line_y_at_x,
32
+ log_binomial,
33
+ log_factorial,
34
+ shot_error_rate_to_piece_error_rate,
35
+ )
36
+ from sinter._plotting import (
37
+ better_sorted_str_terms,
38
+ plot_discard_rate,
39
+ plot_error_rate,
40
+ group_by,
41
+ )
42
+ from sinter._predict import (
43
+ predict_discards_bit_packed,
44
+ predict_observables_bit_packed,
45
+ predict_on_disk,
46
+ predict_observables,
47
+ )
@@ -0,0 +1,10 @@
1
+ from sinter._collection._collection import (
2
+ collect,
3
+ iter_collect,
4
+ post_selection_mask_from_4th_coord,
5
+ post_selection_mask_from_predicate,
6
+ Progress,
7
+ )
8
+ from sinter._collection._printer import (
9
+ ThrottledProgressPrinter,
10
+ )
@@ -0,0 +1,480 @@
1
+ import contextlib
2
+ import dataclasses
3
+ import pathlib
4
+ from typing import Any, Callable, Iterator, Optional, Union, Iterable, List, TYPE_CHECKING, Tuple, Dict
5
+
6
+ import math
7
+ import numpy as np
8
+ import stim
9
+
10
+ from sinter._data import CSV_HEADER, ExistingData, TaskStats, CollectionOptions, Task
11
+ from sinter._collection._collection_manager import CollectionManager
12
+ from sinter._collection._printer import ThrottledProgressPrinter
13
+
14
+ if TYPE_CHECKING:
15
+ import sinter
16
+
17
+
18
+ @dataclasses.dataclass(frozen=True)
19
+ class Progress:
20
+ """Describes statistics and status messages from ongoing sampling.
21
+
22
+ This is the type yielded by `sinter.iter_collect`, and given to the
23
+ `progress_callback` argument of `sinter.collect`.
24
+
25
+ Attributes:
26
+ new_stats: New sampled statistics collected since the last progress
27
+ update.
28
+ status_message: A free form human readable string describing the current
29
+ collection status, such as the number of tasks left and the
30
+ estimated time to completion for each task.
31
+ """
32
+ new_stats: Tuple[TaskStats, ...]
33
+ status_message: str
34
+
35
+
36
+ def iter_collect(*,
37
+ num_workers: int,
38
+ tasks: Union[Iterator['sinter.Task'],
39
+ Iterable['sinter.Task']],
40
+ hint_num_tasks: Optional[int] = None,
41
+ additional_existing_data: Union[None, dict[str, 'TaskStats'], Iterable['TaskStats']] = None,
42
+ max_shots: Optional[int] = None,
43
+ max_errors: Optional[int] = None,
44
+ decoders: Optional[Iterable[str]] = None,
45
+ max_batch_seconds: Optional[int] = None,
46
+ max_batch_size: Optional[int] = None,
47
+ start_batch_size: Optional[int] = None,
48
+ count_observable_error_combos: bool = False,
49
+ count_detection_events: bool = False,
50
+ custom_decoders: Optional[Dict[str, Union['sinter.Decoder', 'sinter.Sampler']]] = None,
51
+ custom_error_count_key: Optional[str] = None,
52
+ allowed_cpu_affinity_ids: Optional[Iterable[int]] = None,
53
+ ) -> Iterator['sinter.Progress']:
54
+ """Iterates error correction statistics collected from worker processes.
55
+
56
+ It is important to iterate until the sequence ends, or worker processes will
57
+ be left alive. The values yielded during iteration are progress updates from
58
+ the workers.
59
+
60
+ Note: if max_batch_size and max_batch_seconds are both not used (or
61
+ explicitly set to None), a default batch-size-limiting mechanism will be
62
+ chosen.
63
+
64
+ Args:
65
+ num_workers: The number of worker processes to use.
66
+ tasks: Decoding problems to sample.
67
+ hint_num_tasks: If `tasks` is an iterator or a generator, its length
68
+ can be given here so that progress printouts can say how many cases
69
+ are left.
70
+ additional_existing_data: Defaults to None (no additional data).
71
+ Statistical data that has already been collected, in addition to
72
+ anything included in each task's `previous_stats` field.
73
+ decoders: Defaults to None (specified by each Task). The names of the
74
+ decoders to use on each Task. It must either be the case that each
75
+ Task specifies a decoder and this is set to None, or this is an
76
+ iterable and each Task has its decoder set to None.
77
+ max_shots: Defaults to None (unused). Stops the sampling process
78
+ after this many samples have been taken from the circuit.
79
+ max_errors: Defaults to None (unused). Stops the sampling process
80
+ after this many errors have been seen in samples taken from the
81
+ circuit. The actual number sampled errors may be larger due to
82
+ batching.
83
+ count_observable_error_combos: Defaults to False. When set to to True,
84
+ the returned stats will have a custom counts field with keys
85
+ like `obs_mistake_mask=E_E__` counting how many times specific
86
+ combinations of observables were mispredicted by the decoder.
87
+ count_detection_events: Defaults to False. When set to True, the
88
+ returned stats will have a custom counts field withs the
89
+ key `detection_events` counting the number of times a detector fired
90
+ and also `detectors_checked` counting the number of detectors that
91
+ were executed. The detection fraction is the ratio of these two
92
+ numbers.
93
+ start_batch_size: Defaults to None (collector's choice). The very
94
+ first shots taken from the circuit will use a batch of this
95
+ size, and no other batches will be taken in parallel. Once this
96
+ initial fact finding batch is done, batches can be taken in
97
+ parallel and the normal batch size limiting processes take over.
98
+ max_batch_size: Defaults to None (unused). Limits batches from
99
+ taking more than this many shots at once. For example, this can
100
+ be used to ensure memory usage stays below some limit.
101
+ max_batch_seconds: Defaults to None (unused). When set, the recorded
102
+ data from previous shots is used to estimate how much time is
103
+ taken per shot. This information is then used to predict the
104
+ biggest batch size that can finish in under the given number of
105
+ seconds. Limits each batch to be no larger than that.
106
+ custom_decoders: Custom decoders that can be used if requested by name.
107
+ If not specified, only decoders built into sinter, such as
108
+ 'pymatching' and 'fusion_blossom', can be used.
109
+ custom_error_count_key: Makes `max_errors` apply to `stat.custom_counts[key]`
110
+ instead of `stat.errors`.
111
+ allowed_cpu_affinity_ids: Controls which CPUs the workers can be pinned to. The
112
+ set of allowed IDs should be at least as large as the number of workers, though
113
+ this is not strictly required. If not set, defaults to all CPUs being allowed.
114
+
115
+ Yields:
116
+ sinter.Progress instances recording incremental statistical data as it
117
+ is collected by workers.
118
+
119
+ Examples:
120
+ >>> import sinter
121
+ >>> import stim
122
+ >>> tasks = [
123
+ ... sinter.Task(
124
+ ... circuit=stim.Circuit.generated(
125
+ ... 'repetition_code:memory',
126
+ ... distance=5,
127
+ ... rounds=5,
128
+ ... before_round_data_depolarization=1e-3,
129
+ ... ),
130
+ ... json_metadata={'d': 5},
131
+ ... ),
132
+ ... sinter.Task(
133
+ ... circuit=stim.Circuit.generated(
134
+ ... 'repetition_code:memory',
135
+ ... distance=7,
136
+ ... rounds=5,
137
+ ... before_round_data_depolarization=1e-3,
138
+ ... ),
139
+ ... json_metadata={'d': 7},
140
+ ... ),
141
+ ... ]
142
+ >>> iterator = sinter.iter_collect(
143
+ ... tasks=tasks,
144
+ ... decoders=['vacuous'],
145
+ ... num_workers=2,
146
+ ... max_shots=100,
147
+ ... )
148
+ >>> total_shots = 0
149
+ >>> for progress in iterator:
150
+ ... for stat in progress.new_stats:
151
+ ... total_shots += stat.shots
152
+ >>> print(total_shots)
153
+ 200
154
+ """
155
+ existing_data: dict[str, TaskStats]
156
+ if isinstance(additional_existing_data, ExistingData):
157
+ existing_data = additional_existing_data.data
158
+ elif isinstance(additional_existing_data, dict):
159
+ existing_data = additional_existing_data
160
+ elif additional_existing_data is None:
161
+ existing_data = {}
162
+ else:
163
+ acc = ExistingData()
164
+ for stat in additional_existing_data:
165
+ acc.add_sample(stat)
166
+ existing_data = acc.data
167
+
168
+ if isinstance(decoders, str):
169
+ decoders = [decoders]
170
+
171
+ if hint_num_tasks is None:
172
+ try:
173
+ # noinspection PyTypeChecker
174
+ hint_num_tasks = len(tasks)
175
+ except TypeError:
176
+ pass
177
+
178
+ if decoders is not None:
179
+ old_tasks = tasks
180
+ tasks = (
181
+ Task(
182
+ circuit=task.circuit,
183
+ decoder=decoder,
184
+ detector_error_model=task.detector_error_model,
185
+ postselection_mask=task.postselection_mask,
186
+ postselected_observables_mask=task.postselected_observables_mask,
187
+ json_metadata=task.json_metadata,
188
+ collection_options=task.collection_options,
189
+ circuit_path=task.circuit_path,
190
+ )
191
+ for task in old_tasks
192
+ for decoder in (decoders if task.decoder is None else [task.decoder])
193
+ )
194
+
195
+ progress_log: list[Optional[TaskStats]] = []
196
+ def log_progress(e: Optional[TaskStats]):
197
+ progress_log.append(e)
198
+ with CollectionManager(
199
+ num_workers=num_workers,
200
+ tasks=tasks,
201
+ collection_options=CollectionOptions(
202
+ max_shots=max_shots,
203
+ max_errors=max_errors,
204
+ max_batch_seconds=max_batch_seconds,
205
+ start_batch_size=start_batch_size,
206
+ max_batch_size=max_batch_size,
207
+ ),
208
+ existing_data=existing_data,
209
+ count_observable_error_combos=count_observable_error_combos,
210
+ count_detection_events=count_detection_events,
211
+ custom_error_count_key=custom_error_count_key,
212
+ custom_decoders=custom_decoders or {},
213
+ allowed_cpu_affinity_ids=allowed_cpu_affinity_ids,
214
+ worker_flush_period=max_batch_seconds or 120,
215
+ progress_callback=log_progress,
216
+ ) as manager:
217
+ try:
218
+ yield Progress(
219
+ new_stats=(),
220
+ status_message=f"Starting {num_workers} workers..."
221
+ )
222
+ manager.start_workers()
223
+ manager.start_distributing_work()
224
+
225
+ while manager.task_states:
226
+ manager.process_message()
227
+ if progress_log:
228
+ vals = list(progress_log)
229
+ progress_log.clear()
230
+ for e in vals:
231
+ if e is not None:
232
+ yield Progress(
233
+ new_stats=(e,),
234
+ status_message=manager.status_message(),
235
+ )
236
+
237
+ except KeyboardInterrupt:
238
+ yield Progress(
239
+ new_stats=(),
240
+ status_message='KeyboardInterrupt',
241
+ )
242
+ raise
243
+
244
+
245
+ def collect(*,
246
+ num_workers: int,
247
+ tasks: Union[Iterator['sinter.Task'], Iterable['sinter.Task']],
248
+ existing_data_filepaths: Iterable[Union[str, pathlib.Path]] = (),
249
+ save_resume_filepath: Union[None, str, pathlib.Path] = None,
250
+ progress_callback: Optional[Callable[['sinter.Progress'], None]] = None,
251
+ max_shots: Optional[int] = None,
252
+ max_errors: Optional[int] = None,
253
+ count_observable_error_combos: bool = False,
254
+ count_detection_events: bool = False,
255
+ decoders: Optional[Iterable[str]] = None,
256
+ max_batch_seconds: Optional[int] = None,
257
+ max_batch_size: Optional[int] = None,
258
+ start_batch_size: Optional[int] = None,
259
+ print_progress: bool = False,
260
+ hint_num_tasks: Optional[int] = None,
261
+ custom_decoders: Optional[Dict[str, Union['sinter.Decoder', 'sinter.Sampler']]] = None,
262
+ custom_error_count_key: Optional[str] = None,
263
+ allowed_cpu_affinity_ids: Optional[Iterable[int]] = None,
264
+ ) -> List['sinter.TaskStats']:
265
+ """Collects statistics from the given tasks, using multiprocessing.
266
+
267
+ Args:
268
+ num_workers: The number of worker processes to use.
269
+ tasks: Decoding problems to sample.
270
+ save_resume_filepath: Defaults to None (unused). If set to a filepath,
271
+ results will be saved to that file while they are collected. If the
272
+ python interpreter is stopped or killed, calling this method again
273
+ with the same save_resume_filepath will load the previous results
274
+ from the file so it can resume where it left off.
275
+
276
+ The stats in this file will be counted in addition to each task's
277
+ previous_stats field (as opposed to overriding the field).
278
+ existing_data_filepaths: CSV data saved to these files will be loaded,
279
+ included in the returned results, and count towards things like
280
+ max_shots and max_errors.
281
+ progress_callback: Defaults to None (unused). If specified, then each
282
+ time new sample statistics are acquired from a worker this method
283
+ will be invoked with the new `sinter.TaskStats`.
284
+ hint_num_tasks: If `tasks` is an iterator or a generator, its length
285
+ can be given here so that progress printouts can say how many cases
286
+ are left.
287
+ decoders: Defaults to None (specified by each Task). The names of the
288
+ decoders to use on each Task. It must either be the case that each
289
+ Task specifies a decoder and this is set to None, or this is an
290
+ iterable and each Task has its decoder set to None.
291
+ count_observable_error_combos: Defaults to False. When set to to True,
292
+ the returned stats will have a custom counts field with keys
293
+ like `obs_mistake_mask=E_E__` counting how many times specific
294
+ combinations of observables were mispredicted by the decoder.
295
+ count_detection_events: Defaults to False. When set to True, the
296
+ returned stats will have a custom counts field withs the
297
+ key `detection_events` counting the number of times a detector fired
298
+ and also `detectors_checked` counting the number of detectors that
299
+ were executed. The detection fraction is the ratio of these two
300
+ numbers.
301
+ max_shots: Defaults to None (unused). Stops the sampling process
302
+ after this many samples have been taken from the circuit.
303
+ max_errors: Defaults to None (unused). Stops the sampling process
304
+ after this many errors have been seen in samples taken from the
305
+ circuit. The actual number sampled errors may be larger due to
306
+ batching.
307
+ start_batch_size: Defaults to None (collector's choice). The very
308
+ first shots taken from the circuit will use a batch of this
309
+ size, and no other batches will be taken in parallel. Once this
310
+ initial fact finding batch is done, batches can be taken in
311
+ parallel and the normal batch size limiting processes take over.
312
+ max_batch_size: Defaults to None (unused). Limits batches from
313
+ taking more than this many shots at once. For example, this can
314
+ be used to ensure memory usage stays below some limit.
315
+ print_progress: When True, progress is printed to stderr while
316
+ collection runs.
317
+ max_batch_seconds: Defaults to None (unused). When set, the recorded
318
+ data from previous shots is used to estimate how much time is
319
+ taken per shot. This information is then used to predict the
320
+ biggest batch size that can finish in under the given number of
321
+ seconds. Limits each batch to be no larger than that.
322
+ custom_decoders: Named child classes of `sinter.decoder`, that can be
323
+ used if requested by name by a task or by the decoders list.
324
+ If not specified, only decoders with support built into sinter, such
325
+ as 'pymatching' and 'fusion_blossom', can be used.
326
+ custom_error_count_key: Makes `max_errors` apply to `stat.custom_counts[key]`
327
+ instead of `stat.errors`.
328
+ allowed_cpu_affinity_ids: Controls which CPUs the workers can be pinned to. The
329
+ set of allowed IDs should be at least as large as the number of workers, though
330
+ this is not strictly required. If not set, defaults to all CPUs being allowed.
331
+
332
+ Returns:
333
+ A list of sample statistics, one from each problem. The list is not in
334
+ any specific order. This is the same data that would have been written
335
+ to a CSV file, but aggregated so that each problem has exactly one
336
+ sample statistic instead of potentially multiple.
337
+
338
+ Examples:
339
+ >>> import sinter
340
+ >>> import stim
341
+ >>> tasks = [
342
+ ... sinter.Task(
343
+ ... circuit=stim.Circuit.generated(
344
+ ... 'repetition_code:memory',
345
+ ... distance=5,
346
+ ... rounds=5,
347
+ ... before_round_data_depolarization=1e-3,
348
+ ... ),
349
+ ... json_metadata={'d': 5},
350
+ ... ),
351
+ ... sinter.Task(
352
+ ... circuit=stim.Circuit.generated(
353
+ ... 'repetition_code:memory',
354
+ ... distance=7,
355
+ ... rounds=5,
356
+ ... before_round_data_depolarization=1e-3,
357
+ ... ),
358
+ ... json_metadata={'d': 7},
359
+ ... ),
360
+ ... ]
361
+ >>> stats = sinter.collect(
362
+ ... tasks=tasks,
363
+ ... decoders=['vacuous'],
364
+ ... num_workers=2,
365
+ ... max_shots=100,
366
+ ... )
367
+ >>> for stat in sorted(stats, key=lambda e: e.json_metadata['d']):
368
+ ... print(stat.json_metadata, stat.shots)
369
+ {'d': 5} 100
370
+ {'d': 7} 100
371
+ """
372
+ # Load existing data.
373
+ additional_existing_data = ExistingData()
374
+ for existing in existing_data_filepaths:
375
+ additional_existing_data += ExistingData.from_file(existing)
376
+
377
+ if save_resume_filepath in existing_data_filepaths:
378
+ raise ValueError("save_resume_filepath in existing_data_filepaths")
379
+
380
+ progress_printer = ThrottledProgressPrinter(
381
+ outs=[],
382
+ print_progress=print_progress,
383
+ min_progress_delay=0.1,
384
+ )
385
+ with contextlib.ExitStack() as exit_stack:
386
+ # Open save/resume file.
387
+ if save_resume_filepath is not None:
388
+ save_resume_filepath = pathlib.Path(save_resume_filepath)
389
+ if save_resume_filepath.exists():
390
+ additional_existing_data += ExistingData.from_file(save_resume_filepath)
391
+ save_resume_file = exit_stack.enter_context(
392
+ open(save_resume_filepath, 'a')) # type: ignore
393
+ else:
394
+ save_resume_filepath.parent.mkdir(exist_ok=True)
395
+ save_resume_file = exit_stack.enter_context(
396
+ open(save_resume_filepath, 'w')) # type: ignore
397
+ print(CSV_HEADER, file=save_resume_file, flush=True)
398
+ else:
399
+ save_resume_file = None
400
+
401
+ # Collect data.
402
+ result = ExistingData()
403
+ result.data = dict(additional_existing_data.data)
404
+ for progress in iter_collect(
405
+ num_workers=num_workers,
406
+ max_shots=max_shots,
407
+ max_errors=max_errors,
408
+ max_batch_seconds=max_batch_seconds,
409
+ start_batch_size=start_batch_size,
410
+ max_batch_size=max_batch_size,
411
+ count_observable_error_combos=count_observable_error_combos,
412
+ count_detection_events=count_detection_events,
413
+ decoders=decoders,
414
+ tasks=tasks,
415
+ hint_num_tasks=hint_num_tasks,
416
+ additional_existing_data=additional_existing_data,
417
+ custom_decoders=custom_decoders,
418
+ custom_error_count_key=custom_error_count_key,
419
+ allowed_cpu_affinity_ids=allowed_cpu_affinity_ids,
420
+ ):
421
+ for stats in progress.new_stats:
422
+ result.add_sample(stats)
423
+ if save_resume_file is not None:
424
+ print(stats.to_csv_line(), file=save_resume_file, flush=True)
425
+ if print_progress:
426
+ progress_printer.show_latest_progress(progress.status_message)
427
+ if progress_callback is not None:
428
+ progress_callback(progress)
429
+ if print_progress:
430
+ progress_printer.flush()
431
+ return list(result.data.values())
432
+
433
+
434
+ def post_selection_mask_from_predicate(
435
+ circuit_or_dem: Union[stim.Circuit, stim.DetectorErrorModel],
436
+ *,
437
+ postselected_detectors_predicate: Callable[[int, Any, Tuple[float, ...]], bool],
438
+ metadata: Any,
439
+ ) -> np.ndarray:
440
+ num_dets = circuit_or_dem.num_detectors
441
+ post_selection_mask = np.zeros(dtype=np.uint8, shape=math.ceil(num_dets / 8))
442
+ for k, coord in circuit_or_dem.get_detector_coordinates().items():
443
+ if postselected_detectors_predicate(k, metadata, coord):
444
+ post_selection_mask[k // 8] |= 1 << (k % 8)
445
+ return post_selection_mask
446
+
447
+
448
+ def post_selection_mask_from_4th_coord(dem: Union[stim.Circuit, stim.DetectorErrorModel]) -> np.ndarray:
449
+ """Returns a mask that postselects detector's with non-zero 4th coordinate.
450
+
451
+ This method is a leftover from before the existence of the command line
452
+ argument `--postselected_detectors_predicate`, when
453
+ `--postselect_detectors_with_non_zero_4th_coord` was the only way to do
454
+ post selection of detectors.
455
+
456
+ Args:
457
+ dem: The detector error model to pull coordinate data from.
458
+
459
+ Returns:
460
+ A bit packed numpy array where detectors with non-zero 4th coordinate
461
+ data have a True bit at their corresponding index.
462
+
463
+ Examples:
464
+ >>> import sinter
465
+ >>> import stim
466
+ >>> dem = stim.DetectorErrorModel('''
467
+ ... detector(1, 2, 3) D0
468
+ ... detector(1, 1, 1, 1) D1
469
+ ... detector(1, 1, 1, 0) D2
470
+ ... detector(1, 1, 1, 999) D80
471
+ ... ''')
472
+ >>> sinter.post_selection_mask_from_4th_coord(dem)
473
+ array([2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], dtype=uint8)
474
+ """
475
+ num_dets = dem.num_detectors
476
+ post_selection_mask = np.zeros(dtype=np.uint8, shape=math.ceil(num_dets / 8))
477
+ for k, coord in dem.get_detector_coordinates().items():
478
+ if len(coord) >= 4 and coord[3]:
479
+ post_selection_mask[k // 8] |= 1 << (k % 8)
480
+ return post_selection_mask