accusleepy 0.4.0__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {accusleepy-0.4.0 → accusleepy-0.6.0}/PKG-INFO +49 -9
- accusleepy-0.6.0/README.md +77 -0
- accusleepy-0.6.0/accusleepy/bouts.py +142 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/classification.py +27 -9
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/config.json +2 -1
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/constants.py +7 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/fileio.py +35 -65
- accusleepy-0.6.0/accusleepy/gui/images/primary_window.png +0 -0
- accusleepy-0.6.0/accusleepy/gui/images/viewer_window.png +0 -0
- accusleepy-0.6.0/accusleepy/gui/images/viewer_window_annotated.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/main.py +199 -101
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/manual_scoring.py +112 -102
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/mplwidget.py +69 -39
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/primary_window.py +240 -158
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/primary_window.ui +313 -155
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/resources.qrc +1 -1
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/text/config_guide.txt +0 -2
- accusleepy-0.6.0/accusleepy/gui/text/main_guide.md +167 -0
- accusleepy-0.6.0/accusleepy/gui/text/manual_scoring_guide.md +23 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/viewer_window.py +19 -7
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/viewer_window.ui +34 -2
- accusleepy-0.6.0/accusleepy/models.py +108 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/multitaper.py +9 -7
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/signal_processing.py +28 -148
- accusleepy-0.6.0/accusleepy/temperature_scaling.py +157 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/pyproject.toml +3 -3
- accusleepy-0.4.0/README.md +0 -37
- accusleepy-0.4.0/accusleepy/gui/text/main_guide_text.py +0 -173
- accusleepy-0.4.0/accusleepy/gui/text/manual_scoring_guide.txt +0 -28
- accusleepy-0.4.0/accusleepy/models.py +0 -48
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/__init__.py +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/__main__.py +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/brain_state_set.py +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/__init__.py +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/brightness_down.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/brightness_up.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/double_down_arrow.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/double_up_arrow.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/down_arrow.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/home.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/question.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/save.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/up_arrow.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/zoom_in.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/zoom_out.png +0 -0
- {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/resources_rc.py +0 -0
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: accusleepy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: Python implementation of AccuSleep
|
|
5
5
|
License: GPL-3.0-only
|
|
6
6
|
Author: Zeke Barger
|
|
7
7
|
Author-email: zekebarger@gmail.com
|
|
8
|
-
Requires-Python: >=3.
|
|
8
|
+
Requires-Python: >=3.11,<3.14
|
|
9
9
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
10
10
|
Classifier: Programming Language :: Python :: 3
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
12
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
14
14
|
Requires-Dist: fastparquet (>=2024.11.0,<2025.0.0)
|
|
15
15
|
Requires-Dist: joblib (>=1.4.2,<2.0.0)
|
|
16
16
|
Requires-Dist: matplotlib (>=3.10.1,<4.0.0)
|
|
@@ -18,7 +18,7 @@ Requires-Dist: numpy (>=2.2.4,<3.0.0)
|
|
|
18
18
|
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
19
19
|
Requires-Dist: pillow (>=11.1.0,<12.0.0)
|
|
20
20
|
Requires-Dist: pre-commit (>=4.2.0,<5.0.0)
|
|
21
|
-
Requires-Dist: pyside6 (>=6.
|
|
21
|
+
Requires-Dist: pyside6 (>=6.9.0,<7.0.0)
|
|
22
22
|
Requires-Dist: scipy (>=1.15.2,<2.0.0)
|
|
23
23
|
Requires-Dist: toml (>=0.10.2,<0.11.0)
|
|
24
24
|
Requires-Dist: torch (>=2.6.0,<3.0.0)
|
|
@@ -30,9 +30,17 @@ Description-Content-Type: text/markdown
|
|
|
30
30
|
|
|
31
31
|
## Description
|
|
32
32
|
|
|
33
|
-
AccuSleePy is
|
|
34
|
-
|
|
35
|
-
|
|
33
|
+
AccuSleePy is set of graphical user interfaces for scoring rodent sleep
|
|
34
|
+
using EEG and EMG recordings.
|
|
35
|
+
It offers the following improvements over the MATLAB version (AccuSleep):
|
|
36
|
+
|
|
37
|
+
- Up to 10 brain states can be configured through the user interface
|
|
38
|
+
- Classification models can be trained through the user interface
|
|
39
|
+
- Model files contain useful metadata (brain state configuration,
|
|
40
|
+
epoch length, number of epochs)
|
|
41
|
+
- Models optimized for real-time scoring can be trained
|
|
42
|
+
- Lists of recordings can be imported and exported for repeatable batch processing
|
|
43
|
+
- Undo/redo functionality in the manual scoring interface
|
|
36
44
|
|
|
37
45
|
If you use AccuSleep in your research, please cite our
|
|
38
46
|
[publication](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0224642):
|
|
@@ -43,24 +51,56 @@ The data and models associated with AccuSleep are available at https://osf.io/py
|
|
|
43
51
|
|
|
44
52
|
Please contact zekebarger (at) gmail (dot) com with any questions or comments about the software.
|
|
45
53
|
|
|
54
|
+
|
|
46
55
|
## Installation
|
|
47
56
|
|
|
48
57
|
- (recommended) create a new virtual environment (using
|
|
49
58
|
[venv](https://docs.python.org/3/library/venv.html),
|
|
50
59
|
[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html),
|
|
51
|
-
etc.)
|
|
60
|
+
etc.) with python >=3.11,<3.14
|
|
52
61
|
- (optional) if you have a CUDA device and want to speed up model training, [install PyTorch](https://pytorch.org/)
|
|
53
62
|
- `pip install accusleepy`
|
|
54
63
|
- (optional) download a classification model from https://osf.io/py5eb/ under /python_format/models/
|
|
55
64
|
|
|
65
|
+
Note that upgrading or reinstalling the package will overwrite any changes
|
|
66
|
+
to the [config file](accusleepy/config.json).
|
|
67
|
+
|
|
56
68
|
## Usage
|
|
57
69
|
|
|
58
70
|
`python -m accusleepy` will open the primary interface.
|
|
59
71
|
|
|
72
|
+
[Guide to the primary interface](accusleepy/gui/text/main_guide.md)
|
|
73
|
+
|
|
74
|
+
[Guide to the manual scoring interface](accusleepy/gui/text/manual_scoring_guide.md)
|
|
75
|
+
|
|
76
|
+
## Changelog
|
|
77
|
+
|
|
78
|
+
- 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
|
|
79
|
+
since the new calibration feature will make the confidence scores more accurate.
|
|
80
|
+
- 0.5.0: Performance improvements
|
|
81
|
+
- 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
|
|
82
|
+
- 0.4.4: Performance improvements
|
|
83
|
+
- 0.4.3: Improved unit tests and user manuals
|
|
84
|
+
- 0.4.0: Improved visuals and user manuals
|
|
85
|
+
- 0.1.0-0.3.1: Early development versions
|
|
86
|
+
|
|
87
|
+
## Screenshots
|
|
88
|
+
|
|
89
|
+
Primary interface
|
|
90
|
+

|
|
91
|
+
|
|
92
|
+
Manual scoring interface
|
|
93
|
+

|
|
94
|
+
|
|
60
95
|
## Acknowledgements
|
|
61
96
|
|
|
62
97
|
We would like to thank [Franz Weber](https://www.med.upenn.edu/weberlab/) for creating an
|
|
63
|
-
early version of the manual labeling interface.
|
|
98
|
+
early version of the manual labeling interface. The code that
|
|
99
|
+
creates spectrograms comes from the
|
|
100
|
+
[Prerau lab](https://github.com/preraulab/multitaper_toolbox/blob/master/python/multitaper_spectrogram_python.py)
|
|
101
|
+
with only minor modifications.
|
|
64
102
|
Jim Bohnslav's [deepethogram](https://github.com/jbohnslav/deepethogram) served as an
|
|
65
103
|
incredibly useful reference when reimplementing this project in python.
|
|
104
|
+
The model calibration code added in version 0.6.0 comes from Geoff Pleiss'
|
|
105
|
+
[temperature scaling repo](https://github.com/gpleiss/temperature_scaling).
|
|
66
106
|
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# AccuSleePy
|
|
2
|
+
|
|
3
|
+
## Description
|
|
4
|
+
|
|
5
|
+
AccuSleePy is set of graphical user interfaces for scoring rodent sleep
|
|
6
|
+
using EEG and EMG recordings.
|
|
7
|
+
It offers the following improvements over the MATLAB version (AccuSleep):
|
|
8
|
+
|
|
9
|
+
- Up to 10 brain states can be configured through the user interface
|
|
10
|
+
- Classification models can be trained through the user interface
|
|
11
|
+
- Model files contain useful metadata (brain state configuration,
|
|
12
|
+
epoch length, number of epochs)
|
|
13
|
+
- Models optimized for real-time scoring can be trained
|
|
14
|
+
- Lists of recordings can be imported and exported for repeatable batch processing
|
|
15
|
+
- Undo/redo functionality in the manual scoring interface
|
|
16
|
+
|
|
17
|
+
If you use AccuSleep in your research, please cite our
|
|
18
|
+
[publication](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0224642):
|
|
19
|
+
|
|
20
|
+
Barger, Z., Frye, C. G., Liu, D., Dan, Y., & Bouchard, K. E. (2019). Robust, automated sleep scoring by a compact neural network with distributional shift correction. *PLOS ONE, 14*(12), 1–18.
|
|
21
|
+
|
|
22
|
+
The data and models associated with AccuSleep are available at https://osf.io/py5eb/
|
|
23
|
+
|
|
24
|
+
Please contact zekebarger (at) gmail (dot) com with any questions or comments about the software.
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
## Installation
|
|
28
|
+
|
|
29
|
+
- (recommended) create a new virtual environment (using
|
|
30
|
+
[venv](https://docs.python.org/3/library/venv.html),
|
|
31
|
+
[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html),
|
|
32
|
+
etc.) with python >=3.11,<3.14
|
|
33
|
+
- (optional) if you have a CUDA device and want to speed up model training, [install PyTorch](https://pytorch.org/)
|
|
34
|
+
- `pip install accusleepy`
|
|
35
|
+
- (optional) download a classification model from https://osf.io/py5eb/ under /python_format/models/
|
|
36
|
+
|
|
37
|
+
Note that upgrading or reinstalling the package will overwrite any changes
|
|
38
|
+
to the [config file](accusleepy/config.json).
|
|
39
|
+
|
|
40
|
+
## Usage
|
|
41
|
+
|
|
42
|
+
`python -m accusleepy` will open the primary interface.
|
|
43
|
+
|
|
44
|
+
[Guide to the primary interface](accusleepy/gui/text/main_guide.md)
|
|
45
|
+
|
|
46
|
+
[Guide to the manual scoring interface](accusleepy/gui/text/manual_scoring_guide.md)
|
|
47
|
+
|
|
48
|
+
## Changelog
|
|
49
|
+
|
|
50
|
+
- 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
|
|
51
|
+
since the new calibration feature will make the confidence scores more accurate.
|
|
52
|
+
- 0.5.0: Performance improvements
|
|
53
|
+
- 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
|
|
54
|
+
- 0.4.4: Performance improvements
|
|
55
|
+
- 0.4.3: Improved unit tests and user manuals
|
|
56
|
+
- 0.4.0: Improved visuals and user manuals
|
|
57
|
+
- 0.1.0-0.3.1: Early development versions
|
|
58
|
+
|
|
59
|
+
## Screenshots
|
|
60
|
+
|
|
61
|
+
Primary interface
|
|
62
|
+

|
|
63
|
+
|
|
64
|
+
Manual scoring interface
|
|
65
|
+

|
|
66
|
+
|
|
67
|
+
## Acknowledgements
|
|
68
|
+
|
|
69
|
+
We would like to thank [Franz Weber](https://www.med.upenn.edu/weberlab/) for creating an
|
|
70
|
+
early version of the manual labeling interface. The code that
|
|
71
|
+
creates spectrograms comes from the
|
|
72
|
+
[Prerau lab](https://github.com/preraulab/multitaper_toolbox/blob/master/python/multitaper_spectrogram_python.py)
|
|
73
|
+
with only minor modifications.
|
|
74
|
+
Jim Bohnslav's [deepethogram](https://github.com/jbohnslav/deepethogram) served as an
|
|
75
|
+
incredibly useful reference when reimplementing this project in python.
|
|
76
|
+
The model calibration code added in version 0.6.0 comes from Geoff Pleiss'
|
|
77
|
+
[temperature scaling repo](https://github.com/gpleiss/temperature_scaling).
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from operator import attrgetter
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Bout:
|
|
10
|
+
"""Stores information about a brain state bout"""
|
|
11
|
+
|
|
12
|
+
length: int # length, in number of epochs
|
|
13
|
+
start_index: int # index where bout starts
|
|
14
|
+
end_index: int # index where bout ends
|
|
15
|
+
surrounding_state: int # brain state on both sides of the bout
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def find_last_adjacent_bout(sorted_bouts: list[Bout], bout_index: int) -> int:
|
|
19
|
+
"""Find index of last consecutive same-length bout
|
|
20
|
+
|
|
21
|
+
When running the post-processing step that enforces a minimum duration
|
|
22
|
+
for brain state bouts, there is a special case when bouts below the
|
|
23
|
+
duration threshold occur consecutively. This function performs a
|
|
24
|
+
recursive search for the index of a bout at the end of such a sequence.
|
|
25
|
+
When initially called, bout_index will always be 0. If, for example, the
|
|
26
|
+
first three bouts in the list are consecutive, the function will return 2.
|
|
27
|
+
|
|
28
|
+
:param sorted_bouts: list of brain state bouts, sorted by start time
|
|
29
|
+
:param bout_index: index of the bout in question
|
|
30
|
+
:return: index of the last consecutive same-length bout
|
|
31
|
+
"""
|
|
32
|
+
# if we're at the end of the bout list, stop
|
|
33
|
+
if bout_index == len(sorted_bouts) - 1:
|
|
34
|
+
return bout_index
|
|
35
|
+
|
|
36
|
+
# if there is an adjacent bout
|
|
37
|
+
if sorted_bouts[bout_index].end_index == sorted_bouts[bout_index + 1].start_index:
|
|
38
|
+
# look for more adjacent bouts using that one as a starting point
|
|
39
|
+
return find_last_adjacent_bout(sorted_bouts, bout_index + 1)
|
|
40
|
+
else:
|
|
41
|
+
return bout_index
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def enforce_min_bout_length(
|
|
45
|
+
labels: np.array, epoch_length: int | float, min_bout_length: int | float
|
|
46
|
+
) -> np.array:
|
|
47
|
+
"""Ensure brain state bouts meet the min length requirement
|
|
48
|
+
|
|
49
|
+
As a post-processing step for sleep scoring, we can require that any
|
|
50
|
+
bout (continuous period) of a brain state have a minimum duration.
|
|
51
|
+
This function sets any bout shorter than the minimum duration to the
|
|
52
|
+
surrounding brain state (if the states on the left and right sides
|
|
53
|
+
are the same). In the case where there are consecutive short bouts,
|
|
54
|
+
it either creates a transition at the midpoint or removes all short
|
|
55
|
+
bouts, depending on whether the number is even or odd. For example:
|
|
56
|
+
...AAABABAAA... -> ...AAAAAAAAA...
|
|
57
|
+
...AAABABABBB... -> ...AAAAABBBBB...
|
|
58
|
+
|
|
59
|
+
:param labels: brain state labels (digits in the 0-9 range)
|
|
60
|
+
:param epoch_length: epoch length, in seconds
|
|
61
|
+
:param min_bout_length: minimum bout length, in seconds
|
|
62
|
+
:return: updated brain state labels
|
|
63
|
+
"""
|
|
64
|
+
# if recording is very short, don't change anything
|
|
65
|
+
if labels.size < 3:
|
|
66
|
+
return labels
|
|
67
|
+
|
|
68
|
+
if epoch_length == min_bout_length:
|
|
69
|
+
return labels
|
|
70
|
+
|
|
71
|
+
# get minimum number of epochs in a bout
|
|
72
|
+
min_epochs = int(np.ceil(min_bout_length / epoch_length))
|
|
73
|
+
# get set of states in the labels
|
|
74
|
+
brain_states = set(labels.tolist())
|
|
75
|
+
|
|
76
|
+
while True: # so true
|
|
77
|
+
# convert labels to a string for regex search
|
|
78
|
+
# There is probably a regex that can find all patterns like ab+a
|
|
79
|
+
# without consuming each "a" but I haven't found it :(
|
|
80
|
+
label_string = "".join(labels.astype(str))
|
|
81
|
+
|
|
82
|
+
bouts = list()
|
|
83
|
+
|
|
84
|
+
for state in brain_states:
|
|
85
|
+
for other_state in brain_states:
|
|
86
|
+
if state == other_state:
|
|
87
|
+
continue
|
|
88
|
+
# get start and end indices of each bout
|
|
89
|
+
expression = (
|
|
90
|
+
f"(?<={other_state}){state}{{1,{min_epochs - 1}}}(?={other_state})"
|
|
91
|
+
)
|
|
92
|
+
matches = re.finditer(expression, label_string)
|
|
93
|
+
spans = [match.span() for match in matches]
|
|
94
|
+
|
|
95
|
+
# if some bouts were found
|
|
96
|
+
for span in spans:
|
|
97
|
+
bouts.append(
|
|
98
|
+
Bout(
|
|
99
|
+
length=span[1] - span[0],
|
|
100
|
+
start_index=span[0],
|
|
101
|
+
end_index=span[1],
|
|
102
|
+
surrounding_state=other_state,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if len(bouts) == 0:
|
|
107
|
+
break
|
|
108
|
+
|
|
109
|
+
# only keep the shortest bouts
|
|
110
|
+
min_length_in_list = np.min([bout.length for bout in bouts])
|
|
111
|
+
bouts = [i for i in bouts if i.length == min_length_in_list]
|
|
112
|
+
# sort by start index
|
|
113
|
+
sorted_bouts = sorted(bouts, key=attrgetter("start_index"))
|
|
114
|
+
|
|
115
|
+
while len(sorted_bouts) > 0:
|
|
116
|
+
# get row index of latest adjacent bout (of same length)
|
|
117
|
+
last_adjacent_bout_index = find_last_adjacent_bout(sorted_bouts, 0)
|
|
118
|
+
# if there's an even number of adjacent bouts
|
|
119
|
+
if (last_adjacent_bout_index + 1) % 2 == 0:
|
|
120
|
+
midpoint = sorted_bouts[
|
|
121
|
+
round((last_adjacent_bout_index + 1) / 2)
|
|
122
|
+
].start_index
|
|
123
|
+
labels[sorted_bouts[0].start_index : midpoint] = sorted_bouts[
|
|
124
|
+
0
|
|
125
|
+
].surrounding_state
|
|
126
|
+
labels[midpoint : sorted_bouts[last_adjacent_bout_index].end_index] = (
|
|
127
|
+
sorted_bouts[last_adjacent_bout_index].surrounding_state
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
labels[
|
|
131
|
+
sorted_bouts[0].start_index : sorted_bouts[
|
|
132
|
+
last_adjacent_bout_index
|
|
133
|
+
].end_index
|
|
134
|
+
] = sorted_bouts[0].surrounding_state
|
|
135
|
+
|
|
136
|
+
# delete the bouts we just fixed
|
|
137
|
+
if last_adjacent_bout_index == len(sorted_bouts) - 1:
|
|
138
|
+
sorted_bouts = []
|
|
139
|
+
else:
|
|
140
|
+
sorted_bouts = sorted_bouts[(last_adjacent_bout_index + 1) :]
|
|
141
|
+
|
|
142
|
+
return labels
|
|
@@ -61,13 +61,31 @@ def get_device():
|
|
|
61
61
|
)
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
def
|
|
64
|
+
def create_dataloader(
|
|
65
|
+
annotations_file: str, img_dir: str, shuffle: bool = True
|
|
66
|
+
) -> DataLoader:
|
|
67
|
+
"""Create DataLoader for a dataset of training or calibration images
|
|
68
|
+
|
|
69
|
+
:param annotations_file: file with information on each training image
|
|
70
|
+
:param img_dir: training image location
|
|
71
|
+
:param shuffle: reshuffle data for every epoch
|
|
72
|
+
:return: DataLoader for the data
|
|
73
|
+
|
|
74
|
+
"""
|
|
75
|
+
image_dataset = AccuSleepImageDataset(
|
|
76
|
+
annotations_file=annotations_file,
|
|
77
|
+
img_dir=img_dir,
|
|
78
|
+
)
|
|
79
|
+
return DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=shuffle)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def train_ssann(
|
|
65
83
|
annotations_file: str,
|
|
66
84
|
img_dir: str,
|
|
67
85
|
mixture_weights: np.array,
|
|
68
86
|
n_classes: int,
|
|
69
87
|
) -> SSANN:
|
|
70
|
-
"""Train a classification model for sleep scoring
|
|
88
|
+
"""Train a SSANN classification model for sleep scoring
|
|
71
89
|
|
|
72
90
|
:param annotations_file: file with information on each training image
|
|
73
91
|
:param img_dir: training image location
|
|
@@ -75,11 +93,9 @@ def train_model(
|
|
|
75
93
|
:param n_classes: number of classes the model will learn
|
|
76
94
|
:return: trained Sleep Scoring Artificial Neural Network model
|
|
77
95
|
"""
|
|
78
|
-
|
|
79
|
-
annotations_file=annotations_file,
|
|
80
|
-
img_dir=img_dir,
|
|
96
|
+
train_dataloader = create_dataloader(
|
|
97
|
+
annotations_file=annotations_file, img_dir=img_dir
|
|
81
98
|
)
|
|
82
|
-
train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)
|
|
83
99
|
|
|
84
100
|
device = get_device()
|
|
85
101
|
model = SSANN(n_classes=n_classes)
|
|
@@ -130,7 +146,7 @@ def score_recording(
|
|
|
130
146
|
:param epoch_length: epoch length, in seconds
|
|
131
147
|
:param epochs_per_img: number of epochs for the model to consider
|
|
132
148
|
:param brain_state_set: set of brain state options
|
|
133
|
-
:return: brain state labels
|
|
149
|
+
:return: brain state labels, confidence scores
|
|
134
150
|
"""
|
|
135
151
|
# prepare model
|
|
136
152
|
device = get_device()
|
|
@@ -158,10 +174,12 @@ def score_recording(
|
|
|
158
174
|
# perform classification
|
|
159
175
|
with torch.no_grad():
|
|
160
176
|
outputs = model(images)
|
|
161
|
-
|
|
177
|
+
logits, predicted = torch.max(outputs, 1)
|
|
162
178
|
|
|
163
179
|
labels = brain_state_set.convert_class_to_digit(predicted.cpu().numpy())
|
|
164
|
-
|
|
180
|
+
confidence_scores = 1 / (1 + np.e ** (-logits.cpu().numpy()))
|
|
181
|
+
|
|
182
|
+
return labels, confidence_scores
|
|
165
183
|
|
|
166
184
|
|
|
167
185
|
def example_real_time_scoring_function(
|
|
@@ -8,6 +8,7 @@ EEG_COL = "eeg"
|
|
|
8
8
|
EMG_COL = "emg"
|
|
9
9
|
# label file columns
|
|
10
10
|
BRAIN_STATE_COL = "brain_state"
|
|
11
|
+
CONFIDENCE_SCORE_COL = "confidence_score"
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
# really don't change these
|
|
@@ -37,3 +38,9 @@ RECORDING_LIST_NAME = "recording_list"
|
|
|
37
38
|
RECORDING_LIST_FILE_TYPE = ".json"
|
|
38
39
|
# key for default epoch length in config
|
|
39
40
|
DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
|
|
41
|
+
# key used for default confidence score behavior in config
|
|
42
|
+
DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
|
|
43
|
+
# filename used to store info about training image datasets
|
|
44
|
+
ANNOTATIONS_FILENAME = "annotations.csv"
|
|
45
|
+
# filename for annotation file for the calibration set
|
|
46
|
+
CALIBRATION_ANNOTATION_FILENAME = "calibration_set.csv"
|
|
@@ -4,13 +4,14 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
|
-
import torch
|
|
8
7
|
from PySide6.QtWidgets import QListWidgetItem
|
|
9
8
|
|
|
10
9
|
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
11
10
|
from accusleepy.constants import (
|
|
12
11
|
BRAIN_STATE_COL,
|
|
12
|
+
CONFIDENCE_SCORE_COL,
|
|
13
13
|
CONFIG_FILE,
|
|
14
|
+
DEFAULT_CONFIDENCE_SETTING_KEY,
|
|
14
15
|
DEFAULT_EPOCH_LENGTH_KEY,
|
|
15
16
|
EEG_COL,
|
|
16
17
|
EMG_COL,
|
|
@@ -19,7 +20,6 @@ from accusleepy.constants import (
|
|
|
19
20
|
RECORDING_LIST_NAME,
|
|
20
21
|
UNDEFINED_LABEL,
|
|
21
22
|
)
|
|
22
|
-
from accusleepy.models import SSANN
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
@dataclass
|
|
@@ -46,57 +46,6 @@ def load_calibration_file(filename: str) -> (np.array, np.array):
|
|
|
46
46
|
return mixture_means, mixture_sds
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
def save_model(
|
|
50
|
-
model: SSANN,
|
|
51
|
-
filename: str,
|
|
52
|
-
epoch_length: int | float,
|
|
53
|
-
epochs_per_img: int,
|
|
54
|
-
model_type: str,
|
|
55
|
-
brain_state_set: BrainStateSet,
|
|
56
|
-
) -> None:
|
|
57
|
-
"""Save classification model and its metadata
|
|
58
|
-
|
|
59
|
-
:param model: classification model
|
|
60
|
-
:param epoch_length: epoch length used when training the model
|
|
61
|
-
:param epochs_per_img: number of epochs in each model input
|
|
62
|
-
:param model_type: default or real-time
|
|
63
|
-
:param brain_state_set: set of brain state options
|
|
64
|
-
:param filename: filename
|
|
65
|
-
"""
|
|
66
|
-
state_dict = model.state_dict()
|
|
67
|
-
state_dict.update({"epoch_length": epoch_length})
|
|
68
|
-
state_dict.update({"epochs_per_img": epochs_per_img})
|
|
69
|
-
state_dict.update({"model_type": model_type})
|
|
70
|
-
state_dict.update(
|
|
71
|
-
{BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
torch.save(state_dict, filename)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
|
|
78
|
-
"""Load classification model and its metadata
|
|
79
|
-
|
|
80
|
-
:param filename: filename
|
|
81
|
-
:return: model, epoch length used when training the model,
|
|
82
|
-
number of epochs in each model input, model type
|
|
83
|
-
(default or real-time), set of brain state options
|
|
84
|
-
used when training the model
|
|
85
|
-
"""
|
|
86
|
-
state_dict = torch.load(
|
|
87
|
-
filename, weights_only=True, map_location=torch.device("cpu")
|
|
88
|
-
)
|
|
89
|
-
epoch_length = state_dict.pop("epoch_length")
|
|
90
|
-
epochs_per_img = state_dict.pop("epochs_per_img")
|
|
91
|
-
model_type = state_dict.pop("model_type")
|
|
92
|
-
brain_states = state_dict.pop(BRAIN_STATES_KEY)
|
|
93
|
-
n_classes = len([b for b in brain_states if b["is_scored"]])
|
|
94
|
-
|
|
95
|
-
model = SSANN(n_classes=n_classes)
|
|
96
|
-
model.load_state_dict(state_dict)
|
|
97
|
-
return model, epoch_length, epochs_per_img, model_type, brain_states
|
|
98
|
-
|
|
99
|
-
|
|
100
49
|
def load_csv_or_parquet(filename: str) -> pd.DataFrame:
|
|
101
50
|
"""Load a csv or parquet file as a dataframe
|
|
102
51
|
|
|
@@ -125,49 +74,70 @@ def load_recording(filename: str) -> (np.array, np.array):
|
|
|
125
74
|
return eeg, emg
|
|
126
75
|
|
|
127
76
|
|
|
128
|
-
def load_labels(filename: str) -> np.array:
|
|
129
|
-
"""Load file of brain state labels
|
|
77
|
+
def load_labels(filename: str) -> (np.array, np.array):
|
|
78
|
+
"""Load file of brain state labels and confidence scores
|
|
130
79
|
|
|
131
80
|
:param filename: filename
|
|
132
|
-
:return: array of brain state labels
|
|
81
|
+
:return: array of brain state labels and, optionally, array of confidence scores
|
|
133
82
|
"""
|
|
134
83
|
df = load_csv_or_parquet(filename)
|
|
135
|
-
|
|
84
|
+
if CONFIDENCE_SCORE_COL in df.columns:
|
|
85
|
+
return df[BRAIN_STATE_COL].values, df[CONFIDENCE_SCORE_COL].values
|
|
86
|
+
else:
|
|
87
|
+
return df[BRAIN_STATE_COL].values, None
|
|
136
88
|
|
|
137
89
|
|
|
138
|
-
def save_labels(
|
|
90
|
+
def save_labels(
|
|
91
|
+
labels: np.array, filename: str, confidence_scores: np.array = None
|
|
92
|
+
) -> None:
|
|
139
93
|
"""Save brain state labels to file
|
|
140
94
|
|
|
141
95
|
:param labels: brain state labels
|
|
142
96
|
:param filename: filename
|
|
97
|
+
:param confidence_scores: optional confidence scores
|
|
143
98
|
"""
|
|
144
|
-
|
|
99
|
+
if confidence_scores is not None:
|
|
100
|
+
pd.DataFrame(
|
|
101
|
+
{BRAIN_STATE_COL: labels, CONFIDENCE_SCORE_COL: confidence_scores}
|
|
102
|
+
).to_csv(filename, index=False)
|
|
103
|
+
else:
|
|
104
|
+
pd.DataFrame({BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
|
|
145
105
|
|
|
146
106
|
|
|
147
|
-
def load_config() -> tuple[BrainStateSet, int | float]:
|
|
107
|
+
def load_config() -> tuple[BrainStateSet, int | float, bool]:
|
|
148
108
|
"""Load configuration file with brain state options
|
|
149
109
|
|
|
150
|
-
:return: set of brain state options
|
|
110
|
+
:return: set of brain state options, other settings
|
|
151
111
|
"""
|
|
152
112
|
with open(
|
|
153
113
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "r"
|
|
154
114
|
) as f:
|
|
155
115
|
data = json.load(f)
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
116
|
+
|
|
117
|
+
return (
|
|
118
|
+
BrainStateSet(
|
|
119
|
+
[BrainState(**b) for b in data[BRAIN_STATES_KEY]], UNDEFINED_LABEL
|
|
120
|
+
),
|
|
121
|
+
data[DEFAULT_EPOCH_LENGTH_KEY],
|
|
122
|
+
data.get(DEFAULT_CONFIDENCE_SETTING_KEY, True),
|
|
123
|
+
)
|
|
159
124
|
|
|
160
125
|
|
|
161
126
|
def save_config(
|
|
162
|
-
brain_state_set: BrainStateSet,
|
|
127
|
+
brain_state_set: BrainStateSet,
|
|
128
|
+
default_epoch_length: int | float,
|
|
129
|
+
save_confidence_setting: bool,
|
|
163
130
|
) -> None:
|
|
164
131
|
"""Save configuration of brain state options to json file
|
|
165
132
|
|
|
166
133
|
:param brain_state_set: set of brain state options
|
|
167
134
|
:param default_epoch_length: epoch length to use when the GUI starts
|
|
135
|
+
:param save_confidence_setting: whether the option to save confidence
|
|
136
|
+
scores should be True by default
|
|
168
137
|
"""
|
|
169
138
|
output_dict = brain_state_set.to_output_dict()
|
|
170
139
|
output_dict.update({DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
|
|
140
|
+
output_dict.update({DEFAULT_CONFIDENCE_SETTING_KEY: save_confidence_setting})
|
|
171
141
|
with open(
|
|
172
142
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "w"
|
|
173
143
|
) as f:
|
|
Binary file
|
|
Binary file
|
|
Binary file
|