napari-tmidas 0.2.1__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- napari_tmidas/__init__.py +35 -5
- napari_tmidas/_crop_anything.py +1458 -499
- napari_tmidas/_env_manager.py +76 -0
- napari_tmidas/_file_conversion.py +1646 -1131
- napari_tmidas/_file_selector.py +1464 -223
- napari_tmidas/_label_inspection.py +83 -8
- napari_tmidas/_processing_worker.py +309 -0
- napari_tmidas/_reader.py +6 -10
- napari_tmidas/_registry.py +15 -14
- napari_tmidas/_roi_colocalization.py +1221 -84
- napari_tmidas/_tests/test_crop_anything.py +123 -0
- napari_tmidas/_tests/test_env_manager.py +89 -0
- napari_tmidas/_tests/test_file_selector.py +90 -0
- napari_tmidas/_tests/test_grid_view_overlay.py +193 -0
- napari_tmidas/_tests/test_init.py +98 -0
- napari_tmidas/_tests/test_intensity_label_filter.py +222 -0
- napari_tmidas/_tests/test_label_inspection.py +86 -0
- napari_tmidas/_tests/test_processing_basic.py +500 -0
- napari_tmidas/_tests/test_processing_worker.py +142 -0
- napari_tmidas/_tests/test_regionprops_analysis.py +547 -0
- napari_tmidas/_tests/test_registry.py +135 -0
- napari_tmidas/_tests/test_scipy_filters.py +168 -0
- napari_tmidas/_tests/test_skimage_filters.py +259 -0
- napari_tmidas/_tests/test_split_channels.py +217 -0
- napari_tmidas/_tests/test_spotiflow.py +87 -0
- napari_tmidas/_tests/test_tyx_display_fix.py +142 -0
- napari_tmidas/_tests/test_ui_utils.py +68 -0
- napari_tmidas/_tests/test_widget.py +30 -0
- napari_tmidas/_tests/test_windows_basic.py +66 -0
- napari_tmidas/_ui_utils.py +57 -0
- napari_tmidas/_version.py +16 -3
- napari_tmidas/_widget.py +41 -4
- napari_tmidas/processing_functions/basic.py +557 -20
- napari_tmidas/processing_functions/careamics_env_manager.py +72 -99
- napari_tmidas/processing_functions/cellpose_env_manager.py +415 -112
- napari_tmidas/processing_functions/cellpose_segmentation.py +132 -191
- napari_tmidas/processing_functions/colocalization.py +513 -56
- napari_tmidas/processing_functions/grid_view_overlay.py +703 -0
- napari_tmidas/processing_functions/intensity_label_filter.py +422 -0
- napari_tmidas/processing_functions/regionprops_analysis.py +1280 -0
- napari_tmidas/processing_functions/sam2_env_manager.py +53 -69
- napari_tmidas/processing_functions/sam2_mp4.py +274 -195
- napari_tmidas/processing_functions/scipy_filters.py +403 -8
- napari_tmidas/processing_functions/skimage_filters.py +424 -212
- napari_tmidas/processing_functions/spotiflow_detection.py +949 -0
- napari_tmidas/processing_functions/spotiflow_env_manager.py +591 -0
- napari_tmidas/processing_functions/timepoint_merger.py +334 -86
- napari_tmidas/processing_functions/trackastra_tracking.py +24 -5
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/METADATA +92 -39
- napari_tmidas-0.2.4.dist-info/RECORD +63 -0
- napari_tmidas/_tests/__init__.py +0 -0
- napari_tmidas-0.2.1.dist-info/RECORD +0 -38
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/WHEEL +0 -0
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/entry_points.txt +0 -0
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,591 @@
|
|
|
1
|
+
# processing_functions/spotiflow_env_manager.py
|
|
2
|
+
"""
|
|
3
|
+
This module manages a dedicated virtual environment for Spotiflow.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import contextlib
|
|
7
|
+
import os
|
|
8
|
+
import subprocess
|
|
9
|
+
import tempfile
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from napari_tmidas._env_manager import BaseEnvironmentManager
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import tifffile
|
|
17
|
+
except ImportError:
|
|
18
|
+
tifffile = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SpotiflowEnvironmentManager(BaseEnvironmentManager):
|
|
22
|
+
"""Environment manager for Spotiflow."""
|
|
23
|
+
|
|
24
|
+
def __init__(self):
|
|
25
|
+
super().__init__("spotiflow")
|
|
26
|
+
|
|
27
|
+
def _install_dependencies(self, env_python: str) -> None:
|
|
28
|
+
"""Install Spotiflow-specific dependencies."""
|
|
29
|
+
# Install PyTorch first for compatibility
|
|
30
|
+
# Try to detect if CUDA is available and GPU architecture
|
|
31
|
+
cuda_available = False
|
|
32
|
+
try:
|
|
33
|
+
import torch
|
|
34
|
+
|
|
35
|
+
cuda_available = torch.cuda.is_available()
|
|
36
|
+
if cuda_available:
|
|
37
|
+
print("CUDA is available in main environment")
|
|
38
|
+
# Try to get GPU info
|
|
39
|
+
if torch.cuda.device_count() > 0:
|
|
40
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
41
|
+
print(f"GPU detected: {gpu_name}")
|
|
42
|
+
else:
|
|
43
|
+
print("CUDA is not available in main environment")
|
|
44
|
+
except ImportError:
|
|
45
|
+
print("PyTorch not detected in main environment")
|
|
46
|
+
# Try to detect CUDA from nvidia-smi
|
|
47
|
+
try:
|
|
48
|
+
result = subprocess.run(
|
|
49
|
+
["nvidia-smi"], capture_output=True, text=True
|
|
50
|
+
)
|
|
51
|
+
if result.returncode == 0:
|
|
52
|
+
cuda_available = True
|
|
53
|
+
print("NVIDIA GPU detected via nvidia-smi")
|
|
54
|
+
else:
|
|
55
|
+
cuda_available = False
|
|
56
|
+
print("No NVIDIA GPU detected")
|
|
57
|
+
except FileNotFoundError:
|
|
58
|
+
cuda_available = False
|
|
59
|
+
print("nvidia-smi not found, assuming no CUDA support")
|
|
60
|
+
|
|
61
|
+
if cuda_available:
|
|
62
|
+
# Try to install PyTorch with CUDA support, but with fallback to CPU-only
|
|
63
|
+
print("Attempting PyTorch installation with CUDA support...")
|
|
64
|
+
try:
|
|
65
|
+
# First try with CUDA 11.8 which supports sm_61 (GTX 1080 Ti) and other older GPUs
|
|
66
|
+
subprocess.check_call(
|
|
67
|
+
[
|
|
68
|
+
env_python,
|
|
69
|
+
"-m",
|
|
70
|
+
"pip",
|
|
71
|
+
"install",
|
|
72
|
+
"torch==2.0.1",
|
|
73
|
+
"torchvision==0.15.2",
|
|
74
|
+
"--index-url",
|
|
75
|
+
"https://download.pytorch.org/whl/cu118",
|
|
76
|
+
]
|
|
77
|
+
)
|
|
78
|
+
print("✓ PyTorch with CUDA 11.8 installed successfully")
|
|
79
|
+
|
|
80
|
+
# Test CUDA compatibility
|
|
81
|
+
test_script = """
|
|
82
|
+
import torch
|
|
83
|
+
try:
|
|
84
|
+
if torch.cuda.is_available():
|
|
85
|
+
test_tensor = torch.ones(1).cuda()
|
|
86
|
+
print("CUDA compatibility test passed")
|
|
87
|
+
else:
|
|
88
|
+
print("CUDA not available in PyTorch")
|
|
89
|
+
exit(1)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
print(f"CUDA compatibility test failed: {e}")
|
|
92
|
+
exit(1)
|
|
93
|
+
"""
|
|
94
|
+
result = subprocess.run(
|
|
95
|
+
[env_python, "-c", test_script],
|
|
96
|
+
capture_output=True,
|
|
97
|
+
text=True,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if result.returncode != 0:
|
|
101
|
+
print(
|
|
102
|
+
"CUDA compatibility test failed, falling back to CPU-only PyTorch..."
|
|
103
|
+
)
|
|
104
|
+
# Uninstall CUDA version and install CPU version
|
|
105
|
+
subprocess.check_call(
|
|
106
|
+
[
|
|
107
|
+
env_python,
|
|
108
|
+
"-m",
|
|
109
|
+
"pip",
|
|
110
|
+
"uninstall",
|
|
111
|
+
"-y",
|
|
112
|
+
"torch",
|
|
113
|
+
"torchvision",
|
|
114
|
+
]
|
|
115
|
+
)
|
|
116
|
+
subprocess.check_call(
|
|
117
|
+
[
|
|
118
|
+
env_python,
|
|
119
|
+
"-m",
|
|
120
|
+
"pip",
|
|
121
|
+
"install",
|
|
122
|
+
"torch==2.0.1",
|
|
123
|
+
"torchvision==0.15.2",
|
|
124
|
+
]
|
|
125
|
+
)
|
|
126
|
+
print(
|
|
127
|
+
"✓ Switched to CPU-only PyTorch due to CUDA incompatibility"
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
print("✓ CUDA compatibility test passed")
|
|
131
|
+
|
|
132
|
+
except subprocess.CalledProcessError as e:
|
|
133
|
+
print(f"CUDA PyTorch installation failed: {e}")
|
|
134
|
+
print("Falling back to CPU-only PyTorch...")
|
|
135
|
+
# Install PyTorch without CUDA
|
|
136
|
+
subprocess.check_call(
|
|
137
|
+
[
|
|
138
|
+
env_python,
|
|
139
|
+
"-m",
|
|
140
|
+
"pip",
|
|
141
|
+
"install",
|
|
142
|
+
"torch==2.0.1",
|
|
143
|
+
"torchvision==0.15.2",
|
|
144
|
+
]
|
|
145
|
+
)
|
|
146
|
+
print("✓ CPU-only PyTorch installed as fallback")
|
|
147
|
+
else:
|
|
148
|
+
# Install PyTorch without CUDA
|
|
149
|
+
print("Installing PyTorch without CUDA support...")
|
|
150
|
+
subprocess.check_call(
|
|
151
|
+
[
|
|
152
|
+
env_python,
|
|
153
|
+
"-m",
|
|
154
|
+
"pip",
|
|
155
|
+
"install",
|
|
156
|
+
"torch==2.0.1",
|
|
157
|
+
"torchvision==0.15.2",
|
|
158
|
+
]
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Install Spotiflow with all dependencies, but force CPU usage to avoid GPU issues
|
|
162
|
+
print("Installing Spotiflow in the dedicated environment...")
|
|
163
|
+
subprocess.check_call(
|
|
164
|
+
[env_python, "-m", "pip", "install", "spotiflow"]
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Install additional dependencies for image handling
|
|
168
|
+
subprocess.check_call(
|
|
169
|
+
[env_python, "-m", "pip", "install", "tifffile", "numpy"]
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Check if installation was successful
|
|
173
|
+
self._verify_installation(env_python)
|
|
174
|
+
|
|
175
|
+
def _verify_installation(self, env_python: str) -> None:
|
|
176
|
+
"""Verify Spotiflow installation."""
|
|
177
|
+
check_script = """
|
|
178
|
+
import sys
|
|
179
|
+
try:
|
|
180
|
+
import spotiflow
|
|
181
|
+
print(f"Spotiflow version: {spotiflow.__version__}")
|
|
182
|
+
from spotiflow.model import Spotiflow
|
|
183
|
+
print("Spotiflow model imported successfully")
|
|
184
|
+
import torch
|
|
185
|
+
print(f"PyTorch version: {torch.__version__}")
|
|
186
|
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
|
187
|
+
if torch.cuda.is_available():
|
|
188
|
+
print(f"CUDA version: {torch.version.cuda}")
|
|
189
|
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
|
190
|
+
print("SUCCESS: Spotiflow environment is working correctly")
|
|
191
|
+
except Exception as e:
|
|
192
|
+
print(f"ERROR: {str(e)}")
|
|
193
|
+
sys.exit(1)
|
|
194
|
+
"""
|
|
195
|
+
with tempfile.NamedTemporaryFile(
|
|
196
|
+
mode="w", suffix=".py", delete=False
|
|
197
|
+
) as temp:
|
|
198
|
+
temp.write(check_script)
|
|
199
|
+
temp_path = temp.name
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
result = subprocess.run(
|
|
203
|
+
[env_python, temp_path],
|
|
204
|
+
check=True,
|
|
205
|
+
capture_output=True,
|
|
206
|
+
text=True,
|
|
207
|
+
)
|
|
208
|
+
print(result.stdout)
|
|
209
|
+
if "SUCCESS" in result.stdout:
|
|
210
|
+
print(
|
|
211
|
+
"Spotiflow environment created and verified successfully."
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
raise RuntimeError(
|
|
215
|
+
"Spotiflow environment verification failed."
|
|
216
|
+
)
|
|
217
|
+
except subprocess.CalledProcessError as e:
|
|
218
|
+
print(f"Verification failed: {e.stderr}")
|
|
219
|
+
raise
|
|
220
|
+
finally:
|
|
221
|
+
with contextlib.suppress(FileNotFoundError):
|
|
222
|
+
os.unlink(temp_path)
|
|
223
|
+
|
|
224
|
+
def is_package_installed(self) -> bool:
|
|
225
|
+
"""Check if spotiflow is installed in the current environment."""
|
|
226
|
+
try:
|
|
227
|
+
import importlib.util
|
|
228
|
+
|
|
229
|
+
return importlib.util.find_spec("spotiflow") is not None
|
|
230
|
+
except ImportError:
|
|
231
|
+
return False
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
# Global instance for backward compatibility
|
|
235
|
+
manager = SpotiflowEnvironmentManager()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def is_spotiflow_installed():
|
|
239
|
+
"""Check if spotiflow is installed in the current environment."""
|
|
240
|
+
return manager.is_package_installed()
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def is_env_created():
|
|
244
|
+
"""Check if the dedicated environment exists."""
|
|
245
|
+
return manager.is_env_created()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def get_env_python_path():
|
|
249
|
+
"""Get the path to the Python executable in the environment."""
|
|
250
|
+
return manager.get_env_python_path()
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def create_spotiflow_env():
|
|
254
|
+
"""Create a dedicated virtual environment for Spotiflow."""
|
|
255
|
+
return manager.create_env()
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def run_spotiflow_in_env(func_name, args_dict):
|
|
259
|
+
"""
|
|
260
|
+
Run Spotiflow in a dedicated environment.
|
|
261
|
+
|
|
262
|
+
Parameters:
|
|
263
|
+
-----------
|
|
264
|
+
func_name : str
|
|
265
|
+
Name of the Spotiflow function to run
|
|
266
|
+
args_dict : dict
|
|
267
|
+
Dictionary of arguments for Spotiflow prediction
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
--------
|
|
271
|
+
numpy.ndarray or tuple
|
|
272
|
+
Detection results (points coordinates and optionally heatmap/flow)
|
|
273
|
+
"""
|
|
274
|
+
# Ensure the environment exists
|
|
275
|
+
if not is_env_created():
|
|
276
|
+
create_spotiflow_env()
|
|
277
|
+
|
|
278
|
+
# Prepare temporary files
|
|
279
|
+
with (
|
|
280
|
+
tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as input_file,
|
|
281
|
+
tempfile.NamedTemporaryFile(
|
|
282
|
+
suffix=".npy", delete=False
|
|
283
|
+
) as output_file,
|
|
284
|
+
tempfile.NamedTemporaryFile(
|
|
285
|
+
mode="w", suffix=".py", delete=False
|
|
286
|
+
) as script_file,
|
|
287
|
+
):
|
|
288
|
+
|
|
289
|
+
# Save input image
|
|
290
|
+
tifffile.imwrite(input_file.name, args_dict["image"])
|
|
291
|
+
|
|
292
|
+
# Prepare a temporary script to run Spotiflow
|
|
293
|
+
script = f"""
|
|
294
|
+
import numpy as np
|
|
295
|
+
import os
|
|
296
|
+
import sys
|
|
297
|
+
print("Starting Spotiflow detection script...")
|
|
298
|
+
print(f"Python version: {{sys.version}}")
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
from spotiflow.model import Spotiflow
|
|
302
|
+
print("✓ Spotiflow model imported successfully")
|
|
303
|
+
except Exception as e:
|
|
304
|
+
print(f"✗ Failed to import Spotiflow model: {{e}}")
|
|
305
|
+
sys.exit(1)
|
|
306
|
+
|
|
307
|
+
try:
|
|
308
|
+
import tifffile
|
|
309
|
+
print("✓ tifffile imported successfully")
|
|
310
|
+
except Exception as e:
|
|
311
|
+
print(f"✗ Failed to import tifffile: {{e}}")
|
|
312
|
+
sys.exit(1)
|
|
313
|
+
|
|
314
|
+
try:
|
|
315
|
+
# Load image
|
|
316
|
+
print(f"Loading image from: {input_file.name}")
|
|
317
|
+
image = tifffile.imread('{input_file.name}')
|
|
318
|
+
print(f"✓ Image loaded successfully, shape: {{image.shape}}, dtype: {{image.dtype}}")
|
|
319
|
+
except Exception as e:
|
|
320
|
+
print(f"✗ Failed to load image: {{e}}")
|
|
321
|
+
sys.exit(1)
|
|
322
|
+
|
|
323
|
+
try:
|
|
324
|
+
# Load the model
|
|
325
|
+
if '{args_dict.get('model_path', '')}' and os.path.exists('{args_dict.get('model_path', '')}'):
|
|
326
|
+
# Load custom model from folder
|
|
327
|
+
print(f"Loading custom model from {args_dict.get('model_path', '')}")
|
|
328
|
+
model = Spotiflow.from_folder('{args_dict.get('model_path', '')}')
|
|
329
|
+
else:
|
|
330
|
+
# Load pretrained model
|
|
331
|
+
print(f"Loading pretrained model: {args_dict.get('pretrained_model', 'general')}")
|
|
332
|
+
model = Spotiflow.from_pretrained('{args_dict.get('pretrained_model', 'general')}')
|
|
333
|
+
print("✓ Model loaded successfully")
|
|
334
|
+
|
|
335
|
+
# Handle device selection and force_cpu parameter
|
|
336
|
+
import torch
|
|
337
|
+
force_cpu = {args_dict.get('force_cpu', False)}
|
|
338
|
+
|
|
339
|
+
if force_cpu:
|
|
340
|
+
print("Forcing CPU execution as requested")
|
|
341
|
+
device = torch.device("cpu")
|
|
342
|
+
# Set environment variable to ensure CPU usage
|
|
343
|
+
import os
|
|
344
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
345
|
+
else:
|
|
346
|
+
# Use CUDA if available and compatible
|
|
347
|
+
if torch.cuda.is_available():
|
|
348
|
+
try:
|
|
349
|
+
# Test CUDA compatibility by creating a small tensor
|
|
350
|
+
test_tensor = torch.ones(1).cuda()
|
|
351
|
+
device = torch.device("cuda")
|
|
352
|
+
print("Using CUDA (GPU) for inference")
|
|
353
|
+
except Exception as cuda_e:
|
|
354
|
+
print(f"CUDA incompatible ({{cuda_e}}), falling back to CPU")
|
|
355
|
+
device = torch.device("cpu")
|
|
356
|
+
force_cpu = True
|
|
357
|
+
else:
|
|
358
|
+
print("CUDA not available, using CPU")
|
|
359
|
+
device = torch.device("cpu")
|
|
360
|
+
force_cpu = True
|
|
361
|
+
|
|
362
|
+
# Move model to appropriate device
|
|
363
|
+
try:
|
|
364
|
+
model = model.to(device)
|
|
365
|
+
print(f"Model moved to device: {{device}}")
|
|
366
|
+
except Exception as device_e:
|
|
367
|
+
if not force_cpu:
|
|
368
|
+
print(f"Failed to move model to GPU ({{device_e}}), falling back to CPU")
|
|
369
|
+
device = torch.device("cpu")
|
|
370
|
+
model = model.to(device)
|
|
371
|
+
else:
|
|
372
|
+
raise
|
|
373
|
+
|
|
374
|
+
except Exception as e:
|
|
375
|
+
print(f"✗ Failed to load model: {{e}}")
|
|
376
|
+
sys.exit(1)
|
|
377
|
+
|
|
378
|
+
# Utility functions for input preparation
|
|
379
|
+
def _validate_axes(img, axes):
|
|
380
|
+
if img.ndim != len(axes):
|
|
381
|
+
raise ValueError(f"Image has {{img.ndim}} dimensions, but axes has {{len(axes)}} dimensions")
|
|
382
|
+
|
|
383
|
+
def _prepare_input(img, axes):
|
|
384
|
+
_validate_axes(img, axes)
|
|
385
|
+
if axes in {{"YX", "ZYX", "TYX", "TZYX"}}:
|
|
386
|
+
return img[..., None]
|
|
387
|
+
elif axes in {{"YXC", "ZYXC", "TYXC", "TZYXC"}}:
|
|
388
|
+
return img
|
|
389
|
+
elif axes == "CYX":
|
|
390
|
+
return img.transpose(1, 2, 0)
|
|
391
|
+
elif axes == "CZYX":
|
|
392
|
+
return img.transpose(1, 2, 3, 0)
|
|
393
|
+
elif axes == "ZCYX":
|
|
394
|
+
return img.transpose(0, 2, 3, 1)
|
|
395
|
+
elif axes == "TCYX":
|
|
396
|
+
return img.transpose(0, 2, 3, 1)
|
|
397
|
+
elif axes == "TZCYX":
|
|
398
|
+
return img.transpose(0, 1, 3, 4, 2)
|
|
399
|
+
elif axes == "TCZYX":
|
|
400
|
+
return img.transpose(0, 2, 3, 4, 1)
|
|
401
|
+
else:
|
|
402
|
+
raise ValueError(f"Invalid axes: {{axes}}")
|
|
403
|
+
|
|
404
|
+
try:
|
|
405
|
+
# Handle axes and input preparation
|
|
406
|
+
axes = '{args_dict.get('axes', 'auto')}'
|
|
407
|
+
if axes == 'auto':
|
|
408
|
+
# Auto-infer axes
|
|
409
|
+
ndim = image.ndim
|
|
410
|
+
if ndim == 2:
|
|
411
|
+
axes = "YX"
|
|
412
|
+
elif ndim == 3:
|
|
413
|
+
axes = "ZYX"
|
|
414
|
+
elif ndim == 4:
|
|
415
|
+
if image.shape[-1] <= 4:
|
|
416
|
+
axes = "ZYXC"
|
|
417
|
+
else:
|
|
418
|
+
axes = "TZYX"
|
|
419
|
+
elif ndim == 5:
|
|
420
|
+
axes = "TZYXC"
|
|
421
|
+
else:
|
|
422
|
+
raise ValueError(f"Cannot infer axes for {{ndim}}D image")
|
|
423
|
+
|
|
424
|
+
print(f"Using axes: {{axes}}")
|
|
425
|
+
|
|
426
|
+
# Prepare input
|
|
427
|
+
prepared_img = _prepare_input(image, axes)
|
|
428
|
+
print(f"Prepared image shape: {{prepared_img.shape}}")
|
|
429
|
+
|
|
430
|
+
# Check model compatibility
|
|
431
|
+
is_3d_image = len(image.shape) == 3 and "Z" in axes
|
|
432
|
+
if is_3d_image and not model.config.is_3d:
|
|
433
|
+
print("Warning: Using a 2D model on 3D data. Consider using a 3D model.")
|
|
434
|
+
|
|
435
|
+
except Exception as e:
|
|
436
|
+
print(f"✗ Failed to prepare input: {{e}}")
|
|
437
|
+
# Fallback to original image
|
|
438
|
+
prepared_img = image
|
|
439
|
+
axes = "YX" if image.ndim == 2 else "ZYX"
|
|
440
|
+
|
|
441
|
+
try:
|
|
442
|
+
# Parse string parameters
|
|
443
|
+
def parse_param(param_str, default_val):
|
|
444
|
+
if param_str == "auto":
|
|
445
|
+
return default_val
|
|
446
|
+
try:
|
|
447
|
+
return eval(param_str) if param_str.startswith("(") else param_str
|
|
448
|
+
except:
|
|
449
|
+
return default_val
|
|
450
|
+
|
|
451
|
+
n_tiles_parsed = parse_param('{args_dict.get('n_tiles', 'auto')}', None)
|
|
452
|
+
scale_parsed = parse_param('{args_dict.get('scale', 'auto')}', None)
|
|
453
|
+
|
|
454
|
+
# Handle normalization manually (similar to napari-spotiflow)
|
|
455
|
+
normalizer_type = '{args_dict.get('normalizer', 'percentile')}'
|
|
456
|
+
if normalizer_type == "percentile":
|
|
457
|
+
normalizer_low = {args_dict.get('normalizer_low', 1.0)}
|
|
458
|
+
normalizer_high = {args_dict.get('normalizer_high', 99.8)}
|
|
459
|
+
print(f"Applying percentile normalization: {{normalizer_low}}% to {{normalizer_high}}%")
|
|
460
|
+
p_low, p_high = np.percentile(prepared_img, [normalizer_low, normalizer_high])
|
|
461
|
+
normalized_img = np.clip((prepared_img - p_low) / (p_high - p_low), 0, 1)
|
|
462
|
+
elif normalizer_type == "minmax":
|
|
463
|
+
print("Applying min-max normalization")
|
|
464
|
+
img_min, img_max = prepared_img.min(), prepared_img.max()
|
|
465
|
+
normalized_img = (prepared_img - img_min) / (img_max - img_min) if img_max > img_min else prepared_img
|
|
466
|
+
else:
|
|
467
|
+
normalized_img = prepared_img
|
|
468
|
+
|
|
469
|
+
print(f"Normalized image range: {{normalized_img.min():.3f}} to {{normalized_img.max():.3f}}")
|
|
470
|
+
|
|
471
|
+
# Prepare prediction parameters (following napari-spotiflow style)
|
|
472
|
+
predict_kwargs = {{
|
|
473
|
+
'subpix': {args_dict.get('subpixel', True)}, # Note: Spotiflow API uses 'subpix', not 'subpixel'
|
|
474
|
+
'peak_mode': '{args_dict.get('peak_mode', 'fast')}',
|
|
475
|
+
'normalizer': None, # We handle normalization manually
|
|
476
|
+
'exclude_border': {args_dict.get('exclude_border', True)},
|
|
477
|
+
'min_distance': {args_dict.get('min_distance', 2)},
|
|
478
|
+
'verbose': True,
|
|
479
|
+
}}
|
|
480
|
+
|
|
481
|
+
# Set probability threshold - use automatic or provided value
|
|
482
|
+
prob_thresh = {args_dict.get('prob_thresh', None)}
|
|
483
|
+
if prob_thresh is not None and prob_thresh > 0.0:
|
|
484
|
+
predict_kwargs['prob_thresh'] = prob_thresh
|
|
485
|
+
# If prob_thresh is None or 0.0, don't set it - let spotiflow use automatic threshold
|
|
486
|
+
|
|
487
|
+
if n_tiles_parsed is not None:
|
|
488
|
+
predict_kwargs['n_tiles'] = n_tiles_parsed
|
|
489
|
+
if scale_parsed is not None:
|
|
490
|
+
predict_kwargs['scale'] = scale_parsed
|
|
491
|
+
|
|
492
|
+
print(f"Prediction parameters: {{predict_kwargs}}")
|
|
493
|
+
except Exception as e:
|
|
494
|
+
print(f"✗ Failed to prepare parameters: {{e}}")
|
|
495
|
+
sys.exit(1)
|
|
496
|
+
|
|
497
|
+
try:
|
|
498
|
+
# Perform spot detection
|
|
499
|
+
print("Running Spotiflow prediction...")
|
|
500
|
+
try:
|
|
501
|
+
points, details = model.predict(normalized_img, **predict_kwargs)
|
|
502
|
+
except (RuntimeError, Exception) as pred_e:
|
|
503
|
+
if "CUDA" in str(pred_e) and not force_cpu:
|
|
504
|
+
print(f"CUDA error during prediction ({{pred_e}}), retrying with CPU")
|
|
505
|
+
# Move model to CPU and retry
|
|
506
|
+
device = torch.device("cpu")
|
|
507
|
+
model = model.to(device)
|
|
508
|
+
# Set environment to force CPU
|
|
509
|
+
import os
|
|
510
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
511
|
+
points, details = model.predict(normalized_img, **predict_kwargs)
|
|
512
|
+
else:
|
|
513
|
+
raise
|
|
514
|
+
|
|
515
|
+
print(f"✓ Initial detection: {{len(points)}} spots")
|
|
516
|
+
|
|
517
|
+
# Only apply minimal additional filtering if we still have too many detections
|
|
518
|
+
# This should rarely be needed now that we use proper automatic thresholding
|
|
519
|
+
if len(points) > 500: # Only if we have an excessive number of spots
|
|
520
|
+
print(f"Applying additional filtering for {{len(points)}} spots")
|
|
521
|
+
|
|
522
|
+
# Check if we can apply probability filtering
|
|
523
|
+
if hasattr(details, 'prob'):
|
|
524
|
+
# Use a more stringent threshold
|
|
525
|
+
auto_thresh = 0.7
|
|
526
|
+
prob_mask = details.prob > auto_thresh
|
|
527
|
+
points = points[prob_mask]
|
|
528
|
+
print(f"After additional probability thresholding ({{auto_thresh}}): {{len(points)}} spots")
|
|
529
|
+
|
|
530
|
+
print(f"Final detection: {{len(points)}} spots")
|
|
531
|
+
|
|
532
|
+
if len(points) > 0:
|
|
533
|
+
print(f"✓ Points shape: {{points.shape}}")
|
|
534
|
+
print(f"✓ Points dtype: {{points.dtype}}")
|
|
535
|
+
print(f"✓ First few points: {{points[:3]}}")
|
|
536
|
+
|
|
537
|
+
except Exception as e:
|
|
538
|
+
print(f"✗ Failed during spot detection: {{e}}")
|
|
539
|
+
import traceback
|
|
540
|
+
traceback.print_exc()
|
|
541
|
+
sys.exit(1)
|
|
542
|
+
|
|
543
|
+
try:
|
|
544
|
+
# Prepare output data
|
|
545
|
+
output_data = {{
|
|
546
|
+
'points': points,
|
|
547
|
+
}}
|
|
548
|
+
|
|
549
|
+
# Save results
|
|
550
|
+
print(f"Saving results to: {output_file.name}")
|
|
551
|
+
np.save('{output_file.name}', output_data)
|
|
552
|
+
print(f"✓ Results saved successfully")
|
|
553
|
+
print(f"Detected {{len(points)}} spots")
|
|
554
|
+
except Exception as e:
|
|
555
|
+
print(f"✗ Failed to save results: {{e}}")
|
|
556
|
+
sys.exit(1)
|
|
557
|
+
"""
|
|
558
|
+
|
|
559
|
+
# Write script
|
|
560
|
+
script_file.write(script)
|
|
561
|
+
script_file.flush()
|
|
562
|
+
|
|
563
|
+
# Execute the script in the dedicated environment
|
|
564
|
+
env_python = get_env_python_path()
|
|
565
|
+
result = subprocess.run(
|
|
566
|
+
[env_python, script_file.name],
|
|
567
|
+
capture_output=True,
|
|
568
|
+
text=True,
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
# Check for errors
|
|
572
|
+
if result.returncode != 0:
|
|
573
|
+
print("Error in Spotiflow environment execution:")
|
|
574
|
+
print(f"STDOUT: {result.stdout}")
|
|
575
|
+
print(f"STDERR: {result.stderr}")
|
|
576
|
+
raise subprocess.CalledProcessError(
|
|
577
|
+
result.returncode, result.args, result.stdout, result.stderr
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
print(result.stdout)
|
|
581
|
+
|
|
582
|
+
# Load and return results
|
|
583
|
+
output_data = np.load(output_file.name, allow_pickle=True).item()
|
|
584
|
+
|
|
585
|
+
# Clean up temporary files
|
|
586
|
+
with contextlib.suppress(FileNotFoundError):
|
|
587
|
+
os.unlink(input_file.name)
|
|
588
|
+
os.unlink(output_file.name)
|
|
589
|
+
os.unlink(script_file.name)
|
|
590
|
+
|
|
591
|
+
return output_data
|