accusleepy 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. accusleepy/__init__.py +0 -0
  2. accusleepy/__main__.py +4 -0
  3. accusleepy/bouts.py +142 -0
  4. accusleepy/brain_state_set.py +89 -0
  5. accusleepy/classification.py +285 -0
  6. accusleepy/config.json +24 -0
  7. accusleepy/constants.py +46 -0
  8. accusleepy/fileio.py +179 -0
  9. accusleepy/gui/__init__.py +0 -0
  10. accusleepy/gui/icons/brightness_down.png +0 -0
  11. accusleepy/gui/icons/brightness_up.png +0 -0
  12. accusleepy/gui/icons/double_down_arrow.png +0 -0
  13. accusleepy/gui/icons/double_up_arrow.png +0 -0
  14. accusleepy/gui/icons/down_arrow.png +0 -0
  15. accusleepy/gui/icons/home.png +0 -0
  16. accusleepy/gui/icons/question.png +0 -0
  17. accusleepy/gui/icons/save.png +0 -0
  18. accusleepy/gui/icons/up_arrow.png +0 -0
  19. accusleepy/gui/icons/zoom_in.png +0 -0
  20. accusleepy/gui/icons/zoom_out.png +0 -0
  21. accusleepy/gui/images/primary_window.png +0 -0
  22. accusleepy/gui/images/viewer_window.png +0 -0
  23. accusleepy/gui/images/viewer_window_annotated.png +0 -0
  24. accusleepy/gui/main.py +1494 -0
  25. accusleepy/gui/manual_scoring.py +1096 -0
  26. accusleepy/gui/mplwidget.py +386 -0
  27. accusleepy/gui/primary_window.py +2577 -0
  28. accusleepy/gui/primary_window.ui +3831 -0
  29. accusleepy/gui/resources.qrc +16 -0
  30. accusleepy/gui/resources_rc.py +6710 -0
  31. accusleepy/gui/text/config_guide.txt +27 -0
  32. accusleepy/gui/text/main_guide.md +167 -0
  33. accusleepy/gui/text/manual_scoring_guide.md +23 -0
  34. accusleepy/gui/viewer_window.py +610 -0
  35. accusleepy/gui/viewer_window.ui +926 -0
  36. accusleepy/models.py +108 -0
  37. accusleepy/multitaper.py +661 -0
  38. accusleepy/signal_processing.py +469 -0
  39. accusleepy/temperature_scaling.py +157 -0
  40. accusleepy-0.6.0.dist-info/METADATA +106 -0
  41. accusleepy-0.6.0.dist-info/RECORD +42 -0
  42. accusleepy-0.6.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,469 @@
1
+ import os
2
+ import warnings
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from PIL import Image
7
+ from tqdm import trange
8
+
9
+ from accusleepy.brain_state_set import BrainStateSet
10
+ from accusleepy.constants import (
11
+ ANNOTATIONS_FILENAME,
12
+ CALIBRATION_ANNOTATION_FILENAME,
13
+ DEFAULT_MODEL_TYPE,
14
+ DOWNSAMPLING_START_FREQ,
15
+ EMG_COPIES,
16
+ FILENAME_COL,
17
+ LABEL_COL,
18
+ MIN_WINDOW_LEN,
19
+ UPPER_FREQ,
20
+ )
21
+ from accusleepy.fileio import Recording, load_labels, load_recording
22
+ from accusleepy.multitaper import spectrogram
23
+
24
+ # note: scipy is lazily imported
25
+
26
+ # clip mixture z-scores above and below this level
27
+ # in the matlab implementation, I used 4.5
28
+ ABS_MAX_Z_SCORE = 3.5
29
+ # upper frequency limit when generating EEG spectrograms
30
+ SPECTROGRAM_UPPER_FREQ = 64
31
+
32
+
33
+ def resample(
34
+ eeg: np.array, emg: np.array, sampling_rate: int | float, epoch_length: int | float
35
+ ) -> (np.array, np.array, float):
36
+ """Resample recording so that epochs contain equal numbers of samples
37
+
38
+ If the number of samples per epoch is not an integer, epoch-level calculations
39
+ are much more difficult. To avoid this, we can resample the EEG and EMG signals
40
+ and adjust the sampling rate accordingly.
41
+
42
+ :param eeg: EEG signal
43
+ :param emg: EMG signal
44
+ :param sampling_rate: original sampling rate, in Hz
45
+ :param epoch_length: epoch length, in seconds
46
+ :return: resampled EEG & EMG and updated sampling rate
47
+ """
48
+ samples_per_epoch = sampling_rate * epoch_length
49
+ if samples_per_epoch % 1 == 0:
50
+ return eeg, emg, sampling_rate
51
+
52
+ resampled = list()
53
+ for arr in [eeg, emg]:
54
+ x = np.arange(0, arr.size)
55
+ x_new = np.linspace(
56
+ 0,
57
+ arr.size - 1,
58
+ round(arr.size * np.ceil(samples_per_epoch) / samples_per_epoch),
59
+ )
60
+ resampled.append(np.interp(x_new, x, arr))
61
+
62
+ eeg = resampled[0]
63
+ emg = resampled[1]
64
+ new_sampling_rate = np.ceil(samples_per_epoch) / samples_per_epoch * sampling_rate
65
+ return eeg, emg, new_sampling_rate
66
+
67
+
68
+ def standardize_signal_length(
69
+ eeg: np.array, emg: np.array, sampling_rate: int | float, epoch_length: int | float
70
+ ) -> (np.array, np.array):
71
+ """Truncate or pad EEG/EMG signals to have an integer number of epochs
72
+
73
+ :param eeg: EEG signal
74
+ :param emg: EMG signal
75
+ :param sampling_rate: original sampling rate, in Hz
76
+ :param epoch_length: epoch length, in seconds
77
+ :return: EEG and EMG signals
78
+ """
79
+ # since resample() was called, this will be extremely close to an integer
80
+ samples_per_epoch = round(sampling_rate * epoch_length)
81
+
82
+ # pad the signal at the end in case we need more samples
83
+ eeg = np.concatenate((eeg, np.ones(samples_per_epoch) * eeg[-1]))
84
+ emg = np.concatenate((emg, np.ones(samples_per_epoch) * emg[-1]))
85
+ padded_signal_length = eeg.size
86
+
87
+ # count samples that don't fit in any epoch
88
+ excess_samples = padded_signal_length % samples_per_epoch
89
+ # we will definitely remove those
90
+ last_index = padded_signal_length - excess_samples
91
+ # and if the last epoch of real data had more than half of
92
+ # its samples missing, delete it
93
+ if excess_samples < samples_per_epoch / 2:
94
+ last_index -= samples_per_epoch
95
+
96
+ return eeg[:last_index], emg[:last_index]
97
+
98
+
99
+ def resample_and_standardize(
100
+ eeg: np.array, emg: np.array, sampling_rate: int | float, epoch_length: int | float
101
+ ) -> (np.array, np.array, float):
102
+ """Preprocess EEG and EMG signals
103
+
104
+ Adjust the length and sampling rate of the EEG and EMG signals so that
105
+ each epoch contains an integer number of samples and each recording
106
+ contains an integer number of epochs.
107
+
108
+ :param eeg: EEG signal
109
+ :param emg: EMG signal
110
+ :param sampling_rate: sampling rate, in Hz
111
+ :param epoch_length: epoch length, in seconds
112
+ :return: processed EEG & EMG signals, and the new sampling rate
113
+ """
114
+ eeg, emg, sampling_rate = resample(
115
+ eeg=eeg, emg=emg, sampling_rate=sampling_rate, epoch_length=epoch_length
116
+ )
117
+ eeg, emg = standardize_signal_length(
118
+ eeg=eeg, emg=emg, sampling_rate=sampling_rate, epoch_length=epoch_length
119
+ )
120
+ return eeg, emg, sampling_rate
121
+
122
+
123
+ def create_spectrogram(
124
+ eeg: np.array,
125
+ sampling_rate: int | float,
126
+ epoch_length: int | float,
127
+ time_bandwidth=2,
128
+ n_tapers=3,
129
+ ) -> (np.array, np.array):
130
+ """Create an EEG spectrogram image
131
+
132
+ :param eeg: EEG signal
133
+ :param sampling_rate: sampling rate, in Hz
134
+ :param epoch_length: epoch length, in seconds
135
+ :param time_bandwidth: time-half bandwidth product
136
+ :param n_tapers: number of DPSS tapers to use
137
+ :return: spectrogram and its frequency axis
138
+ """
139
+ window_length_sec = max(MIN_WINDOW_LEN, epoch_length)
140
+ # pad the EEG signal so that the first spectrogram window is centered
141
+ # on the first epoch
142
+ # it's possible there's some jank here, if this isn't close to an integer
143
+ pad_length = round((sampling_rate * (window_length_sec - epoch_length) / 2))
144
+ padded_eeg = np.concatenate(
145
+ [eeg[:pad_length][::-1], eeg, eeg[(len(eeg) - pad_length) :][::-1]]
146
+ )
147
+
148
+ spec, _, f = spectrogram(
149
+ padded_eeg,
150
+ sampling_rate,
151
+ frequency_range=[0, SPECTROGRAM_UPPER_FREQ],
152
+ time_bandwidth=time_bandwidth,
153
+ num_tapers=n_tapers,
154
+ window_params=[window_length_sec, epoch_length],
155
+ min_nfft=0,
156
+ detrend_opt="off",
157
+ multiprocess=True,
158
+ plot_on=False,
159
+ return_fig=False,
160
+ verbose=False,
161
+ )
162
+
163
+ # resample frequencies for consistency
164
+ target_frequencies = np.arange(0, SPECTROGRAM_UPPER_FREQ, 1 / MIN_WINDOW_LEN)
165
+ freq_idx = list()
166
+ for i in target_frequencies:
167
+ freq_idx.append(np.argmin(np.abs(f - i)))
168
+ f = f[freq_idx]
169
+ spec = spec[freq_idx, :]
170
+
171
+ return spec, f
172
+
173
+
174
+ def get_emg_power(
175
+ emg: np.array, sampling_rate: int | float, epoch_length: int | float
176
+ ) -> np.array:
177
+ """Calculate EMG power for each epoch
178
+
179
+ This applies a 20-50 Hz bandpass filter to the EMG, calculates the RMS
180
+ in each epoch, and takes the log of the result.
181
+
182
+ :param emg: EMG signal
183
+ :param sampling_rate: sampling rate, in Hz
184
+ :param epoch_length: epoch length, in seconds
185
+ :return: EMG "power" for each epoch
186
+ """
187
+ from scipy.signal import butter, filtfilt
188
+
189
+ # filter parameters
190
+ order = 8
191
+ bp_lower = 20
192
+ bp_upper = 50
193
+
194
+ b, a = butter(
195
+ N=order,
196
+ Wn=[bp_lower, bp_upper],
197
+ btype="bandpass",
198
+ output="ba",
199
+ fs=sampling_rate,
200
+ )
201
+ filtered = filtfilt(b, a, x=emg, padlen=int(np.ceil(sampling_rate)))
202
+
203
+ # since resample() was called, this will be extremely close to an integer
204
+ samples_per_epoch = round(sampling_rate * epoch_length)
205
+ reshaped = np.reshape(
206
+ filtered,
207
+ [round(len(emg) / samples_per_epoch), samples_per_epoch],
208
+ )
209
+ rms = np.sqrt(np.mean(np.power(reshaped, 2), axis=1))
210
+
211
+ return np.log(rms)
212
+
213
+
214
+ def create_eeg_emg_image(
215
+ eeg: np.array,
216
+ emg: np.array,
217
+ sampling_rate: int | float,
218
+ epoch_length: int | float,
219
+ ) -> np.array:
220
+ """Stack EEG spectrogram and EMG power into an image
221
+
222
+ This assumes that each epoch contains an integer number of samples and
223
+ each recording contains an integer number of epochs. Note that a log
224
+ transformation is applied to the spectrogram.
225
+
226
+ :param eeg: EEG signal
227
+ :param emg: EMG signal
228
+ :param sampling_rate: sampling rate, in Hz
229
+ :param epoch_length: epoch length, in seconds
230
+ :return: combined EEG + EMG image for a recording
231
+ """
232
+ spec, f = create_spectrogram(eeg, sampling_rate, epoch_length)
233
+ f_lower_idx = sum(f < DOWNSAMPLING_START_FREQ)
234
+ f_upper_idx = sum(f < UPPER_FREQ)
235
+
236
+ modified_spectrogram = np.log(
237
+ spec[
238
+ np.concatenate(
239
+ [np.arange(0, f_lower_idx), np.arange(f_lower_idx, f_upper_idx, 2)]
240
+ ),
241
+ :,
242
+ ]
243
+ )
244
+
245
+ emg_log_rms = get_emg_power(emg, sampling_rate, epoch_length)
246
+ output = np.concatenate(
247
+ [modified_spectrogram, np.tile(emg_log_rms, (EMG_COPIES, 1))]
248
+ )
249
+ return output
250
+
251
+
252
+ def get_mixture_values(
253
+ img: np.array, labels: np.array, brain_state_set: BrainStateSet
254
+ ) -> (np.array, np.array):
255
+ """Compute weighted feature means and SDs for mixture z-scoring
256
+
257
+ The outputs of this function can be used to standardize features
258
+ extracted from all recordings from one subject under the same
259
+ recording conditions. Note that labels must be in "class" format
260
+ (i.e., integers between 0 and the number of scored states).
261
+
262
+ :param img: combined EEG + EMG image - see create_eeg_emg_image()
263
+ :param labels: brain state labels, in "class" format
264
+ :param brain_state_set: set of brain state options
265
+ :return: mixture means, mixture standard deviations
266
+ """
267
+
268
+ means = list()
269
+ variances = list()
270
+ mixture_weights = brain_state_set.mixture_weights
271
+
272
+ # get feature means, variances by class
273
+ for i in range(brain_state_set.n_classes):
274
+ means.append(np.mean(img[:, labels == i], axis=1))
275
+ variances.append(np.var(img[:, labels == i], axis=1))
276
+ means = np.array(means)
277
+ variances = np.array(variances)
278
+
279
+ # mixture means are just weighted averages across classes
280
+ mixture_means = means.T @ mixture_weights
281
+ # mixture variance is given by the law of total variance
282
+ mixture_sds = np.sqrt(
283
+ variances.T @ mixture_weights
284
+ + (
285
+ (mixture_means - np.tile(mixture_means, (brain_state_set.n_classes, 1)))
286
+ ** 2
287
+ ).T
288
+ @ mixture_weights
289
+ )
290
+
291
+ return mixture_means, mixture_sds
292
+
293
+
294
+ def mixture_z_score_img(
295
+ img: np.array,
296
+ brain_state_set: BrainStateSet,
297
+ labels: np.array = None,
298
+ mixture_means: np.array = None,
299
+ mixture_sds: np.array = None,
300
+ ) -> np.array:
301
+ """Perform mixture z-scoring on a combined EEG+EMG image
302
+
303
+ If brain state labels are provided, they will be used to calculate
304
+ mixture means and SDs. Otherwise, you must provide those inputs.
305
+ Note that pixel values in the output are in the 0-1 range and will
306
+ clip z-scores beyond ABS_MAX_Z_SCORE.
307
+
308
+ :param img: combined EEG + EMG image - see create_eeg_emg_image()
309
+ :param brain_state_set: set of brain state options
310
+ :param labels: labels, in "class" format
311
+ :param mixture_means: mixture means
312
+ :param mixture_sds: mixture standard deviations
313
+ :return:
314
+ """
315
+ if labels is None and (mixture_means is None or mixture_sds is None):
316
+ raise Exception("must provide either labels or mixture means+SDs")
317
+ if labels is not None and ((mixture_means is not None) ^ (mixture_sds is not None)):
318
+ warnings.warn("labels were given, mixture means / SDs will be ignored")
319
+
320
+ if labels is not None:
321
+ mixture_means, mixture_sds = get_mixture_values(
322
+ img=img, labels=labels, brain_state_set=brain_state_set
323
+ )
324
+
325
+ img = ((img.T - mixture_means) / mixture_sds).T
326
+ img = (img + ABS_MAX_Z_SCORE) / (2 * ABS_MAX_Z_SCORE)
327
+ img = np.clip(img, 0, 1)
328
+
329
+ return img
330
+
331
+
332
+ def format_img(img: np.array, epochs_per_img: int, add_padding: bool) -> np.array:
333
+ """Adjust the format of an EEG+EMG image
334
+
335
+ This function converts the values in a combined EEG+EMG image to uint8.
336
+ This is a convenient format both for storing individual images as files,
337
+ and for using the images as input to a classifier.
338
+ This function also optionally adds new epochs to the beginning/end of the
339
+ recording's image so that an image can be created for every epoch. For
340
+ real-time scoring, padding should not be used.
341
+
342
+ :param img: combined EEG + EMG image
343
+ :param epochs_per_img: number of epochs in each individual image
344
+ :param add_padding: whether to pad each side by (epochs_per_img - 1) / 2
345
+ :return: formatted EEG + EMG image
346
+ """
347
+ # pad beginning and end
348
+ if add_padding:
349
+ pad_width = round((epochs_per_img - 1) / 2)
350
+ img = np.concatenate(
351
+ [
352
+ np.tile(img[:, 0], (pad_width, 1)).T,
353
+ img,
354
+ np.tile(img[:, -1], (pad_width, 1)).T,
355
+ ],
356
+ axis=1,
357
+ )
358
+
359
+ # use 8-bit values
360
+ img = np.clip(img * 255, 0, 255)
361
+ img = img.astype(np.uint8)
362
+
363
+ return img
364
+
365
+
366
+ def create_training_images(
367
+ recordings: list[Recording],
368
+ output_path: str,
369
+ epoch_length: int | float,
370
+ epochs_per_img: int,
371
+ brain_state_set: BrainStateSet,
372
+ model_type: str,
373
+ calibration_fraction: float,
374
+ ) -> list[int]:
375
+ """Create training dataset
376
+
377
+ By default, the current epoch is located in the central column
378
+ of pixels in each image. For real-time scoring applications,
379
+ the current epoch is at the right edge of each image.
380
+
381
+ :param recordings: list of recordings in the training set
382
+ :param output_path: where to store training images
383
+ :param epoch_length: epoch length, in seconds
384
+ :param epochs_per_img: # number of epochs shown in each image
385
+ :param brain_state_set: set of brain state options
386
+ :param model_type: default or real-time
387
+ :param calibration_fraction: fraction of training data to use for calibration
388
+ :return: list of the names of any recordings that could not
389
+ be used to create training images.
390
+ """
391
+ # recordings that had to be skipped
392
+ failed_recordings = list()
393
+ # image filenames for valid epochs
394
+ filenames = list()
395
+ # all valid labels from all valid recordings
396
+ all_labels = list()
397
+ # try to load each recording and create training images
398
+ for i in trange(len(recordings)):
399
+ recording = recordings[i]
400
+ try:
401
+ eeg, emg = load_recording(recording.recording_file)
402
+ sampling_rate = recording.sampling_rate
403
+ eeg, emg, sampling_rate = resample_and_standardize(
404
+ eeg=eeg,
405
+ emg=emg,
406
+ sampling_rate=sampling_rate,
407
+ epoch_length=epoch_length,
408
+ )
409
+
410
+ labels, _ = load_labels(recording.label_file)
411
+ labels = brain_state_set.convert_digit_to_class(labels)
412
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
413
+ img = mixture_z_score_img(
414
+ img=img, brain_state_set=brain_state_set, labels=labels
415
+ )
416
+ img = format_img(img=img, epochs_per_img=epochs_per_img, add_padding=True)
417
+
418
+ # the model type determines which epochs are used in each image
419
+ if model_type == DEFAULT_MODEL_TYPE:
420
+ # here, j is the index of the current epoch in 'labels'
421
+ # and the index of the leftmost epoch in 'img'
422
+ for j in range(img.shape[1] - (epochs_per_img - 1)):
423
+ if labels[j] is None:
424
+ continue
425
+ im = img[:, j : (j + epochs_per_img)]
426
+ filename = f"recording_{recording.name}_{j}_{labels[j]}.png"
427
+ filenames.append(filename)
428
+ all_labels.append(labels[j])
429
+ Image.fromarray(im).save(os.path.join(output_path, filename))
430
+ else:
431
+ # here, j is the index of the current epoch in 'labels'
432
+ # but we throw away a few epochs at the start since they
433
+ # would require even more padding on the left side.
434
+ one_side_padding = round((epochs_per_img - 1) / 2)
435
+ for j in range(one_side_padding, len(labels)):
436
+ if labels[j] is None:
437
+ continue
438
+ im = img[:, (j - one_side_padding) : j + one_side_padding + 1]
439
+ filename = f"recording_{recording.name}_{j}_{labels[j]}.png"
440
+ filenames.append(filename)
441
+ all_labels.append(labels[j])
442
+ Image.fromarray(im).save(os.path.join(output_path, filename))
443
+
444
+ except Exception as e:
445
+ print(e)
446
+ failed_recordings.append(recording.name)
447
+
448
+ annotations = pd.DataFrame({FILENAME_COL: filenames, LABEL_COL: all_labels})
449
+
450
+ # split into training and calibration sets, if necessary
451
+ if calibration_fraction > 0:
452
+ calibration_set = annotations.sample(frac=calibration_fraction)
453
+ training_set = annotations.drop(calibration_set.index)
454
+ training_set.to_csv(
455
+ os.path.join(output_path, ANNOTATIONS_FILENAME),
456
+ index=False,
457
+ )
458
+ calibration_set.to_csv(
459
+ os.path.join(output_path, CALIBRATION_ANNOTATION_FILENAME),
460
+ index=False,
461
+ )
462
+ else:
463
+ # annotation file contains info on all training images
464
+ annotations.to_csv(
465
+ os.path.join(output_path, ANNOTATIONS_FILENAME),
466
+ index=False,
467
+ )
468
+
469
+ return failed_recordings
@@ -0,0 +1,157 @@
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn, optim
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class ModelWithTemperature(nn.Module):
8
+ """
9
+ A thin decorator, which wraps a model with temperature scaling
10
+ model (nn.Module):
11
+ A classification neural network
12
+ NB: Output of the neural network should be the classification logits,
13
+ NOT the softmax (or log softmax)!
14
+ """
15
+
16
+ def __init__(self, model):
17
+ super(ModelWithTemperature, self).__init__()
18
+ self.model = model
19
+ # https://github.com/gpleiss/temperature_scaling/issues/20
20
+ # for another approach, see https://github.com/gpleiss/temperature_scaling/issues/36
21
+ self.model.eval()
22
+ self.temperature = nn.Parameter(torch.ones(1) * 1.5)
23
+
24
+ def forward(self, x):
25
+ logits = self.model(x)
26
+ return self.temperature_scale(logits)
27
+
28
+ def temperature_scale(self, logits):
29
+ """
30
+ Perform temperature scaling on logits
31
+ """
32
+ # Expand temperature to match the size of logits
33
+ temperature = self.temperature.unsqueeze(1).expand(
34
+ logits.size(0), logits.size(1)
35
+ )
36
+ return logits / temperature
37
+
38
+ # This function probably should live outside of this class, but whatever
39
+ def set_temperature(self, valid_loader):
40
+ """
41
+ Tune the temperature of the model (using the validation set).
42
+ We're going to set it to optimize NLL.
43
+ valid_loader (DataLoader): validation set loader
44
+ """
45
+ if torch.accelerator.is_available():
46
+ device = torch.accelerator.current_accelerator().type
47
+ else:
48
+ device = "cpu"
49
+
50
+ # self.cuda()
51
+ self.to(device)
52
+ nll_criterion = nn.CrossEntropyLoss().to(device) # .cuda()
53
+ ece_criterion = _ECELoss().to(device) # .cuda()
54
+
55
+ # First: collect all the logits and labels for the validation set
56
+ logits_list = []
57
+ labels_list = []
58
+ prediction_list = []
59
+ with torch.no_grad():
60
+ for x, label in valid_loader:
61
+ x = x.to(device) # .cuda()
62
+ logits = self.model(x)
63
+ logits_list.append(logits)
64
+ labels_list.append(label)
65
+
66
+ _, pred = torch.max(logits, 1)
67
+ prediction_list.append(pred)
68
+ logits = torch.cat(logits_list).to(device) # .cuda()
69
+ labels = torch.cat(labels_list).to(device) # .cuda()
70
+ predictions = torch.cat(prediction_list).to(device)
71
+
72
+ # Calculate NLL and ECE before temperature scaling
73
+ before_temperature_nll = nll_criterion(logits, labels).item()
74
+ before_temperature_ece = ece_criterion(logits, labels).item()
75
+ print(
76
+ "Before temperature - NLL: %.3f, ECE: %.3f"
77
+ % (before_temperature_nll, before_temperature_ece)
78
+ )
79
+
80
+ # Next: optimize the temperature w.r.t. NLL
81
+ # https://github.com/gpleiss/temperature_scaling/issues/34
82
+ optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=100)
83
+
84
+ def eval():
85
+ optimizer.zero_grad()
86
+ loss = nll_criterion(self.temperature_scale(logits), labels)
87
+ loss.backward()
88
+ return loss
89
+
90
+ optimizer.step(eval)
91
+
92
+ # Calculate NLL and ECE after temperature scaling
93
+ after_temperature_nll = nll_criterion(
94
+ self.temperature_scale(logits), labels
95
+ ).item()
96
+ after_temperature_ece = ece_criterion(
97
+ self.temperature_scale(logits), labels
98
+ ).item()
99
+ print("Optimal temperature: %.3f" % self.temperature.item())
100
+ print(
101
+ "After temperature - NLL: %.3f, ECE: %.3f"
102
+ % (after_temperature_nll, after_temperature_ece)
103
+ )
104
+
105
+ val_acc = round(
106
+ 100 * np.mean(labels.cpu().numpy() == predictions.cpu().numpy()), 2
107
+ )
108
+ print(f"Validation accuracy: {val_acc}%")
109
+
110
+ return self
111
+
112
+
113
+ class _ECELoss(nn.Module):
114
+ """
115
+ Calculates the Expected Calibration Error of a model.
116
+ (This isn't necessary for temperature scaling, just a cool metric).
117
+
118
+ The input to this loss is the logits of a model, NOT the softmax scores.
119
+
120
+ This divides the confidence outputs into equally-sized interval bins.
121
+ In each bin, we compute the confidence gap:
122
+
123
+ bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
124
+
125
+ We then return a weighted average of the gaps, based on the number
126
+ of samples in each bin
127
+
128
+ See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
129
+ "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
130
+ 2015.
131
+ """
132
+
133
+ def __init__(self, n_bins=15):
134
+ """
135
+ n_bins (int): number of confidence interval bins
136
+ """
137
+ super(_ECELoss, self).__init__()
138
+ bin_boundaries = torch.linspace(0, 1, n_bins + 1)
139
+ self.bin_lowers = bin_boundaries[:-1]
140
+ self.bin_uppers = bin_boundaries[1:]
141
+
142
+ def forward(self, logits, labels):
143
+ softmaxes = F.softmax(logits, dim=1)
144
+ confidences, predictions = torch.max(softmaxes, 1)
145
+ accuracies = predictions.eq(labels)
146
+
147
+ ece = torch.zeros(1, device=logits.device)
148
+ for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
149
+ # Calculated |confidence - accuracy| in each bin
150
+ in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
151
+ prop_in_bin = in_bin.float().mean()
152
+ if prop_in_bin.item() > 0:
153
+ accuracy_in_bin = accuracies[in_bin].float().mean()
154
+ avg_confidence_in_bin = confidences[in_bin].mean()
155
+ ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
156
+
157
+ return ece