celldetective 1.4.2__py3-none-any.whl → 1.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +304 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/measure_cells.py +565 -0
  81. celldetective/processes/segment_cells.py +760 -0
  82. celldetective/processes/track_cells.py +435 -0
  83. celldetective/processes/train_segmentation_model.py +694 -0
  84. celldetective/processes/train_signal_model.py +265 -0
  85. celldetective/processes/unified_process.py +292 -0
  86. celldetective/regionprops/_regionprops.py +358 -317
  87. celldetective/relative_measurements.py +987 -710
  88. celldetective/scripts/measure_cells.py +313 -212
  89. celldetective/scripts/measure_relative.py +90 -46
  90. celldetective/scripts/segment_cells.py +165 -104
  91. celldetective/scripts/segment_cells_thresholds.py +96 -68
  92. celldetective/scripts/track_cells.py +198 -149
  93. celldetective/scripts/train_segmentation_model.py +324 -201
  94. celldetective/scripts/train_signal_model.py +87 -45
  95. celldetective/segmentation.py +844 -749
  96. celldetective/signals.py +3514 -2861
  97. celldetective/tracking.py +30 -15
  98. celldetective/utils/__init__.py +0 -0
  99. celldetective/utils/cellpose_utils/__init__.py +133 -0
  100. celldetective/utils/color_mappings.py +42 -0
  101. celldetective/utils/data_cleaning.py +630 -0
  102. celldetective/utils/data_loaders.py +450 -0
  103. celldetective/utils/dataset_helpers.py +207 -0
  104. celldetective/utils/downloaders.py +197 -0
  105. celldetective/utils/event_detection/__init__.py +8 -0
  106. celldetective/utils/experiment.py +1782 -0
  107. celldetective/utils/image_augmenters.py +308 -0
  108. celldetective/utils/image_cleaning.py +74 -0
  109. celldetective/utils/image_loaders.py +926 -0
  110. celldetective/utils/image_transforms.py +335 -0
  111. celldetective/utils/io.py +62 -0
  112. celldetective/utils/mask_cleaning.py +348 -0
  113. celldetective/utils/mask_transforms.py +5 -0
  114. celldetective/utils/masks.py +184 -0
  115. celldetective/utils/maths.py +351 -0
  116. celldetective/utils/model_getters.py +325 -0
  117. celldetective/utils/model_loaders.py +296 -0
  118. celldetective/utils/normalization.py +380 -0
  119. celldetective/utils/parsing.py +465 -0
  120. celldetective/utils/plots/__init__.py +0 -0
  121. celldetective/utils/plots/regression.py +53 -0
  122. celldetective/utils/resources.py +34 -0
  123. celldetective/utils/stardist_utils/__init__.py +104 -0
  124. celldetective/utils/stats.py +90 -0
  125. celldetective/utils/types.py +21 -0
  126. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/METADATA +1 -1
  127. celldetective-1.5.0b0.dist-info/RECORD +187 -0
  128. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/WHEEL +1 -1
  129. tests/gui/test_new_project.py +129 -117
  130. tests/gui/test_project.py +127 -79
  131. tests/test_filters.py +39 -15
  132. tests/test_notebooks.py +8 -0
  133. tests/test_tracking.py +232 -13
  134. tests/test_utils.py +123 -77
  135. celldetective/gui/base_components.py +0 -23
  136. celldetective/gui/layouts.py +0 -1602
  137. celldetective/gui/processes/compute_neighborhood.py +0 -594
  138. celldetective/gui/processes/measure_cells.py +0 -360
  139. celldetective/gui/processes/segment_cells.py +0 -499
  140. celldetective/gui/processes/track_cells.py +0 -303
  141. celldetective/gui/processes/train_segmentation_model.py +0 -270
  142. celldetective/gui/processes/train_signal_model.py +0 -108
  143. celldetective/gui/table_ops/merge_groups.py +0 -118
  144. celldetective/gui/viewers.py +0 -1354
  145. celldetective/io.py +0 -3663
  146. celldetective/utils.py +0 -3108
  147. celldetective-1.4.2.dist-info/RECORD +0 -123
  148. /celldetective/{gui/processes → processes}/downloader.py +0 -0
  149. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/entry_points.txt +0 -0
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/licenses/LICENSE +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/top_level.txt +0 -0
celldetective/utils.py DELETED
@@ -1,3108 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import matplotlib.pyplot as plt
4
- import os
5
- from scipy.ndimage import shift, zoom
6
- os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '3'
7
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
8
- from tensorflow.config import list_physical_devices
9
- import configparser
10
- from sklearn.utils.class_weight import compute_class_weight
11
- from skimage.util import random_noise
12
- from skimage.filters import gaussian
13
- import random
14
- from tifffile import imread
15
- import json
16
- from csbdeep.utils import normalize_mi_ma
17
- from glob import glob
18
- from urllib.request import urlopen
19
- import zipfile
20
- from tqdm import tqdm
21
- import shutil
22
- import tempfile
23
- from scipy.interpolate import griddata
24
- import re
25
- from scipy.ndimage.morphology import distance_transform_edt
26
- from scipy import ndimage
27
- from skimage.morphology import disk
28
- from scipy.stats import ks_2samp
29
- from cliffs_delta import cliffs_delta
30
- from stardist.models import StarDist2D
31
- from cellpose.models import CellposeModel
32
- from pathlib import PosixPath, PurePath, PurePosixPath, WindowsPath, Path
33
- from prettytable import PrettyTable
34
- from typing import List, Dict, Union, Optional
35
-
36
- def is_integer_array(arr: np.ndarray) -> bool:
37
-
38
- # Mask out NaNs
39
- non_nan_values = arr[arr==arr].flatten()
40
- test = np.all(np.mod(non_nan_values, 1) == 0)
41
-
42
- if test:
43
- return True
44
- else:
45
- return False
46
-
47
- def get_config(experiment: Union[str,Path]) -> str:
48
-
49
- """
50
- Retrieves the path to the configuration file for a given experiment.
51
-
52
- Parameters
53
- ----------
54
- experiment : str
55
- The file system path to the directory of the experiment project.
56
-
57
- Returns
58
- -------
59
- str
60
- The full path to the configuration file (`config.ini`) within the experiment directory.
61
-
62
- Raises
63
- ------
64
- AssertionError
65
- If the `config.ini` file does not exist in the specified experiment directory.
66
-
67
- Notes
68
- -----
69
- - The function ensures that the provided experiment path ends with the appropriate file separator (`os.sep`)
70
- before appending `config.ini` to locate the configuration file.
71
- - The configuration file is expected to be named `config.ini` and located at the root of the experiment directory.
72
-
73
- Example
74
- -------
75
- >>> experiment = "/path/to/experiment"
76
- >>> config_path = get_config(experiment)
77
- >>> print(config_path)
78
- '/path/to/experiment/config.ini'
79
-
80
- """
81
-
82
- if isinstance(experiment, (PosixPath, PurePosixPath, WindowsPath)):
83
- experiment = str(experiment)
84
-
85
- if not experiment.endswith(os.sep):
86
- experiment += os.sep
87
-
88
- config = experiment + 'config.ini'
89
- config = rf"{config}"
90
-
91
- assert os.path.exists(config), 'The experiment configuration could not be located...'
92
- return config
93
-
94
-
95
- def _remove_invalid_cols(df: pd.DataFrame) -> pd.DataFrame:
96
-
97
- """
98
- Removes invalid columns from a DataFrame.
99
-
100
- This function identifies and removes columns in the DataFrame whose names
101
- start with "Unnamed", or that contain only NaN values.
102
-
103
- Parameters
104
- ----------
105
- df : pandas.DataFrame
106
- The input DataFrame from which invalid columns will be removed.
107
-
108
- Returns
109
- -------
110
- pandas.DataFrame
111
- A new DataFrame with the invalid columns removed. If no invalid
112
- columns are found, the original DataFrame is returned unchanged.
113
- """
114
-
115
- invalid_cols = [c for c in list(df.columns) if c.startswith('Unnamed')]
116
- if len(invalid_cols)>0:
117
- df = df.drop(invalid_cols, axis=1)
118
- df = df.dropna(axis=1, how='all')
119
- return df
120
-
121
- def _extract_coordinates_from_features(df: pd.DataFrame, timepoint: int) -> pd.DataFrame:
122
-
123
- """
124
- Re-format coordinates from a regionprops table to tracking/measurement table format.
125
-
126
- Parameters
127
- ----------
128
- df : pandas.DataFrame
129
- A DataFrame containing feature data, including columns for centroids
130
- (`'centroid-1'` and `'centroid-0'`) and feature classes (`'class_id'`).
131
- timepoint : int
132
- The timepoint (frame) to assign to all features. This is used to populate
133
- the `'FRAME'` column in the output.
134
-
135
- Returns
136
- -------
137
- pandas.DataFrame
138
- A DataFrame containing the extracted coordinates and additional metadata,
139
- with the following columns:
140
- - `'POSITION_X'`: X-coordinate of the centroid.
141
- - `'POSITION_Y'`: Y-coordinate of the centroid.
142
- - `'class_id'`: The label associated to the cell mask.
143
- - `'ID'`: A unique identifier for each cell (index-based).
144
- - `'FRAME'`: The timepoint associated with the features.
145
-
146
- Notes
147
- -----
148
- - The function assumes that the input DataFrame contains columns `'centroid-1'`,
149
- `'centroid-0'`, and `'class_id'`. Missing columns will raise a KeyError.
150
- - The `'ID'` column is created based on the index of the input DataFrame.
151
- - This function renames `'centroid-1'` to `'POSITION_X'` and `'centroid-0'`
152
- to `'POSITION_Y'`.
153
- """
154
-
155
- coords = df[['centroid-1', 'centroid-0', 'class_id']].copy()
156
- coords['ID'] = np.arange(len(coords))
157
- coords.rename(columns={'centroid-1': 'POSITION_X', 'centroid-0': 'POSITION_Y'}, inplace=True)
158
- coords['FRAME'] = int(timepoint)
159
-
160
- return coords
161
-
162
- def _mask_intensity_measurements(df: pd.DataFrame, mask_channels: Optional[List[str]]):
163
-
164
- """
165
- Removes columns from a DataFrame that match specific channel name patterns.
166
-
167
- This function filters out intensity measurement columns in a DataFrame based on
168
- specified channel names. It identifies columns containing the channel
169
- names as substrings and drops them from the DataFrame.
170
-
171
- Parameters
172
- ----------
173
- df : pandas.DataFrame
174
- The input DataFrame containing intensity measurement data. Column names should
175
- include the mask channel names if they are to be filtered.
176
- mask_channels : list of str or None
177
- A list of channel names (as substrings) to use for identifying columns
178
- to remove. If `None`, no filtering is applied, and the original DataFrame is
179
- returned.
180
-
181
- Returns
182
- -------
183
- pandas.DataFrame
184
- The modified DataFrame with specified columns removed. If no columns match
185
- the mask channels, the original DataFrame is returned.
186
-
187
- Notes
188
- -----
189
- - The function searches for mask channel substrings in column names.
190
- Partial matches are sufficient to mark a column for removal.
191
- - If no mask channels are specified (`mask_channels` is `None`), the function
192
- does not modify the input DataFrame.
193
- """
194
-
195
- if isinstance(mask_channels, str):
196
- mask_channels = [mask_channels]
197
-
198
- if mask_channels is not None:
199
-
200
- cols_to_drop = []
201
- columns = list(df.columns)
202
-
203
- for mc in mask_channels:
204
- cols_to_remove = [c for c in columns if mc in c]
205
- cols_to_drop.extend(cols_to_remove)
206
-
207
- if len(cols_to_drop)>0:
208
- df = df.drop(cols_to_drop, axis=1)
209
- return df
210
-
211
- def _rearrange_multichannel_frame(frame: np.ndarray, n_channels: Optional[int] = None) -> np.ndarray:
212
-
213
- """
214
- Rearranges the axes of a multi-channel frame to ensure the channel axis is at the end.
215
-
216
- This function standardizes the input frame to ensure that the channel axis (if present)
217
- is moved to the last position. For 2D frames, it adds a singleton channel axis at the end.
218
-
219
- Parameters
220
- ----------
221
- frame : ndarray
222
- The input frame to be rearranged. Can be 2D or 3D.
223
- - If 3D, the function identifies the channel axis (assumed to be the axis with the smallest size)
224
- and moves it to the last position.
225
- - If 2D, the function adds a singleton channel axis to make it compatible with 3D processing.
226
-
227
- Returns
228
- -------
229
- ndarray
230
- The rearranged frame with the channel axis at the end.
231
- - For 3D frames, the output shape will have the channel axis as the last dimension.
232
- - For 2D frames, the output will have shape `(H, W, 1)` where `H` and `W` are the height and width of the frame.
233
-
234
- Notes
235
- -----
236
- - This function assumes that in a 3D input, the channel axis is the one with the smallest size.
237
- - For 2D frames, this function ensures compatibility with multi-channel processing pipelines by
238
- adding a singleton dimension for the channel axis.
239
-
240
- Examples
241
- --------
242
- Rearranging a 3D multi-channel frame:
243
- >>> frame = np.zeros((10, 10, 3)) # Already channel-last
244
- >>> _rearrange_multichannel_frame(frame).shape
245
- (10, 10, 3)
246
-
247
- Rearranging a 3D frame with channel axis not at the end:
248
- >>> frame = np.zeros((3, 10, 10)) # Channel-first
249
- >>> _rearrange_multichannel_frame(frame).shape
250
- (10, 10, 3)
251
-
252
- Converting a 2D frame to have a channel axis:
253
- >>> frame = np.zeros((10, 10)) # Grayscale image
254
- >>> _rearrange_multichannel_frame(frame).shape
255
- (10, 10, 1)
256
- """
257
-
258
-
259
- if frame.ndim == 3:
260
- # Systematically move channel axis to the end
261
- if n_channels is not None and n_channels in list(frame.shape):
262
- channel_axis = list(frame.shape).index(n_channels)
263
- else:
264
- channel_axis = np.argmin(frame.shape)
265
- frame = np.moveaxis(frame, channel_axis, -1)
266
-
267
- if frame.ndim==2:
268
- frame = frame[:,:,np.newaxis]
269
-
270
- return frame
271
-
272
- def _fix_no_contrast(frames: np.ndarray, value: Union[float,int] = 1):
273
-
274
- """
275
- Ensures that frames with no contrast (i.e., containing only a single unique value) are adjusted.
276
-
277
- This function modifies frames that lack contrast by adding a small value to the first pixel in
278
- the affected frame. This prevents downstream issues in image processing pipelines that require
279
- a minimum level of contrast.
280
-
281
- Parameters
282
- ----------
283
- frames : ndarray
284
- A 3D array of shape `(H, W, N)`, where:
285
- - `H` is the height of the frame,
286
- - `W` is the width of the frame,
287
- - `N` is the number of frames or channels.
288
- Each frame (or channel) is independently checked for contrast.
289
- value : int or float, optional
290
- The value to add to the first pixel (`frames[0, 0, k]`) of any frame that lacks contrast.
291
- Default is `1`.
292
-
293
- Returns
294
- -------
295
- ndarray
296
- The modified `frames` array, where frames with no contrast have been adjusted.
297
-
298
- Notes
299
- -----
300
- - A frame is determined to have "no contrast" if all its pixel values are identical.
301
- - Only the first pixel (`[0, 0, k]`) of a no-contrast frame is modified, leaving the rest
302
- of the frame unchanged.
303
- """
304
-
305
- for k in range(frames.shape[2]):
306
- unique_values = np.unique(frames[:,:,k])
307
- if len(unique_values)==1:
308
- frames[0,0,k] += value
309
- return frames
310
-
311
- def zoom_multiframes(frames: np.ndarray, zoom_factor: float) -> np.ndarray:
312
-
313
- """
314
- Applies zooming to each frame (channel) in a multi-frame image.
315
-
316
- This function resizes each channel of a multi-frame image independently using a specified zoom factor.
317
- The zoom is applied using spline interpolation of the specified order, and the channels are combined
318
- back into the original format.
319
-
320
- Parameters
321
- ----------
322
- frames : ndarray
323
- A multi-frame image with dimensions `(height, width, channels)`. The last axis represents different
324
- channels.
325
- zoom_factor : float
326
- The zoom factor to apply to each channel. Values greater than 1 increase the size, and values
327
- between 0 and 1 decrease the size.
328
-
329
- Returns
330
- -------
331
- ndarray
332
- A new multi-frame image with the same number of channels as the input, but with the height and width
333
- scaled by the zoom factor.
334
-
335
- Notes
336
- -----
337
- - The function uses spline interpolation (order 3) for resizing, which provides smooth results.
338
- - `prefilter=False` is used to prevent additional filtering during the zoom operation.
339
- - The function assumes that the input is in `height x width x channels` format, with channels along the
340
- last axis.
341
- """
342
-
343
- frames = [zoom(frames[:,:,c].copy(), [zoom_factor,zoom_factor], order=3, prefilter=False) for c in range(frames.shape[-1])]
344
- frames = np.moveaxis(frames,0,-1)
345
- return frames
346
-
347
- def _prep_stardist_model(model_name, path, use_gpu=False, scale=1):
348
-
349
- """
350
- Prepares and loads a StarDist2D model for segmentation tasks.
351
-
352
- This function initializes a StarDist2D model with the specified parameters, sets GPU usage if desired,
353
- and allows scaling to adapt the model for specific applications.
354
-
355
- Parameters
356
- ----------
357
- model_name : str
358
- The name of the StarDist2D model to load. This name should match the model saved in the specified path.
359
- path : str
360
- The directory where the model is stored.
361
- use_gpu : bool, optional
362
- If `True`, the model will be configured to use GPU acceleration for computations. Default is `False`.
363
- scale : int or float, optional
364
- A scaling factor for the model. This can be used to adapt the model for specific image resolutions.
365
- Default is `1`.
366
-
367
- Returns
368
- -------
369
- tuple
370
- - model : StarDist2D
371
- The loaded StarDist2D model configured with the specified parameters.
372
- - scale_model : int or float
373
- The scaling factor passed to the function.
374
-
375
- Notes
376
- -----
377
- - Ensure the StarDist2D package is installed and the model files are correctly stored in the provided path.
378
- - GPU support depends on the availability of compatible hardware and software setup.
379
- """
380
-
381
- model = StarDist2D(None, name=model_name, basedir=path)
382
- model.config.use_gpu = use_gpu
383
- model.use_gpu = use_gpu
384
-
385
- scale_model = scale
386
-
387
-
388
- print(f"StarDist model {model_name} successfully loaded...")
389
- return model, scale_model
390
-
391
- def _prep_cellpose_model(model_name, path, use_gpu=False, n_channels=2, scale=None):
392
-
393
- """
394
- Prepares and loads a Cellpose model for segmentation tasks.
395
-
396
- This function initializes a Cellpose model with the specified parameters, configures GPU usage if available,
397
- and calculates or applies a scaling factor for the model based on image resolution.
398
-
399
- Parameters
400
- ----------
401
- model_name : str
402
- The name of the pretrained Cellpose model to load.
403
- path : str
404
- The directory where the model is stored.
405
- use_gpu : bool, optional
406
- If `True`, the model will use GPU acceleration for computations. Default is `False`.
407
- n_channels : int, optional
408
- The number of input channels expected by the model. Default is `2`.
409
- scale : float, optional
410
- A scaling factor to adjust the model's output to match the image resolution. If not provided, the scale is
411
- automatically calculated based on the model's diameter parameters.
412
-
413
- Returns
414
- -------
415
- tuple
416
- - model : CellposeModel
417
- The loaded Cellpose model configured with the specified parameters.
418
- - scale_model : float
419
- The scaling factor applied to the model, calculated or provided.
420
-
421
- Notes
422
- -----
423
- - Ensure the Cellpose package is installed and the model files are correctly stored in the provided path.
424
- - GPU support depends on the availability of compatible hardware and software setup.
425
- - The scale is calculated as `(diam_mean / diam_labels)` if `scale` is not provided, where `diam_mean` and
426
- `diam_labels` are attributes of the model.
427
- """
428
-
429
- import torch
430
- if not use_gpu:
431
- device = torch.device("cpu")
432
- else:
433
- device = torch.device("cuda")
434
-
435
- model = CellposeModel(gpu=use_gpu, device=device, pretrained_model=path+model_name, model_type=None, nchan=n_channels) #diam_mean=30.0,
436
- if scale is None:
437
- scale_model = model.diam_mean / model.diam_labels
438
- else:
439
- scale_model = scale * model.diam_mean / model.diam_labels
440
-
441
- print(f'Cell size in model: {model.diam_mean} pixels...')
442
- print(f'Cell size in training set: {model.diam_labels} pixels...')
443
- print(f"Rescaling factor to apply: {scale_model}...")
444
-
445
- print(f'Cellpose model {model_name} successfully loaded...')
446
- return model, scale_model
447
-
448
-
449
- def _get_normalize_kwargs_from_config(config):
450
-
451
- if isinstance(config, str):
452
- if os.path.exists(config):
453
- with open(config) as cfg:
454
- config = json.load(cfg)
455
- else:
456
- print('Configuration could not be loaded...')
457
- os.abort()
458
-
459
- normalization_percentile = config['normalization_percentile']
460
- normalization_clip = config['normalization_clip']
461
- normalization_values = config['normalization_values']
462
- normalize_kwargs = _get_normalize_kwargs(normalization_percentile, normalization_values, normalization_clip)
463
-
464
- return normalize_kwargs
465
-
466
- def _get_normalize_kwargs(normalization_percentile, normalization_values, normalization_clip):
467
-
468
- values = []
469
- percentiles = []
470
- for k in range(len(normalization_percentile)):
471
- if normalization_percentile[k]:
472
- percentiles.append(normalization_values[k])
473
- values.append(None)
474
- else:
475
- percentiles.append(None)
476
- values.append(normalization_values[k])
477
-
478
- return {"percentiles": percentiles, 'values': values, 'clip': normalization_clip}
479
-
480
- def _segment_image_with_cellpose_model(img, model=None, diameter=None, cellprob_threshold=None, flow_threshold=None, channel_axis=-1):
481
-
482
- """
483
- Segments an input image using a Cellpose model.
484
-
485
- This function applies a preloaded Cellpose model to segment an input image and returns the resulting labeled mask.
486
- The image is rearranged into the format expected by the Cellpose model, with the specified channel axis moved to the first dimension.
487
-
488
- Parameters
489
- ----------
490
- img : ndarray
491
- The input image to be segmented. It is expected to have a channel axis specified by `channel_axis`.
492
- model : CellposeModel, optional
493
- A preloaded Cellpose model instance used for segmentation.
494
- diameter : float, optional
495
- The diameter of objects to segment. If `None`, the model's default diameter is used.
496
- cellprob_threshold : float, optional
497
- The threshold for the probability of cells used during segmentation. If `None`, the default threshold is used.
498
- flow_threshold : float, optional
499
- The threshold for flow error during segmentation. If `None`, the default threshold is used.
500
- channel_axis : int, optional
501
- The axis of the input image that represents the channels. Default is `-1` (channel-last format).
502
-
503
- Returns
504
- -------
505
- ndarray
506
- A labeled mask of the same spatial dimensions as the input image, with segmented regions assigned unique
507
- integer labels. The dtype of the mask is `uint16`.
508
-
509
- Notes
510
- -----
511
- - The `img` array is internally rearranged to move the specified `channel_axis` to the first dimension to comply
512
- with the Cellpose model's input requirements.
513
- - Ensure the provided `model` is a properly initialized Cellpose model instance.
514
- - Parameters `diameter`, `cellprob_threshold`, and `flow_threshold` allow fine-tuning of the segmentation process.
515
- """
516
-
517
- img = np.moveaxis(img, channel_axis, 0)
518
- lbl, _, _ = model.eval(img, diameter = diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, channels=None, normalize=False)
519
-
520
- return lbl.astype(np.uint16)
521
-
522
- def _segment_image_with_stardist_model(img, model=None, return_details=False, channel_axis=-1):
523
-
524
- """
525
- Segments an input image using a StarDist model.
526
-
527
- This function applies a preloaded StarDist model to segment an input image and returns the resulting labeled mask.
528
- Optionally, additional details about the segmentation can also be returned.
529
-
530
- Parameters
531
- ----------
532
- img : ndarray
533
- The input image to be segmented. It is expected to have a channel axis specified by `channel_axis`.
534
- model : StarDist2D, optional
535
- A preloaded StarDist model instance used for segmentation.
536
- return_details : bool, optional
537
- Whether to return additional details from the model alongside the labeled mask. Default is `False`.
538
- channel_axis : int, optional
539
- The axis of the input image that represents the channels. Default is `-1` (channel-last format).
540
-
541
- Returns
542
- -------
543
- ndarray
544
- A labeled mask of the same spatial dimensions as the input image, with segmented regions assigned unique
545
- integer labels. The dtype of the mask is `uint16`.
546
- tuple of (ndarray, dict), optional
547
- If `return_details` is `True`, returns a tuple where the first element is the labeled mask and the second
548
- element is a dictionary containing additional details about the segmentation.
549
-
550
- Notes
551
- -----
552
- - The `img` array is internally rearranged to move the specified `channel_axis` to the last dimension to comply
553
- with the StarDist model's input requirements.
554
- - Ensure the provided `model` is a properly initialized StarDist model instance.
555
- - The model automatically determines the number of tiles (`n_tiles`) required for processing large images.
556
- """
557
-
558
- if channel_axis!=-1:
559
- img = np.moveaxis(img, channel_axis, -1)
560
-
561
- lbl, details = model.predict_instances(img, n_tiles=model._guess_n_tiles(img), show_tile_progress=False, verbose=False)
562
- if not return_details:
563
- return lbl.astype(np.uint16)
564
- else:
565
- return lbl.astype(np.uint16), details
566
-
567
- def _rescale_labels(lbl, scale_model=1):
568
- return zoom(lbl, [1./scale_model, 1./scale_model], order=0)
569
-
570
- def extract_cols_from_table_list(tables, nrows=1):
571
-
572
- """
573
- Extracts a unique list of column names from a list of CSV tables.
574
-
575
- Parameters
576
- ----------
577
- tables : list of str
578
- A list of file paths to the CSV tables from which to extract column names.
579
- nrows : int, optional
580
- The number of rows to read from each table to identify the columns.
581
- Default is 1.
582
-
583
- Returns
584
- -------
585
- numpy.ndarray
586
- An array of unique column names found across all the tables.
587
-
588
- Notes
589
- -----
590
- - This function reads only the first `nrows` rows of each table to improve performance when dealing with large files.
591
- - The function ensures that column names are unique by consolidating them using `numpy.unique`.
592
-
593
- Examples
594
- --------
595
- >>> tables = ["table1.csv", "table2.csv"]
596
- >>> extract_cols_from_table_list(tables)
597
- array(['Column1', 'Column2', 'Column3'], dtype='<U8')
598
- """
599
-
600
- all_columns = []
601
- for tab in tables:
602
- cols = pd.read_csv(tab, nrows=1).columns.tolist()
603
- all_columns.extend(cols)
604
- all_columns = np.unique(all_columns)
605
- return all_columns
606
-
607
- def safe_log(array):
608
-
609
- """
610
- Safely computes the base-10 logarithm for numeric inputs, handling invalid or non-positive values.
611
-
612
- Parameters
613
- ----------
614
- array : int, float, list, or numpy.ndarray
615
- The input value or array for which to compute the logarithm.
616
- Can be a single number (int or float), a list, or a numpy array.
617
-
618
- Returns
619
- -------
620
- float or numpy.ndarray
621
- - If the input is a single numeric value, returns the base-10 logarithm as a float, or `np.nan` if the value is non-positive.
622
- - If the input is a list or numpy array, returns a numpy array with the base-10 logarithm of each element.
623
- Invalid or non-positive values are replaced with `np.nan`.
624
-
625
- Notes
626
- -----
627
- - Non-positive values (`<= 0`) are considered invalid and will result in `np.nan`.
628
- - NaN values in the input array are preserved in the output.
629
- - If the input is a list, it is converted to a numpy array for processing.
630
-
631
- Examples
632
- --------
633
- >>> safe_log(10)
634
- 1.0
635
-
636
- >>> safe_log(-5)
637
- nan
638
-
639
- >>> safe_log([10, 0, -5, 100])
640
- array([1.0, nan, nan, 2.0])
641
-
642
- >>> import numpy as np
643
- >>> safe_log(np.array([1, 10, 100]))
644
- array([0.0, 1.0, 2.0])
645
- """
646
-
647
- array = np.asarray(array, dtype=float)
648
- result = np.where(array > 0, np.log10(array), np.nan)
649
-
650
- return result.item() if np.isscalar(array) else result
651
-
652
- def contour_of_instance_segmentation(label, distance):
653
-
654
- """
655
-
656
- Generate an instance mask containing the contour of the segmented objects.
657
-
658
- Parameters
659
- ----------
660
- label : ndarray
661
- The instance segmentation labels.
662
- distance : int, float, list, or tuple
663
- The distance or range of distances from the edge of each instance to include in the contour.
664
- If a single value is provided, it represents the maximum distance. If a tuple or list is provided,
665
- it represents the minimum and maximum distances.
666
-
667
- Returns
668
- -------
669
- border_label : ndarray
670
- An instance mask containing the contour of the segmented objects.
671
-
672
- Notes
673
- -----
674
- This function generates an instance mask representing the contour of the segmented instances in the label image.
675
- It use the distance_transform_edt function from the scipy.ndimage module to compute the Euclidean distance transform.
676
- The contour is defined based on the specified distance(s) from the edge of each instance.
677
- The resulting mask, `border_label`, contains the contour regions, while the interior regions are set to zero.
678
-
679
- Examples
680
- --------
681
- >>> border_label = contour_of_instance_segmentation(label, distance=3)
682
- # Generate a binary mask containing the contour of the segmented instances with a maximum distance of 3 pixels.
683
-
684
- """
685
- if isinstance(distance,(list,tuple)) or distance >= 0 :
686
-
687
- edt = distance_transform_edt(label)
688
-
689
- if isinstance(distance, list) or isinstance(distance, tuple):
690
- min_distance = distance[0]; max_distance = distance[1]
691
-
692
- elif isinstance(distance, (int, float)):
693
- min_distance = 0
694
- max_distance = distance
695
-
696
- thresholded = (edt <= max_distance) * (edt > min_distance)
697
- border_label = np.copy(label)
698
- border_label[np.where(thresholded == 0)] = 0
699
-
700
- else:
701
- size = (2*abs(int(distance))+1, 2*abs(int(distance))+1)
702
- dilated_image = ndimage.grey_dilation(label, footprint=disk(int(abs(distance)))) #size=size,
703
- border_label=np.copy(dilated_image)
704
- matching_cells = np.logical_and(dilated_image != 0, label == dilated_image)
705
- border_label[np.where(matching_cells == True)] = 0
706
- border_label[label!=0] = 0.
707
-
708
- return border_label
709
-
710
- def extract_identity_col(trajectories):
711
-
712
- """
713
- Determines the identity column name in a DataFrame of trajectories.
714
-
715
- This function checks the provided DataFrame for the presence of a column
716
- that can serve as the identity column. It first looks for the column
717
- 'TRACK_ID'. If 'TRACK_ID' exists but contains only null values, it checks
718
- for the column 'ID' instead. If neither column is found, the function
719
- returns `None` and prints a message indicating the issue.
720
-
721
- Parameters
722
- ----------
723
- trajectories : pandas.DataFrame
724
- A DataFrame containing trajectory data. The function assumes that
725
- the identity of each trajectory might be stored in either the
726
- 'TRACK_ID' or 'ID' column.
727
-
728
- Returns
729
- -------
730
- str or None
731
- The name of the identity column ('TRACK_ID' or 'ID') if found;
732
- otherwise, `None`.
733
- """
734
-
735
- for col in ['TRACK_ID', 'ID']:
736
- if col in trajectories.columns and not trajectories[col].isnull().all():
737
- return col
738
-
739
- print('ID or TRACK_ID column could not be found in the table...')
740
- return None
741
-
742
- def derivative(x, timeline, window, mode='bi'):
743
-
744
- """
745
- Compute the derivative of a given array of values with respect to time using a specified numerical differentiation method.
746
-
747
- Parameters
748
- ----------
749
- x : array_like
750
- The input array of values.
751
- timeline : array_like
752
- The array representing the time points corresponding to the input values.
753
- window : int
754
- The size of the window used for numerical differentiation. Must be a positive odd integer.
755
- mode : {'bi', 'forward', 'backward'}, optional
756
- The numerical differentiation method to be used:
757
- - 'bi' (default): Bidirectional differentiation using a symmetric window.
758
- - 'forward': Forward differentiation using a one-sided window.
759
- - 'backward': Backward differentiation using a one-sided window.
760
-
761
- Returns
762
- -------
763
- dxdt : ndarray
764
- The computed derivative values of the input array with respect to time.
765
-
766
- Raises
767
- ------
768
- AssertionError
769
- If the window size is not an odd integer and mode is 'bi'.
770
-
771
- Notes
772
- -----
773
- - For 'bi' mode, the window size must be an odd number.
774
- - For 'forward' mode, the derivative at the edge points may not be accurate due to the one-sided window.
775
- - For 'backward' mode, the derivative at the first few points may not be accurate due to the one-sided window.
776
-
777
- Examples
778
- --------
779
- >>> import numpy as np
780
- >>> x = np.array([1, 2, 4, 7, 11])
781
- >>> timeline = np.array([0, 1, 2, 3, 4])
782
- >>> window = 3
783
- >>> derivative(x, timeline, window, mode='bi')
784
- array([3., 3., 3.])
785
-
786
- >>> derivative(x, timeline, window, mode='forward')
787
- array([1., 2., 3.])
788
-
789
- >>> derivative(x, timeline, window, mode='backward')
790
- array([3., 3., 3., 3.])
791
- """
792
-
793
- # modes = bi, forward, backward
794
- dxdt = np.zeros(len(x))
795
- dxdt[:] = np.nan
796
-
797
- if mode=='bi':
798
- assert window%2==1,'Please set an odd window for the bidirectional mode'
799
- lower_bound = window//2
800
- upper_bound = len(x) - window//2
801
- elif mode=='forward':
802
- lower_bound = 0
803
- upper_bound = len(x) - window
804
- elif mode=='backward':
805
- lower_bound = window
806
- upper_bound = len(x)
807
-
808
- for t in range(lower_bound,upper_bound):
809
- if mode=='bi':
810
- dxdt[t] = (x[t+window//2] - x[t-window//2]) / (timeline[t+window//2] - timeline[t-window//2])
811
- elif mode=='forward':
812
- dxdt[t] = (x[t+window] - x[t]) / (timeline[t+window] - timeline[t])
813
- elif mode=='backward':
814
- dxdt[t] = (x[t] - x[t-window]) / (timeline[t] - timeline[t-window])
815
- return dxdt
816
-
817
- def differentiate_per_track(tracks, measurement, window_size=3, mode='bi'):
818
-
819
- groupby_cols = ['TRACK_ID']
820
- if 'position' in list(tracks.columns):
821
- groupby_cols = ['position']+groupby_cols
822
-
823
- tracks = tracks.sort_values(by=groupby_cols+['FRAME'],ignore_index=True)
824
- tracks = tracks.reset_index(drop=True)
825
- for tid, group in tracks.groupby(groupby_cols):
826
- indices = group.index
827
- timeline = group['FRAME'].values
828
- signal = group[measurement].values
829
- dsignal = derivative(signal, timeline, window_size, mode=mode)
830
- tracks.loc[indices, 'd/dt.'+measurement] = dsignal
831
- return tracks
832
-
833
- def velocity_per_track(tracks, window_size=3, mode='bi'):
834
-
835
- groupby_cols = ['TRACK_ID']
836
- if 'position' in list(tracks.columns):
837
- groupby_cols = ['position']+groupby_cols
838
-
839
- tracks = tracks.sort_values(by=groupby_cols+['FRAME'],ignore_index=True)
840
- tracks = tracks.reset_index(drop=True)
841
- for tid, group in tracks.groupby(groupby_cols):
842
- indices = group.index
843
- timeline = group['FRAME'].values
844
- x = group['POSITION_X'].values
845
- y = group['POSITION_Y'].values
846
- v = velocity(x,y,timeline,window=window_size,mode=mode)
847
- v_abs = magnitude_velocity(v)
848
- tracks.loc[indices, 'velocity'] = v_abs
849
- return tracks
850
-
851
- def velocity(x,y,timeline,window,mode='bi'):
852
-
853
- """
854
- Compute the velocity vector of a given 2D trajectory represented by arrays of x and y coordinates
855
- with respect to time using a specified numerical differentiation method.
856
-
857
- Parameters
858
- ----------
859
- x : array_like
860
- The array of x-coordinates of the trajectory.
861
- y : array_like
862
- The array of y-coordinates of the trajectory.
863
- timeline : array_like
864
- The array representing the time points corresponding to the x and y coordinates.
865
- window : int
866
- The size of the window used for numerical differentiation. Must be a positive odd integer.
867
- mode : {'bi', 'forward', 'backward'}, optional
868
- The numerical differentiation method to be used:
869
- - 'bi' (default): Bidirectional differentiation using a symmetric window.
870
- - 'forward': Forward differentiation using a one-sided window.
871
- - 'backward': Backward differentiation using a one-sided window.
872
-
873
- Returns
874
- -------
875
- v : ndarray
876
- The computed velocity vector of the 2D trajectory with respect to time.
877
- The first column represents the x-component of velocity, and the second column represents the y-component.
878
-
879
- Raises
880
- ------
881
- AssertionError
882
- If the window size is not an odd integer and mode is 'bi'.
883
-
884
- Notes
885
- -----
886
- - For 'bi' mode, the window size must be an odd number.
887
- - For 'forward' mode, the velocity at the edge points may not be accurate due to the one-sided window.
888
- - For 'backward' mode, the velocity at the first few points may not be accurate due to the one-sided window.
889
-
890
- Examples
891
- --------
892
- >>> import numpy as np
893
- >>> x = np.array([1, 2, 4, 7, 11])
894
- >>> y = np.array([0, 3, 5, 8, 10])
895
- >>> timeline = np.array([0, 1, 2, 3, 4])
896
- >>> window = 3
897
- >>> velocity(x, y, timeline, window, mode='bi')
898
- array([[3., 3.],
899
- [3., 3.]])
900
-
901
- >>> velocity(x, y, timeline, window, mode='forward')
902
- array([[2., 2.],
903
- [3., 3.]])
904
-
905
- >>> velocity(x, y, timeline, window, mode='backward')
906
- array([[3., 3.],
907
- [3., 3.]])
908
- """
909
-
910
- v = np.zeros((len(x),2))
911
- v[:,:] = np.nan
912
-
913
- v[:,0] = derivative(x, timeline, window, mode=mode)
914
- v[:,1] = derivative(y, timeline, window, mode=mode)
915
-
916
- return v
917
-
918
- def magnitude_velocity(v_matrix):
919
-
920
- """
921
- Compute the magnitude of velocity vectors given a matrix representing 2D velocity vectors.
922
-
923
- Parameters
924
- ----------
925
- v_matrix : array_like
926
- The matrix where each row represents a 2D velocity vector with the first column
927
- being the x-component and the second column being the y-component.
928
-
929
- Returns
930
- -------
931
- magnitude : ndarray
932
- The computed magnitudes of the input velocity vectors.
933
-
934
- Notes
935
- -----
936
- - If a velocity vector has NaN components, the corresponding magnitude will be NaN.
937
- - The function handles NaN values in the input matrix gracefully.
938
-
939
- Examples
940
- --------
941
- >>> import numpy as np
942
- >>> v_matrix = np.array([[3, 4],
943
- ... [2, 2],
944
- ... [3, 3]])
945
- >>> magnitude_velocity(v_matrix)
946
- array([5., 2.82842712, 4.24264069])
947
-
948
- >>> v_matrix_with_nan = np.array([[3, 4],
949
- ... [np.nan, 2],
950
- ... [3, np.nan]])
951
- >>> magnitude_velocity(v_matrix_with_nan)
952
- array([5., nan, nan])
953
- """
954
-
955
- magnitude = np.zeros(len(v_matrix))
956
- magnitude[:] = np.nan
957
- for i in range(len(v_matrix)):
958
- if v_matrix[i,0]==v_matrix[i,0]:
959
- magnitude[i] = np.sqrt(v_matrix[i,0]**2 + v_matrix[i,1]**2)
960
- return magnitude
961
-
962
- def orientation(v_matrix):
963
-
964
- """
965
- Compute the orientation angles (in radians) of 2D velocity vectors given a matrix representing velocity vectors.
966
-
967
- Parameters
968
- ----------
969
- v_matrix : array_like
970
- The matrix where each row represents a 2D velocity vector with the first column
971
- being the x-component and the second column being the y-component.
972
-
973
- Returns
974
- -------
975
- orientation_array : ndarray
976
- The computed orientation angles of the input velocity vectors in radians.
977
- If a velocity vector has NaN components, the corresponding orientation angle will be NaN.
978
-
979
- Examples
980
- --------
981
- >>> import numpy as np
982
- >>> v_matrix = np.array([[3, 4],
983
- ... [2, 2],
984
- ... [-3, -3]])
985
- >>> orientation(v_matrix)
986
- array([0.92729522, 0.78539816, -2.35619449])
987
-
988
- >>> v_matrix_with_nan = np.array([[3, 4],
989
- ... [np.nan, 2],
990
- ... [3, np.nan]])
991
- >>> orientation(v_matrix_with_nan)
992
- array([0.92729522, nan, nan])
993
- """
994
-
995
- orientation_array = np.zeros(len(v_matrix))
996
- for t in range(len(orientation_array)):
997
- if v_matrix[t,0]==v_matrix[t,0]:
998
- orientation_array[t] = np.arctan2(v_matrix[t,0],v_matrix[t,1])
999
- return orientation_array
1000
-
1001
-
1002
- def estimate_unreliable_edge(activation_protocol=[['gauss',2],['std',4]]):
1003
-
1004
- """
1005
- Safely estimate the distance to the edge of an image in which the filtered image values can be artefactual.
1006
-
1007
- Parameters
1008
- ----------
1009
- activation_protocol : list of list, optional
1010
- A list of lists, where each sublist contains a string naming the filter function, followed by its arguments (usually a kernel size).
1011
- Default is [['gauss', 2], ['std', 4]].
1012
-
1013
- Returns
1014
- -------
1015
- int or None
1016
- The sum of the kernel sizes in the activation protocol if the protocol
1017
- is not empty. Returns None if the activation protocol is empty.
1018
-
1019
- Notes
1020
- -----
1021
- This function assumes that the second element of each sublist in the
1022
- activation protocol is a kernel size.
1023
-
1024
- Examples
1025
- --------
1026
- >>> estimate_unreliable_edge([['gauss', 2], ['std', 4]])
1027
- 6
1028
- >>> estimate_unreliable_edge([])
1029
- None
1030
- """
1031
-
1032
- if activation_protocol==[]:
1033
- return None
1034
- else:
1035
- edge=0
1036
- for fct in activation_protocol:
1037
- if isinstance(fct[1],(int,np.int_)) and not fct[0]=='invert':
1038
- edge+=fct[1]
1039
- return edge
1040
-
1041
- def unpad(img, pad):
1042
-
1043
- """
1044
- Remove padding from an image.
1045
-
1046
- This function removes the specified amount of padding from the borders
1047
- of an image. The padding is assumed to be the same on all sides.
1048
-
1049
- Parameters
1050
- ----------
1051
- img : ndarray
1052
- The input image from which the padding will be removed.
1053
- pad : int
1054
- The amount of padding to remove from each side of the image.
1055
-
1056
- Returns
1057
- -------
1058
- ndarray
1059
- The image with the padding removed.
1060
-
1061
- Raises
1062
- ------
1063
- ValueError
1064
- If `pad` is greater than or equal to half of the smallest dimension
1065
- of `img`.
1066
-
1067
- See Also
1068
- --------
1069
- numpy.pad : Pads an array.
1070
-
1071
- Notes
1072
- -----
1073
- This function assumes that the input image is a 2D array.
1074
-
1075
- Examples
1076
- --------
1077
- >>> import numpy as np
1078
- >>> img = np.array([[0, 0, 0, 0, 0],
1079
- ... [0, 1, 1, 1, 0],
1080
- ... [0, 1, 1, 1, 0],
1081
- ... [0, 1, 1, 1, 0],
1082
- ... [0, 0, 0, 0, 0]])
1083
- >>> unpad(img, 1)
1084
- array([[1, 1, 1],
1085
- [1, 1, 1],
1086
- [1, 1, 1]])
1087
- """
1088
-
1089
- return img[pad:-pad, pad:-pad]
1090
-
1091
- def mask_edges(binary_mask, border_size):
1092
-
1093
- """
1094
- Mask the edges of a binary mask.
1095
-
1096
- This function sets the edges of a binary mask to False, effectively
1097
- masking out a border of the specified size.
1098
-
1099
- Parameters
1100
- ----------
1101
- binary_mask : ndarray
1102
- A 2D binary mask array where the edges will be masked.
1103
- border_size : int
1104
- The size of the border to mask (set to False) on all sides.
1105
-
1106
- Returns
1107
- -------
1108
- ndarray
1109
- The binary mask with the edges masked out.
1110
-
1111
- Raises
1112
- ------
1113
- ValueError
1114
- If `border_size` is greater than or equal to half of the smallest
1115
- dimension of `binary_mask`.
1116
-
1117
- Notes
1118
- -----
1119
- This function assumes that the input `binary_mask` is a 2D array. The
1120
- input mask is converted to a boolean array before masking the edges.
1121
-
1122
- Examples
1123
- --------
1124
- >>> import numpy as np
1125
- >>> binary_mask = np.array([[1, 1, 1, 1, 1],
1126
- ... [1, 1, 1, 1, 1],
1127
- ... [1, 1, 1, 1, 1],
1128
- ... [1, 1, 1, 1, 1],
1129
- ... [1, 1, 1, 1, 1]])
1130
- >>> mask_edges(binary_mask, 1)
1131
- array([[False, False, False, False, False],
1132
- [False, True, True, True, False],
1133
- [False, True, True, True, False],
1134
- [False, True, True, True, False],
1135
- [False, False, False, False, False]])
1136
- """
1137
-
1138
- binary_mask = binary_mask.astype(bool)
1139
- binary_mask[:border_size,:] = False
1140
- binary_mask[(binary_mask.shape[0]-border_size):,:] = False
1141
- binary_mask[:,:border_size] = False
1142
- binary_mask[:,(binary_mask.shape[1]-border_size):] = False
1143
-
1144
- return binary_mask
1145
-
1146
- def demangle_column_name(name):
1147
- if name.startswith("BACKTICK_QUOTED_STRING_"):
1148
- # Unquote backtick-quoted string.
1149
- return name[len("BACKTICK_QUOTED_STRING_"):].replace("_DOT_", ".").replace("_SLASH_", "/").replace('_MINUS_','-').replace('_PLUS_','+').replace('_PERCENT_','%').replace('_STAR_','*').replace('_LPAR_','(').replace('_RPAR_',')').replace('_AMPER_','&')
1150
- return name
1151
-
1152
- def extract_cols_from_query(query: str):
1153
-
1154
- backtick_pattern = r'`([^`]+)`'
1155
- backticked = set(re.findall(backtick_pattern, query))
1156
-
1157
- # 2. Remove backtick sections so they don't get double-counted
1158
- cleaned_query = re.sub(backtick_pattern, "", query)
1159
-
1160
- # 3. Extract bare identifiers from the remaining string
1161
- identifier_pattern = r'\b([A-Za-z_]\w*)\b'
1162
- bare = set(re.findall(identifier_pattern, cleaned_query))
1163
-
1164
- # 4. Remove Python keywords, operators, and pandas builtins
1165
- blacklist = set(dir(pd)) | set(dir(__builtins__)) | {
1166
- "and", "or", "not", "in", "True", "False"
1167
- }
1168
- bare = {c for c in bare if c not in blacklist}
1169
- cols = backticked | bare
1170
-
1171
- return list([demangle_column_name(c) for c in cols])
1172
-
1173
- def create_patch_mask(h, w, center=None, radius=None):
1174
-
1175
- """
1176
-
1177
- Create a circular patch mask of given dimensions.
1178
- Adapted from alkasm on https://stackoverflow.com/questions/44865023/how-can-i-create-a-circular-mask-for-a-numpy-array
1179
-
1180
- Parameters
1181
- ----------
1182
- h : int
1183
- Height of the mask. Prefer odd value.
1184
- w : int
1185
- Width of the mask. Prefer odd value.
1186
- center : tuple, optional
1187
- Coordinates of the center of the patch. If not provided, the middle of the image is used.
1188
- radius : int or float or list, optional
1189
- Radius of the circular patch. If not provided, the smallest distance between the center and image walls is used.
1190
- If a list is provided, it should contain two elements representing the inner and outer radii of a circular annular patch.
1191
-
1192
- Returns
1193
- -------
1194
- numpy.ndarray
1195
- Boolean mask where True values represent pixels within the circular patch or annular patch, and False values represent pixels outside.
1196
-
1197
- Notes
1198
- -----
1199
- The function creates a circular patch mask of the given dimensions by determining which pixels fall within the circular patch or annular patch.
1200
- The circular patch or annular patch is centered at the specified coordinates or at the middle of the image if coordinates are not provided.
1201
- The radius of the circular patch or annular patch is determined by the provided radius parameter or by the minimum distance between the center and image walls.
1202
- If an annular patch is desired, the radius parameter should be a list containing the inner and outer radii respectively.
1203
-
1204
- Examples
1205
- --------
1206
- >>> mask = create_patch_mask(100, 100, center=(50, 50), radius=30)
1207
- >>> print(mask)
1208
-
1209
- """
1210
-
1211
- if center is None: # use the middle of the image
1212
- center = (int(w/2), int(h/2))
1213
- if radius is None: # use the smallest distance between the center and image walls
1214
- radius = min(center[0], center[1], w-center[0], h-center[1])
1215
-
1216
- Y, X = np.ogrid[:h, :w]
1217
- dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
1218
-
1219
- if isinstance(radius,int) or isinstance(radius,float):
1220
- mask = dist_from_center <= radius
1221
- elif isinstance(radius,list):
1222
- mask = (dist_from_center <= radius[1])*(dist_from_center >= radius[0])
1223
- else:
1224
- print("Please provide a proper format for the radius")
1225
- return None
1226
-
1227
- return mask
1228
-
1229
- def rename_intensity_column(df, channels):
1230
-
1231
- """
1232
-
1233
- Rename intensity columns in a DataFrame based on the provided channel names.
1234
-
1235
- Parameters
1236
- ----------
1237
- df : pandas DataFrame
1238
- The DataFrame containing the intensity columns.
1239
- channels : list
1240
- A list of channel names corresponding to the intensity columns.
1241
-
1242
- Returns
1243
- -------
1244
- pandas DataFrame
1245
- The DataFrame with renamed intensity columns.
1246
-
1247
- Notes
1248
- -----
1249
- This function renames the intensity columns in a DataFrame based on the provided channel names.
1250
- It searches for columns containing the substring 'intensity' in their names and replaces it with
1251
- the respective channel name. The renaming is performed according to the order of the channels
1252
- provided in the `channels` list. If multiple channels are provided, the function assumes that the
1253
- intensity columns have a naming pattern that includes a numerical index indicating the channel.
1254
- If only one channel is provided, the function replaces 'intensity' with the single channel name.
1255
-
1256
- Examples
1257
- --------
1258
- >>> data = {'intensity_0': [1, 2, 3], 'intensity_1': [4, 5, 6]}
1259
- >>> df = pd.DataFrame(data)
1260
- >>> channels = ['channel1', 'channel2']
1261
- >>> renamed_df = rename_intensity_column(df, channels)
1262
- # Rename the intensity columns in the DataFrame based on the provided channel names.
1263
-
1264
- """
1265
-
1266
- channel_names = np.array(channels)
1267
- channel_indices = np.arange(len(channel_names),dtype=int)
1268
- intensity_cols = [s for s in list(df.columns) if 'intensity' in s]
1269
-
1270
- to_rename = {}
1271
-
1272
- for k in range(len(intensity_cols)):
1273
-
1274
- # identify if digit in section
1275
- sections = np.array(re.split('-|_', intensity_cols[k]))
1276
- test_digit = np.array([False for s in sections])
1277
- for j,s in enumerate(sections):
1278
- if str(s).isdigit():
1279
- if int(s)<len(channel_names):
1280
- test_digit[j] = True
1281
-
1282
- if np.any(test_digit):
1283
- index = int(sections[np.where(test_digit)[0]][-1])
1284
- else:
1285
- print(f'No valid channel index found for {intensity_cols[k]}... Skipping the renaming for {intensity_cols[k]}...')
1286
- continue
1287
-
1288
- channel_name = channel_names[np.where(channel_indices==index)[0]][0]
1289
- new_name = np.delete(sections, np.where(test_digit)[0]) #np.where(test_digit)[0]
1290
- new_name = '_'.join(list(new_name))
1291
- new_name = new_name.replace('intensity', channel_name)
1292
- new_name = new_name.replace('-','_')
1293
- new_name = new_name.replace('_nanmean','_mean')
1294
-
1295
- to_rename.update({intensity_cols[k]: new_name})
1296
-
1297
- if 'centre' in intensity_cols[k]:
1298
-
1299
- measure = np.array(re.split('-|_', new_name))
1300
-
1301
- if sections[-2] == "0":
1302
- new_name = np.delete(measure, -1)
1303
- new_name = '_'.join(list(new_name))
1304
- if 'edge' in intensity_cols[k]:
1305
- new_name = new_name.replace('center_of_mass_displacement', "edge_center_of_mass_displacement_in_px")
1306
- else:
1307
- new_name = new_name.replace('center_of_mass', "center_of_mass_displacement_in_px")
1308
- to_rename.update({intensity_cols[k]: new_name.replace('-', '_')})
1309
-
1310
- elif sections[-2] == "1":
1311
- new_name = np.delete(measure, -1)
1312
- new_name = '_'.join(list(new_name))
1313
- if 'edge' in intensity_cols[k]:
1314
- new_name = new_name.replace('center_of_mass_displacement', "edge_center_of_mass_orientation")
1315
- else:
1316
- new_name = new_name.replace('center_of_mass', "center_of_mass_orientation")
1317
- to_rename.update({intensity_cols[k]: new_name.replace('-', '_')})
1318
-
1319
- elif sections[-2] == "2":
1320
- new_name = np.delete(measure, -1)
1321
- new_name = '_'.join(list(new_name))
1322
- if 'edge' in intensity_cols[k]:
1323
- new_name = new_name.replace('center_of_mass_displacement', "edge_center_of_mass_x")
1324
- else:
1325
- new_name = new_name.replace('center_of_mass', "center_of_mass_x")
1326
- to_rename.update({intensity_cols[k]: new_name.replace('-', '_')})
1327
-
1328
- elif sections[-2] == "3":
1329
- new_name = np.delete(measure, -1)
1330
- new_name = '_'.join(list(new_name))
1331
- if 'edge' in intensity_cols[k]:
1332
- new_name = new_name.replace('center_of_mass_displacement', "edge_center_of_mass_y")
1333
- else:
1334
- new_name = new_name.replace('center_of_mass', "center_of_mass_y")
1335
- to_rename.update({intensity_cols[k]: new_name.replace('-', '_')})
1336
-
1337
- if 'radial_gradient' in intensity_cols[k]:
1338
- # sections = np.array(re.split('-|_', intensity_columns[k]))
1339
- measure = np.array(re.split('-|_', new_name))
1340
-
1341
- if sections[-2] == "0":
1342
- new_name = np.delete(measure, -1)
1343
- new_name = '_'.join(list(measure))
1344
- new_name = new_name.replace('radial_gradient', "radial_gradient")
1345
- to_rename.update({intensity_cols[k]: new_name.replace('-', '_')})
1346
-
1347
- elif sections[-2] == "1":
1348
- new_name = np.delete(measure, -1)
1349
- new_name = '_'.join(list(measure))
1350
- new_name = new_name.replace('radial_gradient', "radial_intercept")
1351
- to_rename.update({intensity_cols[k]: new_name.replace('-', '_')})
1352
-
1353
- elif sections[-2] == "2":
1354
- new_name = np.delete(measure, -1)
1355
- new_name = '_'.join(list(measure))
1356
- new_name = new_name.replace('radial_gradient', "radial_gradient_r2_score")
1357
- to_rename.update({intensity_cols[k]: new_name.replace('-', '_')})
1358
-
1359
- df = df.rename(columns=to_rename)
1360
-
1361
- return df
1362
-
1363
-
1364
- def regression_plot(y_pred, y_true, savepath=None):
1365
-
1366
- """
1367
-
1368
- Create a regression plot to compare predicted and ground truth values.
1369
-
1370
- Parameters
1371
- ----------
1372
- y_pred : array-like
1373
- Predicted values.
1374
- y_true : array-like
1375
- Ground truth values.
1376
- savepath : str or None, optional
1377
- File path to save the plot. If None, the plot is displayed but not saved. Default is None.
1378
-
1379
- Returns
1380
- -------
1381
- None
1382
-
1383
- Notes
1384
- -----
1385
- This function creates a scatter plot comparing the predicted values (`y_pred`) to the ground truth values (`y_true`)
1386
- for regression analysis. The plot also includes a diagonal reference line to visualize the ideal prediction scenario.
1387
-
1388
- If `savepath` is provided, the plot is saved as an image file at the specified path. The file format and other
1389
- parameters can be controlled by the `savepath` argument.
1390
-
1391
- Examples
1392
- --------
1393
- >>> y_pred = [1.5, 2.0, 3.2, 4.1]
1394
- >>> y_true = [1.7, 2.1, 3.5, 4.2]
1395
- >>> regression_plot(y_pred, y_true)
1396
- # Create a scatter plot comparing the predicted values to the ground truth values.
1397
-
1398
- >>> regression_plot(y_pred, y_true, savepath="regression_plot.png")
1399
- # Create a scatter plot and save it as "regression_plot.png".
1400
-
1401
- """
1402
-
1403
- fig,ax = plt.subplots(1,1,figsize=(4,3))
1404
- ax.scatter(y_pred, y_true)
1405
- ax.set_xlabel("prediction")
1406
- ax.set_ylabel("ground truth")
1407
- line = np.linspace(np.amin([y_pred,y_true]),np.amax([y_pred,y_true]),1000)
1408
- ax.plot(line,line,linestyle="--",c="k",alpha=0.7)
1409
- plt.tight_layout()
1410
- if savepath is not None:
1411
- plt.savefig(savepath,bbox_inches="tight",dpi=300)
1412
- plt.pause(2)
1413
- plt.close()
1414
-
1415
- def split_by_ratio(arr, *ratios):
1416
-
1417
- """
1418
-
1419
- Split an array into multiple chunks based on given ratios.
1420
-
1421
- Parameters
1422
- ----------
1423
- arr : array-like
1424
- The input array to be split.
1425
- *ratios : float
1426
- Ratios specifying the proportions of each chunk. The sum of ratios should be less than or equal to 1.
1427
-
1428
- Returns
1429
- -------
1430
- list
1431
- A list of arrays containing the splits/chunks of the input array.
1432
-
1433
- Notes
1434
- -----
1435
- This function randomly permutes the input array (`arr`) and then splits it into multiple chunks based on the provided ratios.
1436
- The ratios determine the relative sizes of the resulting chunks. The sum of the ratios should be less than or equal to 1.
1437
- The function uses the accumulated ratios to determine the split indices.
1438
-
1439
- The function returns a list of arrays representing the splits of the input array. The number of splits is equal to the number
1440
- of provided ratios. If there are more ratios than splits, the extra ratios are ignored.
1441
-
1442
- Examples
1443
- --------
1444
- >>> arr = np.arange(10)
1445
- >>> splits = split_by_ratio(arr, 0.6, 0.2, 0.2)
1446
- >>> print(len(splits))
1447
- 3
1448
- # Split the array into 3 chunks with ratios 0.6, 0.2, and 0.2.
1449
-
1450
- >>> arr = np.arange(100)
1451
- >>> splits = split_by_ratio(arr, 0.5, 0.25)
1452
- >>> print([len(split) for split in splits])
1453
- [50, 25]
1454
- # Split the array into 2 chunks with ratios 0.5 and 0.25.
1455
-
1456
- """
1457
-
1458
- arr = np.random.permutation(arr)
1459
- ind = np.add.accumulate(np.array(ratios) * len(arr)).astype(int)
1460
- return [x.tolist() for x in np.split(arr, ind)][:len(ratios)]
1461
-
1462
- def compute_weights(y):
1463
-
1464
- """
1465
-
1466
- Compute class weights based on the input labels.
1467
-
1468
- Parameters
1469
- ----------
1470
- y : array-like
1471
- Array of labels.
1472
-
1473
- Returns
1474
- -------
1475
- dict
1476
- A dictionary containing the computed class weights.
1477
-
1478
- Notes
1479
- -----
1480
- This function calculates the class weights based on the input labels (`y`) using the "balanced" method.
1481
- The class weights are computed to address the class imbalance problem, where the weights are inversely
1482
- proportional to the class frequencies.
1483
-
1484
- The function returns a dictionary (`class_weights`) where the keys represent the unique classes in `y`
1485
- and the values represent the computed weights for each class.
1486
-
1487
- Examples
1488
- --------
1489
- >>> labels = np.array([0, 1, 0, 1, 1])
1490
- >>> weights = compute_weights(labels)
1491
- >>> print(weights)
1492
- {0: 1.5, 1: 0.75}
1493
- # Compute class weights for the binary labels.
1494
-
1495
- >>> labels = np.array([0, 1, 2, 0, 1, 2, 2])
1496
- >>> weights = compute_weights(labels)
1497
- >>> print(weights)
1498
- {0: 1.1666666666666667, 1: 1.1666666666666667, 2: 0.5833333333333334}
1499
- # Compute class weights for the multi-class labels.
1500
-
1501
- """
1502
-
1503
- class_weights = compute_class_weight(
1504
- class_weight = "balanced",
1505
- classes = np.unique(y),
1506
- y = y,
1507
- )
1508
- class_weights = dict(zip(np.unique(y), class_weights))
1509
-
1510
- return class_weights
1511
-
1512
- def train_test_split(data_x, data_y1, data_class=None, validation_size=0.25, test_size=0, n_iterations=10):
1513
-
1514
- """
1515
-
1516
- Split the dataset into training, validation, and test sets.
1517
-
1518
- Parameters
1519
- ----------
1520
- data_x : array-like
1521
- Input features or independent variables.
1522
- data_y1 : array-like
1523
- Target variable 1.
1524
- data_y2 : array-like
1525
- Target variable 2.
1526
- validation_size : float, optional
1527
- Proportion of the dataset to include in the validation set. Default is 0.25.
1528
- test_size : float, optional
1529
- Proportion of the dataset to include in the test set. Default is 0.
1530
-
1531
- Returns
1532
- -------
1533
- dict
1534
- A dictionary containing the split datasets.
1535
- Keys: "x_train", "x_val", "y1_train", "y1_val", "y2_train", "y2_val".
1536
- If test_size > 0, additional keys: "x_test", "y1_test", "y2_test".
1537
-
1538
- Notes
1539
- -----
1540
- This function divides the dataset into training, validation, and test sets based on the specified proportions.
1541
- It shuffles the data and splits it according to the proportions defined by `validation_size` and `test_size`.
1542
-
1543
- The input features (`data_x`) and target variables (`data_y1`, `data_y2`) should be arrays or array-like objects
1544
- with compatible dimensions.
1545
-
1546
- The function returns a dictionary containing the split datasets. The training set is assigned to "x_train",
1547
- "y1_train", and "y2_train". The validation set is assigned to "x_val", "y1_val", and "y2_val". If `test_size` is
1548
- greater than 0, the test set is assigned to "x_test", "y1_test", and "y2_test".
1549
-
1550
- """
1551
-
1552
- if data_class is not None:
1553
- print(f"Unique classes: {np.sort(np.argmax(np.unique(data_class,axis=0),axis=1))}")
1554
-
1555
- for i in range(n_iterations):
1556
-
1557
- n_values = len(data_x)
1558
- randomize = np.arange(n_values)
1559
- np.random.shuffle(randomize)
1560
-
1561
- train_percentage = 1 - validation_size - test_size
1562
-
1563
- chunks = split_by_ratio(randomize, train_percentage, validation_size, test_size)
1564
-
1565
- x_train = data_x[chunks[0]]
1566
- y1_train = data_y1[chunks[0]]
1567
- if data_class is not None:
1568
- y2_train = data_class[chunks[0]]
1569
-
1570
- x_val = data_x[chunks[1]]
1571
- y1_val = data_y1[chunks[1]]
1572
- if data_class is not None:
1573
- y2_val = data_class[chunks[1]]
1574
-
1575
- if data_class is not None:
1576
- print(f"classes in train set: {np.sort(np.argmax(np.unique(y2_train,axis=0),axis=1))}; classes in validation set: {np.sort(np.argmax(np.unique(y2_val,axis=0),axis=1))}")
1577
- same_class_test = np.array_equal(np.sort(np.argmax(np.unique(y2_train,axis=0),axis=1)), np.sort(np.argmax(np.unique(y2_val,axis=0),axis=1)))
1578
- print(f"Check that classes are found in all sets: {same_class_test}...")
1579
- else:
1580
- same_class_test = True
1581
-
1582
- if same_class_test:
1583
-
1584
- ds = {"x_train": x_train, "x_val": x_val,
1585
- "y1_train": y1_train, "y1_val": y1_val}
1586
- if data_class is not None:
1587
- ds.update({"y2_train": y2_train, "y2_val": y2_val})
1588
-
1589
- if test_size>0:
1590
- x_test = data_x[chunks[2]]
1591
- y1_test = data_y1[chunks[2]]
1592
- ds.update({"x_test": x_test, "y1_test": y1_test})
1593
- if data_class is not None:
1594
- y2_test = data_class[chunks[2]]
1595
- ds.update({"y2_test": y2_test})
1596
- return ds
1597
- else:
1598
- continue
1599
-
1600
- raise Exception("Some classes are missing from the train or validation set... Abort.")
1601
-
1602
-
1603
- def remove_redundant_features(features, reference_features, channel_names=None):
1604
-
1605
- """
1606
-
1607
- Remove redundant features from a list of features based on a reference feature list.
1608
-
1609
- Parameters
1610
- ----------
1611
- features : list
1612
- The list of features to be filtered.
1613
- reference_features : list
1614
- The reference list of features.
1615
- channel_names : list or None, optional
1616
- The list of channel names. If provided, it is used to identify and remove redundant intensity features.
1617
- Default is None.
1618
-
1619
- Returns
1620
- -------
1621
- list
1622
- The filtered list of features without redundant entries.
1623
-
1624
- Notes
1625
- -----
1626
- This function removes redundant features from the input list based on a reference list of features. Features that
1627
- appear in the reference list are removed from the input list. Additionally, if the channel_names parameter is provided,
1628
- it is used to identify and remove redundant intensity features. Intensity features that have the same mode (e.g., 'mean',
1629
- 'min', 'max') as any of the channel names in the reference list are also removed.
1630
-
1631
- Examples
1632
- --------
1633
- >>> features = ['area', 'intensity_mean', 'intensity_max', 'eccentricity']
1634
- >>> reference_features = ['area', 'eccentricity']
1635
- >>> filtered_features = remove_redundant_features(features, reference_features)
1636
- >>> filtered_features
1637
- ['intensity_mean', 'intensity_max']
1638
-
1639
- >>> channel_names = ['brightfield', 'channel1', 'channel2']
1640
- >>> filtered_features = remove_redundant_features(features, reference_features, channel_names)
1641
- >>> filtered_features
1642
- ['area', 'eccentricity']
1643
-
1644
- """
1645
-
1646
- new_features = features[:]
1647
-
1648
- for f in features:
1649
-
1650
- if f in reference_features:
1651
- new_features.remove(f)
1652
-
1653
- if ('intensity' in f) and (channel_names is not None):
1654
-
1655
- mode = f.split('_')[-1]
1656
- pattern = [a+'_'+mode for a in channel_names]
1657
-
1658
- for p in pattern:
1659
- if p in reference_features:
1660
- try:
1661
- new_features.remove(f)
1662
- except:
1663
- pass
1664
- return new_features
1665
-
1666
- def _estimate_scale_factor(spatial_calibration, required_spatial_calibration):
1667
-
1668
- """
1669
- Estimates the scale factor needed to adjust spatial calibration to a required value.
1670
-
1671
- This function calculates the scale factor by which spatial dimensions (e.g., in microscopy images)
1672
- should be adjusted to align with a specified calibration standard. This is particularly useful when
1673
- preparing data for analysis with models trained on data of a specific spatial calibration.
1674
-
1675
- Parameters
1676
- ----------
1677
- spatial_calibration : float or None
1678
- The current spatial calibration factor of the data, expressed as units per pixel (e.g., micrometers per pixel).
1679
- If None, indicates that the current spatial calibration is unknown or unspecified.
1680
- required_spatial_calibration : float or None
1681
- The spatial calibration factor required for compatibility with the model or analysis standard, expressed
1682
- in the same units as `spatial_calibration`. If None, indicates no adjustment is required.
1683
-
1684
- Returns
1685
- -------
1686
- float or None
1687
- The scale factor by which the current data should be rescaled to match the required spatial calibration,
1688
- or None if no scaling is necessary or if insufficient information is provided.
1689
-
1690
- Notes
1691
- -----
1692
- - A scale factor close to 1 (within a tolerance defined by `epsilon`) indicates that no significant rescaling
1693
- is needed, and the function returns None.
1694
- - The function issues a warning if a significant rescaling is necessary, indicating the scale factor to be applied.
1695
-
1696
- Examples
1697
- --------
1698
- >>> scale_factor = _estimate_scale_factor(spatial_calibration=0.5, required_spatial_calibration=0.25)
1699
- # Each frame will be rescaled by a factor 2.0 to match with the model training data...
1700
-
1701
- >>> scale_factor = _estimate_scale_factor(spatial_calibration=None, required_spatial_calibration=0.25)
1702
- # Returns None due to insufficient information about current spatial calibration.
1703
- """
1704
-
1705
- if (required_spatial_calibration is not None)*(spatial_calibration is not None):
1706
- scale = spatial_calibration / required_spatial_calibration
1707
- else:
1708
- scale = None
1709
-
1710
- epsilon = 0.05
1711
- if scale is not None:
1712
- if not np.all([scale >= (1-epsilon), scale <= (1+epsilon)]):
1713
- print(f"Each frame will be rescaled by a factor {scale} to match with the model training data...")
1714
- else:
1715
- scale = None
1716
- return scale
1717
-
1718
- def auto_find_gpu():
1719
-
1720
- """
1721
- Automatically detects the presence of GPU devices in the system.
1722
-
1723
- This function checks if any GPU devices are available for use by querying the system's physical devices.
1724
- It is a utility function to simplify the process of determining whether GPU-accelerated computing can be
1725
- leveraged in data processing or model training tasks.
1726
-
1727
- Returns
1728
- -------
1729
- bool
1730
- True if one or more GPU devices are detected, False otherwise.
1731
-
1732
- Notes
1733
- -----
1734
- - The function uses TensorFlow's `list_physical_devices` method to query available devices, specifically
1735
- looking for 'GPU' devices.
1736
- - This function is useful for dynamically adjusting computation strategies based on available hardware resources.
1737
-
1738
- Examples
1739
- --------
1740
- >>> has_gpu = auto_find_gpu()
1741
- >>> print(f"GPU available: {has_gpu}")
1742
- # GPU available: True or False based on the system's hardware configuration.
1743
- """
1744
-
1745
- gpus = list_physical_devices('GPU')
1746
- if len(gpus)>0:
1747
- use_gpu = True
1748
- else:
1749
- use_gpu = False
1750
-
1751
- return use_gpu
1752
-
1753
- def _extract_channel_indices(channels, required_channels):
1754
-
1755
- """
1756
- Extracts the indices of required channels from a list of available channels.
1757
-
1758
- This function is designed to match the channels required by a model or analysis process with the channels
1759
- present in the dataset. It returns the indices of the required channels within the list of available channels.
1760
- If the required channels are not found among the available channels, the function prints an error message and
1761
- returns None.
1762
-
1763
- Parameters
1764
- ----------
1765
- channels : list of str or None
1766
- A list containing the names of the channels available in the dataset. If None, it is assumed that the
1767
- dataset channels are in the same order as the required channels.
1768
- required_channels : list of str
1769
- A list containing the names of the channels required by the model or analysis process.
1770
-
1771
- Returns
1772
- -------
1773
- ndarray or None
1774
- An array of indices indicating the positions of the required channels within the list of available
1775
- channels. Returns None if there is a mismatch between required and available channels.
1776
-
1777
- Notes
1778
- -----
1779
- - The function is useful for preprocessing steps where specific channels of multi-channel data are needed
1780
- for further analysis or model input.
1781
- - In cases where `channels` is None, indicating that the dataset does not specify channel names, the function
1782
- assumes that the dataset's channel order matches the order of `required_channels` and returns an array of
1783
- indices based on this assumption.
1784
-
1785
- Examples
1786
- --------
1787
- >>> available_channels = ['DAPI', 'GFP', 'RFP']
1788
- >>> required_channels = ['GFP', 'RFP']
1789
- >>> indices = _extract_channel_indices(available_channels, required_channels)
1790
- >>> print(indices)
1791
- # [1, 2]
1792
-
1793
- >>> indices = _extract_channel_indices(None, required_channels)
1794
- >>> print(indices)
1795
- # [0, 1]
1796
- """
1797
-
1798
- channel_indices = []
1799
- for c in required_channels:
1800
- if c!='None' and c is not None:
1801
- try:
1802
- ch_idx = channels.index(c)
1803
- channel_indices.append(ch_idx)
1804
- except Exception as e:
1805
- channel_indices.append(None)
1806
- else:
1807
- channel_indices.append(None)
1808
-
1809
- return channel_indices
1810
-
1811
- def config_section_to_dict(path: Union[str,PurePath,Path], section: str) -> Union[Dict,None]:
1812
-
1813
- """
1814
- Parse the config file to extract experiment parameters
1815
- following https://wiki.python.org/moin/ConfigParserExamples
1816
-
1817
- Parameters
1818
- ----------
1819
-
1820
- path: str
1821
- path to the config.ini file
1822
-
1823
- section: str
1824
- name of the section that contains the parameter
1825
-
1826
- Returns
1827
- -------
1828
-
1829
- dict1: dictionary
1830
-
1831
- Examples
1832
- --------
1833
- >>> config = "path/to/config_file.ini"
1834
- >>> section = "Channels"
1835
- >>> channel_dictionary = config_section_to_dict(config,section)
1836
- >>> print(channel_dictionary)
1837
- # {'brightfield_channel': '0',
1838
- # 'live_nuclei_channel': 'nan',
1839
- # 'dead_nuclei_channel': 'nan',
1840
- # 'effector_fluo_channel': 'nan',
1841
- # 'adhesion_channel': '1',
1842
- # 'fluo_channel_1': 'nan',
1843
- # 'fluo_channel_2': 'nan',
1844
- # 'fitc_channel': '2',
1845
- # 'cy5_channel': '3'}
1846
- """
1847
-
1848
- Config = configparser.ConfigParser(interpolation=None)
1849
- Config.read(path)
1850
- dict1 = {}
1851
- try:
1852
- options = Config.options(section)
1853
- except:
1854
- return None
1855
- for option in options:
1856
- try:
1857
- dict1[option] = Config.get(section, option)
1858
- if dict1[option] == -1:
1859
- print("skip: %s" % option)
1860
- except:
1861
- print("exception on %s!" % option)
1862
- dict1[option] = None
1863
- return dict1
1864
-
1865
- def _extract_channel_indices_from_config(config, channels_to_extract):
1866
-
1867
- """
1868
- Extracts the indices of specified channels from a configuration object.
1869
-
1870
- This function attempts to map required channel names to their respective indices as specified in a
1871
- configuration file. It supports two versions of configuration parsing: a primary method (V2) and a
1872
- fallback legacy method. If the required channels are not found using the primary method, the function
1873
- attempts to find them using the legacy configuration settings.
1874
-
1875
- Parameters
1876
- ----------
1877
- config : ConfigParser object
1878
- The configuration object parsed from a .ini or similar configuration file that includes channel settings.
1879
- channels_to_extract : list of str
1880
- A list of channel names for which indices are to be extracted from the configuration settings.
1881
-
1882
- Returns
1883
- -------
1884
- list of int or None
1885
- A list containing the indices of the specified channels as found in the configuration settings.
1886
- If a channel cannot be found, None is appended in its place. If an error occurs during the extraction
1887
- process, the function returns None.
1888
-
1889
- Notes
1890
- -----
1891
- - This function is designed to be flexible, accommodating changes in configuration file structure by
1892
- checking multiple sections for the required information.
1893
- - The configuration file is expected to contain either "Channels" or "MovieSettings" sections with mappings
1894
- from channel names to indices.
1895
- - An error message is printed if a required channel cannot be found, advising the user to check the
1896
- configuration file.
1897
-
1898
- Examples
1899
- --------
1900
- >>> config = "path/to/config_file.ini"
1901
- >>> channels_to_extract = ['adhesion_channel', 'brightfield_channel']
1902
- >>> channel_indices = _extract_channel_indices_from_config(config, channels_to_extract)
1903
- >>> print(channel_indices)
1904
- # [1, 0] or None if an error occurs or the channels are not found.
1905
- """
1906
-
1907
- if isinstance(channels_to_extract, str):
1908
- channels_to_extract = [channels_to_extract]
1909
-
1910
- channels = []
1911
- for c in channels_to_extract:
1912
- try:
1913
- c1 = int(config_section_to_dict(config, "Channels")[c])
1914
- channels.append(c1)
1915
- except Exception as e:
1916
- print(f"Warning: The channel {c} required by the model is not available in your data...")
1917
- channels.append(None)
1918
- if np.all([c is None for c in channels]):
1919
- channels = None
1920
-
1921
- return channels
1922
-
1923
- def _extract_nbr_channels_from_config(config, return_names=False):
1924
-
1925
- """
1926
-
1927
- Examples
1928
- --------
1929
- >>> config = "path/to/config_file.ini"
1930
- >>> nbr_channels = _extract_channel_indices_from_config(config)
1931
- >>> print(nbr_channels)
1932
- # 4
1933
- """
1934
-
1935
- # V2
1936
- nbr_channels = 0
1937
- channels = []
1938
- try:
1939
- fields = config_section_to_dict(config, "Channels")
1940
- for c in fields:
1941
- try:
1942
- channel = int(config_section_to_dict(config, "Channels")[c])
1943
- nbr_channels += 1
1944
- channels.append(c)
1945
- except:
1946
- pass
1947
- except:
1948
- pass
1949
-
1950
- if nbr_channels==0:
1951
-
1952
- # Read channels LEGACY
1953
- nbr_channels = 0
1954
- channels = []
1955
- try:
1956
- brightfield_channel = int(config_section_to_dict(config, "MovieSettings")["brightfield_channel"])
1957
- nbr_channels += 1
1958
- channels.append('brightfield_channel')
1959
- except:
1960
- brightfield_channel = None
1961
-
1962
- try:
1963
- live_nuclei_channel = int(config_section_to_dict(config, "MovieSettings")["live_nuclei_channel"])
1964
- nbr_channels += 1
1965
- channels.append('live_nuclei_channel')
1966
- except:
1967
- live_nuclei_channel = None
1968
-
1969
- try:
1970
- dead_nuclei_channel = int(config_section_to_dict(config, "MovieSettings")["dead_nuclei_channel"])
1971
- nbr_channels +=1
1972
- channels.append('dead_nuclei_channel')
1973
- except:
1974
- dead_nuclei_channel = None
1975
-
1976
- try:
1977
- effector_fluo_channel = int(config_section_to_dict(config, "MovieSettings")["effector_fluo_channel"])
1978
- nbr_channels +=1
1979
- channels.append('effector_fluo_channel')
1980
- except:
1981
- effector_fluo_channel = None
1982
-
1983
- try:
1984
- adhesion_channel = int(config_section_to_dict(config, "MovieSettings")["adhesion_channel"])
1985
- nbr_channels += 1
1986
- channels.append('adhesion_channel')
1987
- except:
1988
- adhesion_channel = None
1989
-
1990
- try:
1991
- fluo_channel_1 = int(config_section_to_dict(config, "MovieSettings")["fluo_channel_1"])
1992
- nbr_channels += 1
1993
- channels.append('fluo_channel_1')
1994
- except:
1995
- fluo_channel_1 = None
1996
-
1997
- try:
1998
- fluo_channel_2 = int(config_section_to_dict(config, "MovieSettings")["fluo_channel_2"])
1999
- nbr_channels += 1
2000
- channels.append('fluo_channel_2')
2001
- except:
2002
- fluo_channel_2 = None
2003
-
2004
- if return_names:
2005
- return nbr_channels,channels
2006
- else:
2007
- return nbr_channels
2008
-
2009
- def _get_img_num_per_channel(channels_indices, len_movie, nbr_channels):
2010
-
2011
- """
2012
- Calculates the image frame numbers for each specified channel in a multi-channel movie.
2013
-
2014
- Given the indices of channels of interest, the total length of the movie, and the number of channels,
2015
- this function computes the frame numbers corresponding to each channel throughout the movie. If a
2016
- channel index is specified as None, it assigns a placeholder value to indicate no frames for that channel.
2017
-
2018
- Parameters
2019
- ----------
2020
- channels_indices : list of int or None
2021
- A list containing the indices of channels for which to calculate frame numbers. If an index is None,
2022
- it is interpreted as a channel with no frames to be processed.
2023
- len_movie : int
2024
- The total number of frames in the movie across all channels.
2025
- nbr_channels : int
2026
- The total number of channels in the movie.
2027
-
2028
- Returns
2029
- -------
2030
- ndarray
2031
- A 2D numpy array where each row corresponds to a channel specified in `channels_indices` and contains
2032
- the frame numbers for that channel throughout the movie. If a channel index is None, the corresponding
2033
- row contains placeholder values (-1).
2034
-
2035
- Notes
2036
- -----
2037
- - The function assumes that frames in the movie are interleaved by channel, with frames for each channel
2038
- appearing in a regular sequence throughout the movie.
2039
- - This utility is particularly useful for multi-channel time-lapse movies where analysis or processing
2040
- needs to be performed on a per-channel basis.
2041
-
2042
- Examples
2043
- --------
2044
- >>> channels_indices = [0] # Indices for channels 1, 3, and a non-existing channel
2045
- >>> len_movie = 10 # Total frames for each channel
2046
- >>> nbr_channels = 3 # Total channels in the movie
2047
- >>> img_num_per_channel = _get_img_num_per_channel(channels_indices, len_movie, nbr_channels)
2048
- >>> print(img_num_per_channel)
2049
- # array([[ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27]])
2050
-
2051
- >>> channels_indices = [1,2] # Indices for channels 1, 3, and a non-existing channel
2052
- >>> len_movie = 10 # Total frames for each channel
2053
- >>> nbr_channels = 3 # Total channels in the movie
2054
- >>> img_num_per_channel = _get_img_num_per_channel(channels_indices, len_movie, nbr_channels)
2055
- >>> print(img_num_per_channel)
2056
- # array([[ 1, 4, 7, 10, 13, 16, 19, 22, 25, 28],
2057
- # [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29]])
2058
-
2059
- """
2060
-
2061
- if isinstance(channels_indices, (int, np.int_)):
2062
- channels_indices = [channels_indices]
2063
-
2064
- len_movie = int(len_movie)
2065
- nbr_channels = int(nbr_channels)
2066
-
2067
- img_num_all_channels = []
2068
- for c in channels_indices:
2069
- if c is not None:
2070
- indices = np.arange(len_movie*nbr_channels)[c::nbr_channels]
2071
- else:
2072
- indices = [-1]*len_movie
2073
- img_num_all_channels.append(indices)
2074
- img_num_all_channels = np.array(img_num_all_channels, dtype=int)
2075
-
2076
- return img_num_all_channels
2077
-
2078
- def _extract_labels_from_config(config,number_of_wells):
2079
-
2080
- """
2081
-
2082
- Extract each well's biological condition from the configuration file
2083
-
2084
- Parameters
2085
- ----------
2086
-
2087
- config: str,
2088
- path to the configuration file
2089
-
2090
- number_of_wells: int,
2091
- total number of wells in the experiment
2092
-
2093
- Returns
2094
- -------
2095
-
2096
- labels: string of the biological condition for each well
2097
-
2098
- """
2099
-
2100
- # Deprecated, need to read metadata to extract concentration units and discard non essential fields
2101
-
2102
-
2103
- try:
2104
- concentrations = config_section_to_dict(config, "Labels")["concentrations"].split(",")
2105
- cell_types = config_section_to_dict(config, "Labels")["cell_types"].split(",")
2106
- antibodies = config_section_to_dict(config, "Labels")["antibodies"].split(",")
2107
- pharmaceutical_agents = config_section_to_dict(config, "Labels")["pharmaceutical_agents"].split(",")
2108
- index = np.arange(len(concentrations)).astype(int) + 1
2109
- if not np.all(pharmaceutical_agents=="None"):
2110
- labels = [f"W{idx}: [CT] "+a+"; [Ab] "+b+" @ "+c+" pM "+d for idx,a,b,c,d in zip(index,cell_types,antibodies,concentrations,pharmaceutical_agents)]
2111
- else:
2112
- labels = [f"W{idx}: [CT] "+a+"; [Ab] "+b+" @ "+c+" pM " for idx,a,b,c in zip(index,cell_types,antibodies,concentrations)]
2113
-
2114
-
2115
- except Exception as e:
2116
- print(f"{e}: the well labels cannot be read from the concentration and cell_type fields")
2117
- labels = np.linspace(0,number_of_wells-1,number_of_wells,dtype=str)
2118
-
2119
- return(labels)
2120
-
2121
-
2122
- def _extract_channels_from_config(config):
2123
-
2124
- """
2125
- Extracts channel names and their indices from an experiment configuration.
2126
-
2127
- Parameters
2128
- ----------
2129
- config : path to config file (.ini)
2130
- The configuration object parsed from an experiment's .ini or similar configuration file.
2131
-
2132
- Returns
2133
- -------
2134
- tuple
2135
- A tuple containing two numpy arrays: `channel_names` and `channel_indices`. `channel_names` includes
2136
- the names of the channels as specified in the configuration, and `channel_indices` includes their
2137
- corresponding indices. Both arrays are ordered according to the channel indices.
2138
-
2139
- Examples
2140
- --------
2141
- >>> config = "path/to/config_file.ini"
2142
- >>> channels, indices = _extract_channels_from_config(config)
2143
- >>> print(channels)
2144
- # array(['brightfield_channel', 'adhesion_channel', 'fitc_channel',
2145
- # 'cy5_channel'], dtype='<U19')
2146
- >>> print(indices)
2147
- # array([0, 1, 2, 3])
2148
- """
2149
-
2150
- channel_names = []
2151
- channel_indices = []
2152
- try:
2153
- fields = config_section_to_dict(config, "Channels")
2154
- for c in fields:
2155
- try:
2156
- idx = int(config_section_to_dict(config, "Channels")[c])
2157
- channel_names.append(c)
2158
- channel_indices.append(idx)
2159
- except:
2160
- pass
2161
- except:
2162
- pass
2163
-
2164
- channel_indices = np.array(channel_indices)
2165
- channel_names = np.array(channel_names)
2166
- reorder = np.argsort(channel_indices)
2167
- channel_indices = channel_indices[reorder]
2168
- channel_names = channel_names[reorder]
2169
-
2170
- return channel_names, channel_indices
2171
-
2172
-
2173
- def extract_experiment_channels(experiment):
2174
-
2175
- """
2176
- Extracts channel names and their indices from an experiment project.
2177
-
2178
- Parameters
2179
- ----------
2180
- experiment : str
2181
- The file system path to the directory of the experiment project.
2182
-
2183
- Returns
2184
- -------
2185
- tuple
2186
- A tuple containing two numpy arrays: `channel_names` and `channel_indices`. `channel_names` includes
2187
- the names of the channels as specified in the configuration, and `channel_indices` includes their
2188
- corresponding indices. Both arrays are ordered according to the channel indices.
2189
-
2190
- Examples
2191
- --------
2192
- >>> experiment = "path/to/my_experiment"
2193
- >>> channels, indices = extract_experiment_channels(experiment)
2194
- >>> print(channels)
2195
- # array(['brightfield_channel', 'adhesion_channel', 'fitc_channel',
2196
- # 'cy5_channel'], dtype='<U19')
2197
- >>> print(indices)
2198
- # array([0, 1, 2, 3])
2199
- """
2200
-
2201
- config = get_config(experiment)
2202
- return _extract_channels_from_config(config)
2203
-
2204
-
2205
- def get_software_location() -> str:
2206
-
2207
- """
2208
- Get the installation folder of celldetective.
2209
-
2210
- Returns
2211
- -------
2212
- str
2213
- Path to the celldetective installation folder.
2214
- """
2215
-
2216
- return rf"{os.path.split(os.path.dirname(os.path.realpath(__file__)))[0]}"
2217
-
2218
- def remove_trajectory_measurements(trajectories, column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
2219
-
2220
- """
2221
- Clear a measurement table, while keeping the tracking information.
2222
-
2223
- Parameters
2224
- ----------
2225
- trajectories : pandas.DataFrame
2226
- The measurement table where each line is a cell at a timepoint and each column a tracking feature or measurement.
2227
- column_labels : dict, optional
2228
- The column labels to use in the output DataFrame. Default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}.
2229
-
2230
-
2231
- Returns
2232
- -------
2233
- pandas.DataFrame
2234
- A filtered DataFrame containing only the tracking columns.
2235
-
2236
- Examples
2237
- --------
2238
- >>> trajectories_df = pd.DataFrame({
2239
- ... 'TRACK_ID': [1, 1, 2],
2240
- ... 'FRAME': [0, 1, 0],
2241
- ... 'POSITION_X': [100, 105, 200],
2242
- ... 'POSITION_Y': [150, 155, 250],
2243
- ... 'area': [10,100,100], # Additional column to be removed
2244
- ... })
2245
- >>> filtered_df = remove_trajectory_measurements(trajectories_df)
2246
- >>> print(filtered_df)
2247
- # pd.DataFrame({
2248
- # 'TRACK_ID': [1, 1, 2],
2249
- # 'FRAME': [0, 1, 0],
2250
- # 'POSITION_X': [100, 105, 200],
2251
- # 'POSITION_Y': [150, 155, 250],
2252
- # })
2253
- """
2254
-
2255
- tracks = trajectories.copy()
2256
-
2257
- columns_to_keep = [column_labels['track'], column_labels['time'], column_labels['x'], column_labels['y'],column_labels['x']+'_um', column_labels['y']+'_um', 'class_id',
2258
- 't', 'state', 'generation', 'root', 'parent', 'ID', 't0', 'class', 'status', 'class_color', 'status_color', 'class_firstdetection', 't_firstdetection', 'status_firstdetection','velocity']
2259
- cols = list(tracks.columns)
2260
- for c in columns_to_keep:
2261
- if c not in cols:
2262
- columns_to_keep.remove(c)
2263
-
2264
- keep = [x for x in columns_to_keep if x in cols]
2265
- tracks = tracks[keep]
2266
-
2267
- return tracks
2268
-
2269
-
2270
- def color_from_status(status, recently_modified=False):
2271
-
2272
- if not recently_modified:
2273
- if status==0:
2274
- return 'tab:blue'
2275
- elif status==1:
2276
- return 'tab:red'
2277
- elif status==2:
2278
- return 'yellow'
2279
- else:
2280
- return 'k'
2281
- else:
2282
- if status==0:
2283
- return 'tab:cyan'
2284
- elif status==1:
2285
- return 'tab:orange'
2286
- elif status==2:
2287
- return 'tab:olive'
2288
- else:
2289
- return 'k'
2290
-
2291
- def color_from_class(cclass, recently_modified=False):
2292
-
2293
- if not recently_modified:
2294
- if cclass==0:
2295
- return 'tab:red'
2296
- elif cclass==1:
2297
- return 'tab:blue'
2298
- elif cclass==2:
2299
- return 'yellow'
2300
- else:
2301
- return 'k'
2302
- else:
2303
- if cclass==0:
2304
- return 'tab:orange'
2305
- elif cclass==1:
2306
- return 'tab:cyan'
2307
- elif cclass==2:
2308
- return 'tab:olive'
2309
- else:
2310
- return 'k'
2311
-
2312
- def random_fliprot(img, mask):
2313
-
2314
- """
2315
- Randomly flips and rotates an image and its corresponding mask.
2316
-
2317
- This function applies a series of random flips and permutations (rotations) to both the input image and its
2318
- associated mask, ensuring that any transformations applied to the image are also exactly applied to the mask.
2319
- The function is designed to handle multi-dimensional images (e.g., multi-channel images in YXC format where
2320
- channels are last).
2321
-
2322
- Parameters
2323
- ----------
2324
- img : ndarray
2325
- The input image to be transformed. This array is expected to have dimensions where the channel axis is last.
2326
- mask : ndarray
2327
- The mask corresponding to `img`, to be transformed in the same way as the image.
2328
-
2329
- Returns
2330
- -------
2331
- tuple of ndarray
2332
- A tuple containing the transformed image and mask.
2333
-
2334
- Raises
2335
- ------
2336
- AssertionError
2337
- If the number of dimensions of the mask exceeds that of the image, indicating incompatible shapes.
2338
-
2339
- """
2340
-
2341
- assert img.ndim >= mask.ndim
2342
- axes = tuple(range(mask.ndim))
2343
- perm = tuple(np.random.permutation(axes))
2344
- img = img.transpose(perm + tuple(range(mask.ndim, img.ndim)))
2345
- mask = mask.transpose(perm)
2346
- for ax in axes:
2347
- if np.random.rand() > 0.5:
2348
- img = np.flip(img, axis=ax)
2349
- mask = np.flip(mask, axis=ax)
2350
- return img, mask
2351
-
2352
- # def random_intensity_change(img):
2353
- # img[img!=0.] = img[img!=0.]*np.random.uniform(0.3,2)
2354
- # img[img!=0.] += np.random.uniform(-0.2,0.2)
2355
- # return img
2356
-
2357
- def random_shift(image,mask, max_shift_amplitude=0.1):
2358
-
2359
- """
2360
- Randomly shifts an image and its corresponding mask along the X and Y axes.
2361
-
2362
- This function shifts both the image and the mask by a randomly chosen distance up to a maximum
2363
- percentage of the image's dimensions, specified by `max_shift_amplitude`. The shifts are applied
2364
- independently in both the X and Y directions. This type of augmentation can help improve the robustness
2365
- of models to positional variations in images.
2366
-
2367
- Parameters
2368
- ----------
2369
- image : ndarray
2370
- The input image to be shifted. Must be in YXC format (height, width, channels).
2371
- mask : ndarray
2372
- The mask corresponding to `image`, to be shifted in the same way as the image.
2373
- max_shift_amplitude : float, optional
2374
- The maximum shift as a fraction of the image's dimension. Default is 0.1 (10% of the image's size).
2375
-
2376
- Returns
2377
- -------
2378
- tuple of ndarray
2379
- A tuple containing the shifted image and mask.
2380
-
2381
- Notes
2382
- -----
2383
- - The shift values are chosen randomly within the range defined by the maximum amplitude.
2384
- - Shifting is performed using the 'constant' mode where missing values are filled with zeros (cval=0.0),
2385
- which may introduce areas of zero-padding along the edges of the shifted images and masks.
2386
- - This function is designed to support data augmentation for machine learning and image processing tasks,
2387
- particularly in contexts where spatial invariance is beneficial.
2388
-
2389
- """
2390
-
2391
- input_shape = image.shape[0]
2392
- max_shift = input_shape*max_shift_amplitude
2393
-
2394
- shift_value_x = random.choice(np.arange(max_shift))
2395
- if np.random.random() > 0.5:
2396
- shift_value_x*=-1
2397
-
2398
- shift_value_y = random.choice(np.arange(max_shift))
2399
- if np.random.random() > 0.5:
2400
- shift_value_y*=-1
2401
-
2402
- image = shift(image,[shift_value_x, shift_value_y, 0], output=np.float32, order=3, mode="constant",cval=0.0)
2403
- mask = shift(mask,[shift_value_x,shift_value_y],order=0,mode="constant",cval=0.0)
2404
-
2405
- return image,mask
2406
-
2407
-
2408
- def blur(x,max_sigma=4.0):
2409
-
2410
- """
2411
- Applies a random Gaussian blur to an image.
2412
-
2413
- This function blurs an image by applying a Gaussian filter with a randomly chosen sigma value. The sigma
2414
- represents the standard deviation for the Gaussian kernel and is selected randomly up to a specified maximum.
2415
- The blurring is applied while preserving the range of the image's intensity values and maintaining any
2416
- zero-valued pixels as they are.
2417
-
2418
- Parameters
2419
- ----------
2420
- x : ndarray
2421
- The input image to be blurred. The image can have any number of channels, but must be in a format
2422
- where the channels are the last dimension (YXC format).
2423
- max_sigma : float, optional
2424
- The maximum value for the standard deviation of the Gaussian blur. Default is 4.0.
2425
-
2426
- Returns
2427
- -------
2428
- ndarray
2429
- The blurred image. The output will have the same shape and type as the input image.
2430
-
2431
- Notes
2432
- -----
2433
- - The function ensures that zero-valued pixels in the input image remain unchanged after the blurring,
2434
- which can be important for maintaining masks or other specific regions within the image.
2435
- - Gaussian blurring is commonly used in image processing to reduce image noise and detail by smoothing.
2436
- """
2437
-
2438
- sigma = np.random.random()*max_sigma
2439
- loc_i,loc_j,loc_c = np.where(x==0.)
2440
- x = gaussian(x, sigma, channel_axis=-1, preserve_range=True)
2441
- x[loc_i,loc_j,loc_c] = 0.
2442
-
2443
- return x
2444
-
2445
- def noise(x, apply_probability=0.5, clip_option=False):
2446
-
2447
- """
2448
- Applies random noise to each channel of a multichannel image based on a specified probability.
2449
-
2450
- This function introduces various types of random noise to an image. Each channel of the image can be
2451
- modified independently with different noise models chosen randomly from a predefined list. The application
2452
- of noise to any given channel is determined by a specified probability, allowing for selective noise
2453
- addition.
2454
-
2455
- Parameters
2456
- ----------
2457
- x : ndarray
2458
- The input multichannel image to which noise will be added. The image should be in format with channels
2459
- as the last dimension (e.g., height x width x channels).
2460
- apply_probability : float, optional
2461
- The probability with which noise is applied to each channel of the image. Default is 0.5.
2462
- clip_option : bool, optional
2463
- Specifies whether to clip the corrupted data to stay within the valid range after noise addition.
2464
- If True, the output array will be clipped to the range [0, 1] or [0, 255] depending on the input
2465
- data type. Default is False.
2466
-
2467
- Returns
2468
- -------
2469
- ndarray
2470
- The noised image. This output has the same shape as the input but potentially altered intensity values
2471
- due to noise addition.
2472
-
2473
- Notes
2474
- -----
2475
- - The types of noise that can be applied include 'gaussian', 'localvar', 'poisson', and 'speckle'.
2476
- - The choice of noise type for each channel is randomized and the noise is only applied if a randomly
2477
- generated number is less than or equal to `apply_probability`.
2478
- - Zero-valued pixels in the input image remain zero in the output to preserve background or masked areas.
2479
-
2480
- Examples
2481
- --------
2482
- >>> import numpy as np
2483
- >>> x = np.random.rand(256, 256, 3) # Example 3-channel image
2484
- >>> noised_image = noise(x)
2485
- # The image 'x' may have different types of noise applied to each of its channels with a 50% probability.
2486
- """
2487
-
2488
- x_noise = x.astype(float).copy()
2489
- loc_i,loc_j,loc_c = np.where(x_noise==0.)
2490
- options = ['gaussian', 'localvar', 'poisson', 'speckle']
2491
-
2492
- for k in range(x_noise.shape[-1]):
2493
- mode_order = random.sample(options, len(options))
2494
- for m in mode_order:
2495
- p = np.random.random()
2496
- if p <= apply_probability:
2497
- try:
2498
- x_noise[:,:,k] = random_noise(x_noise[:,:,k], mode=m, clip=clip_option)
2499
- except:
2500
- pass
2501
-
2502
- x_noise[loc_i,loc_j,loc_c] = 0.
2503
-
2504
- return x_noise
2505
-
2506
-
2507
-
2508
- def augmenter(x, y, flip=True, gauss_blur=True, noise_option=True, shift=True,
2509
- channel_extinction=True, extinction_probability=0.1, clip=False, max_sigma_blur=4,
2510
- apply_noise_probability=0.5, augment_probability=0.9):
2511
-
2512
- """
2513
- Applies a series of augmentation techniques to images and their corresponding masks for deep learning training.
2514
-
2515
- This function randomly applies a set of transformations including flipping, rotation, Gaussian blur,
2516
- additive noise, shifting, and channel extinction to input images (x) and their masks (y) based on specified
2517
- probabilities. These augmentations introduce variability in the training dataset, potentially improving model
2518
- generalization.
2519
-
2520
- Parameters
2521
- ----------
2522
- x : ndarray
2523
- The input image to be augmented, with dimensions (height, width, channels).
2524
- y : ndarray
2525
- The corresponding mask or label image for `x`, with the same spatial dimensions.
2526
- flip : bool, optional
2527
- Whether to randomly flip and rotate the images. Default is True.
2528
- gauss_blur : bool, optional
2529
- Whether to apply Gaussian blur to the images. Default is True.
2530
- noise_option : bool, optional
2531
- Whether to add random noise to the images. Default is True.
2532
- shift : bool, optional
2533
- Whether to randomly shift the images. Default is True.
2534
- channel_extinction : bool, optional
2535
- Whether to randomly set entire channels of the image to zero. Default is False.
2536
- extinction_probability : float, optional
2537
- The probability of an entire channel being set to zero. Default is 0.1.
2538
- clip : bool, optional
2539
- Whether to clip the noise-added images to stay within valid intensity values. Default is False.
2540
- max_sigma_blur : int, optional
2541
- The maximum sigma value for Gaussian blur. Default is 4.
2542
- apply_noise_probability : float, optional
2543
- The probability of applying noise to the image. Default is 0.5.
2544
- augment_probability : float, optional
2545
- The overall probability of applying any augmentation to the image. Default is 0.9.
2546
-
2547
- Returns
2548
- -------
2549
- tuple
2550
- A tuple containing the augmented image and mask `(x, y)`.
2551
-
2552
- Raises
2553
- ------
2554
- AssertionError
2555
- If `extinction_probability` is not within the range [0, 1].
2556
-
2557
- Notes
2558
- -----
2559
- - The augmentations are applied randomly based on the specified probabilities, allowing for
2560
- a diverse set of transformed images from the original inputs.
2561
- - This function is designed to be part of a preprocessing pipeline for training deep learning models,
2562
- especially in tasks requiring spatial invariance and robustness to noise.
2563
-
2564
- Examples
2565
- --------
2566
- >>> import numpy as np
2567
- >>> x = np.random.rand(128, 128, 3) # Sample image
2568
- >>> y = np.random.randint(2, size=(128, 128)) # Sample binary mask
2569
- >>> x_aug, y_aug = augmenter(x, y)
2570
- # The returned `x_aug` and `y_aug` are augmented versions of `x` and `y`.
2571
-
2572
- """
2573
-
2574
- r = random.random()
2575
- if r<= augment_probability:
2576
-
2577
- if flip:
2578
- x, y = random_fliprot(x, y)
2579
-
2580
- if gauss_blur:
2581
- x = blur(x, max_sigma=max_sigma_blur)
2582
-
2583
- if noise_option:
2584
- x = noise(x, apply_probability=apply_noise_probability, clip_option=clip)
2585
-
2586
- if shift:
2587
- x,y = random_shift(x,y)
2588
-
2589
- if channel_extinction:
2590
- assert extinction_probability <= 1.,'The extinction probability must be a number between 0 and 1.'
2591
- channel_off = [np.random.random() < extinction_probability for i in range(x.shape[-1])]
2592
- channel_off[0] = False
2593
- x[:,:,np.array(channel_off, dtype=bool)] = 0.
2594
-
2595
- return x, y
2596
-
2597
- def normalize_per_channel(X, normalization_percentile_mode=True, normalization_values=[0.1,99.99],normalization_clipping=False):
2598
-
2599
- """
2600
- Applies per-channel normalization to a list of multi-channel images.
2601
-
2602
- This function normalizes each channel of every image in the list `X` based on either percentile values
2603
- or fixed min-max values. Optionally, it can also clip the normalized values to stay within the [0, 1] range.
2604
- The normalization can be applied in a percentile mode, where the lower and upper bounds for normalization
2605
- are determined based on the specified percentiles of the non-zero values in each channel.
2606
-
2607
- Parameters
2608
- ----------
2609
- X : list of ndarray
2610
- A list of 3D numpy arrays, where each array represents a multi-channel image with dimensions
2611
- (height, width, channels).
2612
- normalization_percentile_mode : bool or list of bool, optional
2613
- If True (or a list of True values), normalization bounds are determined by percentiles specified
2614
- in `normalization_values` for each channel. If False, fixed `normalization_values` are used directly.
2615
- Default is True.
2616
- normalization_values : list of two floats or list of lists of two floats, optional
2617
- The percentile values [lower, upper] used for normalization in percentile mode, or the fixed
2618
- min-max values [min, max] for direct normalization. Default is [0.1, 99.99].
2619
- normalization_clipping : bool or list of bool, optional
2620
- Determines whether to clip the normalized values to the [0, 1] range for each channel. Default is False.
2621
-
2622
- Returns
2623
- -------
2624
- list of ndarray
2625
- The list of normalized multi-channel images.
2626
-
2627
- Raises
2628
- ------
2629
- AssertionError
2630
- If the input images do not have a channel dimension, or if the lengths of `normalization_values`,
2631
- `normalization_clipping`, and `normalization_percentile_mode` do not match the number of channels.
2632
-
2633
- Notes
2634
- -----
2635
- - The normalization is applied in-place, modifying the input list `X`.
2636
- - This function is designed to handle multi-channel images commonly used in image processing and
2637
- computer vision tasks, particularly when different channels require separate normalization strategies.
2638
-
2639
- Examples
2640
- --------
2641
- >>> X = [np.random.rand(100, 100, 3) for _ in range(5)] # Example list of 5 RGB images
2642
- >>> normalized_X = normalize_per_channel(X)
2643
- # Normalizes each channel of each image based on the default percentile values [0.1, 99.99].
2644
- """
2645
-
2646
- assert X[0].ndim==3,'Channel axis does not exist. Abort.'
2647
- n_channels = X[0].shape[-1]
2648
- if isinstance(normalization_percentile_mode, bool):
2649
- normalization_percentile_mode = [normalization_percentile_mode]*n_channels
2650
- if isinstance(normalization_clipping, bool):
2651
- normalization_clipping = [normalization_clipping]*n_channels
2652
- if len(normalization_values)==2 and not isinstance(normalization_values[0], list):
2653
- normalization_values = [normalization_values]*n_channels
2654
-
2655
- assert len(normalization_values)==n_channels
2656
- assert len(normalization_clipping)==n_channels
2657
- assert len(normalization_percentile_mode)==n_channels
2658
-
2659
- X_normalized = []
2660
- for i in range(len(X)):
2661
- x = X[i].copy()
2662
- loc_i,loc_j,loc_c = np.where(x==0.)
2663
- norm_x = np.zeros_like(x, dtype=np.float32)
2664
- for k in range(x.shape[-1]):
2665
- chan = x[:,:,k].copy()
2666
- if not np.all(chan.flatten()==0):
2667
- if normalization_percentile_mode[k]:
2668
- min_val = np.nanpercentile(chan[chan!=0.].flatten(), normalization_values[k][0])
2669
- max_val = np.nanpercentile(chan[chan!=0.].flatten(), normalization_values[k][1])
2670
- else:
2671
- min_val = normalization_values[k][0]
2672
- max_val = normalization_values[k][1]
2673
-
2674
- clip_option = normalization_clipping[k]
2675
- norm_x[:,:,k] = normalize_mi_ma(chan.astype(np.float32).copy(), min_val, max_val, clip=clip_option, eps=1e-20, dtype=np.float32)
2676
- else:
2677
- norm_x[:,:,k] = 0.
2678
- norm_x[loc_i,loc_j,loc_c] = 0.
2679
- X_normalized.append(norm_x.copy())
2680
-
2681
- return X_normalized
2682
-
2683
- def load_image_dataset(datasets, channels, train_spatial_calibration=None, mask_suffix='labelled'):
2684
-
2685
- """
2686
- Loads image and corresponding mask datasets, optionally applying spatial calibration adjustments.
2687
-
2688
- This function iterates over specified datasets, loading image and mask pairs based on provided channels
2689
- and adjusting images according to a specified spatial calibration factor. It supports loading images with
2690
- multiple channels and applies necessary transformations to match the training spatial calibration.
2691
-
2692
- Parameters
2693
- ----------
2694
- datasets : list of str
2695
- A list of paths to the datasets containing the images and masks.
2696
- channels : str or list of str
2697
- The channel(s) to be loaded from the images. If a string is provided, it is converted into a list.
2698
- train_spatial_calibration : float, optional
2699
- The spatial calibration (e.g., micrometers per pixel) used during model training. If provided, images
2700
- will be rescaled to match this calibration. Default is None, indicating no rescaling is applied.
2701
- mask_suffix : str, optional
2702
- The suffix used to identify mask files corresponding to the images. Default is 'labelled'.
2703
-
2704
- Returns
2705
- -------
2706
- tuple of lists
2707
- A tuple containing two lists: `X` for images and `Y` for corresponding masks. Both lists contain
2708
- numpy arrays of loaded and optionally transformed images and masks.
2709
-
2710
- Raises
2711
- ------
2712
- AssertionError
2713
- If the provided `channels` argument is not a list or if the number of loaded images does not match
2714
- the number of loaded masks.
2715
-
2716
- Notes
2717
- -----
2718
- - The function assumes that mask filenames are derived from image filenames by appending a `mask_suffix`
2719
- before the file extension.
2720
- - Spatial calibration adjustment involves rescaling the images and masks to match the `train_spatial_calibration`.
2721
- - Only images with a corresponding mask and a valid configuration file specifying channel indices and
2722
- spatial calibration are loaded.
2723
- - The image samples must have at least one channel in common with the required channels to be accepted. The missing
2724
- channels are passed as black frames.
2725
-
2726
- Examples
2727
- --------
2728
- >>> datasets = ['/path/to/dataset1', '/path/to/dataset2']
2729
- >>> channels = ['DAPI', 'GFP']
2730
- >>> X, Y = load_image_dataset(datasets, channels, train_spatial_calibration=0.65)
2731
- # Loads DAPI and GFP channels from specified datasets, rescaling images to match a spatial calibration of 0.65.
2732
- """
2733
-
2734
- if isinstance(channels, str):
2735
- channels = [channels]
2736
-
2737
- assert isinstance(channels, list),'Please provide a list of channels. Abort.'
2738
-
2739
- X = []; Y = []; files = [];
2740
-
2741
- for ds in datasets:
2742
- print(f'Loading data from dataset {ds}...')
2743
- if not ds.endswith(os.sep):
2744
- ds+=os.sep
2745
- img_paths = list(set(glob(ds+'*.tif')) - set(glob(ds+f'*_{mask_suffix}.tif')))
2746
- for im in img_paths:
2747
- print(f'{im=}')
2748
- mask_path = os.sep.join([os.path.split(im)[0],os.path.split(im)[-1].replace('.tif', f'_{mask_suffix}.tif')])
2749
- if os.path.exists(mask_path):
2750
- # load image and mask
2751
- image = imread(im)
2752
- if image.ndim==2:
2753
- image = image[np.newaxis]
2754
- if image.ndim>3:
2755
- print('Invalid image shape, skipping')
2756
- continue
2757
- mask = imread(mask_path)
2758
- config_path = im.replace('.tif','.json')
2759
- if os.path.exists(config_path):
2760
- # Load config
2761
- with open(config_path, 'r') as f:
2762
- config = json.load(f)
2763
-
2764
- existing_channels = config['channels']
2765
- intersection = list(set(list(channels)) & set(list(existing_channels)))
2766
- print(f'{existing_channels=} {intersection=}')
2767
- if len(intersection)==0:
2768
- print('Channels could not be found in the config... Skipping image.')
2769
- continue
2770
- else:
2771
- ch_idx = []
2772
- for c in channels:
2773
- if c in existing_channels:
2774
- idx = existing_channels.index(c)
2775
- ch_idx.append(idx)
2776
- else:
2777
- # For None or missing channel pass black frame
2778
- ch_idx.append(np.nan)
2779
- im_calib = config['spatial_calibration']
2780
-
2781
- ch_idx = np.array(ch_idx)
2782
- ch_idx_safe = np.copy(ch_idx)
2783
- ch_idx_safe[ch_idx_safe!=ch_idx_safe] = 0
2784
- ch_idx_safe = ch_idx_safe.astype(int)
2785
-
2786
- image = image[ch_idx_safe]
2787
- image[np.where(ch_idx!=ch_idx)[0],:,:] = 0
2788
-
2789
- image = np.moveaxis(image,0,-1)
2790
- assert image.ndim==3,'The image has a wrong number of dimensions. Abort.'
2791
-
2792
- if im_calib != train_spatial_calibration:
2793
- factor = im_calib / train_spatial_calibration
2794
- image = np.moveaxis([zoom(image[:,:,c].astype(float).copy(), [factor,factor], order=3, prefilter=False) for c in range(image.shape[-1])],0,-1) #zoom(image, [factor,factor,1], order=3)
2795
- mask = zoom(mask, [factor,factor], order=0)
2796
-
2797
- X.append(image)
2798
- Y.append(mask)
2799
-
2800
- # fig,ax = plt.subplots(1,image.shape[-1]+1)
2801
- # for k in range(image.shape[-1]):
2802
- # ax[k].imshow(image[:,:,k],cmap='gray')
2803
- # ax[image.shape[-1]].imshow(mask)
2804
- # plt.pause(1)
2805
- # plt.close()
2806
-
2807
- files.append(im)
2808
-
2809
- assert len(X)==len(Y),'The number of images does not match with the number of masks... Abort.'
2810
- return X,Y,files
2811
-
2812
-
2813
- def download_url_to_file(url, dst, progress=True):
2814
- r"""Download object at the given URL to a local path.
2815
- Thanks to torch, slightly modified, from Cellpose
2816
- Args:
2817
- url (string): URL of the object to download
2818
- dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
2819
- progress (bool, optional): whether or not to display a progress bar to stderr
2820
- Default: True
2821
- """
2822
- file_size = None
2823
- import ssl
2824
- ssl._create_default_https_context = ssl._create_unverified_context
2825
- u = urlopen(url)
2826
- meta = u.info()
2827
- if hasattr(meta, 'getheaders'):
2828
- content_length = meta.getheaders("Content-Length")
2829
- else:
2830
- content_length = meta.get_all("Content-Length")
2831
- if content_length is not None and len(content_length) > 0:
2832
- file_size = int(content_length[0])
2833
- # We deliberately save it in a temp file and move it after
2834
- dst = os.path.expanduser(dst)
2835
- dst_dir = os.path.dirname(dst)
2836
- f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
2837
- try:
2838
- with tqdm(total=file_size, disable=not progress,
2839
- unit='B', unit_scale=True, unit_divisor=1024) as pbar:
2840
- while True:
2841
- buffer = u.read(8192) #8192
2842
- if len(buffer) == 0:
2843
- break
2844
- f.write(buffer)
2845
- pbar.update(len(buffer))
2846
- f.close()
2847
- shutil.move(f.name, dst)
2848
- finally:
2849
- f.close()
2850
- remove_file_if_exists(f.name)
2851
-
2852
- def get_zenodo_files(cat=None):
2853
-
2854
-
2855
- zenodo_json = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],"celldetective", "links", "zenodo.json"])
2856
- with open(zenodo_json,"r") as f:
2857
- zenodo_json = json.load(f)
2858
- all_files = list(zenodo_json['files']['entries'].keys())
2859
- all_files_short = [f.replace(".zip","") for f in all_files]
2860
-
2861
- categories = []
2862
- for f in all_files_short:
2863
- if f.startswith('CP') or f.startswith('SD'):
2864
- category = os.sep.join(['models','segmentation_generic'])
2865
- elif f.startswith('MCF7') or f.startswith('mcf7'):
2866
- category = os.sep.join(['models','segmentation_targets'])
2867
- elif f.startswith('primNK') or f.startswith('lymphocytes'):
2868
- category = os.sep.join(['models','segmentation_effectors'])
2869
- elif f.startswith('demo'):
2870
- category = 'demos'
2871
- elif f.startswith('db-si'):
2872
- category = os.sep.join(['datasets','signal_annotations'])
2873
- elif f.startswith('db'):
2874
- category = os.sep.join(['datasets','segmentation_annotations'])
2875
- else:
2876
- category = os.sep.join(['models','signal_detection'])
2877
- categories.append(category)
2878
-
2879
- if cat is not None:
2880
- if cat in [os.sep.join(['models','segmentation_generic']), os.sep.join(['models','segmentation_targets']), os.sep.join(['models','segmentation_effectors']), \
2881
- 'demos', os.sep.join(['datasets','signal_annotations']), os.sep.join(['datasets','segmentation_annotations']), os.sep.join(['models','signal_detection'])]:
2882
- categories = np.array(categories)
2883
- all_files_short = np.array(all_files_short)
2884
- return list(all_files_short[np.where(categories==cat)[0]])
2885
- else:
2886
- return []
2887
- else:
2888
- return all_files_short,categories
2889
-
2890
- def download_zenodo_file(file, output_dir):
2891
-
2892
- zenodo_json = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],"celldetective", "links", "zenodo.json"])
2893
- with open(zenodo_json,"r") as f:
2894
- zenodo_json = json.load(f)
2895
- all_files = list(zenodo_json['files']['entries'].keys())
2896
- all_files_short = [f.replace(".zip","") for f in all_files]
2897
- zenodo_url = zenodo_json['links']['files'].replace('api/','')
2898
- full_links = ["/".join([zenodo_url, f]) for f in all_files]
2899
- index = all_files_short.index(file)
2900
- zip_url = full_links[index]
2901
-
2902
- path_to_zip_file = os.sep.join([output_dir, 'temp.zip'])
2903
- download_url_to_file(fr"{zip_url}",path_to_zip_file)
2904
- with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
2905
- zip_ref.extractall(output_dir)
2906
-
2907
- file_to_rename = glob(os.sep.join([output_dir,file,"*[!.json][!.png][!.h5][!.csv][!.npy][!.tif][!.ini]"]))
2908
- if len(file_to_rename)>0 and not file_to_rename[0].endswith(os.sep) and not file.startswith('demo'):
2909
- os.rename(file_to_rename[0], os.sep.join([output_dir,file,file]))
2910
-
2911
- os.remove(path_to_zip_file)
2912
-
2913
- def interpolate_nan(img, method='nearest'):
2914
-
2915
- """
2916
- Interpolate NaN on single channel array 2D
2917
- """
2918
-
2919
- if np.all(img==0):
2920
- return img
2921
-
2922
- if np.any(img.flatten()!=img.flatten()):
2923
- # then need to interpolate
2924
- x_grid, y_grid = np.meshgrid(np.arange(img.shape[1]),np.arange(img.shape[0]))
2925
- mask = [~np.isnan(img)][0]
2926
- x = x_grid[mask].reshape(-1)
2927
- y = y_grid[mask].reshape(-1)
2928
- points = np.array([x,y]).T
2929
- values = img[mask].reshape(-1)
2930
- interp_grid = griddata(points, values, (x_grid, y_grid), method=method)
2931
- return interp_grid
2932
- else:
2933
- return img
2934
-
2935
-
2936
- def interpolate_nan_multichannel(frames):
2937
- frames = np.moveaxis([interpolate_nan(frames[:,:,c].copy()) for c in range(frames.shape[-1])],0,-1)
2938
- return frames
2939
-
2940
- def collapse_trajectories_by_status(df, status=None, projection='mean', population='effectors', groupby_columns=['position','TRACK_ID']):
2941
-
2942
- static_columns = ['well_index', 'well_name', 'pos_name', 'position', 'well', 'status', 't0', 'class','cell_type','concentration', 'antibody', 'pharmaceutical_agent','TRACK_ID','position', 'neighbor_population', 'reference_population', 'NEIGHBOR_ID', 'REFERENCE_ID', 'FRAME']
2943
-
2944
- if status is None or status not in list(df.columns):
2945
- print('invalid status selection...')
2946
- return None
2947
-
2948
- df = df.dropna(subset=status,ignore_index=True)
2949
- unique_statuses = np.unique(df[status].to_numpy())
2950
-
2951
- df_sections = []
2952
- for s in unique_statuses:
2953
- subtab = df.loc[df[status]==s,:]
2954
- op = getattr(subtab.groupby(groupby_columns), projection)
2955
- subtab_projected = op(subtab.groupby(groupby_columns))
2956
- frame_duration = subtab.groupby(groupby_columns).size().to_numpy()
2957
- for c in static_columns:
2958
- try:
2959
- subtab_projected[c] = subtab.groupby(groupby_columns)[c].apply(lambda x: x.unique()[0])
2960
- except Exception as e:
2961
- print(e)
2962
- pass
2963
- subtab_projected['duration_in_state'] = frame_duration
2964
- df_sections.append(subtab_projected)
2965
-
2966
- group_table = pd.concat(df_sections,axis=0,ignore_index=True)
2967
- if population=='pairs':
2968
- for col in ['duration_in_state',status, 'neighbor_population', 'reference_population', 'NEIGHBOR_ID', 'REFERENCE_ID']:
2969
- first_column = group_table.pop(col)
2970
- group_table.insert(0, col, first_column)
2971
- else:
2972
- for col in ['duration_in_state',status,'TRACK_ID']:
2973
- first_column = group_table.pop(col)
2974
- group_table.insert(0, col, first_column)
2975
-
2976
- group_table.pop('FRAME')
2977
- group_table = group_table.sort_values(by=groupby_columns + [status],ignore_index=True)
2978
- group_table = group_table.reset_index(drop=True)
2979
-
2980
- return group_table
2981
-
2982
- def step_function(t: Union[np.ndarray,List], t_shift: float, dt: float) -> np.ndarray:
2983
-
2984
- """
2985
- Computes a step function using the logistic sigmoid function.
2986
-
2987
- This function calculates the value of a sigmoid function, which is often used to model
2988
- a step change or transition. The sigmoid function is defined as:
2989
-
2990
- .. math::
2991
- f(t) = \\frac{1}{1 + \\exp{\\left( -\\frac{t - t_{shift}}{dt} \\right)}}
2992
-
2993
- where `t` is the input variable, `t_shift` is the point of the transition, and `dt` controls
2994
- the steepness of the transition.
2995
-
2996
- Parameters
2997
- ----------
2998
- t : array_like
2999
- The input values for which the step function will be computed.
3000
- t_shift : float
3001
- The point in the `t` domain where the transition occurs.
3002
- dt : float
3003
- The parameter that controls the steepness of the transition. Smaller values make the
3004
- transition steeper, while larger values make it smoother.
3005
-
3006
- Returns
3007
- -------
3008
- array_like
3009
- The computed values of the step function for each value in `t`.
3010
-
3011
- Examples
3012
- --------
3013
- >>> import numpy as np
3014
- >>> t = np.array([0, 1, 2, 3, 4, 5])
3015
- >>> t_shift = 2
3016
- >>> dt = 1
3017
- >>> step_function(t, t_shift, dt)
3018
- array([0.26894142, 0.37754067, 0.5 , 0.62245933, 0.73105858, 0.81757448])
3019
- """
3020
-
3021
- return 1/(1+np.exp(-(t-t_shift)/dt))
3022
-
3023
-
3024
- def test_2samp_generic(data: pd.DataFrame, feature: Optional[str] = None, groupby_cols: Optional[Union[str,List[str]]] = None, method="ks_2samp", *args, **kwargs) -> pd.DataFrame:
3025
-
3026
- """
3027
- Performs pairwise statistical tests between groups of data, comparing a specified feature using a chosen method.
3028
-
3029
- The function applies two-sample statistical tests, such as the Kolmogorov-Smirnov (KS) test or Cliff's Delta,
3030
- to compare distributions of a given feature across groups defined by `groupby_cols`. It returns the test results
3031
- in a pivot table format with each group's pairwise comparison.
3032
-
3033
- Parameters
3034
- ----------
3035
- data : pandas.DataFrame
3036
- The input dataset containing the feature to be tested.
3037
- feature : str
3038
- The name of the column representing the feature to compare between groups.
3039
- groupby_cols : list or str
3040
- The column(s) used to group the data. These columns define the groups that will be compared pairwise.
3041
- method : str, optional, default="ks_2samp"
3042
- The statistical test to use. Options:
3043
- - "ks_2samp": Two-sample Kolmogorov-Smirnov test (default).
3044
- - "cliffs_delta": Cliff's Delta for effect size between two distributions.
3045
- *args, **kwargs :
3046
- Additional arguments and keyword arguments for the selected test method.
3047
-
3048
- Returns
3049
- -------
3050
- pivot : pandas.DataFrame
3051
- A pivot table containing the pairwise test results (p-values or effect sizes).
3052
- The rows and columns represent the unique groups defined by `groupby_cols`,
3053
- and the values represent the test result (e.g., p-values or effect sizes) between each group.
3054
-
3055
- Notes
3056
- -----
3057
- - The function compares all unique pairwise combinations of the groups based on `groupby_cols`.
3058
- - For the "ks_2samp" method, the test compares the distributions using the Kolmogorov-Smirnov test.
3059
- - For the "cliffs_delta" method, the function calculates the effect size between two distributions.
3060
- - The results are returned in a symmetric pivot table where each cell represents the test result for the corresponding group pair.
3061
-
3062
- """
3063
-
3064
-
3065
- assert groupby_cols is not None,"Please set a valid groupby_cols..."
3066
- assert feature is not None,"Please set a feature to test..."
3067
-
3068
- results = []
3069
-
3070
- for lbl1,group1 in data.dropna(subset=feature).groupby(groupby_cols):
3071
- for lbl2,group2 in data.dropna(subset=feature).groupby(groupby_cols):
3072
-
3073
- dist1 = group1[feature].values
3074
- dist2 = group2[feature].values
3075
- if method=="ks_2samp":
3076
- test = ks_2samp(list(dist1),list(dist2), alternative='less', mode='auto', *args, **kwargs)
3077
- val = test.pvalue
3078
- elif method=="cliffs_delta":
3079
- test = cliffs_delta(list(dist1),list(dist2), *args, **kwargs)
3080
- val = test[0]
3081
-
3082
- results.append({"cdt1": lbl1, "cdt2": lbl2, "value": val})
3083
-
3084
- results = pd.DataFrame(results)
3085
- results['cdt1'] = results['cdt1'].astype(str)
3086
- results['cdt2'] = results['cdt2'].astype(str)
3087
-
3088
- pivot = results.pivot(index='cdt1', columns='cdt2', values='value')
3089
- pivot.reset_index(inplace=True)
3090
- pivot.columns.name = None
3091
- pivot.set_index("cdt1",drop=True, inplace=True)
3092
- pivot.index.name = None
3093
-
3094
- return pivot
3095
-
3096
- def pretty_table(dct: dict):
3097
- table = PrettyTable()
3098
- for c in dct.keys():
3099
- table.add_column(str(c), [])
3100
- table.add_row([dct.get(c, "") for c in dct.keys()])
3101
- print(table)
3102
-
3103
- def remove_file_if_exists(file: Union[str,Path]):
3104
- if os.path.exists(file):
3105
- try:
3106
- os.remove(file)
3107
- except Exception as e:
3108
- print(e)