accusleepy 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
accusleepy/gui/main.py CHANGED
@@ -36,8 +36,16 @@ from accusleepy.bouts import enforce_min_bout_length
36
36
  from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
37
37
  from accusleepy.constants import (
38
38
  ANNOTATIONS_FILENAME,
39
+ CALIBRATION_ANNOTATION_FILENAME,
39
40
  CALIBRATION_FILE_TYPE,
40
41
  DEFAULT_MODEL_TYPE,
42
+ DEFAULT_EMG_FILTER_ORDER,
43
+ DEFAULT_EMG_BP_LOWER,
44
+ DEFAULT_EMG_BP_UPPER,
45
+ DEFAULT_BATCH_SIZE,
46
+ DEFAULT_LEARNING_RATE,
47
+ DEFAULT_MOMENTUM,
48
+ DEFAULT_TRAINING_EPOCHS,
41
49
  LABEL_FILE_TYPE,
42
50
  MODEL_FILE_TYPE,
43
51
  REAL_TIME_MODEL_TYPE,
@@ -55,6 +63,8 @@ from accusleepy.fileio import (
55
63
  save_config,
56
64
  save_labels,
57
65
  save_recording_list,
66
+ EMGFilter,
67
+ Hyperparameters,
58
68
  )
59
69
  from accusleepy.gui.manual_scoring import ManualScoringWindow
60
70
  from accusleepy.gui.primary_window import Ui_PrimaryWindow
@@ -96,21 +106,31 @@ class AccuSleepWindow(QMainWindow):
96
106
  self.setWindowTitle("AccuSleePy")
97
107
 
98
108
  # fill in settings tab
99
- self.brain_state_set, self.epoch_length = load_config()
109
+ (
110
+ self.brain_state_set,
111
+ self.epoch_length,
112
+ self.only_overwrite_undefined,
113
+ self.save_confidence_scores,
114
+ self.min_bout_length,
115
+ self.emg_filter,
116
+ self.hyperparameters,
117
+ ) = load_config()
118
+
100
119
  self.settings_widgets = None
101
120
  self.initialize_settings_tab()
102
121
 
103
122
  # initialize info about the recordings, classification data / settings
104
123
  self.ui.epoch_length_input.setValue(self.epoch_length)
124
+ self.ui.overwritecheckbox.setChecked(self.only_overwrite_undefined)
125
+ self.ui.save_confidence_checkbox.setChecked(self.save_confidence_scores)
126
+ self.ui.bout_length_input.setValue(self.min_bout_length)
105
127
  self.model = None
106
- self.only_overwrite_undefined = False
107
- self.min_bout_length = 5
108
128
 
109
129
  # initialize model training variables
110
130
  self.training_epochs_per_img = 9
111
131
  self.delete_training_images = True
112
- self.training_image_dir = ""
113
132
  self.model_type = DEFAULT_MODEL_TYPE
133
+ self.calibrate_trained_model = True
114
134
 
115
135
  # metadata for the currently loaded classification model
116
136
  self.model_epoch_length = None
@@ -166,16 +186,25 @@ class AccuSleepWindow(QMainWindow):
166
186
  self.ui.load_model_button.clicked.connect(partial(self.load_model, None))
167
187
  self.ui.score_all_button.clicked.connect(self.score_all)
168
188
  self.ui.overwritecheckbox.stateChanged.connect(self.update_overwrite_policy)
189
+ self.ui.save_confidence_checkbox.stateChanged.connect(
190
+ self.update_confidence_policy
191
+ )
169
192
  self.ui.bout_length_input.valueChanged.connect(self.update_min_bout_length)
170
193
  self.ui.user_manual_button.clicked.connect(self.show_user_manual)
171
194
  self.ui.image_number_input.valueChanged.connect(self.update_epochs_per_img)
172
195
  self.ui.delete_image_box.stateChanged.connect(self.update_image_deletion)
173
- self.ui.training_folder_button.clicked.connect(self.set_training_folder)
196
+ self.ui.calibrate_checkbox.stateChanged.connect(
197
+ self.update_training_calibration
198
+ )
174
199
  self.ui.train_model_button.clicked.connect(self.train_model)
175
200
  self.ui.save_config_button.clicked.connect(self.save_brain_state_config)
176
201
  self.ui.export_button.clicked.connect(self.export_recording_list)
177
202
  self.ui.import_button.clicked.connect(self.import_recording_list)
178
203
  self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
204
+ self.ui.reset_emg_params_button.clicked.connect(self.reset_emg_filter_settings)
205
+ self.ui.reset_hyperparams_button.clicked.connect(
206
+ self.reset_hyperparams_settings
207
+ )
179
208
 
180
209
  # user input: drag and drop
181
210
  self.ui.recording_file_label.installEventFilter(self)
@@ -294,11 +323,12 @@ class AccuSleepWindow(QMainWindow):
294
323
  )
