celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +403 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/downloader.py +137 -0
  81. celldetective/processes/measure_cells.py +565 -0
  82. celldetective/processes/segment_cells.py +760 -0
  83. celldetective/processes/track_cells.py +435 -0
  84. celldetective/processes/train_segmentation_model.py +694 -0
  85. celldetective/processes/train_signal_model.py +265 -0
  86. celldetective/processes/unified_process.py +292 -0
  87. celldetective/regionprops/_regionprops.py +358 -317
  88. celldetective/relative_measurements.py +987 -710
  89. celldetective/scripts/measure_cells.py +313 -212
  90. celldetective/scripts/measure_relative.py +90 -46
  91. celldetective/scripts/segment_cells.py +165 -104
  92. celldetective/scripts/segment_cells_thresholds.py +96 -68
  93. celldetective/scripts/track_cells.py +198 -149
  94. celldetective/scripts/train_segmentation_model.py +324 -201
  95. celldetective/scripts/train_signal_model.py +87 -45
  96. celldetective/segmentation.py +844 -749
  97. celldetective/signals.py +3514 -2861
  98. celldetective/tracking.py +30 -15
  99. celldetective/utils/__init__.py +0 -0
  100. celldetective/utils/cellpose_utils/__init__.py +133 -0
  101. celldetective/utils/color_mappings.py +42 -0
  102. celldetective/utils/data_cleaning.py +630 -0
  103. celldetective/utils/data_loaders.py +450 -0
  104. celldetective/utils/dataset_helpers.py +207 -0
  105. celldetective/utils/downloaders.py +235 -0
  106. celldetective/utils/event_detection/__init__.py +8 -0
  107. celldetective/utils/experiment.py +1782 -0
  108. celldetective/utils/image_augmenters.py +308 -0
  109. celldetective/utils/image_cleaning.py +74 -0
  110. celldetective/utils/image_loaders.py +926 -0
  111. celldetective/utils/image_transforms.py +335 -0
  112. celldetective/utils/io.py +62 -0
  113. celldetective/utils/mask_cleaning.py +348 -0
  114. celldetective/utils/mask_transforms.py +5 -0
  115. celldetective/utils/masks.py +184 -0
  116. celldetective/utils/maths.py +351 -0
  117. celldetective/utils/model_getters.py +325 -0
  118. celldetective/utils/model_loaders.py +296 -0
  119. celldetective/utils/normalization.py +380 -0
  120. celldetective/utils/parsing.py +465 -0
  121. celldetective/utils/plots/__init__.py +0 -0
  122. celldetective/utils/plots/regression.py +53 -0
  123. celldetective/utils/resources.py +34 -0
  124. celldetective/utils/stardist_utils/__init__.py +104 -0
  125. celldetective/utils/stats.py +90 -0
  126. celldetective/utils/types.py +21 -0
  127. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
  128. celldetective-1.5.0b1.dist-info/RECORD +187 -0
  129. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
  130. tests/gui/test_new_project.py +129 -117
  131. tests/gui/test_project.py +127 -79
  132. tests/test_filters.py +39 -15
  133. tests/test_notebooks.py +8 -0
  134. tests/test_tracking.py +232 -13
  135. tests/test_utils.py +123 -77
  136. celldetective/gui/base_components.py +0 -23
  137. celldetective/gui/layouts.py +0 -1602
  138. celldetective/gui/processes/compute_neighborhood.py +0 -594
  139. celldetective/gui/processes/downloader.py +0 -111
  140. celldetective/gui/processes/measure_cells.py +0 -360
  141. celldetective/gui/processes/segment_cells.py +0 -499
  142. celldetective/gui/processes/track_cells.py +0 -303
  143. celldetective/gui/processes/train_segmentation_model.py +0 -270
  144. celldetective/gui/processes/train_signal_model.py +0 -108
  145. celldetective/gui/table_ops/merge_groups.py +0 -118
  146. celldetective/gui/viewers.py +0 -1354
  147. celldetective/io.py +0 -3663
  148. celldetective/utils.py +0 -3108
  149. celldetective-1.4.2.dist-info/RECORD +0 -123
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
  152. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
@@ -1,778 +1,873 @@
1
1
  """
2
2
  Segmentation module
3
3
  """
4
+
4
5
  import json
5
6
  import os
6
- from .io import locate_segmentation_model, normalize_multichannel
7
- from .utils import _estimate_scale_factor, _extract_channel_indices
7
+ from typing import List, Optional, Union
8
+
9
+ from celldetective.utils.model_loaders import locate_segmentation_model
10
+ from celldetective.utils.normalization import normalize_multichannel
8
11
  from pathlib import Path
9
12
  from tqdm import tqdm
10
- import numpy as np
11
- from celldetective.io import _view_on_napari, locate_labels, locate_stack, _view_on_napari, _check_label_dims, auto_correct_masks
12
- from celldetective.filters import * #rework this to give a name
13
- from celldetective.utils import interpolate_nan_multichannel,_rearrange_multichannel_frame, _fix_no_contrast, zoom_multiframes, _rescale_labels, rename_intensity_column, mask_edges, _prep_stardist_model, _prep_cellpose_model, estimate_unreliable_edge,_get_normalize_kwargs_from_config, _segment_image_with_stardist_model, _segment_image_with_cellpose_model
14
- from stardist import fill_label_holes
15
- from stardist.matching import matching
13
+ from celldetective.utils.image_loaders import (
14
+ locate_stack,
15
+ locate_labels,
16
+ _rearrange_multichannel_frame,
17
+ zoom_multiframes,
18
+ _extract_channel_indices,
19
+ )
20
+ from celldetective.utils.mask_cleaning import _check_label_dims, auto_correct_masks
21
+ from celldetective.utils.image_cleaning import (
22
+ _fix_no_contrast,
23
+ interpolate_nan_multichannel,
24
+ )
25
+ from celldetective.napari.utils import _view_on_napari
26
+ from celldetective.filters import *
27
+ from celldetective.utils.stardist_utils import (
28
+ _prep_stardist_model,
29
+ _segment_image_with_stardist_model,
30
+ )
31
+ from celldetective.utils.cellpose_utils import (
32
+ _segment_image_with_cellpose_model,
33
+ _prep_cellpose_model,
34
+ )
35
+ from celldetective.utils.mask_transforms import _rescale_labels
36
+ from celldetective.utils.image_transforms import (
37
+ estimate_unreliable_edge,
38
+ _estimate_scale_factor,
39
+ threshold_image,
40
+ )
41
+ from celldetective.utils.data_cleaning import rename_intensity_column
42
+ from celldetective.utils.parsing import _get_normalize_kwargs_from_config
16
43
 
17
44
  import scipy.ndimage as ndi
18
45
  from skimage.segmentation import watershed
19
46
  from skimage.feature import peak_local_max
20
47
  from skimage.measure import regionprops_table
21
48
  from skimage.exposure import match_histograms
22
- from scipy.ndimage import zoom
23
- import pandas as pd
24
- import subprocess
25
-
26
49
 
