accusleepy 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/__init__.py +0 -0
- accusleepy/__main__.py +4 -0
- accusleepy/bouts.py +142 -0
- accusleepy/brain_state_set.py +89 -0
- accusleepy/classification.py +285 -0
- accusleepy/config.json +24 -0
- accusleepy/constants.py +46 -0
- accusleepy/fileio.py +179 -0
- accusleepy/gui/__init__.py +0 -0
- accusleepy/gui/icons/brightness_down.png +0 -0
- accusleepy/gui/icons/brightness_up.png +0 -0
- accusleepy/gui/icons/double_down_arrow.png +0 -0
- accusleepy/gui/icons/double_up_arrow.png +0 -0
- accusleepy/gui/icons/down_arrow.png +0 -0
- accusleepy/gui/icons/home.png +0 -0
- accusleepy/gui/icons/question.png +0 -0
- accusleepy/gui/icons/save.png +0 -0
- accusleepy/gui/icons/up_arrow.png +0 -0
- accusleepy/gui/icons/zoom_in.png +0 -0
- accusleepy/gui/icons/zoom_out.png +0 -0
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/images/viewer_window.png +0 -0
- accusleepy/gui/images/viewer_window_annotated.png +0 -0
- accusleepy/gui/main.py +1494 -0
- accusleepy/gui/manual_scoring.py +1096 -0
- accusleepy/gui/mplwidget.py +386 -0
- accusleepy/gui/primary_window.py +2577 -0
- accusleepy/gui/primary_window.ui +3831 -0
- accusleepy/gui/resources.qrc +16 -0
- accusleepy/gui/resources_rc.py +6710 -0
- accusleepy/gui/text/config_guide.txt +27 -0
- accusleepy/gui/text/main_guide.md +167 -0
- accusleepy/gui/text/manual_scoring_guide.md +23 -0
- accusleepy/gui/viewer_window.py +610 -0
- accusleepy/gui/viewer_window.ui +926 -0
- accusleepy/models.py +108 -0
- accusleepy/multitaper.py +661 -0
- accusleepy/signal_processing.py +469 -0
- accusleepy/temperature_scaling.py +157 -0
- accusleepy-0.6.0.dist-info/METADATA +106 -0
- accusleepy-0.6.0.dist-info/RECORD +42 -0
- accusleepy-0.6.0.dist-info/WHEEL +4 -0
accusleepy/models.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from torch import device, flatten, nn
|
|
3
|
+
from torch import load as torch_load
|
|
4
|
+
from torch import save as torch_save
|
|
5
|
+
|
|
6
|
+
from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainStateSet
|
|
7
|
+
from accusleepy.constants import (
|
|
8
|
+
DOWNSAMPLING_START_FREQ,
|
|
9
|
+
EMG_COPIES,
|
|
10
|
+
MIN_WINDOW_LEN,
|
|
11
|
+
UPPER_FREQ,
|
|
12
|
+
)
|
|
13
|
+
from accusleepy.temperature_scaling import ModelWithTemperature
|
|
14
|
+
|
|
15
|
+
# height in pixels of each training image
|
|
16
|
+
IMAGE_HEIGHT = (
|
|
17
|
+
len(np.arange(0, DOWNSAMPLING_START_FREQ, 1 / MIN_WINDOW_LEN))
|
|
18
|
+
+ len(np.arange(DOWNSAMPLING_START_FREQ, UPPER_FREQ, 2 / MIN_WINDOW_LEN))
|
|
19
|
+
+ EMG_COPIES
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SSANN(nn.Module):
|
|
24
|
+
"""Small CNN for classifying images"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, n_classes: int):
|
|
27
|
+
super().__init__()
|
|
28
|
+
|
|
29
|
+
self.pool = nn.MaxPool2d(2, 2)
|
|
30
|
+
self.conv1 = nn.Conv2d(
|
|
31
|
+
in_channels=1, out_channels=8, kernel_size=3, padding="same"
|
|
32
|
+
)
|
|
33
|
+
self.conv2 = nn.Conv2d(
|
|
34
|
+
in_channels=8, out_channels=16, kernel_size=3, padding="same"
|
|
35
|
+
)
|
|
36
|
+
self.conv3 = nn.Conv2d(
|
|
37
|
+
in_channels=16, out_channels=32, kernel_size=3, padding="same"
|
|
38
|
+
)
|
|
39
|
+
self.conv1_bn = nn.BatchNorm2d(8)
|
|
40
|
+
self.conv2_bn = nn.BatchNorm2d(16)
|
|
41
|
+
self.conv3_bn = nn.BatchNorm2d(32)
|
|
42
|
+
self.fc1 = nn.Linear(int(32 * IMAGE_HEIGHT / 8), n_classes)
|
|
43
|
+
|
|
44
|
+
def forward(self, x):
|
|
45
|
+
x = x.float()
|
|
46
|
+
x = self.pool(nn.functional.relu(self.conv1_bn(self.conv1(x))))
|
|
47
|
+
x = self.pool(nn.functional.relu(self.conv2_bn(self.conv2(x))))
|
|
48
|
+
x = self.pool(nn.functional.relu(self.conv3_bn(self.conv3(x))))
|
|
49
|
+
x = flatten(x, 1) # flatten all dimensions except batch
|
|
50
|
+
return self.fc1(x)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def save_model(
|
|
54
|
+
model: SSANN,
|
|
55
|
+
filename: str,
|
|
56
|
+
epoch_length: int | float,
|
|
57
|
+
epochs_per_img: int,
|
|
58
|
+
model_type: str,
|
|
59
|
+
brain_state_set: BrainStateSet,
|
|
60
|
+
is_calibrated: bool,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""Save classification model and its metadata
|
|
63
|
+
|
|
64
|
+
:param model: classification model
|
|
65
|
+
:param filename: filename
|
|
66
|
+
:param epoch_length: epoch length used when training the model
|
|
67
|
+
:param epochs_per_img: number of epochs in each model input
|
|
68
|
+
:param model_type: default or real-time
|
|
69
|
+
:param brain_state_set: set of brain state options
|
|
70
|
+
:param is_calibrated: whether the model has been calibrated
|
|
71
|
+
"""
|
|
72
|
+
state_dict = model.state_dict()
|
|
73
|
+
state_dict.update({"epoch_length": epoch_length})
|
|
74
|
+
state_dict.update({"epochs_per_img": epochs_per_img})
|
|
75
|
+
state_dict.update({"model_type": model_type})
|
|
76
|
+
state_dict.update({"is_calibrated": is_calibrated})
|
|
77
|
+
state_dict.update(
|
|
78
|
+
{BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
torch_save(state_dict, filename)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
|
|
85
|
+
"""Load classification model and its metadata
|
|
86
|
+
|
|
87
|
+
:param filename: filename
|
|
88
|
+
:return: model, epoch length used when training the model,
|
|
89
|
+
number of epochs in each model input, model type
|
|
90
|
+
(default or real-time), set of brain state options
|
|
91
|
+
used when training the model
|
|
92
|
+
"""
|
|
93
|
+
state_dict = torch_load(filename, weights_only=True, map_location=device("cpu"))
|
|
94
|
+
epoch_length = state_dict.pop("epoch_length")
|
|
95
|
+
epochs_per_img = state_dict.pop("epochs_per_img")
|
|
96
|
+
model_type = state_dict.pop("model_type")
|
|
97
|
+
if "is_calibrated" in state_dict:
|
|
98
|
+
is_calibrated = state_dict.pop("is_calibrated")
|
|
99
|
+
else:
|
|
100
|
+
is_calibrated = False
|
|
101
|
+
brain_states = state_dict.pop(BRAIN_STATES_KEY)
|
|
102
|
+
n_classes = len([b for b in brain_states if b["is_scored"]])
|
|
103
|
+
|
|
104
|
+
model = SSANN(n_classes=n_classes)
|
|
105
|
+
if is_calibrated:
|
|
106
|
+
model = ModelWithTemperature(model)
|
|
107
|
+
model.load_state_dict(state_dict)
|
|
108
|
+
return model, epoch_length, epochs_per_img, model_type, brain_states
|