295
324
  )
296
325
  return
297
- if self.training_image_dir == "":
298
- self.show_message(
299
- ("ERROR: no output location selected for training images.")
300
- )
301
- return
326
+
327
+ # determine fraction of training data to use for calibration
328
+ if self.calibrate_trained_model:
329
+ calibration_fraction = self.ui.calibration_spinbox.value() / 100
330
+ else:
331
+ calibration_fraction = 0
302
332
 
303
333
  # check some inputs for each recording
304
334
  for recording_index in range(len(self.recordings)):
@@ -320,9 +350,10 @@ class AccuSleepWindow(QMainWindow):
320
350
  return
321
351
  model_filename = os.path.normpath(model_filename)
322
352
 
323
- # create (probably temporary) image folder
353
+ # create (probably temporary) image folder in
354
+ # the same folder as the trained model
324
355
  temp_image_dir = os.path.join(
325
- self.training_image_dir,
356
+ os.path.dirname(model_filename),
326
357
  "images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
327
358
  )
328
359
 
@@ -334,7 +365,12 @@ class AccuSleepWindow(QMainWindow):
334
365
 
335
366
  # create training images
336
367
  self.show_message("Training, please wait. See console for progress updates.")
337
- self.show_message((f"Creating training images in {temp_image_dir}"))
368
+ if not self.delete_training_images:
369
+ self.show_message((f"Creating training images in {temp_image_dir}"))
370
+ else:
371
+ self.show_message(
372
+ (f"Creating temporary folder of training images: {temp_image_dir}")
373
+ )
338
374
  self.ui.message_area.repaint()
339
375
  QApplication.processEvents()
340
376
  print("Creating training images")
@@ -345,14 +381,17 @@ class AccuSleepWindow(QMainWindow):
345
381
  epochs_per_img=self.training_epochs_per_img,
346
382
  brain_state_set=self.brain_state_set,
347
383
  model_type=self.model_type,
384
+ calibration_fraction=calibration_fraction,
385
+ emg_filter=self.emg_filter,
348
386
  )
349
387
  if len(failed_recordings) > 0:
350
388
  if len(failed_recordings) == len(self.recordings):
351
389
  self.show_message("ERROR: no recordings were valid!")
390
+ return
352
391
  else:
