accusleepy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
accusleepy/gui/main.py ADDED
@@ -0,0 +1,1372 @@
1
+ # AccuSleePy main window
2
+ # Icon source: Arkinasi, https://www.flaticon.com/authors/arkinasi
3
+
4
+ import datetime
5
+ import os
6
+ import shutil
7
+ import sys
8
+ from dataclasses import dataclass
9
+ from functools import partial
10
+
11
+ import numpy as np
12
+ from PySide6 import QtCore, QtGui, QtWidgets
13
+
14
+ from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
15
+ from accusleepy.classification import (
16
+ create_calibration_file,
17
+ score_recording,
18
+ train_model,
19
+ )
20
+ from accusleepy.constants import (
21
+ CALIBRATION_FILE_TYPE,
22
+ DEFAULT_MODEL_TYPE,
23
+ LABEL_FILE_TYPE,
24
+ MODEL_FILE_TYPE,
25
+ REAL_TIME_MODEL_TYPE,
26
+ RECORDING_FILE_TYPES,
27
+ RECORDING_LIST_FILE_TYPE,
28
+ UNDEFINED_LABEL,
29
+ )
30
+ from accusleepy.fileio import (
31
+ Recording,
32
+ load_calibration_file,
33
+ load_config,
34
+ load_labels,
35
+ load_model,
36
+ load_recording,
37
+ load_recording_list,
38
+ save_config,
39
+ save_labels,
40
+ save_model,
41
+ save_recording_list,
42
+ )
43
+ from accusleepy.gui.manual_scoring import ManualScoringWindow
44
+ from accusleepy.gui.primary_window import Ui_PrimaryWindow
45
+ from accusleepy.signal_processing import (
46
+ ANNOTATIONS_FILENAME,
47
+ create_training_images,
48
+ enforce_min_bout_length,
49
+ resample_and_standardize,
50
+ )
51
+
52
+ # max number of messages to display
53
+ MESSAGE_BOX_MAX_DEPTH = 50
54
+ LABEL_LENGTH_ERROR = "label file length does not match recording length"
55
+ # relative path to user manual txt file
56
+ USER_MANUAL_FILE = "text/main_guide.txt"
57
+ CONFIG_GUIDE_FILE = "text/config_guide.txt"
58
+
59
+
60
+ @dataclass
61
+ class StateSettings:
62
+ """Widgets for config settings for a brain state"""
63
+
64
+ digit: int
65
+ enabled_widget: QtWidgets.QCheckBox
66
+ name_widget: QtWidgets.QLabel
67
+ is_scored_widget: QtWidgets.QCheckBox
68
+ frequency_widget: QtWidgets.QDoubleSpinBox
69
+
70
+
71
+ class AccuSleepWindow(QtWidgets.QMainWindow):
72
+ """AccuSleePy primary window"""
73
+
74
+ def __init__(self):
75
+ super(AccuSleepWindow, self).__init__()
76
+
77
+ # initialize the UI
78
+ self.ui = Ui_PrimaryWindow()
79
+ self.ui.setupUi(self)
80
+ self.setWindowTitle("AccuSleePy")
81
+
82
+ # fill in settings tab
83
+ self.brain_state_set = load_config()
84
+ self.settings_widgets = None
85
+ self.initialize_settings_tab()
86
+
87
+ # initialize info about the recordings, classification data / settings
88
+ self.epoch_length = 0
89
+ self.model = None
90
+ self.only_overwrite_undefined = False
91
+ self.min_bout_length = 5
92
+
93
+ # initialize model training variables
94
+ self.training_epochs_per_img = 9
95
+ self.delete_training_images = True
96
+ self.training_image_dir = ""
97
+ self.model_type = DEFAULT_MODEL_TYPE
98
+
99
+ # metadata for the currently loaded classification model
100
+ self.model_epoch_length = None
101
+ self.model_epochs_per_img = None
102
+
103
+ # set up the list of recordings
104
+ first_recording = Recording(
105
+ widget=QtWidgets.QListWidgetItem(
106
+ "Recording 1", self.ui.recording_list_widget
107
+ ),
108
+ )
109
+ self.ui.recording_list_widget.addItem(first_recording.widget)
110
+ self.ui.recording_list_widget.setCurrentRow(0)
111
+ # index of currently selected recording in the list
112
+ self.recording_index = 0
113
+ # list of recordings the user has added
114
+ self.recordings = [first_recording]
115
+
116
+ # messages to display
117
+ self.messages = []
118
+
119
+ # user input: keyboard shortcuts
120
+ keypress_quit = QtGui.QShortcut(
121
+ QtGui.QKeySequence(
122
+ QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_W)
123
+ ),
124
+ self,
125
+ )
126
+ keypress_quit.activated.connect(self.close)
127
+
128
+ # user input: button presses
129
+ self.ui.add_button.clicked.connect(self.add_recording)
130
+ self.ui.remove_button.clicked.connect(self.remove_recording)
131
+ self.ui.recording_list_widget.currentRowChanged.connect(self.select_recording)
132
+ self.ui.sampling_rate_input.valueChanged.connect(self.update_sampling_rate)
133
+ self.ui.epoch_length_input.valueChanged.connect(self.update_epoch_length)
134
+ self.ui.recording_file_button.clicked.connect(self.select_recording_file)
135
+ self.ui.select_label_button.clicked.connect(self.select_label_file)
136
+ self.ui.create_label_button.clicked.connect(self.create_label_file)
137
+ self.ui.manual_scoring_button.clicked.connect(self.manual_scoring)
138
+ self.ui.create_calibration_button.clicked.connect(self.create_calibration_file)
139
+ self.ui.select_calibration_button.clicked.connect(self.select_calibration_file)
140
+ self.ui.load_model_button.clicked.connect(partial(self.load_model, None))
141
+ self.ui.score_all_button.clicked.connect(self.score_all)
142
+ self.ui.overwritecheckbox.stateChanged.connect(self.update_overwrite_policy)
143
+ self.ui.bout_length_input.valueChanged.connect(self.update_min_bout_length)
144
+ self.ui.user_manual_button.clicked.connect(self.show_user_manual)
145
+ self.ui.image_number_input.valueChanged.connect(self.update_epochs_per_img)
146
+ self.ui.delete_image_box.stateChanged.connect(self.update_image_deletion)
147
+ self.ui.training_folder_button.clicked.connect(self.set_training_folder)
148
+ self.ui.train_model_button.clicked.connect(self.train_model)
149
+ self.ui.save_config_button.clicked.connect(self.save_brain_state_config)
150
+ self.ui.export_button.clicked.connect(self.export_recording_list)
151
+ self.ui.import_button.clicked.connect(self.import_recording_list)
152
+ self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
153
+
154
+ # user input: drag and drop
155
+ self.ui.recording_file_label.installEventFilter(self)
156
+ self.ui.label_file_label.installEventFilter(self)
157
+ self.ui.calibration_file_label.installEventFilter(self)
158
+ self.ui.model_label.installEventFilter(self)
159
+
160
+ self.show()
161
+
162
+ def model_type_radio_buttons(self, default_selected: bool) -> None:
163
+ """Toggle training default or real-time model
164
+
165
+ :param default_selected: whether default option is selected
166
+ """
167
+ if default_selected:
168
+ self.model_type = DEFAULT_MODEL_TYPE
169
+ else:
170
+ self.model_type = REAL_TIME_MODEL_TYPE
171
+
172
+ def export_recording_list(self) -> None:
173
+ """Save current list of recordings to file"""
174
+ # get the name for the recording list file
175
+ filename, _ = QtWidgets.QFileDialog.getSaveFileName(
176
+ self,
177
+ caption="Save list of recordings as",
178
+ filter="*" + RECORDING_LIST_FILE_TYPE,
179
+ )
180
+ if not filename:
181
+ return
182
+ save_recording_list(filename=filename, recordings=self.recordings)
183
+ self.show_message(f"Saved list of recordings to {filename}")
184
+
185
+ def import_recording_list(self):
186
+ """Load list of recordings from file, overwriting current list"""
187
+ file_dialog = QtWidgets.QFileDialog(self)
188
+ file_dialog.setWindowTitle("Select list of recordings")
189
+ file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
190
+ file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
191
+ file_dialog.setNameFilter("*" + RECORDING_LIST_FILE_TYPE)
192
+
193
+ if file_dialog.exec():
194
+ selected_files = file_dialog.selectedFiles()
195
+ filename = selected_files[0]
196
+ else:
197
+ return
198
+
199
+ # clear widget
200
+ self.ui.recording_list_widget.clear()
201
+ # overwrite current list
202
+ self.recordings = load_recording_list(filename)
203
+
204
+ for recording in self.recordings:
205
+ recording.widget = QtWidgets.QListWidgetItem(
206
+ f"Recording {recording.name}", self.ui.recording_list_widget
207
+ )
208
+ self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
209
+
210
+ # display new list
211
+ self.ui.recording_list_widget.setCurrentRow(0)
212
+ self.show_message(f"Loaded list of recordings from {filename}")
213
+
214
+ def eventFilter(self, obj: QtCore.QObject, event: QtCore.QEvent) -> bool:
215
+ """Filter mouse events to detect when user drags/drops a file
216
+
217
+ :param obj: UI object receiving the event
218
+ :param event: mouse event
219
+ :return: whether to filter (block) the event
220
+ """
221
+ filename = None
222
+ if obj in [
223
+ self.ui.recording_file_label,
224
+ self.ui.label_file_label,
225
+ self.ui.calibration_file_label,
226
+ self.ui.model_label,
227
+ ]:
228
+ event.accept()
229
+ if event.type() == QtCore.QEvent.Drop:
230
+ urls = event.mimeData().urls()
231
+ if len(urls) == 1:
232
+ filename = urls[0].toLocalFile()
233
+
234
+ if filename is None:
235
+ return super().eventFilter(obj, event)
236
+
237
+ _, file_extension = os.path.splitext(filename)
238
+
239
+ if obj == self.ui.recording_file_label:
240
+ if file_extension in RECORDING_FILE_TYPES:
241
+ self.recordings[self.recording_index].recording_file = filename
242
+ self.ui.recording_file_label.setText(filename)
243
+ elif obj == self.ui.label_file_label:
244
+ if file_extension == LABEL_FILE_TYPE:
245
+ self.recordings[self.recording_index].label_file = filename
246
+ self.ui.label_file_label.setText(filename)
247
+ elif obj == self.ui.calibration_file_label:
248
+ if file_extension == CALIBRATION_FILE_TYPE:
249
+ self.recordings[self.recording_index].calibration_file = filename
250
+ self.ui.calibration_file_label.setText(filename)
251
+ elif obj == self.ui.model_label:
252
+ self.load_model(filename=filename)
253
+
254
+ return super().eventFilter(obj, event)
255
+
256
+ def train_model(self) -> None:
257
+ # check basic training inputs
258
+ if (
259
+ self.model_type == DEFAULT_MODEL_TYPE
260
+ and self.training_epochs_per_img % 2 == 0
261
+ ):
262
+ self.show_message(
263
+ (
264
+ "ERROR: for the default model type, number of epochs "
265
+ "per image must be an odd number."
266
+ )
267
+ )
268
+ return
269
+ if self.training_image_dir == "":
270
+ self.show_message("ERROR: no folder selected for training images.")
271
+ return
272
+
273
+ # check some inputs for each recording
274
+ for recording_index in range(len(self.recordings)):
275
+ error_message = self.check_single_file_inputs(recording_index)
276
+ if error_message:
277
+ self.show_message(
278
+ f"ERROR ({self.recordings[recording_index].name}): {error_message}"
279
+ )
280
+ return
281
+
282
+ # get filename for the new model
283
+ model_filename, _ = QtWidgets.QFileDialog.getSaveFileName(
284
+ self,
285
+ caption="Save classification model file as",
286
+ filter="*" + MODEL_FILE_TYPE,
287
+ )
288
+ if not model_filename:
289
+ self.show_message("Model training canceled, no filename given")
290
+
291
+ # create image folder
292
+ if os.path.exists(self.training_image_dir):
293
+ self.show_message(
294
+ "Warning: training image folder exists, will be overwritten"
295
+ )
296
+ os.makedirs(self.training_image_dir, exist_ok=True)
297
+
298
+ # create training images
299
+ self.show_message(
300
+ (f"Creating training images in {self.training_image_dir}, please wait...")
301
+ )
302
+ self.ui.message_area.repaint()
303
+ QtWidgets.QApplication.processEvents()
304
+ print("Creating training images")
305
+ failed_recordings = create_training_images(
306
+ recordings=self.recordings,
307
+ output_path=self.training_image_dir,
308
+ epoch_length=self.epoch_length,
309
+ epochs_per_img=self.training_epochs_per_img,
310
+ brain_state_set=self.brain_state_set,
311
+ model_type=self.model_type,
312
+ )
313
+ if len(failed_recordings) > 0:
314
+ if len(failed_recordings) == len(self.recordings):
315
+ self.show_message("ERROR: no recordings were valid!")
316
+ else:
317
+ self.show_message(
318
+ (
319
+ "WARNING: the following recordings could not be"
320
+ "loaded and will not be used for training: "
321
+ f"{', '.join([str(r) for r in failed_recordings])}"
322
+ )
323
+ )
324
+
325
+ # train model
326
+ self.show_message("Training model, please wait...")
327
+ self.ui.message_area.repaint()
328
+ QtWidgets.QApplication.processEvents()
329
+ print("Training model")
330
+ model = train_model(
331
+ annotations_file=os.path.join(
332
+ self.training_image_dir, ANNOTATIONS_FILENAME
333
+ ),
334
+ img_dir=self.training_image_dir,
335
+ mixture_weights=self.brain_state_set.mixture_weights,
336
+ n_classes=self.brain_state_set.n_classes,
337
+ )
338
+
339
+ # save model
340
+ save_model(
341
+ model=model,
342
+ filename=model_filename,
343
+ epoch_length=self.epoch_length,
344
+ epochs_per_img=self.training_epochs_per_img,
345
+ model_type=self.model_type,
346
+ brain_state_set=self.brain_state_set,
347
+ )
348
+
349
+ # optionally delete images
350
+ if self.delete_training_images:
351
+ shutil.rmtree(self.training_image_dir)
352
+
353
+ self.show_message(f"Training complete, saved model to {model_filename}")
354
+
355
+ def set_training_folder(self):
356
+ training_folder_parent = QtWidgets.QFileDialog.getExistingDirectory(
357
+ self, "Select directory for training images"
358
+ )
359
+ if training_folder_parent:
360
+ self.training_image_dir = os.path.join(
361
+ training_folder_parent,
362
+ "images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
363
+ )
364
+ self.ui.image_folder_label.setText(self.training_image_dir)
365
+
366
+ def update_image_deletion(self) -> None:
367
+ """Update choice of whether to delete images after training"""
368
+ self.delete_training_images = self.ui.delete_image_box.isChecked()
369
+
370
+ def update_epochs_per_img(self, new_value) -> None:
371
+ """Update number of epochs per image
372
+
373
+ :param new_value: new number of epochs per image
374
+ """
375
+ self.training_epochs_per_img = new_value
376
+
377
+ def score_all(self) -> None:
378
+ """Score all recordings using the classification model"""
379
+ # check basic inputs
380
+ if self.model is None:
381
+ self.ui.score_all_status.setText("missing classification model")
382
+ self.show_message("ERROR: no classification model file selected")
383
+ return
384
+ if self.min_bout_length < self.epoch_length:
385
+ self.ui.score_all_status.setText("invalid minimum bout length")
386
+ self.show_message("ERROR: minimum bout length must be >= epoch length")
387
+ return
388
+ if self.epoch_length != self.model_epoch_length:
389
+ self.ui.score_all_status.setText("invalid epoch length")
390
+ self.show_message(
391
+ (
392
+ "ERROR: model was trained with an epoch length of "
393
+ f"{self.model_epoch_length} seconds, but the current "
394
+ f"epoch length setting is {self.epoch_length} seconds."
395
+ )
396
+ )
397
+ return
398
+
399
+ self.ui.score_all_status.setText("running...")
400
+ self.ui.score_all_status.repaint()
401
+ QtWidgets.QApplication.processEvents()
402
+
403
+ # check some inputs for each recording
404
+ for recording_index in range(len(self.recordings)):
405
+ error_message = self.check_single_file_inputs(recording_index)
406
+ if error_message:
407
+ self.ui.score_all_status.setText(
408
+ f"error on recording {self.recordings[recording_index].name}"
409
+ )
410
+ self.show_message(
411
+ f"ERROR ({self.recordings[recording_index].name}): {error_message}"
412
+ )
413
+ return
414
+ if self.recordings[recording_index].calibration_file == "":
415
+ self.ui.score_all_status.setText(
416
+ f"error on recording {self.recordings[recording_index].name}"
417
+ )
418
+ self.show_message(
419
+ (
420
+ f"ERROR ({self.recordings[recording_index].name}): "
421
+ "no calibration file selected"
422
+ )
423
+ )
424
+ return
425
+
426
+ # score each recording
427
+ for recording_index in range(len(self.recordings)):
428
+ # load EEG, EMG
429
+ try:
430
+ eeg, emg = load_recording(
431
+ self.recordings[recording_index].recording_file
432
+ )
433
+ sampling_rate = self.recordings[recording_index].sampling_rate
434
+
435
+ eeg, emg, sampling_rate = resample_and_standardize(
436
+ eeg=eeg,
437
+ emg=emg,
438
+ sampling_rate=sampling_rate,
439
+ epoch_length=self.epoch_length,
440
+ )
441
+ except Exception:
442
+ self.show_message(
443
+ (
444
+ "ERROR: could not load recording "
445
+ f"{self.recordings[recording_index].name}."
446
+ "This recording will be skipped."
447
+ )
448
+ )
449
+ continue
450
+
451
+ # load labels
452
+ label_file = self.recordings[recording_index].label_file
453
+ if os.path.isfile(label_file):
454
+ try:
455
+ existing_labels = load_labels(label_file)
456
+ except Exception:
457
+ self.show_message(
458
+ (
459
+ "ERROR: could not load existing labels for recording "
460
+ f"{self.recordings[recording_index].name}."
461
+ "This recording will be skipped."
462
+ )
463
+ )
464
+ continue
465
+ # only check the length
466
+ samples_per_epoch = sampling_rate * self.epoch_length
467
+ epochs_in_recording = round(eeg.size / samples_per_epoch)
468
+ if epochs_in_recording != existing_labels.size:
469
+ self.show_message(
470
+ (
471
+ "ERROR: existing labels for recording "
472
+ f"{self.recordings[recording_index].name} "
473
+ "do not match the recording length. "
474
+ "This recording will be skipped."
475
+ )
476
+ )
477
+ continue
478
+ else:
479
+ existing_labels = None
480
+
481
+ # load calibration data
482
+ if not os.path.isfile(self.recordings[recording_index].calibration_file):
483
+ self.show_message(
484
+ (
485
+ "ERROR: calibration file does not exist for recording "
486
+ f"{self.recordings[recording_index].name}. "
487
+ "This recording will be skipped."
488
+ )
489
+ )
490
+ continue
491
+ try:
492
+ (
493
+ mixture_means,
494
+ mixture_sds,
495
+ ) = load_calibration_file(
496
+ self.recordings[recording_index].calibration_file
497
+ )
498
+ except Exception:
499
+ self.show_message(
500
+ (
501
+ "ERROR: could not load calibration file for recording "
502
+ f"{self.recordings[recording_index].name}. "
503
+ "This recording will be skipped."
504
+ )
505
+ )
506
+ continue
507
+
508
+ labels = score_recording(
509
+ model=self.model,
510
+ eeg=eeg,
511
+ emg=emg,
512
+ mixture_means=mixture_means,
513
+ mixture_sds=mixture_sds,
514
+ sampling_rate=sampling_rate,
515
+ epoch_length=self.epoch_length,
516
+ epochs_per_img=self.model_epochs_per_img,
517
+ brain_state_set=self.brain_state_set,
518
+ )
519
+
520
+ # overwrite as needed
521
+ if existing_labels is not None and self.only_overwrite_undefined:
522
+ labels[existing_labels != UNDEFINED_LABEL] = existing_labels[
523
+ existing_labels != UNDEFINED_LABEL
524
+ ]
525
+
526
+ # enforce minimum bout length
527
+ labels = enforce_min_bout_length(
528
+ labels=labels,
529
+ epoch_length=self.epoch_length,
530
+ min_bout_length=self.min_bout_length,
531
+ )
532
+
533
+ # save results
534
+ save_labels(labels, label_file)
535
+ self.show_message(
536
+ (
537
+ "Saved labels for recording "
538
+ f"{self.recordings[recording_index].name} "
539
+ f"to {label_file}"
540
+ )
541
+ )
542
+
543
+ self.ui.score_all_status.setText("")
544
+
545
+ def load_model(self, filename=None) -> None:
546
+ """Load trained classification model from file
547
+
548
+ :param filename: model filename, if it's known
549
+ """
550
+ if filename is None:
551
+ file_dialog = QtWidgets.QFileDialog(self)
552
+ file_dialog.setWindowTitle("Select classification model")
553
+ file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
554
+ file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
555
+ file_dialog.setNameFilter("*" + MODEL_FILE_TYPE)
556
+
557
+ if file_dialog.exec():
558
+ selected_files = file_dialog.selectedFiles()
559
+ filename = selected_files[0]
560
+ else:
561
+ return
562
+
563
+ if not os.path.isfile(filename):
564
+ self.show_message("ERROR: model file does not exist")
565
+ return
566
+
567
+ try:
568
+ model, epoch_length, epochs_per_img, model_type, brain_states = load_model(
569
+ filename=filename
570
+ )
571
+ except Exception:
572
+ self.show_message(
573
+ (
574
+ "ERROR: could not load classification model. Check "
575
+ "user manual for instructions on creating this file."
576
+ )
577
+ )
578
+ return
579
+
580
+ # make sure only "default" model type is loaded
581
+ if model_type != DEFAULT_MODEL_TYPE:
582
+ self.show_message(
583
+ (
584
+ "ERROR: only 'default'-style models can be used. "
585
+ "'Real-time' models are not supported. "
586
+ "See classification.example_real_time_scoring_function.py "
587
+ "for an example of how to classify brain states in real time."
588
+ )
589
+ )
590
+ return
591
+
592
+ self.model = model
593
+ self.model_epoch_length = epoch_length
594
+ self.model_epochs_per_img = epochs_per_img
595
+
596
+ # warn user if the model's expected epoch length or brain states
597
+ # don't match the current configuration
598
+ config_warnings = check_config_consistency(
599
+ current_brain_states=self.brain_state_set.to_output_dict()[
600
+ BRAIN_STATES_KEY
601
+ ],
602
+ model_brain_states=brain_states,
603
+ current_epoch_length=self.epoch_length,
604
+ model_epoch_length=epoch_length,
605
+ )
606
+ if len(config_warnings) > 0:
607
+ for w in config_warnings:
608
+ self.show_message(w)
609
+
610
+ self.ui.model_label.setText(filename)
611
+
612
+ def load_single_recording(
613
+ self, status_widget: QtWidgets.QLabel
614
+ ) -> (np.array, np.array, int | float, bool):
615
+ """Load and preprocess one recording
616
+
617
+ This loads one recording, resamples it, and standardizes its length.
618
+ If an error occurs during this process, it is displayed in the
619
+ indicated widget.
620
+
621
+ :param status_widget: UI element on which to display error messages
622
+ :return: EEG data, EMG data, sampling rate, process completion
623
+ """
624
+ error_message = self.check_single_file_inputs(self.recording_index)
625
+ if error_message:
626
+ status_widget.setText(error_message)
627
+ self.show_message(f"ERROR: {error_message}")
628
+ return None, None, None, False
629
+
630
+ try:
631
+ eeg, emg = load_recording(
632
+ self.recordings[self.recording_index].recording_file
633
+ )
634
+ except Exception:
635
+ status_widget.setText("could not load recording")
636
+ self.show_message(
637
+ (
638
+ "ERROR: could not load recording. "
639
+ "Check user manual for formatting instructions."
640
+ )
641
+ )
642
+ return None, None, None, False
643
+
644
+ sampling_rate = self.recordings[self.recording_index].sampling_rate
645
+
646
+ eeg, emg, sampling_rate = resample_and_standardize(
647
+ eeg=eeg,
648
+ emg=emg,
649
+ sampling_rate=sampling_rate,
650
+ epoch_length=self.epoch_length,
651
+ )
652
+
653
+ return eeg, emg, sampling_rate, True
654
+
655
+ def create_calibration_file(self) -> None:
656
+ """Creates a calibration file
657
+
658
+ This loads a recording and its labels, checks that the labels are
659
+ all valid, creates the calibration file, and sets the
660
+ "calibration file" property of the current recording to be the
661
+ newly created file.
662
+ """
663
+ # load the recording
664
+ eeg, emg, sampling_rate, success = self.load_single_recording(
665
+ self.ui.calibration_status
666
+ )
667
+ if not success:
668
+ return
669
+
670
+ # load the labels
671
+ label_file = self.recordings[self.recording_index].label_file
672
+ if not os.path.isfile(label_file):
673
+ self.ui.calibration_status.setText("label file does not exist")
674
+ self.show_message("ERROR: label file does not exist")
675
+ return
676
+ try:
677
+ labels = load_labels(label_file)
678
+ except Exception:
679
+ self.ui.calibration_status.setText("could not load labels")
680
+ self.show_message(
681
+ (
682
+ "ERROR: could not load labels. "
683
+ "Check user manual for formatting instructions."
684
+ )
685
+ )
686
+ return
687
+ label_error_message = check_label_validity(
688
+ labels=labels,
689
+ samples_in_recording=eeg.size,
690
+ sampling_rate=sampling_rate,
691
+ epoch_length=self.epoch_length,
692
+ brain_state_set=self.brain_state_set,
693
+ )
694
+ if label_error_message:
695
+ self.ui.calibration_status.setText("invalid label file")
696
+ self.show_message(f"ERROR: {label_error_message}")
697
+ return
698
+
699
+ # get the name for the calibration file
700
+ filename, _ = QtWidgets.QFileDialog.getSaveFileName(
701
+ self,
702
+ caption="Save calibration file as",
703
+ filter="*" + CALIBRATION_FILE_TYPE,
704
+ )
705
+ if not filename:
706
+ return
707
+
708
+ create_calibration_file(
709
+ filename=filename,
710
+ eeg=eeg,
711
+ emg=emg,
712
+ labels=labels,
713
+ sampling_rate=sampling_rate,
714
+ epoch_length=self.epoch_length,
715
+ brain_state_set=self.brain_state_set,
716
+ )
717
+
718
+ self.ui.calibration_status.setText("")
719
+ self.show_message(
720
+ (
721
+ "Created calibration file using recording "
722
+ f"{self.recordings[self.recording_index].name} "
723
+ f"at {filename}"
724
+ )
725
+ )
726
+
727
+ self.recordings[self.recording_index].calibration_file = filename
728
+ self.ui.calibration_file_label.setText(filename)
729
+
730
+ def check_single_file_inputs(self, recording_index: int) -> str:
731
+ """Check that a recording's inputs appear valid
732
+
733
+ This runs some basic tests for whether it will be possible to
734
+ load and score a recording. If any test fails, we return an
735
+ error message.
736
+
737
+ :param recording_index: index of the recording in the list of
738
+ all recordings.
739
+ :return: error message
740
+ """
741
+ sampling_rate = self.recordings[recording_index].sampling_rate
742
+ if self.epoch_length == 0:
743
+ return "epoch length can't be 0"
744
+ if sampling_rate == 0:
745
+ return "sampling rate can't be 0"
746
+ if self.epoch_length > sampling_rate:
747
+ return "invalid epoch length or sampling rate"
748
+ if self.recordings[self.recording_index].recording_file == "":
749
+ return "no recording selected"
750
+ if self.recordings[self.recording_index].label_file == "":
751
+ return "no label file selected"
752
+
753
+ def update_min_bout_length(self, new_value) -> None:
754
+ """Update the minimum bout length
755
+
756
+ :param new_value: new minimum bout length, in seconds
757
+ """
758
+ self.min_bout_length = new_value
759
+
760
+ def update_overwrite_policy(self, checked) -> None:
761
+ """Toggle overwriting policy
762
+
763
+ If the checkbox is enabled, only epochs where the brain state is set to
764
+ undefined will be overwritten by the automatic scoring process.
765
+
766
+ :param checked: state of the checkbox
767
+ """
768
+ self.only_overwrite_undefined = checked
769
+
770
+ def manual_scoring(self) -> None:
771
+ """View the selected recording for manual scoring"""
772
+ # immediately display a status message
773
+ self.ui.manual_scoring_status.setText("loading...")
774
+ self.ui.manual_scoring_status.repaint()
775
+ QtWidgets.QApplication.processEvents()
776
+
777
+ # load the recording
778
+ eeg, emg, sampling_rate, success = self.load_single_recording(
779
+ self.ui.manual_scoring_status
780
+ )
781
+ if not success:
782
+ return
783
+
784
+ # if the labels exist, load them
785
+ # otherwise, create a blank set of labels
786
+ label_file = self.recordings[self.recording_index].label_file
787
+ if os.path.isfile(label_file):
788
+ try:
789
+ labels = load_labels(label_file)
790
+ except Exception:
791
+ self.ui.manual_scoring_status.setText("could not load labels")
792
+ self.show_message(
793
+ (
794
+ "ERROR: could not load labels. "
795
+ "Check user manual for formatting instructions."
796
+ )
797
+ )
798
+ return
799
+ else:
800
+ labels = (
801
+ np.ones(int(eeg.size / (sampling_rate * self.epoch_length)))
802
+ * UNDEFINED_LABEL
803
+ ).astype(int)
804
+
805
+ # check that all labels are valid
806
+ label_error = check_label_validity(
807
+ labels=labels,
808
+ samples_in_recording=eeg.size,
809
+ sampling_rate=sampling_rate,
810
+ epoch_length=self.epoch_length,
811
+ brain_state_set=self.brain_state_set,
812
+ )
813
+ if label_error:
814
+ # if the label length is only off by one, pad or truncate as needed
815
+ # and show a warning
816
+ if label_error == LABEL_LENGTH_ERROR:
817
+ # should be very close to an integer
818
+ samples_per_epoch = round(sampling_rate * self.epoch_length)
819
+ epochs_in_recording = round(eeg.size / samples_per_epoch)
820
+ if epochs_in_recording - labels.size == 1:
821
+ labels = np.concatenate((labels, np.array([UNDEFINED_LABEL])))
822
+ self.show_message(
823
+ (
824
+ "WARNING: an undefined epoch was added to "
825
+ "the label file to correct its length."
826
+ )
827
+ )
828
+ elif labels.size - epochs_in_recording == 1:
829
+ labels = labels[:-1]
830
+ self.show_message(
831
+ (
832
+ "WARNING: the last epoch was removed from "
833
+ "the label file to correct its length."
834
+ )
835
+ )
836
+ else:
837
+ self.ui.manual_scoring_status.setText("invalid label file")
838
+ self.show_message(f"ERROR: {label_error}")
839
+ return
840
+ else:
841
+ self.ui.manual_scoring_status.setText("invalid label file")
842
+ self.show_message(f"ERROR: {label_error}")
843
+ return
844
+
845
+ self.show_message(
846
+ f"Viewing recording {self.recordings[self.recording_index].name}"
847
+ )
848
+ self.ui.manual_scoring_status.setText("file is open")
849
+
850
+ # launch the manual scoring window
851
+ manual_scoring_window = ManualScoringWindow(
852
+ eeg=eeg,
853
+ emg=emg,
854
+ label_file=label_file,
855
+ labels=labels,
856
+ sampling_rate=sampling_rate,
857
+ epoch_length=self.epoch_length,
858
+ )
859
+ manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
860
+ manual_scoring_window.exec()
861
+ self.ui.manual_scoring_status.setText("")
862
+
863
+ def create_label_file(self) -> None:
864
+ """Set the filename for a new label file"""
865
+ filename, _ = QtWidgets.QFileDialog.getSaveFileName(
866
+ self,
867
+ caption="Set filename for label file (nothing will be overwritten yet)",
868
+ filter="*" + LABEL_FILE_TYPE,
869
+ )
870
+ if filename:
871
+ self.recordings[self.recording_index].label_file = filename
872
+ self.ui.label_file_label.setText(filename)
873
+
874
+ def select_label_file(self) -> None:
875
+ """User can select an existing label file"""
876
+ file_dialog = QtWidgets.QFileDialog(self)
877
+ file_dialog.setWindowTitle("Select label file")
878
+ file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
879
+ file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
880
+ file_dialog.setNameFilter("*" + LABEL_FILE_TYPE)
881
+
882
+ if file_dialog.exec():
883
+ selected_files = file_dialog.selectedFiles()
884
+ filename = selected_files[0]
885
+ self.recordings[self.recording_index].label_file = filename
886
+ self.ui.label_file_label.setText(filename)
887
+
888
+ def select_calibration_file(self) -> None:
889
+ """User can select a calibration file"""
890
+ file_dialog = QtWidgets.QFileDialog(self)
891
+ file_dialog.setWindowTitle("Select calibration file")
892
+ file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
893
+ file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
894
+ file_dialog.setNameFilter("*" + CALIBRATION_FILE_TYPE)
895
+
896
+ if file_dialog.exec():
897
+ selected_files = file_dialog.selectedFiles()
898
+ filename = selected_files[0]
899
+ self.recordings[self.recording_index].calibration_file = filename
900
+ self.ui.calibration_file_label.setText(filename)
901
+
902
+ def select_recording_file(self) -> None:
903
+ """User can select a recording file"""
904
+ file_dialog = QtWidgets.QFileDialog(self)
905
+ file_dialog.setWindowTitle("Select recording file")
906
+ file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
907
+ file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
908
+ file_dialog.setNameFilter(f"(*{' *'.join(RECORDING_FILE_TYPES)})")
909
+
910
+ if file_dialog.exec():
911
+ selected_files = file_dialog.selectedFiles()
912
+ filename = selected_files[0]
913
+ self.recordings[self.recording_index].recording_file = filename
914
+ self.ui.recording_file_label.setText(filename)
915
+
916
+ def show_recording_info(self) -> None:
917
+ """Update the UI to show info for the selected recording"""
918
+ self.ui.sampling_rate_input.setValue(
919
+ self.recordings[self.recording_index].sampling_rate
920
+ )
921
+ self.ui.recording_file_label.setText(
922
+ self.recordings[self.recording_index].recording_file
923
+ )
924
+ self.ui.label_file_label.setText(
925
+ self.recordings[self.recording_index].label_file
926
+ )
927
+ self.ui.calibration_file_label.setText(
928
+ self.recordings[self.recording_index].calibration_file
929
+ )
930
+
931
+ def update_epoch_length(self, new_value: int | float) -> None:
932
+ """Update the epoch length when the widget state changes
933
+
934
+ :param new_value: new epoch length
935
+ """
936
+ self.epoch_length = new_value
937
+
938
+ def update_sampling_rate(self, new_value: int | float) -> None:
939
+ """Update recording's sampling rate when the widget state changes
940
+
941
+ :param new_value: new sampling rate
942
+ """
943
+ self.recordings[self.recording_index].sampling_rate = new_value
944
+
945
+ def show_message(self, message: str) -> None:
946
+ """Display a new message to the user
947
+
948
+ :param message: message to display
949
+ """
950
+ self.messages.append(message)
951
+ if len(self.messages) > MESSAGE_BOX_MAX_DEPTH:
952
+ del self.messages[0]
953
+ self.ui.message_area.setText("\n".join(self.messages))
954
+ # scroll to the bottom
955
+ scrollbar = self.ui.message_area.verticalScrollBar()
956
+ scrollbar.setValue(scrollbar.maximum())
957
+
958
+ def select_recording(self, list_index: int) -> None:
959
+ """Callback for when a recording is selected
960
+
961
+ :param list_index: index of this recording in the list widget
962
+ """
963
+ # get index of this recording
964
+ self.recording_index = list_index
965
+ # display information about this recording
966
+ self.show_recording_info()
967
+ self.ui.selected_recording_groupbox.setTitle(
968
+ f"Data / actions for Recording {self.recordings[list_index].name}"
969
+ )
970
+
971
+ def add_recording(self) -> None:
972
+ """Add new recording to the list"""
973
+ # find name to use for the new recording
974
+ new_name = max([r.name for r in self.recordings]) + 1
975
+
976
+ # add new recording to list
977
+ self.recordings.append(
978
+ Recording(
979
+ name=new_name,
980
+ sampling_rate=self.recordings[self.recording_index].sampling_rate,
981
+ widget=QtWidgets.QListWidgetItem(
982
+ f"Recording {new_name}", self.ui.recording_list_widget
983
+ ),
984
+ )
985
+ )
986
+
987
+ # display new list
988
+ self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
989
+ self.ui.recording_list_widget.setCurrentRow(len(self.recordings) - 1)
990
+ self.show_message(f"added Recording {new_name}")
991
+
992
+ def remove_recording(self) -> None:
993
+ """Delete selected recording from the list"""
994
+ if len(self.recordings) > 1:
995
+ current_list_index = self.ui.recording_list_widget.currentRow()
996
+ _ = self.ui.recording_list_widget.takeItem(current_list_index)
997
+ self.show_message(
998
+ f"deleted Recording {self.recordings[current_list_index].name}"
999
+ )
1000
+ del self.recordings[current_list_index]
1001
+ self.recording_index = self.ui.recording_list_widget.currentRow()
1002
+
1003
+ def show_user_manual(self) -> None:
1004
+ """Show a popup window with the user manual"""
1005
+ user_manual_file = open(
1006
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), USER_MANUAL_FILE),
1007
+ "r",
1008
+ )
1009
+ user_manual_text = user_manual_file.read()
1010
+ user_manual_file.close()
1011
+
1012
+ label_widget = QtWidgets.QLabel()
1013
+ label_widget.setText(user_manual_text)
1014
+ scroll_area = QtWidgets.QScrollArea()
1015
+ scroll_area.setStyleSheet("background-color: white;")
1016
+ scroll_area.setWidget(label_widget)
1017
+ grid = QtWidgets.QGridLayout()
1018
+ grid.addWidget(scroll_area)
1019
+ self.popup = QtWidgets.QWidget()
1020
+ self.popup.setLayout(grid)
1021
+ self.popup.setGeometry(QtCore.QRect(100, 100, 600, 600))
1022
+ self.popup.show()
1023
+
1024
+ def initialize_settings_tab(self):
1025
+ """Populate settings tab and assign its callbacks"""
1026
+ # show information about the settings tab
1027
+ config_guide_file = open(
1028
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_GUIDE_FILE),
1029
+ "r",
1030
+ )
1031
+ config_guide_text = config_guide_file.read()
1032
+ config_guide_file.close()
1033
+ self.ui.settings_text.setText(config_guide_text)
1034
+
1035
+ # store dictionary that maps digits to rows of widgets
1036
+ # in the settings tab
1037
+ self.settings_widgets = {
1038
+ 1: StateSettings(
1039
+ digit=1,
1040
+ enabled_widget=self.ui.enable_state_1,
1041
+ name_widget=self.ui.state_name_1,
1042
+ is_scored_widget=self.ui.state_scored_1,
1043
+ frequency_widget=self.ui.state_frequency_1,
1044
+ ),
1045
+ 2: StateSettings(
1046
+ digit=2,
1047
+ enabled_widget=self.ui.enable_state_2,
1048
+ name_widget=self.ui.state_name_2,
1049
+ is_scored_widget=self.ui.state_scored_2,
1050
+ frequency_widget=self.ui.state_frequency_2,
1051
+ ),
1052
+ 3: StateSettings(
1053
+ digit=3,
1054
+ enabled_widget=self.ui.enable_state_3,
1055
+ name_widget=self.ui.state_name_3,
1056
+ is_scored_widget=self.ui.state_scored_3,
1057
+ frequency_widget=self.ui.state_frequency_3,
1058
+ ),
1059
+ 4: StateSettings(
1060
+ digit=4,
1061
+ enabled_widget=self.ui.enable_state_4,
1062
+ name_widget=self.ui.state_name_4,
1063
+ is_scored_widget=self.ui.state_scored_4,
1064
+ frequency_widget=self.ui.state_frequency_4,
1065
+ ),
1066
+ 5: StateSettings(
1067
+ digit=5,
1068
+ enabled_widget=self.ui.enable_state_5,
1069
+ name_widget=self.ui.state_name_5,
1070
+ is_scored_widget=self.ui.state_scored_5,
1071
+ frequency_widget=self.ui.state_frequency_5,
1072
+ ),
1073
+ 6: StateSettings(
1074
+ digit=6,
1075
+ enabled_widget=self.ui.enable_state_6,
1076
+ name_widget=self.ui.state_name_6,
1077
+ is_scored_widget=self.ui.state_scored_6,
1078
+ frequency_widget=self.ui.state_frequency_6,
1079
+ ),
1080
+ 7: StateSettings(
1081
+ digit=7,
1082
+ enabled_widget=self.ui.enable_state_7,
1083
+ name_widget=self.ui.state_name_7,
1084
+ is_scored_widget=self.ui.state_scored_7,
1085
+ frequency_widget=self.ui.state_frequency_7,
1086
+ ),
1087
+ 8: StateSettings(
1088
+ digit=8,
1089
+ enabled_widget=self.ui.enable_state_8,
1090
+ name_widget=self.ui.state_name_8,
1091
+ is_scored_widget=self.ui.state_scored_8,
1092
+ frequency_widget=self.ui.state_frequency_8,
1093
+ ),
1094
+ 9: StateSettings(
1095
+ digit=9,
1096
+ enabled_widget=self.ui.enable_state_9,
1097
+ name_widget=self.ui.state_name_9,
1098
+ is_scored_widget=self.ui.state_scored_9,
1099
+ frequency_widget=self.ui.state_frequency_9,
1100
+ ),
1101
+ 0: StateSettings(
1102
+ digit=0,
1103
+ enabled_widget=self.ui.enable_state_0,
1104
+ name_widget=self.ui.state_name_0,
1105
+ is_scored_widget=self.ui.state_scored_0,
1106
+ frequency_widget=self.ui.state_frequency_0,
1107
+ ),
1108
+ }
1109
+
1110
+ # update widget state to display current config
1111
+ states = {b.digit: b for b in self.brain_state_set.brain_states}
1112
+ for digit in range(10):
1113
+ if digit in states.keys():
1114
+ self.settings_widgets[digit].enabled_widget.setChecked(True)
1115
+ self.settings_widgets[digit].name_widget.setText(states[digit].name)
1116
+ self.settings_widgets[digit].is_scored_widget.setChecked(
1117
+ states[digit].is_scored
1118
+ )
1119
+ self.settings_widgets[digit].frequency_widget.setValue(
1120
+ states[digit].frequency
1121
+ )
1122
+ else:
1123
+ self.settings_widgets[digit].enabled_widget.setChecked(False)
1124
+ self.settings_widgets[digit].name_widget.setEnabled(False)
1125
+ self.settings_widgets[digit].is_scored_widget.setEnabled(False)
1126
+ self.settings_widgets[digit].frequency_widget.setEnabled(False)
1127
+
1128
+ # set callbacks
1129
+ for digit in range(10):
1130
+ state = self.settings_widgets[digit]
1131
+ state.enabled_widget.stateChanged.connect(
1132
+ partial(self.set_brain_state_enabled, digit)
1133
+ )
1134
+ state.name_widget.editingFinished.connect(self.finished_editing_state_name)
1135
+ state.is_scored_widget.stateChanged.connect(
1136
+ partial(self.is_scored_changed, digit)
1137
+ )
1138
+ state.frequency_widget.valueChanged.connect(self.state_frequency_changed)
1139
+
1140
+ def set_brain_state_enabled(self, digit, e) -> None:
1141
+ """Called when user clicks "enabled" checkbox
1142
+
1143
+ :param digit: brain state digit
1144
+ :param e: unused but mandatory
1145
+ """
1146
+ # get the widgets for this brain state
1147
+ state = self.settings_widgets[digit]
1148
+ # update state of these widgets
1149
+ is_checked = state.enabled_widget.isChecked()
1150
+ for widget in [
1151
+ state.name_widget,
1152
+ state.is_scored_widget,
1153
+ ]:
1154
+ widget.setEnabled(is_checked)
1155
+ state.frequency_widget.setEnabled(
1156
+ is_checked and state.is_scored_widget.isChecked()
1157
+ )
1158
+ if not is_checked:
1159
+ state.name_widget.setText("")
1160
+ state.frequency_widget.setValue(0)
1161
+ # check that configuration is valid
1162
+ _ = self.check_config_validity()
1163
+
1164
+ def finished_editing_state_name(self) -> None:
1165
+ """Called when user finishes editing a brain state's name"""
1166
+ _ = self.check_config_validity()
1167
+
1168
+ def state_frequency_changed(self, new_value) -> None:
1169
+ """Called when user edits a brain state's frequency
1170
+
1171
+ :param new_value: unused
1172
+ """
1173
+ _ = self.check_config_validity()
1174
+
1175
+ def is_scored_changed(self, digit, e) -> None:
1176
+ """Called when user sets whether a state is scored
1177
+
1178
+ :param digit: brain state digit
1179
+ :param e: unused, but mandatory
1180
+ """
1181
+ # get the widgets for this brain state
1182
+ state = self.settings_widgets[digit]
1183
+ # update the state of these widgets
1184
+ is_checked = state.is_scored_widget.isChecked()
1185
+ state.frequency_widget.setEnabled(is_checked)
1186
+ if not is_checked:
1187
+ state.frequency_widget.setValue(0)
1188
+ # check that configuration is valid
1189
+ _ = self.check_config_validity()
1190
+
1191
+ def check_config_validity(self) -> str:
1192
+ """Check if brain state configuration on screen is valid"""
1193
+ # error message, if we get one
1194
+ message = None
1195
+
1196
+ # strip whitespace from brain state names and update display
1197
+ for digit in range(10):
1198
+ state = self.settings_widgets[digit]
1199
+ current_name = state.name_widget.text()
1200
+ formatted_name = current_name.strip()
1201
+ if current_name != formatted_name:
1202
+ state.name_widget.setText(formatted_name)
1203
+
1204
+ # check if names are unique and frequencies add up to 1
1205
+ names = []
1206
+ frequencies = []
1207
+ for digit in range(10):
1208
+ state = self.settings_widgets[digit]
1209
+ if state.enabled_widget.isChecked():
1210
+ names.append(state.name_widget.text())
1211
+ frequencies.append(state.frequency_widget.value())
1212
+ if len(names) != len(set(names)):
1213
+ message = "Error: names must be unique"
1214
+ if sum(frequencies) != 1:
1215
+ message = "Error: sum(frequencies) != 1"
1216
+
1217
+ if message is not None:
1218
+ self.ui.save_config_status.setText(message)
1219
+ self.ui.save_config_button.setEnabled(False)
1220
+ return message
1221
+
1222
+ self.ui.save_config_button.setEnabled(True)
1223
+ self.ui.save_config_status.setText("")
1224
+
1225
+ def save_brain_state_config(self):
1226
+ """Save configuration to file"""
1227
+ # check that configuration is valid
1228
+ error_message = self.check_config_validity()
1229
+ if error_message is not None:
1230
+ return
1231
+
1232
+ # build a BrainStateMapper object from the current configuration
1233
+ brain_states = list()
1234
+ for digit in range(10):
1235
+ state = self.settings_widgets[digit]
1236
+ if state.enabled_widget.isChecked():
1237
+ brain_states.append(
1238
+ BrainState(
1239
+ name=state.name_widget.text(),
1240
+ digit=digit,
1241
+ is_scored=state.is_scored_widget.isChecked(),
1242
+ frequency=state.frequency_widget.value(),
1243
+ )
1244
+ )
1245
+ self.brain_state_set = BrainStateSet(brain_states, UNDEFINED_LABEL)
1246
+
1247
+ # save to file
1248
+ save_config(self.brain_state_set)
1249
+ self.ui.save_config_status.setText("configuration saved")
1250
+
1251
+
1252
+ def check_label_validity(
1253
+ labels: np.array,
1254
+ samples_in_recording: int,
1255
+ sampling_rate: int | float,
1256
+ epoch_length: int | float,
1257
+ brain_state_set: BrainStateSet,
1258
+ ) -> str:
1259
+ """Check whether a set of brain state labels is valid
1260
+
1261
+ This returns an error message if a problem is found with the
1262
+ brain state labels.
1263
+
1264
+ :param labels: brain state labels
1265
+ :param samples_in_recording: number of samples in the recording
1266
+ :param sampling_rate: sampling rate, in Hz
1267
+ :param epoch_length: epoch length, in seconds
1268
+ :param brain_state_set: BrainStateMapper object
1269
+ :return: error message
1270
+ """
1271
+ # check that number of labels is correct
1272
+ samples_per_epoch = round(sampling_rate * epoch_length)
1273
+ epochs_in_recording = round(samples_in_recording / samples_per_epoch)
1274
+ if epochs_in_recording != labels.size:
1275
+ return LABEL_LENGTH_ERROR
1276
+
1277
+ # check that entries are valid
1278
+ if not set(labels.tolist()).issubset(
1279
+ set([b.digit for b in brain_state_set.brain_states] + [UNDEFINED_LABEL])
1280
+ ):
1281
+ return "label file contains invalid entries"
1282
+
1283
+
1284
+ def check_config_consistency(
1285
+ current_brain_states: dict,
1286
+ model_brain_states: dict,
1287
+ current_epoch_length: int | float,
1288
+ model_epoch_length: int | float,
1289
+ ) -> list[str]:
1290
+ """Compare current brain state config to the model's config
1291
+
1292
+ This only displays warnings - the user should decide whether to proceed
1293
+
1294
+ :param current_brain_states: current brain state config
1295
+ :param model_brain_states: brain state config when the model was created
1296
+ :param current_epoch_length: current epoch length setting
1297
+ :param model_epoch_length: epoch length used when the model was created
1298
+ """
1299
+ output = list()
1300
+
1301
+ # make lists of names and digits for scored brain states
1302
+ current_scored_states = {
1303
+ f: [b[f] for b in current_brain_states if b["is_scored"]]
1304
+ for f in ["name", "digit"]
1305
+ }
1306
+ model_scored_states = {
1307
+ f: [b[f] for b in model_brain_states if b["is_scored"]]
1308
+ for f in ["name", "digit"]
1309
+ }
1310
+
1311
+ # generate message comparing the brain state configs
1312
+ config_comparisons = list()
1313
+ for config, config_name in zip(
1314
+ [current_scored_states, model_scored_states], ["current", "model's"]
1315
+ ):
1316
+ config_comparisons.append(
1317
+ f"Scored brain states in {config_name} configuration: "
1318
+ f"""{
1319
+ ", ".join(
1320
+ [
1321
+ f"{x}: {y}"
1322
+ for x, y in zip(
1323
+ config["digit"],
1324
+ config["name"],
1325
+ )
1326
+ ]
1327
+ )
1328
+ }"""
1329
+ )
1330
+
1331
+ # check if the number of scored states is different
1332
+ len_diff = len(current_scored_states["name"]) - len(model_scored_states["name"])
1333
+ if len_diff != 0:
1334
+ output.append(
1335
+ (
1336
+ "WARNING: current brain state configuration has "
1337
+ f"{'fewer' if len_diff < 0 else 'more'} "
1338
+ "scored brain states than the model's configuration."
1339
+ )
1340
+ )
1341
+ output = output + config_comparisons
1342
+ else:
1343
+ # the length is the same, but names might be different
1344
+ if current_scored_states["name"] != model_scored_states["name"]:
1345
+ output.append(
1346
+ (
1347
+ "WARNING: current brain state configuration appears "
1348
+ "to contain different brain states than "
1349
+ "the model's configuration."
1350
+ )
1351
+ )
1352
+ output = output + config_comparisons
1353
+
1354
+ if current_epoch_length != model_epoch_length:
1355
+ output.append(
1356
+ (
1357
+ "Warning: the epoch length used when training this model "
1358
+ "does not match the current epoch length setting."
1359
+ )
1360
+ )
1361
+
1362
+ return output
1363
+
1364
+
1365
+ def run_primary_window() -> None:
1366
+ app = QtWidgets.QApplication(sys.argv)
1367
+ AccuSleepWindow()
1368
+ sys.exit(app.exec())
1369
+
1370
+
1371
+ if __name__ == "__main__":
1372
+ run_primary_window()