celldetective 1.5.0b7__py3-none-any.whl → 1.5.0b8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
celldetective/signals.py CHANGED
@@ -3,63 +3,17 @@ import os
3
3
  import subprocess
4
4
  import json
5
5
 
6
- from tensorflow.keras.optimizers import Adam
7
- from tensorflow.keras.callbacks import (
8
- EarlyStopping,
9
- ModelCheckpoint,
10
- TensorBoard,
11
- ReduceLROnPlateau,
12
- CSVLogger,
13
- )
14
- from tensorflow.keras.losses import (
15
- CategoricalCrossentropy,
16
- MeanSquaredError,
17
- MeanAbsoluteError,
18
- )
19
- from tensorflow.keras.metrics import Precision, Recall, MeanIoU
20
- from tensorflow.keras.models import load_model, clone_model
21
- from tensorflow.config.experimental import list_physical_devices, set_memory_growth
22
- from tensorflow.keras.utils import to_categorical
23
- from tensorflow.keras import Input, Model
24
- from tensorflow.keras.layers import (
25
- Conv1D,
26
- BatchNormalization,
27
- Dense,
28
- Activation,
29
- Add,
30
- MaxPooling1D,
31
- Dropout,
32
- GlobalAveragePooling1D,
33
- Concatenate,
34
- ZeroPadding1D,
35
- Flatten,
36
- )
37
- from tensorflow.keras.callbacks import Callback
38
- from sklearn.metrics import confusion_matrix, classification_report
39
- from sklearn.metrics import (
40
- jaccard_score,
41
- balanced_accuracy_score,
42
- precision_score,
43
- recall_score,
44
- )
45
- from scipy.interpolate import interp1d
46
- from scipy.ndimage import shift
47
- from sklearn.metrics import ConfusionMatrixDisplay
6
+ # TensorFlow imports are lazy-loaded in functions that need them to avoid
7
+ # slow import times for modules that don't require TensorFlow.
48
8
 
49
9
  from celldetective.utils.model_loaders import locate_signal_model
50
10
  from celldetective.utils.data_loaders import get_position_table, get_position_pickle
51
11
  from celldetective.tracking import clean_trajectories, interpolate_nan_properties
52
- from celldetective.utils.dataset_helpers import compute_weights, train_test_split
53
- from celldetective.utils.plots.regression import regression_plot
54
12
  import matplotlib.pyplot as plt
55
13
  from natsort import natsorted
56
- from glob import glob
57
- import random
58
14
  from celldetective.utils.color_mappings import color_from_status, color_from_class
59
15
  from math import floor
60
16
  from scipy.optimize import curve_fit
61
- import time
62
- import math
63
17
  import pandas as pd
64
18
  from pandas.api.types import is_numeric_dtype
65
19
  from scipy.stats import median_abs_deviation
@@ -69,59 +23,6 @@ abs_path = os.sep.join(
69
23
  )
70
24
 
71
25
 
