accusleepy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ from tqdm import trange
11
11
 
12
12
  import accusleepy.constants as c
13
13
  from accusleepy.brain_state_set import BrainStateSet
14
+ from accusleepy.fileio import EMGFilter, Hyperparameters
14
15
  from accusleepy.models import SSANN
15
16
  from accusleepy.signal_processing import (
16
17
  create_eeg_emg_image,
@@ -19,11 +20,6 @@ from accusleepy.signal_processing import (
19
20
  mixture_z_score_img,
20
21
  )
21
22
 
22
- BATCH_SIZE = 64
23
- LEARNING_RATE = 1e-3
24
- MOMENTUM = 0.9
25
- TRAINING_EPOCHS = 6
26
-
27
23
 
28
24
  class AccuSleepImageDataset(Dataset):
29
25
  """Dataset for loading AccuSleep training images"""
@@ -62,12 +58,16 @@ def get_device():
62
58
 
63
59
 
64
60
  def create_dataloader(
65
- annotations_file: str, img_dir: str, shuffle: bool = True
61
+ annotations_file: str,
62
+ img_dir: str,
63
+ hyperparameters: Hyperparameters,
64
+ shuffle: bool = True,
66
65
  ) -> DataLoader:
67
66
  """Create DataLoader for a dataset of training or calibration images
68
67
 
69
68
  :param annotations_file: file with information on each training image
70
69
  :param img_dir: training image location
70
+ :param hyperparameters: model training hyperparameters
71
71
  :param shuffle: reshuffle data for every epoch
72
72
  :return: DataLoader for the data
73
73
 
@@ -76,7 +76,9 @@ def create_dataloader(
76
76
  annotations_file=annotations_file,
77
77
  img_dir=img_dir,
78
78
  )
79
- return DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=shuffle)
79
+ return DataLoader(
80
+ image_dataset, batch_size=hyperparameters.batch_size, shuffle=shuffle
81
+ )
80
82
 
81
83
 
82
84
  def train_ssann(
@@ -84,6 +86,7 @@ def train_ssann(
84
86
  img_dir: str,
85
87
  mixture_weights: np.array,
86
88
  n_classes: int,
89
+ hyperparameters: Hyperparameters,
87
90
  ) -> SSANN:
88
91
  """Train a SSANN classification model for sleep scoring
89
92
 
@@ -91,10 +94,13 @@ def train_ssann(
91
94
  :param img_dir: training image location
92
95
  :param mixture_weights: typical relative frequencies of brain states
93
96
  :param n_classes: number of classes the model will learn
97
+ :param hyperparameters: model training hyperparameters
94
98
  :return: trained Sleep Scoring Artificial Neural Network model
95
99
  """
96
100
  train_dataloader = create_dataloader(
97
- annotations_file=annotations_file, img_dir=img_dir
101
+ annotations_file=annotations_file,
102
+ img_dir=img_dir,
103
+ hyperparameters=hyperparameters,
98
104
  )
99
105
 
100
106
  device = get_device()
@@ -106,9 +112,13 @@ def train_ssann(
106
112
  weight = torch.tensor((mixture_weights**-1).astype("float32")).to(device)
107
113
 
108
114
  criterion = nn.CrossEntropyLoss(weight=weight)
109
- optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
115
+ optimizer = optim.SGD(
116
+ model.parameters(),
117
+ lr=hyperparameters.learning_rate,
118
+ momentum=hyperparameters.momentum,
119
+ )
110
120
 
111
- for _ in trange(TRAINING_EPOCHS):
121
+ for _ in trange(hyperparameters.training_epochs):
112
122
  for data in train_dataloader:
113
123
  inputs, labels = data
114
124
  (inputs, labels) = (inputs.to(device), labels.to(device))
@@ -131,6 +141,7 @@ def score_recording(
131
141
  epoch_length: int | float,
132
142
  epochs_per_img: int,
133
143
  brain_state_set: BrainStateSet,
144
+ emg_filter: EMGFilter,
134
145
  ) -> np.array:
135
146
  """Use classification model to get brain state labels for a recording
136
147
 
@@ -146,6 +157,7 @@ def score_recording(
146
157
  :param epoch_length: epoch length, in seconds
147
158
  :param epochs_per_img: number of epochs for the model to consider
148
159
  :param brain_state_set: set of brain state options
160
+ :param emg_filter: EMG filter parameters
149
161
  :return: brain state labels, confidence scores
150
162
  """
151
163
  # prepare model
@@ -154,7 +166,7 @@ def score_recording(
154
166
  model.eval()
155
167
 
156
168
  # create and scale eeg+emg spectrogram
157
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
169
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
158
170
  img = mixture_z_score_img(
159
171
  img,
160
172
  mixture_means=mixture_means,
@@ -192,6 +204,7 @@ def example_real_time_scoring_function(
192
204
  epoch_length: int | float,
193
205
  epochs_per_img: int,
194
206
  brain_state_set: BrainStateSet,
207
+ emg_filter: EMGFilter,
195
208
  ) -> int:
196
209
  """Example function that could be used for real-time scoring
197
210
 
@@ -220,6 +233,7 @@ def example_real_time_scoring_function(
220
233
  :param epoch_length: epoch length, in seconds
221
234
  :param epochs_per_img: number of epochs shown to the model at once
222
235
  :param brain_state_set: set of brain state options
236
+ :param emg_filter: EMG filter parameters
223
237
  :return: brain state label
224
238
  """
225
239
  # prepare model
@@ -229,7 +243,7 @@ def example_real_time_scoring_function(
229
243
  model.eval()
230
244
 
231
245
  # create and scale eeg+emg spectrogram
232
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
246
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
233
247
  img = mixture_z_score_img(
234
248
  img,
235
249
  mixture_means=mixture_means,
@@ -260,6 +274,7 @@ def create_calibration_file(
260
274
  sampling_rate: int | float,
261
275
  epoch_length: int | float,
262
276
  brain_state_set: BrainStateSet,
277
+ emg_filter: EMGFilter,
263
278
  ) -> None:
264
279
  """Create file of calibration data for a subject
265
280
 
@@ -273,8 +288,9 @@ def create_calibration_file(
273
288
  :param sampling_rate: sampling rate, in Hz
274
289
  :param epoch_length: epoch length, in seconds
275
290
  :param brain_state_set: set of brain state options
291
+ :param emg_filter: EMG filter parameters
276
292
  """
277
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
293
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
278
294
  mixture_means, mixture_sds = get_mixture_values(
279
295
  img=img,
280
296
  labels=brain_state_set.convert_digit_to_class(labels),
accusleepy/config.json CHANGED
@@ -20,5 +20,18 @@
20
20
  }
21
21
  ],
22
22
  "default_epoch_length": 2.5,
23
- "save_confidence_setting": true
23
+ "default_overwrite_setting": false,
24
+ "save_confidence_setting": true,
25
+ "default_min_bout_length": 5.0,
26
+ "emg_filter": {
27
+ "order": 8,
28
+ "bp_lower": 20.0,
29
+ "bp_upper": 50.0
30
+ },
31
+ "hyperparameters": {
32
+ "batch_size": 64,
33
+ "learning_rate": 0.001,
34
+ "momentum": 0.9,
35
+ "training_epochs": 6
36
+ }
24
37
  }
accusleepy/constants.py CHANGED
@@ -36,11 +36,33 @@ LABEL_COL = "label"
36
36
  # recording list file header:
37
37
  RECORDING_LIST_NAME = "recording_list"
38
38
  RECORDING_LIST_FILE_TYPE = ".json"
39
- # key for default epoch length in config
40
- DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
41
- # key used for default confidence score behavior in config
42
- DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
43
39
  # filename used to store info about training image datasets
44
40
  ANNOTATIONS_FILENAME = "annotations.csv"
45
41
  # filename for annotation file for the calibration set
46
42
  CALIBRATION_ANNOTATION_FILENAME = "calibration_set.csv"
43
+
44
+ # config file keys
45
+ # ui setting keys
46
+ DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
47
+ DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
48
+ DEFAULT_MIN_BOUT_LENGTH_KEY = "default_min_bout_length"
49
+ DEFAULT_OVERWRITE_KEY = "default_overwrite_setting"
50
+ # EMG filter parameters key
51
+ EMG_FILTER_KEY = "emg_filter"
52
+ # model training hyperparameters key
53
+ HYPERPARAMETERS_KEY = "hyperparameters"
54
+
55
+ # default values
56
+ # default UI settings
57
+ DEFAULT_MIN_BOUT_LENGTH = 5.0
58
+ DEFAULT_CONFIDENCE_SETTING = True
59
+ DEFAULT_OVERWRITE_SETTING = False
60
+ # default EMG filter parameters (order, bandpass frequencies)
61
+ DEFAULT_EMG_FILTER_ORDER = 8
62
+ DEFAULT_EMG_BP_LOWER = 20
63
+ DEFAULT_EMG_BP_UPPER = 50
64
+ # default hyperparameters
65
+ DEFAULT_BATCH_SIZE = 64
66
+ DEFAULT_LEARNING_RATE = 1e-3
67
+ DEFAULT_MOMENTUM = 0.9
68
+ DEFAULT_TRAINING_EPOCHS = 6
accusleepy/fileio.py CHANGED
@@ -7,19 +7,26 @@ import pandas as pd
7
7
  from PySide6.QtWidgets import QListWidgetItem
8
8
 
9
9
  from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
10
- from accusleepy.constants import (
11
- BRAIN_STATE_COL,
12
- CONFIDENCE_SCORE_COL,
13
- CONFIG_FILE,
14
- DEFAULT_CONFIDENCE_SETTING_KEY,
15
- DEFAULT_EPOCH_LENGTH_KEY,
16
- EEG_COL,
17
- EMG_COL,
18
- MIXTURE_MEAN_COL,
19
- MIXTURE_SD_COL,
20
- RECORDING_LIST_NAME,
21
- UNDEFINED_LABEL,
22
- )
10
+ import accusleepy.constants as c
11
+
12
+
13
+ @dataclass
14
+ class EMGFilter:
15
+ """Convenience class for a EMG filter parameters"""
16
+
17
+ order: int # filter order
18
+ bp_lower: int | float # lower bandpass frequency
19
+ bp_upper: int | float # upper bandpass frequency
20
+
21
+
22
+ @dataclass
23
+ class Hyperparameters:
24
+ """Convenience class for model training hyperparameters"""
25
+
26
+ batch_size: int
27
+ learning_rate: float
28
+ momentum: float
29
+ training_epochs: int
23
30
 
24
31
 
25
32
  @dataclass
@@ -41,8 +48,8 @@ def load_calibration_file(filename: str) -> (np.array, np.array):
41
48
  :return: mixture means and SDs
42
49
  """
43
50
  df = pd.read_csv(filename)
44
- mixture_means = df[MIXTURE_MEAN_COL].values
45
- mixture_sds = df[MIXTURE_SD_COL].values
51
+ mixture_means = df[c.MIXTURE_MEAN_COL].values
52
+ mixture_sds = df[c.MIXTURE_SD_COL].values
46
53
  return mixture_means, mixture_sds
47
54
 
48
55
 
@@ -69,8 +76,8 @@ def load_recording(filename: str) -> (np.array, np.array):
69
76
  :return: arrays of EEG and EMG data
70
77
  """
71
78
  df = load_csv_or_parquet(filename)
72
- eeg = df[EEG_COL].values
73
- emg = df[EMG_COL].values
79
+ eeg = df[c.EEG_COL].values
80
+ emg = df[c.EMG_COL].values
74
81
  return eeg, emg
75
82
 
76
83
 
@@ -81,10 +88,10 @@ def load_labels(filename: str) -> (np.array, np.array):
81
88
  :return: array of brain state labels and, optionally, array of confidence scores
82
89
  """
83
90
  df = load_csv_or_parquet(filename)
84
- if CONFIDENCE_SCORE_COL in df.columns:
85
- return df[BRAIN_STATE_COL].values, df[CONFIDENCE_SCORE_COL].values
91
+ if c.CONFIDENCE_SCORE_COL in df.columns:
92
+ return df[c.BRAIN_STATE_COL].values, df[c.CONFIDENCE_SCORE_COL].values
86
93
  else:
87
- return df[BRAIN_STATE_COL].values, None
94
+ return df[c.BRAIN_STATE_COL].values, None
88
95
 
89
96
 
90
97
  def save_labels(
@@ -98,48 +105,92 @@ def save_labels(
98
105
  """
99
106
  if confidence_scores is not None:
100
107
  pd.DataFrame(
101
- {BRAIN_STATE_COL: labels, CONFIDENCE_SCORE_COL: confidence_scores}
108
+ {c.BRAIN_STATE_COL: labels, c.CONFIDENCE_SCORE_COL: confidence_scores}
102
109
  ).to_csv(filename, index=False)
103
110
  else:
104
- pd.DataFrame({BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
111
+ pd.DataFrame({c.BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
105
112
 
106
113
 
107
- def load_config() -> tuple[BrainStateSet, int | float, bool]:
114
+ def load_config() -> tuple[
115
+ BrainStateSet, int | float, bool, bool, int | float, EMGFilter, Hyperparameters
116
+ ]:
108
117
  """Load configuration file with brain state options
109
118
 
110
- :return: set of brain state options, other settings
119
+ :return: set of brain state options,
120
+ default epoch length,
121
+ default overwrite setting,
122
+ default confidence score output setting,
123
+ default minimum bout length,
124
+ EMG filter parameters,
125
+ model training hyperparameters
111
126
  """
112
127
  with open(
113
- os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "r"
128
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "r"
114
129
  ) as f:
115
130
  data = json.load(f)
116
131
 
117
132
  return (
118
133
  BrainStateSet(
119
- [BrainState(**b) for b in data[BRAIN_STATES_KEY]], UNDEFINED_LABEL
134
+ [BrainState(**b) for b in data[BRAIN_STATES_KEY]], c.UNDEFINED_LABEL
135
+ ),
136
+ data[c.DEFAULT_EPOCH_LENGTH_KEY],
137
+ data.get(c.DEFAULT_OVERWRITE_KEY, c.DEFAULT_OVERWRITE_SETTING),
138
+ data.get(c.DEFAULT_CONFIDENCE_SETTING_KEY, c.DEFAULT_CONFIDENCE_SETTING),
139
+ data.get(c.DEFAULT_MIN_BOUT_LENGTH_KEY, c.DEFAULT_MIN_BOUT_LENGTH),
140
+ EMGFilter(
141
+ **data.get(
142
+ c.EMG_FILTER_KEY,
143
+ {
144
+ "order": c.DEFAULT_EMG_FILTER_ORDER,
145
+ "bp_lower": c.DEFAULT_EMG_BP_LOWER,
146
+ "bp_upper": c.DEFAULT_EMG_BP_UPPER,
147
+ },
148
+ )
149
+ ),
150
+ Hyperparameters(
151
+ **data.get(
152
+ c.HYPERPARAMETERS_KEY,
153
+ {
154
+ "batch_size": c.DEFAULT_BATCH_SIZE,
155
+ "learning_rate": c.DEFAULT_LEARNING_RATE,
156
+ "momentum": c.DEFAULT_MOMENTUM,
157
+ "training_epochs": c.DEFAULT_TRAINING_EPOCHS,
158
+ },
159
+ )
120
160
  ),
121
- data[DEFAULT_EPOCH_LENGTH_KEY],
122
- data.get(DEFAULT_CONFIDENCE_SETTING_KEY, True),
123
161
  )
124
162
 
125
163
 
126
164
  def save_config(
127
165
  brain_state_set: BrainStateSet,
128
166
  default_epoch_length: int | float,
167
+ overwrite_setting: bool,
129
168
  save_confidence_setting: bool,
169
+ min_bout_length: int | float,
170
+ emg_filter: EMGFilter,
171
+ hyperparameters: Hyperparameters,
130
172
  ) -> None:
131
173
  """Save configuration of brain state options to json file
132
174
 
133
175
  :param brain_state_set: set of brain state options
134
- :param default_epoch_length: epoch length to use when the GUI starts
135
- :param save_confidence_setting: whether the option to save confidence
136
- scores should be True by default
176
+ :param default_epoch_length: default epoch length
177
+ :param save_confidence_setting: default setting for
178
+ saving confidence scores
179
+ :param emg_filter: EMG filter parameters
180
+ :param min_bout_length: default minimum bout length
181
+ :param overwrite_setting: default setting for overwriting
182
+ existing labels
183
+ :param hyperparameters: model training hyperparameters
137
184
  """
138
185
  output_dict = brain_state_set.to_output_dict()
139
- output_dict.update({DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
140
- output_dict.update({DEFAULT_CONFIDENCE_SETTING_KEY: save_confidence_setting})
186
+ output_dict.update({c.DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
187
+ output_dict.update({c.DEFAULT_OVERWRITE_KEY: overwrite_setting})
188
+ output_dict.update({c.DEFAULT_CONFIDENCE_SETTING_KEY: save_confidence_setting})
189
+ output_dict.update({c.DEFAULT_MIN_BOUT_LENGTH_KEY: min_bout_length})
190
+ output_dict.update({c.EMG_FILTER_KEY: emg_filter.__dict__})
191
+ output_dict.update({c.HYPERPARAMETERS_KEY: hyperparameters.__dict__})
141
192
  with open(
142
- os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "w"
193
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "w"
143
194
  ) as f:
144
195
  json.dump(output_dict, f, indent=4)
145
196
 
@@ -152,7 +203,7 @@ def load_recording_list(filename: str) -> list[Recording]:
152
203
  """
153
204
  with open(filename, "r") as f:
154
205
  data = json.load(f)
155
- recording_list = [Recording(**r) for r in data[RECORDING_LIST_NAME]]
206
+ recording_list = [Recording(**r) for r in data[c.RECORDING_LIST_NAME]]
156
207
  for i, r in enumerate(recording_list):
157
208
  r.name = i + 1
158
209
  return recording_list
@@ -165,7 +216,7 @@ def save_recording_list(filename: str, recordings: list[Recording]) -> None:
165
216
  :param recordings: list of recordings to export
166
217
  """
167
218
  recording_dict = {
168
- RECORDING_LIST_NAME: [
219
+ c.RECORDING_LIST_NAME: [
169
220
  {
170
221
  "recording_file": r.recording_file,
171
222
  "label_file": r.label_file,
Binary file
accusleepy/gui/main.py CHANGED
@@ -39,6 +39,13 @@ from accusleepy.constants import (
39
39
  CALIBRATION_ANNOTATION_FILENAME,
40
40
  CALIBRATION_FILE_TYPE,
41
41
  DEFAULT_MODEL_TYPE,
42
+ DEFAULT_EMG_FILTER_ORDER,
43
+ DEFAULT_EMG_BP_LOWER,
44
+ DEFAULT_EMG_BP_UPPER,
45
+ DEFAULT_BATCH_SIZE,
46
+ DEFAULT_LEARNING_RATE,
47
+ DEFAULT_MOMENTUM,
48
+ DEFAULT_TRAINING_EPOCHS,
42
49
  LABEL_FILE_TYPE,
43
50
  MODEL_FILE_TYPE,
44
51
  REAL_TIME_MODEL_TYPE,
@@ -56,6 +63,8 @@ from accusleepy.fileio import (
56
63
  save_config,
57
64
  save_labels,
58
65
  save_recording_list,
66
+ EMGFilter,
67
+ Hyperparameters,
59
68
  )
60
69
  from accusleepy.gui.manual_scoring import ManualScoringWindow
61
70
  from accusleepy.gui.primary_window import Ui_PrimaryWindow
@@ -97,19 +106,25 @@ class AccuSleepWindow(QMainWindow):
97
106
  self.setWindowTitle("AccuSleePy")
98
107
 
99
108
  # fill in settings tab
100
- self.brain_state_set, self.epoch_length, self.save_confidence_setting = (
101
- load_config()
102
- )
109
+ (
110
+ self.brain_state_set,
111
+ self.epoch_length,
112
+ self.only_overwrite_undefined,
113
+ self.save_confidence_scores,
114
+ self.min_bout_length,
115
+ self.emg_filter,
116
+ self.hyperparameters,
117
+ ) = load_config()
118
+
103
119
  self.settings_widgets = None
104
120
  self.initialize_settings_tab()
105
121
 
106
122
  # initialize info about the recordings, classification data / settings
107
123
  self.ui.epoch_length_input.setValue(self.epoch_length)
108
- self.ui.save_confidence_checkbox.setChecked(self.save_confidence_setting)
124
+ self.ui.overwritecheckbox.setChecked(self.only_overwrite_undefined)
125
+ self.ui.save_confidence_checkbox.setChecked(self.save_confidence_scores)
126
+ self.ui.bout_length_input.setValue(self.min_bout_length)
109
127
  self.model = None
110
- self.only_overwrite_undefined = False
111
- self.save_confidence_scores = self.save_confidence_setting
112
- self.min_bout_length = 5
113
128
 
114
129
  # initialize model training variables
115
130
  self.training_epochs_per_img = 9
@@ -186,6 +201,10 @@ class AccuSleepWindow(QMainWindow):
186
201
  self.ui.export_button.clicked.connect(self.export_recording_list)
187
202
  self.ui.import_button.clicked.connect(self.import_recording_list)
188
203
  self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
204
+ self.ui.reset_emg_params_button.clicked.connect(self.reset_emg_filter_settings)
205
+ self.ui.reset_hyperparams_button.clicked.connect(
206
+ self.reset_hyperparams_settings
207
+ )
189
208
 
190
209
  # user input: drag and drop
191
210
  self.ui.recording_file_label.installEventFilter(self)
@@ -363,6 +382,7 @@ class AccuSleepWindow(QMainWindow):
363
382
  brain_state_set=self.brain_state_set,
364
383
  model_type=self.model_type,
365
384
  calibration_fraction=calibration_fraction,
385
+ emg_filter=self.emg_filter,
366
386
  )
367
387
  if len(failed_recordings) > 0:
368
388
  if len(failed_recordings) == len(self.recordings):
@@ -391,6 +411,7 @@ class AccuSleepWindow(QMainWindow):
391
411
  img_dir=temp_image_dir,
392
412
  mixture_weights=self.brain_state_set.mixture_weights,
393
413
  n_classes=self.brain_state_set.n_classes,
414
+ hyperparameters=self.hyperparameters,
394
415
  )
395
416
 
396
417
  # calibrate the model
@@ -399,7 +420,9 @@ class AccuSleepWindow(QMainWindow):
399
420
  temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
400
421
  )
401
422
  calibration_dataloader = create_dataloader(
402
- annotations_file=calibration_annotation_file, img_dir=temp_image_dir
423
+ annotations_file=calibration_annotation_file,
424
+ img_dir=temp_image_dir,
425
+ hyperparameters=self.hyperparameters,
403
426
  )
404
427
  model = ModelWithTemperature(model)
405
428
  print("Calibrating model")
@@ -584,6 +607,7 @@ class AccuSleepWindow(QMainWindow):
584
607
  epoch_length=self.epoch_length,
585
608
  epochs_per_img=self.model_epochs_per_img,
586
609
  brain_state_set=self.brain_state_set,
610
+ emg_filter=self.emg_filter,
587
611
  )
588
612
 
589
613
  # overwrite as needed
@@ -801,6 +825,7 @@ class AccuSleepWindow(QMainWindow):
801
825
  sampling_rate=sampling_rate,
802
826
  epoch_length=self.epoch_length,
803
827
  brain_state_set=self.brain_state_set,
828
+ emg_filter=self.emg_filter,
804
829
  )
805
830
 
806
831
  self.ui.calibration_status.setText("")
@@ -965,6 +990,7 @@ class AccuSleepWindow(QMainWindow):
965
990
  confidence_scores=confidence_scores,
966
991
  sampling_rate=sampling_rate,
967
992
  epoch_length=self.epoch_length,
993
+ emg_filter=self.emg_filter,
968
994
  )
969
995
  manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
970
996
  manual_scoring_window.exec()
@@ -1130,15 +1156,6 @@ class AccuSleepWindow(QMainWindow):
1130
1156
 
1131
1157
  def initialize_settings_tab(self):
1132
1158
  """Populate settings tab and assign its callbacks"""
1133
- # show information about the settings tab
1134
- config_guide_file = open(
1135
- os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_GUIDE_FILE),
1136
- "r",
1137
- )
1138
- config_guide_text = config_guide_file.read()
1139
- config_guide_file.close()
1140
- self.ui.settings_text.setText(config_guide_text)
1141
-
1142
1159
  # store dictionary that maps digits to rows of widgets
1143
1160
  # in the settings tab
1144
1161
  self.settings_widgets = {
@@ -1215,8 +1232,21 @@ class AccuSleepWindow(QMainWindow):
1215
1232
  }
1216
1233
 
1217
1234
  # update widget state to display current config
1235
+ # UI defaults
1218
1236
  self.ui.default_epoch_input.setValue(self.epoch_length)
1219
- self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_setting)
1237
+ self.ui.overwrite_default_checkbox.setChecked(self.only_overwrite_undefined)
1238
+ self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_scores)
1239
+ self.ui.default_min_bout_length_spinbox.setValue(self.min_bout_length)
1240
+ # EMG filter
1241
+ self.ui.emg_order_spinbox.setValue(self.emg_filter.order)
1242
+ self.ui.bp_lower_spinbox.setValue(self.emg_filter.bp_lower)
1243
+ self.ui.bp_upper_spinbox.setValue(self.emg_filter.bp_upper)
1244
+ # model training hyperparameters
1245
+ self.ui.batch_size_spinbox.setValue(self.hyperparameters.batch_size)
1246
+ self.ui.learning_rate_spinbox.setValue(self.hyperparameters.learning_rate)
1247
+ self.ui.momentum_spinbox.setValue(self.hyperparameters.momentum)
1248
+ self.ui.training_epochs_spinbox.setValue(self.hyperparameters.training_epochs)
1249
+ # brain states
1220
1250
  states = {b.digit: b for b in self.brain_state_set.brain_states}
1221
1251
  for digit in range(10):
1222
1252
  if digit in states.keys():
@@ -1235,6 +1265,15 @@ class AccuSleepWindow(QMainWindow):
1235
1265
  self.settings_widgets[digit].frequency_widget.setEnabled(False)
1236
1266
 
1237
1267
  # set callbacks
1268
+ self.ui.emg_order_spinbox.valueChanged.connect(self.emg_filter_order_changed)
1269
+ self.ui.bp_lower_spinbox.valueChanged.connect(self.emg_filter_bp_lower_changed)
1270
+ self.ui.bp_upper_spinbox.valueChanged.connect(self.emg_filter_bp_upper_changed)
1271
+ self.ui.batch_size_spinbox.valueChanged.connect(self.hyperparameters_changed)
1272
+ self.ui.learning_rate_spinbox.valueChanged.connect(self.hyperparameters_changed)
1273
+ self.ui.momentum_spinbox.valueChanged.connect(self.hyperparameters_changed)
1274
+ self.ui.training_epochs_spinbox.valueChanged.connect(
1275
+ self.hyperparameters_changed
1276
+ )
1238
1277
  for digit in range(10):
1239
1278
  state = self.settings_widgets[digit]
1240
1279
  state.enabled_widget.stateChanged.connect(
@@ -1297,6 +1336,41 @@ class AccuSleepWindow(QMainWindow):
1297
1336
  # check that configuration is valid
1298
1337
  _ = self.check_config_validity()
1299
1338
 
1339
+ def emg_filter_order_changed(self, new_value: int) -> None:
1340
+ """Called when user modifies EMG filter order
1341
+
1342
+ :param new_value: new EMG filter order
1343
+ """
1344
+ self.emg_filter.order = new_value
1345
+
1346
+ def emg_filter_bp_lower_changed(self, new_value: int | float) -> None:
1347
+ """Called when user modifies EMG filter lower cutoff
1348
+
1349
+ :param new_value: new lower bandpass cutoff frequency
1350
+ """
1351
+ self.emg_filter.bp_lower = new_value
1352
+ _ = self.check_config_validity()
1353
+
1354
+ def emg_filter_bp_upper_changed(self, new_value: int | float) -> None:
1355
+ """Called when user modifies EMG filter upper cutoff
1356
+
1357
+ :param new_value: new upper bandpass cutoff frequency
1358
+ """
1359
+ self.emg_filter.bp_upper = new_value
1360
+ _ = self.check_config_validity()
1361
+
1362
+ def hyperparameters_changed(self, new_value) -> None:
1363
+ """Called when user modifies model training hyperparameters
1364
+
1365
+ :param new_value: unused
1366
+ """
1367
+ self.hyperparameters = Hyperparameters(
1368
+ batch_size=self.ui.batch_size_spinbox.value(),
1369
+ learning_rate=self.ui.learning_rate_spinbox.value(),
1370
+ momentum=self.ui.momentum_spinbox.value(),
1371
+ training_epochs=self.ui.training_epochs_spinbox.value(),
1372
+ )
1373
+
1300
1374
  def check_config_validity(self) -> str:
1301
1375
  """Check if brain state configuration on screen is valid"""
1302
1376
  # error message, if we get one
@@ -1323,6 +1397,10 @@ class AccuSleepWindow(QMainWindow):
1323
1397
  if sum(frequencies) != 1:
1324
1398
  message = "Error: sum(frequencies) != 1"
1325
1399
 
1400
+ # check validity of EMG filter settings
1401
+ if self.emg_filter.bp_lower >= self.emg_filter.bp_upper:
1402
+ message = "Error: EMG filter cutoff frequencies are invalid"
1403
+
1326
1404
  if message is not None:
1327
1405
  self.ui.save_config_status.setText(message)
1328
1406
  self.ui.save_config_button.setEnabled(False)
@@ -1355,12 +1433,36 @@ class AccuSleepWindow(QMainWindow):
1355
1433
 
1356
1434
  # save to file
1357
1435
  save_config(
1358
- self.brain_state_set,
1359
- self.ui.default_epoch_input.value(),
1360
- self.ui.confidence_setting_checkbox.isChecked(),
1436
+ brain_state_set=self.brain_state_set,
1437
+ default_epoch_length=self.ui.default_epoch_input.value(),
1438
+ overwrite_setting=self.ui.overwrite_default_checkbox.isChecked(),
1439
+ save_confidence_setting=self.ui.confidence_setting_checkbox.isChecked(),
1440
+ min_bout_length=self.ui.default_min_bout_length_spinbox.value(),
1441
+ emg_filter=EMGFilter(
1442
+ order=self.emg_filter.order,
1443
+ bp_lower=self.emg_filter.bp_lower,
1444
+ bp_upper=self.emg_filter.bp_upper,
1445
+ ),
1446
+ hyperparameters=Hyperparameters(
1447
+ batch_size=self.hyperparameters.batch_size,
1448
+ learning_rate=self.hyperparameters.learning_rate,
1449
+ momentum=self.hyperparameters.momentum,
1450
+ training_epochs=self.hyperparameters.training_epochs,
1451
+ ),
1361
1452
  )
1362
1453
  self.ui.save_config_status.setText("configuration saved")
1363
1454
 
1455
+ def reset_emg_filter_settings(self) -> None:
1456
+ self.ui.emg_order_spinbox.setValue(DEFAULT_EMG_FILTER_ORDER)
1457
+ self.ui.bp_lower_spinbox.setValue(DEFAULT_EMG_BP_LOWER)
1458
+ self.ui.bp_upper_spinbox.setValue(DEFAULT_EMG_BP_UPPER)
1459
+
1460
+ def reset_hyperparams_settings(self):
1461
+ self.ui.batch_size_spinbox.setValue(DEFAULT_BATCH_SIZE)
1462
+ self.ui.learning_rate_spinbox.setValue(DEFAULT_LEARNING_RATE)
1463
+ self.ui.momentum_spinbox.setValue(DEFAULT_MOMENTUM)
1464
+ self.ui.training_epochs_spinbox.setValue(DEFAULT_TRAINING_EPOCHS)
1465
+
1364
1466
 
1365
1467
  def check_label_validity(
1366
1468
  labels: np.array,