accusleepy 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/classification.py +49 -15
- accusleepy/config.json +15 -1
- accusleepy/constants.py +29 -2
- accusleepy/fileio.py +107 -33
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/images/viewer_window.png +0 -0
- accusleepy/gui/images/viewer_window_annotated.png +0 -0
- accusleepy/gui/main.py +220 -42
- accusleepy/gui/manual_scoring.py +38 -8
- accusleepy/gui/mplwidget.py +54 -29
- accusleepy/gui/primary_window.py +937 -254
- accusleepy/gui/primary_window.ui +3182 -2227
- accusleepy/gui/resources.qrc +1 -1
- accusleepy/gui/text/main_guide.md +18 -12
- accusleepy/gui/viewer_window.py +19 -7
- accusleepy/gui/viewer_window.ui +34 -2
- accusleepy/models.py +11 -1
- accusleepy/signal_processing.py +40 -17
- accusleepy/temperature_scaling.py +157 -0
- {accusleepy-0.5.0.dist-info → accusleepy-0.7.0.dist-info}/METADATA +11 -2
- accusleepy-0.7.0.dist-info/RECORD +41 -0
- {accusleepy-0.5.0.dist-info → accusleepy-0.7.0.dist-info}/WHEEL +1 -1
- accusleepy/gui/text/config_guide.txt +0 -29
- accusleepy-0.5.0.dist-info/RECORD +0 -41
accusleepy/classification.py
CHANGED
|
@@ -11,6 +11,7 @@ from tqdm import trange
|
|
|
11
11
|
|
|
12
12
|
import accusleepy.constants as c
|
|
13
13
|
from accusleepy.brain_state_set import BrainStateSet
|
|
14
|
+
from accusleepy.fileio import EMGFilter, Hyperparameters
|
|
14
15
|
from accusleepy.models import SSANN
|
|
15
16
|
from accusleepy.signal_processing import (
|
|
16
17
|
create_eeg_emg_image,
|
|
@@ -19,11 +20,6 @@ from accusleepy.signal_processing import (
|
|
|
19
20
|
mixture_z_score_img,
|
|
20
21
|
)
|
|
21
22
|
|
|
22
|
-
BATCH_SIZE = 64
|
|
23
|
-
LEARNING_RATE = 1e-3
|
|
24
|
-
MOMENTUM = 0.9
|
|
25
|
-
TRAINING_EPOCHS = 6
|
|
26
|
-
|
|
27
23
|
|
|
28
24
|
class AccuSleepImageDataset(Dataset):
|
|
29
25
|
"""Dataset for loading AccuSleep training images"""
|
|
@@ -61,11 +57,36 @@ def get_device():
|
|
|
61
57
|
)
|
|
62
58
|
|
|
63
59
|
|
|
60
|
+
def create_dataloader(
|
|
61
|
+
annotations_file: str,
|
|
62
|
+
img_dir: str,
|
|
63
|
+
hyperparameters: Hyperparameters,
|
|
64
|
+
shuffle: bool = True,
|
|
65
|
+
) -> DataLoader:
|
|
66
|
+
"""Create DataLoader for a dataset of training or calibration images
|
|
67
|
+
|
|
68
|
+
:param annotations_file: file with information on each training image
|
|
69
|
+
:param img_dir: training image location
|
|
70
|
+
:param hyperparameters: model training hyperparameters
|
|
71
|
+
:param shuffle: reshuffle data for every epoch
|
|
72
|
+
:return: DataLoader for the data
|
|
73
|
+
|
|
74
|
+
"""
|
|
75
|
+
image_dataset = AccuSleepImageDataset(
|
|
76
|
+
annotations_file=annotations_file,
|
|
77
|
+
img_dir=img_dir,
|
|
78
|
+
)
|
|
79
|
+
return DataLoader(
|
|
80
|
+
image_dataset, batch_size=hyperparameters.batch_size, shuffle=shuffle
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
64
84
|
def train_ssann(
|
|
65
85
|
annotations_file: str,
|
|
66
86
|
img_dir: str,
|
|
67
87
|
mixture_weights: np.array,
|
|
68
88
|
n_classes: int,
|
|
89
|
+
hyperparameters: Hyperparameters,
|
|
69
90
|
) -> SSANN:
|
|
70
91
|
"""Train a SSANN classification model for sleep scoring
|
|
71
92
|
|
|
@@ -73,13 +94,14 @@ def train_ssann(
|
|
|
73
94
|
:param img_dir: training image location
|
|
74
95
|
:param mixture_weights: typical relative frequencies of brain states
|
|
75
96
|
:param n_classes: number of classes the model will learn
|
|
97
|
+
:param hyperparameters: model training hyperparameters
|
|
76
98
|
:return: trained Sleep Scoring Artificial Neural Network model
|
|
77
99
|
"""
|
|
78
|
-
|
|
100
|
+
train_dataloader = create_dataloader(
|
|
79
101
|
annotations_file=annotations_file,
|
|
80
102
|
img_dir=img_dir,
|
|
103
|
+
hyperparameters=hyperparameters,
|
|
81
104
|
)
|
|
82
|
-
train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)
|
|
83
105
|
|
|
84
106
|
device = get_device()
|
|
85
107
|
model = SSANN(n_classes=n_classes)
|
|
@@ -90,9 +112,13 @@ def train_ssann(
|
|
|
90
112
|
weight = torch.tensor((mixture_weights**-1).astype("float32")).to(device)
|
|
91
113
|
|
|
92
114
|
criterion = nn.CrossEntropyLoss(weight=weight)
|
|
93
|
-
optimizer = optim.SGD(
|
|
115
|
+
optimizer = optim.SGD(
|
|
116
|
+
model.parameters(),
|
|
117
|
+
lr=hyperparameters.learning_rate,
|
|
118
|
+
momentum=hyperparameters.momentum,
|
|
119
|
+
)
|
|
94
120
|
|
|
95
|
-
for _ in trange(
|
|
121
|
+
for _ in trange(hyperparameters.training_epochs):
|
|
96
122
|
for data in train_dataloader:
|
|
97
123
|
inputs, labels = data
|
|
98
124
|
(inputs, labels) = (inputs.to(device), labels.to(device))
|
|
@@ -115,6 +141,7 @@ def score_recording(
|
|
|
115
141
|
epoch_length: int | float,
|
|
116
142
|
epochs_per_img: int,
|
|
117
143
|
brain_state_set: BrainStateSet,
|
|
144
|
+
emg_filter: EMGFilter,
|
|
118
145
|
) -> np.array:
|
|
119
146
|
"""Use classification model to get brain state labels for a recording
|
|
120
147
|
|
|
@@ -130,7 +157,8 @@ def score_recording(
|
|
|
130
157
|
:param epoch_length: epoch length, in seconds
|
|
131
158
|
:param epochs_per_img: number of epochs for the model to consider
|
|
132
159
|
:param brain_state_set: set of brain state options
|
|
133
|
-
:
|
|
160
|
+
:param emg_filter: EMG filter parameters
|
|
161
|
+
:return: brain state labels, confidence scores
|
|
134
162
|
"""
|
|
135
163
|
# prepare model
|
|
136
164
|
device = get_device()
|
|
@@ -138,7 +166,7 @@ def score_recording(
|
|
|
138
166
|
model.eval()
|
|
139
167
|
|
|
140
168
|
# create and scale eeg+emg spectrogram
|
|
141
|
-
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
169
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
|
|
142
170
|
img = mixture_z_score_img(
|
|
143
171
|
img,
|
|
144
172
|
mixture_means=mixture_means,
|
|
@@ -158,10 +186,12 @@ def score_recording(
|
|
|
158
186
|
# perform classification
|
|
159
187
|
with torch.no_grad():
|
|
160
188
|
outputs = model(images)
|
|
161
|
-
|
|
189
|
+
logits, predicted = torch.max(outputs, 1)
|
|
162
190
|
|
|
163
191
|
labels = brain_state_set.convert_class_to_digit(predicted.cpu().numpy())
|
|
164
|
-
|
|
192
|
+
confidence_scores = 1 / (1 + np.e ** (-logits.cpu().numpy()))
|
|
193
|
+
|
|
194
|
+
return labels, confidence_scores
|
|
165
195
|
|
|
166
196
|
|
|
167
197
|
def example_real_time_scoring_function(
|
|
@@ -174,6 +204,7 @@ def example_real_time_scoring_function(
|
|
|
174
204
|
epoch_length: int | float,
|
|
175
205
|
epochs_per_img: int,
|
|
176
206
|
brain_state_set: BrainStateSet,
|
|
207
|
+
emg_filter: EMGFilter,
|
|
177
208
|
) -> int:
|
|
178
209
|
"""Example function that could be used for real-time scoring
|
|
179
210
|
|
|
@@ -202,6 +233,7 @@ def example_real_time_scoring_function(
|
|
|
202
233
|
:param epoch_length: epoch length, in seconds
|
|
203
234
|
:param epochs_per_img: number of epochs shown to the model at once
|
|
204
235
|
:param brain_state_set: set of brain state options
|
|
236
|
+
:param emg_filter: EMG filter parameters
|
|
205
237
|
:return: brain state label
|
|
206
238
|
"""
|
|
207
239
|
# prepare model
|
|
@@ -211,7 +243,7 @@ def example_real_time_scoring_function(
|
|
|
211
243
|
model.eval()
|
|
212
244
|
|
|
213
245
|
# create and scale eeg+emg spectrogram
|
|
214
|
-
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
246
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
|
|
215
247
|
img = mixture_z_score_img(
|
|
216
248
|
img,
|
|
217
249
|
mixture_means=mixture_means,
|
|
@@ -242,6 +274,7 @@ def create_calibration_file(
|
|
|
242
274
|
sampling_rate: int | float,
|
|
243
275
|
epoch_length: int | float,
|
|
244
276
|
brain_state_set: BrainStateSet,
|
|
277
|
+
emg_filter: EMGFilter,
|
|
245
278
|
) -> None:
|
|
246
279
|
"""Create file of calibration data for a subject
|
|
247
280
|
|
|
@@ -255,8 +288,9 @@ def create_calibration_file(
|
|
|
255
288
|
:param sampling_rate: sampling rate, in Hz
|
|
256
289
|
:param epoch_length: epoch length, in seconds
|
|
257
290
|
:param brain_state_set: set of brain state options
|
|
291
|
+
:param emg_filter: EMG filter parameters
|
|
258
292
|
"""
|
|
259
|
-
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
293
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
|
|
260
294
|
mixture_means, mixture_sds = get_mixture_values(
|
|
261
295
|
img=img,
|
|
262
296
|
labels=brain_state_set.convert_digit_to_class(labels),
|
accusleepy/config.json
CHANGED
|
@@ -19,5 +19,19 @@
|
|
|
19
19
|
"frequency": 0.55
|
|
20
20
|
}
|
|
21
21
|
],
|
|
22
|
-
"default_epoch_length": 2.5
|
|
22
|
+
"default_epoch_length": 2.5,
|
|
23
|
+
"default_overwrite_setting": false,
|
|
24
|
+
"save_confidence_setting": true,
|
|
25
|
+
"default_min_bout_length": 5.0,
|
|
26
|
+
"emg_filter": {
|
|
27
|
+
"order": 8,
|
|
28
|
+
"bp_lower": 20.0,
|
|
29
|
+
"bp_upper": 50.0
|
|
30
|
+
},
|
|
31
|
+
"hyperparameters": {
|
|
32
|
+
"batch_size": 64,
|
|
33
|
+
"learning_rate": 0.001,
|
|
34
|
+
"momentum": 0.9,
|
|
35
|
+
"training_epochs": 6
|
|
36
|
+
}
|
|
23
37
|
}
|
accusleepy/constants.py
CHANGED
|
@@ -8,6 +8,7 @@ EEG_COL = "eeg"
|
|
|
8
8
|
EMG_COL = "emg"
|
|
9
9
|
# label file columns
|
|
10
10
|
BRAIN_STATE_COL = "brain_state"
|
|
11
|
+
CONFIDENCE_SCORE_COL = "confidence_score"
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
# really don't change these
|
|
@@ -35,7 +36,33 @@ LABEL_COL = "label"
|
|
|
35
36
|
# recording list file header:
|
|
36
37
|
RECORDING_LIST_NAME = "recording_list"
|
|
37
38
|
RECORDING_LIST_FILE_TYPE = ".json"
|
|
38
|
-
# key for default epoch length in config
|
|
39
|
-
DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
|
|
40
39
|
# filename used to store info about training image datasets
|
|
41
40
|
ANNOTATIONS_FILENAME = "annotations.csv"
|
|
41
|
+
# filename for annotation file for the calibration set
|
|
42
|
+
CALIBRATION_ANNOTATION_FILENAME = "calibration_set.csv"
|
|
43
|
+
|
|
44
|
+
# config file keys
|
|
45
|
+
# ui setting keys
|
|
46
|
+
DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
|
|
47
|
+
DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
|
|
48
|
+
DEFAULT_MIN_BOUT_LENGTH_KEY = "default_min_bout_length"
|
|
49
|
+
DEFAULT_OVERWRITE_KEY = "default_overwrite_setting"
|
|
50
|
+
# EMG filter parameters key
|
|
51
|
+
EMG_FILTER_KEY = "emg_filter"
|
|
52
|
+
# model training hyperparameters key
|
|
53
|
+
HYPERPARAMETERS_KEY = "hyperparameters"
|
|
54
|
+
|
|
55
|
+
# default values
|
|
56
|
+
# default UI settings
|
|
57
|
+
DEFAULT_MIN_BOUT_LENGTH = 5.0
|
|
58
|
+
DEFAULT_CONFIDENCE_SETTING = True
|
|
59
|
+
DEFAULT_OVERWRITE_SETTING = False
|
|
60
|
+
# default EMG filter parameters (order, bandpass frequencies)
|
|
61
|
+
DEFAULT_EMG_FILTER_ORDER = 8
|
|
62
|
+
DEFAULT_EMG_BP_LOWER = 20
|
|
63
|
+
DEFAULT_EMG_BP_UPPER = 50
|
|
64
|
+
# default hyperparameters
|
|
65
|
+
DEFAULT_BATCH_SIZE = 64
|
|
66
|
+
DEFAULT_LEARNING_RATE = 1e-3
|
|
67
|
+
DEFAULT_MOMENTUM = 0.9
|
|
68
|
+
DEFAULT_TRAINING_EPOCHS = 6
|
accusleepy/fileio.py
CHANGED
|
@@ -7,17 +7,26 @@ import pandas as pd
|
|
|
7
7
|
from PySide6.QtWidgets import QListWidgetItem
|
|
8
8
|
|
|
9
9
|
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
10
|
+
import accusleepy.constants as c
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class EMGFilter:
|
|
15
|
+
"""Convenience class for a EMG filter parameters"""
|
|
16
|
+
|
|
17
|
+
order: int # filter order
|
|
18
|
+
bp_lower: int | float # lower bandpass frequency
|
|
19
|
+
bp_upper: int | float # upper bandpass frequency
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Hyperparameters:
|
|
24
|
+
"""Convenience class for model training hyperparameters"""
|
|
25
|
+
|
|
26
|
+
batch_size: int
|
|
27
|
+
learning_rate: float
|
|
28
|
+
momentum: float
|
|
29
|
+
training_epochs: int
|
|
21
30
|
|
|
22
31
|
|
|
23
32
|
@dataclass
|
|
@@ -39,8 +48,8 @@ def load_calibration_file(filename: str) -> (np.array, np.array):
|
|
|
39
48
|
:return: mixture means and SDs
|
|
40
49
|
"""
|
|
41
50
|
df = pd.read_csv(filename)
|
|
42
|
-
mixture_means = df[MIXTURE_MEAN_COL].values
|
|
43
|
-
mixture_sds = df[MIXTURE_SD_COL].values
|
|
51
|
+
mixture_means = df[c.MIXTURE_MEAN_COL].values
|
|
52
|
+
mixture_sds = df[c.MIXTURE_SD_COL].values
|
|
44
53
|
return mixture_means, mixture_sds
|
|
45
54
|
|
|
46
55
|
|
|
@@ -67,56 +76,121 @@ def load_recording(filename: str) -> (np.array, np.array):
|
|
|
67
76
|
:return: arrays of EEG and EMG data
|
|
68
77
|
"""
|
|
69
78
|
df = load_csv_or_parquet(filename)
|
|
70
|
-
eeg = df[EEG_COL].values
|
|
71
|
-
emg = df[EMG_COL].values
|
|
79
|
+
eeg = df[c.EEG_COL].values
|
|
80
|
+
emg = df[c.EMG_COL].values
|
|
72
81
|
return eeg, emg
|
|
73
82
|
|
|
74
83
|
|
|
75
|
-
def load_labels(filename: str) -> np.array:
|
|
76
|
-
"""Load file of brain state labels
|
|
84
|
+
def load_labels(filename: str) -> (np.array, np.array):
|
|
85
|
+
"""Load file of brain state labels and confidence scores
|
|
77
86
|
|
|
78
87
|
:param filename: filename
|
|
79
|
-
:return: array of brain state labels
|
|
88
|
+
:return: array of brain state labels and, optionally, array of confidence scores
|
|
80
89
|
"""
|
|
81
90
|
df = load_csv_or_parquet(filename)
|
|
82
|
-
|
|
91
|
+
if c.CONFIDENCE_SCORE_COL in df.columns:
|
|
92
|
+
return df[c.BRAIN_STATE_COL].values, df[c.CONFIDENCE_SCORE_COL].values
|
|
93
|
+
else:
|
|
94
|
+
return df[c.BRAIN_STATE_COL].values, None
|
|
83
95
|
|
|
84
96
|
|
|
85
|
-
def save_labels(
|
|
97
|
+
def save_labels(
|
|
98
|
+
labels: np.array, filename: str, confidence_scores: np.array = None
|
|
99
|
+
) -> None:
|
|
86
100
|
"""Save brain state labels to file
|
|
87
101
|
|
|
88
102
|
:param labels: brain state labels
|
|
89
103
|
:param filename: filename
|
|
104
|
+
:param confidence_scores: optional confidence scores
|
|
90
105
|
"""
|
|
91
|
-
|
|
106
|
+
if confidence_scores is not None:
|
|
107
|
+
pd.DataFrame(
|
|
108
|
+
{c.BRAIN_STATE_COL: labels, c.CONFIDENCE_SCORE_COL: confidence_scores}
|
|
109
|
+
).to_csv(filename, index=False)
|
|
110
|
+
else:
|
|
111
|
+
pd.DataFrame({c.BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
|
|
92
112
|
|
|
93
113
|
|
|
94
|
-
def load_config() -> tuple[
|
|
114
|
+
def load_config() -> tuple[
|
|
115
|
+
BrainStateSet, int | float, bool, bool, int | float, EMGFilter, Hyperparameters
|
|
116
|
+
]:
|
|
95
117
|
"""Load configuration file with brain state options
|
|
96
118
|
|
|
97
|
-
:return: set of brain state options
|
|
119
|
+
:return: set of brain state options,
|
|
120
|
+
default epoch length,
|
|
121
|
+
default overwrite setting,
|
|
122
|
+
default confidence score output setting,
|
|
123
|
+
default minimum bout length,
|
|
124
|
+
EMG filter parameters,
|
|
125
|
+
model training hyperparameters
|
|
98
126
|
"""
|
|
99
127
|
with open(
|
|
100
|
-
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "r"
|
|
128
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "r"
|
|
101
129
|
) as f:
|
|
102
130
|
data = json.load(f)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
131
|
+
|
|
132
|
+
return (
|
|
133
|
+
BrainStateSet(
|
|
134
|
+
[BrainState(**b) for b in data[BRAIN_STATES_KEY]], c.UNDEFINED_LABEL
|
|
135
|
+
),
|
|
136
|
+
data[c.DEFAULT_EPOCH_LENGTH_KEY],
|
|
137
|
+
data.get(c.DEFAULT_OVERWRITE_KEY, c.DEFAULT_OVERWRITE_SETTING),
|
|
138
|
+
data.get(c.DEFAULT_CONFIDENCE_SETTING_KEY, c.DEFAULT_CONFIDENCE_SETTING),
|
|
139
|
+
data.get(c.DEFAULT_MIN_BOUT_LENGTH_KEY, c.DEFAULT_MIN_BOUT_LENGTH),
|
|
140
|
+
EMGFilter(
|
|
141
|
+
**data.get(
|
|
142
|
+
c.EMG_FILTER_KEY,
|
|
143
|
+
{
|
|
144
|
+
"order": c.DEFAULT_EMG_FILTER_ORDER,
|
|
145
|
+
"bp_lower": c.DEFAULT_EMG_BP_LOWER,
|
|
146
|
+
"bp_upper": c.DEFAULT_EMG_BP_UPPER,
|
|
147
|
+
},
|
|
148
|
+
)
|
|
149
|
+
),
|
|
150
|
+
Hyperparameters(
|
|
151
|
+
**data.get(
|
|
152
|
+
c.HYPERPARAMETERS_KEY,
|
|
153
|
+
{
|
|
154
|
+
"batch_size": c.DEFAULT_BATCH_SIZE,
|
|
155
|
+
"learning_rate": c.DEFAULT_LEARNING_RATE,
|
|
156
|
+
"momentum": c.DEFAULT_MOMENTUM,
|
|
157
|
+
"training_epochs": c.DEFAULT_TRAINING_EPOCHS,
|
|
158
|
+
},
|
|
159
|
+
)
|
|
160
|
+
),
|
|
161
|
+
)
|
|
106
162
|
|
|
107
163
|
|
|
108
164
|
def save_config(
|
|
109
|
-
brain_state_set: BrainStateSet,
|
|
165
|
+
brain_state_set: BrainStateSet,
|
|
166
|
+
default_epoch_length: int | float,
|
|
167
|
+
overwrite_setting: bool,
|
|
168
|
+
save_confidence_setting: bool,
|
|
169
|
+
min_bout_length: int | float,
|
|
170
|
+
emg_filter: EMGFilter,
|
|
171
|
+
hyperparameters: Hyperparameters,
|
|
110
172
|
) -> None:
|
|
111
173
|
"""Save configuration of brain state options to json file
|
|
112
174
|
|
|
113
175
|
:param brain_state_set: set of brain state options
|
|
114
|
-
:param default_epoch_length: epoch length
|
|
176
|
+
:param default_epoch_length: default epoch length
|
|
177
|
+
:param save_confidence_setting: default setting for
|
|
178
|
+
saving confidence scores
|
|
179
|
+
:param emg_filter: EMG filter parameters
|
|
180
|
+
:param min_bout_length: default minimum bout length
|
|
181
|
+
:param overwrite_setting: default setting for overwriting
|
|
182
|
+
existing labels
|
|
183
|
+
:param hyperparameters: model training hyperparameters
|
|
115
184
|
"""
|
|
116
185
|
output_dict = brain_state_set.to_output_dict()
|
|
117
|
-
output_dict.update({DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
|
|
186
|
+
output_dict.update({c.DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
|
|
187
|
+
output_dict.update({c.DEFAULT_OVERWRITE_KEY: overwrite_setting})
|
|
188
|
+
output_dict.update({c.DEFAULT_CONFIDENCE_SETTING_KEY: save_confidence_setting})
|
|
189
|
+
output_dict.update({c.DEFAULT_MIN_BOUT_LENGTH_KEY: min_bout_length})
|
|
190
|
+
output_dict.update({c.EMG_FILTER_KEY: emg_filter.__dict__})
|
|
191
|
+
output_dict.update({c.HYPERPARAMETERS_KEY: hyperparameters.__dict__})
|
|
118
192
|
with open(
|
|
119
|
-
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "w"
|
|
193
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "w"
|
|
120
194
|
) as f:
|
|
121
195
|
json.dump(output_dict, f, indent=4)
|
|
122
196
|
|
|
@@ -129,7 +203,7 @@ def load_recording_list(filename: str) -> list[Recording]:
|
|
|
129
203
|
"""
|
|
130
204
|
with open(filename, "r") as f:
|
|
131
205
|
data = json.load(f)
|
|
132
|
-
recording_list = [Recording(**r) for r in data[RECORDING_LIST_NAME]]
|
|
206
|
+
recording_list = [Recording(**r) for r in data[c.RECORDING_LIST_NAME]]
|
|
133
207
|
for i, r in enumerate(recording_list):
|
|
134
208
|
r.name = i + 1
|
|
135
209
|
return recording_list
|
|
@@ -142,7 +216,7 @@ def save_recording_list(filename: str, recordings: list[Recording]) -> None:
|
|
|
142
216
|
:param recordings: list of recordings to export
|
|
143
217
|
"""
|
|
144
218
|
recording_dict = {
|
|
145
|
-
RECORDING_LIST_NAME: [
|
|
219
|
+
c.RECORDING_LIST_NAME: [
|
|
146
220
|
{
|
|
147
221
|
"recording_file": r.recording_file,
|
|
148
222
|
"label_file": r.label_file,
|
|
Binary file
|
|
Binary file
|
|
Binary file
|