sinter 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sinter might be problematic. Click here for more details.

Files changed (62) hide show
  1. sinter/__init__.py +47 -0
  2. sinter/_collection/__init__.py +10 -0
  3. sinter/_collection/_collection.py +480 -0
  4. sinter/_collection/_collection_manager.py +581 -0
  5. sinter/_collection/_collection_manager_test.py +287 -0
  6. sinter/_collection/_collection_test.py +317 -0
  7. sinter/_collection/_collection_worker_loop.py +35 -0
  8. sinter/_collection/_collection_worker_state.py +259 -0
  9. sinter/_collection/_collection_worker_test.py +222 -0
  10. sinter/_collection/_mux_sampler.py +56 -0
  11. sinter/_collection/_printer.py +65 -0
  12. sinter/_collection/_sampler_ramp_throttled.py +66 -0
  13. sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
  14. sinter/_command/__init__.py +0 -0
  15. sinter/_command/_main.py +39 -0
  16. sinter/_command/_main_collect.py +350 -0
  17. sinter/_command/_main_collect_test.py +482 -0
  18. sinter/_command/_main_combine.py +84 -0
  19. sinter/_command/_main_combine_test.py +153 -0
  20. sinter/_command/_main_plot.py +817 -0
  21. sinter/_command/_main_plot_test.py +445 -0
  22. sinter/_command/_main_predict.py +75 -0
  23. sinter/_command/_main_predict_test.py +36 -0
  24. sinter/_data/__init__.py +20 -0
  25. sinter/_data/_anon_task_stats.py +89 -0
  26. sinter/_data/_anon_task_stats_test.py +35 -0
  27. sinter/_data/_collection_options.py +106 -0
  28. sinter/_data/_collection_options_test.py +24 -0
  29. sinter/_data/_csv_out.py +74 -0
  30. sinter/_data/_existing_data.py +173 -0
  31. sinter/_data/_existing_data_test.py +41 -0
  32. sinter/_data/_task.py +311 -0
  33. sinter/_data/_task_stats.py +244 -0
  34. sinter/_data/_task_stats_test.py +140 -0
  35. sinter/_data/_task_test.py +38 -0
  36. sinter/_decoding/__init__.py +16 -0
  37. sinter/_decoding/_decoding.py +419 -0
  38. sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
  39. sinter/_decoding/_decoding_decoder_class.py +161 -0
  40. sinter/_decoding/_decoding_fusion_blossom.py +193 -0
  41. sinter/_decoding/_decoding_mwpf.py +302 -0
  42. sinter/_decoding/_decoding_pymatching.py +81 -0
  43. sinter/_decoding/_decoding_test.py +480 -0
  44. sinter/_decoding/_decoding_vacuous.py +38 -0
  45. sinter/_decoding/_perfectionist_sampler.py +38 -0
  46. sinter/_decoding/_sampler.py +72 -0
  47. sinter/_decoding/_stim_then_decode_sampler.py +222 -0
  48. sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
  49. sinter/_plotting.py +619 -0
  50. sinter/_plotting_test.py +108 -0
  51. sinter/_predict.py +381 -0
  52. sinter/_predict_test.py +227 -0
  53. sinter/_probability_util.py +519 -0
  54. sinter/_probability_util_test.py +281 -0
  55. sinter-1.15.0.data/data/README.md +332 -0
  56. sinter-1.15.0.data/data/readme_example_plot.png +0 -0
  57. sinter-1.15.0.data/data/requirements.txt +4 -0
  58. sinter-1.15.0.dist-info/METADATA +354 -0
  59. sinter-1.15.0.dist-info/RECORD +62 -0
  60. sinter-1.15.0.dist-info/WHEEL +5 -0
  61. sinter-1.15.0.dist-info/entry_points.txt +2 -0
  62. sinter-1.15.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,259 @@
