accusleepy 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
accusleepy/bouts.py ADDED
@@ -0,0 +1,142 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from operator import attrgetter
4
+
5
+ import numpy as np
6
+
7
+
8
+ @dataclass
9
+ class Bout:
10
+ """Stores information about a brain state bout"""
11
+
12
+ length: int # length, in number of epochs
13
+ start_index: int # index where bout starts
14
+ end_index: int # index where bout ends
15
+ surrounding_state: int # brain state on both sides of the bout
16
+
17
+
18
+ def find_last_adjacent_bout(sorted_bouts: list[Bout], bout_index: int) -> int:
19
+ """Find index of last consecutive same-length bout
20
+
21
+ When running the post-processing step that enforces a minimum duration
22
+ for brain state bouts, there is a special case when bouts below the
23
+ duration threshold occur consecutively. This function performs a
24
+ recursive search for the index of a bout at the end of such a sequence.
25
+ When initially called, bout_index will always be 0. If, for example, the
26
+ first three bouts in the list are consecutive, the function will return 2.
27
+
28
+ :param sorted_bouts: list of brain state bouts, sorted by start time
29
+ :param bout_index: index of the bout in question
30
+ :return: index of the last consecutive same-length bout
31
+ """
32
+ # if we're at the end of the bout list, stop
33
+ if bout_index == len(sorted_bouts) - 1:
34
+ return bout_index
35
+
36
+ # if there is an adjacent bout
37
+ if sorted_bouts[bout_index].end_index == sorted_bouts[bout_index + 1].start_index:
38
+ # look for more adjacent bouts using that one as a starting point
39
+ return find_last_adjacent_bout(sorted_bouts, bout_index + 1)
40
+ else:
41
+ return bout_index
42
+
43
+
44
+ def enforce_min_bout_length(
45
+ labels: np.array, epoch_length: int | float, min_bout_length: int | float
46
+ ) -> np.array:
47
+ """Ensure brain state bouts meet the min length requirement
48
+
49
+ As a post-processing step for sleep scoring, we can require that any
50
+ bout (continuous period) of a brain state have a minimum duration.
51
+ This function sets any bout shorter than the minimum duration to the
52
+ surrounding brain state (if the states on the left and right sides
53
+ are the same). In the case where there are consecutive short bouts,
54
+ it either creates a transition at the midpoint or removes all short
55
+ bouts, depending on whether the number is even or odd. For example:
56
+ ...AAABABAAA... -> ...AAAAAAAAA...
57
+ ...AAABABABBB... -> ...AAAAABBBBB...
58
+
59
+ :param labels: brain state labels (digits in the 0-9 range)
60
+ :param epoch_length: epoch length, in seconds
61
+ :param min_bout_length: minimum bout length, in seconds
62
+ :return: updated brain state labels
63
+ """
64
+ # if recording is very short, don't change anything
65
+ if labels.size < 3:
66
+ return labels
67
+
68
+ if epoch_length == min_bout_length:
69
+ return labels
70
+
71
+ # get minimum number of epochs in a bout
72
+ min_epochs = int(np.ceil(min_bout_length / epoch_length))
73
+ # get set of states in the labels
74
+ brain_states = set(labels.tolist())
75
+
76
+ while True: # so true
77
+ # convert labels to a string for regex search
78
+ # There is probably a regex that can find all patterns like ab+a
79
+ # without consuming each "a" but I haven't found it :(
80
+ label_string = "".join(labels.astype(str))
81
+
82
+ bouts = list()
83
+
84
+ for state in brain_states:
85
+ for other_state in brain_states:
86
+ if state == other_state:
87
+ continue
88
+ # get start and end indices of each bout
89
+ expression = (
90
+ f"(?<={other_state}){state}{{1,{min_epochs - 1}}}(?={other_state})"
91
+ )
92
+ matches = re.finditer(expression, label_string)
93
+ spans = [match.span() for match in matches]
94
+
95
+ # if some bouts were found
96
+ for span in spans:
97
+ bouts.append(
98
+ Bout(
99
+ length=span[1] - span[0],
100
+ start_index=span[0],
101
+ end_index=span[1],
102
+ surrounding_state=other_state,
103
+ )
104
+ )
105
+
106
+ if len(bouts) == 0:
107
+ break
108
+
109
+ # only keep the shortest bouts
110
+ min_length_in_list = np.min([bout.length for bout in bouts])
111
+ bouts = [i for i in bouts if i.length == min_length_in_list]
112
+ # sort by start index
113
+ sorted_bouts = sorted(bouts, key=attrgetter("start_index"))
114
+
115
+ while len(sorted_bouts) > 0:
116
+ # get row index of latest adjacent bout (of same length)
117
+ last_adjacent_bout_index = find_last_adjacent_bout(sorted_bouts, 0)
118
+ # if there's an even number of adjacent bouts
119
+ if (last_adjacent_bout_index + 1) % 2 == 0:
120
+ midpoint = sorted_bouts[
121
+ round((last_adjacent_bout_index + 1) / 2)
122
+ ].start_index
123
+ labels[sorted_bouts[0].start_index : midpoint] = sorted_bouts[
124
+ 0
125
+ ].surrounding_state
126
+ labels[midpoint : sorted_bouts[last_adjacent_bout_index].end_index] = (
127
+ sorted_bouts[last_adjacent_bout_index].surrounding_state
128
+ )
129
+ else:
130
+ labels[
131
+ sorted_bouts[0].start_index : sorted_bouts[
132
+ last_adjacent_bout_index
133
+ ].end_index
134
+ ] = sorted_bouts[0].surrounding_state
135
+
136
+ # delete the bouts we just fixed
137
+ if last_adjacent_bout_index == len(sorted_bouts) - 1:
138
+ sorted_bouts = []
139
+ else:
140
+ sorted_bouts = sorted_bouts[(last_adjacent_bout_index + 1) :]
141
+
142
+ return labels
@@ -61,13 +61,13 @@ def get_device():
61
61
  )
62
62
 
63
63
 
