celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +403 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/downloader.py +137 -0
  81. celldetective/processes/measure_cells.py +565 -0
  82. celldetective/processes/segment_cells.py +760 -0
  83. celldetective/processes/track_cells.py +435 -0
  84. celldetective/processes/train_segmentation_model.py +694 -0
  85. celldetective/processes/train_signal_model.py +265 -0
  86. celldetective/processes/unified_process.py +292 -0
  87. celldetective/regionprops/_regionprops.py +358 -317
  88. celldetective/relative_measurements.py +987 -710
  89. celldetective/scripts/measure_cells.py +313 -212
  90. celldetective/scripts/measure_relative.py +90 -46
  91. celldetective/scripts/segment_cells.py +165 -104
  92. celldetective/scripts/segment_cells_thresholds.py +96 -68
  93. celldetective/scripts/track_cells.py +198 -149
  94. celldetective/scripts/train_segmentation_model.py +324 -201
  95. celldetective/scripts/train_signal_model.py +87 -45
  96. celldetective/segmentation.py +844 -749
  97. celldetective/signals.py +3514 -2861
  98. celldetective/tracking.py +30 -15
  99. celldetective/utils/__init__.py +0 -0
  100. celldetective/utils/cellpose_utils/__init__.py +133 -0
  101. celldetective/utils/color_mappings.py +42 -0
  102. celldetective/utils/data_cleaning.py +630 -0
  103. celldetective/utils/data_loaders.py +450 -0
  104. celldetective/utils/dataset_helpers.py +207 -0
  105. celldetective/utils/downloaders.py +235 -0
  106. celldetective/utils/event_detection/__init__.py +8 -0
  107. celldetective/utils/experiment.py +1782 -0
  108. celldetective/utils/image_augmenters.py +308 -0
  109. celldetective/utils/image_cleaning.py +74 -0
  110. celldetective/utils/image_loaders.py +926 -0
  111. celldetective/utils/image_transforms.py +335 -0
  112. celldetective/utils/io.py +62 -0
  113. celldetective/utils/mask_cleaning.py +348 -0
  114. celldetective/utils/mask_transforms.py +5 -0
  115. celldetective/utils/masks.py +184 -0
  116. celldetective/utils/maths.py +351 -0
  117. celldetective/utils/model_getters.py +325 -0
  118. celldetective/utils/model_loaders.py +296 -0
  119. celldetective/utils/normalization.py +380 -0
  120. celldetective/utils/parsing.py +465 -0
  121. celldetective/utils/plots/__init__.py +0 -0
  122. celldetective/utils/plots/regression.py +53 -0
  123. celldetective/utils/resources.py +34 -0
  124. celldetective/utils/stardist_utils/__init__.py +104 -0
  125. celldetective/utils/stats.py +90 -0
  126. celldetective/utils/types.py +21 -0
  127. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
  128. celldetective-1.5.0b1.dist-info/RECORD +187 -0
  129. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
  130. tests/gui/test_new_project.py +129 -117
  131. tests/gui/test_project.py +127 -79
  132. tests/test_filters.py +39 -15
  133. tests/test_notebooks.py +8 -0
  134. tests/test_tracking.py +232 -13
  135. tests/test_utils.py +123 -77
  136. celldetective/gui/base_components.py +0 -23
  137. celldetective/gui/layouts.py +0 -1602
  138. celldetective/gui/processes/compute_neighborhood.py +0 -594
  139. celldetective/gui/processes/downloader.py +0 -111
  140. celldetective/gui/processes/measure_cells.py +0 -360
  141. celldetective/gui/processes/segment_cells.py +0 -499
  142. celldetective/gui/processes/track_cells.py +0 -303
  143. celldetective/gui/processes/train_segmentation_model.py +0 -270
  144. celldetective/gui/processes/train_signal_model.py +0 -108
  145. celldetective/gui/table_ops/merge_groups.py +0 -118
  146. celldetective/gui/viewers.py +0 -1354
  147. celldetective/io.py +0 -3663
  148. celldetective/utils.py +0 -3108
  149. celldetective-1.4.2.dist-info/RECORD +0 -123
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
  152. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
