sinter 1.14.dev1727164676__tar.gz → 1.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sinter might be problematic. Click here for more details.

Files changed (69) hide show
  1. {sinter-1.14.dev1727164676/src/sinter.egg-info → sinter-1.15.0}/PKG-INFO +13 -5
  2. sinter-1.15.0/requirements.txt +4 -0
  3. {sinter-1.14.dev1727164676 → sinter-1.15.0}/setup.py +1 -1
  4. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/__init__.py +1 -1
  5. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/_collection_manager.py +5 -1
  6. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/_collection_test.py +46 -3
  7. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/_collection_worker_test.py +9 -1
  8. sinter-1.15.0/src/sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
  9. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_decoding_decoder_class.py +20 -3
  10. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_decoding_fusion_blossom.py +5 -2
  11. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_decoding_mwpf.py +56 -63
  12. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_decoding_test.py +87 -0
  13. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_stim_then_decode_sampler.py +1 -1
  14. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_probability_util.py +2 -2
  15. {sinter-1.14.dev1727164676 → sinter-1.15.0/src/sinter.egg-info}/PKG-INFO +13 -5
  16. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter.egg-info/SOURCES.txt +1 -0
  17. sinter-1.15.0/src/sinter.egg-info/requires.txt +4 -0
  18. sinter-1.14.dev1727164676/requirements.txt +0 -4
  19. sinter-1.14.dev1727164676/src/sinter.egg-info/requires.txt +0 -4
  20. {sinter-1.14.dev1727164676 → sinter-1.15.0}/MANIFEST.in +0 -0
  21. {sinter-1.14.dev1727164676 → sinter-1.15.0}/README.md +0 -0
  22. {sinter-1.14.dev1727164676 → sinter-1.15.0}/readme_example_plot.png +0 -0
  23. {sinter-1.14.dev1727164676 → sinter-1.15.0}/setup.cfg +0 -0
  24. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/__init__.py +0 -0
  25. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/_collection.py +0 -0
  26. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/_collection_manager_test.py +0 -0
  27. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/_collection_worker_loop.py +0 -0
  28. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/_collection_worker_state.py +0 -0
  29. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/_mux_sampler.py +0 -0
  30. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/_printer.py +0 -0
  31. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_collection/_sampler_ramp_throttled.py +0 -0
  32. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_command/__init__.py +0 -0
  33. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_command/_main.py +0 -0
  34. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_command/_main_collect.py +0 -0
  35. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_command/_main_collect_test.py +0 -0
  36. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_command/_main_combine.py +0 -0
  37. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_command/_main_combine_test.py +0 -0
  38. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_command/_main_plot.py +0 -0
  39. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_command/_main_plot_test.py +0 -0
  40. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_command/_main_predict.py +0 -0
  41. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_command/_main_predict_test.py +0 -0
  42. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/__init__.py +0 -0
  43. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_anon_task_stats.py +0 -0
  44. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_anon_task_stats_test.py +0 -0
  45. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_collection_options.py +0 -0
  46. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_collection_options_test.py +0 -0
  47. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_csv_out.py +0 -0
  48. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_existing_data.py +0 -0
  49. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_existing_data_test.py +0 -0
  50. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_task.py +0 -0
  51. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_task_stats.py +0 -0
  52. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_task_stats_test.py +0 -0
  53. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_data/_task_test.py +0 -0
  54. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/__init__.py +0 -0
  55. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_decoding.py +0 -0
  56. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_decoding_all_built_in_decoders.py +0 -0
  57. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_decoding_pymatching.py +0 -0
  58. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_decoding_vacuous.py +0 -0
  59. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_perfectionist_sampler.py +0 -0
  60. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_sampler.py +0 -0
  61. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_decoding/_stim_then_decode_sampler_test.py +0 -0
  62. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_plotting.py +0 -0
  63. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_plotting_test.py +0 -0
  64. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_predict.py +0 -0
  65. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_predict_test.py +0 -0
  66. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter/_probability_util_test.py +0 -0
  67. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter.egg-info/dependency_links.txt +0 -0
  68. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter.egg-info/entry_points.txt +0 -0
  69. {sinter-1.14.dev1727164676 → sinter-1.15.0}/src/sinter.egg-info/top_level.txt +0 -0
@@ -1,16 +1,24 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: sinter
3
- Version: 1.14.dev1727164676
3
+ Version: 1.15.0
4
4
  Summary: Samples stim circuits and decodes them using pymatching.
5
5
  Author: Craig Gidney
6
6
  Author-email: craig.gidney@gmail.com
7
7
  License: Apache 2
8
8
  Requires-Python: >=3.7.0
9
9
  Description-Content-Type: text/markdown
10
- Requires-Dist: matplotlib~=3.5
11
- Requires-Dist: numpy~=1.22
10
+ Requires-Dist: matplotlib
11
+ Requires-Dist: numpy
12
12
  Requires-Dist: stim
