senoquant 1.0.0b2__py3-none-any.whl → 1.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. senoquant/__init__.py +6 -2
  2. senoquant/_reader.py +1 -1
  3. senoquant/_widget.py +9 -1
  4. senoquant/reader/core.py +201 -18
  5. senoquant/tabs/__init__.py +2 -0
  6. senoquant/tabs/batch/backend.py +76 -27
  7. senoquant/tabs/batch/frontend.py +127 -25
  8. senoquant/tabs/quantification/features/marker/dialog.py +26 -6
  9. senoquant/tabs/quantification/features/marker/export.py +97 -24
  10. senoquant/tabs/quantification/features/marker/rows.py +2 -2
  11. senoquant/tabs/quantification/features/spots/dialog.py +41 -11
  12. senoquant/tabs/quantification/features/spots/export.py +163 -10
  13. senoquant/tabs/quantification/frontend.py +2 -2
  14. senoquant/tabs/segmentation/frontend.py +46 -9
  15. senoquant/tabs/segmentation/models/cpsam/model.py +1 -1
  16. senoquant/tabs/segmentation/models/default_2d/model.py +22 -77
  17. senoquant/tabs/segmentation/models/default_3d/model.py +8 -74
  18. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +0 -0
  19. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +13 -13
  20. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/stardist_libs.py +171 -0
  21. senoquant/tabs/spots/frontend.py +96 -5
  22. senoquant/tabs/spots/models/rmp/details.json +3 -9
  23. senoquant/tabs/spots/models/rmp/model.py +341 -266
  24. senoquant/tabs/spots/models/ufish/details.json +32 -0
  25. senoquant/tabs/spots/models/ufish/model.py +327 -0
  26. senoquant/tabs/spots/ufish_utils/__init__.py +13 -0
  27. senoquant/tabs/spots/ufish_utils/core.py +387 -0
  28. senoquant/tabs/visualization/__init__.py +1 -0
  29. senoquant/tabs/visualization/backend.py +306 -0
  30. senoquant/tabs/visualization/frontend.py +1113 -0
  31. senoquant/tabs/visualization/plots/__init__.py +80 -0
  32. senoquant/tabs/visualization/plots/base.py +152 -0
  33. senoquant/tabs/visualization/plots/double_expression.py +187 -0
  34. senoquant/tabs/visualization/plots/spatialplot.py +156 -0
  35. senoquant/tabs/visualization/plots/umap.py +140 -0
  36. senoquant/utils.py +1 -1
  37. senoquant-1.0.0b4.dist-info/METADATA +162 -0
  38. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/RECORD +53 -30
  39. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/top_level.txt +1 -0
  40. ufish/__init__.py +1 -0
  41. ufish/api.py +778 -0
  42. ufish/model/__init__.py +0 -0
  43. ufish/model/loss.py +62 -0
  44. ufish/model/network/__init__.py +0 -0
  45. ufish/model/network/spot_learn.py +50 -0
  46. ufish/model/network/ufish_net.py +204 -0
  47. ufish/model/train.py +175 -0
  48. ufish/utils/__init__.py +0 -0
  49. ufish/utils/img.py +418 -0
  50. ufish/utils/log.py +8 -0
  51. ufish/utils/spot_calling.py +115 -0
  52. senoquant/tabs/spots/models/udwt/details.json +0 -103
  53. senoquant/tabs/spots/models/udwt/model.py +0 -482
  54. senoquant-1.0.0b2.dist-info/METADATA +0 -193
  55. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/WHEEL +0 -0
  56. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/entry_points.txt +0 -0
  57. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/licenses/LICENSE +0 -0
@@ -8,13 +8,16 @@ from typing import Iterable
8
8
 
9
9
  import numpy as np
10
10
  from scipy import ndimage as ndi
11
- from skimage.filters import threshold_otsu
12
- from skimage.measure import label
13
- from skimage.morphology import opening, rectangle
11
+ from skimage.filters import laplace, threshold_otsu
12
+ from skimage.morphology import local_maxima
14
13
  from skimage.segmentation import watershed
15
- from skimage.feature import peak_local_max
16
- from skimage.transform import rotate
17
- from skimage.util import img_as_ubyte
14
+
15
+ try:
16
+ import torch
17
+ import torch.nn.functional as F
18
+ except ImportError: # pragma: no cover - optional dependency
19
+ torch = None # type: ignore[assignment]
20
+ F = None # type: ignore[assignment]
18
21
 
19
22
  from ..base import SenoQuantSpotDetector
20
23
  from senoquant.utils import layer_data_asarray
@@ -30,81 +33,266 @@ except ImportError: # pragma: no cover - optional dependency
30
33
  Client = None # type: ignore[assignment]