27
- abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
28
-
29
- def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_napari=False,
30
- use_gpu=True, channel_axis=-1, cellprob_threshold=None, flow_threshold=None):
31
-
32
- """
33
-
34
- Segment objects in a stack using a pre-trained segmentation model.
35
-
36
- Parameters
37
- ----------
38
- stack : ndarray
39
- The input stack to be segmented, with shape (frames, height, width, channels).
40
- model_name : str
41
- The name of the pre-trained segmentation model to use.
42
- channels : list or None, optional
43
- The names of the channels in the stack. If None, assumes the channels are indexed from 0 to `stack.shape[-1] - 1`.
44
- Default is None.
45
- spatial_calibration : float or None, optional
46
- The spatial calibration factor of the stack. If None, the calibration factor from the model configuration will be used.
47
- Default is None.
48
- view_on_napari : bool, optional
49
- Whether to visualize the segmentation results using Napari. Default is False.
50
- use_gpu : bool, optional
51
- Whether to use GPU acceleration if available. Default is True.
52
-
53
- Returns
54
- -------
55
- ndarray
56
- The segmented labels with shape (frames, height, width).
57
-
58
- Notes
59
- -----
60
- This function applies object segmentation to a stack of images using a pre-trained segmentation model. The stack is first
61
- preprocessed by normalizing the intensity values, rescaling the spatial dimensions, and applying the segmentation model.
62
- The resulting labels are returned as an ndarray with the same number of frames as the input stack.
63
-
64
- Examples
65
- --------
66
- >>> stack = np.random.rand(10, 256, 256, 3)
67
- >>> labels = segment(stack, 'model_name', channels=['channel_1', 'channel_2', 'channel_3'], spatial_calibration=0.5)
68
-
69
- """
70
-
71
- model_path = locate_segmentation_model(model_name)
72
- input_config = model_path+'config_input.json'
73
- if os.path.exists(input_config):
74
- with open(input_config) as config:
75
- print("Loading input configuration from 'config_input.json'.")
76
- input_config = json.load(config)
77
- else:
78
- print('Model input configuration could not be located...')
79
- return None
80
-
81
- if not use_gpu:
82
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
83
- else:
84
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
85
-
86
- if channel_axis != -1:
87
- stack = np.moveaxis(stack, channel_axis, -1)
88
-
89
- if channels is not None:
90
- assert len(channels)==stack.shape[-1],f'The channel names provided do not match with the expected number of channels in the stack: {stack.shape[-1]}.'
91
-
92
- required_channels = input_config['channels']
93
- channel_intersection = [ch for ch in channels if ch in required_channels]
94
- assert len(channel_intersection)>0,'None of the channels required by the model can be found in the images to segment... Abort.'
95
-
96
- channel_indices = _extract_channel_indices(channels, required_channels)
97
-
98
- required_spatial_calibration = input_config['spatial_calibration']
99
- model_type = input_config['model_type']
100
-
101
- normalize_kwargs = _get_normalize_kwargs_from_config(input_config)
102
-
103
- if model_type=='cellpose':
104
- diameter = input_config['diameter']
105
- # if diameter!=30:
106
- # required_spatial_calibration = None
107
- if cellprob_threshold is None:
108
- cellprob_threshold = input_config['cellprob_threshold']
109
- if flow_threshold is None:
110
- flow_threshold = input_config['flow_threshold']
111
-
112
- scale = _estimate_scale_factor(spatial_calibration, required_spatial_calibration)
113
- print(f"{spatial_calibration=} {required_spatial_calibration=} Scale = {scale}...")
114
-
115
- if model_type=='stardist':
116
- model, scale_model = _prep_stardist_model(model_name, Path(model_path).parent, use_gpu=use_gpu, scale=scale)
117
-
118
- elif model_type=='cellpose':
119
- model, scale_model = _prep_cellpose_model(model_path.split('/')[-2], model_path, use_gpu=use_gpu, n_channels=len(required_channels), scale=scale)
120
-
121
- labels = []
122
-
123
- for t in tqdm(range(len(stack)),desc="frame"):
124
-
125
- # normalize
126
- channel_indices = np.array(channel_indices)
127
- none_channel_indices = np.where(channel_indices==None)[0]
128
- channel_indices[channel_indices==None] = 0
129
-
130
- frame = stack[t]
131
- frame = _rearrange_multichannel_frame(frame).astype(float)
132
-
133
- frame_to_segment = np.zeros((frame.shape[0], frame.shape[1], len(required_channels))).astype(float)
134
- for ch in channel_intersection:
135
- idx = required_channels.index(ch)
136
- frame_to_segment[:,:,idx] = frame[:,:,channels.index(ch)]
137
- frame = frame_to_segment
138
- template = frame.copy()
139
-
140
- frame = normalize_multichannel(frame, **normalize_kwargs)
141
-
142
- if scale_model is not None:
143
- frame = zoom_multiframes(frame, scale_model)
144
-
145
- frame = _fix_no_contrast(frame)
146
- frame = interpolate_nan_multichannel(frame)
147
- frame[:,:,none_channel_indices] = 0.
148
-
149
- if model_type=="stardist":
150
- Y_pred = _segment_image_with_stardist_model(frame, model=model, return_details=False)
151
-
152
- elif model_type=="cellpose":
153
- Y_pred = _segment_image_with_cellpose_model(frame, model=model, diameter=diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold)
154
-
155
- if Y_pred.shape != stack[0].shape[:2]:
156
- Y_pred = _rescale_labels(Y_pred, scale_model)
157
-
158
- Y_pred = _check_label_dims(Y_pred, template=template)
159
-
160
- labels.append(Y_pred)
161
-
162
- labels = np.array(labels,dtype=int)
163
-
164
- if view_on_napari:
165
- _view_on_napari(tracks=None, stack=stack, labels=labels)
166
-
167
- return labels
168
-
169
-
170
- def segment_from_thresholds(stack, target_channel=0, thresholds=None, view_on_napari=False, equalize_reference=None,
171
- filters=None, marker_min_distance=30, marker_footprint_size=20, marker_footprint=None, feature_queries=None):
172
-
173
- """
174
- Segments objects from a stack of images based on provided thresholds and optional image processing steps.
175
-
176
- This function applies instance segmentation to each frame in a stack of images. Segmentation is based on intensity
177
- thresholds, optionally preceded by image equalization and filtering. Identified objects can
178
- be distinguished by applying distance-based marker detection. The segmentation results can be optionally viewed in Napari.
179
-
180
- Parameters
181
- ----------
182
- stack : ndarray
183
- A 4D numpy array representing the image stack with dimensions (T, Y, X, C) where T is the
184
- time dimension and C the channel dimension.
185
- target_channel : int, optional
186
- The channel index to be used for segmentation (default is 0).
187
- thresholds : list of tuples, optional
188
- A list of tuples specifying intensity thresholds for segmentation. Each tuple corresponds to a frame in the stack,
189
- with values (lower_threshold, upper_threshold). If None, global thresholds are determined automatically (default is None).
190
- view_on_napari : bool, optional
191
- If True, displays the original stack and segmentation results in Napari (default is False).
192
- equalize_reference : int or None, optional
193
- The index of a reference frame used for histogram equalization. If None, equalization is not performed (default is None).
194
- filters : list of dict, optional
195
- A list of dictionaries specifying filters to be applied pre-segmentation. Each dictionary should
196
- contain filter parameters (default is None).
197
- marker_min_distance : int, optional
198
- The minimum distance between markers used for distinguishing separate objects (default is 30).
199
- marker_footprint_size : int, optional
200
- The size of the footprint used for local maxima detection when generating markers (default is 20).
201
- marker_footprint : ndarray or None, optional
202
- An array specifying the footprint used for local maxima detection. Overrides `marker_footprint_size` if provided
203
- (default is None).
204
- feature_queries : list of str or None, optional
205
- A list of query strings used to select features of interest from the segmented objects (default is None).
206
-
207
- Returns
208
- -------
209
- ndarray
210
- A 3D numpy array (T, Y, X) of type int16, where each element represents the segmented object label at each pixel.
211
-
212
- Notes
213
- -----
214
- - The segmentation process can be customized extensively via the parameters, allowing for complex segmentation tasks.
215
-
216
- """
217
-
218
-
219
- masks = []
220
- for t in tqdm(range(len(stack))):
221
- instance_seg = segment_frame_from_thresholds(stack[t], target_channel=target_channel, thresholds=thresholds, equalize_reference=equalize_reference,
222
- filters=filters, marker_min_distance=marker_min_distance, marker_footprint_size=marker_footprint_size,
223
- marker_footprint=marker_footprint, feature_queries=feature_queries)
224
- masks.append(instance_seg)
225
-
226
- masks = np.array(masks, dtype=np.int16)
227
- if view_on_napari:
228
- _view_on_napari(tracks=None, stack=stack, labels=masks)
229
- return masks
230
-
231
- def segment_frame_from_thresholds(frame, target_channel=0, thresholds=None, equalize_reference=None,
232
- filters=None, marker_min_distance=30, marker_footprint_size=20, marker_footprint=None, feature_queries=None, channel_names=None, do_watershed=True, edge_exclusion=True, fill_holes=True):
233
-
234
- """
235
- Segments objects within a single frame based on intensity thresholds and optional image processing steps.
236
-
237
- This function performs instance segmentation on a single frame using intensity thresholds, with optional steps
238
- including histogram equalization, filtering, and marker-based watershed segmentation. The segmented
239
- objects can be further filtered based on specified features.
240
-
241
- Parameters
242
- ----------
243
- frame : ndarray
244
- A 3D numpy array representing a single frame with dimensions (Y, X, C).
245
- target_channel : int, optional
246
- The channel index to be used for segmentation (default is 0).
247
- thresholds : tuple of int, optional
248
- A tuple specifying the intensity thresholds for segmentation, in the form (lower_threshold, upper_threshold).
249
- equalize_reference : ndarray or None, optional
250
- A 2D numpy array used as a reference for histogram equalization. If None, equalization is not performed (default is None).
251
- filters : list of dict, optional
252
- A list of dictionaries specifying filters to be applied to the image before segmentation. Each dictionary
253
- should contain filter parameters (default is None).
254
- marker_min_distance : int, optional
255
- The minimum distance between markers used for distinguishing separate objects during watershed segmentation (default is 30).
256
- marker_footprint_size : int, optional
257
- The size of the footprint used for local maxima detection when generating markers for watershed segmentation (default is 20).
258
- marker_footprint : ndarray or None, optional
259
- An array specifying the footprint used for local maxima detection. Overrides `marker_footprint_size` if provided (default is None).
260
- feature_queries : list of str or None, optional
261
- A list of query strings used to select features of interest from the segmented objects for further filtering (default is None).
262
- channel_names : list of str or None, optional
263
- A list of channel names corresponding to the dimensions in `frame`, used in conjunction with `feature_queries` for feature selection (default is None).
264
-
265
- Returns
266
- -------
267
- ndarray
268
- A 2D numpy array of type int, where each element represents the segmented object label at each pixel.
269
-
270
- """
271
-
272
- if frame.ndim==2:
273
- frame = frame[:,:,np.newaxis]
274
- img = frame[:,:,target_channel]
275
-
276
- if np.any(img!=img):
277
- img = interpolate_nan(img)
278
-
279
- if equalize_reference is not None:
280
- img = match_histograms(img, equalize_reference)
281
-
282
- img_mc = frame.copy()
283
- img = filter_image(img, filters=filters)
284
- if edge_exclusion:
285
- edge = estimate_unreliable_edge(filters)
286
- else:
287
- edge = None
288
-
289
- binary_image = threshold_image(img, thresholds[0], thresholds[1], fill_holes=fill_holes, edge_exclusion=edge)
290
-
291
- if do_watershed:
292
- coords,distance = identify_markers_from_binary(binary_image, marker_min_distance, footprint_size=marker_footprint_size, footprint=marker_footprint, return_edt=True)
293
- instance_seg = apply_watershed(binary_image, coords, distance)
294
- else:
295
- instance_seg, _ = ndi.label(binary_image.astype(int).copy())
296
-
297
- instance_seg = filter_on_property(instance_seg, intensity_image=img_mc, queries=feature_queries, channel_names=channel_names)
298
-
299
- return instance_seg
50
+ import subprocess
51
+ from celldetective.log_manager import get_logger
52
+
53
+ logger = get_logger(__name__)
54
+
55
+ abs_path = os.sep.join(
56
+ [os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], "celldetective"]
57
+ )
58
+
59
+
60
+ def segment(
61
+ stack: Union[np.ndarray, List],
62
+ model_name: str,
63
+ channels: Optional[List[str]] = None,
64
+ spatial_calibration: Optional[float] = None,
65
+ view_on_napari: bool = False,
66
+ use_gpu: bool = True,
67
+ channel_axis: int = -1,
68
+ cellprob_threshold: float = None,
69
+ flow_threshold: float = None,
70
+ ):
71
+ """
72
+
73
+ Segment objects in a stack using a pre-trained segmentation model.
74
+
75
+ Parameters
76
+ ----------
77
+ stack : ndarray
78
+ The input stack to be segmented, with shape (frames, height, width, channels).
79
+ model_name : str
80
+ The name of the pre-trained segmentation model to use.
81
+ channels : list or None, optional
82
+ The names of the channels in the stack. If None, assumes the channels are indexed from 0 to `stack.shape[-1] - 1`.
83
+ Default is None.
84
+ spatial_calibration : float or None, optional
85
+ The spatial calibration factor of the stack. If None, the calibration factor from the model configuration will be used.
86
+ Default is None.
87
+ view_on_napari : bool, optional
88
+ Whether to visualize the segmentation results using Napari. Default is False.
89
+ use_gpu : bool, optional
90
+ Whether to use GPU acceleration if available. Default is True.
91
+ channel_axis : int, optional
92
+ Channel axis in the input array. Default is the last (-1).
93
+ cellprob_threshold : float, optional
94
+ Cell probability threshold for Cellpose mask computation. Default is None.
95
+ flow_threshold : float, optional
96
+ Flow threshold for Cellpose mask computation. Default is None.
97
+
98
+ Returns
99
+ -------
100
+ ndarray
101
+ The segmented labels with shape (frames, height, width).
102
+
103
+ Notes
104
+ -----
105
+ This function applies object segmentation to a stack of images using a pre-trained segmentation model. The stack is first
106
+ preprocessed by normalizing the intensity values, rescaling the spatial dimensions, and applying the segmentation model.
107
+ The resulting labels are returned as an ndarray with the same number of frames as the input stack.
108
+
109
+ Examples
110
+ --------
111
+ >>> stack = np.random.rand(10, 256, 256, 3)
112
+ >>> labels = segment(stack, 'model_name', channels=['channel_1', 'channel_2', 'channel_3'], spatial_calibration=0.5)
113
+
114
+ """
115
+
116
+ model_path = locate_segmentation_model(model_name)
117
+ input_config = model_path + "config_input.json"
118
+ if os.path.exists(input_config):
119
+ with open(input_config) as config:
120
+ logger.info("Loading input configuration from 'config_input.json'.")
121
+ input_config = json.load(config)
122
+ else:
123
+ logger.error("Model input configuration could not be located...")
124
+ return None
125
+
126
+ if not use_gpu:
127
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
128
+ else:
129
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
130
+
131
+ if channel_axis != -1:
132
+ stack = np.moveaxis(stack, channel_axis, -1)
133
+
134
+ if channels is not None:
135
+ assert (
136
+ len(channels) == stack.shape[-1]
137
+ ), f"The channel names provided do not match with the expected number of channels in the stack: {stack.shape[-1]}."
138
+
139
+ required_channels = input_config["channels"]
140
+ channel_intersection = [ch for ch in channels if ch in required_channels]
141
+ assert (
142
+ len(channel_intersection) > 0
143
+ ), "None of the channels required by the model can be found in the images to segment... Abort."
144
+
145
+ channel_indices = _extract_channel_indices(channels, required_channels)
146
+
147
+ required_spatial_calibration = input_config["spatial_calibration"]
148
+ model_type = input_config["model_type"]
149
+
150
+ normalize_kwargs = _get_normalize_kwargs_from_config(input_config)
151
+
152
+ if model_type == "cellpose":
153
+ diameter = input_config["diameter"]
154
+ # if diameter!=30:
155
+ # required_spatial_calibration = None
156
+ if cellprob_threshold is None:
157
+ cellprob_threshold = input_config["cellprob_threshold"]
158
+ if flow_threshold is None:
159
+ flow_threshold = input_config["flow_threshold"]
160
+
161
+ scale = _estimate_scale_factor(spatial_calibration, required_spatial_calibration)
162
+ logger.info(
163
+ f"{spatial_calibration=} {required_spatial_calibration=} Scale = {scale}..."
164
+ )
165
+
166
+ if model_type == "stardist":
167
+ model, scale_model = _prep_stardist_model(
168
+ model_name, Path(model_path).parent, use_gpu=use_gpu, scale=scale
169
+ )
170
+
171
+ elif model_type == "cellpose":
172
+ model, scale_model = _prep_cellpose_model(
173
+ model_path.split("/")[-2],
174
+ model_path,
175
+ use_gpu=use_gpu,
176
+ n_channels=len(required_channels),
177
+ scale=scale,
178
+ )
179
+
180
+
181
+ if model is None:
182
+ logger.error(f"Could not load model {model_name}. Aborting segmentation.")
183
+ return None
184
+
185
+ labels = []
186
+
187
+ for t in tqdm(range(len(stack)), desc="frame"):
188
+
189
+ # normalize
190
+ channel_indices = np.array(channel_indices)
191
+ none_channel_indices = np.where(channel_indices == None)[0]
192
+ channel_indices[channel_indices == None] = 0
193
+
194
+ frame = stack[t]
195
+ frame = _rearrange_multichannel_frame(frame).astype(float)
196
+
197
+ frame_to_segment = np.zeros(
198
+ (frame.shape[0], frame.shape[1], len(required_channels))
199
+ ).astype(float)
200
+ for ch in channel_intersection:
201
+ idx = required_channels.index(ch)
202
+ frame_to_segment[:, :, idx] = frame[:, :, channels.index(ch)]
203
+ frame = frame_to_segment
204
+ template = frame.copy()
205
+
206
+ frame = normalize_multichannel(frame, **normalize_kwargs)
207
+
208
+ if scale_model is not None:
209
+ frame = zoom_multiframes(frame, scale_model)
210
+
211
+ frame = _fix_no_contrast(frame)
212
+ frame = interpolate_nan_multichannel(frame)
213
+ frame[:, :, none_channel_indices] = 0.0
214
+
215
+ if model_type == "stardist":
216
+ Y_pred = _segment_image_with_stardist_model(
217
+ frame, model=model, return_details=False
218
+ )
219
+
220
+ elif model_type == "cellpose":
221
+ Y_pred = _segment_image_with_cellpose_model(
222
+ frame,
223
+ model=model,
224
+ diameter=diameter,
225
+ cellprob_threshold=cellprob_threshold,
226
+ flow_threshold=flow_threshold,
227
+ )
228
+
229
+ if Y_pred.shape != stack[0].shape[:2]:
230
+ Y_pred = _rescale_labels(Y_pred, scale_model)
231
+
232
+ Y_pred = _check_label_dims(Y_pred, template=template)
233
+
234
+ labels.append(Y_pred)
235
+
236
+ labels = np.array(labels, dtype=int)
237
+
238
+ if view_on_napari:
239
+ _view_on_napari(tracks=None, stack=stack, labels=labels)
240
+
241
+ return labels
242
+
243
+
244
+ def segment_from_thresholds(
245
+ stack,
246
+ target_channel=0,
247
+ thresholds=None,
248
+ view_on_napari=False,
249
+ equalize_reference=None,
250
+ filters=None,
251
+ marker_min_distance=30,
252
+ marker_footprint_size=20,
253
+ marker_footprint=None,
254
+ feature_queries=None,
255
+ fill_holes=True,
256
+ ):
257
+ """
258
+ Segments objects from a stack of images based on provided thresholds and optional image processing steps.
259
+
260
+ This function applies instance segmentation to each frame in a stack of images. Segmentation is based on intensity
261
+ thresholds, optionally preceded by image equalization and filtering. Identified objects can
262
+ be distinguished by applying distance-based marker detection. The segmentation results can be optionally viewed in Napari.
263
+
264
+ Parameters
265
+ ----------
266
+ stack : ndarray
267
+ A 4D numpy array representing the image stack with dimensions (T, Y, X, C) where T is the
268
+ time dimension and C the channel dimension.
269
+ target_channel : int, optional
270
+ The channel index to be used for segmentation (default is 0).
271
+ thresholds : list of tuples, optional
272
+ A list of tuples specifying intensity thresholds for segmentation. Each tuple corresponds to a frame in the stack,
273
+ with values (lower_threshold, upper_threshold). If None, global thresholds are determined automatically (default is None).
274
+ view_on_napari : bool, optional
275
+ If True, displays the original stack and segmentation results in Napari (default is False).
276
+ equalize_reference : int or None, optional
277
+ The index of a reference frame used for histogram equalization. If None, equalization is not performed (default is None).
278
+ filters : list of dict, optional
279
+ A list of dictionaries specifying filters to be applied pre-segmentation. Each dictionary should
280
+ contain filter parameters (default is None).
281
+ marker_min_distance : int, optional
282
+ The minimum distance between markers used for distinguishing separate objects (default is 30).
283
+ marker_footprint_size : int, optional
284
+ The size of the footprint used for local maxima detection when generating markers (default is 20).
285
+ marker_footprint : ndarray or None, optional
286
+ An array specifying the footprint used for local maxima detection. Overrides `marker_footprint_size` if provided
287
+ (default is None).
288
+ feature_queries : list of str or None, optional
289
+ A list of query strings used to select features of interest from the segmented objects (default is None).
290
+ fill_holes : bool, optional
291
+ Whether to fill holes in the binary mask. If True, the binary mask will be processed to fill any holes.
292
+ If False, the binary mask will not be modified. Default is True.
293
+
294
+ Returns
295
+ -------
296
+ ndarray
297
+ A 3D numpy array (T, Y, X) of type int16, where each element represents the segmented object label at each pixel.
298
+
299
+ Notes
300
+ -----
301
+ - The segmentation process can be customized extensively via the parameters, allowing for complex segmentation tasks.
302
+
303
+ """
304
+
305
+ masks = []
306
+ for t in tqdm(range(len(stack))):
307
+ instance_seg = segment_frame_from_thresholds(
308
+ stack[t],
309
+ target_channel=target_channel,
310
+ thresholds=thresholds,
311
+ equalize_reference=equalize_reference,
312
+ filters=filters,
313
+ marker_min_distance=marker_min_distance,
314
+ marker_footprint_size=marker_footprint_size,
315
+ marker_footprint=marker_footprint,
316
+ feature_queries=feature_queries,
317
+ fill_holes=fill_holes,
318
+ )
319
+ masks.append(instance_seg)
320
+
321
+ masks = np.array(masks, dtype=np.int16)
322
+ if view_on_napari:
323
+ _view_on_napari(tracks=None, stack=stack, labels=masks)
324
+ return masks
325
+
326
+
327
+ def segment_frame_from_thresholds(
328
+ frame,
329
+ target_channel=0,
330
+ thresholds=None,
331
+ equalize_reference=None,
332
+ filters=None,
333
+ marker_min_distance=30,
334
+ marker_footprint_size=20,
335
+ marker_footprint=None,
336
+ feature_queries=None,
337
+ channel_names=None,
338
+ do_watershed=True,
339
+ edge_exclusion=True,
340
+ fill_holes=True,
341
+ ):
342
+ """
343
+ Segments objects within a single frame based on intensity thresholds and optional image processing steps.
344
+
345
+ This function performs instance segmentation on a single frame using intensity thresholds, with optional steps
346
+ including histogram equalization, filtering, and marker-based watershed segmentation. The segmented
347
+ objects can be further filtered based on specified features.
348
+
349
+ Parameters
350
+ ----------
351
+ frame : ndarray
352
+ A 3D numpy array representing a single frame with dimensions (Y, X, C).
353
+ target_channel : int, optional
354
+ The channel index to be used for segmentation (default is 0).
355
+ thresholds : tuple of int, optional
356
+ A tuple specifying the intensity thresholds for segmentation, in the form (lower_threshold, upper_threshold).
357
+ equalize_reference : ndarray or None, optional
358
+ A 2D numpy array used as a reference for histogram equalization. If None, equalization is not performed (default is None).
359
+ filters : list of dict, optional
360
+ A list of dictionaries specifying filters to be applied to the image before segmentation. Each dictionary
361
+ should contain filter parameters (default is None).
362
+ marker_min_distance : int, optional
363
+ The minimum distance between markers used for distinguishing separate objects during watershed segmentation (default is 30).
364
+ marker_footprint_size : int, optional
365
+ The size of the footprint used for local maxima detection when generating markers for watershed segmentation (default is 20).
366
+ marker_footprint : ndarray or None, optional
367
+ An array specifying the footprint used for local maxima detection. Overrides `marker_footprint_size` if provided (default is None).
368
+ feature_queries : list of str or None, optional
369
+ A list of query strings used to select features of interest from the segmented objects for further filtering (default is None).
370
+ channel_names : list of str or None, optional
371
+ A list of channel names corresponding to the dimensions in `frame`, used in conjunction with `feature_queries` for feature selection (default is None).
372
+
373
+ Returns
374
+ -------
375
+ ndarray
376
+ A 2D numpy array of type int, where each element represents the segmented object label at each pixel.
377
+
378
+ """
379
+
380
+ if frame.ndim == 2:
381
+ frame = frame[:, :, np.newaxis]
382
+ img = frame[:, :, target_channel]
383
+
384
+ if np.any(img != img):
385
+ img = interpolate_nan(img)
386
+
387
+ if equalize_reference is not None:
388
+ img = match_histograms(img, equalize_reference)
389
+
390
+ img_mc = frame.copy()
391
+ img = filter_image(img, filters=filters)
392
+ if edge_exclusion:
393
+ edge = estimate_unreliable_edge(filters)
394
+ else:
395
+ edge = None
396
+
397
+ binary_image = threshold_image(
398
+ img, thresholds[0], thresholds[1], fill_holes=fill_holes, edge_exclusion=edge
399
+ )
400
+
401
+ if do_watershed:
402
+ coords, distance = identify_markers_from_binary(
403
+ binary_image,
404
+ marker_min_distance,
405
+ footprint_size=marker_footprint_size,
406
+ footprint=marker_footprint,
407
+ return_edt=True,
408
+ )
409
+ instance_seg = apply_watershed(
410
+ binary_image, coords, distance, fill_holes=fill_holes
411
+ )
412
+ else:
413
+ instance_seg, _ = ndi.label(binary_image.astype(int).copy())
414
+
415
+ instance_seg = filter_on_property(
416
+ instance_seg,
417
+ intensity_image=img_mc,
418
+ queries=feature_queries,
419
+ channel_names=channel_names,
420
+ )
421
+
422
+ return instance_seg
300
423
 