13
- Requires-Dist: scipy~=1.9
13
+ Requires-Dist: scipy
14
+ Dynamic: author
15
+ Dynamic: author-email
16
+ Dynamic: description
17
+ Dynamic: description-content-type
18
+ Dynamic: license
19
+ Dynamic: requires-dist
20
+ Dynamic: requires-python
21
+ Dynamic: summary
14
22
 
15
23
  # sinter: fast QEC sampling
16
24
 
@@ -0,0 +1,4 @@
1
+ matplotlib
2
+ numpy
3
+ stim
4
+ scipy
@@ -19,7 +19,7 @@ with open('README.md', encoding='UTF-8') as f:
19
19
  with open('requirements.txt', encoding='UTF-8') as f:
20
20
  requirements = f.read().splitlines()
21
21
 
22
- __version__ = '1.14.dev1727164676'
22
+ __version__ = '1.15.0'
23
23
 
24
24
  setup(
25
25
  name='sinter',
@@ -1,4 +1,4 @@
1
- __version__ = '1.14.dev1727164676'
1
+ __version__ = '1.15.0'
2
2
 
3
3
  from sinter._collection import (
4
4
  collect,
@@ -129,6 +129,10 @@ class CollectionManager:
129
129
  def start_workers(self, *, actually_start_worker_processes: bool = True):
130
130
  assert not self.started
131
131
 
132
+ # Use max_batch_size from collection_options if provided, otherwise default to 1024 as large
133
+ # batch sizes can lead to thrashing
134
+ max_batch_shots = self.collection_options.max_batch_size or 1024
135
+
132
136
  sampler = RampThrottledSampler(
133
137
  sub_sampler=MuxSampler(
134
138
  custom_decoders=self.custom_decoders,
@@ -137,7 +141,7 @@ class CollectionManager:
137
141
  tmp_dir=self.tmp_dir,
138
142
  ),
139
143
  target_batch_seconds=1,
140
- max_batch_shots=1024,
144
+ max_batch_shots=max_batch_shots,
141
145
  )
142
146
 
143
147
  self.started = True
@@ -1,14 +1,13 @@
1
1
  import collections
2
- import math
2
+ import multiprocessing
3
3
  import pathlib
4
- import sys
5
4
  import tempfile
6
5
  import time
7
6
 
8
7
  import pytest
9
- import stim
10
8
 
11
9
  import sinter
10
+ import stim
12
11
 
13
12
 
14
13
  def test_iter_collect():
@@ -272,3 +271,47 @@ def test_mock_timing_sampler():
272
271
  custom_decoders={'MockTimingSampler': MockTimingSampler()},
273
272
  )
274
273
  assert 1_000_000 <= results[0].shots <= 1_000_000 + 12000
274
+
275
+ class BatchSizeTrackingSampler(sinter.Sampler, sinter.CompiledSampler):
276
+ """A sampler that tracks the suggested batch size requests it receives."""
277
+
278
+ def __init__(self, batch_sizes: list[int]):
279
+ self.batch_sizes = batch_sizes
280
+
281
+ def compiled_sampler_for_task(self, task: sinter.Task) -> sinter.CompiledSampler:
282
+ return self
283
+
284
+ def sample(self, suggested_shots: int) -> sinter.AnonTaskStats:
285
+ self.batch_sizes.append(suggested_shots)
286
+ return sinter.AnonTaskStats(
287
+ shots=suggested_shots,
288
+ errors=1,
289
+ seconds=0.001,
290
+ )
291
+
292
+
293
+ def test_ramp_throttled_sampler_respects_max_batch_size():
294
+ """Test that the CollectionManager instantiated RampThrottledSampler respects the `max_batch_size`
295
+ parameter."""
296
+
297
+ # since the RampThrottledSampler and batch sizing happens in the worker process, we need a
298
+ # shared list to track what goes on with the sampler
299
+ with multiprocessing.Manager() as manager:
300
+ tracking_sampler = BatchSizeTrackingSampler(manager.list())
301
+
302
+ sinter.collect(
303
+ num_workers=1,
304
+ tasks=[
305
+ sinter.Task(
306
+ circuit=stim.Circuit(),
307
+ decoder='tracking_sampler',
308
+ json_metadata={'test': 'small_batch'},
309
+ )
310
+ ],
311
+ max_shots=10_000,
312
+ max_batch_size=128, # Set a small max batch size
313
+ custom_decoders={'tracking_sampler': tracking_sampler},
314
+ )
315
+ # batch size should start at one and then maximum seen should be at most 128
316
+ assert tracking_sampler.batch_sizes[0] == 1
317
+ assert 1 < max(tracking_sampler.batch_sizes) <= 128
@@ -163,7 +163,15 @@ def test_worker_finish_work():
163
163
  handler.expected_task = ta
164
164
  _put_wait_not_empty(inp, ('change_job', (ta, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000)))
165
165
  _put_wait_not_empty(inp, ('accept_shots', (ta.strong_id(), 10000)))
166
- assert worker.process_messages() == 2
166
+ t0 = time.monotonic()
167
+ num_processed = 0
168
+ while True:
169
+ num_processed += worker.process_messages()
170
+ if num_processed >= 2:
171
+ break
172
+ if time.monotonic() - t0 > 1:
173
+ raise ValueError("Messages not processed")
174
+ assert num_processed == 2
167
175
  _assert_drain_queue(out, [
168
176
  ('changed_job', 5, (ta.strong_id(),)),
169
177
  ('accepted_shots', 5, (ta.strong_id(), 10000)),
@@ -0,0 +1,144 @@
1
+ from unittest import mock
2
+
3
+ import pytest
4
+
5
+ import sinter
6
+ import stim
7
+ from sinter._collection._sampler_ramp_throttled import (
8
+ CompiledRampThrottledSampler,
9
+ RampThrottledSampler,
10
+ )
11
+ from sinter._data import AnonTaskStats, Task
12
+
13
+
14
+ class MockSampler(sinter.Sampler, sinter.CompiledSampler):
15
+ """Mock sampler that tracks `suggested_shots` parameter in `sample` calls."""
16
+
17
+ def __init__(self):
18
+ self.calls = []
19
+
20
+ def compiled_sampler_for_task(self, task: Task) -> sinter.CompiledSampler:
21
+ return self
22
+
23
+ def sample(self, suggested_shots: int) -> AnonTaskStats:
24
+ self.calls.append(suggested_shots)
25
+ return AnonTaskStats(
26
+ shots=suggested_shots,
27
+ errors=1,
28
+ seconds=0.001 * suggested_shots, # Simulate time proportional to shots
29
+ )
30
+
31
+
32
+ @pytest.fixture
33
+ def mock_sampler():
34
+ return MockSampler()
35
+
36
+
37
+ def test_initial_batch_size(mock_sampler):
38
+ """Test that the sampler starts with a batch size of 1."""
39
+ sampler = CompiledRampThrottledSampler(
40
+ sub_sampler=mock_sampler,
41
+ target_batch_seconds=1.0,
42
+ max_batch_shots=1024,
43
+ )
44
+
45
+ # First call should use batch_size=1
46
+ sampler.sample(100)
47
+ assert mock_sampler.calls[0] == 1
48
+
49
+
50
+ def test_batch_size_ramps_up(mock_sampler):
51
+ """Test that the batch size increases when execution is fast."""
52
+ sampler = CompiledRampThrottledSampler(
53
+ sub_sampler=mock_sampler,
54
+ target_batch_seconds=1.0,
55
+ max_batch_shots=1024,
56
+ )
57
+
58
+ # Mock time.monotonic to simulate fast execution
59
+ # two calls per sample for tic/toc
60
+ with mock.patch(
61
+ "time.monotonic", side_effect=[0.0, 0.001, 0.02, 0.021, 0.03, 0.031]
62
+ ):
63
+ sampler.sample(100) # First call, batch_size=1
64
+ sampler.sample(100) # Should double 4 times to 16
65
+ sampler.sample(100) # Should double 4 times again but hit limit of 100
66
+
67
+ assert mock_sampler.calls == [1, 16, 100]
68
+
69
+
70
+ def test_batch_size_decreases(mock_sampler):
71
+ """Test that the batch size decreases when execution is slow."""
72
+ sampler = CompiledRampThrottledSampler(
73
+ sub_sampler=mock_sampler,
74
+ target_batch_seconds=0.1,
75
+ max_batch_shots=1024,
76
+ )
77
+
78
+ # Set initial batch size higher for this test
79
+ sampler.batch_shots = 64
80
+
81
+ # Mock time.monotonic to simulate slow execution (>1.3x target)
82
+ with mock.patch("time.monotonic", side_effect=[0.0, 0.15, 0.5, 0.65]):
83
+ sampler.sample(100) # First call, batch_size=64
84
+ sampler.sample(100) # Should halve to 32
85
+
86
+ assert mock_sampler.calls == [64, 32]
87
+
88
+
89
+ def test_respects_max_batch_shots(mock_sampler):
90
+ """Test that the batch size never exceeds max_batch_shots."""
91
+ sampler = CompiledRampThrottledSampler(
92
+ sub_sampler=mock_sampler,
93
+ target_batch_seconds=1.0,
94
+ max_batch_shots=16, # Small max for testing
95
+ )
96
+
97
+ # Set initial batch size close to max
98
+ sampler.batch_shots = 8
99
+
100
+ # Mock time.monotonic to simulate very fast execution
101
+ # two calls per sample for tic/toc
102
+ with mock.patch(
103
+ "time.monotonic", side_effect=[0.0, 0.001, 0.02, 0.021, 0.03, 0.031]
104
+ ):
105
+ sampler.sample(100) # First call, batch_size=8
106
+ sampler.sample(100) # Should double to 16
107
+ sampler.sample(100) # Should stay at 16 (max)
108
+
109
+ assert mock_sampler.calls == [8, 16, 16]
110
+
111
+
112
+ def test_respects_max_shots_parameter(mock_sampler):
113
+ """Test that the sampler respects the max_shots parameter."""
114
+ sampler = CompiledRampThrottledSampler(
115
+ sub_sampler=mock_sampler,
116
+ target_batch_seconds=1.0,
117
+ max_batch_shots=1024,
118
+ )
119
+
120
+ # Set batch size higher than max_shots
121
+ sampler.batch_shots = 100
122
+
123
+ # Call with max_shots=10
124
+ sampler.sample(10)
125
+
126
+ # Should only request 10 shots, not 100
127
+ assert mock_sampler.calls[0] == 10
128
+
129
+
130
+ def test_sub_sampler_parameter_pass_through(mock_sampler):
131
+ """Test that parameters are passed through to compiled sub sampler."""
132
+ factory = RampThrottledSampler(
133
+ sub_sampler=mock_sampler,
134
+ target_batch_seconds=0.5,
135
+ max_batch_shots=512,
136
+ )
137
+
138
+ task = Task(circuit=stim.Circuit(), decoder="test")
139
+ compiled = factory.compiled_sampler_for_task(task)
140
+
141
+ assert isinstance(compiled, CompiledRampThrottledSampler)
142
+ assert compiled.target_batch_seconds == 0.5
143
+ assert compiled.max_batch_shots == 512
144
+ assert compiled.batch_shots == 1 # Initial batch size
@@ -55,7 +55,7 @@ class CompiledDecoder(metaclass=abc.ABCMeta):
55
55
  pass
56
56
 
57
57
 
58
- class Decoder(metaclass=abc.ABCMeta):
58
+ class Decoder:
59
59
  """Abstract base class for custom decoders.
60
60
 
61
61
  Custom decoders can be explained to sinter by inheriting from this class and
@@ -63,6 +63,10 @@ class Decoder(metaclass=abc.ABCMeta):
63
63
 
64
64
  Decoder classes MUST be serializable (e.g. via pickling), so that they can
65
65
  be given to worker processes when using python multiprocessing.
66
+
67
+ Child classes should implement `compile_decoder_for_dem`, but (for legacy
68
+ reasons) can alternatively implement `decode_via_files`. At least one of
69
+ the two methods must be implemented.
66
70
  """
67
71
 
68
72
  def compile_decoder_for_dem(
@@ -95,7 +99,6 @@ class Decoder(metaclass=abc.ABCMeta):
95
99
  """
96
100
  raise NotImplementedError('compile_decoder_for_dem')
97
101
 
98
- @abc.abstractmethod
99
102
  def decode_via_files(self,
100
103
  *,
101
104
  num_shots: int,
@@ -141,4 +144,18 @@ class Decoder(metaclass=abc.ABCMeta):
141
144
  temporary objects. All cleanup should be done via sinter
142
145
  deleting this directory after killing the decoder.
143
146
  """
144
- pass
147
+ dem = stim.DetectorErrorModel.from_file(dem_path)
148
+
149
+ try:
150
+ compiled = self.compile_decoder_for_dem(dem=dem)
151
+ except NotImplementedError as ex:
152
+ raise NotImplementedError(f"{type(self).__qualname__} didn't implement `compile_decoder_for_dem` or `decode_via_files`.") from ex
153
+
154
+ num_det_bytes = -(-num_dets // 8)
155
+ num_obs_bytes = -(-num_obs // 8)
156
+ dets = np.fromfile(dets_b8_in_path, dtype=np.uint8, count=num_shots * num_det_bytes)
157
+ dets = dets.reshape(num_shots, num_det_bytes)
158
+ obs = compiled.decode_shots_bit_packed(bit_packed_detection_event_data=dets)
159
+ if obs.dtype != np.uint8 or obs.shape != (num_shots, num_obs_bytes):
160
+ raise ValueError(f"Got a numpy array with dtype={obs.dtype},shape={obs.shape} instead of dtype={np.uint8},shape={(num_shots, num_obs_bytes)} from {type(self).__qualname__}(...).compile_decoder_for_dem(...).decode_shots_bit_packed(...).")
161
+ obs.tofile(obs_predictions_b8_out_path)
@@ -24,14 +24,17 @@ class FusionBlossomCompiledDecoder(CompiledDecoder):
24
24
  bit_packed_detection_event_data: 'np.ndarray',
25
25
  ) -> 'np.ndarray':
26
26
  num_shots = bit_packed_detection_event_data.shape[0]
27
- predictions = np.zeros(shape=(num_shots, self.num_obs), dtype=np.uint8)
27
+ predictions = np.zeros(shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8)
28
28
  import fusion_blossom
29
29
  for shot in range(num_shots):
30
30
  dets_sparse = np.flatnonzero(np.unpackbits(bit_packed_detection_event_data[shot], count=self.num_dets, bitorder='little'))
31
31
  syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse)
32
32
  self.solver.solve(syndrome)
33
33
  prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]))
34
- predictions[shot] = np.packbits(prediction, bitorder='little')
34
+ predictions[shot] = np.packbits(
35
+ np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8),
36
+ bitorder="little",
37
+ )
35
38
  self.solver.clear()