31
34
  LocalCluster = None # type: ignore[assignment]
32
35
 
33
- try: # pragma: no cover - optional dependency
34
- from dask_cuda import LocalCUDACluster
35
- except ImportError: # pragma: no cover - optional dependency
36
- LocalCUDACluster = None # type: ignore[assignment]
37
36
 
38
- try: # pragma: no cover - optional dependency
39
- import cupy as cp
40
- from cucim.skimage.filters import threshold_otsu as gpu_threshold_otsu
41
- from cucim.skimage.morphology import opening as gpu_opening, rectangle as gpu_rectangle
42
- from cucim.skimage.transform import rotate as gpu_rotate
43
- except ImportError: # pragma: no cover - optional dependency
44
- cp = None # type: ignore[assignment]
45
- gpu_threshold_otsu = None # type: ignore[assignment]
46
- gpu_opening = None # type: ignore[assignment]
47
- gpu_rectangle = None # type: ignore[assignment]
48
- gpu_rotate = None # type: ignore[assignment]
37
+ Array2D = np.ndarray
38
+ KernelShape = tuple[int, int]
39
+ EPS = 1e-6
40
+ NOISE_FLOOR_SIGMA = 1.5
41
+ MIN_SCALE_SIGMA = 5.0
42
+ SIGNAL_SCALE_QUANTILE = 99.9
43
+ USE_LAPLACE_FOR_PEAKS = False
44
+
45
+
46
+ def _ensure_torch_available() -> None:
47
+ """Ensure torch is available for RMP processing."""
48
+ if torch is None or F is None: # pragma: no cover - import guard
49
+ raise ImportError("torch is required for the RMP detector.")
50
+
51
+
52
+ def _torch_device() -> "torch.device":
53
+ """Return the best available torch device (CUDA, MPS, then CPU)."""
54
+ _ensure_torch_available()
55
+ assert torch is not None
56
+ if torch.cuda.is_available():
57
+ return torch.device("cuda")
58
+ mps_backend = getattr(torch.backends, "mps", None)
59
+ if mps_backend is not None and mps_backend.is_available():
60
+ return torch.device("mps")
61
+ return torch.device("cpu")
62
+
63
+
64
+ def _to_image_tensor(image: np.ndarray, *, device: "torch.device") -> "torch.Tensor":
65
+ """Convert a 2D image array to a [1,1,H,W] torch tensor."""
66
+ _ensure_torch_available()
67
+ assert torch is not None
68
+ tensor = torch.as_tensor(image, dtype=torch.float32, device=device)
69
+ if tensor.ndim != 2:
70
+ raise ValueError("Expected a 2D image for tensor conversion.")
71
+ return tensor.unsqueeze(0).unsqueeze(0)
72
+
73
+
74
+ def _rotate_tensor(image: "torch.Tensor", angle: float) -> "torch.Tensor":
75
+ """Rotate a [1,1,H,W] tensor with reflection padding."""
76
+ _ensure_torch_available()
77
+ assert torch is not None
78
+ assert F is not None
79
+ if image.ndim != 4:
80
+ raise ValueError("Expected a [N,C,H,W] tensor for rotation.")
81
+
82
+ height = float(image.shape[-2])
83
+ width = float(image.shape[-1])
84
+ hw_ratio = height / width if width > 0 else 1.0
85
+ wh_ratio = width / height if height > 0 else 1.0
86
+
87
+ radians = np.deg2rad(float(angle))
88
+ cos_v = float(np.cos(radians))
89
+ sin_v = float(np.sin(radians))
90
+ # affine_grid operates in normalized coordinates; non-square images need
91
+ # aspect-ratio correction on the off-diagonal terms.
92
+ theta = torch.tensor(
93
+ [[[cos_v, -sin_v * hw_ratio, 0.0], [sin_v * wh_ratio, cos_v, 0.0]]],
94
+ dtype=image.dtype,
95
+ device=image.device,
96
+ )
97
+ grid = F.affine_grid(theta, tuple(image.shape), align_corners=False)
98
+ return F.grid_sample(
99
+ image,
100
+ grid,
101
+ mode="bilinear",
102
+ padding_mode="reflection",
103
+ align_corners=False,
104
+ )
49
105
 
50
106
 
