sinter 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sinter might be problematic. Click here for more details.

Files changed (62) hide show
  1. sinter/__init__.py +47 -0
  2. sinter/_collection/__init__.py +10 -0
  3. sinter/_collection/_collection.py +480 -0
  4. sinter/_collection/_collection_manager.py +581 -0
  5. sinter/_collection/_collection_manager_test.py +287 -0
  6. sinter/_collection/_collection_test.py +317 -0
  7. sinter/_collection/_collection_worker_loop.py +35 -0
  8. sinter/_collection/_collection_worker_state.py +259 -0
  9. sinter/_collection/_collection_worker_test.py +222 -0
  10. sinter/_collection/_mux_sampler.py +56 -0
  11. sinter/_collection/_printer.py +65 -0
  12. sinter/_collection/_sampler_ramp_throttled.py +66 -0
  13. sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
  14. sinter/_command/__init__.py +0 -0
  15. sinter/_command/_main.py +39 -0
  16. sinter/_command/_main_collect.py +350 -0
  17. sinter/_command/_main_collect_test.py +482 -0
  18. sinter/_command/_main_combine.py +84 -0
  19. sinter/_command/_main_combine_test.py +153 -0
  20. sinter/_command/_main_plot.py +817 -0
  21. sinter/_command/_main_plot_test.py +445 -0
  22. sinter/_command/_main_predict.py +75 -0
  23. sinter/_command/_main_predict_test.py +36 -0
  24. sinter/_data/__init__.py +20 -0
  25. sinter/_data/_anon_task_stats.py +89 -0
  26. sinter/_data/_anon_task_stats_test.py +35 -0
  27. sinter/_data/_collection_options.py +106 -0
  28. sinter/_data/_collection_options_test.py +24 -0
  29. sinter/_data/_csv_out.py +74 -0
  30. sinter/_data/_existing_data.py +173 -0
  31. sinter/_data/_existing_data_test.py +41 -0
  32. sinter/_data/_task.py +311 -0
  33. sinter/_data/_task_stats.py +244 -0
  34. sinter/_data/_task_stats_test.py +140 -0
  35. sinter/_data/_task_test.py +38 -0
  36. sinter/_decoding/__init__.py +16 -0
  37. sinter/_decoding/_decoding.py +419 -0
  38. sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
  39. sinter/_decoding/_decoding_decoder_class.py +161 -0
  40. sinter/_decoding/_decoding_fusion_blossom.py +193 -0
  41. sinter/_decoding/_decoding_mwpf.py +302 -0
  42. sinter/_decoding/_decoding_pymatching.py +81 -0
  43. sinter/_decoding/_decoding_test.py +480 -0
  44. sinter/_decoding/_decoding_vacuous.py +38 -0
  45. sinter/_decoding/_perfectionist_sampler.py +38 -0
  46. sinter/_decoding/_sampler.py +72 -0
  47. sinter/_decoding/_stim_then_decode_sampler.py +222 -0
  48. sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
  49. sinter/_plotting.py +619 -0
  50. sinter/_plotting_test.py +108 -0
  51. sinter/_predict.py +381 -0
  52. sinter/_predict_test.py +227 -0
  53. sinter/_probability_util.py +519 -0
  54. sinter/_probability_util_test.py +281 -0
  55. sinter-1.15.0.data/data/README.md +332 -0
  56. sinter-1.15.0.data/data/readme_example_plot.png +0 -0
  57. sinter-1.15.0.data/data/requirements.txt +4 -0
  58. sinter-1.15.0.dist-info/METADATA +354 -0
  59. sinter-1.15.0.dist-info/RECORD +62 -0
  60. sinter-1.15.0.dist-info/WHEEL +5 -0
  61. sinter-1.15.0.dist-info/entry_points.txt +2 -0
  62. sinter-1.15.0.dist-info/top_level.txt +1 -0