301
424
 
302
425
  def filter_on_property(labels, intensity_image=None, queries=None, channel_names=None):
303
-
304
- """
305
- Filters segmented objects in a label image based on specified properties and queries.
306
-
307
- This function evaluates each segmented object (label) in the input label image against a set of queries related to its
308
- morphological and intensity properties. Objects not meeting the criteria defined in the queries are removed from the label
309
- image. This allows for the exclusion of objects based on size, shape, intensity, or custom-defined properties.
310
-
311
- Parameters
312
- ----------
313
- labels : ndarray
314
- A 2D numpy array where each unique non-zero integer represents a segmented object (label).
315
- intensity_image : ndarray, optional
316
- A 2D numpy array of the same shape as `labels`, providing intensity values for each pixel. This is used to calculate
317
- intensity-related properties of the segmented objects if provided (default is None).
318
- queries : str or list of str, optional
319
- One or more query strings used to filter the segmented objects based on their properties. Each query should be a
320
- valid pandas query string (default is None).
321
- channel_names : list of str or None, optional
322
- A list of channel names corresponding to the dimensions in the `intensity_image`. This is used to rename intensity
323
- property columns appropriately (default is None).
324
-
325
- Returns
326
- -------
327
- ndarray
328
- A 2D numpy array of the same shape as `labels`, with objects not meeting the query criteria removed.
329
-
330
- Notes
331
- -----
332
- - The function computes a set of predefined morphological properties and, if `intensity_image` is provided, intensity properties.
333
- - Queries should be structured according to pandas DataFrame query syntax and can reference any of the computed properties.
334
- - If `channel_names` is provided, intensity property column names are renamed to reflect the corresponding channel.
335
-
336
- """
337
-
338
- if queries is None:
339
- return labels
340
- else:
341
- if isinstance(queries, str):
342
- queries = [queries]
343
-
344
- props = ['label','area', 'area_bbox', 'area_convex', 'area_filled', 'axis_major_length',
345
- 'axis_minor_length', 'eccentricity', 'equivalent_diameter_area',
346
- 'euler_number', 'feret_diameter_max', 'orientation', 'perimeter',
347
- 'perimeter_crofton', 'solidity', 'centroid']
348
-
349
- intensity_props = ['intensity_mean', 'intensity_max', 'intensity_min']
350
-
351
- if intensity_image is not None:
352
- props.extend(intensity_props)
353
-
354
- properties = pd.DataFrame(regionprops_table(labels, intensity_image=intensity_image, properties=props))
355
- if channel_names is not None:
356
- properties = rename_intensity_column(properties, channel_names)
357
- properties['radial_distance'] = np.sqrt((properties['centroid-1'] - labels.shape[0]/2)**2 + (properties['centroid-0'] - labels.shape[1]/2)**2)
358
-
359
- for query in queries:
360
- if query!='':
361
- try:
362
- properties = properties.query(f'not ({query})')
363
- except Exception as e:
364
- print(f'Query {query} could not be applied. Ensure that the feature exists. {e}')
365
- else:
366
- pass
367
-
368
- cell_ids = list(np.unique(labels)[1:])
369
- leftover_cells = list(properties['label'].unique())
370
- to_remove = [value for value in cell_ids if value not in leftover_cells]
371
-
372
- for c in to_remove:
373
- labels[np.where(labels==c)] = 0.
374
-
375
- return labels
426
+ """
427
+ Filters segmented objects in a label image based on specified properties and queries.
428
+
429
+ This function evaluates each segmented object (label) in the input label image against a set of queries related to its
430
+ morphological and intensity properties. Objects not meeting the criteria defined in the queries are removed from the label
431
+ image. This allows for the exclusion of objects based on size, shape, intensity, or custom-defined properties.
432
+
433
+ Parameters
434
+ ----------
435
+ labels : ndarray
436
+ A 2D numpy array where each unique non-zero integer represents a segmented object (label).
437
+ intensity_image : ndarray, optional
438
+ A 2D numpy array of the same shape as `labels`, providing intensity values for each pixel. This is used to calculate
439
+ intensity-related properties of the segmented objects if provided (default is None).
440
+ queries : str or list of str, optional
441
+ One or more query strings used to filter the segmented objects based on their properties. Each query should be a
442
+ valid pandas query string (default is None).
443
+ channel_names : list of str or None, optional
444
+ A list of channel names corresponding to the dimensions in the `intensity_image`. This is used to rename intensity
445
+ property columns appropriately (default is None).
446
+
447
+ Returns
448
+ -------
449
+ ndarray
450
+ A 2D numpy array of the same shape as `labels`, with objects not meeting the query criteria removed.
451
+
452
+ Notes
453
+ -----
454
+ - The function computes a set of predefined morphological properties and, if `intensity_image` is provided, intensity properties.
455
+ - Queries should be structured according to pandas DataFrame query syntax and can reference any of the computed properties.
456
+ - If `channel_names` is provided, intensity property column names are renamed to reflect the corresponding channel.
457
+
458
+ """
459
+
460
+ if queries is None:
461
+ return labels
462
+ else:
463
+ if isinstance(queries, str):
464
+ queries = [queries]
465
+
466
+ props = [
467
+ "label",
468
+ "area",
469
+ "area_bbox",
470
+ "area_convex",
471
+ "area_filled",
472
+ "axis_major_length",
473
+ "axis_minor_length",
474
+ "eccentricity",
475
+ "equivalent_diameter_area",
476
+ "euler_number",
477
+ "feret_diameter_max",
478
+ "orientation",
479
+ "perimeter",
480
+ "perimeter_crofton",
481
+ "solidity",
482
+ "centroid",
483
+ ]
484
+
485
+ intensity_props = ["intensity_mean", "intensity_max", "intensity_min"]
486
+
487
+ if intensity_image is not None:
488
+ props.extend(intensity_props)
489
+
490
+ if intensity_image is not None:
491
+ props.extend(intensity_props)
492
+
493
+ import pandas as pd
494
+
495
+ properties = pd.DataFrame(
496
+ regionprops_table(labels, intensity_image=intensity_image, properties=props)
497
+ )
498
+
499
+ if channel_names is not None:
500
+ properties = rename_intensity_column(properties, channel_names)
501
+ properties["radial_distance"] = np.sqrt(
502
+ (properties["centroid-1"] - labels.shape[0] / 2) ** 2
503
+ + (properties["centroid-0"] - labels.shape[1] / 2) ** 2
504
+ )
505
+
506
+ for query in queries:
507
+ if query != "":
508
+ try:
509
+ properties = properties.query(f"not ({query})")
510
+ except Exception as e:
511
+ logger.error(
512
+ f"Query {query} could not be applied. Ensure that the feature exists. {e}"
513
+ )
514
+ else:
515
+ pass
516
+
517
+ cell_ids = list(np.unique(labels)[1:])
518
+ leftover_cells = list(properties["label"].unique())
519
+ to_remove = [value for value in cell_ids if value not in leftover_cells]
520
+
521
+ for c in to_remove:
522
+ labels[np.where(labels == c)] = 0.0
523
+
524
+ return labels
376
525
 
