napari-tmidas 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- napari_tmidas/_tests/test_intensity_label_filter.py +9 -11
- napari_tmidas/_tests/test_registry.py +6 -0
- napari_tmidas/_tests/test_viscy_virtual_staining.py +138 -0
- napari_tmidas/_version.py +2 -2
- napari_tmidas/processing_functions/__init__.py +14 -3
- napari_tmidas/processing_functions/intensity_label_filter.py +5 -2
- napari_tmidas/processing_functions/skimage_filters.py +71 -8
- napari_tmidas/processing_functions/viscy_env_manager.py +381 -0
- napari_tmidas/processing_functions/viscy_virtual_staining.py +393 -0
- napari_tmidas-0.3.0.dist-info/METADATA +249 -0
- {napari_tmidas-0.2.5.dist-info → napari_tmidas-0.3.0.dist-info}/RECORD +15 -12
- {napari_tmidas-0.2.5.dist-info → napari_tmidas-0.3.0.dist-info}/WHEEL +1 -1
- napari_tmidas-0.2.5.dist-info/METADATA +0 -282
- {napari_tmidas-0.2.5.dist-info → napari_tmidas-0.3.0.dist-info}/entry_points.txt +0 -0
- {napari_tmidas-0.2.5.dist-info → napari_tmidas-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {napari_tmidas-0.2.5.dist-info → napari_tmidas-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
# processing_functions/viscy_env_manager.py
|
|
2
|
+
"""
|
|
3
|
+
This module manages a dedicated virtual environment for VisCy (Virtual Staining of Cells using deep learning).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import contextlib
|
|
7
|
+
import os
|
|
8
|
+
import subprocess
|
|
9
|
+
import tempfile
|
|
10
|
+
import urllib.request
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from napari_tmidas._env_manager import BaseEnvironmentManager
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import tifffile
|
|
19
|
+
except ImportError:
|
|
20
|
+
tifffile = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ViscyEnvironmentManager(BaseEnvironmentManager):
|
|
24
|
+
"""Environment manager for VisCy."""
|
|
25
|
+
|
|
26
|
+
def __init__(self):
|
|
27
|
+
super().__init__("viscy")
|
|
28
|
+
# Model directory in the environment
|
|
29
|
+
self.model_dir = os.path.join(self.env_dir, "models")
|
|
30
|
+
|
|
31
|
+
def _install_dependencies(self, env_python: str) -> None:
|
|
32
|
+
"""Install VisCy-specific dependencies."""
|
|
33
|
+
# Try to detect if CUDA is available
|
|
34
|
+
cuda_available = False
|
|
35
|
+
try:
|
|
36
|
+
import torch
|
|
37
|
+
|
|
38
|
+
cuda_available = torch.cuda.is_available()
|
|
39
|
+
if cuda_available:
|
|
40
|
+
print("CUDA is available in main environment")
|
|
41
|
+
if torch.cuda.device_count() > 0:
|
|
42
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
43
|
+
print(f"GPU detected: {gpu_name}")
|
|
44
|
+
else:
|
|
45
|
+
print("CUDA is not available in main environment")
|
|
46
|
+
except ImportError:
|
|
47
|
+
print("PyTorch not detected in main environment")
|
|
48
|
+
# Try to detect CUDA from nvidia-smi
|
|
49
|
+
try:
|
|
50
|
+
result = subprocess.run(
|
|
51
|
+
["nvidia-smi"], capture_output=True, text=True
|
|
52
|
+
)
|
|
53
|
+
if result.returncode == 0:
|
|
54
|
+
cuda_available = True
|
|
55
|
+
print("NVIDIA GPU detected via nvidia-smi")
|
|
56
|
+
else:
|
|
57
|
+
cuda_available = False
|
|
58
|
+
print("No NVIDIA GPU detected")
|
|
59
|
+
except FileNotFoundError:
|
|
60
|
+
cuda_available = False
|
|
61
|
+
print("nvidia-smi not found, assuming no CUDA support")
|
|
62
|
+
|
|
63
|
+
if cuda_available:
|
|
64
|
+
print("Attempting PyTorch installation with CUDA support...")
|
|
65
|
+
try:
|
|
66
|
+
subprocess.check_call(
|
|
67
|
+
[
|
|
68
|
+
env_python,
|
|
69
|
+
"-m",
|
|
70
|
+
"pip",
|
|
71
|
+
"install",
|
|
72
|
+
"torch==2.0.1",
|
|
73
|
+
"torchvision==0.15.2",
|
|
74
|
+
"--index-url",
|
|
75
|
+
"https://download.pytorch.org/whl/cu118",
|
|
76
|
+
]
|
|
77
|
+
)
|
|
78
|
+
print("✓ PyTorch with CUDA 11.8 installed successfully")
|
|
79
|
+
|
|
80
|
+
# Test CUDA compatibility
|
|
81
|
+
test_script = """
|
|
82
|
+
import torch
|
|
83
|
+
try:
|
|
84
|
+
if torch.cuda.is_available():
|
|
85
|
+
test_tensor = torch.ones(1).cuda()
|
|
86
|
+
print("CUDA compatibility test passed")
|
|
87
|
+
else:
|
|
88
|
+
print("CUDA not available in PyTorch")
|
|
89
|
+
exit(1)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
print(f"CUDA compatibility test failed: {e}")
|
|
92
|
+
exit(1)
|
|
93
|
+
"""
|
|
94
|
+
result = subprocess.run(
|
|
95
|
+
[env_python, "-c", test_script],
|
|
96
|
+
capture_output=True,
|
|
97
|
+
text=True,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if result.returncode != 0:
|
|
101
|
+
print(
|
|
102
|
+
"⚠ CUDA test failed, falling back to CPU-only installation"
|
|
103
|
+
)
|
|
104
|
+
print(result.stdout)
|
|
105
|
+
print(result.stderr)
|
|
106
|
+
cuda_available = False
|
|
107
|
+
|
|
108
|
+
except subprocess.CalledProcessError as e:
|
|
109
|
+
print(
|
|
110
|
+
f"⚠ PyTorch CUDA installation failed: {e}, falling back to CPU-only"
|
|
111
|
+
)
|
|
112
|
+
cuda_available = False
|
|
113
|
+
|
|
114
|
+
if not cuda_available:
|
|
115
|
+
print("Installing PyTorch (CPU-only version)...")
|
|
116
|
+
subprocess.check_call(
|
|
117
|
+
[
|
|
118
|
+
env_python,
|
|
119
|
+
"-m",
|
|
120
|
+
"pip",
|
|
121
|
+
"install",
|
|
122
|
+
"torch==2.0.1",
|
|
123
|
+
"torchvision==0.15.2",
|
|
124
|
+
"--index-url",
|
|
125
|
+
"https://download.pytorch.org/whl/cpu",
|
|
126
|
+
]
|
|
127
|
+
)
|
|
128
|
+
print("✓ PyTorch (CPU-only) installed successfully")
|
|
129
|
+
|
|
130
|
+
# Install VisCy and dependencies
|
|
131
|
+
print("Installing VisCy and dependencies...")
|
|
132
|
+
subprocess.check_call(
|
|
133
|
+
[
|
|
134
|
+
env_python,
|
|
135
|
+
"-m",
|
|
136
|
+
"pip",
|
|
137
|
+
"install",
|
|
138
|
+
"viscy",
|
|
139
|
+
"iohub",
|
|
140
|
+
"tifffile",
|
|
141
|
+
"numpy",
|
|
142
|
+
]
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
print("✓ VisCy and dependencies installed successfully")
|
|
146
|
+
|
|
147
|
+
# Download the VSCyto3D model checkpoint
|
|
148
|
+
print("Downloading VSCyto3D model checkpoint...")
|
|
149
|
+
self._download_model()
|
|
150
|
+
|
|
151
|
+
def _download_model(self):
|
|
152
|
+
"""Download the VSCyto3D checkpoint if not already present."""
|
|
153
|
+
os.makedirs(self.model_dir, exist_ok=True)
|
|
154
|
+
|
|
155
|
+
checkpoint_path = os.path.join(self.model_dir, "VSCyto3D.ckpt")
|
|
156
|
+
|
|
157
|
+
if not os.path.exists(checkpoint_path):
|
|
158
|
+
print("Downloading VSCyto3D model...")
|
|
159
|
+
url = "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D/epoch=83-step=14532-loss=0.492.ckpt"
|
|
160
|
+
try:
|
|
161
|
+
urllib.request.urlretrieve(url, checkpoint_path)
|
|
162
|
+
print(f"✓ Model checkpoint downloaded to {checkpoint_path}")
|
|
163
|
+
except Exception as e:
|
|
164
|
+
print(f"⚠ Failed to download model checkpoint: {e}")
|
|
165
|
+
print(
|
|
166
|
+
"You can manually download it from the URL above and place it in:"
|
|
167
|
+
)
|
|
168
|
+
print(f" {checkpoint_path}")
|
|
169
|
+
else:
|
|
170
|
+
print(f"✓ Model checkpoint already exists at {checkpoint_path}")
|
|
171
|
+
|
|
172
|
+
def is_package_installed(self) -> bool:
|
|
173
|
+
"""Check if VisCy is installed in the environment."""
|
|
174
|
+
if not self.is_env_created():
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
env_python = self.get_env_python_path()
|
|
178
|
+
try:
|
|
179
|
+
subprocess.check_call(
|
|
180
|
+
[env_python, "-c", "import viscy"],
|
|
181
|
+
stdout=subprocess.DEVNULL,
|
|
182
|
+
stderr=subprocess.DEVNULL,
|
|
183
|
+
)
|
|
184
|
+
return True
|
|
185
|
+
except subprocess.CalledProcessError:
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
def get_model_path(self) -> str:
|
|
189
|
+
"""Get the path to the VSCyto3D model checkpoint."""
|
|
190
|
+
return os.path.join(self.model_dir, "VSCyto3D.ckpt")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# Singleton instance
|
|
194
|
+
_viscy_env_manager = ViscyEnvironmentManager()
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def create_viscy_env() -> str:
|
|
198
|
+
"""Create the VisCy environment."""
|
|
199
|
+
return _viscy_env_manager.create_env()
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def is_env_created() -> bool:
|
|
203
|
+
"""Check if the VisCy environment exists."""
|
|
204
|
+
return _viscy_env_manager.is_env_created()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def is_viscy_installed() -> bool:
|
|
208
|
+
"""Check if VisCy is installed."""
|
|
209
|
+
return _viscy_env_manager.is_package_installed()
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def get_model_path() -> str:
|
|
213
|
+
"""Get the path to the VSCyto3D model checkpoint."""
|
|
214
|
+
return _viscy_env_manager.get_model_path()
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def run_viscy_in_env(image: np.ndarray, z_batch_size: int = 15) -> np.ndarray:
|
|
218
|
+
"""
|
|
219
|
+
Run VisCy virtual staining in the dedicated environment.
|
|
220
|
+
|
|
221
|
+
Parameters:
|
|
222
|
+
-----------
|
|
223
|
+
image : np.ndarray
|
|
224
|
+
Input image with shape (Z, Y, X)
|
|
225
|
+
z_batch_size : int
|
|
226
|
+
Number of Z slices to process at once (default: 15, required by VSCyto3D)
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
--------
|
|
230
|
+
np.ndarray
|
|
231
|
+
Virtual stained output with shape (Z, 2, Y, X) where channels are:
|
|
232
|
+
- Channel 0: Nuclei
|
|
233
|
+
- Channel 1: Membrane
|
|
234
|
+
"""
|
|
235
|
+
if not is_env_created():
|
|
236
|
+
raise RuntimeError(
|
|
237
|
+
"VisCy environment not created. Please create it first."
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
if not is_viscy_installed():
|
|
241
|
+
raise RuntimeError(
|
|
242
|
+
"VisCy not installed in environment. Please create the environment first."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
env_python = _viscy_env_manager.get_env_python_path()
|
|
246
|
+
model_path = get_model_path()
|
|
247
|
+
|
|
248
|
+
if not os.path.exists(model_path):
|
|
249
|
+
raise RuntimeError(
|
|
250
|
+
f"Model checkpoint not found at {model_path}. Please re-create the environment."
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Create a temporary file for input
|
|
254
|
+
with tempfile.NamedTemporaryFile(
|
|
255
|
+
suffix=".tif", delete=False
|
|
256
|
+
) as input_file:
|
|
257
|
+
input_path = input_file.name
|
|
258
|
+
if tifffile is not None:
|
|
259
|
+
tifffile.imwrite(input_path, image)
|
|
260
|
+
else:
|
|
261
|
+
raise ImportError("tifffile is required but not available")
|
|
262
|
+
|
|
263
|
+
# Create a temporary file for output
|
|
264
|
+
with tempfile.NamedTemporaryFile(
|
|
265
|
+
suffix=".tif", delete=False
|
|
266
|
+
) as output_file:
|
|
267
|
+
output_path = output_file.name
|
|
268
|
+
|
|
269
|
+
try:
|
|
270
|
+
# Create Python script to run in the environment
|
|
271
|
+
script = f"""
|
|
272
|
+
import numpy as np
|
|
273
|
+
import torch
|
|
274
|
+
import tifffile
|
|
275
|
+
from viscy.translation.engine import VSUNet
|
|
276
|
+
|
|
277
|
+
# Load the model
|
|
278
|
+
model = VSUNet.load_from_checkpoint(
|
|
279
|
+
"{model_path}",
|
|
280
|
+
architecture="fcmae",
|
|
281
|
+
model_config={{
|
|
282
|
+
"in_channels": 1,
|
|
283
|
+
"out_channels": 2,
|
|
284
|
+
"encoder_blocks": [3, 3, 9, 3],
|
|
285
|
+
"dims": [96, 192, 384, 768],
|
|
286
|
+
"decoder_conv_blocks": 2,
|
|
287
|
+
"stem_kernel_size": [5, 4, 4],
|
|
288
|
+
"in_stack_depth": 15,
|
|
289
|
+
"pretraining": False
|
|
290
|
+
}}
|
|
291
|
+
)
|
|
292
|
+
model.eval()
|
|
293
|
+
|
|
294
|
+
if torch.cuda.is_available():
|
|
295
|
+
model = model.cuda()
|
|
296
|
+
|
|
297
|
+
# Load input image
|
|
298
|
+
image = tifffile.imread("{input_path}")
|
|
299
|
+
n_z = image.shape[0]
|
|
300
|
+
z_batch_size = {z_batch_size}
|
|
301
|
+
n_batches = (n_z + z_batch_size - 1) // z_batch_size
|
|
302
|
+
|
|
303
|
+
all_predictions = []
|
|
304
|
+
|
|
305
|
+
for batch_idx in range(n_batches):
|
|
306
|
+
start_z = batch_idx * z_batch_size
|
|
307
|
+
end_z = min((batch_idx + 1) * z_batch_size, n_z)
|
|
308
|
+
|
|
309
|
+
# Get batch
|
|
310
|
+
batch_data = image[start_z:end_z]
|
|
311
|
+
actual_size = batch_data.shape[0]
|
|
312
|
+
|
|
313
|
+
# Pad if necessary
|
|
314
|
+
if actual_size < z_batch_size:
|
|
315
|
+
pad_size = z_batch_size - actual_size
|
|
316
|
+
batch_data = np.pad(batch_data, ((0, pad_size), (0, 0), (0, 0)), mode='edge')
|
|
317
|
+
|
|
318
|
+
# Normalize
|
|
319
|
+
p_low, p_high = np.percentile(batch_data, [1, 99])
|
|
320
|
+
batch_data = np.clip((batch_data - p_low) / (p_high - p_low + 1e-8), 0, 1)
|
|
321
|
+
|
|
322
|
+
# Convert to tensor: (Z, Y, X) -> (1, 1, Z, Y, X)
|
|
323
|
+
batch_tensor = torch.from_numpy(batch_data.astype(np.float32))[None, None, :, :, :]
|
|
324
|
+
if torch.cuda.is_available():
|
|
325
|
+
batch_tensor = batch_tensor.cuda()
|
|
326
|
+
|
|
327
|
+
# Run prediction
|
|
328
|
+
with torch.no_grad():
|
|
329
|
+
pred = model(batch_tensor) # (1, 2, Z, Y, X)
|
|
330
|
+
|
|
331
|
+
# Process output: (2, Z, Y, X) -> (Z, 2, Y, X)
|
|
332
|
+
pred_np = pred[0].cpu().numpy().transpose(1, 0, 2, 3)[:actual_size]
|
|
333
|
+
all_predictions.append(pred_np)
|
|
334
|
+
|
|
335
|
+
# Free memory
|
|
336
|
+
del batch_data, batch_tensor, pred
|
|
337
|
+
if torch.cuda.is_available():
|
|
338
|
+
torch.cuda.empty_cache()
|
|
339
|
+
|
|
340
|
+
# Concatenate all predictions: (Z, 2, Y, X)
|
|
341
|
+
full_prediction = np.concatenate(all_predictions, axis=0)
|
|
342
|
+
|
|
343
|
+
# Save output
|
|
344
|
+
tifffile.imwrite("{output_path}", full_prediction)
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
# Write script to temporary file
|
|
348
|
+
with tempfile.NamedTemporaryFile(
|
|
349
|
+
mode="w", suffix=".py", delete=False
|
|
350
|
+
) as script_file:
|
|
351
|
+
script_path = script_file.name
|
|
352
|
+
script_file.write(script)
|
|
353
|
+
|
|
354
|
+
# Run the script in the environment
|
|
355
|
+
result = subprocess.run(
|
|
356
|
+
[env_python, script_path],
|
|
357
|
+
capture_output=True,
|
|
358
|
+
text=True,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if result.returncode != 0:
|
|
362
|
+
raise RuntimeError(
|
|
363
|
+
f"VisCy processing failed:\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}"
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Load the output
|
|
367
|
+
if tifffile is not None:
|
|
368
|
+
output_image = tifffile.imread(output_path)
|
|
369
|
+
else:
|
|
370
|
+
raise ImportError("tifffile is required but not available")
|
|
371
|
+
|
|
372
|
+
return output_image
|
|
373
|
+
|
|
374
|
+
finally:
|
|
375
|
+
# Clean up temporary files
|
|
376
|
+
with contextlib.suppress(Exception):
|
|
377
|
+
os.unlink(input_path)
|
|
378
|
+
with contextlib.suppress(Exception):
|
|
379
|
+
os.unlink(output_path)
|
|
380
|
+
with contextlib.suppress(Exception):
|
|
381
|
+
os.unlink(script_path)
|