36
39
  return predictions
37
40
 
@@ -15,7 +15,7 @@ def mwpf_import_error() -> ImportError:
15
15
  return ImportError(
16
16
  "The decoder 'MWPF' isn't installed\n"
17
17
  "To fix this, install the python package 'MWPF' into your environment.\n"
18
- "For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n"
18
+ "For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n"
19
19
  )
20
20
 
21
21
 
@@ -38,7 +38,9 @@ class MwpfCompiledDecoder(CompiledDecoder):
38
38
  bit_packed_detection_event_data: "np.ndarray",
39
39
  ) -> "np.ndarray":
40
40
  num_shots = bit_packed_detection_event_data.shape[0]
41
- predictions = np.zeros(shape=(num_shots, self.num_obs), dtype=np.uint8)
41
+ predictions = np.zeros(
42
+ shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8
43
+ )
42
44
  import mwpf
43
45
 
44
46
  for shot in range(num_shots):
@@ -58,29 +60,48 @@ class MwpfCompiledDecoder(CompiledDecoder):
58
60
  np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()])
59
61
  )
60
62
  self.solver.clear()
61
- predictions[shot] = np.packbits(prediction, bitorder="little")
63
+ predictions[shot] = np.packbits(
64
+ np.array(
65
+ list(np.binary_repr(prediction, width=self.num_obs))[::-1],
66
+ dtype=np.uint8,
67
+ ),
68
+ bitorder="little",
69
+ )
62
70
  return predictions