377
526
 
378
527
  def apply_watershed(binary_image, coords, distance, fill_holes=True):
528
+ """
529
+ Applies the watershed algorithm to segment objects in a binary image using given markers and distance map.
530
+
531
+ This function uses the watershed segmentation algorithm to delineate objects in a binary image. Markers for watershed
532
+ are determined by the coordinates of local maxima, and the segmentation is guided by a distance map to separate objects
533
+ that are close to each other.
534
+
535
+ Parameters
536
+ ----------
537
+ binary_image : ndarray
538
+ A 2D numpy array of type bool, where True represents the foreground objects to be segmented and False represents the background.
539
+ coords : ndarray
540
+ An array of shape (N, 2) containing the (row, column) coordinates of local maxima points that will be used as markers for the
541
+ watershed algorithm. N is the number of local maxima.
542
+ distance : ndarray
543
+ A 2D numpy array of the same shape as `binary_image`, containing the distance transform of the binary image. This map is used
544
+ to guide the watershed segmentation.
545
+
546
+ Returns
547
+ -------
548
+ ndarray
549
+ A 2D numpy array of type int, where each unique non-zero integer represents a segmented object (label).
550
+
551
+ Notes
552
+ -----
553
+ - The function assumes that `coords` are derived from the distance map of `binary_image`, typically obtained using
554
+ peak local max detection on the distance transform.
555
+ - The watershed algorithm treats each local maximum as a separate object and segments the image by "flooding" from these points.
556
+ - This implementation uses the `skimage.morphology.watershed` function under the hood.
557
+
558
+ Examples
559
+ --------
560
+ >>> from skimage import measure, morphology
561
+ >>> binary_image = np.array([[0, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0]], dtype=bool)
562
+ >>> distance = morphology.distance_transform_edt(binary_image)
563
+ >>> coords = measure.peak_local_max(distance, indices=True)
564
+ >>> labels = apply_watershed(binary_image, coords, distance)
565
+ # Segments the objects in `binary_image` using the watershed algorithm.
566
+
567
+ """
568
+
569
+ mask = np.zeros(binary_image.shape, dtype=bool)
570
+ mask[tuple(coords.T)] = True
571
+ markers, _ = ndi.label(mask)
572
+ labels = watershed(-distance, markers, mask=binary_image)
573
+ if fill_holes:
574
+ try:
575
+ from celldetective.utils.mask_cleaning import fill_label_holes
576
+
577
+ labels = fill_label_holes(labels)
578
+ except ImportError as ie:
579
+ logger.warning(f"Stardist not found, cannot fill holes... {ie}")
580
+ return labels
581
+
582
+
583
+ def identify_markers_from_binary(
584
+ binary_image, min_distance, footprint_size=20, footprint=None, return_edt=False
585
+ ):
586
+ """
587
+
588
+ Identify markers from a binary image using distance transform and peak detection.
589
+
590
+ Parameters
591
+ ----------
592
+ binary_image : ndarray
593
+ The binary image from which to identify markers.
594
+ min_distance : int
595
+ The minimum distance between markers. Only the markers with a minimum distance greater than or equal to
596
+ `min_distance` will be identified.
597
+ footprint_size : int, optional
598
+ The size of the footprint or structuring element used for peak detection. Default is 20.
599
+ footprint : ndarray, optional
600
+ The footprint or structuring element used for peak detection. If None, a square footprint of size
601
+ `footprint_size` will be used. Default is None.
602
+ return_edt : bool, optional
603
+ Whether to return the Euclidean distance transform image along with the identified marker coordinates.
604
+ If True, the function will return the marker coordinates and the distance transform image as a tuple.
605
+ If False, only the marker coordinates will be returned. Default is False.
606
+
607
+ Returns
608
+ -------
609
+ ndarray or tuple
610
+ If `return_edt` is False, returns the identified marker coordinates as an ndarray of shape (N, 2), where N is
611
+ the number of identified markers. If `return_edt` is True, returns a tuple containing the marker coordinates
612
+ and the distance transform image.
613
+
614
+ Notes
615
+ -----
616
+ This function uses the distance transform of the binary image to identify markers by detecting local maxima. The
617
+ distance transform assigns each pixel a value representing the Euclidean distance to the nearest background pixel.
618
+ By finding peaks in the distance transform, we can identify the markers in the original binary image. The `min_distance`
619
+ parameter controls the minimum distance between markers to avoid clustering.
620
+
621
+ """
622
+
623
+ distance = ndi.distance_transform_edt(binary_image.astype(float))
624
+ if footprint is None:
625
+ footprint = np.ones((footprint_size, footprint_size))
626
+ coords = peak_local_max(
627
+ distance,
628
+ footprint=footprint,
629
+ labels=binary_image.astype(int),
630
+ min_distance=min_distance,
631
+ )
632
+ if return_edt:
633
+ return coords, distance
634
+ else:
635
+ return coords
636
+
637
+
638
+ def segment_at_position(
639
+ pos,
640
+ mode,
641
+ model_name,
642
+ stack_prefix=None,
643
+ use_gpu=True,
644
+ return_labels=False,
645
+ view_on_napari=False,
646
+ threads=1,
647
+ ):
648
+ """
649
+ Perform image segmentation at the specified position using a pre-trained model.
650
+
651
+ Parameters
652
+ ----------
653
+ pos : str
654
+ The path to the position directory containing the input images to be segmented.
655
+ mode : str
656
+ The segmentation mode. This determines the type of objects to be segmented ('target' or 'effector').
657
+ model_name : str
658
+ The name of the pre-trained segmentation model to be used.
659
+ stack_prefix : str or None, optional
660
+ The prefix of the stack file name. Defaults to None.
661
+ use_gpu : bool, optional
662
+ Whether to use the GPU for segmentation if available. Defaults to True.
663
+ return_labels : bool, optional
664
+ If True, the function returns the segmentation labels as an output. Defaults to False.
665
+ view_on_napari : bool, optional
666
+ If True, the segmented labels are displayed in a Napari viewer. Defaults to False.
667
+
668
+ Returns
669
+ -------
670
+ numpy.ndarray or None
671
+ If `return_labels` is True, the function returns the segmentation labels as a NumPy array. Otherwise, it returns None. The subprocess writes the
672
+ segmentation labels in the position directory.
673
+
674
+ Examples
675
+ --------
676
+ >>> labels = segment_at_position('ExperimentFolder/W1/100/', 'effector', 'mice_t_cell_RICM', return_labels=True)
677
+
678
+ """
679
+
680
+ pos = pos.replace("\\", "/")
681
+ pos = rf"{pos}"
682
+ assert os.path.exists(pos), f"Position {pos} is not a valid path."
683
+
684
+ name_path = locate_segmentation_model(model_name)
685
+
686
+ script_path = os.sep.join([abs_path, "scripts", "segment_cells.py"])
687
+ cmd = f'python "{script_path}" --pos "{pos}" --model "{model_name}" --mode "{mode}" --use_gpu "{use_gpu}" --threads "{threads}"'
688
+ subprocess.call(cmd, shell=True)
689
+
690
+ if return_labels or view_on_napari:
691
+ labels = locate_labels(pos, population=mode)
692
+ if view_on_napari:
693
+ if stack_prefix is None:
694
+ stack_prefix = ""
695
+ stack = locate_stack(pos, prefix=stack_prefix)
696
+ _view_on_napari(tracks=None, stack=stack, labels=labels)
697
+ if return_labels:
698
+ return labels
699
+ else:
700
+ return None
379
701
 
