celldetective 1.0.2.post1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. celldetective/__main__.py +7 -21
  2. celldetective/events.py +2 -44
  3. celldetective/extra_properties.py +62 -52
  4. celldetective/filters.py +4 -5
  5. celldetective/gui/__init__.py +1 -1
  6. celldetective/gui/analyze_block.py +37 -10
  7. celldetective/gui/btrack_options.py +24 -23
  8. celldetective/gui/classifier_widget.py +62 -19
  9. celldetective/gui/configure_new_exp.py +32 -35
  10. celldetective/gui/control_panel.py +120 -81
  11. celldetective/gui/gui_utils.py +674 -396
  12. celldetective/gui/json_readers.py +7 -6
  13. celldetective/gui/layouts.py +756 -0
  14. celldetective/gui/measurement_options.py +98 -513
  15. celldetective/gui/neighborhood_options.py +322 -270
  16. celldetective/gui/plot_measurements.py +1114 -0
  17. celldetective/gui/plot_signals_ui.py +21 -20
  18. celldetective/gui/process_block.py +449 -169
  19. celldetective/gui/retrain_segmentation_model_options.py +27 -26
  20. celldetective/gui/retrain_signal_model_options.py +25 -24
  21. celldetective/gui/seg_model_loader.py +31 -27
  22. celldetective/gui/signal_annotator.py +2326 -2295
  23. celldetective/gui/signal_annotator_options.py +18 -16
  24. celldetective/gui/styles.py +16 -1
  25. celldetective/gui/survival_ui.py +67 -39
  26. celldetective/gui/tableUI.py +337 -48
  27. celldetective/gui/thresholds_gui.py +75 -71
  28. celldetective/gui/viewers.py +743 -0
  29. celldetective/io.py +247 -27
  30. celldetective/measure.py +43 -263
  31. celldetective/models/segmentation_effectors/primNK_cfse/config_input.json +29 -0
  32. celldetective/models/segmentation_effectors/primNK_cfse/cp-cfse-transfer +0 -0
  33. celldetective/models/segmentation_effectors/primNK_cfse/training_instructions.json +37 -0
  34. celldetective/neighborhood.py +498 -27
  35. celldetective/preprocessing.py +1023 -0
  36. celldetective/scripts/analyze_signals.py +7 -0
  37. celldetective/scripts/measure_cells.py +12 -0
  38. celldetective/scripts/segment_cells.py +20 -4
  39. celldetective/scripts/track_cells.py +11 -0
  40. celldetective/scripts/train_segmentation_model.py +35 -34
  41. celldetective/segmentation.py +14 -9
  42. celldetective/signals.py +234 -329
  43. celldetective/tracking.py +2 -2
  44. celldetective/utils.py +602 -49
  45. celldetective-1.1.1.dist-info/METADATA +305 -0
  46. celldetective-1.1.1.dist-info/RECORD +84 -0
  47. {celldetective-1.0.2.post1.dist-info → celldetective-1.1.1.dist-info}/top_level.txt +1 -0
  48. tests/__init__.py +0 -0
  49. tests/test_events.py +28 -0
  50. tests/test_filters.py +24 -0
  51. tests/test_io.py +70 -0
  52. tests/test_measure.py +141 -0
  53. tests/test_neighborhood.py +70 -0
  54. tests/test_preprocessing.py +37 -0
  55. tests/test_segmentation.py +93 -0
  56. tests/test_signals.py +135 -0
  57. tests/test_tracking.py +164 -0
  58. tests/test_utils.py +118 -0
  59. celldetective-1.0.2.post1.dist-info/METADATA +0 -221
  60. celldetective-1.0.2.post1.dist-info/RECORD +0 -66
  61. {celldetective-1.0.2.post1.dist-info → celldetective-1.1.1.dist-info}/LICENSE +0 -0
  62. {celldetective-1.0.2.post1.dist-info → celldetective-1.1.1.dist-info}/WHEEL +0 -0
  63. {celldetective-1.0.2.post1.dist-info → celldetective-1.1.1.dist-info}/entry_points.txt +0 -0
celldetective/utils.py CHANGED
@@ -22,6 +22,413 @@ import zipfile
22
22
  from tqdm import tqdm
23
23
  import shutil
24
24
  import tempfile
