accusleepy 0.4.5__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {accusleepy-0.4.5 → accusleepy-0.5.0}/PKG-INFO +2 -1
- {accusleepy-0.4.5 → accusleepy-0.5.0}/README.md +1 -0
- accusleepy-0.5.0/accusleepy/bouts.py +142 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/classification.py +2 -2
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/constants.py +2 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/fileio.py +0 -53
- accusleepy-0.5.0/accusleepy/gui/images/primary_window.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/main.py +84 -64
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/manual_scoring.py +76 -81
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/mplwidget.py +15 -10
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/primary_window.py +1 -0
- accusleepy-0.5.0/accusleepy/models.py +98 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/multitaper.py +9 -7
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/signal_processing.py +5 -143
- {accusleepy-0.4.5 → accusleepy-0.5.0}/pyproject.toml +1 -1
- accusleepy-0.4.5/accusleepy/gui/images/primary_window.png +0 -0
- accusleepy-0.4.5/accusleepy/models.py +0 -48
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/__init__.py +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/__main__.py +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/brain_state_set.py +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/config.json +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/__init__.py +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/brightness_down.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/brightness_up.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/double_down_arrow.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/double_up_arrow.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/down_arrow.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/home.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/question.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/save.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/up_arrow.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/zoom_in.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/icons/zoom_out.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/images/viewer_window.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/images/viewer_window_annotated.png +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/primary_window.ui +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/resources.qrc +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/resources_rc.py +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/text/config_guide.txt +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/text/main_guide.md +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/text/manual_scoring_guide.md +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/viewer_window.py +0 -0
- {accusleepy-0.4.5 → accusleepy-0.5.0}/accusleepy/gui/viewer_window.ui +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: accusleepy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: Python implementation of AccuSleep
|
|
5
5
|
License: GPL-3.0-only
|
|
6
6
|
Author: Zeke Barger
|
|
@@ -75,6 +75,7 @@ to the [config file](accusleepy/config.json).
|
|
|
75
75
|
|
|
76
76
|
## Changelog
|
|
77
77
|
|
|
78
|
+
- 0.5.0: Performance improvements
|
|
78
79
|
- 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
|
|
79
80
|
- 0.4.4: Performance improvements
|
|
80
81
|
- 0.4.3: Improved unit tests and user manuals
|
|
@@ -47,6 +47,7 @@ to the [config file](accusleepy/config.json).
|
|
|
47
47
|
|
|
48
48
|
## Changelog
|
|
49
49
|
|
|
50
|
+
- 0.5.0: Performance improvements
|
|
50
51
|
- 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
|
|
51
52
|
- 0.4.4: Performance improvements
|
|
52
53
|
- 0.4.3: Improved unit tests and user manuals
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from operator import attrgetter
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Bout:
|
|
10
|
+
"""Stores information about a brain state bout"""
|
|
11
|
+
|
|
12
|
+
length: int # length, in number of epochs
|
|
13
|
+
start_index: int # index where bout starts
|
|
14
|
+
end_index: int # index where bout ends
|
|
15
|
+
surrounding_state: int # brain state on both sides of the bout
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def find_last_adjacent_bout(sorted_bouts: list[Bout], bout_index: int) -> int:
|
|
19
|
+
"""Find index of last consecutive same-length bout
|
|
20
|
+
|
|
21
|
+
When running the post-processing step that enforces a minimum duration
|
|
22
|
+
for brain state bouts, there is a special case when bouts below the
|
|
23
|
+
duration threshold occur consecutively. This function performs a
|
|
24
|
+
recursive search for the index of a bout at the end of such a sequence.
|
|
25
|
+
When initially called, bout_index will always be 0. If, for example, the
|
|
26
|
+
first three bouts in the list are consecutive, the function will return 2.
|
|
27
|
+
|
|
28
|
+
:param sorted_bouts: list of brain state bouts, sorted by start time
|
|
29
|
+
:param bout_index: index of the bout in question
|
|
30
|
+
:return: index of the last consecutive same-length bout
|
|
31
|
+
"""
|
|
32
|
+
# if we're at the end of the bout list, stop
|
|
33
|
+
if bout_index == len(sorted_bouts) - 1:
|
|
34
|
+
return bout_index
|
|
35
|
+
|
|
36
|
+
# if there is an adjacent bout
|
|
37
|
+
if sorted_bouts[bout_index].end_index == sorted_bouts[bout_index + 1].start_index:
|
|
38
|
+
# look for more adjacent bouts using that one as a starting point
|
|
39
|
+
return find_last_adjacent_bout(sorted_bouts, bout_index + 1)
|
|
40
|
+
else:
|
|
41
|
+
return bout_index
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def enforce_min_bout_length(
|
|
45
|
+
labels: np.array, epoch_length: int | float, min_bout_length: int | float
|
|
46
|
+
) -> np.array:
|
|
47
|
+
"""Ensure brain state bouts meet the min length requirement
|
|
48
|
+
|
|
49
|
+
As a post-processing step for sleep scoring, we can require that any
|
|
50
|
+
bout (continuous period) of a brain state have a minimum duration.
|
|
51
|
+
This function sets any bout shorter than the minimum duration to the
|
|
52
|
+
surrounding brain state (if the states on the left and right sides
|
|
53
|
+
are the same). In the case where there are consecutive short bouts,
|
|
54
|
+
it either creates a transition at the midpoint or removes all short
|
|
55
|
+
bouts, depending on whether the number is even or odd. For example:
|
|
56
|
+
...AAABABAAA... -> ...AAAAAAAAA...
|
|
57
|
+
...AAABABABBB... -> ...AAAAABBBBB...
|
|
58
|
+
|
|
59
|
+
:param labels: brain state labels (digits in the 0-9 range)
|
|
60
|
+
:param epoch_length: epoch length, in seconds
|
|
61
|
+
:param min_bout_length: minimum bout length, in seconds
|
|
62
|
+
:return: updated brain state labels
|
|
63
|
+
"""
|
|
64
|
+
# if recording is very short, don't change anything
|
|
65
|
+
if labels.size < 3:
|
|
66
|
+
return labels
|
|
67
|
+
|
|
68
|
+
if epoch_length == min_bout_length:
|
|
69
|
+
return labels
|
|
70
|
+
|
|
71
|
+
# get minimum number of epochs in a bout
|
|
72
|
+
min_epochs = int(np.ceil(min_bout_length / epoch_length))
|
|
73
|
+
# get set of states in the labels
|
|
74
|
+
brain_states = set(labels.tolist())
|
|
75
|
+
|
|
76
|
+
while True: # so true
|
|
77
|
+
# convert labels to a string for regex search
|
|
78
|
+
# There is probably a regex that can find all patterns like ab+a
|
|
79
|
+
# without consuming each "a" but I haven't found it :(
|
|
80
|
+
label_string = "".join(labels.astype(str))
|
|
81
|
+
|
|
82
|
+
bouts = list()
|
|
83
|
+
|
|
84
|
+
for state in brain_states:
|
|
85
|
+
for other_state in brain_states:
|
|
86
|
+
if state == other_state:
|
|
87
|
+
continue
|
|
88
|
+
# get start and end indices of each bout
|
|
89
|
+
expression = (
|
|
90
|
+
f"(?<={other_state}){state}{{1,{min_epochs - 1}}}(?={other_state})"
|
|
91
|
+
)
|
|
92
|
+
matches = re.finditer(expression, label_string)
|
|
93
|
+
spans = [match.span() for match in matches]
|
|
94
|
+
|
|
95
|
+
# if some bouts were found
|
|
96
|
+
for span in spans:
|
|
97
|
+
bouts.append(
|
|
98
|
+
Bout(
|
|
99
|
+
length=span[1] - span[0],
|
|
100
|
+
start_index=span[0],
|
|
101
|
+
end_index=span[1],
|
|
102
|
+
surrounding_state=other_state,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if len(bouts) == 0:
|
|
107
|
+
break
|
|
108
|
+
|
|
109
|
+
# only keep the shortest bouts
|
|
110
|
+
min_length_in_list = np.min([bout.length for bout in bouts])
|
|
111
|
+
bouts = [i for i in bouts if i.length == min_length_in_list]
|
|
112
|
+
# sort by start index
|
|
113
|
+
sorted_bouts = sorted(bouts, key=attrgetter("start_index"))
|
|
114
|
+
|
|
115
|
+
while len(sorted_bouts) > 0:
|
|
116
|
+
# get row index of latest adjacent bout (of same length)
|
|
117
|
+
last_adjacent_bout_index = find_last_adjacent_bout(sorted_bouts, 0)
|
|
118
|
+
# if there's an even number of adjacent bouts
|
|
119
|
+
if (last_adjacent_bout_index + 1) % 2 == 0:
|
|
120
|
+
midpoint = sorted_bouts[
|
|
121
|
+
round((last_adjacent_bout_index + 1) / 2)
|
|
122
|
+
].start_index
|
|
123
|
+
labels[sorted_bouts[0].start_index : midpoint] = sorted_bouts[
|
|
124
|
+
0
|
|
125
|
+
].surrounding_state
|
|
126
|
+
labels[midpoint : sorted_bouts[last_adjacent_bout_index].end_index] = (
|
|
127
|
+
sorted_bouts[last_adjacent_bout_index].surrounding_state
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
labels[
|
|
131
|
+
sorted_bouts[0].start_index : sorted_bouts[
|
|
132
|
+
last_adjacent_bout_index
|
|
133
|
+
].end_index
|
|
134
|
+
] = sorted_bouts[0].surrounding_state
|
|
135
|
+
|
|
136
|
+
# delete the bouts we just fixed
|
|
137
|
+
if last_adjacent_bout_index == len(sorted_bouts) - 1:
|
|
138
|
+
sorted_bouts = []
|
|
139
|
+
else:
|
|
140
|
+
sorted_bouts = sorted_bouts[(last_adjacent_bout_index + 1) :]
|
|
141
|
+
|
|
142
|
+
return labels
|
|
@@ -61,13 +61,13 @@ def get_device():
|
|
|
61
61
|
)
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
def
|
|
64
|
+
def train_ssann(
|
|
65
65
|
annotations_file: str,
|
|
66
66
|
img_dir: str,
|
|
67
67
|
mixture_weights: np.array,
|
|
68
68
|
n_classes: int,
|
|
69
69
|
) -> SSANN:
|
|
70
|
-
"""Train a classification model for sleep scoring
|
|
70
|
+
"""Train a SSANN classification model for sleep scoring
|
|
71
71
|
|
|
72
72
|
:param annotations_file: file with information on each training image
|
|
73
73
|
:param img_dir: training image location
|
|
@@ -37,3 +37,5 @@ RECORDING_LIST_NAME = "recording_list"
|
|
|
37
37
|
RECORDING_LIST_FILE_TYPE = ".json"
|
|
38
38
|
# key for default epoch length in config
|
|
39
39
|
DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
|
|
40
|
+
# filename used to store info about training image datasets
|
|
41
|
+
ANNOTATIONS_FILENAME = "annotations.csv"
|
|
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
|
-
import torch
|
|
8
7
|
from PySide6.QtWidgets import QListWidgetItem
|
|
9
8
|
|
|
10
9
|
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
@@ -19,7 +18,6 @@ from accusleepy.constants import (
|
|
|
19
18
|
RECORDING_LIST_NAME,
|
|
20
19
|
UNDEFINED_LABEL,
|
|
21
20
|
)
|
|
22
|
-
from accusleepy.models import SSANN
|
|
23
21
|
|
|
24
22
|
|
|
25
23
|
@dataclass
|
|
@@ -46,57 +44,6 @@ def load_calibration_file(filename: str) -> (np.array, np.array):
|
|
|
46
44
|
return mixture_means, mixture_sds
|
|
47
45
|
|
|
48
46
|
|
|
49
|
-
def save_model(
|
|
50
|
-
model: SSANN,
|
|
51
|
-
filename: str,
|
|
52
|
-
epoch_length: int | float,
|
|
53
|
-
epochs_per_img: int,
|
|
54
|
-
model_type: str,
|
|
55
|
-
brain_state_set: BrainStateSet,
|
|
56
|
-
) -> None:
|
|
57
|
-
"""Save classification model and its metadata
|
|
58
|
-
|
|
59
|
-
:param model: classification model
|
|
60
|
-
:param epoch_length: epoch length used when training the model
|
|
61
|
-
:param epochs_per_img: number of epochs in each model input
|
|
62
|
-
:param model_type: default or real-time
|
|
63
|
-
:param brain_state_set: set of brain state options
|
|
64
|
-
:param filename: filename
|
|
65
|
-
"""
|
|
66
|
-
state_dict = model.state_dict()
|
|
67
|
-
state_dict.update({"epoch_length": epoch_length})
|
|
68
|
-
state_dict.update({"epochs_per_img": epochs_per_img})
|
|
69
|
-
state_dict.update({"model_type": model_type})
|
|
70
|
-
state_dict.update(
|
|
71
|
-
{BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
torch.save(state_dict, filename)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
|
|
78
|
-
"""Load classification model and its metadata
|
|
79
|
-
|
|
80
|
-
:param filename: filename
|
|
81
|
-
:return: model, epoch length used when training the model,
|
|
82
|
-
number of epochs in each model input, model type
|
|
83
|
-
(default or real-time), set of brain state options
|
|
84
|
-
used when training the model
|
|
85
|
-
"""
|
|
86
|
-
state_dict = torch.load(
|
|
87
|
-
filename, weights_only=True, map_location=torch.device("cpu")
|
|
88
|
-
)
|
|
89
|
-
epoch_length = state_dict.pop("epoch_length")
|
|
90
|
-
epochs_per_img = state_dict.pop("epochs_per_img")
|
|
91
|
-
model_type = state_dict.pop("model_type")
|
|
92
|
-
brain_states = state_dict.pop(BRAIN_STATES_KEY)
|
|
93
|
-
n_classes = len([b for b in brain_states if b["is_scored"]])
|
|
94
|
-
|
|
95
|
-
model = SSANN(n_classes=n_classes)
|
|
96
|
-
model.load_state_dict(state_dict)
|
|
97
|
-
return model, epoch_length, epochs_per_img, model_type, brain_states
|
|
98
|
-
|
|
99
|
-
|
|
100
47
|
def load_csv_or_parquet(filename: str) -> pd.DataFrame:
|
|
101
48
|
"""Load a csv or parquet file as a dataframe
|
|
102
49
|
|
|
Binary file
|
|
@@ -5,20 +5,37 @@ import datetime
|
|
|
5
5
|
import os
|
|
6
6
|
import shutil
|
|
7
7
|
import sys
|
|
8
|
-
import toml
|
|
9
8
|
from dataclasses import dataclass
|
|
10
9
|
from functools import partial
|
|
11
10
|
|
|
12
11
|
import numpy as np
|
|
13
|
-
|
|
12
|
+
import toml
|
|
13
|
+
from PySide6.QtCore import (
|
|
14
|
+
QEvent,
|
|
15
|
+
QKeyCombination,
|
|
16
|
+
QObject,
|
|
17
|
+
QRect,
|
|
18
|
+
Qt,
|
|
19
|
+
QUrl,
|
|
20
|
+
)
|
|
21
|
+
from PySide6.QtGui import QKeySequence, QShortcut
|
|
22
|
+
from PySide6.QtWidgets import (
|
|
23
|
+
QApplication,
|
|
24
|
+
QCheckBox,
|
|
25
|
+
QDoubleSpinBox,
|
|
26
|
+
QFileDialog,
|
|
27
|
+
QLabel,
|
|
28
|
+
QListWidgetItem,
|
|
29
|
+
QMainWindow,
|
|
30
|
+
QTextBrowser,
|
|
31
|
+
QVBoxLayout,
|
|
32
|
+
QWidget,
|
|
33
|
+
)
|
|
14
34
|
|
|
35
|
+
from accusleepy.bouts import enforce_min_bout_length
|
|
15
36
|
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
16
|
-
from accusleepy.classification import (
|
|
17
|
-
create_calibration_file,
|
|
18
|
-
score_recording,
|
|
19
|
-
train_model,
|
|
20
|
-
)
|
|
21
37
|
from accusleepy.constants import (
|
|
38
|
+
ANNOTATIONS_FILENAME,
|
|
22
39
|
CALIBRATION_FILE_TYPE,
|
|
23
40
|
DEFAULT_MODEL_TYPE,
|
|
24
41
|
LABEL_FILE_TYPE,
|
|
@@ -33,23 +50,21 @@ from accusleepy.fileio import (
|
|
|
33
50
|
load_calibration_file,
|
|
34
51
|
load_config,
|
|
35
52
|
load_labels,
|
|
36
|
-
load_model,
|
|
37
53
|
load_recording,
|
|
38
54
|
load_recording_list,
|
|
39
55
|
save_config,
|
|
40
56
|
save_labels,
|
|
41
|
-
save_model,
|
|
42
57
|
save_recording_list,
|
|
43
58
|
)
|
|
44
59
|
from accusleepy.gui.manual_scoring import ManualScoringWindow
|
|
45
60
|
from accusleepy.gui.primary_window import Ui_PrimaryWindow
|
|
46
61
|
from accusleepy.signal_processing import (
|
|
47
|
-
ANNOTATIONS_FILENAME,
|
|
48
62
|
create_training_images,
|
|
49
|
-
enforce_min_bout_length,
|
|
50
63
|
resample_and_standardize,
|
|
51
64
|
)
|
|
52
65
|
|
|
66
|
+
# note: functions using torch or scipy are lazily imported
|
|
67
|
+
|
|
53
68
|
# max number of messages to display
|
|
54
69
|
MESSAGE_BOX_MAX_DEPTH = 200
|
|
55
70
|
LABEL_LENGTH_ERROR = "label file length does not match recording length"
|
|
@@ -63,13 +78,13 @@ class StateSettings:
|
|
|
63
78
|
"""Widgets for config settings for a brain state"""
|
|
64
79
|
|
|
65
80
|
digit: int
|
|
66
|
-
enabled_widget:
|
|
67
|
-
name_widget:
|
|
68
|
-
is_scored_widget:
|
|
69
|
-
frequency_widget:
|
|
81
|
+
enabled_widget: QCheckBox
|
|
82
|
+
name_widget: QLabel
|
|
83
|
+
is_scored_widget: QCheckBox
|
|
84
|
+
frequency_widget: QDoubleSpinBox
|
|
70
85
|
|
|
71
86
|
|
|
72
|
-
class AccuSleepWindow(
|
|
87
|
+
class AccuSleepWindow(QMainWindow):
|
|
73
88
|
"""AccuSleePy primary window"""
|
|
74
89
|
|
|
75
90
|
def __init__(self):
|
|
@@ -103,9 +118,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
103
118
|
|
|
104
119
|
# set up the list of recordings
|
|
105
120
|
first_recording = Recording(
|
|
106
|
-
widget=
|
|
107
|
-
"Recording 1", self.ui.recording_list_widget
|
|
108
|
-
),
|
|
121
|
+
widget=QListWidgetItem("Recording 1", self.ui.recording_list_widget),
|
|
109
122
|
)
|
|
110
123
|
self.ui.recording_list_widget.addItem(first_recording.widget)
|
|
111
124
|
self.ui.recording_list_widget.setCurrentRow(0)
|
|
@@ -132,10 +145,8 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
132
145
|
self.ui.version_label.setText(f"v{version}")
|
|
133
146
|
|
|
134
147
|
# user input: keyboard shortcuts
|
|
135
|
-
keypress_quit =
|
|
136
|
-
|
|
137
|
-
QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_W)
|
|
138
|
-
),
|
|
148
|
+
keypress_quit = QShortcut(
|
|
149
|
+
QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_W)),
|
|
139
150
|
self,
|
|
140
151
|
)
|
|
141
152
|
keypress_quit.activated.connect(self.close)
|
|
@@ -187,7 +198,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
187
198
|
def export_recording_list(self) -> None:
|
|
188
199
|
"""Save current list of recordings to file"""
|
|
189
200
|
# get the name for the recording list file
|
|
190
|
-
filename, _ =
|
|
201
|
+
filename, _ = QFileDialog.getSaveFileName(
|
|
191
202
|
self,
|
|
192
203
|
caption="Save list of recordings as",
|
|
193
204
|
filter="*" + RECORDING_LIST_FILE_TYPE,
|
|
@@ -200,10 +211,10 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
200
211
|
|
|
201
212
|
def import_recording_list(self):
|
|
202
213
|
"""Load list of recordings from file, overwriting current list"""
|
|
203
|
-
file_dialog =
|
|
214
|
+
file_dialog = QFileDialog(self)
|
|
204
215
|
file_dialog.setWindowTitle("Select list of recordings")
|
|
205
|
-
file_dialog.setFileMode(
|
|
206
|
-
file_dialog.setViewMode(
|
|
216
|
+
file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
|
|
217
|
+
file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
|
|
207
218
|
file_dialog.setNameFilter("*" + RECORDING_LIST_FILE_TYPE)
|
|
208
219
|
|
|
209
220
|
if file_dialog.exec():
|
|
@@ -219,7 +230,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
219
230
|
self.recordings = load_recording_list(filename)
|
|
220
231
|
|
|
221
232
|
for recording in self.recordings:
|
|
222
|
-
recording.widget =
|
|
233
|
+
recording.widget = QListWidgetItem(
|
|
223
234
|
f"Recording {recording.name}", self.ui.recording_list_widget
|
|
224
235
|
)
|
|
225
236
|
self.ui.recording_list_widget.addItem(self.recordings[-1].widget)
|
|
@@ -228,7 +239,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
228
239
|
self.ui.recording_list_widget.setCurrentRow(0)
|
|
229
240
|
self.show_message(f"Loaded list of recordings from {filename}")
|
|
230
241
|
|
|
231
|
-
def eventFilter(self, obj:
|
|
242
|
+
def eventFilter(self, obj: QObject, event: QEvent) -> bool:
|
|
232
243
|
"""Filter mouse events to detect when user drags/drops a file
|
|
233
244
|
|
|
234
245
|
:param obj: UI object receiving the event
|
|
@@ -243,7 +254,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
243
254
|
self.ui.model_label,
|
|
244
255
|
]:
|
|
245
256
|
event.accept()
|
|
246
|
-
if event.type() ==
|
|
257
|
+
if event.type() == QEvent.Drop:
|
|
247
258
|
urls = event.mimeData().urls()
|
|
248
259
|
if len(urls) == 1:
|
|
249
260
|
filename = os.path.normpath(urls[0].toLocalFile())
|
|
@@ -299,7 +310,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
299
310
|
return
|
|
300
311
|
|
|
301
312
|
# get filename for the new model
|
|
302
|
-
model_filename, _ =
|
|
313
|
+
model_filename, _ = QFileDialog.getSaveFileName(
|
|
303
314
|
self,
|
|
304
315
|
caption="Save classification model file as",
|
|
305
316
|
filter="*" + MODEL_FILE_TYPE,
|
|
@@ -322,11 +333,10 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
322
333
|
os.makedirs(temp_image_dir, exist_ok=True)
|
|
323
334
|
|
|
324
335
|
# create training images
|
|
325
|
-
self.show_message(
|
|
326
|
-
|
|
327
|
-
)
|
|
336
|
+
self.show_message("Training, please wait. See console for progress updates.")
|
|
337
|
+
self.show_message((f"Creating training images in {temp_image_dir}"))
|
|
328
338
|
self.ui.message_area.repaint()
|
|
329
|
-
|
|
339
|
+
QApplication.processEvents()
|
|
330
340
|
print("Creating training images")
|
|
331
341
|
failed_recordings = create_training_images(
|
|
332
342
|
recordings=self.recordings,
|
|
@@ -349,11 +359,14 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
349
359
|
)
|
|
350
360
|
|
|
351
361
|
# train model
|
|
352
|
-
self.show_message("Training model
|
|
362
|
+
self.show_message("Training model")
|
|
353
363
|
self.ui.message_area.repaint()
|
|
354
|
-
|
|
364
|
+
QApplication.processEvents()
|
|
355
365
|
print("Training model")
|
|
356
|
-
|
|
366
|
+
from accusleepy.classification import train_ssann
|
|
367
|
+
from accusleepy.models import save_model
|
|
368
|
+
|
|
369
|
+
model = train_ssann(
|
|
357
370
|
annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
|
|
358
371
|
img_dir=temp_image_dir,
|
|
359
372
|
mixture_weights=self.brain_state_set.mixture_weights,
|
|
@@ -374,11 +387,12 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
374
387
|
if self.delete_training_images:
|
|
375
388
|
shutil.rmtree(temp_image_dir)
|
|
376
389
|
|
|
377
|
-
self.show_message(f"Training complete
|
|
390
|
+
self.show_message(f"Training complete. Saved model to {model_filename}")
|
|
391
|
+
print("Training complete.")
|
|
378
392
|
|
|
379
393
|
def set_training_folder(self) -> None:
|
|
380
394
|
"""Select location in which to create a folder for training images"""
|
|
381
|
-
training_folder_parent =
|
|
395
|
+
training_folder_parent = QFileDialog.getExistingDirectory(
|
|
382
396
|
self, "Select directory for training images"
|
|
383
397
|
)
|
|
384
398
|
if training_folder_parent:
|
|
@@ -421,7 +435,9 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
421
435
|
|
|
422
436
|
self.ui.score_all_status.setText("running...")
|
|
423
437
|
self.ui.score_all_status.repaint()
|
|
424
|
-
|
|
438
|
+
QApplication.processEvents()
|
|
439
|
+
|
|
440
|
+
from accusleepy.classification import score_recording
|
|
425
441
|
|
|
426
442
|
# check some inputs for each recording
|
|
427
443
|
for recording_index in range(len(self.recordings)):
|
|
@@ -570,11 +586,13 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
570
586
|
|
|
571
587
|
:param filename: model filename, if it's known
|
|
572
588
|
"""
|
|
589
|
+
from accusleepy.models import load_model
|
|
590
|
+
|
|
573
591
|
if filename is None:
|
|
574
|
-
file_dialog =
|
|
592
|
+
file_dialog = QFileDialog(self)
|
|
575
593
|
file_dialog.setWindowTitle("Select classification model")
|
|
576
|
-
file_dialog.setFileMode(
|
|
577
|
-
file_dialog.setViewMode(
|
|
594
|
+
file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
|
|
595
|
+
file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
|
|
578
596
|
file_dialog.setNameFilter("*" + MODEL_FILE_TYPE)
|
|
579
597
|
|
|
580
598
|
if file_dialog.exec():
|
|
@@ -634,7 +652,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
634
652
|
self.ui.model_label.setText(filename)
|
|
635
653
|
|
|
636
654
|
def load_single_recording(
|
|
637
|
-
self, status_widget:
|
|
655
|
+
self, status_widget: QLabel
|
|
638
656
|
) -> (np.array, np.array, int | float, bool):
|
|
639
657
|
"""Load and preprocess one recording
|
|
640
658
|
|
|
@@ -721,7 +739,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
721
739
|
return
|
|
722
740
|
|
|
723
741
|
# get the name for the calibration file
|
|
724
|
-
filename, _ =
|
|
742
|
+
filename, _ = QFileDialog.getSaveFileName(
|
|
725
743
|
self,
|
|
726
744
|
caption="Save calibration file as",
|
|
727
745
|
filter="*" + CALIBRATION_FILE_TYPE,
|
|
@@ -730,6 +748,8 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
730
748
|
return
|
|
731
749
|
filename = os.path.normpath(filename)
|
|
732
750
|
|
|
751
|
+
from accusleepy.classification import create_calibration_file
|
|
752
|
+
|
|
733
753
|
create_calibration_file(
|
|
734
754
|
filename=filename,
|
|
735
755
|
eeg=eeg,
|
|
@@ -799,7 +819,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
799
819
|
# immediately display a status message
|
|
800
820
|
self.ui.manual_scoring_status.setText("loading...")
|
|
801
821
|
self.ui.manual_scoring_status.repaint()
|
|
802
|
-
|
|
822
|
+
QApplication.processEvents()
|
|
803
823
|
|
|
804
824
|
# load the recording
|
|
805
825
|
eeg, emg, sampling_rate, success = self.load_single_recording(
|
|
@@ -889,7 +909,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
889
909
|
|
|
890
910
|
def create_label_file(self) -> None:
|
|
891
911
|
"""Set the filename for a new label file"""
|
|
892
|
-
filename, _ =
|
|
912
|
+
filename, _ = QFileDialog.getSaveFileName(
|
|
893
913
|
self,
|
|
894
914
|
caption="Set filename for label file (nothing will be overwritten yet)",
|
|
895
915
|
filter="*" + LABEL_FILE_TYPE,
|
|
@@ -901,10 +921,10 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
901
921
|
|
|
902
922
|
def select_label_file(self) -> None:
|
|
903
923
|
"""User can select an existing label file"""
|
|
904
|
-
file_dialog =
|
|
924
|
+
file_dialog = QFileDialog(self)
|
|
905
925
|
file_dialog.setWindowTitle("Select label file")
|
|
906
|
-
file_dialog.setFileMode(
|
|
907
|
-
file_dialog.setViewMode(
|
|
926
|
+
file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
|
|
927
|
+
file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
|
|
908
928
|
file_dialog.setNameFilter("*" + LABEL_FILE_TYPE)
|
|
909
929
|
|
|
910
930
|
if file_dialog.exec():
|
|
@@ -916,10 +936,10 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
916
936
|
|
|
917
937
|
def select_calibration_file(self) -> None:
|
|
918
938
|
"""User can select a calibration file"""
|
|
919
|
-
file_dialog =
|
|
939
|
+
file_dialog = QFileDialog(self)
|
|
920
940
|
file_dialog.setWindowTitle("Select calibration file")
|
|
921
|
-
file_dialog.setFileMode(
|
|
922
|
-
file_dialog.setViewMode(
|
|
941
|
+
file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
|
|
942
|
+
file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
|
|
923
943
|
file_dialog.setNameFilter("*" + CALIBRATION_FILE_TYPE)
|
|
924
944
|
|
|
925
945
|
if file_dialog.exec():
|
|
@@ -931,10 +951,10 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
931
951
|
|
|
932
952
|
def select_recording_file(self) -> None:
|
|
933
953
|
"""User can select a recording file"""
|
|
934
|
-
file_dialog =
|
|
954
|
+
file_dialog = QFileDialog(self)
|
|
935
955
|
file_dialog.setWindowTitle("Select recording file")
|
|
936
|
-
file_dialog.setFileMode(
|
|
937
|
-
file_dialog.setViewMode(
|
|
956
|
+
file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
|
|
957
|
+
file_dialog.setViewMode(QFileDialog.ViewMode.Detail)
|
|
938
958
|
file_dialog.setNameFilter(f"(*{' *'.join(RECORDING_FILE_TYPES)})")
|
|
939
959
|
|
|
940
960
|
if file_dialog.exec():
|
|
@@ -1009,7 +1029,7 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
1009
1029
|
Recording(
|
|
1010
1030
|
name=new_name,
|
|
1011
1031
|
sampling_rate=self.recordings[self.recording_index].sampling_rate,
|
|
1012
|
-
widget=
|
|
1032
|
+
widget=QListWidgetItem(
|
|
1013
1033
|
f"Recording {new_name}", self.ui.recording_list_widget
|
|
1014
1034
|
),
|
|
1015
1035
|
)
|
|
@@ -1033,16 +1053,16 @@ class AccuSleepWindow(QtWidgets.QMainWindow):
|
|
|
1033
1053
|
|
|
1034
1054
|
def show_user_manual(self) -> None:
|
|
1035
1055
|
"""Show a popup window with the user manual"""
|
|
1036
|
-
self.popup =
|
|
1037
|
-
self.popup_vlayout =
|
|
1038
|
-
self.guide_textbox =
|
|
1056
|
+
self.popup = QWidget()
|
|
1057
|
+
self.popup_vlayout = QVBoxLayout(self.popup)
|
|
1058
|
+
self.guide_textbox = QTextBrowser(self.popup)
|
|
1039
1059
|
self.popup_vlayout.addWidget(self.guide_textbox)
|
|
1040
1060
|
|
|
1041
|
-
url =
|
|
1061
|
+
url = QUrl.fromLocalFile(MAIN_GUIDE_FILE)
|
|
1042
1062
|
self.guide_textbox.setSource(url)
|
|
1043
1063
|
self.guide_textbox.setOpenLinks(False)
|
|
1044
1064
|
|
|
1045
|
-
self.popup.setGeometry(
|
|
1065
|
+
self.popup.setGeometry(QRect(100, 100, 600, 600))
|
|
1046
1066
|
self.popup.show()
|
|
1047
1067
|
|
|
1048
1068
|
def initialize_settings_tab(self):
|
|
@@ -1389,7 +1409,7 @@ def check_config_consistency(
|
|
|
1389
1409
|
|
|
1390
1410
|
|
|
1391
1411
|
def run_primary_window() -> None:
|
|
1392
|
-
app =
|
|
1412
|
+
app = QApplication(sys.argv)
|
|
1393
1413
|
AccuSleepWindow()
|
|
1394
1414
|
sys.exit(app.exec())
|
|
1395
1415
|
|