1
+ import queue
2
+ import time
3
+ from typing import Any
4
+ from typing import Optional
5
+ from typing import TYPE_CHECKING
6
+
7
+ import stim
8
+
9
+ from sinter._data import AnonTaskStats
10
+ from sinter._data import CollectionOptions
11
+ from sinter._data import Task
12
+ from sinter._decoding import CompiledSampler
13
+ from sinter._decoding import Sampler
14
+
15
+ if TYPE_CHECKING:
16
+ import multiprocessing
17
+
18
+
19
+ def _fill_in_task(task: Task) -> Task:
20
+ changed = False
21
+ circuit = task.circuit
22
+ if circuit is None:
23
+ circuit = stim.Circuit.from_file(task.circuit_path)
24
+ changed = True
25
+ dem = task.detector_error_model
26
+ if dem is None:
27
+ try:
28
+ dem = circuit.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True)
29
+ except ValueError:
30
+ try:
31
+ dem = circuit.detector_error_model(approximate_disjoint_errors=True)
32
+ except ValueError:
33
+ dem = circuit.detector_error_model(approximate_disjoint_errors=True, flatten_loops=True)
34
+ changed = True
35
+ if not changed:
36
+ return task
37
+ return Task(
38
+ circuit=circuit,
39
+ decoder=task.decoder,
40
+ detector_error_model=dem,
41
+ postselection_mask=task.postselection_mask,
42
+ postselected_observables_mask=task.postselected_observables_mask,
43
+ json_metadata=task.json_metadata,
44
+ collection_options=task.collection_options,
45
+ )
46
+
47
+
48
+ class CollectionWorkerState:
49
+ def __init__(
50
+ self,
51
+ *,
52
+ flush_period: float,
53
+ worker_id: int,
54
+ inp: 'multiprocessing.Queue',
55
+ out: 'multiprocessing.Queue',
56
+ sampler: Sampler,
57
+ custom_error_count_key: Optional[str],
58
+ ):
59
+ assert isinstance(flush_period, (int, float))
60
+ assert isinstance(sampler, Sampler)
61
+ self.max_flush_period = flush_period
62
+ self.cur_flush_period = 0.01
63
+ self.inp = inp
64
+ self.out = out
65
+ self.sampler = sampler
66
+ self.compiled_sampler: CompiledSampler | None = None
67
+ self.worker_id = worker_id
68
+
69
+ self.current_task: Task | None = None
70
+ self.current_error_cutoff: int | None = None
71
+ self.custom_error_count_key = custom_error_count_key
72
+ self.current_task_shots_left: int = 0
73
+ self.unflushed_results: AnonTaskStats = AnonTaskStats()
74
+ self.last_flush_message_time = time.monotonic()
75
+ self.soft_error_flush_threshold: int = 1
76
+
77
+ def _send_message_to_manager(self, message: Any):
78
+ self.out.put(message)
79
+
80
+ def state_summary(self) -> str:
81
+ lines = [
82
+ f'Worker(id={self.worker_id}) [',
83
+ f' max_flush_period={self.max_flush_period}',
84
+ f' cur_flush_period={self.cur_flush_period}',
85
+ f' sampler={self.sampler}',
86
+ f' compiled_sampler={self.compiled_sampler}',
87
+ f' current_task={self.current_task}',
88
+ f' current_error_cutoff={self.current_error_cutoff}',
89
+ f' custom_error_count_key={self.custom_error_count_key}',
90
+ f' current_task_shots_left={self.current_task_shots_left}',
91
+ f' unflushed_results={self.unflushed_results}',
92
+ f' last_flush_message_time={self.last_flush_message_time}',
93
+ f' soft_error_flush_threshold={self.soft_error_flush_threshold}',
94
+ f']',
95
+ ]
96
+ return '\n' + '\n'.join(lines) + '\n'
97
+
98
+ def flush_results(self):
99
+ if self.unflushed_results.shots > 0:
100
+ self.last_flush_message_time = time.monotonic()
101
+ self.cur_flush_period = min(self.cur_flush_period * 1.4, self.max_flush_period)
102
+ self._send_message_to_manager((
103
+ 'flushed_results',
104
+ self.worker_id,
105
+ (self.current_task.strong_id(), self.unflushed_results),
106
+ ))
107
+ self.unflushed_results = AnonTaskStats()
108
+ return True
109
+ return False
110
+
111
+ def accept_shots(self, *, shots_delta: int):
112
+ assert shots_delta >= 0
113
+ self.current_task_shots_left += shots_delta
114
+ self._send_message_to_manager((
115
+ 'accepted_shots',
116
+ self.worker_id,
117
+ (self.current_task.strong_id(), shots_delta),
118
+ ))
119
+
120
+ def return_shots(self, *, requested_shots: int):
121
+ assert requested_shots >= 0
122
+ returned_shots = max(0, min(requested_shots, self.current_task_shots_left))
123
+ self.current_task_shots_left -= returned_shots
124
+ if self.current_task_shots_left <= 0:
125
+ self.flush_results()
126
+ self._send_message_to_manager((
127
+ 'returned_shots',
128
+ self.worker_id,
129
+ (self.current_task.strong_id(), returned_shots),
130
+ ))
131
+
132
+ def compute_strong_id(self, *, new_task: Task):
133
+ strong_id = _fill_in_task(new_task).strong_id()
134
+ self._send_message_to_manager((
135
+ 'computed_strong_id',
136
+ self.worker_id,
137
+ strong_id,
138
+ ))
139
+
140
+ def change_job(self, *, new_task: Task, new_collection_options: CollectionOptions):
141
+ self.flush_results()
142
+
143
+ self.current_task = _fill_in_task(new_task)
144
+ self.current_error_cutoff = new_collection_options.max_errors
145
+ self.compiled_sampler = self.sampler.compiled_sampler_for_task(self.current_task)
146
+ assert self.current_task.strong_id() is not None
147
+ self.current_task_shots_left = 0
148
+ self.last_flush_message_time = time.monotonic()
149
+
150
+ self._send_message_to_manager((
151
+ 'changed_job',
152
+ self.worker_id,
153
+ (self.current_task.strong_id(),),
154
+ ))
155
+
156
+ def process_messages(self) -> int:
157
+ num_processed = 0
158
+ while True:
159
+ try:
160
+ message = self.inp.get_nowait()
161
+ except queue.Empty:
162
+ return num_processed
163
+
164
+ num_processed += 1
165
+ message_type, message_body = message
166
+
167
+ if message_type == 'stop':
168
+ return -1
169
+
170
+ elif message_type == 'flush_results':
171
+ self.flush_results()
172
+
173
+ elif message_type == 'compute_strong_id':
174
+ assert isinstance(message_body, Task)
175
+ self.compute_strong_id(new_task=message_body)
176
+
177
+ elif message_type == 'change_job':
178
+ new_task, new_collection_options, soft_error_flush_threshold = message_body
179
+ self.cur_flush_period = 0.01
180
+ self.soft_error_flush_threshold = soft_error_flush_threshold
181
+ assert isinstance(new_task, Task)
182
+ self.change_job(new_task=new_task, new_collection_options=new_collection_options)
183
+
184
+ elif message_type == 'set_soft_error_flush_threshold':
185
+ soft_error_flush_threshold = message_body
186
+ self.soft_error_flush_threshold = soft_error_flush_threshold
187
+
188
+ elif message_type == 'accept_shots':
189
+ job_key, shots_delta = message_body
190
+ assert isinstance(shots_delta, int)
191
+ assert job_key == self.current_task.strong_id()
192
+ self.accept_shots(shots_delta=shots_delta)
193
+
194
+ elif message_type == 'return_shots':
195
+ job_key, requested_shots = message_body
196
+ assert isinstance(requested_shots, int)
197
+ assert job_key == self.current_task.strong_id()
198
+ self.return_shots(requested_shots=requested_shots)
199
+
200
+ else:
201
+ raise NotImplementedError(f'{message_type=}')
202
+
203
+ def num_unflushed_errors(self) -> int:
204
+ if self.custom_error_count_key is not None:
205
+ return self.unflushed_results.custom_counts[self.custom_error_count_key]
206
+ return self.unflushed_results.errors
207
+
208
+ def do_some_work(self) -> bool:
209
+ did_some_work = False
210
+
211
+ # Sample some stats.
212
+ if self.current_task_shots_left > 0:
213
+ # Don't keep sampling if we've exceeded the number of errors needed.
214
+ if self.current_error_cutoff is not None and self.current_error_cutoff <= 0:
215
+ return self.flush_results()
216
+
217
+ some_work_done = self.compiled_sampler.sample(self.current_task_shots_left)
218
+ if some_work_done.shots < 1:
219
+ raise ValueError(f"Sampler didn't do any work. It returned statistics with shots == 0: {some_work_done}.")
220
+ assert isinstance(some_work_done, AnonTaskStats)
221
+ self.current_task_shots_left -= some_work_done.shots
222
+ if self.current_error_cutoff is not None:
223
+ errors_done = some_work_done.custom_counts[self.custom_error_count_key] if self.custom_error_count_key is not None else some_work_done.errors
224
+ self.current_error_cutoff -= errors_done
225
+ self.unflushed_results += some_work_done
226
+ did_some_work = True
227
+
228
+ # Report them periodically.
229
+ should_flush = False
230
+ if self.num_unflushed_errors() >= self.soft_error_flush_threshold:
231
+ should_flush = True
232
+ if self.unflushed_results.shots > 0:
233
+ if self.current_task_shots_left <= 0 or self.last_flush_message_time + self.cur_flush_period < time.monotonic():
234
+ should_flush = True
235
+ if should_flush:
236
+ did_some_work |= self.flush_results()
237
+
238
+ return did_some_work
239
+
240
+ def run_message_loop(self):
241
+ try:
242
+ while True:
243
+ num_messages_processed = self.process_messages()
244
+ if num_messages_processed == -1:
245
+ break
246
+ did_some_work = self.do_some_work()
247
+ if not did_some_work and num_messages_processed == 0:
248
+ time.sleep(0.01)
249
+
250
+ except KeyboardInterrupt:
251
+ pass
252
+
253
+ except BaseException as ex:
254
+ import traceback
255
+ self._send_message_to_manager((
256
+ 'stopped_due_to_exception',
257
+ self.worker_id,
258
+ (None if self.current_task is None else self.current_task.strong_id(), self.current_task_shots_left, self.unflushed_results, traceback.format_exc(), ex),
259
+ ))
@@ -0,0 +1,222 @@
1
+ import collections
2
+ import multiprocessing
3
+ import time
4
+ from typing import Any, List
5
+
6
+ import sinter
7
+ import stim
8
+
9
+ from sinter._collection._collection_worker_state import CollectionWorkerState
10
+
11
+
12
+ class MockWorkHandler(sinter.Sampler, sinter.CompiledSampler):
13
+ def __init__(self):
14
+ self.expected_task = None
15
+ self.expected = collections.deque()
16
+
17
+ def compiled_sampler_for_task(self, task: sinter.Task) -> sinter.CompiledSampler:
18
+ assert task == self.expected_task
19
+ return self
20
+
21
+ def handles_throttling(self) -> bool:
22
+ return True
23
+
24
+ def sample(self, shots: int) -> sinter.AnonTaskStats:
25
+ assert self.expected
26
+ expected_shots, response = self.expected.popleft()
27
+ assert shots == expected_shots
28
+ return response
29
+
30
+
31
+ def _assert_drain_queue(q: multiprocessing.Queue, expected_contents: List[Any]):
32
+ for v in expected_contents:
33
+ assert q.get(timeout=0.1) == v
34
+ assert q.empty()
35
+
36
+
37
+ def _put_wait_not_empty(q: multiprocessing.Queue, item: Any):
38
+ q.put(item)
39
+ while q.empty():
40
+ time.sleep(0.0001)
41
+
42
+
43
+ def test_worker_stop():
44
+ handler = MockWorkHandler()
45
+
46
+ inp = multiprocessing.Queue()
47
+ out = multiprocessing.Queue()
48
+ inp.cancel_join_thread()
49
+ out.cancel_join_thread()
50
+
51
+ worker = CollectionWorkerState(
52
+ flush_period=-1,
53
+ worker_id=5,
54
+ sampler=handler,
55
+ inp=inp,
56
+ out=out,
57
+ custom_error_count_key=None,
58
+ )
59
+
60
+ assert worker.process_messages() == 0
61
+ _assert_drain_queue(out, [])
62
+
63
+ t0 = sinter.Task(
64
+ circuit=stim.Circuit('H 0'),
65
+ detector_error_model=stim.DetectorErrorModel(),
66
+ decoder='mock',
67
+ collection_options=sinter.CollectionOptions(max_shots=100_000_000),
68
+ json_metadata={'a': 3},
69
+ )
70
+ handler.expected_task = t0
71
+
72
+ _put_wait_not_empty(inp, ('change_job', (t0, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000)))
73
+ assert worker.process_messages() == 1
74
+ _assert_drain_queue(out, [('changed_job', 5, (t0.strong_id(),))])
75
+
76
+ _put_wait_not_empty(inp, ('stop', None))
77
+ assert worker.process_messages() == -1
78
+
79
+
80
+ def test_worker_skip_work():
81
+ handler = MockWorkHandler()
82
+
83
+ inp = multiprocessing.Queue()
84
+ out = multiprocessing.Queue()
85
+ inp.cancel_join_thread()
86
+ out.cancel_join_thread()
87
+
88
+ worker = CollectionWorkerState(
89
+ flush_period=-1,
90
+ worker_id=5,
91
+ sampler=handler,
92
+ inp=inp,
93
+ out=out,
94
+ custom_error_count_key=None,
95
+ )
96
+
97
+ assert worker.process_messages() == 0
98
+ _assert_drain_queue(out, [])
99
+
100
+ t0 = sinter.Task(
101
+ circuit=stim.Circuit('H 0'),
102
+ detector_error_model=stim.DetectorErrorModel(),
103
+ decoder='mock',
104
+ collection_options=sinter.CollectionOptions(max_shots=100_000_000),
105
+ json_metadata={'a': 3},
106
+ )
107
+ handler.expected_task = t0
108
+ _put_wait_not_empty(inp, ('change_job', (t0, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000)))
109
+ assert worker.process_messages() == 1
110
+ _assert_drain_queue(out, [('changed_job', 5, (t0.strong_id(),))])
111
+
112
+ _put_wait_not_empty(inp, ('accept_shots', (t0.strong_id(), 10000)))
113
+ assert worker.process_messages() == 1
114
+ _assert_drain_queue(out, [('accepted_shots', 5, (t0.strong_id(), 10000))])
115
+
116
+ assert worker.current_task == t0
117
+ assert worker.current_task_shots_left == 10000
118
+ assert worker.process_messages() == 0
119
+ _assert_drain_queue(out, [])
120
+
121
+ _put_wait_not_empty(inp, ('return_shots', (t0.strong_id(), 2000)))
122
+ assert worker.process_messages() == 1
123
+ _assert_drain_queue(out, [
124
+ ('returned_shots', 5, (t0.strong_id(), 2000)),
125
+ ])
126
+
127
+ _put_wait_not_empty(inp, ('return_shots', (t0.strong_id(), 20000000)))
128
+ assert worker.process_messages() == 1
129
+ _assert_drain_queue(out, [
130
+ ('returned_shots', 5, (t0.strong_id(), 8000)),
131
+ ])
132
+
133
+ assert not worker.do_some_work()
134
+
135
+
136
+ def test_worker_finish_work():
137
+ handler = MockWorkHandler()
138
+
139
+ inp = multiprocessing.Queue()
140
+ out = multiprocessing.Queue()
141
+ inp.cancel_join_thread()
142
+ out.cancel_join_thread()
143
+
144
+ worker = CollectionWorkerState(
145
+ flush_period=-1,
146
+ worker_id=5,
147
+ sampler=handler,
148
+ inp=inp,
149
+ out=out,
150
+ custom_error_count_key=None,
151
+ )
152
+
153
+ assert worker.process_messages() == 0
154
+ _assert_drain_queue(out, [])
155
+
156
+ ta = sinter.Task(
157
+ circuit=stim.Circuit('H 0'),
158
+ detector_error_model=stim.DetectorErrorModel(),
159
+ decoder='mock',
160
+ collection_options=sinter.CollectionOptions(max_shots=100_000_000),
161
+ json_metadata={'a': 3},
162
+ )
163
+ handler.expected_task = ta
164
+ _put_wait_not_empty(inp, ('change_job', (ta, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000)))
165
+ _put_wait_not_empty(inp, ('accept_shots', (ta.strong_id(), 10000)))
166
+ t0 = time.monotonic()
167
+ num_processed = 0
168
+ while True:
169
+ num_processed += worker.process_messages()
170
+ if num_processed >= 2:
171
+ break
172
+ if time.monotonic() - t0 > 1:
173
+ raise ValueError("Messages not processed")
174
+ assert num_processed == 2
175
+ _assert_drain_queue(out, [
176
+ ('changed_job', 5, (ta.strong_id(),)),
177
+ ('accepted_shots', 5, (ta.strong_id(), 10000)),
178
+ ])
179
+
180
+ assert worker.current_task == ta
181
+ assert worker.current_task_shots_left == 10000
182
+ assert worker.process_messages() == 0
183
+ _assert_drain_queue(out, [])
184
+
185
+ handler.expected.append((
186
+ 10000,
187
+ sinter.AnonTaskStats(
188
+ shots=1000,
189
+ errors=23,
190
+ discards=0,
191
+ seconds=1,
192
+ ),
193
+ ))
194
+
195
+ assert worker.do_some_work()
196
+ worker.flush_results()
197
+ _assert_drain_queue(out, [
198
+ ('flushed_results', 5, (ta.strong_id(), sinter.AnonTaskStats(shots=1000, errors=23, discards=0, seconds=1)))])
199
+
200
+ handler.expected.append((
201
+ 9000,
202
+ sinter.AnonTaskStats(
203
+ shots=9000,
204
+ errors=13,
205
+ discards=0,
206
+ seconds=1,
207
+ ),
208
+ ))
209
+
210
+ assert worker.do_some_work()
211
+ worker.flush_results()
212
+ _assert_drain_queue(out, [
213
+ ('flushed_results', 5, (ta.strong_id(), sinter.AnonTaskStats(
214
+ shots=9000,
215
+ errors=13,
216
+ discards=0,
217
+ seconds=1,
218
+ ))),
219
+ ])
220
+ assert not worker.do_some_work()
221
+ worker.flush_results()
222
+ _assert_drain_queue(out, [])
@@ -0,0 +1,56 @@
1
+ import pathlib
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ from sinter._data import Task
6
+ from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_SAMPLERS
7
+ from sinter._decoding._decoding_decoder_class import Decoder
8
+ from sinter._decoding._sampler import CompiledSampler
9
+ from sinter._decoding._sampler import Sampler
10
+ from sinter._decoding._stim_then_decode_sampler import StimThenDecodeSampler
11
+
12
+
13
+ class MuxSampler(Sampler):
14
+ """Looks up the sampler to use for a task, by the task's decoder name."""
15
+
16
+ def __init__(
17
+ self,
18
+ *,
19
+ custom_decoders: Union[dict[str, Union[Decoder, Sampler]], None],
20
+ count_observable_error_combos: bool,
21
+ count_detection_events: bool,
22
+ tmp_dir: Optional[pathlib.Path],
23
+ ):
24
+ self.custom_decoders = custom_decoders
25
+ self.count_observable_error_combos = count_observable_error_combos
26
+ self.count_detection_events = count_detection_events
27
+ self.tmp_dir = tmp_dir
28
+
29
+ def compiled_sampler_for_task(self, task: Task) -> CompiledSampler:
30
+ return self._resolve_sampler(task.decoder).compiled_sampler_for_task(task)
31
+
32
+ def _resolve_sampler(self, name: str) -> Sampler:
33
+ sub_sampler: Union[Decoder, Sampler]
34
+
35
+ if name in self.custom_decoders:
36
+ sub_sampler = self.custom_decoders[name]
37
+ elif name in BUILT_IN_SAMPLERS:
38
+ sub_sampler = BUILT_IN_SAMPLERS[name]
39
+ else:
40
+ raise NotImplementedError(f'Not a recognized decoder or sampler: {name=}. Did you forget to specify custom_decoders?')
41
+
42
+ if isinstance(sub_sampler, Sampler):
43
+ if self.count_detection_events:
44
+ raise NotImplementedError("'count_detection_events' not supported when using a custom Sampler (instead of a custom Decoder).")
45
+ if self.count_observable_error_combos:
46
+ raise NotImplementedError("'count_observable_error_combos' not supported when using a custom Sampler (instead of a custom Decoder).")
47
+ return sub_sampler
48
+ elif isinstance(sub_sampler, Decoder) or hasattr(sub_sampler, 'compile_decoder_for_dem'):
49
+ return StimThenDecodeSampler(
50
+ decoder=sub_sampler,
51
+ count_detection_events=self.count_detection_events,
52
+ count_observable_error_combos=self.count_observable_error_combos,
53
+ tmp_dir=self.tmp_dir,
54
+ )
55
+ else:
56
+ raise NotImplementedError(f"Don't know how to turn this into a Sampler: {sub_sampler!r}")
@@ -0,0 +1,65 @@
1
+ import threading
2
+ from typing import List, Any
3
+
4
+ import sys
5
+ import time
6
+
7
+
8
+ class ThrottledProgressPrinter:
9
+ """Handles printing progress updates interspersed amongst output.
10
+
11
+ Throttles the progress updates to not flood the screen when 100 show up
12
+ at the same time, and instead only show the latest one.
13
+ """
14
+ def __init__(self, *, outs: List[Any], print_progress: bool, min_progress_delay: float):
15
+ self.outs = outs
16
+ self.print_progress = print_progress
17
+ self.next_can_print_time = time.monotonic()
18
+ self.latest_msg = ''
19
+ self.latest_printed_msg = ''
20
+ self.min_progress_delay = min_progress_delay
21
+ self.is_worker_running = False
22
+ self.lock = threading.Lock()
23
+
24
+ def print_out(self, msg: str) -> None:
25
+ with self.lock:
26
+ for out in self.outs:
27
+ print(msg, file=out, flush=True)
28
+
29
+ def show_latest_progress(self, msg: str) -> None:
30
+ if not self.print_progress:
31
+ return
32
+ with self.lock:
33
+ if msg == self.latest_msg:
34
+ return
35
+ self.latest_msg = msg
36
+ if not self.is_worker_running:
37
+ dt = self._try_print_else_delay()
38
+ if dt > 0:
39
+ self.is_worker_running = True
40
+ threading.Thread(target=self._print_worker).start()
41
+
42
+ def flush(self):
43
+ with self.lock:
44
+ if self.latest_msg != "" and self.latest_printed_msg != self.latest_msg:
45
+ print('\033[31m' + self.latest_msg + '\033[0m', file=sys.stderr, flush=True)
46
+ self.latest_printed_msg = self.latest_msg
47
+
48
+ def _try_print_else_delay(self) -> float:
49
+ t = time.monotonic()
50
+ dt = self.next_can_print_time - t
51
+ if dt <= 0:
52
+ self.next_can_print_time = t + self.min_progress_delay
53
+ self.is_worker_running = False
54
+ if self.latest_msg != "" and self.latest_msg != self.latest_printed_msg:
55
+ print('\033[31m' + self.latest_msg + '\033[0m', file=sys.stderr, flush=True)
56
+ self.latest_printed_msg = self.latest_msg
57
+ return max(dt, 0)
58
+
59
+ def _print_worker(self):
60
+ while True:
61
+ with self.lock:
62
+ dt = self._try_print_else_delay()
63
+ if dt == 0:
64
+ break
65
+ time.sleep(dt)
@@ -0,0 +1,66 @@
1
+ import time
2
+
3
+ from sinter._decoding import Sampler, CompiledSampler
4
+ from sinter._data import Task, AnonTaskStats
5
+
6
+
7
+ class RampThrottledSampler(Sampler):
8
+ """Wraps a sampler to adjust requested shots to hit a target time.
9
+
10
+ This sampler will initially only take 1 shot per call. If the time taken
11
+ significantly undershoots the target time, the maximum number of shots per
12
+ call is increased by a constant factor. If it exceeds the target time, the
13
+ maximum is reduced by a constant factor. The result is that the sampler
14
+ "ramps up" how many shots it does per call until it takes roughly the target
15
+ time, and then dynamically adapts to stay near it.
16
+ """
17
+
18
+ def __init__(self, sub_sampler: Sampler, target_batch_seconds: float, max_batch_shots: int):
19
+ self.sub_sampler = sub_sampler
20
+ self.target_batch_seconds = target_batch_seconds
21
+ self.max_batch_shots = max_batch_shots
22
+
23
+ def __str__(self) -> str:
24
+ return f'CompiledRampThrottledSampler({self.sub_sampler})'
25
+
26
+ def compiled_sampler_for_task(self, task: Task) -> CompiledSampler:
27
+ compiled_sub_sampler = self.sub_sampler.compiled_sampler_for_task(task)
28
+ if compiled_sub_sampler.handles_throttling():
29
+ return compiled_sub_sampler
30
+
31
+ return CompiledRampThrottledSampler(
32
+ sub_sampler=compiled_sub_sampler,
33
+ target_batch_seconds=self.target_batch_seconds,
34
+ max_batch_shots=self.max_batch_shots,
35
+ )
36
+
37
+
38
+ class CompiledRampThrottledSampler(CompiledSampler):
39
+ def __init__(self, sub_sampler: CompiledSampler, target_batch_seconds: float, max_batch_shots: int):
40
+ self.sub_sampler = sub_sampler
41
+ self.target_batch_seconds = target_batch_seconds
42
+ self.batch_shots = 1
43
+ self.max_batch_shots = max_batch_shots
44
+
45
+ def __str__(self) -> str:
46
+ return f'CompiledRampThrottledSampler({self.sub_sampler})'
47
+
48
+ def sample(self, max_shots: int) -> AnonTaskStats:
49
+ t0 = time.monotonic()
50
+ actual_shots = min(max_shots, self.batch_shots)
51
+ result = self.sub_sampler.sample(actual_shots)
52
+ dt = time.monotonic() - t0
53
+
54
+ # Rebalance number of shots.
55
+ if self.batch_shots > 1 and dt > self.target_batch_seconds * 1.3:
56
+ self.batch_shots //= 2
57
+ if result.shots * 2 >= actual_shots:
58
+ for _ in range(4):
59
+ if self.batch_shots * 2 > self.max_batch_shots:
60
+ break
61
+ if dt > self.target_batch_seconds * 0.3:
62
+ break
63
+ self.batch_shots *= 2
64
+ dt *= 2
65
+
66
+ return result