accusleepy 0.6.0__tar.gz → 0.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {accusleepy-0.6.0 → accusleepy-0.7.1}/PKG-INFO +4 -1
  2. {accusleepy-0.6.0 → accusleepy-0.7.1}/README.md +3 -0
  3. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/classification.py +29 -13
  4. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/config.json +14 -1
  5. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/constants.py +44 -6
  6. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/fileio.py +87 -36
  7. accusleepy-0.7.1/accusleepy/gui/images/primary_window.png +0 -0
  8. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/main.py +133 -163
  9. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/manual_scoring.py +45 -47
  10. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/primary_window.py +760 -135
  11. accusleepy-0.7.1/accusleepy/gui/primary_window.ui +4643 -0
  12. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/text/main_guide.md +2 -1
  13. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/models.py +1 -12
  14. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/signal_processing.py +18 -17
  15. accusleepy-0.7.1/accusleepy/validation.py +128 -0
  16. {accusleepy-0.6.0 → accusleepy-0.7.1}/pyproject.toml +1 -1
  17. accusleepy-0.6.0/accusleepy/gui/images/primary_window.png +0 -0
  18. accusleepy-0.6.0/accusleepy/gui/primary_window.ui +0 -3831
  19. accusleepy-0.6.0/accusleepy/gui/text/config_guide.txt +0 -27
  20. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/__init__.py +0 -0
  21. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/__main__.py +0 -0
  22. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/bouts.py +0 -0
  23. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/brain_state_set.py +0 -0
  24. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/__init__.py +0 -0
  25. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/brightness_down.png +0 -0
  26. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/brightness_up.png +0 -0
  27. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/double_down_arrow.png +0 -0
  28. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/double_up_arrow.png +0 -0
  29. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/down_arrow.png +0 -0
  30. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/home.png +0 -0
  31. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/question.png +0 -0
  32. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/save.png +0 -0
  33. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/up_arrow.png +0 -0
  34. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/zoom_in.png +0 -0
  35. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/zoom_out.png +0 -0
  36. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/images/viewer_window.png +0 -0
  37. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/images/viewer_window_annotated.png +0 -0
  38. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/mplwidget.py +0 -0
  39. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/resources.qrc +0 -0
  40. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/resources_rc.py +0 -0
  41. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/text/manual_scoring_guide.md +0 -0
  42. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/viewer_window.py +0 -0
  43. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/viewer_window.ui +0 -0
  44. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/multitaper.py +0 -0
  45. {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/temperature_scaling.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: accusleepy
3
- Version: 0.6.0
3
+ Version: 0.7.1
4
4
  Summary: Python implementation of AccuSleep
5
5
  License: GPL-3.0-only
6
6
  Author: Zeke Barger
@@ -39,6 +39,7 @@ It offers the following improvements over the MATLAB version (AccuSleep):
39
39
  - Model files contain useful metadata (brain state configuration,
40
40
  epoch length, number of epochs)
41
41
  - Models optimized for real-time scoring can be trained
42
+ - Confidence scores can be saved and visualized
42
43
  - Lists of recordings can be imported and exported for repeatable batch processing
43
44
  - Undo/redo functionality in the manual scoring interface
44
45
 
@@ -75,6 +76,8 @@ to the [config file](accusleepy/config.json).
75
76
 
76
77
  ## Changelog
77
78
 
79
+ - 0.7.1: Bugfixes, code cleanup
80
+ - 0.7.0: More settings can be configured in the UI
78
81
  - 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
79
82
  since the new calibration feature will make the confidence scores more accurate.
80
83
  - 0.5.0: Performance improvements
@@ -11,6 +11,7 @@ It offers the following improvements over the MATLAB version (AccuSleep):
11
11
  - Model files contain useful metadata (brain state configuration,
12
12
  epoch length, number of epochs)
13
13
  - Models optimized for real-time scoring can be trained
14
+ - Confidence scores can be saved and visualized
14
15
  - Lists of recordings can be imported and exported for repeatable batch processing
15
16
  - Undo/redo functionality in the manual scoring interface
16
17
 
@@ -47,6 +48,8 @@ to the [config file](accusleepy/config.json).
47
48
 
48
49
  ## Changelog
49
50
 
51
+ - 0.7.1: Bugfixes, code cleanup
52
+ - 0.7.0: More settings can be configured in the UI
50
53
  - 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
51
54
  since the new calibration feature will make the confidence scores more accurate.
52
55
  - 0.5.0: Performance improvements
@@ -11,6 +11,7 @@ from tqdm import trange
11
11
 
12
12
  import accusleepy.constants as c
13
13
  from accusleepy.brain_state_set import BrainStateSet
14
+ from accusleepy.fileio import EMGFilter, Hyperparameters
14
15
  from accusleepy.models import SSANN
15
16
  from accusleepy.signal_processing import (
16
17
  create_eeg_emg_image,
@@ -19,11 +20,6 @@ from accusleepy.signal_processing import (
19
20
  mixture_z_score_img,
20
21
  )
21
22
 
22
- BATCH_SIZE = 64
23
- LEARNING_RATE = 1e-3
24
- MOMENTUM = 0.9
25
- TRAINING_EPOCHS = 6
26
-
27
23
 
28
24
  class AccuSleepImageDataset(Dataset):
29
25
  """Dataset for loading AccuSleep training images"""
@@ -62,12 +58,16 @@ def get_device():
62
58
 
63
59
 
64
60
  def create_dataloader(
65
- annotations_file: str, img_dir: str, shuffle: bool = True
61
+ annotations_file: str,
62
+ img_dir: str,
63
+ hyperparameters: Hyperparameters,
64
+ shuffle: bool = True,
66
65
  ) -> DataLoader:
67
66
  """Create DataLoader for a dataset of training or calibration images
68
67
 
69
68
  :param annotations_file: file with information on each training image
70
69
  :param img_dir: training image location
70
+ :param hyperparameters: model training hyperparameters
71
71
  :param shuffle: reshuffle data for every epoch
72
72
  :return: DataLoader for the data
73
73
 
@@ -76,7 +76,9 @@ def create_dataloader(
76
76
  annotations_file=annotations_file,
77
77
  img_dir=img_dir,
78
78
  )
79
- return DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=shuffle)
79
+ return DataLoader(
80
+ image_dataset, batch_size=hyperparameters.batch_size, shuffle=shuffle
81
+ )
80
82
 
81
83
 
82
84
  def train_ssann(
@@ -84,6 +86,7 @@ def train_ssann(
84
86
  img_dir: str,
85
87
  mixture_weights: np.array,
86
88
  n_classes: int,
89
+ hyperparameters: Hyperparameters,
87
90
  ) -> SSANN:
88
91
  """Train a SSANN classification model for sleep scoring
89
92
 
@@ -91,10 +94,13 @@ def train_ssann(
91
94
  :param img_dir: training image location
92
95
  :param mixture_weights: typical relative frequencies of brain states
93
96
  :param n_classes: number of classes the model will learn
97
+ :param hyperparameters: model training hyperparameters
94
98
  :return: trained Sleep Scoring Artificial Neural Network model
95
99
  """
96
100
  train_dataloader = create_dataloader(
97
- annotations_file=annotations_file, img_dir=img_dir
101
+ annotations_file=annotations_file,
102
+ img_dir=img_dir,
103
+ hyperparameters=hyperparameters,
98
104
  )
99
105
 
100
106
  device = get_device()
@@ -106,9 +112,13 @@ def train_ssann(
106
112
  weight = torch.tensor((mixture_weights**-1).astype("float32")).to(device)
107
113
 
108
114
  criterion = nn.CrossEntropyLoss(weight=weight)
109
- optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
115
+ optimizer = optim.SGD(
116
+ model.parameters(),
117
+ lr=hyperparameters.learning_rate,
118
+ momentum=hyperparameters.momentum,
119
+ )
110
120
 
111
- for _ in trange(TRAINING_EPOCHS):
121
+ for _ in trange(hyperparameters.training_epochs):
112
122
  for data in train_dataloader:
113
123
  inputs, labels = data
114
124
  (inputs, labels) = (inputs.to(device), labels.to(device))
@@ -131,6 +141,7 @@ def score_recording(
131
141
  epoch_length: int | float,
132
142
  epochs_per_img: int,
133
143
  brain_state_set: BrainStateSet,
144
+ emg_filter: EMGFilter,
134
145
  ) -> np.array:
135
146
  """Use classification model to get brain state labels for a recording
136
147
 
@@ -146,6 +157,7 @@ def score_recording(
146
157
  :param epoch_length: epoch length, in seconds
147
158
  :param epochs_per_img: number of epochs for the model to consider
148
159
  :param brain_state_set: set of brain state options
160
+ :param emg_filter: EMG filter parameters
149
161
  :return: brain state labels, confidence scores
150
162
  """
151
163
  # prepare model
@@ -154,7 +166,7 @@ def score_recording(
154
166
  model.eval()
155
167
 
156
168
  # create and scale eeg+emg spectrogram
157
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
169
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
158
170
  img = mixture_z_score_img(
159
171
  img,
160
172
  mixture_means=mixture_means,
@@ -192,6 +204,7 @@ def example_real_time_scoring_function(
192
204
  epoch_length: int | float,
193
205
  epochs_per_img: int,
194
206
  brain_state_set: BrainStateSet,
207
+ emg_filter: EMGFilter,
195
208
  ) -> int:
196
209
  """Example function that could be used for real-time scoring
197
210
 
@@ -220,6 +233,7 @@ def example_real_time_scoring_function(
220
233
  :param epoch_length: epoch length, in seconds
221
234
  :param epochs_per_img: number of epochs shown to the model at once
222
235
  :param brain_state_set: set of brain state options
236
+ :param emg_filter: EMG filter parameters
223
237
  :return: brain state label
224
238
  """
225
239
  # prepare model
@@ -229,7 +243,7 @@ def example_real_time_scoring_function(
229
243
  model.eval()
230
244
 
231
245
  # create and scale eeg+emg spectrogram
232
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
246
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
233
247
  img = mixture_z_score_img(
234
248
  img,
235
249
  mixture_means=mixture_means,
@@ -260,6 +274,7 @@ def create_calibration_file(
260
274
  sampling_rate: int | float,
261
275
  epoch_length: int | float,
262
276
  brain_state_set: BrainStateSet,
277
+ emg_filter: EMGFilter,
263
278
  ) -> None:
264
279
  """Create file of calibration data for a subject
265
280
 
@@ -273,8 +288,9 @@ def create_calibration_file(
273
288
  :param sampling_rate: sampling rate, in Hz
274
289
  :param epoch_length: epoch length, in seconds
275
290
  :param brain_state_set: set of brain state options
291
+ :param emg_filter: EMG filter parameters
276
292
  """
277
- img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
293
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
278
294
  mixture_means, mixture_sds = get_mixture_values(
279
295
  img=img,
280
296
  labels=brain_state_set.convert_digit_to_class(labels),
@@ -20,5 +20,18 @@
20
20
  }
21
21
  ],
22
22
  "default_epoch_length": 2.5,
23
- "save_confidence_setting": true
23
+ "default_overwrite_setting": false,
24
+ "save_confidence_setting": true,
25
+ "default_min_bout_length": 5.0,
26
+ "emg_filter": {
27
+ "order": 8,
28
+ "bp_lower": 20.0,
29
+ "bp_upper": 50.0
30
+ },
31
+ "hyperparameters": {
32
+ "batch_size": 64,
33
+ "learning_rate": 0.001,
34
+ "momentum": 0.9,
35
+ "training_epochs": 6
36
+ }
24
37
  }
@@ -1,3 +1,5 @@
1
+ import numpy as np
2
+
1
3
  # probably don't change these unless you really need to
2
4
  UNDEFINED_LABEL = -1 # can't be the same as a brain state's digit, must be an integer
3
5
  # calibration file columns
@@ -9,9 +11,16 @@ EMG_COL = "emg"
9
11
  # label file columns
10
12
  BRAIN_STATE_COL = "brain_state"
11
13
  CONFIDENCE_SCORE_COL = "confidence_score"
14
+ # max number of messages to store in main window message box
15
+ MESSAGE_BOX_MAX_DEPTH = 200
16
+ # clip mixture z-scores above and below this level
17
+ # in the matlab implementation, 4.5 was used
18
+ ABS_MAX_Z_SCORE = 3.5
19
+ # upper frequency limit when generating EEG spectrograms
20
+ SPECTROGRAM_UPPER_FREQ = 64
12
21
 
13
22
 
14
- # really don't change these
23
+ # very unlikely you will want to change values from here onwards
15
24
  # config file location
16
25
  CONFIG_FILE = "config.json"
17
26
  # number of times to include the EMG power in a training image
@@ -20,8 +29,15 @@ EMG_COPIES = 9
20
29
  MIN_WINDOW_LEN = 5
21
30
  # frequency above which to downsample EEG spectrograms
22
31
  DOWNSAMPLING_START_FREQ = 20
23
- # upper frequency cutoff for EEG spectrograms
32
+ # highest EEG frequency used as model input
24
33
  UPPER_FREQ = 50
34
+ # height in pixels of each training image
35
+ IMAGE_HEIGHT = (
36
+ len(np.arange(0, DOWNSAMPLING_START_FREQ, 1 / MIN_WINDOW_LEN))
37
+ + len(np.arange(DOWNSAMPLING_START_FREQ, UPPER_FREQ, 2 / MIN_WINDOW_LEN))
38
+ + EMG_COPIES
39
+ )
40
+
25
41
  # classification model types
26
42
  DEFAULT_MODEL_TYPE = "default" # current epoch is centered
27
43
  REAL_TIME_MODEL_TYPE = "real-time" # current epoch on the right
@@ -36,11 +52,33 @@ LABEL_COL = "label"
36
52
  # recording list file header:
37
53
  RECORDING_LIST_NAME = "recording_list"
38
54
  RECORDING_LIST_FILE_TYPE = ".json"
39
- # key for default epoch length in config
40
- DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
41
- # key used for default confidence score behavior in config
42
- DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
43
55
  # filename used to store info about training image datasets
44
56
  ANNOTATIONS_FILENAME = "annotations.csv"
45
57
  # filename for annotation file for the calibration set
46
58
  CALIBRATION_ANNOTATION_FILENAME = "calibration_set.csv"
59
+
60
+ # config file keys
61
+ # ui setting keys
62
+ DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
63
+ DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
64
+ DEFAULT_MIN_BOUT_LENGTH_KEY = "default_min_bout_length"
65
+ DEFAULT_OVERWRITE_KEY = "default_overwrite_setting"
66
+ # EMG filter parameters key
67
+ EMG_FILTER_KEY = "emg_filter"
68
+ # model training hyperparameters key
69
+ HYPERPARAMETERS_KEY = "hyperparameters"
70
+
71
+ # default values
72
+ # default UI settings
73
+ DEFAULT_MIN_BOUT_LENGTH = 5.0
74
+ DEFAULT_CONFIDENCE_SETTING = True
75
+ DEFAULT_OVERWRITE_SETTING = False
76
+ # default EMG filter parameters (order, bandpass frequencies)
77
+ DEFAULT_EMG_FILTER_ORDER = 8
78
+ DEFAULT_EMG_BP_LOWER = 20
79
+ DEFAULT_EMG_BP_UPPER = 50
80
+ # default hyperparameters
81
+ DEFAULT_BATCH_SIZE = 64
82
+ DEFAULT_LEARNING_RATE = 1e-3
83
+ DEFAULT_MOMENTUM = 0.9
84
+ DEFAULT_TRAINING_EPOCHS = 6
@@ -7,19 +7,26 @@ import pandas as pd
7
7
  from PySide6.QtWidgets import QListWidgetItem
8
8
 
9
9
  from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
10
- from accusleepy.constants import (
11
- BRAIN_STATE_COL,
12
- CONFIDENCE_SCORE_COL,
13
- CONFIG_FILE,
14
- DEFAULT_CONFIDENCE_SETTING_KEY,
15
- DEFAULT_EPOCH_LENGTH_KEY,
16
- EEG_COL,
17
- EMG_COL,
18
- MIXTURE_MEAN_COL,
19
- MIXTURE_SD_COL,
20
- RECORDING_LIST_NAME,
21
- UNDEFINED_LABEL,
22
- )
10
+ import accusleepy.constants as c
11
+
12
+
13
+ @dataclass
14
+ class EMGFilter:
15
+ """Convenience class for a EMG filter parameters"""
16
+
17
+ order: int # filter order
18
+ bp_lower: int | float # lower bandpass frequency
19
+ bp_upper: int | float # upper bandpass frequency
20
+
21
+
22
+ @dataclass
23
+ class Hyperparameters:
24
+ """Convenience class for model training hyperparameters"""
25
+
26
+ batch_size: int
27
+ learning_rate: float
28
+ momentum: float
29
+ training_epochs: int
23
30
 
24
31
 
25
32
  @dataclass
@@ -41,8 +48,8 @@ def load_calibration_file(filename: str) -> (np.array, np.array):
41
48
  :return: mixture means and SDs
42
49
  """
43
50
  df = pd.read_csv(filename)
44
- mixture_means = df[MIXTURE_MEAN_COL].values
45
- mixture_sds = df[MIXTURE_SD_COL].values
51
+ mixture_means = df[c.MIXTURE_MEAN_COL].values
52
+ mixture_sds = df[c.MIXTURE_SD_COL].values
46
53
  return mixture_means, mixture_sds
47
54
 
48
55
 
@@ -69,8 +76,8 @@ def load_recording(filename: str) -> (np.array, np.array):
69
76
  :return: arrays of EEG and EMG data
70
77
  """
71
78
  df = load_csv_or_parquet(filename)
72
- eeg = df[EEG_COL].values
73
- emg = df[EMG_COL].values
79
+ eeg = df[c.EEG_COL].values
80
+ emg = df[c.EMG_COL].values
74
81
  return eeg, emg
75
82
 
76
83
 
@@ -81,10 +88,10 @@ def load_labels(filename: str) -> (np.array, np.array):
81
88
  :return: array of brain state labels and, optionally, array of confidence scores
82
89
  """
83
90
  df = load_csv_or_parquet(filename)
84
- if CONFIDENCE_SCORE_COL in df.columns:
85
- return df[BRAIN_STATE_COL].values, df[CONFIDENCE_SCORE_COL].values
91
+ if c.CONFIDENCE_SCORE_COL in df.columns:
92
+ return df[c.BRAIN_STATE_COL].values, df[c.CONFIDENCE_SCORE_COL].values
86
93
  else:
87
- return df[BRAIN_STATE_COL].values, None
94
+ return df[c.BRAIN_STATE_COL].values, None
88
95
 
89
96
 
90
97
  def save_labels(
@@ -98,48 +105,92 @@ def save_labels(
98
105
  """
99
106
  if confidence_scores is not None:
100
107
  pd.DataFrame(
101
- {BRAIN_STATE_COL: labels, CONFIDENCE_SCORE_COL: confidence_scores}
108
+ {c.BRAIN_STATE_COL: labels, c.CONFIDENCE_SCORE_COL: confidence_scores}
102
109
  ).to_csv(filename, index=False)
103
110
  else:
104
- pd.DataFrame({BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
111
+ pd.DataFrame({c.BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
105
112
 
106
113
 
107
- def load_config() -> tuple[BrainStateSet, int | float, bool]:
114
+ def load_config() -> tuple[
115
+ BrainStateSet, int | float, bool, bool, int | float, EMGFilter, Hyperparameters
116
+ ]:
108
117
  """Load configuration file with brain state options
109
118
 
110
- :return: set of brain state options, other settings
119
+ :return: set of brain state options,
120
+ default epoch length,
121
+ default overwrite setting,
122
+ default confidence score output setting,
123
+ default minimum bout length,
124
+ EMG filter parameters,
125
+ model training hyperparameters
111
126
  """
112
127
  with open(
113
- os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "r"
128
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "r"
114
129
  ) as f:
115
130
  data = json.load(f)
116
131
 
117
132
  return (
118
133
  BrainStateSet(
119
- [BrainState(**b) for b in data[BRAIN_STATES_KEY]], UNDEFINED_LABEL
134
+ [BrainState(**b) for b in data[BRAIN_STATES_KEY]], c.UNDEFINED_LABEL
135
+ ),
136
+ data[c.DEFAULT_EPOCH_LENGTH_KEY],
137
+ data.get(c.DEFAULT_OVERWRITE_KEY, c.DEFAULT_OVERWRITE_SETTING),
138
+ data.get(c.DEFAULT_CONFIDENCE_SETTING_KEY, c.DEFAULT_CONFIDENCE_SETTING),
139
+ data.get(c.DEFAULT_MIN_BOUT_LENGTH_KEY, c.DEFAULT_MIN_BOUT_LENGTH),
140
+ EMGFilter(
141
+ **data.get(
142
+ c.EMG_FILTER_KEY,
143
+ {
144
+ "order": c.DEFAULT_EMG_FILTER_ORDER,
145
+ "bp_lower": c.DEFAULT_EMG_BP_LOWER,
146
+ "bp_upper": c.DEFAULT_EMG_BP_UPPER,
147
+ },
148
+ )
149
+ ),
150
+ Hyperparameters(
151
+ **data.get(
152
+ c.HYPERPARAMETERS_KEY,
153
+ {
154
+ "batch_size": c.DEFAULT_BATCH_SIZE,
155
+ "learning_rate": c.DEFAULT_LEARNING_RATE,
156
+ "momentum": c.DEFAULT_MOMENTUM,
157
+ "training_epochs": c.DEFAULT_TRAINING_EPOCHS,
158
+ },
159
+ )
120
160
  ),
121
- data[DEFAULT_EPOCH_LENGTH_KEY],
122
- data.get(DEFAULT_CONFIDENCE_SETTING_KEY, True),
123
161
  )
124
162
 
125
163
 
126
164
  def save_config(
127
165
  brain_state_set: BrainStateSet,
128
166
  default_epoch_length: int | float,
167
+ overwrite_setting: bool,
129
168
  save_confidence_setting: bool,
169
+ min_bout_length: int | float,
170
+ emg_filter: EMGFilter,
171
+ hyperparameters: Hyperparameters,
130
172
  ) -> None:
131
173
  """Save configuration of brain state options to json file
132
174
 
133
175
  :param brain_state_set: set of brain state options
134
- :param default_epoch_length: epoch length to use when the GUI starts
135
- :param save_confidence_setting: whether the option to save confidence
136
- scores should be True by default
176
+ :param default_epoch_length: default epoch length
177
+ :param save_confidence_setting: default setting for
178
+ saving confidence scores
179
+ :param emg_filter: EMG filter parameters
180
+ :param min_bout_length: default minimum bout length
181
+ :param overwrite_setting: default setting for overwriting
182
+ existing labels
183
+ :param hyperparameters: model training hyperparameters
137
184
  """
138
185
  output_dict = brain_state_set.to_output_dict()
139
- output_dict.update({DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
140
- output_dict.update({DEFAULT_CONFIDENCE_SETTING_KEY: save_confidence_setting})
186
+ output_dict.update({c.DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
187
+ output_dict.update({c.DEFAULT_OVERWRITE_KEY: overwrite_setting})
188
+ output_dict.update({c.DEFAULT_CONFIDENCE_SETTING_KEY: save_confidence_setting})
189
+ output_dict.update({c.DEFAULT_MIN_BOUT_LENGTH_KEY: min_bout_length})
190
+ output_dict.update({c.EMG_FILTER_KEY: emg_filter.__dict__})
191
+ output_dict.update({c.HYPERPARAMETERS_KEY: hyperparameters.__dict__})
141
192
  with open(
142
- os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "w"
193
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "w"
143
194
  ) as f:
144
195
  json.dump(output_dict, f, indent=4)
145
196
 
@@ -152,7 +203,7 @@ def load_recording_list(filename: str) -> list[Recording]:
152
203
  """
153
204
  with open(filename, "r") as f:
154
205
  data = json.load(f)
155
- recording_list = [Recording(**r) for r in data[RECORDING_LIST_NAME]]
206
+ recording_list = [Recording(**r) for r in data[c.RECORDING_LIST_NAME]]
156
207
  for i, r in enumerate(recording_list):
157
208
  r.name = i + 1
158
209
  return recording_list
@@ -165,7 +216,7 @@ def save_recording_list(filename: str, recordings: list[Recording]) -> None:
165
216
  :param recordings: list of recordings to export
166
217
  """
167
218
  recording_dict = {
168
- RECORDING_LIST_NAME: [
219
+ c.RECORDING_LIST_NAME: [
169
220
  {
170
221
  "recording_file": r.recording_file,
171
222
  "label_file": r.label_file,