accusleepy 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. accusleepy/__init__.py +0 -0
  2. accusleepy/__main__.py +4 -0
  3. accusleepy/bouts.py +142 -0
  4. accusleepy/brain_state_set.py +89 -0
  5. accusleepy/classification.py +285 -0
  6. accusleepy/config.json +24 -0
  7. accusleepy/constants.py +46 -0
  8. accusleepy/fileio.py +179 -0
  9. accusleepy/gui/__init__.py +0 -0
  10. accusleepy/gui/icons/brightness_down.png +0 -0
  11. accusleepy/gui/icons/brightness_up.png +0 -0
  12. accusleepy/gui/icons/double_down_arrow.png +0 -0
  13. accusleepy/gui/icons/double_up_arrow.png +0 -0
  14. accusleepy/gui/icons/down_arrow.png +0 -0
  15. accusleepy/gui/icons/home.png +0 -0
  16. accusleepy/gui/icons/question.png +0 -0
  17. accusleepy/gui/icons/save.png +0 -0
  18. accusleepy/gui/icons/up_arrow.png +0 -0
  19. accusleepy/gui/icons/zoom_in.png +0 -0
  20. accusleepy/gui/icons/zoom_out.png +0 -0
  21. accusleepy/gui/images/primary_window.png +0 -0
  22. accusleepy/gui/images/viewer_window.png +0 -0
  23. accusleepy/gui/images/viewer_window_annotated.png +0 -0
  24. accusleepy/gui/main.py +1494 -0
  25. accusleepy/gui/manual_scoring.py +1096 -0
  26. accusleepy/gui/mplwidget.py +386 -0
  27. accusleepy/gui/primary_window.py +2577 -0
  28. accusleepy/gui/primary_window.ui +3831 -0
  29. accusleepy/gui/resources.qrc +16 -0
  30. accusleepy/gui/resources_rc.py +6710 -0
  31. accusleepy/gui/text/config_guide.txt +27 -0
  32. accusleepy/gui/text/main_guide.md +167 -0
  33. accusleepy/gui/text/manual_scoring_guide.md +23 -0
  34. accusleepy/gui/viewer_window.py +610 -0
  35. accusleepy/gui/viewer_window.ui +926 -0
  36. accusleepy/models.py +108 -0
  37. accusleepy/multitaper.py +661 -0
  38. accusleepy/signal_processing.py +469 -0
  39. accusleepy/temperature_scaling.py +157 -0
  40. accusleepy-0.6.0.dist-info/METADATA +106 -0
  41. accusleepy-0.6.0.dist-info/RECORD +42 -0
  42. accusleepy-0.6.0.dist-info/WHEEL +4 -0
