sinter 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sinter might be problematic. Click here for more details.
- sinter/__init__.py +47 -0
- sinter/_collection/__init__.py +10 -0
- sinter/_collection/_collection.py +480 -0
- sinter/_collection/_collection_manager.py +581 -0
- sinter/_collection/_collection_manager_test.py +287 -0
- sinter/_collection/_collection_test.py +317 -0
- sinter/_collection/_collection_worker_loop.py +35 -0
- sinter/_collection/_collection_worker_state.py +259 -0
- sinter/_collection/_collection_worker_test.py +222 -0
- sinter/_collection/_mux_sampler.py +56 -0
- sinter/_collection/_printer.py +65 -0
- sinter/_collection/_sampler_ramp_throttled.py +66 -0
- sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
- sinter/_command/__init__.py +0 -0
- sinter/_command/_main.py +39 -0
- sinter/_command/_main_collect.py +350 -0
- sinter/_command/_main_collect_test.py +482 -0
- sinter/_command/_main_combine.py +84 -0
- sinter/_command/_main_combine_test.py +153 -0
- sinter/_command/_main_plot.py +817 -0
- sinter/_command/_main_plot_test.py +445 -0
- sinter/_command/_main_predict.py +75 -0
- sinter/_command/_main_predict_test.py +36 -0
- sinter/_data/__init__.py +20 -0
- sinter/_data/_anon_task_stats.py +89 -0
- sinter/_data/_anon_task_stats_test.py +35 -0
- sinter/_data/_collection_options.py +106 -0
- sinter/_data/_collection_options_test.py +24 -0
- sinter/_data/_csv_out.py +74 -0
- sinter/_data/_existing_data.py +173 -0
- sinter/_data/_existing_data_test.py +41 -0
- sinter/_data/_task.py +311 -0
- sinter/_data/_task_stats.py +244 -0
- sinter/_data/_task_stats_test.py +140 -0
- sinter/_data/_task_test.py +38 -0
- sinter/_decoding/__init__.py +16 -0
- sinter/_decoding/_decoding.py +419 -0
- sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
- sinter/_decoding/_decoding_decoder_class.py +161 -0
- sinter/_decoding/_decoding_fusion_blossom.py +193 -0
- sinter/_decoding/_decoding_mwpf.py +302 -0
- sinter/_decoding/_decoding_pymatching.py +81 -0
- sinter/_decoding/_decoding_test.py +480 -0
- sinter/_decoding/_decoding_vacuous.py +38 -0
- sinter/_decoding/_perfectionist_sampler.py +38 -0
- sinter/_decoding/_sampler.py +72 -0
- sinter/_decoding/_stim_then_decode_sampler.py +222 -0
- sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
- sinter/_plotting.py +619 -0
- sinter/_plotting_test.py +108 -0
- sinter/_predict.py +381 -0
- sinter/_predict_test.py +227 -0
- sinter/_probability_util.py +519 -0
- sinter/_probability_util_test.py +281 -0
- sinter-1.15.0.data/data/README.md +332 -0
- sinter-1.15.0.data/data/readme_example_plot.png +0 -0
- sinter-1.15.0.data/data/requirements.txt +4 -0
- sinter-1.15.0.dist-info/METADATA +354 -0
- sinter-1.15.0.dist-info/RECORD +62 -0
- sinter-1.15.0.dist-info/WHEEL +5 -0
- sinter-1.15.0.dist-info/entry_points.txt +2 -0
- sinter-1.15.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
import queue
|
|
2
|
+
import time
|
|
3
|
+
from typing import Any
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import stim
|
|
8
|
+
|
|
9
|
+
from sinter._data import AnonTaskStats
|
|
10
|
+
from sinter._data import CollectionOptions
|
|
11
|
+
from sinter._data import Task
|
|
12
|
+
from sinter._decoding import CompiledSampler
|
|
13
|
+
from sinter._decoding import Sampler
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import multiprocessing
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _fill_in_task(task: Task) -> Task:
|
|
20
|
+
changed = False
|
|
21
|
+
circuit = task.circuit
|
|
22
|
+
if circuit is None:
|
|
23
|
+
circuit = stim.Circuit.from_file(task.circuit_path)
|
|
24
|
+
changed = True
|
|
25
|
+
dem = task.detector_error_model
|
|
26
|
+
if dem is None:
|
|
27
|
+
try:
|
|
28
|
+
dem = circuit.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True)
|
|
29
|
+
except ValueError:
|
|
30
|
+
try:
|
|
31
|
+
dem = circuit.detector_error_model(approximate_disjoint_errors=True)
|
|
32
|
+
except ValueError:
|
|
33
|
+
dem = circuit.detector_error_model(approximate_disjoint_errors=True, flatten_loops=True)
|
|
34
|
+
changed = True
|
|
35
|
+
if not changed:
|
|
36
|
+
return task
|
|
37
|
+
return Task(
|
|
38
|
+
circuit=circuit,
|
|
39
|
+
decoder=task.decoder,
|
|
40
|
+
detector_error_model=dem,
|
|
41
|
+
postselection_mask=task.postselection_mask,
|
|
42
|
+
postselected_observables_mask=task.postselected_observables_mask,
|
|
43
|
+
json_metadata=task.json_metadata,
|
|
44
|
+
collection_options=task.collection_options,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CollectionWorkerState:
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
flush_period: float,
|
|
53
|
+
worker_id: int,
|
|
54
|
+
inp: 'multiprocessing.Queue',
|
|
55
|
+
out: 'multiprocessing.Queue',
|
|
56
|
+
sampler: Sampler,
|
|
57
|
+
custom_error_count_key: Optional[str],
|
|
58
|
+
):
|
|
59
|
+
assert isinstance(flush_period, (int, float))
|
|
60
|
+
assert isinstance(sampler, Sampler)
|
|
61
|
+
self.max_flush_period = flush_period
|
|
62
|
+
self.cur_flush_period = 0.01
|
|
63
|
+
self.inp = inp
|
|
64
|
+
self.out = out
|
|
65
|
+
self.sampler = sampler
|
|
66
|
+
self.compiled_sampler: CompiledSampler | None = None
|
|
67
|
+
self.worker_id = worker_id
|
|
68
|
+
|
|
69
|
+
self.current_task: Task | None = None
|
|
70
|
+
self.current_error_cutoff: int | None = None
|
|
71
|
+
self.custom_error_count_key = custom_error_count_key
|
|
72
|
+
self.current_task_shots_left: int = 0
|
|
73
|
+
self.unflushed_results: AnonTaskStats = AnonTaskStats()
|
|
74
|
+
self.last_flush_message_time = time.monotonic()
|
|
75
|
+
self.soft_error_flush_threshold: int = 1
|
|
76
|
+
|
|
77
|
+
def _send_message_to_manager(self, message: Any):
|
|
78
|
+
self.out.put(message)
|
|
79
|
+
|
|
80
|
+
def state_summary(self) -> str:
|
|
81
|
+
lines = [
|
|
82
|
+
f'Worker(id={self.worker_id}) [',
|
|
83
|
+
f' max_flush_period={self.max_flush_period}',
|
|
84
|
+
f' cur_flush_period={self.cur_flush_period}',
|
|
85
|
+
f' sampler={self.sampler}',
|
|
86
|
+
f' compiled_sampler={self.compiled_sampler}',
|
|
87
|
+
f' current_task={self.current_task}',
|
|
88
|
+
f' current_error_cutoff={self.current_error_cutoff}',
|
|
89
|
+
f' custom_error_count_key={self.custom_error_count_key}',
|
|
90
|
+
f' current_task_shots_left={self.current_task_shots_left}',
|
|
91
|
+
f' unflushed_results={self.unflushed_results}',
|
|
92
|
+
f' last_flush_message_time={self.last_flush_message_time}',
|
|
93
|
+
f' soft_error_flush_threshold={self.soft_error_flush_threshold}',
|
|
94
|
+
f']',
|
|
95
|
+
]
|
|
96
|
+
return '\n' + '\n'.join(lines) + '\n'
|
|
97
|
+
|
|
98
|
+
def flush_results(self):
|
|
99
|
+
if self.unflushed_results.shots > 0:
|
|
100
|
+
self.last_flush_message_time = time.monotonic()
|
|
101
|
+
self.cur_flush_period = min(self.cur_flush_period * 1.4, self.max_flush_period)
|
|
102
|
+
self._send_message_to_manager((
|
|
103
|
+
'flushed_results',
|
|
104
|
+
self.worker_id,
|
|
105
|
+
(self.current_task.strong_id(), self.unflushed_results),
|
|
106
|
+
))
|
|
107
|
+
self.unflushed_results = AnonTaskStats()
|
|
108
|
+
return True
|
|
109
|
+
return False
|
|
110
|
+
|
|
111
|
+
def accept_shots(self, *, shots_delta: int):
|
|
112
|
+
assert shots_delta >= 0
|
|
113
|
+
self.current_task_shots_left += shots_delta
|
|
114
|
+
self._send_message_to_manager((
|
|
115
|
+
'accepted_shots',
|
|
116
|
+
self.worker_id,
|
|
117
|
+
(self.current_task.strong_id(), shots_delta),
|
|
118
|
+
))
|
|
119
|
+
|
|
120
|
+
def return_shots(self, *, requested_shots: int):
|
|
121
|
+
assert requested_shots >= 0
|
|
122
|
+
returned_shots = max(0, min(requested_shots, self.current_task_shots_left))
|
|
123
|
+
self.current_task_shots_left -= returned_shots
|
|
124
|
+
if self.current_task_shots_left <= 0:
|
|
125
|
+
self.flush_results()
|
|
126
|
+
self._send_message_to_manager((
|
|
127
|
+
'returned_shots',
|
|
128
|
+
self.worker_id,
|
|
129
|
+
(self.current_task.strong_id(), returned_shots),
|
|
130
|
+
))
|
|
131
|
+
|
|
132
|
+
def compute_strong_id(self, *, new_task: Task):
|
|
133
|
+
strong_id = _fill_in_task(new_task).strong_id()
|
|
134
|
+
self._send_message_to_manager((
|
|
135
|
+
'computed_strong_id',
|
|
136
|
+
self.worker_id,
|
|
137
|
+
strong_id,
|
|
138
|
+
))
|
|
139
|
+
|
|
140
|
+
def change_job(self, *, new_task: Task, new_collection_options: CollectionOptions):
|
|
141
|
+
self.flush_results()
|
|
142
|
+
|
|
143
|
+
self.current_task = _fill_in_task(new_task)
|
|
144
|
+
self.current_error_cutoff = new_collection_options.max_errors
|
|
145
|
+
self.compiled_sampler = self.sampler.compiled_sampler_for_task(self.current_task)
|
|
146
|
+
assert self.current_task.strong_id() is not None
|
|
147
|
+
self.current_task_shots_left = 0
|
|
148
|
+
self.last_flush_message_time = time.monotonic()
|
|
149
|
+
|
|
150
|
+
self._send_message_to_manager((
|
|
151
|
+
'changed_job',
|
|
152
|
+
self.worker_id,
|
|
153
|
+
(self.current_task.strong_id(),),
|
|
154
|
+
))
|
|
155
|
+
|
|
156
|
+
def process_messages(self) -> int:
|
|
157
|
+
num_processed = 0
|
|
158
|
+
while True:
|
|
159
|
+
try:
|
|
160
|
+
message = self.inp.get_nowait()
|
|
161
|
+
except queue.Empty:
|
|
162
|
+
return num_processed
|
|
163
|
+
|
|
164
|
+
num_processed += 1
|
|
165
|
+
message_type, message_body = message
|
|
166
|
+
|
|
167
|
+
if message_type == 'stop':
|
|
168
|
+
return -1
|
|
169
|
+
|
|
170
|
+
elif message_type == 'flush_results':
|
|
171
|
+
self.flush_results()
|
|
172
|
+
|
|
173
|
+
elif message_type == 'compute_strong_id':
|
|
174
|
+
assert isinstance(message_body, Task)
|
|
175
|
+
self.compute_strong_id(new_task=message_body)
|
|
176
|
+
|
|
177
|
+
elif message_type == 'change_job':
|
|
178
|
+
new_task, new_collection_options, soft_error_flush_threshold = message_body
|
|
179
|
+
self.cur_flush_period = 0.01
|
|
180
|
+
self.soft_error_flush_threshold = soft_error_flush_threshold
|
|
181
|
+
assert isinstance(new_task, Task)
|
|
182
|
+
self.change_job(new_task=new_task, new_collection_options=new_collection_options)
|
|
183
|
+
|
|
184
|
+
elif message_type == 'set_soft_error_flush_threshold':
|
|
185
|
+
soft_error_flush_threshold = message_body
|
|
186
|
+
self.soft_error_flush_threshold = soft_error_flush_threshold
|
|
187
|
+
|
|
188
|
+
elif message_type == 'accept_shots':
|
|
189
|
+
job_key, shots_delta = message_body
|
|
190
|
+
assert isinstance(shots_delta, int)
|
|
191
|
+
assert job_key == self.current_task.strong_id()
|
|
192
|
+
self.accept_shots(shots_delta=shots_delta)
|
|
193
|
+
|
|
194
|
+
elif message_type == 'return_shots':
|
|
195
|
+
job_key, requested_shots = message_body
|
|
196
|
+
assert isinstance(requested_shots, int)
|
|
197
|
+
assert job_key == self.current_task.strong_id()
|
|
198
|
+
self.return_shots(requested_shots=requested_shots)
|
|
199
|
+
|
|
200
|
+
else:
|
|
201
|
+
raise NotImplementedError(f'{message_type=}')
|
|
202
|
+
|
|
203
|
+
def num_unflushed_errors(self) -> int:
|
|
204
|
+
if self.custom_error_count_key is not None:
|
|
205
|
+
return self.unflushed_results.custom_counts[self.custom_error_count_key]
|
|
206
|
+
return self.unflushed_results.errors
|
|
207
|
+
|
|
208
|
+
def do_some_work(self) -> bool:
|
|
209
|
+
did_some_work = False
|
|
210
|
+
|
|
211
|
+
# Sample some stats.
|
|
212
|
+
if self.current_task_shots_left > 0:
|
|
213
|
+
# Don't keep sampling if we've exceeded the number of errors needed.
|
|
214
|
+
if self.current_error_cutoff is not None and self.current_error_cutoff <= 0:
|
|
215
|
+
return self.flush_results()
|
|
216
|
+
|
|
217
|
+
some_work_done = self.compiled_sampler.sample(self.current_task_shots_left)
|
|
218
|
+
if some_work_done.shots < 1:
|
|
219
|
+
raise ValueError(f"Sampler didn't do any work. It returned statistics with shots == 0: {some_work_done}.")
|
|
220
|
+
assert isinstance(some_work_done, AnonTaskStats)
|
|
221
|
+
self.current_task_shots_left -= some_work_done.shots
|
|
222
|
+
if self.current_error_cutoff is not None:
|
|
223
|
+
errors_done = some_work_done.custom_counts[self.custom_error_count_key] if self.custom_error_count_key is not None else some_work_done.errors
|
|
224
|
+
self.current_error_cutoff -= errors_done
|
|
225
|
+
self.unflushed_results += some_work_done
|
|
226
|
+
did_some_work = True
|
|
227
|
+
|
|
228
|
+
# Report them periodically.
|
|
229
|
+
should_flush = False
|
|
230
|
+
if self.num_unflushed_errors() >= self.soft_error_flush_threshold:
|
|
231
|
+
should_flush = True
|
|
232
|
+
if self.unflushed_results.shots > 0:
|
|
233
|
+
if self.current_task_shots_left <= 0 or self.last_flush_message_time + self.cur_flush_period < time.monotonic():
|
|
234
|
+
should_flush = True
|
|
235
|
+
if should_flush:
|
|
236
|
+
did_some_work |= self.flush_results()
|
|
237
|
+
|
|
238
|
+
return did_some_work
|
|
239
|
+
|
|
240
|
+
def run_message_loop(self):
|
|
241
|
+
try:
|
|
242
|
+
while True:
|
|
243
|
+
num_messages_processed = self.process_messages()
|
|
244
|
+
if num_messages_processed == -1:
|
|
245
|
+
break
|
|
246
|
+
did_some_work = self.do_some_work()
|
|
247
|
+
if not did_some_work and num_messages_processed == 0:
|
|
248
|
+
time.sleep(0.01)
|
|
249
|
+
|
|
250
|
+
except KeyboardInterrupt:
|
|
251
|
+
pass
|
|
252
|
+
|
|
253
|
+
except BaseException as ex:
|
|
254
|
+
import traceback
|
|
255
|
+
self._send_message_to_manager((
|
|
256
|
+
'stopped_due_to_exception',
|
|
257
|
+
self.worker_id,
|
|
258
|
+
(None if self.current_task is None else self.current_task.strong_id(), self.current_task_shots_left, self.unflushed_results, traceback.format_exc(), ex),
|
|
259
|
+
))
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import multiprocessing
|
|
3
|
+
import time
|
|
4
|
+
from typing import Any, List
|
|
5
|
+
|
|
6
|
+
import sinter
|
|
7
|
+
import stim
|
|
8
|
+
|
|
9
|
+
from sinter._collection._collection_worker_state import CollectionWorkerState
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MockWorkHandler(sinter.Sampler, sinter.CompiledSampler):
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.expected_task = None
|
|
15
|
+
self.expected = collections.deque()
|
|
16
|
+
|
|
17
|
+
def compiled_sampler_for_task(self, task: sinter.Task) -> sinter.CompiledSampler:
|
|
18
|
+
assert task == self.expected_task
|
|
19
|
+
return self
|
|
20
|
+
|
|
21
|
+
def handles_throttling(self) -> bool:
|
|
22
|
+
return True
|
|
23
|
+
|
|
24
|
+
def sample(self, shots: int) -> sinter.AnonTaskStats:
|
|
25
|
+
assert self.expected
|
|
26
|
+
expected_shots, response = self.expected.popleft()
|
|
27
|
+
assert shots == expected_shots
|
|
28
|
+
return response
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _assert_drain_queue(q: multiprocessing.Queue, expected_contents: List[Any]):
|
|
32
|
+
for v in expected_contents:
|
|
33
|
+
assert q.get(timeout=0.1) == v
|
|
34
|
+
assert q.empty()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _put_wait_not_empty(q: multiprocessing.Queue, item: Any):
|
|
38
|
+
q.put(item)
|
|
39
|
+
while q.empty():
|
|
40
|
+
time.sleep(0.0001)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def test_worker_stop():
|
|
44
|
+
handler = MockWorkHandler()
|
|
45
|
+
|
|
46
|
+
inp = multiprocessing.Queue()
|
|
47
|
+
out = multiprocessing.Queue()
|
|
48
|
+
inp.cancel_join_thread()
|
|
49
|
+
out.cancel_join_thread()
|
|
50
|
+
|
|
51
|
+
worker = CollectionWorkerState(
|
|
52
|
+
flush_period=-1,
|
|
53
|
+
worker_id=5,
|
|
54
|
+
sampler=handler,
|
|
55
|
+
inp=inp,
|
|
56
|
+
out=out,
|
|
57
|
+
custom_error_count_key=None,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
assert worker.process_messages() == 0
|
|
61
|
+
_assert_drain_queue(out, [])
|
|
62
|
+
|
|
63
|
+
t0 = sinter.Task(
|
|
64
|
+
circuit=stim.Circuit('H 0'),
|
|
65
|
+
detector_error_model=stim.DetectorErrorModel(),
|
|
66
|
+
decoder='mock',
|
|
67
|
+
collection_options=sinter.CollectionOptions(max_shots=100_000_000),
|
|
68
|
+
json_metadata={'a': 3},
|
|
69
|
+
)
|
|
70
|
+
handler.expected_task = t0
|
|
71
|
+
|
|
72
|
+
_put_wait_not_empty(inp, ('change_job', (t0, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000)))
|
|
73
|
+
assert worker.process_messages() == 1
|
|
74
|
+
_assert_drain_queue(out, [('changed_job', 5, (t0.strong_id(),))])
|
|
75
|
+
|
|
76
|
+
_put_wait_not_empty(inp, ('stop', None))
|
|
77
|
+
assert worker.process_messages() == -1
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_worker_skip_work():
|
|
81
|
+
handler = MockWorkHandler()
|
|
82
|
+
|
|
83
|
+
inp = multiprocessing.Queue()
|
|
84
|
+
out = multiprocessing.Queue()
|
|
85
|
+
inp.cancel_join_thread()
|
|
86
|
+
out.cancel_join_thread()
|
|
87
|
+
|
|
88
|
+
worker = CollectionWorkerState(
|
|
89
|
+
flush_period=-1,
|
|
90
|
+
worker_id=5,
|
|
91
|
+
sampler=handler,
|
|
92
|
+
inp=inp,
|
|
93
|
+
out=out,
|
|
94
|
+
custom_error_count_key=None,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
assert worker.process_messages() == 0
|
|
98
|
+
_assert_drain_queue(out, [])
|
|
99
|
+
|
|
100
|
+
t0 = sinter.Task(
|
|
101
|
+
circuit=stim.Circuit('H 0'),
|
|
102
|
+
detector_error_model=stim.DetectorErrorModel(),
|
|
103
|
+
decoder='mock',
|
|
104
|
+
collection_options=sinter.CollectionOptions(max_shots=100_000_000),
|
|
105
|
+
json_metadata={'a': 3},
|
|
106
|
+
)
|
|
107
|
+
handler.expected_task = t0
|
|
108
|
+
_put_wait_not_empty(inp, ('change_job', (t0, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000)))
|
|
109
|
+
assert worker.process_messages() == 1
|
|
110
|
+
_assert_drain_queue(out, [('changed_job', 5, (t0.strong_id(),))])
|
|
111
|
+
|
|
112
|
+
_put_wait_not_empty(inp, ('accept_shots', (t0.strong_id(), 10000)))
|
|
113
|
+
assert worker.process_messages() == 1
|
|
114
|
+
_assert_drain_queue(out, [('accepted_shots', 5, (t0.strong_id(), 10000))])
|
|
115
|
+
|
|
116
|
+
assert worker.current_task == t0
|
|
117
|
+
assert worker.current_task_shots_left == 10000
|
|
118
|
+
assert worker.process_messages() == 0
|
|
119
|
+
_assert_drain_queue(out, [])
|
|
120
|
+
|
|
121
|
+
_put_wait_not_empty(inp, ('return_shots', (t0.strong_id(), 2000)))
|
|
122
|
+
assert worker.process_messages() == 1
|
|
123
|
+
_assert_drain_queue(out, [
|
|
124
|
+
('returned_shots', 5, (t0.strong_id(), 2000)),
|
|
125
|
+
])
|
|
126
|
+
|
|
127
|
+
_put_wait_not_empty(inp, ('return_shots', (t0.strong_id(), 20000000)))
|
|
128
|
+
assert worker.process_messages() == 1
|
|
129
|
+
_assert_drain_queue(out, [
|
|
130
|
+
('returned_shots', 5, (t0.strong_id(), 8000)),
|
|
131
|
+
])
|
|
132
|
+
|
|
133
|
+
assert not worker.do_some_work()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def test_worker_finish_work():
|
|
137
|
+
handler = MockWorkHandler()
|
|
138
|
+
|
|
139
|
+
inp = multiprocessing.Queue()
|
|
140
|
+
out = multiprocessing.Queue()
|
|
141
|
+
inp.cancel_join_thread()
|
|
142
|
+
out.cancel_join_thread()
|
|
143
|
+
|
|
144
|
+
worker = CollectionWorkerState(
|
|
145
|
+
flush_period=-1,
|
|
146
|
+
worker_id=5,
|
|
147
|
+
sampler=handler,
|
|
148
|
+
inp=inp,
|
|
149
|
+
out=out,
|
|
150
|
+
custom_error_count_key=None,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
assert worker.process_messages() == 0
|
|
154
|
+
_assert_drain_queue(out, [])
|
|
155
|
+
|
|
156
|
+
ta = sinter.Task(
|
|
157
|
+
circuit=stim.Circuit('H 0'),
|
|
158
|
+
detector_error_model=stim.DetectorErrorModel(),
|
|
159
|
+
decoder='mock',
|
|
160
|
+
collection_options=sinter.CollectionOptions(max_shots=100_000_000),
|
|
161
|
+
json_metadata={'a': 3},
|
|
162
|
+
)
|
|
163
|
+
handler.expected_task = ta
|
|
164
|
+
_put_wait_not_empty(inp, ('change_job', (ta, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000)))
|
|
165
|
+
_put_wait_not_empty(inp, ('accept_shots', (ta.strong_id(), 10000)))
|
|
166
|
+
t0 = time.monotonic()
|
|
167
|
+
num_processed = 0
|
|
168
|
+
while True:
|
|
169
|
+
num_processed += worker.process_messages()
|
|
170
|
+
if num_processed >= 2:
|
|
171
|
+
break
|
|
172
|
+
if time.monotonic() - t0 > 1:
|
|
173
|
+
raise ValueError("Messages not processed")
|
|
174
|
+
assert num_processed == 2
|
|
175
|
+
_assert_drain_queue(out, [
|
|
176
|
+
('changed_job', 5, (ta.strong_id(),)),
|
|
177
|
+
('accepted_shots', 5, (ta.strong_id(), 10000)),
|
|
178
|
+
])
|
|
179
|
+
|
|
180
|
+
assert worker.current_task == ta
|
|
181
|
+
assert worker.current_task_shots_left == 10000
|
|
182
|
+
assert worker.process_messages() == 0
|
|
183
|
+
_assert_drain_queue(out, [])
|
|
184
|
+
|
|
185
|
+
handler.expected.append((
|
|
186
|
+
10000,
|
|
187
|
+
sinter.AnonTaskStats(
|
|
188
|
+
shots=1000,
|
|
189
|
+
errors=23,
|
|
190
|
+
discards=0,
|
|
191
|
+
seconds=1,
|
|
192
|
+
),
|
|
193
|
+
))
|
|
194
|
+
|
|
195
|
+
assert worker.do_some_work()
|
|
196
|
+
worker.flush_results()
|
|
197
|
+
_assert_drain_queue(out, [
|
|
198
|
+
('flushed_results', 5, (ta.strong_id(), sinter.AnonTaskStats(shots=1000, errors=23, discards=0, seconds=1)))])
|
|
199
|
+
|
|
200
|
+
handler.expected.append((
|
|
201
|
+
9000,
|
|
202
|
+
sinter.AnonTaskStats(
|
|
203
|
+
shots=9000,
|
|
204
|
+
errors=13,
|
|
205
|
+
discards=0,
|
|
206
|
+
seconds=1,
|
|
207
|
+
),
|
|
208
|
+
))
|
|
209
|
+
|
|
210
|
+
assert worker.do_some_work()
|
|
211
|
+
worker.flush_results()
|
|
212
|
+
_assert_drain_queue(out, [
|
|
213
|
+
('flushed_results', 5, (ta.strong_id(), sinter.AnonTaskStats(
|
|
214
|
+
shots=9000,
|
|
215
|
+
errors=13,
|
|
216
|
+
discards=0,
|
|
217
|
+
seconds=1,
|
|
218
|
+
))),
|
|
219
|
+
])
|
|
220
|
+
assert not worker.do_some_work()
|
|
221
|
+
worker.flush_results()
|
|
222
|
+
_assert_drain_queue(out, [])
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from sinter._data import Task
|
|
6
|
+
from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_SAMPLERS
|
|
7
|
+
from sinter._decoding._decoding_decoder_class import Decoder
|
|
8
|
+
from sinter._decoding._sampler import CompiledSampler
|
|
9
|
+
from sinter._decoding._sampler import Sampler
|
|
10
|
+
from sinter._decoding._stim_then_decode_sampler import StimThenDecodeSampler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MuxSampler(Sampler):
|
|
14
|
+
"""Looks up the sampler to use for a task, by the task's decoder name."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
*,
|
|
19
|
+
custom_decoders: Union[dict[str, Union[Decoder, Sampler]], None],
|
|
20
|
+
count_observable_error_combos: bool,
|
|
21
|
+
count_detection_events: bool,
|
|
22
|
+
tmp_dir: Optional[pathlib.Path],
|
|
23
|
+
):
|
|
24
|
+
self.custom_decoders = custom_decoders
|
|
25
|
+
self.count_observable_error_combos = count_observable_error_combos
|
|
26
|
+
self.count_detection_events = count_detection_events
|
|
27
|
+
self.tmp_dir = tmp_dir
|
|
28
|
+
|
|
29
|
+
def compiled_sampler_for_task(self, task: Task) -> CompiledSampler:
|
|
30
|
+
return self._resolve_sampler(task.decoder).compiled_sampler_for_task(task)
|
|
31
|
+
|
|
32
|
+
def _resolve_sampler(self, name: str) -> Sampler:
|
|
33
|
+
sub_sampler: Union[Decoder, Sampler]
|
|
34
|
+
|
|
35
|
+
if name in self.custom_decoders:
|
|
36
|
+
sub_sampler = self.custom_decoders[name]
|
|
37
|
+
elif name in BUILT_IN_SAMPLERS:
|
|
38
|
+
sub_sampler = BUILT_IN_SAMPLERS[name]
|
|
39
|
+
else:
|
|
40
|
+
raise NotImplementedError(f'Not a recognized decoder or sampler: {name=}. Did you forget to specify custom_decoders?')
|
|
41
|
+
|
|
42
|
+
if isinstance(sub_sampler, Sampler):
|
|
43
|
+
if self.count_detection_events:
|
|
44
|
+
raise NotImplementedError("'count_detection_events' not supported when using a custom Sampler (instead of a custom Decoder).")
|
|
45
|
+
if self.count_observable_error_combos:
|
|
46
|
+
raise NotImplementedError("'count_observable_error_combos' not supported when using a custom Sampler (instead of a custom Decoder).")
|
|
47
|
+
return sub_sampler
|
|
48
|
+
elif isinstance(sub_sampler, Decoder) or hasattr(sub_sampler, 'compile_decoder_for_dem'):
|
|
49
|
+
return StimThenDecodeSampler(
|
|
50
|
+
decoder=sub_sampler,
|
|
51
|
+
count_detection_events=self.count_detection_events,
|
|
52
|
+
count_observable_error_combos=self.count_observable_error_combos,
|
|
53
|
+
tmp_dir=self.tmp_dir,
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
raise NotImplementedError(f"Don't know how to turn this into a Sampler: {sub_sampler!r}")
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from typing import List, Any
|
|
3
|
+
|
|
4
|
+
import sys
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ThrottledProgressPrinter:
|
|
9
|
+
"""Handles printing progress updates interspersed amongst output.
|
|
10
|
+
|
|
11
|
+
Throttles the progress updates to not flood the screen when 100 show up
|
|
12
|
+
at the same time, and instead only show the latest one.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self, *, outs: List[Any], print_progress: bool, min_progress_delay: float):
|
|
15
|
+
self.outs = outs
|
|
16
|
+
self.print_progress = print_progress
|
|
17
|
+
self.next_can_print_time = time.monotonic()
|
|
18
|
+
self.latest_msg = ''
|
|
19
|
+
self.latest_printed_msg = ''
|
|
20
|
+
self.min_progress_delay = min_progress_delay
|
|
21
|
+
self.is_worker_running = False
|
|
22
|
+
self.lock = threading.Lock()
|
|
23
|
+
|
|
24
|
+
def print_out(self, msg: str) -> None:
|
|
25
|
+
with self.lock:
|
|
26
|
+
for out in self.outs:
|
|
27
|
+
print(msg, file=out, flush=True)
|
|
28
|
+
|
|
29
|
+
def show_latest_progress(self, msg: str) -> None:
|
|
30
|
+
if not self.print_progress:
|
|
31
|
+
return
|
|
32
|
+
with self.lock:
|
|
33
|
+
if msg == self.latest_msg:
|
|
34
|
+
return
|
|
35
|
+
self.latest_msg = msg
|
|
36
|
+
if not self.is_worker_running:
|
|
37
|
+
dt = self._try_print_else_delay()
|
|
38
|
+
if dt > 0:
|
|
39
|
+
self.is_worker_running = True
|
|
40
|
+
threading.Thread(target=self._print_worker).start()
|
|
41
|
+
|
|
42
|
+
def flush(self):
|
|
43
|
+
with self.lock:
|
|
44
|
+
if self.latest_msg != "" and self.latest_printed_msg != self.latest_msg:
|
|
45
|
+
print('\033[31m' + self.latest_msg + '\033[0m', file=sys.stderr, flush=True)
|
|
46
|
+
self.latest_printed_msg = self.latest_msg
|
|
47
|
+
|
|
48
|
+
def _try_print_else_delay(self) -> float:
|
|
49
|
+
t = time.monotonic()
|
|
50
|
+
dt = self.next_can_print_time - t
|
|
51
|
+
if dt <= 0:
|
|
52
|
+
self.next_can_print_time = t + self.min_progress_delay
|
|
53
|
+
self.is_worker_running = False
|
|
54
|
+
if self.latest_msg != "" and self.latest_msg != self.latest_printed_msg:
|
|
55
|
+
print('\033[31m' + self.latest_msg + '\033[0m', file=sys.stderr, flush=True)
|
|
56
|
+
self.latest_printed_msg = self.latest_msg
|
|
57
|
+
return max(dt, 0)
|
|
58
|
+
|
|
59
|
+
def _print_worker(self):
|
|
60
|
+
while True:
|
|
61
|
+
with self.lock:
|
|
62
|
+
dt = self._try_print_else_delay()
|
|
63
|
+
if dt == 0:
|
|
64
|
+
break
|
|
65
|
+
time.sleep(dt)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
from sinter._decoding import Sampler, CompiledSampler
|
|
4
|
+
from sinter._data import Task, AnonTaskStats
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RampThrottledSampler(Sampler):
|
|
8
|
+
"""Wraps a sampler to adjust requested shots to hit a target time.
|
|
9
|
+
|
|
10
|
+
This sampler will initially only take 1 shot per call. If the time taken
|
|
11
|
+
significantly undershoots the target time, the maximum number of shots per
|
|
12
|
+
call is increased by a constant factor. If it exceeds the target time, the
|
|
13
|
+
maximum is reduced by a constant factor. The result is that the sampler
|
|
14
|
+
"ramps up" how many shots it does per call until it takes roughly the target
|
|
15
|
+
time, and then dynamically adapts to stay near it.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, sub_sampler: Sampler, target_batch_seconds: float, max_batch_shots: int):
|
|
19
|
+
self.sub_sampler = sub_sampler
|
|
20
|
+
self.target_batch_seconds = target_batch_seconds
|
|
21
|
+
self.max_batch_shots = max_batch_shots
|
|
22
|
+
|
|
23
|
+
def __str__(self) -> str:
|
|
24
|
+
return f'CompiledRampThrottledSampler({self.sub_sampler})'
|
|
25
|
+
|
|
26
|
+
def compiled_sampler_for_task(self, task: Task) -> CompiledSampler:
|
|
27
|
+
compiled_sub_sampler = self.sub_sampler.compiled_sampler_for_task(task)
|
|
28
|
+
if compiled_sub_sampler.handles_throttling():
|
|
29
|
+
return compiled_sub_sampler
|
|
30
|
+
|
|
31
|
+
return CompiledRampThrottledSampler(
|
|
32
|
+
sub_sampler=compiled_sub_sampler,
|
|
33
|
+
target_batch_seconds=self.target_batch_seconds,
|
|
34
|
+
max_batch_shots=self.max_batch_shots,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class CompiledRampThrottledSampler(CompiledSampler):
|
|
39
|
+
def __init__(self, sub_sampler: CompiledSampler, target_batch_seconds: float, max_batch_shots: int):
|
|
40
|
+
self.sub_sampler = sub_sampler
|
|
41
|
+
self.target_batch_seconds = target_batch_seconds
|
|
42
|
+
self.batch_shots = 1
|
|
43
|
+
self.max_batch_shots = max_batch_shots
|
|
44
|
+
|
|
45
|
+
def __str__(self) -> str:
|
|
46
|
+
return f'CompiledRampThrottledSampler({self.sub_sampler})'
|
|
47
|
+
|
|
48
|
+
def sample(self, max_shots: int) -> AnonTaskStats:
|
|
49
|
+
t0 = time.monotonic()
|
|
50
|
+
actual_shots = min(max_shots, self.batch_shots)
|
|
51
|
+
result = self.sub_sampler.sample(actual_shots)
|
|
52
|
+
dt = time.monotonic() - t0
|
|
53
|
+
|
|
54
|
+
# Rebalance number of shots.
|
|
55
|
+
if self.batch_shots > 1 and dt > self.target_batch_seconds * 1.3:
|
|
56
|
+
self.batch_shots //= 2
|
|
57
|
+
if result.shots * 2 >= actual_shots:
|
|
58
|
+
for _ in range(4):
|
|
59
|
+
if self.batch_shots * 2 > self.max_batch_shots:
|
|
60
|
+
break
|
|
61
|
+
if dt > self.target_batch_seconds * 0.3:
|
|
62
|
+
break
|
|
63
|
+
self.batch_shots *= 2
|
|
64
|
+
dt *= 2
|
|
65
|
+
|
|
66
|
+
return result
|