ocdkit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ocdkit/__init__.py +10 -0
  2. ocdkit/array/__init__.py +3 -0
  3. ocdkit/array/convert.py +121 -0
  4. ocdkit/array/filters.py +56 -0
  5. ocdkit/array/imports.py +4 -0
  6. ocdkit/array/index.py +242 -0
  7. ocdkit/array/morphology.py +194 -0
  8. ocdkit/array/normalize.py +425 -0
  9. ocdkit/array/ops.py +134 -0
  10. ocdkit/array/spatial.py +410 -0
  11. ocdkit/array/transform.py +261 -0
  12. ocdkit/array/union_find.py +52 -0
  13. ocdkit/array/warp.py +28 -0
  14. ocdkit/imports.py +8 -0
  15. ocdkit/io/__init__.py +3 -0
  16. ocdkit/io/files.py +141 -0
  17. ocdkit/io/image.py +138 -0
  18. ocdkit/io/imports.py +4 -0
  19. ocdkit/io/path.py +68 -0
  20. ocdkit/io/result.py +34 -0
  21. ocdkit/load/__init__.py +5 -0
  22. ocdkit/load/module.py +132 -0
  23. ocdkit/load/object.py +136 -0
  24. ocdkit/logging/__init__.py +3 -0
  25. ocdkit/logging/handler.py +206 -0
  26. ocdkit/measure/__init__.py +3 -0
  27. ocdkit/measure/bbox.py +188 -0
  28. ocdkit/measure/diameter.py +185 -0
  29. ocdkit/measure/imports.py +4 -0
  30. ocdkit/measure/medoid.py +181 -0
  31. ocdkit/measure/metrics.py +43 -0
  32. ocdkit/plot/__init__.py +5 -0
  33. ocdkit/plot/color.py +215 -0
  34. ocdkit/plot/contour.py +102 -0
  35. ocdkit/plot/defaults.py +147 -0
  36. ocdkit/plot/display.py +133 -0
  37. ocdkit/plot/export.py +108 -0
  38. ocdkit/plot/figure.py +24 -0
  39. ocdkit/plot/grid.py +306 -0
  40. ocdkit/plot/imports.py +9 -0
  41. ocdkit/plot/label.py +733 -0
  42. ocdkit/plot/ncolor.py +54 -0
  43. ocdkit/utils/__init__.py +3 -0
  44. ocdkit/utils/collections.py +97 -0
  45. ocdkit/utils/gpu.py +210 -0
  46. ocdkit/utils/kwargs.py +136 -0
  47. ocdkit-0.0.1.dist-info/METADATA +66 -0
  48. ocdkit-0.0.1.dist-info/RECORD +51 -0
  49. ocdkit-0.0.1.dist-info/WHEEL +5 -0
  50. ocdkit-0.0.1.dist-info/licenses/LICENSE +28 -0
  51. ocdkit-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,425 @@