64
- def train_model(
64
+ def train_ssann(
65
65
  annotations_file: str,
66
66
  img_dir: str,
67
67
  mixture_weights: np.array,
68
68
  n_classes: int,
69
69
  ) -> SSANN:
70
- """Train a classification model for sleep scoring
70
+ """Train a SSANN classification model for sleep scoring
71
71
 
72
72
  :param annotations_file: file with information on each training image
73
73
  :param img_dir: training image location
accusleepy/constants.py CHANGED
@@ -37,3 +37,5 @@ RECORDING_LIST_NAME = "recording_list"
37
37
  RECORDING_LIST_FILE_TYPE = ".json"
38
38
  # key for default epoch length in config
39
39
  DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
40
+ # filename used to store info about training image datasets
41
+ ANNOTATIONS_FILENAME = "annotations.csv"
accusleepy/fileio.py CHANGED
@@ -4,7 +4,6 @@ from dataclasses import dataclass
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
7
- import torch
8
7
  from PySide6.QtWidgets import QListWidgetItem
9
8
 
10
9
  from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
@@ -19,7 +18,6 @@ from accusleepy.constants import (
19
18
  RECORDING_LIST_NAME,
20
19
  UNDEFINED_LABEL,
21
20
  )
22
- from accusleepy.models import SSANN
23
21
 
24
22
 
25
23
  @dataclass
@@ -46,57 +44,6 @@ def load_calibration_file(filename: str) -> (np.array, np.array):
46
44
  return mixture_means, mixture_sds
47
45
 
48
46
 
49
- def save_model(
50
- model: SSANN,
51
- filename: str,
52
- epoch_length: int | float,
53
- epochs_per_img: int,
54
- model_type: str,
55
- brain_state_set: BrainStateSet,
56
- ) -> None:
57
- """Save classification model and its metadata
58
-
59
- :param model: classification model
60
- :param epoch_length: epoch length used when training the model
61
- :param epochs_per_img: number of epochs in each model input
62
- :param model_type: default or real-time
63
- :param brain_state_set: set of brain state options
64
- :param filename: filename
65
- """
66
- state_dict = model.state_dict()
67
- state_dict.update({"epoch_length": epoch_length})
68
- state_dict.update({"epochs_per_img": epochs_per_img})
69
- state_dict.update({"model_type": model_type})
70
- state_dict.update(
71
- {BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
72
- )
73
-
74
- torch.save(state_dict, filename)
75
-
76
-
77
- def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
78
- """Load classification model and its metadata
79
-
80
- :param filename: filename
81
- :return: model, epoch length used when training the model,
82
- number of epochs in each model input, model type
83
- (default or real-time), set of brain state options
84
- used when training the model
85
- """
86
- state_dict = torch.load(
87
- filename, weights_only=True, map_location=torch.device("cpu")
88
- )
89
- epoch_length = state_dict.pop("epoch_length")
90
- epochs_per_img = state_dict.pop("epochs_per_img")
91
- model_type = state_dict.pop("model_type")
92
- brain_states = state_dict.pop(BRAIN_STATES_KEY)
93
- n_classes = len([b for b in brain_states if b["is_scored"]])
94
-
95
- model = SSANN(n_classes=n_classes)
96
- model.load_state_dict(state_dict)
97
- return model, epoch_length, epochs_per_img, model_type, brain_states
98
-
99
-
100
47
  def load_csv_or_parquet(filename: str) -> pd.DataFrame:
101
48
  """Load a csv or parquet file as a dataframe
102
49
 
Binary file
accusleepy/gui/main.py CHANGED
@@ -5,20 +5,37 @@ import datetime
5
5
  import os
6
6
  import shutil
7
7
  import sys
8
- import toml
9
8
  from dataclasses import dataclass
10
9
  from functools import partial
11
10
 
12
11
  import numpy as np
13
- from PySide6 import QtCore, QtGui, QtWidgets
12
+ import toml
13
+ from PySide6.QtCore import (
14
+ QEvent,
15
+ QKeyCombination,
16
+ QObject,
17
+ QRect,
18
+ Qt,
19
+ QUrl,
20
+ )
21
+ from PySide6.QtGui import QKeySequence, QShortcut
22
+ from PySide6.QtWidgets import (
23
+ QApplication,
24
+ QCheckBox,
25
+ QDoubleSpinBox,
26
+ QFileDialog,
27
+ QLabel,
28
+ QListWidgetItem,
29
+ QMainWindow,
30
+ QTextBrowser,
31
+ QVBoxLayout,
32
+ QWidget,
33
+ )
14
34
 
35
+ from accusleepy.bouts import enforce_min_bout_length
15
36
  from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
16
- from accusleepy.classification import (
17
- create_calibration_file,
18
- score_recording,
19
- train_model,
20
- )
21
37
  from accusleepy.constants import (
38
+ ANNOTATIONS_FILENAME,
22
39
  CALIBRATION_FILE_TYPE,
23
40
  DEFAULT_MODEL_TYPE,
24
41
  LABEL_FILE_TYPE,
@@ -33,23 +50,21 @@ from accusleepy.fileio import (
33
50
  load_calibration_file,
34
51
  load_config,
35
52
  load_labels,
36
- load_model,
37
53
  load_recording,
38
54
  load_recording_list,
39
55
  save_config,
40
56
  save_labels,
41
- save_model,
42
57
  save_recording_list,
43
58
  )
44
59
  from accusleepy.gui.manual_scoring import ManualScoringWindow
45
60
  from accusleepy.gui.primary_window import Ui_PrimaryWindow
46
61
  from accusleepy.signal_processing import (
47
- ANNOTATIONS_FILENAME,
48
62
  create_training_images,
49
- enforce_min_bout_length,
50
63
  resample_and_standardize,
51
64
  )
52
65
 
66
+ # note: functions using torch or scipy are lazily imported
67
+
53
68
  # max number of messages to display
54
69
  MESSAGE_BOX_MAX_DEPTH = 200
55
70
  LABEL_LENGTH_ERROR = "label file length does not match recording length"
@@ -63,13 +78,13 @@ class StateSettings:
63
78
  """Widgets for config settings for a brain state"""
64
79
 
65
80
  digit: int
66
- enabled_widget: QtWidgets.QCheckBox
67
- name_widget: QtWidgets.QLabel
68
- is_scored_widget: QtWidgets.QCheckBox
69
- frequency_widget: QtWidgets.QDoubleSpinBox
81
+ enabled_widget: QCheckBox
82
+ name_widget: QLabel
83
+ is_scored_widget: QCheckBox
84
+ frequency_widget: QDoubleSpinBox
70
85
 
71
86
 
72
- class AccuSleepWindow(QtWidgets.QMainWindow):
87
+ class AccuSleepWindow(QMainWindow):
73
88
  """AccuSleePy primary window"""
74
89
 
75
90
  def __init__(self):
@@ -103,9 +118,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
103
118
 
104
119
  # set up the list of recordings
105
120
  first_recording = Recording(
106
- widget=QtWidgets.QListWidgetItem(
107
- "Recording 1", self.ui.recording_list_widget
108
- ),
121
+ widget=QListWidgetItem("Recording 1", self.ui.recording_list_widget),
109
122
  )
110
123
  self.ui.recording_list_widget.addItem(first_recording.widget)
111
124
  self.ui.recording_list_widget.setCurrentRow(0)
@@ -132,10 +145,8 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
132
145
  self.ui.version_label.setText(f"v{version}")
133
146
 
134
147
  # user input: keyboard shortcuts
135
- keypress_quit = QtGui.QShortcut(
136
- QtGui.QKeySequence(
137
- QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_W)
138
- ),
148
+ keypress_quit = QShortcut(
149
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_W)),
139
150
  self,
140
151
  )
141
152
  keypress_quit.activated.connect(self.close)
@@ -187,7 +198,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
187
198
  def export_recording_list(self) -> None:
188
199
  """Save current list of recordings to file"""
189
200
  # get the name for the recording list file
190
- filename, _ = QtWidgets.QFileDialog.getSaveFileName(
201
+ filename, _ = QFileDialog.getSaveFileName(
191
202
  self,
192
203
  caption="Save list of recordings as",
193
204
  filter="*" + RECORDING_LIST_FILE_TYPE,
@@ -200,10 +211,10 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
200
211
 
201
212
  def import_recording_list(self):
202
213
  """Load list of recordings from file, overwriting current list"""
203
- file_dialog = QtWidgets.QFileDialog(self)
214
+ file_dialog = QFileDialog(self)
204
215
  file_dialog.setWindowTitle("Select list of recordings")
205
- file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
206
- file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
216
+ file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
217
+ file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
207
218
  file_dialog.setNameFilter("*" + RECORDING_LIST_FILE_TYPE)
208
219
 
209
220
  if file_dialog.exec():
@@ -219,7 +230,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
219
230
  self.recordings = load_recording_list(filename)
220
231
 
221
232
  for recording in self.recordings:
222
- recording.widget = QtWidgets.QListWidgetItem(
233
+ recording.widget = QListWidgetItem(
223
234
  f"Recording {recording.name}", self.ui.recording_list_widget
224
235
  )
225
236
  self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
@@ -228,7 +239,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
228
239
  self.ui.recording_list_widget.setCurrentRow(0)
229
240
  self.show_message(f"Loaded list of recordings from {filename}")
230
241
 
231
- def eventFilter(self, obj: QtCore.QObject, event: QtCore.QEvent) -> bool:
242
+ def eventFilter(self, obj: QObject, event: QEvent) -> bool:
232
243
  """Filter mouse events to detect when user drags/drops a file
233
244
 
234
245
  :param obj: UI object receiving the event
@@ -243,7 +254,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
243
254
  self.ui.model_label,
244
255
  ]:
245
256
  event.accept()
246
- if event.type() == QtCore.QEvent.Drop:
257
+ if event.type() == QEvent.Drop:
247
258
  urls = event.mimeData().urls()
248
259
  if len(urls) == 1:
249
260
  filename = os.path.normpath(urls[0].toLocalFile())
@@ -299,7 +310,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
299
310
  return
300
311
 
301
312
  # get filename for the new model
302
- model_filename, _ = QtWidgets.QFileDialog.getSaveFileName(
313
+ model_filename, _ = QFileDialog.getSaveFileName(
303
314
  self,
304
315
  caption="Save classification model file as",
305
316
  filter="*" + MODEL_FILE_TYPE,
@@ -322,11 +333,10 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
322
333
  os.makedirs(temp_image_dir, exist_ok=True)
323
334
 
324
335
  # create training images
325
- self.show_message(
326
- (f"Creating training images in {temp_image_dir}, please wait...")
327
- )
336
+ self.show_message("Training, please wait. See console for progress updates.")
337
+ self.show_message((f"Creating training images in {temp_image_dir}"))
328
338
  self.ui.message_area.repaint()
329
- QtWidgets.QApplication.processEvents()
339
+ QApplication.processEvents()
330
340
  print("Creating training images")
331
341
  failed_recordings = create_training_images(
332
342
  recordings=self.recordings,
@@ -349,11 +359,14 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
349
359
  )
350
360
 
351
361
  # train model
352
- self.show_message("Training model, please wait...")
362
+ self.show_message("Training model")
353
363
  self.ui.message_area.repaint()
354
- QtWidgets.QApplication.processEvents()
364
+ QApplication.processEvents()
355
365
  print("Training model")
356
- model = train_model(
366
+ from accusleepy.classification import train_ssann
367
+ from accusleepy.models import save_model
368
+
369
+ model = train_ssann(
357
370
  annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
358
371
  img_dir=temp_image_dir,
359
372
  mixture_weights=self.brain_state_set.mixture_weights,
@@ -374,11 +387,12 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
374
387
  if self.delete_training_images:
375
388
  shutil.rmtree(temp_image_dir)
376
389
 
377
- self.show_message(f"Training complete, saved model to {model_filename}")
390
+ self.show_message(f"Training complete. Saved model to {model_filename}")
391
+ print("Training complete.")
378
392
 
379
393
  def set_training_folder(self) -> None:
380
394
  """Select location in which to create a folder for training images"""
381
- training_folder_parent = QtWidgets.QFileDialog.getExistingDirectory(
395
+ training_folder_parent = QFileDialog.getExistingDirectory(
382
396
  self, "Select directory for training images"
383
397
  )
384
398
  if training_folder_parent:
@@ -421,7 +435,9 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
421
435
 
422
436
  self.ui.score_all_status.setText("running...")
423
437
  self.ui.score_all_status.repaint()
424
- QtWidgets.QApplication.processEvents()
438
+ QApplication.processEvents()
439
+
440
+ from accusleepy.classification import score_recording
425
441
 
426
442
  # check some inputs for each recording
427
443
  for recording_index in range(len(self.recordings)):
@@ -570,11 +586,13 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
570
586
 
571
587
  :param filename: model filename, if it's known
572
588
  """
589
+ from accusleepy.models import load_model
590
+
573
591
  if filename is None:
574
- file_dialog = QtWidgets.QFileDialog(self)
592
+ file_dialog = QFileDialog(self)
575
593
  file_dialog.setWindowTitle("Select classification model")
576
- file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
577
- file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
594
+ file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
595
+ file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
578
596
  file_dialog.setNameFilter("*" + MODEL_FILE_TYPE)
579
597
 
580
598
  if file_dialog.exec():
@@ -634,7 +652,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
634
652
  self.ui.model_label.setText(filename)
635
653
 
636
654
  def load_single_recording(
637
- self, status_widget: QtWidgets.QLabel
655
+ self, status_widget: QLabel
638
656
  ) -> (np.array, np.array, int | float, bool):
639
657
  """Load and preprocess one recording
640
658
 
@@ -721,7 +739,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
721
739
  return
722
740
 
723
741
  # get the name for the calibration file
724
- filename, _ = QtWidgets.QFileDialog.getSaveFileName(
742
+ filename, _ = QFileDialog.getSaveFileName(
725
743
  self,
726
744
  caption="Save calibration file as",
727
745
  filter="*" + CALIBRATION_FILE_TYPE,
@@ -730,6 +748,8 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
730
748
  return
731
749
  filename = os.path.normpath(filename)
732
750
 
751
+ from accusleepy.classification import create_calibration_file
752
+
733
753
  create_calibration_file(
734
754
  filename=filename,
735
755
  eeg=eeg,
@@ -799,7 +819,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
799
819
  # immediately display a status message
800
820
  self.ui.manual_scoring_status.setText("loading...")
801
821
  self.ui.manual_scoring_status.repaint()
802
- QtWidgets.QApplication.processEvents()
822
+ QApplication.processEvents()
803
823
 
804
824
  # load the recording
805
825
  eeg, emg, sampling_rate, success = self.load_single_recording(
@@ -889,7 +909,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
889
909
 
890
910
  def create_label_file(self) -> None:
891
911
  """Set the filename for a new label file"""
892
- filename, _ = QtWidgets.QFileDialog.getSaveFileName(
912
+ filename, _ = QFileDialog.getSaveFileName(
893
913
  self,
894
914
  caption="Set filename for label file (nothing will be overwritten yet)",
895
915
  filter="*" + LABEL_FILE_TYPE,
@@ -901,10 +921,10 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
901
921
 
902
922
  def select_label_file(self) -> None:
903
923
  """User can select an existing label file"""
904
- file_dialog = QtWidgets.QFileDialog(self)
924
+ file_dialog = QFileDialog(self)
905
925
  file_dialog.setWindowTitle("Select label file")
906
- file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
907
- file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
926
+ file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
927
+ file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
908
928
  file_dialog.setNameFilter("*" + LABEL_FILE_TYPE)
909
929
 
910
930
  if file_dialog.exec():
@@ -916,10 +936,10 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
916
936
 
917
937
  def select_calibration_file(self) -> None:
918
938
  """User can select a calibration file"""
919
- file_dialog = QtWidgets.QFileDialog(self)
939
+ file_dialog = QFileDialog(self)
920
940
  file_dialog.setWindowTitle("Select calibration file")
921
- file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
922
- file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
941
+ file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
942
+ file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
923
943
  file_dialog.setNameFilter("*" + CALIBRATION_FILE_TYPE)
924
944
 
925
945
  if file_dialog.exec():
@@ -931,10 +951,10 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
931
951
 
932
952
  def select_recording_file(self) -> None:
933
953
  """User can select a recording file"""
934
- file_dialog = QtWidgets.QFileDialog(self)
954
+ file_dialog = QFileDialog(self)
935
955
  file_dialog.setWindowTitle("Select recording file")
936
- file_dialog.setFileMode(QtWidgets.QFileDialog.FileMode.ExistingFile)
937
- file_dialog.setViewMode(QtWidgets.QFileDialog.ViewMode.Detail)
956
+ file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
957
+ file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
938
958
  file_dialog.setNameFilter(f"(*{' *'.join(RECORDING_FILE_TYPES)})")
939
959
 
940
960
  if file_dialog.exec():
@@ -1009,7 +1029,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
1009
1029
  Recording(
1010
1030
  name=new_name,
1011
1031
  sampling_rate=self.recordings[self.recording_index].sampling_rate,
1012
- widget=QtWidgets.QListWidgetItem(
1032
+ widget=QListWidgetItem(
1013
1033
  f"Recording {new_name}", self.ui.recording_list_widget
1014
1034
  ),
1015
1035
  )
@@ -1033,16 +1053,16 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
1033
1053
 
1034
1054
  def show_user_manual(self) -> None:
1035
1055
  """Show a popup window with the user manual"""
1036
- self.popup = QtWidgets.QWidget()
1037
- self.popup_vlayout = QtWidgets.QVBoxLayout(self.popup)
1038
- self.guide_textbox = QtWidgets.QTextBrowser(self.popup)
1056
+ self.popup = QWidget()
1057
+ self.popup_vlayout = QVBoxLayout(self.popup)
1058
+ self.guide_textbox = QTextBrowser(self.popup)
1039
1059
  self.popup_vlayout.addWidget(self.guide_textbox)
1040
1060
 
1041
- url = QtCore.QUrl.fromLocalFile(MAIN_GUIDE_FILE)
1061
+ url = QUrl.fromLocalFile(MAIN_GUIDE_FILE)
1042
1062
  self.guide_textbox.setSource(url)
1043
1063
  self.guide_textbox.setOpenLinks(False)
1044
1064
 
1045
- self.popup.setGeometry(QtCore.QRect(100, 100, 600, 600))
1065
+ self.popup.setGeometry(QRect(100, 100, 600, 600))
1046
1066
  self.popup.show()
1047
1067
 
1048
1068
  def initialize_settings_tab(self):
@@ -1389,7 +1409,7 @@ def check_config_consistency(
1389
1409
 
1390
1410
 
1391
1411
  def run_primary_window() -> None:
1392
- app = QtWidgets.QApplication(sys.argv)
1412
+ app = QApplication(sys.argv)
1393
1413
  AccuSleepWindow()
1394
1414
  sys.exit(app.exec())
1395
1415
 
@@ -12,7 +12,24 @@ from types import SimpleNamespace
12
12
 
13
13
  import matplotlib.pyplot as plt
14
14
  import numpy as np
15
- from PySide6 import QtCore, QtGui, QtWidgets
15
+ from PySide6.QtCore import (
16
+ QKeyCombination,
17
+ QRect,
18
+ Qt,
19
+ QUrl,
20
+ )
21
+ from PySide6.QtGui import (
22
+ QCloseEvent,
23
+ QKeySequence,
24
+ QShortcut,
25
+ )
26
+ from PySide6.QtWidgets import (
27
+ QDialog,
28
+ QMessageBox,
29
+ QTextBrowser,
30
+ QVBoxLayout,
31
+ QWidget,
32
+ )
16
33
 
17
34
  from accusleepy.constants import UNDEFINED_LABEL
18
35
  from accusleepy.fileio import load_config, save_labels
@@ -73,7 +90,7 @@ class StateChange:
73
90
  epoch: int # first epoch affected
74
91
 
75
92
 
76
- class ManualScoringWindow(QtWidgets.QDialog):
93
+ class ManualScoringWindow(QDialog):
77
94
  """AccuSleePy manual scoring GUI"""
78
95
 
79
96
  def __init__(
@@ -191,33 +208,25 @@ class ManualScoringWindow(QtWidgets.QDialog):
191
208
  self.update_lower_figure()
192
209
 
193
210
  # user input: keyboard shortcuts
194
- keypress_right = QtGui.QShortcut(
195
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Right), self
196
- )
211
+ keypress_right = QShortcut(QKeySequence(Qt.Key.Key_Right), self)
197
212
  keypress_right.activated.connect(partial(self.shift_epoch, DIRECTION_RIGHT))
198
213
 
199
- keypress_left = QtGui.QShortcut(
200
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Left), self
201
- )
214
+ keypress_left = QShortcut(QKeySequence(Qt.Key.Key_Left), self)
202
215
  keypress_left.activated.connect(partial(self.shift_epoch, DIRECTION_LEFT))
203
216
 
204
217
  keypress_zoom_in_x = list()
205
- for zoom_key in [QtCore.Qt.Key.Key_Plus, QtCore.Qt.Key.Key_Equal]:
206
- keypress_zoom_in_x.append(
207
- QtGui.QShortcut(QtGui.QKeySequence(zoom_key), self)
208
- )
218
+ for zoom_key in [Qt.Key.Key_Plus, Qt.Key.Key_Equal]:
219
+ keypress_zoom_in_x.append(QShortcut(QKeySequence(zoom_key), self))
209
220
  keypress_zoom_in_x[-1].activated.connect(partial(self.zoom_x, ZOOM_IN))
210
221
 
211
- keypress_zoom_out_x = QtGui.QShortcut(
212
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Minus), self
213
- )
222
+ keypress_zoom_out_x = QShortcut(QKeySequence(Qt.Key.Key_Minus), self)
214
223
  keypress_zoom_out_x.activated.connect(partial(self.zoom_x, ZOOM_OUT))
215
224
 
216
225
  keypress_modify_label = list()
217
226
  for brain_state in self.brain_state_set.brain_states:
218
227
  keypress_modify_label.append(
219
- QtGui.QShortcut(
220
- QtGui.QKeySequence(QtCore.Qt.Key[f"Key_{brain_state.digit}"]),
228
+ QShortcut(
229
+ QKeySequence(Qt.Key[f"Key_{brain_state.digit}"]),
221
230
  self,
222
231
  )
223
232
  )
@@ -225,25 +234,19 @@ class ManualScoringWindow(QtWidgets.QDialog):
225
234
  partial(self.modify_current_epoch_label, brain_state.digit)
226
235
  )
227
236
 
228
- keypress_delete_label = QtGui.QShortcut(
229
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Backspace), self
230
- )
237
+ keypress_delete_label = QShortcut(QKeySequence(Qt.Key.Key_Backspace), self)
231
238
  keypress_delete_label.activated.connect(
232
239
  partial(self.modify_current_epoch_label, UNDEFINED_LABEL)
233
240
  )
234
241
 
235
- keypress_quit = QtGui.QShortcut(
236
- QtGui.QKeySequence(
237
- QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_W)
238
- ),
242
+ keypress_quit = QShortcut(
243
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_W)),
239
244
  self,
240
245
  )
241
246
  keypress_quit.activated.connect(self.close)
242
247
 
243
- keypress_save = QtGui.QShortcut(
244
- QtGui.QKeySequence(
245
- QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_S)
246
- ),
248
+ keypress_save = QShortcut(
249
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_S)),
247
250
  self,
248
251
  )
249
252
  keypress_save.activated.connect(self.save)
@@ -251,11 +254,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
251
254
  keypress_roi = list()
252
255
  for brain_state in self.brain_state_set.brain_states:
253
256
  keypress_roi.append(
254
- QtGui.QShortcut(
255
- QtGui.QKeySequence(
256
- QtCore.QKeyCombination(
257
- QtCore.Qt.Modifier.SHIFT,
258
- QtCore.Qt.Key[f"Key_{brain_state.digit}"],
257
+ QShortcut(
258
+ QKeySequence(
259
+ QKeyCombination(
260
+ Qt.Modifier.SHIFT,
261
+ Qt.Key[f"Key_{brain_state.digit}"],
259
262
  )
260
263
  ),
261
264
  self,
@@ -265,11 +268,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
265
268
  partial(self.enter_label_roi_mode, brain_state.digit)
266
269
  )
267
270
  keypress_roi.append(
268
- QtGui.QShortcut(
269
- QtGui.QKeySequence(
270
- QtCore.QKeyCombination(
271
- QtCore.Qt.Modifier.SHIFT,
272
- QtCore.Qt.Key.Key_Backspace,
271
+ QShortcut(
272
+ QKeySequence(
273
+ QKeyCombination(
274
+ Qt.Modifier.SHIFT,
275
+ Qt.Key.Key_Backspace,
273
276
  )
274
277
  ),
275
278
  self,
@@ -279,22 +282,18 @@ class ManualScoringWindow(QtWidgets.QDialog):
279
282
  partial(self.enter_label_roi_mode, UNDEFINED_LABEL)
280
283
  )
281
284
 
282
- keypress_esc = QtGui.QShortcut(
283
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Escape), self
284
- )
285
+ keypress_esc = QShortcut(QKeySequence(Qt.Key.Key_Escape), self)
285
286
  keypress_esc.activated.connect(self.exit_label_roi_mode)
286
287
 
287
- keypress_space = QtGui.QShortcut(
288
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Space), self
289
- )
288
+ keypress_space = QShortcut(QKeySequence(Qt.Key.Key_Space), self)
290
289
  keypress_space.activated.connect(
291
290
  partial(self.jump_to_next_state, DIRECTION_RIGHT, DIFFERENT_STATE)
292
291
  )
293
- keypress_shift_right = QtGui.QShortcut(
294
- QtGui.QKeySequence(
295
- QtCore.QKeyCombination(
296
- QtCore.Qt.Modifier.SHIFT,
297
- QtCore.Qt.Key.Key_Right,
292
+ keypress_shift_right = QShortcut(
293
+ QKeySequence(
294
+ QKeyCombination(
295
+ Qt.Modifier.SHIFT,
296
+ Qt.Key.Key_Right,
298
297
  )
299
298
  ),
300
299
  self,
@@ -302,11 +301,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
302
301
  keypress_shift_right.activated.connect(
303
302
  partial(self.jump_to_next_state, DIRECTION_RIGHT, DIFFERENT_STATE)
304
303
  )
305
- keypress_shift_left = QtGui.QShortcut(
306
- QtGui.QKeySequence(
307
- QtCore.QKeyCombination(
308
- QtCore.Qt.Modifier.SHIFT,
309
- QtCore.Qt.Key.Key_Left,
304
+ keypress_shift_left = QShortcut(
305
+ QKeySequence(
306
+ QKeyCombination(
307
+ Qt.Modifier.SHIFT,
308
+ Qt.Key.Key_Left,
310
309
  )
311
310
  ),
312
311
  self,
@@ -314,11 +313,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
314
313
  keypress_shift_left.activated.connect(
315
314
  partial(self.jump_to_next_state, DIRECTION_LEFT, DIFFERENT_STATE)
316
315
  )
317
- keypress_ctrl_right = QtGui.QShortcut(
318
- QtGui.QKeySequence(
319
- QtCore.QKeyCombination(
320
- QtCore.Qt.Modifier.CTRL,
321
- QtCore.Qt.Key.Key_Right,
316
+ keypress_ctrl_right = QShortcut(
317
+ QKeySequence(
318
+ QKeyCombination(
319
+ Qt.Modifier.CTRL,
320
+ Qt.Key.Key_Right,
322
321
  )
323
322
  ),
324
323
  self,
@@ -326,11 +325,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
326
325
  keypress_ctrl_right.activated.connect(
327
326
  partial(self.jump_to_next_state, DIRECTION_RIGHT, UNDEFINED_STATE)
328
327
  )
329
- keypress_ctrl_left = QtGui.QShortcut(
330
- QtGui.QKeySequence(
331
- QtCore.QKeyCombination(
332
- QtCore.Qt.Modifier.CTRL,
333
- QtCore.Qt.Key.Key_Left,
328
+ keypress_ctrl_left = QShortcut(
329
+ QKeySequence(
330
+ QKeyCombination(
331
+ Qt.Modifier.CTRL,
332
+ Qt.Key.Key_Left,
334
333
  )
335
334
  ),
336
335
  self,
@@ -339,17 +338,13 @@ class ManualScoringWindow(QtWidgets.QDialog):
339
338
  partial(self.jump_to_next_state, DIRECTION_LEFT, UNDEFINED_STATE)
340
339
  )
341
340
 
342
- keypress_undo = QtGui.QShortcut(
343
- QtGui.QKeySequence(
344
- QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_Z)
345
- ),
341
+ keypress_undo = QShortcut(
342
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_Z)),
346
343
  self,
347
344
  )
348
345
  keypress_undo.activated.connect(self.undo)
349
- keypress_redo = QtGui.QShortcut(
350
- QtGui.QKeySequence(
351
- QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_Y)
352
- ),
346
+ keypress_redo = QShortcut(
347
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_Y)),
353
348
  self,
354
349
  )
355
350
  keypress_redo.activated.connect(self.redo)
@@ -483,34 +478,34 @@ class ManualScoringWindow(QtWidgets.QDialog):
483
478
  )
484
479
  self.click_to_jump(simulated_click)
485
480
 
486
- def closeEvent(self, event: QtGui.QCloseEvent) -> None:
481
+ def closeEvent(self, event: QCloseEvent) -> None:
487
482
  """Check if there are unsaved changes before closing"""
488
483
  if not all(self.labels == self.last_saved_labels):
489
- result = QtWidgets.QMessageBox.question(
484
+ result = QMessageBox.question(
490
485
  self,
491
486
  "Unsaved changes",
492
487
  "You have unsaved changes. Really quit?",
493
- QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
488
+ QMessageBox.Yes | QMessageBox.No,
494
489
  )
495
- if result == QtWidgets.QMessageBox.Yes:
490
+ if result == QMessageBox.Yes:
496
491
  event.accept()
497
492
  else:
498
493
  event.ignore()
499
494
 
500
495
  def show_user_manual(self) -> None:
501
496
  """Show a popup window with the user manual"""
502
- self.popup = QtWidgets.QWidget()
503
- self.popup_vlayout = QtWidgets.QVBoxLayout(self.popup)
504
- self.guide_textbox = QtWidgets.QTextBrowser(self.popup)
497
+ self.popup = QWidget()
498
+ self.popup_vlayout = QVBoxLayout(self.popup)
499
+ self.guide_textbox = QTextBrowser(self.popup)
505
500
  self.popup_vlayout.addWidget(self.guide_textbox)
506
501
 
507
- url = QtCore.QUrl.fromLocalFile(
502
+ url = QUrl.fromLocalFile(
508
503
  os.path.join(os.path.dirname(os.path.abspath(__file__)), USER_MANUAL_FILE)
509
504
  )
510
505
  self.guide_textbox.setSource(url)
511
506
  self.guide_textbox.setOpenLinks(False)
512
507
 
513
- self.popup.setGeometry(QtCore.QRect(100, 100, 830, 600))
508
+ self.popup.setGeometry(QRect(100, 100, 830, 600))
514
509
  self.popup.show()
515
510
 
516
511
  def jump_to_next_state(self, direction: str, target: str) -> None:
@@ -339,18 +339,23 @@ def resample_x_ticks(x_ticks: np.array) -> np.array:
339
339
  """Choose a subset of x_ticks to display
340
340
 
341
341
  The x-axis can get crowded if there are too many timestamps shown.
342
- This function resamples the x-axis ticks by a factor of either
343
- MAX_LOWER_X_TICK_N or MAX_LOWER_X_TICK_N - 2, whichever is closer
344
- to being a factor of the number of ticks.
342
+ This function finds a subset of evenly spaced x-axis ticks that
343
+ includes the one at the beginning of the central epoch.
345
344
 
346
345
  :param x_ticks: full set of x_ticks
347
346
  :return: smaller subset of x_ticks
348
347
  """
349
- # add one since the tick at the rightmost edge isn't shown
350
- n_ticks = len(x_ticks) + 1
351
- if n_ticks < MAX_LOWER_X_TICK_N:
348
+ if len(x_ticks) <= MAX_LOWER_X_TICK_N:
352
349
  return x_ticks
353
- elif n_ticks % MAX_LOWER_X_TICK_N < n_ticks % (MAX_LOWER_X_TICK_N - 2):
354
- return x_ticks[:: int(n_ticks / MAX_LOWER_X_TICK_N)]
355
- else:
356
- return x_ticks[:: int(n_ticks / (MAX_LOWER_X_TICK_N - 2))]
350
+
351
+ # number of ticks to the left of the central epoch
352
+ # this will always be an integer
353
+ nl = round((len(x_ticks) - 1) / 2)
354
+
355
+ # search for even tick spacings that include the central epoch
356
+ # if necessary, skip the leftmost tick
357
+ for offset in [0, 1]:
358
+ if (nl - offset) % 3 == 0:
359
+ return x_ticks[offset :: round((nl - offset) / 3)]
360
+ elif (nl - offset) % 2 == 0:
361
+ return x_ticks[offset :: round((nl - offset) / 2)]
@@ -30,6 +30,7 @@ from PySide6.QtWidgets import (
30
30
  QVBoxLayout,
31
31
  QWidget,
32
32
  )
33
+
33
34
  import accusleepy.gui.resources_rc # noqa F401
34
35
 
35
36
 
accusleepy/models.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- from torch import nn
2
+ from torch import device, flatten, nn
3
+ from torch import load as torch_load
4
+ from torch import save as torch_save
5
5
 
6
+ from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainStateSet
6
7
  from accusleepy.constants import (
7
8
  DOWNSAMPLING_START_FREQ,
8
9
  EMG_COPIES,
@@ -41,8 +42,57 @@ class SSANN(nn.Module):
41
42
 
42
43
  def forward(self, x):
43
44
  x = x.float()
44
- x = self.pool(F.relu(self.conv1_bn(self.conv1(x))))
45
- x = self.pool(F.relu(self.conv2_bn(self.conv2(x))))
46
- x = self.pool(F.relu(self.conv3_bn(self.conv3(x))))
47
- x = torch.flatten(x, 1) # flatten all dimensions except batch
45
+ x = self.pool(nn.functional.relu(self.conv1_bn(self.conv1(x))))
46
+ x = self.pool(nn.functional.relu(self.conv2_bn(self.conv2(x))))
47
+ x = self.pool(nn.functional.relu(self.conv3_bn(self.conv3(x))))
48
+ x = flatten(x, 1) # flatten all dimensions except batch
48
49
  return self.fc1(x)
50
+
51
+
52
+ def save_model(
53
+ model: SSANN,
54
+ filename: str,
55
+ epoch_length: int | float,
56
+ epochs_per_img: int,
57
+ model_type: str,
58
+ brain_state_set: BrainStateSet,
59
+ ) -> None:
60
+ """Save classification model and its metadata
61
+
62
+ :param model: classification model
63
+ :param epoch_length: epoch length used when training the model
64
+ :param epochs_per_img: number of epochs in each model input
65
+ :param model_type: default or real-time
66
+ :param brain_state_set: set of brain state options
67
+ :param filename: filename
68
+ """
69
+ state_dict = model.state_dict()
70
+ state_dict.update({"epoch_length": epoch_length})
71
+ state_dict.update({"epochs_per_img": epochs_per_img})
72
+ state_dict.update({"model_type": model_type})
73
+ state_dict.update(
74
+ {BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
75
+ )
76
+
77
+ torch_save(state_dict, filename)
78
+
79
+
80
+ def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
81
+ """Load classification model and its metadata
82
+
83
+ :param filename: filename
84
+ :return: model, epoch length used when training the model,
85
+ number of epochs in each model input, model type
86
+ (default or real-time), set of brain state options
87
+ used when training the model
88
+ """
89
+ state_dict = torch_load(filename, weights_only=True, map_location=device("cpu"))
90
+ epoch_length = state_dict.pop("epoch_length")
91
+ epochs_per_img = state_dict.pop("epochs_per_img")
92
+ model_type = state_dict.pop("model_type")
93
+ brain_states = state_dict.pop(BRAIN_STATES_KEY)
94
+ n_classes = len([b for b in brain_states if b["is_scored"]])
95
+
96
+ model = SSANN(n_classes=n_classes)
97
+ model.load_state_dict(state_dict)
98
+ return model, epoch_length, epochs_per_img, model_type, brain_states
accusleepy/multitaper.py CHANGED
@@ -15,8 +15,9 @@ import warnings
15
15
 
16
16
  import numpy as np
17
17
  from joblib import Parallel, cpu_count, delayed
18
- from scipy.signal import detrend
19
- from scipy.signal.windows import dpss
18
+
19
+ # from scipy.signal import detrend # unused by AccuSleePy
20
+ # from scipy.signal.windows import dpss # lazily loaded later
20
21
 
21
22
 
22
23
  # MULTITAPER SPECTROGRAM #
@@ -28,14 +29,14 @@ def spectrogram(
28
29
  num_tapers=None,
29
30
  window_params=None,
30
31
  min_nfft=0,
31
- detrend_opt="linear",
32
+ detrend_opt="off", # this functionality is disabled
32
33
  multiprocess=False,
33
34
  n_jobs=None,
34
35
  weighting="unity",
35
36
  plot_on=False,
36
37
  return_fig=False,
37
38
  clim_scale=True,
38
- verbose=True,
39
+ verbose=False,
39
40
  xyflip=False,
40
41
  ax=None,
41
42
  ):
@@ -121,6 +122,7 @@ def spectrogram(
121
122
 
122
123
  __________________________________________________________________________________________________________________
123
124
  """
125
+ from scipy.signal.windows import dpss
124
126
 
125
127
  # Process user input
126
128
  [
@@ -618,9 +620,9 @@ def calc_mts_segment(
618
620
  ret.fill(np.nan)
619
621
  return ret
620
622
 
621
- # Option to detrend data to remove low frequency DC component
622
- if detrend_opt != "off":
623
- data_segment = detrend(data_segment, type=detrend_opt)
623
+ # # Option to detrend data to remove low frequency DC component
624
+ # if detrend_opt != "off":
625
+ # data_segment = detrend(data_segment, type=detrend_opt)
624
626
 
625
627
  # Multiply data by dpss tapers (STEP 2)
626
628
  tapered_data = np.multiply(np.asmatrix(data_segment).T, np.asmatrix(dpss_tapers.T))
@@ -1,17 +1,14 @@
1
1
  import os
2
- import re
3
2
  import warnings
4
- from dataclasses import dataclass
5
- from operator import attrgetter
6
3
 
7
4
  import numpy as np
8
5
  import pandas as pd
9
6
  from PIL import Image
10
- from scipy.signal import butter, filtfilt
11
7
  from tqdm import trange
12
8
 
13
9
  from accusleepy.brain_state_set import BrainStateSet
14
10
  from accusleepy.constants import (
11
+ ANNOTATIONS_FILENAME,
15
12
  DEFAULT_MODEL_TYPE,
16
13
  DOWNSAMPLING_START_FREQ,
17
14
  EMG_COPIES,
@@ -23,13 +20,13 @@ from accusleepy.constants import (
23
20
  from accusleepy.fileio import Recording, load_labels, load_recording
24
21
  from accusleepy.multitaper import spectrogram
25
22
 
23
+ # note: scipy is lazily imported
24
+
26
25
  # clip mixture z-scores above and below this level
27
26
  # in the matlab implementation, I used 4.5
28
27
  ABS_MAX_Z_SCORE = 3.5
29
28
  # upper frequency limit when generating EEG spectrograms
30
29
  SPECTROGRAM_UPPER_FREQ = 64
31
- # filename used to store info about training image datasets
32
- ANNOTATIONS_FILENAME = "annotations.csv"
33
30
 
34
31
 
35
32
  def resample(
@@ -186,6 +183,8 @@ def get_emg_power(
186
183
  :param epoch_length: epoch length, in seconds
187
184
  :return: EMG "power" for each epoch
188
185
  """
186
+ from scipy.signal import butter, filtfilt
187
+
189
188
  # filter parameters
190
189
  order = 8
191
190
  bp_lower = 20
@@ -450,140 +449,3 @@ def create_training_images(
450
449
  )
451
450
 
452
451
  return failed_recordings
453
-
454
-
455
- @dataclass
456
- class Bout:
457
- """Stores information about a brain state bout"""
458
-
459
- length: int # length, in number of epochs
460
- start_index: int # index where bout starts
461
- end_index: int # index where bout ends
462
- surrounding_state: int # brain state on both sides of the bout
463
-
464
-
465
- def find_last_adjacent_bout(sorted_bouts: list[Bout], bout_index: int) -> int:
466
- """Find index of last consecutive same-length bout
467
-
468
- When running the post-processing step that enforces a minimum duration
469
- for brain state bouts, there is a special case when bouts below the
470
- duration threshold occur consecutively. This function performs a
471
- recursive search for the index of a bout at the end of such a sequence.
472
- When initially called, bout_index will always be 0. If, for example, the
473
- first three bouts in the list are consecutive, the function will return 2.
474
-
475
- :param sorted_bouts: list of brain state bouts, sorted by start time
476
- :param bout_index: index of the bout in question
477
- :return: index of the last consecutive same-length bout
478
- """
479
- # if we're at the end of the bout list, stop
480
- if bout_index == len(sorted_bouts) - 1:
481
- return bout_index
482
-
483
- # if there is an adjacent bout
484
- if sorted_bouts[bout_index].end_index == sorted_bouts[bout_index + 1].start_index:
485
- # look for more adjacent bouts using that one as a starting point
486
- return find_last_adjacent_bout(sorted_bouts, bout_index + 1)
487
- else:
488
- return bout_index
489
-
490
-
491
- def enforce_min_bout_length(
492
- labels: np.array, epoch_length: int | float, min_bout_length: int | float
493
- ) -> np.array:
494
- """Ensure brain state bouts meet the min length requirement
495
-
496
- As a post-processing step for sleep scoring, we can require that any
497
- bout (continuous period) of a brain state have a minimum duration.
498
- This function sets any bout shorter than the minimum duration to the
499
- surrounding brain state (if the states on the left and right sides
500
- are the same). In the case where there are consecutive short bouts,
501
- it either creates a transition at the midpoint or removes all short
502
- bouts, depending on whether the number is even or odd. For example:
503
- ...AAABABAAA... -> ...AAAAAAAAA...
504
- ...AAABABABBB... -> ...AAAAABBBBB...
505
-
506
- :param labels: brain state labels (digits in the 0-9 range)
507
- :param epoch_length: epoch length, in seconds
508
- :param min_bout_length: minimum bout length, in seconds
509
- :return: updated brain state labels
510
- """
511
- # if recording is very short, don't change anything
512
- if labels.size < 3:
513
- return labels
514
-
515
- if epoch_length == min_bout_length:
516
- return labels
517
-
518
- # get minimum number of epochs in a bout
519
- min_epochs = int(np.ceil(min_bout_length / epoch_length))
520
- # get set of states in the labels
521
- brain_states = set(labels.tolist())
522
-
523
- while True: # so true
524
- # convert labels to a string for regex search
525
- # There is probably a regex that can find all patterns like ab+a
526
- # without consuming each "a" but I haven't found it :(
527
- label_string = "".join(labels.astype(str))
528
-
529
- bouts = list()
530
-
531
- for state in brain_states:
532
- for other_state in brain_states:
533
- if state == other_state:
534
- continue
535
- # get start and end indices of each bout
536
- expression = (
537
- f"(?<={other_state}){state}{{1,{min_epochs - 1}}}(?={other_state})"
538
- )
539
- matches = re.finditer(expression, label_string)
540
- spans = [match.span() for match in matches]
541
-
542
- # if some bouts were found
543
- for span in spans:
544
- bouts.append(
545
- Bout(
546
- length=span[1] - span[0],
547
- start_index=span[0],
548
- end_index=span[1],
549
- surrounding_state=other_state,
550
- )
551
- )
552
-
553
- if len(bouts) == 0:
554
- break
555
-
556
- # only keep the shortest bouts
557
- min_length_in_list = np.min([bout.length for bout in bouts])
558
- bouts = [i for i in bouts if i.length == min_length_in_list]
559
- # sort by start index
560
- sorted_bouts = sorted(bouts, key=attrgetter("start_index"))
561
-
562
- while len(sorted_bouts) > 0:
563
- # get row index of latest adjacent bout (of same length)
564
- last_adjacent_bout_index = find_last_adjacent_bout(sorted_bouts, 0)
565
- # if there's an even number of adjacent bouts
566
- if (last_adjacent_bout_index + 1) % 2 == 0:
567
- midpoint = sorted_bouts[
568
- round((last_adjacent_bout_index + 1) / 2)
569
- ].start_index
570
- labels[sorted_bouts[0].start_index : midpoint] = sorted_bouts[
571
- 0
572
- ].surrounding_state
573
- labels[midpoint : sorted_bouts[last_adjacent_bout_index].end_index] = (
574
- sorted_bouts[last_adjacent_bout_index].surrounding_state
575
- )
576
- else:
577
- labels[
578
- sorted_bouts[0].start_index : sorted_bouts[
579
- last_adjacent_bout_index
580
- ].end_index
581
- ] = sorted_bouts[0].surrounding_state
582
-
583
- # delete the bouts we just fixed
584
- if last_adjacent_bout_index == len(sorted_bouts) - 1:
585
- sorted_bouts = []
586
- else:
587
- sorted_bouts = sorted_bouts[(last_adjacent_bout_index + 1) :]
588
-
589
- return labels
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: accusleepy
3
- Version: 0.4.5
3
+ Version: 0.5.0
4
4
  Summary: Python implementation of AccuSleep
5
5
  License: GPL-3.0-only
6
6
  Author: Zeke Barger
@@ -75,6 +75,7 @@ to the [config file](accusleepy/config.json).
75
75
 
76
76
  ## Changelog
77
77
 
78
+ - 0.5.0: Performance improvements
78
79
  - 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
79
80
  - 0.4.4: Performance improvements
80
81
  - 0.4.3: Improved unit tests and user manuals
@@ -1,10 +1,11 @@
1
1
  accusleepy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  accusleepy/__main__.py,sha256=dKzl2N2Hg9lD264CWYNxThRyDKzWwyMwHRXmJxOmMis,104
3
+ accusleepy/bouts.py,sha256=F_y6DxnpKFfImYb7vCZluZ2eD5I_33gZXmRM8mvebsg,5679
3
4
  accusleepy/brain_state_set.py,sha256=fRkrArHLIbEKimub804yt_mUXoyfsjJEfiJnTjeCMkY,3233
4
- accusleepy/classification.py,sha256=xrmPyMHlzYh0QfNCID1PRIYEIyNkWduOi7g1Bdb6xfg,8573
5
+ accusleepy/classification.py,sha256=czRGcDYN28QK5Nsahy6y2C162OgbO1nhjJxWZT9AvWc,8579
5
6
  accusleepy/config.json,sha256=F76WRLarMEW38BBMPwFlQ_d7Dur-ptqYmW8BxqnQF4A,464
6
- accusleepy/constants.py,sha256=PnsPANggyIfMfd6OCR-kNztFOTybUEhMnPeibu5_eEU,1280
7
- accusleepy/fileio.py,sha256=S5pf_hE-btJPMbrTplKLaQTULSJQoOJ-56LBH79Uz3I,6383
7
+ accusleepy/constants.py,sha256=mb6Tjzat-tWOHdz2I1mqW7NtDzDKcy3rVjeqSdOQ2qE,1381
8
+ accusleepy/fileio.py,sha256=qJfnAnGou337z9_ngBpqsyhCawKazh2DpQdneFZMaMg,4547
8
9
  accusleepy/gui/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
10
  accusleepy/gui/icons/brightness_down.png,sha256=PLT1fb83RHIhSRuU7MMMx0G7oJAY7o9wUcnqM8veZfM,12432
10
11
  accusleepy/gui/icons/brightness_up.png,sha256=64GnUqgPvN5xZ6Um3wOzwqvUmdAWYZT6eFmWpBsHyks,12989
@@ -17,13 +18,13 @@ accusleepy/gui/icons/save.png,sha256=J3EA8iU1BqLYRSsrq_OdoZlqrv2yfL7oV54DklTy_DI
17
18
  accusleepy/gui/icons/up_arrow.png,sha256=V9yF9t1WgjPaUu-mF1YGe_DfaRHg2dUpR_sUVVcvVvY,3329
18
19
  accusleepy/gui/icons/zoom_in.png,sha256=MFWnKZp7Rvh4bLPq4Cqo4sB_jQYedUUtT8-ZO8tNYyc,13589
19
20
  accusleepy/gui/icons/zoom_out.png,sha256=IB8Jecb3i0U4qjWRR46ridjLpvLCSe7PozBaLqQqYSw,13055
20
- accusleepy/gui/images/primary_window.png,sha256=x_ppmv0fKySxXAzbQHCv5JFLdM2ETTxJqUHyPVo5xck,596421
21
+ accusleepy/gui/images/primary_window.png,sha256=-JHTb7bvRS-mUoSl9XRNmhBjQHHiwwuO80jwblp-IO8,598790
21
22
  accusleepy/gui/images/viewer_window.png,sha256=gKwIXkgsl1rTMfmMeMwNyjEAUL5I6FXk9-hpMR92qTI,970630
22
23
  accusleepy/gui/images/viewer_window_annotated.png,sha256=M5NmoWDHRLS334Rp8SsfOPUUXzPltH1p7aB0BrISgQU,261481
23
- accusleepy/gui/main.py,sha256=QLtxVvBz81CjDeBddJZNDCLd207j9WJga1n45nlsxWI,54384
24
- accusleepy/gui/manual_scoring.py,sha256=Sy4vwMmMLY_SreXUyxd0t-at2F-1pHkvvz1UGIbzVic,40496
25
- accusleepy/gui/mplwidget.py,sha256=f9O3u_96whQGUwpi3o_QGc7yjiETX5vE0oj3ePXTJWE,12279
26
- accusleepy/gui/primary_window.py,sha256=RXpDvcb7zy8Ea4Da1VhMG1T6GC54KW3vyGjqqQJN45k,104582
24
+ accusleepy/gui/main.py,sha256=VW4dYk4a2NPA3zl9omBkxaVRJDAvECTLYHsE0uIlnyo,54502
25
+ accusleepy/gui/manual_scoring.py,sha256=f4y_33kFzZ6krdtfkUDJqQ4LRhn7aC1Fwm6nwus4x1I,39853
26
+ accusleepy/gui/mplwidget.py,sha256=Jy3hdkTayo8KzKx1AJ7jS0n_w-CCLlAYqUK7uSEzIBY,12389
27
+ accusleepy/gui/primary_window.py,sha256=gk-IRcsjw4Sf7b1mQbf5RhZiPprv2BtvxGfYjAjrUmc,104583
27
28
  accusleepy/gui/primary_window.ui,sha256=09k4xFcjgOL9mlhFlg6mXCc_tgj4_FY9CZ3H03e3z3A,147074
28
29
  accusleepy/gui/resources.qrc,sha256=ByNEmJqr0YbKBqoZGvONZtjyNYr4ST4enO6TEdYSqWg,802
29
30
  accusleepy/gui/resources_rc.py,sha256=Z2e34h30U4snJjnYdZVV9B6yjATKxxfvgTRt5uXtQdo,329727
@@ -32,9 +33,9 @@ accusleepy/gui/text/main_guide.md,sha256=VS6A5_CzQOBwIotNgEA_X0KHKfMT4lEK43Ki_Dk
32
33
  accusleepy/gui/text/manual_scoring_guide.md,sha256=ow_RMSjFy05NupEDSCuJtu-V65-BPnIkrZqtssFoZCQ,999
33
34
  accusleepy/gui/viewer_window.py,sha256=5PkbuYMXUegH1CExCoqSGDZ9GeJqCCUz0-3WWkM8Vfc,24049
34
35
  accusleepy/gui/viewer_window.ui,sha256=D1LwUFR-kZ_GWGZFFtXvGJdFWghLrOWZTblQeLQt9kI,30525
35
- accusleepy/models.py,sha256=Muapsw088AUHqRIbW97Rkbv0oiwCtQvO9tEoBCC-MYg,1476
36
- accusleepy/multitaper.py,sha256=V6MJDk0OSWhg2MFhrnt9dvYrHiNsk2T7IxAA7paZVyE,25549
37
- accusleepy/signal_processing.py,sha256=-aXnywfp1LBsk3DcbIMmZlgv3f8j6sZ6js0bizZId0o,21718
38
- accusleepy-0.4.5.dist-info/METADATA,sha256=qKYYF4smvfZF2l92BjYNpDyT6mB6QTqg10kw4MAxtRU,3875
39
- accusleepy-0.4.5.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
40
- accusleepy-0.4.5.dist-info/RECORD,,
36
+ accusleepy/models.py,sha256=IJcMy102p8RndidqZ9HrL-eIGLOjQqc1hwebZHRvl5Q,3390
37
+ accusleepy/multitaper.py,sha256=D5-iglwkFBRciL5tKSNcunMtcq0rM3zHwRHUVPgem1U,25679
38
+ accusleepy/signal_processing.py,sha256=dxq0Nq9ae8ze5hX8vX7LXBqz1IJ3SoBsNQYmNSY5n2E,16023
39
+ accusleepy-0.5.0.dist-info/METADATA,sha256=F1AZtdxlgo70PPCJucnVhYyUa922FXsjMzS942wfqSs,3909
40
+ accusleepy-0.5.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
41
+ accusleepy-0.5.0.dist-info/RECORD,,