accusleepy 0.8.0__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
accusleepy/gui/main.py CHANGED
@@ -1,15 +1,13 @@
1
1
  # AccuSleePy main window
2
2
  # Icon source: Arkinasi, https://www.flaticon.com/authors/arkinasi
3
3
 
4
- import datetime
4
+ import logging
5
5
  import os
6
- import shutil
7
6
  import sys
8
7
  from dataclasses import dataclass
9
8
  from functools import partial
10
9
 
11
10
  import numpy as np
12
- import toml
13
11
  from PySide6.QtCore import (
14
12
  QEvent,
15
13
  QKeyCombination,
@@ -21,31 +19,17 @@ from PySide6.QtCore import (
21
19
  from PySide6.QtGui import QKeySequence, QShortcut
22
20
  from PySide6.QtWidgets import (
23
21
  QApplication,
24
- QCheckBox,
25
- QDoubleSpinBox,
26
- QFileDialog,
27
22
  QLabel,
28
- QListWidgetItem,
29
23
  QMainWindow,
30
24
  QTextBrowser,
31
25
  QVBoxLayout,
32
26
  QWidget,
33
27
  )
34
28
 
35
- from accusleepy.bouts import enforce_min_bout_length
36
- from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
29
+ from accusleepy.brain_state_set import BRAIN_STATES_KEY
37
30
  from accusleepy.constants import (
38
- ANNOTATIONS_FILENAME,
39
- CALIBRATION_ANNOTATION_FILENAME,
40
31
  CALIBRATION_FILE_TYPE,
41
32
  DEFAULT_MODEL_TYPE,
42
- DEFAULT_EMG_FILTER_ORDER,
43
- DEFAULT_EMG_BP_LOWER,
44
- DEFAULT_EMG_BP_UPPER,
45
- DEFAULT_BATCH_SIZE,
46
- DEFAULT_LEARNING_RATE,
47
- DEFAULT_MOMENTUM,
48
- DEFAULT_TRAINING_EPOCHS,
49
33
  LABEL_FILE_TYPE,
50
34
  MESSAGE_BOX_MAX_DEPTH,
51
35
  MODEL_FILE_TYPE,
@@ -55,31 +39,28 @@ from accusleepy.constants import (
55
39
  UNDEFINED_LABEL,
56
40
  )
57
41
  from accusleepy.fileio import (
58
- Recording,
59
- load_calibration_file,
60
42
  load_config,
61
43
  load_labels,
62
44
  load_recording,
63
- load_recording_list,
64
- save_config,
65
- save_labels,
66
- save_recording_list,
67
- EMGFilter,
68
- Hyperparameters,
45
+ get_version,
69
46
  )
47
+ from accusleepy.gui.dialogs import select_existing_file, select_save_location
70
48
  from accusleepy.gui.manual_scoring import ManualScoringWindow
71
49
  from accusleepy.gui.primary_window import Ui_PrimaryWindow
72
- from accusleepy.signal_processing import (
73
- create_training_images,
74
- resample_and_standardize,
75
- )
76
- from accusleepy.validation import (
77
- check_label_validity,
78
- LABEL_LENGTH_ERROR,
79
- check_config_consistency,
50
+ from accusleepy.gui.recording_manager import RecordingListManager
51
+ from accusleepy.gui.settings_widget import SettingsWidget
52
+ from accusleepy.services import (
53
+ LoadedModel,
54
+ TrainingService,
55
+ check_single_file_inputs,
56
+ create_calibration,
57
+ score_recording_list,
80
58
  )
59
+ from accusleepy.validation import validate_and_correct_labels
60
+ from accusleepy.signal_processing import resample_and_standardize
61
+ from accusleepy.validation import check_config_consistency
81
62
 
82
- # note: functions using torch or scipy are lazily imported
63
+ logger = logging.getLogger(__name__)
83
64
 
84
65
  # on Windows, prevent dark mode from changing the visual style
85
66
  if os.name == "nt":
@@ -91,14 +72,22 @@ MAIN_GUIDE_FILE = os.path.normpath(r"text/main_guide.md")
91
72
 
92
73
 
93
74
  @dataclass
94
- class StateSettings:
95
- """Widgets for config settings for a brain state"""
75
+ class TrainingSettings:
76
+ """Settings for training a new model"""
77
+
78
+ epochs_per_img: int = 9
79
+ delete_images: bool = True
80
+ model_type: str = DEFAULT_MODEL_TYPE
81
+ calibrate: bool = True
82
+
83
+
84
+ @dataclass
85
+ class ScoringSettings:
86
+ """Settings for scoring a recording"""
96
87
 
97
- digit: int
98
- enabled_widget: QCheckBox
99
- name_widget: QLabel
100
- is_scored_widget: QCheckBox
101
- frequency_widget: QDoubleSpinBox
88
+ only_overwrite_undefined: bool
89
+ save_confidence_scores: bool
90
+ min_bout_length: int | float
102
91
 
103
92
 
104
93
  class AccuSleepWindow(QMainWindow):
@@ -112,66 +101,42 @@ class AccuSleepWindow(QMainWindow):
112
101
  self.ui.setupUi(self)
113
102
  self.setWindowTitle("AccuSleePy")
114
103
 
115
- # fill in settings tab
116
- (
117
- self.brain_state_set,
118
- self.epoch_length,
119
- self.only_overwrite_undefined,
120
- self.save_confidence_scores,
121
- self.min_bout_length,
122
- self.emg_filter,
123
- self.hyperparameters,
124
- self.default_epochs_to_show,
125
- self.default_autoscroll_state,
126
- ) = load_config()
127
-
128
- self.settings_widgets = None
129
- self.initialize_settings_tab()
104
+ # Load configuration
105
+ loaded_config = load_config()
106
+
107
+ # Apply default values from the configuration
108
+ self.epoch_length = loaded_config.default_epoch_length
109
+ self.scoring = ScoringSettings(
110
+ only_overwrite_undefined=loaded_config.overwrite_setting,
111
+ save_confidence_scores=loaded_config.save_confidence_setting,
112
+ min_bout_length=loaded_config.min_bout_length,
113
+ )
114
+
115
+ # Initialize settings tab (manages Settings tab UI and saved config values)
116
+ self.config = SettingsWidget(ui=self.ui, config=loaded_config, parent=self)
130
117
 
131
118
  # initialize info about the recordings, classification data / settings
132
119
  self.ui.epoch_length_input.setValue(self.epoch_length)
133
- self.ui.overwritecheckbox.setChecked(self.only_overwrite_undefined)
134
- self.ui.save_confidence_checkbox.setChecked(self.save_confidence_scores)
135
- self.ui.bout_length_input.setValue(self.min_bout_length)
136
- self.model = None
120
+ self.ui.overwritecheckbox.setChecked(self.scoring.only_overwrite_undefined)
121
+ self.ui.save_confidence_checkbox.setChecked(self.scoring.save_confidence_scores)
122
+ self.ui.bout_length_input.setValue(self.scoring.min_bout_length)
137
123
 
138
- # initialize model training variables
139
- self.training_epochs_per_img = 9
140
- self.delete_training_images = True
141
- self.model_type = DEFAULT_MODEL_TYPE
142
- self.calibrate_trained_model = True
124
+ # loaded classification model and its metadata
125
+ self.loaded_model = LoadedModel()
143
126
 
144
- # metadata for the currently loaded classification model
145
- self.model_epoch_length = None
146
- self.model_epochs_per_img = None
127
+ # settings for training new models
128
+ self.training = TrainingSettings()
147
129
 
148
130
  # set up the list of recordings
149
- first_recording = Recording(
150
- widget=QListWidgetItem("Recording 1", self.ui.recording_list_widget),
131
+ self.recording_manager = RecordingListManager(
132
+ self.ui.recording_list_widget, parent=self
151
133
  )
152
- self.ui.recording_list_widget.addItem(first_recording.widget)
153
- self.ui.recording_list_widget.setCurrentRow(0)
154
- # index of currently selected recording in the list
155
- self.recording_index = 0
156
- # list of recordings the user has added
157
- self.recordings = [first_recording]
158
134
 
159
135
  # messages to display
160
136
  self.messages = []
161
137
 
162
138
  # display current version
163
- version = ""
164
- toml_file = os.path.join(
165
- os.path.dirname(
166
- os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
167
- ),
168
- "pyproject.toml",
169
- )
170
- if os.path.isfile(toml_file):
171
- toml_data = toml.load(toml_file)
172
- if "project" in toml_data and "version" in toml_data["project"]:
173
- version = toml_data["project"]["version"]
174
- self.ui.version_label.setText(f"v{version}")
139
+ self.ui.version_label.setText(f"v{get_version()}")
175
140
 
176
141
  # user input: keyboard shortcuts
177
142
  keypress_quit = QShortcut(
@@ -184,8 +149,12 @@ class AccuSleepWindow(QMainWindow):
184
149
  self.ui.add_button.clicked.connect(self.add_recording)
185
150
  self.ui.remove_button.clicked.connect(self.remove_recording)
186
151
  self.ui.recording_list_widget.currentRowChanged.connect(self.select_recording)
187
- self.ui.sampling_rate_input.valueChanged.connect(self.update_sampling_rate)
188
- self.ui.epoch_length_input.valueChanged.connect(self.update_epoch_length)
152
+ self.ui.sampling_rate_input.valueChanged.connect(
153
+ lambda v: setattr(self.recording_manager.current, "sampling_rate", v)
154
+ )
155
+ self.ui.epoch_length_input.valueChanged.connect(
156
+ lambda v: setattr(self, "epoch_length", v)
157
+ )
189
158
  self.ui.recording_file_button.clicked.connect(self.select_recording_file)
190
159
  self.ui.select_label_button.clicked.connect(self.select_label_file)
191
160
  self.ui.create_label_button.clicked.connect(self.create_label_file)
@@ -193,27 +162,33 @@ class AccuSleepWindow(QMainWindow):
193
162
  self.ui.create_calibration_button.clicked.connect(self.create_calibration_file)
194
163
  self.ui.select_calibration_button.clicked.connect(self.select_calibration_file)
195
164
  self.ui.load_model_button.clicked.connect(partial(self.load_model, None))
196
- self.ui.score_all_button.clicked.connect(self.score_all)
197
- self.ui.overwritecheckbox.stateChanged.connect(self.update_overwrite_policy)
165
+ self.ui.score_all_button.clicked.connect(self.score_recordings)
166
+ self.ui.overwritecheckbox.stateChanged.connect(
167
+ lambda v: setattr(self.scoring, "only_overwrite_undefined", bool(v))
168
+ )
198
169
  self.ui.save_confidence_checkbox.stateChanged.connect(
199
- self.update_confidence_policy
170
+ lambda v: setattr(self.scoring, "save_confidence_scores", bool(v))
171
+ )
172
+ self.ui.bout_length_input.valueChanged.connect(
173
+ lambda v: setattr(self.scoring, "min_bout_length", v)
200
174
  )
201
- self.ui.bout_length_input.valueChanged.connect(self.update_min_bout_length)
202
175
  self.ui.user_manual_button.clicked.connect(self.show_user_manual)
203
- self.ui.image_number_input.valueChanged.connect(self.update_epochs_per_img)
204
- self.ui.delete_image_box.stateChanged.connect(self.update_image_deletion)
176
+ self.ui.image_number_input.valueChanged.connect(
177
+ lambda v: setattr(self.training, "epochs_per_img", v)
178
+ )
179
+ self.ui.delete_image_box.stateChanged.connect(
180
+ lambda v: setattr(self.training, "delete_images", bool(v))
181
+ )
205
182
  self.ui.calibrate_checkbox.stateChanged.connect(
206
183
  self.update_training_calibration
207
184
  )
208
185
  self.ui.train_model_button.clicked.connect(self.train_model)
209
- self.ui.save_config_button.clicked.connect(self.save_brain_state_config)
186
+ self.ui.save_config_button.clicked.connect(self.config.save_config)
210
187
  self.ui.export_button.clicked.connect(self.export_recording_list)
211
188
  self.ui.import_button.clicked.connect(self.import_recording_list)
212
189
  self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
213
- self.ui.reset_emg_params_button.clicked.connect(self.reset_emg_filter_settings)
214
- self.ui.reset_hyperparams_button.clicked.connect(
215
- self.reset_hyperparams_settings
216
- )
190
+ self.ui.reset_emg_params_button.clicked.connect(self.config.reset_emg_filter)
191
+ self.ui.reset_hyperparams_button.clicked.connect(self.config.reset_hyperparams)
217
192
 
218
193
  # user input: drag and drop
219
194
  self.ui.recording_file_label.installEventFilter(self)
@@ -228,52 +203,29 @@ class AccuSleepWindow(QMainWindow):
228
203
 
229
204
  :param default_selected: whether default option is selected
230
205
  """
231
- self.model_type = (
206
+ self.training.model_type = (
232
207
  DEFAULT_MODEL_TYPE if default_selected else REAL_TIME_MODEL_TYPE
233
208
  )
234
209
 
235
210
  def export_recording_list(self) -> None:
236
211
  """Save current list of recordings to file"""
237
- # get the name for the recording list file
238
- filename, _ = QFileDialog.getSaveFileName(
239
- self,
240
- caption="Save list of recordings as",
241
- filter="*" + RECORDING_LIST_FILE_TYPE,
212
+ filename = select_save_location(
213
+ self, "Save list of recordings as", "*" + RECORDING_LIST_FILE_TYPE
242
214
  )
243
215
  if not filename:
244
216
  return
245
- filename = os.path.normpath(filename)
246
- save_recording_list(filename=filename, recordings=self.recordings)
217
+ self.recording_manager.export_to_file(filename)
247
218
  self.show_message(f"Saved list of recordings to {filename}")
248
219
 
249
220
  def import_recording_list(self):
250
221
  """Load list of recordings from file, overwriting current list"""
251
- file_dialog = QFileDialog(self)
252
- file_dialog.setWindowTitle("Select list of recordings")
253
- file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
254
- file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
255
- file_dialog.setNameFilter("*" + RECORDING_LIST_FILE_TYPE)
256
-
257
- if file_dialog.exec():
258
- selected_files = file_dialog.selectedFiles()
259
- filename = selected_files[0]
260
- filename = os.path.normpath(filename)
261
- else:
222
+ filename = select_existing_file(
223
+ self, "Select list of recordings", "*" + RECORDING_LIST_FILE_TYPE
224
+ )
225
+ if not filename:
262
226
  return
263
227
 
264
- # clear widget
265
- self.ui.recording_list_widget.clear()
266
- # overwrite current list
267
- self.recordings = load_recording_list(filename)
268
-
269
- for recording in self.recordings:
270
- recording.widget = QListWidgetItem(
271
- f"Recording {recording.name}", self.ui.recording_list_widget
272
- )
273
- self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
274
-
275
- # display new list
276
- self.ui.recording_list_widget.setCurrentRow(0)
228
+ self.recording_manager.import_from_file(filename)
277
229
  self.show_message(f"Loaded list of recordings from {filename}")
278
230
 
279
231
  def eventFilter(self, obj: QObject, event: QEvent) -> bool:
@@ -298,20 +250,21 @@ class AccuSleepWindow(QMainWindow):
298
250
 
299
251
  if filename is None:
300
252
  return super().eventFilter(obj, event)
253
+ filename = str(filename)
301
254
 
302
255
  _, file_extension = os.path.splitext(filename)
303
256
 
304
257
  if obj == self.ui.recording_file_label:
305
258
  if file_extension in RECORDING_FILE_TYPES:
306
- self.recordings[self.recording_index].recording_file = filename
259
+ self.recording_manager.current.recording_file = filename
307
260
  self.ui.recording_file_label.setText(filename)
308
261
  elif obj == self.ui.label_file_label:
309
262
  if file_extension == LABEL_FILE_TYPE:
310
- self.recordings[self.recording_index].label_file = filename
263
+ self.recording_manager.current.label_file = filename
311
264
  self.ui.label_file_label.setText(filename)
312
265
  elif obj == self.ui.calibration_file_label:
313
266
  if file_extension == CALIBRATION_FILE_TYPE:
314
- self.recordings[self.recording_index].calibration_file = filename
267
+ self.recording_manager.current.calibration_file = filename
315
268
  self.ui.calibration_file_label.setText(filename)
316
269
  elif obj == self.ui.model_label:
317
270
  self.load_model(filename=filename)
@@ -319,335 +272,69 @@ class AccuSleepWindow(QMainWindow):
319
272
  return super().eventFilter(obj, event)
320
273
 
321
274
  def train_model(self) -> None:
322
- # check basic training inputs
323
- if (
324
- self.model_type == DEFAULT_MODEL_TYPE
325
- and self.training_epochs_per_img % 2 == 0
326
- ):
327
- self.show_message(
328
- (
329
- "ERROR: for the default model type, number of epochs "
330
- "per image must be an odd number."
331
- )
332
- )
333
- return
334
-
335
- # determine fraction of training data to use for calibration
336
- if self.calibrate_trained_model:
337
- calibration_fraction = self.ui.calibration_spinbox.value() / 100
338
- else:
339
- calibration_fraction = 0
340
-
341
- # check some inputs for each recording
342
- for recording_index in range(len(self.recordings)):
343
- error_message = self.check_single_file_inputs(recording_index)
344
- if error_message:
345
- self.show_message(
346
- f"ERROR (recording {self.recordings[recording_index].name}): {error_message}"
347
- )
348
- return
349
-
350
- # get filename for the new model
351
- model_filename, _ = QFileDialog.getSaveFileName(
352
- self,
353
- caption="Save classification model file as",
354
- filter="*" + MODEL_FILE_TYPE,
275
+ """Train a classification model using the current recordings."""
276
+ model_filename = select_save_location(
277
+ self, "Save classification model file as", "*" + MODEL_FILE_TYPE
355
278
  )
356
279
  if not model_filename:
357
280
  self.show_message("Model training canceled, no filename given")
358
281
  return
359
- model_filename = os.path.normpath(model_filename)
360
282
 
361
- # create (probably temporary) image folder in
362
- # the same folder as the trained model
363
- temp_image_dir = os.path.join(
364
- os.path.dirname(model_filename),
365
- "images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
366
- )
367
-
368
- if os.path.exists(temp_image_dir): # unlikely
369
- self.show_message(
370
- "Warning: training image folder exists, will be overwritten"
371
- )
372
- os.makedirs(temp_image_dir, exist_ok=True)
283
+ # Determine calibration fraction
284
+ if self.training.calibrate:
285
+ calibration_fraction = self.ui.calibration_spinbox.value() / 100
286
+ else:
287
+ calibration_fraction = 0
373
288
 
374
- # create training images
289
+ # Show progress message
375
290
  self.show_message("Training, please wait. See console for progress updates.")
376
- if not self.delete_training_images:
377
- self.show_message((f"Creating training images in {temp_image_dir}"))
378
- else:
379
- self.show_message(
380
- (f"Creating temporary folder of training images: {temp_image_dir}")
381
- )
382
291
  self.ui.message_area.repaint()
383
292
  QApplication.processEvents()
384
- print("Creating training images")
385
- failed_recordings = create_training_images(
386
- recordings=self.recordings,
387
- output_path=temp_image_dir,
388
- epoch_length=self.epoch_length,
389
- epochs_per_img=self.training_epochs_per_img,
390
- brain_state_set=self.brain_state_set,
391
- model_type=self.model_type,
392
- calibration_fraction=calibration_fraction,
393
- emg_filter=self.emg_filter,
394
- )
395
- if len(failed_recordings) > 0:
396
- if len(failed_recordings) == len(self.recordings):
397
- self.show_message("ERROR: no recordings were valid!")
398
- return
399
- else:
400
- self.show_message(
401
- (
402
- "WARNING: the following recordings could not be "
403
- "loaded and will not be used for training: "
404
- f"{', '.join([str(r) for r in failed_recordings])}"
405
- )
406
- )
407
293
 
408
- # train model
409
- self.show_message("Training model")
410
- self.ui.message_area.repaint()
411
- QApplication.processEvents()
412
- print("Training model")
413
- from accusleepy.classification import create_dataloader, train_ssann
414
- from accusleepy.models import save_model
415
- from accusleepy.temperature_scaling import ModelWithTemperature
416
-
417
- model = train_ssann(
418
- annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
419
- img_dir=temp_image_dir,
420
- mixture_weights=self.brain_state_set.mixture_weights,
421
- n_classes=self.brain_state_set.n_classes,
422
- hyperparameters=self.hyperparameters,
423
- )
424
-
425
- # calibrate the model
426
- if self.calibrate_trained_model:
427
- calibration_annotation_file = os.path.join(
428
- temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
429
- )
430
- calibration_dataloader = create_dataloader(
431
- annotations_file=calibration_annotation_file,
432
- img_dir=temp_image_dir,
433
- hyperparameters=self.hyperparameters,
434
- )
435
- model = ModelWithTemperature(model)
436
- print("Calibrating model")
437
- model.set_temperature(calibration_dataloader)
438
-
439
- # save model
440
- save_model(
441
- model=model,
442
- filename=model_filename,
294
+ # Create service and run training
295
+ service = TrainingService(progress_callback=self.show_message)
296
+ result = service.train_model(
297
+ recordings=list(self.recording_manager),
443
298
  epoch_length=self.epoch_length,
444
- epochs_per_img=self.training_epochs_per_img,
445
- model_type=self.model_type,
446
- brain_state_set=self.brain_state_set,
447
- is_calibrated=self.calibrate_trained_model,
299
+ epochs_per_img=self.training.epochs_per_img,
300
+ model_type=self.training.model_type,
301
+ calibrate=self.training.calibrate,
302
+ calibration_fraction=calibration_fraction,
303
+ brain_state_set=self.config.brain_state_set,
304
+ emg_filter=self.config.emg_filter,
305
+ hyperparameters=self.config.hyperparameters,
306
+ model_filename=model_filename,
307
+ delete_images=self.training.delete_images,
448
308
  )
449
309
 
450
- # optionally delete images
451
- if self.delete_training_images:
452
- print("Cleaning up training image folder")
453
- shutil.rmtree(temp_image_dir)
454
-
455
- self.show_message(f"Training complete. Saved model to {model_filename}")
456
- print("Training complete.")
457
-
458
- def update_image_deletion(self) -> None:
459
- """Update choice of whether to delete images after training"""
460
- self.delete_training_images = self.ui.delete_image_box.isChecked()
310
+ # Display results
311
+ result.report_to(self.show_message)
461
312
 
462
313
  def update_training_calibration(self) -> None:
463
314
  """Update choice of whether to calibrate model after training"""
464
- self.calibrate_trained_model = self.ui.calibrate_checkbox.isChecked()
465
- self.ui.calibration_spinbox.setEnabled(self.calibrate_trained_model)
466
-
467
- def update_epochs_per_img(self, new_value) -> None:
468
- """Update number of epochs per image
469
-
470
- :param new_value: new number of epochs per image
471
- """
472
- self.training_epochs_per_img = new_value
473
-
474
- def score_all(self) -> None:
475
- """Score all recordings using the classification model"""
476
- # check basic inputs
477
- if self.model is None:
478
- self.ui.score_all_status.setText("missing classification model")
479
- self.show_message("ERROR: no classification model file selected")
480
- return
481
- if self.min_bout_length < self.epoch_length:
482
- self.ui.score_all_status.setText("invalid minimum bout length")
483
- self.show_message("ERROR: minimum bout length must be >= epoch length")
484
- return
485
- if self.epoch_length != self.model_epoch_length:
486
- self.ui.score_all_status.setText("invalid epoch length")
487
- self.show_message(
488
- (
489
- "ERROR: model was trained with an epoch length of "
490
- f"{self.model_epoch_length} seconds, but the current "
491
- f"epoch length setting is {self.epoch_length} seconds."
492
- )
493
- )
494
- return
315
+ self.training.calibrate = self.ui.calibrate_checkbox.isChecked()
316
+ self.ui.calibration_spinbox.setEnabled(self.training.calibrate)
495
317
 
318
+ def score_recordings(self) -> None:
319
+ """Score all recordings using the classification model."""
496
320
  self.ui.score_all_status.setText("running...")
497
321
  self.ui.score_all_status.repaint()
498
322
  QApplication.processEvents()
499
323
 
500
- from accusleepy.classification import score_recording
501
-
502
- # check some inputs for each recording
503
- for recording_index in range(len(self.recordings)):
504
- error_message = self.check_single_file_inputs(recording_index)
505
- if error_message:
506
- self.ui.score_all_status.setText(
507
- f"error on recording {self.recordings[recording_index].name}"
508
- )
509
- self.show_message(
510
- f"ERROR (recording {self.recordings[recording_index].name}): {error_message}"
511
- )
512
- return
513
- if self.recordings[recording_index].calibration_file == "":
514
- self.ui.score_all_status.setText(
515
- f"error on recording {self.recordings[recording_index].name}"
516
- )
517
- self.show_message(
518
- (
519
- f"ERROR (recording {self.recordings[recording_index].name}): "
520
- "no calibration file selected"
521
- )
522
- )
523
- return
524
-
525
- # score each recording
526
- for recording_index in range(len(self.recordings)):
527
- # load EEG, EMG
528
- try:
529
- eeg, emg = load_recording(
530
- self.recordings[recording_index].recording_file
531
- )
532
- sampling_rate = self.recordings[recording_index].sampling_rate
533
-
534
- eeg, emg, sampling_rate = resample_and_standardize(
535
- eeg=eeg,
536
- emg=emg,
537
- sampling_rate=sampling_rate,
538
- epoch_length=self.epoch_length,
539
- )
540
- except Exception:
541
- self.show_message(
542
- (
543
- "ERROR: could not load recording "
544
- f"{self.recordings[recording_index].name}."
545
- "This recording will be skipped."
546
- )
547
- )
548
- continue
549
-
550
- # load labels
551
- label_file = self.recordings[recording_index].label_file
552
- if os.path.isfile(label_file):
553
- try:
554
- # ignore any existing confidence scores; they will all be overwritten
555
- existing_labels, _ = load_labels(label_file)
556
- except Exception:
557
- self.show_message(
558
- (
559
- "ERROR: could not load existing labels for recording "
560
- f"{self.recordings[recording_index].name}."
561
- "This recording will be skipped."
562
- )
563
- )
564
- continue
565
- # only check the length
566
- samples_per_epoch = sampling_rate * self.epoch_length
567
- epochs_in_recording = round(eeg.size / samples_per_epoch)
568
- if epochs_in_recording != existing_labels.size:
569
- self.show_message(
570
- (
571
- "ERROR: existing labels for recording "
572
- f"{self.recordings[recording_index].name} "
573
- "do not match the recording length. "
574
- "This recording will be skipped."
575
- )
576
- )
577
- continue
578
- else:
579
- existing_labels = None
580
-
581
- # load calibration data
582
- if not os.path.isfile(self.recordings[recording_index].calibration_file):
583
- self.show_message(
584
- (
585
- "ERROR: calibration file does not exist for recording "
586
- f"{self.recordings[recording_index].name}. "
587
- "This recording will be skipped."
588
- )
589
- )
590
- continue
591
- try:
592
- (
593
- mixture_means,
594
- mixture_sds,
595
- ) = load_calibration_file(
596
- self.recordings[recording_index].calibration_file
597
- )
598
- except Exception:
599
- self.show_message(
600
- (
601
- "ERROR: could not load calibration file for recording "
602
- f"{self.recordings[recording_index].name}. "
603
- "This recording will be skipped."
604
- )
605
- )
606
- continue
607
-
608
- labels, confidence_scores = score_recording(
609
- model=self.model,
610
- eeg=eeg,
611
- emg=emg,
612
- mixture_means=mixture_means,
613
- mixture_sds=mixture_sds,
614
- sampling_rate=sampling_rate,
615
- epoch_length=self.epoch_length,
616
- epochs_per_img=self.model_epochs_per_img,
617
- brain_state_set=self.brain_state_set,
618
- emg_filter=self.emg_filter,
619
- )
620
-
621
- # overwrite as needed
622
- if existing_labels is not None and self.only_overwrite_undefined:
623
- labels[existing_labels != UNDEFINED_LABEL] = existing_labels[
624
- existing_labels != UNDEFINED_LABEL
625
- ]
626
-
627
- # enforce minimum bout length
628
- labels = enforce_min_bout_length(
629
- labels=labels,
630
- epoch_length=self.epoch_length,
631
- min_bout_length=self.min_bout_length,
632
- )
633
-
634
- # ignore confidence scores if desired
635
- if not self.save_confidence_scores:
636
- confidence_scores = None
637
-
638
- # save results
639
- save_labels(
640
- labels=labels, filename=label_file, confidence_scores=confidence_scores
641
- )
642
- self.show_message(
643
- (
644
- "Saved labels for recording "
645
- f"{self.recordings[recording_index].name} "
646
- f"to {label_file}"
647
- )
648
- )
324
+ result = score_recording_list(
325
+ recordings=list(self.recording_manager),
326
+ loaded_model=self.loaded_model,
327
+ epoch_length=self.epoch_length,
328
+ only_overwrite_undefined=self.scoring.only_overwrite_undefined,
329
+ save_confidence_scores=self.scoring.save_confidence_scores,
330
+ min_bout_length=self.scoring.min_bout_length,
331
+ brain_state_set=self.config.brain_state_set,
332
+ emg_filter=self.config.emg_filter,
333
+ )
649
334
 
650
- self.ui.score_all_status.setText("")
335
+ # Display results
336
+ result.report_to(self.show_message)
337
+ self.ui.score_all_status.setText("error" if not result.success else "")
651
338
 
652
339
  def load_model(self, filename=None) -> None:
653
340
  """Load trained classification model from file
@@ -655,17 +342,10 @@ class AccuSleepWindow(QMainWindow):
655
342
  :param filename: model filename, if it's known
656
343
  """
657
344
  if filename is None:
658
- file_dialog = QFileDialog(self)
659
- file_dialog.setWindowTitle("Select classification model")
660
- file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
661
- file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
662
- file_dialog.setNameFilter("*" + MODEL_FILE_TYPE)
663
-
664
- if file_dialog.exec():
665
- selected_files = file_dialog.selectedFiles()
666
- filename = selected_files[0]
667
- filename = os.path.normpath(filename)
668
- else:
345
+ filename = select_existing_file(
346
+ self, "Select classification model", "*" + MODEL_FILE_TYPE
347
+ )
348
+ if not filename:
669
349
  return
670
350
 
671
351
  if not os.path.isfile(filename):
@@ -683,6 +363,7 @@ class AccuSleepWindow(QMainWindow):
683
363
  filename=filename
684
364
  )
685
365
  except Exception:
366
+ logger.exception("Failed to load %s", filename)
686
367
  self.show_message(
687
368
  (
688
369
  "ERROR: could not load classification model. Check "
@@ -703,14 +384,14 @@ class AccuSleepWindow(QMainWindow):
703
384
  )
704
385
  return
705
386
 
706
- self.model = model
707
- self.model_epoch_length = epoch_length
708
- self.model_epochs_per_img = epochs_per_img
387
+ self.loaded_model.model = model
388
+ self.loaded_model.epoch_length = epoch_length
389
+ self.loaded_model.epochs_per_img = epochs_per_img
709
390
 
710
391
  # warn user if the model's expected epoch length or brain states
711
392
  # don't match the current configuration
712
393
  config_warnings = check_config_consistency(
713
- current_brain_states=self.brain_state_set.to_output_dict()[
394
+ current_brain_states=self.config.brain_state_set.to_output_dict()[
714
395
  BRAIN_STATES_KEY
715
396
  ],
716
397
  model_brain_states=brain_states,
@@ -737,17 +418,21 @@ class AccuSleepWindow(QMainWindow):
737
418
  :param status_widget: UI element on which to display error messages
738
419
  :return: EEG data, EMG data, sampling rate, process completion
739
420
  """
740
- error_message = self.check_single_file_inputs(self.recording_index)
421
+ error_message = check_single_file_inputs(
422
+ self.recording_manager.current, self.epoch_length
423
+ )
741
424
  if error_message:
742
425
  status_widget.setText(error_message)
743
426
  self.show_message(f"ERROR: {error_message}")
744
427
  return None, None, None, False
745
428
 
746
429
  try:
747
- eeg, emg = load_recording(
748
- self.recordings[self.recording_index].recording_file
749
- )
430
+ eeg, emg = load_recording(self.recording_manager.current.recording_file)
750
431
  except Exception:
432
+ logger.exception(
433
+ "Failed to load %s",
434
+ self.recording_manager.current.recording_file,
435
+ )
751
436
  status_widget.setText("could not load recording")
752
437
  self.show_message(
753
438
  (
@@ -757,7 +442,7 @@ class AccuSleepWindow(QMainWindow):
757
442
  )
758
443
  return None, None, None, False
759
444
 
760
- sampling_rate = self.recordings[self.recording_index].sampling_rate
445
+ sampling_rate = self.recording_manager.current.sampling_rate
761
446
 
762
447
  eeg, emg, sampling_rate = resample_and_standardize(
763
448
  eeg=eeg,
@@ -769,135 +454,35 @@ class AccuSleepWindow(QMainWindow):
769
454
  return eeg, emg, sampling_rate, True
770
455
 
771
456
  def create_calibration_file(self) -> None:
772
- """Creates a calibration file
457
+ """Creates a calibration file.
773
458
 
774
459
  This loads a recording and its labels, checks that the labels are
775
460
  all valid, creates the calibration file, and sets the
776
461
  "calibration file" property of the current recording to be the
777
462
  newly created file.
778
463
  """
779
- # load the recording
780
- eeg, emg, sampling_rate, success = self.load_single_recording(
781
- self.ui.calibration_status
782
- )
783
- if not success:
784
- return
785
-
786
- # load the labels
787
- label_file = self.recordings[self.recording_index].label_file
788
- if not os.path.isfile(label_file):
789
- self.ui.calibration_status.setText("label file does not exist")
790
- self.show_message("ERROR: label file does not exist")
791
- return
792
- try:
793
- labels, _ = load_labels(label_file)
794
- except Exception:
795
- self.ui.calibration_status.setText("could not load labels")
796
- self.show_message(
797
- (
798
- "ERROR: could not load labels. "
799
- "Check user manual for formatting instructions."
800
- )
801
- )
802
- return
803
- label_error_message = check_label_validity(
804
- labels=labels,
805
- confidence_scores=None,
806
- samples_in_recording=eeg.size,
807
- sampling_rate=sampling_rate,
808
- epoch_length=self.epoch_length,
809
- brain_state_set=self.brain_state_set,
810
- )
811
- if label_error_message:
812
- self.ui.calibration_status.setText("invalid label file")
813
- self.show_message(f"ERROR: {label_error_message}")
814
- return
815
-
816
- # get the name for the calibration file
817
- filename, _ = QFileDialog.getSaveFileName(
818
- self,
819
- caption="Save calibration file as",
820
- filter="*" + CALIBRATION_FILE_TYPE,
464
+ filename = select_save_location(
465
+ self, "Save calibration file as", "*" + CALIBRATION_FILE_TYPE
821
466
  )
822
467
  if not filename:
823
468
  return
824
- filename = os.path.normpath(filename)
825
469
 
826
- from accusleepy.classification import create_calibration_file
827
-
828
- create_calibration_file(
829
- filename=filename,
830
- eeg=eeg,
831
- emg=emg,
832
- labels=labels,
833
- sampling_rate=sampling_rate,
470
+ result = create_calibration(
471
+ recording=self.recording_manager.current,
834
472
  epoch_length=self.epoch_length,
835
- brain_state_set=self.brain_state_set,
836
- emg_filter=self.emg_filter,
837
- )
838
-
839
- self.ui.calibration_status.setText("")
840
- self.show_message(
841
- (
842
- "Created calibration file using recording "
843
- f"{self.recordings[self.recording_index].name} "
844
- f"at {filename}"
845
- )
473
+ brain_state_set=self.config.brain_state_set,
474
+ emg_filter=self.config.emg_filter,
475
+ output_filename=filename,
846
476
  )
847
477
 
848
- self.recordings[self.recording_index].calibration_file = filename
849
- self.ui.calibration_file_label.setText(filename)
850
-
851
- def check_single_file_inputs(self, recording_index: int) -> str | None:
852
- """Check that a recording's inputs appear valid
853
-
854
- This runs some basic tests for whether it will be possible to
855
- load and score a recording. If any test fails, we return an
856
- error message.
857
-
858
- :param recording_index: index of the recording in the list of
859
- all recordings.
860
- :return: error message
861
- """
862
- sampling_rate = self.recordings[recording_index].sampling_rate
863
- if self.epoch_length == 0:
864
- return "epoch length can't be 0"
865
- if sampling_rate == 0:
866
- return "sampling rate can't be 0"
867
- if self.epoch_length > sampling_rate:
868
- return "invalid epoch length or sampling rate"
869
- if self.recordings[self.recording_index].recording_file == "":
870
- return "no recording selected"
871
- if not os.path.isfile(self.recordings[self.recording_index].recording_file):
872
- return "recording file does not exist"
873
- if self.recordings[self.recording_index].label_file == "":
874
- return "no label file selected"
875
-
876
- def update_min_bout_length(self, new_value) -> None:
877
- """Update the minimum bout length
878
-
879
- :param new_value: new minimum bout length, in seconds
880
- """
881
- self.min_bout_length = new_value
882
-
883
- def update_overwrite_policy(self, checked) -> None:
884
- """Toggle overwriting policy
885
-
886
- If the checkbox is enabled, only epochs where the brain state is set to
887
- undefined will be overwritten by the automatic scoring process.
888
-
889
- :param checked: state of the checkbox
890
- """
891
- self.only_overwrite_undefined = checked
892
-
893
- def update_confidence_policy(self, checked) -> None:
894
- """Toggle policy for saving confidence scores
895
-
896
- If the checkbox is enabled, confidence scores will be saved to the label files.
897
-
898
- :param checked: state of the checkbox
899
- """
900
- self.save_confidence_scores = checked
478
+ # Display results
479
+ result.report_to(self.show_message)
480
+ if not result.success:
481
+ self.ui.calibration_status.setText("error")
482
+ else:
483
+ self.ui.calibration_status.setText("")
484
+ self.recording_manager.current.calibration_file = filename
485
+ self.ui.calibration_file_label.setText(filename)
901
486
 
902
487
  def manual_scoring(self) -> None:
903
488
  """View the selected recording for manual scoring"""
@@ -915,11 +500,12 @@ class AccuSleepWindow(QMainWindow):
915
500
 
916
501
  # if the labels exist, load them
917
502
  # otherwise, create a blank set of labels
918
- label_file = self.recordings[self.recording_index].label_file
503
+ label_file = self.recording_manager.current.label_file
919
504
  if os.path.isfile(label_file):
920
505
  try:
921
506
  labels, confidence_scores = load_labels(label_file)
922
507
  except Exception:
508
+ logger.exception("Failed to load %s", label_file)
923
509
  self.ui.manual_scoring_status.setText("could not load labels")
924
510
  self.show_message(
925
511
  (
@@ -937,56 +523,23 @@ class AccuSleepWindow(QMainWindow):
937
523
  # to a label file that does not have one
938
524
  confidence_scores = None
939
525
 
940
- # check that all labels are valid
941
- label_error = check_label_validity(
526
+ # check that labels are valid and correct minor length mismatches
527
+ labels, confidence_scores, validation_message = validate_and_correct_labels(
942
528
  labels=labels,
943
529
  confidence_scores=confidence_scores,
944
530
  samples_in_recording=eeg.size,
945
531
  sampling_rate=sampling_rate,
946
532
  epoch_length=self.epoch_length,
947
- brain_state_set=self.brain_state_set,
533
+ brain_state_set=self.config.brain_state_set,
948
534
  )
949
- if label_error:
950
- # if the label length is only off by one, pad or truncate as needed
951
- # and show a warning
952
- if label_error == LABEL_LENGTH_ERROR:
953
- # should be very close to an integer
954
- samples_per_epoch = round(sampling_rate * self.epoch_length)
955
- epochs_in_recording = round(eeg.size / samples_per_epoch)
956
- if epochs_in_recording - labels.size == 1:
957
- labels = np.concatenate((labels, np.array([UNDEFINED_LABEL])))
958
- if confidence_scores is not None:
959
- confidence_scores = np.concatenate(
960
- (confidence_scores, np.array([0]))
961
- )
962
- self.show_message(
963
- (
964
- "WARNING: an undefined epoch was added to "
965
- "the label file to correct its length."
966
- )
967
- )
968
- elif labels.size - epochs_in_recording == 1:
969
- labels = labels[:-1]
970
- if confidence_scores is not None:
971
- confidence_scores = confidence_scores[:-1]
972
- self.show_message(
973
- (
974
- "WARNING: the last epoch was removed from "
975
- "the label file to correct its length."
976
- )
977
- )
978
- else:
979
- self.ui.manual_scoring_status.setText("invalid label file")
980
- self.show_message(f"ERROR: {label_error}")
981
- return
982
- else:
983
- self.ui.manual_scoring_status.setText("invalid label file")
984
- self.show_message(f"ERROR: {label_error}")
985
- return
535
+ if labels is None:
536
+ self.ui.manual_scoring_status.setText("invalid label file")
537
+ self.show_message(f"ERROR: {validation_message}")
538
+ return
539
+ if validation_message:
540
+ self.show_message(f"WARNING: {validation_message}")
986
541
 
987
- self.show_message(
988
- f"Viewing recording {self.recordings[self.recording_index].name}"
989
- )
542
+ self.show_message(f"Viewing recording {self.recording_manager.current.name}")
990
543
  self.ui.manual_scoring_status.setText("file is open")
991
544
 
992
545
  # launch the manual scoring window
@@ -998,7 +551,7 @@ class AccuSleepWindow(QMainWindow):
998
551
  confidence_scores=confidence_scores,
999
552
  sampling_rate=sampling_rate,
1000
553
  epoch_length=self.epoch_length,
1001
- emg_filter=self.emg_filter,
554
+ emg_filter=self.config.emg_filter,
1002
555
  )
1003
556
  manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
1004
557
  manual_scoring_window.exec()
@@ -1006,89 +559,48 @@ class AccuSleepWindow(QMainWindow):
1006
559
 
1007
560
  def create_label_file(self) -> None:
1008
561
  """Set the filename for a new label file"""
1009
- filename, _ = QFileDialog.getSaveFileName(
562
+ filename = select_save_location(
1010
563
  self,
1011
- caption="Set filename for label file (nothing will be overwritten yet)",
1012
- filter="*" + LABEL_FILE_TYPE,
564
+ "Set filename for label file (nothing will be overwritten yet)",
565
+ "*" + LABEL_FILE_TYPE,
1013
566
  )
1014
567
  if filename:
1015
- filename = os.path.normpath(filename)
1016
- self.recordings[self.recording_index].label_file = filename
568
+ self.recording_manager.current.label_file = filename
1017
569
  self.ui.label_file_label.setText(filename)
1018
570
 
1019
571
  def select_label_file(self) -> None:
1020
572
  """User can select an existing label file"""
1021
- file_dialog = QFileDialog(self)
1022
- file_dialog.setWindowTitle("Select label file")
1023
- file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
1024
- file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
1025
- file_dialog.setNameFilter("*" + LABEL_FILE_TYPE)
1026
-
1027
- if file_dialog.exec():
1028
- selected_files = file_dialog.selectedFiles()
1029
- filename = selected_files[0]
1030
- filename = os.path.normpath(filename)
1031
- self.recordings[self.recording_index].label_file = filename
573
+ filename = select_existing_file(
574
+ self, "Select label file", "*" + LABEL_FILE_TYPE
575
+ )
576
+ if filename:
577
+ self.recording_manager.current.label_file = filename
1032
578
  self.ui.label_file_label.setText(filename)
1033
579
 
1034
580
  def select_calibration_file(self) -> None:
1035
581
  """User can select a calibration file"""
1036
- file_dialog = QFileDialog(self)
1037
- file_dialog.setWindowTitle("Select calibration file")
1038
- file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
1039
- file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
1040
- file_dialog.setNameFilter("*" + CALIBRATION_FILE_TYPE)
1041
-
1042
- if file_dialog.exec():
1043
- selected_files = file_dialog.selectedFiles()
1044
- filename = selected_files[0]
1045
- filename = os.path.normpath(filename)
1046
- self.recordings[self.recording_index].calibration_file = filename
582
+ filename = select_existing_file(
583
+ self, "Select calibration file", "*" + CALIBRATION_FILE_TYPE
584
+ )
585
+ if filename:
586
+ self.recording_manager.current.calibration_file = filename
1047
587
  self.ui.calibration_file_label.setText(filename)
1048
588
 
1049
589
  def select_recording_file(self) -> None:
1050
590
  """User can select a recording file"""
1051
- file_dialog = QFileDialog(self)
1052
- file_dialog.setWindowTitle("Select recording file")
1053
- file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
1054
- file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
1055
- file_dialog.setNameFilter(f"(*{' *'.join(RECORDING_FILE_TYPES)})")
1056
-
1057
- if file_dialog.exec():
1058
- selected_files = file_dialog.selectedFiles()
1059
- filename = selected_files[0]
1060
- filename = os.path.normpath(filename)
1061
- self.recordings[self.recording_index].recording_file = filename
591
+ file_filter = f"(*{' *'.join(RECORDING_FILE_TYPES)})"
592
+ filename = select_existing_file(self, "Select recording file", file_filter)
593
+ if filename:
594
+ self.recording_manager.current.recording_file = filename
1062
595
  self.ui.recording_file_label.setText(filename)
1063
596
 
1064
597
  def show_recording_info(self) -> None:
1065
598
  """Update the UI to show info for the selected recording"""
1066
- self.ui.sampling_rate_input.setValue(
1067
- self.recordings[self.recording_index].sampling_rate
1068
- )
1069
- self.ui.recording_file_label.setText(
1070
- self.recordings[self.recording_index].recording_file
1071
- )
1072
- self.ui.label_file_label.setText(
1073
- self.recordings[self.recording_index].label_file
1074
- )
1075
- self.ui.calibration_file_label.setText(
1076
- self.recordings[self.recording_index].calibration_file
1077
- )
1078
-
1079
- def update_epoch_length(self, new_value: int | float) -> None:
1080
- """Update the epoch length when the widget state changes
1081
-
1082
- :param new_value: new epoch length
1083
- """
1084
- self.epoch_length = new_value
1085
-
1086
- def update_sampling_rate(self, new_value: int | float) -> None:
1087
- """Update recording's sampling rate when the widget state changes
1088
-
1089
- :param new_value: new sampling rate
1090
- """
1091
- self.recordings[self.recording_index].sampling_rate = new_value
599
+ recording = self.recording_manager.current
600
+ self.ui.sampling_rate_input.setValue(recording.sampling_rate)
601
+ self.ui.recording_file_label.setText(recording.recording_file)
602
+ self.ui.label_file_label.setText(recording.label_file)
603
+ self.ui.calibration_file_label.setText(recording.calibration_file)
1092
604
 
1093
605
  def show_message(self, message: str) -> None:
1094
606
  """Display a new message to the user
@@ -1103,50 +615,22 @@ class AccuSleepWindow(QMainWindow):
1103
615
  scrollbar = self.ui.message_area.verticalScrollBar()
1104
616
  scrollbar.setValue(scrollbar.maximum())
1105
617
 
1106
- def select_recording(self, list_index: int) -> None:
1107
- """Callback for when a recording is selected
1108
-
1109
- :param list_index: index of this recording in the list widget
1110
- """
1111
- # get index of this recording
1112
- self.recording_index = list_index
1113
- # display information about this recording
618
+ def select_recording(self, _index: int) -> None:
619
+ """Callback for when a recording is selected"""
1114
620
  self.show_recording_info()
1115
621
  self.ui.selected_recording_groupbox.setTitle(
1116
- f"Data / actions for Recording {self.recordings[list_index].name}"
622
+ f"Data / actions for Recording {self.recording_manager.current.name}"
1117
623
  )
1118
624
 
1119
625
  def add_recording(self) -> None:
1120
626
  """Add new recording to the list"""
1121
- # find name to use for the new recording
1122
- new_name = max([r.name for r in self.recordings]) + 1
1123
-
1124
- # add new recording to list
1125
- self.recordings.append(
1126
- Recording(
1127
- name=new_name,
1128
- sampling_rate=self.recordings[self.recording_index].sampling_rate,
1129
- widget=QListWidgetItem(
1130
- f"Recording {new_name}", self.ui.recording_list_widget
1131
- ),
1132
- )
1133
- )
1134
-
1135
- # display new list
1136
- self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
1137
- self.ui.recording_list_widget.setCurrentRow(len(self.recordings) - 1)
1138
- self.show_message(f"added Recording {new_name}")
627
+ current_sampling_rate = self.recording_manager.current.sampling_rate
628
+ recording = self.recording_manager.add(sampling_rate=current_sampling_rate)
629
+ self.show_message(f"added Recording {recording.name}")
1139
630
 
1140
631
  def remove_recording(self) -> None:
1141
632
  """Delete selected recording from the list"""
1142
- if len(self.recordings) > 1:
1143
- current_list_index = self.ui.recording_list_widget.currentRow()
1144
- _ = self.ui.recording_list_widget.takeItem(current_list_index)
1145
- self.show_message(
1146
- f"deleted Recording {self.recordings[current_list_index].name}"
1147
- )
1148
- del self.recordings[current_list_index]
1149
- self.recording_index = self.ui.recording_list_widget.currentRow()
633
+ self.show_message(self.recording_manager.remove_current())
1150
634
 
1151
635
  def show_user_manual(self) -> None:
1152
636
  """Show a popup window with the user manual"""
@@ -1162,310 +646,12 @@ class AccuSleepWindow(QMainWindow):
1162
646
  self.popup.setGeometry(QRect(100, 100, 600, 600))
1163
647
  self.popup.show()
1164
648
 
1165
- def initialize_settings_tab(self):
1166
- """Populate settings tab and assign its callbacks"""
1167
- # store dictionary that maps digits to rows of widgets
1168
- # in the settings tab
1169
- self.settings_widgets = {
1170
- 1: StateSettings(
1171
- digit=1,
1172
- enabled_widget=self.ui.enable_state_1,
1173
- name_widget=self.ui.state_name_1,
1174
- is_scored_widget=self.ui.state_scored_1,
1175
- frequency_widget=self.ui.state_frequency_1,
1176
- ),
1177
- 2: StateSettings(
1178
- digit=2,
1179
- enabled_widget=self.ui.enable_state_2,
1180
- name_widget=self.ui.state_name_2,
1181
- is_scored_widget=self.ui.state_scored_2,
1182
- frequency_widget=self.ui.state_frequency_2,
1183
- ),
1184
- 3: StateSettings(
1185
- digit=3,
1186
- enabled_widget=self.ui.enable_state_3,
1187
- name_widget=self.ui.state_name_3,
1188
- is_scored_widget=self.ui.state_scored_3,
1189
- frequency_widget=self.ui.state_frequency_3,
1190
- ),
1191
- 4: StateSettings(
1192
- digit=4,
1193
- enabled_widget=self.ui.enable_state_4,
1194
- name_widget=self.ui.state_name_4,
1195
- is_scored_widget=self.ui.state_scored_4,
1196
- frequency_widget=self.ui.state_frequency_4,
1197
- ),
1198
- 5: StateSettings(
1199
- digit=5,
1200
- enabled_widget=self.ui.enable_state_5,
1201
- name_widget=self.ui.state_name_5,
1202
- is_scored_widget=self.ui.state_scored_5,
1203
- frequency_widget=self.ui.state_frequency_5,
1204
- ),
1205
- 6: StateSettings(
1206
- digit=6,
1207
- enabled_widget=self.ui.enable_state_6,
1208
- name_widget=self.ui.state_name_6,
1209
- is_scored_widget=self.ui.state_scored_6,
1210
- frequency_widget=self.ui.state_frequency_6,
1211
- ),
1212
- 7: StateSettings(
1213
- digit=7,
1214
- enabled_widget=self.ui.enable_state_7,
1215
- name_widget=self.ui.state_name_7,
1216
- is_scored_widget=self.ui.state_scored_7,
1217
- frequency_widget=self.ui.state_frequency_7,
1218
- ),
1219
- 8: StateSettings(
1220
- digit=8,
1221
- enabled_widget=self.ui.enable_state_8,
1222
- name_widget=self.ui.state_name_8,
1223
- is_scored_widget=self.ui.state_scored_8,
1224
- frequency_widget=self.ui.state_frequency_8,
1225
- ),
1226
- 9: StateSettings(
1227
- digit=9,
1228
- enabled_widget=self.ui.enable_state_9,
1229
- name_widget=self.ui.state_name_9,
1230
- is_scored_widget=self.ui.state_scored_9,
1231
- frequency_widget=self.ui.state_frequency_9,
1232
- ),
1233
- 0: StateSettings(
1234
- digit=0,
1235
- enabled_widget=self.ui.enable_state_0,
1236
- name_widget=self.ui.state_name_0,
1237
- is_scored_widget=self.ui.state_scored_0,
1238
- frequency_widget=self.ui.state_frequency_0,
1239
- ),
1240
- }
1241
-
1242
- # update widget state to display current config
1243
- # UI defaults
1244
- self.ui.default_epoch_input.setValue(self.epoch_length)
1245
- self.ui.overwrite_default_checkbox.setChecked(self.only_overwrite_undefined)
1246
- self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_scores)
1247
- self.ui.default_min_bout_length_spinbox.setValue(self.min_bout_length)
1248
- self.ui.epochs_to_show_spinbox.setValue(self.default_epochs_to_show)
1249
- self.ui.autoscroll_checkbox.setChecked(self.default_autoscroll_state)
1250
- # EMG filter
1251
- self.ui.emg_order_spinbox.setValue(self.emg_filter.order)
1252
- self.ui.bp_lower_spinbox.setValue(self.emg_filter.bp_lower)
1253
- self.ui.bp_upper_spinbox.setValue(self.emg_filter.bp_upper)
1254
- # model training hyperparameters
1255
- self.ui.batch_size_spinbox.setValue(self.hyperparameters.batch_size)
1256
- self.ui.learning_rate_spinbox.setValue(self.hyperparameters.learning_rate)
1257
- self.ui.momentum_spinbox.setValue(self.hyperparameters.momentum)
1258
- self.ui.training_epochs_spinbox.setValue(self.hyperparameters.training_epochs)
1259
- # brain states
1260
- states = {b.digit: b for b in self.brain_state_set.brain_states}
1261
- for digit in range(10):
1262
- if digit in states.keys():
1263
- self.settings_widgets[digit].enabled_widget.setChecked(True)
1264
- self.settings_widgets[digit].name_widget.setText(states[digit].name)
1265
- self.settings_widgets[digit].is_scored_widget.setChecked(
1266
- states[digit].is_scored
1267
- )
1268
- self.settings_widgets[digit].frequency_widget.setValue(
1269
- states[digit].frequency
1270
- )
1271
- else:
1272
- self.settings_widgets[digit].enabled_widget.setChecked(False)
1273
- self.settings_widgets[digit].name_widget.setEnabled(False)
1274
- self.settings_widgets[digit].is_scored_widget.setEnabled(False)
1275
- self.settings_widgets[digit].frequency_widget.setEnabled(False)
1276
-
1277
- # set callbacks
1278
- self.ui.emg_order_spinbox.valueChanged.connect(self.emg_filter_order_changed)
1279
- self.ui.bp_lower_spinbox.valueChanged.connect(self.emg_filter_bp_lower_changed)
1280
- self.ui.bp_upper_spinbox.valueChanged.connect(self.emg_filter_bp_upper_changed)
1281
- self.ui.batch_size_spinbox.valueChanged.connect(self.hyperparameters_changed)
1282
- self.ui.learning_rate_spinbox.valueChanged.connect(self.hyperparameters_changed)
1283
- self.ui.momentum_spinbox.valueChanged.connect(self.hyperparameters_changed)
1284
- self.ui.training_epochs_spinbox.valueChanged.connect(
1285
- self.hyperparameters_changed
1286
- )
1287
- for digit in range(10):
1288
- state = self.settings_widgets[digit]
1289
- state.enabled_widget.stateChanged.connect(
1290
- partial(self.set_brain_state_enabled, digit)
1291
- )
1292
- state.name_widget.editingFinished.connect(self.check_config_validity)
1293
- state.is_scored_widget.stateChanged.connect(
1294
- partial(self.is_scored_changed, digit)
1295
- )
1296
- state.frequency_widget.valueChanged.connect(self.check_config_validity)
1297
-
1298
- def set_brain_state_enabled(self, digit, e) -> None:
1299
- """Called when user clicks "enabled" checkbox
1300
-
1301
- :param digit: brain state digit
1302
- :param e: unused but mandatory
1303
- """
1304
- # get the widgets for this brain state
1305
- state = self.settings_widgets[digit]
1306
- # update state of these widgets
1307
- is_checked = state.enabled_widget.isChecked()
1308
- for widget in [
1309
- state.name_widget,
1310
- state.is_scored_widget,
1311
- ]:
1312
- widget.setEnabled(is_checked)
1313
- state.frequency_widget.setEnabled(
1314
- is_checked and state.is_scored_widget.isChecked()
1315
- )
1316
- if not is_checked:
1317
- state.name_widget.setText("")
1318
- state.frequency_widget.setValue(0)
1319
- # check that configuration is valid
1320
- _ = self.check_config_validity()
1321
-
1322
- def is_scored_changed(self, digit, e) -> None:
1323
- """Called when user sets whether a state is scored
1324
-
1325
- :param digit: brain state digit
1326
- :param e: unused, but mandatory
1327
- """
1328
- # get the widgets for this brain state
1329
- state = self.settings_widgets[digit]
1330
- # update the state of these widgets
1331
- is_checked = state.is_scored_widget.isChecked()
1332
- state.frequency_widget.setEnabled(is_checked)
1333
- if not is_checked:
1334
- state.frequency_widget.setValue(0)
1335
- # check that configuration is valid
1336
- _ = self.check_config_validity()
1337
-
1338
- def emg_filter_order_changed(self, new_value: int) -> None:
1339
- """Called when user modifies EMG filter order
1340
-
1341
- :param new_value: new EMG filter order
1342
- """
1343
- self.emg_filter.order = new_value
1344
-
1345
- def emg_filter_bp_lower_changed(self, new_value: int | float) -> None:
1346
- """Called when user modifies EMG filter lower cutoff
1347
-
1348
- :param new_value: new lower bandpass cutoff frequency
1349
- """
1350
- self.emg_filter.bp_lower = new_value
1351
- _ = self.check_config_validity()
1352
-
1353
- def emg_filter_bp_upper_changed(self, new_value: int | float) -> None:
1354
- """Called when user modifies EMG filter upper cutoff
1355
-
1356
- :param new_value: new upper bandpass cutoff frequency
1357
- """
1358
- self.emg_filter.bp_upper = new_value
1359
- _ = self.check_config_validity()
1360
-
1361
- def hyperparameters_changed(self, new_value) -> None:
1362
- """Called when user modifies model training hyperparameters
1363
-
1364
- :param new_value: unused
1365
- """
1366
- self.hyperparameters = Hyperparameters(
1367
- batch_size=self.ui.batch_size_spinbox.value(),
1368
- learning_rate=self.ui.learning_rate_spinbox.value(),
1369
- momentum=self.ui.momentum_spinbox.value(),
1370
- training_epochs=self.ui.training_epochs_spinbox.value(),
1371
- )
1372
-
1373
- def check_config_validity(self) -> str:
1374
- """Check if brain state configuration on screen is valid"""
1375
- # error message, if we get one
1376
- message = None
1377
-
1378
- # strip whitespace from brain state names and update display
1379
- for digit in range(10):
1380
- state = self.settings_widgets[digit]
1381
- current_name = state.name_widget.text()
1382
- formatted_name = current_name.strip()
1383
- if current_name != formatted_name:
1384
- state.name_widget.setText(formatted_name)
1385
-
1386
- # check if names are unique and frequencies add up to 1
1387
- names = []
1388
- frequencies = []
1389
- for digit in range(10):
1390
- state = self.settings_widgets[digit]
1391
- if state.enabled_widget.isChecked():
1392
- names.append(state.name_widget.text())
1393
- frequencies.append(state.frequency_widget.value())
1394
- if len(names) != len(set(names)):
1395
- message = "Error: names must be unique"
1396
- if sum(frequencies) != 1:
1397
- message = "Error: sum(frequencies) != 1"
1398
-
1399
- # check validity of EMG filter settings
1400
- if self.emg_filter.bp_lower >= self.emg_filter.bp_upper:
1401
- message = "Error: EMG filter cutoff frequencies are invalid"
1402
-
1403
- if message is not None:
1404
- self.ui.save_config_status.setText(message)
1405
- self.ui.save_config_button.setEnabled(False)
1406
- return message
1407
-
1408
- self.ui.save_config_button.setEnabled(True)
1409
- self.ui.save_config_status.setText("")
1410
-
1411
- def save_brain_state_config(self):
1412
- """Save configuration to file"""
1413
- # check that configuration is valid
1414
- error_message = self.check_config_validity()
1415
- if error_message is not None:
1416
- return
1417
-
1418
- # build a BrainStateMapper object from the current configuration
1419
- brain_states = list()
1420
- for digit in range(10):
1421
- state = self.settings_widgets[digit]
1422
- if state.enabled_widget.isChecked():
1423
- brain_states.append(
1424
- BrainState(
1425
- name=state.name_widget.text(),
1426
- digit=digit,
1427
- is_scored=state.is_scored_widget.isChecked(),
1428
- frequency=state.frequency_widget.value(),
1429
- )
1430
- )
1431
- self.brain_state_set = BrainStateSet(brain_states, UNDEFINED_LABEL)
1432
-
1433
- # save to file
1434
- save_config(
1435
- brain_state_set=self.brain_state_set,
1436
- default_epoch_length=self.ui.default_epoch_input.value(),
1437
- overwrite_setting=self.ui.overwrite_default_checkbox.isChecked(),
1438
- save_confidence_setting=self.ui.confidence_setting_checkbox.isChecked(),
1439
- min_bout_length=self.ui.default_min_bout_length_spinbox.value(),
1440
- emg_filter=EMGFilter(
1441
- order=self.emg_filter.order,
1442
- bp_lower=self.emg_filter.bp_lower,
1443
- bp_upper=self.emg_filter.bp_upper,
1444
- ),
1445
- hyperparameters=Hyperparameters(
1446
- batch_size=self.hyperparameters.batch_size,
1447
- learning_rate=self.hyperparameters.learning_rate,
1448
- momentum=self.hyperparameters.momentum,
1449
- training_epochs=self.hyperparameters.training_epochs,
1450
- ),
1451
- epochs_to_show=self.ui.epochs_to_show_spinbox.value(),
1452
- autoscroll_state=self.ui.autoscroll_checkbox.isChecked(),
1453
- )
1454
- self.ui.save_config_status.setText("configuration saved")
1455
-
1456
- def reset_emg_filter_settings(self) -> None:
1457
- self.ui.emg_order_spinbox.setValue(DEFAULT_EMG_FILTER_ORDER)
1458
- self.ui.bp_lower_spinbox.setValue(DEFAULT_EMG_BP_LOWER)
1459
- self.ui.bp_upper_spinbox.setValue(DEFAULT_EMG_BP_UPPER)
1460
-
1461
- def reset_hyperparams_settings(self):
1462
- self.ui.batch_size_spinbox.setValue(DEFAULT_BATCH_SIZE)
1463
- self.ui.learning_rate_spinbox.setValue(DEFAULT_LEARNING_RATE)
1464
- self.ui.momentum_spinbox.setValue(DEFAULT_MOMENTUM)
1465
- self.ui.training_epochs_spinbox.setValue(DEFAULT_TRAINING_EPOCHS)
1466
-
1467
649
 
1468
650
  def run_primary_window() -> None:
651
+ logging.basicConfig(
652
+ level=logging.INFO,
653
+ format="%(levelname)s - %(name)s - %(message)s",
654
+ )
1469
655
  app = QApplication(sys.argv)
1470
656
  AccuSleepWindow()
1471
657
  sys.exit(app.exec())