63
71
 
64
72
 
65
73
  class MwpfDecoder(Decoder):
66
74
  """Use MWPF to predict observables from detection events."""
67
75
 
68
- def compile_decoder_for_dem(
76
+ def __init__(
69
77
  self,
70
- *,
71
- dem: "stim.DetectorErrorModel",
72
78
  decoder_cls: Any = None, # decoder class used to construct the MWPF decoder.
73
79
  # in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins`
74
80
  # but just provide different plugins for optimizing the primal and/or dual solutions.
75
81
  # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
76
82
  # grows the clusters until the first valid solution appears; some more optimized solvers uses
77
83
  # one or more plugins to further optimize the solution, which requires longer decoding time.
84
+ cluster_node_limit: int = 50, # The maximum number of nodes in a cluster,
85
+ ):
86
+ self.decoder_cls = decoder_cls
87
+ self.cluster_node_limit = cluster_node_limit
88
+ super().__init__()
89
+
90
+ def compile_decoder_for_dem(
91
+ self,
92
+ *,
93
+ dem: "stim.DetectorErrorModel",
78
94
  ) -> CompiledDecoder:
79
95
  solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
80
- dem, decoder_cls=decoder_cls
96
+ dem,
97
+ decoder_cls=self.decoder_cls,
98
+ cluster_node_limit=self.cluster_node_limit,
81
99
  )
