sinter 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sinter might be problematic. Click here for more details.

Files changed (62) hide show
  1. sinter/__init__.py +47 -0
  2. sinter/_collection/__init__.py +10 -0
  3. sinter/_collection/_collection.py +480 -0
  4. sinter/_collection/_collection_manager.py +581 -0
  5. sinter/_collection/_collection_manager_test.py +287 -0
  6. sinter/_collection/_collection_test.py +317 -0
  7. sinter/_collection/_collection_worker_loop.py +35 -0
  8. sinter/_collection/_collection_worker_state.py +259 -0
  9. sinter/_collection/_collection_worker_test.py +222 -0
  10. sinter/_collection/_mux_sampler.py +56 -0
  11. sinter/_collection/_printer.py +65 -0
  12. sinter/_collection/_sampler_ramp_throttled.py +66 -0
  13. sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
  14. sinter/_command/__init__.py +0 -0
  15. sinter/_command/_main.py +39 -0
  16. sinter/_command/_main_collect.py +350 -0
  17. sinter/_command/_main_collect_test.py +482 -0
  18. sinter/_command/_main_combine.py +84 -0
  19. sinter/_command/_main_combine_test.py +153 -0
  20. sinter/_command/_main_plot.py +817 -0
  21. sinter/_command/_main_plot_test.py +445 -0
  22. sinter/_command/_main_predict.py +75 -0
  23. sinter/_command/_main_predict_test.py +36 -0
  24. sinter/_data/__init__.py +20 -0
  25. sinter/_data/_anon_task_stats.py +89 -0
  26. sinter/_data/_anon_task_stats_test.py +35 -0
  27. sinter/_data/_collection_options.py +106 -0
  28. sinter/_data/_collection_options_test.py +24 -0
  29. sinter/_data/_csv_out.py +74 -0
  30. sinter/_data/_existing_data.py +173 -0
  31. sinter/_data/_existing_data_test.py +41 -0
  32. sinter/_data/_task.py +311 -0
  33. sinter/_data/_task_stats.py +244 -0
  34. sinter/_data/_task_stats_test.py +140 -0
  35. sinter/_data/_task_test.py +38 -0
  36. sinter/_decoding/__init__.py +16 -0
  37. sinter/_decoding/_decoding.py +419 -0
  38. sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
  39. sinter/_decoding/_decoding_decoder_class.py +161 -0
  40. sinter/_decoding/_decoding_fusion_blossom.py +193 -0
  41. sinter/_decoding/_decoding_mwpf.py +302 -0
  42. sinter/_decoding/_decoding_pymatching.py +81 -0
  43. sinter/_decoding/_decoding_test.py +480 -0
  44. sinter/_decoding/_decoding_vacuous.py +38 -0
  45. sinter/_decoding/_perfectionist_sampler.py +38 -0
  46. sinter/_decoding/_sampler.py +72 -0
  47. sinter/_decoding/_stim_then_decode_sampler.py +222 -0
  48. sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
  49. sinter/_plotting.py +619 -0
  50. sinter/_plotting_test.py +108 -0
  51. sinter/_predict.py +381 -0
  52. sinter/_predict_test.py +227 -0
  53. sinter/_probability_util.py +519 -0
  54. sinter/_probability_util_test.py +281 -0
  55. sinter-1.15.0.data/data/README.md +332 -0
  56. sinter-1.15.0.data/data/readme_example_plot.png +0 -0
  57. sinter-1.15.0.data/data/requirements.txt +4 -0
  58. sinter-1.15.0.dist-info/METADATA +354 -0
  59. sinter-1.15.0.dist-info/RECORD +62 -0
  60. sinter-1.15.0.dist-info/WHEEL +5 -0
  61. sinter-1.15.0.dist-info/entry_points.txt +2 -0
  62. sinter-1.15.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,144 @@
