celldetective 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. celldetective/__init__.py +2 -0
  2. celldetective/__main__.py +432 -0
  3. celldetective/datasets/segmentation_annotations/blank +0 -0
  4. celldetective/datasets/signal_annotations/blank +0 -0
  5. celldetective/events.py +149 -0
  6. celldetective/extra_properties.py +100 -0
  7. celldetective/filters.py +89 -0
  8. celldetective/gui/__init__.py +20 -0
  9. celldetective/gui/about.py +44 -0
  10. celldetective/gui/analyze_block.py +563 -0
  11. celldetective/gui/btrack_options.py +898 -0
  12. celldetective/gui/classifier_widget.py +386 -0
  13. celldetective/gui/configure_new_exp.py +532 -0
  14. celldetective/gui/control_panel.py +438 -0
  15. celldetective/gui/gui_utils.py +495 -0
  16. celldetective/gui/json_readers.py +113 -0
  17. celldetective/gui/measurement_options.py +1425 -0
  18. celldetective/gui/neighborhood_options.py +452 -0
  19. celldetective/gui/plot_signals_ui.py +1042 -0
  20. celldetective/gui/process_block.py +1055 -0
  21. celldetective/gui/retrain_segmentation_model_options.py +706 -0
  22. celldetective/gui/retrain_signal_model_options.py +643 -0
  23. celldetective/gui/seg_model_loader.py +460 -0
  24. celldetective/gui/signal_annotator.py +2388 -0
  25. celldetective/gui/signal_annotator_options.py +340 -0
  26. celldetective/gui/styles.py +217 -0
  27. celldetective/gui/survival_ui.py +903 -0
  28. celldetective/gui/tableUI.py +608 -0
  29. celldetective/gui/thresholds_gui.py +1300 -0
  30. celldetective/icons/logo-large.png +0 -0
  31. celldetective/icons/logo.png +0 -0
  32. celldetective/icons/signals_icon.png +0 -0
  33. celldetective/icons/splash-test.png +0 -0
  34. celldetective/icons/splash.png +0 -0
  35. celldetective/icons/splash0.png +0 -0
  36. celldetective/icons/survival2.png +0 -0
  37. celldetective/icons/vignette_signals2.png +0 -0
  38. celldetective/icons/vignette_signals2.svg +114 -0
  39. celldetective/io.py +2050 -0
  40. celldetective/links/zenodo.json +561 -0
  41. celldetective/measure.py +1258 -0
  42. celldetective/models/segmentation_effectors/blank +0 -0
  43. celldetective/models/segmentation_generic/blank +0 -0
  44. celldetective/models/segmentation_targets/blank +0 -0
  45. celldetective/models/signal_detection/blank +0 -0
  46. celldetective/models/tracking_configs/mcf7.json +68 -0
  47. celldetective/models/tracking_configs/ricm.json +203 -0
  48. celldetective/models/tracking_configs/ricm2.json +203 -0
  49. celldetective/neighborhood.py +717 -0
  50. celldetective/scripts/analyze_signals.py +51 -0
  51. celldetective/scripts/measure_cells.py +275 -0
  52. celldetective/scripts/segment_cells.py +212 -0
  53. celldetective/scripts/segment_cells_thresholds.py +140 -0
  54. celldetective/scripts/track_cells.py +206 -0
  55. celldetective/scripts/train_segmentation_model.py +246 -0
  56. celldetective/scripts/train_signal_model.py +49 -0
  57. celldetective/segmentation.py +712 -0
  58. celldetective/signals.py +2826 -0
  59. celldetective/tracking.py +974 -0
  60. celldetective/utils.py +1681 -0
  61. celldetective-1.0.2.dist-info/LICENSE +674 -0
  62. celldetective-1.0.2.dist-info/METADATA +192 -0
  63. celldetective-1.0.2.dist-info/RECORD +66 -0
  64. celldetective-1.0.2.dist-info/WHEEL +5 -0
  65. celldetective-1.0.2.dist-info/entry_points.txt +2 -0
  66. celldetective-1.0.2.dist-info/top_level.txt +1 -0