353
392
  self.show_message(
354
393
  (
355
- "WARNING: the following recordings could not be"
394
+ "WARNING: the following recordings could not be "
356
395
  "loaded and will not be used for training: "
357
396
  f"{', '.join([str(r) for r in failed_recordings])}"
358
397
  )
@@ -363,16 +402,32 @@ class AccuSleepWindow(QMainWindow):
363
402
  self.ui.message_area.repaint()
364
403
  QApplication.processEvents()
365
404
  print("Training model")
366
- from accusleepy.classification import train_ssann
405
+ from accusleepy.classification import create_dataloader, train_ssann
367
406
  from accusleepy.models import save_model
407
+ from accusleepy.temperature_scaling import ModelWithTemperature
368
408
 
369
409
  model = train_ssann(
370
410
  annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
371
411
  img_dir=temp_image_dir,
372
412
  mixture_weights=self.brain_state_set.mixture_weights,
373
413
  n_classes=self.brain_state_set.n_classes,
414
+ hyperparameters=self.hyperparameters,
374
415
  )
375
416
 
417
+ # calibrate the model
418
+ if self.calibrate_trained_model:
419
+ calibration_annotation_file = os.path.join(
420
+ temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
421
+ )
422
+ calibration_dataloader = create_dataloader(
423
+ annotations_file=calibration_annotation_file,
424
+ img_dir=temp_image_dir,
425
+ hyperparameters=self.hyperparameters,
426
+ )
427
+ model = ModelWithTemperature(model)
428
+ print("Calibrating model")
429
+ model.set_temperature(calibration_dataloader)
430
+
376
431
  # save model
377
432
  save_model(
378
433
  model=model,
@@ -381,29 +436,26 @@ class AccuSleepWindow(QMainWindow):
381
436
  epochs_per_img=self.training_epochs_per_img,
382
437
  model_type=self.model_type,
383
438
  brain_state_set=self.brain_state_set,
439
+ is_calibrated=self.calibrate_trained_model,
384
440
  )
385
441
 
386
442
  # optionally delete images
387
443
  if self.delete_training_images:
444
+ print("Cleaning up training image folder")
388
445
  shutil.rmtree(temp_image_dir)
389
446
 
390
447
  self.show_message(f"Training complete. Saved model to {model_filename}")
391
448
  print("Training complete.")
392
449
 
393
- def set_training_folder(self) -> None:
394
- """Select location in which to create a folder for training images"""
395
- training_folder_parent = QFileDialog.getExistingDirectory(
396
- self, "Select directory for training images"
397
- )
398
- if training_folder_parent:
399
- training_folder_parent = os.path.normpath(training_folder_parent)
400
- self.training_image_dir = training_folder_parent
401
- self.ui.image_folder_label.setText(training_folder_parent)
402
-
403
450
  def update_image_deletion(self) -> None:
404
451
  """Update choice of whether to delete images after training"""
405
452
  self.delete_training_images = self.ui.delete_image_box.isChecked()
406
453
 
454
+ def update_training_calibration(self) -> None:
455
+ """Update choice of whether to calibrate model after training"""
456
+ self.calibrate_trained_model = self.ui.calibrate_checkbox.isChecked()
457
+ self.ui.calibration_spinbox.setEnabled(self.calibrate_trained_model)
458
+
407
459
  def update_epochs_per_img(self, new_value) -> None:
408
460
  """Update number of epochs per image
409
461
 
@@ -491,7 +543,8 @@ class AccuSleepWindow(QMainWindow):
491
543
  label_file = self.recordings[recording_index].label_file
492
544
  if os.path.isfile(label_file):
493
545
  try:
494
- existing_labels = load_labels(label_file)
546
+ # ignore any existing confidence scores; they will all be overwritten
547
+ existing_labels, _ = load_labels(label_file)
495
548
  except Exception:
496
549
  self.show_message(
497
550
  (
@@ -544,7 +597,7 @@ class AccuSleepWindow(QMainWindow):
544
597
  )
545
598
  continue
546
599
 
547
- labels = score_recording(
600
+ labels, confidence_scores = score_recording(
548
601
  model=self.model,
549
602
  eeg=eeg,
550
603
  emg=emg,
@@ -554,6 +607,7 @@ class AccuSleepWindow(QMainWindow):
554
607
  epoch_length=self.epoch_length,
555
608
  epochs_per_img=self.model_epochs_per_img,
556
609
  brain_state_set=self.brain_state_set,
610
+ emg_filter=self.emg_filter,
557
611
  )
558
612
 
559
613
  # overwrite as needed
@@ -569,8 +623,14 @@ class AccuSleepWindow(QMainWindow):
569
623
  min_bout_length=self.min_bout_length,
570
624
  )
571
625
 
626
+ # ignore confidence scores if desired
627
+ if not self.save_confidence_scores:
628
+ confidence_scores = None
629
+
572
630
  # save results
573
- save_labels(labels, label_file)
631
+ save_labels(
632
+ labels=labels, filename=label_file, confidence_scores=confidence_scores
633
+ )
574
634
  self.show_message(
575
635
  (
576
636
  "Saved labels for recording "
@@ -586,8 +646,6 @@ class AccuSleepWindow(QMainWindow):
586
646
 
587
647
  :param filename: model filename, if it's known
588
648
  """
589
- from accusleepy.models import load_model
590
-
591
649
  if filename is None:
592
650
  file_dialog = QFileDialog(self)
593
651
  file_dialog.setWindowTitle("Select classification model")
@@ -606,6 +664,12 @@ class AccuSleepWindow(QMainWindow):
606
664
  self.show_message("ERROR: model file does not exist")
607
665
  return
608
666
 
667
+ self.show_message("Loading classification model")
668
+ self.ui.message_area.repaint()
669
+ QApplication.processEvents()
670
+
671
+ from accusleepy.models import load_model
672
+
609
673
  try:
610
674
  model, epoch_length, epochs_per_img, model_type, brain_states = load_model(
611
675
  filename=filename
@@ -648,6 +712,8 @@ class AccuSleepWindow(QMainWindow):
648
712
  if len(config_warnings) > 0:
649
713
  for w in config_warnings:
650
714
  self.show_message(w)
715
+ else:
716
+ self.show_message(f"Loaded classification model from {filename}")
651
717
 
652
718
  self.ui.model_label.setText(filename)
653
719
 
@@ -716,7 +782,7 @@ class AccuSleepWindow(QMainWindow):
716
782
  self.show_message("ERROR: label file does not exist")
717
783
  return
718
784
  try:
719
- labels = load_labels(label_file)
785
+ labels, _ = load_labels(label_file)
720
786
  except Exception:
721
787
  self.ui.calibration_status.setText("could not load labels")
722
788
  self.show_message(
@@ -728,6 +794,7 @@ class AccuSleepWindow(QMainWindow):
728
794
  return
729
795
  label_error_message = check_label_validity(
730
796
  labels=labels,
797
+ confidence_scores=None,
731
798
  samples_in_recording=eeg.size,
732
799
  sampling_rate=sampling_rate,
733
800
  epoch_length=self.epoch_length,
@@ -758,6 +825,7 @@ class AccuSleepWindow(QMainWindow):
758
825
  sampling_rate=sampling_rate,
759
826
  epoch_length=self.epoch_length,
760
827
  brain_state_set=self.brain_state_set,
828
+ emg_filter=self.emg_filter,
761
829
  )
762
830
 
763
831
  self.ui.calibration_status.setText("")
@@ -814,6 +882,15 @@ class AccuSleepWindow(QMainWindow):
814
882
  """
815
883
  self.only_overwrite_undefined = checked
816
884
 
885
+ def update_confidence_policy(self, checked) -> None:
886
+ """Toggle policy for saving confidence scores
887
+
888
+ If the checkbox is enabled, confidence scores will be saved to the label files.
889
+
890
+ :param checked: state of the checkbox
891
+ """
892
+ self.save_confidence_scores = checked
893
+
817
894
  def manual_scoring(self) -> None:
818
895
  """View the selected recording for manual scoring"""
819
896
  # immediately display a status message
@@ -833,7 +910,7 @@ class AccuSleepWindow(QMainWindow):
833
910
  label_file = self.recordings[self.recording_index].label_file
834
911
  if os.path.isfile(label_file):
835
912
  try:
836
- labels = load_labels(label_file)
913
+ labels, confidence_scores = load_labels(label_file)
837
914
  except Exception:
838
915
  self.ui.manual_scoring_status.setText("could not load labels")
839
916
  self.show_message(
@@ -848,10 +925,14 @@ class AccuSleepWindow(QMainWindow):
848
925
  np.ones(int(eeg.size / (sampling_rate * self.epoch_length)))
849
926
  * UNDEFINED_LABEL
850
927
  ).astype(int)
928
+ # manual scoring will not add a new confidence score column
929
+ # to a label file that does not have one
930
+ confidence_scores = None
851
931
 
852
932
  # check that all labels are valid
853
933
  label_error = check_label_validity(
854
934
  labels=labels,
935
+ confidence_scores=confidence_scores,
855
936
  samples_in_recording=eeg.size,
856
937
  sampling_rate=sampling_rate,
857
938
  epoch_length=self.epoch_length,
@@ -866,6 +947,10 @@ class AccuSleepWindow(QMainWindow):
866
947
  epochs_in_recording = round(eeg.size / samples_per_epoch)
867
948
  if epochs_in_recording - labels.size == 1:
868
949
  labels = np.concatenate((labels, np.array([UNDEFINED_LABEL])))
950
+ if confidence_scores is not None:
951
+ confidence_scores = np.concatenate(
952
+ (confidence_scores, np.array([0]))
953
+ )
869
954
  self.show_message(
870
955
  (
871
956
  "WARNING: an undefined epoch was added to "
@@ -874,6 +959,8 @@ class AccuSleepWindow(QMainWindow):
874
959
  )
875
960
  elif labels.size - epochs_in_recording == 1:
876
961
  labels = labels[:-1]
962
+ if confidence_scores is not None:
963
+ confidence_scores = confidence_scores[:-1]
877
964
  self.show_message(
878
965
  (
879
966
  "WARNING: the last epoch was removed from "
@@ -900,8 +987,10 @@ class AccuSleepWindow(QMainWindow):
900
987
  emg=emg,
901
988
  label_file=label_file,
902
989
  labels=labels,
990
+ confidence_scores=confidence_scores,
903
991
  sampling_rate=sampling_rate,
904
992
  epoch_length=self.epoch_length,
993
+ emg_filter=self.emg_filter,
905
994
  )
906
995
  manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
907
996
  manual_scoring_window.exec()
@@ -1067,15 +1156,6 @@ class AccuSleepWindow(QMainWindow):
1067
1156
 
1068
1157
  def initialize_settings_tab(self):
1069
1158
  """Populate settings tab and assign its callbacks"""
1070
- # show information about the settings tab
1071
- config_guide_file = open(
1072
- os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_GUIDE_FILE),
1073
- "r",
1074
- )
1075
- config_guide_text = config_guide_file.read()
1076
- config_guide_file.close()
1077
- self.ui.settings_text.setText(config_guide_text)
1078
-
1079
1159
  # store dictionary that maps digits to rows of widgets
1080
1160
  # in the settings tab
1081
1161
  self.settings_widgets = {
@@ -1152,7 +1232,21 @@ class AccuSleepWindow(QMainWindow):
1152
1232
  }
1153
1233
 
1154
1234
  # update widget state to display current config
1235
+ # UI defaults
1155
1236
  self.ui.default_epoch_input.setValue(self.epoch_length)
1237
+ self.ui.overwrite_default_checkbox.setChecked(self.only_overwrite_undefined)
1238
+ self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_scores)
1239
+ self.ui.default_min_bout_length_spinbox.setValue(self.min_bout_length)
1240
+ # EMG filter
1241
+ self.ui.emg_order_spinbox.setValue(self.emg_filter.order)
1242
+ self.ui.bp_lower_spinbox.setValue(self.emg_filter.bp_lower)
1243
+ self.ui.bp_upper_spinbox.setValue(self.emg_filter.bp_upper)
1244
+ # model training hyperparameters
1245
+ self.ui.batch_size_spinbox.setValue(self.hyperparameters.batch_size)
1246
+ self.ui.learning_rate_spinbox.setValue(self.hyperparameters.learning_rate)
1247
+ self.ui.momentum_spinbox.setValue(self.hyperparameters.momentum)
1248
+ self.ui.training_epochs_spinbox.setValue(self.hyperparameters.training_epochs)
1249
+ # brain states
1156
1250
  states = {b.digit: b for b in self.brain_state_set.brain_states}
1157
1251
  for digit in range(10):
1158
1252
  if digit in states.keys():
@@ -1171,6 +1265,15 @@ class AccuSleepWindow(QMainWindow):
1171
1265
  self.settings_widgets[digit].frequency_widget.setEnabled(False)
1172
1266
 
1173
1267
  # set callbacks
1268
+ self.ui.emg_order_spinbox.valueChanged.connect(self.emg_filter_order_changed)
1269
+ self.ui.bp_lower_spinbox.valueChanged.connect(self.emg_filter_bp_lower_changed)
1270
+ self.ui.bp_upper_spinbox.valueChanged.connect(self.emg_filter_bp_upper_changed)
1271
+ self.ui.batch_size_spinbox.valueChanged.connect(self.hyperparameters_changed)
1272
+ self.ui.learning_rate_spinbox.valueChanged.connect(self.hyperparameters_changed)
1273
+ self.ui.momentum_spinbox.valueChanged.connect(self.hyperparameters_changed)
1274
+ self.ui.training_epochs_spinbox.valueChanged.connect(
1275
+ self.hyperparameters_changed
1276
+ )
1174
1277
  for digit in range(10):
1175
1278
  state = self.settings_widgets[digit]
1176
1279
  state.enabled_widget.stateChanged.connect(
@@ -1233,6 +1336,41 @@ class AccuSleepWindow(QMainWindow):
1233
1336
  # check that configuration is valid
1234
1337
  _ = self.check_config_validity()
1235
1338
 
1339
+ def emg_filter_order_changed(self, new_value: int) -> None:
1340
+ """Called when user modifies EMG filter order
1341
+
1342
+ :param new_value: new EMG filter order
1343
+ """
1344
+ self.emg_filter.order = new_value
1345
+
1346
+ def emg_filter_bp_lower_changed(self, new_value: int | float) -> None:
1347
+ """Called when user modifies EMG filter lower cutoff
1348
+
1349
+ :param new_value: new lower bandpass cutoff frequency
1350
+ """
1351
+ self.emg_filter.bp_lower = new_value
1352
+ _ = self.check_config_validity()
1353
+
1354
+ def emg_filter_bp_upper_changed(self, new_value: int | float) -> None:
1355
+ """Called when user modifies EMG filter upper cutoff
1356
+
1357
+ :param new_value: new upper bandpass cutoff frequency
1358
+ """
1359
+ self.emg_filter.bp_upper = new_value
1360
+ _ = self.check_config_validity()
1361
+
1362
+ def hyperparameters_changed(self, new_value) -> None:
1363
+ """Called when user modifies model training hyperparameters
1364
+
1365
+ :param new_value: unused
1366
+ """
1367
+ self.hyperparameters = Hyperparameters(
1368
+ batch_size=self.ui.batch_size_spinbox.value(),
1369
+ learning_rate=self.ui.learning_rate_spinbox.value(),
1370
+ momentum=self.ui.momentum_spinbox.value(),
1371
+ training_epochs=self.ui.training_epochs_spinbox.value(),
1372
+ )
1373
+
1236
1374
  def check_config_validity(self) -> str:
1237
1375
  """Check if brain state configuration on screen is valid"""
1238
1376
  # error message, if we get one
@@ -1259,6 +1397,10 @@ class AccuSleepWindow(QMainWindow):
1259
1397
  if sum(frequencies) != 1:
1260
1398
  message = "Error: sum(frequencies) != 1"
1261
1399
 
1400
+ # check validity of EMG filter settings
1401
+ if self.emg_filter.bp_lower >= self.emg_filter.bp_upper:
1402
+ message = "Error: EMG filter cutoff frequencies are invalid"
1403
+
1262
1404
  if message is not None:
1263
1405
  self.ui.save_config_status.setText(message)
1264
1406
  self.ui.save_config_button.setEnabled(False)
@@ -1290,12 +1432,41 @@ class AccuSleepWindow(QMainWindow):
1290
1432
  self.brain_state_set = BrainStateSet(brain_states, UNDEFINED_LABEL)
1291
1433
 
1292
1434
  # save to file
1293
- save_config(self.brain_state_set, self.ui.default_epoch_input.value())
1435
+ save_config(
1436
+ brain_state_set=self.brain_state_set,
1437
+ default_epoch_length=self.ui.default_epoch_input.value(),
1438
+ overwrite_setting=self.ui.overwrite_default_checkbox.isChecked(),
1439
+ save_confidence_setting=self.ui.confidence_setting_checkbox.isChecked(),
1440
+ min_bout_length=self.ui.default_min_bout_length_spinbox.value(),
1441
+ emg_filter=EMGFilter(
1442
+ order=self.emg_filter.order,
1443
+ bp_lower=self.emg_filter.bp_lower,
1444
+ bp_upper=self.emg_filter.bp_upper,
1445
+ ),
1446
+ hyperparameters=Hyperparameters(
1447
+ batch_size=self.hyperparameters.batch_size,
1448
+ learning_rate=self.hyperparameters.learning_rate,
1449
+ momentum=self.hyperparameters.momentum,
1450
+ training_epochs=self.hyperparameters.training_epochs,
1451
+ ),
1452
+ )
1294
1453
  self.ui.save_config_status.setText("configuration saved")
1295
1454
 
1455
+ def reset_emg_filter_settings(self) -> None:
1456
+ self.ui.emg_order_spinbox.setValue(DEFAULT_EMG_FILTER_ORDER)
1457
+ self.ui.bp_lower_spinbox.setValue(DEFAULT_EMG_BP_LOWER)
1458
+ self.ui.bp_upper_spinbox.setValue(DEFAULT_EMG_BP_UPPER)
1459
+
1460
+ def reset_hyperparams_settings(self):
1461
+ self.ui.batch_size_spinbox.setValue(DEFAULT_BATCH_SIZE)
1462
+ self.ui.learning_rate_spinbox.setValue(DEFAULT_LEARNING_RATE)
1463
+ self.ui.momentum_spinbox.setValue(DEFAULT_MOMENTUM)
1464
+ self.ui.training_epochs_spinbox.setValue(DEFAULT_TRAINING_EPOCHS)
1465
+
1296
1466
 
1297
1467
  def check_label_validity(
1298
1468
  labels: np.array,
1469
+ confidence_scores: np.array,
1299
1470
  samples_in_recording: int,
1300
1471
  sampling_rate: int | float,
1301
1472
  epoch_length: int | float,
@@ -1307,6 +1478,7 @@ def check_label_validity(
1307
1478
  brain state labels.
1308
1479
 
1309
1480
  :param labels: brain state labels
1481
+ :param confidence_scores: confidence scores
1310
1482
  :param samples_in_recording: number of samples in the recording
1311
1483
  :param sampling_rate: sampling rate, in Hz
1312
1484
  :param epoch_length: epoch length, in seconds
@@ -1325,6 +1497,12 @@ def check_label_validity(
1325
1497
  ):
1326
1498
  return "label file contains invalid entries"
1327
1499
 
1500
+ if confidence_scores is not None:
1501
+ if np.min(confidence_scores) < 0 or np.max(confidence_scores) > 1:
1502
+ return "label file contains invalid confidence scores"
1503
+
1504
+ return None
1505
+
1328
1506
 
1329
1507
  def check_config_consistency(
1330
1508
  current_brain_states: dict,