1
+ """Normalization and contrast adjustment functions."""
2
+
3
+ from .imports import *
4
+ from .convert import get_module, safe_divide, rescale
5
+ from .ops import searchsorted
6
+
7
+ from ..utils.gpu import torch_GPU
8
+
9
+
10
+ def torch_norm(a, dim=0, keepdim=False):
11
+ """Compute vector magnitude along *dim*.
12
+
13
+ Works on numpy arrays and torch tensors. For torch, avoids
14
+ intermediate allocations and supports autograd.
15
+ """
16
+ module = get_module(a)
17
+ if module == np:
18
+ return np.sqrt(np.sum(a**2, axis=dim, keepdims=keepdim))
19
+ norm_sq = (a * a).sum(dim=dim, keepdim=keepdim)
20
+ return norm_sq.sqrt_() if not norm_sq.requires_grad else norm_sq.sqrt()
21
+
22
+
23
+ def _auto_chunked_quantile(tensor, q):
24
+ """Chunked quantile for large tensors that exceed torch.quantile limits."""
25
+ import math
26
+ max_elements = int(16e6 - 1)
27
+ num_elements = tensor.nelement()
28
+ chunk_size = max(1, math.ceil(num_elements / max_elements))
29
+ chunks = torch.chunk(tensor, chunk_size)
30
+ return torch.stack([torch.quantile(chunk, q) for chunk in chunks]).mean(dim=0)
31
+
32
+
33
+ def normalize99(Y, lower=0.01, upper=99.99, contrast_limits=None, dim=None, **kwargs):
34
+ """Clip to percentile range and rescale to [0, 1].
35
+
36
+ Works on numpy arrays and torch tensors.
37
+
38
+ Parameters
39
+ ----------
40
+ Y : array or tensor
41
+ Input data.
42
+ lower : float
43
+ Lower percentile (0-100).
44
+ upper : float
45
+ Upper percentile (0-100).
46
+ contrast_limits : tuple of float, optional
47
+ Explicit (low, high) values instead of computing from percentiles.
48
+ dim : int, optional
49
+ Normalize independently along this dimension.
50
+ """
51
+ module = get_module(Y)
52
+
53
+ if contrast_limits is None:
54
+ quantiles = np.array([lower, upper]) / 100
55
+ if module == torch:
56
+ quantiles = torch.tensor(quantiles, dtype=Y.dtype, device=Y.device)
57
+
58
+ if dim is not None:
59
+ Y_flattened = Y.reshape(Y.shape[dim], -1)
60
+ lower_val, upper_val = module.quantile(Y_flattened, quantiles, axis=-1)
61
+ if dim == 0:
62
+ lower_val = lower_val.reshape(Y.shape[dim], *([1] * (len(Y.shape) - 1)))
63
+ upper_val = upper_val.reshape(Y.shape[dim], *([1] * (len(Y.shape) - 1)))
64
+ else:
65
+ lower_val = lower_val.reshape(*Y.shape[:dim], *([1] * (len(Y.shape) - dim - 1)))
66
+ upper_val = upper_val.reshape(*Y.shape[:dim], *([1] * (len(Y.shape) - dim - 1)))
67
+ else:
68
+ try:
69
+ lower_val, upper_val = module.quantile(Y, quantiles)
70
+ except RuntimeError:
71
+ lower_val, upper_val = _auto_chunked_quantile(Y, quantiles)
72
+ else:
73
+ if module == np:
74
+ contrast_limits = np.array(contrast_limits)
75
+ elif module == torch:
76
+ contrast_limits = torch.tensor(contrast_limits)
77
+ lower_val, upper_val = contrast_limits
78
+
79
+ return module.clip(safe_divide(Y - lower_val, upper_val - lower_val), 0, 1)
80
+
81
+
82
+ def normalize_field(mu, cutoff=0, **kwargs):
83
+ """Normalize all nonzero field vectors to unit magnitude.
84
+
85
+ Works on numpy arrays and torch tensors (auto-detected).
86
+
87
+ Parameters
88
+ ----------
89
+ mu : array or tensor
90
+ Vector field, shape ``(D, *spatial)``.
91
+ cutoff : float
92
+ Vectors with magnitude below this are left unchanged.
93
+ """
94
+ module = get_module(mu)
95
+ if module == torch:
96
+ mag = torch_norm(mu, dim=0)
97
+ return torch.where(mag > cutoff, mu / mag, mu)
98
+ mag = np.sqrt(np.nansum(mu**2, axis=0))
99
+ valid = mag > cutoff
100
+ return np.where(valid, mu / np.where(valid, mag, 1.0), mu)
101
+
102
+
103
+ def quantile_rescale(Y, lower=0.0001, upper=0.9999, contrast_limits=None, bins=None):
104
+ """Sort-based quantile rescale to [0, 1].
105
+
106
+ Slower than ``normalize99`` for large arrays (mergesort cost), but the
107
+ explicit sort lets the caller plug in a different quantile estimator
108
+ later. Numpy only for now.
109
+ """
110
+ sorted_array = np.sort(Y.flatten(), kind="mergesort")
111
+ lower_idx = int(lower * (len(sorted_array) - 1))
112
+ upper_idx = int(upper * (len(sorted_array) - 1))
113
+ lower_val, upper_val = sorted_array[lower_idx], sorted_array[upper_idx]
114
+ r = safe_divide(Y - lower_val, upper_val - lower_val)
115
+ r[r < 0] = 0
116
+ r[r > 1] = 1
117
+ return r
118
+
119
+
120
+ def normalize99_hist(Y, lower=0.01, upper=99.99, contrast_limits=None, bins=None):
121
+ """Histogram-based percentile clip and rescale to [0, 1].
122
+
123
+ An alternative to ``normalize99`` that estimates quantiles from a CDF
124
+ over a histogram (cheaper than torch.quantile on huge tensors).
125
+ Works on numpy and torch.
126
+ """
127
+ upper = upper / 100
128
+ lower = lower / 100
129
+
130
+ module = get_module(Y)
131
+ if bins is None:
132
+ num_elements = Y.size if module == np else Y.numel()
133
+ bins = int(np.sqrt(num_elements))
134
+
135
+ if contrast_limits is None:
136
+ hist, bin_edges = module.histogram(Y, bins=bins)
137
+ cdf = module.cumsum(hist, axis=0) / module.sum(hist)
138
+ lower_val = bin_edges[searchsorted(cdf, lower)]
139
+ upper_val = bin_edges[searchsorted(cdf, upper)]
140
+ else:
141
+ if module == np:
142
+ contrast_limits = np.array(contrast_limits)
143
+ elif module == torch:
144
+ contrast_limits = torch.tensor(contrast_limits)
145
+ lower_val, upper_val = contrast_limits
146
+
147
+ r = safe_divide(Y - lower_val, upper_val - lower_val)
148
+ r[r < 0] = 0
149
+ r[r > 1] = 1
150
+ return r
151
+
152
+
153
+ def qnorm(
154
+ Y,
155
+ nbins=100,
156
+ bw_method=2,
157
+ density_cutoff=None,
158
+ density_quantile=(0.001, 0.999),
159
+ debug=False,
160
+ dx=None,
161
+ log=False,
162
+ eps=1,
163
+ ):
164
+ """Density-based quantile normalization.
165
+
166
+ Bins the histogram, fits a symmetric KDE to the density of the histogram
167
+ counts, and clips to the range where density exceeds *density_cutoff*.
168
+ Useful when the brightness distribution has a heavy tail and a simple
169
+ percentile clip is too aggressive. Numpy only for now.
170
+ """
171
+ import fastremap
172
+ from scipy.stats import gaussian_kde
173
+
174
+ if dx is not None:
175
+ X = Y[:, ::dx, ::dx]
176
+ else:
177
+ X = Y
178
+
179
+ if X.dtype not in (np.uint8, np.uint16, np.uint32, np.uint64):
180
+ X = (rescale(X) * (2 ** 16 - 1)).astype(np.uint16)
181
+
182
+ # bin counts
183
+ unique_values, counts = fastremap.unique(X, return_counts=True)
184
+ bin_edges = np.linspace(unique_values.min(), unique_values.max(), nbins + 1)
185
+ bin_indices = np.digitize(unique_values, bin_edges) - 1
186
+ binned_counts = np.bincount(bin_indices, weights=counts, minlength=nbins)
187
+ bin_start = bin_edges[:-1]
188
+ binned_counts = binned_counts[:-1]
189
+
190
+ sel = binned_counts > 0
191
+ counts_sel = binned_counts[sel]
192
+ unique_sel = bin_start[sel]
193
+ x = np.arange(len(counts_sel))
194
+ y = np.log(counts_sel + eps) if log else counts_sel
195
+
196
+ # symmetric KDE density
197
+ points = np.vstack([x, y])
198
+ kde = gaussian_kde(points, bw_method=bw_method)
199
+ density = kde(points)
200
+ inverted_kde = gaussian_kde(np.vstack([-x, y]), bw_method=bw_method)
201
+ inverted_density = inverted_kde(np.vstack([-x, y]))
202
+ d = rescale((density + inverted_density) / 2)
203
+
204
+ if not isinstance(density_quantile, (list, tuple)):
205
+ density_quantile = (density_quantile, density_quantile)
206
+
207
+ if density_cutoff is None:
208
+ density_cutoff = np.quantile(d, density_quantile) # pragma: no cover
209
+ elif not isinstance(density_cutoff, (list, tuple)):
210
+ density_cutoff = (density_cutoff, density_cutoff)
211
+
212
+ imin = np.argwhere(d > density_cutoff[0])[0][0]
213
+ imax = np.argwhere(d > density_cutoff[1])[-1][0]
214
+ vmin, vmax = unique_sel[imin], unique_sel[imax]
215
+
216
+ if vmax > vmin:
217
+ scale_factor = np.float16(1.0 / (vmax - vmin))
218
+ r = X * scale_factor
219
+ r[r > 1] = 1
220
+ else:
221
+ r = X
222
+
223
+ if debug:
224
+ return r, x, y, d, imin, imax, vmin, vmax
225
+ return r
226
+
227
+
228
+ def localnormalize(im, sigma1=2, sigma2=20):
229
+ """Local mean/std normalization via Gaussian blurs.
230
+
231
+ Works on numpy and torch (auto-dispatched). For torch, uses
232
+ ``torchvision.transforms.functional.gaussian_blur``; for numpy uses
233
+ ``scipy.ndimage.gaussian_filter``.
234
+ """
235
+ module = get_module(im)
236
+ if module == torch:
237
+ import torchvision.transforms.functional as TF
238
+ im = normalize99(im)
239
+ ks1 = round(sigma1 * 6)
240
+ ks1 += ks1 % 2 == 0
241
+ blur1 = TF.gaussian_blur(im, ks1, sigma1)
242
+ num = im - blur1
243
+ ks2 = round(sigma2 * 6)
244
+ ks2 += ks2 % 2 == 0
245
+ blur2 = TF.gaussian_blur(num * num, ks2, sigma2)
246
+ den = torch.sqrt(blur2)
247
+ return normalize99(num / den + 1e-8)
248
+
249
+ from scipy.ndimage import gaussian_filter
250
+ im = normalize99(im)
251
+ blur1 = gaussian_filter(im, sigma=sigma1)
252
+ num = im - blur1
253
+ blur2 = gaussian_filter(num * num, sigma=sigma2)
254
+ den = np.sqrt(blur2)
255
+ return normalize99(num / den + 1e-8)
256
+
257
+
258
+ # Backward-compat alias
259
+ localnormalize_GPU = localnormalize
260
+
261
+
262
+ def pnormalize(Y, p_min=-1, p_max=10):
263
+ """Power-mean normalization to [0, 1].
264
+
265
+ Uses ``L^p`` norms with negative *p_min* (approximating min) and positive
266
+ *p_max* (approximating max) as soft min/max estimators. Works on numpy
267
+ and torch (auto-dispatched).
268
+ """
269
+ module = get_module(Y)
270
+ lower_val = (module.abs(Y * 1.0) ** p_min).sum() ** (1.0 / p_min)
271
+ upper_val = (module.abs(Y * 1.0) ** p_max).sum() ** (1.0 / p_max)
272
+ return module.clip(safe_divide(Y - lower_val, upper_val - lower_val), 0, 1)
273
+
274
+
275
+ def normalize_image(
276
+ im,
277
+ mask,
278
+ target=0.5,
279
+ foreground=False,
280
+ iterations=1,
281
+ scale=1,
282
+ channel_axis=0,
283
+ per_channel=True,
284
+ ):
285
+ """Mask-aware gamma normalization to push masked region mean to *target*.
286
+
287
+ Numpy implementation. Optionally erodes the mask before computing the
288
+ target mean to avoid edge contamination.
289
+ """
290
+ from scipy.ndimage import binary_erosion
291
+ try:
292
+ import numexpr as ne
293
+ except Exception: # pragma: no cover
294
+ ne = None
295
+
296
+ im = im.astype("float32") * scale
297
+ im_min = im.min()
298
+ im_max = im.max()
299
+ if ne is None:
300
+ im = (im - im_min) / (im_max - im_min)
301
+ else:
302
+ ne.evaluate("(im - im_min) / (im_max - im_min)", out=im)
303
+
304
+ if im.ndim > 2:
305
+ im = np.moveaxis(im, channel_axis, -1)
306
+ else:
307
+ im = np.expand_dims(im, axis=-1)
308
+
309
+ if not isinstance(mask, list):
310
+ mask = np.expand_dims(mask, axis=-1)
311
+ mask = np.broadcast_to(mask, im.shape)
312
+
313
+ bin0 = mask > 0 if foreground else mask == 0
314
+ if iterations > 0:
315
+ structure = np.ones((3,) * (im.ndim - 1) + (1,))
316
+ structure[1, ...] = 0
317
+ bin0 = binary_erosion(bin0, structure=structure, iterations=iterations)
318
+
319
+ masked_im = im.copy()
320
+ masked_im[~bin0] = np.nan
321
+ source_target = np.nanmean(masked_im, axis=(0, 1) if per_channel else None)
322
+ source_target = source_target.astype("float32")
323
+ target = np.array(target).astype("float32")
324
+ if ne is None:
325
+ im = im ** (np.log(target) / np.log(source_target))
326
+ else:
327
+ ne.evaluate("im ** (log(target) / log(source_target))", out=im)
328
+ return np.moveaxis(im, -1, channel_axis).squeeze()
329
+
330
+
331
+ def adjust_contrast_masked(
332
+ img,
333
+ masks,
334
+ r_target=1.10,
335
+ plo=0.01,
336
+ phi=99.99,
337
+ clip_output=True,
338
+ ):
339
+ """Mask-aware percentile-clip + gamma to hit a target fg/bg ratio.
340
+
341
+ Returns ``(adjusted, gamma, (lo, hi))``. Numpy only.
342
+ """
343
+ x = np.asarray(img, dtype=np.float32)
344
+ m = np.asarray(masks).astype(bool)
345
+ bg = ~m
346
+ fg = m
347
+
348
+ if fg.sum() == 0 or bg.sum() == 0:
349
+ return x.copy(), 1.0, (float(np.min(x)), float(np.max(x)))
350
+
351
+ a = np.percentile(x[bg], plo)
352
+ b = np.percentile(x[fg], phi)
353
+ if not np.isfinite(a) or not np.isfinite(b) or b <= a:
354
+ a = float(np.min(x))
355
+ b = float(np.max(x))
356
+
357
+ if not np.isfinite(b - a) or b <= a:
358
+ return x.copy(), 1.0, (float(a), float(b))
359
+
360
+ j = (x - a) / (b - a)
361
+ j = np.clip(j, 0.0, 1.0)
362
+
363
+ m_fg = float(j[fg].mean())
364
+ m_bg = float(j[bg].mean() + 1e-12)
365
+ r = m_fg / m_bg
366
+
367
+ if (r >= 1.0 and r_target < 1.0) or (r <= 1.0 and r_target > 1.0): # pragma: no cover
368
+ return j.copy(), 1.0, (a, b)
369
+
370
+ if abs(np.log(max(r, 1e-12))) < 1e-8 or abs((r - r_target) / max(r_target, 1e-12)) < 1e-3: # pragma: no cover
371
+ y = j
372
+ gamma = 1.0
373
+ else:
374
+ gamma = float(np.log(max(r_target, 1e-12)) / np.log(max(r, 1e-12)))
375
+ gamma = float(np.clip(gamma, 0.2, 5.0))
376
+ y = np.power(j, gamma)
377
+
378
+ if clip_output:
379
+ y = np.clip(y, 0.0, 1.0)
380
+
381
+ return y.astype(np.float32), gamma, (float(a), float(b))
382
+
383
+
384
+ def gamma_normalize(
385
+ im,
386
+ mask,
387
+ target=1.0,
388
+ scale=1.0,
389
+ foreground=True,
390
+ iterations=0,
391
+ per_channel=True,
392
+ channel_axis=-1,
393
+ ):
394
+ """Torch (GPU-accelerated) variant of :func:`normalize_image`.
395
+
396
+ Uses ``ocdkit.gpu.torch_GPU`` for the device.
397
+ """
398
+ from scipy.ndimage import binary_erosion
399
+
400
+ device = torch_GPU
401
+ im = rescale(im) * scale
402
+ if im.ndim > 2:
403
+ im = np.moveaxis(im, channel_axis, -1)
404
+ else:
405
+ im = np.expand_dims(im, axis=-1)
406
+
407
+ if not isinstance(mask, list):
408
+ mask = np.stack([mask] * im.shape[-1], axis=-1)
409
+
410
+ im = torch.from_numpy(im).float().to(device)
411
+ mask = torch.from_numpy(mask).float().to(device)
412
+
413
+ bin0 = mask > 0 if foreground else mask == 0
414
+ if iterations > 0:
415
+ structure = torch.ones((3,) * (im.ndim - 1) + (1,)).to(device)
416
+ structure[1, ...] = 0
417
+ bin0 = torch.from_numpy(
418
+ binary_erosion(bin0.cpu().numpy(), structure=structure.cpu().numpy(), iterations=iterations)
419
+ ).to(device)
420
+
421
+ masked_im = im.masked_fill(~bin0, float("nan"))
422
+ source_target = torch.nanmean(masked_im, dim=(0, 1) if per_channel else None)
423
+ im **= (torch.log(target) / torch.log(source_target))
424
+
425
+ return im.permute(*[channel_axis] + [i for i in range(im.ndim) if i != channel_axis]).squeeze().cpu().numpy()
ocdkit/array/ops.py ADDED
@@ -0,0 +1,134 @@
1
+ """Miscellaneous array utilities — divergence, noise, search, metadata."""
2
+
3
+ from .imports import *
4
+ from scipy.ndimage import convolve1d, gaussian_filter
5
+
6
+ from .convert import get_module
7
+
8
+
9
+ def divergence(f):
10
+ """Divergence of a vector field, dispatched on backend.
11
+
12
+ Parameters
13
+ ----------
14
+ f : array or tensor
15
+ Numpy: shape ``(D, *spatial)`` — unbatched D-vector field.
16
+ Torch: shape ``(B, D, *spatial)`` — batched D-vector field.
17
+
18
+ Returns
19
+ -------
20
+ div : array or tensor
21
+ Numpy: shape ``(*spatial,)``.
22
+ Torch: shape ``(B, *spatial)``.
23
+
24
+ Notes
25
+ -----
26
+ Returns zeros if any spatial dimension has size < 2 (gradient undefined).
27
+ On the CPU torch path this loops over components rather than calling
28
+ ``torch.gradient`` on all dims at once, which would do unnecessary work.
29
+ """
30
+ module = get_module(f)
31
+ if module == np:
32
+ num_dims = len(f)
33
+ if any(f.shape[1 + i] < 2 for i in range(num_dims)):
34
+ return np.zeros_like(f[0])
35
+ return np.add.reduce([np.gradient(f[i], axis=i) for i in range(num_dims)])
36
+
37
+ # Torch path: batched (B, D, *spatial)
38
+ B, D, *spatial = f.shape
39
+ if any(s < 2 for s in spatial):
40
+ return torch.zeros((B, *spatial), dtype=f.dtype, device=f.device)
41
+ div = torch.zeros((B, *spatial), dtype=f.dtype, device=f.device)
42
+ for d in range(D):
43
+ div += torch.gradient(f[:, d], dim=d + 1)[0]
44
+ return div
45
+
46
+
47
+ def searchsorted(tensor, value):
48
+ """Find indices where *value* should be inserted in *tensor* to keep order.
49
+
50
+ Backend-agnostic via ``get_module``: works on numpy arrays, torch tensors,
51
+ and any input where ``(tensor < value).sum()`` is meaningful.
52
+ """
53
+ return (tensor < value).sum()
54
+
55
+
56
+ def enumerate_nested(*lists, parent_indices=None):
57
+ """Traverse matching nested lists and yield indices with corresponding values.
58
+
59
+ Parameters
60
+ ----------
61
+ lists : list(s)
62
+ One or more nested lists to traverse. All must share the same structure.
63
+ parent_indices : list, optional
64
+ Indices accumulated by recursive calls (internal).
65
+
66
+ Yields
67
+ ------
68
+ tuple
69
+ ``(indices, values...)`` — index path plus one value from each input.
70
+ """
71
+ if parent_indices is None:
72
+ parent_indices = []
73
+
74
+ if all(isinstance(lst[0], list) for lst in lists):
75
+ for i, sublists in enumerate(zip(*lists)):
76
+ current_indices = parent_indices + [i]
77
+ yield from enumerate_nested(*sublists, parent_indices=current_indices)
78
+ else:
79
+ for i, values in enumerate(zip(*lists)):
80
+ current_indices = parent_indices + [i]
81
+ yield current_indices, *values
82
+
83
+
84
+ def unique_nonzero(arr):
85
+ """Return sorted unique non-zero values of *arr*."""
86
+ import fastremap
87
+ u = fastremap.unique(arr)
88
+ return u[u != 0]
89
+
90
+
91
+ def get_size(var, unit='GB'):
92
+ """Return the in-memory size of *var* in *unit* ('B', 'KB', 'MB', 'GB').
93
+
94
+ Works for any object with an ``nbytes`` attribute (numpy, dask, torch).
95
+ """
96
+ units = {'B': 0, 'KB': 1, 'MB': 2, 'GB': 3}
97
+ return var.nbytes / (1024 ** units[unit])
98
+
99
+
100
+ def random_int(N, M=None, seed=None):
101
+ """Generate random integers in ``[0, N)``.
102
+
103
+ Parameters
104
+ ----------
105
+ N : int
106
+ Upper bound (exclusive).
107
+ M : int, optional
108
+ Number of integers to generate. Scalar if None.
109
+ seed : int, optional
110
+ RNG seed. If None, a random seed is generated and printed to stdout.
111
+ """
112
+ if seed is None:
113
+ seed = np.random.randint(0, 2 ** 32 - 1)
114
+ print(f'Seed: {seed}')
115
+ else:
116
+ np.random.seed(seed)
117
+ return np.random.randint(0, N, M)
118
+
119
+
120
+ def moving_average(x, w):
121
+ """1-D moving average via convolution along axis 0."""
122
+ return convolve1d(x, np.ones(w) / w, axis=0)
123
+
124
+
125
+ def add_poisson_noise(image):
126
+ """Apply Poisson noise to an image and clip to [0, 1]."""
127
+ noisy_image = np.random.poisson(image)
128
+ return np.clip(noisy_image, 0, 1)
129
+
130
+
131
+ def correct_illumination(img, sigma=5):
132
+ """Flatten illumination by subtracting a Gaussian-blurred background."""
133
+ blurred = gaussian_filter(img, sigma=sigma)
134
+ return (img - blurred) / np.std(blurred)