ocdkit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocdkit/__init__.py +10 -0
- ocdkit/array/__init__.py +3 -0
- ocdkit/array/convert.py +121 -0
- ocdkit/array/filters.py +56 -0
- ocdkit/array/imports.py +4 -0
- ocdkit/array/index.py +242 -0
- ocdkit/array/morphology.py +194 -0
- ocdkit/array/normalize.py +425 -0
- ocdkit/array/ops.py +134 -0
- ocdkit/array/spatial.py +410 -0
- ocdkit/array/transform.py +261 -0
- ocdkit/array/union_find.py +52 -0
- ocdkit/array/warp.py +28 -0
- ocdkit/imports.py +8 -0
- ocdkit/io/__init__.py +3 -0
- ocdkit/io/files.py +141 -0
- ocdkit/io/image.py +138 -0
- ocdkit/io/imports.py +4 -0
- ocdkit/io/path.py +68 -0
- ocdkit/io/result.py +34 -0
- ocdkit/load/__init__.py +5 -0
- ocdkit/load/module.py +132 -0
- ocdkit/load/object.py +136 -0
- ocdkit/logging/__init__.py +3 -0
- ocdkit/logging/handler.py +206 -0
- ocdkit/measure/__init__.py +3 -0
- ocdkit/measure/bbox.py +188 -0
- ocdkit/measure/diameter.py +185 -0
- ocdkit/measure/imports.py +4 -0
- ocdkit/measure/medoid.py +181 -0
- ocdkit/measure/metrics.py +43 -0
- ocdkit/plot/__init__.py +5 -0
- ocdkit/plot/color.py +215 -0
- ocdkit/plot/contour.py +102 -0
- ocdkit/plot/defaults.py +147 -0
- ocdkit/plot/display.py +133 -0
- ocdkit/plot/export.py +108 -0
- ocdkit/plot/figure.py +24 -0
- ocdkit/plot/grid.py +306 -0
- ocdkit/plot/imports.py +9 -0
- ocdkit/plot/label.py +733 -0
- ocdkit/plot/ncolor.py +54 -0
- ocdkit/utils/__init__.py +3 -0
- ocdkit/utils/collections.py +97 -0
- ocdkit/utils/gpu.py +210 -0
- ocdkit/utils/kwargs.py +136 -0
- ocdkit-0.0.1.dist-info/METADATA +66 -0
- ocdkit-0.0.1.dist-info/RECORD +51 -0
- ocdkit-0.0.1.dist-info/WHEEL +5 -0
- ocdkit-0.0.1.dist-info/licenses/LICENSE +28 -0
- ocdkit-0.0.1.dist-info/top_level.txt +1 -0
ocdkit/__init__.py
ADDED
ocdkit/array/__init__.py
ADDED
ocdkit/array/convert.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Type detection, conversion, and rescaling utilities."""
|
|
2
|
+
|
|
3
|
+
from .imports import *
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_module(x):
|
|
7
|
+
"""Return ``np`` or ``torch`` depending on the type of *x*."""
|
|
8
|
+
if isinstance(x, da.Array):
|
|
9
|
+
return np
|
|
10
|
+
if isinstance(x, (np.ndarray, tuple, int, float)) or np.isscalar(x):
|
|
11
|
+
return np
|
|
12
|
+
if torch.is_tensor(x):
|
|
13
|
+
return torch
|
|
14
|
+
raise ValueError(
|
|
15
|
+
"Input must be a numpy array, a tuple, a torch tensor, "
|
|
16
|
+
"an integer, or a float"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def safe_divide(num, den, cutoff=0):
|
|
21
|
+
"""Division ignoring zeros and NaNs in the denominator."""
|
|
22
|
+
module = get_module(num)
|
|
23
|
+
valid_den = (den > cutoff) & module.isfinite(den)
|
|
24
|
+
|
|
25
|
+
if isinstance(num, da.Array) or isinstance(den, da.Array):
|
|
26
|
+
return da.where(valid_den, num / den, 0)
|
|
27
|
+
elif module == np:
|
|
28
|
+
r = num.astype(np.float32, copy=False)
|
|
29
|
+
r = np.divide(r, den, out=np.zeros_like(r), where=valid_den)
|
|
30
|
+
elif module == torch:
|
|
31
|
+
r = num.float()
|
|
32
|
+
den = den.float()
|
|
33
|
+
safe_den = torch.where(valid_den, den, torch.ones_like(den))
|
|
34
|
+
r = torch.where(valid_den, torch.div(r, safe_den), torch.zeros_like(r))
|
|
35
|
+
else:
|
|
36
|
+
raise TypeError("num must be a numpy array or a PyTorch tensor")
|
|
37
|
+
|
|
38
|
+
return r
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def rescale(T, floor=None, ceiling=None, exclude_dims=None):
|
|
42
|
+
"""Min-max rescale to [0, 1].
|
|
43
|
+
|
|
44
|
+
Works on numpy arrays and torch tensors. When *exclude_dims* is
|
|
45
|
+
given, normalization is applied independently along those axes.
|
|
46
|
+
"""
|
|
47
|
+
module = get_module(T)
|
|
48
|
+
if exclude_dims is not None:
|
|
49
|
+
if isinstance(exclude_dims, int):
|
|
50
|
+
exclude_dims = (exclude_dims,)
|
|
51
|
+
axes = tuple(i for i in range(T.ndim) if i not in exclude_dims)
|
|
52
|
+
newshape = [T.shape[i] if i in exclude_dims else 1 for i in range(T.ndim)]
|
|
53
|
+
else:
|
|
54
|
+
axes = None
|
|
55
|
+
newshape = T.shape
|
|
56
|
+
|
|
57
|
+
if ceiling is None:
|
|
58
|
+
ceiling = module.amax(T, axis=axes)
|
|
59
|
+
if exclude_dims is not None:
|
|
60
|
+
ceiling = ceiling.reshape(*newshape)
|
|
61
|
+
if floor is None:
|
|
62
|
+
floor = module.amin(T, axis=axes)
|
|
63
|
+
if exclude_dims is not None:
|
|
64
|
+
floor = floor.reshape(*newshape)
|
|
65
|
+
|
|
66
|
+
T = safe_divide(T - floor, ceiling - floor)
|
|
67
|
+
|
|
68
|
+
return T
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def to_16_bit(im):
|
|
72
|
+
"""Rescale image to [0, 2**16 - 1] and cast to uint16."""
|
|
73
|
+
return np.uint16(rescale(im) * (2 ** 16 - 1))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def to_8_bit(im):
|
|
77
|
+
"""Rescale image to [0, 2**8 - 1] and cast to uint8."""
|
|
78
|
+
return np.uint8(rescale(im) * (2 ** 8 - 1))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def is_integer(var):
|
|
82
|
+
"""Check whether *var* is an integer or integer-typed array/tensor."""
|
|
83
|
+
if isinstance(var, int):
|
|
84
|
+
return True
|
|
85
|
+
if isinstance(var, np.integer):
|
|
86
|
+
return True
|
|
87
|
+
if isinstance(var, (np.ndarray, np.memmap)) and np.issubdtype(var.dtype, np.integer):
|
|
88
|
+
return True
|
|
89
|
+
if isinstance(var, da.Array) and np.issubdtype(var.dtype, np.integer):
|
|
90
|
+
return True
|
|
91
|
+
if isinstance(var, torch.Tensor) and not var.is_floating_point():
|
|
92
|
+
return True
|
|
93
|
+
return False
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def move_axis(img, axis=-1, pos="last"):
|
|
97
|
+
"""Move ndarray axis to a new location, preserving order of other axes."""
|
|
98
|
+
if axis == -1:
|
|
99
|
+
axis = img.ndim - 1
|
|
100
|
+
axis = min(img.ndim - 1, axis)
|
|
101
|
+
if pos in ("first", 0):
|
|
102
|
+
pos = 0
|
|
103
|
+
elif pos in ("last", -1):
|
|
104
|
+
pos = img.ndim - 1
|
|
105
|
+
perm = list(range(img.ndim))
|
|
106
|
+
perm.pop(axis)
|
|
107
|
+
perm.insert(pos, axis)
|
|
108
|
+
return np.transpose(img, perm)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def move_min_dim(img, force=False):
|
|
112
|
+
"""Move the minimum-sized dimension last (as channels) if < 10."""
|
|
113
|
+
if len(img.shape) > 2:
|
|
114
|
+
min_dim = min(img.shape)
|
|
115
|
+
if min_dim < 10 or force:
|
|
116
|
+
if img.shape[-1] == min_dim:
|
|
117
|
+
channel_axis = -1
|
|
118
|
+
else:
|
|
119
|
+
channel_axis = (img.shape).index(min_dim)
|
|
120
|
+
img = move_axis(img, axis=channel_axis, pos="last")
|
|
121
|
+
return img
|
ocdkit/array/filters.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Spatial filtering utilities for labeled arrays."""
|
|
2
|
+
|
|
3
|
+
from .imports import *
|
|
4
|
+
from numba import njit
|
|
5
|
+
|
|
6
|
+
from .spatial import kernel_setup, get_neighbors
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@njit
|
|
10
|
+
def _most_frequent(neighbor_masks):
|
|
11
|
+
"""Column-wise mode: for each column, return the most common value."""
|
|
12
|
+
return np.array([np.bincount(row).argmax() for row in neighbor_masks.T])
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def mode_filter(masks):
|
|
16
|
+
"""Replace each nonzero pixel with the most frequent label in its neighborhood.
|
|
17
|
+
|
|
18
|
+
Uses ocdkit's spatial neighbor primitives and a numba-JIT inner loop
|
|
19
|
+
for the per-pixel bincount. Background (0) results are replaced with
|
|
20
|
+
the original label to avoid erosion.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
masks : ndarray
|
|
25
|
+
Integer label array (2D or ND).
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
ndarray
|
|
30
|
+
Filtered label array, same shape as input.
|
|
31
|
+
"""
|
|
32
|
+
pad = 1
|
|
33
|
+
masks = np.pad(masks, pad).astype(int)
|
|
34
|
+
d = masks.ndim
|
|
35
|
+
shape = masks.shape
|
|
36
|
+
coords = np.nonzero(masks)
|
|
37
|
+
|
|
38
|
+
if coords[0].size == 0:
|
|
39
|
+
unpad = tuple([slice(pad, -pad)] * d)
|
|
40
|
+
return masks[unpad]
|
|
41
|
+
|
|
42
|
+
steps, inds, idx, fact, sign = kernel_setup(d)
|
|
43
|
+
subinds = np.concatenate(inds)
|
|
44
|
+
substeps = steps[subinds]
|
|
45
|
+
neighbors = get_neighbors(coords, substeps, d, shape)
|
|
46
|
+
|
|
47
|
+
neighbor_masks = masks[tuple(neighbors)]
|
|
48
|
+
|
|
49
|
+
mask_filt = np.zeros_like(masks)
|
|
50
|
+
most_f = _most_frequent(neighbor_masks)
|
|
51
|
+
z = most_f == 0
|
|
52
|
+
most_f[z] = masks[coords][z]
|
|
53
|
+
mask_filt[coords] = most_f
|
|
54
|
+
|
|
55
|
+
unpad = tuple([slice(pad, -pad)] * d)
|
|
56
|
+
return mask_filt[unpad]
|
ocdkit/array/imports.py
ADDED
ocdkit/array/index.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
"""Indexing, slicing, partitioning, and coordinate generation."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
|
|
5
|
+
from .imports import *
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def ravel_index(b, shp):
|
|
9
|
+
"""Row-major flat index from multi-dim indices *b* and array shape *shp*.
|
|
10
|
+
|
|
11
|
+
*b* has shape ``(ndim, npts)``; returns a 1D array of length ``npts``.
|
|
12
|
+
"""
|
|
13
|
+
return np.concatenate((np.asarray(shp[1:])[::-1].cumprod()[::-1], [1])).dot(b)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def unravel_index(index, shape):
|
|
17
|
+
"""Multi-dim indices (row-major) from a flat *index* and array *shape*.
|
|
18
|
+
|
|
19
|
+
Returns a tuple of per-axis index arrays.
|
|
20
|
+
"""
|
|
21
|
+
out = []
|
|
22
|
+
for dim in reversed(shape):
|
|
23
|
+
out.append(index % dim)
|
|
24
|
+
index = index // dim
|
|
25
|
+
return tuple(reversed(out))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def border_indices(tyx):
|
|
29
|
+
"""Flat indices of the border pixels of an ND array with shape *tyx*.
|
|
30
|
+
|
|
31
|
+
Use as ``A.flat[border_indices(A.shape)]``.
|
|
32
|
+
"""
|
|
33
|
+
dim_indices = [np.arange(dim_size) for dim_size in tyx]
|
|
34
|
+
dim_indices = np.meshgrid(*dim_indices, indexing='ij')
|
|
35
|
+
dim_indices = [indices.ravel() for indices in dim_indices]
|
|
36
|
+
|
|
37
|
+
indices = []
|
|
38
|
+
for i in range(len(tyx)):
|
|
39
|
+
for j in [0, tyx[i] - 1]:
|
|
40
|
+
mask = (dim_indices[i] == j)
|
|
41
|
+
indices.append(np.where(mask)[0])
|
|
42
|
+
return np.concatenate(indices)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def split_array(array, parts, axes=None):
|
|
46
|
+
"""Split an ndarray into *parts* along specified *axes*.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
array : ndarray
|
|
51
|
+
The array to split.
|
|
52
|
+
parts : int or tuple of int
|
|
53
|
+
Number of parts to split along each axis. If an integer, applies to
|
|
54
|
+
all specified axes.
|
|
55
|
+
axes : int, tuple of int, or None
|
|
56
|
+
The axes to split. If None, splits all axes.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
nested list of ndarray
|
|
61
|
+
Nested sub-arrays after splitting. The nesting depth equals
|
|
62
|
+
``len(axes)``.
|
|
63
|
+
"""
|
|
64
|
+
if isinstance(parts, int):
|
|
65
|
+
parts = (parts,) * array.ndim
|
|
66
|
+
|
|
67
|
+
if axes is None:
|
|
68
|
+
axes = tuple(range(array.ndim))
|
|
69
|
+
elif isinstance(axes, int):
|
|
70
|
+
axes = (axes,)
|
|
71
|
+
|
|
72
|
+
if len(parts) != len(axes):
|
|
73
|
+
raise ValueError("Length of 'parts' must match the number of axes specified.")
|
|
74
|
+
|
|
75
|
+
splits = []
|
|
76
|
+
warnings = []
|
|
77
|
+
|
|
78
|
+
for ax, num_parts in zip(axes, parts):
|
|
79
|
+
dim_size = array.shape[ax]
|
|
80
|
+
chunk_sizes = [
|
|
81
|
+
dim_size // num_parts + (1 if i < dim_size % num_parts else 0)
|
|
82
|
+
for i in range(num_parts)
|
|
83
|
+
]
|
|
84
|
+
if dim_size % num_parts != 0:
|
|
85
|
+
warnings.append(f"Axis {ax} ({dim_size}) is not evenly divisible by {num_parts}.")
|
|
86
|
+
split_indices = np.cumsum(chunk_sizes[:-1])
|
|
87
|
+
splits.append(np.split(np.arange(dim_size), split_indices))
|
|
88
|
+
|
|
89
|
+
for warning in warnings:
|
|
90
|
+
print("Warning:", warning)
|
|
91
|
+
|
|
92
|
+
def _recursive_split(array, splits, axes):
|
|
93
|
+
if not splits:
|
|
94
|
+
return array
|
|
95
|
+
ax = axes[0]
|
|
96
|
+
subarrays = []
|
|
97
|
+
for idxs in splits[0]:
|
|
98
|
+
sliced = np.take(array, idxs, axis=ax)
|
|
99
|
+
subarrays.append(_recursive_split(sliced, splits[1:], axes[1:]))
|
|
100
|
+
return subarrays
|
|
101
|
+
|
|
102
|
+
return _recursive_split(array, splits, axes)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def reconstruct_array(nested_list, axes=None):
|
|
106
|
+
"""Reconstruct an ndarray from a nested list produced by :func:`split_array`.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
nested_list : list of ndarray
|
|
111
|
+
Nested sub-arrays to concatenate back together.
|
|
112
|
+
axes : int, tuple of int, or None
|
|
113
|
+
The axes used for splitting. If None, infers from the outer list's
|
|
114
|
+
nesting depth.
|
|
115
|
+
"""
|
|
116
|
+
if axes is None:
|
|
117
|
+
axes = tuple(range(
|
|
118
|
+
len(nested_list[0].shape) if isinstance(nested_list[0], np.ndarray)
|
|
119
|
+
else len(nested_list)
|
|
120
|
+
))
|
|
121
|
+
elif isinstance(axes, int):
|
|
122
|
+
axes = (axes,)
|
|
123
|
+
|
|
124
|
+
def _recursive_reconstruct(nested, level):
|
|
125
|
+
if isinstance(nested, np.ndarray):
|
|
126
|
+
return nested
|
|
127
|
+
if level == len(axes):
|
|
128
|
+
return np.array(nested)
|
|
129
|
+
return np.concatenate(
|
|
130
|
+
[_recursive_reconstruct(sub, level + 1) for sub in nested],
|
|
131
|
+
axis=axes[level],
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return _recursive_reconstruct(nested_list, 0)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def meshgrid(shape):
|
|
138
|
+
"""Generate a tuple of coordinate grids for an ND array of given *shape*.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
shape : tuple of int
|
|
143
|
+
Shape of the ND array (e.g., ``(Y, X)`` for 2D, ``(Z, Y, X)`` for 3D).
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
tuple of ndarray
|
|
148
|
+
``N`` coordinate arrays, one per dimension, in ``ij`` indexing.
|
|
149
|
+
"""
|
|
150
|
+
ranges = [np.arange(dim) for dim in shape]
|
|
151
|
+
return np.meshgrid(*ranges, indexing='ij')
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def generate_flat_coordinates(shape):
|
|
155
|
+
"""Generate flat coordinate arrays for an ND array.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
shape : tuple of int
|
|
160
|
+
Shape of the array (e.g., ``(Y, X)`` for 2D).
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
tuple of ndarray
|
|
165
|
+
``N`` 1D arrays containing the coordinates of every element of an
|
|
166
|
+
array with ``shape``, in ``ij`` order.
|
|
167
|
+
"""
|
|
168
|
+
return tuple(grid.ravel() for grid in meshgrid(shape))
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def get_slice_tuple(start, stop, shape, axis=None):
|
|
172
|
+
"""Build a tuple of slice objects for ND array indexing.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
start : int or iterable of int
|
|
177
|
+
Starting index(es).
|
|
178
|
+
stop : int or iterable of int
|
|
179
|
+
Stopping index(es).
|
|
180
|
+
shape : tuple of int
|
|
181
|
+
Shape of the array to slice.
|
|
182
|
+
axis : int, iterable of int, or None
|
|
183
|
+
Axis or axes to apply slices to. Default: all axes if *start*/*stop*
|
|
184
|
+
are iterable, else axis ``0``.
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
tuple of slice
|
|
189
|
+
Length ``len(shape)``. Axes not addressed by *axis* get ``slice(None)``.
|
|
190
|
+
"""
|
|
191
|
+
ndim = len(shape)
|
|
192
|
+
slices = [slice(None)] * ndim
|
|
193
|
+
|
|
194
|
+
if isinstance(start, Iterable) and isinstance(stop, Iterable):
|
|
195
|
+
if axis is None:
|
|
196
|
+
axis = list(range(ndim))
|
|
197
|
+
|
|
198
|
+
if len(start) != len(stop):
|
|
199
|
+
raise ValueError("start and stop must be the same length")
|
|
200
|
+
|
|
201
|
+
if isinstance(axis, Iterable):
|
|
202
|
+
if len(axis) != len(start):
|
|
203
|
+
raise ValueError("axis must be the same length as start and stop")
|
|
204
|
+
else:
|
|
205
|
+
axis = [axis] * len(start)
|
|
206
|
+
|
|
207
|
+
for a, s, e in zip(axis, start, stop):
|
|
208
|
+
slices[a] = slice(s, e, None)
|
|
209
|
+
else:
|
|
210
|
+
if axis is None:
|
|
211
|
+
axis = 0
|
|
212
|
+
slices[axis] = slice(start, stop, None)
|
|
213
|
+
|
|
214
|
+
return tuple(slices)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def intersect_slices(s1, s2):
|
|
218
|
+
"""Return a slice that is the intersection of *s1* and *s2*.
|
|
219
|
+
|
|
220
|
+
``None`` boundaries are treated as unbounded (``-inf`` / ``+inf``).
|
|
221
|
+
Steps are unified: the first non-None step wins, defaulting to 1.
|
|
222
|
+
"""
|
|
223
|
+
import math
|
|
224
|
+
|
|
225
|
+
lo1 = -math.inf if s1.start is None else s1.start
|
|
226
|
+
hi1 = math.inf if s1.stop is None else s1.stop
|
|
227
|
+
lo2 = -math.inf if s2.start is None else s2.start
|
|
228
|
+
hi2 = math.inf if s2.stop is None else s2.stop
|
|
229
|
+
|
|
230
|
+
new_lo = max(lo1, lo2)
|
|
231
|
+
new_hi = min(hi1, hi2)
|
|
232
|
+
if new_hi < new_lo:
|
|
233
|
+
new_hi = new_lo
|
|
234
|
+
|
|
235
|
+
step = s1.step if s1.step is not None else s2.step
|
|
236
|
+
if step is None:
|
|
237
|
+
step = 1
|
|
238
|
+
|
|
239
|
+
start = None if math.isinf(new_lo) else int(new_lo)
|
|
240
|
+
stop = None if math.isinf(new_hi) else int(new_hi)
|
|
241
|
+
|
|
242
|
+
return slice(start, stop, step)
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Morphology utilities — boundaries, skeletonization, outlines, thresholding."""
|
|
2
|
+
|
|
3
|
+
from .imports import *
|
|
4
|
+
import skimage.morphology
|
|
5
|
+
|
|
6
|
+
from .spatial import kernel_setup
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def find_boundaries(labels, connectivity=1, use_symmetry=False):
|
|
10
|
+
"""Compute boundaries of labeled instances in an N-dimensional array.
|
|
11
|
+
|
|
12
|
+
Replicates ``skimage.segmentation.find_boundaries`` with
|
|
13
|
+
``mode='inner'``, but is much faster.
|
|
14
|
+
"""
|
|
15
|
+
boundaries = np.zeros_like(labels, dtype=bool)
|
|
16
|
+
ndim = labels.ndim
|
|
17
|
+
shape = labels.shape
|
|
18
|
+
|
|
19
|
+
steps, inds, idx, fact, sign = kernel_setup(ndim)
|
|
20
|
+
|
|
21
|
+
if use_symmetry:
|
|
22
|
+
allowed_inds = []
|
|
23
|
+
for i in range(1, 1 + connectivity):
|
|
24
|
+
j = inds[i][:len(inds[i]) // 2]
|
|
25
|
+
allowed_inds.append(j)
|
|
26
|
+
allowed_inds = np.concatenate(allowed_inds)
|
|
27
|
+
else:
|
|
28
|
+
allowed_inds = np.concatenate(inds[1:1 + connectivity])
|
|
29
|
+
|
|
30
|
+
shifts = steps[allowed_inds]
|
|
31
|
+
|
|
32
|
+
if use_symmetry:
|
|
33
|
+
for shift in shifts:
|
|
34
|
+
slices_main = tuple(
|
|
35
|
+
slice(max(-s, 0), min(shape[d] - s, shape[d]))
|
|
36
|
+
for d, s in enumerate(shift)
|
|
37
|
+
)
|
|
38
|
+
slices_shifted = tuple(
|
|
39
|
+
slice(max(s, 0), min(shape[d] + s, shape[d]))
|
|
40
|
+
for d, s in enumerate(shift)
|
|
41
|
+
)
|
|
42
|
+
boundary_main = (
|
|
43
|
+
(labels[slices_main] != labels[slices_shifted])
|
|
44
|
+
& (labels[slices_main] != 0)
|
|
45
|
+
)
|
|
46
|
+
boundary_shifted = (
|
|
47
|
+
(labels[slices_shifted] != labels[slices_main])
|
|
48
|
+
& (labels[slices_shifted] != 0)
|
|
49
|
+
)
|
|
50
|
+
boundaries[slices_main] |= boundary_main
|
|
51
|
+
boundaries[slices_shifted] |= boundary_shifted
|
|
52
|
+
else:
|
|
53
|
+
for shift in shifts:
|
|
54
|
+
slices_main = tuple(
|
|
55
|
+
slice(max(-s, 0), min(shape[d] - s, shape[d]))
|
|
56
|
+
for d, s in enumerate(shift)
|
|
57
|
+
)
|
|
58
|
+
slices_shifted = tuple(
|
|
59
|
+
slice(max(s, 0), min(shape[d] + s, shape[d]))
|
|
60
|
+
for d, s in enumerate(shift)
|
|
61
|
+
)
|
|
62
|
+
boundaries[slices_main] |= (
|
|
63
|
+
(labels[slices_main] != labels[slices_shifted])
|
|
64
|
+
& (labels[slices_main] != 0)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return boundaries.astype(np.uint8)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def skeletonize(labels, dt_thresh=1, dt=None, method='zhang'):
|
|
71
|
+
"""Skeletonize labeled instances, preserving label identity.
|
|
72
|
+
|
|
73
|
+
When *dt* is provided, pixels with ``dt > dt_thresh`` are used as the
|
|
74
|
+
interior mask (fast path). Otherwise, boundaries are removed first and
|
|
75
|
+
missing labels are re-attached after thinning.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
labels : ndarray
|
|
80
|
+
Integer label matrix.
|
|
81
|
+
dt_thresh : float
|
|
82
|
+
Distance-transform threshold for the fast path.
|
|
83
|
+
dt : ndarray, optional
|
|
84
|
+
Pre-computed distance transform.
|
|
85
|
+
method : str
|
|
86
|
+
Skeletonization method (``'zhang'`` or ``'lee'``).
|
|
87
|
+
"""
|
|
88
|
+
if dt is not None:
|
|
89
|
+
inner = dt > dt_thresh
|
|
90
|
+
skel = skimage.morphology.skeletonize(inner, method=method)
|
|
91
|
+
return skel * labels
|
|
92
|
+
|
|
93
|
+
bd = find_boundaries(labels, connectivity=2)
|
|
94
|
+
inner = np.logical_xor(labels > 0, bd)
|
|
95
|
+
skel = skimage.morphology.skeletonize(inner, method=method)
|
|
96
|
+
skeleton = skel * labels
|
|
97
|
+
|
|
98
|
+
original_labels = fastremap.unique(labels)
|
|
99
|
+
original_labels = original_labels[original_labels != 0]
|
|
100
|
+
|
|
101
|
+
skeleton_labels = fastremap.unique(skeleton)
|
|
102
|
+
skeleton_labels = skeleton_labels[skeleton_labels != 0]
|
|
103
|
+
|
|
104
|
+
missing_labels = np.setdiff1d(original_labels, skeleton_labels)
|
|
105
|
+
missing_labels_mask = fastremap.mask_except(labels, list(missing_labels))
|
|
106
|
+
skeleton += missing_labels_mask
|
|
107
|
+
|
|
108
|
+
return skeleton
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def masks_to_outlines(masks, omni=False, mode="inner", connectivity=None):
|
|
112
|
+
"""Return a 0/1 outline mask for labeled instances.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
masks : ndarray
|
|
117
|
+
2D or 3D label matrix. 3D inputs are processed slice-by-slice.
|
|
118
|
+
omni : bool
|
|
119
|
+
If True, use the fast native :func:`find_boundaries` (treats the
|
|
120
|
+
label matrix directly). If False, use the legacy per-label
|
|
121
|
+
erosion path (slower, iterates over each label individually).
|
|
122
|
+
mode : str
|
|
123
|
+
Forwarded to :func:`find_boundaries` (omni=True) or
|
|
124
|
+
``skimage.segmentation.find_boundaries``-equivalent semantics
|
|
125
|
+
(omni=False, only ``"inner"`` is supported).
|
|
126
|
+
connectivity : int, optional
|
|
127
|
+
Forwarded to :func:`find_boundaries` when ``omni=True``. Defaults
|
|
128
|
+
to ``masks.ndim``.
|
|
129
|
+
"""
|
|
130
|
+
if masks.ndim > 3 or masks.ndim < 2:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"masks_to_outlines takes 2D or 3D array, not {masks.ndim}D array"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
outlines = np.zeros(masks.shape, bool)
|
|
136
|
+
|
|
137
|
+
if masks.ndim == 3:
|
|
138
|
+
for i in range(masks.shape[0]):
|
|
139
|
+
outlines[i] = masks_to_outlines(
|
|
140
|
+
masks[i], omni=omni, mode=mode, connectivity=connectivity,
|
|
141
|
+
)
|
|
142
|
+
return outlines
|
|
143
|
+
|
|
144
|
+
if omni:
|
|
145
|
+
if connectivity is None:
|
|
146
|
+
connectivity = masks.ndim
|
|
147
|
+
return find_boundaries(masks, connectivity=connectivity).astype(bool)
|
|
148
|
+
|
|
149
|
+
# Legacy per-label erosion path.
|
|
150
|
+
from scipy.ndimage import binary_erosion, find_objects
|
|
151
|
+
slices = find_objects(masks.astype(int))
|
|
152
|
+
for i, si in enumerate(slices):
|
|
153
|
+
if si is not None:
|
|
154
|
+
sr, sc = si
|
|
155
|
+
mask = (masks[sr, sc] == (i + 1))
|
|
156
|
+
boundary = mask & ~binary_erosion(mask)
|
|
157
|
+
pvr, pvc = np.nonzero(boundary)
|
|
158
|
+
vr, vc = pvr + sr.start, pvc + sc.start
|
|
159
|
+
outlines[vr, vc] = 1
|
|
160
|
+
return outlines
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def hysteresis_threshold(image, low, high):
|
|
164
|
+
"""PyTorch implementation of ``skimage.filters.apply_hysteresis_threshold``.
|
|
165
|
+
|
|
166
|
+
Supports 2D and 3D spatial inputs (expects batch+channel dims, i.e.
|
|
167
|
+
``(B, C, *spatial)``). Minor discrepancies vs. skimage occur for very
|
|
168
|
+
high thresholds on thin objects.
|
|
169
|
+
"""
|
|
170
|
+
import torch
|
|
171
|
+
|
|
172
|
+
if not isinstance(image, torch.Tensor):
|
|
173
|
+
image = torch.tensor(image)
|
|
174
|
+
|
|
175
|
+
mask_low = image > low
|
|
176
|
+
mask_high = image > high
|
|
177
|
+
thresholded = mask_low.clone()
|
|
178
|
+
|
|
179
|
+
spatial_dims = len(image.shape) - 2
|
|
180
|
+
kernel_size = [3] * spatial_dims
|
|
181
|
+
hysteresis_kernel = torch.ones([1, 1] + kernel_size, device=image.device, dtype=image.dtype)
|
|
182
|
+
|
|
183
|
+
thresholded_old = torch.zeros_like(thresholded)
|
|
184
|
+
while (thresholded_old != thresholded).any():
|
|
185
|
+
if spatial_dims == 2:
|
|
186
|
+
hysteresis_magnitude = torch.nn.functional.conv2d(thresholded.float(), hysteresis_kernel, padding=1)
|
|
187
|
+
elif spatial_dims == 3:
|
|
188
|
+
hysteresis_magnitude = torch.nn.functional.conv3d(thresholded.float(), hysteresis_kernel, padding=1)
|
|
189
|
+
else:
|
|
190
|
+
raise ValueError(f'Unsupported number of spatial dimensions: {spatial_dims}')
|
|
191
|
+
thresholded_old.copy_(thresholded)
|
|
192
|
+
thresholded = ((hysteresis_magnitude > 0) & mask_low) | mask_high
|
|
193
|
+
|
|
194
|
+
return thresholded.bool()
|