82
100
  return MwpfCompiledDecoder(
83
- solver, fault_masks, dem.num_detectors, dem.num_observables
101
+ solver,
102
+ fault_masks,
103
+ dem.num_detectors,
104
+ dem.num_observables,
84
105
  )
85
106
 
86
107
  def decode_via_files(
@@ -93,13 +114,14 @@ class MwpfDecoder(Decoder):
93
114
  dets_b8_in_path: pathlib.Path,
94
115
  obs_predictions_b8_out_path: pathlib.Path,
95
116
  tmp_dir: pathlib.Path,
96
- decoder_cls: Any = None,
97
117
  ) -> None:
98
118
  import mwpf
99
119
 
100
120
  error_model = stim.DetectorErrorModel.from_file(dem_path)
101
121
  solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
102
- error_model, decoder_cls=decoder_cls
122
+ error_model,
123
+ decoder_cls=self.decoder_cls,
124
+ cluster_node_limit=self.cluster_node_limit,
103
125
  )
104
126
  num_det_bytes = math.ceil(num_dets / 8)
105
127
  with open(dets_b8_in_path, "rb") as dets_in_f:
@@ -130,44 +152,8 @@ class MwpfDecoder(Decoder):
130
152
 
131
153
 
132
154
  class HyperUFDecoder(MwpfDecoder):