1
+ import concurrent.futures
2
+ import threading
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from skimage.measure import regionprops_table, label
7
+ from skimage.transform import resize
8
+ from tqdm import tqdm
9
+
10
+ from celldetective.utils.image_loaders import load_frames
11
+ from scipy.ndimage import binary_fill_holes
12
+ from scipy.ndimage import find_objects
13
+
14
+
15
+ def fill_label_holes(lbl_img, **kwargs):
16
+ """Fill small holes in label image.
17
+ from https://github.com/stardist/stardist/blob/main/stardist/utils.py
18
+ """
19
+
20
+ # TODO: refactor 'fill_label_holes' and 'edt_prob' to share code
21
+ def grow(sl, interior):
22
+ return tuple(
23
+ slice(s.start - int(w[0]), s.stop + int(w[1])) for s, w in zip(sl, interior)
24
+ )
25
+
26
+ def shrink(interior):
27
+ return tuple(slice(int(w[0]), (-1 if w[1] else None)) for w in interior)
28
+
29
+ objects = find_objects(lbl_img)
30
+ lbl_img_filled = np.zeros_like(lbl_img)
31
+ for i, sl in enumerate(objects, 1):
32
+ if sl is None:
33
+ continue
34
+ interior = [(s.start > 0, s.stop < sz) for s, sz in zip(sl, lbl_img.shape)]
35
+ shrink_slice = shrink(interior)
36
+ grown_mask = lbl_img[grow(sl, interior)] == i
37
+ mask_filled = binary_fill_holes(grown_mask, **kwargs)[shrink_slice]
38
+ lbl_img_filled[sl][mask_filled] = i
39
+ if lbl_img.min() < 0:
40
+ # preserve (and fill holes in) negative labels ('find_objects' ignores these)
41
+ lbl_neg_filled = -fill_label_holes(-np.minimum(lbl_img, 0))
42
+ mask = lbl_neg_filled < 0
43
+ lbl_img_filled[mask] = lbl_neg_filled[mask]
44
+ return lbl_img_filled
45
+
46
+
47
+ def _check_label_dims(lbl, file=None, template=None):
48
+
49
+ if file is not None:
50
+ template = load_frames(0, file, scale=1, normalize_input=False)
51
+ elif template is not None:
52
+ template = template
53
+ else:
54
+ return lbl
55
+
56
+ if lbl.shape != template.shape[:2]:
57
+ lbl = resize(lbl, template.shape[:2], order=0)
58
+ return lbl
59
+
60
+
61
+ def auto_correct_masks(
62
+ masks, bbox_factor: float = 1.75, min_area: int = 9, fill_labels: bool = False
63
+ ):
64
+ """
65
+ Correct segmentation masks to ensure consistency and remove anomalies.
66
+
67
+ This function processes a labeled mask image to correct anomalies and reassign labels.
68
+ It performs the following operations:
69
+
70
+ 1. Corrects negative mask values by taking their absolute values.
71
+ 2. Identifies and corrects segmented objects with a bounding box area that is disproportionately
72
+ larger than the actual object area. This indicates potential segmentation errors where separate objects
73
+ share the same label.
74
+ 3. Removes small objects that are considered noise (default threshold is an area of less than 9 pixels).
75
+ 4. Reorders the labels so they are consecutive from 1 up to the number of remaining objects (to avoid encoding errors).
76
+
77
+ Parameters
78
+ ----------
79
+ masks : np.ndarray
80
+ A 2D array representing the segmented mask image with labeled regions. Each unique value
81
+ in the array represents a different object or cell.
82
+ bbox_factor : float, optional
83
+ A factor on cell area that is compared directly to the bounding box area of the cell, to detect remote cells
84
+ sharing a same label value. The default is `1.75`.
85
+ min_area : int, optional
86
+ Discard cells that have an area smaller than this minimum area (px²). The default is `9` (3x3 pixels).
87
+ fill_labels : bool, optional
88
+ Fill holes within cell masks automatically. The default is `False`.
89
+
90
+ Returns
91
+ -------
92
+ clean_labels : np.ndarray
93
+ A corrected version of the input mask, with anomalies corrected, small objects removed,
94
+ and labels reordered to be consecutive integers.
95
+
96
+ Notes
97
+ -----
98
+ - This function is useful for post-processing segmentation outputs to ensure high-quality
99
+ object detection, particularly in applications such as cell segmentation in microscopy images.
100
+ - The function assumes that the input masks contain integer labels and that the background
101
+ is represented by 0.
102
+
103
+ Examples
104
+ --------
105
+ >>> masks = np.array([[0, 0, 1, 1], [0, 2, 2, 1], [0, 2, 0, 0]])
106
+ >>> corrected_masks = auto_correct_masks(masks)
107
+ >>> corrected_masks
108
+ array([[0, 0, 1, 1],
109
+ [0, 2, 2, 1],
110
+ [0, 2, 0, 0]])
111
+ """
112
+
113
+ assert masks.ndim == 2, "`masks` should be a 2D numpy array..."
114
+
115
+ # Avoid negative mask values
116
+ masks[masks < 0] = np.abs(masks[masks < 0])
117
+
118
+ props = pd.DataFrame(
119
+ regionprops_table(masks, properties=("label", "area", "area_bbox"))
120
+ )
121
+ max_lbl = props["label"].max()
122
+ corrected_lbl = masks.copy() # .astype(int)
123
+
124
+ for cell in props["label"].unique():
125
+
126
+ bbox_area = props.loc[props["label"] == cell, "area_bbox"].values
127
+ area = props.loc[props["label"] == cell, "area"].values
128
+
129
+ if bbox_area > bbox_factor * area: # condition for anomaly
130
+
131
+ lbl = masks == cell
132
+ lbl = lbl.astype(int)
133
+
134
+ relabelled = label(lbl, connectivity=2)
135
+ relabelled += max_lbl
136
+ relabelled[np.where(lbl == 0)] = 0
137
+
138
+ corrected_lbl[np.where(relabelled != 0)] = relabelled[
139
+ np.where(relabelled != 0)
140
+ ]
141
+
142
+ max_lbl = np.amax(corrected_lbl)
143
+
144
+ # Second routine to eliminate objects too small
145
+ props2 = pd.DataFrame(
146
+ regionprops_table(corrected_lbl, properties=("label", "area", "area_bbox"))
147
+ )
148
+ for cell in props2["label"].unique():
149
+ area = props2.loc[props2["label"] == cell, "area"].values
150
+ lbl = corrected_lbl == cell
151
+ if area < min_area:
152
+ corrected_lbl[lbl] = 0
153
+
154
+ # Additionnal routine to reorder labels from 1 to number of cells
155
+ label_ids = np.unique(corrected_lbl)[1:]
156
+ clean_labels = corrected_lbl.copy()
157
+
158
+ for k, lbl in enumerate(label_ids):
159
+ clean_labels[corrected_lbl == lbl] = k + 1
160
+
161
+ clean_labels = clean_labels.astype(int)
162
+
163
+ if fill_labels:
164
+ clean_labels = fill_label_holes(clean_labels)
165
+
166
+ return clean_labels
167
+
168
+
169
+ def relabel_segmentation(
170
+ labels,
171
+ df,
172
+ exclude_nans=True,
173
+ column_labels={
174
+ "track": "TRACK_ID",
175
+ "frame": "FRAME",
176
+ "y": "POSITION_Y",
177
+ "x": "POSITION_X",
178
+ "label": "class_id",
179
+ },
180
+ threads=1,
181
+ progress_callback=None,
182
+ ):
183
+ """
184
+ Relabel the segmentation labels with the tracking IDs from the tracks.
185
+
186
+ The function reassigns the mask value for each cell with the associated `TRACK_ID`, if it exists
187
+ in the trajectory table (`df`). If no track uses the cell mask, a new track with a single point
188
+ is generated on the fly (max of `TRACK_ID` values + i, for i=0 to N such cells). It supports
189
+ multithreaded processing for faster execution on large datasets.
190
+
191
+ Parameters
192
+ ----------
193
+ labels : np.ndarray
194
+ A (TYX) array where each frame contains a 2D segmentation mask. Each unique
195
+ non-zero integer represents a labeled object.
196
+ df : pandas.DataFrame
197
+ A DataFrame containing tracking information with columns
198
+ specified in `column_labels`. Must include at least frame, track ID, and object ID.
199
+ exclude_nans : bool, optional
200
+ Whether to exclude rows in `df` with NaN values in the column specified by
201
+ `column_labels['label']`. Default is `True`.
202
+ column_labels : dict, optional
203
+ A dictionary specifying the column names in `df`. Default is:
204
+ - `'track'`: Track ID column name (`"TRACK_ID"`)
205
+ - `'frame'`: Frame column name (`"FRAME"`)
206
+ - `'y'`: Y-coordinate column name (`"POSITION_Y"`)
207
+ - `'x'`: X-coordinate column name (`"POSITION_X"`)
208
+ - `'label'`: Object ID column name (`"class_id"`)
209
+ threads : int, optional
210
+ Number of threads to use for multithreaded processing. Default is `1`.
211
+
212
+ Returns
213
+ -------
214
+ np.ndarray
215
+ A new (TYX) array with the same shape as `labels`, where objects are relabeled
216
+ according to their tracking identity in `df`.
217
+
218
+ Notes
219
+ -----
220
+ - For frames where labeled objects in `labels` do not match any entries in the `df`,
221
+ new track IDs are generated for the unmatched labels.
222
+ - The relabeling process maintains synchronization across threads using a shared
223
+ counter for generating unique track IDs.
224
+
225
+ Examples
226
+ --------
227
+ Relabel segmentation using tracking data:
228
+
229
+ >>> labels = np.random.randint(0, 5, (10, 100, 100))
230
+ >>> df = pd.DataFrame({
231
+ ... "TRACK_ID": [1, 2, 1, 2],
232
+ ... "FRAME": [0, 0, 1, 1],
233
+ ... "class_id": [1, 2, 1, 2],
234
+ ... })
235
+ >>> new_labels = relabel_segmentation(labels, df, threads=2)
236
+ Done.
237
+
238
+ Use custom column labels and exclude rows with NaNs:
239
+
240
+ >>> column_labels = {
241
+ ... 'track': "track_id",
242
+ ... 'frame': "time",
243
+ ... 'label': "object_id"
244
+ ... }
245
+ >>> new_labels = relabel_segmentation(labels, df, column_labels=column_labels, exclude_nans=True)
246
+ Done.
247
+
248
+ """
249
+
250
+ n_threads = threads
251
+ df = df.sort_values(by=[column_labels["track"], column_labels["frame"]])
252
+ if exclude_nans:
253
+ df = df.dropna(subset=column_labels["label"])
254
+
255
+ new_labels = np.zeros_like(labels)
256
+ shared_data = {"s": 0}
257
+
258
+ # Progress tracking
259
+ shared_progress = {"val": 0, "lock": threading.Lock()}
260
+ total_frames = len(df[column_labels["frame"]].dropna().unique())
261
+
262
+ def rewrite_labels(indices):
263
+
264
+ all_track_ids = df[column_labels["track"]].dropna().unique()
265
+
266
+ # Check for cancellation
267
+ if progress_callback:
268
+ with shared_progress["lock"]:
269
+ if shared_progress.get("cancelled", False):
270
+ return
271
+
272
+ disable_tqdm = progress_callback is not None
273
+
274
+ for t in tqdm(indices, disable=disable_tqdm):
275
+
276
+ # Cancellation check inside loop
277
+ if progress_callback:
278
+ with shared_progress["lock"]:
279
+ if shared_progress.get("cancelled", False):
280
+ return
281
+
282
+ shared_progress["val"] += 1
283
+ p = int((shared_progress["val"] / total_frames) * 100)
284
+
285
+ if not progress_callback(p):
286
+ with shared_progress["lock"]:
287
+ shared_progress["cancelled"] = True
288
+ return
289
+
290
+ f = int(t)
291
+ cells = df.loc[
292
+ df[column_labels["frame"]] == f,
293
+ [column_labels["track"], column_labels["label"]],
294
+ ].to_numpy()
295
+ tracks_at_t = list(cells[:, 0])
296
+ identities = list(cells[:, 1])
297
+
298
+ labels_at_t = list(np.unique(labels[f]))
299
+ if 0 in labels_at_t:
300
+ labels_at_t.remove(0)
301
+ labels_not_in_df = [lbl for lbl in labels_at_t if lbl not in identities]
302
+ for lbl in labels_not_in_df:
303
+ with threading.Lock(): # Synchronize access to `shared_data["s"]`
304
+ track_id = max(all_track_ids) + shared_data["s"]
305
+ shared_data["s"] += 1
306
+ tracks_at_t.append(track_id)
307
+ identities.append(lbl)
308
+
309
+ # exclude NaN
310
+ tracks_at_t = np.array(tracks_at_t)
311
+ identities = np.array(identities)
312
+
313
+ tracks_at_t = tracks_at_t[identities == identities]
314
+ identities = identities[identities == identities]
315
+
316
+ for k in range(len(identities)):
317
+
318
+ # need routine to check values from labels not in class_id of this frame and add new track id
319
+
320
+ loc_i, loc_j = np.where(labels[f] == identities[k])
321
+ track_id = tracks_at_t[k]
322
+
323
+ if track_id == track_id:
324
+ new_labels[f, loc_i, loc_j] = round(track_id)
325
+
326
+ # Multithreading
327
+ indices = list(df[column_labels["frame"]].dropna().unique())
328
+ chunks = np.array_split(indices, n_threads)
329
+
330
+ with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
331
+
332
+ results = executor.map(
333
+ rewrite_labels, chunks
334
+ ) # list(map(lambda x: executor.submit(self.parallel_job, x), chunks))
335
+ try:
336
+ for i, return_value in enumerate(results):
337
+ # print(f"Thread {i} output check: ", return_value)
338
+ pass
339
+ except Exception as e:
340
+ print("Exception: ", e)
341
+
342
+ if shared_progress.get("cancelled", False):
343
+ print("Relabeling cancelled.")
344
+ return None
345
+
346
+ print("\nDone.")
347
+
348
+ return new_labels
@@ -0,0 +1,5 @@
1
+ from scipy.ndimage import zoom
2
+
3
+
4
+ def _rescale_labels(lbl, scale_model=1):
5
+ return zoom(lbl, [1.0 / scale_model, 1.0 / scale_model], order=0)
@@ -0,0 +1,184 @@
1
+ from typing import Union, List, Tuple
2
+
3
+ import numpy as np
4
+ from scipy import ndimage
5
+ from scipy.ndimage.morphology import distance_transform_edt
6
+ from skimage.morphology import disk
7
+ from celldetective.log_manager import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ def contour_of_instance_segmentation(label, distance, sdf=None, voronoi_map=None):
13
+ """
14
+
15
+ Generate an instance mask containing the contour of the segmented objects.
16
+
17
+ This updated version uses a Signed Distance Field (SDF) and Voronoi tessellation approach
18
+ Generic enough to handle Inner contours, Outer contours, and arbitrary "stripes" (annuli).
19
+
20
+ Parameters
21
+ ----------
22
+ label : ndarray
23
+ The instance segmentation labels.
24
+ distance : int, float, list, tuple, or str
25
+ The distance specification.
26
+ - Scalar > 0: Inner contour (Erosion) from boundary to 'distance' pixels inside. Range [0, distance].
27
+ - Scalar < 0: Outer contour (Dilation) from 'distance' pixels outside to boundary. Range [distance, 0].
28
+ - Tuple/List (a, b): Explicit range in SDF space.
29
+ - Positive values are inside the object.
30
+ - Negative values are outside methods.
31
+ - Example: (5, 10) -> Inner ring 5 to 10px deep.
32
+ - Example: (-10, -5) -> Outer ring 5 to 10px away.
33
+ - String "rad1-rad2": Interpretation for Batch Measurements (Outer Rings).
34
+ - Interpreted as an annular region OUTSIDE the object.
35
+ - "5-10" -> Range [-10, -5] in SDF space (5 to 10px away).
36
+ sdf : ndarray, optional
37
+ Pre-computed Signed Distance Field (dist_in - dist_out).
38
+ If provided, avoids recomputing EDT.
39
+ voronoi_map : ndarray, optional
40
+ Pre-computed Voronoi map of instance labels.
41
+ Required if sdf is provided and outer contours are needed.
42
+
43
+ Returns
44
+ -------
45
+ border_label : ndarray
46
+ An instance mask containing the contour of the segmented objects.
47
+ Outer contours preserve instance identity via Voronoi propagation.
48
+
49
+ """
50
+ from scipy.ndimage import distance_transform_edt
51
+
52
+ # helper to parse string "rad1-rad2"
53
+ if isinstance(distance, str):
54
+ try:
55
+ # Check for stringified tuple "(a, b)"
56
+ distance = distance.strip()
57
+ if distance.startswith("(") and distance.endswith(")"):
58
+ import ast
59
+
60
+ val_tuple = ast.literal_eval(distance)
61
+ if isinstance(val_tuple, (list, tuple)) and len(val_tuple) == 2:
62
+ min_r = val_tuple[0]
63
+ max_r = val_tuple[1]
64
+ else:
65
+ raise ValueError("Tuple string must have 2 elements")
66
+ else:
67
+ try:
68
+ val = float(distance)
69
+ # It's a scalar string like "5" or "-5"
70
+ if val >= 0:
71
+ min_r = 0
72
+ max_r = val
73
+ else:
74
+ min_r = val
75
+ max_r = 0
76
+ except ValueError:
77
+ # It's a range string "5-10"
78
+ parts = distance.split("-")
79
+ # Assumption: "A-B" where A, B positive radii for OUTER annulus.
80
+ r1 = float(parts[0])
81
+ r2 = float(parts[1])
82
+ min_r = -max(r1, r2)
83
+ max_r = -min(r1, r2)
84
+
85
+ except Exception:
86
+ logger.warning(
87
+ f"Could not parse contour string '{distance}'. returning empty."
88
+ )
89
+ return np.zeros_like(label)
90
+
91
+ elif isinstance(distance, (list, tuple)):
92
+ min_r = distance[0]
93
+ max_r = distance[1]
94
+
95
+ elif isinstance(distance, (int, float)):
96
+ if distance >= 0:
97
+ min_r = 0
98
+ max_r = distance
99
+ else:
100
+ min_r = distance
101
+ max_r = 0
102
+ else:
103
+ return np.zeros_like(label)
104
+
105
+ if sdf is None or voronoi_map is None:
106
+ # Compute SDF maps
107
+ # We need SDF = dist_in - dist_out
108
+ # inside > 0, outside < 0
109
+
110
+ # 1. Dist In (Inside object)
111
+ dist_in = distance_transform_edt(label > 0)
112
+
113
+ # 2. Dist Out (Outside object) + Voronoi
114
+ dist_out, indices = distance_transform_edt(label == 0, return_indices=True)
115
+
116
+ # Voronoi Map
117
+ voronoi_map = label[indices[0], indices[1]]
118
+
119
+ # Composite SDF
120
+ sdf = dist_in - dist_out
121
+
122
+ # Create Mask
123
+ mask = (sdf >= min_r) & (sdf <= max_r)
124
+
125
+ # Result
126
+ border_label = voronoi_map * mask
127
+
128
+ return border_label
129
+
130
+
131
+ def create_patch_mask(h, w, center=None, radius=None):
132
+ """
133
+
134
+ Create a circular patch mask of given dimensions.
135
+ Adapted from alkasm on https://stackoverflow.com/questions/44865023/how-can-i-create-a-circular-mask-for-a-numpy-array
136
+
137
+ Parameters
138
+ ----------
139
+ h : int
140
+ Height of the mask. Prefer odd value.
141
+ w : int
142
+ Width of the mask. Prefer odd value.
143
+ center : tuple, optional
144
+ Coordinates of the center of the patch. If not provided, the middle of the image is used.
145
+ radius : int or float or list, optional
146
+ Radius of the circular patch. If not provided, the smallest distance between the center and image walls is used.
147
+ If a list is provided, it should contain two elements representing the inner and outer radii of a circular annular patch.
148
+
149
+ Returns
150
+ -------
151
+ numpy.ndarray
152
+ Boolean mask where True values represent pixels within the circular patch or annular patch, and False values represent pixels outside.
153
+
154
+ Notes
155
+ -----
156
+ The function creates a circular patch mask of the given dimensions by determining which pixels fall within the circular patch or annular patch.
157
+ The circular patch or annular patch is centered at the specified coordinates or at the middle of the image if coordinates are not provided.
158
+ The radius of the circular patch or annular patch is determined by the provided radius parameter or by the minimum distance between the center and image walls.
159
+ If an annular patch is desired, the radius parameter should be a list containing the inner and outer radii respectively.
160
+
161
+ Examples
162
+ --------
163
+ >>> mask = create_patch_mask(100, 100, center=(50, 50), radius=30)
164
+ >>> print(mask)
165
+
166
+ """
167
+
168
+ if center is None: # use the middle of the image
169
+ center = (int(w / 2), int(h / 2))
170
+ if radius is None: # use the smallest distance between the center and image walls
171
+ radius = min(center[0], center[1], w - center[0], h - center[1])
172
+
173
+ Y, X = np.ogrid[:h, :w]
174
+ dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
175
+
176
+ if isinstance(radius, int) or isinstance(radius, float):
177
+ mask = dist_from_center <= radius
178
+ elif isinstance(radius, list):
179
+ mask = (dist_from_center <= radius[1]) * (dist_from_center >= radius[0])
180
+ else:
181
+ print("Please provide a proper format for the radius")
182
+ return None
183
+
184
+ return mask