accusleepy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/__init__.py +0 -0
- accusleepy/__main__.py +4 -0
- accusleepy/brain_state_set.py +89 -0
- accusleepy/classification.py +267 -0
- accusleepy/config.json +22 -0
- accusleepy/constants.py +37 -0
- accusleepy/fileio.py +201 -0
- accusleepy/gui/__init__.py +0 -0
- accusleepy/gui/icons/brightness_down.png +0 -0
- accusleepy/gui/icons/brightness_up.png +0 -0
- accusleepy/gui/icons/double_down_arrow.png +0 -0
- accusleepy/gui/icons/double_up_arrow.png +0 -0
- accusleepy/gui/icons/down_arrow.png +0 -0
- accusleepy/gui/icons/home.png +0 -0
- accusleepy/gui/icons/question.png +0 -0
- accusleepy/gui/icons/save.png +0 -0
- accusleepy/gui/icons/up_arrow.png +0 -0
- accusleepy/gui/icons/zoom_in.png +0 -0
- accusleepy/gui/icons/zoom_out.png +0 -0
- accusleepy/gui/main.py +1372 -0
- accusleepy/gui/manual_scoring.py +1086 -0
- accusleepy/gui/mplwidget.py +356 -0
- accusleepy/gui/primary_window.py +2330 -0
- accusleepy/gui/primary_window.ui +3432 -0
- accusleepy/gui/resources.qrc +16 -0
- accusleepy/gui/resources_rc.py +6710 -0
- accusleepy/gui/text/config_guide.txt +24 -0
- accusleepy/gui/text/main_guide.txt +142 -0
- accusleepy/gui/text/manual_scoring_guide.txt +28 -0
- accusleepy/gui/viewer_window.py +598 -0
- accusleepy/gui/viewer_window.ui +894 -0
- accusleepy/models.py +48 -0
- accusleepy/multitaper.py +659 -0
- accusleepy/signal_processing.py +589 -0
- accusleepy-0.1.0.dist-info/METADATA +57 -0
- accusleepy-0.1.0.dist-info/RECORD +37 -0
- accusleepy-0.1.0.dist-info/WHEEL +4 -0
accusleepy/__init__.py
ADDED
|
File without changes
|
accusleepy/__main__.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
BRAIN_STATES_KEY = "brain_states"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class BrainState:
|
|
10
|
+
"""Convenience class for a brain state and its attributes"""
|
|
11
|
+
|
|
12
|
+
name: str # friendly name
|
|
13
|
+
digit: int # number 0-9 - used as keyboard shortcut and in label files
|
|
14
|
+
is_scored: bool # whether a classification model should score this state
|
|
15
|
+
frequency: int | float # typical relative frequency, between 0 and 1
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BrainStateSet:
|
|
19
|
+
def __init__(self, brain_states: list[BrainState], undefined_label: int):
|
|
20
|
+
"""Initialize set of brain states
|
|
21
|
+
|
|
22
|
+
:param brain_states: list of BrainState objects
|
|
23
|
+
:param undefined_label: label for undefined epochs
|
|
24
|
+
"""
|
|
25
|
+
self.brain_states = brain_states
|
|
26
|
+
|
|
27
|
+
# The user can choose any subset of the digits 0-9 to represent
|
|
28
|
+
# brain states, but not all of them are necessarily intended to be
|
|
29
|
+
# scored by a classifier, and pytorch requires that all input
|
|
30
|
+
# labels are in the 0-n range for training and inference.
|
|
31
|
+
# So, we have to have a distinction between "brain states" (as
|
|
32
|
+
# represented in label files and keyboard inputs) and "classes"
|
|
33
|
+
# (AccuSleep's internal representation).
|
|
34
|
+
|
|
35
|
+
# map digits to classes, and vice versa
|
|
36
|
+
self.digit_to_class = {undefined_label: None}
|
|
37
|
+
self.class_to_digit = dict()
|
|
38
|
+
# relative frequencies of each class
|
|
39
|
+
self.mixture_weights = list()
|
|
40
|
+
|
|
41
|
+
i = 0
|
|
42
|
+
for brain_state in self.brain_states:
|
|
43
|
+
if brain_state.digit == undefined_label:
|
|
44
|
+
raise Exception(
|
|
45
|
+
f"Digit for {brain_state.name} matches 'undefined' label"
|
|
46
|
+
)
|
|
47
|
+
if brain_state.is_scored:
|
|
48
|
+
self.digit_to_class[brain_state.digit] = i
|
|
49
|
+
self.class_to_digit[i] = brain_state.digit
|
|
50
|
+
self.mixture_weights.append(brain_state.frequency)
|
|
51
|
+
i += 1
|
|
52
|
+
else:
|
|
53
|
+
self.digit_to_class[brain_state.digit] = None
|
|
54
|
+
|
|
55
|
+
self.n_classes = i
|
|
56
|
+
|
|
57
|
+
self.mixture_weights = np.array(self.mixture_weights)
|
|
58
|
+
if np.sum(self.mixture_weights) != 1:
|
|
59
|
+
raise Exception("Typical frequencies for scored brain states must sum to 1")
|
|
60
|
+
|
|
61
|
+
def convert_digit_to_class(self, digits: np.array) -> np.array:
|
|
62
|
+
"""Convert array of digits to their corresponding classes
|
|
63
|
+
|
|
64
|
+
:param digits: array of digits
|
|
65
|
+
:return: array of classes
|
|
66
|
+
"""
|
|
67
|
+
return np.array([self.digit_to_class[i] for i in digits])
|
|
68
|
+
|
|
69
|
+
def convert_class_to_digit(self, classes: np.array) -> np.array:
|
|
70
|
+
"""Convert array of classes to their corresponding digits
|
|
71
|
+
|
|
72
|
+
:param classes: array of classes
|
|
73
|
+
:return: array of digits
|
|
74
|
+
"""
|
|
75
|
+
return np.array([self.class_to_digit[i] for i in classes])
|
|
76
|
+
|
|
77
|
+
def to_output_dict(self) -> dict:
|
|
78
|
+
"""Return dictionary of brain states"""
|
|
79
|
+
return {
|
|
80
|
+
BRAIN_STATES_KEY: [
|
|
81
|
+
{
|
|
82
|
+
"name": b.name,
|
|
83
|
+
"digit": b.digit,
|
|
84
|
+
"is_scored": b.is_scored,
|
|
85
|
+
"frequency": b.frequency,
|
|
86
|
+
}
|
|
87
|
+
for b in self.brain_states
|
|
88
|
+
]
|
|
89
|
+
}
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import torch
|
|
6
|
+
import torch.optim as optim
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.utils.data import DataLoader, Dataset
|
|
9
|
+
from torchvision.io import read_image
|
|
10
|
+
from tqdm import trange
|
|
11
|
+
|
|
12
|
+
import accusleepy.constants as c
|
|
13
|
+
from accusleepy.brain_state_set import BrainStateSet
|
|
14
|
+
from accusleepy.models import SSANN
|
|
15
|
+
from accusleepy.signal_processing import (
|
|
16
|
+
create_eeg_emg_image,
|
|
17
|
+
format_img,
|
|
18
|
+
get_mixture_values,
|
|
19
|
+
mixture_z_score_img,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
BATCH_SIZE = 64
|
|
23
|
+
LEARNING_RATE = 1e-3
|
|
24
|
+
MOMENTUM = 0.9
|
|
25
|
+
TRAINING_EPOCHS = 6
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AccuSleepImageDataset(Dataset):
|
|
29
|
+
"""Dataset for loading AccuSleep training images"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self, annotations_file, img_dir, transform=None, target_transform=None
|
|
33
|
+
):
|
|
34
|
+
self.img_labels = pd.read_csv(annotations_file)
|
|
35
|
+
self.img_dir = img_dir
|
|
36
|
+
self.transform = transform
|
|
37
|
+
self.target_transform = target_transform
|
|
38
|
+
|
|
39
|
+
def __len__(self):
|
|
40
|
+
return len(self.img_labels)
|
|
41
|
+
|
|
42
|
+
def __getitem__(self, idx):
|
|
43
|
+
img_path = str(
|
|
44
|
+
os.path.join(self.img_dir, self.img_labels.at[idx, c.FILENAME_COL])
|
|
45
|
+
)
|
|
46
|
+
image = read_image(img_path)
|
|
47
|
+
label = self.img_labels.at[idx, c.LABEL_COL]
|
|
48
|
+
if self.transform:
|
|
49
|
+
image = self.transform(image)
|
|
50
|
+
if self.target_transform:
|
|
51
|
+
label = self.target_transform(label)
|
|
52
|
+
return image, label
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_device():
|
|
56
|
+
"""Get accelerator, if one is available"""
|
|
57
|
+
return (
|
|
58
|
+
torch.accelerator.current_accelerator().type
|
|
59
|
+
if torch.accelerator.is_available()
|
|
60
|
+
else "cpu"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def train_model(
|
|
65
|
+
annotations_file: str,
|
|
66
|
+
img_dir: str,
|
|
67
|
+
mixture_weights: np.array,
|
|
68
|
+
n_classes: int,
|
|
69
|
+
) -> SSANN:
|
|
70
|
+
"""Train a classification model for sleep scoring
|
|
71
|
+
|
|
72
|
+
:param annotations_file: file with information on each training image
|
|
73
|
+
:param img_dir: training image location
|
|
74
|
+
:param mixture_weights: typical relative frequencies of brain states
|
|
75
|
+
:param n_classes: number of classes the model will learn
|
|
76
|
+
:return: trained Sleep Scoring Artificial Neural Network model
|
|
77
|
+
"""
|
|
78
|
+
training_data = AccuSleepImageDataset(
|
|
79
|
+
annotations_file=annotations_file,
|
|
80
|
+
img_dir=img_dir,
|
|
81
|
+
)
|
|
82
|
+
train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)
|
|
83
|
+
|
|
84
|
+
device = get_device()
|
|
85
|
+
model = SSANN(n_classes=n_classes)
|
|
86
|
+
model.to(device)
|
|
87
|
+
model.train()
|
|
88
|
+
|
|
89
|
+
# correct for class imbalance
|
|
90
|
+
weight = torch.tensor((mixture_weights**-1).astype("float32")).to(device)
|
|
91
|
+
|
|
92
|
+
criterion = nn.CrossEntropyLoss(weight=weight)
|
|
93
|
+
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
|
|
94
|
+
|
|
95
|
+
for _ in trange(TRAINING_EPOCHS):
|
|
96
|
+
for data in train_dataloader:
|
|
97
|
+
inputs, labels = data
|
|
98
|
+
(inputs, labels) = (inputs.to(device), labels.to(device))
|
|
99
|
+
optimizer.zero_grad()
|
|
100
|
+
outputs = model(inputs)
|
|
101
|
+
loss = criterion(outputs, labels)
|
|
102
|
+
loss.backward()
|
|
103
|
+
optimizer.step()
|
|
104
|
+
|
|
105
|
+
return model
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def score_recording(
|
|
109
|
+
model: SSANN,
|
|
110
|
+
eeg: np.array,
|
|
111
|
+
emg: np.array,
|
|
112
|
+
mixture_means: np.array,
|
|
113
|
+
mixture_sds: np.array,
|
|
114
|
+
sampling_rate: int | float,
|
|
115
|
+
epoch_length: int | float,
|
|
116
|
+
epochs_per_img: int,
|
|
117
|
+
brain_state_set: BrainStateSet,
|
|
118
|
+
) -> np.array:
|
|
119
|
+
"""Use classification model to get brain state labels for a recording
|
|
120
|
+
|
|
121
|
+
This assumes signals have been preprocessed to contain an integer
|
|
122
|
+
number of epochs.
|
|
123
|
+
|
|
124
|
+
:param model: classification model
|
|
125
|
+
:param eeg: EEG signal
|
|
126
|
+
:param emg: EMG signal
|
|
127
|
+
:param mixture_means: mixture means, for calibration
|
|
128
|
+
:param mixture_sds: mixture standard deviations, for calibration
|
|
129
|
+
:param sampling_rate: sampling rate, in Hz
|
|
130
|
+
:param epoch_length: epoch length, in seconds
|
|
131
|
+
:param epochs_per_img: number of epochs for the model to consider
|
|
132
|
+
:param brain_state_set: set of brain state options
|
|
133
|
+
:return: brain state labels
|
|
134
|
+
"""
|
|
135
|
+
# prepare model
|
|
136
|
+
device = get_device()
|
|
137
|
+
model = model.to(device)
|
|
138
|
+
model.eval()
|
|
139
|
+
|
|
140
|
+
# create and scale eeg+emg spectrogram
|
|
141
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
142
|
+
img = mixture_z_score_img(
|
|
143
|
+
img,
|
|
144
|
+
mixture_means=mixture_means,
|
|
145
|
+
mixture_sds=mixture_sds,
|
|
146
|
+
brain_state_set=brain_state_set,
|
|
147
|
+
)
|
|
148
|
+
img = format_img(img=img, epochs_per_img=epochs_per_img, add_padding=True)
|
|
149
|
+
|
|
150
|
+
# create dataset for inference
|
|
151
|
+
images = []
|
|
152
|
+
for i in range(img.shape[1] - epochs_per_img + 1):
|
|
153
|
+
images.append(img[:, i : (i + epochs_per_img)].astype("float32"))
|
|
154
|
+
images = torch.from_numpy(np.array(images))
|
|
155
|
+
images = images[:, None, :, :] # add channel
|
|
156
|
+
images = images.to(device)
|
|
157
|
+
|
|
158
|
+
# perform classification
|
|
159
|
+
with torch.no_grad():
|
|
160
|
+
outputs = model(images)
|
|
161
|
+
_, predicted = torch.max(outputs, 1)
|
|
162
|
+
|
|
163
|
+
labels = brain_state_set.convert_class_to_digit(predicted.cpu().numpy())
|
|
164
|
+
return labels
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def example_real_time_scoring_function(
|
|
168
|
+
model: SSANN,
|
|
169
|
+
eeg: np.array,
|
|
170
|
+
emg: np.array,
|
|
171
|
+
mixture_means: np.array,
|
|
172
|
+
mixture_sds: np.array,
|
|
173
|
+
sampling_rate: int | float,
|
|
174
|
+
epoch_length: int | float,
|
|
175
|
+
epochs_per_img: int,
|
|
176
|
+
brain_state_set: BrainStateSet,
|
|
177
|
+
) -> int:
|
|
178
|
+
"""Example function that could be used for real-time scoring
|
|
179
|
+
|
|
180
|
+
This function demonstrates how you could use a model trained in
|
|
181
|
+
"real-time" mode (current epoch on the right side of each image)
|
|
182
|
+
to score incoming data. By passing a segment of EEG/EMG data
|
|
183
|
+
into this function, the most recent epoch will be scored. For
|
|
184
|
+
example, if the model expects 9 epochs worth of data and the
|
|
185
|
+
epoch length is 5 seconds, you would pass in 45 seconds of data
|
|
186
|
+
and would obtain the brain state of the most recent 5 seconds.
|
|
187
|
+
|
|
188
|
+
Note:
|
|
189
|
+
- The EEG and EMG signals must have length equal to
|
|
190
|
+
sampling_rate * epoch_length * <number of epochs per image>.
|
|
191
|
+
- The number of samples per epoch must be an integer.
|
|
192
|
+
- This is just a demonstration, you should customize this for
|
|
193
|
+
your application and there are probably ways to make it
|
|
194
|
+
run faster.
|
|
195
|
+
|
|
196
|
+
:param model: classification model
|
|
197
|
+
:param eeg: EEG signal
|
|
198
|
+
:param emg: EMG signal
|
|
199
|
+
:param mixture_means: mixture means, for calibration
|
|
200
|
+
:param mixture_sds: mixture standard deviations, for calibration
|
|
201
|
+
:param sampling_rate: sampling rate, in Hz
|
|
202
|
+
:param epoch_length: epoch length, in seconds
|
|
203
|
+
:param epochs_per_img: number of epochs shown to the model at once
|
|
204
|
+
:param brain_state_set: set of brain state options
|
|
205
|
+
:return: brain state label
|
|
206
|
+
"""
|
|
207
|
+
# prepare model
|
|
208
|
+
# this could be done outside the function
|
|
209
|
+
device = get_device()
|
|
210
|
+
model = model.to(device)
|
|
211
|
+
model.eval()
|
|
212
|
+
|
|
213
|
+
# create and scale eeg+emg spectrogram
|
|
214
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
215
|
+
img = mixture_z_score_img(
|
|
216
|
+
img,
|
|
217
|
+
mixture_means=mixture_means,
|
|
218
|
+
mixture_sds=mixture_sds,
|
|
219
|
+
brain_state_set=brain_state_set,
|
|
220
|
+
)
|
|
221
|
+
img = format_img(img=img, epochs_per_img=epochs_per_img, add_padding=False)
|
|
222
|
+
|
|
223
|
+
# create dataset for inference
|
|
224
|
+
images = torch.from_numpy(np.array([img.astype("float32")]))
|
|
225
|
+
images = images[:, None, :, :] # add channel
|
|
226
|
+
images = images.to(device)
|
|
227
|
+
|
|
228
|
+
# perform classification
|
|
229
|
+
with torch.no_grad():
|
|
230
|
+
outputs = model(images)
|
|
231
|
+
_, predicted = torch.max(outputs, 1)
|
|
232
|
+
|
|
233
|
+
label = int(brain_state_set.convert_class_to_digit(predicted.cpu().numpy())[0])
|
|
234
|
+
return label
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def create_calibration_file(
|
|
238
|
+
filename: str,
|
|
239
|
+
eeg: np.array,
|
|
240
|
+
emg: np.array,
|
|
241
|
+
labels: np.array,
|
|
242
|
+
sampling_rate: int | float,
|
|
243
|
+
epoch_length: int | float,
|
|
244
|
+
brain_state_set: BrainStateSet,
|
|
245
|
+
) -> None:
|
|
246
|
+
"""Create file of calibration data for a subject
|
|
247
|
+
|
|
248
|
+
This assumes signals have been preprocessed to contain an integer
|
|
249
|
+
number of epochs.
|
|
250
|
+
|
|
251
|
+
:param filename: filename for the calibration file
|
|
252
|
+
:param eeg: EEG signal
|
|
253
|
+
:param emg: EMG signal
|
|
254
|
+
:param labels: brain state labels, as digits
|
|
255
|
+
:param sampling_rate: sampling rate, in Hz
|
|
256
|
+
:param epoch_length: epoch length, in seconds
|
|
257
|
+
:param brain_state_set: set of brain state options
|
|
258
|
+
"""
|
|
259
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length)
|
|
260
|
+
mixture_means, mixture_sds = get_mixture_values(
|
|
261
|
+
img=img,
|
|
262
|
+
labels=brain_state_set.convert_digit_to_class(labels),
|
|
263
|
+
brain_state_set=brain_state_set,
|
|
264
|
+
)
|
|
265
|
+
pd.DataFrame(
|
|
266
|
+
{c.MIXTURE_MEAN_COL: mixture_means, c.MIXTURE_SD_COL: mixture_sds}
|
|
267
|
+
).to_csv(filename, index=False)
|
accusleepy/config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
{
|
|
2
|
+
"brain_states": [
|
|
3
|
+
{
|
|
4
|
+
"name": "REM",
|
|
5
|
+
"digit": 1,
|
|
6
|
+
"is_scored": true,
|
|
7
|
+
"frequency": 0.1
|
|
8
|
+
},
|
|
9
|
+
{
|
|
10
|
+
"name": "Wake",
|
|
11
|
+
"digit": 2,
|
|
12
|
+
"is_scored": true,
|
|
13
|
+
"frequency": 0.35
|
|
14
|
+
},
|
|
15
|
+
{
|
|
16
|
+
"name": "NREM",
|
|
17
|
+
"digit": 3,
|
|
18
|
+
"is_scored": true,
|
|
19
|
+
"frequency": 0.55
|
|
20
|
+
}
|
|
21
|
+
]
|
|
22
|
+
}
|
accusleepy/constants.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# probably don't change these unless you really need to
|
|
2
|
+
UNDEFINED_LABEL = -1 # can't be the same as a brain state's digit, must be an integer
|
|
3
|
+
# calibration file columns
|
|
4
|
+
MIXTURE_MEAN_COL = "mixture_mean"
|
|
5
|
+
MIXTURE_SD_COL = "mixture_sd"
|
|
6
|
+
# recording file columns
|
|
7
|
+
EEG_COL = "eeg"
|
|
8
|
+
EMG_COL = "emg"
|
|
9
|
+
# label file columns
|
|
10
|
+
BRAIN_STATE_COL = "brain_state"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# really don't change these
|
|
14
|
+
# config file location
|
|
15
|
+
CONFIG_FILE = "config.json"
|
|
16
|
+
# number of times to include the EMG power in a training image
|
|
17
|
+
EMG_COPIES = 9
|
|
18
|
+
# minimum spectrogram window length, in seconds
|
|
19
|
+
MIN_WINDOW_LEN = 5
|
|
20
|
+
# frequency above which to downsample EEG spectrograms
|
|
21
|
+
DOWNSAMPLING_START_FREQ = 20
|
|
22
|
+
# upper frequency cutoff for EEG spectrograms
|
|
23
|
+
UPPER_FREQ = 50
|
|
24
|
+
# classification model types
|
|
25
|
+
DEFAULT_MODEL_TYPE = "default" # current epoch is centered
|
|
26
|
+
REAL_TIME_MODEL_TYPE = "real-time" # current epoch on the right
|
|
27
|
+
# valid filetypes
|
|
28
|
+
RECORDING_FILE_TYPES = [".parquet", ".csv"]
|
|
29
|
+
LABEL_FILE_TYPE = ".csv"
|
|
30
|
+
CALIBRATION_FILE_TYPE = ".csv"
|
|
31
|
+
MODEL_FILE_TYPE = ".pth"
|
|
32
|
+
# annotation file columns
|
|
33
|
+
FILENAME_COL = "filename"
|
|
34
|
+
LABEL_COL = "label"
|
|
35
|
+
# recording list file header:
|
|
36
|
+
RECORDING_LIST_NAME = "recording_list"
|
|
37
|
+
RECORDING_LIST_FILE_TYPE = ".json"
|
accusleepy/fileio.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
from PySide6.QtWidgets import QListWidgetItem
|
|
9
|
+
|
|
10
|
+
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
11
|
+
from accusleepy.constants import (
|
|
12
|
+
BRAIN_STATE_COL,
|
|
13
|
+
CONFIG_FILE,
|
|
14
|
+
EEG_COL,
|
|
15
|
+
EMG_COL,
|
|
16
|
+
MIXTURE_MEAN_COL,
|
|
17
|
+
MIXTURE_SD_COL,
|
|
18
|
+
RECORDING_LIST_NAME,
|
|
19
|
+
UNDEFINED_LABEL,
|
|
20
|
+
)
|
|
21
|
+
from accusleepy.models import SSANN
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class Recording:
|
|
26
|
+
"""Store information about a recording"""
|
|
27
|
+
|
|
28
|
+
name: int = 1 # name to show in the GUI
|
|
29
|
+
recording_file: str = "" # path to recording file
|
|
30
|
+
label_file: str = "" # path to label file
|
|
31
|
+
calibration_file: str = "" # path to calibration file
|
|
32
|
+
sampling_rate: int | float = 0.0 # sampling rate, in Hz
|
|
33
|
+
widget: QListWidgetItem = None # list item widget shown in the GUI
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_calibration_file(filename: str) -> (np.array, np.array):
|
|
37
|
+
"""Load a calibration file
|
|
38
|
+
|
|
39
|
+
:param filename: filename
|
|
40
|
+
:return: mixture means and SDs
|
|
41
|
+
"""
|
|
42
|
+
df = pd.read_csv(filename)
|
|
43
|
+
mixture_means = df[MIXTURE_MEAN_COL].values
|
|
44
|
+
mixture_sds = df[MIXTURE_SD_COL].values
|
|
45
|
+
return mixture_means, mixture_sds
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def save_model(
|
|
49
|
+
model: SSANN,
|
|
50
|
+
filename: str,
|
|
51
|
+
epoch_length: int | float,
|
|
52
|
+
epochs_per_img: int,
|
|
53
|
+
model_type: str,
|
|
54
|
+
brain_state_set: BrainStateSet,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Save classification model and its metadata
|
|
57
|
+
|
|
58
|
+
:param model: classification model
|
|
59
|
+
:param epoch_length: epoch length used when training the model
|
|
60
|
+
:param epochs_per_img: number of epochs in each model input
|
|
61
|
+
:param model_type: default or real-time
|
|
62
|
+
:param brain_state_set: set of brain state options
|
|
63
|
+
:param filename: filename
|
|
64
|
+
"""
|
|
65
|
+
state_dict = model.state_dict()
|
|
66
|
+
state_dict.update({"epoch_length": epoch_length})
|
|
67
|
+
state_dict.update({"epochs_per_img": epochs_per_img})
|
|
68
|
+
state_dict.update({"model_type": model_type})
|
|
69
|
+
state_dict.update(
|
|
70
|
+
{BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
torch.save(state_dict, filename)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
|
|
77
|
+
"""Load classification model and its metadata
|
|
78
|
+
|
|
79
|
+
:param filename: filename
|
|
80
|
+
:return: model, epoch length used when training the model,
|
|
81
|
+
number of epochs in each model input, model type
|
|
82
|
+
(default or real-time), set of brain state options
|
|
83
|
+
used when training the model
|
|
84
|
+
"""
|
|
85
|
+
state_dict = torch.load(filename, weights_only=True)
|
|
86
|
+
epoch_length = state_dict.pop("epoch_length")
|
|
87
|
+
epochs_per_img = state_dict.pop("epochs_per_img")
|
|
88
|
+
model_type = state_dict.pop("model_type")
|
|
89
|
+
brain_states = state_dict.pop(BRAIN_STATES_KEY)
|
|
90
|
+
n_classes = len([b for b in brain_states if b["is_scored"]])
|
|
91
|
+
|
|
92
|
+
model = SSANN(n_classes=n_classes)
|
|
93
|
+
model.load_state_dict(state_dict)
|
|
94
|
+
return model, epoch_length, epochs_per_img, model_type, brain_states
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def load_csv_or_parquet(filename: str) -> pd.DataFrame:
|
|
98
|
+
"""Load a csv or parquet file as a dataframe
|
|
99
|
+
|
|
100
|
+
:param filename: filename
|
|
101
|
+
:return: dataframe of file contents
|
|
102
|
+
"""
|
|
103
|
+
extension = os.path.splitext(filename)[1]
|
|
104
|
+
if extension == ".csv":
|
|
105
|
+
df = pd.read_csv(filename)
|
|
106
|
+
elif extension == ".parquet":
|
|
107
|
+
df = pd.read_parquet(filename)
|
|
108
|
+
else:
|
|
109
|
+
raise Exception("file must be csv or parquet")
|
|
110
|
+
return df
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def load_recording(filename: str) -> (np.array, np.array):
|
|
114
|
+
"""Load recording of EEG and EMG time series data
|
|
115
|
+
|
|
116
|
+
:param filename: filename
|
|
117
|
+
:return: arrays of EEG and EMG data
|
|
118
|
+
"""
|
|
119
|
+
df = load_csv_or_parquet(filename)
|
|
120
|
+
eeg = df[EEG_COL].values
|
|
121
|
+
emg = df[EMG_COL].values
|
|
122
|
+
return eeg, emg
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def load_labels(filename: str) -> np.array:
|
|
126
|
+
"""Load file of brain state labels
|
|
127
|
+
|
|
128
|
+
:param filename: filename
|
|
129
|
+
:return: array of brain state labels
|
|
130
|
+
"""
|
|
131
|
+
df = load_csv_or_parquet(filename)
|
|
132
|
+
return df[BRAIN_STATE_COL].values
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def save_labels(labels: np.array, filename: str) -> None:
|
|
136
|
+
"""Save brain state labels to file
|
|
137
|
+
|
|
138
|
+
:param labels: brain state labels
|
|
139
|
+
:param filename: filename
|
|
140
|
+
"""
|
|
141
|
+
pd.DataFrame({BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def load_config() -> BrainStateSet:
|
|
145
|
+
"""Load configuration file with brain state options
|
|
146
|
+
|
|
147
|
+
:return: set of brain state options
|
|
148
|
+
"""
|
|
149
|
+
with open(
|
|
150
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "r"
|
|
151
|
+
) as f:
|
|
152
|
+
data = json.load(f)
|
|
153
|
+
return BrainStateSet(
|
|
154
|
+
[BrainState(**b) for b in data[BRAIN_STATES_KEY]], UNDEFINED_LABEL
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def save_config(brain_state_set: BrainStateSet) -> None:
|
|
159
|
+
"""Save configuration of brain state options to json file
|
|
160
|
+
|
|
161
|
+
:param brain_state_set: set of brain state options
|
|
162
|
+
"""
|
|
163
|
+
with open(
|
|
164
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "w"
|
|
165
|
+
) as f:
|
|
166
|
+
json.dump(brain_state_set.to_output_dict(), f, indent=4)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def load_recording_list(filename: str) -> list[Recording]:
|
|
170
|
+
"""Load list of recordings from file
|
|
171
|
+
|
|
172
|
+
:param filename: filename of list of recordings
|
|
173
|
+
:return: list of recordings
|
|
174
|
+
"""
|
|
175
|
+
with open(filename, "r") as f:
|
|
176
|
+
data = json.load(f)
|
|
177
|
+
recording_list = [Recording(**r) for r in data[RECORDING_LIST_NAME]]
|
|
178
|
+
for i, r in enumerate(recording_list):
|
|
179
|
+
r.name = i + 1
|
|
180
|
+
return recording_list
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def save_recording_list(filename: str, recordings: list[Recording]) -> None:
|
|
184
|
+
"""Save list of recordings to file
|
|
185
|
+
|
|
186
|
+
:param filename: where to save the list
|
|
187
|
+
:param recordings: list of recordings to export
|
|
188
|
+
"""
|
|
189
|
+
recording_dict = {
|
|
190
|
+
RECORDING_LIST_NAME: [
|
|
191
|
+
{
|
|
192
|
+
"recording_file": r.recording_file,
|
|
193
|
+
"label_file": r.label_file,
|
|
194
|
+
"calibration_file": r.calibration_file,
|
|
195
|
+
"sampling_rate": r.sampling_rate,
|
|
196
|
+
}
|
|
197
|
+
for r in recordings
|
|
198
|
+
]
|
|
199
|
+
}
|
|
200
|
+
with open(filename, "w") as f:
|
|
201
|
+
json.dump(recording_dict, f, indent=4)
|
|
File without changes
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|