133
- def compile_decoder_for_dem(
134
- self, *, dem: "stim.DetectorErrorModel"
135
- ) -> CompiledDecoder:
136
- try:
137
- import mwpf
138
- except ImportError as ex:
139
- raise mwpf_import_error() from ex
140
-
141
- return super().compile_decoder_for_dem(
142
- dem=dem, decoder_cls=mwpf.SolverSerialUnionFind
143
- )
144
-
145
- def decode_via_files(
146
- self,
147
- *,
148
- num_shots: int,
149
- num_dets: int,
150
- num_obs: int,
151
- dem_path: pathlib.Path,
152
- dets_b8_in_path: pathlib.Path,
153
- obs_predictions_b8_out_path: pathlib.Path,
154
- tmp_dir: pathlib.Path,
155
- ) -> None:
156
- try:
157
- import mwpf
158
- except ImportError as ex:
159
- raise mwpf_import_error() from ex
160
-
161
- return super().decode_via_files(
162
- num_shots=num_shots,
163
- num_dets=num_dets,
164
- num_obs=num_obs,
165
- dem_path=dem_path,
166
- dets_b8_in_path=dets_b8_in_path,
167
- obs_predictions_b8_out_path=obs_predictions_b8_out_path,
168
- tmp_dir=tmp_dir,
169
- decoder_cls=mwpf.SolverSerialUnionFind,
170
- )
155
+ def __init__(self):
156
+ super().__init__(decoder_cls="SolverSerialUnionFind", cluster_node_limit=0)
171
157
 
172
158
 