380
- """
381
- Applies the watershed algorithm to segment objects in a binary image using given markers and distance map.
382
-
383
- This function uses the watershed segmentation algorithm to delineate objects in a binary image. Markers for watershed
384
- are determined by the coordinates of local maxima, and the segmentation is guided by a distance map to separate objects
385
- that are close to each other.
386
-
387
- Parameters
388
- ----------
389
- binary_image : ndarray
390
- A 2D numpy array of type bool, where True represents the foreground objects to be segmented and False represents the background.
391
- coords : ndarray
392
- An array of shape (N, 2) containing the (row, column) coordinates of local maxima points that will be used as markers for the
393
- watershed algorithm. N is the number of local maxima.
394
- distance : ndarray
395
- A 2D numpy array of the same shape as `binary_image`, containing the distance transform of the binary image. This map is used
396
- to guide the watershed segmentation.
397
-
398
- Returns
399
- -------
400
- ndarray
401
- A 2D numpy array of type int, where each unique non-zero integer represents a segmented object (label).
402
-
403
- Notes
404
- -----
405
- - The function assumes that `coords` are derived from the distance map of `binary_image`, typically obtained using
406
- peak local max detection on the distance transform.
407
- - The watershed algorithm treats each local maximum as a separate object and segments the image by "flooding" from these points.
408
- - This implementation uses the `skimage.morphology.watershed` function under the hood.
409
-
410
- Examples
411
- --------
412
- >>> from skimage import measure, morphology
413
- >>> binary_image = np.array([[0, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0]], dtype=bool)
414
- >>> distance = morphology.distance_transform_edt(binary_image)
415
- >>> coords = measure.peak_local_max(distance, indices=True)
416
- >>> labels = apply_watershed(binary_image, coords, distance)
417
- # Segments the objects in `binary_image` using the watershed algorithm.
418
-
419
- """
420
-
421
- mask = np.zeros(binary_image.shape, dtype=bool)
422
- mask[tuple(coords.T)] = True
423
- markers, _ = ndi.label(mask)
424
- labels = watershed(-distance, markers, mask=binary_image)
425
- if fill_holes:
426
- labels = fill_label_holes(labels)
427
- return labels
428
-
429
- def identify_markers_from_binary(binary_image, min_distance, footprint_size=20, footprint=None, return_edt=False):
430
-
431
- """
432
-
433
- Identify markers from a binary image using distance transform and peak detection.
434
-
435
- Parameters
436
- ----------
437
- binary_image : ndarray
438
- The binary image from which to identify markers.
439
- min_distance : int
440
- The minimum distance between markers. Only the markers with a minimum distance greater than or equal to
441
- `min_distance` will be identified.
442
- footprint_size : int, optional
443
- The size of the footprint or structuring element used for peak detection. Default is 20.
444
- footprint : ndarray, optional
445
- The footprint or structuring element used for peak detection. If None, a square footprint of size
446
- `footprint_size` will be used. Default is None.
447
- return_edt : bool, optional
448
- Whether to return the Euclidean distance transform image along with the identified marker coordinates.
449
- If True, the function will return the marker coordinates and the distance transform image as a tuple.
450
- If False, only the marker coordinates will be returned. Default is False.
451
-
452
- Returns
453
- -------
454
- ndarray or tuple
455
- If `return_edt` is False, returns the identified marker coordinates as an ndarray of shape (N, 2), where N is
456
- the number of identified markers. If `return_edt` is True, returns a tuple containing the marker coordinates
457
- and the distance transform image.
458
-
459
- Notes
460
- -----
461
- This function uses the distance transform of the binary image to identify markers by detecting local maxima. The
462
- distance transform assigns each pixel a value representing the Euclidean distance to the nearest background pixel.
463
- By finding peaks in the distance transform, we can identify the markers in the original binary image. The `min_distance`
464
- parameter controls the minimum distance between markers to avoid clustering.
465
-
466
- """
467
-
468
- distance = ndi.distance_transform_edt(binary_image.astype(float))
469
- if footprint is None:
470
- footprint = np.ones((footprint_size, footprint_size))
471
- coords = peak_local_max(distance, footprint=footprint,
472
- labels=binary_image.astype(int), min_distance=min_distance)
473
- if return_edt:
474
- return coords, distance
475
- else:
476
- return coords
477
-
478
-
479
- def threshold_image(img, min_threshold, max_threshold, foreground_value=255., fill_holes=True, edge_exclusion=None):
480
-
481
- """
482
-
483
- Threshold the input image to create a binary mask.
484
-
485
- Parameters
486
- ----------
487
- img : ndarray
488
- The input image to be thresholded.
489
- min_threshold : float
490
- The minimum threshold value.
491
- max_threshold : float
492
- The maximum threshold value.
493
- foreground_value : float, optional
494
- The value assigned to foreground pixels in the binary mask. Default is 255.
495
- fill_holes : bool, optional
496
- Whether to fill holes in the binary mask. If True, the binary mask will be processed to fill any holes.
497
- If False, the binary mask will not be modified. Default is True.
498
-
499
- Returns
500
- -------
501
- ndarray
502
- The binary mask after thresholding.
503
-
504
- Notes
505
- -----
506
- This function applies a threshold to the input image to create a binary mask. Pixels with values within the specified
507
- threshold range are considered as foreground and assigned the `foreground_value`, while pixels outside the range are
508
- considered as background and assigned 0. If `fill_holes` is True, the binary mask will be processed to fill any holes
509
- using morphological operations.
510
-
511
- Examples
512
- --------
513
- >>> image = np.random.rand(256, 256)
514
- >>> binary_mask = threshold_image(image, 0.2, 0.8, foreground_value=1., fill_holes=True)
515
-
516
- """
517
-
518
- binary = np.zeros_like(img).astype(bool)
519
- binary[img==img] = (img[img==img]>=min_threshold)*(img[img==img]<=max_threshold) * foreground_value
520
- if isinstance(edge_exclusion, (int,np.int_)):
521
- binary = mask_edges(binary, edge_exclusion)
522
- if fill_holes:
523
- binary = ndi.binary_fill_holes(binary.astype(int))
524
- return binary
525
-
526
- def filter_image(img, filters=None):
527
-
528
- """
529
-
530
- Apply one or more image filters to the input image.
531
-
532
- Parameters
533
- ----------
534
- img : ndarray
535
- The input image to be filtered.
536
- filters : list or None, optional
537
- A list of filters to be applied to the image. Each filter is represented as a tuple or list with the first element being
538
- the filter function name (minus the '_filter' extension, as listed in software.filters) and the subsequent elements being
539
- the arguments for that filter function. If None, the original image is returned without any filtering applied. Default is None.
540
-
541
- Returns
542
- -------
543
- ndarray
544
- The filtered image.
545
-
546
- Notes
547
- -----
548
- This function applies a series of image filters to the input image. The filters are specified as a list of tuples,
549
- where each tuple contains the name of the filter function and its corresponding arguments. The filters are applied
550
- sequentially to the image. If no filters are provided, the original image is returned unchanged.
551
-
552
- Examples
553
- --------
554
- >>> image = np.random.rand(256, 256)
555
- >>> filtered_image = filter_image(image, filters=[('gaussian', 3), ('median', 5)])
556
-
557
- """
558
-
559
- if filters is None:
560
- return img
561
-
562
- if img.ndim==3:
563
- img = np.squeeze(img)
564
-
565
- for f in filters:
566
- func = eval(f[0]+'_filter')
567
- img = func(img, *f[1:])
568
- return img
569
-
570
-
571
- def segment_at_position(pos, mode, model_name, stack_prefix=None, use_gpu=True, return_labels=False, view_on_napari=False, threads=1):
572
-
573
- """
574
- Perform image segmentation at the specified position using a pre-trained model.
575
-
576
- Parameters
577
- ----------
578
- pos : str
579
- The path to the position directory containing the input images to be segmented.
580
- mode : str
581
- The segmentation mode. This determines the type of objects to be segmented ('target' or 'effector').
582
- model_name : str
583
- The name of the pre-trained segmentation model to be used.
584
- stack_prefix : str or None, optional
585
- The prefix of the stack file name. Defaults to None.
586
- use_gpu : bool, optional
587
- Whether to use the GPU for segmentation if available. Defaults to True.
588
- return_labels : bool, optional
589
- If True, the function returns the segmentation labels as an output. Defaults to False.
590
- view_on_napari : bool, optional
591
- If True, the segmented labels are displayed in a Napari viewer. Defaults to False.
592
-
593
- Returns
594
- -------
595
- numpy.ndarray or None
596
- If `return_labels` is True, the function returns the segmentation labels as a NumPy array. Otherwise, it returns None. The subprocess writes the
597
- segmentation labels in the position directory.
598
-
599
- Examples
600
- --------
601
- >>> labels = segment_at_position('ExperimentFolder/W1/100/', 'effector', 'mice_t_cell_RICM', return_labels=True)
602
-
603
- """
604
-
605
- pos = pos.replace('\\','/')
606
- pos = rf'{pos}'
607
- assert os.path.exists(pos),f'Position {pos} is not a valid path.'
608
-
609
- name_path = locate_segmentation_model(model_name)
610
-
611
- script_path = os.sep.join([abs_path, 'scripts', 'segment_cells.py'])
612
- cmd = f'python "{script_path}" --pos "{pos}" --model "{model_name}" --mode "{mode}" --use_gpu "{use_gpu}" --threads "{threads}"'
613
- subprocess.call(cmd, shell=True)
614
-
615
- if return_labels or view_on_napari:
616
- labels = locate_labels(pos, population=mode)
617
- if view_on_napari:
618
- if stack_prefix is None:
619
- stack_prefix = ''
620
- stack = locate_stack(pos, prefix=stack_prefix)
621
- _view_on_napari(tracks=None, stack=stack, labels=labels)
622
- if return_labels:
623
- return labels
624
- else:
625
- return None
626
702
 