sinter/_plotting.py ADDED
@@ -0,0 +1,619 @@
1
+ import math
2
+ from typing import Callable, TypeVar, List, Any, Iterable, Optional, TYPE_CHECKING, Dict, Union, Literal, Tuple
3
+ from typing import Sequence
4
+ from typing import cast
5
+
6
+ import numpy as np
7
+
8
+ from sinter._probability_util import fit_binomial, shot_error_rate_to_piece_error_rate, Fit
9
+
10
+ if TYPE_CHECKING:
11
+ import sinter
12
+ import matplotlib.pyplot as plt
13
+
14
+
15
+ MARKERS: str = "ov*sp^<>8PhH+xXDd|" * 100
16
+ LINESTYLES: tuple[str, ...] = (
17
+ 'solid',
18
+ 'dotted',
19
+ 'dashed',
20
+ 'dashdot',
21
+ 'loosely dotted',
22
+ 'dotted',
23
+ 'densely dotted',
24
+ 'long dash with offset',
25
+ 'loosely dashed',
26
+ 'dashed',
27
+ 'densely dashed',
28
+ 'loosely dashdotted',
29
+ 'dashdotted',
30
+ 'densely dashdotted',
31
+ 'dashdotdotted',
32
+ 'loosely dashdotdotted',
33
+ 'densely dashdotdotted',
34
+ )
35
+ T = TypeVar('T')
36
+ TVal = TypeVar('TVal')
37
+ TKey = TypeVar('TKey')
38
+
39
+
40
+ def split_by(vs: Iterable[T], key_func: Callable[[T], Any]) -> List[List[T]]:
41
+ cur_key: Any = None
42
+ out: List[List[T]] = []
43
+ buf: List[T] = []
44
+ for item in vs:
45
+ key = key_func(item)
46
+ if key != cur_key:
47
+ cur_key = key
48
+ if buf:
49
+ out.append(buf)
50
+ buf = []
51
+ buf.append(item)
52
+ if buf:
53
+ out.append(buf)
54
+ return out
55
+
56
+
57
+ class LooseCompare:
58
+ def __init__(self, val: Any):
59
+ self.val: Any = None
60
+
61
+ self.val = val.val if isinstance(val, LooseCompare) else val
62
+
63
+ def __lt__(self, other: Any) -> bool:
64
+ other_val = other.val if isinstance(other, LooseCompare) else other
65
+ if isinstance(self.val, (int, float)) and isinstance(other_val, (int, float)):
66
+ return self.val < other_val
67
+ if isinstance(self.val, (tuple, list)) and isinstance(other_val, (tuple, list)):
68
+ return tuple(LooseCompare(e) for e in self.val) < tuple(LooseCompare(e) for e in other_val)
69
+ return str(self.val) < str(other_val)
70
+
71
+ def __gt__(self, other: Any) -> bool:
72
+ other_val = other.val if isinstance(other, LooseCompare) else other
73
+ if isinstance(self.val, (int, float)) and isinstance(other_val, (int, float)):
74
+ return self.val > other_val
75
+ if isinstance(self.val, (tuple, list)) and isinstance(other_val, (tuple, list)):
76
+ return tuple(LooseCompare(e) for e in self.val) > tuple(LooseCompare(e) for e in other_val)
77
+ return str(self.val) > str(other_val)
78
+
79
+ def __str__(self) -> str:
80
+ return str(self.val)
81
+
82
+ def __eq__(self, other: Any) -> bool:
83
+ if isinstance(other, LooseCompare):
84
+ other_val = other.val
85
+ else:
86
+ other_val = other
87
+ if isinstance(self.val, (int, float)) and isinstance(other_val, (int, float)):
88
+ return self.val == other_val
89
+ return str(self.val) == str(other_val)
90
+
91
+
92
+ def better_sorted_str_terms(val: Any) -> Any:
93
+ """A function that orders "a10000" after "a9", instead of before.
94
+
95
+ Normally, sorting strings sorts them lexicographically, treating numbers so
96
+ that "1999999" ends up being less than "2". This method splits the string
97
+ into a tuple of text pairs and parsed number parts, so that sorting by this
98
+ key puts "2" before "1999999".
99
+
100
+ Because this method is intended for use in plotting, where it's more
101
+ important to see a bad result than to see nothing, it returns a type that
102
+ tries to be comparable to everything.
103
+
104
+ Args:
105
+ val: The value to convert into a value with a better sorting order.
106
+
107
+ Returns:
108
+ A custom type of object with a better sorting order.
109
+
110
+ Examples:
111
+ >>> import sinter
112
+ >>> items = [
113
+ ... "distance=199999, rounds=3",
114
+ ... "distance=2, rounds=3",
115
+ ... "distance=199999, rounds=199999",
116
+ ... "distance=2, rounds=199999",
117
+ ... ]
118
+ >>> for e in sorted(items, key=sinter.better_sorted_str_terms):
119
+ ... print(e)
120
+ distance=2, rounds=3
121
+ distance=2, rounds=199999
122
+ distance=199999, rounds=3
123
+ distance=199999, rounds=199999
124
+ """
125
+ if val is None:
126
+ return 'None'
127
+ if isinstance(val, tuple):
128
+ return tuple(better_sorted_str_terms(e) for e in val)
129
+ if not isinstance(val, str):
130
+ return LooseCompare(val)
131
+ terms = split_by(val, lambda c: c in '.0123456789')
132
+ result = []
133
+ for term in terms:
134
+ term = ''.join(term)
135
+ if '.' in term:
136
+ try:
137
+ term = float(term)
138
+ except ValueError:
139
+ try:
140
+ term = tuple(int(e) for e in term.split('.'))
141
+ except ValueError:
142
+ pass
143
+ else:
144
+ try:
145
+ term = int(term)
146
+ except ValueError:
147
+ pass
148
+ result.append(term)
149
+ if len(result) == 1 and isinstance(result[0], (int, float)):
150
+ return LooseCompare(result[0])
151
+ return tuple(LooseCompare(e) for e in result)
152
+
153
+
154
+ def group_by(items: Iterable[TVal],
155
+ *,
156
+ key: Callable[[TVal], TKey],
157
+ ) -> Dict[TKey, List[TVal]]:
158
+ """Groups items based on whether they produce the same key from a function.
159
+
160
+ Args:
161
+ items: The items to group.
162
+ key: Items that produce the same value from this function get grouped together.
163
+
164
+ Returns:
165
+ A dictionary mapping outputs that were produced by the grouping function to
166
+ the list of items that produced that output.
167
+
168
+ Examples:
169
+ >>> import sinter
170
+ >>> sinter.group_by([1, 2, 3], key=lambda i: i == 2)
171
+ {False: [1, 3], True: [2]}
172
+
173
+ >>> sinter.group_by(range(10), key=lambda i: i % 3)
174
+ {0: [0, 3, 6, 9], 1: [1, 4, 7], 2: [2, 5, 8]}
175
+ """
176
+
177
+ result: Dict[TKey, List[TVal]] = {}
178
+
179
+ for item in items:
180
+ curve_id = key(item)
181
+ result.setdefault(curve_id, []).append(item)
182
+
183
+ return result
184
+
185
+
186
+ TCurveId = TypeVar('TCurveId')
187
+
188
+
189
+ class _FrozenDict:
190
+ def __init__(self, v: dict):
191
+ self._v = dict(v)
192
+ self._eq = frozenset(v.items())
193
+ self._hash = hash(self._eq)
194
+
195
+ terms = []
196
+ for k in sorted(self._v.keys(), key=lambda e: (e != 'sort', e)):
197
+ terms.append(k)
198
+ terms.append(better_sorted_str_terms(self._v[k])
199
+ )
200
+ self._order = tuple(terms)
201
+
202
+ def __eq__(self, other):
203
+ if isinstance(other, _FrozenDict):
204
+ return self._eq == other._eq
205
+ return NotImplemented
206
+
207
+ def __lt__(self, other):
208
+ if isinstance(other, _FrozenDict):
209
+ return self._order < other._order
210
+ return NotImplemented
211
+
212
+ def __ne__(self, other):
213
+ return not (self == other)
214
+
215
+ def __hash__(self):
216
+ return self._hash
217
+
218
+ def __getitem__(self, item):
219
+ return self._v[item]
220
+
221
+ def get(self, item, alternate = None):
222
+ return self._v.get(item, alternate)
223
+
224
+ def __str__(self):
225
+ return " ".join(str(v) for _, v in sorted(self._v.items()))
226
+
227
+
228
+ def plot_discard_rate(
229
+ *,
230
+ ax: 'plt.Axes',
231
+ stats: 'Iterable[sinter.TaskStats]',
232
+ x_func: Callable[['sinter.TaskStats'], Any],
233
+ failure_units_per_shot_func: Callable[['sinter.TaskStats'], Any] = lambda _: 1,
234
+ group_func: Callable[['sinter.TaskStats'], TCurveId] = lambda _: None,
235
+ filter_func: Callable[['sinter.TaskStats'], Any] = lambda _: True,
236
+ plot_args_func: Callable[[int, TCurveId, List['sinter.TaskStats']], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
237
+ highlight_max_likelihood_factor: Optional[float] = 1e3,
238
+ point_label_func: Callable[['sinter.TaskStats'], Any] = lambda _: None,
239
+ ) -> None:
240
+ """Plots discard rates in curves with uncertainty highlights.
241
+
242
+ Args:
243
+ ax: The plt.Axes to plot onto. For example, the `ax` value from `fig, ax = plt.subplots(1, 1)`.
244
+ stats: The collected statistics to plot.
245
+ x_func: The X coordinate to use for each stat's data point. For example, this could be
246
+ `x_func=lambda stat: stat.json_metadata['physical_error_rate']`.
247
+ failure_units_per_shot_func: How many discard chances there are per shot. This rescales what the
248
+ discard rate means. By default, it is the discard rate per shot, but this allows
249
+ you to instead make it the discard rate per round. For example, if the metadata
250
+ associated with a shot has a field 'r' which is the number of rounds, then this can be
251
+ achieved with `failure_units_per_shot_func=lambda stats: stats.metadata['r']`.
252
+ group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
253
+ The statistics are grouped into curves based on whether or not they get the same result
254
+ out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
255
+ If the result of the function is a dictionary, then optional keys in the dictionary will
256
+ also control the plotting of each curve. Available keys are:
257
+ 'label': the label added to the legend for the curve
258
+ 'color': the color used for plotting the curve
259
+ 'marker': the marker used for the curve
260
+ 'linestyle': the linestyle used for the curve
261
+ 'sort': the order in which the curves will be plotted and added to the legend
262
+ e.g. if two curves (with different resulting dictionaries from group_func) share the same
263
+ value for key 'marker', they will be plotted with the same marker.
264
+ Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
265
+ filter_func: Optional. When specified, some curves will not be plotted.
266
+ The statistics are filtered and only plotted if filter_func(stat) returns True.
267
+ For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
268
+ where the saved metadata indicates the basis was 'x'.
269
+ plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
270
+ `plot` and `fill_between` used to do the actual plotting. For example, this can be used
271
+ to specify markers and colors. Takes the index of the curve in sorted order and also a
272
+ curve_id (these will be 0 and None respectively if group_func is not specified). For example,
273
+ this could be:
274
+
275
+ plot_args_func=lambda index, curve_id: {'color': 'red'
276
+ if curve_id == 'pymatching'
277
+ else 'blue'}
278
+
279
+ highlight_max_likelihood_factor: Controls how wide the uncertainty highlight region around curves is.
280
+ Must be 1 or larger. Hypothesis probabilities at most that many times as unlikely as the max likelihood
281
+ hypothesis will be highlighted.
282
+ point_label_func: Optional. Specifies text to draw next to data points.
283
+ """
284
+ if highlight_max_likelihood_factor is None:
285
+ highlight_max_likelihood_factor = 1
286
+
287
+ def y_func(stat: 'sinter.TaskStats') -> Union[float, 'sinter.Fit']:
288
+ result = fit_binomial(
289
+ num_shots=stat.shots,
290
+ num_hits=stat.discards,
291
+ max_likelihood_factor=highlight_max_likelihood_factor,
292
+ )
293
+
294
+ pieces = failure_units_per_shot_func(stat)
295
+ result = Fit(
296
+ low=shot_error_rate_to_piece_error_rate(result.low, pieces=pieces),
297
+ best=shot_error_rate_to_piece_error_rate(result.best, pieces=pieces),
298
+ high=shot_error_rate_to_piece_error_rate(result.high, pieces=pieces),
299
+ )
300
+
301
+ if highlight_max_likelihood_factor == 1:
302
+ return result.best
303
+ return result
304
+
305
+ plot_custom(
306
+ ax=ax,
307
+ stats=stats,
308
+ x_func=x_func,
309
+ y_func=y_func,
310
+ group_func=group_func,
311
+ filter_func=filter_func,
312
+ plot_args_func=plot_args_func,
313
+ point_label_func=point_label_func,
314
+ )
315
+
316
+
317
+ def plot_error_rate(
318
+ *,
319
+ ax: 'plt.Axes',
320
+ stats: 'Iterable[sinter.TaskStats]',
321
+ x_func: Callable[['sinter.TaskStats'], Any],
322
+ failure_units_per_shot_func: Callable[['sinter.TaskStats'], Any] = lambda _: 1,
323
+ failure_values_func: Callable[['sinter.TaskStats'], Any] = lambda _: 1,
324
+ group_func: Callable[['sinter.TaskStats'], TCurveId] = lambda _: None,
325
+ filter_func: Callable[['sinter.TaskStats'], Any] = lambda _: True,
326
+ plot_args_func: Callable[[int, TCurveId, List['sinter.TaskStats']], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
327
+ highlight_max_likelihood_factor: Optional[float] = 1e3,
328
+ line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None,
329
+ point_label_func: Callable[['sinter.TaskStats'], Any] = lambda _: None,
330
+ ) -> None:
331
+ """Plots error rates in curves with uncertainty highlights.
332
+
333
+ Args:
334
+ ax: The plt.Axes to plot onto. For example, the `ax` value from `fig, ax = plt.subplots(1, 1)`.
335
+ stats: The collected statistics to plot.
336
+ x_func: The X coordinate to use for each stat's data point. For example, this could be
337
+ `x_func=lambda stat: stat.json_metadata['physical_error_rate']`.
338
+ failure_units_per_shot_func: How many error chances there are per shot. This rescales what the
339
+ logical error rate means. By default, it is the logical error rate per shot, but this allows
340
+ you to instead make it the logical error rate per round. For example, if the metadata
341
+ associated with a shot has a field 'r' which is the number of rounds, then this can be
342
+ achieved with `failure_units_per_shot_func=lambda stats: stats.metadata['r']`.
343
+ failure_values_func: How many independent ways there are for a shot to fail, such as
344
+ the number of independent observables in a memory experiment. This affects how the failure
345
+ units rescaling plays out (e.g. with 1 independent failure the "center" of the conversion
346
+ is at 50% whereas for 2 independent failures the "center" is at 75%).
347
+ group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
348
+ The statistics are grouped into curves based on whether or not they get the same result
349
+ out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
350
+ If the result of the function is a dictionary, then optional keys in the dictionary will
351
+ also control the plotting of each curve. Available keys are:
352
+ 'label': the label added to the legend for the curve
353
+ 'color': the color used for plotting the curve
354
+ 'marker': the marker used for the curve
355
+ 'linestyle': the linestyle used for the curve
356
+ 'sort': the order in which the curves will be plotted and added to the legend
357
+ e.g. if two curves (with different resulting dictionaries from group_func) share the same
358
+ value for key 'marker', they will be plotted with the same marker.
359
+ Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
360
+ filter_func: Optional. When specified, some curves will not be plotted.
361
+ The statistics are filtered and only plotted if filter_func(stat) returns True.
362
+ For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
363
+ where the saved metadata indicates the basis was 'x'.
364
+ plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
365
+ `plot` and `fill_between` used to do the actual plotting. For example, this can be used
366
+ to specify markers and colors. Takes the index of the curve in sorted order and also a
367
+ curve_id (these will be 0 and None respectively if group_func is not specified). For example,
368
+ this could be:
369
+
370
+ plot_args_func=lambda index, curve_id: {'color': 'red'
371
+ if curve_id == 'pymatching'
372
+ else 'blue'}
373
+
374
+ highlight_max_likelihood_factor: Controls how wide the uncertainty highlight region around curves is.
375
+ Must be 1 or larger. Hypothesis probabilities at most that many times as unlikely as the max likelihood
376
+ hypothesis will be highlighted.
377
+ line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line
378
+ fit to every curve. The scales determine how to transform the coordinates before
379
+ performing the fit, and can be set to 'linear', 'sqrt', or 'log'.
380
+ point_label_func: Optional. Specifies text to draw next to data points.
381
+ """
382
+ if highlight_max_likelihood_factor is None:
383
+ highlight_max_likelihood_factor = 1
384
+ if not (highlight_max_likelihood_factor >= 1):
385
+ raise ValueError(f"not (highlight_max_likelihood_factor={highlight_max_likelihood_factor} >= 1)")
386
+
387
+ def y_func(stat: 'sinter.TaskStats') -> Union[float, 'sinter.Fit']:
388
+ result = fit_binomial(
389
+ num_shots=stat.shots - stat.discards,
390
+ num_hits=stat.errors,
391
+ max_likelihood_factor=highlight_max_likelihood_factor,
392
+ )
393
+
394
+ pieces = failure_units_per_shot_func(stat)
395
+ values = failure_values_func(stat)
396
+ result = Fit(
397
+ low=shot_error_rate_to_piece_error_rate(result.low, pieces=pieces, values=values),
398
+ best=shot_error_rate_to_piece_error_rate(result.best, pieces=pieces, values=values),
399
+ high=shot_error_rate_to_piece_error_rate(result.high, pieces=pieces, values=values),
400
+ )
401
+
402
+ if stat.errors == 0:
403
+ result = Fit(low=result.low, high=result.high, best=float('nan'))
404
+
405
+ if highlight_max_likelihood_factor == 1:
406
+ return result.best
407
+ return result
408
+
409
+ plot_custom(
410
+ ax=ax,
411
+ stats=stats,
412
+ x_func=x_func,
413
+ y_func=y_func,
414
+ group_func=group_func,
415
+ filter_func=filter_func,
416
+ plot_args_func=plot_args_func,
417
+ line_fits=line_fits,
418
+ point_label_func=point_label_func,
419
+ )
420
+
421
+
422
+ def _rescale(v: Sequence[float], scale: str, invert: bool) -> np.ndarray:
423
+ if scale == 'linear':
424
+ return np.array(v)
425
+ elif scale == 'log':
426
+ return np.exp(v) if invert else np.log(v)
427
+ elif scale == 'sqrt':
428
+ return np.array(v)**2 if invert else np.sqrt(v)
429
+ else:
430
+ raise NotImplementedError(f'{scale=}')
431
+
432
+
433
+ def plot_custom(
434
+ *,
435
+ ax: 'plt.Axes',
436
+ stats: 'Iterable[sinter.TaskStats]',
437
+ x_func: Callable[['sinter.TaskStats'], Any],
438
+ y_func: Callable[['sinter.TaskStats'], Union['sinter.Fit', float, int]],
439
+ group_func: Callable[['sinter.TaskStats'], TCurveId] = lambda _: None,
440
+ point_label_func: Callable[['sinter.TaskStats'], Any] = lambda _: None,
441
+ filter_func: Callable[['sinter.TaskStats'], Any] = lambda _: True,
442
+ plot_args_func: Callable[[int, TCurveId, List['sinter.TaskStats']], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
443
+ line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None,
444
+ ) -> None:
445
+ """Plots error rates in curves with uncertainty highlights.
446
+
447
+ Args:
448
+ ax: The plt.Axes to plot onto. For example, the `ax` value from `fig, ax = plt.subplots(1, 1)`.
449
+ stats: The collected statistics to plot.
450
+ x_func: The X coordinate to use for each stat's data point. For example, this could be
451
+ `x_func=lambda stat: stat.json_metadata['physical_error_rate']`.
452
+ y_func: The Y value to use for each stat's data point. This can be a float or it can be a
453
+ sinter.Fit value, in which case the curve will follow the fit.best value and a
454
+ highlighted area will be shown from fit.low to fit.high.
455
+ group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
456
+ The statistics are grouped into curves based on whether or not they get the same result
457
+ out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
458
+ If the result of the function is a dictionary, then optional keys in the dictionary will
459
+ also control the plotting of each curve. Available keys are:
460
+ 'label': the label added to the legend for the curve
461
+ 'color': the color used for plotting the curve
462
+ 'marker': the marker used for the curve
463
+ 'linestyle': the linestyle used for the curve
464
+ 'sort': the order in which the curves will be plotted and added to the legend
465
+ e.g. if two curves (with different resulting dictionaries from group_func) share the same
466
+ value for key 'marker', they will be plotted with the same marker.
467
+ Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
468
+ point_label_func: Optional. Specifies text to draw next to data points.
469
+ filter_func: Optional. When specified, some curves will not be plotted.
470
+ The statistics are filtered and only plotted if filter_func(stat) returns True.
471
+ For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
472
+ where the saved metadata indicates the basis was 'x'.
473
+ plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
474
+ `plot` and `fill_between` used to do the actual plotting. For example, this can be used
475
+ to specify markers and colors. Takes the index of the curve in sorted order and also a
476
+ curve_id (these will be 0 and None respectively if group_func is not specified). For example,
477
+ this could be:
478
+
479
+ plot_args_func=lambda index, group_key, group_stats: {
480
+ 'color': (
481
+ 'red'
482
+ if group_key == 'decoder=pymatching p=0.001'
483
+ else 'blue'
484
+ ),
485
+ }
486
+ line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line
487
+ fit to every curve. The scales determine how to transform the coordinates before
488
+ performing the fit, and can be set to 'linear', 'sqrt', or 'log'.
489
+ """
490
+
491
+ def group_dict_func(item: 'sinter.TaskStats') -> _FrozenDict:
492
+ e = group_func(item)
493
+ return _FrozenDict(e if isinstance(e, dict) else {'label': str(e)})
494
+
495
+ # Backwards compatibility to when the group stats argument wasn't present.
496
+ import inspect
497
+ if len(inspect.signature(plot_args_func).parameters) == 2:
498
+ old_plot_args_func = cast(Callable[[int, TCurveId], Any], plot_args_func)
499
+ plot_args_func = lambda a, b, _: old_plot_args_func(a, b)
500
+
501
+ filtered_stats: List['sinter.TaskStats'] = [
502
+ stat
503
+ for stat in stats
504
+ if filter_func(stat)
505
+ ]
506
+
507
+ curve_groups = group_by(filtered_stats, key=group_dict_func)
508
+ colors = {
509
+ k: f'C{i}'
510
+ for i, k in enumerate(sorted({g.get('color', g) for g in curve_groups.keys()}, key=better_sorted_str_terms))
511
+ }
512
+ markers = {
513
+ k: MARKERS[i % len(MARKERS)]
514
+ for i, k in enumerate(sorted({g.get('marker', g) for g in curve_groups.keys()}, key=better_sorted_str_terms))
515
+ }
516
+ linestyles = {
517
+ k: LINESTYLES[i % len(LINESTYLES)]
518
+ for i, k in enumerate(sorted({g.get('linestyle', None) for g in curve_groups.keys()}, key=better_sorted_str_terms))
519
+ }
520
+
521
+ def sort_key(a: Any) -> Any:
522
+ if isinstance(a, _FrozenDict):
523
+ return a.get('sort', better_sorted_str_terms(a))
524
+ return better_sorted_str_terms(a)
525
+
526
+ for k, group_key in enumerate(sorted(curve_groups.keys(), key=sort_key)):
527
+ group = curve_groups[group_key]
528
+ group = sorted(group, key=x_func)
529
+ color = colors[group_key.get('color', group_key)]
530
+ marker = markers[group_key.get('marker', group_key)]
531
+ linestyle = linestyles[group_key.get('linestyle', None)]
532
+ label = str(group_key.get('label', group_key))
533
+ xs_label: list[float] = []
534
+ ys_label: list[float] = []
535
+ vs_label: list[float] = []
536
+ xs_best: list[float] = []
537
+ ys_best: list[float] = []
538
+ xs_low_high: list[float] = []
539
+ ys_low: list[float] = []
540
+ ys_high: list[float] = []
541
+ for item in group:
542
+ x = x_func(item)
543
+ y = y_func(item)
544
+ point_label = point_label_func(item)
545
+ if isinstance(y, Fit):
546
+ if y.low is not None and y.high is not None and not math.isnan(y.low) and not math.isnan(y.high):
547
+ xs_low_high.append(x)
548
+ ys_low.append(y.low)
549
+ ys_high.append(y.high)
550
+ if y.best is not None and not math.isnan(y.best):
551
+ ys_best.append(y.best)
552
+ xs_best.append(x)
553
+
554
+ if point_label:
555
+ cy = None
556
+ for e in [y.best, y.high, y.low]:
557
+ if e is not None and not math.isnan(e):
558
+ cy = e
559
+ break
560
+ if cy is not None:
561
+ xs_label.append(x)
562
+ ys_label.append(cy)
563
+ vs_label.append(point_label)
564
+ elif not math.isnan(y):
565
+ xs_best.append(x)
566
+ ys_best.append(y)
567
+ if point_label:
568
+ xs_label.append(x)
569
+ ys_label.append(y)
570
+ vs_label.append(point_label)
571
+ args = dict(plot_args_func(k, group_func(group[0]), group))
572
+ if 'linestyle' not in args:
573
+ args['linestyle'] = linestyle
574
+ if 'marker' not in args:
575
+ args['marker'] = marker
576
+ if 'color' not in args:
577
+ args['color'] = color
578
+ if 'label' not in args:
579
+ args['label'] = label
580
+ ax.plot(xs_best, ys_best, **args)
581
+ for x, y, lbl in zip(xs_label, ys_label, vs_label):
582
+ if lbl:
583
+ ax.annotate(lbl, (x, y))
584
+ if len(xs_low_high) > 1:
585
+ ax.fill_between(xs_low_high, ys_low, ys_high, color=args['color'], alpha=0.2, zorder=-100)
586
+ elif len(xs_low_high) == 1:
587
+ l, = ys_low
588
+ h, = ys_high
589
+ m = (l + h) / 2
590
+ ax.errorbar(xs_low_high, [m], yerr=([m - l], [h - m]), marker='', elinewidth=1, ecolor=color, capsize=5)
591
+
592
+ if line_fits is not None and len(set(xs_best)) >= 2:
593
+ x_scale, y_scale = line_fits
594
+ fit_xs = _rescale(xs_best, x_scale, False)
595
+ fit_ys = _rescale(ys_best, y_scale, False)
596
+
597
+ from scipy.stats import linregress
598
+ line_fit = linregress(fit_xs, fit_ys)
599
+
600
+ x0 = fit_xs[0]
601
+ x1 = fit_xs[-1]
602
+ dx = x1 - x0
603
+ x0 -= dx*10
604
+ x1 += dx*10
605
+ if x0 < 0 <= fit_xs[0] > x0 and x_scale == 'sqrt':
606
+ x0 = 0
607
+
608
+ out_xs = np.linspace(x0, x1, 1000)
609
+ out_ys = out_xs * line_fit.slope + line_fit.intercept
610
+ out_xs = _rescale(out_xs, x_scale, True)
611
+ out_ys = _rescale(out_ys, y_scale, True)
612
+
613
+ line_fit_kwargs = args.copy()
614
+ line_fit_kwargs.pop('marker', None)
615
+ line_fit_kwargs.pop('label', None)
616
+ line_fit_kwargs['linestyle'] = '--'
617
+ line_fit_kwargs.setdefault('linewidth', 1)
618
+ line_fit_kwargs['linewidth'] /= 2
619
+ ax.plot(out_xs, out_ys, **line_fit_kwargs)