celldetective 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. celldetective/__init__.py +2 -0
  2. celldetective/__main__.py +432 -0
  3. celldetective/datasets/segmentation_annotations/blank +0 -0
  4. celldetective/datasets/signal_annotations/blank +0 -0
  5. celldetective/events.py +149 -0
  6. celldetective/extra_properties.py +100 -0
  7. celldetective/filters.py +89 -0
  8. celldetective/gui/__init__.py +20 -0
  9. celldetective/gui/about.py +44 -0
  10. celldetective/gui/analyze_block.py +563 -0
  11. celldetective/gui/btrack_options.py +898 -0
  12. celldetective/gui/classifier_widget.py +386 -0
  13. celldetective/gui/configure_new_exp.py +532 -0
  14. celldetective/gui/control_panel.py +438 -0
  15. celldetective/gui/gui_utils.py +495 -0
  16. celldetective/gui/json_readers.py +113 -0
  17. celldetective/gui/measurement_options.py +1425 -0
  18. celldetective/gui/neighborhood_options.py +452 -0
  19. celldetective/gui/plot_signals_ui.py +1042 -0
  20. celldetective/gui/process_block.py +1055 -0
  21. celldetective/gui/retrain_segmentation_model_options.py +706 -0
  22. celldetective/gui/retrain_signal_model_options.py +643 -0
  23. celldetective/gui/seg_model_loader.py +460 -0
  24. celldetective/gui/signal_annotator.py +2388 -0
  25. celldetective/gui/signal_annotator_options.py +340 -0
  26. celldetective/gui/styles.py +217 -0
  27. celldetective/gui/survival_ui.py +903 -0
  28. celldetective/gui/tableUI.py +608 -0
  29. celldetective/gui/thresholds_gui.py +1300 -0
  30. celldetective/icons/logo-large.png +0 -0
  31. celldetective/icons/logo.png +0 -0
  32. celldetective/icons/signals_icon.png +0 -0
  33. celldetective/icons/splash-test.png +0 -0
  34. celldetective/icons/splash.png +0 -0
  35. celldetective/icons/splash0.png +0 -0
  36. celldetective/icons/survival2.png +0 -0
  37. celldetective/icons/vignette_signals2.png +0 -0
  38. celldetective/icons/vignette_signals2.svg +114 -0
  39. celldetective/io.py +2050 -0
  40. celldetective/links/zenodo.json +561 -0
  41. celldetective/measure.py +1258 -0
  42. celldetective/models/segmentation_effectors/blank +0 -0
  43. celldetective/models/segmentation_generic/blank +0 -0
  44. celldetective/models/segmentation_targets/blank +0 -0
  45. celldetective/models/signal_detection/blank +0 -0
  46. celldetective/models/tracking_configs/mcf7.json +68 -0
  47. celldetective/models/tracking_configs/ricm.json +203 -0
  48. celldetective/models/tracking_configs/ricm2.json +203 -0
  49. celldetective/neighborhood.py +717 -0
  50. celldetective/scripts/analyze_signals.py +51 -0
  51. celldetective/scripts/measure_cells.py +275 -0
  52. celldetective/scripts/segment_cells.py +212 -0
  53. celldetective/scripts/segment_cells_thresholds.py +140 -0
  54. celldetective/scripts/track_cells.py +206 -0
  55. celldetective/scripts/train_segmentation_model.py +246 -0
  56. celldetective/scripts/train_signal_model.py +49 -0
  57. celldetective/segmentation.py +712 -0
  58. celldetective/signals.py +2826 -0
  59. celldetective/tracking.py +974 -0
  60. celldetective/utils.py +1681 -0
  61. celldetective-1.0.2.dist-info/LICENSE +674 -0
  62. celldetective-1.0.2.dist-info/METADATA +192 -0
  63. celldetective-1.0.2.dist-info/RECORD +66 -0
  64. celldetective-1.0.2.dist-info/WHEEL +5 -0
  65. celldetective-1.0.2.dist-info/entry_points.txt +2 -0
  66. celldetective-1.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2826 @@
