celldetective 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. celldetective/__init__.py +2 -0
  2. celldetective/__main__.py +432 -0
  3. celldetective/datasets/segmentation_annotations/blank +0 -0
  4. celldetective/datasets/signal_annotations/blank +0 -0
  5. celldetective/events.py +149 -0
  6. celldetective/extra_properties.py +100 -0
  7. celldetective/filters.py +89 -0
  8. celldetective/gui/__init__.py +20 -0
  9. celldetective/gui/about.py +44 -0
  10. celldetective/gui/analyze_block.py +563 -0
  11. celldetective/gui/btrack_options.py +898 -0
  12. celldetective/gui/classifier_widget.py +386 -0
  13. celldetective/gui/configure_new_exp.py +532 -0
  14. celldetective/gui/control_panel.py +438 -0
  15. celldetective/gui/gui_utils.py +495 -0
  16. celldetective/gui/json_readers.py +113 -0
  17. celldetective/gui/measurement_options.py +1425 -0
  18. celldetective/gui/neighborhood_options.py +452 -0
  19. celldetective/gui/plot_signals_ui.py +1042 -0
  20. celldetective/gui/process_block.py +1055 -0
  21. celldetective/gui/retrain_segmentation_model_options.py +706 -0
  22. celldetective/gui/retrain_signal_model_options.py +643 -0
  23. celldetective/gui/seg_model_loader.py +460 -0
  24. celldetective/gui/signal_annotator.py +2388 -0
  25. celldetective/gui/signal_annotator_options.py +340 -0
  26. celldetective/gui/styles.py +217 -0
  27. celldetective/gui/survival_ui.py +903 -0
  28. celldetective/gui/tableUI.py +608 -0
  29. celldetective/gui/thresholds_gui.py +1300 -0
  30. celldetective/icons/logo-large.png +0 -0
  31. celldetective/icons/logo.png +0 -0
  32. celldetective/icons/signals_icon.png +0 -0
  33. celldetective/icons/splash-test.png +0 -0
  34. celldetective/icons/splash.png +0 -0
  35. celldetective/icons/splash0.png +0 -0
  36. celldetective/icons/survival2.png +0 -0
  37. celldetective/icons/vignette_signals2.png +0 -0
  38. celldetective/icons/vignette_signals2.svg +114 -0
  39. celldetective/io.py +2050 -0
  40. celldetective/links/zenodo.json +561 -0
  41. celldetective/measure.py +1258 -0
  42. celldetective/models/segmentation_effectors/blank +0 -0
  43. celldetective/models/segmentation_generic/blank +0 -0
  44. celldetective/models/segmentation_targets/blank +0 -0
  45. celldetective/models/signal_detection/blank +0 -0
  46. celldetective/models/tracking_configs/mcf7.json +68 -0
  47. celldetective/models/tracking_configs/ricm.json +203 -0
  48. celldetective/models/tracking_configs/ricm2.json +203 -0
  49. celldetective/neighborhood.py +717 -0
  50. celldetective/scripts/analyze_signals.py +51 -0
  51. celldetective/scripts/measure_cells.py +275 -0
  52. celldetective/scripts/segment_cells.py +212 -0
  53. celldetective/scripts/segment_cells_thresholds.py +140 -0
  54. celldetective/scripts/track_cells.py +206 -0
  55. celldetective/scripts/train_segmentation_model.py +246 -0
  56. celldetective/scripts/train_signal_model.py +49 -0
  57. celldetective/segmentation.py +712 -0
  58. celldetective/signals.py +2826 -0
  59. celldetective/tracking.py +974 -0
  60. celldetective/utils.py +1681 -0
  61. celldetective-1.0.2.dist-info/LICENSE +674 -0
  62. celldetective-1.0.2.dist-info/METADATA +192 -0
  63. celldetective-1.0.2.dist-info/RECORD +66 -0
  64. celldetective-1.0.2.dist-info/WHEEL +5 -0
  65. celldetective-1.0.2.dist-info/entry_points.txt +2 -0
  66. celldetective-1.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,712 @@
