sinter 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sinter might be problematic. Click here for more details.
- sinter/__init__.py +47 -0
- sinter/_collection/__init__.py +10 -0
- sinter/_collection/_collection.py +480 -0
- sinter/_collection/_collection_manager.py +581 -0
- sinter/_collection/_collection_manager_test.py +287 -0
- sinter/_collection/_collection_test.py +317 -0
- sinter/_collection/_collection_worker_loop.py +35 -0
- sinter/_collection/_collection_worker_state.py +259 -0
- sinter/_collection/_collection_worker_test.py +222 -0
- sinter/_collection/_mux_sampler.py +56 -0
- sinter/_collection/_printer.py +65 -0
- sinter/_collection/_sampler_ramp_throttled.py +66 -0
- sinter/_collection/_sampler_ramp_throttled_test.py +144 -0
- sinter/_command/__init__.py +0 -0
- sinter/_command/_main.py +39 -0
- sinter/_command/_main_collect.py +350 -0
- sinter/_command/_main_collect_test.py +482 -0
- sinter/_command/_main_combine.py +84 -0
- sinter/_command/_main_combine_test.py +153 -0
- sinter/_command/_main_plot.py +817 -0
- sinter/_command/_main_plot_test.py +445 -0
- sinter/_command/_main_predict.py +75 -0
- sinter/_command/_main_predict_test.py +36 -0
- sinter/_data/__init__.py +20 -0
- sinter/_data/_anon_task_stats.py +89 -0
- sinter/_data/_anon_task_stats_test.py +35 -0
- sinter/_data/_collection_options.py +106 -0
- sinter/_data/_collection_options_test.py +24 -0
- sinter/_data/_csv_out.py +74 -0
- sinter/_data/_existing_data.py +173 -0
- sinter/_data/_existing_data_test.py +41 -0
- sinter/_data/_task.py +311 -0
- sinter/_data/_task_stats.py +244 -0
- sinter/_data/_task_stats_test.py +140 -0
- sinter/_data/_task_test.py +38 -0
- sinter/_decoding/__init__.py +16 -0
- sinter/_decoding/_decoding.py +419 -0
- sinter/_decoding/_decoding_all_built_in_decoders.py +25 -0
- sinter/_decoding/_decoding_decoder_class.py +161 -0
- sinter/_decoding/_decoding_fusion_blossom.py +193 -0
- sinter/_decoding/_decoding_mwpf.py +302 -0
- sinter/_decoding/_decoding_pymatching.py +81 -0
- sinter/_decoding/_decoding_test.py +480 -0
- sinter/_decoding/_decoding_vacuous.py +38 -0
- sinter/_decoding/_perfectionist_sampler.py +38 -0
- sinter/_decoding/_sampler.py +72 -0
- sinter/_decoding/_stim_then_decode_sampler.py +222 -0
- sinter/_decoding/_stim_then_decode_sampler_test.py +192 -0
- sinter/_plotting.py +619 -0
- sinter/_plotting_test.py +108 -0
- sinter/_predict.py +381 -0
- sinter/_predict_test.py +227 -0
- sinter/_probability_util.py +519 -0
- sinter/_probability_util_test.py +281 -0
- sinter-1.15.0.data/data/README.md +332 -0
- sinter-1.15.0.data/data/readme_example_plot.png +0 -0
- sinter-1.15.0.data/data/requirements.txt +4 -0
- sinter-1.15.0.dist-info/METADATA +354 -0
- sinter-1.15.0.dist-info/RECORD +62 -0
- sinter-1.15.0.dist-info/WHEEL +5 -0
- sinter-1.15.0.dist-info/entry_points.txt +2 -0
- sinter-1.15.0.dist-info/top_level.txt +1 -0
sinter/_plotting.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Callable, TypeVar, List, Any, Iterable, Optional, TYPE_CHECKING, Dict, Union, Literal, Tuple
|
|
3
|
+
from typing import Sequence
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from sinter._probability_util import fit_binomial, shot_error_rate_to_piece_error_rate, Fit
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import sinter
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
MARKERS: str = "ov*sp^<>8PhH+xXDd|" * 100
|
|
16
|
+
LINESTYLES: tuple[str, ...] = (
|
|
17
|
+
'solid',
|
|
18
|
+
'dotted',
|
|
19
|
+
'dashed',
|
|
20
|
+
'dashdot',
|
|
21
|
+
'loosely dotted',
|
|
22
|
+
'dotted',
|
|
23
|
+
'densely dotted',
|
|
24
|
+
'long dash with offset',
|
|
25
|
+
'loosely dashed',
|
|
26
|
+
'dashed',
|
|
27
|
+
'densely dashed',
|
|
28
|
+
'loosely dashdotted',
|
|
29
|
+
'dashdotted',
|
|
30
|
+
'densely dashdotted',
|
|
31
|
+
'dashdotdotted',
|
|
32
|
+
'loosely dashdotdotted',
|
|
33
|
+
'densely dashdotdotted',
|
|
34
|
+
)
|
|
35
|
+
T = TypeVar('T')
|
|
36
|
+
TVal = TypeVar('TVal')
|
|
37
|
+
TKey = TypeVar('TKey')
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def split_by(vs: Iterable[T], key_func: Callable[[T], Any]) -> List[List[T]]:
|
|
41
|
+
cur_key: Any = None
|
|
42
|
+
out: List[List[T]] = []
|
|
43
|
+
buf: List[T] = []
|
|
44
|
+
for item in vs:
|
|
45
|
+
key = key_func(item)
|
|
46
|
+
if key != cur_key:
|
|
47
|
+
cur_key = key
|
|
48
|
+
if buf:
|
|
49
|
+
out.append(buf)
|
|
50
|
+
buf = []
|
|
51
|
+
buf.append(item)
|
|
52
|
+
if buf:
|
|
53
|
+
out.append(buf)
|
|
54
|
+
return out
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class LooseCompare:
|
|
58
|
+
def __init__(self, val: Any):
|
|
59
|
+
self.val: Any = None
|
|
60
|
+
|
|
61
|
+
self.val = val.val if isinstance(val, LooseCompare) else val
|
|
62
|
+
|
|
63
|
+
def __lt__(self, other: Any) -> bool:
|
|
64
|
+
other_val = other.val if isinstance(other, LooseCompare) else other
|
|
65
|
+
if isinstance(self.val, (int, float)) and isinstance(other_val, (int, float)):
|
|
66
|
+
return self.val < other_val
|
|
67
|
+
if isinstance(self.val, (tuple, list)) and isinstance(other_val, (tuple, list)):
|
|
68
|
+
return tuple(LooseCompare(e) for e in self.val) < tuple(LooseCompare(e) for e in other_val)
|
|
69
|
+
return str(self.val) < str(other_val)
|
|
70
|
+
|
|
71
|
+
def __gt__(self, other: Any) -> bool:
|
|
72
|
+
other_val = other.val if isinstance(other, LooseCompare) else other
|
|
73
|
+
if isinstance(self.val, (int, float)) and isinstance(other_val, (int, float)):
|
|
74
|
+
return self.val > other_val
|
|
75
|
+
if isinstance(self.val, (tuple, list)) and isinstance(other_val, (tuple, list)):
|
|
76
|
+
return tuple(LooseCompare(e) for e in self.val) > tuple(LooseCompare(e) for e in other_val)
|
|
77
|
+
return str(self.val) > str(other_val)
|
|
78
|
+
|
|
79
|
+
def __str__(self) -> str:
|
|
80
|
+
return str(self.val)
|
|
81
|
+
|
|
82
|
+
def __eq__(self, other: Any) -> bool:
|
|
83
|
+
if isinstance(other, LooseCompare):
|
|
84
|
+
other_val = other.val
|
|
85
|
+
else:
|
|
86
|
+
other_val = other
|
|
87
|
+
if isinstance(self.val, (int, float)) and isinstance(other_val, (int, float)):
|
|
88
|
+
return self.val == other_val
|
|
89
|
+
return str(self.val) == str(other_val)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def better_sorted_str_terms(val: Any) -> Any:
|
|
93
|
+
"""A function that orders "a10000" after "a9", instead of before.
|
|
94
|
+
|
|
95
|
+
Normally, sorting strings sorts them lexicographically, treating numbers so
|
|
96
|
+
that "1999999" ends up being less than "2". This method splits the string
|
|
97
|
+
into a tuple of text pairs and parsed number parts, so that sorting by this
|
|
98
|
+
key puts "2" before "1999999".
|
|
99
|
+
|
|
100
|
+
Because this method is intended for use in plotting, where it's more
|
|
101
|
+
important to see a bad result than to see nothing, it returns a type that
|
|
102
|
+
tries to be comparable to everything.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
val: The value to convert into a value with a better sorting order.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
A custom type of object with a better sorting order.
|
|
109
|
+
|
|
110
|
+
Examples:
|
|
111
|
+
>>> import sinter
|
|
112
|
+
>>> items = [
|
|
113
|
+
... "distance=199999, rounds=3",
|
|
114
|
+
... "distance=2, rounds=3",
|
|
115
|
+
... "distance=199999, rounds=199999",
|
|
116
|
+
... "distance=2, rounds=199999",
|
|
117
|
+
... ]
|
|
118
|
+
>>> for e in sorted(items, key=sinter.better_sorted_str_terms):
|
|
119
|
+
... print(e)
|
|
120
|
+
distance=2, rounds=3
|
|
121
|
+
distance=2, rounds=199999
|
|
122
|
+
distance=199999, rounds=3
|
|
123
|
+
distance=199999, rounds=199999
|
|
124
|
+
"""
|
|
125
|
+
if val is None:
|
|
126
|
+
return 'None'
|
|
127
|
+
if isinstance(val, tuple):
|
|
128
|
+
return tuple(better_sorted_str_terms(e) for e in val)
|
|
129
|
+
if not isinstance(val, str):
|
|
130
|
+
return LooseCompare(val)
|
|
131
|
+
terms = split_by(val, lambda c: c in '.0123456789')
|
|
132
|
+
result = []
|
|
133
|
+
for term in terms:
|
|
134
|
+
term = ''.join(term)
|
|
135
|
+
if '.' in term:
|
|
136
|
+
try:
|
|
137
|
+
term = float(term)
|
|
138
|
+
except ValueError:
|
|
139
|
+
try:
|
|
140
|
+
term = tuple(int(e) for e in term.split('.'))
|
|
141
|
+
except ValueError:
|
|
142
|
+
pass
|
|
143
|
+
else:
|
|
144
|
+
try:
|
|
145
|
+
term = int(term)
|
|
146
|
+
except ValueError:
|
|
147
|
+
pass
|
|
148
|
+
result.append(term)
|
|
149
|
+
if len(result) == 1 and isinstance(result[0], (int, float)):
|
|
150
|
+
return LooseCompare(result[0])
|
|
151
|
+
return tuple(LooseCompare(e) for e in result)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def group_by(items: Iterable[TVal],
|
|
155
|
+
*,
|
|
156
|
+
key: Callable[[TVal], TKey],
|
|
157
|
+
) -> Dict[TKey, List[TVal]]:
|
|
158
|
+
"""Groups items based on whether they produce the same key from a function.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
items: The items to group.
|
|
162
|
+
key: Items that produce the same value from this function get grouped together.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
A dictionary mapping outputs that were produced by the grouping function to
|
|
166
|
+
the list of items that produced that output.
|
|
167
|
+
|
|
168
|
+
Examples:
|
|
169
|
+
>>> import sinter
|
|
170
|
+
>>> sinter.group_by([1, 2, 3], key=lambda i: i == 2)
|
|
171
|
+
{False: [1, 3], True: [2]}
|
|
172
|
+
|
|
173
|
+
>>> sinter.group_by(range(10), key=lambda i: i % 3)
|
|
174
|
+
{0: [0, 3, 6, 9], 1: [1, 4, 7], 2: [2, 5, 8]}
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
result: Dict[TKey, List[TVal]] = {}
|
|
178
|
+
|
|
179
|
+
for item in items:
|
|
180
|
+
curve_id = key(item)
|
|
181
|
+
result.setdefault(curve_id, []).append(item)
|
|
182
|
+
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
TCurveId = TypeVar('TCurveId')
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class _FrozenDict:
|
|
190
|
+
def __init__(self, v: dict):
|
|
191
|
+
self._v = dict(v)
|
|
192
|
+
self._eq = frozenset(v.items())
|
|
193
|
+
self._hash = hash(self._eq)
|
|
194
|
+
|
|
195
|
+
terms = []
|
|
196
|
+
for k in sorted(self._v.keys(), key=lambda e: (e != 'sort', e)):
|
|
197
|
+
terms.append(k)
|
|
198
|
+
terms.append(better_sorted_str_terms(self._v[k])
|
|
199
|
+
)
|
|
200
|
+
self._order = tuple(terms)
|
|
201
|
+
|
|
202
|
+
def __eq__(self, other):
|
|
203
|
+
if isinstance(other, _FrozenDict):
|
|
204
|
+
return self._eq == other._eq
|
|
205
|
+
return NotImplemented
|
|
206
|
+
|
|
207
|
+
def __lt__(self, other):
|
|
208
|
+
if isinstance(other, _FrozenDict):
|
|
209
|
+
return self._order < other._order
|
|
210
|
+
return NotImplemented
|
|
211
|
+
|
|
212
|
+
def __ne__(self, other):
|
|
213
|
+
return not (self == other)
|
|
214
|
+
|
|
215
|
+
def __hash__(self):
|
|
216
|
+
return self._hash
|
|
217
|
+
|
|
218
|
+
def __getitem__(self, item):
|
|
219
|
+
return self._v[item]
|
|
220
|
+
|
|
221
|
+
def get(self, item, alternate = None):
|
|
222
|
+
return self._v.get(item, alternate)
|
|
223
|
+
|
|
224
|
+
def __str__(self):
|
|
225
|
+
return " ".join(str(v) for _, v in sorted(self._v.items()))
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def plot_discard_rate(
|
|
229
|
+
*,
|
|
230
|
+
ax: 'plt.Axes',
|
|
231
|
+
stats: 'Iterable[sinter.TaskStats]',
|
|
232
|
+
x_func: Callable[['sinter.TaskStats'], Any],
|
|
233
|
+
failure_units_per_shot_func: Callable[['sinter.TaskStats'], Any] = lambda _: 1,
|
|
234
|
+
group_func: Callable[['sinter.TaskStats'], TCurveId] = lambda _: None,
|
|
235
|
+
filter_func: Callable[['sinter.TaskStats'], Any] = lambda _: True,
|
|
236
|
+
plot_args_func: Callable[[int, TCurveId, List['sinter.TaskStats']], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
|
|
237
|
+
highlight_max_likelihood_factor: Optional[float] = 1e3,
|
|
238
|
+
point_label_func: Callable[['sinter.TaskStats'], Any] = lambda _: None,
|
|
239
|
+
) -> None:
|
|
240
|
+
"""Plots discard rates in curves with uncertainty highlights.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
ax: The plt.Axes to plot onto. For example, the `ax` value from `fig, ax = plt.subplots(1, 1)`.
|
|
244
|
+
stats: The collected statistics to plot.
|
|
245
|
+
x_func: The X coordinate to use for each stat's data point. For example, this could be
|
|
246
|
+
`x_func=lambda stat: stat.json_metadata['physical_error_rate']`.
|
|
247
|
+
failure_units_per_shot_func: How many discard chances there are per shot. This rescales what the
|
|
248
|
+
discard rate means. By default, it is the discard rate per shot, but this allows
|
|
249
|
+
you to instead make it the discard rate per round. For example, if the metadata
|
|
250
|
+
associated with a shot has a field 'r' which is the number of rounds, then this can be
|
|
251
|
+
achieved with `failure_units_per_shot_func=lambda stats: stats.metadata['r']`.
|
|
252
|
+
group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
|
|
253
|
+
The statistics are grouped into curves based on whether or not they get the same result
|
|
254
|
+
out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
|
|
255
|
+
If the result of the function is a dictionary, then optional keys in the dictionary will
|
|
256
|
+
also control the plotting of each curve. Available keys are:
|
|
257
|
+
'label': the label added to the legend for the curve
|
|
258
|
+
'color': the color used for plotting the curve
|
|
259
|
+
'marker': the marker used for the curve
|
|
260
|
+
'linestyle': the linestyle used for the curve
|
|
261
|
+
'sort': the order in which the curves will be plotted and added to the legend
|
|
262
|
+
e.g. if two curves (with different resulting dictionaries from group_func) share the same
|
|
263
|
+
value for key 'marker', they will be plotted with the same marker.
|
|
264
|
+
Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
|
|
265
|
+
filter_func: Optional. When specified, some curves will not be plotted.
|
|
266
|
+
The statistics are filtered and only plotted if filter_func(stat) returns True.
|
|
267
|
+
For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
|
|
268
|
+
where the saved metadata indicates the basis was 'x'.
|
|
269
|
+
plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
|
|
270
|
+
`plot` and `fill_between` used to do the actual plotting. For example, this can be used
|
|
271
|
+
to specify markers and colors. Takes the index of the curve in sorted order and also a
|
|
272
|
+
curve_id (these will be 0 and None respectively if group_func is not specified). For example,
|
|
273
|
+
this could be:
|
|
274
|
+
|
|
275
|
+
plot_args_func=lambda index, curve_id: {'color': 'red'
|
|
276
|
+
if curve_id == 'pymatching'
|
|
277
|
+
else 'blue'}
|
|
278
|
+
|
|
279
|
+
highlight_max_likelihood_factor: Controls how wide the uncertainty highlight region around curves is.
|
|
280
|
+
Must be 1 or larger. Hypothesis probabilities at most that many times as unlikely as the max likelihood
|
|
281
|
+
hypothesis will be highlighted.
|
|
282
|
+
point_label_func: Optional. Specifies text to draw next to data points.
|
|
283
|
+
"""
|
|
284
|
+
if highlight_max_likelihood_factor is None:
|
|
285
|
+
highlight_max_likelihood_factor = 1
|
|
286
|
+
|
|
287
|
+
def y_func(stat: 'sinter.TaskStats') -> Union[float, 'sinter.Fit']:
|
|
288
|
+
result = fit_binomial(
|
|
289
|
+
num_shots=stat.shots,
|
|
290
|
+
num_hits=stat.discards,
|
|
291
|
+
max_likelihood_factor=highlight_max_likelihood_factor,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
pieces = failure_units_per_shot_func(stat)
|
|
295
|
+
result = Fit(
|
|
296
|
+
low=shot_error_rate_to_piece_error_rate(result.low, pieces=pieces),
|
|
297
|
+
best=shot_error_rate_to_piece_error_rate(result.best, pieces=pieces),
|
|
298
|
+
high=shot_error_rate_to_piece_error_rate(result.high, pieces=pieces),
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if highlight_max_likelihood_factor == 1:
|
|
302
|
+
return result.best
|
|
303
|
+
return result
|
|
304
|
+
|
|
305
|
+
plot_custom(
|
|
306
|
+
ax=ax,
|
|
307
|
+
stats=stats,
|
|
308
|
+
x_func=x_func,
|
|
309
|
+
y_func=y_func,
|
|
310
|
+
group_func=group_func,
|
|
311
|
+
filter_func=filter_func,
|
|
312
|
+
plot_args_func=plot_args_func,
|
|
313
|
+
point_label_func=point_label_func,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def plot_error_rate(
|
|
318
|
+
*,
|
|
319
|
+
ax: 'plt.Axes',
|
|
320
|
+
stats: 'Iterable[sinter.TaskStats]',
|
|
321
|
+
x_func: Callable[['sinter.TaskStats'], Any],
|
|
322
|
+
failure_units_per_shot_func: Callable[['sinter.TaskStats'], Any] = lambda _: 1,
|
|
323
|
+
failure_values_func: Callable[['sinter.TaskStats'], Any] = lambda _: 1,
|
|
324
|
+
group_func: Callable[['sinter.TaskStats'], TCurveId] = lambda _: None,
|
|
325
|
+
filter_func: Callable[['sinter.TaskStats'], Any] = lambda _: True,
|
|
326
|
+
plot_args_func: Callable[[int, TCurveId, List['sinter.TaskStats']], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
|
|
327
|
+
highlight_max_likelihood_factor: Optional[float] = 1e3,
|
|
328
|
+
line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None,
|
|
329
|
+
point_label_func: Callable[['sinter.TaskStats'], Any] = lambda _: None,
|
|
330
|
+
) -> None:
|
|
331
|
+
"""Plots error rates in curves with uncertainty highlights.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
ax: The plt.Axes to plot onto. For example, the `ax` value from `fig, ax = plt.subplots(1, 1)`.
|
|
335
|
+
stats: The collected statistics to plot.
|
|
336
|
+
x_func: The X coordinate to use for each stat's data point. For example, this could be
|
|
337
|
+
`x_func=lambda stat: stat.json_metadata['physical_error_rate']`.
|
|
338
|
+
failure_units_per_shot_func: How many error chances there are per shot. This rescales what the
|
|
339
|
+
logical error rate means. By default, it is the logical error rate per shot, but this allows
|
|
340
|
+
you to instead make it the logical error rate per round. For example, if the metadata
|
|
341
|
+
associated with a shot has a field 'r' which is the number of rounds, then this can be
|
|
342
|
+
achieved with `failure_units_per_shot_func=lambda stats: stats.metadata['r']`.
|
|
343
|
+
failure_values_func: How many independent ways there are for a shot to fail, such as
|
|
344
|
+
the number of independent observables in a memory experiment. This affects how the failure
|
|
345
|
+
units rescaling plays out (e.g. with 1 independent failure the "center" of the conversion
|
|
346
|
+
is at 50% whereas for 2 independent failures the "center" is at 75%).
|
|
347
|
+
group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
|
|
348
|
+
The statistics are grouped into curves based on whether or not they get the same result
|
|
349
|
+
out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
|
|
350
|
+
If the result of the function is a dictionary, then optional keys in the dictionary will
|
|
351
|
+
also control the plotting of each curve. Available keys are:
|
|
352
|
+
'label': the label added to the legend for the curve
|
|
353
|
+
'color': the color used for plotting the curve
|
|
354
|
+
'marker': the marker used for the curve
|
|
355
|
+
'linestyle': the linestyle used for the curve
|
|
356
|
+
'sort': the order in which the curves will be plotted and added to the legend
|
|
357
|
+
e.g. if two curves (with different resulting dictionaries from group_func) share the same
|
|
358
|
+
value for key 'marker', they will be plotted with the same marker.
|
|
359
|
+
Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
|
|
360
|
+
filter_func: Optional. When specified, some curves will not be plotted.
|
|
361
|
+
The statistics are filtered and only plotted if filter_func(stat) returns True.
|
|
362
|
+
For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
|
|
363
|
+
where the saved metadata indicates the basis was 'x'.
|
|
364
|
+
plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
|
|
365
|
+
`plot` and `fill_between` used to do the actual plotting. For example, this can be used
|
|
366
|
+
to specify markers and colors. Takes the index of the curve in sorted order and also a
|
|
367
|
+
curve_id (these will be 0 and None respectively if group_func is not specified). For example,
|
|
368
|
+
this could be:
|
|
369
|
+
|
|
370
|
+
plot_args_func=lambda index, curve_id: {'color': 'red'
|
|
371
|
+
if curve_id == 'pymatching'
|
|
372
|
+
else 'blue'}
|
|
373
|
+
|
|
374
|
+
highlight_max_likelihood_factor: Controls how wide the uncertainty highlight region around curves is.
|
|
375
|
+
Must be 1 or larger. Hypothesis probabilities at most that many times as unlikely as the max likelihood
|
|
376
|
+
hypothesis will be highlighted.
|
|
377
|
+
line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line
|
|
378
|
+
fit to every curve. The scales determine how to transform the coordinates before
|
|
379
|
+
performing the fit, and can be set to 'linear', 'sqrt', or 'log'.
|
|
380
|
+
point_label_func: Optional. Specifies text to draw next to data points.
|
|
381
|
+
"""
|
|
382
|
+
if highlight_max_likelihood_factor is None:
|
|
383
|
+
highlight_max_likelihood_factor = 1
|
|
384
|
+
if not (highlight_max_likelihood_factor >= 1):
|
|
385
|
+
raise ValueError(f"not (highlight_max_likelihood_factor={highlight_max_likelihood_factor} >= 1)")
|
|
386
|
+
|
|
387
|
+
def y_func(stat: 'sinter.TaskStats') -> Union[float, 'sinter.Fit']:
|
|
388
|
+
result = fit_binomial(
|
|
389
|
+
num_shots=stat.shots - stat.discards,
|
|
390
|
+
num_hits=stat.errors,
|
|
391
|
+
max_likelihood_factor=highlight_max_likelihood_factor,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
pieces = failure_units_per_shot_func(stat)
|
|
395
|
+
values = failure_values_func(stat)
|
|
396
|
+
result = Fit(
|
|
397
|
+
low=shot_error_rate_to_piece_error_rate(result.low, pieces=pieces, values=values),
|
|
398
|
+
best=shot_error_rate_to_piece_error_rate(result.best, pieces=pieces, values=values),
|
|
399
|
+
high=shot_error_rate_to_piece_error_rate(result.high, pieces=pieces, values=values),
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
if stat.errors == 0:
|
|
403
|
+
result = Fit(low=result.low, high=result.high, best=float('nan'))
|
|
404
|
+
|
|
405
|
+
if highlight_max_likelihood_factor == 1:
|
|
406
|
+
return result.best
|
|
407
|
+
return result
|
|
408
|
+
|
|
409
|
+
plot_custom(
|
|
410
|
+
ax=ax,
|
|
411
|
+
stats=stats,
|
|
412
|
+
x_func=x_func,
|
|
413
|
+
y_func=y_func,
|
|
414
|
+
group_func=group_func,
|
|
415
|
+
filter_func=filter_func,
|
|
416
|
+
plot_args_func=plot_args_func,
|
|
417
|
+
line_fits=line_fits,
|
|
418
|
+
point_label_func=point_label_func,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _rescale(v: Sequence[float], scale: str, invert: bool) -> np.ndarray:
|
|
423
|
+
if scale == 'linear':
|
|
424
|
+
return np.array(v)
|
|
425
|
+
elif scale == 'log':
|
|
426
|
+
return np.exp(v) if invert else np.log(v)
|
|
427
|
+
elif scale == 'sqrt':
|
|
428
|
+
return np.array(v)**2 if invert else np.sqrt(v)
|
|
429
|
+
else:
|
|
430
|
+
raise NotImplementedError(f'{scale=}')
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def plot_custom(
|
|
434
|
+
*,
|
|
435
|
+
ax: 'plt.Axes',
|
|
436
|
+
stats: 'Iterable[sinter.TaskStats]',
|
|
437
|
+
x_func: Callable[['sinter.TaskStats'], Any],
|
|
438
|
+
y_func: Callable[['sinter.TaskStats'], Union['sinter.Fit', float, int]],
|
|
439
|
+
group_func: Callable[['sinter.TaskStats'], TCurveId] = lambda _: None,
|
|
440
|
+
point_label_func: Callable[['sinter.TaskStats'], Any] = lambda _: None,
|
|
441
|
+
filter_func: Callable[['sinter.TaskStats'], Any] = lambda _: True,
|
|
442
|
+
plot_args_func: Callable[[int, TCurveId, List['sinter.TaskStats']], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
|
|
443
|
+
line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None,
|
|
444
|
+
) -> None:
|
|
445
|
+
"""Plots error rates in curves with uncertainty highlights.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
ax: The plt.Axes to plot onto. For example, the `ax` value from `fig, ax = plt.subplots(1, 1)`.
|
|
449
|
+
stats: The collected statistics to plot.
|
|
450
|
+
x_func: The X coordinate to use for each stat's data point. For example, this could be
|
|
451
|
+
`x_func=lambda stat: stat.json_metadata['physical_error_rate']`.
|
|
452
|
+
y_func: The Y value to use for each stat's data point. This can be a float or it can be a
|
|
453
|
+
sinter.Fit value, in which case the curve will follow the fit.best value and a
|
|
454
|
+
highlighted area will be shown from fit.low to fit.high.
|
|
455
|
+
group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
|
|
456
|
+
The statistics are grouped into curves based on whether or not they get the same result
|
|
457
|
+
out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
|
|
458
|
+
If the result of the function is a dictionary, then optional keys in the dictionary will
|
|
459
|
+
also control the plotting of each curve. Available keys are:
|
|
460
|
+
'label': the label added to the legend for the curve
|
|
461
|
+
'color': the color used for plotting the curve
|
|
462
|
+
'marker': the marker used for the curve
|
|
463
|
+
'linestyle': the linestyle used for the curve
|
|
464
|
+
'sort': the order in which the curves will be plotted and added to the legend
|
|
465
|
+
e.g. if two curves (with different resulting dictionaries from group_func) share the same
|
|
466
|
+
value for key 'marker', they will be plotted with the same marker.
|
|
467
|
+
Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
|
|
468
|
+
point_label_func: Optional. Specifies text to draw next to data points.
|
|
469
|
+
filter_func: Optional. When specified, some curves will not be plotted.
|
|
470
|
+
The statistics are filtered and only plotted if filter_func(stat) returns True.
|
|
471
|
+
For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
|
|
472
|
+
where the saved metadata indicates the basis was 'x'.
|
|
473
|
+
plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
|
|
474
|
+
`plot` and `fill_between` used to do the actual plotting. For example, this can be used
|
|
475
|
+
to specify markers and colors. Takes the index of the curve in sorted order and also a
|
|
476
|
+
curve_id (these will be 0 and None respectively if group_func is not specified). For example,
|
|
477
|
+
this could be:
|
|
478
|
+
|
|
479
|
+
plot_args_func=lambda index, group_key, group_stats: {
|
|
480
|
+
'color': (
|
|
481
|
+
'red'
|
|
482
|
+
if group_key == 'decoder=pymatching p=0.001'
|
|
483
|
+
else 'blue'
|
|
484
|
+
),
|
|
485
|
+
}
|
|
486
|
+
line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line
|
|
487
|
+
fit to every curve. The scales determine how to transform the coordinates before
|
|
488
|
+
performing the fit, and can be set to 'linear', 'sqrt', or 'log'.
|
|
489
|
+
"""
|
|
490
|
+
|
|
491
|
+
def group_dict_func(item: 'sinter.TaskStats') -> _FrozenDict:
|
|
492
|
+
e = group_func(item)
|
|
493
|
+
return _FrozenDict(e if isinstance(e, dict) else {'label': str(e)})
|
|
494
|
+
|
|
495
|
+
# Backwards compatibility to when the group stats argument wasn't present.
|
|
496
|
+
import inspect
|
|
497
|
+
if len(inspect.signature(plot_args_func).parameters) == 2:
|
|
498
|
+
old_plot_args_func = cast(Callable[[int, TCurveId], Any], plot_args_func)
|
|
499
|
+
plot_args_func = lambda a, b, _: old_plot_args_func(a, b)
|
|
500
|
+
|
|
501
|
+
filtered_stats: List['sinter.TaskStats'] = [
|
|
502
|
+
stat
|
|
503
|
+
for stat in stats
|
|
504
|
+
if filter_func(stat)
|
|
505
|
+
]
|
|
506
|
+
|
|
507
|
+
curve_groups = group_by(filtered_stats, key=group_dict_func)
|
|
508
|
+
colors = {
|
|
509
|
+
k: f'C{i}'
|
|
510
|
+
for i, k in enumerate(sorted({g.get('color', g) for g in curve_groups.keys()}, key=better_sorted_str_terms))
|
|
511
|
+
}
|
|
512
|
+
markers = {
|
|
513
|
+
k: MARKERS[i % len(MARKERS)]
|
|
514
|
+
for i, k in enumerate(sorted({g.get('marker', g) for g in curve_groups.keys()}, key=better_sorted_str_terms))
|
|
515
|
+
}
|
|
516
|
+
linestyles = {
|
|
517
|
+
k: LINESTYLES[i % len(LINESTYLES)]
|
|
518
|
+
for i, k in enumerate(sorted({g.get('linestyle', None) for g in curve_groups.keys()}, key=better_sorted_str_terms))
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
def sort_key(a: Any) -> Any:
|
|
522
|
+
if isinstance(a, _FrozenDict):
|
|
523
|
+
return a.get('sort', better_sorted_str_terms(a))
|
|
524
|
+
return better_sorted_str_terms(a)
|
|
525
|
+
|
|
526
|
+
for k, group_key in enumerate(sorted(curve_groups.keys(), key=sort_key)):
|
|
527
|
+
group = curve_groups[group_key]
|
|
528
|
+
group = sorted(group, key=x_func)
|
|
529
|
+
color = colors[group_key.get('color', group_key)]
|
|
530
|
+
marker = markers[group_key.get('marker', group_key)]
|
|
531
|
+
linestyle = linestyles[group_key.get('linestyle', None)]
|
|
532
|
+
label = str(group_key.get('label', group_key))
|
|
533
|
+
xs_label: list[float] = []
|
|
534
|
+
ys_label: list[float] = []
|
|
535
|
+
vs_label: list[float] = []
|
|
536
|
+
xs_best: list[float] = []
|
|
537
|
+
ys_best: list[float] = []
|
|
538
|
+
xs_low_high: list[float] = []
|
|
539
|
+
ys_low: list[float] = []
|
|
540
|
+
ys_high: list[float] = []
|
|
541
|
+
for item in group:
|
|
542
|
+
x = x_func(item)
|
|
543
|
+
y = y_func(item)
|
|
544
|
+
point_label = point_label_func(item)
|
|
545
|
+
if isinstance(y, Fit):
|
|
546
|
+
if y.low is not None and y.high is not None and not math.isnan(y.low) and not math.isnan(y.high):
|
|
547
|
+
xs_low_high.append(x)
|
|
548
|
+
ys_low.append(y.low)
|
|
549
|
+
ys_high.append(y.high)
|
|
550
|
+
if y.best is not None and not math.isnan(y.best):
|
|
551
|
+
ys_best.append(y.best)
|
|
552
|
+
xs_best.append(x)
|
|
553
|
+
|
|
554
|
+
if point_label:
|
|
555
|
+
cy = None
|
|
556
|
+
for e in [y.best, y.high, y.low]:
|
|
557
|
+
if e is not None and not math.isnan(e):
|
|
558
|
+
cy = e
|
|
559
|
+
break
|
|
560
|
+
if cy is not None:
|
|
561
|
+
xs_label.append(x)
|
|
562
|
+
ys_label.append(cy)
|
|
563
|
+
vs_label.append(point_label)
|
|
564
|
+
elif not math.isnan(y):
|
|
565
|
+
xs_best.append(x)
|
|
566
|
+
ys_best.append(y)
|
|
567
|
+
if point_label:
|
|
568
|
+
xs_label.append(x)
|
|
569
|
+
ys_label.append(y)
|
|
570
|
+
vs_label.append(point_label)
|
|
571
|
+
args = dict(plot_args_func(k, group_func(group[0]), group))
|
|
572
|
+
if 'linestyle' not in args:
|
|
573
|
+
args['linestyle'] = linestyle
|
|
574
|
+
if 'marker' not in args:
|
|
575
|
+
args['marker'] = marker
|
|
576
|
+
if 'color' not in args:
|
|
577
|
+
args['color'] = color
|
|
578
|
+
if 'label' not in args:
|
|
579
|
+
args['label'] = label
|
|
580
|
+
ax.plot(xs_best, ys_best, **args)
|
|
581
|
+
for x, y, lbl in zip(xs_label, ys_label, vs_label):
|
|
582
|
+
if lbl:
|
|
583
|
+
ax.annotate(lbl, (x, y))
|
|
584
|
+
if len(xs_low_high) > 1:
|
|
585
|
+
ax.fill_between(xs_low_high, ys_low, ys_high, color=args['color'], alpha=0.2, zorder=-100)
|
|
586
|
+
elif len(xs_low_high) == 1:
|
|
587
|
+
l, = ys_low
|
|
588
|
+
h, = ys_high
|
|
589
|
+
m = (l + h) / 2
|
|
590
|
+
ax.errorbar(xs_low_high, [m], yerr=([m - l], [h - m]), marker='', elinewidth=1, ecolor=color, capsize=5)
|
|
591
|
+
|
|
592
|
+
if line_fits is not None and len(set(xs_best)) >= 2:
|
|
593
|
+
x_scale, y_scale = line_fits
|
|
594
|
+
fit_xs = _rescale(xs_best, x_scale, False)
|
|
595
|
+
fit_ys = _rescale(ys_best, y_scale, False)
|
|
596
|
+
|
|
597
|
+
from scipy.stats import linregress
|
|
598
|
+
line_fit = linregress(fit_xs, fit_ys)
|
|
599
|
+
|
|
600
|
+
x0 = fit_xs[0]
|
|
601
|
+
x1 = fit_xs[-1]
|
|
602
|
+
dx = x1 - x0
|
|
603
|
+
x0 -= dx*10
|
|
604
|
+
x1 += dx*10
|
|
605
|
+
if x0 < 0 <= fit_xs[0] > x0 and x_scale == 'sqrt':
|
|
606
|
+
x0 = 0
|
|
607
|
+
|
|
608
|
+
out_xs = np.linspace(x0, x1, 1000)
|
|
609
|
+
out_ys = out_xs * line_fit.slope + line_fit.intercept
|
|
610
|
+
out_xs = _rescale(out_xs, x_scale, True)
|
|
611
|
+
out_ys = _rescale(out_ys, y_scale, True)
|
|
612
|
+
|
|
613
|
+
line_fit_kwargs = args.copy()
|
|
614
|
+
line_fit_kwargs.pop('marker', None)
|
|
615
|
+
line_fit_kwargs.pop('label', None)
|
|
616
|
+
line_fit_kwargs['linestyle'] = '--'
|
|
617
|
+
line_fit_kwargs.setdefault('linewidth', 1)
|
|
618
|
+
line_fit_kwargs['linewidth'] /= 2
|
|
619
|
+
ax.plot(out_xs, out_ys, **line_fit_kwargs)
|