accusleepy 0.4.4__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {accusleepy-0.4.4 → accusleepy-0.5.0}/PKG-INFO +29 -19
- {accusleepy-0.4.4 → accusleepy-0.5.0}/README.md +25 -15
- accusleepy-0.5.0/accusleepy/bouts.py +142 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/classification.py +2 -2
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/constants.py +2 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/fileio.py +0 -53
- accusleepy-0.5.0/accusleepy/gui/images/primary_window.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/main.py +84 -64
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/manual_scoring.py +76 -81
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/mplwidget.py +15 -10
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/primary_window.py +1 -0
- accusleepy-0.5.0/accusleepy/models.py +98 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/multitaper.py +9 -7
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/signal_processing.py +5 -143
- {accusleepy-0.4.4 → accusleepy-0.5.0}/pyproject.toml +3 -3
- accusleepy-0.4.4/accusleepy/gui/images/primary_window.png +0 -0
- accusleepy-0.4.4/accusleepy/models.py +0 -48
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/__init__.py +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/__main__.py +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/brain_state_set.py +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/config.json +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/__init__.py +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/brightness_down.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/brightness_up.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/double_down_arrow.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/double_up_arrow.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/down_arrow.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/home.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/question.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/save.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/up_arrow.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/zoom_in.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/icons/zoom_out.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/images/viewer_window.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/images/viewer_window_annotated.png +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/primary_window.ui +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/resources.qrc +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/resources_rc.py +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/text/config_guide.txt +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/text/main_guide.md +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/text/manual_scoring_guide.md +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/viewer_window.py +0 -0
- {accusleepy-0.4.4 → accusleepy-0.5.0}/accusleepy/gui/viewer_window.ui +0 -0
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: accusleepy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: Python implementation of AccuSleep
|
|
5
5
|
License: GPL-3.0-only
|
|
6
6
|
Author: Zeke Barger
|
|
7
7
|
Author-email: zekebarger@gmail.com
|
|
8
|
-
Requires-Python: >=3.
|
|
8
|
+
Requires-Python: >=3.11,<3.14
|
|
9
9
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
10
10
|
Classifier: Programming Language :: Python :: 3
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
12
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
14
14
|
Requires-Dist: fastparquet (>=2024.11.0,<2025.0.0)
|
|
15
15
|
Requires-Dist: joblib (>=1.4.2,<2.0.0)
|
|
16
16
|
Requires-Dist: matplotlib (>=3.10.1,<4.0.0)
|
|
@@ -18,7 +18,7 @@ Requires-Dist: numpy (>=2.2.4,<3.0.0)
|
|
|
18
18
|
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
19
19
|
Requires-Dist: pillow (>=11.1.0,<12.0.0)
|
|
20
20
|
Requires-Dist: pre-commit (>=4.2.0,<5.0.0)
|
|
21
|
-
Requires-Dist: pyside6 (>=6.
|
|
21
|
+
Requires-Dist: pyside6 (>=6.9.0,<7.0.0)
|
|
22
22
|
Requires-Dist: scipy (>=1.15.2,<2.0.0)
|
|
23
23
|
Requires-Dist: toml (>=0.10.2,<0.11.0)
|
|
24
24
|
Requires-Dist: torch (>=2.6.0,<3.0.0)
|
|
@@ -30,9 +30,17 @@ Description-Content-Type: text/markdown
|
|
|
30
30
|
|
|
31
31
|
## Description
|
|
32
32
|
|
|
33
|
-
AccuSleePy is
|
|
34
|
-
|
|
35
|
-
|
|
33
|
+
AccuSleePy is set of graphical user interfaces for scoring rodent sleep
|
|
34
|
+
using EEG and EMG recordings.
|
|
35
|
+
It offers the following improvements over the MATLAB version (AccuSleep):
|
|
36
|
+
|
|
37
|
+
- Up to 10 brain states can be configured through the user interface
|
|
38
|
+
- Classification models can be trained through the user interface
|
|
39
|
+
- Model files contain useful metadata (brain state configuration,
|
|
40
|
+
epoch length, number of epochs)
|
|
41
|
+
- Models optimized for real-time scoring can be trained
|
|
42
|
+
- Lists of recordings can be imported and exported for repeatable batch processing
|
|
43
|
+
- Undo/redo functionality in the manual scoring interface
|
|
36
44
|
|
|
37
45
|
If you use AccuSleep in your research, please cite our
|
|
38
46
|
[publication](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0224642):
|
|
@@ -43,28 +51,20 @@ The data and models associated with AccuSleep are available at https://osf.io/py
|
|
|
43
51
|
|
|
44
52
|
Please contact zekebarger (at) gmail (dot) com with any questions or comments about the software.
|
|
45
53
|
|
|
46
|
-
## What's new
|
|
47
|
-
|
|
48
|
-
AccuSleePy offers the following improvements over the MATLAB version:
|
|
49
|
-
|
|
50
|
-
- Up to 10 brain states can be configured through the user interface
|
|
51
|
-
- Models can be trained through the user interface
|
|
52
|
-
- Model files contain useful metadata (brain state configuration,
|
|
53
|
-
epoch length, number of epochs)
|
|
54
|
-
- Models optimized for real-time scoring can be trained
|
|
55
|
-
- Lists of recordings can be imported and exported for repeatable batch processing
|
|
56
|
-
- Undo/redo functionality in the manual scoring interface
|
|
57
54
|
|
|
58
55
|
## Installation
|
|
59
56
|
|
|
60
57
|
- (recommended) create a new virtual environment (using
|
|
61
58
|
[venv](https://docs.python.org/3/library/venv.html),
|
|
62
59
|
[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html),
|
|
63
|
-
etc.)
|
|
60
|
+
etc.) with python >=3.11,<3.14
|
|
64
61
|
- (optional) if you have a CUDA device and want to speed up model training, [install PyTorch](https://pytorch.org/)
|
|
65
62
|
- `pip install accusleepy`
|
|
66
63
|
- (optional) download a classification model from https://osf.io/py5eb/ under /python_format/models/
|
|
67
64
|
|
|
65
|
+
Note that upgrading or reinstalling the package will overwrite any changes
|
|
66
|
+
to the [config file](accusleepy/config.json).
|
|
67
|
+
|
|
68
68
|
## Usage
|
|
69
69
|
|
|
70
70
|
`python -m accusleepy` will open the primary interface.
|
|
@@ -73,7 +73,17 @@ etc.) using python >=3.10,<3.13
|
|
|
73
73
|
|
|
74
74
|
[Guide to the manual scoring interface](accusleepy/gui/text/manual_scoring_guide.md)
|
|
75
75
|
|
|
76
|
+
## Changelog
|
|
77
|
+
|
|
78
|
+
- 0.5.0: Performance improvements
|
|
79
|
+
- 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
|
|
80
|
+
- 0.4.4: Performance improvements
|
|
81
|
+
- 0.4.3: Improved unit tests and user manuals
|
|
82
|
+
- 0.4.0: Improved visuals and user manuals
|
|
83
|
+
- 0.1.0-0.3.1: Early development versions
|
|
84
|
+
|
|
76
85
|
## Screenshots
|
|
86
|
+
|
|
77
87
|
Primary interface
|
|
78
88
|

|
|
79
89
|
|
|
@@ -2,9 +2,17 @@
|
|
|
2
2
|
|
|
3
3
|
## Description
|
|
4
4
|
|
|
5
|
-
AccuSleePy is
|
|
6
|
-
|
|
7
|
-
|
|
5
|
+
AccuSleePy is set of graphical user interfaces for scoring rodent sleep
|
|
6
|
+
using EEG and EMG recordings.
|
|
7
|
+
It offers the following improvements over the MATLAB version (AccuSleep):
|
|
8
|
+
|
|
9
|
+
- Up to 10 brain states can be configured through the user interface
|
|
10
|
+
- Classification models can be trained through the user interface
|
|
11
|
+
- Model files contain useful metadata (brain state configuration,
|
|
12
|
+
epoch length, number of epochs)
|
|
13
|
+
- Models optimized for real-time scoring can be trained
|
|
14
|
+
- Lists of recordings can be imported and exported for repeatable batch processing
|
|
15
|
+
- Undo/redo functionality in the manual scoring interface
|
|
8
16
|
|
|
9
17
|
If you use AccuSleep in your research, please cite our
|
|
10
18
|
[publication](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0224642):
|
|
@@ -15,28 +23,20 @@ The data and models associated with AccuSleep are available at https://osf.io/py
|
|
|
15
23
|
|
|
16
24
|
Please contact zekebarger (at) gmail (dot) com with any questions or comments about the software.
|
|
17
25
|
|
|
18
|
-
## What's new
|
|
19
|
-
|
|
20
|
-
AccuSleePy offers the following improvements over the MATLAB version:
|
|
21
|
-
|
|
22
|
-
- Up to 10 brain states can be configured through the user interface
|
|
23
|
-
- Models can be trained through the user interface
|
|
24
|
-
- Model files contain useful metadata (brain state configuration,
|
|
25
|
-
epoch length, number of epochs)
|
|
26
|
-
- Models optimized for real-time scoring can be trained
|
|
27
|
-
- Lists of recordings can be imported and exported for repeatable batch processing
|
|
28
|
-
- Undo/redo functionality in the manual scoring interface
|
|
29
26
|
|
|
30
27
|
## Installation
|
|
31
28
|
|
|
32
29
|
- (recommended) create a new virtual environment (using
|
|
33
30
|
[venv](https://docs.python.org/3/library/venv.html),
|
|
34
31
|
[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html),
|
|
35
|
-
etc.)
|
|
32
|
+
etc.) with python >=3.11,<3.14
|
|
36
33
|
- (optional) if you have a CUDA device and want to speed up model training, [install PyTorch](https://pytorch.org/)
|
|
37
34
|
- `pip install accusleepy`
|
|
38
35
|
- (optional) download a classification model from https://osf.io/py5eb/ under /python_format/models/
|
|
39
36
|
|
|
37
|
+
Note that upgrading or reinstalling the package will overwrite any changes
|
|
38
|
+
to the [config file](accusleepy/config.json).
|
|
39
|
+
|
|
40
40
|
## Usage
|
|
41
41
|
|
|
42
42
|
`python -m accusleepy` will open the primary interface.
|
|
@@ -45,7 +45,17 @@ etc.) using python >=3.10,<3.13
|
|
|
45
45
|
|
|
46
46
|
[Guide to the manual scoring interface](accusleepy/gui/text/manual_scoring_guide.md)
|
|
47
47
|
|
|
48
|
+
## Changelog
|
|
49
|
+
|
|
50
|
+
- 0.5.0: Performance improvements
|
|
51
|
+
- 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
|
|
52
|
+
- 0.4.4: Performance improvements
|
|
53
|
+
- 0.4.3: Improved unit tests and user manuals
|
|
54
|
+
- 0.4.0: Improved visuals and user manuals
|
|
55
|
+
- 0.1.0-0.3.1: Early development versions
|
|
56
|
+
|
|
48
57
|
## Screenshots
|
|
58
|
+
|
|
49
59
|
Primary interface
|
|
50
60
|

|
|
51
61
|
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from operator import attrgetter
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Bout:
|
|
10
|
+
"""Stores information about a brain state bout"""
|
|
11
|
+
|
|
12
|
+
length: int # length, in number of epochs
|
|
13
|
+
start_index: int # index where bout starts
|
|
14
|
+
end_index: int # index where bout ends
|
|
15
|
+
surrounding_state: int # brain state on both sides of the bout
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def find_last_adjacent_bout(sorted_bouts: list[Bout], bout_index: int) -> int:
|
|
19
|
+
"""Find index of last consecutive same-length bout
|
|
20
|
+
|
|
21
|
+
When running the post-processing step that enforces a minimum duration
|
|
22
|
+
for brain state bouts, there is a special case when bouts below the
|
|
23
|
+
duration threshold occur consecutively. This function performs a
|
|
24
|
+
recursive search for the index of a bout at the end of such a sequence.
|
|
25
|
+
When initially called, bout_index will always be 0. If, for example, the
|
|
26
|
+
first three bouts in the list are consecutive, the function will return 2.
|
|
27
|
+
|
|
28
|
+
:param sorted_bouts: list of brain state bouts, sorted by start time
|
|
29
|
+
:param bout_index: index of the bout in question
|
|
30
|
+
:return: index of the last consecutive same-length bout
|
|
31
|
+
"""
|
|
32
|
+
# if we're at the end of the bout list, stop
|
|
33
|
+
if bout_index == len(sorted_bouts) - 1:
|
|
34
|
+
return bout_index
|
|
35
|
+
|
|
36
|
+
# if there is an adjacent bout
|
|
37
|
+
if sorted_bouts[bout_index].end_index == sorted_bouts[bout_index + 1].start_index:
|
|
38
|
+
# look for more adjacent bouts using that one as a starting point
|
|
39
|
+
return find_last_adjacent_bout(sorted_bouts, bout_index + 1)
|
|
40
|
+
else:
|
|
41
|
+
return bout_index
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def enforce_min_bout_length(
|
|
45
|
+
labels: np.array, epoch_length: int | float, min_bout_length: int | float
|
|
46
|
+
) -> np.array:
|
|
47
|
+
"""Ensure brain state bouts meet the min length requirement
|
|
48
|
+
|
|
49
|
+
As a post-processing step for sleep scoring, we can require that any
|
|
50
|
+
bout (continuous period) of a brain state have a minimum duration.
|
|
51
|
+
This function sets any bout shorter than the minimum duration to the
|
|
52
|
+
surrounding brain state (if the states on the left and right sides
|
|
53
|
+
are the same). In the case where there are consecutive short bouts,
|
|
54
|
+
it either creates a transition at the midpoint or removes all short
|
|
55
|
+
bouts, depending on whether the number is even or odd. For example:
|
|
56
|
+
...AAABABAAA... -> ...AAAAAAAAA...
|
|
57
|
+
...AAABABABBB... -> ...AAAAABBBBB...
|
|
58
|
+
|
|
59
|
+
:param labels: brain state labels (digits in the 0-9 range)
|
|
60
|
+
:param epoch_length: epoch length, in seconds
|
|
61
|
+
:param min_bout_length: minimum bout length, in seconds
|
|
62
|
+
:return: updated brain state labels
|
|
63
|
+
"""
|
|
64
|
+
# if recording is very short, don't change anything
|
|
65
|
+
if labels.size < 3:
|
|
66
|
+
return labels
|
|
67
|
+
|
|
68
|
+
if epoch_length == min_bout_length:
|
|
69
|
+
return labels
|
|
70
|
+
|
|
71
|
+
# get minimum number of epochs in a bout
|
|
72
|
+
min_epochs = int(np.ceil(min_bout_length / epoch_length))
|
|
73
|
+
# get set of states in the labels
|
|
74
|
+
brain_states = set(labels.tolist())
|
|
75
|
+
|
|
76
|
+
while True: # so true
|
|
77
|
+
# convert labels to a string for regex search
|
|
78
|
+
# There is probably a regex that can find all patterns like ab+a
|
|
79
|
+
# without consuming each "a" but I haven't found it :(
|
|
80
|
+
label_string = "".join(labels.astype(str))
|
|
81
|
+
|
|
82
|
+
bouts = list()
|
|
83
|
+
|
|
84
|
+
for state in brain_states:
|
|
85
|
+
for other_state in brain_states:
|
|
86
|
+
if state == other_state:
|
|
87
|
+
continue
|
|
88
|
+
# get start and end indices of each bout
|
|
89
|
+
expression = (
|
|
90
|
+
f"(?<={other_state}){state}{{1,{min_epochs - 1}}}(?={other_state})"
|
|
91
|
+
)
|
|
92
|
+
matches = re.finditer(expression, label_string)
|
|
93
|
+
spans = [match.span() for match in matches]
|
|
94
|
+
|
|
95
|
+
# if some bouts were found
|
|
96
|
+
for span in spans:
|
|
97
|
+
bouts.append(
|
|
98
|
+
Bout(
|
|
99
|
+
length=span[1] - span[0],
|
|
100
|
+
start_index=span[0],
|
|
101
|
+
end_index=span[1],
|
|
102
|
+
surrounding_state=other_state,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if len(bouts) == 0:
|
|
107
|
+
break
|
|
108
|
+
|
|
109
|
+
# only keep the shortest bouts
|
|
110
|
+
min_length_in_list = np.min([bout.length for bout in bouts])
|
|
111
|
+
bouts = [i for i in bouts if i.length == min_length_in_list]
|
|
112
|
+
# sort by start index
|
|
113
|
+
sorted_bouts = sorted(bouts, key=attrgetter("start_index"))
|
|
114
|
+
|
|
115
|
+
while len(sorted_bouts) > 0:
|
|
116
|
+
# get row index of latest adjacent bout (of same length)
|
|
117
|
+
last_adjacent_bout_index = find_last_adjacent_bout(sorted_bouts, 0)
|
|
118
|
+
# if there's an even number of adjacent bouts
|
|
119
|
+
if (last_adjacent_bout_index + 1) % 2 == 0:
|
|
120
|
+
midpoint = sorted_bouts[
|
|
121
|
+
round((last_adjacent_bout_index + 1) / 2)
|
|
122
|
+
].start_index
|
|
123
|
+
labels[sorted_bouts[0].start_index : midpoint] = sorted_bouts[
|
|
124
|
+
0
|
|
125
|
+
].surrounding_state
|
|
126
|
+
labels[midpoint : sorted_bouts[last_adjacent_bout_index].end_index] = (
|
|
127
|
+
sorted_bouts[last_adjacent_bout_index].surrounding_state
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
labels[
|
|
131
|
+
sorted_bouts[0].start_index : sorted_bouts[
|
|
132
|
+
last_adjacent_bout_index
|
|
133
|
+
].end_index
|
|
134
|
+
] = sorted_bouts[0].surrounding_state
|
|
135
|
+
|
|
136
|
+
# delete the bouts we just fixed
|
|
137
|
+
if last_adjacent_bout_index == len(sorted_bouts) - 1:
|
|
138
|
+
sorted_bouts = []
|
|
139
|
+
else:
|
|
140
|
+
sorted_bouts = sorted_bouts[(last_adjacent_bout_index + 1) :]
|
|
141
|
+
|
|
142
|
+
return labels
|
|
@@ -61,13 +61,13 @@ def get_device():
|
|
|
61
61
|
)
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
def
|
|
64
|
+
def train_ssann(
|
|
65
65
|
annotations_file: str,
|
|
66
66
|
img_dir: str,
|
|
67
67
|
mixture_weights: np.array,
|
|
68
68
|
n_classes: int,
|
|
69
69
|
) -> SSANN:
|
|
70
|
-
"""Train a classification model for sleep scoring
|
|
70
|
+
"""Train a SSANN classification model for sleep scoring
|
|
71
71
|
|
|
72
72
|
:param annotations_file: file with information on each training image
|
|
73
73
|
:param img_dir: training image location
|
|
@@ -37,3 +37,5 @@ RECORDING_LIST_NAME = "recording_list"
|
|
|
37
37
|
RECORDING_LIST_FILE_TYPE = ".json"
|
|
38
38
|
# key for default epoch length in config
|
|
39
39
|
DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
|
|
40
|
+
# filename used to store info about training image datasets
|
|
41
|
+
ANNOTATIONS_FILENAME = "annotations.csv"
|
|
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
|
-
import torch
|
|
8
7
|
from PySide6.QtWidgets import QListWidgetItem
|
|
9
8
|
|
|
10
9
|
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
|
|
@@ -19,7 +18,6 @@ from accusleepy.constants import (
|
|
|
19
18
|
RECORDING_LIST_NAME,
|
|
20
19
|
UNDEFINED_LABEL,
|
|
21
20
|
)
|
|
22
|
-
from accusleepy.models import SSANN
|
|
23
21
|
|
|
24
22
|
|
|
25
23
|
@dataclass
|
|
@@ -46,57 +44,6 @@ def load_calibration_file(filename: str) -> (np.array, np.array):
|
|
|
46
44
|
return mixture_means, mixture_sds
|
|
47
45
|
|
|
48
46
|
|
|
49
|
-
def save_model(
|
|
50
|
-
model: SSANN,
|
|
51
|
-
filename: str,
|
|
52
|
-
epoch_length: int | float,
|
|
53
|
-
epochs_per_img: int,
|
|
54
|
-
model_type: str,
|
|
55
|
-
brain_state_set: BrainStateSet,
|
|
56
|
-
) -> None:
|
|
57
|
-
"""Save classification model and its metadata
|
|
58
|
-
|
|
59
|
-
:param model: classification model
|
|
60
|
-
:param epoch_length: epoch length used when training the model
|
|
61
|
-
:param epochs_per_img: number of epochs in each model input
|
|
62
|
-
:param model_type: default or real-time
|
|
63
|
-
:param brain_state_set: set of brain state options
|
|
64
|
-
:param filename: filename
|
|
65
|
-
"""
|
|
66
|
-
state_dict = model.state_dict()
|
|
67
|
-
state_dict.update({"epoch_length": epoch_length})
|
|
68
|
-
state_dict.update({"epochs_per_img": epochs_per_img})
|
|
69
|
-
state_dict.update({"model_type": model_type})
|
|
70
|
-
state_dict.update(
|
|
71
|
-
{BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
torch.save(state_dict, filename)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
|
|
78
|
-
"""Load classification model and its metadata
|
|
79
|
-
|
|
80
|
-
:param filename: filename
|
|
81
|
-
:return: model, epoch length used when training the model,
|
|
82
|
-
number of epochs in each model input, model type
|
|
83
|
-
(default or real-time), set of brain state options
|
|
84
|
-
used when training the model
|
|
85
|
-
"""
|
|
86
|
-
state_dict = torch.load(
|
|
87
|
-
filename, weights_only=True, map_location=torch.device("cpu")
|
|
88
|
-
)
|
|
89
|
-
epoch_length = state_dict.pop("epoch_length")
|
|
90
|
-
epochs_per_img = state_dict.pop("epochs_per_img")
|
|
91
|
-
model_type = state_dict.pop("model_type")
|
|
92
|
-
brain_states = state_dict.pop(BRAIN_STATES_KEY)
|
|
93
|
-
n_classes = len([b for b in brain_states if b["is_scored"]])
|
|
94
|
-
|
|
95
|
-
model = SSANN(n_classes=n_classes)
|
|
96
|
-
model.load_state_dict(state_dict)
|
|
97
|
-
return model, epoch_length, epochs_per_img, model_type, brain_states
|
|
98
|
-
|
|
99
|
-
|
|
100
47
|
def load_csv_or_parquet(filename: str) -> pd.DataFrame:
|
|
101
48
|
"""Load a csv or parquet file as a dataframe
|
|
102
49
|
|
|
Binary file
|