1
+ """
2
+ Segmentation module
3
+ """
4
+ import json
5
+ import os
6
+ from .io import locate_segmentation_model, get_stack_normalization_values, normalize_multichannel
7
+ from .utils import _estimate_scale_factor, _extract_channel_indices
8
+ from pathlib import Path
9
+ from tqdm import tqdm
10
+ import numpy as np
11
+ from stardist.models import StarDist2D
12
+ from cellpose.models import CellposeModel
13
+ from skimage.transform import resize
14
+ from celldetective.io import _view_on_napari, locate_labels, locate_stack, _view_on_napari
15
+ from celldetective.filters import * #rework this to give a name
16
+ from celldetective.utils import rename_intensity_column
17
+ import scipy.ndimage as ndi
18
+ from skimage.segmentation import watershed
19
+ from skimage.feature import peak_local_max
20
+ from skimage.measure import regionprops_table
21
+ from skimage.exposure import match_histograms
22
+ import pandas as pd
23
+ import subprocess
24
+
25
+
26
+ abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
27
+
28
+ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_napari=False,
29
+ use_gpu=True, time_flat_normalization=False, time_flat_percentiles=(0.0,99.99)):
30
+
31
+ """
32
+
33
+ Segment objects in a stack using a pre-trained segmentation model.
34
+
35
+ Parameters
36
+ ----------
37
+ stack : ndarray
38
+ The input stack to be segmented, with shape (frames, height, width, channels).
39
+ model_name : str
40
+ The name of the pre-trained segmentation model to use.
41
+ channels : list or None, optional
42
+ The names of the channels in the stack. If None, assumes the channels are indexed from 0 to `stack.shape[-1] - 1`.
43
+ Default is None.
44
+ spatial_calibration : float or None, optional
45
+ The spatial calibration factor of the stack. If None, the calibration factor from the model configuration will be used.
46
+ Default is None.
47
+ view_on_napari : bool, optional
48
+ Whether to visualize the segmentation results using Napari. Default is False.
49
+ use_gpu : bool, optional
50
+ Whether to use GPU acceleration if available. Default is True.
51
+ time_flat_normalization : bool, optional
52
+ Whether to perform time-flat normalization on the stack before segmentation. Default is False.
53
+ time_flat_percentiles : tuple, optional
54
+ The percentiles used for time-flat normalization. Default is (0.0, 99.99).
55
+
56
+ Returns
57
+ -------
58
+ ndarray
59
+ The segmented labels with shape (frames, height, width).
60
+
61
+ Notes
62
+ -----
63
+ This function applies object segmentation to a stack of images using a pre-trained segmentation model. The stack is first
64
+ preprocessed by normalizing the intensity values, rescaling the spatial dimensions, and applying the segmentation model.
65
+ The resulting labels are returned as an ndarray with the same number of frames as the input stack.
66
+
67
+ Examples
68
+ --------
69
+ >>> stack = np.random.rand(10, 256, 256, 3)
70
+ >>> labels = segment(stack, 'model_name', channels=['channel_1', 'channel_2', 'channel_3'], spatial_calibration=0.5)
71
+
72
+ """
73
+
74
+ model_path = locate_segmentation_model(model_name)
75
+ input_config = model_path+'config_input.json'
76
+ if os.path.exists(input_config):
77
+ with open(input_config) as config:
78
+ print("Loading input configuration from 'config_input.json'.")
79
+ input_config = json.load(config)
80
+ else:
81
+ print('Model input configuration could not be located...')
82
+ return None
83
+
84
+ if not use_gpu:
85
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
86
+ else:
87
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
88
+
89
+ if channels is not None:
90
+ assert len(channels)==stack.shape[-1],f'The channel names provided do not match with the expected number of channels in the stack: {stack.shape[-1]}.'
91
+
92
+ required_channels = input_config['channels']
93
+ channel_indices = _extract_channel_indices(channels, required_channels)
94
+
95
+ required_spatial_calibration = input_config['spatial_calibration']
96
+ model_type = input_config['model_type']
97
+
98
+ if 'normalize' in input_config:
99
+ normalize = input_config['normalize']
100
+ else:
101
+ normalize = True
102
+
103
+ if model_type=='cellpose':
104
+ diameter = input_config['diameter']
105
+ if diameter!=30:
106
+ required_spatial_calibration = None
107
+ cellprob_threshold = input_config['cellprob_threshold']
108
+ flow_threshold = input_config['flow_threshold']
109
+
110
+ scale = _estimate_scale_factor(spatial_calibration, required_spatial_calibration)
111
+
112
+ if model_type=='stardist':
113
+ model = StarDist2D(None, name=model_name, basedir=Path(model_path).parent)
114
+ print(f"StarDist model {model_name} successfully loaded")
115
+
116
+ elif model_type=='cellpose':
117
+ model = CellposeModel(gpu=use_gpu, pretrained_model=model_path+model_path.split('/')[-2], diam_mean=30.0)
118
+
119
+ labels = []
120
+ if (time_flat_normalization)*normalize:
121
+ normalization_values = get_stack_normalization_values(stack[:,:,:,channel_indices], percentiles=time_flat_percentiles)
122
+ else:
123
+ normalization_values = [None]*len(channel_indices)
124
+
125
+ for t in tqdm(range(len(stack)),desc="frame"):
126
+
127
+ # normalize
128
+ frame = stack[t,:,:,np.array(channel_indices)]
129
+ if np.argmin(frame.shape)!=(frame.ndim-1):
130
+ frame = np.moveaxis(frame,np.argmin(frame.shape),-1)
131
+ if normalize:
132
+ frame = normalize_multichannel(frame, values=normalization_values)
133
+
134
+ if scale is not None:
135
+ frame = ndi.zoom(frame, [scale,scale,1], order=3)
136
+
137
+ if model_type=="stardist":
138
+
139
+ Y_pred, details = model.predict_instances(frame, n_tiles=model._guess_n_tiles(frame), show_tile_progress=False, verbose=False)
140
+ Y_pred = Y_pred.astype(np.uint16)
141
+
142
+ elif model_type=="cellpose":
143
+
144
+ if stack.ndim==3:
145
+ channels_cp = [[0,0]]
146
+ else:
147
+ channels_cp = [[0,1]]
148
+
149
+ Y_pred, _, _ = model.eval([frame], diameter = diameter, flow_threshold=flow_threshold, channels=channels_cp, normalize=normalize)
150
+ Y_pred = Y_pred[0].astype(np.uint16)
151
+
152
+ if scale is not None:
153
+ Y_pred = ndi.zoom(Y_pred, [1./scale,1./scale],order=0)
154
+
155
+
156
+ if Y_pred.shape != stack[0].shape[:2]:
157
+ Y_pred = resize(Y_pred, stack[0].shape, order=0)
158
+
159
+ labels.append(Y_pred)
160
+
161
+ labels = np.array(labels,dtype=int)
162
+
163
+ if view_on_napari:
164
+ _view_on_napari(tracks=None, stack=stack, labels=labels)
165
+
166
+ return labels
167
+
168
+
169
+ def segment_from_thresholds(stack, target_channel=0, thresholds=None, view_on_napari=False, equalize_reference=None,
170
+ filters=None, marker_min_distance=30, marker_footprint_size=20, marker_footprint=None, feature_queries=None):
171
+
172
+ """
173
+ Segments objects from a stack of images based on provided thresholds and optional image processing steps.
174
+
175
+ This function applies instance segmentation to each frame in a stack of images. Segmentation is based on intensity
176
+ thresholds, optionally preceded by image equalization and filtering. Identified objects can
177
+ be distinguished by applying distance-based marker detection. The segmentation results can be optionally viewed in Napari.
178
+
179
+ Parameters
180
+ ----------
181
+ stack : ndarray
182
+ A 4D numpy array representing the image stack with dimensions (T, Y, X, C) where T is the
183
+ time dimension and C the channel dimension.
184
+ target_channel : int, optional
185
+ The channel index to be used for segmentation (default is 0).
186
+ thresholds : list of tuples, optional
187
+ A list of tuples specifying intensity thresholds for segmentation. Each tuple corresponds to a frame in the stack,
188
+ with values (lower_threshold, upper_threshold). If None, global thresholds are determined automatically (default is None).
189
+ view_on_napari : bool, optional
190
+ If True, displays the original stack and segmentation results in Napari (default is False).
191
+ equalize_reference : int or None, optional
192
+ The index of a reference frame used for histogram equalization. If None, equalization is not performed (default is None).
193
+ filters : list of dict, optional
194
+ A list of dictionaries specifying filters to be applied pre-segmentation. Each dictionary should
195
+ contain filter parameters (default is None).
196
+ marker_min_distance : int, optional
197
+ The minimum distance between markers used for distinguishing separate objects (default is 30).
198
+ marker_footprint_size : int, optional
199
+ The size of the footprint used for local maxima detection when generating markers (default is 20).
200
+ marker_footprint : ndarray or None, optional
201
+ An array specifying the footprint used for local maxima detection. Overrides `marker_footprint_size` if provided
202
+ (default is None).
203
+ feature_queries : list of str or None, optional
204
+ A list of query strings used to select features of interest from the segmented objects (default is None).
205
+
206
+ Returns
207
+ -------
208
+ ndarray
209
+ A 3D numpy array (T, Y, X) of type int16, where each element represents the segmented object label at each pixel.
210
+
211
+ Notes
212
+ -----
213
+ - The segmentation process can be customized extensively via the parameters, allowing for complex segmentation tasks.
214
+
215
+ """
216
+
217
+
218
+ masks = []
219
+ for t in tqdm(range(len(stack))):
220
+ instance_seg = segment_frame_from_thresholds(stack[t], target_channel=target_channel, thresholds=thresholds, equalize_reference=equalize_reference,
221
+ filters=filters, marker_min_distance=marker_min_distance, marker_footprint_size=marker_footprint_size,
222
+ marker_footprint=marker_footprint, feature_queries=feature_queries)
223
+ masks.append(instance_seg)
224
+
225
+ masks = np.array(masks, dtype=np.int16)
226
+ if view_on_napari:
227
+ _view_on_napari(tracks=None, stack=stack, labels=masks)
228
+ return masks
229
+
230
+ def segment_frame_from_thresholds(frame, target_channel=0, thresholds=None, equalize_reference=None,
231
+ filters=None, marker_min_distance=30, marker_footprint_size=20, marker_footprint=None, feature_queries=None, channel_names=None):
232
+
233
+ """
234
+ Segments objects within a single frame based on intensity thresholds and optional image processing steps.
235
+
236
+ This function performs instance segmentation on a single frame using intensity thresholds, with optional steps
237
+ including histogram equalization, filtering, and marker-based watershed segmentation. The segmented
238
+ objects can be further filtered based on specified features.
239
+
240
+ Parameters
241
+ ----------
242
+ frame : ndarray
243
+ A 3D numpy array representing a single frame with dimensions (Y, X, C).
244
+ target_channel : int, optional
245
+ The channel index to be used for segmentation (default is 0).
246
+ thresholds : tuple of int, optional
247
+ A tuple specifying the intensity thresholds for segmentation, in the form (lower_threshold, upper_threshold).
248
+ equalize_reference : ndarray or None, optional
249
+ A 2D numpy array used as a reference for histogram equalization. If None, equalization is not performed (default is None).
250
+ filters : list of dict, optional
251
+ A list of dictionaries specifying filters to be applied to the image before segmentation. Each dictionary
252
+ should contain filter parameters (default is None).
253
+ marker_min_distance : int, optional
254
+ The minimum distance between markers used for distinguishing separate objects during watershed segmentation (default is 30).
255
+ marker_footprint_size : int, optional
256
+ The size of the footprint used for local maxima detection when generating markers for watershed segmentation (default is 20).
257
+ marker_footprint : ndarray or None, optional
258
+ An array specifying the footprint used for local maxima detection. Overrides `marker_footprint_size` if provided (default is None).
259
+ feature_queries : list of str or None, optional
260
+ A list of query strings used to select features of interest from the segmented objects for further filtering (default is None).
261
+ channel_names : list of str or None, optional
262
+ A list of channel names corresponding to the dimensions in `frame`, used in conjunction with `feature_queries` for feature selection (default is None).
263
+
264
+ Returns
265
+ -------
266
+ ndarray
267
+ A 2D numpy array of type int, where each element represents the segmented object label at each pixel.
268
+
269
+ """
270
+
271
+ img = frame[:,:,target_channel]
272
+ if equalize_reference is not None:
273
+ img = match_histograms(img, equalize_reference)
274
+ img_mc = frame.copy()
275
+ img = filter_image(img, filters=filters)
276
+ binary_image = threshold_image(img, thresholds[0], thresholds[1])
277
+ coords,distance = identify_markers_from_binary(binary_image, marker_min_distance, footprint_size=marker_footprint_size, footprint=marker_footprint, return_edt=True)
278
+ instance_seg = apply_watershed(binary_image, coords, distance)
279
+ instance_seg = filter_on_property(instance_seg, intensity_image=img_mc, queries=feature_queries, channel_names=channel_names)
280
+
281
+ return instance_seg
282
+
283
+
284
+ def filter_on_property(labels, intensity_image=None, queries=None, channel_names=None):
285
+
286
+ """
287
+ Filters segmented objects in a label image based on specified properties and queries.
288
+
289
+ This function evaluates each segmented object (label) in the input label image against a set of queries related to its
290
+ morphological and intensity properties. Objects not meeting the criteria defined in the queries are removed from the label
291
+ image. This allows for the exclusion of objects based on size, shape, intensity, or custom-defined properties.
292
+
293
+ Parameters
294
+ ----------
295
+ labels : ndarray
296
+ A 2D numpy array where each unique non-zero integer represents a segmented object (label).
297
+ intensity_image : ndarray, optional
298
+ A 2D numpy array of the same shape as `labels`, providing intensity values for each pixel. This is used to calculate
299
+ intensity-related properties of the segmented objects if provided (default is None).
300
+ queries : str or list of str, optional
301
+ One or more query strings used to filter the segmented objects based on their properties. Each query should be a
302
+ valid pandas query string (default is None).
303
+ channel_names : list of str or None, optional
304
+ A list of channel names corresponding to the dimensions in the `intensity_image`. This is used to rename intensity
305
+ property columns appropriately (default is None).
306
+
307
+ Returns
308
+ -------
309
+ ndarray
310
+ A 2D numpy array of the same shape as `labels`, with objects not meeting the query criteria removed.
311
+
312
+ Notes
313
+ -----
314
+ - The function computes a set of predefined morphological properties and, if `intensity_image` is provided, intensity properties.
315
+ - Queries should be structured according to pandas DataFrame query syntax and can reference any of the computed properties.
316
+ - If `channel_names` is provided, intensity property column names are renamed to reflect the corresponding channel.
317
+
318
+ """
319
+
320
+ if queries is None:
321
+ return labels
322
+ else:
323
+ if isinstance(queries, str):
324
+ queries = [queries]
325
+
326
+ props = ['label','area', 'area_bbox', 'area_convex', 'area_filled', 'axis_major_length',
327
+ 'axis_minor_length', 'eccentricity', 'equivalent_diameter_area',
328
+ 'euler_number', 'feret_diameter_max', 'orientation', 'perimeter',
329
+ 'perimeter_crofton', 'solidity', 'centroid']
330
+
331
+ intensity_props = ['intensity_mean', 'intensity_max', 'intensity_min']
332
+
333
+ if intensity_image is not None:
334
+ props.extend(intensity_props)
335
+
336
+ properties = pd.DataFrame(regionprops_table(labels, intensity_image=intensity_image, properties=props))
337
+ if channel_names is not None:
338
+ properties = rename_intensity_column(properties, channel_names)
339
+ properties['radial_distance'] = np.sqrt((properties['centroid-1'] - labels.shape[0]/2)**2 + (properties['centroid-0'] - labels.shape[1]/2)**2)
340
+
341
+ for query in queries:
342
+ if query!='':
343
+ try:
344
+ properties = properties.query(f'not ({query})')
345
+ except Exception as e:
346
+ print(f'Query {query} could not be applied. Ensure that the feature exists. {e}')
347
+ else:
348
+ pass
349
+
350
+ cell_ids = list(np.unique(labels)[1:])
351
+ leftover_cells = list(properties['label'].unique())
352
+ to_remove = [value for value in cell_ids if value not in leftover_cells]
353
+
354
+ for c in to_remove:
355
+ labels[np.where(labels==c)] = 0.
356
+
357
+ return labels
358
+
359
+
360
+ def apply_watershed(binary_image, coords, distance):
361
+
362
+ """
363
+ Applies the watershed algorithm to segment objects in a binary image using given markers and distance map.
364
+
365
+ This function uses the watershed segmentation algorithm to delineate objects in a binary image. Markers for watershed
366
+ are determined by the coordinates of local maxima, and the segmentation is guided by a distance map to separate objects
367
+ that are close to each other.
368
+
369
+ Parameters
370
+ ----------
371
+ binary_image : ndarray
372
+ A 2D numpy array of type bool, where True represents the foreground objects to be segmented and False represents the background.
373
+ coords : ndarray
374
+ An array of shape (N, 2) containing the (row, column) coordinates of local maxima points that will be used as markers for the
375
+ watershed algorithm. N is the number of local maxima.
376
+ distance : ndarray
377
+ A 2D numpy array of the same shape as `binary_image`, containing the distance transform of the binary image. This map is used
378
+ to guide the watershed segmentation.
379
+
380
+ Returns
381
+ -------
382
+ ndarray
383
+ A 2D numpy array of type int, where each unique non-zero integer represents a segmented object (label).
384
+
385
+ Notes
386
+ -----
387
+ - The function assumes that `coords` are derived from the distance map of `binary_image`, typically obtained using
388
+ peak local max detection on the distance transform.
389
+ - The watershed algorithm treats each local maximum as a separate object and segments the image by "flooding" from these points.
390
+ - This implementation uses the `skimage.morphology.watershed` function under the hood.
391
+
392
+ Examples
393
+ --------
394
+ >>> from skimage import measure, morphology
395
+ >>> binary_image = np.array([[0, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0]], dtype=bool)
396
+ >>> distance = morphology.distance_transform_edt(binary_image)
397
+ >>> coords = measure.peak_local_max(distance, indices=True)
398
+ >>> labels = apply_watershed(binary_image, coords, distance)
399
+ # Segments the objects in `binary_image` using the watershed algorithm.
400
+
401
+ """
402
+
403
+ mask = np.zeros(binary_image.shape, dtype=bool)
404
+ mask[tuple(coords.T)] = True
405
+ markers, _ = ndi.label(mask)
406
+ labels = watershed(-distance, markers, mask=binary_image)
407
+
408
+ return labels
409
+
410
+ def identify_markers_from_binary(binary_image, min_distance, footprint_size=20, footprint=None, return_edt=False):
411
+
412
+ """
413
+
414
+ Identify markers from a binary image using distance transform and peak detection.
415
+
416
+ Parameters
417
+ ----------
418
+ binary_image : ndarray
419
+ The binary image from which to identify markers.
420
+ min_distance : int
421
+ The minimum distance between markers. Only the markers with a minimum distance greater than or equal to
422
+ `min_distance` will be identified.
423
+ footprint_size : int, optional
424
+ The size of the footprint or structuring element used for peak detection. Default is 20.
425
+ footprint : ndarray, optional
426
+ The footprint or structuring element used for peak detection. If None, a square footprint of size
427
+ `footprint_size` will be used. Default is None.
428
+ return_edt : bool, optional
429
+ Whether to return the Euclidean distance transform image along with the identified marker coordinates.
430
+ If True, the function will return the marker coordinates and the distance transform image as a tuple.
431
+ If False, only the marker coordinates will be returned. Default is False.
432
+
433
+ Returns
434
+ -------
435
+ ndarray or tuple
436
+ If `return_edt` is False, returns the identified marker coordinates as an ndarray of shape (N, 2), where N is
437
+ the number of identified markers. If `return_edt` is True, returns a tuple containing the marker coordinates
438
+ and the distance transform image.
439
+
440
+ Notes
441
+ -----
442
+ This function uses the distance transform of the binary image to identify markers by detecting local maxima. The
443
+ distance transform assigns each pixel a value representing the Euclidean distance to the nearest background pixel.
444
+ By finding peaks in the distance transform, we can identify the markers in the original binary image. The `min_distance`
445
+ parameter controls the minimum distance between markers to avoid clustering.
446
+
447
+ """
448
+
449
+ distance = ndi.distance_transform_edt(binary_image.astype(float))
450
+ if footprint is None:
451
+ footprint = np.ones((footprint_size, footprint_size))
452
+ coords = peak_local_max(distance, footprint=footprint,
453
+ labels=binary_image.astype(int), min_distance=min_distance)
454
+ if return_edt:
455
+ return coords, distance
456
+ else:
457
+ return coords
458
+
459
+
460
+ def threshold_image(img, min_threshold, max_threshold, foreground_value=255., fill_holes=True):
461
+
462
+ """
463
+
464
+ Threshold the input image to create a binary mask.
465
+
466
+ Parameters
467
+ ----------
468
+ img : ndarray
469
+ The input image to be thresholded.
470
+ min_threshold : float
471
+ The minimum threshold value.
472
+ max_threshold : float
473
+ The maximum threshold value.
474
+ foreground_value : float, optional
475
+ The value assigned to foreground pixels in the binary mask. Default is 255.
476
+ fill_holes : bool, optional
477
+ Whether to fill holes in the binary mask. If True, the binary mask will be processed to fill any holes.
478
+ If False, the binary mask will not be modified. Default is True.
479
+
480
+ Returns
481
+ -------
482
+ ndarray
483
+ The binary mask after thresholding.
484
+
485
+ Notes
486
+ -----
487
+ This function applies a threshold to the input image to create a binary mask. Pixels with values within the specified
488
+ threshold range are considered as foreground and assigned the `foreground_value`, while pixels outside the range are
489
+ considered as background and assigned 0. If `fill_holes` is True, the binary mask will be processed to fill any holes
490
+ using morphological operations.
491
+
492
+ Examples
493
+ --------
494
+ >>> image = np.random.rand(256, 256)
495
+ >>> binary_mask = threshold_image(image, 0.2, 0.8, foreground_value=1., fill_holes=True)
496
+
497
+ """
498
+
499
+
500
+ binary = (img>=min_threshold)*(img<=max_threshold) * foreground_value
501
+ if fill_holes:
502
+ binary = ndi.binary_fill_holes(binary)
503
+ return binary
504
+
505
+ def filter_image(img, filters=None):
506
+
507
+ """
508
+
509
+ Apply one or more image filters to the input image.
510
+
511
+ Parameters
512
+ ----------
513
+ img : ndarray
514
+ The input image to be filtered.
515
+ filters : list or None, optional
516
+ A list of filters to be applied to the image. Each filter is represented as a tuple or list with the first element being
517
+ the filter function name (minus the '_filter' extension, as listed in software.filters) and the subsequent elements being
518
+ the arguments for that filter function. If None, the original image is returned without any filtering applied. Default is None.
519
+
520
+ Returns
521
+ -------
522
+ ndarray
523
+ The filtered image.
524
+
525
+ Notes
526
+ -----
527
+ This function applies a series of image filters to the input image. The filters are specified as a list of tuples,
528
+ where each tuple contains the name of the filter function and its corresponding arguments. The filters are applied
529
+ sequentially to the image. If no filters are provided, the original image is returned unchanged.
530
+
531
+ Examples
532
+ --------
533
+ >>> image = np.random.rand(256, 256)
534
+ >>> filtered_image = filter_image(image, filters=[('gaussian', 3), ('median', 5)])
535
+
536
+ """
537
+
538
+ if filters is None:
539
+ return img
540
+
541
+ if img.ndim==3:
542
+ img = np.squeeze(img)
543
+
544
+ for f in filters:
545
+ func = eval(f[0]+'_filter')
546
+ img = func(img, *f[1:])
547
+ return img
548
+
549
+
550
+ def segment_at_position(pos, mode, model_name, stack_prefix=None, use_gpu=True, return_labels=False, view_on_napari=False, threads=1):
551
+
552
+ """
553
+ Perform image segmentation at the specified position using a pre-trained model.
554
+
555
+ Parameters
556
+ ----------
557
+ pos : str
558
+ The path to the position directory containing the input images to be segmented.
559
+ mode : str
560
+ The segmentation mode. This determines the type of objects to be segmented ('target' or 'effector').
561
+ model_name : str
562
+ The name of the pre-trained segmentation model to be used.
563
+ stack_prefix : str or None, optional
564
+ The prefix of the stack file name. Defaults to None.
565
+ use_gpu : bool, optional
566
+ Whether to use the GPU for segmentation if available. Defaults to True.
567
+ return_labels : bool, optional
568
+ If True, the function returns the segmentation labels as an output. Defaults to False.
569
+ view_on_napari : bool, optional
570
+ If True, the segmented labels are displayed in a Napari viewer. Defaults to False.
571
+
572
+ Returns
573
+ -------
574
+ numpy.ndarray or None
575
+ If `return_labels` is True, the function returns the segmentation labels as a NumPy array. Otherwise, it returns None. The subprocess writes the
576
+ segmentation labels in the position directory.
577
+
578
+ Examples
579
+ --------
580
+ >>> labels = segment_at_position('ExperimentFolder/W1/100/', 'effector', 'mice_t_cell_RICM', return_labels=True)
581
+
582
+ """
583
+
584
+ pos = pos.replace('\\','/')
585
+ pos = rf'{pos}'
586
+ assert os.path.exists(pos),f'Position {pos} is not a valid path.'
587
+
588
+ name_path = locate_segmentation_model(model_name)
589
+
590
+ script_path = os.sep.join([abs_path, 'scripts', 'segment_cells.py'])
591
+ cmd = f'python "{script_path}" --pos "{pos}" --model "{model_name}" --mode "{mode}" --use_gpu "{use_gpu}" --threads "{threads}"'
592
+ subprocess.call(cmd, shell=True)
593
+
594
+ if return_labels or view_on_napari:
595
+ labels = locate_labels(pos, population=mode)
596
+ if view_on_napari:
597
+ if stack_prefix is None:
598
+ stack_prefix = ''
599
+ stack = locate_stack(pos, prefix=stack_prefix)
600
+ _view_on_napari(tracks=None, stack=stack, labels=labels)
601
+ if return_labels:
602
+ return labels
603
+ else:
604
+ return None
605
+
606
+ def segment_from_threshold_at_position(pos, mode, config, threads=1):
607
+
608
+ """
609
+ Executes a segmentation script on a specified position directory using a given configuration and mode.
610
+
611
+ This function calls an external Python script designed to segment images at a specified position directory.
612
+ The segmentation is configured through a JSON file and can operate in different modes specified by the user.
613
+ The function can leverage multiple threads to potentially speed up the processing.
614
+
615
+ Parameters
616
+ ----------
617
+ pos : str
618
+ The file path to the position directory where images to be segmented are stored. The path must be valid.
619
+ mode : str
620
+ The operation mode for the segmentation script. The mode determines how the segmentation is performed and
621
+ which algorithm or parameters are used.
622
+ config : str
623
+ The file path to the JSON configuration file that specifies parameters for the segmentation process. The
624
+ path must be valid.
625
+ threads : int, optional
626
+ The number of threads to use for processing. Using more than one thread can speed up segmentation on
627
+ systems with multiple CPU cores (default is 1).
628
+
629
+ Raises
630
+ ------
631
+ AssertionError
632
+ If either the `pos` or `config` paths do not exist.
633
+
634
+ Notes
635
+ -----
636
+ - The external segmentation script (`segment_cells_thresholds.py`) is expected to be located in a specific
637
+ directory relative to this function.
638
+ - The segmentation process and its parameters, including modes and thread usage, are defined by the external
639
+ script and the configuration file.
640
+
641
+ Examples
642
+ --------
643
+ >>> pos = '/path/to/position'
644
+ >>> mode = 'default'
645
+ >>> config = '/path/to/config.json'
646
+ >>> segment_from_threshold_at_position(pos, mode, config, threads=2)
647
+ # This will execute the segmentation script on the specified position directory with the given mode and
648
+ # configuration, utilizing 2 threads.
649
+
650
+ """
651
+
652
+
653
+ pos = pos.replace('\\','/')
654
+ pos = rf"{pos}"
655
+ assert os.path.exists(pos),f'Position {pos} is not a valid path.'
656
+
657
+ config = config.replace('\\','/')
658
+ config = rf"{config}"
659
+ assert os.path.exists(config),f'Config {config} is not a valid path.'
660
+
661
+ script_path = os.sep.join([abs_path, 'scripts', 'segment_cells_thresholds.py'])
662
+ cmd = f'python "{script_path}" --pos "{pos}" --config "{config}" --mode "{mode}" --threads "{threads}"'
663
+ subprocess.call(cmd, shell=True)
664
+
665
+
666
+ def train_segmentation_model(config, use_gpu=True):
667
+
668
+ """
669
+ Trains a segmentation model based on a specified configuration file.
670
+
671
+ This function initiates the training of a segmentation model by calling an external Python script,
672
+ which reads the training parameters and dataset information from a given JSON configuration file.
673
+ The training process, including model architecture, training data, and hyperparameters, is defined
674
+ by the contents of the configuration file.
675
+
676
+ Parameters
677
+ ----------
678
+ config : str
679
+ The file path to the JSON configuration file that specifies training parameters and dataset
680
+ information for the segmentation model. The path must be valid.
681
+
682
+ Raises
683
+ ------
684
+ AssertionError
685
+ If the `config` path does not exist.
686
+
687
+ Notes
688
+ -----
689
+ - The external training script (`train_segmentation_model.py`) is assumed to be located in a specific
690
+ directory relative to this function.
691
+ - The segmentation model and training process are highly dependent on the details specified in the
692
+ configuration file, including the model architecture, loss functions, optimizer settings, and
693
+ training/validation data paths.
694
+
695
+ Examples
696
+ --------
697
+ >>> config = '/path/to/training_config.json'
698
+ >>> train_segmentation_model(config)
699
+ # Initiates the training of a segmentation model using the parameters specified in the given configuration file.
700
+
701
+ """
702
+
703
+ config = config.replace('\\','/')
704
+ config = rf"{config}"
705
+ assert os.path.exists(config),f'Config {config} is not a valid path.'
706
+
707
+ script_path = os.sep.join([abs_path, 'scripts', 'train_segmentation_model.py'])
708
+ cmd = f'python "{script_path}" --config "{config}" --use_gpu "{use_gpu}"'
709
+ subprocess.call(cmd, shell=True)
710
+
711
+ if __name__ == "__main__":
712
+ print(segment(None,'test'))