accusleepy 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
accusleepy/gui/main.py CHANGED
@@ -39,7 +39,15 @@ from accusleepy.constants import (
39
39
  CALIBRATION_ANNOTATION_FILENAME,
40
40
  CALIBRATION_FILE_TYPE,
41
41
  DEFAULT_MODEL_TYPE,
42
+ DEFAULT_EMG_FILTER_ORDER,
43
+ DEFAULT_EMG_BP_LOWER,
44
+ DEFAULT_EMG_BP_UPPER,
45
+ DEFAULT_BATCH_SIZE,
46
+ DEFAULT_LEARNING_RATE,
47
+ DEFAULT_MOMENTUM,
48
+ DEFAULT_TRAINING_EPOCHS,
42
49
  LABEL_FILE_TYPE,
50
+ MESSAGE_BOX_MAX_DEPTH,
43
51
  MODEL_FILE_TYPE,
44
52
  REAL_TIME_MODEL_TYPE,
45
53
  RECORDING_FILE_TYPES,
@@ -56,6 +64,8 @@ from accusleepy.fileio import (
56
64
  save_config,
57
65
  save_labels,
58
66
  save_recording_list,
67
+ EMGFilter,
68
+ Hyperparameters,
59
69
  )
60
70
  from accusleepy.gui.manual_scoring import ManualScoringWindow
61
71
  from accusleepy.gui.primary_window import Ui_PrimaryWindow
@@ -63,14 +73,15 @@ from accusleepy.signal_processing import (
63
73
  create_training_images,
64
74
  resample_and_standardize,
65
75
  )
76
+ from accusleepy.validation import (
77
+ check_label_validity,
78
+ LABEL_LENGTH_ERROR,
79
+ check_config_consistency,
80
+ )
66
81
 
67
82
  # note: functions using torch or scipy are lazily imported
68
83
 
69
- # max number of messages to display
70
- MESSAGE_BOX_MAX_DEPTH = 200
71
- LABEL_LENGTH_ERROR = "label file length does not match recording length"
72
- # relative path to config guide txt file
73
- CONFIG_GUIDE_FILE = os.path.normpath(r"text/config_guide.txt")
84
+ # relative path to user manual
74
85
  MAIN_GUIDE_FILE = os.path.normpath(r"text/main_guide.md")
75
86
 
76
87
 
@@ -97,19 +108,25 @@ class AccuSleepWindow(QMainWindow):
97
108
  self.setWindowTitle("AccuSleePy")
98
109
 
99
110
  # fill in settings tab
100
- self.brain_state_set, self.epoch_length, self.save_confidence_setting = (
101
- load_config()
102
- )
111
+ (
112
+ self.brain_state_set,
113
+ self.epoch_length,
114
+ self.only_overwrite_undefined,
115
+ self.save_confidence_scores,
116
+ self.min_bout_length,
117
+ self.emg_filter,
118
+ self.hyperparameters,
119
+ ) = load_config()
120
+
103
121
  self.settings_widgets = None
104
122
  self.initialize_settings_tab()
105
123
 
106
124
  # initialize info about the recordings, classification data / settings
107
125
  self.ui.epoch_length_input.setValue(self.epoch_length)
108
- self.ui.save_confidence_checkbox.setChecked(self.save_confidence_setting)
126
+ self.ui.overwritecheckbox.setChecked(self.only_overwrite_undefined)
127
+ self.ui.save_confidence_checkbox.setChecked(self.save_confidence_scores)
128
+ self.ui.bout_length_input.setValue(self.min_bout_length)
109
129
  self.model = None
110
- self.only_overwrite_undefined = False
111
- self.save_confidence_scores = self.save_confidence_setting
112
- self.min_bout_length = 5
113
130
 
114
131
  # initialize model training variables
115
132
  self.training_epochs_per_img = 9
@@ -186,6 +203,10 @@ class AccuSleepWindow(QMainWindow):
186
203
  self.ui.export_button.clicked.connect(self.export_recording_list)
187
204
  self.ui.import_button.clicked.connect(self.import_recording_list)
188
205
  self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
206
+ self.ui.reset_emg_params_button.clicked.connect(self.reset_emg_filter_settings)
207
+ self.ui.reset_hyperparams_button.clicked.connect(
208
+ self.reset_hyperparams_settings
209
+ )
189
210
 
190
211
  # user input: drag and drop
191
212
  self.ui.recording_file_label.installEventFilter(self)
@@ -200,10 +221,9 @@ class AccuSleepWindow(QMainWindow):
200
221
 
201
222
  :param default_selected: whether default option is selected
202
223
  """
203
- if default_selected:
204
- self.model_type = DEFAULT_MODEL_TYPE
205
- else:
206
- self.model_type = REAL_TIME_MODEL_TYPE
224
+ self.model_type = (
225
+ DEFAULT_MODEL_TYPE if default_selected else REAL_TIME_MODEL_TYPE
226
+ )
207
227
 
208
228
  def export_recording_list(self) -> None:
209
229
  """Save current list of recordings to file"""
@@ -363,6 +383,7 @@ class AccuSleepWindow(QMainWindow):
363
383
  brain_state_set=self.brain_state_set,
364
384
  model_type=self.model_type,
365
385
  calibration_fraction=calibration_fraction,
386
+ emg_filter=self.emg_filter,
366
387
  )
367
388
  if len(failed_recordings) > 0:
368
389
  if len(failed_recordings) == len(self.recordings):
@@ -391,6 +412,7 @@ class AccuSleepWindow(QMainWindow):
391
412
  img_dir=temp_image_dir,
392
413
  mixture_weights=self.brain_state_set.mixture_weights,
393
414
  n_classes=self.brain_state_set.n_classes,
415
+ hyperparameters=self.hyperparameters,
394
416
  )
395
417
 
396
418
  # calibrate the model
@@ -399,7 +421,9 @@ class AccuSleepWindow(QMainWindow):
399
421
  temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
400
422
  )
401
423
  calibration_dataloader = create_dataloader(
402
- annotations_file=calibration_annotation_file, img_dir=temp_image_dir
424
+ annotations_file=calibration_annotation_file,
425
+ img_dir=temp_image_dir,
426
+ hyperparameters=self.hyperparameters,
403
427
  )
404
428
  model = ModelWithTemperature(model)
405
429
  print("Calibrating model")
@@ -584,6 +608,7 @@ class AccuSleepWindow(QMainWindow):
584
608
  epoch_length=self.epoch_length,
585
609
  epochs_per_img=self.model_epochs_per_img,
586
610
  brain_state_set=self.brain_state_set,
611
+ emg_filter=self.emg_filter,
587
612
  )
588
613
 
589
614
  # overwrite as needed
@@ -801,6 +826,7 @@ class AccuSleepWindow(QMainWindow):
801
826
  sampling_rate=sampling_rate,
802
827
  epoch_length=self.epoch_length,
803
828
  brain_state_set=self.brain_state_set,
829
+ emg_filter=self.emg_filter,
804
830
  )
805
831
 
806
832
  self.ui.calibration_status.setText("")
@@ -965,6 +991,7 @@ class AccuSleepWindow(QMainWindow):
965
991
  confidence_scores=confidence_scores,
966
992
  sampling_rate=sampling_rate,
967
993
  epoch_length=self.epoch_length,
994
+ emg_filter=self.emg_filter,
968
995
  )
969
996
  manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
970
997
  manual_scoring_window.exec()
@@ -1130,15 +1157,6 @@ class AccuSleepWindow(QMainWindow):
1130
1157
 
1131
1158
  def initialize_settings_tab(self):
1132
1159
  """Populate settings tab and assign its callbacks"""
1133
- # show information about the settings tab
1134
- config_guide_file = open(
1135
- os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_GUIDE_FILE),
1136
- "r",
1137
- )
1138
- config_guide_text = config_guide_file.read()
1139
- config_guide_file.close()
1140
- self.ui.settings_text.setText(config_guide_text)
1141
-
1142
1160
  # store dictionary that maps digits to rows of widgets
1143
1161
  # in the settings tab
1144
1162
  self.settings_widgets = {
@@ -1215,8 +1233,21 @@ class AccuSleepWindow(QMainWindow):
1215
1233
  }
1216
1234
 
1217
1235
  # update widget state to display current config
1236
+ # UI defaults
1218
1237
  self.ui.default_epoch_input.setValue(self.epoch_length)
1219
- self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_setting)
1238
+ self.ui.overwrite_default_checkbox.setChecked(self.only_overwrite_undefined)
1239
+ self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_scores)
1240
+ self.ui.default_min_bout_length_spinbox.setValue(self.min_bout_length)
1241
+ # EMG filter
1242
+ self.ui.emg_order_spinbox.setValue(self.emg_filter.order)
1243
+ self.ui.bp_lower_spinbox.setValue(self.emg_filter.bp_lower)
1244
+ self.ui.bp_upper_spinbox.setValue(self.emg_filter.bp_upper)
1245
+ # model training hyperparameters
1246
+ self.ui.batch_size_spinbox.setValue(self.hyperparameters.batch_size)
1247
+ self.ui.learning_rate_spinbox.setValue(self.hyperparameters.learning_rate)
1248
+ self.ui.momentum_spinbox.setValue(self.hyperparameters.momentum)
1249
+ self.ui.training_epochs_spinbox.setValue(self.hyperparameters.training_epochs)
1250
+ # brain states
1220
1251
  states = {b.digit: b for b in self.brain_state_set.brain_states}
1221
1252
  for digit in range(10):
1222
1253
  if digit in states.keys():
@@ -1235,16 +1266,25 @@ class AccuSleepWindow(QMainWindow):
1235
1266
  self.settings_widgets[digit].frequency_widget.setEnabled(False)
1236
1267
 
1237
1268
  # set callbacks
1269
+ self.ui.emg_order_spinbox.valueChanged.connect(self.emg_filter_order_changed)
1270
+ self.ui.bp_lower_spinbox.valueChanged.connect(self.emg_filter_bp_lower_changed)
1271
+ self.ui.bp_upper_spinbox.valueChanged.connect(self.emg_filter_bp_upper_changed)
1272
+ self.ui.batch_size_spinbox.valueChanged.connect(self.hyperparameters_changed)
1273
+ self.ui.learning_rate_spinbox.valueChanged.connect(self.hyperparameters_changed)
1274
+ self.ui.momentum_spinbox.valueChanged.connect(self.hyperparameters_changed)
1275
+ self.ui.training_epochs_spinbox.valueChanged.connect(
1276
+ self.hyperparameters_changed
1277
+ )
1238
1278
  for digit in range(10):
1239
1279
  state = self.settings_widgets[digit]
1240
1280
  state.enabled_widget.stateChanged.connect(
1241
1281
  partial(self.set_brain_state_enabled, digit)
1242
1282
  )
1243
- state.name_widget.editingFinished.connect(self.finished_editing_state_name)
1283
+ state.name_widget.editingFinished.connect(self.check_config_validity)
1244
1284
  state.is_scored_widget.stateChanged.connect(
1245
1285
  partial(self.is_scored_changed, digit)
1246
1286
  )
1247
- state.frequency_widget.valueChanged.connect(self.state_frequency_changed)
1287
+ state.frequency_widget.valueChanged.connect(self.check_config_validity)
1248
1288
 
1249
1289
  def set_brain_state_enabled(self, digit, e) -> None:
1250
1290
  """Called when user clicks "enabled" checkbox
@@ -1270,17 +1310,6 @@ class AccuSleepWindow(QMainWindow):
1270
1310
  # check that configuration is valid
1271
1311
  _ = self.check_config_validity()
1272
1312
 
1273
- def finished_editing_state_name(self) -> None:
1274
- """Called when user finishes editing a brain state's name"""
1275
- _ = self.check_config_validity()
1276
-
1277
- def state_frequency_changed(self, new_value) -> None:
1278
- """Called when user edits a brain state's frequency
1279
-
1280
- :param new_value: unused
1281
- """
1282
- _ = self.check_config_validity()
1283
-
1284
1313
  def is_scored_changed(self, digit, e) -> None:
1285
1314
  """Called when user sets whether a state is scored
1286
1315
 
@@ -1297,6 +1326,41 @@ class AccuSleepWindow(QMainWindow):
1297
1326
  # check that configuration is valid
1298
1327
  _ = self.check_config_validity()
1299
1328
 
1329
+ def emg_filter_order_changed(self, new_value: int) -> None:
1330
+ """Called when user modifies EMG filter order
1331
+
1332
+ :param new_value: new EMG filter order
1333
+ """
1334
+ self.emg_filter.order = new_value
1335
+
1336
+ def emg_filter_bp_lower_changed(self, new_value: int | float) -> None:
1337
+ """Called when user modifies EMG filter lower cutoff
1338
+
1339
+ :param new_value: new lower bandpass cutoff frequency
1340
+ """
1341
+ self.emg_filter.bp_lower = new_value
1342
+ _ = self.check_config_validity()
1343
+
1344
+ def emg_filter_bp_upper_changed(self, new_value: int | float) -> None:
1345
+ """Called when user modifies EMG filter upper cutoff
1346
+
1347
+ :param new_value: new upper bandpass cutoff frequency
1348
+ """
1349
+ self.emg_filter.bp_upper = new_value
1350
+ _ = self.check_config_validity()
1351
+
1352
+ def hyperparameters_changed(self, new_value) -> None:
1353
+ """Called when user modifies model training hyperparameters
1354
+
1355
+ :param new_value: unused
1356
+ """
1357
+ self.hyperparameters = Hyperparameters(
1358
+ batch_size=self.ui.batch_size_spinbox.value(),
1359
+ learning_rate=self.ui.learning_rate_spinbox.value(),
1360
+ momentum=self.ui.momentum_spinbox.value(),
1361
+ training_epochs=self.ui.training_epochs_spinbox.value(),
1362
+ )
1363
+
1300
1364
  def check_config_validity(self) -> str:
1301
1365
  """Check if brain state configuration on screen is valid"""
1302
1366
  # error message, if we get one
@@ -1323,6 +1387,10 @@ class AccuSleepWindow(QMainWindow):
1323
1387
  if sum(frequencies) != 1:
1324
1388
  message = "Error: sum(frequencies) != 1"
1325
1389
 
1390
+ # check validity of EMG filter settings
1391
+ if self.emg_filter.bp_lower >= self.emg_filter.bp_upper:
1392
+ message = "Error: EMG filter cutoff frequencies are invalid"
1393
+
1326
1394
  if message is not None:
1327
1395
  self.ui.save_config_status.setText(message)
1328
1396
  self.ui.save_config_button.setEnabled(False)
@@ -1355,133 +1423,35 @@ class AccuSleepWindow(QMainWindow):
1355
1423
 
1356
1424
  # save to file
1357
1425
  save_config(
1358
- self.brain_state_set,
1359
- self.ui.default_epoch_input.value(),
1360
- self.ui.confidence_setting_checkbox.isChecked(),
1426
+ brain_state_set=self.brain_state_set,
1427
+ default_epoch_length=self.ui.default_epoch_input.value(),
1428
+ overwrite_setting=self.ui.overwrite_default_checkbox.isChecked(),
1429
+ save_confidence_setting=self.ui.confidence_setting_checkbox.isChecked(),
1430
+ min_bout_length=self.ui.default_min_bout_length_spinbox.value(),
1431
+ emg_filter=EMGFilter(
1432
+ order=self.emg_filter.order,
1433
+ bp_lower=self.emg_filter.bp_lower,
1434
+ bp_upper=self.emg_filter.bp_upper,
1435
+ ),
1436
+ hyperparameters=Hyperparameters(
1437
+ batch_size=self.hyperparameters.batch_size,
1438
+ learning_rate=self.hyperparameters.learning_rate,
1439
+ momentum=self.hyperparameters.momentum,
1440
+ training_epochs=self.hyperparameters.training_epochs,
1441
+ ),
1361
1442
  )
1362
1443
  self.ui.save_config_status.setText("configuration saved")
1363
1444
 
1445
+ def reset_emg_filter_settings(self) -> None:
1446
+ self.ui.emg_order_spinbox.setValue(DEFAULT_EMG_FILTER_ORDER)
1447
+ self.ui.bp_lower_spinbox.setValue(DEFAULT_EMG_BP_LOWER)
1448
+ self.ui.bp_upper_spinbox.setValue(DEFAULT_EMG_BP_UPPER)
1364
1449
 
1365
- def check_label_validity(
1366
- labels: np.array,
1367
- confidence_scores: np.array,
1368
- samples_in_recording: int,
1369
- sampling_rate: int | float,
1370
- epoch_length: int | float,
1371
- brain_state_set: BrainStateSet,
1372
- ) -> str | None:
1373
- """Check whether a set of brain state labels is valid
1374
-
1375
- This returns an error message if a problem is found with the
1376
- brain state labels.
1377
-
1378
- :param labels: brain state labels
1379
- :param confidence_scores: confidence scores
1380
- :param samples_in_recording: number of samples in the recording
1381
- :param sampling_rate: sampling rate, in Hz
1382
- :param epoch_length: epoch length, in seconds
1383
- :param brain_state_set: BrainStateMapper object
1384
- :return: error message
1385
- """
1386
- # check that number of labels is correct
1387
- samples_per_epoch = round(sampling_rate * epoch_length)
1388
- epochs_in_recording = round(samples_in_recording / samples_per_epoch)
1389
- if epochs_in_recording != labels.size:
1390
- return LABEL_LENGTH_ERROR
1391
-
1392
- # check that entries are valid
1393
- if not set(labels.tolist()).issubset(
1394
- set([b.digit for b in brain_state_set.brain_states] + [UNDEFINED_LABEL])
1395
- ):
1396
- return "label file contains invalid entries"
1397
-
1398
- if confidence_scores is not None:
1399
- if np.min(confidence_scores) < 0 or np.max(confidence_scores) > 1:
1400
- return "label file contains invalid confidence scores"
1401
-
1402
- return None
1403
-
1404
-
1405
- def check_config_consistency(
1406
- current_brain_states: dict,
1407
- model_brain_states: dict,
1408
- current_epoch_length: int | float,
1409
- model_epoch_length: int | float,
1410
- ) -> list[str]:
1411
- """Compare current brain state config to the model's config
1412
-
1413
- This only displays warnings - the user should decide whether to proceed
1414
-
1415
- :param current_brain_states: current brain state config
1416
- :param model_brain_states: brain state config when the model was created
1417
- :param current_epoch_length: current epoch length setting
1418
- :param model_epoch_length: epoch length used when the model was created
1419
- """
1420
- output = list()
1421
-
1422
- # make lists of names and digits for scored brain states
1423
- current_scored_states = {
1424
- f: [b[f] for b in current_brain_states if b["is_scored"]]
1425
- for f in ["name", "digit"]
1426
- }
1427
- model_scored_states = {
1428
- f: [b[f] for b in model_brain_states if b["is_scored"]]
1429
- for f in ["name", "digit"]
1430
- }
1431
-
1432
- # generate message comparing the brain state configs
1433
- config_comparisons = list()
1434
- for config, config_name in zip(
1435
- [current_scored_states, model_scored_states], ["current", "model's"]
1436
- ):
1437
- config_comparisons.append(
1438
- f"Scored brain states in {config_name} configuration: "
1439
- f"""{
1440
- ", ".join(
1441
- [
1442
- f"{x}: {y}"
1443
- for x, y in zip(
1444
- config["digit"],
1445
- config["name"],
1446
- )
1447
- ]
1448
- )
1449
- }"""
1450
- )
1451
-
1452
- # check if the number of scored states is different
1453
- len_diff = len(current_scored_states["name"]) - len(model_scored_states["name"])
1454
- if len_diff != 0:
1455
- output.append(
1456
- (
1457
- "WARNING: current brain state configuration has "
1458
- f"{'fewer' if len_diff < 0 else 'more'} "
1459
- "scored brain states than the model's configuration."
1460
- )
1461
- )
1462
- output = output + config_comparisons
1463
- else:
1464
- # the length is the same, but names might be different
1465
- if current_scored_states["name"] != model_scored_states["name"]:
1466
- output.append(
1467
- (
1468
- "WARNING: current brain state configuration appears "
1469
- "to contain different brain states than "
1470
- "the model's configuration."
1471
- )
1472
- )
1473
- output = output + config_comparisons
1474
-
1475
- if current_epoch_length != model_epoch_length:
1476
- output.append(
1477
- (
1478
- "Warning: the epoch length used when training this model "
1479
- f"({model_epoch_length} seconds) "
1480
- "does not match the current epoch length setting."
1481
- )
1482
- )
1483
-
1484
- return output
1450
+ def reset_hyperparams_settings(self):
1451
+ self.ui.batch_size_spinbox.setValue(DEFAULT_BATCH_SIZE)
1452
+ self.ui.learning_rate_spinbox.setValue(DEFAULT_LEARNING_RATE)
1453
+ self.ui.momentum_spinbox.setValue(DEFAULT_MOMENTUM)
1454
+ self.ui.training_epochs_spinbox.setValue(DEFAULT_TRAINING_EPOCHS)
1485
1455
 
1486
1456
 
1487
1457
  def run_primary_window() -> None:
@@ -32,7 +32,7 @@ from PySide6.QtWidgets import (
32
32
  )
33
33
 
34
34
  from accusleepy.constants import UNDEFINED_LABEL
35
- from accusleepy.fileio import load_config, save_labels
35
+ from accusleepy.fileio import load_config, save_labels, EMGFilter
36
36
  from accusleepy.gui.mplwidget import resample_x_ticks
37
37
  from accusleepy.gui.viewer_window import Ui_ViewerWindow
38
38
  from accusleepy.signal_processing import create_spectrogram, get_emg_power
@@ -79,6 +79,12 @@ UNDEFINED_STATE = "undefined"
79
79
  SCROLL_BOUNDARY = 0.35
80
80
  # max number of sequential undo actions allowed
81
81
  UNDO_LIMIT = 1000
82
+ # brightness scaling factors for the spectrogram
83
+ BRIGHTER_SCALE_FACTOR = 0.96
84
+ DIMMER_SCALE_FACTOR = 1.07
85
+ # zoom factor for upper plots
86
+ ZOOM_IN_FACTOR = 0.45
87
+ ZOOM_OUT_FACTOR = 1.017
82
88
 
83
89
 
84
90
  @dataclass
@@ -102,6 +108,7 @@ class ManualScoringWindow(QDialog):
102
108
  confidence_scores: np.array,
103
109
  sampling_rate: int | float,
104
110
  epoch_length: int | float,
111
+ emg_filter: EMGFilter,
105
112
  ):
106
113
  """Initialize the manual scoring window
107
114
 
@@ -112,6 +119,7 @@ class ManualScoringWindow(QDialog):
112
119
  :param confidence_scores: confidence scores
113
120
  :param sampling_rate: sampling rate, in Hz
114
121
  :param epoch_length: epoch length, in seconds
122
+ :param emg_filter: EMG filter parameters
115
123
  """
116
124
  super(ManualScoringWindow, self).__init__()
117
125
 
@@ -122,6 +130,7 @@ class ManualScoringWindow(QDialog):
122
130
  self.confidence_scores = confidence_scores
123
131
  self.sampling_rate = sampling_rate
124
132
  self.epoch_length = epoch_length
133
+ self.emg_filter = emg_filter
125
134
 
126
135
  self.n_epochs = len(self.labels)
127
136
 
@@ -131,7 +140,7 @@ class ManualScoringWindow(QDialog):
131
140
  self.setWindowTitle("AccuSleePy manual scoring window")
132
141
 
133
142
  # load set of valid brain states
134
- self.brain_state_set, _, _ = load_config()
143
+ self.brain_state_set, _, _, _, _, _, _ = load_config()
135
144
 
136
145
  # initial setting for number of epochs to show in the lower plot
137
146
  self.epochs_to_show = 5
@@ -153,7 +162,7 @@ class ManualScoringWindow(QDialog):
153
162
 
154
163
  # calculate RMS of EMG for each epoch and apply a ceiling
155
164
  self.upper_emg = create_upper_emg_signal(
156
- self.emg, self.sampling_rate, self.epoch_length
165
+ self.emg, self.sampling_rate, self.epoch_length, self.emg_filter
157
166
  )
158
167
 
159
168
  # center and scale the EEG and EMG signals to fit the display
@@ -229,23 +238,6 @@ class ManualScoringWindow(QDialog):
229
238
  keypress_zoom_out_x = QShortcut(QKeySequence(Qt.Key.Key_Minus), self)
230
239
  keypress_zoom_out_x.activated.connect(partial(self.zoom_x, ZOOM_OUT))
231
240
 
232
- keypress_modify_label = list()
233
- for brain_state in self.brain_state_set.brain_states:
234
- keypress_modify_label.append(
235
- QShortcut(
236
- QKeySequence(Qt.Key[f"Key_{brain_state.digit}"]),
237
- self,
238
- )
239
- )
240
- keypress_modify_label[-1].activated.connect(
241
- partial(self.modify_current_epoch_label, brain_state.digit)
242
- )
243
-
244
- keypress_delete_label = QShortcut(QKeySequence(Qt.Key.Key_Backspace), self)
245
- keypress_delete_label.activated.connect(
246
- partial(self.modify_current_epoch_label, UNDEFINED_LABEL)
247
- )
248
-
249
241
  keypress_quit = QShortcut(
250
242
  QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_W)),
251
243
  self,
@@ -258,36 +250,37 @@ class ManualScoringWindow(QDialog):
258
250
  )
259
251
  keypress_save.activated.connect(self.save)
260
252
 
253
+ keypress_modify_label = list()
261
254
  keypress_roi = list()
262
- for brain_state in self.brain_state_set.brain_states:
255
+ digit_key_label_pairs = [
256
+ (Qt.Key[f"Key_{brain_state.digit}"], brain_state.digit)
257
+ for brain_state in self.brain_state_set.brain_states
258
+ ] + [(Qt.Key.Key_Backspace, UNDEFINED_LABEL)]
259
+
260
+ for digit_key, digit_label in digit_key_label_pairs:
261
+ keypress_modify_label.append(
262
+ QShortcut(
263
+ QKeySequence(digit_key),
264
+ self,
265
+ )
266
+ )
267
+ keypress_modify_label[-1].activated.connect(
268
+ partial(self.modify_current_epoch_label, digit_label)
269
+ )
263
270
  keypress_roi.append(
264
271
  QShortcut(
265
272
  QKeySequence(
266
273
  QKeyCombination(
267
274
  Qt.Modifier.SHIFT,
268
- Qt.Key[f"Key_{brain_state.digit}"],
275
+ digit_key,
269
276
  )
270
277
  ),
271
278
  self,
272
279
  )
273
280
  )
274
281
  keypress_roi[-1].activated.connect(
275
- partial(self.enter_label_roi_mode, brain_state.digit)
276
- )
277
- keypress_roi.append(
278
- QShortcut(
279
- QKeySequence(
280
- QKeyCombination(
281
- Qt.Modifier.SHIFT,
282
- Qt.Key.Key_Backspace,
283
- )
284
- ),
285
- self,
282
+ partial(self.enter_label_roi_mode, digit_label)
286
283
  )
287
- )
288
- keypress_roi[-1].activated.connect(
289
- partial(self.enter_label_roi_mode, UNDEFINED_LABEL)
290
- )
291
284
 
292
285
  keypress_esc = QShortcut(QKeySequence(Qt.Key.Key_Escape), self)
293
286
  keypress_esc.activated.connect(self.exit_label_roi_mode)
@@ -623,9 +616,9 @@ class ManualScoringWindow(QDialog):
623
616
  """
624
617
  vmin, vmax = self.ui.upperfigure.spec_ref.get_clim()
625
618
  if direction == BRIGHTER:
626
- self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax * 0.96))
619
+ self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax * BRIGHTER_SCALE_FACTOR))
627
620
  else:
