accusleepy 0.5.0__tar.gz → 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {accusleepy-0.5.0 → accusleepy-0.7.0}/PKG-INFO +11 -2
  2. {accusleepy-0.5.0 → accusleepy-0.7.0}/README.md +10 -1
  3. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/classification.py +49 -15
  4. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/config.json +15 -1
  5. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/constants.py +29 -2
  6. accusleepy-0.7.0/accusleepy/fileio.py +230 -0
  7. accusleepy-0.7.0/accusleepy/gui/images/primary_window.png +0 -0
  8. accusleepy-0.7.0/accusleepy/gui/images/viewer_window.png +0 -0
  9. accusleepy-0.7.0/accusleepy/gui/images/viewer_window_annotated.png +0 -0
  10. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/main.py +220 -42
  11. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/manual_scoring.py +38 -8
  12. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/mplwidget.py +54 -29
  13. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/primary_window.py +937 -254
  14. accusleepy-0.7.0/accusleepy/gui/primary_window.ui +4628 -0
  15. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/resources.qrc +1 -1
  16. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/text/main_guide.md +18 -12
  17. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/viewer_window.py +19 -7
  18. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/viewer_window.ui +34 -2
  19. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/models.py +11 -1
  20. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/signal_processing.py +40 -17
  21. accusleepy-0.7.0/accusleepy/temperature_scaling.py +157 -0
  22. {accusleepy-0.5.0 → accusleepy-0.7.0}/pyproject.toml +1 -1
  23. accusleepy-0.5.0/accusleepy/fileio.py +0 -156
  24. accusleepy-0.5.0/accusleepy/gui/images/primary_window.png +0 -0
  25. accusleepy-0.5.0/accusleepy/gui/images/viewer_window.png +0 -0
  26. accusleepy-0.5.0/accusleepy/gui/images/viewer_window_annotated.png +0 -0
  27. accusleepy-0.5.0/accusleepy/gui/primary_window.ui +0 -3673
  28. accusleepy-0.5.0/accusleepy/gui/text/config_guide.txt +0 -29
  29. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/__init__.py +0 -0
  30. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/__main__.py +0 -0
  31. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/bouts.py +0 -0
  32. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/brain_state_set.py +0 -0
  33. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/__init__.py +0 -0
  34. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/brightness_down.png +0 -0
  35. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/brightness_up.png +0 -0
  36. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/double_down_arrow.png +0 -0
  37. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/double_up_arrow.png +0 -0
  38. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/down_arrow.png +0 -0
  39. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/home.png +0 -0
  40. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/question.png +0 -0
  41. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/save.png +0 -0
  42. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/up_arrow.png +0 -0
  43. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/zoom_in.png +0 -0
  44. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/icons/zoom_out.png +0 -0
  45. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/resources_rc.py +0 -0
  46. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/gui/text/manual_scoring_guide.md +0 -0
  47. {accusleepy-0.5.0 → accusleepy-0.7.0}/accusleepy/multitaper.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: accusleepy
3
- Version: 0.5.0
3
+ Version: 0.7.0
4
4
  Summary: Python implementation of AccuSleep
5
5
  License: GPL-3.0-only
6
6
  Author: Zeke Barger
@@ -39,6 +39,7 @@ It offers the following improvements over the MATLAB version (AccuSleep):
39
39
  - Model files contain useful metadata (brain state configuration,
40
40
  epoch length, number of epochs)
41
41
  - Models optimized for real-time scoring can be trained
42
+ - Confidence scores can be saved and visualized
42
43
  - Lists of recordings can be imported and exported for repeatable batch processing
43
44
  - Undo/redo functionality in the manual scoring interface
44
45
 
@@ -75,6 +76,9 @@ to the [config file](accusleepy/config.json).
75
76
 
76
77
  ## Changelog
77
78
 
79
+ - 0.7.0: More settings can be configured in the UI
80
+ - 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
81
+ since the new calibration feature will make the confidence scores more accurate.
78
82
  - 0.5.0: Performance improvements
79
83
  - 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
80
84
  - 0.4.4: Performance improvements
@@ -93,7 +97,12 @@ Manual scoring interface
93
97
  ## Acknowledgements
94
98
 