accusleepy/gui/main.py ADDED
@@ -0,0 +1,1494 @@
1
+ # AccuSleePy main window
2
+ # Icon source: Arkinasi, https://www.flaticon.com/authors/arkinasi
3
+
4
+ import datetime
5
+ import os
6
+ import shutil
7
+ import sys
8
+ from dataclasses import dataclass
9
+ from functools import partial
10
+
11
+ import numpy as np
12
+ import toml
13
+ from PySide6.QtCore import (
14
+ QEvent,
15
+ QKeyCombination,
16
+ QObject,
17
+ QRect,
18
+ Qt,
19
+ QUrl,
20
+ )
21
+ from PySide6.QtGui import QKeySequence, QShortcut
22
+ from PySide6.QtWidgets import (
23
+ QApplication,
24
+ QCheckBox,
25
+ QDoubleSpinBox,
26
+ QFileDialog,
27
+ QLabel,
28
+ QListWidgetItem,
29
+ QMainWindow,
30
+ QTextBrowser,
31
+ QVBoxLayout,
32
+ QWidget,
33
+ )
34
+
35
+ from accusleepy.bouts import enforce_min_bout_length
36
+ from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
37
+ from accusleepy.constants import (
38
+ ANNOTATIONS_FILENAME,
39
+ CALIBRATION_ANNOTATION_FILENAME,
40
+ CALIBRATION_FILE_TYPE,
41
+ DEFAULT_MODEL_TYPE,
42
+ LABEL_FILE_TYPE,
43
+ MODEL_FILE_TYPE,
44
+ REAL_TIME_MODEL_TYPE,
45
+ RECORDING_FILE_TYPES,
46
+ RECORDING_LIST_FILE_TYPE,
47
+ UNDEFINED_LABEL,
48
+ )
49
+ from accusleepy.fileio import (
50
+ Recording,
51
+ load_calibration_file,
52
+ load_config,
53
+ load_labels,
54
+ load_recording,
55
+ load_recording_list,
56
+ save_config,
57
+ save_labels,
58
+ save_recording_list,
59
+ )
60
+ from accusleepy.gui.manual_scoring import ManualScoringWindow
61
+ from accusleepy.gui.primary_window import Ui_PrimaryWindow
62
+ from accusleepy.signal_processing import (
63
+ create_training_images,
64
+ resample_and_standardize,
65
+ )
66
+
67
+ # note: functions using torch or scipy are lazily imported
68
+
69
+ # max number of messages to display
70
+ MESSAGE_BOX_MAX_DEPTH = 200
71
+ LABEL_LENGTH_ERROR = "label file length does not match recording length"
72
+ # relative path to config guide txt file
73
+ CONFIG_GUIDE_FILE = os.path.normpath(r"text/config_guide.txt")
74
+ MAIN_GUIDE_FILE = os.path.normpath(r"text/main_guide.md")
75
+
76
+
77
+ @dataclass
78
+ class StateSettings:
79
+ """Widgets for config settings for a brain state"""
80
+
81
+ digit: int
82
+ enabled_widget: QCheckBox
83
+ name_widget: QLabel
84
+ is_scored_widget: QCheckBox
85
+ frequency_widget: QDoubleSpinBox
86
+
87
+
88
+ class AccuSleepWindow(QMainWindow):
89
+ """AccuSleePy primary window"""
90
+
91
+ def __init__(self):
92
+ super(AccuSleepWindow, self).__init__()
93
+
94
+ # initialize the UI
95
+ self.ui = Ui_PrimaryWindow()
96
+ self.ui.setupUi(self)
97
+ self.setWindowTitle("AccuSleePy")
98
+
99
+ # fill in settings tab
100
+ self.brain_state_set, self.epoch_length, self.save_confidence_setting = (
101
+ load_config()
102
+ )
103
+ self.settings_widgets = None
104
+ self.initialize_settings_tab()
105
+
106
+ # initialize info about the recordings, classification data / settings
107
+ self.ui.epoch_length_input.setValue(self.epoch_length)
108
+ self.ui.save_confidence_checkbox.setChecked(self.save_confidence_setting)
109
+ self.model = None
110
+ self.only_overwrite_undefined = False
111
+ self.save_confidence_scores = self.save_confidence_setting
112
+ self.min_bout_length = 5
113
+
114
+ # initialize model training variables
115
+ self.training_epochs_per_img = 9
116
+ self.delete_training_images = True
117
+ self.model_type = DEFAULT_MODEL_TYPE
118
+ self.calibrate_trained_model = True
119
+
120
+ # metadata for the currently loaded classification model
121
+ self.model_epoch_length = None
122
+ self.model_epochs_per_img = None
123
+
124
+ # set up the list of recordings
125
+ first_recording = Recording(
126
+ widget=QListWidgetItem("Recording 1", self.ui.recording_list_widget),
127
+ )
128
+ self.ui.recording_list_widget.addItem(first_recording.widget)
129
+ self.ui.recording_list_widget.setCurrentRow(0)
130
+ # index of currently selected recording in the list
131
+ self.recording_index = 0
132
+ # list of recordings the user has added
133
+ self.recordings = [first_recording]
134
+
135
+ # messages to display
136
+ self.messages = []
137
+
138
+ # display current version
139
+ version = ""
140
+ toml_file = os.path.join(
141
+ os.path.dirname(
142
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
143
+ ),
144
+ "pyproject.toml",
145
+ )
146
+ if os.path.isfile(toml_file):
147
+ toml_data = toml.load(toml_file)
148
+ if "project" in toml_data and "version" in toml_data["project"]:
149
+ version = toml_data["project"]["version"]
150
+ self.ui.version_label.setText(f"v{version}")
151
+
152
+ # user input: keyboard shortcuts
153
+ keypress_quit = QShortcut(
154
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_W)),
155
+ self,
156
+ )
157
+ keypress_quit.activated.connect(self.close)
158
+
159
+ # user input: button presses
160
+ self.ui.add_button.clicked.connect(self.add_recording)
161
+ self.ui.remove_button.clicked.connect(self.remove_recording)
162
+ self.ui.recording_list_widget.currentRowChanged.connect(self.select_recording)
163
+ self.ui.sampling_rate_input.valueChanged.connect(self.update_sampling_rate)
164
+ self.ui.epoch_length_input.valueChanged.connect(self.update_epoch_length)
165
+ self.ui.recording_file_button.clicked.connect(self.select_recording_file)
166
+ self.ui.select_label_button.clicked.connect(self.select_label_file)
167
+ self.ui.create_label_button.clicked.connect(self.create_label_file)
168
+ self.ui.manual_scoring_button.clicked.connect(self.manual_scoring)
169
+ self.ui.create_calibration_button.clicked.connect(self.create_calibration_file)
170
+ self.ui.select_calibration_button.clicked.connect(self.select_calibration_file)
171
+ self.ui.load_model_button.clicked.connect(partial(self.load_model, None))
172
+ self.ui.score_all_button.clicked.connect(self.score_all)
173
+ self.ui.overwritecheckbox.stateChanged.connect(self.update_overwrite_policy)
174
+ self.ui.save_confidence_checkbox.stateChanged.connect(
175
+ self.update_confidence_policy
176
+ )
177
+ self.ui.bout_length_input.valueChanged.connect(self.update_min_bout_length)
178
+ self.ui.user_manual_button.clicked.connect(self.show_user_manual)
179
+ self.ui.image_number_input.valueChanged.connect(self.update_epochs_per_img)
180
+ self.ui.delete_image_box.stateChanged.connect(self.update_image_deletion)
181
+ self.ui.calibrate_checkbox.stateChanged.connect(
182
+ self.update_training_calibration
183
+ )
184
+ self.ui.train_model_button.clicked.connect(self.train_model)
185
+ self.ui.save_config_button.clicked.connect(self.save_brain_state_config)
186
+ self.ui.export_button.clicked.connect(self.export_recording_list)
187
+ self.ui.import_button.clicked.connect(self.import_recording_list)
188
+ self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
189
+
190
+ # user input: drag and drop
191
+ self.ui.recording_file_label.installEventFilter(self)
192
+ self.ui.label_file_label.installEventFilter(self)
193
+ self.ui.calibration_file_label.installEventFilter(self)
194
+ self.ui.model_label.installEventFilter(self)
195
+
196
+ self.show()
197
+
198
+ def model_type_radio_buttons(self, default_selected: bool) -> None:
199
+ """Toggle training default or real-time model
200
+
201
+ :param default_selected: whether default option is selected
202
+ """
203
+ if default_selected:
204
+ self.model_type = DEFAULT_MODEL_TYPE
205
+ else:
206
+ self.model_type = REAL_TIME_MODEL_TYPE
207
+
208
+ def export_recording_list(self) -> None:
209
+ """Save current list of recordings to file"""
210
+ # get the name for the recording list file
211
+ filename, _ = QFileDialog.getSaveFileName(
212
+ self,
213
+ caption="Save list of recordings as",
214
+ filter="*" + RECORDING_LIST_FILE_TYPE,
215
+ )
216
+ if not filename:
217
+ return
218
+ filename = os.path.normpath(filename)
219
+ save_recording_list(filename=filename, recordings=self.recordings)
220
+ self.show_message(f"Saved list of recordings to {filename}")
221
+
222
+ def import_recording_list(self):
223
+ """Load list of recordings from file, overwriting current list"""
224
+ file_dialog = QFileDialog(self)
225
+ file_dialog.setWindowTitle("Select list of recordings")
226
+ file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
227
+ file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
228
+ file_dialog.setNameFilter("*" + RECORDING_LIST_FILE_TYPE)
229
+
230
+ if file_dialog.exec():
231
+ selected_files = file_dialog.selectedFiles()
232
+ filename = selected_files[0]
233
+ filename = os.path.normpath(filename)
234
+ else:
235
+ return
236
+
237
+ # clear widget
238
+ self.ui.recording_list_widget.clear()
239
+ # overwrite current list
240
+ self.recordings = load_recording_list(filename)
241
+
242
+ for recording in self.recordings:
243
+ recording.widget = QListWidgetItem(
244
+ f"Recording {recording.name}", self.ui.recording_list_widget
245
+ )
246
+ self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
247
+
248
+ # display new list
249
+ self.ui.recording_list_widget.setCurrentRow(0)
250
+ self.show_message(f"Loaded list of recordings from {filename}")
251
+
252
+ def eventFilter(self, obj: QObject, event: QEvent) -> bool:
253
+ """Filter mouse events to detect when user drags/drops a file
254
+
255
+ :param obj: UI object receiving the event
256
+ :param event: mouse event
257
+ :return: whether to filter (block) the event
258
+ """
259
+ filename = None
260
+ if obj in [
261
+ self.ui.recording_file_label,
262
+ self.ui.label_file_label,
263
+ self.ui.calibration_file_label,
264
+ self.ui.model_label,
265
+ ]:
266
+ event.accept()
267
+ if event.type() == QEvent.Drop:
268
+ urls = event.mimeData().urls()
269
+ if len(urls) == 1:
270
+ filename = os.path.normpath(urls[0].toLocalFile())
271
+
272
+ if filename is None:
273
+ return super().eventFilter(obj, event)
274
+
275
+ _, file_extension = os.path.splitext(filename)
276
+
277
+ if obj == self.ui.recording_file_label:
278
+ if file_extension in RECORDING_FILE_TYPES:
279
+ self.recordings[self.recording_index].recording_file = filename
280
+ self.ui.recording_file_label.setText(filename)
281
+ elif obj == self.ui.label_file_label:
282
+ if file_extension == LABEL_FILE_TYPE:
283
+ self.recordings[self.recording_index].label_file = filename
284
+ self.ui.label_file_label.setText(filename)
285
+ elif obj == self.ui.calibration_file_label:
286
+ if file_extension == CALIBRATION_FILE_TYPE:
287
+ self.recordings[self.recording_index].calibration_file = filename
288
+ self.ui.calibration_file_label.setText(filename)
289
+ elif obj == self.ui.model_label:
290
+ self.load_model(filename=filename)
291
+
292
+ return super().eventFilter(obj, event)
293
+
294
+ def train_model(self) -> None:
295
+ # check basic training inputs
296
+ if (
297
+ self.model_type == DEFAULT_MODEL_TYPE
298
+ and self.training_epochs_per_img % 2 == 0
299
+ ):
300
+ self.show_message(
301
+ (
302
+ "ERROR: for the default model type, number of epochs "
303
+ "per image must be an odd number."
304
+ )
305
+ )
306
+ return
307
+
308
+ # determine fraction of training data to use for calibration
309
+ if self.calibrate_trained_model:
310
+ calibration_fraction = self.ui.calibration_spinbox.value() / 100
311
+ else:
312
+ calibration_fraction = 0
313
+
314
+ # check some inputs for each recording
315
+ for recording_index in range(len(self.recordings)):
316
+ error_message = self.check_single_file_inputs(recording_index)
317
+ if error_message:
318
+ self.show_message(
319
+ f"ERROR (recording {self.recordings[recording_index].name}): {error_message}"
320
+ )
321
+ return
322
+
323
+ # get filename for the new model
324
+ model_filename, _ = QFileDialog.getSaveFileName(
325
+ self,
326
+ caption="Save classification model file as",
327
+ filter="*" + MODEL_FILE_TYPE,
328
+ )
329
+ if not model_filename:
330
+ self.show_message("Model training canceled, no filename given")
331
+ return
332
+ model_filename = os.path.normpath(model_filename)
333
+
334
+ # create (probably temporary) image folder in
335
+ # the same folder as the trained model
336
+ temp_image_dir = os.path.join(
337
+ os.path.dirname(model_filename),
338
+ "images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
339
+ )
340
+
341
+ if os.path.exists(temp_image_dir): # unlikely
342
+ self.show_message(
343
+ "Warning: training image folder exists, will be overwritten"
344
+ )
345
+ os.makedirs(temp_image_dir, exist_ok=True)
346
+
347
+ # create training images
348
+ self.show_message("Training, please wait. See console for progress updates.")
349
+ if not self.delete_training_images:
350
+ self.show_message((f"Creating training images in {temp_image_dir}"))
351
+ else:
352
+ self.show_message(
353
+ (f"Creating temporary folder of training images: {temp_image_dir}")
354
+ )
355
+ self.ui.message_area.repaint()
356
+ QApplication.processEvents()
357
+ print("Creating training images")
358
+ failed_recordings = create_training_images(
359
+ recordings=self.recordings,
360
+ output_path=temp_image_dir,
361
+ epoch_length=self.epoch_length,
362
+ epochs_per_img=self.training_epochs_per_img,
363
+ brain_state_set=self.brain_state_set,
364
+ model_type=self.model_type,
365
+ calibration_fraction=calibration_fraction,
366
+ )
367
+ if len(failed_recordings) > 0:
368
+ if len(failed_recordings) == len(self.recordings):
369
+ self.show_message("ERROR: no recordings were valid!")
370
+ return
371
+ else:
372
+ self.show_message(
373
+ (
374
+ "WARNING: the following recordings could not be "
375
+ "loaded and will not be used for training: "
376
+ f"{', '.join([str(r) for r in failed_recordings])}"
377
+ )
378
+ )
379
+
380
+ # train model
381
+ self.show_message("Training model")
382
+ self.ui.message_area.repaint()
383
+ QApplication.processEvents()
384
+ print("Training model")
385
+ from accusleepy.classification import create_dataloader, train_ssann
386
+ from accusleepy.models import save_model
387
+ from accusleepy.temperature_scaling import ModelWithTemperature
388
+
389
+ model = train_ssann(
390
+ annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
391
+ img_dir=temp_image_dir,
392
+ mixture_weights=self.brain_state_set.mixture_weights,
393
+ n_classes=self.brain_state_set.n_classes,
394
+ )
395
+
396
+ # calibrate the model
397
+ if self.calibrate_trained_model:
398
+ calibration_annotation_file = os.path.join(
399
+ temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
400
+ )
401
+ calibration_dataloader = create_dataloader(
402
+ annotations_file=calibration_annotation_file, img_dir=temp_image_dir
403
+ )
404
+ model = ModelWithTemperature(model)
405
+ print("Calibrating model")
406
+ model.set_temperature(calibration_dataloader)
407
+
408
+ # save model
409
+ save_model(
410
+ model=model,
411
+ filename=model_filename,
412
+ epoch_length=self.epoch_length,
413
+ epochs_per_img=self.training_epochs_per_img,
414
+ model_type=self.model_type,
415
+ brain_state_set=self.brain_state_set,
416
+ is_calibrated=self.calibrate_trained_model,
417
+ )
418
+
419
+ # optionally delete images
420
+ if self.delete_training_images:
421
+ print("Cleaning up training image folder")
422
+ shutil.rmtree(temp_image_dir)
423
+
424
+ self.show_message(f"Training complete. Saved model to {model_filename}")
425
+ print("Training complete.")
426
+
427
+ def update_image_deletion(self) -> None:
428
+ """Update choice of whether to delete images after training"""
429
+ self.delete_training_images = self.ui.delete_image_box.isChecked()
430
+
431
+ def update_training_calibration(self) -> None:
432
+ """Update choice of whether to calibrate model after training"""
433
+ self.calibrate_trained_model = self.ui.calibrate_checkbox.isChecked()
434
+ self.ui.calibration_spinbox.setEnabled(self.calibrate_trained_model)
435
+
436
+ def update_epochs_per_img(self, new_value) -> None:
437
+ """Update number of epochs per image
438
+
439
+ :param new_value: new number of epochs per image
440
+ """
441
+ self.training_epochs_per_img = new_value
442
+
443
+ def score_all(self) -> None:
444
+ """Score all recordings using the classification model"""
445
+ # check basic inputs
446
+ if self.model is None:
447
+ self.ui.score_all_status.setText("missing classification model")
448
+ self.show_message("ERROR: no classification model file selected")
449
+ return
450
+ if self.min_bout_length < self.epoch_length:
451
+ self.ui.score_all_status.setText("invalid minimum bout length")
452
+ self.show_message("ERROR: minimum bout length must be >= epoch length")
453
+ return
454
+ if self.epoch_length != self.model_epoch_length:
455
+ self.ui.score_all_status.setText("invalid epoch length")
456
+ self.show_message(
457
+ (
458
+ "ERROR: model was trained with an epoch length of "
459
+ f"{self.model_epoch_length} seconds, but the current "
460
+ f"epoch length setting is {self.epoch_length} seconds."
461
+ )
462
+ )
463
+ return
464
+
465
+ self.ui.score_all_status.setText("running...")
466
+ self.ui.score_all_status.repaint()
467
+ QApplication.processEvents()
468
+
469
+ from accusleepy.classification import score_recording
470
+
471
+ # check some inputs for each recording
472
+ for recording_index in range(len(self.recordings)):
473
+ error_message = self.check_single_file_inputs(recording_index)
474
+ if error_message:
475
+ self.ui.score_all_status.setText(
476
+ f"error on recording {self.recordings[recording_index].name}"
477
+ )
478
+ self.show_message(
479
+ f"ERROR (recording {self.recordings[recording_index].name}): {error_message}"
480
+ )
481
+ return
482
+ if self.recordings[recording_index].calibration_file == "":
483
+ self.ui.score_all_status.setText(
484
+ f"error on recording {self.recordings[recording_index].name}"
485
+ )
486
+ self.show_message(
487
+ (
488
+ f"ERROR (recording {self.recordings[recording_index].name}): "
489
+ "no calibration file selected"
490
+ )
491
+ )
492
+ return
493
+
494
+ # score each recording
495
+ for recording_index in range(len(self.recordings)):
496
+ # load EEG, EMG
497
+ try:
498
+ eeg, emg = load_recording(
499
+ self.recordings[recording_index].recording_file
500
+ )
501
+ sampling_rate = self.recordings[recording_index].sampling_rate
502
+
503
+ eeg, emg, sampling_rate = resample_and_standardize(
504
+ eeg=eeg,
505
+ emg=emg,
506
+ sampling_rate=sampling_rate,
507
+ epoch_length=self.epoch_length,
508
+ )
509
+ except Exception:
510
+ self.show_message(
511
+ (
512
+ "ERROR: could not load recording "
513
+ f"{self.recordings[recording_index].name}."
514
+ "This recording will be skipped."
515
+ )
516
+ )
517
+ continue
518
+
519
+ # load labels
520
+ label_file = self.recordings[recording_index].label_file
521
+ if os.path.isfile(label_file):
522
+ try:
523
+ # ignore any existing confidence scores; they will all be overwritten
524
+ existing_labels, _ = load_labels(label_file)
525
+ except Exception:
526
+ self.show_message(
527
+ (
528
+ "ERROR: could not load existing labels for recording "
529
+ f"{self.recordings[recording_index].name}."
530
+ "This recording will be skipped."
531
+ )
532
+ )
533
+ continue
534
+ # only check the length
535
+ samples_per_epoch = sampling_rate * self.epoch_length
536
+ epochs_in_recording = round(eeg.size / samples_per_epoch)
537
+ if epochs_in_recording != existing_labels.size:
538
+ self.show_message(
539
+ (
540
+ "ERROR: existing labels for recording "
541
+ f"{self.recordings[recording_index].name} "
542
+ "do not match the recording length. "
543
+ "This recording will be skipped."
544
+ )
545
+ )
546
+ continue
547
+ else:
548
+ existing_labels = None
549
+
550
+ # load calibration data
551
+ if not os.path.isfile(self.recordings[recording_index].calibration_file):
552
+ self.show_message(
553
+ (
554
+ "ERROR: calibration file does not exist for recording "
555
+ f"{self.recordings[recording_index].name}. "
556
+ "This recording will be skipped."
557
+ )
558
+ )
559
+ continue
560
+ try:
561
+ (
562
+ mixture_means,
563
+ mixture_sds,
564
+ ) = load_calibration_file(
565
+ self.recordings[recording_index].calibration_file
566
+ )
567
+ except Exception:
568
+ self.show_message(
569
+ (
570
+ "ERROR: could not load calibration file for recording "
571
+ f"{self.recordings[recording_index].name}. "
572
+ "This recording will be skipped."
573
+ )
574
+ )
575
+ continue
576
+
577
+ labels, confidence_scores = score_recording(
578
+ model=self.model,
579
+ eeg=eeg,
580
+ emg=emg,
581
+ mixture_means=mixture_means,
582
+ mixture_sds=mixture_sds,
583
+ sampling_rate=sampling_rate,
584
+ epoch_length=self.epoch_length,
585
+ epochs_per_img=self.model_epochs_per_img,
586
+ brain_state_set=self.brain_state_set,
587
+ )
588
+
589
+ # overwrite as needed
590
+ if existing_labels is not None and self.only_overwrite_undefined:
591
+ labels[existing_labels != UNDEFINED_LABEL] = existing_labels[
592
+ existing_labels != UNDEFINED_LABEL
593
+ ]
594
+
595
+ # enforce minimum bout length
596
+ labels = enforce_min_bout_length(
597
+ labels=labels,
598
+ epoch_length=self.epoch_length,
599
+ min_bout_length=self.min_bout_length,
600
+ )
601
+
602
+ # ignore confidence scores if desired
603
+ if not self.save_confidence_scores:
604
+ confidence_scores = None
605
+
606
+ # save results
607
+ save_labels(
608
+ labels=labels, filename=label_file, confidence_scores=confidence_scores
609
+ )
610
+ self.show_message(
611
+ (
612
+ "Saved labels for recording "
613
+ f"{self.recordings[recording_index].name} "
614
+ f"to {label_file}"
615
+ )
616
+ )
617
+
618
+ self.ui.score_all_status.setText("")
619
+
620
+ def load_model(self, filename=None) -> None:
621
+ """Load trained classification model from file
622
+
623
+ :param filename: model filename, if it's known
624
+ """
625
+ if filename is None:
626
+ file_dialog = QFileDialog(self)
627
+ file_dialog.setWindowTitle("Select classification model")
628
+ file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
629
+ file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
630
+ file_dialog.setNameFilter("*" + MODEL_FILE_TYPE)
631
+
632
+ if file_dialog.exec():
633
+ selected_files = file_dialog.selectedFiles()
634
+ filename = selected_files[0]
635
+ filename = os.path.normpath(filename)
636
+ else:
637
+ return
638
+
639
+ if not os.path.isfile(filename):
640
+ self.show_message("ERROR: model file does not exist")
641
+ return
642
+
643
+ self.show_message("Loading classification model")
644
+ self.ui.message_area.repaint()
645
+ QApplication.processEvents()
646
+
647
+ from accusleepy.models import load_model
648
+
649
+ try:
650
+ model, epoch_length, epochs_per_img, model_type, brain_states = load_model(
651
+ filename=filename
652
+ )
653
+ except Exception:
654
+ self.show_message(
655
+ (
656
+ "ERROR: could not load classification model. Check "
657
+ "user manual for instructions on creating this file."
658
+ )
659
+ )
660
+ return
661
+
662
+ # make sure only "default" model type is loaded
663
+ if model_type != DEFAULT_MODEL_TYPE:
664
+ self.show_message(
665
+ (
666
+ "ERROR: only 'default'-style models can be used. "
667
+ "'Real-time' models are not supported. "
668
+ "See classification.example_real_time_scoring_function.py "
669
+ "for an example of how to classify brain states in real time."
670
+ )
671
+ )
672
+ return
673
+
674
+ self.model = model
675
+ self.model_epoch_length = epoch_length
676
+ self.model_epochs_per_img = epochs_per_img
677
+
678
+ # warn user if the model's expected epoch length or brain states
679
+ # don't match the current configuration
680
+ config_warnings = check_config_consistency(
681
+ current_brain_states=self.brain_state_set.to_output_dict()[
682
+ BRAIN_STATES_KEY
683
+ ],
684
+ model_brain_states=brain_states,
685
+ current_epoch_length=self.epoch_length,
686
+ model_epoch_length=epoch_length,
687
+ )
688
+ if len(config_warnings) > 0:
689
+ for w in config_warnings:
690
+ self.show_message(w)
691
+ else:
692
+ self.show_message(f"Loaded classification model from {filename}")
693
+
694
+ self.ui.model_label.setText(filename)
695
+
696
+ def load_single_recording(
697
+ self, status_widget: QLabel
698
+ ) -> (np.array, np.array, int | float, bool):
699
+ """Load and preprocess one recording
700
+
701
+ This loads one recording, resamples it, and standardizes its length.
702
+ If an error occurs during this process, it is displayed in the
703
+ indicated widget.
704
+
705
+ :param status_widget: UI element on which to display error messages
706
+ :return: EEG data, EMG data, sampling rate, process completion
707
+ """
708
+ error_message = self.check_single_file_inputs(self.recording_index)
709
+ if error_message:
710
+ status_widget.setText(error_message)
711
+ self.show_message(f"ERROR: {error_message}")
712
+ return None, None, None, False
713
+
714
+ try:
715
+ eeg, emg = load_recording(
716
+ self.recordings[self.recording_index].recording_file
717
+ )
718
+ except Exception:
719
+ status_widget.setText("could not load recording")
720
+ self.show_message(
721
+ (
722
+ "ERROR: could not load recording. "
723
+ "Check user manual for formatting instructions."
724
+ )
725
+ )
726
+ return None, None, None, False
727
+
728
+ sampling_rate = self.recordings[self.recording_index].sampling_rate
729
+
730
+ eeg, emg, sampling_rate = resample_and_standardize(
731
+ eeg=eeg,
732
+ emg=emg,
733
+ sampling_rate=sampling_rate,
734
+ epoch_length=self.epoch_length,
735
+ )
736
+
737
+ return eeg, emg, sampling_rate, True
738
+
739
+ def create_calibration_file(self) -> None:
740
+ """Creates a calibration file
741
+
742
+ This loads a recording and its labels, checks that the labels are
743
+ all valid, creates the calibration file, and sets the
744
+ "calibration file" property of the current recording to be the
745
+ newly created file.
746
+ """
747
+ # load the recording
748
+ eeg, emg, sampling_rate, success = self.load_single_recording(
749
+ self.ui.calibration_status
750
+ )
751
+ if not success:
752
+ return
753
+
754
+ # load the labels
755
+ label_file = self.recordings[self.recording_index].label_file
756
+ if not os.path.isfile(label_file):
757
+ self.ui.calibration_status.setText("label file does not exist")
758
+ self.show_message("ERROR: label file does not exist")
759
+ return
760
+ try:
761
+ labels, _ = load_labels(label_file)
762
+ except Exception:
763
+ self.ui.calibration_status.setText("could not load labels")
764
+ self.show_message(
765
+ (
766
+ "ERROR: could not load labels. "
767
+ "Check user manual for formatting instructions."
768
+ )
769
+ )
770
+ return
771
+ label_error_message = check_label_validity(
772
+ labels=labels,
773
+ confidence_scores=None,
774
+ samples_in_recording=eeg.size,
775
+ sampling_rate=sampling_rate,
776
+ epoch_length=self.epoch_length,
777
+ brain_state_set=self.brain_state_set,
778
+ )
779
+ if label_error_message:
780
+ self.ui.calibration_status.setText("invalid label file")
781
+ self.show_message(f"ERROR: {label_error_message}")
782
+ return
783
+
784
+ # get the name for the calibration file
785
+ filename, _ = QFileDialog.getSaveFileName(
786
+ self,
787
+ caption="Save calibration file as",
788
+ filter="*" + CALIBRATION_FILE_TYPE,
789
+ )
790
+ if not filename:
791
+ return
792
+ filename = os.path.normpath(filename)
793
+
794
+ from accusleepy.classification import create_calibration_file
795
+
796
+ create_calibration_file(
797
+ filename=filename,
798
+ eeg=eeg,
799
+ emg=emg,
800
+ labels=labels,
801
+ sampling_rate=sampling_rate,
802
+ epoch_length=self.epoch_length,
803
+ brain_state_set=self.brain_state_set,
804
+ )
805
+
806
+ self.ui.calibration_status.setText("")
807
+ self.show_message(
808
+ (
809
+ "Created calibration file using recording "
810
+ f"{self.recordings[self.recording_index].name} "
811
+ f"at {filename}"
812
+ )
813
+ )
814
+
815
+ self.recordings[self.recording_index].calibration_file = filename
816
+ self.ui.calibration_file_label.setText(filename)
817
+
818
+ def check_single_file_inputs(self, recording_index: int) -> str | None:
819
+ """Check that a recording's inputs appear valid
820
+
821
+ This runs some basic tests for whether it will be possible to
822
+ load and score a recording. If any test fails, we return an
823
+ error message.
824
+
825
+ :param recording_index: index of the recording in the list of
826
+ all recordings.
827
+ :return: error message
828
+ """
829
+ sampling_rate = self.recordings[recording_index].sampling_rate
830
+ if self.epoch_length == 0:
831
+ return "epoch length can't be 0"
832
+ if sampling_rate == 0:
833
+ return "sampling rate can't be 0"
834
+ if self.epoch_length > sampling_rate:
835
+ return "invalid epoch length or sampling rate"
836
+ if self.recordings[self.recording_index].recording_file == "":
837
+ return "no recording selected"
838
+ if not os.path.isfile(self.recordings[self.recording_index].recording_file):
839
+ return "recording file does not exist"
840
+ if self.recordings[self.recording_index].label_file == "":
841
+ return "no label file selected"
842
+
843
+ def update_min_bout_length(self, new_value) -> None:
844
+ """Update the minimum bout length
845
+
846
+ :param new_value: new minimum bout length, in seconds
847
+ """
848
+ self.min_bout_length = new_value
849
+
850
+ def update_overwrite_policy(self, checked) -> None:
851
+ """Toggle overwriting policy
852
+
853
+ If the checkbox is enabled, only epochs where the brain state is set to
854
+ undefined will be overwritten by the automatic scoring process.
855
+
856
+ :param checked: state of the checkbox
857
+ """
858
+ self.only_overwrite_undefined = checked
859
+
860
+ def update_confidence_policy(self, checked) -> None:
861
+ """Toggle policy for saving confidence scores
862
+
863
+ If the checkbox is enabled, confidence scores will be saved to the label files.
864
+
865
+ :param checked: state of the checkbox
866
+ """
867
+ self.save_confidence_scores = checked
868
+
869
+ def manual_scoring(self) -> None:
870
+ """View the selected recording for manual scoring"""
871
+ # immediately display a status message
872
+ self.ui.manual_scoring_status.setText("loading...")
873
+ self.ui.manual_scoring_status.repaint()
874
+ QApplication.processEvents()
875
+
876
+ # load the recording
877
+ eeg, emg, sampling_rate, success = self.load_single_recording(
878
+ self.ui.manual_scoring_status
879
+ )
880
+ if not success:
881
+ return
882
+
883
+ # if the labels exist, load them
884
+ # otherwise, create a blank set of labels
885
+ label_file = self.recordings[self.recording_index].label_file
886
+ if os.path.isfile(label_file):
887
+ try:
888
+ labels, confidence_scores = load_labels(label_file)
889
+ except Exception:
890
+ self.ui.manual_scoring_status.setText("could not load labels")
891
+ self.show_message(
892
+ (
893
+ "ERROR: could not load labels. "
894
+ "Check user manual for formatting instructions."
895
+ )
896
+ )
897
+ return
898
+ else:
899
+ labels = (
900
+ np.ones(int(eeg.size / (sampling_rate * self.epoch_length)))
901
+ * UNDEFINED_LABEL
902
+ ).astype(int)
903
+ # manual scoring will not add a new confidence score column
904
+ # to a label file that does not have one
905
+ confidence_scores = None
906
+
907
+ # check that all labels are valid
908
+ label_error = check_label_validity(
909
+ labels=labels,
910
+ confidence_scores=confidence_scores,
911
+ samples_in_recording=eeg.size,
912
+ sampling_rate=sampling_rate,
913
+ epoch_length=self.epoch_length,
914
+ brain_state_set=self.brain_state_set,
915
+ )
916
+ if label_error:
917
+ # if the label length is only off by one, pad or truncate as needed
918
+ # and show a warning
919
+ if label_error == LABEL_LENGTH_ERROR:
920
+ # should be very close to an integer
921
+ samples_per_epoch = round(sampling_rate * self.epoch_length)
922
+ epochs_in_recording = round(eeg.size / samples_per_epoch)
923
+ if epochs_in_recording - labels.size == 1:
924
+ labels = np.concatenate((labels, np.array([UNDEFINED_LABEL])))
925
+ if confidence_scores is not None:
926
+ confidence_scores = np.concatenate(
927
+ (confidence_scores, np.array([0]))
928
+ )
929
+ self.show_message(
930
+ (
931
+ "WARNING: an undefined epoch was added to "
932
+ "the label file to correct its length."
933
+ )
934
+ )
935
+ elif labels.size - epochs_in_recording == 1:
936
+ labels = labels[:-1]
937
+ if confidence_scores is not None:
938
+ confidence_scores = confidence_scores[:-1]
939
+ self.show_message(
940
+ (
941
+ "WARNING: the last epoch was removed from "
942
+ "the label file to correct its length."
943
+ )
944
+ )
945
+ else:
946
+ self.ui.manual_scoring_status.setText("invalid label file")
947
+ self.show_message(f"ERROR: {label_error}")
948
+ return
949
+ else:
950
+ self.ui.manual_scoring_status.setText("invalid label file")
951
+ self.show_message(f"ERROR: {label_error}")
952
+ return
953
+
954
+ self.show_message(
955
+ f"Viewing recording {self.recordings[self.recording_index].name}"
956
+ )
957
+ self.ui.manual_scoring_status.setText("file is open")
958
+
959
+ # launch the manual scoring window
960
+ manual_scoring_window = ManualScoringWindow(
961
+ eeg=eeg,
962
+ emg=emg,
963
+ label_file=label_file,
964
+ labels=labels,
965
+ confidence_scores=confidence_scores,
966
+ sampling_rate=sampling_rate,
967
+ epoch_length=self.epoch_length,
968
+ )
969
+ manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
970
+ manual_scoring_window.exec()
971
+ self.ui.manual_scoring_status.setText("")
972
+
973
+ def create_label_file(self) -> None:
974
+ """Set the filename for a new label file"""
975
+ filename, _ = QFileDialog.getSaveFileName(
976
+ self,
977
+ caption="Set filename for label file (nothing will be overwritten yet)",
978
+ filter="*" + LABEL_FILE_TYPE,
979
+ )
980
+ if filename:
981
+ filename = os.path.normpath(filename)
982
+ self.recordings[self.recording_index].label_file = filename
983
+ self.ui.label_file_label.setText(filename)
984
+
985
+ def select_label_file(self) -> None:
986
+ """User can select an existing label file"""
987
+ file_dialog = QFileDialog(self)
988
+ file_dialog.setWindowTitle("Select label file")
989
+ file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
990
+ file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
991
+ file_dialog.setNameFilter("*" + LABEL_FILE_TYPE)
992
+
993
+ if file_dialog.exec():
994
+ selected_files = file_dialog.selectedFiles()
995
+ filename = selected_files[0]
996
+ filename = os.path.normpath(filename)
997
+ self.recordings[self.recording_index].label_file = filename
998
+ self.ui.label_file_label.setText(filename)
999
+
1000
+ def select_calibration_file(self) -> None:
1001
+ """User can select a calibration file"""
1002
+ file_dialog = QFileDialog(self)
1003
+ file_dialog.setWindowTitle("Select calibration file")
1004
+ file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
1005
+ file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
1006
+ file_dialog.setNameFilter("*" + CALIBRATION_FILE_TYPE)
1007
+
1008
+ if file_dialog.exec():
1009
+ selected_files = file_dialog.selectedFiles()
1010
+ filename = selected_files[0]
1011
+ filename = os.path.normpath(filename)
1012
+ self.recordings[self.recording_index].calibration_file = filename
1013
+ self.ui.calibration_file_label.setText(filename)
1014
+
1015
+ def select_recording_file(self) -> None:
1016
+ """User can select a recording file"""
1017
+ file_dialog = QFileDialog(self)
1018
+ file_dialog.setWindowTitle("Select recording file")
1019
+ file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
1020
+ file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
1021
+ file_dialog.setNameFilter(f"(*{' *'.join(RECORDING_FILE_TYPES)})")
1022
+
1023
+ if file_dialog.exec():
1024
+ selected_files = file_dialog.selectedFiles()
1025
+ filename = selected_files[0]
1026
+ filename = os.path.normpath(filename)
1027
+ self.recordings[self.recording_index].recording_file = filename
1028
+ self.ui.recording_file_label.setText(filename)
1029
+
1030
+ def show_recording_info(self) -> None:
1031
+ """Update the UI to show info for the selected recording"""
1032
+ self.ui.sampling_rate_input.setValue(
1033
+ self.recordings[self.recording_index].sampling_rate
1034
+ )
1035
+ self.ui.recording_file_label.setText(
1036
+ self.recordings[self.recording_index].recording_file
1037
+ )
1038
+ self.ui.label_file_label.setText(
1039
+ self.recordings[self.recording_index].label_file
1040
+ )
1041
+ self.ui.calibration_file_label.setText(
1042
+ self.recordings[self.recording_index].calibration_file
1043
+ )
1044
+
1045
+ def update_epoch_length(self, new_value: int | float) -> None:
1046
+ """Update the epoch length when the widget state changes
1047
+
1048
+ :param new_value: new epoch length
1049
+ """
1050
+ self.epoch_length = new_value
1051
+
1052
+ def update_sampling_rate(self, new_value: int | float) -> None:
1053
+ """Update recording's sampling rate when the widget state changes
1054
+
1055
+ :param new_value: new sampling rate
1056
+ """
1057
+ self.recordings[self.recording_index].sampling_rate = new_value
1058
+
1059
+ def show_message(self, message: str) -> None:
1060
+ """Display a new message to the user
1061
+
1062
+ :param message: message to display
1063
+ """
1064
+ self.messages.append(message)
1065
+ if len(self.messages) > MESSAGE_BOX_MAX_DEPTH:
1066
+ del self.messages[0]
1067
+ self.ui.message_area.setText("\n".join(self.messages))
1068
+ # scroll to the bottom
1069
+ scrollbar = self.ui.message_area.verticalScrollBar()
1070
+ scrollbar.setValue(scrollbar.maximum())
1071
+
1072
+ def select_recording(self, list_index: int) -> None:
1073
+ """Callback for when a recording is selected
1074
+
1075
+ :param list_index: index of this recording in the list widget
1076
+ """
1077
+ # get index of this recording
1078
+ self.recording_index = list_index
1079
+ # display information about this recording
1080
+ self.show_recording_info()
1081
+ self.ui.selected_recording_groupbox.setTitle(
1082
+ f"Data / actions for Recording {self.recordings[list_index].name}"
1083
+ )
1084
+
1085
+ def add_recording(self) -> None:
1086
+ """Add new recording to the list"""
1087
+ # find name to use for the new recording
1088
+ new_name = max([r.name for r in self.recordings]) + 1
1089
+
1090
+ # add new recording to list
1091
+ self.recordings.append(
1092
+ Recording(
1093
+ name=new_name,
1094
+ sampling_rate=self.recordings[self.recording_index].sampling_rate,
1095
+ widget=QListWidgetItem(
1096
+ f"Recording {new_name}", self.ui.recording_list_widget
1097
+ ),
1098
+ )
1099
+ )
1100
+
1101
+ # display new list
1102
+ self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
1103
+ self.ui.recording_list_widget.setCurrentRow(len(self.recordings) - 1)
1104
+ self.show_message(f"added Recording {new_name}")
1105
+
1106
+ def remove_recording(self) -> None:
1107
+ """Delete selected recording from the list"""
1108
+ if len(self.recordings) > 1:
1109
+ current_list_index = self.ui.recording_list_widget.currentRow()
1110
+ _ = self.ui.recording_list_widget.takeItem(current_list_index)
1111
+ self.show_message(
1112
+ f"deleted Recording {self.recordings[current_list_index].name}"
1113
+ )
1114
+ del self.recordings[current_list_index]
1115
+ self.recording_index = self.ui.recording_list_widget.currentRow()
1116
+
1117
+ def show_user_manual(self) -> None:
1118
+ """Show a popup window with the user manual"""
1119
+ self.popup = QWidget()
1120
+ self.popup_vlayout = QVBoxLayout(self.popup)
1121
+ self.guide_textbox = QTextBrowser(self.popup)
1122
+ self.popup_vlayout.addWidget(self.guide_textbox)
1123
+
1124
+ url = QUrl.fromLocalFile(MAIN_GUIDE_FILE)
1125
+ self.guide_textbox.setSource(url)
1126
+ self.guide_textbox.setOpenLinks(False)
1127
+
1128
+ self.popup.setGeometry(QRect(100, 100, 600, 600))
1129
+ self.popup.show()
1130
+
1131
+ def initialize_settings_tab(self):
1132
+ """Populate settings tab and assign its callbacks"""
1133
+ # show information about the settings tab
1134
+ config_guide_file = open(
1135
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_GUIDE_FILE),
1136
+ "r",
1137
+ )
1138
+ config_guide_text = config_guide_file.read()
1139
+ config_guide_file.close()
1140
+ self.ui.settings_text.setText(config_guide_text)
1141
+
1142
+ # store dictionary that maps digits to rows of widgets
1143
+ # in the settings tab
1144
+ self.settings_widgets = {
1145
+ 1: StateSettings(
1146
+ digit=1,
1147
+ enabled_widget=self.ui.enable_state_1,
1148
+ name_widget=self.ui.state_name_1,
1149
+ is_scored_widget=self.ui.state_scored_1,
1150
+ frequency_widget=self.ui.state_frequency_1,
1151
+ ),
1152
+ 2: StateSettings(
1153
+ digit=2,
1154
+ enabled_widget=self.ui.enable_state_2,
1155
+ name_widget=self.ui.state_name_2,
1156
+ is_scored_widget=self.ui.state_scored_2,
1157
+ frequency_widget=self.ui.state_frequency_2,
1158
+ ),
1159
+ 3: StateSettings(
1160
+ digit=3,
1161
+ enabled_widget=self.ui.enable_state_3,
1162
+ name_widget=self.ui.state_name_3,
1163
+ is_scored_widget=self.ui.state_scored_3,
1164
+ frequency_widget=self.ui.state_frequency_3,
1165
+ ),
1166
+ 4: StateSettings(
1167
+ digit=4,
1168
+ enabled_widget=self.ui.enable_state_4,
1169
+ name_widget=self.ui.state_name_4,
1170
+ is_scored_widget=self.ui.state_scored_4,
1171
+ frequency_widget=self.ui.state_frequency_4,
1172
+ ),
1173
+ 5: StateSettings(
1174
+ digit=5,
1175
+ enabled_widget=self.ui.enable_state_5,
1176
+ name_widget=self.ui.state_name_5,
1177
+ is_scored_widget=self.ui.state_scored_5,
1178
+ frequency_widget=self.ui.state_frequency_5,
1179
+ ),
1180
+ 6: StateSettings(
1181
+ digit=6,
1182
+ enabled_widget=self.ui.enable_state_6,
1183
+ name_widget=self.ui.state_name_6,
1184
+ is_scored_widget=self.ui.state_scored_6,
1185
+ frequency_widget=self.ui.state_frequency_6,
1186
+ ),
1187
+ 7: StateSettings(
1188
+ digit=7,
1189
+ enabled_widget=self.ui.enable_state_7,
1190
+ name_widget=self.ui.state_name_7,
1191
+ is_scored_widget=self.ui.state_scored_7,
1192
+ frequency_widget=self.ui.state_frequency_7,
1193
+ ),
1194
+ 8: StateSettings(
1195
+ digit=8,
1196
+ enabled_widget=self.ui.enable_state_8,
1197
+ name_widget=self.ui.state_name_8,
1198
+ is_scored_widget=self.ui.state_scored_8,
1199
+ frequency_widget=self.ui.state_frequency_8,
1200
+ ),
1201
+ 9: StateSettings(
1202
+ digit=9,
1203
+ enabled_widget=self.ui.enable_state_9,
1204
+ name_widget=self.ui.state_name_9,
1205
+ is_scored_widget=self.ui.state_scored_9,
1206
+ frequency_widget=self.ui.state_frequency_9,
1207
+ ),
1208
+ 0: StateSettings(
1209
+ digit=0,
1210
+ enabled_widget=self.ui.enable_state_0,
1211
+ name_widget=self.ui.state_name_0,
1212
+ is_scored_widget=self.ui.state_scored_0,
1213
+ frequency_widget=self.ui.state_frequency_0,
1214
+ ),
1215
+ }
1216
+
1217
+ # update widget state to display current config
1218
+ self.ui.default_epoch_input.setValue(self.epoch_length)
1219
+ self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_setting)
1220
+ states = {b.digit: b for b in self.brain_state_set.brain_states}
1221
+ for digit in range(10):
1222
+ if digit in states.keys():
1223
+ self.settings_widgets[digit].enabled_widget.setChecked(True)
1224
+ self.settings_widgets[digit].name_widget.setText(states[digit].name)
1225
+ self.settings_widgets[digit].is_scored_widget.setChecked(
1226
+ states[digit].is_scored
1227
+ )
1228
+ self.settings_widgets[digit].frequency_widget.setValue(
1229
+ states[digit].frequency
1230
+ )
1231
+ else:
1232
+ self.settings_widgets[digit].enabled_widget.setChecked(False)
1233
+ self.settings_widgets[digit].name_widget.setEnabled(False)
1234
+ self.settings_widgets[digit].is_scored_widget.setEnabled(False)
1235
+ self.settings_widgets[digit].frequency_widget.setEnabled(False)
1236
+
1237
+ # set callbacks
1238
+ for digit in range(10):
1239
+ state = self.settings_widgets[digit]
1240
+ state.enabled_widget.stateChanged.connect(
1241
+ partial(self.set_brain_state_enabled, digit)
1242
+ )
1243
+ state.name_widget.editingFinished.connect(self.finished_editing_state_name)
1244
+ state.is_scored_widget.stateChanged.connect(
1245
+ partial(self.is_scored_changed, digit)
1246
+ )
1247
+ state.frequency_widget.valueChanged.connect(self.state_frequency_changed)
1248
+
1249
+ def set_brain_state_enabled(self, digit, e) -> None:
1250
+ """Called when user clicks "enabled" checkbox
1251
+
1252
+ :param digit: brain state digit
1253
+ :param e: unused but mandatory
1254
+ """
1255
+ # get the widgets for this brain state
1256
+ state = self.settings_widgets[digit]
1257
+ # update state of these widgets
1258
+ is_checked = state.enabled_widget.isChecked()
1259
+ for widget in [
1260
+ state.name_widget,
1261
+ state.is_scored_widget,
1262
+ ]:
1263
+ widget.setEnabled(is_checked)
1264
+ state.frequency_widget.setEnabled(
1265
+ is_checked and state.is_scored_widget.isChecked()
1266
+ )
1267
+ if not is_checked:
1268
+ state.name_widget.setText("")
1269
+ state.frequency_widget.setValue(0)
1270
+ # check that configuration is valid
1271
+ _ = self.check_config_validity()
1272
+
1273
+ def finished_editing_state_name(self) -> None:
1274
+ """Called when user finishes editing a brain state's name"""
1275
+ _ = self.check_config_validity()
1276
+
1277
+ def state_frequency_changed(self, new_value) -> None:
1278
+ """Called when user edits a brain state's frequency
1279
+
1280
+ :param new_value: unused
1281
+ """
1282
+ _ = self.check_config_validity()
1283
+
1284
+ def is_scored_changed(self, digit, e) -> None:
1285
+ """Called when user sets whether a state is scored
1286
+
1287
+ :param digit: brain state digit
1288
+ :param e: unused, but mandatory
1289
+ """
1290
+ # get the widgets for this brain state
1291
+ state = self.settings_widgets[digit]
1292
+ # update the state of these widgets
1293
+ is_checked = state.is_scored_widget.isChecked()
1294
+ state.frequency_widget.setEnabled(is_checked)
1295
+ if not is_checked:
1296
+ state.frequency_widget.setValue(0)
1297
+ # check that configuration is valid
1298
+ _ = self.check_config_validity()
1299
+
1300
+ def check_config_validity(self) -> str:
1301
+ """Check if brain state configuration on screen is valid"""
1302
+ # error message, if we get one
1303
+ message = None
1304
+
1305
+ # strip whitespace from brain state names and update display
1306
+ for digit in range(10):
1307
+ state = self.settings_widgets[digit]
1308
+ current_name = state.name_widget.text()
1309
+ formatted_name = current_name.strip()
1310
+ if current_name != formatted_name:
1311
+ state.name_widget.setText(formatted_name)
1312
+
1313
+ # check if names are unique and frequencies add up to 1
1314
+ names = []
1315
+ frequencies = []
1316
+ for digit in range(10):
1317
+ state = self.settings_widgets[digit]
1318
+ if state.enabled_widget.isChecked():
1319
+ names.append(state.name_widget.text())
1320
+ frequencies.append(state.frequency_widget.value())
1321
+ if len(names) != len(set(names)):
1322
+ message = "Error: names must be unique"
1323
+ if sum(frequencies) != 1:
1324
+ message = "Error: sum(frequencies) != 1"
1325
+
1326
+ if message is not None:
1327
+ self.ui.save_config_status.setText(message)
1328
+ self.ui.save_config_button.setEnabled(False)
1329
+ return message
1330
+
1331
+ self.ui.save_config_button.setEnabled(True)
1332
+ self.ui.save_config_status.setText("")
1333
+
1334
+ def save_brain_state_config(self):
1335
+ """Save configuration to file"""
1336
+ # check that configuration is valid
1337
+ error_message = self.check_config_validity()
1338
+ if error_message is not None:
1339
+ return
1340
+
1341
+ # build a BrainStateMapper object from the current configuration
1342
+ brain_states = list()
1343
+ for digit in range(10):
1344
+ state = self.settings_widgets[digit]
1345
+ if state.enabled_widget.isChecked():
1346
+ brain_states.append(
1347
+ BrainState(
1348
+ name=state.name_widget.text(),
1349
+ digit=digit,
1350
+ is_scored=state.is_scored_widget.isChecked(),
1351
+ frequency=state.frequency_widget.value(),
1352
+ )
1353
+ )
1354
+ self.brain_state_set = BrainStateSet(brain_states, UNDEFINED_LABEL)
1355
+
1356
+ # save to file
1357
+ save_config(
1358
+ self.brain_state_set,
1359
+ self.ui.default_epoch_input.value(),
1360
+ self.ui.confidence_setting_checkbox.isChecked(),
1361
+ )
1362
+ self.ui.save_config_status.setText("configuration saved")
1363
+
1364
+
1365
+ def check_label_validity(
1366
+ labels: np.array,
1367
+ confidence_scores: np.array,
1368
+ samples_in_recording: int,
1369
+ sampling_rate: int | float,
1370
+ epoch_length: int | float,
1371
+ brain_state_set: BrainStateSet,
1372
+ ) -> str | None:
1373
+ """Check whether a set of brain state labels is valid
1374
+
1375
+ This returns an error message if a problem is found with the
1376
+ brain state labels.
1377
+
1378
+ :param labels: brain state labels
1379
+ :param confidence_scores: confidence scores
1380
+ :param samples_in_recording: number of samples in the recording
1381
+ :param sampling_rate: sampling rate, in Hz
1382
+ :param epoch_length: epoch length, in seconds
1383
+ :param brain_state_set: BrainStateMapper object
1384
+ :return: error message
1385
+ """
1386
+ # check that number of labels is correct
1387
+ samples_per_epoch = round(sampling_rate * epoch_length)
1388
+ epochs_in_recording = round(samples_in_recording / samples_per_epoch)
1389
+ if epochs_in_recording != labels.size:
1390
+ return LABEL_LENGTH_ERROR
1391
+
1392
+ # check that entries are valid
1393
+ if not set(labels.tolist()).issubset(
1394
+ set([b.digit for b in brain_state_set.brain_states] + [UNDEFINED_LABEL])
1395
+ ):
1396
+ return "label file contains invalid entries"
1397
+
1398
+ if confidence_scores is not None:
1399
+ if np.min(confidence_scores) < 0 or np.max(confidence_scores) > 1:
1400
+ return "label file contains invalid confidence scores"
1401
+
1402
+ return None
1403
+
1404
+
1405
+ def check_config_consistency(
1406
+ current_brain_states: dict,
1407
+ model_brain_states: dict,
1408
+ current_epoch_length: int | float,
1409
+ model_epoch_length: int | float,
1410
+ ) -> list[str]:
1411
+ """Compare current brain state config to the model's config
1412
+
1413
+ This only displays warnings - the user should decide whether to proceed
1414
+
1415
+ :param current_brain_states: current brain state config
1416
+ :param model_brain_states: brain state config when the model was created
1417
+ :param current_epoch_length: current epoch length setting
1418
+ :param model_epoch_length: epoch length used when the model was created
1419
+ """
1420
+ output = list()
1421
+
1422
+ # make lists of names and digits for scored brain states
1423
+ current_scored_states = {
1424
+ f: [b[f] for b in current_brain_states if b["is_scored"]]
1425
+ for f in ["name", "digit"]
1426
+ }
1427
+ model_scored_states = {
1428
+ f: [b[f] for b in model_brain_states if b["is_scored"]]
1429
+ for f in ["name", "digit"]
1430
+ }
1431
+
1432
+ # generate message comparing the brain state configs
1433
+ config_comparisons = list()
1434
+ for config, config_name in zip(
1435
+ [current_scored_states, model_scored_states], ["current", "model's"]
1436
+ ):
1437
+ config_comparisons.append(
1438
+ f"Scored brain states in {config_name} configuration: "
1439
+ f"""{
1440
+ ", ".join(
1441
+ [
1442
+ f"{x}: {y}"
1443
+ for x, y in zip(
1444
+ config["digit"],
1445
+ config["name"],
1446
+ )
1447
+ ]
1448
+ )
1449
+ }"""
1450
+ )
1451
+
1452
+ # check if the number of scored states is different
1453
+ len_diff = len(current_scored_states["name"]) - len(model_scored_states["name"])
1454
+ if len_diff != 0:
1455
+ output.append(
1456
+ (
1457
+ "WARNING: current brain state configuration has "
1458
+ f"{'fewer' if len_diff < 0 else 'more'} "
1459
+ "scored brain states than the model's configuration."
1460
+ )
1461
+ )
1462
+ output = output + config_comparisons
1463
+ else:
1464
+ # the length is the same, but names might be different
1465
+ if current_scored_states["name"] != model_scored_states["name"]:
1466
+ output.append(
1467
+ (
1468
+ "WARNING: current brain state configuration appears "
1469
+ "to contain different brain states than "
1470
+ "the model's configuration."
1471
+ )
1472
+ )
1473
+ output = output + config_comparisons
1474
+
1475
+ if current_epoch_length != model_epoch_length:
1476
+ output.append(
1477
+ (
1478
+ "Warning: the epoch length used when training this model "
1479
+ f"({model_epoch_length} seconds) "
1480
+ "does not match the current epoch length setting."
1481
+ )
1482
+ )
1483
+
1484
+ return output
1485
+
1486
+
1487
+ def run_primary_window() -> None:
1488
+ app = QApplication(sys.argv)
1489
+ AccuSleepWindow()
1490
+ sys.exit(app.exec())
1491
+
1492
+
1493
+ if __name__ == "__main__":
1494
+ run_primary_window()