accusleepy 0.8.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
accusleepy/gui/main.py CHANGED
@@ -1,15 +1,13 @@
1
1
  # AccuSleePy main window
2
2
  # Icon source: Arkinasi, https://www.flaticon.com/authors/arkinasi
3
3
 
4
- import datetime
4
+ import logging
5
5
  import os
6
- import shutil
7
6
  import sys
8
7
  from dataclasses import dataclass
9
8
  from functools import partial
10
9
 
11
10
  import numpy as np
12
- import toml
13
11
  from PySide6.QtCore import (
14
12
  QEvent,
15
13
  QKeyCombination,
@@ -21,31 +19,17 @@ from PySide6.QtCore import (
21
19
  from PySide6.QtGui import QKeySequence, QShortcut
22
20
  from PySide6.QtWidgets import (
23
21
  QApplication,
24
- QCheckBox,
25
- QDoubleSpinBox,
26
- QFileDialog,
27
22
  QLabel,
28
- QListWidgetItem,
29
23
  QMainWindow,
30
24
  QTextBrowser,
31
25
  QVBoxLayout,
32
26
  QWidget,
33
27
  )
34
28
 
35
- from accusleepy.bouts import enforce_min_bout_length
36
- from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
29
+ from accusleepy.brain_state_set import BRAIN_STATES_KEY
37
30
  from accusleepy.constants import (
38
- ANNOTATIONS_FILENAME,
39
- CALIBRATION_ANNOTATION_FILENAME,
40
31
  CALIBRATION_FILE_TYPE,
41
32
  DEFAULT_MODEL_TYPE,
42
- DEFAULT_EMG_FILTER_ORDER,
43
- DEFAULT_EMG_BP_LOWER,
44
- DEFAULT_EMG_BP_UPPER,
45
- DEFAULT_BATCH_SIZE,
46
- DEFAULT_LEARNING_RATE,
47
- DEFAULT_MOMENTUM,
48
- DEFAULT_TRAINING_EPOCHS,
49
33
  LABEL_FILE_TYPE,
50
34
  MESSAGE_BOX_MAX_DEPTH,
51
35
  MODEL_FILE_TYPE,
@@ -55,31 +39,28 @@ from accusleepy.constants import (
55
39
  UNDEFINED_LABEL,
56
40
  )
57
41
  from accusleepy.fileio import (
58
- Recording,
59
- load_calibration_file,
60
42
  load_config,
61
43
  load_labels,
62
44
  load_recording,
63
- load_recording_list,
64
- save_config,
65
- save_labels,
66
- save_recording_list,
67
- EMGFilter,
68
- Hyperparameters,
45
+ get_version,
69
46
  )
47
+ from accusleepy.gui.dialogs import select_existing_file, select_save_location
70
48
  from accusleepy.gui.manual_scoring import ManualScoringWindow
71
49
  from accusleepy.gui.primary_window import Ui_PrimaryWindow
72
- from accusleepy.signal_processing import (
73
- create_training_images,
74
- resample_and_standardize,
75
- )
76
- from accusleepy.validation import (
77
- check_label_validity,
78
- LABEL_LENGTH_ERROR,
79
- check_config_consistency,
50
+ from accusleepy.gui.recording_manager import RecordingListManager
51
+ from accusleepy.gui.settings_widget import SettingsWidget
52
+ from accusleepy.services import (
53
+ LoadedModel,
54
+ TrainingService,
55
+ check_single_file_inputs,
56
+ create_calibration,
57
+ score_recording_list,
80
58
  )
59
+ from accusleepy.validation import validate_and_correct_labels
60
+ from accusleepy.signal_processing import resample_and_standardize
61
+ from accusleepy.validation import check_config_consistency
81
62
 
82
- # note: functions using torch or scipy are lazily imported
63
+ logger = logging.getLogger(__name__)
83
64
 
84
65
  # on Windows, prevent dark mode from changing the visual style
85
66
  if os.name == "nt":
@@ -91,14 +72,22 @@ MAIN_GUIDE_FILE = os.path.normpath(r"text/main_guide.md")
91
72
 
92
73
 
93
74
  @dataclass
94
- class StateSettings:
95
- """Widgets for config settings for a brain state"""
75
+ class TrainingSettings:
76
+ """Settings for training a new model"""
77
+
78
+ epochs_per_img: int = 9
79
+ delete_images: bool = True
80
+ model_type: str = DEFAULT_MODEL_TYPE
81
+ calibrate: bool = True
82
+
83
+
84
+ @dataclass
85
+ class ScoringSettings:
86
+ """Settings for scoring a recording"""
96
87
 
97
- digit: int
98
- enabled_widget: QCheckBox
99
- name_widget: QLabel
100
- is_scored_widget: QCheckBox
101
- frequency_widget: QDoubleSpinBox
88
+ only_overwrite_undefined: bool
89
+ save_confidence_scores: bool
90
+ min_bout_length: int | float
102
91
 
103
92
 
104
93
  class AccuSleepWindow(QMainWindow):
@@ -112,65 +101,42 @@ class AccuSleepWindow(QMainWindow):
112
101
  self.ui.setupUi(self)
113
102
  self.setWindowTitle("AccuSleePy")
114
103
 
115
- # fill in settings tab
116
- config = load_config()
117
- self.brain_state_set = config.brain_state_set
118
- self.epoch_length = config.default_epoch_length
119
- self.only_overwrite_undefined = config.overwrite_setting
120
- self.save_confidence_scores = config.save_confidence_setting
121
- self.min_bout_length = config.min_bout_length
122
- self.emg_filter = config.emg_filter
123
- self.hyperparameters = config.hyperparameters
124
- self.default_epochs_to_show = config.epochs_to_show
125
- self.default_autoscroll_state = config.autoscroll_state
126
-
127
- self.settings_widgets = None
128
- self.initialize_settings_tab()
104
+ # Load configuration
105
+ loaded_config = load_config()
106
+
107
+ # Apply default values from the configuration
108
+ self.epoch_length = loaded_config.default_epoch_length
109
+ self.scoring = ScoringSettings(
110
+ only_overwrite_undefined=loaded_config.overwrite_setting,
111
+ save_confidence_scores=loaded_config.save_confidence_setting,
112
+ min_bout_length=loaded_config.min_bout_length,
113
+ )
114
+
115
+ # Initialize settings tab (manages Settings tab UI and saved config values)
116
+ self.config = SettingsWidget(ui=self.ui, config=loaded_config, parent=self)
129
117
 
130
118
  # initialize info about the recordings, classification data / settings
131
119
  self.ui.epoch_length_input.setValue(self.epoch_length)
132
- self.ui.overwritecheckbox.setChecked(self.only_overwrite_undefined)
133
- self.ui.save_confidence_checkbox.setChecked(self.save_confidence_scores)
134
- self.ui.bout_length_input.setValue(self.min_bout_length)
135
- self.model = None
120
+ self.ui.overwritecheckbox.setChecked(self.scoring.only_overwrite_undefined)
121
+ self.ui.save_confidence_checkbox.setChecked(self.scoring.save_confidence_scores)
122
+ self.ui.bout_length_input.setValue(self.scoring.min_bout_length)
136
123
 
137
- # initialize model training variables
138
- self.training_epochs_per_img = 9
139
- self.delete_training_images = True
140
- self.model_type = DEFAULT_MODEL_TYPE
141
- self.calibrate_trained_model = True
124
+ # loaded classification model and its metadata
125
+ self.loaded_model = LoadedModel()
142
126
 
143
- # metadata for the currently loaded classification model
144
- self.model_epoch_length = None
145
- self.model_epochs_per_img = None
127
+ # settings for training new models
128
+ self.training = TrainingSettings()
146
129
 
147
130
  # set up the list of recordings
148
- first_recording = Recording(
149
- widget=QListWidgetItem("Recording 1", self.ui.recording_list_widget),
131
+ self.recording_manager = RecordingListManager(
132
+ self.ui.recording_list_widget, parent=self
150
133
  )
151
- self.ui.recording_list_widget.addItem(first_recording.widget)
152
- self.ui.recording_list_widget.setCurrentRow(0)
153
- # index of currently selected recording in the list
154
- self.recording_index = 0
155
- # list of recordings the user has added
156
- self.recordings = [first_recording]
157
134
 
158
135
  # messages to display
159
136
  self.messages = []
160
137
 
161
138
  # display current version
162
- version = ""
163
- toml_file = os.path.join(
164
- os.path.dirname(
165
- os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
166
- ),
167
- "pyproject.toml",
168
- )
169
- if os.path.isfile(toml_file):
170
- toml_data = toml.load(toml_file)
171
- if "project" in toml_data and "version" in toml_data["project"]:
172
- version = toml_data["project"]["version"]
173
- self.ui.version_label.setText(f"v{version}")
139
+ self.ui.version_label.setText(f"v{get_version()}")
174
140
 
175
141
  # user input: keyboard shortcuts
176
142
  keypress_quit = QShortcut(
@@ -183,8 +149,12 @@ class AccuSleepWindow(QMainWindow):
183
149
  self.ui.add_button.clicked.connect(self.add_recording)
184
150
  self.ui.remove_button.clicked.connect(self.remove_recording)
185
151
  self.ui.recording_list_widget.currentRowChanged.connect(self.select_recording)
186
- self.ui.sampling_rate_input.valueChanged.connect(self.update_sampling_rate)
187
- self.ui.epoch_length_input.valueChanged.connect(self.update_epoch_length)
152
+ self.ui.sampling_rate_input.valueChanged.connect(
153
+ lambda v: setattr(self.recording_manager.current, "sampling_rate", v)
154
+ )
155
+ self.ui.epoch_length_input.valueChanged.connect(
156
+ lambda v: setattr(self, "epoch_length", v)
157
+ )
188
158
  self.ui.recording_file_button.clicked.connect(self.select_recording_file)
189
159
  self.ui.select_label_button.clicked.connect(self.select_label_file)
190
160
  self.ui.create_label_button.clicked.connect(self.create_label_file)
@@ -192,27 +162,33 @@ class AccuSleepWindow(QMainWindow):
192
162
  self.ui.create_calibration_button.clicked.connect(self.create_calibration_file)
193
163
  self.ui.select_calibration_button.clicked.connect(self.select_calibration_file)
194
164
  self.ui.load_model_button.clicked.connect(partial(self.load_model, None))
195
- self.ui.score_all_button.clicked.connect(self.score_all)
196
- self.ui.overwritecheckbox.stateChanged.connect(self.update_overwrite_policy)
165
+ self.ui.score_all_button.clicked.connect(self.score_recordings)
166
+ self.ui.overwritecheckbox.stateChanged.connect(
167
+ lambda v: setattr(self.scoring, "only_overwrite_undefined", bool(v))
168
+ )
197
169
  self.ui.save_confidence_checkbox.stateChanged.connect(
198
- self.update_confidence_policy
170
+ lambda v: setattr(self.scoring, "save_confidence_scores", bool(v))
171
+ )
172
+ self.ui.bout_length_input.valueChanged.connect(
173
+ lambda v: setattr(self.scoring, "min_bout_length", v)
199
174
  )
200
- self.ui.bout_length_input.valueChanged.connect(self.update_min_bout_length)
201
175
  self.ui.user_manual_button.clicked.connect(self.show_user_manual)
202
- self.ui.image_number_input.valueChanged.connect(self.update_epochs_per_img)
203
- self.ui.delete_image_box.stateChanged.connect(self.update_image_deletion)
176
+ self.ui.image_number_input.valueChanged.connect(
177
+ lambda v: setattr(self.training, "epochs_per_img", v)
178
+ )
179
+ self.ui.delete_image_box.stateChanged.connect(
180
+ lambda v: setattr(self.training, "delete_images", bool(v))
181
+ )
204
182
  self.ui.calibrate_checkbox.stateChanged.connect(
205
183
  self.update_training_calibration
206
184
  )
207
185
  self.ui.train_model_button.clicked.connect(self.train_model)
208
- self.ui.save_config_button.clicked.connect(self.save_brain_state_config)
186
+ self.ui.save_config_button.clicked.connect(self.config.save_config)
209
187
  self.ui.export_button.clicked.connect(self.export_recording_list)
210
188
  self.ui.import_button.clicked.connect(self.import_recording_list)
211
189
  self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
212
- self.ui.reset_emg_params_button.clicked.connect(self.reset_emg_filter_settings)
213
- self.ui.reset_hyperparams_button.clicked.connect(
214
- self.reset_hyperparams_settings
215
- )
190
+ self.ui.reset_emg_params_button.clicked.connect(self.config.reset_emg_filter)
191
+ self.ui.reset_hyperparams_button.clicked.connect(self.config.reset_hyperparams)
216
192
 
217
193
  # user input: drag and drop
218
194
  self.ui.recording_file_label.installEventFilter(self)
@@ -227,52 +203,29 @@ class AccuSleepWindow(QMainWindow):
227
203
 
228
204
  :param default_selected: whether default option is selected
229
205
  """
230
- self.model_type = (
206
+ self.training.model_type = (
231
207
  DEFAULT_MODEL_TYPE if default_selected else REAL_TIME_MODEL_TYPE
232
208
  )
233
209
 
234
210
  def export_recording_list(self) -> None:
235
211
  """Save current list of recordings to file"""
236
- # get the name for the recording list file
237
- filename, _ = QFileDialog.getSaveFileName(
238
- self,
239
- caption="Save list of recordings as",
240
- filter="*" + RECORDING_LIST_FILE_TYPE,
212
+ filename = select_save_location(
213
+ self, "Save list of recordings as", "*" + RECORDING_LIST_FILE_TYPE
241
214
  )
242
215
  if not filename:
243
216
  return
244
- filename = os.path.normpath(filename)
245
- save_recording_list(filename=filename, recordings=self.recordings)
217
+ self.recording_manager.export_to_file(filename)
246
218
  self.show_message(f"Saved list of recordings to {filename}")
247
219
 
248
220
  def import_recording_list(self):
249
221
  """Load list of recordings from file, overwriting current list"""
250
- file_dialog = QFileDialog(self)
251
- file_dialog.setWindowTitle("Select list of recordings")
252
- file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
253
- file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
254
- file_dialog.setNameFilter("*" + RECORDING_LIST_FILE_TYPE)
255
-
256
- if file_dialog.exec():
257
- selected_files = file_dialog.selectedFiles()
258
- filename = selected_files[0]
259
- filename = os.path.normpath(filename)
260
- else:
222
+ filename = select_existing_file(
223
+ self, "Select list of recordings", "*" + RECORDING_LIST_FILE_TYPE
224
+ )
225
+ if not filename:
261
226
  return
262
227
 
263
- # clear widget
264
- self.ui.recording_list_widget.clear()
265
- # overwrite current list
266
- self.recordings = load_recording_list(filename)
267
-
268
- for recording in self.recordings:
269
- recording.widget = QListWidgetItem(
270
- f"Recording {recording.name}", self.ui.recording_list_widget
271
- )
272
- self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
273
-
274
- # display new list
275
- self.ui.recording_list_widget.setCurrentRow(0)
228
+ self.recording_manager.import_from_file(filename)
276
229
  self.show_message(f"Loaded list of recordings from {filename}")
277
230
 
278
231
  def eventFilter(self, obj: QObject, event: QEvent) -> bool:
@@ -297,20 +250,21 @@ class AccuSleepWindow(QMainWindow):
297
250
 
298
251
  if filename is None:
299
252
  return super().eventFilter(obj, event)
253
+ filename = str(filename)
300
254
 
301
255
  _, file_extension = os.path.splitext(filename)
302
256
 
303
257
  if obj == self.ui.recording_file_label:
304
258
  if file_extension in RECORDING_FILE_TYPES:
305
- self.recordings[self.recording_index].recording_file = filename
259
+ self.recording_manager.current.recording_file = filename
306
260
  self.ui.recording_file_label.setText(filename)
307
261
  elif obj == self.ui.label_file_label:
308
262
  if file_extension == LABEL_FILE_TYPE:
309
- self.recordings[self.recording_index].label_file = filename
263
+ self.recording_manager.current.label_file = filename
310
264
  self.ui.label_file_label.setText(filename)
311
265
  elif obj == self.ui.calibration_file_label:
312
266
  if file_extension == CALIBRATION_FILE_TYPE:
313
- self.recordings[self.recording_index].calibration_file = filename
267
+ self.recording_manager.current.calibration_file = filename
314
268
  self.ui.calibration_file_label.setText(filename)
315
269
  elif obj == self.ui.model_label:
316
270
  self.load_model(filename=filename)
@@ -318,335 +272,69 @@ class AccuSleepWindow(QMainWindow):
318
272
  return super().eventFilter(obj, event)
319
273
 
320
274
  def train_model(self) -> None:
321
- # check basic training inputs
322
- if (
323
- self.model_type == DEFAULT_MODEL_TYPE
324
- and self.training_epochs_per_img % 2 == 0
325
- ):
326
- self.show_message(
327
- (
328
- "ERROR: for the default model type, number of epochs "
329
- "per image must be an odd number."
330
- )
331
- )
332
- return
333
-
334
- # determine fraction of training data to use for calibration
335
- if self.calibrate_trained_model:
336
- calibration_fraction = self.ui.calibration_spinbox.value() / 100
337
- else:
338
- calibration_fraction = 0
339
-
340
- # check some inputs for each recording
341
- for recording_index in range(len(self.recordings)):
342
- error_message = self.check_single_file_inputs(recording_index)
343
- if error_message:
344
- self.show_message(
345
- f"ERROR (recording {self.recordings[recording_index].name}): {error_message}"
346
- )
347
- return
348
-
349
- # get filename for the new model
350
- model_filename, _ = QFileDialog.getSaveFileName(
351
- self,
352
- caption="Save classification model file as",
353
- filter="*" + MODEL_FILE_TYPE,
275
+ """Train a classification model using the current recordings."""
276
+ model_filename = select_save_location(
277
+ self, "Save classification model file as", "*" + MODEL_FILE_TYPE
354
278
  )
355
279
  if not model_filename:
356
280
  self.show_message("Model training canceled, no filename given")
357
281
  return
358
- model_filename = os.path.normpath(model_filename)
359
282
 
360
- # create (probably temporary) image folder in
361
- # the same folder as the trained model
362
- temp_image_dir = os.path.join(
363
- os.path.dirname(model_filename),
364
- "images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
365
- )
366
-
367
- if os.path.exists(temp_image_dir): # unlikely
368
- self.show_message(
369
- "Warning: training image folder exists, will be overwritten"
370
- )
371
- os.makedirs(temp_image_dir, exist_ok=True)
283
+ # Determine calibration fraction
284
+ if self.training.calibrate:
285
+ calibration_fraction = self.ui.calibration_spinbox.value() / 100
286
+ else:
287
+ calibration_fraction = 0
372
288
 
373
- # create training images
289
+ # Show progress message
374
290
  self.show_message("Training, please wait. See console for progress updates.")
375
- if not self.delete_training_images:
376
- self.show_message((f"Creating training images in {temp_image_dir}"))
377
- else:
378
- self.show_message(
379
- (f"Creating temporary folder of training images: {temp_image_dir}")
380
- )
381
291
  self.ui.message_area.repaint()
382
292
  QApplication.processEvents()
383
- print("Creating training images")
384
- failed_recordings = create_training_images(
385
- recordings=self.recordings,
386
- output_path=temp_image_dir,
387
- epoch_length=self.epoch_length,
388
- epochs_per_img=self.training_epochs_per_img,
389
- brain_state_set=self.brain_state_set,
390
- model_type=self.model_type,
391
- calibration_fraction=calibration_fraction,
392
- emg_filter=self.emg_filter,
393
- )
394
- if len(failed_recordings) > 0:
395
- if len(failed_recordings) == len(self.recordings):
396
- self.show_message("ERROR: no recordings were valid!")
397
- return
398
- else:
399
- self.show_message(
400
- (
401
- "WARNING: the following recordings could not be "
402
- "loaded and will not be used for training: "
403
- f"{', '.join([str(r) for r in failed_recordings])}"
404
- )
405
- )
406
293
 
407
- # train model
408
- self.show_message("Training model")
409
- self.ui.message_area.repaint()
410
- QApplication.processEvents()
411
- print("Training model")
412
- from accusleepy.classification import create_dataloader, train_ssann
413
- from accusleepy.models import save_model
414
- from accusleepy.temperature_scaling import ModelWithTemperature
415
-
416
- model = train_ssann(
417
- annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
418
- img_dir=temp_image_dir,
419
- mixture_weights=self.brain_state_set.mixture_weights,
420
- n_classes=self.brain_state_set.n_classes,
421
- hyperparameters=self.hyperparameters,
422
- )
423
-
424
- # calibrate the model
425
- if self.calibrate_trained_model:
426
- calibration_annotation_file = os.path.join(
427
- temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
428
- )
429
- calibration_dataloader = create_dataloader(
430
- annotations_file=calibration_annotation_file,
431
- img_dir=temp_image_dir,
432
- hyperparameters=self.hyperparameters,
433
- )
434
- model = ModelWithTemperature(model)
435
- print("Calibrating model")
436
- model.set_temperature(calibration_dataloader)
437
-
438
- # save model
439
- save_model(
440
- model=model,
441
- filename=model_filename,
294
+ # Create service and run training
295
+ service = TrainingService(progress_callback=self.show_message)
296
+ result = service.train_model(
297
+ recordings=list(self.recording_manager),
442
298
  epoch_length=self.epoch_length,
443
- epochs_per_img=self.training_epochs_per_img,
444
- model_type=self.model_type,
445
- brain_state_set=self.brain_state_set,
446
- is_calibrated=self.calibrate_trained_model,
299
+ epochs_per_img=self.training.epochs_per_img,
300
+ model_type=self.training.model_type,
301
+ calibrate=self.training.calibrate,
302
+ calibration_fraction=calibration_fraction,
303
+ brain_state_set=self.config.brain_state_set,
304
+ emg_filter=self.config.emg_filter,
305
+ hyperparameters=self.config.hyperparameters,
306
+ model_filename=model_filename,
307
+ delete_images=self.training.delete_images,
447
308
  )
448
309
 
449
- # optionally delete images
450
- if self.delete_training_images:
451
- print("Cleaning up training image folder")
452
- shutil.rmtree(temp_image_dir)
453
-
454
- self.show_message(f"Training complete. Saved model to {model_filename}")
455
- print("Training complete.")
456
-
457
- def update_image_deletion(self) -> None:
458
- """Update choice of whether to delete images after training"""
459
- self.delete_training_images = self.ui.delete_image_box.isChecked()
310
+ # Display results
311
+ result.report_to(self.show_message)
460
312
 
461
313
  def update_training_calibration(self) -> None:
462
314
  """Update choice of whether to calibrate model after training"""
463
- self.calibrate_trained_model = self.ui.calibrate_checkbox.isChecked()
464
- self.ui.calibration_spinbox.setEnabled(self.calibrate_trained_model)
465
-
466
- def update_epochs_per_img(self, new_value) -> None:
467
- """Update number of epochs per image
468
-
469
- :param new_value: new number of epochs per image
470
- """
471
- self.training_epochs_per_img = new_value
472
-
473
- def score_all(self) -> None:
474
- """Score all recordings using the classification model"""
475
- # check basic inputs
476
- if self.model is None:
477
- self.ui.score_all_status.setText("missing classification model")
478
- self.show_message("ERROR: no classification model file selected")
479
- return
480
- if self.min_bout_length < self.epoch_length:
481
- self.ui.score_all_status.setText("invalid minimum bout length")
482
- self.show_message("ERROR: minimum bout length must be >= epoch length")
483
- return
484
- if self.epoch_length != self.model_epoch_length:
485
- self.ui.score_all_status.setText("invalid epoch length")
486
- self.show_message(
487
- (
488
- "ERROR: model was trained with an epoch length of "
489
- f"{self.model_epoch_length} seconds, but the current "
490
- f"epoch length setting is {self.epoch_length} seconds."
491
- )
492
- )
493
- return
315
+ self.training.calibrate = self.ui.calibrate_checkbox.isChecked()
316
+ self.ui.calibration_spinbox.setEnabled(self.training.calibrate)
494
317
 
318
+ def score_recordings(self) -> None:
319
+ """Score all recordings using the classification model."""
495
320
  self.ui.score_all_status.setText("running...")
496
321
  self.ui.score_all_status.repaint()
497
322
  QApplication.processEvents()
498
323
 
499
- from accusleepy.classification import score_recording
500
-
501
- # check some inputs for each recording
502
- for recording_index in range(len(self.recordings)):
503
- error_message = self.check_single_file_inputs(recording_index)
504
- if error_message:
505
- self.ui.score_all_status.setText(
506
- f"error on recording {self.recordings[recording_index].name}"
507
- )
508
- self.show_message(
509
- f"ERROR (recording {self.recordings[recording_index].name}): {error_message}"
510
- )
511
- return
512
- if self.recordings[recording_index].calibration_file == "":
513
- self.ui.score_all_status.setText(
514
- f"error on recording {self.recordings[recording_index].name}"
515
- )
516
- self.show_message(
517
- (
518
- f"ERROR (recording {self.recordings[recording_index].name}): "
519
- "no calibration file selected"
520
- )
521
- )
522
- return
523
-
524
- # score each recording
525
- for recording_index in range(len(self.recordings)):
526
- # load EEG, EMG
527
- try:
528
- eeg, emg = load_recording(
529
- self.recordings[recording_index].recording_file
530
- )
531
- sampling_rate = self.recordings[recording_index].sampling_rate
532
-
533
- eeg, emg, sampling_rate = resample_and_standardize(
534
- eeg=eeg,
535
- emg=emg,
536
- sampling_rate=sampling_rate,
537
- epoch_length=self.epoch_length,
538
- )
539
- except Exception:
540
- self.show_message(
541
- (
542
- "ERROR: could not load recording "
543
- f"{self.recordings[recording_index].name}."
544
- "This recording will be skipped."
545
- )
546
- )
547
- continue
548
-
549
- # load labels
550
- label_file = self.recordings[recording_index].label_file
551
- if os.path.isfile(label_file):
552
- try:
553
- # ignore any existing confidence scores; they will all be overwritten
554
- existing_labels, _ = load_labels(label_file)
555
- except Exception:
556
- self.show_message(
557
- (
558
- "ERROR: could not load existing labels for recording "
559
- f"{self.recordings[recording_index].name}."
560
- "This recording will be skipped."
561
- )
562
- )
563
- continue
564
- # only check the length
565
- samples_per_epoch = sampling_rate * self.epoch_length
566
- epochs_in_recording = round(eeg.size / samples_per_epoch)
567
- if epochs_in_recording != existing_labels.size:
568
- self.show_message(
569
- (
570
- "ERROR: existing labels for recording "
571
- f"{self.recordings[recording_index].name} "
572
- "do not match the recording length. "
573
- "This recording will be skipped."
574
- )
575
- )
576
- continue
577
- else:
578
- existing_labels = None
579
-
580
- # load calibration data
581
- if not os.path.isfile(self.recordings[recording_index].calibration_file):
582
- self.show_message(
583
- (
584
- "ERROR: calibration file does not exist for recording "
585
- f"{self.recordings[recording_index].name}. "
586
- "This recording will be skipped."
587
- )
588
- )
589
- continue
590
- try:
591
- (
592
- mixture_means,
593
- mixture_sds,
594
- ) = load_calibration_file(
595
- self.recordings[recording_index].calibration_file
596
- )
597
- except Exception:
598
- self.show_message(
599
- (
600
- "ERROR: could not load calibration file for recording "
601
- f"{self.recordings[recording_index].name}. "
602
- "This recording will be skipped."
603
- )
604
- )
605
- continue
606
-
607
- labels, confidence_scores = score_recording(
608
- model=self.model,
609
- eeg=eeg,
610
- emg=emg,
611
- mixture_means=mixture_means,
612
- mixture_sds=mixture_sds,
613
- sampling_rate=sampling_rate,
614
- epoch_length=self.epoch_length,
615
- epochs_per_img=self.model_epochs_per_img,
616
- brain_state_set=self.brain_state_set,
617
- emg_filter=self.emg_filter,
618
- )
619
-
620
- # overwrite as needed
621
- if existing_labels is not None and self.only_overwrite_undefined:
622
- labels[existing_labels != UNDEFINED_LABEL] = existing_labels[
623
- existing_labels != UNDEFINED_LABEL
624
- ]
625
-
626
- # enforce minimum bout length
627
- labels = enforce_min_bout_length(
628
- labels=labels,
629
- epoch_length=self.epoch_length,
630
- min_bout_length=self.min_bout_length,
631
- )
632
-
633
- # ignore confidence scores if desired
634
- if not self.save_confidence_scores:
635
- confidence_scores = None
636
-
637
- # save results
638
- save_labels(
639
- labels=labels, filename=label_file, confidence_scores=confidence_scores
640
- )
641
- self.show_message(
642
- (
643
- "Saved labels for recording "
644
- f"{self.recordings[recording_index].name} "
645
- f"to {label_file}"
646
- )
647
- )
324
+ result = score_recording_list(
325
+ recordings=list(self.recording_manager),
326
+ loaded_model=self.loaded_model,
327
+ epoch_length=self.epoch_length,
328
+ only_overwrite_undefined=self.scoring.only_overwrite_undefined,
329
+ save_confidence_scores=self.scoring.save_confidence_scores,
330
+ min_bout_length=self.scoring.min_bout_length,
331
+ brain_state_set=self.config.brain_state_set,
332
+ emg_filter=self.config.emg_filter,
333
+ )
648
334
 
649
- self.ui.score_all_status.setText("")
335
+ # Display results
336
+ result.report_to(self.show_message)
337
+ self.ui.score_all_status.setText("error" if not result.success else "")
650
338
 
651
339
  def load_model(self, filename=None) -> None:
652
340
  """Load trained classification model from file
@@ -654,17 +342,10 @@ class AccuSleepWindow(QMainWindow):
654
342
  :param filename: model filename, if it's known
655
343
  """
656
344
  if filename is None:
657
- file_dialog = QFileDialog(self)
658
- file_dialog.setWindowTitle("Select classification model")
659
- file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
660
- file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
661
- file_dialog.setNameFilter("*" + MODEL_FILE_TYPE)
662
-
663
- if file_dialog.exec():
664
- selected_files = file_dialog.selectedFiles()
665
- filename = selected_files[0]
666
- filename = os.path.normpath(filename)
667
- else:
345
+ filename = select_existing_file(
346
+ self, "Select classification model", "*" + MODEL_FILE_TYPE
347
+ )
348
+ if not filename:
668
349
  return
669
350
 
670
351
  if not os.path.isfile(filename):
@@ -682,6 +363,7 @@ class AccuSleepWindow(QMainWindow):
682
363
  filename=filename
683
364
  )
684
365
  except Exception:
366
+ logger.exception("Failed to load %s", filename)
685
367
  self.show_message(
686
368
  (
687
369
  "ERROR: could not load classification model. Check "
@@ -702,14 +384,14 @@ class AccuSleepWindow(QMainWindow):
702
384
  )
703
385
  return
704
386
 
705
- self.model = model
706
- self.model_epoch_length = epoch_length
707
- self.model_epochs_per_img = epochs_per_img
387
+ self.loaded_model.model = model
388
+ self.loaded_model.epoch_length = epoch_length
389
+ self.loaded_model.epochs_per_img = epochs_per_img
708
390
 
709
391
  # warn user if the model's expected epoch length or brain states
710
392
  # don't match the current configuration
711
393
  config_warnings = check_config_consistency(
712
- current_brain_states=self.brain_state_set.to_output_dict()[
394
+ current_brain_states=self.config.brain_state_set.to_output_dict()[
713
395
  BRAIN_STATES_KEY
714
396
  ],
715
397
  model_brain_states=brain_states,
@@ -736,17 +418,21 @@ class AccuSleepWindow(QMainWindow):
736
418
  :param status_widget: UI element on which to display error messages
737
419
  :return: EEG data, EMG data, sampling rate, process completion
738
420
  """
739
- error_message = self.check_single_file_inputs(self.recording_index)
421
+ error_message = check_single_file_inputs(
422
+ self.recording_manager.current, self.epoch_length
423
+ )
740
424
  if error_message:
741
425
  status_widget.setText(error_message)
742
426
  self.show_message(f"ERROR: {error_message}")
743
427
  return None, None, None, False
744
428
 
745
429
  try:
746
- eeg, emg = load_recording(
747
- self.recordings[self.recording_index].recording_file
748
- )
430
+ eeg, emg = load_recording(self.recording_manager.current.recording_file)
749
431
  except Exception:
432
+ logger.exception(
433
+ "Failed to load %s",
434
+ self.recording_manager.current.recording_file,
435
+ )
750
436
  status_widget.setText("could not load recording")
751
437
  self.show_message(
752
438
  (
@@ -756,7 +442,7 @@ class AccuSleepWindow(QMainWindow):
756
442
  )
757
443
  return None, None, None, False
758
444
 
759
- sampling_rate = self.recordings[self.recording_index].sampling_rate
445
+ sampling_rate = self.recording_manager.current.sampling_rate
760
446
 
761
447
  eeg, emg, sampling_rate = resample_and_standardize(
762
448
  eeg=eeg,
@@ -768,135 +454,35 @@ class AccuSleepWindow(QMainWindow):
768
454
  return eeg, emg, sampling_rate, True
769
455
 
770
456
  def create_calibration_file(self) -> None:
771
- """Creates a calibration file
457
+ """Creates a calibration file.
772
458
 
773
459
  This loads a recording and its labels, checks that the labels are
774
460
  all valid, creates the calibration file, and sets the
775
461
  "calibration file" property of the current recording to be the
776
462
  newly created file.
777
463
  """
778
- # load the recording
779
- eeg, emg, sampling_rate, success = self.load_single_recording(
780
- self.ui.calibration_status
781
- )
782
- if not success:
783
- return
784
-
785
- # load the labels
786
- label_file = self.recordings[self.recording_index].label_file
787
- if not os.path.isfile(label_file):
788
- self.ui.calibration_status.setText("label file does not exist")
789
- self.show_message("ERROR: label file does not exist")
790
- return
791
- try:
792
- labels, _ = load_labels(label_file)
793
- except Exception:
794
- self.ui.calibration_status.setText("could not load labels")
795
- self.show_message(
796
- (
797
- "ERROR: could not load labels. "
798
- "Check user manual for formatting instructions."
799
- )
800
- )
801
- return
802
- label_error_message = check_label_validity(
803
- labels=labels,
804
- confidence_scores=None,
805
- samples_in_recording=eeg.size,
806
- sampling_rate=sampling_rate,
807
- epoch_length=self.epoch_length,
808
- brain_state_set=self.brain_state_set,
809
- )
810
- if label_error_message:
811
- self.ui.calibration_status.setText("invalid label file")
812
- self.show_message(f"ERROR: {label_error_message}")
813
- return
814
-
815
- # get the name for the calibration file
816
- filename, _ = QFileDialog.getSaveFileName(
817
- self,
818
- caption="Save calibration file as",
819
- filter="*" + CALIBRATION_FILE_TYPE,
464
+ filename = select_save_location(
465
+ self, "Save calibration file as", "*" + CALIBRATION_FILE_TYPE
820
466
  )
821
467
  if not filename:
822
468
  return
823
- filename = os.path.normpath(filename)
824
469
 
825
- from accusleepy.classification import create_calibration_file
826
-
827
- create_calibration_file(
828
- filename=filename,
829
- eeg=eeg,
830
- emg=emg,
831
- labels=labels,
832
- sampling_rate=sampling_rate,
470
+ result = create_calibration(
471
+ recording=self.recording_manager.current,
833
472
  epoch_length=self.epoch_length,
834
- brain_state_set=self.brain_state_set,
835
- emg_filter=self.emg_filter,
836
- )
837
-
838
- self.ui.calibration_status.setText("")
839
- self.show_message(
840
- (
841
- "Created calibration file using recording "
842
- f"{self.recordings[self.recording_index].name} "
843
- f"at {filename}"
844
- )
473
+ brain_state_set=self.config.brain_state_set,
474
+ emg_filter=self.config.emg_filter,
475
+ output_filename=filename,
845
476
  )
846
477
 
847
- self.recordings[self.recording_index].calibration_file = filename
848
- self.ui.calibration_file_label.setText(filename)
849
-
850
- def check_single_file_inputs(self, recording_index: int) -> str | None:
851
- """Check that a recording's inputs appear valid
852
-
853
- This runs some basic tests for whether it will be possible to
854
- load and score a recording. If any test fails, we return an
855
- error message.
856
-
857
- :param recording_index: index of the recording in the list of
858
- all recordings.
859
- :return: error message
860
- """
861
- sampling_rate = self.recordings[recording_index].sampling_rate
862
- if self.epoch_length == 0:
863
- return "epoch length can't be 0"
864
- if sampling_rate == 0:
865
- return "sampling rate can't be 0"
866
- if self.epoch_length > sampling_rate:
867
- return "invalid epoch length or sampling rate"
868
- if self.recordings[self.recording_index].recording_file == "":
869
- return "no recording selected"
870
- if not os.path.isfile(self.recordings[self.recording_index].recording_file):
871
- return "recording file does not exist"
872
- if self.recordings[self.recording_index].label_file == "":
873
- return "no label file selected"
874
-
875
- def update_min_bout_length(self, new_value) -> None:
876
- """Update the minimum bout length
877
-
878
- :param new_value: new minimum bout length, in seconds
879
- """
880
- self.min_bout_length = new_value
881
-
882
- def update_overwrite_policy(self, checked) -> None:
883
- """Toggle overwriting policy
884
-
885
- If the checkbox is enabled, only epochs where the brain state is set to
886
- undefined will be overwritten by the automatic scoring process.
887
-
888
- :param checked: state of the checkbox
889
- """
890
- self.only_overwrite_undefined = checked
891
-
892
- def update_confidence_policy(self, checked) -> None:
893
- """Toggle policy for saving confidence scores
894
-
895
- If the checkbox is enabled, confidence scores will be saved to the label files.
896
-
897
- :param checked: state of the checkbox
898
- """
899
- self.save_confidence_scores = checked
478
+ # Display results
479
+ result.report_to(self.show_message)
480
+ if not result.success:
481
+ self.ui.calibration_status.setText("error")
482
+ else:
483
+ self.ui.calibration_status.setText("")
484
+ self.recording_manager.current.calibration_file = filename
485
+ self.ui.calibration_file_label.setText(filename)
900
486
 
901
487
  def manual_scoring(self) -> None:
902
488
  """View the selected recording for manual scoring"""
@@ -914,11 +500,12 @@ class AccuSleepWindow(QMainWindow):
914
500
 
915
501
  # if the labels exist, load them
916
502
  # otherwise, create a blank set of labels
917
- label_file = self.recordings[self.recording_index].label_file
503
+ label_file = self.recording_manager.current.label_file
918
504
  if os.path.isfile(label_file):
919
505
  try:
920
506
  labels, confidence_scores = load_labels(label_file)
921
507
  except Exception:
508
+ logger.exception("Failed to load %s", label_file)
922
509
  self.ui.manual_scoring_status.setText("could not load labels")
923
510
  self.show_message(
924
511
  (
@@ -936,56 +523,23 @@ class AccuSleepWindow(QMainWindow):
936
523
  # to a label file that does not have one
937
524
  confidence_scores = None
938
525
 
939
- # check that all labels are valid
940
- label_error = check_label_validity(
526
+ # check that labels are valid and correct minor length mismatches
527
+ labels, confidence_scores, validation_message = validate_and_correct_labels(
941
528
  labels=labels,
942
529
  confidence_scores=confidence_scores,
943
530
  samples_in_recording=eeg.size,
944
531
  sampling_rate=sampling_rate,
945
532
  epoch_length=self.epoch_length,
946
- brain_state_set=self.brain_state_set,
533
+ brain_state_set=self.config.brain_state_set,
947
534
  )
948
- if label_error:
949
- # if the label length is only off by one, pad or truncate as needed
950
- # and show a warning
951
- if label_error == LABEL_LENGTH_ERROR:
952
- # should be very close to an integer
953
- samples_per_epoch = round(sampling_rate * self.epoch_length)
954
- epochs_in_recording = round(eeg.size / samples_per_epoch)
955
- if epochs_in_recording - labels.size == 1:
956
- labels = np.concatenate((labels, np.array([UNDEFINED_LABEL])))
957
- if confidence_scores is not None:
958
- confidence_scores = np.concatenate(
959
- (confidence_scores, np.array([0]))
960
- )
961
- self.show_message(
962
- (
963
- "WARNING: an undefined epoch was added to "
964
- "the label file to correct its length."
965
- )
966
- )
967
- elif labels.size - epochs_in_recording == 1:
968
- labels = labels[:-1]
969
- if confidence_scores is not None:
970
- confidence_scores = confidence_scores[:-1]
971
- self.show_message(
972
- (
973
- "WARNING: the last epoch was removed from "
974
- "the label file to correct its length."
975
- )
976
- )
977
- else:
978
- self.ui.manual_scoring_status.setText("invalid label file")
979
- self.show_message(f"ERROR: {label_error}")
980
- return
981
- else:
982
- self.ui.manual_scoring_status.setText("invalid label file")
983
- self.show_message(f"ERROR: {label_error}")
984
- return
535
+ if labels is None:
536
+ self.ui.manual_scoring_status.setText("invalid label file")
537
+ self.show_message(f"ERROR: {validation_message}")
538
+ return
539
+ if validation_message:
540
+ self.show_message(f"WARNING: {validation_message}")
985
541
 
986
- self.show_message(
987
- f"Viewing recording {self.recordings[self.recording_index].name}"
988
- )
542
+ self.show_message(f"Viewing recording {self.recording_manager.current.name}")
989
543
  self.ui.manual_scoring_status.setText("file is open")
990
544
 
991
545
  # launch the manual scoring window
@@ -997,7 +551,7 @@ class AccuSleepWindow(QMainWindow):
997
551
  confidence_scores=confidence_scores,
998
552
  sampling_rate=sampling_rate,
999
553
  epoch_length=self.epoch_length,
1000
- emg_filter=self.emg_filter,
554
+ emg_filter=self.config.emg_filter,
1001
555
  )
1002
556
  manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
1003
557
  manual_scoring_window.exec()
@@ -1005,89 +559,48 @@ class AccuSleepWindow(QMainWindow):
1005
559
 
1006
560
  def create_label_file(self) -> None:
1007
561
  """Set the filename for a new label file"""
1008
- filename, _ = QFileDialog.getSaveFileName(
562
+ filename = select_save_location(
1009
563
  self,
1010
- caption="Set filename for label file (nothing will be overwritten yet)",
1011
- filter="*" + LABEL_FILE_TYPE,
564
+ "Set filename for label file (nothing will be overwritten yet)",
565
+ "*" + LABEL_FILE_TYPE,
1012
566
  )
1013
567
  if filename:
1014
- filename = os.path.normpath(filename)
1015
- self.recordings[self.recording_index].label_file = filename
568
+ self.recording_manager.current.label_file = filename
1016
569
  self.ui.label_file_label.setText(filename)
1017
570
 
1018
571
  def select_label_file(self) -> None:
1019
572
  """User can select an existing label file"""
1020
- file_dialog = QFileDialog(self)
1021
- file_dialog.setWindowTitle("Select label file")
1022
- file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
1023
- file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
1024
- file_dialog.setNameFilter("*" + LABEL_FILE_TYPE)
1025
-
1026
- if file_dialog.exec():
1027
- selected_files = file_dialog.selectedFiles()
1028
- filename = selected_files[0]
1029
- filename = os.path.normpath(filename)
1030
- self.recordings[self.recording_index].label_file = filename
573
+ filename = select_existing_file(
574
+ self, "Select label file", "*" + LABEL_FILE_TYPE
575
+ )
576
+ if filename:
577
+ self.recording_manager.current.label_file = filename
1031
578
  self.ui.label_file_label.setText(filename)
1032
579
 
1033
580
  def select_calibration_file(self) -> None:
1034
581
  """User can select a calibration file"""
1035
- file_dialog = QFileDialog(self)
1036
- file_dialog.setWindowTitle("Select calibration file")
1037
- file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
1038
- file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
1039
- file_dialog.setNameFilter("*" + CALIBRATION_FILE_TYPE)
1040
-
1041
- if file_dialog.exec():
1042
- selected_files = file_dialog.selectedFiles()
1043
- filename = selected_files[0]
1044
- filename = os.path.normpath(filename)
1045
- self.recordings[self.recording_index].calibration_file = filename
582
+ filename = select_existing_file(
583
+ self, "Select calibration file", "*" + CALIBRATION_FILE_TYPE
584
+ )
585
+ if filename:
586
+ self.recording_manager.current.calibration_file = filename
1046
587
  self.ui.calibration_file_label.setText(filename)
1047
588
 
1048
589
  def select_recording_file(self) -> None:
1049
590
  """User can select a recording file"""
1050
- file_dialog = QFileDialog(self)
1051
- file_dialog.setWindowTitle("Select recording file")
1052
- file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
1053
- file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
1054
- file_dialog.setNameFilter(f"(*{' *'.join(RECORDING_FILE_TYPES)})")
1055
-
1056
- if file_dialog.exec():
1057
- selected_files = file_dialog.selectedFiles()
1058
- filename = selected_files[0]
1059
- filename = os.path.normpath(filename)
1060
- self.recordings[self.recording_index].recording_file = filename
591
+ file_filter = f"(*{' *'.join(RECORDING_FILE_TYPES)})"
592
+ filename = select_existing_file(self, "Select recording file", file_filter)
593
+ if filename:
594
+ self.recording_manager.current.recording_file = filename
1061
595
  self.ui.recording_file_label.setText(filename)
1062
596
 
1063
597
  def show_recording_info(self) -> None:
1064
598
  """Update the UI to show info for the selected recording"""
1065
- self.ui.sampling_rate_input.setValue(
1066
- self.recordings[self.recording_index].sampling_rate
1067
- )
1068
- self.ui.recording_file_label.setText(
1069
- self.recordings[self.recording_index].recording_file
1070
- )
1071
- self.ui.label_file_label.setText(
1072
- self.recordings[self.recording_index].label_file
1073
- )
1074
- self.ui.calibration_file_label.setText(
1075
- self.recordings[self.recording_index].calibration_file
1076
- )
1077
-
1078
- def update_epoch_length(self, new_value: int | float) -> None:
1079
- """Update the epoch length when the widget state changes
1080
-
1081
- :param new_value: new epoch length
1082
- """
1083
- self.epoch_length = new_value
1084
-
1085
- def update_sampling_rate(self, new_value: int | float) -> None:
1086
- """Update recording's sampling rate when the widget state changes
1087
-
1088
- :param new_value: new sampling rate
1089
- """
1090
- self.recordings[self.recording_index].sampling_rate = new_value
599
+ recording = self.recording_manager.current
600
+ self.ui.sampling_rate_input.setValue(recording.sampling_rate)
601
+ self.ui.recording_file_label.setText(recording.recording_file)
602
+ self.ui.label_file_label.setText(recording.label_file)
603
+ self.ui.calibration_file_label.setText(recording.calibration_file)
1091
604
 
1092
605
  def show_message(self, message: str) -> None:
1093
606
  """Display a new message to the user
@@ -1102,50 +615,22 @@ class AccuSleepWindow(QMainWindow):
1102
615
  scrollbar = self.ui.message_area.verticalScrollBar()
1103
616
  scrollbar.setValue(scrollbar.maximum())
1104
617
 
1105
- def select_recording(self, list_index: int) -> None:
1106
- """Callback for when a recording is selected
1107
-
1108
- :param list_index: index of this recording in the list widget
1109
- """
1110
- # get index of this recording
1111
- self.recording_index = list_index
1112
- # display information about this recording
618
+ def select_recording(self, _index: int) -> None:
619
+ """Callback for when a recording is selected"""
1113
620
  self.show_recording_info()
1114
621
  self.ui.selected_recording_groupbox.setTitle(
1115
- f"Data / actions for Recording {self.recordings[list_index].name}"
622
+ f"Data / actions for Recording {self.recording_manager.current.name}"
1116
623
  )
1117
624
 
1118
625
  def add_recording(self) -> None:
1119
626
  """Add new recording to the list"""
1120
- # find name to use for the new recording
1121
- new_name = max([r.name for r in self.recordings]) + 1
1122
-
1123
- # add new recording to list
1124
- self.recordings.append(
1125
- Recording(
1126
- name=new_name,
1127
- sampling_rate=self.recordings[self.recording_index].sampling_rate,
1128
- widget=QListWidgetItem(
1129
- f"Recording {new_name}", self.ui.recording_list_widget
1130
- ),
1131
- )
1132
- )
1133
-
1134
- # display new list
1135
- self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
1136
- self.ui.recording_list_widget.setCurrentRow(len(self.recordings) - 1)
1137
- self.show_message(f"added Recording {new_name}")
627
+ current_sampling_rate = self.recording_manager.current.sampling_rate
628
+ recording = self.recording_manager.add(sampling_rate=current_sampling_rate)
629
+ self.show_message(f"added Recording {recording.name}")
1138
630
 
1139
631
  def remove_recording(self) -> None:
1140
632
  """Delete selected recording from the list"""
1141
- if len(self.recordings) > 1:
1142
- current_list_index = self.ui.recording_list_widget.currentRow()
1143
- _ = self.ui.recording_list_widget.takeItem(current_list_index)
1144
- self.show_message(
1145
- f"deleted Recording {self.recordings[current_list_index].name}"
1146
- )
1147
- del self.recordings[current_list_index]
1148
- self.recording_index = self.ui.recording_list_widget.currentRow()
633
+ self.show_message(self.recording_manager.remove_current())
1149
634
 
1150
635
  def show_user_manual(self) -> None:
1151
636
  """Show a popup window with the user manual"""
@@ -1161,310 +646,12 @@ class AccuSleepWindow(QMainWindow):
1161
646
  self.popup.setGeometry(QRect(100, 100, 600, 600))
1162
647
  self.popup.show()
1163
648
 
1164
- def initialize_settings_tab(self):
1165
- """Populate settings tab and assign its callbacks"""
1166
- # store dictionary that maps digits to rows of widgets
1167
- # in the settings tab
1168
- self.settings_widgets = {
1169
- 1: StateSettings(
1170
- digit=1,
1171
- enabled_widget=self.ui.enable_state_1,
1172
- name_widget=self.ui.state_name_1,
1173
- is_scored_widget=self.ui.state_scored_1,
1174
- frequency_widget=self.ui.state_frequency_1,
1175
- ),
1176
- 2: StateSettings(
1177
- digit=2,
1178
- enabled_widget=self.ui.enable_state_2,
1179
- name_widget=self.ui.state_name_2,
1180
- is_scored_widget=self.ui.state_scored_2,
1181
- frequency_widget=self.ui.state_frequency_2,
1182
- ),
1183
- 3: StateSettings(
1184
- digit=3,
1185
- enabled_widget=self.ui.enable_state_3,
1186
- name_widget=self.ui.state_name_3,
1187
- is_scored_widget=self.ui.state_scored_3,
1188
- frequency_widget=self.ui.state_frequency_3,
1189
- ),
1190
- 4: StateSettings(
1191
- digit=4,
1192
- enabled_widget=self.ui.enable_state_4,
1193
- name_widget=self.ui.state_name_4,
1194
- is_scored_widget=self.ui.state_scored_4,
1195
- frequency_widget=self.ui.state_frequency_4,
1196
- ),
1197
- 5: StateSettings(
1198
- digit=5,
1199
- enabled_widget=self.ui.enable_state_5,
1200
- name_widget=self.ui.state_name_5,
1201
- is_scored_widget=self.ui.state_scored_5,
1202
- frequency_widget=self.ui.state_frequency_5,
1203
- ),
1204
- 6: StateSettings(
1205
- digit=6,
1206
- enabled_widget=self.ui.enable_state_6,
1207
- name_widget=self.ui.state_name_6,
1208
- is_scored_widget=self.ui.state_scored_6,
1209
- frequency_widget=self.ui.state_frequency_6,
1210
- ),
1211
- 7: StateSettings(
1212
- digit=7,
1213
- enabled_widget=self.ui.enable_state_7,
1214
- name_widget=self.ui.state_name_7,
1215
- is_scored_widget=self.ui.state_scored_7,
1216
- frequency_widget=self.ui.state_frequency_7,
1217
- ),
1218
- 8: StateSettings(
1219
- digit=8,
1220
- enabled_widget=self.ui.enable_state_8,
1221
- name_widget=self.ui.state_name_8,
1222
- is_scored_widget=self.ui.state_scored_8,
1223
- frequency_widget=self.ui.state_frequency_8,
1224
- ),
1225
- 9: StateSettings(
1226
- digit=9,
1227
- enabled_widget=self.ui.enable_state_9,
1228
- name_widget=self.ui.state_name_9,
1229
- is_scored_widget=self.ui.state_scored_9,
1230
- frequency_widget=self.ui.state_frequency_9,
1231
- ),
1232
- 0: StateSettings(
1233
- digit=0,
1234
- enabled_widget=self.ui.enable_state_0,
1235
- name_widget=self.ui.state_name_0,
1236
- is_scored_widget=self.ui.state_scored_0,
1237
- frequency_widget=self.ui.state_frequency_0,
1238
- ),
1239
- }
1240
-
1241
- # update widget state to display current config
1242
- # UI defaults
1243
- self.ui.default_epoch_input.setValue(self.epoch_length)
1244
- self.ui.overwrite_default_checkbox.setChecked(self.only_overwrite_undefined)
1245
- self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_scores)
1246
- self.ui.default_min_bout_length_spinbox.setValue(self.min_bout_length)
1247
- self.ui.epochs_to_show_spinbox.setValue(self.default_epochs_to_show)
1248
- self.ui.autoscroll_checkbox.setChecked(self.default_autoscroll_state)
1249
- # EMG filter
1250
- self.ui.emg_order_spinbox.setValue(self.emg_filter.order)
1251
- self.ui.bp_lower_spinbox.setValue(self.emg_filter.bp_lower)
1252
- self.ui.bp_upper_spinbox.setValue(self.emg_filter.bp_upper)
1253
- # model training hyperparameters
1254
- self.ui.batch_size_spinbox.setValue(self.hyperparameters.batch_size)
1255
- self.ui.learning_rate_spinbox.setValue(self.hyperparameters.learning_rate)
1256
- self.ui.momentum_spinbox.setValue(self.hyperparameters.momentum)
1257
- self.ui.training_epochs_spinbox.setValue(self.hyperparameters.training_epochs)
1258
- # brain states
1259
- states = {b.digit: b for b in self.brain_state_set.brain_states}
1260
- for digit in range(10):
1261
- if digit in states.keys():
1262
- self.settings_widgets[digit].enabled_widget.setChecked(True)
1263
- self.settings_widgets[digit].name_widget.setText(states[digit].name)
1264
- self.settings_widgets[digit].is_scored_widget.setChecked(
1265
- states[digit].is_scored
1266
- )
1267
- self.settings_widgets[digit].frequency_widget.setValue(
1268
- states[digit].frequency
1269
- )
1270
- else:
1271
- self.settings_widgets[digit].enabled_widget.setChecked(False)
1272
- self.settings_widgets[digit].name_widget.setEnabled(False)
1273
- self.settings_widgets[digit].is_scored_widget.setEnabled(False)
1274
- self.settings_widgets[digit].frequency_widget.setEnabled(False)
1275
-
1276
- # set callbacks
1277
- self.ui.emg_order_spinbox.valueChanged.connect(self.emg_filter_order_changed)
1278
- self.ui.bp_lower_spinbox.valueChanged.connect(self.emg_filter_bp_lower_changed)
1279
- self.ui.bp_upper_spinbox.valueChanged.connect(self.emg_filter_bp_upper_changed)
1280
- self.ui.batch_size_spinbox.valueChanged.connect(self.hyperparameters_changed)
1281
- self.ui.learning_rate_spinbox.valueChanged.connect(self.hyperparameters_changed)
1282
- self.ui.momentum_spinbox.valueChanged.connect(self.hyperparameters_changed)
1283
- self.ui.training_epochs_spinbox.valueChanged.connect(
1284
- self.hyperparameters_changed
1285
- )
1286
- for digit in range(10):
1287
- state = self.settings_widgets[digit]
1288
- state.enabled_widget.stateChanged.connect(
1289
- partial(self.set_brain_state_enabled, digit)
1290
- )
1291
- state.name_widget.editingFinished.connect(self.check_config_validity)
1292
- state.is_scored_widget.stateChanged.connect(
1293
- partial(self.is_scored_changed, digit)
1294
- )
1295
- state.frequency_widget.valueChanged.connect(self.check_config_validity)
1296
-
1297
- def set_brain_state_enabled(self, digit, e) -> None:
1298
- """Called when user clicks "enabled" checkbox
1299
-
1300
- :param digit: brain state digit
1301
- :param e: unused but mandatory
1302
- """
1303
- # get the widgets for this brain state
1304
- state = self.settings_widgets[digit]
1305
- # update state of these widgets
1306
- is_checked = state.enabled_widget.isChecked()
1307
- for widget in [
1308
- state.name_widget,
1309
- state.is_scored_widget,
1310
- ]:
1311
- widget.setEnabled(is_checked)
1312
- state.frequency_widget.setEnabled(
1313
- is_checked and state.is_scored_widget.isChecked()
1314
- )
1315
- if not is_checked:
1316
- state.name_widget.setText("")
1317
- state.frequency_widget.setValue(0)
1318
- # check that configuration is valid
1319
- _ = self.check_config_validity()
1320
-
1321
- def is_scored_changed(self, digit, e) -> None:
1322
- """Called when user sets whether a state is scored
1323
-
1324
- :param digit: brain state digit
1325
- :param e: unused, but mandatory
1326
- """
1327
- # get the widgets for this brain state
1328
- state = self.settings_widgets[digit]
1329
- # update the state of these widgets
1330
- is_checked = state.is_scored_widget.isChecked()
1331
- state.frequency_widget.setEnabled(is_checked)
1332
- if not is_checked:
1333
- state.frequency_widget.setValue(0)
1334
- # check that configuration is valid
1335
- _ = self.check_config_validity()
1336
-
1337
- def emg_filter_order_changed(self, new_value: int) -> None:
1338
- """Called when user modifies EMG filter order
1339
-
1340
- :param new_value: new EMG filter order
1341
- """
1342
- self.emg_filter.order = new_value
1343
-
1344
- def emg_filter_bp_lower_changed(self, new_value: int | float) -> None:
1345
- """Called when user modifies EMG filter lower cutoff
1346
-
1347
- :param new_value: new lower bandpass cutoff frequency
1348
- """
1349
- self.emg_filter.bp_lower = new_value
1350
- _ = self.check_config_validity()
1351
-
1352
- def emg_filter_bp_upper_changed(self, new_value: int | float) -> None:
1353
- """Called when user modifies EMG filter upper cutoff
1354
-
1355
- :param new_value: new upper bandpass cutoff frequency
1356
- """
1357
- self.emg_filter.bp_upper = new_value
1358
- _ = self.check_config_validity()
1359
-
1360
- def hyperparameters_changed(self, new_value) -> None:
1361
- """Called when user modifies model training hyperparameters
1362
-
1363
- :param new_value: unused
1364
- """
1365
- self.hyperparameters = Hyperparameters(
1366
- batch_size=self.ui.batch_size_spinbox.value(),
1367
- learning_rate=self.ui.learning_rate_spinbox.value(),
1368
- momentum=self.ui.momentum_spinbox.value(),
1369
- training_epochs=self.ui.training_epochs_spinbox.value(),
1370
- )
1371
-
1372
- def check_config_validity(self) -> str:
1373
- """Check if brain state configuration on screen is valid"""
1374
- # error message, if we get one
1375
- message = None
1376
-
1377
- # strip whitespace from brain state names and update display
1378
- for digit in range(10):
1379
- state = self.settings_widgets[digit]
1380
- current_name = state.name_widget.text()
1381
- formatted_name = current_name.strip()
1382
- if current_name != formatted_name:
1383
- state.name_widget.setText(formatted_name)
1384
-
1385
- # check if names are unique and frequencies add up to 1
1386
- names = []
1387
- frequencies = []
1388
- for digit in range(10):
1389
- state = self.settings_widgets[digit]
1390
- if state.enabled_widget.isChecked():
1391
- names.append(state.name_widget.text())
1392
- frequencies.append(state.frequency_widget.value())
1393
- if len(names) != len(set(names)):
1394
- message = "Error: names must be unique"
1395
- if sum(frequencies) != 1:
1396
- message = "Error: sum(frequencies) != 1"
1397
-
1398
- # check validity of EMG filter settings
1399
- if self.emg_filter.bp_lower >= self.emg_filter.bp_upper:
1400
- message = "Error: EMG filter cutoff frequencies are invalid"
1401
-
1402
- if message is not None:
1403
- self.ui.save_config_status.setText(message)
1404
- self.ui.save_config_button.setEnabled(False)
1405
- return message
1406
-
1407
- self.ui.save_config_button.setEnabled(True)
1408
- self.ui.save_config_status.setText("")
1409
-
1410
- def save_brain_state_config(self):
1411
- """Save configuration to file"""
1412
- # check that configuration is valid
1413
- error_message = self.check_config_validity()
1414
- if error_message is not None:
1415
- return
1416
-
1417
- # build a BrainStateMapper object from the current configuration
1418
- brain_states = list()
1419
- for digit in range(10):
1420
- state = self.settings_widgets[digit]
1421
- if state.enabled_widget.isChecked():
1422
- brain_states.append(
1423
- BrainState(
1424
- name=state.name_widget.text(),
1425
- digit=digit,
1426
- is_scored=state.is_scored_widget.isChecked(),
1427
- frequency=state.frequency_widget.value(),
1428
- )
1429
- )
1430
- self.brain_state_set = BrainStateSet(brain_states, UNDEFINED_LABEL)
1431
-
1432
- # save to file
1433
- save_config(
1434
- brain_state_set=self.brain_state_set,
1435
- default_epoch_length=self.ui.default_epoch_input.value(),
1436
- overwrite_setting=self.ui.overwrite_default_checkbox.isChecked(),
1437
- save_confidence_setting=self.ui.confidence_setting_checkbox.isChecked(),
1438
- min_bout_length=self.ui.default_min_bout_length_spinbox.value(),
1439
- emg_filter=EMGFilter(
1440
- order=self.emg_filter.order,
1441
- bp_lower=self.emg_filter.bp_lower,
1442
- bp_upper=self.emg_filter.bp_upper,
1443
- ),
1444
- hyperparameters=Hyperparameters(
1445
- batch_size=self.hyperparameters.batch_size,
1446
- learning_rate=self.hyperparameters.learning_rate,
1447
- momentum=self.hyperparameters.momentum,
1448
- training_epochs=self.hyperparameters.training_epochs,
1449
- ),
1450
- epochs_to_show=self.ui.epochs_to_show_spinbox.value(),
1451
- autoscroll_state=self.ui.autoscroll_checkbox.isChecked(),
1452
- )
1453
- self.ui.save_config_status.setText("configuration saved")
1454
-
1455
- def reset_emg_filter_settings(self) -> None:
1456
- self.ui.emg_order_spinbox.setValue(DEFAULT_EMG_FILTER_ORDER)
1457
- self.ui.bp_lower_spinbox.setValue(DEFAULT_EMG_BP_LOWER)
1458
- self.ui.bp_upper_spinbox.setValue(DEFAULT_EMG_BP_UPPER)
1459
-
1460
- def reset_hyperparams_settings(self):
1461
- self.ui.batch_size_spinbox.setValue(DEFAULT_BATCH_SIZE)
1462
- self.ui.learning_rate_spinbox.setValue(DEFAULT_LEARNING_RATE)
1463
- self.ui.momentum_spinbox.setValue(DEFAULT_MOMENTUM)
1464
- self.ui.training_epochs_spinbox.setValue(DEFAULT_TRAINING_EPOCHS)
1465
-
1466
649
 
1467
650
  def run_primary_window() -> None:
651
+ logging.basicConfig(
652
+ level=logging.INFO,
653
+ format="%(levelname)s - %(name)s - %(message)s",
654
+ )
1468
655
  app = QApplication(sys.argv)
1469
656
  AccuSleepWindow()
1470
657
  sys.exit(app.exec())