accusleepy 0.1.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
accusleepy/config.json CHANGED
@@ -18,5 +18,6 @@
18
18
  "is_scored": true,
19
19
  "frequency": 0.55
20
20
  }
21
- ]
22
- }
21
+ ],
22
+ "default_epoch_length": 2.5
23
+ }
accusleepy/constants.py CHANGED
@@ -35,3 +35,5 @@ LABEL_COL = "label"
35
35
  # recording list file header:
36
36
  RECORDING_LIST_NAME = "recording_list"
37
37
  RECORDING_LIST_FILE_TYPE = ".json"
38
+ # key for default epoch length in config
39
+ DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
accusleepy/fileio.py CHANGED
@@ -11,6 +11,7 @@ from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateS
11
11
  from accusleepy.constants import (
12
12
  BRAIN_STATE_COL,
13
13
  CONFIG_FILE,
14
+ DEFAULT_EPOCH_LENGTH_KEY,
14
15
  EEG_COL,
15
16
  EMG_COL,
16
17
  MIXTURE_MEAN_COL,
@@ -82,7 +83,9 @@ def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
82
83
  (default or real-time), set of brain state options
83
84
  used when training the model
84
85
  """
85
- state_dict = torch.load(filename, weights_only=True)
86
+ state_dict = torch.load(
87
+ filename, weights_only=True, map_location=torch.device("cpu")
88
+ )
86
89
  epoch_length = state_dict.pop("epoch_length")
87
90
  epochs_per_img = state_dict.pop("epochs_per_img")
88
91
  model_type = state_dict.pop("model_type")
@@ -141,10 +144,10 @@ def save_labels(labels: np.array, filename: str) -> None:
141
144
  pd.DataFrame({BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
142
145
 
143
146
 
144
- def load_config() -> BrainStateSet:
147
+ def load_config() -> tuple[BrainStateSet, int | float]:
145
148
  """Load configuration file with brain state options
146
149
 
147
- :return: set of brain state options
150
+ :return: set of brain state options and default epoch length
148
151
  """
149
152
  with open(
150
153
  os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "r"
@@ -152,18 +155,23 @@ def load_config() -> BrainStateSet:
152
155
  data = json.load(f)
153
156
  return BrainStateSet(
154
157
  [BrainState(**b) for b in data[BRAIN_STATES_KEY]], UNDEFINED_LABEL
155
- )
158
+ ), data[DEFAULT_EPOCH_LENGTH_KEY]
156
159
 
157
160
 
158
- def save_config(brain_state_set: BrainStateSet) -> None:
161
+ def save_config(
162
+ brain_state_set: BrainStateSet, default_epoch_length: int | float
163
+ ) -> None:
159
164
  """Save configuration of brain state options to json file
160
165
 
161
166
  :param brain_state_set: set of brain state options
167
+ :param default_epoch_length: epoch length to use when the GUI starts
162
168
  """
169
+ output_dict = brain_state_set.to_output_dict()
170
+ output_dict.update({DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
163
171
  with open(
164
172
  os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "w"
165
173
  ) as f:
166
- json.dump(brain_state_set.to_output_dict(), f, indent=4)
174
+ json.dump(output_dict, f, indent=4)
167
175
 
168
176
 
169
177
  def load_recording_list(filename: str) -> list[Recording]:
accusleepy/gui/main.py CHANGED
@@ -40,6 +40,7 @@ from accusleepy.fileio import (
40
40
  save_model,
41
41
  save_recording_list,
42
42
  )
43
+ from accusleepy.gui.text.main_guide_text import MAIN_GUIDE_TEXT
43
44
  from accusleepy.gui.manual_scoring import ManualScoringWindow
44
45
  from accusleepy.gui.primary_window import Ui_PrimaryWindow
45
46
  from accusleepy.signal_processing import (
@@ -52,8 +53,7 @@ from accusleepy.signal_processing import (
52
53
  # max number of messages to display
53
54
  MESSAGE_BOX_MAX_DEPTH = 50
54
55
  LABEL_LENGTH_ERROR = "label file length does not match recording length"
55
- # relative path to user manual txt file
56
- USER_MANUAL_FILE = "text/main_guide.txt"
56
+ # relative path to config guide txt file
57
57
  CONFIG_GUIDE_FILE = "text/config_guide.txt"
58
58
 
59
59
 
@@ -80,12 +80,12 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
80
80
  self.setWindowTitle("AccuSleePy")
81
81
 
82
82
  # fill in settings tab
83
- self.brain_state_set = load_config()
83
+ self.brain_state_set, self.epoch_length = load_config()
84
84
  self.settings_widgets = None
85
85
  self.initialize_settings_tab()
86
86
 
87
87
  # initialize info about the recordings, classification data / settings
88
- self.epoch_length = 0
88
+ self.ui.epoch_length_input.setValue(self.epoch_length)
89
89
  self.model = None
90
90
  self.only_overwrite_undefined = False
91
91
  self.min_bout_length = 5
@@ -267,7 +267,9 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
267
267
  )
268
268
  return
269
269
  if self.training_image_dir == "":
270
- self.show_message("ERROR: no folder selected for training images.")
270
+ self.show_message(
271
+ ("ERROR: no output location selected for training images.")
272
+ )
271
273
  return
272
274
 
273
275
  # check some inputs for each recording
@@ -288,23 +290,28 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
288
290
  if not model_filename:
289
291
  self.show_message("Model training canceled, no filename given")
290
292
 
291
- # create image folder
292
- if os.path.exists(self.training_image_dir):
293
+ # create (probably temporary) image folder
294
+ temp_image_dir = os.path.join(
295
+ self.training_image_dir,
296
+ "images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
297
+ )
298
+
299
+ if os.path.exists(temp_image_dir): # unlikely
293
300
  self.show_message(
294
301
  "Warning: training image folder exists, will be overwritten"
295
302
  )
296
- os.makedirs(self.training_image_dir, exist_ok=True)
303
+ os.makedirs(temp_image_dir, exist_ok=True)
297
304
 
298
305
  # create training images
299
306
  self.show_message(
300
- (f"Creating training images in {self.training_image_dir}, please wait...")
307
+ (f"Creating training images in {temp_image_dir}, please wait...")
301
308
  )
302
309
  self.ui.message_area.repaint()
303
310
  QtWidgets.QApplication.processEvents()
304
311
  print("Creating training images")
305
312
  failed_recordings = create_training_images(
306
313
  recordings=self.recordings,
307
- output_path=self.training_image_dir,
314
+ output_path=temp_image_dir,
308
315
  epoch_length=self.epoch_length,
309
316
  epochs_per_img=self.training_epochs_per_img,
310
317
  brain_state_set=self.brain_state_set,
@@ -328,10 +335,8 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
328
335
  QtWidgets.QApplication.processEvents()
329
336
  print("Training model")
330
337
  model = train_model(
331
- annotations_file=os.path.join(
332
- self.training_image_dir, ANNOTATIONS_FILENAME
333
- ),
334
- img_dir=self.training_image_dir,
338
+ annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
339
+ img_dir=temp_image_dir,
335
340
  mixture_weights=self.brain_state_set.mixture_weights,
336
341
  n_classes=self.brain_state_set.n_classes,
337
342
  )
@@ -348,20 +353,18 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
348
353
 
349
354
  # optionally delete images
350
355
  if self.delete_training_images:
351
- shutil.rmtree(self.training_image_dir)
356
+ shutil.rmtree(temp_image_dir)
352
357
 
353
358
  self.show_message(f"Training complete, saved model to {model_filename}")
354
359
 
355
- def set_training_folder(self):
360
+ def set_training_folder(self) -> None:
361
+ """Select location in which to create a folder for training images"""
356
362
  training_folder_parent = QtWidgets.QFileDialog.getExistingDirectory(
357
363
  self, "Select directory for training images"
358
364
  )
359
365
  if training_folder_parent:
360
- self.training_image_dir = os.path.join(
361
- training_folder_parent,
362
- "images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
363
- )
364
- self.ui.image_folder_label.setText(self.training_image_dir)
366
+ self.training_image_dir = training_folder_parent
367
+ self.ui.image_folder_label.setText(training_folder_parent)
365
368
 
366
369
  def update_image_deletion(self) -> None:
367
370
  """Update choice of whether to delete images after training"""
@@ -1002,15 +1005,8 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
1002
1005
 
1003
1006
  def show_user_manual(self) -> None:
1004
1007
  """Show a popup window with the user manual"""
1005
- user_manual_file = open(
1006
- os.path.join(os.path.dirname(os.path.abspath(__file__)), USER_MANUAL_FILE),
1007
- "r",
1008
- )
1009
- user_manual_text = user_manual_file.read()
1010
- user_manual_file.close()
1011
-
1012
1008
  label_widget = QtWidgets.QLabel()
1013
- label_widget.setText(user_manual_text)
1009
+ label_widget.setText(MAIN_GUIDE_TEXT)
1014
1010
  scroll_area = QtWidgets.QScrollArea()
1015
1011
  scroll_area.setStyleSheet("background-color: white;")
1016
1012
  scroll_area.setWidget(label_widget)
@@ -1108,6 +1104,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
1108
1104
  }
1109
1105
 
1110
1106
  # update widget state to display current config
1107
+ self.ui.default_epoch_input.setValue(self.epoch_length)
1111
1108
  states = {b.digit: b for b in self.brain_state_set.brain_states}
1112
1109
  for digit in range(10):
1113
1110
  if digit in states.keys():
@@ -1245,7 +1242,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
1245
1242
  self.brain_state_set = BrainStateSet(brain_states, UNDEFINED_LABEL)
1246
1243
 
1247
1244
  # save to file
1248
- save_config(self.brain_state_set)
1245
+ save_config(self.brain_state_set, self.ui.default_epoch_input.value())
1249
1246
  self.ui.save_config_status.setText("configuration saved")
1250
1247
 
1251
1248
 
@@ -111,7 +111,7 @@ class ManualScoringWindow(QtWidgets.QDialog):
111
111
  self.setWindowTitle("AccuSleePy manual scoring window")
112
112
 
113
113
  # load set of valid brain states
114
- self.brain_state_set = load_config()
114
+ self.brain_state_set, _ = load_config()
115
115
 
116
116
  # initial setting for number of epochs to show in the lower plot
117
117
  self.epochs_to_show = 5
@@ -833,7 +833,7 @@ class ManualScoringWindow(QtWidgets.QDialog):
833
833
  self.adjust_upper_figure_x_limits()
834
834
 
835
835
  # update parts of lower plot
836
- old_window_center = round(self.epochs_to_show / 2) + self.lower_left_epoch
836
+ old_window_center = round((self.epochs_to_show - 1) / 2) + self.lower_left_epoch
837
837
  # change the window bounds if needed
838
838
  if self.epoch < old_window_center and self.lower_left_epoch > 0:
839
839
  self.lower_left_epoch -= 1