627
703
  def segment_from_threshold_at_position(pos, mode, config, threads=1):
628
-
629
- """
630
- Executes a segmentation script on a specified position directory using a given configuration and mode.
631
-
632
- This function calls an external Python script designed to segment images at a specified position directory.
633
- The segmentation is configured through a JSON file and can operate in different modes specified by the user.
634
- The function can leverage multiple threads to potentially speed up the processing.
635
-
636
- Parameters
637
- ----------
638
- pos : str
639
- The file path to the position directory where images to be segmented are stored. The path must be valid.
640
- mode : str
641
- The operation mode for the segmentation script. The mode determines how the segmentation is performed and
642
- which algorithm or parameters are used.
643
- config : str
644
- The file path to the JSON configuration file that specifies parameters for the segmentation process. The
645
- path must be valid.
646
- threads : int, optional
647
- The number of threads to use for processing. Using more than one thread can speed up segmentation on
648
- systems with multiple CPU cores (default is 1).
649
-
650
- Raises
651
- ------
652
- AssertionError
653
- If either the `pos` or `config` paths do not exist.
654
-
655
- Notes
656
- -----
657
- - The external segmentation script (`segment_cells_thresholds.py`) is expected to be located in a specific
658
- directory relative to this function.
659
- - The segmentation process and its parameters, including modes and thread usage, are defined by the external
660
- script and the configuration file.
661
-
662
- Examples
663
- --------
664
- >>> pos = '/path/to/position'
665
- >>> mode = 'default'
666
- >>> config = '/path/to/config.json'
667
- >>> segment_from_threshold_at_position(pos, mode, config, threads=2)
668
- # This will execute the segmentation script on the specified position directory with the given mode and
669
- # configuration, utilizing 2 threads.
670
-
671
- """
672
-
673
-
674
- pos = pos.replace('\\','/')
675
- pos = rf"{pos}"
676
- assert os.path.exists(pos),f'Position {pos} is not a valid path.'
677
-
678
- config = config.replace('\\','/')
679
- config = rf"{config}"
680
- assert os.path.exists(config),f'Config {config} is not a valid path.'
681
-
682
- script_path = os.sep.join([abs_path, 'scripts', 'segment_cells_thresholds.py'])
683
- cmd = f'python "{script_path}" --pos "{pos}" --config "{config}" --mode "{mode}" --threads "{threads}"'
684
- subprocess.call(cmd, shell=True)
704
+ """
705
+ Executes a segmentation script on a specified position directory using a given configuration and mode.
706
+
707
+ This function calls an external Python script designed to segment images at a specified position directory.
708
+ The segmentation is configured through a JSON file and can operate in different modes specified by the user.
709
+ The function can leverage multiple threads to potentially speed up the processing.
710
+
711
+ Parameters
712
+ ----------
713
+ pos : str
714
+ The file path to the position directory where images to be segmented are stored. The path must be valid.
715
+ mode : str
716
+ The operation mode for the segmentation script. The mode determines how the segmentation is performed and
717
+ which algorithm or parameters are used.
718
+ config : str
719
+ The file path to the JSON configuration file that specifies parameters for the segmentation process. The
720
+ path must be valid.
721
+ threads : int, optional
722
+ The number of threads to use for processing. Using more than one thread can speed up segmentation on
723
+ systems with multiple CPU cores (default is 1).
724
+
725
+ Raises
726
+ ------
727
+ AssertionError
728
+ If either the `pos` or `config` paths do not exist.
729
+
730
+ Notes
731
+ -----
732
+ - The external segmentation script (`segment_cells_thresholds.py`) is expected to be located in a specific
733
+ directory relative to this function.
734
+ - The segmentation process and its parameters, including modes and thread usage, are defined by the external
735
+ script and the configuration file.
736
+
737
+ Examples
738
+ --------
739
+ >>> pos = '/path/to/position'
740
+ >>> mode = 'default'
741
+ >>> config = '/path/to/config.json'
742
+ >>> segment_from_threshold_at_position(pos, mode, config, threads=2)
743
+ # This will execute the segmentation script on the specified position directory with the given mode and
744
+ # configuration, utilizing 2 threads.
745
+
746
+ """
747
+
748
+ pos = pos.replace("\\", "/")
749
+ pos = rf"{pos}"
750
+ assert os.path.exists(pos), f"Position {pos} is not a valid path."
751
+
752
+ config = config.replace("\\", "/")
753
+ config = rf"{config}"
754
+ assert os.path.exists(config), f"Config {config} is not a valid path."
755
+
756
+ script_path = os.sep.join([abs_path, "scripts", "segment_cells_thresholds.py"])
757
+ cmd = f'python "{script_path}" --pos "{pos}" --config "{config}" --mode "{mode}" --threads "{threads}"'
758
+ subprocess.call(cmd, shell=True)
685
759
 