51
- Array2D = np.ndarray
107
+ def _grayscale_opening_tensor(
108
+ image: "torch.Tensor",
109
+ kernel_shape: KernelShape,
110
+ ) -> "torch.Tensor":
111
+ """Apply grayscale opening (erosion then dilation) with a rectangular kernel."""
112
+ _ensure_torch_available()
113
+ assert F is not None
114
+ img_h = int(image.shape[-2])
115
+ img_w = int(image.shape[-1])
116
+ ky = min(max(1, int(kernel_shape[0])), max(1, img_h))
117
+ kx = min(max(1, int(kernel_shape[1])), max(1, img_w))
118
+ pad_y = ky // 2
119
+ pad_x = kx // 2
120
+ pad = (pad_x, pad_x, pad_y, pad_y)
121
+
122
+ # Erosion via pooling uses the min-over-window identity:
123
+ # min(x) == -max(-x). Missing the inner negation flips morphology behavior.
124
+ eroded = -F.max_pool2d(
125
+ F.pad(-image, pad, mode="reflect"),
126
+ kernel_size=(ky, kx),
127
+ stride=1,
128
+ )
129
+ opened = F.max_pool2d(
130
+ F.pad(eroded, pad, mode="reflect"),
131
+ kernel_size=(ky, kx),
132
+ stride=1,
133
+ )
134
+ return opened
135
+
136
+
137
+ def _kernel_shape(footprint: KernelShape | np.ndarray) -> KernelShape:
138
+ """Return kernel shape from either a tuple footprint or array."""
139
+ if isinstance(footprint, tuple):
140
+ return max(1, int(footprint[0])), max(1, int(footprint[1]))
141
+ arr = np.asarray(footprint)
142
+ if arr.ndim != 2:
143
+ raise ValueError("Structuring element must be 2D.")
144
+ return max(1, int(arr.shape[0])), max(1, int(arr.shape[1]))
52
145
 
53
146
 
54
147
  def _normalize_image(image: np.ndarray) -> np.ndarray:
55
148
  """Normalize an image to float32 in [0, 1]."""
149
+ device = _torch_device()
56
150
  data = np.asarray(image, dtype=np.float32)
57
- min_val = float(data.min())
58
- max_val = float(data.max())
59
- if max_val <= min_val:
151
+ _ensure_torch_available()
152
+ assert torch is not None
153
+ tensor = torch.as_tensor(data, dtype=torch.float32, device=device)
154
+ min_val = tensor.amin()
155
+ max_val = tensor.amax()
156
+ if bool(max_val <= min_val):
60
157
  return np.zeros_like(data, dtype=np.float32)
61
- data = (data - min_val) / (max_val - min_val)
62
- return np.clip(data, 0.0, 1.0)
158
+ normalized = (tensor - min_val) / (max_val - min_val)
159
+ normalized = normalized.clamp(0.0, 1.0)
160
+ return normalized.detach().cpu().numpy().astype(np.float32, copy=False)
63
161
 
64
162
 
65
- def _pad_for_rotation(image: Array2D) -> tuple[Array2D, tuple[int, int]]:
66
- """Pad image to preserve content after rotations."""
67
- nrows, ncols = image.shape[:2]
68
- diagonal = int(np.ceil(np.sqrt(nrows**2 + ncols**2)))
163
+ def _clamp_threshold(value: float) -> float:
164
+ """Clamp threshold to the inclusive [0.0, 1.0] range."""
165
+ return float(np.clip(value, 0.0, 1.0))
166
+
167
+
168
+ def _normalize_top_hat_unit(image: np.ndarray) -> np.ndarray:
169
+ """Robust normalization for top-hat output."""
170
+ data = np.asarray(image, dtype=np.float32)
171
+ finite_mask = np.isfinite(data)
172
+ if not np.any(finite_mask):
173
+ return np.zeros_like(data, dtype=np.float32)
174
+
175
+ valid = data[finite_mask]
176
+ background = float(np.nanmedian(valid))
177
+ sigma = 1.4826 * float(np.nanmedian(np.abs(valid - background)))
69
178
 