25
+ from scipy.interpolate import griddata
26
+
27
+
28
+ def derivative(x, timeline, window, mode='bi'):
29
+
30
+ """
31
+ Compute the derivative of a given array of values with respect to time using a specified numerical differentiation method.
32
+
33
+ Parameters
34
+ ----------
35
+ x : array_like
36
+ The input array of values.
37
+ timeline : array_like
38
+ The array representing the time points corresponding to the input values.
39
+ window : int
40
+ The size of the window used for numerical differentiation. Must be a positive odd integer.
41
+ mode : {'bi', 'forward', 'backward'}, optional
42
+ The numerical differentiation method to be used:
43
+ - 'bi' (default): Bidirectional differentiation using a symmetric window.
44
+ - 'forward': Forward differentiation using a one-sided window.
45
+ - 'backward': Backward differentiation using a one-sided window.
46
+
47
+ Returns
48
+ -------
49
+ dxdt : ndarray
50
+ The computed derivative values of the input array with respect to time.
51
+
52
+ Raises
53
+ ------
54
+ AssertionError
55
+ If the window size is not an odd integer and mode is 'bi'.
56
+
57
+ Notes
58
+ -----
59
+ - For 'bi' mode, the window size must be an odd number.
60
+ - For 'forward' mode, the derivative at the edge points may not be accurate due to the one-sided window.
61
+ - For 'backward' mode, the derivative at the first few points may not be accurate due to the one-sided window.
62
+
63
+ Examples
64
+ --------
65
+ >>> import numpy as np
66
+ >>> x = np.array([1, 2, 4, 7, 11])
67
+ >>> timeline = np.array([0, 1, 2, 3, 4])
68
+ >>> window = 3
69
+ >>> derivative(x, timeline, window, mode='bi')
70
+ array([3., 3., 3.])
71
+
72
+ >>> derivative(x, timeline, window, mode='forward')
73
+ array([1., 2., 3.])
74
+
75
+ >>> derivative(x, timeline, window, mode='backward')
76
+ array([3., 3., 3., 3.])
77
+ """
78
+
79
+ # modes = bi, forward, backward
80
+ dxdt = np.zeros(len(x))
81
+ dxdt[:] = np.nan
82
+
83
+ if mode=='bi':
84
+ assert window%2==1,'Please set an odd window for the bidirectional mode'
85
+ lower_bound = window//2
86
+ upper_bound = len(x) - window//2 - 1
87
+ elif mode=='forward':
88
+ lower_bound = 0
89
+ upper_bound = len(x) - window
90
+ elif mode=='backward':
91
+ lower_bound = window
92
+ upper_bound = len(x)
93
+
94
+ for t in range(lower_bound,upper_bound):
95
+ if mode=='bi':
96
+ dxdt[t] = (x[t+window//2+1] - x[t-window//2]) / (timeline[t+window//2+1] - timeline[t-window//2])
97
+ elif mode=='forward':
98
+ dxdt[t] = (x[t+window] - x[t]) / (timeline[t+window] - timeline[t])
99
+ elif mode=='backward':
100
+ dxdt[t] = (x[t] - x[t-window]) / (timeline[t] - timeline[t-window])
101
+ return dxdt
102
+
103
+ def differentiate_per_track(tracks, measurement, window_size=3, mode='bi'):
104
+
105
+ groupby_cols = ['TRACK_ID']
106
+ if 'position' in list(tracks.columns):
107
+ groupby_cols = ['position']+groupby_cols
108
+
109
+ tracks = tracks.sort_values(by=groupby_cols+['FRAME'],ignore_index=True)
110
+ tracks = tracks.reset_index(drop=True)
111
+ for tid, group in tracks.groupby(groupby_cols):
112
+ indices = group.index
113
+ timeline = group['FRAME'].values
114
+ signal = group[measurement].values
115
+ dsignal = derivative(signal, timeline, window_size, mode=mode)
116
+ tracks.loc[indices, 'd/dt.'+measurement] = dsignal
117
+ return tracks
118
+
119
+ def velocity_per_track(tracks, window_size=3, mode='bi'):
120
+
121
+ groupby_cols = ['TRACK_ID']
122
+ if 'position' in list(tracks.columns):
123
+ groupby_cols = ['position']+groupby_cols
124
+
125
+ tracks = tracks.sort_values(by=groupby_cols+['FRAME'],ignore_index=True)
126
+ tracks = tracks.reset_index(drop=True)
127
+ for tid, group in tracks.groupby(groupby_cols):
128
+ indices = group.index
129
+ timeline = group['FRAME'].values
130
+ x = group['POSITION_X'].values
131
+ y = group['POSITION_Y'].values
132
+ v = velocity(x,y,timeline,window=window_size,mode=mode)
133
+ v_abs = magnitude_velocity(v)
134
+ tracks.loc[indices, 'velocity'] = v_abs
135
+ return tracks
136
+
137
+ def velocity(x,y,timeline,window,mode='bi'):
138
+
139
+ """
140
+ Compute the velocity vector of a given 2D trajectory represented by arrays of x and y coordinates
141
+ with respect to time using a specified numerical differentiation method.
142
+
143
+ Parameters
144
+ ----------
145
+ x : array_like
146
+ The array of x-coordinates of the trajectory.
147
+ y : array_like
148
+ The array of y-coordinates of the trajectory.
149
+ timeline : array_like
150
+ The array representing the time points corresponding to the x and y coordinates.
151
+ window : int
152
+ The size of the window used for numerical differentiation. Must be a positive odd integer.
153
+ mode : {'bi', 'forward', 'backward'}, optional
154
+ The numerical differentiation method to be used:
155
+ - 'bi' (default): Bidirectional differentiation using a symmetric window.
156
+ - 'forward': Forward differentiation using a one-sided window.
157
+ - 'backward': Backward differentiation using a one-sided window.
158
+
159
+ Returns
160
+ -------
161
+ v : ndarray
162
+ The computed velocity vector of the 2D trajectory with respect to time.
163
+ The first column represents the x-component of velocity, and the second column represents the y-component.
164
+
165
+ Raises
166
+ ------
167
+ AssertionError
168
+ If the window size is not an odd integer and mode is 'bi'.
169
+
170
+ Notes
171
+ -----
172
+ - For 'bi' mode, the window size must be an odd number.
173
+ - For 'forward' mode, the velocity at the edge points may not be accurate due to the one-sided window.
174
+ - For 'backward' mode, the velocity at the first few points may not be accurate due to the one-sided window.
175
+
176
+ Examples
177
+ --------
178
+ >>> import numpy as np
179
+ >>> x = np.array([1, 2, 4, 7, 11])
180
+ >>> y = np.array([0, 3, 5, 8, 10])
181
+ >>> timeline = np.array([0, 1, 2, 3, 4])
182
+ >>> window = 3
183
+ >>> velocity(x, y, timeline, window, mode='bi')
184
+ array([[3., 3.],
185
+ [3., 3.]])
186
+
187
+ >>> velocity(x, y, timeline, window, mode='forward')
188
+ array([[2., 2.],
189
+ [3., 3.]])
190
+
191
+ >>> velocity(x, y, timeline, window, mode='backward')
192
+ array([[3., 3.],
193
+ [3., 3.]])
194
+ """
195
+
196
+ v = np.zeros((len(x),2))
197
+ v[:,:] = np.nan
198
+
199
+ v[:,0] = derivative(x, timeline, window, mode=mode)
200
+ v[:,1] = derivative(y, timeline, window, mode=mode)
201
+
202
+ return v
203
+
204
+ def magnitude_velocity(v_matrix):
205
+
206
+ """
207
+ Compute the magnitude of velocity vectors given a matrix representing 2D velocity vectors.
208
+
209
+ Parameters
210
+ ----------
211
+ v_matrix : array_like
212
+ The matrix where each row represents a 2D velocity vector with the first column
213
+ being the x-component and the second column being the y-component.
214
+
215
+ Returns
216
+ -------
217
+ magnitude : ndarray
218
+ The computed magnitudes of the input velocity vectors.
219
+
220
+ Notes
221
+ -----
222
+ - If a velocity vector has NaN components, the corresponding magnitude will be NaN.
223
+ - The function handles NaN values in the input matrix gracefully.
224
+
225
+ Examples
226
+ --------
227
+ >>> import numpy as np
228
+ >>> v_matrix = np.array([[3, 4],
229
+ ... [2, 2],
230
+ ... [3, 3]])
231
+ >>> magnitude_velocity(v_matrix)
232
+ array([5., 2.82842712, 4.24264069])
233
+
234
+ >>> v_matrix_with_nan = np.array([[3, 4],
235
+ ... [np.nan, 2],
236
+ ... [3, np.nan]])
237
+ >>> magnitude_velocity(v_matrix_with_nan)
238
+ array([5., nan, nan])
239
+ """
240
+
241
+ magnitude = np.zeros(len(v_matrix))
242
+ magnitude[:] = np.nan
243
+ for i in range(len(v_matrix)):
244
+ if v_matrix[i,0]==v_matrix[i,0]:
245
+ magnitude[i] = np.sqrt(v_matrix[i,0]**2 + v_matrix[i,1]**2)
246
+ return magnitude
247
+
248
+ def orientation(v_matrix):
249
+
250
+ """
251
+ Compute the orientation angles (in radians) of 2D velocity vectors given a matrix representing velocity vectors.
252
+
253
+ Parameters
254
+ ----------
255
+ v_matrix : array_like
256
+ The matrix where each row represents a 2D velocity vector with the first column
257
+ being the x-component and the second column being the y-component.
258
+
259
+ Returns
260
+ -------
261
+ orientation_array : ndarray
262
+ The computed orientation angles of the input velocity vectors in radians.
263
+ If a velocity vector has NaN components, the corresponding orientation angle will be NaN.
264
+
265
+ Examples
266
+ --------
267
+ >>> import numpy as np
268
+ >>> v_matrix = np.array([[3, 4],
269
+ ... [2, 2],
270
+ ... [-3, -3]])
271
+ >>> orientation(v_matrix)
272
+ array([0.92729522, 0.78539816, -2.35619449])
273
+
274
+ >>> v_matrix_with_nan = np.array([[3, 4],
275
+ ... [np.nan, 2],
276
+ ... [3, np.nan]])
277
+ >>> orientation(v_matrix_with_nan)
278
+ array([0.92729522, nan, nan])
279
+ """
280
+
281
+ orientation_array = np.zeros(len(v_matrix))
282
+ for t in range(len(orientation_array)):
283
+ if v_matrix[t,0]==v_matrix[t,0]:
284
+ orientation_array[t] = np.arctan2(v_matrix[t,0],v_matrix[t,1])
285
+ return orientation_array
286
+
287
+
288
+ def estimate_unreliable_edge(activation_protocol=[['gauss',2],['std',4]]):
289
+
290
+ """
291
+ Safely estimate the distance to the edge of an image in which the filtered image values can be artefactual.
292
+
293
+ Parameters
294
+ ----------
295
+ activation_protocol : list of list, optional
296
+ A list of lists, where each sublist contains a string naming the filter function, followed by its arguments (usually a kernel size).
297
+ Default is [['gauss', 2], ['std', 4]].
298
+
299
+ Returns
300
+ -------
301
+ int or None
302
+ The sum of the kernel sizes in the activation protocol if the protocol
303
+ is not empty. Returns None if the activation protocol is empty.
304
+
305
+ Notes
306
+ -----
307
+ This function assumes that the second element of each sublist in the
308
+ activation protocol is a kernel size.
309
+
310
+ Examples
311
+ --------
312
+ >>> estimate_unreliable_edge([['gauss', 2], ['std', 4]])
313
+ 6
314
+ >>> estimate_unreliable_edge([])
315
+ None
316
+ """
317
+
318
+ if activation_protocol==[]:
319
+ return None
320
+ else:
321
+ edge=0
322
+ for fct in activation_protocol:
323
+ if isinstance(fct[1],(int,np.int_)):
324
+ edge+=fct[1]
325
+ return edge
326
+
327
+ def unpad(img, pad):
328
+
329
+ """
330
+ Remove padding from an image.
331
+
332
+ This function removes the specified amount of padding from the borders
333
+ of an image. The padding is assumed to be the same on all sides.
334
+
335
+ Parameters
336
+ ----------
337
+ img : ndarray
338
+ The input image from which the padding will be removed.
339
+ pad : int
340
+ The amount of padding to remove from each side of the image.
341
+
342
+ Returns
343
+ -------
344
+ ndarray
345
+ The image with the padding removed.
346
+
347
+ Raises
348
+ ------
349
+ ValueError
350
+ If `pad` is greater than or equal to half of the smallest dimension
351
+ of `img`.
352
+
353
+ See Also
354
+ --------
355
+ numpy.pad : Pads an array.
356
+
357
+ Notes
358
+ -----
359
+ This function assumes that the input image is a 2D array.
360
+
361
+ Examples
362
+ --------
363
+ >>> import numpy as np
364
+ >>> img = np.array([[0, 0, 0, 0, 0],
365
+ ... [0, 1, 1, 1, 0],
366
+ ... [0, 1, 1, 1, 0],
367
+ ... [0, 1, 1, 1, 0],
368
+ ... [0, 0, 0, 0, 0]])
369
+ >>> unpad(img, 1)
370
+ array([[1, 1, 1],
371
+ [1, 1, 1],
372
+ [1, 1, 1]])
373
+ """
374
+
375
+ return img[pad:-pad, pad:-pad]
376
+
377
+ def mask_edges(binary_mask, border_size):
378
+
379
+ """
380
+ Mask the edges of a binary mask.
381
+
382
+ This function sets the edges of a binary mask to False, effectively
383
+ masking out a border of the specified size.
384
+
385
+ Parameters
386
+ ----------
387
+ binary_mask : ndarray
388
+ A 2D binary mask array where the edges will be masked.
389
+ border_size : int
390
+ The size of the border to mask (set to False) on all sides.
391
+
392
+ Returns
393
+ -------
394
+ ndarray
395
+ The binary mask with the edges masked out.
396
+
397
+ Raises
398
+ ------
399
+ ValueError
400
+ If `border_size` is greater than or equal to half of the smallest
401
+ dimension of `binary_mask`.
402
+
403
+ Notes
404
+ -----
405
+ This function assumes that the input `binary_mask` is a 2D array. The
406
+ input mask is converted to a boolean array before masking the edges.
407
+
408
+ Examples
409
+ --------
410
+ >>> import numpy as np
411
+ >>> binary_mask = np.array([[1, 1, 1, 1, 1],
412
+ ... [1, 1, 1, 1, 1],
413
+ ... [1, 1, 1, 1, 1],
414
+ ... [1, 1, 1, 1, 1],
415
+ ... [1, 1, 1, 1, 1]])
416
+ >>> mask_edges(binary_mask, 1)
417
+ array([[False, False, False, False, False],
418
+ [False, True, True, True, False],
419
+ [False, True, True, True, False],
420
+ [False, True, True, True, False],
421
+ [False, False, False, False, False]])
422
+ """
423
+
424
+ binary_mask = binary_mask.astype(bool)
425
+ binary_mask[:border_size,:] = False
426
+ binary_mask[(binary_mask.shape[0]-border_size):,:] = False
427
+ binary_mask[:,:border_size] = False
428
+ binary_mask[:,(binary_mask.shape[1]-border_size):] = False
429
+
430
+ return binary_mask
431
+
25
432
 
26
433
  def create_patch_mask(h, w, center=None, radius=None):
27
434
 
@@ -170,8 +577,6 @@ def rename_intensity_column(df, channels):
170
577
  new_name = '_'.join(list(measure))
171
578
  new_name = new_name.replace('radial_gradient', "radial_intercept")
172
579
  to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
173
-
174
-
175
580
  else:
176
581
  to_rename = {}
177
582
  for k in range(len(intensity_columns)):
@@ -367,7 +772,7 @@ def compute_weights(y):
367
772
 
368
773
  return class_weights
369
774
 
370
- def train_test_split(data_x, data_y1, data_y2=None, validation_size=0.25, test_size=0):
775
+ def train_test_split(data_x, data_y1, data_class=None, validation_size=0.25, test_size=0, n_iterations=10):
371
776
 
372
777
  """
373
778
 
@@ -407,37 +812,56 @@ def train_test_split(data_x, data_y1, data_y2=None, validation_size=0.25, test_s
407
812
 
408
813
  """
409
814
 
410
- n_values = len(data_x)
411
- randomize = np.arange(n_values)
412
- np.random.shuffle(randomize)
815
+ if data_class is not None:
816
+ print(f"Unique classes: {np.sort(np.argmax(np.unique(data_class,axis=0),axis=1))}")
817
+
818
+ for i in range(n_iterations):
413
819
 
414
- train_percentage = 1- validation_size - test_size
415
- chunks = split_by_ratio(randomize, train_percentage, validation_size, test_size)
820
+ n_values = len(data_x)
821
+ randomize = np.arange(n_values)
822
+ np.random.shuffle(randomize)
416
823
 
417
- x_train = data_x[chunks[0]]
418
- y1_train = data_y1[chunks[0]]
419
- if data_y2 is not None:
420
- y2_train = data_y2[chunks[0]]
824
+ train_percentage = 1 - validation_size - test_size
421
825
 
826
+ chunks = split_by_ratio(randomize, train_percentage, validation_size, test_size)
422
827
 
423
- x_val = data_x[chunks[1]]
424
- y1_val = data_y1[chunks[1]]
425
- if data_y2 is not None:
426
- y2_val = data_y2[chunks[1]]
828
+ x_train = data_x[chunks[0]]
829
+ y1_train = data_y1[chunks[0]]
830
+ if data_class is not None:
831
+ y2_train = data_class[chunks[0]]
427
832
 
428
- ds = {"x_train": x_train, "x_val": x_val,
429
- "y1_train": y1_train, "y1_val": y1_val}
430
- if data_y2 is not None:
431
- ds.update({"y2_train": y2_train, "y2_val": y2_val})
833
+ x_val = data_x[chunks[1]]
834
+ y1_val = data_y1[chunks[1]]
835
+ if data_class is not None:
836
+ y2_val = data_class[chunks[1]]
837
+
838
+ if data_class is not None:
839
+ print(f"classes in train set: {np.sort(np.argmax(np.unique(y2_train,axis=0),axis=1))}; classes in validation set: {np.sort(np.argmax(np.unique(y2_val,axis=0),axis=1))}")
840
+ same_class_test = np.array_equal(np.sort(np.argmax(np.unique(y2_train,axis=0),axis=1)), np.sort(np.argmax(np.unique(y2_val,axis=0),axis=1)))
841
+ print(f"Check that classes are found in all sets: {same_class_test}...")
842
+ else:
843
+ same_class_test = True
844
+
845
+ if same_class_test:
846
+
847
+ ds = {"x_train": x_train, "x_val": x_val,
848
+ "y1_train": y1_train, "y1_val": y1_val}
849
+ if data_class is not None:
850
+ ds.update({"y2_train": y2_train, "y2_val": y2_val})
851
+
852
+ if test_size>0:
853
+ x_test = data_x[chunks[2]]
854
+ y1_test = data_y1[chunks[2]]
855
+ ds.update({"x_test": x_test, "y1_test": y1_test})
856
+ if data_class is not None:
857
+ y2_test = data_class[chunks[2]]
858
+ ds.update({"y2_test": y2_test})
859
+ return ds
860
+ else:
861
+ continue
862
+
863
+ raise Exception("Some classes are missing from the train or validation set... Abort.")
432
864
 
433
- if test_size>0:
434
- x_test = data_x[chunks[2]]
435
- y1_test = data_y1[chunks[2]]
436
- ds.update({"x_test": x_test, "y1_test": y1_test})
437
- if data_y2 is not None:
438
- y2_test = data_y2[chunks[2]]
439
- ds.update({"y2_test": y2_test})
440
- return ds
441
865
 
442
866
  def remove_redundant_features(features, reference_features, channel_names=None):
443
867
 
@@ -482,7 +906,7 @@ def remove_redundant_features(features, reference_features, channel_names=None):
482
906
 
483
907
  """
484
908
 
485
- new_features = features.copy()
909
+ new_features = features[:]
486
910
 
487
911
  for f in features:
488
912
 
@@ -1144,7 +1568,7 @@ def remove_trajectory_measurements(trajectories, column_labels):
1144
1568
  tracks = trajectories.copy()
1145
1569
 
1146
1570
  columns_to_keep = [column_labels['track'], column_labels['time'], column_labels['x'], column_labels['y'],column_labels['x']+'_um', column_labels['y']+'_um', 'class_id',
1147
- 't', 'state', 'generation', 'root', 'parent', 'ID', 't0', 'class', 'status', 'class_color', 'status_color', 'class_firstdetection', 't_firstdetection']
1571
+ 't', 'state', 'generation', 'root', 'parent', 'ID', 't0', 'class', 'status', 'class_color', 'status_color', 'class_firstdetection', 't_firstdetection', 'velocity']
1148
1572
  cols = tracks.columns
1149
1573
  for c in columns_to_keep:
1150
1574
  if c not in cols:
@@ -1201,11 +1625,32 @@ def color_from_class(cclass, recently_modified=False):
1201
1625
  def random_fliprot(img, mask):
1202
1626
 
1203
1627
  """
1628
+ Randomly flips and rotates an image and its corresponding mask.
1629
+
1630
+ This function applies a series of random flips and permutations (rotations) to both the input image and its
1631
+ associated mask, ensuring that any transformations applied to the image are also exactly applied to the mask.
1632
+ The function is designed to handle multi-dimensional images (e.g., multi-channel images in YXC format where
1633
+ channels are last).
1634
+
1635
+ Parameters
1636
+ ----------
1637
+ img : ndarray
1638
+ The input image to be transformed. This array is expected to have dimensions where the channel axis is last.
1639
+ mask : ndarray
1640
+ The mask corresponding to `img`, to be transformed in the same way as the image.
1641
+
1642
+ Returns
1643
+ -------
1644
+ tuple of ndarray
1645
+ A tuple containing the transformed image and mask.
1204
1646
 
1205
- Perform random flipping of the image and the associated mask.
1206
- Needs YXC (channel last).
1647
+ Raises
1648
+ ------
1649
+ AssertionError
1650
+ If the number of dimensions of the mask exceeds that of the image, indicating incompatible shapes.
1207
1651
 
1208
1652
  """
1653
+
1209
1654
  assert img.ndim >= mask.ndim
1210
1655
  axes = tuple(range(mask.ndim))
1211
1656
  perm = tuple(np.random.permutation(axes))
@@ -1225,12 +1670,37 @@ def random_fliprot(img, mask):
1225
1670
  def random_shift(image,mask, max_shift_amplitude=0.1):
1226
1671
 
1227
1672
  """
1673
+ Randomly shifts an image and its corresponding mask along the X and Y axes.
1228
1674
 
1229
- Perform random shift of the image in X and or Y.
1230
- Needs YXC (channel last).
1675
+ This function shifts both the image and the mask by a randomly chosen distance up to a maximum
1676
+ percentage of the image's dimensions, specified by `max_shift_amplitude`. The shifts are applied
1677
+ independently in both the X and Y directions. This type of augmentation can help improve the robustness
1678
+ of models to positional variations in images.
1679
+
1680
+ Parameters
1681
+ ----------
1682
+ image : ndarray
1683
+ The input image to be shifted. Must be in YXC format (height, width, channels).
1684
+ mask : ndarray
1685
+ The mask corresponding to `image`, to be shifted in the same way as the image.
1686
+ max_shift_amplitude : float, optional
1687
+ The maximum shift as a fraction of the image's dimension. Default is 0.1 (10% of the image's size).
1688
+
1689
+ Returns
1690
+ -------
1691
+ tuple of ndarray
1692
+ A tuple containing the shifted image and mask.
1693
+
1694
+ Notes
1695
+ -----
1696
+ - The shift values are chosen randomly within the range defined by the maximum amplitude.
1697
+ - Shifting is performed using the 'constant' mode where missing values are filled with zeros (cval=0.0),
1698
+ which may introduce areas of zero-padding along the edges of the shifted images and masks.
1699
+ - This function is designed to support data augmentation for machine learning and image processing tasks,
1700
+ particularly in contexts where spatial invariance is beneficial.
1701
+
1702
+ """
1231
1703
 
1232
- """
1233
-
1234
1704
  input_shape = image.shape[0]
1235
1705
  max_shift = input_shape*max_shift_amplitude
1236
1706
 
@@ -1249,9 +1719,35 @@ def random_shift(image,mask, max_shift_amplitude=0.1):
1249
1719
 
1250
1720
 
1251
1721
  def blur(x,max_sigma=4.0):
1722
+
1252
1723
  """
1253
- Random image blur
1724
+ Applies a random Gaussian blur to an image.
1725
+
1726
+ This function blurs an image by applying a Gaussian filter with a randomly chosen sigma value. The sigma
1727
+ represents the standard deviation for the Gaussian kernel and is selected randomly up to a specified maximum.
1728
+ The blurring is applied while preserving the range of the image's intensity values and maintaining any
1729
+ zero-valued pixels as they are.
1730
+
1731
+ Parameters
1732
+ ----------
1733
+ x : ndarray
1734
+ The input image to be blurred. The image can have any number of channels, but must be in a format
1735
+ where the channels are the last dimension (YXC format).
1736
+ max_sigma : float, optional
1737
+ The maximum value for the standard deviation of the Gaussian blur. Default is 4.0.
1738
+
1739
+ Returns
1740
+ -------
1741
+ ndarray
1742
+ The blurred image. The output will have the same shape and type as the input image.
1743
+
1744
+ Notes
1745
+ -----
1746
+ - The function ensures that zero-valued pixels in the input image remain unchanged after the blurring,
1747
+ which can be important for maintaining masks or other specific regions within the image.
1748
+ - Gaussian blurring is commonly used in image processing to reduce image noise and detail by smoothing.
1254
1749
  """
1750
+
1255
1751
  sigma = np.random.random()*max_sigma
1256
1752
  loc_i,loc_j,loc_c = np.where(x==0.)
1257
1753
  x = gaussian(x, sigma, channel_axis=-1, preserve_range=True)
@@ -1262,8 +1758,44 @@ def blur(x,max_sigma=4.0):
1262
1758
  def noise(x, apply_probability=0.5, clip_option=False):
1263
1759
 
1264
1760
  """
1265
- Apply random noise to a multichannel image
1761
+ Applies random noise to each channel of a multichannel image based on a specified probability.
1762
+
1763
+ This function introduces various types of random noise to an image. Each channel of the image can be
1764
+ modified independently with different noise models chosen randomly from a predefined list. The application
1765
+ of noise to any given channel is determined by a specified probability, allowing for selective noise
1766
+ addition.
1767
+
1768
+ Parameters
1769
+ ----------
1770
+ x : ndarray
1771
+ The input multichannel image to which noise will be added. The image should be in format with channels
1772
+ as the last dimension (e.g., height x width x channels).
1773
+ apply_probability : float, optional
1774
+ The probability with which noise is applied to each channel of the image. Default is 0.5.
1775
+ clip_option : bool, optional
1776
+ Specifies whether to clip the corrupted data to stay within the valid range after noise addition.
1777
+ If True, the output array will be clipped to the range [0, 1] or [0, 255] depending on the input
1778
+ data type. Default is False.
1779
+
1780
+ Returns
1781
+ -------
1782
+ ndarray
1783
+ The noised image. This output has the same shape as the input but potentially altered intensity values
1784
+ due to noise addition.
1785
+
1786
+ Notes
1787
+ -----
1788
+ - The types of noise that can be applied include 'gaussian', 'localvar', 'poisson', and 'speckle'.
1789
+ - The choice of noise type for each channel is randomized and the noise is only applied if a randomly
1790
+ generated number is less than or equal to `apply_probability`.
1791
+ - Zero-valued pixels in the input image remain zero in the output to preserve background or masked areas.
1266
1792
 
1793
+ Examples
1794
+ --------
1795
+ >>> import numpy as np
1796
+ >>> x = np.random.rand(256, 256, 3) # Example 3-channel image
1797
+ >>> noised_image = noise(x)
1798
+ # The image 'x' may have different types of noise applied to each of its channels with a 50% probability.
1267
1799
  """
1268
1800
 
1269
1801
  x_noise = x.astype(float).copy()
@@ -1436,26 +1968,29 @@ def normalize_per_channel(X, normalization_percentile_mode=True, normalization_v
1436
1968
  assert len(normalization_clipping)==n_channels
1437
1969
  assert len(normalization_percentile_mode)==n_channels
1438
1970
 
1971
+ X_normalized = []
1439
1972
  for i in range(len(X)):
1440
- x = X[i]
1973
+ x = X[i].copy()
1441
1974
  loc_i,loc_j,loc_c = np.where(x==0.)
1442
1975
  norm_x = np.zeros_like(x, dtype=np.float32)
1443
1976
  for k in range(x.shape[-1]):
1444
- chan = x[:,:,k]
1977
+ chan = x[:,:,k].copy()
1445
1978
  if not np.all(chan.flatten()==0):
1446
1979
  if normalization_percentile_mode[k]:
1447
- min_val = np.percentile(chan[chan!=0.].flatten(), normalization_values[k][0])
1448
- max_val = np.percentile(chan[chan!=0.].flatten(), normalization_values[k][1])
1980
+ min_val = np.nanpercentile(chan[chan!=0.].flatten(), normalization_values[k][0])
1981
+ max_val = np.nanpercentile(chan[chan!=0.].flatten(), normalization_values[k][1])
1449
1982
  else:
1450
1983
  min_val = normalization_values[k][0]
1451
1984
  max_val = normalization_values[k][1]
1452
1985
 
1453
1986
  clip_option = normalization_clipping[k]
1454
- norm_x[:,:,k] = normalize_mi_ma(chan.astype(np.float32), min_val, max_val, clip=clip_option, eps=1e-20, dtype=np.float32)
1455
-
1456
- X[i] = norm_x
1987
+ norm_x[:,:,k] = normalize_mi_ma(chan.astype(np.float32).copy(), min_val, max_val, clip=clip_option, eps=1e-20, dtype=np.float32)
1988
+ else:
1989
+ norm_x[:,:,k] = 0.
1990
+ norm_x[loc_i,loc_j,loc_c] = 0.
1991
+ X_normalized.append(norm_x.copy())
1457
1992
 
1458
- return X
1993
+ return X_normalized
1459
1994
 
1460
1995
  def load_image_dataset(datasets, channels, train_spatial_calibration=None, mask_suffix='labelled'):
1461
1996
 
@@ -1561,8 +2096,7 @@ def load_image_dataset(datasets, channels, train_spatial_calibration=None, mask_
1561
2096
 
1562
2097
  if im_calib != train_spatial_calibration:
1563
2098
  factor = im_calib / train_spatial_calibration
1564
- print(f'{im_calib=}, {train_spatial_calibration=}, {factor=}')
1565
- image = zoom(image, [factor,factor,1], order=3)
2099
+ image = np.moveaxis([zoom(image[:,:,c].astype(float).copy(), [factor,factor], order=3, prefilter=False) for c in range(image.shape[-1])],0,-1) #zoom(image, [factor,factor,1], order=3)
1566
2100
  mask = zoom(mask, [factor,factor], order=0)
1567
2101
 
1568
2102
  X.append(image)
@@ -1678,4 +2212,23 @@ def download_zenodo_file(file, output_dir):
1678
2212
  if file=='db-si-NucCondensation':
1679
2213
  os.rename(os.sep.join([output_dir,'db1-NucCondensation']), os.sep.join([output_dir,file]))
1680
2214
 
1681
- os.remove(path_to_zip_file)
2215
+ os.remove(path_to_zip_file)
2216
+
2217
+ def interpolate_nan(img, method='nearest'):
2218
+
2219
+ """
2220
+ Interpolate NaN on single channel array 2D
2221
+ """
2222
+
2223
+ if np.any(img.flatten()!=img.flatten()):
2224
+ # then need to interpolate
2225
+ x_grid, y_grid = np.meshgrid(np.arange(img.shape[1]),np.arange(img.shape[0]))
2226
+ mask = [~np.isnan(img)][0]
2227
+ x = x_grid[mask].reshape(-1)
2228
+ y = y_grid[mask].reshape(-1)
2229
+ points = np.array([x,y]).T
2230
+ values = img[mask].reshape(-1)
2231
+ interp_grid = griddata(points, values, (x_grid, y_grid), method=method)
2232
+ return interp_grid
2233
+ else:
2234
+ return img