ocdkit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocdkit/__init__.py +10 -0
- ocdkit/array/__init__.py +3 -0
- ocdkit/array/convert.py +121 -0
- ocdkit/array/filters.py +56 -0
- ocdkit/array/imports.py +4 -0
- ocdkit/array/index.py +242 -0
- ocdkit/array/morphology.py +194 -0
- ocdkit/array/normalize.py +425 -0
- ocdkit/array/ops.py +134 -0
- ocdkit/array/spatial.py +410 -0
- ocdkit/array/transform.py +261 -0
- ocdkit/array/union_find.py +52 -0
- ocdkit/array/warp.py +28 -0
- ocdkit/imports.py +8 -0
- ocdkit/io/__init__.py +3 -0
- ocdkit/io/files.py +141 -0
- ocdkit/io/image.py +138 -0
- ocdkit/io/imports.py +4 -0
- ocdkit/io/path.py +68 -0
- ocdkit/io/result.py +34 -0
- ocdkit/load/__init__.py +5 -0
- ocdkit/load/module.py +132 -0
- ocdkit/load/object.py +136 -0
- ocdkit/logging/__init__.py +3 -0
- ocdkit/logging/handler.py +206 -0
- ocdkit/measure/__init__.py +3 -0
- ocdkit/measure/bbox.py +188 -0
- ocdkit/measure/diameter.py +185 -0
- ocdkit/measure/imports.py +4 -0
- ocdkit/measure/medoid.py +181 -0
- ocdkit/measure/metrics.py +43 -0
- ocdkit/plot/__init__.py +5 -0
- ocdkit/plot/color.py +215 -0
- ocdkit/plot/contour.py +102 -0
- ocdkit/plot/defaults.py +147 -0
- ocdkit/plot/display.py +133 -0
- ocdkit/plot/export.py +108 -0
- ocdkit/plot/figure.py +24 -0
- ocdkit/plot/grid.py +306 -0
- ocdkit/plot/imports.py +9 -0
- ocdkit/plot/label.py +733 -0
- ocdkit/plot/ncolor.py +54 -0
- ocdkit/utils/__init__.py +3 -0
- ocdkit/utils/collections.py +97 -0
- ocdkit/utils/gpu.py +210 -0
- ocdkit/utils/kwargs.py +136 -0
- ocdkit-0.0.1.dist-info/METADATA +66 -0
- ocdkit-0.0.1.dist-info/RECORD +51 -0
- ocdkit-0.0.1.dist-info/WHEEL +5 -0
- ocdkit-0.0.1.dist-info/licenses/LICENSE +28 -0
- ocdkit-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
"""Normalization and contrast adjustment functions."""
|
|
2
|
+
|
|
3
|
+
from .imports import *
|
|
4
|
+
from .convert import get_module, safe_divide, rescale
|
|
5
|
+
from .ops import searchsorted
|
|
6
|
+
|
|
7
|
+
from ..utils.gpu import torch_GPU
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def torch_norm(a, dim=0, keepdim=False):
|
|
11
|
+
"""Compute vector magnitude along *dim*.
|
|
12
|
+
|
|
13
|
+
Works on numpy arrays and torch tensors. For torch, avoids
|
|
14
|
+
intermediate allocations and supports autograd.
|
|
15
|
+
"""
|
|
16
|
+
module = get_module(a)
|
|
17
|
+
if module == np:
|
|
18
|
+
return np.sqrt(np.sum(a**2, axis=dim, keepdims=keepdim))
|
|
19
|
+
norm_sq = (a * a).sum(dim=dim, keepdim=keepdim)
|
|
20
|
+
return norm_sq.sqrt_() if not norm_sq.requires_grad else norm_sq.sqrt()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _auto_chunked_quantile(tensor, q):
|
|
24
|
+
"""Chunked quantile for large tensors that exceed torch.quantile limits."""
|
|
25
|
+
import math
|
|
26
|
+
max_elements = int(16e6 - 1)
|
|
27
|
+
num_elements = tensor.nelement()
|
|
28
|
+
chunk_size = max(1, math.ceil(num_elements / max_elements))
|
|
29
|
+
chunks = torch.chunk(tensor, chunk_size)
|
|
30
|
+
return torch.stack([torch.quantile(chunk, q) for chunk in chunks]).mean(dim=0)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def normalize99(Y, lower=0.01, upper=99.99, contrast_limits=None, dim=None, **kwargs):
|
|
34
|
+
"""Clip to percentile range and rescale to [0, 1].
|
|
35
|
+
|
|
36
|
+
Works on numpy arrays and torch tensors.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
Y : array or tensor
|
|
41
|
+
Input data.
|
|
42
|
+
lower : float
|
|
43
|
+
Lower percentile (0-100).
|
|
44
|
+
upper : float
|
|
45
|
+
Upper percentile (0-100).
|
|
46
|
+
contrast_limits : tuple of float, optional
|
|
47
|
+
Explicit (low, high) values instead of computing from percentiles.
|
|
48
|
+
dim : int, optional
|
|
49
|
+
Normalize independently along this dimension.
|
|
50
|
+
"""
|
|
51
|
+
module = get_module(Y)
|
|
52
|
+
|
|
53
|
+
if contrast_limits is None:
|
|
54
|
+
quantiles = np.array([lower, upper]) / 100
|
|
55
|
+
if module == torch:
|
|
56
|
+
quantiles = torch.tensor(quantiles, dtype=Y.dtype, device=Y.device)
|
|
57
|
+
|
|
58
|
+
if dim is not None:
|
|
59
|
+
Y_flattened = Y.reshape(Y.shape[dim], -1)
|
|
60
|
+
lower_val, upper_val = module.quantile(Y_flattened, quantiles, axis=-1)
|
|
61
|
+
if dim == 0:
|
|
62
|
+
lower_val = lower_val.reshape(Y.shape[dim], *([1] * (len(Y.shape) - 1)))
|
|
63
|
+
upper_val = upper_val.reshape(Y.shape[dim], *([1] * (len(Y.shape) - 1)))
|
|
64
|
+
else:
|
|
65
|
+
lower_val = lower_val.reshape(*Y.shape[:dim], *([1] * (len(Y.shape) - dim - 1)))
|
|
66
|
+
upper_val = upper_val.reshape(*Y.shape[:dim], *([1] * (len(Y.shape) - dim - 1)))
|
|
67
|
+
else:
|
|
68
|
+
try:
|
|
69
|
+
lower_val, upper_val = module.quantile(Y, quantiles)
|
|
70
|
+
except RuntimeError:
|
|
71
|
+
lower_val, upper_val = _auto_chunked_quantile(Y, quantiles)
|
|
72
|
+
else:
|
|
73
|
+
if module == np:
|
|
74
|
+
contrast_limits = np.array(contrast_limits)
|
|
75
|
+
elif module == torch:
|
|
76
|
+
contrast_limits = torch.tensor(contrast_limits)
|
|
77
|
+
lower_val, upper_val = contrast_limits
|
|
78
|
+
|
|
79
|
+
return module.clip(safe_divide(Y - lower_val, upper_val - lower_val), 0, 1)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def normalize_field(mu, cutoff=0, **kwargs):
|
|
83
|
+
"""Normalize all nonzero field vectors to unit magnitude.
|
|
84
|
+
|
|
85
|
+
Works on numpy arrays and torch tensors (auto-detected).
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
mu : array or tensor
|
|
90
|
+
Vector field, shape ``(D, *spatial)``.
|
|
91
|
+
cutoff : float
|
|
92
|
+
Vectors with magnitude below this are left unchanged.
|
|
93
|
+
"""
|
|
94
|
+
module = get_module(mu)
|
|
95
|
+
if module == torch:
|
|
96
|
+
mag = torch_norm(mu, dim=0)
|
|
97
|
+
return torch.where(mag > cutoff, mu / mag, mu)
|
|
98
|
+
mag = np.sqrt(np.nansum(mu**2, axis=0))
|
|
99
|
+
valid = mag > cutoff
|
|
100
|
+
return np.where(valid, mu / np.where(valid, mag, 1.0), mu)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def quantile_rescale(Y, lower=0.0001, upper=0.9999, contrast_limits=None, bins=None):
|
|
104
|
+
"""Sort-based quantile rescale to [0, 1].
|
|
105
|
+
|
|
106
|
+
Slower than ``normalize99`` for large arrays (mergesort cost), but the
|
|
107
|
+
explicit sort lets the caller plug in a different quantile estimator
|
|
108
|
+
later. Numpy only for now.
|
|
109
|
+
"""
|
|
110
|
+
sorted_array = np.sort(Y.flatten(), kind="mergesort")
|
|
111
|
+
lower_idx = int(lower * (len(sorted_array) - 1))
|
|
112
|
+
upper_idx = int(upper * (len(sorted_array) - 1))
|
|
113
|
+
lower_val, upper_val = sorted_array[lower_idx], sorted_array[upper_idx]
|
|
114
|
+
r = safe_divide(Y - lower_val, upper_val - lower_val)
|
|
115
|
+
r[r < 0] = 0
|
|
116
|
+
r[r > 1] = 1
|
|
117
|
+
return r
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def normalize99_hist(Y, lower=0.01, upper=99.99, contrast_limits=None, bins=None):
|
|
121
|
+
"""Histogram-based percentile clip and rescale to [0, 1].
|
|
122
|
+
|
|
123
|
+
An alternative to ``normalize99`` that estimates quantiles from a CDF
|
|
124
|
+
over a histogram (cheaper than torch.quantile on huge tensors).
|
|
125
|
+
Works on numpy and torch.
|
|
126
|
+
"""
|
|
127
|
+
upper = upper / 100
|
|
128
|
+
lower = lower / 100
|
|
129
|
+
|
|
130
|
+
module = get_module(Y)
|
|
131
|
+
if bins is None:
|
|
132
|
+
num_elements = Y.size if module == np else Y.numel()
|
|
133
|
+
bins = int(np.sqrt(num_elements))
|
|
134
|
+
|
|
135
|
+
if contrast_limits is None:
|
|
136
|
+
hist, bin_edges = module.histogram(Y, bins=bins)
|
|
137
|
+
cdf = module.cumsum(hist, axis=0) / module.sum(hist)
|
|
138
|
+
lower_val = bin_edges[searchsorted(cdf, lower)]
|
|
139
|
+
upper_val = bin_edges[searchsorted(cdf, upper)]
|
|
140
|
+
else:
|
|
141
|
+
if module == np:
|
|
142
|
+
contrast_limits = np.array(contrast_limits)
|
|
143
|
+
elif module == torch:
|
|
144
|
+
contrast_limits = torch.tensor(contrast_limits)
|
|
145
|
+
lower_val, upper_val = contrast_limits
|
|
146
|
+
|
|
147
|
+
r = safe_divide(Y - lower_val, upper_val - lower_val)
|
|
148
|
+
r[r < 0] = 0
|
|
149
|
+
r[r > 1] = 1
|
|
150
|
+
return r
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def qnorm(
|
|
154
|
+
Y,
|
|
155
|
+
nbins=100,
|
|
156
|
+
bw_method=2,
|
|
157
|
+
density_cutoff=None,
|
|
158
|
+
density_quantile=(0.001, 0.999),
|
|
159
|
+
debug=False,
|
|
160
|
+
dx=None,
|
|
161
|
+
log=False,
|
|
162
|
+
eps=1,
|
|
163
|
+
):
|
|
164
|
+
"""Density-based quantile normalization.
|
|
165
|
+
|
|
166
|
+
Bins the histogram, fits a symmetric KDE to the density of the histogram
|
|
167
|
+
counts, and clips to the range where density exceeds *density_cutoff*.
|
|
168
|
+
Useful when the brightness distribution has a heavy tail and a simple
|
|
169
|
+
percentile clip is too aggressive. Numpy only for now.
|
|
170
|
+
"""
|
|
171
|
+
import fastremap
|
|
172
|
+
from scipy.stats import gaussian_kde
|
|
173
|
+
|
|
174
|
+
if dx is not None:
|
|
175
|
+
X = Y[:, ::dx, ::dx]
|
|
176
|
+
else:
|
|
177
|
+
X = Y
|
|
178
|
+
|
|
179
|
+
if X.dtype not in (np.uint8, np.uint16, np.uint32, np.uint64):
|
|
180
|
+
X = (rescale(X) * (2 ** 16 - 1)).astype(np.uint16)
|
|
181
|
+
|
|
182
|
+
# bin counts
|
|
183
|
+
unique_values, counts = fastremap.unique(X, return_counts=True)
|
|
184
|
+
bin_edges = np.linspace(unique_values.min(), unique_values.max(), nbins + 1)
|
|
185
|
+
bin_indices = np.digitize(unique_values, bin_edges) - 1
|
|
186
|
+
binned_counts = np.bincount(bin_indices, weights=counts, minlength=nbins)
|
|
187
|
+
bin_start = bin_edges[:-1]
|
|
188
|
+
binned_counts = binned_counts[:-1]
|
|
189
|
+
|
|
190
|
+
sel = binned_counts > 0
|
|
191
|
+
counts_sel = binned_counts[sel]
|
|
192
|
+
unique_sel = bin_start[sel]
|
|
193
|
+
x = np.arange(len(counts_sel))
|
|
194
|
+
y = np.log(counts_sel + eps) if log else counts_sel
|
|
195
|
+
|
|
196
|
+
# symmetric KDE density
|
|
197
|
+
points = np.vstack([x, y])
|
|
198
|
+
kde = gaussian_kde(points, bw_method=bw_method)
|
|
199
|
+
density = kde(points)
|
|
200
|
+
inverted_kde = gaussian_kde(np.vstack([-x, y]), bw_method=bw_method)
|
|
201
|
+
inverted_density = inverted_kde(np.vstack([-x, y]))
|
|
202
|
+
d = rescale((density + inverted_density) / 2)
|
|
203
|
+
|
|
204
|
+
if not isinstance(density_quantile, (list, tuple)):
|
|
205
|
+
density_quantile = (density_quantile, density_quantile)
|
|
206
|
+
|
|
207
|
+
if density_cutoff is None:
|
|
208
|
+
density_cutoff = np.quantile(d, density_quantile) # pragma: no cover
|
|
209
|
+
elif not isinstance(density_cutoff, (list, tuple)):
|
|
210
|
+
density_cutoff = (density_cutoff, density_cutoff)
|
|
211
|
+
|
|
212
|
+
imin = np.argwhere(d > density_cutoff[0])[0][0]
|
|
213
|
+
imax = np.argwhere(d > density_cutoff[1])[-1][0]
|
|
214
|
+
vmin, vmax = unique_sel[imin], unique_sel[imax]
|
|
215
|
+
|
|
216
|
+
if vmax > vmin:
|
|
217
|
+
scale_factor = np.float16(1.0 / (vmax - vmin))
|
|
218
|
+
r = X * scale_factor
|
|
219
|
+
r[r > 1] = 1
|
|
220
|
+
else:
|
|
221
|
+
r = X
|
|
222
|
+
|
|
223
|
+
if debug:
|
|
224
|
+
return r, x, y, d, imin, imax, vmin, vmax
|
|
225
|
+
return r
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def localnormalize(im, sigma1=2, sigma2=20):
|
|
229
|
+
"""Local mean/std normalization via Gaussian blurs.
|
|
230
|
+
|
|
231
|
+
Works on numpy and torch (auto-dispatched). For torch, uses
|
|
232
|
+
``torchvision.transforms.functional.gaussian_blur``; for numpy uses
|
|
233
|
+
``scipy.ndimage.gaussian_filter``.
|
|
234
|
+
"""
|
|
235
|
+
module = get_module(im)
|
|
236
|
+
if module == torch:
|
|
237
|
+
import torchvision.transforms.functional as TF
|
|
238
|
+
im = normalize99(im)
|
|
239
|
+
ks1 = round(sigma1 * 6)
|
|
240
|
+
ks1 += ks1 % 2 == 0
|
|
241
|
+
blur1 = TF.gaussian_blur(im, ks1, sigma1)
|
|
242
|
+
num = im - blur1
|
|
243
|
+
ks2 = round(sigma2 * 6)
|
|
244
|
+
ks2 += ks2 % 2 == 0
|
|
245
|
+
blur2 = TF.gaussian_blur(num * num, ks2, sigma2)
|
|
246
|
+
den = torch.sqrt(blur2)
|
|
247
|
+
return normalize99(num / den + 1e-8)
|
|
248
|
+
|
|
249
|
+
from scipy.ndimage import gaussian_filter
|
|
250
|
+
im = normalize99(im)
|
|
251
|
+
blur1 = gaussian_filter(im, sigma=sigma1)
|
|
252
|
+
num = im - blur1
|
|
253
|
+
blur2 = gaussian_filter(num * num, sigma=sigma2)
|
|
254
|
+
den = np.sqrt(blur2)
|
|
255
|
+
return normalize99(num / den + 1e-8)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
# Backward-compat alias
|
|
259
|
+
localnormalize_GPU = localnormalize
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def pnormalize(Y, p_min=-1, p_max=10):
|
|
263
|
+
"""Power-mean normalization to [0, 1].
|
|
264
|
+
|
|
265
|
+
Uses ``L^p`` norms with negative *p_min* (approximating min) and positive
|
|
266
|
+
*p_max* (approximating max) as soft min/max estimators. Works on numpy
|
|
267
|
+
and torch (auto-dispatched).
|
|
268
|
+
"""
|
|
269
|
+
module = get_module(Y)
|
|
270
|
+
lower_val = (module.abs(Y * 1.0) ** p_min).sum() ** (1.0 / p_min)
|
|
271
|
+
upper_val = (module.abs(Y * 1.0) ** p_max).sum() ** (1.0 / p_max)
|
|
272
|
+
return module.clip(safe_divide(Y - lower_val, upper_val - lower_val), 0, 1)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def normalize_image(
|
|
276
|
+
im,
|
|
277
|
+
mask,
|
|
278
|
+
target=0.5,
|
|
279
|
+
foreground=False,
|
|
280
|
+
iterations=1,
|
|
281
|
+
scale=1,
|
|
282
|
+
channel_axis=0,
|
|
283
|
+
per_channel=True,
|
|
284
|
+
):
|
|
285
|
+
"""Mask-aware gamma normalization to push masked region mean to *target*.
|
|
286
|
+
|
|
287
|
+
Numpy implementation. Optionally erodes the mask before computing the
|
|
288
|
+
target mean to avoid edge contamination.
|
|
289
|
+
"""
|
|
290
|
+
from scipy.ndimage import binary_erosion
|
|
291
|
+
try:
|
|
292
|
+
import numexpr as ne
|
|
293
|
+
except Exception: # pragma: no cover
|
|
294
|
+
ne = None
|
|
295
|
+
|
|
296
|
+
im = im.astype("float32") * scale
|
|
297
|
+
im_min = im.min()
|
|
298
|
+
im_max = im.max()
|
|
299
|
+
if ne is None:
|
|
300
|
+
im = (im - im_min) / (im_max - im_min)
|
|
301
|
+
else:
|
|
302
|
+
ne.evaluate("(im - im_min) / (im_max - im_min)", out=im)
|
|
303
|
+
|
|
304
|
+
if im.ndim > 2:
|
|
305
|
+
im = np.moveaxis(im, channel_axis, -1)
|
|
306
|
+
else:
|
|
307
|
+
im = np.expand_dims(im, axis=-1)
|
|
308
|
+
|
|
309
|
+
if not isinstance(mask, list):
|
|
310
|
+
mask = np.expand_dims(mask, axis=-1)
|
|
311
|
+
mask = np.broadcast_to(mask, im.shape)
|
|
312
|
+
|
|
313
|
+
bin0 = mask > 0 if foreground else mask == 0
|
|
314
|
+
if iterations > 0:
|
|
315
|
+
structure = np.ones((3,) * (im.ndim - 1) + (1,))
|
|
316
|
+
structure[1, ...] = 0
|
|
317
|
+
bin0 = binary_erosion(bin0, structure=structure, iterations=iterations)
|
|
318
|
+
|
|
319
|
+
masked_im = im.copy()
|
|
320
|
+
masked_im[~bin0] = np.nan
|
|
321
|
+
source_target = np.nanmean(masked_im, axis=(0, 1) if per_channel else None)
|
|
322
|
+
source_target = source_target.astype("float32")
|
|
323
|
+
target = np.array(target).astype("float32")
|
|
324
|
+
if ne is None:
|
|
325
|
+
im = im ** (np.log(target) / np.log(source_target))
|
|
326
|
+
else:
|
|
327
|
+
ne.evaluate("im ** (log(target) / log(source_target))", out=im)
|
|
328
|
+
return np.moveaxis(im, -1, channel_axis).squeeze()
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def adjust_contrast_masked(
|
|
332
|
+
img,
|
|
333
|
+
masks,
|
|
334
|
+
r_target=1.10,
|
|
335
|
+
plo=0.01,
|
|
336
|
+
phi=99.99,
|
|
337
|
+
clip_output=True,
|
|
338
|
+
):
|
|
339
|
+
"""Mask-aware percentile-clip + gamma to hit a target fg/bg ratio.
|
|
340
|
+
|
|
341
|
+
Returns ``(adjusted, gamma, (lo, hi))``. Numpy only.
|
|
342
|
+
"""
|
|
343
|
+
x = np.asarray(img, dtype=np.float32)
|
|
344
|
+
m = np.asarray(masks).astype(bool)
|
|
345
|
+
bg = ~m
|
|
346
|
+
fg = m
|
|
347
|
+
|
|
348
|
+
if fg.sum() == 0 or bg.sum() == 0:
|
|
349
|
+
return x.copy(), 1.0, (float(np.min(x)), float(np.max(x)))
|
|
350
|
+
|
|
351
|
+
a = np.percentile(x[bg], plo)
|
|
352
|
+
b = np.percentile(x[fg], phi)
|
|
353
|
+
if not np.isfinite(a) or not np.isfinite(b) or b <= a:
|
|
354
|
+
a = float(np.min(x))
|
|
355
|
+
b = float(np.max(x))
|
|
356
|
+
|
|
357
|
+
if not np.isfinite(b - a) or b <= a:
|
|
358
|
+
return x.copy(), 1.0, (float(a), float(b))
|
|
359
|
+
|
|
360
|
+
j = (x - a) / (b - a)
|
|
361
|
+
j = np.clip(j, 0.0, 1.0)
|
|
362
|
+
|
|
363
|
+
m_fg = float(j[fg].mean())
|
|
364
|
+
m_bg = float(j[bg].mean() + 1e-12)
|
|
365
|
+
r = m_fg / m_bg
|
|
366
|
+
|
|
367
|
+
if (r >= 1.0 and r_target < 1.0) or (r <= 1.0 and r_target > 1.0): # pragma: no cover
|
|
368
|
+
return j.copy(), 1.0, (a, b)
|
|
369
|
+
|
|
370
|
+
if abs(np.log(max(r, 1e-12))) < 1e-8 or abs((r - r_target) / max(r_target, 1e-12)) < 1e-3: # pragma: no cover
|
|
371
|
+
y = j
|
|
372
|
+
gamma = 1.0
|
|
373
|
+
else:
|
|
374
|
+
gamma = float(np.log(max(r_target, 1e-12)) / np.log(max(r, 1e-12)))
|
|
375
|
+
gamma = float(np.clip(gamma, 0.2, 5.0))
|
|
376
|
+
y = np.power(j, gamma)
|
|
377
|
+
|
|
378
|
+
if clip_output:
|
|
379
|
+
y = np.clip(y, 0.0, 1.0)
|
|
380
|
+
|
|
381
|
+
return y.astype(np.float32), gamma, (float(a), float(b))
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def gamma_normalize(
|
|
385
|
+
im,
|
|
386
|
+
mask,
|
|
387
|
+
target=1.0,
|
|
388
|
+
scale=1.0,
|
|
389
|
+
foreground=True,
|
|
390
|
+
iterations=0,
|
|
391
|
+
per_channel=True,
|
|
392
|
+
channel_axis=-1,
|
|
393
|
+
):
|
|
394
|
+
"""Torch (GPU-accelerated) variant of :func:`normalize_image`.
|
|
395
|
+
|
|
396
|
+
Uses ``ocdkit.gpu.torch_GPU`` for the device.
|
|
397
|
+
"""
|
|
398
|
+
from scipy.ndimage import binary_erosion
|
|
399
|
+
|
|
400
|
+
device = torch_GPU
|
|
401
|
+
im = rescale(im) * scale
|
|
402
|
+
if im.ndim > 2:
|
|
403
|
+
im = np.moveaxis(im, channel_axis, -1)
|
|
404
|
+
else:
|
|
405
|
+
im = np.expand_dims(im, axis=-1)
|
|
406
|
+
|
|
407
|
+
if not isinstance(mask, list):
|
|
408
|
+
mask = np.stack([mask] * im.shape[-1], axis=-1)
|
|
409
|
+
|
|
410
|
+
im = torch.from_numpy(im).float().to(device)
|
|
411
|
+
mask = torch.from_numpy(mask).float().to(device)
|
|
412
|
+
|
|
413
|
+
bin0 = mask > 0 if foreground else mask == 0
|
|
414
|
+
if iterations > 0:
|
|
415
|
+
structure = torch.ones((3,) * (im.ndim - 1) + (1,)).to(device)
|
|
416
|
+
structure[1, ...] = 0
|
|
417
|
+
bin0 = torch.from_numpy(
|
|
418
|
+
binary_erosion(bin0.cpu().numpy(), structure=structure.cpu().numpy(), iterations=iterations)
|
|
419
|
+
).to(device)
|
|
420
|
+
|
|
421
|
+
masked_im = im.masked_fill(~bin0, float("nan"))
|
|
422
|
+
source_target = torch.nanmean(masked_im, dim=(0, 1) if per_channel else None)
|
|
423
|
+
im **= (torch.log(target) / torch.log(source_target))
|
|
424
|
+
|
|
425
|
+
return im.permute(*[channel_axis] + [i for i in range(im.ndim) if i != channel_axis]).squeeze().cpu().numpy()
|
ocdkit/array/ops.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""Miscellaneous array utilities — divergence, noise, search, metadata."""
|
|
2
|
+
|
|
3
|
+
from .imports import *
|
|
4
|
+
from scipy.ndimage import convolve1d, gaussian_filter
|
|
5
|
+
|
|
6
|
+
from .convert import get_module
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def divergence(f):
|
|
10
|
+
"""Divergence of a vector field, dispatched on backend.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
f : array or tensor
|
|
15
|
+
Numpy: shape ``(D, *spatial)`` — unbatched D-vector field.
|
|
16
|
+
Torch: shape ``(B, D, *spatial)`` — batched D-vector field.
|
|
17
|
+
|
|
18
|
+
Returns
|
|
19
|
+
-------
|
|
20
|
+
div : array or tensor
|
|
21
|
+
Numpy: shape ``(*spatial,)``.
|
|
22
|
+
Torch: shape ``(B, *spatial)``.
|
|
23
|
+
|
|
24
|
+
Notes
|
|
25
|
+
-----
|
|
26
|
+
Returns zeros if any spatial dimension has size < 2 (gradient undefined).
|
|
27
|
+
On the CPU torch path this loops over components rather than calling
|
|
28
|
+
``torch.gradient`` on all dims at once, which would do unnecessary work.
|
|
29
|
+
"""
|
|
30
|
+
module = get_module(f)
|
|
31
|
+
if module == np:
|
|
32
|
+
num_dims = len(f)
|
|
33
|
+
if any(f.shape[1 + i] < 2 for i in range(num_dims)):
|
|
34
|
+
return np.zeros_like(f[0])
|
|
35
|
+
return np.add.reduce([np.gradient(f[i], axis=i) for i in range(num_dims)])
|
|
36
|
+
|
|
37
|
+
# Torch path: batched (B, D, *spatial)
|
|
38
|
+
B, D, *spatial = f.shape
|
|
39
|
+
if any(s < 2 for s in spatial):
|
|
40
|
+
return torch.zeros((B, *spatial), dtype=f.dtype, device=f.device)
|
|
41
|
+
div = torch.zeros((B, *spatial), dtype=f.dtype, device=f.device)
|
|
42
|
+
for d in range(D):
|
|
43
|
+
div += torch.gradient(f[:, d], dim=d + 1)[0]
|
|
44
|
+
return div
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def searchsorted(tensor, value):
|
|
48
|
+
"""Find indices where *value* should be inserted in *tensor* to keep order.
|
|
49
|
+
|
|
50
|
+
Backend-agnostic via ``get_module``: works on numpy arrays, torch tensors,
|
|
51
|
+
and any input where ``(tensor < value).sum()`` is meaningful.
|
|
52
|
+
"""
|
|
53
|
+
return (tensor < value).sum()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def enumerate_nested(*lists, parent_indices=None):
|
|
57
|
+
"""Traverse matching nested lists and yield indices with corresponding values.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
lists : list(s)
|
|
62
|
+
One or more nested lists to traverse. All must share the same structure.
|
|
63
|
+
parent_indices : list, optional
|
|
64
|
+
Indices accumulated by recursive calls (internal).
|
|
65
|
+
|
|
66
|
+
Yields
|
|
67
|
+
------
|
|
68
|
+
tuple
|
|
69
|
+
``(indices, values...)`` — index path plus one value from each input.
|
|
70
|
+
"""
|
|
71
|
+
if parent_indices is None:
|
|
72
|
+
parent_indices = []
|
|
73
|
+
|
|
74
|
+
if all(isinstance(lst[0], list) for lst in lists):
|
|
75
|
+
for i, sublists in enumerate(zip(*lists)):
|
|
76
|
+
current_indices = parent_indices + [i]
|
|
77
|
+
yield from enumerate_nested(*sublists, parent_indices=current_indices)
|
|
78
|
+
else:
|
|
79
|
+
for i, values in enumerate(zip(*lists)):
|
|
80
|
+
current_indices = parent_indices + [i]
|
|
81
|
+
yield current_indices, *values
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def unique_nonzero(arr):
|
|
85
|
+
"""Return sorted unique non-zero values of *arr*."""
|
|
86
|
+
import fastremap
|
|
87
|
+
u = fastremap.unique(arr)
|
|
88
|
+
return u[u != 0]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_size(var, unit='GB'):
|
|
92
|
+
"""Return the in-memory size of *var* in *unit* ('B', 'KB', 'MB', 'GB').
|
|
93
|
+
|
|
94
|
+
Works for any object with an ``nbytes`` attribute (numpy, dask, torch).
|
|
95
|
+
"""
|
|
96
|
+
units = {'B': 0, 'KB': 1, 'MB': 2, 'GB': 3}
|
|
97
|
+
return var.nbytes / (1024 ** units[unit])
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def random_int(N, M=None, seed=None):
|
|
101
|
+
"""Generate random integers in ``[0, N)``.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
N : int
|
|
106
|
+
Upper bound (exclusive).
|
|
107
|
+
M : int, optional
|
|
108
|
+
Number of integers to generate. Scalar if None.
|
|
109
|
+
seed : int, optional
|
|
110
|
+
RNG seed. If None, a random seed is generated and printed to stdout.
|
|
111
|
+
"""
|
|
112
|
+
if seed is None:
|
|
113
|
+
seed = np.random.randint(0, 2 ** 32 - 1)
|
|
114
|
+
print(f'Seed: {seed}')
|
|
115
|
+
else:
|
|
116
|
+
np.random.seed(seed)
|
|
117
|
+
return np.random.randint(0, N, M)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def moving_average(x, w):
|
|
121
|
+
"""1-D moving average via convolution along axis 0."""
|
|
122
|
+
return convolve1d(x, np.ones(w) / w, axis=0)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def add_poisson_noise(image):
|
|
126
|
+
"""Apply Poisson noise to an image and clip to [0, 1]."""
|
|
127
|
+
noisy_image = np.random.poisson(image)
|
|
128
|
+
return np.clip(noisy_image, 0, 1)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def correct_illumination(img, sigma=5):
|
|
132
|
+
"""Flatten illumination by subtracting a Gaussian-blurred background."""
|
|
133
|
+
blurred = gaussian_filter(img, sigma=sigma)
|
|
134
|
+
return (img - blurred) / np.std(blurred)
|