napari-tmidas 0.2.1__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. napari_tmidas/__init__.py +35 -5
  2. napari_tmidas/_crop_anything.py +1458 -499
  3. napari_tmidas/_env_manager.py +76 -0
  4. napari_tmidas/_file_conversion.py +1646 -1131
  5. napari_tmidas/_file_selector.py +1464 -223
  6. napari_tmidas/_label_inspection.py +83 -8
  7. napari_tmidas/_processing_worker.py +309 -0
  8. napari_tmidas/_reader.py +6 -10
  9. napari_tmidas/_registry.py +15 -14
  10. napari_tmidas/_roi_colocalization.py +1221 -84
  11. napari_tmidas/_tests/test_crop_anything.py +123 -0
  12. napari_tmidas/_tests/test_env_manager.py +89 -0
  13. napari_tmidas/_tests/test_file_selector.py +90 -0
  14. napari_tmidas/_tests/test_grid_view_overlay.py +193 -0
  15. napari_tmidas/_tests/test_init.py +98 -0
  16. napari_tmidas/_tests/test_intensity_label_filter.py +222 -0
  17. napari_tmidas/_tests/test_label_inspection.py +86 -0
  18. napari_tmidas/_tests/test_processing_basic.py +500 -0
  19. napari_tmidas/_tests/test_processing_worker.py +142 -0
  20. napari_tmidas/_tests/test_regionprops_analysis.py +547 -0
  21. napari_tmidas/_tests/test_registry.py +135 -0
  22. napari_tmidas/_tests/test_scipy_filters.py +168 -0
  23. napari_tmidas/_tests/test_skimage_filters.py +259 -0
  24. napari_tmidas/_tests/test_split_channels.py +217 -0
  25. napari_tmidas/_tests/test_spotiflow.py +87 -0
  26. napari_tmidas/_tests/test_tyx_display_fix.py +142 -0
  27. napari_tmidas/_tests/test_ui_utils.py +68 -0
  28. napari_tmidas/_tests/test_widget.py +30 -0
  29. napari_tmidas/_tests/test_windows_basic.py +66 -0
  30. napari_tmidas/_ui_utils.py +57 -0
  31. napari_tmidas/_version.py +16 -3
  32. napari_tmidas/_widget.py +41 -4
  33. napari_tmidas/processing_functions/basic.py +557 -20
  34. napari_tmidas/processing_functions/careamics_env_manager.py +72 -99
  35. napari_tmidas/processing_functions/cellpose_env_manager.py +415 -112
  36. napari_tmidas/processing_functions/cellpose_segmentation.py +132 -191
  37. napari_tmidas/processing_functions/colocalization.py +513 -56
  38. napari_tmidas/processing_functions/grid_view_overlay.py +703 -0
  39. napari_tmidas/processing_functions/intensity_label_filter.py +422 -0
  40. napari_tmidas/processing_functions/regionprops_analysis.py +1280 -0
  41. napari_tmidas/processing_functions/sam2_env_manager.py +53 -69
  42. napari_tmidas/processing_functions/sam2_mp4.py +274 -195
  43. napari_tmidas/processing_functions/scipy_filters.py +403 -8
  44. napari_tmidas/processing_functions/skimage_filters.py +424 -212
  45. napari_tmidas/processing_functions/spotiflow_detection.py +949 -0
  46. napari_tmidas/processing_functions/spotiflow_env_manager.py +591 -0
  47. napari_tmidas/processing_functions/timepoint_merger.py +334 -86
  48. napari_tmidas/processing_functions/trackastra_tracking.py +24 -5
  49. {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/METADATA +92 -39
  50. napari_tmidas-0.2.4.dist-info/RECORD +63 -0
  51. napari_tmidas/_tests/__init__.py +0 -0
  52. napari_tmidas-0.2.1.dist-info/RECORD +0 -38
  53. {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/WHEEL +0 -0
  54. {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/entry_points.txt +0 -0
  55. {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/licenses/LICENSE +0 -0
  56. {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,949 @@
1
+ # processing_functions/spotiflow_detection.py
2
+ """
3
+ Processing functions for spot detection using Spotiflow.
4
+
5
+ This module provides functionality to detect spots in fluorescence microscopy images
6
+ using Spotiflow models. It supports both 2D and 3D data with various pretrained models.
7
+
8
+ The functions will automatically create and manage a dedicated environment for Spotiflow
9
+ if it's not already installed in the main environment.
10
+ """
11
+ import os
12
+
13
+ import numpy as np
14
+
15
+ from napari_tmidas._registry import BatchProcessingRegistry
16
+
17
+ # Import the environment manager for Spotiflow
18
+ from napari_tmidas.processing_functions.spotiflow_env_manager import (
19
+ run_spotiflow_in_env,
20
+ )
21
+
22
+
23
+ # Utility functions for axes and input preparation (from napari-spotiflow)
24
+ def _validate_axes(img: np.ndarray, axes: str) -> None:
25
+ """Validate that the number of dimensions in the image matches the given axes string."""
26
+ if img.ndim != len(axes):
27
+ raise ValueError(
28
+ f"Image has {img.ndim} dimensions, but axes has {len(axes)} dimensions"
29
+ )
30
+
31
+
32
+ def _prepare_input(img: np.ndarray, axes: str) -> np.ndarray:
33
+ """Reshape input for Spotiflow's API compatibility based on axes notation."""
34
+ _validate_axes(img, axes)
35
+
36
+ if axes in {"YX", "ZYX", "TYX", "TZYX"}:
37
+ return img[..., None]
38
+ elif axes in {"YXC", "ZYXC", "TYXC", "TZYXC"}:
39
+ return img
40
+ elif axes == "CYX":
41
+ return img.transpose(1, 2, 0)
42
+ elif axes == "CZYX":
43
+ return img.transpose(1, 2, 3, 0)
44
+ elif axes == "ZCYX" or axes == "TCYX":
45
+ return img.transpose(0, 2, 3, 1)
46
+ elif axes == "TZCYX":
47
+ return img.transpose(0, 1, 3, 4, 2)
48
+ elif axes == "TCZYX":
49
+ return img.transpose(0, 2, 3, 4, 1)
50
+ else:
51
+ raise ValueError(f"Invalid axes: {axes}")
52
+
53
+
54
+ def _infer_axes(img: np.ndarray) -> str:
55
+ """Infer the most likely axes order for the image."""
56
+ ndim = img.ndim
57
+ if ndim == 2:
58
+ return "YX"
59
+ elif ndim == 3:
60
+ # For 3D, we need to make an educated guess
61
+ # Most common is ZYX for 3D microscopy
62
+ return "ZYX"
63
+ elif ndim == 4:
64
+ # Could be TZYX or ZYXC, let's check the last dimension
65
+ if img.shape[-1] <= 4: # Likely channels
66
+ return "ZYXC"
67
+ else:
68
+ return "TZYX"
69
+ elif ndim == 5:
70
+ return "TZYXC"
71
+ else:
72
+ raise ValueError(f"Cannot infer axes for {ndim}D image")
73
+
74
+
75
+ # Check if Spotiflow is directly available in current environment
76
+ try:
77
+ import importlib.util
78
+
79
+ spec = importlib.util.find_spec("spotiflow.model")
80
+ if spec is not None:
81
+ SPOTIFLOW_AVAILABLE = True
82
+ USE_DEDICATED_ENV = False
83
+ print("Spotiflow found in current environment, using direct import")
84
+ else:
85
+ raise ImportError("Spotiflow not found")
86
+ except ImportError:
87
+ SPOTIFLOW_AVAILABLE = False
88
+ USE_DEDICATED_ENV = True
89
+ print(
90
+ "Spotiflow not found in current environment, will use dedicated environment"
91
+ )
92
+
93
+
94
+ def _convert_points_to_labels_with_heatmap(
95
+ image: np.ndarray,
96
+ points: np.ndarray,
97
+ spot_radius: int,
98
+ pretrained_model: str,
99
+ model_path: str,
100
+ prob_thresh: float,
101
+ force_cpu: bool,
102
+ ) -> np.ndarray:
103
+ """
104
+ Convert points to label masks using Spotiflow's probability heatmap for better segmentation.
105
+ """
106
+ try:
107
+ import torch
108
+ from scipy.ndimage import label
109
+ from skimage.segmentation import watershed
110
+ from spotiflow.model import Spotiflow
111
+
112
+ # Set device
113
+ if force_cpu:
114
+ device = torch.device("cpu")
115
+ else:
116
+ device = torch.device(
117
+ "cuda" if torch.cuda.is_available() else "cpu"
118
+ )
119
+
120
+ # Load the model (reuse existing model loading logic)
121
+ if model_path and os.path.exists(model_path):
122
+ model = Spotiflow.from_folder(model_path)
123
+ else:
124
+ model = Spotiflow.from_pretrained(pretrained_model)
125
+
126
+ model = model.to(device)
127
+
128
+ # Prepare input (reuse existing logic)
129
+ axes = _infer_axes(image)
130
+ prepared_img = _prepare_input(image, axes)
131
+
132
+ # Normalize (simple percentile normalization)
133
+ p_low, p_high = np.percentile(prepared_img, [1.0, 99.8])
134
+ normalized_img = np.clip(
135
+ (prepared_img - p_low) / (p_high - p_low), 0, 1
136
+ )
137
+
138
+ # Get prediction with details
139
+ points_new, details = model.predict(
140
+ normalized_img,
141
+ prob_thresh=prob_thresh,
142
+ device=device,
143
+ verbose=False,
144
+ )
145
+
146
+ # Use probability heatmap for segmentation
147
+ if hasattr(details, "heatmap") and details.heatmap is not None:
148
+ prob_map = details.heatmap
149
+
150
+ # Apply threshold to create binary mask
151
+ threshold = prob_thresh if prob_thresh is not None else 0.4
152
+ binary_mask = prob_map > threshold
153
+
154
+ # Use detected points as seeds for watershed segmentation
155
+ if len(points) > 0:
156
+ # Create marker image from detected points
157
+ markers = np.zeros(prob_map.shape, dtype=np.int32)
158
+ for i, point in enumerate(points):
159
+ if len(point) >= 2:
160
+ y, x = int(point[0]), int(point[1])
161
+ if (
162
+ 0 <= y < markers.shape[0]
163
+ and 0 <= x < markers.shape[1]
164
+ ):
165
+ markers[y, x] = i + 1
166
+
167
+ # Apply watershed segmentation using probability map and markers
168
+ labels = watershed(-prob_map, markers, mask=binary_mask)
169
+ else:
170
+ # No points detected, just label connected components
171
+ labels, _ = label(binary_mask)
172
+
173
+ return labels.astype(np.uint16)
174
+ else:
175
+ # Fallback to point-based method
176
+ return _points_to_label_mask(points, image.shape[:2], spot_radius)
177
+
178
+ except (ImportError, RuntimeError, ValueError, AttributeError) as e:
179
+ print(f"Error in heatmap-based conversion: {e}")
180
+ # Fallback to point-based method
181
+ return _points_to_label_mask(points, image.shape[:2], spot_radius)
182
+
183
+
184
+ @BatchProcessingRegistry.register(
185
+ name="Spotiflow Spot Detection",
186
+ suffix="_spot_labels",
187
+ description="Detect spots in fluorescence microscopy images using Spotiflow and return as label masks",
188
+ parameters={
189
+ "pretrained_model": {
190
+ "type": str,
191
+ "default": "general",
192
+ "description": "Pretrained model to use (general, hybiss, synth_complex, synth_3d, smfish_3d)",
193
+ "choices": [
194
+ "general",
195
+ "hybiss",
196
+ "synth_complex",
197
+ "synth_3d",
198
+ "smfish_3d",
199
+ ],
200
+ },
201
+ "model_path": {
202
+ "type": str,
203
+ "default": "",
204
+ "description": "Path to custom trained model folder (leave empty to use pretrained model)",
205
+ },
206
+ "subpixel": {
207
+ "type": bool,
208
+ "default": True,
209
+ "description": "Enable subpixel localization for more accurate spot coordinates",
210
+ },
211
+ "peak_mode": {
212
+ "type": str,
213
+ "default": "fast",
214
+ "description": "Peak detection mode",
215
+ "choices": ["fast", "skimage"],
216
+ },
217
+ "normalizer": {
218
+ "type": str,
219
+ "default": "percentile",
220
+ "description": "Image normalization method",
221
+ "choices": ["percentile", "minmax"],
222
+ },
223
+ "normalizer_low": {
224
+ "type": float,
225
+ "default": 1.0,
226
+ "min": 0.0,
227
+ "max": 50.0,
228
+ "description": "Lower percentile for normalization",
229
+ },
230
+ "normalizer_high": {
231
+ "type": float,
232
+ "default": 99.8,
233
+ "min": 50.0,
234
+ "max": 100.0,
235
+ "description": "Upper percentile for normalization",
236
+ },
237
+ "prob_thresh": {
238
+ "type": float,
239
+ "default": None,
240
+ "min": 0.0,
241
+ "max": 1.0,
242
+ "description": "Probability threshold (leave empty or 0.0 for automatic)",
243
+ },
244
+ "n_tiles": {
245
+ "type": str,
246
+ "default": "auto",
247
+ "description": "Number of tiles for prediction (e.g., '(2,2)' or 'auto')",
248
+ },
249
+ "exclude_border": {
250
+ "type": bool,
251
+ "default": True,
252
+ "description": "Exclude spots near image borders",
253
+ },
254
+ "scale": {
255
+ "type": str,
256
+ "default": "auto",
257
+ "description": "Scaling factor (e.g., '(1,1)' or 'auto')",
258
+ },
259
+ "min_distance": {
260
+ "type": int,
261
+ "default": 2,
262
+ "min": 1,
263
+ "max": 10,
264
+ "description": "Minimum distance between detected spots",
265
+ },
266
+ "spot_radius": {
267
+ "type": int,
268
+ "default": 3,
269
+ "min": 1,
270
+ "max": 20,
271
+ "description": "Radius of spots in the label mask (in pixels, used for fallback method)",
272
+ },
273
+ "axes": {
274
+ "type": str,
275
+ "default": "auto",
276
+ "description": "Axes order (e.g., 'ZYX', 'YX', or 'auto' for automatic detection)",
277
+ },
278
+ "output_csv": {
279
+ "type": bool,
280
+ "default": True,
281
+ "description": "Save spot coordinates as CSV file alongside the mask",
282
+ },
283
+ "force_dedicated_env": {
284
+ "type": bool,
285
+ "default": False,
286
+ "description": "Force using dedicated environment even if Spotiflow is available",
287
+ },
288
+ "force_cpu": {
289
+ "type": bool,
290
+ "default": False,
291
+ "description": "Force CPU execution (disable GPU) to avoid CUDA compatibility issues",
292
+ },
293
+ },
294
+ )
295
+ def spotiflow_detect_spots(
296
+ image: np.ndarray,
297
+ pretrained_model: str = "general",
298
+ model_path: str = "",
299
+ subpixel: bool = True,
300
+ peak_mode: str = "fast",
301
+ normalizer: str = "percentile",
302
+ normalizer_low: float = 1.0,
303
+ normalizer_high: float = 99.8,
304
+ prob_thresh: float = None,
305
+ n_tiles: str = "auto",
306
+ exclude_border: bool = True,
307
+ scale: str = "auto",
308
+ min_distance: int = 2,
309
+ spot_radius: int = 3,
310
+ axes: str = "auto",
311
+ output_csv: bool = True,
312
+ force_dedicated_env: bool = False,
313
+ force_cpu: bool = False,
314
+ # For internal use by processing system
315
+ input_file_path: str = None,
316
+ ) -> np.ndarray:
317
+ """
318
+ Detect spots in fluorescence microscopy images using Spotiflow and return label masks.
319
+
320
+ Spotiflow is a deep learning-based spot detection method that provides
321
+ threshold-agnostic, subpixel-accurate detection of spots in 2D and 3D
322
+ fluorescence microscopy images. The output is a label mask suitable for
323
+ napari Labels layers, created from the Spotiflow probability heatmap.
324
+
325
+ Parameters:
326
+ -----------
327
+ image : np.ndarray
328
+ Input image (2D or 3D)
329
+ pretrained_model : str
330
+ Pretrained model to use ('general', 'hybiss', 'synth_complex', 'synth_3d', 'smfish_3d')
331
+ model_path : str
332
+ Path to custom trained model folder (overrides pretrained_model if provided)
333
+ subpixel : bool
334
+ Enable subpixel localization
335
+ peak_mode : str
336
+ Peak detection mode ('fast' or 'skimage')
337
+ normalizer : str
338
+ Image normalization method ('percentile' or 'minmax')
339
+ normalizer_low : float
340
+ Lower percentile for normalization
341
+ normalizer_high : float
342
+ Upper percentile for normalization
343
+ prob_thresh : float or None
344
+ Probability threshold (None for automatic)
345
+ n_tiles : str
346
+ Number of tiles for prediction (e.g., '(2,2)' or 'auto')
347
+ exclude_border : bool
348
+ Exclude spots near image borders
349
+ scale : str
350
+ Scaling factor (e.g., '(1,1)' or 'auto')
351
+ min_distance : int
352
+ Minimum distance between detected spots
353
+ spot_radius : int
354
+ Radius of spots in the label mask (in pixels, used for fallback method)
355
+ axes : str
356
+ Axes order (e.g., 'ZYX', 'YX', or 'auto' for automatic detection)
357
+ output_csv : bool
358
+ Save spot coordinates as CSV file alongside the mask
359
+ force_dedicated_env : bool
360
+ Force using dedicated environment
361
+ force_cpu : bool
362
+ Force CPU execution (disable GPU) to avoid CUDA compatibility issues
363
+ input_file_path : str
364
+ Path to input file (used for saving CSV output)
365
+
366
+ Returns:
367
+ --------
368
+ np.ndarray
369
+ Label mask with detected spots (uint16) for napari Labels layer
370
+ """
371
+ print("Detecting spots using Spotiflow...")
372
+ print(f"Image shape: {image.shape}")
373
+ print(f"Image dtype: {image.dtype}")
374
+
375
+ # Infer axes if auto
376
+ if axes == "auto":
377
+ axes = _infer_axes(image)
378
+ print(f"Inferred axes: {axes}")
379
+ else:
380
+ print(f"Using provided axes: {axes}")
381
+
382
+ # Decide whether to use dedicated environment
383
+ use_env = USE_DEDICATED_ENV or force_dedicated_env
384
+
385
+ if not use_env and SPOTIFLOW_AVAILABLE:
386
+ # Use direct import
387
+ points = _detect_spots_direct(
388
+ image,
389
+ axes,
390
+ pretrained_model,
391
+ model_path,
392
+ subpixel,
393
+ peak_mode,
394
+ normalizer,
395
+ normalizer_low,
396
+ normalizer_high,
397
+ prob_thresh,
398
+ n_tiles,
399
+ exclude_border,
400
+ scale,
401
+ min_distance,
402
+ force_cpu,
403
+ )
404
+ else:
405
+ # Use dedicated environment
406
+ points = _detect_spots_env(
407
+ image,
408
+ axes,
409
+ pretrained_model,
410
+ model_path,
411
+ subpixel,
412
+ peak_mode,
413
+ normalizer,
414
+ normalizer_low,
415
+ normalizer_high,
416
+ prob_thresh,
417
+ n_tiles,
418
+ exclude_border,
419
+ scale,
420
+ min_distance,
421
+ force_cpu,
422
+ )
423
+
424
+ # Save CSV if requested (use a default filename if no input path provided)
425
+ if output_csv:
426
+ if input_file_path:
427
+ _save_coords_csv(points, input_file_path, use_env)
428
+ else:
429
+ # No input file path provided; skipping CSV export.
430
+ print(
431
+ "No input file path provided, skipping CSV export of spot coordinates."
432
+ )
433
+
434
+ # Convert points to label masks using the improved method
435
+ print(f"Detected {len(points)} spots, converting to label masks...")
436
+
437
+ # Always use the simple point-based method for now to ensure it works
438
+ label_mask = _points_to_label_mask(points, image.shape, spot_radius)
439
+
440
+ print(
441
+ f"Created label mask with {len(np.unique(label_mask)) - 1} labeled objects"
442
+ )
443
+ return label_mask
444
+
445
+
446
+ def _points_to_label_mask(
447
+ points: np.ndarray, image_shape: tuple, spot_radius: int
448
+ ) -> np.ndarray:
449
+ """Convert detected points to a label mask for napari."""
450
+ from scipy import ndimage
451
+ from skimage import draw
452
+
453
+ # Create empty label mask with the same shape as input image
454
+ label_mask = np.zeros(image_shape, dtype=np.uint16)
455
+
456
+ # Handle different dimensionalities - focus on spatial dimensions
457
+ spatial_dims = len(image_shape)
458
+ if spatial_dims >= 4: # TZYX, TZYXC, etc.
459
+ if image_shape[-1] <= 4: # Last dim is channels
460
+ spatial_shape = image_shape[-4:-1] # Take ZYX (skip channels)
461
+ else:
462
+ spatial_shape = image_shape[-3:] # Take last 3 dims (ZYX)
463
+ elif spatial_dims == 3: # ZYX or YXC
464
+ # Check if last dimension is small (likely channels)
465
+ if image_shape[-1] <= 4:
466
+ spatial_shape = image_shape[:2] # YX (with channels)
467
+ else:
468
+ spatial_shape = image_shape # ZYX
469
+ else: # 2D: YX or YXC
470
+ if len(image_shape) == 3 and image_shape[-1] <= 4:
471
+ spatial_shape = image_shape[:2] # YX (with channels)
472
+ else:
473
+ spatial_shape = image_shape # YX
474
+
475
+ if len(points) == 0:
476
+ return label_mask
477
+
478
+ # Check coordinate format and swap if necessary
479
+ if points.shape[1] == 2: # 2D points (y, x)
480
+ coords = points.astype(int)
481
+ elif points.shape[1] == 3: # 3D points - need to figure out the format
482
+ # Try to determine the correct coordinate mapping based on spatial shape
483
+ if len(spatial_shape) == 2: # Working with 2D spatial data
484
+ # If dim1 and dim2 fit in image bounds, assume (z, y, x)
485
+ if (
486
+ points[:, 1].max() < spatial_shape[0]
487
+ and points[:, 2].max() < spatial_shape[1]
488
+ ):
489
+ coords = points[:, 1:3].astype(int) # Take y, x (skip z)
490
+ # If dim0 and dim2 fit in image bounds, assume (y, z, x)
491
+ elif (
492
+ points[:, 0].max() < spatial_shape[0]
493
+ and points[:, 2].max() < spatial_shape[1]
494
+ ):
495
+ coords = points[:, [0, 2]].astype(int) # Take y, x (skip z)
496
+ # If dim0 and dim1 fit in image bounds, assume (y, x, z)
497
+ elif (
498
+ points[:, 0].max() < spatial_shape[0]
499
+ and points[:, 1].max() < spatial_shape[1]
500
+ ):
501
+ coords = points[:, 0:2].astype(int) # Take y, x (skip z)
502
+ else:
503
+ # Try swapping coordinates - maybe it's (x, y) instead of (y, x)
504
+ coords = points[:, [1, 0]].astype(int)
505
+ else: # Working with 3D spatial data
506
+ coords = points.astype(int) # Use all 3 coordinates
507
+ else:
508
+ raise ValueError(f"Unexpected points shape: {points.shape}")
509
+
510
+ # Create spots based on spatial dimensions
511
+ valid_spots = 0
512
+
513
+ if len(spatial_shape) == 2: # 2D spatial
514
+ for i, (y, x) in enumerate(coords):
515
+ if 0 <= y < spatial_shape[0] and 0 <= x < spatial_shape[1]:
516
+ try:
517
+ rr, cc = draw.disk(
518
+ (y, x), spot_radius, shape=spatial_shape
519
+ )
520
+ # Handle different label mask shapes
521
+ if len(image_shape) == 2: # Pure 2D
522
+ label_mask[rr, cc] = i + 1
523
+ elif len(image_shape) == 3: # 2D with channels or 3D
524
+ if image_shape[-1] <= 4: # Likely channels
525
+ label_mask[rr, cc, :] = (
526
+ i + 1
527
+ ) # Apply to all channels
528
+ else: # 3D data - apply to all Z slices
529
+ label_mask[:, rr, cc] = i + 1
530
+ elif len(image_shape) == 4: # TZYX or similar
531
+ label_mask[:, :, rr, cc] = (
532
+ i + 1
533
+ ) # Apply to all T and Z
534
+ elif len(image_shape) == 5: # TZYXC
535
+ label_mask[:, :, rr, cc, :] = (
536
+ i + 1
537
+ ) # Apply to all T, Z, and C
538
+
539
+ valid_spots += 1
540
+ except (ValueError, IndexError, TypeError) as e:
541
+ print(f"Error drawing spot {i} at ({y}, {x}): {e}")
542
+
543
+ elif len(spatial_shape) == 3: # 3D spatial
544
+ # For 3D spatial, we need 3D coordinates
545
+ if coords.shape[1] == 2:
546
+ # We have 2D points but need 3D - place them in the middle Z slice
547
+ middle_z = spatial_shape[0] // 2
548
+ coords_3d = np.column_stack(
549
+ [np.full(len(coords), middle_z), coords]
550
+ )
551
+ else:
552
+ coords_3d = coords
553
+
554
+ for i, (z, y, x) in enumerate(coords_3d):
555
+ if (
556
+ 0 <= z < spatial_shape[0]
557
+ and 0 <= y < spatial_shape[1]
558
+ and 0 <= x < spatial_shape[2]
559
+ ):
560
+ try:
561
+ # Create a small sphere
562
+ ball = ndimage.generate_binary_structure(3, 1)
563
+ ball = ndimage.iterate_structure(ball, spot_radius)
564
+
565
+ # Get sphere coordinates
566
+ ball_coords = np.array(np.where(ball)).T - spot_radius
567
+ z_coords = ball_coords[:, 0] + z
568
+ y_coords = ball_coords[:, 1] + y
569
+ x_coords = ball_coords[:, 2] + x
570
+
571
+ # Filter valid coordinates
572
+ valid = (
573
+ (z_coords >= 0)
574
+ & (z_coords < spatial_shape[0])
575
+ & (y_coords >= 0)
576
+ & (y_coords < spatial_shape[1])
577
+ & (x_coords >= 0)
578
+ & (x_coords < spatial_shape[2])
579
+ )
580
+
581
+ # Handle different label mask shapes
582
+ if len(image_shape) == 3: # Pure 3D
583
+ label_mask[
584
+ z_coords[valid], y_coords[valid], x_coords[valid]
585
+ ] = (i + 1)
586
+ elif len(image_shape) == 4: # TZYX or ZYXC
587
+ if image_shape[-1] <= 4: # ZYXC
588
+ label_mask[
589
+ z_coords[valid],
590
+ y_coords[valid],
591
+ x_coords[valid],
592
+ :,
593
+ ] = (
594
+ i + 1
595
+ )
596
+ else: # TZYX
597
+ label_mask[
598
+ :,
599
+ z_coords[valid],
600
+ y_coords[valid],
601
+ x_coords[valid],
602
+ ] = (
603
+ i + 1
604
+ )
605
+ elif len(image_shape) == 5: # TZYXC
606
+ label_mask[
607
+ :,
608
+ z_coords[valid],
609
+ y_coords[valid],
610
+ x_coords[valid],
611
+ :,
612
+ ] = (
613
+ i + 1
614
+ )
615
+
616
+ valid_spots += 1
617
+ except (ValueError, IndexError, TypeError) as e:
618
+ print(f"Error drawing 3D spot {i} at ({z}, {y}, {x}): {e}")
619
+
620
+ print(
621
+ f"Successfully created {valid_spots} spots in label mask with shape {label_mask.shape}"
622
+ )
623
+ return label_mask
624
+
625
+
626
+ def _detect_spots_direct(
627
+ image,
628
+ axes,
629
+ pretrained_model,
630
+ model_path,
631
+ subpixel,
632
+ peak_mode,
633
+ normalizer,
634
+ normalizer_low,
635
+ normalizer_high,
636
+ prob_thresh,
637
+ n_tiles,
638
+ exclude_border,
639
+ scale,
640
+ min_distance,
641
+ force_cpu,
642
+ ):
643
+ """Direct implementation using imported Spotiflow."""
644
+ import torch
645
+ from spotiflow.model import Spotiflow
646
+
647
+ # Set device based on force_cpu parameter
648
+ if force_cpu:
649
+ print("Forcing CPU execution as requested")
650
+ device = torch.device("cpu")
651
+ # Set environment variable to ensure CPU usage
652
+ import os
653
+
654
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
655
+ else:
656
+ # Use CUDA if available and compatible
657
+ if torch.cuda.is_available():
658
+ try:
659
+ # Test CUDA compatibility by creating a small tensor
660
+ torch.ones(1).cuda()
661
+ device = torch.device("cuda")
662
+ print("Using CUDA (GPU) for inference")
663
+ except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
664
+ print(f"CUDA incompatible ({e}), falling back to CPU")
665
+ device = torch.device("cpu")
666
+ force_cpu = True
667
+ else:
668
+ print("CUDA not available, using CPU")
669
+ device = torch.device("cpu")
670
+ force_cpu = True
671
+
672
+ # Load the model
673
+ if model_path and os.path.exists(model_path):
674
+ print(f"Loading custom model from {model_path}")
675
+ model = Spotiflow.from_folder(model_path)
676
+ else:
677
+ print(f"Loading pretrained model: {pretrained_model}")
678
+ model = Spotiflow.from_pretrained(pretrained_model)
679
+
680
+ # Move model to the appropriate device
681
+ try:
682
+ model = model.to(device)
683
+ print(f"Model moved to device: {device}")
684
+ except Exception as e:
685
+ if not force_cpu:
686
+ print(f"Failed to move model to GPU ({e}), falling back to CPU")
687
+ device = torch.device("cpu")
688
+ model = model.to(device)
689
+ else:
690
+ raise
691
+
692
+ # Check model compatibility with image dimensionality
693
+ is_3d_image = len(image.shape) == 3 and "Z" in axes
694
+ if is_3d_image and not model.config.is_3d:
695
+ print(
696
+ "Warning: Using a 2D model on 3D data. Consider using a 3D model like 'synth_3d' or 'smfish_3d'."
697
+ )
698
+
699
+ # Prepare input using the same method as napari-spotiflow
700
+ print(f"Preparing input with axes: {axes}")
701
+ try:
702
+ prepared_img = _prepare_input(image, axes)
703
+ print(f"Prepared image shape: {prepared_img.shape}")
704
+ except ValueError as e:
705
+ print(f"Error preparing input: {e}")
706
+ # Fallback to original image
707
+ prepared_img = image
708
+
709
+ # Parse string parameters
710
+ def parse_param(param_str, default_val):
711
+ if param_str == "auto":
712
+ return default_val
713
+ try:
714
+ return eval(param_str) if param_str.startswith("(") else param_str
715
+ except (ValueError, SyntaxError):
716
+ return default_val
717
+
718
+ n_tiles_parsed = parse_param(n_tiles, None)
719
+ scale_parsed = parse_param(scale, None)
720
+
721
+ # Prepare prediction parameters (following napari-spotiflow style)
722
+ predict_kwargs = {
723
+ "subpix": subpixel, # Note: Spotiflow API uses 'subpix', not 'subpixel'
724
+ "peak_mode": peak_mode,
725
+ "normalizer": None, # We'll handle normalization manually
726
+ "exclude_border": exclude_border,
727
+ "min_distance": min_distance,
728
+ "verbose": True,
729
+ }
730
+
731
+ # Set probability threshold - use automatic or provided value
732
+ if prob_thresh is not None and prob_thresh > 0.0:
733
+ predict_kwargs["prob_thresh"] = prob_thresh
734
+ else:
735
+ # Use automatic thresholding similar to napari-spotiflow
736
+ # Don't set prob_thresh - let spotiflow determine it automatically
737
+ # This includes None and 0.0 values which should use automatic thresholding
738
+ pass # Spotiflow will use its default optimized threshold
739
+
740
+ if n_tiles_parsed is not None:
741
+ predict_kwargs["n_tiles"] = n_tiles_parsed
742
+ if scale_parsed is not None:
743
+ predict_kwargs["scale"] = scale_parsed
744
+
745
+ # Handle normalization manually (similar to napari-spotiflow)
746
+ if normalizer == "percentile":
747
+ print(
748
+ f"Applying percentile normalization: {normalizer_low}% to {normalizer_high}%"
749
+ )
750
+ p_low, p_high = np.percentile(
751
+ prepared_img, [normalizer_low, normalizer_high]
752
+ )
753
+ normalized_img = np.clip(
754
+ (prepared_img - p_low) / (p_high - p_low), 0, 1
755
+ )
756
+ elif normalizer == "minmax":
757
+ print("Applying min-max normalization")
758
+ img_min, img_max = prepared_img.min(), prepared_img.max()
759
+ normalized_img = (
760
+ (prepared_img - img_min) / (img_max - img_min)
761
+ if img_max > img_min
762
+ else prepared_img
763
+ )
764
+ else:
765
+ normalized_img = prepared_img
766
+
767
+ print(
768
+ f"Normalized image range: {normalized_img.min():.3f} to {normalized_img.max():.3f}"
769
+ )
770
+
771
+ # Perform spot detection
772
+ print("Running Spotiflow prediction...")
773
+ try:
774
+ points, details = model.predict(normalized_img, **predict_kwargs)
775
+ except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
776
+ if "CUDA" in str(e) and not force_cpu:
777
+ print(f"CUDA error during prediction ({e}), retrying with CPU")
778
+ # Move model to CPU and retry
779
+ device = torch.device("cpu")
780
+ model = model.to(device)
781
+ # Set environment to force CPU
782
+ import os
783
+
784
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
785
+ points, details = model.predict(normalized_img, **predict_kwargs)
786
+ else:
787
+ raise
788
+
789
+ print(f"Initial detection: {len(points)} spots")
790
+
791
+ # Only apply minimal additional filtering if we still have too many detections
792
+ # This should rarely be needed now that we use proper automatic thresholding
793
+ if len(points) > 500: # Only if we have an excessive number of spots
794
+ print(f"Applying additional filtering for {len(points)} spots")
795
+
796
+ # Check if we can apply probability filtering
797
+ if hasattr(details, "prob"):
798
+ # Use a more stringent threshold
799
+ auto_thresh = 0.7
800
+ prob_mask = details.prob > auto_thresh
801
+ points = points[prob_mask]
802
+ print(
803
+ f"After additional probability thresholding ({auto_thresh}): {len(points)} spots"
804
+ )
805
+
806
+ print(f"Final detection: {len(points)} spots")
807
+ return points
808
+
809
+
810
+ def _detect_spots_env(
811
+ image,
812
+ axes,
813
+ pretrained_model,
814
+ model_path,
815
+ subpixel,
816
+ peak_mode,
817
+ normalizer,
818
+ normalizer_low,
819
+ normalizer_high,
820
+ prob_thresh,
821
+ n_tiles,
822
+ exclude_border,
823
+ scale,
824
+ min_distance,
825
+ force_cpu,
826
+ ):
827
+ """Implementation using dedicated environment."""
828
+ # Prepare arguments for environment execution
829
+ args_dict = {
830
+ "image": image,
831
+ "axes": axes,
832
+ "pretrained_model": pretrained_model,
833
+ "model_path": model_path,
834
+ "subpixel": subpixel,
835
+ "peak_mode": peak_mode,
836
+ "normalizer": normalizer,
837
+ "normalizer_low": normalizer_low,
838
+ "normalizer_high": normalizer_high,
839
+ "prob_thresh": prob_thresh,
840
+ "n_tiles": n_tiles,
841
+ "exclude_border": exclude_border,
842
+ "scale": scale,
843
+ "min_distance": min_distance,
844
+ "force_cpu": force_cpu,
845
+ }
846
+
847
+ # Run in dedicated environment
848
+ result = run_spotiflow_in_env("detect_spots", args_dict)
849
+
850
+ print(f"Detected {len(result['points'])} spots")
851
+ return result["points"]
852
+
853
+
854
+ def _save_coords_csv(
855
+ points: np.ndarray, input_file_path: str, use_env: bool = False
856
+ ):
857
+ """Save coordinates to CSV using Spotiflow's write_coords_csv function."""
858
+ if not input_file_path:
859
+ return
860
+
861
+ # Generate CSV filename based on input file
862
+ from pathlib import Path
863
+
864
+ input_path = Path(input_file_path)
865
+ csv_path = input_path.parent / (input_path.stem + "_spots.csv")
866
+
867
+ if use_env:
868
+ # Use dedicated environment
869
+ _save_coords_csv_env(points, str(csv_path))
870
+ else:
871
+ # Use direct import
872
+ _save_coords_csv_direct(points, str(csv_path))
873
+
874
+
875
+ def _save_coords_csv_direct(points: np.ndarray, csv_path: str):
876
+ """Save coordinates directly using Spotiflow utils."""
877
+ try:
878
+ from spotiflow.utils import write_coords_csv
879
+
880
+ write_coords_csv(points, csv_path)
881
+ print(f"Saved {len(points)} spot coordinates to {csv_path}")
882
+ except ImportError:
883
+ # Fallback to basic CSV writing
884
+ import pandas as pd
885
+
886
+ columns = ["y", "x"] if points.shape[1] == 2 else ["z", "y", "x"]
887
+ df = pd.DataFrame(points, columns=columns)
888
+ df.to_csv(csv_path, index=False)
889
+ print(
890
+ f"Saved {len(points)} spot coordinates to {csv_path} (fallback method)"
891
+ )
892
+
893
+
894
+ def _save_coords_csv_env(points: np.ndarray, csv_path: str):
895
+ """Save coordinates using dedicated environment."""
896
+ import contextlib
897
+ import subprocess
898
+ import tempfile
899
+
900
+ from napari_tmidas.processing_functions.spotiflow_env_manager import (
901
+ get_env_python_path,
902
+ )
903
+
904
+ # Save points to temporary numpy file
905
+ with tempfile.NamedTemporaryFile(
906
+ suffix=".npy", delete=False
907
+ ) as temp_points:
908
+ np.save(temp_points.name, points)
909
+
910
+ # Create script to save CSV
911
+ script = f"""
912
+ import numpy as np
913
+ from spotiflow.utils import write_coords_csv
914
+
915
+ # Load points
916
+ points = np.load('{temp_points.name}')
917
+
918
+ # Save CSV
919
+ write_coords_csv(points, '{csv_path}')
920
+ print(f"Saved {{len(points)}} spot coordinates to {csv_path}")
921
+ """
922
+
923
+ with tempfile.NamedTemporaryFile(
924
+ mode="w", suffix=".py", delete=False
925
+ ) as script_file:
926
+ script_file.write(script)
927
+ script_file.flush()
928
+
929
+ # Execute script
930
+ env_python = get_env_python_path()
931
+ result = subprocess.run(
932
+ [env_python, script_file.name],
933
+ check=True,
934
+ capture_output=True,
935
+ text=True,
936
+ )
937
+
938
+ print(result.stdout)
939
+
940
+ # Clean up
941
+ with contextlib.suppress(FileNotFoundError):
942
+ import os
943
+
944
+ os.unlink(temp_points.name)
945
+ os.unlink(script_file.name)
946
+
947
+
948
+ # Alias for convenience
949
+ spotiflow_spot_detection = spotiflow_detect_spots