ocdkit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ocdkit/__init__.py +10 -0
  2. ocdkit/array/__init__.py +3 -0
  3. ocdkit/array/convert.py +121 -0
  4. ocdkit/array/filters.py +56 -0
  5. ocdkit/array/imports.py +4 -0
  6. ocdkit/array/index.py +242 -0
  7. ocdkit/array/morphology.py +194 -0
  8. ocdkit/array/normalize.py +425 -0
  9. ocdkit/array/ops.py +134 -0
  10. ocdkit/array/spatial.py +410 -0
  11. ocdkit/array/transform.py +261 -0
  12. ocdkit/array/union_find.py +52 -0
  13. ocdkit/array/warp.py +28 -0
  14. ocdkit/imports.py +8 -0
  15. ocdkit/io/__init__.py +3 -0
  16. ocdkit/io/files.py +141 -0
  17. ocdkit/io/image.py +138 -0
  18. ocdkit/io/imports.py +4 -0
  19. ocdkit/io/path.py +68 -0
  20. ocdkit/io/result.py +34 -0
  21. ocdkit/load/__init__.py +5 -0
  22. ocdkit/load/module.py +132 -0
  23. ocdkit/load/object.py +136 -0
  24. ocdkit/logging/__init__.py +3 -0
  25. ocdkit/logging/handler.py +206 -0
  26. ocdkit/measure/__init__.py +3 -0
  27. ocdkit/measure/bbox.py +188 -0
  28. ocdkit/measure/diameter.py +185 -0
  29. ocdkit/measure/imports.py +4 -0
  30. ocdkit/measure/medoid.py +181 -0
  31. ocdkit/measure/metrics.py +43 -0
  32. ocdkit/plot/__init__.py +5 -0
  33. ocdkit/plot/color.py +215 -0
  34. ocdkit/plot/contour.py +102 -0
  35. ocdkit/plot/defaults.py +147 -0
  36. ocdkit/plot/display.py +133 -0
  37. ocdkit/plot/export.py +108 -0
  38. ocdkit/plot/figure.py +24 -0
  39. ocdkit/plot/grid.py +306 -0
  40. ocdkit/plot/imports.py +9 -0
  41. ocdkit/plot/label.py +733 -0
  42. ocdkit/plot/ncolor.py +54 -0
  43. ocdkit/utils/__init__.py +3 -0
  44. ocdkit/utils/collections.py +97 -0
  45. ocdkit/utils/gpu.py +210 -0
  46. ocdkit/utils/kwargs.py +136 -0
  47. ocdkit-0.0.1.dist-info/METADATA +66 -0
  48. ocdkit-0.0.1.dist-info/RECORD +51 -0
  49. ocdkit-0.0.1.dist-info/WHEEL +5 -0
  50. ocdkit-0.0.1.dist-info/licenses/LICENSE +28 -0
  51. ocdkit-0.0.1.dist-info/top_level.txt +1 -0
