accusleepy 0.8.0__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/bouts.py +46 -38
- accusleepy/brain_state_set.py +6 -4
- accusleepy/classification.py +14 -50
- accusleepy/constants.py +3 -0
- accusleepy/fileio.py +58 -27
- accusleepy/gui/dialogs.py +40 -0
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/main.py +212 -1026
- accusleepy/gui/manual_scoring.py +5 -13
- accusleepy/gui/primary_window.py +7 -9
- accusleepy/gui/primary_window.ui +6 -8
- accusleepy/gui/recording_manager.py +110 -0
- accusleepy/gui/settings_widget.py +409 -0
- accusleepy/gui/text/main_guide.md +1 -1
- accusleepy/models.py +1 -1
- accusleepy/services.py +581 -0
- accusleepy/signal_processing.py +110 -38
- accusleepy/temperature_scaling.py +14 -8
- accusleepy/validation.py +67 -2
- {accusleepy-0.8.0.dist-info → accusleepy-0.9.2.dist-info}/METADATA +3 -5
- {accusleepy-0.8.0.dist-info → accusleepy-0.9.2.dist-info}/RECORD +22 -18
- {accusleepy-0.8.0.dist-info → accusleepy-0.9.2.dist-info}/WHEEL +1 -1
accusleepy/gui/main.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
# AccuSleePy main window
|
|
2
2
|
# Icon source: Arkinasi, https://www.flaticon.com/authors/arkinasi
|
|
3
3
|
|
|
4
|
-
import
|
|
4
|
+
import logging
|
|
5
5
|
import os
|
|
6
|
-
import shutil
|
|
7
6
|
import sys
|
|
8
7
|
from dataclasses import dataclass
|
|
9
8
|
from functools import partial
|
|
10
9
|
|
|
11
10
|
import numpy as np
|
|
12
|
-
import toml
|
|
13
11
|
from PySide6.QtCore import (
|
|
14
12
|
QEvent,
|
|
15
13
|
QKeyCombination,
|
|
@@ -21,31 +19,17 @@ from PySide6.QtCore import (
|
|
|
21
19
|
from PySide6.QtGui import QKeySequence, QShortcut
|
|
22
20
|
from PySide6.QtWidgets import (
|
|
23
21
|
QApplication,
|
|
24
|
-
QCheckBox,
|
|
25
|
-
QDoubleSpinBox,
|
|
26
|
-
QFileDialog,
|
|
27
22
|
QLabel,
|
|
28
|
-
QListWidgetItem,
|
|
29
23
|
QMainWindow,
|
|
30
24
|
QTextBrowser,
|
|
31
25
|
QVBoxLayout,
|
|
32
26
|
QWidget,
|
|
33
27
|
)
|
|
34
28
|
|
|
35
|
-
from accusleepy.
|
|
36
|
-
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
29
|
+
from accusleepy.brain_state_set import BRAIN_STATES_KEY
|
|
37
30
|
from accusleepy.constants import (
|
|
38
|
-
ANNOTATIONS_FILENAME,
|
|
39
|
-
CALIBRATION_ANNOTATION_FILENAME,
|
|
40
31
|
CALIBRATION_FILE_TYPE,
|
|
41
32
|
DEFAULT_MODEL_TYPE,
|
|
42
|
-
DEFAULT_EMG_FILTER_ORDER,
|
|
43
|
-
DEFAULT_EMG_BP_LOWER,
|
|
44
|
-
DEFAULT_EMG_BP_UPPER,
|
|
45
|
-
DEFAULT_BATCH_SIZE,
|
|
46
|
-
DEFAULT_LEARNING_RATE,
|
|
47
|
-
DEFAULT_MOMENTUM,
|
|
48
|
-
DEFAULT_TRAINING_EPOCHS,
|
|
49
33
|
LABEL_FILE_TYPE,
|
|
50
34
|
MESSAGE_BOX_MAX_DEPTH,
|
|
51
35
|
MODEL_FILE_TYPE,
|
|
@@ -55,31 +39,28 @@ from accusleepy.constants import (
|
|
|
55
39
|
UNDEFINED_LABEL,
|
|
56
40
|
)
|
|
57
41
|
from accusleepy.fileio import (
|
|
58
|
-
Recording,
|
|
59
|
-
load_calibration_file,
|
|
60
42
|
load_config,
|
|
61
43
|
load_labels,
|
|
62
44
|
load_recording,
|
|
63
|
-
|
|
64
|
-
save_config,
|
|
65
|
-
save_labels,
|
|
66
|
-
save_recording_list,
|
|
67
|
-
EMGFilter,
|
|
68
|
-
Hyperparameters,
|
|
45
|
+
get_version,
|
|
69
46
|
)
|
|
47
|
+
from accusleepy.gui.dialogs import select_existing_file, select_save_location
|
|
70
48
|
from accusleepy.gui.manual_scoring import ManualScoringWindow
|
|
71
49
|
from accusleepy.gui.primary_window import Ui_PrimaryWindow
|
|
72
|
-
from accusleepy.
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
50
|
+
from accusleepy.gui.recording_manager import RecordingListManager
|
|
51
|
+
from accusleepy.gui.settings_widget import SettingsWidget
|
|
52
|
+
from accusleepy.services import (
|
|
53
|
+
LoadedModel,
|
|
54
|
+
TrainingService,
|
|
55
|
+
check_single_file_inputs,
|
|
56
|
+
create_calibration,
|
|
57
|
+
score_recording_list,
|
|
80
58
|
)
|
|
59
|
+
from accusleepy.validation import validate_and_correct_labels
|
|
60
|
+
from accusleepy.signal_processing import resample_and_standardize
|
|
61
|
+
from accusleepy.validation import check_config_consistency
|
|
81
62
|
|
|
82
|
-
|
|
63
|
+
logger = logging.getLogger(__name__)
|
|
83
64
|
|
|
84
65
|
# on Windows, prevent dark mode from changing the visual style
|
|
85
66
|
if os.name == "nt":
|
|
@@ -91,14 +72,22 @@ MAIN_GUIDE_FILE = os.path.normpath(r"text/main_guide.md")
|
|
|
91
72
|
|
|
92
73
|
|
|
93
74
|
@dataclass
|
|
94
|
-
class
|
|
95
|
-
"""
|
|
75
|
+
class TrainingSettings:
|
|
76
|
+
"""Settings for training a new model"""
|
|
77
|
+
|
|
78
|
+
epochs_per_img: int = 9
|
|
79
|
+
delete_images: bool = True
|
|
80
|
+
model_type: str = DEFAULT_MODEL_TYPE
|
|
81
|
+
calibrate: bool = True
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class ScoringSettings:
|
|
86
|
+
"""Settings for scoring a recording"""
|
|
96
87
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
is_scored_widget: QCheckBox
|
|
101
|
-
frequency_widget: QDoubleSpinBox
|
|
88
|
+
only_overwrite_undefined: bool
|
|
89
|
+
save_confidence_scores: bool
|
|
90
|
+
min_bout_length: int | float
|
|
102
91
|
|
|
103
92
|
|
|
104
93
|
class AccuSleepWindow(QMainWindow):
|
|
@@ -112,66 +101,42 @@ class AccuSleepWindow(QMainWindow):
|
|
|
112
101
|
self.ui.setupUi(self)
|
|
113
102
|
self.setWindowTitle("AccuSleePy")
|
|
114
103
|
|
|
115
|
-
#
|
|
116
|
-
(
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
self.settings_widgets = None
|
|
129
|
-
self.initialize_settings_tab()
|
|
104
|
+
# Load configuration
|
|
105
|
+
loaded_config = load_config()
|
|
106
|
+
|
|
107
|
+
# Apply default values from the configuration
|
|
108
|
+
self.epoch_length = loaded_config.default_epoch_length
|
|
109
|
+
self.scoring = ScoringSettings(
|
|
110
|
+
only_overwrite_undefined=loaded_config.overwrite_setting,
|
|
111
|
+
save_confidence_scores=loaded_config.save_confidence_setting,
|
|
112
|
+
min_bout_length=loaded_config.min_bout_length,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Initialize settings tab (manages Settings tab UI and saved config values)
|
|
116
|
+
self.config = SettingsWidget(ui=self.ui, config=loaded_config, parent=self)
|
|
130
117
|
|
|
131
118
|
# initialize info about the recordings, classification data / settings
|
|
132
119
|
self.ui.epoch_length_input.setValue(self.epoch_length)
|
|
133
|
-
self.ui.overwritecheckbox.setChecked(self.only_overwrite_undefined)
|
|
134
|
-
self.ui.save_confidence_checkbox.setChecked(self.save_confidence_scores)
|
|
135
|
-
self.ui.bout_length_input.setValue(self.min_bout_length)
|
|
136
|
-
self.model = None
|
|
120
|
+
self.ui.overwritecheckbox.setChecked(self.scoring.only_overwrite_undefined)
|
|
121
|
+
self.ui.save_confidence_checkbox.setChecked(self.scoring.save_confidence_scores)
|
|
122
|
+
self.ui.bout_length_input.setValue(self.scoring.min_bout_length)
|
|
137
123
|
|
|
138
|
-
#
|
|
139
|
-
self.
|
|
140
|
-
self.delete_training_images = True
|
|
141
|
-
self.model_type = DEFAULT_MODEL_TYPE
|
|
142
|
-
self.calibrate_trained_model = True
|
|
124
|
+
# loaded classification model and its metadata
|
|
125
|
+
self.loaded_model = LoadedModel()
|
|
143
126
|
|
|
144
|
-
#
|
|
145
|
-
self.
|
|
146
|
-
self.model_epochs_per_img = None
|
|
127
|
+
# settings for training new models
|
|
128
|
+
self.training = TrainingSettings()
|
|
147
129
|
|
|
148
130
|
# set up the list of recordings
|
|
149
|
-
|
|
150
|
-
|
|
131
|
+
self.recording_manager = RecordingListManager(
|
|
132
|
+
self.ui.recording_list_widget, parent=self
|
|
151
133
|
)
|
|
152
|
-
self.ui.recording_list_widget.addItem(first_recording.widget)
|
|
153
|
-
self.ui.recording_list_widget.setCurrentRow(0)
|
|
154
|
-
# index of currently selected recording in the list
|
|
155
|
-
self.recording_index = 0
|
|
156
|
-
# list of recordings the user has added
|
|
157
|
-
self.recordings = [first_recording]
|
|
158
134
|
|
|
159
135
|
# messages to display
|
|
160
136
|
self.messages = []
|
|
161
137
|
|
|
162
138
|
# display current version
|
|
163
|
-
|
|
164
|
-
toml_file = os.path.join(
|
|
165
|
-
os.path.dirname(
|
|
166
|
-
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
167
|
-
),
|
|
168
|
-
"pyproject.toml",
|
|
169
|
-
)
|
|
170
|
-
if os.path.isfile(toml_file):
|
|
171
|
-
toml_data = toml.load(toml_file)
|
|
172
|
-
if "project" in toml_data and "version" in toml_data["project"]:
|
|
173
|
-
version = toml_data["project"]["version"]
|
|
174
|
-
self.ui.version_label.setText(f"v{version}")
|
|
139
|
+
self.ui.version_label.setText(f"v{get_version()}")
|
|
175
140
|
|
|
176
141
|
# user input: keyboard shortcuts
|
|
177
142
|
keypress_quit = QShortcut(
|
|
@@ -184,8 +149,12 @@ class AccuSleepWindow(QMainWindow):
|
|
|
184
149
|
self.ui.add_button.clicked.connect(self.add_recording)
|
|
185
150
|
self.ui.remove_button.clicked.connect(self.remove_recording)
|
|
186
151
|
self.ui.recording_list_widget.currentRowChanged.connect(self.select_recording)
|
|
187
|
-
self.ui.sampling_rate_input.valueChanged.connect(
|
|
188
|
-
|
|
152
|
+
self.ui.sampling_rate_input.valueChanged.connect(
|
|
153
|
+
lambda v: setattr(self.recording_manager.current, "sampling_rate", v)
|
|
154
|
+
)
|
|
155
|
+
self.ui.epoch_length_input.valueChanged.connect(
|
|
156
|
+
lambda v: setattr(self, "epoch_length", v)
|
|
157
|
+
)
|
|
189
158
|
self.ui.recording_file_button.clicked.connect(self.select_recording_file)
|
|
190
159
|
self.ui.select_label_button.clicked.connect(self.select_label_file)
|
|
191
160
|
self.ui.create_label_button.clicked.connect(self.create_label_file)
|
|
@@ -193,27 +162,33 @@ class AccuSleepWindow(QMainWindow):
|
|
|
193
162
|
self.ui.create_calibration_button.clicked.connect(self.create_calibration_file)
|
|
194
163
|
self.ui.select_calibration_button.clicked.connect(self.select_calibration_file)
|
|
195
164
|
self.ui.load_model_button.clicked.connect(partial(self.load_model, None))
|
|
196
|
-
self.ui.score_all_button.clicked.connect(self.
|
|
197
|
-
self.ui.overwritecheckbox.stateChanged.connect(
|
|
165
|
+
self.ui.score_all_button.clicked.connect(self.score_recordings)
|
|
166
|
+
self.ui.overwritecheckbox.stateChanged.connect(
|
|
167
|
+
lambda v: setattr(self.scoring, "only_overwrite_undefined", bool(v))
|
|
168
|
+
)
|
|
198
169
|
self.ui.save_confidence_checkbox.stateChanged.connect(
|
|
199
|
-
self.
|
|
170
|
+
lambda v: setattr(self.scoring, "save_confidence_scores", bool(v))
|
|
171
|
+
)
|
|
172
|
+
self.ui.bout_length_input.valueChanged.connect(
|
|
173
|
+
lambda v: setattr(self.scoring, "min_bout_length", v)
|
|
200
174
|
)
|
|
201
|
-
self.ui.bout_length_input.valueChanged.connect(self.update_min_bout_length)
|
|
202
175
|
self.ui.user_manual_button.clicked.connect(self.show_user_manual)
|
|
203
|
-
self.ui.image_number_input.valueChanged.connect(
|
|
204
|
-
|
|
176
|
+
self.ui.image_number_input.valueChanged.connect(
|
|
177
|
+
lambda v: setattr(self.training, "epochs_per_img", v)
|
|
178
|
+
)
|
|
179
|
+
self.ui.delete_image_box.stateChanged.connect(
|
|
180
|
+
lambda v: setattr(self.training, "delete_images", bool(v))
|
|
181
|
+
)
|
|
205
182
|
self.ui.calibrate_checkbox.stateChanged.connect(
|
|
206
183
|
self.update_training_calibration
|
|
207
184
|
)
|
|
208
185
|
self.ui.train_model_button.clicked.connect(self.train_model)
|
|
209
|
-
self.ui.save_config_button.clicked.connect(self.
|
|
186
|
+
self.ui.save_config_button.clicked.connect(self.config.save_config)
|
|
210
187
|
self.ui.export_button.clicked.connect(self.export_recording_list)
|
|
211
188
|
self.ui.import_button.clicked.connect(self.import_recording_list)
|
|
212
189
|
self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
|
|
213
|
-
self.ui.reset_emg_params_button.clicked.connect(self.
|
|
214
|
-
self.ui.reset_hyperparams_button.clicked.connect(
|
|
215
|
-
self.reset_hyperparams_settings
|
|
216
|
-
)
|
|
190
|
+
self.ui.reset_emg_params_button.clicked.connect(self.config.reset_emg_filter)
|
|
191
|
+
self.ui.reset_hyperparams_button.clicked.connect(self.config.reset_hyperparams)
|
|
217
192
|
|
|
218
193
|
# user input: drag and drop
|
|
219
194
|
self.ui.recording_file_label.installEventFilter(self)
|
|
@@ -228,52 +203,29 @@ class AccuSleepWindow(QMainWindow):
|
|
|
228
203
|
|
|
229
204
|
:param default_selected: whether default option is selected
|
|
230
205
|
"""
|
|
231
|
-
self.model_type = (
|
|
206
|
+
self.training.model_type = (
|
|
232
207
|
DEFAULT_MODEL_TYPE if default_selected else REAL_TIME_MODEL_TYPE
|
|
233
208
|
)
|
|
234
209
|
|
|
235
210
|
def export_recording_list(self) -> None:
|
|
236
211
|
"""Save current list of recordings to file"""
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
self,
|
|
240
|
-
caption="Save list of recordings as",
|
|
241
|
-
filter="*" + RECORDING_LIST_FILE_TYPE,
|
|
212
|
+
filename = select_save_location(
|
|
213
|
+
self, "Save list of recordings as", "*" + RECORDING_LIST_FILE_TYPE
|
|
242
214
|
)
|
|
243
215
|
if not filename:
|
|
244
216
|
return
|
|
245
|
-
|
|
246
|
-
save_recording_list(filename=filename, recordings=self.recordings)
|
|
217
|
+
self.recording_manager.export_to_file(filename)
|
|
247
218
|
self.show_message(f"Saved list of recordings to {filename}")
|
|
248
219
|
|
|
249
220
|
def import_recording_list(self):
|
|
250
221
|
"""Load list of recordings from file, overwriting current list"""
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
file_dialog.setNameFilter("*" + RECORDING_LIST_FILE_TYPE)
|
|
256
|
-
|
|
257
|
-
if file_dialog.exec():
|
|
258
|
-
selected_files = file_dialog.selectedFiles()
|
|
259
|
-
filename = selected_files[0]
|
|
260
|
-
filename = os.path.normpath(filename)
|
|
261
|
-
else:
|
|
222
|
+
filename = select_existing_file(
|
|
223
|
+
self, "Select list of recordings", "*" + RECORDING_LIST_FILE_TYPE
|
|
224
|
+
)
|
|
225
|
+
if not filename:
|
|
262
226
|
return
|
|
263
227
|
|
|
264
|
-
|
|
265
|
-
self.ui.recording_list_widget.clear()
|
|
266
|
-
# overwrite current list
|
|
267
|
-
self.recordings = load_recording_list(filename)
|
|
268
|
-
|
|
269
|
-
for recording in self.recordings:
|
|
270
|
-
recording.widget = QListWidgetItem(
|
|
271
|
-
f"Recording {recording.name}", self.ui.recording_list_widget
|
|
272
|
-
)
|
|
273
|
-
self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
|
|
274
|
-
|
|
275
|
-
# display new list
|
|
276
|
-
self.ui.recording_list_widget.setCurrentRow(0)
|
|
228
|
+
self.recording_manager.import_from_file(filename)
|
|
277
229
|
self.show_message(f"Loaded list of recordings from {filename}")
|
|
278
230
|
|
|
279
231
|
def eventFilter(self, obj: QObject, event: QEvent) -> bool:
|
|
@@ -298,20 +250,21 @@ class AccuSleepWindow(QMainWindow):
|
|
|
298
250
|
|
|
299
251
|
if filename is None:
|
|
300
252
|
return super().eventFilter(obj, event)
|
|
253
|
+
filename = str(filename)
|
|
301
254
|
|
|
302
255
|
_, file_extension = os.path.splitext(filename)
|
|
303
256
|
|
|
304
257
|
if obj == self.ui.recording_file_label:
|
|
305
258
|
if file_extension in RECORDING_FILE_TYPES:
|
|
306
|
-
self.
|
|
259
|
+
self.recording_manager.current.recording_file = filename
|
|
307
260
|
self.ui.recording_file_label.setText(filename)
|
|
308
261
|
elif obj == self.ui.label_file_label:
|
|
309
262
|
if file_extension == LABEL_FILE_TYPE:
|
|
310
|
-
self.
|
|
263
|
+
self.recording_manager.current.label_file = filename
|
|
311
264
|
self.ui.label_file_label.setText(filename)
|
|
312
265
|
elif obj == self.ui.calibration_file_label:
|
|
313
266
|
if file_extension == CALIBRATION_FILE_TYPE:
|
|
314
|
-
self.
|
|
267
|
+
self.recording_manager.current.calibration_file = filename
|
|
315
268
|
self.ui.calibration_file_label.setText(filename)
|
|
316
269
|
elif obj == self.ui.model_label:
|
|
317
270
|
self.load_model(filename=filename)
|
|
@@ -319,335 +272,69 @@ class AccuSleepWindow(QMainWindow):
|
|
|
319
272
|
return super().eventFilter(obj, event)
|
|
320
273
|
|
|
321
274
|
def train_model(self) -> None:
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
self
|
|
325
|
-
and self.training_epochs_per_img % 2 == 0
|
|
326
|
-
):
|
|
327
|
-
self.show_message(
|
|
328
|
-
(
|
|
329
|
-
"ERROR: for the default model type, number of epochs "
|
|
330
|
-
"per image must be an odd number."
|
|
331
|
-
)
|
|
332
|
-
)
|
|
333
|
-
return
|
|
334
|
-
|
|
335
|
-
# determine fraction of training data to use for calibration
|
|
336
|
-
if self.calibrate_trained_model:
|
|
337
|
-
calibration_fraction = self.ui.calibration_spinbox.value() / 100
|
|
338
|
-
else:
|
|
339
|
-
calibration_fraction = 0
|
|
340
|
-
|
|
341
|
-
# check some inputs for each recording
|
|
342
|
-
for recording_index in range(len(self.recordings)):
|
|
343
|
-
error_message = self.check_single_file_inputs(recording_index)
|
|
344
|
-
if error_message:
|
|
345
|
-
self.show_message(
|
|
346
|
-
f"ERROR (recording {self.recordings[recording_index].name}): {error_message}"
|
|
347
|
-
)
|
|
348
|
-
return
|
|
349
|
-
|
|
350
|
-
# get filename for the new model
|
|
351
|
-
model_filename, _ = QFileDialog.getSaveFileName(
|
|
352
|
-
self,
|
|
353
|
-
caption="Save classification model file as",
|
|
354
|
-
filter="*" + MODEL_FILE_TYPE,
|
|
275
|
+
"""Train a classification model using the current recordings."""
|
|
276
|
+
model_filename = select_save_location(
|
|
277
|
+
self, "Save classification model file as", "*" + MODEL_FILE_TYPE
|
|
355
278
|
)
|
|
356
279
|
if not model_filename:
|
|
357
280
|
self.show_message("Model training canceled, no filename given")
|
|
358
281
|
return
|
|
359
|
-
model_filename = os.path.normpath(model_filename)
|
|
360
282
|
|
|
361
|
-
#
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
)
|
|
367
|
-
|
|
368
|
-
if os.path.exists(temp_image_dir): # unlikely
|
|
369
|
-
self.show_message(
|
|
370
|
-
"Warning: training image folder exists, will be overwritten"
|
|
371
|
-
)
|
|
372
|
-
os.makedirs(temp_image_dir, exist_ok=True)
|
|
283
|
+
# Determine calibration fraction
|
|
284
|
+
if self.training.calibrate:
|
|
285
|
+
calibration_fraction = self.ui.calibration_spinbox.value() / 100
|
|
286
|
+
else:
|
|
287
|
+
calibration_fraction = 0
|
|
373
288
|
|
|
374
|
-
#
|
|
289
|
+
# Show progress message
|
|
375
290
|
self.show_message("Training, please wait. See console for progress updates.")
|
|
376
|
-
if not self.delete_training_images:
|
|
377
|
-
self.show_message((f"Creating training images in {temp_image_dir}"))
|
|
378
|
-
else:
|
|
379
|
-
self.show_message(
|
|
380
|
-
(f"Creating temporary folder of training images: {temp_image_dir}")
|
|
381
|
-
)
|
|
382
291
|
self.ui.message_area.repaint()
|
|
383
292
|
QApplication.processEvents()
|
|
384
|
-
print("Creating training images")
|
|
385
|
-
failed_recordings = create_training_images(
|
|
386
|
-
recordings=self.recordings,
|
|
387
|
-
output_path=temp_image_dir,
|
|
388
|
-
epoch_length=self.epoch_length,
|
|
389
|
-
epochs_per_img=self.training_epochs_per_img,
|
|
390
|
-
brain_state_set=self.brain_state_set,
|
|
391
|
-
model_type=self.model_type,
|
|
392
|
-
calibration_fraction=calibration_fraction,
|
|
393
|
-
emg_filter=self.emg_filter,
|
|
394
|
-
)
|
|
395
|
-
if len(failed_recordings) > 0:
|
|
396
|
-
if len(failed_recordings) == len(self.recordings):
|
|
397
|
-
self.show_message("ERROR: no recordings were valid!")
|
|
398
|
-
return
|
|
399
|
-
else:
|
|
400
|
-
self.show_message(
|
|
401
|
-
(
|
|
402
|
-
"WARNING: the following recordings could not be "
|
|
403
|
-
"loaded and will not be used for training: "
|
|
404
|
-
f"{', '.join([str(r) for r in failed_recordings])}"
|
|
405
|
-
)
|
|
406
|
-
)
|
|
407
293
|
|
|
408
|
-
#
|
|
409
|
-
self.show_message
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
print("Training model")
|
|
413
|
-
from accusleepy.classification import create_dataloader, train_ssann
|
|
414
|
-
from accusleepy.models import save_model
|
|
415
|
-
from accusleepy.temperature_scaling import ModelWithTemperature
|
|
416
|
-
|
|
417
|
-
model = train_ssann(
|
|
418
|
-
annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
|
|
419
|
-
img_dir=temp_image_dir,
|
|
420
|
-
mixture_weights=self.brain_state_set.mixture_weights,
|
|
421
|
-
n_classes=self.brain_state_set.n_classes,
|
|
422
|
-
hyperparameters=self.hyperparameters,
|
|
423
|
-
)
|
|
424
|
-
|
|
425
|
-
# calibrate the model
|
|
426
|
-
if self.calibrate_trained_model:
|
|
427
|
-
calibration_annotation_file = os.path.join(
|
|
428
|
-
temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
|
|
429
|
-
)
|
|
430
|
-
calibration_dataloader = create_dataloader(
|
|
431
|
-
annotations_file=calibration_annotation_file,
|
|
432
|
-
img_dir=temp_image_dir,
|
|
433
|
-
hyperparameters=self.hyperparameters,
|
|
434
|
-
)
|
|
435
|
-
model = ModelWithTemperature(model)
|
|
436
|
-
print("Calibrating model")
|
|
437
|
-
model.set_temperature(calibration_dataloader)
|
|
438
|
-
|
|
439
|
-
# save model
|
|
440
|
-
save_model(
|
|
441
|
-
model=model,
|
|
442
|
-
filename=model_filename,
|
|
294
|
+
# Create service and run training
|
|
295
|
+
service = TrainingService(progress_callback=self.show_message)
|
|
296
|
+
result = service.train_model(
|
|
297
|
+
recordings=list(self.recording_manager),
|
|
443
298
|
epoch_length=self.epoch_length,
|
|
444
|
-
epochs_per_img=self.
|
|
445
|
-
model_type=self.model_type,
|
|
446
|
-
|
|
447
|
-
|
|
299
|
+
epochs_per_img=self.training.epochs_per_img,
|
|
300
|
+
model_type=self.training.model_type,
|
|
301
|
+
calibrate=self.training.calibrate,
|
|
302
|
+
calibration_fraction=calibration_fraction,
|
|
303
|
+
brain_state_set=self.config.brain_state_set,
|
|
304
|
+
emg_filter=self.config.emg_filter,
|
|
305
|
+
hyperparameters=self.config.hyperparameters,
|
|
306
|
+
model_filename=model_filename,
|
|
307
|
+
delete_images=self.training.delete_images,
|
|
448
308
|
)
|
|
449
309
|
|
|
450
|
-
#
|
|
451
|
-
|
|
452
|
-
print("Cleaning up training image folder")
|
|
453
|
-
shutil.rmtree(temp_image_dir)
|
|
454
|
-
|
|
455
|
-
self.show_message(f"Training complete. Saved model to {model_filename}")
|
|
456
|
-
print("Training complete.")
|
|
457
|
-
|
|
458
|
-
def update_image_deletion(self) -> None:
|
|
459
|
-
"""Update choice of whether to delete images after training"""
|
|
460
|
-
self.delete_training_images = self.ui.delete_image_box.isChecked()
|
|
310
|
+
# Display results
|
|
311
|
+
result.report_to(self.show_message)
|
|
461
312
|
|
|
462
313
|
def update_training_calibration(self) -> None:
|
|
463
314
|
"""Update choice of whether to calibrate model after training"""
|
|
464
|
-
self.
|
|
465
|
-
self.ui.calibration_spinbox.setEnabled(self.
|
|
466
|
-
|
|
467
|
-
def update_epochs_per_img(self, new_value) -> None:
|
|
468
|
-
"""Update number of epochs per image
|
|
469
|
-
|
|
470
|
-
:param new_value: new number of epochs per image
|
|
471
|
-
"""
|
|
472
|
-
self.training_epochs_per_img = new_value
|
|
473
|
-
|
|
474
|
-
def score_all(self) -> None:
|
|
475
|
-
"""Score all recordings using the classification model"""
|
|
476
|
-
# check basic inputs
|
|
477
|
-
if self.model is None:
|
|
478
|
-
self.ui.score_all_status.setText("missing classification model")
|
|
479
|
-
self.show_message("ERROR: no classification model file selected")
|
|
480
|
-
return
|
|
481
|
-
if self.min_bout_length < self.epoch_length:
|
|
482
|
-
self.ui.score_all_status.setText("invalid minimum bout length")
|
|
483
|
-
self.show_message("ERROR: minimum bout length must be >= epoch length")
|
|
484
|
-
return
|
|
485
|
-
if self.epoch_length != self.model_epoch_length:
|
|
486
|
-
self.ui.score_all_status.setText("invalid epoch length")
|
|
487
|
-
self.show_message(
|
|
488
|
-
(
|
|
489
|
-
"ERROR: model was trained with an epoch length of "
|
|
490
|
-
f"{self.model_epoch_length} seconds, but the current "
|
|
491
|
-
f"epoch length setting is {self.epoch_length} seconds."
|
|
492
|
-
)
|
|
493
|
-
)
|
|
494
|
-
return
|
|
315
|
+
self.training.calibrate = self.ui.calibrate_checkbox.isChecked()
|
|
316
|
+
self.ui.calibration_spinbox.setEnabled(self.training.calibrate)
|
|
495
317
|
|
|
318
|
+
def score_recordings(self) -> None:
|
|
319
|
+
"""Score all recordings using the classification model."""
|
|
496
320
|
self.ui.score_all_status.setText("running...")
|
|
497
321
|
self.ui.score_all_status.repaint()
|
|
498
322
|
QApplication.processEvents()
|
|
499
323
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
f"ERROR (recording {self.recordings[recording_index].name}): {error_message}"
|
|
511
|
-
)
|
|
512
|
-
return
|
|
513
|
-
if self.recordings[recording_index].calibration_file == "":
|
|
514
|
-
self.ui.score_all_status.setText(
|
|
515
|
-
f"error on recording {self.recordings[recording_index].name}"
|
|
516
|
-
)
|
|
517
|
-
self.show_message(
|
|
518
|
-
(
|
|
519
|
-
f"ERROR (recording {self.recordings[recording_index].name}): "
|
|
520
|
-
"no calibration file selected"
|
|
521
|
-
)
|
|
522
|
-
)
|
|
523
|
-
return
|
|
524
|
-
|
|
525
|
-
# score each recording
|
|
526
|
-
for recording_index in range(len(self.recordings)):
|
|
527
|
-
# load EEG, EMG
|
|
528
|
-
try:
|
|
529
|
-
eeg, emg = load_recording(
|
|
530
|
-
self.recordings[recording_index].recording_file
|
|
531
|
-
)
|
|
532
|
-
sampling_rate = self.recordings[recording_index].sampling_rate
|
|
533
|
-
|
|
534
|
-
eeg, emg, sampling_rate = resample_and_standardize(
|
|
535
|
-
eeg=eeg,
|
|
536
|
-
emg=emg,
|
|
537
|
-
sampling_rate=sampling_rate,
|
|
538
|
-
epoch_length=self.epoch_length,
|
|
539
|
-
)
|
|
540
|
-
except Exception:
|
|
541
|
-
self.show_message(
|
|
542
|
-
(
|
|
543
|
-
"ERROR: could not load recording "
|
|
544
|
-
f"{self.recordings[recording_index].name}."
|
|
545
|
-
"This recording will be skipped."
|
|
546
|
-
)
|
|
547
|
-
)
|
|
548
|
-
continue
|
|
549
|
-
|
|
550
|
-
# load labels
|
|
551
|
-
label_file = self.recordings[recording_index].label_file
|
|
552
|
-
if os.path.isfile(label_file):
|
|
553
|
-
try:
|
|
554
|
-
# ignore any existing confidence scores; they will all be overwritten
|
|
555
|
-
existing_labels, _ = load_labels(label_file)
|
|
556
|
-
except Exception:
|
|
557
|
-
self.show_message(
|
|
558
|
-
(
|
|
559
|
-
"ERROR: could not load existing labels for recording "
|
|
560
|
-
f"{self.recordings[recording_index].name}."
|
|
561
|
-
"This recording will be skipped."
|
|
562
|
-
)
|
|
563
|
-
)
|
|
564
|
-
continue
|
|
565
|
-
# only check the length
|
|
566
|
-
samples_per_epoch = sampling_rate * self.epoch_length
|
|
567
|
-
epochs_in_recording = round(eeg.size / samples_per_epoch)
|
|
568
|
-
if epochs_in_recording != existing_labels.size:
|
|
569
|
-
self.show_message(
|
|
570
|
-
(
|
|
571
|
-
"ERROR: existing labels for recording "
|
|
572
|
-
f"{self.recordings[recording_index].name} "
|
|
573
|
-
"do not match the recording length. "
|
|
574
|
-
"This recording will be skipped."
|
|
575
|
-
)
|
|
576
|
-
)
|
|
577
|
-
continue
|
|
578
|
-
else:
|
|
579
|
-
existing_labels = None
|
|
580
|
-
|
|
581
|
-
# load calibration data
|
|
582
|
-
if not os.path.isfile(self.recordings[recording_index].calibration_file):
|
|
583
|
-
self.show_message(
|
|
584
|
-
(
|
|
585
|
-
"ERROR: calibration file does not exist for recording "
|
|
586
|
-
f"{self.recordings[recording_index].name}. "
|
|
587
|
-
"This recording will be skipped."
|
|
588
|
-
)
|
|
589
|
-
)
|
|
590
|
-
continue
|
|
591
|
-
try:
|
|
592
|
-
(
|
|
593
|
-
mixture_means,
|
|
594
|
-
mixture_sds,
|
|
595
|
-
) = load_calibration_file(
|
|
596
|
-
self.recordings[recording_index].calibration_file
|
|
597
|
-
)
|
|
598
|
-
except Exception:
|
|
599
|
-
self.show_message(
|
|
600
|
-
(
|
|
601
|
-
"ERROR: could not load calibration file for recording "
|
|
602
|
-
f"{self.recordings[recording_index].name}. "
|
|
603
|
-
"This recording will be skipped."
|
|
604
|
-
)
|
|
605
|
-
)
|
|
606
|
-
continue
|
|
607
|
-
|
|
608
|
-
labels, confidence_scores = score_recording(
|
|
609
|
-
model=self.model,
|
|
610
|
-
eeg=eeg,
|
|
611
|
-
emg=emg,
|
|
612
|
-
mixture_means=mixture_means,
|
|
613
|
-
mixture_sds=mixture_sds,
|
|
614
|
-
sampling_rate=sampling_rate,
|
|
615
|
-
epoch_length=self.epoch_length,
|
|
616
|
-
epochs_per_img=self.model_epochs_per_img,
|
|
617
|
-
brain_state_set=self.brain_state_set,
|
|
618
|
-
emg_filter=self.emg_filter,
|
|
619
|
-
)
|
|
620
|
-
|
|
621
|
-
# overwrite as needed
|
|
622
|
-
if existing_labels is not None and self.only_overwrite_undefined:
|
|
623
|
-
labels[existing_labels != UNDEFINED_LABEL] = existing_labels[
|
|
624
|
-
existing_labels != UNDEFINED_LABEL
|
|
625
|
-
]
|
|
626
|
-
|
|
627
|
-
# enforce minimum bout length
|
|
628
|
-
labels = enforce_min_bout_length(
|
|
629
|
-
labels=labels,
|
|
630
|
-
epoch_length=self.epoch_length,
|
|
631
|
-
min_bout_length=self.min_bout_length,
|
|
632
|
-
)
|
|
633
|
-
|
|
634
|
-
# ignore confidence scores if desired
|
|
635
|
-
if not self.save_confidence_scores:
|
|
636
|
-
confidence_scores = None
|
|
637
|
-
|
|
638
|
-
# save results
|
|
639
|
-
save_labels(
|
|
640
|
-
labels=labels, filename=label_file, confidence_scores=confidence_scores
|
|
641
|
-
)
|
|
642
|
-
self.show_message(
|
|
643
|
-
(
|
|
644
|
-
"Saved labels for recording "
|
|
645
|
-
f"{self.recordings[recording_index].name} "
|
|
646
|
-
f"to {label_file}"
|
|
647
|
-
)
|
|
648
|
-
)
|
|
324
|
+
result = score_recording_list(
|
|
325
|
+
recordings=list(self.recording_manager),
|
|
326
|
+
loaded_model=self.loaded_model,
|
|
327
|
+
epoch_length=self.epoch_length,
|
|
328
|
+
only_overwrite_undefined=self.scoring.only_overwrite_undefined,
|
|
329
|
+
save_confidence_scores=self.scoring.save_confidence_scores,
|
|
330
|
+
min_bout_length=self.scoring.min_bout_length,
|
|
331
|
+
brain_state_set=self.config.brain_state_set,
|
|
332
|
+
emg_filter=self.config.emg_filter,
|
|
333
|
+
)
|
|
649
334
|
|
|
650
|
-
|
|
335
|
+
# Display results
|
|
336
|
+
result.report_to(self.show_message)
|
|
337
|
+
self.ui.score_all_status.setText("error" if not result.success else "")
|
|
651
338
|
|
|
652
339
|
def load_model(self, filename=None) -> None:
|
|
653
340
|
"""Load trained classification model from file
|
|
@@ -655,17 +342,10 @@ class AccuSleepWindow(QMainWindow):
|
|
|
655
342
|
:param filename: model filename, if it's known
|
|
656
343
|
"""
|
|
657
344
|
if filename is None:
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
file_dialog.setNameFilter("*" + MODEL_FILE_TYPE)
|
|
663
|
-
|
|
664
|
-
if file_dialog.exec():
|
|
665
|
-
selected_files = file_dialog.selectedFiles()
|
|
666
|
-
filename = selected_files[0]
|
|
667
|
-
filename = os.path.normpath(filename)
|
|
668
|
-
else:
|
|
345
|
+
filename = select_existing_file(
|
|
346
|
+
self, "Select classification model", "*" + MODEL_FILE_TYPE
|
|
347
|
+
)
|
|
348
|
+
if not filename:
|
|
669
349
|
return
|
|
670
350
|
|
|
671
351
|
if not os.path.isfile(filename):
|
|
@@ -683,6 +363,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
683
363
|
filename=filename
|
|
684
364
|
)
|
|
685
365
|
except Exception:
|
|
366
|
+
logger.exception("Failed to load %s", filename)
|
|
686
367
|
self.show_message(
|
|
687
368
|
(
|
|
688
369
|
"ERROR: could not load classification model. Check "
|
|
@@ -703,14 +384,14 @@ class AccuSleepWindow(QMainWindow):
|
|
|
703
384
|
)
|
|
704
385
|
return
|
|
705
386
|
|
|
706
|
-
self.model = model
|
|
707
|
-
self.
|
|
708
|
-
self.
|
|
387
|
+
self.loaded_model.model = model
|
|
388
|
+
self.loaded_model.epoch_length = epoch_length
|
|
389
|
+
self.loaded_model.epochs_per_img = epochs_per_img
|
|
709
390
|
|
|
710
391
|
# warn user if the model's expected epoch length or brain states
|
|
711
392
|
# don't match the current configuration
|
|
712
393
|
config_warnings = check_config_consistency(
|
|
713
|
-
current_brain_states=self.brain_state_set.to_output_dict()[
|
|
394
|
+
current_brain_states=self.config.brain_state_set.to_output_dict()[
|
|
714
395
|
BRAIN_STATES_KEY
|
|
715
396
|
],
|
|
716
397
|
model_brain_states=brain_states,
|
|
@@ -737,17 +418,21 @@ class AccuSleepWindow(QMainWindow):
|
|
|
737
418
|
:param status_widget: UI element on which to display error messages
|
|
738
419
|
:return: EEG data, EMG data, sampling rate, process completion
|
|
739
420
|
"""
|
|
740
|
-
error_message =
|
|
421
|
+
error_message = check_single_file_inputs(
|
|
422
|
+
self.recording_manager.current, self.epoch_length
|
|
423
|
+
)
|
|
741
424
|
if error_message:
|
|
742
425
|
status_widget.setText(error_message)
|
|
743
426
|
self.show_message(f"ERROR: {error_message}")
|
|
744
427
|
return None, None, None, False
|
|
745
428
|
|
|
746
429
|
try:
|
|
747
|
-
eeg, emg = load_recording(
|
|
748
|
-
self.recordings[self.recording_index].recording_file
|
|
749
|
-
)
|
|
430
|
+
eeg, emg = load_recording(self.recording_manager.current.recording_file)
|
|
750
431
|
except Exception:
|
|
432
|
+
logger.exception(
|
|
433
|
+
"Failed to load %s",
|
|
434
|
+
self.recording_manager.current.recording_file,
|
|
435
|
+
)
|
|
751
436
|
status_widget.setText("could not load recording")
|
|
752
437
|
self.show_message(
|
|
753
438
|
(
|
|
@@ -757,7 +442,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
757
442
|
)
|
|
758
443
|
return None, None, None, False
|
|
759
444
|
|
|
760
|
-
sampling_rate = self.
|
|
445
|
+
sampling_rate = self.recording_manager.current.sampling_rate
|
|
761
446
|
|
|
762
447
|
eeg, emg, sampling_rate = resample_and_standardize(
|
|
763
448
|
eeg=eeg,
|
|
@@ -769,135 +454,35 @@ class AccuSleepWindow(QMainWindow):
|
|
|
769
454
|
return eeg, emg, sampling_rate, True
|
|
770
455
|
|
|
771
456
|
def create_calibration_file(self) -> None:
|
|
772
|
-
"""Creates a calibration file
|
|
457
|
+
"""Creates a calibration file.
|
|
773
458
|
|
|
774
459
|
This loads a recording and its labels, checks that the labels are
|
|
775
460
|
all valid, creates the calibration file, and sets the
|
|
776
461
|
"calibration file" property of the current recording to be the
|
|
777
462
|
newly created file.
|
|
778
463
|
"""
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
self.ui.calibration_status
|
|
782
|
-
)
|
|
783
|
-
if not success:
|
|
784
|
-
return
|
|
785
|
-
|
|
786
|
-
# load the labels
|
|
787
|
-
label_file = self.recordings[self.recording_index].label_file
|
|
788
|
-
if not os.path.isfile(label_file):
|
|
789
|
-
self.ui.calibration_status.setText("label file does not exist")
|
|
790
|
-
self.show_message("ERROR: label file does not exist")
|
|
791
|
-
return
|
|
792
|
-
try:
|
|
793
|
-
labels, _ = load_labels(label_file)
|
|
794
|
-
except Exception:
|
|
795
|
-
self.ui.calibration_status.setText("could not load labels")
|
|
796
|
-
self.show_message(
|
|
797
|
-
(
|
|
798
|
-
"ERROR: could not load labels. "
|
|
799
|
-
"Check user manual for formatting instructions."
|
|
800
|
-
)
|
|
801
|
-
)
|
|
802
|
-
return
|
|
803
|
-
label_error_message = check_label_validity(
|
|
804
|
-
labels=labels,
|
|
805
|
-
confidence_scores=None,
|
|
806
|
-
samples_in_recording=eeg.size,
|
|
807
|
-
sampling_rate=sampling_rate,
|
|
808
|
-
epoch_length=self.epoch_length,
|
|
809
|
-
brain_state_set=self.brain_state_set,
|
|
810
|
-
)
|
|
811
|
-
if label_error_message:
|
|
812
|
-
self.ui.calibration_status.setText("invalid label file")
|
|
813
|
-
self.show_message(f"ERROR: {label_error_message}")
|
|
814
|
-
return
|
|
815
|
-
|
|
816
|
-
# get the name for the calibration file
|
|
817
|
-
filename, _ = QFileDialog.getSaveFileName(
|
|
818
|
-
self,
|
|
819
|
-
caption="Save calibration file as",
|
|
820
|
-
filter="*" + CALIBRATION_FILE_TYPE,
|
|
464
|
+
filename = select_save_location(
|
|
465
|
+
self, "Save calibration file as", "*" + CALIBRATION_FILE_TYPE
|
|
821
466
|
)
|
|
822
467
|
if not filename:
|
|
823
468
|
return
|
|
824
|
-
filename = os.path.normpath(filename)
|
|
825
469
|
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
create_calibration_file(
|
|
829
|
-
filename=filename,
|
|
830
|
-
eeg=eeg,
|
|
831
|
-
emg=emg,
|
|
832
|
-
labels=labels,
|
|
833
|
-
sampling_rate=sampling_rate,
|
|
470
|
+
result = create_calibration(
|
|
471
|
+
recording=self.recording_manager.current,
|
|
834
472
|
epoch_length=self.epoch_length,
|
|
835
|
-
brain_state_set=self.brain_state_set,
|
|
836
|
-
emg_filter=self.emg_filter,
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
self.ui.calibration_status.setText("")
|
|
840
|
-
self.show_message(
|
|
841
|
-
(
|
|
842
|
-
"Created calibration file using recording "
|
|
843
|
-
f"{self.recordings[self.recording_index].name} "
|
|
844
|
-
f"at {filename}"
|
|
845
|
-
)
|
|
473
|
+
brain_state_set=self.config.brain_state_set,
|
|
474
|
+
emg_filter=self.config.emg_filter,
|
|
475
|
+
output_filename=filename,
|
|
846
476
|
)
|
|
847
477
|
|
|
848
|
-
|
|
849
|
-
self.
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
error message.
|
|
857
|
-
|
|
858
|
-
:param recording_index: index of the recording in the list of
|
|
859
|
-
all recordings.
|
|
860
|
-
:return: error message
|
|
861
|
-
"""
|
|
862
|
-
sampling_rate = self.recordings[recording_index].sampling_rate
|
|
863
|
-
if self.epoch_length == 0:
|
|
864
|
-
return "epoch length can't be 0"
|
|
865
|
-
if sampling_rate == 0:
|
|
866
|
-
return "sampling rate can't be 0"
|
|
867
|
-
if self.epoch_length > sampling_rate:
|
|
868
|
-
return "invalid epoch length or sampling rate"
|
|
869
|
-
if self.recordings[self.recording_index].recording_file == "":
|
|
870
|
-
return "no recording selected"
|
|
871
|
-
if not os.path.isfile(self.recordings[self.recording_index].recording_file):
|
|
872
|
-
return "recording file does not exist"
|
|
873
|
-
if self.recordings[self.recording_index].label_file == "":
|
|
874
|
-
return "no label file selected"
|
|
875
|
-
|
|
876
|
-
def update_min_bout_length(self, new_value) -> None:
|
|
877
|
-
"""Update the minimum bout length
|
|
878
|
-
|
|
879
|
-
:param new_value: new minimum bout length, in seconds
|
|
880
|
-
"""
|
|
881
|
-
self.min_bout_length = new_value
|
|
882
|
-
|
|
883
|
-
def update_overwrite_policy(self, checked) -> None:
|
|
884
|
-
"""Toggle overwriting policy
|
|
885
|
-
|
|
886
|
-
If the checkbox is enabled, only epochs where the brain state is set to
|
|
887
|
-
undefined will be overwritten by the automatic scoring process.
|
|
888
|
-
|
|
889
|
-
:param checked: state of the checkbox
|
|
890
|
-
"""
|
|
891
|
-
self.only_overwrite_undefined = checked
|
|
892
|
-
|
|
893
|
-
def update_confidence_policy(self, checked) -> None:
|
|
894
|
-
"""Toggle policy for saving confidence scores
|
|
895
|
-
|
|
896
|
-
If the checkbox is enabled, confidence scores will be saved to the label files.
|
|
897
|
-
|
|
898
|
-
:param checked: state of the checkbox
|
|
899
|
-
"""
|
|
900
|
-
self.save_confidence_scores = checked
|
|
478
|
+
# Display results
|
|
479
|
+
result.report_to(self.show_message)
|
|
480
|
+
if not result.success:
|
|
481
|
+
self.ui.calibration_status.setText("error")
|
|
482
|
+
else:
|
|
483
|
+
self.ui.calibration_status.setText("")
|
|
484
|
+
self.recording_manager.current.calibration_file = filename
|
|
485
|
+
self.ui.calibration_file_label.setText(filename)
|
|
901
486
|
|
|
902
487
|
def manual_scoring(self) -> None:
|
|
903
488
|
"""View the selected recording for manual scoring"""
|
|
@@ -915,11 +500,12 @@ class AccuSleepWindow(QMainWindow):
|
|
|
915
500
|
|
|
916
501
|
# if the labels exist, load them
|
|
917
502
|
# otherwise, create a blank set of labels
|
|
918
|
-
label_file = self.
|
|
503
|
+
label_file = self.recording_manager.current.label_file
|
|
919
504
|
if os.path.isfile(label_file):
|
|
920
505
|
try:
|
|
921
506
|
labels, confidence_scores = load_labels(label_file)
|
|
922
507
|
except Exception:
|
|
508
|
+
logger.exception("Failed to load %s", label_file)
|
|
923
509
|
self.ui.manual_scoring_status.setText("could not load labels")
|
|
924
510
|
self.show_message(
|
|
925
511
|
(
|
|
@@ -937,56 +523,23 @@ class AccuSleepWindow(QMainWindow):
|
|
|
937
523
|
# to a label file that does not have one
|
|
938
524
|
confidence_scores = None
|
|
939
525
|
|
|
940
|
-
# check that
|
|
941
|
-
|
|
526
|
+
# check that labels are valid and correct minor length mismatches
|
|
527
|
+
labels, confidence_scores, validation_message = validate_and_correct_labels(
|
|
942
528
|
labels=labels,
|
|
943
529
|
confidence_scores=confidence_scores,
|
|
944
530
|
samples_in_recording=eeg.size,
|
|
945
531
|
sampling_rate=sampling_rate,
|
|
946
532
|
epoch_length=self.epoch_length,
|
|
947
|
-
brain_state_set=self.brain_state_set,
|
|
533
|
+
brain_state_set=self.config.brain_state_set,
|
|
948
534
|
)
|
|
949
|
-
if
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
epochs_in_recording = round(eeg.size / samples_per_epoch)
|
|
956
|
-
if epochs_in_recording - labels.size == 1:
|
|
957
|
-
labels = np.concatenate((labels, np.array([UNDEFINED_LABEL])))
|
|
958
|
-
if confidence_scores is not None:
|
|
959
|
-
confidence_scores = np.concatenate(
|
|
960
|
-
(confidence_scores, np.array([0]))
|
|
961
|
-
)
|
|
962
|
-
self.show_message(
|
|
963
|
-
(
|
|
964
|
-
"WARNING: an undefined epoch was added to "
|
|
965
|
-
"the label file to correct its length."
|
|
966
|
-
)
|
|
967
|
-
)
|
|
968
|
-
elif labels.size - epochs_in_recording == 1:
|
|
969
|
-
labels = labels[:-1]
|
|
970
|
-
if confidence_scores is not None:
|
|
971
|
-
confidence_scores = confidence_scores[:-1]
|
|
972
|
-
self.show_message(
|
|
973
|
-
(
|
|
974
|
-
"WARNING: the last epoch was removed from "
|
|
975
|
-
"the label file to correct its length."
|
|
976
|
-
)
|
|
977
|
-
)
|
|
978
|
-
else:
|
|
979
|
-
self.ui.manual_scoring_status.setText("invalid label file")
|
|
980
|
-
self.show_message(f"ERROR: {label_error}")
|
|
981
|
-
return
|
|
982
|
-
else:
|
|
983
|
-
self.ui.manual_scoring_status.setText("invalid label file")
|
|
984
|
-
self.show_message(f"ERROR: {label_error}")
|
|
985
|
-
return
|
|
535
|
+
if labels is None:
|
|
536
|
+
self.ui.manual_scoring_status.setText("invalid label file")
|
|
537
|
+
self.show_message(f"ERROR: {validation_message}")
|
|
538
|
+
return
|
|
539
|
+
if validation_message:
|
|
540
|
+
self.show_message(f"WARNING: {validation_message}")
|
|
986
541
|
|
|
987
|
-
self.show_message(
|
|
988
|
-
f"Viewing recording {self.recordings[self.recording_index].name}"
|
|
989
|
-
)
|
|
542
|
+
self.show_message(f"Viewing recording {self.recording_manager.current.name}")
|
|
990
543
|
self.ui.manual_scoring_status.setText("file is open")
|
|
991
544
|
|
|
992
545
|
# launch the manual scoring window
|
|
@@ -998,7 +551,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
998
551
|
confidence_scores=confidence_scores,
|
|
999
552
|
sampling_rate=sampling_rate,
|
|
1000
553
|
epoch_length=self.epoch_length,
|
|
1001
|
-
emg_filter=self.emg_filter,
|
|
554
|
+
emg_filter=self.config.emg_filter,
|
|
1002
555
|
)
|
|
1003
556
|
manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
|
|
1004
557
|
manual_scoring_window.exec()
|
|
@@ -1006,89 +559,48 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1006
559
|
|
|
1007
560
|
def create_label_file(self) -> None:
|
|
1008
561
|
"""Set the filename for a new label file"""
|
|
1009
|
-
filename
|
|
562
|
+
filename = select_save_location(
|
|
1010
563
|
self,
|
|
1011
|
-
|
|
1012
|
-
|
|
564
|
+
"Set filename for label file (nothing will be overwritten yet)",
|
|
565
|
+
"*" + LABEL_FILE_TYPE,
|
|
1013
566
|
)
|
|
1014
567
|
if filename:
|
|
1015
|
-
|
|
1016
|
-
self.recordings[self.recording_index].label_file = filename
|
|
568
|
+
self.recording_manager.current.label_file = filename
|
|
1017
569
|
self.ui.label_file_label.setText(filename)
|
|
1018
570
|
|
|
1019
571
|
def select_label_file(self) -> None:
|
|
1020
572
|
"""User can select an existing label file"""
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
if file_dialog.exec():
|
|
1028
|
-
selected_files = file_dialog.selectedFiles()
|
|
1029
|
-
filename = selected_files[0]
|
|
1030
|
-
filename = os.path.normpath(filename)
|
|
1031
|
-
self.recordings[self.recording_index].label_file = filename
|
|
573
|
+
filename = select_existing_file(
|
|
574
|
+
self, "Select label file", "*" + LABEL_FILE_TYPE
|
|
575
|
+
)
|
|
576
|
+
if filename:
|
|
577
|
+
self.recording_manager.current.label_file = filename
|
|
1032
578
|
self.ui.label_file_label.setText(filename)
|
|
1033
579
|
|
|
1034
580
|
def select_calibration_file(self) -> None:
|
|
1035
581
|
"""User can select a calibration file"""
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
if file_dialog.exec():
|
|
1043
|
-
selected_files = file_dialog.selectedFiles()
|
|
1044
|
-
filename = selected_files[0]
|
|
1045
|
-
filename = os.path.normpath(filename)
|
|
1046
|
-
self.recordings[self.recording_index].calibration_file = filename
|
|
582
|
+
filename = select_existing_file(
|
|
583
|
+
self, "Select calibration file", "*" + CALIBRATION_FILE_TYPE
|
|
584
|
+
)
|
|
585
|
+
if filename:
|
|
586
|
+
self.recording_manager.current.calibration_file = filename
|
|
1047
587
|
self.ui.calibration_file_label.setText(filename)
|
|
1048
588
|
|
|
1049
589
|
def select_recording_file(self) -> None:
|
|
1050
590
|
"""User can select a recording file"""
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
file_dialog.setNameFilter(f"(*{' *'.join(RECORDING_FILE_TYPES)})")
|
|
1056
|
-
|
|
1057
|
-
if file_dialog.exec():
|
|
1058
|
-
selected_files = file_dialog.selectedFiles()
|
|
1059
|
-
filename = selected_files[0]
|
|
1060
|
-
filename = os.path.normpath(filename)
|
|
1061
|
-
self.recordings[self.recording_index].recording_file = filename
|
|
591
|
+
file_filter = f"(*{' *'.join(RECORDING_FILE_TYPES)})"
|
|
592
|
+
filename = select_existing_file(self, "Select recording file", file_filter)
|
|
593
|
+
if filename:
|
|
594
|
+
self.recording_manager.current.recording_file = filename
|
|
1062
595
|
self.ui.recording_file_label.setText(filename)
|
|
1063
596
|
|
|
1064
597
|
def show_recording_info(self) -> None:
|
|
1065
598
|
"""Update the UI to show info for the selected recording"""
|
|
1066
|
-
self.
|
|
1067
|
-
|
|
1068
|
-
)
|
|
1069
|
-
self.ui.
|
|
1070
|
-
|
|
1071
|
-
)
|
|
1072
|
-
self.ui.label_file_label.setText(
|
|
1073
|
-
self.recordings[self.recording_index].label_file
|
|
1074
|
-
)
|
|
1075
|
-
self.ui.calibration_file_label.setText(
|
|
1076
|
-
self.recordings[self.recording_index].calibration_file
|
|
1077
|
-
)
|
|
1078
|
-
|
|
1079
|
-
def update_epoch_length(self, new_value: int | float) -> None:
|
|
1080
|
-
"""Update the epoch length when the widget state changes
|
|
1081
|
-
|
|
1082
|
-
:param new_value: new epoch length
|
|
1083
|
-
"""
|
|
1084
|
-
self.epoch_length = new_value
|
|
1085
|
-
|
|
1086
|
-
def update_sampling_rate(self, new_value: int | float) -> None:
|
|
1087
|
-
"""Update recording's sampling rate when the widget state changes
|
|
1088
|
-
|
|
1089
|
-
:param new_value: new sampling rate
|
|
1090
|
-
"""
|
|
1091
|
-
self.recordings[self.recording_index].sampling_rate = new_value
|
|
599
|
+
recording = self.recording_manager.current
|
|
600
|
+
self.ui.sampling_rate_input.setValue(recording.sampling_rate)
|
|
601
|
+
self.ui.recording_file_label.setText(recording.recording_file)
|
|
602
|
+
self.ui.label_file_label.setText(recording.label_file)
|
|
603
|
+
self.ui.calibration_file_label.setText(recording.calibration_file)
|
|
1092
604
|
|
|
1093
605
|
def show_message(self, message: str) -> None:
|
|
1094
606
|
"""Display a new message to the user
|
|
@@ -1103,50 +615,22 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1103
615
|
scrollbar = self.ui.message_area.verticalScrollBar()
|
|
1104
616
|
scrollbar.setValue(scrollbar.maximum())
|
|
1105
617
|
|
|
1106
|
-
def select_recording(self,
|
|
1107
|
-
"""Callback for when a recording is selected
|
|
1108
|
-
|
|
1109
|
-
:param list_index: index of this recording in the list widget
|
|
1110
|
-
"""
|
|
1111
|
-
# get index of this recording
|
|
1112
|
-
self.recording_index = list_index
|
|
1113
|
-
# display information about this recording
|
|
618
|
+
def select_recording(self, _index: int) -> None:
|
|
619
|
+
"""Callback for when a recording is selected"""
|
|
1114
620
|
self.show_recording_info()
|
|
1115
621
|
self.ui.selected_recording_groupbox.setTitle(
|
|
1116
|
-
f"Data / actions for Recording {self.
|
|
622
|
+
f"Data / actions for Recording {self.recording_manager.current.name}"
|
|
1117
623
|
)
|
|
1118
624
|
|
|
1119
625
|
def add_recording(self) -> None:
|
|
1120
626
|
"""Add new recording to the list"""
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
# add new recording to list
|
|
1125
|
-
self.recordings.append(
|
|
1126
|
-
Recording(
|
|
1127
|
-
name=new_name,
|
|
1128
|
-
sampling_rate=self.recordings[self.recording_index].sampling_rate,
|
|
1129
|
-
widget=QListWidgetItem(
|
|
1130
|
-
f"Recording {new_name}", self.ui.recording_list_widget
|
|
1131
|
-
),
|
|
1132
|
-
)
|
|
1133
|
-
)
|
|
1134
|
-
|
|
1135
|
-
# display new list
|
|
1136
|
-
self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
|
|
1137
|
-
self.ui.recording_list_widget.setCurrentRow(len(self.recordings) - 1)
|
|
1138
|
-
self.show_message(f"added Recording {new_name}")
|
|
627
|
+
current_sampling_rate = self.recording_manager.current.sampling_rate
|
|
628
|
+
recording = self.recording_manager.add(sampling_rate=current_sampling_rate)
|
|
629
|
+
self.show_message(f"added Recording {recording.name}")
|
|
1139
630
|
|
|
1140
631
|
def remove_recording(self) -> None:
|
|
1141
632
|
"""Delete selected recording from the list"""
|
|
1142
|
-
|
|
1143
|
-
current_list_index = self.ui.recording_list_widget.currentRow()
|
|
1144
|
-
_ = self.ui.recording_list_widget.takeItem(current_list_index)
|
|
1145
|
-
self.show_message(
|
|
1146
|
-
f"deleted Recording {self.recordings[current_list_index].name}"
|
|
1147
|
-
)
|
|
1148
|
-
del self.recordings[current_list_index]
|
|
1149
|
-
self.recording_index = self.ui.recording_list_widget.currentRow()
|
|
633
|
+
self.show_message(self.recording_manager.remove_current())
|
|
1150
634
|
|
|
1151
635
|
def show_user_manual(self) -> None:
|
|
1152
636
|
"""Show a popup window with the user manual"""
|
|
@@ -1162,310 +646,12 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1162
646
|
self.popup.setGeometry(QRect(100, 100, 600, 600))
|
|
1163
647
|
self.popup.show()
|
|
1164
648
|
|
|
1165
|
-
def initialize_settings_tab(self):
|
|
1166
|
-
"""Populate settings tab and assign its callbacks"""
|
|
1167
|
-
# store dictionary that maps digits to rows of widgets
|
|
1168
|
-
# in the settings tab
|
|
1169
|
-
self.settings_widgets = {
|
|
1170
|
-
1: StateSettings(
|
|
1171
|
-
digit=1,
|
|
1172
|
-
enabled_widget=self.ui.enable_state_1,
|
|
1173
|
-
name_widget=self.ui.state_name_1,
|
|
1174
|
-
is_scored_widget=self.ui.state_scored_1,
|
|
1175
|
-
frequency_widget=self.ui.state_frequency_1,
|
|
1176
|
-
),
|
|
1177
|
-
2: StateSettings(
|
|
1178
|
-
digit=2,
|
|
1179
|
-
enabled_widget=self.ui.enable_state_2,
|
|
1180
|
-
name_widget=self.ui.state_name_2,
|
|
1181
|
-
is_scored_widget=self.ui.state_scored_2,
|
|
1182
|
-
frequency_widget=self.ui.state_frequency_2,
|
|
1183
|
-
),
|
|
1184
|
-
3: StateSettings(
|
|
1185
|
-
digit=3,
|
|
1186
|
-
enabled_widget=self.ui.enable_state_3,
|
|
1187
|
-
name_widget=self.ui.state_name_3,
|
|
1188
|
-
is_scored_widget=self.ui.state_scored_3,
|
|
1189
|
-
frequency_widget=self.ui.state_frequency_3,
|
|
1190
|
-
),
|
|
1191
|
-
4: StateSettings(
|
|
1192
|
-
digit=4,
|
|
1193
|
-
enabled_widget=self.ui.enable_state_4,
|
|
1194
|
-
name_widget=self.ui.state_name_4,
|
|
1195
|
-
is_scored_widget=self.ui.state_scored_4,
|
|
1196
|
-
frequency_widget=self.ui.state_frequency_4,
|
|
1197
|
-
),
|
|
1198
|
-
5: StateSettings(
|
|
1199
|
-
digit=5,
|
|
1200
|
-
enabled_widget=self.ui.enable_state_5,
|
|
1201
|
-
name_widget=self.ui.state_name_5,
|
|
1202
|
-
is_scored_widget=self.ui.state_scored_5,
|
|
1203
|
-
frequency_widget=self.ui.state_frequency_5,
|
|
1204
|
-
),
|
|
1205
|
-
6: StateSettings(
|
|
1206
|
-
digit=6,
|
|
1207
|
-
enabled_widget=self.ui.enable_state_6,
|
|
1208
|
-
name_widget=self.ui.state_name_6,
|
|
1209
|
-
is_scored_widget=self.ui.state_scored_6,
|
|
1210
|
-
frequency_widget=self.ui.state_frequency_6,
|
|
1211
|
-
),
|
|
1212
|
-
7: StateSettings(
|
|
1213
|
-
digit=7,
|
|
1214
|
-
enabled_widget=self.ui.enable_state_7,
|
|
1215
|
-
name_widget=self.ui.state_name_7,
|
|
1216
|
-
is_scored_widget=self.ui.state_scored_7,
|
|
1217
|
-
frequency_widget=self.ui.state_frequency_7,
|
|
1218
|
-
),
|
|
1219
|
-
8: StateSettings(
|
|
1220
|
-
digit=8,
|
|
1221
|
-
enabled_widget=self.ui.enable_state_8,
|
|
1222
|
-
name_widget=self.ui.state_name_8,
|
|
1223
|
-
is_scored_widget=self.ui.state_scored_8,
|
|
1224
|
-
frequency_widget=self.ui.state_frequency_8,
|
|
1225
|
-
),
|
|
1226
|
-
9: StateSettings(
|
|
1227
|
-
digit=9,
|
|
1228
|
-
enabled_widget=self.ui.enable_state_9,
|
|
1229
|
-
name_widget=self.ui.state_name_9,
|
|
1230
|
-
is_scored_widget=self.ui.state_scored_9,
|
|
1231
|
-
frequency_widget=self.ui.state_frequency_9,
|
|
1232
|
-
),
|
|
1233
|
-
0: StateSettings(
|
|
1234
|
-
digit=0,
|
|
1235
|
-
enabled_widget=self.ui.enable_state_0,
|
|
1236
|
-
name_widget=self.ui.state_name_0,
|
|
1237
|
-
is_scored_widget=self.ui.state_scored_0,
|
|
1238
|
-
frequency_widget=self.ui.state_frequency_0,
|
|
1239
|
-
),
|
|
1240
|
-
}
|
|
1241
|
-
|
|
1242
|
-
# update widget state to display current config
|
|
1243
|
-
# UI defaults
|
|
1244
|
-
self.ui.default_epoch_input.setValue(self.epoch_length)
|
|
1245
|
-
self.ui.overwrite_default_checkbox.setChecked(self.only_overwrite_undefined)
|
|
1246
|
-
self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_scores)
|
|
1247
|
-
self.ui.default_min_bout_length_spinbox.setValue(self.min_bout_length)
|
|
1248
|
-
self.ui.epochs_to_show_spinbox.setValue(self.default_epochs_to_show)
|
|
1249
|
-
self.ui.autoscroll_checkbox.setChecked(self.default_autoscroll_state)
|
|
1250
|
-
# EMG filter
|
|
1251
|
-
self.ui.emg_order_spinbox.setValue(self.emg_filter.order)
|
|
1252
|
-
self.ui.bp_lower_spinbox.setValue(self.emg_filter.bp_lower)
|
|
1253
|
-
self.ui.bp_upper_spinbox.setValue(self.emg_filter.bp_upper)
|
|
1254
|
-
# model training hyperparameters
|
|
1255
|
-
self.ui.batch_size_spinbox.setValue(self.hyperparameters.batch_size)
|
|
1256
|
-
self.ui.learning_rate_spinbox.setValue(self.hyperparameters.learning_rate)
|
|
1257
|
-
self.ui.momentum_spinbox.setValue(self.hyperparameters.momentum)
|
|
1258
|
-
self.ui.training_epochs_spinbox.setValue(self.hyperparameters.training_epochs)
|
|
1259
|
-
# brain states
|
|
1260
|
-
states = {b.digit: b for b in self.brain_state_set.brain_states}
|
|
1261
|
-
for digit in range(10):
|
|
1262
|
-
if digit in states.keys():
|
|
1263
|
-
self.settings_widgets[digit].enabled_widget.setChecked(True)
|
|
1264
|
-
self.settings_widgets[digit].name_widget.setText(states[digit].name)
|
|
1265
|
-
self.settings_widgets[digit].is_scored_widget.setChecked(
|
|
1266
|
-
states[digit].is_scored
|
|
1267
|
-
)
|
|
1268
|
-
self.settings_widgets[digit].frequency_widget.setValue(
|
|
1269
|
-
states[digit].frequency
|
|
1270
|
-
)
|
|
1271
|
-
else:
|
|
1272
|
-
self.settings_widgets[digit].enabled_widget.setChecked(False)
|
|
1273
|
-
self.settings_widgets[digit].name_widget.setEnabled(False)
|
|
1274
|
-
self.settings_widgets[digit].is_scored_widget.setEnabled(False)
|
|
1275
|
-
self.settings_widgets[digit].frequency_widget.setEnabled(False)
|
|
1276
|
-
|
|
1277
|
-
# set callbacks
|
|
1278
|
-
self.ui.emg_order_spinbox.valueChanged.connect(self.emg_filter_order_changed)
|
|
1279
|
-
self.ui.bp_lower_spinbox.valueChanged.connect(self.emg_filter_bp_lower_changed)
|
|
1280
|
-
self.ui.bp_upper_spinbox.valueChanged.connect(self.emg_filter_bp_upper_changed)
|
|
1281
|
-
self.ui.batch_size_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1282
|
-
self.ui.learning_rate_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1283
|
-
self.ui.momentum_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1284
|
-
self.ui.training_epochs_spinbox.valueChanged.connect(
|
|
1285
|
-
self.hyperparameters_changed
|
|
1286
|
-
)
|
|
1287
|
-
for digit in range(10):
|
|
1288
|
-
state = self.settings_widgets[digit]
|
|
1289
|
-
state.enabled_widget.stateChanged.connect(
|
|
1290
|
-
partial(self.set_brain_state_enabled, digit)
|
|
1291
|
-
)
|
|
1292
|
-
state.name_widget.editingFinished.connect(self.check_config_validity)
|
|
1293
|
-
state.is_scored_widget.stateChanged.connect(
|
|
1294
|
-
partial(self.is_scored_changed, digit)
|
|
1295
|
-
)
|
|
1296
|
-
state.frequency_widget.valueChanged.connect(self.check_config_validity)
|
|
1297
|
-
|
|
1298
|
-
def set_brain_state_enabled(self, digit, e) -> None:
|
|
1299
|
-
"""Called when user clicks "enabled" checkbox
|
|
1300
|
-
|
|
1301
|
-
:param digit: brain state digit
|
|
1302
|
-
:param e: unused but mandatory
|
|
1303
|
-
"""
|
|
1304
|
-
# get the widgets for this brain state
|
|
1305
|
-
state = self.settings_widgets[digit]
|
|
1306
|
-
# update state of these widgets
|
|
1307
|
-
is_checked = state.enabled_widget.isChecked()
|
|
1308
|
-
for widget in [
|
|
1309
|
-
state.name_widget,
|
|
1310
|
-
state.is_scored_widget,
|
|
1311
|
-
]:
|
|
1312
|
-
widget.setEnabled(is_checked)
|
|
1313
|
-
state.frequency_widget.setEnabled(
|
|
1314
|
-
is_checked and state.is_scored_widget.isChecked()
|
|
1315
|
-
)
|
|
1316
|
-
if not is_checked:
|
|
1317
|
-
state.name_widget.setText("")
|
|
1318
|
-
state.frequency_widget.setValue(0)
|
|
1319
|
-
# check that configuration is valid
|
|
1320
|
-
_ = self.check_config_validity()
|
|
1321
|
-
|
|
1322
|
-
def is_scored_changed(self, digit, e) -> None:
|
|
1323
|
-
"""Called when user sets whether a state is scored
|
|
1324
|
-
|
|
1325
|
-
:param digit: brain state digit
|
|
1326
|
-
:param e: unused, but mandatory
|
|
1327
|
-
"""
|
|
1328
|
-
# get the widgets for this brain state
|
|
1329
|
-
state = self.settings_widgets[digit]
|
|
1330
|
-
# update the state of these widgets
|
|
1331
|
-
is_checked = state.is_scored_widget.isChecked()
|
|
1332
|
-
state.frequency_widget.setEnabled(is_checked)
|
|
1333
|
-
if not is_checked:
|
|
1334
|
-
state.frequency_widget.setValue(0)
|
|
1335
|
-
# check that configuration is valid
|
|
1336
|
-
_ = self.check_config_validity()
|
|
1337
|
-
|
|
1338
|
-
def emg_filter_order_changed(self, new_value: int) -> None:
|
|
1339
|
-
"""Called when user modifies EMG filter order
|
|
1340
|
-
|
|
1341
|
-
:param new_value: new EMG filter order
|
|
1342
|
-
"""
|
|
1343
|
-
self.emg_filter.order = new_value
|
|
1344
|
-
|
|
1345
|
-
def emg_filter_bp_lower_changed(self, new_value: int | float) -> None:
|
|
1346
|
-
"""Called when user modifies EMG filter lower cutoff
|
|
1347
|
-
|
|
1348
|
-
:param new_value: new lower bandpass cutoff frequency
|
|
1349
|
-
"""
|
|
1350
|
-
self.emg_filter.bp_lower = new_value
|
|
1351
|
-
_ = self.check_config_validity()
|
|
1352
|
-
|
|
1353
|
-
def emg_filter_bp_upper_changed(self, new_value: int | float) -> None:
|
|
1354
|
-
"""Called when user modifies EMG filter upper cutoff
|
|
1355
|
-
|
|
1356
|
-
:param new_value: new upper bandpass cutoff frequency
|
|
1357
|
-
"""
|
|
1358
|
-
self.emg_filter.bp_upper = new_value
|
|
1359
|
-
_ = self.check_config_validity()
|
|
1360
|
-
|
|
1361
|
-
def hyperparameters_changed(self, new_value) -> None:
|
|
1362
|
-
"""Called when user modifies model training hyperparameters
|
|
1363
|
-
|
|
1364
|
-
:param new_value: unused
|
|
1365
|
-
"""
|
|
1366
|
-
self.hyperparameters = Hyperparameters(
|
|
1367
|
-
batch_size=self.ui.batch_size_spinbox.value(),
|
|
1368
|
-
learning_rate=self.ui.learning_rate_spinbox.value(),
|
|
1369
|
-
momentum=self.ui.momentum_spinbox.value(),
|
|
1370
|
-
training_epochs=self.ui.training_epochs_spinbox.value(),
|
|
1371
|
-
)
|
|
1372
|
-
|
|
1373
|
-
def check_config_validity(self) -> str:
|
|
1374
|
-
"""Check if brain state configuration on screen is valid"""
|
|
1375
|
-
# error message, if we get one
|
|
1376
|
-
message = None
|
|
1377
|
-
|
|
1378
|
-
# strip whitespace from brain state names and update display
|
|
1379
|
-
for digit in range(10):
|
|
1380
|
-
state = self.settings_widgets[digit]
|
|
1381
|
-
current_name = state.name_widget.text()
|
|
1382
|
-
formatted_name = current_name.strip()
|
|
1383
|
-
if current_name != formatted_name:
|
|
1384
|
-
state.name_widget.setText(formatted_name)
|
|
1385
|
-
|
|
1386
|
-
# check if names are unique and frequencies add up to 1
|
|
1387
|
-
names = []
|
|
1388
|
-
frequencies = []
|
|
1389
|
-
for digit in range(10):
|
|
1390
|
-
state = self.settings_widgets[digit]
|
|
1391
|
-
if state.enabled_widget.isChecked():
|
|
1392
|
-
names.append(state.name_widget.text())
|
|
1393
|
-
frequencies.append(state.frequency_widget.value())
|
|
1394
|
-
if len(names) != len(set(names)):
|
|
1395
|
-
message = "Error: names must be unique"
|
|
1396
|
-
if sum(frequencies) != 1:
|
|
1397
|
-
message = "Error: sum(frequencies) != 1"
|
|
1398
|
-
|
|
1399
|
-
# check validity of EMG filter settings
|
|
1400
|
-
if self.emg_filter.bp_lower >= self.emg_filter.bp_upper:
|
|
1401
|
-
message = "Error: EMG filter cutoff frequencies are invalid"
|
|
1402
|
-
|
|
1403
|
-
if message is not None:
|
|
1404
|
-
self.ui.save_config_status.setText(message)
|
|
1405
|
-
self.ui.save_config_button.setEnabled(False)
|
|
1406
|
-
return message
|
|
1407
|
-
|
|
1408
|
-
self.ui.save_config_button.setEnabled(True)
|
|
1409
|
-
self.ui.save_config_status.setText("")
|
|
1410
|
-
|
|
1411
|
-
def save_brain_state_config(self):
|
|
1412
|
-
"""Save configuration to file"""
|
|
1413
|
-
# check that configuration is valid
|
|
1414
|
-
error_message = self.check_config_validity()
|
|
1415
|
-
if error_message is not None:
|
|
1416
|
-
return
|
|
1417
|
-
|
|
1418
|
-
# build a BrainStateMapper object from the current configuration
|
|
1419
|
-
brain_states = list()
|
|
1420
|
-
for digit in range(10):
|
|
1421
|
-
state = self.settings_widgets[digit]
|
|
1422
|
-
if state.enabled_widget.isChecked():
|
|
1423
|
-
brain_states.append(
|
|
1424
|
-
BrainState(
|
|
1425
|
-
name=state.name_widget.text(),
|
|
1426
|
-
digit=digit,
|
|
1427
|
-
is_scored=state.is_scored_widget.isChecked(),
|
|
1428
|
-
frequency=state.frequency_widget.value(),
|
|
1429
|
-
)
|
|
1430
|
-
)
|
|
1431
|
-
self.brain_state_set = BrainStateSet(brain_states, UNDEFINED_LABEL)
|
|
1432
|
-
|
|
1433
|
-
# save to file
|
|
1434
|
-
save_config(
|
|
1435
|
-
brain_state_set=self.brain_state_set,
|
|
1436
|
-
default_epoch_length=self.ui.default_epoch_input.value(),
|
|
1437
|
-
overwrite_setting=self.ui.overwrite_default_checkbox.isChecked(),
|
|
1438
|
-
save_confidence_setting=self.ui.confidence_setting_checkbox.isChecked(),
|
|
1439
|
-
min_bout_length=self.ui.default_min_bout_length_spinbox.value(),
|
|
1440
|
-
emg_filter=EMGFilter(
|
|
1441
|
-
order=self.emg_filter.order,
|
|
1442
|
-
bp_lower=self.emg_filter.bp_lower,
|
|
1443
|
-
bp_upper=self.emg_filter.bp_upper,
|
|
1444
|
-
),
|
|
1445
|
-
hyperparameters=Hyperparameters(
|
|
1446
|
-
batch_size=self.hyperparameters.batch_size,
|
|
1447
|
-
learning_rate=self.hyperparameters.learning_rate,
|
|
1448
|
-
momentum=self.hyperparameters.momentum,
|
|
1449
|
-
training_epochs=self.hyperparameters.training_epochs,
|
|
1450
|
-
),
|
|
1451
|
-
epochs_to_show=self.ui.epochs_to_show_spinbox.value(),
|
|
1452
|
-
autoscroll_state=self.ui.autoscroll_checkbox.isChecked(),
|
|
1453
|
-
)
|
|
1454
|
-
self.ui.save_config_status.setText("configuration saved")
|
|
1455
|
-
|
|
1456
|
-
def reset_emg_filter_settings(self) -> None:
|
|
1457
|
-
self.ui.emg_order_spinbox.setValue(DEFAULT_EMG_FILTER_ORDER)
|
|
1458
|
-
self.ui.bp_lower_spinbox.setValue(DEFAULT_EMG_BP_LOWER)
|
|
1459
|
-
self.ui.bp_upper_spinbox.setValue(DEFAULT_EMG_BP_UPPER)
|
|
1460
|
-
|
|
1461
|
-
def reset_hyperparams_settings(self):
|
|
1462
|
-
self.ui.batch_size_spinbox.setValue(DEFAULT_BATCH_SIZE)
|
|
1463
|
-
self.ui.learning_rate_spinbox.setValue(DEFAULT_LEARNING_RATE)
|
|
1464
|
-
self.ui.momentum_spinbox.setValue(DEFAULT_MOMENTUM)
|
|
1465
|
-
self.ui.training_epochs_spinbox.setValue(DEFAULT_TRAINING_EPOCHS)
|
|
1466
|
-
|
|
1467
649
|
|
|
1468
650
|
def run_primary_window() -> None:
|
|
651
|
+
logging.basicConfig(
|
|
652
|
+
level=logging.INFO,
|
|
653
|
+
format="%(levelname)s - %(name)s - %(message)s",
|
|
654
|
+
)
|
|
1469
655
|
app = QApplication(sys.argv)
|
|
1470
656
|
AccuSleepWindow()
|
|
1471
657
|
sys.exit(app.exec())
|