celldetective 1.5.0b6__py3-none-any.whl → 1.5.0b8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. celldetective/_version.py +1 -1
  2. celldetective/event_detection_models.py +2463 -0
  3. celldetective/gui/base/channel_norm_generator.py +19 -3
  4. celldetective/gui/base/figure_canvas.py +1 -1
  5. celldetective/gui/base_annotator.py +2 -5
  6. celldetective/gui/event_annotator.py +248 -138
  7. celldetective/gui/pair_event_annotator.py +146 -20
  8. celldetective/gui/process_block.py +2 -2
  9. celldetective/gui/seg_model_loader.py +4 -4
  10. celldetective/gui/settings/_settings_event_model_training.py +32 -14
  11. celldetective/gui/settings/_settings_segmentation_model_training.py +5 -5
  12. celldetective/gui/settings/_settings_signal_annotator.py +0 -19
  13. celldetective/gui/viewers/base_viewer.py +17 -20
  14. celldetective/processes/train_signal_model.py +1 -1
  15. celldetective/processes/unified_process.py +16 -2
  16. celldetective/scripts/train_signal_model.py +1 -1
  17. celldetective/signals.py +4 -2426
  18. celldetective/utils/cellpose_utils/__init__.py +2 -2
  19. celldetective/utils/event_detection/__init__.py +1 -1
  20. celldetective/utils/stardist_utils/__init__.py +1 -2
  21. {celldetective-1.5.0b6.dist-info → celldetective-1.5.0b8.dist-info}/METADATA +1 -5
  22. {celldetective-1.5.0b6.dist-info → celldetective-1.5.0b8.dist-info}/RECORD +27 -26
  23. tests/test_signals.py +4 -4
  24. {celldetective-1.5.0b6.dist-info → celldetective-1.5.0b8.dist-info}/WHEEL +0 -0
  25. {celldetective-1.5.0b6.dist-info → celldetective-1.5.0b8.dist-info}/entry_points.txt +0 -0
  26. {celldetective-1.5.0b6.dist-info → celldetective-1.5.0b8.dist-info}/licenses/LICENSE +0 -0
  27. {celldetective-1.5.0b6.dist-info → celldetective-1.5.0b8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2463 @@
1
+ import json
2
+ import os
3
+ import random
4
+ import time
5
+ from glob import glob
6
+
7
+ import keras
8
+ import numpy as np
9
+ from scipy.ndimage import shift
10
+ from tensorflow.keras.callbacks import Callback
11
+
12
+ from matplotlib import pyplot as plt
13
+ from natsort import natsorted
14
+ from scipy.interpolate import interp1d
15
+ from sklearn.metrics import (
16
+ jaccard_score,
17
+ balanced_accuracy_score,
18
+ precision_score,
19
+ recall_score,
20
+ confusion_matrix,
21
+ ConfusionMatrixDisplay,
22
+ classification_report,
23
+ )
24
+ from tensorflow.keras.utils import to_categorical
25
+ from tensorflow.keras.losses import MeanAbsoluteError
26
+ from tensorflow.keras.callbacks import (
27
+ ReduceLROnPlateau,
28
+ CSVLogger,
29
+ ModelCheckpoint,
30
+ EarlyStopping,
31
+ TensorBoard,
32
+ )
33
+ from tensorflow.keras.layers import (
34
+ Activation,
35
+ Add,
36
+ )
37
+ from tensorflow.keras.layers import (
38
+ Input,
39
+ ZeroPadding1D,
40
+ Conv1D,
41
+ BatchNormalization,
42
+ MaxPooling1D,
43
+ GlobalAveragePooling1D,
44
+ Concatenate,
45
+ Flatten,
46
+ Dense,
47
+ Dropout,
48
+ )
49
+ from tensorflow.keras.models import Model
50
+
51
+ from tensorflow.keras.losses import CategoricalCrossentropy
52
+ from tensorflow.config.experimental import (
53
+ list_physical_devices,
54
+ set_memory_growth,
55
+ )
56
+ from tensorflow.keras.optimizers import Adam
57
+ from tensorflow.keras.metrics import Precision, Recall, MeanIoU
58
+ from tensorflow.keras.models import clone_model, load_model
59
+ from tensorflow.keras.losses import MeanSquaredError
60
+
61
+ from celldetective.utils.dataset_helpers import compute_weights, train_test_split
62
+ from celldetective.utils.plots.regression import regression_plot
63
+
64
+
65
+ def TimeHistory():
66
+ """Create a TimeHistory callback instance."""
67
+ cls = _get_time_history_class()
68
+ return cls()
69
+
70
+
71
+ class SignalDetectionModel(object):
72
+ """
73
+ A class for creating and managing signal detection models for analyzing biological signals.
74
+
75
+ This class provides functionalities to load a pretrained signal detection model or create one from scratch,
76
+ preprocess input signals, train the model, and make predictions on new data.
77
+
78
+ Parameters
79
+ ----------
80
+ path : str, optional
81
+ Path to the directory containing the model and its configuration. This is used when loading a pretrained model.
82
+ pretrained : str, optional
83
+ Path to the pretrained model to load. If specified, the model and its configuration are loaded from this path.
84
+ channel_option : list of str, optional
85
+ Specifies the channels to be used for signal analysis. Default is ["live_nuclei_channel"].
86
+ model_signal_length : int, optional
87
+ The length of the input signals that the model expects. Default is 128.
88
+ n_channels : int, optional
89
+ The number of channels in the input signals. Default is 1.
90
+ n_conv : int, optional
91
+ The number of convolutional layers in the model. Default is 2.
92
+ n_classes : int, optional
93
+ The number of classes for the classification task. Default is 3.
94
+ dense_collection : int, optional
95
+ The number of units in the dense layer of the model. Default is 512.
96
+ dropout_rate : float, optional
97
+ The dropout rate applied to the dense layer of the model. Default is 0.1.
98
+ label : str, optional
99
+ A label for the model, used in naming and organizing outputs. Default is ''.
100
+
101
+ Attributes
102
+ ----------
103
+ model_class : keras Model
104
+ The classification model for predicting the class of signals.
105
+ model_reg : keras Model
106
+ The regression model for predicting the time of interest for signals.
107
+
108
+ Methods
109
+ -------
110
+ load_pretrained_model()
111
+ Loads the model and its configuration from the pretrained path.
112
+ create_models_from_scratch()
113
+ Creates new models for classification and regression from scratch.
114
+ prep_gpu()
115
+ Prepares GPU devices for training, if available.
116
+ fit_from_directory(ds_folders, ...)
117
+ Trains the model using data from specified directories.
118
+ fit(x_train, y_time_train, y_class_train, ...)
119
+ Trains the model using provided datasets.
120
+ predict_class(x, ...)
121
+ Predicts the class of input signals.
122
+ predict_time_of_interest(x, ...)
123
+ Predicts the time of interest for input signals.
124
+ plot_model_history(mode)
125
+ Plots the training history for the specified mode (classifier or regressor).
126
+ evaluate_regression_model()
127
+ Evaluates the regression model on test and validation data.
128
+ gather_callbacks(mode)
129
+ Gathers and prepares callbacks for training based on the specified mode.
130
+ generate_sets()
131
+ Generates training, validation, and test sets from loaded data.
132
+ augment_training_set()
133
+ Augments the training set with additional generated data.
134
+ load_and_normalize(subset)
135
+ Loads and normalizes signals from a subset of data.
136
+
137
+ Notes
138
+ -----
139
+ - This class is designed to work with biological signal data, such as time series from microscopy imaging.
140
+ - The model architecture and training configurations can be customized through the class parameters and methods.
141
+
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ path=None,
147
+ pretrained=None,
148
+ channel_option=["live_nuclei_channel"],
149
+ model_signal_length=128,
150
+ n_channels=1,
151
+ n_conv=2,
152
+ n_classes=3,
153
+ dense_collection=512,
154
+ dropout_rate=0.1,
155
+ label="",
156
+ ):
157
+
158
+ self.prep_gpu()
159
+
160
+ self.model_signal_length = model_signal_length
161
+ self.channel_option = channel_option
162
+ self.pretrained = pretrained
163
+ self.n_channels = n_channels
164
+ self.n_conv = n_conv
165
+ self.n_classes = n_classes
166
+ self.dense_collection = dense_collection
167
+ self.dropout_rate = dropout_rate
168
+ self.label = label
169
+ self.show_plots = True
170
+
171
+ if self.pretrained is not None:
172
+ print(f"Load pretrained models from {pretrained}...")
173
+ test = self.load_pretrained_model()
174
+ if test is None:
175
+ self.pretrained = None
176
+ print(
177
+ "Pretrained model could not be loaded. Check the log for error. Abort..."
178
+ )
179
+ return None
180
+ else:
181
+ print("Create models from scratch...")
182
+ self.create_models_from_scratch()
183
+ print("Models successfully created.")
184
+
185
+ def load_pretrained_model(self):
186
+ """
187
+ Loads a pretrained model and its configuration from the specified path.
188
+
189
+ This method attempts to load both the classification and regression models from the path specified during the
190
+ class instantiation. It also loads the model configuration from a JSON file and updates the model attributes
191
+ accordingly. If the models cannot be loaded, an error message is printed.
192
+
193
+ Raises
194
+ ------
195
+ Exception
196
+ If there is an error loading the model or the configuration file, an exception is raised with details.
197
+
198
+ Notes
199
+ -----
200
+ - The models are expected to be saved in .h5 format with the filenames "classifier.h5" and "regressor.h5".
201
+ - The configuration file is expected to be named "config_input.json" and located in the same directory as the models.
202
+
203
+ """
204
+
205
+ if self.pretrained.endswith(os.sep):
206
+ self.pretrained = os.sep.join(self.pretrained.split(os.sep)[:-1])
207
+
208
+ try:
209
+ self.model_class = load_model(
210
+ os.sep.join([self.pretrained, "classifier.h5"]),
211
+ compile=False,
212
+ custom_objects={"mse": MeanSquaredError()},
213
+ )
214
+ self.model_class.load_weights(
215
+ os.sep.join([self.pretrained, "classifier.h5"])
216
+ )
217
+ self.model_class = self.freeze_encoder(self.model_class, 5)
218
+ print("Classifier successfully loaded...")
219
+ except Exception as e:
220
+ print(f"Error {e}...")
221
+ self.model_class = None
222
+ try:
223
+ self.model_reg = load_model(
224
+ os.sep.join([self.pretrained, "regressor.h5"]),
225
+ compile=False,
226
+ custom_objects={"mse": MeanSquaredError()},
227
+ )
228
+ self.model_reg.load_weights(os.sep.join([self.pretrained, "regressor.h5"]))
229
+ self.model_reg = self.freeze_encoder(self.model_reg, 5)
230
+ print("Regressor successfully loaded...")
231
+ except Exception as e:
232
+ print(f"Error {e}...")
233
+ self.model_reg = None
234
+
235
+ if self.model_class is None and self.model_reg is None:
236
+ return None
237
+
238
+ # load config
239
+ with open(os.sep.join([self.pretrained, "config_input.json"])) as config_file:
240
+ model_config = json.load(config_file)
241
+ self.config = model_config
242
+
243
+ req_channels = model_config["channels"]
244
+ print(f"Required channels read from pretrained model: {req_channels}")
245
+ self.channel_option = req_channels
246
+ if "normalize" in model_config:
247
+ self.normalize = model_config["normalize"]
248
+ if "normalization_percentile" in model_config:
249
+ self.normalization_percentile = model_config["normalization_percentile"]
250
+ if "normalization_values" in model_config:
251
+ self.normalization_values = model_config["normalization_values"]
252
+ if "normalization_clip" in model_config:
253
+ self.normalization_clip = model_config["normalization_clip"]
254
+ if "label" in model_config:
255
+ self.label = model_config["label"]
256
+
257
+ try:
258
+ self.n_channels = self.model_class.layers[0].input_shape[0][-1]
259
+ self.model_signal_length = self.model_class.layers[0].input_shape[0][-2]
260
+ self.n_classes = self.model_class.layers[-1].output_shape[-1]
261
+ model_class_input_shape = self.model_class.layers[0].input_shape[0]
262
+ model_reg_input_shape = self.model_reg.layers[0].input_shape[0]
263
+ except AttributeError:
264
+ self.n_channels = self.model_class.input_shape[
265
+ -1
266
+ ] # self.model_class.layers[0].input.shape[0][-1]
267
+ self.model_signal_length = self.model_class.input_shape[
268
+ -2
269
+ ] # self.model_class.layers[0].input[0].shape[0][-2]
270
+ self.n_classes = self.model_class.output_shape[
271
+ -1
272
+ ] # self.model_class.layers[-1].output[0].shape[-1]
273
+ model_class_input_shape = self.model_class.input_shape
274
+ model_reg_input_shape = self.model_reg.input_shape
275
+ except Exception as e:
276
+ print(e)
277
+
278
+ assert (
279
+ model_class_input_shape == model_reg_input_shape
280
+ ), f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
281
+
282
+ return True
283
+
284
+ def freeze_encoder(self, model, n_trainable_layers: int = 3):
285
+ for layer in model.layers[
286
+ : -min(n_trainable_layers, len(model.layers))
287
+ ]: # freeze everything except final Dense layer
288
+ layer.trainable = False
289
+ return model
290
+
291
+ def create_models_from_scratch(self):
292
+ """
293
+ Initializes new models for classification and regression based on the specified parameters.
294
+
295
+ This method creates new ResNet models for both classification and regression tasks using the parameters specified
296
+ during class instantiation. The models are configured but not compiled or trained.
297
+
298
+ Notes
299
+ -----
300
+ - The models are created using a custom ResNet architecture defined elsewhere in the codebase.
301
+ - The models are stored in the `model_class` and `model_reg` attributes of the class.
302
+
303
+ """
304
+
305
+ self.model_class = ResNetModelCurrent(
306
+ n_channels=self.n_channels,
307
+ n_slices=self.n_conv,
308
+ n_classes=3,
309
+ dense_collection=self.dense_collection,
310
+ dropout_rate=self.dropout_rate,
311
+ header="classifier",
312
+ model_signal_length=self.model_signal_length,
313
+ )
314
+
315
+ self.model_reg = ResNetModelCurrent(
316
+ n_channels=self.n_channels,
317
+ n_slices=self.n_conv,
318
+ n_classes=self.n_classes,
319
+ dense_collection=self.dense_collection,
320
+ dropout_rate=self.dropout_rate,
321
+ header="regressor",
322
+ model_signal_length=self.model_signal_length,
323
+ )
324
+
325
+ def prep_gpu(self):
326
+ """
327
+ Prepares GPU devices for training by enabling memory growth.
328
+
329
+ This method attempts to identify available GPU devices and configures TensorFlow to allow memory growth on each
330
+ GPU. This prevents TensorFlow from allocating the total available memory on the GPU device upfront.
331
+
332
+ Notes
333
+ -----
334
+ - This method should be called before any TensorFlow/Keras operations that might allocate GPU memory.
335
+ - If no GPUs are detected, the method will pass silently.
336
+
337
+ """
338
+
339
+ try:
340
+
341
+ physical_devices = list_physical_devices("GPU")
342
+ for gpu in physical_devices:
343
+ set_memory_growth(gpu, True)
344
+ except Exception:
345
+ pass
346
+
347
+ def fit_from_directory(
348
+ self,
349
+ datasets,
350
+ normalize=True,
351
+ normalization_percentile=None,
352
+ normalization_values=None,
353
+ normalization_clip=None,
354
+ channel_option=["live_nuclei_channel"],
355
+ model_name=None,
356
+ target_directory=None,
357
+ augment=True,
358
+ augmentation_factor=2,
359
+ validation_split=0.20,
360
+ test_split=0.0,
361
+ batch_size=64,
362
+ epochs=300,
363
+ recompile_pretrained=False,
364
+ learning_rate=0.01,
365
+ loss_reg="mse",
366
+ loss_class=None,
367
+ show_plots=True,
368
+ callbacks=None,
369
+ ):
370
+ """
371
+ Trains the model using data from specified directories.
372
+
373
+ This method prepares the dataset for training by loading and preprocessing data from specified directories,
374
+ then trains the classification and regression models.
375
+
376
+ Parameters
377
+ ----------
378
+ ds_folders : list of str
379
+ List of directories containing the dataset files for training.
380
+ callbacks : list, optional
381
+ List of Keras callbacks to apply during training.
382
+ normalize : bool, optional
383
+ Whether to normalize the input signals (default is True).
384
+ normalization_percentile : list or None, optional
385
+ Percentiles for signal normalization (default is None).
386
+ normalization_values : list or None, optional
387
+ Specific values for signal normalization (default is None).
388
+ normalization_clip : bool, optional
389
+ Whether to clip the normalized signals (default is None).
390
+ channel_option : list of str, optional
391
+ Specifies the channels to be used for signal analysis (default is ["live_nuclei_channel"]).
392
+ model_name : str, optional
393
+ Name of the model for saving purposes (default is None).
394
+ target_directory : str, optional
395
+ Directory where the trained model and outputs will be saved (default is None).
396
+ augment : bool, optional
397
+ Whether to augment the training data (default is True).
398
+ augmentation_factor : int, optional
399
+ Factor by which to augment the training data (default is 2).
400
+ validation_split : float, optional
401
+ Fraction of the data to be used as validation set (default is 0.20).
402
+ test_split : float, optional
403
+ Fraction of the data to be used as test set (default is 0.0).
404
+ batch_size : int, optional
405
+ Batch size for training (default is 64).
406
+ epochs : int, optional
407
+ Number of epochs to train for (default is 300).
408
+ recompile_pretrained : bool, optional
409
+ Whether to recompile a pretrained model (default is False).
410
+ learning_rate : float, optional
411
+ Learning rate for the optimizer (default is 0.01).
412
+ loss_reg : str or keras.losses.Loss, optional
413
+ Loss function for the regression model (default is "mse").
414
+ loss_class : str or keras.losses.Loss, optional
415
+ Loss function for the classification model (default is CategoricalCrossentropy(from_logits=False)).
416
+
417
+ Notes
418
+ -----
419
+ - The method automatically splits the dataset into training, validation, and test sets according to the specified splits.
420
+
421
+ """
422
+
423
+ # Lazy import for TensorFlow loss class
424
+ if loss_class is None:
425
+ loss_class = CategoricalCrossentropy(from_logits=False)
426
+
427
+ if not hasattr(self, "normalization_percentile"):
428
+ self.normalization_percentile = normalization_percentile
429
+ if not hasattr(self, "normalization_values"):
430
+ self.normalization_values = normalization_values
431
+ if not hasattr(self, "normalization_clip"):
432
+ self.normalization_clip = normalization_clip
433
+
434
+ self.callbacks = callbacks
435
+ self.normalize = normalize
436
+ (
437
+ self.normalization_percentile,
438
+ self.normalization_values,
439
+ self.normalization_clip,
440
+ ) = _interpret_normalization_parameters(
441
+ self.n_channels,
442
+ self.normalization_percentile,
443
+ self.normalization_values,
444
+ self.normalization_clip,
445
+ )
446
+
447
+ self.datasets = [rf"{d}" if isinstance(d, str) else d for d in datasets]
448
+ self.batch_size = batch_size
449
+ self.epochs = epochs
450
+ self.validation_split = validation_split
451
+ self.test_split = test_split
452
+ self.augment = augment
453
+ self.augmentation_factor = augmentation_factor
454
+ self.model_name = rf"{model_name}"
455
+ self.target_directory = rf"{target_directory}"
456
+ self.model_folder = os.sep.join([self.target_directory, self.model_name])
457
+ self.recompile_pretrained = recompile_pretrained
458
+ self.learning_rate = learning_rate
459
+ self.loss_reg = loss_reg
460
+ self.loss_class = loss_class
461
+ self.show_plots = show_plots
462
+ self.channel_option = channel_option
463
+
464
+ assert self.n_channels == len(
465
+ self.channel_option
466
+ ), f"Mismatch between the channel option and the number of channels of the model..."
467
+
468
+ if isinstance(self.datasets[0], dict):
469
+ self.datasets = [self.datasets]
470
+
471
+ self.list_of_sets = []
472
+ for ds in self.datasets:
473
+ if isinstance(ds, str):
474
+ self.list_of_sets.extend(glob(os.sep.join([ds, "*.npy"])))
475
+ else:
476
+ self.list_of_sets.append(ds)
477
+
478
+ print(f"Found {len(self.list_of_sets)} datasets...")
479
+
480
+ self.prepare_sets()
481
+ self.train_generic()
482
+
483
+ def fit(
484
+ self,
485
+ x_train,
486
+ y_time_train,
487
+ y_class_train,
488
+ normalize=True,
489
+ normalization_percentile=None,
490
+ normalization_values=None,
491
+ normalization_clip=None,
492
+ pad=True,
493
+ validation_data=None,
494
+ test_data=None,
495
+ channel_option=["live_nuclei_channel", "dead_nuclei_channel"],
496
+ model_name=None,
497
+ target_directory=None,
498
+ augment=True,
499
+ augmentation_factor=3,
500
+ validation_split=0.25,
501
+ batch_size=64,
502
+ epochs=300,
503
+ recompile_pretrained=False,
504
+ learning_rate=0.001,
505
+ loss_reg="mse",
506
+ loss_class=None,
507
+ ):
508
+ """
509
+ Trains the model using provided datasets.
510
+
511
+ Parameters
512
+ ----------
513
+ Same as `fit_from_directory`, but instead of loading data from directories, this method accepts preloaded and
514
+ optionally preprocessed datasets directly.
515
+
516
+ Notes
517
+ -----
518
+ - This method provides an alternative way to train the model when data is already loaded into memory, offering
519
+ flexibility for data preprocessing steps outside this class.
520
+
521
+ """
522
+
523
+ # Lazy import for TensorFlow loss class
524
+ if loss_class is None:
525
+ loss_class = CategoricalCrossentropy(from_logits=False)
526
+
527
+ self.normalize = normalize
528
+ if not hasattr(self, "normalization_percentile"):
529
+ self.normalization_percentile = normalization_percentile
530
+ if not hasattr(self, "normalization_values"):
531
+ self.normalization_values = normalization_values
532
+ if not hasattr(self, "normalization_clip"):
533
+ self.normalization_clip = normalization_clip
534
+ (
535
+ self.normalization_percentile,
536
+ self.normalization_values,
537
+ self.normalization_clip,
538
+ ) = _interpret_normalization_parameters(
539
+ self.n_channels,
540
+ self.normalization_percentile,
541
+ self.normalization_values,
542
+ self.normalization_clip,
543
+ )
544
+
545
+ self.x_train = x_train
546
+ self.y_class_train = y_class_train
547
+ self.y_time_train = y_time_train
548
+ self.channel_option = channel_option
549
+
550
+ assert self.n_channels == len(
551
+ self.channel_option
552
+ ), f"Mismatch between the channel option and the number of channels of the model..."
553
+
554
+ if pad:
555
+ self.x_train = pad_to_model_length(self.x_train, self.model_signal_length)
556
+
557
+ assert self.x_train.shape[1:] == (
558
+ self.model_signal_length,
559
+ self.n_channels,
560
+ ), f"Shape mismatch between the provided training fluorescence signals and the model..."
561
+
562
+ # If y-class is not one-hot encoded, encode it
563
+ if self.y_class_train.shape[-1] != self.n_classes:
564
+ self.class_weights = compute_weights(
565
+ y=self.y_class_train,
566
+ class_weight="balanced",
567
+ classes=np.unique(self.y_class_train),
568
+ )
569
+ self.y_class_train = to_categorical(self.y_class_train, num_classes=3)
570
+
571
+ if self.normalize:
572
+ self.y_time_train = (
573
+ self.y_time_train.astype(np.float32) / self.model_signal_length
574
+ )
575
+ self.x_train = normalize_signal_set(
576
+ self.x_train,
577
+ self.channel_option,
578
+ normalization_percentile=self.normalization_percentile,
579
+ normalization_values=self.normalization_values,
580
+ normalization_clip=self.normalization_clip,
581
+ )
582
+
583
+ if validation_data is not None:
584
+ try:
585
+ self.x_val = validation_data[0]
586
+ if pad:
587
+ self.x_val = pad_to_model_length(
588
+ self.x_val, self.model_signal_length
589
+ )
590
+ self.y_class_val = validation_data[1]
591
+ if self.y_class_val.shape[-1] != self.n_classes:
592
+ self.y_class_val = to_categorical(self.y_class_val, num_classes=3)
593
+ self.y_time_val = validation_data[2]
594
+ if self.normalize:
595
+ self.y_time_val = (
596
+ self.y_time_val.astype(np.float32) / self.model_signal_length
597
+ )
598
+ self.x_val = normalize_signal_set(
599
+ self.x_val,
600
+ self.channel_option,
601
+ normalization_percentile=self.normalization_percentile,
602
+ normalization_values=self.normalization_values,
603
+ normalization_clip=self.normalization_clip,
604
+ )
605
+
606
+ except Exception as e:
607
+ print(f"Could not load validation data, error {e}...")
608
+ else:
609
+ self.validation_split = validation_split
610
+
611
+ if test_data is not None:
612
+ try:
613
+ self.x_test = test_data[0]
614
+ if pad:
615
+ self.x_test = pad_to_model_length(
616
+ self.x_test, self.model_signal_length
617
+ )
618
+ self.y_class_test = test_data[1]
619
+ if self.y_class_test.shape[-1] != self.n_classes:
620
+ self.y_class_test = to_categorical(self.y_class_test, num_classes=3)
621
+ self.y_time_test = test_data[2]
622
+ if self.normalize:
623
+ self.y_time_test = (
624
+ self.y_time_test.astype(np.float32) / self.model_signal_length
625
+ )
626
+ self.x_test = normalize_signal_set(
627
+ self.x_test,
628
+ self.channel_option,
629
+ normalization_percentile=self.normalization_percentile,
630
+ normalization_values=self.normalization_values,
631
+ normalization_clip=self.normalization_clip,
632
+ )
633
+ except Exception as e:
634
+ print(f"Could not load test data, error {e}...")
635
+
636
+ self.batch_size = batch_size
637
+ self.epochs = epochs
638
+ self.augment = augment
639
+ self.augmentation_factor = augmentation_factor
640
+ if self.augmentation_factor == 1:
641
+ self.augment = False
642
+ self.model_name = model_name
643
+ self.target_directory = target_directory
644
+ self.model_folder = os.sep.join([self.target_directory, self.model_name])
645
+ self.recompile_pretrained = recompile_pretrained
646
+ self.learning_rate = learning_rate
647
+ self.loss_reg = loss_reg
648
+ self.loss_class = loss_class
649
+
650
+ self.train_generic()
651
+
652
+ def train_generic(self):
653
+
654
+ if not os.path.exists(self.model_folder):
655
+ os.mkdir(self.model_folder)
656
+
657
+ self.train_classifier()
658
+ self.train_regressor()
659
+
660
+ config_input = {
661
+ "channels": self.channel_option,
662
+ "model_signal_length": self.model_signal_length,
663
+ "label": self.label,
664
+ "normalize": self.normalize,
665
+ "normalization_percentile": self.normalization_percentile,
666
+ "normalization_values": self.normalization_values,
667
+ "normalization_clip": self.normalization_clip,
668
+ }
669
+ json_string = json.dumps(config_input)
670
+ with open(
671
+ os.sep.join([self.model_folder, "config_input.json"]), "w"
672
+ ) as outfile:
673
+ outfile.write(json_string)
674
+
675
+ # Free memory by clearing large training arrays
676
+ import gc
677
+
678
+ for attr in [
679
+ "x_train",
680
+ "x_val",
681
+ "x_test",
682
+ "x_set",
683
+ "y_time_train",
684
+ "y_time_val",
685
+ "y_time_test",
686
+ "y_time_set",
687
+ "y_class_train",
688
+ "y_class_val",
689
+ "y_class_test",
690
+ "y_class_set",
691
+ ]:
692
+ if hasattr(self, attr):
693
+ delattr(self, attr)
694
+ gc.collect()
695
+
696
+ def predict_class(
697
+ self, x, normalize=True, pad=True, return_one_hot=False, interpolate=True
698
+ ):
699
+ """
700
+ Predicts the class of input signals using the trained classification model.
701
+
702
+ Parameters
703
+ ----------
704
+ x : ndarray
705
+ The input signals for which to predict classes.
706
+ normalize : bool, optional
707
+ Whether to normalize the input signals (default is True).
708
+ pad : bool, optional
709
+ Whether to pad the input signals to match the model's expected signal length (default is True).
710
+ return_one_hot : bool, optional
711
+ Whether to return predictions in one-hot encoded format (default is False).
712
+ interpolate : bool, optional
713
+ Whether to interpolate the input signals (default is True).
714
+
715
+ Returns
716
+ -------
717
+ ndarray
718
+ The predicted classes for the input signals. If `return_one_hot` is True, predictions are returned in one-hot
719
+ encoded format, otherwise as integer labels.
720
+
721
+ Notes
722
+ -----
723
+ - The method processes the input signals according to the specified options to ensure compatibility with the model's
724
+ input requirements.
725
+
726
+ """
727
+
728
+ self.x = np.copy(x)
729
+ self.normalize = normalize
730
+ self.pad = pad
731
+ self.return_one_hot = return_one_hot
732
+ # self.max_relevant_time = np.shape(self.x)[1]
733
+ # print(f'Max relevant time: {self.max_relevant_time}')
734
+
735
+ if self.pad:
736
+ self.x = pad_to_model_length(self.x, self.model_signal_length)
737
+
738
+ if self.normalize:
739
+ self.x = normalize_signal_set(
740
+ self.x,
741
+ self.channel_option,
742
+ normalization_percentile=self.normalization_percentile,
743
+ normalization_values=self.normalization_values,
744
+ normalization_clip=self.normalization_clip,
745
+ )
746
+
747
+ # implement auto interpolation here!!
748
+ # self.x = self.interpolate_signals(self.x)
749
+
750
+ # for i in range(5):
751
+ # plt.plot(self.x[i,:,0])
752
+ # plt.show()
753
+
754
+ try:
755
+ n_channels = self.model_class.layers[0].input_shape[0][-1]
756
+ model_signal_length = self.model_class.layers[0].input_shape[0][-2]
757
+ except AttributeError:
758
+ n_channels = self.model_class.input_shape[-1]
759
+ model_signal_length = self.model_class.input_shape[-2]
760
+
761
+ assert (
762
+ self.x.shape[-1] == n_channels
763
+ ), f"Shape mismatch between the input shape and the model input shape..."
764
+ assert (
765
+ self.x.shape[-2] == model_signal_length
766
+ ), f"Shape mismatch between the input shape and the model input shape..."
767
+
768
+ self.class_predictions_one_hot = self.model_class.predict(self.x)
769
+ self.class_predictions = self.class_predictions_one_hot.argmax(axis=1)
770
+
771
+ if self.return_one_hot:
772
+ return self.class_predictions_one_hot
773
+ else:
774
+ return self.class_predictions
775
+
776
+ def predict_time_of_interest(
777
+ self, x, class_predictions=None, normalize=True, pad=True
778
+ ):
779
+ """
780
+ Predicts the time of interest for input signals using the trained regression model.
781
+
782
+ Parameters
783
+ ----------
784
+ x : ndarray
785
+ The input signals for which to predict times of interest.
786
+ class_predictions : ndarray, optional
787
+ The predicted classes for the input signals. If provided, time of interest predictions are only made for
788
+ signals predicted to belong to a specific class (default is None).
789
+ normalize : bool, optional
790
+ Whether to normalize the input signals (default is True).
791
+ pad : bool, optional
792
+ Whether to pad the input signals to match the model's expected signal length (default is True).
793
+
794
+ Returns
795
+ -------
796
+ ndarray
797
+ The predicted times of interest for the input signals.
798
+
799
+ Notes
800
+ -----
801
+ - The method processes the input signals according to the specified options and uses the regression model to
802
+ predict times at which a particular event of interest occurs.
803
+
804
+ """
805
+
806
+ self.x = np.copy(x)
807
+ self.normalize = normalize
808
+ self.pad = pad
809
+ # self.max_relevant_time = np.shape(self.x)[1]
810
+ # print(f'Max relevant time: {self.max_relevant_time}')
811
+
812
+ if class_predictions is not None:
813
+ self.class_predictions = class_predictions
814
+
815
+ if self.pad:
816
+ self.x = pad_to_model_length(self.x, self.model_signal_length)
817
+
818
+ if self.normalize:
819
+ self.x = normalize_signal_set(
820
+ self.x,
821
+ self.channel_option,
822
+ normalization_percentile=self.normalization_percentile,
823
+ normalization_values=self.normalization_values,
824
+ normalization_clip=self.normalization_clip,
825
+ )
826
+
827
+ try:
828
+ n_channels = self.model_reg.layers[0].input_shape[0][-1]
829
+ model_signal_length = self.model_reg.layers[0].input_shape[0][-2]
830
+ except AttributeError:
831
+ n_channels = self.model_reg.input_shape[-1]
832
+ model_signal_length = self.model_reg.input_shape[-2]
833
+
834
+ assert (
835
+ self.x.shape[-1] == n_channels
836
+ ), f"Shape mismatch between the input shape and the model input shape..."
837
+ assert (
838
+ self.x.shape[-2] == model_signal_length
839
+ ), f"Shape mismatch between the input shape and the model input shape..."
840
+
841
+ if np.any(self.class_predictions == 0):
842
+ self.time_predictions = (
843
+ self.model_reg.predict(self.x[self.class_predictions == 0])
844
+ * self.model_signal_length
845
+ )
846
+ self.time_predictions = self.time_predictions[:, 0]
847
+ self.time_predictions_recast = np.zeros(len(self.x)) - 1.0
848
+ self.time_predictions_recast[self.class_predictions == 0] = (
849
+ self.time_predictions
850
+ )
851
+ else:
852
+ self.time_predictions_recast = np.zeros(len(self.x)) - 1.0
853
+ return self.time_predictions_recast
854
+
855
+ def interpolate_signals(self, x_set):
856
+ """
857
+ Interpolates missing values in the input signal set.
858
+
859
+ Parameters
860
+ ----------
861
+ x_set : ndarray
862
+ The input signal set with potentially missing values.
863
+
864
+ Returns
865
+ -------
866
+ ndarray
867
+ The input signal set with missing values interpolated.
868
+
869
+ Notes
870
+ -----
871
+ - This method is useful for preparing signals that have gaps or missing time points before further processing
872
+ or model training.
873
+
874
+ """
875
+
876
+ for i in range(len(x_set)):
877
+ for k in range(x_set.shape[-1]):
878
+ x = x_set[i, :, k]
879
+ not_nan = np.logical_not(np.isnan(x))
880
+ indices = np.arange(len(x))
881
+ interp = interp1d(
882
+ indices[not_nan],
883
+ x[not_nan],
884
+ fill_value=(0.0, 0.0),
885
+ bounds_error=False,
886
+ )
887
+ x_set[i, :, k] = interp(indices)
888
+ return x_set
889
+
890
+ def train_classifier(self):
891
+ """
892
+ Trains the classifier component of the model to predict event classes in signals.
893
+
894
+ This method compiles the classifier model (if not pretrained or if recompilation is requested) and
895
+ trains it on the prepared dataset. The training process includes validation and early stopping based
896
+ on precision to prevent overfitting.
897
+
898
+ Notes
899
+ -----
900
+ - The classifier model predicts the class of each signal, such as live, dead, or miscellaneous.
901
+ - Training parameters such as epochs, batch size, and learning rate are specified during class instantiation.
902
+ - Model performance metrics and training history are saved for analysis.
903
+
904
+ """
905
+
906
+ # if pretrained model
907
+ self.n_classes = 3
908
+
909
+ if self.pretrained is not None:
910
+ # if recompile
911
+ if self.recompile_pretrained:
912
+ print(
913
+ "Recompiling the pretrained classifier model... Warning, this action reinitializes all the weights; are you sure that this is what you intended?"
914
+ )
915
+ self.model_class.set_weights(
916
+ clone_model(self.model_class).get_weights()
917
+ )
918
+ self.model_class.compile(
919
+ optimizer=Adam(learning_rate=self.learning_rate),
920
+ loss=self.loss_class,
921
+ metrics=[
922
+ "accuracy",
923
+ Precision(),
924
+ Recall(),
925
+ MeanIoU(
926
+ num_classes=self.n_classes,
927
+ name="iou",
928
+ dtype=float,
929
+ sparse_y_true=False,
930
+ sparse_y_pred=False,
931
+ ),
932
+ ],
933
+ )
934
+ else:
935
+ # Recompile to avoid crash
936
+ self.model_class.compile(
937
+ optimizer=Adam(learning_rate=self.learning_rate),
938
+ loss=self.loss_class,
939
+ metrics=[
940
+ "accuracy",
941
+ Precision(),
942
+ Recall(),
943
+ MeanIoU(
944
+ num_classes=self.n_classes,
945
+ name="iou",
946
+ dtype=float,
947
+ sparse_y_true=False,
948
+ sparse_y_pred=False,
949
+ ),
950
+ ],
951
+ )
952
+
953
+ else:
954
+ print("Compiling the classifier...")
955
+ self.model_class.compile(
956
+ optimizer=Adam(learning_rate=self.learning_rate),
957
+ loss=self.loss_class,
958
+ metrics=[
959
+ "accuracy",
960
+ Precision(),
961
+ Recall(),
962
+ MeanIoU(
963
+ num_classes=self.n_classes,
964
+ name="iou",
965
+ dtype=float,
966
+ sparse_y_true=False,
967
+ sparse_y_pred=False,
968
+ ),
969
+ ],
970
+ )
971
+
972
+ self.gather_callbacks("classifier")
973
+
974
+ # for i in range(30):
975
+ # for j in range(self.x_train.shape[-1]):
976
+ # plt.plot(self.x_train[i,:,j])
977
+ # plt.show()
978
+
979
+ if hasattr(self, "x_val"):
980
+
981
+ self.history_classifier = self.model_class.fit(
982
+ x=self.x_train,
983
+ y=self.y_class_train,
984
+ batch_size=self.batch_size,
985
+ class_weight=self.class_weights,
986
+ epochs=self.epochs,
987
+ validation_data=(self.x_val, self.y_class_val),
988
+ callbacks=self.cb,
989
+ verbose=1,
990
+ )
991
+ else:
992
+ self.history_classifier = self.model_class.fit(
993
+ x=self.x_train,
994
+ y=self.y_class_train,
995
+ batch_size=self.batch_size,
996
+ class_weight=self.class_weights,
997
+ epochs=self.epochs,
998
+ callbacks=self.cb,
999
+ validation_split=self.validation_split,
1000
+ verbose=1,
1001
+ )
1002
+
1003
+ if self.show_plots:
1004
+ self.plot_model_history(mode="classifier")
1005
+
1006
+ # Set current classification model as the best model
1007
+ self.model_class = load_model(
1008
+ os.sep.join([self.model_folder, "classifier.h5"]),
1009
+ custom_objects={"mse": MeanSquaredError()},
1010
+ )
1011
+ self.model_class.load_weights(os.sep.join([self.model_folder, "classifier.h5"]))
1012
+
1013
+ time_callback = next(
1014
+ (cb for cb in self.cb if type(cb).__name__ == "TimeHistory"), None
1015
+ )
1016
+ self.dico = {
1017
+ "history_classifier": self.history_classifier,
1018
+ "execution_time_classifier": time_callback.times if time_callback else [],
1019
+ }
1020
+
1021
+ if hasattr(self, "x_test"):
1022
+
1023
+ predictions = self.model_class.predict(self.x_test).argmax(axis=1)
1024
+ ground_truth = self.y_class_test.argmax(axis=1)
1025
+ assert (
1026
+ predictions.shape == ground_truth.shape
1027
+ ), "Mismatch in shape between the predictions and the ground truth..."
1028
+
1029
+ title = "Test data"
1030
+ IoU_score = jaccard_score(ground_truth, predictions, average=None)
1031
+ balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
1032
+ precision = precision_score(ground_truth, predictions, average=None)
1033
+ recall = recall_score(ground_truth, predictions, average=None)
1034
+
1035
+ print(f"Test IoU score: {IoU_score}")
1036
+ print(f"Test Balanced accuracy score: {balanced_accuracy}")
1037
+ print(f"Test Precision: {precision}")
1038
+ print(f"Test Recall: {recall}")
1039
+
1040
+ # Confusion matrix on test set
1041
+ results = confusion_matrix(ground_truth, predictions)
1042
+ self.dico.update(
1043
+ {
1044
+ "test_IoU": IoU_score,
1045
+ "test_balanced_accuracy": balanced_accuracy,
1046
+ "test_confusion": results,
1047
+ "test_precision": precision,
1048
+ "test_recall": recall,
1049
+ }
1050
+ )
1051
+
1052
+ if self.show_plots:
1053
+ try:
1054
+ ConfusionMatrixDisplay.from_predictions(
1055
+ ground_truth,
1056
+ predictions,
1057
+ cmap="Blues",
1058
+ normalize="pred",
1059
+ display_labels=["event", "no event", "left censored"],
1060
+ )
1061
+ plt.savefig(
1062
+ os.sep.join([self.model_folder, "test_confusion_matrix.png"]),
1063
+ bbox_inches="tight",
1064
+ dpi=300,
1065
+ )
1066
+ # plt.pause(3)
1067
+ plt.close()
1068
+ except Exception as e:
1069
+ print(e)
1070
+ pass
1071
+ print("Test set: ", classification_report(ground_truth, predictions))
1072
+
1073
+ if hasattr(self, "x_val"):
1074
+ predictions = self.model_class.predict(self.x_val).argmax(axis=1)
1075
+ ground_truth = self.y_class_val.argmax(axis=1)
1076
+ assert (
1077
+ ground_truth.shape == predictions.shape
1078
+ ), "Mismatch in shape between the predictions and the ground truth..."
1079
+ title = "Validation data"
1080
+
1081
+ # Validation scores
1082
+ IoU_score = jaccard_score(ground_truth, predictions, average=None)
1083
+ balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
1084
+ precision = precision_score(ground_truth, predictions, average=None)
1085
+ recall = recall_score(ground_truth, predictions, average=None)
1086
+
1087
+ print(f"Validation IoU score: {IoU_score}")
1088
+ print(f"Validation Balanced accuracy score: {balanced_accuracy}")
1089
+ print(f"Validation Precision: {precision}")
1090
+ print(f"Validation Recall: {recall}")
1091
+
1092
+ # Confusion matrix on validation set
1093
+ results = confusion_matrix(ground_truth, predictions)
1094
+ self.dico.update(
1095
+ {
1096
+ "val_IoU": IoU_score,
1097
+ "val_balanced_accuracy": balanced_accuracy,
1098
+ "val_confusion": results,
1099
+ "val_precision": precision,
1100
+ "val_recall": recall,
1101
+ }
1102
+ )
1103
+
1104
+ if self.show_plots:
1105
+ try:
1106
+ ConfusionMatrixDisplay.from_predictions(
1107
+ ground_truth,
1108
+ predictions,
1109
+ cmap="Blues",
1110
+ normalize="pred",
1111
+ display_labels=["event", "no event", "left censored"],
1112
+ )
1113
+ plt.savefig(
1114
+ os.sep.join(
1115
+ [self.model_folder, "validation_confusion_matrix.png"]
1116
+ ),
1117
+ bbox_inches="tight",
1118
+ dpi=300,
1119
+ )
1120
+ # plt.pause(3)
1121
+ plt.close()
1122
+ except Exception as e:
1123
+ print(e)
1124
+ pass
1125
+ print("Validation set: ", classification_report(ground_truth, predictions))
1126
+
1127
+ # Send result to GUI and wait
1128
+ for cb in self.cb:
1129
+ if hasattr(cb, "on_training_result"):
1130
+ cb.on_training_result(self.dico)
1131
+ time.sleep(3)
1132
+
1133
+ def train_regressor(self):
1134
+ """
1135
+ Trains the regressor component of the model to estimate the time of interest for events in signals.
1136
+
1137
+ This method compiles the regressor model (if not pretrained or if recompilation is requested) and
1138
+ trains it on a subset of the prepared dataset containing signals with events. The training process
1139
+ includes validation and early stopping based on mean squared error to prevent overfitting.
1140
+
1141
+ Notes
1142
+ -----
1143
+ - The regressor model estimates the time at which an event of interest occurs within each signal.
1144
+ - Only signals predicted to have an event by the classifier model are used for regressor training.
1145
+ - Model performance metrics and training history are saved for analysis.
1146
+
1147
+ """
1148
+
1149
+ # Compile model
1150
+ # if pretrained model
1151
+ if self.pretrained is not None:
1152
+ # if recompile
1153
+ if self.recompile_pretrained:
1154
+ print(
1155
+ "Recompiling the pretrained regressor model... Warning, this action reinitializes all the weights; are you sure that this is what you intended?"
1156
+ )
1157
+ self.model_reg.set_weights(clone_model(self.model_reg).get_weights())
1158
+ self.model_reg.compile(
1159
+ optimizer=Adam(learning_rate=self.learning_rate),
1160
+ loss=self.loss_reg,
1161
+ metrics=["mse", "mae"],
1162
+ )
1163
+ else:
1164
+ self.model_reg.compile(
1165
+ optimizer=Adam(learning_rate=self.learning_rate),
1166
+ loss=self.loss_reg,
1167
+ metrics=["mse", "mae"],
1168
+ )
1169
+
1170
+ else:
1171
+ print("Compiling the regressor...")
1172
+ self.model_reg.compile(
1173
+ optimizer=Adam(learning_rate=self.learning_rate),
1174
+ loss=self.loss_reg,
1175
+ metrics=["mse", "mae"],
1176
+ )
1177
+
1178
+ self.gather_callbacks("regressor")
1179
+
1180
+ # Train on subset of data with event
1181
+
1182
+ subset = self.x_train[np.argmax(self.y_class_train, axis=1) == 0]
1183
+ # for i in range(30):
1184
+ # plt.plot(subset[i,:,0],c="tab:red")
1185
+ # plt.plot(subset[i,:,1],c="tab:blue")
1186
+ # plt.show()
1187
+
1188
+ if hasattr(self, "x_val"):
1189
+ self.history_regressor = self.model_reg.fit(
1190
+ x=self.x_train[np.argmax(self.y_class_train, axis=1) == 0],
1191
+ y=self.y_time_train[np.argmax(self.y_class_train, axis=1) == 0],
1192
+ batch_size=self.batch_size,
1193
+ epochs=self.epochs * 2,
1194
+ validation_data=(
1195
+ self.x_val[np.argmax(self.y_class_val, axis=1) == 0],
1196
+ self.y_time_val[np.argmax(self.y_class_val, axis=1) == 0],
1197
+ ),
1198
+ callbacks=self.cb,
1199
+ verbose=1,
1200
+ )
1201
+ else:
1202
+ self.history_regressor = self.model_reg.fit(
1203
+ x=self.x_train[np.argmax(self.y_class_train, axis=1) == 0],
1204
+ y=self.y_time_train[np.argmax(self.y_class_train, axis=1) == 0],
1205
+ batch_size=self.batch_size,
1206
+ epochs=self.epochs * 2,
1207
+ callbacks=self.cb,
1208
+ validation_split=self.validation_split,
1209
+ verbose=1,
1210
+ )
1211
+
1212
+ if self.show_plots:
1213
+ self.plot_model_history(mode="regressor")
1214
+ time_callback = next(
1215
+ (cb for cb in self.cb if type(cb).__name__ == "TimeHistory"), None
1216
+ )
1217
+ self.dico.update(
1218
+ {
1219
+ "history_regressor": self.history_regressor,
1220
+ "execution_time_regressor": (
1221
+ time_callback.times if time_callback else []
1222
+ ),
1223
+ }
1224
+ )
1225
+
1226
+ # Evaluate best model
1227
+ self.model_reg = load_model(
1228
+ os.sep.join([self.model_folder, "regressor.h5"]),
1229
+ custom_objects={"mse": MeanSquaredError()},
1230
+ )
1231
+ self.model_reg.load_weights(os.sep.join([self.model_folder, "regressor.h5"]))
1232
+ self.evaluate_regression_model()
1233
+
1234
+ try:
1235
+ np.save(os.sep.join([self.model_folder, "scores.npy"]), self.dico)
1236
+ except Exception as e:
1237
+ print(e)
1238
+
1239
+ def plot_model_history(self, mode="regressor"):
1240
+ """
1241
+ Generates and saves plots of the training history for the classifier or regressor model.
1242
+
1243
+ Parameters
1244
+ ----------
1245
+ mode : str, optional
1246
+ Specifies which model's training history to plot. Options are "classifier" or "regressor". Default is "regressor".
1247
+
1248
+ Notes
1249
+ -----
1250
+ - Plots include loss and accuracy metrics over epochs for the classifier, and loss metrics for the regressor.
1251
+ - The plots are saved as image files in the model's output directory.
1252
+
1253
+ """
1254
+
1255
+ if mode == "regressor":
1256
+ try:
1257
+ plt.plot(self.history_regressor.history["loss"])
1258
+ plt.plot(self.history_regressor.history["val_loss"])
1259
+ plt.title("model loss")
1260
+ plt.ylabel("loss")
1261
+ plt.xlabel("epoch")
1262
+ plt.yscale("log")
1263
+ plt.legend(["train", "val"], loc="upper left")
1264
+ # plt.pause(3)
1265
+ plt.savefig(
1266
+ os.sep.join([self.model_folder, "regression_loss.png"]),
1267
+ bbox_inches="tight",
1268
+ dpi=300,
1269
+ )
1270
+ plt.close()
1271
+ except Exception as e:
1272
+ print(f"Error {e}; could not generate plot...")
1273
+ elif mode == "classifier":
1274
+ try:
1275
+ plt.plot(self.history_classifier.history["precision"])
1276
+ plt.plot(self.history_classifier.history["val_precision"])
1277
+ plt.title("model precision")
1278
+ plt.ylabel("precision")
1279
+ plt.xlabel("epoch")
1280
+ plt.legend(["train", "val"], loc="upper left")
1281
+ # plt.pause(3)
1282
+ plt.savefig(
1283
+ os.sep.join([self.model_folder, "classification_loss.png"]),
1284
+ bbox_inches="tight",
1285
+ dpi=300,
1286
+ )
1287
+ plt.close()
1288
+ except Exception as e:
1289
+ print(f"Error {e}; could not generate plot...")
1290
+ else:
1291
+ return None
1292
+
1293
+ def evaluate_regression_model(self):
1294
+ """
1295
+ Evaluates the performance of the trained regression model on test and validation datasets.
1296
+
1297
+ This method calculates and prints mean squared error and mean absolute error metrics for the regression model's
1298
+ predictions. It also generates regression plots comparing predicted times of interest to true values.
1299
+
1300
+ Notes
1301
+ -----
1302
+ - Evaluation is performed on both test and validation datasets, if available.
1303
+ - Regression plots and performance metrics are saved in the model's output directory.
1304
+
1305
+ """
1306
+ mse = MeanSquaredError()
1307
+ mae = MeanAbsoluteError()
1308
+
1309
+ if hasattr(self, "x_test"):
1310
+
1311
+ print("Evaluate on test set...")
1312
+ predictions = self.model_reg.predict(
1313
+ self.x_test[np.argmax(self.y_class_test, axis=1) == 0],
1314
+ batch_size=self.batch_size,
1315
+ )[:, 0]
1316
+ ground_truth = self.y_time_test[np.argmax(self.y_class_test, axis=1) == 0]
1317
+ assert (
1318
+ predictions.shape == ground_truth.shape
1319
+ ), "Shape mismatch between predictions and ground truths..."
1320
+
1321
+ test_mse = mse(ground_truth, predictions).numpy()
1322
+ test_mae = mae(ground_truth, predictions).numpy()
1323
+ print(f"MSE on test set: {test_mse}...")
1324
+ print(f"MAE on test set: {test_mae}...")
1325
+ if self.show_plots:
1326
+ regression_plot(
1327
+ predictions,
1328
+ ground_truth,
1329
+ savepath=os.sep.join([self.model_folder, "test_regression.png"]),
1330
+ )
1331
+ self.dico.update({"test_mse": test_mse, "test_mae": test_mae})
1332
+
1333
+ if hasattr(self, "x_val"):
1334
+ # Validation set
1335
+ predictions = self.model_reg.predict(
1336
+ self.x_val[np.argmax(self.y_class_val, axis=1) == 0],
1337
+ batch_size=self.batch_size,
1338
+ )[:, 0]
1339
+ ground_truth = self.y_time_val[np.argmax(self.y_class_val, axis=1) == 0]
1340
+ assert (
1341
+ predictions.shape == ground_truth.shape
1342
+ ), "Shape mismatch between predictions and ground truths..."
1343
+
1344
+ val_mse = mse(ground_truth, predictions).numpy()
1345
+ val_mae = mae(ground_truth, predictions).numpy()
1346
+
1347
+ if self.show_plots:
1348
+ regression_plot(
1349
+ predictions,
1350
+ ground_truth,
1351
+ savepath=os.sep.join(
1352
+ [self.model_folder, "validation_regression.png"]
1353
+ ),
1354
+ )
1355
+ print(f"MSE on validation set: {val_mse}...")
1356
+ print(f"MAE on validation set: {val_mae}...")
1357
+
1358
+ self.dico.update(
1359
+ {
1360
+ "val_mse": val_mse,
1361
+ "val_mae": val_mae,
1362
+ "val_predictions": predictions,
1363
+ "val_ground_truth": ground_truth,
1364
+ }
1365
+ )
1366
+
1367
+ # Send result to GUI and wait
1368
+ for cb in self.cb:
1369
+ if hasattr(cb, "on_training_result"):
1370
+ cb.on_training_result(self.dico)
1371
+ time.sleep(3)
1372
+
1373
+ def gather_callbacks(self, mode):
1374
+ """
1375
+ Prepares a list of Keras callbacks for model training based on the specified mode.
1376
+
1377
+ Parameters
1378
+ ----------
1379
+ mode : str
1380
+ The training mode for which callbacks are being prepared. Options are "classifier" or "regressor".
1381
+
1382
+ Notes
1383
+ -----
1384
+ - Callbacks include learning rate reduction on plateau, early stopping, model checkpointing, and TensorBoard logging.
1385
+ - The list of callbacks is stored in the class attribute `cb` and used during model training.
1386
+
1387
+ """
1388
+
1389
+ self.cb = []
1390
+
1391
+ if mode == "classifier":
1392
+
1393
+ reduce_lr = ReduceLROnPlateau(
1394
+ monitor="val_iou",
1395
+ factor=0.5,
1396
+ patience=30,
1397
+ cooldown=10,
1398
+ min_lr=5e-10,
1399
+ min_delta=1.0e-10,
1400
+ verbose=1,
1401
+ mode="max",
1402
+ )
1403
+ self.cb.append(reduce_lr)
1404
+ csv_logger = CSVLogger(
1405
+ os.sep.join([self.model_folder, "log_classifier.csv"]),
1406
+ append=True,
1407
+ separator=";",
1408
+ )
1409
+ self.cb.append(csv_logger)
1410
+ checkpoint_path = os.sep.join([self.model_folder, "classifier.h5"])
1411
+ cp_callback = ModelCheckpoint(
1412
+ checkpoint_path,
1413
+ monitor="val_iou",
1414
+ mode="max",
1415
+ verbose=1,
1416
+ save_best_only=True,
1417
+ save_weights_only=False,
1418
+ save_freq="epoch",
1419
+ )
1420
+ self.cb.append(cp_callback)
1421
+
1422
+ callback_stop = EarlyStopping(monitor="val_iou", mode="max", patience=100)
1423
+ self.cb.append(callback_stop)
1424
+
1425
+ elif mode == "regressor":
1426
+
1427
+ reduce_lr = ReduceLROnPlateau(
1428
+ monitor="val_loss",
1429
+ factor=0.5,
1430
+ patience=30,
1431
+ cooldown=10,
1432
+ min_lr=5e-10,
1433
+ min_delta=1.0e-10,
1434
+ verbose=1,
1435
+ mode="min",
1436
+ )
1437
+ self.cb.append(reduce_lr)
1438
+
1439
+ csv_logger = CSVLogger(
1440
+ os.sep.join([self.model_folder, "log_regressor.csv"]),
1441
+ append=True,
1442
+ separator=";",
1443
+ )
1444
+ self.cb.append(csv_logger)
1445
+
1446
+ checkpoint_path = os.sep.join([self.model_folder, "regressor.h5"])
1447
+ cp_callback = ModelCheckpoint(
1448
+ checkpoint_path,
1449
+ monitor="val_loss",
1450
+ mode="min",
1451
+ verbose=1,
1452
+ save_best_only=True,
1453
+ save_weights_only=False,
1454
+ save_freq="epoch",
1455
+ )
1456
+ self.cb.append(cp_callback)
1457
+
1458
+ callback_stop = EarlyStopping(monitor="val_loss", mode="min", patience=200)
1459
+ self.cb.append(callback_stop)
1460
+
1461
+ log_dir = self.model_folder + os.sep
1462
+ cb_tb = TensorBoard(log_dir=log_dir, update_freq="batch")
1463
+ self.cb.append(cb_tb)
1464
+
1465
+ cb_time = TimeHistory()
1466
+ self.cb.append(cb_time)
1467
+
1468
+ if hasattr(self, "callbacks") and self.callbacks is not None:
1469
+ self.cb.extend(self.callbacks)
1470
+
1471
+ def prepare_sets(self):
1472
+ """
1473
+ Generates and preprocesses training, validation, and test sets from loaded annotations.
1474
+
1475
+ This method loads signal data from annotation files, normalizes and interpolates the signals, and splits
1476
+ the dataset into training, validation, and test sets according to specified proportions.
1477
+
1478
+ Notes
1479
+ -----
1480
+ - Signal annotations are expected to be stored in .npy format and contain required channels and event information.
1481
+ - The method applies specified normalization and interpolation options to prepare the signals for model training.
1482
+
1483
+ """
1484
+
1485
+ self.x_set = []
1486
+ self.y_time_set = []
1487
+ self.y_class_set = []
1488
+
1489
+ if isinstance(self.list_of_sets[0], str):
1490
+ # Case 1: a list of npy files to be loaded
1491
+ for s in self.list_of_sets:
1492
+
1493
+ signal_dataset = self.load_set(s)
1494
+ selected_signals, max_length = self.find_best_signal_match(
1495
+ signal_dataset
1496
+ )
1497
+ signals_recast, classes, times_of_interest = (
1498
+ self.cast_signals_into_training_data(
1499
+ signal_dataset, selected_signals, max_length
1500
+ )
1501
+ )
1502
+ signals_recast, times_of_interest = self.normalize_signals(
1503
+ signals_recast, times_of_interest
1504
+ )
1505
+
1506
+ self.x_set.extend(signals_recast)
1507
+ self.y_time_set.extend(times_of_interest)
1508
+ self.y_class_set.extend(classes)
1509
+
1510
+ elif isinstance(self.list_of_sets[0], list):
1511
+ # Case 2: a list of sets (already loaded)
1512
+ for signal_dataset in self.list_of_sets:
1513
+
1514
+ selected_signals, max_length = self.find_best_signal_match(
1515
+ signal_dataset
1516
+ )
1517
+ signals_recast, classes, times_of_interest = (
1518
+ self.cast_signals_into_training_data(
1519
+ signal_dataset, selected_signals, max_length
1520
+ )
1521
+ )
1522
+ signals_recast, times_of_interest = self.normalize_signals(
1523
+ signals_recast, times_of_interest
1524
+ )
1525
+
1526
+ self.x_set.extend(signals_recast)
1527
+ self.y_time_set.extend(times_of_interest)
1528
+ self.y_class_set.extend(classes)
1529
+
1530
+ self.x_set = np.array(self.x_set).astype(np.float32)
1531
+ self.x_set = self.interpolate_signals(self.x_set)
1532
+
1533
+ self.y_time_set = np.array(self.y_time_set).astype(np.float32)
1534
+ self.y_class_set = np.array(self.y_class_set).astype(np.float32)
1535
+
1536
+ class_test = np.isin(self.y_class_set, [0, 1, 2])
1537
+ self.x_set = self.x_set[class_test]
1538
+ self.y_time_set = self.y_time_set[class_test]
1539
+ self.y_class_set = self.y_class_set[class_test]
1540
+
1541
+ # Compute class weights and one-hot encode
1542
+ self.class_weights = compute_weights(self.y_class_set)
1543
+ self.nbr_classes = 3 # len(np.unique(self.y_class_set))
1544
+ self.y_class_set = to_categorical(self.y_class_set, num_classes=3)
1545
+
1546
+ ds = train_test_split(
1547
+ self.x_set,
1548
+ self.y_time_set,
1549
+ self.y_class_set,
1550
+ validation_size=self.validation_split,
1551
+ test_size=self.test_split,
1552
+ )
1553
+
1554
+ self.x_train = ds["x_train"]
1555
+ self.x_val = ds["x_val"]
1556
+ self.y_time_train = ds["y1_train"].astype(np.float32)
1557
+ self.y_time_val = ds["y1_val"].astype(np.float32)
1558
+ self.y_class_train = ds["y2_train"]
1559
+ self.y_class_val = ds["y2_val"]
1560
+
1561
+ if self.test_split > 0:
1562
+ self.x_test = ds["x_test"]
1563
+ self.y_time_test = ds["y1_test"].astype(np.float32)
1564
+ self.y_class_test = ds["y2_test"]
1565
+
1566
+ if self.augment:
1567
+ self.augment_training_set()
1568
+
1569
+ def augment_training_set(self, time_shift=True):
1570
+ """
1571
+ Augments the training dataset with artificially generated data to increase model robustness.
1572
+
1573
+ Parameters
1574
+ ----------
1575
+ time_shift : bool, optional
1576
+ Specifies whether to include time-shifted versions of signals in the augmented dataset. Default is True.
1577
+
1578
+ Notes
1579
+ -----
1580
+ - Augmentation strategies include random time shifting and signal modifications to simulate variations in real data.
1581
+ - The augmented dataset is used for training the classifier and regressor models to improve generalization.
1582
+
1583
+ """
1584
+
1585
+ nbr_augment = self.augmentation_factor * len(self.x_train)
1586
+ randomize = np.arange(len(self.x_train))
1587
+
1588
+ unique, counts = np.unique(
1589
+ self.y_class_train.argmax(axis=1), return_counts=True
1590
+ )
1591
+ frac = counts / sum(counts)
1592
+ weights = [frac[0] / f for f in frac]
1593
+ weights[0] = weights[0] * 3
1594
+
1595
+ self.pre_augment_weights = weights / sum(weights)
1596
+ weights_array = [
1597
+ self.pre_augment_weights[a.argmax()] for a in self.y_class_train
1598
+ ]
1599
+
1600
+ indices = random.choices(randomize, k=nbr_augment, weights=weights_array)
1601
+
1602
+ x_train_aug = []
1603
+ y_time_train_aug = []
1604
+ y_class_train_aug = []
1605
+
1606
+ counts = [0.0, 0.0, 0.0]
1607
+ # warning augmentation creates class 2 even if does not exist in data, need to address this
1608
+ for k in indices:
1609
+ counts[self.y_class_train[k].argmax()] += 1
1610
+ aug = augmenter(
1611
+ self.x_train[k],
1612
+ self.y_time_train[k],
1613
+ self.y_class_train[k],
1614
+ self.model_signal_length,
1615
+ time_shift=time_shift,
1616
+ )
1617
+ x_train_aug.append(aug[0])
1618
+ y_time_train_aug.append(aug[1])
1619
+ y_class_train_aug.append(aug[2])
1620
+
1621
+ # Save augmented training set
1622
+ self.x_train = np.array(x_train_aug)
1623
+ self.y_time_train = np.array(y_time_train_aug)
1624
+ self.y_class_train = np.array(y_class_train_aug)
1625
+
1626
+ self.class_weights = compute_weights(self.y_class_train.argmax(axis=1))
1627
+ print(f"New class weights: {self.class_weights}...")
1628
+
1629
+ def load_set(self, signal_dataset):
1630
+ return np.load(signal_dataset, allow_pickle=True)
1631
+
1632
+ def find_best_signal_match(self, signal_dataset):
1633
+
1634
+ required_signals = self.channel_option
1635
+ available_signals = list(signal_dataset[0].keys())
1636
+
1637
+ selected_signals = []
1638
+ for s in required_signals:
1639
+ pattern_test = [s in a for a in available_signals]
1640
+ if np.any(pattern_test):
1641
+ valid_columns = np.array(available_signals)[np.array(pattern_test)]
1642
+ if len(valid_columns) == 1:
1643
+ selected_signals.append(valid_columns[0])
1644
+ else:
1645
+ print(f"Found several candidate signals: {valid_columns}")
1646
+ for vc in natsorted(valid_columns):
1647
+ if "circle" in vc:
1648
+ selected_signals.append(vc)
1649
+ break
1650
+ else:
1651
+ selected_signals.append(valid_columns[0])
1652
+ else:
1653
+ return None
1654
+
1655
+ key_to_check = selected_signals[0] # self.channel_option[0]
1656
+ signal_lengths = [len(l[key_to_check]) for l in signal_dataset]
1657
+ max_length = np.amax(signal_lengths)
1658
+
1659
+ return selected_signals, max_length
1660
+
1661
+ def cast_signals_into_training_data(
1662
+ self, signal_dataset, selected_signals, max_length
1663
+ ):
1664
+
1665
+ signals_recast = np.zeros((len(signal_dataset), max_length, self.n_channels))
1666
+ classes = np.zeros(len(signal_dataset))
1667
+ times_of_interest = np.zeros(len(signal_dataset))
1668
+
1669
+ for k in range(len(signal_dataset)):
1670
+
1671
+ for i in range(self.n_channels):
1672
+ try:
1673
+ # take into account timeline for accurate time regression
1674
+
1675
+ if selected_signals[i].startswith("pair_"):
1676
+ timeline = signal_dataset[k]["pair_FRAME"].astype(int)
1677
+ elif selected_signals[i].startswith("reference_"):
1678
+ timeline = signal_dataset[k]["reference_FRAME"].astype(int)
1679
+ elif selected_signals[i].startswith("neighbor_"):
1680
+ timeline = signal_dataset[k]["neighbor_FRAME"].astype(int)
1681
+ else:
1682
+ timeline = signal_dataset[k]["FRAME"].astype(int)
1683
+ signals_recast[k, timeline, i] = signal_dataset[k][
1684
+ selected_signals[i]
1685
+ ]
1686
+ except:
1687
+ print(
1688
+ f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation..."
1689
+ )
1690
+ pass
1691
+
1692
+ classes[k] = signal_dataset[k]["class"]
1693
+ times_of_interest[k] = signal_dataset[k]["time_of_interest"]
1694
+
1695
+ # Correct absurd times of interest
1696
+ times_of_interest[np.nonzero(classes)] = -1
1697
+ times_of_interest[(times_of_interest <= 0.0)] = -1
1698
+
1699
+ return signals_recast, classes, times_of_interest
1700
+
1701
+ def normalize_signals(self, signals_recast, times_of_interest):
1702
+
1703
+ signals_recast = pad_to_model_length(signals_recast, self.model_signal_length)
1704
+ if self.normalize:
1705
+ signals_recast = normalize_signal_set(
1706
+ signals_recast,
1707
+ self.channel_option,
1708
+ normalization_percentile=self.normalization_percentile,
1709
+ normalization_values=self.normalization_values,
1710
+ normalization_clip=self.normalization_clip,
1711
+ )
1712
+
1713
+ # Trivial normalization for time of interest
1714
+ times_of_interest /= self.model_signal_length
1715
+
1716
+ return signals_recast, times_of_interest
1717
+
1718
+
1719
+ def residual_block1D(
1720
+ x, number_of_filters, kernel_size=8, match_filter_size=True, connection="identity"
1721
+ ):
1722
+ """
1723
+
1724
+ Create a 1D residual block.
1725
+
1726
+ Parameters
1727
+ ----------
1728
+ x : Tensor
1729
+ Input tensor.
1730
+ number_of_filters : int
1731
+ Number of filters in the convolutional layers.
1732
+ match_filter_size : bool, optional
1733
+ Whether to match the filter size of the skip connection to the output. Default is True.
1734
+
1735
+ Returns
1736
+ -------
1737
+ Tensor
1738
+ Output tensor of the residual block.
1739
+
1740
+ Notes
1741
+ -----
1742
+ This function creates a 1D residual block by performing the original mapping followed by adding a skip connection
1743
+ and applying non-linear activation. The skip connection allows the gradient to flow directly to earlier layers and
1744
+ helps mitigate the vanishing gradient problem. The residual block consists of three convolutional layers with
1745
+ batch normalization and ReLU activation functions.
1746
+
1747
+ If `match_filter_size` is True, the skip connection is adjusted to have the same number of filters as the output.
1748
+ Otherwise, the skip connection is kept as is.
1749
+
1750
+ Examples
1751
+ --------
1752
+ >>> inputs = Input(shape=(10, 3))
1753
+ >>> x = residual_block1D(inputs, 64)
1754
+ # Create a 1D residual block with 64 filters and apply it to the input tensor.
1755
+
1756
+ """
1757
+
1758
+ # Create skip connection
1759
+ x_skip = x
1760
+
1761
+ # Perform the original mapping
1762
+ if connection == "identity":
1763
+ x = Conv1D(
1764
+ number_of_filters, kernel_size=kernel_size, strides=1, padding="same"
1765
+ )(x_skip)
1766
+ elif connection == "projection":
1767
+ x = ZeroPadding1D(padding=kernel_size // 2)(x_skip)
1768
+ x = Conv1D(
1769
+ number_of_filters, kernel_size=kernel_size, strides=2, padding="valid"
1770
+ )(x)
1771
+ x = BatchNormalization()(x)
1772
+ x = Activation("relu")(x)
1773
+
1774
+ x = Conv1D(number_of_filters, kernel_size=kernel_size, strides=1, padding="same")(x)
1775
+ x = BatchNormalization()(x)
1776
+
1777
+ if match_filter_size and connection == "identity":
1778
+ x_skip = Conv1D(number_of_filters, kernel_size=1, padding="same")(x_skip)
1779
+ elif match_filter_size and connection == "projection":
1780
+ x_skip = Conv1D(number_of_filters, kernel_size=1, strides=2, padding="valid")(
1781
+ x_skip
1782
+ )
1783
+
1784
+ # Add the skip connection to the regular mapping
1785
+ x = Add()([x, x_skip])
1786
+
1787
+ # Nonlinearly activate the result
1788
+ x = Activation("relu")(x)
1789
+
1790
+ # Return the result
1791
+ return x
1792
+
1793
+
1794
+ def MultiscaleResNetModel(
1795
+ n_channels,
1796
+ n_classes=3,
1797
+ dropout_rate=0,
1798
+ dense_collection=0,
1799
+ use_pooling=True,
1800
+ header="classifier",
1801
+ model_signal_length=128,
1802
+ ):
1803
+ """
1804
+
1805
+ Define a generic ResNet 1D encoder model.
1806
+
1807
+ Parameters
1808
+ ----------
1809
+ n_channels : int
1810
+ Number of input channels.
1811
+ n_blocks : int
1812
+ Number of residual blocks in the model.
1813
+ n_classes : int, optional
1814
+ Number of output classes. Default is 3.
1815
+ dropout_rate : float, optional
1816
+ Dropout rate to be applied. Default is 0.
1817
+ dense_collection : int, optional
1818
+ Number of neurons in the dense layer. Default is 0.
1819
+ header : str, optional
1820
+ Type of the model header. "classifier" for classification, "regressor" for regression. Default is "classifier".
1821
+ model_signal_length : int, optional
1822
+ Length of the input signal. Default is 128.
1823
+
1824
+ Returns
1825
+ -------
1826
+ keras.models.Model
1827
+ ResNet 1D encoder model.
1828
+
1829
+ Notes
1830
+ -----
1831
+ This function defines a generic ResNet 1D encoder model with the specified number of input channels, residual
1832
+ blocks, output classes, dropout rate, dense collection, and model header. The model architecture follows the
1833
+ ResNet principles with 1D convolutional layers and residual connections. The final activation and number of
1834
+ neurons in the output layer are determined based on the header type.
1835
+
1836
+ Examples
1837
+ --------
1838
+ >>> model = ResNetModel(n_channels=3, n_blocks=4, n_classes=2, dropout_rate=0.2)
1839
+ # Define a ResNet 1D encoder model with 3 input channels, 4 residual blocks, and 2 output classes.
1840
+
1841
+ """
1842
+ if header == "classifier":
1843
+ final_activation = "softmax"
1844
+ neurons_final = n_classes
1845
+ elif header == "regressor":
1846
+ final_activation = "linear"
1847
+ neurons_final = 1
1848
+ else:
1849
+ return None
1850
+
1851
+ inputs = Input(
1852
+ shape=(
1853
+ model_signal_length,
1854
+ n_channels,
1855
+ )
1856
+ )
1857
+ x = ZeroPadding1D(3)(inputs)
1858
+ x = Conv1D(64, kernel_size=7, strides=2, padding="valid", use_bias=False)(x)
1859
+ x = BatchNormalization()(x)
1860
+ x = ZeroPadding1D(1)(x)
1861
+ x_common = MaxPooling1D(pool_size=3, strides=2, padding="valid")(x)
1862
+
1863
+ # Block 1
1864
+ x1 = residual_block1D(x_common, 64, kernel_size=7, connection="projection")
1865
+ x1 = residual_block1D(x1, 128, kernel_size=7, connection="projection")
1866
+ x1 = residual_block1D(x1, 256, kernel_size=7, connection="projection")
1867
+ x1 = GlobalAveragePooling1D()(x1)
1868
+
1869
+ # Block 2
1870
+ x2 = residual_block1D(x_common, 64, kernel_size=5, connection="projection")
1871
+ x2 = residual_block1D(x2, 128, kernel_size=5, connection="projection")
1872
+ x2 = residual_block1D(x2, 256, kernel_size=5, connection="projection")
1873
+ x2 = GlobalAveragePooling1D()(x2)
1874
+
1875
+ # Block 3
1876
+ x3 = residual_block1D(x_common, 64, kernel_size=3, connection="projection")
1877
+ x3 = residual_block1D(x3, 128, kernel_size=3, connection="projection")
1878
+ x3 = residual_block1D(x3, 256, kernel_size=3, connection="projection")
1879
+ x3 = GlobalAveragePooling1D()(x3)
1880
+
1881
+ x_combined = Concatenate()([x1, x2, x3])
1882
+ x_combined = Flatten()(x_combined)
1883
+
1884
+ if dense_collection > 0:
1885
+ x_combined = Dense(dense_collection)(x_combined)
1886
+ if dropout_rate > 0:
1887
+ x_combined = Dropout(dropout_rate)(x_combined)
1888
+
1889
+ x_combined = Dense(neurons_final, activation=final_activation, name=header)(
1890
+ x_combined
1891
+ )
1892
+ model = Model(inputs, x_combined, name=header)
1893
+
1894
+ return model
1895
+
1896
+
1897
+ def ResNetModelCurrent(
1898
+ n_channels,
1899
+ n_slices,
1900
+ depth=2,
1901
+ use_pooling=True,
1902
+ n_classes=3,
1903
+ dropout_rate=0.1,
1904
+ dense_collection=512,
1905
+ header="classifier",
1906
+ model_signal_length=128,
1907
+ ):
1908
+ """
1909
+ Creates a ResNet-based model tailored for signal classification or regression tasks.
1910
+
1911
+ This function constructs a 1D ResNet architecture with specified parameters. The model can be configured
1912
+ for either classification or regression tasks, determined by the `header` parameter. It consists of
1913
+ configurable ResNet blocks, global average pooling, optional dense layers, and dropout for regularization.
1914
+
1915
+ Parameters
1916
+ ----------
1917
+ n_channels : int
1918
+ The number of channels in the input signal.
1919
+ n_slices : int
1920
+ The number of slices (or ResNet blocks) to use in the model.
1921
+ depth : int, optional
1922
+ The depth of the network, i.e., how many times the number of filters is doubled. Default is 2.
1923
+ use_pooling : bool, optional
1924
+ Whether to use MaxPooling between ResNet blocks. Default is True.
1925
+ n_classes : int, optional
1926
+ The number of classes for the classification task. Ignored for regression. Default is 3.
1927
+ dropout_rate : float, optional
1928
+ The dropout rate for regularization. Default is 0.1.
1929
+ dense_collection : int, optional
1930
+ The number of neurons in the dense layer following global pooling. If 0, the dense layer is omitted. Default is 512.
1931
+ header : str, optional
1932
+ Specifies the task type: "classifier" for classification or "regressor" for regression. Default is "classifier".
1933
+ model_signal_length : int, optional
1934
+ The length of the input signal. Default is 128.
1935
+
1936
+ Returns
1937
+ -------
1938
+ keras.Model
1939
+ The constructed Keras model ready for training or inference.
1940
+
1941
+ Notes
1942
+ -----
1943
+ - The model uses Conv1D layers for signal processing and applies global average pooling before the final classification
1944
+ or regression layer.
1945
+ - The choice of `final_activation` and `neurons_final` depends on the task: "softmax" and `n_classes` for classification,
1946
+ and "linear" and 1 for regression.
1947
+ - This function relies on a custom `residual_block1D` function for constructing ResNet blocks.
1948
+
1949
+ Examples
1950
+ --------
1951
+ >>> model = ResNetModelCurrent(n_channels=1, n_slices=2, depth=2, use_pooling=True, n_classes=3, dropout_rate=0.1, dense_collection=512, header="classifier", model_signal_length=128)
1952
+ # Creates a ResNet model configured for classification with 3 classes.
1953
+
1954
+ """
1955
+
1956
+ if header == "classifier":
1957
+ final_activation = "softmax"
1958
+ neurons_final = n_classes
1959
+ elif header == "regressor":
1960
+ final_activation = "linear"
1961
+ neurons_final = 1
1962
+ else:
1963
+ return None
1964
+
1965
+ inputs = Input(
1966
+ shape=(
1967
+ model_signal_length,
1968
+ n_channels,
1969
+ )
1970
+ )
1971
+ x2 = Conv1D(64, kernel_size=1, strides=1, padding="same")(inputs)
1972
+
1973
+ n_filters = 64
1974
+ for k in range(depth):
1975
+ for i in range(n_slices):
1976
+ x2 = residual_block1D(x2, n_filters, kernel_size=8)
1977
+ n_filters *= 2
1978
+ if use_pooling and k != (depth - 1):
1979
+ x2 = MaxPooling1D()(x2)
1980
+
1981
+ x2 = GlobalAveragePooling1D()(x2)
1982
+ if dense_collection > 0:
1983
+ x2 = Dense(dense_collection)(x2)
1984
+ if dropout_rate > 0:
1985
+ x2 = Dropout(dropout_rate)(x2)
1986
+
1987
+ x2 = Dense(neurons_final, activation=final_activation, name=header)(x2)
1988
+ model = Model(inputs, x2, name=header)
1989
+
1990
+ return model
1991
+
1992
+
1993
+ def _get_time_history_class():
1994
+ """Factory function to get TimeHistory class with lazy TensorFlow import."""
1995
+
1996
+ class TimeHistory(Callback):
1997
+ """
1998
+ A custom Keras callback to log the duration of each epoch during training.
1999
+
2000
+ This callback records the time taken for each epoch during the model training process, allowing for
2001
+ monitoring of training efficiency and performance over time. The times are stored in a list, with each
2002
+ element representing the duration of an epoch in seconds.
2003
+
2004
+ Attributes
2005
+ ----------
2006
+ times : list
2007
+ A list of times (in seconds) taken for each epoch during the training. This list is populated as the
2008
+ training progresses.
2009
+
2010
+ Methods
2011
+ -------
2012
+ on_train_begin(logs={})
2013
+ Initializes the list of times at the beginning of training.
2014
+
2015
+ on_epoch_begin(epoch, logs={})
2016
+ Records the start time of the current epoch.
2017
+
2018
+ on_epoch_end(epoch, logs={})
2019
+ Calculates and appends the duration of the current epoch to the `times` list.
2020
+
2021
+ Notes
2022
+ -----
2023
+ - This callback is intended to be used with the `fit` method of Keras models.
2024
+ - The time measurements are made using the `time.time()` function, which provides wall-clock time.
2025
+
2026
+ Examples
2027
+ --------
2028
+ >>> from keras.models import Sequential
2029
+ >>> from keras.layers import Dense
2030
+ >>> model = Sequential([Dense(10, activation='relu', input_shape=(20,)), Dense(1)])
2031
+ >>> time_callback = TimeHistory()
2032
+ >>> model.compile(optimizer='adam', loss='mean_squared_error')
2033
+ >>> model.fit(x_train, y_train, epochs=10, callbacks=[time_callback])
2034
+ >>> print(time_callback.times)
2035
+ # This will print the time taken for each epoch during the training.
2036
+
2037
+ """
2038
+
2039
+ def on_train_begin(self, logs={}):
2040
+ self.times = []
2041
+
2042
+ def on_epoch_begin(self, epoch, logs={}):
2043
+ self.epoch_time_start = time.time()
2044
+
2045
+ def on_epoch_end(self, epoch, logs={}):
2046
+ self.times.append(time.time() - self.epoch_time_start)
2047
+
2048
+ return TimeHistory
2049
+
2050
+
2051
+ def _interpret_normalization_parameters(
2052
+ n_channels, normalization_percentile, normalization_values, normalization_clip
2053
+ ):
2054
+ """
2055
+ Interprets and validates normalization parameters for each channel.
2056
+
2057
+ This function ensures the normalization parameters are correctly formatted and expanded to match
2058
+ the number of channels in the dataset. It provides default values and expands single values into
2059
+ lists to match the number of channels if necessary.
2060
+
2061
+ Parameters
2062
+ ----------
2063
+ n_channels : int
2064
+ The number of channels in the dataset.
2065
+ normalization_percentile : list of bool or bool, optional
2066
+ Specifies whether to normalize each channel based on percentile values. If a single bool is provided,
2067
+ it is expanded to a list matching the number of channels. Default is True for all channels.
2068
+ normalization_values : list of lists or list, optional
2069
+ Specifies the percentile values [lower, upper] for normalization for each channel. If a single pair
2070
+ is provided, it is expanded to match the number of channels. Default is [[0.1, 99.9]] for all channels.
2071
+ normalization_clip : list of bool or bool, optional
2072
+ Specifies whether to clip the normalized values for each channel to the range [0, 1]. If a single bool
2073
+ is provided, it is expanded to a list matching the number of channels. Default is False for all channels.
2074
+
2075
+ Returns
2076
+ -------
2077
+ tuple
2078
+ A tuple containing three lists: `normalization_percentile`, `normalization_values`, and `normalization_clip`,
2079
+ each of length `n_channels`, representing the interpreted and validated normalization parameters for each channel.
2080
+
2081
+ Raises
2082
+ ------
2083
+ AssertionError
2084
+ If the lengths of the provided lists do not match `n_channels`.
2085
+
2086
+ Examples
2087
+ --------
2088
+ >>> n_channels = 2
2089
+ >>> normalization_percentile = True
2090
+ >>> normalization_values = [0.1, 99.9]
2091
+ >>> normalization_clip = False
2092
+ >>> params = _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip)
2093
+ >>> print(params)
2094
+ # ([True, True], [[0.1, 99.9], [0.1, 99.9]], [False, False])
2095
+
2096
+ """
2097
+
2098
+ if normalization_percentile is None:
2099
+ normalization_percentile = [True] * n_channels
2100
+ if normalization_values is None:
2101
+ normalization_values = [[0.1, 99.9]] * n_channels
2102
+ if normalization_clip is None:
2103
+ normalization_clip = [False] * n_channels
2104
+
2105
+ if isinstance(normalization_percentile, bool):
2106
+ normalization_percentile = [normalization_percentile] * n_channels
2107
+ if isinstance(normalization_clip, bool):
2108
+ normalization_clip = [normalization_clip] * n_channels
2109
+ if len(normalization_values) == 2 and not isinstance(normalization_values[0], list):
2110
+ normalization_values = [normalization_values] * n_channels
2111
+
2112
+ assert len(normalization_values) == n_channels
2113
+ assert len(normalization_clip) == n_channels
2114
+ assert len(normalization_percentile) == n_channels
2115
+
2116
+ return normalization_percentile, normalization_values, normalization_clip
2117
+
2118
+
2119
+ def normalize_signal_set(
2120
+ signal_set,
2121
+ channel_option,
2122
+ percentile_alive=[0.01, 99.99],
2123
+ percentile_dead=[0.5, 99.999],
2124
+ percentile_generic=[0.01, 99.99],
2125
+ normalization_percentile=None,
2126
+ normalization_values=None,
2127
+ normalization_clip=None,
2128
+ ):
2129
+ """
2130
+ Normalizes a set of single-cell signals across specified channels using given percentile values or specific normalization parameters.
2131
+
2132
+ This function applies normalization to each channel in the signal set based on the provided normalization parameters,
2133
+ which can be defined globally or per channel. The normalization process aims to scale the signal values to a standard
2134
+ range, improving the consistency and comparability of signal measurements across samples.
2135
+
2136
+ Parameters
2137
+ ----------
2138
+ signal_set : ndarray
2139
+ A 3D numpy array representing the set of signals to be normalized, with dimensions corresponding to (samples, time points, channels).
2140
+ channel_option : list of str
2141
+ A list specifying the channels included in the signal set and their corresponding normalization strategy based on channel names.
2142
+ percentile_alive : list of float, optional
2143
+ The percentile values [lower, upper] used for normalization of signals from channels labeled as 'alive'. Default is [0.01, 99.99].
2144
+ percentile_dead : list of float, optional
2145
+ The percentile values [lower, upper] used for normalization of signals from channels labeled as 'dead'. Default is [0.5, 99.999].
2146
+ percentile_generic : list of float, optional
2147
+ The percentile values [lower, upper] used for normalization of signals from channels not specifically labeled as 'alive' or 'dead'.
2148
+ Default is [0.01, 99.99].
2149
+ normalization_percentile : list of bool or None, optional
2150
+ Specifies whether to normalize each channel based on percentile values. If None, the default percentile strategy is applied
2151
+ based on `channel_option`. If a list, it should match the length of `channel_option`.
2152
+ normalization_values : list of lists or None, optional
2153
+ Specifies the percentile values [lower, upper] or fixed values [min, max] for normalization for each channel. Overrides
2154
+ `percentile_alive`, `percentile_dead`, and `percentile_generic` if provided.
2155
+ normalization_clip : list of bool or None, optional
2156
+ Specifies whether to clip the normalized values for each channel to the range [0, 1]. If None, clipping is disabled by default.
2157
+
2158
+ Returns
2159
+ -------
2160
+ ndarray
2161
+ The normalized signal set with the same shape as the input `signal_set`.
2162
+
2163
+ Notes
2164
+ -----
2165
+ - The function supports different normalization strategies for 'alive', 'dead', and generic signal channels, which can be customized
2166
+ via `channel_option` and the percentile parameters.
2167
+ - Normalization parameters (`normalization_percentile`, `normalization_values`, `normalization_clip`) are interpreted and validated
2168
+ by calling `_interpret_normalization_parameters`.
2169
+
2170
+ Examples
2171
+ --------
2172
+ >>> signal_set = np.random.rand(100, 128, 2) # 100 samples, 128 time points, 2 channels
2173
+ >>> channel_option = ['alive', 'dead']
2174
+ >>> normalized_signals = normalize_signal_set(signal_set, channel_option)
2175
+ # Normalizes the signal set based on the default percentile values for 'alive' and 'dead' channels.
2176
+
2177
+ """
2178
+
2179
+ # Check normalization params are ok
2180
+ n_channels = len(channel_option)
2181
+ normalization_percentile, normalization_values, normalization_clip = (
2182
+ _interpret_normalization_parameters(
2183
+ n_channels,
2184
+ normalization_percentile,
2185
+ normalization_values,
2186
+ normalization_clip,
2187
+ )
2188
+ )
2189
+ for k, channel in enumerate(channel_option):
2190
+
2191
+ zero_values = []
2192
+ for i in range(len(signal_set)):
2193
+ zeros_loc = np.where(signal_set[i, :, k] == 0)
2194
+ zero_values.append(zeros_loc)
2195
+
2196
+ values = signal_set[:, :, k]
2197
+
2198
+ if normalization_percentile[k]:
2199
+ min_val = np.nanpercentile(
2200
+ values[values != 0.0], normalization_values[k][0]
2201
+ )
2202
+ max_val = np.nanpercentile(
2203
+ values[values != 0.0], normalization_values[k][1]
2204
+ )
2205
+ else:
2206
+ min_val = normalization_values[k][0]
2207
+ max_val = normalization_values[k][1]
2208
+
2209
+ signal_set[:, :, k] -= min_val
2210
+ signal_set[:, :, k] /= max_val - min_val
2211
+
2212
+ if normalization_clip[k]:
2213
+ signal_set[:, :, k] = np.clip(signal_set[:, :, k], 0.0, 1.0)
2214
+
2215
+ for i, z in enumerate(zero_values):
2216
+ signal_set[i, z, k] = 0.0
2217
+
2218
+ return signal_set
2219
+
2220
+
2221
+ def pad_to_model_length(signal_set, model_signal_length):
2222
+ """
2223
+
2224
+ Pad the signal set to match the specified model signal length.
2225
+
2226
+ Parameters
2227
+ ----------
2228
+ signal_set : array-like
2229
+ The signal set to be padded.
2230
+ model_signal_length : int
2231
+ The desired length of the model signal.
2232
+
2233
+ Returns
2234
+ -------
2235
+ array-like
2236
+ The padded signal set.
2237
+
2238
+ Notes
2239
+ -----
2240
+ This function pads the signal set with zeros along the second dimension (axis 1) to match the specified model signal
2241
+ length. The padding is applied to the end of the signals, increasing their length.
2242
+
2243
+ Examples
2244
+ --------
2245
+ >>> signal_set = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
2246
+ >>> padded_signals = pad_to_model_length(signal_set, 5)
2247
+
2248
+ """
2249
+
2250
+ padded = np.pad(
2251
+ signal_set,
2252
+ [(0, 0), (0, model_signal_length - signal_set.shape[1]), (0, 0)],
2253
+ mode="edge",
2254
+ )
2255
+
2256
+ return padded
2257
+
2258
+
2259
+ def augmenter(
2260
+ signal,
2261
+ time_of_interest,
2262
+ cclass,
2263
+ model_signal_length,
2264
+ time_shift=True,
2265
+ probability=0.95,
2266
+ ):
2267
+ """
2268
+ Randomly augments single-cell signals to simulate variations in noise, intensity ratios, and event times.
2269
+
2270
+ This function applies random transformations to the input signal, including time shifts, intensity changes,
2271
+ and the addition of Gaussian noise, with the aim of increasing the diversity of the dataset for training robust models.
2272
+
2273
+ Parameters
2274
+ ----------
2275
+ signal : ndarray
2276
+ A 1D numpy array representing the signal of a single cell to be augmented.
2277
+ time_of_interest : float
2278
+ The normalized time of interest (event time) for the signal, scaled to the range [0, 1].
2279
+ cclass : ndarray
2280
+ A one-hot encoded numpy array representing the class of the cell associated with the signal.
2281
+ model_signal_length : int
2282
+ The length of the signal expected by the model, used for scaling the time of interest.
2283
+ time_shift : bool, optional
2284
+ Specifies whether to apply random time shifts to the signal. Default is True.
2285
+ probability : float, optional
2286
+ The probability with which to apply the augmentation transformations. Default is 0.8.
2287
+
2288
+ Returns
2289
+ -------
2290
+ tuple
2291
+ A tuple containing the augmented signal, the normalized time of interest, and the class of the cell.
2292
+
2293
+ Raises
2294
+ ------
2295
+ AssertionError
2296
+ If the time of interest is provided but invalid for time shifting.
2297
+
2298
+ Notes
2299
+ -----
2300
+ - Time shifting is not applied to cells of the class labeled as 'miscellaneous' (typically encoded as the class '2').
2301
+ - The time of interest is rescaled based on the model's expected signal length before and after any time shift.
2302
+ - Augmentation is applied with the specified probability to simulate realistic variability while maintaining
2303
+ some original signals in the dataset.
2304
+
2305
+ """
2306
+
2307
+ if np.amax(time_of_interest) <= 1.0:
2308
+ time_of_interest *= model_signal_length
2309
+
2310
+ # augment with a certain probability
2311
+ r = random.random()
2312
+ if r <= probability:
2313
+
2314
+ if time_shift:
2315
+ # do not time shift miscellaneous cells
2316
+ assert time_of_interest is not None, f"Please provide valid lysis times"
2317
+ signal, time_of_interest, cclass = random_time_shift(
2318
+ signal, time_of_interest, cclass, model_signal_length
2319
+ )
2320
+
2321
+ # signal = random_intensity_change(signal) #maybe bad idea for non percentile-normalized signals
2322
+ signal = gauss_noise(signal)
2323
+
2324
+ return signal, time_of_interest / model_signal_length, cclass
2325
+
2326
+
2327
+ def random_intensity_change(signal):
2328
+ """
2329
+
2330
+ Randomly change the intensity of a signal.
2331
+
2332
+ Parameters
2333
+ ----------
2334
+ signal : array-like
2335
+ The input signal to be modified.
2336
+
2337
+ Returns
2338
+ -------
2339
+ array-like
2340
+ The modified signal with randomly changed intensity.
2341
+
2342
+ Notes
2343
+ -----
2344
+ This function applies a random intensity change to each channel of the input signal. The intensity change is
2345
+ performed by multiplying each channel with a random value drawn from a uniform distribution between 0.7 and 1.0.
2346
+
2347
+ Examples
2348
+ --------
2349
+ >>> signal = np.array([[1, 2, 3], [4, 5, 6]])
2350
+ >>> modified_signal = random_intensity_change(signal)
2351
+
2352
+ """
2353
+
2354
+ for k in range(signal.shape[1]):
2355
+ signal[:, k] = signal[:, k] * np.random.uniform(0.7, 1.0)
2356
+
2357
+ return signal
2358
+
2359
+
2360
+ def gauss_noise(signal):
2361
+ """
2362
+
2363
+ Add Gaussian noise to a signal.
2364
+
2365
+ Parameters
2366
+ ----------
2367
+ signal : array-like
2368
+ The input signal to which noise will be added.
2369
+
2370
+ Returns
2371
+ -------
2372
+ array-like
2373
+ The signal with Gaussian noise added.
2374
+
2375
+ Notes
2376
+ -----
2377
+ This function adds Gaussian noise to the input signal. The noise is generated by drawing random values from a
2378
+ standard normal distribution and scaling them by a factor of 0.08 times the input signal. The scaled noise values
2379
+ are then added to the original signal.
2380
+
2381
+ Examples
2382
+ --------
2383
+ >>> signal = np.array([1, 2, 3, 4, 5])
2384
+ >>> noisy_signal = gauss_noise(signal)
2385
+
2386
+ """
2387
+
2388
+ sig = 0.08 * np.random.uniform(0, 1)
2389
+ signal = signal + sig * np.random.normal(0, 1, signal.shape) * signal
2390
+ return signal
2391
+
2392
+
2393
+ def random_time_shift(signal, time_of_interest, cclass, model_signal_length):
2394
+ """
2395
+
2396
+ Randomly shift the signals to another time.
2397
+
2398
+ Parameters
2399
+ ----------
2400
+ signal : array-like
2401
+ The signal to be shifted.
2402
+ time_of_interest : int or float
2403
+ The original time of interest for the signal. Use -1 if not applicable.
2404
+ model_signal_length : int
2405
+ The length of the model signal.
2406
+
2407
+ Returns
2408
+ -------
2409
+ array-like
2410
+ The shifted fluorescence signal.
2411
+ int or float
2412
+ The new time of interest if available; otherwise, the original time of interest.
2413
+
2414
+ Notes
2415
+ -----
2416
+ This function randomly selects a target time within the specified model signal length and shifts the
2417
+ signal accordingly. The shift is performed along the first dimension (axis 0) of the signal. The function uses
2418
+ nearest-neighbor interpolation for shifting.
2419
+
2420
+ If the original time of interest (`time_of_interest`) is provided (not equal to -1), the function returns the
2421
+ shifted signal along with the new time of interest. Otherwise, it returns the shifted signal along with the
2422
+ original time of interest.
2423
+
2424
+ The `max_time` is set to the `model_signal_length` unless the original time of interest is provided. In that case,
2425
+ `max_time` is set to `model_signal_length - 3` to prevent shifting too close to the edge.
2426
+
2427
+ Examples
2428
+ --------
2429
+ >>> signal = np.array([[1, 2, 3], [4, 5, 6]])
2430
+ >>> shifted_signal, new_time = random_time_shift(signal, 1, 5)
2431
+
2432
+ """
2433
+
2434
+ min_time = 3
2435
+ max_time = model_signal_length
2436
+
2437
+ return_target = False
2438
+ if time_of_interest != -1:
2439
+ return_target = True
2440
+ max_time = (
2441
+ model_signal_length + 1 / 3 * model_signal_length
2442
+ ) # bias to have a third of event class becoming no event
2443
+ min_time = -model_signal_length * 1 / 3
2444
+
2445
+ times = np.linspace(
2446
+ min_time, max_time, 2000
2447
+ ) # symmetrize to create left-censored events
2448
+ target_time = np.random.choice(times)
2449
+
2450
+ delta_t = target_time - time_of_interest
2451
+ signal = shift(signal, [delta_t, 0], order=0, mode="nearest")
2452
+
2453
+ if target_time <= 0 and np.argmax(cclass) == 0:
2454
+ target_time = -1
2455
+ cclass = np.array([0.0, 0.0, 1.0]).astype(np.float32)
2456
+ if target_time >= model_signal_length and np.argmax(cclass) == 0:
2457
+ target_time = -1
2458
+ cclass = np.array([0.0, 1.0, 0.0]).astype(np.float32)
2459
+
2460
+ if return_target:
2461
+ return signal, target_time, cclass
2462
+ else:
2463
+ return signal, time_of_interest, cclass