napari-tmidas 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,381 @@
1
+ # processing_functions/viscy_env_manager.py
2
+ """
3
+ This module manages a dedicated virtual environment for VisCy (Virtual Staining of Cells using deep learning).
4
+ """
5
+
6
+ import contextlib
7
+ import os
8
+ import subprocess
9
+ import tempfile
10
+ import urllib.request
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+
15
+ from napari_tmidas._env_manager import BaseEnvironmentManager
16
+
17
+ try:
18
+ import tifffile
19
+ except ImportError:
20
+ tifffile = None
21
+
22
+
23
+ class ViscyEnvironmentManager(BaseEnvironmentManager):
24
+ """Environment manager for VisCy."""
25
+
26
+ def __init__(self):
27
+ super().__init__("viscy")
28
+ # Model directory in the environment
29
+ self.model_dir = os.path.join(self.env_dir, "models")
30
+
31
+ def _install_dependencies(self, env_python: str) -> None:
32
+ """Install VisCy-specific dependencies."""
33
+ # Try to detect if CUDA is available
34
+ cuda_available = False
35
+ try:
36
+ import torch
37
+
38
+ cuda_available = torch.cuda.is_available()
39
+ if cuda_available:
40
+ print("CUDA is available in main environment")
41
+ if torch.cuda.device_count() > 0:
42
+ gpu_name = torch.cuda.get_device_name(0)
43
+ print(f"GPU detected: {gpu_name}")
44
+ else:
45
+ print("CUDA is not available in main environment")
46
+ except ImportError:
47
+ print("PyTorch not detected in main environment")
48
+ # Try to detect CUDA from nvidia-smi
49
+ try:
50
+ result = subprocess.run(
51
+ ["nvidia-smi"], capture_output=True, text=True
52
+ )
53
+ if result.returncode == 0:
54
+ cuda_available = True
55
+ print("NVIDIA GPU detected via nvidia-smi")
56
+ else:
57
+ cuda_available = False
58
+ print("No NVIDIA GPU detected")
59
+ except FileNotFoundError:
60
+ cuda_available = False
61
+ print("nvidia-smi not found, assuming no CUDA support")
62
+
63
+ if cuda_available:
64
+ print("Attempting PyTorch installation with CUDA support...")
65
+ try:
66
+ subprocess.check_call(
67
+ [
68
+ env_python,
69
+ "-m",
70
+ "pip",
71
+ "install",
72
+ "torch==2.0.1",
73
+ "torchvision==0.15.2",
74
+ "--index-url",
75
+ "https://download.pytorch.org/whl/cu118",
76
+ ]
77
+ )
78
+ print("✓ PyTorch with CUDA 11.8 installed successfully")
79
+
80
+ # Test CUDA compatibility
81
+ test_script = """
82
+ import torch
83
+ try:
84
+ if torch.cuda.is_available():
85
+ test_tensor = torch.ones(1).cuda()
86
+ print("CUDA compatibility test passed")
87
+ else:
88
+ print("CUDA not available in PyTorch")
89
+ exit(1)
90
+ except Exception as e:
91
+ print(f"CUDA compatibility test failed: {e}")
92
+ exit(1)
93
+ """
94
+ result = subprocess.run(
95
+ [env_python, "-c", test_script],
96
+ capture_output=True,
97
+ text=True,
98
+ )
99
+
100
+ if result.returncode != 0:
101
+ print(
102
+ "⚠ CUDA test failed, falling back to CPU-only installation"
103
+ )
104
+ print(result.stdout)
105
+ print(result.stderr)
106
+ cuda_available = False
107
+
108
+ except subprocess.CalledProcessError as e:
109
+ print(
110
+ f"⚠ PyTorch CUDA installation failed: {e}, falling back to CPU-only"
111
+ )
112
+ cuda_available = False
113
+
114
+ if not cuda_available:
115
+ print("Installing PyTorch (CPU-only version)...")
116
+ subprocess.check_call(
117
+ [
118
+ env_python,
119
+ "-m",
120
+ "pip",
121
+ "install",
122
+ "torch==2.0.1",
123
+ "torchvision==0.15.2",
124
+ "--index-url",
125
+ "https://download.pytorch.org/whl/cpu",
126
+ ]
127
+ )
128
+ print("✓ PyTorch (CPU-only) installed successfully")
129
+
130
+ # Install VisCy and dependencies
131
+ print("Installing VisCy and dependencies...")
132
+ subprocess.check_call(
133
+ [
134
+ env_python,
135
+ "-m",
136
+ "pip",
137
+ "install",
138
+ "viscy",
139
+ "iohub",
140
+ "tifffile",
141
+ "numpy",
142
+ ]
143
+ )
144
+
145
+ print("✓ VisCy and dependencies installed successfully")
146
+
147
+ # Download the VSCyto3D model checkpoint
148
+ print("Downloading VSCyto3D model checkpoint...")
149
+ self._download_model()
150
+
151
+ def _download_model(self):
152
+ """Download the VSCyto3D checkpoint if not already present."""
153
+ os.makedirs(self.model_dir, exist_ok=True)
154
+
155
+ checkpoint_path = os.path.join(self.model_dir, "VSCyto3D.ckpt")
156
+
157
+ if not os.path.exists(checkpoint_path):
158
+ print("Downloading VSCyto3D model...")
159
+ url = "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D/epoch=83-step=14532-loss=0.492.ckpt"
160
+ try:
161
+ urllib.request.urlretrieve(url, checkpoint_path)
162
+ print(f"✓ Model checkpoint downloaded to {checkpoint_path}")
163
+ except Exception as e:
164
+ print(f"⚠ Failed to download model checkpoint: {e}")
165
+ print(
166
+ "You can manually download it from the URL above and place it in:"
167
+ )
168
+ print(f" {checkpoint_path}")
169
+ else:
170
+ print(f"✓ Model checkpoint already exists at {checkpoint_path}")
171
+
172
+ def is_package_installed(self) -> bool:
173
+ """Check if VisCy is installed in the environment."""
174
+ if not self.is_env_created():
175
+ return False
176
+
177
+ env_python = self.get_env_python_path()
178
+ try:
179
+ subprocess.check_call(
180
+ [env_python, "-c", "import viscy"],
181
+ stdout=subprocess.DEVNULL,
182
+ stderr=subprocess.DEVNULL,
183
+ )
184
+ return True
185
+ except subprocess.CalledProcessError:
186
+ return False
187
+
188
+ def get_model_path(self) -> str:
189
+ """Get the path to the VSCyto3D model checkpoint."""
190
+ return os.path.join(self.model_dir, "VSCyto3D.ckpt")
191
+
192
+
193
+ # Singleton instance
194
+ _viscy_env_manager = ViscyEnvironmentManager()
195
+
196
+
197
+ def create_viscy_env() -> str:
198
+ """Create the VisCy environment."""
199
+ return _viscy_env_manager.create_env()
200
+
201
+
202
+ def is_env_created() -> bool:
203
+ """Check if the VisCy environment exists."""
204
+ return _viscy_env_manager.is_env_created()
205
+
206
+
207
+ def is_viscy_installed() -> bool:
208
+ """Check if VisCy is installed."""
209
+ return _viscy_env_manager.is_package_installed()
210
+
211
+
212
+ def get_model_path() -> str:
213
+ """Get the path to the VSCyto3D model checkpoint."""
214
+ return _viscy_env_manager.get_model_path()
215
+
216
+
217
+ def run_viscy_in_env(image: np.ndarray, z_batch_size: int = 15) -> np.ndarray:
218
+ """
219
+ Run VisCy virtual staining in the dedicated environment.
220
+
221
+ Parameters:
222
+ -----------
223
+ image : np.ndarray
224
+ Input image with shape (Z, Y, X)
225
+ z_batch_size : int
226
+ Number of Z slices to process at once (default: 15, required by VSCyto3D)
227
+
228
+ Returns:
229
+ --------
230
+ np.ndarray
231
+ Virtual stained output with shape (Z, 2, Y, X) where channels are:
232
+ - Channel 0: Nuclei
233
+ - Channel 1: Membrane
234
+ """
235
+ if not is_env_created():
236
+ raise RuntimeError(
237
+ "VisCy environment not created. Please create it first."
238
+ )
239
+
240
+ if not is_viscy_installed():
241
+ raise RuntimeError(
242
+ "VisCy not installed in environment. Please create the environment first."
243
+ )
244
+
245
+ env_python = _viscy_env_manager.get_env_python_path()
246
+ model_path = get_model_path()
247
+
248
+ if not os.path.exists(model_path):
249
+ raise RuntimeError(
250
+ f"Model checkpoint not found at {model_path}. Please re-create the environment."
251
+ )
252
+
253
+ # Create a temporary file for input
254
+ with tempfile.NamedTemporaryFile(
255
+ suffix=".tif", delete=False
256
+ ) as input_file:
257
+ input_path = input_file.name
258
+ if tifffile is not None:
259
+ tifffile.imwrite(input_path, image)
260
+ else:
261
+ raise ImportError("tifffile is required but not available")
262
+
263
+ # Create a temporary file for output
264
+ with tempfile.NamedTemporaryFile(
265
+ suffix=".tif", delete=False
266
+ ) as output_file:
267
+ output_path = output_file.name
268
+
269
+ try:
270
+ # Create Python script to run in the environment
271
+ script = f"""
272
+ import numpy as np
273
+ import torch
274
+ import tifffile
275
+ from viscy.translation.engine import VSUNet
276
+
277
+ # Load the model
278
+ model = VSUNet.load_from_checkpoint(
279
+ "{model_path}",
280
+ architecture="fcmae",
281
+ model_config={{
282
+ "in_channels": 1,
283
+ "out_channels": 2,
284
+ "encoder_blocks": [3, 3, 9, 3],
285
+ "dims": [96, 192, 384, 768],
286
+ "decoder_conv_blocks": 2,
287
+ "stem_kernel_size": [5, 4, 4],
288
+ "in_stack_depth": 15,
289
+ "pretraining": False
290
+ }}
291
+ )
292
+ model.eval()
293
+
294
+ if torch.cuda.is_available():
295
+ model = model.cuda()
296
+
297
+ # Load input image
298
+ image = tifffile.imread("{input_path}")
299
+ n_z = image.shape[0]
300
+ z_batch_size = {z_batch_size}
301
+ n_batches = (n_z + z_batch_size - 1) // z_batch_size
302
+
303
+ all_predictions = []
304
+
305
+ for batch_idx in range(n_batches):
306
+ start_z = batch_idx * z_batch_size
307
+ end_z = min((batch_idx + 1) * z_batch_size, n_z)
308
+
309
+ # Get batch
310
+ batch_data = image[start_z:end_z]
311
+ actual_size = batch_data.shape[0]
312
+
313
+ # Pad if necessary
314
+ if actual_size < z_batch_size:
315
+ pad_size = z_batch_size - actual_size
316
+ batch_data = np.pad(batch_data, ((0, pad_size), (0, 0), (0, 0)), mode='edge')
317
+
318
+ # Normalize
319
+ p_low, p_high = np.percentile(batch_data, [1, 99])
320
+ batch_data = np.clip((batch_data - p_low) / (p_high - p_low + 1e-8), 0, 1)
321
+
322
+ # Convert to tensor: (Z, Y, X) -> (1, 1, Z, Y, X)
323
+ batch_tensor = torch.from_numpy(batch_data.astype(np.float32))[None, None, :, :, :]
324
+ if torch.cuda.is_available():
325
+ batch_tensor = batch_tensor.cuda()
326
+
327
+ # Run prediction
328
+ with torch.no_grad():
329
+ pred = model(batch_tensor) # (1, 2, Z, Y, X)
330
+
331
+ # Process output: (2, Z, Y, X) -> (Z, 2, Y, X)
332
+ pred_np = pred[0].cpu().numpy().transpose(1, 0, 2, 3)[:actual_size]
333
+ all_predictions.append(pred_np)
334
+
335
+ # Free memory
336
+ del batch_data, batch_tensor, pred
337
+ if torch.cuda.is_available():
338
+ torch.cuda.empty_cache()
339
+
340
+ # Concatenate all predictions: (Z, 2, Y, X)
341
+ full_prediction = np.concatenate(all_predictions, axis=0)
342
+
343
+ # Save output
344
+ tifffile.imwrite("{output_path}", full_prediction)
345
+ """
346
+
347
+ # Write script to temporary file
348
+ with tempfile.NamedTemporaryFile(
349
+ mode="w", suffix=".py", delete=False
350
+ ) as script_file:
351
+ script_path = script_file.name
352
+ script_file.write(script)
353
+
354
+ # Run the script in the environment
355
+ result = subprocess.run(
356
+ [env_python, script_path],
357
+ capture_output=True,
358
+ text=True,
359
+ )
360
+
361
+ if result.returncode != 0:
362
+ raise RuntimeError(
363
+ f"VisCy processing failed:\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}"
364
+ )
365
+
366
+ # Load the output
367
+ if tifffile is not None:
368
+ output_image = tifffile.imread(output_path)
369
+ else:
370
+ raise ImportError("tifffile is required but not available")
371
+
372
+ return output_image
373
+
374
+ finally:
375
+ # Clean up temporary files
376
+ with contextlib.suppress(Exception):
377
+ os.unlink(input_path)
378
+ with contextlib.suppress(Exception):
379
+ os.unlink(output_path)
380
+ with contextlib.suppress(Exception):
381
+ os.unlink(script_path)