accusleepy 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. accusleepy/__init__.py +0 -0
  2. accusleepy/__main__.py +4 -0
  3. accusleepy/bouts.py +142 -0
  4. accusleepy/brain_state_set.py +89 -0
  5. accusleepy/classification.py +285 -0
  6. accusleepy/config.json +24 -0
  7. accusleepy/constants.py +46 -0
  8. accusleepy/fileio.py +179 -0
  9. accusleepy/gui/__init__.py +0 -0
  10. accusleepy/gui/icons/brightness_down.png +0 -0
  11. accusleepy/gui/icons/brightness_up.png +0 -0
  12. accusleepy/gui/icons/double_down_arrow.png +0 -0
  13. accusleepy/gui/icons/double_up_arrow.png +0 -0
  14. accusleepy/gui/icons/down_arrow.png +0 -0
  15. accusleepy/gui/icons/home.png +0 -0
  16. accusleepy/gui/icons/question.png +0 -0
  17. accusleepy/gui/icons/save.png +0 -0
  18. accusleepy/gui/icons/up_arrow.png +0 -0
  19. accusleepy/gui/icons/zoom_in.png +0 -0
  20. accusleepy/gui/icons/zoom_out.png +0 -0
  21. accusleepy/gui/images/primary_window.png +0 -0
  22. accusleepy/gui/images/viewer_window.png +0 -0
  23. accusleepy/gui/images/viewer_window_annotated.png +0 -0
  24. accusleepy/gui/main.py +1494 -0
  25. accusleepy/gui/manual_scoring.py +1096 -0
  26. accusleepy/gui/mplwidget.py +386 -0
  27. accusleepy/gui/primary_window.py +2577 -0
  28. accusleepy/gui/primary_window.ui +3831 -0
  29. accusleepy/gui/resources.qrc +16 -0
  30. accusleepy/gui/resources_rc.py +6710 -0
  31. accusleepy/gui/text/config_guide.txt +27 -0
  32. accusleepy/gui/text/main_guide.md +167 -0
  33. accusleepy/gui/text/manual_scoring_guide.md +23 -0
  34. accusleepy/gui/viewer_window.py +610 -0
  35. accusleepy/gui/viewer_window.ui +926 -0
  36. accusleepy/models.py +108 -0
  37. accusleepy/multitaper.py +661 -0
  38. accusleepy/signal_processing.py +469 -0
  39. accusleepy/temperature_scaling.py +157 -0
  40. accusleepy-0.6.0.dist-info/METADATA +106 -0
  41. accusleepy-0.6.0.dist-info/RECORD +42 -0
  42. accusleepy-0.6.0.dist-info/WHEEL +4 -0
accusleepy/models.py ADDED
@@ -0,0 +1,108 @@
1
+ import numpy as np
2
+ from torch import device, flatten, nn
3
+ from torch import load as torch_load
4
+ from torch import save as torch_save
5
+
6
+ from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainStateSet
7
+ from accusleepy.constants import (
8
+ DOWNSAMPLING_START_FREQ,
9
+ EMG_COPIES,
10
+ MIN_WINDOW_LEN,
11
+ UPPER_FREQ,
12
+ )
13
+ from accusleepy.temperature_scaling import ModelWithTemperature
14
+
15
+ # height in pixels of each training image
16
+ IMAGE_HEIGHT = (
17
+ len(np.arange(0, DOWNSAMPLING_START_FREQ, 1 / MIN_WINDOW_LEN))
18
+ + len(np.arange(DOWNSAMPLING_START_FREQ, UPPER_FREQ, 2 / MIN_WINDOW_LEN))
19
+ + EMG_COPIES
20
+ )
21
+
22
+
23
+ class SSANN(nn.Module):
24
+ """Small CNN for classifying images"""
25
+
26
+ def __init__(self, n_classes: int):
27
+ super().__init__()
28
+
29
+ self.pool = nn.MaxPool2d(2, 2)
30
+ self.conv1 = nn.Conv2d(
31
+ in_channels=1, out_channels=8, kernel_size=3, padding="same"
32
+ )
33
+ self.conv2 = nn.Conv2d(
34
+ in_channels=8, out_channels=16, kernel_size=3, padding="same"
35
+ )
36
+ self.conv3 = nn.Conv2d(
37
+ in_channels=16, out_channels=32, kernel_size=3, padding="same"
38
+ )
39
+ self.conv1_bn = nn.BatchNorm2d(8)
40
+ self.conv2_bn = nn.BatchNorm2d(16)
41
+ self.conv3_bn = nn.BatchNorm2d(32)
42
+ self.fc1 = nn.Linear(int(32 * IMAGE_HEIGHT / 8), n_classes)
43
+
44
+ def forward(self, x):
45
+ x = x.float()
46
+ x = self.pool(nn.functional.relu(self.conv1_bn(self.conv1(x))))
47
+ x = self.pool(nn.functional.relu(self.conv2_bn(self.conv2(x))))
48
+ x = self.pool(nn.functional.relu(self.conv3_bn(self.conv3(x))))
49
+ x = flatten(x, 1) # flatten all dimensions except batch
50
+ return self.fc1(x)
51
+
52
+
53
+ def save_model(
54
+ model: SSANN,
55
+ filename: str,
56
+ epoch_length: int | float,
57
+ epochs_per_img: int,
58
+ model_type: str,
59
+ brain_state_set: BrainStateSet,
60
+ is_calibrated: bool,
61
+ ) -> None:
62
+ """Save classification model and its metadata
63
+
64
+ :param model: classification model
65
+ :param filename: filename
66
+ :param epoch_length: epoch length used when training the model
67
+ :param epochs_per_img: number of epochs in each model input
68
+ :param model_type: default or real-time
69
+ :param brain_state_set: set of brain state options
70
+ :param is_calibrated: whether the model has been calibrated
71
+ """
72
+ state_dict = model.state_dict()
73
+ state_dict.update({"epoch_length": epoch_length})
74
+ state_dict.update({"epochs_per_img": epochs_per_img})
75
+ state_dict.update({"model_type": model_type})
76
+ state_dict.update({"is_calibrated": is_calibrated})
77
+ state_dict.update(
78
+ {BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
79
+ )
80
+
81
+ torch_save(state_dict, filename)
82
+
83
+
84
+ def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
85
+ """Load classification model and its metadata
86
+
87
+ :param filename: filename
88
+ :return: model, epoch length used when training the model,
89
+ number of epochs in each model input, model type
90
+ (default or real-time), set of brain state options
91
+ used when training the model
92
+ """
93
+ state_dict = torch_load(filename, weights_only=True, map_location=device("cpu"))
94
+ epoch_length = state_dict.pop("epoch_length")
95
+ epochs_per_img = state_dict.pop("epochs_per_img")
96
+ model_type = state_dict.pop("model_type")
97
+ if "is_calibrated" in state_dict:
98
+ is_calibrated = state_dict.pop("is_calibrated")
99
+ else:
100
+ is_calibrated = False
101
+ brain_states = state_dict.pop(BRAIN_STATES_KEY)
102
+ n_classes = len([b for b in brain_states if b["is_scored"]])
103
+
104
+ model = SSANN(n_classes=n_classes)
105
+ if is_calibrated:
106
+ model = ModelWithTemperature(model)
107
+ model.load_state_dict(state_dict)
108
+ return model, epoch_length, epochs_per_img, model_type, brain_states