179
+ if (not np.isfinite(sigma)) or sigma <= EPS:
180
+ sigma = float(np.nanstd(valid))
181
+ if (not np.isfinite(sigma)) or sigma <= EPS:
182
+ return np.zeros_like(data, dtype=np.float32)
183
+
184
+ noise_floor = background + (NOISE_FLOOR_SIGMA * sigma)
185
+ residual = np.clip(data - noise_floor, 0.0, None)
186
+ residual = np.where(finite_mask, residual, 0.0)
187
+
188
+ positive = residual[residual > 0.0]
189
+ if positive.size == 0:
190
+ return np.zeros_like(data, dtype=np.float32)
191
+ high = float(np.nanpercentile(positive, SIGNAL_SCALE_QUANTILE))
192
+ if (not np.isfinite(high)) or high <= EPS:
193
+ high = float(np.nanmax(positive))
194
+ if (not np.isfinite(high)) or high <= EPS:
195
+ return np.zeros_like(data, dtype=np.float32)
196
+
197
+ scale = max(high, MIN_SCALE_SIGMA * sigma, EPS)
198
+ normalized = np.clip(residual / scale, 0.0, 1.0)
199
+ return normalized.astype(np.float32, copy=False)
200
+
201
+
202
+ def _markers_from_local_maxima(
203
+ enhanced: np.ndarray,
204
+ threshold: float,
205
+ use_laplace: bool = USE_LAPLACE_FOR_PEAKS,
206
+ ) -> np.ndarray:
207
+ """Build marker labels from local maxima and thresholding."""
208
+ connectivity = max(1, min(2, enhanced.ndim))
209
+ response = (
210
+ laplace(enhanced.astype(np.float32, copy=False))
211
+ if use_laplace
212
+ else np.asarray(enhanced, dtype=np.float32)
213
+ )
214
+ mask = local_maxima(response, connectivity=connectivity)
215
+ mask = mask & (response > threshold)
216
+
217
+ markers = np.zeros(enhanced.shape, dtype=np.int32)
218
+ coords = np.argwhere(mask)
219
+ if coords.size == 0:
220
+ return markers
221
+
222
+ max_indices = np.asarray(enhanced.shape) - 1
223
+ coords = np.clip(coords, 0, max_indices)
224
+ markers[tuple(coords.T)] = 1
225
+
226
+ structure = ndi.generate_binary_structure(enhanced.ndim, 1)
227
+ marker_labels, _num = ndi.label(markers > 0, structure=structure)
228
+ return marker_labels.astype(np.int32, copy=False)
229
+
230
+
231
+ def _segment_from_markers(
232
+ enhanced: np.ndarray,
233
+ markers: np.ndarray,
234
+ threshold: float,
235
+ ) -> np.ndarray:
236
+ """Run watershed from local-maxima markers inside threshold foreground."""
237
+ foreground = enhanced > threshold
238
+ if not np.any(foreground):
239
+ return np.zeros_like(enhanced, dtype=np.int32)
240
+
241
+ seeded_markers = markers * foreground.astype(np.int32, copy=False)
242
+ if not np.any(seeded_markers > 0):
243
+ return np.zeros_like(enhanced, dtype=np.int32)
244
+
245
+ labels = watershed(
246
+ -enhanced.astype(np.float32, copy=False),
247
+ markers=seeded_markers,
248
+ mask=foreground,
249
+ )
250
+ return labels.astype(np.int32, copy=False)
251
+
252
+ def _pad_tensor_for_rotation(
253
+ image: "torch.Tensor",
254
+ ) -> tuple["torch.Tensor", tuple[int, int]]:
255
+ """Pad a [1,1,H,W] tensor to preserve content after rotations."""
256
+ nrows = int(image.shape[-2])
257
+ ncols = int(image.shape[-1])
258
+ diagonal = int(np.ceil(np.sqrt(nrows**2 + ncols**2)))
70
259
  rows_to_pad = int(np.ceil((diagonal - nrows) / 2))
71
260
  cols_to_pad = int(np.ceil((diagonal - ncols) / 2))
72
-
73
- padded_image = np.pad(
261
+ assert F is not None
262
+ padded = F.pad(
74
263
  image,
75
- ((rows_to_pad, rows_to_pad), (cols_to_pad, cols_to_pad)),
264
+ (cols_to_pad, cols_to_pad, rows_to_pad, rows_to_pad),
76
265
  mode="reflect",
77
266
  )
78
-
79
- return padded_image, (rows_to_pad, cols_to_pad)
80
-
267
+ return padded, (rows_to_pad, cols_to_pad)
81
268
 
82
269
  def _rmp_opening(
83
270
  input_image: Array2D,
84
- structuring_element: Array2D,
271
+ structuring_element: KernelShape | Array2D,
85
272
  rotation_angles: Iterable[int],
86
273
  ) -> Array2D:
87
274
  """Perform the RMP opening on an image."""
88
- padded_image, (newy, newx) = _pad_for_rotation(input_image)
89
- rotated_images = [
90
- rotate(padded_image, angle, mode="reflect") for angle in rotation_angles
91
- ]
275
+ device = _torch_device()
276
+ tensor = _to_image_tensor(np.asarray(input_image, dtype=np.float32), device=device)
277
+ padded, (newy, newx) = _pad_tensor_for_rotation(tensor)
278
+ kernel_shape = _kernel_shape(structuring_element)
279
+
280
+ rotated_images = [_rotate_tensor(padded, angle) for angle in rotation_angles]
92
281
  opened_images = [
93
- opening(image, footprint=structuring_element, mode="reflect")
94
- for image in rotated_images
282
+ _grayscale_opening_tensor(image, kernel_shape) for image in rotated_images
95
283
  ]
96
284
  rotated_back = [
97
- rotate(image, -angle, mode="reflect")
285
+ _rotate_tensor(image, -angle)
98
286
  for image, angle in zip(opened_images, rotation_angles)
99
287
  ]
100
-
101
- stacked_images = np.stack(rotated_back, axis=0)
102
- union_image = np.max(stacked_images, axis=0)
288
+ stacked = torch.stack(rotated_back, dim=0)
289
+ union_image = stacked.max(dim=0).values
103
290
  cropped = union_image[
291
+ ...,
104
292
  newy : newy + input_image.shape[0],
105
293
  newx : newx + input_image.shape[1],
106
294
  ]
107
- return cropped
295
+ return cropped.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.float32, copy=False)
108
296
 
