accusleepy 0.4.0__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {accusleepy-0.4.0 → accusleepy-0.6.0}/PKG-INFO +49 -9
  2. accusleepy-0.6.0/README.md +77 -0
  3. accusleepy-0.6.0/accusleepy/bouts.py +142 -0
  4. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/classification.py +27 -9
  5. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/config.json +2 -1
  6. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/constants.py +7 -0
  7. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/fileio.py +35 -65
  8. accusleepy-0.6.0/accusleepy/gui/images/primary_window.png +0 -0
  9. accusleepy-0.6.0/accusleepy/gui/images/viewer_window.png +0 -0
  10. accusleepy-0.6.0/accusleepy/gui/images/viewer_window_annotated.png +0 -0
  11. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/main.py +199 -101
  12. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/manual_scoring.py +112 -102
  13. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/mplwidget.py +69 -39
  14. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/primary_window.py +240 -158
  15. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/primary_window.ui +313 -155
  16. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/resources.qrc +1 -1
  17. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/text/config_guide.txt +0 -2
  18. accusleepy-0.6.0/accusleepy/gui/text/main_guide.md +167 -0
  19. accusleepy-0.6.0/accusleepy/gui/text/manual_scoring_guide.md +23 -0
  20. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/viewer_window.py +19 -7
  21. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/viewer_window.ui +34 -2
  22. accusleepy-0.6.0/accusleepy/models.py +108 -0
  23. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/multitaper.py +9 -7
  24. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/signal_processing.py +28 -148
  25. accusleepy-0.6.0/accusleepy/temperature_scaling.py +157 -0
  26. {accusleepy-0.4.0 → accusleepy-0.6.0}/pyproject.toml +3 -3
  27. accusleepy-0.4.0/README.md +0 -37
  28. accusleepy-0.4.0/accusleepy/gui/text/main_guide_text.py +0 -173
  29. accusleepy-0.4.0/accusleepy/gui/text/manual_scoring_guide.txt +0 -28
  30. accusleepy-0.4.0/accusleepy/models.py +0 -48
  31. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/__init__.py +0 -0
  32. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/__main__.py +0 -0
  33. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/brain_state_set.py +0 -0
  34. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/__init__.py +0 -0
  35. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/brightness_down.png +0 -0
  36. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/brightness_up.png +0 -0
  37. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/double_down_arrow.png +0 -0
  38. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/double_up_arrow.png +0 -0
  39. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/down_arrow.png +0 -0
  40. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/home.png +0 -0
  41. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/question.png +0 -0
  42. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/save.png +0 -0
  43. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/up_arrow.png +0 -0
  44. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/zoom_in.png +0 -0
  45. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/icons/zoom_out.png +0 -0
  46. {accusleepy-0.4.0 → accusleepy-0.6.0}/accusleepy/gui/resources_rc.py +0 -0
@@ -1,16 +1,16 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: accusleepy
3
- Version: 0.4.0
3
+ Version: 0.6.0
4
4
  Summary: Python implementation of AccuSleep
5
5
  License: GPL-3.0-only
6
6
  Author: Zeke Barger
7
7
  Author-email: zekebarger@gmail.com
8
- Requires-Python: >=3.10,<3.13
8
+ Requires-Python: >=3.11,<3.14
9
9
  Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
10
10
  Classifier: Programming Language :: Python :: 3
11
- Classifier: Programming Language :: Python :: 3.10
12
11
  Classifier: Programming Language :: Python :: 3.11
13
12
  Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
14
14
  Requires-Dist: fastparquet (>=2024.11.0,<2025.0.0)
15
15
  Requires-Dist: joblib (>=1.4.2,<2.0.0)
16
16
  Requires-Dist: matplotlib (>=3.10.1,<4.0.0)
@@ -18,7 +18,7 @@ Requires-Dist: numpy (>=2.2.4,<3.0.0)
18
18
  Requires-Dist: pandas (>=2.2.3,<3.0.0)
19
19
  Requires-Dist: pillow (>=11.1.0,<12.0.0)
20
20
  Requires-Dist: pre-commit (>=4.2.0,<5.0.0)
