accusleepy 0.8.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
accusleepy/bouts.py CHANGED
@@ -42,7 +42,7 @@ def find_last_adjacent_bout(sorted_bouts: list[Bout], bout_index: int) -> int:
42
42
 
43
43
 
44
44
  def find_short_bouts(
45
- labels: np.array, min_epochs: int, brain_states: set[int]
45
+ labels: np.ndarray, min_epochs: int, brain_states: set[int]
46
46
  ) -> list[Bout]:
47
47
  """Locate all brain state bouts below a minimum length
48
48
 
@@ -80,8 +80,8 @@ def find_short_bouts(
80
80
 
81
81
 
82
82
  def enforce_min_bout_length(
83
- labels: np.array, epoch_length: int | float, min_bout_length: int | float
84
- ) -> np.array:
83
+ labels: np.ndarray, epoch_length: int | float, min_bout_length: int | float
84
+ ) -> np.ndarray:
85
85
  """Ensure brain state bouts meet the min length requirement
86
86
 
87
87
  As a post-processing step for sleep scoring, we can require that any
@@ -41,7 +41,7 @@ class BrainStateSet:
41
41
  i = 0
42
42
  for brain_state in self.brain_states:
43
43
  if brain_state.digit == undefined_label:
44
- raise Exception(
44
+ raise ValueError(
45
45
  f"Digit for {brain_state.name} matches 'undefined' label"
46
46
  )
47
47
  if brain_state.is_scored:
@@ -56,9 +56,11 @@ class BrainStateSet:
56
56
 
57
57
  self.mixture_weights = np.array(self.mixture_weights)
58
58
  if np.sum(self.mixture_weights) != 1:
59
- raise Exception("Typical frequencies for scored brain states must sum to 1")
59
+ raise ValueError(
60
+ "Typical frequencies for scored brain states must sum to 1"
61
+ )
60
62
 
61
- def convert_digit_to_class(self, digits: np.array) -> np.array:
63
+ def convert_digit_to_class(self, digits: np.ndarray) -> np.ndarray:
62
64
  """Convert array of digits to their corresponding classes
63
65
 
64
66
  :param digits: array of digits
@@ -66,7 +68,7 @@ class BrainStateSet:
66
68
  """
67
69
  return np.array([self.digit_to_class[i] for i in digits])
68
70
 
69
- def convert_class_to_digit(self, classes: np.array) -> np.array:
71
+ def convert_class_to_digit(self, classes: np.ndarray) -> np.ndarray:
70
72
  """Convert array of classes to their corresponding digits
71
73
 
72
74
  :param classes: array of classes
@@ -16,7 +16,6 @@ from accusleepy.models import SSANN
16
16
  from accusleepy.signal_processing import (
17
17
  create_eeg_emg_image,
18
18
  format_img,
19
- get_mixture_values,
20
19
  mixture_z_score_img,
21
20
  )
22
21
 
@@ -84,7 +83,7 @@ def create_dataloader(
84
83
  def train_ssann(
85
84
  annotations_file: str,
86
85
  img_dir: str,
87
- mixture_weights: np.array,
86
+ training_class_balance: np.ndarray,
88
87
  n_classes: int,
89
88
  hyperparameters: Hyperparameters,
90
89
  ) -> SSANN:
@@ -92,7 +91,7 @@ def train_ssann(
92
91
 
93
92
  :param annotations_file: file with information on each training image
94
93
  :param img_dir: training image location
95
- :param mixture_weights: typical relative frequencies of brain states
94
+ :param training_class_balance: proportion of each class in the training set
96
95
  :param n_classes: number of classes the model will learn
97
96
  :param hyperparameters: model training hyperparameters
98
97
  :return: trained Sleep Scoring Artificial Neural Network model
@@ -109,7 +108,7 @@ def train_ssann(
109
108
  model.train()
110
109
 
111
110
  # correct for class imbalance
112
- weight = torch.tensor((mixture_weights**-1).astype("float32")).to(device)
111
+ weight = torch.tensor((training_class_balance**-1).astype("float32")).to(device)
113
112
 
114
113
  criterion = nn.CrossEntropyLoss(weight=weight)
115
114
  optimizer = optim.SGD(
@@ -133,16 +132,16 @@ def train_ssann(
133
132
 
134
133
  def score_recording(
135
134
  model: SSANN,
136
- eeg: np.array,
137
- emg: np.array,
138
- mixture_means: np.array,
139
- mixture_sds: np.array,
135
+ eeg: np.ndarray,
136
+ emg: np.ndarray,
137
+ mixture_means: np.ndarray,
138
+ mixture_sds: np.ndarray,
140
139
  sampling_rate: int | float,
141
140
  epoch_length: int | float,
142
141
  epochs_per_img: int,
143
142
  brain_state_set: BrainStateSet,
144
143
  emg_filter: EMGFilter,
145
- ) -> np.array:
144
+ ) -> tuple[np.ndarray, np.ndarray]:
146
145
  """Use classification model to get brain state labels for a recording
147
146
 
148
147
  This assumes signals have been preprocessed to contain an integer
@@ -167,7 +166,7 @@ def score_recording(
167
166
 
168
167
  # create and scale eeg+emg spectrogram
169
168
  img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
170
- img = mixture_z_score_img(
169
+ img, _ = mixture_z_score_img(
171
170
  img,
172
171
  mixture_means=mixture_means,
173
172
  mixture_sds=mixture_sds,
@@ -196,10 +195,10 @@ def score_recording(
196
195
 
197
196
  def example_real_time_scoring_function(
198
197
  model: SSANN,
199
- eeg: np.array,
200
- emg: np.array,
201
- mixture_means: np.array,
202
- mixture_sds: np.array,
198
+ eeg: np.ndarray,
199
+ emg: np.ndarray,
200
+ mixture_means: np.ndarray,
201
+ mixture_sds: np.ndarray,
203
202
  sampling_rate: int | float,
204
203
  epoch_length: int | float,
205
204
  epochs_per_img: int,
@@ -244,7 +243,7 @@ def example_real_time_scoring_function(
244
243
 
245
244
  # create and scale eeg+emg spectrogram
246
245
  img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
247
- img = mixture_z_score_img(
246
+ img, _ = mixture_z_score_img(
248
247
  img,
249
248
  mixture_means=mixture_means,
250
249
  mixture_sds=mixture_sds,
@@ -264,38 +263,3 @@ def example_real_time_scoring_function(
264
263
 
265
264
  label = int(brain_state_set.convert_class_to_digit(predicted.cpu().numpy())[0])
266
265
  return label
267
-
268
-
269
- def create_calibration_file(
270
- filename: str,
271
- eeg: np.array,
272
- emg: np.array,
273
- labels: np.array,
274
- sampling_rate: int | float,
275
- epoch_length: int | float,
276
- brain_state_set: BrainStateSet,
277
- emg_filter: EMGFilter,
278
- ) -> None:
279
- """Create file of calibration data for a subject
280
-
281
- This assumes signals have been preprocessed to contain an integer
282
- number of epochs.
283
-
284
- :param filename: filename for the calibration file
285
- :param eeg: EEG signal
286
- :param emg: EMG signal
287
- :param labels: brain state labels, as digits
288
- :param sampling_rate: sampling rate, in Hz
289
- :param epoch_length: epoch length, in seconds
290
- :param brain_state_set: set of brain state options
291
- :param emg_filter: EMG filter parameters
292
- """
293
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
294
- mixture_means, mixture_sds = get_mixture_values(
295
- img=img,
296
- labels=brain_state_set.convert_digit_to_class(labels),
297
- brain_state_set=brain_state_set,
298
- )
299
- pd.DataFrame(
300
- {c.MIXTURE_MEAN_COL: mixture_means, c.MIXTURE_SD_COL: mixture_sds}
301
- ).to_csv(filename, index=False)
accusleepy/constants.py CHANGED
@@ -18,6 +18,9 @@ MESSAGE_BOX_MAX_DEPTH = 200
18
18
  ABS_MAX_Z_SCORE = 3.5
19
19
  # upper frequency limit when generating EEG spectrograms
20
20
  SPECTROGRAM_UPPER_FREQ = 64
21
+ # minimum number of epochs per brain state needed to create
22
+ # a calibration file or use a recording for model training
23
+ MIN_EPOCHS_PER_STATE = 3
21
24
 
22
25
 
23
26
  # very unlikely you will want to change values from here onwards
accusleepy/fileio.py CHANGED
@@ -5,6 +5,7 @@ from dataclasses import dataclass
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  from PySide6.QtWidgets import QListWidgetItem
8
+ import toml
8
9
 
9
10
  from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
10
11
  import accusleepy.constants as c
@@ -56,7 +57,7 @@ class Recording:
56
57
  widget: QListWidgetItem = None # list item widget shown in the GUI
57
58
 
58
59
 
59
- def load_calibration_file(filename: str) -> (np.array, np.array):
60
+ def load_calibration_file(filename: str) -> tuple[np.ndarray, np.ndarray]:
60
61
  """Load a calibration file
61
62
 
62
63
  :param filename: filename
@@ -80,11 +81,11 @@ def load_csv_or_parquet(filename: str) -> pd.DataFrame:
80
81
  elif extension == ".parquet":
81
82
  df = pd.read_parquet(filename)
82
83
  else:
83
- raise Exception("file must be csv or parquet")
84
+ raise ValueError("file must be csv or parquet")
84
85
  return df
85
86
 
86
87
 
87
- def load_recording(filename: str) -> (np.array, np.array):
88
+ def load_recording(filename: str) -> tuple[np.ndarray, np.ndarray]:
88
89
  """Load recording of EEG and EMG time series data
89
90
 
90
91
  :param filename: filename
@@ -96,7 +97,7 @@ def load_recording(filename: str) -> (np.array, np.array):
96
97
  return eeg, emg
97
98
 
98
99
 
99
- def load_labels(filename: str) -> (np.array, np.array):
100
+ def load_labels(filename: str) -> tuple[np.ndarray, np.ndarray | None]:
100
101
  """Load file of brain state labels and confidence scores
101
102
 
102
103
  :param filename: filename
@@ -110,7 +111,7 @@ def load_labels(filename: str) -> (np.array, np.array):
110
111
 
111
112
 
112
113
  def save_labels(
113
- labels: np.array, filename: str, confidence_scores: np.array = None
114
+ labels: np.ndarray, filename: str, confidence_scores: np.ndarray | None = None
114
115
  ) -> None:
115
116
  """Save brain state labels to file
116
117
 
@@ -223,6 +224,7 @@ def save_config(
223
224
  os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "w"
224
225
  ) as f:
225
226
  json.dump(output_dict, f, indent=4)
227
+ f.write("\n")
226
228
 
227
229
 
228
230
  def load_recording_list(filename: str) -> list[Recording]:
@@ -258,3 +260,20 @@ def save_recording_list(filename: str, recordings: list[Recording]) -> None:
258
260
  }
259
261
  with open(filename, "w") as f:
260
262
  json.dump(recording_dict, f, indent=4)
263
+
264
+
265
+ def get_version() -> str:
266
+ """Get AccuSleePy package version
267
+
268
+ :return: AccuSleePy package version
269
+ """
270
+ version = ""
271
+ toml_file = os.path.join(
272
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
273
+ "pyproject.toml",
274
+ )
275
+ if os.path.isfile(toml_file):
276
+ toml_data = toml.load(toml_file)
277
+ if "project" in toml_data and "version" in toml_data["project"]:
278
+ version = toml_data["project"]["version"]
279
+ return version
@@ -0,0 +1,40 @@
1
+ """File dialog helpers"""
2
+
3
+ import os
4
+
5
+ from PySide6.QtWidgets import QFileDialog, QWidget
6
+
7
+
8
+ def select_existing_file(parent: QWidget, title: str, file_filter: str) -> str | None:
9
+ """Show dialog to select an existing file.
10
+
11
+ :param parent: parent widget
12
+ :param title: dialog window title
13
+ :param file_filter: file type filter (e.g., "*.csv")
14
+ :return: normalized path or None if cancelled
15
+ """
16
+ dialog = QFileDialog(parent)
17
+ dialog.setWindowTitle(title)
18
+ dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
19
+ dialog.setViewMode(QFileDialog.ViewMode.Detail)
20
+ dialog.setNameFilter(file_filter)
21
+
22
+ if dialog.exec():
23
+ return os.path.normpath(dialog.selectedFiles()[0])
24
+ return None
25
+
26
+
27
+ def select_save_location(parent: QWidget, caption: str, file_filter: str) -> str | None:
28
+ """Show dialog to choose save location.
29
+
30
+ :param parent: parent widget
31
+ :param caption: dialog window caption
32
+ :param file_filter: file type filter (e.g., "*.csv")
33
+ :return: normalized path or None if cancelled
34
+ """
35
+ filename, _ = QFileDialog.getSaveFileName(
36
+ parent, caption=caption, filter=file_filter
37
+ )
38
+ if filename:
39
+ return os.path.normpath(filename)
40
+ return None
Binary file