accusleepy 0.8.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/bouts.py +3 -3
- accusleepy/brain_state_set.py +6 -4
- accusleepy/classification.py +14 -50
- accusleepy/constants.py +3 -0
- accusleepy/fileio.py +24 -5
- accusleepy/gui/dialogs.py +40 -0
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/main.py +212 -1025
- accusleepy/gui/manual_scoring.py +1 -1
- accusleepy/gui/primary_window.py +7 -9
- accusleepy/gui/primary_window.ui +6 -8
- accusleepy/gui/recording_manager.py +110 -0
- accusleepy/gui/settings_widget.py +409 -0
- accusleepy/gui/text/main_guide.md +1 -1
- accusleepy/models.py +1 -1
- accusleepy/services.py +581 -0
- accusleepy/signal_processing.py +110 -38
- accusleepy/temperature_scaling.py +14 -8
- accusleepy/validation.py +67 -2
- {accusleepy-0.8.1.dist-info → accusleepy-0.9.2.dist-info}/METADATA +2 -2
- {accusleepy-0.8.1.dist-info → accusleepy-0.9.2.dist-info}/RECORD +22 -18
- {accusleepy-0.8.1.dist-info → accusleepy-0.9.2.dist-info}/WHEEL +1 -1
accusleepy/gui/main.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
# AccuSleePy main window
|
|
2
2
|
# Icon source: Arkinasi, https://www.flaticon.com/authors/arkinasi
|
|
3
3
|
|
|
4
|
-
import
|
|
4
|
+
import logging
|
|
5
5
|
import os
|
|
6
|
-
import shutil
|
|
7
6
|
import sys
|
|
8
7
|
from dataclasses import dataclass
|
|
9
8
|
from functools import partial
|
|
10
9
|
|
|
11
10
|
import numpy as np
|
|
12
|
-
import toml
|
|
13
11
|
from PySide6.QtCore import (
|
|
14
12
|
QEvent,
|
|
15
13
|
QKeyCombination,
|
|
@@ -21,31 +19,17 @@ from PySide6.QtCore import (
|
|
|
21
19
|
from PySide6.QtGui import QKeySequence, QShortcut
|
|
22
20
|
from PySide6.QtWidgets import (
|
|
23
21
|
QApplication,
|
|
24
|
-
QCheckBox,
|
|
25
|
-
QDoubleSpinBox,
|
|
26
|
-
QFileDialog,
|
|
27
22
|
QLabel,
|
|
28
|
-
QListWidgetItem,
|
|
29
23
|
QMainWindow,
|
|
30
24
|
QTextBrowser,
|
|
31
25
|
QVBoxLayout,
|
|
32
26
|
QWidget,
|
|
33
27
|
)
|
|
34
28
|
|
|
35
|
-
from accusleepy.
|
|
36
|
-
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
29
|
+
from accusleepy.brain_state_set import BRAIN_STATES_KEY
|
|
37
30
|
from accusleepy.constants import (
|
|
38
|
-
ANNOTATIONS_FILENAME,
|
|
39
|
-
CALIBRATION_ANNOTATION_FILENAME,
|
|
40
31
|
CALIBRATION_FILE_TYPE,
|
|
41
32
|
DEFAULT_MODEL_TYPE,
|
|
42
|
-
DEFAULT_EMG_FILTER_ORDER,
|
|
43
|
-
DEFAULT_EMG_BP_LOWER,
|
|
44
|
-
DEFAULT_EMG_BP_UPPER,
|
|
45
|
-
DEFAULT_BATCH_SIZE,
|
|
46
|
-
DEFAULT_LEARNING_RATE,
|
|
47
|
-
DEFAULT_MOMENTUM,
|
|
48
|
-
DEFAULT_TRAINING_EPOCHS,
|
|
49
33
|
LABEL_FILE_TYPE,
|
|
50
34
|
MESSAGE_BOX_MAX_DEPTH,
|
|
51
35
|
MODEL_FILE_TYPE,
|
|
@@ -55,31 +39,28 @@ from accusleepy.constants import (
|
|
|
55
39
|
UNDEFINED_LABEL,
|
|
56
40
|
)
|
|
57
41
|
from accusleepy.fileio import (
|
|
58
|
-
Recording,
|
|
59
|
-
load_calibration_file,
|
|
60
42
|
load_config,
|
|
61
43
|
load_labels,
|
|
62
44
|
load_recording,
|
|
63
|
-
|
|
64
|
-
save_config,
|
|
65
|
-
save_labels,
|
|
66
|
-
save_recording_list,
|
|
67
|
-
EMGFilter,
|
|
68
|
-
Hyperparameters,
|
|
45
|
+
get_version,
|
|
69
46
|
)
|
|
47
|
+
from accusleepy.gui.dialogs import select_existing_file, select_save_location
|
|
70
48
|
from accusleepy.gui.manual_scoring import ManualScoringWindow
|
|
71
49
|
from accusleepy.gui.primary_window import Ui_PrimaryWindow
|
|
72
|
-
from accusleepy.
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
50
|
+
from accusleepy.gui.recording_manager import RecordingListManager
|
|
51
|
+
from accusleepy.gui.settings_widget import SettingsWidget
|
|
52
|
+
from accusleepy.services import (
|
|
53
|
+
LoadedModel,
|
|
54
|
+
TrainingService,
|
|
55
|
+
check_single_file_inputs,
|
|
56
|
+
create_calibration,
|
|
57
|
+
score_recording_list,
|
|
80
58
|
)
|
|
59
|
+
from accusleepy.validation import validate_and_correct_labels
|
|
60
|
+
from accusleepy.signal_processing import resample_and_standardize
|
|
61
|
+
from accusleepy.validation import check_config_consistency
|
|
81
62
|
|
|
82
|
-
|
|
63
|
+
logger = logging.getLogger(__name__)
|
|
83
64
|
|
|
84
65
|
# on Windows, prevent dark mode from changing the visual style
|
|
85
66
|
if os.name == "nt":
|
|
@@ -91,14 +72,22 @@ MAIN_GUIDE_FILE = os.path.normpath(r"text/main_guide.md")
|
|
|
91
72
|
|
|
92
73
|
|
|
93
74
|
@dataclass
|
|
94
|
-
class
|
|
95
|
-
"""
|
|
75
|
+
class TrainingSettings:
|
|
76
|
+
"""Settings for training a new model"""
|
|
77
|
+
|
|
78
|
+
epochs_per_img: int = 9
|
|
79
|
+
delete_images: bool = True
|
|
80
|
+
model_type: str = DEFAULT_MODEL_TYPE
|
|
81
|
+
calibrate: bool = True
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class ScoringSettings:
|
|
86
|
+
"""Settings for scoring a recording"""
|
|
96
87
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
is_scored_widget: QCheckBox
|
|
101
|
-
frequency_widget: QDoubleSpinBox
|
|
88
|
+
only_overwrite_undefined: bool
|
|
89
|
+
save_confidence_scores: bool
|
|
90
|
+
min_bout_length: int | float
|
|
102
91
|
|
|
103
92
|
|
|
104
93
|
class AccuSleepWindow(QMainWindow):
|
|
@@ -112,65 +101,42 @@ class AccuSleepWindow(QMainWindow):
|
|
|
112
101
|
self.ui.setupUi(self)
|
|
113
102
|
self.setWindowTitle("AccuSleePy")
|
|
114
103
|
|
|
115
|
-
#
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
self.
|
|
120
|
-
self.
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
self.
|
|
128
|
-
self.initialize_settings_tab()
|
|
104
|
+
# Load configuration
|
|
105
|
+
loaded_config = load_config()
|
|
106
|
+
|
|
107
|
+
# Apply default values from the configuration
|
|
108
|
+
self.epoch_length = loaded_config.default_epoch_length
|
|
109
|
+
self.scoring = ScoringSettings(
|
|
110
|
+
only_overwrite_undefined=loaded_config.overwrite_setting,
|
|
111
|
+
save_confidence_scores=loaded_config.save_confidence_setting,
|
|
112
|
+
min_bout_length=loaded_config.min_bout_length,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Initialize settings tab (manages Settings tab UI and saved config values)
|
|
116
|
+
self.config = SettingsWidget(ui=self.ui, config=loaded_config, parent=self)
|
|
129
117
|
|
|
130
118
|
# initialize info about the recordings, classification data / settings
|
|
131
119
|
self.ui.epoch_length_input.setValue(self.epoch_length)
|
|
132
|
-
self.ui.overwritecheckbox.setChecked(self.only_overwrite_undefined)
|
|
133
|
-
self.ui.save_confidence_checkbox.setChecked(self.save_confidence_scores)
|
|
134
|
-
self.ui.bout_length_input.setValue(self.min_bout_length)
|
|
135
|
-
self.model = None
|
|
120
|
+
self.ui.overwritecheckbox.setChecked(self.scoring.only_overwrite_undefined)
|
|
121
|
+
self.ui.save_confidence_checkbox.setChecked(self.scoring.save_confidence_scores)
|
|
122
|
+
self.ui.bout_length_input.setValue(self.scoring.min_bout_length)
|
|
136
123
|
|
|
137
|
-
#
|
|
138
|
-
self.
|
|
139
|
-
self.delete_training_images = True
|
|
140
|
-
self.model_type = DEFAULT_MODEL_TYPE
|
|
141
|
-
self.calibrate_trained_model = True
|
|
124
|
+
# loaded classification model and its metadata
|
|
125
|
+
self.loaded_model = LoadedModel()
|
|
142
126
|
|
|
143
|
-
#
|
|
144
|
-
self.
|
|
145
|
-
self.model_epochs_per_img = None
|
|
127
|
+
# settings for training new models
|
|
128
|
+
self.training = TrainingSettings()
|
|
146
129
|
|
|
147
130
|
# set up the list of recordings
|
|
148
|
-
|
|
149
|
-
|
|
131
|
+
self.recording_manager = RecordingListManager(
|
|
132
|
+
self.ui.recording_list_widget, parent=self
|
|
150
133
|
)
|
|
151
|
-
self.ui.recording_list_widget.addItem(first_recording.widget)
|
|
152
|
-
self.ui.recording_list_widget.setCurrentRow(0)
|
|
153
|
-
# index of currently selected recording in the list
|
|
154
|
-
self.recording_index = 0
|
|
155
|
-
# list of recordings the user has added
|
|
156
|
-
self.recordings = [first_recording]
|
|
157
134
|
|
|
158
135
|
# messages to display
|
|
159
136
|
self.messages = []
|
|
160
137
|
|
|
161
138
|
# display current version
|
|
162
|
-
|
|
163
|
-
toml_file = os.path.join(
|
|
164
|
-
os.path.dirname(
|
|
165
|
-
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
166
|
-
),
|
|
167
|
-
"pyproject.toml",
|
|
168
|
-
)
|
|
169
|
-
if os.path.isfile(toml_file):
|
|
170
|
-
toml_data = toml.load(toml_file)
|
|
171
|
-
if "project" in toml_data and "version" in toml_data["project"]:
|
|
172
|
-
version = toml_data["project"]["version"]
|
|
173
|
-
self.ui.version_label.setText(f"v{version}")
|
|
139
|
+
self.ui.version_label.setText(f"v{get_version()}")
|
|
174
140
|
|
|
175
141
|
# user input: keyboard shortcuts
|
|
176
142
|
keypress_quit = QShortcut(
|
|
@@ -183,8 +149,12 @@ class AccuSleepWindow(QMainWindow):
|
|
|
183
149
|
self.ui.add_button.clicked.connect(self.add_recording)
|
|
184
150
|
self.ui.remove_button.clicked.connect(self.remove_recording)
|
|
185
151
|
self.ui.recording_list_widget.currentRowChanged.connect(self.select_recording)
|
|
186
|
-
self.ui.sampling_rate_input.valueChanged.connect(
|
|
187
|
-
|
|
152
|
+
self.ui.sampling_rate_input.valueChanged.connect(
|
|
153
|
+
lambda v: setattr(self.recording_manager.current, "sampling_rate", v)
|
|
154
|
+
)
|
|
155
|
+
self.ui.epoch_length_input.valueChanged.connect(
|
|
156
|
+
lambda v: setattr(self, "epoch_length", v)
|
|
157
|
+
)
|
|
188
158
|
self.ui.recording_file_button.clicked.connect(self.select_recording_file)
|
|
189
159
|
self.ui.select_label_button.clicked.connect(self.select_label_file)
|
|
190
160
|
self.ui.create_label_button.clicked.connect(self.create_label_file)
|
|
@@ -192,27 +162,33 @@ class AccuSleepWindow(QMainWindow):
|
|
|
192
162
|
self.ui.create_calibration_button.clicked.connect(self.create_calibration_file)
|
|
193
163
|
self.ui.select_calibration_button.clicked.connect(self.select_calibration_file)
|
|
194
164
|
self.ui.load_model_button.clicked.connect(partial(self.load_model, None))
|
|
195
|
-
self.ui.score_all_button.clicked.connect(self.
|
|
196
|
-
self.ui.overwritecheckbox.stateChanged.connect(
|
|
165
|
+
self.ui.score_all_button.clicked.connect(self.score_recordings)
|
|
166
|
+
self.ui.overwritecheckbox.stateChanged.connect(
|
|
167
|
+
lambda v: setattr(self.scoring, "only_overwrite_undefined", bool(v))
|
|
168
|
+
)
|
|
197
169
|
self.ui.save_confidence_checkbox.stateChanged.connect(
|
|
198
|
-
self.
|
|
170
|
+
lambda v: setattr(self.scoring, "save_confidence_scores", bool(v))
|
|
171
|
+
)
|
|
172
|
+
self.ui.bout_length_input.valueChanged.connect(
|
|
173
|
+
lambda v: setattr(self.scoring, "min_bout_length", v)
|
|
199
174
|
)
|
|
200
|
-
self.ui.bout_length_input.valueChanged.connect(self.update_min_bout_length)
|
|
201
175
|
self.ui.user_manual_button.clicked.connect(self.show_user_manual)
|
|
202
|
-
self.ui.image_number_input.valueChanged.connect(
|
|
203
|
-
|
|
176
|
+
self.ui.image_number_input.valueChanged.connect(
|
|
177
|
+
lambda v: setattr(self.training, "epochs_per_img", v)
|
|
178
|
+
)
|
|
179
|
+
self.ui.delete_image_box.stateChanged.connect(
|
|
180
|
+
lambda v: setattr(self.training, "delete_images", bool(v))
|
|
181
|
+
)
|
|
204
182
|
self.ui.calibrate_checkbox.stateChanged.connect(
|
|
205
183
|
self.update_training_calibration
|
|
206
184
|
)
|
|
207
185
|
self.ui.train_model_button.clicked.connect(self.train_model)
|
|
208
|
-
self.ui.save_config_button.clicked.connect(self.
|
|
186
|
+
self.ui.save_config_button.clicked.connect(self.config.save_config)
|
|
209
187
|
self.ui.export_button.clicked.connect(self.export_recording_list)
|
|
210
188
|
self.ui.import_button.clicked.connect(self.import_recording_list)
|
|
211
189
|
self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
|
|
212
|
-
self.ui.reset_emg_params_button.clicked.connect(self.
|
|
213
|
-
self.ui.reset_hyperparams_button.clicked.connect(
|
|
214
|
-
self.reset_hyperparams_settings
|
|
215
|
-
)
|
|
190
|
+
self.ui.reset_emg_params_button.clicked.connect(self.config.reset_emg_filter)
|
|
191
|
+
self.ui.reset_hyperparams_button.clicked.connect(self.config.reset_hyperparams)
|
|
216
192
|
|
|
217
193
|
# user input: drag and drop
|
|
218
194
|
self.ui.recording_file_label.installEventFilter(self)
|
|
@@ -227,52 +203,29 @@ class AccuSleepWindow(QMainWindow):
|
|
|
227
203
|
|
|
228
204
|
:param default_selected: whether default option is selected
|
|
229
205
|
"""
|
|
230
|
-
self.model_type = (
|
|
206
|
+
self.training.model_type = (
|
|
231
207
|
DEFAULT_MODEL_TYPE if default_selected else REAL_TIME_MODEL_TYPE
|
|
232
208
|
)
|
|
233
209
|
|
|
234
210
|
def export_recording_list(self) -> None:
|
|
235
211
|
"""Save current list of recordings to file"""
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
self,
|
|
239
|
-
caption="Save list of recordings as",
|
|
240
|
-
filter="*" + RECORDING_LIST_FILE_TYPE,
|
|
212
|
+
filename = select_save_location(
|
|
213
|
+
self, "Save list of recordings as", "*" + RECORDING_LIST_FILE_TYPE
|
|
241
214
|
)
|
|
242
215
|
if not filename:
|
|
243
216
|
return
|
|
244
|
-
|
|
245
|
-
save_recording_list(filename=filename, recordings=self.recordings)
|
|
217
|
+
self.recording_manager.export_to_file(filename)
|
|
246
218
|
self.show_message(f"Saved list of recordings to {filename}")
|
|
247
219
|
|
|
248
220
|
def import_recording_list(self):
|
|
249
221
|
"""Load list of recordings from file, overwriting current list"""
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
file_dialog.setNameFilter("*" + RECORDING_LIST_FILE_TYPE)
|
|
255
|
-
|
|
256
|
-
if file_dialog.exec():
|
|
257
|
-
selected_files = file_dialog.selectedFiles()
|
|
258
|
-
filename = selected_files[0]
|
|
259
|
-
filename = os.path.normpath(filename)
|
|
260
|
-
else:
|
|
222
|
+
filename = select_existing_file(
|
|
223
|
+
self, "Select list of recordings", "*" + RECORDING_LIST_FILE_TYPE
|
|
224
|
+
)
|
|
225
|
+
if not filename:
|
|
261
226
|
return
|
|
262
227
|
|
|
263
|
-
|
|
264
|
-
self.ui.recording_list_widget.clear()
|
|
265
|
-
# overwrite current list
|
|
266
|
-
self.recordings = load_recording_list(filename)
|
|
267
|
-
|
|
268
|
-
for recording in self.recordings:
|
|
269
|
-
recording.widget = QListWidgetItem(
|
|
270
|
-
f"Recording {recording.name}", self.ui.recording_list_widget
|
|
271
|
-
)
|
|
272
|
-
self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
|
|
273
|
-
|
|
274
|
-
# display new list
|
|
275
|
-
self.ui.recording_list_widget.setCurrentRow(0)
|
|
228
|
+
self.recording_manager.import_from_file(filename)
|
|
276
229
|
self.show_message(f"Loaded list of recordings from {filename}")
|
|
277
230
|
|
|
278
231
|
def eventFilter(self, obj: QObject, event: QEvent) -> bool:
|
|
@@ -297,20 +250,21 @@ class AccuSleepWindow(QMainWindow):
|
|
|
297
250
|
|
|
298
251
|
if filename is None:
|
|
299
252
|
return super().eventFilter(obj, event)
|
|
253
|
+
filename = str(filename)
|
|
300
254
|
|
|
301
255
|
_, file_extension = os.path.splitext(filename)
|
|
302
256
|
|
|
303
257
|
if obj == self.ui.recording_file_label:
|
|
304
258
|
if file_extension in RECORDING_FILE_TYPES:
|
|
305
|
-
self.
|
|
259
|
+
self.recording_manager.current.recording_file = filename
|
|
306
260
|
self.ui.recording_file_label.setText(filename)
|
|
307
261
|
elif obj == self.ui.label_file_label:
|
|
308
262
|
if file_extension == LABEL_FILE_TYPE:
|
|
309
|
-
self.
|
|
263
|
+
self.recording_manager.current.label_file = filename
|
|
310
264
|
self.ui.label_file_label.setText(filename)
|
|
311
265
|
elif obj == self.ui.calibration_file_label:
|
|
312
266
|
if file_extension == CALIBRATION_FILE_TYPE:
|
|
313
|
-
self.
|
|
267
|
+
self.recording_manager.current.calibration_file = filename
|
|
314
268
|
self.ui.calibration_file_label.setText(filename)
|
|
315
269
|
elif obj == self.ui.model_label:
|
|
316
270
|
self.load_model(filename=filename)
|
|
@@ -318,335 +272,69 @@ class AccuSleepWindow(QMainWindow):
|
|
|
318
272
|
return super().eventFilter(obj, event)
|
|
319
273
|
|
|
320
274
|
def train_model(self) -> None:
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
self
|
|
324
|
-
and self.training_epochs_per_img % 2 == 0
|
|
325
|
-
):
|
|
326
|
-
self.show_message(
|
|
327
|
-
(
|
|
328
|
-
"ERROR: for the default model type, number of epochs "
|
|
329
|
-
"per image must be an odd number."
|
|
330
|
-
)
|
|
331
|
-
)
|
|
332
|
-
return
|
|
333
|
-
|
|
334
|
-
# determine fraction of training data to use for calibration
|
|
335
|
-
if self.calibrate_trained_model:
|
|
336
|
-
calibration_fraction = self.ui.calibration_spinbox.value() / 100
|
|
337
|
-
else:
|
|
338
|
-
calibration_fraction = 0
|
|
339
|
-
|
|
340
|
-
# check some inputs for each recording
|
|
341
|
-
for recording_index in range(len(self.recordings)):
|
|
342
|
-
error_message = self.check_single_file_inputs(recording_index)
|
|
343
|
-
if error_message:
|
|
344
|
-
self.show_message(
|
|
345
|
-
f"ERROR (recording {self.recordings[recording_index].name}): {error_message}"
|
|
346
|
-
)
|
|
347
|
-
return
|
|
348
|
-
|
|
349
|
-
# get filename for the new model
|
|
350
|
-
model_filename, _ = QFileDialog.getSaveFileName(
|
|
351
|
-
self,
|
|
352
|
-
caption="Save classification model file as",
|
|
353
|
-
filter="*" + MODEL_FILE_TYPE,
|
|
275
|
+
"""Train a classification model using the current recordings."""
|
|
276
|
+
model_filename = select_save_location(
|
|
277
|
+
self, "Save classification model file as", "*" + MODEL_FILE_TYPE
|
|
354
278
|
)
|
|
355
279
|
if not model_filename:
|
|
356
280
|
self.show_message("Model training canceled, no filename given")
|
|
357
281
|
return
|
|
358
|
-
model_filename = os.path.normpath(model_filename)
|
|
359
282
|
|
|
360
|
-
#
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
)
|
|
366
|
-
|
|
367
|
-
if os.path.exists(temp_image_dir): # unlikely
|
|
368
|
-
self.show_message(
|
|
369
|
-
"Warning: training image folder exists, will be overwritten"
|
|
370
|
-
)
|
|
371
|
-
os.makedirs(temp_image_dir, exist_ok=True)
|
|
283
|
+
# Determine calibration fraction
|
|
284
|
+
if self.training.calibrate:
|
|
285
|
+
calibration_fraction = self.ui.calibration_spinbox.value() / 100
|
|
286
|
+
else:
|
|
287
|
+
calibration_fraction = 0
|
|
372
288
|
|
|
373
|
-
#
|
|
289
|
+
# Show progress message
|
|
374
290
|
self.show_message("Training, please wait. See console for progress updates.")
|
|
375
|
-
if not self.delete_training_images:
|
|
376
|
-
self.show_message((f"Creating training images in {temp_image_dir}"))
|
|
377
|
-
else:
|
|
378
|
-
self.show_message(
|
|
379
|
-
(f"Creating temporary folder of training images: {temp_image_dir}")
|
|
380
|
-
)
|
|
381
291
|
self.ui.message_area.repaint()
|
|
382
292
|
QApplication.processEvents()
|
|
383
|
-
print("Creating training images")
|
|
384
|
-
failed_recordings = create_training_images(
|
|
385
|
-
recordings=self.recordings,
|
|
386
|
-
output_path=temp_image_dir,
|
|
387
|
-
epoch_length=self.epoch_length,
|
|
388
|
-
epochs_per_img=self.training_epochs_per_img,
|
|
389
|
-
brain_state_set=self.brain_state_set,
|
|
390
|
-
model_type=self.model_type,
|
|
391
|
-
calibration_fraction=calibration_fraction,
|
|
392
|
-
emg_filter=self.emg_filter,
|
|
393
|
-
)
|
|
394
|
-
if len(failed_recordings) > 0:
|
|
395
|
-
if len(failed_recordings) == len(self.recordings):
|
|
396
|
-
self.show_message("ERROR: no recordings were valid!")
|
|
397
|
-
return
|
|
398
|
-
else:
|
|
399
|
-
self.show_message(
|
|
400
|
-
(
|
|
401
|
-
"WARNING: the following recordings could not be "
|
|
402
|
-
"loaded and will not be used for training: "
|
|
403
|
-
f"{', '.join([str(r) for r in failed_recordings])}"
|
|
404
|
-
)
|
|
405
|
-
)
|
|
406
293
|
|
|
407
|
-
#
|
|
408
|
-
self.show_message
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
print("Training model")
|
|
412
|
-
from accusleepy.classification import create_dataloader, train_ssann
|
|
413
|
-
from accusleepy.models import save_model
|
|
414
|
-
from accusleepy.temperature_scaling import ModelWithTemperature
|
|
415
|
-
|
|
416
|
-
model = train_ssann(
|
|
417
|
-
annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
|
|
418
|
-
img_dir=temp_image_dir,
|
|
419
|
-
mixture_weights=self.brain_state_set.mixture_weights,
|
|
420
|
-
n_classes=self.brain_state_set.n_classes,
|
|
421
|
-
hyperparameters=self.hyperparameters,
|
|
422
|
-
)
|
|
423
|
-
|
|
424
|
-
# calibrate the model
|
|
425
|
-
if self.calibrate_trained_model:
|
|
426
|
-
calibration_annotation_file = os.path.join(
|
|
427
|
-
temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
|
|
428
|
-
)
|
|
429
|
-
calibration_dataloader = create_dataloader(
|
|
430
|
-
annotations_file=calibration_annotation_file,
|
|
431
|
-
img_dir=temp_image_dir,
|
|
432
|
-
hyperparameters=self.hyperparameters,
|
|
433
|
-
)
|
|
434
|
-
model = ModelWithTemperature(model)
|
|
435
|
-
print("Calibrating model")
|
|
436
|
-
model.set_temperature(calibration_dataloader)
|
|
437
|
-
|
|
438
|
-
# save model
|
|
439
|
-
save_model(
|
|
440
|
-
model=model,
|
|
441
|
-
filename=model_filename,
|
|
294
|
+
# Create service and run training
|
|
295
|
+
service = TrainingService(progress_callback=self.show_message)
|
|
296
|
+
result = service.train_model(
|
|
297
|
+
recordings=list(self.recording_manager),
|
|
442
298
|
epoch_length=self.epoch_length,
|
|
443
|
-
epochs_per_img=self.
|
|
444
|
-
model_type=self.model_type,
|
|
445
|
-
|
|
446
|
-
|
|
299
|
+
epochs_per_img=self.training.epochs_per_img,
|
|
300
|
+
model_type=self.training.model_type,
|
|
301
|
+
calibrate=self.training.calibrate,
|
|
302
|
+
calibration_fraction=calibration_fraction,
|
|
303
|
+
brain_state_set=self.config.brain_state_set,
|
|
304
|
+
emg_filter=self.config.emg_filter,
|
|
305
|
+
hyperparameters=self.config.hyperparameters,
|
|
306
|
+
model_filename=model_filename,
|
|
307
|
+
delete_images=self.training.delete_images,
|
|
447
308
|
)
|
|
448
309
|
|
|
449
|
-
#
|
|
450
|
-
|
|
451
|
-
print("Cleaning up training image folder")
|
|
452
|
-
shutil.rmtree(temp_image_dir)
|
|
453
|
-
|
|
454
|
-
self.show_message(f"Training complete. Saved model to {model_filename}")
|
|
455
|
-
print("Training complete.")
|
|
456
|
-
|
|
457
|
-
def update_image_deletion(self) -> None:
|
|
458
|
-
"""Update choice of whether to delete images after training"""
|
|
459
|
-
self.delete_training_images = self.ui.delete_image_box.isChecked()
|
|
310
|
+
# Display results
|
|
311
|
+
result.report_to(self.show_message)
|
|
460
312
|
|
|
461
313
|
def update_training_calibration(self) -> None:
|
|
462
314
|
"""Update choice of whether to calibrate model after training"""
|
|
463
|
-
self.
|
|
464
|
-
self.ui.calibration_spinbox.setEnabled(self.
|
|
465
|
-
|
|
466
|
-
def update_epochs_per_img(self, new_value) -> None:
|
|
467
|
-
"""Update number of epochs per image
|
|
468
|
-
|
|
469
|
-
:param new_value: new number of epochs per image
|
|
470
|
-
"""
|
|
471
|
-
self.training_epochs_per_img = new_value
|
|
472
|
-
|
|
473
|
-
def score_all(self) -> None:
|
|
474
|
-
"""Score all recordings using the classification model"""
|
|
475
|
-
# check basic inputs
|
|
476
|
-
if self.model is None:
|
|
477
|
-
self.ui.score_all_status.setText("missing classification model")
|
|
478
|
-
self.show_message("ERROR: no classification model file selected")
|
|
479
|
-
return
|
|
480
|
-
if self.min_bout_length < self.epoch_length:
|
|
481
|
-
self.ui.score_all_status.setText("invalid minimum bout length")
|
|
482
|
-
self.show_message("ERROR: minimum bout length must be >= epoch length")
|
|
483
|
-
return
|
|
484
|
-
if self.epoch_length != self.model_epoch_length:
|
|
485
|
-
self.ui.score_all_status.setText("invalid epoch length")
|
|
486
|
-
self.show_message(
|
|
487
|
-
(
|
|
488
|
-
"ERROR: model was trained with an epoch length of "
|
|
489
|
-
f"{self.model_epoch_length} seconds, but the current "
|
|
490
|
-
f"epoch length setting is {self.epoch_length} seconds."
|
|
491
|
-
)
|
|
492
|
-
)
|
|
493
|
-
return
|
|
315
|
+
self.training.calibrate = self.ui.calibrate_checkbox.isChecked()
|
|
316
|
+
self.ui.calibration_spinbox.setEnabled(self.training.calibrate)
|
|
494
317
|
|
|
318
|
+
def score_recordings(self) -> None:
|
|
319
|
+
"""Score all recordings using the classification model."""
|
|
495
320
|
self.ui.score_all_status.setText("running...")
|
|
496
321
|
self.ui.score_all_status.repaint()
|
|
497
322
|
QApplication.processEvents()
|
|
498
323
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
f"ERROR (recording {self.recordings[recording_index].name}): {error_message}"
|
|
510
|
-
)
|
|
511
|
-
return
|
|
512
|
-
if self.recordings[recording_index].calibration_file == "":
|
|
513
|
-
self.ui.score_all_status.setText(
|
|
514
|
-
f"error on recording {self.recordings[recording_index].name}"
|
|
515
|
-
)
|
|
516
|
-
self.show_message(
|
|
517
|
-
(
|
|
518
|
-
f"ERROR (recording {self.recordings[recording_index].name}): "
|
|
519
|
-
"no calibration file selected"
|
|
520
|
-
)
|
|
521
|
-
)
|
|
522
|
-
return
|
|
523
|
-
|
|
524
|
-
# score each recording
|
|
525
|
-
for recording_index in range(len(self.recordings)):
|
|
526
|
-
# load EEG, EMG
|
|
527
|
-
try:
|
|
528
|
-
eeg, emg = load_recording(
|
|
529
|
-
self.recordings[recording_index].recording_file
|
|
530
|
-
)
|
|
531
|
-
sampling_rate = self.recordings[recording_index].sampling_rate
|
|
532
|
-
|
|
533
|
-
eeg, emg, sampling_rate = resample_and_standardize(
|
|
534
|
-
eeg=eeg,
|
|
535
|
-
emg=emg,
|
|
536
|
-
sampling_rate=sampling_rate,
|
|
537
|
-
epoch_length=self.epoch_length,
|
|
538
|
-
)
|
|
539
|
-
except Exception:
|
|
540
|
-
self.show_message(
|
|
541
|
-
(
|
|
542
|
-
"ERROR: could not load recording "
|
|
543
|
-
f"{self.recordings[recording_index].name}."
|
|
544
|
-
"This recording will be skipped."
|
|
545
|
-
)
|
|
546
|
-
)
|
|
547
|
-
continue
|
|
548
|
-
|
|
549
|
-
# load labels
|
|
550
|
-
label_file = self.recordings[recording_index].label_file
|
|
551
|
-
if os.path.isfile(label_file):
|
|
552
|
-
try:
|
|
553
|
-
# ignore any existing confidence scores; they will all be overwritten
|
|
554
|
-
existing_labels, _ = load_labels(label_file)
|
|
555
|
-
except Exception:
|
|
556
|
-
self.show_message(
|
|
557
|
-
(
|
|
558
|
-
"ERROR: could not load existing labels for recording "
|
|
559
|
-
f"{self.recordings[recording_index].name}."
|
|
560
|
-
"This recording will be skipped."
|
|
561
|
-
)
|
|
562
|
-
)
|
|
563
|
-
continue
|
|
564
|
-
# only check the length
|
|
565
|
-
samples_per_epoch = sampling_rate * self.epoch_length
|
|
566
|
-
epochs_in_recording = round(eeg.size / samples_per_epoch)
|
|
567
|
-
if epochs_in_recording != existing_labels.size:
|
|
568
|
-
self.show_message(
|
|
569
|
-
(
|
|
570
|
-
"ERROR: existing labels for recording "
|
|
571
|
-
f"{self.recordings[recording_index].name} "
|
|
572
|
-
"do not match the recording length. "
|
|
573
|
-
"This recording will be skipped."
|
|
574
|
-
)
|
|
575
|
-
)
|
|
576
|
-
continue
|
|
577
|
-
else:
|
|
578
|
-
existing_labels = None
|
|
579
|
-
|
|
580
|
-
# load calibration data
|
|
581
|
-
if not os.path.isfile(self.recordings[recording_index].calibration_file):
|
|
582
|
-
self.show_message(
|
|
583
|
-
(
|
|
584
|
-
"ERROR: calibration file does not exist for recording "
|
|
585
|
-
f"{self.recordings[recording_index].name}. "
|
|
586
|
-
"This recording will be skipped."
|
|
587
|
-
)
|
|
588
|
-
)
|
|
589
|
-
continue
|
|
590
|
-
try:
|
|
591
|
-
(
|
|
592
|
-
mixture_means,
|
|
593
|
-
mixture_sds,
|
|
594
|
-
) = load_calibration_file(
|
|
595
|
-
self.recordings[recording_index].calibration_file
|
|
596
|
-
)
|
|
597
|
-
except Exception:
|
|
598
|
-
self.show_message(
|
|
599
|
-
(
|
|
600
|
-
"ERROR: could not load calibration file for recording "
|
|
601
|
-
f"{self.recordings[recording_index].name}. "
|
|
602
|
-
"This recording will be skipped."
|
|
603
|
-
)
|
|
604
|
-
)
|
|
605
|
-
continue
|
|
606
|
-
|
|
607
|
-
labels, confidence_scores = score_recording(
|
|
608
|
-
model=self.model,
|
|
609
|
-
eeg=eeg,
|
|
610
|
-
emg=emg,
|
|
611
|
-
mixture_means=mixture_means,
|
|
612
|
-
mixture_sds=mixture_sds,
|
|
613
|
-
sampling_rate=sampling_rate,
|
|
614
|
-
epoch_length=self.epoch_length,
|
|
615
|
-
epochs_per_img=self.model_epochs_per_img,
|
|
616
|
-
brain_state_set=self.brain_state_set,
|
|
617
|
-
emg_filter=self.emg_filter,
|
|
618
|
-
)
|
|
619
|
-
|
|
620
|
-
# overwrite as needed
|
|
621
|
-
if existing_labels is not None and self.only_overwrite_undefined:
|
|
622
|
-
labels[existing_labels != UNDEFINED_LABEL] = existing_labels[
|
|
623
|
-
existing_labels != UNDEFINED_LABEL
|
|
624
|
-
]
|
|
625
|
-
|
|
626
|
-
# enforce minimum bout length
|
|
627
|
-
labels = enforce_min_bout_length(
|
|
628
|
-
labels=labels,
|
|
629
|
-
epoch_length=self.epoch_length,
|
|
630
|
-
min_bout_length=self.min_bout_length,
|
|
631
|
-
)
|
|
632
|
-
|
|
633
|
-
# ignore confidence scores if desired
|
|
634
|
-
if not self.save_confidence_scores:
|
|
635
|
-
confidence_scores = None
|
|
636
|
-
|
|
637
|
-
# save results
|
|
638
|
-
save_labels(
|
|
639
|
-
labels=labels, filename=label_file, confidence_scores=confidence_scores
|
|
640
|
-
)
|
|
641
|
-
self.show_message(
|
|
642
|
-
(
|
|
643
|
-
"Saved labels for recording "
|
|
644
|
-
f"{self.recordings[recording_index].name} "
|
|
645
|
-
f"to {label_file}"
|
|
646
|
-
)
|
|
647
|
-
)
|
|
324
|
+
result = score_recording_list(
|
|
325
|
+
recordings=list(self.recording_manager),
|
|
326
|
+
loaded_model=self.loaded_model,
|
|
327
|
+
epoch_length=self.epoch_length,
|
|
328
|
+
only_overwrite_undefined=self.scoring.only_overwrite_undefined,
|
|
329
|
+
save_confidence_scores=self.scoring.save_confidence_scores,
|
|
330
|
+
min_bout_length=self.scoring.min_bout_length,
|
|
331
|
+
brain_state_set=self.config.brain_state_set,
|
|
332
|
+
emg_filter=self.config.emg_filter,
|
|
333
|
+
)
|
|
648
334
|
|
|
649
|
-
|
|
335
|
+
# Display results
|
|
336
|
+
result.report_to(self.show_message)
|
|
337
|
+
self.ui.score_all_status.setText("error" if not result.success else "")
|
|
650
338
|
|
|
651
339
|
def load_model(self, filename=None) -> None:
|
|
652
340
|
"""Load trained classification model from file
|
|
@@ -654,17 +342,10 @@ class AccuSleepWindow(QMainWindow):
|
|
|
654
342
|
:param filename: model filename, if it's known
|
|
655
343
|
"""
|
|
656
344
|
if filename is None:
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
file_dialog.setNameFilter("*" + MODEL_FILE_TYPE)
|
|
662
|
-
|
|
663
|
-
if file_dialog.exec():
|
|
664
|
-
selected_files = file_dialog.selectedFiles()
|
|
665
|
-
filename = selected_files[0]
|
|
666
|
-
filename = os.path.normpath(filename)
|
|
667
|
-
else:
|
|
345
|
+
filename = select_existing_file(
|
|
346
|
+
self, "Select classification model", "*" + MODEL_FILE_TYPE
|
|
347
|
+
)
|
|
348
|
+
if not filename:
|
|
668
349
|
return
|
|
669
350
|
|
|
670
351
|
if not os.path.isfile(filename):
|
|
@@ -682,6 +363,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
682
363
|
filename=filename
|
|
683
364
|
)
|
|
684
365
|
except Exception:
|
|
366
|
+
logger.exception("Failed to load %s", filename)
|
|
685
367
|
self.show_message(
|
|
686
368
|
(
|
|
687
369
|
"ERROR: could not load classification model. Check "
|
|
@@ -702,14 +384,14 @@ class AccuSleepWindow(QMainWindow):
|
|
|
702
384
|
)
|
|
703
385
|
return
|
|
704
386
|
|
|
705
|
-
self.model = model
|
|
706
|
-
self.
|
|
707
|
-
self.
|
|
387
|
+
self.loaded_model.model = model
|
|
388
|
+
self.loaded_model.epoch_length = epoch_length
|
|
389
|
+
self.loaded_model.epochs_per_img = epochs_per_img
|
|
708
390
|
|
|
709
391
|
# warn user if the model's expected epoch length or brain states
|
|
710
392
|
# don't match the current configuration
|
|
711
393
|
config_warnings = check_config_consistency(
|
|
712
|
-
current_brain_states=self.brain_state_set.to_output_dict()[
|
|
394
|
+
current_brain_states=self.config.brain_state_set.to_output_dict()[
|
|
713
395
|
BRAIN_STATES_KEY
|
|
714
396
|
],
|
|
715
397
|
model_brain_states=brain_states,
|
|
@@ -736,17 +418,21 @@ class AccuSleepWindow(QMainWindow):
|
|
|
736
418
|
:param status_widget: UI element on which to display error messages
|
|
737
419
|
:return: EEG data, EMG data, sampling rate, process completion
|
|
738
420
|
"""
|
|
739
|
-
error_message =
|
|
421
|
+
error_message = check_single_file_inputs(
|
|
422
|
+
self.recording_manager.current, self.epoch_length
|
|
423
|
+
)
|
|
740
424
|
if error_message:
|
|
741
425
|
status_widget.setText(error_message)
|
|
742
426
|
self.show_message(f"ERROR: {error_message}")
|
|
743
427
|
return None, None, None, False
|
|
744
428
|
|
|
745
429
|
try:
|
|
746
|
-
eeg, emg = load_recording(
|
|
747
|
-
self.recordings[self.recording_index].recording_file
|
|
748
|
-
)
|
|
430
|
+
eeg, emg = load_recording(self.recording_manager.current.recording_file)
|
|
749
431
|
except Exception:
|
|
432
|
+
logger.exception(
|
|
433
|
+
"Failed to load %s",
|
|
434
|
+
self.recording_manager.current.recording_file,
|
|
435
|
+
)
|
|
750
436
|
status_widget.setText("could not load recording")
|
|
751
437
|
self.show_message(
|
|
752
438
|
(
|
|
@@ -756,7 +442,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
756
442
|
)
|
|
757
443
|
return None, None, None, False
|
|
758
444
|
|
|
759
|
-
sampling_rate = self.
|
|
445
|
+
sampling_rate = self.recording_manager.current.sampling_rate
|
|
760
446
|
|
|
761
447
|
eeg, emg, sampling_rate = resample_and_standardize(
|
|
762
448
|
eeg=eeg,
|
|
@@ -768,135 +454,35 @@ class AccuSleepWindow(QMainWindow):
|
|
|
768
454
|
return eeg, emg, sampling_rate, True
|
|
769
455
|
|
|
770
456
|
def create_calibration_file(self) -> None:
|
|
771
|
-
"""Creates a calibration file
|
|
457
|
+
"""Creates a calibration file.
|
|
772
458
|
|
|
773
459
|
This loads a recording and its labels, checks that the labels are
|
|
774
460
|
all valid, creates the calibration file, and sets the
|
|
775
461
|
"calibration file" property of the current recording to be the
|
|
776
462
|
newly created file.
|
|
777
463
|
"""
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
self.ui.calibration_status
|
|
781
|
-
)
|
|
782
|
-
if not success:
|
|
783
|
-
return
|
|
784
|
-
|
|
785
|
-
# load the labels
|
|
786
|
-
label_file = self.recordings[self.recording_index].label_file
|
|
787
|
-
if not os.path.isfile(label_file):
|
|
788
|
-
self.ui.calibration_status.setText("label file does not exist")
|
|
789
|
-
self.show_message("ERROR: label file does not exist")
|
|
790
|
-
return
|
|
791
|
-
try:
|
|
792
|
-
labels, _ = load_labels(label_file)
|
|
793
|
-
except Exception:
|
|
794
|
-
self.ui.calibration_status.setText("could not load labels")
|
|
795
|
-
self.show_message(
|
|
796
|
-
(
|
|
797
|
-
"ERROR: could not load labels. "
|
|
798
|
-
"Check user manual for formatting instructions."
|
|
799
|
-
)
|
|
800
|
-
)
|
|
801
|
-
return
|
|
802
|
-
label_error_message = check_label_validity(
|
|
803
|
-
labels=labels,
|
|
804
|
-
confidence_scores=None,
|
|
805
|
-
samples_in_recording=eeg.size,
|
|
806
|
-
sampling_rate=sampling_rate,
|
|
807
|
-
epoch_length=self.epoch_length,
|
|
808
|
-
brain_state_set=self.brain_state_set,
|
|
809
|
-
)
|
|
810
|
-
if label_error_message:
|
|
811
|
-
self.ui.calibration_status.setText("invalid label file")
|
|
812
|
-
self.show_message(f"ERROR: {label_error_message}")
|
|
813
|
-
return
|
|
814
|
-
|
|
815
|
-
# get the name for the calibration file
|
|
816
|
-
filename, _ = QFileDialog.getSaveFileName(
|
|
817
|
-
self,
|
|
818
|
-
caption="Save calibration file as",
|
|
819
|
-
filter="*" + CALIBRATION_FILE_TYPE,
|
|
464
|
+
filename = select_save_location(
|
|
465
|
+
self, "Save calibration file as", "*" + CALIBRATION_FILE_TYPE
|
|
820
466
|
)
|
|
821
467
|
if not filename:
|
|
822
468
|
return
|
|
823
|
-
filename = os.path.normpath(filename)
|
|
824
469
|
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
create_calibration_file(
|
|
828
|
-
filename=filename,
|
|
829
|
-
eeg=eeg,
|
|
830
|
-
emg=emg,
|
|
831
|
-
labels=labels,
|
|
832
|
-
sampling_rate=sampling_rate,
|
|
470
|
+
result = create_calibration(
|
|
471
|
+
recording=self.recording_manager.current,
|
|
833
472
|
epoch_length=self.epoch_length,
|
|
834
|
-
brain_state_set=self.brain_state_set,
|
|
835
|
-
emg_filter=self.emg_filter,
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
self.ui.calibration_status.setText("")
|
|
839
|
-
self.show_message(
|
|
840
|
-
(
|
|
841
|
-
"Created calibration file using recording "
|
|
842
|
-
f"{self.recordings[self.recording_index].name} "
|
|
843
|
-
f"at {filename}"
|
|
844
|
-
)
|
|
473
|
+
brain_state_set=self.config.brain_state_set,
|
|
474
|
+
emg_filter=self.config.emg_filter,
|
|
475
|
+
output_filename=filename,
|
|
845
476
|
)
|
|
846
477
|
|
|
847
|
-
|
|
848
|
-
self.
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
error message.
|
|
856
|
-
|
|
857
|
-
:param recording_index: index of the recording in the list of
|
|
858
|
-
all recordings.
|
|
859
|
-
:return: error message
|
|
860
|
-
"""
|
|
861
|
-
sampling_rate = self.recordings[recording_index].sampling_rate
|
|
862
|
-
if self.epoch_length == 0:
|
|
863
|
-
return "epoch length can't be 0"
|
|
864
|
-
if sampling_rate == 0:
|
|
865
|
-
return "sampling rate can't be 0"
|
|
866
|
-
if self.epoch_length > sampling_rate:
|
|
867
|
-
return "invalid epoch length or sampling rate"
|
|
868
|
-
if self.recordings[self.recording_index].recording_file == "":
|
|
869
|
-
return "no recording selected"
|
|
870
|
-
if not os.path.isfile(self.recordings[self.recording_index].recording_file):
|
|
871
|
-
return "recording file does not exist"
|
|
872
|
-
if self.recordings[self.recording_index].label_file == "":
|
|
873
|
-
return "no label file selected"
|
|
874
|
-
|
|
875
|
-
def update_min_bout_length(self, new_value) -> None:
|
|
876
|
-
"""Update the minimum bout length
|
|
877
|
-
|
|
878
|
-
:param new_value: new minimum bout length, in seconds
|
|
879
|
-
"""
|
|
880
|
-
self.min_bout_length = new_value
|
|
881
|
-
|
|
882
|
-
def update_overwrite_policy(self, checked) -> None:
|
|
883
|
-
"""Toggle overwriting policy
|
|
884
|
-
|
|
885
|
-
If the checkbox is enabled, only epochs where the brain state is set to
|
|
886
|
-
undefined will be overwritten by the automatic scoring process.
|
|
887
|
-
|
|
888
|
-
:param checked: state of the checkbox
|
|
889
|
-
"""
|
|
890
|
-
self.only_overwrite_undefined = checked
|
|
891
|
-
|
|
892
|
-
def update_confidence_policy(self, checked) -> None:
|
|
893
|
-
"""Toggle policy for saving confidence scores
|
|
894
|
-
|
|
895
|
-
If the checkbox is enabled, confidence scores will be saved to the label files.
|
|
896
|
-
|
|
897
|
-
:param checked: state of the checkbox
|
|
898
|
-
"""
|
|
899
|
-
self.save_confidence_scores = checked
|
|
478
|
+
# Display results
|
|
479
|
+
result.report_to(self.show_message)
|
|
480
|
+
if not result.success:
|
|
481
|
+
self.ui.calibration_status.setText("error")
|
|
482
|
+
else:
|
|
483
|
+
self.ui.calibration_status.setText("")
|
|
484
|
+
self.recording_manager.current.calibration_file = filename
|
|
485
|
+
self.ui.calibration_file_label.setText(filename)
|
|
900
486
|
|
|
901
487
|
def manual_scoring(self) -> None:
|
|
902
488
|
"""View the selected recording for manual scoring"""
|
|
@@ -914,11 +500,12 @@ class AccuSleepWindow(QMainWindow):
|
|
|
914
500
|
|
|
915
501
|
# if the labels exist, load them
|
|
916
502
|
# otherwise, create a blank set of labels
|
|
917
|
-
label_file = self.
|
|
503
|
+
label_file = self.recording_manager.current.label_file
|
|
918
504
|
if os.path.isfile(label_file):
|
|
919
505
|
try:
|
|
920
506
|
labels, confidence_scores = load_labels(label_file)
|
|
921
507
|
except Exception:
|
|
508
|
+
logger.exception("Failed to load %s", label_file)
|
|
922
509
|
self.ui.manual_scoring_status.setText("could not load labels")
|
|
923
510
|
self.show_message(
|
|
924
511
|
(
|
|
@@ -936,56 +523,23 @@ class AccuSleepWindow(QMainWindow):
|
|
|
936
523
|
# to a label file that does not have one
|
|
937
524
|
confidence_scores = None
|
|
938
525
|
|
|
939
|
-
# check that
|
|
940
|
-
|
|
526
|
+
# check that labels are valid and correct minor length mismatches
|
|
527
|
+
labels, confidence_scores, validation_message = validate_and_correct_labels(
|
|
941
528
|
labels=labels,
|
|
942
529
|
confidence_scores=confidence_scores,
|
|
943
530
|
samples_in_recording=eeg.size,
|
|
944
531
|
sampling_rate=sampling_rate,
|
|
945
532
|
epoch_length=self.epoch_length,
|
|
946
|
-
brain_state_set=self.brain_state_set,
|
|
533
|
+
brain_state_set=self.config.brain_state_set,
|
|
947
534
|
)
|
|
948
|
-
if
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
epochs_in_recording = round(eeg.size / samples_per_epoch)
|
|
955
|
-
if epochs_in_recording - labels.size == 1:
|
|
956
|
-
labels = np.concatenate((labels, np.array([UNDEFINED_LABEL])))
|
|
957
|
-
if confidence_scores is not None:
|
|
958
|
-
confidence_scores = np.concatenate(
|
|
959
|
-
(confidence_scores, np.array([0]))
|
|
960
|
-
)
|
|
961
|
-
self.show_message(
|
|
962
|
-
(
|
|
963
|
-
"WARNING: an undefined epoch was added to "
|
|
964
|
-
"the label file to correct its length."
|
|
965
|
-
)
|
|
966
|
-
)
|
|
967
|
-
elif labels.size - epochs_in_recording == 1:
|
|
968
|
-
labels = labels[:-1]
|
|
969
|
-
if confidence_scores is not None:
|
|
970
|
-
confidence_scores = confidence_scores[:-1]
|
|
971
|
-
self.show_message(
|
|
972
|
-
(
|
|
973
|
-
"WARNING: the last epoch was removed from "
|
|
974
|
-
"the label file to correct its length."
|
|
975
|
-
)
|
|
976
|
-
)
|
|
977
|
-
else:
|
|
978
|
-
self.ui.manual_scoring_status.setText("invalid label file")
|
|
979
|
-
self.show_message(f"ERROR: {label_error}")
|
|
980
|
-
return
|
|
981
|
-
else:
|
|
982
|
-
self.ui.manual_scoring_status.setText("invalid label file")
|
|
983
|
-
self.show_message(f"ERROR: {label_error}")
|
|
984
|
-
return
|
|
535
|
+
if labels is None:
|
|
536
|
+
self.ui.manual_scoring_status.setText("invalid label file")
|
|
537
|
+
self.show_message(f"ERROR: {validation_message}")
|
|
538
|
+
return
|
|
539
|
+
if validation_message:
|
|
540
|
+
self.show_message(f"WARNING: {validation_message}")
|
|
985
541
|
|
|
986
|
-
self.show_message(
|
|
987
|
-
f"Viewing recording {self.recordings[self.recording_index].name}"
|
|
988
|
-
)
|
|
542
|
+
self.show_message(f"Viewing recording {self.recording_manager.current.name}")
|
|
989
543
|
self.ui.manual_scoring_status.setText("file is open")
|
|
990
544
|
|
|
991
545
|
# launch the manual scoring window
|
|
@@ -997,7 +551,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
997
551
|
confidence_scores=confidence_scores,
|
|
998
552
|
sampling_rate=sampling_rate,
|
|
999
553
|
epoch_length=self.epoch_length,
|
|
1000
|
-
emg_filter=self.emg_filter,
|
|
554
|
+
emg_filter=self.config.emg_filter,
|
|
1001
555
|
)
|
|
1002
556
|
manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
|
|
1003
557
|
manual_scoring_window.exec()
|
|
@@ -1005,89 +559,48 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1005
559
|
|
|
1006
560
|
def create_label_file(self) -> None:
|
|
1007
561
|
"""Set the filename for a new label file"""
|
|
1008
|
-
filename
|
|
562
|
+
filename = select_save_location(
|
|
1009
563
|
self,
|
|
1010
|
-
|
|
1011
|
-
|
|
564
|
+
"Set filename for label file (nothing will be overwritten yet)",
|
|
565
|
+
"*" + LABEL_FILE_TYPE,
|
|
1012
566
|
)
|
|
1013
567
|
if filename:
|
|
1014
|
-
|
|
1015
|
-
self.recordings[self.recording_index].label_file = filename
|
|
568
|
+
self.recording_manager.current.label_file = filename
|
|
1016
569
|
self.ui.label_file_label.setText(filename)
|
|
1017
570
|
|
|
1018
571
|
def select_label_file(self) -> None:
|
|
1019
572
|
"""User can select an existing label file"""
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
if file_dialog.exec():
|
|
1027
|
-
selected_files = file_dialog.selectedFiles()
|
|
1028
|
-
filename = selected_files[0]
|
|
1029
|
-
filename = os.path.normpath(filename)
|
|
1030
|
-
self.recordings[self.recording_index].label_file = filename
|
|
573
|
+
filename = select_existing_file(
|
|
574
|
+
self, "Select label file", "*" + LABEL_FILE_TYPE
|
|
575
|
+
)
|
|
576
|
+
if filename:
|
|
577
|
+
self.recording_manager.current.label_file = filename
|
|
1031
578
|
self.ui.label_file_label.setText(filename)
|
|
1032
579
|
|
|
1033
580
|
def select_calibration_file(self) -> None:
|
|
1034
581
|
"""User can select a calibration file"""
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
if file_dialog.exec():
|
|
1042
|
-
selected_files = file_dialog.selectedFiles()
|
|
1043
|
-
filename = selected_files[0]
|
|
1044
|
-
filename = os.path.normpath(filename)
|
|
1045
|
-
self.recordings[self.recording_index].calibration_file = filename
|
|
582
|
+
filename = select_existing_file(
|
|
583
|
+
self, "Select calibration file", "*" + CALIBRATION_FILE_TYPE
|
|
584
|
+
)
|
|
585
|
+
if filename:
|
|
586
|
+
self.recording_manager.current.calibration_file = filename
|
|
1046
587
|
self.ui.calibration_file_label.setText(filename)
|
|
1047
588
|
|
|
1048
589
|
def select_recording_file(self) -> None:
|
|
1049
590
|
"""User can select a recording file"""
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
file_dialog.setNameFilter(f"(*{' *'.join(RECORDING_FILE_TYPES)})")
|
|
1055
|
-
|
|
1056
|
-
if file_dialog.exec():
|
|
1057
|
-
selected_files = file_dialog.selectedFiles()
|
|
1058
|
-
filename = selected_files[0]
|
|
1059
|
-
filename = os.path.normpath(filename)
|
|
1060
|
-
self.recordings[self.recording_index].recording_file = filename
|
|
591
|
+
file_filter = f"(*{' *'.join(RECORDING_FILE_TYPES)})"
|
|
592
|
+
filename = select_existing_file(self, "Select recording file", file_filter)
|
|
593
|
+
if filename:
|
|
594
|
+
self.recording_manager.current.recording_file = filename
|
|
1061
595
|
self.ui.recording_file_label.setText(filename)
|
|
1062
596
|
|
|
1063
597
|
def show_recording_info(self) -> None:
|
|
1064
598
|
"""Update the UI to show info for the selected recording"""
|
|
1065
|
-
self.
|
|
1066
|
-
|
|
1067
|
-
)
|
|
1068
|
-
self.ui.
|
|
1069
|
-
|
|
1070
|
-
)
|
|
1071
|
-
self.ui.label_file_label.setText(
|
|
1072
|
-
self.recordings[self.recording_index].label_file
|
|
1073
|
-
)
|
|
1074
|
-
self.ui.calibration_file_label.setText(
|
|
1075
|
-
self.recordings[self.recording_index].calibration_file
|
|
1076
|
-
)
|
|
1077
|
-
|
|
1078
|
-
def update_epoch_length(self, new_value: int | float) -> None:
|
|
1079
|
-
"""Update the epoch length when the widget state changes
|
|
1080
|
-
|
|
1081
|
-
:param new_value: new epoch length
|
|
1082
|
-
"""
|
|
1083
|
-
self.epoch_length = new_value
|
|
1084
|
-
|
|
1085
|
-
def update_sampling_rate(self, new_value: int | float) -> None:
|
|
1086
|
-
"""Update recording's sampling rate when the widget state changes
|
|
1087
|
-
|
|
1088
|
-
:param new_value: new sampling rate
|
|
1089
|
-
"""
|
|
1090
|
-
self.recordings[self.recording_index].sampling_rate = new_value
|
|
599
|
+
recording = self.recording_manager.current
|
|
600
|
+
self.ui.sampling_rate_input.setValue(recording.sampling_rate)
|
|
601
|
+
self.ui.recording_file_label.setText(recording.recording_file)
|
|
602
|
+
self.ui.label_file_label.setText(recording.label_file)
|
|
603
|
+
self.ui.calibration_file_label.setText(recording.calibration_file)
|
|
1091
604
|
|
|
1092
605
|
def show_message(self, message: str) -> None:
|
|
1093
606
|
"""Display a new message to the user
|
|
@@ -1102,50 +615,22 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1102
615
|
scrollbar = self.ui.message_area.verticalScrollBar()
|
|
1103
616
|
scrollbar.setValue(scrollbar.maximum())
|
|
1104
617
|
|
|
1105
|
-
def select_recording(self,
|
|
1106
|
-
"""Callback for when a recording is selected
|
|
1107
|
-
|
|
1108
|
-
:param list_index: index of this recording in the list widget
|
|
1109
|
-
"""
|
|
1110
|
-
# get index of this recording
|
|
1111
|
-
self.recording_index = list_index
|
|
1112
|
-
# display information about this recording
|
|
618
|
+
def select_recording(self, _index: int) -> None:
|
|
619
|
+
"""Callback for when a recording is selected"""
|
|
1113
620
|
self.show_recording_info()
|
|
1114
621
|
self.ui.selected_recording_groupbox.setTitle(
|
|
1115
|
-
f"Data / actions for Recording {self.
|
|
622
|
+
f"Data / actions for Recording {self.recording_manager.current.name}"
|
|
1116
623
|
)
|
|
1117
624
|
|
|
1118
625
|
def add_recording(self) -> None:
|
|
1119
626
|
"""Add new recording to the list"""
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
# add new recording to list
|
|
1124
|
-
self.recordings.append(
|
|
1125
|
-
Recording(
|
|
1126
|
-
name=new_name,
|
|
1127
|
-
sampling_rate=self.recordings[self.recording_index].sampling_rate,
|
|
1128
|
-
widget=QListWidgetItem(
|
|
1129
|
-
f"Recording {new_name}", self.ui.recording_list_widget
|
|
1130
|
-
),
|
|
1131
|
-
)
|
|
1132
|
-
)
|
|
1133
|
-
|
|
1134
|
-
# display new list
|
|
1135
|
-
self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
|
|
1136
|
-
self.ui.recording_list_widget.setCurrentRow(len(self.recordings) - 1)
|
|
1137
|
-
self.show_message(f"added Recording {new_name}")
|
|
627
|
+
current_sampling_rate = self.recording_manager.current.sampling_rate
|
|
628
|
+
recording = self.recording_manager.add(sampling_rate=current_sampling_rate)
|
|
629
|
+
self.show_message(f"added Recording {recording.name}")
|
|
1138
630
|
|
|
1139
631
|
def remove_recording(self) -> None:
|
|
1140
632
|
"""Delete selected recording from the list"""
|
|
1141
|
-
|
|
1142
|
-
current_list_index = self.ui.recording_list_widget.currentRow()
|
|
1143
|
-
_ = self.ui.recording_list_widget.takeItem(current_list_index)
|
|
1144
|
-
self.show_message(
|
|
1145
|
-
f"deleted Recording {self.recordings[current_list_index].name}"
|
|
1146
|
-
)
|
|
1147
|
-
del self.recordings[current_list_index]
|
|
1148
|
-
self.recording_index = self.ui.recording_list_widget.currentRow()
|
|
633
|
+
self.show_message(self.recording_manager.remove_current())
|
|
1149
634
|
|
|
1150
635
|
def show_user_manual(self) -> None:
|
|
1151
636
|
"""Show a popup window with the user manual"""
|
|
@@ -1161,310 +646,12 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1161
646
|
self.popup.setGeometry(QRect(100, 100, 600, 600))
|
|
1162
647
|
self.popup.show()
|
|
1163
648
|
|
|
1164
|
-
def initialize_settings_tab(self):
|
|
1165
|
-
"""Populate settings tab and assign its callbacks"""
|
|
1166
|
-
# store dictionary that maps digits to rows of widgets
|
|
1167
|
-
# in the settings tab
|
|
1168
|
-
self.settings_widgets = {
|
|
1169
|
-
1: StateSettings(
|
|
1170
|
-
digit=1,
|
|
1171
|
-
enabled_widget=self.ui.enable_state_1,
|
|
1172
|
-
name_widget=self.ui.state_name_1,
|
|
1173
|
-
is_scored_widget=self.ui.state_scored_1,
|
|
1174
|
-
frequency_widget=self.ui.state_frequency_1,
|
|
1175
|
-
),
|
|
1176
|
-
2: StateSettings(
|
|
1177
|
-
digit=2,
|
|
1178
|
-
enabled_widget=self.ui.enable_state_2,
|
|
1179
|
-
name_widget=self.ui.state_name_2,
|
|
1180
|
-
is_scored_widget=self.ui.state_scored_2,
|
|
1181
|
-
frequency_widget=self.ui.state_frequency_2,
|
|
1182
|
-
),
|
|
1183
|
-
3: StateSettings(
|
|
1184
|
-
digit=3,
|
|
1185
|
-
enabled_widget=self.ui.enable_state_3,
|
|
1186
|
-
name_widget=self.ui.state_name_3,
|
|
1187
|
-
is_scored_widget=self.ui.state_scored_3,
|
|
1188
|
-
frequency_widget=self.ui.state_frequency_3,
|
|
1189
|
-
),
|
|
1190
|
-
4: StateSettings(
|
|
1191
|
-
digit=4,
|
|
1192
|
-
enabled_widget=self.ui.enable_state_4,
|
|
1193
|
-
name_widget=self.ui.state_name_4,
|
|
1194
|
-
is_scored_widget=self.ui.state_scored_4,
|
|
1195
|
-
frequency_widget=self.ui.state_frequency_4,
|
|
1196
|
-
),
|
|
1197
|
-
5: StateSettings(
|
|
1198
|
-
digit=5,
|
|
1199
|
-
enabled_widget=self.ui.enable_state_5,
|
|
1200
|
-
name_widget=self.ui.state_name_5,
|
|
1201
|
-
is_scored_widget=self.ui.state_scored_5,
|
|
1202
|
-
frequency_widget=self.ui.state_frequency_5,
|
|
1203
|
-
),
|
|
1204
|
-
6: StateSettings(
|
|
1205
|
-
digit=6,
|
|
1206
|
-
enabled_widget=self.ui.enable_state_6,
|
|
1207
|
-
name_widget=self.ui.state_name_6,
|
|
1208
|
-
is_scored_widget=self.ui.state_scored_6,
|
|
1209
|
-
frequency_widget=self.ui.state_frequency_6,
|
|
1210
|
-
),
|
|
1211
|
-
7: StateSettings(
|
|
1212
|
-
digit=7,
|
|
1213
|
-
enabled_widget=self.ui.enable_state_7,
|
|
1214
|
-
name_widget=self.ui.state_name_7,
|
|
1215
|
-
is_scored_widget=self.ui.state_scored_7,
|
|
1216
|
-
frequency_widget=self.ui.state_frequency_7,
|
|
1217
|
-
),
|
|
1218
|
-
8: StateSettings(
|
|
1219
|
-
digit=8,
|
|
1220
|
-
enabled_widget=self.ui.enable_state_8,
|
|
1221
|
-
name_widget=self.ui.state_name_8,
|
|
1222
|
-
is_scored_widget=self.ui.state_scored_8,
|
|
1223
|
-
frequency_widget=self.ui.state_frequency_8,
|
|
1224
|
-
),
|
|
1225
|
-
9: StateSettings(
|
|
1226
|
-
digit=9,
|
|
1227
|
-
enabled_widget=self.ui.enable_state_9,
|
|
1228
|
-
name_widget=self.ui.state_name_9,
|
|
1229
|
-
is_scored_widget=self.ui.state_scored_9,
|
|
1230
|
-
frequency_widget=self.ui.state_frequency_9,
|
|
1231
|
-
),
|
|
1232
|
-
0: StateSettings(
|
|
1233
|
-
digit=0,
|
|
1234
|
-
enabled_widget=self.ui.enable_state_0,
|
|
1235
|
-
name_widget=self.ui.state_name_0,
|
|
1236
|
-
is_scored_widget=self.ui.state_scored_0,
|
|
1237
|
-
frequency_widget=self.ui.state_frequency_0,
|
|
1238
|
-
),
|
|
1239
|
-
}
|
|
1240
|
-
|
|
1241
|
-
# update widget state to display current config
|
|
1242
|
-
# UI defaults
|
|
1243
|
-
self.ui.default_epoch_input.setValue(self.epoch_length)
|
|
1244
|
-
self.ui.overwrite_default_checkbox.setChecked(self.only_overwrite_undefined)
|
|
1245
|
-
self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_scores)
|
|
1246
|
-
self.ui.default_min_bout_length_spinbox.setValue(self.min_bout_length)
|
|
1247
|
-
self.ui.epochs_to_show_spinbox.setValue(self.default_epochs_to_show)
|
|
1248
|
-
self.ui.autoscroll_checkbox.setChecked(self.default_autoscroll_state)
|
|
1249
|
-
# EMG filter
|
|
1250
|
-
self.ui.emg_order_spinbox.setValue(self.emg_filter.order)
|
|
1251
|
-
self.ui.bp_lower_spinbox.setValue(self.emg_filter.bp_lower)
|
|
1252
|
-
self.ui.bp_upper_spinbox.setValue(self.emg_filter.bp_upper)
|
|
1253
|
-
# model training hyperparameters
|
|
1254
|
-
self.ui.batch_size_spinbox.setValue(self.hyperparameters.batch_size)
|
|
1255
|
-
self.ui.learning_rate_spinbox.setValue(self.hyperparameters.learning_rate)
|
|
1256
|
-
self.ui.momentum_spinbox.setValue(self.hyperparameters.momentum)
|
|
1257
|
-
self.ui.training_epochs_spinbox.setValue(self.hyperparameters.training_epochs)
|
|
1258
|
-
# brain states
|
|
1259
|
-
states = {b.digit: b for b in self.brain_state_set.brain_states}
|
|
1260
|
-
for digit in range(10):
|
|
1261
|
-
if digit in states.keys():
|
|
1262
|
-
self.settings_widgets[digit].enabled_widget.setChecked(True)
|
|
1263
|
-
self.settings_widgets[digit].name_widget.setText(states[digit].name)
|
|
1264
|
-
self.settings_widgets[digit].is_scored_widget.setChecked(
|
|
1265
|
-
states[digit].is_scored
|
|
1266
|
-
)
|
|
1267
|
-
self.settings_widgets[digit].frequency_widget.setValue(
|
|
1268
|
-
states[digit].frequency
|
|
1269
|
-
)
|
|
1270
|
-
else:
|
|
1271
|
-
self.settings_widgets[digit].enabled_widget.setChecked(False)
|
|
1272
|
-
self.settings_widgets[digit].name_widget.setEnabled(False)
|
|
1273
|
-
self.settings_widgets[digit].is_scored_widget.setEnabled(False)
|
|
1274
|
-
self.settings_widgets[digit].frequency_widget.setEnabled(False)
|
|
1275
|
-
|
|
1276
|
-
# set callbacks
|
|
1277
|
-
self.ui.emg_order_spinbox.valueChanged.connect(self.emg_filter_order_changed)
|
|
1278
|
-
self.ui.bp_lower_spinbox.valueChanged.connect(self.emg_filter_bp_lower_changed)
|
|
1279
|
-
self.ui.bp_upper_spinbox.valueChanged.connect(self.emg_filter_bp_upper_changed)
|
|
1280
|
-
self.ui.batch_size_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1281
|
-
self.ui.learning_rate_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1282
|
-
self.ui.momentum_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1283
|
-
self.ui.training_epochs_spinbox.valueChanged.connect(
|
|
1284
|
-
self.hyperparameters_changed
|
|
1285
|
-
)
|
|
1286
|
-
for digit in range(10):
|
|
1287
|
-
state = self.settings_widgets[digit]
|
|
1288
|
-
state.enabled_widget.stateChanged.connect(
|
|
1289
|
-
partial(self.set_brain_state_enabled, digit)
|
|
1290
|
-
)
|
|
1291
|
-
state.name_widget.editingFinished.connect(self.check_config_validity)
|
|
1292
|
-
state.is_scored_widget.stateChanged.connect(
|
|
1293
|
-
partial(self.is_scored_changed, digit)
|
|
1294
|
-
)
|
|
1295
|
-
state.frequency_widget.valueChanged.connect(self.check_config_validity)
|
|
1296
|
-
|
|
1297
|
-
def set_brain_state_enabled(self, digit, e) -> None:
|
|
1298
|
-
"""Called when user clicks "enabled" checkbox
|
|
1299
|
-
|
|
1300
|
-
:param digit: brain state digit
|
|
1301
|
-
:param e: unused but mandatory
|
|
1302
|
-
"""
|
|
1303
|
-
# get the widgets for this brain state
|
|
1304
|
-
state = self.settings_widgets[digit]
|
|
1305
|
-
# update state of these widgets
|
|
1306
|
-
is_checked = state.enabled_widget.isChecked()
|
|
1307
|
-
for widget in [
|
|
1308
|
-
state.name_widget,
|
|
1309
|
-
state.is_scored_widget,
|
|
1310
|
-
]:
|
|
1311
|
-
widget.setEnabled(is_checked)
|
|
1312
|
-
state.frequency_widget.setEnabled(
|
|
1313
|
-
is_checked and state.is_scored_widget.isChecked()
|
|
1314
|
-
)
|
|
1315
|
-
if not is_checked:
|
|
1316
|
-
state.name_widget.setText("")
|
|
1317
|
-
state.frequency_widget.setValue(0)
|
|
1318
|
-
# check that configuration is valid
|
|
1319
|
-
_ = self.check_config_validity()
|
|
1320
|
-
|
|
1321
|
-
def is_scored_changed(self, digit, e) -> None:
|
|
1322
|
-
"""Called when user sets whether a state is scored
|
|
1323
|
-
|
|
1324
|
-
:param digit: brain state digit
|
|
1325
|
-
:param e: unused, but mandatory
|
|
1326
|
-
"""
|
|
1327
|
-
# get the widgets for this brain state
|
|
1328
|
-
state = self.settings_widgets[digit]
|
|
1329
|
-
# update the state of these widgets
|
|
1330
|
-
is_checked = state.is_scored_widget.isChecked()
|
|
1331
|
-
state.frequency_widget.setEnabled(is_checked)
|
|
1332
|
-
if not is_checked:
|
|
1333
|
-
state.frequency_widget.setValue(0)
|
|
1334
|
-
# check that configuration is valid
|
|
1335
|
-
_ = self.check_config_validity()
|
|
1336
|
-
|
|
1337
|
-
def emg_filter_order_changed(self, new_value: int) -> None:
|
|
1338
|
-
"""Called when user modifies EMG filter order
|
|
1339
|
-
|
|
1340
|
-
:param new_value: new EMG filter order
|
|
1341
|
-
"""
|
|
1342
|
-
self.emg_filter.order = new_value
|
|
1343
|
-
|
|
1344
|
-
def emg_filter_bp_lower_changed(self, new_value: int | float) -> None:
|
|
1345
|
-
"""Called when user modifies EMG filter lower cutoff
|
|
1346
|
-
|
|
1347
|
-
:param new_value: new lower bandpass cutoff frequency
|
|
1348
|
-
"""
|
|
1349
|
-
self.emg_filter.bp_lower = new_value
|
|
1350
|
-
_ = self.check_config_validity()
|
|
1351
|
-
|
|
1352
|
-
def emg_filter_bp_upper_changed(self, new_value: int | float) -> None:
|
|
1353
|
-
"""Called when user modifies EMG filter upper cutoff
|
|
1354
|
-
|
|
1355
|
-
:param new_value: new upper bandpass cutoff frequency
|
|
1356
|
-
"""
|
|
1357
|
-
self.emg_filter.bp_upper = new_value
|
|
1358
|
-
_ = self.check_config_validity()
|
|
1359
|
-
|
|
1360
|
-
def hyperparameters_changed(self, new_value) -> None:
|
|
1361
|
-
"""Called when user modifies model training hyperparameters
|
|
1362
|
-
|
|
1363
|
-
:param new_value: unused
|
|
1364
|
-
"""
|
|
1365
|
-
self.hyperparameters = Hyperparameters(
|
|
1366
|
-
batch_size=self.ui.batch_size_spinbox.value(),
|
|
1367
|
-
learning_rate=self.ui.learning_rate_spinbox.value(),
|
|
1368
|
-
momentum=self.ui.momentum_spinbox.value(),
|
|
1369
|
-
training_epochs=self.ui.training_epochs_spinbox.value(),
|
|
1370
|
-
)
|
|
1371
|
-
|
|
1372
|
-
def check_config_validity(self) -> str:
|
|
1373
|
-
"""Check if brain state configuration on screen is valid"""
|
|
1374
|
-
# error message, if we get one
|
|
1375
|
-
message = None
|
|
1376
|
-
|
|
1377
|
-
# strip whitespace from brain state names and update display
|
|
1378
|
-
for digit in range(10):
|
|
1379
|
-
state = self.settings_widgets[digit]
|
|
1380
|
-
current_name = state.name_widget.text()
|
|
1381
|
-
formatted_name = current_name.strip()
|
|
1382
|
-
if current_name != formatted_name:
|
|
1383
|
-
state.name_widget.setText(formatted_name)
|
|
1384
|
-
|
|
1385
|
-
# check if names are unique and frequencies add up to 1
|
|
1386
|
-
names = []
|
|
1387
|
-
frequencies = []
|
|
1388
|
-
for digit in range(10):
|
|
1389
|
-
state = self.settings_widgets[digit]
|
|
1390
|
-
if state.enabled_widget.isChecked():
|
|
1391
|
-
names.append(state.name_widget.text())
|
|
1392
|
-
frequencies.append(state.frequency_widget.value())
|
|
1393
|
-
if len(names) != len(set(names)):
|
|
1394
|
-
message = "Error: names must be unique"
|
|
1395
|
-
if sum(frequencies) != 1:
|
|
1396
|
-
message = "Error: sum(frequencies) != 1"
|
|
1397
|
-
|
|
1398
|
-
# check validity of EMG filter settings
|
|
1399
|
-
if self.emg_filter.bp_lower >= self.emg_filter.bp_upper:
|
|
1400
|
-
message = "Error: EMG filter cutoff frequencies are invalid"
|
|
1401
|
-
|
|
1402
|
-
if message is not None:
|
|
1403
|
-
self.ui.save_config_status.setText(message)
|
|
1404
|
-
self.ui.save_config_button.setEnabled(False)
|
|
1405
|
-
return message
|
|
1406
|
-
|
|
1407
|
-
self.ui.save_config_button.setEnabled(True)
|
|
1408
|
-
self.ui.save_config_status.setText("")
|
|
1409
|
-
|
|
1410
|
-
def save_brain_state_config(self):
|
|
1411
|
-
"""Save configuration to file"""
|
|
1412
|
-
# check that configuration is valid
|
|
1413
|
-
error_message = self.check_config_validity()
|
|
1414
|
-
if error_message is not None:
|
|
1415
|
-
return
|
|
1416
|
-
|
|
1417
|
-
# build a BrainStateMapper object from the current configuration
|
|
1418
|
-
brain_states = list()
|
|
1419
|
-
for digit in range(10):
|
|
1420
|
-
state = self.settings_widgets[digit]
|
|
1421
|
-
if state.enabled_widget.isChecked():
|
|
1422
|
-
brain_states.append(
|
|
1423
|
-
BrainState(
|
|
1424
|
-
name=state.name_widget.text(),
|
|
1425
|
-
digit=digit,
|
|
1426
|
-
is_scored=state.is_scored_widget.isChecked(),
|
|
1427
|
-
frequency=state.frequency_widget.value(),
|
|
1428
|
-
)
|
|
1429
|
-
)
|
|
1430
|
-
self.brain_state_set = BrainStateSet(brain_states, UNDEFINED_LABEL)
|
|
1431
|
-
|
|
1432
|
-
# save to file
|
|
1433
|
-
save_config(
|
|
1434
|
-
brain_state_set=self.brain_state_set,
|
|
1435
|
-
default_epoch_length=self.ui.default_epoch_input.value(),
|
|
1436
|
-
overwrite_setting=self.ui.overwrite_default_checkbox.isChecked(),
|
|
1437
|
-
save_confidence_setting=self.ui.confidence_setting_checkbox.isChecked(),
|
|
1438
|
-
min_bout_length=self.ui.default_min_bout_length_spinbox.value(),
|
|
1439
|
-
emg_filter=EMGFilter(
|
|
1440
|
-
order=self.emg_filter.order,
|
|
1441
|
-
bp_lower=self.emg_filter.bp_lower,
|
|
1442
|
-
bp_upper=self.emg_filter.bp_upper,
|
|
1443
|
-
),
|
|
1444
|
-
hyperparameters=Hyperparameters(
|
|
1445
|
-
batch_size=self.hyperparameters.batch_size,
|
|
1446
|
-
learning_rate=self.hyperparameters.learning_rate,
|
|
1447
|
-
momentum=self.hyperparameters.momentum,
|
|
1448
|
-
training_epochs=self.hyperparameters.training_epochs,
|
|
1449
|
-
),
|
|
1450
|
-
epochs_to_show=self.ui.epochs_to_show_spinbox.value(),
|
|
1451
|
-
autoscroll_state=self.ui.autoscroll_checkbox.isChecked(),
|
|
1452
|
-
)
|
|
1453
|
-
self.ui.save_config_status.setText("configuration saved")
|
|
1454
|
-
|
|
1455
|
-
def reset_emg_filter_settings(self) -> None:
|
|
1456
|
-
self.ui.emg_order_spinbox.setValue(DEFAULT_EMG_FILTER_ORDER)
|
|
1457
|
-
self.ui.bp_lower_spinbox.setValue(DEFAULT_EMG_BP_LOWER)
|
|
1458
|
-
self.ui.bp_upper_spinbox.setValue(DEFAULT_EMG_BP_UPPER)
|
|
1459
|
-
|
|
1460
|
-
def reset_hyperparams_settings(self):
|
|
1461
|
-
self.ui.batch_size_spinbox.setValue(DEFAULT_BATCH_SIZE)
|
|
1462
|
-
self.ui.learning_rate_spinbox.setValue(DEFAULT_LEARNING_RATE)
|
|
1463
|
-
self.ui.momentum_spinbox.setValue(DEFAULT_MOMENTUM)
|
|
1464
|
-
self.ui.training_epochs_spinbox.setValue(DEFAULT_TRAINING_EPOCHS)
|
|
1465
|
-
|
|
1466
649
|
|
|
1467
650
|
def run_primary_window() -> None:
|
|
651
|
+
logging.basicConfig(
|
|
652
|
+
level=logging.INFO,
|
|
653
|
+
format="%(levelname)s - %(name)s - %(message)s",
|
|
654
|
+
)
|
|
1468
655
|
app = QApplication(sys.argv)
|
|
1469
656
|
AccuSleepWindow()
|
|
1470
657
|
sys.exit(app.exec())
|