accusleepy 0.8.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
accusleepy/services.py ADDED
@@ -0,0 +1,581 @@
1
+ """Service classes for orchestrating AccuSleePy operations.
2
+
3
+ Isolating certain functionality here, without any interaction
4
+ with UI state, makes it more testable.
5
+ """
6
+
7
+ import datetime
8
+ import logging
9
+ import os
10
+ import shutil
11
+ from dataclasses import dataclass, field
12
+ from typing import Callable
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ from accusleepy.bouts import enforce_min_bout_length
18
+ from accusleepy.brain_state_set import BrainStateSet
19
+ from accusleepy.constants import (
20
+ ANNOTATIONS_FILENAME,
21
+ CALIBRATION_ANNOTATION_FILENAME,
22
+ DEFAULT_MODEL_TYPE,
23
+ MIN_EPOCHS_PER_STATE,
24
+ UNDEFINED_LABEL,
25
+ MIXTURE_MEAN_COL,
26
+ MIXTURE_SD_COL,
27
+ )
28
+ from accusleepy.fileio import (
29
+ EMGFilter,
30
+ Hyperparameters,
31
+ Recording,
32
+ load_calibration_file,
33
+ load_labels,
34
+ load_recording,
35
+ save_labels,
36
+ )
37
+ from accusleepy.models import SSANN
38
+ from accusleepy.signal_processing import (
39
+ create_training_images,
40
+ resample_and_standardize,
41
+ create_eeg_emg_image,
42
+ get_mixture_values,
43
+ )
44
+ from accusleepy.validation import check_label_validity
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ @dataclass
50
+ class ServiceResult:
51
+ """Result of a service operation."""
52
+
53
+ success: bool
54
+ messages: list[str] = field(default_factory=list)
55
+ warnings: list[str] = field(default_factory=list)
56
+ error: str | None = None
57
+
58
+ def report_to(self, callback: Callable[[str], None]) -> None:
59
+ """Report all warnings, messages, and errors through a callback.
60
+
61
+ :param callback: function to call with each message string
62
+ """
63
+ for warning in self.warnings:
64
+ callback(f"WARNING: {warning}")
65
+ for message in self.messages:
66
+ callback(message)
67
+ if not self.success and self.error:
68
+ callback(f"ERROR: {self.error}")
69
+
70
+
71
+ def check_single_file_inputs(
72
+ recording: Recording,
73
+ epoch_length: int | float,
74
+ ) -> str | None:
75
+ """Check that a recording's inputs appear valid.
76
+
77
+ This runs some basic tests for whether it will be possible to
78
+ load and score a recording. If any test fails, we return an
79
+ error message.
80
+
81
+ :param recording: the recording to validate
82
+ :param epoch_length: epoch length in seconds
83
+ :return: error message, or None if valid
84
+ """
85
+ if epoch_length == 0:
86
+ return "epoch length can't be 0"
87
+ if recording.sampling_rate == 0:
88
+ return "sampling rate can't be 0"
89
+ if epoch_length > recording.sampling_rate:
90
+ return "invalid epoch length or sampling rate"
91
+ if recording.recording_file == "":
92
+ return "no recording selected"
93
+ if not os.path.isfile(recording.recording_file):
94
+ return "recording file does not exist"
95
+ if recording.label_file == "":
96
+ return "no label file selected"
97
+ return None
98
+
99
+
100
+ class TrainingService:
101
+ """Service for training classification models."""
102
+
103
+ def __init__(self, progress_callback: Callable[[str], None] | None = None) -> None:
104
+ """Initialize the training service.
105
+
106
+ :param progress_callback: optional callback for progress messages
107
+ """
108
+ self.progress_callback = progress_callback
109
+
110
+ def _report_progress(self, message: str) -> None:
111
+ """Report progress if callback is available."""
112
+ if self.progress_callback:
113
+ self.progress_callback(message)
114
+
115
+ def train_model(
116
+ self,
117
+ recordings: list[Recording],
118
+ epoch_length: int | float,
119
+ epochs_per_img: int,
120
+ model_type: str,
121
+ calibrate: bool,
122
+ calibration_fraction: float,
123
+ brain_state_set: BrainStateSet,
124
+ emg_filter: EMGFilter,
125
+ hyperparameters: Hyperparameters,
126
+ model_filename: str,
127
+ temp_image_dir: str | None = None,
128
+ delete_images: bool = True,
129
+ ) -> ServiceResult:
130
+ """Train a classification model.
131
+
132
+ :param recordings: list of recordings to use for training
133
+ :param epoch_length: epoch length in seconds
134
+ :param epochs_per_img: number of epochs per training image
135
+ :param model_type: type of model ('default' or 'real-time')
136
+ :param calibrate: whether to calibrate the model
137
+ :param calibration_fraction: fraction of data to use for calibration
138
+ :param brain_state_set: set of brain state options
139
+ :param emg_filter: EMG filter parameters
140
+ :param hyperparameters: model training hyperparameters
141
+ :param model_filename: path to save the trained model
142
+ :param temp_image_dir: directory for training images (auto-generated if None)
143
+ :param delete_images: whether to delete training images after training
144
+ :return: ServiceResult with status and messages
145
+ """
146
+ result = ServiceResult(success=True)
147
+
148
+ # Validate epochs_per_img for default model type
149
+ if model_type == DEFAULT_MODEL_TYPE and epochs_per_img % 2 == 0:
150
+ return ServiceResult(
151
+ success=False,
152
+ error=(
153
+ "For the default model type, number of epochs "
154
+ "per image must be an odd number."
155
+ ),
156
+ )
157
+
158
+ # Validate each recording
159
+ for recording in recordings:
160
+ error_message = check_single_file_inputs(recording, epoch_length)
161
+ if error_message:
162
+ return ServiceResult(
163
+ success=False,
164
+ error=f"Recording {recording.name}: {error_message}",
165
+ )
166
+
167
+ # Create temp image directory if not provided
168
+ if temp_image_dir is None:
169
+ temp_image_dir = os.path.join(
170
+ os.path.dirname(model_filename),
171
+ "images_" + datetime.datetime.now().strftime("%Y%m%d%H%M"),
172
+ )
173
+
174
+ if os.path.exists(temp_image_dir):
175
+ result.warnings.append("Training image folder exists, will be overwritten")
176
+ os.makedirs(temp_image_dir, exist_ok=True)
177
+
178
+ # Create training images
179
+ self._report_progress("Creating training images")
180
+ if not delete_images:
181
+ result.messages.append(f"Creating training images in {temp_image_dir}")
182
+ else:
183
+ result.messages.append(
184
+ f"Creating temporary folder of training images: {temp_image_dir}"
185
+ )
186
+
187
+ logger.info("Creating training images")
188
+ failed_recordings, training_class_balance, had_zero_variance = (
189
+ create_training_images(
190
+ recordings=recordings,
191
+ output_path=temp_image_dir,
192
+ epoch_length=epoch_length,
193
+ epochs_per_img=epochs_per_img,
194
+ brain_state_set=brain_state_set,
195
+ model_type=model_type,
196
+ calibration_fraction=calibration_fraction,
197
+ emg_filter=emg_filter,
198
+ )
199
+ )
200
+
201
+ if had_zero_variance:
202
+ result.warnings.append(
203
+ "Some recordings contain features with zero variance. "
204
+ "The EEG or EMG signal might be empty. If this is unexpected, "
205
+ "please make sure the recording files are correctly formatted."
206
+ )
207
+
208
+ if len(failed_recordings) > 0:
209
+ if len(failed_recordings) == len(recordings):
210
+ # Cleanup before returning error
211
+ if delete_images and os.path.exists(temp_image_dir):
212
+ shutil.rmtree(temp_image_dir)
213
+ return ServiceResult(
214
+ success=False,
215
+ error="No recordings were valid!",
216
+ warnings=result.warnings,
217
+ )
218
+ else:
219
+ result.warnings.append(
220
+ "The following recordings could not be loaded and will not "
221
+ f"be used for training: {', '.join([str(r) for r in failed_recordings])}. "
222
+ "More information might be available in the terminal."
223
+ )
224
+
225
+ # Train model
226
+ self._report_progress("Training model")
227
+ logger.info("Training model")
228
+
229
+ from accusleepy.classification import create_dataloader, train_ssann
230
+ from accusleepy.models import save_model
231
+ from accusleepy.temperature_scaling import ModelWithTemperature
232
+
233
+ model = train_ssann(
234
+ annotations_file=os.path.join(temp_image_dir, ANNOTATIONS_FILENAME),
235
+ img_dir=temp_image_dir,
236
+ training_class_balance=training_class_balance,
237
+ n_classes=brain_state_set.n_classes,
238
+ hyperparameters=hyperparameters,
239
+ )
240
+
241
+ # Calibrate the model if requested
242
+ if calibrate:
243
+ calibration_annotation_file = os.path.join(
244
+ temp_image_dir, CALIBRATION_ANNOTATION_FILENAME
245
+ )
246
+ calibration_dataloader = create_dataloader(
247
+ annotations_file=calibration_annotation_file,
248
+ img_dir=temp_image_dir,
249
+ hyperparameters=hyperparameters,
250
+ )
251
+ model = ModelWithTemperature(model)
252
+ logger.info("Calibrating model")
253
+ model.set_temperature(calibration_dataloader)
254
+
255
+ # Save model
256
+ save_model(
257
+ model=model,
258
+ filename=model_filename,
259
+ epoch_length=epoch_length,
260
+ epochs_per_img=epochs_per_img,
261
+ model_type=model_type,
262
+ brain_state_set=brain_state_set,
263
+ is_calibrated=calibrate,
264
+ )
265
+
266
+ # Optionally delete images
267
+ if delete_images:
268
+ logger.info("Cleaning up training image folder")
269
+ shutil.rmtree(temp_image_dir)
270
+
271
+ result.messages.append(f"Training complete. Saved model to {model_filename}")
272
+ logger.info("Training complete")
273
+
274
+ return result
275
+
276
+
277
+ @dataclass
278
+ class LoadedModel:
279
+ """State for a loaded classification model."""
280
+
281
+ model: SSANN | None = None
282
+ epoch_length: int | float | None = None
283
+ epochs_per_img: int | None = None
284
+
285
+
286
+ def score_recording_list(
287
+ recordings: list[Recording],
288
+ loaded_model: LoadedModel,
289
+ epoch_length: int | float,
290
+ only_overwrite_undefined: bool,
291
+ save_confidence_scores: bool,
292
+ min_bout_length: int | float,
293
+ brain_state_set: BrainStateSet,
294
+ emg_filter: EMGFilter,
295
+ ) -> ServiceResult:
296
+ """Score all recordings using a classification model.
297
+
298
+ :param recordings: list of recordings to score
299
+ :param loaded_model: loaded classification model and metadata
300
+ :param epoch_length: epoch length in seconds
301
+ :param only_overwrite_undefined: only overwrite epochs labeled as undefined
302
+ :param save_confidence_scores: whether to save confidence scores
303
+ :param min_bout_length: minimum bout length in seconds
304
+ :param brain_state_set: set of brain state options
305
+ :param emg_filter: EMG filter parameters
306
+ :return: ServiceResult with status and messages
307
+ """
308
+ result = ServiceResult(success=True)
309
+
310
+ # Validate model is loaded
311
+ if loaded_model.model is None:
312
+ return ServiceResult(
313
+ success=False,
314
+ error="No classification model file selected",
315
+ )
316
+
317
+ # Validate min_bout_length
318
+ if min_bout_length < epoch_length:
319
+ return ServiceResult(
320
+ success=False,
321
+ error="Minimum bout length must be >= epoch length",
322
+ )
323
+
324
+ # Validate model epoch length matches current
325
+ if epoch_length != loaded_model.epoch_length:
326
+ return ServiceResult(
327
+ success=False,
328
+ error=(
329
+ f"Model was trained with an epoch length of "
330
+ f"{loaded_model.epoch_length} seconds, but the current "
331
+ f"epoch length setting is {epoch_length} seconds."
332
+ ),
333
+ )
334
+
335
+ # Validate each recording
336
+ for recording in recordings:
337
+ error_message = check_single_file_inputs(recording, epoch_length)
338
+ if error_message:
339
+ return ServiceResult(
340
+ success=False,
341
+ error=f"Recording {recording.name}: {error_message}",
342
+ )
343
+ if recording.calibration_file == "":
344
+ return ServiceResult(
345
+ success=False,
346
+ error=f"Recording {recording.name}: no calibration file selected",
347
+ )
348
+
349
+ from accusleepy.classification import score_recording
350
+
351
+ any_zero_variance = False
352
+
353
+ # Score each recording
354
+ for recording in recordings:
355
+ # Load EEG, EMG
356
+ try:
357
+ eeg, emg = load_recording(recording.recording_file)
358
+ sampling_rate = recording.sampling_rate
359
+
360
+ eeg, emg, sampling_rate = resample_and_standardize(
361
+ eeg=eeg,
362
+ emg=emg,
363
+ sampling_rate=sampling_rate,
364
+ epoch_length=epoch_length,
365
+ )
366
+ except Exception:
367
+ logger.exception("Failed to load %s", recording.recording_file)
368
+ result.warnings.append(
369
+ f"Could not load recording {recording.name}. "
370
+ "This recording will be skipped."
371
+ )
372
+ continue
373
+
374
+ # Load labels
375
+ label_file = recording.label_file
376
+ if os.path.isfile(label_file):
377
+ try:
378
+ existing_labels, _ = load_labels(label_file)
379
+ except Exception:
380
+ logger.exception("Failed to load %s", label_file)
381
+ result.warnings.append(
382
+ f"Could not load existing labels for recording "
383
+ f"{recording.name}. This recording will be skipped."
384
+ )
385
+ continue
386
+ # Only check the length
387
+ samples_per_epoch = sampling_rate * epoch_length
388
+ epochs_in_recording = round(eeg.size / samples_per_epoch)
389
+ if epochs_in_recording != existing_labels.size:
390
+ result.warnings.append(
391
+ f"Existing labels for recording {recording.name} "
392
+ "do not match the recording length. "
393
+ "This recording will be skipped."
394
+ )
395
+ continue
396
+ else:
397
+ existing_labels = None
398
+
399
+ # Load calibration data
400
+ if not os.path.isfile(recording.calibration_file):
401
+ result.warnings.append(
402
+ f"Calibration file does not exist for recording "
403
+ f"{recording.name}. This recording will be skipped."
404
+ )
405
+ continue
406
+ try:
407
+ mixture_means, mixture_sds = load_calibration_file(
408
+ recording.calibration_file
409
+ )
410
+ except Exception:
411
+ logger.exception("Failed to load %s", recording.calibration_file)
412
+ result.warnings.append(
413
+ f"Could not load calibration file for recording "
414
+ f"{recording.name}. This recording will be skipped."
415
+ )
416
+ continue
417
+
418
+ # Check if calibration data contains any 0-variance features
419
+ if np.any(mixture_sds == 0):
420
+ any_zero_variance = True
421
+
422
+ labels, confidence_scores = score_recording(
423
+ model=loaded_model.model,
424
+ eeg=eeg,
425
+ emg=emg,
426
+ mixture_means=mixture_means,
427
+ mixture_sds=mixture_sds,
428
+ sampling_rate=sampling_rate,
429
+ epoch_length=epoch_length,
430
+ epochs_per_img=loaded_model.epochs_per_img,
431
+ brain_state_set=brain_state_set,
432
+ emg_filter=emg_filter,
433
+ )
434
+
435
+ # Overwrite as needed
436
+ if existing_labels is not None and only_overwrite_undefined:
437
+ labels[existing_labels != UNDEFINED_LABEL] = existing_labels[
438
+ existing_labels != UNDEFINED_LABEL
439
+ ]
440
+
441
+ # Enforce minimum bout length
442
+ labels = enforce_min_bout_length(
443
+ labels=labels,
444
+ epoch_length=epoch_length,
445
+ min_bout_length=min_bout_length,
446
+ )
447
+
448
+ # Ignore confidence scores if desired
449
+ if not save_confidence_scores:
450
+ confidence_scores = None
451
+
452
+ # Save results
453
+ save_labels(
454
+ labels=labels, filename=label_file, confidence_scores=confidence_scores
455
+ )
456
+ result.messages.append(
457
+ f"Saved labels for recording {recording.name} to {label_file}"
458
+ )
459
+
460
+ if any_zero_variance:
461
+ result.warnings.append(
462
+ "One or more calibration files has 0 variance "
463
+ "for some features. This could indicate that the EEG or "
464
+ "EMG signal is empty in the recording used for calibration."
465
+ )
466
+
467
+ return result
468
+
469
+
470
+ def create_calibration(
471
+ recording: Recording,
472
+ epoch_length: int | float,
473
+ brain_state_set: BrainStateSet,
474
+ emg_filter: EMGFilter,
475
+ output_filename: str,
476
+ ) -> ServiceResult:
477
+ """Create a calibration file for a recording.
478
+
479
+ :param recording: the recording to create calibration for
480
+ :param epoch_length: epoch length in seconds
481
+ :param brain_state_set: set of brain state options
482
+ :param emg_filter: EMG filter parameters
483
+ :param output_filename: path to save the calibration file
484
+ :return: ServiceResult with status and messages
485
+ """
486
+ result = ServiceResult(success=True)
487
+
488
+ # Validate recording inputs
489
+ error_message = check_single_file_inputs(recording, epoch_length)
490
+ if error_message:
491
+ return ServiceResult(success=False, error=error_message)
492
+
493
+ # Load the recording
494
+ try:
495
+ eeg, emg = load_recording(recording.recording_file)
496
+ except Exception:
497
+ logger.exception("Failed to load %s", recording.recording_file)
498
+ return ServiceResult(
499
+ success=False,
500
+ error=(
501
+ "Could not load recording. "
502
+ "Check user manual for formatting instructions."
503
+ ),
504
+ )
505
+
506
+ sampling_rate = recording.sampling_rate
507
+ eeg, emg, sampling_rate = resample_and_standardize(
508
+ eeg=eeg,
509
+ emg=emg,
510
+ sampling_rate=sampling_rate,
511
+ epoch_length=epoch_length,
512
+ )
513
+
514
+ # Load and validate labels
515
+ label_file = recording.label_file
516
+ if not os.path.isfile(label_file):
517
+ return ServiceResult(
518
+ success=False,
519
+ error="Label file does not exist",
520
+ )
521
+
522
+ try:
523
+ labels, _ = load_labels(label_file)
524
+ except Exception:
525
+ logger.exception("Failed to load %s", label_file)
526
+ return ServiceResult(
527
+ success=False,
528
+ error=(
529
+ "Could not load labels. Check user manual for formatting instructions."
530
+ ),
531
+ )
532
+
533
+ label_error_message = check_label_validity(
534
+ labels=labels,
535
+ confidence_scores=None,
536
+ samples_in_recording=eeg.size,
537
+ sampling_rate=sampling_rate,
538
+ epoch_length=epoch_length,
539
+ brain_state_set=brain_state_set,
540
+ )
541
+ if label_error_message:
542
+ return ServiceResult(success=False, error=label_error_message)
543
+
544
+ # Check that each scored brain state has sufficient observations
545
+ for brain_state in brain_state_set.brain_states:
546
+ if brain_state.is_scored:
547
+ count = np.sum(labels == brain_state.digit)
548
+ if count < MIN_EPOCHS_PER_STATE:
549
+ return ServiceResult(
550
+ success=False,
551
+ error=(
552
+ f"At least {MIN_EPOCHS_PER_STATE} labeled epochs "
553
+ f"per brain state are required for calibration. Only "
554
+ f"{count} '{brain_state.name}' epoch(s) found."
555
+ ),
556
+ )
557
+
558
+ # Create calibration file
559
+ img = create_eeg_emg_image(eeg, emg, sampling_rate, epoch_length, emg_filter)
560
+ mixture_means, mixture_sds = get_mixture_values(
561
+ img=img,
562
+ labels=brain_state_set.convert_digit_to_class(labels),
563
+ brain_state_set=brain_state_set,
564
+ )
565
+ pd.DataFrame({MIXTURE_MEAN_COL: mixture_means, MIXTURE_SD_COL: mixture_sds}).to_csv(
566
+ output_filename, index=False
567
+ )
568
+
569
+ result.messages.append(
570
+ f"Created calibration file using recording {recording.name} "
571
+ f"at {output_filename}"
572
+ )
573
+
574
+ if np.any(mixture_sds == 0):
575
+ result.warnings.append(
576
+ "One or more features derived from the data have "
577
+ "zero variance. This could indicate that the EEG or "
578
+ "EMG signal is empty."
579
+ )
580
+
581
+ return result