celldetective/utils.py ADDED
@@ -0,0 +1,1681 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import re
5
+ import os
6
+ from scipy.ndimage import shift, zoom
7
+ os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '3'
8
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
9
+ from tensorflow.config import list_physical_devices
10
+ import configparser
11
+ from sklearn.utils.class_weight import compute_class_weight
12
+ from skimage.util import random_noise
13
+ from skimage.filters import gaussian
14
+ import random
15
+ from tifffile import imread
16
+ import json
17
+ from csbdeep.utils import normalize_mi_ma
18
+ from glob import glob
19
+ from urllib.request import urlopen
20
+ from urllib.parse import urlparse
21
+ import zipfile
22
+ from tqdm import tqdm
23
+ import shutil
24
+ import tempfile
25
+
26
+ def create_patch_mask(h, w, center=None, radius=None):
27
+
28
+ """
29
+
30
+ Create a circular patch mask of given dimensions.
31
+ Adapted from alkasm on https://stackoverflow.com/questions/44865023/how-can-i-create-a-circular-mask-for-a-numpy-array
32
+
33
+ Parameters
34
+ ----------
35
+ h : int
36
+ Height of the mask. Prefer odd value.
37
+ w : int
38
+ Width of the mask. Prefer odd value.
39
+ center : tuple, optional
40
+ Coordinates of the center of the patch. If not provided, the middle of the image is used.
41
+ radius : int or float or list, optional
42
+ Radius of the circular patch. If not provided, the smallest distance between the center and image walls is used.
43
+ If a list is provided, it should contain two elements representing the inner and outer radii of a circular annular patch.
44
+
45
+ Returns
46
+ -------
47
+ numpy.ndarray
48
+ Boolean mask where True values represent pixels within the circular patch or annular patch, and False values represent pixels outside.
49
+
50
+ Notes
51
+ -----
52
+ The function creates a circular patch mask of the given dimensions by determining which pixels fall within the circular patch or annular patch.
53
+ The circular patch or annular patch is centered at the specified coordinates or at the middle of the image if coordinates are not provided.
54
+ The radius of the circular patch or annular patch is determined by the provided radius parameter or by the minimum distance between the center and image walls.
55
+ If an annular patch is desired, the radius parameter should be a list containing the inner and outer radii respectively.
56
+
57
+ Examples
58
+ --------
59
+ >>> mask = create_patch_mask(100, 100, center=(50, 50), radius=30)
60
+ >>> print(mask)
61
+
62
+ """
63
+
64
+ if center is None: # use the middle of the image
65
+ center = (int(w/2), int(h/2))
66
+ if radius is None: # use the smallest distance between the center and image walls
67
+ radius = min(center[0], center[1], w-center[0], h-center[1])
68
+
69
+ Y, X = np.ogrid[:h, :w]
70
+ dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
71
+
72
+ if isinstance(radius,int) or isinstance(radius,float):
73
+ mask = dist_from_center <= radius
74
+ elif isinstance(radius,list):
75
+ mask = (dist_from_center <= radius[1])*(dist_from_center >= radius[0])
76
+ else:
77
+ print("Please provide a proper format for the radius")
78
+ return None
79
+
80
+ return mask
81
+
82
+ def rename_intensity_column(df, channels):
83
+
84
+ """
85
+
86
+ Rename intensity columns in a DataFrame based on the provided channel names.
87
+
88
+ Parameters
89
+ ----------
90
+ df : pandas DataFrame
91
+ The DataFrame containing the intensity columns.
92
+ channels : list
93
+ A list of channel names corresponding to the intensity columns.
94
+
95
+ Returns
96
+ -------
97
+ pandas DataFrame
98
+ The DataFrame with renamed intensity columns.
99
+
100
+ Notes
101
+ -----
102
+ This function renames the intensity columns in a DataFrame based on the provided channel names.
103
+ It searches for columns containing the substring 'intensity' in their names and replaces it with
104
+ the respective channel name. The renaming is performed according to the order of the channels
105
+ provided in the `channels` list. If multiple channels are provided, the function assumes that the
106
+ intensity columns have a naming pattern that includes a numerical index indicating the channel.
107
+ If only one channel is provided, the function replaces 'intensity' with the single channel name.
108
+
109
+ Examples
110
+ --------
111
+ >>> data = {'intensity_0': [1, 2, 3], 'intensity_1': [4, 5, 6]}
112
+ >>> df = pd.DataFrame(data)
113
+ >>> channels = ['channel1', 'channel2']
114
+ >>> renamed_df = rename_intensity_column(df, channels)
115
+ # Rename the intensity columns in the DataFrame based on the provided channel names.
116
+
117
+ """
118
+
119
+ channel_names = np.array(channels)
120
+ channel_indices = np.arange(len(channel_names),dtype=int)
121
+
122
+ if np.any(['intensity' in c for c in df.columns]):
123
+
124
+ intensity_indices = [s.startswith('intensity') for s in df.columns]
125
+ intensity_columns = df.columns[intensity_indices]
126
+
127
+ if len(channel_names) > 1:
128
+ to_rename = {}
129
+ for k in range(len(intensity_columns)):
130
+ #print(intensity_columns[k])
131
+
132
+ sections = np.array(re.split('-|_', intensity_columns[k]))
133
+ test_digit = np.array([s.isdigit() for s in sections])
134
+ index = int(sections[np.where(test_digit)[0]][-1])
135
+
136
+ channel_name = channel_names[np.where(channel_indices==index)[0]][0]
137
+ new_name = np.delete(sections, np.where(test_digit)[0]) #np.where(test_digit)[0]
138
+ new_name = '_'.join(list(new_name))
139
+ new_name = new_name.replace('intensity', channel_name)
140
+ to_rename.update({intensity_columns[k]: new_name.replace('-','_')})
141
+ if 'centre' in intensity_columns[k]:
142
+ # sections = np.array(re.split('-|_', intensity_columns[k]))
143
+ measure = np.array(re.split('-|_', new_name))
144
+ if sections[-2] == "0":
145
+ new_name = np.delete(measure, -1)
146
+ new_name = '_'.join(list(new_name))
147
+ if 'edge' in intensity_columns[k]:
148
+ new_name = new_name.replace('centre_of_mass_displacement', "edge_centre_of_mass_displacement_in_px")
149
+ else:
150
+ new_name = new_name.replace('centre_of_mass', "centre_of_mass_displacement_in_px")
151
+ to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
152
+ elif sections[-2] == "1":
153
+ new_name = np.delete(measure, -1)
154
+ new_name = '_'.join(list(new_name))
155
+ if 'edge' in intensity_columns[k]:
156
+ new_name = new_name.replace('centre_of_mass_displacement', "edge_centre_of_mass_orientation")
157
+ else:
158
+ new_name = new_name.replace('centre_of_mass', "centre_of_mass_orientation")
159
+ to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
160
+ if 'radial_gradient' in intensity_columns[k]:
161
+ # sections = np.array(re.split('-|_', intensity_columns[k]))
162
+ measure = np.array(re.split('-|_', new_name))
163
+ if sections[-2] == "0":
164
+ new_name = np.delete(measure, -1)
165
+ new_name = '_'.join(list(measure))
166
+ new_name = new_name.replace('radial_gradient', "radial_gradient")
167
+ to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
168
+ elif sections[-2] == "1":
169
+ new_name = np.delete(measure, -1)
170
+ new_name = '_'.join(list(measure))
171
+ new_name = new_name.replace('radial_gradient', "radial_intercept")
172
+ to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
173
+
174
+
175
+ else:
176
+ to_rename = {}
177
+ for k in range(len(intensity_columns)):
178
+ sections = np.array(re.split('_|-', intensity_columns[k]))
179
+ channel_name = channel_names[0]
180
+ test_digit = np.array([s.isdigit() for s in sections])
181
+ new_name = np.delete(sections, np.where(test_digit)[0])
182
+ new_name = '_'.join(list(new_name))
183
+ new_name = new_name.replace('intensity', channel_name)
184
+ to_rename.update({intensity_columns[k]: new_name.replace('-','_')})
185
+ if 'centre' in intensity_columns[k]:
186
+ measure = np.array(re.split('-|_', new_name))
187
+ if sections[-2] == "0":
188
+ new_name = np.delete(measure, -1)
189
+ new_name = '_'.join(list(new_name))
190
+ if 'edge' in intensity_columns[k]:
191
+ new_name = new_name.replace('centre_of_mass_displacement', "edge_centre_of_mass_displacement_in_px")
192
+ else:
193
+ new_name = new_name.replace('centre_of_mass', "centre_of_mass_displacement_in_px")
194
+ to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
195
+ if sections[-2] == "1":
196
+ new_name = np.delete(measure, -1)
197
+ new_name = '_'.join(list(new_name))
198
+ if 'edge' in intensity_columns[k]:
199
+ new_name = new_name.replace('centre_of_mass_displacement', "edge_centre_of_mass_orientation")
200
+ else:
201
+ new_name = new_name.replace('centre_of_mass', "centre_of_mass_orientation")
202
+ to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
203
+ if 'radial_gradient' in intensity_columns[k]:
204
+ # sections = np.array(re.split('-|_', intensity_columns[k]))
205
+ measure = np.array(re.split('-|_', new_name))
206
+ if sections[-2] == "0":
207
+ #new_name = np.delete(measure, -1)
208
+ new_name = '_'.join(list(measure))
209
+ new_name = new_name.replace('radial_gradient', "radial_gradient")
210
+ to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
211
+ elif sections[-2] == "1":
212
+ #new_name = np.delete(measure, -1)
213
+ new_name = '_'.join(list(measure))
214
+ new_name = new_name.replace('radial_gradient', "radial_intercept")
215
+ to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
216
+
217
+ df = df.rename(columns=to_rename)
218
+
219
+ return df
220
+
221
+
222
+ def regression_plot(y_pred, y_true, savepath=None):
223
+
224
+ """
225
+
226
+ Create a regression plot to compare predicted and ground truth values.
227
+
228
+ Parameters
229
+ ----------
230
+ y_pred : array-like
231
+ Predicted values.
232
+ y_true : array-like
233
+ Ground truth values.
234
+ savepath : str or None, optional
235
+ File path to save the plot. If None, the plot is displayed but not saved. Default is None.
236
+
237
+ Returns
238
+ -------
239
+ None
240
+
241
+ Notes
242
+ -----
243
+ This function creates a scatter plot comparing the predicted values (`y_pred`) to the ground truth values (`y_true`)
244
+ for regression analysis. The plot also includes a diagonal reference line to visualize the ideal prediction scenario.
245
+
246
+ If `savepath` is provided, the plot is saved as an image file at the specified path. The file format and other
247
+ parameters can be controlled by the `savepath` argument.
248
+
249
+ Examples
250
+ --------
251
+ >>> y_pred = [1.5, 2.0, 3.2, 4.1]
252
+ >>> y_true = [1.7, 2.1, 3.5, 4.2]
253
+ >>> regression_plot(y_pred, y_true)
254
+ # Create a scatter plot comparing the predicted values to the ground truth values.
255
+
256
+ >>> regression_plot(y_pred, y_true, savepath="regression_plot.png")
257
+ # Create a scatter plot and save it as "regression_plot.png".
258
+
259
+ """
260
+
261
+ fig,ax = plt.subplots(1,1,figsize=(4,3))
262
+ ax.scatter(y_pred, y_true)
263
+ ax.set_xlabel("prediction")
264
+ ax.set_ylabel("ground truth")
265
+ line = np.linspace(np.amin([y_pred,y_true]),np.amax([y_pred,y_true]),1000)
266
+ ax.plot(line,line,linestyle="--",c="k",alpha=0.7)
267
+ plt.tight_layout()
268
+ if savepath is not None:
269
+ plt.savefig(savepath,bbox_inches="tight",dpi=300)
270
+ plt.pause(2)
271
+ plt.close()
272
+
273
+ def split_by_ratio(arr, *ratios):
274
+
275
+ """
276
+
277
+ Split an array into multiple chunks based on given ratios.
278
+
279
+ Parameters
280
+ ----------
281
+ arr : array-like
282
+ The input array to be split.
283
+ *ratios : float
284
+ Ratios specifying the proportions of each chunk. The sum of ratios should be less than or equal to 1.
285
+
286
+ Returns
287
+ -------
288
+ list
289
+ A list of arrays containing the splits/chunks of the input array.
290
+
291
+ Notes
292
+ -----
293
+ This function randomly permutes the input array (`arr`) and then splits it into multiple chunks based on the provided ratios.
294
+ The ratios determine the relative sizes of the resulting chunks. The sum of the ratios should be less than or equal to 1.
295
+ The function uses the accumulated ratios to determine the split indices.
296
+
297
+ The function returns a list of arrays representing the splits of the input array. The number of splits is equal to the number
298
+ of provided ratios. If there are more ratios than splits, the extra ratios are ignored.
299
+
300
+ Examples
301
+ --------
302
+ >>> arr = np.arange(10)
303
+ >>> splits = split_by_ratio(arr, 0.6, 0.2, 0.2)
304
+ >>> print(len(splits))
305
+ 3
306
+ # Split the array into 3 chunks with ratios 0.6, 0.2, and 0.2.
307
+
308
+ >>> arr = np.arange(100)
309
+ >>> splits = split_by_ratio(arr, 0.5, 0.25)
310
+ >>> print([len(split) for split in splits])
311
+ [50, 25]
312
+ # Split the array into 2 chunks with ratios 0.5 and 0.25.
313
+
314
+ """
315
+
316
+ arr = np.random.permutation(arr)
317
+ ind = np.add.accumulate(np.array(ratios) * len(arr)).astype(int)
318
+ return [x.tolist() for x in np.split(arr, ind)][:len(ratios)]
319
+
320
+ def compute_weights(y):
321
+
322
+ """
323
+
324
+ Compute class weights based on the input labels.
325
+
326
+ Parameters
327
+ ----------
328
+ y : array-like
329
+ Array of labels.
330
+
331
+ Returns
332
+ -------
333
+ dict
334
+ A dictionary containing the computed class weights.
335
+
336
+ Notes
337
+ -----
338
+ This function calculates the class weights based on the input labels (`y`) using the "balanced" method.
339
+ The class weights are computed to address the class imbalance problem, where the weights are inversely
340
+ proportional to the class frequencies.
341
+
342
+ The function returns a dictionary (`class_weights`) where the keys represent the unique classes in `y`
343
+ and the values represent the computed weights for each class.
344
+
345
+ Examples
346
+ --------
347
+ >>> labels = np.array([0, 1, 0, 1, 1])
348
+ >>> weights = compute_weights(labels)
349
+ >>> print(weights)
350
+ {0: 1.5, 1: 0.75}
351
+ # Compute class weights for the binary labels.
352
+
353
+ >>> labels = np.array([0, 1, 2, 0, 1, 2, 2])
354
+ >>> weights = compute_weights(labels)
355
+ >>> print(weights)
356
+ {0: 1.1666666666666667, 1: 1.1666666666666667, 2: 0.5833333333333334}
357
+ # Compute class weights for the multi-class labels.
358
+
359
+ """
360
+
361
+ class_weights = compute_class_weight(
362
+ class_weight = "balanced",
363
+ classes = np.unique(y),
364
+ y = y,
365
+ )
366
+ class_weights = dict(zip(np.unique(y), class_weights))
367
+
368
+ return class_weights
369
+
370
+ def train_test_split(data_x, data_y1, data_y2=None, validation_size=0.25, test_size=0):
371
+
372
+ """
373
+
374
+ Split the dataset into training, validation, and test sets.
375
+
376
+ Parameters
377
+ ----------
378
+ data_x : array-like
379
+ Input features or independent variables.
380
+ data_y1 : array-like
381
+ Target variable 1.
382
+ data_y2 : array-like
383
+ Target variable 2.
384
+ validation_size : float, optional
385
+ Proportion of the dataset to include in the validation set. Default is 0.25.
386
+ test_size : float, optional
387
+ Proportion of the dataset to include in the test set. Default is 0.
388
+
389
+ Returns
390
+ -------
391
+ dict
392
+ A dictionary containing the split datasets.
393
+ Keys: "x_train", "x_val", "y1_train", "y1_val", "y2_train", "y2_val".
394
+ If test_size > 0, additional keys: "x_test", "y1_test", "y2_test".
395
+
396
+ Notes
397
+ -----
398
+ This function divides the dataset into training, validation, and test sets based on the specified proportions.
399
+ It shuffles the data and splits it according to the proportions defined by `validation_size` and `test_size`.
400
+
401
+ The input features (`data_x`) and target variables (`data_y1`, `data_y2`) should be arrays or array-like objects
402
+ with compatible dimensions.
403
+
404
+ The function returns a dictionary containing the split datasets. The training set is assigned to "x_train",
405
+ "y1_train", and "y2_train". The validation set is assigned to "x_val", "y1_val", and "y2_val". If `test_size` is
406
+ greater than 0, the test set is assigned to "x_test", "y1_test", and "y2_test".
407
+
408
+ """
409
+
410
+ n_values = len(data_x)
411
+ randomize = np.arange(n_values)
412
+ np.random.shuffle(randomize)
413
+
414
+ train_percentage = 1- validation_size - test_size
415
+ chunks = split_by_ratio(randomize, train_percentage, validation_size, test_size)
416
+
417
+ x_train = data_x[chunks[0]]
418
+ y1_train = data_y1[chunks[0]]
419
+ if data_y2 is not None:
420
+ y2_train = data_y2[chunks[0]]
421
+
422
+
423
+ x_val = data_x[chunks[1]]
424
+ y1_val = data_y1[chunks[1]]
425
+ if data_y2 is not None:
426
+ y2_val = data_y2[chunks[1]]
427
+
428
+ ds = {"x_train": x_train, "x_val": x_val,
429
+ "y1_train": y1_train, "y1_val": y1_val}
430
+ if data_y2 is not None:
431
+ ds.update({"y2_train": y2_train, "y2_val": y2_val})
432
+
433
+ if test_size>0:
434
+ x_test = data_x[chunks[2]]
435
+ y1_test = data_y1[chunks[2]]
436
+ ds.update({"x_test": x_test, "y1_test": y1_test})
437
+ if data_y2 is not None:
438
+ y2_test = data_y2[chunks[2]]
439
+ ds.update({"y2_test": y2_test})
440
+ return ds
441
+
442
+ def remove_redundant_features(features, reference_features, channel_names=None):
443
+
444
+ """
445
+
446
+ Remove redundant features from a list of features based on a reference feature list.
447
+
448
+ Parameters
449
+ ----------
450
+ features : list
451
+ The list of features to be filtered.
452
+ reference_features : list
453
+ The reference list of features.
454
+ channel_names : list or None, optional
455
+ The list of channel names. If provided, it is used to identify and remove redundant intensity features.
456
+ Default is None.
457
+
458
+ Returns
459
+ -------
460
+ list
461
+ The filtered list of features without redundant entries.
462
+
463
+ Notes
464
+ -----
465
+ This function removes redundant features from the input list based on a reference list of features. Features that
466
+ appear in the reference list are removed from the input list. Additionally, if the channel_names parameter is provided,
467
+ it is used to identify and remove redundant intensity features. Intensity features that have the same mode (e.g., 'mean',
468
+ 'min', 'max') as any of the channel names in the reference list are also removed.
469
+
470
+ Examples
471
+ --------
472
+ >>> features = ['area', 'intensity_mean', 'intensity_max', 'eccentricity']
473
+ >>> reference_features = ['area', 'eccentricity']
474
+ >>> filtered_features = remove_redundant_features(features, reference_features)
475
+ >>> filtered_features
476
+ ['intensity_mean', 'intensity_max']
477
+
478
+ >>> channel_names = ['brightfield', 'channel1', 'channel2']
479
+ >>> filtered_features = remove_redundant_features(features, reference_features, channel_names)
480
+ >>> filtered_features
481
+ ['area', 'eccentricity']
482
+
483
+ """
484
+
485
+ new_features = features.copy()
486
+
487
+ for f in features:
488
+
489
+ if f in reference_features:
490
+ new_features.remove(f)
491
+
492
+ if ('intensity' in f) and (channel_names is not None):
493
+
494
+ mode = f.split('_')[-1]
495
+ pattern = [a+'_'+mode for a in channel_names]
496
+
497
+ for p in pattern:
498
+ if p in reference_features:
499
+ try:
500
+ new_features.remove(f)
501
+ except:
502
+ pass
503
+ return new_features
504
+
505
+ def _estimate_scale_factor(spatial_calibration, required_spatial_calibration):
506
+
507
+ """
508
+ Estimates the scale factor needed to adjust spatial calibration to a required value.
509
+
510
+ This function calculates the scale factor by which spatial dimensions (e.g., in microscopy images)
511
+ should be adjusted to align with a specified calibration standard. This is particularly useful when
512
+ preparing data for analysis with models trained on data of a specific spatial calibration.
513
+
514
+ Parameters
515
+ ----------
516
+ spatial_calibration : float or None
517
+ The current spatial calibration factor of the data, expressed as units per pixel (e.g., micrometers per pixel).
518
+ If None, indicates that the current spatial calibration is unknown or unspecified.
519
+ required_spatial_calibration : float or None
520
+ The spatial calibration factor required for compatibility with the model or analysis standard, expressed
521
+ in the same units as `spatial_calibration`. If None, indicates no adjustment is required.
522
+
523
+ Returns
524
+ -------
525
+ float or None
526
+ The scale factor by which the current data should be rescaled to match the required spatial calibration,
527
+ or None if no scaling is necessary or if insufficient information is provided.
528
+
529
+ Notes
530
+ -----
531
+ - A scale factor close to 1 (within a tolerance defined by `epsilon`) indicates that no significant rescaling
532
+ is needed, and the function returns None.
533
+ - The function issues a warning if a significant rescaling is necessary, indicating the scale factor to be applied.
534
+
535
+ Examples
536
+ --------
537
+ >>> scale_factor = _estimate_scale_factor(spatial_calibration=0.5, required_spatial_calibration=0.25)
538
+ # Each frame will be rescaled by a factor 2.0 to match with the model training data...
539
+
540
+ >>> scale_factor = _estimate_scale_factor(spatial_calibration=None, required_spatial_calibration=0.25)
541
+ # Returns None due to insufficient information about current spatial calibration.
542
+ """
543
+
544
+ if (required_spatial_calibration is not None)*(spatial_calibration is not None):
545
+ scale = spatial_calibration / required_spatial_calibration
546
+ else:
547
+ scale = None
548
+
549
+ epsilon = 0.05
550
+ if scale is not None:
551
+ if not np.all([scale >= (1-epsilon), scale <= (1+epsilon)]):
552
+ print(f"Each frame will be rescaled by a factor {scale} to match with the model training data...")
553
+ else:
554
+ scale = None
555
+ return scale
556
+
557
+ def auto_find_gpu():
558
+
559
+ """
560
+ Automatically detects the presence of GPU devices in the system.
561
+
562
+ This function checks if any GPU devices are available for use by querying the system's physical devices.
563
+ It is a utility function to simplify the process of determining whether GPU-accelerated computing can be
564
+ leveraged in data processing or model training tasks.
565
+
566
+ Returns
567
+ -------
568
+ bool
569
+ True if one or more GPU devices are detected, False otherwise.
570
+
571
+ Notes
572
+ -----
573
+ - The function uses TensorFlow's `list_physical_devices` method to query available devices, specifically
574
+ looking for 'GPU' devices.
575
+ - This function is useful for dynamically adjusting computation strategies based on available hardware resources.
576
+
577
+ Examples
578
+ --------
579
+ >>> has_gpu = auto_find_gpu()
580
+ >>> print(f"GPU available: {has_gpu}")
581
+ # GPU available: True or False based on the system's hardware configuration.
582
+ """
583
+
584
+ gpus = list_physical_devices('GPU')
585
+ if len(gpus)>0:
586
+ use_gpu = True
587
+ else:
588
+ use_gpu = False
589
+
590
+ return use_gpu
591
+
592
+ def _extract_channel_indices(channels, required_channels):
593
+
594
+ """
595
+ Extracts the indices of required channels from a list of available channels.
596
+
597
+ This function is designed to match the channels required by a model or analysis process with the channels
598
+ present in the dataset. It returns the indices of the required channels within the list of available channels.
599
+ If the required channels are not found among the available channels, the function prints an error message and
600
+ returns None.
601
+
602
+ Parameters
603
+ ----------
604
+ channels : list of str or None
605
+ A list containing the names of the channels available in the dataset. If None, it is assumed that the
606
+ dataset channels are in the same order as the required channels.
607
+ required_channels : list of str
608
+ A list containing the names of the channels required by the model or analysis process.
609
+
610
+ Returns
611
+ -------
612
+ ndarray or None
613
+ An array of indices indicating the positions of the required channels within the list of available
614
+ channels. Returns None if there is a mismatch between required and available channels.
615
+
616
+ Notes
617
+ -----
618
+ - The function is useful for preprocessing steps where specific channels of multi-channel data are needed
619
+ for further analysis or model input.
620
+ - In cases where `channels` is None, indicating that the dataset does not specify channel names, the function
621
+ assumes that the dataset's channel order matches the order of `required_channels` and returns an array of
622
+ indices based on this assumption.
623
+
624
+ Examples
625
+ --------
626
+ >>> available_channels = ['DAPI', 'GFP', 'RFP']
627
+ >>> required_channels = ['GFP', 'RFP']
628
+ >>> indices = _extract_channel_indices(available_channels, required_channels)
629
+ >>> print(indices)
630
+ # [1, 2]
631
+
632
+ >>> indices = _extract_channel_indices(None, required_channels)
633
+ >>> print(indices)
634
+ # [0, 1]
635
+ """
636
+
637
+ if channels is not None:
638
+ channel_indices = []
639
+ for ch in required_channels:
640
+
641
+ try:
642
+ idx = channels.index(ch)
643
+ except ValueError:
644
+ print('Mismatch between the channels required by the model and the provided channels.')
645
+ return None
646
+
647
+ channel_indices.append(idx)
648
+ channel_indices = np.array(channel_indices)
649
+ else:
650
+ channel_indices = np.arange(len(required_channels))
651
+
652
+ return channel_indices
653
+
654
+ def ConfigSectionMap(path,section):
655
+
656
+ """
657
+ Parse the config file to extract experiment parameters
658
+ following https://wiki.python.org/moin/ConfigParserExamples
659
+
660
+ Parameters
661
+ ----------
662
+
663
+ path: str
664
+ path to the config.ini file
665
+
666
+ section: str
667
+ name of the section that contains the parameter
668
+
669
+ Returns
670
+ -------
671
+
672
+ dict1: dictionary
673
+
674
+ """
675
+
676
+ Config = configparser.ConfigParser()
677
+ Config.read(path)
678
+ dict1 = {}
679
+ options = Config.options(section)
680
+ for option in options:
681
+ try:
682
+ dict1[option] = Config.get(section, option)
683
+ if dict1[option] == -1:
684
+ DebugPrint("skip: %s" % option)
685
+ except:
686
+ print("exception on %s!" % option)
687
+ dict1[option] = None
688
+ return dict1
689
+
690
+ def _extract_channel_indices_from_config(config, channels_to_extract):
691
+
692
+ """
693
+ Extracts the indices of specified channels from a configuration object.
694
+
695
+ This function attempts to map required channel names to their respective indices as specified in a
696
+ configuration file. It supports two versions of configuration parsing: a primary method (V2) and a
697
+ fallback legacy method. If the required channels are not found using the primary method, the function
698
+ attempts to find them using the legacy configuration settings.
699
+
700
+ Parameters
701
+ ----------
702
+ config : ConfigParser object
703
+ The configuration object parsed from a .ini or similar configuration file that includes channel settings.
704
+ channels_to_extract : list of str
705
+ A list of channel names for which indices are to be extracted from the configuration settings.
706
+
707
+ Returns
708
+ -------
709
+ list of int or None
710
+ A list containing the indices of the specified channels as found in the configuration settings.
711
+ If a channel cannot be found, None is appended in its place. If an error occurs during the extraction
712
+ process, the function returns None.
713
+
714
+ Notes
715
+ -----
716
+ - This function is designed to be flexible, accommodating changes in configuration file structure by
717
+ checking multiple sections for the required information.
718
+ - The configuration file is expected to contain either "Channels" or "MovieSettings" sections with mappings
719
+ from channel names to indices.
720
+ - An error message is printed if a required channel cannot be found, advising the user to check the
721
+ configuration file.
722
+
723
+ Examples
724
+ --------
725
+ >>> config = ConfigParser()
726
+ >>> config.read('example_config.ini')
727
+ >>> channels_to_extract = ['GFP', 'RFP']
728
+ >>> channel_indices = _extract_channel_indices_from_config(config, channels_to_extract)
729
+ >>> print(channel_indices)
730
+ # [1, 2] or None if an error occurs or the channels are not found.
731
+ """
732
+
733
+ # V2
734
+ channels = []
735
+ for c in channels_to_extract:
736
+ if c!='None' and c is not None:
737
+ try:
738
+ c1 = int(ConfigSectionMap(config,"Channels")[c])
739
+ channels.append(c1)
740
+ except Exception as e:
741
+ print(f"Error {e}. The channel required by the model is not available in your data... Check the configuration file.")
742
+ channels = None
743
+ break
744
+ else:
745
+ channels.append(None)
746
+
747
+ # LEGACY
748
+ if channels is None:
749
+ channels = []
750
+ for c in channels_to_extract:
751
+ try:
752
+ c1 = int(ConfigSectionMap(config,"MovieSettings")[c])
753
+ channels.append(c1)
754
+ except Exception as e:
755
+ print(f"Error {e}. The channel required by the model is not available in your data... Check the configuration file.")
756
+ channels = None
757
+ break
758
+ return channels
759
+
760
+ def _extract_nbr_channels_from_config(config, return_names=False):
761
+
762
+ """
763
+ Extracts the indices of specified channels from a configuration object.
764
+
765
+ This function attempts to map required channel names to their respective indices as specified in a
766
+ configuration file. It supports two versions of configuration parsing: a primary method (V2) and a
767
+ fallback legacy method. If the required channels are not found using the primary method, the function
768
+ attempts to find them using the legacy configuration settings.
769
+
770
+ Parameters
771
+ ----------
772
+ config : ConfigParser object
773
+ The configuration object parsed from a .ini or similar configuration file that includes channel settings.
774
+ channels_to_extract : list of str
775
+ A list of channel names for which indices are to be extracted from the configuration settings.
776
+
777
+ Returns
778
+ -------
779
+ list of int or None
780
+ A list containing the indices of the specified channels as found in the configuration settings.
781
+ If a channel cannot be found, None is appended in its place. If an error occurs during the extraction
782
+ process, the function returns None.
783
+
784
+ Notes
785
+ -----
786
+ - This function is designed to be flexible, accommodating changes in configuration file structure by
787
+ checking multiple sections for the required information.
788
+ - The configuration file is expected to contain either "Channels" or "MovieSettings" sections with mappings
789
+ from channel names to indices.
790
+ - An error message is printed if a required channel cannot be found, advising the user to check the
791
+ configuration file.
792
+
793
+ Examples
794
+ --------
795
+ >>> config = ConfigParser()
796
+ >>> config.read('example_config.ini')
797
+ >>> channels_to_extract = ['GFP', 'RFP']
798
+ >>> channel_indices = _extract_channel_indices_from_config(config, channels_to_extract)
799
+ >>> print(channel_indices)
800
+ # [1, 2] or None if an error occurs or the channels are not found.
801
+ """
802
+
803
+ # V2
804
+ nbr_channels = 0
805
+ channels = []
806
+ try:
807
+ fields = ConfigSectionMap(config,"Channels")
808
+ for c in fields:
809
+ try:
810
+ channel = int(ConfigSectionMap(config, "Channels")[c])
811
+ nbr_channels += 1
812
+ channels.append(c)
813
+ except:
814
+ pass
815
+ except:
816
+ pass
817
+
818
+ if nbr_channels==0:
819
+
820
+ # Read channels LEGACY
821
+ nbr_channels = 0
822
+ channels = []
823
+ try:
824
+ brightfield_channel = int(ConfigSectionMap(config,"MovieSettings")["brightfield_channel"])
825
+ nbr_channels += 1
826
+ channels.append('brightfield_channel')
827
+ except:
828
+ brightfield_channel = None
829
+
830
+ try:
831
+ live_nuclei_channel = int(ConfigSectionMap(config,"MovieSettings")["live_nuclei_channel"])
832
+ nbr_channels += 1
833
+ channels.append('live_nuclei_channel')
834
+ except:
835
+ live_nuclei_channel = None
836
+
837
+ try:
838
+ dead_nuclei_channel = int(ConfigSectionMap(config,"MovieSettings")["dead_nuclei_channel"])
839
+ nbr_channels +=1
840
+ channels.append('dead_nuclei_channel')
841
+ except:
842
+ dead_nuclei_channel = None
843
+
844
+ try:
845
+ effector_fluo_channel = int(ConfigSectionMap(config,"MovieSettings")["effector_fluo_channel"])
846
+ nbr_channels +=1
847
+ channels.append('effector_fluo_channel')
848
+ except:
849
+ effector_fluo_channel = None
850
+
851
+ try:
852
+ adhesion_channel = int(ConfigSectionMap(config,"MovieSettings")["adhesion_channel"])
853
+ nbr_channels += 1
854
+ channels.append('adhesion_channel')
855
+ except:
856
+ adhesion_channel = None
857
+
858
+ try:
859
+ fluo_channel_1 = int(ConfigSectionMap(config,"MovieSettings")["fluo_channel_1"])
860
+ nbr_channels += 1
861
+ channels.append('fluo_channel_1')
862
+ except:
863
+ fluo_channel_1 = None
864
+
865
+ try:
866
+ fluo_channel_2 = int(ConfigSectionMap(config,"MovieSettings")["fluo_channel_2"])
867
+ nbr_channels += 1
868
+ channels.append('fluo_channel_2')
869
+ except:
870
+ fluo_channel_2 = None
871
+
872
+ if return_names:
873
+ return nbr_channels,channels
874
+ else:
875
+ return nbr_channels
876
+
877
+ def _get_img_num_per_channel(channels_indices, len_movie, nbr_channels):
878
+
879
+ """
880
+ Calculates the image frame numbers for each specified channel in a multi-channel movie.
881
+
882
+ Given the indices of channels of interest, the total length of the movie, and the number of channels,
883
+ this function computes the frame numbers corresponding to each channel throughout the movie. If a
884
+ channel index is specified as None, it assigns a placeholder value to indicate no frames for that channel.
885
+
886
+ Parameters
887
+ ----------
888
+ channels_indices : list of int or None
889
+ A list containing the indices of channels for which to calculate frame numbers. If an index is None,
890
+ it is interpreted as a channel with no frames to be processed.
891
+ len_movie : int
892
+ The total number of frames in the movie across all channels.
893
+ nbr_channels : int
894
+ The total number of channels in the movie.
895
+
896
+ Returns
897
+ -------
898
+ ndarray
899
+ A 2D numpy array where each row corresponds to a channel specified in `channels_indices` and contains
900
+ the frame numbers for that channel throughout the movie. If a channel index is None, the corresponding
901
+ row contains placeholder values (-1).
902
+
903
+ Notes
904
+ -----
905
+ - The function assumes that frames in the movie are interleaved by channel, with frames for each channel
906
+ appearing in a regular sequence throughout the movie.
907
+ - This utility is particularly useful for multi-channel time-lapse movies where analysis or processing
908
+ needs to be performed on a per-channel basis.
909
+
910
+ Examples
911
+ --------
912
+ >>> channels_indices = [0, 2, None] # Indices for channels 1, 3, and a non-existing channel
913
+ >>> len_movie = 10 # Total frames for each channel
914
+ >>> nbr_channels = 3 # Total channels in the movie
915
+ >>> img_num_per_channel = _get_img_num_per_channel(channels_indices, len_movie, nbr_channels)
916
+ >>> print(img_num_per_channel)
917
+ # [[ 0 3 6 9 12 15 18 21 24 27]
918
+ # [ 2 5 8 11 14 17 20 23 26 29]
919
+ # [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1]]
920
+ """
921
+
922
+ len_movie = int(len_movie)
923
+ nbr_channels = int(nbr_channels)
924
+
925
+ img_num_all_channels = []
926
+ for c in channels_indices:
927
+ if c is not None:
928
+ indices = np.arange(len_movie*nbr_channels)[c::nbr_channels]
929
+ else:
930
+ indices = [-1]*len_movie
931
+ img_num_all_channels.append(indices)
932
+ img_num_all_channels = np.array(img_num_all_channels, dtype=int)
933
+ return img_num_all_channels
934
+
935
+ def _extract_labels_from_config(config,number_of_wells):
936
+
937
+ """
938
+
939
+ Extract each well's biological condition from the configuration file
940
+
941
+ Parameters
942
+ ----------
943
+
944
+ config: str,
945
+ path to the configuration file
946
+
947
+ number_of_wells: int,
948
+ total number of wells in the experiment
949
+
950
+ Returns
951
+ -------
952
+
953
+ labels: string of the biological condition for each well
954
+
955
+ """
956
+
957
+ try:
958
+ concentrations = ConfigSectionMap(config,"Labels")["concentrations"].split(",")
959
+ cell_types = ConfigSectionMap(config,"Labels")["cell_types"].split(",")
960
+ antibodies = ConfigSectionMap(config,"Labels")["antibodies"].split(",")
961
+ pharmaceutical_agents = ConfigSectionMap(config,"Labels")["pharmaceutical_agents"].split(",")
962
+ index = np.arange(len(concentrations)).astype(int) + 1
963
+ if not np.all(pharmaceutical_agents=="None"):
964
+ labels = [f"W{idx}: [CT] "+a+"; [Ab] "+b+" @ "+c+" pM "+d for idx,a,b,c,d in zip(index,cell_types,antibodies,concentrations,pharmaceutical_agents)]
965
+ else:
966
+ labels = [f"W{idx}: [CT] "+a+"; [Ab] "+b+" @ "+c+" pM " for idx,a,b,c in zip(index,cell_types,antibodies,concentrations)]
967
+
968
+
969
+ except Exception as e:
970
+ print(f"{e}: the well labels cannot be read from the concentration and cell_type fields")
971
+ labels = np.linspace(0,number_of_wells-1,number_of_wells,dtype=str)
972
+
973
+ return(labels)
974
+
975
+ def extract_experiment_channels(config):
976
+
977
+ """
978
+ Extracts channel names and their indices from an experiment configuration.
979
+
980
+ This function attempts to parse channel information from a given configuration object, supporting
981
+ both a newer (V2) and a legacy format. It first tries to extract channel names and indices according
982
+ to the V2 format from the "Channels" section. If no channels are found or if the section does not
983
+ exist, it falls back to extracting specific channel information from the "MovieSettings" section
984
+ based on predefined channel names.
985
+
986
+ Parameters
987
+ ----------
988
+ config : ConfigParser object
989
+ The configuration object parsed from an experiment's .ini or similar configuration file.
990
+
991
+ Returns
992
+ -------
993
+ tuple
994
+ A tuple containing two numpy arrays: `channel_names` and `channel_indices`. `channel_names` includes
995
+ the names of the channels as specified in the configuration, and `channel_indices` includes their
996
+ corresponding indices. Both arrays are ordered according to the channel indices.
997
+
998
+ Notes
999
+ -----
1000
+ - The function supports extracting a variety of channel types, including brightfield, live and dead nuclei
1001
+ channels, effector fluorescence channels, adhesion channels, and generic fluorescence channels.
1002
+ - If channel information cannot be parsed or if required fields are missing, the function returns empty arrays.
1003
+ - This utility is particularly useful for preprocessing steps where specific channels of multi-channel
1004
+ experimental data are needed for further analysis or model input.
1005
+
1006
+ Examples
1007
+ --------
1008
+ >>> config = ConfigParser()
1009
+ >>> config.read('experiment_config.ini')
1010
+ >>> channel_names, channel_indices = extract_experiment_channels(config)
1011
+ # Extracts and sorts channel information based on indices from the experiment configuration.
1012
+ """
1013
+
1014
+ # V2
1015
+ channel_names = []
1016
+ channel_indices = []
1017
+ try:
1018
+ fields = ConfigSectionMap(config,"Channels")
1019
+ for c in fields:
1020
+ try:
1021
+ idx = int(ConfigSectionMap(config, "Channels")[c])
1022
+ channel_names.append(c)
1023
+ channel_indices.append(idx)
1024
+ except:
1025
+ pass
1026
+ except:
1027
+ pass
1028
+
1029
+
1030
+ if not channel_names:
1031
+ # LEGACY
1032
+ # Remap intensities to channel:
1033
+ channel_names = []
1034
+ channel_indices = []
1035
+
1036
+ try:
1037
+ brightfield_channel = int(ConfigSectionMap(config,"MovieSettings")["brightfield_channel"])
1038
+ channel_names.append("brightfield_channel")
1039
+ channel_indices.append(brightfield_channel)
1040
+ #exp_channels.update({"brightfield_channel": brightfield_channel})
1041
+ except:
1042
+ pass
1043
+ try:
1044
+ live_nuclei_channel = int(ConfigSectionMap(config,"MovieSettings")["live_nuclei_channel"])
1045
+ channel_names.append("live_nuclei_channel")
1046
+ channel_indices.append(live_nuclei_channel)
1047
+ #exp_channels.update({"live_nuclei_channel": live_nuclei_channel})
1048
+ except:
1049
+ pass
1050
+ try:
1051
+ dead_nuclei_channel = int(ConfigSectionMap(config,"MovieSettings")["dead_nuclei_channel"])
1052
+ channel_names.append("dead_nuclei_channel")
1053
+ channel_indices.append(dead_nuclei_channel)
1054
+ #exp_channels.update({"dead_nuclei_channel": dead_nuclei_channel})
1055
+ except:
1056
+ pass
1057
+ try:
1058
+ effector_fluo_channel = int(ConfigSectionMap(config,"MovieSettings")["effector_fluo_channel"])
1059
+ channel_names.append("effector_fluo_channel")
1060
+ channel_indices.append(effector_fluo_channel)
1061
+ #exp_channels.update({"effector_fluo_channel": effector_fluo_channel})
1062
+ except:
1063
+ pass
1064
+ try:
1065
+ adhesion_channel = int(ConfigSectionMap(config,"MovieSettings")["adhesion_channel"])
1066
+ channel_names.append("adhesion_channel")
1067
+ channel_indices.append(adhesion_channel)
1068
+ #exp_channels.update({"adhesion_channel": adhesion_channel})
1069
+ except:
1070
+ pass
1071
+ try:
1072
+ fluo_channel_1 = int(ConfigSectionMap(config,"MovieSettings")["fluo_channel_1"])
1073
+ channel_names.append("fluo_channel_1")
1074
+ channel_indices.append(fluo_channel_1)
1075
+ #exp_channels.update({"fluo_channel_1": fluo_channel_1})
1076
+ except:
1077
+ pass
1078
+ try:
1079
+ fluo_channel_2 = int(ConfigSectionMap(config,"MovieSettings")["fluo_channel_2"])
1080
+ channel_names.append("fluo_channel_2")
1081
+ channel_indices.append(fluo_channel_2)
1082
+ #exp_channels.update({"fluo_channel_2": fluo_channel_2})
1083
+ except:
1084
+ pass
1085
+
1086
+ channel_indices = np.array(channel_indices)
1087
+ channel_names = np.array(channel_names)
1088
+ reorder = np.argsort(channel_indices)
1089
+ channel_indices = channel_indices[reorder]
1090
+ channel_names = channel_names[reorder]
1091
+
1092
+ return channel_names, channel_indices
1093
+
1094
+ def get_software_location():
1095
+ return rf"{os.path.split(os.path.dirname(os.path.realpath(__file__)))[0]}"
1096
+
1097
+ def remove_trajectory_measurements(trajectories, column_labels):
1098
+
1099
+ """
1100
+ Filters a DataFrame of trajectory measurements to retain only essential tracking and classification columns.
1101
+
1102
+ Given a DataFrame containing detailed trajectory measurements and metadata for tracked objects, this
1103
+ function reduces the DataFrame to include only a predefined set of essential columns necessary for
1104
+ further analysis or visualization. The set of columns to retain includes basic tracking information,
1105
+ spatial coordinates, classification results, and certain metadata.
1106
+
1107
+ Parameters
1108
+ ----------
1109
+ trajectories : pandas.DataFrame
1110
+ The DataFrame containing trajectory measurements and metadata for each tracked object.
1111
+ column_labels : dict
1112
+ A dictionary mapping standard column names to their corresponding column names in the `trajectories` DataFrame.
1113
+ Expected keys include 'track', 'time', 'x', 'y', among others.
1114
+
1115
+ Returns
1116
+ -------
1117
+ pandas.DataFrame
1118
+ A filtered DataFrame containing only the essential columns as defined by `columns_to_keep` and present
1119
+ in `trajectories`.
1120
+
1121
+ Notes
1122
+ -----
1123
+ - The function dynamically adjusts the list of columns to retain based on their presence in the input DataFrame,
1124
+ ensuring compatibility with DataFrames containing varying sets of measurements.
1125
+ - Essential columns include tracking identifiers, time points, spatial coordinates (both pixel and physical units),
1126
+ classification labels, state information, lineage metadata, and visualization attributes.
1127
+
1128
+ Examples
1129
+ --------
1130
+ >>> column_labels = {
1131
+ ... 'track': 'TRACK_ID', 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'
1132
+ ... }
1133
+ >>> trajectories_df = pd.DataFrame({
1134
+ ... 'TRACK_ID': [1, 1, 2],
1135
+ ... 'FRAME': [0, 1, 0],
1136
+ ... 'POSITION_X': [100, 105, 200],
1137
+ ... 'POSITION_Y': [150, 155, 250],
1138
+ ... 'velocity': [0.5, 0.5, 0.2], # Additional column to be removed
1139
+ ... })
1140
+ >>> filtered_df = remove_trajectory_measurements(trajectories_df, column_labels)
1141
+ # `filtered_df` will contain only the essential columns as per `column_labels` and predefined essential columns.
1142
+ """
1143
+
1144
+ tracks = trajectories.copy()
1145
+
1146
+ columns_to_keep = [column_labels['track'], column_labels['time'], column_labels['x'], column_labels['y'],column_labels['x']+'_um', column_labels['y']+'_um', 'class_id',
1147
+ 't', 'state', 'generation', 'root', 'parent', 'ID', 't0', 'class', 'status', 'class_color', 'status_color', 'class_firstdetection', 't_firstdetection']
1148
+ cols = tracks.columns
1149
+ for c in columns_to_keep:
1150
+ if c not in cols:
1151
+ columns_to_keep.remove(c)
1152
+
1153
+ keep = [x for x in columns_to_keep if x in cols]
1154
+ tracks = tracks[keep]
1155
+
1156
+ return tracks
1157
+
1158
+
1159
+ def color_from_status(status, recently_modified=False):
1160
+
1161
+ if not recently_modified:
1162
+ if status==0:
1163
+ return 'tab:blue'
1164
+ elif status==1:
1165
+ return 'tab:red'
1166
+ elif status==2:
1167
+ return 'yellow'
1168
+ else:
1169
+ return 'k'
1170
+ else:
1171
+ if status==0:
1172
+ return 'tab:cyan'
1173
+ elif status==1:
1174
+ return 'tab:orange'
1175
+ elif status==2:
1176
+ return 'tab:olive'
1177
+ else:
1178
+ return 'k'
1179
+
1180
+ def color_from_class(cclass, recently_modified=False):
1181
+
1182
+ if not recently_modified:
1183
+ if cclass==0:
1184
+ return 'tab:red'
1185
+ elif cclass==1:
1186
+ return 'tab:blue'
1187
+ elif cclass==2:
1188
+ return 'yellow'
1189
+ else:
1190
+ return 'k'
1191
+ else:
1192
+ if cclass==0:
1193
+ return 'tab:orange'
1194
+ elif cclass==1:
1195
+ return 'tab:cyan'
1196
+ elif cclass==2:
1197
+ return 'tab:olive'
1198
+ else:
1199
+ return 'k'
1200
+
1201
+ def random_fliprot(img, mask):
1202
+
1203
+ """
1204
+
1205
+ Perform random flipping of the image and the associated mask.
1206
+ Needs YXC (channel last).
1207
+
1208
+ """
1209
+ assert img.ndim >= mask.ndim
1210
+ axes = tuple(range(mask.ndim))
1211
+ perm = tuple(np.random.permutation(axes))
1212
+ img = img.transpose(perm + tuple(range(mask.ndim, img.ndim)))
1213
+ mask = mask.transpose(perm)
1214
+ for ax in axes:
1215
+ if np.random.rand() > 0.5:
1216
+ img = np.flip(img, axis=ax)
1217
+ mask = np.flip(mask, axis=ax)
1218
+ return img, mask
1219
+
1220
+ # def random_intensity_change(img):
1221
+ # img[img!=0.] = img[img!=0.]*np.random.uniform(0.3,2)
1222
+ # img[img!=0.] += np.random.uniform(-0.2,0.2)
1223
+ # return img
1224
+
1225
+ def random_shift(image,mask, max_shift_amplitude=0.1):
1226
+
1227
+ """
1228
+
1229
+ Perform random shift of the image in X and or Y.
1230
+ Needs YXC (channel last).
1231
+
1232
+ """
1233
+
1234
+ input_shape = image.shape[0]
1235
+ max_shift = input_shape*max_shift_amplitude
1236
+
1237
+ shift_value_x = random.choice(np.arange(max_shift))
1238
+ if np.random.random() > 0.5:
1239
+ shift_value_x*=-1
1240
+
1241
+ shift_value_y = random.choice(np.arange(max_shift))
1242
+ if np.random.random() > 0.5:
1243
+ shift_value_y*=-1
1244
+
1245
+ image = shift(image,[shift_value_x, shift_value_y, 0], output=np.float32, order=3, mode="constant",cval=0.0)
1246
+ mask = shift(mask,[shift_value_x,shift_value_y],order=0,mode="constant",cval=0.0)
1247
+
1248
+ return image,mask
1249
+
1250
+
1251
+ def blur(x,max_sigma=4.0):
1252
+ """
1253
+ Random image blur
1254
+ """
1255
+ sigma = np.random.random()*max_sigma
1256
+ loc_i,loc_j,loc_c = np.where(x==0.)
1257
+ x = gaussian(x, sigma, channel_axis=-1, preserve_range=True)
1258
+ x[loc_i,loc_j,loc_c] = 0.
1259
+
1260
+ return x
1261
+
1262
+ def noise(x, apply_probability=0.5, clip_option=False):
1263
+
1264
+ """
1265
+ Apply random noise to a multichannel image
1266
+
1267
+ """
1268
+
1269
+ x_noise = x.astype(float).copy()
1270
+ loc_i,loc_j,loc_c = np.where(x_noise==0.)
1271
+ options = ['gaussian', 'localvar', 'poisson', 'speckle']
1272
+
1273
+ for k in range(x_noise.shape[-1]):
1274
+ mode_order = random.sample(options, len(options))
1275
+ for m in mode_order:
1276
+ p = np.random.random()
1277
+ if p <= apply_probability:
1278
+ try:
1279
+ x_noise[:,:,k] = random_noise(x_noise[:,:,k], mode=m, clip=clip_option)
1280
+ except:
1281
+ pass
1282
+
1283
+ x_noise[loc_i,loc_j,loc_c] = 0.
1284
+
1285
+ return x_noise
1286
+
1287
+
1288
+
1289
+ def augmenter(x, y, flip=True, gauss_blur=True, noise_option=True, shift=True,
1290
+ channel_extinction=False, extinction_probability=0.1, clip=False, max_sigma_blur=4,
1291
+ apply_noise_probability=0.5, augment_probability=0.9):
1292
+
1293
+ """
1294
+ Applies a series of augmentation techniques to images and their corresponding masks for deep learning training.
1295
+
1296
+ This function randomly applies a set of transformations including flipping, rotation, Gaussian blur,
1297
+ additive noise, shifting, and channel extinction to input images (x) and their masks (y) based on specified
1298
+ probabilities. These augmentations introduce variability in the training dataset, potentially improving model
1299
+ generalization.
1300
+
1301
+ Parameters
1302
+ ----------
1303
+ x : ndarray
1304
+ The input image to be augmented, with dimensions (height, width, channels).
1305
+ y : ndarray
1306
+ The corresponding mask or label image for `x`, with the same spatial dimensions.
1307
+ flip : bool, optional
1308
+ Whether to randomly flip and rotate the images. Default is True.
1309
+ gauss_blur : bool, optional
1310
+ Whether to apply Gaussian blur to the images. Default is True.
1311
+ noise_option : bool, optional
1312
+ Whether to add random noise to the images. Default is True.
1313
+ shift : bool, optional
1314
+ Whether to randomly shift the images. Default is True.
1315
+ channel_extinction : bool, optional
1316
+ Whether to randomly set entire channels of the image to zero. Default is False.
1317
+ extinction_probability : float, optional
1318
+ The probability of an entire channel being set to zero. Default is 0.1.
1319
+ clip : bool, optional
1320
+ Whether to clip the noise-added images to stay within valid intensity values. Default is False.
1321
+ max_sigma_blur : int, optional
1322
+ The maximum sigma value for Gaussian blur. Default is 4.
1323
+ apply_noise_probability : float, optional
1324
+ The probability of applying noise to the image. Default is 0.5.
1325
+ augment_probability : float, optional
1326
+ The overall probability of applying any augmentation to the image. Default is 0.9.
1327
+
1328
+ Returns
1329
+ -------
1330
+ tuple
1331
+ A tuple containing the augmented image and mask `(x, y)`.
1332
+
1333
+ Raises
1334
+ ------
1335
+ AssertionError
1336
+ If `extinction_probability` is not within the range [0, 1].
1337
+
1338
+ Notes
1339
+ -----
1340
+ - The augmentations are applied randomly based on the specified probabilities, allowing for
1341
+ a diverse set of transformed images from the original inputs.
1342
+ - This function is designed to be part of a preprocessing pipeline for training deep learning models,
1343
+ especially in tasks requiring spatial invariance and robustness to noise.
1344
+
1345
+ Examples
1346
+ --------
1347
+ >>> import numpy as np
1348
+ >>> x = np.random.rand(128, 128, 3) # Sample image
1349
+ >>> y = np.random.randint(2, size=(128, 128)) # Sample binary mask
1350
+ >>> x_aug, y_aug = augmenter(x, y)
1351
+ # The returned `x_aug` and `y_aug` are augmented versions of `x` and `y`.
1352
+ """
1353
+
1354
+ r = random.random()
1355
+ if r<= augment_probability:
1356
+
1357
+ if flip:
1358
+ x, y = random_fliprot(x, y)
1359
+
1360
+ if gauss_blur:
1361
+ x = blur(x, max_sigma=max_sigma_blur)
1362
+
1363
+ if noise_option:
1364
+ x = noise(x, apply_probability=apply_noise_probability, clip_option=clip)
1365
+
1366
+ if shift:
1367
+ x,y = random_shift(x,y)
1368
+
1369
+ if channel_extinction:
1370
+ assert extinction_probability <= 1.,'The extinction probability must be a number between 0 and 1.'
1371
+ for i in range(x.shape[-1]):
1372
+ if np.random.random() > (1 - extinction_probability):
1373
+ x[:,:,i] = 0.
1374
+
1375
+ return x, y
1376
+
1377
+ def normalize_per_channel(X, normalization_percentile_mode=True, normalization_values=[0.1,99.99],normalization_clipping=False):
1378
+
1379
+ """
1380
+ Applies per-channel normalization to a list of multi-channel images.
1381
+
1382
+ This function normalizes each channel of every image in the list `X` based on either percentile values
1383
+ or fixed min-max values. Optionally, it can also clip the normalized values to stay within the [0, 1] range.
1384
+ The normalization can be applied in a percentile mode, where the lower and upper bounds for normalization
1385
+ are determined based on the specified percentiles of the non-zero values in each channel.
1386
+
1387
+ Parameters
1388
+ ----------
1389
+ X : list of ndarray
1390
+ A list of 3D numpy arrays, where each array represents a multi-channel image with dimensions
1391
+ (height, width, channels).
1392
+ normalization_percentile_mode : bool or list of bool, optional
1393
+ If True (or a list of True values), normalization bounds are determined by percentiles specified
1394
+ in `normalization_values` for each channel. If False, fixed `normalization_values` are used directly.
1395
+ Default is True.
1396
+ normalization_values : list of two floats or list of lists of two floats, optional
1397
+ The percentile values [lower, upper] used for normalization in percentile mode, or the fixed
1398
+ min-max values [min, max] for direct normalization. Default is [0.1, 99.99].
1399
+ normalization_clipping : bool or list of bool, optional
1400
+ Determines whether to clip the normalized values to the [0, 1] range for each channel. Default is False.
1401
+
1402
+ Returns
1403
+ -------
1404
+ list of ndarray
1405
+ The list of normalized multi-channel images.
1406
+
1407
+ Raises
1408
+ ------
1409
+ AssertionError
1410
+ If the input images do not have a channel dimension, or if the lengths of `normalization_values`,
1411
+ `normalization_clipping`, and `normalization_percentile_mode` do not match the number of channels.
1412
+
1413
+ Notes
1414
+ -----
1415
+ - The normalization is applied in-place, modifying the input list `X`.
1416
+ - This function is designed to handle multi-channel images commonly used in image processing and
1417
+ computer vision tasks, particularly when different channels require separate normalization strategies.
1418
+
1419
+ Examples
1420
+ --------
1421
+ >>> X = [np.random.rand(100, 100, 3) for _ in range(5)] # Example list of 5 RGB images
1422
+ >>> normalized_X = normalize_per_channel(X)
1423
+ # Normalizes each channel of each image based on the default percentile values [0.1, 99.99].
1424
+ """
1425
+
1426
+ assert X[0].ndim==3,'Channel axis does not exist. Abort.'
1427
+ n_channels = X[0].shape[-1]
1428
+ if isinstance(normalization_percentile_mode, bool):
1429
+ normalization_percentile_mode = [normalization_percentile_mode]*n_channels
1430
+ if isinstance(normalization_clipping, bool):
1431
+ normalization_clipping = [normalization_clipping]*n_channels
1432
+ if len(normalization_values)==2 and not isinstance(normalization_values[0], list):
1433
+ normalization_values = [normalization_values]*n_channels
1434
+
1435
+ assert len(normalization_values)==n_channels
1436
+ assert len(normalization_clipping)==n_channels
1437
+ assert len(normalization_percentile_mode)==n_channels
1438
+
1439
+ for i in range(len(X)):
1440
+ x = X[i]
1441
+ loc_i,loc_j,loc_c = np.where(x==0.)
1442
+ norm_x = np.zeros_like(x, dtype=np.float32)
1443
+ for k in range(x.shape[-1]):
1444
+ chan = x[:,:,k]
1445
+ if not np.all(chan.flatten()==0):
1446
+ if normalization_percentile_mode[k]:
1447
+ min_val = np.percentile(chan[chan!=0.].flatten(), normalization_values[k][0])
1448
+ max_val = np.percentile(chan[chan!=0.].flatten(), normalization_values[k][1])
1449
+ else:
1450
+ min_val = normalization_values[k][0]
1451
+ max_val = normalization_values[k][1]
1452
+
1453
+ clip_option = normalization_clipping[k]
1454
+ norm_x[:,:,k] = normalize_mi_ma(chan.astype(np.float32), min_val, max_val, clip=clip_option, eps=1e-20, dtype=np.float32)
1455
+
1456
+ X[i] = norm_x
1457
+
1458
+ return X
1459
+
1460
+ def load_image_dataset(datasets, channels, train_spatial_calibration=None, mask_suffix='labelled'):
1461
+
1462
+ """
1463
+ Loads image and corresponding mask datasets, optionally applying spatial calibration adjustments.
1464
+
1465
+ This function iterates over specified datasets, loading image and mask pairs based on provided channels
1466
+ and adjusting images according to a specified spatial calibration factor. It supports loading images with
1467
+ multiple channels and applies necessary transformations to match the training spatial calibration.
1468
+
1469
+ Parameters
1470
+ ----------
1471
+ datasets : list of str
1472
+ A list of paths to the datasets containing the images and masks.
1473
+ channels : str or list of str
1474
+ The channel(s) to be loaded from the images. If a string is provided, it is converted into a list.
1475
+ train_spatial_calibration : float, optional
1476
+ The spatial calibration (e.g., micrometers per pixel) used during model training. If provided, images
1477
+ will be rescaled to match this calibration. Default is None, indicating no rescaling is applied.
1478
+ mask_suffix : str, optional
1479
+ The suffix used to identify mask files corresponding to the images. Default is 'labelled'.
1480
+
1481
+ Returns
1482
+ -------
1483
+ tuple of lists
1484
+ A tuple containing two lists: `X` for images and `Y` for corresponding masks. Both lists contain
1485
+ numpy arrays of loaded and optionally transformed images and masks.
1486
+
1487
+ Raises
1488
+ ------
1489
+ AssertionError
1490
+ If the provided `channels` argument is not a list or if the number of loaded images does not match
1491
+ the number of loaded masks.
1492
+
1493
+ Notes
1494
+ -----
1495
+ - The function assumes that mask filenames are derived from image filenames by appending a `mask_suffix`
1496
+ before the file extension.
1497
+ - Spatial calibration adjustment involves rescaling the images and masks to match the `train_spatial_calibration`.
1498
+ - Only images with a corresponding mask and a valid configuration file specifying channel indices and
1499
+ spatial calibration are loaded.
1500
+
1501
+ Examples
1502
+ --------
1503
+ >>> datasets = ['/path/to/dataset1', '/path/to/dataset2']
1504
+ >>> channels = ['DAPI', 'GFP']
1505
+ >>> X, Y = load_image_dataset(datasets, channels, train_spatial_calibration=0.65)
1506
+ # Loads DAPI and GFP channels from specified datasets, rescaling images to match a spatial calibration of 0.65.
1507
+ """
1508
+
1509
+ if isinstance(channels, str):
1510
+ channels = [channels]
1511
+
1512
+ assert isinstance(channels, list),'Please provide a list of channels. Abort.'
1513
+
1514
+ X = []; Y = [];
1515
+
1516
+ for ds in datasets:
1517
+ print(f'Loading data from dataset {ds}...')
1518
+ if not ds.endswith(os.sep):
1519
+ ds+=os.sep
1520
+ img_paths = list(set(glob(ds+'*.tif')) - set(glob(ds+f'*_{mask_suffix}.tif')))
1521
+ for im in img_paths:
1522
+ print(f'{im=}')
1523
+ mask_path = os.sep.join([os.path.split(im)[0],os.path.split(im)[-1].replace('.tif', f'_{mask_suffix}.tif')])
1524
+ if os.path.exists(mask_path):
1525
+ # load image and mask
1526
+ image = imread(im)
1527
+ if image.ndim==2:
1528
+ image = image[np.newaxis]
1529
+ if image.ndim>3:
1530
+ print('Invalid image shape, skipping')
1531
+ continue
1532
+ mask = imread(mask_path)
1533
+ config_path = im.replace('.tif','.json')
1534
+ if os.path.exists(config_path):
1535
+ # Load config
1536
+ with open(config_path, 'r') as f:
1537
+ config = json.load(f)
1538
+ try:
1539
+ ch_idx = []
1540
+ for c in channels:
1541
+ if c!='None':
1542
+ idx = config['channels'].index(c)
1543
+ ch_idx.append(idx)
1544
+ else:
1545
+ ch_idx.append(np.nan)
1546
+ im_calib = config['spatial_calibration']
1547
+ except Exception as e:
1548
+ print(e,' channels and/or spatial calibration could not be found in the config... Skipping image.')
1549
+ continue
1550
+
1551
+ ch_idx = np.array(ch_idx)
1552
+ ch_idx_safe = np.copy(ch_idx)
1553
+ ch_idx_safe[ch_idx_safe!=ch_idx_safe] = 0
1554
+ ch_idx_safe = ch_idx_safe.astype(int)
1555
+ print(ch_idx_safe)
1556
+ image = image[ch_idx_safe]
1557
+ image[np.where(ch_idx!=ch_idx)[0],:,:] = 0
1558
+
1559
+ image = np.moveaxis(image,0,-1)
1560
+ assert image.ndim==3,'The image has a wrong number of dimensions. Abort.'
1561
+
1562
+ if im_calib != train_spatial_calibration:
1563
+ factor = im_calib / train_spatial_calibration
1564
+ print(f'{im_calib=}, {train_spatial_calibration=}, {factor=}')
1565
+ image = zoom(image, [factor,factor,1], order=3)
1566
+ mask = zoom(mask, [factor,factor], order=0)
1567
+
1568
+ X.append(image)
1569
+ Y.append(mask)
1570
+
1571
+ assert len(X)==len(Y),'The number of images does not match with the number of masks... Abort.'
1572
+ return X,Y
1573
+
1574
+
1575
+ def download_url_to_file(url, dst, progress=True):
1576
+ r"""Download object at the given URL to a local path.
1577
+ Thanks to torch, slightly modified, from Cellpose
1578
+ Args:
1579
+ url (string): URL of the object to download
1580
+ dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
1581
+ progress (bool, optional): whether or not to display a progress bar to stderr
1582
+ Default: True
1583
+ """
1584
+ file_size = None
1585
+ import ssl
1586
+ ssl._create_default_https_context = ssl._create_unverified_context
1587
+ u = urlopen(url)
1588
+ meta = u.info()
1589
+ if hasattr(meta, 'getheaders'):
1590
+ content_length = meta.getheaders("Content-Length")
1591
+ else:
1592
+ content_length = meta.get_all("Content-Length")
1593
+ if content_length is not None and len(content_length) > 0:
1594
+ file_size = int(content_length[0])
1595
+ # We deliberately save it in a temp file and move it after
1596
+ dst = os.path.expanduser(dst)
1597
+ dst_dir = os.path.dirname(dst)
1598
+ f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
1599
+ try:
1600
+ with tqdm(total=file_size, disable=not progress,
1601
+ unit='B', unit_scale=True, unit_divisor=1024) as pbar:
1602
+ while True:
1603
+ buffer = u.read(8192) #8192
1604
+ if len(buffer) == 0:
1605
+ break
1606
+ f.write(buffer)
1607
+ pbar.update(len(buffer))
1608
+ f.close()
1609
+ shutil.move(f.name, dst)
1610
+ finally:
1611
+ f.close()
1612
+ if os.path.exists(f.name):
1613
+ os.remove(f.name)
1614
+
1615
+ def get_zenodo_files(cat=None):
1616
+
1617
+
1618
+ zenodo_json = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],"celldetective", "links", "zenodo.json"])
1619
+ with open(zenodo_json,"r") as f:
1620
+ zenodo_json = json.load(f)
1621
+ all_files = list(zenodo_json['files']['entries'].keys())
1622
+ all_files_short = [f.replace(".zip","") for f in all_files]
1623
+
1624
+ categories = []
1625
+ for f in all_files_short:
1626
+ if f.startswith('CP') or f.startswith('SD'):
1627
+ category = os.sep.join(['models','segmentation_generic'])
1628
+ elif f.startswith('MCF7'):
1629
+ category = os.sep.join(['models','segmentation_targets'])
1630
+ elif f.startswith('primNK'):
1631
+ category = os.sep.join(['models','segmentation_effectors'])
1632
+ elif f.startswith('demo'):
1633
+ category = 'demos'
1634
+ elif f.startswith('db-si'):
1635
+ category = os.sep.join(['datasets','signal_annotations'])
1636
+ elif f.startswith('db'):
1637
+ category = os.sep.join(['datasets','segmentation_annotations'])
1638
+ else:
1639
+ category = os.sep.join(['models','signal_detection'])
1640
+ categories.append(category)
1641
+
1642
+ if cat is not None:
1643
+ assert cat in [os.sep.join(['models','segmentation_generic']), os.sep.join(['models','segmentation_targets']), os.sep.join(['models','segmentation_effectors']), \
1644
+ 'demos', os.sep.join(['datasets','signal_annotations']), os.sep.join(['datasets','segmentation_annotations']), os.sep.join(['models','signal_detection'])]
1645
+ categories = np.array(categories)
1646
+ all_files_short = np.array(all_files_short)
1647
+ return list(all_files_short[np.where(categories==cat)[0]])
1648
+ else:
1649
+ return all_files_short,categories
1650
+
1651
+ def download_zenodo_file(file, output_dir):
1652
+
1653
+ zenodo_json = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],"celldetective", "links", "zenodo.json"])
1654
+ with open(zenodo_json,"r") as f:
1655
+ zenodo_json = json.load(f)
1656
+ all_files = list(zenodo_json['files']['entries'].keys())
1657
+ all_files_short = [f.replace(".zip","") for f in all_files]
1658
+ zenodo_url = zenodo_json['links']['files'].replace('api/','')
1659
+ full_links = ["/".join([zenodo_url, f]) for f in all_files]
1660
+ index = all_files_short.index(file)
1661
+ zip_url = full_links[index]
1662
+
1663
+ path_to_zip_file = os.sep.join([output_dir, 'temp.zip'])
1664
+ download_url_to_file(fr"{zip_url}",path_to_zip_file)
1665
+ with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
1666
+ zip_ref.extractall(output_dir)
1667
+
1668
+ file_to_rename = glob(os.sep.join([output_dir,file,"*[!.json][!.png][!.h5][!.csv][!.npy][!.tif][!.ini]"]))
1669
+ if len(file_to_rename)>0 and not file_to_rename[0].endswith(os.sep) and not file.startswith('demo'):
1670
+ os.rename(file_to_rename[0], os.sep.join([output_dir,file,file]))
1671
+
1672
+ if file=="db_mcf7_nuclei_w_primary_NK":
1673
+ os.rename(os.sep.join([output_dir,file.replace('db_','')]), os.sep.join([output_dir,file]))
1674
+ if file=="db_primary_NK_w_mcf7":
1675
+ os.rename(os.sep.join([output_dir,file.replace('db_','')]), os.sep.join([output_dir,file]))
1676
+ if file=='db-si-NucPI':
1677
+ os.rename(os.sep.join([output_dir,'db2-NucPI']), os.sep.join([output_dir,file]))
1678
+ if file=='db-si-NucCondensation':
1679
+ os.rename(os.sep.join([output_dir,'db1-NucCondensation']), os.sep.join([output_dir,file]))
1680
+
1681
+ os.remove(path_to_zip_file)