21
- Requires-Dist: pyside6 (>=6.7.1,<6.8.0)
21
+ Requires-Dist: pyside6 (>=6.9.0,<7.0.0)
22
22
  Requires-Dist: scipy (>=1.15.2,<2.0.0)
23
23
  Requires-Dist: toml (>=0.10.2,<0.11.0)
24
24
  Requires-Dist: torch (>=2.6.0,<3.0.0)
@@ -30,9 +30,17 @@ Description-Content-Type: text/markdown
30
30
 
31
31
  ## Description
32
32
 
33
- AccuSleePy is a python implementation of AccuSleep--a set of graphical user interfaces for scoring rodent
34
- sleep using EEG and EMG recordings. It offers several improvements over the original MATLAB version
35
- and is the only version that will be actively maintained.
33
+ AccuSleePy is set of graphical user interfaces for scoring rodent sleep
34
+ using EEG and EMG recordings.
35
+ It offers the following improvements over the MATLAB version (AccuSleep):
36
+
37
+ - Up to 10 brain states can be configured through the user interface
38
+ - Classification models can be trained through the user interface
39
+ - Model files contain useful metadata (brain state configuration,
40
+ epoch length, number of epochs)
41
+ - Models optimized for real-time scoring can be trained
42
+ - Lists of recordings can be imported and exported for repeatable batch processing
43
+ - Undo/redo functionality in the manual scoring interface
36
44
 
37
45
  If you use AccuSleep in your research, please cite our
38
46
  [publication](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0224642):
@@ -43,24 +51,56 @@ The data and models associated with AccuSleep are available at https://osf.io/py
43
51
 
44
52
  Please contact zekebarger (at) gmail (dot) com with any questions or comments about the software.
45
53
 
54
+
46
55
  ## Installation
47
56
 