95
99
  We would like to thank [Franz Weber](https://www.med.upenn.edu/weberlab/) for creating an
96
- early version of the manual labeling interface.
100
+ early version of the manual labeling interface. The code that
101
+ creates spectrograms comes from the
102
+ [Prerau lab](https://github.com/preraulab/multitaper_toolbox/blob/master/python/multitaper_spectrogram_python.py)
103
+ with only minor modifications.
97
104
  Jim Bohnslav's [deepethogram](https://github.com/jbohnslav/deepethogram) served as an
98
105
  incredibly useful reference when reimplementing this project in python.
106
+ The model calibration code added in version 0.6.0 comes from Geoff Pleiss'
107
+ [temperature scaling repo](https://github.com/gpleiss/temperature_scaling).
99
108
 
@@ -11,6 +11,7 @@ It offers the following improvements over the MATLAB version (AccuSleep):
11
11
  - Model files contain useful metadata (brain state configuration,
12
12
  epoch length, number of epochs)
13
13
  - Models optimized for real-time scoring can be trained
14
+ - Confidence scores can be saved and visualized
14
15
  - Lists of recordings can be imported and exported for repeatable batch processing
15
16
  - Undo/redo functionality in the manual scoring interface
16
17
 
@@ -47,6 +48,9 @@ to the [config file](accusleepy/config.json).
47
48
 
48
49
  ## Changelog
49
50
 
51
+ - 0.7.0: More settings can be configured in the UI
52
+ - 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
53
+ since the new calibration feature will make the confidence scores more accurate.
50
54
  - 0.5.0: Performance improvements
51
55
  - 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
52
56
  - 0.4.4: Performance improvements
@@ -65,6 +69,11 @@ Manual scoring interface
65
69
  ## Acknowledgements
66
70
 
67
71
  We would like to thank [Franz Weber](https://www.med.upenn.edu/weberlab/) for creating an
68
- early version of the manual labeling interface.
72
+ early version of the manual labeling interface. The code that
73
+ creates spectrograms comes from the
74
+ [Prerau lab](https://github.com/preraulab/multitaper_toolbox/blob/master/python/multitaper_spectrogram_python.py)
75
+ with only minor modifications.
69
76
  Jim Bohnslav's [deepethogram](https://github.com/jbohnslav/deepethogram) served as an
70
77
  incredibly useful reference when reimplementing this project in python.
78
+ The model calibration code added in version 0.6.0 comes from Geoff Pleiss'
79
+ [temperature scaling repo](https://github.com/gpleiss/temperature_scaling).
@@ -11,6 +11,7 @@ from tqdm import trange
11
11
 
12
12
  import accusleepy.constants as c
13
13
  from accusleepy.brain_state_set import BrainStateSet
14
+ from accusleepy.fileio import EMGFilter, Hyperparameters
14
15
  from accusleepy.models import SSANN
15
16
  from accusleepy.signal_processing import (
16
17
  create_eeg_emg_image,
@@ -19,11 +20,6 @@ from accusleepy.signal_processing import (
19
20
  mixture_z_score_img,
20
21
  )
21
22
 
22
- BATCH_SIZE = 64
23
- LEARNING_RATE = 1e-3
24
- MOMENTUM = 0.9
25
- TRAINING_EPOCHS = 6
26
-
27
23
 
28
24
  class AccuSleepImageDataset(Dataset):
29
25
  """Dataset for loading AccuSleep training images"""
@@ -61,11 +57,36 @@ def get_device():
61
57
  )
62
58
 
63
59
 
60
+ def create_dataloader(
61
+ annotations_file: str,
62
+ img_dir: str,
63
+ hyperparameters: Hyperparameters,
64
+ shuffle: bool = True,
65
+ ) -> DataLoader:
66
+ """Create DataLoader for a dataset of training or calibration images
67
+
68
+ :param annotations_file: file with information on each training image
69
+ :param img_dir: training image location
70
+ :param hyperparameters: model training hyperparameters
71
+ :param shuffle: reshuffle data for every epoch
72
+ :return: DataLoader for the data
73
+
74
+ """
75
+ image_dataset = AccuSleepImageDataset(
76
+ annotations_file=annotations_file,
77
+ img_dir=img_dir,
78
+ )
79
+ return DataLoader(
80
+ image_dataset, batch_size=hyperparameters.batch_size, shuffle=shuffle
81
+ )
82
+
83
+
64
84
  def train_ssann(
65
85
  annotations_file: str,
66
86
  img_dir: str,
67
87
  mixture_weights: np.array,
68
88
  n_classes: int,
89
+ hyperparameters: Hyperparameters,
69
90
  ) -> SSANN:
70
91
  """Train a SSANN classification model for sleep scoring
71
92
 
@@ -73,13 +94,14 @@ def train_ssann(
73
94
  :param img_dir: training image location
74
95
  :param mixture_weights: typical relative frequencies of brain states
75
96
  :param n_classes: number of classes the model will learn
97
+ :param hyperparameters: model training hyperparameters
76
98
  :return: trained Sleep Scoring Artificial Neural Network model
77
99
  """
78
- training_data = AccuSleepImageDataset(
100
+ train_dataloader = create_dataloader(
79
101
  annotations_file=annotations_file,
80
102
  img_dir=img_dir,
103
+ hyperparameters=hyperparameters,
81
104
  )
82
- train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)
83
105
 
84
106
  device = get_device()
85
107
  model = SSANN(n_classes=n_classes)
@@ -90,9 +112,13 @@ def train_ssann(
90
112
  weight = torch.tensor((mixture_weights**-1).astype("float32")).to(device)
91
113
 
92
114
  criterion = nn.CrossEntropyLoss(weight=weight)
93
- optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
115
+ optimizer = optim.SGD(
116
+ model.parameters(),
117
+ lr=hyperparameters.learning_rate,
118
+ momentum=hyperparameters.momentum,
119
+ )
94
120
 
95
- for _ in trange(TRAINING_EPOCHS):
121
+ for _ in trange(hyperparameters.training_epochs):
96
122
  for data in train_dataloader:
97
123
  inputs, labels = data
98
124
  (inputs, labels) = (inputs.to(device), labels.to(device))
@@ -115,6 +141,7 @@ def score_recording(
115
141
  epoch_length: int | float,
116
142
  epochs_per_img: int,
117
143
  brain_state_set: BrainStateSet,
144
+ emg_filter: EMGFilter,
118
145
  ) -> np.array:
119
146
  """Use classification model to get brain state labels for a recording
120
147
 
@@ -130,7 +157,8 @@ def score_recording(
130
157
  :param epoch_length: epoch length, in seconds
131
158
  :param epochs_per_img: number of epochs for the model to consider
132
159
  :param brain_state_set: set of brain state options
133
- :return: brain state labels
160
+ :param emg_filter: EMG filter parameters
161
+ :return: brain state labels, confidence scores
134
162
  """
135
163
  # prepare model
136
164
  device = get_device()
@@ -138,7 +166,7 @@ def score_recording(
138
166
  model.eval()
139
167
 
140
168
  # create and scale eeg+emg spectrogram
141
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
169
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
142
170
  img = mixture_z_score_img(
143
171
  img,
144
172
  mixture_means=mixture_means,
@@ -158,10 +186,12 @@ def score_recording(
158
186
  # perform classification
159
187
  with torch.no_grad():
160
188
  outputs = model(images)
161
- _, predicted = torch.max(outputs, 1)
189
+ logits, predicted = torch.max(outputs, 1)
162
190
 
163
191
  labels = brain_state_set.convert_class_to_digit(predicted.cpu().numpy())
164
- return labels
192
+ confidence_scores = 1 / (1 + np.e ** (-logits.cpu().numpy()))
193
+
194
+ return labels, confidence_scores
165
195
 
166
196
 
167
197
  def example_real_time_scoring_function(
@@ -174,6 +204,7 @@ def example_real_time_scoring_function(
174
204
  epoch_length: int | float,
175
205
  epochs_per_img: int,
176
206
  brain_state_set: BrainStateSet,
207
+ emg_filter: EMGFilter,
177
208
  ) -> int:
178
209
  """Example function that could be used for real-time scoring
179
210
 
@@ -202,6 +233,7 @@ def example_real_time_scoring_function(
202
233
  :param epoch_length: epoch length, in seconds
203
234
  :param epochs_per_img: number of epochs shown to the model at once
204
235
  :param brain_state_set: set of brain state options
236
+ :param emg_filter: EMG filter parameters
205
237
  :return: brain state label
206
238
  """
207
239
  # prepare model
@@ -211,7 +243,7 @@ def example_real_time_scoring_function(
211
243
  model.eval()
212
244
 
213
245
  # create and scale eeg+emg spectrogram
214
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
246
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
215
247
  img = mixture_z_score_img(
216
248
  img,
217
249
  mixture_means=mixture_means,
@@ -242,6 +274,7 @@ def create_calibration_file(
242
274
  sampling_rate: int | float,
243
275
  epoch_length: int | float,
244
276
  brain_state_set: BrainStateSet,
277
+ emg_filter: EMGFilter,
245
278
  ) -> None:
246
279
  """Create file of calibration data for a subject
247
280
 
@@ -255,8 +288,9 @@ def create_calibration_file(
255
288
  :param sampling_rate: sampling rate, in Hz
256
289
  :param epoch_length: epoch length, in seconds
257
290
  :param brain_state_set: set of brain state options
291
+ :param emg_filter: EMG filter parameters
258
292
  """
259
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
293
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
260
294
  mixture_means, mixture_sds = get_mixture_values(
261
295
  img=img,
262
296
  labels=brain_state_set.convert_digit_to_class(labels),
@@ -19,5 +19,19 @@
19
19
  "frequency": 0.55
20
20
  }
21
21
  ],
22
- "default_epoch_length": 2.5
22
+ "default_epoch_length": 2.5,
23
+ "default_overwrite_setting": false,
24
+ "save_confidence_setting": true,
25
+ "default_min_bout_length": 5.0,
26
+ "emg_filter": {
27
+ "order": 8,
28
+ "bp_lower": 20.0,
29
+ "bp_upper": 50.0
30
+ },
31
+ "hyperparameters": {
32
+ "batch_size": 64,
33
+ "learning_rate": 0.001,
34
+ "momentum": 0.9,
35
+ "training_epochs": 6
36
+ }
23
37
  }
@@ -8,6 +8,7 @@ EEG_COL = "eeg"
8
8
  EMG_COL = "emg"
9
9
  # label file columns
10
10
  BRAIN_STATE_COL = "brain_state"
11
+ CONFIDENCE_SCORE_COL = "confidence_score"
11
12
 
12
13
 
13
14
  # really don't change these
@@ -35,7 +36,33 @@ LABEL_COL = "label"
35
36
  # recording list file header:
36
37
  RECORDING_LIST_NAME = "recording_list"
37
38
  RECORDING_LIST_FILE_TYPE = ".json"
38
- # key for default epoch length in config
39
- DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
40
39
  # filename used to store info about training image datasets
41
40
  ANNOTATIONS_FILENAME = "annotations.csv"
41
+ # filename for annotation file for the calibration set
42
+ CALIBRATION_ANNOTATION_FILENAME = "calibration_set.csv"
43
+
44
+ # config file keys
45
+ # ui setting keys
46
+ DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
47
+ DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
48
+ DEFAULT_MIN_BOUT_LENGTH_KEY = "default_min_bout_length"
49
+ DEFAULT_OVERWRITE_KEY = "default_overwrite_setting"
50
+ # EMG filter parameters key
51
+ EMG_FILTER_KEY = "emg_filter"
52
+ # model training hyperparameters key
53
+ HYPERPARAMETERS_KEY = "hyperparameters"
54
+
55
+ # default values
56
+ # default UI settings
57
+ DEFAULT_MIN_BOUT_LENGTH = 5.0
58
+ DEFAULT_CONFIDENCE_SETTING = True
59
+ DEFAULT_OVERWRITE_SETTING = False
60
+ # default EMG filter parameters (order, bandpass frequencies)
61
+ DEFAULT_EMG_FILTER_ORDER = 8
62
+ DEFAULT_EMG_BP_LOWER = 20
63
+ DEFAULT_EMG_BP_UPPER = 50
64
+ # default hyperparameters
65
+ DEFAULT_BATCH_SIZE = 64
66
+ DEFAULT_LEARNING_RATE = 1e-3
67
+ DEFAULT_MOMENTUM = 0.9
68
+ DEFAULT_TRAINING_EPOCHS = 6
@@ -0,0 +1,230 @@
1
+ import json
2
+ import os
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PySide6.QtWidgets import QListWidgetItem
8
+
9
+ from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
10
+ import accusleepy.constants as c
11
+
12
+
13
+ @dataclass
14
+ class EMGFilter:
15
+ """Convenience class for a EMG filter parameters"""
16
+
17
+ order: int # filter order
18
+ bp_lower: int | float # lower bandpass frequency
19
+ bp_upper: int | float # upper bandpass frequency
20
+
21
+
22
+ @dataclass
23
+ class Hyperparameters:
24
+ """Convenience class for model training hyperparameters"""
25
+
26
+ batch_size: int
27
+ learning_rate: float
28
+ momentum: float
29
+ training_epochs: int
30
+
31
+
32
+ @dataclass
33
+ class Recording:
34
+ """Store information about a recording"""
35
+
36
+ name: int = 1 # name to show in the GUI
37
+ recording_file: str = "" # path to recording file
38
+ label_file: str = "" # path to label file
39
+ calibration_file: str = "" # path to calibration file
40
+ sampling_rate: int | float = 0.0 # sampling rate, in Hz
41
+ widget: QListWidgetItem = None # list item widget shown in the GUI
42
+
43
+
44
+ def load_calibration_file(filename: str) -> (np.array, np.array):
45
+ """Load a calibration file
46
+
47
+ :param filename: filename
48
+ :return: mixture means and SDs
49
+ """
50
+ df = pd.read_csv(filename)
51
+ mixture_means = df[c.MIXTURE_MEAN_COL].values
52
+ mixture_sds = df[c.MIXTURE_SD_COL].values
53
+ return mixture_means, mixture_sds
54
+
55
+
56
+ def load_csv_or_parquet(filename: str) -> pd.DataFrame:
57
+ """Load a csv or parquet file as a dataframe
58
+
59
+ :param filename: filename
60
+ :return: dataframe of file contents
61
+ """
62
+ extension = os.path.splitext(filename)[1]
63
+ if extension == ".csv":
64
+ df = pd.read_csv(filename)
65
+ elif extension == ".parquet":
66
+ df = pd.read_parquet(filename)
67
+ else:
68
+ raise Exception("file must be csv or parquet")
69
+ return df
70
+
71
+
72
+ def load_recording(filename: str) -> (np.array, np.array):
73
+ """Load recording of EEG and EMG time series data
74
+
75
+ :param filename: filename
76
+ :return: arrays of EEG and EMG data
77
+ """
78
+ df = load_csv_or_parquet(filename)
79
+ eeg = df[c.EEG_COL].values
80
+ emg = df[c.EMG_COL].values
81
+ return eeg, emg
82
+
83
+
84
+ def load_labels(filename: str) -> (np.array, np.array):
85
+ """Load file of brain state labels and confidence scores
86
+
87
+ :param filename: filename
88
+ :return: array of brain state labels and, optionally, array of confidence scores
89
+ """
90
+ df = load_csv_or_parquet(filename)
91
+ if c.CONFIDENCE_SCORE_COL in df.columns:
92
+ return df[c.BRAIN_STATE_COL].values, df[c.CONFIDENCE_SCORE_COL].values
93
+ else:
94
+ return df[c.BRAIN_STATE_COL].values, None
95
+
96
+
97
+ def save_labels(
98
+ labels: np.array, filename: str, confidence_scores: np.array = None
99
+ ) -> None:
100
+ """Save brain state labels to file
101
+
102
+ :param labels: brain state labels
103
+ :param filename: filename
104
+ :param confidence_scores: optional confidence scores
105
+ """
106
+ if confidence_scores is not None:
107
+ pd.DataFrame(
108
+ {c.BRAIN_STATE_COL: labels, c.CONFIDENCE_SCORE_COL: confidence_scores}
109
+ ).to_csv(filename, index=False)
110
+ else:
111
+ pd.DataFrame({c.BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
112
+
113
+
114
+ def load_config() -> tuple[
115
+ BrainStateSet, int | float, bool, bool, int | float, EMGFilter, Hyperparameters
116
+ ]:
117
+ """Load configuration file with brain state options
118
+
119
+ :return: set of brain state options,
120
+ default epoch length,
121
+ default overwrite setting,
122
+ default confidence score output setting,
123
+ default minimum bout length,
124
+ EMG filter parameters,
125
+ model training hyperparameters
126
+ """
127
+ with open(
128
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "r"
129
+ ) as f:
130
+ data = json.load(f)
131
+
132
+ return (
133
+ BrainStateSet(
134
+ [BrainState(**b) for b in data[BRAIN_STATES_KEY]], c.UNDEFINED_LABEL
135
+ ),
136
+ data[c.DEFAULT_EPOCH_LENGTH_KEY],
137
+ data.get(c.DEFAULT_OVERWRITE_KEY, c.DEFAULT_OVERWRITE_SETTING),
138
+ data.get(c.DEFAULT_CONFIDENCE_SETTING_KEY, c.DEFAULT_CONFIDENCE_SETTING),
139
+ data.get(c.DEFAULT_MIN_BOUT_LENGTH_KEY, c.DEFAULT_MIN_BOUT_LENGTH),
140
+ EMGFilter(
141
+ **data.get(
142
+ c.EMG_FILTER_KEY,
143
+ {
144
+ "order": c.DEFAULT_EMG_FILTER_ORDER,
145
+ "bp_lower": c.DEFAULT_EMG_BP_LOWER,
146
+ "bp_upper": c.DEFAULT_EMG_BP_UPPER,
147
+ },
148
+ )
149
+ ),
150
+ Hyperparameters(
151
+ **data.get(
152
+ c.HYPERPARAMETERS_KEY,
153
+ {
154
+ "batch_size": c.DEFAULT_BATCH_SIZE,
155
+ "learning_rate": c.DEFAULT_LEARNING_RATE,
156
+ "momentum": c.DEFAULT_MOMENTUM,
157
+ "training_epochs": c.DEFAULT_TRAINING_EPOCHS,
158
+ },
159
+ )
160
+ ),
161
+ )
162
+
163
+
164
+ def save_config(
165
+ brain_state_set: BrainStateSet,
166
+ default_epoch_length: int | float,
167
+ overwrite_setting: bool,
168
+ save_confidence_setting: bool,
169
+ min_bout_length: int | float,
170
+ emg_filter: EMGFilter,
171
+ hyperparameters: Hyperparameters,
172
+ ) -> None:
173
+ """Save configuration of brain state options to json file
174
+
175
+ :param brain_state_set: set of brain state options
176
+ :param default_epoch_length: default epoch length
177
+ :param save_confidence_setting: default setting for
178
+ saving confidence scores
179
+ :param emg_filter: EMG filter parameters
180
+ :param min_bout_length: default minimum bout length
181
+ :param overwrite_setting: default setting for overwriting
182
+ existing labels
183
+ :param hyperparameters: model training hyperparameters
184
+ """
185
+ output_dict = brain_state_set.to_output_dict()
186
+ output_dict.update({c.DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
187
+ output_dict.update({c.DEFAULT_OVERWRITE_KEY: overwrite_setting})
188
+ output_dict.update({c.DEFAULT_CONFIDENCE_SETTING_KEY: save_confidence_setting})
189
+ output_dict.update({c.DEFAULT_MIN_BOUT_LENGTH_KEY: min_bout_length})
190
+ output_dict.update({c.EMG_FILTER_KEY: emg_filter.__dict__})
191
+ output_dict.update({c.HYPERPARAMETERS_KEY: hyperparameters.__dict__})
192
+ with open(
193
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "w"
194
+ ) as f:
195
+ json.dump(output_dict, f, indent=4)
196
+
197
+
198
+ def load_recording_list(filename: str) -> list[Recording]:
199
+ """Load list of recordings from file
200
+
201
+ :param filename: filename of list of recordings
202
+ :return: list of recordings
203
+ """
204
+ with open(filename, "r") as f:
205
+ data = json.load(f)
206
+ recording_list = [Recording(**r) for r in data[c.RECORDING_LIST_NAME]]
207
+ for i, r in enumerate(recording_list):
208
+ r.name = i + 1
209
+ return recording_list
210
+
211
+
212
+ def save_recording_list(filename: str, recordings: list[Recording]) -> None:
213
+ """Save list of recordings to file
214
+
215
+ :param filename: where to save the list
216
+ :param recordings: list of recordings to export
217
+ """
218
+ recording_dict = {
219
+ c.RECORDING_LIST_NAME: [
220
+ {
221
+ "recording_file": r.recording_file,
222
+ "label_file": r.label_file,
223
+ "calibration_file": r.calibration_file,
224
+ "sampling_rate": r.sampling_rate,
225
+ }
226
+ for r in recordings
227
+ ]
228
+ }
229
+ with open(filename, "w") as f:
230
+ json.dump(recording_dict, f, indent=4)