628
- self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax * 1.07))
621
+ self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax * DIMMER_SCALE_FACTOR))
629
622
  self.ui.upperfigure.canvas.draw()
630
623
 
631
624
  def update_epochs_shown(self, direction: str) -> None:
@@ -722,31 +715,29 @@ class ManualScoringWindow(QDialog):
722
715
 
723
716
  :param direction: in, out, or reset
724
717
  """
725
- zoom_in_factor = 0.45
726
- zoom_out_factor = 1.017
727
718
  epochs_shown = self.upper_right_epoch - self.upper_left_epoch + 1
728
719
  if direction == ZOOM_IN:
729
720
  self.upper_left_epoch = max(
730
721
  [
731
722
  self.upper_left_epoch,
732
- round(self.epoch - zoom_in_factor * epochs_shown),
723
+ round(self.epoch - ZOOM_IN_FACTOR * epochs_shown),
733
724
  ]
734
725
  )
735
726
 
736
727
  self.upper_right_epoch = min(
737
728
  [
738
729
  self.upper_right_epoch,
739
- round(self.epoch + zoom_in_factor * epochs_shown),
730
+ round(self.epoch + ZOOM_IN_FACTOR * epochs_shown),
740
731
  ]
741
732
  )
742
733
 
743
734
  elif direction == ZOOM_OUT:
744
735
  self.upper_left_epoch = max(
745
- [0, round(self.epoch - zoom_out_factor * epochs_shown)]
736
+ [0, round(self.epoch - ZOOM_OUT_FACTOR * epochs_shown)]
746
737
  )
747
738
 
748
739
  self.upper_right_epoch = min(
749
- [self.n_epochs - 1, round(self.epoch + zoom_out_factor * epochs_shown)]
740
+ [self.n_epochs - 1, round(self.epoch + ZOOM_OUT_FACTOR * epochs_shown)]
750
741
  )
751
742
 
752
743
  else: # reset
@@ -1063,21 +1054,28 @@ def create_confidence_img(confidence_scores: np.array) -> np.array:
1063
1054
 
1064
1055
 
1065
1056
  def create_upper_emg_signal(
1066
- emg: np.array, sampling_rate: int | float, epoch_length: int | float
1057
+ emg: np.array,
1058
+ sampling_rate: int | float,
1059
+ epoch_length: int | float,
1060
+ emg_filter: EMGFilter,
1067
1061
  ) -> np.array:
1068
1062
  """Calculate RMS of EMG for each epoch and apply a ceiling
1069
1063
 
1070
1064
  :param emg: EMG signal
1071
1065
  :param sampling_rate: sampling rate, in Hz
1072
1066
  :param epoch_length: epoch length, in seconds
1067
+ :param emg_filter: EMG filter parameters
1073
1068
  :return: processed EMG signal
1074
1069
  """
1075
1070
  emg_rms = get_emg_power(
1076
1071
  emg,
1077
1072
  sampling_rate,
1078
1073
  epoch_length,
1074
+ emg_filter,
1075
+ )
1076
+ return np.clip(
1077
+ emg_rms, np.percentile(emg_rms, 0.1), np.mean(emg_rms) + np.std(emg_rms) * 2.5
1079
1078
  )
1080
- return np.clip(emg_rms, 0, np.mean(emg_rms) + np.std(emg_rms) * 2.5)
1081
1079
 
1082
1080
 
1083
1081
  def transform_eeg_emg(eeg: np.array, emg: np.array) -> (np.array, np.array):