accusleepy 0.8.0__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
accusleepy/bouts.py CHANGED
@@ -11,7 +11,7 @@ class Bout:
11
11
 
12
12
  length: int # length, in number of epochs
13
13
  start_index: int # index where bout starts
14
- end_index: int # index where bout ends
14
+ end_index: int # index where bout ends (non-inclusive)
15
15
  surrounding_state: int # brain state on both sides of the bout
16
16
 
17
17
 
@@ -41,9 +41,47 @@ def find_last_adjacent_bout(sorted_bouts: list[Bout], bout_index: int) -> int:
41
41
  return bout_index
42
42
 
43
43
 
44
+ def find_short_bouts(
45
+ labels: np.ndarray, min_epochs: int, brain_states: set[int]
46
+ ) -> list[Bout]:
47
+ """Locate all brain state bouts below a minimum length
48
+
49
+ :param labels: brain state labels (digits in the 0-9 range)
50
+ :param min_epochs: minimum number of epochs in a bout
51
+ :param brain_states: set of brain states in the labels
52
+ :return: list of Bout objects
53
+ """
54
+ # convert labels to a string for regex search
55
+ # There is probably a regex that can find all patterns like ab+a
56
+ # without consuming each "a" but I haven't found it :(
57
+ label_string = "".join(labels.astype(str))
58
+ bouts = list()
59
+ for state in brain_states:
60
+ for other_state in brain_states:
61
+ if state == other_state:
62
+ continue
63
+ # get start and end indices of each bout
64
+ expression = (
65
+ f"(?<={other_state}){state}{{1,{min_epochs - 1}}}(?={other_state})"
66
+ )
67
+ matches = re.finditer(expression, label_string)
68
+ spans = [match.span() for match in matches]
69
+
70
+ for span in spans:
71
+ bouts.append(
72
+ Bout(
73
+ length=span[1] - span[0],
74
+ start_index=span[0],
75
+ end_index=span[1],
76
+ surrounding_state=other_state,
77
+ )
78
+ )
79
+ return bouts
80
+
81
+
44
82
  def enforce_min_bout_length(
45
- labels: np.array, epoch_length: int | float, min_bout_length: int | float
46
- ) -> np.array:
83
+ labels: np.ndarray, epoch_length: int | float, min_bout_length: int | float
84
+ ) -> np.ndarray:
47
85
  """Ensure brain state bouts meet the min length requirement
48
86
 
49
87
  As a post-processing step for sleep scoring, we can require that any
@@ -61,11 +99,9 @@ def enforce_min_bout_length(
61
99
  :param min_bout_length: minimum bout length, in seconds
62
100
  :return: updated brain state labels
63
101
  """
64
- # if recording is very short, don't change anything
65
- if labels.size < 3:
66
- return labels
67
-
68
- if epoch_length == min_bout_length:
102
+ # if the recording is very short or the minimum bout length
103
+ # is one epoch long, don't change anything
104
+ if labels.size < 3 or epoch_length == min_bout_length:
69
105
  return labels
70
106
 
71
107
  # get minimum number of epochs in a bout
