sinter 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sinter might be problematic. Click here for more details.
- sinter/__init__.py +47 -0
- sinter/_collection/__init__.py +10 -0
- sinter/_collection/_collection.py +480 -0
- sinter/_collection/_collection_manager.py +581 -0
- sinter/_collection/_collection_manager_test.py +287 -0
- sinter/_collection/_collection_test.py +317 -0
- sinter/_collection/_collection_worker_loop.py +35 -0
- sinter/_collection/_collection_worker_state.py +259 -0
- sinter/_collection/_collection_worker_test.py +222 -0
- sinter/_collection/_mux_sampler.py +56 -0
- sinter/_collection/_printer.py +65 -0
- sinter/_collection/_sampler_ramp_throttled.py +66 -0
- sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
- sinter/_command/__init__.py +0 -0
- sinter/_command/_main.py +39 -0
- sinter/_command/_main_collect.py +350 -0
- sinter/_command/_main_collect_test.py +482 -0
- sinter/_command/_main_combine.py +84 -0
- sinter/_command/_main_combine_test.py +153 -0
- sinter/_command/_main_plot.py +817 -0
- sinter/_command/_main_plot_test.py +445 -0
- sinter/_command/_main_predict.py +75 -0
- sinter/_command/_main_predict_test.py +36 -0
- sinter/_data/__init__.py +20 -0
- sinter/_data/_anon_task_stats.py +89 -0
- sinter/_data/_anon_task_stats_test.py +35 -0
- sinter/_data/_collection_options.py +106 -0
- sinter/_data/_collection_options_test.py +24 -0
- sinter/_data/_csv_out.py +74 -0
- sinter/_data/_existing_data.py +173 -0
- sinter/_data/_existing_data_test.py +41 -0
- sinter/_data/_task.py +311 -0
- sinter/_data/_task_stats.py +244 -0
- sinter/_data/_task_stats_test.py +140 -0
- sinter/_data/_task_test.py +38 -0
- sinter/_decoding/__init__.py +16 -0
- sinter/_decoding/_decoding.py +419 -0
- sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
- sinter/_decoding/_decoding_decoder_class.py +161 -0
- sinter/_decoding/_decoding_fusion_blossom.py +193 -0
- sinter/_decoding/_decoding_mwpf.py +302 -0
- sinter/_decoding/_decoding_pymatching.py +81 -0
- sinter/_decoding/_decoding_test.py +480 -0
- sinter/_decoding/_decoding_vacuous.py +38 -0
- sinter/_decoding/_perfectionist_sampler.py +38 -0
- sinter/_decoding/_sampler.py +72 -0
- sinter/_decoding/_stim_then_decode_sampler.py +222 -0
- sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
- sinter/_plotting.py +619 -0
- sinter/_plotting_test.py +108 -0
- sinter/_predict.py +381 -0
- sinter/_predict_test.py +227 -0
- sinter/_probability_util.py +519 -0
- sinter/_probability_util_test.py +281 -0
- sinter-1.15.0.data/data/README.md +332 -0
- sinter-1.15.0.data/data/readme_example_plot.png +0 -0
- sinter-1.15.0.data/data/requirements.txt +4 -0
- sinter-1.15.0.dist-info/METADATA +354 -0
- sinter-1.15.0.dist-info/RECORD +62 -0
- sinter-1.15.0.dist-info/WHEEL +5 -0
- sinter-1.15.0.dist-info/entry_points.txt +2 -0
- sinter-1.15.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from unittest import mock
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
import sinter
|
|
6
|
+
import stim
|
|
7
|
+
from sinter._collection._sampler_ramp_throttled import (
|
|
8
|
+
CompiledRampThrottledSampler,
|
|
9
|
+
RampThrottledSampler,
|
|
10
|
+
)
|
|
11
|
+
from sinter._data import AnonTaskStats, Task
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MockSampler(sinter.Sampler, sinter.CompiledSampler):
|
|
15
|
+
"""Mock sampler that tracks `suggested_shots` parameter in `sample` calls."""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self.calls = []
|
|
19
|
+
|
|
20
|
+
def compiled_sampler_for_task(self, task: Task) -> sinter.CompiledSampler:
|
|
21
|
+
return self
|
|
22
|
+
|
|
23
|
+
def sample(self, suggested_shots: int) -> AnonTaskStats:
|
|
24
|
+
self.calls.append(suggested_shots)
|
|
25
|
+
return AnonTaskStats(
|
|
26
|
+
shots=suggested_shots,
|
|
27
|
+
errors=1,
|
|
28
|
+
seconds=0.001 * suggested_shots, # Simulate time proportional to shots
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def mock_sampler():
|
|
34
|
+
return MockSampler()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_initial_batch_size(mock_sampler):
|
|
38
|
+
"""Test that the sampler starts with a batch size of 1."""
|
|
39
|
+
sampler = CompiledRampThrottledSampler(
|
|
40
|
+
sub_sampler=mock_sampler,
|
|
41
|
+
target_batch_seconds=1.0,
|
|
42
|
+
max_batch_shots=1024,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# First call should use batch_size=1
|
|
46
|
+
sampler.sample(100)
|
|
47
|
+
assert mock_sampler.calls[0] == 1
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_batch_size_ramps_up(mock_sampler):
|
|
51
|
+
"""Test that the batch size increases when execution is fast."""
|
|
52
|
+
sampler = CompiledRampThrottledSampler(
|
|
53
|
+
sub_sampler=mock_sampler,
|
|
54
|
+
target_batch_seconds=1.0,
|
|
55
|
+
max_batch_shots=1024,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Mock time.monotonic to simulate fast execution
|
|
59
|
+
# two calls per sample for tic/toc
|
|
60
|
+
with mock.patch(
|
|
61
|
+
"time.monotonic", side_effect=[0.0, 0.001, 0.02, 0.021, 0.03, 0.031]
|
|
62
|
+
):
|
|
63
|
+
sampler.sample(100) # First call, batch_size=1
|
|
64
|
+
sampler.sample(100) # Should double 4 times to 16
|
|
65
|
+
sampler.sample(100) # Should double 4 times again but hit limit of 100
|
|
66
|
+
|
|
67
|
+
assert mock_sampler.calls == [1, 16, 100]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def test_batch_size_decreases(mock_sampler):
|
|
71
|
+
"""Test that the batch size decreases when execution is slow."""
|
|
72
|
+
sampler = CompiledRampThrottledSampler(
|
|
73
|
+
sub_sampler=mock_sampler,
|
|
74
|
+
target_batch_seconds=0.1,
|
|
75
|
+
max_batch_shots=1024,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Set initial batch size higher for this test
|
|
79
|
+
sampler.batch_shots = 64
|
|
80
|
+
|
|
81
|
+
# Mock time.monotonic to simulate slow execution (>1.3x target)
|
|
82
|
+
with mock.patch("time.monotonic", side_effect=[0.0, 0.15, 0.5, 0.65]):
|
|
83
|
+
sampler.sample(100) # First call, batch_size=64
|
|
84
|
+
sampler.sample(100) # Should halve to 32
|
|
85
|
+
|
|
86
|
+
assert mock_sampler.calls == [64, 32]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def test_respects_max_batch_shots(mock_sampler):
|
|
90
|
+
"""Test that the batch size never exceeds max_batch_shots."""
|
|
91
|
+
sampler = CompiledRampThrottledSampler(
|
|
92
|
+
sub_sampler=mock_sampler,
|
|
93
|
+
target_batch_seconds=1.0,
|
|
94
|
+
max_batch_shots=16, # Small max for testing
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Set initial batch size close to max
|
|
98
|
+
sampler.batch_shots = 8
|
|
99
|
+
|
|
100
|
+
# Mock time.monotonic to simulate very fast execution
|
|
101
|
+
# two calls per sample for tic/toc
|
|
102
|
+
with mock.patch(
|
|
103
|
+
"time.monotonic", side_effect=[0.0, 0.001, 0.02, 0.021, 0.03, 0.031]
|
|
104
|
+
):
|
|
105
|
+
sampler.sample(100) # First call, batch_size=8
|
|
106
|
+
sampler.sample(100) # Should double to 16
|
|
107
|
+
sampler.sample(100) # Should stay at 16 (max)
|
|
108
|
+
|
|
109
|
+
assert mock_sampler.calls == [8, 16, 16]
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_respects_max_shots_parameter(mock_sampler):
|
|
113
|
+
"""Test that the sampler respects the max_shots parameter."""
|
|
114
|
+
sampler = CompiledRampThrottledSampler(
|
|
115
|
+
sub_sampler=mock_sampler,
|
|
116
|
+
target_batch_seconds=1.0,
|
|
117
|
+
max_batch_shots=1024,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Set batch size higher than max_shots
|
|
121
|
+
sampler.batch_shots = 100
|
|
122
|
+
|
|
123
|
+
# Call with max_shots=10
|
|
124
|
+
sampler.sample(10)
|
|
125
|
+
|
|
126
|
+
# Should only request 10 shots, not 100
|
|
127
|
+
assert mock_sampler.calls[0] == 10
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_sub_sampler_parameter_pass_through(mock_sampler):
|
|
131
|
+
"""Test that parameters are passed through to compiled sub sampler."""
|
|
132
|
+
factory = RampThrottledSampler(
|
|
133
|
+
sub_sampler=mock_sampler,
|
|
134
|
+
target_batch_seconds=0.5,
|
|
135
|
+
max_batch_shots=512,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
task = Task(circuit=stim.Circuit(), decoder="test")
|
|
139
|
+
compiled = factory.compiled_sampler_for_task(task)
|
|
140
|
+
|
|
141
|
+
assert isinstance(compiled, CompiledRampThrottledSampler)
|
|
142
|
+
assert compiled.target_batch_seconds == 0.5
|
|
143
|
+
assert compiled.max_batch_shots == 512
|
|
144
|
+
assert compiled.batch_shots == 1 # Initial batch size
|
|
File without changes
|
sinter/_command/_main.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from typing import Optional, List
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def main(*, command_line_args: Optional[List[str]] = None):
|
|
6
|
+
if command_line_args is None:
|
|
7
|
+
command_line_args = sys.argv[1:]
|
|
8
|
+
|
|
9
|
+
mode = command_line_args[0] if command_line_args else None
|
|
10
|
+
if mode == 'combine':
|
|
11
|
+
from sinter._command._main_combine import main_combine
|
|
12
|
+
return main_combine(command_line_args=command_line_args[1:])
|
|
13
|
+
if mode == 'collect':
|
|
14
|
+
from sinter._command._main_collect import main_collect
|
|
15
|
+
return main_collect(command_line_args=command_line_args[1:])
|
|
16
|
+
if mode == 'plot':
|
|
17
|
+
from sinter._command._main_plot import main_plot
|
|
18
|
+
return main_plot(command_line_args=command_line_args[1:])
|
|
19
|
+
if mode == 'predict':
|
|
20
|
+
from sinter._command._main_predict import main_predict
|
|
21
|
+
return main_predict(command_line_args=command_line_args[1:])
|
|
22
|
+
|
|
23
|
+
want_help = mode in ['help', 'h', '--help', '-help', '-h', '--h']
|
|
24
|
+
if not want_help:
|
|
25
|
+
if command_line_args and not command_line_args[0].startswith('-'):
|
|
26
|
+
print(f"\033[31mUnrecognized command: sinter {command_line_args[0]}\033[0m\n", file=sys.stderr)
|
|
27
|
+
else:
|
|
28
|
+
print(f"\033[31mDidn't specify a command.\033[0m\n", file=sys.stderr)
|
|
29
|
+
print(f"Available commands are:\n"
|
|
30
|
+
f" sinter collect\n"
|
|
31
|
+
f" sinter combine\n"
|
|
32
|
+
f" sinter plot"
|
|
33
|
+
f"", file=sys.stderr)
|
|
34
|
+
if not want_help:
|
|
35
|
+
sys.exit(1)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
if __name__ == '__main__':
|
|
39
|
+
main()
|
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
from typing import Iterator, Any, Tuple, List, Callable, Optional
|
|
6
|
+
from typing import cast
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import stim
|
|
10
|
+
|
|
11
|
+
from sinter._collection import ThrottledProgressPrinter
|
|
12
|
+
from sinter._data import Task
|
|
13
|
+
from sinter._collection import collect, Progress, post_selection_mask_from_predicate
|
|
14
|
+
from sinter._command._main_combine import ExistingData, CSV_HEADER
|
|
15
|
+
from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_SAMPLERS
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def iter_file_paths_into_goals(circuit_paths: Iterator[str],
|
|
19
|
+
metadata_func: Callable,
|
|
20
|
+
postselected_detectors_predicate: Optional[Callable[[int, Any, Tuple[float, ...]], bool]],
|
|
21
|
+
postselected_observables_predicate: Callable[[int, Any], bool],
|
|
22
|
+
) -> Iterator[Task]:
|
|
23
|
+
for path in circuit_paths:
|
|
24
|
+
with open(path) as f:
|
|
25
|
+
circuit_text = f.read()
|
|
26
|
+
circuit = stim.Circuit(circuit_text)
|
|
27
|
+
|
|
28
|
+
metadata = metadata_func(path=path, circuit=circuit)
|
|
29
|
+
if postselected_detectors_predicate is not None:
|
|
30
|
+
post_mask = post_selection_mask_from_predicate(circuit, metadata=metadata, postselected_detectors_predicate=postselected_detectors_predicate)
|
|
31
|
+
if not np.any(post_mask):
|
|
32
|
+
post_mask = None
|
|
33
|
+
else:
|
|
34
|
+
post_mask = None
|
|
35
|
+
postselected_observables = [
|
|
36
|
+
k
|
|
37
|
+
for k in range(circuit.num_observables)
|
|
38
|
+
if postselected_observables_predicate(k, metadata)
|
|
39
|
+
]
|
|
40
|
+
if any(postselected_observables):
|
|
41
|
+
postselected_observables_mask = np.zeros(shape=math.ceil(circuit.num_observables / 8), dtype=np.uint8)
|
|
42
|
+
for k in postselected_observables:
|
|
43
|
+
postselected_observables_mask[k // 8] |= 1 << (k % 8)
|
|
44
|
+
else:
|
|
45
|
+
postselected_observables_mask = None
|
|
46
|
+
|
|
47
|
+
yield Task(
|
|
48
|
+
circuit=circuit,
|
|
49
|
+
postselection_mask=post_mask,
|
|
50
|
+
postselected_observables_mask=postselected_observables_mask,
|
|
51
|
+
json_metadata=metadata,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def parse_args(args: List[str]) -> Any:
|
|
56
|
+
parser = argparse.ArgumentParser(description='Collect Monte Carlo samples.',
|
|
57
|
+
prog='sinter collect')
|
|
58
|
+
parser.add_argument('--circuits',
|
|
59
|
+
nargs='+',
|
|
60
|
+
required=True,
|
|
61
|
+
help='Circuit files to sample from and decode.\n'
|
|
62
|
+
'This parameter can be given multiple arguments.')
|
|
63
|
+
parser.add_argument('--decoders',
|
|
64
|
+
type=str,
|
|
65
|
+
nargs='+',
|
|
66
|
+
required=True,
|
|
67
|
+
help='The decoder to use to predict observables from detection events.')
|
|
68
|
+
parser.add_argument('--custom_decoders_module_function',
|
|
69
|
+
default=None,
|
|
70
|
+
nargs='+',
|
|
71
|
+
help='Use the syntax "module:function" to "import function from module" '
|
|
72
|
+
'and use the result of "function()" as the custom_decoders '
|
|
73
|
+
'dictionary. The dictionary must map strings to stim.Decoder '
|
|
74
|
+
'instances.')
|
|
75
|
+
parser.add_argument('--max_shots',
|
|
76
|
+
type=int,
|
|
77
|
+
default=None,
|
|
78
|
+
help='Sampling of a circuit will stop if this many shots have been taken.')
|
|
79
|
+
parser.add_argument('--max_errors',
|
|
80
|
+
type=int,
|
|
81
|
+
default=None,
|
|
82
|
+
help='Sampling of a circuit will stop if this many errors have been seen.')
|
|
83
|
+
parser.add_argument('--processes',
|
|
84
|
+
default='auto',
|
|
85
|
+
type=str,
|
|
86
|
+
help='Number of processes to use for simultaneous sampling and decoding. '
|
|
87
|
+
'Must be either a number or "auto" which sets it to the number of '
|
|
88
|
+
'CPUs on the machine.')
|
|
89
|
+
parser.add_argument('--save_resume_filepath',
|
|
90
|
+
type=str,
|
|
91
|
+
default=None,
|
|
92
|
+
help='Activates MERGE mode.\n'
|
|
93
|
+
"If save_resume_filepath doesn't exist, initializes it with a CSV header.\n"
|
|
94
|
+
'CSV data already at save_resume_filepath counts towards max_shots and max_errors.\n'
|
|
95
|
+
'Collected data is appended to save_resume_filepath.\n'
|
|
96
|
+
'Note that MERGE mode is tolerant to failures: if the process is killed, it can simply be restarted and it will pick up where it left off.\n'
|
|
97
|
+
'Note that MERGE mode is idempotent: if sufficient data has been collected, no additional work is done when run again.')
|
|
98
|
+
|
|
99
|
+
parser.add_argument('--start_batch_size',
|
|
100
|
+
type=int,
|
|
101
|
+
default=100,
|
|
102
|
+
help='Initial number of samples to batch together into one job.\n'
|
|
103
|
+
'Starting small prevents over-sampling of circuits above threshold.\n'
|
|
104
|
+
'The allowed batch size increases exponentially from this starting point.')
|
|
105
|
+
parser.add_argument('--max_batch_size',
|
|
106
|
+
type=int,
|
|
107
|
+
default=None,
|
|
108
|
+
help='Maximum number of samples to batch together into one job.\n'
|
|
109
|
+
'Bigger values increase the delay between jobs finishing.\n'
|
|
110
|
+
'Smaller values decrease the amount of aggregation of results, increasing the amount of output information.')
|
|
111
|
+
parser.add_argument('--max_batch_seconds',
|
|
112
|
+
type=int,
|
|
113
|
+
default=None,
|
|
114
|
+
help='Limits number of shots in a batch so that the estimated runtime of the batch is below this amount.')
|
|
115
|
+
parser.add_argument('--postselect_detectors_with_non_zero_4th_coord',
|
|
116
|
+
help='Turns on detector postselection. '
|
|
117
|
+
'If any detector with a non-zero 4th coordinate fires, the shot is discarded.',
|
|
118
|
+
action='store_true')
|
|
119
|
+
parser.add_argument('--postselected_detectors_predicate',
|
|
120
|
+
type=str,
|
|
121
|
+
default='''False''',
|
|
122
|
+
help='Specifies a predicate used to decide which detectors to postselect. '
|
|
123
|
+
'When a postselected detector produces a detection event, the shot is discarded instead of being given to the decoder.'
|
|
124
|
+
'The number of discarded shots is tracked as a statistic.'
|
|
125
|
+
'Available values:\n'
|
|
126
|
+
' index: The unique number identifying the detector, determined by the order of detectors in the circuit file.\n'
|
|
127
|
+
' coords: The coordinate data associated with the detector. An empty tuple, if the circuit file did not specify detector coordinates.\n'
|
|
128
|
+
' metadata: The metadata associated with the task being sampled.\n'
|
|
129
|
+
'Expected expression type:\n'
|
|
130
|
+
' Something that can be given to `bool` to get False (do not postselect) or True (yes postselect).\n'
|
|
131
|
+
'Examples:\n'
|
|
132
|
+
''' --postselected_detectors_predicate "coords[2] == 0"\n'''
|
|
133
|
+
''' --postselected_detectors_predicate "coords[3] < metadata['postselection_level']"\n''')
|
|
134
|
+
parser.add_argument('--postselected_observables_predicate',
|
|
135
|
+
type=str,
|
|
136
|
+
default='''False''',
|
|
137
|
+
help='Specifies a predicate used to decide which observables to postselect. '
|
|
138
|
+
'When a decoder mispredicts a postselected observable, the shot is discarded instead of counting as an error.'
|
|
139
|
+
'Available values:\n'
|
|
140
|
+
' index: The index of the observable to postselect or not.\n'
|
|
141
|
+
' metadata: The metadata associated with the task.\n'
|
|
142
|
+
'Expected expression type:\n'
|
|
143
|
+
' Something that can be given to `bool` to get False (do not postselect) or True (yes postselect).\n'
|
|
144
|
+
'Examples:\n'
|
|
145
|
+
''' --postselected_observables_predicate "False"\n'''
|
|
146
|
+
''' --postselected_observables_predicate "metadata['d'] == 5 and index >= 2"\n''')
|
|
147
|
+
parser.add_argument('--count_observable_error_combos',
|
|
148
|
+
help='When set, the returned stats will include custom '
|
|
149
|
+
'counts like `obs_mistake_mask=E_E__` counting '
|
|
150
|
+
'how many times the decoder made each pattern of '
|
|
151
|
+
'observable mistakes.',
|
|
152
|
+
action='store_true')
|
|
153
|
+
parser.add_argument('--count_detection_events',
|
|
154
|
+
help='When set, the returned stats will include custom '
|
|
155
|
+
'counts `detectors_checked` and '
|
|
156
|
+
'`detection_events`. The detection fraction is '
|
|
157
|
+
'the ratio of these two numbers.',
|
|
158
|
+
action='store_true')
|
|
159
|
+
parser.add_argument('--quiet',
|
|
160
|
+
help='Disables writing progress to stderr.',
|
|
161
|
+
action='store_true')
|
|
162
|
+
parser.add_argument('--custom_error_count_key',
|
|
163
|
+
type=str,
|
|
164
|
+
help='Makes --max_errors apply to `stat.custom_counts[key]` '
|
|
165
|
+
'instead of to `stat.errors`.',
|
|
166
|
+
default=None)
|
|
167
|
+
parser.add_argument('--allowed_cpu_affinity_ids',
|
|
168
|
+
type=str,
|
|
169
|
+
nargs='+',
|
|
170
|
+
help='Controls which CPUs workers can be pinned to. By default, all'
|
|
171
|
+
' CPUs are used. Specifying this argument makes it so that '
|
|
172
|
+
'only the given CPU ids can be pinned. The given arguments '
|
|
173
|
+
' will be evaluated as python expressions. The expressions '
|
|
174
|
+
'should be integers or iterables of integers. So values like'
|
|
175
|
+
' "1" and "[1, 2, 4]" and "range(5, 30)" all work.',
|
|
176
|
+
default=None)
|
|
177
|
+
parser.add_argument('--also_print_results_to_stdout',
|
|
178
|
+
help='Even if writing to a file, also write results to stdout.',
|
|
179
|
+
action='store_true')
|
|
180
|
+
parser.add_argument('--existing_data_filepaths',
|
|
181
|
+
nargs='*',
|
|
182
|
+
type=str,
|
|
183
|
+
default=(),
|
|
184
|
+
help='CSV data from these files counts towards max_shots and max_errors.\n'
|
|
185
|
+
'This parameter can be given multiple arguments.')
|
|
186
|
+
parser.add_argument('--metadata_func',
|
|
187
|
+
type=str,
|
|
188
|
+
default="{'path': path}",
|
|
189
|
+
help='A python expression that associates json metadata with a circuit\'s results.\n'
|
|
190
|
+
'Set to "auto" to use "sinter.comma_separated_key_values(path)"\n'
|
|
191
|
+
'Values available to the expression:\n'
|
|
192
|
+
' path: Relative path to the circuit file, from the command line arguments.\n'
|
|
193
|
+
' circuit: The circuit itself, parsed from the file, as a stim.Circuit.\n'
|
|
194
|
+
'Expected type:\n'
|
|
195
|
+
' A value that can be serialized into JSON, like a Dict[str, int].\n'
|
|
196
|
+
'\n'
|
|
197
|
+
' Note that the decoder field is already recorded separately, so storing\n'
|
|
198
|
+
' it in the metadata as well would be redundant. But something like\n'
|
|
199
|
+
' decoder version could be usefully added.\n'
|
|
200
|
+
'Examples:\n'
|
|
201
|
+
''' --metadata_func "{'path': path}"\n'''
|
|
202
|
+
''' --metadata_func "auto"\n'''
|
|
203
|
+
''' --metadata_func "{'n': circuit.num_qubits, 'p': float(path.split('/')[-1].split('.')[0])}"\n'''
|
|
204
|
+
)
|
|
205
|
+
import sinter
|
|
206
|
+
a = parser.parse_args(args=args)
|
|
207
|
+
if a.metadata_func == 'auto':
|
|
208
|
+
a.metadata_func = "sinter.comma_separated_key_values(path)"
|
|
209
|
+
a.metadata_func = eval(compile(
|
|
210
|
+
'lambda *, path, circuit: ' + a.metadata_func,
|
|
211
|
+
filename='metadata_func:command_line_arg',
|
|
212
|
+
mode='eval'), {'sinter': sinter})
|
|
213
|
+
a.postselected_observables_predicate = eval(compile(
|
|
214
|
+
'lambda index, metadata: ' + a.postselected_observables_predicate,
|
|
215
|
+
filename='postselected_observables_predicate:command_line_arg',
|
|
216
|
+
mode='eval'))
|
|
217
|
+
if a.postselected_detectors_predicate == 'False':
|
|
218
|
+
if a.postselect_detectors_with_non_zero_4th_coord:
|
|
219
|
+
a.postselected_detectors_predicate = lambda index, metadata, coords: coords[3]
|
|
220
|
+
else:
|
|
221
|
+
a.postselected_detectors_predicate = None
|
|
222
|
+
else:
|
|
223
|
+
if a.postselect_detectors_with_non_zero_4th_coord:
|
|
224
|
+
raise ValueError("Can't specify both --postselect_detectors_with_non_zero_4th_coord and --postselected_detectors_predicate")
|
|
225
|
+
a.postselected_detectors_predicate = eval(compile(
|
|
226
|
+
'lambda index, metadata, coords: ' + cast(str, a.postselected_detectors_predicate),
|
|
227
|
+
filename='postselected_detectors_predicate:command_line_arg',
|
|
228
|
+
mode='eval'))
|
|
229
|
+
if a.custom_decoders_module_function is not None:
|
|
230
|
+
all_custom_decoders = {}
|
|
231
|
+
for entry in a.custom_decoders_module_function:
|
|
232
|
+
terms = entry.split(':')
|
|
233
|
+
if len(terms) != 2:
|
|
234
|
+
raise ValueError("--custom_decoders_module_function didn't have exactly one colon "
|
|
235
|
+
"separating a module name from a function name. Expected an argument "
|
|
236
|
+
"of the form --custom_decoders_module_function 'module:function'")
|
|
237
|
+
module, function = terms
|
|
238
|
+
vals = {'__name__': '[]'}
|
|
239
|
+
exec(f"from {module} import {function} as _custom_decoders", vals)
|
|
240
|
+
custom_decoders = vals['_custom_decoders']()
|
|
241
|
+
all_custom_decoders = {**all_custom_decoders, **custom_decoders}
|
|
242
|
+
a.custom_decoders = all_custom_decoders
|
|
243
|
+
else:
|
|
244
|
+
a.custom_decoders = None
|
|
245
|
+
for decoder in a.decoders:
|
|
246
|
+
if decoder not in BUILT_IN_SAMPLERS and (a.custom_decoders is None or decoder not in a.custom_decoders):
|
|
247
|
+
message = f"Not a recognized decoder or sampler: {decoder=}.\n"
|
|
248
|
+
message += f"Available built-in decoders and samplers: {sorted(e for e in BUILT_IN_SAMPLERS.keys() if 'internal' not in e)}.\n"
|
|
249
|
+
if a.custom_decoders is None:
|
|
250
|
+
message += f"No custom decoders are available. --custom_decoders_module_function wasn't specified."
|
|
251
|
+
else:
|
|
252
|
+
message += f"Available custom decoders: {sorted(a.custom_decoders.keys())}."
|
|
253
|
+
raise ValueError(message)
|
|
254
|
+
if a.allowed_cpu_affinity_ids is not None:
|
|
255
|
+
vals: List[int] = []
|
|
256
|
+
e: str
|
|
257
|
+
for e in a.allowed_cpu_affinity_ids:
|
|
258
|
+
try:
|
|
259
|
+
v = eval(e, {}, {})
|
|
260
|
+
if isinstance(v, int):
|
|
261
|
+
vals.append(v)
|
|
262
|
+
elif all(isinstance(e, int) for e in v):
|
|
263
|
+
vals.extend(v)
|
|
264
|
+
else:
|
|
265
|
+
raise ValueError("Not an integer or iterable of integers.")
|
|
266
|
+
except Exception as ex:
|
|
267
|
+
raise ValueError("Failed to eval {e!r} for --allowed_cpu_affinity_ids") from ex
|
|
268
|
+
a.allowed_cpu_affinity_ids = vals
|
|
269
|
+
return a
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def open_merge_file(path: str) -> Tuple[Any, ExistingData]:
|
|
273
|
+
try:
|
|
274
|
+
existing = ExistingData.from_file(path)
|
|
275
|
+
return open(path, 'a'), existing
|
|
276
|
+
except FileNotFoundError:
|
|
277
|
+
f = open(path, 'w')
|
|
278
|
+
print(CSV_HEADER, file=f)
|
|
279
|
+
return f, ExistingData()
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def main_collect(*, command_line_args: List[str]):
|
|
283
|
+
args = parse_args(args=command_line_args)
|
|
284
|
+
|
|
285
|
+
iter_tasks = iter_file_paths_into_goals(
|
|
286
|
+
circuit_paths=args.circuits,
|
|
287
|
+
metadata_func=args.metadata_func,
|
|
288
|
+
postselected_detectors_predicate=args.postselected_detectors_predicate,
|
|
289
|
+
postselected_observables_predicate=args.postselected_observables_predicate,
|
|
290
|
+
)
|
|
291
|
+
num_tasks = len(args.circuits) * len(args.decoders)
|
|
292
|
+
|
|
293
|
+
print_to_stdout = args.also_print_results_to_stdout or args.save_resume_filepath is None
|
|
294
|
+
|
|
295
|
+
did_work = False
|
|
296
|
+
printer = ThrottledProgressPrinter(
|
|
297
|
+
outs=[],
|
|
298
|
+
print_progress=not args.quiet,
|
|
299
|
+
min_progress_delay=0.03 if args.also_print_results_to_stdout else 0.1,
|
|
300
|
+
)
|
|
301
|
+
if print_to_stdout:
|
|
302
|
+
printer.outs.append(sys.stdout)
|
|
303
|
+
|
|
304
|
+
def on_progress(sample: Progress) -> None:
|
|
305
|
+
nonlocal did_work
|
|
306
|
+
for stats in sample.new_stats:
|
|
307
|
+
if not did_work:
|
|
308
|
+
printer.print_out(CSV_HEADER)
|
|
309
|
+
did_work = True
|
|
310
|
+
printer.print_out(stats.to_csv_line())
|
|
311
|
+
|
|
312
|
+
msg = sample.status_message
|
|
313
|
+
if msg == 'KeyboardInterrupt':
|
|
314
|
+
printer.show_latest_progress('\nInterrupted. Output is flushed. Cleaning up workers...')
|
|
315
|
+
printer.flush()
|
|
316
|
+
else:
|
|
317
|
+
printer.show_latest_progress(msg)
|
|
318
|
+
|
|
319
|
+
if args.processes == 'auto':
|
|
320
|
+
num_workers = os.cpu_count()
|
|
321
|
+
else:
|
|
322
|
+
try:
|
|
323
|
+
num_workers = int(args.processes)
|
|
324
|
+
except ValueError:
|
|
325
|
+
num_workers = 0
|
|
326
|
+
if num_workers < 1:
|
|
327
|
+
raise ValueError(f'--processes must be a non-negative integer, or "auto", but was: {args.processes}')
|
|
328
|
+
try:
|
|
329
|
+
collect(
|
|
330
|
+
num_workers=num_workers,
|
|
331
|
+
hint_num_tasks=num_tasks,
|
|
332
|
+
tasks=iter_tasks,
|
|
333
|
+
print_progress=False,
|
|
334
|
+
save_resume_filepath=args.save_resume_filepath,
|
|
335
|
+
existing_data_filepaths=args.existing_data_filepaths,
|
|
336
|
+
progress_callback=on_progress,
|
|
337
|
+
max_errors=args.max_errors,
|
|
338
|
+
max_shots=args.max_shots,
|
|
339
|
+
count_detection_events=args.count_detection_events,
|
|
340
|
+
count_observable_error_combos=args.count_observable_error_combos,
|
|
341
|
+
decoders=args.decoders,
|
|
342
|
+
max_batch_seconds=args.max_batch_seconds,
|
|
343
|
+
max_batch_size=args.max_batch_size,
|
|
344
|
+
start_batch_size=args.start_batch_size,
|
|
345
|
+
custom_decoders=args.custom_decoders,
|
|
346
|
+
custom_error_count_key=args.custom_error_count_key,
|
|
347
|
+
allowed_cpu_affinity_ids=args.allowed_cpu_affinity_ids,
|
|
348
|
+
)
|
|
349
|
+
except KeyboardInterrupt:
|
|
350
|
+
pass
|