686
760
 
687
761
  def train_segmentation_model(config, use_gpu=True):
688
-
689
- """
690
- Trains a segmentation model based on a specified configuration file.
691
-
692
- This function initiates the training of a segmentation model by calling an external Python script,
693
- which reads the training parameters and dataset information from a given JSON configuration file.
694
- The training process, including model architecture, training data, and hyperparameters, is defined
695
- by the contents of the configuration file.
696
-
697
- Parameters
698
- ----------
699
- config : str
700
- The file path to the JSON configuration file that specifies training parameters and dataset
701
- information for the segmentation model. The path must be valid.
702
-
703
- Raises
704
- ------
705
- AssertionError
706
- If the `config` path does not exist.
707
-
708
- Notes
709
- -----
710
- - The external training script (`train_segmentation_model.py`) is assumed to be located in a specific
711
- directory relative to this function.
712
- - The segmentation model and training process are highly dependent on the details specified in the
713
- configuration file, including the model architecture, loss functions, optimizer settings, and
714
- training/validation data paths.
715
-
716
- Examples
717
- --------
718
- >>> config = '/path/to/training_config.json'
719
- >>> train_segmentation_model(config)
720
- # Initiates the training of a segmentation model using the parameters specified in the given configuration file.
721
-
722
- """
723
-
724
- config = config.replace('\\','/')
725
- config = rf"{config}"
726
- assert os.path.exists(config),f'Config {config} is not a valid path.'
727
-
728
- script_path = os.sep.join([abs_path, 'scripts', 'train_segmentation_model.py'])
729
- cmd = f'python "{script_path}" --config "{config}" --use_gpu "{use_gpu}"'
730
- subprocess.call(cmd, shell=True)
731
-
732
-
733
- def merge_instance_segmentation(labels, iou_matching_threshold=0.05, mode='OR'):
734
-
735
- label_reference = labels[0]
736
- for i in range(1,len(labels)):
737
-
738
- label_to_merge = labels[i]
739
- pairs = matching(label_reference,label_to_merge, thresh=0.5, criterion='iou', report_matches=True).matched_pairs
740
- scores = matching(label_reference,label_to_merge, thresh=0.5, criterion='iou', report_matches=True).matched_scores
741
-
742
- accepted_pairs = []
743
- for k,p in enumerate(pairs):
744
- s = scores[k]
745
- if s > iou_matching_threshold:
746
- accepted_pairs.append(p)
747
-
748
- merge = np.copy(label_reference)
749
-
750
- for p in accepted_pairs:
751
- merge[np.where(merge==p[0])] = 0.
752
- cdt1 = label_reference==p[0]
753
- cdt2 = label_to_merge==p[1]
754
- if mode=='OR':
755
- cdt = np.logical_or(cdt1, cdt2)
756
- elif mode=='AND':
757
- cdt = np.logical_and(cdt1, cdt2)
758
- elif mode=='XOR':
759
- cdt = np.logical_xor(cdt1,cdt2)
760
- loc_i, loc_j = np.where(cdt)
761
- merge[loc_i, loc_j] = p[0]
762
-
763
- cells_to_ignore = [p[1] for p in accepted_pairs]
764
- for c in cells_to_ignore:
765
- label_to_merge[label_to_merge==c] = 0
766
-
767
- label_to_merge[label_to_merge!=0] = label_to_merge[label_to_merge!=0] + int(np.amax(label_reference))
768
- merge[label_to_merge!=0] = label_to_merge[label_to_merge!=0]
769
-
770
- label_reference = merge
771
-
772
- merge = auto_correct_masks(merge)
773
-
774
- return merge
762
+ """
763
+ Trains a segmentation model based on a specified configuration file.
764
+
765
+ This function initiates the training of a segmentation model by calling an external Python script,
766
+ which reads the training parameters and dataset information from a given JSON configuration file.
767
+ The training process, including model architecture, training data, and hyperparameters, is defined
768
+ by the contents of the configuration file.
769
+
770
+ Parameters
771
+ ----------
772
+ config : str
773
+ The file path to the JSON configuration file that specifies training parameters and dataset
774
+ information for the segmentation model. The path must be valid.
775
+
776
+ Raises
777
+ ------
778
+ AssertionError
779
+ If the `config` path does not exist.
780
+
781
+ Notes
782
+ -----
783
+ - The external training script (`train_segmentation_model.py`) is assumed to be located in a specific
784
+ directory relative to this function.
785
+ - The segmentation model and training process are highly dependent on the details specified in the
786
+ configuration file, including the model architecture, loss functions, optimizer settings, and
787
+ training/validation data paths.
788
+
789
+ Examples
790
+ --------
791
+ >>> config = '/path/to/training_config.json'
792
+ >>> train_segmentation_model(config)
793
+ # Initiates the training of a segmentation model using the parameters specified in the given configuration file.
794
+
795
+ """
796
+
797
+ config = config.replace("\\", "/")
798
+ config = rf"{config}"
799
+ assert os.path.exists(config), f"Config {config} is not a valid path."
800
+
801
+ script_path = os.sep.join([abs_path, "scripts", "train_segmentation_model.py"])
802
+ cmd = f'python "{script_path}" --config "{config}" --use_gpu "{use_gpu}"'
803
+ subprocess.call(cmd, shell=True)
804
+
805
+
806
+ def merge_instance_segmentation(labels, iou_matching_threshold=0.05, mode="OR"):
807
+
808
+ label_reference = labels[0]
809
+ try:
810
+ from stardist.matching import matching
811
+ except ImportError:
812
+ logger.warning(
813
+ "StarDist not installed. Cannot perform instance matching/merging..."
814
+ )
815
+ return label_reference
816
+
817
+ for i in range(1, len(labels)):
818
+
819
+ label_to_merge = labels[i]
820
+ pairs = matching(
821
+ label_reference,
822
+ label_to_merge,
823
+ thresh=0.5,
824
+ criterion="iou",
825
+ report_matches=True,
826
+ ).matched_pairs
827
+ scores = matching(
828
+ label_reference,
829
+ label_to_merge,
830
+ thresh=0.5,
831
+ criterion="iou",
832
+ report_matches=True,
833
+ ).matched_scores
834
+
835
+ accepted_pairs = []
836
+ for k, p in enumerate(pairs):
837
+ s = scores[k]
838
+ if s > iou_matching_threshold:
839
+ accepted_pairs.append(p)
840
+
841
+ merge = np.copy(label_reference)
842
+
843
+ for p in accepted_pairs:
844
+ merge[np.where(merge == p[0])] = 0.0
845
+ cdt1 = label_reference == p[0]
846
+ cdt2 = label_to_merge == p[1]
847
+ if mode == "OR":
848
+ cdt = np.logical_or(cdt1, cdt2)
849
+ elif mode == "AND":
850
+ cdt = np.logical_and(cdt1, cdt2)
851
+ elif mode == "XOR":
852
+ cdt = np.logical_xor(cdt1, cdt2)
853
+ loc_i, loc_j = np.where(cdt)
854
+ merge[loc_i, loc_j] = p[0]
855
+
856
+ cells_to_ignore = [p[1] for p in accepted_pairs]
857
+ for c in cells_to_ignore:
858
+ label_to_merge[label_to_merge == c] = 0
859
+
860
+ label_to_merge[label_to_merge != 0] = label_to_merge[label_to_merge != 0] + int(
861
+ np.amax(label_reference)
862
+ )
863
+ merge[label_to_merge != 0] = label_to_merge[label_to_merge != 0]
864
+
865
+ label_reference = merge
866
+
867
+ merge = auto_correct_masks(merge)
868
+
869
+ return merge
775
870
 
776
871
 
777
872
  if __name__ == "__main__":
778
- print(segment(None,'test'))
873
+ print(segment(None, "test"))