senoquant 1.0.0b3__py3-none-any.whl → 1.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,574 @@
1
+ """RMP spot detector implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from typing import Iterable
8
+
9
+ import numpy as np
10
+ from scipy import ndimage as ndi
11
+ from skimage.filters import laplace, threshold_otsu
12
+ from skimage.morphology import local_maxima
13
+ from skimage.segmentation import watershed
14
+
15
+ try:
16
+ import torch
17
+ import torch.nn.functional as F
18
+ except ImportError: # pragma: no cover - optional dependency
19
+ torch = None # type: ignore[assignment]
20
+ F = None # type: ignore[assignment]
21
+
22
+ from ..base import SenoQuantSpotDetector
23
+ from senoquant.utils import layer_data_asarray
24
+
25
+ try:
26
+ import dask.array as da
27
+ except ImportError: # pragma: no cover - optional dependency
28
+ da = None # type: ignore[assignment]
29
+
30
+ try: # pragma: no cover - optional dependency
31
+ from dask.distributed import Client, LocalCluster
32
+ except ImportError: # pragma: no cover - optional dependency
33
+ Client = None # type: ignore[assignment]
34
+ LocalCluster = None # type: ignore[assignment]
35
+
36
+
37
+ Array2D = np.ndarray
38
+ KernelShape = tuple[int, int]
39
+ EPS = 1e-6
40
+ NOISE_FLOOR_SIGMA = 1.5
41
+ MIN_SCALE_SIGMA = 5.0
42
+ SIGNAL_SCALE_QUANTILE = 99.9
43
+ USE_LAPLACE_FOR_PEAKS = False
44
+
45
+
46
+ def _ensure_torch_available() -> None:
47
+ """Ensure torch is available for RMP processing."""
48
+ if torch is None or F is None: # pragma: no cover - import guard
49
+ raise ImportError("torch is required for the RMP detector.")
50
+
51
+
52
+ def _torch_device() -> "torch.device":
53
+ """Return the best available torch device (CUDA, MPS, then CPU)."""
54
+ _ensure_torch_available()
55
+ assert torch is not None
56
+ if torch.cuda.is_available():
57
+ return torch.device("cuda")
58
+ mps_backend = getattr(torch.backends, "mps", None)
59
+ if mps_backend is not None and mps_backend.is_available():
60
+ return torch.device("mps")
61
+ return torch.device("cpu")
62
+
63
+
64
+ def _to_image_tensor(image: np.ndarray, *, device: "torch.device") -> "torch.Tensor":
65
+ """Convert a 2D image array to a [1,1,H,W] torch tensor."""
66
+ _ensure_torch_available()
67
+ assert torch is not None
68
+ tensor = torch.as_tensor(image, dtype=torch.float32, device=device)
69
+ if tensor.ndim != 2:
70
+ raise ValueError("Expected a 2D image for tensor conversion.")
71
+ return tensor.unsqueeze(0).unsqueeze(0)
72
+
73
+
74
+ def _rotate_tensor(image: "torch.Tensor", angle: float) -> "torch.Tensor":
75
+ """Rotate a [1,1,H,W] tensor with reflection padding."""
76
+ _ensure_torch_available()
77
+ assert torch is not None
78
+ assert F is not None
79
+ if image.ndim != 4:
80
+ raise ValueError("Expected a [N,C,H,W] tensor for rotation.")
81
+
82
+ height = float(image.shape[-2])
83
+ width = float(image.shape[-1])
84
+ hw_ratio = height / width if width > 0 else 1.0
85
+ wh_ratio = width / height if height > 0 else 1.0
86
+
87
+ radians = np.deg2rad(float(angle))
88
+ cos_v = float(np.cos(radians))
89
+ sin_v = float(np.sin(radians))
90
+ # affine_grid operates in normalized coordinates; non-square images need
91
+ # aspect-ratio correction on the off-diagonal terms.
92
+ theta = torch.tensor(
93
+ [[[cos_v, -sin_v * hw_ratio, 0.0], [sin_v * wh_ratio, cos_v, 0.0]]],
94
+ dtype=image.dtype,
95
+ device=image.device,
96
+ )
97
+ grid = F.affine_grid(theta, tuple(image.shape), align_corners=False)
98
+ return F.grid_sample(
99
+ image,
100
+ grid,
101
+ mode="bilinear",
102
+ padding_mode="reflection",
103
+ align_corners=False,
104
+ )
105
+
106
+
107
+ def _grayscale_opening_tensor(
108
+ image: "torch.Tensor",
109
+ kernel_shape: KernelShape,
110
+ ) -> "torch.Tensor":
111
+ """Apply grayscale opening (erosion then dilation) with a rectangular kernel."""
112
+ _ensure_torch_available()
113
+ assert F is not None
114
+ img_h = int(image.shape[-2])
115
+ img_w = int(image.shape[-1])
116
+ ky = min(max(1, int(kernel_shape[0])), max(1, img_h))
117
+ kx = min(max(1, int(kernel_shape[1])), max(1, img_w))
118
+ pad_y = ky // 2
119
+ pad_x = kx // 2
120
+ pad = (pad_x, pad_x, pad_y, pad_y)
121
+
122
+ # Erosion via pooling uses the min-over-window identity:
123
+ # min(x) == -max(-x). Missing the inner negation flips morphology behavior.
124
+ eroded = -F.max_pool2d(
125
+ F.pad(-image, pad, mode="reflect"),
126
+ kernel_size=(ky, kx),
127
+ stride=1,
128
+ )
129
+ opened = F.max_pool2d(
130
+ F.pad(eroded, pad, mode="reflect"),
131
+ kernel_size=(ky, kx),
132
+ stride=1,
133
+ )
134
+ return opened
135
+
136
+
137
+ def _kernel_shape(footprint: KernelShape | np.ndarray) -> KernelShape:
138
+ """Return kernel shape from either a tuple footprint or array."""
139
+ if isinstance(footprint, tuple):
140
+ return max(1, int(footprint[0])), max(1, int(footprint[1]))
141
+ arr = np.asarray(footprint)
142
+ if arr.ndim != 2:
143
+ raise ValueError("Structuring element must be 2D.")
144
+ return max(1, int(arr.shape[0])), max(1, int(arr.shape[1]))
145
+
146
+
147
+ def _normalize_image(image: np.ndarray) -> np.ndarray:
148
+ """Normalize an image to float32 in [0, 1]."""
149
+ device = _torch_device()
150
+ data = np.asarray(image, dtype=np.float32)
151
+ _ensure_torch_available()
152
+ assert torch is not None
153
+ tensor = torch.as_tensor(data, dtype=torch.float32, device=device)
154
+ min_val = tensor.amin()
155
+ max_val = tensor.amax()
156
+ if bool(max_val <= min_val):
157
+ return np.zeros_like(data, dtype=np.float32)
158
+ normalized = (tensor - min_val) / (max_val - min_val)
159
+ normalized = normalized.clamp(0.0, 1.0)
160
+ return normalized.detach().cpu().numpy().astype(np.float32, copy=False)
161
+
162
+
163
+ def _clamp_threshold(value: float) -> float:
164
+ """Clamp threshold to the inclusive [0.0, 1.0] range."""
165
+ return float(np.clip(value, 0.0, 1.0))
166
+
167
+
168
+ def _normalize_top_hat_unit(image: np.ndarray) -> np.ndarray:
169
+ """Robust normalization for top-hat output."""
170
+ data = np.asarray(image, dtype=np.float32)
171
+ finite_mask = np.isfinite(data)
172
+ if not np.any(finite_mask):
173
+ return np.zeros_like(data, dtype=np.float32)
174
+
175
+ valid = data[finite_mask]
176
+ background = float(np.nanmedian(valid))
177
+ sigma = 1.4826 * float(np.nanmedian(np.abs(valid - background)))
178
+
179
+ if (not np.isfinite(sigma)) or sigma <= EPS:
180
+ sigma = float(np.nanstd(valid))
181
+ if (not np.isfinite(sigma)) or sigma <= EPS:
182
+ return np.zeros_like(data, dtype=np.float32)
183
+
184
+ noise_floor = background + (NOISE_FLOOR_SIGMA * sigma)
185
+ residual = np.clip(data - noise_floor, 0.0, None)
186
+ residual = np.where(finite_mask, residual, 0.0)
187
+
188
+ positive = residual[residual > 0.0]
189
+ if positive.size == 0:
190
+ return np.zeros_like(data, dtype=np.float32)
191
+ high = float(np.nanpercentile(positive, SIGNAL_SCALE_QUANTILE))
192
+ if (not np.isfinite(high)) or high <= EPS:
193
+ high = float(np.nanmax(positive))
194
+ if (not np.isfinite(high)) or high <= EPS:
195
+ return np.zeros_like(data, dtype=np.float32)
196
+
197
+ scale = max(high, MIN_SCALE_SIGMA * sigma, EPS)
198
+ normalized = np.clip(residual / scale, 0.0, 1.0)
199
+ return normalized.astype(np.float32, copy=False)
200
+
201
+
202
+ def _markers_from_local_maxima(
203
+ enhanced: np.ndarray,
204
+ threshold: float,
205
+ use_laplace: bool = USE_LAPLACE_FOR_PEAKS,
206
+ ) -> np.ndarray:
207
+ """Build marker labels from local maxima and thresholding."""
208
+ connectivity = max(1, min(2, enhanced.ndim))
209
+ response = (
210
+ laplace(enhanced.astype(np.float32, copy=False))
211
+ if use_laplace
212
+ else np.asarray(enhanced, dtype=np.float32)
213
+ )
214
+ mask = local_maxima(response, connectivity=connectivity)
215
+ mask = mask & (response > threshold)
216
+
217
+ markers = np.zeros(enhanced.shape, dtype=np.int32)
218
+ coords = np.argwhere(mask)
219
+ if coords.size == 0:
220
+ return markers
221
+
222
+ max_indices = np.asarray(enhanced.shape) - 1
223
+ coords = np.clip(coords, 0, max_indices)
224
+ markers[tuple(coords.T)] = 1
225
+
226
+ structure = ndi.generate_binary_structure(enhanced.ndim, 1)
227
+ marker_labels, _num = ndi.label(markers > 0, structure=structure)
228
+ return marker_labels.astype(np.int32, copy=False)
229
+
230
+
231
+ def _segment_from_markers(
232
+ enhanced: np.ndarray,
233
+ markers: np.ndarray,
234
+ threshold: float,
235
+ ) -> np.ndarray:
236
+ """Run watershed from local-maxima markers inside threshold foreground."""
237
+ foreground = enhanced > threshold
238
+ if not np.any(foreground):
239
+ return np.zeros_like(enhanced, dtype=np.int32)
240
+
241
+ seeded_markers = markers * foreground.astype(np.int32, copy=False)
242
+ if not np.any(seeded_markers > 0):
243
+ return np.zeros_like(enhanced, dtype=np.int32)
244
+
245
+ labels = watershed(
246
+ -enhanced.astype(np.float32, copy=False),
247
+ markers=seeded_markers,
248
+ mask=foreground,
249
+ )
250
+ return labels.astype(np.int32, copy=False)
251
+
252
+ def _pad_tensor_for_rotation(
253
+ image: "torch.Tensor",
254
+ ) -> tuple["torch.Tensor", tuple[int, int]]:
255
+ """Pad a [1,1,H,W] tensor to preserve content after rotations."""
256
+ nrows = int(image.shape[-2])
257
+ ncols = int(image.shape[-1])
258
+ diagonal = int(np.ceil(np.sqrt(nrows**2 + ncols**2)))
259
+ rows_to_pad = int(np.ceil((diagonal - nrows) / 2))
260
+ cols_to_pad = int(np.ceil((diagonal - ncols) / 2))
261
+ assert F is not None
262
+ padded = F.pad(
263
+ image,
264
+ (cols_to_pad, cols_to_pad, rows_to_pad, rows_to_pad),
265
+ mode="reflect",
266
+ )
267
+ return padded, (rows_to_pad, cols_to_pad)
268
+
269
+ def _rmp_opening(
270
+ input_image: Array2D,
271
+ structuring_element: KernelShape | Array2D,
272
+ rotation_angles: Iterable[int],
273
+ ) -> Array2D:
274
+ """Perform the RMP opening on an image."""
275
+ device = _torch_device()
276
+ tensor = _to_image_tensor(np.asarray(input_image, dtype=np.float32), device=device)
277
+ padded, (newy, newx) = _pad_tensor_for_rotation(tensor)
278
+ kernel_shape = _kernel_shape(structuring_element)
279
+
280
+ rotated_images = [_rotate_tensor(padded, angle) for angle in rotation_angles]
281
+ opened_images = [
282
+ _grayscale_opening_tensor(image, kernel_shape) for image in rotated_images
283
+ ]
284
+ rotated_back = [
285
+ _rotate_tensor(image, -angle)
286
+ for image, angle in zip(opened_images, rotation_angles)
287
+ ]
288
+ stacked = torch.stack(rotated_back, dim=0)
289
+ union_image = stacked.max(dim=0).values
290
+ cropped = union_image[
291
+ ...,
292
+ newy : newy + input_image.shape[0],
293
+ newx : newx + input_image.shape[1],
294
+ ]
295
+ return cropped.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.float32, copy=False)
296
+
297
+
298
+ def _rmp_top_hat(
299
+ input_image: Array2D,
300
+ structuring_element: Array2D,
301
+ rotation_angles: Iterable[int],
302
+ ) -> Array2D:
303
+ """Return the top-hat (background subtracted) image."""
304
+ opened_image = _rmp_opening(input_image, structuring_element, rotation_angles)
305
+ return input_image - opened_image
306
+
307
+
308
+ def _compute_top_hat(input_image: Array2D, config: "RMPSettings") -> Array2D:
309
+ """Compute the RMP top-hat response for a 2D image."""
310
+ denoising_se: KernelShape = (1, config.denoising_se_length)
311
+ extraction_se: KernelShape = (1, config.extraction_se_length)
312
+ rotation_angles = tuple(range(0, 180, config.angle_spacing))
313
+
314
+ working = (
315
+ _rmp_opening(input_image, denoising_se, rotation_angles)
316
+ if config.enable_denoising
317
+ else input_image
318
+ )
319
+ return _rmp_top_hat(working, extraction_se, rotation_angles)
320
+
321
+
322
+ def _ensure_dask_available() -> None:
323
+ """Ensure dask is installed for tiled execution."""
324
+ if da is None: # pragma: no cover - import guard
325
+ raise ImportError("dask is required for distributed spot detection.")
326
+
327
+
328
+ def _ensure_distributed_available() -> None:
329
+ """Ensure dask.distributed is installed for distributed execution."""
330
+ if Client is None or LocalCluster is None: # pragma: no cover - import guard
331
+ raise ImportError("dask.distributed is required for distributed execution.")
332
+
333
+
334
+ def _dask_available() -> bool:
335
+ """Return True when dask is available."""
336
+ return da is not None
337
+
338
+
339
+ def _distributed_available() -> bool:
340
+ """Return True when dask.distributed is available."""
341
+ return Client is not None and LocalCluster is not None and da is not None
342
+
343
+
344
+ def _recommended_overlap(config: "RMPSettings") -> int:
345
+ """Derive a suitable overlap from structuring-element sizes."""
346
+ lengths = [config.extraction_se_length]
347
+ if config.enable_denoising:
348
+ lengths.append(config.denoising_se_length)
349
+ return max(1, max(lengths) * 2)
350
+
351
+
352
+ @contextmanager
353
+ def _cluster_client():
354
+ """Yield a connected Dask client backed by a local cluster."""
355
+ _ensure_distributed_available()
356
+ with LocalCluster() as cluster:
357
+ with Client(cluster) as client:
358
+ yield client
359
+
360
+
361
+ def _rmp_top_hat_block(block: np.ndarray, config: "RMPSettings") -> np.ndarray:
362
+ """Return background-subtracted tile via the RMP top-hat pipeline."""
363
+ denoising_se: KernelShape = (1, config.denoising_se_length)
364
+ extraction_se: KernelShape = (1, config.extraction_se_length)
365
+ rotation_angles = tuple(range(0, 180, config.angle_spacing))
366
+
367
+ working = (
368
+ _rmp_opening(block, denoising_se, rotation_angles)
369
+ if config.enable_denoising
370
+ else block
371
+ )
372
+ top_hat = working - _rmp_opening(working, extraction_se, rotation_angles)
373
+ return np.asarray(top_hat, dtype=np.float32)
374
+
375
+
376
+ def _compute_top_hat_2d(
377
+ image_2d: np.ndarray,
378
+ config: "RMPSettings",
379
+ *,
380
+ use_tiled: bool,
381
+ distributed: bool,
382
+ client: "Client | None" = None,
383
+ ) -> np.ndarray:
384
+ """Compute a top-hat image for one 2D plane."""
385
+ if use_tiled:
386
+ return _rmp_top_hat_tiled(
387
+ image_2d,
388
+ config=config,
389
+ distributed=distributed,
390
+ client=client,
391
+ )
392
+ return _compute_top_hat(image_2d, config)
393
+
394
+
395
+ def _compute_top_hat_nd(
396
+ image: np.ndarray,
397
+ config: "RMPSettings",
398
+ *,
399
+ use_tiled: bool,
400
+ use_distributed: bool,
401
+ ) -> np.ndarray:
402
+ """Compute top-hat for 2D images or slice-wise for 3D stacks."""
403
+ if image.ndim == 2:
404
+ return _compute_top_hat_2d(
405
+ image,
406
+ config,
407
+ use_tiled=use_tiled,
408
+ distributed=use_distributed,
409
+ )
410
+
411
+ top_hat_stack = np.zeros_like(image, dtype=np.float32)
412
+ if use_tiled and use_distributed:
413
+ with _cluster_client() as client:
414
+ for z in range(image.shape[0]):
415
+ top_hat_stack[z] = _compute_top_hat_2d(
416
+ image[z],
417
+ config,
418
+ use_tiled=True,
419
+ distributed=True,
420
+ client=client,
421
+ )
422
+ return top_hat_stack
423
+
424
+ for z in range(image.shape[0]):
425
+ top_hat_stack[z] = _compute_top_hat_2d(
426
+ image[z],
427
+ config,
428
+ use_tiled=use_tiled,
429
+ distributed=False,
430
+ )
431
+ return top_hat_stack
432
+
433
+
434
+ def _postprocess_top_hat(
435
+ top_hat: np.ndarray,
436
+ config: "RMPSettings",
437
+ ) -> tuple[np.ndarray, np.ndarray]:
438
+ """Apply normalization, thresholding, marker extraction, and watershed."""
439
+ top_hat_normalized = _normalize_top_hat_unit(top_hat)
440
+ threshold = (
441
+ _clamp_threshold(float(threshold_otsu(top_hat_normalized)))
442
+ if config.auto_threshold
443
+ else config.manual_threshold
444
+ )
445
+ markers = _markers_from_local_maxima(
446
+ top_hat_normalized,
447
+ threshold,
448
+ use_laplace=USE_LAPLACE_FOR_PEAKS,
449
+ )
450
+ labels = _segment_from_markers(
451
+ top_hat_normalized,
452
+ markers,
453
+ threshold,
454
+ )
455
+ return labels, top_hat_normalized
456
+
457
+
458
+ def _rmp_top_hat_tiled(
459
+ image: np.ndarray,
460
+ config: "RMPSettings",
461
+ chunk_size: tuple[int, int] = (1024, 1024),
462
+ overlap: int | None = None,
463
+ distributed: bool = False,
464
+ client: "Client | None" = None,
465
+ ) -> np.ndarray:
466
+ """Return the RMP top-hat image using tiled execution."""
467
+ _ensure_dask_available()
468
+
469
+ effective_overlap = _recommended_overlap(config) if overlap is None else overlap
470
+
471
+ def block_fn(block, block_info=None):
472
+ return _rmp_top_hat_block(block, config)
473
+
474
+ arr = da.from_array(image.astype(np.float32, copy=False), chunks=chunk_size)
475
+ result = arr.map_overlap(
476
+ block_fn,
477
+ depth=(effective_overlap, effective_overlap),
478
+ boundary="reflect",
479
+ dtype=np.float32,
480
+ trim=True,
481
+ )
482
+
483
+ if distributed:
484
+ _ensure_distributed_available()
485
+ if client is None:
486
+ with _cluster_client() as temp_client:
487
+ return temp_client.compute(result).result()
488
+ return client.compute(result).result()
489
+
490
+ return result.compute()
491
+
492
+
493
+ @dataclass(slots=True)
494
+ class RMPSettings:
495
+ """Configuration for the RMP detector."""
496
+
497
+ denoising_se_length: int = 2
498
+ extraction_se_length: int = 10
499
+ angle_spacing: int = 5
500
+ auto_threshold: bool = True
501
+ manual_threshold: float = 0.05
502
+ enable_denoising: bool = True
503
+
504
+
505
+ class RMPDetector(SenoQuantSpotDetector):
506
+ """RMP spot detector implementation."""
507
+
508
+ def __init__(self, models_root=None) -> None:
509
+ super().__init__("rmp", models_root=models_root)
510
+
511
+ def run(self, **kwargs) -> dict:
512
+ """Run the RMP detector and return instance labels.
513
+
514
+ Parameters
515
+ ----------
516
+ **kwargs
517
+ layer : napari.layers.Image or None
518
+ Image layer used for spot detection.
519
+ settings : dict
520
+ Detector settings keyed by the details.json schema.
521
+
522
+ Returns
523
+ -------
524
+ dict
525
+ Dictionary with ``mask`` key containing instance labels.
526
+ """
527
+ layer = kwargs.get("layer")
528
+ if layer is None:
529
+ return {"mask": None, "points": None}
530
+ if getattr(layer, "rgb", False):
531
+ raise ValueError("RMP requires single-channel images.")
532
+
533
+ settings = kwargs.get("settings", {})
534
+ manual_threshold = _clamp_threshold(
535
+ float(settings.get("manual_threshold", 0.5))
536
+ )
537
+ config = RMPSettings(
538
+ denoising_se_length=int(settings.get("denoising_kernel_length", 2)),
539
+ extraction_se_length=int(settings.get("extraction_kernel_length", 10)),
540
+ angle_spacing=int(settings.get("angle_spacing", 5)),
541
+ auto_threshold=bool(settings.get("auto_threshold", True)),
542
+ manual_threshold=manual_threshold,
543
+ enable_denoising=bool(settings.get("enable_denoising", True)),
544
+ )
545
+
546
+ if config.angle_spacing <= 0:
547
+ raise ValueError("Angle spacing must be positive.")
548
+ if config.denoising_se_length <= 0 or config.extraction_se_length <= 0:
549
+ raise ValueError("Structuring element lengths must be positive.")
550
+
551
+ data = layer_data_asarray(layer)
552
+ if data.ndim not in (2, 3):
553
+ raise ValueError("RMP expects 2D images or 3D stacks.")
554
+
555
+ normalized = _normalize_image(data)
556
+
557
+ use_distributed = _distributed_available()
558
+ use_tiled = _dask_available()
559
+ top_hat = _compute_top_hat_nd(
560
+ normalized,
561
+ config,
562
+ use_tiled=use_tiled,
563
+ use_distributed=use_distributed,
564
+ )
565
+ labels, _top_hat_normalized = _postprocess_top_hat(top_hat, config)
566
+ return {
567
+ "mask": labels,
568
+ # "debug_images": {
569
+ # "debug_top_hat_before_threshold": _top_hat_normalized.astype(
570
+ # np.float32,
571
+ # copy=False,
572
+ # ),
573
+ # },
574
+ }
@@ -2,8 +2,23 @@
2
2
  "name": "ufish",
3
3
  "description": "U-FISH local-maxima seeded watershed detector",
4
4
  "version": "0.1.0",
5
- "order": 0,
5
+ "order": 1,
6
6
  "settings": [
7
+ {
8
+ "key": "denoise_enabled",
9
+ "label": "Denoise input",
10
+ "type": "bool",
11
+ "default": true
12
+ },
13
+ {
14
+ "key": "spot_size",
15
+ "label": "Spot size",
16
+ "type": "float",
17
+ "decimals": 2,
18
+ "min": 0.25,
19
+ "max": 4.0,
20
+ "default": 1.0
21
+ },
7
22
  {
8
23
  "key": "threshold",
9
24
  "label": "Threshold",