accusleepy 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/classification.py +29 -13
- accusleepy/config.json +14 -1
- accusleepy/constants.py +44 -6
- accusleepy/fileio.py +87 -36
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/main.py +133 -163
- accusleepy/gui/manual_scoring.py +45 -47
- accusleepy/gui/primary_window.py +760 -135
- accusleepy/gui/primary_window.ui +2934 -2122
- accusleepy/gui/text/main_guide.md +2 -1
- accusleepy/models.py +1 -12
- accusleepy/signal_processing.py +18 -17
- accusleepy/validation.py +128 -0
- {accusleepy-0.6.0.dist-info → accusleepy-0.7.1.dist-info}/METADATA +4 -1
- {accusleepy-0.6.0.dist-info → accusleepy-0.7.1.dist-info}/RECORD +16 -16
- accusleepy/gui/text/config_guide.txt +0 -27
- {accusleepy-0.6.0.dist-info → accusleepy-0.7.1.dist-info}/WHEEL +0 -0
accusleepy/gui/main.py
CHANGED
|
@@ -39,7 +39,15 @@ from accusleepy.constants import (
|
|
|
39
39
|
CALIBRATION_ANNOTATION_FILENAME,
|
|
40
40
|
CALIBRATION_FILE_TYPE,
|
|
41
41
|
DEFAULT_MODEL_TYPE,
|
|
42
|
+
DEFAULT_EMG_FILTER_ORDER,
|
|
43
|
+
DEFAULT_EMG_BP_LOWER,
|
|
44
|
+
DEFAULT_EMG_BP_UPPER,
|
|
45
|
+
DEFAULT_BATCH_SIZE,
|
|
46
|
+
DEFAULT_LEARNING_RATE,
|
|
47
|
+
DEFAULT_MOMENTUM,
|
|
48
|
+
DEFAULT_TRAINING_EPOCHS,
|
|
42
49
|
LABEL_FILE_TYPE,
|
|
50
|
+
MESSAGE_BOX_MAX_DEPTH,
|
|
43
51
|
MODEL_FILE_TYPE,
|
|
44
52
|
REAL_TIME_MODEL_TYPE,
|
|
45
53
|
RECORDING_FILE_TYPES,
|
|
@@ -56,6 +64,8 @@ from accusleepy.fileio import (
|
|
|
56
64
|
save_config,
|
|
57
65
|
save_labels,
|
|
58
66
|
save_recording_list,
|
|
67
|
+
EMGFilter,
|
|
68
|
+
Hyperparameters,
|
|
59
69
|
)
|
|
60
70
|
from accusleepy.gui.manual_scoring import ManualScoringWindow
|
|
61
71
|
from accusleepy.gui.primary_window import Ui_PrimaryWindow
|
|
@@ -63,14 +73,15 @@ from accusleepy.signal_processing import (
|
|
|
63
73
|
create_training_images,
|
|
64
74
|
resample_and_standardize,
|
|
65
75
|
)
|
|
76
|
+
from accusleepy.validation import (
|
|
77
|
+
check_label_validity,
|
|
78
|
+
LABEL_LENGTH_ERROR,
|
|
79
|
+
check_config_consistency,
|
|
80
|
+
)
|
|
66
81
|
|
|
67
82
|
# note: functions using torch or scipy are lazily imported
|
|
68
83
|
|
|
69
|
-
#
|
|
70
|
-
MESSAGE_BOX_MAX_DEPTH = 200
|
|
71
|
-
LABEL_LENGTH_ERROR = "label file length does not match recording length"
|
|
72
|
-
# relative path to config guide txt file
|
|
73
|
-
CONFIG_GUIDE_FILE = os.path.normpath(r"text/config_guide.txt")
|
|
84
|
+
# relative path to user manual
|
|
74
85
|
MAIN_GUIDE_FILE = os.path.normpath(r"text/main_guide.md")
|
|
75
86
|
|
|
76
87
|
|
|
@@ -97,19 +108,25 @@ class AccuSleepWindow(QMainWindow):
|
|
|
97
108
|
self.setWindowTitle("AccuSleePy")
|
|
98
109
|
|
|
99
110
|
# fill in settings tab
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
111
|
+
(
|
|
112
|
+
self.brain_state_set,
|
|
113
|
+
self.epoch_length,
|
|
114
|
+
self.only_overwrite_undefined,
|
|
115
|
+
self.save_confidence_scores,
|
|
116
|
+
self.min_bout_length,
|
|
117
|
+
self.emg_filter,
|
|
118
|
+
self.hyperparameters,
|
|
119
|
+
) = load_config()
|
|
120
|
+
|
|
103
121
|
self.settings_widgets = None
|
|
104
122
|
self.initialize_settings_tab()
|
|
105
123
|
|
|
106
124
|
# initialize info about the recordings, classification data / settings
|
|
107
125
|
self.ui.epoch_length_input.setValue(self.epoch_length)
|
|
108
|
-
self.ui.
|
|
126
|
+
self.ui.overwritecheckbox.setChecked(self.only_overwrite_undefined)
|
|
127
|
+
self.ui.save_confidence_checkbox.setChecked(self.save_confidence_scores)
|
|
128
|
+
self.ui.bout_length_input.setValue(self.min_bout_length)
|
|
109
129
|
self.model = None
|
|
110
|
-
self.only_overwrite_undefined = False
|
|
111
|
-
self.save_confidence_scores = self.save_confidence_setting
|
|
112
|
-
self.min_bout_length = 5
|
|
113
130
|
|
|
114
131
|
# initialize model training variables
|
|
115
132
|
self.training_epochs_per_img = 9
|
|
@@ -186,6 +203,10 @@ class AccuSleepWindow(QMainWindow):
|
|
|
186
203
|
self.ui.export_button.clicked.connect(self.export_recording_list)
|
|
187
204
|
self.ui.import_button.clicked.connect(self.import_recording_list)
|
|
188
205
|
self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
|
|
206
|
+
self.ui.reset_emg_params_button.clicked.connect(self.reset_emg_filter_settings)
|
|
207
|
+
self.ui.reset_hyperparams_button.clicked.connect(
|
|
208
|
+
self.reset_hyperparams_settings
|
|
209
|
+
)
|
|
189
210
|
|
|
190
211
|
# user input: drag and drop
|
|
191
212
|
self.ui.recording_file_label.installEventFilter(self)
|
|
@@ -200,10 +221,9 @@ class AccuSleepWindow(QMainWindow):
|
|
|
200
221
|
|
|
201
222
|
:param default_selected: whether default option is selected
|
|
202
223
|
"""
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
self.model_type = REAL_TIME_MODEL_TYPE
|
|
224
|
+
self.model_type = (
|
|
225
|
+
DEFAULT_MODEL_TYPE if default_selected else REAL_TIME_MODEL_TYPE
|
|
226
|
+
)
|
|
207
227
|
|
|
208
228
|
def export_recording_list(self) -> None:
|
|
209
229
|
"""Save current list of recordings to file"""
|
|
@@ -363,6 +383,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
363
383
|
brain_state_set=self.brain_state_set,
|
|
364
384
|
model_type=self.model_type,
|
|
365
385
|
calibration_fraction=calibration_fraction,
|
|
386
|
+
emg_filter=self.emg_filter,
|
|
366
387
|
)
|
|
367
388
|
if len(failed_recordings) > 0:
|
|
368
389
|
if len(failed_recordings) == len(self.recordings):
|
|
@@ -391,6 +412,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
391
412
|
img_dir=temp_image_dir,
|
|
392
413
|
mixture_weights=self.brain_state_set.mixture_weights,
|
|
393
414
|
n_classes=self.brain_state_set.n_classes,
|
|
415
|
+
hyperparameters=self.hyperparameters,
|
|
394
416
|
)
|
|
395
417
|
|
|
396
418
|
# calibrate the model
|
|
@@ -399,7 +421,9 @@ class AccuSleepWindow(QMainWindow):
|
|
|
399
421
|
temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
|
|
400
422
|
)
|
|
401
423
|
calibration_dataloader = create_dataloader(
|
|
402
|
-
annotations_file=calibration_annotation_file,
|
|
424
|
+
annotations_file=calibration_annotation_file,
|
|
425
|
+
img_dir=temp_image_dir,
|
|
426
|
+
hyperparameters=self.hyperparameters,
|
|
403
427
|
)
|
|
404
428
|
model = ModelWithTemperature(model)
|
|
405
429
|
print("Calibrating model")
|
|
@@ -584,6 +608,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
584
608
|
epoch_length=self.epoch_length,
|
|
585
609
|
epochs_per_img=self.model_epochs_per_img,
|
|
586
610
|
brain_state_set=self.brain_state_set,
|
|
611
|
+
emg_filter=self.emg_filter,
|
|
587
612
|
)
|
|
588
613
|
|
|
589
614
|
# overwrite as needed
|
|
@@ -801,6 +826,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
801
826
|
sampling_rate=sampling_rate,
|
|
802
827
|
epoch_length=self.epoch_length,
|
|
803
828
|
brain_state_set=self.brain_state_set,
|
|
829
|
+
emg_filter=self.emg_filter,
|
|
804
830
|
)
|
|
805
831
|
|
|
806
832
|
self.ui.calibration_status.setText("")
|
|
@@ -965,6 +991,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
965
991
|
confidence_scores=confidence_scores,
|
|
966
992
|
sampling_rate=sampling_rate,
|
|
967
993
|
epoch_length=self.epoch_length,
|
|
994
|
+
emg_filter=self.emg_filter,
|
|
968
995
|
)
|
|
969
996
|
manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
|
|
970
997
|
manual_scoring_window.exec()
|
|
@@ -1130,15 +1157,6 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1130
1157
|
|
|
1131
1158
|
def initialize_settings_tab(self):
|
|
1132
1159
|
"""Populate settings tab and assign its callbacks"""
|
|
1133
|
-
# show information about the settings tab
|
|
1134
|
-
config_guide_file = open(
|
|
1135
|
-
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_GUIDE_FILE),
|
|
1136
|
-
"r",
|
|
1137
|
-
)
|
|
1138
|
-
config_guide_text = config_guide_file.read()
|
|
1139
|
-
config_guide_file.close()
|
|
1140
|
-
self.ui.settings_text.setText(config_guide_text)
|
|
1141
|
-
|
|
1142
1160
|
# store dictionary that maps digits to rows of widgets
|
|
1143
1161
|
# in the settings tab
|
|
1144
1162
|
self.settings_widgets = {
|
|
@@ -1215,8 +1233,21 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1215
1233
|
}
|
|
1216
1234
|
|
|
1217
1235
|
# update widget state to display current config
|
|
1236
|
+
# UI defaults
|
|
1218
1237
|
self.ui.default_epoch_input.setValue(self.epoch_length)
|
|
1219
|
-
self.ui.
|
|
1238
|
+
self.ui.overwrite_default_checkbox.setChecked(self.only_overwrite_undefined)
|
|
1239
|
+
self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_scores)
|
|
1240
|
+
self.ui.default_min_bout_length_spinbox.setValue(self.min_bout_length)
|
|
1241
|
+
# EMG filter
|
|
1242
|
+
self.ui.emg_order_spinbox.setValue(self.emg_filter.order)
|
|
1243
|
+
self.ui.bp_lower_spinbox.setValue(self.emg_filter.bp_lower)
|
|
1244
|
+
self.ui.bp_upper_spinbox.setValue(self.emg_filter.bp_upper)
|
|
1245
|
+
# model training hyperparameters
|
|
1246
|
+
self.ui.batch_size_spinbox.setValue(self.hyperparameters.batch_size)
|
|
1247
|
+
self.ui.learning_rate_spinbox.setValue(self.hyperparameters.learning_rate)
|
|
1248
|
+
self.ui.momentum_spinbox.setValue(self.hyperparameters.momentum)
|
|
1249
|
+
self.ui.training_epochs_spinbox.setValue(self.hyperparameters.training_epochs)
|
|
1250
|
+
# brain states
|
|
1220
1251
|
states = {b.digit: b for b in self.brain_state_set.brain_states}
|
|
1221
1252
|
for digit in range(10):
|
|
1222
1253
|
if digit in states.keys():
|
|
@@ -1235,16 +1266,25 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1235
1266
|
self.settings_widgets[digit].frequency_widget.setEnabled(False)
|
|
1236
1267
|
|
|
1237
1268
|
# set callbacks
|
|
1269
|
+
self.ui.emg_order_spinbox.valueChanged.connect(self.emg_filter_order_changed)
|
|
1270
|
+
self.ui.bp_lower_spinbox.valueChanged.connect(self.emg_filter_bp_lower_changed)
|
|
1271
|
+
self.ui.bp_upper_spinbox.valueChanged.connect(self.emg_filter_bp_upper_changed)
|
|
1272
|
+
self.ui.batch_size_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1273
|
+
self.ui.learning_rate_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1274
|
+
self.ui.momentum_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1275
|
+
self.ui.training_epochs_spinbox.valueChanged.connect(
|
|
1276
|
+
self.hyperparameters_changed
|
|
1277
|
+
)
|
|
1238
1278
|
for digit in range(10):
|
|
1239
1279
|
state = self.settings_widgets[digit]
|
|
1240
1280
|
state.enabled_widget.stateChanged.connect(
|
|
1241
1281
|
partial(self.set_brain_state_enabled, digit)
|
|
1242
1282
|
)
|
|
1243
|
-
state.name_widget.editingFinished.connect(self.
|
|
1283
|
+
state.name_widget.editingFinished.connect(self.check_config_validity)
|
|
1244
1284
|
state.is_scored_widget.stateChanged.connect(
|
|
1245
1285
|
partial(self.is_scored_changed, digit)
|
|
1246
1286
|
)
|
|
1247
|
-
state.frequency_widget.valueChanged.connect(self.
|
|
1287
|
+
state.frequency_widget.valueChanged.connect(self.check_config_validity)
|
|
1248
1288
|
|
|
1249
1289
|
def set_brain_state_enabled(self, digit, e) -> None:
|
|
1250
1290
|
"""Called when user clicks "enabled" checkbox
|
|
@@ -1270,17 +1310,6 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1270
1310
|
# check that configuration is valid
|
|
1271
1311
|
_ = self.check_config_validity()
|
|
1272
1312
|
|
|
1273
|
-
def finished_editing_state_name(self) -> None:
|
|
1274
|
-
"""Called when user finishes editing a brain state's name"""
|
|
1275
|
-
_ = self.check_config_validity()
|
|
1276
|
-
|
|
1277
|
-
def state_frequency_changed(self, new_value) -> None:
|
|
1278
|
-
"""Called when user edits a brain state's frequency
|
|
1279
|
-
|
|
1280
|
-
:param new_value: unused
|
|
1281
|
-
"""
|
|
1282
|
-
_ = self.check_config_validity()
|
|
1283
|
-
|
|
1284
1313
|
def is_scored_changed(self, digit, e) -> None:
|
|
1285
1314
|
"""Called when user sets whether a state is scored
|
|
1286
1315
|
|
|
@@ -1297,6 +1326,41 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1297
1326
|
# check that configuration is valid
|
|
1298
1327
|
_ = self.check_config_validity()
|
|
1299
1328
|
|
|
1329
|
+
def emg_filter_order_changed(self, new_value: int) -> None:
|
|
1330
|
+
"""Called when user modifies EMG filter order
|
|
1331
|
+
|
|
1332
|
+
:param new_value: new EMG filter order
|
|
1333
|
+
"""
|
|
1334
|
+
self.emg_filter.order = new_value
|
|
1335
|
+
|
|
1336
|
+
def emg_filter_bp_lower_changed(self, new_value: int | float) -> None:
|
|
1337
|
+
"""Called when user modifies EMG filter lower cutoff
|
|
1338
|
+
|
|
1339
|
+
:param new_value: new lower bandpass cutoff frequency
|
|
1340
|
+
"""
|
|
1341
|
+
self.emg_filter.bp_lower = new_value
|
|
1342
|
+
_ = self.check_config_validity()
|
|
1343
|
+
|
|
1344
|
+
def emg_filter_bp_upper_changed(self, new_value: int | float) -> None:
|
|
1345
|
+
"""Called when user modifies EMG filter upper cutoff
|
|
1346
|
+
|
|
1347
|
+
:param new_value: new upper bandpass cutoff frequency
|
|
1348
|
+
"""
|
|
1349
|
+
self.emg_filter.bp_upper = new_value
|
|
1350
|
+
_ = self.check_config_validity()
|
|
1351
|
+
|
|
1352
|
+
def hyperparameters_changed(self, new_value) -> None:
|
|
1353
|
+
"""Called when user modifies model training hyperparameters
|
|
1354
|
+
|
|
1355
|
+
:param new_value: unused
|
|
1356
|
+
"""
|
|
1357
|
+
self.hyperparameters = Hyperparameters(
|
|
1358
|
+
batch_size=self.ui.batch_size_spinbox.value(),
|
|
1359
|
+
learning_rate=self.ui.learning_rate_spinbox.value(),
|
|
1360
|
+
momentum=self.ui.momentum_spinbox.value(),
|
|
1361
|
+
training_epochs=self.ui.training_epochs_spinbox.value(),
|
|
1362
|
+
)
|
|
1363
|
+
|
|
1300
1364
|
def check_config_validity(self) -> str:
|
|
1301
1365
|
"""Check if brain state configuration on screen is valid"""
|
|
1302
1366
|
# error message, if we get one
|
|
@@ -1323,6 +1387,10 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1323
1387
|
if sum(frequencies) != 1:
|
|
1324
1388
|
message = "Error: sum(frequencies) != 1"
|
|
1325
1389
|
|
|
1390
|
+
# check validity of EMG filter settings
|
|
1391
|
+
if self.emg_filter.bp_lower >= self.emg_filter.bp_upper:
|
|
1392
|
+
message = "Error: EMG filter cutoff frequencies are invalid"
|
|
1393
|
+
|
|
1326
1394
|
if message is not None:
|
|
1327
1395
|
self.ui.save_config_status.setText(message)
|
|
1328
1396
|
self.ui.save_config_button.setEnabled(False)
|
|
@@ -1355,133 +1423,35 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1355
1423
|
|
|
1356
1424
|
# save to file
|
|
1357
1425
|
save_config(
|
|
1358
|
-
self.brain_state_set,
|
|
1359
|
-
self.ui.default_epoch_input.value(),
|
|
1360
|
-
self.ui.
|
|
1426
|
+
brain_state_set=self.brain_state_set,
|
|
1427
|
+
default_epoch_length=self.ui.default_epoch_input.value(),
|
|
1428
|
+
overwrite_setting=self.ui.overwrite_default_checkbox.isChecked(),
|
|
1429
|
+
save_confidence_setting=self.ui.confidence_setting_checkbox.isChecked(),
|
|
1430
|
+
min_bout_length=self.ui.default_min_bout_length_spinbox.value(),
|
|
1431
|
+
emg_filter=EMGFilter(
|
|
1432
|
+
order=self.emg_filter.order,
|
|
1433
|
+
bp_lower=self.emg_filter.bp_lower,
|
|
1434
|
+
bp_upper=self.emg_filter.bp_upper,
|
|
1435
|
+
),
|
|
1436
|
+
hyperparameters=Hyperparameters(
|
|
1437
|
+
batch_size=self.hyperparameters.batch_size,
|
|
1438
|
+
learning_rate=self.hyperparameters.learning_rate,
|
|
1439
|
+
momentum=self.hyperparameters.momentum,
|
|
1440
|
+
training_epochs=self.hyperparameters.training_epochs,
|
|
1441
|
+
),
|
|
1361
1442
|
)
|
|
1362
1443
|
self.ui.save_config_status.setText("configuration saved")
|
|
1363
1444
|
|
|
1445
|
+
def reset_emg_filter_settings(self) -> None:
|
|
1446
|
+
self.ui.emg_order_spinbox.setValue(DEFAULT_EMG_FILTER_ORDER)
|
|
1447
|
+
self.ui.bp_lower_spinbox.setValue(DEFAULT_EMG_BP_LOWER)
|
|
1448
|
+
self.ui.bp_upper_spinbox.setValue(DEFAULT_EMG_BP_UPPER)
|
|
1364
1449
|
|
|
1365
|
-
def
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
epoch_length: int | float,
|
|
1371
|
-
brain_state_set: BrainStateSet,
|
|
1372
|
-
) -> str | None:
|
|
1373
|
-
"""Check whether a set of brain state labels is valid
|
|
1374
|
-
|
|
1375
|
-
This returns an error message if a problem is found with the
|
|
1376
|
-
brain state labels.
|
|
1377
|
-
|
|
1378
|
-
:param labels: brain state labels
|
|
1379
|
-
:param confidence_scores: confidence scores
|
|
1380
|
-
:param samples_in_recording: number of samples in the recording
|
|
1381
|
-
:param sampling_rate: sampling rate, in Hz
|
|
1382
|
-
:param epoch_length: epoch length, in seconds
|
|
1383
|
-
:param brain_state_set: BrainStateMapper object
|
|
1384
|
-
:return: error message
|
|
1385
|
-
"""
|
|
1386
|
-
# check that number of labels is correct
|
|
1387
|
-
samples_per_epoch = round(sampling_rate * epoch_length)
|
|
1388
|
-
epochs_in_recording = round(samples_in_recording / samples_per_epoch)
|
|
1389
|
-
if epochs_in_recording != labels.size:
|
|
1390
|
-
return LABEL_LENGTH_ERROR
|
|
1391
|
-
|
|
1392
|
-
# check that entries are valid
|
|
1393
|
-
if not set(labels.tolist()).issubset(
|
|
1394
|
-
set([b.digit for b in brain_state_set.brain_states] + [UNDEFINED_LABEL])
|
|
1395
|
-
):
|
|
1396
|
-
return "label file contains invalid entries"
|
|
1397
|
-
|
|
1398
|
-
if confidence_scores is not None:
|
|
1399
|
-
if np.min(confidence_scores) < 0 or np.max(confidence_scores) > 1:
|
|
1400
|
-
return "label file contains invalid confidence scores"
|
|
1401
|
-
|
|
1402
|
-
return None
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
def check_config_consistency(
|
|
1406
|
-
current_brain_states: dict,
|
|
1407
|
-
model_brain_states: dict,
|
|
1408
|
-
current_epoch_length: int | float,
|
|
1409
|
-
model_epoch_length: int | float,
|
|
1410
|
-
) -> list[str]:
|
|
1411
|
-
"""Compare current brain state config to the model's config
|
|
1412
|
-
|
|
1413
|
-
This only displays warnings - the user should decide whether to proceed
|
|
1414
|
-
|
|
1415
|
-
:param current_brain_states: current brain state config
|
|
1416
|
-
:param model_brain_states: brain state config when the model was created
|
|
1417
|
-
:param current_epoch_length: current epoch length setting
|
|
1418
|
-
:param model_epoch_length: epoch length used when the model was created
|
|
1419
|
-
"""
|
|
1420
|
-
output = list()
|
|
1421
|
-
|
|
1422
|
-
# make lists of names and digits for scored brain states
|
|
1423
|
-
current_scored_states = {
|
|
1424
|
-
f: [b[f] for b in current_brain_states if b["is_scored"]]
|
|
1425
|
-
for f in ["name", "digit"]
|
|
1426
|
-
}
|
|
1427
|
-
model_scored_states = {
|
|
1428
|
-
f: [b[f] for b in model_brain_states if b["is_scored"]]
|
|
1429
|
-
for f in ["name", "digit"]
|
|
1430
|
-
}
|
|
1431
|
-
|
|
1432
|
-
# generate message comparing the brain state configs
|
|
1433
|
-
config_comparisons = list()
|
|
1434
|
-
for config, config_name in zip(
|
|
1435
|
-
[current_scored_states, model_scored_states], ["current", "model's"]
|
|
1436
|
-
):
|
|
1437
|
-
config_comparisons.append(
|
|
1438
|
-
f"Scored brain states in {config_name} configuration: "
|
|
1439
|
-
f"""{
|
|
1440
|
-
", ".join(
|
|
1441
|
-
[
|
|
1442
|
-
f"{x}: {y}"
|
|
1443
|
-
for x, y in zip(
|
|
1444
|
-
config["digit"],
|
|
1445
|
-
config["name"],
|
|
1446
|
-
)
|
|
1447
|
-
]
|
|
1448
|
-
)
|
|
1449
|
-
}"""
|
|
1450
|
-
)
|
|
1451
|
-
|
|
1452
|
-
# check if the number of scored states is different
|
|
1453
|
-
len_diff = len(current_scored_states["name"]) - len(model_scored_states["name"])
|
|
1454
|
-
if len_diff != 0:
|
|
1455
|
-
output.append(
|
|
1456
|
-
(
|
|
1457
|
-
"WARNING: current brain state configuration has "
|
|
1458
|
-
f"{'fewer' if len_diff < 0 else 'more'} "
|
|
1459
|
-
"scored brain states than the model's configuration."
|
|
1460
|
-
)
|
|
1461
|
-
)
|
|
1462
|
-
output = output + config_comparisons
|
|
1463
|
-
else:
|
|
1464
|
-
# the length is the same, but names might be different
|
|
1465
|
-
if current_scored_states["name"] != model_scored_states["name"]:
|
|
1466
|
-
output.append(
|
|
1467
|
-
(
|
|
1468
|
-
"WARNING: current brain state configuration appears "
|
|
1469
|
-
"to contain different brain states than "
|
|
1470
|
-
"the model's configuration."
|
|
1471
|
-
)
|
|
1472
|
-
)
|
|
1473
|
-
output = output + config_comparisons
|
|
1474
|
-
|
|
1475
|
-
if current_epoch_length != model_epoch_length:
|
|
1476
|
-
output.append(
|
|
1477
|
-
(
|
|
1478
|
-
"Warning: the epoch length used when training this model "
|
|
1479
|
-
f"({model_epoch_length} seconds) "
|
|
1480
|
-
"does not match the current epoch length setting."
|
|
1481
|
-
)
|
|
1482
|
-
)
|
|
1483
|
-
|
|
1484
|
-
return output
|
|
1450
|
+
def reset_hyperparams_settings(self):
|
|
1451
|
+
self.ui.batch_size_spinbox.setValue(DEFAULT_BATCH_SIZE)
|
|
1452
|
+
self.ui.learning_rate_spinbox.setValue(DEFAULT_LEARNING_RATE)
|
|
1453
|
+
self.ui.momentum_spinbox.setValue(DEFAULT_MOMENTUM)
|
|
1454
|
+
self.ui.training_epochs_spinbox.setValue(DEFAULT_TRAINING_EPOCHS)
|
|
1485
1455
|
|
|
1486
1456
|
|
|
1487
1457
|
def run_primary_window() -> None:
|
accusleepy/gui/manual_scoring.py
CHANGED
|
@@ -32,7 +32,7 @@ from PySide6.QtWidgets import (
|
|
|
32
32
|
)
|
|
33
33
|
|
|
34
34
|
from accusleepy.constants import UNDEFINED_LABEL
|
|
35
|
-
from accusleepy.fileio import load_config, save_labels
|
|
35
|
+
from accusleepy.fileio import load_config, save_labels, EMGFilter
|
|
36
36
|
from accusleepy.gui.mplwidget import resample_x_ticks
|
|
37
37
|
from accusleepy.gui.viewer_window import Ui_ViewerWindow
|
|
38
38
|
from accusleepy.signal_processing import create_spectrogram, get_emg_power
|
|
@@ -79,6 +79,12 @@ UNDEFINED_STATE = "undefined"
|
|
|
79
79
|
SCROLL_BOUNDARY = 0.35
|
|
80
80
|
# max number of sequential undo actions allowed
|
|
81
81
|
UNDO_LIMIT = 1000
|
|
82
|
+
# brightness scaling factors for the spectrogram
|
|
83
|
+
BRIGHTER_SCALE_FACTOR = 0.96
|
|
84
|
+
DIMMER_SCALE_FACTOR = 1.07
|
|
85
|
+
# zoom factor for upper plots
|
|
86
|
+
ZOOM_IN_FACTOR = 0.45
|
|
87
|
+
ZOOM_OUT_FACTOR = 1.017
|
|
82
88
|
|
|
83
89
|
|
|
84
90
|
@dataclass
|
|
@@ -102,6 +108,7 @@ class ManualScoringWindow(QDialog):
|
|
|
102
108
|
confidence_scores: np.array,
|
|
103
109
|
sampling_rate: int | float,
|
|
104
110
|
epoch_length: int | float,
|
|
111
|
+
emg_filter: EMGFilter,
|
|
105
112
|
):
|
|
106
113
|
"""Initialize the manual scoring window
|
|
107
114
|
|
|
@@ -112,6 +119,7 @@ class ManualScoringWindow(QDialog):
|
|
|
112
119
|
:param confidence_scores: confidence scores
|
|
113
120
|
:param sampling_rate: sampling rate, in Hz
|
|
114
121
|
:param epoch_length: epoch length, in seconds
|
|
122
|
+
:param emg_filter: EMG filter parameters
|
|
115
123
|
"""
|
|
116
124
|
super(ManualScoringWindow, self).__init__()
|
|
117
125
|
|
|
@@ -122,6 +130,7 @@ class ManualScoringWindow(QDialog):
|
|
|
122
130
|
self.confidence_scores = confidence_scores
|
|
123
131
|
self.sampling_rate = sampling_rate
|
|
124
132
|
self.epoch_length = epoch_length
|
|
133
|
+
self.emg_filter = emg_filter
|
|
125
134
|
|
|
126
135
|
self.n_epochs = len(self.labels)
|
|
127
136
|
|
|
@@ -131,7 +140,7 @@ class ManualScoringWindow(QDialog):
|
|
|
131
140
|
self.setWindowTitle("AccuSleePy manual scoring window")
|
|
132
141
|
|
|
133
142
|
# load set of valid brain states
|
|
134
|
-
self.brain_state_set, _, _ = load_config()
|
|
143
|
+
self.brain_state_set, _, _, _, _, _, _ = load_config()
|
|
135
144
|
|
|
136
145
|
# initial setting for number of epochs to show in the lower plot
|
|
137
146
|
self.epochs_to_show = 5
|
|
@@ -153,7 +162,7 @@ class ManualScoringWindow(QDialog):
|
|
|
153
162
|
|
|
154
163
|
# calculate RMS of EMG for each epoch and apply a ceiling
|
|
155
164
|
self.upper_emg = create_upper_emg_signal(
|
|
156
|
-
self.emg, self.sampling_rate, self.epoch_length
|
|
165
|
+
self.emg, self.sampling_rate, self.epoch_length, self.emg_filter
|
|
157
166
|
)
|
|
158
167
|
|
|
159
168
|
# center and scale the EEG and EMG signals to fit the display
|
|
@@ -229,23 +238,6 @@ class ManualScoringWindow(QDialog):
|
|
|
229
238
|
keypress_zoom_out_x = QShortcut(QKeySequence(Qt.Key.Key_Minus), self)
|
|
230
239
|
keypress_zoom_out_x.activated.connect(partial(self.zoom_x, ZOOM_OUT))
|
|
231
240
|
|
|
232
|
-
keypress_modify_label = list()
|
|
233
|
-
for brain_state in self.brain_state_set.brain_states:
|
|
234
|
-
keypress_modify_label.append(
|
|
235
|
-
QShortcut(
|
|
236
|
-
QKeySequence(Qt.Key[f"Key_{brain_state.digit}"]),
|
|
237
|
-
self,
|
|
238
|
-
)
|
|
239
|
-
)
|
|
240
|
-
keypress_modify_label[-1].activated.connect(
|
|
241
|
-
partial(self.modify_current_epoch_label, brain_state.digit)
|
|
242
|
-
)
|
|
243
|
-
|
|
244
|
-
keypress_delete_label = QShortcut(QKeySequence(Qt.Key.Key_Backspace), self)
|
|
245
|
-
keypress_delete_label.activated.connect(
|
|
246
|
-
partial(self.modify_current_epoch_label, UNDEFINED_LABEL)
|
|
247
|
-
)
|
|
248
|
-
|
|
249
241
|
keypress_quit = QShortcut(
|
|
250
242
|
QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_W)),
|
|
251
243
|
self,
|
|
@@ -258,36 +250,37 @@ class ManualScoringWindow(QDialog):
|
|
|
258
250
|
)
|
|
259
251
|
keypress_save.activated.connect(self.save)
|
|
260
252
|
|
|
253
|
+
keypress_modify_label = list()
|
|
261
254
|
keypress_roi = list()
|
|
262
|
-
|
|
255
|
+
digit_key_label_pairs = [
|
|
256
|
+
(Qt.Key[f"Key_{brain_state.digit}"], brain_state.digit)
|
|
257
|
+
for brain_state in self.brain_state_set.brain_states
|
|
258
|
+
] + [(Qt.Key.Key_Backspace, UNDEFINED_LABEL)]
|
|
259
|
+
|
|
260
|
+
for digit_key, digit_label in digit_key_label_pairs:
|
|
261
|
+
keypress_modify_label.append(
|
|
262
|
+
QShortcut(
|
|
263
|
+
QKeySequence(digit_key),
|
|
264
|
+
self,
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
keypress_modify_label[-1].activated.connect(
|
|
268
|
+
partial(self.modify_current_epoch_label, digit_label)
|
|
269
|
+
)
|
|
263
270
|
keypress_roi.append(
|
|
264
271
|
QShortcut(
|
|
265
272
|
QKeySequence(
|
|
266
273
|
QKeyCombination(
|
|
267
274
|
Qt.Modifier.SHIFT,
|
|
268
|
-
|
|
275
|
+
digit_key,
|
|
269
276
|
)
|
|
270
277
|
),
|
|
271
278
|
self,
|
|
272
279
|
)
|
|
273
280
|
)
|
|
274
281
|
keypress_roi[-1].activated.connect(
|
|
275
|
-
partial(self.enter_label_roi_mode,
|
|
276
|
-
)
|
|
277
|
-
keypress_roi.append(
|
|
278
|
-
QShortcut(
|
|
279
|
-
QKeySequence(
|
|
280
|
-
QKeyCombination(
|
|
281
|
-
Qt.Modifier.SHIFT,
|
|
282
|
-
Qt.Key.Key_Backspace,
|
|
283
|
-
)
|
|
284
|
-
),
|
|
285
|
-
self,
|
|
282
|
+
partial(self.enter_label_roi_mode, digit_label)
|
|
286
283
|
)
|
|
287
|
-
)
|
|
288
|
-
keypress_roi[-1].activated.connect(
|
|
289
|
-
partial(self.enter_label_roi_mode, UNDEFINED_LABEL)
|
|
290
|
-
)
|
|
291
284
|
|
|
292
285
|
keypress_esc = QShortcut(QKeySequence(Qt.Key.Key_Escape), self)
|
|
293
286
|
keypress_esc.activated.connect(self.exit_label_roi_mode)
|
|
@@ -623,9 +616,9 @@ class ManualScoringWindow(QDialog):
|
|
|
623
616
|
"""
|
|
624
617
|
vmin, vmax = self.ui.upperfigure.spec_ref.get_clim()
|
|
625
618
|
if direction == BRIGHTER:
|
|
626
|
-
self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax *
|
|
619
|
+
self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax * BRIGHTER_SCALE_FACTOR))
|
|
627
620
|
else:
|
|
628
|
-
self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax *
|
|
621
|
+
self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax * DIMMER_SCALE_FACTOR))
|
|
629
622
|
self.ui.upperfigure.canvas.draw()
|
|
630
623
|
|
|
631
624
|
def update_epochs_shown(self, direction: str) -> None:
|
|
@@ -722,31 +715,29 @@ class ManualScoringWindow(QDialog):
|
|
|
722
715
|
|
|
723
716
|
:param direction: in, out, or reset
|
|
724
717
|
"""
|
|
725
|
-
zoom_in_factor = 0.45
|
|
726
|
-
zoom_out_factor = 1.017
|
|
727
718
|
epochs_shown = self.upper_right_epoch - self.upper_left_epoch + 1
|
|
728
719
|
if direction == ZOOM_IN:
|
|
729
720
|
self.upper_left_epoch = max(
|
|
730
721
|
[
|
|
731
722
|
self.upper_left_epoch,
|
|
732
|
-
round(self.epoch -
|
|
723
|
+
round(self.epoch - ZOOM_IN_FACTOR * epochs_shown),
|
|
733
724
|
]
|
|
734
725
|
)
|
|
735
726
|
|
|
736
727
|
self.upper_right_epoch = min(
|
|
737
728
|
[
|
|
738
729
|
self.upper_right_epoch,
|
|
739
|
-
round(self.epoch +
|
|
730
|
+
round(self.epoch + ZOOM_IN_FACTOR * epochs_shown),
|
|
740
731
|
]
|
|
741
732
|
)
|
|
742
733
|
|
|
743
734
|
elif direction == ZOOM_OUT:
|
|
744
735
|
self.upper_left_epoch = max(
|
|
745
|
-
[0, round(self.epoch -
|
|
736
|
+
[0, round(self.epoch - ZOOM_OUT_FACTOR * epochs_shown)]
|
|
746
737
|
)
|
|
747
738
|
|
|
748
739
|
self.upper_right_epoch = min(
|
|
749
|
-
[self.n_epochs - 1, round(self.epoch +
|
|
740
|
+
[self.n_epochs - 1, round(self.epoch + ZOOM_OUT_FACTOR * epochs_shown)]
|
|
750
741
|
)
|
|
751
742
|
|
|
752
743
|
else: # reset
|
|
@@ -1063,21 +1054,28 @@ def create_confidence_img(confidence_scores: np.array) -> np.array:
|
|
|
1063
1054
|
|
|
1064
1055
|
|
|
1065
1056
|
def create_upper_emg_signal(
|
|
1066
|
-
emg: np.array,
|
|
1057
|
+
emg: np.array,
|
|
1058
|
+
sampling_rate: int | float,
|
|
1059
|
+
epoch_length: int | float,
|
|
1060
|
+
emg_filter: EMGFilter,
|
|
1067
1061
|
) -> np.array:
|
|
1068
1062
|
"""Calculate RMS of EMG for each epoch and apply a ceiling
|
|
1069
1063
|
|
|
1070
1064
|
:param emg: EMG signal
|
|
1071
1065
|
:param sampling_rate: sampling rate, in Hz
|
|
1072
1066
|
:param epoch_length: epoch length, in seconds
|
|
1067
|
+
:param emg_filter: EMG filter parameters
|
|
1073
1068
|
:return: processed EMG signal
|
|
1074
1069
|
"""
|
|
1075
1070
|
emg_rms = get_emg_power(
|
|
1076
1071
|
emg,
|
|
1077
1072
|
sampling_rate,
|
|
1078
1073
|
epoch_length,
|
|
1074
|
+
emg_filter,
|
|
1075
|
+
)
|
|
1076
|
+
return np.clip(
|
|
1077
|
+
emg_rms, np.percentile(emg_rms, 0.1), np.mean(emg_rms) + np.std(emg_rms) * 2.5
|
|
1079
1078
|
)
|
|
1080
|
-
return np.clip(emg_rms, 0, np.mean(emg_rms) + np.std(emg_rms) * 2.5)
|
|
1081
1079
|
|
|
1082
1080
|
|
|
1083
1081
|
def transform_eeg_emg(eeg: np.array, emg: np.array) -> (np.array, np.array):
|