72
- class TimeHistory(Callback):
73
- """
74
- A custom Keras callback to log the duration of each epoch during training.
75
-
76
- This callback records the time taken for each epoch during the model training process, allowing for
77
- monitoring of training efficiency and performance over time. The times are stored in a list, with each
78
- element representing the duration of an epoch in seconds.
79
-
80
- Attributes
81
- ----------
82
- times : list
83
- A list of times (in seconds) taken for each epoch during the training. This list is populated as the
84
- training progresses.
85
-
86
- Methods
87
- -------
88
- on_train_begin(logs={})
89
- Initializes the list of times at the beginning of training.
90
-
91
- on_epoch_begin(epoch, logs={})
92
- Records the start time of the current epoch.
93
-
94
- on_epoch_end(epoch, logs={})
95
- Calculates and appends the duration of the current epoch to the `times` list.
96
-
97
- Notes
98
- -----
99
- - This callback is intended to be used with the `fit` method of Keras models.
100
- - The time measurements are made using the `time.time()` function, which provides wall-clock time.
101
-
102
- Examples
103
- --------
104
- >>> from keras.models import Sequential
105
- >>> from keras.layers import Dense
106
- >>> model = Sequential([Dense(10, activation='relu', input_shape=(20,)), Dense(1)])
107
- >>> time_callback = TimeHistory()
108
- >>> model.compile(optimizer='adam', loss='mean_squared_error')
109
- >>> model.fit(x_train, y_train, epochs=10, callbacks=[time_callback])
110
- >>> print(time_callback.times)
111
- # This will print the time taken for each epoch during the training.
112
-
113
- """
114
-
115
- def on_train_begin(self, logs={}):
116
- self.times = []
117
-
118
- def on_epoch_begin(self, epoch, logs={}):
119
- self.epoch_time_start = time.time()
120
-
121
- def on_epoch_end(self, epoch, logs={}):
122
- self.times.append(time.time() - self.epoch_time_start)
123
-
124
-
125
26
  def analyze_signals(
126
27
  trajectories,
127
28
  model,
@@ -180,6 +81,7 @@ def analyze_signals(
180
81
  - Signal selection and preprocessing are based on the requirements specified in the model's configuration.
181
82
 
182
83
  """
84
+ from celldetective.event_detection_models import SignalDetectionModel
183
85
 
184
86
  model_path = locate_signal_model(model, path=model_path)
185
87
  complete_path = model_path # +model
@@ -476,6 +378,7 @@ def analyze_pair_signals(
476
378
  "y": "POSITION_Y",
477
379
  },
478
380
  ):
381
+ from celldetective.event_detection_models import SignalDetectionModel
479
382
 
480
383
  model_path = locate_signal_model(model, path=model_path, pairs=True)
481
384
  print(f"Looking for model in {model_path}...")
@@ -681,2331 +584,6 @@ def analyze_pair_signals(
681
584
  return trajectories_pairs
682
585
 
683
586
 
684
- class SignalDetectionModel(object):
685
- """
686
- A class for creating and managing signal detection models for analyzing biological signals.
687
-
688
- This class provides functionalities to load a pretrained signal detection model or create one from scratch,
689
- preprocess input signals, train the model, and make predictions on new data.
690
-
691
- Parameters
692
- ----------
693
- path : str, optional
694
- Path to the directory containing the model and its configuration. This is used when loading a pretrained model.
695
- pretrained : str, optional
696
- Path to the pretrained model to load. If specified, the model and its configuration are loaded from this path.
697
- channel_option : list of str, optional
698
- Specifies the channels to be used for signal analysis. Default is ["live_nuclei_channel"].
699
- model_signal_length : int, optional
700
- The length of the input signals that the model expects. Default is 128.
701
- n_channels : int, optional
702
- The number of channels in the input signals. Default is 1.
703
- n_conv : int, optional
704
- The number of convolutional layers in the model. Default is 2.
705
- n_classes : int, optional
706
- The number of classes for the classification task. Default is 3.
707
- dense_collection : int, optional
708
- The number of units in the dense layer of the model. Default is 512.
709
- dropout_rate : float, optional
710
- The dropout rate applied to the dense layer of the model. Default is 0.1.
711
- label : str, optional
712
- A label for the model, used in naming and organizing outputs. Default is ''.
713
-
714
- Attributes
715
- ----------
716
- model_class : keras Model
717
- The classification model for predicting the class of signals.
718
- model_reg : keras Model
719
- The regression model for predicting the time of interest for signals.
720
-
721
- Methods
722
- -------
723
- load_pretrained_model()
724
- Loads the model and its configuration from the pretrained path.
725
- create_models_from_scratch()
726
- Creates new models for classification and regression from scratch.
727
- prep_gpu()
728
- Prepares GPU devices for training, if available.
729
- fit_from_directory(ds_folders, ...)
730
- Trains the model using data from specified directories.
731
- fit(x_train, y_time_train, y_class_train, ...)
732
- Trains the model using provided datasets.
733
- predict_class(x, ...)
734
- Predicts the class of input signals.
735
- predict_time_of_interest(x, ...)
736
- Predicts the time of interest for input signals.
737
- plot_model_history(mode)
738
- Plots the training history for the specified mode (classifier or regressor).
739
- evaluate_regression_model()
740
- Evaluates the regression model on test and validation data.
741
- gather_callbacks(mode)
742
- Gathers and prepares callbacks for training based on the specified mode.
743
- generate_sets()
744
- Generates training, validation, and test sets from loaded data.
745
- augment_training_set()
746
- Augments the training set with additional generated data.
747
- load_and_normalize(subset)
748
- Loads and normalizes signals from a subset of data.
749
-
750
- Notes
751
- -----
752
- - This class is designed to work with biological signal data, such as time series from microscopy imaging.
753
- - The model architecture and training configurations can be customized through the class parameters and methods.
754
-
755
- """
756
-
757
- def __init__(
758
- self,
759
- path=None,
760
- pretrained=None,
761
- channel_option=["live_nuclei_channel"],
762
- model_signal_length=128,
763
- n_channels=1,
764
- n_conv=2,
765
- n_classes=3,
766
- dense_collection=512,
767
- dropout_rate=0.1,
768
- label="",
769
- ):
770
-
771
- self.prep_gpu()
772
-
773
- self.model_signal_length = model_signal_length
774
- self.channel_option = channel_option
775
- self.pretrained = pretrained
776
- self.n_channels = n_channels
777
- self.n_conv = n_conv
778
- self.n_classes = n_classes
779
- self.dense_collection = dense_collection
780
- self.dropout_rate = dropout_rate
781
- self.label = label
782
- self.show_plots = True
783
-
784
- if self.pretrained is not None:
785
- print(f"Load pretrained models from {pretrained}...")
786
- test = self.load_pretrained_model()
787
- if test is None:
788
- self.pretrained = None
789
- print(
790
- "Pretrained model could not be loaded. Check the log for error. Abort..."
791
- )
792
- return None
793
- else:
794
- print("Create models from scratch...")
795
- self.create_models_from_scratch()
796
- print("Models successfully created.")
797
-
798
- def load_pretrained_model(self):
799
- """
800
- Loads a pretrained model and its configuration from the specified path.
801
-
802
- This method attempts to load both the classification and regression models from the path specified during the
803
- class instantiation. It also loads the model configuration from a JSON file and updates the model attributes
804
- accordingly. If the models cannot be loaded, an error message is printed.
805
-
806
- Raises
807
- ------
808
- Exception
809
- If there is an error loading the model or the configuration file, an exception is raised with details.
810
-
811
- Notes
812
- -----
813
- - The models are expected to be saved in .h5 format with the filenames "classifier.h5" and "regressor.h5".
814
- - The configuration file is expected to be named "config_input.json" and located in the same directory as the models.
815
-
816
- """
817
-
818
- if self.pretrained.endswith(os.sep):
819
- self.pretrained = os.sep.join(self.pretrained.split(os.sep)[:-1])
820
-
821
- try:
822
- self.model_class = load_model(
823
- os.sep.join([self.pretrained, "classifier.h5"]),
824
- compile=False,
825
- custom_objects={"mse": MeanSquaredError()},
826
- )
827
- self.model_class.load_weights(
828
- os.sep.join([self.pretrained, "classifier.h5"])
829
- )
830
- self.model_class = self.freeze_encoder(self.model_class, 5)
831
- print("Classifier successfully loaded...")
832
- except Exception as e:
833
- print(f"Error {e}...")
834
- self.model_class = None
835
- try:
836
- self.model_reg = load_model(
837
- os.sep.join([self.pretrained, "regressor.h5"]),
838
- compile=False,
839
- custom_objects={"mse": MeanSquaredError()},
840
- )
841
- self.model_reg.load_weights(os.sep.join([self.pretrained, "regressor.h5"]))
842
- self.model_reg = self.freeze_encoder(self.model_reg, 5)
843
- print("Regressor successfully loaded...")
844
- except Exception as e:
845
- print(f"Error {e}...")
846
- self.model_reg = None
847
-
848
- if self.model_class is None and self.model_reg is None:
849
- return None
850
-
851
- # load config
852
- with open(os.sep.join([self.pretrained, "config_input.json"])) as config_file:
853
- model_config = json.load(config_file)
854
- self.config = model_config
855
-
856
- req_channels = model_config["channels"]
857
- print(f"Required channels read from pretrained model: {req_channels}")
858
- self.channel_option = req_channels
859
- if "normalize" in model_config:
860
- self.normalize = model_config["normalize"]
861
- if "normalization_percentile" in model_config:
862
- self.normalization_percentile = model_config["normalization_percentile"]
863
- if "normalization_values" in model_config:
864
- self.normalization_values = model_config["normalization_values"]
865
- if "normalization_percentile" in model_config:
866
- self.normalization_clip = model_config["normalization_clip"]
867
- if "label" in model_config:
868
- self.label = model_config["label"]
869
-
870
- try:
871
- self.n_channels = self.model_class.layers[0].input_shape[0][-1]
872
- self.model_signal_length = self.model_class.layers[0].input_shape[0][-2]
873
- self.n_classes = self.model_class.layers[-1].output_shape[-1]
874
- model_class_input_shape = self.model_class.layers[0].input_shape[0]
875
- model_reg_input_shape = self.model_reg.layers[0].input_shape[0]
876
- except AttributeError:
877
- self.n_channels = self.model_class.input_shape[
878
- -1
879
- ] # self.model_class.layers[0].input.shape[0][-1]
880
- self.model_signal_length = self.model_class.input_shape[
881
- -2
882
- ] # self.model_class.layers[0].input[0].shape[0][-2]
883
- self.n_classes = self.model_class.output_shape[
884
- -1
885
- ] # self.model_class.layers[-1].output[0].shape[-1]
886
- model_class_input_shape = self.model_class.input_shape
887
- model_reg_input_shape = self.model_reg.input_shape
888
- except Exception as e:
889
- print(e)
890
-
891
- assert (
892
- model_class_input_shape == model_reg_input_shape
893
- ), f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
894
-
895
- return True
896
-
897
- def freeze_encoder(self, model, n_trainable_layers: int = 3):
898
- for layer in model.layers[
899
- : -min(n_trainable_layers, len(model.layers))
900
- ]: # freeze everything except final Dense layer
901
- layer.trainable = False
902
- return model
903
-
904
- def create_models_from_scratch(self):
905
- """
906
- Initializes new models for classification and regression based on the specified parameters.
907
-
908
- This method creates new ResNet models for both classification and regression tasks using the parameters specified
909
- during class instantiation. The models are configured but not compiled or trained.
910
-
911
- Notes
912
- -----
913
- - The models are created using a custom ResNet architecture defined elsewhere in the codebase.
914
- - The models are stored in the `model_class` and `model_reg` attributes of the class.
915
-
916
- """
917
-
918
- self.model_class = ResNetModelCurrent(
919
- n_channels=self.n_channels,
920
- n_slices=self.n_conv,
921
- n_classes=3,
922
- dense_collection=self.dense_collection,
923
- dropout_rate=self.dropout_rate,
924
- header="classifier",
925
- model_signal_length=self.model_signal_length,
926
- )
927
-
928
- self.model_reg = ResNetModelCurrent(
929
- n_channels=self.n_channels,
930
- n_slices=self.n_conv,
931
- n_classes=self.n_classes,
932
- dense_collection=self.dense_collection,
933
- dropout_rate=self.dropout_rate,
934
- header="regressor",
935
- model_signal_length=self.model_signal_length,
936
- )
937
-
938
- def prep_gpu(self):
939
- """
940
- Prepares GPU devices for training by enabling memory growth.
941
-
942
- This method attempts to identify available GPU devices and configures TensorFlow to allow memory growth on each
943
- GPU. This prevents TensorFlow from allocating the total available memory on the GPU device upfront.
944
-
945
- Notes
946
- -----
947
- - This method should be called before any TensorFlow/Keras operations that might allocate GPU memory.
948
- - If no GPUs are detected, the method will pass silently.
949
-
950
- """
951
-
952
- try:
953
- physical_devices = list_physical_devices("GPU")
954
- for gpu in physical_devices:
955
- set_memory_growth(gpu, True)
956
- except:
957
- pass
958
-
959
- def fit_from_directory(
960
- self,
961
- datasets,
962
- normalize=True,
963
- normalization_percentile=None,
964
- normalization_values=None,
965
- normalization_clip=None,
966
- channel_option=["live_nuclei_channel"],
967
- model_name=None,
968
- target_directory=None,
969
- augment=True,
970
- augmentation_factor=2,
971
- validation_split=0.20,
972
- test_split=0.0,
973
- batch_size=64,
974
- epochs=300,
975
- recompile_pretrained=False,
976
- learning_rate=0.01,
977
- loss_reg="mse",
978
- loss_class=CategoricalCrossentropy(from_logits=False),
979
- show_plots=True,
980
- callbacks=None,
981
- ):
982
- """
983
- Trains the model using data from specified directories.
984
-
985
- This method prepares the dataset for training by loading and preprocessing data from specified directories,
986
- then trains the classification and regression models.
987
-
988
- Parameters
989
- ----------
990
- ds_folders : list of str
991
- List of directories containing the dataset files for training.
992
- callbacks : list, optional
993
- List of Keras callbacks to apply during training.
994
- normalize : bool, optional
995
- Whether to normalize the input signals (default is True).
996
- normalization_percentile : list or None, optional
997
- Percentiles for signal normalization (default is None).
998
- normalization_values : list or None, optional
999
- Specific values for signal normalization (default is None).
1000
- normalization_clip : bool, optional
1001
- Whether to clip the normalized signals (default is None).
1002
- channel_option : list of str, optional
1003
- Specifies the channels to be used for signal analysis (default is ["live_nuclei_channel"]).
1004
- model_name : str, optional
1005
- Name of the model for saving purposes (default is None).
1006
- target_directory : str, optional
1007
- Directory where the trained model and outputs will be saved (default is None).
1008
- augment : bool, optional
1009
- Whether to augment the training data (default is True).
1010
- augmentation_factor : int, optional
1011
- Factor by which to augment the training data (default is 2).
1012
- validation_split : float, optional
1013
- Fraction of the data to be used as validation set (default is 0.20).
1014
- test_split : float, optional
1015
- Fraction of the data to be used as test set (default is 0.0).
1016
- batch_size : int, optional
1017
- Batch size for training (default is 64).
1018
- epochs : int, optional
1019
- Number of epochs to train for (default is 300).
1020
- recompile_pretrained : bool, optional
1021
- Whether to recompile a pretrained model (default is False).
1022
- learning_rate : float, optional
1023
- Learning rate for the optimizer (default is 0.01).
1024
- loss_reg : str or keras.losses.Loss, optional
1025
- Loss function for the regression model (default is "mse").
1026
- loss_class : str or keras.losses.Loss, optional
1027
- Loss function for the classification model (default is CategoricalCrossentropy(from_logits=False)).
1028
-
1029
- Notes
1030
- -----
1031
- - The method automatically splits the dataset into training, validation, and test sets according to the specified splits.
1032
-
1033
- """
1034
-
1035
- if not hasattr(self, "normalization_percentile"):
1036
- self.normalization_percentile = normalization_percentile
1037
- if not hasattr(self, "normalization_values"):
1038
- self.normalization_values = normalization_values
1039
- if not hasattr(self, "normalization_clip"):
1040
- self.normalization_clip = normalization_clip
1041
-
1042
- self.callbacks = callbacks
1043
- self.normalize = normalize
1044
- (
1045
- self.normalization_percentile,
1046
- self.normalization_values,
1047
- self.normalization_clip,
1048
- ) = _interpret_normalization_parameters(
1049
- self.n_channels,
1050
- self.normalization_percentile,
1051
- self.normalization_values,
1052
- self.normalization_clip,
1053
- )
1054
-
1055
- self.datasets = [rf"{d}" if isinstance(d, str) else d for d in datasets]
1056
- self.batch_size = batch_size
1057
- self.epochs = epochs
1058
- self.validation_split = validation_split
1059
- self.test_split = test_split
1060
- self.augment = augment
1061
- self.augmentation_factor = augmentation_factor
1062
- self.model_name = rf"{model_name}"
1063
- self.target_directory = rf"{target_directory}"
1064
- self.model_folder = os.sep.join([self.target_directory, self.model_name])
1065
- self.recompile_pretrained = recompile_pretrained
1066
- self.learning_rate = learning_rate
1067
- self.loss_reg = loss_reg
1068
- self.loss_class = loss_class
1069
- self.show_plots = show_plots
1070
- self.channel_option = channel_option
1071
-
1072
- assert self.n_channels == len(
1073
- self.channel_option
1074
- ), f"Mismatch between the channel option and the number of channels of the model..."
1075
-
1076
- if isinstance(self.datasets[0], dict):
1077
- self.datasets = [self.datasets]
1078
-
1079
- self.list_of_sets = []
1080
- for ds in self.datasets:
1081
- if isinstance(ds, str):
1082
- self.list_of_sets.extend(glob(os.sep.join([ds, "*.npy"])))
1083
- else:
1084
- self.list_of_sets.append(ds)
1085
-
1086
- print(f"Found {len(self.list_of_sets)} datasets...")
1087
-
1088
- self.prepare_sets()
1089
- self.train_generic()
1090
-
1091
- def fit(
1092
- self,
1093
- x_train,
1094
- y_time_train,
1095
- y_class_train,
1096
- normalize=True,
1097
- normalization_percentile=None,
1098
- normalization_values=None,
1099
- normalization_clip=None,
1100
- pad=True,
1101
- validation_data=None,
1102
- test_data=None,
1103
- channel_option=["live_nuclei_channel", "dead_nuclei_channel"],
1104
- model_name=None,
1105
- target_directory=None,
1106
- augment=True,
1107
- augmentation_factor=3,
1108
- validation_split=0.25,
1109
- batch_size=64,
1110
- epochs=300,
1111
- recompile_pretrained=False,
1112
- learning_rate=0.001,
1113
- loss_reg="mse",
1114
- loss_class=CategoricalCrossentropy(from_logits=False),
1115
- ):
1116
- """
1117
- Trains the model using provided datasets.
1118
-
1119
- Parameters
1120
- ----------
1121
- Same as `fit_from_directory`, but instead of loading data from directories, this method accepts preloaded and
1122
- optionally preprocessed datasets directly.
1123
-
1124
- Notes
1125
- -----
1126
- - This method provides an alternative way to train the model when data is already loaded into memory, offering
1127
- flexibility for data preprocessing steps outside this class.
1128
-
1129
- """
1130
-
1131
- self.normalize = normalize
1132
- if not hasattr(self, "normalization_percentile"):
1133
- self.normalization_percentile = normalization_percentile
1134
- if not hasattr(self, "normalization_values"):
1135
- self.normalization_values = normalization_values
1136
- if not hasattr(self, "normalization_clip"):
1137
- self.normalization_clip = normalization_clip
1138
- (
1139
- self.normalization_percentile,
1140
- self.normalization_values,
1141
- self.normalization_clip,
1142
- ) = _interpret_normalization_parameters(
1143
- self.n_channels,
1144
- self.normalization_percentile,
1145
- self.normalization_values,
1146
- self.normalization_clip,
1147
- )
1148
-
1149
- self.x_train = x_train
1150
- self.y_class_train = y_class_train
1151
- self.y_time_train = y_time_train
1152
- self.channel_option = channel_option
1153
-
1154
- assert self.n_channels == len(
1155
- self.channel_option
1156
- ), f"Mismatch between the channel option and the number of channels of the model..."
1157
-
1158
- if pad:
1159
- self.x_train = pad_to_model_length(self.x_train, self.model_signal_length)
1160
-
1161
- assert self.x_train.shape[1:] == (
1162
- self.model_signal_length,
1163
- self.n_channels,
1164
- ), f"Shape mismatch between the provided training fluorescence signals and the model..."
1165
-
1166
- # If y-class is not one-hot encoded, encode it
1167
- if self.y_class_train.shape[-1] != self.n_classes:
1168
- self.class_weights = compute_weights(
1169
- y=self.y_class_train,
1170
- class_weight="balanced",
1171
- classes=np.unique(self.y_class_train),
1172
- )
1173
- self.y_class_train = to_categorical(self.y_class_train, num_classes=3)
1174
-
1175
- if self.normalize:
1176
- self.y_time_train = (
1177
- self.y_time_train.astype(np.float32) / self.model_signal_length
1178
- )
1179
- self.x_train = normalize_signal_set(
1180
- self.x_train,
1181
- self.channel_option,
1182
- normalization_percentile=self.normalization_percentile,
1183
- normalization_values=self.normalization_values,
1184
- normalization_clip=self.normalization_clip,
1185
- )
1186
-
1187
- if validation_data is not None:
1188
- try:
1189
- self.x_val = validation_data[0]
1190
- if pad:
1191
- self.x_val = pad_to_model_length(
1192
- self.x_val, self.model_signal_length
1193
- )
1194
- self.y_class_val = validation_data[1]
1195
- if self.y_class_val.shape[-1] != self.n_classes:
1196
- self.y_class_val = to_categorical(self.y_class_val, num_classes=3)
1197
- self.y_time_val = validation_data[2]
1198
- if self.normalize:
1199
- self.y_time_val = (
1200
- self.y_time_val.astype(np.float32) / self.model_signal_length
1201
- )
1202
- self.x_val = normalize_signal_set(
1203
- self.x_val,
1204
- self.channel_option,
1205
- normalization_percentile=self.normalization_percentile,
1206
- normalization_values=self.normalization_values,
1207
- normalization_clip=self.normalization_clip,
1208
- )
1209
-
1210
- except Exception as e:
1211
- print("Could not load validation data, error {e}...")
1212
- else:
1213
- self.validation_split = validation_split
1214
-
1215
- if test_data is not None:
1216
- try:
1217
- self.x_test = test_data[0]
1218
- if pad:
1219
- self.x_test = pad_to_model_length(
1220
- self.x_test, self.model_signal_length
1221
- )
1222
- self.y_class_test = test_data[1]
1223
- if self.y_class_test.shape[-1] != self.n_classes:
1224
- self.y_class_test = to_categorical(self.y_class_test, num_classes=3)
1225
- self.y_time_test = test_data[2]
1226
- if self.normalize:
1227
- self.y_time_test = (
1228
- self.y_time_test.astype(np.float32) / self.model_signal_length
1229
- )
1230
- self.x_test = normalize_signal_set(
1231
- self.x_test,
1232
- self.channel_option,
1233
- normalization_percentile=self.normalization_percentile,
1234
- normalization_values=self.normalization_values,
1235
- normalization_clip=self.normalization_clip,
1236
- )
1237
- except Exception as e:
1238
- print("Could not load test data, error {e}...")
1239
-
1240
- self.batch_size = batch_size
1241
- self.epochs = epochs
1242
- self.augment = augment
1243
- self.augmentation_factor = augmentation_factor
1244
- if self.augmentation_factor == 1:
1245
- self.augment = False
1246
- self.model_name = model_name
1247
- self.target_directory = target_directory
1248
- self.model_folder = os.sep.join([self.target_directory, self.model_name])
1249
- self.recompile_pretrained = recompile_pretrained
1250
- self.learning_rate = learning_rate
1251
- self.loss_reg = loss_reg
1252
- self.loss_class = loss_class
1253
-
1254
- self.train_generic()
1255
-
1256
- def train_generic(self):
1257
-
1258
- if not os.path.exists(self.model_folder):
1259
- os.mkdir(self.model_folder)
1260
-
1261
- self.train_classifier()
1262
- self.train_regressor()
1263
-
1264
- config_input = {
1265
- "channels": self.channel_option,
1266
- "model_signal_length": self.model_signal_length,
1267
- "label": self.label,
1268
- "normalize": self.normalize,
1269
- "normalization_percentile": self.normalization_percentile,
1270
- "normalization_values": self.normalization_values,
1271
- "normalization_clip": self.normalization_clip,
1272
- }
1273
- json_string = json.dumps(config_input)
1274
- with open(
1275
- os.sep.join([self.model_folder, "config_input.json"]), "w"
1276
- ) as outfile:
1277
- outfile.write(json_string)
1278
-
1279
- def predict_class(
1280
- self, x, normalize=True, pad=True, return_one_hot=False, interpolate=True
1281
- ):
1282
- """
1283
- Predicts the class of input signals using the trained classification model.
1284
-
1285
- Parameters
1286
- ----------
1287
- x : ndarray
1288
- The input signals for which to predict classes.
1289
- normalize : bool, optional
1290
- Whether to normalize the input signals (default is True).
1291
- pad : bool, optional
1292
- Whether to pad the input signals to match the model's expected signal length (default is True).
1293
- return_one_hot : bool, optional
1294
- Whether to return predictions in one-hot encoded format (default is False).
1295
- interpolate : bool, optional
1296
- Whether to interpolate the input signals (default is True).
1297
-
1298
- Returns
1299
- -------
1300
- ndarray
1301
- The predicted classes for the input signals. If `return_one_hot` is True, predictions are returned in one-hot
1302
- encoded format, otherwise as integer labels.
1303
-
1304
- Notes
1305
- -----
1306
- - The method processes the input signals according to the specified options to ensure compatibility with the model's
1307
- input requirements.
1308
-
1309
- """
1310
-
1311
- self.x = np.copy(x)
1312
- self.normalize = normalize
1313
- self.pad = pad
1314
- self.return_one_hot = return_one_hot
1315
- # self.max_relevant_time = np.shape(self.x)[1]
1316
- # print(f'Max relevant time: {self.max_relevant_time}')
1317
-
1318
- if self.pad:
1319
- self.x = pad_to_model_length(self.x, self.model_signal_length)
1320
-
1321
- if self.normalize:
1322
- self.x = normalize_signal_set(
1323
- self.x,
1324
- self.channel_option,
1325
- normalization_percentile=self.normalization_percentile,
1326
- normalization_values=self.normalization_values,
1327
- normalization_clip=self.normalization_clip,
1328
- )
1329
-
1330
- # implement auto interpolation here!!
1331
- # self.x = self.interpolate_signals(self.x)
1332
-
1333
- # for i in range(5):
1334
- # plt.plot(self.x[i,:,0])
1335
- # plt.show()
1336
-
1337
- try:
1338
- n_channels = self.model_class.layers[0].input_shape[0][-1]
1339
- model_signal_length = self.model_class.layers[0].input_shape[0][-2]
1340
- except AttributeError:
1341
- n_channels = self.model_class.input_shape[-1]
1342
- model_signal_length = self.model_class.input_shape[-2]
1343
-
1344
- assert (
1345
- self.x.shape[-1] == n_channels
1346
- ), f"Shape mismatch between the input shape and the model input shape..."
1347
- assert (
1348
- self.x.shape[-2] == model_signal_length
1349
- ), f"Shape mismatch between the input shape and the model input shape..."
1350
-
1351
- self.class_predictions_one_hot = self.model_class.predict(self.x)
1352
- self.class_predictions = self.class_predictions_one_hot.argmax(axis=1)
1353
-
1354
- if self.return_one_hot:
1355
- return self.class_predictions_one_hot
1356
- else:
1357
- return self.class_predictions
1358
-
1359
- def predict_time_of_interest(
1360
- self, x, class_predictions=None, normalize=True, pad=True
1361
- ):
1362
- """
1363
- Predicts the time of interest for input signals using the trained regression model.
1364
-
1365
- Parameters
1366
- ----------
1367
- x : ndarray
1368
- The input signals for which to predict times of interest.
1369
- class_predictions : ndarray, optional
1370
- The predicted classes for the input signals. If provided, time of interest predictions are only made for
1371
- signals predicted to belong to a specific class (default is None).
1372
- normalize : bool, optional
1373
- Whether to normalize the input signals (default is True).
1374
- pad : bool, optional
1375
- Whether to pad the input signals to match the model's expected signal length (default is True).
1376
-
1377
- Returns
1378
- -------
1379
- ndarray
1380
- The predicted times of interest for the input signals.
1381
-
1382
- Notes
1383
- -----
1384
- - The method processes the input signals according to the specified options and uses the regression model to
1385
- predict times at which a particular event of interest occurs.
1386
-
1387
- """
1388
-
1389
- self.x = np.copy(x)
1390
- self.normalize = normalize
1391
- self.pad = pad
1392
- # self.max_relevant_time = np.shape(self.x)[1]
1393
- # print(f'Max relevant time: {self.max_relevant_time}')
1394
-
1395
- if class_predictions is not None:
1396
- self.class_predictions = class_predictions
1397
-
1398
- if self.pad:
1399
- self.x = pad_to_model_length(self.x, self.model_signal_length)
1400
-
1401
- if self.normalize:
1402
- self.x = normalize_signal_set(
1403
- self.x,
1404
- self.channel_option,
1405
- normalization_percentile=self.normalization_percentile,
1406
- normalization_values=self.normalization_values,
1407
- normalization_clip=self.normalization_clip,
1408
- )
1409
-
1410
- try:
1411
- n_channels = self.model_reg.layers[0].input_shape[0][-1]
1412
- model_signal_length = self.model_reg.layers[0].input_shape[0][-2]
1413
- except AttributeError:
1414
- n_channels = self.model_reg.input_shape[-1]
1415
- model_signal_length = self.model_reg.input_shape[-2]
1416
-
1417
- assert (
1418
- self.x.shape[-1] == n_channels
1419
- ), f"Shape mismatch between the input shape and the model input shape..."
1420
- assert (
1421
- self.x.shape[-2] == model_signal_length
1422
- ), f"Shape mismatch between the input shape and the model input shape..."
1423
-
1424
- if np.any(self.class_predictions == 0):
1425
- self.time_predictions = (
1426
- self.model_reg.predict(self.x[self.class_predictions == 0])
1427
- * self.model_signal_length
1428
- )
1429
- self.time_predictions = self.time_predictions[:, 0]
1430
- self.time_predictions_recast = np.zeros(len(self.x)) - 1.0
1431
- self.time_predictions_recast[self.class_predictions == 0] = (
1432
- self.time_predictions
1433
- )
1434
- else:
1435
- self.time_predictions_recast = np.zeros(len(self.x)) - 1.0
1436
- return self.time_predictions_recast
1437
-
1438
- def interpolate_signals(self, x_set):
1439
- """
1440
- Interpolates missing values in the input signal set.
1441
-
1442
- Parameters
1443
- ----------
1444
- x_set : ndarray
1445
- The input signal set with potentially missing values.
1446
-
1447
- Returns
1448
- -------
1449
- ndarray
1450
- The input signal set with missing values interpolated.
1451
-
1452
- Notes
1453
- -----
1454
- - This method is useful for preparing signals that have gaps or missing time points before further processing
1455
- or model training.
1456
-
1457
- """
1458
-
1459
- for i in range(len(x_set)):
1460
- for k in range(x_set.shape[-1]):
1461
- x = x_set[i, :, k]
1462
- not_nan = np.logical_not(np.isnan(x))
1463
- indices = np.arange(len(x))
1464
- interp = interp1d(
1465
- indices[not_nan],
1466
- x[not_nan],
1467
- fill_value=(0.0, 0.0),
1468
- bounds_error=False,
1469
- )
1470
- x_set[i, :, k] = interp(indices)
1471
- return x_set
1472
-
1473
- def train_classifier(self):
1474
- """
1475
- Trains the classifier component of the model to predict event classes in signals.
1476
-
1477
- This method compiles the classifier model (if not pretrained or if recompilation is requested) and
1478
- trains it on the prepared dataset. The training process includes validation and early stopping based
1479
- on precision to prevent overfitting.
1480
-
1481
- Notes
1482
- -----
1483
- - The classifier model predicts the class of each signal, such as live, dead, or miscellaneous.
1484
- - Training parameters such as epochs, batch size, and learning rate are specified during class instantiation.
1485
- - Model performance metrics and training history are saved for analysis.
1486
-
1487
- """
1488
-
1489
- # if pretrained model
1490
- self.n_classes = 3
1491
-
1492
- if self.pretrained is not None:
1493
- # if recompile
1494
- if self.recompile_pretrained:
1495
- print(
1496
- "Recompiling the pretrained classifier model... Warning, this action reinitializes all the weights; are you sure that this is what you intended?"
1497
- )
1498
- self.model_class.set_weights(
1499
- clone_model(self.model_class).get_weights()
1500
- )
1501
- self.model_class.compile(
1502
- optimizer=Adam(learning_rate=self.learning_rate),
1503
- loss=self.loss_class,
1504
- metrics=[
1505
- "accuracy",
1506
- Precision(),
1507
- Recall(),
1508
- MeanIoU(
1509
- num_classes=self.n_classes,
1510
- name="iou",
1511
- dtype=float,
1512
- sparse_y_true=False,
1513
- sparse_y_pred=False,
1514
- ),
1515
- ],
1516
- )
1517
- else:
1518
- self.initial_model = clone_model(self.model_class)
1519
- self.model_class.set_weights(self.initial_model.get_weights())
1520
- # Recompile to avoid crash
1521
- self.model_class.compile(
1522
- optimizer=Adam(learning_rate=self.learning_rate),
1523
- loss=self.loss_class,
1524
- metrics=[
1525
- "accuracy",
1526
- Precision(),
1527
- Recall(),
1528
- MeanIoU(
1529
- num_classes=self.n_classes,
1530
- name="iou",
1531
- dtype=float,
1532
- sparse_y_true=False,
1533
- sparse_y_pred=False,
1534
- ),
1535
- ],
1536
- )
1537
- # Reset weights
1538
- self.model_class.set_weights(self.initial_model.get_weights())
1539
- else:
1540
- print("Compiling the classifier...")
1541
- self.model_class.compile(
1542
- optimizer=Adam(learning_rate=self.learning_rate),
1543
- loss=self.loss_class,
1544
- metrics=[
1545
- "accuracy",
1546
- Precision(),
1547
- Recall(),
1548
- MeanIoU(
1549
- num_classes=self.n_classes,
1550
- name="iou",
1551
- dtype=float,
1552
- sparse_y_true=False,
1553
- sparse_y_pred=False,
1554
- ),
1555
- ],
1556
- )
1557
-
1558
- self.gather_callbacks("classifier")
1559
-
1560
- # for i in range(30):
1561
- # for j in range(self.x_train.shape[-1]):
1562
- # plt.plot(self.x_train[i,:,j])
1563
- # plt.show()
1564
-
1565
- if hasattr(self, "x_val"):
1566
-
1567
- self.history_classifier = self.model_class.fit(
1568
- x=self.x_train,
1569
- y=self.y_class_train,
1570
- batch_size=self.batch_size,
1571
- class_weight=self.class_weights,
1572
- epochs=self.epochs,
1573
- validation_data=(self.x_val, self.y_class_val),
1574
- callbacks=self.cb,
1575
- verbose=1,
1576
- )
1577
- else:
1578
- self.history_classifier = self.model_class.fit(
1579
- x=self.x_train,
1580
- y=self.y_class_train,
1581
- batch_size=self.batch_size,
1582
- class_weight=self.class_weights,
1583
- epochs=self.epochs,
1584
- callbacks=self.cb,
1585
- validation_split=self.validation_split,
1586
- verbose=1,
1587
- )
1588
-
1589
- if self.show_plots:
1590
- self.plot_model_history(mode="classifier")
1591
-
1592
- # Set current classification model as the best model
1593
- self.model_class = load_model(
1594
- os.sep.join([self.model_folder, "classifier.h5"]),
1595
- custom_objects={"mse": MeanSquaredError()},
1596
- )
1597
- self.model_class.load_weights(os.sep.join([self.model_folder, "classifier.h5"]))
1598
-
1599
- time_callback = next(
1600
- (cb for cb in self.cb if isinstance(cb, TimeHistory)), None
1601
- )
1602
- self.dico = {
1603
- "history_classifier": self.history_classifier,
1604
- "execution_time_classifier": time_callback.times if time_callback else [],
1605
- }
1606
-
1607
- if hasattr(self, "x_test"):
1608
-
1609
- predictions = self.model_class.predict(self.x_test).argmax(axis=1)
1610
- ground_truth = self.y_class_test.argmax(axis=1)
1611
- assert (
1612
- predictions.shape == ground_truth.shape
1613
- ), "Mismatch in shape between the predictions and the ground truth..."
1614
-
1615
- title = "Test data"
1616
- IoU_score = jaccard_score(ground_truth, predictions, average=None)
1617
- balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
1618
- precision = precision_score(ground_truth, predictions, average=None)
1619
- recall = recall_score(ground_truth, predictions, average=None)
1620
-
1621
- print(f"Test IoU score: {IoU_score}")
1622
- print(f"Test Balanced accuracy score: {balanced_accuracy}")
1623
- print(f"Test Precision: {precision}")
1624
- print(f"Test Recall: {recall}")
1625
-
1626
- # Confusion matrix on test set
1627
- results = confusion_matrix(ground_truth, predictions)
1628
- self.dico.update(
1629
- {
1630
- "test_IoU": IoU_score,
1631
- "test_balanced_accuracy": balanced_accuracy,
1632
- "test_confusion": results,
1633
- "test_precision": precision,
1634
- "test_recall": recall,
1635
- }
1636
- )
1637
-
1638
- if self.show_plots:
1639
- try:
1640
- ConfusionMatrixDisplay.from_predictions(
1641
- ground_truth,
1642
- predictions,
1643
- cmap="Blues",
1644
- normalize="pred",
1645
- display_labels=["event", "no event", "left censored"],
1646
- )
1647
- plt.savefig(
1648
- os.sep.join([self.model_folder, "test_confusion_matrix.png"]),
1649
- bbox_inches="tight",
1650
- dpi=300,
1651
- )
1652
- # plt.pause(3)
1653
- plt.close()
1654
- except Exception as e:
1655
- print(e)
1656
- pass
1657
- print("Test set: ", classification_report(ground_truth, predictions))
1658
-
1659
- if hasattr(self, "x_val"):
1660
- predictions = self.model_class.predict(self.x_val).argmax(axis=1)
1661
- ground_truth = self.y_class_val.argmax(axis=1)
1662
- assert (
1663
- ground_truth.shape == predictions.shape
1664
- ), "Mismatch in shape between the predictions and the ground truth..."
1665
- title = "Validation data"
1666
-
1667
- # Validation scores
1668
- IoU_score = jaccard_score(ground_truth, predictions, average=None)
1669
- balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
1670
- precision = precision_score(ground_truth, predictions, average=None)
1671
- recall = recall_score(ground_truth, predictions, average=None)
1672
-
1673
- print(f"Validation IoU score: {IoU_score}")
1674
- print(f"Validation Balanced accuracy score: {balanced_accuracy}")
1675
- print(f"Validation Precision: {precision}")
1676
- print(f"Validation Recall: {recall}")
1677
-
1678
- # Confusion matrix on validation set
1679
- results = confusion_matrix(ground_truth, predictions)
1680
- self.dico.update(
1681
- {
1682
- "val_IoU": IoU_score,
1683
- "val_balanced_accuracy": balanced_accuracy,
1684
- "val_confusion": results,
1685
- "val_precision": precision,
1686
- "val_recall": recall,
1687
- }
1688
- )
1689
-
1690
- if self.show_plots:
1691
- try:
1692
- ConfusionMatrixDisplay.from_predictions(
1693
- ground_truth,
1694
- predictions,
1695
- cmap="Blues",
1696
- normalize="pred",
1697
- display_labels=["event", "no event", "left censored"],
1698
- )
1699
- plt.savefig(
1700
- os.sep.join(
1701
- [self.model_folder, "validation_confusion_matrix.png"]
1702
- ),
1703
- bbox_inches="tight",
1704
- dpi=300,
1705
- )
1706
- # plt.pause(3)
1707
- plt.close()
1708
- except Exception as e:
1709
- print(e)
1710
- pass
1711
- print("Validation set: ", classification_report(ground_truth, predictions))
1712
-
1713
- # Send result to GUI and wait
1714
- for cb in self.cb:
1715
- if hasattr(cb, "on_training_result"):
1716
- cb.on_training_result(self.dico)
1717
- time.sleep(3)
1718
-
1719
- def train_regressor(self):
1720
- """
1721
- Trains the regressor component of the model to estimate the time of interest for events in signals.
1722
-
1723
- This method compiles the regressor model (if not pretrained or if recompilation is requested) and
1724
- trains it on a subset of the prepared dataset containing signals with events. The training process
1725
- includes validation and early stopping based on mean squared error to prevent overfitting.
1726
-
1727
- Notes
1728
- -----
1729
- - The regressor model estimates the time at which an event of interest occurs within each signal.
1730
- - Only signals predicted to have an event by the classifier model are used for regressor training.
1731
- - Model performance metrics and training history are saved for analysis.
1732
-
1733
- """
1734
-
1735
- # Compile model
1736
- # if pretrained model
1737
- if self.pretrained is not None:
1738
- # if recompile
1739
- if self.recompile_pretrained:
1740
- print(
1741
- "Recompiling the pretrained regressor model... Warning, this action reinitializes all the weights; are you sure that this is what you intended?"
1742
- )
1743
- self.model_reg.set_weights(clone_model(self.model_reg).get_weights())
1744
- self.model_reg.compile(
1745
- optimizer=Adam(learning_rate=self.learning_rate),
1746
- loss=self.loss_reg,
1747
- metrics=["mse", "mae"],
1748
- )
1749
- else:
1750
- self.initial_model = clone_model(self.model_reg)
1751
- self.model_reg.set_weights(self.initial_model.get_weights())
1752
- self.model_reg.compile(
1753
- optimizer=Adam(learning_rate=self.learning_rate),
1754
- loss=self.loss_reg,
1755
- metrics=["mse", "mae"],
1756
- )
1757
- self.model_reg.set_weights(self.initial_model.get_weights())
1758
- else:
1759
- print("Compiling the regressor...")
1760
- self.model_reg.compile(
1761
- optimizer=Adam(learning_rate=self.learning_rate),
1762
- loss=self.loss_reg,
1763
- metrics=["mse", "mae"],
1764
- )
1765
-
1766
- self.gather_callbacks("regressor")
1767
-
1768
- # Train on subset of data with event
1769
-
1770
- subset = self.x_train[np.argmax(self.y_class_train, axis=1) == 0]
1771
- # for i in range(30):
1772
- # plt.plot(subset[i,:,0],c="tab:red")
1773
- # plt.plot(subset[i,:,1],c="tab:blue")
1774
- # plt.show()
1775
-
1776
- if hasattr(self, "x_val"):
1777
- self.history_regressor = self.model_reg.fit(
1778
- x=self.x_train[np.argmax(self.y_class_train, axis=1) == 0],
1779
- y=self.y_time_train[np.argmax(self.y_class_train, axis=1) == 0],
1780
- batch_size=self.batch_size,
1781
- epochs=self.epochs * 2,
1782
- validation_data=(
1783
- self.x_val[np.argmax(self.y_class_val, axis=1) == 0],
1784
- self.y_time_val[np.argmax(self.y_class_val, axis=1) == 0],
1785
- ),
1786
- callbacks=self.cb,
1787
- verbose=1,
1788
- )
1789
- else:
1790
- self.history_regressor = self.model_reg.fit(
1791
- x=self.x_train[np.argmax(self.y_class_train, axis=1) == 0],
1792
- y=self.y_time_train[np.argmax(self.y_class_train, axis=1) == 0],
1793
- batch_size=self.batch_size,
1794
- epochs=self.epochs * 2,
1795
- callbacks=self.cb,
1796
- validation_split=self.validation_split,
1797
- verbose=1,
1798
- )
1799
-
1800
- if self.show_plots:
1801
- self.plot_model_history(mode="regressor")
1802
- time_callback = next(
1803
- (cb for cb in self.cb if isinstance(cb, TimeHistory)), None
1804
- )
1805
- self.dico.update(
1806
- {
1807
- "history_regressor": self.history_regressor,
1808
- "execution_time_regressor": (
1809
- time_callback.times if time_callback else []
1810
- ),
1811
- }
1812
- )
1813
-
1814
- # Evaluate best model
1815
- self.model_reg = load_model(
1816
- os.sep.join([self.model_folder, "regressor.h5"]),
1817
- custom_objects={"mse": MeanSquaredError()},
1818
- )
1819
- self.model_reg.load_weights(os.sep.join([self.model_folder, "regressor.h5"]))
1820
- self.evaluate_regression_model()
1821
-
1822
- try:
1823
- np.save(os.sep.join([self.model_folder, "scores.npy"]), self.dico)
1824
- except Exception as e:
1825
- print(e)
1826
-
1827
- def plot_model_history(self, mode="regressor"):
1828
- """
1829
- Generates and saves plots of the training history for the classifier or regressor model.
1830
-
1831
- Parameters
1832
- ----------
1833
- mode : str, optional
1834
- Specifies which model's training history to plot. Options are "classifier" or "regressor". Default is "regressor".
1835
-
1836
- Notes
1837
- -----
1838
- - Plots include loss and accuracy metrics over epochs for the classifier, and loss metrics for the regressor.
1839
- - The plots are saved as image files in the model's output directory.
1840
-
1841
- """
1842
-
1843
- if mode == "regressor":
1844
- try:
1845
- plt.plot(self.history_regressor.history["loss"])
1846
- plt.plot(self.history_regressor.history["val_loss"])
1847
- plt.title("model loss")
1848
- plt.ylabel("loss")
1849
- plt.xlabel("epoch")
1850
- plt.yscale("log")
1851
- plt.legend(["train", "val"], loc="upper left")
1852
- # plt.pause(3)
1853
- plt.savefig(
1854
- os.sep.join([self.model_folder, "regression_loss.png"]),
1855
- bbox_inches="tight",
1856
- dpi=300,
1857
- )
1858
- plt.close()
1859
- except Exception as e:
1860
- print(f"Error {e}; could not generate plot...")
1861
- elif mode == "classifier":
1862
- try:
1863
- plt.plot(self.history_classifier.history["precision"])
1864
- plt.plot(self.history_classifier.history["val_precision"])
1865
- plt.title("model precision")
1866
- plt.ylabel("precision")
1867
- plt.xlabel("epoch")
1868
- plt.legend(["train", "val"], loc="upper left")
1869
- # plt.pause(3)
1870
- plt.savefig(
1871
- os.sep.join([self.model_folder, "classification_loss.png"]),
1872
- bbox_inches="tight",
1873
- dpi=300,
1874
- )
1875
- plt.close()
1876
- except Exception as e:
1877
- print(f"Error {e}; could not generate plot...")
1878
- else:
1879
- return None
1880
-
1881
- def evaluate_regression_model(self):
1882
- """
1883
- Evaluates the performance of the trained regression model on test and validation datasets.
1884
-
1885
- This method calculates and prints mean squared error and mean absolute error metrics for the regression model's
1886
- predictions. It also generates regression plots comparing predicted times of interest to true values.
1887
-
1888
- Notes
1889
- -----
1890
- - Evaluation is performed on both test and validation datasets, if available.
1891
- - Regression plots and performance metrics are saved in the model's output directory.
1892
-
1893
- """
1894
-
1895
- mse = MeanSquaredError()
1896
- mae = MeanAbsoluteError()
1897
-
1898
- if hasattr(self, "x_test"):
1899
-
1900
- print("Evaluate on test set...")
1901
- predictions = self.model_reg.predict(
1902
- self.x_test[np.argmax(self.y_class_test, axis=1) == 0],
1903
- batch_size=self.batch_size,
1904
- )[:, 0]
1905
- ground_truth = self.y_time_test[np.argmax(self.y_class_test, axis=1) == 0]
1906
- assert (
1907
- predictions.shape == ground_truth.shape
1908
- ), "Shape mismatch between predictions and ground truths..."
1909
-
1910
- test_mse = mse(ground_truth, predictions).numpy()
1911
- test_mae = mae(ground_truth, predictions).numpy()
1912
- print(f"MSE on test set: {test_mse}...")
1913
- print(f"MAE on test set: {test_mae}...")
1914
- if self.show_plots:
1915
- regression_plot(
1916
- predictions,
1917
- ground_truth,
1918
- savepath=os.sep.join([self.model_folder, "test_regression.png"]),
1919
- )
1920
- self.dico.update({"test_mse": test_mse, "test_mae": test_mae})
1921
-
1922
- if hasattr(self, "x_val"):
1923
- # Validation set
1924
- predictions = self.model_reg.predict(
1925
- self.x_val[np.argmax(self.y_class_val, axis=1) == 0],
1926
- batch_size=self.batch_size,
1927
- )[:, 0]
1928
- ground_truth = self.y_time_val[np.argmax(self.y_class_val, axis=1) == 0]
1929
- assert (
1930
- predictions.shape == ground_truth.shape
1931
- ), "Shape mismatch between predictions and ground truths..."
1932
-
1933
- val_mse = mse(ground_truth, predictions).numpy()
1934
- val_mae = mae(ground_truth, predictions).numpy()
1935
-
1936
- if self.show_plots:
1937
- regression_plot(
1938
- predictions,
1939
- ground_truth,
1940
- savepath=os.sep.join(
1941
- [self.model_folder, "validation_regression.png"]
1942
- ),
1943
- )
1944
- print(f"MSE on validation set: {val_mse}...")
1945
- print(f"MAE on validation set: {val_mae}...")
1946
-
1947
- self.dico.update(
1948
- {
1949
- "val_mse": val_mse,
1950
- "val_mae": val_mae,
1951
- "val_predictions": predictions,
1952
- "val_ground_truth": ground_truth,
1953
- }
1954
- )
1955
-
1956
- # Send result to GUI and wait
1957
- for cb in self.cb:
1958
- if hasattr(cb, "on_training_result"):
1959
- cb.on_training_result(self.dico)
1960
- time.sleep(3)
1961
-
1962
- def gather_callbacks(self, mode):
1963
- """
1964
- Prepares a list of Keras callbacks for model training based on the specified mode.
1965
-
1966
- Parameters
1967
- ----------
1968
- mode : str
1969
- The training mode for which callbacks are being prepared. Options are "classifier" or "regressor".
1970
-
1971
- Notes
1972
- -----
1973
- - Callbacks include learning rate reduction on plateau, early stopping, model checkpointing, and TensorBoard logging.
1974
- - The list of callbacks is stored in the class attribute `cb` and used during model training.
1975
-
1976
- """
1977
-
1978
- self.cb = []
1979
-
1980
- if mode == "classifier":
1981
-
1982
- reduce_lr = ReduceLROnPlateau(
1983
- monitor="val_iou",
1984
- factor=0.5,
1985
- patience=30,
1986
- cooldown=10,
1987
- min_lr=5e-10,
1988
- min_delta=1.0e-10,
1989
- verbose=1,
1990
- mode="max",
1991
- )
1992
- self.cb.append(reduce_lr)
1993
- csv_logger = CSVLogger(
1994
- os.sep.join([self.model_folder, "log_classifier.csv"]),
1995
- append=True,
1996
- separator=";",
1997
- )
1998
- self.cb.append(csv_logger)
1999
- checkpoint_path = os.sep.join([self.model_folder, "classifier.h5"])
2000
- cp_callback = ModelCheckpoint(
2001
- checkpoint_path,
2002
- monitor="val_iou",
2003
- mode="max",
2004
- verbose=1,
2005
- save_best_only=True,
2006
- save_weights_only=False,
2007
- save_freq="epoch",
2008
- )
2009
- self.cb.append(cp_callback)
2010
-
2011
- callback_stop = EarlyStopping(monitor="val_iou", mode="max", patience=100)
2012
- self.cb.append(callback_stop)
2013
-
2014
- elif mode == "regressor":
2015
-
2016
- reduce_lr = ReduceLROnPlateau(
2017
- monitor="val_loss",
2018
- factor=0.5,
2019
- patience=30,
2020
- cooldown=10,
2021
- min_lr=5e-10,
2022
- min_delta=1.0e-10,
2023
- verbose=1,
2024
- mode="min",
2025
- )
2026
- self.cb.append(reduce_lr)
2027
-
2028
- csv_logger = CSVLogger(
2029
- os.sep.join([self.model_folder, "log_regressor.csv"]),
2030
- append=True,
2031
- separator=";",
2032
- )
2033
- self.cb.append(csv_logger)
2034
-
2035
- checkpoint_path = os.sep.join([self.model_folder, "regressor.h5"])
2036
- cp_callback = ModelCheckpoint(
2037
- checkpoint_path,
2038
- monitor="val_loss",
2039
- mode="min",
2040
- verbose=1,
2041
- save_best_only=True,
2042
- save_weights_only=False,
2043
- save_freq="epoch",
2044
- )
2045
- self.cb.append(cp_callback)
2046
-
2047
- callback_stop = EarlyStopping(monitor="val_loss", mode="min", patience=200)
2048
- self.cb.append(callback_stop)
2049
-
2050
- log_dir = self.model_folder + os.sep
2051
- cb_tb = TensorBoard(log_dir=log_dir, update_freq="batch")
2052
- self.cb.append(cb_tb)
2053
-
2054
- cb_time = TimeHistory()
2055
- self.cb.append(cb_time)
2056
-
2057
- if hasattr(self, "callbacks") and self.callbacks is not None:
2058
- self.cb.extend(self.callbacks)
2059
-
2060
- def prepare_sets(self):
2061
- """
2062
- Generates and preprocesses training, validation, and test sets from loaded annotations.
2063
-
2064
- This method loads signal data from annotation files, normalizes and interpolates the signals, and splits
2065
- the dataset into training, validation, and test sets according to specified proportions.
2066
-
2067
- Notes
2068
- -----
2069
- - Signal annotations are expected to be stored in .npy format and contain required channels and event information.
2070
- - The method applies specified normalization and interpolation options to prepare the signals for model training.
2071
-
2072
- """
2073
-
2074
- self.x_set = []
2075
- self.y_time_set = []
2076
- self.y_class_set = []
2077
-
2078
- if isinstance(self.list_of_sets[0], str):
2079
- # Case 1: a list of npy files to be loaded
2080
- for s in self.list_of_sets:
2081
-
2082
- signal_dataset = self.load_set(s)
2083
- selected_signals, max_length = self.find_best_signal_match(
2084
- signal_dataset
2085
- )
2086
- signals_recast, classes, times_of_interest = (
2087
- self.cast_signals_into_training_data(
2088
- signal_dataset, selected_signals, max_length
2089
- )
2090
- )
2091
- signals_recast, times_of_interest = self.normalize_signals(
2092
- signals_recast, times_of_interest
2093
- )
2094
-
2095
- self.x_set.extend(signals_recast)
2096
- self.y_time_set.extend(times_of_interest)
2097
- self.y_class_set.extend(classes)
2098
-
2099
- elif isinstance(self.list_of_sets[0], list):
2100
- # Case 2: a list of sets (already loaded)
2101
- for signal_dataset in self.list_of_sets:
2102
-
2103
- selected_signals, max_length = self.find_best_signal_match(
2104
- signal_dataset
2105
- )
2106
- signals_recast, classes, times_of_interest = (
2107
- self.cast_signals_into_training_data(
2108
- signal_dataset, selected_signals, max_length
2109
- )
2110
- )
2111
- signals_recast, times_of_interest = self.normalize_signals(
2112
- signals_recast, times_of_interest
2113
- )
2114
-
2115
- self.x_set.extend(signals_recast)
2116
- self.y_time_set.extend(times_of_interest)
2117
- self.y_class_set.extend(classes)
2118
-
2119
- self.x_set = np.array(self.x_set).astype(np.float32)
2120
- self.x_set = self.interpolate_signals(self.x_set)
2121
-
2122
- self.y_time_set = np.array(self.y_time_set).astype(np.float32)
2123
- self.y_class_set = np.array(self.y_class_set).astype(np.float32)
2124
-
2125
- class_test = np.isin(self.y_class_set, [0, 1, 2])
2126
- self.x_set = self.x_set[class_test]
2127
- self.y_time_set = self.y_time_set[class_test]
2128
- self.y_class_set = self.y_class_set[class_test]
2129
-
2130
- # Compute class weights and one-hot encode
2131
- self.class_weights = compute_weights(self.y_class_set)
2132
- self.nbr_classes = 3 # len(np.unique(self.y_class_set))
2133
- self.y_class_set = to_categorical(self.y_class_set, num_classes=3)
2134
-
2135
- ds = train_test_split(
2136
- self.x_set,
2137
- self.y_time_set,
2138
- self.y_class_set,
2139
- validation_size=self.validation_split,
2140
- test_size=self.test_split,
2141
- )
2142
-
2143
- self.x_train = ds["x_train"]
2144
- self.x_val = ds["x_val"]
2145
- self.y_time_train = ds["y1_train"].astype(np.float32)
2146
- self.y_time_val = ds["y1_val"].astype(np.float32)
2147
- self.y_class_train = ds["y2_train"]
2148
- self.y_class_val = ds["y2_val"]
2149
-
2150
- if self.test_split > 0:
2151
- self.x_test = ds["x_test"]
2152
- self.y_time_test = ds["y1_test"].astype(np.float32)
2153
- self.y_class_test = ds["y2_test"]
2154
-
2155
- if self.augment:
2156
- self.augment_training_set()
2157
-
2158
- def augment_training_set(self, time_shift=True):
2159
- """
2160
- Augments the training dataset with artificially generated data to increase model robustness.
2161
-
2162
- Parameters
2163
- ----------
2164
- time_shift : bool, optional
2165
- Specifies whether to include time-shifted versions of signals in the augmented dataset. Default is True.
2166
-
2167
- Notes
2168
- -----
2169
- - Augmentation strategies include random time shifting and signal modifications to simulate variations in real data.
2170
- - The augmented dataset is used for training the classifier and regressor models to improve generalization.
2171
-
2172
- """
2173
-
2174
- nbr_augment = self.augmentation_factor * len(self.x_train)
2175
- randomize = np.arange(len(self.x_train))
2176
-
2177
- unique, counts = np.unique(
2178
- self.y_class_train.argmax(axis=1), return_counts=True
2179
- )
2180
- frac = counts / sum(counts)
2181
- weights = [frac[0] / f for f in frac]
2182
- weights[0] = weights[0] * 3
2183
-
2184
- self.pre_augment_weights = weights / sum(weights)
2185
- weights_array = [
2186
- self.pre_augment_weights[a.argmax()] for a in self.y_class_train
2187
- ]
2188
-
2189
- indices = random.choices(randomize, k=nbr_augment, weights=weights_array)
2190
-
2191
- x_train_aug = []
2192
- y_time_train_aug = []
2193
- y_class_train_aug = []
2194
-
2195
- counts = [0.0, 0.0, 0.0]
2196
- # warning augmentation creates class 2 even if does not exist in data, need to address this
2197
- for k in indices:
2198
- counts[self.y_class_train[k].argmax()] += 1
2199
- aug = augmenter(
2200
- self.x_train[k],
2201
- self.y_time_train[k],
2202
- self.y_class_train[k],
2203
- self.model_signal_length,
2204
- time_shift=time_shift,
2205
- )
2206
- x_train_aug.append(aug[0])
2207
- y_time_train_aug.append(aug[1])
2208
- y_class_train_aug.append(aug[2])
2209
-
2210
- # Save augmented training set
2211
- self.x_train = np.array(x_train_aug)
2212
- self.y_time_train = np.array(y_time_train_aug)
2213
- self.y_class_train = np.array(y_class_train_aug)
2214
-
2215
- self.class_weights = compute_weights(self.y_class_train.argmax(axis=1))
2216
- print(f"New class weights: {self.class_weights}...")
2217
-
2218
- def load_set(self, signal_dataset):
2219
- return np.load(signal_dataset, allow_pickle=True)
2220
-
2221
- def find_best_signal_match(self, signal_dataset):
2222
-
2223
- required_signals = self.channel_option
2224
- available_signals = list(signal_dataset[0].keys())
2225
-
2226
- selected_signals = []
2227
- for s in required_signals:
2228
- pattern_test = [s in a for a in available_signals]
2229
- if np.any(pattern_test):
2230
- valid_columns = np.array(available_signals)[np.array(pattern_test)]
2231
- if len(valid_columns) == 1:
2232
- selected_signals.append(valid_columns[0])
2233
- else:
2234
- print(f"Found several candidate signals: {valid_columns}")
2235
- for vc in natsorted(valid_columns):
2236
- if "circle" in vc:
2237
- selected_signals.append(vc)
2238
- break
2239
- else:
2240
- selected_signals.append(valid_columns[0])
2241
- else:
2242
- return None
2243
-
2244
- key_to_check = selected_signals[0] # self.channel_option[0]
2245
- signal_lengths = [len(l[key_to_check]) for l in signal_dataset]
2246
- max_length = np.amax(signal_lengths)
2247
-
2248
- return selected_signals, max_length
2249
-
2250
- def cast_signals_into_training_data(
2251
- self, signal_dataset, selected_signals, max_length
2252
- ):
2253
-
2254
- signals_recast = np.zeros((len(signal_dataset), max_length, self.n_channels))
2255
- classes = np.zeros(len(signal_dataset))
2256
- times_of_interest = np.zeros(len(signal_dataset))
2257
-
2258
- for k in range(len(signal_dataset)):
2259
-
2260
- for i in range(self.n_channels):
2261
- try:
2262
- # take into account timeline for accurate time regression
2263
-
2264
- if selected_signals[i].startswith("pair_"):
2265
- timeline = signal_dataset[k]["pair_FRAME"].astype(int)
2266
- elif selected_signals[i].startswith("reference_"):
2267
- timeline = signal_dataset[k]["reference_FRAME"].astype(int)
2268
- elif selected_signals[i].startswith("neighbor_"):
2269
- timeline = signal_dataset[k]["neighbor_FRAME"].astype(int)
2270
- else:
2271
- timeline = signal_dataset[k]["FRAME"].astype(int)
2272
- signals_recast[k, timeline, i] = signal_dataset[k][
2273
- selected_signals[i]
2274
- ]
2275
- except:
2276
- print(
2277
- f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation..."
2278
- )
2279
- pass
2280
-
2281
- classes[k] = signal_dataset[k]["class"]
2282
- times_of_interest[k] = signal_dataset[k]["time_of_interest"]
2283
-
2284
- # Correct absurd times of interest
2285
- times_of_interest[np.nonzero(classes)] = -1
2286
- times_of_interest[(times_of_interest <= 0.0)] = -1
2287
-
2288
- return signals_recast, classes, times_of_interest
2289
-
2290
- def normalize_signals(self, signals_recast, times_of_interest):
2291
-
2292
- signals_recast = pad_to_model_length(signals_recast, self.model_signal_length)
2293
- if self.normalize:
2294
- signals_recast = normalize_signal_set(
2295
- signals_recast,
2296
- self.channel_option,
2297
- normalization_percentile=self.normalization_percentile,
2298
- normalization_values=self.normalization_values,
2299
- normalization_clip=self.normalization_clip,
2300
- )
2301
-
2302
- # Trivial normalization for time of interest
2303
- times_of_interest /= self.model_signal_length
2304
-
2305
- return signals_recast, times_of_interest
2306
-
2307
-
2308
- def _interpret_normalization_parameters(
2309
- n_channels, normalization_percentile, normalization_values, normalization_clip
2310
- ):
2311
- """
2312
- Interprets and validates normalization parameters for each channel.
2313
-
2314
- This function ensures the normalization parameters are correctly formatted and expanded to match
2315
- the number of channels in the dataset. It provides default values and expands single values into
2316
- lists to match the number of channels if necessary.
2317
-
2318
- Parameters
2319
- ----------
2320
- n_channels : int
2321
- The number of channels in the dataset.
2322
- normalization_percentile : list of bool or bool, optional
2323
- Specifies whether to normalize each channel based on percentile values. If a single bool is provided,
2324
- it is expanded to a list matching the number of channels. Default is True for all channels.
2325
- normalization_values : list of lists or list, optional
2326
- Specifies the percentile values [lower, upper] for normalization for each channel. If a single pair
2327
- is provided, it is expanded to match the number of channels. Default is [[0.1, 99.9]] for all channels.
2328
- normalization_clip : list of bool or bool, optional
2329
- Specifies whether to clip the normalized values for each channel to the range [0, 1]. If a single bool
2330
- is provided, it is expanded to a list matching the number of channels. Default is False for all channels.
2331
-
2332
- Returns
2333
- -------
2334
- tuple
2335
- A tuple containing three lists: `normalization_percentile`, `normalization_values`, and `normalization_clip`,
2336
- each of length `n_channels`, representing the interpreted and validated normalization parameters for each channel.
2337
-
2338
- Raises
2339
- ------
2340
- AssertionError
2341
- If the lengths of the provided lists do not match `n_channels`.
2342
-
2343
- Examples
2344
- --------
2345
- >>> n_channels = 2
2346
- >>> normalization_percentile = True
2347
- >>> normalization_values = [0.1, 99.9]
2348
- >>> normalization_clip = False
2349
- >>> params = _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip)
2350
- >>> print(params)
2351
- # ([True, True], [[0.1, 99.9], [0.1, 99.9]], [False, False])
2352
-
2353
- """
2354
-
2355
- if normalization_percentile is None:
2356
- normalization_percentile = [True] * n_channels
2357
- if normalization_values is None:
2358
- normalization_values = [[0.1, 99.9]] * n_channels
2359
- if normalization_clip is None:
2360
- normalization_clip = [False] * n_channels
2361
-
2362
- if isinstance(normalization_percentile, bool):
2363
- normalization_percentile = [normalization_percentile] * n_channels
2364
- if isinstance(normalization_clip, bool):
2365
- normalization_clip = [normalization_clip] * n_channels
2366
- if len(normalization_values) == 2 and not isinstance(normalization_values[0], list):
2367
- normalization_values = [normalization_values] * n_channels
2368
-
2369
- assert len(normalization_values) == n_channels
2370
- assert len(normalization_clip) == n_channels
2371
- assert len(normalization_percentile) == n_channels
2372
-
2373
- return normalization_percentile, normalization_values, normalization_clip
2374
-
2375
-
2376
- def normalize_signal_set(
2377
- signal_set,
2378
- channel_option,
2379
- percentile_alive=[0.01, 99.99],
2380
- percentile_dead=[0.5, 99.999],
2381
- percentile_generic=[0.01, 99.99],
2382
- normalization_percentile=None,
2383
- normalization_values=None,
2384
- normalization_clip=None,
2385
- ):
2386
- """
2387
- Normalizes a set of single-cell signals across specified channels using given percentile values or specific normalization parameters.
2388
-
2389
- This function applies normalization to each channel in the signal set based on the provided normalization parameters,
2390
- which can be defined globally or per channel. The normalization process aims to scale the signal values to a standard
2391
- range, improving the consistency and comparability of signal measurements across samples.
2392
-
2393
- Parameters
2394
- ----------
2395
- signal_set : ndarray
2396
- A 3D numpy array representing the set of signals to be normalized, with dimensions corresponding to (samples, time points, channels).
2397
- channel_option : list of str
2398
- A list specifying the channels included in the signal set and their corresponding normalization strategy based on channel names.
2399
- percentile_alive : list of float, optional
2400
- The percentile values [lower, upper] used for normalization of signals from channels labeled as 'alive'. Default is [0.01, 99.99].
2401
- percentile_dead : list of float, optional
2402
- The percentile values [lower, upper] used for normalization of signals from channels labeled as 'dead'. Default is [0.5, 99.999].
2403
- percentile_generic : list of float, optional
2404
- The percentile values [lower, upper] used for normalization of signals from channels not specifically labeled as 'alive' or 'dead'.
2405
- Default is [0.01, 99.99].
2406
- normalization_percentile : list of bool or None, optional
2407
- Specifies whether to normalize each channel based on percentile values. If None, the default percentile strategy is applied
2408
- based on `channel_option`. If a list, it should match the length of `channel_option`.
2409
- normalization_values : list of lists or None, optional
2410
- Specifies the percentile values [lower, upper] or fixed values [min, max] for normalization for each channel. Overrides
2411
- `percentile_alive`, `percentile_dead`, and `percentile_generic` if provided.
2412
- normalization_clip : list of bool or None, optional
2413
- Specifies whether to clip the normalized values for each channel to the range [0, 1]. If None, clipping is disabled by default.
2414
-
2415
- Returns
2416
- -------
2417
- ndarray
2418
- The normalized signal set with the same shape as the input `signal_set`.
2419
-
2420
- Notes
2421
- -----
2422
- - The function supports different normalization strategies for 'alive', 'dead', and generic signal channels, which can be customized
2423
- via `channel_option` and the percentile parameters.
2424
- - Normalization parameters (`normalization_percentile`, `normalization_values`, `normalization_clip`) are interpreted and validated
2425
- by calling `_interpret_normalization_parameters`.
2426
-
2427
- Examples
2428
- --------
2429
- >>> signal_set = np.random.rand(100, 128, 2) # 100 samples, 128 time points, 2 channels
2430
- >>> channel_option = ['alive', 'dead']
2431
- >>> normalized_signals = normalize_signal_set(signal_set, channel_option)
2432
- # Normalizes the signal set based on the default percentile values for 'alive' and 'dead' channels.
2433
-
2434
- """
2435
-
2436
- # Check normalization params are ok
2437
- n_channels = len(channel_option)
2438
- normalization_percentile, normalization_values, normalization_clip = (
2439
- _interpret_normalization_parameters(
2440
- n_channels,
2441
- normalization_percentile,
2442
- normalization_values,
2443
- normalization_clip,
2444
- )
2445
- )
2446
- for k, channel in enumerate(channel_option):
2447
-
2448
- zero_values = []
2449
- for i in range(len(signal_set)):
2450
- zeros_loc = np.where(signal_set[i, :, k] == 0)
2451
- zero_values.append(zeros_loc)
2452
-
2453
- values = signal_set[:, :, k]
2454
-
2455
- if normalization_percentile[k]:
2456
- min_val = np.nanpercentile(
2457
- values[values != 0.0], normalization_values[k][0]
2458
- )
2459
- max_val = np.nanpercentile(
2460
- values[values != 0.0], normalization_values[k][1]
2461
- )
2462
- else:
2463
- min_val = normalization_values[k][0]
2464
- max_val = normalization_values[k][1]
2465
-
2466
- signal_set[:, :, k] -= min_val
2467
- signal_set[:, :, k] /= max_val - min_val
2468
-
2469
- if normalization_clip[k]:
2470
- to_clip_low = []
2471
- to_clip_high = []
2472
- for i in range(len(signal_set)):
2473
- clip_low_loc = np.where(signal_set[i, :, k] <= 0)
2474
- clip_high_loc = np.where(signal_set[i, :, k] >= 1.0)
2475
- to_clip_low.append(clip_low_loc)
2476
- to_clip_high.append(clip_high_loc)
2477
-
2478
- for i, z in enumerate(to_clip_low):
2479
- signal_set[i, z, k] = 0.0
2480
- for i, z in enumerate(to_clip_high):
2481
- signal_set[i, z, k] = 1.0
2482
-
2483
- for i, z in enumerate(zero_values):
2484
- signal_set[i, z, k] = 0.0
2485
-
2486
- return signal_set
2487
-
2488
-
2489
- def pad_to_model_length(signal_set, model_signal_length):
2490
- """
2491
-
2492
- Pad the signal set to match the specified model signal length.
2493
-
2494
- Parameters
2495
- ----------
2496
- signal_set : array-like
2497
- The signal set to be padded.
2498
- model_signal_length : int
2499
- The desired length of the model signal.
2500
-
2501
- Returns
2502
- -------
2503
- array-like
2504
- The padded signal set.
2505
-
2506
- Notes
2507
- -----
2508
- This function pads the signal set with zeros along the second dimension (axis 1) to match the specified model signal
2509
- length. The padding is applied to the end of the signals, increasing their length.
2510
-
2511
- Examples
2512
- --------
2513
- >>> signal_set = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
2514
- >>> padded_signals = pad_to_model_length(signal_set, 5)
2515
-
2516
- """
2517
-
2518
- padded = np.pad(
2519
- signal_set,
2520
- [(0, 0), (0, model_signal_length - signal_set.shape[1]), (0, 0)],
2521
- mode="edge",
2522
- )
2523
-
2524
- return padded
2525
-
2526
-
2527
- def random_intensity_change(signal):
2528
- """
2529
-
2530
- Randomly change the intensity of a signal.
2531
-
2532
- Parameters
2533
- ----------
2534
- signal : array-like
2535
- The input signal to be modified.
2536
-
2537
- Returns
2538
- -------
2539
- array-like
2540
- The modified signal with randomly changed intensity.
2541
-
2542
- Notes
2543
- -----
2544
- This function applies a random intensity change to each channel of the input signal. The intensity change is
2545
- performed by multiplying each channel with a random value drawn from a uniform distribution between 0.7 and 1.0.
2546
-
2547
- Examples
2548
- --------
2549
- >>> signal = np.array([[1, 2, 3], [4, 5, 6]])
2550
- >>> modified_signal = random_intensity_change(signal)
2551
-
2552
- """
2553
-
2554
- for k in range(signal.shape[1]):
2555
- signal[:, k] = signal[:, k] * np.random.uniform(0.7, 1.0)
2556
-
2557
- return signal
2558
-
2559
-
2560
- def gauss_noise(signal):
2561
- """
2562
-
2563
- Add Gaussian noise to a signal.
2564
-
2565
- Parameters
2566
- ----------
2567
- signal : array-like
2568
- The input signal to which noise will be added.
2569
-
2570
- Returns
2571
- -------
2572
- array-like
2573
- The signal with Gaussian noise added.
2574
-
2575
- Notes
2576
- -----
2577
- This function adds Gaussian noise to the input signal. The noise is generated by drawing random values from a
2578
- standard normal distribution and scaling them by a factor of 0.08 times the input signal. The scaled noise values
2579
- are then added to the original signal.
2580
-
2581
- Examples
2582
- --------
2583
- >>> signal = np.array([1, 2, 3, 4, 5])
2584
- >>> noisy_signal = gauss_noise(signal)
2585
-
2586
- """
2587
-
2588
- sig = 0.08 * np.random.uniform(0, 1)
2589
- signal = signal + sig * np.random.normal(0, 1, signal.shape) * signal
2590
- return signal
2591
-
2592
-
2593
- def random_time_shift(signal, time_of_interest, cclass, model_signal_length):
2594
- """
2595
-
2596
- Randomly shift the signals to another time.
2597
-
2598
- Parameters
2599
- ----------
2600
- signal : array-like
2601
- The signal to be shifted.
2602
- time_of_interest : int or float
2603
- The original time of interest for the signal. Use -1 if not applicable.
2604
- model_signal_length : int
2605
- The length of the model signal.
2606
-
2607
- Returns
2608
- -------
2609
- array-like
2610
- The shifted fluorescence signal.
2611
- int or float
2612
- The new time of interest if available; otherwise, the original time of interest.
2613
-
2614
- Notes
2615
- -----
2616
- This function randomly selects a target time within the specified model signal length and shifts the
2617
- signal accordingly. The shift is performed along the first dimension (axis 0) of the signal. The function uses
2618
- nearest-neighbor interpolation for shifting.
2619
-
2620
- If the original time of interest (`time_of_interest`) is provided (not equal to -1), the function returns the
2621
- shifted signal along with the new time of interest. Otherwise, it returns the shifted signal along with the
2622
- original time of interest.
2623
-
2624
- The `max_time` is set to the `model_signal_length` unless the original time of interest is provided. In that case,
2625
- `max_time` is set to `model_signal_length - 3` to prevent shifting too close to the edge.
2626
-
2627
- Examples
2628
- --------
2629
- >>> signal = np.array([[1, 2, 3], [4, 5, 6]])
2630
- >>> shifted_signal, new_time = random_time_shift(signal, 1, 5)
2631
-
2632
- """
2633
-
2634
- min_time = 3
2635
- max_time = model_signal_length
2636
-
2637
- return_target = False
2638
- if time_of_interest != -1:
2639
- return_target = True
2640
- max_time = (
2641
- model_signal_length + 1 / 3 * model_signal_length
2642
- ) # bias to have a third of event class becoming no event
2643
- min_time = -model_signal_length * 1 / 3
2644
-
2645
- times = np.linspace(
2646
- min_time, max_time, 2000
2647
- ) # symmetrize to create left-censored events
2648
- target_time = np.random.choice(times)
2649
-
2650
- delta_t = target_time - time_of_interest
2651
- signal = shift(signal, [delta_t, 0], order=0, mode="nearest")
2652
-
2653
- if target_time <= 0 and np.argmax(cclass) == 0:
2654
- target_time = -1
2655
- cclass = np.array([0.0, 0.0, 1.0]).astype(np.float32)
2656
- if target_time >= model_signal_length and np.argmax(cclass) == 0:
2657
- target_time = -1
2658
- cclass = np.array([0.0, 1.0, 0.0]).astype(np.float32)
2659
-
2660
- if return_target:
2661
- return signal, target_time, cclass
2662
- else:
2663
- return signal, time_of_interest, cclass
2664
-
2665
-
2666
- def augmenter(
2667
- signal,
2668
- time_of_interest,
2669
- cclass,
2670
- model_signal_length,
2671
- time_shift=True,
2672
- probability=0.95,
2673
- ):
2674
- """
2675
- Randomly augments single-cell signals to simulate variations in noise, intensity ratios, and event times.
2676
-
2677
- This function applies random transformations to the input signal, including time shifts, intensity changes,
2678
- and the addition of Gaussian noise, with the aim of increasing the diversity of the dataset for training robust models.
2679
-
2680
- Parameters
2681
- ----------
2682
- signal : ndarray
2683
- A 1D numpy array representing the signal of a single cell to be augmented.
2684
- time_of_interest : float
2685
- The normalized time of interest (event time) for the signal, scaled to the range [0, 1].
2686
- cclass : ndarray
2687
- A one-hot encoded numpy array representing the class of the cell associated with the signal.
2688
- model_signal_length : int
2689
- The length of the signal expected by the model, used for scaling the time of interest.
2690
- time_shift : bool, optional
2691
- Specifies whether to apply random time shifts to the signal. Default is True.
2692
- probability : float, optional
2693
- The probability with which to apply the augmentation transformations. Default is 0.8.
2694
-
2695
- Returns
2696
- -------
2697
- tuple
2698
- A tuple containing the augmented signal, the normalized time of interest, and the class of the cell.
2699
-
2700
- Raises
2701
- ------
2702
- AssertionError
2703
- If the time of interest is provided but invalid for time shifting.
2704
-
2705
- Notes
2706
- -----
2707
- - Time shifting is not applied to cells of the class labeled as 'miscellaneous' (typically encoded as the class '2').
2708
- - The time of interest is rescaled based on the model's expected signal length before and after any time shift.
2709
- - Augmentation is applied with the specified probability to simulate realistic variability while maintaining
2710
- some original signals in the dataset.
2711
-
2712
- """
2713
-
2714
- if np.amax(time_of_interest) <= 1.0:
2715
- time_of_interest *= model_signal_length
2716
-
2717
- # augment with a certain probability
2718
- r = random.random()
2719
- if r <= probability:
2720
-
2721
- if time_shift:
2722
- # do not time shift miscellaneous cells
2723
- assert time_of_interest is not None, f"Please provide valid lysis times"
2724
- signal, time_of_interest, cclass = random_time_shift(
2725
- signal, time_of_interest, cclass, model_signal_length
2726
- )
2727
-
2728
- # signal = random_intensity_change(signal) #maybe bad idea for non percentile-normalized signals
2729
- signal = gauss_noise(signal)
2730
-
2731
- return signal, time_of_interest / model_signal_length, cclass
2732
-
2733
-
2734
- def residual_block1D(
2735
- x, number_of_filters, kernel_size=8, match_filter_size=True, connection="identity"
2736
- ):
2737
- """
2738
-
2739
- Create a 1D residual block.
2740
-
2741
- Parameters
2742
- ----------
2743
- x : Tensor
2744
- Input tensor.
2745
- number_of_filters : int
2746
- Number of filters in the convolutional layers.
2747
- match_filter_size : bool, optional
2748
- Whether to match the filter size of the skip connection to the output. Default is True.
2749
-
2750
- Returns
2751
- -------
2752
- Tensor
2753
- Output tensor of the residual block.
2754
-
2755
- Notes
2756
- -----
2757
- This function creates a 1D residual block by performing the original mapping followed by adding a skip connection
2758
- and applying non-linear activation. The skip connection allows the gradient to flow directly to earlier layers and
2759
- helps mitigate the vanishing gradient problem. The residual block consists of three convolutional layers with
2760
- batch normalization and ReLU activation functions.
2761
-
2762
- If `match_filter_size` is True, the skip connection is adjusted to have the same number of filters as the output.
2763
- Otherwise, the skip connection is kept as is.
2764
-
2765
- Examples
2766
- --------
2767
- >>> inputs = Input(shape=(10, 3))
2768
- >>> x = residual_block1D(inputs, 64)
2769
- # Create a 1D residual block with 64 filters and apply it to the input tensor.
2770
-
2771
- """
2772
-
2773
- # Create skip connection
2774
- x_skip = x
2775
-
2776
- # Perform the original mapping
2777
- if connection == "identity":
2778
- x = Conv1D(
2779
- number_of_filters, kernel_size=kernel_size, strides=1, padding="same"
2780
- )(x_skip)
2781
- elif connection == "projection":
2782
- x = ZeroPadding1D(padding=kernel_size // 2)(x_skip)
2783
- x = Conv1D(
2784
- number_of_filters, kernel_size=kernel_size, strides=2, padding="valid"
2785
- )(x)
2786
- x = BatchNormalization()(x)
2787
- x = Activation("relu")(x)
2788
-
2789
- x = Conv1D(number_of_filters, kernel_size=kernel_size, strides=1, padding="same")(x)
2790
- x = BatchNormalization()(x)
2791
-
2792
- if match_filter_size and connection == "identity":
2793
- x_skip = Conv1D(number_of_filters, kernel_size=1, padding="same")(x_skip)
2794
- elif match_filter_size and connection == "projection":
2795
- x_skip = Conv1D(number_of_filters, kernel_size=1, strides=2, padding="valid")(
2796
- x_skip
2797
- )
2798
-
2799
- # Add the skip connection to the regular mapping
2800
- x = Add()([x, x_skip])
2801
-
2802
- # Nonlinearly activate the result
2803
- x = Activation("relu")(x)
2804
-
2805
- # Return the result
2806
- return x
2807
-
2808
-
2809
- def MultiscaleResNetModel(
2810
- n_channels,
2811
- n_classes=3,
2812
- dropout_rate=0,
2813
- dense_collection=0,
2814
- use_pooling=True,
2815
- header="classifier",
2816
- model_signal_length=128,
2817
- ):
2818
- """
2819
-
2820
- Define a generic ResNet 1D encoder model.
2821
-
2822
- Parameters
2823
- ----------
2824
- n_channels : int
2825
- Number of input channels.
2826
- n_blocks : int
2827
- Number of residual blocks in the model.
2828
- n_classes : int, optional
2829
- Number of output classes. Default is 3.
2830
- dropout_rate : float, optional
2831
- Dropout rate to be applied. Default is 0.
2832
- dense_collection : int, optional
2833
- Number of neurons in the dense layer. Default is 0.
2834
- header : str, optional
2835
- Type of the model header. "classifier" for classification, "regressor" for regression. Default is "classifier".
2836
- model_signal_length : int, optional
2837
- Length of the input signal. Default is 128.
2838
-
2839
- Returns
2840
- -------
2841
- keras.models.Model
2842
- ResNet 1D encoder model.
2843
-
2844
- Notes
2845
- -----
2846
- This function defines a generic ResNet 1D encoder model with the specified number of input channels, residual
2847
- blocks, output classes, dropout rate, dense collection, and model header. The model architecture follows the
2848
- ResNet principles with 1D convolutional layers and residual connections. The final activation and number of
2849
- neurons in the output layer are determined based on the header type.
2850
-
2851
- Examples
2852
- --------
2853
- >>> model = ResNetModel(n_channels=3, n_blocks=4, n_classes=2, dropout_rate=0.2)
2854
- # Define a ResNet 1D encoder model with 3 input channels, 4 residual blocks, and 2 output classes.
2855
-
2856
- """
2857
-
2858
- if header == "classifier":
2859
- final_activation = "softmax"
2860
- neurons_final = n_classes
2861
- elif header == "regressor":
2862
- final_activation = "linear"
2863
- neurons_final = 1
2864
- else:
2865
- return None
2866
-
2867
- inputs = Input(
2868
- shape=(
2869
- model_signal_length,
2870
- n_channels,
2871
- )
2872
- )
2873
- x = ZeroPadding1D(3)(inputs)
2874
- x = Conv1D(64, kernel_size=7, strides=2, padding="valid", use_bias=False)(x)
2875
- x = BatchNormalization()(x)
2876
- x = ZeroPadding1D(1)(x)
2877
- x_common = MaxPooling1D(pool_size=3, strides=2, padding="valid")(x)
2878
-
2879
- # Block 1
2880
- x1 = residual_block1D(x_common, 64, kernel_size=7, connection="projection")
2881
- x1 = residual_block1D(x1, 128, kernel_size=7, connection="projection")
2882
- x1 = residual_block1D(x1, 256, kernel_size=7, connection="projection")
2883
- x1 = GlobalAveragePooling1D()(x1)
2884
-
2885
- # Block 2
2886
- x2 = residual_block1D(x_common, 64, kernel_size=5, connection="projection")
2887
- x2 = residual_block1D(x2, 128, kernel_size=5, connection="projection")
2888
- x2 = residual_block1D(x2, 256, kernel_size=5, connection="projection")
2889
- x2 = GlobalAveragePooling1D()(x2)
2890
-
2891
- # Block 3
2892
- x3 = residual_block1D(x_common, 64, kernel_size=3, connection="projection")
2893
- x3 = residual_block1D(x3, 128, kernel_size=3, connection="projection")
2894
- x3 = residual_block1D(x3, 256, kernel_size=3, connection="projection")
2895
- x3 = GlobalAveragePooling1D()(x3)
2896
-
2897
- x_combined = Concatenate()([x1, x2, x3])
2898
- x_combined = Flatten()(x_combined)
2899
-
2900
- if dense_collection > 0:
2901
- x_combined = Dense(dense_collection)(x_combined)
2902
- if dropout_rate > 0:
2903
- x_combined = Dropout(dropout_rate)(x_combined)
2904
-
2905
- x_combined = Dense(neurons_final, activation=final_activation, name=header)(
2906
- x_combined
2907
- )
2908
- model = Model(inputs, x_combined, name=header)
2909
-
2910
- return model
2911
-
2912
-
2913
- def ResNetModelCurrent(
2914
- n_channels,
2915
- n_slices,
2916
- depth=2,
2917
- use_pooling=True,
2918
- n_classes=3,
2919
- dropout_rate=0.1,
2920
- dense_collection=512,
2921
- header="classifier",
2922
- model_signal_length=128,
2923
- ):
2924
- """
2925
- Creates a ResNet-based model tailored for signal classification or regression tasks.
2926
-
2927
- This function constructs a 1D ResNet architecture with specified parameters. The model can be configured
2928
- for either classification or regression tasks, determined by the `header` parameter. It consists of
2929
- configurable ResNet blocks, global average pooling, optional dense layers, and dropout for regularization.
2930
-
2931
- Parameters
2932
- ----------
2933
- n_channels : int
2934
- The number of channels in the input signal.
2935
- n_slices : int
2936
- The number of slices (or ResNet blocks) to use in the model.
2937
- depth : int, optional
2938
- The depth of the network, i.e., how many times the number of filters is doubled. Default is 2.
2939
- use_pooling : bool, optional
2940
- Whether to use MaxPooling between ResNet blocks. Default is True.
2941
- n_classes : int, optional
2942
- The number of classes for the classification task. Ignored for regression. Default is 3.
2943
- dropout_rate : float, optional
2944
- The dropout rate for regularization. Default is 0.1.
2945
- dense_collection : int, optional
2946
- The number of neurons in the dense layer following global pooling. If 0, the dense layer is omitted. Default is 512.
2947
- header : str, optional
2948
- Specifies the task type: "classifier" for classification or "regressor" for regression. Default is "classifier".
2949
- model_signal_length : int, optional
2950
- The length of the input signal. Default is 128.
2951
-
2952
- Returns
2953
- -------
2954
- keras.Model
2955
- The constructed Keras model ready for training or inference.
2956
-
2957
- Notes
2958
- -----
2959
- - The model uses Conv1D layers for signal processing and applies global average pooling before the final classification
2960
- or regression layer.
2961
- - The choice of `final_activation` and `neurons_final` depends on the task: "softmax" and `n_classes` for classification,
2962
- and "linear" and 1 for regression.
2963
- - This function relies on a custom `residual_block1D` function for constructing ResNet blocks.
2964
-
2965
- Examples
2966
- --------
2967
- >>> model = ResNetModelCurrent(n_channels=1, n_slices=2, depth=2, use_pooling=True, n_classes=3, dropout_rate=0.1, dense_collection=512, header="classifier", model_signal_length=128)
2968
- # Creates a ResNet model configured for classification with 3 classes.
2969
-
2970
- """
2971
-
2972
- if header == "classifier":
2973
- final_activation = "softmax"
2974
- neurons_final = n_classes
2975
- elif header == "regressor":
2976
- final_activation = "linear"
2977
- neurons_final = 1
2978
- else:
2979
- return None
2980
-
2981
- inputs = Input(
2982
- shape=(
2983
- model_signal_length,
2984
- n_channels,
2985
- )
2986
- )
2987
- x2 = Conv1D(64, kernel_size=1, strides=1, padding="same")(inputs)
2988
-
2989
- n_filters = 64
2990
- for k in range(depth):
2991
- for i in range(n_slices):
2992
- x2 = residual_block1D(x2, n_filters, kernel_size=8)
2993
- n_filters *= 2
2994
- if use_pooling and k != (depth - 1):
2995
- x2 = MaxPooling1D()(x2)
2996
-
2997
- x2 = GlobalAveragePooling1D()(x2)
2998
- if dense_collection > 0:
2999
- x2 = Dense(dense_collection)(x2)
3000
- if dropout_rate > 0:
3001
- x2 = Dropout(dropout_rate)(x2)
3002
-
3003
- x2 = Dense(neurons_final, activation=final_activation, name=header)(x2)
3004
- model = Model(inputs, x2, name=header)
3005
-
3006
- return model
3007
-
3008
-
3009
587
  def train_signal_model(config):
3010
588
  """
3011
589
  Initiates the training of a signal detection model using a specified configuration file.