sinter 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sinter might be problematic. Click here for more details.

Files changed (62) hide show
  1. sinter/__init__.py +47 -0
  2. sinter/_collection/__init__.py +10 -0
  3. sinter/_collection/_collection.py +480 -0
  4. sinter/_collection/_collection_manager.py +581 -0
  5. sinter/_collection/_collection_manager_test.py +287 -0
  6. sinter/_collection/_collection_test.py +317 -0
  7. sinter/_collection/_collection_worker_loop.py +35 -0
  8. sinter/_collection/_collection_worker_state.py +259 -0
  9. sinter/_collection/_collection_worker_test.py +222 -0
  10. sinter/_collection/_mux_sampler.py +56 -0
  11. sinter/_collection/_printer.py +65 -0
  12. sinter/_collection/_sampler_ramp_throttled.py +66 -0
  13. sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
  14. sinter/_command/__init__.py +0 -0
  15. sinter/_command/_main.py +39 -0
  16. sinter/_command/_main_collect.py +350 -0
  17. sinter/_command/_main_collect_test.py +482 -0
  18. sinter/_command/_main_combine.py +84 -0
  19. sinter/_command/_main_combine_test.py +153 -0
  20. sinter/_command/_main_plot.py +817 -0
  21. sinter/_command/_main_plot_test.py +445 -0
  22. sinter/_command/_main_predict.py +75 -0
  23. sinter/_command/_main_predict_test.py +36 -0
  24. sinter/_data/__init__.py +20 -0
  25. sinter/_data/_anon_task_stats.py +89 -0
  26. sinter/_data/_anon_task_stats_test.py +35 -0
  27. sinter/_data/_collection_options.py +106 -0
  28. sinter/_data/_collection_options_test.py +24 -0
  29. sinter/_data/_csv_out.py +74 -0
  30. sinter/_data/_existing_data.py +173 -0
  31. sinter/_data/_existing_data_test.py +41 -0
  32. sinter/_data/_task.py +311 -0
  33. sinter/_data/_task_stats.py +244 -0
  34. sinter/_data/_task_stats_test.py +140 -0
  35. sinter/_data/_task_test.py +38 -0
  36. sinter/_decoding/__init__.py +16 -0
  37. sinter/_decoding/_decoding.py +419 -0
  38. sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
  39. sinter/_decoding/_decoding_decoder_class.py +161 -0
  40. sinter/_decoding/_decoding_fusion_blossom.py +193 -0
  41. sinter/_decoding/_decoding_mwpf.py +302 -0
  42. sinter/_decoding/_decoding_pymatching.py +81 -0
  43. sinter/_decoding/_decoding_test.py +480 -0
  44. sinter/_decoding/_decoding_vacuous.py +38 -0
  45. sinter/_decoding/_perfectionist_sampler.py +38 -0
  46. sinter/_decoding/_sampler.py +72 -0
  47. sinter/_decoding/_stim_then_decode_sampler.py +222 -0
  48. sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
  49. sinter/_plotting.py +619 -0
  50. sinter/_plotting_test.py +108 -0
  51. sinter/_predict.py +381 -0
  52. sinter/_predict_test.py +227 -0
  53. sinter/_probability_util.py +519 -0
  54. sinter/_probability_util_test.py +281 -0
  55. sinter-1.15.0.data/data/README.md +332 -0
  56. sinter-1.15.0.data/data/readme_example_plot.png +0 -0
  57. sinter-1.15.0.data/data/requirements.txt +4 -0
  58. sinter-1.15.0.dist-info/METADATA +354 -0
  59. sinter-1.15.0.dist-info/RECORD +62 -0
  60. sinter-1.15.0.dist-info/WHEEL +5 -0
  61. sinter-1.15.0.dist-info/entry_points.txt +2 -0
  62. sinter-1.15.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,581 @@
