accusleepy 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/classification.py +49 -15
- accusleepy/config.json +15 -1
- accusleepy/constants.py +29 -2
- accusleepy/fileio.py +107 -33
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/images/viewer_window.png +0 -0
- accusleepy/gui/images/viewer_window_annotated.png +0 -0
- accusleepy/gui/main.py +220 -42
- accusleepy/gui/manual_scoring.py +38 -8
- accusleepy/gui/mplwidget.py +54 -29
- accusleepy/gui/primary_window.py +937 -254
- accusleepy/gui/primary_window.ui +3182 -2227
- accusleepy/gui/resources.qrc +1 -1
- accusleepy/gui/text/main_guide.md +18 -12
- accusleepy/gui/viewer_window.py +19 -7
- accusleepy/gui/viewer_window.ui +34 -2
- accusleepy/models.py +11 -1
- accusleepy/signal_processing.py +40 -17
- accusleepy/temperature_scaling.py +157 -0
- {accusleepy-0.5.0.dist-info → accusleepy-0.7.0.dist-info}/METADATA +11 -2
- accusleepy-0.7.0.dist-info/RECORD +41 -0
- {accusleepy-0.5.0.dist-info → accusleepy-0.7.0.dist-info}/WHEEL +1 -1
- accusleepy/gui/text/config_guide.txt +0 -29
- accusleepy-0.5.0.dist-info/RECORD +0 -41
accusleepy/gui/resources.qrc
CHANGED
|
@@ -29,7 +29,8 @@ At this point, you can score the recordings manually.
|
|
|
29
29
|
5. Score all recordings automatically using the classifier
|
|
30
30
|
|
|
31
31
|
By default, there are three brain state options: REM, wake, and NREM.
|
|
32
|
-
If you want to change this configuration, click the "Settings" tab
|
|
32
|
+
If you want to change this configuration, click the "Settings" tab and
|
|
33
|
+
choose "Brain states" from the drop-down menu.
|
|
33
34
|
Note that if you change the configuration, you might be unable to load
|
|
34
35
|
existing labels and calibration data, and you may need to train a new
|
|
35
36
|
classification model.
|
|
@@ -45,12 +46,14 @@ To select a file in the primary interface, you can either use the
|
|
|
45
46
|
associated button, or drag/drop the file into the empty box adjacent
|
|
46
47
|
to the button.
|
|
47
48
|
- Recording file: a .parquet or .csv file containing one
|
|
48
|
-
column of EEG
|
|
49
|
-
The column names must be eeg and emg
|
|
50
|
-
- Label file: a .csv file with one column titled brain_state
|
|
49
|
+
column of EEG data and one column of EMG data.
|
|
50
|
+
The column names must be **eeg** and **emg**.
|
|
51
|
+
- Label file: a .csv file with one column titled **brain_state**
|
|
51
52
|
with entries that are either the undefined brain state (by default, this is -1)
|
|
52
53
|
or one of the digits in your brain state configuration.
|
|
53
54
|
By default, these are 1-3 where REM = 1, wake = 2, NREM = 3.
|
|
55
|
+
Optionally, there can be a second column named **confidence_score**
|
|
56
|
+
containing classification confidence scores between 0 and 1.
|
|
54
57
|
- Calibration data file: required for automated scoring. See Section 4
|
|
55
58
|
for details. These have .csv format.
|
|
56
59
|
- Trained classification model: required for automated scoring. See
|
|
@@ -126,15 +129,17 @@ To train a new model on your own data:
|
|
|
126
129
|
type models, this must be an odd number. In general, about 30
|
|
127
130
|
seconds worth of data is enough.
|
|
128
131
|
4. Choose whether the images used to train the model should be
|
|
129
|
-
deleted once training is complete.
|
|
130
|
-
leave this box checked.)
|
|
132
|
+
deleted once training is complete. It's generally best to
|
|
133
|
+
leave this box checked. A (temporary) folder for these files
|
|
134
|
+
will be created in the same location as the trained model.
|
|
131
135
|
5. Choose whether to create a "default" or "real-time"-type model.
|
|
132
136
|
Note that scoring recordings in the primary interface requires
|
|
133
137
|
a default-type model.
|
|
134
|
-
6.
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
+
6. Choose whether to calibrate the model. This process uses part
|
|
139
|
+
of the training data to make the model's confidence scores
|
|
140
|
+
more accurately reflect the probability that the output
|
|
141
|
+
labels are accurate. If using calibration, choose what percent
|
|
142
|
+
of the training data to set aside for calibration.
|
|
138
143
|
7. Click "Train classification model" and enter a
|
|
139
144
|
filename for the trained model. Training can take some time.
|
|
140
145
|
The console will display progress updates.
|
|
@@ -153,10 +158,11 @@ Instructions for automatic scoring using this interface are below.
|
|
|
153
158
|
4. If you wish to preserve any existing labels in the label file, and
|
|
154
159
|
only overwrite undefined epochs, check the box labeled
|
|
155
160
|
"Only overwrite undefined epochs".
|
|
156
|
-
5.
|
|
161
|
+
5. Choose whether to save confidence scores to the label files.
|
|
162
|
+
6. Set the minimum bout length, in seconds. A typical value could be 5.
|
|
157
163
|
Following automatic labeling, any brain state bout shorter than this
|
|
158
164
|
duration will be reassigned to the surrounding state (if the states
|
|
159
165
|
on either side of the bout are the same).
|
|
160
|
-
|
|
166
|
+
7. Click "Score all automatically" to score all recordings in the
|
|
161
167
|
recording list. To inspect the results, select a recording
|
|
162
168
|
in the list and click "Score manually".
|
accusleepy/gui/viewer_window.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
################################################################################
|
|
4
4
|
## Form generated from reading UI file 'viewer_window.ui'
|
|
5
5
|
##
|
|
6
|
-
## Created by: Qt User Interface Compiler version 6.
|
|
6
|
+
## Created by: Qt User Interface Compiler version 6.7.3
|
|
7
7
|
##
|
|
8
8
|
## WARNING! All changes made in this file will be lost when recompiling UI file!
|
|
9
9
|
################################################################################
|
|
@@ -90,7 +90,7 @@ class Ui_ViewerWindow(object):
|
|
|
90
90
|
self.top_plot_buttons = QVBoxLayout()
|
|
91
91
|
self.top_plot_buttons.setObjectName("top_plot_buttons")
|
|
92
92
|
self.topcontroltopspacer = QSpacerItem(
|
|
93
|
-
5, 10, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.
|
|
93
|
+
5, 10, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred
|
|
94
94
|
)
|
|
95
95
|
|
|
96
96
|
self.top_plot_buttons.addItem(self.topcontroltopspacer)
|
|
@@ -142,6 +142,12 @@ class Ui_ViewerWindow(object):
|
|
|
142
142
|
|
|
143
143
|
self.zoom_and_brightness = QVBoxLayout()
|
|
144
144
|
self.zoom_and_brightness.setObjectName("zoom_and_brightness")
|
|
145
|
+
self.verticalSpacer_6 = QSpacerItem(
|
|
146
|
+
5, 5, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
self.zoom_and_brightness.addItem(self.verticalSpacer_6)
|
|
150
|
+
|
|
145
151
|
self.brightness_buttons = QHBoxLayout()
|
|
146
152
|
self.brightness_buttons.setObjectName("brightness_buttons")
|
|
147
153
|
self.specbrighter = QPushButton(ViewerWindow)
|
|
@@ -227,6 +233,12 @@ class Ui_ViewerWindow(object):
|
|
|
227
233
|
|
|
228
234
|
self.zoom_and_brightness.addLayout(self.zoom_buttons)
|
|
229
235
|
|
|
236
|
+
self.verticalSpacer_7 = QSpacerItem(
|
|
237
|
+
5, 5, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self.zoom_and_brightness.addItem(self.verticalSpacer_7)
|
|
241
|
+
|
|
230
242
|
self.top_plot_buttons.addLayout(self.zoom_and_brightness)
|
|
231
243
|
|
|
232
244
|
self.topcontrolbottomspacer = QSpacerItem(
|
|
@@ -235,11 +247,11 @@ class Ui_ViewerWindow(object):
|
|
|
235
247
|
|
|
236
248
|
self.top_plot_buttons.addItem(self.topcontrolbottomspacer)
|
|
237
249
|
|
|
238
|
-
self.top_plot_buttons.setStretch(0,
|
|
239
|
-
self.top_plot_buttons.setStretch(1,
|
|
240
|
-
self.top_plot_buttons.setStretch(2,
|
|
241
|
-
self.top_plot_buttons.setStretch(3,
|
|
242
|
-
self.top_plot_buttons.setStretch(4,
|
|
250
|
+
self.top_plot_buttons.setStretch(0, 2)
|
|
251
|
+
self.top_plot_buttons.setStretch(1, 5)
|
|
252
|
+
self.top_plot_buttons.setStretch(2, 5)
|
|
253
|
+
self.top_plot_buttons.setStretch(3, 12)
|
|
254
|
+
self.top_plot_buttons.setStretch(4, 12)
|
|
243
255
|
|
|
244
256
|
self.all_controls.addLayout(self.top_plot_buttons)
|
|
245
257
|
|
accusleepy/gui/viewer_window.ui
CHANGED
|
@@ -75,14 +75,14 @@
|
|
|
75
75
|
<enum>QLayout::SizeConstraint::SetDefaultConstraint</enum>
|
|
76
76
|
</property>
|
|
77
77
|
<item>
|
|
78
|
-
<layout class="QVBoxLayout" name="top_plot_buttons" stretch="
|
|
78
|
+
<layout class="QVBoxLayout" name="top_plot_buttons" stretch="2,5,5,12,12">
|
|
79
79
|
<item>
|
|
80
80
|
<spacer name="topcontroltopspacer">
|
|
81
81
|
<property name="orientation">
|
|
82
82
|
<enum>Qt::Orientation::Vertical</enum>
|
|
83
83
|
</property>
|
|
84
84
|
<property name="sizeType">
|
|
85
|
-
<enum>QSizePolicy::Policy::
|
|
85
|
+
<enum>QSizePolicy::Policy::Preferred</enum>
|
|
86
86
|
</property>
|
|
87
87
|
<property name="sizeHint" stdset="0">
|
|
88
88
|
<size>
|
|
@@ -184,6 +184,22 @@
|
|
|
184
184
|
</item>
|
|
185
185
|
<item>
|
|
186
186
|
<layout class="QVBoxLayout" name="zoom_and_brightness">
|
|
187
|
+
<item>
|
|
188
|
+
<spacer name="verticalSpacer_6">
|
|
189
|
+
<property name="orientation">
|
|
190
|
+
<enum>Qt::Orientation::Vertical</enum>
|
|
191
|
+
</property>
|
|
192
|
+
<property name="sizeType">
|
|
193
|
+
<enum>QSizePolicy::Policy::Preferred</enum>
|
|
194
|
+
</property>
|
|
195
|
+
<property name="sizeHint" stdset="0">
|
|
196
|
+
<size>
|
|
197
|
+
<width>5</width>
|
|
198
|
+
<height>5</height>
|
|
199
|
+
</size>
|
|
200
|
+
</property>
|
|
201
|
+
</spacer>
|
|
202
|
+
</item>
|
|
187
203
|
<item>
|
|
188
204
|
<layout class="QHBoxLayout" name="brightness_buttons">
|
|
189
205
|
<item>
|
|
@@ -361,6 +377,22 @@
|
|
|
361
377
|
</item>
|
|
362
378
|
</layout>
|
|
363
379
|
</item>
|
|
380
|
+
<item>
|
|
381
|
+
<spacer name="verticalSpacer_7">
|
|
382
|
+
<property name="orientation">
|
|
383
|
+
<enum>Qt::Orientation::Vertical</enum>
|
|
384
|
+
</property>
|
|
385
|
+
<property name="sizeType">
|
|
386
|
+
<enum>QSizePolicy::Policy::Preferred</enum>
|
|
387
|
+
</property>
|
|
388
|
+
<property name="sizeHint" stdset="0">
|
|
389
|
+
<size>
|
|
390
|
+
<width>5</width>
|
|
391
|
+
<height>5</height>
|
|
392
|
+
</size>
|
|
393
|
+
</property>
|
|
394
|
+
</spacer>
|
|
395
|
+
</item>
|
|
364
396
|
</layout>
|
|
365
397
|
</item>
|
|
366
398
|
<item>
|
accusleepy/models.py
CHANGED
|
@@ -10,6 +10,7 @@ from accusleepy.constants import (
|
|
|
10
10
|
MIN_WINDOW_LEN,
|
|
11
11
|
UPPER_FREQ,
|
|
12
12
|
)
|
|
13
|
+
from accusleepy.temperature_scaling import ModelWithTemperature
|
|
13
14
|
|
|
14
15
|
# height in pixels of each training image
|
|
15
16
|
IMAGE_HEIGHT = (
|
|
@@ -56,20 +57,23 @@ def save_model(
|
|
|
56
57
|
epochs_per_img: int,
|
|
57
58
|
model_type: str,
|
|
58
59
|
brain_state_set: BrainStateSet,
|
|
60
|
+
is_calibrated: bool,
|
|
59
61
|
) -> None:
|
|
60
62
|
"""Save classification model and its metadata
|
|
61
63
|
|
|
62
64
|
:param model: classification model
|
|
65
|
+
:param filename: filename
|
|
63
66
|
:param epoch_length: epoch length used when training the model
|
|
64
67
|
:param epochs_per_img: number of epochs in each model input
|
|
65
68
|
:param model_type: default or real-time
|
|
66
69
|
:param brain_state_set: set of brain state options
|
|
67
|
-
:param
|
|
70
|
+
:param is_calibrated: whether the model has been calibrated
|
|
68
71
|
"""
|
|
69
72
|
state_dict = model.state_dict()
|
|
70
73
|
state_dict.update({"epoch_length": epoch_length})
|
|
71
74
|
state_dict.update({"epochs_per_img": epochs_per_img})
|
|
72
75
|
state_dict.update({"model_type": model_type})
|
|
76
|
+
state_dict.update({"is_calibrated": is_calibrated})
|
|
73
77
|
state_dict.update(
|
|
74
78
|
{BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
|
|
75
79
|
)
|
|
@@ -90,9 +94,15 @@ def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
|
|
|
90
94
|
epoch_length = state_dict.pop("epoch_length")
|
|
91
95
|
epochs_per_img = state_dict.pop("epochs_per_img")
|
|
92
96
|
model_type = state_dict.pop("model_type")
|
|
97
|
+
if "is_calibrated" in state_dict:
|
|
98
|
+
is_calibrated = state_dict.pop("is_calibrated")
|
|
99
|
+
else:
|
|
100
|
+
is_calibrated = False
|
|
93
101
|
brain_states = state_dict.pop(BRAIN_STATES_KEY)
|
|
94
102
|
n_classes = len([b for b in brain_states if b["is_scored"]])
|
|
95
103
|
|
|
96
104
|
model = SSANN(n_classes=n_classes)
|
|
105
|
+
if is_calibrated:
|
|
106
|
+
model = ModelWithTemperature(model)
|
|
97
107
|
model.load_state_dict(state_dict)
|
|
98
108
|
return model, epoch_length, epochs_per_img, model_type, brain_states
|
accusleepy/signal_processing.py
CHANGED
|
@@ -9,6 +9,7 @@ from tqdm import trange
|
|
|
9
9
|
from accusleepy.brain_state_set import BrainStateSet
|
|
10
10
|
from accusleepy.constants import (
|
|
11
11
|
ANNOTATIONS_FILENAME,
|
|
12
|
+
CALIBRATION_ANNOTATION_FILENAME,
|
|
12
13
|
DEFAULT_MODEL_TYPE,
|
|
13
14
|
DOWNSAMPLING_START_FREQ,
|
|
14
15
|
EMG_COPIES,
|
|
@@ -17,7 +18,7 @@ from accusleepy.constants import (
|
|
|
17
18
|
MIN_WINDOW_LEN,
|
|
18
19
|
UPPER_FREQ,
|
|
19
20
|
)
|
|
20
|
-
from accusleepy.fileio import Recording, load_labels, load_recording
|
|
21
|
+
from accusleepy.fileio import Recording, load_labels, load_recording, EMGFilter
|
|
21
22
|
from accusleepy.multitaper import spectrogram
|
|
22
23
|
|
|
23
24
|
# note: scipy is lazily imported
|
|
@@ -171,7 +172,10 @@ def create_spectrogram(
|
|
|
171
172
|
|
|
172
173
|
|
|
173
174
|
def get_emg_power(
|
|
174
|
-
emg: np.array,
|
|
175
|
+
emg: np.array,
|
|
176
|
+
sampling_rate: int | float,
|
|
177
|
+
epoch_length: int | float,
|
|
178
|
+
emg_filter: EMGFilter,
|
|
175
179
|
) -> np.array:
|
|
176
180
|
"""Calculate EMG power for each epoch
|
|
177
181
|
|
|
@@ -181,18 +185,14 @@ def get_emg_power(
|
|
|
181
185
|
:param emg: EMG signal
|
|
182
186
|
:param sampling_rate: sampling rate, in Hz
|
|
183
187
|
:param epoch_length: epoch length, in seconds
|
|
188
|
+
:param emg_filter: EMG filter parameters
|
|
184
189
|
:return: EMG "power" for each epoch
|
|
185
190
|
"""
|
|
186
191
|
from scipy.signal import butter, filtfilt
|
|
187
192
|
|
|
188
|
-
# filter parameters
|
|
189
|
-
order = 8
|
|
190
|
-
bp_lower = 20
|
|
191
|
-
bp_upper = 50
|
|
192
|
-
|
|
193
193
|
b, a = butter(
|
|
194
|
-
N=order,
|
|
195
|
-
Wn=[bp_lower, bp_upper],
|
|
194
|
+
N=emg_filter.order,
|
|
195
|
+
Wn=[emg_filter.bp_lower, emg_filter.bp_upper],
|
|
196
196
|
btype="bandpass",
|
|
197
197
|
output="ba",
|
|
198
198
|
fs=sampling_rate,
|
|
@@ -215,6 +215,7 @@ def create_eeg_emg_image(
|
|
|
215
215
|
emg: np.array,
|
|
216
216
|
sampling_rate: int | float,
|
|
217
217
|
epoch_length: int | float,
|
|
218
|
+
emg_filter: EMGFilter,
|
|
218
219
|
) -> np.array:
|
|
219
220
|
"""Stack EEG spectrogram and EMG power into an image
|
|
220
221
|
|
|
@@ -226,6 +227,7 @@ def create_eeg_emg_image(
|
|
|
226
227
|
:param emg: EMG signal
|
|
227
228
|
:param sampling_rate: sampling rate, in Hz
|
|
228
229
|
:param epoch_length: epoch length, in seconds
|
|
230
|
+
:param emg_filter: EMG filter parameters
|
|
229
231
|
:return: combined EEG + EMG image for a recording
|
|
230
232
|
"""
|
|
231
233
|
spec, f = create_spectrogram(eeg, sampling_rate, epoch_length)
|
|
@@ -241,7 +243,7 @@ def create_eeg_emg_image(
|
|
|
241
243
|
]
|
|
242
244
|
)
|
|
243
245
|
|
|
244
|
-
emg_log_rms = get_emg_power(emg, sampling_rate, epoch_length)
|
|
246
|
+
emg_log_rms = get_emg_power(emg, sampling_rate, epoch_length, emg_filter)
|
|
245
247
|
output = np.concatenate(
|
|
246
248
|
[modified_spectrogram, np.tile(emg_log_rms, (EMG_COPIES, 1))]
|
|
247
249
|
)
|
|
@@ -369,6 +371,8 @@ def create_training_images(
|
|
|
369
371
|
epochs_per_img: int,
|
|
370
372
|
brain_state_set: BrainStateSet,
|
|
371
373
|
model_type: str,
|
|
374
|
+
calibration_fraction: float,
|
|
375
|
+
emg_filter: EMGFilter,
|
|
372
376
|
) -> list[int]:
|
|
373
377
|
"""Create training dataset
|
|
374
378
|
|
|
@@ -382,6 +386,8 @@ def create_training_images(
|
|
|
382
386
|
:param epochs_per_img: # number of epochs shown in each image
|
|
383
387
|
:param brain_state_set: set of brain state options
|
|
384
388
|
:param model_type: default or real-time
|
|
389
|
+
:param calibration_fraction: fraction of training data to use for calibration
|
|
390
|
+
:param emg_filter: EMG filter parameters
|
|
385
391
|
:return: list of the names of any recordings that could not
|
|
386
392
|
be used to create training images.
|
|
387
393
|
"""
|
|
@@ -404,9 +410,11 @@ def create_training_images(
|
|
|
404
410
|
epoch_length=epoch_length,
|
|
405
411
|
)
|
|
406
412
|
|
|
407
|
-
labels = load_labels(recording.label_file)
|
|
413
|
+
labels, _ = load_labels(recording.label_file)
|
|
408
414
|
labels = brain_state_set.convert_digit_to_class(labels)
|
|
409
|
-
img = create_eeg_emg_image(
|
|
415
|
+
img = create_eeg_emg_image(
|
|
416
|
+
eeg, emg, sampling_rate, epoch_length, emg_filter
|
|
417
|
+
)
|
|
410
418
|
img = mixture_z_score_img(
|
|
411
419
|
img=img, brain_state_set=brain_state_set, labels=labels
|
|
412
420
|
)
|
|
@@ -442,10 +450,25 @@ def create_training_images(
|
|
|
442
450
|
print(e)
|
|
443
451
|
failed_recordings.append(recording.name)
|
|
444
452
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
453
|
+
annotations = pd.DataFrame({FILENAME_COL: filenames, LABEL_COL: all_labels})
|
|
454
|
+
|
|
455
|
+
# split into training and calibration sets, if necessary
|
|
456
|
+
if calibration_fraction > 0:
|
|
457
|
+
calibration_set = annotations.sample(frac=calibration_fraction)
|
|
458
|
+
training_set = annotations.drop(calibration_set.index)
|
|
459
|
+
training_set.to_csv(
|
|
460
|
+
os.path.join(output_path, ANNOTATIONS_FILENAME),
|
|
461
|
+
index=False,
|
|
462
|
+
)
|
|
463
|
+
calibration_set.to_csv(
|
|
464
|
+
os.path.join(output_path, CALIBRATION_ANNOTATION_FILENAME),
|
|
465
|
+
index=False,
|
|
466
|
+
)
|
|
467
|
+
else:
|
|
468
|
+
# annotation file contains info on all training images
|
|
469
|
+
annotations.to_csv(
|
|
470
|
+
os.path.join(output_path, ANNOTATIONS_FILENAME),
|
|
471
|
+
index=False,
|
|
472
|
+
)
|
|
450
473
|
|
|
451
474
|
return failed_recordings
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from torch import nn, optim
|
|
4
|
+
from torch.nn import functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ModelWithTemperature(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
A thin decorator, which wraps a model with temperature scaling
|
|
10
|
+
model (nn.Module):
|
|
11
|
+
A classification neural network
|
|
12
|
+
NB: Output of the neural network should be the classification logits,
|
|
13
|
+
NOT the softmax (or log softmax)!
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, model):
|
|
17
|
+
super(ModelWithTemperature, self).__init__()
|
|
18
|
+
self.model = model
|
|
19
|
+
# https://github.com/gpleiss/temperature_scaling/issues/20
|
|
20
|
+
# for another approach, see https://github.com/gpleiss/temperature_scaling/issues/36
|
|
21
|
+
self.model.eval()
|
|
22
|
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
|
23
|
+
|
|
24
|
+
def forward(self, x):
|
|
25
|
+
logits = self.model(x)
|
|
26
|
+
return self.temperature_scale(logits)
|
|
27
|
+
|
|
28
|
+
def temperature_scale(self, logits):
|
|
29
|
+
"""
|
|
30
|
+
Perform temperature scaling on logits
|
|
31
|
+
"""
|
|
32
|
+
# Expand temperature to match the size of logits
|
|
33
|
+
temperature = self.temperature.unsqueeze(1).expand(
|
|
34
|
+
logits.size(0), logits.size(1)
|
|
35
|
+
)
|
|
36
|
+
return logits / temperature
|
|
37
|
+
|
|
38
|
+
# This function probably should live outside of this class, but whatever
|
|
39
|
+
def set_temperature(self, valid_loader):
|
|
40
|
+
"""
|
|
41
|
+
Tune the temperature of the model (using the validation set).
|
|
42
|
+
We're going to set it to optimize NLL.
|
|
43
|
+
valid_loader (DataLoader): validation set loader
|
|
44
|
+
"""
|
|
45
|
+
if torch.accelerator.is_available():
|
|
46
|
+
device = torch.accelerator.current_accelerator().type
|
|
47
|
+
else:
|
|
48
|
+
device = "cpu"
|
|
49
|
+
|
|
50
|
+
# self.cuda()
|
|
51
|
+
self.to(device)
|
|
52
|
+
nll_criterion = nn.CrossEntropyLoss().to(device) # .cuda()
|
|
53
|
+
ece_criterion = _ECELoss().to(device) # .cuda()
|
|
54
|
+
|
|
55
|
+
# First: collect all the logits and labels for the validation set
|
|
56
|
+
logits_list = []
|
|
57
|
+
labels_list = []
|
|
58
|
+
prediction_list = []
|
|
59
|
+
with torch.no_grad():
|
|
60
|
+
for x, label in valid_loader:
|
|
61
|
+
x = x.to(device) # .cuda()
|
|
62
|
+
logits = self.model(x)
|
|
63
|
+
logits_list.append(logits)
|
|
64
|
+
labels_list.append(label)
|
|
65
|
+
|
|
66
|
+
_, pred = torch.max(logits, 1)
|
|
67
|
+
prediction_list.append(pred)
|
|
68
|
+
logits = torch.cat(logits_list).to(device) # .cuda()
|
|
69
|
+
labels = torch.cat(labels_list).to(device) # .cuda()
|
|
70
|
+
predictions = torch.cat(prediction_list).to(device)
|
|
71
|
+
|
|
72
|
+
# Calculate NLL and ECE before temperature scaling
|
|
73
|
+
before_temperature_nll = nll_criterion(logits, labels).item()
|
|
74
|
+
before_temperature_ece = ece_criterion(logits, labels).item()
|
|
75
|
+
print(
|
|
76
|
+
"Before temperature - NLL: %.3f, ECE: %.3f"
|
|
77
|
+
% (before_temperature_nll, before_temperature_ece)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Next: optimize the temperature w.r.t. NLL
|
|
81
|
+
# https://github.com/gpleiss/temperature_scaling/issues/34
|
|
82
|
+
optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=100)
|
|
83
|
+
|
|
84
|
+
def eval():
|
|
85
|
+
optimizer.zero_grad()
|
|
86
|
+
loss = nll_criterion(self.temperature_scale(logits), labels)
|
|
87
|
+
loss.backward()
|
|
88
|
+
return loss
|
|
89
|
+
|
|
90
|
+
optimizer.step(eval)
|
|
91
|
+
|
|
92
|
+
# Calculate NLL and ECE after temperature scaling
|
|
93
|
+
after_temperature_nll = nll_criterion(
|
|
94
|
+
self.temperature_scale(logits), labels
|
|
95
|
+
).item()
|
|
96
|
+
after_temperature_ece = ece_criterion(
|
|
97
|
+
self.temperature_scale(logits), labels
|
|
98
|
+
).item()
|
|
99
|
+
print("Optimal temperature: %.3f" % self.temperature.item())
|
|
100
|
+
print(
|
|
101
|
+
"After temperature - NLL: %.3f, ECE: %.3f"
|
|
102
|
+
% (after_temperature_nll, after_temperature_ece)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
val_acc = round(
|
|
106
|
+
100 * np.mean(labels.cpu().numpy() == predictions.cpu().numpy()), 2
|
|
107
|
+
)
|
|
108
|
+
print(f"Validation accuracy: {val_acc}%")
|
|
109
|
+
|
|
110
|
+
return self
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class _ECELoss(nn.Module):
|
|
114
|
+
"""
|
|
115
|
+
Calculates the Expected Calibration Error of a model.
|
|
116
|
+
(This isn't necessary for temperature scaling, just a cool metric).
|
|
117
|
+
|
|
118
|
+
The input to this loss is the logits of a model, NOT the softmax scores.
|
|
119
|
+
|
|
120
|
+
This divides the confidence outputs into equally-sized interval bins.
|
|
121
|
+
In each bin, we compute the confidence gap:
|
|
122
|
+
|
|
123
|
+
bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
|
|
124
|
+
|
|
125
|
+
We then return a weighted average of the gaps, based on the number
|
|
126
|
+
of samples in each bin
|
|
127
|
+
|
|
128
|
+
See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
|
|
129
|
+
"Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
|
|
130
|
+
2015.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(self, n_bins=15):
|
|
134
|
+
"""
|
|
135
|
+
n_bins (int): number of confidence interval bins
|
|
136
|
+
"""
|
|
137
|
+
super(_ECELoss, self).__init__()
|
|
138
|
+
bin_boundaries = torch.linspace(0, 1, n_bins + 1)
|
|
139
|
+
self.bin_lowers = bin_boundaries[:-1]
|
|
140
|
+
self.bin_uppers = bin_boundaries[1:]
|
|
141
|
+
|
|
142
|
+
def forward(self, logits, labels):
|
|
143
|
+
softmaxes = F.softmax(logits, dim=1)
|
|
144
|
+
confidences, predictions = torch.max(softmaxes, 1)
|
|
145
|
+
accuracies = predictions.eq(labels)
|
|
146
|
+
|
|
147
|
+
ece = torch.zeros(1, device=logits.device)
|
|
148
|
+
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
|
|
149
|
+
# Calculated |confidence - accuracy| in each bin
|
|
150
|
+
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
|
|
151
|
+
prop_in_bin = in_bin.float().mean()
|
|
152
|
+
if prop_in_bin.item() > 0:
|
|
153
|
+
accuracy_in_bin = accuracies[in_bin].float().mean()
|
|
154
|
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
|
155
|
+
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
|
156
|
+
|
|
157
|
+
return ece
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: accusleepy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.0
|
|
4
4
|
Summary: Python implementation of AccuSleep
|
|
5
5
|
License: GPL-3.0-only
|
|
6
6
|
Author: Zeke Barger
|
|
@@ -39,6 +39,7 @@ It offers the following improvements over the MATLAB version (AccuSleep):
|
|
|
39
39
|
- Model files contain useful metadata (brain state configuration,
|
|
40
40
|
epoch length, number of epochs)
|
|
41
41
|
- Models optimized for real-time scoring can be trained
|
|
42
|
+
- Confidence scores can be saved and visualized
|
|
42
43
|
- Lists of recordings can be imported and exported for repeatable batch processing
|
|
43
44
|
- Undo/redo functionality in the manual scoring interface
|
|
44
45
|
|
|
@@ -75,6 +76,9 @@ to the [config file](accusleepy/config.json).
|
|
|
75
76
|
|
|
76
77
|
## Changelog
|
|
77
78
|
|
|
79
|
+
- 0.7.0: More settings can be configured in the UI
|
|
80
|
+
- 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
|
|
81
|
+
since the new calibration feature will make the confidence scores more accurate.
|
|
78
82
|
- 0.5.0: Performance improvements
|
|
79
83
|
- 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
|
|
80
84
|
- 0.4.4: Performance improvements
|
|
@@ -93,7 +97,12 @@ Manual scoring interface
|
|
|
93
97
|
## Acknowledgements
|
|
94
98
|
|
|
95
99
|
We would like to thank [Franz Weber](https://www.med.upenn.edu/weberlab/) for creating an
|
|
96
|
-
early version of the manual labeling interface.
|
|
100
|
+
early version of the manual labeling interface. The code that
|
|
101
|
+
creates spectrograms comes from the
|
|
102
|
+
[Prerau lab](https://github.com/preraulab/multitaper_toolbox/blob/master/python/multitaper_spectrogram_python.py)
|
|
103
|
+
with only minor modifications.
|
|
97
104
|
Jim Bohnslav's [deepethogram](https://github.com/jbohnslav/deepethogram) served as an
|
|
98
105
|
incredibly useful reference when reimplementing this project in python.
|
|
106
|
+
The model calibration code added in version 0.6.0 comes from Geoff Pleiss'
|
|
107
|
+
[temperature scaling repo](https://github.com/gpleiss/temperature_scaling).
|
|
99
108
|
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
accusleepy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
accusleepy/__main__.py,sha256=dKzl2N2Hg9lD264CWYNxThRyDKzWwyMwHRXmJxOmMis,104
|
|
3
|
+
accusleepy/bouts.py,sha256=F_y6DxnpKFfImYb7vCZluZ2eD5I_33gZXmRM8mvebsg,5679
|
|
4
|
+
accusleepy/brain_state_set.py,sha256=fRkrArHLIbEKimub804yt_mUXoyfsjJEfiJnTjeCMkY,3233
|
|
5
|
+
accusleepy/classification.py,sha256=mF35xMrD9QXGldSnl3vkdHbm7CAptPUNjHxUA_agOTA,9778
|
|
6
|
+
accusleepy/config.json,sha256=Ip0qTMAn2LZfof9GVA_azOvpXP0WKnqLCZeSaya1sss,819
|
|
7
|
+
accusleepy/constants.py,sha256=t7x-wzncJ_wVm0Oj6LiUzGukpsTsxfhO0KJiAAsuMN4,2244
|
|
8
|
+
accusleepy/fileio.py,sha256=woIF0zgJt6Lx6T9KBXAQ-AlbQAwOK1_RUmVF710nltI,7383
|
|
9
|
+
accusleepy/gui/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
accusleepy/gui/icons/brightness_down.png,sha256=PLT1fb83RHIhSRuU7MMMx0G7oJAY7o9wUcnqM8veZfM,12432
|
|
11
|
+
accusleepy/gui/icons/brightness_up.png,sha256=64GnUqgPvN5xZ6Um3wOzwqvUmdAWYZT6eFmWpBsHyks,12989
|
|
12
|
+
accusleepy/gui/icons/double_down_arrow.png,sha256=fGiPIP7_RJ3UAonNhORFVX0emXEmtzRlHI3Tfjai064,4964
|
|
13
|
+
accusleepy/gui/icons/double_up_arrow.png,sha256=n7QEo0bZZDve4thwTCKghPKVjTNbQMgyQNsn46iqXbI,5435
|
|
14
|
+
accusleepy/gui/icons/down_arrow.png,sha256=XwS_Gq2j6PoNHRaeaAGoh5kcXJNXWAHWWbJbUsvrRPU,3075
|
|
15
|
+
accusleepy/gui/icons/home.png,sha256=yd3nmHlD9w2a2j3cBd-w_Cuidr-J0apryRoWJoPb66w,5662
|
|
16
|
+
accusleepy/gui/icons/question.png,sha256=IJcIRgQOC9KlzA4vtA5Qu-DQ1-SclhVLeovIsEfl3TU,17105
|
|
17
|
+
accusleepy/gui/icons/save.png,sha256=J3EA8iU1BqLYRSsrq_OdoZlqrv2yfL7oV54DklTy_DI,13555
|
|
18
|
+
accusleepy/gui/icons/up_arrow.png,sha256=V9yF9t1WgjPaUu-mF1YGe_DfaRHg2dUpR_sUVVcvVvY,3329
|
|
19
|
+
accusleepy/gui/icons/zoom_in.png,sha256=MFWnKZp7Rvh4bLPq4Cqo4sB_jQYedUUtT8-ZO8tNYyc,13589
|
|
20
|
+
accusleepy/gui/icons/zoom_out.png,sha256=IB8Jecb3i0U4qjWRR46ridjLpvLCSe7PozBaLqQqYSw,13055
|
|
21
|
+
accusleepy/gui/images/primary_window.png,sha256=ABu49UXTBfI5UwNHrpWPhJoTC7zhyLVvbsbywIuu1nc,602640
|
|
22
|
+
accusleepy/gui/images/viewer_window.png,sha256=b_B7m9WSLMAOzNjctq76SyekO1WfC6qYZVNnYfhjPe8,977197
|
|
23
|
+
accusleepy/gui/images/viewer_window_annotated.png,sha256=uMNUmsZIdzDlQpyoiS3lJGoWlg_T325Oj5hDZhM3Y14,146817
|
|
24
|
+
accusleepy/gui/main.py,sha256=Ywk3XwfwMjAZL76qwDnZRHQ96L7T1DWyMCU8la0N5Qg,62447
|
|
25
|
+
accusleepy/gui/manual_scoring.py,sha256=xk0TERVbH33owj4QYJkMA2LF_qR8jtS5W2enq36_njk,40907
|
|
26
|
+
accusleepy/gui/mplwidget.py,sha256=rJSTtWmLjHn8r3c9Kb23Rc4XzXl3i9B-JrjNjjlNnmQ,13492
|
|
27
|
+
accusleepy/gui/primary_window.py,sha256=sVIsZXnszBsq3ZUC9OIOwB13jPQRcisaxi0Vg-O6A-8,135475
|
|
28
|
+
accusleepy/gui/primary_window.ui,sha256=_SorGY7qJG2B6e45UNt51OheeYLkpzAxxezX1hEfI3Y,201341
|
|
29
|
+
accusleepy/gui/resources.qrc,sha256=wqPendnTLAuKfVI6v2lKHiRqAWM0oaz2ZuF5cucJdS4,803
|
|
30
|
+
accusleepy/gui/resources_rc.py,sha256=Z2e34h30U4snJjnYdZVV9B6yjATKxxfvgTRt5uXtQdo,329727
|
|
31
|
+
accusleepy/gui/text/main_guide.md,sha256=iZDRp5OWyQX9LV7CMeUFIYv2ryKlIcGALRLXjxR8HpI,8288
|
|
32
|
+
accusleepy/gui/text/manual_scoring_guide.md,sha256=ow_RMSjFy05NupEDSCuJtu-V65-BPnIkrZqtssFoZCQ,999
|
|
33
|
+
accusleepy/gui/viewer_window.py,sha256=O4ceqLMYdahxQ9s6DYhloUnNESim-cqIZxFeXEiRjog,24444
|
|
34
|
+
accusleepy/gui/viewer_window.ui,sha256=jsjydsSSyN49AwJw4nVS2mEJ2JBIUTXesAJsij1JNV0,31530
|
|
35
|
+
accusleepy/models.py,sha256=15VjtFoWaYXblyGPbtYgp0yJdyUfGu7t3zCShdtr_7c,3799
|
|
36
|
+
accusleepy/multitaper.py,sha256=D5-iglwkFBRciL5tKSNcunMtcq0rM3zHwRHUVPgem1U,25679
|
|
37
|
+
accusleepy/signal_processing.py,sha256=NOkQuLmUVINd0tFvt48RXRxSl4TDjV42m54XWc_EB9s,16987
|
|
38
|
+
accusleepy/temperature_scaling.py,sha256=glvPcvxHpBdFjwjGfZdNku9L_BozycEmdqZhKKUCCNg,5749
|
|
39
|
+
accusleepy-0.7.0.dist-info/METADATA,sha256=yLkBaWfg1wrN7DevZlLLg881Yq9hhGoTsqHSgCJ0IYk,4536
|
|
40
|
+
accusleepy-0.7.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
41
|
+
accusleepy-0.7.0.dist-info/RECORD,,
|
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
This is the current brain state configuration.
|
|
2
|
-
If you make changes, click 'Save' to store them.
|
|
3
|
-
|
|
4
|
-
Each brain state has several attributes:
|
|
5
|
-
|
|
6
|
-
- Digit: how the brain state is represented in label files,
|
|
7
|
-
and the key on the keyboard that, during manual scoring,
|
|
8
|
-
sets an epoch to this brain state.
|
|
9
|
-
|
|
10
|
-
- Enabled: whether a brain state for this digit exists.
|
|
11
|
-
|
|
12
|
-
- Name: unique name of the brain state (e.g., REM).
|
|
13
|
-
|
|
14
|
-
- Scored: whether a classification model should output this
|
|
15
|
-
brain state. If you have a state that corresponds to
|
|
16
|
-
missing or corrupted data, for example, you would
|
|
17
|
-
probably want to uncheck this box.
|
|
18
|
-
|
|
19
|
-
- Frequency: approximate relative frequency of this brain
|
|
20
|
-
state. Does not need to be very accurate, but it can
|
|
21
|
-
influence classification accuracy slightly. The values
|
|
22
|
-
for all scored brain states must sum to 1.
|
|
23
|
-
|
|
24
|
-
Important notes:
|
|
25
|
-
- Changing these settings can invalidate existing label files,
|
|
26
|
-
calibration files, and trained models!
|
|
27
|
-
- Reinstalling AccuSleePy will overwrite this configuration.
|
|
28
|
-
- You can also choose the default epoch length that is shown
|
|
29
|
-
when the primary interface starts up.
|