accusleepy 0.6.0__tar.gz → 0.7.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {accusleepy-0.6.0 → accusleepy-0.7.1}/PKG-INFO +4 -1
- {accusleepy-0.6.0 → accusleepy-0.7.1}/README.md +3 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/classification.py +29 -13
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/config.json +14 -1
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/constants.py +44 -6
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/fileio.py +87 -36
- accusleepy-0.7.1/accusleepy/gui/images/primary_window.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/main.py +133 -163
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/manual_scoring.py +45 -47
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/primary_window.py +760 -135
- accusleepy-0.7.1/accusleepy/gui/primary_window.ui +4643 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/text/main_guide.md +2 -1
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/models.py +1 -12
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/signal_processing.py +18 -17
- accusleepy-0.7.1/accusleepy/validation.py +128 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/pyproject.toml +1 -1
- accusleepy-0.6.0/accusleepy/gui/images/primary_window.png +0 -0
- accusleepy-0.6.0/accusleepy/gui/primary_window.ui +0 -3831
- accusleepy-0.6.0/accusleepy/gui/text/config_guide.txt +0 -27
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/__init__.py +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/__main__.py +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/bouts.py +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/brain_state_set.py +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/__init__.py +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/brightness_down.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/brightness_up.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/double_down_arrow.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/double_up_arrow.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/down_arrow.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/home.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/question.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/save.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/up_arrow.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/zoom_in.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/icons/zoom_out.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/images/viewer_window.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/images/viewer_window_annotated.png +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/mplwidget.py +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/resources.qrc +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/resources_rc.py +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/text/manual_scoring_guide.md +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/viewer_window.py +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/gui/viewer_window.ui +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/multitaper.py +0 -0
- {accusleepy-0.6.0 → accusleepy-0.7.1}/accusleepy/temperature_scaling.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: accusleepy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.1
|
|
4
4
|
Summary: Python implementation of AccuSleep
|
|
5
5
|
License: GPL-3.0-only
|
|
6
6
|
Author: Zeke Barger
|
|
@@ -39,6 +39,7 @@ It offers the following improvements over the MATLAB version (AccuSleep):
|
|
|
39
39
|
- Model files contain useful metadata (brain state configuration,
|
|
40
40
|
epoch length, number of epochs)
|
|
41
41
|
- Models optimized for real-time scoring can be trained
|
|
42
|
+
- Confidence scores can be saved and visualized
|
|
42
43
|
- Lists of recordings can be imported and exported for repeatable batch processing
|
|
43
44
|
- Undo/redo functionality in the manual scoring interface
|
|
44
45
|
|
|
@@ -75,6 +76,8 @@ to the [config file](accusleepy/config.json).
|
|
|
75
76
|
|
|
76
77
|
## Changelog
|
|
77
78
|
|
|
79
|
+
- 0.7.1: Bugfixes, code cleanup
|
|
80
|
+
- 0.7.0: More settings can be configured in the UI
|
|
78
81
|
- 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
|
|
79
82
|
since the new calibration feature will make the confidence scores more accurate.
|
|
80
83
|
- 0.5.0: Performance improvements
|
|
@@ -11,6 +11,7 @@ It offers the following improvements over the MATLAB version (AccuSleep):
|
|
|
11
11
|
- Model files contain useful metadata (brain state configuration,
|
|
12
12
|
epoch length, number of epochs)
|
|
13
13
|
- Models optimized for real-time scoring can be trained
|
|
14
|
+
- Confidence scores can be saved and visualized
|
|
14
15
|
- Lists of recordings can be imported and exported for repeatable batch processing
|
|
15
16
|
- Undo/redo functionality in the manual scoring interface
|
|
16
17
|
|
|
@@ -47,6 +48,8 @@ to the [config file](accusleepy/config.json).
|
|
|
47
48
|
|
|
48
49
|
## Changelog
|
|
49
50
|
|
|
51
|
+
- 0.7.1: Bugfixes, code cleanup
|
|
52
|
+
- 0.7.0: More settings can be configured in the UI
|
|
50
53
|
- 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
|
|
51
54
|
since the new calibration feature will make the confidence scores more accurate.
|
|
52
55
|
- 0.5.0: Performance improvements
|
|
@@ -11,6 +11,7 @@ from tqdm import trange
|
|
|
11
11
|
|
|
12
12
|
import accusleepy.constants as c
|
|
13
13
|
from accusleepy.brain_state_set import BrainStateSet
|
|
14
|
+
from accusleepy.fileio import EMGFilter, Hyperparameters
|
|
14
15
|
from accusleepy.models import SSANN
|
|
15
16
|
from accusleepy.signal_processing import (
|
|
16
17
|
create_eeg_emg_image,
|
|
@@ -19,11 +20,6 @@ from accusleepy.signal_processing import (
|
|
|
19
20
|
mixture_z_score_img,
|
|
20
21
|
)
|
|
21
22
|
|
|
22
|
-
BATCH_SIZE = 64
|
|
23
|
-
LEARNING_RATE = 1e-3
|
|
24
|
-
MOMENTUM = 0.9
|
|
25
|
-
TRAINING_EPOCHS = 6
|
|
26
|
-
|
|
27
23
|
|
|
28
24
|
class AccuSleepImageDataset(Dataset):
|
|
29
25
|
"""Dataset for loading AccuSleep training images"""
|
|
@@ -62,12 +58,16 @@ def get_device():
|
|
|
62
58
|
|
|
63
59
|
|
|
64
60
|
def create_dataloader(
|
|
65
|
-
annotations_file: str,
|
|
61
|
+
annotations_file: str,
|
|
62
|
+
img_dir: str,
|
|
63
|
+
hyperparameters: Hyperparameters,
|
|
64
|
+
shuffle: bool = True,
|
|
66
65
|
) -> DataLoader:
|
|
67
66
|
"""Create DataLoader for a dataset of training or calibration images
|
|
68
67
|
|
|
69
68
|
:param annotations_file: file with information on each training image
|
|
70
69
|
:param img_dir: training image location
|
|
70
|
+
:param hyperparameters: model training hyperparameters
|
|
71
71
|
:param shuffle: reshuffle data for every epoch
|
|
72
72
|
:return: DataLoader for the data
|
|
73
73
|
|
|
@@ -76,7 +76,9 @@ def create_dataloader(
|
|
|
76
76
|
annotations_file=annotations_file,
|
|
77
77
|
img_dir=img_dir,
|
|
78
78
|
)
|
|
79
|
-
return DataLoader(
|
|
79
|
+
return DataLoader(
|
|
80
|
+
image_dataset, batch_size=hyperparameters.batch_size, shuffle=shuffle
|
|
81
|
+
)
|
|
80
82
|
|
|
81
83
|
|
|
82
84
|
def train_ssann(
|
|
@@ -84,6 +86,7 @@ def train_ssann(
|
|
|
84
86
|
img_dir: str,
|
|
85
87
|
mixture_weights: np.array,
|
|
86
88
|
n_classes: int,
|
|
89
|
+
hyperparameters: Hyperparameters,
|
|
87
90
|
) -> SSANN:
|
|
88
91
|
"""Train a SSANN classification model for sleep scoring
|
|
89
92
|
|
|
@@ -91,10 +94,13 @@ def train_ssann(
|
|
|
91
94
|
:param img_dir: training image location
|
|
92
95
|
:param mixture_weights: typical relative frequencies of brain states
|
|
93
96
|
:param n_classes: number of classes the model will learn
|
|
97
|
+
:param hyperparameters: model training hyperparameters
|
|
94
98
|
:return: trained Sleep Scoring Artificial Neural Network model
|
|
95
99
|
"""
|
|
96
100
|
train_dataloader = create_dataloader(
|
|
97
|
-
annotations_file=annotations_file,
|
|
101
|
+
annotations_file=annotations_file,
|
|
102
|
+
img_dir=img_dir,
|
|
103
|
+
hyperparameters=hyperparameters,
|
|
98
104
|
)
|
|
99
105
|
|
|
100
106
|
device = get_device()
|
|
@@ -106,9 +112,13 @@ def train_ssann(
|
|
|
106
112
|
weight = torch.tensor((mixture_weights**-1).astype("float32")).to(device)
|
|
107
113
|
|
|
108
114
|
criterion = nn.CrossEntropyLoss(weight=weight)
|
|
109
|
-
optimizer = optim.SGD(
|
|
115
|
+
optimizer = optim.SGD(
|
|
116
|
+
model.parameters(),
|
|
117
|
+
lr=hyperparameters.learning_rate,
|
|
118
|
+
momentum=hyperparameters.momentum,
|
|
119
|
+
)
|
|
110
120
|
|
|
111
|
-
for _ in trange(
|
|
121
|
+
for _ in trange(hyperparameters.training_epochs):
|
|
112
122
|
for data in train_dataloader:
|
|
113
123
|
inputs, labels = data
|
|
114
124
|
(inputs, labels) = (inputs.to(device), labels.to(device))
|
|
@@ -131,6 +141,7 @@ def score_recording(
|
|
|
131
141
|
epoch_length: int | float,
|
|
132
142
|
epochs_per_img: int,
|
|
133
143
|
brain_state_set: BrainStateSet,
|
|
144
|
+
emg_filter: EMGFilter,
|
|
134
145
|
) -> np.array:
|
|
135
146
|
"""Use classification model to get brain state labels for a recording
|
|
136
147
|
|
|
@@ -146,6 +157,7 @@ def score_recording(
|
|
|
146
157
|
:param epoch_length: epoch length, in seconds
|
|
147
158
|
:param epochs_per_img: number of epochs for the model to consider
|
|
148
159
|
:param brain_state_set: set of brain state options
|
|
160
|
+
:param emg_filter: EMG filter parameters
|
|
149
161
|
:return: brain state labels, confidence scores
|
|
150
162
|
"""
|
|
151
163
|
# prepare model
|
|
@@ -154,7 +166,7 @@ def score_recording(
|
|
|
154
166
|
model.eval()
|
|
155
167
|
|
|
156
168
|
# create and scale eeg+emg spectrogram
|
|
157
|
-
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
169
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
|
|
158
170
|
img = mixture_z_score_img(
|
|
159
171
|
img,
|
|
160
172
|
mixture_means=mixture_means,
|
|
@@ -192,6 +204,7 @@ def example_real_time_scoring_function(
|
|
|
192
204
|
epoch_length: int | float,
|
|
193
205
|
epochs_per_img: int,
|
|
194
206
|
brain_state_set: BrainStateSet,
|
|
207
|
+
emg_filter: EMGFilter,
|
|
195
208
|
) -> int:
|
|
196
209
|
"""Example function that could be used for real-time scoring
|
|
197
210
|
|
|
@@ -220,6 +233,7 @@ def example_real_time_scoring_function(
|
|
|
220
233
|
:param epoch_length: epoch length, in seconds
|
|
221
234
|
:param epochs_per_img: number of epochs shown to the model at once
|
|
222
235
|
:param brain_state_set: set of brain state options
|
|
236
|
+
:param emg_filter: EMG filter parameters
|
|
223
237
|
:return: brain state label
|
|
224
238
|
"""
|
|
225
239
|
# prepare model
|
|
@@ -229,7 +243,7 @@ def example_real_time_scoring_function(
|
|
|
229
243
|
model.eval()
|
|
230
244
|
|
|
231
245
|
# create and scale eeg+emg spectrogram
|
|
232
|
-
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
246
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
|
|
233
247
|
img = mixture_z_score_img(
|
|
234
248
|
img,
|
|
235
249
|
mixture_means=mixture_means,
|
|
@@ -260,6 +274,7 @@ def create_calibration_file(
|
|
|
260
274
|
sampling_rate: int | float,
|
|
261
275
|
epoch_length: int | float,
|
|
262
276
|
brain_state_set: BrainStateSet,
|
|
277
|
+
emg_filter: EMGFilter,
|
|
263
278
|
) -> None:
|
|
264
279
|
"""Create file of calibration data for a subject
|
|
265
280
|
|
|
@@ -273,8 +288,9 @@ def create_calibration_file(
|
|
|
273
288
|
:param sampling_rate: sampling rate, in Hz
|
|
274
289
|
:param epoch_length: epoch length, in seconds
|
|
275
290
|
:param brain_state_set: set of brain state options
|
|
291
|
+
:param emg_filter: EMG filter parameters
|
|
276
292
|
"""
|
|
277
|
-
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
293
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
|
|
278
294
|
mixture_means, mixture_sds = get_mixture_values(
|
|
279
295
|
img=img,
|
|
280
296
|
labels=brain_state_set.convert_digit_to_class(labels),
|
|
@@ -20,5 +20,18 @@
|
|
|
20
20
|
}
|
|
21
21
|
],
|
|
22
22
|
"default_epoch_length": 2.5,
|
|
23
|
-
"
|
|
23
|
+
"default_overwrite_setting": false,
|
|
24
|
+
"save_confidence_setting": true,
|
|
25
|
+
"default_min_bout_length": 5.0,
|
|
26
|
+
"emg_filter": {
|
|
27
|
+
"order": 8,
|
|
28
|
+
"bp_lower": 20.0,
|
|
29
|
+
"bp_upper": 50.0
|
|
30
|
+
},
|
|
31
|
+
"hyperparameters": {
|
|
32
|
+
"batch_size": 64,
|
|
33
|
+
"learning_rate": 0.001,
|
|
34
|
+
"momentum": 0.9,
|
|
35
|
+
"training_epochs": 6
|
|
36
|
+
}
|
|
24
37
|
}
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
1
3
|
# probably don't change these unless you really need to
|
|
2
4
|
UNDEFINED_LABEL = -1 # can't be the same as a brain state's digit, must be an integer
|
|
3
5
|
# calibration file columns
|
|
@@ -9,9 +11,16 @@ EMG_COL = "emg"
|
|
|
9
11
|
# label file columns
|
|
10
12
|
BRAIN_STATE_COL = "brain_state"
|
|
11
13
|
CONFIDENCE_SCORE_COL = "confidence_score"
|
|
14
|
+
# max number of messages to store in main window message box
|
|
15
|
+
MESSAGE_BOX_MAX_DEPTH = 200
|
|
16
|
+
# clip mixture z-scores above and below this level
|
|
17
|
+
# in the matlab implementation, 4.5 was used
|
|
18
|
+
ABS_MAX_Z_SCORE = 3.5
|
|
19
|
+
# upper frequency limit when generating EEG spectrograms
|
|
20
|
+
SPECTROGRAM_UPPER_FREQ = 64
|
|
12
21
|
|
|
13
22
|
|
|
14
|
-
#
|
|
23
|
+
# very unlikely you will want to change values from here onwards
|
|
15
24
|
# config file location
|
|
16
25
|
CONFIG_FILE = "config.json"
|
|
17
26
|
# number of times to include the EMG power in a training image
|
|
@@ -20,8 +29,15 @@ EMG_COPIES = 9
|
|
|
20
29
|
MIN_WINDOW_LEN = 5
|
|
21
30
|
# frequency above which to downsample EEG spectrograms
|
|
22
31
|
DOWNSAMPLING_START_FREQ = 20
|
|
23
|
-
#
|
|
32
|
+
# highest EEG frequency used as model input
|
|
24
33
|
UPPER_FREQ = 50
|
|
34
|
+
# height in pixels of each training image
|
|
35
|
+
IMAGE_HEIGHT = (
|
|
36
|
+
len(np.arange(0, DOWNSAMPLING_START_FREQ, 1 / MIN_WINDOW_LEN))
|
|
37
|
+
+ len(np.arange(DOWNSAMPLING_START_FREQ, UPPER_FREQ, 2 / MIN_WINDOW_LEN))
|
|
38
|
+
+ EMG_COPIES
|
|
39
|
+
)
|
|
40
|
+
|
|
25
41
|
# classification model types
|
|
26
42
|
DEFAULT_MODEL_TYPE = "default" # current epoch is centered
|
|
27
43
|
REAL_TIME_MODEL_TYPE = "real-time" # current epoch on the right
|
|
@@ -36,11 +52,33 @@ LABEL_COL = "label"
|
|
|
36
52
|
# recording list file header:
|
|
37
53
|
RECORDING_LIST_NAME = "recording_list"
|
|
38
54
|
RECORDING_LIST_FILE_TYPE = ".json"
|
|
39
|
-
# key for default epoch length in config
|
|
40
|
-
DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
|
|
41
|
-
# key used for default confidence score behavior in config
|
|
42
|
-
DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
|
|
43
55
|
# filename used to store info about training image datasets
|
|
44
56
|
ANNOTATIONS_FILENAME = "annotations.csv"
|
|
45
57
|
# filename for annotation file for the calibration set
|
|
46
58
|
CALIBRATION_ANNOTATION_FILENAME = "calibration_set.csv"
|
|
59
|
+
|
|
60
|
+
# config file keys
|
|
61
|
+
# ui setting keys
|
|
62
|
+
DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
|
|
63
|
+
DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
|
|
64
|
+
DEFAULT_MIN_BOUT_LENGTH_KEY = "default_min_bout_length"
|
|
65
|
+
DEFAULT_OVERWRITE_KEY = "default_overwrite_setting"
|
|
66
|
+
# EMG filter parameters key
|
|
67
|
+
EMG_FILTER_KEY = "emg_filter"
|
|
68
|
+
# model training hyperparameters key
|
|
69
|
+
HYPERPARAMETERS_KEY = "hyperparameters"
|
|
70
|
+
|
|
71
|
+
# default values
|
|
72
|
+
# default UI settings
|
|
73
|
+
DEFAULT_MIN_BOUT_LENGTH = 5.0
|
|
74
|
+
DEFAULT_CONFIDENCE_SETTING = True
|
|
75
|
+
DEFAULT_OVERWRITE_SETTING = False
|
|
76
|
+
# default EMG filter parameters (order, bandpass frequencies)
|
|
77
|
+
DEFAULT_EMG_FILTER_ORDER = 8
|
|
78
|
+
DEFAULT_EMG_BP_LOWER = 20
|
|
79
|
+
DEFAULT_EMG_BP_UPPER = 50
|
|
80
|
+
# default hyperparameters
|
|
81
|
+
DEFAULT_BATCH_SIZE = 64
|
|
82
|
+
DEFAULT_LEARNING_RATE = 1e-3
|
|
83
|
+
DEFAULT_MOMENTUM = 0.9
|
|
84
|
+
DEFAULT_TRAINING_EPOCHS = 6
|
|
@@ -7,19 +7,26 @@ import pandas as pd
|
|
|
7
7
|
from PySide6.QtWidgets import QListWidgetItem
|
|
8
8
|
|
|
9
9
|
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
10
|
+
import accusleepy.constants as c
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class EMGFilter:
|
|
15
|
+
"""Convenience class for a EMG filter parameters"""
|
|
16
|
+
|
|
17
|
+
order: int # filter order
|
|
18
|
+
bp_lower: int | float # lower bandpass frequency
|
|
19
|
+
bp_upper: int | float # upper bandpass frequency
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Hyperparameters:
|
|
24
|
+
"""Convenience class for model training hyperparameters"""
|
|
25
|
+
|
|
26
|
+
batch_size: int
|
|
27
|
+
learning_rate: float
|
|
28
|
+
momentum: float
|
|
29
|
+
training_epochs: int
|
|
23
30
|
|
|
24
31
|
|
|
25
32
|
@dataclass
|
|
@@ -41,8 +48,8 @@ def load_calibration_file(filename: str) -> (np.array, np.array):
|
|
|
41
48
|
:return: mixture means and SDs
|
|
42
49
|
"""
|
|
43
50
|
df = pd.read_csv(filename)
|
|
44
|
-
mixture_means = df[MIXTURE_MEAN_COL].values
|
|
45
|
-
mixture_sds = df[MIXTURE_SD_COL].values
|
|
51
|
+
mixture_means = df[c.MIXTURE_MEAN_COL].values
|
|
52
|
+
mixture_sds = df[c.MIXTURE_SD_COL].values
|
|
46
53
|
return mixture_means, mixture_sds
|
|
47
54
|
|
|
48
55
|
|
|
@@ -69,8 +76,8 @@ def load_recording(filename: str) -> (np.array, np.array):
|
|
|
69
76
|
:return: arrays of EEG and EMG data
|
|
70
77
|
"""
|
|
71
78
|
df = load_csv_or_parquet(filename)
|
|
72
|
-
eeg = df[EEG_COL].values
|
|
73
|
-
emg = df[EMG_COL].values
|
|
79
|
+
eeg = df[c.EEG_COL].values
|
|
80
|
+
emg = df[c.EMG_COL].values
|
|
74
81
|
return eeg, emg
|
|
75
82
|
|
|
76
83
|
|
|
@@ -81,10 +88,10 @@ def load_labels(filename: str) -> (np.array, np.array):
|
|
|
81
88
|
:return: array of brain state labels and, optionally, array of confidence scores
|
|
82
89
|
"""
|
|
83
90
|
df = load_csv_or_parquet(filename)
|
|
84
|
-
if CONFIDENCE_SCORE_COL in df.columns:
|
|
85
|
-
return df[BRAIN_STATE_COL].values, df[CONFIDENCE_SCORE_COL].values
|
|
91
|
+
if c.CONFIDENCE_SCORE_COL in df.columns:
|
|
92
|
+
return df[c.BRAIN_STATE_COL].values, df[c.CONFIDENCE_SCORE_COL].values
|
|
86
93
|
else:
|
|
87
|
-
return df[BRAIN_STATE_COL].values, None
|
|
94
|
+
return df[c.BRAIN_STATE_COL].values, None
|
|
88
95
|
|
|
89
96
|
|
|
90
97
|
def save_labels(
|
|
@@ -98,48 +105,92 @@ def save_labels(
|
|
|
98
105
|
"""
|
|
99
106
|
if confidence_scores is not None:
|
|
100
107
|
pd.DataFrame(
|
|
101
|
-
{BRAIN_STATE_COL: labels, CONFIDENCE_SCORE_COL: confidence_scores}
|
|
108
|
+
{c.BRAIN_STATE_COL: labels, c.CONFIDENCE_SCORE_COL: confidence_scores}
|
|
102
109
|
).to_csv(filename, index=False)
|
|
103
110
|
else:
|
|
104
|
-
pd.DataFrame({BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
|
|
111
|
+
pd.DataFrame({c.BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
|
|
105
112
|
|
|
106
113
|
|
|
107
|
-
def load_config() -> tuple[
|
|
114
|
+
def load_config() -> tuple[
|
|
115
|
+
BrainStateSet, int | float, bool, bool, int | float, EMGFilter, Hyperparameters
|
|
116
|
+
]:
|
|
108
117
|
"""Load configuration file with brain state options
|
|
109
118
|
|
|
110
|
-
:return: set of brain state options,
|
|
119
|
+
:return: set of brain state options,
|
|
120
|
+
default epoch length,
|
|
121
|
+
default overwrite setting,
|
|
122
|
+
default confidence score output setting,
|
|
123
|
+
default minimum bout length,
|
|
124
|
+
EMG filter parameters,
|
|
125
|
+
model training hyperparameters
|
|
111
126
|
"""
|
|
112
127
|
with open(
|
|
113
|
-
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "r"
|
|
128
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "r"
|
|
114
129
|
) as f:
|
|
115
130
|
data = json.load(f)
|
|
116
131
|
|
|
117
132
|
return (
|
|
118
133
|
BrainStateSet(
|
|
119
|
-
[BrainState(**b) for b in data[BRAIN_STATES_KEY]], UNDEFINED_LABEL
|
|
134
|
+
[BrainState(**b) for b in data[BRAIN_STATES_KEY]], c.UNDEFINED_LABEL
|
|
135
|
+
),
|
|
136
|
+
data[c.DEFAULT_EPOCH_LENGTH_KEY],
|
|
137
|
+
data.get(c.DEFAULT_OVERWRITE_KEY, c.DEFAULT_OVERWRITE_SETTING),
|
|
138
|
+
data.get(c.DEFAULT_CONFIDENCE_SETTING_KEY, c.DEFAULT_CONFIDENCE_SETTING),
|
|
139
|
+
data.get(c.DEFAULT_MIN_BOUT_LENGTH_KEY, c.DEFAULT_MIN_BOUT_LENGTH),
|
|
140
|
+
EMGFilter(
|
|
141
|
+
**data.get(
|
|
142
|
+
c.EMG_FILTER_KEY,
|
|
143
|
+
{
|
|
144
|
+
"order": c.DEFAULT_EMG_FILTER_ORDER,
|
|
145
|
+
"bp_lower": c.DEFAULT_EMG_BP_LOWER,
|
|
146
|
+
"bp_upper": c.DEFAULT_EMG_BP_UPPER,
|
|
147
|
+
},
|
|
148
|
+
)
|
|
149
|
+
),
|
|
150
|
+
Hyperparameters(
|
|
151
|
+
**data.get(
|
|
152
|
+
c.HYPERPARAMETERS_KEY,
|
|
153
|
+
{
|
|
154
|
+
"batch_size": c.DEFAULT_BATCH_SIZE,
|
|
155
|
+
"learning_rate": c.DEFAULT_LEARNING_RATE,
|
|
156
|
+
"momentum": c.DEFAULT_MOMENTUM,
|
|
157
|
+
"training_epochs": c.DEFAULT_TRAINING_EPOCHS,
|
|
158
|
+
},
|
|
159
|
+
)
|
|
120
160
|
),
|
|
121
|
-
data[DEFAULT_EPOCH_LENGTH_KEY],
|
|
122
|
-
data.get(DEFAULT_CONFIDENCE_SETTING_KEY, True),
|
|
123
161
|
)
|
|
124
162
|
|
|
125
163
|
|
|
126
164
|
def save_config(
|
|
127
165
|
brain_state_set: BrainStateSet,
|
|
128
166
|
default_epoch_length: int | float,
|
|
167
|
+
overwrite_setting: bool,
|
|
129
168
|
save_confidence_setting: bool,
|
|
169
|
+
min_bout_length: int | float,
|
|
170
|
+
emg_filter: EMGFilter,
|
|
171
|
+
hyperparameters: Hyperparameters,
|
|
130
172
|
) -> None:
|
|
131
173
|
"""Save configuration of brain state options to json file
|
|
132
174
|
|
|
133
175
|
:param brain_state_set: set of brain state options
|
|
134
|
-
:param default_epoch_length: epoch length
|
|
135
|
-
:param save_confidence_setting:
|
|
136
|
-
|
|
176
|
+
:param default_epoch_length: default epoch length
|
|
177
|
+
:param save_confidence_setting: default setting for
|
|
178
|
+
saving confidence scores
|
|
179
|
+
:param emg_filter: EMG filter parameters
|
|
180
|
+
:param min_bout_length: default minimum bout length
|
|
181
|
+
:param overwrite_setting: default setting for overwriting
|
|
182
|
+
existing labels
|
|
183
|
+
:param hyperparameters: model training hyperparameters
|
|
137
184
|
"""
|
|
138
185
|
output_dict = brain_state_set.to_output_dict()
|
|
139
|
-
output_dict.update({DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
|
|
140
|
-
output_dict.update({
|
|
186
|
+
output_dict.update({c.DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
|
|
187
|
+
output_dict.update({c.DEFAULT_OVERWRITE_KEY: overwrite_setting})
|
|
188
|
+
output_dict.update({c.DEFAULT_CONFIDENCE_SETTING_KEY: save_confidence_setting})
|
|
189
|
+
output_dict.update({c.DEFAULT_MIN_BOUT_LENGTH_KEY: min_bout_length})
|
|
190
|
+
output_dict.update({c.EMG_FILTER_KEY: emg_filter.__dict__})
|
|
191
|
+
output_dict.update({c.HYPERPARAMETERS_KEY: hyperparameters.__dict__})
|
|
141
192
|
with open(
|
|
142
|
-
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "w"
|
|
193
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "w"
|
|
143
194
|
) as f:
|
|
144
195
|
json.dump(output_dict, f, indent=4)
|
|
145
196
|
|
|
@@ -152,7 +203,7 @@ def load_recording_list(filename: str) -> list[Recording]:
|
|
|
152
203
|
"""
|
|
153
204
|
with open(filename, "r") as f:
|
|
154
205
|
data = json.load(f)
|
|
155
|
-
recording_list = [Recording(**r) for r in data[RECORDING_LIST_NAME]]
|
|
206
|
+
recording_list = [Recording(**r) for r in data[c.RECORDING_LIST_NAME]]
|
|
156
207
|
for i, r in enumerate(recording_list):
|
|
157
208
|
r.name = i + 1
|
|
158
209
|
return recording_list
|
|
@@ -165,7 +216,7 @@ def save_recording_list(filename: str, recordings: list[Recording]) -> None:
|
|
|
165
216
|
:param recordings: list of recordings to export
|
|
166
217
|
"""
|
|
167
218
|
recording_dict = {
|
|
168
|
-
RECORDING_LIST_NAME: [
|
|
219
|
+
c.RECORDING_LIST_NAME: [
|
|
169
220
|
{
|
|
170
221
|
"recording_file": r.recording_file,
|
|
171
222
|
"label_file": r.label_file,
|
|
Binary file
|