accusleepy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/classification.py +29 -13
- accusleepy/config.json +14 -1
- accusleepy/constants.py +26 -4
- accusleepy/fileio.py +87 -36
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/main.py +123 -21
- accusleepy/gui/manual_scoring.py +13 -5
- accusleepy/gui/primary_window.py +730 -128
- accusleepy/gui/primary_window.ui +2916 -2119
- accusleepy/gui/text/main_guide.md +2 -1
- accusleepy/signal_processing.py +16 -11
- {accusleepy-0.6.0.dist-info → accusleepy-0.7.0.dist-info}/METADATA +3 -1
- {accusleepy-0.6.0.dist-info → accusleepy-0.7.0.dist-info}/RECORD +14 -15
- accusleepy/gui/text/config_guide.txt +0 -27
- {accusleepy-0.6.0.dist-info → accusleepy-0.7.0.dist-info}/WHEEL +0 -0
accusleepy/classification.py
CHANGED
|
@@ -11,6 +11,7 @@ from tqdm import trange
|
|
|
11
11
|
|
|
12
12
|
import accusleepy.constants as c
|
|
13
13
|
from accusleepy.brain_state_set import BrainStateSet
|
|
14
|
+
from accusleepy.fileio import EMGFilter, Hyperparameters
|
|
14
15
|
from accusleepy.models import SSANN
|
|
15
16
|
from accusleepy.signal_processing import (
|
|
16
17
|
create_eeg_emg_image,
|
|
@@ -19,11 +20,6 @@ from accusleepy.signal_processing import (
|
|
|
19
20
|
mixture_z_score_img,
|
|
20
21
|
)
|
|
21
22
|
|
|
22
|
-
BATCH_SIZE = 64
|
|
23
|
-
LEARNING_RATE = 1e-3
|
|
24
|
-
MOMENTUM = 0.9
|
|
25
|
-
TRAINING_EPOCHS = 6
|
|
26
|
-
|
|
27
23
|
|
|
28
24
|
class AccuSleepImageDataset(Dataset):
|
|
29
25
|
"""Dataset for loading AccuSleep training images"""
|
|
@@ -62,12 +58,16 @@ def get_device():
|
|
|
62
58
|
|
|
63
59
|
|
|
64
60
|
def create_dataloader(
|
|
65
|
-
annotations_file: str,
|
|
61
|
+
annotations_file: str,
|
|
62
|
+
img_dir: str,
|
|
63
|
+
hyperparameters: Hyperparameters,
|
|
64
|
+
shuffle: bool = True,
|
|
66
65
|
) -> DataLoader:
|
|
67
66
|
"""Create DataLoader for a dataset of training or calibration images
|
|
68
67
|
|
|
69
68
|
:param annotations_file: file with information on each training image
|
|
70
69
|
:param img_dir: training image location
|
|
70
|
+
:param hyperparameters: model training hyperparameters
|
|
71
71
|
:param shuffle: reshuffle data for every epoch
|
|
72
72
|
:return: DataLoader for the data
|
|
73
73
|
|
|
@@ -76,7 +76,9 @@ def create_dataloader(
|
|
|
76
76
|
annotations_file=annotations_file,
|
|
77
77
|
img_dir=img_dir,
|
|
78
78
|
)
|
|
79
|
-
return DataLoader(
|
|
79
|
+
return DataLoader(
|
|
80
|
+
image_dataset, batch_size=hyperparameters.batch_size, shuffle=shuffle
|
|
81
|
+
)
|
|
80
82
|
|
|
81
83
|
|
|
82
84
|
def train_ssann(
|
|
@@ -84,6 +86,7 @@ def train_ssann(
|
|
|
84
86
|
img_dir: str,
|
|
85
87
|
mixture_weights: np.array,
|
|
86
88
|
n_classes: int,
|
|
89
|
+
hyperparameters: Hyperparameters,
|
|
87
90
|
) -> SSANN:
|
|
88
91
|
"""Train a SSANN classification model for sleep scoring
|
|
89
92
|
|
|
@@ -91,10 +94,13 @@ def train_ssann(
|
|
|
91
94
|
:param img_dir: training image location
|
|
92
95
|
:param mixture_weights: typical relative frequencies of brain states
|
|
93
96
|
:param n_classes: number of classes the model will learn
|
|
97
|
+
:param hyperparameters: model training hyperparameters
|
|
94
98
|
:return: trained Sleep Scoring Artificial Neural Network model
|
|
95
99
|
"""
|
|
96
100
|
train_dataloader = create_dataloader(
|
|
97
|
-
annotations_file=annotations_file,
|
|
101
|
+
annotations_file=annotations_file,
|
|
102
|
+
img_dir=img_dir,
|
|
103
|
+
hyperparameters=hyperparameters,
|
|
98
104
|
)
|
|
99
105
|
|
|
100
106
|
device = get_device()
|
|
@@ -106,9 +112,13 @@ def train_ssann(
|
|
|
106
112
|
weight = torch.tensor((mixture_weights**-1).astype("float32")).to(device)
|
|
107
113
|
|
|
108
114
|
criterion = nn.CrossEntropyLoss(weight=weight)
|
|
109
|
-
optimizer = optim.SGD(
|
|
115
|
+
optimizer = optim.SGD(
|
|
116
|
+
model.parameters(),
|
|
117
|
+
lr=hyperparameters.learning_rate,
|
|
118
|
+
momentum=hyperparameters.momentum,
|
|
119
|
+
)
|
|
110
120
|
|
|
111
|
-
for _ in trange(
|
|
121
|
+
for _ in trange(hyperparameters.training_epochs):
|
|
112
122
|
for data in train_dataloader:
|
|
113
123
|
inputs, labels = data
|
|
114
124
|
(inputs, labels) = (inputs.to(device), labels.to(device))
|
|
@@ -131,6 +141,7 @@ def score_recording(
|
|
|
131
141
|
epoch_length: int | float,
|
|
132
142
|
epochs_per_img: int,
|
|
133
143
|
brain_state_set: BrainStateSet,
|
|
144
|
+
emg_filter: EMGFilter,
|
|
134
145
|
) -> np.array:
|
|
135
146
|
"""Use classification model to get brain state labels for a recording
|
|
136
147
|
|
|
@@ -146,6 +157,7 @@ def score_recording(
|
|
|
146
157
|
:param epoch_length: epoch length, in seconds
|
|
147
158
|
:param epochs_per_img: number of epochs for the model to consider
|
|
148
159
|
:param brain_state_set: set of brain state options
|
|
160
|
+
:param emg_filter: EMG filter parameters
|
|
149
161
|
:return: brain state labels, confidence scores
|
|
150
162
|
"""
|
|
151
163
|
# prepare model
|
|
@@ -154,7 +166,7 @@ def score_recording(
|
|
|
154
166
|
model.eval()
|
|
155
167
|
|
|
156
168
|
# create and scale eeg+emg spectrogram
|
|
157
|
-
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
169
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
|
|
158
170
|
img = mixture_z_score_img(
|
|
159
171
|
img,
|
|
160
172
|
mixture_means=mixture_means,
|
|
@@ -192,6 +204,7 @@ def example_real_time_scoring_function(
|
|
|
192
204
|
epoch_length: int | float,
|
|
193
205
|
epochs_per_img: int,
|
|
194
206
|
brain_state_set: BrainStateSet,
|
|
207
|
+
emg_filter: EMGFilter,
|
|
195
208
|
) -> int:
|
|
196
209
|
"""Example function that could be used for real-time scoring
|
|
197
210
|
|
|
@@ -220,6 +233,7 @@ def example_real_time_scoring_function(
|
|
|
220
233
|
:param epoch_length: epoch length, in seconds
|
|
221
234
|
:param epochs_per_img: number of epochs shown to the model at once
|
|
222
235
|
:param brain_state_set: set of brain state options
|
|
236
|
+
:param emg_filter: EMG filter parameters
|
|
223
237
|
:return: brain state label
|
|
224
238
|
"""
|
|
225
239
|
# prepare model
|
|
@@ -229,7 +243,7 @@ def example_real_time_scoring_function(
|
|
|
229
243
|
model.eval()
|
|
230
244
|
|
|
231
245
|
# create and scale eeg+emg spectrogram
|
|
232
|
-
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
246
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
|
|
233
247
|
img = mixture_z_score_img(
|
|
234
248
|
img,
|
|
235
249
|
mixture_means=mixture_means,
|
|
@@ -260,6 +274,7 @@ def create_calibration_file(
|
|
|
260
274
|
sampling_rate: int | float,
|
|
261
275
|
epoch_length: int | float,
|
|
262
276
|
brain_state_set: BrainStateSet,
|
|
277
|
+
emg_filter: EMGFilter,
|
|
263
278
|
) -> None:
|
|
264
279
|
"""Create file of calibration data for a subject
|
|
265
280
|
|
|
@@ -273,8 +288,9 @@ def create_calibration_file(
|
|
|
273
288
|
:param sampling_rate: sampling rate, in Hz
|
|
274
289
|
:param epoch_length: epoch length, in seconds
|
|
275
290
|
:param brain_state_set: set of brain state options
|
|
291
|
+
:param emg_filter: EMG filter parameters
|
|
276
292
|
"""
|
|
277
|
-
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
293
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
|
|
278
294
|
mixture_means, mixture_sds = get_mixture_values(
|
|
279
295
|
img=img,
|
|
280
296
|
labels=brain_state_set.convert_digit_to_class(labels),
|
accusleepy/config.json
CHANGED
|
@@ -20,5 +20,18 @@
|
|
|
20
20
|
}
|
|
21
21
|
],
|
|
22
22
|
"default_epoch_length": 2.5,
|
|
23
|
-
"
|
|
23
|
+
"default_overwrite_setting": false,
|
|
24
|
+
"save_confidence_setting": true,
|
|
25
|
+
"default_min_bout_length": 5.0,
|
|
26
|
+
"emg_filter": {
|
|
27
|
+
"order": 8,
|
|
28
|
+
"bp_lower": 20.0,
|
|
29
|
+
"bp_upper": 50.0
|
|
30
|
+
},
|
|
31
|
+
"hyperparameters": {
|
|
32
|
+
"batch_size": 64,
|
|
33
|
+
"learning_rate": 0.001,
|
|
34
|
+
"momentum": 0.9,
|
|
35
|
+
"training_epochs": 6
|
|
36
|
+
}
|
|
24
37
|
}
|
accusleepy/constants.py
CHANGED
|
@@ -36,11 +36,33 @@ LABEL_COL = "label"
|
|
|
36
36
|
# recording list file header:
|
|
37
37
|
RECORDING_LIST_NAME = "recording_list"
|
|
38
38
|
RECORDING_LIST_FILE_TYPE = ".json"
|
|
39
|
-
# key for default epoch length in config
|
|
40
|
-
DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
|
|
41
|
-
# key used for default confidence score behavior in config
|
|
42
|
-
DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
|
|
43
39
|
# filename used to store info about training image datasets
|
|
44
40
|
ANNOTATIONS_FILENAME = "annotations.csv"
|
|
45
41
|
# filename for annotation file for the calibration set
|
|
46
42
|
CALIBRATION_ANNOTATION_FILENAME = "calibration_set.csv"
|
|
43
|
+
|
|
44
|
+
# config file keys
|
|
45
|
+
# ui setting keys
|
|
46
|
+
DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
|
|
47
|
+
DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
|
|
48
|
+
DEFAULT_MIN_BOUT_LENGTH_KEY = "default_min_bout_length"
|
|
49
|
+
DEFAULT_OVERWRITE_KEY = "default_overwrite_setting"
|
|
50
|
+
# EMG filter parameters key
|
|
51
|
+
EMG_FILTER_KEY = "emg_filter"
|
|
52
|
+
# model training hyperparameters key
|
|
53
|
+
HYPERPARAMETERS_KEY = "hyperparameters"
|
|
54
|
+
|
|
55
|
+
# default values
|
|
56
|
+
# default UI settings
|
|
57
|
+
DEFAULT_MIN_BOUT_LENGTH = 5.0
|
|
58
|
+
DEFAULT_CONFIDENCE_SETTING = True
|
|
59
|
+
DEFAULT_OVERWRITE_SETTING = False
|
|
60
|
+
# default EMG filter parameters (order, bandpass frequencies)
|
|
61
|
+
DEFAULT_EMG_FILTER_ORDER = 8
|
|
62
|
+
DEFAULT_EMG_BP_LOWER = 20
|
|
63
|
+
DEFAULT_EMG_BP_UPPER = 50
|
|
64
|
+
# default hyperparameters
|
|
65
|
+
DEFAULT_BATCH_SIZE = 64
|
|
66
|
+
DEFAULT_LEARNING_RATE = 1e-3
|
|
67
|
+
DEFAULT_MOMENTUM = 0.9
|
|
68
|
+
DEFAULT_TRAINING_EPOCHS = 6
|
accusleepy/fileio.py
CHANGED
|
@@ -7,19 +7,26 @@ import pandas as pd
|
|
|
7
7
|
from PySide6.QtWidgets import QListWidgetItem
|
|
8
8
|
|
|
9
9
|
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
10
|
+
import accusleepy.constants as c
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class EMGFilter:
|
|
15
|
+
"""Convenience class for a EMG filter parameters"""
|
|
16
|
+
|
|
17
|
+
order: int # filter order
|
|
18
|
+
bp_lower: int | float # lower bandpass frequency
|
|
19
|
+
bp_upper: int | float # upper bandpass frequency
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Hyperparameters:
|
|
24
|
+
"""Convenience class for model training hyperparameters"""
|
|
25
|
+
|
|
26
|
+
batch_size: int
|
|
27
|
+
learning_rate: float
|
|
28
|
+
momentum: float
|
|
29
|
+
training_epochs: int
|
|
23
30
|
|
|
24
31
|
|
|
25
32
|
@dataclass
|
|
@@ -41,8 +48,8 @@ def load_calibration_file(filename: str) -> (np.array, np.array):
|
|
|
41
48
|
:return: mixture means and SDs
|
|
42
49
|
"""
|
|
43
50
|
df = pd.read_csv(filename)
|
|
44
|
-
mixture_means = df[MIXTURE_MEAN_COL].values
|
|
45
|
-
mixture_sds = df[MIXTURE_SD_COL].values
|
|
51
|
+
mixture_means = df[c.MIXTURE_MEAN_COL].values
|
|
52
|
+
mixture_sds = df[c.MIXTURE_SD_COL].values
|
|
46
53
|
return mixture_means, mixture_sds
|
|
47
54
|
|
|
48
55
|
|
|
@@ -69,8 +76,8 @@ def load_recording(filename: str) -> (np.array, np.array):
|
|
|
69
76
|
:return: arrays of EEG and EMG data
|
|
70
77
|
"""
|
|
71
78
|
df = load_csv_or_parquet(filename)
|
|
72
|
-
eeg = df[EEG_COL].values
|
|
73
|
-
emg = df[EMG_COL].values
|
|
79
|
+
eeg = df[c.EEG_COL].values
|
|
80
|
+
emg = df[c.EMG_COL].values
|
|
74
81
|
return eeg, emg
|
|
75
82
|
|
|
76
83
|
|
|
@@ -81,10 +88,10 @@ def load_labels(filename: str) -> (np.array, np.array):
|
|
|
81
88
|
:return: array of brain state labels and, optionally, array of confidence scores
|
|
82
89
|
"""
|
|
83
90
|
df = load_csv_or_parquet(filename)
|
|
84
|
-
if CONFIDENCE_SCORE_COL in df.columns:
|
|
85
|
-
return df[BRAIN_STATE_COL].values, df[CONFIDENCE_SCORE_COL].values
|
|
91
|
+
if c.CONFIDENCE_SCORE_COL in df.columns:
|
|
92
|
+
return df[c.BRAIN_STATE_COL].values, df[c.CONFIDENCE_SCORE_COL].values
|
|
86
93
|
else:
|
|
87
|
-
return df[BRAIN_STATE_COL].values, None
|
|
94
|
+
return df[c.BRAIN_STATE_COL].values, None
|
|
88
95
|
|
|
89
96
|
|
|
90
97
|
def save_labels(
|
|
@@ -98,48 +105,92 @@ def save_labels(
|
|
|
98
105
|
"""
|
|
99
106
|
if confidence_scores is not None:
|
|
100
107
|
pd.DataFrame(
|
|
101
|
-
{BRAIN_STATE_COL: labels, CONFIDENCE_SCORE_COL: confidence_scores}
|
|
108
|
+
{c.BRAIN_STATE_COL: labels, c.CONFIDENCE_SCORE_COL: confidence_scores}
|
|
102
109
|
).to_csv(filename, index=False)
|
|
103
110
|
else:
|
|
104
|
-
pd.DataFrame({BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
|
|
111
|
+
pd.DataFrame({c.BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
|
|
105
112
|
|
|
106
113
|
|
|
107
|
-
def load_config() -> tuple[
|
|
114
|
+
def load_config() -> tuple[
|
|
115
|
+
BrainStateSet, int | float, bool, bool, int | float, EMGFilter, Hyperparameters
|
|
116
|
+
]:
|
|
108
117
|
"""Load configuration file with brain state options
|
|
109
118
|
|
|
110
|
-
:return: set of brain state options,
|
|
119
|
+
:return: set of brain state options,
|
|
120
|
+
default epoch length,
|
|
121
|
+
default overwrite setting,
|
|
122
|
+
default confidence score output setting,
|
|
123
|
+
default minimum bout length,
|
|
124
|
+
EMG filter parameters,
|
|
125
|
+
model training hyperparameters
|
|
111
126
|
"""
|
|
112
127
|
with open(
|
|
113
|
-
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "r"
|
|
128
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "r"
|
|
114
129
|
) as f:
|
|
115
130
|
data = json.load(f)
|
|
116
131
|
|
|
117
132
|
return (
|
|
118
133
|
BrainStateSet(
|
|
119
|
-
[BrainState(**b) for b in data[BRAIN_STATES_KEY]], UNDEFINED_LABEL
|
|
134
|
+
[BrainState(**b) for b in data[BRAIN_STATES_KEY]], c.UNDEFINED_LABEL
|
|
135
|
+
),
|
|
136
|
+
data[c.DEFAULT_EPOCH_LENGTH_KEY],
|
|
137
|
+
data.get(c.DEFAULT_OVERWRITE_KEY, c.DEFAULT_OVERWRITE_SETTING),
|
|
138
|
+
data.get(c.DEFAULT_CONFIDENCE_SETTING_KEY, c.DEFAULT_CONFIDENCE_SETTING),
|
|
139
|
+
data.get(c.DEFAULT_MIN_BOUT_LENGTH_KEY, c.DEFAULT_MIN_BOUT_LENGTH),
|
|
140
|
+
EMGFilter(
|
|
141
|
+
**data.get(
|
|
142
|
+
c.EMG_FILTER_KEY,
|
|
143
|
+
{
|
|
144
|
+
"order": c.DEFAULT_EMG_FILTER_ORDER,
|
|
145
|
+
"bp_lower": c.DEFAULT_EMG_BP_LOWER,
|
|
146
|
+
"bp_upper": c.DEFAULT_EMG_BP_UPPER,
|
|
147
|
+
},
|
|
148
|
+
)
|
|
149
|
+
),
|
|
150
|
+
Hyperparameters(
|
|
151
|
+
**data.get(
|
|
152
|
+
c.HYPERPARAMETERS_KEY,
|
|
153
|
+
{
|
|
154
|
+
"batch_size": c.DEFAULT_BATCH_SIZE,
|
|
155
|
+
"learning_rate": c.DEFAULT_LEARNING_RATE,
|
|
156
|
+
"momentum": c.DEFAULT_MOMENTUM,
|
|
157
|
+
"training_epochs": c.DEFAULT_TRAINING_EPOCHS,
|
|
158
|
+
},
|
|
159
|
+
)
|
|
120
160
|
),
|
|
121
|
-
data[DEFAULT_EPOCH_LENGTH_KEY],
|
|
122
|
-
data.get(DEFAULT_CONFIDENCE_SETTING_KEY, True),
|
|
123
161
|
)
|
|
124
162
|
|
|
125
163
|
|
|
126
164
|
def save_config(
|
|
127
165
|
brain_state_set: BrainStateSet,
|
|
128
166
|
default_epoch_length: int | float,
|
|
167
|
+
overwrite_setting: bool,
|
|
129
168
|
save_confidence_setting: bool,
|
|
169
|
+
min_bout_length: int | float,
|
|
170
|
+
emg_filter: EMGFilter,
|
|
171
|
+
hyperparameters: Hyperparameters,
|
|
130
172
|
) -> None:
|
|
131
173
|
"""Save configuration of brain state options to json file
|
|
132
174
|
|
|
133
175
|
:param brain_state_set: set of brain state options
|
|
134
|
-
:param default_epoch_length: epoch length
|
|
135
|
-
:param save_confidence_setting:
|
|
136
|
-
|
|
176
|
+
:param default_epoch_length: default epoch length
|
|
177
|
+
:param save_confidence_setting: default setting for
|
|
178
|
+
saving confidence scores
|
|
179
|
+
:param emg_filter: EMG filter parameters
|
|
180
|
+
:param min_bout_length: default minimum bout length
|
|
181
|
+
:param overwrite_setting: default setting for overwriting
|
|
182
|
+
existing labels
|
|
183
|
+
:param hyperparameters: model training hyperparameters
|
|
137
184
|
"""
|
|
138
185
|
output_dict = brain_state_set.to_output_dict()
|
|
139
|
-
output_dict.update({DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
|
|
140
|
-
output_dict.update({
|
|
186
|
+
output_dict.update({c.DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
|
|
187
|
+
output_dict.update({c.DEFAULT_OVERWRITE_KEY: overwrite_setting})
|
|
188
|
+
output_dict.update({c.DEFAULT_CONFIDENCE_SETTING_KEY: save_confidence_setting})
|
|
189
|
+
output_dict.update({c.DEFAULT_MIN_BOUT_LENGTH_KEY: min_bout_length})
|
|
190
|
+
output_dict.update({c.EMG_FILTER_KEY: emg_filter.__dict__})
|
|
191
|
+
output_dict.update({c.HYPERPARAMETERS_KEY: hyperparameters.__dict__})
|
|
141
192
|
with open(
|
|
142
|
-
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "w"
|
|
193
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), c.CONFIG_FILE), "w"
|
|
143
194
|
) as f:
|
|
144
195
|
json.dump(output_dict, f, indent=4)
|
|
145
196
|
|
|
@@ -152,7 +203,7 @@ def load_recording_list(filename: str) -> list[Recording]:
|
|
|
152
203
|
"""
|
|
153
204
|
with open(filename, "r") as f:
|
|
154
205
|
data = json.load(f)
|
|
155
|
-
recording_list = [Recording(**r) for r in data[RECORDING_LIST_NAME]]
|
|
206
|
+
recording_list = [Recording(**r) for r in data[c.RECORDING_LIST_NAME]]
|
|
156
207
|
for i, r in enumerate(recording_list):
|
|
157
208
|
r.name = i + 1
|
|
158
209
|
return recording_list
|
|
@@ -165,7 +216,7 @@ def save_recording_list(filename: str, recordings: list[Recording]) -> None:
|
|
|
165
216
|
:param recordings: list of recordings to export
|
|
166
217
|
"""
|
|
167
218
|
recording_dict = {
|
|
168
|
-
RECORDING_LIST_NAME: [
|
|
219
|
+
c.RECORDING_LIST_NAME: [
|
|
169
220
|
{
|
|
170
221
|
"recording_file": r.recording_file,
|
|
171
222
|
"label_file": r.label_file,
|
|
Binary file
|
accusleepy/gui/main.py
CHANGED
|
@@ -39,6 +39,13 @@ from accusleepy.constants import (
|
|
|
39
39
|
CALIBRATION_ANNOTATION_FILENAME,
|
|
40
40
|
CALIBRATION_FILE_TYPE,
|
|
41
41
|
DEFAULT_MODEL_TYPE,
|
|
42
|
+
DEFAULT_EMG_FILTER_ORDER,
|
|
43
|
+
DEFAULT_EMG_BP_LOWER,
|
|
44
|
+
DEFAULT_EMG_BP_UPPER,
|
|
45
|
+
DEFAULT_BATCH_SIZE,
|
|
46
|
+
DEFAULT_LEARNING_RATE,
|
|
47
|
+
DEFAULT_MOMENTUM,
|
|
48
|
+
DEFAULT_TRAINING_EPOCHS,
|
|
42
49
|
LABEL_FILE_TYPE,
|
|
43
50
|
MODEL_FILE_TYPE,
|
|
44
51
|
REAL_TIME_MODEL_TYPE,
|
|
@@ -56,6 +63,8 @@ from accusleepy.fileio import (
|
|
|
56
63
|
save_config,
|
|
57
64
|
save_labels,
|
|
58
65
|
save_recording_list,
|
|
66
|
+
EMGFilter,
|
|
67
|
+
Hyperparameters,
|
|
59
68
|
)
|
|
60
69
|
from accusleepy.gui.manual_scoring import ManualScoringWindow
|
|
61
70
|
from accusleepy.gui.primary_window import Ui_PrimaryWindow
|
|
@@ -97,19 +106,25 @@ class AccuSleepWindow(QMainWindow):
|
|
|
97
106
|
self.setWindowTitle("AccuSleePy")
|
|
98
107
|
|
|
99
108
|
# fill in settings tab
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
109
|
+
(
|
|
110
|
+
self.brain_state_set,
|
|
111
|
+
self.epoch_length,
|
|
112
|
+
self.only_overwrite_undefined,
|
|
113
|
+
self.save_confidence_scores,
|
|
114
|
+
self.min_bout_length,
|
|
115
|
+
self.emg_filter,
|
|
116
|
+
self.hyperparameters,
|
|
117
|
+
) = load_config()
|
|
118
|
+
|
|
103
119
|
self.settings_widgets = None
|
|
104
120
|
self.initialize_settings_tab()
|
|
105
121
|
|
|
106
122
|
# initialize info about the recordings, classification data / settings
|
|
107
123
|
self.ui.epoch_length_input.setValue(self.epoch_length)
|
|
108
|
-
self.ui.
|
|
124
|
+
self.ui.overwritecheckbox.setChecked(self.only_overwrite_undefined)
|
|
125
|
+
self.ui.save_confidence_checkbox.setChecked(self.save_confidence_scores)
|
|
126
|
+
self.ui.bout_length_input.setValue(self.min_bout_length)
|
|
109
127
|
self.model = None
|
|
110
|
-
self.only_overwrite_undefined = False
|
|
111
|
-
self.save_confidence_scores = self.save_confidence_setting
|
|
112
|
-
self.min_bout_length = 5
|
|
113
128
|
|
|
114
129
|
# initialize model training variables
|
|
115
130
|
self.training_epochs_per_img = 9
|
|
@@ -186,6 +201,10 @@ class AccuSleepWindow(QMainWindow):
|
|
|
186
201
|
self.ui.export_button.clicked.connect(self.export_recording_list)
|
|
187
202
|
self.ui.import_button.clicked.connect(self.import_recording_list)
|
|
188
203
|
self.ui.default_type_button.toggled.connect(self.model_type_radio_buttons)
|
|
204
|
+
self.ui.reset_emg_params_button.clicked.connect(self.reset_emg_filter_settings)
|
|
205
|
+
self.ui.reset_hyperparams_button.clicked.connect(
|
|
206
|
+
self.reset_hyperparams_settings
|
|
207
|
+
)
|
|
189
208
|
|
|
190
209
|
# user input: drag and drop
|
|
191
210
|
self.ui.recording_file_label.installEventFilter(self)
|
|
@@ -363,6 +382,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
363
382
|
brain_state_set=self.brain_state_set,
|
|
364
383
|
model_type=self.model_type,
|
|
365
384
|
calibration_fraction=calibration_fraction,
|
|
385
|
+
emg_filter=self.emg_filter,
|
|
366
386
|
)
|
|
367
387
|
if len(failed_recordings) > 0:
|
|
368
388
|
if len(failed_recordings) == len(self.recordings):
|
|
@@ -391,6 +411,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
391
411
|
img_dir=temp_image_dir,
|
|
392
412
|
mixture_weights=self.brain_state_set.mixture_weights,
|
|
393
413
|
n_classes=self.brain_state_set.n_classes,
|
|
414
|
+
hyperparameters=self.hyperparameters,
|
|
394
415
|
)
|
|
395
416
|
|
|
396
417
|
# calibrate the model
|
|
@@ -399,7 +420,9 @@ class AccuSleepWindow(QMainWindow):
|
|
|
399
420
|
temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
|
|
400
421
|
)
|
|
401
422
|
calibration_dataloader = create_dataloader(
|
|
402
|
-
annotations_file=calibration_annotation_file,
|
|
423
|
+
annotations_file=calibration_annotation_file,
|
|
424
|
+
img_dir=temp_image_dir,
|
|
425
|
+
hyperparameters=self.hyperparameters,
|
|
403
426
|
)
|
|
404
427
|
model = ModelWithTemperature(model)
|
|
405
428
|
print("Calibrating model")
|
|
@@ -584,6 +607,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
584
607
|
epoch_length=self.epoch_length,
|
|
585
608
|
epochs_per_img=self.model_epochs_per_img,
|
|
586
609
|
brain_state_set=self.brain_state_set,
|
|
610
|
+
emg_filter=self.emg_filter,
|
|
587
611
|
)
|
|
588
612
|
|
|
589
613
|
# overwrite as needed
|
|
@@ -801,6 +825,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
801
825
|
sampling_rate=sampling_rate,
|
|
802
826
|
epoch_length=self.epoch_length,
|
|
803
827
|
brain_state_set=self.brain_state_set,
|
|
828
|
+
emg_filter=self.emg_filter,
|
|
804
829
|
)
|
|
805
830
|
|
|
806
831
|
self.ui.calibration_status.setText("")
|
|
@@ -965,6 +990,7 @@ class AccuSleepWindow(QMainWindow):
|
|
|
965
990
|
confidence_scores=confidence_scores,
|
|
966
991
|
sampling_rate=sampling_rate,
|
|
967
992
|
epoch_length=self.epoch_length,
|
|
993
|
+
emg_filter=self.emg_filter,
|
|
968
994
|
)
|
|
969
995
|
manual_scoring_window.setWindowTitle(f"AccuSleePy viewer: {label_file}")
|
|
970
996
|
manual_scoring_window.exec()
|
|
@@ -1130,15 +1156,6 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1130
1156
|
|
|
1131
1157
|
def initialize_settings_tab(self):
|
|
1132
1158
|
"""Populate settings tab and assign its callbacks"""
|
|
1133
|
-
# show information about the settings tab
|
|
1134
|
-
config_guide_file = open(
|
|
1135
|
-
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_GUIDE_FILE),
|
|
1136
|
-
"r",
|
|
1137
|
-
)
|
|
1138
|
-
config_guide_text = config_guide_file.read()
|
|
1139
|
-
config_guide_file.close()
|
|
1140
|
-
self.ui.settings_text.setText(config_guide_text)
|
|
1141
|
-
|
|
1142
1159
|
# store dictionary that maps digits to rows of widgets
|
|
1143
1160
|
# in the settings tab
|
|
1144
1161
|
self.settings_widgets = {
|
|
@@ -1215,8 +1232,21 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1215
1232
|
}
|
|
1216
1233
|
|
|
1217
1234
|
# update widget state to display current config
|
|
1235
|
+
# UI defaults
|
|
1218
1236
|
self.ui.default_epoch_input.setValue(self.epoch_length)
|
|
1219
|
-
self.ui.
|
|
1237
|
+
self.ui.overwrite_default_checkbox.setChecked(self.only_overwrite_undefined)
|
|
1238
|
+
self.ui.confidence_setting_checkbox.setChecked(self.save_confidence_scores)
|
|
1239
|
+
self.ui.default_min_bout_length_spinbox.setValue(self.min_bout_length)
|
|
1240
|
+
# EMG filter
|
|
1241
|
+
self.ui.emg_order_spinbox.setValue(self.emg_filter.order)
|
|
1242
|
+
self.ui.bp_lower_spinbox.setValue(self.emg_filter.bp_lower)
|
|
1243
|
+
self.ui.bp_upper_spinbox.setValue(self.emg_filter.bp_upper)
|
|
1244
|
+
# model training hyperparameters
|
|
1245
|
+
self.ui.batch_size_spinbox.setValue(self.hyperparameters.batch_size)
|
|
1246
|
+
self.ui.learning_rate_spinbox.setValue(self.hyperparameters.learning_rate)
|
|
1247
|
+
self.ui.momentum_spinbox.setValue(self.hyperparameters.momentum)
|
|
1248
|
+
self.ui.training_epochs_spinbox.setValue(self.hyperparameters.training_epochs)
|
|
1249
|
+
# brain states
|
|
1220
1250
|
states = {b.digit: b for b in self.brain_state_set.brain_states}
|
|
1221
1251
|
for digit in range(10):
|
|
1222
1252
|
if digit in states.keys():
|
|
@@ -1235,6 +1265,15 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1235
1265
|
self.settings_widgets[digit].frequency_widget.setEnabled(False)
|
|
1236
1266
|
|
|
1237
1267
|
# set callbacks
|
|
1268
|
+
self.ui.emg_order_spinbox.valueChanged.connect(self.emg_filter_order_changed)
|
|
1269
|
+
self.ui.bp_lower_spinbox.valueChanged.connect(self.emg_filter_bp_lower_changed)
|
|
1270
|
+
self.ui.bp_upper_spinbox.valueChanged.connect(self.emg_filter_bp_upper_changed)
|
|
1271
|
+
self.ui.batch_size_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1272
|
+
self.ui.learning_rate_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1273
|
+
self.ui.momentum_spinbox.valueChanged.connect(self.hyperparameters_changed)
|
|
1274
|
+
self.ui.training_epochs_spinbox.valueChanged.connect(
|
|
1275
|
+
self.hyperparameters_changed
|
|
1276
|
+
)
|
|
1238
1277
|
for digit in range(10):
|
|
1239
1278
|
state = self.settings_widgets[digit]
|
|
1240
1279
|
state.enabled_widget.stateChanged.connect(
|
|
@@ -1297,6 +1336,41 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1297
1336
|
# check that configuration is valid
|
|
1298
1337
|
_ = self.check_config_validity()
|
|
1299
1338
|
|
|
1339
|
+
def emg_filter_order_changed(self, new_value: int) -> None:
|
|
1340
|
+
"""Called when user modifies EMG filter order
|
|
1341
|
+
|
|
1342
|
+
:param new_value: new EMG filter order
|
|
1343
|
+
"""
|
|
1344
|
+
self.emg_filter.order = new_value
|
|
1345
|
+
|
|
1346
|
+
def emg_filter_bp_lower_changed(self, new_value: int | float) -> None:
|
|
1347
|
+
"""Called when user modifies EMG filter lower cutoff
|
|
1348
|
+
|
|
1349
|
+
:param new_value: new lower bandpass cutoff frequency
|
|
1350
|
+
"""
|
|
1351
|
+
self.emg_filter.bp_lower = new_value
|
|
1352
|
+
_ = self.check_config_validity()
|
|
1353
|
+
|
|
1354
|
+
def emg_filter_bp_upper_changed(self, new_value: int | float) -> None:
|
|
1355
|
+
"""Called when user modifies EMG filter upper cutoff
|
|
1356
|
+
|
|
1357
|
+
:param new_value: new upper bandpass cutoff frequency
|
|
1358
|
+
"""
|
|
1359
|
+
self.emg_filter.bp_upper = new_value
|
|
1360
|
+
_ = self.check_config_validity()
|
|
1361
|
+
|
|
1362
|
+
def hyperparameters_changed(self, new_value) -> None:
|
|
1363
|
+
"""Called when user modifies model training hyperparameters
|
|
1364
|
+
|
|
1365
|
+
:param new_value: unused
|
|
1366
|
+
"""
|
|
1367
|
+
self.hyperparameters = Hyperparameters(
|
|
1368
|
+
batch_size=self.ui.batch_size_spinbox.value(),
|
|
1369
|
+
learning_rate=self.ui.learning_rate_spinbox.value(),
|
|
1370
|
+
momentum=self.ui.momentum_spinbox.value(),
|
|
1371
|
+
training_epochs=self.ui.training_epochs_spinbox.value(),
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1300
1374
|
def check_config_validity(self) -> str:
|
|
1301
1375
|
"""Check if brain state configuration on screen is valid"""
|
|
1302
1376
|
# error message, if we get one
|
|
@@ -1323,6 +1397,10 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1323
1397
|
if sum(frequencies) != 1:
|
|
1324
1398
|
message = "Error: sum(frequencies) != 1"
|
|
1325
1399
|
|
|
1400
|
+
# check validity of EMG filter settings
|
|
1401
|
+
if self.emg_filter.bp_lower >= self.emg_filter.bp_upper:
|
|
1402
|
+
message = "Error: EMG filter cutoff frequencies are invalid"
|
|
1403
|
+
|
|
1326
1404
|
if message is not None:
|
|
1327
1405
|
self.ui.save_config_status.setText(message)
|
|
1328
1406
|
self.ui.save_config_button.setEnabled(False)
|
|
@@ -1355,12 +1433,36 @@ class AccuSleepWindow(QMainWindow):
|
|
|
1355
1433
|
|
|
1356
1434
|
# save to file
|
|
1357
1435
|
save_config(
|
|
1358
|
-
self.brain_state_set,
|
|
1359
|
-
self.ui.default_epoch_input.value(),
|
|
1360
|
-
self.ui.
|
|
1436
|
+
brain_state_set=self.brain_state_set,
|
|
1437
|
+
default_epoch_length=self.ui.default_epoch_input.value(),
|
|
1438
|
+
overwrite_setting=self.ui.overwrite_default_checkbox.isChecked(),
|
|
1439
|
+
save_confidence_setting=self.ui.confidence_setting_checkbox.isChecked(),
|
|
1440
|
+
min_bout_length=self.ui.default_min_bout_length_spinbox.value(),
|
|
1441
|
+
emg_filter=EMGFilter(
|
|
1442
|
+
order=self.emg_filter.order,
|
|
1443
|
+
bp_lower=self.emg_filter.bp_lower,
|
|
1444
|
+
bp_upper=self.emg_filter.bp_upper,
|
|
1445
|
+
),
|
|
1446
|
+
hyperparameters=Hyperparameters(
|
|
1447
|
+
batch_size=self.hyperparameters.batch_size,
|
|
1448
|
+
learning_rate=self.hyperparameters.learning_rate,
|
|
1449
|
+
momentum=self.hyperparameters.momentum,
|
|
1450
|
+
training_epochs=self.hyperparameters.training_epochs,
|
|
1451
|
+
),
|
|
1361
1452
|
)
|
|
1362
1453
|
self.ui.save_config_status.setText("configuration saved")
|
|
1363
1454
|
|
|
1455
|
+
def reset_emg_filter_settings(self) -> None:
|
|
1456
|
+
self.ui.emg_order_spinbox.setValue(DEFAULT_EMG_FILTER_ORDER)
|
|
1457
|
+
self.ui.bp_lower_spinbox.setValue(DEFAULT_EMG_BP_LOWER)
|
|
1458
|
+
self.ui.bp_upper_spinbox.setValue(DEFAULT_EMG_BP_UPPER)
|
|
1459
|
+
|
|
1460
|
+
def reset_hyperparams_settings(self):
|
|
1461
|
+
self.ui.batch_size_spinbox.setValue(DEFAULT_BATCH_SIZE)
|
|
1462
|
+
self.ui.learning_rate_spinbox.setValue(DEFAULT_LEARNING_RATE)
|
|
1463
|
+
self.ui.momentum_spinbox.setValue(DEFAULT_MOMENTUM)
|
|
1464
|
+
self.ui.training_epochs_spinbox.setValue(DEFAULT_TRAINING_EPOCHS)
|
|
1465
|
+
|
|
1364
1466
|
|
|
1365
1467
|
def check_label_validity(
|
|
1366
1468
|
labels: np.array,
|