1
+ import numpy as np
2
+ import os
3
+ import subprocess
4
+ import json
5
+
6
+ from tensorflow.keras.optimizers import Adam
7
+ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau, CSVLogger
8
+ from tensorflow.keras.losses import CategoricalCrossentropy, MeanSquaredError, MeanAbsoluteError
9
+ from tensorflow.keras.metrics import Precision, Recall
10
+ from tensorflow.keras.models import load_model,clone_model
11
+ from tensorflow.config.experimental import list_physical_devices, set_memory_growth
12
+ from tensorflow.keras.utils import to_categorical, plot_model
13
+ from tensorflow.keras import Input, Model
14
+ from tensorflow.keras.layers import Conv1D, BatchNormalization, Dense, Activation, Add, MaxPooling1D, Dropout, GlobalAveragePooling1D, Concatenate, ZeroPadding1D, Flatten
15
+ from tensorflow.keras.callbacks import Callback
16
+ from sklearn.metrics import confusion_matrix, classification_report
17
+ from sklearn.metrics import jaccard_score, balanced_accuracy_score, precision_score, recall_score
18
+ from scipy.interpolate import interp1d
19
+ from scipy.ndimage import shift
20
+
21
+ from celldetective.io import get_signal_models_list, locate_signal_model
22
+ from celldetective.tracking import clean_trajectories
23
+ from celldetective.utils import regression_plot, train_test_split, compute_weights
24
+ import matplotlib.pyplot as plt
25
+ from natsort import natsorted
26
+ from glob import glob
27
+ import shutil
28
+ import random
29
+ from celldetective.utils import color_from_status, color_from_class
30
+ from math import floor, ceil
31
+ from scipy.optimize import curve_fit
32
+ import time
33
+ import math
34
+ import pandas as pd
35
+
36
+ abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
37
+
38
+
39
+ class TimeHistory(Callback):
40
+
41
+ """
42
+ A custom Keras callback to log the duration of each epoch during training.
43
+
44
+ This callback records the time taken for each epoch during the model training process, allowing for
45
+ monitoring of training efficiency and performance over time. The times are stored in a list, with each
46
+ element representing the duration of an epoch in seconds.
47
+
48
+ Attributes
49
+ ----------
50
+ times : list
51
+ A list of times (in seconds) taken for each epoch during the training. This list is populated as the
52
+ training progresses.
53
+
54
+ Methods
55
+ -------
56
+ on_train_begin(logs={})
57
+ Initializes the list of times at the beginning of training.
58
+
59
+ on_epoch_begin(epoch, logs={})
60
+ Records the start time of the current epoch.
61
+
62
+ on_epoch_end(epoch, logs={})
63
+ Calculates and appends the duration of the current epoch to the `times` list.
64
+
65
+ Notes
66
+ -----
67
+ - This callback is intended to be used with the `fit` method of Keras models.
68
+ - The time measurements are made using the `time.time()` function, which provides wall-clock time.
69
+
70
+ Examples
71
+ --------
72
+ >>> from keras.models import Sequential
73
+ >>> from keras.layers import Dense
74
+ >>> model = Sequential([Dense(10, activation='relu', input_shape=(20,)), Dense(1)])
75
+ >>> time_callback = TimeHistory()
76
+ >>> model.compile(optimizer='adam', loss='mean_squared_error')
77
+ >>> model.fit(x_train, y_train, epochs=10, callbacks=[time_callback])
78
+ >>> print(time_callback.times)
79
+ # This will print the time taken for each epoch during the training.
80
+
81
+ """
82
+
83
+ def on_train_begin(self, logs={}):
84
+ self.times = []
85
+
86
+ def on_epoch_begin(self, epoch, logs={}):
87
+ self.epoch_time_start = time.time()
88
+
89
+ def on_epoch_end(self, epoch, logs={}):
90
+ self.times.append(time.time() - self.epoch_time_start)
91
+
92
+
93
+ def analyze_signals(trajectories, model, interpolate_na=True,
94
+ selected_signals=None,
95
+ column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
96
+ plot_outcome=False, output_dir=None):
97
+
98
+ """
99
+ Analyzes signals from trajectory data using a specified signal detection model and configuration.
100
+
101
+ This function preprocesses trajectory data, selects specified signals, and applies a pretrained signal detection
102
+ model to predict classes and times of interest for each trajectory. It supports custom column labeling, interpolation
103
+ of missing values, and plotting of analysis outcomes.
104
+
105
+ Parameters
106
+ ----------
107
+ trajectories : pandas.DataFrame
108
+ DataFrame containing trajectory data with columns for track ID, frame, position, and signals.
109
+ model : str
110
+ The name of the signal detection model to be used for analysis.
111
+ interpolate_na : bool, optional
112
+ Whether to interpolate missing values in the trajectories (default is True).
113
+ selected_signals : list of str, optional
114
+ A list of column names from `trajectories` representing the signals to be analyzed. If None, signals will
115
+ be automatically selected based on the model configuration (default is None).
116
+ column_labels : dict, optional
117
+ A dictionary mapping the default column names ('track', 'time', 'x', 'y') to the corresponding column names
118
+ in `trajectories` (default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}).
119
+ plot_outcome : bool, optional
120
+ If True, generates and saves a plot of the signal analysis outcome (default is False).
121
+ output_dir : str, optional
122
+ The directory where the outcome plot will be saved. Required if `plot_outcome` is True (default is None).
123
+
124
+ Returns
125
+ -------
126
+ pandas.DataFrame
127
+ The input `trajectories` DataFrame with additional columns for predicted classes, times of interest, and
128
+ corresponding colors based on status and class.
129
+
130
+ Raises
131
+ ------
132
+ AssertionError
133
+ If the model or its configuration file cannot be located.
134
+
135
+ Notes
136
+ -----
137
+ - The function relies on an external model configuration file (`config_input.json`) located in the model's directory.
138
+ - Signal selection and preprocessing are based on the requirements specified in the model's configuration.
139
+
140
+ """
141
+
142
+ model_path = locate_signal_model(model)
143
+ complete_path = model_path #+model
144
+ complete_path = rf"{complete_path}"
145
+ model_config_path = os.sep.join([complete_path,'config_input.json'])
146
+ model_config_path = rf"{model_config_path}"
147
+ assert os.path.exists(complete_path),f'Model {model} could not be located in folder {model_path}... Abort.'
148
+ assert os.path.exists(model_config_path),f'Model configuration could not be located in folder {model_path}... Abort.'
149
+
150
+ available_signals = list(trajectories.columns)
151
+ print('The available_signals are : ',available_signals)
152
+
153
+ f = open(model_config_path)
154
+ config = json.load(f)
155
+ required_signals = config["channels"]
156
+
157
+ try:
158
+ label = config['label']
159
+ if label=='':
160
+ label = None
161
+ except:
162
+ label = None
163
+
164
+ if selected_signals is None:
165
+ selected_signals = []
166
+ for s in required_signals:
167
+ pattern_test = [s in a or s==a for a in available_signals]
168
+ print(f'Pattern test for signal {s}: ', pattern_test)
169
+ assert np.any(pattern_test),f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
170
+ valid_columns = np.array(available_signals)[np.array(pattern_test)]
171
+ if len(valid_columns)==1:
172
+ selected_signals.append(valid_columns[0])
173
+ else:
174
+ #print(test_number_of_nan(trajectories, valid_columns))
175
+ print(f'Found several candidate signals: {valid_columns}')
176
+ for vc in natsorted(valid_columns):
177
+ if 'circle' in vc:
178
+ selected_signals.append(vc)
179
+ break
180
+ else:
181
+ selected_signals.append(valid_columns[0])
182
+ # do something more complicated in case of one to many columns
183
+ #pass
184
+ else:
185
+ assert len(selected_signals)==len(required_signals),f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
186
+
187
+ print(f'The following channels will be passed to the model: {selected_signals}')
188
+ trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na, interpolate_position_gaps=interpolate_na, column_labels=column_labels)
189
+
190
+ max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
191
+ tracks = trajectories_clean[column_labels['track']].unique()
192
+ signals = np.zeros((len(tracks),max_signal_size, len(selected_signals)))
193
+
194
+ for i,(tid,group) in enumerate(trajectories_clean.groupby(column_labels['track'])):
195
+ frames = group[column_labels['time']].to_numpy().astype(int)
196
+ for j,col in enumerate(selected_signals):
197
+ signal = group[col].to_numpy()
198
+ signals[i,frames,j] = signal
199
+
200
+ # for i in range(5):
201
+ # print('pre model')
202
+ # plt.plot(signals[i,:,0])
203
+ # plt.show()
204
+
205
+ model = SignalDetectionModel(pretrained=complete_path)
206
+ print('signal shape: ', signals.shape)
207
+
208
+ classes = model.predict_class(signals)
209
+ times_recast = model.predict_time_of_interest(signals)
210
+
211
+ if label is None:
212
+ class_col = 'class'
213
+ time_col = 't0'
214
+ status_col = 'status'
215
+ else:
216
+ class_col = 'class_'+label
217
+ time_col = 't_'+label
218
+ status_col = 'status_'+label
219
+
220
+ for i,(tid,group) in enumerate(trajectories.groupby(column_labels['track'])):
221
+ indices = group.index
222
+ trajectories.loc[indices,class_col] = classes[i]
223
+ trajectories.loc[indices,time_col] = times_recast[i]
224
+ print('Done.')
225
+
226
+ for tid, group in trajectories.groupby(column_labels['track']):
227
+
228
+ indices = group.index
229
+ t0 = group[time_col].to_numpy()[0]
230
+ cclass = group[class_col].to_numpy()[0]
231
+ timeline = group[column_labels['time']].to_numpy()
232
+ status = np.zeros_like(timeline)
233
+ if t0 > 0:
234
+ status[timeline>=t0] = 1.
235
+ if cclass==2:
236
+ status[:] = 2
237
+ if cclass>2:
238
+ status[:] = 42
239
+ status_color = [color_from_status(s) for s in status]
240
+ class_color = [color_from_class(cclass) for i in range(len(status))]
241
+
242
+ trajectories.loc[indices, status_col] = status
243
+ trajectories.loc[indices, 'status_color'] = status_color
244
+ trajectories.loc[indices, 'class_color'] = class_color
245
+
246
+ if plot_outcome:
247
+ fig,ax = plt.subplots(1,len(selected_signals), figsize=(10,5))
248
+ for i,s in enumerate(selected_signals):
249
+ for k,(tid,group) in enumerate(trajectories.groupby(column_labels['track'])):
250
+ cclass = group[class_col].to_numpy()[0]
251
+ t0 = group[time_col].to_numpy()[0]
252
+ timeline = group[column_labels['time']].to_numpy()
253
+ if cclass==0:
254
+ if len(selected_signals)>1:
255
+ ax[i].plot(timeline - t0, group[s].to_numpy(),c='tab:blue',alpha=0.1)
256
+ else:
257
+ ax.plot(timeline - t0, group[s].to_numpy(),c='tab:blue',alpha=0.1)
258
+ if len(selected_signals)>1:
259
+ for a,s in zip(ax,selected_signals):
260
+ a.set_title(s)
261
+ a.set_xlabel(r'time - t$_0$ [frame]')
262
+ a.spines['top'].set_visible(False)
263
+ a.spines['right'].set_visible(False)
264
+ else:
265
+ ax.set_title(s)
266
+ ax.set_xlabel(r'time - t$_0$ [frame]')
267
+ ax.spines['top'].set_visible(False)
268
+ ax.spines['right'].set_visible(False)
269
+ plt.tight_layout()
270
+ if output_dir is not None:
271
+ plt.savefig(output_dir+'signal_collapse.png',bbox_inches='tight',dpi=300)
272
+ plt.pause(3)
273
+ plt.close()
274
+
275
+ return trajectories
276
+
277
+ def analyze_signals_at_position(pos, model, mode, use_gpu=True, return_table=False):
278
+
279
+ """
280
+ Analyzes signals for a given position directory using a specified model and mode, with an option to use GPU acceleration.
281
+
282
+ This function executes an external Python script to analyze signals within the specified position directory, applying
283
+ a predefined model in a specified mode. It supports GPU acceleration for faster processing. Optionally, the function
284
+ can return the resulting analysis table as a pandas DataFrame.
285
+
286
+ Parameters
287
+ ----------
288
+ pos : str
289
+ The file path to the position directory containing the data to be analyzed. The path must be valid and accessible.
290
+ model : str
291
+ The name of the model to use for signal analysis.
292
+ mode : str
293
+ The operation mode specifying how the analysis should be conducted.
294
+ use_gpu : bool, optional
295
+ Specifies whether to use GPU acceleration for the analysis (default is True).
296
+ return_table : bool, optional
297
+ If True, the function returns a pandas DataFrame containing the analysis results (default is False).
298
+
299
+ Returns
300
+ -------
301
+ pandas.DataFrame or None
302
+ If `return_table` is True, returns a DataFrame containing the analysis results. Otherwise, returns None.
303
+
304
+ Raises
305
+ ------
306
+ AssertionError
307
+ If the specified position path does not exist.
308
+
309
+ Notes
310
+ -----
311
+ - The analysis is performed by an external script (`analyze_signals.py`) located in a specific directory relative
312
+ to this function.
313
+ - The results of the analysis are expected to be saved in the "output/tables" subdirectory within the position
314
+ directory, following a naming convention based on the analysis `mode`.
315
+
316
+ """
317
+
318
+ pos = pos.replace('\\','/')
319
+ pos = rf"{pos}"
320
+ assert os.path.exists(pos),f'Position {pos} is not a valid path.'
321
+ if not pos.endswith('/'):
322
+ pos += '/'
323
+
324
+ script_path = os.sep.join([abs_path, 'scripts', 'analyze_signals.py'])
325
+ cmd = f'python "{script_path}" --pos "{pos}" --model "{model}" --mode "{mode}" --use_gpu "{use_gpu}"'
326
+ subprocess.call(cmd, shell=True)
327
+
328
+ table = pos + os.sep.join(["output","tables",f"trajectories_{mode}.csv"])
329
+ if return_table:
330
+ df = pd.read_csv(table)
331
+ return df
332
+ else:
333
+ return None
334
+
335
+
336
+ class SignalDetectionModel(object):
337
+
338
+ """
339
+ A class for creating and managing signal detection models for analyzing biological signals.
340
+
341
+ This class provides functionalities to load a pretrained signal detection model or create one from scratch,
342
+ preprocess input signals, train the model, and make predictions on new data.
343
+
344
+ Parameters
345
+ ----------
346
+ path : str, optional
347
+ Path to the directory containing the model and its configuration. This is used when loading a pretrained model.
348
+ pretrained : str, optional
349
+ Path to the pretrained model to load. If specified, the model and its configuration are loaded from this path.
350
+ channel_option : list of str, optional
351
+ Specifies the channels to be used for signal analysis. Default is ["live_nuclei_channel"].
352
+ model_signal_length : int, optional
353
+ The length of the input signals that the model expects. Default is 128.
354
+ n_channels : int, optional
355
+ The number of channels in the input signals. Default is 1.
356
+ n_conv : int, optional
357
+ The number of convolutional layers in the model. Default is 2.
358
+ n_classes : int, optional
359
+ The number of classes for the classification task. Default is 3.
360
+ dense_collection : int, optional
361
+ The number of units in the dense layer of the model. Default is 512.
362
+ dropout_rate : float, optional
363
+ The dropout rate applied to the dense layer of the model. Default is 0.1.
364
+ label : str, optional
365
+ A label for the model, used in naming and organizing outputs. Default is ''.
366
+
367
+ Attributes
368
+ ----------
369
+ model_class : keras Model
370
+ The classification model for predicting the class of signals.
371
+ model_reg : keras Model
372
+ The regression model for predicting the time of interest for signals.
373
+
374
+ Methods
375
+ -------
376
+ load_pretrained_model()
377
+ Loads the model and its configuration from the pretrained path.
378
+ create_models_from_scratch()
379
+ Creates new models for classification and regression from scratch.
380
+ prep_gpu()
381
+ Prepares GPU devices for training, if available.
382
+ fit_from_directory(ds_folders, ...)
383
+ Trains the model using data from specified directories.
384
+ fit(x_train, y_time_train, y_class_train, ...)
385
+ Trains the model using provided datasets.
386
+ predict_class(x, ...)
387
+ Predicts the class of input signals.
388
+ predict_time_of_interest(x, ...)
389
+ Predicts the time of interest for input signals.
390
+ plot_model_history(mode)
391
+ Plots the training history for the specified mode (classifier or regressor).
392
+ evaluate_regression_model()
393
+ Evaluates the regression model on test and validation data.
394
+ gather_callbacks(mode)
395
+ Gathers and prepares callbacks for training based on the specified mode.
396
+ generate_sets()
397
+ Generates training, validation, and test sets from loaded data.
398
+ augment_training_set()
399
+ Augments the training set with additional generated data.
400
+ load_and_normalize(subset)
401
+ Loads and normalizes signals from a subset of data.
402
+
403
+ Notes
404
+ -----
405
+ - This class is designed to work with biological signal data, such as time series from microscopy imaging.
406
+ - The model architecture and training configurations can be customized through the class parameters and methods.
407
+
408
+ """
409
+
410
+
411
+ def __init__(self, path=None, pretrained=None, channel_option=["live_nuclei_channel"], model_signal_length=128, n_channels=1,
412
+ n_conv=2, n_classes=3, dense_collection=512, dropout_rate=0.1, label=''):
413
+
414
+ self.prep_gpu()
415
+
416
+ self.model_signal_length = model_signal_length
417
+ self.channel_option = channel_option
418
+ self.pretrained = pretrained
419
+ self.n_channels = n_channels
420
+ self.n_conv = n_conv
421
+ self.n_classes = n_classes
422
+ self.dense_collection = dense_collection
423
+ self.dropout_rate = dropout_rate
424
+ self.label = label
425
+
426
+
427
+ if self.pretrained is not None:
428
+ print(f"Load pretrained models from {path}...")
429
+ self.load_pretrained_model()
430
+ else:
431
+ print("Create models from scratch...")
432
+ self.create_models_from_scratch()
433
+
434
+
435
+ def load_pretrained_model(self):
436
+
437
+ """
438
+ Loads a pretrained model and its configuration from the specified path.
439
+
440
+ This method attempts to load both the classification and regression models from the path specified during the
441
+ class instantiation. It also loads the model configuration from a JSON file and updates the model attributes
442
+ accordingly. If the models cannot be loaded, an error message is printed.
443
+
444
+ Raises
445
+ ------
446
+ Exception
447
+ If there is an error loading the model or the configuration file, an exception is raised with details.
448
+
449
+ Notes
450
+ -----
451
+ - The models are expected to be saved in .h5 format with the filenames "classifier.h5" and "regressor.h5".
452
+ - The configuration file is expected to be named "config_input.json" and located in the same directory as the models.
453
+ """
454
+
455
+ try:
456
+ self.model_class = load_model(os.sep.join([self.pretrained,"classifier.h5"]),compile=False)
457
+ self.model_class.load_weights(os.sep.join([self.pretrained,"classifier.h5"]))
458
+ print("Classifier successfully loaded...")
459
+ except Exception as e:
460
+ print(f"Error {e}...")
461
+ self.model_class = None
462
+ try:
463
+ self.model_reg = load_model(os.sep.join([self.pretrained,"regressor.h5"]),compile=False)
464
+ self.model_reg.load_weights(os.sep.join([self.pretrained,"regressor.h5"]))
465
+ print("Regressor successfully loaded...")
466
+ except Exception as e:
467
+ print(f"Error {e}...")
468
+ self.model_reg = None
469
+
470
+ # load config
471
+ with open(os.sep.join([self.pretrained,"config_input.json"])) as config_file:
472
+ model_config = json.load(config_file)
473
+
474
+ req_channels = model_config["channels"]
475
+ print(f"Required channels read from pretrained model: {req_channels}")
476
+ self.channel_option = req_channels
477
+ if 'normalize' in model_config:
478
+ self.normalize = model_config['normalize']
479
+ if 'normalization_percentile' in model_config:
480
+ self.normalization_percentile = model_config['normalization_percentile']
481
+ if 'normalization_values' in model_config:
482
+ self.normalization_values = model_config['normalization_values']
483
+ if 'normalization_percentile' in model_config:
484
+ self.normalization_clip = model_config['normalization_clip']
485
+ if 'label' in model_config:
486
+ self.label = model_config['label']
487
+
488
+ self.n_channels = self.model_class.layers[0].input_shape[0][-1]
489
+ self.model_signal_length = self.model_class.layers[0].input_shape[0][-2]
490
+ self.n_classes = self.model_class.layers[-1].output_shape[-1]
491
+
492
+ assert self.model_class.layers[0].input_shape[0] == self.model_reg.layers[0].input_shape[0], f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
493
+
494
+
495
+ def create_models_from_scratch(self):
496
+
497
+ """
498
+ Initializes new models for classification and regression based on the specified parameters.
499
+
500
+ This method creates new ResNet models for both classification and regression tasks using the parameters specified
501
+ during class instantiation. The models are configured but not compiled or trained.
502
+
503
+ Notes
504
+ -----
505
+ - The models are created using a custom ResNet architecture defined elsewhere in the codebase.
506
+ - The models are stored in the `model_class` and `model_reg` attributes of the class.
507
+ """
508
+
509
+ self.model_class = ResNetModelCurrent(n_channels=self.n_channels,
510
+ n_slices=self.n_conv,
511
+ n_classes = self.n_classes,
512
+ dense_collection=self.dense_collection,
513
+ dropout_rate=self.dropout_rate,
514
+ header="classifier",
515
+ model_signal_length = self.model_signal_length
516
+ )
517
+
518
+ self.model_reg = ResNetModelCurrent(n_channels=self.n_channels,
519
+ n_slices=self.n_conv,
520
+ n_classes = self.n_classes,
521
+ dense_collection=self.dense_collection,
522
+ dropout_rate=self.dropout_rate,
523
+ header="regressor",
524
+ model_signal_length = self.model_signal_length
525
+ )
526
+
527
+ def prep_gpu(self):
528
+
529
+ """
530
+ Prepares GPU devices for training by enabling memory growth.
531
+
532
+ This method attempts to identify available GPU devices and configures TensorFlow to allow memory growth on each
533
+ GPU. This prevents TensorFlow from allocating the total available memory on the GPU device upfront.
534
+
535
+ Notes
536
+ -----
537
+ - This method should be called before any TensorFlow/Keras operations that might allocate GPU memory.
538
+ - If no GPUs are detected, the method will pass silently.
539
+ """
540
+
541
+ try:
542
+ physical_devices = list_physical_devices('GPU')
543
+ for gpu in physical_devices:
544
+ set_memory_growth(gpu, True)
545
+ except:
546
+ pass
547
+
548
+ def fit_from_directory(self, ds_folders, normalize=True, normalization_percentile=None, normalization_values = None,
549
+ normalization_clip = None, channel_option=["live_nuclei_channel"], model_name=None, target_directory=None,
550
+ augment=True, augmentation_factor=2, validation_split=0.20, test_split=0.0, batch_size = 64, epochs=300,
551
+ recompile_pretrained=False, learning_rate=0.01, loss_reg="mse", loss_class = CategoricalCrossentropy(from_logits=False)):
552
+
553
+ """
554
+ Trains the model using data from specified directories.
555
+
556
+ This method prepares the dataset for training by loading and preprocessing data from specified directories,
557
+ then trains the classification and regression models.
558
+
559
+ Parameters
560
+ ----------
561
+ ds_folders : list of str
562
+ List of directories containing the dataset files for training.
563
+ normalize : bool, optional
564
+ Whether to normalize the input signals (default is True).
565
+ normalization_percentile : list or None, optional
566
+ Percentiles for signal normalization (default is None).
567
+ normalization_values : list or None, optional
568
+ Specific values for signal normalization (default is None).
569
+ normalization_clip : bool, optional
570
+ Whether to clip the normalized signals (default is None).
571
+ channel_option : list of str, optional
572
+ Specifies the channels to be used for signal analysis (default is ["live_nuclei_channel"]).
573
+ model_name : str, optional
574
+ Name of the model for saving purposes (default is None).
575
+ target_directory : str, optional
576
+ Directory where the trained model and outputs will be saved (default is None).
577
+ augment : bool, optional
578
+ Whether to augment the training data (default is True).
579
+ augmentation_factor : int, optional
580
+ Factor by which to augment the training data (default is 2).
581
+ validation_split : float, optional
582
+ Fraction of the data to be used as validation set (default is 0.20).
583
+ test_split : float, optional
584
+ Fraction of the data to be used as test set (default is 0.0).
585
+ batch_size : int, optional
586
+ Batch size for training (default is 64).
587
+ epochs : int, optional
588
+ Number of epochs to train for (default is 300).
589
+ recompile_pretrained : bool, optional
590
+ Whether to recompile a pretrained model (default is False).
591
+ learning_rate : float, optional
592
+ Learning rate for the optimizer (default is 0.01).
593
+ loss_reg : str or keras.losses.Loss, optional
594
+ Loss function for the regression model (default is "mse").
595
+ loss_class : str or keras.losses.Loss, optional
596
+ Loss function for the classification model (default is CategoricalCrossentropy(from_logits=False)).
597
+
598
+ Notes
599
+ -----
600
+ - The method automatically splits the dataset into training, validation, and test sets according to the specified splits.
601
+ """
602
+
603
+
604
+ if not hasattr(self, 'normalization_percentile'):
605
+ self.normalization_percentile = normalization_percentile
606
+ if not hasattr(self, 'normalization_values'):
607
+ self.normalization_values = normalization_values
608
+ if not hasattr(self, 'normalization_clip'):
609
+ self.normalization_clip = normalization_clip
610
+ print('Actual clip option:', self.normalization_clip)
611
+
612
+ self.normalize = normalize
613
+ self.normalization_percentile, self. normalization_values, self.normalization_clip = _interpret_normalization_parameters(self.n_channels, self.normalization_percentile, self.normalization_values, self.normalization_clip)
614
+
615
+ self.ds_folders = [rf'{d}' for d in ds_folders]
616
+ self.batch_size = batch_size
617
+ self.epochs = epochs
618
+ self.validation_split = validation_split
619
+ self.test_split = test_split
620
+ self.augment = augment
621
+ self.augmentation_factor = augmentation_factor
622
+ self.model_name = rf'{model_name}'
623
+ self.target_directory = rf'{target_directory}'
624
+ self.model_folder = os.sep.join([self.target_directory,self.model_name])
625
+ self.recompile_pretrained = recompile_pretrained
626
+ self.learning_rate = learning_rate
627
+ self.loss_reg = loss_reg
628
+ self.loss_class = loss_class
629
+
630
+
631
+ if not os.path.exists(self.model_folder):
632
+ #shutil.rmtree(self.model_folder)
633
+ os.mkdir(self.model_folder)
634
+
635
+ self.channel_option = channel_option
636
+ assert self.n_channels==len(self.channel_option), f'Mismatch between the channel option and the number of channels of the model...'
637
+
638
+ self.list_of_sets = []
639
+ print(self.ds_folders)
640
+ for f in self.ds_folders:
641
+ self.list_of_sets.extend(glob(os.sep.join([f,"*.npy"])))
642
+ print(f"Found {len(self.list_of_sets)} annotation files...")
643
+ self.generate_sets()
644
+
645
+ self.train_classifier()
646
+ self.train_regressor()
647
+
648
+ config_input = {"channels": self.channel_option, "model_signal_length": self.model_signal_length, 'label': self.label, 'normalize': self.normalize, 'normalization_percentile': self.normalization_percentile, 'normalization_values': self.normalization_values, 'normalization_clip': self.normalization_clip}
649
+ json_string = json.dumps(config_input)
650
+ with open(os.sep.join([self.model_folder,"config_input.json"]), 'w') as outfile:
651
+ outfile.write(json_string)
652
+
653
+ def fit(self, x_train, y_time_train, y_class_train, normalize=True, normalization_percentile=None, normalization_values = None, normalization_clip = None, pad=True, validation_data=None, test_data=None, channel_option=["live_nuclei_channel","dead_nuclei_channel"], model_name=None,
654
+ target_directory=None, augment=True, augmentation_factor=3, validation_split=0.25, batch_size = 64, epochs=300,
655
+ recompile_pretrained=False, learning_rate=0.001, loss_reg="mse", loss_class = CategoricalCrossentropy(from_logits=False)):
656
+
657
+ """
658
+ Trains the model using provided datasets.
659
+
660
+ Parameters
661
+ ----------
662
+ Same as `fit_from_directory`, but instead of loading data from directories, this method accepts preloaded and
663
+ optionally preprocessed datasets directly.
664
+
665
+ Notes
666
+ -----
667
+ - This method provides an alternative way to train the model when data is already loaded into memory, offering
668
+ flexibility for data preprocessing steps outside this class.
669
+ """
670
+
671
+ self.normalize = normalize
672
+ if not hasattr(self, 'normalization_percentile'):
673
+ self.normalization_percentile = normalization_percentile
674
+ if not hasattr(self, 'normalization_values'):
675
+ self.normalization_values = normalization_values
676
+ if not hasattr(self, 'normalization_clip'):
677
+ self.normalization_clip = normalization_clip
678
+ self.normalization_percentile, self. normalization_values, self.normalization_clip = _interpret_normalization_parameters(self.n_channels, self.normalization_percentile, self.normalization_values, self.normalization_clip)
679
+
680
+ self.x_train = x_train
681
+ self.y_class_train = y_class_train
682
+ self.y_time_train = y_time_train
683
+ self.channel_option = channel_option
684
+
685
+ assert self.n_channels==len(self.channel_option), f'Mismatch between the channel option and the number of channels of the model...'
686
+
687
+ if pad:
688
+ self.x_train = pad_to_model_length(self.x_train, self.model_signal_length)
689
+
690
+ assert self.x_train.shape[1:] == (self.model_signal_length, self.n_channels), f"Shape mismatch between the provided training fluorescence signals and the model..."
691
+
692
+ # If y-class is not one-hot encoded, encode it
693
+ if self.y_class_train.shape[-1] != self.n_classes:
694
+ self.class_weights = compute_weights(self.y_class_train)
695
+ self.y_class_train = to_categorical(self.y_class_train)
696
+
697
+ if self.normalize:
698
+ self.y_time_train = self.y_time_train.astype(np.float32)/self.model_signal_length
699
+ self.x_train = normalize_signal_set(self.x_train, self.channel_option, normalization_percentile=self.normalization_percentile,
700
+ normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
701
+ )
702
+
703
+
704
+ if validation_data is not None:
705
+ try:
706
+ self.x_val = validation_data[0]
707
+ if pad:
708
+ self.x_val = pad_to_model_length(self.x_val, self.model_signal_length)
709
+ self.y_class_val = validation_data[1]
710
+ if self.y_class_val.shape[-1] != self.n_classes:
711
+ self.y_class_val = to_categorical(self.y_class_val)
712
+ self.y_time_val = validation_data[2]
713
+ if self.normalize:
714
+ self.y_time_val = self.y_time_val.astype(np.float32)/self.model_signal_length
715
+ self.x_val = normalize_signal_set(self.x_val, self.channel_option, normalization_percentile=self.normalization_percentile,
716
+ normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
717
+ )
718
+
719
+ except Exception as e:
720
+ print("Could not load validation data, error {e}...")
721
+ else:
722
+ self.validation_split = validation_split
723
+
724
+ if test_data is not None:
725
+ try:
726
+ self.x_test = test_data[0]
727
+ if pad:
728
+ self.x_test = pad_to_model_length(self.x_test, self.model_signal_length)
729
+ self.y_class_test = test_data[1]
730
+ if self.y_class_test.shape[-1] != self.n_classes:
731
+ self.y_class_test = to_categorical(self.y_class_test)
732
+ self.y_time_test = test_data[2]
733
+ if self.normalize:
734
+ self.y_time_test = self.y_time_test.astype(np.float32)/self.model_signal_length
735
+ self.x_test = normalize_signal_set(self.x_test, self.channel_option, normalization_percentile=self.normalization_percentile,
736
+ normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
737
+ )
738
+ except Exception as e:
739
+ print("Could not load test data, error {e}...")
740
+
741
+
742
+ self.batch_size = batch_size
743
+ self.epochs = epochs
744
+ self.augment = augment
745
+ self.augmentation_factor = augmentation_factor
746
+ if self.augmentation_factor==1:
747
+ self.augment = False
748
+ self.model_name = model_name
749
+ self.target_directory = target_directory
750
+ self.model_folder = os.sep.join([self.target_directory,self.model_name])
751
+ self.recompile_pretrained = recompile_pretrained
752
+ self.learning_rate = learning_rate
753
+ self.loss_reg = loss_reg
754
+ self.loss_class = loss_class
755
+
756
+ if os.path.exists(self.model_folder):
757
+ shutil.rmtree(self.model_folder)
758
+ os.mkdir(self.model_folder)
759
+
760
+ self.train_classifier()
761
+ self.train_regressor()
762
+
763
+ def predict_class(self, x, normalize=True, pad=True, return_one_hot=False, interpolate=True):
764
+
765
+ """
766
+ Predicts the class of input signals using the trained classification model.
767
+
768
+ Parameters
769
+ ----------
770
+ x : ndarray
771
+ The input signals for which to predict classes.
772
+ normalize : bool, optional
773
+ Whether to normalize the input signals (default is True).
774
+ pad : bool, optional
775
+ Whether to pad the input signals to match the model's expected signal length (default is True).
776
+ return_one_hot : bool, optional
777
+ Whether to return predictions in one-hot encoded format (default is False).
778
+ interpolate : bool, optional
779
+ Whether to interpolate the input signals (default is True).
780
+
781
+ Returns
782
+ -------
783
+ ndarray
784
+ The predicted classes for the input signals. If `return_one_hot` is True, predictions are returned in one-hot
785
+ encoded format, otherwise as integer labels.
786
+
787
+ Notes
788
+ -----
789
+ - The method processes the input signals according to the specified options to ensure compatibility with the model's
790
+ input requirements.
791
+ """
792
+
793
+ self.x = np.copy(x)
794
+ self.normalize = normalize
795
+ self.pad = pad
796
+ self.return_one_hot = return_one_hot
797
+ # self.max_relevant_time = np.shape(self.x)[1]
798
+ # print(f'Max relevant time: {self.max_relevant_time}')
799
+
800
+ if self.pad:
801
+ self.x = pad_to_model_length(self.x, self.model_signal_length)
802
+
803
+ if self.normalize:
804
+ self.x = normalize_signal_set(self.x, self.channel_option, normalization_percentile=self.normalization_percentile,
805
+ normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
806
+ )
807
+
808
+ # implement auto interpolation here!!
809
+ #self.x = self.interpolate_signals(self.x)
810
+
811
+ # for i in range(5):
812
+ # plt.plot(self.x[i,:,0])
813
+ # plt.show()
814
+
815
+ assert self.x.shape[-1] == self.model_class.layers[0].input_shape[0][-1], f"Shape mismatch between the input shape and the model input shape..."
816
+ assert self.x.shape[-2] == self.model_class.layers[0].input_shape[0][-2], f"Shape mismatch between the input shape and the model input shape..."
817
+
818
+ self.class_predictions_one_hot = self.model_class.predict(self.x)
819
+ self.class_predictions = self.class_predictions_one_hot.argmax(axis=1)
820
+
821
+ if self.return_one_hot:
822
+ return self.class_predictions_one_hot
823
+ else:
824
+ return self.class_predictions
825
+
826
+ def predict_time_of_interest(self, x, class_predictions=None, normalize=True, pad=True):
827
+
828
+ """
829
+ Predicts the time of interest for input signals using the trained regression model.
830
+
831
+ Parameters
832
+ ----------
833
+ x : ndarray
834
+ The input signals for which to predict times of interest.
835
+ class_predictions : ndarray, optional
836
+ The predicted classes for the input signals. If provided, time of interest predictions are only made for
837
+ signals predicted to belong to a specific class (default is None).
838
+ normalize : bool, optional
839
+ Whether to normalize the input signals (default is True).
840
+ pad : bool, optional
841
+ Whether to pad the input signals to match the model's expected signal length (default is True).
842
+
843
+ Returns
844
+ -------
845
+ ndarray
846
+ The predicted times of interest for the input signals.
847
+
848
+ Notes
849
+ -----
850
+ - The method processes the input signals according to the specified options and uses the regression model to
851
+ predict times at which a particular event of interest occurs.
852
+ """
853
+
854
+ self.x = np.copy(x)
855
+ self.normalize = normalize
856
+ self.pad = pad
857
+ # self.max_relevant_time = np.shape(self.x)[1]
858
+ # print(f'Max relevant time: {self.max_relevant_time}')
859
+
860
+ if class_predictions is not None:
861
+ self.class_predictions = class_predictions
862
+
863
+ if self.pad:
864
+ self.x = pad_to_model_length(self.x, self.model_signal_length)
865
+
866
+ if self.normalize:
867
+ self.x = normalize_signal_set(self.x, self.channel_option, normalization_percentile=self.normalization_percentile,
868
+ normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
869
+ )
870
+
871
+ assert self.x.shape[-1] == self.model_reg.layers[0].input_shape[0][-1], f"Shape mismatch between the input shape and the model input shape..."
872
+ assert self.x.shape[-2] == self.model_reg.layers[0].input_shape[0][-2], f"Shape mismatch between the input shape and the model input shape..."
873
+
874
+ if np.any(self.class_predictions==0):
875
+ self.time_predictions = self.model_reg.predict(self.x[self.class_predictions==0])*self.model_signal_length
876
+ self.time_predictions = self.time_predictions[:,0]
877
+ self.time_predictions_recast = np.zeros(len(self.x)) - 1.
878
+ self.time_predictions_recast[self.class_predictions==0] = self.time_predictions
879
+ else:
880
+ self.time_predictions_recast = np.zeros(len(self.x)) - 1.
881
+ return self.time_predictions_recast
882
+
883
+ def interpolate_signals(self, x_set):
884
+
885
+ """
886
+ Interpolates missing values in the input signal set.
887
+
888
+ Parameters
889
+ ----------
890
+ x_set : ndarray
891
+ The input signal set with potentially missing values.
892
+
893
+ Returns
894
+ -------
895
+ ndarray
896
+ The input signal set with missing values interpolated.
897
+
898
+ Notes
899
+ -----
900
+ - This method is useful for preparing signals that have gaps or missing time points before further processing
901
+ or model training.
902
+ """
903
+
904
+ for i in range(len(x_set)):
905
+ for k in range(x_set.shape[-1]):
906
+ x = x_set[i,:,k]
907
+ not_nan = np.logical_not(np.isnan(x))
908
+ indices = np.arange(len(x))
909
+ interp = interp1d(indices[not_nan], x[not_nan],fill_value=(0.,0.), bounds_error=False)
910
+ x_set[i,:,k] = interp(indices)
911
+ return x_set
912
+
913
+
914
+
915
+ def train_classifier(self):
916
+
917
+ """
918
+ Trains the classifier component of the model to predict event classes in signals.
919
+
920
+ This method compiles the classifier model (if not pretrained or if recompilation is requested) and
921
+ trains it on the prepared dataset. The training process includes validation and early stopping based
922
+ on precision to prevent overfitting.
923
+
924
+ Notes
925
+ -----
926
+ - The classifier model predicts the class of each signal, such as live, dead, or miscellaneous.
927
+ - Training parameters such as epochs, batch size, and learning rate are specified during class instantiation.
928
+ - Model performance metrics and training history are saved for analysis.
929
+ """
930
+
931
+ # if pretrained model
932
+ if self.pretrained is not None:
933
+ # if recompile
934
+ if self.recompile_pretrained:
935
+ print('Recompiling the pretrained classifier model... Warning, this action reinitializes all the weights; are you sure that this is what you intended?')
936
+ self.model_class.set_weights(clone_model(self.model_class).get_weights())
937
+ self.model_class.compile(optimizer=Adam(learning_rate=self.learning_rate),
938
+ loss=self.loss_class,
939
+ metrics=['accuracy', Precision(), Recall()])
940
+ else:
941
+ self.initial_model = clone_model(self.model_class)
942
+ self.model_class.set_weights(self.initial_model.get_weights())
943
+ # Recompile to avoid crash
944
+ self.model_class.compile(optimizer=Adam(learning_rate=self.learning_rate),
945
+ loss=self.loss_class,
946
+ metrics=['accuracy', Precision(), Recall()])
947
+ # Reset weights
948
+ self.model_class.set_weights(self.initial_model.get_weights())
949
+ else:
950
+ print("Compiling the classifier...")
951
+ self.model_class.compile(optimizer=Adam(learning_rate=self.learning_rate),
952
+ loss=self.loss_class,
953
+ metrics=['accuracy', Precision(), Recall()])
954
+
955
+ self.gather_callbacks("classifier")
956
+
957
+
958
+ # for i in range(30):
959
+ # for j in range(self.x_train.shape[-1]):
960
+ # plt.plot(self.x_train[i,:,j])
961
+ # plt.show()
962
+
963
+ if hasattr(self, 'x_val'):
964
+ self.history_classifier = self.model_class.fit(x=self.x_train,
965
+ y=self.y_class_train,
966
+ batch_size=self.batch_size,
967
+ class_weight=self.class_weights,
968
+ epochs=self.epochs,
969
+ validation_data=(self.x_val,self.y_class_val),
970
+ callbacks=self.cb,
971
+ verbose=1)
972
+ else:
973
+ self.history_classifier = self.model_class.fit(x=self.x_train,
974
+ y=self.y_class_train,
975
+ batch_size=self.batch_size,
976
+ class_weight=self.class_weights,
977
+ epochs=self.epochs,
978
+ callbacks=self.cb,
979
+ validation_split = self.validation_split,
980
+ verbose=1)
981
+
982
+ self.plot_model_history(mode="classifier")
983
+
984
+ # Set current classification model as the best model
985
+ self.model_class = load_model(os.sep.join([self.model_folder,"classifier.h5"]))
986
+ self.model_class.load_weights(os.sep.join([self.model_folder,"classifier.h5"]))
987
+
988
+ self.dico = {"history_classifier": self.history_classifier, "execution_time_classifier": self.cb[-1].times}
989
+
990
+ if hasattr(self, 'x_test'):
991
+
992
+ predictions = self.model_class.predict(self.x_test).argmax(axis=1)
993
+ ground_truth = self.y_class_test.argmax(axis=1)
994
+ assert predictions.shape==ground_truth.shape,"Mismatch in shape between the predictions and the ground truth..."
995
+
996
+ title="Test data"
997
+ IoU_score = jaccard_score(ground_truth, predictions, average=None)
998
+ balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
999
+ precision = precision_score(ground_truth, predictions, average=None)
1000
+ recall = recall_score(ground_truth, predictions, average=None)
1001
+
1002
+ print(f"Test IoU score: {IoU_score}")
1003
+ print(f"Test Balanced accuracy score: {balanced_accuracy}")
1004
+ print(f'Test Precision: {precision}')
1005
+ print(f'Test Recall: {recall}')
1006
+
1007
+ # Confusion matrix on test set
1008
+ results = confusion_matrix(ground_truth,predictions)
1009
+ self.dico.update({"test_IoU": IoU_score, "test_balanced_accuracy": balanced_accuracy, "test_confusion": results, 'test_precision': precision, 'test_recall': recall})
1010
+
1011
+ try:
1012
+ plot_confusion_matrix(results, ["dead","alive","miscellaneous"], output_dir=self.model_folder+os.sep, title=title)
1013
+ except:
1014
+ pass
1015
+ print("Test set: ",classification_report(ground_truth,predictions))
1016
+
1017
+ if hasattr(self, 'x_val'):
1018
+ predictions = self.model_class.predict(self.x_val).argmax(axis=1)
1019
+ ground_truth = self.y_class_val.argmax(axis=1)
1020
+ assert ground_truth.shape==predictions.shape,"Mismatch in shape between the predictions and the ground truth..."
1021
+ title="Validation data"
1022
+
1023
+ # Validation scores
1024
+ IoU_score = jaccard_score(ground_truth, predictions, average=None)
1025
+ balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
1026
+ precision = precision_score(ground_truth, predictions, average=None)
1027
+ recall = recall_score(ground_truth, predictions, average=None)
1028
+
1029
+ print(f"Validation IoU score: {IoU_score}")
1030
+ print(f"Validation Balanced accuracy score: {balanced_accuracy}")
1031
+ print(f'Validation Precision: {precision}')
1032
+ print(f'Validation Recall: {recall}')
1033
+
1034
+ # Confusion matrix on validation set
1035
+ results = confusion_matrix(ground_truth,predictions)
1036
+ self.dico.update({"val_IoU": IoU_score, "val_balanced_accuracy": balanced_accuracy, "val_confusion": results, 'val_precision': precision, 'val_recall': recall})
1037
+
1038
+ try:
1039
+ plot_confusion_matrix(results, ["dead","alive","miscellaneous"], output_dir=self.model_folder+os.sep, title=title)
1040
+ except:
1041
+ pass
1042
+ print("Validation set: ",classification_report(ground_truth,predictions))
1043
+
1044
+
1045
+ def train_regressor(self):
1046
+
1047
+ """
1048
+ Trains the regressor component of the model to estimate the time of interest for events in signals.
1049
+
1050
+ This method compiles the regressor model (if not pretrained or if recompilation is requested) and
1051
+ trains it on a subset of the prepared dataset containing signals with events. The training process
1052
+ includes validation and early stopping based on mean squared error to prevent overfitting.
1053
+
1054
+ Notes
1055
+ -----
1056
+ - The regressor model estimates the time at which an event of interest occurs within each signal.
1057
+ - Only signals predicted to have an event by the classifier model are used for regressor training.
1058
+ - Model performance metrics and training history are saved for analysis.
1059
+ """
1060
+
1061
+
1062
+ # Compile model
1063
+ # if pretrained model
1064
+ if self.pretrained is not None:
1065
+ # if recompile
1066
+ if self.recompile_pretrained:
1067
+ print('Recompiling the pretrained regressor model... Warning, this action reinitializes all the weights; are you sure that this is what you intended?')
1068
+ self.model_reg.set_weights(clone_model(self.model_reg).get_weights())
1069
+ self.model_reg.compile(optimizer=Adam(learning_rate=self.learning_rate),
1070
+ loss=self.loss_reg,
1071
+ metrics=['mse','mae'])
1072
+ else:
1073
+ self.initial_model = clone_model(self.model_reg)
1074
+ self.model_reg.set_weights(self.initial_model.get_weights())
1075
+ self.model_reg.compile(optimizer=Adam(learning_rate=self.learning_rate),
1076
+ loss=self.loss_reg,
1077
+ metrics=['mse','mae'])
1078
+ self.model_reg.set_weights(self.initial_model.get_weights())
1079
+ else:
1080
+ print("Compiling the regressor...")
1081
+ self.model_reg.compile(optimizer=Adam(learning_rate=self.learning_rate),
1082
+ loss=self.loss_reg,
1083
+ metrics=['mse','mae'])
1084
+
1085
+
1086
+ self.gather_callbacks("regressor")
1087
+
1088
+ # Train on subset of data with event
1089
+
1090
+ subset = self.x_train[np.argmax(self.y_class_train,axis=1)==0]
1091
+ # for i in range(30):
1092
+ # plt.plot(subset[i,:,0],c="tab:red")
1093
+ # plt.plot(subset[i,:,1],c="tab:blue")
1094
+ # plt.show()
1095
+
1096
+ if hasattr(self, 'x_val'):
1097
+ self.history_regressor = self.model_reg.fit(x=self.x_train[np.argmax(self.y_class_train,axis=1)==0],
1098
+ y=self.y_time_train[np.argmax(self.y_class_train,axis=1)==0],
1099
+ batch_size=self.batch_size,
1100
+ epochs=self.epochs*2,
1101
+ validation_data=(self.x_val[np.argmax(self.y_class_val,axis=1)==0],self.y_time_val[np.argmax(self.y_class_val,axis=1)==0]),
1102
+ callbacks=self.cb,
1103
+ verbose=1)
1104
+ else:
1105
+ self.history_regressor = self.model_reg.fit(x=self.x_train[np.argmax(self.y_class_train,axis=1)==0],
1106
+ y=self.y_time_train[np.argmax(self.y_class_train,axis=1)==0],
1107
+ batch_size=self.batch_size,
1108
+ epochs=self.epochs*2,
1109
+ callbacks=self.cb,
1110
+ validation_split = self.validation_split,
1111
+ verbose=1)
1112
+
1113
+ self.plot_model_history(mode="regressor")
1114
+ self.dico.update({"history_regressor": self.history_regressor, "execution_time_regressor": self.cb[-1].times})
1115
+
1116
+
1117
+ # Evaluate best model
1118
+ self.model_reg = load_model(os.sep.join([self.model_folder,"regressor.h5"]))
1119
+ self.model_reg.load_weights(os.sep.join([self.model_folder,"regressor.h5"]))
1120
+ self.evaluate_regression_model()
1121
+
1122
+ try:
1123
+ np.save(os.sep.join([self.model_folder,"scores.npy"]), self.dico)
1124
+ except Exception as e:
1125
+ print(e)
1126
+
1127
+
1128
+ def plot_model_history(self, mode="regressor"):
1129
+
1130
+ """
1131
+ Generates and saves plots of the training history for the classifier or regressor model.
1132
+
1133
+ Parameters
1134
+ ----------
1135
+ mode : str, optional
1136
+ Specifies which model's training history to plot. Options are "classifier" or "regressor". Default is "regressor".
1137
+
1138
+ Notes
1139
+ -----
1140
+ - Plots include loss and accuracy metrics over epochs for the classifier, and loss metrics for the regressor.
1141
+ - The plots are saved as image files in the model's output directory.
1142
+ """
1143
+
1144
+ if mode=="regressor":
1145
+ try:
1146
+ plt.plot(self.history_regressor.history['loss'])
1147
+ plt.plot(self.history_regressor.history['val_loss'])
1148
+ plt.title('model loss')
1149
+ plt.ylabel('loss')
1150
+ plt.xlabel('epoch')
1151
+ plt.yscale('log')
1152
+ plt.legend(['train', 'val'], loc='upper left')
1153
+ plt.pause(3)
1154
+ plt.savefig(os.sep.join([self.model_folder,"regression_loss.png"]),bbox_inches="tight",dpi=300)
1155
+ plt.close()
1156
+ except Exception as e:
1157
+ print(f"Error {e}; could not generate plot...")
1158
+ elif mode=="classifier":
1159
+ try:
1160
+ plt.plot(self.history_classifier.history['precision'])
1161
+ plt.plot(self.history_classifier.history['val_precision'])
1162
+ plt.title('model precision')
1163
+ plt.ylabel('precision')
1164
+ plt.xlabel('epoch')
1165
+ plt.legend(['train', 'val'], loc='upper left')
1166
+ plt.pause(3)
1167
+ plt.savefig(os.sep.join([self.model_folder,"classification_loss.png"]),bbox_inches="tight",dpi=300)
1168
+ plt.close()
1169
+ except Exception as e:
1170
+ print(f"Error {e}; could not generate plot...")
1171
+ else:
1172
+ return None
1173
+
1174
+ def evaluate_regression_model(self):
1175
+
1176
+ """
1177
+ Evaluates the performance of the trained regression model on test and validation datasets.
1178
+
1179
+ This method calculates and prints mean squared error and mean absolute error metrics for the regression model's
1180
+ predictions. It also generates regression plots comparing predicted times of interest to true values.
1181
+
1182
+ Notes
1183
+ -----
1184
+ - Evaluation is performed on both test and validation datasets, if available.
1185
+ - Regression plots and performance metrics are saved in the model's output directory.
1186
+ """
1187
+
1188
+
1189
+ mse = MeanSquaredError()
1190
+ mae = MeanAbsoluteError()
1191
+
1192
+ if hasattr(self, 'x_test'):
1193
+
1194
+ print("Evaluate on test set...")
1195
+ predictions = self.model_reg.predict(self.x_test[np.argmax(self.y_class_test,axis=1)==0], batch_size=self.batch_size)[:,0]
1196
+ ground_truth = self.y_time_test[np.argmax(self.y_class_test,axis=1)==0]
1197
+ assert predictions.shape==ground_truth.shape,"Shape mismatch between predictions and ground truths..."
1198
+
1199
+ test_mse = mse(ground_truth, predictions).numpy()
1200
+ test_mae = mae(ground_truth, predictions).numpy()
1201
+ print(f"MSE on test set: {test_mse}...")
1202
+ print(f"MAE on test set: {test_mae}...")
1203
+ regression_plot(predictions, ground_truth, savepath=os.sep.join([self.model_folder,"test_regression.png"]))
1204
+ self.dico.update({"test_mse": test_mse, "test_mae": test_mae})
1205
+
1206
+ if hasattr(self, 'x_val'):
1207
+ # Validation set
1208
+ predictions = self.model_reg.predict(self.x_val[np.argmax(self.y_class_val,axis=1)==0], batch_size=self.batch_size)[:,0]
1209
+ ground_truth = self.y_time_val[np.argmax(self.y_class_val,axis=1)==0]
1210
+ assert predictions.shape==ground_truth.shape,"Shape mismatch between predictions and ground truths..."
1211
+
1212
+ val_mse = mse(ground_truth, predictions).numpy()
1213
+ val_mae = mae(ground_truth, predictions).numpy()
1214
+
1215
+ regression_plot(predictions, ground_truth, savepath=os.sep.join([self.model_folder,"validation_regression.png"]))
1216
+ print(f"MSE on validation set: {val_mse}...")
1217
+ print(f"MAE on validation set: {val_mae}...")
1218
+
1219
+ self.dico.update({"val_mse": val_mse, "val_mae": val_mae})
1220
+
1221
+
1222
+ def gather_callbacks(self, mode):
1223
+
1224
+ """
1225
+ Prepares a list of Keras callbacks for model training based on the specified mode.
1226
+
1227
+ Parameters
1228
+ ----------
1229
+ mode : str
1230
+ The training mode for which callbacks are being prepared. Options are "classifier" or "regressor".
1231
+
1232
+ Notes
1233
+ -----
1234
+ - Callbacks include learning rate reduction on plateau, early stopping, model checkpointing, and TensorBoard logging.
1235
+ - The list of callbacks is stored in the class attribute `cb` and used during model training.
1236
+ """
1237
+
1238
+ self.cb = []
1239
+
1240
+ if mode=="classifier":
1241
+
1242
+ reduce_lr = ReduceLROnPlateau(monitor='val_precision', factor=0.5, patience=30,
1243
+ cooldown=10, min_lr=5e-10, min_delta=1.0E-10,
1244
+ verbose=1,mode="max")
1245
+ self.cb.append(reduce_lr)
1246
+ csv_logger = CSVLogger(os.sep.join([self.model_folder,'log_classifier.csv']), append=True, separator=';')
1247
+ self.cb.append(csv_logger)
1248
+ checkpoint_path = os.sep.join([self.model_folder,"classifier.h5"])
1249
+ cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_precision",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
1250
+ self.cb.append(cp_callback)
1251
+
1252
+ callback_stop = EarlyStopping(monitor='val_precision', patience=100)
1253
+ self.cb.append(callback_stop)
1254
+
1255
+ elif mode=="regressor":
1256
+
1257
+ reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=30,
1258
+ cooldown=10, min_lr=5e-10, min_delta=1.0E-10,
1259
+ verbose=1,mode="min")
1260
+ self.cb.append(reduce_lr)
1261
+
1262
+ csv_logger = CSVLogger(os.sep.join([self.model_folder,'log_regressor.csv']), append=True, separator=';')
1263
+ self.cb.append(csv_logger)
1264
+
1265
+ checkpoint_path = os.sep.join([self.model_folder,"regressor.h5"])
1266
+ cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_loss",mode="min",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
1267
+ self.cb.append(cp_callback)
1268
+
1269
+ callback_stop = EarlyStopping(monitor='val_loss', patience=200)
1270
+ self.cb.append(callback_stop)
1271
+
1272
+ log_dir = self.model_folder+os.sep
1273
+ cb_tb = TensorBoard(log_dir=log_dir, update_freq='batch')
1274
+ self.cb.append(cb_tb)
1275
+
1276
+ cb_time = TimeHistory()
1277
+ self.cb.append(cb_time)
1278
+
1279
+
1280
+
1281
+ def generate_sets(self):
1282
+
1283
+ """
1284
+ Generates and preprocesses training, validation, and test sets from loaded annotations.
1285
+
1286
+ This method loads signal data from annotation files, normalizes and interpolates the signals, and splits
1287
+ the dataset into training, validation, and test sets according to specified proportions.
1288
+
1289
+ Notes
1290
+ -----
1291
+ - Signal annotations are expected to be stored in .npy format and contain required channels and event information.
1292
+ - The method applies specified normalization and interpolation options to prepare the signals for model training.
1293
+ """
1294
+
1295
+
1296
+ self.x_set = []
1297
+ self.y_time_set = []
1298
+ self.y_class_set = []
1299
+
1300
+ for s in self.list_of_sets:
1301
+ self.load_and_normalize(s)
1302
+
1303
+ self.x_set = np.array(self.x_set).astype(np.float32)
1304
+ self.x_set = self.interpolate_signals(self.x_set)
1305
+
1306
+ self.y_time_set = np.array(self.y_time_set).astype(np.float32)
1307
+ self.y_class_set = np.array(self.y_class_set).astype(np.float32)
1308
+
1309
+ class_test = np.isin(self.y_class_set, [0,1,2])
1310
+ self.x_set = self.x_set[class_test]
1311
+ self.y_time_set = self.y_time_set[class_test]
1312
+ self.y_class_set = self.y_class_set[class_test]
1313
+
1314
+ # Compute class weights and one-hot encode
1315
+ self.class_weights = compute_weights(self.y_class_set)
1316
+ self.nbr_classes = len(np.unique(self.y_class_set))
1317
+ self.y_class_set = to_categorical(self.y_class_set)
1318
+
1319
+ ds = train_test_split(self.x_set,
1320
+ self.y_time_set,
1321
+ self.y_class_set,
1322
+ validation_size=self.validation_split,
1323
+ test_size=self.test_split)
1324
+
1325
+ self.x_train = ds["x_train"]
1326
+ self.x_val = ds["x_val"]
1327
+ self.y_time_train = ds["y1_train"].astype(np.float32)
1328
+ print(np.amax(self.y_time_train),np.amin(self.y_time_train))
1329
+ self.y_time_val = ds["y1_val"].astype(np.float32)
1330
+ self.y_class_train = ds["y2_train"]
1331
+ self.y_class_val = ds["y2_val"]
1332
+
1333
+ if self.test_split>0:
1334
+ self.x_test = ds["x_test"]
1335
+ self.y_time_test = ds["y1_test"].astype(np.float32)
1336
+ self.y_class_test = ds["y2_test"]
1337
+
1338
+ if self.augment:
1339
+ self.augment_training_set()
1340
+
1341
+ def augment_training_set(self, time_shift=True):
1342
+
1343
+ """
1344
+ Augments the training dataset with artificially generated data to increase model robustness.
1345
+
1346
+ Parameters
1347
+ ----------
1348
+ time_shift : bool, optional
1349
+ Specifies whether to include time-shifted versions of signals in the augmented dataset. Default is True.
1350
+
1351
+ Notes
1352
+ -----
1353
+ - Augmentation strategies include random time shifting and signal modifications to simulate variations in real data.
1354
+ - The augmented dataset is used for training the classifier and regressor models to improve generalization.
1355
+ """
1356
+
1357
+
1358
+ nbr_augment = self.augmentation_factor*len(self.x_train)
1359
+ randomize = np.arange(len(self.x_train))
1360
+ indices = random.choices(randomize,k=nbr_augment)
1361
+
1362
+ x_train_aug = []
1363
+ y_time_train_aug = []
1364
+ y_class_train_aug = []
1365
+
1366
+ for k in indices:
1367
+ aug = augmenter(self.x_train[k],
1368
+ self.y_time_train[k],
1369
+ self.y_class_train[k],
1370
+ self.model_signal_length,
1371
+ time_shift=time_shift)
1372
+ x_train_aug.append(aug[0])
1373
+ y_time_train_aug.append(aug[1])
1374
+ y_class_train_aug.append(aug[2])
1375
+
1376
+ # Save augmented training set
1377
+ self.x_train = np.array(x_train_aug)
1378
+ self.y_time_train = np.array(y_time_train_aug)
1379
+ self.y_class_train = np.array(y_class_train_aug)
1380
+
1381
+
1382
+
1383
+ def load_and_normalize(self, subset):
1384
+
1385
+ """
1386
+ Loads a subset of signal data from an annotation file and applies normalization.
1387
+
1388
+ Parameters
1389
+ ----------
1390
+ subset : str
1391
+ The file path to the .npy annotation file containing signal data for a subset of observations.
1392
+
1393
+ Notes
1394
+ -----
1395
+ - The method extracts required signal channels from the annotation file and applies specified normalization
1396
+ and interpolation steps.
1397
+ - Preprocessed signals are added to the global dataset for model training.
1398
+ """
1399
+
1400
+ set_k = np.load(subset,allow_pickle=True)
1401
+ ### here do a mapping between channel option and existing signals
1402
+
1403
+ required_signals = self.channel_option
1404
+ available_signals = list(set_k[0].keys())
1405
+
1406
+ selected_signals = []
1407
+ for s in required_signals:
1408
+ pattern_test = [s in a for a in available_signals]
1409
+ if np.any(pattern_test):
1410
+ valid_columns = np.array(available_signals)[np.array(pattern_test)]
1411
+ if len(valid_columns)==1:
1412
+ selected_signals.append(valid_columns[0])
1413
+ else:
1414
+ print(f'Found several candidate signals: {valid_columns}')
1415
+ for vc in natsorted(valid_columns):
1416
+ if 'circle' in vc:
1417
+ selected_signals.append(vc)
1418
+ break
1419
+ else:
1420
+ selected_signals.append(valid_columns[0])
1421
+ else:
1422
+ return None
1423
+
1424
+
1425
+ key_to_check = selected_signals[0] #self.channel_option[0]
1426
+ signal_lengths = [len(l[key_to_check]) for l in set_k]
1427
+ max_length = np.amax(signal_lengths)
1428
+
1429
+ fluo = np.zeros((len(set_k),max_length,self.n_channels))
1430
+ classes = np.zeros(len(set_k))
1431
+ times_of_interest = np.zeros(len(set_k))
1432
+
1433
+ for k in range(len(set_k)):
1434
+
1435
+ for i in range(self.n_channels):
1436
+ try:
1437
+ # take into account timeline for accurate time regression
1438
+ timeline = set_k[k]['FRAME'].astype(int)
1439
+ fluo[k,timeline,i] = set_k[k][selected_signals[i]]
1440
+ except:
1441
+ print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
1442
+ pass
1443
+
1444
+ classes[k] = set_k[k]["class"]
1445
+ times_of_interest[k] = set_k[k]["time_of_interest"]
1446
+
1447
+ # Correct absurd times of interest
1448
+ times_of_interest[np.nonzero(classes)] = -1
1449
+ times_of_interest[(times_of_interest<=0.0)] = -1
1450
+
1451
+ # Attempt per-set normalization
1452
+ fluo = pad_to_model_length(fluo, self.model_signal_length)
1453
+ if self.normalize:
1454
+ fluo = normalize_signal_set(fluo, self.channel_option, normalization_percentile=self.normalization_percentile,
1455
+ normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
1456
+ )
1457
+
1458
+ # Trivial normalization for time of interest
1459
+ times_of_interest /= self.model_signal_length
1460
+
1461
+ # Add to global dataset
1462
+ self.x_set.extend(fluo)
1463
+ self.y_time_set.extend(times_of_interest)
1464
+ self.y_class_set.extend(classes)
1465
+
1466
+ def _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip):
1467
+
1468
+ """
1469
+ Interprets and validates normalization parameters for each channel.
1470
+
1471
+ This function ensures the normalization parameters are correctly formatted and expanded to match
1472
+ the number of channels in the dataset. It provides default values and expands single values into
1473
+ lists to match the number of channels if necessary.
1474
+
1475
+ Parameters
1476
+ ----------
1477
+ n_channels : int
1478
+ The number of channels in the dataset.
1479
+ normalization_percentile : list of bool or bool, optional
1480
+ Specifies whether to normalize each channel based on percentile values. If a single bool is provided,
1481
+ it is expanded to a list matching the number of channels. Default is True for all channels.
1482
+ normalization_values : list of lists or list, optional
1483
+ Specifies the percentile values [lower, upper] for normalization for each channel. If a single pair
1484
+ is provided, it is expanded to match the number of channels. Default is [[0.1, 99.9]] for all channels.
1485
+ normalization_clip : list of bool or bool, optional
1486
+ Specifies whether to clip the normalized values for each channel to the range [0, 1]. If a single bool
1487
+ is provided, it is expanded to a list matching the number of channels. Default is False for all channels.
1488
+
1489
+ Returns
1490
+ -------
1491
+ tuple
1492
+ A tuple containing three lists: `normalization_percentile`, `normalization_values`, and `normalization_clip`,
1493
+ each of length `n_channels`, representing the interpreted and validated normalization parameters for each channel.
1494
+
1495
+ Raises
1496
+ ------
1497
+ AssertionError
1498
+ If the lengths of the provided lists do not match `n_channels`.
1499
+
1500
+ Examples
1501
+ --------
1502
+ >>> n_channels = 2
1503
+ >>> normalization_percentile = True
1504
+ >>> normalization_values = [0.1, 99.9]
1505
+ >>> normalization_clip = False
1506
+ >>> params = _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip)
1507
+ >>> print(params)
1508
+ # ([True, True], [[0.1, 99.9], [0.1, 99.9]], [False, False])
1509
+ """
1510
+
1511
+
1512
+ if normalization_percentile is None:
1513
+ normalization_percentile = [True]*n_channels
1514
+ if normalization_values is None:
1515
+ normalization_values = [[0.1,99.9]]*n_channels
1516
+ if normalization_clip is None:
1517
+ normalization_clip = [False]*n_channels
1518
+
1519
+ if isinstance(normalization_percentile, bool):
1520
+ normalization_percentile = [normalization_percentile]*n_channels
1521
+ if isinstance(normalization_clip, bool):
1522
+ normalization_clip = [normalization_clip]*n_channels
1523
+ if len(normalization_values)==2 and not isinstance(normalization_values[0], list):
1524
+ normalization_values = [normalization_values]*n_channels
1525
+
1526
+ assert len(normalization_values)==n_channels
1527
+ assert len(normalization_clip)==n_channels
1528
+ assert len(normalization_percentile)==n_channels
1529
+
1530
+ return normalization_percentile, normalization_values, normalization_clip
1531
+
1532
+
1533
+ def normalize_signal_set(signal_set, channel_option, percentile_alive=[0.01,99.99], percentile_dead=[0.5,99.999], percentile_generic=[0.01,99.99], normalization_percentile=None, normalization_values=None, normalization_clip=None):
1534
+
1535
+ """
1536
+ Normalizes a set of single-cell signals across specified channels using given percentile values or specific normalization parameters.
1537
+
1538
+ This function applies normalization to each channel in the signal set based on the provided normalization parameters,
1539
+ which can be defined globally or per channel. The normalization process aims to scale the signal values to a standard
1540
+ range, improving the consistency and comparability of signal measurements across samples.
1541
+
1542
+ Parameters
1543
+ ----------
1544
+ signal_set : ndarray
1545
+ A 3D numpy array representing the set of signals to be normalized, with dimensions corresponding to (samples, time points, channels).
1546
+ channel_option : list of str
1547
+ A list specifying the channels included in the signal set and their corresponding normalization strategy based on channel names.
1548
+ percentile_alive : list of float, optional
1549
+ The percentile values [lower, upper] used for normalization of signals from channels labeled as 'alive'. Default is [0.01, 99.99].
1550
+ percentile_dead : list of float, optional
1551
+ The percentile values [lower, upper] used for normalization of signals from channels labeled as 'dead'. Default is [0.5, 99.999].
1552
+ percentile_generic : list of float, optional
1553
+ The percentile values [lower, upper] used for normalization of signals from channels not specifically labeled as 'alive' or 'dead'.
1554
+ Default is [0.01, 99.99].
1555
+ normalization_percentile : list of bool or None, optional
1556
+ Specifies whether to normalize each channel based on percentile values. If None, the default percentile strategy is applied
1557
+ based on `channel_option`. If a list, it should match the length of `channel_option`.
1558
+ normalization_values : list of lists or None, optional
1559
+ Specifies the percentile values [lower, upper] or fixed values [min, max] for normalization for each channel. Overrides
1560
+ `percentile_alive`, `percentile_dead`, and `percentile_generic` if provided.
1561
+ normalization_clip : list of bool or None, optional
1562
+ Specifies whether to clip the normalized values for each channel to the range [0, 1]. If None, clipping is disabled by default.
1563
+
1564
+ Returns
1565
+ -------
1566
+ ndarray
1567
+ The normalized signal set with the same shape as the input `signal_set`.
1568
+
1569
+ Notes
1570
+ -----
1571
+ - The function supports different normalization strategies for 'alive', 'dead', and generic signal channels, which can be customized
1572
+ via `channel_option` and the percentile parameters.
1573
+ - Normalization parameters (`normalization_percentile`, `normalization_values`, `normalization_clip`) are interpreted and validated
1574
+ by calling `_interpret_normalization_parameters`.
1575
+
1576
+ Examples
1577
+ --------
1578
+ >>> signal_set = np.random.rand(100, 128, 2) # 100 samples, 128 time points, 2 channels
1579
+ >>> channel_option = ['alive', 'dead']
1580
+ >>> normalized_signals = normalize_signal_set(signal_set, channel_option)
1581
+ # Normalizes the signal set based on the default percentile values for 'alive' and 'dead' channels.
1582
+ """
1583
+
1584
+ # Check normalization params are ok
1585
+ n_channels = len(channel_option)
1586
+ normalization_percentile, normalization_values, normalization_clip = _interpret_normalization_parameters(n_channels,
1587
+ normalization_percentile,
1588
+ normalization_values,
1589
+ normalization_clip)
1590
+ for k,channel in enumerate(channel_option):
1591
+
1592
+ zero_values = []
1593
+ for i in range(len(signal_set)):
1594
+ zeros_loc = np.where(signal_set[i,:,k]==0)
1595
+ zero_values.append(zeros_loc)
1596
+
1597
+ values = signal_set[:,:,k]
1598
+
1599
+ if normalization_percentile[k]:
1600
+ min_val = np.nanpercentile(values[values!=0.], normalization_values[k][0])
1601
+ max_val = np.nanpercentile(values[values!=0.], normalization_values[k][1])
1602
+ else:
1603
+ min_val = normalization_values[k][0]
1604
+ max_val = normalization_values[k][1]
1605
+
1606
+ signal_set[:,:,k] -= min_val
1607
+ signal_set[:,:,k] /= (max_val - min_val)
1608
+
1609
+ if normalization_clip[k]:
1610
+ to_clip_low = []
1611
+ to_clip_high = []
1612
+ for i in range(len(signal_set)):
1613
+ clip_low_loc = np.where(signal_set[i,:,k]<=0)
1614
+ clip_high_loc = np.where(signal_set[i,:,k]>=1.0)
1615
+ to_clip_low.append(clip_low_loc)
1616
+ to_clip_high.append(clip_high_loc)
1617
+
1618
+ for i,z in enumerate(to_clip_low):
1619
+ signal_set[i,z,k] = 0.
1620
+ for i,z in enumerate(to_clip_high):
1621
+ signal_set[i,z,k] = 1.
1622
+
1623
+ for i,z in enumerate(zero_values):
1624
+ signal_set[i,z,k] = 0.
1625
+
1626
+ return signal_set
1627
+
1628
+ def pad_to_model_length(signal_set, model_signal_length):
1629
+
1630
+ """
1631
+
1632
+ Pad the signal set to match the specified model signal length.
1633
+
1634
+ Parameters
1635
+ ----------
1636
+ signal_set : array-like
1637
+ The signal set to be padded.
1638
+ model_signal_length : int
1639
+ The desired length of the model signal.
1640
+
1641
+ Returns
1642
+ -------
1643
+ array-like
1644
+ The padded signal set.
1645
+
1646
+ Notes
1647
+ -----
1648
+ This function pads the signal set with zeros along the second dimension (axis 1) to match the specified model signal
1649
+ length. The padding is applied to the end of the signals, increasing their length.
1650
+
1651
+ Examples
1652
+ --------
1653
+ >>> signal_set = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
1654
+ >>> padded_signals = pad_to_model_length(signal_set, 5)
1655
+
1656
+ """
1657
+
1658
+ padded = np.pad(signal_set, [(0,0),(0,model_signal_length - signal_set.shape[1]),(0,0)])
1659
+
1660
+ return padded
1661
+
1662
+ def random_intensity_change(signal):
1663
+
1664
+ """
1665
+
1666
+ Randomly change the intensity of a signal.
1667
+
1668
+ Parameters
1669
+ ----------
1670
+ signal : array-like
1671
+ The input signal to be modified.
1672
+
1673
+ Returns
1674
+ -------
1675
+ array-like
1676
+ The modified signal with randomly changed intensity.
1677
+
1678
+ Notes
1679
+ -----
1680
+ This function applies a random intensity change to each channel of the input signal. The intensity change is
1681
+ performed by multiplying each channel with a random value drawn from a uniform distribution between 0.7 and 1.0.
1682
+
1683
+ Examples
1684
+ --------
1685
+ >>> signal = np.array([[1, 2, 3], [4, 5, 6]])
1686
+ >>> modified_signal = random_intensity_change(signal)
1687
+
1688
+ """
1689
+
1690
+ for k in range(signal.shape[1]):
1691
+ signal[:,k] = signal[:,k]*np.random.uniform(0.7,1.)
1692
+
1693
+ return signal
1694
+
1695
+ def gauss_noise(signal):
1696
+
1697
+ """
1698
+
1699
+ Add Gaussian noise to a signal.
1700
+
1701
+ Parameters
1702
+ ----------
1703
+ signal : array-like
1704
+ The input signal to which noise will be added.
1705
+
1706
+ Returns
1707
+ -------
1708
+ array-like
1709
+ The signal with Gaussian noise added.
1710
+
1711
+ Notes
1712
+ -----
1713
+ This function adds Gaussian noise to the input signal. The noise is generated by drawing random values from a
1714
+ standard normal distribution and scaling them by a factor of 0.08 times the input signal. The scaled noise values
1715
+ are then added to the original signal.
1716
+
1717
+ Examples
1718
+ --------
1719
+ >>> signal = np.array([1, 2, 3, 4, 5])
1720
+ >>> noisy_signal = gauss_noise(signal)
1721
+
1722
+ """
1723
+
1724
+ sig = 0.08*np.random.uniform(0,1)
1725
+ signal = signal + sig*np.random.normal(0,1,signal.shape)*signal
1726
+ return signal
1727
+
1728
+ def random_time_shift(signal, time_of_interest, cclass, model_signal_length):
1729
+
1730
+ """
1731
+
1732
+ Randomly shift the signals to another time.
1733
+
1734
+ Parameters
1735
+ ----------
1736
+ signal : array-like
1737
+ The signal to be shifted.
1738
+ time_of_interest : int or float
1739
+ The original time of interest for the signal. Use -1 if not applicable.
1740
+ model_signal_length : int
1741
+ The length of the model signal.
1742
+
1743
+ Returns
1744
+ -------
1745
+ array-like
1746
+ The shifted fluorescence signal.
1747
+ int or float
1748
+ The new time of interest if available; otherwise, the original time of interest.
1749
+
1750
+ Notes
1751
+ -----
1752
+ This function randomly selects a target time within the specified model signal length and shifts the
1753
+ signal accordingly. The shift is performed along the first dimension (axis 0) of the signal. The function uses
1754
+ nearest-neighbor interpolation for shifting.
1755
+
1756
+ If the original time of interest (`time_of_interest`) is provided (not equal to -1), the function returns the
1757
+ shifted signal along with the new time of interest. Otherwise, it returns the shifted signal along with the
1758
+ original time of interest.
1759
+
1760
+ The `max_time` is set to the `model_signal_length` unless the original time of interest is provided. In that case,
1761
+ `max_time` is set to `model_signal_length - 3` to prevent shifting too close to the edge.
1762
+
1763
+ Examples
1764
+ --------
1765
+ >>> signal = np.array([[1, 2, 3], [4, 5, 6]])
1766
+ >>> shifted_signal, new_time = random_time_shift(signal, 1, 5)
1767
+
1768
+ """
1769
+
1770
+ max_time = model_signal_length
1771
+ return_target = False
1772
+ if time_of_interest != -1:
1773
+ return_target = True
1774
+ max_time = model_signal_length - 3 # to prevent approaching too much to the edge
1775
+
1776
+ times = np.linspace(-max_time,max_time,2000) # symmetrize to create left-censored events
1777
+ target_time = np.random.choice(times)
1778
+
1779
+ delta_t = target_time - time_of_interest
1780
+ signal = shift(signal, [delta_t,0], order=0, mode="nearest")
1781
+
1782
+ if target_time<=0 and np.argmax(cclass)==0:
1783
+ target_time = -1
1784
+ cclass = np.array([0.,0.,1.]).astype(np.float32)
1785
+
1786
+ if return_target:
1787
+ return signal,target_time, cclass
1788
+ else:
1789
+ return signal, time_of_interest, cclass
1790
+
1791
+ def augmenter(signal, time_of_interest, cclass, model_signal_length, time_shift=True, probability=0.8):
1792
+
1793
+ """
1794
+ Randomly augments single-cell signals to simulate variations in noise, intensity ratios, and event times.
1795
+
1796
+ This function applies random transformations to the input signal, including time shifts, intensity changes,
1797
+ and the addition of Gaussian noise, with the aim of increasing the diversity of the dataset for training robust models.
1798
+
1799
+ Parameters
1800
+ ----------
1801
+ signal : ndarray
1802
+ A 1D numpy array representing the signal of a single cell to be augmented.
1803
+ time_of_interest : float
1804
+ The normalized time of interest (event time) for the signal, scaled to the range [0, 1].
1805
+ cclass : ndarray
1806
+ A one-hot encoded numpy array representing the class of the cell associated with the signal.
1807
+ model_signal_length : int
1808
+ The length of the signal expected by the model, used for scaling the time of interest.
1809
+ time_shift : bool, optional
1810
+ Specifies whether to apply random time shifts to the signal. Default is True.
1811
+ probability : float, optional
1812
+ The probability with which to apply the augmentation transformations. Default is 0.8.
1813
+
1814
+ Returns
1815
+ -------
1816
+ tuple
1817
+ A tuple containing the augmented signal, the normalized time of interest, and the class of the cell.
1818
+
1819
+ Raises
1820
+ ------
1821
+ AssertionError
1822
+ If the time of interest is provided but invalid for time shifting.
1823
+
1824
+ Notes
1825
+ -----
1826
+ - Time shifting is not applied to cells of the class labeled as 'miscellaneous' (typically encoded as the class '2').
1827
+ - The time of interest is rescaled based on the model's expected signal length before and after any time shift.
1828
+ - Augmentation is applied with the specified probability to simulate realistic variability while maintaining
1829
+ some original signals in the dataset.
1830
+
1831
+ """
1832
+
1833
+ if np.amax(time_of_interest)<=1.0:
1834
+ time_of_interest *= model_signal_length
1835
+
1836
+ # augment with a certain probability
1837
+ r = random.random()
1838
+ if r<= probability:
1839
+
1840
+ if time_shift:
1841
+ # do not time shift miscellaneous cells
1842
+ if cclass.argmax()!=2.:
1843
+ assert time_of_interest is not None, f"Please provide valid lysis times"
1844
+ signal,time_of_interest,cclass = random_time_shift(signal, time_of_interest, cclass, model_signal_length)
1845
+
1846
+ #signal = random_intensity_change(signal) #maybe bad idea for non percentile-normalized signals
1847
+ signal = gauss_noise(signal)
1848
+
1849
+ return signal, time_of_interest/model_signal_length, cclass
1850
+
1851
+
1852
+ def residual_block1D(x, number_of_filters, kernel_size=8, match_filter_size=True, connection='identity'):
1853
+
1854
+ """
1855
+
1856
+ Create a 1D residual block.
1857
+
1858
+ Parameters
1859
+ ----------
1860
+ x : Tensor
1861
+ Input tensor.
1862
+ number_of_filters : int
1863
+ Number of filters in the convolutional layers.
1864
+ match_filter_size : bool, optional
1865
+ Whether to match the filter size of the skip connection to the output. Default is True.
1866
+
1867
+ Returns
1868
+ -------
1869
+ Tensor
1870
+ Output tensor of the residual block.
1871
+
1872
+ Notes
1873
+ -----
1874
+ This function creates a 1D residual block by performing the original mapping followed by adding a skip connection
1875
+ and applying non-linear activation. The skip connection allows the gradient to flow directly to earlier layers and
1876
+ helps mitigate the vanishing gradient problem. The residual block consists of three convolutional layers with
1877
+ batch normalization and ReLU activation functions.
1878
+
1879
+ If `match_filter_size` is True, the skip connection is adjusted to have the same number of filters as the output.
1880
+ Otherwise, the skip connection is kept as is.
1881
+
1882
+ Examples
1883
+ --------
1884
+ >>> inputs = Input(shape=(10, 3))
1885
+ >>> x = residual_block1D(inputs, 64)
1886
+ # Create a 1D residual block with 64 filters and apply it to the input tensor.
1887
+
1888
+ """
1889
+
1890
+
1891
+ # Create skip connection
1892
+ x_skip = x
1893
+
1894
+ # Perform the original mapping
1895
+ if connection=='identity':
1896
+ x = Conv1D(number_of_filters, kernel_size=kernel_size, strides=1,padding="same")(x_skip)
1897
+ elif connection=='projection':
1898
+ x = ZeroPadding1D(padding=kernel_size//2)(x_skip)
1899
+ x = Conv1D(number_of_filters, kernel_size=kernel_size, strides=2,padding="valid")(x)
1900
+ x = BatchNormalization()(x)
1901
+ x = Activation("relu")(x)
1902
+
1903
+ x = Conv1D(number_of_filters, kernel_size=kernel_size, strides=1,padding="same")(x)
1904
+ x = BatchNormalization()(x)
1905
+
1906
+ if match_filter_size and connection=='identity':
1907
+ x_skip = Conv1D(number_of_filters, kernel_size=1, padding="same")(x_skip)
1908
+ elif match_filter_size and connection=='projection':
1909
+ x_skip = Conv1D(number_of_filters, kernel_size=1, strides=2, padding="valid")(x_skip)
1910
+
1911
+
1912
+ # Add the skip connection to the regular mapping
1913
+ x = Add()([x, x_skip])
1914
+
1915
+ # Nonlinearly activate the result
1916
+ x = Activation("relu")(x)
1917
+
1918
+ # Return the result
1919
+ return x
1920
+
1921
+
1922
+ def MultiscaleResNetModel(n_channels, n_classes = 3, dropout_rate=0, dense_collection=0, use_pooling=True,
1923
+ header="classifier", model_signal_length = 128):
1924
+
1925
+ """
1926
+
1927
+ Define a generic ResNet 1D encoder model.
1928
+
1929
+ Parameters
1930
+ ----------
1931
+ n_channels : int
1932
+ Number of input channels.
1933
+ n_blocks : int
1934
+ Number of residual blocks in the model.
1935
+ n_classes : int, optional
1936
+ Number of output classes. Default is 3.
1937
+ dropout_rate : float, optional
1938
+ Dropout rate to be applied. Default is 0.
1939
+ dense_collection : int, optional
1940
+ Number of neurons in the dense layer. Default is 0.
1941
+ header : str, optional
1942
+ Type of the model header. "classifier" for classification, "regressor" for regression. Default is "classifier".
1943
+ model_signal_length : int, optional
1944
+ Length of the input signal. Default is 128.
1945
+
1946
+ Returns
1947
+ -------
1948
+ keras.models.Model
1949
+ ResNet 1D encoder model.
1950
+
1951
+ Notes
1952
+ -----
1953
+ This function defines a generic ResNet 1D encoder model with the specified number of input channels, residual
1954
+ blocks, output classes, dropout rate, dense collection, and model header. The model architecture follows the
1955
+ ResNet principles with 1D convolutional layers and residual connections. The final activation and number of
1956
+ neurons in the output layer are determined based on the header type.
1957
+
1958
+ Examples
1959
+ --------
1960
+ >>> model = ResNetModel(n_channels=3, n_blocks=4, n_classes=2, dropout_rate=0.2)
1961
+ # Define a ResNet 1D encoder model with 3 input channels, 4 residual blocks, and 2 output classes.
1962
+
1963
+ """
1964
+
1965
+ if header=="classifier":
1966
+ final_activation = "softmax"
1967
+ neurons_final = n_classes
1968
+ elif header=="regressor":
1969
+ final_activation = "linear"
1970
+ neurons_final = 1
1971
+ else:
1972
+ return None
1973
+
1974
+ inputs = Input(shape=(model_signal_length,n_channels,))
1975
+ x = ZeroPadding1D(3)(inputs)
1976
+ x = Conv1D(64, kernel_size=7, strides=2, padding="valid", use_bias=False)(x)
1977
+ x = BatchNormalization()(x)
1978
+ x = ZeroPadding1D(1)(x)
1979
+ x_common = MaxPooling1D(pool_size=3, strides=2, padding='valid')(x)
1980
+
1981
+ # Block 1
1982
+ x1 = residual_block1D(x_common, 64, kernel_size=7,connection='projection')
1983
+ x1 = residual_block1D(x1, 128, kernel_size=7,connection='projection')
1984
+ x1 = residual_block1D(x1, 256, kernel_size=7,connection='projection')
1985
+ x1 = GlobalAveragePooling1D()(x1)
1986
+
1987
+ # Block 2
1988
+ x2 = residual_block1D(x_common, 64, kernel_size=5,connection='projection')
1989
+ x2 = residual_block1D(x2, 128, kernel_size=5,connection='projection')
1990
+ x2 = residual_block1D(x2, 256, kernel_size=5,connection='projection')
1991
+ x2 = GlobalAveragePooling1D()(x2)
1992
+
1993
+ # Block 3
1994
+ x3 = residual_block1D(x_common, 64, kernel_size=3,connection='projection')
1995
+ x3 = residual_block1D(x3, 128, kernel_size=3,connection='projection')
1996
+ x3 = residual_block1D(x3, 256, kernel_size=3,connection='projection')
1997
+ x3 = GlobalAveragePooling1D()(x3)
1998
+
1999
+ x_combined = Concatenate()([x1, x2, x3])
2000
+ x_combined = Flatten()(x_combined)
2001
+
2002
+ if dense_collection>0:
2003
+ x_combined = Dense(dense_collection)(x_combined)
2004
+ if dropout_rate>0:
2005
+ x_combined = Dropout(dropout_rate)(x_combined)
2006
+
2007
+ x_combined = Dense(neurons_final,activation=final_activation,name=header)(x_combined)
2008
+ model = Model(inputs, x_combined, name=header)
2009
+
2010
+ return model
2011
+
2012
+ def ResNetModelCurrent(n_channels, n_slices, depth=2, use_pooling=True, n_classes = 3, dropout_rate=0.1, dense_collection=512,
2013
+ header="classifier", model_signal_length = 128):
2014
+
2015
+ """
2016
+ Creates a ResNet-based model tailored for signal classification or regression tasks.
2017
+
2018
+ This function constructs a 1D ResNet architecture with specified parameters. The model can be configured
2019
+ for either classification or regression tasks, determined by the `header` parameter. It consists of
2020
+ configurable ResNet blocks, global average pooling, optional dense layers, and dropout for regularization.
2021
+
2022
+ Parameters
2023
+ ----------
2024
+ n_channels : int
2025
+ The number of channels in the input signal.
2026
+ n_slices : int
2027
+ The number of slices (or ResNet blocks) to use in the model.
2028
+ depth : int, optional
2029
+ The depth of the network, i.e., how many times the number of filters is doubled. Default is 2.
2030
+ use_pooling : bool, optional
2031
+ Whether to use MaxPooling between ResNet blocks. Default is True.
2032
+ n_classes : int, optional
2033
+ The number of classes for the classification task. Ignored for regression. Default is 3.
2034
+ dropout_rate : float, optional
2035
+ The dropout rate for regularization. Default is 0.1.
2036
+ dense_collection : int, optional
2037
+ The number of neurons in the dense layer following global pooling. If 0, the dense layer is omitted. Default is 512.
2038
+ header : str, optional
2039
+ Specifies the task type: "classifier" for classification or "regressor" for regression. Default is "classifier".
2040
+ model_signal_length : int, optional
2041
+ The length of the input signal. Default is 128.
2042
+
2043
+ Returns
2044
+ -------
2045
+ keras.Model
2046
+ The constructed Keras model ready for training or inference.
2047
+
2048
+ Notes
2049
+ -----
2050
+ - The model uses Conv1D layers for signal processing and applies global average pooling before the final classification
2051
+ or regression layer.
2052
+ - The choice of `final_activation` and `neurons_final` depends on the task: "softmax" and `n_classes` for classification,
2053
+ and "linear" and 1 for regression.
2054
+ - This function relies on a custom `residual_block1D` function for constructing ResNet blocks.
2055
+
2056
+ Examples
2057
+ --------
2058
+ >>> model = ResNetModelCurrent(n_channels=1, n_slices=2, depth=2, use_pooling=True, n_classes=3, dropout_rate=0.1, dense_collection=512, header="classifier", model_signal_length=128)
2059
+ # Creates a ResNet model configured for classification with 3 classes.
2060
+ """
2061
+
2062
+ if header=="classifier":
2063
+ final_activation = "softmax"
2064
+ neurons_final = n_classes
2065
+ elif header=="regressor":
2066
+ final_activation = "linear"
2067
+ neurons_final = 1
2068
+ else:
2069
+ return None
2070
+
2071
+ inputs = Input(shape=(model_signal_length,n_channels,))
2072
+ x2 = Conv1D(64, kernel_size=1,strides=1,padding='same')(inputs)
2073
+
2074
+ n_filters = 64
2075
+ for k in range(depth):
2076
+ for i in range(n_slices):
2077
+ x2 = residual_block1D(x2,n_filters,kernel_size=8)
2078
+ n_filters *= 2
2079
+ if use_pooling and k!=(depth-1):
2080
+ x2 = MaxPooling1D()(x2)
2081
+
2082
+ x2 = GlobalAveragePooling1D()(x2)
2083
+ if dense_collection>0:
2084
+ x2 = Dense(dense_collection)(x2)
2085
+ if dropout_rate>0:
2086
+ x2 = Dropout(dropout_rate)(x2)
2087
+
2088
+ x2 = Dense(neurons_final,activation=final_activation,name=header)(x2)
2089
+ model = Model(inputs, x2, name=header)
2090
+
2091
+ return model
2092
+
2093
+
2094
+ def train_signal_model(config):
2095
+
2096
+ """
2097
+ Initiates the training of a signal detection model using a specified configuration file.
2098
+
2099
+ This function triggers an external Python script to train a signal detection model. The training
2100
+ configuration, including data paths, model parameters, and training options, are specified in a JSON
2101
+ configuration file. The function asserts the existence of the configuration file before proceeding
2102
+ with the training process.
2103
+
2104
+ Parameters
2105
+ ----------
2106
+ config : str
2107
+ The file path to the JSON configuration file specifying training parameters. This path must be valid
2108
+ and the configuration file must be correctly formatted according to the expectations of the
2109
+ 'train_signal_model.py' script.
2110
+
2111
+ Raises
2112
+ ------
2113
+ AssertionError
2114
+ If the specified configuration file does not exist at the given path.
2115
+
2116
+ Notes
2117
+ -----
2118
+ - The external training script 'train_signal_model.py' is expected to be located in a predefined directory
2119
+ relative to this function and is responsible for the actual model training process.
2120
+ - The configuration file should include details such as data directories, model architecture specifications,
2121
+ training hyperparameters, and any preprocessing steps required.
2122
+
2123
+ Examples
2124
+ --------
2125
+ >>> config_path = '/path/to/training_config.json'
2126
+ >>> train_signal_model(config_path)
2127
+ # This will execute the 'train_signal_model.py' script using the parameters specified in 'training_config.json'.
2128
+ """
2129
+
2130
+ config = config.replace('\\','/')
2131
+ config = rf"{config}"
2132
+ assert os.path.exists(config),f'Config {config} is not a valid path.'
2133
+
2134
+ script_path = os.sep.join([abs_path, 'scripts', 'train_signal_model.py'])
2135
+ cmd = f'python "{script_path}" --config "{config}"'
2136
+ subprocess.call(cmd, shell=True)
2137
+
2138
+ def derivative(x, timeline, window, mode='bi'):
2139
+
2140
+ """
2141
+ Compute the derivative of a given array of values with respect to time using a specified numerical differentiation method.
2142
+
2143
+ Parameters
2144
+ ----------
2145
+ x : array_like
2146
+ The input array of values.
2147
+ timeline : array_like
2148
+ The array representing the time points corresponding to the input values.
2149
+ window : int
2150
+ The size of the window used for numerical differentiation. Must be a positive odd integer.
2151
+ mode : {'bi', 'forward', 'backward'}, optional
2152
+ The numerical differentiation method to be used:
2153
+ - 'bi' (default): Bidirectional differentiation using a symmetric window.
2154
+ - 'forward': Forward differentiation using a one-sided window.
2155
+ - 'backward': Backward differentiation using a one-sided window.
2156
+
2157
+ Returns
2158
+ -------
2159
+ dxdt : ndarray
2160
+ The computed derivative values of the input array with respect to time.
2161
+
2162
+ Raises
2163
+ ------
2164
+ AssertionError
2165
+ If the window size is not an odd integer and mode is 'bi'.
2166
+
2167
+ Notes
2168
+ -----
2169
+ - For 'bi' mode, the window size must be an odd number.
2170
+ - For 'forward' mode, the derivative at the edge points may not be accurate due to the one-sided window.
2171
+ - For 'backward' mode, the derivative at the first few points may not be accurate due to the one-sided window.
2172
+
2173
+ Examples
2174
+ --------
2175
+ >>> import numpy as np
2176
+ >>> x = np.array([1, 2, 4, 7, 11])
2177
+ >>> timeline = np.array([0, 1, 2, 3, 4])
2178
+ >>> window = 3
2179
+ >>> derivative(x, timeline, window, mode='bi')
2180
+ array([3., 3., 3.])
2181
+
2182
+ >>> derivative(x, timeline, window, mode='forward')
2183
+ array([1., 2., 3.])
2184
+
2185
+ >>> derivative(x, timeline, window, mode='backward')
2186
+ array([3., 3., 3., 3.])
2187
+ """
2188
+
2189
+ # modes = bi, forward, backward
2190
+ dxdt = np.zeros(len(x))
2191
+ dxdt[:] = np.nan
2192
+
2193
+ if mode=='bi':
2194
+ assert window%2==1,'Please set an odd window for the bidirectional mode'
2195
+ lower_bound = window//2
2196
+ upper_bound = len(x) - window//2 - 1
2197
+ elif mode=='forward':
2198
+ lower_bound = 0
2199
+ upper_bound = len(x) - window
2200
+ elif mode=='backward':
2201
+ lower_bound = window
2202
+ upper_bound = len(x)
2203
+
2204
+ for t in range(lower_bound,upper_bound):
2205
+ if mode=='bi':
2206
+ dxdt[t] = (x[t+window//2+1] - x[t-window//2]) / (timeline[t+window//2+1] - timeline[t-window//2])
2207
+ elif mode=='forward':
2208
+ dxdt[t] = (x[t+window] - x[t]) / (timeline[t+window] - timeline[t])
2209
+ elif mode=='backward':
2210
+ dxdt[t] = (x[t] - x[t-window]) / (timeline[t] - timeline[t-window])
2211
+ return dxdt
2212
+
2213
+ def velocity(x,y,timeline,window,mode='bi'):
2214
+
2215
+ """
2216
+ Compute the velocity vector of a given 2D trajectory represented by arrays of x and y coordinates
2217
+ with respect to time using a specified numerical differentiation method.
2218
+
2219
+ Parameters
2220
+ ----------
2221
+ x : array_like
2222
+ The array of x-coordinates of the trajectory.
2223
+ y : array_like
2224
+ The array of y-coordinates of the trajectory.
2225
+ timeline : array_like
2226
+ The array representing the time points corresponding to the x and y coordinates.
2227
+ window : int
2228
+ The size of the window used for numerical differentiation. Must be a positive odd integer.
2229
+ mode : {'bi', 'forward', 'backward'}, optional
2230
+ The numerical differentiation method to be used:
2231
+ - 'bi' (default): Bidirectional differentiation using a symmetric window.
2232
+ - 'forward': Forward differentiation using a one-sided window.
2233
+ - 'backward': Backward differentiation using a one-sided window.
2234
+
2235
+ Returns
2236
+ -------
2237
+ v : ndarray
2238
+ The computed velocity vector of the 2D trajectory with respect to time.
2239
+ The first column represents the x-component of velocity, and the second column represents the y-component.
2240
+
2241
+ Raises
2242
+ ------
2243
+ AssertionError
2244
+ If the window size is not an odd integer and mode is 'bi'.
2245
+
2246
+ Notes
2247
+ -----
2248
+ - For 'bi' mode, the window size must be an odd number.
2249
+ - For 'forward' mode, the velocity at the edge points may not be accurate due to the one-sided window.
2250
+ - For 'backward' mode, the velocity at the first few points may not be accurate due to the one-sided window.
2251
+
2252
+ Examples
2253
+ --------
2254
+ >>> import numpy as np
2255
+ >>> x = np.array([1, 2, 4, 7, 11])
2256
+ >>> y = np.array([0, 3, 5, 8, 10])
2257
+ >>> timeline = np.array([0, 1, 2, 3, 4])
2258
+ >>> window = 3
2259
+ >>> velocity(x, y, timeline, window, mode='bi')
2260
+ array([[3., 3.],
2261
+ [3., 3.]])
2262
+
2263
+ >>> velocity(x, y, timeline, window, mode='forward')
2264
+ array([[2., 2.],
2265
+ [3., 3.]])
2266
+
2267
+ >>> velocity(x, y, timeline, window, mode='backward')
2268
+ array([[3., 3.],
2269
+ [3., 3.]])
2270
+ """
2271
+
2272
+ v = np.zeros((len(x),2))
2273
+ v[:,:] = np.nan
2274
+
2275
+ v[:,0] = derivative(x, timeline, window, mode=mode)
2276
+ v[:,1] = derivative(y, timeline, window, mode=mode)
2277
+
2278
+ return v
2279
+
2280
+ def magnitude_velocity(v_matrix):
2281
+
2282
+ """
2283
+ Compute the magnitude of velocity vectors given a matrix representing 2D velocity vectors.
2284
+
2285
+ Parameters
2286
+ ----------
2287
+ v_matrix : array_like
2288
+ The matrix where each row represents a 2D velocity vector with the first column
2289
+ being the x-component and the second column being the y-component.
2290
+
2291
+ Returns
2292
+ -------
2293
+ magnitude : ndarray
2294
+ The computed magnitudes of the input velocity vectors.
2295
+
2296
+ Notes
2297
+ -----
2298
+ - If a velocity vector has NaN components, the corresponding magnitude will be NaN.
2299
+ - The function handles NaN values in the input matrix gracefully.
2300
+
2301
+ Examples
2302
+ --------
2303
+ >>> import numpy as np
2304
+ >>> v_matrix = np.array([[3, 4],
2305
+ ... [2, 2],
2306
+ ... [3, 3]])
2307
+ >>> magnitude_velocity(v_matrix)
2308
+ array([5., 2.82842712, 4.24264069])
2309
+
2310
+ >>> v_matrix_with_nan = np.array([[3, 4],
2311
+ ... [np.nan, 2],
2312
+ ... [3, np.nan]])
2313
+ >>> magnitude_velocity(v_matrix_with_nan)
2314
+ array([5., nan, nan])
2315
+ """
2316
+
2317
+ magnitude = np.zeros(len(v_matrix))
2318
+ magnitude[:] = np.nan
2319
+ for i in range(len(v_matrix)):
2320
+ if v_matrix[i,0]==v_matrix[i,0]:
2321
+ magnitude[i] = np.sqrt(v_matrix[i,0]**2 + v_matrix[i,1]**2)
2322
+ return magnitude
2323
+
2324
+ def orientation(v_matrix):
2325
+
2326
+ """
2327
+ Compute the orientation angles (in radians) of 2D velocity vectors given a matrix representing velocity vectors.
2328
+
2329
+ Parameters
2330
+ ----------
2331
+ v_matrix : array_like
2332
+ The matrix where each row represents a 2D velocity vector with the first column
2333
+ being the x-component and the second column being the y-component.
2334
+
2335
+ Returns
2336
+ -------
2337
+ orientation_array : ndarray
2338
+ The computed orientation angles of the input velocity vectors in radians.
2339
+ If a velocity vector has NaN components, the corresponding orientation angle will be NaN.
2340
+
2341
+ Examples
2342
+ --------
2343
+ >>> import numpy as np
2344
+ >>> v_matrix = np.array([[3, 4],
2345
+ ... [2, 2],
2346
+ ... [-3, -3]])
2347
+ >>> orientation(v_matrix)
2348
+ array([0.92729522, 0.78539816, -2.35619449])
2349
+
2350
+ >>> v_matrix_with_nan = np.array([[3, 4],
2351
+ ... [np.nan, 2],
2352
+ ... [3, np.nan]])
2353
+ >>> orientation(v_matrix_with_nan)
2354
+ array([0.92729522, nan, nan])
2355
+ """
2356
+
2357
+ orientation_array = np.zeros(len(v_matrix))
2358
+ for t in range(len(orientation_array)):
2359
+ if v_matrix[t,0]==v_matrix[t,0]:
2360
+ orientation_array[t] = np.arctan2(v_matrix[t,0],v_matrix[t,1])
2361
+ return orientation_array
2362
+
2363
+ def T_MSD(x,y,dt):
2364
+
2365
+ """
2366
+ Compute the Time-Averaged Mean Square Displacement (T-MSD) of a 2D trajectory.
2367
+
2368
+ Parameters
2369
+ ----------
2370
+ x : array_like
2371
+ The array of x-coordinates of the trajectory.
2372
+ y : array_like
2373
+ The array of y-coordinates of the trajectory.
2374
+ dt : float
2375
+ The time interval between successive data points in the trajectory.
2376
+
2377
+ Returns
2378
+ -------
2379
+ msd : list
2380
+ A list containing the Time-Averaged Mean Square Displacement values for different time lags.
2381
+ timelag : ndarray
2382
+ The array representing the time lags corresponding to the calculated MSD values.
2383
+
2384
+ Notes
2385
+ -----
2386
+ - T-MSD is a measure of the average spatial extent explored by a particle over a given time interval.
2387
+ - The input trajectories (x, y) are assumed to be in the same unit of length.
2388
+ - The time interval (dt) should be consistent with the time unit used in the data.
2389
+
2390
+ Examples
2391
+ --------
2392
+ >>> import numpy as np
2393
+ >>> x = np.array([1, 2, 4, 7, 11])
2394
+ >>> y = np.array([0, 3, 5, 8, 10])
2395
+ >>> dt = 1.0 # Time interval between data points
2396
+ >>> T_MSD(x, y, dt)
2397
+ ([6.0, 9.0, 4.666666666666667, 1.6666666666666667],
2398
+ array([1., 2., 3., 4.]))
2399
+ """
2400
+
2401
+ msd = []
2402
+ N = len(x)
2403
+ for n in range(1,N):
2404
+ s = 0
2405
+ for i in range(0,N-n):
2406
+ s+=(x[n+i] - x[i])**2 + (y[n+i] - y[i])**2
2407
+ msd.append(1/(N-n)*s)
2408
+
2409
+ timelag = np.linspace(dt,(N-1)*dt,N-1)
2410
+ return msd,timelag
2411
+
2412
+ def linear_msd(t, m):
2413
+
2414
+ """
2415
+ Function to compute Mean Square Displacement (MSD) with a linear scaling relationship.
2416
+
2417
+ Parameters
2418
+ ----------
2419
+ t : array_like
2420
+ Time lag values.
2421
+ m : float
2422
+ Linear scaling factor representing the slope of the MSD curve.
2423
+
2424
+ Returns
2425
+ -------
2426
+ msd : ndarray
2427
+ Computed MSD values based on the linear scaling relationship.
2428
+
2429
+ Examples
2430
+ --------
2431
+ >>> import numpy as np
2432
+ >>> t = np.array([1, 2, 3, 4])
2433
+ >>> m = 2.0
2434
+ >>> linear_msd(t, m)
2435
+ array([2., 4., 6., 8.])
2436
+ """
2437
+
2438
+ return m*t
2439
+
2440
+ def alpha_msd(t, m, alpha):
2441
+
2442
+ """
2443
+ Function to compute Mean Square Displacement (MSD) with a power-law scaling relationship.
2444
+
2445
+ Parameters
2446
+ ----------
2447
+ t : array_like
2448
+ Time lag values.
2449
+ m : float
2450
+ Scaling factor.
2451
+ alpha : float
2452
+ Exponent representing the scaling relationship between MSD and time.
2453
+
2454
+ Returns
2455
+ -------
2456
+ msd : ndarray
2457
+ Computed MSD values based on the power-law scaling relationship.
2458
+
2459
+ Examples
2460
+ --------
2461
+ >>> import numpy as np
2462
+ >>> t = np.array([1, 2, 3, 4])
2463
+ >>> m = 2.0
2464
+ >>> alpha = 0.5
2465
+ >>> alpha_msd(t, m, alpha)
2466
+ array([2. , 4. , 6. , 8. ])
2467
+ """
2468
+
2469
+ return m*t**alpha
2470
+
2471
+ def sliding_msd(x, y, timeline, window, mode='bi', n_points_migration=7, n_points_transport=7):
2472
+
2473
+ """
2474
+ Compute sliding mean square displacement (sMSD) and anomalous exponent (alpha) for a 2D trajectory using a sliding window approach.
2475
+
2476
+ Parameters
2477
+ ----------
2478
+ x : array_like
2479
+ The array of x-coordinates of the trajectory.
2480
+ y : array_like
2481
+ The array of y-coordinates of the trajectory.
2482
+ timeline : array_like
2483
+ The array representing the time points corresponding to the x and y coordinates.
2484
+ window : int
2485
+ The size of the sliding window used for computing local MSD and alpha values.
2486
+ mode : {'bi', 'forward', 'backward'}, optional
2487
+ The sliding window mode:
2488
+ - 'bi' (default): Bidirectional sliding window.
2489
+ - 'forward': Forward sliding window.
2490
+ - 'backward': Backward sliding window.
2491
+ n_points_migration : int, optional
2492
+ The number of points used for fitting the linear function in the MSD calculation.
2493
+ n_points_transport : int, optional
2494
+ The number of points used for fitting the alpha function in the anomalous exponent calculation.
2495
+
2496
+ Returns
2497
+ -------
2498
+ s_msd : ndarray
2499
+ Sliding Mean Square Displacement values calculated using the sliding window approach.
2500
+ s_alpha : ndarray
2501
+ Sliding anomalous exponent (alpha) values calculated using the sliding window approach.
2502
+
2503
+ Raises
2504
+ ------
2505
+ AssertionError
2506
+ If the window size is not larger than the number of fit points.
2507
+
2508
+ Notes
2509
+ -----
2510
+ - The input trajectories (x, y) are assumed to be in the same unit of length.
2511
+ - The time unit used in the data should be consistent with the time intervals in the timeline array.
2512
+
2513
+ Examples
2514
+ --------
2515
+ >>> import numpy as np
2516
+ >>> x = np.array([1, 2, 4, 7, 11, 15, 20])
2517
+ >>> y = np.array([0, 3, 5, 8, 10, 14, 18])
2518
+ >>> timeline = np.array([0, 1, 2, 3, 4, 5, 6])
2519
+ >>> window = 3
2520
+ >>> s_msd, s_alpha = sliding_msd(x, y, timeline, window, n_points_migration=2, n_points_transport=3)
2521
+ """
2522
+
2523
+ assert window > n_points_migration,'Please set a window larger than the number of fit points...'
2524
+
2525
+ # modes = bi, forward, backward
2526
+ s_msd = np.zeros(len(x))
2527
+ s_msd[:] = np.nan
2528
+ s_alpha = np.zeros(len(x))
2529
+ s_alpha[:] = np.nan
2530
+ dt = timeline[1] - timeline[0]
2531
+
2532
+ if mode=='bi':
2533
+ assert window%2==1,'Please set an odd window for the bidirectional mode'
2534
+ lower_bound = window//2
2535
+ upper_bound = len(x) - window//2 - 1
2536
+ elif mode=='forward':
2537
+ lower_bound = 0
2538
+ upper_bound = len(x) - window
2539
+ elif mode=='backward':
2540
+ lower_bound = window
2541
+ upper_bound = len(x)
2542
+
2543
+ for t in range(lower_bound,upper_bound):
2544
+ if mode=='bi':
2545
+ x_sub = x[t-window//2:t+window//2+1]
2546
+ y_sub = y[t-window//2:t+window//2+1]
2547
+ msd,timelag = T_MSD(x_sub,y_sub,dt)
2548
+ # dxdt[t] = (x[t+window//2+1] - x[t-window//2]) / (timeline[t+window//2+1] - timeline[t-window//2])
2549
+ elif mode=='forward':
2550
+ x_sub = x[t:t+window]
2551
+ y_sub = y[t:t+window]
2552
+ msd,timelag = T_MSD(x_sub,y_sub,dt)
2553
+ # dxdt[t] = (x[t+window] - x[t]) / (timeline[t+window] - timeline[t])
2554
+ elif mode=='backward':
2555
+ x_sub = x[t-window:t]
2556
+ y_sub = y[t-window:t]
2557
+ msd,timelag = T_MSD(x_sub,y_sub,dt)
2558
+ # dxdt[t] = (x[t] - x[t-window]) / (timeline[t] - timeline[t-window])
2559
+ popt,pcov = curve_fit(linear_msd,timelag[:n_points_migration],msd[:n_points_migration])
2560
+ s_msd[t] = popt[0]
2561
+ popt_alpha,pcov_alpha = curve_fit(alpha_msd,timelag[:n_points_transport],msd[:n_points_transport])
2562
+ s_alpha[t] = popt_alpha[1]
2563
+
2564
+ return s_msd, s_alpha
2565
+
2566
+ def drift_msd(t, d, v):
2567
+
2568
+ """
2569
+ Calculates the mean squared displacement (MSD) of a particle undergoing diffusion with drift.
2570
+
2571
+ The function computes the MSD for a particle that diffuses in a medium with a constant drift velocity.
2572
+ The MSD is given by the formula: MSD = 4Dt + V^2t^2, where D is the diffusion coefficient, V is the drift
2573
+ velocity, and t is the time.
2574
+
2575
+ Parameters
2576
+ ----------
2577
+ t : float or ndarray
2578
+ Time or an array of time points at which to calculate the MSD.
2579
+ d : float
2580
+ Diffusion coefficient of the particle.
2581
+ v : float
2582
+ Drift velocity of the particle.
2583
+
2584
+ Returns
2585
+ -------
2586
+ float or ndarray
2587
+ The mean squared displacement of the particle at time t. Returns a single float value if t is a float,
2588
+ or returns an array of MSD values if t is an ndarray.
2589
+
2590
+ Examples
2591
+ --------
2592
+ >>> drift_msd(t=5, d=1, v=2)
2593
+ 40
2594
+ >>> drift_msd(t=np.array([1, 2, 3]), d=1, v=2)
2595
+ array([ 6, 16, 30])
2596
+
2597
+ Notes
2598
+ -----
2599
+ - This formula assumes that the particle undergoes normal diffusion with an additional constant drift component.
2600
+ - The function can be used to model the behavior of particles in systems where both diffusion and directed motion occur.
2601
+ """
2602
+
2603
+ return 4*d*t + v**2*t**2
2604
+
2605
+ def sliding_msd_drift(x, y, timeline, window, mode='bi', n_points_migration=7, n_points_transport=7, r2_threshold=0.75):
2606
+
2607
+ """
2608
+ Computes the sliding mean squared displacement (MSD) with drift for particle trajectories.
2609
+
2610
+ This function calculates the diffusion coefficient and drift velocity of particles based on their
2611
+ x and y positions over time. It uses a sliding window approach to estimate the MSD at each point in time,
2612
+ fitting the MSD to the equation MSD = 4Dt + V^2t^2 to extract the diffusion coefficient (D) and drift velocity (V).
2613
+
2614
+ Parameters
2615
+ ----------
2616
+ x : ndarray
2617
+ The x positions of the particle over time.
2618
+ y : ndarray
2619
+ The y positions of the particle over time.
2620
+ timeline : ndarray
2621
+ The time points corresponding to the x and y positions.
2622
+ window : int
2623
+ The size of the sliding window used to calculate the MSD at each point in time.
2624
+ mode : str, optional
2625
+ The mode of sliding window calculation. Options are 'bi' for bidirectional, 'forward', or 'backward'. Default is 'bi'.
2626
+ n_points_migration : int, optional
2627
+ The number of initial points from the calculated MSD to use for fitting the migration model. Default is 7.
2628
+ n_points_transport : int, optional
2629
+ The number of initial points from the calculated MSD to use for fitting the transport model. Default is 7.
2630
+ r2_threshold : float, optional
2631
+ The R-squared threshold used to validate the fit. Default is 0.75.
2632
+
2633
+ Returns
2634
+ -------
2635
+ tuple
2636
+ A tuple containing two ndarrays: the estimated diffusion coefficients and drift velocities for each point in time.
2637
+
2638
+ Raises
2639
+ ------
2640
+ AssertionError
2641
+ If the window size is not larger than the number of fit points or if the window size is even when mode is 'bi'.
2642
+
2643
+ Notes
2644
+ -----
2645
+ - The function assumes a uniform time step between each point in the timeline.
2646
+ - The 'bi' mode requires an odd-sized window to symmetrically calculate the MSD around each point in time.
2647
+ - The curve fitting is performed using the `curve_fit` function from `scipy.optimize`, fitting to the `drift_msd` model.
2648
+
2649
+ Examples
2650
+ --------
2651
+ >>> x = np.random.rand(100)
2652
+ >>> y = np.random.rand(100)
2653
+ >>> timeline = np.arange(100)
2654
+ >>> window = 11
2655
+ >>> diffusion, velocity = sliding_msd_drift(x, y, timeline, window, mode='bi')
2656
+ # Calculates the diffusion coefficient and drift velocity using a bidirectional sliding window.
2657
+ """
2658
+
2659
+ assert window > n_points_migration,'Please set a window larger than the number of fit points...'
2660
+
2661
+ # modes = bi, forward, backward
2662
+ s_diffusion = np.zeros(len(x))
2663
+ s_diffusion[:] = np.nan
2664
+ s_velocity = np.zeros(len(x))
2665
+ s_velocity[:] = np.nan
2666
+ dt = timeline[1] - timeline[0]
2667
+
2668
+ if mode=='bi':
2669
+ assert window%2==1,'Please set an odd window for the bidirectional mode'
2670
+ lower_bound = window//2
2671
+ upper_bound = len(x) - window//2 - 1
2672
+ elif mode=='forward':
2673
+ lower_bound = 0
2674
+ upper_bound = len(x) - window
2675
+ elif mode=='backward':
2676
+ lower_bound = window
2677
+ upper_bound = len(x)
2678
+
2679
+ for t in range(lower_bound,upper_bound):
2680
+ if mode=='bi':
2681
+ x_sub = x[t-window//2:t+window//2+1]
2682
+ y_sub = y[t-window//2:t+window//2+1]
2683
+ msd,timelag = T_MSD(x_sub,y_sub,dt)
2684
+ # dxdt[t] = (x[t+window//2+1] - x[t-window//2]) / (timeline[t+window//2+1] - timeline[t-window//2])
2685
+ elif mode=='forward':
2686
+ x_sub = x[t:t+window]
2687
+ y_sub = y[t:t+window]
2688
+ msd,timelag = T_MSD(x_sub,y_sub,dt)
2689
+ # dxdt[t] = (x[t+window] - x[t]) / (timeline[t+window] - timeline[t])
2690
+ elif mode=='backward':
2691
+ x_sub = x[t-window:t]
2692
+ y_sub = y[t-window:t]
2693
+ msd,timelag = T_MSD(x_sub,y_sub,dt)
2694
+ # dxdt[t] = (x[t] - x[t-window]) / (timeline[t] - timeline[t-window])
2695
+
2696
+ popt,pcov = curve_fit(drift_msd,timelag[:n_points_migration],msd[:n_points_migration])
2697
+ #if not np.any([math.isinf(a) for a in pcov.flatten()]):
2698
+ s_diffusion[t] = popt[0]
2699
+ s_velocity[t] = popt[1]
2700
+
2701
+ return s_diffusion, s_velocity
2702
+
2703
+ def columnwise_mean(matrix, min_nbr_values = 1):
2704
+
2705
+ """
2706
+ Calculate the column-wise mean and standard deviation of non-NaN elements in the input matrix.
2707
+
2708
+ Parameters:
2709
+ ----------
2710
+ matrix : numpy.ndarray
2711
+ The input matrix for which column-wise mean and standard deviation are calculated.
2712
+ min_nbr_values : int, optional
2713
+ The minimum number of non-NaN values required in a column to calculate mean and standard deviation.
2714
+ Default is 8.
2715
+
2716
+ Returns:
2717
+ -------
2718
+ mean_line : numpy.ndarray
2719
+ An array containing the column-wise mean of non-NaN elements. Elements with fewer than `min_nbr_values` non-NaN
2720
+ values are replaced with NaN.
2721
+ mean_line_std : numpy.ndarray
2722
+ An array containing the column-wise standard deviation of non-NaN elements. Elements with fewer than `min_nbr_values`
2723
+ non-NaN values are replaced with NaN.
2724
+
2725
+ Notes:
2726
+ ------
2727
+ 1. This function calculates the mean and standard deviation of non-NaN elements in each column of the input matrix.
2728
+ 2. Columns with fewer than `min_nbr_values` non-zero elements will have NaN as the mean and standard deviation.
2729
+ 3. NaN values in the input matrix are ignored during calculation.
2730
+ """
2731
+
2732
+ mean_line = np.zeros(matrix.shape[1])
2733
+ mean_line[:] = np.nan
2734
+ mean_line_std = np.zeros(matrix.shape[1])
2735
+ mean_line_std[:] = np.nan
2736
+
2737
+ for k in range(matrix.shape[1]):
2738
+ values = matrix[:,k]
2739
+ values = values[values!=0]
2740
+ if len(values[values==values])>min_nbr_values:
2741
+ mean_line[k] = np.nanmean(values)
2742
+ mean_line_std[k] = np.nanstd(values)
2743
+ return mean_line, mean_line_std
2744
+
2745
+
2746
+ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], return_matrix=False, forced_max_duration=None, min_nbr_values=2):
2747
+
2748
+ """
2749
+ Calculate the mean and standard deviation of a specified signal for tracks of a given class in the input DataFrame.
2750
+
2751
+ Parameters:
2752
+ ----------
2753
+ df : pandas.DataFrame
2754
+ Input DataFrame containing tracking data.
2755
+ signal_name : str
2756
+ Name of the signal (column) in the DataFrame for which mean and standard deviation are calculated.
2757
+ class_col : str
2758
+ Name of the column in the DataFrame containing class labels.
2759
+ time_col : str, optional
2760
+ Name of the column in the DataFrame containing time information. Default is None.
2761
+ class_value : int, optional
2762
+ Value representing the class of interest. Default is 0.
2763
+
2764
+ Returns:
2765
+ -------
2766
+ mean_signal : numpy.ndarray
2767
+ An array containing the mean signal values for tracks of the specified class. Tracks with class not equal to
2768
+ `class_value` are excluded from the calculation.
2769
+ std_signal : numpy.ndarray
2770
+ An array containing the standard deviation of signal values for tracks of the specified class. Tracks with class
2771
+ not equal to `class_value` are excluded from the calculation.
2772
+ actual_timeline : numpy.ndarray
2773
+ An array representing the time points corresponding to the mean signal values.
2774
+
2775
+ Notes:
2776
+ ------
2777
+ 1. This function calculates the mean and standard deviation of the specified signal for tracks of a given class.
2778
+ 2. Tracks with class not equal to `class_value` are excluded from the calculation.
2779
+ 3. Tracks with missing or NaN values in the specified signal are ignored during calculation.
2780
+ 4. Tracks are aligned based on their 'FRAME' values and the specified `time_col` (if provided).
2781
+ """
2782
+
2783
+ assert signal_name in list(df.columns),"The signal you want to plot is not one of the measured features."
2784
+ if isinstance(class_value,int):
2785
+ class_value = [class_value]
2786
+
2787
+ if forced_max_duration is None:
2788
+ max_duration = ceil(np.amax(df.groupby(['position','TRACK_ID']).size().values))
2789
+ else:
2790
+ max_duration = forced_max_duration
2791
+ n_tracks = len(df.groupby(['position','TRACK_ID']))
2792
+ signal_matrix = np.zeros((n_tracks,max_duration*2 + 1))
2793
+ signal_matrix[:,:] = np.nan
2794
+
2795
+ trackid=0
2796
+ for track,track_group in df.loc[df[class_col].isin(class_value)].groupby(['position','TRACK_ID']):
2797
+ track_group = track_group.sort_values(by='FRAME')
2798
+ cclass = track_group[class_col].to_numpy()[0]
2799
+ if cclass != 0:
2800
+ ref_time = 0
2801
+ else:
2802
+ try:
2803
+ ref_time = floor(track_group[time_col].to_numpy()[0])
2804
+ except:
2805
+ continue
2806
+ signal = track_group[signal_name].to_numpy()
2807
+ timeline = track_group['FRAME'].to_numpy().astype(int)
2808
+ timeline_shifted = timeline - ref_time + max_duration
2809
+ signal_matrix[trackid,timeline_shifted] = signal
2810
+ trackid+=1
2811
+
2812
+ mean_signal, std_signal = columnwise_mean(signal_matrix, min_nbr_values=min_nbr_values)
2813
+ actual_timeline = np.linspace(-max_duration, max_duration, 2*max_duration+1)
2814
+ if return_matrix:
2815
+ return mean_signal, std_signal, actual_timeline, signal_matrix
2816
+ else:
2817
+ return mean_signal, std_signal, actual_timeline
2818
+
2819
+ if __name__ == "__main__":
2820
+
2821
+ # model = MultiScaleResNetModel(3, n_classes = 3, dropout_rate=0, dense_collection=1024, header="classifier", model_signal_length = 128)
2822
+ # print(model.summary())
2823
+ model = ResNetModelCurrent(1, 2, depth=2, use_pooling=True, n_classes = 3, dropout_rate=0.1, dense_collection=512,
2824
+ header="classifier", model_signal_length = 128)
2825
+ print(model.summary())
2826
+ #plot_model(model, to_file='test.png', show_shapes=True)