accusleepy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/__init__.py +0 -0
- accusleepy/__main__.py +4 -0
- accusleepy/brain_state_set.py +89 -0
- accusleepy/classification.py +267 -0
- accusleepy/config.json +22 -0
- accusleepy/constants.py +37 -0
- accusleepy/fileio.py +201 -0
- accusleepy/gui/__init__.py +0 -0
- accusleepy/gui/icons/brightness_down.png +0 -0
- accusleepy/gui/icons/brightness_up.png +0 -0
- accusleepy/gui/icons/double_down_arrow.png +0 -0
- accusleepy/gui/icons/double_up_arrow.png +0 -0
- accusleepy/gui/icons/down_arrow.png +0 -0
- accusleepy/gui/icons/home.png +0 -0
- accusleepy/gui/icons/question.png +0 -0
- accusleepy/gui/icons/save.png +0 -0
- accusleepy/gui/icons/up_arrow.png +0 -0
- accusleepy/gui/icons/zoom_in.png +0 -0
- accusleepy/gui/icons/zoom_out.png +0 -0
- accusleepy/gui/main.py +1372 -0
- accusleepy/gui/manual_scoring.py +1086 -0
- accusleepy/gui/mplwidget.py +356 -0
- accusleepy/gui/primary_window.py +2330 -0
- accusleepy/gui/primary_window.ui +3432 -0
- accusleepy/gui/resources.qrc +16 -0
- accusleepy/gui/resources_rc.py +6710 -0
- accusleepy/gui/text/config_guide.txt +24 -0
- accusleepy/gui/text/main_guide.txt +142 -0
- accusleepy/gui/text/manual_scoring_guide.txt +28 -0
- accusleepy/gui/viewer_window.py +598 -0
- accusleepy/gui/viewer_window.ui +894 -0
- accusleepy/models.py +48 -0
- accusleepy/multitaper.py +659 -0
- accusleepy/signal_processing.py +589 -0
- accusleepy-0.1.0.dist-info/METADATA +57 -0
- accusleepy-0.1.0.dist-info/RECORD +37 -0
- accusleepy-0.1.0.dist-info/WHEEL +4 -0
accusleepy/gui/main.py
ADDED
|
@@ -0,0 +1,1372 @@
|
|
|
1
|
+
# AccuSleePy main window
|
|
2
|
+
# Icon source: Arkinasi, https://www.flaticon.com/authors/arkinasi
|
|
3
|
+
|
|
4
|
+
import datetime
|
|
5
|
+
import os
|
|
6
|
+
import shutil
|
|
7
|
+
import sys
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from functools import partial
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from PySide6 import QtCore, QtGui, QtWidgets
|
|
13
|
+
|
|
14
|
+
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
15
|
+
from accusleepy.classification import (
|
|
16
|
+
create_calibration_file,
|
|
17
|
+
score_recording,
|
|
18
|
+
train_model,
|
|
19
|
+
)
|
|
20
|
+
from accusleepy.constants import (
|
|
21
|
+
CALIBRATION_FILE_TYPE,
|
|
22
|
+
DEFAULT_MODEL_TYPE,
|
|
23
|
+
LABEL_FILE_TYPE,
|
|
24
|
+
MODEL_FILE_TYPE,
|
|
25
|
+
REAL_TIME_MODEL_TYPE,
|
|
26
|
+
RECORDING_FILE_TYPES,
|
|
27
|
+
RECORDING_LIST_FILE_TYPE,
|
|
28
|
+
UNDEFINED_LABEL,
|
|
29
|
+
)
|
|
30
|
+
from accusleepy.fileio import (
|
|
31
|
+
Recording,
|
|
32
|
+
load_calibration_file,
|
|
33
|
+
load_config,
|
|
34
|
+
load_labels,
|
|
35
|
+
load_model,
|
|
36
|
+
load_recording,
|
|
37
|
+
load_recording_list,
|
|
38
|
+
save_config,
|
|
39
|
+
save_labels,
|
|
40
|
+
save_model,
|
|
41
|
+
save_recording_list,
|
|
42
|
+
)
|
|
43
|
+
from accusleepy.gui.manual_scoring import ManualScoringWindow
|
|
44
|
+
from accusleepy.gui.primary_window import Ui_PrimaryWindow
|
|
45
|
+
from accusleepy.signal_processing import (
|
|
46
|
+
ANNOTATIONS_FILENAME,
|
|
47
|
+
create_training_images,
|
|
48
|
+
enforce_min_bout_length,
|
|
49
|
+
resample_and_standardize,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# max number of messages to display
|
|
53
|
+
MESSAGE_BOX_MAX_DEPTH = 50
|
|
54
|
+
LABEL_LENGTH_ERROR = "label file length does not match recording length"
|
|
55
|
+
# relative path to user manual txt file
|
|
56
|
+
USER_MANUAL_FILE = "text/main_guide.txt"
|
|
57
|
+
CONFIG_GUIDE_FILE = "text/config_guide.txt"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class StateSettings:
|
|
62
|
+
"""Widgets for config settings for a brain state"""
|
|
63
|
+
|
|
64
|
+
digit: int
|
|
65
|
+
enabled_widget: QtWidgets.QCheckBox
|
|
66
|
+
name_widget: QtWidgets.QLabel
|
|
67
|
+
is_scored_widget: QtWidgets.QCheckBox
|
|
68
|
+
frequency_widget: QtWidgets.QDoubleSpinBox
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
72
|
+
"""AccuSleePy primary window"""
|
|
73
|
+
|
|
74
|
+
def __init__(self):
|
|
75
|
+
super(AccuSleepWindow, self).__init__()
|
|
76
|
+
|
|
77
|
+
# initialize the UI
|
|
78
|
+
self.ui = Ui_PrimaryWindow()
|
|
79
|
+
self.ui.setupUi(self)
|
|
80
|
+
self.setWindowTitle("AccuSleePy")
|
|
81
|
+
|
|
82
|
+
# fill in settings tab
|
|
83
|
+
self.brain_state_set = load_config()
|
|
84
|
+
self.settings_widgets = None
|
|
85
|
+
self.initialize_settings_tab()
|
|
86
|
+
|
|
87
|
+
# initialize info about the recordings, classification data / settings
|
|
88
|
+
self.epoch_length = 0
|
|
89
|
+
self.model = None
|
|
90
|
+
self.only_overwrite_undefined = False
|
|
91
|
+
self.min_bout_length = 5
|
|
92
|
+
|
|
93
|
+
# initialize model training variables
|
|
94
|
+
self.training_epochs_per_img = 9
|
|
95
|
+
self.delete_training_images = True
|
|
96
|
+
self.training_image_dir = ""
|
|
97
|
+
self.model_type = DEFAULT_MODEL_TYPE
|
|
98
|
+
|
|
99
|
+
# metadata for the currently loaded classification model
|
|
100
|
+
self.model_epoch_length = None
|
|
101
|
+
self.model_epochs_per_img = None
|
|
102
|
+
|
|
103
|
+
# set up the list of recordings
|
|
104
|
+
first_recording = Recording(
|
|
105
|
+
widget=QtWidgets.QListWidgetItem(
|
|
106
|
+
"Recording 1", self.ui.recording_list_widget
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
self.ui.recording_list_widget.addItem(first_recording.widget)
|
|
110
|
+
self.ui.recording_list_widget.setCurrentRow(0)
|
|
111
|
+
# index of currently selected recording in the list
|
|
112
|
+
self.recording_index = 0
|
|
113
|
+
# list of recordings the user has added
|
|
114
|
+
self.recordings = [first_recording]
|
|
115
|
+
|
|
116
|
+
# messages to display
|
|
117
|
+
self.messages = []
|
|
118
|
+
|
|
119
|
+
# user input: keyboard shortcuts
|
|
120
|
+
keypress_quit = QtGui.QShortcut(
|
|
121
|
+
QtGui.QKeySequence(
|
|
122
|
+
QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_W)
|
|
123
|
+
),
|
|
124
|
+
self,
|
|
125
|
+
)
|
|
126
|
+
keypress_quit.activated.connect(self.close)
|
|
127
|
+
|
|
128
|
+
# user input: button presses
|
|
129
|
+
self.ui.add_button.clicked.connect(self.add_recording)
|
|
130
|
+
self.ui.remove_button.clicked.connect(self.remove_recording)
|
|
131
|
+
self.ui.recording_list_widget.currentRowChanged.connect(self.select_recording)
|
|
132
|
+
self.ui.sampling_rate_input.valueChanged.connect(self.update_sampling_rate)
|
|
133
|
+
self.ui.epoch_length_input.valueChanged.connect(self.update_epoch_length)
|
|
134
|
+
self.ui.recording_file_button.clicked.connect(self.select_recording_file)
|
|
135
|
+
self.ui.select_label_button.clicked.connect(self.select_label_file)
|
|
136
|
+
self.ui.create_label_button.clicked.connect(self.create_label_file)
|
|
137
|
+
self.ui.manual_scoring_button.clicked.connect(self.manual_scoring)
|
|
138
|
+
self.ui.create_calibration_button.clicked.connect(self.create_calibration_file)
|
|
139
|
+
self.ui.select_calibration_button.clicked.connect(self.select_calibration_file)
|
|
140
|
+
self.ui.load_model_button.clicked.connect(partial(self.load_model, None))
|
|
141
|
+
self.ui.score_all_button.clicked.connect(self.score_all)
|
|
142
|
+
self.ui.overwritecheckbox.stateChanged.connect(self.update_overwrite_policy)
|
|
143
|
+
self.ui.bout_length_input.valueChanged.connect(self.update_min_bout_length)
|
|
144
|
+
self.ui.user_manual_button.clicked.connect(self.show_user_manual)
|
|
145
|
+
self.ui.image_number_input.valueChanged.connect(self.update_epochs_per_img)
|
|
146
|
+
self.ui.delete_image_box.stateChanged.connect(self.update_image_deletion)
|
|
147
|
+
self.ui.training_folder_button.clicked.connect(self.set_training_folder)
|
|
148
|
+
self.ui.train_model_button.clicked.connect(self.train_model)
|
|
149
|
+
self.ui.save_config_button.clicked.connect(self.save_brain_state_config)
|
|
150
|
+
self.ui.export_button.clicked.connect(self.export_recording_list)
|
|
151
|
+
self.ui.import_button.clicked.connect(self.import_recording_list)
|
|
152
|
+
self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
|
|
153
|
+
|
|
154
|
+
# user input: drag and drop
|
|
155
|
+
self.ui.recording_file_label.installEventFilter(self)
|
|
156
|
+
self.ui.label_file_label.installEventFilter(self)
|
|
157
|
+
self.ui.calibration_file_label.installEventFilter(self)
|
|
158
|
+
self.ui.model_label.installEventFilter(self)
|
|
159
|
+
|
|
160
|
+
self.show()
|
|
161
|
+
|
|
162
|
+
def model_type_radio_buttons(self, default_selected: bool) -> None:
|
|
163
|
+
"""Toggle training default or real-time model
|
|
164
|
+
|
|
165
|
+
:param default_selected: whether default option is selected
|
|
166
|
+
"""
|
|
167
|
+
if default_selected:
|
|
168
|
+
self.model_type = DEFAULT_MODEL_TYPE
|
|
169
|
+
else:
|
|
170
|
+
self.model_type = REAL_TIME_MODEL_TYPE
|
|
171
|
+
|
|
172
|
+
def export_recording_list(self) -> None:
|
|
173
|
+
"""Save current list of recordings to file"""
|
|
174
|
+
# get the name for the recording list file
|
|
175
|
+
filename, _ = QtWidgets.QFileDialog.getSaveFileName(
|
|
176
|
+
self,
|
|
177
|
+
caption="Save list of recordings as",
|
|
178
|
+
filter="*" + RECORDING_LIST_FILE_TYPE,
|
|
179
|
+
)
|
|
180
|
+
if not filename:
|
|
181
|
+
return
|
|
182
|
+
save_recording_list(filename=filename, recordings=self.recordings)
|
|
183
|
+
self.show_message(f"Saved list of recordings to {filename}")
|
|
184
|
+
|
|
185
|
+
def import_recording_list(self):
|
|
186
|
+
"""Load list of recordings from file, overwriting current list"""
|
|
187
|
+
file_dialog = QtWidgets.QFileDialog(self)
|
|
188
|
+
file_dialog.setWindowTitle("Select list of recordings")
|
|
189
|
+
file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
|
|
190
|
+
file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
|
|
191
|
+
file_dialog.setNameFilter("*" + RECORDING_LIST_FILE_TYPE)
|
|
192
|
+
|
|
193
|
+
if file_dialog.exec():
|
|
194
|
+
selected_files = file_dialog.selectedFiles()
|
|
195
|
+
filename = selected_files[0]
|
|
196
|
+
else:
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
# clear widget
|
|
200
|
+
self.ui.recording_list_widget.clear()
|
|
201
|
+
# overwrite current list
|
|
202
|
+
self.recordings = load_recording_list(filename)
|
|
203
|
+
|
|
204
|
+
for recording in self.recordings:
|
|
205
|
+
recording.widget = QtWidgets.QListWidgetItem(
|
|
206
|
+
f"Recording {recording.name}", self.ui.recording_list_widget
|
|
207
|
+
)
|
|
208
|
+
self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
|
|
209
|
+
|
|
210
|
+
# display new list
|
|
211
|
+
self.ui.recording_list_widget.setCurrentRow(0)
|
|
212
|
+
self.show_message(f"Loaded list of recordings from {filename}")
|
|
213
|
+
|
|
214
|
+
def eventFilter(self, obj: QtCore.QObject, event: QtCore.QEvent) -> bool:
|
|
215
|
+
"""Filter mouse events to detect when user drags/drops a file
|
|
216
|
+
|
|
217
|
+
:param obj: UI object receiving the event
|
|
218
|
+
:param event: mouse event
|
|
219
|
+
:return: whether to filter (block) the event
|
|
220
|
+
"""
|
|
221
|
+
filename = None
|
|
222
|
+
if obj in [
|
|
223
|
+
self.ui.recording_file_label,
|
|
224
|
+
self.ui.label_file_label,
|
|
225
|
+
self.ui.calibration_file_label,
|
|
226
|
+
self.ui.model_label,
|
|
227
|
+
]:
|
|
228
|
+
event.accept()
|
|
229
|
+
if event.type() == QtCore.QEvent.Drop:
|
|
230
|
+
urls = event.mimeData().urls()
|
|
231
|
+
if len(urls) == 1:
|
|
232
|
+
filename = urls[0].toLocalFile()
|
|
233
|
+
|
|
234
|
+
if filename is None:
|
|
235
|
+
return super().eventFilter(obj, event)
|
|
236
|
+
|
|
237
|
+
_, file_extension = os.path.splitext(filename)
|
|
238
|
+
|
|
239
|
+
if obj == self.ui.recording_file_label:
|
|
240
|
+
if file_extension in RECORDING_FILE_TYPES:
|
|
241
|
+
self.recordings[self.recording_index].recording_file = filename
|
|
242
|
+
self.ui.recording_file_label.setText(filename)
|
|
243
|
+
elif obj == self.ui.label_file_label:
|
|
244
|
+
if file_extension == LABEL_FILE_TYPE:
|
|
245
|
+
self.recordings[self.recording_index].label_file = filename
|
|
246
|
+
self.ui.label_file_label.setText(filename)
|
|
247
|
+
elif obj == self.ui.calibration_file_label:
|
|
248
|
+
if file_extension == CALIBRATION_FILE_TYPE:
|
|
249
|
+
self.recordings[self.recording_index].calibration_file = filename
|
|
250
|
+
self.ui.calibration_file_label.setText(filename)
|
|
251
|
+
elif obj == self.ui.model_label:
|
|
252
|
+
self.load_model(filename=filename)
|
|
253
|
+
|
|
254
|
+
return super().eventFilter(obj, event)
|
|
255
|
+
|
|
256
|
+
def train_model(self) -> None:
|
|
257
|
+
# check basic training inputs
|
|
258
|
+
if (
|
|
259
|
+
self.model_type == DEFAULT_MODEL_TYPE
|
|
260
|
+
and self.training_epochs_per_img % 2 == 0
|
|
261
|
+
):
|
|
262
|
+
self.show_message(
|
|
263
|
+
(
|
|
264
|
+
"ERROR: for the default model type, number of epochs "
|
|
265
|
+
"per image must be an odd number."
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
return
|
|
269
|
+
if self.training_image_dir == "":
|
|
270
|
+
self.show_message("ERROR: no folder selected for training images.")
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
# check some inputs for each recording
|
|
274
|
+
for recording_index in range(len(self.recordings)):
|
|
275
|
+
error_message = self.check_single_file_inputs(recording_index)
|
|
276
|
+
if error_message:
|
|
277
|
+
self.show_message(
|
|
278
|
+
f"ERROR ({self.recordings[recording_index].name}): {error_message}"
|
|
279
|
+
)
|
|
280
|
+
return
|
|
281
|
+
|
|
282
|
+
# get filename for the new model
|
|
283
|
+
model_filename, _ = QtWidgets.QFileDialog.getSaveFileName(
|
|
284
|
+
self,
|
|
285
|
+
caption="Save classification model file as",
|
|
286
|
+
filter="*" + MODEL_FILE_TYPE,
|
|
287
|
+
)
|
|
288
|
+
if not model_filename:
|
|
289
|
+
self.show_message("Model training canceled, no filename given")
|
|
290
|
+
|
|
291
|
+
# create image folder
|
|
292
|
+
if os.path.exists(self.training_image_dir):
|
|
293
|
+
self.show_message(
|
|
294
|
+
"Warning: training image folder exists, will be overwritten"
|
|
295
|
+
)
|
|
296
|
+
os.makedirs(self.training_image_dir, exist_ok=True)
|
|
297
|
+
|
|
298
|
+
# create training images
|
|
299
|
+
self.show_message(
|
|
300
|
+
(f"Creating training images in {self.training_image_dir}, please wait...")
|
|
301
|
+
)
|
|
302
|
+
self.ui.message_area.repaint()
|
|
303
|
+
QtWidgets.QApplication.processEvents()
|
|
304
|
+
print("Creating training images")
|
|
305
|
+
failed_recordings = create_training_images(
|
|
306
|
+
recordings=self.recordings,
|
|
307
|
+
output_path=self.training_image_dir,
|
|
308
|
+
epoch_length=self.epoch_length,
|
|
309
|
+
epochs_per_img=self.training_epochs_per_img,
|
|
310
|
+
brain_state_set=self.brain_state_set,
|
|
311
|
+
model_type=self.model_type,
|
|
312
|
+
)
|
|
313
|
+
if len(failed_recordings) > 0:
|
|
314
|
+
if len(failed_recordings) == len(self.recordings):
|
|
315
|
+
self.show_message("ERROR: no recordings were valid!")
|
|
316
|
+
else:
|
|
317
|
+
self.show_message(
|
|
318
|
+
(
|
|
319
|
+
"WARNING: the following recordings could not be"
|
|
320
|
+
"loaded and will not be used for training: "
|
|
321
|
+
f"{', '.join([str(r) for r in failed_recordings])}"
|
|
322
|
+
)
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# train model
|
|
326
|
+
self.show_message("Training model, please wait...")
|
|
327
|
+
self.ui.message_area.repaint()
|
|
328
|
+
QtWidgets.QApplication.processEvents()
|
|
329
|
+
print("Training model")
|
|
330
|
+
model = train_model(
|
|
331
|
+
annotations_file=os.path.join(
|
|
332
|
+
self.training_image_dir, ANNOTATIONS_FILENAME
|
|
333
|
+
),
|
|
334
|
+
img_dir=self.training_image_dir,
|
|
335
|
+
mixture_weights=self.brain_state_set.mixture_weights,
|
|
336
|
+
n_classes=self.brain_state_set.n_classes,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# save model
|
|
340
|
+
save_model(
|
|
341
|
+
model=model,
|
|
342
|
+
filename=model_filename,
|
|
343
|
+
epoch_length=self.epoch_length,
|
|
344
|
+
epochs_per_img=self.training_epochs_per_img,
|
|
345
|
+
model_type=self.model_type,
|
|
346
|
+
brain_state_set=self.brain_state_set,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# optionally delete images
|
|
350
|
+
if self.delete_training_images:
|
|
351
|
+
shutil.rmtree(self.training_image_dir)
|
|
352
|
+
|
|
353
|
+
self.show_message(f"Training complete, saved model to {model_filename}")
|
|
354
|
+
|
|
355
|
+
def set_training_folder(self):
|
|
356
|
+
training_folder_parent = QtWidgets.QFileDialog.getExistingDirectory(
|
|
357
|
+
self, "Select directory for training images"
|
|
358
|
+
)
|
|
359
|
+
if training_folder_parent:
|
|
360
|
+
self.training_image_dir = os.path.join(
|
|
361
|
+
training_folder_parent,
|
|
362
|
+
"images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
|
|
363
|
+
)
|
|
364
|
+
self.ui.image_folder_label.setText(self.training_image_dir)
|
|
365
|
+
|
|
366
|
+
def update_image_deletion(self) -> None:
|
|
367
|
+
"""Update choice of whether to delete images after training"""
|
|
368
|
+
self.delete_training_images = self.ui.delete_image_box.isChecked()
|
|
369
|
+
|
|
370
|
+
def update_epochs_per_img(self, new_value) -> None:
|
|
371
|
+
"""Update number of epochs per image
|
|
372
|
+
|
|
373
|
+
:param new_value: new number of epochs per image
|
|
374
|
+
"""
|
|
375
|
+
self.training_epochs_per_img = new_value
|
|
376
|
+
|
|
377
|
+
def score_all(self) -> None:
|
|
378
|
+
"""Score all recordings using the classification model"""
|
|
379
|
+
# check basic inputs
|
|
380
|
+
if self.model is None:
|
|
381
|
+
self.ui.score_all_status.setText("missing classification model")
|
|
382
|
+
self.show_message("ERROR: no classification model file selected")
|
|
383
|
+
return
|
|
384
|
+
if self.min_bout_length < self.epoch_length:
|
|
385
|
+
self.ui.score_all_status.setText("invalid minimum bout length")
|
|
386
|
+
self.show_message("ERROR: minimum bout length must be >= epoch length")
|
|
387
|
+
return
|
|
388
|
+
if self.epoch_length != self.model_epoch_length:
|
|
389
|
+
self.ui.score_all_status.setText("invalid epoch length")
|
|
390
|
+
self.show_message(
|
|
391
|
+
(
|
|
392
|
+
"ERROR: model was trained with an epoch length of "
|
|
393
|
+
f"{self.model_epoch_length} seconds, but the current "
|
|
394
|
+
f"epoch length setting is {self.epoch_length} seconds."
|
|
395
|
+
)
|
|
396
|
+
)
|
|
397
|
+
return
|
|
398
|
+
|
|
399
|
+
self.ui.score_all_status.setText("running...")
|
|
400
|
+
self.ui.score_all_status.repaint()
|
|
401
|
+
QtWidgets.QApplication.processEvents()
|
|
402
|
+
|
|
403
|
+
# check some inputs for each recording
|
|
404
|
+
for recording_index in range(len(self.recordings)):
|
|
405
|
+
error_message = self.check_single_file_inputs(recording_index)
|
|
406
|
+
if error_message:
|
|
407
|
+
self.ui.score_all_status.setText(
|
|
408
|
+
f"error on recording {self.recordings[recording_index].name}"
|
|
409
|
+
)
|
|
410
|
+
self.show_message(
|
|
411
|
+
f"ERROR ({self.recordings[recording_index].name}): {error_message}"
|
|
412
|
+
)
|
|
413
|
+
return
|
|
414
|
+
if self.recordings[recording_index].calibration_file == "":
|
|
415
|
+
self.ui.score_all_status.setText(
|
|
416
|
+
f"error on recording {self.recordings[recording_index].name}"
|
|
417
|
+
)
|
|
418
|
+
self.show_message(
|
|
419
|
+
(
|
|
420
|
+
f"ERROR ({self.recordings[recording_index].name}): "
|
|
421
|
+
"no calibration file selected"
|
|
422
|
+
)
|
|
423
|
+
)
|
|
424
|
+
return
|
|
425
|
+
|
|
426
|
+
# score each recording
|
|
427
|
+
for recording_index in range(len(self.recordings)):
|
|
428
|
+
# load EEG, EMG
|
|
429
|
+
try:
|
|
430
|
+
eeg, emg = load_recording(
|
|
431
|
+
self.recordings[recording_index].recording_file
|
|
432
|
+
)
|
|
433
|
+
sampling_rate = self.recordings[recording_index].sampling_rate
|
|
434
|
+
|
|
435
|
+
eeg, emg, sampling_rate = resample_and_standardize(
|
|
436
|
+
eeg=eeg,
|
|
437
|
+
emg=emg,
|
|
438
|
+
sampling_rate=sampling_rate,
|
|
439
|
+
epoch_length=self.epoch_length,
|
|
440
|
+
)
|
|
441
|
+
except Exception:
|
|
442
|
+
self.show_message(
|
|
443
|
+
(
|
|
444
|
+
"ERROR: could not load recording "
|
|
445
|
+
f"{self.recordings[recording_index].name}."
|
|
446
|
+
"This recording will be skipped."
|
|
447
|
+
)
|
|
448
|
+
)
|
|
449
|
+
continue
|
|
450
|
+
|
|
451
|
+
# load labels
|
|
452
|
+
label_file = self.recordings[recording_index].label_file
|
|
453
|
+
if os.path.isfile(label_file):
|
|
454
|
+
try:
|
|
455
|
+
existing_labels = load_labels(label_file)
|
|
456
|
+
except Exception:
|
|
457
|
+
self.show_message(
|
|
458
|
+
(
|
|
459
|
+
"ERROR: could not load existing labels for recording "
|
|
460
|
+
f"{self.recordings[recording_index].name}."
|
|
461
|
+
"This recording will be skipped."
|
|
462
|
+
)
|
|
463
|
+
)
|
|
464
|
+
continue
|
|
465
|
+
# only check the length
|
|
466
|
+
samples_per_epoch = sampling_rate * self.epoch_length
|
|
467
|
+
epochs_in_recording = round(eeg.size / samples_per_epoch)
|
|
468
|
+
if epochs_in_recording != existing_labels.size:
|
|
469
|
+
self.show_message(
|
|
470
|
+
(
|
|
471
|
+
"ERROR: existing labels for recording "
|
|
472
|
+
f"{self.recordings[recording_index].name} "
|
|
473
|
+
"do not match the recording length. "
|
|
474
|
+
"This recording will be skipped."
|
|
475
|
+
)
|
|
476
|
+
)
|
|
477
|
+
continue
|
|
478
|
+
else:
|
|
479
|
+
existing_labels = None
|
|
480
|
+
|
|
481
|
+
# load calibration data
|
|
482
|
+
if not os.path.isfile(self.recordings[recording_index].calibration_file):
|
|
483
|
+
self.show_message(
|
|
484
|
+
(
|
|
485
|
+
"ERROR: calibration file does not exist for recording "
|
|
486
|
+
f"{self.recordings[recording_index].name}. "
|
|
487
|
+
"This recording will be skipped."
|
|
488
|
+
)
|
|
489
|
+
)
|
|
490
|
+
continue
|
|
491
|
+
try:
|
|
492
|
+
(
|
|
493
|
+
mixture_means,
|
|
494
|
+
mixture_sds,
|
|
495
|
+
) = load_calibration_file(
|
|
496
|
+
self.recordings[recording_index].calibration_file
|
|
497
|
+
)
|
|
498
|
+
except Exception:
|
|
499
|
+
self.show_message(
|
|
500
|
+
(
|
|
501
|
+
"ERROR: could not load calibration file for recording "
|
|
502
|
+
f"{self.recordings[recording_index].name}. "
|
|
503
|
+
"This recording will be skipped."
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
continue
|
|
507
|
+
|
|
508
|
+
labels = score_recording(
|
|
509
|
+
model=self.model,
|
|
510
|
+
eeg=eeg,
|
|
511
|
+
emg=emg,
|
|
512
|
+
mixture_means=mixture_means,
|
|
513
|
+
mixture_sds=mixture_sds,
|
|
514
|
+
sampling_rate=sampling_rate,
|
|
515
|
+
epoch_length=self.epoch_length,
|
|
516
|
+
epochs_per_img=self.model_epochs_per_img,
|
|
517
|
+
brain_state_set=self.brain_state_set,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
# overwrite as needed
|
|
521
|
+
if existing_labels is not None and self.only_overwrite_undefined:
|
|
522
|
+
labels[existing_labels != UNDEFINED_LABEL] = existing_labels[
|
|
523
|
+
existing_labels != UNDEFINED_LABEL
|
|
524
|
+
]
|
|
525
|
+
|
|
526
|
+
# enforce minimum bout length
|
|
527
|
+
labels = enforce_min_bout_length(
|
|
528
|
+
labels=labels,
|
|
529
|
+
epoch_length=self.epoch_length,
|
|
530
|
+
min_bout_length=self.min_bout_length,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# save results
|
|
534
|
+
save_labels(labels, label_file)
|
|
535
|
+
self.show_message(
|
|
536
|
+
(
|
|
537
|
+
"Saved labels for recording "
|
|
538
|
+
f"{self.recordings[recording_index].name} "
|
|
539
|
+
f"to {label_file}"
|
|
540
|
+
)
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
self.ui.score_all_status.setText("")
|
|
544
|
+
|
|
545
|
+
def load_model(self, filename=None) -> None:
|
|
546
|
+
"""Load trained classification model from file
|
|
547
|
+
|
|
548
|
+
:param filename: model filename, if it's known
|
|
549
|
+
"""
|
|
550
|
+
if filename is None:
|
|
551
|
+
file_dialog = QtWidgets.QFileDialog(self)
|
|
552
|
+
file_dialog.setWindowTitle("Select classification model")
|
|
553
|
+
file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
|
|
554
|
+
file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
|
|
555
|
+
file_dialog.setNameFilter("*" + MODEL_FILE_TYPE)
|
|
556
|
+
|
|
557
|
+
if file_dialog.exec():
|
|
558
|
+
selected_files = file_dialog.selectedFiles()
|
|
559
|
+
filename = selected_files[0]
|
|
560
|
+
else:
|
|
561
|
+
return
|
|
562
|
+
|
|
563
|
+
if not os.path.isfile(filename):
|
|
564
|
+
self.show_message("ERROR: model file does not exist")
|
|
565
|
+
return
|
|
566
|
+
|
|
567
|
+
try:
|
|
568
|
+
model, epoch_length, epochs_per_img, model_type, brain_states = load_model(
|
|
569
|
+
filename=filename
|
|
570
|
+
)
|
|
571
|
+
except Exception:
|
|
572
|
+
self.show_message(
|
|
573
|
+
(
|
|
574
|
+
"ERROR: could not load classification model. Check "
|
|
575
|
+
"user manual for instructions on creating this file."
|
|
576
|
+
)
|
|
577
|
+
)
|
|
578
|
+
return
|
|
579
|
+
|
|
580
|
+
# make sure only "default" model type is loaded
|
|
581
|
+
if model_type != DEFAULT_MODEL_TYPE:
|
|
582
|
+
self.show_message(
|
|
583
|
+
(
|
|
584
|
+
"ERROR: only 'default'-style models can be used. "
|
|
585
|
+
"'Real-time' models are not supported. "
|
|
586
|
+
"See classification.example_real_time_scoring_function.py "
|
|
587
|
+
"for an example of how to classify brain states in real time."
|
|
588
|
+
)
|
|
589
|
+
)
|
|
590
|
+
return
|
|
591
|
+
|
|
592
|
+
self.model = model
|
|
593
|
+
self.model_epoch_length = epoch_length
|
|
594
|
+
self.model_epochs_per_img = epochs_per_img
|
|
595
|
+
|
|
596
|
+
# warn user if the model's expected epoch length or brain states
|
|
597
|
+
# don't match the current configuration
|
|
598
|
+
config_warnings = check_config_consistency(
|
|
599
|
+
current_brain_states=self.brain_state_set.to_output_dict()[
|
|
600
|
+
BRAIN_STATES_KEY
|
|
601
|
+
],
|
|
602
|
+
model_brain_states=brain_states,
|
|
603
|
+
current_epoch_length=self.epoch_length,
|
|
604
|
+
model_epoch_length=epoch_length,
|
|
605
|
+
)
|
|
606
|
+
if len(config_warnings) > 0:
|
|
607
|
+
for w in config_warnings:
|
|
608
|
+
self.show_message(w)
|
|
609
|
+
|
|
610
|
+
self.ui.model_label.setText(filename)
|
|
611
|
+
|
|
612
|
+
def load_single_recording(
|
|
613
|
+
self, status_widget: QtWidgets.QLabel
|
|
614
|
+
) -> (np.array, np.array, int | float, bool):
|
|
615
|
+
"""Load and preprocess one recording
|
|
616
|
+
|
|
617
|
+
This loads one recording, resamples it, and standardizes its length.
|
|
618
|
+
If an error occurs during this process, it is displayed in the
|
|
619
|
+
indicated widget.
|
|
620
|
+
|
|
621
|
+
:param status_widget: UI element on which to display error messages
|
|
622
|
+
:return: EEG data, EMG data, sampling rate, process completion
|
|
623
|
+
"""
|
|
624
|
+
error_message = self.check_single_file_inputs(self.recording_index)
|
|
625
|
+
if error_message:
|
|
626
|
+
status_widget.setText(error_message)
|
|
627
|
+
self.show_message(f"ERROR: {error_message}")
|
|
628
|
+
return None, None, None, False
|
|
629
|
+
|
|
630
|
+
try:
|
|
631
|
+
eeg, emg = load_recording(
|
|
632
|
+
self.recordings[self.recording_index].recording_file
|
|
633
|
+
)
|
|
634
|
+
except Exception:
|
|
635
|
+
status_widget.setText("could not load recording")
|
|
636
|
+
self.show_message(
|
|
637
|
+
(
|
|
638
|
+
"ERROR: could not load recording. "
|
|
639
|
+
"Check user manual for formatting instructions."
|
|
640
|
+
)
|
|
641
|
+
)
|
|
642
|
+
return None, None, None, False
|
|
643
|
+
|
|
644
|
+
sampling_rate = self.recordings[self.recording_index].sampling_rate
|
|
645
|
+
|
|
646
|
+
eeg, emg, sampling_rate = resample_and_standardize(
|
|
647
|
+
eeg=eeg,
|
|
648
|
+
emg=emg,
|
|
649
|
+
sampling_rate=sampling_rate,
|
|
650
|
+
epoch_length=self.epoch_length,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
return eeg, emg, sampling_rate, True
|
|
654
|
+
|
|
655
|
+
def create_calibration_file(self) -> None:
|
|
656
|
+
"""Creates a calibration file
|
|
657
|
+
|
|
658
|
+
This loads a recording and its labels, checks that the labels are
|
|
659
|
+
all valid, creates the calibration file, and sets the
|
|
660
|
+
"calibration file" property of the current recording to be the
|
|
661
|
+
newly created file.
|
|
662
|
+
"""
|
|
663
|
+
# load the recording
|
|
664
|
+
eeg, emg, sampling_rate, success = self.load_single_recording(
|
|
665
|
+
self.ui.calibration_status
|
|
666
|
+
)
|
|
667
|
+
if not success:
|
|
668
|
+
return
|
|
669
|
+
|
|
670
|
+
# load the labels
|
|
671
|
+
label_file = self.recordings[self.recording_index].label_file
|
|
672
|
+
if not os.path.isfile(label_file):
|
|
673
|
+
self.ui.calibration_status.setText("label file does not exist")
|
|
674
|
+
self.show_message("ERROR: label file does not exist")
|
|
675
|
+
return
|
|
676
|
+
try:
|
|
677
|
+
labels = load_labels(label_file)
|
|
678
|
+
except Exception:
|
|
679
|
+
self.ui.calibration_status.setText("could not load labels")
|
|
680
|
+
self.show_message(
|
|
681
|
+
(
|
|
682
|
+
"ERROR: could not load labels. "
|
|
683
|
+
"Check user manual for formatting instructions."
|
|
684
|
+
)
|
|
685
|
+
)
|
|
686
|
+
return
|
|
687
|
+
label_error_message = check_label_validity(
|
|
688
|
+
labels=labels,
|
|
689
|
+
samples_in_recording=eeg.size,
|
|
690
|
+
sampling_rate=sampling_rate,
|
|
691
|
+
epoch_length=self.epoch_length,
|
|
692
|
+
brain_state_set=self.brain_state_set,
|
|
693
|
+
)
|
|
694
|
+
if label_error_message:
|
|
695
|
+
self.ui.calibration_status.setText("invalid label file")
|
|
696
|
+
self.show_message(f"ERROR: {label_error_message}")
|
|
697
|
+
return
|
|
698
|
+
|
|
699
|
+
# get the name for the calibration file
|
|
700
|
+
filename, _ = QtWidgets.QFileDialog.getSaveFileName(
|
|
701
|
+
self,
|
|
702
|
+
caption="Save calibration file as",
|
|
703
|
+
filter="*" + CALIBRATION_FILE_TYPE,
|
|
704
|
+
)
|
|
705
|
+
if not filename:
|
|
706
|
+
return
|
|
707
|
+
|
|
708
|
+
create_calibration_file(
|
|
709
|
+
filename=filename,
|
|
710
|
+
eeg=eeg,
|
|
711
|
+
emg=emg,
|
|
712
|
+
labels=labels,
|
|
713
|
+
sampling_rate=sampling_rate,
|
|
714
|
+
epoch_length=self.epoch_length,
|
|
715
|
+
brain_state_set=self.brain_state_set,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
self.ui.calibration_status.setText("")
|
|
719
|
+
self.show_message(
|
|
720
|
+
(
|
|
721
|
+
"Created calibration file using recording "
|
|
722
|
+
f"{self.recordings[self.recording_index].name} "
|
|
723
|
+
f"at {filename}"
|
|
724
|
+
)
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
self.recordings[self.recording_index].calibration_file = filename
|
|
728
|
+
self.ui.calibration_file_label.setText(filename)
|
|
729
|
+
|
|
730
|
+
def check_single_file_inputs(self, recording_index: int) -> str:
|
|
731
|
+
"""Check that a recording's inputs appear valid
|
|
732
|
+
|
|
733
|
+
This runs some basic tests for whether it will be possible to
|
|
734
|
+
load and score a recording. If any test fails, we return an
|
|
735
|
+
error message.
|
|
736
|
+
|
|
737
|
+
:param recording_index: index of the recording in the list of
|
|
738
|
+
all recordings.
|
|
739
|
+
:return: error message
|
|
740
|
+
"""
|
|
741
|
+
sampling_rate = self.recordings[recording_index].sampling_rate
|
|
742
|
+
if self.epoch_length == 0:
|
|
743
|
+
return "epoch length can't be 0"
|
|
744
|
+
if sampling_rate == 0:
|
|
745
|
+
return "sampling rate can't be 0"
|
|
746
|
+
if self.epoch_length > sampling_rate:
|
|
747
|
+
return "invalid epoch length or sampling rate"
|
|
748
|
+
if self.recordings[self.recording_index].recording_file == "":
|
|
749
|
+
return "no recording selected"
|
|
750
|
+
if self.recordings[self.recording_index].label_file == "":
|
|
751
|
+
return "no label file selected"
|
|
752
|
+
|
|
753
|
+
def update_min_bout_length(self, new_value) -> None:
|
|
754
|
+
"""Update the minimum bout length
|
|
755
|
+
|
|
756
|
+
:param new_value: new minimum bout length, in seconds
|
|
757
|
+
"""
|
|
758
|
+
self.min_bout_length = new_value
|
|
759
|
+
|
|
760
|
+
def update_overwrite_policy(self, checked) -> None:
|
|
761
|
+
"""Toggle overwriting policy
|
|
762
|
+
|
|
763
|
+
If the checkbox is enabled, only epochs where the brain state is set to
|
|
764
|
+
undefined will be overwritten by the automatic scoring process.
|
|
765
|
+
|
|
766
|
+
:param checked: state of the checkbox
|
|
767
|
+
"""
|
|
768
|
+
self.only_overwrite_undefined = checked
|
|
769
|
+
|
|
770
|
+
def manual_scoring(self) -> None:
|
|
771
|
+
"""View the selected recording for manual scoring"""
|
|
772
|
+
# immediately display a status message
|
|
773
|
+
self.ui.manual_scoring_status.setText("loading...")
|
|
774
|
+
self.ui.manual_scoring_status.repaint()
|
|
775
|
+
QtWidgets.QApplication.processEvents()
|
|
776
|
+
|
|
777
|
+
# load the recording
|
|
778
|
+
eeg, emg, sampling_rate, success = self.load_single_recording(
|
|
779
|
+
self.ui.manual_scoring_status
|
|
780
|
+
)
|
|
781
|
+
if not success:
|
|
782
|
+
return
|
|
783
|
+
|
|
784
|
+
# if the labels exist, load them
|
|
785
|
+
# otherwise, create a blank set of labels
|
|
786
|
+
label_file = self.recordings[self.recording_index].label_file
|
|
787
|
+
if os.path.isfile(label_file):
|
|
788
|
+
try:
|
|
789
|
+
labels = load_labels(label_file)
|
|
790
|
+
except Exception:
|
|
791
|
+
self.ui.manual_scoring_status.setText("could not load labels")
|
|
792
|
+
self.show_message(
|
|
793
|
+
(
|
|
794
|
+
"ERROR: could not load labels. "
|
|
795
|
+
"Check user manual for formatting instructions."
|
|
796
|
+
)
|
|
797
|
+
)
|
|
798
|
+
return
|
|
799
|
+
else:
|
|
800
|
+
labels = (
|
|
801
|
+
np.ones(int(eeg.size / (sampling_rate * self.epoch_length)))
|
|
802
|
+
* UNDEFINED_LABEL
|
|
803
|
+
).astype(int)
|
|
804
|
+
|
|
805
|
+
# check that all labels are valid
|
|
806
|
+
label_error = check_label_validity(
|
|
807
|
+
labels=labels,
|
|
808
|
+
samples_in_recording=eeg.size,
|
|
809
|
+
sampling_rate=sampling_rate,
|
|
810
|
+
epoch_length=self.epoch_length,
|
|
811
|
+
brain_state_set=self.brain_state_set,
|
|
812
|
+
)
|
|
813
|
+
if label_error:
|
|
814
|
+
# if the label length is only off by one, pad or truncate as needed
|
|
815
|
+
# and show a warning
|
|
816
|
+
if label_error == LABEL_LENGTH_ERROR:
|
|
817
|
+
# should be very close to an integer
|
|
818
|
+
samples_per_epoch = round(sampling_rate * self.epoch_length)
|
|
819
|
+
epochs_in_recording = round(eeg.size / samples_per_epoch)
|
|
820
|
+
if epochs_in_recording - labels.size == 1:
|
|
821
|
+
labels = np.concatenate((labels, np.array([UNDEFINED_LABEL])))
|
|
822
|
+
self.show_message(
|
|
823
|
+
(
|
|
824
|
+
"WARNING: an undefined epoch was added to "
|
|
825
|
+
"the label file to correct its length."
|
|
826
|
+
)
|
|
827
|
+
)
|
|
828
|
+
elif labels.size - epochs_in_recording == 1:
|
|
829
|
+
labels = labels[:-1]
|
|
830
|
+
self.show_message(
|
|
831
|
+
(
|
|
832
|
+
"WARNING: the last epoch was removed from "
|
|
833
|
+
"the label file to correct its length."
|
|
834
|
+
)
|
|
835
|
+
)
|
|
836
|
+
else:
|
|
837
|
+
self.ui.manual_scoring_status.setText("invalid label file")
|
|
838
|
+
self.show_message(f"ERROR: {label_error}")
|
|
839
|
+
return
|
|
840
|
+
else:
|
|
841
|
+
self.ui.manual_scoring_status.setText("invalid label file")
|
|
842
|
+
self.show_message(f"ERROR: {label_error}")
|
|
843
|
+
return
|
|
844
|
+
|
|
845
|
+
self.show_message(
|
|
846
|
+
f"Viewing recording {self.recordings[self.recording_index].name}"
|
|
847
|
+
)
|
|
848
|
+
self.ui.manual_scoring_status.setText("file is open")
|
|
849
|
+
|
|
850
|
+
# launch the manual scoring window
|
|
851
|
+
manual_scoring_window = ManualScoringWindow(
|
|
852
|
+
eeg=eeg,
|
|
853
|
+
emg=emg,
|
|
854
|
+
label_file=label_file,
|
|
855
|
+
labels=labels,
|
|
856
|
+
sampling_rate=sampling_rate,
|
|
857
|
+
epoch_length=self.epoch_length,
|
|
858
|
+
)
|
|
859
|
+
manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
|
|
860
|
+
manual_scoring_window.exec()
|
|
861
|
+
self.ui.manual_scoring_status.setText("")
|
|
862
|
+
|
|
863
|
+
def create_label_file(self) -> None:
|
|
864
|
+
"""Set the filename for a new label file"""
|
|
865
|
+
filename, _ = QtWidgets.QFileDialog.getSaveFileName(
|
|
866
|
+
self,
|
|
867
|
+
caption="Set filename for label file (nothing will be overwritten yet)",
|
|
868
|
+
filter="*" + LABEL_FILE_TYPE,
|
|
869
|
+
)
|
|
870
|
+
if filename:
|
|
871
|
+
self.recordings[self.recording_index].label_file = filename
|
|
872
|
+
self.ui.label_file_label.setText(filename)
|
|
873
|
+
|
|
874
|
+
def select_label_file(self) -> None:
|
|
875
|
+
"""User can select an existing label file"""
|
|
876
|
+
file_dialog = QtWidgets.QFileDialog(self)
|
|
877
|
+
file_dialog.setWindowTitle("Select label file")
|
|
878
|
+
file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
|
|
879
|
+
file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
|
|
880
|
+
file_dialog.setNameFilter("*" + LABEL_FILE_TYPE)
|
|
881
|
+
|
|
882
|
+
if file_dialog.exec():
|
|
883
|
+
selected_files = file_dialog.selectedFiles()
|
|
884
|
+
filename = selected_files[0]
|
|
885
|
+
self.recordings[self.recording_index].label_file = filename
|
|
886
|
+
self.ui.label_file_label.setText(filename)
|
|
887
|
+
|
|
888
|
+
def select_calibration_file(self) -> None:
|
|
889
|
+
"""User can select a calibration file"""
|
|
890
|
+
file_dialog = QtWidgets.QFileDialog(self)
|
|
891
|
+
file_dialog.setWindowTitle("Select calibration file")
|
|
892
|
+
file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
|
|
893
|
+
file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
|
|
894
|
+
file_dialog.setNameFilter("*" + CALIBRATION_FILE_TYPE)
|
|
895
|
+
|
|
896
|
+
if file_dialog.exec():
|
|
897
|
+
selected_files = file_dialog.selectedFiles()
|
|
898
|
+
filename = selected_files[0]
|
|
899
|
+
self.recordings[self.recording_index].calibration_file = filename
|
|
900
|
+
self.ui.calibration_file_label.setText(filename)
|
|
901
|
+
|
|
902
|
+
def select_recording_file(self) -> None:
|
|
903
|
+
"""User can select a recording file"""
|
|
904
|
+
file_dialog = QtWidgets.QFileDialog(self)
|
|
905
|
+
file_dialog.setWindowTitle("Select recording file")
|
|
906
|
+
file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
|
|
907
|
+
file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
|
|
908
|
+
file_dialog.setNameFilter(f"(*{' *'.join(RECORDING_FILE_TYPES)})")
|
|
909
|
+
|
|
910
|
+
if file_dialog.exec():
|
|
911
|
+
selected_files = file_dialog.selectedFiles()
|
|
912
|
+
filename = selected_files[0]
|
|
913
|
+
self.recordings[self.recording_index].recording_file = filename
|
|
914
|
+
self.ui.recording_file_label.setText(filename)
|
|
915
|
+
|
|
916
|
+
def show_recording_info(self) -> None:
|
|
917
|
+
"""Update the UI to show info for the selected recording"""
|
|
918
|
+
self.ui.sampling_rate_input.setValue(
|
|
919
|
+
self.recordings[self.recording_index].sampling_rate
|
|
920
|
+
)
|
|
921
|
+
self.ui.recording_file_label.setText(
|
|
922
|
+
self.recordings[self.recording_index].recording_file
|
|
923
|
+
)
|
|
924
|
+
self.ui.label_file_label.setText(
|
|
925
|
+
self.recordings[self.recording_index].label_file
|
|
926
|
+
)
|
|
927
|
+
self.ui.calibration_file_label.setText(
|
|
928
|
+
self.recordings[self.recording_index].calibration_file
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
def update_epoch_length(self, new_value: int | float) -> None:
|
|
932
|
+
"""Update the epoch length when the widget state changes
|
|
933
|
+
|
|
934
|
+
:param new_value: new epoch length
|
|
935
|
+
"""
|
|
936
|
+
self.epoch_length = new_value
|
|
937
|
+
|
|
938
|
+
def update_sampling_rate(self, new_value: int | float) -> None:
|
|
939
|
+
"""Update recording's sampling rate when the widget state changes
|
|
940
|
+
|
|
941
|
+
:param new_value: new sampling rate
|
|
942
|
+
"""
|
|
943
|
+
self.recordings[self.recording_index].sampling_rate = new_value
|
|
944
|
+
|
|
945
|
+
def show_message(self, message: str) -> None:
|
|
946
|
+
"""Display a new message to the user
|
|
947
|
+
|
|
948
|
+
:param message: message to display
|
|
949
|
+
"""
|
|
950
|
+
self.messages.append(message)
|
|
951
|
+
if len(self.messages) > MESSAGE_BOX_MAX_DEPTH:
|
|
952
|
+
del self.messages[0]
|
|
953
|
+
self.ui.message_area.setText("\n".join(self.messages))
|
|
954
|
+
# scroll to the bottom
|
|
955
|
+
scrollbar = self.ui.message_area.verticalScrollBar()
|
|
956
|
+
scrollbar.setValue(scrollbar.maximum())
|
|
957
|
+
|
|
958
|
+
def select_recording(self, list_index: int) -> None:
|
|
959
|
+
"""Callback for when a recording is selected
|
|
960
|
+
|
|
961
|
+
:param list_index: index of this recording in the list widget
|
|
962
|
+
"""
|
|
963
|
+
# get index of this recording
|
|
964
|
+
self.recording_index = list_index
|
|
965
|
+
# display information about this recording
|
|
966
|
+
self.show_recording_info()
|
|
967
|
+
self.ui.selected_recording_groupbox.setTitle(
|
|
968
|
+
f"Data / actions for Recording {self.recordings[list_index].name}"
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
def add_recording(self) -> None:
|
|
972
|
+
"""Add new recording to the list"""
|
|
973
|
+
# find name to use for the new recording
|
|
974
|
+
new_name = max([r.name for r in self.recordings]) + 1
|
|
975
|
+
|
|
976
|
+
# add new recording to list
|
|
977
|
+
self.recordings.append(
|
|
978
|
+
Recording(
|
|
979
|
+
name=new_name,
|
|
980
|
+
sampling_rate=self.recordings[self.recording_index].sampling_rate,
|
|
981
|
+
widget=QtWidgets.QListWidgetItem(
|
|
982
|
+
f"Recording {new_name}", self.ui.recording_list_widget
|
|
983
|
+
),
|
|
984
|
+
)
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
# display new list
|
|
988
|
+
self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
|
|
989
|
+
self.ui.recording_list_widget.setCurrentRow(len(self.recordings) - 1)
|
|
990
|
+
self.show_message(f"added Recording {new_name}")
|
|
991
|
+
|
|
992
|
+
def remove_recording(self) -> None:
|
|
993
|
+
"""Delete selected recording from the list"""
|
|
994
|
+
if len(self.recordings) > 1:
|
|
995
|
+
current_list_index = self.ui.recording_list_widget.currentRow()
|
|
996
|
+
_ = self.ui.recording_list_widget.takeItem(current_list_index)
|
|
997
|
+
self.show_message(
|
|
998
|
+
f"deleted Recording {self.recordings[current_list_index].name}"
|
|
999
|
+
)
|
|
1000
|
+
del self.recordings[current_list_index]
|
|
1001
|
+
self.recording_index = self.ui.recording_list_widget.currentRow()
|
|
1002
|
+
|
|
1003
|
+
def show_user_manual(self) -> None:
|
|
1004
|
+
"""Show a popup window with the user manual"""
|
|
1005
|
+
user_manual_file = open(
|
|
1006
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), USER_MANUAL_FILE),
|
|
1007
|
+
"r",
|
|
1008
|
+
)
|
|
1009
|
+
user_manual_text = user_manual_file.read()
|
|
1010
|
+
user_manual_file.close()
|
|
1011
|
+
|
|
1012
|
+
label_widget = QtWidgets.QLabel()
|
|
1013
|
+
label_widget.setText(user_manual_text)
|
|
1014
|
+
scroll_area = QtWidgets.QScrollArea()
|
|
1015
|
+
scroll_area.setStyleSheet("background-color: white;")
|
|
1016
|
+
scroll_area.setWidget(label_widget)
|
|
1017
|
+
grid = QtWidgets.QGridLayout()
|
|
1018
|
+
grid.addWidget(scroll_area)
|
|
1019
|
+
self.popup = QtWidgets.QWidget()
|
|
1020
|
+
self.popup.setLayout(grid)
|
|
1021
|
+
self.popup.setGeometry(QtCore.QRect(100, 100, 600, 600))
|
|
1022
|
+
self.popup.show()
|
|
1023
|
+
|
|
1024
|
+
def initialize_settings_tab(self):
|
|
1025
|
+
"""Populate settings tab and assign its callbacks"""
|
|
1026
|
+
# show information about the settings tab
|
|
1027
|
+
config_guide_file = open(
|
|
1028
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_GUIDE_FILE),
|
|
1029
|
+
"r",
|
|
1030
|
+
)
|
|
1031
|
+
config_guide_text = config_guide_file.read()
|
|
1032
|
+
config_guide_file.close()
|
|
1033
|
+
self.ui.settings_text.setText(config_guide_text)
|
|
1034
|
+
|
|
1035
|
+
# store dictionary that maps digits to rows of widgets
|
|
1036
|
+
# in the settings tab
|
|
1037
|
+
self.settings_widgets = {
|
|
1038
|
+
1: StateSettings(
|
|
1039
|
+
digit=1,
|
|
1040
|
+
enabled_widget=self.ui.enable_state_1,
|
|
1041
|
+
name_widget=self.ui.state_name_1,
|
|
1042
|
+
is_scored_widget=self.ui.state_scored_1,
|
|
1043
|
+
frequency_widget=self.ui.state_frequency_1,
|
|
1044
|
+
),
|
|
1045
|
+
2: StateSettings(
|
|
1046
|
+
digit=2,
|
|
1047
|
+
enabled_widget=self.ui.enable_state_2,
|
|
1048
|
+
name_widget=self.ui.state_name_2,
|
|
1049
|
+
is_scored_widget=self.ui.state_scored_2,
|
|
1050
|
+
frequency_widget=self.ui.state_frequency_2,
|
|
1051
|
+
),
|
|
1052
|
+
3: StateSettings(
|
|
1053
|
+
digit=3,
|
|
1054
|
+
enabled_widget=self.ui.enable_state_3,
|
|
1055
|
+
name_widget=self.ui.state_name_3,
|
|
1056
|
+
is_scored_widget=self.ui.state_scored_3,
|
|
1057
|
+
frequency_widget=self.ui.state_frequency_3,
|
|
1058
|
+
),
|
|
1059
|
+
4: StateSettings(
|
|
1060
|
+
digit=4,
|
|
1061
|
+
enabled_widget=self.ui.enable_state_4,
|
|
1062
|
+
name_widget=self.ui.state_name_4,
|
|
1063
|
+
is_scored_widget=self.ui.state_scored_4,
|
|
1064
|
+
frequency_widget=self.ui.state_frequency_4,
|
|
1065
|
+
),
|
|
1066
|
+
5: StateSettings(
|
|
1067
|
+
digit=5,
|
|
1068
|
+
enabled_widget=self.ui.enable_state_5,
|
|
1069
|
+
name_widget=self.ui.state_name_5,
|
|
1070
|
+
is_scored_widget=self.ui.state_scored_5,
|
|
1071
|
+
frequency_widget=self.ui.state_frequency_5,
|
|
1072
|
+
),
|
|
1073
|
+
6: StateSettings(
|
|
1074
|
+
digit=6,
|
|
1075
|
+
enabled_widget=self.ui.enable_state_6,
|
|
1076
|
+
name_widget=self.ui.state_name_6,
|
|
1077
|
+
is_scored_widget=self.ui.state_scored_6,
|
|
1078
|
+
frequency_widget=self.ui.state_frequency_6,
|
|
1079
|
+
),
|
|
1080
|
+
7: StateSettings(
|
|
1081
|
+
digit=7,
|
|
1082
|
+
enabled_widget=self.ui.enable_state_7,
|
|
1083
|
+
name_widget=self.ui.state_name_7,
|
|
1084
|
+
is_scored_widget=self.ui.state_scored_7,
|
|
1085
|
+
frequency_widget=self.ui.state_frequency_7,
|
|
1086
|
+
),
|
|
1087
|
+
8: StateSettings(
|
|
1088
|
+
digit=8,
|
|
1089
|
+
enabled_widget=self.ui.enable_state_8,
|
|
1090
|
+
name_widget=self.ui.state_name_8,
|
|
1091
|
+
is_scored_widget=self.ui.state_scored_8,
|
|
1092
|
+
frequency_widget=self.ui.state_frequency_8,
|
|
1093
|
+
),
|
|
1094
|
+
9: StateSettings(
|
|
1095
|
+
digit=9,
|
|
1096
|
+
enabled_widget=self.ui.enable_state_9,
|
|
1097
|
+
name_widget=self.ui.state_name_9,
|
|
1098
|
+
is_scored_widget=self.ui.state_scored_9,
|
|
1099
|
+
frequency_widget=self.ui.state_frequency_9,
|
|
1100
|
+
),
|
|
1101
|
+
0: StateSettings(
|
|
1102
|
+
digit=0,
|
|
1103
|
+
enabled_widget=self.ui.enable_state_0,
|
|
1104
|
+
name_widget=self.ui.state_name_0,
|
|
1105
|
+
is_scored_widget=self.ui.state_scored_0,
|
|
1106
|
+
frequency_widget=self.ui.state_frequency_0,
|
|
1107
|
+
),
|
|
1108
|
+
}
|
|
1109
|
+
|
|
1110
|
+
# update widget state to display current config
|
|
1111
|
+
states = {b.digit: b for b in self.brain_state_set.brain_states}
|
|
1112
|
+
for digit in range(10):
|
|
1113
|
+
if digit in states.keys():
|
|
1114
|
+
self.settings_widgets[digit].enabled_widget.setChecked(True)
|
|
1115
|
+
self.settings_widgets[digit].name_widget.setText(states[digit].name)
|
|
1116
|
+
self.settings_widgets[digit].is_scored_widget.setChecked(
|
|
1117
|
+
states[digit].is_scored
|
|
1118
|
+
)
|
|
1119
|
+
self.settings_widgets[digit].frequency_widget.setValue(
|
|
1120
|
+
states[digit].frequency
|
|
1121
|
+
)
|
|
1122
|
+
else:
|
|
1123
|
+
self.settings_widgets[digit].enabled_widget.setChecked(False)
|
|
1124
|
+
self.settings_widgets[digit].name_widget.setEnabled(False)
|
|
1125
|
+
self.settings_widgets[digit].is_scored_widget.setEnabled(False)
|
|
1126
|
+
self.settings_widgets[digit].frequency_widget.setEnabled(False)
|
|
1127
|
+
|
|
1128
|
+
# set callbacks
|
|
1129
|
+
for digit in range(10):
|
|
1130
|
+
state = self.settings_widgets[digit]
|
|
1131
|
+
state.enabled_widget.stateChanged.connect(
|
|
1132
|
+
partial(self.set_brain_state_enabled, digit)
|
|
1133
|
+
)
|
|
1134
|
+
state.name_widget.editingFinished.connect(self.finished_editing_state_name)
|
|
1135
|
+
state.is_scored_widget.stateChanged.connect(
|
|
1136
|
+
partial(self.is_scored_changed, digit)
|
|
1137
|
+
)
|
|
1138
|
+
state.frequency_widget.valueChanged.connect(self.state_frequency_changed)
|
|
1139
|
+
|
|
1140
|
+
def set_brain_state_enabled(self, digit, e) -> None:
|
|
1141
|
+
"""Called when user clicks "enabled" checkbox
|
|
1142
|
+
|
|
1143
|
+
:param digit: brain state digit
|
|
1144
|
+
:param e: unused but mandatory
|
|
1145
|
+
"""
|
|
1146
|
+
# get the widgets for this brain state
|
|
1147
|
+
state = self.settings_widgets[digit]
|
|
1148
|
+
# update state of these widgets
|
|
1149
|
+
is_checked = state.enabled_widget.isChecked()
|
|
1150
|
+
for widget in [
|
|
1151
|
+
state.name_widget,
|
|
1152
|
+
state.is_scored_widget,
|
|
1153
|
+
]:
|
|
1154
|
+
widget.setEnabled(is_checked)
|
|
1155
|
+
state.frequency_widget.setEnabled(
|
|
1156
|
+
is_checked and state.is_scored_widget.isChecked()
|
|
1157
|
+
)
|
|
1158
|
+
if not is_checked:
|
|
1159
|
+
state.name_widget.setText("")
|
|
1160
|
+
state.frequency_widget.setValue(0)
|
|
1161
|
+
# check that configuration is valid
|
|
1162
|
+
_ = self.check_config_validity()
|
|
1163
|
+
|
|
1164
|
+
def finished_editing_state_name(self) -> None:
|
|
1165
|
+
"""Called when user finishes editing a brain state's name"""
|
|
1166
|
+
_ = self.check_config_validity()
|
|
1167
|
+
|
|
1168
|
+
def state_frequency_changed(self, new_value) -> None:
|
|
1169
|
+
"""Called when user edits a brain state's frequency
|
|
1170
|
+
|
|
1171
|
+
:param new_value: unused
|
|
1172
|
+
"""
|
|
1173
|
+
_ = self.check_config_validity()
|
|
1174
|
+
|
|
1175
|
+
def is_scored_changed(self, digit, e) -> None:
|
|
1176
|
+
"""Called when user sets whether a state is scored
|
|
1177
|
+
|
|
1178
|
+
:param digit: brain state digit
|
|
1179
|
+
:param e: unused, but mandatory
|
|
1180
|
+
"""
|
|
1181
|
+
# get the widgets for this brain state
|
|
1182
|
+
state = self.settings_widgets[digit]
|
|
1183
|
+
# update the state of these widgets
|
|
1184
|
+
is_checked = state.is_scored_widget.isChecked()
|
|
1185
|
+
state.frequency_widget.setEnabled(is_checked)
|
|
1186
|
+
if not is_checked:
|
|
1187
|
+
state.frequency_widget.setValue(0)
|
|
1188
|
+
# check that configuration is valid
|
|
1189
|
+
_ = self.check_config_validity()
|
|
1190
|
+
|
|
1191
|
+
def check_config_validity(self) -> str:
|
|
1192
|
+
"""Check if brain state configuration on screen is valid"""
|
|
1193
|
+
# error message, if we get one
|
|
1194
|
+
message = None
|
|
1195
|
+
|
|
1196
|
+
# strip whitespace from brain state names and update display
|
|
1197
|
+
for digit in range(10):
|
|
1198
|
+
state = self.settings_widgets[digit]
|
|
1199
|
+
current_name = state.name_widget.text()
|
|
1200
|
+
formatted_name = current_name.strip()
|
|
1201
|
+
if current_name != formatted_name:
|
|
1202
|
+
state.name_widget.setText(formatted_name)
|
|
1203
|
+
|
|
1204
|
+
# check if names are unique and frequencies add up to 1
|
|
1205
|
+
names = []
|
|
1206
|
+
frequencies = []
|
|
1207
|
+
for digit in range(10):
|
|
1208
|
+
state = self.settings_widgets[digit]
|
|
1209
|
+
if state.enabled_widget.isChecked():
|
|
1210
|
+
names.append(state.name_widget.text())
|
|
1211
|
+
frequencies.append(state.frequency_widget.value())
|
|
1212
|
+
if len(names) != len(set(names)):
|
|
1213
|
+
message = "Error: names must be unique"
|
|
1214
|
+
if sum(frequencies) != 1:
|
|
1215
|
+
message = "Error: sum(frequencies) != 1"
|
|
1216
|
+
|
|
1217
|
+
if message is not None:
|
|
1218
|
+
self.ui.save_config_status.setText(message)
|
|
1219
|
+
self.ui.save_config_button.setEnabled(False)
|
|
1220
|
+
return message
|
|
1221
|
+
|
|
1222
|
+
self.ui.save_config_button.setEnabled(True)
|
|
1223
|
+
self.ui.save_config_status.setText("")
|
|
1224
|
+
|
|
1225
|
+
def save_brain_state_config(self):
|
|
1226
|
+
"""Save configuration to file"""
|
|
1227
|
+
# check that configuration is valid
|
|
1228
|
+
error_message = self.check_config_validity()
|
|
1229
|
+
if error_message is not None:
|
|
1230
|
+
return
|
|
1231
|
+
|
|
1232
|
+
# build a BrainStateMapper object from the current configuration
|
|
1233
|
+
brain_states = list()
|
|
1234
|
+
for digit in range(10):
|
|
1235
|
+
state = self.settings_widgets[digit]
|
|
1236
|
+
if state.enabled_widget.isChecked():
|
|
1237
|
+
brain_states.append(
|
|
1238
|
+
BrainState(
|
|
1239
|
+
name=state.name_widget.text(),
|
|
1240
|
+
digit=digit,
|
|
1241
|
+
is_scored=state.is_scored_widget.isChecked(),
|
|
1242
|
+
frequency=state.frequency_widget.value(),
|
|
1243
|
+
)
|
|
1244
|
+
)
|
|
1245
|
+
self.brain_state_set = BrainStateSet(brain_states, UNDEFINED_LABEL)
|
|
1246
|
+
|
|
1247
|
+
# save to file
|
|
1248
|
+
save_config(self.brain_state_set)
|
|
1249
|
+
self.ui.save_config_status.setText("configuration saved")
|
|
1250
|
+
|
|
1251
|
+
|
|
1252
|
+
def check_label_validity(
|
|
1253
|
+
labels: np.array,
|
|
1254
|
+
samples_in_recording: int,
|
|
1255
|
+
sampling_rate: int | float,
|
|
1256
|
+
epoch_length: int | float,
|
|
1257
|
+
brain_state_set: BrainStateSet,
|
|
1258
|
+
) -> str:
|
|
1259
|
+
"""Check whether a set of brain state labels is valid
|
|
1260
|
+
|
|
1261
|
+
This returns an error message if a problem is found with the
|
|
1262
|
+
brain state labels.
|
|
1263
|
+
|
|
1264
|
+
:param labels: brain state labels
|
|
1265
|
+
:param samples_in_recording: number of samples in the recording
|
|
1266
|
+
:param sampling_rate: sampling rate, in Hz
|
|
1267
|
+
:param epoch_length: epoch length, in seconds
|
|
1268
|
+
:param brain_state_set: BrainStateMapper object
|
|
1269
|
+
:return: error message
|
|
1270
|
+
"""
|
|
1271
|
+
# check that number of labels is correct
|
|
1272
|
+
samples_per_epoch = round(sampling_rate * epoch_length)
|
|
1273
|
+
epochs_in_recording = round(samples_in_recording / samples_per_epoch)
|
|
1274
|
+
if epochs_in_recording != labels.size:
|
|
1275
|
+
return LABEL_LENGTH_ERROR
|
|
1276
|
+
|
|
1277
|
+
# check that entries are valid
|
|
1278
|
+
if not set(labels.tolist()).issubset(
|
|
1279
|
+
set([b.digit for b in brain_state_set.brain_states] + [UNDEFINED_LABEL])
|
|
1280
|
+
):
|
|
1281
|
+
return "label file contains invalid entries"
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
def check_config_consistency(
|
|
1285
|
+
current_brain_states: dict,
|
|
1286
|
+
model_brain_states: dict,
|
|
1287
|
+
current_epoch_length: int | float,
|
|
1288
|
+
model_epoch_length: int | float,
|
|
1289
|
+
) -> list[str]:
|
|
1290
|
+
"""Compare current brain state config to the model's config
|
|
1291
|
+
|
|
1292
|
+
This only displays warnings - the user should decide whether to proceed
|
|
1293
|
+
|
|
1294
|
+
:param current_brain_states: current brain state config
|
|
1295
|
+
:param model_brain_states: brain state config when the model was created
|
|
1296
|
+
:param current_epoch_length: current epoch length setting
|
|
1297
|
+
:param model_epoch_length: epoch length used when the model was created
|
|
1298
|
+
"""
|
|
1299
|
+
output = list()
|
|
1300
|
+
|
|
1301
|
+
# make lists of names and digits for scored brain states
|
|
1302
|
+
current_scored_states = {
|
|
1303
|
+
f: [b[f] for b in current_brain_states if b["is_scored"]]
|
|
1304
|
+
for f in ["name", "digit"]
|
|
1305
|
+
}
|
|
1306
|
+
model_scored_states = {
|
|
1307
|
+
f: [b[f] for b in model_brain_states if b["is_scored"]]
|
|
1308
|
+
for f in ["name", "digit"]
|
|
1309
|
+
}
|
|
1310
|
+
|
|
1311
|
+
# generate message comparing the brain state configs
|
|
1312
|
+
config_comparisons = list()
|
|
1313
|
+
for config, config_name in zip(
|
|
1314
|
+
[current_scored_states, model_scored_states], ["current", "model's"]
|
|
1315
|
+
):
|
|
1316
|
+
config_comparisons.append(
|
|
1317
|
+
f"Scored brain states in {config_name} configuration: "
|
|
1318
|
+
f"""{
|
|
1319
|
+
", ".join(
|
|
1320
|
+
[
|
|
1321
|
+
f"{x}: {y}"
|
|
1322
|
+
for x, y in zip(
|
|
1323
|
+
config["digit"],
|
|
1324
|
+
config["name"],
|
|
1325
|
+
)
|
|
1326
|
+
]
|
|
1327
|
+
)
|
|
1328
|
+
}"""
|
|
1329
|
+
)
|
|
1330
|
+
|
|
1331
|
+
# check if the number of scored states is different
|
|
1332
|
+
len_diff = len(current_scored_states["name"]) - len(model_scored_states["name"])
|
|
1333
|
+
if len_diff != 0:
|
|
1334
|
+
output.append(
|
|
1335
|
+
(
|
|
1336
|
+
"WARNING: current brain state configuration has "
|
|
1337
|
+
f"{'fewer' if len_diff < 0 else 'more'} "
|
|
1338
|
+
"scored brain states than the model's configuration."
|
|
1339
|
+
)
|
|
1340
|
+
)
|
|
1341
|
+
output = output + config_comparisons
|
|
1342
|
+
else:
|
|
1343
|
+
# the length is the same, but names might be different
|
|
1344
|
+
if current_scored_states["name"] != model_scored_states["name"]:
|
|
1345
|
+
output.append(
|
|
1346
|
+
(
|
|
1347
|
+
"WARNING: current brain state configuration appears "
|
|
1348
|
+
"to contain different brain states than "
|
|
1349
|
+
"the model's configuration."
|
|
1350
|
+
)
|
|
1351
|
+
)
|
|
1352
|
+
output = output + config_comparisons
|
|
1353
|
+
|
|
1354
|
+
if current_epoch_length != model_epoch_length:
|
|
1355
|
+
output.append(
|
|
1356
|
+
(
|
|
1357
|
+
"Warning: the epoch length used when training this model "
|
|
1358
|
+
"does not match the current epoch length setting."
|
|
1359
|
+
)
|
|
1360
|
+
)
|
|
1361
|
+
|
|
1362
|
+
return output
|
|
1363
|
+
|
|
1364
|
+
|
|
1365
|
+
def run_primary_window() -> None:
|
|
1366
|
+
app = QtWidgets.QApplication(sys.argv)
|
|
1367
|
+
AccuSleepWindow()
|
|
1368
|
+
sys.exit(app.exec())
|
|
1369
|
+
|
|
1370
|
+
|
|
1371
|
+
if __name__ == "__main__":
|
|
1372
|
+
run_primary_window()
|