celldetective 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__init__.py +2 -0
- celldetective/__main__.py +432 -0
- celldetective/datasets/segmentation_annotations/blank +0 -0
- celldetective/datasets/signal_annotations/blank +0 -0
- celldetective/events.py +149 -0
- celldetective/extra_properties.py +100 -0
- celldetective/filters.py +89 -0
- celldetective/gui/__init__.py +20 -0
- celldetective/gui/about.py +44 -0
- celldetective/gui/analyze_block.py +563 -0
- celldetective/gui/btrack_options.py +898 -0
- celldetective/gui/classifier_widget.py +386 -0
- celldetective/gui/configure_new_exp.py +532 -0
- celldetective/gui/control_panel.py +438 -0
- celldetective/gui/gui_utils.py +495 -0
- celldetective/gui/json_readers.py +113 -0
- celldetective/gui/measurement_options.py +1425 -0
- celldetective/gui/neighborhood_options.py +452 -0
- celldetective/gui/plot_signals_ui.py +1042 -0
- celldetective/gui/process_block.py +1055 -0
- celldetective/gui/retrain_segmentation_model_options.py +706 -0
- celldetective/gui/retrain_signal_model_options.py +643 -0
- celldetective/gui/seg_model_loader.py +460 -0
- celldetective/gui/signal_annotator.py +2388 -0
- celldetective/gui/signal_annotator_options.py +340 -0
- celldetective/gui/styles.py +217 -0
- celldetective/gui/survival_ui.py +903 -0
- celldetective/gui/tableUI.py +608 -0
- celldetective/gui/thresholds_gui.py +1300 -0
- celldetective/icons/logo-large.png +0 -0
- celldetective/icons/logo.png +0 -0
- celldetective/icons/signals_icon.png +0 -0
- celldetective/icons/splash-test.png +0 -0
- celldetective/icons/splash.png +0 -0
- celldetective/icons/splash0.png +0 -0
- celldetective/icons/survival2.png +0 -0
- celldetective/icons/vignette_signals2.png +0 -0
- celldetective/icons/vignette_signals2.svg +114 -0
- celldetective/io.py +2050 -0
- celldetective/links/zenodo.json +561 -0
- celldetective/measure.py +1258 -0
- celldetective/models/segmentation_effectors/blank +0 -0
- celldetective/models/segmentation_generic/blank +0 -0
- celldetective/models/segmentation_targets/blank +0 -0
- celldetective/models/signal_detection/blank +0 -0
- celldetective/models/tracking_configs/mcf7.json +68 -0
- celldetective/models/tracking_configs/ricm.json +203 -0
- celldetective/models/tracking_configs/ricm2.json +203 -0
- celldetective/neighborhood.py +717 -0
- celldetective/scripts/analyze_signals.py +51 -0
- celldetective/scripts/measure_cells.py +275 -0
- celldetective/scripts/segment_cells.py +212 -0
- celldetective/scripts/segment_cells_thresholds.py +140 -0
- celldetective/scripts/track_cells.py +206 -0
- celldetective/scripts/train_segmentation_model.py +246 -0
- celldetective/scripts/train_signal_model.py +49 -0
- celldetective/segmentation.py +712 -0
- celldetective/signals.py +2826 -0
- celldetective/tracking.py +974 -0
- celldetective/utils.py +1681 -0
- celldetective-1.0.2.dist-info/LICENSE +674 -0
- celldetective-1.0.2.dist-info/METADATA +192 -0
- celldetective-1.0.2.dist-info/RECORD +66 -0
- celldetective-1.0.2.dist-info/WHEEL +5 -0
- celldetective-1.0.2.dist-info/entry_points.txt +2 -0
- celldetective-1.0.2.dist-info/top_level.txt +1 -0
celldetective/signals.py
ADDED
|
@@ -0,0 +1,2826 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import os
|
|
3
|
+
import subprocess
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
from tensorflow.keras.optimizers import Adam
|
|
7
|
+
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau, CSVLogger
|
|
8
|
+
from tensorflow.keras.losses import CategoricalCrossentropy, MeanSquaredError, MeanAbsoluteError
|
|
9
|
+
from tensorflow.keras.metrics import Precision, Recall
|
|
10
|
+
from tensorflow.keras.models import load_model,clone_model
|
|
11
|
+
from tensorflow.config.experimental import list_physical_devices, set_memory_growth
|
|
12
|
+
from tensorflow.keras.utils import to_categorical, plot_model
|
|
13
|
+
from tensorflow.keras import Input, Model
|
|
14
|
+
from tensorflow.keras.layers import Conv1D, BatchNormalization, Dense, Activation, Add, MaxPooling1D, Dropout, GlobalAveragePooling1D, Concatenate, ZeroPadding1D, Flatten
|
|
15
|
+
from tensorflow.keras.callbacks import Callback
|
|
16
|
+
from sklearn.metrics import confusion_matrix, classification_report
|
|
17
|
+
from sklearn.metrics import jaccard_score, balanced_accuracy_score, precision_score, recall_score
|
|
18
|
+
from scipy.interpolate import interp1d
|
|
19
|
+
from scipy.ndimage import shift
|
|
20
|
+
|
|
21
|
+
from celldetective.io import get_signal_models_list, locate_signal_model
|
|
22
|
+
from celldetective.tracking import clean_trajectories
|
|
23
|
+
from celldetective.utils import regression_plot, train_test_split, compute_weights
|
|
24
|
+
import matplotlib.pyplot as plt
|
|
25
|
+
from natsort import natsorted
|
|
26
|
+
from glob import glob
|
|
27
|
+
import shutil
|
|
28
|
+
import random
|
|
29
|
+
from celldetective.utils import color_from_status, color_from_class
|
|
30
|
+
from math import floor, ceil
|
|
31
|
+
from scipy.optimize import curve_fit
|
|
32
|
+
import time
|
|
33
|
+
import math
|
|
34
|
+
import pandas as pd
|
|
35
|
+
|
|
36
|
+
abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TimeHistory(Callback):
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
A custom Keras callback to log the duration of each epoch during training.
|
|
43
|
+
|
|
44
|
+
This callback records the time taken for each epoch during the model training process, allowing for
|
|
45
|
+
monitoring of training efficiency and performance over time. The times are stored in a list, with each
|
|
46
|
+
element representing the duration of an epoch in seconds.
|
|
47
|
+
|
|
48
|
+
Attributes
|
|
49
|
+
----------
|
|
50
|
+
times : list
|
|
51
|
+
A list of times (in seconds) taken for each epoch during the training. This list is populated as the
|
|
52
|
+
training progresses.
|
|
53
|
+
|
|
54
|
+
Methods
|
|
55
|
+
-------
|
|
56
|
+
on_train_begin(logs={})
|
|
57
|
+
Initializes the list of times at the beginning of training.
|
|
58
|
+
|
|
59
|
+
on_epoch_begin(epoch, logs={})
|
|
60
|
+
Records the start time of the current epoch.
|
|
61
|
+
|
|
62
|
+
on_epoch_end(epoch, logs={})
|
|
63
|
+
Calculates and appends the duration of the current epoch to the `times` list.
|
|
64
|
+
|
|
65
|
+
Notes
|
|
66
|
+
-----
|
|
67
|
+
- This callback is intended to be used with the `fit` method of Keras models.
|
|
68
|
+
- The time measurements are made using the `time.time()` function, which provides wall-clock time.
|
|
69
|
+
|
|
70
|
+
Examples
|
|
71
|
+
--------
|
|
72
|
+
>>> from keras.models import Sequential
|
|
73
|
+
>>> from keras.layers import Dense
|
|
74
|
+
>>> model = Sequential([Dense(10, activation='relu', input_shape=(20,)), Dense(1)])
|
|
75
|
+
>>> time_callback = TimeHistory()
|
|
76
|
+
>>> model.compile(optimizer='adam', loss='mean_squared_error')
|
|
77
|
+
>>> model.fit(x_train, y_train, epochs=10, callbacks=[time_callback])
|
|
78
|
+
>>> print(time_callback.times)
|
|
79
|
+
# This will print the time taken for each epoch during the training.
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def on_train_begin(self, logs={}):
|
|
84
|
+
self.times = []
|
|
85
|
+
|
|
86
|
+
def on_epoch_begin(self, epoch, logs={}):
|
|
87
|
+
self.epoch_time_start = time.time()
|
|
88
|
+
|
|
89
|
+
def on_epoch_end(self, epoch, logs={}):
|
|
90
|
+
self.times.append(time.time() - self.epoch_time_start)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def analyze_signals(trajectories, model, interpolate_na=True,
|
|
94
|
+
selected_signals=None,
|
|
95
|
+
column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
|
|
96
|
+
plot_outcome=False, output_dir=None):
|
|
97
|
+
|
|
98
|
+
"""
|
|
99
|
+
Analyzes signals from trajectory data using a specified signal detection model and configuration.
|
|
100
|
+
|
|
101
|
+
This function preprocesses trajectory data, selects specified signals, and applies a pretrained signal detection
|
|
102
|
+
model to predict classes and times of interest for each trajectory. It supports custom column labeling, interpolation
|
|
103
|
+
of missing values, and plotting of analysis outcomes.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
trajectories : pandas.DataFrame
|
|
108
|
+
DataFrame containing trajectory data with columns for track ID, frame, position, and signals.
|
|
109
|
+
model : str
|
|
110
|
+
The name of the signal detection model to be used for analysis.
|
|
111
|
+
interpolate_na : bool, optional
|
|
112
|
+
Whether to interpolate missing values in the trajectories (default is True).
|
|
113
|
+
selected_signals : list of str, optional
|
|
114
|
+
A list of column names from `trajectories` representing the signals to be analyzed. If None, signals will
|
|
115
|
+
be automatically selected based on the model configuration (default is None).
|
|
116
|
+
column_labels : dict, optional
|
|
117
|
+
A dictionary mapping the default column names ('track', 'time', 'x', 'y') to the corresponding column names
|
|
118
|
+
in `trajectories` (default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}).
|
|
119
|
+
plot_outcome : bool, optional
|
|
120
|
+
If True, generates and saves a plot of the signal analysis outcome (default is False).
|
|
121
|
+
output_dir : str, optional
|
|
122
|
+
The directory where the outcome plot will be saved. Required if `plot_outcome` is True (default is None).
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
pandas.DataFrame
|
|
127
|
+
The input `trajectories` DataFrame with additional columns for predicted classes, times of interest, and
|
|
128
|
+
corresponding colors based on status and class.
|
|
129
|
+
|
|
130
|
+
Raises
|
|
131
|
+
------
|
|
132
|
+
AssertionError
|
|
133
|
+
If the model or its configuration file cannot be located.
|
|
134
|
+
|
|
135
|
+
Notes
|
|
136
|
+
-----
|
|
137
|
+
- The function relies on an external model configuration file (`config_input.json`) located in the model's directory.
|
|
138
|
+
- Signal selection and preprocessing are based on the requirements specified in the model's configuration.
|
|
139
|
+
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
model_path = locate_signal_model(model)
|
|
143
|
+
complete_path = model_path #+model
|
|
144
|
+
complete_path = rf"{complete_path}"
|
|
145
|
+
model_config_path = os.sep.join([complete_path,'config_input.json'])
|
|
146
|
+
model_config_path = rf"{model_config_path}"
|
|
147
|
+
assert os.path.exists(complete_path),f'Model {model} could not be located in folder {model_path}... Abort.'
|
|
148
|
+
assert os.path.exists(model_config_path),f'Model configuration could not be located in folder {model_path}... Abort.'
|
|
149
|
+
|
|
150
|
+
available_signals = list(trajectories.columns)
|
|
151
|
+
print('The available_signals are : ',available_signals)
|
|
152
|
+
|
|
153
|
+
f = open(model_config_path)
|
|
154
|
+
config = json.load(f)
|
|
155
|
+
required_signals = config["channels"]
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
label = config['label']
|
|
159
|
+
if label=='':
|
|
160
|
+
label = None
|
|
161
|
+
except:
|
|
162
|
+
label = None
|
|
163
|
+
|
|
164
|
+
if selected_signals is None:
|
|
165
|
+
selected_signals = []
|
|
166
|
+
for s in required_signals:
|
|
167
|
+
pattern_test = [s in a or s==a for a in available_signals]
|
|
168
|
+
print(f'Pattern test for signal {s}: ', pattern_test)
|
|
169
|
+
assert np.any(pattern_test),f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
|
|
170
|
+
valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
171
|
+
if len(valid_columns)==1:
|
|
172
|
+
selected_signals.append(valid_columns[0])
|
|
173
|
+
else:
|
|
174
|
+
#print(test_number_of_nan(trajectories, valid_columns))
|
|
175
|
+
print(f'Found several candidate signals: {valid_columns}')
|
|
176
|
+
for vc in natsorted(valid_columns):
|
|
177
|
+
if 'circle' in vc:
|
|
178
|
+
selected_signals.append(vc)
|
|
179
|
+
break
|
|
180
|
+
else:
|
|
181
|
+
selected_signals.append(valid_columns[0])
|
|
182
|
+
# do something more complicated in case of one to many columns
|
|
183
|
+
#pass
|
|
184
|
+
else:
|
|
185
|
+
assert len(selected_signals)==len(required_signals),f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
|
|
186
|
+
|
|
187
|
+
print(f'The following channels will be passed to the model: {selected_signals}')
|
|
188
|
+
trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na, interpolate_position_gaps=interpolate_na, column_labels=column_labels)
|
|
189
|
+
|
|
190
|
+
max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
|
|
191
|
+
tracks = trajectories_clean[column_labels['track']].unique()
|
|
192
|
+
signals = np.zeros((len(tracks),max_signal_size, len(selected_signals)))
|
|
193
|
+
|
|
194
|
+
for i,(tid,group) in enumerate(trajectories_clean.groupby(column_labels['track'])):
|
|
195
|
+
frames = group[column_labels['time']].to_numpy().astype(int)
|
|
196
|
+
for j,col in enumerate(selected_signals):
|
|
197
|
+
signal = group[col].to_numpy()
|
|
198
|
+
signals[i,frames,j] = signal
|
|
199
|
+
|
|
200
|
+
# for i in range(5):
|
|
201
|
+
# print('pre model')
|
|
202
|
+
# plt.plot(signals[i,:,0])
|
|
203
|
+
# plt.show()
|
|
204
|
+
|
|
205
|
+
model = SignalDetectionModel(pretrained=complete_path)
|
|
206
|
+
print('signal shape: ', signals.shape)
|
|
207
|
+
|
|
208
|
+
classes = model.predict_class(signals)
|
|
209
|
+
times_recast = model.predict_time_of_interest(signals)
|
|
210
|
+
|
|
211
|
+
if label is None:
|
|
212
|
+
class_col = 'class'
|
|
213
|
+
time_col = 't0'
|
|
214
|
+
status_col = 'status'
|
|
215
|
+
else:
|
|
216
|
+
class_col = 'class_'+label
|
|
217
|
+
time_col = 't_'+label
|
|
218
|
+
status_col = 'status_'+label
|
|
219
|
+
|
|
220
|
+
for i,(tid,group) in enumerate(trajectories.groupby(column_labels['track'])):
|
|
221
|
+
indices = group.index
|
|
222
|
+
trajectories.loc[indices,class_col] = classes[i]
|
|
223
|
+
trajectories.loc[indices,time_col] = times_recast[i]
|
|
224
|
+
print('Done.')
|
|
225
|
+
|
|
226
|
+
for tid, group in trajectories.groupby(column_labels['track']):
|
|
227
|
+
|
|
228
|
+
indices = group.index
|
|
229
|
+
t0 = group[time_col].to_numpy()[0]
|
|
230
|
+
cclass = group[class_col].to_numpy()[0]
|
|
231
|
+
timeline = group[column_labels['time']].to_numpy()
|
|
232
|
+
status = np.zeros_like(timeline)
|
|
233
|
+
if t0 > 0:
|
|
234
|
+
status[timeline>=t0] = 1.
|
|
235
|
+
if cclass==2:
|
|
236
|
+
status[:] = 2
|
|
237
|
+
if cclass>2:
|
|
238
|
+
status[:] = 42
|
|
239
|
+
status_color = [color_from_status(s) for s in status]
|
|
240
|
+
class_color = [color_from_class(cclass) for i in range(len(status))]
|
|
241
|
+
|
|
242
|
+
trajectories.loc[indices, status_col] = status
|
|
243
|
+
trajectories.loc[indices, 'status_color'] = status_color
|
|
244
|
+
trajectories.loc[indices, 'class_color'] = class_color
|
|
245
|
+
|
|
246
|
+
if plot_outcome:
|
|
247
|
+
fig,ax = plt.subplots(1,len(selected_signals), figsize=(10,5))
|
|
248
|
+
for i,s in enumerate(selected_signals):
|
|
249
|
+
for k,(tid,group) in enumerate(trajectories.groupby(column_labels['track'])):
|
|
250
|
+
cclass = group[class_col].to_numpy()[0]
|
|
251
|
+
t0 = group[time_col].to_numpy()[0]
|
|
252
|
+
timeline = group[column_labels['time']].to_numpy()
|
|
253
|
+
if cclass==0:
|
|
254
|
+
if len(selected_signals)>1:
|
|
255
|
+
ax[i].plot(timeline - t0, group[s].to_numpy(),c='tab:blue',alpha=0.1)
|
|
256
|
+
else:
|
|
257
|
+
ax.plot(timeline - t0, group[s].to_numpy(),c='tab:blue',alpha=0.1)
|
|
258
|
+
if len(selected_signals)>1:
|
|
259
|
+
for a,s in zip(ax,selected_signals):
|
|
260
|
+
a.set_title(s)
|
|
261
|
+
a.set_xlabel(r'time - t$_0$ [frame]')
|
|
262
|
+
a.spines['top'].set_visible(False)
|
|
263
|
+
a.spines['right'].set_visible(False)
|
|
264
|
+
else:
|
|
265
|
+
ax.set_title(s)
|
|
266
|
+
ax.set_xlabel(r'time - t$_0$ [frame]')
|
|
267
|
+
ax.spines['top'].set_visible(False)
|
|
268
|
+
ax.spines['right'].set_visible(False)
|
|
269
|
+
plt.tight_layout()
|
|
270
|
+
if output_dir is not None:
|
|
271
|
+
plt.savefig(output_dir+'signal_collapse.png',bbox_inches='tight',dpi=300)
|
|
272
|
+
plt.pause(3)
|
|
273
|
+
plt.close()
|
|
274
|
+
|
|
275
|
+
return trajectories
|
|
276
|
+
|
|
277
|
+
def analyze_signals_at_position(pos, model, mode, use_gpu=True, return_table=False):
|
|
278
|
+
|
|
279
|
+
"""
|
|
280
|
+
Analyzes signals for a given position directory using a specified model and mode, with an option to use GPU acceleration.
|
|
281
|
+
|
|
282
|
+
This function executes an external Python script to analyze signals within the specified position directory, applying
|
|
283
|
+
a predefined model in a specified mode. It supports GPU acceleration for faster processing. Optionally, the function
|
|
284
|
+
can return the resulting analysis table as a pandas DataFrame.
|
|
285
|
+
|
|
286
|
+
Parameters
|
|
287
|
+
----------
|
|
288
|
+
pos : str
|
|
289
|
+
The file path to the position directory containing the data to be analyzed. The path must be valid and accessible.
|
|
290
|
+
model : str
|
|
291
|
+
The name of the model to use for signal analysis.
|
|
292
|
+
mode : str
|
|
293
|
+
The operation mode specifying how the analysis should be conducted.
|
|
294
|
+
use_gpu : bool, optional
|
|
295
|
+
Specifies whether to use GPU acceleration for the analysis (default is True).
|
|
296
|
+
return_table : bool, optional
|
|
297
|
+
If True, the function returns a pandas DataFrame containing the analysis results (default is False).
|
|
298
|
+
|
|
299
|
+
Returns
|
|
300
|
+
-------
|
|
301
|
+
pandas.DataFrame or None
|
|
302
|
+
If `return_table` is True, returns a DataFrame containing the analysis results. Otherwise, returns None.
|
|
303
|
+
|
|
304
|
+
Raises
|
|
305
|
+
------
|
|
306
|
+
AssertionError
|
|
307
|
+
If the specified position path does not exist.
|
|
308
|
+
|
|
309
|
+
Notes
|
|
310
|
+
-----
|
|
311
|
+
- The analysis is performed by an external script (`analyze_signals.py`) located in a specific directory relative
|
|
312
|
+
to this function.
|
|
313
|
+
- The results of the analysis are expected to be saved in the "output/tables" subdirectory within the position
|
|
314
|
+
directory, following a naming convention based on the analysis `mode`.
|
|
315
|
+
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
pos = pos.replace('\\','/')
|
|
319
|
+
pos = rf"{pos}"
|
|
320
|
+
assert os.path.exists(pos),f'Position {pos} is not a valid path.'
|
|
321
|
+
if not pos.endswith('/'):
|
|
322
|
+
pos += '/'
|
|
323
|
+
|
|
324
|
+
script_path = os.sep.join([abs_path, 'scripts', 'analyze_signals.py'])
|
|
325
|
+
cmd = f'python "{script_path}" --pos "{pos}" --model "{model}" --mode "{mode}" --use_gpu "{use_gpu}"'
|
|
326
|
+
subprocess.call(cmd, shell=True)
|
|
327
|
+
|
|
328
|
+
table = pos + os.sep.join(["output","tables",f"trajectories_{mode}.csv"])
|
|
329
|
+
if return_table:
|
|
330
|
+
df = pd.read_csv(table)
|
|
331
|
+
return df
|
|
332
|
+
else:
|
|
333
|
+
return None
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class SignalDetectionModel(object):
|
|
337
|
+
|
|
338
|
+
"""
|
|
339
|
+
A class for creating and managing signal detection models for analyzing biological signals.
|
|
340
|
+
|
|
341
|
+
This class provides functionalities to load a pretrained signal detection model or create one from scratch,
|
|
342
|
+
preprocess input signals, train the model, and make predictions on new data.
|
|
343
|
+
|
|
344
|
+
Parameters
|
|
345
|
+
----------
|
|
346
|
+
path : str, optional
|
|
347
|
+
Path to the directory containing the model and its configuration. This is used when loading a pretrained model.
|
|
348
|
+
pretrained : str, optional
|
|
349
|
+
Path to the pretrained model to load. If specified, the model and its configuration are loaded from this path.
|
|
350
|
+
channel_option : list of str, optional
|
|
351
|
+
Specifies the channels to be used for signal analysis. Default is ["live_nuclei_channel"].
|
|
352
|
+
model_signal_length : int, optional
|
|
353
|
+
The length of the input signals that the model expects. Default is 128.
|
|
354
|
+
n_channels : int, optional
|
|
355
|
+
The number of channels in the input signals. Default is 1.
|
|
356
|
+
n_conv : int, optional
|
|
357
|
+
The number of convolutional layers in the model. Default is 2.
|
|
358
|
+
n_classes : int, optional
|
|
359
|
+
The number of classes for the classification task. Default is 3.
|
|
360
|
+
dense_collection : int, optional
|
|
361
|
+
The number of units in the dense layer of the model. Default is 512.
|
|
362
|
+
dropout_rate : float, optional
|
|
363
|
+
The dropout rate applied to the dense layer of the model. Default is 0.1.
|
|
364
|
+
label : str, optional
|
|
365
|
+
A label for the model, used in naming and organizing outputs. Default is ''.
|
|
366
|
+
|
|
367
|
+
Attributes
|
|
368
|
+
----------
|
|
369
|
+
model_class : keras Model
|
|
370
|
+
The classification model for predicting the class of signals.
|
|
371
|
+
model_reg : keras Model
|
|
372
|
+
The regression model for predicting the time of interest for signals.
|
|
373
|
+
|
|
374
|
+
Methods
|
|
375
|
+
-------
|
|
376
|
+
load_pretrained_model()
|
|
377
|
+
Loads the model and its configuration from the pretrained path.
|
|
378
|
+
create_models_from_scratch()
|
|
379
|
+
Creates new models for classification and regression from scratch.
|
|
380
|
+
prep_gpu()
|
|
381
|
+
Prepares GPU devices for training, if available.
|
|
382
|
+
fit_from_directory(ds_folders, ...)
|
|
383
|
+
Trains the model using data from specified directories.
|
|
384
|
+
fit(x_train, y_time_train, y_class_train, ...)
|
|
385
|
+
Trains the model using provided datasets.
|
|
386
|
+
predict_class(x, ...)
|
|
387
|
+
Predicts the class of input signals.
|
|
388
|
+
predict_time_of_interest(x, ...)
|
|
389
|
+
Predicts the time of interest for input signals.
|
|
390
|
+
plot_model_history(mode)
|
|
391
|
+
Plots the training history for the specified mode (classifier or regressor).
|
|
392
|
+
evaluate_regression_model()
|
|
393
|
+
Evaluates the regression model on test and validation data.
|
|
394
|
+
gather_callbacks(mode)
|
|
395
|
+
Gathers and prepares callbacks for training based on the specified mode.
|
|
396
|
+
generate_sets()
|
|
397
|
+
Generates training, validation, and test sets from loaded data.
|
|
398
|
+
augment_training_set()
|
|
399
|
+
Augments the training set with additional generated data.
|
|
400
|
+
load_and_normalize(subset)
|
|
401
|
+
Loads and normalizes signals from a subset of data.
|
|
402
|
+
|
|
403
|
+
Notes
|
|
404
|
+
-----
|
|
405
|
+
- This class is designed to work with biological signal data, such as time series from microscopy imaging.
|
|
406
|
+
- The model architecture and training configurations can be customized through the class parameters and methods.
|
|
407
|
+
|
|
408
|
+
"""
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def __init__(self, path=None, pretrained=None, channel_option=["live_nuclei_channel"], model_signal_length=128, n_channels=1,
|
|
412
|
+
n_conv=2, n_classes=3, dense_collection=512, dropout_rate=0.1, label=''):
|
|
413
|
+
|
|
414
|
+
self.prep_gpu()
|
|
415
|
+
|
|
416
|
+
self.model_signal_length = model_signal_length
|
|
417
|
+
self.channel_option = channel_option
|
|
418
|
+
self.pretrained = pretrained
|
|
419
|
+
self.n_channels = n_channels
|
|
420
|
+
self.n_conv = n_conv
|
|
421
|
+
self.n_classes = n_classes
|
|
422
|
+
self.dense_collection = dense_collection
|
|
423
|
+
self.dropout_rate = dropout_rate
|
|
424
|
+
self.label = label
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
if self.pretrained is not None:
|
|
428
|
+
print(f"Load pretrained models from {path}...")
|
|
429
|
+
self.load_pretrained_model()
|
|
430
|
+
else:
|
|
431
|
+
print("Create models from scratch...")
|
|
432
|
+
self.create_models_from_scratch()
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def load_pretrained_model(self):
|
|
436
|
+
|
|
437
|
+
"""
|
|
438
|
+
Loads a pretrained model and its configuration from the specified path.
|
|
439
|
+
|
|
440
|
+
This method attempts to load both the classification and regression models from the path specified during the
|
|
441
|
+
class instantiation. It also loads the model configuration from a JSON file and updates the model attributes
|
|
442
|
+
accordingly. If the models cannot be loaded, an error message is printed.
|
|
443
|
+
|
|
444
|
+
Raises
|
|
445
|
+
------
|
|
446
|
+
Exception
|
|
447
|
+
If there is an error loading the model or the configuration file, an exception is raised with details.
|
|
448
|
+
|
|
449
|
+
Notes
|
|
450
|
+
-----
|
|
451
|
+
- The models are expected to be saved in .h5 format with the filenames "classifier.h5" and "regressor.h5".
|
|
452
|
+
- The configuration file is expected to be named "config_input.json" and located in the same directory as the models.
|
|
453
|
+
"""
|
|
454
|
+
|
|
455
|
+
try:
|
|
456
|
+
self.model_class = load_model(os.sep.join([self.pretrained,"classifier.h5"]),compile=False)
|
|
457
|
+
self.model_class.load_weights(os.sep.join([self.pretrained,"classifier.h5"]))
|
|
458
|
+
print("Classifier successfully loaded...")
|
|
459
|
+
except Exception as e:
|
|
460
|
+
print(f"Error {e}...")
|
|
461
|
+
self.model_class = None
|
|
462
|
+
try:
|
|
463
|
+
self.model_reg = load_model(os.sep.join([self.pretrained,"regressor.h5"]),compile=False)
|
|
464
|
+
self.model_reg.load_weights(os.sep.join([self.pretrained,"regressor.h5"]))
|
|
465
|
+
print("Regressor successfully loaded...")
|
|
466
|
+
except Exception as e:
|
|
467
|
+
print(f"Error {e}...")
|
|
468
|
+
self.model_reg = None
|
|
469
|
+
|
|
470
|
+
# load config
|
|
471
|
+
with open(os.sep.join([self.pretrained,"config_input.json"])) as config_file:
|
|
472
|
+
model_config = json.load(config_file)
|
|
473
|
+
|
|
474
|
+
req_channels = model_config["channels"]
|
|
475
|
+
print(f"Required channels read from pretrained model: {req_channels}")
|
|
476
|
+
self.channel_option = req_channels
|
|
477
|
+
if 'normalize' in model_config:
|
|
478
|
+
self.normalize = model_config['normalize']
|
|
479
|
+
if 'normalization_percentile' in model_config:
|
|
480
|
+
self.normalization_percentile = model_config['normalization_percentile']
|
|
481
|
+
if 'normalization_values' in model_config:
|
|
482
|
+
self.normalization_values = model_config['normalization_values']
|
|
483
|
+
if 'normalization_percentile' in model_config:
|
|
484
|
+
self.normalization_clip = model_config['normalization_clip']
|
|
485
|
+
if 'label' in model_config:
|
|
486
|
+
self.label = model_config['label']
|
|
487
|
+
|
|
488
|
+
self.n_channels = self.model_class.layers[0].input_shape[0][-1]
|
|
489
|
+
self.model_signal_length = self.model_class.layers[0].input_shape[0][-2]
|
|
490
|
+
self.n_classes = self.model_class.layers[-1].output_shape[-1]
|
|
491
|
+
|
|
492
|
+
assert self.model_class.layers[0].input_shape[0] == self.model_reg.layers[0].input_shape[0], f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def create_models_from_scratch(self):
|
|
496
|
+
|
|
497
|
+
"""
|
|
498
|
+
Initializes new models for classification and regression based on the specified parameters.
|
|
499
|
+
|
|
500
|
+
This method creates new ResNet models for both classification and regression tasks using the parameters specified
|
|
501
|
+
during class instantiation. The models are configured but not compiled or trained.
|
|
502
|
+
|
|
503
|
+
Notes
|
|
504
|
+
-----
|
|
505
|
+
- The models are created using a custom ResNet architecture defined elsewhere in the codebase.
|
|
506
|
+
- The models are stored in the `model_class` and `model_reg` attributes of the class.
|
|
507
|
+
"""
|
|
508
|
+
|
|
509
|
+
self.model_class = ResNetModelCurrent(n_channels=self.n_channels,
|
|
510
|
+
n_slices=self.n_conv,
|
|
511
|
+
n_classes = self.n_classes,
|
|
512
|
+
dense_collection=self.dense_collection,
|
|
513
|
+
dropout_rate=self.dropout_rate,
|
|
514
|
+
header="classifier",
|
|
515
|
+
model_signal_length = self.model_signal_length
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
self.model_reg = ResNetModelCurrent(n_channels=self.n_channels,
|
|
519
|
+
n_slices=self.n_conv,
|
|
520
|
+
n_classes = self.n_classes,
|
|
521
|
+
dense_collection=self.dense_collection,
|
|
522
|
+
dropout_rate=self.dropout_rate,
|
|
523
|
+
header="regressor",
|
|
524
|
+
model_signal_length = self.model_signal_length
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
def prep_gpu(self):
|
|
528
|
+
|
|
529
|
+
"""
|
|
530
|
+
Prepares GPU devices for training by enabling memory growth.
|
|
531
|
+
|
|
532
|
+
This method attempts to identify available GPU devices and configures TensorFlow to allow memory growth on each
|
|
533
|
+
GPU. This prevents TensorFlow from allocating the total available memory on the GPU device upfront.
|
|
534
|
+
|
|
535
|
+
Notes
|
|
536
|
+
-----
|
|
537
|
+
- This method should be called before any TensorFlow/Keras operations that might allocate GPU memory.
|
|
538
|
+
- If no GPUs are detected, the method will pass silently.
|
|
539
|
+
"""
|
|
540
|
+
|
|
541
|
+
try:
|
|
542
|
+
physical_devices = list_physical_devices('GPU')
|
|
543
|
+
for gpu in physical_devices:
|
|
544
|
+
set_memory_growth(gpu, True)
|
|
545
|
+
except:
|
|
546
|
+
pass
|
|
547
|
+
|
|
548
|
+
def fit_from_directory(self, ds_folders, normalize=True, normalization_percentile=None, normalization_values = None,
|
|
549
|
+
normalization_clip = None, channel_option=["live_nuclei_channel"], model_name=None, target_directory=None,
|
|
550
|
+
augment=True, augmentation_factor=2, validation_split=0.20, test_split=0.0, batch_size = 64, epochs=300,
|
|
551
|
+
recompile_pretrained=False, learning_rate=0.01, loss_reg="mse", loss_class = CategoricalCrossentropy(from_logits=False)):
|
|
552
|
+
|
|
553
|
+
"""
|
|
554
|
+
Trains the model using data from specified directories.
|
|
555
|
+
|
|
556
|
+
This method prepares the dataset for training by loading and preprocessing data from specified directories,
|
|
557
|
+
then trains the classification and regression models.
|
|
558
|
+
|
|
559
|
+
Parameters
|
|
560
|
+
----------
|
|
561
|
+
ds_folders : list of str
|
|
562
|
+
List of directories containing the dataset files for training.
|
|
563
|
+
normalize : bool, optional
|
|
564
|
+
Whether to normalize the input signals (default is True).
|
|
565
|
+
normalization_percentile : list or None, optional
|
|
566
|
+
Percentiles for signal normalization (default is None).
|
|
567
|
+
normalization_values : list or None, optional
|
|
568
|
+
Specific values for signal normalization (default is None).
|
|
569
|
+
normalization_clip : bool, optional
|
|
570
|
+
Whether to clip the normalized signals (default is None).
|
|
571
|
+
channel_option : list of str, optional
|
|
572
|
+
Specifies the channels to be used for signal analysis (default is ["live_nuclei_channel"]).
|
|
573
|
+
model_name : str, optional
|
|
574
|
+
Name of the model for saving purposes (default is None).
|
|
575
|
+
target_directory : str, optional
|
|
576
|
+
Directory where the trained model and outputs will be saved (default is None).
|
|
577
|
+
augment : bool, optional
|
|
578
|
+
Whether to augment the training data (default is True).
|
|
579
|
+
augmentation_factor : int, optional
|
|
580
|
+
Factor by which to augment the training data (default is 2).
|
|
581
|
+
validation_split : float, optional
|
|
582
|
+
Fraction of the data to be used as validation set (default is 0.20).
|
|
583
|
+
test_split : float, optional
|
|
584
|
+
Fraction of the data to be used as test set (default is 0.0).
|
|
585
|
+
batch_size : int, optional
|
|
586
|
+
Batch size for training (default is 64).
|
|
587
|
+
epochs : int, optional
|
|
588
|
+
Number of epochs to train for (default is 300).
|
|
589
|
+
recompile_pretrained : bool, optional
|
|
590
|
+
Whether to recompile a pretrained model (default is False).
|
|
591
|
+
learning_rate : float, optional
|
|
592
|
+
Learning rate for the optimizer (default is 0.01).
|
|
593
|
+
loss_reg : str or keras.losses.Loss, optional
|
|
594
|
+
Loss function for the regression model (default is "mse").
|
|
595
|
+
loss_class : str or keras.losses.Loss, optional
|
|
596
|
+
Loss function for the classification model (default is CategoricalCrossentropy(from_logits=False)).
|
|
597
|
+
|
|
598
|
+
Notes
|
|
599
|
+
-----
|
|
600
|
+
- The method automatically splits the dataset into training, validation, and test sets according to the specified splits.
|
|
601
|
+
"""
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
if not hasattr(self, 'normalization_percentile'):
|
|
605
|
+
self.normalization_percentile = normalization_percentile
|
|
606
|
+
if not hasattr(self, 'normalization_values'):
|
|
607
|
+
self.normalization_values = normalization_values
|
|
608
|
+
if not hasattr(self, 'normalization_clip'):
|
|
609
|
+
self.normalization_clip = normalization_clip
|
|
610
|
+
print('Actual clip option:', self.normalization_clip)
|
|
611
|
+
|
|
612
|
+
self.normalize = normalize
|
|
613
|
+
self.normalization_percentile, self. normalization_values, self.normalization_clip = _interpret_normalization_parameters(self.n_channels, self.normalization_percentile, self.normalization_values, self.normalization_clip)
|
|
614
|
+
|
|
615
|
+
self.ds_folders = [rf'{d}' for d in ds_folders]
|
|
616
|
+
self.batch_size = batch_size
|
|
617
|
+
self.epochs = epochs
|
|
618
|
+
self.validation_split = validation_split
|
|
619
|
+
self.test_split = test_split
|
|
620
|
+
self.augment = augment
|
|
621
|
+
self.augmentation_factor = augmentation_factor
|
|
622
|
+
self.model_name = rf'{model_name}'
|
|
623
|
+
self.target_directory = rf'{target_directory}'
|
|
624
|
+
self.model_folder = os.sep.join([self.target_directory,self.model_name])
|
|
625
|
+
self.recompile_pretrained = recompile_pretrained
|
|
626
|
+
self.learning_rate = learning_rate
|
|
627
|
+
self.loss_reg = loss_reg
|
|
628
|
+
self.loss_class = loss_class
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
if not os.path.exists(self.model_folder):
|
|
632
|
+
#shutil.rmtree(self.model_folder)
|
|
633
|
+
os.mkdir(self.model_folder)
|
|
634
|
+
|
|
635
|
+
self.channel_option = channel_option
|
|
636
|
+
assert self.n_channels==len(self.channel_option), f'Mismatch between the channel option and the number of channels of the model...'
|
|
637
|
+
|
|
638
|
+
self.list_of_sets = []
|
|
639
|
+
print(self.ds_folders)
|
|
640
|
+
for f in self.ds_folders:
|
|
641
|
+
self.list_of_sets.extend(glob(os.sep.join([f,"*.npy"])))
|
|
642
|
+
print(f"Found {len(self.list_of_sets)} annotation files...")
|
|
643
|
+
self.generate_sets()
|
|
644
|
+
|
|
645
|
+
self.train_classifier()
|
|
646
|
+
self.train_regressor()
|
|
647
|
+
|
|
648
|
+
config_input = {"channels": self.channel_option, "model_signal_length": self.model_signal_length, 'label': self.label, 'normalize': self.normalize, 'normalization_percentile': self.normalization_percentile, 'normalization_values': self.normalization_values, 'normalization_clip': self.normalization_clip}
|
|
649
|
+
json_string = json.dumps(config_input)
|
|
650
|
+
with open(os.sep.join([self.model_folder,"config_input.json"]), 'w') as outfile:
|
|
651
|
+
outfile.write(json_string)
|
|
652
|
+
|
|
653
|
+
def fit(self, x_train, y_time_train, y_class_train, normalize=True, normalization_percentile=None, normalization_values = None, normalization_clip = None, pad=True, validation_data=None, test_data=None, channel_option=["live_nuclei_channel","dead_nuclei_channel"], model_name=None,
|
|
654
|
+
target_directory=None, augment=True, augmentation_factor=3, validation_split=0.25, batch_size = 64, epochs=300,
|
|
655
|
+
recompile_pretrained=False, learning_rate=0.001, loss_reg="mse", loss_class = CategoricalCrossentropy(from_logits=False)):
|
|
656
|
+
|
|
657
|
+
"""
|
|
658
|
+
Trains the model using provided datasets.
|
|
659
|
+
|
|
660
|
+
Parameters
|
|
661
|
+
----------
|
|
662
|
+
Same as `fit_from_directory`, but instead of loading data from directories, this method accepts preloaded and
|
|
663
|
+
optionally preprocessed datasets directly.
|
|
664
|
+
|
|
665
|
+
Notes
|
|
666
|
+
-----
|
|
667
|
+
- This method provides an alternative way to train the model when data is already loaded into memory, offering
|
|
668
|
+
flexibility for data preprocessing steps outside this class.
|
|
669
|
+
"""
|
|
670
|
+
|
|
671
|
+
self.normalize = normalize
|
|
672
|
+
if not hasattr(self, 'normalization_percentile'):
|
|
673
|
+
self.normalization_percentile = normalization_percentile
|
|
674
|
+
if not hasattr(self, 'normalization_values'):
|
|
675
|
+
self.normalization_values = normalization_values
|
|
676
|
+
if not hasattr(self, 'normalization_clip'):
|
|
677
|
+
self.normalization_clip = normalization_clip
|
|
678
|
+
self.normalization_percentile, self. normalization_values, self.normalization_clip = _interpret_normalization_parameters(self.n_channels, self.normalization_percentile, self.normalization_values, self.normalization_clip)
|
|
679
|
+
|
|
680
|
+
self.x_train = x_train
|
|
681
|
+
self.y_class_train = y_class_train
|
|
682
|
+
self.y_time_train = y_time_train
|
|
683
|
+
self.channel_option = channel_option
|
|
684
|
+
|
|
685
|
+
assert self.n_channels==len(self.channel_option), f'Mismatch between the channel option and the number of channels of the model...'
|
|
686
|
+
|
|
687
|
+
if pad:
|
|
688
|
+
self.x_train = pad_to_model_length(self.x_train, self.model_signal_length)
|
|
689
|
+
|
|
690
|
+
assert self.x_train.shape[1:] == (self.model_signal_length, self.n_channels), f"Shape mismatch between the provided training fluorescence signals and the model..."
|
|
691
|
+
|
|
692
|
+
# If y-class is not one-hot encoded, encode it
|
|
693
|
+
if self.y_class_train.shape[-1] != self.n_classes:
|
|
694
|
+
self.class_weights = compute_weights(self.y_class_train)
|
|
695
|
+
self.y_class_train = to_categorical(self.y_class_train)
|
|
696
|
+
|
|
697
|
+
if self.normalize:
|
|
698
|
+
self.y_time_train = self.y_time_train.astype(np.float32)/self.model_signal_length
|
|
699
|
+
self.x_train = normalize_signal_set(self.x_train, self.channel_option, normalization_percentile=self.normalization_percentile,
|
|
700
|
+
normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
if validation_data is not None:
|
|
705
|
+
try:
|
|
706
|
+
self.x_val = validation_data[0]
|
|
707
|
+
if pad:
|
|
708
|
+
self.x_val = pad_to_model_length(self.x_val, self.model_signal_length)
|
|
709
|
+
self.y_class_val = validation_data[1]
|
|
710
|
+
if self.y_class_val.shape[-1] != self.n_classes:
|
|
711
|
+
self.y_class_val = to_categorical(self.y_class_val)
|
|
712
|
+
self.y_time_val = validation_data[2]
|
|
713
|
+
if self.normalize:
|
|
714
|
+
self.y_time_val = self.y_time_val.astype(np.float32)/self.model_signal_length
|
|
715
|
+
self.x_val = normalize_signal_set(self.x_val, self.channel_option, normalization_percentile=self.normalization_percentile,
|
|
716
|
+
normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
except Exception as e:
|
|
720
|
+
print("Could not load validation data, error {e}...")
|
|
721
|
+
else:
|
|
722
|
+
self.validation_split = validation_split
|
|
723
|
+
|
|
724
|
+
if test_data is not None:
|
|
725
|
+
try:
|
|
726
|
+
self.x_test = test_data[0]
|
|
727
|
+
if pad:
|
|
728
|
+
self.x_test = pad_to_model_length(self.x_test, self.model_signal_length)
|
|
729
|
+
self.y_class_test = test_data[1]
|
|
730
|
+
if self.y_class_test.shape[-1] != self.n_classes:
|
|
731
|
+
self.y_class_test = to_categorical(self.y_class_test)
|
|
732
|
+
self.y_time_test = test_data[2]
|
|
733
|
+
if self.normalize:
|
|
734
|
+
self.y_time_test = self.y_time_test.astype(np.float32)/self.model_signal_length
|
|
735
|
+
self.x_test = normalize_signal_set(self.x_test, self.channel_option, normalization_percentile=self.normalization_percentile,
|
|
736
|
+
normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
|
|
737
|
+
)
|
|
738
|
+
except Exception as e:
|
|
739
|
+
print("Could not load test data, error {e}...")
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
self.batch_size = batch_size
|
|
743
|
+
self.epochs = epochs
|
|
744
|
+
self.augment = augment
|
|
745
|
+
self.augmentation_factor = augmentation_factor
|
|
746
|
+
if self.augmentation_factor==1:
|
|
747
|
+
self.augment = False
|
|
748
|
+
self.model_name = model_name
|
|
749
|
+
self.target_directory = target_directory
|
|
750
|
+
self.model_folder = os.sep.join([self.target_directory,self.model_name])
|
|
751
|
+
self.recompile_pretrained = recompile_pretrained
|
|
752
|
+
self.learning_rate = learning_rate
|
|
753
|
+
self.loss_reg = loss_reg
|
|
754
|
+
self.loss_class = loss_class
|
|
755
|
+
|
|
756
|
+
if os.path.exists(self.model_folder):
|
|
757
|
+
shutil.rmtree(self.model_folder)
|
|
758
|
+
os.mkdir(self.model_folder)
|
|
759
|
+
|
|
760
|
+
self.train_classifier()
|
|
761
|
+
self.train_regressor()
|
|
762
|
+
|
|
763
|
+
def predict_class(self, x, normalize=True, pad=True, return_one_hot=False, interpolate=True):
|
|
764
|
+
|
|
765
|
+
"""
|
|
766
|
+
Predicts the class of input signals using the trained classification model.
|
|
767
|
+
|
|
768
|
+
Parameters
|
|
769
|
+
----------
|
|
770
|
+
x : ndarray
|
|
771
|
+
The input signals for which to predict classes.
|
|
772
|
+
normalize : bool, optional
|
|
773
|
+
Whether to normalize the input signals (default is True).
|
|
774
|
+
pad : bool, optional
|
|
775
|
+
Whether to pad the input signals to match the model's expected signal length (default is True).
|
|
776
|
+
return_one_hot : bool, optional
|
|
777
|
+
Whether to return predictions in one-hot encoded format (default is False).
|
|
778
|
+
interpolate : bool, optional
|
|
779
|
+
Whether to interpolate the input signals (default is True).
|
|
780
|
+
|
|
781
|
+
Returns
|
|
782
|
+
-------
|
|
783
|
+
ndarray
|
|
784
|
+
The predicted classes for the input signals. If `return_one_hot` is True, predictions are returned in one-hot
|
|
785
|
+
encoded format, otherwise as integer labels.
|
|
786
|
+
|
|
787
|
+
Notes
|
|
788
|
+
-----
|
|
789
|
+
- The method processes the input signals according to the specified options to ensure compatibility with the model's
|
|
790
|
+
input requirements.
|
|
791
|
+
"""
|
|
792
|
+
|
|
793
|
+
self.x = np.copy(x)
|
|
794
|
+
self.normalize = normalize
|
|
795
|
+
self.pad = pad
|
|
796
|
+
self.return_one_hot = return_one_hot
|
|
797
|
+
# self.max_relevant_time = np.shape(self.x)[1]
|
|
798
|
+
# print(f'Max relevant time: {self.max_relevant_time}')
|
|
799
|
+
|
|
800
|
+
if self.pad:
|
|
801
|
+
self.x = pad_to_model_length(self.x, self.model_signal_length)
|
|
802
|
+
|
|
803
|
+
if self.normalize:
|
|
804
|
+
self.x = normalize_signal_set(self.x, self.channel_option, normalization_percentile=self.normalization_percentile,
|
|
805
|
+
normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
# implement auto interpolation here!!
|
|
809
|
+
#self.x = self.interpolate_signals(self.x)
|
|
810
|
+
|
|
811
|
+
# for i in range(5):
|
|
812
|
+
# plt.plot(self.x[i,:,0])
|
|
813
|
+
# plt.show()
|
|
814
|
+
|
|
815
|
+
assert self.x.shape[-1] == self.model_class.layers[0].input_shape[0][-1], f"Shape mismatch between the input shape and the model input shape..."
|
|
816
|
+
assert self.x.shape[-2] == self.model_class.layers[0].input_shape[0][-2], f"Shape mismatch between the input shape and the model input shape..."
|
|
817
|
+
|
|
818
|
+
self.class_predictions_one_hot = self.model_class.predict(self.x)
|
|
819
|
+
self.class_predictions = self.class_predictions_one_hot.argmax(axis=1)
|
|
820
|
+
|
|
821
|
+
if self.return_one_hot:
|
|
822
|
+
return self.class_predictions_one_hot
|
|
823
|
+
else:
|
|
824
|
+
return self.class_predictions
|
|
825
|
+
|
|
826
|
+
def predict_time_of_interest(self, x, class_predictions=None, normalize=True, pad=True):
|
|
827
|
+
|
|
828
|
+
"""
|
|
829
|
+
Predicts the time of interest for input signals using the trained regression model.
|
|
830
|
+
|
|
831
|
+
Parameters
|
|
832
|
+
----------
|
|
833
|
+
x : ndarray
|
|
834
|
+
The input signals for which to predict times of interest.
|
|
835
|
+
class_predictions : ndarray, optional
|
|
836
|
+
The predicted classes for the input signals. If provided, time of interest predictions are only made for
|
|
837
|
+
signals predicted to belong to a specific class (default is None).
|
|
838
|
+
normalize : bool, optional
|
|
839
|
+
Whether to normalize the input signals (default is True).
|
|
840
|
+
pad : bool, optional
|
|
841
|
+
Whether to pad the input signals to match the model's expected signal length (default is True).
|
|
842
|
+
|
|
843
|
+
Returns
|
|
844
|
+
-------
|
|
845
|
+
ndarray
|
|
846
|
+
The predicted times of interest for the input signals.
|
|
847
|
+
|
|
848
|
+
Notes
|
|
849
|
+
-----
|
|
850
|
+
- The method processes the input signals according to the specified options and uses the regression model to
|
|
851
|
+
predict times at which a particular event of interest occurs.
|
|
852
|
+
"""
|
|
853
|
+
|
|
854
|
+
self.x = np.copy(x)
|
|
855
|
+
self.normalize = normalize
|
|
856
|
+
self.pad = pad
|
|
857
|
+
# self.max_relevant_time = np.shape(self.x)[1]
|
|
858
|
+
# print(f'Max relevant time: {self.max_relevant_time}')
|
|
859
|
+
|
|
860
|
+
if class_predictions is not None:
|
|
861
|
+
self.class_predictions = class_predictions
|
|
862
|
+
|
|
863
|
+
if self.pad:
|
|
864
|
+
self.x = pad_to_model_length(self.x, self.model_signal_length)
|
|
865
|
+
|
|
866
|
+
if self.normalize:
|
|
867
|
+
self.x = normalize_signal_set(self.x, self.channel_option, normalization_percentile=self.normalization_percentile,
|
|
868
|
+
normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
assert self.x.shape[-1] == self.model_reg.layers[0].input_shape[0][-1], f"Shape mismatch between the input shape and the model input shape..."
|
|
872
|
+
assert self.x.shape[-2] == self.model_reg.layers[0].input_shape[0][-2], f"Shape mismatch between the input shape and the model input shape..."
|
|
873
|
+
|
|
874
|
+
if np.any(self.class_predictions==0):
|
|
875
|
+
self.time_predictions = self.model_reg.predict(self.x[self.class_predictions==0])*self.model_signal_length
|
|
876
|
+
self.time_predictions = self.time_predictions[:,0]
|
|
877
|
+
self.time_predictions_recast = np.zeros(len(self.x)) - 1.
|
|
878
|
+
self.time_predictions_recast[self.class_predictions==0] = self.time_predictions
|
|
879
|
+
else:
|
|
880
|
+
self.time_predictions_recast = np.zeros(len(self.x)) - 1.
|
|
881
|
+
return self.time_predictions_recast
|
|
882
|
+
|
|
883
|
+
def interpolate_signals(self, x_set):
|
|
884
|
+
|
|
885
|
+
"""
|
|
886
|
+
Interpolates missing values in the input signal set.
|
|
887
|
+
|
|
888
|
+
Parameters
|
|
889
|
+
----------
|
|
890
|
+
x_set : ndarray
|
|
891
|
+
The input signal set with potentially missing values.
|
|
892
|
+
|
|
893
|
+
Returns
|
|
894
|
+
-------
|
|
895
|
+
ndarray
|
|
896
|
+
The input signal set with missing values interpolated.
|
|
897
|
+
|
|
898
|
+
Notes
|
|
899
|
+
-----
|
|
900
|
+
- This method is useful for preparing signals that have gaps or missing time points before further processing
|
|
901
|
+
or model training.
|
|
902
|
+
"""
|
|
903
|
+
|
|
904
|
+
for i in range(len(x_set)):
|
|
905
|
+
for k in range(x_set.shape[-1]):
|
|
906
|
+
x = x_set[i,:,k]
|
|
907
|
+
not_nan = np.logical_not(np.isnan(x))
|
|
908
|
+
indices = np.arange(len(x))
|
|
909
|
+
interp = interp1d(indices[not_nan], x[not_nan],fill_value=(0.,0.), bounds_error=False)
|
|
910
|
+
x_set[i,:,k] = interp(indices)
|
|
911
|
+
return x_set
|
|
912
|
+
|
|
913
|
+
|
|
914
|
+
|
|
915
|
+
def train_classifier(self):
|
|
916
|
+
|
|
917
|
+
"""
|
|
918
|
+
Trains the classifier component of the model to predict event classes in signals.
|
|
919
|
+
|
|
920
|
+
This method compiles the classifier model (if not pretrained or if recompilation is requested) and
|
|
921
|
+
trains it on the prepared dataset. The training process includes validation and early stopping based
|
|
922
|
+
on precision to prevent overfitting.
|
|
923
|
+
|
|
924
|
+
Notes
|
|
925
|
+
-----
|
|
926
|
+
- The classifier model predicts the class of each signal, such as live, dead, or miscellaneous.
|
|
927
|
+
- Training parameters such as epochs, batch size, and learning rate are specified during class instantiation.
|
|
928
|
+
- Model performance metrics and training history are saved for analysis.
|
|
929
|
+
"""
|
|
930
|
+
|
|
931
|
+
# if pretrained model
|
|
932
|
+
if self.pretrained is not None:
|
|
933
|
+
# if recompile
|
|
934
|
+
if self.recompile_pretrained:
|
|
935
|
+
print('Recompiling the pretrained classifier model... Warning, this action reinitializes all the weights; are you sure that this is what you intended?')
|
|
936
|
+
self.model_class.set_weights(clone_model(self.model_class).get_weights())
|
|
937
|
+
self.model_class.compile(optimizer=Adam(learning_rate=self.learning_rate),
|
|
938
|
+
loss=self.loss_class,
|
|
939
|
+
metrics=['accuracy', Precision(), Recall()])
|
|
940
|
+
else:
|
|
941
|
+
self.initial_model = clone_model(self.model_class)
|
|
942
|
+
self.model_class.set_weights(self.initial_model.get_weights())
|
|
943
|
+
# Recompile to avoid crash
|
|
944
|
+
self.model_class.compile(optimizer=Adam(learning_rate=self.learning_rate),
|
|
945
|
+
loss=self.loss_class,
|
|
946
|
+
metrics=['accuracy', Precision(), Recall()])
|
|
947
|
+
# Reset weights
|
|
948
|
+
self.model_class.set_weights(self.initial_model.get_weights())
|
|
949
|
+
else:
|
|
950
|
+
print("Compiling the classifier...")
|
|
951
|
+
self.model_class.compile(optimizer=Adam(learning_rate=self.learning_rate),
|
|
952
|
+
loss=self.loss_class,
|
|
953
|
+
metrics=['accuracy', Precision(), Recall()])
|
|
954
|
+
|
|
955
|
+
self.gather_callbacks("classifier")
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
# for i in range(30):
|
|
959
|
+
# for j in range(self.x_train.shape[-1]):
|
|
960
|
+
# plt.plot(self.x_train[i,:,j])
|
|
961
|
+
# plt.show()
|
|
962
|
+
|
|
963
|
+
if hasattr(self, 'x_val'):
|
|
964
|
+
self.history_classifier = self.model_class.fit(x=self.x_train,
|
|
965
|
+
y=self.y_class_train,
|
|
966
|
+
batch_size=self.batch_size,
|
|
967
|
+
class_weight=self.class_weights,
|
|
968
|
+
epochs=self.epochs,
|
|
969
|
+
validation_data=(self.x_val,self.y_class_val),
|
|
970
|
+
callbacks=self.cb,
|
|
971
|
+
verbose=1)
|
|
972
|
+
else:
|
|
973
|
+
self.history_classifier = self.model_class.fit(x=self.x_train,
|
|
974
|
+
y=self.y_class_train,
|
|
975
|
+
batch_size=self.batch_size,
|
|
976
|
+
class_weight=self.class_weights,
|
|
977
|
+
epochs=self.epochs,
|
|
978
|
+
callbacks=self.cb,
|
|
979
|
+
validation_split = self.validation_split,
|
|
980
|
+
verbose=1)
|
|
981
|
+
|
|
982
|
+
self.plot_model_history(mode="classifier")
|
|
983
|
+
|
|
984
|
+
# Set current classification model as the best model
|
|
985
|
+
self.model_class = load_model(os.sep.join([self.model_folder,"classifier.h5"]))
|
|
986
|
+
self.model_class.load_weights(os.sep.join([self.model_folder,"classifier.h5"]))
|
|
987
|
+
|
|
988
|
+
self.dico = {"history_classifier": self.history_classifier, "execution_time_classifier": self.cb[-1].times}
|
|
989
|
+
|
|
990
|
+
if hasattr(self, 'x_test'):
|
|
991
|
+
|
|
992
|
+
predictions = self.model_class.predict(self.x_test).argmax(axis=1)
|
|
993
|
+
ground_truth = self.y_class_test.argmax(axis=1)
|
|
994
|
+
assert predictions.shape==ground_truth.shape,"Mismatch in shape between the predictions and the ground truth..."
|
|
995
|
+
|
|
996
|
+
title="Test data"
|
|
997
|
+
IoU_score = jaccard_score(ground_truth, predictions, average=None)
|
|
998
|
+
balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
|
|
999
|
+
precision = precision_score(ground_truth, predictions, average=None)
|
|
1000
|
+
recall = recall_score(ground_truth, predictions, average=None)
|
|
1001
|
+
|
|
1002
|
+
print(f"Test IoU score: {IoU_score}")
|
|
1003
|
+
print(f"Test Balanced accuracy score: {balanced_accuracy}")
|
|
1004
|
+
print(f'Test Precision: {precision}')
|
|
1005
|
+
print(f'Test Recall: {recall}')
|
|
1006
|
+
|
|
1007
|
+
# Confusion matrix on test set
|
|
1008
|
+
results = confusion_matrix(ground_truth,predictions)
|
|
1009
|
+
self.dico.update({"test_IoU": IoU_score, "test_balanced_accuracy": balanced_accuracy, "test_confusion": results, 'test_precision': precision, 'test_recall': recall})
|
|
1010
|
+
|
|
1011
|
+
try:
|
|
1012
|
+
plot_confusion_matrix(results, ["dead","alive","miscellaneous"], output_dir=self.model_folder+os.sep, title=title)
|
|
1013
|
+
except:
|
|
1014
|
+
pass
|
|
1015
|
+
print("Test set: ",classification_report(ground_truth,predictions))
|
|
1016
|
+
|
|
1017
|
+
if hasattr(self, 'x_val'):
|
|
1018
|
+
predictions = self.model_class.predict(self.x_val).argmax(axis=1)
|
|
1019
|
+
ground_truth = self.y_class_val.argmax(axis=1)
|
|
1020
|
+
assert ground_truth.shape==predictions.shape,"Mismatch in shape between the predictions and the ground truth..."
|
|
1021
|
+
title="Validation data"
|
|
1022
|
+
|
|
1023
|
+
# Validation scores
|
|
1024
|
+
IoU_score = jaccard_score(ground_truth, predictions, average=None)
|
|
1025
|
+
balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
|
|
1026
|
+
precision = precision_score(ground_truth, predictions, average=None)
|
|
1027
|
+
recall = recall_score(ground_truth, predictions, average=None)
|
|
1028
|
+
|
|
1029
|
+
print(f"Validation IoU score: {IoU_score}")
|
|
1030
|
+
print(f"Validation Balanced accuracy score: {balanced_accuracy}")
|
|
1031
|
+
print(f'Validation Precision: {precision}')
|
|
1032
|
+
print(f'Validation Recall: {recall}')
|
|
1033
|
+
|
|
1034
|
+
# Confusion matrix on validation set
|
|
1035
|
+
results = confusion_matrix(ground_truth,predictions)
|
|
1036
|
+
self.dico.update({"val_IoU": IoU_score, "val_balanced_accuracy": balanced_accuracy, "val_confusion": results, 'val_precision': precision, 'val_recall': recall})
|
|
1037
|
+
|
|
1038
|
+
try:
|
|
1039
|
+
plot_confusion_matrix(results, ["dead","alive","miscellaneous"], output_dir=self.model_folder+os.sep, title=title)
|
|
1040
|
+
except:
|
|
1041
|
+
pass
|
|
1042
|
+
print("Validation set: ",classification_report(ground_truth,predictions))
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
def train_regressor(self):
|
|
1046
|
+
|
|
1047
|
+
"""
|
|
1048
|
+
Trains the regressor component of the model to estimate the time of interest for events in signals.
|
|
1049
|
+
|
|
1050
|
+
This method compiles the regressor model (if not pretrained or if recompilation is requested) and
|
|
1051
|
+
trains it on a subset of the prepared dataset containing signals with events. The training process
|
|
1052
|
+
includes validation and early stopping based on mean squared error to prevent overfitting.
|
|
1053
|
+
|
|
1054
|
+
Notes
|
|
1055
|
+
-----
|
|
1056
|
+
- The regressor model estimates the time at which an event of interest occurs within each signal.
|
|
1057
|
+
- Only signals predicted to have an event by the classifier model are used for regressor training.
|
|
1058
|
+
- Model performance metrics and training history are saved for analysis.
|
|
1059
|
+
"""
|
|
1060
|
+
|
|
1061
|
+
|
|
1062
|
+
# Compile model
|
|
1063
|
+
# if pretrained model
|
|
1064
|
+
if self.pretrained is not None:
|
|
1065
|
+
# if recompile
|
|
1066
|
+
if self.recompile_pretrained:
|
|
1067
|
+
print('Recompiling the pretrained regressor model... Warning, this action reinitializes all the weights; are you sure that this is what you intended?')
|
|
1068
|
+
self.model_reg.set_weights(clone_model(self.model_reg).get_weights())
|
|
1069
|
+
self.model_reg.compile(optimizer=Adam(learning_rate=self.learning_rate),
|
|
1070
|
+
loss=self.loss_reg,
|
|
1071
|
+
metrics=['mse','mae'])
|
|
1072
|
+
else:
|
|
1073
|
+
self.initial_model = clone_model(self.model_reg)
|
|
1074
|
+
self.model_reg.set_weights(self.initial_model.get_weights())
|
|
1075
|
+
self.model_reg.compile(optimizer=Adam(learning_rate=self.learning_rate),
|
|
1076
|
+
loss=self.loss_reg,
|
|
1077
|
+
metrics=['mse','mae'])
|
|
1078
|
+
self.model_reg.set_weights(self.initial_model.get_weights())
|
|
1079
|
+
else:
|
|
1080
|
+
print("Compiling the regressor...")
|
|
1081
|
+
self.model_reg.compile(optimizer=Adam(learning_rate=self.learning_rate),
|
|
1082
|
+
loss=self.loss_reg,
|
|
1083
|
+
metrics=['mse','mae'])
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
self.gather_callbacks("regressor")
|
|
1087
|
+
|
|
1088
|
+
# Train on subset of data with event
|
|
1089
|
+
|
|
1090
|
+
subset = self.x_train[np.argmax(self.y_class_train,axis=1)==0]
|
|
1091
|
+
# for i in range(30):
|
|
1092
|
+
# plt.plot(subset[i,:,0],c="tab:red")
|
|
1093
|
+
# plt.plot(subset[i,:,1],c="tab:blue")
|
|
1094
|
+
# plt.show()
|
|
1095
|
+
|
|
1096
|
+
if hasattr(self, 'x_val'):
|
|
1097
|
+
self.history_regressor = self.model_reg.fit(x=self.x_train[np.argmax(self.y_class_train,axis=1)==0],
|
|
1098
|
+
y=self.y_time_train[np.argmax(self.y_class_train,axis=1)==0],
|
|
1099
|
+
batch_size=self.batch_size,
|
|
1100
|
+
epochs=self.epochs*2,
|
|
1101
|
+
validation_data=(self.x_val[np.argmax(self.y_class_val,axis=1)==0],self.y_time_val[np.argmax(self.y_class_val,axis=1)==0]),
|
|
1102
|
+
callbacks=self.cb,
|
|
1103
|
+
verbose=1)
|
|
1104
|
+
else:
|
|
1105
|
+
self.history_regressor = self.model_reg.fit(x=self.x_train[np.argmax(self.y_class_train,axis=1)==0],
|
|
1106
|
+
y=self.y_time_train[np.argmax(self.y_class_train,axis=1)==0],
|
|
1107
|
+
batch_size=self.batch_size,
|
|
1108
|
+
epochs=self.epochs*2,
|
|
1109
|
+
callbacks=self.cb,
|
|
1110
|
+
validation_split = self.validation_split,
|
|
1111
|
+
verbose=1)
|
|
1112
|
+
|
|
1113
|
+
self.plot_model_history(mode="regressor")
|
|
1114
|
+
self.dico.update({"history_regressor": self.history_regressor, "execution_time_regressor": self.cb[-1].times})
|
|
1115
|
+
|
|
1116
|
+
|
|
1117
|
+
# Evaluate best model
|
|
1118
|
+
self.model_reg = load_model(os.sep.join([self.model_folder,"regressor.h5"]))
|
|
1119
|
+
self.model_reg.load_weights(os.sep.join([self.model_folder,"regressor.h5"]))
|
|
1120
|
+
self.evaluate_regression_model()
|
|
1121
|
+
|
|
1122
|
+
try:
|
|
1123
|
+
np.save(os.sep.join([self.model_folder,"scores.npy"]), self.dico)
|
|
1124
|
+
except Exception as e:
|
|
1125
|
+
print(e)
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
def plot_model_history(self, mode="regressor"):
|
|
1129
|
+
|
|
1130
|
+
"""
|
|
1131
|
+
Generates and saves plots of the training history for the classifier or regressor model.
|
|
1132
|
+
|
|
1133
|
+
Parameters
|
|
1134
|
+
----------
|
|
1135
|
+
mode : str, optional
|
|
1136
|
+
Specifies which model's training history to plot. Options are "classifier" or "regressor". Default is "regressor".
|
|
1137
|
+
|
|
1138
|
+
Notes
|
|
1139
|
+
-----
|
|
1140
|
+
- Plots include loss and accuracy metrics over epochs for the classifier, and loss metrics for the regressor.
|
|
1141
|
+
- The plots are saved as image files in the model's output directory.
|
|
1142
|
+
"""
|
|
1143
|
+
|
|
1144
|
+
if mode=="regressor":
|
|
1145
|
+
try:
|
|
1146
|
+
plt.plot(self.history_regressor.history['loss'])
|
|
1147
|
+
plt.plot(self.history_regressor.history['val_loss'])
|
|
1148
|
+
plt.title('model loss')
|
|
1149
|
+
plt.ylabel('loss')
|
|
1150
|
+
plt.xlabel('epoch')
|
|
1151
|
+
plt.yscale('log')
|
|
1152
|
+
plt.legend(['train', 'val'], loc='upper left')
|
|
1153
|
+
plt.pause(3)
|
|
1154
|
+
plt.savefig(os.sep.join([self.model_folder,"regression_loss.png"]),bbox_inches="tight",dpi=300)
|
|
1155
|
+
plt.close()
|
|
1156
|
+
except Exception as e:
|
|
1157
|
+
print(f"Error {e}; could not generate plot...")
|
|
1158
|
+
elif mode=="classifier":
|
|
1159
|
+
try:
|
|
1160
|
+
plt.plot(self.history_classifier.history['precision'])
|
|
1161
|
+
plt.plot(self.history_classifier.history['val_precision'])
|
|
1162
|
+
plt.title('model precision')
|
|
1163
|
+
plt.ylabel('precision')
|
|
1164
|
+
plt.xlabel('epoch')
|
|
1165
|
+
plt.legend(['train', 'val'], loc='upper left')
|
|
1166
|
+
plt.pause(3)
|
|
1167
|
+
plt.savefig(os.sep.join([self.model_folder,"classification_loss.png"]),bbox_inches="tight",dpi=300)
|
|
1168
|
+
plt.close()
|
|
1169
|
+
except Exception as e:
|
|
1170
|
+
print(f"Error {e}; could not generate plot...")
|
|
1171
|
+
else:
|
|
1172
|
+
return None
|
|
1173
|
+
|
|
1174
|
+
def evaluate_regression_model(self):
|
|
1175
|
+
|
|
1176
|
+
"""
|
|
1177
|
+
Evaluates the performance of the trained regression model on test and validation datasets.
|
|
1178
|
+
|
|
1179
|
+
This method calculates and prints mean squared error and mean absolute error metrics for the regression model's
|
|
1180
|
+
predictions. It also generates regression plots comparing predicted times of interest to true values.
|
|
1181
|
+
|
|
1182
|
+
Notes
|
|
1183
|
+
-----
|
|
1184
|
+
- Evaluation is performed on both test and validation datasets, if available.
|
|
1185
|
+
- Regression plots and performance metrics are saved in the model's output directory.
|
|
1186
|
+
"""
|
|
1187
|
+
|
|
1188
|
+
|
|
1189
|
+
mse = MeanSquaredError()
|
|
1190
|
+
mae = MeanAbsoluteError()
|
|
1191
|
+
|
|
1192
|
+
if hasattr(self, 'x_test'):
|
|
1193
|
+
|
|
1194
|
+
print("Evaluate on test set...")
|
|
1195
|
+
predictions = self.model_reg.predict(self.x_test[np.argmax(self.y_class_test,axis=1)==0], batch_size=self.batch_size)[:,0]
|
|
1196
|
+
ground_truth = self.y_time_test[np.argmax(self.y_class_test,axis=1)==0]
|
|
1197
|
+
assert predictions.shape==ground_truth.shape,"Shape mismatch between predictions and ground truths..."
|
|
1198
|
+
|
|
1199
|
+
test_mse = mse(ground_truth, predictions).numpy()
|
|
1200
|
+
test_mae = mae(ground_truth, predictions).numpy()
|
|
1201
|
+
print(f"MSE on test set: {test_mse}...")
|
|
1202
|
+
print(f"MAE on test set: {test_mae}...")
|
|
1203
|
+
regression_plot(predictions, ground_truth, savepath=os.sep.join([self.model_folder,"test_regression.png"]))
|
|
1204
|
+
self.dico.update({"test_mse": test_mse, "test_mae": test_mae})
|
|
1205
|
+
|
|
1206
|
+
if hasattr(self, 'x_val'):
|
|
1207
|
+
# Validation set
|
|
1208
|
+
predictions = self.model_reg.predict(self.x_val[np.argmax(self.y_class_val,axis=1)==0], batch_size=self.batch_size)[:,0]
|
|
1209
|
+
ground_truth = self.y_time_val[np.argmax(self.y_class_val,axis=1)==0]
|
|
1210
|
+
assert predictions.shape==ground_truth.shape,"Shape mismatch between predictions and ground truths..."
|
|
1211
|
+
|
|
1212
|
+
val_mse = mse(ground_truth, predictions).numpy()
|
|
1213
|
+
val_mae = mae(ground_truth, predictions).numpy()
|
|
1214
|
+
|
|
1215
|
+
regression_plot(predictions, ground_truth, savepath=os.sep.join([self.model_folder,"validation_regression.png"]))
|
|
1216
|
+
print(f"MSE on validation set: {val_mse}...")
|
|
1217
|
+
print(f"MAE on validation set: {val_mae}...")
|
|
1218
|
+
|
|
1219
|
+
self.dico.update({"val_mse": val_mse, "val_mae": val_mae})
|
|
1220
|
+
|
|
1221
|
+
|
|
1222
|
+
def gather_callbacks(self, mode):
|
|
1223
|
+
|
|
1224
|
+
"""
|
|
1225
|
+
Prepares a list of Keras callbacks for model training based on the specified mode.
|
|
1226
|
+
|
|
1227
|
+
Parameters
|
|
1228
|
+
----------
|
|
1229
|
+
mode : str
|
|
1230
|
+
The training mode for which callbacks are being prepared. Options are "classifier" or "regressor".
|
|
1231
|
+
|
|
1232
|
+
Notes
|
|
1233
|
+
-----
|
|
1234
|
+
- Callbacks include learning rate reduction on plateau, early stopping, model checkpointing, and TensorBoard logging.
|
|
1235
|
+
- The list of callbacks is stored in the class attribute `cb` and used during model training.
|
|
1236
|
+
"""
|
|
1237
|
+
|
|
1238
|
+
self.cb = []
|
|
1239
|
+
|
|
1240
|
+
if mode=="classifier":
|
|
1241
|
+
|
|
1242
|
+
reduce_lr = ReduceLROnPlateau(monitor='val_precision', factor=0.5, patience=30,
|
|
1243
|
+
cooldown=10, min_lr=5e-10, min_delta=1.0E-10,
|
|
1244
|
+
verbose=1,mode="max")
|
|
1245
|
+
self.cb.append(reduce_lr)
|
|
1246
|
+
csv_logger = CSVLogger(os.sep.join([self.model_folder,'log_classifier.csv']), append=True, separator=';')
|
|
1247
|
+
self.cb.append(csv_logger)
|
|
1248
|
+
checkpoint_path = os.sep.join([self.model_folder,"classifier.h5"])
|
|
1249
|
+
cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_precision",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1250
|
+
self.cb.append(cp_callback)
|
|
1251
|
+
|
|
1252
|
+
callback_stop = EarlyStopping(monitor='val_precision', patience=100)
|
|
1253
|
+
self.cb.append(callback_stop)
|
|
1254
|
+
|
|
1255
|
+
elif mode=="regressor":
|
|
1256
|
+
|
|
1257
|
+
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=30,
|
|
1258
|
+
cooldown=10, min_lr=5e-10, min_delta=1.0E-10,
|
|
1259
|
+
verbose=1,mode="min")
|
|
1260
|
+
self.cb.append(reduce_lr)
|
|
1261
|
+
|
|
1262
|
+
csv_logger = CSVLogger(os.sep.join([self.model_folder,'log_regressor.csv']), append=True, separator=';')
|
|
1263
|
+
self.cb.append(csv_logger)
|
|
1264
|
+
|
|
1265
|
+
checkpoint_path = os.sep.join([self.model_folder,"regressor.h5"])
|
|
1266
|
+
cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_loss",mode="min",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1267
|
+
self.cb.append(cp_callback)
|
|
1268
|
+
|
|
1269
|
+
callback_stop = EarlyStopping(monitor='val_loss', patience=200)
|
|
1270
|
+
self.cb.append(callback_stop)
|
|
1271
|
+
|
|
1272
|
+
log_dir = self.model_folder+os.sep
|
|
1273
|
+
cb_tb = TensorBoard(log_dir=log_dir, update_freq='batch')
|
|
1274
|
+
self.cb.append(cb_tb)
|
|
1275
|
+
|
|
1276
|
+
cb_time = TimeHistory()
|
|
1277
|
+
self.cb.append(cb_time)
|
|
1278
|
+
|
|
1279
|
+
|
|
1280
|
+
|
|
1281
|
+
def generate_sets(self):
|
|
1282
|
+
|
|
1283
|
+
"""
|
|
1284
|
+
Generates and preprocesses training, validation, and test sets from loaded annotations.
|
|
1285
|
+
|
|
1286
|
+
This method loads signal data from annotation files, normalizes and interpolates the signals, and splits
|
|
1287
|
+
the dataset into training, validation, and test sets according to specified proportions.
|
|
1288
|
+
|
|
1289
|
+
Notes
|
|
1290
|
+
-----
|
|
1291
|
+
- Signal annotations are expected to be stored in .npy format and contain required channels and event information.
|
|
1292
|
+
- The method applies specified normalization and interpolation options to prepare the signals for model training.
|
|
1293
|
+
"""
|
|
1294
|
+
|
|
1295
|
+
|
|
1296
|
+
self.x_set = []
|
|
1297
|
+
self.y_time_set = []
|
|
1298
|
+
self.y_class_set = []
|
|
1299
|
+
|
|
1300
|
+
for s in self.list_of_sets:
|
|
1301
|
+
self.load_and_normalize(s)
|
|
1302
|
+
|
|
1303
|
+
self.x_set = np.array(self.x_set).astype(np.float32)
|
|
1304
|
+
self.x_set = self.interpolate_signals(self.x_set)
|
|
1305
|
+
|
|
1306
|
+
self.y_time_set = np.array(self.y_time_set).astype(np.float32)
|
|
1307
|
+
self.y_class_set = np.array(self.y_class_set).astype(np.float32)
|
|
1308
|
+
|
|
1309
|
+
class_test = np.isin(self.y_class_set, [0,1,2])
|
|
1310
|
+
self.x_set = self.x_set[class_test]
|
|
1311
|
+
self.y_time_set = self.y_time_set[class_test]
|
|
1312
|
+
self.y_class_set = self.y_class_set[class_test]
|
|
1313
|
+
|
|
1314
|
+
# Compute class weights and one-hot encode
|
|
1315
|
+
self.class_weights = compute_weights(self.y_class_set)
|
|
1316
|
+
self.nbr_classes = len(np.unique(self.y_class_set))
|
|
1317
|
+
self.y_class_set = to_categorical(self.y_class_set)
|
|
1318
|
+
|
|
1319
|
+
ds = train_test_split(self.x_set,
|
|
1320
|
+
self.y_time_set,
|
|
1321
|
+
self.y_class_set,
|
|
1322
|
+
validation_size=self.validation_split,
|
|
1323
|
+
test_size=self.test_split)
|
|
1324
|
+
|
|
1325
|
+
self.x_train = ds["x_train"]
|
|
1326
|
+
self.x_val = ds["x_val"]
|
|
1327
|
+
self.y_time_train = ds["y1_train"].astype(np.float32)
|
|
1328
|
+
print(np.amax(self.y_time_train),np.amin(self.y_time_train))
|
|
1329
|
+
self.y_time_val = ds["y1_val"].astype(np.float32)
|
|
1330
|
+
self.y_class_train = ds["y2_train"]
|
|
1331
|
+
self.y_class_val = ds["y2_val"]
|
|
1332
|
+
|
|
1333
|
+
if self.test_split>0:
|
|
1334
|
+
self.x_test = ds["x_test"]
|
|
1335
|
+
self.y_time_test = ds["y1_test"].astype(np.float32)
|
|
1336
|
+
self.y_class_test = ds["y2_test"]
|
|
1337
|
+
|
|
1338
|
+
if self.augment:
|
|
1339
|
+
self.augment_training_set()
|
|
1340
|
+
|
|
1341
|
+
def augment_training_set(self, time_shift=True):
|
|
1342
|
+
|
|
1343
|
+
"""
|
|
1344
|
+
Augments the training dataset with artificially generated data to increase model robustness.
|
|
1345
|
+
|
|
1346
|
+
Parameters
|
|
1347
|
+
----------
|
|
1348
|
+
time_shift : bool, optional
|
|
1349
|
+
Specifies whether to include time-shifted versions of signals in the augmented dataset. Default is True.
|
|
1350
|
+
|
|
1351
|
+
Notes
|
|
1352
|
+
-----
|
|
1353
|
+
- Augmentation strategies include random time shifting and signal modifications to simulate variations in real data.
|
|
1354
|
+
- The augmented dataset is used for training the classifier and regressor models to improve generalization.
|
|
1355
|
+
"""
|
|
1356
|
+
|
|
1357
|
+
|
|
1358
|
+
nbr_augment = self.augmentation_factor*len(self.x_train)
|
|
1359
|
+
randomize = np.arange(len(self.x_train))
|
|
1360
|
+
indices = random.choices(randomize,k=nbr_augment)
|
|
1361
|
+
|
|
1362
|
+
x_train_aug = []
|
|
1363
|
+
y_time_train_aug = []
|
|
1364
|
+
y_class_train_aug = []
|
|
1365
|
+
|
|
1366
|
+
for k in indices:
|
|
1367
|
+
aug = augmenter(self.x_train[k],
|
|
1368
|
+
self.y_time_train[k],
|
|
1369
|
+
self.y_class_train[k],
|
|
1370
|
+
self.model_signal_length,
|
|
1371
|
+
time_shift=time_shift)
|
|
1372
|
+
x_train_aug.append(aug[0])
|
|
1373
|
+
y_time_train_aug.append(aug[1])
|
|
1374
|
+
y_class_train_aug.append(aug[2])
|
|
1375
|
+
|
|
1376
|
+
# Save augmented training set
|
|
1377
|
+
self.x_train = np.array(x_train_aug)
|
|
1378
|
+
self.y_time_train = np.array(y_time_train_aug)
|
|
1379
|
+
self.y_class_train = np.array(y_class_train_aug)
|
|
1380
|
+
|
|
1381
|
+
|
|
1382
|
+
|
|
1383
|
+
def load_and_normalize(self, subset):
|
|
1384
|
+
|
|
1385
|
+
"""
|
|
1386
|
+
Loads a subset of signal data from an annotation file and applies normalization.
|
|
1387
|
+
|
|
1388
|
+
Parameters
|
|
1389
|
+
----------
|
|
1390
|
+
subset : str
|
|
1391
|
+
The file path to the .npy annotation file containing signal data for a subset of observations.
|
|
1392
|
+
|
|
1393
|
+
Notes
|
|
1394
|
+
-----
|
|
1395
|
+
- The method extracts required signal channels from the annotation file and applies specified normalization
|
|
1396
|
+
and interpolation steps.
|
|
1397
|
+
- Preprocessed signals are added to the global dataset for model training.
|
|
1398
|
+
"""
|
|
1399
|
+
|
|
1400
|
+
set_k = np.load(subset,allow_pickle=True)
|
|
1401
|
+
### here do a mapping between channel option and existing signals
|
|
1402
|
+
|
|
1403
|
+
required_signals = self.channel_option
|
|
1404
|
+
available_signals = list(set_k[0].keys())
|
|
1405
|
+
|
|
1406
|
+
selected_signals = []
|
|
1407
|
+
for s in required_signals:
|
|
1408
|
+
pattern_test = [s in a for a in available_signals]
|
|
1409
|
+
if np.any(pattern_test):
|
|
1410
|
+
valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
1411
|
+
if len(valid_columns)==1:
|
|
1412
|
+
selected_signals.append(valid_columns[0])
|
|
1413
|
+
else:
|
|
1414
|
+
print(f'Found several candidate signals: {valid_columns}')
|
|
1415
|
+
for vc in natsorted(valid_columns):
|
|
1416
|
+
if 'circle' in vc:
|
|
1417
|
+
selected_signals.append(vc)
|
|
1418
|
+
break
|
|
1419
|
+
else:
|
|
1420
|
+
selected_signals.append(valid_columns[0])
|
|
1421
|
+
else:
|
|
1422
|
+
return None
|
|
1423
|
+
|
|
1424
|
+
|
|
1425
|
+
key_to_check = selected_signals[0] #self.channel_option[0]
|
|
1426
|
+
signal_lengths = [len(l[key_to_check]) for l in set_k]
|
|
1427
|
+
max_length = np.amax(signal_lengths)
|
|
1428
|
+
|
|
1429
|
+
fluo = np.zeros((len(set_k),max_length,self.n_channels))
|
|
1430
|
+
classes = np.zeros(len(set_k))
|
|
1431
|
+
times_of_interest = np.zeros(len(set_k))
|
|
1432
|
+
|
|
1433
|
+
for k in range(len(set_k)):
|
|
1434
|
+
|
|
1435
|
+
for i in range(self.n_channels):
|
|
1436
|
+
try:
|
|
1437
|
+
# take into account timeline for accurate time regression
|
|
1438
|
+
timeline = set_k[k]['FRAME'].astype(int)
|
|
1439
|
+
fluo[k,timeline,i] = set_k[k][selected_signals[i]]
|
|
1440
|
+
except:
|
|
1441
|
+
print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
|
|
1442
|
+
pass
|
|
1443
|
+
|
|
1444
|
+
classes[k] = set_k[k]["class"]
|
|
1445
|
+
times_of_interest[k] = set_k[k]["time_of_interest"]
|
|
1446
|
+
|
|
1447
|
+
# Correct absurd times of interest
|
|
1448
|
+
times_of_interest[np.nonzero(classes)] = -1
|
|
1449
|
+
times_of_interest[(times_of_interest<=0.0)] = -1
|
|
1450
|
+
|
|
1451
|
+
# Attempt per-set normalization
|
|
1452
|
+
fluo = pad_to_model_length(fluo, self.model_signal_length)
|
|
1453
|
+
if self.normalize:
|
|
1454
|
+
fluo = normalize_signal_set(fluo, self.channel_option, normalization_percentile=self.normalization_percentile,
|
|
1455
|
+
normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
|
|
1456
|
+
)
|
|
1457
|
+
|
|
1458
|
+
# Trivial normalization for time of interest
|
|
1459
|
+
times_of_interest /= self.model_signal_length
|
|
1460
|
+
|
|
1461
|
+
# Add to global dataset
|
|
1462
|
+
self.x_set.extend(fluo)
|
|
1463
|
+
self.y_time_set.extend(times_of_interest)
|
|
1464
|
+
self.y_class_set.extend(classes)
|
|
1465
|
+
|
|
1466
|
+
def _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip):
|
|
1467
|
+
|
|
1468
|
+
"""
|
|
1469
|
+
Interprets and validates normalization parameters for each channel.
|
|
1470
|
+
|
|
1471
|
+
This function ensures the normalization parameters are correctly formatted and expanded to match
|
|
1472
|
+
the number of channels in the dataset. It provides default values and expands single values into
|
|
1473
|
+
lists to match the number of channels if necessary.
|
|
1474
|
+
|
|
1475
|
+
Parameters
|
|
1476
|
+
----------
|
|
1477
|
+
n_channels : int
|
|
1478
|
+
The number of channels in the dataset.
|
|
1479
|
+
normalization_percentile : list of bool or bool, optional
|
|
1480
|
+
Specifies whether to normalize each channel based on percentile values. If a single bool is provided,
|
|
1481
|
+
it is expanded to a list matching the number of channels. Default is True for all channels.
|
|
1482
|
+
normalization_values : list of lists or list, optional
|
|
1483
|
+
Specifies the percentile values [lower, upper] for normalization for each channel. If a single pair
|
|
1484
|
+
is provided, it is expanded to match the number of channels. Default is [[0.1, 99.9]] for all channels.
|
|
1485
|
+
normalization_clip : list of bool or bool, optional
|
|
1486
|
+
Specifies whether to clip the normalized values for each channel to the range [0, 1]. If a single bool
|
|
1487
|
+
is provided, it is expanded to a list matching the number of channels. Default is False for all channels.
|
|
1488
|
+
|
|
1489
|
+
Returns
|
|
1490
|
+
-------
|
|
1491
|
+
tuple
|
|
1492
|
+
A tuple containing three lists: `normalization_percentile`, `normalization_values`, and `normalization_clip`,
|
|
1493
|
+
each of length `n_channels`, representing the interpreted and validated normalization parameters for each channel.
|
|
1494
|
+
|
|
1495
|
+
Raises
|
|
1496
|
+
------
|
|
1497
|
+
AssertionError
|
|
1498
|
+
If the lengths of the provided lists do not match `n_channels`.
|
|
1499
|
+
|
|
1500
|
+
Examples
|
|
1501
|
+
--------
|
|
1502
|
+
>>> n_channels = 2
|
|
1503
|
+
>>> normalization_percentile = True
|
|
1504
|
+
>>> normalization_values = [0.1, 99.9]
|
|
1505
|
+
>>> normalization_clip = False
|
|
1506
|
+
>>> params = _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip)
|
|
1507
|
+
>>> print(params)
|
|
1508
|
+
# ([True, True], [[0.1, 99.9], [0.1, 99.9]], [False, False])
|
|
1509
|
+
"""
|
|
1510
|
+
|
|
1511
|
+
|
|
1512
|
+
if normalization_percentile is None:
|
|
1513
|
+
normalization_percentile = [True]*n_channels
|
|
1514
|
+
if normalization_values is None:
|
|
1515
|
+
normalization_values = [[0.1,99.9]]*n_channels
|
|
1516
|
+
if normalization_clip is None:
|
|
1517
|
+
normalization_clip = [False]*n_channels
|
|
1518
|
+
|
|
1519
|
+
if isinstance(normalization_percentile, bool):
|
|
1520
|
+
normalization_percentile = [normalization_percentile]*n_channels
|
|
1521
|
+
if isinstance(normalization_clip, bool):
|
|
1522
|
+
normalization_clip = [normalization_clip]*n_channels
|
|
1523
|
+
if len(normalization_values)==2 and not isinstance(normalization_values[0], list):
|
|
1524
|
+
normalization_values = [normalization_values]*n_channels
|
|
1525
|
+
|
|
1526
|
+
assert len(normalization_values)==n_channels
|
|
1527
|
+
assert len(normalization_clip)==n_channels
|
|
1528
|
+
assert len(normalization_percentile)==n_channels
|
|
1529
|
+
|
|
1530
|
+
return normalization_percentile, normalization_values, normalization_clip
|
|
1531
|
+
|
|
1532
|
+
|
|
1533
|
+
def normalize_signal_set(signal_set, channel_option, percentile_alive=[0.01,99.99], percentile_dead=[0.5,99.999], percentile_generic=[0.01,99.99], normalization_percentile=None, normalization_values=None, normalization_clip=None):
|
|
1534
|
+
|
|
1535
|
+
"""
|
|
1536
|
+
Normalizes a set of single-cell signals across specified channels using given percentile values or specific normalization parameters.
|
|
1537
|
+
|
|
1538
|
+
This function applies normalization to each channel in the signal set based on the provided normalization parameters,
|
|
1539
|
+
which can be defined globally or per channel. The normalization process aims to scale the signal values to a standard
|
|
1540
|
+
range, improving the consistency and comparability of signal measurements across samples.
|
|
1541
|
+
|
|
1542
|
+
Parameters
|
|
1543
|
+
----------
|
|
1544
|
+
signal_set : ndarray
|
|
1545
|
+
A 3D numpy array representing the set of signals to be normalized, with dimensions corresponding to (samples, time points, channels).
|
|
1546
|
+
channel_option : list of str
|
|
1547
|
+
A list specifying the channels included in the signal set and their corresponding normalization strategy based on channel names.
|
|
1548
|
+
percentile_alive : list of float, optional
|
|
1549
|
+
The percentile values [lower, upper] used for normalization of signals from channels labeled as 'alive'. Default is [0.01, 99.99].
|
|
1550
|
+
percentile_dead : list of float, optional
|
|
1551
|
+
The percentile values [lower, upper] used for normalization of signals from channels labeled as 'dead'. Default is [0.5, 99.999].
|
|
1552
|
+
percentile_generic : list of float, optional
|
|
1553
|
+
The percentile values [lower, upper] used for normalization of signals from channels not specifically labeled as 'alive' or 'dead'.
|
|
1554
|
+
Default is [0.01, 99.99].
|
|
1555
|
+
normalization_percentile : list of bool or None, optional
|
|
1556
|
+
Specifies whether to normalize each channel based on percentile values. If None, the default percentile strategy is applied
|
|
1557
|
+
based on `channel_option`. If a list, it should match the length of `channel_option`.
|
|
1558
|
+
normalization_values : list of lists or None, optional
|
|
1559
|
+
Specifies the percentile values [lower, upper] or fixed values [min, max] for normalization for each channel. Overrides
|
|
1560
|
+
`percentile_alive`, `percentile_dead`, and `percentile_generic` if provided.
|
|
1561
|
+
normalization_clip : list of bool or None, optional
|
|
1562
|
+
Specifies whether to clip the normalized values for each channel to the range [0, 1]. If None, clipping is disabled by default.
|
|
1563
|
+
|
|
1564
|
+
Returns
|
|
1565
|
+
-------
|
|
1566
|
+
ndarray
|
|
1567
|
+
The normalized signal set with the same shape as the input `signal_set`.
|
|
1568
|
+
|
|
1569
|
+
Notes
|
|
1570
|
+
-----
|
|
1571
|
+
- The function supports different normalization strategies for 'alive', 'dead', and generic signal channels, which can be customized
|
|
1572
|
+
via `channel_option` and the percentile parameters.
|
|
1573
|
+
- Normalization parameters (`normalization_percentile`, `normalization_values`, `normalization_clip`) are interpreted and validated
|
|
1574
|
+
by calling `_interpret_normalization_parameters`.
|
|
1575
|
+
|
|
1576
|
+
Examples
|
|
1577
|
+
--------
|
|
1578
|
+
>>> signal_set = np.random.rand(100, 128, 2) # 100 samples, 128 time points, 2 channels
|
|
1579
|
+
>>> channel_option = ['alive', 'dead']
|
|
1580
|
+
>>> normalized_signals = normalize_signal_set(signal_set, channel_option)
|
|
1581
|
+
# Normalizes the signal set based on the default percentile values for 'alive' and 'dead' channels.
|
|
1582
|
+
"""
|
|
1583
|
+
|
|
1584
|
+
# Check normalization params are ok
|
|
1585
|
+
n_channels = len(channel_option)
|
|
1586
|
+
normalization_percentile, normalization_values, normalization_clip = _interpret_normalization_parameters(n_channels,
|
|
1587
|
+
normalization_percentile,
|
|
1588
|
+
normalization_values,
|
|
1589
|
+
normalization_clip)
|
|
1590
|
+
for k,channel in enumerate(channel_option):
|
|
1591
|
+
|
|
1592
|
+
zero_values = []
|
|
1593
|
+
for i in range(len(signal_set)):
|
|
1594
|
+
zeros_loc = np.where(signal_set[i,:,k]==0)
|
|
1595
|
+
zero_values.append(zeros_loc)
|
|
1596
|
+
|
|
1597
|
+
values = signal_set[:,:,k]
|
|
1598
|
+
|
|
1599
|
+
if normalization_percentile[k]:
|
|
1600
|
+
min_val = np.nanpercentile(values[values!=0.], normalization_values[k][0])
|
|
1601
|
+
max_val = np.nanpercentile(values[values!=0.], normalization_values[k][1])
|
|
1602
|
+
else:
|
|
1603
|
+
min_val = normalization_values[k][0]
|
|
1604
|
+
max_val = normalization_values[k][1]
|
|
1605
|
+
|
|
1606
|
+
signal_set[:,:,k] -= min_val
|
|
1607
|
+
signal_set[:,:,k] /= (max_val - min_val)
|
|
1608
|
+
|
|
1609
|
+
if normalization_clip[k]:
|
|
1610
|
+
to_clip_low = []
|
|
1611
|
+
to_clip_high = []
|
|
1612
|
+
for i in range(len(signal_set)):
|
|
1613
|
+
clip_low_loc = np.where(signal_set[i,:,k]<=0)
|
|
1614
|
+
clip_high_loc = np.where(signal_set[i,:,k]>=1.0)
|
|
1615
|
+
to_clip_low.append(clip_low_loc)
|
|
1616
|
+
to_clip_high.append(clip_high_loc)
|
|
1617
|
+
|
|
1618
|
+
for i,z in enumerate(to_clip_low):
|
|
1619
|
+
signal_set[i,z,k] = 0.
|
|
1620
|
+
for i,z in enumerate(to_clip_high):
|
|
1621
|
+
signal_set[i,z,k] = 1.
|
|
1622
|
+
|
|
1623
|
+
for i,z in enumerate(zero_values):
|
|
1624
|
+
signal_set[i,z,k] = 0.
|
|
1625
|
+
|
|
1626
|
+
return signal_set
|
|
1627
|
+
|
|
1628
|
+
def pad_to_model_length(signal_set, model_signal_length):
|
|
1629
|
+
|
|
1630
|
+
"""
|
|
1631
|
+
|
|
1632
|
+
Pad the signal set to match the specified model signal length.
|
|
1633
|
+
|
|
1634
|
+
Parameters
|
|
1635
|
+
----------
|
|
1636
|
+
signal_set : array-like
|
|
1637
|
+
The signal set to be padded.
|
|
1638
|
+
model_signal_length : int
|
|
1639
|
+
The desired length of the model signal.
|
|
1640
|
+
|
|
1641
|
+
Returns
|
|
1642
|
+
-------
|
|
1643
|
+
array-like
|
|
1644
|
+
The padded signal set.
|
|
1645
|
+
|
|
1646
|
+
Notes
|
|
1647
|
+
-----
|
|
1648
|
+
This function pads the signal set with zeros along the second dimension (axis 1) to match the specified model signal
|
|
1649
|
+
length. The padding is applied to the end of the signals, increasing their length.
|
|
1650
|
+
|
|
1651
|
+
Examples
|
|
1652
|
+
--------
|
|
1653
|
+
>>> signal_set = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
1654
|
+
>>> padded_signals = pad_to_model_length(signal_set, 5)
|
|
1655
|
+
|
|
1656
|
+
"""
|
|
1657
|
+
|
|
1658
|
+
padded = np.pad(signal_set, [(0,0),(0,model_signal_length - signal_set.shape[1]),(0,0)])
|
|
1659
|
+
|
|
1660
|
+
return padded
|
|
1661
|
+
|
|
1662
|
+
def random_intensity_change(signal):
|
|
1663
|
+
|
|
1664
|
+
"""
|
|
1665
|
+
|
|
1666
|
+
Randomly change the intensity of a signal.
|
|
1667
|
+
|
|
1668
|
+
Parameters
|
|
1669
|
+
----------
|
|
1670
|
+
signal : array-like
|
|
1671
|
+
The input signal to be modified.
|
|
1672
|
+
|
|
1673
|
+
Returns
|
|
1674
|
+
-------
|
|
1675
|
+
array-like
|
|
1676
|
+
The modified signal with randomly changed intensity.
|
|
1677
|
+
|
|
1678
|
+
Notes
|
|
1679
|
+
-----
|
|
1680
|
+
This function applies a random intensity change to each channel of the input signal. The intensity change is
|
|
1681
|
+
performed by multiplying each channel with a random value drawn from a uniform distribution between 0.7 and 1.0.
|
|
1682
|
+
|
|
1683
|
+
Examples
|
|
1684
|
+
--------
|
|
1685
|
+
>>> signal = np.array([[1, 2, 3], [4, 5, 6]])
|
|
1686
|
+
>>> modified_signal = random_intensity_change(signal)
|
|
1687
|
+
|
|
1688
|
+
"""
|
|
1689
|
+
|
|
1690
|
+
for k in range(signal.shape[1]):
|
|
1691
|
+
signal[:,k] = signal[:,k]*np.random.uniform(0.7,1.)
|
|
1692
|
+
|
|
1693
|
+
return signal
|
|
1694
|
+
|
|
1695
|
+
def gauss_noise(signal):
|
|
1696
|
+
|
|
1697
|
+
"""
|
|
1698
|
+
|
|
1699
|
+
Add Gaussian noise to a signal.
|
|
1700
|
+
|
|
1701
|
+
Parameters
|
|
1702
|
+
----------
|
|
1703
|
+
signal : array-like
|
|
1704
|
+
The input signal to which noise will be added.
|
|
1705
|
+
|
|
1706
|
+
Returns
|
|
1707
|
+
-------
|
|
1708
|
+
array-like
|
|
1709
|
+
The signal with Gaussian noise added.
|
|
1710
|
+
|
|
1711
|
+
Notes
|
|
1712
|
+
-----
|
|
1713
|
+
This function adds Gaussian noise to the input signal. The noise is generated by drawing random values from a
|
|
1714
|
+
standard normal distribution and scaling them by a factor of 0.08 times the input signal. The scaled noise values
|
|
1715
|
+
are then added to the original signal.
|
|
1716
|
+
|
|
1717
|
+
Examples
|
|
1718
|
+
--------
|
|
1719
|
+
>>> signal = np.array([1, 2, 3, 4, 5])
|
|
1720
|
+
>>> noisy_signal = gauss_noise(signal)
|
|
1721
|
+
|
|
1722
|
+
"""
|
|
1723
|
+
|
|
1724
|
+
sig = 0.08*np.random.uniform(0,1)
|
|
1725
|
+
signal = signal + sig*np.random.normal(0,1,signal.shape)*signal
|
|
1726
|
+
return signal
|
|
1727
|
+
|
|
1728
|
+
def random_time_shift(signal, time_of_interest, cclass, model_signal_length):
|
|
1729
|
+
|
|
1730
|
+
"""
|
|
1731
|
+
|
|
1732
|
+
Randomly shift the signals to another time.
|
|
1733
|
+
|
|
1734
|
+
Parameters
|
|
1735
|
+
----------
|
|
1736
|
+
signal : array-like
|
|
1737
|
+
The signal to be shifted.
|
|
1738
|
+
time_of_interest : int or float
|
|
1739
|
+
The original time of interest for the signal. Use -1 if not applicable.
|
|
1740
|
+
model_signal_length : int
|
|
1741
|
+
The length of the model signal.
|
|
1742
|
+
|
|
1743
|
+
Returns
|
|
1744
|
+
-------
|
|
1745
|
+
array-like
|
|
1746
|
+
The shifted fluorescence signal.
|
|
1747
|
+
int or float
|
|
1748
|
+
The new time of interest if available; otherwise, the original time of interest.
|
|
1749
|
+
|
|
1750
|
+
Notes
|
|
1751
|
+
-----
|
|
1752
|
+
This function randomly selects a target time within the specified model signal length and shifts the
|
|
1753
|
+
signal accordingly. The shift is performed along the first dimension (axis 0) of the signal. The function uses
|
|
1754
|
+
nearest-neighbor interpolation for shifting.
|
|
1755
|
+
|
|
1756
|
+
If the original time of interest (`time_of_interest`) is provided (not equal to -1), the function returns the
|
|
1757
|
+
shifted signal along with the new time of interest. Otherwise, it returns the shifted signal along with the
|
|
1758
|
+
original time of interest.
|
|
1759
|
+
|
|
1760
|
+
The `max_time` is set to the `model_signal_length` unless the original time of interest is provided. In that case,
|
|
1761
|
+
`max_time` is set to `model_signal_length - 3` to prevent shifting too close to the edge.
|
|
1762
|
+
|
|
1763
|
+
Examples
|
|
1764
|
+
--------
|
|
1765
|
+
>>> signal = np.array([[1, 2, 3], [4, 5, 6]])
|
|
1766
|
+
>>> shifted_signal, new_time = random_time_shift(signal, 1, 5)
|
|
1767
|
+
|
|
1768
|
+
"""
|
|
1769
|
+
|
|
1770
|
+
max_time = model_signal_length
|
|
1771
|
+
return_target = False
|
|
1772
|
+
if time_of_interest != -1:
|
|
1773
|
+
return_target = True
|
|
1774
|
+
max_time = model_signal_length - 3 # to prevent approaching too much to the edge
|
|
1775
|
+
|
|
1776
|
+
times = np.linspace(-max_time,max_time,2000) # symmetrize to create left-censored events
|
|
1777
|
+
target_time = np.random.choice(times)
|
|
1778
|
+
|
|
1779
|
+
delta_t = target_time - time_of_interest
|
|
1780
|
+
signal = shift(signal, [delta_t,0], order=0, mode="nearest")
|
|
1781
|
+
|
|
1782
|
+
if target_time<=0 and np.argmax(cclass)==0:
|
|
1783
|
+
target_time = -1
|
|
1784
|
+
cclass = np.array([0.,0.,1.]).astype(np.float32)
|
|
1785
|
+
|
|
1786
|
+
if return_target:
|
|
1787
|
+
return signal,target_time, cclass
|
|
1788
|
+
else:
|
|
1789
|
+
return signal, time_of_interest, cclass
|
|
1790
|
+
|
|
1791
|
+
def augmenter(signal, time_of_interest, cclass, model_signal_length, time_shift=True, probability=0.8):
|
|
1792
|
+
|
|
1793
|
+
"""
|
|
1794
|
+
Randomly augments single-cell signals to simulate variations in noise, intensity ratios, and event times.
|
|
1795
|
+
|
|
1796
|
+
This function applies random transformations to the input signal, including time shifts, intensity changes,
|
|
1797
|
+
and the addition of Gaussian noise, with the aim of increasing the diversity of the dataset for training robust models.
|
|
1798
|
+
|
|
1799
|
+
Parameters
|
|
1800
|
+
----------
|
|
1801
|
+
signal : ndarray
|
|
1802
|
+
A 1D numpy array representing the signal of a single cell to be augmented.
|
|
1803
|
+
time_of_interest : float
|
|
1804
|
+
The normalized time of interest (event time) for the signal, scaled to the range [0, 1].
|
|
1805
|
+
cclass : ndarray
|
|
1806
|
+
A one-hot encoded numpy array representing the class of the cell associated with the signal.
|
|
1807
|
+
model_signal_length : int
|
|
1808
|
+
The length of the signal expected by the model, used for scaling the time of interest.
|
|
1809
|
+
time_shift : bool, optional
|
|
1810
|
+
Specifies whether to apply random time shifts to the signal. Default is True.
|
|
1811
|
+
probability : float, optional
|
|
1812
|
+
The probability with which to apply the augmentation transformations. Default is 0.8.
|
|
1813
|
+
|
|
1814
|
+
Returns
|
|
1815
|
+
-------
|
|
1816
|
+
tuple
|
|
1817
|
+
A tuple containing the augmented signal, the normalized time of interest, and the class of the cell.
|
|
1818
|
+
|
|
1819
|
+
Raises
|
|
1820
|
+
------
|
|
1821
|
+
AssertionError
|
|
1822
|
+
If the time of interest is provided but invalid for time shifting.
|
|
1823
|
+
|
|
1824
|
+
Notes
|
|
1825
|
+
-----
|
|
1826
|
+
- Time shifting is not applied to cells of the class labeled as 'miscellaneous' (typically encoded as the class '2').
|
|
1827
|
+
- The time of interest is rescaled based on the model's expected signal length before and after any time shift.
|
|
1828
|
+
- Augmentation is applied with the specified probability to simulate realistic variability while maintaining
|
|
1829
|
+
some original signals in the dataset.
|
|
1830
|
+
|
|
1831
|
+
"""
|
|
1832
|
+
|
|
1833
|
+
if np.amax(time_of_interest)<=1.0:
|
|
1834
|
+
time_of_interest *= model_signal_length
|
|
1835
|
+
|
|
1836
|
+
# augment with a certain probability
|
|
1837
|
+
r = random.random()
|
|
1838
|
+
if r<= probability:
|
|
1839
|
+
|
|
1840
|
+
if time_shift:
|
|
1841
|
+
# do not time shift miscellaneous cells
|
|
1842
|
+
if cclass.argmax()!=2.:
|
|
1843
|
+
assert time_of_interest is not None, f"Please provide valid lysis times"
|
|
1844
|
+
signal,time_of_interest,cclass = random_time_shift(signal, time_of_interest, cclass, model_signal_length)
|
|
1845
|
+
|
|
1846
|
+
#signal = random_intensity_change(signal) #maybe bad idea for non percentile-normalized signals
|
|
1847
|
+
signal = gauss_noise(signal)
|
|
1848
|
+
|
|
1849
|
+
return signal, time_of_interest/model_signal_length, cclass
|
|
1850
|
+
|
|
1851
|
+
|
|
1852
|
+
def residual_block1D(x, number_of_filters, kernel_size=8, match_filter_size=True, connection='identity'):
|
|
1853
|
+
|
|
1854
|
+
"""
|
|
1855
|
+
|
|
1856
|
+
Create a 1D residual block.
|
|
1857
|
+
|
|
1858
|
+
Parameters
|
|
1859
|
+
----------
|
|
1860
|
+
x : Tensor
|
|
1861
|
+
Input tensor.
|
|
1862
|
+
number_of_filters : int
|
|
1863
|
+
Number of filters in the convolutional layers.
|
|
1864
|
+
match_filter_size : bool, optional
|
|
1865
|
+
Whether to match the filter size of the skip connection to the output. Default is True.
|
|
1866
|
+
|
|
1867
|
+
Returns
|
|
1868
|
+
-------
|
|
1869
|
+
Tensor
|
|
1870
|
+
Output tensor of the residual block.
|
|
1871
|
+
|
|
1872
|
+
Notes
|
|
1873
|
+
-----
|
|
1874
|
+
This function creates a 1D residual block by performing the original mapping followed by adding a skip connection
|
|
1875
|
+
and applying non-linear activation. The skip connection allows the gradient to flow directly to earlier layers and
|
|
1876
|
+
helps mitigate the vanishing gradient problem. The residual block consists of three convolutional layers with
|
|
1877
|
+
batch normalization and ReLU activation functions.
|
|
1878
|
+
|
|
1879
|
+
If `match_filter_size` is True, the skip connection is adjusted to have the same number of filters as the output.
|
|
1880
|
+
Otherwise, the skip connection is kept as is.
|
|
1881
|
+
|
|
1882
|
+
Examples
|
|
1883
|
+
--------
|
|
1884
|
+
>>> inputs = Input(shape=(10, 3))
|
|
1885
|
+
>>> x = residual_block1D(inputs, 64)
|
|
1886
|
+
# Create a 1D residual block with 64 filters and apply it to the input tensor.
|
|
1887
|
+
|
|
1888
|
+
"""
|
|
1889
|
+
|
|
1890
|
+
|
|
1891
|
+
# Create skip connection
|
|
1892
|
+
x_skip = x
|
|
1893
|
+
|
|
1894
|
+
# Perform the original mapping
|
|
1895
|
+
if connection=='identity':
|
|
1896
|
+
x = Conv1D(number_of_filters, kernel_size=kernel_size, strides=1,padding="same")(x_skip)
|
|
1897
|
+
elif connection=='projection':
|
|
1898
|
+
x = ZeroPadding1D(padding=kernel_size//2)(x_skip)
|
|
1899
|
+
x = Conv1D(number_of_filters, kernel_size=kernel_size, strides=2,padding="valid")(x)
|
|
1900
|
+
x = BatchNormalization()(x)
|
|
1901
|
+
x = Activation("relu")(x)
|
|
1902
|
+
|
|
1903
|
+
x = Conv1D(number_of_filters, kernel_size=kernel_size, strides=1,padding="same")(x)
|
|
1904
|
+
x = BatchNormalization()(x)
|
|
1905
|
+
|
|
1906
|
+
if match_filter_size and connection=='identity':
|
|
1907
|
+
x_skip = Conv1D(number_of_filters, kernel_size=1, padding="same")(x_skip)
|
|
1908
|
+
elif match_filter_size and connection=='projection':
|
|
1909
|
+
x_skip = Conv1D(number_of_filters, kernel_size=1, strides=2, padding="valid")(x_skip)
|
|
1910
|
+
|
|
1911
|
+
|
|
1912
|
+
# Add the skip connection to the regular mapping
|
|
1913
|
+
x = Add()([x, x_skip])
|
|
1914
|
+
|
|
1915
|
+
# Nonlinearly activate the result
|
|
1916
|
+
x = Activation("relu")(x)
|
|
1917
|
+
|
|
1918
|
+
# Return the result
|
|
1919
|
+
return x
|
|
1920
|
+
|
|
1921
|
+
|
|
1922
|
+
def MultiscaleResNetModel(n_channels, n_classes = 3, dropout_rate=0, dense_collection=0, use_pooling=True,
|
|
1923
|
+
header="classifier", model_signal_length = 128):
|
|
1924
|
+
|
|
1925
|
+
"""
|
|
1926
|
+
|
|
1927
|
+
Define a generic ResNet 1D encoder model.
|
|
1928
|
+
|
|
1929
|
+
Parameters
|
|
1930
|
+
----------
|
|
1931
|
+
n_channels : int
|
|
1932
|
+
Number of input channels.
|
|
1933
|
+
n_blocks : int
|
|
1934
|
+
Number of residual blocks in the model.
|
|
1935
|
+
n_classes : int, optional
|
|
1936
|
+
Number of output classes. Default is 3.
|
|
1937
|
+
dropout_rate : float, optional
|
|
1938
|
+
Dropout rate to be applied. Default is 0.
|
|
1939
|
+
dense_collection : int, optional
|
|
1940
|
+
Number of neurons in the dense layer. Default is 0.
|
|
1941
|
+
header : str, optional
|
|
1942
|
+
Type of the model header. "classifier" for classification, "regressor" for regression. Default is "classifier".
|
|
1943
|
+
model_signal_length : int, optional
|
|
1944
|
+
Length of the input signal. Default is 128.
|
|
1945
|
+
|
|
1946
|
+
Returns
|
|
1947
|
+
-------
|
|
1948
|
+
keras.models.Model
|
|
1949
|
+
ResNet 1D encoder model.
|
|
1950
|
+
|
|
1951
|
+
Notes
|
|
1952
|
+
-----
|
|
1953
|
+
This function defines a generic ResNet 1D encoder model with the specified number of input channels, residual
|
|
1954
|
+
blocks, output classes, dropout rate, dense collection, and model header. The model architecture follows the
|
|
1955
|
+
ResNet principles with 1D convolutional layers and residual connections. The final activation and number of
|
|
1956
|
+
neurons in the output layer are determined based on the header type.
|
|
1957
|
+
|
|
1958
|
+
Examples
|
|
1959
|
+
--------
|
|
1960
|
+
>>> model = ResNetModel(n_channels=3, n_blocks=4, n_classes=2, dropout_rate=0.2)
|
|
1961
|
+
# Define a ResNet 1D encoder model with 3 input channels, 4 residual blocks, and 2 output classes.
|
|
1962
|
+
|
|
1963
|
+
"""
|
|
1964
|
+
|
|
1965
|
+
if header=="classifier":
|
|
1966
|
+
final_activation = "softmax"
|
|
1967
|
+
neurons_final = n_classes
|
|
1968
|
+
elif header=="regressor":
|
|
1969
|
+
final_activation = "linear"
|
|
1970
|
+
neurons_final = 1
|
|
1971
|
+
else:
|
|
1972
|
+
return None
|
|
1973
|
+
|
|
1974
|
+
inputs = Input(shape=(model_signal_length,n_channels,))
|
|
1975
|
+
x = ZeroPadding1D(3)(inputs)
|
|
1976
|
+
x = Conv1D(64, kernel_size=7, strides=2, padding="valid", use_bias=False)(x)
|
|
1977
|
+
x = BatchNormalization()(x)
|
|
1978
|
+
x = ZeroPadding1D(1)(x)
|
|
1979
|
+
x_common = MaxPooling1D(pool_size=3, strides=2, padding='valid')(x)
|
|
1980
|
+
|
|
1981
|
+
# Block 1
|
|
1982
|
+
x1 = residual_block1D(x_common, 64, kernel_size=7,connection='projection')
|
|
1983
|
+
x1 = residual_block1D(x1, 128, kernel_size=7,connection='projection')
|
|
1984
|
+
x1 = residual_block1D(x1, 256, kernel_size=7,connection='projection')
|
|
1985
|
+
x1 = GlobalAveragePooling1D()(x1)
|
|
1986
|
+
|
|
1987
|
+
# Block 2
|
|
1988
|
+
x2 = residual_block1D(x_common, 64, kernel_size=5,connection='projection')
|
|
1989
|
+
x2 = residual_block1D(x2, 128, kernel_size=5,connection='projection')
|
|
1990
|
+
x2 = residual_block1D(x2, 256, kernel_size=5,connection='projection')
|
|
1991
|
+
x2 = GlobalAveragePooling1D()(x2)
|
|
1992
|
+
|
|
1993
|
+
# Block 3
|
|
1994
|
+
x3 = residual_block1D(x_common, 64, kernel_size=3,connection='projection')
|
|
1995
|
+
x3 = residual_block1D(x3, 128, kernel_size=3,connection='projection')
|
|
1996
|
+
x3 = residual_block1D(x3, 256, kernel_size=3,connection='projection')
|
|
1997
|
+
x3 = GlobalAveragePooling1D()(x3)
|
|
1998
|
+
|
|
1999
|
+
x_combined = Concatenate()([x1, x2, x3])
|
|
2000
|
+
x_combined = Flatten()(x_combined)
|
|
2001
|
+
|
|
2002
|
+
if dense_collection>0:
|
|
2003
|
+
x_combined = Dense(dense_collection)(x_combined)
|
|
2004
|
+
if dropout_rate>0:
|
|
2005
|
+
x_combined = Dropout(dropout_rate)(x_combined)
|
|
2006
|
+
|
|
2007
|
+
x_combined = Dense(neurons_final,activation=final_activation,name=header)(x_combined)
|
|
2008
|
+
model = Model(inputs, x_combined, name=header)
|
|
2009
|
+
|
|
2010
|
+
return model
|
|
2011
|
+
|
|
2012
|
+
def ResNetModelCurrent(n_channels, n_slices, depth=2, use_pooling=True, n_classes = 3, dropout_rate=0.1, dense_collection=512,
|
|
2013
|
+
header="classifier", model_signal_length = 128):
|
|
2014
|
+
|
|
2015
|
+
"""
|
|
2016
|
+
Creates a ResNet-based model tailored for signal classification or regression tasks.
|
|
2017
|
+
|
|
2018
|
+
This function constructs a 1D ResNet architecture with specified parameters. The model can be configured
|
|
2019
|
+
for either classification or regression tasks, determined by the `header` parameter. It consists of
|
|
2020
|
+
configurable ResNet blocks, global average pooling, optional dense layers, and dropout for regularization.
|
|
2021
|
+
|
|
2022
|
+
Parameters
|
|
2023
|
+
----------
|
|
2024
|
+
n_channels : int
|
|
2025
|
+
The number of channels in the input signal.
|
|
2026
|
+
n_slices : int
|
|
2027
|
+
The number of slices (or ResNet blocks) to use in the model.
|
|
2028
|
+
depth : int, optional
|
|
2029
|
+
The depth of the network, i.e., how many times the number of filters is doubled. Default is 2.
|
|
2030
|
+
use_pooling : bool, optional
|
|
2031
|
+
Whether to use MaxPooling between ResNet blocks. Default is True.
|
|
2032
|
+
n_classes : int, optional
|
|
2033
|
+
The number of classes for the classification task. Ignored for regression. Default is 3.
|
|
2034
|
+
dropout_rate : float, optional
|
|
2035
|
+
The dropout rate for regularization. Default is 0.1.
|
|
2036
|
+
dense_collection : int, optional
|
|
2037
|
+
The number of neurons in the dense layer following global pooling. If 0, the dense layer is omitted. Default is 512.
|
|
2038
|
+
header : str, optional
|
|
2039
|
+
Specifies the task type: "classifier" for classification or "regressor" for regression. Default is "classifier".
|
|
2040
|
+
model_signal_length : int, optional
|
|
2041
|
+
The length of the input signal. Default is 128.
|
|
2042
|
+
|
|
2043
|
+
Returns
|
|
2044
|
+
-------
|
|
2045
|
+
keras.Model
|
|
2046
|
+
The constructed Keras model ready for training or inference.
|
|
2047
|
+
|
|
2048
|
+
Notes
|
|
2049
|
+
-----
|
|
2050
|
+
- The model uses Conv1D layers for signal processing and applies global average pooling before the final classification
|
|
2051
|
+
or regression layer.
|
|
2052
|
+
- The choice of `final_activation` and `neurons_final` depends on the task: "softmax" and `n_classes` for classification,
|
|
2053
|
+
and "linear" and 1 for regression.
|
|
2054
|
+
- This function relies on a custom `residual_block1D` function for constructing ResNet blocks.
|
|
2055
|
+
|
|
2056
|
+
Examples
|
|
2057
|
+
--------
|
|
2058
|
+
>>> model = ResNetModelCurrent(n_channels=1, n_slices=2, depth=2, use_pooling=True, n_classes=3, dropout_rate=0.1, dense_collection=512, header="classifier", model_signal_length=128)
|
|
2059
|
+
# Creates a ResNet model configured for classification with 3 classes.
|
|
2060
|
+
"""
|
|
2061
|
+
|
|
2062
|
+
if header=="classifier":
|
|
2063
|
+
final_activation = "softmax"
|
|
2064
|
+
neurons_final = n_classes
|
|
2065
|
+
elif header=="regressor":
|
|
2066
|
+
final_activation = "linear"
|
|
2067
|
+
neurons_final = 1
|
|
2068
|
+
else:
|
|
2069
|
+
return None
|
|
2070
|
+
|
|
2071
|
+
inputs = Input(shape=(model_signal_length,n_channels,))
|
|
2072
|
+
x2 = Conv1D(64, kernel_size=1,strides=1,padding='same')(inputs)
|
|
2073
|
+
|
|
2074
|
+
n_filters = 64
|
|
2075
|
+
for k in range(depth):
|
|
2076
|
+
for i in range(n_slices):
|
|
2077
|
+
x2 = residual_block1D(x2,n_filters,kernel_size=8)
|
|
2078
|
+
n_filters *= 2
|
|
2079
|
+
if use_pooling and k!=(depth-1):
|
|
2080
|
+
x2 = MaxPooling1D()(x2)
|
|
2081
|
+
|
|
2082
|
+
x2 = GlobalAveragePooling1D()(x2)
|
|
2083
|
+
if dense_collection>0:
|
|
2084
|
+
x2 = Dense(dense_collection)(x2)
|
|
2085
|
+
if dropout_rate>0:
|
|
2086
|
+
x2 = Dropout(dropout_rate)(x2)
|
|
2087
|
+
|
|
2088
|
+
x2 = Dense(neurons_final,activation=final_activation,name=header)(x2)
|
|
2089
|
+
model = Model(inputs, x2, name=header)
|
|
2090
|
+
|
|
2091
|
+
return model
|
|
2092
|
+
|
|
2093
|
+
|
|
2094
|
+
def train_signal_model(config):
|
|
2095
|
+
|
|
2096
|
+
"""
|
|
2097
|
+
Initiates the training of a signal detection model using a specified configuration file.
|
|
2098
|
+
|
|
2099
|
+
This function triggers an external Python script to train a signal detection model. The training
|
|
2100
|
+
configuration, including data paths, model parameters, and training options, are specified in a JSON
|
|
2101
|
+
configuration file. The function asserts the existence of the configuration file before proceeding
|
|
2102
|
+
with the training process.
|
|
2103
|
+
|
|
2104
|
+
Parameters
|
|
2105
|
+
----------
|
|
2106
|
+
config : str
|
|
2107
|
+
The file path to the JSON configuration file specifying training parameters. This path must be valid
|
|
2108
|
+
and the configuration file must be correctly formatted according to the expectations of the
|
|
2109
|
+
'train_signal_model.py' script.
|
|
2110
|
+
|
|
2111
|
+
Raises
|
|
2112
|
+
------
|
|
2113
|
+
AssertionError
|
|
2114
|
+
If the specified configuration file does not exist at the given path.
|
|
2115
|
+
|
|
2116
|
+
Notes
|
|
2117
|
+
-----
|
|
2118
|
+
- The external training script 'train_signal_model.py' is expected to be located in a predefined directory
|
|
2119
|
+
relative to this function and is responsible for the actual model training process.
|
|
2120
|
+
- The configuration file should include details such as data directories, model architecture specifications,
|
|
2121
|
+
training hyperparameters, and any preprocessing steps required.
|
|
2122
|
+
|
|
2123
|
+
Examples
|
|
2124
|
+
--------
|
|
2125
|
+
>>> config_path = '/path/to/training_config.json'
|
|
2126
|
+
>>> train_signal_model(config_path)
|
|
2127
|
+
# This will execute the 'train_signal_model.py' script using the parameters specified in 'training_config.json'.
|
|
2128
|
+
"""
|
|
2129
|
+
|
|
2130
|
+
config = config.replace('\\','/')
|
|
2131
|
+
config = rf"{config}"
|
|
2132
|
+
assert os.path.exists(config),f'Config {config} is not a valid path.'
|
|
2133
|
+
|
|
2134
|
+
script_path = os.sep.join([abs_path, 'scripts', 'train_signal_model.py'])
|
|
2135
|
+
cmd = f'python "{script_path}" --config "{config}"'
|
|
2136
|
+
subprocess.call(cmd, shell=True)
|
|
2137
|
+
|
|
2138
|
+
def derivative(x, timeline, window, mode='bi'):
|
|
2139
|
+
|
|
2140
|
+
"""
|
|
2141
|
+
Compute the derivative of a given array of values with respect to time using a specified numerical differentiation method.
|
|
2142
|
+
|
|
2143
|
+
Parameters
|
|
2144
|
+
----------
|
|
2145
|
+
x : array_like
|
|
2146
|
+
The input array of values.
|
|
2147
|
+
timeline : array_like
|
|
2148
|
+
The array representing the time points corresponding to the input values.
|
|
2149
|
+
window : int
|
|
2150
|
+
The size of the window used for numerical differentiation. Must be a positive odd integer.
|
|
2151
|
+
mode : {'bi', 'forward', 'backward'}, optional
|
|
2152
|
+
The numerical differentiation method to be used:
|
|
2153
|
+
- 'bi' (default): Bidirectional differentiation using a symmetric window.
|
|
2154
|
+
- 'forward': Forward differentiation using a one-sided window.
|
|
2155
|
+
- 'backward': Backward differentiation using a one-sided window.
|
|
2156
|
+
|
|
2157
|
+
Returns
|
|
2158
|
+
-------
|
|
2159
|
+
dxdt : ndarray
|
|
2160
|
+
The computed derivative values of the input array with respect to time.
|
|
2161
|
+
|
|
2162
|
+
Raises
|
|
2163
|
+
------
|
|
2164
|
+
AssertionError
|
|
2165
|
+
If the window size is not an odd integer and mode is 'bi'.
|
|
2166
|
+
|
|
2167
|
+
Notes
|
|
2168
|
+
-----
|
|
2169
|
+
- For 'bi' mode, the window size must be an odd number.
|
|
2170
|
+
- For 'forward' mode, the derivative at the edge points may not be accurate due to the one-sided window.
|
|
2171
|
+
- For 'backward' mode, the derivative at the first few points may not be accurate due to the one-sided window.
|
|
2172
|
+
|
|
2173
|
+
Examples
|
|
2174
|
+
--------
|
|
2175
|
+
>>> import numpy as np
|
|
2176
|
+
>>> x = np.array([1, 2, 4, 7, 11])
|
|
2177
|
+
>>> timeline = np.array([0, 1, 2, 3, 4])
|
|
2178
|
+
>>> window = 3
|
|
2179
|
+
>>> derivative(x, timeline, window, mode='bi')
|
|
2180
|
+
array([3., 3., 3.])
|
|
2181
|
+
|
|
2182
|
+
>>> derivative(x, timeline, window, mode='forward')
|
|
2183
|
+
array([1., 2., 3.])
|
|
2184
|
+
|
|
2185
|
+
>>> derivative(x, timeline, window, mode='backward')
|
|
2186
|
+
array([3., 3., 3., 3.])
|
|
2187
|
+
"""
|
|
2188
|
+
|
|
2189
|
+
# modes = bi, forward, backward
|
|
2190
|
+
dxdt = np.zeros(len(x))
|
|
2191
|
+
dxdt[:] = np.nan
|
|
2192
|
+
|
|
2193
|
+
if mode=='bi':
|
|
2194
|
+
assert window%2==1,'Please set an odd window for the bidirectional mode'
|
|
2195
|
+
lower_bound = window//2
|
|
2196
|
+
upper_bound = len(x) - window//2 - 1
|
|
2197
|
+
elif mode=='forward':
|
|
2198
|
+
lower_bound = 0
|
|
2199
|
+
upper_bound = len(x) - window
|
|
2200
|
+
elif mode=='backward':
|
|
2201
|
+
lower_bound = window
|
|
2202
|
+
upper_bound = len(x)
|
|
2203
|
+
|
|
2204
|
+
for t in range(lower_bound,upper_bound):
|
|
2205
|
+
if mode=='bi':
|
|
2206
|
+
dxdt[t] = (x[t+window//2+1] - x[t-window//2]) / (timeline[t+window//2+1] - timeline[t-window//2])
|
|
2207
|
+
elif mode=='forward':
|
|
2208
|
+
dxdt[t] = (x[t+window] - x[t]) / (timeline[t+window] - timeline[t])
|
|
2209
|
+
elif mode=='backward':
|
|
2210
|
+
dxdt[t] = (x[t] - x[t-window]) / (timeline[t] - timeline[t-window])
|
|
2211
|
+
return dxdt
|
|
2212
|
+
|
|
2213
|
+
def velocity(x,y,timeline,window,mode='bi'):
|
|
2214
|
+
|
|
2215
|
+
"""
|
|
2216
|
+
Compute the velocity vector of a given 2D trajectory represented by arrays of x and y coordinates
|
|
2217
|
+
with respect to time using a specified numerical differentiation method.
|
|
2218
|
+
|
|
2219
|
+
Parameters
|
|
2220
|
+
----------
|
|
2221
|
+
x : array_like
|
|
2222
|
+
The array of x-coordinates of the trajectory.
|
|
2223
|
+
y : array_like
|
|
2224
|
+
The array of y-coordinates of the trajectory.
|
|
2225
|
+
timeline : array_like
|
|
2226
|
+
The array representing the time points corresponding to the x and y coordinates.
|
|
2227
|
+
window : int
|
|
2228
|
+
The size of the window used for numerical differentiation. Must be a positive odd integer.
|
|
2229
|
+
mode : {'bi', 'forward', 'backward'}, optional
|
|
2230
|
+
The numerical differentiation method to be used:
|
|
2231
|
+
- 'bi' (default): Bidirectional differentiation using a symmetric window.
|
|
2232
|
+
- 'forward': Forward differentiation using a one-sided window.
|
|
2233
|
+
- 'backward': Backward differentiation using a one-sided window.
|
|
2234
|
+
|
|
2235
|
+
Returns
|
|
2236
|
+
-------
|
|
2237
|
+
v : ndarray
|
|
2238
|
+
The computed velocity vector of the 2D trajectory with respect to time.
|
|
2239
|
+
The first column represents the x-component of velocity, and the second column represents the y-component.
|
|
2240
|
+
|
|
2241
|
+
Raises
|
|
2242
|
+
------
|
|
2243
|
+
AssertionError
|
|
2244
|
+
If the window size is not an odd integer and mode is 'bi'.
|
|
2245
|
+
|
|
2246
|
+
Notes
|
|
2247
|
+
-----
|
|
2248
|
+
- For 'bi' mode, the window size must be an odd number.
|
|
2249
|
+
- For 'forward' mode, the velocity at the edge points may not be accurate due to the one-sided window.
|
|
2250
|
+
- For 'backward' mode, the velocity at the first few points may not be accurate due to the one-sided window.
|
|
2251
|
+
|
|
2252
|
+
Examples
|
|
2253
|
+
--------
|
|
2254
|
+
>>> import numpy as np
|
|
2255
|
+
>>> x = np.array([1, 2, 4, 7, 11])
|
|
2256
|
+
>>> y = np.array([0, 3, 5, 8, 10])
|
|
2257
|
+
>>> timeline = np.array([0, 1, 2, 3, 4])
|
|
2258
|
+
>>> window = 3
|
|
2259
|
+
>>> velocity(x, y, timeline, window, mode='bi')
|
|
2260
|
+
array([[3., 3.],
|
|
2261
|
+
[3., 3.]])
|
|
2262
|
+
|
|
2263
|
+
>>> velocity(x, y, timeline, window, mode='forward')
|
|
2264
|
+
array([[2., 2.],
|
|
2265
|
+
[3., 3.]])
|
|
2266
|
+
|
|
2267
|
+
>>> velocity(x, y, timeline, window, mode='backward')
|
|
2268
|
+
array([[3., 3.],
|
|
2269
|
+
[3., 3.]])
|
|
2270
|
+
"""
|
|
2271
|
+
|
|
2272
|
+
v = np.zeros((len(x),2))
|
|
2273
|
+
v[:,:] = np.nan
|
|
2274
|
+
|
|
2275
|
+
v[:,0] = derivative(x, timeline, window, mode=mode)
|
|
2276
|
+
v[:,1] = derivative(y, timeline, window, mode=mode)
|
|
2277
|
+
|
|
2278
|
+
return v
|
|
2279
|
+
|
|
2280
|
+
def magnitude_velocity(v_matrix):
|
|
2281
|
+
|
|
2282
|
+
"""
|
|
2283
|
+
Compute the magnitude of velocity vectors given a matrix representing 2D velocity vectors.
|
|
2284
|
+
|
|
2285
|
+
Parameters
|
|
2286
|
+
----------
|
|
2287
|
+
v_matrix : array_like
|
|
2288
|
+
The matrix where each row represents a 2D velocity vector with the first column
|
|
2289
|
+
being the x-component and the second column being the y-component.
|
|
2290
|
+
|
|
2291
|
+
Returns
|
|
2292
|
+
-------
|
|
2293
|
+
magnitude : ndarray
|
|
2294
|
+
The computed magnitudes of the input velocity vectors.
|
|
2295
|
+
|
|
2296
|
+
Notes
|
|
2297
|
+
-----
|
|
2298
|
+
- If a velocity vector has NaN components, the corresponding magnitude will be NaN.
|
|
2299
|
+
- The function handles NaN values in the input matrix gracefully.
|
|
2300
|
+
|
|
2301
|
+
Examples
|
|
2302
|
+
--------
|
|
2303
|
+
>>> import numpy as np
|
|
2304
|
+
>>> v_matrix = np.array([[3, 4],
|
|
2305
|
+
... [2, 2],
|
|
2306
|
+
... [3, 3]])
|
|
2307
|
+
>>> magnitude_velocity(v_matrix)
|
|
2308
|
+
array([5., 2.82842712, 4.24264069])
|
|
2309
|
+
|
|
2310
|
+
>>> v_matrix_with_nan = np.array([[3, 4],
|
|
2311
|
+
... [np.nan, 2],
|
|
2312
|
+
... [3, np.nan]])
|
|
2313
|
+
>>> magnitude_velocity(v_matrix_with_nan)
|
|
2314
|
+
array([5., nan, nan])
|
|
2315
|
+
"""
|
|
2316
|
+
|
|
2317
|
+
magnitude = np.zeros(len(v_matrix))
|
|
2318
|
+
magnitude[:] = np.nan
|
|
2319
|
+
for i in range(len(v_matrix)):
|
|
2320
|
+
if v_matrix[i,0]==v_matrix[i,0]:
|
|
2321
|
+
magnitude[i] = np.sqrt(v_matrix[i,0]**2 + v_matrix[i,1]**2)
|
|
2322
|
+
return magnitude
|
|
2323
|
+
|
|
2324
|
+
def orientation(v_matrix):
|
|
2325
|
+
|
|
2326
|
+
"""
|
|
2327
|
+
Compute the orientation angles (in radians) of 2D velocity vectors given a matrix representing velocity vectors.
|
|
2328
|
+
|
|
2329
|
+
Parameters
|
|
2330
|
+
----------
|
|
2331
|
+
v_matrix : array_like
|
|
2332
|
+
The matrix where each row represents a 2D velocity vector with the first column
|
|
2333
|
+
being the x-component and the second column being the y-component.
|
|
2334
|
+
|
|
2335
|
+
Returns
|
|
2336
|
+
-------
|
|
2337
|
+
orientation_array : ndarray
|
|
2338
|
+
The computed orientation angles of the input velocity vectors in radians.
|
|
2339
|
+
If a velocity vector has NaN components, the corresponding orientation angle will be NaN.
|
|
2340
|
+
|
|
2341
|
+
Examples
|
|
2342
|
+
--------
|
|
2343
|
+
>>> import numpy as np
|
|
2344
|
+
>>> v_matrix = np.array([[3, 4],
|
|
2345
|
+
... [2, 2],
|
|
2346
|
+
... [-3, -3]])
|
|
2347
|
+
>>> orientation(v_matrix)
|
|
2348
|
+
array([0.92729522, 0.78539816, -2.35619449])
|
|
2349
|
+
|
|
2350
|
+
>>> v_matrix_with_nan = np.array([[3, 4],
|
|
2351
|
+
... [np.nan, 2],
|
|
2352
|
+
... [3, np.nan]])
|
|
2353
|
+
>>> orientation(v_matrix_with_nan)
|
|
2354
|
+
array([0.92729522, nan, nan])
|
|
2355
|
+
"""
|
|
2356
|
+
|
|
2357
|
+
orientation_array = np.zeros(len(v_matrix))
|
|
2358
|
+
for t in range(len(orientation_array)):
|
|
2359
|
+
if v_matrix[t,0]==v_matrix[t,0]:
|
|
2360
|
+
orientation_array[t] = np.arctan2(v_matrix[t,0],v_matrix[t,1])
|
|
2361
|
+
return orientation_array
|
|
2362
|
+
|
|
2363
|
+
def T_MSD(x,y,dt):
|
|
2364
|
+
|
|
2365
|
+
"""
|
|
2366
|
+
Compute the Time-Averaged Mean Square Displacement (T-MSD) of a 2D trajectory.
|
|
2367
|
+
|
|
2368
|
+
Parameters
|
|
2369
|
+
----------
|
|
2370
|
+
x : array_like
|
|
2371
|
+
The array of x-coordinates of the trajectory.
|
|
2372
|
+
y : array_like
|
|
2373
|
+
The array of y-coordinates of the trajectory.
|
|
2374
|
+
dt : float
|
|
2375
|
+
The time interval between successive data points in the trajectory.
|
|
2376
|
+
|
|
2377
|
+
Returns
|
|
2378
|
+
-------
|
|
2379
|
+
msd : list
|
|
2380
|
+
A list containing the Time-Averaged Mean Square Displacement values for different time lags.
|
|
2381
|
+
timelag : ndarray
|
|
2382
|
+
The array representing the time lags corresponding to the calculated MSD values.
|
|
2383
|
+
|
|
2384
|
+
Notes
|
|
2385
|
+
-----
|
|
2386
|
+
- T-MSD is a measure of the average spatial extent explored by a particle over a given time interval.
|
|
2387
|
+
- The input trajectories (x, y) are assumed to be in the same unit of length.
|
|
2388
|
+
- The time interval (dt) should be consistent with the time unit used in the data.
|
|
2389
|
+
|
|
2390
|
+
Examples
|
|
2391
|
+
--------
|
|
2392
|
+
>>> import numpy as np
|
|
2393
|
+
>>> x = np.array([1, 2, 4, 7, 11])
|
|
2394
|
+
>>> y = np.array([0, 3, 5, 8, 10])
|
|
2395
|
+
>>> dt = 1.0 # Time interval between data points
|
|
2396
|
+
>>> T_MSD(x, y, dt)
|
|
2397
|
+
([6.0, 9.0, 4.666666666666667, 1.6666666666666667],
|
|
2398
|
+
array([1., 2., 3., 4.]))
|
|
2399
|
+
"""
|
|
2400
|
+
|
|
2401
|
+
msd = []
|
|
2402
|
+
N = len(x)
|
|
2403
|
+
for n in range(1,N):
|
|
2404
|
+
s = 0
|
|
2405
|
+
for i in range(0,N-n):
|
|
2406
|
+
s+=(x[n+i] - x[i])**2 + (y[n+i] - y[i])**2
|
|
2407
|
+
msd.append(1/(N-n)*s)
|
|
2408
|
+
|
|
2409
|
+
timelag = np.linspace(dt,(N-1)*dt,N-1)
|
|
2410
|
+
return msd,timelag
|
|
2411
|
+
|
|
2412
|
+
def linear_msd(t, m):
|
|
2413
|
+
|
|
2414
|
+
"""
|
|
2415
|
+
Function to compute Mean Square Displacement (MSD) with a linear scaling relationship.
|
|
2416
|
+
|
|
2417
|
+
Parameters
|
|
2418
|
+
----------
|
|
2419
|
+
t : array_like
|
|
2420
|
+
Time lag values.
|
|
2421
|
+
m : float
|
|
2422
|
+
Linear scaling factor representing the slope of the MSD curve.
|
|
2423
|
+
|
|
2424
|
+
Returns
|
|
2425
|
+
-------
|
|
2426
|
+
msd : ndarray
|
|
2427
|
+
Computed MSD values based on the linear scaling relationship.
|
|
2428
|
+
|
|
2429
|
+
Examples
|
|
2430
|
+
--------
|
|
2431
|
+
>>> import numpy as np
|
|
2432
|
+
>>> t = np.array([1, 2, 3, 4])
|
|
2433
|
+
>>> m = 2.0
|
|
2434
|
+
>>> linear_msd(t, m)
|
|
2435
|
+
array([2., 4., 6., 8.])
|
|
2436
|
+
"""
|
|
2437
|
+
|
|
2438
|
+
return m*t
|
|
2439
|
+
|
|
2440
|
+
def alpha_msd(t, m, alpha):
|
|
2441
|
+
|
|
2442
|
+
"""
|
|
2443
|
+
Function to compute Mean Square Displacement (MSD) with a power-law scaling relationship.
|
|
2444
|
+
|
|
2445
|
+
Parameters
|
|
2446
|
+
----------
|
|
2447
|
+
t : array_like
|
|
2448
|
+
Time lag values.
|
|
2449
|
+
m : float
|
|
2450
|
+
Scaling factor.
|
|
2451
|
+
alpha : float
|
|
2452
|
+
Exponent representing the scaling relationship between MSD and time.
|
|
2453
|
+
|
|
2454
|
+
Returns
|
|
2455
|
+
-------
|
|
2456
|
+
msd : ndarray
|
|
2457
|
+
Computed MSD values based on the power-law scaling relationship.
|
|
2458
|
+
|
|
2459
|
+
Examples
|
|
2460
|
+
--------
|
|
2461
|
+
>>> import numpy as np
|
|
2462
|
+
>>> t = np.array([1, 2, 3, 4])
|
|
2463
|
+
>>> m = 2.0
|
|
2464
|
+
>>> alpha = 0.5
|
|
2465
|
+
>>> alpha_msd(t, m, alpha)
|
|
2466
|
+
array([2. , 4. , 6. , 8. ])
|
|
2467
|
+
"""
|
|
2468
|
+
|
|
2469
|
+
return m*t**alpha
|
|
2470
|
+
|
|
2471
|
+
def sliding_msd(x, y, timeline, window, mode='bi', n_points_migration=7, n_points_transport=7):
|
|
2472
|
+
|
|
2473
|
+
"""
|
|
2474
|
+
Compute sliding mean square displacement (sMSD) and anomalous exponent (alpha) for a 2D trajectory using a sliding window approach.
|
|
2475
|
+
|
|
2476
|
+
Parameters
|
|
2477
|
+
----------
|
|
2478
|
+
x : array_like
|
|
2479
|
+
The array of x-coordinates of the trajectory.
|
|
2480
|
+
y : array_like
|
|
2481
|
+
The array of y-coordinates of the trajectory.
|
|
2482
|
+
timeline : array_like
|
|
2483
|
+
The array representing the time points corresponding to the x and y coordinates.
|
|
2484
|
+
window : int
|
|
2485
|
+
The size of the sliding window used for computing local MSD and alpha values.
|
|
2486
|
+
mode : {'bi', 'forward', 'backward'}, optional
|
|
2487
|
+
The sliding window mode:
|
|
2488
|
+
- 'bi' (default): Bidirectional sliding window.
|
|
2489
|
+
- 'forward': Forward sliding window.
|
|
2490
|
+
- 'backward': Backward sliding window.
|
|
2491
|
+
n_points_migration : int, optional
|
|
2492
|
+
The number of points used for fitting the linear function in the MSD calculation.
|
|
2493
|
+
n_points_transport : int, optional
|
|
2494
|
+
The number of points used for fitting the alpha function in the anomalous exponent calculation.
|
|
2495
|
+
|
|
2496
|
+
Returns
|
|
2497
|
+
-------
|
|
2498
|
+
s_msd : ndarray
|
|
2499
|
+
Sliding Mean Square Displacement values calculated using the sliding window approach.
|
|
2500
|
+
s_alpha : ndarray
|
|
2501
|
+
Sliding anomalous exponent (alpha) values calculated using the sliding window approach.
|
|
2502
|
+
|
|
2503
|
+
Raises
|
|
2504
|
+
------
|
|
2505
|
+
AssertionError
|
|
2506
|
+
If the window size is not larger than the number of fit points.
|
|
2507
|
+
|
|
2508
|
+
Notes
|
|
2509
|
+
-----
|
|
2510
|
+
- The input trajectories (x, y) are assumed to be in the same unit of length.
|
|
2511
|
+
- The time unit used in the data should be consistent with the time intervals in the timeline array.
|
|
2512
|
+
|
|
2513
|
+
Examples
|
|
2514
|
+
--------
|
|
2515
|
+
>>> import numpy as np
|
|
2516
|
+
>>> x = np.array([1, 2, 4, 7, 11, 15, 20])
|
|
2517
|
+
>>> y = np.array([0, 3, 5, 8, 10, 14, 18])
|
|
2518
|
+
>>> timeline = np.array([0, 1, 2, 3, 4, 5, 6])
|
|
2519
|
+
>>> window = 3
|
|
2520
|
+
>>> s_msd, s_alpha = sliding_msd(x, y, timeline, window, n_points_migration=2, n_points_transport=3)
|
|
2521
|
+
"""
|
|
2522
|
+
|
|
2523
|
+
assert window > n_points_migration,'Please set a window larger than the number of fit points...'
|
|
2524
|
+
|
|
2525
|
+
# modes = bi, forward, backward
|
|
2526
|
+
s_msd = np.zeros(len(x))
|
|
2527
|
+
s_msd[:] = np.nan
|
|
2528
|
+
s_alpha = np.zeros(len(x))
|
|
2529
|
+
s_alpha[:] = np.nan
|
|
2530
|
+
dt = timeline[1] - timeline[0]
|
|
2531
|
+
|
|
2532
|
+
if mode=='bi':
|
|
2533
|
+
assert window%2==1,'Please set an odd window for the bidirectional mode'
|
|
2534
|
+
lower_bound = window//2
|
|
2535
|
+
upper_bound = len(x) - window//2 - 1
|
|
2536
|
+
elif mode=='forward':
|
|
2537
|
+
lower_bound = 0
|
|
2538
|
+
upper_bound = len(x) - window
|
|
2539
|
+
elif mode=='backward':
|
|
2540
|
+
lower_bound = window
|
|
2541
|
+
upper_bound = len(x)
|
|
2542
|
+
|
|
2543
|
+
for t in range(lower_bound,upper_bound):
|
|
2544
|
+
if mode=='bi':
|
|
2545
|
+
x_sub = x[t-window//2:t+window//2+1]
|
|
2546
|
+
y_sub = y[t-window//2:t+window//2+1]
|
|
2547
|
+
msd,timelag = T_MSD(x_sub,y_sub,dt)
|
|
2548
|
+
# dxdt[t] = (x[t+window//2+1] - x[t-window//2]) / (timeline[t+window//2+1] - timeline[t-window//2])
|
|
2549
|
+
elif mode=='forward':
|
|
2550
|
+
x_sub = x[t:t+window]
|
|
2551
|
+
y_sub = y[t:t+window]
|
|
2552
|
+
msd,timelag = T_MSD(x_sub,y_sub,dt)
|
|
2553
|
+
# dxdt[t] = (x[t+window] - x[t]) / (timeline[t+window] - timeline[t])
|
|
2554
|
+
elif mode=='backward':
|
|
2555
|
+
x_sub = x[t-window:t]
|
|
2556
|
+
y_sub = y[t-window:t]
|
|
2557
|
+
msd,timelag = T_MSD(x_sub,y_sub,dt)
|
|
2558
|
+
# dxdt[t] = (x[t] - x[t-window]) / (timeline[t] - timeline[t-window])
|
|
2559
|
+
popt,pcov = curve_fit(linear_msd,timelag[:n_points_migration],msd[:n_points_migration])
|
|
2560
|
+
s_msd[t] = popt[0]
|
|
2561
|
+
popt_alpha,pcov_alpha = curve_fit(alpha_msd,timelag[:n_points_transport],msd[:n_points_transport])
|
|
2562
|
+
s_alpha[t] = popt_alpha[1]
|
|
2563
|
+
|
|
2564
|
+
return s_msd, s_alpha
|
|
2565
|
+
|
|
2566
|
+
def drift_msd(t, d, v):
|
|
2567
|
+
|
|
2568
|
+
"""
|
|
2569
|
+
Calculates the mean squared displacement (MSD) of a particle undergoing diffusion with drift.
|
|
2570
|
+
|
|
2571
|
+
The function computes the MSD for a particle that diffuses in a medium with a constant drift velocity.
|
|
2572
|
+
The MSD is given by the formula: MSD = 4Dt + V^2t^2, where D is the diffusion coefficient, V is the drift
|
|
2573
|
+
velocity, and t is the time.
|
|
2574
|
+
|
|
2575
|
+
Parameters
|
|
2576
|
+
----------
|
|
2577
|
+
t : float or ndarray
|
|
2578
|
+
Time or an array of time points at which to calculate the MSD.
|
|
2579
|
+
d : float
|
|
2580
|
+
Diffusion coefficient of the particle.
|
|
2581
|
+
v : float
|
|
2582
|
+
Drift velocity of the particle.
|
|
2583
|
+
|
|
2584
|
+
Returns
|
|
2585
|
+
-------
|
|
2586
|
+
float or ndarray
|
|
2587
|
+
The mean squared displacement of the particle at time t. Returns a single float value if t is a float,
|
|
2588
|
+
or returns an array of MSD values if t is an ndarray.
|
|
2589
|
+
|
|
2590
|
+
Examples
|
|
2591
|
+
--------
|
|
2592
|
+
>>> drift_msd(t=5, d=1, v=2)
|
|
2593
|
+
40
|
|
2594
|
+
>>> drift_msd(t=np.array([1, 2, 3]), d=1, v=2)
|
|
2595
|
+
array([ 6, 16, 30])
|
|
2596
|
+
|
|
2597
|
+
Notes
|
|
2598
|
+
-----
|
|
2599
|
+
- This formula assumes that the particle undergoes normal diffusion with an additional constant drift component.
|
|
2600
|
+
- The function can be used to model the behavior of particles in systems where both diffusion and directed motion occur.
|
|
2601
|
+
"""
|
|
2602
|
+
|
|
2603
|
+
return 4*d*t + v**2*t**2
|
|
2604
|
+
|
|
2605
|
+
def sliding_msd_drift(x, y, timeline, window, mode='bi', n_points_migration=7, n_points_transport=7, r2_threshold=0.75):
|
|
2606
|
+
|
|
2607
|
+
"""
|
|
2608
|
+
Computes the sliding mean squared displacement (MSD) with drift for particle trajectories.
|
|
2609
|
+
|
|
2610
|
+
This function calculates the diffusion coefficient and drift velocity of particles based on their
|
|
2611
|
+
x and y positions over time. It uses a sliding window approach to estimate the MSD at each point in time,
|
|
2612
|
+
fitting the MSD to the equation MSD = 4Dt + V^2t^2 to extract the diffusion coefficient (D) and drift velocity (V).
|
|
2613
|
+
|
|
2614
|
+
Parameters
|
|
2615
|
+
----------
|
|
2616
|
+
x : ndarray
|
|
2617
|
+
The x positions of the particle over time.
|
|
2618
|
+
y : ndarray
|
|
2619
|
+
The y positions of the particle over time.
|
|
2620
|
+
timeline : ndarray
|
|
2621
|
+
The time points corresponding to the x and y positions.
|
|
2622
|
+
window : int
|
|
2623
|
+
The size of the sliding window used to calculate the MSD at each point in time.
|
|
2624
|
+
mode : str, optional
|
|
2625
|
+
The mode of sliding window calculation. Options are 'bi' for bidirectional, 'forward', or 'backward'. Default is 'bi'.
|
|
2626
|
+
n_points_migration : int, optional
|
|
2627
|
+
The number of initial points from the calculated MSD to use for fitting the migration model. Default is 7.
|
|
2628
|
+
n_points_transport : int, optional
|
|
2629
|
+
The number of initial points from the calculated MSD to use for fitting the transport model. Default is 7.
|
|
2630
|
+
r2_threshold : float, optional
|
|
2631
|
+
The R-squared threshold used to validate the fit. Default is 0.75.
|
|
2632
|
+
|
|
2633
|
+
Returns
|
|
2634
|
+
-------
|
|
2635
|
+
tuple
|
|
2636
|
+
A tuple containing two ndarrays: the estimated diffusion coefficients and drift velocities for each point in time.
|
|
2637
|
+
|
|
2638
|
+
Raises
|
|
2639
|
+
------
|
|
2640
|
+
AssertionError
|
|
2641
|
+
If the window size is not larger than the number of fit points or if the window size is even when mode is 'bi'.
|
|
2642
|
+
|
|
2643
|
+
Notes
|
|
2644
|
+
-----
|
|
2645
|
+
- The function assumes a uniform time step between each point in the timeline.
|
|
2646
|
+
- The 'bi' mode requires an odd-sized window to symmetrically calculate the MSD around each point in time.
|
|
2647
|
+
- The curve fitting is performed using the `curve_fit` function from `scipy.optimize`, fitting to the `drift_msd` model.
|
|
2648
|
+
|
|
2649
|
+
Examples
|
|
2650
|
+
--------
|
|
2651
|
+
>>> x = np.random.rand(100)
|
|
2652
|
+
>>> y = np.random.rand(100)
|
|
2653
|
+
>>> timeline = np.arange(100)
|
|
2654
|
+
>>> window = 11
|
|
2655
|
+
>>> diffusion, velocity = sliding_msd_drift(x, y, timeline, window, mode='bi')
|
|
2656
|
+
# Calculates the diffusion coefficient and drift velocity using a bidirectional sliding window.
|
|
2657
|
+
"""
|
|
2658
|
+
|
|
2659
|
+
assert window > n_points_migration,'Please set a window larger than the number of fit points...'
|
|
2660
|
+
|
|
2661
|
+
# modes = bi, forward, backward
|
|
2662
|
+
s_diffusion = np.zeros(len(x))
|
|
2663
|
+
s_diffusion[:] = np.nan
|
|
2664
|
+
s_velocity = np.zeros(len(x))
|
|
2665
|
+
s_velocity[:] = np.nan
|
|
2666
|
+
dt = timeline[1] - timeline[0]
|
|
2667
|
+
|
|
2668
|
+
if mode=='bi':
|
|
2669
|
+
assert window%2==1,'Please set an odd window for the bidirectional mode'
|
|
2670
|
+
lower_bound = window//2
|
|
2671
|
+
upper_bound = len(x) - window//2 - 1
|
|
2672
|
+
elif mode=='forward':
|
|
2673
|
+
lower_bound = 0
|
|
2674
|
+
upper_bound = len(x) - window
|
|
2675
|
+
elif mode=='backward':
|
|
2676
|
+
lower_bound = window
|
|
2677
|
+
upper_bound = len(x)
|
|
2678
|
+
|
|
2679
|
+
for t in range(lower_bound,upper_bound):
|
|
2680
|
+
if mode=='bi':
|
|
2681
|
+
x_sub = x[t-window//2:t+window//2+1]
|
|
2682
|
+
y_sub = y[t-window//2:t+window//2+1]
|
|
2683
|
+
msd,timelag = T_MSD(x_sub,y_sub,dt)
|
|
2684
|
+
# dxdt[t] = (x[t+window//2+1] - x[t-window//2]) / (timeline[t+window//2+1] - timeline[t-window//2])
|
|
2685
|
+
elif mode=='forward':
|
|
2686
|
+
x_sub = x[t:t+window]
|
|
2687
|
+
y_sub = y[t:t+window]
|
|
2688
|
+
msd,timelag = T_MSD(x_sub,y_sub,dt)
|
|
2689
|
+
# dxdt[t] = (x[t+window] - x[t]) / (timeline[t+window] - timeline[t])
|
|
2690
|
+
elif mode=='backward':
|
|
2691
|
+
x_sub = x[t-window:t]
|
|
2692
|
+
y_sub = y[t-window:t]
|
|
2693
|
+
msd,timelag = T_MSD(x_sub,y_sub,dt)
|
|
2694
|
+
# dxdt[t] = (x[t] - x[t-window]) / (timeline[t] - timeline[t-window])
|
|
2695
|
+
|
|
2696
|
+
popt,pcov = curve_fit(drift_msd,timelag[:n_points_migration],msd[:n_points_migration])
|
|
2697
|
+
#if not np.any([math.isinf(a) for a in pcov.flatten()]):
|
|
2698
|
+
s_diffusion[t] = popt[0]
|
|
2699
|
+
s_velocity[t] = popt[1]
|
|
2700
|
+
|
|
2701
|
+
return s_diffusion, s_velocity
|
|
2702
|
+
|
|
2703
|
+
def columnwise_mean(matrix, min_nbr_values = 1):
|
|
2704
|
+
|
|
2705
|
+
"""
|
|
2706
|
+
Calculate the column-wise mean and standard deviation of non-NaN elements in the input matrix.
|
|
2707
|
+
|
|
2708
|
+
Parameters:
|
|
2709
|
+
----------
|
|
2710
|
+
matrix : numpy.ndarray
|
|
2711
|
+
The input matrix for which column-wise mean and standard deviation are calculated.
|
|
2712
|
+
min_nbr_values : int, optional
|
|
2713
|
+
The minimum number of non-NaN values required in a column to calculate mean and standard deviation.
|
|
2714
|
+
Default is 8.
|
|
2715
|
+
|
|
2716
|
+
Returns:
|
|
2717
|
+
-------
|
|
2718
|
+
mean_line : numpy.ndarray
|
|
2719
|
+
An array containing the column-wise mean of non-NaN elements. Elements with fewer than `min_nbr_values` non-NaN
|
|
2720
|
+
values are replaced with NaN.
|
|
2721
|
+
mean_line_std : numpy.ndarray
|
|
2722
|
+
An array containing the column-wise standard deviation of non-NaN elements. Elements with fewer than `min_nbr_values`
|
|
2723
|
+
non-NaN values are replaced with NaN.
|
|
2724
|
+
|
|
2725
|
+
Notes:
|
|
2726
|
+
------
|
|
2727
|
+
1. This function calculates the mean and standard deviation of non-NaN elements in each column of the input matrix.
|
|
2728
|
+
2. Columns with fewer than `min_nbr_values` non-zero elements will have NaN as the mean and standard deviation.
|
|
2729
|
+
3. NaN values in the input matrix are ignored during calculation.
|
|
2730
|
+
"""
|
|
2731
|
+
|
|
2732
|
+
mean_line = np.zeros(matrix.shape[1])
|
|
2733
|
+
mean_line[:] = np.nan
|
|
2734
|
+
mean_line_std = np.zeros(matrix.shape[1])
|
|
2735
|
+
mean_line_std[:] = np.nan
|
|
2736
|
+
|
|
2737
|
+
for k in range(matrix.shape[1]):
|
|
2738
|
+
values = matrix[:,k]
|
|
2739
|
+
values = values[values!=0]
|
|
2740
|
+
if len(values[values==values])>min_nbr_values:
|
|
2741
|
+
mean_line[k] = np.nanmean(values)
|
|
2742
|
+
mean_line_std[k] = np.nanstd(values)
|
|
2743
|
+
return mean_line, mean_line_std
|
|
2744
|
+
|
|
2745
|
+
|
|
2746
|
+
def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], return_matrix=False, forced_max_duration=None, min_nbr_values=2):
|
|
2747
|
+
|
|
2748
|
+
"""
|
|
2749
|
+
Calculate the mean and standard deviation of a specified signal for tracks of a given class in the input DataFrame.
|
|
2750
|
+
|
|
2751
|
+
Parameters:
|
|
2752
|
+
----------
|
|
2753
|
+
df : pandas.DataFrame
|
|
2754
|
+
Input DataFrame containing tracking data.
|
|
2755
|
+
signal_name : str
|
|
2756
|
+
Name of the signal (column) in the DataFrame for which mean and standard deviation are calculated.
|
|
2757
|
+
class_col : str
|
|
2758
|
+
Name of the column in the DataFrame containing class labels.
|
|
2759
|
+
time_col : str, optional
|
|
2760
|
+
Name of the column in the DataFrame containing time information. Default is None.
|
|
2761
|
+
class_value : int, optional
|
|
2762
|
+
Value representing the class of interest. Default is 0.
|
|
2763
|
+
|
|
2764
|
+
Returns:
|
|
2765
|
+
-------
|
|
2766
|
+
mean_signal : numpy.ndarray
|
|
2767
|
+
An array containing the mean signal values for tracks of the specified class. Tracks with class not equal to
|
|
2768
|
+
`class_value` are excluded from the calculation.
|
|
2769
|
+
std_signal : numpy.ndarray
|
|
2770
|
+
An array containing the standard deviation of signal values for tracks of the specified class. Tracks with class
|
|
2771
|
+
not equal to `class_value` are excluded from the calculation.
|
|
2772
|
+
actual_timeline : numpy.ndarray
|
|
2773
|
+
An array representing the time points corresponding to the mean signal values.
|
|
2774
|
+
|
|
2775
|
+
Notes:
|
|
2776
|
+
------
|
|
2777
|
+
1. This function calculates the mean and standard deviation of the specified signal for tracks of a given class.
|
|
2778
|
+
2. Tracks with class not equal to `class_value` are excluded from the calculation.
|
|
2779
|
+
3. Tracks with missing or NaN values in the specified signal are ignored during calculation.
|
|
2780
|
+
4. Tracks are aligned based on their 'FRAME' values and the specified `time_col` (if provided).
|
|
2781
|
+
"""
|
|
2782
|
+
|
|
2783
|
+
assert signal_name in list(df.columns),"The signal you want to plot is not one of the measured features."
|
|
2784
|
+
if isinstance(class_value,int):
|
|
2785
|
+
class_value = [class_value]
|
|
2786
|
+
|
|
2787
|
+
if forced_max_duration is None:
|
|
2788
|
+
max_duration = ceil(np.amax(df.groupby(['position','TRACK_ID']).size().values))
|
|
2789
|
+
else:
|
|
2790
|
+
max_duration = forced_max_duration
|
|
2791
|
+
n_tracks = len(df.groupby(['position','TRACK_ID']))
|
|
2792
|
+
signal_matrix = np.zeros((n_tracks,max_duration*2 + 1))
|
|
2793
|
+
signal_matrix[:,:] = np.nan
|
|
2794
|
+
|
|
2795
|
+
trackid=0
|
|
2796
|
+
for track,track_group in df.loc[df[class_col].isin(class_value)].groupby(['position','TRACK_ID']):
|
|
2797
|
+
track_group = track_group.sort_values(by='FRAME')
|
|
2798
|
+
cclass = track_group[class_col].to_numpy()[0]
|
|
2799
|
+
if cclass != 0:
|
|
2800
|
+
ref_time = 0
|
|
2801
|
+
else:
|
|
2802
|
+
try:
|
|
2803
|
+
ref_time = floor(track_group[time_col].to_numpy()[0])
|
|
2804
|
+
except:
|
|
2805
|
+
continue
|
|
2806
|
+
signal = track_group[signal_name].to_numpy()
|
|
2807
|
+
timeline = track_group['FRAME'].to_numpy().astype(int)
|
|
2808
|
+
timeline_shifted = timeline - ref_time + max_duration
|
|
2809
|
+
signal_matrix[trackid,timeline_shifted] = signal
|
|
2810
|
+
trackid+=1
|
|
2811
|
+
|
|
2812
|
+
mean_signal, std_signal = columnwise_mean(signal_matrix, min_nbr_values=min_nbr_values)
|
|
2813
|
+
actual_timeline = np.linspace(-max_duration, max_duration, 2*max_duration+1)
|
|
2814
|
+
if return_matrix:
|
|
2815
|
+
return mean_signal, std_signal, actual_timeline, signal_matrix
|
|
2816
|
+
else:
|
|
2817
|
+
return mean_signal, std_signal, actual_timeline
|
|
2818
|
+
|
|
2819
|
+
if __name__ == "__main__":
|
|
2820
|
+
|
|
2821
|
+
# model = MultiScaleResNetModel(3, n_classes = 3, dropout_rate=0, dense_collection=1024, header="classifier", model_signal_length = 128)
|
|
2822
|
+
# print(model.summary())
|
|
2823
|
+
model = ResNetModelCurrent(1, 2, depth=2, use_pooling=True, n_classes = 3, dropout_rate=0.1, dense_collection=512,
|
|
2824
|
+
header="classifier", model_signal_length = 128)
|
|
2825
|
+
print(model.summary())
|
|
2826
|
+
#plot_model(model, to_file='test.png', show_shapes=True)
|