accusleepy 0.8.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accusleepy/bouts.py +3 -3
- accusleepy/brain_state_set.py +6 -4
- accusleepy/classification.py +14 -50
- accusleepy/constants.py +3 -0
- accusleepy/fileio.py +24 -5
- accusleepy/gui/dialogs.py +40 -0
- accusleepy/gui/images/primary_window.png +0 -0
- accusleepy/gui/main.py +212 -1025
- accusleepy/gui/manual_scoring.py +1 -1
- accusleepy/gui/primary_window.py +7 -9
- accusleepy/gui/primary_window.ui +6 -8
- accusleepy/gui/recording_manager.py +110 -0
- accusleepy/gui/settings_widget.py +409 -0
- accusleepy/gui/text/main_guide.md +1 -1
- accusleepy/models.py +1 -1
- accusleepy/services.py +581 -0
- accusleepy/signal_processing.py +110 -38
- accusleepy/temperature_scaling.py +14 -8
- accusleepy/validation.py +67 -2
- {accusleepy-0.8.1.dist-info → accusleepy-0.9.2.dist-info}/METADATA +2 -2
- {accusleepy-0.8.1.dist-info → accusleepy-0.9.2.dist-info}/RECORD +22 -18
- {accusleepy-0.8.1.dist-info → accusleepy-0.9.2.dist-info}/WHEEL +1 -1
accusleepy/services.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
1
|
+
"""Service classes for orchestrating AccuSleePy operations.
|
|
2
|
+
|
|
3
|
+
Isolating certain functionality here, without any interaction
|
|
4
|
+
with UI state, makes it more testable.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import datetime
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import shutil
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Callable
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
from accusleepy.bouts import enforce_min_bout_length
|
|
18
|
+
from accusleepy.brain_state_set import BrainStateSet
|
|
19
|
+
from accusleepy.constants import (
|
|
20
|
+
ANNOTATIONS_FILENAME,
|
|
21
|
+
CALIBRATION_ANNOTATION_FILENAME,
|
|
22
|
+
DEFAULT_MODEL_TYPE,
|
|
23
|
+
MIN_EPOCHS_PER_STATE,
|
|
24
|
+
UNDEFINED_LABEL,
|
|
25
|
+
MIXTURE_MEAN_COL,
|
|
26
|
+
MIXTURE_SD_COL,
|
|
27
|
+
)
|
|
28
|
+
from accusleepy.fileio import (
|
|
29
|
+
EMGFilter,
|
|
30
|
+
Hyperparameters,
|
|
31
|
+
Recording,
|
|
32
|
+
load_calibration_file,
|
|
33
|
+
load_labels,
|
|
34
|
+
load_recording,
|
|
35
|
+
save_labels,
|
|
36
|
+
)
|
|
37
|
+
from accusleepy.models import SSANN
|
|
38
|
+
from accusleepy.signal_processing import (
|
|
39
|
+
create_training_images,
|
|
40
|
+
resample_and_standardize,
|
|
41
|
+
create_eeg_emg_image,
|
|
42
|
+
get_mixture_values,
|
|
43
|
+
)
|
|
44
|
+
from accusleepy.validation import check_label_validity
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class ServiceResult:
|
|
51
|
+
"""Result of a service operation."""
|
|
52
|
+
|
|
53
|
+
success: bool
|
|
54
|
+
messages: list[str] = field(default_factory=list)
|
|
55
|
+
warnings: list[str] = field(default_factory=list)
|
|
56
|
+
error: str | None = None
|
|
57
|
+
|
|
58
|
+
def report_to(self, callback: Callable[[str], None]) -> None:
|
|
59
|
+
"""Report all warnings, messages, and errors through a callback.
|
|
60
|
+
|
|
61
|
+
:param callback: function to call with each message string
|
|
62
|
+
"""
|
|
63
|
+
for warning in self.warnings:
|
|
64
|
+
callback(f"WARNING: {warning}")
|
|
65
|
+
for message in self.messages:
|
|
66
|
+
callback(message)
|
|
67
|
+
if not self.success and self.error:
|
|
68
|
+
callback(f"ERROR: {self.error}")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def check_single_file_inputs(
|
|
72
|
+
recording: Recording,
|
|
73
|
+
epoch_length: int | float,
|
|
74
|
+
) -> str | None:
|
|
75
|
+
"""Check that a recording's inputs appear valid.
|
|
76
|
+
|
|
77
|
+
This runs some basic tests for whether it will be possible to
|
|
78
|
+
load and score a recording. If any test fails, we return an
|
|
79
|
+
error message.
|
|
80
|
+
|
|
81
|
+
:param recording: the recording to validate
|
|
82
|
+
:param epoch_length: epoch length in seconds
|
|
83
|
+
:return: error message, or None if valid
|
|
84
|
+
"""
|
|
85
|
+
if epoch_length == 0:
|
|
86
|
+
return "epoch length can't be 0"
|
|
87
|
+
if recording.sampling_rate == 0:
|
|
88
|
+
return "sampling rate can't be 0"
|
|
89
|
+
if epoch_length > recording.sampling_rate:
|
|
90
|
+
return "invalid epoch length or sampling rate"
|
|
91
|
+
if recording.recording_file == "":
|
|
92
|
+
return "no recording selected"
|
|
93
|
+
if not os.path.isfile(recording.recording_file):
|
|
94
|
+
return "recording file does not exist"
|
|
95
|
+
if recording.label_file == "":
|
|
96
|
+
return "no label file selected"
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class TrainingService:
|
|
101
|
+
"""Service for training classification models."""
|
|
102
|
+
|
|
103
|
+
def __init__(self, progress_callback: Callable[[str], None] | None = None) -> None:
|
|
104
|
+
"""Initialize the training service.
|
|
105
|
+
|
|
106
|
+
:param progress_callback: optional callback for progress messages
|
|
107
|
+
"""
|
|
108
|
+
self.progress_callback = progress_callback
|
|
109
|
+
|
|
110
|
+
def _report_progress(self, message: str) -> None:
|
|
111
|
+
"""Report progress if callback is available."""
|
|
112
|
+
if self.progress_callback:
|
|
113
|
+
self.progress_callback(message)
|
|
114
|
+
|
|
115
|
+
def train_model(
|
|
116
|
+
self,
|
|
117
|
+
recordings: list[Recording],
|
|
118
|
+
epoch_length: int | float,
|
|
119
|
+
epochs_per_img: int,
|
|
120
|
+
model_type: str,
|
|
121
|
+
calibrate: bool,
|
|
122
|
+
calibration_fraction: float,
|
|
123
|
+
brain_state_set: BrainStateSet,
|
|
124
|
+
emg_filter: EMGFilter,
|
|
125
|
+
hyperparameters: Hyperparameters,
|
|
126
|
+
model_filename: str,
|
|
127
|
+
temp_image_dir: str | None = None,
|
|
128
|
+
delete_images: bool = True,
|
|
129
|
+
) -> ServiceResult:
|
|
130
|
+
"""Train a classification model.
|
|
131
|
+
|
|
132
|
+
:param recordings: list of recordings to use for training
|
|
133
|
+
:param epoch_length: epoch length in seconds
|
|
134
|
+
:param epochs_per_img: number of epochs per training image
|
|
135
|
+
:param model_type: type of model ('default' or 'real-time')
|
|
136
|
+
:param calibrate: whether to calibrate the model
|
|
137
|
+
:param calibration_fraction: fraction of data to use for calibration
|
|
138
|
+
:param brain_state_set: set of brain state options
|
|
139
|
+
:param emg_filter: EMG filter parameters
|
|
140
|
+
:param hyperparameters: model training hyperparameters
|
|
141
|
+
:param model_filename: path to save the trained model
|
|
142
|
+
:param temp_image_dir: directory for training images (auto-generated if None)
|
|
143
|
+
:param delete_images: whether to delete training images after training
|
|
144
|
+
:return: ServiceResult with status and messages
|
|
145
|
+
"""
|
|
146
|
+
result = ServiceResult(success=True)
|
|
147
|
+
|
|
148
|
+
# Validate epochs_per_img for default model type
|
|
149
|
+
if model_type == DEFAULT_MODEL_TYPE and epochs_per_img % 2 == 0:
|
|
150
|
+
return ServiceResult(
|
|
151
|
+
success=False,
|
|
152
|
+
error=(
|
|
153
|
+
"For the default model type, number of epochs "
|
|
154
|
+
"per image must be an odd number."
|
|
155
|
+
),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Validate each recording
|
|
159
|
+
for recording in recordings:
|
|
160
|
+
error_message = check_single_file_inputs(recording, epoch_length)
|
|
161
|
+
if error_message:
|
|
162
|
+
return ServiceResult(
|
|
163
|
+
success=False,
|
|
164
|
+
error=f"Recording {recording.name}: {error_message}",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Create temp image directory if not provided
|
|
168
|
+
if temp_image_dir is None:
|
|
169
|
+
temp_image_dir = os.path.join(
|
|
170
|
+
os.path.dirname(model_filename),
|
|
171
|
+
"images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if os.path.exists(temp_image_dir):
|
|
175
|
+
result.warnings.append("Training image folder exists, will be overwritten")
|
|
176
|
+
os.makedirs(temp_image_dir, exist_ok=True)
|
|
177
|
+
|
|
178
|
+
# Create training images
|
|
179
|
+
self._report_progress("Creating training images")
|
|
180
|
+
if not delete_images:
|
|
181
|
+
result.messages.append(f"Creating training images in {temp_image_dir}")
|
|
182
|
+
else:
|
|
183
|
+
result.messages.append(
|
|
184
|
+
f"Creating temporary folder of training images: {temp_image_dir}"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
logger.info("Creating training images")
|
|
188
|
+
failed_recordings, training_class_balance, had_zero_variance = (
|
|
189
|
+
create_training_images(
|
|
190
|
+
recordings=recordings,
|
|
191
|
+
output_path=temp_image_dir,
|
|
192
|
+
epoch_length=epoch_length,
|
|
193
|
+
epochs_per_img=epochs_per_img,
|
|
194
|
+
brain_state_set=brain_state_set,
|
|
195
|
+
model_type=model_type,
|
|
196
|
+
calibration_fraction=calibration_fraction,
|
|
197
|
+
emg_filter=emg_filter,
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if had_zero_variance:
|
|
202
|
+
result.warnings.append(
|
|
203
|
+
"Some recordings contain features with zero variance. "
|
|
204
|
+
"The EEG or EMG signal might be empty. If this is unexpected, "
|
|
205
|
+
"please make sure the recording files are correctly formatted."
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if len(failed_recordings) > 0:
|
|
209
|
+
if len(failed_recordings) == len(recordings):
|
|
210
|
+
# Cleanup before returning error
|
|
211
|
+
if delete_images and os.path.exists(temp_image_dir):
|
|
212
|
+
shutil.rmtree(temp_image_dir)
|
|
213
|
+
return ServiceResult(
|
|
214
|
+
success=False,
|
|
215
|
+
error="No recordings were valid!",
|
|
216
|
+
warnings=result.warnings,
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
result.warnings.append(
|
|
220
|
+
"The following recordings could not be loaded and will not "
|
|
221
|
+
f"be used for training: {', '.join([str(r) for r in failed_recordings])}. "
|
|
222
|
+
"More information might be available in the terminal."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Train model
|
|
226
|
+
self._report_progress("Training model")
|
|
227
|
+
logger.info("Training model")
|
|
228
|
+
|
|
229
|
+
from accusleepy.classification import create_dataloader, train_ssann
|
|
230
|
+
from accusleepy.models import save_model
|
|
231
|
+
from accusleepy.temperature_scaling import ModelWithTemperature
|
|
232
|
+
|
|
233
|
+
model = train_ssann(
|
|
234
|
+
annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
|
|
235
|
+
img_dir=temp_image_dir,
|
|
236
|
+
training_class_balance=training_class_balance,
|
|
237
|
+
n_classes=brain_state_set.n_classes,
|
|
238
|
+
hyperparameters=hyperparameters,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Calibrate the model if requested
|
|
242
|
+
if calibrate:
|
|
243
|
+
calibration_annotation_file = os.path.join(
|
|
244
|
+
temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
|
|
245
|
+
)
|
|
246
|
+
calibration_dataloader = create_dataloader(
|
|
247
|
+
annotations_file=calibration_annotation_file,
|
|
248
|
+
img_dir=temp_image_dir,
|
|
249
|
+
hyperparameters=hyperparameters,
|
|
250
|
+
)
|
|
251
|
+
model = ModelWithTemperature(model)
|
|
252
|
+
logger.info("Calibrating model")
|
|
253
|
+
model.set_temperature(calibration_dataloader)
|
|
254
|
+
|
|
255
|
+
# Save model
|
|
256
|
+
save_model(
|
|
257
|
+
model=model,
|
|
258
|
+
filename=model_filename,
|
|
259
|
+
epoch_length=epoch_length,
|
|
260
|
+
epochs_per_img=epochs_per_img,
|
|
261
|
+
model_type=model_type,
|
|
262
|
+
brain_state_set=brain_state_set,
|
|
263
|
+
is_calibrated=calibrate,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Optionally delete images
|
|
267
|
+
if delete_images:
|
|
268
|
+
logger.info("Cleaning up training image folder")
|
|
269
|
+
shutil.rmtree(temp_image_dir)
|
|
270
|
+
|
|
271
|
+
result.messages.append(f"Training complete. Saved model to {model_filename}")
|
|
272
|
+
logger.info("Training complete")
|
|
273
|
+
|
|
274
|
+
return result
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
@dataclass
|
|
278
|
+
class LoadedModel:
|
|
279
|
+
"""State for a loaded classification model."""
|
|
280
|
+
|
|
281
|
+
model: SSANN | None = None
|
|
282
|
+
epoch_length: int | float | None = None
|
|
283
|
+
epochs_per_img: int | None = None
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def score_recording_list(
|
|
287
|
+
recordings: list[Recording],
|
|
288
|
+
loaded_model: LoadedModel,
|
|
289
|
+
epoch_length: int | float,
|
|
290
|
+
only_overwrite_undefined: bool,
|
|
291
|
+
save_confidence_scores: bool,
|
|
292
|
+
min_bout_length: int | float,
|
|
293
|
+
brain_state_set: BrainStateSet,
|
|
294
|
+
emg_filter: EMGFilter,
|
|
295
|
+
) -> ServiceResult:
|
|
296
|
+
"""Score all recordings using a classification model.
|
|
297
|
+
|
|
298
|
+
:param recordings: list of recordings to score
|
|
299
|
+
:param loaded_model: loaded classification model and metadata
|
|
300
|
+
:param epoch_length: epoch length in seconds
|
|
301
|
+
:param only_overwrite_undefined: only overwrite epochs labeled as undefined
|
|
302
|
+
:param save_confidence_scores: whether to save confidence scores
|
|
303
|
+
:param min_bout_length: minimum bout length in seconds
|
|
304
|
+
:param brain_state_set: set of brain state options
|
|
305
|
+
:param emg_filter: EMG filter parameters
|
|
306
|
+
:return: ServiceResult with status and messages
|
|
307
|
+
"""
|
|
308
|
+
result = ServiceResult(success=True)
|
|
309
|
+
|
|
310
|
+
# Validate model is loaded
|
|
311
|
+
if loaded_model.model is None:
|
|
312
|
+
return ServiceResult(
|
|
313
|
+
success=False,
|
|
314
|
+
error="No classification model file selected",
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Validate min_bout_length
|
|
318
|
+
if min_bout_length < epoch_length:
|
|
319
|
+
return ServiceResult(
|
|
320
|
+
success=False,
|
|
321
|
+
error="Minimum bout length must be >= epoch length",
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Validate model epoch length matches current
|
|
325
|
+
if epoch_length != loaded_model.epoch_length:
|
|
326
|
+
return ServiceResult(
|
|
327
|
+
success=False,
|
|
328
|
+
error=(
|
|
329
|
+
f"Model was trained with an epoch length of "
|
|
330
|
+
f"{loaded_model.epoch_length} seconds, but the current "
|
|
331
|
+
f"epoch length setting is {epoch_length} seconds."
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Validate each recording
|
|
336
|
+
for recording in recordings:
|
|
337
|
+
error_message = check_single_file_inputs(recording, epoch_length)
|
|
338
|
+
if error_message:
|
|
339
|
+
return ServiceResult(
|
|
340
|
+
success=False,
|
|
341
|
+
error=f"Recording {recording.name}: {error_message}",
|
|
342
|
+
)
|
|
343
|
+
if recording.calibration_file == "":
|
|
344
|
+
return ServiceResult(
|
|
345
|
+
success=False,
|
|
346
|
+
error=f"Recording {recording.name}: no calibration file selected",
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
from accusleepy.classification import score_recording
|
|
350
|
+
|
|
351
|
+
any_zero_variance = False
|
|
352
|
+
|
|
353
|
+
# Score each recording
|
|
354
|
+
for recording in recordings:
|
|
355
|
+
# Load EEG, EMG
|
|
356
|
+
try:
|
|
357
|
+
eeg, emg = load_recording(recording.recording_file)
|
|
358
|
+
sampling_rate = recording.sampling_rate
|
|
359
|
+
|
|
360
|
+
eeg, emg, sampling_rate = resample_and_standardize(
|
|
361
|
+
eeg=eeg,
|
|
362
|
+
emg=emg,
|
|
363
|
+
sampling_rate=sampling_rate,
|
|
364
|
+
epoch_length=epoch_length,
|
|
365
|
+
)
|
|
366
|
+
except Exception:
|
|
367
|
+
logger.exception("Failed to load %s", recording.recording_file)
|
|
368
|
+
result.warnings.append(
|
|
369
|
+
f"Could not load recording {recording.name}. "
|
|
370
|
+
"This recording will be skipped."
|
|
371
|
+
)
|
|
372
|
+
continue
|
|
373
|
+
|
|
374
|
+
# Load labels
|
|
375
|
+
label_file = recording.label_file
|
|
376
|
+
if os.path.isfile(label_file):
|
|
377
|
+
try:
|
|
378
|
+
existing_labels, _ = load_labels(label_file)
|
|
379
|
+
except Exception:
|
|
380
|
+
logger.exception("Failed to load %s", label_file)
|
|
381
|
+
result.warnings.append(
|
|
382
|
+
f"Could not load existing labels for recording "
|
|
383
|
+
f"{recording.name}. This recording will be skipped."
|
|
384
|
+
)
|
|
385
|
+
continue
|
|
386
|
+
# Only check the length
|
|
387
|
+
samples_per_epoch = sampling_rate * epoch_length
|
|
388
|
+
epochs_in_recording = round(eeg.size / samples_per_epoch)
|
|
389
|
+
if epochs_in_recording != existing_labels.size:
|
|
390
|
+
result.warnings.append(
|
|
391
|
+
f"Existing labels for recording {recording.name} "
|
|
392
|
+
"do not match the recording length. "
|
|
393
|
+
"This recording will be skipped."
|
|
394
|
+
)
|
|
395
|
+
continue
|
|
396
|
+
else:
|
|
397
|
+
existing_labels = None
|
|
398
|
+
|
|
399
|
+
# Load calibration data
|
|
400
|
+
if not os.path.isfile(recording.calibration_file):
|
|
401
|
+
result.warnings.append(
|
|
402
|
+
f"Calibration file does not exist for recording "
|
|
403
|
+
f"{recording.name}. This recording will be skipped."
|
|
404
|
+
)
|
|
405
|
+
continue
|
|
406
|
+
try:
|
|
407
|
+
mixture_means, mixture_sds = load_calibration_file(
|
|
408
|
+
recording.calibration_file
|
|
409
|
+
)
|
|
410
|
+
except Exception:
|
|
411
|
+
logger.exception("Failed to load %s", recording.calibration_file)
|
|
412
|
+
result.warnings.append(
|
|
413
|
+
f"Could not load calibration file for recording "
|
|
414
|
+
f"{recording.name}. This recording will be skipped."
|
|
415
|
+
)
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
# Check if calibration data contains any 0-variance features
|
|
419
|
+
if np.any(mixture_sds == 0):
|
|
420
|
+
any_zero_variance = True
|
|
421
|
+
|
|
422
|
+
labels, confidence_scores = score_recording(
|
|
423
|
+
model=loaded_model.model,
|
|
424
|
+
eeg=eeg,
|
|
425
|
+
emg=emg,
|
|
426
|
+
mixture_means=mixture_means,
|
|
427
|
+
mixture_sds=mixture_sds,
|
|
428
|
+
sampling_rate=sampling_rate,
|
|
429
|
+
epoch_length=epoch_length,
|
|
430
|
+
epochs_per_img=loaded_model.epochs_per_img,
|
|
431
|
+
brain_state_set=brain_state_set,
|
|
432
|
+
emg_filter=emg_filter,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# Overwrite as needed
|
|
436
|
+
if existing_labels is not None and only_overwrite_undefined:
|
|
437
|
+
labels[existing_labels != UNDEFINED_LABEL] = existing_labels[
|
|
438
|
+
existing_labels != UNDEFINED_LABEL
|
|
439
|
+
]
|
|
440
|
+
|
|
441
|
+
# Enforce minimum bout length
|
|
442
|
+
labels = enforce_min_bout_length(
|
|
443
|
+
labels=labels,
|
|
444
|
+
epoch_length=epoch_length,
|
|
445
|
+
min_bout_length=min_bout_length,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# Ignore confidence scores if desired
|
|
449
|
+
if not save_confidence_scores:
|
|
450
|
+
confidence_scores = None
|
|
451
|
+
|
|
452
|
+
# Save results
|
|
453
|
+
save_labels(
|
|
454
|
+
labels=labels, filename=label_file, confidence_scores=confidence_scores
|
|
455
|
+
)
|
|
456
|
+
result.messages.append(
|
|
457
|
+
f"Saved labels for recording {recording.name} to {label_file}"
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
if any_zero_variance:
|
|
461
|
+
result.warnings.append(
|
|
462
|
+
"One or more calibration files has 0 variance "
|
|
463
|
+
"for some features. This could indicate that the EEG or "
|
|
464
|
+
"EMG signal is empty in the recording used for calibration."
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
return result
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def create_calibration(
|
|
471
|
+
recording: Recording,
|
|
472
|
+
epoch_length: int | float,
|
|
473
|
+
brain_state_set: BrainStateSet,
|
|
474
|
+
emg_filter: EMGFilter,
|
|
475
|
+
output_filename: str,
|
|
476
|
+
) -> ServiceResult:
|
|
477
|
+
"""Create a calibration file for a recording.
|
|
478
|
+
|
|
479
|
+
:param recording: the recording to create calibration for
|
|
480
|
+
:param epoch_length: epoch length in seconds
|
|
481
|
+
:param brain_state_set: set of brain state options
|
|
482
|
+
:param emg_filter: EMG filter parameters
|
|
483
|
+
:param output_filename: path to save the calibration file
|
|
484
|
+
:return: ServiceResult with status and messages
|
|
485
|
+
"""
|
|
486
|
+
result = ServiceResult(success=True)
|
|
487
|
+
|
|
488
|
+
# Validate recording inputs
|
|
489
|
+
error_message = check_single_file_inputs(recording, epoch_length)
|
|
490
|
+
if error_message:
|
|
491
|
+
return ServiceResult(success=False, error=error_message)
|
|
492
|
+
|
|
493
|
+
# Load the recording
|
|
494
|
+
try:
|
|
495
|
+
eeg, emg = load_recording(recording.recording_file)
|
|
496
|
+
except Exception:
|
|
497
|
+
logger.exception("Failed to load %s", recording.recording_file)
|
|
498
|
+
return ServiceResult(
|
|
499
|
+
success=False,
|
|
500
|
+
error=(
|
|
501
|
+
"Could not load recording. "
|
|
502
|
+
"Check user manual for formatting instructions."
|
|
503
|
+
),
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
sampling_rate = recording.sampling_rate
|
|
507
|
+
eeg, emg, sampling_rate = resample_and_standardize(
|
|
508
|
+
eeg=eeg,
|
|
509
|
+
emg=emg,
|
|
510
|
+
sampling_rate=sampling_rate,
|
|
511
|
+
epoch_length=epoch_length,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
# Load and validate labels
|
|
515
|
+
label_file = recording.label_file
|
|
516
|
+
if not os.path.isfile(label_file):
|
|
517
|
+
return ServiceResult(
|
|
518
|
+
success=False,
|
|
519
|
+
error="Label file does not exist",
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
try:
|
|
523
|
+
labels, _ = load_labels(label_file)
|
|
524
|
+
except Exception:
|
|
525
|
+
logger.exception("Failed to load %s", label_file)
|
|
526
|
+
return ServiceResult(
|
|
527
|
+
success=False,
|
|
528
|
+
error=(
|
|
529
|
+
"Could not load labels. Check user manual for formatting instructions."
|
|
530
|
+
),
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
label_error_message = check_label_validity(
|
|
534
|
+
labels=labels,
|
|
535
|
+
confidence_scores=None,
|
|
536
|
+
samples_in_recording=eeg.size,
|
|
537
|
+
sampling_rate=sampling_rate,
|
|
538
|
+
epoch_length=epoch_length,
|
|
539
|
+
brain_state_set=brain_state_set,
|
|
540
|
+
)
|
|
541
|
+
if label_error_message:
|
|
542
|
+
return ServiceResult(success=False, error=label_error_message)
|
|
543
|
+
|
|
544
|
+
# Check that each scored brain state has sufficient observations
|
|
545
|
+
for brain_state in brain_state_set.brain_states:
|
|
546
|
+
if brain_state.is_scored:
|
|
547
|
+
count = np.sum(labels == brain_state.digit)
|
|
548
|
+
if count < MIN_EPOCHS_PER_STATE:
|
|
549
|
+
return ServiceResult(
|
|
550
|
+
success=False,
|
|
551
|
+
error=(
|
|
552
|
+
f"At least {MIN_EPOCHS_PER_STATE} labeled epochs "
|
|
553
|
+
f"per brain state are required for calibration. Only "
|
|
554
|
+
f"{count} '{brain_state.name}' epoch(s) found."
|
|
555
|
+
),
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Create calibration file
|
|
559
|
+
img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
|
|
560
|
+
mixture_means, mixture_sds = get_mixture_values(
|
|
561
|
+
img=img,
|
|
562
|
+
labels=brain_state_set.convert_digit_to_class(labels),
|
|
563
|
+
brain_state_set=brain_state_set,
|
|
564
|
+
)
|
|
565
|
+
pd.DataFrame({MIXTURE_MEAN_COL: mixture_means, MIXTURE_SD_COL: mixture_sds}).to_csv(
|
|
566
|
+
output_filename, index=False
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
result.messages.append(
|
|
570
|
+
f"Created calibration file using recording {recording.name} "
|
|
571
|
+
f"at {output_filename}"
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
if np.any(mixture_sds == 0):
|
|
575
|
+
result.warnings.append(
|
|
576
|
+
"One or more features derived from the data have "
|
|
577
|
+
"zero variance. This could indicate that the EEG or "
|
|
578
|
+
"EMG signal is empty."
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
return result
|