48
57
  - (recommended) create a new virtual environment (using
49
58
  [venv](https://docs.python.org/3/library/venv.html),
50
59
  [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html),
51
- etc.) using python >=3.10,<3.13
60
+ etc.) with python >=3.11,<3.14
52
61
  - (optional) if you have a CUDA device and want to speed up model training, [install PyTorch](https://pytorch.org/)
53
62
  - `pip install accusleepy`
54
63
  - (optional) download a classification model from https://osf.io/py5eb/ under /python_format/models/
55
64
 
65
+ Note that upgrading or reinstalling the package will overwrite any changes
66
+ to the [config file](accusleepy/config.json).
67
+
56
68
  ## Usage
57
69
 
58
70
  `python -m accusleepy` will open the primary interface.
59
71
 
72
+ [Guide to the primary interface](accusleepy/gui/text/main_guide.md)
73
+
74
+ [Guide to the manual scoring interface](accusleepy/gui/text/manual_scoring_guide.md)
75
+
76
+ ## Changelog
77
+
78
+ - 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
79
+ since the new calibration feature will make the confidence scores more accurate.
80
+ - 0.5.0: Performance improvements
81
+ - 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
82
+ - 0.4.4: Performance improvements
83
+ - 0.4.3: Improved unit tests and user manuals
84
+ - 0.4.0: Improved visuals and user manuals
85
+ - 0.1.0-0.3.1: Early development versions
86
+
87
+ ## Screenshots
88
+
89
+ Primary interface
90
+ ![AccuSleePy primary interface](accusleepy/gui/images/primary_window.png)
91
+
92
+ Manual scoring interface
93
+ ![AccuSleePy manual scoring interface](accusleepy/gui/images/viewer_window.png)
94
+
60
95
  ## Acknowledgements
61
96
 
62
97
  We would like to thank [Franz Weber](https://www.med.upenn.edu/weberlab/) for creating an
63
- early version of the manual labeling interface.
98
+ early version of the manual labeling interface. The code that
99
+ creates spectrograms comes from the
100
+ [Prerau lab](https://github.com/preraulab/multitaper_toolbox/blob/master/python/multitaper_spectrogram_python.py)
101
+ with only minor modifications.
64
102
  Jim Bohnslav's [deepethogram](https://github.com/jbohnslav/deepethogram) served as an
65
103
  incredibly useful reference when reimplementing this project in python.
104
+ The model calibration code added in version 0.6.0 comes from Geoff Pleiss'
105
+ [temperature scaling repo](https://github.com/gpleiss/temperature_scaling).
66
106
 
@@ -0,0 +1,77 @@
1
+ # AccuSleePy
2
+
3
+ ## Description
4
+
5
+ AccuSleePy is set of graphical user interfaces for scoring rodent sleep
6
+ using EEG and EMG recordings.
7
+ It offers the following improvements over the MATLAB version (AccuSleep):
8
+
9
+ - Up to 10 brain states can be configured through the user interface
10
+ - Classification models can be trained through the user interface
11
+ - Model files contain useful metadata (brain state configuration,
12
+ epoch length, number of epochs)
13
+ - Models optimized for real-time scoring can be trained
14
+ - Lists of recordings can be imported and exported for repeatable batch processing
15
+ - Undo/redo functionality in the manual scoring interface
16
+
17
+ If you use AccuSleep in your research, please cite our
18
+ [publication](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0224642):
19
+
20
+ Barger, Z., Frye, C. G., Liu, D., Dan, Y., & Bouchard, K. E. (2019). Robust, automated sleep scoring by a compact neural network with distributional shift correction. *PLOS ONE, 14*(12), 1–18.
21
+
22
+ The data and models associated with AccuSleep are available at https://osf.io/py5eb/
23
+
24
+ Please contact zekebarger (at) gmail (dot) com with any questions or comments about the software.
25
+
26
+
27
+ ## Installation
28
+
29
+ - (recommended) create a new virtual environment (using
30
+ [venv](https://docs.python.org/3/library/venv.html),
31
+ [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html),
32
+ etc.) with python >=3.11,<3.14
33
+ - (optional) if you have a CUDA device and want to speed up model training, [install PyTorch](https://pytorch.org/)
34
+ - `pip install accusleepy`
35
+ - (optional) download a classification model from https://osf.io/py5eb/ under /python_format/models/
36
+
37
+ Note that upgrading or reinstalling the package will overwrite any changes
38
+ to the [config file](accusleepy/config.json).
39
+
40
+ ## Usage
41
+
42
+ `python -m accusleepy` will open the primary interface.
43
+
44
+ [Guide to the primary interface](accusleepy/gui/text/main_guide.md)
45
+
46
+ [Guide to the manual scoring interface](accusleepy/gui/text/manual_scoring_guide.md)
47
+
48
+ ## Changelog
49
+
50
+ - 0.6.0: Confidence scores can now be displayed and saved. Retraining your models is recommended
51
+ since the new calibration feature will make the confidence scores more accurate.
52
+ - 0.5.0: Performance improvements
53
+ - 0.4.5: Added support for python 3.13, **removed support for python 3.10.**
54
+ - 0.4.4: Performance improvements
55
+ - 0.4.3: Improved unit tests and user manuals
56
+ - 0.4.0: Improved visuals and user manuals
57
+ - 0.1.0-0.3.1: Early development versions
58
+
59
+ ## Screenshots
60
+
61
+ Primary interface
62
+ ![AccuSleePy primary interface](accusleepy/gui/images/primary_window.png)
63
+
64
+ Manual scoring interface
65
+ ![AccuSleePy manual scoring interface](accusleepy/gui/images/viewer_window.png)
66
+
67
+ ## Acknowledgements
68
+
69
+ We would like to thank [Franz Weber](https://www.med.upenn.edu/weberlab/) for creating an
70
+ early version of the manual labeling interface. The code that
71
+ creates spectrograms comes from the
72
+ [Prerau lab](https://github.com/preraulab/multitaper_toolbox/blob/master/python/multitaper_spectrogram_python.py)
73
+ with only minor modifications.
74
+ Jim Bohnslav's [deepethogram](https://github.com/jbohnslav/deepethogram) served as an
75
+ incredibly useful reference when reimplementing this project in python.
76
+ The model calibration code added in version 0.6.0 comes from Geoff Pleiss'
77
+ [temperature scaling repo](https://github.com/gpleiss/temperature_scaling).
@@ -0,0 +1,142 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from operator import attrgetter
4
+
5
+ import numpy as np
6
+
7
+
8
+ @dataclass
9
+ class Bout:
10
+ """Stores information about a brain state bout"""
11
+
12
+ length: int # length, in number of epochs
13
+ start_index: int # index where bout starts
14
+ end_index: int # index where bout ends
15
+ surrounding_state: int # brain state on both sides of the bout
16
+
17
+
18
+ def find_last_adjacent_bout(sorted_bouts: list[Bout], bout_index: int) -> int:
19
+ """Find index of last consecutive same-length bout
20
+
21
+ When running the post-processing step that enforces a minimum duration
22
+ for brain state bouts, there is a special case when bouts below the
23
+ duration threshold occur consecutively. This function performs a
24
+ recursive search for the index of a bout at the end of such a sequence.
25
+ When initially called, bout_index will always be 0. If, for example, the
26
+ first three bouts in the list are consecutive, the function will return 2.
27
+
28
+ :param sorted_bouts: list of brain state bouts, sorted by start time
29
+ :param bout_index: index of the bout in question
30
+ :return: index of the last consecutive same-length bout
31
+ """
32
+ # if we're at the end of the bout list, stop
33
+ if bout_index == len(sorted_bouts) - 1:
34
+ return bout_index
35
+
36
+ # if there is an adjacent bout
37
+ if sorted_bouts[bout_index].end_index == sorted_bouts[bout_index + 1].start_index:
38
+ # look for more adjacent bouts using that one as a starting point
39
+ return find_last_adjacent_bout(sorted_bouts, bout_index + 1)
40
+ else:
41
+ return bout_index
42
+
43
+
44
+ def enforce_min_bout_length(
45
+ labels: np.array, epoch_length: int | float, min_bout_length: int | float
46
+ ) -> np.array:
47
+ """Ensure brain state bouts meet the min length requirement
48
+
49
+ As a post-processing step for sleep scoring, we can require that any
50
+ bout (continuous period) of a brain state have a minimum duration.
51
+ This function sets any bout shorter than the minimum duration to the
52
+ surrounding brain state (if the states on the left and right sides
53
+ are the same). In the case where there are consecutive short bouts,
54
+ it either creates a transition at the midpoint or removes all short
55
+ bouts, depending on whether the number is even or odd. For example:
56
+ ...AAABABAAA... -> ...AAAAAAAAA...
57
+ ...AAABABABBB... -> ...AAAAABBBBB...
58
+
59
+ :param labels: brain state labels (digits in the 0-9 range)
60
+ :param epoch_length: epoch length, in seconds
61
+ :param min_bout_length: minimum bout length, in seconds
62
+ :return: updated brain state labels
63
+ """
64
+ # if recording is very short, don't change anything
65
+ if labels.size < 3:
66
+ return labels
67
+
68
+ if epoch_length == min_bout_length:
69
+ return labels
70
+
71
+ # get minimum number of epochs in a bout
72
+ min_epochs = int(np.ceil(min_bout_length / epoch_length))
73
+ # get set of states in the labels
74
+ brain_states = set(labels.tolist())
75
+
76
+ while True: # so true
77
+ # convert labels to a string for regex search
78
+ # There is probably a regex that can find all patterns like ab+a
79
+ # without consuming each "a" but I haven't found it :(
80
+ label_string = "".join(labels.astype(str))
81
+
82
+ bouts = list()
83
+
84
+ for state in brain_states:
85
+ for other_state in brain_states:
86
+ if state == other_state:
87
+ continue
88
+ # get start and end indices of each bout
89
+ expression = (
90
+ f"(?<={other_state}){state}{{1,{min_epochs - 1}}}(?={other_state})"
91
+ )
92
+ matches = re.finditer(expression, label_string)
93
+ spans = [match.span() for match in matches]
94
+
95
+ # if some bouts were found
96
+ for span in spans:
97
+ bouts.append(
98
+ Bout(
99
+ length=span[1] - span[0],
100
+ start_index=span[0],
101
+ end_index=span[1],
102
+ surrounding_state=other_state,
103
+ )
104
+ )
105
+
106
+ if len(bouts) == 0:
107
+ break
108
+
109
+ # only keep the shortest bouts
110
+ min_length_in_list = np.min([bout.length for bout in bouts])
111
+ bouts = [i for i in bouts if i.length == min_length_in_list]
112
+ # sort by start index
113
+ sorted_bouts = sorted(bouts, key=attrgetter("start_index"))
114
+
115
+ while len(sorted_bouts) > 0:
116
+ # get row index of latest adjacent bout (of same length)
117
+ last_adjacent_bout_index = find_last_adjacent_bout(sorted_bouts, 0)
118
+ # if there's an even number of adjacent bouts
119
+ if (last_adjacent_bout_index + 1) % 2 == 0:
120
+ midpoint = sorted_bouts[
121
+ round((last_adjacent_bout_index + 1) / 2)
122
+ ].start_index
123
+ labels[sorted_bouts[0].start_index : midpoint] = sorted_bouts[
124
+ 0
125
+ ].surrounding_state
126
+ labels[midpoint : sorted_bouts[last_adjacent_bout_index].end_index] = (
127
+ sorted_bouts[last_adjacent_bout_index].surrounding_state
128
+ )
129
+ else:
130
+ labels[
131
+ sorted_bouts[0].start_index : sorted_bouts[
132
+ last_adjacent_bout_index
133
+ ].end_index
134
+ ] = sorted_bouts[0].surrounding_state
135
+
136
+ # delete the bouts we just fixed
137
+ if last_adjacent_bout_index == len(sorted_bouts) - 1:
138
+ sorted_bouts = []
139
+ else:
140
+ sorted_bouts = sorted_bouts[(last_adjacent_bout_index + 1) :]
141
+
142
+ return labels
@@ -61,13 +61,31 @@ def get_device():
61
61
  )
62
62
 
63
63
 
64
- def train_model(
64
+ def create_dataloader(
65
+ annotations_file: str, img_dir: str, shuffle: bool = True
66
+ ) -> DataLoader:
67
+ """Create DataLoader for a dataset of training or calibration images
68
+
69
+ :param annotations_file: file with information on each training image
70
+ :param img_dir: training image location
71
+ :param shuffle: reshuffle data for every epoch
72
+ :return: DataLoader for the data
73
+
74
+ """
75
+ image_dataset = AccuSleepImageDataset(
76
+ annotations_file=annotations_file,
77
+ img_dir=img_dir,
78
+ )
79
+ return DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=shuffle)
80
+
81
+
82
+ def train_ssann(
65
83
  annotations_file: str,
66
84
  img_dir: str,
67
85
  mixture_weights: np.array,
68
86
  n_classes: int,
69
87
  ) -> SSANN:
70
- """Train a classification model for sleep scoring
88
+ """Train a SSANN classification model for sleep scoring
71
89
 
72
90
  :param annotations_file: file with information on each training image
73
91
  :param img_dir: training image location
@@ -75,11 +93,9 @@ def train_model(
75
93
  :param n_classes: number of classes the model will learn
76
94
  :return: trained Sleep Scoring Artificial Neural Network model
77
95
  """
78
- training_data = AccuSleepImageDataset(
79
- annotations_file=annotations_file,
80
- img_dir=img_dir,
96
+ train_dataloader = create_dataloader(
97
+ annotations_file=annotations_file, img_dir=img_dir
81
98
  )
82
- train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)
83
99
 
84
100
  device = get_device()
85
101
  model = SSANN(n_classes=n_classes)
@@ -130,7 +146,7 @@ def score_recording(
130
146
  :param epoch_length: epoch length, in seconds
131
147
  :param epochs_per_img: number of epochs for the model to consider
132
148
  :param brain_state_set: set of brain state options
133
- :return: brain state labels
149
+ :return: brain state labels, confidence scores
134
150
  """
135
151
  # prepare model
136
152
  device = get_device()
@@ -158,10 +174,12 @@ def score_recording(
158
174
  # perform classification
159
175
  with torch.no_grad():
160
176
  outputs = model(images)
161
- _, predicted = torch.max(outputs, 1)
177
+ logits, predicted = torch.max(outputs, 1)
162
178
 
163
179
  labels = brain_state_set.convert_class_to_digit(predicted.cpu().numpy())
164
- return labels
180
+ confidence_scores = 1 / (1 + np.e ** (-logits.cpu().numpy()))
181
+
182
+ return labels, confidence_scores
165
183
 
166
184
 
167
185
  def example_real_time_scoring_function(
@@ -19,5 +19,6 @@
19
19
  "frequency": 0.55
20
20
  }
21
21
  ],
22
- "default_epoch_length": 2.5
22
+ "default_epoch_length": 2.5,
23
+ "save_confidence_setting": true
23
24
  }
@@ -8,6 +8,7 @@ EEG_COL = "eeg"
8
8
  EMG_COL = "emg"
9
9
  # label file columns
10
10
  BRAIN_STATE_COL = "brain_state"
11
+ CONFIDENCE_SCORE_COL = "confidence_score"
11
12
 
12
13
 
13
14
  # really don't change these
@@ -37,3 +38,9 @@ RECORDING_LIST_NAME = "recording_list"
37
38
  RECORDING_LIST_FILE_TYPE = ".json"
38
39
  # key for default epoch length in config
39
40
  DEFAULT_EPOCH_LENGTH_KEY = "default_epoch_length"
41
+ # key used for default confidence score behavior in config
42
+ DEFAULT_CONFIDENCE_SETTING_KEY = "save_confidence_setting"
43
+ # filename used to store info about training image datasets
44
+ ANNOTATIONS_FILENAME = "annotations.csv"
45
+ # filename for annotation file for the calibration set
46
+ CALIBRATION_ANNOTATION_FILENAME = "calibration_set.csv"
@@ -4,13 +4,14 @@ from dataclasses import dataclass
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
7
- import torch
8
7
  from PySide6.QtWidgets import QListWidgetItem
9
8
 
10
9
  from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainState, BrainStateSet
11
10
  from accusleepy.constants import (
12
11
  BRAIN_STATE_COL,
12
+ CONFIDENCE_SCORE_COL,
13
13
  CONFIG_FILE,
14
+ DEFAULT_CONFIDENCE_SETTING_KEY,
14
15
  DEFAULT_EPOCH_LENGTH_KEY,
15
16
  EEG_COL,
16
17
  EMG_COL,
@@ -19,7 +20,6 @@ from accusleepy.constants import (
19
20
  RECORDING_LIST_NAME,
20
21
  UNDEFINED_LABEL,
21
22
  )
22
- from accusleepy.models import SSANN
23
23
 
24
24
 
25
25
  @dataclass
@@ -46,57 +46,6 @@ def load_calibration_file(filename: str) -> (np.array, np.array):
46
46
  return mixture_means, mixture_sds
47
47
 
48
48
 
49
- def save_model(
50
- model: SSANN,
51
- filename: str,
52
- epoch_length: int | float,
53
- epochs_per_img: int,
54
- model_type: str,
55
- brain_state_set: BrainStateSet,
56
- ) -> None:
57
- """Save classification model and its metadata
58
-
59
- :param model: classification model
60
- :param epoch_length: epoch length used when training the model
61
- :param epochs_per_img: number of epochs in each model input
62
- :param model_type: default or real-time
63
- :param brain_state_set: set of brain state options
64
- :param filename: filename
65
- """
66
- state_dict = model.state_dict()
67
- state_dict.update({"epoch_length": epoch_length})
68
- state_dict.update({"epochs_per_img": epochs_per_img})
69
- state_dict.update({"model_type": model_type})
70
- state_dict.update(
71
- {BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
72
- )
73
-
74
- torch.save(state_dict, filename)
75
-
76
-
77
- def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
78
- """Load classification model and its metadata
79
-
80
- :param filename: filename
81
- :return: model, epoch length used when training the model,
82
- number of epochs in each model input, model type
83
- (default or real-time), set of brain state options
84
- used when training the model
85
- """
86
- state_dict = torch.load(
87
- filename, weights_only=True, map_location=torch.device("cpu")
88
- )
89
- epoch_length = state_dict.pop("epoch_length")
90
- epochs_per_img = state_dict.pop("epochs_per_img")
91
- model_type = state_dict.pop("model_type")
92
- brain_states = state_dict.pop(BRAIN_STATES_KEY)
93
- n_classes = len([b for b in brain_states if b["is_scored"]])
94
-
95
- model = SSANN(n_classes=n_classes)
96
- model.load_state_dict(state_dict)
97
- return model, epoch_length, epochs_per_img, model_type, brain_states
98
-
99
-
100
49
  def load_csv_or_parquet(filename: str) -> pd.DataFrame:
101
50
  """Load a csv or parquet file as a dataframe
102
51
 
@@ -125,49 +74,70 @@ def load_recording(filename: str) -> (np.array, np.array):
125
74
  return eeg, emg
126
75
 
127
76
 
128
- def load_labels(filename: str) -> np.array:
129
- """Load file of brain state labels
77
+ def load_labels(filename: str) -> (np.array, np.array):
78
+ """Load file of brain state labels and confidence scores
130
79
 
131
80
  :param filename: filename
132
- :return: array of brain state labels
81
+ :return: array of brain state labels and, optionally, array of confidence scores
133
82
  """
134
83
  df = load_csv_or_parquet(filename)
135
- return df[BRAIN_STATE_COL].values
84
+ if CONFIDENCE_SCORE_COL in df.columns:
85
+ return df[BRAIN_STATE_COL].values, df[CONFIDENCE_SCORE_COL].values
86
+ else:
87
+ return df[BRAIN_STATE_COL].values, None
136
88
 
137
89
 
138
- def save_labels(labels: np.array, filename: str) -> None:
90
+ def save_labels(
91
+ labels: np.array, filename: str, confidence_scores: np.array = None
92
+ ) -> None:
139
93
  """Save brain state labels to file
140
94
 
141
95
  :param labels: brain state labels
142
96
  :param filename: filename
97
+ :param confidence_scores: optional confidence scores
143
98
  """
144
- pd.DataFrame({BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
99
+ if confidence_scores is not None:
100
+ pd.DataFrame(
101
+ {BRAIN_STATE_COL: labels, CONFIDENCE_SCORE_COL: confidence_scores}
102
+ ).to_csv(filename, index=False)
103
+ else:
104
+ pd.DataFrame({BRAIN_STATE_COL: labels}).to_csv(filename, index=False)
145
105
 
146
106
 
147
- def load_config() -> tuple[BrainStateSet, int | float]:
107
+ def load_config() -> tuple[BrainStateSet, int | float, bool]:
148
108
  """Load configuration file with brain state options
149
109
 
150
- :return: set of brain state options and default epoch length
110
+ :return: set of brain state options, other settings
151
111
  """
152
112
  with open(
153
113
  os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "r"
154
114
  ) as f:
155
115
  data = json.load(f)
156
- return BrainStateSet(
157
- [BrainState(**b) for b in data[BRAIN_STATES_KEY]], UNDEFINED_LABEL
158
- ), data[DEFAULT_EPOCH_LENGTH_KEY]
116
+
117
+ return (
118
+ BrainStateSet(
119
+ [BrainState(**b) for b in data[BRAIN_STATES_KEY]], UNDEFINED_LABEL
120
+ ),
121
+ data[DEFAULT_EPOCH_LENGTH_KEY],
122
+ data.get(DEFAULT_CONFIDENCE_SETTING_KEY, True),
123
+ )
159
124
 
160
125
 
161
126
  def save_config(
162
- brain_state_set: BrainStateSet, default_epoch_length: int | float
127
+ brain_state_set: BrainStateSet,
128
+ default_epoch_length: int | float,
129
+ save_confidence_setting: bool,
163
130
  ) -> None:
164
131
  """Save configuration of brain state options to json file
165
132
 
166
133
  :param brain_state_set: set of brain state options
167
134
  :param default_epoch_length: epoch length to use when the GUI starts
135
+ :param save_confidence_setting: whether the option to save confidence
136
+ scores should be True by default
168
137
  """
169
138
  output_dict = brain_state_set.to_output_dict()
170
139
  output_dict.update({DEFAULT_EPOCH_LENGTH_KEY: default_epoch_length})
140
+ output_dict.update({DEFAULT_CONFIDENCE_SETTING_KEY: save_confidence_setting})
171
141
  with open(
172
142
  os.path.join(os.path.dirname(os.path.abspath(__file__)), CONFIG_FILE), "w"
173
143
  ) as f: