celldetective 1.5.0b7__py3-none-any.whl → 1.5.0b8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/_version.py +1 -1
- celldetective/event_detection_models.py +2463 -0
- celldetective/gui/base/channel_norm_generator.py +19 -3
- celldetective/gui/base/figure_canvas.py +1 -1
- celldetective/gui/base_annotator.py +2 -5
- celldetective/gui/event_annotator.py +248 -138
- celldetective/gui/pair_event_annotator.py +146 -20
- celldetective/gui/process_block.py +2 -2
- celldetective/gui/seg_model_loader.py +4 -4
- celldetective/gui/settings/_settings_event_model_training.py +32 -14
- celldetective/gui/settings/_settings_segmentation_model_training.py +5 -5
- celldetective/gui/settings/_settings_signal_annotator.py +0 -19
- celldetective/gui/viewers/base_viewer.py +17 -20
- celldetective/processes/train_signal_model.py +1 -1
- celldetective/scripts/train_signal_model.py +1 -1
- celldetective/signals.py +4 -2426
- celldetective/utils/event_detection/__init__.py +1 -1
- {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b8.dist-info}/METADATA +1 -1
- {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b8.dist-info}/RECORD +24 -23
- tests/test_signals.py +4 -4
- {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b8.dist-info}/WHEEL +0 -0
- {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b8.dist-info}/entry_points.txt +0 -0
- {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b8.dist-info}/licenses/LICENSE +0 -0
- {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b8.dist-info}/top_level.txt +0 -0
celldetective/signals.py
CHANGED
|
@@ -3,63 +3,17 @@ import os
|
|
|
3
3
|
import subprocess
|
|
4
4
|
import json
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
EarlyStopping,
|
|
9
|
-
ModelCheckpoint,
|
|
10
|
-
TensorBoard,
|
|
11
|
-
ReduceLROnPlateau,
|
|
12
|
-
CSVLogger,
|
|
13
|
-
)
|
|
14
|
-
from tensorflow.keras.losses import (
|
|
15
|
-
CategoricalCrossentropy,
|
|
16
|
-
MeanSquaredError,
|
|
17
|
-
MeanAbsoluteError,
|
|
18
|
-
)
|
|
19
|
-
from tensorflow.keras.metrics import Precision, Recall, MeanIoU
|
|
20
|
-
from tensorflow.keras.models import load_model, clone_model
|
|
21
|
-
from tensorflow.config.experimental import list_physical_devices, set_memory_growth
|
|
22
|
-
from tensorflow.keras.utils import to_categorical
|
|
23
|
-
from tensorflow.keras import Input, Model
|
|
24
|
-
from tensorflow.keras.layers import (
|
|
25
|
-
Conv1D,
|
|
26
|
-
BatchNormalization,
|
|
27
|
-
Dense,
|
|
28
|
-
Activation,
|
|
29
|
-
Add,
|
|
30
|
-
MaxPooling1D,
|
|
31
|
-
Dropout,
|
|
32
|
-
GlobalAveragePooling1D,
|
|
33
|
-
Concatenate,
|
|
34
|
-
ZeroPadding1D,
|
|
35
|
-
Flatten,
|
|
36
|
-
)
|
|
37
|
-
from tensorflow.keras.callbacks import Callback
|
|
38
|
-
from sklearn.metrics import confusion_matrix, classification_report
|
|
39
|
-
from sklearn.metrics import (
|
|
40
|
-
jaccard_score,
|
|
41
|
-
balanced_accuracy_score,
|
|
42
|
-
precision_score,
|
|
43
|
-
recall_score,
|
|
44
|
-
)
|
|
45
|
-
from scipy.interpolate import interp1d
|
|
46
|
-
from scipy.ndimage import shift
|
|
47
|
-
from sklearn.metrics import ConfusionMatrixDisplay
|
|
6
|
+
# TensorFlow imports are lazy-loaded in functions that need them to avoid
|
|
7
|
+
# slow import times for modules that don't require TensorFlow.
|
|
48
8
|
|
|
49
9
|
from celldetective.utils.model_loaders import locate_signal_model
|
|
50
10
|
from celldetective.utils.data_loaders import get_position_table, get_position_pickle
|
|
51
11
|
from celldetective.tracking import clean_trajectories, interpolate_nan_properties
|
|
52
|
-
from celldetective.utils.dataset_helpers import compute_weights, train_test_split
|
|
53
|
-
from celldetective.utils.plots.regression import regression_plot
|
|
54
12
|
import matplotlib.pyplot as plt
|
|
55
13
|
from natsort import natsorted
|
|
56
|
-
from glob import glob
|
|
57
|
-
import random
|
|
58
14
|
from celldetective.utils.color_mappings import color_from_status, color_from_class
|
|
59
15
|
from math import floor
|
|
60
16
|
from scipy.optimize import curve_fit
|
|
61
|
-
import time
|
|
62
|
-
import math
|
|
63
17
|
import pandas as pd
|
|
64
18
|
from pandas.api.types import is_numeric_dtype
|
|
65
19
|
from scipy.stats import median_abs_deviation
|
|
@@ -69,59 +23,6 @@ abs_path = os.sep.join(
|
|
|
69
23
|
)
|
|
70
24
|
|
|
71
25
|
|
|
72
|
-
class TimeHistory(Callback):
|
|
73
|
-
"""
|
|
74
|
-
A custom Keras callback to log the duration of each epoch during training.
|
|
75
|
-
|
|
76
|
-
This callback records the time taken for each epoch during the model training process, allowing for
|
|
77
|
-
monitoring of training efficiency and performance over time. The times are stored in a list, with each
|
|
78
|
-
element representing the duration of an epoch in seconds.
|
|
79
|
-
|
|
80
|
-
Attributes
|
|
81
|
-
----------
|
|
82
|
-
times : list
|
|
83
|
-
A list of times (in seconds) taken for each epoch during the training. This list is populated as the
|
|
84
|
-
training progresses.
|
|
85
|
-
|
|
86
|
-
Methods
|
|
87
|
-
-------
|
|
88
|
-
on_train_begin(logs={})
|
|
89
|
-
Initializes the list of times at the beginning of training.
|
|
90
|
-
|
|
91
|
-
on_epoch_begin(epoch, logs={})
|
|
92
|
-
Records the start time of the current epoch.
|
|
93
|
-
|
|
94
|
-
on_epoch_end(epoch, logs={})
|
|
95
|
-
Calculates and appends the duration of the current epoch to the `times` list.
|
|
96
|
-
|
|
97
|
-
Notes
|
|
98
|
-
-----
|
|
99
|
-
- This callback is intended to be used with the `fit` method of Keras models.
|
|
100
|
-
- The time measurements are made using the `time.time()` function, which provides wall-clock time.
|
|
101
|
-
|
|
102
|
-
Examples
|
|
103
|
-
--------
|
|
104
|
-
>>> from keras.models import Sequential
|
|
105
|
-
>>> from keras.layers import Dense
|
|
106
|
-
>>> model = Sequential([Dense(10, activation='relu', input_shape=(20,)), Dense(1)])
|
|
107
|
-
>>> time_callback = TimeHistory()
|
|
108
|
-
>>> model.compile(optimizer='adam', loss='mean_squared_error')
|
|
109
|
-
>>> model.fit(x_train, y_train, epochs=10, callbacks=[time_callback])
|
|
110
|
-
>>> print(time_callback.times)
|
|
111
|
-
# This will print the time taken for each epoch during the training.
|
|
112
|
-
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
def on_train_begin(self, logs={}):
|
|
116
|
-
self.times = []
|
|
117
|
-
|
|
118
|
-
def on_epoch_begin(self, epoch, logs={}):
|
|
119
|
-
self.epoch_time_start = time.time()
|
|
120
|
-
|
|
121
|
-
def on_epoch_end(self, epoch, logs={}):
|
|
122
|
-
self.times.append(time.time() - self.epoch_time_start)
|
|
123
|
-
|
|
124
|
-
|
|
125
26
|
def analyze_signals(
|
|
126
27
|
trajectories,
|
|
127
28
|
model,
|
|
@@ -180,6 +81,7 @@ def analyze_signals(
|
|
|
180
81
|
- Signal selection and preprocessing are based on the requirements specified in the model's configuration.
|
|
181
82
|
|
|
182
83
|
"""
|
|
84
|
+
from celldetective.event_detection_models import SignalDetectionModel
|
|
183
85
|
|
|
184
86
|
model_path = locate_signal_model(model, path=model_path)
|
|
185
87
|
complete_path = model_path # +model
|
|
@@ -476,6 +378,7 @@ def analyze_pair_signals(
|
|
|
476
378
|
"y": "POSITION_Y",
|
|
477
379
|
},
|
|
478
380
|
):
|
|
381
|
+
from celldetective.event_detection_models import SignalDetectionModel
|
|
479
382
|
|
|
480
383
|
model_path = locate_signal_model(model, path=model_path, pairs=True)
|
|
481
384
|
print(f"Looking for model in {model_path}...")
|
|
@@ -681,2331 +584,6 @@ def analyze_pair_signals(
|
|
|
681
584
|
return trajectories_pairs
|
|
682
585
|
|
|
683
586
|
|
|
684
|
-
class SignalDetectionModel(object):
|
|
685
|
-
"""
|
|
686
|
-
A class for creating and managing signal detection models for analyzing biological signals.
|
|
687
|
-
|
|
688
|
-
This class provides functionalities to load a pretrained signal detection model or create one from scratch,
|
|
689
|
-
preprocess input signals, train the model, and make predictions on new data.
|
|
690
|
-
|
|
691
|
-
Parameters
|
|
692
|
-
----------
|
|
693
|
-
path : str, optional
|
|
694
|
-
Path to the directory containing the model and its configuration. This is used when loading a pretrained model.
|
|
695
|
-
pretrained : str, optional
|
|
696
|
-
Path to the pretrained model to load. If specified, the model and its configuration are loaded from this path.
|
|
697
|
-
channel_option : list of str, optional
|
|
698
|
-
Specifies the channels to be used for signal analysis. Default is ["live_nuclei_channel"].
|
|
699
|
-
model_signal_length : int, optional
|
|
700
|
-
The length of the input signals that the model expects. Default is 128.
|
|
701
|
-
n_channels : int, optional
|
|
702
|
-
The number of channels in the input signals. Default is 1.
|
|
703
|
-
n_conv : int, optional
|
|
704
|
-
The number of convolutional layers in the model. Default is 2.
|
|
705
|
-
n_classes : int, optional
|
|
706
|
-
The number of classes for the classification task. Default is 3.
|
|
707
|
-
dense_collection : int, optional
|
|
708
|
-
The number of units in the dense layer of the model. Default is 512.
|
|
709
|
-
dropout_rate : float, optional
|
|
710
|
-
The dropout rate applied to the dense layer of the model. Default is 0.1.
|
|
711
|
-
label : str, optional
|
|
712
|
-
A label for the model, used in naming and organizing outputs. Default is ''.
|
|
713
|
-
|
|
714
|
-
Attributes
|
|
715
|
-
----------
|
|
716
|
-
model_class : keras Model
|
|
717
|
-
The classification model for predicting the class of signals.
|
|
718
|
-
model_reg : keras Model
|
|
719
|
-
The regression model for predicting the time of interest for signals.
|
|
720
|
-
|
|
721
|
-
Methods
|
|
722
|
-
-------
|
|
723
|
-
load_pretrained_model()
|
|
724
|
-
Loads the model and its configuration from the pretrained path.
|
|
725
|
-
create_models_from_scratch()
|
|
726
|
-
Creates new models for classification and regression from scratch.
|
|
727
|
-
prep_gpu()
|
|
728
|
-
Prepares GPU devices for training, if available.
|
|
729
|
-
fit_from_directory(ds_folders, ...)
|
|
730
|
-
Trains the model using data from specified directories.
|
|
731
|
-
fit(x_train, y_time_train, y_class_train, ...)
|
|
732
|
-
Trains the model using provided datasets.
|
|
733
|
-
predict_class(x, ...)
|
|
734
|
-
Predicts the class of input signals.
|
|
735
|
-
predict_time_of_interest(x, ...)
|
|
736
|
-
Predicts the time of interest for input signals.
|
|
737
|
-
plot_model_history(mode)
|
|
738
|
-
Plots the training history for the specified mode (classifier or regressor).
|
|
739
|
-
evaluate_regression_model()
|
|
740
|
-
Evaluates the regression model on test and validation data.
|
|
741
|
-
gather_callbacks(mode)
|
|
742
|
-
Gathers and prepares callbacks for training based on the specified mode.
|
|
743
|
-
generate_sets()
|
|
744
|
-
Generates training, validation, and test sets from loaded data.
|
|
745
|
-
augment_training_set()
|
|
746
|
-
Augments the training set with additional generated data.
|
|
747
|
-
load_and_normalize(subset)
|
|
748
|
-
Loads and normalizes signals from a subset of data.
|
|
749
|
-
|
|
750
|
-
Notes
|
|
751
|
-
-----
|
|
752
|
-
- This class is designed to work with biological signal data, such as time series from microscopy imaging.
|
|
753
|
-
- The model architecture and training configurations can be customized through the class parameters and methods.
|
|
754
|
-
|
|
755
|
-
"""
|
|
756
|
-
|
|
757
|
-
def __init__(
|
|
758
|
-
self,
|
|
759
|
-
path=None,
|
|
760
|
-
pretrained=None,
|
|
761
|
-
channel_option=["live_nuclei_channel"],
|
|
762
|
-
model_signal_length=128,
|
|
763
|
-
n_channels=1,
|
|
764
|
-
n_conv=2,
|
|
765
|
-
n_classes=3,
|
|
766
|
-
dense_collection=512,
|
|
767
|
-
dropout_rate=0.1,
|
|
768
|
-
label="",
|
|
769
|
-
):
|
|
770
|
-
|
|
771
|
-
self.prep_gpu()
|
|
772
|
-
|
|
773
|
-
self.model_signal_length = model_signal_length
|
|
774
|
-
self.channel_option = channel_option
|
|
775
|
-
self.pretrained = pretrained
|
|
776
|
-
self.n_channels = n_channels
|
|
777
|
-
self.n_conv = n_conv
|
|
778
|
-
self.n_classes = n_classes
|
|
779
|
-
self.dense_collection = dense_collection
|
|
780
|
-
self.dropout_rate = dropout_rate
|
|
781
|
-
self.label = label
|
|
782
|
-
self.show_plots = True
|
|
783
|
-
|
|
784
|
-
if self.pretrained is not None:
|
|
785
|
-
print(f"Load pretrained models from {pretrained}...")
|
|
786
|
-
test = self.load_pretrained_model()
|
|
787
|
-
if test is None:
|
|
788
|
-
self.pretrained = None
|
|
789
|
-
print(
|
|
790
|
-
"Pretrained model could not be loaded. Check the log for error. Abort..."
|
|
791
|
-
)
|
|
792
|
-
return None
|
|
793
|
-
else:
|
|
794
|
-
print("Create models from scratch...")
|
|
795
|
-
self.create_models_from_scratch()
|
|
796
|
-
print("Models successfully created.")
|
|
797
|
-
|
|
798
|
-
def load_pretrained_model(self):
|
|
799
|
-
"""
|
|
800
|
-
Loads a pretrained model and its configuration from the specified path.
|
|
801
|
-
|
|
802
|
-
This method attempts to load both the classification and regression models from the path specified during the
|
|
803
|
-
class instantiation. It also loads the model configuration from a JSON file and updates the model attributes
|
|
804
|
-
accordingly. If the models cannot be loaded, an error message is printed.
|
|
805
|
-
|
|
806
|
-
Raises
|
|
807
|
-
------
|
|
808
|
-
Exception
|
|
809
|
-
If there is an error loading the model or the configuration file, an exception is raised with details.
|
|
810
|
-
|
|
811
|
-
Notes
|
|
812
|
-
-----
|
|
813
|
-
- The models are expected to be saved in .h5 format with the filenames "classifier.h5" and "regressor.h5".
|
|
814
|
-
- The configuration file is expected to be named "config_input.json" and located in the same directory as the models.
|
|
815
|
-
|
|
816
|
-
"""
|
|
817
|
-
|
|
818
|
-
if self.pretrained.endswith(os.sep):
|
|
819
|
-
self.pretrained = os.sep.join(self.pretrained.split(os.sep)[:-1])
|
|
820
|
-
|
|
821
|
-
try:
|
|
822
|
-
self.model_class = load_model(
|
|
823
|
-
os.sep.join([self.pretrained, "classifier.h5"]),
|
|
824
|
-
compile=False,
|
|
825
|
-
custom_objects={"mse": MeanSquaredError()},
|
|
826
|
-
)
|
|
827
|
-
self.model_class.load_weights(
|
|
828
|
-
os.sep.join([self.pretrained, "classifier.h5"])
|
|
829
|
-
)
|
|
830
|
-
self.model_class = self.freeze_encoder(self.model_class, 5)
|
|
831
|
-
print("Classifier successfully loaded...")
|
|
832
|
-
except Exception as e:
|
|
833
|
-
print(f"Error {e}...")
|
|
834
|
-
self.model_class = None
|
|
835
|
-
try:
|
|
836
|
-
self.model_reg = load_model(
|
|
837
|
-
os.sep.join([self.pretrained, "regressor.h5"]),
|
|
838
|
-
compile=False,
|
|
839
|
-
custom_objects={"mse": MeanSquaredError()},
|
|
840
|
-
)
|
|
841
|
-
self.model_reg.load_weights(os.sep.join([self.pretrained, "regressor.h5"]))
|
|
842
|
-
self.model_reg = self.freeze_encoder(self.model_reg, 5)
|
|
843
|
-
print("Regressor successfully loaded...")
|
|
844
|
-
except Exception as e:
|
|
845
|
-
print(f"Error {e}...")
|
|
846
|
-
self.model_reg = None
|
|
847
|
-
|
|
848
|
-
if self.model_class is None and self.model_reg is None:
|
|
849
|
-
return None
|
|
850
|
-
|
|
851
|
-
# load config
|
|
852
|
-
with open(os.sep.join([self.pretrained, "config_input.json"])) as config_file:
|
|
853
|
-
model_config = json.load(config_file)
|
|
854
|
-
self.config = model_config
|
|
855
|
-
|
|
856
|
-
req_channels = model_config["channels"]
|
|
857
|
-
print(f"Required channels read from pretrained model: {req_channels}")
|
|
858
|
-
self.channel_option = req_channels
|
|
859
|
-
if "normalize" in model_config:
|
|
860
|
-
self.normalize = model_config["normalize"]
|
|
861
|
-
if "normalization_percentile" in model_config:
|
|
862
|
-
self.normalization_percentile = model_config["normalization_percentile"]
|
|
863
|
-
if "normalization_values" in model_config:
|
|
864
|
-
self.normalization_values = model_config["normalization_values"]
|
|
865
|
-
if "normalization_percentile" in model_config:
|
|
866
|
-
self.normalization_clip = model_config["normalization_clip"]
|
|
867
|
-
if "label" in model_config:
|
|
868
|
-
self.label = model_config["label"]
|
|
869
|
-
|
|
870
|
-
try:
|
|
871
|
-
self.n_channels = self.model_class.layers[0].input_shape[0][-1]
|
|
872
|
-
self.model_signal_length = self.model_class.layers[0].input_shape[0][-2]
|
|
873
|
-
self.n_classes = self.model_class.layers[-1].output_shape[-1]
|
|
874
|
-
model_class_input_shape = self.model_class.layers[0].input_shape[0]
|
|
875
|
-
model_reg_input_shape = self.model_reg.layers[0].input_shape[0]
|
|
876
|
-
except AttributeError:
|
|
877
|
-
self.n_channels = self.model_class.input_shape[
|
|
878
|
-
-1
|
|
879
|
-
] # self.model_class.layers[0].input.shape[0][-1]
|
|
880
|
-
self.model_signal_length = self.model_class.input_shape[
|
|
881
|
-
-2
|
|
882
|
-
] # self.model_class.layers[0].input[0].shape[0][-2]
|
|
883
|
-
self.n_classes = self.model_class.output_shape[
|
|
884
|
-
-1
|
|
885
|
-
] # self.model_class.layers[-1].output[0].shape[-1]
|
|
886
|
-
model_class_input_shape = self.model_class.input_shape
|
|
887
|
-
model_reg_input_shape = self.model_reg.input_shape
|
|
888
|
-
except Exception as e:
|
|
889
|
-
print(e)
|
|
890
|
-
|
|
891
|
-
assert (
|
|
892
|
-
model_class_input_shape == model_reg_input_shape
|
|
893
|
-
), f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
|
|
894
|
-
|
|
895
|
-
return True
|
|
896
|
-
|
|
897
|
-
def freeze_encoder(self, model, n_trainable_layers: int = 3):
|
|
898
|
-
for layer in model.layers[
|
|
899
|
-
: -min(n_trainable_layers, len(model.layers))
|
|
900
|
-
]: # freeze everything except final Dense layer
|
|
901
|
-
layer.trainable = False
|
|
902
|
-
return model
|
|
903
|
-
|
|
904
|
-
def create_models_from_scratch(self):
|
|
905
|
-
"""
|
|
906
|
-
Initializes new models for classification and regression based on the specified parameters.
|
|
907
|
-
|
|
908
|
-
This method creates new ResNet models for both classification and regression tasks using the parameters specified
|
|
909
|
-
during class instantiation. The models are configured but not compiled or trained.
|
|
910
|
-
|
|
911
|
-
Notes
|
|
912
|
-
-----
|
|
913
|
-
- The models are created using a custom ResNet architecture defined elsewhere in the codebase.
|
|
914
|
-
- The models are stored in the `model_class` and `model_reg` attributes of the class.
|
|
915
|
-
|
|
916
|
-
"""
|
|
917
|
-
|
|
918
|
-
self.model_class = ResNetModelCurrent(
|
|
919
|
-
n_channels=self.n_channels,
|
|
920
|
-
n_slices=self.n_conv,
|
|
921
|
-
n_classes=3,
|
|
922
|
-
dense_collection=self.dense_collection,
|
|
923
|
-
dropout_rate=self.dropout_rate,
|
|
924
|
-
header="classifier",
|
|
925
|
-
model_signal_length=self.model_signal_length,
|
|
926
|
-
)
|
|
927
|
-
|
|
928
|
-
self.model_reg = ResNetModelCurrent(
|
|
929
|
-
n_channels=self.n_channels,
|
|
930
|
-
n_slices=self.n_conv,
|
|
931
|
-
n_classes=self.n_classes,
|
|
932
|
-
dense_collection=self.dense_collection,
|
|
933
|
-
dropout_rate=self.dropout_rate,
|
|
934
|
-
header="regressor",
|
|
935
|
-
model_signal_length=self.model_signal_length,
|
|
936
|
-
)
|
|
937
|
-
|
|
938
|
-
def prep_gpu(self):
|
|
939
|
-
"""
|
|
940
|
-
Prepares GPU devices for training by enabling memory growth.
|
|
941
|
-
|
|
942
|
-
This method attempts to identify available GPU devices and configures TensorFlow to allow memory growth on each
|
|
943
|
-
GPU. This prevents TensorFlow from allocating the total available memory on the GPU device upfront.
|
|
944
|
-
|
|
945
|
-
Notes
|
|
946
|
-
-----
|
|
947
|
-
- This method should be called before any TensorFlow/Keras operations that might allocate GPU memory.
|
|
948
|
-
- If no GPUs are detected, the method will pass silently.
|
|
949
|
-
|
|
950
|
-
"""
|
|
951
|
-
|
|
952
|
-
try:
|
|
953
|
-
physical_devices = list_physical_devices("GPU")
|
|
954
|
-
for gpu in physical_devices:
|
|
955
|
-
set_memory_growth(gpu, True)
|
|
956
|
-
except:
|
|
957
|
-
pass
|
|
958
|
-
|
|
959
|
-
def fit_from_directory(
|
|
960
|
-
self,
|
|
961
|
-
datasets,
|
|
962
|
-
normalize=True,
|
|
963
|
-
normalization_percentile=None,
|
|
964
|
-
normalization_values=None,
|
|
965
|
-
normalization_clip=None,
|
|
966
|
-
channel_option=["live_nuclei_channel"],
|
|
967
|
-
model_name=None,
|
|
968
|
-
target_directory=None,
|
|
969
|
-
augment=True,
|
|
970
|
-
augmentation_factor=2,
|
|
971
|
-
validation_split=0.20,
|
|
972
|
-
test_split=0.0,
|
|
973
|
-
batch_size=64,
|
|
974
|
-
epochs=300,
|
|
975
|
-
recompile_pretrained=False,
|
|
976
|
-
learning_rate=0.01,
|
|
977
|
-
loss_reg="mse",
|
|
978
|
-
loss_class=CategoricalCrossentropy(from_logits=False),
|
|
979
|
-
show_plots=True,
|
|
980
|
-
callbacks=None,
|
|
981
|
-
):
|
|
982
|
-
"""
|
|
983
|
-
Trains the model using data from specified directories.
|
|
984
|
-
|
|
985
|
-
This method prepares the dataset for training by loading and preprocessing data from specified directories,
|
|
986
|
-
then trains the classification and regression models.
|
|
987
|
-
|
|
988
|
-
Parameters
|
|
989
|
-
----------
|
|
990
|
-
ds_folders : list of str
|
|
991
|
-
List of directories containing the dataset files for training.
|
|
992
|
-
callbacks : list, optional
|
|
993
|
-
List of Keras callbacks to apply during training.
|
|
994
|
-
normalize : bool, optional
|
|
995
|
-
Whether to normalize the input signals (default is True).
|
|
996
|
-
normalization_percentile : list or None, optional
|
|
997
|
-
Percentiles for signal normalization (default is None).
|
|
998
|
-
normalization_values : list or None, optional
|
|
999
|
-
Specific values for signal normalization (default is None).
|
|
1000
|
-
normalization_clip : bool, optional
|
|
1001
|
-
Whether to clip the normalized signals (default is None).
|
|
1002
|
-
channel_option : list of str, optional
|
|
1003
|
-
Specifies the channels to be used for signal analysis (default is ["live_nuclei_channel"]).
|
|
1004
|
-
model_name : str, optional
|
|
1005
|
-
Name of the model for saving purposes (default is None).
|
|
1006
|
-
target_directory : str, optional
|
|
1007
|
-
Directory where the trained model and outputs will be saved (default is None).
|
|
1008
|
-
augment : bool, optional
|
|
1009
|
-
Whether to augment the training data (default is True).
|
|
1010
|
-
augmentation_factor : int, optional
|
|
1011
|
-
Factor by which to augment the training data (default is 2).
|
|
1012
|
-
validation_split : float, optional
|
|
1013
|
-
Fraction of the data to be used as validation set (default is 0.20).
|
|
1014
|
-
test_split : float, optional
|
|
1015
|
-
Fraction of the data to be used as test set (default is 0.0).
|
|
1016
|
-
batch_size : int, optional
|
|
1017
|
-
Batch size for training (default is 64).
|
|
1018
|
-
epochs : int, optional
|
|
1019
|
-
Number of epochs to train for (default is 300).
|
|
1020
|
-
recompile_pretrained : bool, optional
|
|
1021
|
-
Whether to recompile a pretrained model (default is False).
|
|
1022
|
-
learning_rate : float, optional
|
|
1023
|
-
Learning rate for the optimizer (default is 0.01).
|
|
1024
|
-
loss_reg : str or keras.losses.Loss, optional
|
|
1025
|
-
Loss function for the regression model (default is "mse").
|
|
1026
|
-
loss_class : str or keras.losses.Loss, optional
|
|
1027
|
-
Loss function for the classification model (default is CategoricalCrossentropy(from_logits=False)).
|
|
1028
|
-
|
|
1029
|
-
Notes
|
|
1030
|
-
-----
|
|
1031
|
-
- The method automatically splits the dataset into training, validation, and test sets according to the specified splits.
|
|
1032
|
-
|
|
1033
|
-
"""
|
|
1034
|
-
|
|
1035
|
-
if not hasattr(self, "normalization_percentile"):
|
|
1036
|
-
self.normalization_percentile = normalization_percentile
|
|
1037
|
-
if not hasattr(self, "normalization_values"):
|
|
1038
|
-
self.normalization_values = normalization_values
|
|
1039
|
-
if not hasattr(self, "normalization_clip"):
|
|
1040
|
-
self.normalization_clip = normalization_clip
|
|
1041
|
-
|
|
1042
|
-
self.callbacks = callbacks
|
|
1043
|
-
self.normalize = normalize
|
|
1044
|
-
(
|
|
1045
|
-
self.normalization_percentile,
|
|
1046
|
-
self.normalization_values,
|
|
1047
|
-
self.normalization_clip,
|
|
1048
|
-
) = _interpret_normalization_parameters(
|
|
1049
|
-
self.n_channels,
|
|
1050
|
-
self.normalization_percentile,
|
|
1051
|
-
self.normalization_values,
|
|
1052
|
-
self.normalization_clip,
|
|
1053
|
-
)
|
|
1054
|
-
|
|
1055
|
-
self.datasets = [rf"{d}" if isinstance(d, str) else d for d in datasets]
|
|
1056
|
-
self.batch_size = batch_size
|
|
1057
|
-
self.epochs = epochs
|
|
1058
|
-
self.validation_split = validation_split
|
|
1059
|
-
self.test_split = test_split
|
|
1060
|
-
self.augment = augment
|
|
1061
|
-
self.augmentation_factor = augmentation_factor
|
|
1062
|
-
self.model_name = rf"{model_name}"
|
|
1063
|
-
self.target_directory = rf"{target_directory}"
|
|
1064
|
-
self.model_folder = os.sep.join([self.target_directory, self.model_name])
|
|
1065
|
-
self.recompile_pretrained = recompile_pretrained
|
|
1066
|
-
self.learning_rate = learning_rate
|
|
1067
|
-
self.loss_reg = loss_reg
|
|
1068
|
-
self.loss_class = loss_class
|
|
1069
|
-
self.show_plots = show_plots
|
|
1070
|
-
self.channel_option = channel_option
|
|
1071
|
-
|
|
1072
|
-
assert self.n_channels == len(
|
|
1073
|
-
self.channel_option
|
|
1074
|
-
), f"Mismatch between the channel option and the number of channels of the model..."
|
|
1075
|
-
|
|
1076
|
-
if isinstance(self.datasets[0], dict):
|
|
1077
|
-
self.datasets = [self.datasets]
|
|
1078
|
-
|
|
1079
|
-
self.list_of_sets = []
|
|
1080
|
-
for ds in self.datasets:
|
|
1081
|
-
if isinstance(ds, str):
|
|
1082
|
-
self.list_of_sets.extend(glob(os.sep.join([ds, "*.npy"])))
|
|
1083
|
-
else:
|
|
1084
|
-
self.list_of_sets.append(ds)
|
|
1085
|
-
|
|
1086
|
-
print(f"Found {len(self.list_of_sets)} datasets...")
|
|
1087
|
-
|
|
1088
|
-
self.prepare_sets()
|
|
1089
|
-
self.train_generic()
|
|
1090
|
-
|
|
1091
|
-
def fit(
|
|
1092
|
-
self,
|
|
1093
|
-
x_train,
|
|
1094
|
-
y_time_train,
|
|
1095
|
-
y_class_train,
|
|
1096
|
-
normalize=True,
|
|
1097
|
-
normalization_percentile=None,
|
|
1098
|
-
normalization_values=None,
|
|
1099
|
-
normalization_clip=None,
|
|
1100
|
-
pad=True,
|
|
1101
|
-
validation_data=None,
|
|
1102
|
-
test_data=None,
|
|
1103
|
-
channel_option=["live_nuclei_channel", "dead_nuclei_channel"],
|
|
1104
|
-
model_name=None,
|
|
1105
|
-
target_directory=None,
|
|
1106
|
-
augment=True,
|
|
1107
|
-
augmentation_factor=3,
|
|
1108
|
-
validation_split=0.25,
|
|
1109
|
-
batch_size=64,
|
|
1110
|
-
epochs=300,
|
|
1111
|
-
recompile_pretrained=False,
|
|
1112
|
-
learning_rate=0.001,
|
|
1113
|
-
loss_reg="mse",
|
|
1114
|
-
loss_class=CategoricalCrossentropy(from_logits=False),
|
|
1115
|
-
):
|
|
1116
|
-
"""
|
|
1117
|
-
Trains the model using provided datasets.
|
|
1118
|
-
|
|
1119
|
-
Parameters
|
|
1120
|
-
----------
|
|
1121
|
-
Same as `fit_from_directory`, but instead of loading data from directories, this method accepts preloaded and
|
|
1122
|
-
optionally preprocessed datasets directly.
|
|
1123
|
-
|
|
1124
|
-
Notes
|
|
1125
|
-
-----
|
|
1126
|
-
- This method provides an alternative way to train the model when data is already loaded into memory, offering
|
|
1127
|
-
flexibility for data preprocessing steps outside this class.
|
|
1128
|
-
|
|
1129
|
-
"""
|
|
1130
|
-
|
|
1131
|
-
self.normalize = normalize
|
|
1132
|
-
if not hasattr(self, "normalization_percentile"):
|
|
1133
|
-
self.normalization_percentile = normalization_percentile
|
|
1134
|
-
if not hasattr(self, "normalization_values"):
|
|
1135
|
-
self.normalization_values = normalization_values
|
|
1136
|
-
if not hasattr(self, "normalization_clip"):
|
|
1137
|
-
self.normalization_clip = normalization_clip
|
|
1138
|
-
(
|
|
1139
|
-
self.normalization_percentile,
|
|
1140
|
-
self.normalization_values,
|
|
1141
|
-
self.normalization_clip,
|
|
1142
|
-
) = _interpret_normalization_parameters(
|
|
1143
|
-
self.n_channels,
|
|
1144
|
-
self.normalization_percentile,
|
|
1145
|
-
self.normalization_values,
|
|
1146
|
-
self.normalization_clip,
|
|
1147
|
-
)
|
|
1148
|
-
|
|
1149
|
-
self.x_train = x_train
|
|
1150
|
-
self.y_class_train = y_class_train
|
|
1151
|
-
self.y_time_train = y_time_train
|
|
1152
|
-
self.channel_option = channel_option
|
|
1153
|
-
|
|
1154
|
-
assert self.n_channels == len(
|
|
1155
|
-
self.channel_option
|
|
1156
|
-
), f"Mismatch between the channel option and the number of channels of the model..."
|
|
1157
|
-
|
|
1158
|
-
if pad:
|
|
1159
|
-
self.x_train = pad_to_model_length(self.x_train, self.model_signal_length)
|
|
1160
|
-
|
|
1161
|
-
assert self.x_train.shape[1:] == (
|
|
1162
|
-
self.model_signal_length,
|
|
1163
|
-
self.n_channels,
|
|
1164
|
-
), f"Shape mismatch between the provided training fluorescence signals and the model..."
|
|
1165
|
-
|
|
1166
|
-
# If y-class is not one-hot encoded, encode it
|
|
1167
|
-
if self.y_class_train.shape[-1] != self.n_classes:
|
|
1168
|
-
self.class_weights = compute_weights(
|
|
1169
|
-
y=self.y_class_train,
|
|
1170
|
-
class_weight="balanced",
|
|
1171
|
-
classes=np.unique(self.y_class_train),
|
|
1172
|
-
)
|
|
1173
|
-
self.y_class_train = to_categorical(self.y_class_train, num_classes=3)
|
|
1174
|
-
|
|
1175
|
-
if self.normalize:
|
|
1176
|
-
self.y_time_train = (
|
|
1177
|
-
self.y_time_train.astype(np.float32) / self.model_signal_length
|
|
1178
|
-
)
|
|
1179
|
-
self.x_train = normalize_signal_set(
|
|
1180
|
-
self.x_train,
|
|
1181
|
-
self.channel_option,
|
|
1182
|
-
normalization_percentile=self.normalization_percentile,
|
|
1183
|
-
normalization_values=self.normalization_values,
|
|
1184
|
-
normalization_clip=self.normalization_clip,
|
|
1185
|
-
)
|
|
1186
|
-
|
|
1187
|
-
if validation_data is not None:
|
|
1188
|
-
try:
|
|
1189
|
-
self.x_val = validation_data[0]
|
|
1190
|
-
if pad:
|
|
1191
|
-
self.x_val = pad_to_model_length(
|
|
1192
|
-
self.x_val, self.model_signal_length
|
|
1193
|
-
)
|
|
1194
|
-
self.y_class_val = validation_data[1]
|
|
1195
|
-
if self.y_class_val.shape[-1] != self.n_classes:
|
|
1196
|
-
self.y_class_val = to_categorical(self.y_class_val, num_classes=3)
|
|
1197
|
-
self.y_time_val = validation_data[2]
|
|
1198
|
-
if self.normalize:
|
|
1199
|
-
self.y_time_val = (
|
|
1200
|
-
self.y_time_val.astype(np.float32) / self.model_signal_length
|
|
1201
|
-
)
|
|
1202
|
-
self.x_val = normalize_signal_set(
|
|
1203
|
-
self.x_val,
|
|
1204
|
-
self.channel_option,
|
|
1205
|
-
normalization_percentile=self.normalization_percentile,
|
|
1206
|
-
normalization_values=self.normalization_values,
|
|
1207
|
-
normalization_clip=self.normalization_clip,
|
|
1208
|
-
)
|
|
1209
|
-
|
|
1210
|
-
except Exception as e:
|
|
1211
|
-
print("Could not load validation data, error {e}...")
|
|
1212
|
-
else:
|
|
1213
|
-
self.validation_split = validation_split
|
|
1214
|
-
|
|
1215
|
-
if test_data is not None:
|
|
1216
|
-
try:
|
|
1217
|
-
self.x_test = test_data[0]
|
|
1218
|
-
if pad:
|
|
1219
|
-
self.x_test = pad_to_model_length(
|
|
1220
|
-
self.x_test, self.model_signal_length
|
|
1221
|
-
)
|
|
1222
|
-
self.y_class_test = test_data[1]
|
|
1223
|
-
if self.y_class_test.shape[-1] != self.n_classes:
|
|
1224
|
-
self.y_class_test = to_categorical(self.y_class_test, num_classes=3)
|
|
1225
|
-
self.y_time_test = test_data[2]
|
|
1226
|
-
if self.normalize:
|
|
1227
|
-
self.y_time_test = (
|
|
1228
|
-
self.y_time_test.astype(np.float32) / self.model_signal_length
|
|
1229
|
-
)
|
|
1230
|
-
self.x_test = normalize_signal_set(
|
|
1231
|
-
self.x_test,
|
|
1232
|
-
self.channel_option,
|
|
1233
|
-
normalization_percentile=self.normalization_percentile,
|
|
1234
|
-
normalization_values=self.normalization_values,
|
|
1235
|
-
normalization_clip=self.normalization_clip,
|
|
1236
|
-
)
|
|
1237
|
-
except Exception as e:
|
|
1238
|
-
print("Could not load test data, error {e}...")
|
|
1239
|
-
|
|
1240
|
-
self.batch_size = batch_size
|
|
1241
|
-
self.epochs = epochs
|
|
1242
|
-
self.augment = augment
|
|
1243
|
-
self.augmentation_factor = augmentation_factor
|
|
1244
|
-
if self.augmentation_factor == 1:
|
|
1245
|
-
self.augment = False
|
|
1246
|
-
self.model_name = model_name
|
|
1247
|
-
self.target_directory = target_directory
|
|
1248
|
-
self.model_folder = os.sep.join([self.target_directory, self.model_name])
|
|
1249
|
-
self.recompile_pretrained = recompile_pretrained
|
|
1250
|
-
self.learning_rate = learning_rate
|
|
1251
|
-
self.loss_reg = loss_reg
|
|
1252
|
-
self.loss_class = loss_class
|
|
1253
|
-
|
|
1254
|
-
self.train_generic()
|
|
1255
|
-
|
|
1256
|
-
def train_generic(self):
|
|
1257
|
-
|
|
1258
|
-
if not os.path.exists(self.model_folder):
|
|
1259
|
-
os.mkdir(self.model_folder)
|
|
1260
|
-
|
|
1261
|
-
self.train_classifier()
|
|
1262
|
-
self.train_regressor()
|
|
1263
|
-
|
|
1264
|
-
config_input = {
|
|
1265
|
-
"channels": self.channel_option,
|
|
1266
|
-
"model_signal_length": self.model_signal_length,
|
|
1267
|
-
"label": self.label,
|
|
1268
|
-
"normalize": self.normalize,
|
|
1269
|
-
"normalization_percentile": self.normalization_percentile,
|
|
1270
|
-
"normalization_values": self.normalization_values,
|
|
1271
|
-
"normalization_clip": self.normalization_clip,
|
|
1272
|
-
}
|
|
1273
|
-
json_string = json.dumps(config_input)
|
|
1274
|
-
with open(
|
|
1275
|
-
os.sep.join([self.model_folder, "config_input.json"]), "w"
|
|
1276
|
-
) as outfile:
|
|
1277
|
-
outfile.write(json_string)
|
|
1278
|
-
|
|
1279
|
-
def predict_class(
|
|
1280
|
-
self, x, normalize=True, pad=True, return_one_hot=False, interpolate=True
|
|
1281
|
-
):
|
|
1282
|
-
"""
|
|
1283
|
-
Predicts the class of input signals using the trained classification model.
|
|
1284
|
-
|
|
1285
|
-
Parameters
|
|
1286
|
-
----------
|
|
1287
|
-
x : ndarray
|
|
1288
|
-
The input signals for which to predict classes.
|
|
1289
|
-
normalize : bool, optional
|
|
1290
|
-
Whether to normalize the input signals (default is True).
|
|
1291
|
-
pad : bool, optional
|
|
1292
|
-
Whether to pad the input signals to match the model's expected signal length (default is True).
|
|
1293
|
-
return_one_hot : bool, optional
|
|
1294
|
-
Whether to return predictions in one-hot encoded format (default is False).
|
|
1295
|
-
interpolate : bool, optional
|
|
1296
|
-
Whether to interpolate the input signals (default is True).
|
|
1297
|
-
|
|
1298
|
-
Returns
|
|
1299
|
-
-------
|
|
1300
|
-
ndarray
|
|
1301
|
-
The predicted classes for the input signals. If `return_one_hot` is True, predictions are returned in one-hot
|
|
1302
|
-
encoded format, otherwise as integer labels.
|
|
1303
|
-
|
|
1304
|
-
Notes
|
|
1305
|
-
-----
|
|
1306
|
-
- The method processes the input signals according to the specified options to ensure compatibility with the model's
|
|
1307
|
-
input requirements.
|
|
1308
|
-
|
|
1309
|
-
"""
|
|
1310
|
-
|
|
1311
|
-
self.x = np.copy(x)
|
|
1312
|
-
self.normalize = normalize
|
|
1313
|
-
self.pad = pad
|
|
1314
|
-
self.return_one_hot = return_one_hot
|
|
1315
|
-
# self.max_relevant_time = np.shape(self.x)[1]
|
|
1316
|
-
# print(f'Max relevant time: {self.max_relevant_time}')
|
|
1317
|
-
|
|
1318
|
-
if self.pad:
|
|
1319
|
-
self.x = pad_to_model_length(self.x, self.model_signal_length)
|
|
1320
|
-
|
|
1321
|
-
if self.normalize:
|
|
1322
|
-
self.x = normalize_signal_set(
|
|
1323
|
-
self.x,
|
|
1324
|
-
self.channel_option,
|
|
1325
|
-
normalization_percentile=self.normalization_percentile,
|
|
1326
|
-
normalization_values=self.normalization_values,
|
|
1327
|
-
normalization_clip=self.normalization_clip,
|
|
1328
|
-
)
|
|
1329
|
-
|
|
1330
|
-
# implement auto interpolation here!!
|
|
1331
|
-
# self.x = self.interpolate_signals(self.x)
|
|
1332
|
-
|
|
1333
|
-
# for i in range(5):
|
|
1334
|
-
# plt.plot(self.x[i,:,0])
|
|
1335
|
-
# plt.show()
|
|
1336
|
-
|
|
1337
|
-
try:
|
|
1338
|
-
n_channels = self.model_class.layers[0].input_shape[0][-1]
|
|
1339
|
-
model_signal_length = self.model_class.layers[0].input_shape[0][-2]
|
|
1340
|
-
except AttributeError:
|
|
1341
|
-
n_channels = self.model_class.input_shape[-1]
|
|
1342
|
-
model_signal_length = self.model_class.input_shape[-2]
|
|
1343
|
-
|
|
1344
|
-
assert (
|
|
1345
|
-
self.x.shape[-1] == n_channels
|
|
1346
|
-
), f"Shape mismatch between the input shape and the model input shape..."
|
|
1347
|
-
assert (
|
|
1348
|
-
self.x.shape[-2] == model_signal_length
|
|
1349
|
-
), f"Shape mismatch between the input shape and the model input shape..."
|
|
1350
|
-
|
|
1351
|
-
self.class_predictions_one_hot = self.model_class.predict(self.x)
|
|
1352
|
-
self.class_predictions = self.class_predictions_one_hot.argmax(axis=1)
|
|
1353
|
-
|
|
1354
|
-
if self.return_one_hot:
|
|
1355
|
-
return self.class_predictions_one_hot
|
|
1356
|
-
else:
|
|
1357
|
-
return self.class_predictions
|
|
1358
|
-
|
|
1359
|
-
def predict_time_of_interest(
|
|
1360
|
-
self, x, class_predictions=None, normalize=True, pad=True
|
|
1361
|
-
):
|
|
1362
|
-
"""
|
|
1363
|
-
Predicts the time of interest for input signals using the trained regression model.
|
|
1364
|
-
|
|
1365
|
-
Parameters
|
|
1366
|
-
----------
|
|
1367
|
-
x : ndarray
|
|
1368
|
-
The input signals for which to predict times of interest.
|
|
1369
|
-
class_predictions : ndarray, optional
|
|
1370
|
-
The predicted classes for the input signals. If provided, time of interest predictions are only made for
|
|
1371
|
-
signals predicted to belong to a specific class (default is None).
|
|
1372
|
-
normalize : bool, optional
|
|
1373
|
-
Whether to normalize the input signals (default is True).
|
|
1374
|
-
pad : bool, optional
|
|
1375
|
-
Whether to pad the input signals to match the model's expected signal length (default is True).
|
|
1376
|
-
|
|
1377
|
-
Returns
|
|
1378
|
-
-------
|
|
1379
|
-
ndarray
|
|
1380
|
-
The predicted times of interest for the input signals.
|
|
1381
|
-
|
|
1382
|
-
Notes
|
|
1383
|
-
-----
|
|
1384
|
-
- The method processes the input signals according to the specified options and uses the regression model to
|
|
1385
|
-
predict times at which a particular event of interest occurs.
|
|
1386
|
-
|
|
1387
|
-
"""
|
|
1388
|
-
|
|
1389
|
-
self.x = np.copy(x)
|
|
1390
|
-
self.normalize = normalize
|
|
1391
|
-
self.pad = pad
|
|
1392
|
-
# self.max_relevant_time = np.shape(self.x)[1]
|
|
1393
|
-
# print(f'Max relevant time: {self.max_relevant_time}')
|
|
1394
|
-
|
|
1395
|
-
if class_predictions is not None:
|
|
1396
|
-
self.class_predictions = class_predictions
|
|
1397
|
-
|
|
1398
|
-
if self.pad:
|
|
1399
|
-
self.x = pad_to_model_length(self.x, self.model_signal_length)
|
|
1400
|
-
|
|
1401
|
-
if self.normalize:
|
|
1402
|
-
self.x = normalize_signal_set(
|
|
1403
|
-
self.x,
|
|
1404
|
-
self.channel_option,
|
|
1405
|
-
normalization_percentile=self.normalization_percentile,
|
|
1406
|
-
normalization_values=self.normalization_values,
|
|
1407
|
-
normalization_clip=self.normalization_clip,
|
|
1408
|
-
)
|
|
1409
|
-
|
|
1410
|
-
try:
|
|
1411
|
-
n_channels = self.model_reg.layers[0].input_shape[0][-1]
|
|
1412
|
-
model_signal_length = self.model_reg.layers[0].input_shape[0][-2]
|
|
1413
|
-
except AttributeError:
|
|
1414
|
-
n_channels = self.model_reg.input_shape[-1]
|
|
1415
|
-
model_signal_length = self.model_reg.input_shape[-2]
|
|
1416
|
-
|
|
1417
|
-
assert (
|
|
1418
|
-
self.x.shape[-1] == n_channels
|
|
1419
|
-
), f"Shape mismatch between the input shape and the model input shape..."
|
|
1420
|
-
assert (
|
|
1421
|
-
self.x.shape[-2] == model_signal_length
|
|
1422
|
-
), f"Shape mismatch between the input shape and the model input shape..."
|
|
1423
|
-
|
|
1424
|
-
if np.any(self.class_predictions == 0):
|
|
1425
|
-
self.time_predictions = (
|
|
1426
|
-
self.model_reg.predict(self.x[self.class_predictions == 0])
|
|
1427
|
-
* self.model_signal_length
|
|
1428
|
-
)
|
|
1429
|
-
self.time_predictions = self.time_predictions[:, 0]
|
|
1430
|
-
self.time_predictions_recast = np.zeros(len(self.x)) - 1.0
|
|
1431
|
-
self.time_predictions_recast[self.class_predictions == 0] = (
|
|
1432
|
-
self.time_predictions
|
|
1433
|
-
)
|
|
1434
|
-
else:
|
|
1435
|
-
self.time_predictions_recast = np.zeros(len(self.x)) - 1.0
|
|
1436
|
-
return self.time_predictions_recast
|
|
1437
|
-
|
|
1438
|
-
def interpolate_signals(self, x_set):
|
|
1439
|
-
"""
|
|
1440
|
-
Interpolates missing values in the input signal set.
|
|
1441
|
-
|
|
1442
|
-
Parameters
|
|
1443
|
-
----------
|
|
1444
|
-
x_set : ndarray
|
|
1445
|
-
The input signal set with potentially missing values.
|
|
1446
|
-
|
|
1447
|
-
Returns
|
|
1448
|
-
-------
|
|
1449
|
-
ndarray
|
|
1450
|
-
The input signal set with missing values interpolated.
|
|
1451
|
-
|
|
1452
|
-
Notes
|
|
1453
|
-
-----
|
|
1454
|
-
- This method is useful for preparing signals that have gaps or missing time points before further processing
|
|
1455
|
-
or model training.
|
|
1456
|
-
|
|
1457
|
-
"""
|
|
1458
|
-
|
|
1459
|
-
for i in range(len(x_set)):
|
|
1460
|
-
for k in range(x_set.shape[-1]):
|
|
1461
|
-
x = x_set[i, :, k]
|
|
1462
|
-
not_nan = np.logical_not(np.isnan(x))
|
|
1463
|
-
indices = np.arange(len(x))
|
|
1464
|
-
interp = interp1d(
|
|
1465
|
-
indices[not_nan],
|
|
1466
|
-
x[not_nan],
|
|
1467
|
-
fill_value=(0.0, 0.0),
|
|
1468
|
-
bounds_error=False,
|
|
1469
|
-
)
|
|
1470
|
-
x_set[i, :, k] = interp(indices)
|
|
1471
|
-
return x_set
|
|
1472
|
-
|
|
1473
|
-
def train_classifier(self):
|
|
1474
|
-
"""
|
|
1475
|
-
Trains the classifier component of the model to predict event classes in signals.
|
|
1476
|
-
|
|
1477
|
-
This method compiles the classifier model (if not pretrained or if recompilation is requested) and
|
|
1478
|
-
trains it on the prepared dataset. The training process includes validation and early stopping based
|
|
1479
|
-
on precision to prevent overfitting.
|
|
1480
|
-
|
|
1481
|
-
Notes
|
|
1482
|
-
-----
|
|
1483
|
-
- The classifier model predicts the class of each signal, such as live, dead, or miscellaneous.
|
|
1484
|
-
- Training parameters such as epochs, batch size, and learning rate are specified during class instantiation.
|
|
1485
|
-
- Model performance metrics and training history are saved for analysis.
|
|
1486
|
-
|
|
1487
|
-
"""
|
|
1488
|
-
|
|
1489
|
-
# if pretrained model
|
|
1490
|
-
self.n_classes = 3
|
|
1491
|
-
|
|
1492
|
-
if self.pretrained is not None:
|
|
1493
|
-
# if recompile
|
|
1494
|
-
if self.recompile_pretrained:
|
|
1495
|
-
print(
|
|
1496
|
-
"Recompiling the pretrained classifier model... Warning, this action reinitializes all the weights; are you sure that this is what you intended?"
|
|
1497
|
-
)
|
|
1498
|
-
self.model_class.set_weights(
|
|
1499
|
-
clone_model(self.model_class).get_weights()
|
|
1500
|
-
)
|
|
1501
|
-
self.model_class.compile(
|
|
1502
|
-
optimizer=Adam(learning_rate=self.learning_rate),
|
|
1503
|
-
loss=self.loss_class,
|
|
1504
|
-
metrics=[
|
|
1505
|
-
"accuracy",
|
|
1506
|
-
Precision(),
|
|
1507
|
-
Recall(),
|
|
1508
|
-
MeanIoU(
|
|
1509
|
-
num_classes=self.n_classes,
|
|
1510
|
-
name="iou",
|
|
1511
|
-
dtype=float,
|
|
1512
|
-
sparse_y_true=False,
|
|
1513
|
-
sparse_y_pred=False,
|
|
1514
|
-
),
|
|
1515
|
-
],
|
|
1516
|
-
)
|
|
1517
|
-
else:
|
|
1518
|
-
self.initial_model = clone_model(self.model_class)
|
|
1519
|
-
self.model_class.set_weights(self.initial_model.get_weights())
|
|
1520
|
-
# Recompile to avoid crash
|
|
1521
|
-
self.model_class.compile(
|
|
1522
|
-
optimizer=Adam(learning_rate=self.learning_rate),
|
|
1523
|
-
loss=self.loss_class,
|
|
1524
|
-
metrics=[
|
|
1525
|
-
"accuracy",
|
|
1526
|
-
Precision(),
|
|
1527
|
-
Recall(),
|
|
1528
|
-
MeanIoU(
|
|
1529
|
-
num_classes=self.n_classes,
|
|
1530
|
-
name="iou",
|
|
1531
|
-
dtype=float,
|
|
1532
|
-
sparse_y_true=False,
|
|
1533
|
-
sparse_y_pred=False,
|
|
1534
|
-
),
|
|
1535
|
-
],
|
|
1536
|
-
)
|
|
1537
|
-
# Reset weights
|
|
1538
|
-
self.model_class.set_weights(self.initial_model.get_weights())
|
|
1539
|
-
else:
|
|
1540
|
-
print("Compiling the classifier...")
|
|
1541
|
-
self.model_class.compile(
|
|
1542
|
-
optimizer=Adam(learning_rate=self.learning_rate),
|
|
1543
|
-
loss=self.loss_class,
|
|
1544
|
-
metrics=[
|
|
1545
|
-
"accuracy",
|
|
1546
|
-
Precision(),
|
|
1547
|
-
Recall(),
|
|
1548
|
-
MeanIoU(
|
|
1549
|
-
num_classes=self.n_classes,
|
|
1550
|
-
name="iou",
|
|
1551
|
-
dtype=float,
|
|
1552
|
-
sparse_y_true=False,
|
|
1553
|
-
sparse_y_pred=False,
|
|
1554
|
-
),
|
|
1555
|
-
],
|
|
1556
|
-
)
|
|
1557
|
-
|
|
1558
|
-
self.gather_callbacks("classifier")
|
|
1559
|
-
|
|
1560
|
-
# for i in range(30):
|
|
1561
|
-
# for j in range(self.x_train.shape[-1]):
|
|
1562
|
-
# plt.plot(self.x_train[i,:,j])
|
|
1563
|
-
# plt.show()
|
|
1564
|
-
|
|
1565
|
-
if hasattr(self, "x_val"):
|
|
1566
|
-
|
|
1567
|
-
self.history_classifier = self.model_class.fit(
|
|
1568
|
-
x=self.x_train,
|
|
1569
|
-
y=self.y_class_train,
|
|
1570
|
-
batch_size=self.batch_size,
|
|
1571
|
-
class_weight=self.class_weights,
|
|
1572
|
-
epochs=self.epochs,
|
|
1573
|
-
validation_data=(self.x_val, self.y_class_val),
|
|
1574
|
-
callbacks=self.cb,
|
|
1575
|
-
verbose=1,
|
|
1576
|
-
)
|
|
1577
|
-
else:
|
|
1578
|
-
self.history_classifier = self.model_class.fit(
|
|
1579
|
-
x=self.x_train,
|
|
1580
|
-
y=self.y_class_train,
|
|
1581
|
-
batch_size=self.batch_size,
|
|
1582
|
-
class_weight=self.class_weights,
|
|
1583
|
-
epochs=self.epochs,
|
|
1584
|
-
callbacks=self.cb,
|
|
1585
|
-
validation_split=self.validation_split,
|
|
1586
|
-
verbose=1,
|
|
1587
|
-
)
|
|
1588
|
-
|
|
1589
|
-
if self.show_plots:
|
|
1590
|
-
self.plot_model_history(mode="classifier")
|
|
1591
|
-
|
|
1592
|
-
# Set current classification model as the best model
|
|
1593
|
-
self.model_class = load_model(
|
|
1594
|
-
os.sep.join([self.model_folder, "classifier.h5"]),
|
|
1595
|
-
custom_objects={"mse": MeanSquaredError()},
|
|
1596
|
-
)
|
|
1597
|
-
self.model_class.load_weights(os.sep.join([self.model_folder, "classifier.h5"]))
|
|
1598
|
-
|
|
1599
|
-
time_callback = next(
|
|
1600
|
-
(cb for cb in self.cb if isinstance(cb, TimeHistory)), None
|
|
1601
|
-
)
|
|
1602
|
-
self.dico = {
|
|
1603
|
-
"history_classifier": self.history_classifier,
|
|
1604
|
-
"execution_time_classifier": time_callback.times if time_callback else [],
|
|
1605
|
-
}
|
|
1606
|
-
|
|
1607
|
-
if hasattr(self, "x_test"):
|
|
1608
|
-
|
|
1609
|
-
predictions = self.model_class.predict(self.x_test).argmax(axis=1)
|
|
1610
|
-
ground_truth = self.y_class_test.argmax(axis=1)
|
|
1611
|
-
assert (
|
|
1612
|
-
predictions.shape == ground_truth.shape
|
|
1613
|
-
), "Mismatch in shape between the predictions and the ground truth..."
|
|
1614
|
-
|
|
1615
|
-
title = "Test data"
|
|
1616
|
-
IoU_score = jaccard_score(ground_truth, predictions, average=None)
|
|
1617
|
-
balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
|
|
1618
|
-
precision = precision_score(ground_truth, predictions, average=None)
|
|
1619
|
-
recall = recall_score(ground_truth, predictions, average=None)
|
|
1620
|
-
|
|
1621
|
-
print(f"Test IoU score: {IoU_score}")
|
|
1622
|
-
print(f"Test Balanced accuracy score: {balanced_accuracy}")
|
|
1623
|
-
print(f"Test Precision: {precision}")
|
|
1624
|
-
print(f"Test Recall: {recall}")
|
|
1625
|
-
|
|
1626
|
-
# Confusion matrix on test set
|
|
1627
|
-
results = confusion_matrix(ground_truth, predictions)
|
|
1628
|
-
self.dico.update(
|
|
1629
|
-
{
|
|
1630
|
-
"test_IoU": IoU_score,
|
|
1631
|
-
"test_balanced_accuracy": balanced_accuracy,
|
|
1632
|
-
"test_confusion": results,
|
|
1633
|
-
"test_precision": precision,
|
|
1634
|
-
"test_recall": recall,
|
|
1635
|
-
}
|
|
1636
|
-
)
|
|
1637
|
-
|
|
1638
|
-
if self.show_plots:
|
|
1639
|
-
try:
|
|
1640
|
-
ConfusionMatrixDisplay.from_predictions(
|
|
1641
|
-
ground_truth,
|
|
1642
|
-
predictions,
|
|
1643
|
-
cmap="Blues",
|
|
1644
|
-
normalize="pred",
|
|
1645
|
-
display_labels=["event", "no event", "left censored"],
|
|
1646
|
-
)
|
|
1647
|
-
plt.savefig(
|
|
1648
|
-
os.sep.join([self.model_folder, "test_confusion_matrix.png"]),
|
|
1649
|
-
bbox_inches="tight",
|
|
1650
|
-
dpi=300,
|
|
1651
|
-
)
|
|
1652
|
-
# plt.pause(3)
|
|
1653
|
-
plt.close()
|
|
1654
|
-
except Exception as e:
|
|
1655
|
-
print(e)
|
|
1656
|
-
pass
|
|
1657
|
-
print("Test set: ", classification_report(ground_truth, predictions))
|
|
1658
|
-
|
|
1659
|
-
if hasattr(self, "x_val"):
|
|
1660
|
-
predictions = self.model_class.predict(self.x_val).argmax(axis=1)
|
|
1661
|
-
ground_truth = self.y_class_val.argmax(axis=1)
|
|
1662
|
-
assert (
|
|
1663
|
-
ground_truth.shape == predictions.shape
|
|
1664
|
-
), "Mismatch in shape between the predictions and the ground truth..."
|
|
1665
|
-
title = "Validation data"
|
|
1666
|
-
|
|
1667
|
-
# Validation scores
|
|
1668
|
-
IoU_score = jaccard_score(ground_truth, predictions, average=None)
|
|
1669
|
-
balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
|
|
1670
|
-
precision = precision_score(ground_truth, predictions, average=None)
|
|
1671
|
-
recall = recall_score(ground_truth, predictions, average=None)
|
|
1672
|
-
|
|
1673
|
-
print(f"Validation IoU score: {IoU_score}")
|
|
1674
|
-
print(f"Validation Balanced accuracy score: {balanced_accuracy}")
|
|
1675
|
-
print(f"Validation Precision: {precision}")
|
|
1676
|
-
print(f"Validation Recall: {recall}")
|
|
1677
|
-
|
|
1678
|
-
# Confusion matrix on validation set
|
|
1679
|
-
results = confusion_matrix(ground_truth, predictions)
|
|
1680
|
-
self.dico.update(
|
|
1681
|
-
{
|
|
1682
|
-
"val_IoU": IoU_score,
|
|
1683
|
-
"val_balanced_accuracy": balanced_accuracy,
|
|
1684
|
-
"val_confusion": results,
|
|
1685
|
-
"val_precision": precision,
|
|
1686
|
-
"val_recall": recall,
|
|
1687
|
-
}
|
|
1688
|
-
)
|
|
1689
|
-
|
|
1690
|
-
if self.show_plots:
|
|
1691
|
-
try:
|
|
1692
|
-
ConfusionMatrixDisplay.from_predictions(
|
|
1693
|
-
ground_truth,
|
|
1694
|
-
predictions,
|
|
1695
|
-
cmap="Blues",
|
|
1696
|
-
normalize="pred",
|
|
1697
|
-
display_labels=["event", "no event", "left censored"],
|
|
1698
|
-
)
|
|
1699
|
-
plt.savefig(
|
|
1700
|
-
os.sep.join(
|
|
1701
|
-
[self.model_folder, "validation_confusion_matrix.png"]
|
|
1702
|
-
),
|
|
1703
|
-
bbox_inches="tight",
|
|
1704
|
-
dpi=300,
|
|
1705
|
-
)
|
|
1706
|
-
# plt.pause(3)
|
|
1707
|
-
plt.close()
|
|
1708
|
-
except Exception as e:
|
|
1709
|
-
print(e)
|
|
1710
|
-
pass
|
|
1711
|
-
print("Validation set: ", classification_report(ground_truth, predictions))
|
|
1712
|
-
|
|
1713
|
-
# Send result to GUI and wait
|
|
1714
|
-
for cb in self.cb:
|
|
1715
|
-
if hasattr(cb, "on_training_result"):
|
|
1716
|
-
cb.on_training_result(self.dico)
|
|
1717
|
-
time.sleep(3)
|
|
1718
|
-
|
|
1719
|
-
def train_regressor(self):
|
|
1720
|
-
"""
|
|
1721
|
-
Trains the regressor component of the model to estimate the time of interest for events in signals.
|
|
1722
|
-
|
|
1723
|
-
This method compiles the regressor model (if not pretrained or if recompilation is requested) and
|
|
1724
|
-
trains it on a subset of the prepared dataset containing signals with events. The training process
|
|
1725
|
-
includes validation and early stopping based on mean squared error to prevent overfitting.
|
|
1726
|
-
|
|
1727
|
-
Notes
|
|
1728
|
-
-----
|
|
1729
|
-
- The regressor model estimates the time at which an event of interest occurs within each signal.
|
|
1730
|
-
- Only signals predicted to have an event by the classifier model are used for regressor training.
|
|
1731
|
-
- Model performance metrics and training history are saved for analysis.
|
|
1732
|
-
|
|
1733
|
-
"""
|
|
1734
|
-
|
|
1735
|
-
# Compile model
|
|
1736
|
-
# if pretrained model
|
|
1737
|
-
if self.pretrained is not None:
|
|
1738
|
-
# if recompile
|
|
1739
|
-
if self.recompile_pretrained:
|
|
1740
|
-
print(
|
|
1741
|
-
"Recompiling the pretrained regressor model... Warning, this action reinitializes all the weights; are you sure that this is what you intended?"
|
|
1742
|
-
)
|
|
1743
|
-
self.model_reg.set_weights(clone_model(self.model_reg).get_weights())
|
|
1744
|
-
self.model_reg.compile(
|
|
1745
|
-
optimizer=Adam(learning_rate=self.learning_rate),
|
|
1746
|
-
loss=self.loss_reg,
|
|
1747
|
-
metrics=["mse", "mae"],
|
|
1748
|
-
)
|
|
1749
|
-
else:
|
|
1750
|
-
self.initial_model = clone_model(self.model_reg)
|
|
1751
|
-
self.model_reg.set_weights(self.initial_model.get_weights())
|
|
1752
|
-
self.model_reg.compile(
|
|
1753
|
-
optimizer=Adam(learning_rate=self.learning_rate),
|
|
1754
|
-
loss=self.loss_reg,
|
|
1755
|
-
metrics=["mse", "mae"],
|
|
1756
|
-
)
|
|
1757
|
-
self.model_reg.set_weights(self.initial_model.get_weights())
|
|
1758
|
-
else:
|
|
1759
|
-
print("Compiling the regressor...")
|
|
1760
|
-
self.model_reg.compile(
|
|
1761
|
-
optimizer=Adam(learning_rate=self.learning_rate),
|
|
1762
|
-
loss=self.loss_reg,
|
|
1763
|
-
metrics=["mse", "mae"],
|
|
1764
|
-
)
|
|
1765
|
-
|
|
1766
|
-
self.gather_callbacks("regressor")
|
|
1767
|
-
|
|
1768
|
-
# Train on subset of data with event
|
|
1769
|
-
|
|
1770
|
-
subset = self.x_train[np.argmax(self.y_class_train, axis=1) == 0]
|
|
1771
|
-
# for i in range(30):
|
|
1772
|
-
# plt.plot(subset[i,:,0],c="tab:red")
|
|
1773
|
-
# plt.plot(subset[i,:,1],c="tab:blue")
|
|
1774
|
-
# plt.show()
|
|
1775
|
-
|
|
1776
|
-
if hasattr(self, "x_val"):
|
|
1777
|
-
self.history_regressor = self.model_reg.fit(
|
|
1778
|
-
x=self.x_train[np.argmax(self.y_class_train, axis=1) == 0],
|
|
1779
|
-
y=self.y_time_train[np.argmax(self.y_class_train, axis=1) == 0],
|
|
1780
|
-
batch_size=self.batch_size,
|
|
1781
|
-
epochs=self.epochs * 2,
|
|
1782
|
-
validation_data=(
|
|
1783
|
-
self.x_val[np.argmax(self.y_class_val, axis=1) == 0],
|
|
1784
|
-
self.y_time_val[np.argmax(self.y_class_val, axis=1) == 0],
|
|
1785
|
-
),
|
|
1786
|
-
callbacks=self.cb,
|
|
1787
|
-
verbose=1,
|
|
1788
|
-
)
|
|
1789
|
-
else:
|
|
1790
|
-
self.history_regressor = self.model_reg.fit(
|
|
1791
|
-
x=self.x_train[np.argmax(self.y_class_train, axis=1) == 0],
|
|
1792
|
-
y=self.y_time_train[np.argmax(self.y_class_train, axis=1) == 0],
|
|
1793
|
-
batch_size=self.batch_size,
|
|
1794
|
-
epochs=self.epochs * 2,
|
|
1795
|
-
callbacks=self.cb,
|
|
1796
|
-
validation_split=self.validation_split,
|
|
1797
|
-
verbose=1,
|
|
1798
|
-
)
|
|
1799
|
-
|
|
1800
|
-
if self.show_plots:
|
|
1801
|
-
self.plot_model_history(mode="regressor")
|
|
1802
|
-
time_callback = next(
|
|
1803
|
-
(cb for cb in self.cb if isinstance(cb, TimeHistory)), None
|
|
1804
|
-
)
|
|
1805
|
-
self.dico.update(
|
|
1806
|
-
{
|
|
1807
|
-
"history_regressor": self.history_regressor,
|
|
1808
|
-
"execution_time_regressor": (
|
|
1809
|
-
time_callback.times if time_callback else []
|
|
1810
|
-
),
|
|
1811
|
-
}
|
|
1812
|
-
)
|
|
1813
|
-
|
|
1814
|
-
# Evaluate best model
|
|
1815
|
-
self.model_reg = load_model(
|
|
1816
|
-
os.sep.join([self.model_folder, "regressor.h5"]),
|
|
1817
|
-
custom_objects={"mse": MeanSquaredError()},
|
|
1818
|
-
)
|
|
1819
|
-
self.model_reg.load_weights(os.sep.join([self.model_folder, "regressor.h5"]))
|
|
1820
|
-
self.evaluate_regression_model()
|
|
1821
|
-
|
|
1822
|
-
try:
|
|
1823
|
-
np.save(os.sep.join([self.model_folder, "scores.npy"]), self.dico)
|
|
1824
|
-
except Exception as e:
|
|
1825
|
-
print(e)
|
|
1826
|
-
|
|
1827
|
-
def plot_model_history(self, mode="regressor"):
|
|
1828
|
-
"""
|
|
1829
|
-
Generates and saves plots of the training history for the classifier or regressor model.
|
|
1830
|
-
|
|
1831
|
-
Parameters
|
|
1832
|
-
----------
|
|
1833
|
-
mode : str, optional
|
|
1834
|
-
Specifies which model's training history to plot. Options are "classifier" or "regressor". Default is "regressor".
|
|
1835
|
-
|
|
1836
|
-
Notes
|
|
1837
|
-
-----
|
|
1838
|
-
- Plots include loss and accuracy metrics over epochs for the classifier, and loss metrics for the regressor.
|
|
1839
|
-
- The plots are saved as image files in the model's output directory.
|
|
1840
|
-
|
|
1841
|
-
"""
|
|
1842
|
-
|
|
1843
|
-
if mode == "regressor":
|
|
1844
|
-
try:
|
|
1845
|
-
plt.plot(self.history_regressor.history["loss"])
|
|
1846
|
-
plt.plot(self.history_regressor.history["val_loss"])
|
|
1847
|
-
plt.title("model loss")
|
|
1848
|
-
plt.ylabel("loss")
|
|
1849
|
-
plt.xlabel("epoch")
|
|
1850
|
-
plt.yscale("log")
|
|
1851
|
-
plt.legend(["train", "val"], loc="upper left")
|
|
1852
|
-
# plt.pause(3)
|
|
1853
|
-
plt.savefig(
|
|
1854
|
-
os.sep.join([self.model_folder, "regression_loss.png"]),
|
|
1855
|
-
bbox_inches="tight",
|
|
1856
|
-
dpi=300,
|
|
1857
|
-
)
|
|
1858
|
-
plt.close()
|
|
1859
|
-
except Exception as e:
|
|
1860
|
-
print(f"Error {e}; could not generate plot...")
|
|
1861
|
-
elif mode == "classifier":
|
|
1862
|
-
try:
|
|
1863
|
-
plt.plot(self.history_classifier.history["precision"])
|
|
1864
|
-
plt.plot(self.history_classifier.history["val_precision"])
|
|
1865
|
-
plt.title("model precision")
|
|
1866
|
-
plt.ylabel("precision")
|
|
1867
|
-
plt.xlabel("epoch")
|
|
1868
|
-
plt.legend(["train", "val"], loc="upper left")
|
|
1869
|
-
# plt.pause(3)
|
|
1870
|
-
plt.savefig(
|
|
1871
|
-
os.sep.join([self.model_folder, "classification_loss.png"]),
|
|
1872
|
-
bbox_inches="tight",
|
|
1873
|
-
dpi=300,
|
|
1874
|
-
)
|
|
1875
|
-
plt.close()
|
|
1876
|
-
except Exception as e:
|
|
1877
|
-
print(f"Error {e}; could not generate plot...")
|
|
1878
|
-
else:
|
|
1879
|
-
return None
|
|
1880
|
-
|
|
1881
|
-
def evaluate_regression_model(self):
|
|
1882
|
-
"""
|
|
1883
|
-
Evaluates the performance of the trained regression model on test and validation datasets.
|
|
1884
|
-
|
|
1885
|
-
This method calculates and prints mean squared error and mean absolute error metrics for the regression model's
|
|
1886
|
-
predictions. It also generates regression plots comparing predicted times of interest to true values.
|
|
1887
|
-
|
|
1888
|
-
Notes
|
|
1889
|
-
-----
|
|
1890
|
-
- Evaluation is performed on both test and validation datasets, if available.
|
|
1891
|
-
- Regression plots and performance metrics are saved in the model's output directory.
|
|
1892
|
-
|
|
1893
|
-
"""
|
|
1894
|
-
|
|
1895
|
-
mse = MeanSquaredError()
|
|
1896
|
-
mae = MeanAbsoluteError()
|
|
1897
|
-
|
|
1898
|
-
if hasattr(self, "x_test"):
|
|
1899
|
-
|
|
1900
|
-
print("Evaluate on test set...")
|
|
1901
|
-
predictions = self.model_reg.predict(
|
|
1902
|
-
self.x_test[np.argmax(self.y_class_test, axis=1) == 0],
|
|
1903
|
-
batch_size=self.batch_size,
|
|
1904
|
-
)[:, 0]
|
|
1905
|
-
ground_truth = self.y_time_test[np.argmax(self.y_class_test, axis=1) == 0]
|
|
1906
|
-
assert (
|
|
1907
|
-
predictions.shape == ground_truth.shape
|
|
1908
|
-
), "Shape mismatch between predictions and ground truths..."
|
|
1909
|
-
|
|
1910
|
-
test_mse = mse(ground_truth, predictions).numpy()
|
|
1911
|
-
test_mae = mae(ground_truth, predictions).numpy()
|
|
1912
|
-
print(f"MSE on test set: {test_mse}...")
|
|
1913
|
-
print(f"MAE on test set: {test_mae}...")
|
|
1914
|
-
if self.show_plots:
|
|
1915
|
-
regression_plot(
|
|
1916
|
-
predictions,
|
|
1917
|
-
ground_truth,
|
|
1918
|
-
savepath=os.sep.join([self.model_folder, "test_regression.png"]),
|
|
1919
|
-
)
|
|
1920
|
-
self.dico.update({"test_mse": test_mse, "test_mae": test_mae})
|
|
1921
|
-
|
|
1922
|
-
if hasattr(self, "x_val"):
|
|
1923
|
-
# Validation set
|
|
1924
|
-
predictions = self.model_reg.predict(
|
|
1925
|
-
self.x_val[np.argmax(self.y_class_val, axis=1) == 0],
|
|
1926
|
-
batch_size=self.batch_size,
|
|
1927
|
-
)[:, 0]
|
|
1928
|
-
ground_truth = self.y_time_val[np.argmax(self.y_class_val, axis=1) == 0]
|
|
1929
|
-
assert (
|
|
1930
|
-
predictions.shape == ground_truth.shape
|
|
1931
|
-
), "Shape mismatch between predictions and ground truths..."
|
|
1932
|
-
|
|
1933
|
-
val_mse = mse(ground_truth, predictions).numpy()
|
|
1934
|
-
val_mae = mae(ground_truth, predictions).numpy()
|
|
1935
|
-
|
|
1936
|
-
if self.show_plots:
|
|
1937
|
-
regression_plot(
|
|
1938
|
-
predictions,
|
|
1939
|
-
ground_truth,
|
|
1940
|
-
savepath=os.sep.join(
|
|
1941
|
-
[self.model_folder, "validation_regression.png"]
|
|
1942
|
-
),
|
|
1943
|
-
)
|
|
1944
|
-
print(f"MSE on validation set: {val_mse}...")
|
|
1945
|
-
print(f"MAE on validation set: {val_mae}...")
|
|
1946
|
-
|
|
1947
|
-
self.dico.update(
|
|
1948
|
-
{
|
|
1949
|
-
"val_mse": val_mse,
|
|
1950
|
-
"val_mae": val_mae,
|
|
1951
|
-
"val_predictions": predictions,
|
|
1952
|
-
"val_ground_truth": ground_truth,
|
|
1953
|
-
}
|
|
1954
|
-
)
|
|
1955
|
-
|
|
1956
|
-
# Send result to GUI and wait
|
|
1957
|
-
for cb in self.cb:
|
|
1958
|
-
if hasattr(cb, "on_training_result"):
|
|
1959
|
-
cb.on_training_result(self.dico)
|
|
1960
|
-
time.sleep(3)
|
|
1961
|
-
|
|
1962
|
-
def gather_callbacks(self, mode):
|
|
1963
|
-
"""
|
|
1964
|
-
Prepares a list of Keras callbacks for model training based on the specified mode.
|
|
1965
|
-
|
|
1966
|
-
Parameters
|
|
1967
|
-
----------
|
|
1968
|
-
mode : str
|
|
1969
|
-
The training mode for which callbacks are being prepared. Options are "classifier" or "regressor".
|
|
1970
|
-
|
|
1971
|
-
Notes
|
|
1972
|
-
-----
|
|
1973
|
-
- Callbacks include learning rate reduction on plateau, early stopping, model checkpointing, and TensorBoard logging.
|
|
1974
|
-
- The list of callbacks is stored in the class attribute `cb` and used during model training.
|
|
1975
|
-
|
|
1976
|
-
"""
|
|
1977
|
-
|
|
1978
|
-
self.cb = []
|
|
1979
|
-
|
|
1980
|
-
if mode == "classifier":
|
|
1981
|
-
|
|
1982
|
-
reduce_lr = ReduceLROnPlateau(
|
|
1983
|
-
monitor="val_iou",
|
|
1984
|
-
factor=0.5,
|
|
1985
|
-
patience=30,
|
|
1986
|
-
cooldown=10,
|
|
1987
|
-
min_lr=5e-10,
|
|
1988
|
-
min_delta=1.0e-10,
|
|
1989
|
-
verbose=1,
|
|
1990
|
-
mode="max",
|
|
1991
|
-
)
|
|
1992
|
-
self.cb.append(reduce_lr)
|
|
1993
|
-
csv_logger = CSVLogger(
|
|
1994
|
-
os.sep.join([self.model_folder, "log_classifier.csv"]),
|
|
1995
|
-
append=True,
|
|
1996
|
-
separator=";",
|
|
1997
|
-
)
|
|
1998
|
-
self.cb.append(csv_logger)
|
|
1999
|
-
checkpoint_path = os.sep.join([self.model_folder, "classifier.h5"])
|
|
2000
|
-
cp_callback = ModelCheckpoint(
|
|
2001
|
-
checkpoint_path,
|
|
2002
|
-
monitor="val_iou",
|
|
2003
|
-
mode="max",
|
|
2004
|
-
verbose=1,
|
|
2005
|
-
save_best_only=True,
|
|
2006
|
-
save_weights_only=False,
|
|
2007
|
-
save_freq="epoch",
|
|
2008
|
-
)
|
|
2009
|
-
self.cb.append(cp_callback)
|
|
2010
|
-
|
|
2011
|
-
callback_stop = EarlyStopping(monitor="val_iou", mode="max", patience=100)
|
|
2012
|
-
self.cb.append(callback_stop)
|
|
2013
|
-
|
|
2014
|
-
elif mode == "regressor":
|
|
2015
|
-
|
|
2016
|
-
reduce_lr = ReduceLROnPlateau(
|
|
2017
|
-
monitor="val_loss",
|
|
2018
|
-
factor=0.5,
|
|
2019
|
-
patience=30,
|
|
2020
|
-
cooldown=10,
|
|
2021
|
-
min_lr=5e-10,
|
|
2022
|
-
min_delta=1.0e-10,
|
|
2023
|
-
verbose=1,
|
|
2024
|
-
mode="min",
|
|
2025
|
-
)
|
|
2026
|
-
self.cb.append(reduce_lr)
|
|
2027
|
-
|
|
2028
|
-
csv_logger = CSVLogger(
|
|
2029
|
-
os.sep.join([self.model_folder, "log_regressor.csv"]),
|
|
2030
|
-
append=True,
|
|
2031
|
-
separator=";",
|
|
2032
|
-
)
|
|
2033
|
-
self.cb.append(csv_logger)
|
|
2034
|
-
|
|
2035
|
-
checkpoint_path = os.sep.join([self.model_folder, "regressor.h5"])
|
|
2036
|
-
cp_callback = ModelCheckpoint(
|
|
2037
|
-
checkpoint_path,
|
|
2038
|
-
monitor="val_loss",
|
|
2039
|
-
mode="min",
|
|
2040
|
-
verbose=1,
|
|
2041
|
-
save_best_only=True,
|
|
2042
|
-
save_weights_only=False,
|
|
2043
|
-
save_freq="epoch",
|
|
2044
|
-
)
|
|
2045
|
-
self.cb.append(cp_callback)
|
|
2046
|
-
|
|
2047
|
-
callback_stop = EarlyStopping(monitor="val_loss", mode="min", patience=200)
|
|
2048
|
-
self.cb.append(callback_stop)
|
|
2049
|
-
|
|
2050
|
-
log_dir = self.model_folder + os.sep
|
|
2051
|
-
cb_tb = TensorBoard(log_dir=log_dir, update_freq="batch")
|
|
2052
|
-
self.cb.append(cb_tb)
|
|
2053
|
-
|
|
2054
|
-
cb_time = TimeHistory()
|
|
2055
|
-
self.cb.append(cb_time)
|
|
2056
|
-
|
|
2057
|
-
if hasattr(self, "callbacks") and self.callbacks is not None:
|
|
2058
|
-
self.cb.extend(self.callbacks)
|
|
2059
|
-
|
|
2060
|
-
def prepare_sets(self):
|
|
2061
|
-
"""
|
|
2062
|
-
Generates and preprocesses training, validation, and test sets from loaded annotations.
|
|
2063
|
-
|
|
2064
|
-
This method loads signal data from annotation files, normalizes and interpolates the signals, and splits
|
|
2065
|
-
the dataset into training, validation, and test sets according to specified proportions.
|
|
2066
|
-
|
|
2067
|
-
Notes
|
|
2068
|
-
-----
|
|
2069
|
-
- Signal annotations are expected to be stored in .npy format and contain required channels and event information.
|
|
2070
|
-
- The method applies specified normalization and interpolation options to prepare the signals for model training.
|
|
2071
|
-
|
|
2072
|
-
"""
|
|
2073
|
-
|
|
2074
|
-
self.x_set = []
|
|
2075
|
-
self.y_time_set = []
|
|
2076
|
-
self.y_class_set = []
|
|
2077
|
-
|
|
2078
|
-
if isinstance(self.list_of_sets[0], str):
|
|
2079
|
-
# Case 1: a list of npy files to be loaded
|
|
2080
|
-
for s in self.list_of_sets:
|
|
2081
|
-
|
|
2082
|
-
signal_dataset = self.load_set(s)
|
|
2083
|
-
selected_signals, max_length = self.find_best_signal_match(
|
|
2084
|
-
signal_dataset
|
|
2085
|
-
)
|
|
2086
|
-
signals_recast, classes, times_of_interest = (
|
|
2087
|
-
self.cast_signals_into_training_data(
|
|
2088
|
-
signal_dataset, selected_signals, max_length
|
|
2089
|
-
)
|
|
2090
|
-
)
|
|
2091
|
-
signals_recast, times_of_interest = self.normalize_signals(
|
|
2092
|
-
signals_recast, times_of_interest
|
|
2093
|
-
)
|
|
2094
|
-
|
|
2095
|
-
self.x_set.extend(signals_recast)
|
|
2096
|
-
self.y_time_set.extend(times_of_interest)
|
|
2097
|
-
self.y_class_set.extend(classes)
|
|
2098
|
-
|
|
2099
|
-
elif isinstance(self.list_of_sets[0], list):
|
|
2100
|
-
# Case 2: a list of sets (already loaded)
|
|
2101
|
-
for signal_dataset in self.list_of_sets:
|
|
2102
|
-
|
|
2103
|
-
selected_signals, max_length = self.find_best_signal_match(
|
|
2104
|
-
signal_dataset
|
|
2105
|
-
)
|
|
2106
|
-
signals_recast, classes, times_of_interest = (
|
|
2107
|
-
self.cast_signals_into_training_data(
|
|
2108
|
-
signal_dataset, selected_signals, max_length
|
|
2109
|
-
)
|
|
2110
|
-
)
|
|
2111
|
-
signals_recast, times_of_interest = self.normalize_signals(
|
|
2112
|
-
signals_recast, times_of_interest
|
|
2113
|
-
)
|
|
2114
|
-
|
|
2115
|
-
self.x_set.extend(signals_recast)
|
|
2116
|
-
self.y_time_set.extend(times_of_interest)
|
|
2117
|
-
self.y_class_set.extend(classes)
|
|
2118
|
-
|
|
2119
|
-
self.x_set = np.array(self.x_set).astype(np.float32)
|
|
2120
|
-
self.x_set = self.interpolate_signals(self.x_set)
|
|
2121
|
-
|
|
2122
|
-
self.y_time_set = np.array(self.y_time_set).astype(np.float32)
|
|
2123
|
-
self.y_class_set = np.array(self.y_class_set).astype(np.float32)
|
|
2124
|
-
|
|
2125
|
-
class_test = np.isin(self.y_class_set, [0, 1, 2])
|
|
2126
|
-
self.x_set = self.x_set[class_test]
|
|
2127
|
-
self.y_time_set = self.y_time_set[class_test]
|
|
2128
|
-
self.y_class_set = self.y_class_set[class_test]
|
|
2129
|
-
|
|
2130
|
-
# Compute class weights and one-hot encode
|
|
2131
|
-
self.class_weights = compute_weights(self.y_class_set)
|
|
2132
|
-
self.nbr_classes = 3 # len(np.unique(self.y_class_set))
|
|
2133
|
-
self.y_class_set = to_categorical(self.y_class_set, num_classes=3)
|
|
2134
|
-
|
|
2135
|
-
ds = train_test_split(
|
|
2136
|
-
self.x_set,
|
|
2137
|
-
self.y_time_set,
|
|
2138
|
-
self.y_class_set,
|
|
2139
|
-
validation_size=self.validation_split,
|
|
2140
|
-
test_size=self.test_split,
|
|
2141
|
-
)
|
|
2142
|
-
|
|
2143
|
-
self.x_train = ds["x_train"]
|
|
2144
|
-
self.x_val = ds["x_val"]
|
|
2145
|
-
self.y_time_train = ds["y1_train"].astype(np.float32)
|
|
2146
|
-
self.y_time_val = ds["y1_val"].astype(np.float32)
|
|
2147
|
-
self.y_class_train = ds["y2_train"]
|
|
2148
|
-
self.y_class_val = ds["y2_val"]
|
|
2149
|
-
|
|
2150
|
-
if self.test_split > 0:
|
|
2151
|
-
self.x_test = ds["x_test"]
|
|
2152
|
-
self.y_time_test = ds["y1_test"].astype(np.float32)
|
|
2153
|
-
self.y_class_test = ds["y2_test"]
|
|
2154
|
-
|
|
2155
|
-
if self.augment:
|
|
2156
|
-
self.augment_training_set()
|
|
2157
|
-
|
|
2158
|
-
def augment_training_set(self, time_shift=True):
|
|
2159
|
-
"""
|
|
2160
|
-
Augments the training dataset with artificially generated data to increase model robustness.
|
|
2161
|
-
|
|
2162
|
-
Parameters
|
|
2163
|
-
----------
|
|
2164
|
-
time_shift : bool, optional
|
|
2165
|
-
Specifies whether to include time-shifted versions of signals in the augmented dataset. Default is True.
|
|
2166
|
-
|
|
2167
|
-
Notes
|
|
2168
|
-
-----
|
|
2169
|
-
- Augmentation strategies include random time shifting and signal modifications to simulate variations in real data.
|
|
2170
|
-
- The augmented dataset is used for training the classifier and regressor models to improve generalization.
|
|
2171
|
-
|
|
2172
|
-
"""
|
|
2173
|
-
|
|
2174
|
-
nbr_augment = self.augmentation_factor * len(self.x_train)
|
|
2175
|
-
randomize = np.arange(len(self.x_train))
|
|
2176
|
-
|
|
2177
|
-
unique, counts = np.unique(
|
|
2178
|
-
self.y_class_train.argmax(axis=1), return_counts=True
|
|
2179
|
-
)
|
|
2180
|
-
frac = counts / sum(counts)
|
|
2181
|
-
weights = [frac[0] / f for f in frac]
|
|
2182
|
-
weights[0] = weights[0] * 3
|
|
2183
|
-
|
|
2184
|
-
self.pre_augment_weights = weights / sum(weights)
|
|
2185
|
-
weights_array = [
|
|
2186
|
-
self.pre_augment_weights[a.argmax()] for a in self.y_class_train
|
|
2187
|
-
]
|
|
2188
|
-
|
|
2189
|
-
indices = random.choices(randomize, k=nbr_augment, weights=weights_array)
|
|
2190
|
-
|
|
2191
|
-
x_train_aug = []
|
|
2192
|
-
y_time_train_aug = []
|
|
2193
|
-
y_class_train_aug = []
|
|
2194
|
-
|
|
2195
|
-
counts = [0.0, 0.0, 0.0]
|
|
2196
|
-
# warning augmentation creates class 2 even if does not exist in data, need to address this
|
|
2197
|
-
for k in indices:
|
|
2198
|
-
counts[self.y_class_train[k].argmax()] += 1
|
|
2199
|
-
aug = augmenter(
|
|
2200
|
-
self.x_train[k],
|
|
2201
|
-
self.y_time_train[k],
|
|
2202
|
-
self.y_class_train[k],
|
|
2203
|
-
self.model_signal_length,
|
|
2204
|
-
time_shift=time_shift,
|
|
2205
|
-
)
|
|
2206
|
-
x_train_aug.append(aug[0])
|
|
2207
|
-
y_time_train_aug.append(aug[1])
|
|
2208
|
-
y_class_train_aug.append(aug[2])
|
|
2209
|
-
|
|
2210
|
-
# Save augmented training set
|
|
2211
|
-
self.x_train = np.array(x_train_aug)
|
|
2212
|
-
self.y_time_train = np.array(y_time_train_aug)
|
|
2213
|
-
self.y_class_train = np.array(y_class_train_aug)
|
|
2214
|
-
|
|
2215
|
-
self.class_weights = compute_weights(self.y_class_train.argmax(axis=1))
|
|
2216
|
-
print(f"New class weights: {self.class_weights}...")
|
|
2217
|
-
|
|
2218
|
-
def load_set(self, signal_dataset):
|
|
2219
|
-
return np.load(signal_dataset, allow_pickle=True)
|
|
2220
|
-
|
|
2221
|
-
def find_best_signal_match(self, signal_dataset):
|
|
2222
|
-
|
|
2223
|
-
required_signals = self.channel_option
|
|
2224
|
-
available_signals = list(signal_dataset[0].keys())
|
|
2225
|
-
|
|
2226
|
-
selected_signals = []
|
|
2227
|
-
for s in required_signals:
|
|
2228
|
-
pattern_test = [s in a for a in available_signals]
|
|
2229
|
-
if np.any(pattern_test):
|
|
2230
|
-
valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
2231
|
-
if len(valid_columns) == 1:
|
|
2232
|
-
selected_signals.append(valid_columns[0])
|
|
2233
|
-
else:
|
|
2234
|
-
print(f"Found several candidate signals: {valid_columns}")
|
|
2235
|
-
for vc in natsorted(valid_columns):
|
|
2236
|
-
if "circle" in vc:
|
|
2237
|
-
selected_signals.append(vc)
|
|
2238
|
-
break
|
|
2239
|
-
else:
|
|
2240
|
-
selected_signals.append(valid_columns[0])
|
|
2241
|
-
else:
|
|
2242
|
-
return None
|
|
2243
|
-
|
|
2244
|
-
key_to_check = selected_signals[0] # self.channel_option[0]
|
|
2245
|
-
signal_lengths = [len(l[key_to_check]) for l in signal_dataset]
|
|
2246
|
-
max_length = np.amax(signal_lengths)
|
|
2247
|
-
|
|
2248
|
-
return selected_signals, max_length
|
|
2249
|
-
|
|
2250
|
-
def cast_signals_into_training_data(
|
|
2251
|
-
self, signal_dataset, selected_signals, max_length
|
|
2252
|
-
):
|
|
2253
|
-
|
|
2254
|
-
signals_recast = np.zeros((len(signal_dataset), max_length, self.n_channels))
|
|
2255
|
-
classes = np.zeros(len(signal_dataset))
|
|
2256
|
-
times_of_interest = np.zeros(len(signal_dataset))
|
|
2257
|
-
|
|
2258
|
-
for k in range(len(signal_dataset)):
|
|
2259
|
-
|
|
2260
|
-
for i in range(self.n_channels):
|
|
2261
|
-
try:
|
|
2262
|
-
# take into account timeline for accurate time regression
|
|
2263
|
-
|
|
2264
|
-
if selected_signals[i].startswith("pair_"):
|
|
2265
|
-
timeline = signal_dataset[k]["pair_FRAME"].astype(int)
|
|
2266
|
-
elif selected_signals[i].startswith("reference_"):
|
|
2267
|
-
timeline = signal_dataset[k]["reference_FRAME"].astype(int)
|
|
2268
|
-
elif selected_signals[i].startswith("neighbor_"):
|
|
2269
|
-
timeline = signal_dataset[k]["neighbor_FRAME"].astype(int)
|
|
2270
|
-
else:
|
|
2271
|
-
timeline = signal_dataset[k]["FRAME"].astype(int)
|
|
2272
|
-
signals_recast[k, timeline, i] = signal_dataset[k][
|
|
2273
|
-
selected_signals[i]
|
|
2274
|
-
]
|
|
2275
|
-
except:
|
|
2276
|
-
print(
|
|
2277
|
-
f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation..."
|
|
2278
|
-
)
|
|
2279
|
-
pass
|
|
2280
|
-
|
|
2281
|
-
classes[k] = signal_dataset[k]["class"]
|
|
2282
|
-
times_of_interest[k] = signal_dataset[k]["time_of_interest"]
|
|
2283
|
-
|
|
2284
|
-
# Correct absurd times of interest
|
|
2285
|
-
times_of_interest[np.nonzero(classes)] = -1
|
|
2286
|
-
times_of_interest[(times_of_interest <= 0.0)] = -1
|
|
2287
|
-
|
|
2288
|
-
return signals_recast, classes, times_of_interest
|
|
2289
|
-
|
|
2290
|
-
def normalize_signals(self, signals_recast, times_of_interest):
|
|
2291
|
-
|
|
2292
|
-
signals_recast = pad_to_model_length(signals_recast, self.model_signal_length)
|
|
2293
|
-
if self.normalize:
|
|
2294
|
-
signals_recast = normalize_signal_set(
|
|
2295
|
-
signals_recast,
|
|
2296
|
-
self.channel_option,
|
|
2297
|
-
normalization_percentile=self.normalization_percentile,
|
|
2298
|
-
normalization_values=self.normalization_values,
|
|
2299
|
-
normalization_clip=self.normalization_clip,
|
|
2300
|
-
)
|
|
2301
|
-
|
|
2302
|
-
# Trivial normalization for time of interest
|
|
2303
|
-
times_of_interest /= self.model_signal_length
|
|
2304
|
-
|
|
2305
|
-
return signals_recast, times_of_interest
|
|
2306
|
-
|
|
2307
|
-
|
|
2308
|
-
def _interpret_normalization_parameters(
|
|
2309
|
-
n_channels, normalization_percentile, normalization_values, normalization_clip
|
|
2310
|
-
):
|
|
2311
|
-
"""
|
|
2312
|
-
Interprets and validates normalization parameters for each channel.
|
|
2313
|
-
|
|
2314
|
-
This function ensures the normalization parameters are correctly formatted and expanded to match
|
|
2315
|
-
the number of channels in the dataset. It provides default values and expands single values into
|
|
2316
|
-
lists to match the number of channels if necessary.
|
|
2317
|
-
|
|
2318
|
-
Parameters
|
|
2319
|
-
----------
|
|
2320
|
-
n_channels : int
|
|
2321
|
-
The number of channels in the dataset.
|
|
2322
|
-
normalization_percentile : list of bool or bool, optional
|
|
2323
|
-
Specifies whether to normalize each channel based on percentile values. If a single bool is provided,
|
|
2324
|
-
it is expanded to a list matching the number of channels. Default is True for all channels.
|
|
2325
|
-
normalization_values : list of lists or list, optional
|
|
2326
|
-
Specifies the percentile values [lower, upper] for normalization for each channel. If a single pair
|
|
2327
|
-
is provided, it is expanded to match the number of channels. Default is [[0.1, 99.9]] for all channels.
|
|
2328
|
-
normalization_clip : list of bool or bool, optional
|
|
2329
|
-
Specifies whether to clip the normalized values for each channel to the range [0, 1]. If a single bool
|
|
2330
|
-
is provided, it is expanded to a list matching the number of channels. Default is False for all channels.
|
|
2331
|
-
|
|
2332
|
-
Returns
|
|
2333
|
-
-------
|
|
2334
|
-
tuple
|
|
2335
|
-
A tuple containing three lists: `normalization_percentile`, `normalization_values`, and `normalization_clip`,
|
|
2336
|
-
each of length `n_channels`, representing the interpreted and validated normalization parameters for each channel.
|
|
2337
|
-
|
|
2338
|
-
Raises
|
|
2339
|
-
------
|
|
2340
|
-
AssertionError
|
|
2341
|
-
If the lengths of the provided lists do not match `n_channels`.
|
|
2342
|
-
|
|
2343
|
-
Examples
|
|
2344
|
-
--------
|
|
2345
|
-
>>> n_channels = 2
|
|
2346
|
-
>>> normalization_percentile = True
|
|
2347
|
-
>>> normalization_values = [0.1, 99.9]
|
|
2348
|
-
>>> normalization_clip = False
|
|
2349
|
-
>>> params = _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip)
|
|
2350
|
-
>>> print(params)
|
|
2351
|
-
# ([True, True], [[0.1, 99.9], [0.1, 99.9]], [False, False])
|
|
2352
|
-
|
|
2353
|
-
"""
|
|
2354
|
-
|
|
2355
|
-
if normalization_percentile is None:
|
|
2356
|
-
normalization_percentile = [True] * n_channels
|
|
2357
|
-
if normalization_values is None:
|
|
2358
|
-
normalization_values = [[0.1, 99.9]] * n_channels
|
|
2359
|
-
if normalization_clip is None:
|
|
2360
|
-
normalization_clip = [False] * n_channels
|
|
2361
|
-
|
|
2362
|
-
if isinstance(normalization_percentile, bool):
|
|
2363
|
-
normalization_percentile = [normalization_percentile] * n_channels
|
|
2364
|
-
if isinstance(normalization_clip, bool):
|
|
2365
|
-
normalization_clip = [normalization_clip] * n_channels
|
|
2366
|
-
if len(normalization_values) == 2 and not isinstance(normalization_values[0], list):
|
|
2367
|
-
normalization_values = [normalization_values] * n_channels
|
|
2368
|
-
|
|
2369
|
-
assert len(normalization_values) == n_channels
|
|
2370
|
-
assert len(normalization_clip) == n_channels
|
|
2371
|
-
assert len(normalization_percentile) == n_channels
|
|
2372
|
-
|
|
2373
|
-
return normalization_percentile, normalization_values, normalization_clip
|
|
2374
|
-
|
|
2375
|
-
|
|
2376
|
-
def normalize_signal_set(
|
|
2377
|
-
signal_set,
|
|
2378
|
-
channel_option,
|
|
2379
|
-
percentile_alive=[0.01, 99.99],
|
|
2380
|
-
percentile_dead=[0.5, 99.999],
|
|
2381
|
-
percentile_generic=[0.01, 99.99],
|
|
2382
|
-
normalization_percentile=None,
|
|
2383
|
-
normalization_values=None,
|
|
2384
|
-
normalization_clip=None,
|
|
2385
|
-
):
|
|
2386
|
-
"""
|
|
2387
|
-
Normalizes a set of single-cell signals across specified channels using given percentile values or specific normalization parameters.
|
|
2388
|
-
|
|
2389
|
-
This function applies normalization to each channel in the signal set based on the provided normalization parameters,
|
|
2390
|
-
which can be defined globally or per channel. The normalization process aims to scale the signal values to a standard
|
|
2391
|
-
range, improving the consistency and comparability of signal measurements across samples.
|
|
2392
|
-
|
|
2393
|
-
Parameters
|
|
2394
|
-
----------
|
|
2395
|
-
signal_set : ndarray
|
|
2396
|
-
A 3D numpy array representing the set of signals to be normalized, with dimensions corresponding to (samples, time points, channels).
|
|
2397
|
-
channel_option : list of str
|
|
2398
|
-
A list specifying the channels included in the signal set and their corresponding normalization strategy based on channel names.
|
|
2399
|
-
percentile_alive : list of float, optional
|
|
2400
|
-
The percentile values [lower, upper] used for normalization of signals from channels labeled as 'alive'. Default is [0.01, 99.99].
|
|
2401
|
-
percentile_dead : list of float, optional
|
|
2402
|
-
The percentile values [lower, upper] used for normalization of signals from channels labeled as 'dead'. Default is [0.5, 99.999].
|
|
2403
|
-
percentile_generic : list of float, optional
|
|
2404
|
-
The percentile values [lower, upper] used for normalization of signals from channels not specifically labeled as 'alive' or 'dead'.
|
|
2405
|
-
Default is [0.01, 99.99].
|
|
2406
|
-
normalization_percentile : list of bool or None, optional
|
|
2407
|
-
Specifies whether to normalize each channel based on percentile values. If None, the default percentile strategy is applied
|
|
2408
|
-
based on `channel_option`. If a list, it should match the length of `channel_option`.
|
|
2409
|
-
normalization_values : list of lists or None, optional
|
|
2410
|
-
Specifies the percentile values [lower, upper] or fixed values [min, max] for normalization for each channel. Overrides
|
|
2411
|
-
`percentile_alive`, `percentile_dead`, and `percentile_generic` if provided.
|
|
2412
|
-
normalization_clip : list of bool or None, optional
|
|
2413
|
-
Specifies whether to clip the normalized values for each channel to the range [0, 1]. If None, clipping is disabled by default.
|
|
2414
|
-
|
|
2415
|
-
Returns
|
|
2416
|
-
-------
|
|
2417
|
-
ndarray
|
|
2418
|
-
The normalized signal set with the same shape as the input `signal_set`.
|
|
2419
|
-
|
|
2420
|
-
Notes
|
|
2421
|
-
-----
|
|
2422
|
-
- The function supports different normalization strategies for 'alive', 'dead', and generic signal channels, which can be customized
|
|
2423
|
-
via `channel_option` and the percentile parameters.
|
|
2424
|
-
- Normalization parameters (`normalization_percentile`, `normalization_values`, `normalization_clip`) are interpreted and validated
|
|
2425
|
-
by calling `_interpret_normalization_parameters`.
|
|
2426
|
-
|
|
2427
|
-
Examples
|
|
2428
|
-
--------
|
|
2429
|
-
>>> signal_set = np.random.rand(100, 128, 2) # 100 samples, 128 time points, 2 channels
|
|
2430
|
-
>>> channel_option = ['alive', 'dead']
|
|
2431
|
-
>>> normalized_signals = normalize_signal_set(signal_set, channel_option)
|
|
2432
|
-
# Normalizes the signal set based on the default percentile values for 'alive' and 'dead' channels.
|
|
2433
|
-
|
|
2434
|
-
"""
|
|
2435
|
-
|
|
2436
|
-
# Check normalization params are ok
|
|
2437
|
-
n_channels = len(channel_option)
|
|
2438
|
-
normalization_percentile, normalization_values, normalization_clip = (
|
|
2439
|
-
_interpret_normalization_parameters(
|
|
2440
|
-
n_channels,
|
|
2441
|
-
normalization_percentile,
|
|
2442
|
-
normalization_values,
|
|
2443
|
-
normalization_clip,
|
|
2444
|
-
)
|
|
2445
|
-
)
|
|
2446
|
-
for k, channel in enumerate(channel_option):
|
|
2447
|
-
|
|
2448
|
-
zero_values = []
|
|
2449
|
-
for i in range(len(signal_set)):
|
|
2450
|
-
zeros_loc = np.where(signal_set[i, :, k] == 0)
|
|
2451
|
-
zero_values.append(zeros_loc)
|
|
2452
|
-
|
|
2453
|
-
values = signal_set[:, :, k]
|
|
2454
|
-
|
|
2455
|
-
if normalization_percentile[k]:
|
|
2456
|
-
min_val = np.nanpercentile(
|
|
2457
|
-
values[values != 0.0], normalization_values[k][0]
|
|
2458
|
-
)
|
|
2459
|
-
max_val = np.nanpercentile(
|
|
2460
|
-
values[values != 0.0], normalization_values[k][1]
|
|
2461
|
-
)
|
|
2462
|
-
else:
|
|
2463
|
-
min_val = normalization_values[k][0]
|
|
2464
|
-
max_val = normalization_values[k][1]
|
|
2465
|
-
|
|
2466
|
-
signal_set[:, :, k] -= min_val
|
|
2467
|
-
signal_set[:, :, k] /= max_val - min_val
|
|
2468
|
-
|
|
2469
|
-
if normalization_clip[k]:
|
|
2470
|
-
to_clip_low = []
|
|
2471
|
-
to_clip_high = []
|
|
2472
|
-
for i in range(len(signal_set)):
|
|
2473
|
-
clip_low_loc = np.where(signal_set[i, :, k] <= 0)
|
|
2474
|
-
clip_high_loc = np.where(signal_set[i, :, k] >= 1.0)
|
|
2475
|
-
to_clip_low.append(clip_low_loc)
|
|
2476
|
-
to_clip_high.append(clip_high_loc)
|
|
2477
|
-
|
|
2478
|
-
for i, z in enumerate(to_clip_low):
|
|
2479
|
-
signal_set[i, z, k] = 0.0
|
|
2480
|
-
for i, z in enumerate(to_clip_high):
|
|
2481
|
-
signal_set[i, z, k] = 1.0
|
|
2482
|
-
|
|
2483
|
-
for i, z in enumerate(zero_values):
|
|
2484
|
-
signal_set[i, z, k] = 0.0
|
|
2485
|
-
|
|
2486
|
-
return signal_set
|
|
2487
|
-
|
|
2488
|
-
|
|
2489
|
-
def pad_to_model_length(signal_set, model_signal_length):
|
|
2490
|
-
"""
|
|
2491
|
-
|
|
2492
|
-
Pad the signal set to match the specified model signal length.
|
|
2493
|
-
|
|
2494
|
-
Parameters
|
|
2495
|
-
----------
|
|
2496
|
-
signal_set : array-like
|
|
2497
|
-
The signal set to be padded.
|
|
2498
|
-
model_signal_length : int
|
|
2499
|
-
The desired length of the model signal.
|
|
2500
|
-
|
|
2501
|
-
Returns
|
|
2502
|
-
-------
|
|
2503
|
-
array-like
|
|
2504
|
-
The padded signal set.
|
|
2505
|
-
|
|
2506
|
-
Notes
|
|
2507
|
-
-----
|
|
2508
|
-
This function pads the signal set with zeros along the second dimension (axis 1) to match the specified model signal
|
|
2509
|
-
length. The padding is applied to the end of the signals, increasing their length.
|
|
2510
|
-
|
|
2511
|
-
Examples
|
|
2512
|
-
--------
|
|
2513
|
-
>>> signal_set = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
2514
|
-
>>> padded_signals = pad_to_model_length(signal_set, 5)
|
|
2515
|
-
|
|
2516
|
-
"""
|
|
2517
|
-
|
|
2518
|
-
padded = np.pad(
|
|
2519
|
-
signal_set,
|
|
2520
|
-
[(0, 0), (0, model_signal_length - signal_set.shape[1]), (0, 0)],
|
|
2521
|
-
mode="edge",
|
|
2522
|
-
)
|
|
2523
|
-
|
|
2524
|
-
return padded
|
|
2525
|
-
|
|
2526
|
-
|
|
2527
|
-
def random_intensity_change(signal):
|
|
2528
|
-
"""
|
|
2529
|
-
|
|
2530
|
-
Randomly change the intensity of a signal.
|
|
2531
|
-
|
|
2532
|
-
Parameters
|
|
2533
|
-
----------
|
|
2534
|
-
signal : array-like
|
|
2535
|
-
The input signal to be modified.
|
|
2536
|
-
|
|
2537
|
-
Returns
|
|
2538
|
-
-------
|
|
2539
|
-
array-like
|
|
2540
|
-
The modified signal with randomly changed intensity.
|
|
2541
|
-
|
|
2542
|
-
Notes
|
|
2543
|
-
-----
|
|
2544
|
-
This function applies a random intensity change to each channel of the input signal. The intensity change is
|
|
2545
|
-
performed by multiplying each channel with a random value drawn from a uniform distribution between 0.7 and 1.0.
|
|
2546
|
-
|
|
2547
|
-
Examples
|
|
2548
|
-
--------
|
|
2549
|
-
>>> signal = np.array([[1, 2, 3], [4, 5, 6]])
|
|
2550
|
-
>>> modified_signal = random_intensity_change(signal)
|
|
2551
|
-
|
|
2552
|
-
"""
|
|
2553
|
-
|
|
2554
|
-
for k in range(signal.shape[1]):
|
|
2555
|
-
signal[:, k] = signal[:, k] * np.random.uniform(0.7, 1.0)
|
|
2556
|
-
|
|
2557
|
-
return signal
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
def gauss_noise(signal):
|
|
2561
|
-
"""
|
|
2562
|
-
|
|
2563
|
-
Add Gaussian noise to a signal.
|
|
2564
|
-
|
|
2565
|
-
Parameters
|
|
2566
|
-
----------
|
|
2567
|
-
signal : array-like
|
|
2568
|
-
The input signal to which noise will be added.
|
|
2569
|
-
|
|
2570
|
-
Returns
|
|
2571
|
-
-------
|
|
2572
|
-
array-like
|
|
2573
|
-
The signal with Gaussian noise added.
|
|
2574
|
-
|
|
2575
|
-
Notes
|
|
2576
|
-
-----
|
|
2577
|
-
This function adds Gaussian noise to the input signal. The noise is generated by drawing random values from a
|
|
2578
|
-
standard normal distribution and scaling them by a factor of 0.08 times the input signal. The scaled noise values
|
|
2579
|
-
are then added to the original signal.
|
|
2580
|
-
|
|
2581
|
-
Examples
|
|
2582
|
-
--------
|
|
2583
|
-
>>> signal = np.array([1, 2, 3, 4, 5])
|
|
2584
|
-
>>> noisy_signal = gauss_noise(signal)
|
|
2585
|
-
|
|
2586
|
-
"""
|
|
2587
|
-
|
|
2588
|
-
sig = 0.08 * np.random.uniform(0, 1)
|
|
2589
|
-
signal = signal + sig * np.random.normal(0, 1, signal.shape) * signal
|
|
2590
|
-
return signal
|
|
2591
|
-
|
|
2592
|
-
|
|
2593
|
-
def random_time_shift(signal, time_of_interest, cclass, model_signal_length):
|
|
2594
|
-
"""
|
|
2595
|
-
|
|
2596
|
-
Randomly shift the signals to another time.
|
|
2597
|
-
|
|
2598
|
-
Parameters
|
|
2599
|
-
----------
|
|
2600
|
-
signal : array-like
|
|
2601
|
-
The signal to be shifted.
|
|
2602
|
-
time_of_interest : int or float
|
|
2603
|
-
The original time of interest for the signal. Use -1 if not applicable.
|
|
2604
|
-
model_signal_length : int
|
|
2605
|
-
The length of the model signal.
|
|
2606
|
-
|
|
2607
|
-
Returns
|
|
2608
|
-
-------
|
|
2609
|
-
array-like
|
|
2610
|
-
The shifted fluorescence signal.
|
|
2611
|
-
int or float
|
|
2612
|
-
The new time of interest if available; otherwise, the original time of interest.
|
|
2613
|
-
|
|
2614
|
-
Notes
|
|
2615
|
-
-----
|
|
2616
|
-
This function randomly selects a target time within the specified model signal length and shifts the
|
|
2617
|
-
signal accordingly. The shift is performed along the first dimension (axis 0) of the signal. The function uses
|
|
2618
|
-
nearest-neighbor interpolation for shifting.
|
|
2619
|
-
|
|
2620
|
-
If the original time of interest (`time_of_interest`) is provided (not equal to -1), the function returns the
|
|
2621
|
-
shifted signal along with the new time of interest. Otherwise, it returns the shifted signal along with the
|
|
2622
|
-
original time of interest.
|
|
2623
|
-
|
|
2624
|
-
The `max_time` is set to the `model_signal_length` unless the original time of interest is provided. In that case,
|
|
2625
|
-
`max_time` is set to `model_signal_length - 3` to prevent shifting too close to the edge.
|
|
2626
|
-
|
|
2627
|
-
Examples
|
|
2628
|
-
--------
|
|
2629
|
-
>>> signal = np.array([[1, 2, 3], [4, 5, 6]])
|
|
2630
|
-
>>> shifted_signal, new_time = random_time_shift(signal, 1, 5)
|
|
2631
|
-
|
|
2632
|
-
"""
|
|
2633
|
-
|
|
2634
|
-
min_time = 3
|
|
2635
|
-
max_time = model_signal_length
|
|
2636
|
-
|
|
2637
|
-
return_target = False
|
|
2638
|
-
if time_of_interest != -1:
|
|
2639
|
-
return_target = True
|
|
2640
|
-
max_time = (
|
|
2641
|
-
model_signal_length + 1 / 3 * model_signal_length
|
|
2642
|
-
) # bias to have a third of event class becoming no event
|
|
2643
|
-
min_time = -model_signal_length * 1 / 3
|
|
2644
|
-
|
|
2645
|
-
times = np.linspace(
|
|
2646
|
-
min_time, max_time, 2000
|
|
2647
|
-
) # symmetrize to create left-censored events
|
|
2648
|
-
target_time = np.random.choice(times)
|
|
2649
|
-
|
|
2650
|
-
delta_t = target_time - time_of_interest
|
|
2651
|
-
signal = shift(signal, [delta_t, 0], order=0, mode="nearest")
|
|
2652
|
-
|
|
2653
|
-
if target_time <= 0 and np.argmax(cclass) == 0:
|
|
2654
|
-
target_time = -1
|
|
2655
|
-
cclass = np.array([0.0, 0.0, 1.0]).astype(np.float32)
|
|
2656
|
-
if target_time >= model_signal_length and np.argmax(cclass) == 0:
|
|
2657
|
-
target_time = -1
|
|
2658
|
-
cclass = np.array([0.0, 1.0, 0.0]).astype(np.float32)
|
|
2659
|
-
|
|
2660
|
-
if return_target:
|
|
2661
|
-
return signal, target_time, cclass
|
|
2662
|
-
else:
|
|
2663
|
-
return signal, time_of_interest, cclass
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
def augmenter(
|
|
2667
|
-
signal,
|
|
2668
|
-
time_of_interest,
|
|
2669
|
-
cclass,
|
|
2670
|
-
model_signal_length,
|
|
2671
|
-
time_shift=True,
|
|
2672
|
-
probability=0.95,
|
|
2673
|
-
):
|
|
2674
|
-
"""
|
|
2675
|
-
Randomly augments single-cell signals to simulate variations in noise, intensity ratios, and event times.
|
|
2676
|
-
|
|
2677
|
-
This function applies random transformations to the input signal, including time shifts, intensity changes,
|
|
2678
|
-
and the addition of Gaussian noise, with the aim of increasing the diversity of the dataset for training robust models.
|
|
2679
|
-
|
|
2680
|
-
Parameters
|
|
2681
|
-
----------
|
|
2682
|
-
signal : ndarray
|
|
2683
|
-
A 1D numpy array representing the signal of a single cell to be augmented.
|
|
2684
|
-
time_of_interest : float
|
|
2685
|
-
The normalized time of interest (event time) for the signal, scaled to the range [0, 1].
|
|
2686
|
-
cclass : ndarray
|
|
2687
|
-
A one-hot encoded numpy array representing the class of the cell associated with the signal.
|
|
2688
|
-
model_signal_length : int
|
|
2689
|
-
The length of the signal expected by the model, used for scaling the time of interest.
|
|
2690
|
-
time_shift : bool, optional
|
|
2691
|
-
Specifies whether to apply random time shifts to the signal. Default is True.
|
|
2692
|
-
probability : float, optional
|
|
2693
|
-
The probability with which to apply the augmentation transformations. Default is 0.8.
|
|
2694
|
-
|
|
2695
|
-
Returns
|
|
2696
|
-
-------
|
|
2697
|
-
tuple
|
|
2698
|
-
A tuple containing the augmented signal, the normalized time of interest, and the class of the cell.
|
|
2699
|
-
|
|
2700
|
-
Raises
|
|
2701
|
-
------
|
|
2702
|
-
AssertionError
|
|
2703
|
-
If the time of interest is provided but invalid for time shifting.
|
|
2704
|
-
|
|
2705
|
-
Notes
|
|
2706
|
-
-----
|
|
2707
|
-
- Time shifting is not applied to cells of the class labeled as 'miscellaneous' (typically encoded as the class '2').
|
|
2708
|
-
- The time of interest is rescaled based on the model's expected signal length before and after any time shift.
|
|
2709
|
-
- Augmentation is applied with the specified probability to simulate realistic variability while maintaining
|
|
2710
|
-
some original signals in the dataset.
|
|
2711
|
-
|
|
2712
|
-
"""
|
|
2713
|
-
|
|
2714
|
-
if np.amax(time_of_interest) <= 1.0:
|
|
2715
|
-
time_of_interest *= model_signal_length
|
|
2716
|
-
|
|
2717
|
-
# augment with a certain probability
|
|
2718
|
-
r = random.random()
|
|
2719
|
-
if r <= probability:
|
|
2720
|
-
|
|
2721
|
-
if time_shift:
|
|
2722
|
-
# do not time shift miscellaneous cells
|
|
2723
|
-
assert time_of_interest is not None, f"Please provide valid lysis times"
|
|
2724
|
-
signal, time_of_interest, cclass = random_time_shift(
|
|
2725
|
-
signal, time_of_interest, cclass, model_signal_length
|
|
2726
|
-
)
|
|
2727
|
-
|
|
2728
|
-
# signal = random_intensity_change(signal) #maybe bad idea for non percentile-normalized signals
|
|
2729
|
-
signal = gauss_noise(signal)
|
|
2730
|
-
|
|
2731
|
-
return signal, time_of_interest / model_signal_length, cclass
|
|
2732
|
-
|
|
2733
|
-
|
|
2734
|
-
def residual_block1D(
|
|
2735
|
-
x, number_of_filters, kernel_size=8, match_filter_size=True, connection="identity"
|
|
2736
|
-
):
|
|
2737
|
-
"""
|
|
2738
|
-
|
|
2739
|
-
Create a 1D residual block.
|
|
2740
|
-
|
|
2741
|
-
Parameters
|
|
2742
|
-
----------
|
|
2743
|
-
x : Tensor
|
|
2744
|
-
Input tensor.
|
|
2745
|
-
number_of_filters : int
|
|
2746
|
-
Number of filters in the convolutional layers.
|
|
2747
|
-
match_filter_size : bool, optional
|
|
2748
|
-
Whether to match the filter size of the skip connection to the output. Default is True.
|
|
2749
|
-
|
|
2750
|
-
Returns
|
|
2751
|
-
-------
|
|
2752
|
-
Tensor
|
|
2753
|
-
Output tensor of the residual block.
|
|
2754
|
-
|
|
2755
|
-
Notes
|
|
2756
|
-
-----
|
|
2757
|
-
This function creates a 1D residual block by performing the original mapping followed by adding a skip connection
|
|
2758
|
-
and applying non-linear activation. The skip connection allows the gradient to flow directly to earlier layers and
|
|
2759
|
-
helps mitigate the vanishing gradient problem. The residual block consists of three convolutional layers with
|
|
2760
|
-
batch normalization and ReLU activation functions.
|
|
2761
|
-
|
|
2762
|
-
If `match_filter_size` is True, the skip connection is adjusted to have the same number of filters as the output.
|
|
2763
|
-
Otherwise, the skip connection is kept as is.
|
|
2764
|
-
|
|
2765
|
-
Examples
|
|
2766
|
-
--------
|
|
2767
|
-
>>> inputs = Input(shape=(10, 3))
|
|
2768
|
-
>>> x = residual_block1D(inputs, 64)
|
|
2769
|
-
# Create a 1D residual block with 64 filters and apply it to the input tensor.
|
|
2770
|
-
|
|
2771
|
-
"""
|
|
2772
|
-
|
|
2773
|
-
# Create skip connection
|
|
2774
|
-
x_skip = x
|
|
2775
|
-
|
|
2776
|
-
# Perform the original mapping
|
|
2777
|
-
if connection == "identity":
|
|
2778
|
-
x = Conv1D(
|
|
2779
|
-
number_of_filters, kernel_size=kernel_size, strides=1, padding="same"
|
|
2780
|
-
)(x_skip)
|
|
2781
|
-
elif connection == "projection":
|
|
2782
|
-
x = ZeroPadding1D(padding=kernel_size // 2)(x_skip)
|
|
2783
|
-
x = Conv1D(
|
|
2784
|
-
number_of_filters, kernel_size=kernel_size, strides=2, padding="valid"
|
|
2785
|
-
)(x)
|
|
2786
|
-
x = BatchNormalization()(x)
|
|
2787
|
-
x = Activation("relu")(x)
|
|
2788
|
-
|
|
2789
|
-
x = Conv1D(number_of_filters, kernel_size=kernel_size, strides=1, padding="same")(x)
|
|
2790
|
-
x = BatchNormalization()(x)
|
|
2791
|
-
|
|
2792
|
-
if match_filter_size and connection == "identity":
|
|
2793
|
-
x_skip = Conv1D(number_of_filters, kernel_size=1, padding="same")(x_skip)
|
|
2794
|
-
elif match_filter_size and connection == "projection":
|
|
2795
|
-
x_skip = Conv1D(number_of_filters, kernel_size=1, strides=2, padding="valid")(
|
|
2796
|
-
x_skip
|
|
2797
|
-
)
|
|
2798
|
-
|
|
2799
|
-
# Add the skip connection to the regular mapping
|
|
2800
|
-
x = Add()([x, x_skip])
|
|
2801
|
-
|
|
2802
|
-
# Nonlinearly activate the result
|
|
2803
|
-
x = Activation("relu")(x)
|
|
2804
|
-
|
|
2805
|
-
# Return the result
|
|
2806
|
-
return x
|
|
2807
|
-
|
|
2808
|
-
|
|
2809
|
-
def MultiscaleResNetModel(
|
|
2810
|
-
n_channels,
|
|
2811
|
-
n_classes=3,
|
|
2812
|
-
dropout_rate=0,
|
|
2813
|
-
dense_collection=0,
|
|
2814
|
-
use_pooling=True,
|
|
2815
|
-
header="classifier",
|
|
2816
|
-
model_signal_length=128,
|
|
2817
|
-
):
|
|
2818
|
-
"""
|
|
2819
|
-
|
|
2820
|
-
Define a generic ResNet 1D encoder model.
|
|
2821
|
-
|
|
2822
|
-
Parameters
|
|
2823
|
-
----------
|
|
2824
|
-
n_channels : int
|
|
2825
|
-
Number of input channels.
|
|
2826
|
-
n_blocks : int
|
|
2827
|
-
Number of residual blocks in the model.
|
|
2828
|
-
n_classes : int, optional
|
|
2829
|
-
Number of output classes. Default is 3.
|
|
2830
|
-
dropout_rate : float, optional
|
|
2831
|
-
Dropout rate to be applied. Default is 0.
|
|
2832
|
-
dense_collection : int, optional
|
|
2833
|
-
Number of neurons in the dense layer. Default is 0.
|
|
2834
|
-
header : str, optional
|
|
2835
|
-
Type of the model header. "classifier" for classification, "regressor" for regression. Default is "classifier".
|
|
2836
|
-
model_signal_length : int, optional
|
|
2837
|
-
Length of the input signal. Default is 128.
|
|
2838
|
-
|
|
2839
|
-
Returns
|
|
2840
|
-
-------
|
|
2841
|
-
keras.models.Model
|
|
2842
|
-
ResNet 1D encoder model.
|
|
2843
|
-
|
|
2844
|
-
Notes
|
|
2845
|
-
-----
|
|
2846
|
-
This function defines a generic ResNet 1D encoder model with the specified number of input channels, residual
|
|
2847
|
-
blocks, output classes, dropout rate, dense collection, and model header. The model architecture follows the
|
|
2848
|
-
ResNet principles with 1D convolutional layers and residual connections. The final activation and number of
|
|
2849
|
-
neurons in the output layer are determined based on the header type.
|
|
2850
|
-
|
|
2851
|
-
Examples
|
|
2852
|
-
--------
|
|
2853
|
-
>>> model = ResNetModel(n_channels=3, n_blocks=4, n_classes=2, dropout_rate=0.2)
|
|
2854
|
-
# Define a ResNet 1D encoder model with 3 input channels, 4 residual blocks, and 2 output classes.
|
|
2855
|
-
|
|
2856
|
-
"""
|
|
2857
|
-
|
|
2858
|
-
if header == "classifier":
|
|
2859
|
-
final_activation = "softmax"
|
|
2860
|
-
neurons_final = n_classes
|
|
2861
|
-
elif header == "regressor":
|
|
2862
|
-
final_activation = "linear"
|
|
2863
|
-
neurons_final = 1
|
|
2864
|
-
else:
|
|
2865
|
-
return None
|
|
2866
|
-
|
|
2867
|
-
inputs = Input(
|
|
2868
|
-
shape=(
|
|
2869
|
-
model_signal_length,
|
|
2870
|
-
n_channels,
|
|
2871
|
-
)
|
|
2872
|
-
)
|
|
2873
|
-
x = ZeroPadding1D(3)(inputs)
|
|
2874
|
-
x = Conv1D(64, kernel_size=7, strides=2, padding="valid", use_bias=False)(x)
|
|
2875
|
-
x = BatchNormalization()(x)
|
|
2876
|
-
x = ZeroPadding1D(1)(x)
|
|
2877
|
-
x_common = MaxPooling1D(pool_size=3, strides=2, padding="valid")(x)
|
|
2878
|
-
|
|
2879
|
-
# Block 1
|
|
2880
|
-
x1 = residual_block1D(x_common, 64, kernel_size=7, connection="projection")
|
|
2881
|
-
x1 = residual_block1D(x1, 128, kernel_size=7, connection="projection")
|
|
2882
|
-
x1 = residual_block1D(x1, 256, kernel_size=7, connection="projection")
|
|
2883
|
-
x1 = GlobalAveragePooling1D()(x1)
|
|
2884
|
-
|
|
2885
|
-
# Block 2
|
|
2886
|
-
x2 = residual_block1D(x_common, 64, kernel_size=5, connection="projection")
|
|
2887
|
-
x2 = residual_block1D(x2, 128, kernel_size=5, connection="projection")
|
|
2888
|
-
x2 = residual_block1D(x2, 256, kernel_size=5, connection="projection")
|
|
2889
|
-
x2 = GlobalAveragePooling1D()(x2)
|
|
2890
|
-
|
|
2891
|
-
# Block 3
|
|
2892
|
-
x3 = residual_block1D(x_common, 64, kernel_size=3, connection="projection")
|
|
2893
|
-
x3 = residual_block1D(x3, 128, kernel_size=3, connection="projection")
|
|
2894
|
-
x3 = residual_block1D(x3, 256, kernel_size=3, connection="projection")
|
|
2895
|
-
x3 = GlobalAveragePooling1D()(x3)
|
|
2896
|
-
|
|
2897
|
-
x_combined = Concatenate()([x1, x2, x3])
|
|
2898
|
-
x_combined = Flatten()(x_combined)
|
|
2899
|
-
|
|
2900
|
-
if dense_collection > 0:
|
|
2901
|
-
x_combined = Dense(dense_collection)(x_combined)
|
|
2902
|
-
if dropout_rate > 0:
|
|
2903
|
-
x_combined = Dropout(dropout_rate)(x_combined)
|
|
2904
|
-
|
|
2905
|
-
x_combined = Dense(neurons_final, activation=final_activation, name=header)(
|
|
2906
|
-
x_combined
|
|
2907
|
-
)
|
|
2908
|
-
model = Model(inputs, x_combined, name=header)
|
|
2909
|
-
|
|
2910
|
-
return model
|
|
2911
|
-
|
|
2912
|
-
|
|
2913
|
-
def ResNetModelCurrent(
|
|
2914
|
-
n_channels,
|
|
2915
|
-
n_slices,
|
|
2916
|
-
depth=2,
|
|
2917
|
-
use_pooling=True,
|
|
2918
|
-
n_classes=3,
|
|
2919
|
-
dropout_rate=0.1,
|
|
2920
|
-
dense_collection=512,
|
|
2921
|
-
header="classifier",
|
|
2922
|
-
model_signal_length=128,
|
|
2923
|
-
):
|
|
2924
|
-
"""
|
|
2925
|
-
Creates a ResNet-based model tailored for signal classification or regression tasks.
|
|
2926
|
-
|
|
2927
|
-
This function constructs a 1D ResNet architecture with specified parameters. The model can be configured
|
|
2928
|
-
for either classification or regression tasks, determined by the `header` parameter. It consists of
|
|
2929
|
-
configurable ResNet blocks, global average pooling, optional dense layers, and dropout for regularization.
|
|
2930
|
-
|
|
2931
|
-
Parameters
|
|
2932
|
-
----------
|
|
2933
|
-
n_channels : int
|
|
2934
|
-
The number of channels in the input signal.
|
|
2935
|
-
n_slices : int
|
|
2936
|
-
The number of slices (or ResNet blocks) to use in the model.
|
|
2937
|
-
depth : int, optional
|
|
2938
|
-
The depth of the network, i.e., how many times the number of filters is doubled. Default is 2.
|
|
2939
|
-
use_pooling : bool, optional
|
|
2940
|
-
Whether to use MaxPooling between ResNet blocks. Default is True.
|
|
2941
|
-
n_classes : int, optional
|
|
2942
|
-
The number of classes for the classification task. Ignored for regression. Default is 3.
|
|
2943
|
-
dropout_rate : float, optional
|
|
2944
|
-
The dropout rate for regularization. Default is 0.1.
|
|
2945
|
-
dense_collection : int, optional
|
|
2946
|
-
The number of neurons in the dense layer following global pooling. If 0, the dense layer is omitted. Default is 512.
|
|
2947
|
-
header : str, optional
|
|
2948
|
-
Specifies the task type: "classifier" for classification or "regressor" for regression. Default is "classifier".
|
|
2949
|
-
model_signal_length : int, optional
|
|
2950
|
-
The length of the input signal. Default is 128.
|
|
2951
|
-
|
|
2952
|
-
Returns
|
|
2953
|
-
-------
|
|
2954
|
-
keras.Model
|
|
2955
|
-
The constructed Keras model ready for training or inference.
|
|
2956
|
-
|
|
2957
|
-
Notes
|
|
2958
|
-
-----
|
|
2959
|
-
- The model uses Conv1D layers for signal processing and applies global average pooling before the final classification
|
|
2960
|
-
or regression layer.
|
|
2961
|
-
- The choice of `final_activation` and `neurons_final` depends on the task: "softmax" and `n_classes` for classification,
|
|
2962
|
-
and "linear" and 1 for regression.
|
|
2963
|
-
- This function relies on a custom `residual_block1D` function for constructing ResNet blocks.
|
|
2964
|
-
|
|
2965
|
-
Examples
|
|
2966
|
-
--------
|
|
2967
|
-
>>> model = ResNetModelCurrent(n_channels=1, n_slices=2, depth=2, use_pooling=True, n_classes=3, dropout_rate=0.1, dense_collection=512, header="classifier", model_signal_length=128)
|
|
2968
|
-
# Creates a ResNet model configured for classification with 3 classes.
|
|
2969
|
-
|
|
2970
|
-
"""
|
|
2971
|
-
|
|
2972
|
-
if header == "classifier":
|
|
2973
|
-
final_activation = "softmax"
|
|
2974
|
-
neurons_final = n_classes
|
|
2975
|
-
elif header == "regressor":
|
|
2976
|
-
final_activation = "linear"
|
|
2977
|
-
neurons_final = 1
|
|
2978
|
-
else:
|
|
2979
|
-
return None
|
|
2980
|
-
|
|
2981
|
-
inputs = Input(
|
|
2982
|
-
shape=(
|
|
2983
|
-
model_signal_length,
|
|
2984
|
-
n_channels,
|
|
2985
|
-
)
|
|
2986
|
-
)
|
|
2987
|
-
x2 = Conv1D(64, kernel_size=1, strides=1, padding="same")(inputs)
|
|
2988
|
-
|
|
2989
|
-
n_filters = 64
|
|
2990
|
-
for k in range(depth):
|
|
2991
|
-
for i in range(n_slices):
|
|
2992
|
-
x2 = residual_block1D(x2, n_filters, kernel_size=8)
|
|
2993
|
-
n_filters *= 2
|
|
2994
|
-
if use_pooling and k != (depth - 1):
|
|
2995
|
-
x2 = MaxPooling1D()(x2)
|
|
2996
|
-
|
|
2997
|
-
x2 = GlobalAveragePooling1D()(x2)
|
|
2998
|
-
if dense_collection > 0:
|
|
2999
|
-
x2 = Dense(dense_collection)(x2)
|
|
3000
|
-
if dropout_rate > 0:
|
|
3001
|
-
x2 = Dropout(dropout_rate)(x2)
|
|
3002
|
-
|
|
3003
|
-
x2 = Dense(neurons_final, activation=final_activation, name=header)(x2)
|
|
3004
|
-
model = Model(inputs, x2, name=header)
|
|
3005
|
-
|
|
3006
|
-
return model
|
|
3007
|
-
|
|
3008
|
-
|
|
3009
587
|
def train_signal_model(config):
|
|
3010
588
|
"""
|
|
3011
589
|
Initiates the training of a signal detection model using a specified configuration file.
|