109
297
 
110
298
  def _rmp_top_hat(
@@ -119,8 +307,8 @@ def _rmp_top_hat(
119
307
 
120
308
  def _compute_top_hat(input_image: Array2D, config: "RMPSettings") -> Array2D:
121
309
  """Compute the RMP top-hat response for a 2D image."""
122
- denoising_se = rectangle(1, config.denoising_se_length)
123
- extraction_se = rectangle(1, config.extraction_se_length)
310
+ denoising_se: KernelShape = (1, config.denoising_se_length)
311
+ extraction_se: KernelShape = (1, config.extraction_se_length)
124
312
  rotation_angles = tuple(range(0, 180, config.angle_spacing))
125
313
 
126
314
  working = (
@@ -131,64 +319,6 @@ def _compute_top_hat(input_image: Array2D, config: "RMPSettings") -> Array2D:
131
319
  return _rmp_top_hat(working, extraction_se, rotation_angles)
132
320
 
133
321
 
134
- def _binary_to_instances(mask: np.ndarray, start_label: int = 1) -> tuple[np.ndarray, int]:
135
- """Convert a binary mask to instance labels.
136
-
137
- Parameters
138
- ----------
139
- mask : numpy.ndarray
140
- Binary mask where foreground pixels are non-zero.
141
- start_label : int, optional
142
- Starting label index for the output. Defaults to 1.
143
-
144
- Returns
145
- -------
146
- numpy.ndarray
147
- Labeled instance mask.
148
- int
149
- Next label value after the labeled mask.
150
- """
151
- labeled = label(mask > 0)
152
- if start_label > 1 and labeled.max() > 0:
153
- labeled = labeled + (start_label - 1)
154
- next_label = int(labeled.max()) + 1
155
- return labeled.astype(np.int32, copy=False), next_label
156
-
157
-
158
- def _watershed_instances(
159
- image: np.ndarray,
160
- binary: np.ndarray,
161
- min_distance: int,
162
- ) -> np.ndarray:
163
- """Split touching spots using watershed segmentation."""
164
- if not np.any(binary):
165
- return np.zeros_like(binary, dtype=np.int32)
166
- if not np.any(~binary):
167
- labeled, _ = _binary_to_instances(binary)
168
- return labeled
169
-
170
- distance = ndi.distance_transform_edt(binary)
171
- coordinates = peak_local_max(
172
- distance,
173
- labels=binary.astype(np.uint8),
174
- min_distance=max(1, int(min_distance)),
175
- exclude_border=False,
176
- )
177
- if coordinates.size == 0:
178
- labeled, _ = _binary_to_instances(binary)
179
- return labeled
180
-
181
- peaks = np.zeros(binary.shape, dtype=bool)
182
- peaks[tuple(coordinates.T)] = True
183
- markers = label(peaks).astype(np.int32, copy=False)
184
- if markers.max() == 0:
185
- labeled, _ = _binary_to_instances(binary)
186
- return labeled
187
-
188
- labels = watershed(-distance, markers, mask=binary)
189
- return labels.astype(np.int32, copy=False)
190
-
191
-
192
322
  def _ensure_dask_available() -> None:
193
323
  """Ensure dask is installed for tiled execution."""
194
324
  if da is None: # pragma: no cover - import guard
@@ -201,18 +331,6 @@ def _ensure_distributed_available() -> None:
201
331
  raise ImportError("dask.distributed is required for distributed execution.")
202
332
 
203
333
 
204
- def _ensure_cupy_available() -> None:
205
- """Ensure CuPy and cuCIM are installed for GPU execution."""
206
- if (
207
- cp is None
208
- or gpu_threshold_otsu is None
209
- or gpu_opening is None
210
- or gpu_rectangle is None
211
- or gpu_rotate is None
212
- ): # pragma: no cover - import guard
213
- raise ImportError("cupy + cucim are required for GPU execution.")
214
-
215
-
216
334
  def _dask_available() -> bool:
217
335
  """Return True when dask is available."""
218
336
  return da is not None
@@ -223,18 +341,6 @@ def _distributed_available() -> bool:
223
341
  return Client is not None and LocalCluster is not None and da is not None
224
342
 
225
343
 
226
- def _gpu_available() -> bool:
227
- """Return True when CuPy/cuCIM are available for GPU execution."""
228
- return (
229
- cp is not None
230
- and gpu_threshold_otsu is not None
231
- and gpu_opening is not None
232
- and gpu_rectangle is not None
233
- and gpu_rotate is not None
234
- and da is not None
235
- )
236
-
237
-
238
344
  def _recommended_overlap(config: "RMPSettings") -> int:
239
345
  """Derive a suitable overlap from structuring-element sizes."""
240
346
  lengths = [config.extraction_se_length]
@@ -244,21 +350,18 @@ def _recommended_overlap(config: "RMPSettings") -> int:
244
350
 
245
351
 
246
352
  @contextmanager
247
- def _cluster_client(use_gpu: bool):
353
+ def _cluster_client():
248
354
  """Yield a connected Dask client backed by a local cluster."""
249
355
  _ensure_distributed_available()
250
-
251
- use_cuda_cluster = bool(use_gpu and cp is not None and LocalCUDACluster is not None)
252
- cluster_cls = LocalCUDACluster if use_cuda_cluster else LocalCluster
253
- with cluster_cls() as cluster: # type: ignore[call-arg]
356
+ with LocalCluster() as cluster:
254
357
  with Client(cluster) as client:
255
358
  yield client
256
359
 
257
360
 
258
- def _cpu_top_hat_block(block: np.ndarray, config: "RMPSettings") -> np.ndarray:
361
+ def _rmp_top_hat_block(block: np.ndarray, config: "RMPSettings") -> np.ndarray:
259
362
  """Return background-subtracted tile via the RMP top-hat pipeline."""
260
- denoising_se = rectangle(1, config.denoising_se_length)
261
- extraction_se = rectangle(1, config.extraction_se_length)
363
+ denoising_se: KernelShape = (1, config.denoising_se_length)
364
+ extraction_se: KernelShape = (1, config.extraction_se_length)
262
365
  rotation_angles = tuple(range(0, 180, config.angle_spacing))
263
366
 
264
367
  working = (
@@ -270,56 +373,86 @@ def _cpu_top_hat_block(block: np.ndarray, config: "RMPSettings") -> np.ndarray:
270
373
  return np.asarray(top_hat, dtype=np.float32)
271
374
 
272
375
 
273
- def _gpu_pad_for_rotation(image: "cp.ndarray") -> tuple["cp.ndarray", tuple[int, int]]:
274
- nrows, ncols = image.shape[:2]
275
- diagonal = int(cp.ceil(cp.sqrt(nrows**2 + ncols**2)).item())
276
- rows_to_pad = int(cp.ceil((diagonal - nrows) / 2).item())
277
- cols_to_pad = int(cp.ceil((diagonal - ncols) / 2).item())
278
- padded = cp.pad(
279
- image,
280
- ((rows_to_pad, rows_to_pad), (cols_to_pad, cols_to_pad)),
281
- mode="reflect",
282
- )
283
- return padded, (rows_to_pad, cols_to_pad)
284
-
285
-
286
- def _gpu_rmp_opening(
287
- image: "cp.ndarray",
288
- structuring_element: "cp.ndarray",
289
- rotation_angles: Iterable[int],
290
- ) -> "cp.ndarray":
291
- padded, (newy, newx) = _gpu_pad_for_rotation(image)
292
- rotated = [gpu_rotate(padded, angle, mode="reflect") for angle in rotation_angles]
293
- opened = [
294
- gpu_opening(img, footprint=structuring_element, mode="reflect")
295
- for img in rotated
296
- ]
297
- rotated_back = [
298
- gpu_rotate(img, -angle, mode="reflect")
299
- for img, angle in zip(opened, rotation_angles)
300
- ]
376
+ def _compute_top_hat_2d(
377
+ image_2d: np.ndarray,
378
+ config: "RMPSettings",
379
+ *,
380
+ use_tiled: bool,
381
+ distributed: bool,
382
+ client: "Client | None" = None,
383
+ ) -> np.ndarray:
384
+ """Compute a top-hat image for one 2D plane."""
385
+ if use_tiled:
386
+ return _rmp_top_hat_tiled(
387
+ image_2d,
388
+ config=config,
389
+ distributed=distributed,
390
+ client=client,
391
+ )
392
+ return _compute_top_hat(image_2d, config)
301
393
 
302
- stacked = cp.stack(rotated_back, axis=0)
303
- union = cp.max(stacked, axis=0)
304
- return union[newy : newy + image.shape[0], newx : newx + image.shape[1]]
305
394
 
395
+ def _compute_top_hat_nd(
396
+ image: np.ndarray,
397
+ config: "RMPSettings",
398
+ *,
399
+ use_tiled: bool,
400
+ use_distributed: bool,
401
+ ) -> np.ndarray:
402
+ """Compute top-hat for 2D images or slice-wise for 3D stacks."""
403
+ if image.ndim == 2:
404
+ return _compute_top_hat_2d(
405
+ image,
406
+ config,
407
+ use_tiled=use_tiled,
408
+ distributed=use_distributed,
409
+ )
306
410
 
307
- def _gpu_top_hat(block: np.ndarray, config: "RMPSettings") -> np.ndarray:
308
- """CuPy-backed RMP top-hat for a single tile."""
309
- _ensure_cupy_available()
411
+ top_hat_stack = np.zeros_like(image, dtype=np.float32)
412
+ if use_tiled and use_distributed:
413
+ with _cluster_client() as client:
414
+ for z in range(image.shape[0]):
415
+ top_hat_stack[z] = _compute_top_hat_2d(
416
+ image[z],
417
+ config,
418
+ use_tiled=True,
419
+ distributed=True,
420
+ client=client,
421
+ )
422
+ return top_hat_stack
423
+
424
+ for z in range(image.shape[0]):
425
+ top_hat_stack[z] = _compute_top_hat_2d(
426
+ image[z],
427
+ config,
428
+ use_tiled=use_tiled,
429
+ distributed=False,
430
+ )
431
+ return top_hat_stack
310
432
 
311
- gpu_block = cp.asarray(block, dtype=cp.float32)
312
- denoising_se = gpu_rectangle(1, config.denoising_se_length)
313
- extraction_se = gpu_rectangle(1, config.extraction_se_length)
314
- rotation_angles = tuple(range(0, 180, config.angle_spacing))
315
433
 
316
- working = (
317
- _gpu_rmp_opening(gpu_block, denoising_se, rotation_angles)
318
- if config.enable_denoising
319
- else gpu_block
434
+ def _postprocess_top_hat(
435
+ top_hat: np.ndarray,
436
+ config: "RMPSettings",
437
+ ) -> tuple[np.ndarray, np.ndarray]:
438
+ """Apply normalization, thresholding, marker extraction, and watershed."""
439
+ top_hat_normalized = _normalize_top_hat_unit(top_hat)
440
+ threshold = (
441
+ _clamp_threshold(float(threshold_otsu(top_hat_normalized)))
442
+ if config.auto_threshold
443
+ else config.manual_threshold
444
+ )
445
+ markers = _markers_from_local_maxima(
446
+ top_hat_normalized,
447
+ threshold,
448
+ use_laplace=USE_LAPLACE_FOR_PEAKS,
320
449
  )
321
- top_hat = working - _gpu_rmp_opening(working, extraction_se, rotation_angles)
322
- return cp.asnumpy(top_hat).astype(np.float32, copy=False)
450
+ labels = _segment_from_markers(
451
+ top_hat_normalized,
452
+ markers,
453
+ threshold,
454
+ )
455
+ return labels, top_hat_normalized
323
456
 
324
457
 
325
458
  def _rmp_top_hat_tiled(
@@ -327,26 +460,16 @@ def _rmp_top_hat_tiled(
327
460
  config: "RMPSettings",
328
461
  chunk_size: tuple[int, int] = (1024, 1024),
329
462
  overlap: int | None = None,
330
- use_gpu: bool = False,
331
463
  distributed: bool = False,
332
464
  client: "Client | None" = None,
333
465
  ) -> np.ndarray:
334
466
  """Return the RMP top-hat image using tiled execution."""
335
467
  _ensure_dask_available()
336
- if use_gpu:
337
- _ensure_cupy_available()
338
468
 
339
469
  effective_overlap = _recommended_overlap(config) if overlap is None else overlap
340
470
 
341
- if use_gpu:
342
-
343
- def block_fn(block, block_info=None):
344
- return _gpu_top_hat(block, config)
345
-
346
- else:
347
-
348
- def block_fn(block, block_info=None):
349
- return _cpu_top_hat_block(block, config)
471
+ def block_fn(block, block_info=None):
472
+ return _rmp_top_hat_block(block, config)
350
473
 
351
474
  arr = da.from_array(image.astype(np.float32, copy=False), chunks=chunk_size)
352
475
  result = arr.map_overlap(
@@ -360,7 +483,7 @@ def _rmp_top_hat_tiled(
360
483
  if distributed:
361
484
  _ensure_distributed_available()
362
485
  if client is None:
363
- with _cluster_client(use_gpu) as temp_client:
486
+ with _cluster_client() as temp_client:
364
487
  return temp_client.compute(result).result()
365
488
  return client.compute(result).result()
366
489
 
@@ -377,7 +500,7 @@ class RMPSettings:
377
500
  auto_threshold: bool = True
378
501
  manual_threshold: float = 0.05
379
502
  enable_denoising: bool = True
380
- use_3d: bool = False
503
+
381
504
 
382
505
  class RMPDetector(SenoQuantSpotDetector):
383
506
  """RMP spot detector implementation."""
@@ -408,8 +531,9 @@ class RMPDetector(SenoQuantSpotDetector):
408
531
  raise ValueError("RMP requires single-channel images.")
409
532
 
410
533
  settings = kwargs.get("settings", {})
411
- manual_threshold = float(settings.get("manual_threshold", 0.5))
412
- manual_threshold = max(0.0, min(1.0, manual_threshold))
534
+ manual_threshold = _clamp_threshold(
535
+ float(settings.get("manual_threshold", 0.5))
536
+ )
413
537
  config = RMPSettings(
414
538
  denoising_se_length=int(settings.get("denoising_kernel_length", 2)),
415
539
  extraction_se_length=int(settings.get("extraction_kernel_length", 10)),
@@ -417,7 +541,6 @@ class RMPDetector(SenoQuantSpotDetector):
417
541
  auto_threshold=bool(settings.get("auto_threshold", True)),
418
542
  manual_threshold=manual_threshold,
419
543
  enable_denoising=bool(settings.get("enable_denoising", True)),
420
- use_3d=bool(settings.get("use_3d", False)),
421
544
  )
422
545
 
423
546
  if config.angle_spacing <= 0:
@@ -430,70 +553,22 @@ class RMPDetector(SenoQuantSpotDetector):
430
553
  raise ValueError("RMP expects 2D images or 3D stacks.")
431
554
 
432
555
  normalized = _normalize_image(data)
433
- if normalized.ndim == 3 and not config.use_3d:
434
- raise ValueError("Enable 3D to process stacks.")
435
556
 
436
557
  use_distributed = _distributed_available()
437
- use_gpu = _gpu_available()
438
- use_tiled = _dask_available() and (use_distributed or use_gpu)
439
-
440
- if normalized.ndim == 2:
441
- image_2d = normalized
442
- if use_tiled:
443
- top_hat = _rmp_top_hat_tiled(
444
- image_2d,
445
- config=config,
446
- use_gpu=use_gpu,
447
- distributed=use_distributed,
448
- )
449
- else:
450
- top_hat = _compute_top_hat(image_2d, config)
451
-
452
- threshold = (
453
- threshold_otsu(top_hat)
454
- if config.auto_threshold
455
- else config.manual_threshold
456
- )
457
- binary = img_as_ubyte(top_hat > threshold)
458
- labels = _watershed_instances(
459
- top_hat,
460
- binary > 0,
461
- min_distance=max(1, config.extraction_se_length // 2),
462
- )
463
- return {"mask": labels}
464
-
465
- top_hat_stack = np.zeros_like(normalized, dtype=np.float32)
466
- if use_tiled and use_distributed:
467
- with _cluster_client(use_gpu) as client:
468
- for z in range(normalized.shape[0]):
469
- top_hat_stack[z] = _rmp_top_hat_tiled(
470
- normalized[z],
471
- config=config,
472
- use_gpu=use_gpu,
473
- distributed=True,
474
- client=client,
475
- )
476
- elif use_tiled:
477
- for z in range(normalized.shape[0]):
478
- top_hat_stack[z] = _rmp_top_hat_tiled(
479
- normalized[z],
480
- config=config,
481
- use_gpu=use_gpu,
482
- distributed=False,
483
- )
484
- else:
485
- for z in range(normalized.shape[0]):
486
- top_hat_stack[z] = _compute_top_hat(normalized[z], config)
487
-
488
- threshold = (
489
- threshold_otsu(top_hat_stack)
490
- if config.auto_threshold
491
- else config.manual_threshold
492
- )
493
- binary_stack = img_as_ubyte(top_hat_stack > threshold)
494
- labels = _watershed_instances(
495
- top_hat_stack,
496
- binary_stack > 0,
497
- min_distance=max(1, config.extraction_se_length // 2),
558
+ use_tiled = _dask_available()
559
+ top_hat = _compute_top_hat_nd(
560
+ normalized,
561
+ config,
562
+ use_tiled=use_tiled,
563
+ use_distributed=use_distributed,
498
564
  )
499
- return {"mask": labels}
565
+ labels, _top_hat_normalized = _postprocess_top_hat(top_hat, config)
566
+ return {
567
+ "mask": labels,
568
+ # "debug_images": {
569
+ # "debug_top_hat_before_threshold": _top_hat_normalized.astype(
570
+ # np.float32,
571
+ # copy=False,
572
+ # ),
573
+ # },
574
+ }