accusleepy 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/__init__.py +0 -0
- accusleepy/__main__.py +4 -0
- accusleepy/bouts.py +142 -0
- accusleepy/brain_state_set.py +89 -0
- accusleepy/classification.py +285 -0
- accusleepy/config.json +24 -0
- accusleepy/constants.py +46 -0
- accusleepy/fileio.py +179 -0
- accusleepy/gui/__init__.py +0 -0
- accusleepy/gui/icons/brightness_down.png +0 -0
- accusleepy/gui/icons/brightness_up.png +0 -0
- accusleepy/gui/icons/double_down_arrow.png +0 -0
- accusleepy/gui/icons/double_up_arrow.png +0 -0
- accusleepy/gui/icons/down_arrow.png +0 -0
- accusleepy/gui/icons/home.png +0 -0
- accusleepy/gui/icons/question.png +0 -0
- accusleepy/gui/icons/save.png +0 -0
- accusleepy/gui/icons/up_arrow.png +0 -0
- accusleepy/gui/icons/zoom_in.png +0 -0
- accusleepy/gui/icons/zoom_out.png +0 -0
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/images/viewer_window.png +0 -0
- accusleepy/gui/images/viewer_window_annotated.png +0 -0
- accusleepy/gui/main.py +1494 -0
- accusleepy/gui/manual_scoring.py +1096 -0
- accusleepy/gui/mplwidget.py +386 -0
- accusleepy/gui/primary_window.py +2577 -0
- accusleepy/gui/primary_window.ui +3831 -0
- accusleepy/gui/resources.qrc +16 -0
- accusleepy/gui/resources_rc.py +6710 -0
- accusleepy/gui/text/config_guide.txt +27 -0
- accusleepy/gui/text/main_guide.md +167 -0
- accusleepy/gui/text/manual_scoring_guide.md +23 -0
- accusleepy/gui/viewer_window.py +610 -0
- accusleepy/gui/viewer_window.ui +926 -0
- accusleepy/models.py +108 -0
- accusleepy/multitaper.py +661 -0
- accusleepy/signal_processing.py +469 -0
- accusleepy/temperature_scaling.py +157 -0
- accusleepy-0.6.0.dist-info/METADATA +106 -0
- accusleepy-0.6.0.dist-info/RECORD +42 -0
- accusleepy-0.6.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from tqdm import trange
|
|
8
|
+
|
|
9
|
+
from accusleepy.brain_state_set import BrainStateSet
|
|
10
|
+
from accusleepy.constants import (
|
|
11
|
+
ANNOTATIONS_FILENAME,
|
|
12
|
+
CALIBRATION_ANNOTATION_FILENAME,
|
|
13
|
+
DEFAULT_MODEL_TYPE,
|
|
14
|
+
DOWNSAMPLING_START_FREQ,
|
|
15
|
+
EMG_COPIES,
|
|
16
|
+
FILENAME_COL,
|
|
17
|
+
LABEL_COL,
|
|
18
|
+
MIN_WINDOW_LEN,
|
|
19
|
+
UPPER_FREQ,
|
|
20
|
+
)
|
|
21
|
+
from accusleepy.fileio import Recording, load_labels, load_recording
|
|
22
|
+
from accusleepy.multitaper import spectrogram
|
|
23
|
+
|
|
24
|
+
# note: scipy is lazily imported
|
|
25
|
+
|
|
26
|
+
# clip mixture z-scores above and below this level
|
|
27
|
+
# in the matlab implementation, I used 4.5
|
|
28
|
+
ABS_MAX_Z_SCORE = 3.5
|
|
29
|
+
# upper frequency limit when generating EEG spectrograms
|
|
30
|
+
SPECTROGRAM_UPPER_FREQ = 64
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def resample(
|
|
34
|
+
eeg: np.array, emg: np.array, sampling_rate: int | float, epoch_length: int | float
|
|
35
|
+
) -> (np.array, np.array, float):
|
|
36
|
+
"""Resample recording so that epochs contain equal numbers of samples
|
|
37
|
+
|
|
38
|
+
If the number of samples per epoch is not an integer, epoch-level calculations
|
|
39
|
+
are much more difficult. To avoid this, we can resample the EEG and EMG signals
|
|
40
|
+
and adjust the sampling rate accordingly.
|
|
41
|
+
|
|
42
|
+
:param eeg: EEG signal
|
|
43
|
+
:param emg: EMG signal
|
|
44
|
+
:param sampling_rate: original sampling rate, in Hz
|
|
45
|
+
:param epoch_length: epoch length, in seconds
|
|
46
|
+
:return: resampled EEG & EMG and updated sampling rate
|
|
47
|
+
"""
|
|
48
|
+
samples_per_epoch = sampling_rate * epoch_length
|
|
49
|
+
if samples_per_epoch % 1 == 0:
|
|
50
|
+
return eeg, emg, sampling_rate
|
|
51
|
+
|
|
52
|
+
resampled = list()
|
|
53
|
+
for arr in [eeg, emg]:
|
|
54
|
+
x = np.arange(0, arr.size)
|
|
55
|
+
x_new = np.linspace(
|
|
56
|
+
0,
|
|
57
|
+
arr.size - 1,
|
|
58
|
+
round(arr.size * np.ceil(samples_per_epoch) / samples_per_epoch),
|
|
59
|
+
)
|
|
60
|
+
resampled.append(np.interp(x_new, x, arr))
|
|
61
|
+
|
|
62
|
+
eeg = resampled[0]
|
|
63
|
+
emg = resampled[1]
|
|
64
|
+
new_sampling_rate = np.ceil(samples_per_epoch) / samples_per_epoch * sampling_rate
|
|
65
|
+
return eeg, emg, new_sampling_rate
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def standardize_signal_length(
|
|
69
|
+
eeg: np.array, emg: np.array, sampling_rate: int | float, epoch_length: int | float
|
|
70
|
+
) -> (np.array, np.array):
|
|
71
|
+
"""Truncate or pad EEG/EMG signals to have an integer number of epochs
|
|
72
|
+
|
|
73
|
+
:param eeg: EEG signal
|
|
74
|
+
:param emg: EMG signal
|
|
75
|
+
:param sampling_rate: original sampling rate, in Hz
|
|
76
|
+
:param epoch_length: epoch length, in seconds
|
|
77
|
+
:return: EEG and EMG signals
|
|
78
|
+
"""
|
|
79
|
+
# since resample() was called, this will be extremely close to an integer
|
|
80
|
+
samples_per_epoch = round(sampling_rate * epoch_length)
|
|
81
|
+
|
|
82
|
+
# pad the signal at the end in case we need more samples
|
|
83
|
+
eeg = np.concatenate((eeg, np.ones(samples_per_epoch) * eeg[-1]))
|
|
84
|
+
emg = np.concatenate((emg, np.ones(samples_per_epoch) * emg[-1]))
|
|
85
|
+
padded_signal_length = eeg.size
|
|
86
|
+
|
|
87
|
+
# count samples that don't fit in any epoch
|
|
88
|
+
excess_samples = padded_signal_length % samples_per_epoch
|
|
89
|
+
# we will definitely remove those
|
|
90
|
+
last_index = padded_signal_length - excess_samples
|
|
91
|
+
# and if the last epoch of real data had more than half of
|
|
92
|
+
# its samples missing, delete it
|
|
93
|
+
if excess_samples < samples_per_epoch / 2:
|
|
94
|
+
last_index -= samples_per_epoch
|
|
95
|
+
|
|
96
|
+
return eeg[:last_index], emg[:last_index]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def resample_and_standardize(
|
|
100
|
+
eeg: np.array, emg: np.array, sampling_rate: int | float, epoch_length: int | float
|
|
101
|
+
) -> (np.array, np.array, float):
|
|
102
|
+
"""Preprocess EEG and EMG signals
|
|
103
|
+
|
|
104
|
+
Adjust the length and sampling rate of the EEG and EMG signals so that
|
|
105
|
+
each epoch contains an integer number of samples and each recording
|
|
106
|
+
contains an integer number of epochs.
|
|
107
|
+
|
|
108
|
+
:param eeg: EEG signal
|
|
109
|
+
:param emg: EMG signal
|
|
110
|
+
:param sampling_rate: sampling rate, in Hz
|
|
111
|
+
:param epoch_length: epoch length, in seconds
|
|
112
|
+
:return: processed EEG & EMG signals, and the new sampling rate
|
|
113
|
+
"""
|
|
114
|
+
eeg, emg, sampling_rate = resample(
|
|
115
|
+
eeg=eeg, emg=emg, sampling_rate=sampling_rate, epoch_length=epoch_length
|
|
116
|
+
)
|
|
117
|
+
eeg, emg = standardize_signal_length(
|
|
118
|
+
eeg=eeg, emg=emg, sampling_rate=sampling_rate, epoch_length=epoch_length
|
|
119
|
+
)
|
|
120
|
+
return eeg, emg, sampling_rate
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def create_spectrogram(
|
|
124
|
+
eeg: np.array,
|
|
125
|
+
sampling_rate: int | float,
|
|
126
|
+
epoch_length: int | float,
|
|
127
|
+
time_bandwidth=2,
|
|
128
|
+
n_tapers=3,
|
|
129
|
+
) -> (np.array, np.array):
|
|
130
|
+
"""Create an EEG spectrogram image
|
|
131
|
+
|
|
132
|
+
:param eeg: EEG signal
|
|
133
|
+
:param sampling_rate: sampling rate, in Hz
|
|
134
|
+
:param epoch_length: epoch length, in seconds
|
|
135
|
+
:param time_bandwidth: time-half bandwidth product
|
|
136
|
+
:param n_tapers: number of DPSS tapers to use
|
|
137
|
+
:return: spectrogram and its frequency axis
|
|
138
|
+
"""
|
|
139
|
+
window_length_sec = max(MIN_WINDOW_LEN, epoch_length)
|
|
140
|
+
# pad the EEG signal so that the first spectrogram window is centered
|
|
141
|
+
# on the first epoch
|
|
142
|
+
# it's possible there's some jank here, if this isn't close to an integer
|
|
143
|
+
pad_length = round((sampling_rate * (window_length_sec - epoch_length) / 2))
|
|
144
|
+
padded_eeg = np.concatenate(
|
|
145
|
+
[eeg[:pad_length][::-1], eeg, eeg[(len(eeg) - pad_length) :][::-1]]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
spec, _, f = spectrogram(
|
|
149
|
+
padded_eeg,
|
|
150
|
+
sampling_rate,
|
|
151
|
+
frequency_range=[0, SPECTROGRAM_UPPER_FREQ],
|
|
152
|
+
time_bandwidth=time_bandwidth,
|
|
153
|
+
num_tapers=n_tapers,
|
|
154
|
+
window_params=[window_length_sec, epoch_length],
|
|
155
|
+
min_nfft=0,
|
|
156
|
+
detrend_opt="off",
|
|
157
|
+
multiprocess=True,
|
|
158
|
+
plot_on=False,
|
|
159
|
+
return_fig=False,
|
|
160
|
+
verbose=False,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# resample frequencies for consistency
|
|
164
|
+
target_frequencies = np.arange(0, SPECTROGRAM_UPPER_FREQ, 1 / MIN_WINDOW_LEN)
|
|
165
|
+
freq_idx = list()
|
|
166
|
+
for i in target_frequencies:
|
|
167
|
+
freq_idx.append(np.argmin(np.abs(f - i)))
|
|
168
|
+
f = f[freq_idx]
|
|
169
|
+
spec = spec[freq_idx, :]
|
|
170
|
+
|
|
171
|
+
return spec, f
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def get_emg_power(
|
|
175
|
+
emg: np.array, sampling_rate: int | float, epoch_length: int | float
|
|
176
|
+
) -> np.array:
|
|
177
|
+
"""Calculate EMG power for each epoch
|
|
178
|
+
|
|
179
|
+
This applies a 20-50 Hz bandpass filter to the EMG, calculates the RMS
|
|
180
|
+
in each epoch, and takes the log of the result.
|
|
181
|
+
|
|
182
|
+
:param emg: EMG signal
|
|
183
|
+
:param sampling_rate: sampling rate, in Hz
|
|
184
|
+
:param epoch_length: epoch length, in seconds
|
|
185
|
+
:return: EMG "power" for each epoch
|
|
186
|
+
"""
|
|
187
|
+
from scipy.signal import butter, filtfilt
|
|
188
|
+
|
|
189
|
+
# filter parameters
|
|
190
|
+
order = 8
|
|
191
|
+
bp_lower = 20
|
|
192
|
+
bp_upper = 50
|
|
193
|
+
|
|
194
|
+
b, a = butter(
|
|
195
|
+
N=order,
|
|
196
|
+
Wn=[bp_lower, bp_upper],
|
|
197
|
+
btype="bandpass",
|
|
198
|
+
output="ba",
|
|
199
|
+
fs=sampling_rate,
|
|
200
|
+
)
|
|
201
|
+
filtered = filtfilt(b, a, x=emg, padlen=int(np.ceil(sampling_rate)))
|
|
202
|
+
|
|
203
|
+
# since resample() was called, this will be extremely close to an integer
|
|
204
|
+
samples_per_epoch = round(sampling_rate * epoch_length)
|
|
205
|
+
reshaped = np.reshape(
|
|
206
|
+
filtered,
|
|
207
|
+
[round(len(emg) / samples_per_epoch), samples_per_epoch],
|
|
208
|
+
)
|
|
209
|
+
rms = np.sqrt(np.mean(np.power(reshaped, 2), axis=1))
|
|
210
|
+
|
|
211
|
+
return np.log(rms)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def create_eeg_emg_image(
|
|
215
|
+
eeg: np.array,
|
|
216
|
+
emg: np.array,
|
|
217
|
+
sampling_rate: int | float,
|
|
218
|
+
epoch_length: int | float,
|
|
219
|
+
) -> np.array:
|
|
220
|
+
"""Stack EEG spectrogram and EMG power into an image
|
|
221
|
+
|
|
222
|
+
This assumes that each epoch contains an integer number of samples and
|
|
223
|
+
each recording contains an integer number of epochs. Note that a log
|
|
224
|
+
transformation is applied to the spectrogram.
|
|
225
|
+
|
|
226
|
+
:param eeg: EEG signal
|
|
227
|
+
:param emg: EMG signal
|
|
228
|
+
:param sampling_rate: sampling rate, in Hz
|
|
229
|
+
:param epoch_length: epoch length, in seconds
|
|
230
|
+
:return: combined EEG + EMG image for a recording
|
|
231
|
+
"""
|
|
232
|
+
spec, f = create_spectrogram(eeg, sampling_rate, epoch_length)
|
|
233
|
+
f_lower_idx = sum(f < DOWNSAMPLING_START_FREQ)
|
|
234
|
+
f_upper_idx = sum(f < UPPER_FREQ)
|
|
235
|
+
|
|
236
|
+
modified_spectrogram = np.log(
|
|
237
|
+
spec[
|
|
238
|
+
np.concatenate(
|
|
239
|
+
[np.arange(0, f_lower_idx), np.arange(f_lower_idx, f_upper_idx, 2)]
|
|
240
|
+
),
|
|
241
|
+
:,
|
|
242
|
+
]
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
emg_log_rms = get_emg_power(emg, sampling_rate, epoch_length)
|
|
246
|
+
output = np.concatenate(
|
|
247
|
+
[modified_spectrogram, np.tile(emg_log_rms, (EMG_COPIES, 1))]
|
|
248
|
+
)
|
|
249
|
+
return output
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def get_mixture_values(
|
|
253
|
+
img: np.array, labels: np.array, brain_state_set: BrainStateSet
|
|
254
|
+
) -> (np.array, np.array):
|
|
255
|
+
"""Compute weighted feature means and SDs for mixture z-scoring
|
|
256
|
+
|
|
257
|
+
The outputs of this function can be used to standardize features
|
|
258
|
+
extracted from all recordings from one subject under the same
|
|
259
|
+
recording conditions. Note that labels must be in "class" format
|
|
260
|
+
(i.e., integers between 0 and the number of scored states).
|
|
261
|
+
|
|
262
|
+
:param img: combined EEG + EMG image - see create_eeg_emg_image()
|
|
263
|
+
:param labels: brain state labels, in "class" format
|
|
264
|
+
:param brain_state_set: set of brain state options
|
|
265
|
+
:return: mixture means, mixture standard deviations
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
means = list()
|
|
269
|
+
variances = list()
|
|
270
|
+
mixture_weights = brain_state_set.mixture_weights
|
|
271
|
+
|
|
272
|
+
# get feature means, variances by class
|
|
273
|
+
for i in range(brain_state_set.n_classes):
|
|
274
|
+
means.append(np.mean(img[:, labels == i], axis=1))
|
|
275
|
+
variances.append(np.var(img[:, labels == i], axis=1))
|
|
276
|
+
means = np.array(means)
|
|
277
|
+
variances = np.array(variances)
|
|
278
|
+
|
|
279
|
+
# mixture means are just weighted averages across classes
|
|
280
|
+
mixture_means = means.T @ mixture_weights
|
|
281
|
+
# mixture variance is given by the law of total variance
|
|
282
|
+
mixture_sds = np.sqrt(
|
|
283
|
+
variances.T @ mixture_weights
|
|
284
|
+
+ (
|
|
285
|
+
(mixture_means - np.tile(mixture_means, (brain_state_set.n_classes, 1)))
|
|
286
|
+
** 2
|
|
287
|
+
).T
|
|
288
|
+
@ mixture_weights
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
return mixture_means, mixture_sds
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def mixture_z_score_img(
|
|
295
|
+
img: np.array,
|
|
296
|
+
brain_state_set: BrainStateSet,
|
|
297
|
+
labels: np.array = None,
|
|
298
|
+
mixture_means: np.array = None,
|
|
299
|
+
mixture_sds: np.array = None,
|
|
300
|
+
) -> np.array:
|
|
301
|
+
"""Perform mixture z-scoring on a combined EEG+EMG image
|
|
302
|
+
|
|
303
|
+
If brain state labels are provided, they will be used to calculate
|
|
304
|
+
mixture means and SDs. Otherwise, you must provide those inputs.
|
|
305
|
+
Note that pixel values in the output are in the 0-1 range and will
|
|
306
|
+
clip z-scores beyond ABS_MAX_Z_SCORE.
|
|
307
|
+
|
|
308
|
+
:param img: combined EEG + EMG image - see create_eeg_emg_image()
|
|
309
|
+
:param brain_state_set: set of brain state options
|
|
310
|
+
:param labels: labels, in "class" format
|
|
311
|
+
:param mixture_means: mixture means
|
|
312
|
+
:param mixture_sds: mixture standard deviations
|
|
313
|
+
:return:
|
|
314
|
+
"""
|
|
315
|
+
if labels is None and (mixture_means is None or mixture_sds is None):
|
|
316
|
+
raise Exception("must provide either labels or mixture means+SDs")
|
|
317
|
+
if labels is not None and ((mixture_means is not None) ^ (mixture_sds is not None)):
|
|
318
|
+
warnings.warn("labels were given, mixture means / SDs will be ignored")
|
|
319
|
+
|
|
320
|
+
if labels is not None:
|
|
321
|
+
mixture_means, mixture_sds = get_mixture_values(
|
|
322
|
+
img=img, labels=labels, brain_state_set=brain_state_set
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
img = ((img.T - mixture_means) / mixture_sds).T
|
|
326
|
+
img = (img + ABS_MAX_Z_SCORE) / (2 * ABS_MAX_Z_SCORE)
|
|
327
|
+
img = np.clip(img, 0, 1)
|
|
328
|
+
|
|
329
|
+
return img
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def format_img(img: np.array, epochs_per_img: int, add_padding: bool) -> np.array:
|
|
333
|
+
"""Adjust the format of an EEG+EMG image
|
|
334
|
+
|
|
335
|
+
This function converts the values in a combined EEG+EMG image to uint8.
|
|
336
|
+
This is a convenient format both for storing individual images as files,
|
|
337
|
+
and for using the images as input to a classifier.
|
|
338
|
+
This function also optionally adds new epochs to the beginning/end of the
|
|
339
|
+
recording's image so that an image can be created for every epoch. For
|
|
340
|
+
real-time scoring, padding should not be used.
|
|
341
|
+
|
|
342
|
+
:param img: combined EEG + EMG image
|
|
343
|
+
:param epochs_per_img: number of epochs in each individual image
|
|
344
|
+
:param add_padding: whether to pad each side by (epochs_per_img - 1) / 2
|
|
345
|
+
:return: formatted EEG + EMG image
|
|
346
|
+
"""
|
|
347
|
+
# pad beginning and end
|
|
348
|
+
if add_padding:
|
|
349
|
+
pad_width = round((epochs_per_img - 1) / 2)
|
|
350
|
+
img = np.concatenate(
|
|
351
|
+
[
|
|
352
|
+
np.tile(img[:, 0], (pad_width, 1)).T,
|
|
353
|
+
img,
|
|
354
|
+
np.tile(img[:, -1], (pad_width, 1)).T,
|
|
355
|
+
],
|
|
356
|
+
axis=1,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# use 8-bit values
|
|
360
|
+
img = np.clip(img * 255, 0, 255)
|
|
361
|
+
img = img.astype(np.uint8)
|
|
362
|
+
|
|
363
|
+
return img
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def create_training_images(
|
|
367
|
+
recordings: list[Recording],
|
|
368
|
+
output_path: str,
|
|
369
|
+
epoch_length: int | float,
|
|
370
|
+
epochs_per_img: int,
|
|
371
|
+
brain_state_set: BrainStateSet,
|
|
372
|
+
model_type: str,
|
|
373
|
+
calibration_fraction: float,
|
|
374
|
+
) -> list[int]:
|
|
375
|
+
"""Create training dataset
|
|
376
|
+
|
|
377
|
+
By default, the current epoch is located in the central column
|
|
378
|
+
of pixels in each image. For real-time scoring applications,
|
|
379
|
+
the current epoch is at the right edge of each image.
|
|
380
|
+
|
|
381
|
+
:param recordings: list of recordings in the training set
|
|
382
|
+
:param output_path: where to store training images
|
|
383
|
+
:param epoch_length: epoch length, in seconds
|
|
384
|
+
:param epochs_per_img: # number of epochs shown in each image
|
|
385
|
+
:param brain_state_set: set of brain state options
|
|
386
|
+
:param model_type: default or real-time
|
|
387
|
+
:param calibration_fraction: fraction of training data to use for calibration
|
|
388
|
+
:return: list of the names of any recordings that could not
|
|
389
|
+
be used to create training images.
|
|
390
|
+
"""
|
|
391
|
+
# recordings that had to be skipped
|
|
392
|
+
failed_recordings = list()
|
|
393
|
+
# image filenames for valid epochs
|
|
394
|
+
filenames = list()
|
|
395
|
+
# all valid labels from all valid recordings
|
|
396
|
+
all_labels = list()
|
|
397
|
+
# try to load each recording and create training images
|
|
398
|
+
for i in trange(len(recordings)):
|
|
399
|
+
recording = recordings[i]
|
|
400
|
+
try:
|
|
401
|
+
eeg, emg = load_recording(recording.recording_file)
|
|
402
|
+
sampling_rate = recording.sampling_rate
|
|
403
|
+
eeg, emg, sampling_rate = resample_and_standardize(
|
|
404
|
+
eeg=eeg,
|
|
405
|
+
emg=emg,
|
|
406
|
+
sampling_rate=sampling_rate,
|
|
407
|
+
epoch_length=epoch_length,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
labels, _ = load_labels(recording.label_file)
|
|
411
|
+
labels = brain_state_set.convert_digit_to_class(labels)
|
|
412
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
413
|
+
img = mixture_z_score_img(
|
|
414
|
+
img=img, brain_state_set=brain_state_set, labels=labels
|
|
415
|
+
)
|
|
416
|
+
img = format_img(img=img, epochs_per_img=epochs_per_img, add_padding=True)
|
|
417
|
+
|
|
418
|
+
# the model type determines which epochs are used in each image
|
|
419
|
+
if model_type == DEFAULT_MODEL_TYPE:
|
|
420
|
+
# here, j is the index of the current epoch in 'labels'
|
|
421
|
+
# and the index of the leftmost epoch in 'img'
|
|
422
|
+
for j in range(img.shape[1] - (epochs_per_img - 1)):
|
|
423
|
+
if labels[j] is None:
|
|
424
|
+
continue
|
|
425
|
+
im = img[:, j : (j + epochs_per_img)]
|
|
426
|
+
filename = f"recording_{recording.name}_{j}_{labels[j]}.png"
|
|
427
|
+
filenames.append(filename)
|
|
428
|
+
all_labels.append(labels[j])
|
|
429
|
+
Image.fromarray(im).save(os.path.join(output_path, filename))
|
|
430
|
+
else:
|
|
431
|
+
# here, j is the index of the current epoch in 'labels'
|
|
432
|
+
# but we throw away a few epochs at the start since they
|
|
433
|
+
# would require even more padding on the left side.
|
|
434
|
+
one_side_padding = round((epochs_per_img - 1) / 2)
|
|
435
|
+
for j in range(one_side_padding, len(labels)):
|
|
436
|
+
if labels[j] is None:
|
|
437
|
+
continue
|
|
438
|
+
im = img[:, (j - one_side_padding) : j + one_side_padding + 1]
|
|
439
|
+
filename = f"recording_{recording.name}_{j}_{labels[j]}.png"
|
|
440
|
+
filenames.append(filename)
|
|
441
|
+
all_labels.append(labels[j])
|
|
442
|
+
Image.fromarray(im).save(os.path.join(output_path, filename))
|
|
443
|
+
|
|
444
|
+
except Exception as e:
|
|
445
|
+
print(e)
|
|
446
|
+
failed_recordings.append(recording.name)
|
|
447
|
+
|
|
448
|
+
annotations = pd.DataFrame({FILENAME_COL: filenames, LABEL_COL: all_labels})
|
|
449
|
+
|
|
450
|
+
# split into training and calibration sets, if necessary
|
|
451
|
+
if calibration_fraction > 0:
|
|
452
|
+
calibration_set = annotations.sample(frac=calibration_fraction)
|
|
453
|
+
training_set = annotations.drop(calibration_set.index)
|
|
454
|
+
training_set.to_csv(
|
|
455
|
+
os.path.join(output_path, ANNOTATIONS_FILENAME),
|
|
456
|
+
index=False,
|
|
457
|
+
)
|
|
458
|
+
calibration_set.to_csv(
|
|
459
|
+
os.path.join(output_path, CALIBRATION_ANNOTATION_FILENAME),
|
|
460
|
+
index=False,
|
|
461
|
+
)
|
|
462
|
+
else:
|
|
463
|
+
# annotation file contains info on all training images
|
|
464
|
+
annotations.to_csv(
|
|
465
|
+
os.path.join(output_path, ANNOTATIONS_FILENAME),
|
|
466
|
+
index=False,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
return failed_recordings
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from torch import nn, optim
|
|
4
|
+
from torch.nn import functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ModelWithTemperature(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
A thin decorator, which wraps a model with temperature scaling
|
|
10
|
+
model (nn.Module):
|
|
11
|
+
A classification neural network
|
|
12
|
+
NB: Output of the neural network should be the classification logits,
|
|
13
|
+
NOT the softmax (or log softmax)!
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, model):
|
|
17
|
+
super(ModelWithTemperature, self).__init__()
|
|
18
|
+
self.model = model
|
|
19
|
+
# https://github.com/gpleiss/temperature_scaling/issues/20
|
|
20
|
+
# for another approach, see https://github.com/gpleiss/temperature_scaling/issues/36
|
|
21
|
+
self.model.eval()
|
|
22
|
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
|
23
|
+
|
|
24
|
+
def forward(self, x):
|
|
25
|
+
logits = self.model(x)
|
|
26
|
+
return self.temperature_scale(logits)
|
|
27
|
+
|
|
28
|
+
def temperature_scale(self, logits):
|
|
29
|
+
"""
|
|
30
|
+
Perform temperature scaling on logits
|
|
31
|
+
"""
|
|
32
|
+
# Expand temperature to match the size of logits
|
|
33
|
+
temperature = self.temperature.unsqueeze(1).expand(
|
|
34
|
+
logits.size(0), logits.size(1)
|
|
35
|
+
)
|
|
36
|
+
return logits / temperature
|
|
37
|
+
|
|
38
|
+
# This function probably should live outside of this class, but whatever
|
|
39
|
+
def set_temperature(self, valid_loader):
|
|
40
|
+
"""
|
|
41
|
+
Tune the temperature of the model (using the validation set).
|
|
42
|
+
We're going to set it to optimize NLL.
|
|
43
|
+
valid_loader (DataLoader): validation set loader
|
|
44
|
+
"""
|
|
45
|
+
if torch.accelerator.is_available():
|
|
46
|
+
device = torch.accelerator.current_accelerator().type
|
|
47
|
+
else:
|
|
48
|
+
device = "cpu"
|
|
49
|
+
|
|
50
|
+
# self.cuda()
|
|
51
|
+
self.to(device)
|
|
52
|
+
nll_criterion = nn.CrossEntropyLoss().to(device) # .cuda()
|
|
53
|
+
ece_criterion = _ECELoss().to(device) # .cuda()
|
|
54
|
+
|
|
55
|
+
# First: collect all the logits and labels for the validation set
|
|
56
|
+
logits_list = []
|
|
57
|
+
labels_list = []
|
|
58
|
+
prediction_list = []
|
|
59
|
+
with torch.no_grad():
|
|
60
|
+
for x, label in valid_loader:
|
|
61
|
+
x = x.to(device) # .cuda()
|
|
62
|
+
logits = self.model(x)
|
|
63
|
+
logits_list.append(logits)
|
|
64
|
+
labels_list.append(label)
|
|
65
|
+
|
|
66
|
+
_, pred = torch.max(logits, 1)
|
|
67
|
+
prediction_list.append(pred)
|
|
68
|
+
logits = torch.cat(logits_list).to(device) # .cuda()
|
|
69
|
+
labels = torch.cat(labels_list).to(device) # .cuda()
|
|
70
|
+
predictions = torch.cat(prediction_list).to(device)
|
|
71
|
+
|
|
72
|
+
# Calculate NLL and ECE before temperature scaling
|
|
73
|
+
before_temperature_nll = nll_criterion(logits, labels).item()
|
|
74
|
+
before_temperature_ece = ece_criterion(logits, labels).item()
|
|
75
|
+
print(
|
|
76
|
+
"Before temperature - NLL: %.3f, ECE: %.3f"
|
|
77
|
+
% (before_temperature_nll, before_temperature_ece)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Next: optimize the temperature w.r.t. NLL
|
|
81
|
+
# https://github.com/gpleiss/temperature_scaling/issues/34
|
|
82
|
+
optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=100)
|
|
83
|
+
|
|
84
|
+
def eval():
|
|
85
|
+
optimizer.zero_grad()
|
|
86
|
+
loss = nll_criterion(self.temperature_scale(logits), labels)
|
|
87
|
+
loss.backward()
|
|
88
|
+
return loss
|
|
89
|
+
|
|
90
|
+
optimizer.step(eval)
|
|
91
|
+
|
|
92
|
+
# Calculate NLL and ECE after temperature scaling
|
|
93
|
+
after_temperature_nll = nll_criterion(
|
|
94
|
+
self.temperature_scale(logits), labels
|
|
95
|
+
).item()
|
|
96
|
+
after_temperature_ece = ece_criterion(
|
|
97
|
+
self.temperature_scale(logits), labels
|
|
98
|
+
).item()
|
|
99
|
+
print("Optimal temperature: %.3f" % self.temperature.item())
|
|
100
|
+
print(
|
|
101
|
+
"After temperature - NLL: %.3f, ECE: %.3f"
|
|
102
|
+
% (after_temperature_nll, after_temperature_ece)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
val_acc = round(
|
|
106
|
+
100 * np.mean(labels.cpu().numpy() == predictions.cpu().numpy()), 2
|
|
107
|
+
)
|
|
108
|
+
print(f"Validation accuracy: {val_acc}%")
|
|
109
|
+
|
|
110
|
+
return self
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class _ECELoss(nn.Module):
|
|
114
|
+
"""
|
|
115
|
+
Calculates the Expected Calibration Error of a model.
|
|
116
|
+
(This isn't necessary for temperature scaling, just a cool metric).
|
|
117
|
+
|
|
118
|
+
The input to this loss is the logits of a model, NOT the softmax scores.
|
|
119
|
+
|
|
120
|
+
This divides the confidence outputs into equally-sized interval bins.
|
|
121
|
+
In each bin, we compute the confidence gap:
|
|
122
|
+
|
|
123
|
+
bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
|
|
124
|
+
|
|
125
|
+
We then return a weighted average of the gaps, based on the number
|
|
126
|
+
of samples in each bin
|
|
127
|
+
|
|
128
|
+
See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
|
|
129
|
+
"Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
|
|
130
|
+
2015.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(self, n_bins=15):
|
|
134
|
+
"""
|
|
135
|
+
n_bins (int): number of confidence interval bins
|
|
136
|
+
"""
|
|
137
|
+
super(_ECELoss, self).__init__()
|
|
138
|
+
bin_boundaries = torch.linspace(0, 1, n_bins + 1)
|
|
139
|
+
self.bin_lowers = bin_boundaries[:-1]
|
|
140
|
+
self.bin_uppers = bin_boundaries[1:]
|
|
141
|
+
|
|
142
|
+
def forward(self, logits, labels):
|
|
143
|
+
softmaxes = F.softmax(logits, dim=1)
|
|
144
|
+
confidences, predictions = torch.max(softmaxes, 1)
|
|
145
|
+
accuracies = predictions.eq(labels)
|
|
146
|
+
|
|
147
|
+
ece = torch.zeros(1, device=logits.device)
|
|
148
|
+
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
|
|
149
|
+
# Calculated |confidence - accuracy| in each bin
|
|
150
|
+
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
|
|
151
|
+
prop_in_bin = in_bin.float().mean()
|
|
152
|
+
if prop_in_bin.item() > 0:
|
|
153
|
+
accuracy_in_bin = accuracies[in_bin].float().mean()
|
|
154
|
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
|
155
|
+
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
|
156
|
+
|
|
157
|
+
return ece
|