1
+ import collections
2
+ import contextlib
3
+ import math
4
+ import multiprocessing
5
+ import os
6
+ import pathlib
7
+ import queue
8
+ import tempfile
9
+ import threading
10
+ from typing import Any, Optional, List, Dict, Iterable, Callable, Tuple
11
+ from typing import Union
12
+ from typing import cast
13
+
14
+ from sinter._collection._collection_worker_loop import collection_worker_loop
15
+ from sinter._collection._mux_sampler import MuxSampler
16
+ from sinter._collection._sampler_ramp_throttled import RampThrottledSampler
17
+ from sinter._data import CollectionOptions, Task, AnonTaskStats, TaskStats
18
+ from sinter._decoding import Sampler, Decoder
19
+
20
+
21
+ class _ManagedWorkerState:
22
+ def __init__(self, worker_id: int, *, cpu_pin: Optional[int] = None):
23
+ self.worker_id: int = worker_id
24
+ self.process: Union[multiprocessing.Process, threading.Thread, None] = None
25
+ self.input_queue: Optional[multiprocessing.Queue[Tuple[str, Any]]] = None
26
+ self.assigned_work_key: Any = None
27
+ self.asked_to_drop_shots: int = 0
28
+ self.cpu_pin = cpu_pin
29
+
30
+ # Shots transfer into this field when manager sends shot requests to workers.
31
+ # Shots transfer out of this field when clients flush results or respond to work return requests.
32
+ self.assigned_shots: int = 0
33
+
34
+ def send_message(self, message: Any):
35
+ self.input_queue.put(message)
36
+
37
+ def ask_to_return_all_shots(self):
38
+ if self.asked_to_drop_shots == 0 and self.assigned_shots > 0:
39
+ self.send_message((
40
+ 'return_shots',
41
+ (
42
+ self.assigned_work_key,
43
+ self.assigned_shots,
44
+ ),
45
+ ))
46
+ self.asked_to_drop_shots = self.assigned_shots
47
+
48
+ def has_returned_all_shots(self) -> bool:
49
+ return self.assigned_shots == 0 and self.asked_to_drop_shots == 0
50
+
51
+ def is_available_to_reassign(self) -> bool:
52
+ return self.assigned_work_key is None
53
+
54
+
55
+ class _ManagedTaskState:
56
+ def __init__(self, *, partial_task: Task, strong_id: str, shots_left: int, errors_left: int):
57
+ self.partial_task = partial_task
58
+ self.strong_id = strong_id
59
+ self.shots_left = shots_left
60
+ self.errors_left = errors_left
61
+ self.shots_unassigned = shots_left
62
+ self.shot_return_requests = 0
63
+ self.assigned_soft_error_flush_threshold: int = errors_left
64
+ self.workers_assigned: list[int] = []
65
+
66
+ def is_completed(self) -> bool:
67
+ return self.shots_left <= 0 or self.errors_left <= 0
68
+
69
+
70
+ class CollectionManager:
71
+ def __init__(
72
+ self,
73
+ *,
74
+ existing_data: Dict[Any, TaskStats],
75
+ collection_options: CollectionOptions,
76
+ custom_decoders: dict[str, Union[Decoder, Sampler]],
77
+ num_workers: int,
78
+ worker_flush_period: float,
79
+ tasks: Iterable[Task],
80
+ progress_callback: Callable[[Optional[TaskStats]], None],
81
+ allowed_cpu_affinity_ids: Optional[Iterable[int]],
82
+ count_observable_error_combos: bool = False,
83
+ count_detection_events: bool = False,
84
+ custom_error_count_key: Optional[str] = None,
85
+ use_threads_for_debugging: bool = False,
86
+ ):
87
+ assert isinstance(custom_decoders, dict)
88
+ self.existing_data = existing_data
89
+ self.num_workers: int = num_workers
90
+ self.custom_decoders = custom_decoders
91
+ self.worker_flush_period: float = worker_flush_period
92
+ self.progress_callback = progress_callback
93
+ self.collection_options = collection_options
94
+ self.partial_tasks: list[Task] = list(tasks)
95
+ self.task_strong_ids: List[Optional[str]] = [None] * len(self.partial_tasks)
96
+ self.allowed_cpu_affinity_ids = None if allowed_cpu_affinity_ids is None else sorted(set(allowed_cpu_affinity_ids))
97
+ self.count_observable_error_combos = count_observable_error_combos
98
+ self.count_detection_events = count_detection_events
99
+ self.custom_error_count_key = custom_error_count_key
100
+ self.use_threads_for_debugging = use_threads_for_debugging
101
+
102
+ self.shared_worker_output_queue: Optional[multiprocessing.SimpleQueue[Tuple[str, int, Any]]] = None
103
+ self.task_states: Dict[Any, _ManagedTaskState] = {}
104
+ self.started: bool = False
105
+ self.total_collected = {k: v.to_anon_stats() for k, v in existing_data.items()}
106
+
107
+ if self.allowed_cpu_affinity_ids is None:
108
+ cpus = range(os.cpu_count())
109
+ else:
110
+ num_cpus = os.cpu_count()
111
+ cpus = [e for e in self.allowed_cpu_affinity_ids if e < num_cpus]
112
+ self.worker_states: List[_ManagedWorkerState] = []
113
+ for index in range(num_workers):
114
+ cpu_pin = None if len(cpus) == 0 else cpus[index % len(cpus)]
115
+ self.worker_states.append(_ManagedWorkerState(index, cpu_pin=cpu_pin))
116
+ self.tmp_dir: Optional[pathlib.Path] = None
117
+
118
+ def __enter__(self):
119
+ self.exit_stack = contextlib.ExitStack().__enter__()
120
+ self.tmp_dir = pathlib.Path(self.exit_stack.enter_context(tempfile.TemporaryDirectory()))
121
+ return self
122
+
123
+ def __exit__(self, exc_type, exc_val, exc_tb):
124
+ self.hard_stop()
125
+ self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
126
+ self.exit_stack = None
127
+ self.tmp_dir = None
128
+
129
+ def start_workers(self, *, actually_start_worker_processes: bool = True):
130
+ assert not self.started
131
+
132
+ # Use max_batch_size from collection_options if provided, otherwise default to 1024 as large
133
+ # batch sizes can lead to thrashing
134
+ max_batch_shots = self.collection_options.max_batch_size or 1024
135
+
136
+ sampler = RampThrottledSampler(
137
+ sub_sampler=MuxSampler(
138
+ custom_decoders=self.custom_decoders,
139
+ count_observable_error_combos=self.count_observable_error_combos,
140
+ count_detection_events=self.count_detection_events,
141
+ tmp_dir=self.tmp_dir,
142
+ ),
143
+ target_batch_seconds=1,
144
+ max_batch_shots=max_batch_shots,
145
+ )
146
+
147
+ self.started = True
148
+ current_method = multiprocessing.get_start_method()
149
+ try:
150
+ # To ensure the child processes do not accidentally share ANY state
151
+ # related to random number generation, we use 'spawn' instead of 'fork'.
152
+ multiprocessing.set_start_method('spawn', force=True)
153
+ # Create queues after setting start method to work around a deadlock
154
+ # bug that occurs otherwise.
155
+ self.shared_worker_output_queue = multiprocessing.SimpleQueue()
156
+
157
+ for worker_id in range(self.num_workers):
158
+ worker_state = self.worker_states[worker_id]
159
+ worker_state.input_queue = multiprocessing.Queue()
160
+ worker_state.input_queue.cancel_join_thread()
161
+ worker_state.assigned_work_key = None
162
+ args = (
163
+ self.worker_flush_period,
164
+ worker_id,
165
+ sampler,
166
+ worker_state.input_queue,
167
+ self.shared_worker_output_queue,
168
+ worker_state.cpu_pin,
169
+ self.custom_error_count_key,
170
+ )
171
+ if self.use_threads_for_debugging:
172
+ worker_state.process = threading.Thread(
173
+ target=collection_worker_loop,
174
+ args=args,
175
+ )
176
+ else:
177
+ worker_state.process = multiprocessing.Process(
178
+ target=collection_worker_loop,
179
+ args=args,
180
+ )
181
+
182
+ if actually_start_worker_processes:
183
+ worker_state.process.start()
184
+ finally:
185
+ multiprocessing.set_start_method(current_method, force=True)
186
+
187
+ def start_distributing_work(self):
188
+ self._compute_task_ids()
189
+ self._distribute_work()
190
+
191
+ def _compute_task_ids(self):
192
+ idle_worker_ids = list(range(self.num_workers))
193
+ unknown_task_ids = list(range(len(self.partial_tasks)))
194
+ worker_to_task_map = {}
195
+ while worker_to_task_map or unknown_task_ids:
196
+ while idle_worker_ids and unknown_task_ids:
197
+ worker_id = idle_worker_ids.pop()
198
+ unknown_task_id = unknown_task_ids.pop()
199
+ worker_to_task_map[worker_id] = unknown_task_id
200
+ self.worker_states[worker_id].send_message(('compute_strong_id', self.partial_tasks[unknown_task_id]))
201
+
202
+ try:
203
+ message = self.shared_worker_output_queue.get()
204
+ message_type, worker_id, message_body = message
205
+ if message_type == 'computed_strong_id':
206
+ assert worker_id in worker_to_task_map
207
+ assert isinstance(message_body, str)
208
+ self.task_strong_ids[worker_to_task_map.pop(worker_id)] = message_body
209
+ idle_worker_ids.append(worker_id)
210
+ elif message_type == 'stopped_due_to_exception':
211
+ cur_task, cur_shots_left, unflushed_work_done, traceback, ex = message_body
212
+ raise ValueError(f'Worker failed: traceback={traceback}') from ex
213
+ else:
214
+ raise NotImplementedError(f'{message_type=}')
215
+ self.progress_callback(None)
216
+ except queue.Empty:
217
+ pass
218
+
219
+ assert len(idle_worker_ids) == self.num_workers
220
+ seen = set()
221
+ for k in range(len(self.partial_tasks)):
222
+ options = self.partial_tasks[k].collection_options.combine(self.collection_options)
223
+ key: str = self.task_strong_ids[k]
224
+ if key in seen:
225
+ raise ValueError(f'Same task given twice: {self.partial_tasks[k]!r}')
226
+ seen.add(key)
227
+
228
+ shots_left = options.max_shots
229
+ errors_left = options.max_errors
230
+ if errors_left is None:
231
+ errors_left = shots_left
232
+ errors_left = min(errors_left, shots_left)
233
+ if key in self.existing_data:
234
+ val = self.existing_data[key]
235
+ shots_left -= val.shots
236
+ if self.custom_error_count_key is None:
237
+ errors_left -= val.errors
238
+ else:
239
+ errors_left -= val.custom_counts[self.custom_error_count_key]
240
+ if shots_left <= 0:
241
+ continue
242
+ self.task_states[key] = _ManagedTaskState(
243
+ partial_task=self.partial_tasks[k],
244
+ strong_id=key,
245
+ shots_left=shots_left,
246
+ errors_left=errors_left,
247
+ )
248
+ if self.task_states[key].is_completed():
249
+ del self.task_states[key]
250
+
251
+ def hard_stop(self):
252
+ if not self.started:
253
+ return
254
+
255
+ removed_workers = [state.process for state in self.worker_states]
256
+ for state in self.worker_states:
257
+ if isinstance(state.process, threading.Thread):
258
+ state.send_message('stop')
259
+ state.process = None
260
+ state.assigned_work_key = None
261
+ state.input_queue = None
262
+ self.shared_worker_output_queue = None
263
+ self.started = False
264
+ self.task_states.clear()
265
+
266
+ # SIGKILL everything.
267
+ for w in removed_workers:
268
+ if isinstance(w, multiprocessing.Process):
269
+ w.kill()
270
+ # Wait for them to be done.
271
+ for w in removed_workers:
272
+ w.join()
273
+
274
+ def _handle_task_progress(self, task_id: Any):
275
+ task_state = self.task_states[task_id]
276
+ if task_state.is_completed():
277
+ workers_ready = all(self.worker_states[worker_id].has_returned_all_shots() for worker_id in task_state.workers_assigned)
278
+ if workers_ready:
279
+ # Task is fully completed and can be forgotten entirely. Re-assign the workers.
280
+ del self.task_states[task_id]
281
+ for worker_id in task_state.workers_assigned:
282
+ w = self.worker_states[worker_id]
283
+ assert w.assigned_shots <= 0
284
+ assert w.asked_to_drop_shots == 0
285
+ w.assigned_work_key = None
286
+ self._distribute_work()
287
+ else:
288
+ # Task is sufficiently sampled, but some workers are still running.
289
+ for worker_id in task_state.workers_assigned:
290
+ self.worker_states[worker_id].ask_to_return_all_shots()
291
+ self.progress_callback(None)
292
+ else:
293
+ self._distribute_unassigned_workers_to_jobs()
294
+ self._distribute_work_within_a_job(task_state)
295
+
296
+ def state_summary(self) -> str:
297
+ lines = []
298
+ for worker_id, worker in enumerate(self.worker_states):
299
+ lines.append(f'worker {worker_id}:'
300
+ f' asked_to_drop_shots={worker.asked_to_drop_shots}'
301
+ f' assigned_shots={worker.assigned_shots}'
302
+ f' assigned_work_key={worker.assigned_work_key}')
303
+ for task in self.task_states.values():
304
+ lines.append(f'task {task.strong_id=}:\n'
305
+ f' workers_assigned={task.workers_assigned}\n'
306
+ f' shot_return_requests={task.shot_return_requests}\n'
307
+ f' shots_left={task.shots_left}\n'
308
+ f' errors_left={task.errors_left}\n'
309
+ f' shots_unassigned={task.shots_unassigned}')
310
+ return '\n' + '\n'.join(lines) + '\n'
311
+
312
+ def process_message(self) -> bool:
313
+ try:
314
+ message = self.shared_worker_output_queue.get()
315
+ except queue.Empty:
316
+ return False
317
+
318
+ message_type, worker_id, message_body = message
319
+ worker_state = self.worker_states[worker_id]
320
+
321
+ if message_type == 'flushed_results':
322
+ task_strong_id, anon_stat = message_body
323
+ assert isinstance(anon_stat, AnonTaskStats)
324
+ assert worker_state.assigned_work_key == task_strong_id
325
+ task_state = self.task_states[task_strong_id]
326
+
327
+ worker_state.assigned_shots -= anon_stat.shots
328
+ task_state.shots_left -= anon_stat.shots
329
+ if worker_state.assigned_shots < 0:
330
+ # Worker over-achieved. Correct the imbalance by giving them the shots.
331
+ extra_shots = abs(worker_state.assigned_shots)
332
+ worker_state.assigned_shots += extra_shots
333
+ task_state.shots_unassigned -= extra_shots
334
+ worker_state.send_message((
335
+ 'accept_shots',
336
+ (task_state.strong_id, extra_shots),
337
+ ))
338
+
339
+ if self.custom_error_count_key is None:
340
+ task_state.errors_left -= anon_stat.errors
341
+ else:
342
+ task_state.errors_left -= anon_stat.custom_counts[self.custom_error_count_key]
343
+
344
+ stat = TaskStats(
345
+ strong_id=task_state.strong_id,
346
+ decoder=task_state.partial_task.decoder,
347
+ json_metadata=task_state.partial_task.json_metadata,
348
+ shots=anon_stat.shots,
349
+ discards=anon_stat.discards,
350
+ seconds=anon_stat.seconds,
351
+ errors=anon_stat.errors,
352
+ custom_counts=anon_stat.custom_counts,
353
+ )
354
+
355
+ self._handle_task_progress(task_strong_id)
356
+
357
+ if stat.strong_id not in self.total_collected:
358
+ self.total_collected[stat.strong_id] = AnonTaskStats()
359
+ self.total_collected[stat.strong_id] += stat.to_anon_stats()
360
+ self.progress_callback(stat)
361
+
362
+ elif message_type == 'changed_job':
363
+ pass
364
+
365
+ elif message_type == 'accepted_shots':
366
+ pass
367
+
368
+ elif message_type == 'returned_shots':
369
+ task_key, shots_returned = message_body
370
+ assert isinstance(shots_returned, int)
371
+ assert shots_returned >= 0
372
+ assert worker_state.assigned_work_key == task_key
373
+ assert worker_state.asked_to_drop_shots or worker_state.asked_to_drop_errors
374
+ task_state = self.task_states[task_key]
375
+ task_state.shot_return_requests -= 1
376
+ worker_state.asked_to_drop_shots = 0
377
+ worker_state.asked_to_drop_errors = 0
378
+ task_state.shots_unassigned += shots_returned
379
+ worker_state.assigned_shots -= shots_returned
380
+ assert worker_state.assigned_shots >= 0
381
+ self._handle_task_progress(task_key)
382
+
383
+ elif message_type == 'stopped_due_to_exception':
384
+ cur_task, cur_shots_left, unflushed_work_done, traceback, ex = message_body
385
+ raise RuntimeError(f'Worker failed: traceback={traceback}') from ex
386
+
387
+ else:
388
+ raise NotImplementedError(f'{message_type=}')
389
+
390
+ return True
391
+
392
+ def run_until_done(self) -> bool:
393
+ try:
394
+ while self.task_states:
395
+ self.process_message()
396
+ return True
397
+
398
+ except KeyboardInterrupt:
399
+ return False
400
+
401
+ finally:
402
+ self.hard_stop()
403
+
404
+ def _distribute_unassigned_workers_to_jobs(self):
405
+ idle_workers = [
406
+ w
407
+ for w in range(self.num_workers)[::-1]
408
+ if self.worker_states[w].is_available_to_reassign()
409
+ ]
410
+ if not idle_workers or not self.started:
411
+ return
412
+
413
+ groups = collections.defaultdict(list)
414
+ for work_state in self.task_states.values():
415
+ if not work_state.is_completed():
416
+ groups[len(work_state.workers_assigned)].append(work_state)
417
+ for k in groups.keys():
418
+ groups[k] = groups[k][::-1]
419
+ if not groups:
420
+ return
421
+ min_assigned = min(groups.keys(), default=0)
422
+
423
+ # Distribute workers to unfinished jobs with the fewest workers.
424
+ while idle_workers:
425
+ task_state: _ManagedTaskState = groups[min_assigned].pop()
426
+ groups[min_assigned + 1].append(task_state)
427
+ if not groups[min_assigned]:
428
+ min_assigned += 1
429
+
430
+ worker_id = idle_workers.pop()
431
+ task_state.workers_assigned.append(worker_id)
432
+ worker_state = self.worker_states[worker_id]
433
+ worker_state.assigned_work_key = task_state.strong_id
434
+ worker_state.send_message((
435
+ 'change_job',
436
+ (task_state.partial_task, CollectionOptions(max_errors=task_state.errors_left), task_state.assigned_soft_error_flush_threshold),
437
+ ))
438
+
439
+ def _distribute_unassigned_work_to_workers_within_a_job(self, task_state: _ManagedTaskState):
440
+ if not self.started or not task_state.workers_assigned or task_state.shots_left <= 0:
441
+ return
442
+
443
+ num_task_workers = len(task_state.workers_assigned)
444
+ expected_shots_per_worker = (task_state.shots_left + num_task_workers - 1) // num_task_workers
445
+
446
+ # Give unassigned shots to idle workers.
447
+ for worker_id in sorted(task_state.workers_assigned, key=lambda wid: self.worker_states[wid].assigned_shots):
448
+ worker_state = self.worker_states[worker_id]
449
+ if worker_state.assigned_shots < expected_shots_per_worker:
450
+ shots_to_assign = min(expected_shots_per_worker - worker_state.assigned_shots,
451
+ task_state.shots_unassigned)
452
+ if shots_to_assign > 0:
453
+ task_state.shots_unassigned -= shots_to_assign
454
+ worker_state.assigned_shots += shots_to_assign
455
+ worker_state.send_message((
456
+ 'accept_shots',
457
+ (task_state.strong_id, shots_to_assign),
458
+ ))
459
+
460
+ def status_message(self) -> str:
461
+ num_known_tasks_ids = sum(e is not None for e in self.task_strong_ids)
462
+ if num_known_tasks_ids < len(self.task_strong_ids):
463
+ return f"Analyzed {num_known_tasks_ids}/{len(self.task_strong_ids)} tasks..."
464
+ max_errors = self.collection_options.max_errors
465
+ max_shots = self.collection_options.max_shots
466
+
467
+ tasks_left = 0
468
+ lines = []
469
+ skipped_lines = []
470
+ for k, strong_id in enumerate(self.task_strong_ids):
471
+ if strong_id not in self.task_states:
472
+ continue
473
+ c = self.total_collected.get(strong_id, AnonTaskStats())
474
+ tasks_left += 1
475
+ w = len(self.task_states[strong_id].workers_assigned)
476
+ dt = None
477
+ if max_shots is not None and c.shots:
478
+ dt = (max_shots - c.shots) * c.seconds / c.shots
479
+ c_errors = c.custom_counts[self.custom_error_count_key] if self.custom_error_count_key is not None else c.errors
480
+ if max_errors is not None and c_errors and c.seconds:
481
+ dt2 = (max_errors - c_errors) * c.seconds / c_errors
482
+ if dt is None:
483
+ dt = dt2
484
+ else:
485
+ dt = min(dt, dt2)
486
+ if dt is not None:
487
+ dt /= 60
488
+ if dt is not None and w > 0:
489
+ dt /= w
490
+ line = [
491
+ f'{w}',
492
+ self.partial_tasks[k].decoder,
493
+ ("?" if dt is None or dt == 0 else "[draining]" if dt <= 0 else "<1m" if dt < 1 else str(round(dt)) + 'm') + ('·∞' if w == 0 else ''),
494
+ f'{max_shots - c.shots}' if max_shots is not None else f'{c.shots}',
495
+ f'{max_errors - c_errors}' if max_errors is not None else f'{c_errors}',
496
+ ",".join(
497
+ [f"{k}={v}" for k, v in self.partial_tasks[k].json_metadata.items()]
498
+ if isinstance(self.partial_tasks[k].json_metadata, dict)
499
+ else str(self.partial_tasks[k].json_metadata)
500
+ )
501
+ ]
502
+ if w == 0:
503
+ skipped_lines.append(line)
504
+ else:
505
+ lines.append(line)
506
+ if len(lines) < 50 and skipped_lines:
507
+ missing_lines = 50 - len(lines)
508
+ lines.extend(skipped_lines[:missing_lines])
509
+ skipped_lines = skipped_lines[missing_lines:]
510
+
511
+ if lines:
512
+ lines.insert(0, [
513
+ 'workers',
514
+ 'decoder',
515
+ 'eta',
516
+ 'shots_left' if max_shots is not None else 'shots_taken',
517
+ 'errors_left' if max_errors is not None else 'errors_seen',
518
+ 'json_metadata'])
519
+ justs = cast(list[Callable[[str, int], str]], [str.rjust, str.rjust, str.rjust, str.rjust, str.rjust, str.ljust])
520
+ cols = len(lines[0])
521
+ lengths = [
522
+ max(len(lines[row][col]) for row in range(len(lines)))
523
+ for col in range(cols)
524
+ ]
525
+ lines = [
526
+ " " + " ".join(justs[col](row[col], lengths[col]) for col in range(cols))
527
+ for row in lines
528
+ ]
529
+ if skipped_lines:
530
+ lines.append(' ... (' + str(len(skipped_lines)) + ' more tasks) ...')
531
+ return f'{tasks_left} tasks left:\n' + '\n'.join(lines)
532
+
533
+ def _update_soft_error_threshold_for_a_job(self, task_state: _ManagedTaskState):
534
+ if task_state.errors_left <= len(task_state.workers_assigned):
535
+ desired_threshold = 1
536
+ elif task_state.errors_left <= task_state.assigned_soft_error_flush_threshold * self.num_workers:
537
+ desired_threshold = max(1, math.ceil(task_state.errors_left * 0.5 / self.num_workers))
538
+ else:
539
+ return
540
+
541
+ if task_state.assigned_soft_error_flush_threshold != desired_threshold:
542
+ task_state.assigned_soft_error_flush_threshold = desired_threshold
543
+ for wid in task_state.workers_assigned:
544
+ self.worker_states[wid].send_message(('set_soft_error_flush_threshold', desired_threshold))
545
+
546
+ def _take_work_if_unsatisfied_workers_within_a_job(self, task_state: _ManagedTaskState):
547
+ if not self.started or not task_state.workers_assigned or task_state.shots_left <= 0:
548
+ return
549
+
550
+ if all(self.worker_states[w].assigned_shots > 0 for w in task_state.workers_assigned):
551
+ return
552
+
553
+ w = len(task_state.workers_assigned)
554
+ expected_shots_per_worker = (task_state.shots_left + w - 1) // w
555
+
556
+ # There are idle workers that couldn't be given any shots. Take shots from other workers.
557
+ for worker_id in sorted(task_state.workers_assigned, key=lambda w: self.worker_states[w].assigned_shots, reverse=True):
558
+ worker_state = self.worker_states[worker_id]
559
+ if worker_state.asked_to_drop_shots or worker_state.assigned_shots <= expected_shots_per_worker:
560
+ continue
561
+ shots_to_take = worker_state.assigned_shots - expected_shots_per_worker
562
+ assert shots_to_take > 0
563
+ worker_state.asked_to_drop_shots = shots_to_take
564
+ task_state.shot_return_requests += 1
565
+ worker_state.send_message((
566
+ 'return_shots',
567
+ (
568
+ task_state.strong_id,
569
+ shots_to_take,
570
+ ),
571
+ ))
572
+
573
+ def _distribute_work_within_a_job(self, t: _ManagedTaskState):
574
+ self._distribute_unassigned_work_to_workers_within_a_job(t)
575
+ self._take_work_if_unsatisfied_workers_within_a_job(t)
576
+
577
+ def _distribute_work(self):
578
+ self._distribute_unassigned_workers_to_jobs()
579
+ for w in self.task_states.values():
580
+ if not w.is_completed():
581
+ self._distribute_work_within_a_job(w)