accusleepy 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/classification.py +49 -15
- accusleepy/config.json +15 -1
- accusleepy/constants.py +29 -2
- accusleepy/fileio.py +107 -33
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/images/viewer_window.png +0 -0
- accusleepy/gui/images/viewer_window_annotated.png +0 -0
- accusleepy/gui/main.py +220 -42
- accusleepy/gui/manual_scoring.py +38 -8
- accusleepy/gui/mplwidget.py +54 -29
- accusleepy/gui/primary_window.py +937 -254
- accusleepy/gui/primary_window.ui +3182 -2227
- accusleepy/gui/resources.qrc +1 -1
- accusleepy/gui/text/main_guide.md +18 -12
- accusleepy/gui/viewer_window.py +19 -7
- accusleepy/gui/viewer_window.ui +34 -2
- accusleepy/models.py +11 -1
- accusleepy/signal_processing.py +40 -17
- accusleepy/temperature_scaling.py +157 -0
- {accusleepy-0.5.0.dist-info → accusleepy-0.7.0.dist-info}/METADATA +11 -2
- accusleepy-0.7.0.dist-info/RECORD +41 -0
- {accusleepy-0.5.0.dist-info → accusleepy-0.7.0.dist-info}/WHEEL +1 -1
- accusleepy/gui/text/config_guide.txt +0 -29
- accusleepy-0.5.0.dist-info/RECORD +0 -41
accusleepy/gui/main.py
CHANGED
|
@@ -36,8 +36,16 @@ from accusleepy.bouts import enforce_min_bout_length
|
|
|
36
36
|
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
37
37
|
from accusleepy.constants import (
|
|
38
38
|
ANNOTATIONS_FILENAME,
|
|
39
|
+
CALIBRATION_ANNOTATION_FILENAME,
|
|
39
40
|
CALIBRATION_FILE_TYPE,
|
|
40
41
|
DEFAULT_MODEL_TYPE,
|
|
42
|
+
DEFAULT_EMG_FILTER_ORDER,
|
|
43
|
+
DEFAULT_EMG_BP_LOWER,
|
|
44
|
+
DEFAULT_EMG_BP_UPPER,
|
|
45
|
+
DEFAULT_BATCH_SIZE,
|
|
46
|
+
DEFAULT_LEARNING_RATE,
|
|
47
|
+
DEFAULT_MOMENTUM,
|
|
48
|
+
DEFAULT_TRAINING_EPOCHS,
|
|
41
49
|
LABEL_FILE_TYPE,
|
|
42
50
|
MODEL_FILE_TYPE,
|
|
43
51
|
REAL_TIME_MODEL_TYPE,
|
|
@@ -55,6 +63,8 @@ from accusleepy.fileio import (
|
|
|
55
63
|
save_config,
|
|
56
64
|
save_labels,
|
|
57
65
|
save_recording_list,
|
|
66
|
+
EMGFilter,
|
|
67
|
+
Hyperparameters,
|
|
58
68
|
)
|
|
59
69
|
from accusleepy.gui.manual_scoring import ManualScoringWindow
|
|
60
70
|
from accusleepy.gui.primary_window import Ui_PrimaryWindow
|
|
@@ -96,21 +106,31 @@ class AccuSleepWindow(QMainWindow):
|
|
|
96
106
|
self.setWindowTitle("AccuSleePy")
|
|
97
107
|
|
|
98
108
|
# fill in settings tab
|
|
99
|
-
|
|
109
|
+
(
|
|
110
|
+
self.brain_state_set,
|
|
111
|
+
self.epoch_length,
|
|
112
|
+
self.only_overwrite_undefined,
|
|
113
|
+
self.save_confidence_scores,
|
|
114
|
+
self.min_bout_length,
|
|
115
|
+
self.emg_filter,
|
|
116
|
+
self.hyperparameters,
|
|
117
|
+
) = load_config()
|
|
118
|
+
|
|
100
119
|
self.settings_widgets = None
|
|
101
120
|
self.initialize_settings_tab()
|
|
102
121
|
|
|
103
122
|
# initialize info about the recordings, classification data / settings
|
|
104
123
|
self.ui.epoch_length_input.setValue(self.epoch_length)
|
|
124
|
+
self.ui.overwritecheckbox.setChecked(self.only_overwrite_undefined)
|
|
125
|
+
self.ui.save_confidence_checkbox.setChecked(self.save_confidence_scores)
|
|
126
|
+
self.ui.bout_length_input.setValue(self.min_bout_length)
|
|
105
127
|
self.model = None
|
|
106
|
-
self.only_overwrite_undefined = False
|
|
107
|
-
self.min_bout_length = 5
|
|
108
128
|
|
|
109
129
|
# initialize model training variables
|
|
110
130
|
self.training_epochs_per_img = 9
|
|
111
131
|
self.delete_training_images = True
|
|
112
|
-
self.training_image_dir = ""
|
|
113
132
|
self.model_type = DEFAULT_MODEL_TYPE
|
|
133
|
+
self.calibrate_trained_model = True
|
|
114
134
|
|
|
115
135
|
# metadata for the currently loaded classification model
|
|
116
136
|
self.model_epoch_length = None
|
|
@@ -166,16 +186,25 @@ class AccuSleepWindow(QMainWindow):
|
|
|
166
186
|
self.ui.load_model_button.clicked.connect(partial(self.load_model, None))
|
|
167
187
|
self.ui.score_all_button.clicked.connect(self.score_all)
|
|
168
188
|
self.ui.overwritecheckbox.stateChanged.connect(self.update_overwrite_policy)
|
|
189
|
+
self.ui.save_confidence_checkbox.stateChanged.connect(
|
|
190
|
+
self.update_confidence_policy
|
|
191
|
+
)
|
|
169
192
|
self.ui.bout_length_input.valueChanged.connect(self.update_min_bout_length)
|
|
170
193
|
self.ui.user_manual_button.clicked.connect(self.show_user_manual)
|
|
171
194
|
self.ui.image_number_input.valueChanged.connect(self.update_epochs_per_img)
|
|
172
195
|
self.ui.delete_image_box.stateChanged.connect(self.update_image_deletion)
|
|
173
|
-
self.ui.
|
|
196
|
+
self.ui.calibrate_checkbox.stateChanged.connect(
|
|
197
|
+
self.update_training_calibration
|
|
198
|
+
)
|
|
174
199
|
self.ui.train_model_button.clicked.connect(self.train_model)
|
|
175
200
|
self.ui.save_config_button.clicked.connect(self.save_brain_state_config)
|
|
176
201
|
self.ui.export_button.clicked.connect(self.export_recording_list)
|
|
177
202
|
self.ui.import_button.clicked.connect(self.import_recording_list)
|
|
178
203
|
self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
|
|
204
|
+
self.ui.reset_emg_params_button.clicked.connect(self.reset_emg_filter_settings)
|
|
205
|
+
self.ui.reset_hyperparams_button.clicked.connect(
|
|
206
|
+
self.reset_hyperparams_settings
|
|
207
|
+
)
|
|
179
208
|
|
|
180
209
|
# user input: drag and drop
|
|
181
210
|
self.ui.recording_file_label.installEventFilter(self)
|
|
@@ -294,11 +323,12 @@ class AccuSleepWindow(QMainWindow):
|
|
|
294
323
|
)
|
|
295
324
|
)
|
|
296
325
|
return
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
)
|
|
301
|
-
|
|
326
|
+
|
|
327
|
+
# determine fraction of training data to use for calibration
|
|
328
|
+
if self.calibrate_trained_model:
|
|
329
|
+
calibration_fraction = self.ui.calibration_spinbox.value() / 100
|
|
330
|
+
else:
|
|
331
|
+
calibration_fraction = 0
|
|
302
332
|
|
|
303
333
|
# check some inputs for each recording
|
|
304
334
|
for recording_index in range(len(self.recordings)):
|
|
@@ -320,9 +350,10 @@ class AccuSleepWindow(QMainWindow):
|
|
|
320
350
|
return
|
|
321
351
|
model_filename = os.path.normpath(model_filename)
|
|
322
352
|
|
|
323
|
-
# create (probably temporary) image folder
|
|
353
|
+
# create (probably temporary) image folder in
|
|
354
|
+
# the same folder as the trained model
|
|
324
355
|
temp_image_dir = os.path.join(
|
|
325
|
-
|
|
356
|
+
os.path.dirname(model_filename),
|
|
326
357
|
"images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
|
|
327
358
|
)
|
|
328
359
|
|
|
@@ -334,7 +365,12 @@ class AccuSleepWindow(QMainWindow):
|
|
|
334
365
|
|
|
335
366
|
# create training images
|
|
336
367
|
self.show_message("Training, please wait. See console for progress updates.")
|
|
337
|
-
self.
|
|
368
|
+
if not self.delete_training_images:
|
|
369
|
+
self.show_message((f"Creating training images in {temp_image_dir}"))
|
|
370
|
+
else:
|
|
371
|
+
self.show_message(
|
|
372
|
+
(f"Creating temporary folder of training images: {temp_image_dir}")
|
|
373
|
+
)
|
|
338
374
|
self.ui.message_area.repaint()
|
|
339
375
|
QApplication.processEvents()
|
|
340
376
|
print("Creating training images")
|
|
@@ -345,14 +381,17 @@ class AccuSleepWindow(QMainWindow):
|
|
|
345
381
|
epochs_per_img=self.training_epochs_per_img,
|
|
346
382
|
brain_state_set=self.brain_state_set,
|
|
347
383
|
model_type=self.model_type,
|
|
384
|
+
calibration_fraction=calibration_fraction,
|
|
385
|
+
emg_filter=self.emg_filter,
|
|
348
386
|
)
|
|
349
387
|
if len(failed_recordings) > 0:
|
|
350
388
|
if len(failed_recordings) == len(self.recordings):
|
|
351
389
|
self.show_message("ERROR: no recordings were valid!")
|
|
390
|
+
return
|
|
352
391
|
else:
|
|
353
392
|
self.show_message(
|
|
354
393
|
(
|
|
355
|
-
"WARNING: the following recordings could not be"
|
|
394
|
+
"WARNING: the following recordings could not be "
|
|
356
395
|
"loaded and will not be used for training: "
|
|
357
396
|
f"{', '.join([str(r) for r in failed_recordings])}"
|
|
358
397
|
)
|
|
@@ -363,16 +402,32 @@ class AccuSleepWindow(QMainWindow):
|
|
|
363
402
|
self.ui.message_area.repaint()
|
|
364
403
|
QApplication.processEvents()
|
|
365
404
|
print("Training model")
|
|
366
|
-
from accusleepy.classification import train_ssann
|
|
405
|
+
from accusleepy.classification import create_dataloader, train_ssann
|
|
367
406
|
from accusleepy.models import save_model
|
|
407
|
+
from accusleepy.temperature_scaling import ModelWithTemperature
|
|
368
408
|
|
|
369
409
|
model = train_ssann(
|
|
370
410
|
annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
|
|
371
411
|
img_dir=temp_image_dir,
|
|
372
412
|
mixture_weights=self.brain_state_set.mixture_weights,
|
|
373
413
|
n_classes=self.brain_state_set.n_classes,
|
|
414
|
+
hyperparameters=self.hyperparameters,
|
|
374
415
|
)
|
|
375
416
|
|
|
417
|
+
# calibrate the model
|
|
418
|
+
if self.calibrate_trained_model:
|
|
419
|
+
calibration_annotation_file = os.path.join(
|
|
420
|
+
temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
|
|
421
|
+
)
|
|
422
|
+
calibration_dataloader = create_dataloader(
|
|
423
|
+
annotations_file=calibration_annotation_file,
|
|
424
|
+
img_dir=temp_image_dir,
|
|
425
|
+
hyperparameters=self.hyperparameters,
|
|
426
|
+
)
|
|
427
|
+
model = ModelWithTemperature(model)
|
|
428
|
+
print("Calibrating model")
|
|
429
|
+
model.set_temperature(calibration_dataloader)
|
|
430
|
+
|
|
376
431
|
# save model
|
|
377
432
|
save_model(
|
|
378
433
|
model=model,
|
|
@@ -381,29 +436,26 @@ class AccuSleepWindow(QMainWindow):
|
|
|
381
436
|
epochs_per_img=self.training_epochs_per_img,
|
|
382
437
|
model_type=self.model_type,
|
|
383
438
|
brain_state_set=self.brain_state_set,
|
|
439
|
+
is_calibrated=self.calibrate_trained_model,
|
|
384
440
|
)
|
|
385
441
|
|
|
386
442
|
# optionally delete images
|
|
387
443
|
if self.delete_training_images:
|
|
444
|
+
print("Cleaning up training image folder")
|
|
388
445
|
shutil.rmtree(temp_image_dir)
|
|
389
446
|
|
|
390
447
|
self.show_message(f"Training complete. Saved model to {model_filename}")
|
|
391
448
|
print("Training complete.")
|
|
392
449
|
|
|
393
|
-
def set_training_folder(self) -> None:
|
|
394
|
-
"""Select location in which to create a folder for training images"""
|
|
395
|
-
training_folder_parent = QFileDialog.getExistingDirectory(
|
|
396
|
-
self, "Select directory for training images"
|
|
397
|
-
)
|
|
398
|
-
if training_folder_parent:
|
|
399
|
-
training_folder_parent = os.path.normpath(training_folder_parent)
|
|
400
|
-
self.training_image_dir = training_folder_parent
|
|
401
|
-
self.ui.image_folder_label.setText(training_folder_parent)
|
|
402
|
-
|
|
403
450
|
def update_image_deletion(self) -> None:
|
|
404
451
|
"""Update choice of whether to delete images after training"""
|
|
405
452
|
self.delete_training_images = self.ui.delete_image_box.isChecked()
|
|
406
453
|
|
|
454
|
+
def update_training_calibration(self) -> None:
|
|
455
|
+
"""Update choice of whether to calibrate model after training"""
|
|
456
|
+
self.calibrate_trained_model = self.ui.calibrate_checkbox.isChecked()
|
|
457
|
+
self.ui.calibration_spinbox.setEnabled(self.calibrate_trained_model)
|
|
458
|
+
|
|
407
459
|
def update_epochs_per_img(self, new_value) -> None:
|
|
408
460
|
"""Update number of epochs per image
|
|
409
461
|
|
|
@@ -491,7 +543,8 @@ class AccuSleepWindow(QMainWindow):
|
|
|
491
543
|
label_file = self.recordings[recording_index].label_file
|
|
492
544
|
if os.path.isfile(label_file):
|
|
493
545
|
try:
|
|
494
|
-
|
|
546
|
+
# ignore any existing confidence scores; they will all be overwritten
|
|
547
|
+
existing_labels, _ = load_labels(label_file)
|
|
495
548
|
except Exception:
|
|
496
549
|
self.show_message(
|
|
497
550
|
(
|
|
@@ -544,7 +597,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
544
597
|
)
|
|
545
598
|
continue
|
|
546
599
|
|
|
547
|
-
labels = score_recording(
|
|
600
|
+
labels, confidence_scores = score_recording(
|
|
548
601
|
model=self.model,
|
|
549
602
|
eeg=eeg,
|
|
550
603
|
emg=emg,
|
|
@@ -554,6 +607,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
554
607
|
epoch_length=self.epoch_length,
|
|
555
608
|
epochs_per_img=self.model_epochs_per_img,
|
|
556
609
|
brain_state_set=self.brain_state_set,
|
|
610
|
+
emg_filter=self.emg_filter,
|
|
557
611
|
)
|
|
558
612
|
|
|
559
613
|
# overwrite as needed
|
|
@@ -569,8 +623,14 @@ class AccuSleepWindow(QMainWindow):
|
|
|
569
623
|
min_bout_length=self.min_bout_length,
|
|
570
624
|
)
|
|
571
625
|
|
|
626
|
+
# ignore confidence scores if desired
|
|
627
|
+
if not self.save_confidence_scores:
|
|
628
|
+
confidence_scores = None
|
|
629
|
+
|
|
572
630
|
# save results
|
|
573
|
-
save_labels(
|
|
631
|
+
save_labels(
|
|
632
|
+
labels=labels, filename=label_file, confidence_scores=confidence_scores
|
|
633
|
+
)
|
|
574
634
|
self.show_message(
|
|
575
635
|
(
|
|
576
636
|
"Saved labels for recording "
|
|
@@ -586,8 +646,6 @@ class AccuSleepWindow(QMainWindow):
|
|
|
586
646
|
|
|
587
647
|
:param filename: model filename, if it's known
|
|
588
648
|
"""
|
|
589
|
-
from accusleepy.models import load_model
|
|
590
|
-
|
|
591
649
|
if filename is None:
|
|
592
650
|
file_dialog = QFileDialog(self)
|
|
593
651
|
file_dialog.setWindowTitle("Select classification model")
|
|
@@ -606,6 +664,12 @@ class AccuSleepWindow(QMainWindow):
|
|
|
606
664
|
self.show_message("ERROR: model file does not exist")
|
|
607
665
|
return
|
|
608
666
|
|
|
667
|
+
self.show_message("Loading classification model")
|
|
668
|
+
self.ui.message_area.repaint()
|
|
669
|
+
QApplication.processEvents()
|
|
670
|
+
|
|
671
|
+
from accusleepy.models import load_model
|
|
672
|
+
|
|
609
673
|
try:
|
|
610
674
|
model, epoch_length, epochs_per_img, model_type, brain_states = load_model(
|
|
611
675
|
filename=filename
|
|
@@ -648,6 +712,8 @@ class AccuSleepWindow(QMainWindow):
|
|
|
648
712
|
if len(config_warnings) > 0:
|
|
649
713
|
for w in config_warnings:
|
|
650
714
|
self.show_message(w)
|
|
715
|
+
else:
|
|
716
|
+
self.show_message(f"Loaded classification model from {filename}")
|
|
651
717
|
|
|
652
718
|
self.ui.model_label.setText(filename)
|
|
653
719
|
|
|
@@ -716,7 +782,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
716
782
|
self.show_message("ERROR: label file does not exist")
|
|
717
783
|
return
|
|
718
784
|
try:
|
|
719
|
-
labels = load_labels(label_file)
|
|
785
|
+
labels, _ = load_labels(label_file)
|
|
720
786
|
except Exception:
|
|
721
787
|
self.ui.calibration_status.setText("could not load labels")
|
|
722
788
|
self.show_message(
|
|
@@ -728,6 +794,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
728
794
|
return
|
|
729
795
|
label_error_message = check_label_validity(
|
|
730
796
|
labels=labels,
|
|
797
|
+
confidence_scores=None,
|
|
731
798
|
samples_in_recording=eeg.size,
|
|
732
799
|
sampling_rate=sampling_rate,
|
|
733
800
|
epoch_length=self.epoch_length,
|
|
@@ -758,6 +825,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
758
825
|
sampling_rate=sampling_rate,
|
|
759
826
|
epoch_length=self.epoch_length,
|
|
760
827
|
brain_state_set=self.brain_state_set,
|
|
828
|
+
emg_filter=self.emg_filter,
|
|
761
829
|
)
|
|
762
830
|
|
|
763
831
|
self.ui.calibration_status.setText("")
|
|
@@ -814,6 +882,15 @@ class AccuSleepWindow(QMainWindow):
|
|
|
814
882
|
"""
|
|
815
883
|
self.only_overwrite_undefined = checked
|
|
816
884
|
|
|
885
|
+
def update_confidence_policy(self, checked) -> None:
|
|
886
|
+
"""Toggle policy for saving confidence scores
|
|
887
|
+
|
|
888
|
+
If the checkbox is enabled, confidence scores will be saved to the label files.
|
|
889
|
+
|
|
890
|
+
:param checked: state of the checkbox
|
|
891
|
+
"""
|
|
892
|
+
self.save_confidence_scores = checked
|
|
893
|
+
|
|
817
894
|
def manual_scoring(self) -> None:
|
|
818
895
|
"""View the selected recording for manual scoring"""
|
|
819
896
|
# immediately display a status message
|
|
@@ -833,7 +910,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
833
910
|
label_file = self.recordings[self.recording_index].label_file
|
|
834
911
|
if os.path.isfile(label_file):
|
|
835
912
|
try:
|
|
836
|
-
labels = load_labels(label_file)
|
|
913
|
+
labels, confidence_scores = load_labels(label_file)
|
|
837
914
|
except Exception:
|
|
838
915
|
self.ui.manual_scoring_status.setText("could not load labels")
|
|
839
916
|
self.show_message(
|
|
@@ -848,10 +925,14 @@ class AccuSleepWindow(QMainWindow):
|
|
|
848
925
|
np.ones(int(eeg.size / (sampling_rate * self.epoch_length)))
|
|
849
926
|
* UNDEFINED_LABEL
|
|
850
927
|
).astype(int)
|
|
928
|
+
# manual scoring will not add a new confidence score column
|
|
929
|
+
# to a label file that does not have one
|
|
930
|
+
confidence_scores = None
|
|
851
931
|
|
|
852
932
|
# check that all labels are valid
|
|
853
933
|
label_error = check_label_validity(
|
|
854
934
|
labels=labels,
|
|
935
|
+
confidence_scores=confidence_scores,
|
|
855
936
|
samples_in_recording=eeg.size,
|
|
856
937
|
sampling_rate=sampling_rate,
|
|
857
938
|
epoch_length=self.epoch_length,
|
|
@@ -866,6 +947,10 @@ class AccuSleepWindow(QMainWindow):
|
|
|
866
947
|
epochs_in_recording = round(eeg.size / samples_per_epoch)
|
|
867
948
|
if epochs_in_recording - labels.size == 1:
|
|
868
949
|
labels = np.concatenate((labels, np.array([UNDEFINED_LABEL])))
|
|
950
|
+
if confidence_scores is not None:
|
|
951
|
+
confidence_scores = np.concatenate(
|
|
952
|
+
(confidence_scores, np.array([0]))
|
|
953
|
+
)
|
|
869
954
|
self.show_message(
|
|
870
955
|
(
|
|
871
956
|
"WARNING: an undefined epoch was added to "
|
|
@@ -874,6 +959,8 @@ class AccuSleepWindow(QMainWindow):
|
|
|
874
959
|
)
|
|
875
960
|
elif labels.size - epochs_in_recording == 1:
|
|
876
961
|
labels = labels[:-1]
|
|
962
|
+
if confidence_scores is not None:
|
|
963
|
+
confidence_scores = confidence_scores[:-1]
|
|
877
964
|
self.show_message(
|
|
878
965
|
(
|
|
879
966
|
"WARNING: the last epoch was removed from "
|
|
@@ -900,8 +987,10 @@ class AccuSleepWindow(QMainWindow):
|
|
|
900
987
|
emg=emg,
|
|
901
988
|
label_file=label_file,
|
|
902
989
|
labels=labels,
|
|
990
|
+
confidence_scores=confidence_scores,
|
|
903
991
|
sampling_rate=sampling_rate,
|
|
904
992
|
epoch_length=self.epoch_length,
|
|
993
|
+
emg_filter=self.emg_filter,
|
|
905
994
|
)
|
|
906
995
|
manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
|
|
907
996
|
manual_scoring_window.exec()
|
|
@@ -1067,15 +1156,6 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1067
1156
|
|
|
1068
1157
|
def initialize_settings_tab(self):
|
|
1069
1158
|
"""Populate settings tab and assign its callbacks"""
|
|
1070
|
-
# show information about the settings tab
|
|
1071
|
-
config_guide_file = open(
|
|
1072
|
-
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_GUIDE_FILE),
|
|
1073
|
-
"r",
|
|
1074
|
-
)
|
|
1075
|
-
config_guide_text = config_guide_file.read()
|
|
1076
|
-
config_guide_file.close()
|
|
1077
|
-
self.ui.settings_text.setText(config_guide_text)
|
|
1078
|
-
|
|
1079
1159
|
# store dictionary that maps digits to rows of widgets
|
|
1080
1160
|
# in the settings tab
|
|
1081
1161
|
self.settings_widgets = {
|
|
@@ -1152,7 +1232,21 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1152
1232
|
}
|
|
1153
1233
|
|
|
1154
1234
|
# update widget state to display current config
|
|
1235
|
+
# UI defaults
|
|
1155
1236
|
self.ui.default_epoch_input.setValue(self.epoch_length)
|
|
1237
|
+
self.ui.overwrite_default_checkbox.setChecked(self.only_overwrite_undefined)
|
|
1238
|
+
self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_scores)
|
|
1239
|
+
self.ui.default_min_bout_length_spinbox.setValue(self.min_bout_length)
|
|
1240
|
+
# EMG filter
|
|
1241
|
+
self.ui.emg_order_spinbox.setValue(self.emg_filter.order)
|
|
1242
|
+
self.ui.bp_lower_spinbox.setValue(self.emg_filter.bp_lower)
|
|
1243
|
+
self.ui.bp_upper_spinbox.setValue(self.emg_filter.bp_upper)
|
|
1244
|
+
# model training hyperparameters
|
|
1245
|
+
self.ui.batch_size_spinbox.setValue(self.hyperparameters.batch_size)
|
|
1246
|
+
self.ui.learning_rate_spinbox.setValue(self.hyperparameters.learning_rate)
|
|
1247
|
+
self.ui.momentum_spinbox.setValue(self.hyperparameters.momentum)
|
|
1248
|
+
self.ui.training_epochs_spinbox.setValue(self.hyperparameters.training_epochs)
|
|
1249
|
+
# brain states
|
|
1156
1250
|
states = {b.digit: b for b in self.brain_state_set.brain_states}
|
|
1157
1251
|
for digit in range(10):
|
|
1158
1252
|
if digit in states.keys():
|
|
@@ -1171,6 +1265,15 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1171
1265
|
self.settings_widgets[digit].frequency_widget.setEnabled(False)
|
|
1172
1266
|
|
|
1173
1267
|
# set callbacks
|
|
1268
|
+
self.ui.emg_order_spinbox.valueChanged.connect(self.emg_filter_order_changed)
|
|
1269
|
+
self.ui.bp_lower_spinbox.valueChanged.connect(self.emg_filter_bp_lower_changed)
|
|
1270
|
+
self.ui.bp_upper_spinbox.valueChanged.connect(self.emg_filter_bp_upper_changed)
|
|
1271
|
+
self.ui.batch_size_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1272
|
+
self.ui.learning_rate_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1273
|
+
self.ui.momentum_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1274
|
+
self.ui.training_epochs_spinbox.valueChanged.connect(
|
|
1275
|
+
self.hyperparameters_changed
|
|
1276
|
+
)
|
|
1174
1277
|
for digit in range(10):
|
|
1175
1278
|
state = self.settings_widgets[digit]
|
|
1176
1279
|
state.enabled_widget.stateChanged.connect(
|
|
@@ -1233,6 +1336,41 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1233
1336
|
# check that configuration is valid
|
|
1234
1337
|
_ = self.check_config_validity()
|
|
1235
1338
|
|
|
1339
|
+
def emg_filter_order_changed(self, new_value: int) -> None:
|
|
1340
|
+
"""Called when user modifies EMG filter order
|
|
1341
|
+
|
|
1342
|
+
:param new_value: new EMG filter order
|
|
1343
|
+
"""
|
|
1344
|
+
self.emg_filter.order = new_value
|
|
1345
|
+
|
|
1346
|
+
def emg_filter_bp_lower_changed(self, new_value: int | float) -> None:
|
|
1347
|
+
"""Called when user modifies EMG filter lower cutoff
|
|
1348
|
+
|
|
1349
|
+
:param new_value: new lower bandpass cutoff frequency
|
|
1350
|
+
"""
|
|
1351
|
+
self.emg_filter.bp_lower = new_value
|
|
1352
|
+
_ = self.check_config_validity()
|
|
1353
|
+
|
|
1354
|
+
def emg_filter_bp_upper_changed(self, new_value: int | float) -> None:
|
|
1355
|
+
"""Called when user modifies EMG filter upper cutoff
|
|
1356
|
+
|
|
1357
|
+
:param new_value: new upper bandpass cutoff frequency
|
|
1358
|
+
"""
|
|
1359
|
+
self.emg_filter.bp_upper = new_value
|
|
1360
|
+
_ = self.check_config_validity()
|
|
1361
|
+
|
|
1362
|
+
def hyperparameters_changed(self, new_value) -> None:
|
|
1363
|
+
"""Called when user modifies model training hyperparameters
|
|
1364
|
+
|
|
1365
|
+
:param new_value: unused
|
|
1366
|
+
"""
|
|
1367
|
+
self.hyperparameters = Hyperparameters(
|
|
1368
|
+
batch_size=self.ui.batch_size_spinbox.value(),
|
|
1369
|
+
learning_rate=self.ui.learning_rate_spinbox.value(),
|
|
1370
|
+
momentum=self.ui.momentum_spinbox.value(),
|
|
1371
|
+
training_epochs=self.ui.training_epochs_spinbox.value(),
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1236
1374
|
def check_config_validity(self) -> str:
|
|
1237
1375
|
"""Check if brain state configuration on screen is valid"""
|
|
1238
1376
|
# error message, if we get one
|
|
@@ -1259,6 +1397,10 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1259
1397
|
if sum(frequencies) != 1:
|
|
1260
1398
|
message = "Error: sum(frequencies) != 1"
|
|
1261
1399
|
|
|
1400
|
+
# check validity of EMG filter settings
|
|
1401
|
+
if self.emg_filter.bp_lower >= self.emg_filter.bp_upper:
|
|
1402
|
+
message = "Error: EMG filter cutoff frequencies are invalid"
|
|
1403
|
+
|
|
1262
1404
|
if message is not None:
|
|
1263
1405
|
self.ui.save_config_status.setText(message)
|
|
1264
1406
|
self.ui.save_config_button.setEnabled(False)
|
|
@@ -1290,12 +1432,41 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1290
1432
|
self.brain_state_set = BrainStateSet(brain_states, UNDEFINED_LABEL)
|
|
1291
1433
|
|
|
1292
1434
|
# save to file
|
|
1293
|
-
save_config(
|
|
1435
|
+
save_config(
|
|
1436
|
+
brain_state_set=self.brain_state_set,
|
|
1437
|
+
default_epoch_length=self.ui.default_epoch_input.value(),
|
|
1438
|
+
overwrite_setting=self.ui.overwrite_default_checkbox.isChecked(),
|
|
1439
|
+
save_confidence_setting=self.ui.confidence_setting_checkbox.isChecked(),
|
|
1440
|
+
min_bout_length=self.ui.default_min_bout_length_spinbox.value(),
|
|
1441
|
+
emg_filter=EMGFilter(
|
|
1442
|
+
order=self.emg_filter.order,
|
|
1443
|
+
bp_lower=self.emg_filter.bp_lower,
|
|
1444
|
+
bp_upper=self.emg_filter.bp_upper,
|
|
1445
|
+
),
|
|
1446
|
+
hyperparameters=Hyperparameters(
|
|
1447
|
+
batch_size=self.hyperparameters.batch_size,
|
|
1448
|
+
learning_rate=self.hyperparameters.learning_rate,
|
|
1449
|
+
momentum=self.hyperparameters.momentum,
|
|
1450
|
+
training_epochs=self.hyperparameters.training_epochs,
|
|
1451
|
+
),
|
|
1452
|
+
)
|
|
1294
1453
|
self.ui.save_config_status.setText("configuration saved")
|
|
1295
1454
|
|
|
1455
|
+
def reset_emg_filter_settings(self) -> None:
|
|
1456
|
+
self.ui.emg_order_spinbox.setValue(DEFAULT_EMG_FILTER_ORDER)
|
|
1457
|
+
self.ui.bp_lower_spinbox.setValue(DEFAULT_EMG_BP_LOWER)
|
|
1458
|
+
self.ui.bp_upper_spinbox.setValue(DEFAULT_EMG_BP_UPPER)
|
|
1459
|
+
|
|
1460
|
+
def reset_hyperparams_settings(self):
|
|
1461
|
+
self.ui.batch_size_spinbox.setValue(DEFAULT_BATCH_SIZE)
|
|
1462
|
+
self.ui.learning_rate_spinbox.setValue(DEFAULT_LEARNING_RATE)
|
|
1463
|
+
self.ui.momentum_spinbox.setValue(DEFAULT_MOMENTUM)
|
|
1464
|
+
self.ui.training_epochs_spinbox.setValue(DEFAULT_TRAINING_EPOCHS)
|
|
1465
|
+
|
|
1296
1466
|
|
|
1297
1467
|
def check_label_validity(
|
|
1298
1468
|
labels: np.array,
|
|
1469
|
+
confidence_scores: np.array,
|
|
1299
1470
|
samples_in_recording: int,
|
|
1300
1471
|
sampling_rate: int | float,
|
|
1301
1472
|
epoch_length: int | float,
|
|
@@ -1307,6 +1478,7 @@ def check_label_validity(
|
|
|
1307
1478
|
brain state labels.
|
|
1308
1479
|
|
|
1309
1480
|
:param labels: brain state labels
|
|
1481
|
+
:param confidence_scores: confidence scores
|
|
1310
1482
|
:param samples_in_recording: number of samples in the recording
|
|
1311
1483
|
:param sampling_rate: sampling rate, in Hz
|
|
1312
1484
|
:param epoch_length: epoch length, in seconds
|
|
@@ -1325,6 +1497,12 @@ def check_label_validity(
|
|
|
1325
1497
|
):
|
|
1326
1498
|
return "label file contains invalid entries"
|
|
1327
1499
|
|
|
1500
|
+
if confidence_scores is not None:
|
|
1501
|
+
if np.min(confidence_scores) < 0 or np.max(confidence_scores) > 1:
|
|
1502
|
+
return "label file contains invalid confidence scores"
|
|
1503
|
+
|
|
1504
|
+
return None
|
|
1505
|
+
|
|
1328
1506
|
|
|
1329
1507
|
def check_config_consistency(
|
|
1330
1508
|
current_brain_states: dict,
|