sinter 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sinter might be problematic. Click here for more details.
- sinter/__init__.py +47 -0
- sinter/_collection/__init__.py +10 -0
- sinter/_collection/_collection.py +480 -0
- sinter/_collection/_collection_manager.py +581 -0
- sinter/_collection/_collection_manager_test.py +287 -0
- sinter/_collection/_collection_test.py +317 -0
- sinter/_collection/_collection_worker_loop.py +35 -0
- sinter/_collection/_collection_worker_state.py +259 -0
- sinter/_collection/_collection_worker_test.py +222 -0
- sinter/_collection/_mux_sampler.py +56 -0
- sinter/_collection/_printer.py +65 -0
- sinter/_collection/_sampler_ramp_throttled.py +66 -0
- sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
- sinter/_command/__init__.py +0 -0
- sinter/_command/_main.py +39 -0
- sinter/_command/_main_collect.py +350 -0
- sinter/_command/_main_collect_test.py +482 -0
- sinter/_command/_main_combine.py +84 -0
- sinter/_command/_main_combine_test.py +153 -0
- sinter/_command/_main_plot.py +817 -0
- sinter/_command/_main_plot_test.py +445 -0
- sinter/_command/_main_predict.py +75 -0
- sinter/_command/_main_predict_test.py +36 -0
- sinter/_data/__init__.py +20 -0
- sinter/_data/_anon_task_stats.py +89 -0
- sinter/_data/_anon_task_stats_test.py +35 -0
- sinter/_data/_collection_options.py +106 -0
- sinter/_data/_collection_options_test.py +24 -0
- sinter/_data/_csv_out.py +74 -0
- sinter/_data/_existing_data.py +173 -0
- sinter/_data/_existing_data_test.py +41 -0
- sinter/_data/_task.py +311 -0
- sinter/_data/_task_stats.py +244 -0
- sinter/_data/_task_stats_test.py +140 -0
- sinter/_data/_task_test.py +38 -0
- sinter/_decoding/__init__.py +16 -0
- sinter/_decoding/_decoding.py +419 -0
- sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
- sinter/_decoding/_decoding_decoder_class.py +161 -0
- sinter/_decoding/_decoding_fusion_blossom.py +193 -0
- sinter/_decoding/_decoding_mwpf.py +302 -0
- sinter/_decoding/_decoding_pymatching.py +81 -0
- sinter/_decoding/_decoding_test.py +480 -0
- sinter/_decoding/_decoding_vacuous.py +38 -0
- sinter/_decoding/_perfectionist_sampler.py +38 -0
- sinter/_decoding/_sampler.py +72 -0
- sinter/_decoding/_stim_then_decode_sampler.py +222 -0
- sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
- sinter/_plotting.py +619 -0
- sinter/_plotting_test.py +108 -0
- sinter/_predict.py +381 -0
- sinter/_predict_test.py +227 -0
- sinter/_probability_util.py +519 -0
- sinter/_probability_util_test.py +281 -0
- sinter-1.15.0.data/data/README.md +332 -0
- sinter-1.15.0.data/data/readme_example_plot.png +0 -0
- sinter-1.15.0.data/data/requirements.txt +4 -0
- sinter-1.15.0.dist-info/METADATA +354 -0
- sinter-1.15.0.dist-info/RECORD +62 -0
- sinter-1.15.0.dist-info/WHEEL +5 -0
- sinter-1.15.0.dist-info/entry_points.txt +2 -0
- sinter-1.15.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,581 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import contextlib
|
|
3
|
+
import math
|
|
4
|
+
import multiprocessing
|
|
5
|
+
import os
|
|
6
|
+
import pathlib
|
|
7
|
+
import queue
|
|
8
|
+
import tempfile
|
|
9
|
+
import threading
|
|
10
|
+
from typing import Any, Optional, List, Dict, Iterable, Callable, Tuple
|
|
11
|
+
from typing import Union
|
|
12
|
+
from typing import cast
|
|
13
|
+
|
|
14
|
+
from sinter._collection._collection_worker_loop import collection_worker_loop
|
|
15
|
+
from sinter._collection._mux_sampler import MuxSampler
|
|
16
|
+
from sinter._collection._sampler_ramp_throttled import RampThrottledSampler
|
|
17
|
+
from sinter._data import CollectionOptions, Task, AnonTaskStats, TaskStats
|
|
18
|
+
from sinter._decoding import Sampler, Decoder
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class _ManagedWorkerState:
|
|
22
|
+
def __init__(self, worker_id: int, *, cpu_pin: Optional[int] = None):
|
|
23
|
+
self.worker_id: int = worker_id
|
|
24
|
+
self.process: Union[multiprocessing.Process, threading.Thread, None] = None
|
|
25
|
+
self.input_queue: Optional[multiprocessing.Queue[Tuple[str, Any]]] = None
|
|
26
|
+
self.assigned_work_key: Any = None
|
|
27
|
+
self.asked_to_drop_shots: int = 0
|
|
28
|
+
self.cpu_pin = cpu_pin
|
|
29
|
+
|
|
30
|
+
# Shots transfer into this field when manager sends shot requests to workers.
|
|
31
|
+
# Shots transfer out of this field when clients flush results or respond to work return requests.
|
|
32
|
+
self.assigned_shots: int = 0
|
|
33
|
+
|
|
34
|
+
def send_message(self, message: Any):
|
|
35
|
+
self.input_queue.put(message)
|
|
36
|
+
|
|
37
|
+
def ask_to_return_all_shots(self):
|
|
38
|
+
if self.asked_to_drop_shots == 0 and self.assigned_shots > 0:
|
|
39
|
+
self.send_message((
|
|
40
|
+
'return_shots',
|
|
41
|
+
(
|
|
42
|
+
self.assigned_work_key,
|
|
43
|
+
self.assigned_shots,
|
|
44
|
+
),
|
|
45
|
+
))
|
|
46
|
+
self.asked_to_drop_shots = self.assigned_shots
|
|
47
|
+
|
|
48
|
+
def has_returned_all_shots(self) -> bool:
|
|
49
|
+
return self.assigned_shots == 0 and self.asked_to_drop_shots == 0
|
|
50
|
+
|
|
51
|
+
def is_available_to_reassign(self) -> bool:
|
|
52
|
+
return self.assigned_work_key is None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class _ManagedTaskState:
|
|
56
|
+
def __init__(self, *, partial_task: Task, strong_id: str, shots_left: int, errors_left: int):
|
|
57
|
+
self.partial_task = partial_task
|
|
58
|
+
self.strong_id = strong_id
|
|
59
|
+
self.shots_left = shots_left
|
|
60
|
+
self.errors_left = errors_left
|
|
61
|
+
self.shots_unassigned = shots_left
|
|
62
|
+
self.shot_return_requests = 0
|
|
63
|
+
self.assigned_soft_error_flush_threshold: int = errors_left
|
|
64
|
+
self.workers_assigned: list[int] = []
|
|
65
|
+
|
|
66
|
+
def is_completed(self) -> bool:
|
|
67
|
+
return self.shots_left <= 0 or self.errors_left <= 0
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class CollectionManager:
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
*,
|
|
74
|
+
existing_data: Dict[Any, TaskStats],
|
|
75
|
+
collection_options: CollectionOptions,
|
|
76
|
+
custom_decoders: dict[str, Union[Decoder, Sampler]],
|
|
77
|
+
num_workers: int,
|
|
78
|
+
worker_flush_period: float,
|
|
79
|
+
tasks: Iterable[Task],
|
|
80
|
+
progress_callback: Callable[[Optional[TaskStats]], None],
|
|
81
|
+
allowed_cpu_affinity_ids: Optional[Iterable[int]],
|
|
82
|
+
count_observable_error_combos: bool = False,
|
|
83
|
+
count_detection_events: bool = False,
|
|
84
|
+
custom_error_count_key: Optional[str] = None,
|
|
85
|
+
use_threads_for_debugging: bool = False,
|
|
86
|
+
):
|
|
87
|
+
assert isinstance(custom_decoders, dict)
|
|
88
|
+
self.existing_data = existing_data
|
|
89
|
+
self.num_workers: int = num_workers
|
|
90
|
+
self.custom_decoders = custom_decoders
|
|
91
|
+
self.worker_flush_period: float = worker_flush_period
|
|
92
|
+
self.progress_callback = progress_callback
|
|
93
|
+
self.collection_options = collection_options
|
|
94
|
+
self.partial_tasks: list[Task] = list(tasks)
|
|
95
|
+
self.task_strong_ids: List[Optional[str]] = [None] * len(self.partial_tasks)
|
|
96
|
+
self.allowed_cpu_affinity_ids = None if allowed_cpu_affinity_ids is None else sorted(set(allowed_cpu_affinity_ids))
|
|
97
|
+
self.count_observable_error_combos = count_observable_error_combos
|
|
98
|
+
self.count_detection_events = count_detection_events
|
|
99
|
+
self.custom_error_count_key = custom_error_count_key
|
|
100
|
+
self.use_threads_for_debugging = use_threads_for_debugging
|
|
101
|
+
|
|
102
|
+
self.shared_worker_output_queue: Optional[multiprocessing.SimpleQueue[Tuple[str, int, Any]]] = None
|
|
103
|
+
self.task_states: Dict[Any, _ManagedTaskState] = {}
|
|
104
|
+
self.started: bool = False
|
|
105
|
+
self.total_collected = {k: v.to_anon_stats() for k, v in existing_data.items()}
|
|
106
|
+
|
|
107
|
+
if self.allowed_cpu_affinity_ids is None:
|
|
108
|
+
cpus = range(os.cpu_count())
|
|
109
|
+
else:
|
|
110
|
+
num_cpus = os.cpu_count()
|
|
111
|
+
cpus = [e for e in self.allowed_cpu_affinity_ids if e < num_cpus]
|
|
112
|
+
self.worker_states: List[_ManagedWorkerState] = []
|
|
113
|
+
for index in range(num_workers):
|
|
114
|
+
cpu_pin = None if len(cpus) == 0 else cpus[index % len(cpus)]
|
|
115
|
+
self.worker_states.append(_ManagedWorkerState(index, cpu_pin=cpu_pin))
|
|
116
|
+
self.tmp_dir: Optional[pathlib.Path] = None
|
|
117
|
+
|
|
118
|
+
def __enter__(self):
|
|
119
|
+
self.exit_stack = contextlib.ExitStack().__enter__()
|
|
120
|
+
self.tmp_dir = pathlib.Path(self.exit_stack.enter_context(tempfile.TemporaryDirectory()))
|
|
121
|
+
return self
|
|
122
|
+
|
|
123
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
124
|
+
self.hard_stop()
|
|
125
|
+
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
|
126
|
+
self.exit_stack = None
|
|
127
|
+
self.tmp_dir = None
|
|
128
|
+
|
|
129
|
+
def start_workers(self, *, actually_start_worker_processes: bool = True):
|
|
130
|
+
assert not self.started
|
|
131
|
+
|
|
132
|
+
# Use max_batch_size from collection_options if provided, otherwise default to 1024 as large
|
|
133
|
+
# batch sizes can lead to thrashing
|
|
134
|
+
max_batch_shots = self.collection_options.max_batch_size or 1024
|
|
135
|
+
|
|
136
|
+
sampler = RampThrottledSampler(
|
|
137
|
+
sub_sampler=MuxSampler(
|
|
138
|
+
custom_decoders=self.custom_decoders,
|
|
139
|
+
count_observable_error_combos=self.count_observable_error_combos,
|
|
140
|
+
count_detection_events=self.count_detection_events,
|
|
141
|
+
tmp_dir=self.tmp_dir,
|
|
142
|
+
),
|
|
143
|
+
target_batch_seconds=1,
|
|
144
|
+
max_batch_shots=max_batch_shots,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
self.started = True
|
|
148
|
+
current_method = multiprocessing.get_start_method()
|
|
149
|
+
try:
|
|
150
|
+
# To ensure the child processes do not accidentally share ANY state
|
|
151
|
+
# related to random number generation, we use 'spawn' instead of 'fork'.
|
|
152
|
+
multiprocessing.set_start_method('spawn', force=True)
|
|
153
|
+
# Create queues after setting start method to work around a deadlock
|
|
154
|
+
# bug that occurs otherwise.
|
|
155
|
+
self.shared_worker_output_queue = multiprocessing.SimpleQueue()
|
|
156
|
+
|
|
157
|
+
for worker_id in range(self.num_workers):
|
|
158
|
+
worker_state = self.worker_states[worker_id]
|
|
159
|
+
worker_state.input_queue = multiprocessing.Queue()
|
|
160
|
+
worker_state.input_queue.cancel_join_thread()
|
|
161
|
+
worker_state.assigned_work_key = None
|
|
162
|
+
args = (
|
|
163
|
+
self.worker_flush_period,
|
|
164
|
+
worker_id,
|
|
165
|
+
sampler,
|
|
166
|
+
worker_state.input_queue,
|
|
167
|
+
self.shared_worker_output_queue,
|
|
168
|
+
worker_state.cpu_pin,
|
|
169
|
+
self.custom_error_count_key,
|
|
170
|
+
)
|
|
171
|
+
if self.use_threads_for_debugging:
|
|
172
|
+
worker_state.process = threading.Thread(
|
|
173
|
+
target=collection_worker_loop,
|
|
174
|
+
args=args,
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
worker_state.process = multiprocessing.Process(
|
|
178
|
+
target=collection_worker_loop,
|
|
179
|
+
args=args,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if actually_start_worker_processes:
|
|
183
|
+
worker_state.process.start()
|
|
184
|
+
finally:
|
|
185
|
+
multiprocessing.set_start_method(current_method, force=True)
|
|
186
|
+
|
|
187
|
+
def start_distributing_work(self):
|
|
188
|
+
self._compute_task_ids()
|
|
189
|
+
self._distribute_work()
|
|
190
|
+
|
|
191
|
+
def _compute_task_ids(self):
|
|
192
|
+
idle_worker_ids = list(range(self.num_workers))
|
|
193
|
+
unknown_task_ids = list(range(len(self.partial_tasks)))
|
|
194
|
+
worker_to_task_map = {}
|
|
195
|
+
while worker_to_task_map or unknown_task_ids:
|
|
196
|
+
while idle_worker_ids and unknown_task_ids:
|
|
197
|
+
worker_id = idle_worker_ids.pop()
|
|
198
|
+
unknown_task_id = unknown_task_ids.pop()
|
|
199
|
+
worker_to_task_map[worker_id] = unknown_task_id
|
|
200
|
+
self.worker_states[worker_id].send_message(('compute_strong_id', self.partial_tasks[unknown_task_id]))
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
message = self.shared_worker_output_queue.get()
|
|
204
|
+
message_type, worker_id, message_body = message
|
|
205
|
+
if message_type == 'computed_strong_id':
|
|
206
|
+
assert worker_id in worker_to_task_map
|
|
207
|
+
assert isinstance(message_body, str)
|
|
208
|
+
self.task_strong_ids[worker_to_task_map.pop(worker_id)] = message_body
|
|
209
|
+
idle_worker_ids.append(worker_id)
|
|
210
|
+
elif message_type == 'stopped_due_to_exception':
|
|
211
|
+
cur_task, cur_shots_left, unflushed_work_done, traceback, ex = message_body
|
|
212
|
+
raise ValueError(f'Worker failed: traceback={traceback}') from ex
|
|
213
|
+
else:
|
|
214
|
+
raise NotImplementedError(f'{message_type=}')
|
|
215
|
+
self.progress_callback(None)
|
|
216
|
+
except queue.Empty:
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
assert len(idle_worker_ids) == self.num_workers
|
|
220
|
+
seen = set()
|
|
221
|
+
for k in range(len(self.partial_tasks)):
|
|
222
|
+
options = self.partial_tasks[k].collection_options.combine(self.collection_options)
|
|
223
|
+
key: str = self.task_strong_ids[k]
|
|
224
|
+
if key in seen:
|
|
225
|
+
raise ValueError(f'Same task given twice: {self.partial_tasks[k]!r}')
|
|
226
|
+
seen.add(key)
|
|
227
|
+
|
|
228
|
+
shots_left = options.max_shots
|
|
229
|
+
errors_left = options.max_errors
|
|
230
|
+
if errors_left is None:
|
|
231
|
+
errors_left = shots_left
|
|
232
|
+
errors_left = min(errors_left, shots_left)
|
|
233
|
+
if key in self.existing_data:
|
|
234
|
+
val = self.existing_data[key]
|
|
235
|
+
shots_left -= val.shots
|
|
236
|
+
if self.custom_error_count_key is None:
|
|
237
|
+
errors_left -= val.errors
|
|
238
|
+
else:
|
|
239
|
+
errors_left -= val.custom_counts[self.custom_error_count_key]
|
|
240
|
+
if shots_left <= 0:
|
|
241
|
+
continue
|
|
242
|
+
self.task_states[key] = _ManagedTaskState(
|
|
243
|
+
partial_task=self.partial_tasks[k],
|
|
244
|
+
strong_id=key,
|
|
245
|
+
shots_left=shots_left,
|
|
246
|
+
errors_left=errors_left,
|
|
247
|
+
)
|
|
248
|
+
if self.task_states[key].is_completed():
|
|
249
|
+
del self.task_states[key]
|
|
250
|
+
|
|
251
|
+
def hard_stop(self):
|
|
252
|
+
if not self.started:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
removed_workers = [state.process for state in self.worker_states]
|
|
256
|
+
for state in self.worker_states:
|
|
257
|
+
if isinstance(state.process, threading.Thread):
|
|
258
|
+
state.send_message('stop')
|
|
259
|
+
state.process = None
|
|
260
|
+
state.assigned_work_key = None
|
|
261
|
+
state.input_queue = None
|
|
262
|
+
self.shared_worker_output_queue = None
|
|
263
|
+
self.started = False
|
|
264
|
+
self.task_states.clear()
|
|
265
|
+
|
|
266
|
+
# SIGKILL everything.
|
|
267
|
+
for w in removed_workers:
|
|
268
|
+
if isinstance(w, multiprocessing.Process):
|
|
269
|
+
w.kill()
|
|
270
|
+
# Wait for them to be done.
|
|
271
|
+
for w in removed_workers:
|
|
272
|
+
w.join()
|
|
273
|
+
|
|
274
|
+
def _handle_task_progress(self, task_id: Any):
|
|
275
|
+
task_state = self.task_states[task_id]
|
|
276
|
+
if task_state.is_completed():
|
|
277
|
+
workers_ready = all(self.worker_states[worker_id].has_returned_all_shots() for worker_id in task_state.workers_assigned)
|
|
278
|
+
if workers_ready:
|
|
279
|
+
# Task is fully completed and can be forgotten entirely. Re-assign the workers.
|
|
280
|
+
del self.task_states[task_id]
|
|
281
|
+
for worker_id in task_state.workers_assigned:
|
|
282
|
+
w = self.worker_states[worker_id]
|
|
283
|
+
assert w.assigned_shots <= 0
|
|
284
|
+
assert w.asked_to_drop_shots == 0
|
|
285
|
+
w.assigned_work_key = None
|
|
286
|
+
self._distribute_work()
|
|
287
|
+
else:
|
|
288
|
+
# Task is sufficiently sampled, but some workers are still running.
|
|
289
|
+
for worker_id in task_state.workers_assigned:
|
|
290
|
+
self.worker_states[worker_id].ask_to_return_all_shots()
|
|
291
|
+
self.progress_callback(None)
|
|
292
|
+
else:
|
|
293
|
+
self._distribute_unassigned_workers_to_jobs()
|
|
294
|
+
self._distribute_work_within_a_job(task_state)
|
|
295
|
+
|
|
296
|
+
def state_summary(self) -> str:
|
|
297
|
+
lines = []
|
|
298
|
+
for worker_id, worker in enumerate(self.worker_states):
|
|
299
|
+
lines.append(f'worker {worker_id}:'
|
|
300
|
+
f' asked_to_drop_shots={worker.asked_to_drop_shots}'
|
|
301
|
+
f' assigned_shots={worker.assigned_shots}'
|
|
302
|
+
f' assigned_work_key={worker.assigned_work_key}')
|
|
303
|
+
for task in self.task_states.values():
|
|
304
|
+
lines.append(f'task {task.strong_id=}:\n'
|
|
305
|
+
f' workers_assigned={task.workers_assigned}\n'
|
|
306
|
+
f' shot_return_requests={task.shot_return_requests}\n'
|
|
307
|
+
f' shots_left={task.shots_left}\n'
|
|
308
|
+
f' errors_left={task.errors_left}\n'
|
|
309
|
+
f' shots_unassigned={task.shots_unassigned}')
|
|
310
|
+
return '\n' + '\n'.join(lines) + '\n'
|
|
311
|
+
|
|
312
|
+
def process_message(self) -> bool:
|
|
313
|
+
try:
|
|
314
|
+
message = self.shared_worker_output_queue.get()
|
|
315
|
+
except queue.Empty:
|
|
316
|
+
return False
|
|
317
|
+
|
|
318
|
+
message_type, worker_id, message_body = message
|
|
319
|
+
worker_state = self.worker_states[worker_id]
|
|
320
|
+
|
|
321
|
+
if message_type == 'flushed_results':
|
|
322
|
+
task_strong_id, anon_stat = message_body
|
|
323
|
+
assert isinstance(anon_stat, AnonTaskStats)
|
|
324
|
+
assert worker_state.assigned_work_key == task_strong_id
|
|
325
|
+
task_state = self.task_states[task_strong_id]
|
|
326
|
+
|
|
327
|
+
worker_state.assigned_shots -= anon_stat.shots
|
|
328
|
+
task_state.shots_left -= anon_stat.shots
|
|
329
|
+
if worker_state.assigned_shots < 0:
|
|
330
|
+
# Worker over-achieved. Correct the imbalance by giving them the shots.
|
|
331
|
+
extra_shots = abs(worker_state.assigned_shots)
|
|
332
|
+
worker_state.assigned_shots += extra_shots
|
|
333
|
+
task_state.shots_unassigned -= extra_shots
|
|
334
|
+
worker_state.send_message((
|
|
335
|
+
'accept_shots',
|
|
336
|
+
(task_state.strong_id, extra_shots),
|
|
337
|
+
))
|
|
338
|
+
|
|
339
|
+
if self.custom_error_count_key is None:
|
|
340
|
+
task_state.errors_left -= anon_stat.errors
|
|
341
|
+
else:
|
|
342
|
+
task_state.errors_left -= anon_stat.custom_counts[self.custom_error_count_key]
|
|
343
|
+
|
|
344
|
+
stat = TaskStats(
|
|
345
|
+
strong_id=task_state.strong_id,
|
|
346
|
+
decoder=task_state.partial_task.decoder,
|
|
347
|
+
json_metadata=task_state.partial_task.json_metadata,
|
|
348
|
+
shots=anon_stat.shots,
|
|
349
|
+
discards=anon_stat.discards,
|
|
350
|
+
seconds=anon_stat.seconds,
|
|
351
|
+
errors=anon_stat.errors,
|
|
352
|
+
custom_counts=anon_stat.custom_counts,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
self._handle_task_progress(task_strong_id)
|
|
356
|
+
|
|
357
|
+
if stat.strong_id not in self.total_collected:
|
|
358
|
+
self.total_collected[stat.strong_id] = AnonTaskStats()
|
|
359
|
+
self.total_collected[stat.strong_id] += stat.to_anon_stats()
|
|
360
|
+
self.progress_callback(stat)
|
|
361
|
+
|
|
362
|
+
elif message_type == 'changed_job':
|
|
363
|
+
pass
|
|
364
|
+
|
|
365
|
+
elif message_type == 'accepted_shots':
|
|
366
|
+
pass
|
|
367
|
+
|
|
368
|
+
elif message_type == 'returned_shots':
|
|
369
|
+
task_key, shots_returned = message_body
|
|
370
|
+
assert isinstance(shots_returned, int)
|
|
371
|
+
assert shots_returned >= 0
|
|
372
|
+
assert worker_state.assigned_work_key == task_key
|
|
373
|
+
assert worker_state.asked_to_drop_shots or worker_state.asked_to_drop_errors
|
|
374
|
+
task_state = self.task_states[task_key]
|
|
375
|
+
task_state.shot_return_requests -= 1
|
|
376
|
+
worker_state.asked_to_drop_shots = 0
|
|
377
|
+
worker_state.asked_to_drop_errors = 0
|
|
378
|
+
task_state.shots_unassigned += shots_returned
|
|
379
|
+
worker_state.assigned_shots -= shots_returned
|
|
380
|
+
assert worker_state.assigned_shots >= 0
|
|
381
|
+
self._handle_task_progress(task_key)
|
|
382
|
+
|
|
383
|
+
elif message_type == 'stopped_due_to_exception':
|
|
384
|
+
cur_task, cur_shots_left, unflushed_work_done, traceback, ex = message_body
|
|
385
|
+
raise RuntimeError(f'Worker failed: traceback={traceback}') from ex
|
|
386
|
+
|
|
387
|
+
else:
|
|
388
|
+
raise NotImplementedError(f'{message_type=}')
|
|
389
|
+
|
|
390
|
+
return True
|
|
391
|
+
|
|
392
|
+
def run_until_done(self) -> bool:
|
|
393
|
+
try:
|
|
394
|
+
while self.task_states:
|
|
395
|
+
self.process_message()
|
|
396
|
+
return True
|
|
397
|
+
|
|
398
|
+
except KeyboardInterrupt:
|
|
399
|
+
return False
|
|
400
|
+
|
|
401
|
+
finally:
|
|
402
|
+
self.hard_stop()
|
|
403
|
+
|
|
404
|
+
def _distribute_unassigned_workers_to_jobs(self):
|
|
405
|
+
idle_workers = [
|
|
406
|
+
w
|
|
407
|
+
for w in range(self.num_workers)[::-1]
|
|
408
|
+
if self.worker_states[w].is_available_to_reassign()
|
|
409
|
+
]
|
|
410
|
+
if not idle_workers or not self.started:
|
|
411
|
+
return
|
|
412
|
+
|
|
413
|
+
groups = collections.defaultdict(list)
|
|
414
|
+
for work_state in self.task_states.values():
|
|
415
|
+
if not work_state.is_completed():
|
|
416
|
+
groups[len(work_state.workers_assigned)].append(work_state)
|
|
417
|
+
for k in groups.keys():
|
|
418
|
+
groups[k] = groups[k][::-1]
|
|
419
|
+
if not groups:
|
|
420
|
+
return
|
|
421
|
+
min_assigned = min(groups.keys(), default=0)
|
|
422
|
+
|
|
423
|
+
# Distribute workers to unfinished jobs with the fewest workers.
|
|
424
|
+
while idle_workers:
|
|
425
|
+
task_state: _ManagedTaskState = groups[min_assigned].pop()
|
|
426
|
+
groups[min_assigned + 1].append(task_state)
|
|
427
|
+
if not groups[min_assigned]:
|
|
428
|
+
min_assigned += 1
|
|
429
|
+
|
|
430
|
+
worker_id = idle_workers.pop()
|
|
431
|
+
task_state.workers_assigned.append(worker_id)
|
|
432
|
+
worker_state = self.worker_states[worker_id]
|
|
433
|
+
worker_state.assigned_work_key = task_state.strong_id
|
|
434
|
+
worker_state.send_message((
|
|
435
|
+
'change_job',
|
|
436
|
+
(task_state.partial_task, CollectionOptions(max_errors=task_state.errors_left), task_state.assigned_soft_error_flush_threshold),
|
|
437
|
+
))
|
|
438
|
+
|
|
439
|
+
def _distribute_unassigned_work_to_workers_within_a_job(self, task_state: _ManagedTaskState):
|
|
440
|
+
if not self.started or not task_state.workers_assigned or task_state.shots_left <= 0:
|
|
441
|
+
return
|
|
442
|
+
|
|
443
|
+
num_task_workers = len(task_state.workers_assigned)
|
|
444
|
+
expected_shots_per_worker = (task_state.shots_left + num_task_workers - 1) // num_task_workers
|
|
445
|
+
|
|
446
|
+
# Give unassigned shots to idle workers.
|
|
447
|
+
for worker_id in sorted(task_state.workers_assigned, key=lambda wid: self.worker_states[wid].assigned_shots):
|
|
448
|
+
worker_state = self.worker_states[worker_id]
|
|
449
|
+
if worker_state.assigned_shots < expected_shots_per_worker:
|
|
450
|
+
shots_to_assign = min(expected_shots_per_worker - worker_state.assigned_shots,
|
|
451
|
+
task_state.shots_unassigned)
|
|
452
|
+
if shots_to_assign > 0:
|
|
453
|
+
task_state.shots_unassigned -= shots_to_assign
|
|
454
|
+
worker_state.assigned_shots += shots_to_assign
|
|
455
|
+
worker_state.send_message((
|
|
456
|
+
'accept_shots',
|
|
457
|
+
(task_state.strong_id, shots_to_assign),
|
|
458
|
+
))
|
|
459
|
+
|
|
460
|
+
def status_message(self) -> str:
|
|
461
|
+
num_known_tasks_ids = sum(e is not None for e in self.task_strong_ids)
|
|
462
|
+
if num_known_tasks_ids < len(self.task_strong_ids):
|
|
463
|
+
return f"Analyzed {num_known_tasks_ids}/{len(self.task_strong_ids)} tasks..."
|
|
464
|
+
max_errors = self.collection_options.max_errors
|
|
465
|
+
max_shots = self.collection_options.max_shots
|
|
466
|
+
|
|
467
|
+
tasks_left = 0
|
|
468
|
+
lines = []
|
|
469
|
+
skipped_lines = []
|
|
470
|
+
for k, strong_id in enumerate(self.task_strong_ids):
|
|
471
|
+
if strong_id not in self.task_states:
|
|
472
|
+
continue
|
|
473
|
+
c = self.total_collected.get(strong_id, AnonTaskStats())
|
|
474
|
+
tasks_left += 1
|
|
475
|
+
w = len(self.task_states[strong_id].workers_assigned)
|
|
476
|
+
dt = None
|
|
477
|
+
if max_shots is not None and c.shots:
|
|
478
|
+
dt = (max_shots - c.shots) * c.seconds / c.shots
|
|
479
|
+
c_errors = c.custom_counts[self.custom_error_count_key] if self.custom_error_count_key is not None else c.errors
|
|
480
|
+
if max_errors is not None and c_errors and c.seconds:
|
|
481
|
+
dt2 = (max_errors - c_errors) * c.seconds / c_errors
|
|
482
|
+
if dt is None:
|
|
483
|
+
dt = dt2
|
|
484
|
+
else:
|
|
485
|
+
dt = min(dt, dt2)
|
|
486
|
+
if dt is not None:
|
|
487
|
+
dt /= 60
|
|
488
|
+
if dt is not None and w > 0:
|
|
489
|
+
dt /= w
|
|
490
|
+
line = [
|
|
491
|
+
f'{w}',
|
|
492
|
+
self.partial_tasks[k].decoder,
|
|
493
|
+
("?" if dt is None or dt == 0 else "[draining]" if dt <= 0 else "<1m" if dt < 1 else str(round(dt)) + 'm') + ('·∞' if w == 0 else ''),
|
|
494
|
+
f'{max_shots - c.shots}' if max_shots is not None else f'{c.shots}',
|
|
495
|
+
f'{max_errors - c_errors}' if max_errors is not None else f'{c_errors}',
|
|
496
|
+
",".join(
|
|
497
|
+
[f"{k}={v}" for k, v in self.partial_tasks[k].json_metadata.items()]
|
|
498
|
+
if isinstance(self.partial_tasks[k].json_metadata, dict)
|
|
499
|
+
else str(self.partial_tasks[k].json_metadata)
|
|
500
|
+
)
|
|
501
|
+
]
|
|
502
|
+
if w == 0:
|
|
503
|
+
skipped_lines.append(line)
|
|
504
|
+
else:
|
|
505
|
+
lines.append(line)
|
|
506
|
+
if len(lines) < 50 and skipped_lines:
|
|
507
|
+
missing_lines = 50 - len(lines)
|
|
508
|
+
lines.extend(skipped_lines[:missing_lines])
|
|
509
|
+
skipped_lines = skipped_lines[missing_lines:]
|
|
510
|
+
|
|
511
|
+
if lines:
|
|
512
|
+
lines.insert(0, [
|
|
513
|
+
'workers',
|
|
514
|
+
'decoder',
|
|
515
|
+
'eta',
|
|
516
|
+
'shots_left' if max_shots is not None else 'shots_taken',
|
|
517
|
+
'errors_left' if max_errors is not None else 'errors_seen',
|
|
518
|
+
'json_metadata'])
|
|
519
|
+
justs = cast(list[Callable[[str, int], str]], [str.rjust, str.rjust, str.rjust, str.rjust, str.rjust, str.ljust])
|
|
520
|
+
cols = len(lines[0])
|
|
521
|
+
lengths = [
|
|
522
|
+
max(len(lines[row][col]) for row in range(len(lines)))
|
|
523
|
+
for col in range(cols)
|
|
524
|
+
]
|
|
525
|
+
lines = [
|
|
526
|
+
" " + " ".join(justs[col](row[col], lengths[col]) for col in range(cols))
|
|
527
|
+
for row in lines
|
|
528
|
+
]
|
|
529
|
+
if skipped_lines:
|
|
530
|
+
lines.append(' ... (' + str(len(skipped_lines)) + ' more tasks) ...')
|
|
531
|
+
return f'{tasks_left} tasks left:\n' + '\n'.join(lines)
|
|
532
|
+
|
|
533
|
+
def _update_soft_error_threshold_for_a_job(self, task_state: _ManagedTaskState):
|
|
534
|
+
if task_state.errors_left <= len(task_state.workers_assigned):
|
|
535
|
+
desired_threshold = 1
|
|
536
|
+
elif task_state.errors_left <= task_state.assigned_soft_error_flush_threshold * self.num_workers:
|
|
537
|
+
desired_threshold = max(1, math.ceil(task_state.errors_left * 0.5 / self.num_workers))
|
|
538
|
+
else:
|
|
539
|
+
return
|
|
540
|
+
|
|
541
|
+
if task_state.assigned_soft_error_flush_threshold != desired_threshold:
|
|
542
|
+
task_state.assigned_soft_error_flush_threshold = desired_threshold
|
|
543
|
+
for wid in task_state.workers_assigned:
|
|
544
|
+
self.worker_states[wid].send_message(('set_soft_error_flush_threshold', desired_threshold))
|
|
545
|
+
|
|
546
|
+
def _take_work_if_unsatisfied_workers_within_a_job(self, task_state: _ManagedTaskState):
|
|
547
|
+
if not self.started or not task_state.workers_assigned or task_state.shots_left <= 0:
|
|
548
|
+
return
|
|
549
|
+
|
|
550
|
+
if all(self.worker_states[w].assigned_shots > 0 for w in task_state.workers_assigned):
|
|
551
|
+
return
|
|
552
|
+
|
|
553
|
+
w = len(task_state.workers_assigned)
|
|
554
|
+
expected_shots_per_worker = (task_state.shots_left + w - 1) // w
|
|
555
|
+
|
|
556
|
+
# There are idle workers that couldn't be given any shots. Take shots from other workers.
|
|
557
|
+
for worker_id in sorted(task_state.workers_assigned, key=lambda w: self.worker_states[w].assigned_shots, reverse=True):
|
|
558
|
+
worker_state = self.worker_states[worker_id]
|
|
559
|
+
if worker_state.asked_to_drop_shots or worker_state.assigned_shots <= expected_shots_per_worker:
|
|
560
|
+
continue
|
|
561
|
+
shots_to_take = worker_state.assigned_shots - expected_shots_per_worker
|
|
562
|
+
assert shots_to_take > 0
|
|
563
|
+
worker_state.asked_to_drop_shots = shots_to_take
|
|
564
|
+
task_state.shot_return_requests += 1
|
|
565
|
+
worker_state.send_message((
|
|
566
|
+
'return_shots',
|
|
567
|
+
(
|
|
568
|
+
task_state.strong_id,
|
|
569
|
+
shots_to_take,
|
|
570
|
+
),
|
|
571
|
+
))
|
|
572
|
+
|
|
573
|
+
def _distribute_work_within_a_job(self, t: _ManagedTaskState):
|
|
574
|
+
self._distribute_unassigned_work_to_workers_within_a_job(t)
|
|
575
|
+
self._take_work_if_unsatisfied_workers_within_a_job(t)
|
|
576
|
+
|
|
577
|
+
def _distribute_work(self):
|
|
578
|
+
self._distribute_unassigned_workers_to_jobs()
|
|
579
|
+
for w in self.task_states.values():
|
|
580
|
+
if not w.is_completed():
|
|
581
|
+
self._distribute_work_within_a_job(w)
|