sinter 1.15.dev1742871434__tar.gz → 1.16.dev1768553279__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sinter might be problematic. Click here for more details.

Files changed (67) hide show
  1. {sinter-1.15.dev1742871434/src/sinter.egg-info → sinter-1.16.dev1768553279}/PKG-INFO +1 -1
  2. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/setup.py +1 -1
  3. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/__init__.py +1 -1
  4. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/_collection_manager.py +5 -1
  5. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/_collection_test.py +46 -3
  6. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/_collection_worker_test.py +9 -1
  7. sinter-1.16.dev1768553279/src/sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
  8. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_anon_task_stats.py +6 -5
  9. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_csv_out.py +3 -0
  10. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_task_stats.py +11 -11
  11. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_task_stats_test.py +1 -1
  12. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_decoding_decoder_class.py +20 -3
  13. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_decoding_test.py +87 -0
  14. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_stim_then_decode_sampler.py +1 -1
  15. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_probability_util.py +2 -2
  16. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279/src/sinter.egg-info}/PKG-INFO +1 -1
  17. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter.egg-info/SOURCES.txt +1 -0
  18. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/MANIFEST.in +0 -0
  19. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/README.md +0 -0
  20. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/readme_example_plot.png +0 -0
  21. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/requirements.txt +0 -0
  22. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/setup.cfg +0 -0
  23. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/__init__.py +0 -0
  24. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/_collection.py +0 -0
  25. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/_collection_manager_test.py +0 -0
  26. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/_collection_worker_loop.py +0 -0
  27. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/_collection_worker_state.py +0 -0
  28. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/_mux_sampler.py +0 -0
  29. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/_printer.py +0 -0
  30. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_collection/_sampler_ramp_throttled.py +0 -0
  31. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_command/__init__.py +0 -0
  32. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_command/_main.py +0 -0
  33. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_command/_main_collect.py +0 -0
  34. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_command/_main_collect_test.py +0 -0
  35. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_command/_main_combine.py +0 -0
  36. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_command/_main_combine_test.py +0 -0
  37. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_command/_main_plot.py +0 -0
  38. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_command/_main_plot_test.py +0 -0
  39. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_command/_main_predict.py +0 -0
  40. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_command/_main_predict_test.py +0 -0
  41. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/__init__.py +0 -0
  42. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_anon_task_stats_test.py +0 -0
  43. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_collection_options.py +0 -0
  44. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_collection_options_test.py +0 -0
  45. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_existing_data.py +0 -0
  46. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_existing_data_test.py +0 -0
  47. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_task.py +0 -0
  48. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_data/_task_test.py +0 -0
  49. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/__init__.py +0 -0
  50. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_decoding.py +0 -0
  51. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_decoding_all_built_in_decoders.py +0 -0
  52. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_decoding_fusion_blossom.py +0 -0
  53. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_decoding_mwpf.py +0 -0
  54. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_decoding_pymatching.py +0 -0
  55. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_decoding_vacuous.py +0 -0
  56. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_perfectionist_sampler.py +0 -0
  57. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_sampler.py +0 -0
  58. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_decoding/_stim_then_decode_sampler_test.py +0 -0
  59. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_plotting.py +0 -0
  60. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_plotting_test.py +0 -0
  61. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_predict.py +0 -0
  62. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_predict_test.py +0 -0
  63. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter/_probability_util_test.py +0 -0
  64. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter.egg-info/dependency_links.txt +0 -0
  65. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter.egg-info/entry_points.txt +0 -0
  66. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter.egg-info/requires.txt +0 -0
  67. {sinter-1.15.dev1742871434 → sinter-1.16.dev1768553279}/src/sinter.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sinter
3
- Version: 1.15.dev1742871434
3
+ Version: 1.16.dev1768553279
4
4
  Summary: Samples stim circuits and decodes them using pymatching.
5
5
  Author: Craig Gidney
6
6
  Author-email: craig.gidney@gmail.com
@@ -19,7 +19,7 @@ with open('README.md', encoding='UTF-8') as f:
19
19
  with open('requirements.txt', encoding='UTF-8') as f:
20
20
  requirements = f.read().splitlines()
21
21
 
22
- __version__ = '1.15.dev1742871434'
22
+ __version__ = '1.16.dev1768553279'
23
23
 
