napari-tmidas 0.2.1__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- napari_tmidas/__init__.py +35 -5
- napari_tmidas/_crop_anything.py +1458 -499
- napari_tmidas/_env_manager.py +76 -0
- napari_tmidas/_file_conversion.py +1646 -1131
- napari_tmidas/_file_selector.py +1464 -223
- napari_tmidas/_label_inspection.py +83 -8
- napari_tmidas/_processing_worker.py +309 -0
- napari_tmidas/_reader.py +6 -10
- napari_tmidas/_registry.py +15 -14
- napari_tmidas/_roi_colocalization.py +1221 -84
- napari_tmidas/_tests/test_crop_anything.py +123 -0
- napari_tmidas/_tests/test_env_manager.py +89 -0
- napari_tmidas/_tests/test_file_selector.py +90 -0
- napari_tmidas/_tests/test_grid_view_overlay.py +193 -0
- napari_tmidas/_tests/test_init.py +98 -0
- napari_tmidas/_tests/test_intensity_label_filter.py +222 -0
- napari_tmidas/_tests/test_label_inspection.py +86 -0
- napari_tmidas/_tests/test_processing_basic.py +500 -0
- napari_tmidas/_tests/test_processing_worker.py +142 -0
- napari_tmidas/_tests/test_regionprops_analysis.py +547 -0
- napari_tmidas/_tests/test_registry.py +135 -0
- napari_tmidas/_tests/test_scipy_filters.py +168 -0
- napari_tmidas/_tests/test_skimage_filters.py +259 -0
- napari_tmidas/_tests/test_split_channels.py +217 -0
- napari_tmidas/_tests/test_spotiflow.py +87 -0
- napari_tmidas/_tests/test_tyx_display_fix.py +142 -0
- napari_tmidas/_tests/test_ui_utils.py +68 -0
- napari_tmidas/_tests/test_widget.py +30 -0
- napari_tmidas/_tests/test_windows_basic.py +66 -0
- napari_tmidas/_ui_utils.py +57 -0
- napari_tmidas/_version.py +16 -3
- napari_tmidas/_widget.py +41 -4
- napari_tmidas/processing_functions/basic.py +557 -20
- napari_tmidas/processing_functions/careamics_env_manager.py +72 -99
- napari_tmidas/processing_functions/cellpose_env_manager.py +415 -112
- napari_tmidas/processing_functions/cellpose_segmentation.py +132 -191
- napari_tmidas/processing_functions/colocalization.py +513 -56
- napari_tmidas/processing_functions/grid_view_overlay.py +703 -0
- napari_tmidas/processing_functions/intensity_label_filter.py +422 -0
- napari_tmidas/processing_functions/regionprops_analysis.py +1280 -0
- napari_tmidas/processing_functions/sam2_env_manager.py +53 -69
- napari_tmidas/processing_functions/sam2_mp4.py +274 -195
- napari_tmidas/processing_functions/scipy_filters.py +403 -8
- napari_tmidas/processing_functions/skimage_filters.py +424 -212
- napari_tmidas/processing_functions/spotiflow_detection.py +949 -0
- napari_tmidas/processing_functions/spotiflow_env_manager.py +591 -0
- napari_tmidas/processing_functions/timepoint_merger.py +334 -86
- napari_tmidas/processing_functions/trackastra_tracking.py +24 -5
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/METADATA +92 -39
- napari_tmidas-0.2.4.dist-info/RECORD +63 -0
- napari_tmidas/_tests/__init__.py +0 -0
- napari_tmidas-0.2.1.dist-info/RECORD +0 -38
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/WHEEL +0 -0
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/entry_points.txt +0 -0
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,949 @@
|
|
|
1
|
+
# processing_functions/spotiflow_detection.py
|
|
2
|
+
"""
|
|
3
|
+
Processing functions for spot detection using Spotiflow.
|
|
4
|
+
|
|
5
|
+
This module provides functionality to detect spots in fluorescence microscopy images
|
|
6
|
+
using Spotiflow models. It supports both 2D and 3D data with various pretrained models.
|
|
7
|
+
|
|
8
|
+
The functions will automatically create and manage a dedicated environment for Spotiflow
|
|
9
|
+
if it's not already installed in the main environment.
|
|
10
|
+
"""
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from napari_tmidas._registry import BatchProcessingRegistry
|
|
16
|
+
|
|
17
|
+
# Import the environment manager for Spotiflow
|
|
18
|
+
from napari_tmidas.processing_functions.spotiflow_env_manager import (
|
|
19
|
+
run_spotiflow_in_env,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Utility functions for axes and input preparation (from napari-spotiflow)
|
|
24
|
+
def _validate_axes(img: np.ndarray, axes: str) -> None:
|
|
25
|
+
"""Validate that the number of dimensions in the image matches the given axes string."""
|
|
26
|
+
if img.ndim != len(axes):
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"Image has {img.ndim} dimensions, but axes has {len(axes)} dimensions"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _prepare_input(img: np.ndarray, axes: str) -> np.ndarray:
|
|
33
|
+
"""Reshape input for Spotiflow's API compatibility based on axes notation."""
|
|
34
|
+
_validate_axes(img, axes)
|
|
35
|
+
|
|
36
|
+
if axes in {"YX", "ZYX", "TYX", "TZYX"}:
|
|
37
|
+
return img[..., None]
|
|
38
|
+
elif axes in {"YXC", "ZYXC", "TYXC", "TZYXC"}:
|
|
39
|
+
return img
|
|
40
|
+
elif axes == "CYX":
|
|
41
|
+
return img.transpose(1, 2, 0)
|
|
42
|
+
elif axes == "CZYX":
|
|
43
|
+
return img.transpose(1, 2, 3, 0)
|
|
44
|
+
elif axes == "ZCYX" or axes == "TCYX":
|
|
45
|
+
return img.transpose(0, 2, 3, 1)
|
|
46
|
+
elif axes == "TZCYX":
|
|
47
|
+
return img.transpose(0, 1, 3, 4, 2)
|
|
48
|
+
elif axes == "TCZYX":
|
|
49
|
+
return img.transpose(0, 2, 3, 4, 1)
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f"Invalid axes: {axes}")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _infer_axes(img: np.ndarray) -> str:
|
|
55
|
+
"""Infer the most likely axes order for the image."""
|
|
56
|
+
ndim = img.ndim
|
|
57
|
+
if ndim == 2:
|
|
58
|
+
return "YX"
|
|
59
|
+
elif ndim == 3:
|
|
60
|
+
# For 3D, we need to make an educated guess
|
|
61
|
+
# Most common is ZYX for 3D microscopy
|
|
62
|
+
return "ZYX"
|
|
63
|
+
elif ndim == 4:
|
|
64
|
+
# Could be TZYX or ZYXC, let's check the last dimension
|
|
65
|
+
if img.shape[-1] <= 4: # Likely channels
|
|
66
|
+
return "ZYXC"
|
|
67
|
+
else:
|
|
68
|
+
return "TZYX"
|
|
69
|
+
elif ndim == 5:
|
|
70
|
+
return "TZYXC"
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(f"Cannot infer axes for {ndim}D image")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# Check if Spotiflow is directly available in current environment
|
|
76
|
+
try:
|
|
77
|
+
import importlib.util
|
|
78
|
+
|
|
79
|
+
spec = importlib.util.find_spec("spotiflow.model")
|
|
80
|
+
if spec is not None:
|
|
81
|
+
SPOTIFLOW_AVAILABLE = True
|
|
82
|
+
USE_DEDICATED_ENV = False
|
|
83
|
+
print("Spotiflow found in current environment, using direct import")
|
|
84
|
+
else:
|
|
85
|
+
raise ImportError("Spotiflow not found")
|
|
86
|
+
except ImportError:
|
|
87
|
+
SPOTIFLOW_AVAILABLE = False
|
|
88
|
+
USE_DEDICATED_ENV = True
|
|
89
|
+
print(
|
|
90
|
+
"Spotiflow not found in current environment, will use dedicated environment"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _convert_points_to_labels_with_heatmap(
|
|
95
|
+
image: np.ndarray,
|
|
96
|
+
points: np.ndarray,
|
|
97
|
+
spot_radius: int,
|
|
98
|
+
pretrained_model: str,
|
|
99
|
+
model_path: str,
|
|
100
|
+
prob_thresh: float,
|
|
101
|
+
force_cpu: bool,
|
|
102
|
+
) -> np.ndarray:
|
|
103
|
+
"""
|
|
104
|
+
Convert points to label masks using Spotiflow's probability heatmap for better segmentation.
|
|
105
|
+
"""
|
|
106
|
+
try:
|
|
107
|
+
import torch
|
|
108
|
+
from scipy.ndimage import label
|
|
109
|
+
from skimage.segmentation import watershed
|
|
110
|
+
from spotiflow.model import Spotiflow
|
|
111
|
+
|
|
112
|
+
# Set device
|
|
113
|
+
if force_cpu:
|
|
114
|
+
device = torch.device("cpu")
|
|
115
|
+
else:
|
|
116
|
+
device = torch.device(
|
|
117
|
+
"cuda" if torch.cuda.is_available() else "cpu"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Load the model (reuse existing model loading logic)
|
|
121
|
+
if model_path and os.path.exists(model_path):
|
|
122
|
+
model = Spotiflow.from_folder(model_path)
|
|
123
|
+
else:
|
|
124
|
+
model = Spotiflow.from_pretrained(pretrained_model)
|
|
125
|
+
|
|
126
|
+
model = model.to(device)
|
|
127
|
+
|
|
128
|
+
# Prepare input (reuse existing logic)
|
|
129
|
+
axes = _infer_axes(image)
|
|
130
|
+
prepared_img = _prepare_input(image, axes)
|
|
131
|
+
|
|
132
|
+
# Normalize (simple percentile normalization)
|
|
133
|
+
p_low, p_high = np.percentile(prepared_img, [1.0, 99.8])
|
|
134
|
+
normalized_img = np.clip(
|
|
135
|
+
(prepared_img - p_low) / (p_high - p_low), 0, 1
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Get prediction with details
|
|
139
|
+
points_new, details = model.predict(
|
|
140
|
+
normalized_img,
|
|
141
|
+
prob_thresh=prob_thresh,
|
|
142
|
+
device=device,
|
|
143
|
+
verbose=False,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Use probability heatmap for segmentation
|
|
147
|
+
if hasattr(details, "heatmap") and details.heatmap is not None:
|
|
148
|
+
prob_map = details.heatmap
|
|
149
|
+
|
|
150
|
+
# Apply threshold to create binary mask
|
|
151
|
+
threshold = prob_thresh if prob_thresh is not None else 0.4
|
|
152
|
+
binary_mask = prob_map > threshold
|
|
153
|
+
|
|
154
|
+
# Use detected points as seeds for watershed segmentation
|
|
155
|
+
if len(points) > 0:
|
|
156
|
+
# Create marker image from detected points
|
|
157
|
+
markers = np.zeros(prob_map.shape, dtype=np.int32)
|
|
158
|
+
for i, point in enumerate(points):
|
|
159
|
+
if len(point) >= 2:
|
|
160
|
+
y, x = int(point[0]), int(point[1])
|
|
161
|
+
if (
|
|
162
|
+
0 <= y < markers.shape[0]
|
|
163
|
+
and 0 <= x < markers.shape[1]
|
|
164
|
+
):
|
|
165
|
+
markers[y, x] = i + 1
|
|
166
|
+
|
|
167
|
+
# Apply watershed segmentation using probability map and markers
|
|
168
|
+
labels = watershed(-prob_map, markers, mask=binary_mask)
|
|
169
|
+
else:
|
|
170
|
+
# No points detected, just label connected components
|
|
171
|
+
labels, _ = label(binary_mask)
|
|
172
|
+
|
|
173
|
+
return labels.astype(np.uint16)
|
|
174
|
+
else:
|
|
175
|
+
# Fallback to point-based method
|
|
176
|
+
return _points_to_label_mask(points, image.shape[:2], spot_radius)
|
|
177
|
+
|
|
178
|
+
except (ImportError, RuntimeError, ValueError, AttributeError) as e:
|
|
179
|
+
print(f"Error in heatmap-based conversion: {e}")
|
|
180
|
+
# Fallback to point-based method
|
|
181
|
+
return _points_to_label_mask(points, image.shape[:2], spot_radius)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@BatchProcessingRegistry.register(
|
|
185
|
+
name="Spotiflow Spot Detection",
|
|
186
|
+
suffix="_spot_labels",
|
|
187
|
+
description="Detect spots in fluorescence microscopy images using Spotiflow and return as label masks",
|
|
188
|
+
parameters={
|
|
189
|
+
"pretrained_model": {
|
|
190
|
+
"type": str,
|
|
191
|
+
"default": "general",
|
|
192
|
+
"description": "Pretrained model to use (general, hybiss, synth_complex, synth_3d, smfish_3d)",
|
|
193
|
+
"choices": [
|
|
194
|
+
"general",
|
|
195
|
+
"hybiss",
|
|
196
|
+
"synth_complex",
|
|
197
|
+
"synth_3d",
|
|
198
|
+
"smfish_3d",
|
|
199
|
+
],
|
|
200
|
+
},
|
|
201
|
+
"model_path": {
|
|
202
|
+
"type": str,
|
|
203
|
+
"default": "",
|
|
204
|
+
"description": "Path to custom trained model folder (leave empty to use pretrained model)",
|
|
205
|
+
},
|
|
206
|
+
"subpixel": {
|
|
207
|
+
"type": bool,
|
|
208
|
+
"default": True,
|
|
209
|
+
"description": "Enable subpixel localization for more accurate spot coordinates",
|
|
210
|
+
},
|
|
211
|
+
"peak_mode": {
|
|
212
|
+
"type": str,
|
|
213
|
+
"default": "fast",
|
|
214
|
+
"description": "Peak detection mode",
|
|
215
|
+
"choices": ["fast", "skimage"],
|
|
216
|
+
},
|
|
217
|
+
"normalizer": {
|
|
218
|
+
"type": str,
|
|
219
|
+
"default": "percentile",
|
|
220
|
+
"description": "Image normalization method",
|
|
221
|
+
"choices": ["percentile", "minmax"],
|
|
222
|
+
},
|
|
223
|
+
"normalizer_low": {
|
|
224
|
+
"type": float,
|
|
225
|
+
"default": 1.0,
|
|
226
|
+
"min": 0.0,
|
|
227
|
+
"max": 50.0,
|
|
228
|
+
"description": "Lower percentile for normalization",
|
|
229
|
+
},
|
|
230
|
+
"normalizer_high": {
|
|
231
|
+
"type": float,
|
|
232
|
+
"default": 99.8,
|
|
233
|
+
"min": 50.0,
|
|
234
|
+
"max": 100.0,
|
|
235
|
+
"description": "Upper percentile for normalization",
|
|
236
|
+
},
|
|
237
|
+
"prob_thresh": {
|
|
238
|
+
"type": float,
|
|
239
|
+
"default": None,
|
|
240
|
+
"min": 0.0,
|
|
241
|
+
"max": 1.0,
|
|
242
|
+
"description": "Probability threshold (leave empty or 0.0 for automatic)",
|
|
243
|
+
},
|
|
244
|
+
"n_tiles": {
|
|
245
|
+
"type": str,
|
|
246
|
+
"default": "auto",
|
|
247
|
+
"description": "Number of tiles for prediction (e.g., '(2,2)' or 'auto')",
|
|
248
|
+
},
|
|
249
|
+
"exclude_border": {
|
|
250
|
+
"type": bool,
|
|
251
|
+
"default": True,
|
|
252
|
+
"description": "Exclude spots near image borders",
|
|
253
|
+
},
|
|
254
|
+
"scale": {
|
|
255
|
+
"type": str,
|
|
256
|
+
"default": "auto",
|
|
257
|
+
"description": "Scaling factor (e.g., '(1,1)' or 'auto')",
|
|
258
|
+
},
|
|
259
|
+
"min_distance": {
|
|
260
|
+
"type": int,
|
|
261
|
+
"default": 2,
|
|
262
|
+
"min": 1,
|
|
263
|
+
"max": 10,
|
|
264
|
+
"description": "Minimum distance between detected spots",
|
|
265
|
+
},
|
|
266
|
+
"spot_radius": {
|
|
267
|
+
"type": int,
|
|
268
|
+
"default": 3,
|
|
269
|
+
"min": 1,
|
|
270
|
+
"max": 20,
|
|
271
|
+
"description": "Radius of spots in the label mask (in pixels, used for fallback method)",
|
|
272
|
+
},
|
|
273
|
+
"axes": {
|
|
274
|
+
"type": str,
|
|
275
|
+
"default": "auto",
|
|
276
|
+
"description": "Axes order (e.g., 'ZYX', 'YX', or 'auto' for automatic detection)",
|
|
277
|
+
},
|
|
278
|
+
"output_csv": {
|
|
279
|
+
"type": bool,
|
|
280
|
+
"default": True,
|
|
281
|
+
"description": "Save spot coordinates as CSV file alongside the mask",
|
|
282
|
+
},
|
|
283
|
+
"force_dedicated_env": {
|
|
284
|
+
"type": bool,
|
|
285
|
+
"default": False,
|
|
286
|
+
"description": "Force using dedicated environment even if Spotiflow is available",
|
|
287
|
+
},
|
|
288
|
+
"force_cpu": {
|
|
289
|
+
"type": bool,
|
|
290
|
+
"default": False,
|
|
291
|
+
"description": "Force CPU execution (disable GPU) to avoid CUDA compatibility issues",
|
|
292
|
+
},
|
|
293
|
+
},
|
|
294
|
+
)
|
|
295
|
+
def spotiflow_detect_spots(
|
|
296
|
+
image: np.ndarray,
|
|
297
|
+
pretrained_model: str = "general",
|
|
298
|
+
model_path: str = "",
|
|
299
|
+
subpixel: bool = True,
|
|
300
|
+
peak_mode: str = "fast",
|
|
301
|
+
normalizer: str = "percentile",
|
|
302
|
+
normalizer_low: float = 1.0,
|
|
303
|
+
normalizer_high: float = 99.8,
|
|
304
|
+
prob_thresh: float = None,
|
|
305
|
+
n_tiles: str = "auto",
|
|
306
|
+
exclude_border: bool = True,
|
|
307
|
+
scale: str = "auto",
|
|
308
|
+
min_distance: int = 2,
|
|
309
|
+
spot_radius: int = 3,
|
|
310
|
+
axes: str = "auto",
|
|
311
|
+
output_csv: bool = True,
|
|
312
|
+
force_dedicated_env: bool = False,
|
|
313
|
+
force_cpu: bool = False,
|
|
314
|
+
# For internal use by processing system
|
|
315
|
+
input_file_path: str = None,
|
|
316
|
+
) -> np.ndarray:
|
|
317
|
+
"""
|
|
318
|
+
Detect spots in fluorescence microscopy images using Spotiflow and return label masks.
|
|
319
|
+
|
|
320
|
+
Spotiflow is a deep learning-based spot detection method that provides
|
|
321
|
+
threshold-agnostic, subpixel-accurate detection of spots in 2D and 3D
|
|
322
|
+
fluorescence microscopy images. The output is a label mask suitable for
|
|
323
|
+
napari Labels layers, created from the Spotiflow probability heatmap.
|
|
324
|
+
|
|
325
|
+
Parameters:
|
|
326
|
+
-----------
|
|
327
|
+
image : np.ndarray
|
|
328
|
+
Input image (2D or 3D)
|
|
329
|
+
pretrained_model : str
|
|
330
|
+
Pretrained model to use ('general', 'hybiss', 'synth_complex', 'synth_3d', 'smfish_3d')
|
|
331
|
+
model_path : str
|
|
332
|
+
Path to custom trained model folder (overrides pretrained_model if provided)
|
|
333
|
+
subpixel : bool
|
|
334
|
+
Enable subpixel localization
|
|
335
|
+
peak_mode : str
|
|
336
|
+
Peak detection mode ('fast' or 'skimage')
|
|
337
|
+
normalizer : str
|
|
338
|
+
Image normalization method ('percentile' or 'minmax')
|
|
339
|
+
normalizer_low : float
|
|
340
|
+
Lower percentile for normalization
|
|
341
|
+
normalizer_high : float
|
|
342
|
+
Upper percentile for normalization
|
|
343
|
+
prob_thresh : float or None
|
|
344
|
+
Probability threshold (None for automatic)
|
|
345
|
+
n_tiles : str
|
|
346
|
+
Number of tiles for prediction (e.g., '(2,2)' or 'auto')
|
|
347
|
+
exclude_border : bool
|
|
348
|
+
Exclude spots near image borders
|
|
349
|
+
scale : str
|
|
350
|
+
Scaling factor (e.g., '(1,1)' or 'auto')
|
|
351
|
+
min_distance : int
|
|
352
|
+
Minimum distance between detected spots
|
|
353
|
+
spot_radius : int
|
|
354
|
+
Radius of spots in the label mask (in pixels, used for fallback method)
|
|
355
|
+
axes : str
|
|
356
|
+
Axes order (e.g., 'ZYX', 'YX', or 'auto' for automatic detection)
|
|
357
|
+
output_csv : bool
|
|
358
|
+
Save spot coordinates as CSV file alongside the mask
|
|
359
|
+
force_dedicated_env : bool
|
|
360
|
+
Force using dedicated environment
|
|
361
|
+
force_cpu : bool
|
|
362
|
+
Force CPU execution (disable GPU) to avoid CUDA compatibility issues
|
|
363
|
+
input_file_path : str
|
|
364
|
+
Path to input file (used for saving CSV output)
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
--------
|
|
368
|
+
np.ndarray
|
|
369
|
+
Label mask with detected spots (uint16) for napari Labels layer
|
|
370
|
+
"""
|
|
371
|
+
print("Detecting spots using Spotiflow...")
|
|
372
|
+
print(f"Image shape: {image.shape}")
|
|
373
|
+
print(f"Image dtype: {image.dtype}")
|
|
374
|
+
|
|
375
|
+
# Infer axes if auto
|
|
376
|
+
if axes == "auto":
|
|
377
|
+
axes = _infer_axes(image)
|
|
378
|
+
print(f"Inferred axes: {axes}")
|
|
379
|
+
else:
|
|
380
|
+
print(f"Using provided axes: {axes}")
|
|
381
|
+
|
|
382
|
+
# Decide whether to use dedicated environment
|
|
383
|
+
use_env = USE_DEDICATED_ENV or force_dedicated_env
|
|
384
|
+
|
|
385
|
+
if not use_env and SPOTIFLOW_AVAILABLE:
|
|
386
|
+
# Use direct import
|
|
387
|
+
points = _detect_spots_direct(
|
|
388
|
+
image,
|
|
389
|
+
axes,
|
|
390
|
+
pretrained_model,
|
|
391
|
+
model_path,
|
|
392
|
+
subpixel,
|
|
393
|
+
peak_mode,
|
|
394
|
+
normalizer,
|
|
395
|
+
normalizer_low,
|
|
396
|
+
normalizer_high,
|
|
397
|
+
prob_thresh,
|
|
398
|
+
n_tiles,
|
|
399
|
+
exclude_border,
|
|
400
|
+
scale,
|
|
401
|
+
min_distance,
|
|
402
|
+
force_cpu,
|
|
403
|
+
)
|
|
404
|
+
else:
|
|
405
|
+
# Use dedicated environment
|
|
406
|
+
points = _detect_spots_env(
|
|
407
|
+
image,
|
|
408
|
+
axes,
|
|
409
|
+
pretrained_model,
|
|
410
|
+
model_path,
|
|
411
|
+
subpixel,
|
|
412
|
+
peak_mode,
|
|
413
|
+
normalizer,
|
|
414
|
+
normalizer_low,
|
|
415
|
+
normalizer_high,
|
|
416
|
+
prob_thresh,
|
|
417
|
+
n_tiles,
|
|
418
|
+
exclude_border,
|
|
419
|
+
scale,
|
|
420
|
+
min_distance,
|
|
421
|
+
force_cpu,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Save CSV if requested (use a default filename if no input path provided)
|
|
425
|
+
if output_csv:
|
|
426
|
+
if input_file_path:
|
|
427
|
+
_save_coords_csv(points, input_file_path, use_env)
|
|
428
|
+
else:
|
|
429
|
+
# No input file path provided; skipping CSV export.
|
|
430
|
+
print(
|
|
431
|
+
"No input file path provided, skipping CSV export of spot coordinates."
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Convert points to label masks using the improved method
|
|
435
|
+
print(f"Detected {len(points)} spots, converting to label masks...")
|
|
436
|
+
|
|
437
|
+
# Always use the simple point-based method for now to ensure it works
|
|
438
|
+
label_mask = _points_to_label_mask(points, image.shape, spot_radius)
|
|
439
|
+
|
|
440
|
+
print(
|
|
441
|
+
f"Created label mask with {len(np.unique(label_mask)) - 1} labeled objects"
|
|
442
|
+
)
|
|
443
|
+
return label_mask
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def _points_to_label_mask(
|
|
447
|
+
points: np.ndarray, image_shape: tuple, spot_radius: int
|
|
448
|
+
) -> np.ndarray:
|
|
449
|
+
"""Convert detected points to a label mask for napari."""
|
|
450
|
+
from scipy import ndimage
|
|
451
|
+
from skimage import draw
|
|
452
|
+
|
|
453
|
+
# Create empty label mask with the same shape as input image
|
|
454
|
+
label_mask = np.zeros(image_shape, dtype=np.uint16)
|
|
455
|
+
|
|
456
|
+
# Handle different dimensionalities - focus on spatial dimensions
|
|
457
|
+
spatial_dims = len(image_shape)
|
|
458
|
+
if spatial_dims >= 4: # TZYX, TZYXC, etc.
|
|
459
|
+
if image_shape[-1] <= 4: # Last dim is channels
|
|
460
|
+
spatial_shape = image_shape[-4:-1] # Take ZYX (skip channels)
|
|
461
|
+
else:
|
|
462
|
+
spatial_shape = image_shape[-3:] # Take last 3 dims (ZYX)
|
|
463
|
+
elif spatial_dims == 3: # ZYX or YXC
|
|
464
|
+
# Check if last dimension is small (likely channels)
|
|
465
|
+
if image_shape[-1] <= 4:
|
|
466
|
+
spatial_shape = image_shape[:2] # YX (with channels)
|
|
467
|
+
else:
|
|
468
|
+
spatial_shape = image_shape # ZYX
|
|
469
|
+
else: # 2D: YX or YXC
|
|
470
|
+
if len(image_shape) == 3 and image_shape[-1] <= 4:
|
|
471
|
+
spatial_shape = image_shape[:2] # YX (with channels)
|
|
472
|
+
else:
|
|
473
|
+
spatial_shape = image_shape # YX
|
|
474
|
+
|
|
475
|
+
if len(points) == 0:
|
|
476
|
+
return label_mask
|
|
477
|
+
|
|
478
|
+
# Check coordinate format and swap if necessary
|
|
479
|
+
if points.shape[1] == 2: # 2D points (y, x)
|
|
480
|
+
coords = points.astype(int)
|
|
481
|
+
elif points.shape[1] == 3: # 3D points - need to figure out the format
|
|
482
|
+
# Try to determine the correct coordinate mapping based on spatial shape
|
|
483
|
+
if len(spatial_shape) == 2: # Working with 2D spatial data
|
|
484
|
+
# If dim1 and dim2 fit in image bounds, assume (z, y, x)
|
|
485
|
+
if (
|
|
486
|
+
points[:, 1].max() < spatial_shape[0]
|
|
487
|
+
and points[:, 2].max() < spatial_shape[1]
|
|
488
|
+
):
|
|
489
|
+
coords = points[:, 1:3].astype(int) # Take y, x (skip z)
|
|
490
|
+
# If dim0 and dim2 fit in image bounds, assume (y, z, x)
|
|
491
|
+
elif (
|
|
492
|
+
points[:, 0].max() < spatial_shape[0]
|
|
493
|
+
and points[:, 2].max() < spatial_shape[1]
|
|
494
|
+
):
|
|
495
|
+
coords = points[:, [0, 2]].astype(int) # Take y, x (skip z)
|
|
496
|
+
# If dim0 and dim1 fit in image bounds, assume (y, x, z)
|
|
497
|
+
elif (
|
|
498
|
+
points[:, 0].max() < spatial_shape[0]
|
|
499
|
+
and points[:, 1].max() < spatial_shape[1]
|
|
500
|
+
):
|
|
501
|
+
coords = points[:, 0:2].astype(int) # Take y, x (skip z)
|
|
502
|
+
else:
|
|
503
|
+
# Try swapping coordinates - maybe it's (x, y) instead of (y, x)
|
|
504
|
+
coords = points[:, [1, 0]].astype(int)
|
|
505
|
+
else: # Working with 3D spatial data
|
|
506
|
+
coords = points.astype(int) # Use all 3 coordinates
|
|
507
|
+
else:
|
|
508
|
+
raise ValueError(f"Unexpected points shape: {points.shape}")
|
|
509
|
+
|
|
510
|
+
# Create spots based on spatial dimensions
|
|
511
|
+
valid_spots = 0
|
|
512
|
+
|
|
513
|
+
if len(spatial_shape) == 2: # 2D spatial
|
|
514
|
+
for i, (y, x) in enumerate(coords):
|
|
515
|
+
if 0 <= y < spatial_shape[0] and 0 <= x < spatial_shape[1]:
|
|
516
|
+
try:
|
|
517
|
+
rr, cc = draw.disk(
|
|
518
|
+
(y, x), spot_radius, shape=spatial_shape
|
|
519
|
+
)
|
|
520
|
+
# Handle different label mask shapes
|
|
521
|
+
if len(image_shape) == 2: # Pure 2D
|
|
522
|
+
label_mask[rr, cc] = i + 1
|
|
523
|
+
elif len(image_shape) == 3: # 2D with channels or 3D
|
|
524
|
+
if image_shape[-1] <= 4: # Likely channels
|
|
525
|
+
label_mask[rr, cc, :] = (
|
|
526
|
+
i + 1
|
|
527
|
+
) # Apply to all channels
|
|
528
|
+
else: # 3D data - apply to all Z slices
|
|
529
|
+
label_mask[:, rr, cc] = i + 1
|
|
530
|
+
elif len(image_shape) == 4: # TZYX or similar
|
|
531
|
+
label_mask[:, :, rr, cc] = (
|
|
532
|
+
i + 1
|
|
533
|
+
) # Apply to all T and Z
|
|
534
|
+
elif len(image_shape) == 5: # TZYXC
|
|
535
|
+
label_mask[:, :, rr, cc, :] = (
|
|
536
|
+
i + 1
|
|
537
|
+
) # Apply to all T, Z, and C
|
|
538
|
+
|
|
539
|
+
valid_spots += 1
|
|
540
|
+
except (ValueError, IndexError, TypeError) as e:
|
|
541
|
+
print(f"Error drawing spot {i} at ({y}, {x}): {e}")
|
|
542
|
+
|
|
543
|
+
elif len(spatial_shape) == 3: # 3D spatial
|
|
544
|
+
# For 3D spatial, we need 3D coordinates
|
|
545
|
+
if coords.shape[1] == 2:
|
|
546
|
+
# We have 2D points but need 3D - place them in the middle Z slice
|
|
547
|
+
middle_z = spatial_shape[0] // 2
|
|
548
|
+
coords_3d = np.column_stack(
|
|
549
|
+
[np.full(len(coords), middle_z), coords]
|
|
550
|
+
)
|
|
551
|
+
else:
|
|
552
|
+
coords_3d = coords
|
|
553
|
+
|
|
554
|
+
for i, (z, y, x) in enumerate(coords_3d):
|
|
555
|
+
if (
|
|
556
|
+
0 <= z < spatial_shape[0]
|
|
557
|
+
and 0 <= y < spatial_shape[1]
|
|
558
|
+
and 0 <= x < spatial_shape[2]
|
|
559
|
+
):
|
|
560
|
+
try:
|
|
561
|
+
# Create a small sphere
|
|
562
|
+
ball = ndimage.generate_binary_structure(3, 1)
|
|
563
|
+
ball = ndimage.iterate_structure(ball, spot_radius)
|
|
564
|
+
|
|
565
|
+
# Get sphere coordinates
|
|
566
|
+
ball_coords = np.array(np.where(ball)).T - spot_radius
|
|
567
|
+
z_coords = ball_coords[:, 0] + z
|
|
568
|
+
y_coords = ball_coords[:, 1] + y
|
|
569
|
+
x_coords = ball_coords[:, 2] + x
|
|
570
|
+
|
|
571
|
+
# Filter valid coordinates
|
|
572
|
+
valid = (
|
|
573
|
+
(z_coords >= 0)
|
|
574
|
+
& (z_coords < spatial_shape[0])
|
|
575
|
+
& (y_coords >= 0)
|
|
576
|
+
& (y_coords < spatial_shape[1])
|
|
577
|
+
& (x_coords >= 0)
|
|
578
|
+
& (x_coords < spatial_shape[2])
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Handle different label mask shapes
|
|
582
|
+
if len(image_shape) == 3: # Pure 3D
|
|
583
|
+
label_mask[
|
|
584
|
+
z_coords[valid], y_coords[valid], x_coords[valid]
|
|
585
|
+
] = (i + 1)
|
|
586
|
+
elif len(image_shape) == 4: # TZYX or ZYXC
|
|
587
|
+
if image_shape[-1] <= 4: # ZYXC
|
|
588
|
+
label_mask[
|
|
589
|
+
z_coords[valid],
|
|
590
|
+
y_coords[valid],
|
|
591
|
+
x_coords[valid],
|
|
592
|
+
:,
|
|
593
|
+
] = (
|
|
594
|
+
i + 1
|
|
595
|
+
)
|
|
596
|
+
else: # TZYX
|
|
597
|
+
label_mask[
|
|
598
|
+
:,
|
|
599
|
+
z_coords[valid],
|
|
600
|
+
y_coords[valid],
|
|
601
|
+
x_coords[valid],
|
|
602
|
+
] = (
|
|
603
|
+
i + 1
|
|
604
|
+
)
|
|
605
|
+
elif len(image_shape) == 5: # TZYXC
|
|
606
|
+
label_mask[
|
|
607
|
+
:,
|
|
608
|
+
z_coords[valid],
|
|
609
|
+
y_coords[valid],
|
|
610
|
+
x_coords[valid],
|
|
611
|
+
:,
|
|
612
|
+
] = (
|
|
613
|
+
i + 1
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
valid_spots += 1
|
|
617
|
+
except (ValueError, IndexError, TypeError) as e:
|
|
618
|
+
print(f"Error drawing 3D spot {i} at ({z}, {y}, {x}): {e}")
|
|
619
|
+
|
|
620
|
+
print(
|
|
621
|
+
f"Successfully created {valid_spots} spots in label mask with shape {label_mask.shape}"
|
|
622
|
+
)
|
|
623
|
+
return label_mask
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def _detect_spots_direct(
|
|
627
|
+
image,
|
|
628
|
+
axes,
|
|
629
|
+
pretrained_model,
|
|
630
|
+
model_path,
|
|
631
|
+
subpixel,
|
|
632
|
+
peak_mode,
|
|
633
|
+
normalizer,
|
|
634
|
+
normalizer_low,
|
|
635
|
+
normalizer_high,
|
|
636
|
+
prob_thresh,
|
|
637
|
+
n_tiles,
|
|
638
|
+
exclude_border,
|
|
639
|
+
scale,
|
|
640
|
+
min_distance,
|
|
641
|
+
force_cpu,
|
|
642
|
+
):
|
|
643
|
+
"""Direct implementation using imported Spotiflow."""
|
|
644
|
+
import torch
|
|
645
|
+
from spotiflow.model import Spotiflow
|
|
646
|
+
|
|
647
|
+
# Set device based on force_cpu parameter
|
|
648
|
+
if force_cpu:
|
|
649
|
+
print("Forcing CPU execution as requested")
|
|
650
|
+
device = torch.device("cpu")
|
|
651
|
+
# Set environment variable to ensure CPU usage
|
|
652
|
+
import os
|
|
653
|
+
|
|
654
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
655
|
+
else:
|
|
656
|
+
# Use CUDA if available and compatible
|
|
657
|
+
if torch.cuda.is_available():
|
|
658
|
+
try:
|
|
659
|
+
# Test CUDA compatibility by creating a small tensor
|
|
660
|
+
torch.ones(1).cuda()
|
|
661
|
+
device = torch.device("cuda")
|
|
662
|
+
print("Using CUDA (GPU) for inference")
|
|
663
|
+
except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
|
|
664
|
+
print(f"CUDA incompatible ({e}), falling back to CPU")
|
|
665
|
+
device = torch.device("cpu")
|
|
666
|
+
force_cpu = True
|
|
667
|
+
else:
|
|
668
|
+
print("CUDA not available, using CPU")
|
|
669
|
+
device = torch.device("cpu")
|
|
670
|
+
force_cpu = True
|
|
671
|
+
|
|
672
|
+
# Load the model
|
|
673
|
+
if model_path and os.path.exists(model_path):
|
|
674
|
+
print(f"Loading custom model from {model_path}")
|
|
675
|
+
model = Spotiflow.from_folder(model_path)
|
|
676
|
+
else:
|
|
677
|
+
print(f"Loading pretrained model: {pretrained_model}")
|
|
678
|
+
model = Spotiflow.from_pretrained(pretrained_model)
|
|
679
|
+
|
|
680
|
+
# Move model to the appropriate device
|
|
681
|
+
try:
|
|
682
|
+
model = model.to(device)
|
|
683
|
+
print(f"Model moved to device: {device}")
|
|
684
|
+
except Exception as e:
|
|
685
|
+
if not force_cpu:
|
|
686
|
+
print(f"Failed to move model to GPU ({e}), falling back to CPU")
|
|
687
|
+
device = torch.device("cpu")
|
|
688
|
+
model = model.to(device)
|
|
689
|
+
else:
|
|
690
|
+
raise
|
|
691
|
+
|
|
692
|
+
# Check model compatibility with image dimensionality
|
|
693
|
+
is_3d_image = len(image.shape) == 3 and "Z" in axes
|
|
694
|
+
if is_3d_image and not model.config.is_3d:
|
|
695
|
+
print(
|
|
696
|
+
"Warning: Using a 2D model on 3D data. Consider using a 3D model like 'synth_3d' or 'smfish_3d'."
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
# Prepare input using the same method as napari-spotiflow
|
|
700
|
+
print(f"Preparing input with axes: {axes}")
|
|
701
|
+
try:
|
|
702
|
+
prepared_img = _prepare_input(image, axes)
|
|
703
|
+
print(f"Prepared image shape: {prepared_img.shape}")
|
|
704
|
+
except ValueError as e:
|
|
705
|
+
print(f"Error preparing input: {e}")
|
|
706
|
+
# Fallback to original image
|
|
707
|
+
prepared_img = image
|
|
708
|
+
|
|
709
|
+
# Parse string parameters
|
|
710
|
+
def parse_param(param_str, default_val):
|
|
711
|
+
if param_str == "auto":
|
|
712
|
+
return default_val
|
|
713
|
+
try:
|
|
714
|
+
return eval(param_str) if param_str.startswith("(") else param_str
|
|
715
|
+
except (ValueError, SyntaxError):
|
|
716
|
+
return default_val
|
|
717
|
+
|
|
718
|
+
n_tiles_parsed = parse_param(n_tiles, None)
|
|
719
|
+
scale_parsed = parse_param(scale, None)
|
|
720
|
+
|
|
721
|
+
# Prepare prediction parameters (following napari-spotiflow style)
|
|
722
|
+
predict_kwargs = {
|
|
723
|
+
"subpix": subpixel, # Note: Spotiflow API uses 'subpix', not 'subpixel'
|
|
724
|
+
"peak_mode": peak_mode,
|
|
725
|
+
"normalizer": None, # We'll handle normalization manually
|
|
726
|
+
"exclude_border": exclude_border,
|
|
727
|
+
"min_distance": min_distance,
|
|
728
|
+
"verbose": True,
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
# Set probability threshold - use automatic or provided value
|
|
732
|
+
if prob_thresh is not None and prob_thresh > 0.0:
|
|
733
|
+
predict_kwargs["prob_thresh"] = prob_thresh
|
|
734
|
+
else:
|
|
735
|
+
# Use automatic thresholding similar to napari-spotiflow
|
|
736
|
+
# Don't set prob_thresh - let spotiflow determine it automatically
|
|
737
|
+
# This includes None and 0.0 values which should use automatic thresholding
|
|
738
|
+
pass # Spotiflow will use its default optimized threshold
|
|
739
|
+
|
|
740
|
+
if n_tiles_parsed is not None:
|
|
741
|
+
predict_kwargs["n_tiles"] = n_tiles_parsed
|
|
742
|
+
if scale_parsed is not None:
|
|
743
|
+
predict_kwargs["scale"] = scale_parsed
|
|
744
|
+
|
|
745
|
+
# Handle normalization manually (similar to napari-spotiflow)
|
|
746
|
+
if normalizer == "percentile":
|
|
747
|
+
print(
|
|
748
|
+
f"Applying percentile normalization: {normalizer_low}% to {normalizer_high}%"
|
|
749
|
+
)
|
|
750
|
+
p_low, p_high = np.percentile(
|
|
751
|
+
prepared_img, [normalizer_low, normalizer_high]
|
|
752
|
+
)
|
|
753
|
+
normalized_img = np.clip(
|
|
754
|
+
(prepared_img - p_low) / (p_high - p_low), 0, 1
|
|
755
|
+
)
|
|
756
|
+
elif normalizer == "minmax":
|
|
757
|
+
print("Applying min-max normalization")
|
|
758
|
+
img_min, img_max = prepared_img.min(), prepared_img.max()
|
|
759
|
+
normalized_img = (
|
|
760
|
+
(prepared_img - img_min) / (img_max - img_min)
|
|
761
|
+
if img_max > img_min
|
|
762
|
+
else prepared_img
|
|
763
|
+
)
|
|
764
|
+
else:
|
|
765
|
+
normalized_img = prepared_img
|
|
766
|
+
|
|
767
|
+
print(
|
|
768
|
+
f"Normalized image range: {normalized_img.min():.3f} to {normalized_img.max():.3f}"
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# Perform spot detection
|
|
772
|
+
print("Running Spotiflow prediction...")
|
|
773
|
+
try:
|
|
774
|
+
points, details = model.predict(normalized_img, **predict_kwargs)
|
|
775
|
+
except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
|
|
776
|
+
if "CUDA" in str(e) and not force_cpu:
|
|
777
|
+
print(f"CUDA error during prediction ({e}), retrying with CPU")
|
|
778
|
+
# Move model to CPU and retry
|
|
779
|
+
device = torch.device("cpu")
|
|
780
|
+
model = model.to(device)
|
|
781
|
+
# Set environment to force CPU
|
|
782
|
+
import os
|
|
783
|
+
|
|
784
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
785
|
+
points, details = model.predict(normalized_img, **predict_kwargs)
|
|
786
|
+
else:
|
|
787
|
+
raise
|
|
788
|
+
|
|
789
|
+
print(f"Initial detection: {len(points)} spots")
|
|
790
|
+
|
|
791
|
+
# Only apply minimal additional filtering if we still have too many detections
|
|
792
|
+
# This should rarely be needed now that we use proper automatic thresholding
|
|
793
|
+
if len(points) > 500: # Only if we have an excessive number of spots
|
|
794
|
+
print(f"Applying additional filtering for {len(points)} spots")
|
|
795
|
+
|
|
796
|
+
# Check if we can apply probability filtering
|
|
797
|
+
if hasattr(details, "prob"):
|
|
798
|
+
# Use a more stringent threshold
|
|
799
|
+
auto_thresh = 0.7
|
|
800
|
+
prob_mask = details.prob > auto_thresh
|
|
801
|
+
points = points[prob_mask]
|
|
802
|
+
print(
|
|
803
|
+
f"After additional probability thresholding ({auto_thresh}): {len(points)} spots"
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
print(f"Final detection: {len(points)} spots")
|
|
807
|
+
return points
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
def _detect_spots_env(
|
|
811
|
+
image,
|
|
812
|
+
axes,
|
|
813
|
+
pretrained_model,
|
|
814
|
+
model_path,
|
|
815
|
+
subpixel,
|
|
816
|
+
peak_mode,
|
|
817
|
+
normalizer,
|
|
818
|
+
normalizer_low,
|
|
819
|
+
normalizer_high,
|
|
820
|
+
prob_thresh,
|
|
821
|
+
n_tiles,
|
|
822
|
+
exclude_border,
|
|
823
|
+
scale,
|
|
824
|
+
min_distance,
|
|
825
|
+
force_cpu,
|
|
826
|
+
):
|
|
827
|
+
"""Implementation using dedicated environment."""
|
|
828
|
+
# Prepare arguments for environment execution
|
|
829
|
+
args_dict = {
|
|
830
|
+
"image": image,
|
|
831
|
+
"axes": axes,
|
|
832
|
+
"pretrained_model": pretrained_model,
|
|
833
|
+
"model_path": model_path,
|
|
834
|
+
"subpixel": subpixel,
|
|
835
|
+
"peak_mode": peak_mode,
|
|
836
|
+
"normalizer": normalizer,
|
|
837
|
+
"normalizer_low": normalizer_low,
|
|
838
|
+
"normalizer_high": normalizer_high,
|
|
839
|
+
"prob_thresh": prob_thresh,
|
|
840
|
+
"n_tiles": n_tiles,
|
|
841
|
+
"exclude_border": exclude_border,
|
|
842
|
+
"scale": scale,
|
|
843
|
+
"min_distance": min_distance,
|
|
844
|
+
"force_cpu": force_cpu,
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
# Run in dedicated environment
|
|
848
|
+
result = run_spotiflow_in_env("detect_spots", args_dict)
|
|
849
|
+
|
|
850
|
+
print(f"Detected {len(result['points'])} spots")
|
|
851
|
+
return result["points"]
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
def _save_coords_csv(
|
|
855
|
+
points: np.ndarray, input_file_path: str, use_env: bool = False
|
|
856
|
+
):
|
|
857
|
+
"""Save coordinates to CSV using Spotiflow's write_coords_csv function."""
|
|
858
|
+
if not input_file_path:
|
|
859
|
+
return
|
|
860
|
+
|
|
861
|
+
# Generate CSV filename based on input file
|
|
862
|
+
from pathlib import Path
|
|
863
|
+
|
|
864
|
+
input_path = Path(input_file_path)
|
|
865
|
+
csv_path = input_path.parent / (input_path.stem + "_spots.csv")
|
|
866
|
+
|
|
867
|
+
if use_env:
|
|
868
|
+
# Use dedicated environment
|
|
869
|
+
_save_coords_csv_env(points, str(csv_path))
|
|
870
|
+
else:
|
|
871
|
+
# Use direct import
|
|
872
|
+
_save_coords_csv_direct(points, str(csv_path))
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
def _save_coords_csv_direct(points: np.ndarray, csv_path: str):
|
|
876
|
+
"""Save coordinates directly using Spotiflow utils."""
|
|
877
|
+
try:
|
|
878
|
+
from spotiflow.utils import write_coords_csv
|
|
879
|
+
|
|
880
|
+
write_coords_csv(points, csv_path)
|
|
881
|
+
print(f"Saved {len(points)} spot coordinates to {csv_path}")
|
|
882
|
+
except ImportError:
|
|
883
|
+
# Fallback to basic CSV writing
|
|
884
|
+
import pandas as pd
|
|
885
|
+
|
|
886
|
+
columns = ["y", "x"] if points.shape[1] == 2 else ["z", "y", "x"]
|
|
887
|
+
df = pd.DataFrame(points, columns=columns)
|
|
888
|
+
df.to_csv(csv_path, index=False)
|
|
889
|
+
print(
|
|
890
|
+
f"Saved {len(points)} spot coordinates to {csv_path} (fallback method)"
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
def _save_coords_csv_env(points: np.ndarray, csv_path: str):
|
|
895
|
+
"""Save coordinates using dedicated environment."""
|
|
896
|
+
import contextlib
|
|
897
|
+
import subprocess
|
|
898
|
+
import tempfile
|
|
899
|
+
|
|
900
|
+
from napari_tmidas.processing_functions.spotiflow_env_manager import (
|
|
901
|
+
get_env_python_path,
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
# Save points to temporary numpy file
|
|
905
|
+
with tempfile.NamedTemporaryFile(
|
|
906
|
+
suffix=".npy", delete=False
|
|
907
|
+
) as temp_points:
|
|
908
|
+
np.save(temp_points.name, points)
|
|
909
|
+
|
|
910
|
+
# Create script to save CSV
|
|
911
|
+
script = f"""
|
|
912
|
+
import numpy as np
|
|
913
|
+
from spotiflow.utils import write_coords_csv
|
|
914
|
+
|
|
915
|
+
# Load points
|
|
916
|
+
points = np.load('{temp_points.name}')
|
|
917
|
+
|
|
918
|
+
# Save CSV
|
|
919
|
+
write_coords_csv(points, '{csv_path}')
|
|
920
|
+
print(f"Saved {{len(points)}} spot coordinates to {csv_path}")
|
|
921
|
+
"""
|
|
922
|
+
|
|
923
|
+
with tempfile.NamedTemporaryFile(
|
|
924
|
+
mode="w", suffix=".py", delete=False
|
|
925
|
+
) as script_file:
|
|
926
|
+
script_file.write(script)
|
|
927
|
+
script_file.flush()
|
|
928
|
+
|
|
929
|
+
# Execute script
|
|
930
|
+
env_python = get_env_python_path()
|
|
931
|
+
result = subprocess.run(
|
|
932
|
+
[env_python, script_file.name],
|
|
933
|
+
check=True,
|
|
934
|
+
capture_output=True,
|
|
935
|
+
text=True,
|
|
936
|
+
)
|
|
937
|
+
|
|
938
|
+
print(result.stdout)
|
|
939
|
+
|
|
940
|
+
# Clean up
|
|
941
|
+
with contextlib.suppress(FileNotFoundError):
|
|
942
|
+
import os
|
|
943
|
+
|
|
944
|
+
os.unlink(temp_points.name)
|
|
945
|
+
os.unlink(script_file.name)
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
# Alias for convenience
|
|
949
|
+
spotiflow_spot_detection = spotiflow_detect_spots
|