ocdkit/__init__.py ADDED
@@ -0,0 +1,10 @@
1
+ """
2
+ ocdkit — Obsessively precise utilities for scientific Python.
3
+
4
+ Array manipulation, GPU dispatch, image I/O, morphology, and plotting
5
+ tools shared across projects.
6
+ """
7
+
8
+ from .load import enable_submodules
9
+
10
+ enable_submodules(__name__)
@@ -0,0 +1,3 @@
1
+ from ..load import enable_submodules
2
+
3
+ enable_submodules(__name__)
@@ -0,0 +1,121 @@
1
+ """Type detection, conversion, and rescaling utilities."""
2
+
3
+ from .imports import *
4
+
5
+
6
+ def get_module(x):
7
+ """Return ``np`` or ``torch`` depending on the type of *x*."""
8
+ if isinstance(x, da.Array):
9
+ return np
10
+ if isinstance(x, (np.ndarray, tuple, int, float)) or np.isscalar(x):
11
+ return np
12
+ if torch.is_tensor(x):
13
+ return torch
14
+ raise ValueError(
15
+ "Input must be a numpy array, a tuple, a torch tensor, "
16
+ "an integer, or a float"
17
+ )
18
+
19
+
20
+ def safe_divide(num, den, cutoff=0):
21
+ """Division ignoring zeros and NaNs in the denominator."""
22
+ module = get_module(num)
23
+ valid_den = (den > cutoff) & module.isfinite(den)
24
+
25
+ if isinstance(num, da.Array) or isinstance(den, da.Array):
26
+ return da.where(valid_den, num / den, 0)
27
+ elif module == np:
28
+ r = num.astype(np.float32, copy=False)
29
+ r = np.divide(r, den, out=np.zeros_like(r), where=valid_den)
30
+ elif module == torch:
31
+ r = num.float()
32
+ den = den.float()
33
+ safe_den = torch.where(valid_den, den, torch.ones_like(den))
34
+ r = torch.where(valid_den, torch.div(r, safe_den), torch.zeros_like(r))
35
+ else:
36
+ raise TypeError("num must be a numpy array or a PyTorch tensor")
37
+
38
+ return r
39
+
40
+
41
+ def rescale(T, floor=None, ceiling=None, exclude_dims=None):
42
+ """Min-max rescale to [0, 1].
43
+
44
+ Works on numpy arrays and torch tensors. When *exclude_dims* is
45
+ given, normalization is applied independently along those axes.
46
+ """
47
+ module = get_module(T)
48
+ if exclude_dims is not None:
49
+ if isinstance(exclude_dims, int):
50
+ exclude_dims = (exclude_dims,)
51
+ axes = tuple(i for i in range(T.ndim) if i not in exclude_dims)
52
+ newshape = [T.shape[i] if i in exclude_dims else 1 for i in range(T.ndim)]
53
+ else:
54
+ axes = None
55
+ newshape = T.shape
56
+
57
+ if ceiling is None:
58
+ ceiling = module.amax(T, axis=axes)
59
+ if exclude_dims is not None:
60
+ ceiling = ceiling.reshape(*newshape)
61
+ if floor is None:
62
+ floor = module.amin(T, axis=axes)
63
+ if exclude_dims is not None:
64
+ floor = floor.reshape(*newshape)
65
+
66
+ T = safe_divide(T - floor, ceiling - floor)
67
+
68
+ return T
69
+
70
+
71
+ def to_16_bit(im):
72
+ """Rescale image to [0, 2**16 - 1] and cast to uint16."""
73
+ return np.uint16(rescale(im) * (2 ** 16 - 1))
74
+
75
+
76
+ def to_8_bit(im):
77
+ """Rescale image to [0, 2**8 - 1] and cast to uint8."""
78
+ return np.uint8(rescale(im) * (2 ** 8 - 1))
79
+
80
+
81
+ def is_integer(var):
82
+ """Check whether *var* is an integer or integer-typed array/tensor."""
83
+ if isinstance(var, int):
84
+ return True
85
+ if isinstance(var, np.integer):
86
+ return True
87
+ if isinstance(var, (np.ndarray, np.memmap)) and np.issubdtype(var.dtype, np.integer):
88
+ return True
89
+ if isinstance(var, da.Array) and np.issubdtype(var.dtype, np.integer):
90
+ return True
91
+ if isinstance(var, torch.Tensor) and not var.is_floating_point():
92
+ return True
93
+ return False
94
+
95
+
96
+ def move_axis(img, axis=-1, pos="last"):
97
+ """Move ndarray axis to a new location, preserving order of other axes."""
98
+ if axis == -1:
99
+ axis = img.ndim - 1
100
+ axis = min(img.ndim - 1, axis)
101
+ if pos in ("first", 0):
102
+ pos = 0
103
+ elif pos in ("last", -1):
104
+ pos = img.ndim - 1
105
+ perm = list(range(img.ndim))
106
+ perm.pop(axis)
107
+ perm.insert(pos, axis)
108
+ return np.transpose(img, perm)
109
+
110
+
111
+ def move_min_dim(img, force=False):
112
+ """Move the minimum-sized dimension last (as channels) if < 10."""
113
+ if len(img.shape) > 2:
114
+ min_dim = min(img.shape)
115
+ if min_dim < 10 or force:
116
+ if img.shape[-1] == min_dim:
117
+ channel_axis = -1
118
+ else:
119
+ channel_axis = (img.shape).index(min_dim)
120
+ img = move_axis(img, axis=channel_axis, pos="last")
121
+ return img
@@ -0,0 +1,56 @@
1
+ """Spatial filtering utilities for labeled arrays."""
2
+
3
+ from .imports import *
4
+ from numba import njit
5
+
6
+ from .spatial import kernel_setup, get_neighbors
7
+
8
+
9
+ @njit
10
+ def _most_frequent(neighbor_masks):
11
+ """Column-wise mode: for each column, return the most common value."""
12
+ return np.array([np.bincount(row).argmax() for row in neighbor_masks.T])
13
+
14
+
15
+ def mode_filter(masks):
16
+ """Replace each nonzero pixel with the most frequent label in its neighborhood.
17
+
18
+ Uses ocdkit's spatial neighbor primitives and a numba-JIT inner loop
19
+ for the per-pixel bincount. Background (0) results are replaced with
20
+ the original label to avoid erosion.
21
+
22
+ Parameters
23
+ ----------
24
+ masks : ndarray
25
+ Integer label array (2D or ND).
26
+
27
+ Returns
28
+ -------
29
+ ndarray
30
+ Filtered label array, same shape as input.
31
+ """
32
+ pad = 1
33
+ masks = np.pad(masks, pad).astype(int)
34
+ d = masks.ndim
35
+ shape = masks.shape
36
+ coords = np.nonzero(masks)
37
+
38
+ if coords[0].size == 0:
39
+ unpad = tuple([slice(pad, -pad)] * d)
40
+ return masks[unpad]
41
+
42
+ steps, inds, idx, fact, sign = kernel_setup(d)
43
+ subinds = np.concatenate(inds)
44
+ substeps = steps[subinds]
45
+ neighbors = get_neighbors(coords, substeps, d, shape)
46
+
47
+ neighbor_masks = masks[tuple(neighbors)]
48
+
49
+ mask_filt = np.zeros_like(masks)
50
+ most_f = _most_frequent(neighbor_masks)
51
+ z = most_f == 0
52
+ most_f[z] = masks[coords][z]
53
+ mask_filt[coords] = most_f
54
+
55
+ unpad = tuple([slice(pad, -pad)] * d)
56
+ return mask_filt[unpad]
@@ -0,0 +1,4 @@
1
+ """Common imports for ocdkit.array subpackage."""
2
+ from ..imports import *
3
+
4
+ __all__ = ['np', 'torch', 'da', 'fastremap']
ocdkit/array/index.py ADDED
@@ -0,0 +1,242 @@
1
+ """Indexing, slicing, partitioning, and coordinate generation."""
2
+
3
+ from collections.abc import Iterable
4
+
5
+ from .imports import *
6
+
7
+
8
+ def ravel_index(b, shp):
9
+ """Row-major flat index from multi-dim indices *b* and array shape *shp*.
10
+
11
+ *b* has shape ``(ndim, npts)``; returns a 1D array of length ``npts``.
12
+ """
13
+ return np.concatenate((np.asarray(shp[1:])[::-1].cumprod()[::-1], [1])).dot(b)
14
+
15
+
16
+ def unravel_index(index, shape):
17
+ """Multi-dim indices (row-major) from a flat *index* and array *shape*.
18
+
19
+ Returns a tuple of per-axis index arrays.
20
+ """
21
+ out = []
22
+ for dim in reversed(shape):
23
+ out.append(index % dim)
24
+ index = index // dim
25
+ return tuple(reversed(out))
26
+
27
+
28
+ def border_indices(tyx):
29
+ """Flat indices of the border pixels of an ND array with shape *tyx*.
30
+
31
+ Use as ``A.flat[border_indices(A.shape)]``.
32
+ """
33
+ dim_indices = [np.arange(dim_size) for dim_size in tyx]
34
+ dim_indices = np.meshgrid(*dim_indices, indexing='ij')
35
+ dim_indices = [indices.ravel() for indices in dim_indices]
36
+
37
+ indices = []
38
+ for i in range(len(tyx)):
39
+ for j in [0, tyx[i] - 1]:
40
+ mask = (dim_indices[i] == j)
41
+ indices.append(np.where(mask)[0])
42
+ return np.concatenate(indices)
43
+
44
+
45
+ def split_array(array, parts, axes=None):
46
+ """Split an ndarray into *parts* along specified *axes*.
47
+
48
+ Parameters
49
+ ----------
50
+ array : ndarray
51
+ The array to split.
52
+ parts : int or tuple of int
53
+ Number of parts to split along each axis. If an integer, applies to
54
+ all specified axes.
55
+ axes : int, tuple of int, or None
56
+ The axes to split. If None, splits all axes.
57
+
58
+ Returns
59
+ -------
60
+ nested list of ndarray
61
+ Nested sub-arrays after splitting. The nesting depth equals
62
+ ``len(axes)``.
63
+ """
64
+ if isinstance(parts, int):
65
+ parts = (parts,) * array.ndim
66
+
67
+ if axes is None:
68
+ axes = tuple(range(array.ndim))
69
+ elif isinstance(axes, int):
70
+ axes = (axes,)
71
+
72
+ if len(parts) != len(axes):
73
+ raise ValueError("Length of 'parts' must match the number of axes specified.")
74
+
75
+ splits = []
76
+ warnings = []
77
+
78
+ for ax, num_parts in zip(axes, parts):
79
+ dim_size = array.shape[ax]
80
+ chunk_sizes = [
81
+ dim_size // num_parts + (1 if i < dim_size % num_parts else 0)
82
+ for i in range(num_parts)
83
+ ]
84
+ if dim_size % num_parts != 0:
85
+ warnings.append(f"Axis {ax} ({dim_size}) is not evenly divisible by {num_parts}.")
86
+ split_indices = np.cumsum(chunk_sizes[:-1])
87
+ splits.append(np.split(np.arange(dim_size), split_indices))
88
+
89
+ for warning in warnings:
90
+ print("Warning:", warning)
91
+
92
+ def _recursive_split(array, splits, axes):
93
+ if not splits:
94
+ return array
95
+ ax = axes[0]
96
+ subarrays = []
97
+ for idxs in splits[0]:
98
+ sliced = np.take(array, idxs, axis=ax)
99
+ subarrays.append(_recursive_split(sliced, splits[1:], axes[1:]))
100
+ return subarrays
101
+
102
+ return _recursive_split(array, splits, axes)
103
+
104
+
105
+ def reconstruct_array(nested_list, axes=None):
106
+ """Reconstruct an ndarray from a nested list produced by :func:`split_array`.
107
+
108
+ Parameters
109
+ ----------
110
+ nested_list : list of ndarray
111
+ Nested sub-arrays to concatenate back together.
112
+ axes : int, tuple of int, or None
113
+ The axes used for splitting. If None, infers from the outer list's
114
+ nesting depth.
115
+ """
116
+ if axes is None:
117
+ axes = tuple(range(
118
+ len(nested_list[0].shape) if isinstance(nested_list[0], np.ndarray)
119
+ else len(nested_list)
120
+ ))
121
+ elif isinstance(axes, int):
122
+ axes = (axes,)
123
+
124
+ def _recursive_reconstruct(nested, level):
125
+ if isinstance(nested, np.ndarray):
126
+ return nested
127
+ if level == len(axes):
128
+ return np.array(nested)
129
+ return np.concatenate(
130
+ [_recursive_reconstruct(sub, level + 1) for sub in nested],
131
+ axis=axes[level],
132
+ )
133
+
134
+ return _recursive_reconstruct(nested_list, 0)
135
+
136
+
137
+ def meshgrid(shape):
138
+ """Generate a tuple of coordinate grids for an ND array of given *shape*.
139
+
140
+ Parameters
141
+ ----------
142
+ shape : tuple of int
143
+ Shape of the ND array (e.g., ``(Y, X)`` for 2D, ``(Z, Y, X)`` for 3D).
144
+
145
+ Returns
146
+ -------
147
+ tuple of ndarray
148
+ ``N`` coordinate arrays, one per dimension, in ``ij`` indexing.
149
+ """
150
+ ranges = [np.arange(dim) for dim in shape]
151
+ return np.meshgrid(*ranges, indexing='ij')
152
+
153
+
154
+ def generate_flat_coordinates(shape):
155
+ """Generate flat coordinate arrays for an ND array.
156
+
157
+ Parameters
158
+ ----------
159
+ shape : tuple of int
160
+ Shape of the array (e.g., ``(Y, X)`` for 2D).
161
+
162
+ Returns
163
+ -------
164
+ tuple of ndarray
165
+ ``N`` 1D arrays containing the coordinates of every element of an
166
+ array with ``shape``, in ``ij`` order.
167
+ """
168
+ return tuple(grid.ravel() for grid in meshgrid(shape))
169
+
170
+
171
+ def get_slice_tuple(start, stop, shape, axis=None):
172
+ """Build a tuple of slice objects for ND array indexing.
173
+
174
+ Parameters
175
+ ----------
176
+ start : int or iterable of int
177
+ Starting index(es).
178
+ stop : int or iterable of int
179
+ Stopping index(es).
180
+ shape : tuple of int
181
+ Shape of the array to slice.
182
+ axis : int, iterable of int, or None
183
+ Axis or axes to apply slices to. Default: all axes if *start*/*stop*
184
+ are iterable, else axis ``0``.
185
+
186
+ Returns
187
+ -------
188
+ tuple of slice
189
+ Length ``len(shape)``. Axes not addressed by *axis* get ``slice(None)``.
190
+ """
191
+ ndim = len(shape)
192
+ slices = [slice(None)] * ndim
193
+
194
+ if isinstance(start, Iterable) and isinstance(stop, Iterable):
195
+ if axis is None:
196
+ axis = list(range(ndim))
197
+
198
+ if len(start) != len(stop):
199
+ raise ValueError("start and stop must be the same length")
200
+
201
+ if isinstance(axis, Iterable):
202
+ if len(axis) != len(start):
203
+ raise ValueError("axis must be the same length as start and stop")
204
+ else:
205
+ axis = [axis] * len(start)
206
+
207
+ for a, s, e in zip(axis, start, stop):
208
+ slices[a] = slice(s, e, None)
209
+ else:
210
+ if axis is None:
211
+ axis = 0
212
+ slices[axis] = slice(start, stop, None)
213
+
214
+ return tuple(slices)
215
+
216
+
217
+ def intersect_slices(s1, s2):
218
+ """Return a slice that is the intersection of *s1* and *s2*.
219
+
220
+ ``None`` boundaries are treated as unbounded (``-inf`` / ``+inf``).
221
+ Steps are unified: the first non-None step wins, defaulting to 1.
222
+ """
223
+ import math
224
+
225
+ lo1 = -math.inf if s1.start is None else s1.start
226
+ hi1 = math.inf if s1.stop is None else s1.stop
227
+ lo2 = -math.inf if s2.start is None else s2.start
228
+ hi2 = math.inf if s2.stop is None else s2.stop
229
+
230
+ new_lo = max(lo1, lo2)
231
+ new_hi = min(hi1, hi2)
232
+ if new_hi < new_lo:
233
+ new_hi = new_lo
234
+
235
+ step = s1.step if s1.step is not None else s2.step
236
+ if step is None:
237
+ step = 1
238
+
239
+ start = None if math.isinf(new_lo) else int(new_lo)
240
+ stop = None if math.isinf(new_hi) else int(new_hi)
241
+
242
+ return slice(start, stop, step)
@@ -0,0 +1,194 @@
1
+ """Morphology utilities — boundaries, skeletonization, outlines, thresholding."""
2
+
3
+ from .imports import *
4
+ import skimage.morphology
5
+
6
+ from .spatial import kernel_setup
7
+
8
+
9
+ def find_boundaries(labels, connectivity=1, use_symmetry=False):
10
+ """Compute boundaries of labeled instances in an N-dimensional array.
11
+
12
+ Replicates ``skimage.segmentation.find_boundaries`` with
13
+ ``mode='inner'``, but is much faster.
14
+ """
15
+ boundaries = np.zeros_like(labels, dtype=bool)
16
+ ndim = labels.ndim
17
+ shape = labels.shape
18
+
19
+ steps, inds, idx, fact, sign = kernel_setup(ndim)
20
+
21
+ if use_symmetry:
22
+ allowed_inds = []
23
+ for i in range(1, 1 + connectivity):
24
+ j = inds[i][:len(inds[i]) // 2]
25
+ allowed_inds.append(j)
26
+ allowed_inds = np.concatenate(allowed_inds)
27
+ else:
28
+ allowed_inds = np.concatenate(inds[1:1 + connectivity])
29
+
30
+ shifts = steps[allowed_inds]
31
+
32
+ if use_symmetry:
33
+ for shift in shifts:
34
+ slices_main = tuple(
35
+ slice(max(-s, 0), min(shape[d] - s, shape[d]))
36
+ for d, s in enumerate(shift)
37
+ )
38
+ slices_shifted = tuple(
39
+ slice(max(s, 0), min(shape[d] + s, shape[d]))
40
+ for d, s in enumerate(shift)
41
+ )
42
+ boundary_main = (
43
+ (labels[slices_main] != labels[slices_shifted])
44
+ & (labels[slices_main] != 0)
45
+ )
46
+ boundary_shifted = (
47
+ (labels[slices_shifted] != labels[slices_main])
48
+ & (labels[slices_shifted] != 0)
49
+ )
50
+ boundaries[slices_main] |= boundary_main
51
+ boundaries[slices_shifted] |= boundary_shifted
52
+ else:
53
+ for shift in shifts:
54
+ slices_main = tuple(
55
+ slice(max(-s, 0), min(shape[d] - s, shape[d]))
56
+ for d, s in enumerate(shift)
57
+ )
58
+ slices_shifted = tuple(
59
+ slice(max(s, 0), min(shape[d] + s, shape[d]))
60
+ for d, s in enumerate(shift)
61
+ )
62
+ boundaries[slices_main] |= (
63
+ (labels[slices_main] != labels[slices_shifted])
64
+ & (labels[slices_main] != 0)
65
+ )
66
+
67
+ return boundaries.astype(np.uint8)
68
+
69
+
70
+ def skeletonize(labels, dt_thresh=1, dt=None, method='zhang'):
71
+ """Skeletonize labeled instances, preserving label identity.
72
+
73
+ When *dt* is provided, pixels with ``dt > dt_thresh`` are used as the
74
+ interior mask (fast path). Otherwise, boundaries are removed first and
75
+ missing labels are re-attached after thinning.
76
+
77
+ Parameters
78
+ ----------
79
+ labels : ndarray
80
+ Integer label matrix.
81
+ dt_thresh : float
82
+ Distance-transform threshold for the fast path.
83
+ dt : ndarray, optional
84
+ Pre-computed distance transform.
85
+ method : str
86
+ Skeletonization method (``'zhang'`` or ``'lee'``).
87
+ """
88
+ if dt is not None:
89
+ inner = dt > dt_thresh
90
+ skel = skimage.morphology.skeletonize(inner, method=method)
91
+ return skel * labels
92
+
93
+ bd = find_boundaries(labels, connectivity=2)
94
+ inner = np.logical_xor(labels > 0, bd)
95
+ skel = skimage.morphology.skeletonize(inner, method=method)
96
+ skeleton = skel * labels
97
+
98
+ original_labels = fastremap.unique(labels)
99
+ original_labels = original_labels[original_labels != 0]
100
+
101
+ skeleton_labels = fastremap.unique(skeleton)
102
+ skeleton_labels = skeleton_labels[skeleton_labels != 0]
103
+
104
+ missing_labels = np.setdiff1d(original_labels, skeleton_labels)
105
+ missing_labels_mask = fastremap.mask_except(labels, list(missing_labels))
106
+ skeleton += missing_labels_mask
107
+
108
+ return skeleton
109
+
110
+
111
+ def masks_to_outlines(masks, omni=False, mode="inner", connectivity=None):
112
+ """Return a 0/1 outline mask for labeled instances.
113
+
114
+ Parameters
115
+ ----------
116
+ masks : ndarray
117
+ 2D or 3D label matrix. 3D inputs are processed slice-by-slice.
118
+ omni : bool
119
+ If True, use the fast native :func:`find_boundaries` (treats the
120
+ label matrix directly). If False, use the legacy per-label
121
+ erosion path (slower, iterates over each label individually).
122
+ mode : str
123
+ Forwarded to :func:`find_boundaries` (omni=True) or
124
+ ``skimage.segmentation.find_boundaries``-equivalent semantics
125
+ (omni=False, only ``"inner"`` is supported).
126
+ connectivity : int, optional
127
+ Forwarded to :func:`find_boundaries` when ``omni=True``. Defaults
128
+ to ``masks.ndim``.
129
+ """
130
+ if masks.ndim > 3 or masks.ndim < 2:
131
+ raise ValueError(
132
+ f"masks_to_outlines takes 2D or 3D array, not {masks.ndim}D array"
133
+ )
134
+
135
+ outlines = np.zeros(masks.shape, bool)
136
+
137
+ if masks.ndim == 3:
138
+ for i in range(masks.shape[0]):
139
+ outlines[i] = masks_to_outlines(
140
+ masks[i], omni=omni, mode=mode, connectivity=connectivity,
141
+ )
142
+ return outlines
143
+
144
+ if omni:
145
+ if connectivity is None:
146
+ connectivity = masks.ndim
147
+ return find_boundaries(masks, connectivity=connectivity).astype(bool)
148
+
149
+ # Legacy per-label erosion path.
150
+ from scipy.ndimage import binary_erosion, find_objects
151
+ slices = find_objects(masks.astype(int))
152
+ for i, si in enumerate(slices):
153
+ if si is not None:
154
+ sr, sc = si
155
+ mask = (masks[sr, sc] == (i + 1))
156
+ boundary = mask & ~binary_erosion(mask)
157
+ pvr, pvc = np.nonzero(boundary)
158
+ vr, vc = pvr + sr.start, pvc + sc.start
159
+ outlines[vr, vc] = 1
160
+ return outlines
161
+
162
+
163
+ def hysteresis_threshold(image, low, high):
164
+ """PyTorch implementation of ``skimage.filters.apply_hysteresis_threshold``.
165
+
166
+ Supports 2D and 3D spatial inputs (expects batch+channel dims, i.e.
167
+ ``(B, C, *spatial)``). Minor discrepancies vs. skimage occur for very
168
+ high thresholds on thin objects.
169
+ """
170
+ import torch
171
+
172
+ if not isinstance(image, torch.Tensor):
173
+ image = torch.tensor(image)
174
+
175
+ mask_low = image > low
176
+ mask_high = image > high
177
+ thresholded = mask_low.clone()
178
+
179
+ spatial_dims = len(image.shape) - 2
180
+ kernel_size = [3] * spatial_dims
181
+ hysteresis_kernel = torch.ones([1, 1] + kernel_size, device=image.device, dtype=image.dtype)
182
+
183
+ thresholded_old = torch.zeros_like(thresholded)
184
+ while (thresholded_old != thresholded).any():
185
+ if spatial_dims == 2:
186
+ hysteresis_magnitude = torch.nn.functional.conv2d(thresholded.float(), hysteresis_kernel, padding=1)
187
+ elif spatial_dims == 3:
188
+ hysteresis_magnitude = torch.nn.functional.conv3d(thresholded.float(), hysteresis_kernel, padding=1)
189
+ else:
190
+ raise ValueError(f'Unsupported number of spatial dimensions: {spatial_dims}')
191
+ thresholded_old.copy_(thresholded)
192
+ thresholded = ((hysteresis_magnitude > 0) & mask_low) | mask_high
193
+
194
+ return thresholded.bool()