173
159
  def iter_flatten_model(
@@ -187,16 +173,16 @@ def iter_flatten_model(
187
173
  _helper(instruction.body_copy(), instruction.repeat_count)
188
174
  elif isinstance(instruction, stim.DemInstruction):
189
175
  if instruction.type == "error":
190
- dets: List[int] = []
191
- frames: List[int] = []
176
+ dets: set[int] = set()
177
+ frames: set[int] = set()
192
178
  t: stim.DemTarget
193
179
  p = instruction.args_copy()[0]
194
180
  for t in instruction.targets_copy():
195
181
  if t.is_relative_detector_id():
196
- dets.append(t.val + det_offset)
182
+ dets ^= {t.val + det_offset}
197
183
  elif t.is_logical_observable_id():
198
- frames.append(t.val)
199
- handle_error(p, dets, frames)
184
+ frames ^= {t.val}
185
+ handle_error(p, list(dets), list(frames))
200
186
  elif instruction.type == "shift_detectors":
201
187
  det_offset += instruction.targets_copy()[0]
202
188
  a = np.array(instruction.args_copy())
@@ -220,26 +206,31 @@ def iter_flatten_model(
220
206
  def deduplicate_hyperedges(
221
207
  hyperedges: List[Tuple[List[int], float, int]]
222
208
  ) -> List[Tuple[List[int], float, int]]:
223
- indices: dict[frozenset[int], int] = dict()
209
+ indices: dict[frozenset[int], Tuple[int, float]] = dict()
224
210
  result: List[Tuple[List[int], float, int]] = []
225
211
  for dets, weight, mask in hyperedges:
226
212
  dets_set = frozenset(dets)
227
213
  if dets_set in indices:
228
- idx = indices[dets_set]
214
+ idx, min_weight = indices[dets_set]
229
215
  p1 = 1 / (1 + math.exp(weight))
230
216
  p2 = 1 / (1 + math.exp(result[idx][1]))
231
217
  p = p1 * (1 - p2) + p2 * (1 - p1)
232
- # not sure why would this fail? two hyperedges with different masks?
233
- # assert mask == result[idx][2], (result[idx], (dets, weight, mask))
234
- result[idx] = (dets, math.log((1 - p) / p), result[idx][2])
218
+ # choosing the mask from the most likely error
219
+ new_mask = result[idx][2]
220
+ if weight < min_weight:
221
+ indices[dets_set] = (idx, weight)
222
+ new_mask = mask
223
+ result[idx] = (dets, math.log((1 - p) / p), new_mask)
235
224
  else:
236
- indices[dets_set] = len(result)
225
+ indices[dets_set] = (len(result), weight)
237
226
  result.append((dets, weight, mask))
238
227
  return result
239
228
 
240
229
 
241
230
  def detector_error_model_to_mwpf_solver_and_fault_masks(
242
- model: stim.DetectorErrorModel, decoder_cls: Any = None
231
+ model: stim.DetectorErrorModel,
232
+ decoder_cls: Any = None,
233
+ cluster_node_limit: int = 50,
243
234
  ) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]:
244
235
  """Convert a stim error model into a NetworkX graph."""
245
236
 
@@ -261,7 +252,7 @@ def detector_error_model_to_mwpf_solver_and_fault_masks(
261
252
  # Accept it and keep going, though of course decoding will probably perform terribly.
262
253
  return
263
254
  if p > 0.5:
264
- # mwpf doesn't support negative edge weights.
255
+ # mwpf doesn't support negative edge weights (yet, will be supported in the next version).
265
256
  # approximate them as weight 0.
266
257
  p = 0.5
267
258
  weight = math.log((1 - p) / p)
@@ -280,7 +271,7 @@ def detector_error_model_to_mwpf_solver_and_fault_masks(
280
271
  # mwpf package panic on duplicate edges, thus we need to handle them here
281
272
  hyperedges = deduplicate_hyperedges(hyperedges)
282
273
 
283
- # fix the input by connecting an edge to all isolated vertices
274
+ # fix the input by connecting an edge to all isolated vertices; will be supported in the next version
284
275
  for idx in range(num_detectors):
285
276
  if not is_detector_connected[idx]:
286
277
  hyperedges.append(([idx], 0, 0))
@@ -299,9 +290,11 @@ def detector_error_model_to_mwpf_solver_and_fault_masks(
299
290
  if decoder_cls is None:
300
291
  # default to the solver with highest accuracy
301
292
  decoder_cls = mwpf.SolverSerialJointSingleHair
293
+ elif isinstance(decoder_cls, str):
294
+ decoder_cls = getattr(mwpf, decoder_cls)
302
295
  return (
303
296
  (
304
- decoder_cls(initializer)
297
+ decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})
305
298
  if num_detectors > 0 and len(rescaled_edges) > 0
306
299
  else None
307
300
  ),
@@ -10,6 +10,7 @@ import pytest
10
10
  import sinter
11
11
  import stim
12
12
 
13
+ from sinter import CompiledDecoder
13
14
  from sinter._collection import post_selection_mask_from_4th_coord
14
15
  from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_DECODERS
15
16
  from sinter._decoding._decoding import sample_decode
@@ -391,3 +392,89 @@ def test_full_scale(decoder: str):
391
392
  assert result.discards == 0
392
393
  assert result.shots == 1000
393
394
  assert result.errors == 0
395
+
396
+
397
+ def test_infer_decode_via_files_from_decode_from_compile_decoder_for_dem():
398
+ class IncompleteDecoder(sinter.Decoder):
399
+ pass
400
+
401
+ class WrongDecoder(sinter.Decoder, sinter.CompiledDecoder):
402
+ def compile_decoder_for_dem(
403
+ self,
404
+ *,
405
+ dem: stim.DetectorErrorModel,
406
+ ) -> CompiledDecoder:
407
+ return self
408
+ def decode_shots_bit_packed(
409
+ self,
410
+ *,
411
+ bit_packed_detection_event_data: np.ndarray,
412
+ ) -> np.ndarray:
413
+ return np.zeros(shape=5, dtype=np.bool_)
414
+
415
+ class TrivialCompiledDecoder(sinter.CompiledDecoder):
416
+ def __init__(self, num_obs: int):
417
+ self.num_obs = -(-num_obs // 8)
418
+
419
+ def decode_shots_bit_packed(
420
+ self,
421
+ *,
422
+ bit_packed_detection_event_data: np.ndarray,
423
+ ) -> np.ndarray:
424
+ return np.zeros(dtype=np.uint8, shape=(bit_packed_detection_event_data.shape[0], self.num_obs))
425
+
426
+ class TrivialDecoder(sinter.Decoder):
427
+ def compile_decoder_for_dem(
428
+ self,
429
+ *,
430
+ dem: stim.DetectorErrorModel,
431
+ ) -> CompiledDecoder:
432
+ return TrivialCompiledDecoder(num_obs=dem.num_observables)
433
+
434
+ circuit = stim.Circuit.generated("repetition_code:memory", distance=3, rounds=3)
435
+ dem = circuit.detector_error_model()
436
+
437
+ with tempfile.TemporaryDirectory() as d:
438
+ d = pathlib.Path(d)
439
+
440
+ circuit.compile_detector_sampler().sample_write(
441
+ shots=10,
442
+ filepath=d / 'dets.b8',
443
+ format='b8',
444
+ )
445
+
446
+ dem.to_file(d / 'dem.dem')
447
+
448
+ with pytest.raises(NotImplementedError, match='compile_decoder_for_dem'):
449
+ IncompleteDecoder().decode_via_files(
450
+ num_shots=10,
451
+ num_dets=dem.num_detectors,
452
+ num_obs=dem.num_observables,
453
+ dem_path=d / 'dem.dem',
454
+ dets_b8_in_path=d / 'dets.b8',
455
+ obs_predictions_b8_out_path=d / 'obs.b8',
456
+ tmp_dir=d,
457
+ )
458
+
459
+ with pytest.raises(ValueError, match='shape='):
460
+ WrongDecoder().decode_via_files(
461
+ num_shots=10,
462
+ num_dets=dem.num_detectors,
463
+ num_obs=dem.num_observables,
464
+ dem_path=d / 'dem.dem',
465
+ dets_b8_in_path=d / 'dets.b8',
466
+ obs_predictions_b8_out_path=d / 'obs.b8',
467
+ tmp_dir=d,
468
+ )
469
+
470
+ TrivialDecoder().decode_via_files(
471
+ num_shots=10,
472
+ num_dets=dem.num_detectors,
473
+ num_obs=dem.num_observables,
474
+ dem_path=d / 'dem.dem',
475
+ dets_b8_in_path=d / 'dets.b8',
476
+ obs_predictions_b8_out_path=d / 'obs.b8',
477
+ tmp_dir=d,
478
+ )
479
+ obs = np.fromfile(d / 'obs.b8', dtype=np.uint8, count=10)
480
+ np.testing.assert_array_equal(obs, [0] * 10)
@@ -140,7 +140,7 @@ def _compile_decoder_with_disk_fallback(
140
140
  ) -> CompiledDecoder:
141
141
  try:
142
142
  return decoder.compile_decoder_for_dem(dem=task.detector_error_model)
143
- except (NotImplementedError, ValueError):
143
+ except NotImplementedError:
144
144
  pass
145
145
  if tmp_dir is None:
146
146
  raise ValueError(f"Decoder {task.decoder=} didn't implement `compile_decoder_for_dem`, but no temporary directory was provided for falling back to `decode_via_files`.")
@@ -51,7 +51,7 @@ def log_binomial(*, p: Union[float, np.ndarray], n: int, hits: int) -> np.ndarra
51
51
  Examples:
52
52
  >>> import sinter
53
53
  >>> sinter.log_binomial(p=0.5, n=100, hits=50)
54
- array(-2.5308785, dtype=float32)
54
+ array(-2.5308762, dtype=float32)
55
55
  >>> sinter.log_binomial(p=0.2, n=1_000_000, hits=1_000)
56
56
  array(-216626.97, dtype=float32)
57
57
  >>> sinter.log_binomial(p=0.1, n=1_000_000, hits=1_000)
@@ -321,7 +321,7 @@ def fit_line_slope(*,
321
321
 
322
322
  low_slope = binary_intercept(start_x=fit.slope, step=-1, target_y=base_cost + max_extra_squared_error, func=cost_for_slope, atol=1e-5)
323
323
  high_slope = binary_intercept(start_x=fit.slope, step=1, target_y=base_cost + max_extra_squared_error, func=cost_for_slope, atol=1e-5)
324
- return Fit(low=low_slope, best=fit.slope, high=high_slope)
324
+ return Fit(low=float(low_slope), best=float(fit.slope), high=float(high_slope))
325
325
 
326
326
 
327
327
  def fit_binomial(
@@ -1,16 +1,24 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: sinter
3
- Version: 1.14.dev1727164676
3
+ Version: 1.15.0
4
4
  Summary: Samples stim circuits and decodes them using pymatching.
5
5
  Author: Craig Gidney
6
6
  Author-email: craig.gidney@gmail.com
7
7
  License: Apache 2
8
8
  Requires-Python: >=3.7.0
9
9
  Description-Content-Type: text/markdown
10
- Requires-Dist: matplotlib~=3.5
11
- Requires-Dist: numpy~=1.22
10
+ Requires-Dist: matplotlib
11
+ Requires-Dist: numpy
12
12
  Requires-Dist: stim
13
- Requires-Dist: scipy~=1.9
13
+ Requires-Dist: scipy
14
+ Dynamic: author
15
+ Dynamic: author-email
16
+ Dynamic: description
17
+ Dynamic: description-content-type
18
+ Dynamic: license
19
+ Dynamic: requires-dist
20
+ Dynamic: requires-python
21
+ Dynamic: summary
14
22
 
15
23
  # sinter: fast QEC sampling
16
24
 
@@ -27,6 +27,7 @@ src/sinter/_collection/_collection_worker_test.py
27
27
  src/sinter/_collection/_mux_sampler.py
28
28
  src/sinter/_collection/_printer.py
29
29
  src/sinter/_collection/_sampler_ramp_throttled.py
30
+ src/sinter/_collection/_sampler_ramp_throttled_test.py
30
31
  src/sinter/_command/__init__.py
31
32
  src/sinter/_command/_main.py
32
33
  src/sinter/_command/_main_collect.py
@@ -0,0 +1,4 @@
1
+ matplotlib
2
+ numpy
3
+ stim
4
+ scipy
@@ -1,4 +0,0 @@
1
- matplotlib ~= 3.5
2
- numpy ~= 1.22
3
- stim
4
- scipy ~= 1.9
@@ -1,4 +0,0 @@
1
- matplotlib~=3.5
2
- numpy~=1.22
3
- stim
4
- scipy~=1.9
File without changes
File without changes
File without changes