24
24
  setup(
25
25
  name='sinter',
@@ -1,4 +1,4 @@
1
- __version__ = '1.15.dev1742871434'
1
+ __version__ = '1.16.dev1768553279'
2
2
 
3
3
  from sinter._collection import (
4
4
  collect,
@@ -129,6 +129,10 @@ class CollectionManager:
129
129
  def start_workers(self, *, actually_start_worker_processes: bool = True):
130
130
  assert not self.started
131
131
 
132
+ # Use max_batch_size from collection_options if provided, otherwise default to 1024 as large
133
+ # batch sizes can lead to thrashing
134
+ max_batch_shots = self.collection_options.max_batch_size or 1024
135
+
132
136
  sampler = RampThrottledSampler(
133
137
  sub_sampler=MuxSampler(
134
138
  custom_decoders=self.custom_decoders,
@@ -137,7 +141,7 @@ class CollectionManager:
137
141
  tmp_dir=self.tmp_dir,
138
142
  ),
139
143
  target_batch_seconds=1,
140
- max_batch_shots=1024,
144
+ max_batch_shots=max_batch_shots,
141
145
  )
142
146
 
143
147
  self.started = True
@@ -1,14 +1,13 @@
1
1
  import collections
2
- import math
2
+ import multiprocessing
3
3
  import pathlib
4
- import sys
5
4
  import tempfile
6
5
  import time
7
6
 
8
7
  import pytest
9
- import stim
10
8
 
11
9
  import sinter
10
+ import stim
12
11
 
13
12
 
14
13
  def test_iter_collect():
@@ -272,3 +271,47 @@ def test_mock_timing_sampler():
272
271
  custom_decoders={'MockTimingSampler': MockTimingSampler()},
273
272
  )
274
273
  assert 1_000_000 <= results[0].shots <= 1_000_000 + 12000
274
+
275
+ class BatchSizeTrackingSampler(sinter.Sampler, sinter.CompiledSampler):
276
+ """A sampler that tracks the suggested batch size requests it receives."""
277
+
278
+ def __init__(self, batch_sizes: list[int]):
279
+ self.batch_sizes = batch_sizes
280
+
281
+ def compiled_sampler_for_task(self, task: sinter.Task) -> sinter.CompiledSampler:
282
+ return self
283
+
284
+ def sample(self, suggested_shots: int) -> sinter.AnonTaskStats:
285
+ self.batch_sizes.append(suggested_shots)
286
+ return sinter.AnonTaskStats(
287
+ shots=suggested_shots,
288
+ errors=1,
289
+ seconds=0.001,
290
+ )
291
+
292
+
293
+ def test_ramp_throttled_sampler_respects_max_batch_size():
294
+ """Test that the CollectionManager instantiated RampThrottledSampler respects the `max_batch_size`
295
+ parameter."""
296
+
297
+ # since the RampThrottledSampler and batch sizing happens in the worker process, we need a
298
+ # shared list to track what goes on with the sampler
299
+ with multiprocessing.Manager() as manager:
300
+ tracking_sampler = BatchSizeTrackingSampler(manager.list())
301
+
302
+ sinter.collect(
303
+ num_workers=1,
304
+ tasks=[
305
+ sinter.Task(
306
+ circuit=stim.Circuit(),
307
+ decoder='tracking_sampler',
308
+ json_metadata={'test': 'small_batch'},
309
+ )
310
+ ],
311
+ max_shots=10_000,
312
+ max_batch_size=128, # Set a small max batch size
313
+ custom_decoders={'tracking_sampler': tracking_sampler},
314
+ )
315
+ # batch size should start at one and then maximum seen should be at most 128
316
+ assert tracking_sampler.batch_sizes[0] == 1
317
+ assert 1 < max(tracking_sampler.batch_sizes) <= 128
@@ -163,7 +163,15 @@ def test_worker_finish_work():
163
163
  handler.expected_task = ta
164
164
  _put_wait_not_empty(inp, ('change_job', (ta, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000)))
165
165
  _put_wait_not_empty(inp, ('accept_shots', (ta.strong_id(), 10000)))
166
- assert worker.process_messages() == 2
166
+ t0 = time.monotonic()
167
+ num_processed = 0
168
+ while True:
169
+ num_processed += worker.process_messages()
170
+ if num_processed >= 2:
171
+ break
172
+ if time.monotonic() - t0 > 1:
173
+ raise ValueError("Messages not processed")
174
+ assert num_processed == 2
167
175
  _assert_drain_queue(out, [
168
176
  ('changed_job', 5, (ta.strong_id(),)),
169
177
  ('accepted_shots', 5, (ta.strong_id(), 10000)),
@@ -0,0 +1,144 @@
1
+ from unittest import mock
2
+
3
+ import pytest
4
+
5
+ import sinter
6
+ import stim
7
+ from sinter._collection._sampler_ramp_throttled import (
8
+ CompiledRampThrottledSampler,
9
+ RampThrottledSampler,
10
+ )
11
+ from sinter._data import AnonTaskStats, Task
12
+
13
+
14
+ class MockSampler(sinter.Sampler, sinter.CompiledSampler):
15
+ """Mock sampler that tracks `suggested_shots` parameter in `sample` calls."""
16
+
17
+ def __init__(self):
18
+ self.calls = []
19
+
20
+ def compiled_sampler_for_task(self, task: Task) -> sinter.CompiledSampler:
21
+ return self
22
+
23
+ def sample(self, suggested_shots: int) -> AnonTaskStats:
24
+ self.calls.append(suggested_shots)
25
+ return AnonTaskStats(
26
+ shots=suggested_shots,
27
+ errors=1,
28
+ seconds=0.001 * suggested_shots, # Simulate time proportional to shots
29
+ )
30
+
31
+
32
+ @pytest.fixture
33
+ def mock_sampler():
34
+ return MockSampler()
35
+
36
+
37
+ def test_initial_batch_size(mock_sampler):
38
+ """Test that the sampler starts with a batch size of 1."""
39
+ sampler = CompiledRampThrottledSampler(
40
+ sub_sampler=mock_sampler,
41
+ target_batch_seconds=1.0,
42
+ max_batch_shots=1024,
43
+ )
44
+
45
+ # First call should use batch_size=1
46
+ sampler.sample(100)
47
+ assert mock_sampler.calls[0] == 1
48
+
49
+
50
+ def test_batch_size_ramps_up(mock_sampler):
51
+ """Test that the batch size increases when execution is fast."""
52
+ sampler = CompiledRampThrottledSampler(
53
+ sub_sampler=mock_sampler,
54
+ target_batch_seconds=1.0,
55
+ max_batch_shots=1024,
56
+ )
57
+
58
+ # Mock time.monotonic to simulate fast execution
59
+ # two calls per sample for tic/toc
60
+ with mock.patch(
61
+ "time.monotonic", side_effect=[0.0, 0.001, 0.02, 0.021, 0.03, 0.031]
62
+ ):
63
+ sampler.sample(100) # First call, batch_size=1
64
+ sampler.sample(100) # Should double 4 times to 16
65
+ sampler.sample(100) # Should double 4 times again but hit limit of 100
66
+
67
+ assert mock_sampler.calls == [1, 16, 100]
68
+
69
+
70
+ def test_batch_size_decreases(mock_sampler):
71
+ """Test that the batch size decreases when execution is slow."""
72
+ sampler = CompiledRampThrottledSampler(
73
+ sub_sampler=mock_sampler,
74
+ target_batch_seconds=0.1,
75
+ max_batch_shots=1024,
76
+ )
77
+
78
+ # Set initial batch size higher for this test
79
+ sampler.batch_shots = 64
80
+
81
+ # Mock time.monotonic to simulate slow execution (>1.3x target)
82
+ with mock.patch("time.monotonic", side_effect=[0.0, 0.15, 0.5, 0.65]):
83
+ sampler.sample(100) # First call, batch_size=64
84
+ sampler.sample(100) # Should halve to 32
85
+
86
+ assert mock_sampler.calls == [64, 32]
87
+
88
+
89
+ def test_respects_max_batch_shots(mock_sampler):
90
+ """Test that the batch size never exceeds max_batch_shots."""
91
+ sampler = CompiledRampThrottledSampler(
92
+ sub_sampler=mock_sampler,
93
+ target_batch_seconds=1.0,
94
+ max_batch_shots=16, # Small max for testing
95
+ )
96
+
97
+ # Set initial batch size close to max
98
+ sampler.batch_shots = 8
99
+
100
+ # Mock time.monotonic to simulate very fast execution
101
+ # two calls per sample for tic/toc
102
+ with mock.patch(
103
+ "time.monotonic", side_effect=[0.0, 0.001, 0.02, 0.021, 0.03, 0.031]
104
+ ):
105
+ sampler.sample(100) # First call, batch_size=8
106
+ sampler.sample(100) # Should double to 16
107
+ sampler.sample(100) # Should stay at 16 (max)
108
+
109
+ assert mock_sampler.calls == [8, 16, 16]
110
+
111
+
112
+ def test_respects_max_shots_parameter(mock_sampler):
113
+ """Test that the sampler respects the max_shots parameter."""
114
+ sampler = CompiledRampThrottledSampler(
115
+ sub_sampler=mock_sampler,
116
+ target_batch_seconds=1.0,
117
+ max_batch_shots=1024,
118
+ )
119
+
120
+ # Set batch size higher than max_shots
121
+ sampler.batch_shots = 100
122
+
123
+ # Call with max_shots=10
124
+ sampler.sample(10)
125
+
126
+ # Should only request 10 shots, not 100
127
+ assert mock_sampler.calls[0] == 10
128
+
129
+
130
+ def test_sub_sampler_parameter_pass_through(mock_sampler):
131
+ """Test that parameters are passed through to compiled sub sampler."""
132
+ factory = RampThrottledSampler(
133
+ sub_sampler=mock_sampler,
134
+ target_batch_seconds=0.5,
135
+ max_batch_shots=512,
136
+ )
137
+
138
+ task = Task(circuit=stim.Circuit(), decoder="test")
139
+ compiled = factory.compiled_sampler_for_task(task)
140
+
141
+ assert isinstance(compiled, CompiledRampThrottledSampler)
142
+ assert compiled.target_batch_seconds == 0.5
143
+ assert compiled.max_batch_shots == 512
144
+ assert compiled.batch_shots == 1 # Initial batch size
@@ -1,6 +1,7 @@
1
1
  import collections
2
2
  import dataclasses
3
3
  from typing import Counter, Union, TYPE_CHECKING
4
+ import numpy as np
4
5
 
5
6
  if TYPE_CHECKING:
6
7
  from sinter._data._task_stats import TaskStats
@@ -34,16 +35,16 @@ class AnonTaskStats:
34
35
  custom_counts: Counter[str] = dataclasses.field(default_factory=collections.Counter)
35
36
 
36
37
  def __post_init__(self):
37
- assert isinstance(self.errors, int)
38
- assert isinstance(self.shots, int)
39
- assert isinstance(self.discards, int)
40
- assert isinstance(self.seconds, (int, float))
38
+ assert isinstance(self.errors, (int, np.integer))
39
+ assert isinstance(self.shots, (int, np.integer))
40
+ assert isinstance(self.discards, (int, np.integer))
41
+ assert isinstance(self.seconds, (int, float, np.integer, np.floating))
41
42
  assert isinstance(self.custom_counts, collections.Counter)
42
43
  assert self.errors >= 0
43
44
  assert self.discards >= 0
44
45
  assert self.seconds >= 0
45
46
  assert self.shots >= self.errors + self.discards
46
- assert all(isinstance(k, str) and isinstance(v, int) for k, v in self.custom_counts.items())
47
+ assert all(isinstance(k, str) and isinstance(v, (int, np.integer)) for k, v in self.custom_counts.items())
47
48
 
48
49
  def __repr__(self) -> str:
49
50
  terms = []
@@ -36,6 +36,9 @@ def csv_line(*,
36
36
  separators=(',', ':'),
37
37
  sort_keys=True)
38
38
  if custom_counts:
39
+ # Ensure all custom_counts values are integers before dumping to JSON
40
+ for k in custom_counts:
41
+ custom_counts[k] = int(custom_counts[k])
39
42
  custom_counts = escape_csv(
40
43
  json.dumps(custom_counts,
41
44
  separators=(',', ':'),
@@ -4,7 +4,7 @@ from typing import Counter, List, Any
4
4
  from typing import Optional
5
5
  from typing import Union
6
6
  from typing import overload
7
-
7
+ import numpy as np
8
8
  from sinter._data._anon_task_stats import AnonTaskStats
9
9
  from sinter._data._csv_out import csv_line
10
10
 
@@ -71,10 +71,10 @@ class TaskStats:
71
71
  custom_counts: Counter[str] = dataclasses.field(default_factory=collections.Counter)
72
72
 
73
73
  def __post_init__(self):
74
- assert isinstance(self.errors, int)
75
- assert isinstance(self.shots, int)
76
- assert isinstance(self.discards, int)
77
- assert isinstance(self.seconds, (int, float))
74
+ assert isinstance(self.errors, (int, np.integer))
75
+ assert isinstance(self.shots, (int, np.integer))
76
+ assert isinstance(self.discards, (int, np.integer))
77
+ assert isinstance(self.seconds, (int, float, np.integer, np.floating))
78
78
  assert isinstance(self.custom_counts, collections.Counter)
79
79
  assert isinstance(self.decoder, str)
80
80
  assert isinstance(self.strong_id, str)
@@ -83,7 +83,7 @@ class TaskStats:
83
83
  assert self.discards >= 0
84
84
  assert self.seconds >= 0
85
85
  assert self.shots >= self.errors + self.discards
86
- assert all(isinstance(k, str) and isinstance(v, int) for k, v in self.custom_counts.items())
86
+ assert all(isinstance(k, str) and isinstance(v, (int, np.integer)) for k, v in self.custom_counts.items())
87
87
 
88
88
  def with_edits(
89
89
  self,
@@ -190,13 +190,13 @@ class TaskStats:
190
190
  >>> print(sinter.CSV_HEADER)
191
191
  shots, errors, discards, seconds,decoder,strong_id,json_metadata,custom_counts
192
192
  >>> print(stat.to_csv_line())
193
- 22, 3, 0, 5,pymatching,test,"{""a"":[1,2,3]}",
193
+ 22, 3, 0, 5.00,pymatching,test,"{""a"":[1,2,3]}",
194
194
  """
195
195
  return csv_line(
196
- shots=self.shots,
197
- errors=self.errors,
198
- seconds=self.seconds,
199
- discards=self.discards,
196
+ shots=int(self.shots),
197
+ errors=int(self.errors),
198
+ seconds=float(self.seconds),
199
+ discards=int(self.discards),
200
200
  strong_id=self.strong_id,
201
201
  decoder=self.decoder,
202
202
  json_metadata=self.json_metadata,
@@ -29,7 +29,7 @@ def test_to_csv_line():
29
29
  discards=4,
30
30
  seconds=5,
31
31
  )
32
- assert v.to_csv_line() == str(v) == ' 22, 3, 4, 5,pymatching,test,"{""a"":[1,2,3]}",'
32
+ assert v.to_csv_line() == str(v) == ' 22, 3, 4, 5.00,pymatching,test,"{""a"":[1,2,3]}",'
33
33
 
34
34
 
35
35
  def test_to_anon_stats():
@@ -55,7 +55,7 @@ class CompiledDecoder(metaclass=abc.ABCMeta):
55
55
  pass
56
56
 
57
57
 
58
- class Decoder(metaclass=abc.ABCMeta):
58
+ class Decoder:
59
59
  """Abstract base class for custom decoders.
60
60
 
61
61
  Custom decoders can be explained to sinter by inheriting from this class and
@@ -63,6 +63,10 @@ class Decoder(metaclass=abc.ABCMeta):
63
63
 
64
64
  Decoder classes MUST be serializable (e.g. via pickling), so that they can
65
65
  be given to worker processes when using python multiprocessing.
66
+
67
+ Child classes should implement `compile_decoder_for_dem`, but (for legacy
68
+ reasons) can alternatively implement `decode_via_files`. At least one of
69
+ the two methods must be implemented.
66
70
  """
67
71
 
68
72
  def compile_decoder_for_dem(
@@ -95,7 +99,6 @@ class Decoder(metaclass=abc.ABCMeta):
95
99
  """
96
100
  raise NotImplementedError('compile_decoder_for_dem')
97
101
 
98
- @abc.abstractmethod
99
102
  def decode_via_files(self,
100
103
  *,
101
104
  num_shots: int,
@@ -141,4 +144,18 @@ class Decoder(metaclass=abc.ABCMeta):
141
144
  temporary objects. All cleanup should be done via sinter
142
145
  deleting this directory after killing the decoder.
143
146
  """
144
- pass
147
+ dem = stim.DetectorErrorModel.from_file(dem_path)
148
+
149
+ try:
150
+ compiled = self.compile_decoder_for_dem(dem=dem)
151
+ except NotImplementedError as ex:
152
+ raise NotImplementedError(f"{type(self).__qualname__} didn't implement `compile_decoder_for_dem` or `decode_via_files`.") from ex
153
+
154
+ num_det_bytes = -(-num_dets // 8)
155
+ num_obs_bytes = -(-num_obs // 8)
156
+ dets = np.fromfile(dets_b8_in_path, dtype=np.uint8, count=num_shots * num_det_bytes)
157
+ dets = dets.reshape(num_shots, num_det_bytes)
158
+ obs = compiled.decode_shots_bit_packed(bit_packed_detection_event_data=dets)
159
+ if obs.dtype != np.uint8 or obs.shape != (num_shots, num_obs_bytes):
160
+ raise ValueError(f"Got a numpy array with dtype={obs.dtype},shape={obs.shape} instead of dtype={np.uint8},shape={(num_shots, num_obs_bytes)} from {type(self).__qualname__}(...).compile_decoder_for_dem(...).decode_shots_bit_packed(...).")
161
+ obs.tofile(obs_predictions_b8_out_path)
@@ -10,6 +10,7 @@ import pytest
10
10
  import sinter
11
11
  import stim
12
12
 
13
+ from sinter import CompiledDecoder
13
14
  from sinter._collection import post_selection_mask_from_4th_coord
14
15
  from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_DECODERS
15
16
  from sinter._decoding._decoding import sample_decode
@@ -391,3 +392,89 @@ def test_full_scale(decoder: str):
391
392
  assert result.discards == 0
392
393
  assert result.shots == 1000
393
394
  assert result.errors == 0
395
+
396
+
397
+ def test_infer_decode_via_files_from_decode_from_compile_decoder_for_dem():
398
+ class IncompleteDecoder(sinter.Decoder):
399
+ pass
400
+
401
+ class WrongDecoder(sinter.Decoder, sinter.CompiledDecoder):
402
+ def compile_decoder_for_dem(
403
+ self,
404
+ *,
405
+ dem: stim.DetectorErrorModel,
406
+ ) -> CompiledDecoder:
407
+ return self
408
+ def decode_shots_bit_packed(
409
+ self,
410
+ *,
411
+ bit_packed_detection_event_data: np.ndarray,
412
+ ) -> np.ndarray:
413
+ return np.zeros(shape=5, dtype=np.bool_)
414
+
415
+ class TrivialCompiledDecoder(sinter.CompiledDecoder):
416
+ def __init__(self, num_obs: int):
417
+ self.num_obs = -(-num_obs // 8)
418
+
419
+ def decode_shots_bit_packed(
420
+ self,
421
+ *,
422
+ bit_packed_detection_event_data: np.ndarray,
423
+ ) -> np.ndarray:
424
+ return np.zeros(dtype=np.uint8, shape=(bit_packed_detection_event_data.shape[0], self.num_obs))
425
+
426
+ class TrivialDecoder(sinter.Decoder):
427
+ def compile_decoder_for_dem(
428
+ self,
429
+ *,
430
+ dem: stim.DetectorErrorModel,
431
+ ) -> CompiledDecoder:
432
+ return TrivialCompiledDecoder(num_obs=dem.num_observables)
433
+
434
+ circuit = stim.Circuit.generated("repetition_code:memory", distance=3, rounds=3)
435
+ dem = circuit.detector_error_model()
436
+
437
+ with tempfile.TemporaryDirectory() as d:
438
+ d = pathlib.Path(d)
439
+
440
+ circuit.compile_detector_sampler().sample_write(
441
+ shots=10,
442
+ filepath=d / 'dets.b8',
443
+ format='b8',
444
+ )
445
+
446
+ dem.to_file(d / 'dem.dem')
447
+
448
+ with pytest.raises(NotImplementedError, match='compile_decoder_for_dem'):
449
+ IncompleteDecoder().decode_via_files(
450
+ num_shots=10,
451
+ num_dets=dem.num_detectors,
452
+ num_obs=dem.num_observables,
453
+ dem_path=d / 'dem.dem',
454
+ dets_b8_in_path=d / 'dets.b8',
455
+ obs_predictions_b8_out_path=d / 'obs.b8',
456
+ tmp_dir=d,
457
+ )
458
+
459
+ with pytest.raises(ValueError, match='shape='):
460
+ WrongDecoder().decode_via_files(
461
+ num_shots=10,
462
+ num_dets=dem.num_detectors,
463
+ num_obs=dem.num_observables,
464
+ dem_path=d / 'dem.dem',
465
+ dets_b8_in_path=d / 'dets.b8',
466
+ obs_predictions_b8_out_path=d / 'obs.b8',
467
+ tmp_dir=d,
468
+ )
469
+
470
+ TrivialDecoder().decode_via_files(
471
+ num_shots=10,
472
+ num_dets=dem.num_detectors,
473
+ num_obs=dem.num_observables,
474
+ dem_path=d / 'dem.dem',
475
+ dets_b8_in_path=d / 'dets.b8',
476
+ obs_predictions_b8_out_path=d / 'obs.b8',
477
+ tmp_dir=d,
478
+ )
479
+ obs = np.fromfile(d / 'obs.b8', dtype=np.uint8, count=10)
480
+ np.testing.assert_array_equal(obs, [0] * 10)
@@ -140,7 +140,7 @@ def _compile_decoder_with_disk_fallback(
140
140
  ) -> CompiledDecoder:
141
141
  try:
142
142
  return decoder.compile_decoder_for_dem(dem=task.detector_error_model)
143
- except (NotImplementedError, ValueError):
143
+ except NotImplementedError:
144
144
  pass
145
145
  if tmp_dir is None:
146
146
  raise ValueError(f"Decoder {task.decoder=} didn't implement `compile_decoder_for_dem`, but no temporary directory was provided for falling back to `decode_via_files`.")
@@ -51,7 +51,7 @@ def log_binomial(*, p: Union[float, np.ndarray], n: int, hits: int) -> np.ndarra
51
51
  Examples:
52
52
  >>> import sinter
53
53
  >>> sinter.log_binomial(p=0.5, n=100, hits=50)
54
- array(-2.5308785, dtype=float32)
54
+ array(-2.5308762, dtype=float32)
55
55
  >>> sinter.log_binomial(p=0.2, n=1_000_000, hits=1_000)
56
56
  array(-216626.97, dtype=float32)
57
57
  >>> sinter.log_binomial(p=0.1, n=1_000_000, hits=1_000)
@@ -321,7 +321,7 @@ def fit_line_slope(*,
321
321
 
322
322
  low_slope = binary_intercept(start_x=fit.slope, step=-1, target_y=base_cost + max_extra_squared_error, func=cost_for_slope, atol=1e-5)
323
323
  high_slope = binary_intercept(start_x=fit.slope, step=1, target_y=base_cost + max_extra_squared_error, func=cost_for_slope, atol=1e-5)
324
- return Fit(low=low_slope, best=fit.slope, high=high_slope)
324
+ return Fit(low=float(low_slope), best=float(fit.slope), high=float(high_slope))
325
325
 
326
326
 
327
327
  def fit_binomial(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sinter
3
- Version: 1.15.dev1742871434
3
+ Version: 1.16.dev1768553279
4
4
  Summary: Samples stim circuits and decodes them using pymatching.
5
5
  Author: Craig Gidney
6
6
  Author-email: craig.gidney@gmail.com
@@ -27,6 +27,7 @@ src/sinter/_collection/_collection_worker_test.py
27
27
  src/sinter/_collection/_mux_sampler.py
28
28
  src/sinter/_collection/_printer.py
29
29
  src/sinter/_collection/_sampler_ramp_throttled.py
30
+ src/sinter/_collection/_sampler_ramp_throttled_test.py
30
31
  src/sinter/_command/__init__.py
31
32
  src/sinter/_command/_main.py
32
33
  src/sinter/_command/_main_collect.py