1
+ from unittest import mock
2
+
3
+ import pytest
4
+
5
+ import sinter
6
+ import stim
7
+ from sinter._collection._sampler_ramp_throttled import (
8
+ CompiledRampThrottledSampler,
9
+ RampThrottledSampler,
10
+ )
11
+ from sinter._data import AnonTaskStats, Task
12
+
13
+
14
+ class MockSampler(sinter.Sampler, sinter.CompiledSampler):
15
+ """Mock sampler that tracks `suggested_shots` parameter in `sample` calls."""
16
+
17
+ def __init__(self):
18
+ self.calls = []
19
+
20
+ def compiled_sampler_for_task(self, task: Task) -> sinter.CompiledSampler:
21
+ return self
22
+
23
+ def sample(self, suggested_shots: int) -> AnonTaskStats:
24
+ self.calls.append(suggested_shots)
25
+ return AnonTaskStats(
26
+ shots=suggested_shots,
27
+ errors=1,
28
+ seconds=0.001 * suggested_shots, # Simulate time proportional to shots
29
+ )
30
+
31
+
32
+ @pytest.fixture
33
+ def mock_sampler():
34
+ return MockSampler()
35
+
36
+
37
+ def test_initial_batch_size(mock_sampler):
38
+ """Test that the sampler starts with a batch size of 1."""
39
+ sampler = CompiledRampThrottledSampler(
40
+ sub_sampler=mock_sampler,
41
+ target_batch_seconds=1.0,
42
+ max_batch_shots=1024,
43
+ )
44
+
45
+ # First call should use batch_size=1
46
+ sampler.sample(100)
47
+ assert mock_sampler.calls[0] == 1
48
+
49
+
50
+ def test_batch_size_ramps_up(mock_sampler):
51
+ """Test that the batch size increases when execution is fast."""
52
+ sampler = CompiledRampThrottledSampler(
53
+ sub_sampler=mock_sampler,
54
+ target_batch_seconds=1.0,
55
+ max_batch_shots=1024,
56
+ )
57
+
58
+ # Mock time.monotonic to simulate fast execution
59
+ # two calls per sample for tic/toc
60
+ with mock.patch(
61
+ "time.monotonic", side_effect=[0.0, 0.001, 0.02, 0.021, 0.03, 0.031]
62
+ ):
63
+ sampler.sample(100) # First call, batch_size=1
64
+ sampler.sample(100) # Should double 4 times to 16
65
+ sampler.sample(100) # Should double 4 times again but hit limit of 100
66
+
67
+ assert mock_sampler.calls == [1, 16, 100]
68
+
69
+
70
+ def test_batch_size_decreases(mock_sampler):
71
+ """Test that the batch size decreases when execution is slow."""
72
+ sampler = CompiledRampThrottledSampler(
73
+ sub_sampler=mock_sampler,
74
+ target_batch_seconds=0.1,
75
+ max_batch_shots=1024,
76
+ )
77
+
78
+ # Set initial batch size higher for this test
79
+ sampler.batch_shots = 64
80
+
81
+ # Mock time.monotonic to simulate slow execution (>1.3x target)
82
+ with mock.patch("time.monotonic", side_effect=[0.0, 0.15, 0.5, 0.65]):
83
+ sampler.sample(100) # First call, batch_size=64
84
+ sampler.sample(100) # Should halve to 32
85
+
86
+ assert mock_sampler.calls == [64, 32]
87
+
88
+
89
+ def test_respects_max_batch_shots(mock_sampler):
90
+ """Test that the batch size never exceeds max_batch_shots."""
91
+ sampler = CompiledRampThrottledSampler(
92
+ sub_sampler=mock_sampler,
93
+ target_batch_seconds=1.0,
94
+ max_batch_shots=16, # Small max for testing
95
+ )
96
+
97
+ # Set initial batch size close to max
98
+ sampler.batch_shots = 8
99
+
100
+ # Mock time.monotonic to simulate very fast execution
101
+ # two calls per sample for tic/toc
102
+ with mock.patch(
103
+ "time.monotonic", side_effect=[0.0, 0.001, 0.02, 0.021, 0.03, 0.031]
104
+ ):
105
+ sampler.sample(100) # First call, batch_size=8
106
+ sampler.sample(100) # Should double to 16
107
+ sampler.sample(100) # Should stay at 16 (max)
108
+
109
+ assert mock_sampler.calls == [8, 16, 16]
110
+
111
+
112
+ def test_respects_max_shots_parameter(mock_sampler):
113
+ """Test that the sampler respects the max_shots parameter."""
114
+ sampler = CompiledRampThrottledSampler(
115
+ sub_sampler=mock_sampler,
116
+ target_batch_seconds=1.0,
117
+ max_batch_shots=1024,
118
+ )
119
+
120
+ # Set batch size higher than max_shots
121
+ sampler.batch_shots = 100
122
+
123
+ # Call with max_shots=10
124
+ sampler.sample(10)
125
+
126
+ # Should only request 10 shots, not 100
127
+ assert mock_sampler.calls[0] == 10
128
+
129
+
130
+ def test_sub_sampler_parameter_pass_through(mock_sampler):
131
+ """Test that parameters are passed through to compiled sub sampler."""
132
+ factory = RampThrottledSampler(
133
+ sub_sampler=mock_sampler,
134
+ target_batch_seconds=0.5,
135
+ max_batch_shots=512,
136
+ )
137
+
138
+ task = Task(circuit=stim.Circuit(), decoder="test")
139
+ compiled = factory.compiled_sampler_for_task(task)
140
+
141
+ assert isinstance(compiled, CompiledRampThrottledSampler)
142
+ assert compiled.target_batch_seconds == 0.5
143
+ assert compiled.max_batch_shots == 512
144
+ assert compiled.batch_shots == 1 # Initial batch size
File without changes
@@ -0,0 +1,39 @@
1
+ import sys
2
+ from typing import Optional, List
3
+
4
+
5
+ def main(*, command_line_args: Optional[List[str]] = None):
6
+ if command_line_args is None:
7
+ command_line_args = sys.argv[1:]
8
+
9
+ mode = command_line_args[0] if command_line_args else None
10
+ if mode == 'combine':
11
+ from sinter._command._main_combine import main_combine
12
+ return main_combine(command_line_args=command_line_args[1:])
13
+ if mode == 'collect':
14
+ from sinter._command._main_collect import main_collect
15
+ return main_collect(command_line_args=command_line_args[1:])
16
+ if mode == 'plot':
17
+ from sinter._command._main_plot import main_plot
18
+ return main_plot(command_line_args=command_line_args[1:])
19
+ if mode == 'predict':
20
+ from sinter._command._main_predict import main_predict
21
+ return main_predict(command_line_args=command_line_args[1:])
22
+
23
+ want_help = mode in ['help', 'h', '--help', '-help', '-h', '--h']
24
+ if not want_help:
25
+ if command_line_args and not command_line_args[0].startswith('-'):
26
+ print(f"\033[31mUnrecognized command: sinter {command_line_args[0]}\033[0m\n", file=sys.stderr)
27
+ else:
28
+ print(f"\033[31mDidn't specify a command.\033[0m\n", file=sys.stderr)
29
+ print(f"Available commands are:\n"
30
+ f" sinter collect\n"
31
+ f" sinter combine\n"
32
+ f" sinter plot"
33
+ f"", file=sys.stderr)
34
+ if not want_help:
35
+ sys.exit(1)
36
+
37
+
38
+ if __name__ == '__main__':
39
+ main()
@@ -0,0 +1,350 @@
1
+ import argparse
2
+ import math
3
+ import os
4
+ import sys
5
+ from typing import Iterator, Any, Tuple, List, Callable, Optional
6
+ from typing import cast
7
+
8
+ import numpy as np
9
+ import stim
10
+
11
+ from sinter._collection import ThrottledProgressPrinter
12
+ from sinter._data import Task
13
+ from sinter._collection import collect, Progress, post_selection_mask_from_predicate
14
+ from sinter._command._main_combine import ExistingData, CSV_HEADER
15
+ from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_SAMPLERS
16
+
17
+
18
+ def iter_file_paths_into_goals(circuit_paths: Iterator[str],
19
+ metadata_func: Callable,
20
+ postselected_detectors_predicate: Optional[Callable[[int, Any, Tuple[float, ...]], bool]],
21
+ postselected_observables_predicate: Callable[[int, Any], bool],
22
+ ) -> Iterator[Task]:
23
+ for path in circuit_paths:
24
+ with open(path) as f:
25
+ circuit_text = f.read()
26
+ circuit = stim.Circuit(circuit_text)
27
+
28
+ metadata = metadata_func(path=path, circuit=circuit)
29
+ if postselected_detectors_predicate is not None:
30
+ post_mask = post_selection_mask_from_predicate(circuit, metadata=metadata, postselected_detectors_predicate=postselected_detectors_predicate)
31
+ if not np.any(post_mask):
32
+ post_mask = None
33
+ else:
34
+ post_mask = None
35
+ postselected_observables = [
36
+ k
37
+ for k in range(circuit.num_observables)
38
+ if postselected_observables_predicate(k, metadata)
39
+ ]
40
+ if any(postselected_observables):
41
+ postselected_observables_mask = np.zeros(shape=math.ceil(circuit.num_observables / 8), dtype=np.uint8)
42
+ for k in postselected_observables:
43
+ postselected_observables_mask[k // 8] |= 1 << (k % 8)
44
+ else:
45
+ postselected_observables_mask = None
46
+
47
+ yield Task(
48
+ circuit=circuit,
49
+ postselection_mask=post_mask,
50
+ postselected_observables_mask=postselected_observables_mask,
51
+ json_metadata=metadata,
52
+ )
53
+
54
+
55
+ def parse_args(args: List[str]) -> Any:
56
+ parser = argparse.ArgumentParser(description='Collect Monte Carlo samples.',
57
+ prog='sinter collect')
58
+ parser.add_argument('--circuits',
59
+ nargs='+',
60
+ required=True,
61
+ help='Circuit files to sample from and decode.\n'
62
+ 'This parameter can be given multiple arguments.')
63
+ parser.add_argument('--decoders',
64
+ type=str,
65
+ nargs='+',
66
+ required=True,
67
+ help='The decoder to use to predict observables from detection events.')
68
+ parser.add_argument('--custom_decoders_module_function',
69
+ default=None,
70
+ nargs='+',
71
+ help='Use the syntax "module:function" to "import function from module" '
72
+ 'and use the result of "function()" as the custom_decoders '
73
+ 'dictionary. The dictionary must map strings to stim.Decoder '
74
+ 'instances.')
75
+ parser.add_argument('--max_shots',
76
+ type=int,
77
+ default=None,
78
+ help='Sampling of a circuit will stop if this many shots have been taken.')
79
+ parser.add_argument('--max_errors',
80
+ type=int,
81
+ default=None,
82
+ help='Sampling of a circuit will stop if this many errors have been seen.')
83
+ parser.add_argument('--processes',
84
+ default='auto',
85
+ type=str,
86
+ help='Number of processes to use for simultaneous sampling and decoding. '
87
+ 'Must be either a number or "auto" which sets it to the number of '
88
+ 'CPUs on the machine.')
89
+ parser.add_argument('--save_resume_filepath',
90
+ type=str,
91
+ default=None,
92
+ help='Activates MERGE mode.\n'
93
+ "If save_resume_filepath doesn't exist, initializes it with a CSV header.\n"
94
+ 'CSV data already at save_resume_filepath counts towards max_shots and max_errors.\n'
95
+ 'Collected data is appended to save_resume_filepath.\n'
96
+ 'Note that MERGE mode is tolerant to failures: if the process is killed, it can simply be restarted and it will pick up where it left off.\n'
97
+ 'Note that MERGE mode is idempotent: if sufficient data has been collected, no additional work is done when run again.')
98
+
99
+ parser.add_argument('--start_batch_size',
100
+ type=int,
101
+ default=100,
102
+ help='Initial number of samples to batch together into one job.\n'
103
+ 'Starting small prevents over-sampling of circuits above threshold.\n'
104
+ 'The allowed batch size increases exponentially from this starting point.')
105
+ parser.add_argument('--max_batch_size',
106
+ type=int,
107
+ default=None,
108
+ help='Maximum number of samples to batch together into one job.\n'
109
+ 'Bigger values increase the delay between jobs finishing.\n'
110
+ 'Smaller values decrease the amount of aggregation of results, increasing the amount of output information.')
111
+ parser.add_argument('--max_batch_seconds',
112
+ type=int,
113
+ default=None,
114
+ help='Limits number of shots in a batch so that the estimated runtime of the batch is below this amount.')
115
+ parser.add_argument('--postselect_detectors_with_non_zero_4th_coord',
116
+ help='Turns on detector postselection. '
117
+ 'If any detector with a non-zero 4th coordinate fires, the shot is discarded.',
118
+ action='store_true')
119
+ parser.add_argument('--postselected_detectors_predicate',
120
+ type=str,
121
+ default='''False''',
122
+ help='Specifies a predicate used to decide which detectors to postselect. '
123
+ 'When a postselected detector produces a detection event, the shot is discarded instead of being given to the decoder.'
124
+ 'The number of discarded shots is tracked as a statistic.'
125
+ 'Available values:\n'
126
+ ' index: The unique number identifying the detector, determined by the order of detectors in the circuit file.\n'
127
+ ' coords: The coordinate data associated with the detector. An empty tuple, if the circuit file did not specify detector coordinates.\n'
128
+ ' metadata: The metadata associated with the task being sampled.\n'
129
+ 'Expected expression type:\n'
130
+ ' Something that can be given to `bool` to get False (do not postselect) or True (yes postselect).\n'
131
+ 'Examples:\n'
132
+ ''' --postselected_detectors_predicate "coords[2] == 0"\n'''
133
+ ''' --postselected_detectors_predicate "coords[3] < metadata['postselection_level']"\n''')
134
+ parser.add_argument('--postselected_observables_predicate',
135
+ type=str,
136
+ default='''False''',
137
+ help='Specifies a predicate used to decide which observables to postselect. '
138
+ 'When a decoder mispredicts a postselected observable, the shot is discarded instead of counting as an error.'
139
+ 'Available values:\n'
140
+ ' index: The index of the observable to postselect or not.\n'
141
+ ' metadata: The metadata associated with the task.\n'
142
+ 'Expected expression type:\n'
143
+ ' Something that can be given to `bool` to get False (do not postselect) or True (yes postselect).\n'
144
+ 'Examples:\n'
145
+ ''' --postselected_observables_predicate "False"\n'''
146
+ ''' --postselected_observables_predicate "metadata['d'] == 5 and index >= 2"\n''')
147
+ parser.add_argument('--count_observable_error_combos',
148
+ help='When set, the returned stats will include custom '
149
+ 'counts like `obs_mistake_mask=E_E__` counting '
150
+ 'how many times the decoder made each pattern of '
151
+ 'observable mistakes.',
152
+ action='store_true')
153
+ parser.add_argument('--count_detection_events',
154
+ help='When set, the returned stats will include custom '
155
+ 'counts `detectors_checked` and '
156
+ '`detection_events`. The detection fraction is '
157
+ 'the ratio of these two numbers.',
158
+ action='store_true')
159
+ parser.add_argument('--quiet',
160
+ help='Disables writing progress to stderr.',
161
+ action='store_true')
162
+ parser.add_argument('--custom_error_count_key',
163
+ type=str,
164
+ help='Makes --max_errors apply to `stat.custom_counts[key]` '
165
+ 'instead of to `stat.errors`.',
166
+ default=None)
167
+ parser.add_argument('--allowed_cpu_affinity_ids',
168
+ type=str,
169
+ nargs='+',
170
+ help='Controls which CPUs workers can be pinned to. By default, all'
171
+ ' CPUs are used. Specifying this argument makes it so that '
172
+ 'only the given CPU ids can be pinned. The given arguments '
173
+ ' will be evaluated as python expressions. The expressions '
174
+ 'should be integers or iterables of integers. So values like'
175
+ ' "1" and "[1, 2, 4]" and "range(5, 30)" all work.',
176
+ default=None)
177
+ parser.add_argument('--also_print_results_to_stdout',
178
+ help='Even if writing to a file, also write results to stdout.',
179
+ action='store_true')
180
+ parser.add_argument('--existing_data_filepaths',
181
+ nargs='*',
182
+ type=str,
183
+ default=(),
184
+ help='CSV data from these files counts towards max_shots and max_errors.\n'
185
+ 'This parameter can be given multiple arguments.')
186
+ parser.add_argument('--metadata_func',
187
+ type=str,
188
+ default="{'path': path}",
189
+ help='A python expression that associates json metadata with a circuit\'s results.\n'
190
+ 'Set to "auto" to use "sinter.comma_separated_key_values(path)"\n'
191
+ 'Values available to the expression:\n'
192
+ ' path: Relative path to the circuit file, from the command line arguments.\n'
193
+ ' circuit: The circuit itself, parsed from the file, as a stim.Circuit.\n'
194
+ 'Expected type:\n'
195
+ ' A value that can be serialized into JSON, like a Dict[str, int].\n'
196
+ '\n'
197
+ ' Note that the decoder field is already recorded separately, so storing\n'
198
+ ' it in the metadata as well would be redundant. But something like\n'
199
+ ' decoder version could be usefully added.\n'
200
+ 'Examples:\n'
201
+ ''' --metadata_func "{'path': path}"\n'''
202
+ ''' --metadata_func "auto"\n'''
203
+ ''' --metadata_func "{'n': circuit.num_qubits, 'p': float(path.split('/')[-1].split('.')[0])}"\n'''
204
+ )
205
+ import sinter
206
+ a = parser.parse_args(args=args)
207
+ if a.metadata_func == 'auto':
208
+ a.metadata_func = "sinter.comma_separated_key_values(path)"
209
+ a.metadata_func = eval(compile(
210
+ 'lambda *, path, circuit: ' + a.metadata_func,
211
+ filename='metadata_func:command_line_arg',
212
+ mode='eval'), {'sinter': sinter})
213
+ a.postselected_observables_predicate = eval(compile(
214
+ 'lambda index, metadata: ' + a.postselected_observables_predicate,
215
+ filename='postselected_observables_predicate:command_line_arg',
216
+ mode='eval'))
217
+ if a.postselected_detectors_predicate == 'False':
218
+ if a.postselect_detectors_with_non_zero_4th_coord:
219
+ a.postselected_detectors_predicate = lambda index, metadata, coords: coords[3]
220
+ else:
221
+ a.postselected_detectors_predicate = None
222
+ else:
223
+ if a.postselect_detectors_with_non_zero_4th_coord:
224
+ raise ValueError("Can't specify both --postselect_detectors_with_non_zero_4th_coord and --postselected_detectors_predicate")
225
+ a.postselected_detectors_predicate = eval(compile(
226
+ 'lambda index, metadata, coords: ' + cast(str, a.postselected_detectors_predicate),
227
+ filename='postselected_detectors_predicate:command_line_arg',
228
+ mode='eval'))
229
+ if a.custom_decoders_module_function is not None:
230
+ all_custom_decoders = {}
231
+ for entry in a.custom_decoders_module_function:
232
+ terms = entry.split(':')
233
+ if len(terms) != 2:
234
+ raise ValueError("--custom_decoders_module_function didn't have exactly one colon "
235
+ "separating a module name from a function name. Expected an argument "
236
+ "of the form --custom_decoders_module_function 'module:function'")
237
+ module, function = terms
238
+ vals = {'__name__': '[]'}
239
+ exec(f"from {module} import {function} as _custom_decoders", vals)
240
+ custom_decoders = vals['_custom_decoders']()
241
+ all_custom_decoders = {**all_custom_decoders, **custom_decoders}
242
+ a.custom_decoders = all_custom_decoders
243
+ else:
244
+ a.custom_decoders = None
245
+ for decoder in a.decoders:
246
+ if decoder not in BUILT_IN_SAMPLERS and (a.custom_decoders is None or decoder not in a.custom_decoders):
247
+ message = f"Not a recognized decoder or sampler: {decoder=}.\n"
248
+ message += f"Available built-in decoders and samplers: {sorted(e for e in BUILT_IN_SAMPLERS.keys() if 'internal' not in e)}.\n"
249
+ if a.custom_decoders is None:
250
+ message += f"No custom decoders are available. --custom_decoders_module_function wasn't specified."
251
+ else:
252
+ message += f"Available custom decoders: {sorted(a.custom_decoders.keys())}."
253
+ raise ValueError(message)
254
+ if a.allowed_cpu_affinity_ids is not None:
255
+ vals: List[int] = []
256
+ e: str
257
+ for e in a.allowed_cpu_affinity_ids:
258
+ try:
259
+ v = eval(e, {}, {})
260
+ if isinstance(v, int):
261
+ vals.append(v)
262
+ elif all(isinstance(e, int) for e in v):
263
+ vals.extend(v)
264
+ else:
265
+ raise ValueError("Not an integer or iterable of integers.")
266
+ except Exception as ex:
267
+ raise ValueError("Failed to eval {e!r} for --allowed_cpu_affinity_ids") from ex
268
+ a.allowed_cpu_affinity_ids = vals
269
+ return a
270
+
271
+
272
+ def open_merge_file(path: str) -> Tuple[Any, ExistingData]:
273
+ try:
274
+ existing = ExistingData.from_file(path)
275
+ return open(path, 'a'), existing
276
+ except FileNotFoundError:
277
+ f = open(path, 'w')
278
+ print(CSV_HEADER, file=f)
279
+ return f, ExistingData()
280
+
281
+
282
+ def main_collect(*, command_line_args: List[str]):
283
+ args = parse_args(args=command_line_args)
284
+
285
+ iter_tasks = iter_file_paths_into_goals(
286
+ circuit_paths=args.circuits,
287
+ metadata_func=args.metadata_func,
288
+ postselected_detectors_predicate=args.postselected_detectors_predicate,
289
+ postselected_observables_predicate=args.postselected_observables_predicate,
290
+ )
291
+ num_tasks = len(args.circuits) * len(args.decoders)
292
+
293
+ print_to_stdout = args.also_print_results_to_stdout or args.save_resume_filepath is None
294
+
295
+ did_work = False
296
+ printer = ThrottledProgressPrinter(
297
+ outs=[],
298
+ print_progress=not args.quiet,
299
+ min_progress_delay=0.03 if args.also_print_results_to_stdout else 0.1,
300
+ )
301
+ if print_to_stdout:
302
+ printer.outs.append(sys.stdout)
303
+
304
+ def on_progress(sample: Progress) -> None:
305
+ nonlocal did_work
306
+ for stats in sample.new_stats:
307
+ if not did_work:
308
+ printer.print_out(CSV_HEADER)
309
+ did_work = True
310
+ printer.print_out(stats.to_csv_line())
311
+
312
+ msg = sample.status_message
313
+ if msg == 'KeyboardInterrupt':
314
+ printer.show_latest_progress('\nInterrupted. Output is flushed. Cleaning up workers...')
315
+ printer.flush()
316
+ else:
317
+ printer.show_latest_progress(msg)
318
+
319
+ if args.processes == 'auto':
320
+ num_workers = os.cpu_count()
321
+ else:
322
+ try:
323
+ num_workers = int(args.processes)
324
+ except ValueError:
325
+ num_workers = 0
326
+ if num_workers < 1:
327
+ raise ValueError(f'--processes must be a non-negative integer, or "auto", but was: {args.processes}')
328
+ try:
329
+ collect(
330
+ num_workers=num_workers,
331
+ hint_num_tasks=num_tasks,
332
+ tasks=iter_tasks,
333
+ print_progress=False,
334
+ save_resume_filepath=args.save_resume_filepath,
335
+ existing_data_filepaths=args.existing_data_filepaths,
336
+ progress_callback=on_progress,
337
+ max_errors=args.max_errors,
338
+ max_shots=args.max_shots,
339
+ count_detection_events=args.count_detection_events,
340
+ count_observable_error_combos=args.count_observable_error_combos,
341
+ decoders=args.decoders,
342
+ max_batch_seconds=args.max_batch_seconds,
343
+ max_batch_size=args.max_batch_size,
344
+ start_batch_size=args.start_batch_size,
345
+ custom_decoders=args.custom_decoders,
346
+ custom_error_count_key=args.custom_error_count_key,
347
+ allowed_cpu_affinity_ids=args.allowed_cpu_affinity_ids,
348
+ )
349
+ except KeyboardInterrupt:
350
+ pass