accusleepy 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/bouts.py +142 -0
- accusleepy/classification.py +2 -2
- accusleepy/constants.py +2 -0
- accusleepy/fileio.py +0 -53
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/main.py +84 -64
- accusleepy/gui/manual_scoring.py +76 -81
- accusleepy/gui/mplwidget.py +15 -10
- accusleepy/gui/primary_window.py +1 -0
- accusleepy/models.py +57 -7
- accusleepy/multitaper.py +9 -7
- accusleepy/signal_processing.py +5 -143
- {accusleepy-0.4.4.dist-info → accusleepy-0.5.0.dist-info}/METADATA +29 -19
- {accusleepy-0.4.4.dist-info → accusleepy-0.5.0.dist-info}/RECORD +15 -14
- {accusleepy-0.4.4.dist-info → accusleepy-0.5.0.dist-info}/WHEEL +0 -0
accusleepy/gui/manual_scoring.py
CHANGED
|
@@ -12,7 +12,24 @@ from types import SimpleNamespace
|
|
|
12
12
|
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
14
|
import numpy as np
|
|
15
|
-
from PySide6 import
|
|
15
|
+
from PySide6.QtCore import (
|
|
16
|
+
QKeyCombination,
|
|
17
|
+
QRect,
|
|
18
|
+
Qt,
|
|
19
|
+
QUrl,
|
|
20
|
+
)
|
|
21
|
+
from PySide6.QtGui import (
|
|
22
|
+
QCloseEvent,
|
|
23
|
+
QKeySequence,
|
|
24
|
+
QShortcut,
|
|
25
|
+
)
|
|
26
|
+
from PySide6.QtWidgets import (
|
|
27
|
+
QDialog,
|
|
28
|
+
QMessageBox,
|
|
29
|
+
QTextBrowser,
|
|
30
|
+
QVBoxLayout,
|
|
31
|
+
QWidget,
|
|
32
|
+
)
|
|
16
33
|
|
|
17
34
|
from accusleepy.constants import UNDEFINED_LABEL
|
|
18
35
|
from accusleepy.fileio import load_config, save_labels
|
|
@@ -73,7 +90,7 @@ class StateChange:
|
|
|
73
90
|
epoch: int # first epoch affected
|
|
74
91
|
|
|
75
92
|
|
|
76
|
-
class ManualScoringWindow(
|
|
93
|
+
class ManualScoringWindow(QDialog):
|
|
77
94
|
"""AccuSleePy manual scoring GUI"""
|
|
78
95
|
|
|
79
96
|
def __init__(
|
|
@@ -191,33 +208,25 @@ class ManualScoringWindow(QtWidgets.QDialog):
|
|
|
191
208
|
self.update_lower_figure()
|
|
192
209
|
|
|
193
210
|
# user input: keyboard shortcuts
|
|
194
|
-
keypress_right =
|
|
195
|
-
QtGui.QKeySequence(QtCore.Qt.Key.Key_Right), self
|
|
196
|
-
)
|
|
211
|
+
keypress_right = QShortcut(QKeySequence(Qt.Key.Key_Right), self)
|
|
197
212
|
keypress_right.activated.connect(partial(self.shift_epoch, DIRECTION_RIGHT))
|
|
198
213
|
|
|
199
|
-
keypress_left =
|
|
200
|
-
QtGui.QKeySequence(QtCore.Qt.Key.Key_Left), self
|
|
201
|
-
)
|
|
214
|
+
keypress_left = QShortcut(QKeySequence(Qt.Key.Key_Left), self)
|
|
202
215
|
keypress_left.activated.connect(partial(self.shift_epoch, DIRECTION_LEFT))
|
|
203
216
|
|
|
204
217
|
keypress_zoom_in_x = list()
|
|
205
|
-
for zoom_key in [
|
|
206
|
-
keypress_zoom_in_x.append(
|
|
207
|
-
QtGui.QShortcut(QtGui.QKeySequence(zoom_key), self)
|
|
208
|
-
)
|
|
218
|
+
for zoom_key in [Qt.Key.Key_Plus, Qt.Key.Key_Equal]:
|
|
219
|
+
keypress_zoom_in_x.append(QShortcut(QKeySequence(zoom_key), self))
|
|
209
220
|
keypress_zoom_in_x[-1].activated.connect(partial(self.zoom_x, ZOOM_IN))
|
|
210
221
|
|
|
211
|
-
keypress_zoom_out_x =
|
|
212
|
-
QtGui.QKeySequence(QtCore.Qt.Key.Key_Minus), self
|
|
213
|
-
)
|
|
222
|
+
keypress_zoom_out_x = QShortcut(QKeySequence(Qt.Key.Key_Minus), self)
|
|
214
223
|
keypress_zoom_out_x.activated.connect(partial(self.zoom_x, ZOOM_OUT))
|
|
215
224
|
|
|
216
225
|
keypress_modify_label = list()
|
|
217
226
|
for brain_state in self.brain_state_set.brain_states:
|
|
218
227
|
keypress_modify_label.append(
|
|
219
|
-
|
|
220
|
-
|
|
228
|
+
QShortcut(
|
|
229
|
+
QKeySequence(Qt.Key[f"Key_{brain_state.digit}"]),
|
|
221
230
|
self,
|
|
222
231
|
)
|
|
223
232
|
)
|
|
@@ -225,25 +234,19 @@ class ManualScoringWindow(QtWidgets.QDialog):
|
|
|
225
234
|
partial(self.modify_current_epoch_label, brain_state.digit)
|
|
226
235
|
)
|
|
227
236
|
|
|
228
|
-
keypress_delete_label =
|
|
229
|
-
QtGui.QKeySequence(QtCore.Qt.Key.Key_Backspace), self
|
|
230
|
-
)
|
|
237
|
+
keypress_delete_label = QShortcut(QKeySequence(Qt.Key.Key_Backspace), self)
|
|
231
238
|
keypress_delete_label.activated.connect(
|
|
232
239
|
partial(self.modify_current_epoch_label, UNDEFINED_LABEL)
|
|
233
240
|
)
|
|
234
241
|
|
|
235
|
-
keypress_quit =
|
|
236
|
-
|
|
237
|
-
QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_W)
|
|
238
|
-
),
|
|
242
|
+
keypress_quit = QShortcut(
|
|
243
|
+
QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_W)),
|
|
239
244
|
self,
|
|
240
245
|
)
|
|
241
246
|
keypress_quit.activated.connect(self.close)
|
|
242
247
|
|
|
243
|
-
keypress_save =
|
|
244
|
-
|
|
245
|
-
QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_S)
|
|
246
|
-
),
|
|
248
|
+
keypress_save = QShortcut(
|
|
249
|
+
QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_S)),
|
|
247
250
|
self,
|
|
248
251
|
)
|
|
249
252
|
keypress_save.activated.connect(self.save)
|
|
@@ -251,11 +254,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
|
|
|
251
254
|
keypress_roi = list()
|
|
252
255
|
for brain_state in self.brain_state_set.brain_states:
|
|
253
256
|
keypress_roi.append(
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
257
|
+
QShortcut(
|
|
258
|
+
QKeySequence(
|
|
259
|
+
QKeyCombination(
|
|
260
|
+
Qt.Modifier.SHIFT,
|
|
261
|
+
Qt.Key[f"Key_{brain_state.digit}"],
|
|
259
262
|
)
|
|
260
263
|
),
|
|
261
264
|
self,
|
|
@@ -265,11 +268,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
|
|
|
265
268
|
partial(self.enter_label_roi_mode, brain_state.digit)
|
|
266
269
|
)
|
|
267
270
|
keypress_roi.append(
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
271
|
+
QShortcut(
|
|
272
|
+
QKeySequence(
|
|
273
|
+
QKeyCombination(
|
|
274
|
+
Qt.Modifier.SHIFT,
|
|
275
|
+
Qt.Key.Key_Backspace,
|
|
273
276
|
)
|
|
274
277
|
),
|
|
275
278
|
self,
|
|
@@ -279,22 +282,18 @@ class ManualScoringWindow(QtWidgets.QDialog):
|
|
|
279
282
|
partial(self.enter_label_roi_mode, UNDEFINED_LABEL)
|
|
280
283
|
)
|
|
281
284
|
|
|
282
|
-
keypress_esc =
|
|
283
|
-
QtGui.QKeySequence(QtCore.Qt.Key.Key_Escape), self
|
|
284
|
-
)
|
|
285
|
+
keypress_esc = QShortcut(QKeySequence(Qt.Key.Key_Escape), self)
|
|
285
286
|
keypress_esc.activated.connect(self.exit_label_roi_mode)
|
|
286
287
|
|
|
287
|
-
keypress_space =
|
|
288
|
-
QtGui.QKeySequence(QtCore.Qt.Key.Key_Space), self
|
|
289
|
-
)
|
|
288
|
+
keypress_space = QShortcut(QKeySequence(Qt.Key.Key_Space), self)
|
|
290
289
|
keypress_space.activated.connect(
|
|
291
290
|
partial(self.jump_to_next_state, DIRECTION_RIGHT, DIFFERENT_STATE)
|
|
292
291
|
)
|
|
293
|
-
keypress_shift_right =
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
292
|
+
keypress_shift_right = QShortcut(
|
|
293
|
+
QKeySequence(
|
|
294
|
+
QKeyCombination(
|
|
295
|
+
Qt.Modifier.SHIFT,
|
|
296
|
+
Qt.Key.Key_Right,
|
|
298
297
|
)
|
|
299
298
|
),
|
|
300
299
|
self,
|
|
@@ -302,11 +301,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
|
|
|
302
301
|
keypress_shift_right.activated.connect(
|
|
303
302
|
partial(self.jump_to_next_state, DIRECTION_RIGHT, DIFFERENT_STATE)
|
|
304
303
|
)
|
|
305
|
-
keypress_shift_left =
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
304
|
+
keypress_shift_left = QShortcut(
|
|
305
|
+
QKeySequence(
|
|
306
|
+
QKeyCombination(
|
|
307
|
+
Qt.Modifier.SHIFT,
|
|
308
|
+
Qt.Key.Key_Left,
|
|
310
309
|
)
|
|
311
310
|
),
|
|
312
311
|
self,
|
|
@@ -314,11 +313,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
|
|
|
314
313
|
keypress_shift_left.activated.connect(
|
|
315
314
|
partial(self.jump_to_next_state, DIRECTION_LEFT, DIFFERENT_STATE)
|
|
316
315
|
)
|
|
317
|
-
keypress_ctrl_right =
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
316
|
+
keypress_ctrl_right = QShortcut(
|
|
317
|
+
QKeySequence(
|
|
318
|
+
QKeyCombination(
|
|
319
|
+
Qt.Modifier.CTRL,
|
|
320
|
+
Qt.Key.Key_Right,
|
|
322
321
|
)
|
|
323
322
|
),
|
|
324
323
|
self,
|
|
@@ -326,11 +325,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
|
|
|
326
325
|
keypress_ctrl_right.activated.connect(
|
|
327
326
|
partial(self.jump_to_next_state, DIRECTION_RIGHT, UNDEFINED_STATE)
|
|
328
327
|
)
|
|
329
|
-
keypress_ctrl_left =
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
328
|
+
keypress_ctrl_left = QShortcut(
|
|
329
|
+
QKeySequence(
|
|
330
|
+
QKeyCombination(
|
|
331
|
+
Qt.Modifier.CTRL,
|
|
332
|
+
Qt.Key.Key_Left,
|
|
334
333
|
)
|
|
335
334
|
),
|
|
336
335
|
self,
|
|
@@ -339,17 +338,13 @@ class ManualScoringWindow(QtWidgets.QDialog):
|
|
|
339
338
|
partial(self.jump_to_next_state, DIRECTION_LEFT, UNDEFINED_STATE)
|
|
340
339
|
)
|
|
341
340
|
|
|
342
|
-
keypress_undo =
|
|
343
|
-
|
|
344
|
-
QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_Z)
|
|
345
|
-
),
|
|
341
|
+
keypress_undo = QShortcut(
|
|
342
|
+
QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_Z)),
|
|
346
343
|
self,
|
|
347
344
|
)
|
|
348
345
|
keypress_undo.activated.connect(self.undo)
|
|
349
|
-
keypress_redo =
|
|
350
|
-
|
|
351
|
-
QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_Y)
|
|
352
|
-
),
|
|
346
|
+
keypress_redo = QShortcut(
|
|
347
|
+
QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_Y)),
|
|
353
348
|
self,
|
|
354
349
|
)
|
|
355
350
|
keypress_redo.activated.connect(self.redo)
|
|
@@ -483,34 +478,34 @@ class ManualScoringWindow(QtWidgets.QDialog):
|
|
|
483
478
|
)
|
|
484
479
|
self.click_to_jump(simulated_click)
|
|
485
480
|
|
|
486
|
-
def closeEvent(self, event:
|
|
481
|
+
def closeEvent(self, event: QCloseEvent) -> None:
|
|
487
482
|
"""Check if there are unsaved changes before closing"""
|
|
488
483
|
if not all(self.labels == self.last_saved_labels):
|
|
489
|
-
result =
|
|
484
|
+
result = QMessageBox.question(
|
|
490
485
|
self,
|
|
491
486
|
"Unsaved changes",
|
|
492
487
|
"You have unsaved changes. Really quit?",
|
|
493
|
-
|
|
488
|
+
QMessageBox.Yes | QMessageBox.No,
|
|
494
489
|
)
|
|
495
|
-
if result ==
|
|
490
|
+
if result == QMessageBox.Yes:
|
|
496
491
|
event.accept()
|
|
497
492
|
else:
|
|
498
493
|
event.ignore()
|
|
499
494
|
|
|
500
495
|
def show_user_manual(self) -> None:
|
|
501
496
|
"""Show a popup window with the user manual"""
|
|
502
|
-
self.popup =
|
|
503
|
-
self.popup_vlayout =
|
|
504
|
-
self.guide_textbox =
|
|
497
|
+
self.popup = QWidget()
|
|
498
|
+
self.popup_vlayout = QVBoxLayout(self.popup)
|
|
499
|
+
self.guide_textbox = QTextBrowser(self.popup)
|
|
505
500
|
self.popup_vlayout.addWidget(self.guide_textbox)
|
|
506
501
|
|
|
507
|
-
url =
|
|
502
|
+
url = QUrl.fromLocalFile(
|
|
508
503
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), USER_MANUAL_FILE)
|
|
509
504
|
)
|
|
510
505
|
self.guide_textbox.setSource(url)
|
|
511
506
|
self.guide_textbox.setOpenLinks(False)
|
|
512
507
|
|
|
513
|
-
self.popup.setGeometry(
|
|
508
|
+
self.popup.setGeometry(QRect(100, 100, 830, 600))
|
|
514
509
|
self.popup.show()
|
|
515
510
|
|
|
516
511
|
def jump_to_next_state(self, direction: str, target: str) -> None:
|
accusleepy/gui/mplwidget.py
CHANGED
|
@@ -339,18 +339,23 @@ def resample_x_ticks(x_ticks: np.array) -> np.array:
|
|
|
339
339
|
"""Choose a subset of x_ticks to display
|
|
340
340
|
|
|
341
341
|
The x-axis can get crowded if there are too many timestamps shown.
|
|
342
|
-
This function
|
|
343
|
-
|
|
344
|
-
to being a factor of the number of ticks.
|
|
342
|
+
This function finds a subset of evenly spaced x-axis ticks that
|
|
343
|
+
includes the one at the beginning of the central epoch.
|
|
345
344
|
|
|
346
345
|
:param x_ticks: full set of x_ticks
|
|
347
346
|
:return: smaller subset of x_ticks
|
|
348
347
|
"""
|
|
349
|
-
|
|
350
|
-
n_ticks = len(x_ticks) + 1
|
|
351
|
-
if n_ticks < MAX_LOWER_X_TICK_N:
|
|
348
|
+
if len(x_ticks) <= MAX_LOWER_X_TICK_N:
|
|
352
349
|
return x_ticks
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
350
|
+
|
|
351
|
+
# number of ticks to the left of the central epoch
|
|
352
|
+
# this will always be an integer
|
|
353
|
+
nl = round((len(x_ticks) - 1) / 2)
|
|
354
|
+
|
|
355
|
+
# search for even tick spacings that include the central epoch
|
|
356
|
+
# if necessary, skip the leftmost tick
|
|
357
|
+
for offset in [0, 1]:
|
|
358
|
+
if (nl - offset) % 3 == 0:
|
|
359
|
+
return x_ticks[offset :: round((nl - offset) / 3)]
|
|
360
|
+
elif (nl - offset) % 2 == 0:
|
|
361
|
+
return x_ticks[offset :: round((nl - offset) / 2)]
|
accusleepy/gui/primary_window.py
CHANGED
accusleepy/models.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
import
|
|
3
|
-
|
|
4
|
-
from torch import
|
|
2
|
+
from torch import device, flatten, nn
|
|
3
|
+
from torch import load as torch_load
|
|
4
|
+
from torch import save as torch_save
|
|
5
5
|
|
|
6
|
+
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainStateSet
|
|
6
7
|
from accusleepy.constants import (
|
|
7
8
|
DOWNSAMPLING_START_FREQ,
|
|
8
9
|
EMG_COPIES,
|
|
@@ -41,8 +42,57 @@ class SSANN(nn.Module):
|
|
|
41
42
|
|
|
42
43
|
def forward(self, x):
|
|
43
44
|
x = x.float()
|
|
44
|
-
x = self.pool(
|
|
45
|
-
x = self.pool(
|
|
46
|
-
x = self.pool(
|
|
47
|
-
x =
|
|
45
|
+
x = self.pool(nn.functional.relu(self.conv1_bn(self.conv1(x))))
|
|
46
|
+
x = self.pool(nn.functional.relu(self.conv2_bn(self.conv2(x))))
|
|
47
|
+
x = self.pool(nn.functional.relu(self.conv3_bn(self.conv3(x))))
|
|
48
|
+
x = flatten(x, 1) # flatten all dimensions except batch
|
|
48
49
|
return self.fc1(x)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def save_model(
|
|
53
|
+
model: SSANN,
|
|
54
|
+
filename: str,
|
|
55
|
+
epoch_length: int | float,
|
|
56
|
+
epochs_per_img: int,
|
|
57
|
+
model_type: str,
|
|
58
|
+
brain_state_set: BrainStateSet,
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Save classification model and its metadata
|
|
61
|
+
|
|
62
|
+
:param model: classification model
|
|
63
|
+
:param epoch_length: epoch length used when training the model
|
|
64
|
+
:param epochs_per_img: number of epochs in each model input
|
|
65
|
+
:param model_type: default or real-time
|
|
66
|
+
:param brain_state_set: set of brain state options
|
|
67
|
+
:param filename: filename
|
|
68
|
+
"""
|
|
69
|
+
state_dict = model.state_dict()
|
|
70
|
+
state_dict.update({"epoch_length": epoch_length})
|
|
71
|
+
state_dict.update({"epochs_per_img": epochs_per_img})
|
|
72
|
+
state_dict.update({"model_type": model_type})
|
|
73
|
+
state_dict.update(
|
|
74
|
+
{BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
torch_save(state_dict, filename)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
|
|
81
|
+
"""Load classification model and its metadata
|
|
82
|
+
|
|
83
|
+
:param filename: filename
|
|
84
|
+
:return: model, epoch length used when training the model,
|
|
85
|
+
number of epochs in each model input, model type
|
|
86
|
+
(default or real-time), set of brain state options
|
|
87
|
+
used when training the model
|
|
88
|
+
"""
|
|
89
|
+
state_dict = torch_load(filename, weights_only=True, map_location=device("cpu"))
|
|
90
|
+
epoch_length = state_dict.pop("epoch_length")
|
|
91
|
+
epochs_per_img = state_dict.pop("epochs_per_img")
|
|
92
|
+
model_type = state_dict.pop("model_type")
|
|
93
|
+
brain_states = state_dict.pop(BRAIN_STATES_KEY)
|
|
94
|
+
n_classes = len([b for b in brain_states if b["is_scored"]])
|
|
95
|
+
|
|
96
|
+
model = SSANN(n_classes=n_classes)
|
|
97
|
+
model.load_state_dict(state_dict)
|
|
98
|
+
return model, epoch_length, epochs_per_img, model_type, brain_states
|
accusleepy/multitaper.py
CHANGED
|
@@ -15,8 +15,9 @@ import warnings
|
|
|
15
15
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from joblib import Parallel, cpu_count, delayed
|
|
18
|
-
|
|
19
|
-
from scipy.signal
|
|
18
|
+
|
|
19
|
+
# from scipy.signal import detrend # unused by AccuSleePy
|
|
20
|
+
# from scipy.signal.windows import dpss # lazily loaded later
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
# MULTITAPER SPECTROGRAM #
|
|
@@ -28,14 +29,14 @@ def spectrogram(
|
|
|
28
29
|
num_tapers=None,
|
|
29
30
|
window_params=None,
|
|
30
31
|
min_nfft=0,
|
|
31
|
-
detrend_opt="
|
|
32
|
+
detrend_opt="off", # this functionality is disabled
|
|
32
33
|
multiprocess=False,
|
|
33
34
|
n_jobs=None,
|
|
34
35
|
weighting="unity",
|
|
35
36
|
plot_on=False,
|
|
36
37
|
return_fig=False,
|
|
37
38
|
clim_scale=True,
|
|
38
|
-
verbose=
|
|
39
|
+
verbose=False,
|
|
39
40
|
xyflip=False,
|
|
40
41
|
ax=None,
|
|
41
42
|
):
|
|
@@ -121,6 +122,7 @@ def spectrogram(
|
|
|
121
122
|
|
|
122
123
|
__________________________________________________________________________________________________________________
|
|
123
124
|
"""
|
|
125
|
+
from scipy.signal.windows import dpss
|
|
124
126
|
|
|
125
127
|
# Process user input
|
|
126
128
|
[
|
|
@@ -618,9 +620,9 @@ def calc_mts_segment(
|
|
|
618
620
|
ret.fill(np.nan)
|
|
619
621
|
return ret
|
|
620
622
|
|
|
621
|
-
# Option to detrend data to remove low frequency DC component
|
|
622
|
-
if detrend_opt != "off":
|
|
623
|
-
|
|
623
|
+
# # Option to detrend data to remove low frequency DC component
|
|
624
|
+
# if detrend_opt != "off":
|
|
625
|
+
# data_segment = detrend(data_segment, type=detrend_opt)
|
|
624
626
|
|
|
625
627
|
# Multiply data by dpss tapers (STEP 2)
|
|
626
628
|
tapered_data = np.multiply(np.asmatrix(data_segment).T, np.asmatrix(dpss_tapers.T))
|
accusleepy/signal_processing.py
CHANGED
|
@@ -1,17 +1,14 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import re
|
|
3
2
|
import warnings
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
from operator import attrgetter
|
|
6
3
|
|
|
7
4
|
import numpy as np
|
|
8
5
|
import pandas as pd
|
|
9
6
|
from PIL import Image
|
|
10
|
-
from scipy.signal import butter, filtfilt
|
|
11
7
|
from tqdm import trange
|
|
12
8
|
|
|
13
9
|
from accusleepy.brain_state_set import BrainStateSet
|
|
14
10
|
from accusleepy.constants import (
|
|
11
|
+
ANNOTATIONS_FILENAME,
|
|
15
12
|
DEFAULT_MODEL_TYPE,
|
|
16
13
|
DOWNSAMPLING_START_FREQ,
|
|
17
14
|
EMG_COPIES,
|
|
@@ -23,13 +20,13 @@ from accusleepy.constants import (
|
|
|
23
20
|
from accusleepy.fileio import Recording, load_labels, load_recording
|
|
24
21
|
from accusleepy.multitaper import spectrogram
|
|
25
22
|
|
|
23
|
+
# note: scipy is lazily imported
|
|
24
|
+
|
|
26
25
|
# clip mixture z-scores above and below this level
|
|
27
26
|
# in the matlab implementation, I used 4.5
|
|
28
27
|
ABS_MAX_Z_SCORE = 3.5
|
|
29
28
|
# upper frequency limit when generating EEG spectrograms
|
|
30
29
|
SPECTROGRAM_UPPER_FREQ = 64
|
|
31
|
-
# filename used to store info about training image datasets
|
|
32
|
-
ANNOTATIONS_FILENAME = "annotations.csv"
|
|
33
30
|
|
|
34
31
|
|
|
35
32
|
def resample(
|
|
@@ -186,6 +183,8 @@ def get_emg_power(
|
|
|
186
183
|
:param epoch_length: epoch length, in seconds
|
|
187
184
|
:return: EMG "power" for each epoch
|
|
188
185
|
"""
|
|
186
|
+
from scipy.signal import butter, filtfilt
|
|
187
|
+
|
|
189
188
|
# filter parameters
|
|
190
189
|
order = 8
|
|
191
190
|
bp_lower = 20
|
|
@@ -450,140 +449,3 @@ def create_training_images(
|
|
|
450
449
|
)
|
|
451
450
|
|
|
452
451
|
return failed_recordings
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
@dataclass
|
|
456
|
-
class Bout:
|
|
457
|
-
"""Stores information about a brain state bout"""
|
|
458
|
-
|
|
459
|
-
length: int # length, in number of epochs
|
|
460
|
-
start_index: int # index where bout starts
|
|
461
|
-
end_index: int # index where bout ends
|
|
462
|
-
surrounding_state: int # brain state on both sides of the bout
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
def find_last_adjacent_bout(sorted_bouts: list[Bout], bout_index: int) -> int:
|
|
466
|
-
"""Find index of last consecutive same-length bout
|
|
467
|
-
|
|
468
|
-
When running the post-processing step that enforces a minimum duration
|
|
469
|
-
for brain state bouts, there is a special case when bouts below the
|
|
470
|
-
duration threshold occur consecutively. This function performs a
|
|
471
|
-
recursive search for the index of a bout at the end of such a sequence.
|
|
472
|
-
When initially called, bout_index will always be 0. If, for example, the
|
|
473
|
-
first three bouts in the list are consecutive, the function will return 2.
|
|
474
|
-
|
|
475
|
-
:param sorted_bouts: list of brain state bouts, sorted by start time
|
|
476
|
-
:param bout_index: index of the bout in question
|
|
477
|
-
:return: index of the last consecutive same-length bout
|
|
478
|
-
"""
|
|
479
|
-
# if we're at the end of the bout list, stop
|
|
480
|
-
if bout_index == len(sorted_bouts) - 1:
|
|
481
|
-
return bout_index
|
|
482
|
-
|
|
483
|
-
# if there is an adjacent bout
|
|
484
|
-
if sorted_bouts[bout_index].end_index == sorted_bouts[bout_index + 1].start_index:
|
|
485
|
-
# look for more adjacent bouts using that one as a starting point
|
|
486
|
-
return find_last_adjacent_bout(sorted_bouts, bout_index + 1)
|
|
487
|
-
else:
|
|
488
|
-
return bout_index
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
def enforce_min_bout_length(
|
|
492
|
-
labels: np.array, epoch_length: int | float, min_bout_length: int | float
|
|
493
|
-
) -> np.array:
|
|
494
|
-
"""Ensure brain state bouts meet the min length requirement
|
|
495
|
-
|
|
496
|
-
As a post-processing step for sleep scoring, we can require that any
|
|
497
|
-
bout (continuous period) of a brain state have a minimum duration.
|
|
498
|
-
This function sets any bout shorter than the minimum duration to the
|
|
499
|
-
surrounding brain state (if the states on the left and right sides
|
|
500
|
-
are the same). In the case where there are consecutive short bouts,
|
|
501
|
-
it either creates a transition at the midpoint or removes all short
|
|
502
|
-
bouts, depending on whether the number is even or odd. For example:
|
|
503
|
-
...AAABABAAA... -> ...AAAAAAAAA...
|
|
504
|
-
...AAABABABBB... -> ...AAAAABBBBB...
|
|
505
|
-
|
|
506
|
-
:param labels: brain state labels (digits in the 0-9 range)
|
|
507
|
-
:param epoch_length: epoch length, in seconds
|
|
508
|
-
:param min_bout_length: minimum bout length, in seconds
|
|
509
|
-
:return: updated brain state labels
|
|
510
|
-
"""
|
|
511
|
-
# if recording is very short, don't change anything
|
|
512
|
-
if labels.size < 3:
|
|
513
|
-
return labels
|
|
514
|
-
|
|
515
|
-
if epoch_length == min_bout_length:
|
|
516
|
-
return labels
|
|
517
|
-
|
|
518
|
-
# get minimum number of epochs in a bout
|
|
519
|
-
min_epochs = int(np.ceil(min_bout_length / epoch_length))
|
|
520
|
-
# get set of states in the labels
|
|
521
|
-
brain_states = set(labels.tolist())
|
|
522
|
-
|
|
523
|
-
while True: # so true
|
|
524
|
-
# convert labels to a string for regex search
|
|
525
|
-
# There is probably a regex that can find all patterns like ab+a
|
|
526
|
-
# without consuming each "a" but I haven't found it :(
|
|
527
|
-
label_string = "".join(labels.astype(str))
|
|
528
|
-
|
|
529
|
-
bouts = list()
|
|
530
|
-
|
|
531
|
-
for state in brain_states:
|
|
532
|
-
for other_state in brain_states:
|
|
533
|
-
if state == other_state:
|
|
534
|
-
continue
|
|
535
|
-
# get start and end indices of each bout
|
|
536
|
-
expression = (
|
|
537
|
-
f"(?<={other_state}){state}{{1,{min_epochs - 1}}}(?={other_state})"
|
|
538
|
-
)
|
|
539
|
-
matches = re.finditer(expression, label_string)
|
|
540
|
-
spans = [match.span() for match in matches]
|
|
541
|
-
|
|
542
|
-
# if some bouts were found
|
|
543
|
-
for span in spans:
|
|
544
|
-
bouts.append(
|
|
545
|
-
Bout(
|
|
546
|
-
length=span[1] - span[0],
|
|
547
|
-
start_index=span[0],
|
|
548
|
-
end_index=span[1],
|
|
549
|
-
surrounding_state=other_state,
|
|
550
|
-
)
|
|
551
|
-
)
|
|
552
|
-
|
|
553
|
-
if len(bouts) == 0:
|
|
554
|
-
break
|
|
555
|
-
|
|
556
|
-
# only keep the shortest bouts
|
|
557
|
-
min_length_in_list = np.min([bout.length for bout in bouts])
|
|
558
|
-
bouts = [i for i in bouts if i.length == min_length_in_list]
|
|
559
|
-
# sort by start index
|
|
560
|
-
sorted_bouts = sorted(bouts, key=attrgetter("start_index"))
|
|
561
|
-
|
|
562
|
-
while len(sorted_bouts) > 0:
|
|
563
|
-
# get row index of latest adjacent bout (of same length)
|
|
564
|
-
last_adjacent_bout_index = find_last_adjacent_bout(sorted_bouts, 0)
|
|
565
|
-
# if there's an even number of adjacent bouts
|
|
566
|
-
if (last_adjacent_bout_index + 1) % 2 == 0:
|
|
567
|
-
midpoint = sorted_bouts[
|
|
568
|
-
round((last_adjacent_bout_index + 1) / 2)
|
|
569
|
-
].start_index
|
|
570
|
-
labels[sorted_bouts[0].start_index : midpoint] = sorted_bouts[
|
|
571
|
-
0
|
|
572
|
-
].surrounding_state
|
|
573
|
-
labels[midpoint : sorted_bouts[last_adjacent_bout_index].end_index] = (
|
|
574
|
-
sorted_bouts[last_adjacent_bout_index].surrounding_state
|
|
575
|
-
)
|
|
576
|
-
else:
|
|
577
|
-
labels[
|
|
578
|
-
sorted_bouts[0].start_index : sorted_bouts[
|
|
579
|
-
last_adjacent_bout_index
|
|
580
|
-
].end_index
|
|
581
|
-
] = sorted_bouts[0].surrounding_state
|
|
582
|
-
|
|
583
|
-
# delete the bouts we just fixed
|
|
584
|
-
if last_adjacent_bout_index == len(sorted_bouts) - 1:
|
|
585
|
-
sorted_bouts = []
|
|
586
|
-
else:
|
|
587
|
-
sorted_bouts = sorted_bouts[(last_adjacent_bout_index + 1) :]
|
|
588
|
-
|
|
589
|
-
return labels
|