@@ -73,36 +109,8 @@ def enforce_min_bout_length(
73
109
  # get set of states in the labels
74
110
  brain_states = set(labels.tolist())
75
111
 
76
- while True: # so true
77
- # convert labels to a string for regex search
78
- # There is probably a regex that can find all patterns like ab+a
79
- # without consuming each "a" but I haven't found it :(
80
- label_string = "".join(labels.astype(str))
81
-
82
- bouts = list()
83
-
84
- for state in brain_states:
85
- for other_state in brain_states:
86
- if state == other_state:
87
- continue
88
- # get start and end indices of each bout
89
- expression = (
90
- f"(?<={other_state}){state}{{1,{min_epochs - 1}}}(?={other_state})"
91
- )
92
- matches = re.finditer(expression, label_string)
93
- spans = [match.span() for match in matches]
94
-
95
- # if some bouts were found
96
- for span in spans:
97
- bouts.append(
98
- Bout(
99
- length=span[1] - span[0],
100
- start_index=span[0],
101
- end_index=span[1],
102
- surrounding_state=other_state,
103
- )
104
- )
105
-
112
+ while True:
113
+ bouts = find_short_bouts(labels, min_epochs, brain_states)
106
114
  if len(bouts) == 0:
107
115
  break
108
116
 
@@ -41,7 +41,7 @@ class BrainStateSet:
41
41
  i = 0
42
42
  for brain_state in self.brain_states:
43
43
  if brain_state.digit == undefined_label:
44
- raise Exception(
44
+ raise ValueError(
45
45
  f"Digit for {brain_state.name} matches 'undefined' label"
46
46
  )
47
47
  if brain_state.is_scored:
@@ -56,9 +56,11 @@ class BrainStateSet:
56
56
 
57
57
  self.mixture_weights = np.array(self.mixture_weights)
58
58
  if np.sum(self.mixture_weights) != 1:
59
- raise Exception("Typical frequencies for scored brain states must sum to 1")
59
+ raise ValueError(
60
+ "Typical frequencies for scored brain states must sum to 1"
61
+ )
60
62
 
61
- def convert_digit_to_class(self, digits: np.array) -> np.array:
63
+ def convert_digit_to_class(self, digits: np.ndarray) -> np.ndarray:
62
64
  """Convert array of digits to their corresponding classes
63
65
 
64
66
  :param digits: array of digits
@@ -66,7 +68,7 @@ class BrainStateSet:
66
68
  """
67
69
  return np.array([self.digit_to_class[i] for i in digits])
68
70
 
69
- def convert_class_to_digit(self, classes: np.array) -> np.array:
71
+ def convert_class_to_digit(self, classes: np.ndarray) -> np.ndarray:
70
72
  """Convert array of classes to their corresponding digits
71
73
 
72
74
  :param classes: array of classes
@@ -16,7 +16,6 @@ from accusleepy.models import SSANN
16
16
  from accusleepy.signal_processing import (
17
17
  create_eeg_emg_image,
18
18
  format_img,
19
- get_mixture_values,
20
19
  mixture_z_score_img,
21
20
  )
22
21
 
@@ -84,7 +83,7 @@ def create_dataloader(
84
83
  def train_ssann(
85
84
  annotations_file: str,
86
85
  img_dir: str,
87
- mixture_weights: np.array,
86
+ training_class_balance: np.ndarray,
88
87
  n_classes: int,
89
88
  hyperparameters: Hyperparameters,
90
89
  ) -> SSANN:
@@ -92,7 +91,7 @@ def train_ssann(
92
91
 
93
92
  :param annotations_file: file with information on each training image
94
93
  :param img_dir: training image location
95
- :param mixture_weights: typical relative frequencies of brain states
94
+ :param training_class_balance: proportion of each class in the training set
96
95
  :param n_classes: number of classes the model will learn
97
96
  :param hyperparameters: model training hyperparameters
98
97
  :return: trained Sleep Scoring Artificial Neural Network model
@@ -109,7 +108,7 @@ def train_ssann(
109
108
  model.train()
110
109
 
111
110
  # correct for class imbalance
112
- weight = torch.tensor((mixture_weights**-1).astype("float32")).to(device)
111
+ weight = torch.tensor((training_class_balance**-1).astype("float32")).to(device)
113
112
 
114
113
  criterion = nn.CrossEntropyLoss(weight=weight)
115
114
  optimizer = optim.SGD(
@@ -133,16 +132,16 @@ def train_ssann(
133
132
 
134
133
  def score_recording(
135
134
  model: SSANN,
136
- eeg: np.array,
137
- emg: np.array,
138
- mixture_means: np.array,
139
- mixture_sds: np.array,
135
+ eeg: np.ndarray,
136
+ emg: np.ndarray,
137
+ mixture_means: np.ndarray,
138
+ mixture_sds: np.ndarray,
140
139
  sampling_rate: int | float,
141
140
  epoch_length: int | float,
142
141
  epochs_per_img: int,
143
142
  brain_state_set: BrainStateSet,
144
143
  emg_filter: EMGFilter,
145
- ) -> np.array:
144
+ ) -> tuple[np.ndarray, np.ndarray]:
146
145
  """Use classification model to get brain state labels for a recording
147
146
 
148
147
  This assumes signals have been preprocessed to contain an integer
@@ -167,7 +166,7 @@ def score_recording(
167
166
 
168
167
  # create and scale eeg+emg spectrogram
169
168
  img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
170
- img = mixture_z_score_img(
169
+ img, _ = mixture_z_score_img(
171
170
  img,
172
171
  mixture_means=mixture_means,
173
172
  mixture_sds=mixture_sds,
@@ -196,10 +195,10 @@ def score_recording(
196
195
 
197
196
  def example_real_time_scoring_function(
198
197
  model: SSANN,
199
- eeg: np.array,
200
- emg: np.array,
201
- mixture_means: np.array,
202
- mixture_sds: np.array,
198
+ eeg: np.ndarray,
199
+ emg: np.ndarray,
200
+ mixture_means: np.ndarray,
201
+ mixture_sds: np.ndarray,
203
202
  sampling_rate: int | float,
204
203
  epoch_length: int | float,
205
204
  epochs_per_img: int,
@@ -244,7 +243,7 @@ def example_real_time_scoring_function(
244
243
 
245
244
  # create and scale eeg+emg spectrogram
246
245
  img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
247
- img = mixture_z_score_img(
246
+ img, _ = mixture_z_score_img(
248
247
  img,
249
248
  mixture_means=mixture_means,
250
249
  mixture_sds=mixture_sds,
@@ -264,38 +263,3 @@ def example_real_time_scoring_function(
264
263
 
265
264
  label = int(brain_state_set.convert_class_to_digit(predicted.cpu().numpy())[0])
266
265
  return label
267
-
268
-
269
- def create_calibration_file(
270
- filename: str,
271
- eeg: np.array,
272
- emg: np.array,
273
- labels: np.array,
274
- sampling_rate: int | float,
275
- epoch_length: int | float,
276
- brain_state_set: BrainStateSet,
277
- emg_filter: EMGFilter,
278
- ) -> None:
279
- """Create file of calibration data for a subject
280
-
281
- This assumes signals have been preprocessed to contain an integer
282
- number of epochs.
283
-
284
- :param filename: filename for the calibration file
285
- :param eeg: EEG signal
286
- :param emg: EMG signal
287
- :param labels: brain state labels, as digits
288
- :param sampling_rate: sampling rate, in Hz
289
- :param epoch_length: epoch length, in seconds
290
- :param brain_state_set: set of brain state options
291
- :param emg_filter: EMG filter parameters
292
- """
293
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
294
- mixture_means, mixture_sds = get_mixture_values(
295
- img=img,
296
- labels=brain_state_set.convert_digit_to_class(labels),
297
- brain_state_set=brain_state_set,
298
- )
299
- pd.DataFrame(
300
- {c.MIXTURE_MEAN_COL: mixture_means, c.MIXTURE_SD_COL: mixture_sds}
301
- ).to_csv(filename, index=False)
accusleepy/constants.py CHANGED
@@ -18,6 +18,9 @@ MESSAGE_BOX_MAX_DEPTH = 200
18
18
  ABS_MAX_Z_SCORE = 3.5
19
19
  # upper frequency limit when generating EEG spectrograms
20
20
  SPECTROGRAM_UPPER_FREQ = 64
21
+ # minimum number of epochs per brain state needed to create
22
+ # a calibration file or use a recording for model training
23
+ MIN_EPOCHS_PER_STATE = 3
21
24
 
22
25
 
23
26
  # very unlikely you will want to change values from here onwards
accusleepy/fileio.py CHANGED
@@ -5,6 +5,7 @@ from dataclasses import dataclass
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  from PySide6.QtWidgets import QListWidgetItem
8
+ import toml
8
9
 
9
10
  from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
10
11
  import accusleepy.constants as c
@@ -29,6 +30,21 @@ class Hyperparameters:
29
30
  training_epochs: int
30
31
 
31
32
 
33
+ @dataclass
34
+ class AccuSleePyConfig:
35
+ """AccuSleePy configuration settings"""
36
+
37
+ brain_state_set: BrainStateSet
38
+ default_epoch_length: int | float
39
+ overwrite_setting: bool
40
+ save_confidence_setting: bool
41
+ min_bout_length: int | float
42
+ emg_filter: EMGFilter
43
+ hyperparameters: Hyperparameters
44
+ epochs_to_show: int
45
+ autoscroll_state: bool
46
+
47
+
32
48
  @dataclass
33
49
  class Recording:
34
50
  """Store information about a recording"""
@@ -41,7 +57,7 @@ class Recording:
41
57
  widget: QListWidgetItem = None # list item widget shown in the GUI
42
58
 
43
59
 
44
- def load_calibration_file(filename: str) -> (np.array, np.array):
60
+ def load_calibration_file(filename: str) -> tuple[np.ndarray, np.ndarray]:
45
61
  """Load a calibration file
46
62
 
47
63
  :param filename: filename
@@ -65,11 +81,11 @@ def load_csv_or_parquet(filename: str) -> pd.DataFrame:
65
81
  elif extension == ".parquet":
66
82
  df = pd.read_parquet(filename)
67
83
  else:
68
- raise Exception("file must be csv or parquet")
84
+ raise ValueError("file must be csv or parquet")
69
85
  return df
70
86
 
71
87
 
72
- def load_recording(filename: str) -> (np.array, np.array):
88
+ def load_recording(filename: str) -> tuple[np.ndarray, np.ndarray]:
73
89
  """Load recording of EEG and EMG time series data
74
90
 
75
91
  :param filename: filename
@@ -81,7 +97,7 @@ def load_recording(filename: str) -> (np.array, np.array):
81
97
  return eeg, emg
82
98
 
83
99
 
84
- def load_labels(filename: str) -> (np.array, np.array):
100
+ def load_labels(filename: str) -> tuple[np.ndarray, np.ndarray | None]:
85
101
  """Load file of brain state labels and confidence scores
86
102
 
87
103
  :param filename: filename
@@ -95,7 +111,7 @@ def load_labels(filename: str) -> (np.array, np.array):
95
111
 
96
112
 
97
113
  def save_labels(
98
- labels: np.array, filename: str, confidence_scores: np.array = None
114
+ labels: np.ndarray, filename: str, confidence_scores: np.ndarray | None = None
99
115
  ) -> None:
100
116
  """Save brain state labels to file
101
117
 
@@ -111,20 +127,11 @@ def save_labels(
111
127
  pd.DataFrame({c.BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
112
128
 
113
129
 
114
- def load_config() -> tuple[
115
- BrainStateSet,
116
- int | float,
117
- bool,
118
- bool,
119
- int | float,
120
- EMGFilter,
121
- Hyperparameters,
122
- int,
123
- bool,
124
- ]:
130
+ def load_config() -> AccuSleePyConfig:
125
131
  """Load configuration file with brain state options
126
132
 
127
- :return: set of brain state options,
133
+ :return: AccuSleePyConfig containing the following:
134
+ set of brain state options,
128
135
  default epoch length,
129
136
  default overwrite setting,
130
137
  default confidence score output setting,
@@ -139,15 +146,21 @@ def load_config() -> tuple[
139
146
  ) as f:
140
147
  data = json.load(f)
141
148
 
142
- return (
143
- BrainStateSet(
149
+ return AccuSleePyConfig(
150
+ brain_state_set=BrainStateSet(
144
151
  [BrainState(**b) for b in data[BRAIN_STATES_KEY]], c.UNDEFINED_LABEL
145
152
  ),
146
- data[c.DEFAULT_EPOCH_LENGTH_KEY],
147
- data.get(c.DEFAULT_OVERWRITE_KEY, c.DEFAULT_OVERWRITE_SETTING),
148
- data.get(c.DEFAULT_CONFIDENCE_SETTING_KEY, c.DEFAULT_CONFIDENCE_SETTING),
149
- data.get(c.DEFAULT_MIN_BOUT_LENGTH_KEY, c.DEFAULT_MIN_BOUT_LENGTH),
150
- EMGFilter(
153
+ default_epoch_length=data[c.DEFAULT_EPOCH_LENGTH_KEY],
154
+ overwrite_setting=data.get(
155
+ c.DEFAULT_OVERWRITE_KEY, c.DEFAULT_OVERWRITE_SETTING
156
+ ),
157
+ save_confidence_setting=data.get(
158
+ c.DEFAULT_CONFIDENCE_SETTING_KEY, c.DEFAULT_CONFIDENCE_SETTING
159
+ ),
160
+ min_bout_length=data.get(
161
+ c.DEFAULT_MIN_BOUT_LENGTH_KEY, c.DEFAULT_MIN_BOUT_LENGTH
162
+ ),
163
+ emg_filter=EMGFilter(
151
164
  **data.get(
152
165
  c.EMG_FILTER_KEY,
153
166
  {
@@ -157,7 +170,7 @@ def load_config() -> tuple[
157
170
  },
158
171
  )
159
172
  ),
160
- Hyperparameters(
173
+ hyperparameters=Hyperparameters(
161
174
  **data.get(
162
175
  c.HYPERPARAMETERS_KEY,
163
176
  {
@@ -168,8 +181,8 @@ def load_config() -> tuple[
168
181
  },
169
182
  )
170
183
  ),
171
- data.get(c.EPOCHS_TO_SHOW_KEY, c.DEFAULT_EPOCHS_TO_SHOW),
172
- data.get(c.AUTOSCROLL_KEY, c.DEFAULT_AUTOSCROLL_STATE),
184
+ epochs_to_show=data.get(c.EPOCHS_TO_SHOW_KEY, c.DEFAULT_EPOCHS_TO_SHOW),
185
+ autoscroll_state=data.get(c.AUTOSCROLL_KEY, c.DEFAULT_AUTOSCROLL_STATE),
173
186
  )
174
187
 
175
188
 
@@ -211,6 +224,7 @@ def save_config(
211
224
  os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "w"
212
225
  ) as f:
213
226
  json.dump(output_dict, f, indent=4)
227
+ f.write("\n")
214
228
 
215
229
 
216
230
  def load_recording_list(filename: str) -> list[Recording]:
@@ -246,3 +260,20 @@ def save_recording_list(filename: str, recordings: list[Recording]) -> None:
246
260
  }
247
261
  with open(filename, "w") as f:
248
262
  json.dump(recording_dict, f, indent=4)
263
+
264
+
265
+ def get_version() -> str:
266
+ """Get AccuSleePy package version
267
+
268
+ :return: AccuSleePy package version
269
+ """
270
+ version = ""
271
+ toml_file = os.path.join(
272
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
273
+ "pyproject.toml",
274
+ )
275
+ if os.path.isfile(toml_file):
276
+ toml_data = toml.load(toml_file)
277
+ if "project" in toml_data and "version" in toml_data["project"]:
278
+ version = toml_data["project"]["version"]
279
+ return version
@@ -0,0 +1,40 @@
1
+ """File dialog helpers"""
2
+
3
+ import os
4
+
5
+ from PySide6.QtWidgets import QFileDialog, QWidget
6
+
7
+
8
+ def select_existing_file(parent: QWidget, title: str, file_filter: str) -> str | None:
9
+ """Show dialog to select an existing file.
10
+
11
+ :param parent: parent widget
12
+ :param title: dialog window title
13
+ :param file_filter: file type filter (e.g., "*.csv")
14
+ :return: normalized path or None if cancelled
15
+ """
16
+ dialog = QFileDialog(parent)
17
+ dialog.setWindowTitle(title)
18
+ dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
19
+ dialog.setViewMode(QFileDialog.ViewMode.Detail)
20
+ dialog.setNameFilter(file_filter)
21
+
22
+ if dialog.exec():
23
+ return os.path.normpath(dialog.selectedFiles()[0])
24
+ return None
25
+
26
+
27
+ def select_save_location(parent: QWidget, caption: str, file_filter: str) -> str | None:
28
+ """Show dialog to choose save location.
29
+
30
+ :param parent: parent widget
31
+ :param caption: dialog window caption
32
+ :param file_filter: file type filter (e.g., "*.csv")
33
+ :return: normalized path or None if cancelled
34
+ """
35
+ filename, _ = QFileDialog.getSaveFileName(
36
+ parent, caption=caption, filter=file_filter
37
+ )
38
+ if filename:
39
+ return os.path.normpath(filename)
40
+ return None
Binary file