octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
from monai.inferers import sliding_window_inference
|
|
2
|
+
from typing import List, Optional, Union
|
|
3
|
+
from octopi.datasets import io as dataio
|
|
4
|
+
import torch, copick, gc, os, pprint
|
|
5
|
+
from copick_utils.io import writers
|
|
6
|
+
from octopi.models import common
|
|
7
|
+
from octopi.utils import io
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from monai.transforms import (
|
|
12
|
+
Compose,
|
|
13
|
+
NormalizeIntensity,
|
|
14
|
+
EnsureChannelFirst,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
class Predictor:
|
|
18
|
+
|
|
19
|
+
def __init__(self,
|
|
20
|
+
config: str,
|
|
21
|
+
model_config: Union[str, List[str]],
|
|
22
|
+
model_weights: Union[str, List[str]],
|
|
23
|
+
apply_tta: bool = True,
|
|
24
|
+
device: Optional[str] = None):
|
|
25
|
+
|
|
26
|
+
# Open the Copick Project
|
|
27
|
+
self.config = config
|
|
28
|
+
self.root = copick.from_file(config)
|
|
29
|
+
|
|
30
|
+
# Get the number of GPUs available
|
|
31
|
+
num_gpus = torch.cuda.device_count()
|
|
32
|
+
if num_gpus == 0:
|
|
33
|
+
raise RuntimeError("No GPUs available.")
|
|
34
|
+
|
|
35
|
+
# Set the device
|
|
36
|
+
if device is None:
|
|
37
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
38
|
+
else:
|
|
39
|
+
self.device = device
|
|
40
|
+
print('Running Inference On: ', self.device)
|
|
41
|
+
|
|
42
|
+
# Initialize TTA if enabled
|
|
43
|
+
self.apply_tta = apply_tta
|
|
44
|
+
self.create_tta_augmentations()
|
|
45
|
+
|
|
46
|
+
# Determine if Model Soup is Enabled
|
|
47
|
+
if isinstance(model_weights, str):
|
|
48
|
+
model_weights = [model_weights]
|
|
49
|
+
self.apply_modelsoup = len(model_weights) > 1
|
|
50
|
+
self.model_weights = model_weights
|
|
51
|
+
|
|
52
|
+
# Sliding Window Inference Parameters
|
|
53
|
+
self.sw_bs = 4 # sliding window batch size
|
|
54
|
+
self.overlap = 0.5 # overlap between windows
|
|
55
|
+
self.sw = None
|
|
56
|
+
|
|
57
|
+
# Handle Single Model Config or Multiple Model Configs
|
|
58
|
+
if isinstance(model_config, str):
|
|
59
|
+
model_config = [model_config] * len(model_weights)
|
|
60
|
+
elif len(model_config) != len(model_weights):
|
|
61
|
+
raise ValueError("Number of model configs must match number of model weights.")
|
|
62
|
+
self.model_config = model_config
|
|
63
|
+
|
|
64
|
+
# Load the model(s)
|
|
65
|
+
self._load_models(model_config, model_weights)
|
|
66
|
+
|
|
67
|
+
def _load_models(self, model_config: List[str], model_weights: List[str]):
|
|
68
|
+
"""Load a single model or multiple models for soup."""
|
|
69
|
+
|
|
70
|
+
self.models = []
|
|
71
|
+
for i, (config_path, weights_path) in enumerate(zip(model_config, model_weights)):
|
|
72
|
+
|
|
73
|
+
# Load the Model Config and Model Builder
|
|
74
|
+
current_modelconfig = io.load_yaml(config_path)
|
|
75
|
+
model_builder = common.get_model(current_modelconfig['model']['architecture'])
|
|
76
|
+
|
|
77
|
+
# Check if the weights file exists
|
|
78
|
+
if not os.path.exists(weights_path):
|
|
79
|
+
raise ValueError(f"Model weights file does not exist: {weights_path}")
|
|
80
|
+
|
|
81
|
+
# Create model
|
|
82
|
+
model_builder.build_model(current_modelconfig['model'])
|
|
83
|
+
model = model_builder.model
|
|
84
|
+
|
|
85
|
+
# Load weights
|
|
86
|
+
state_dict = torch.load(weights_path, map_location=self.device, weights_only=True)
|
|
87
|
+
model.load_state_dict(state_dict)
|
|
88
|
+
model.to(self.device)
|
|
89
|
+
model.eval()
|
|
90
|
+
|
|
91
|
+
self.models.append(model)
|
|
92
|
+
|
|
93
|
+
# For backward compatibility, also set self.model to the first model
|
|
94
|
+
self.model = self.models[0]
|
|
95
|
+
|
|
96
|
+
# Set the Number of Classes and Input Dimensions - Assume All Models are the Same
|
|
97
|
+
self.Nclass = current_modelconfig['model']['num_classes']
|
|
98
|
+
self.dim_in = current_modelconfig['model']['dim_in']
|
|
99
|
+
self.input_dim = None
|
|
100
|
+
|
|
101
|
+
# Print a message if Model Soup is Enabled
|
|
102
|
+
if self.apply_modelsoup:
|
|
103
|
+
print(f'Model Soup is Enabled : {len(self.models)} models loaded for ensemble inference')
|
|
104
|
+
|
|
105
|
+
def predict(self, input_data):
|
|
106
|
+
"""Run Prediction from an Input Tomogram.
|
|
107
|
+
Args:
|
|
108
|
+
input_data (torch.Tensor or np.ndarray): Input tomogram of shape [Z, Y, X]
|
|
109
|
+
Returns:
|
|
110
|
+
Predicted segmentation mask of shape [Z, Y, X]
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
is_numpy = False
|
|
114
|
+
if isinstance(input_data, np.ndarray):
|
|
115
|
+
is_numpy = True
|
|
116
|
+
input_data = torch.from_numpy(input_data)
|
|
117
|
+
|
|
118
|
+
# Apply transforms directly to tensor (no dictionary needed)
|
|
119
|
+
pre_transforms = Compose([
|
|
120
|
+
EnsureChannelFirst(channel_dim="no_channel"),
|
|
121
|
+
NormalizeIntensity(),
|
|
122
|
+
])
|
|
123
|
+
|
|
124
|
+
input_data = pre_transforms(input_data)
|
|
125
|
+
|
|
126
|
+
# Add batch dimension and move to device
|
|
127
|
+
input_data = input_data.unsqueeze(0).to(self.device)
|
|
128
|
+
|
|
129
|
+
# Run inference
|
|
130
|
+
pred = self._run_inference(input_data)[0]
|
|
131
|
+
|
|
132
|
+
if is_numpy:
|
|
133
|
+
pred = pred.cpu().numpy()
|
|
134
|
+
return pred
|
|
135
|
+
|
|
136
|
+
def _run_single_model_inference(self, model, input_data):
|
|
137
|
+
"""Run sliding window inference on a single model."""
|
|
138
|
+
with torch.cuda.amp.autocast():
|
|
139
|
+
return sliding_window_inference(
|
|
140
|
+
inputs=input_data,
|
|
141
|
+
roi_size=(self.dim_in, self.dim_in, self.dim_in),
|
|
142
|
+
sw_batch_size=self.sw_bs,
|
|
143
|
+
predictor=model,
|
|
144
|
+
overlap=self.overlap,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def _apply_tta_single_model(self, model, single_sample):
|
|
148
|
+
"""Apply TTA to a single model and single sample."""
|
|
149
|
+
# Initialize probability accumulator
|
|
150
|
+
acc_probs = torch.zeros(
|
|
151
|
+
(1, self.Nclass, *single_sample.shape[2:]),
|
|
152
|
+
dtype=torch.float32, device=self.device
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Process each augmentation
|
|
156
|
+
with torch.inference_mode():
|
|
157
|
+
for tta_transform, inverse_transform in zip(self.tta_transforms, self.inverse_tta_transforms):
|
|
158
|
+
# Apply transform
|
|
159
|
+
aug_sample = tta_transform(single_sample)
|
|
160
|
+
|
|
161
|
+
# Run inference
|
|
162
|
+
predictions = self._run_single_model_inference(model, aug_sample)
|
|
163
|
+
|
|
164
|
+
# Get softmax probabilities
|
|
165
|
+
probs = torch.softmax(predictions[0], dim=0)
|
|
166
|
+
|
|
167
|
+
# Apply inverse transform
|
|
168
|
+
inv_probs = inverse_transform(probs)
|
|
169
|
+
|
|
170
|
+
# Accumulate probabilities
|
|
171
|
+
acc_probs[0] += inv_probs
|
|
172
|
+
|
|
173
|
+
# Clear memory
|
|
174
|
+
del predictions, probs, inv_probs, aug_sample
|
|
175
|
+
torch.cuda.empty_cache()
|
|
176
|
+
|
|
177
|
+
# Average accumulated probabilities
|
|
178
|
+
acc_probs = acc_probs / len(self.tta_transforms)
|
|
179
|
+
|
|
180
|
+
return acc_probs[0] # Return shape [Nclass, Z, Y, X]
|
|
181
|
+
|
|
182
|
+
def _run_inference(self, input_data):
|
|
183
|
+
"""
|
|
184
|
+
Main inference function that handles all combinations - Model Soup and/or TTA
|
|
185
|
+
"""
|
|
186
|
+
# Overwrite sw_bs with sw if provided
|
|
187
|
+
if self.sw is not None:
|
|
188
|
+
self.sw_bs = self.sw
|
|
189
|
+
|
|
190
|
+
# Get the batch size (# of tomograms)
|
|
191
|
+
batch_size = input_data.shape[0]
|
|
192
|
+
results = []
|
|
193
|
+
|
|
194
|
+
# Process one sample at a time for memory efficiency
|
|
195
|
+
for sample_idx in range(batch_size):
|
|
196
|
+
single_sample = input_data[sample_idx:sample_idx+1]
|
|
197
|
+
|
|
198
|
+
# Initialize probability accumulator for this sample
|
|
199
|
+
acc_probs = torch.zeros(
|
|
200
|
+
(self.Nclass, *single_sample.shape[2:]),
|
|
201
|
+
dtype=torch.float32, device=self.device
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Process each model
|
|
205
|
+
with torch.inference_mode():
|
|
206
|
+
for model in self.models:
|
|
207
|
+
# Apply TTA with this model
|
|
208
|
+
if self.apply_tta:
|
|
209
|
+
model_probs = self._apply_tta_single_model(model, single_sample)
|
|
210
|
+
# Run inference without TTA
|
|
211
|
+
else:
|
|
212
|
+
predictions = self._run_single_model_inference(model, single_sample)
|
|
213
|
+
model_probs = torch.softmax(predictions[0], dim=0)
|
|
214
|
+
del predictions
|
|
215
|
+
|
|
216
|
+
# Accumulate probabilities from this model
|
|
217
|
+
acc_probs += model_probs
|
|
218
|
+
del model_probs
|
|
219
|
+
torch.cuda.empty_cache()
|
|
220
|
+
|
|
221
|
+
# Average probabilities across models (and TTA augmentations if applied)
|
|
222
|
+
acc_probs = acc_probs / len(self.models)
|
|
223
|
+
|
|
224
|
+
# Convert to discrete prediction
|
|
225
|
+
discrete_pred = torch.argmax(acc_probs, dim=0)
|
|
226
|
+
results.append(discrete_pred)
|
|
227
|
+
|
|
228
|
+
# Clear memory
|
|
229
|
+
del acc_probs, discrete_pred
|
|
230
|
+
torch.cuda.empty_cache()
|
|
231
|
+
|
|
232
|
+
return results
|
|
233
|
+
|
|
234
|
+
def predict_on_gpu(self,
|
|
235
|
+
runIDs: List[str],
|
|
236
|
+
voxel_spacing: float,
|
|
237
|
+
tomo_algorithm: str):
|
|
238
|
+
|
|
239
|
+
# Load data for the current batch
|
|
240
|
+
test_loader, test_dataset = dataio.create_predict_dataloader(
|
|
241
|
+
self.root,
|
|
242
|
+
voxel_spacing, tomo_algorithm,
|
|
243
|
+
runIDs)
|
|
244
|
+
|
|
245
|
+
# Determine Input Crop Size.
|
|
246
|
+
if self.input_dim is None:
|
|
247
|
+
self.input_dim = dataio.get_input_dimensions(test_dataset, self.dim_in)
|
|
248
|
+
|
|
249
|
+
predictions = []
|
|
250
|
+
with torch.inference_mode():
|
|
251
|
+
for data in tqdm(test_loader):
|
|
252
|
+
tomogram = data['image'].to(self.device)
|
|
253
|
+
data['pred'] = self._run_inference(tomogram)
|
|
254
|
+
|
|
255
|
+
for idx in range(len(data['image'])):
|
|
256
|
+
predictions.append(data['pred'][idx].squeeze(0).numpy(force=True))
|
|
257
|
+
|
|
258
|
+
return predictions
|
|
259
|
+
|
|
260
|
+
def batch_predict(self,
|
|
261
|
+
num_tomos_per_batch: int = 15,
|
|
262
|
+
runIDs: Optional[List[str]] = None,
|
|
263
|
+
voxel_spacing: float = 10,
|
|
264
|
+
tomo_algorithm: str = 'denoised',
|
|
265
|
+
segmentation_name: str = 'prediction',
|
|
266
|
+
segmentation_user_id: str = 'octopi',
|
|
267
|
+
segmentation_session_id: str = '0'):
|
|
268
|
+
"""Run inference on tomograms in batches."""
|
|
269
|
+
|
|
270
|
+
# Print Save Inference Parameters
|
|
271
|
+
self.save_parameters(tomo_algorithm, voxel_spacing, [segmentation_name, segmentation_user_id, segmentation_session_id])
|
|
272
|
+
|
|
273
|
+
# If runIDs are not provided, load all runs
|
|
274
|
+
if runIDs is None:
|
|
275
|
+
runIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is not None]
|
|
276
|
+
skippedRunIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is None]
|
|
277
|
+
|
|
278
|
+
if skippedRunIDs:
|
|
279
|
+
print(f"Warning: skipping runs with no voxel spacing {voxel_spacing}: {skippedRunIDs}")
|
|
280
|
+
|
|
281
|
+
# Iterate over batches of runIDs
|
|
282
|
+
for i in range(0, len(runIDs), num_tomos_per_batch):
|
|
283
|
+
# Get a batch of runIDs
|
|
284
|
+
batch_ids = runIDs[i:i + num_tomos_per_batch]
|
|
285
|
+
print('Running Inference on the Following RunIDs: ', batch_ids)
|
|
286
|
+
|
|
287
|
+
predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
|
|
288
|
+
|
|
289
|
+
# Save Predictions to Corresponding RunID
|
|
290
|
+
for ind in range(len(batch_ids)):
|
|
291
|
+
run = self.root.get_run(batch_ids[ind])
|
|
292
|
+
seg = predictions[ind]
|
|
293
|
+
writers.segmentation(run, seg, segmentation_user_id, segmentation_name,
|
|
294
|
+
segmentation_session_id, voxel_spacing)
|
|
295
|
+
|
|
296
|
+
# After processing and saving predictions for a batch:
|
|
297
|
+
del predictions # Remove reference to the list holding prediction arrays
|
|
298
|
+
torch.cuda.empty_cache() # Clear unused GPU memory
|
|
299
|
+
gc.collect() # Trigger garbage collection for CPU memory
|
|
300
|
+
|
|
301
|
+
print('✅ Predictions Complete!')
|
|
302
|
+
|
|
303
|
+
def create_tta_augmentations(self):
|
|
304
|
+
"""Define TTA augmentations and inverse transforms."""
|
|
305
|
+
# Rotate around the YZ plane (dims 3,4 for input, dims 2,3 for output)
|
|
306
|
+
self.tta_transforms = [
|
|
307
|
+
lambda x: x, # Identity (no augmentation)
|
|
308
|
+
lambda x: torch.rot90(x, k=1, dims=(3, 4)), # 90° rotation
|
|
309
|
+
lambda x: torch.rot90(x, k=2, dims=(3, 4)), # 180° rotation
|
|
310
|
+
lambda x: torch.rot90(x, k=3, dims=(3, 4)), # 270° rotation
|
|
311
|
+
]
|
|
312
|
+
|
|
313
|
+
# Define inverse transformations (flip back to original orientation)
|
|
314
|
+
self.inverse_tta_transforms = [
|
|
315
|
+
lambda x: x, # Identity (no transformation needed)
|
|
316
|
+
lambda x: torch.rot90(x, k=-1, dims=(2, 3)), # Inverse of 90° (i.e. -90°)
|
|
317
|
+
lambda x: torch.rot90(x, k=-2, dims=(2, 3)), # Inverse of 180° (i.e. -180°)
|
|
318
|
+
lambda x: torch.rot90(x, k=-3, dims=(2, 3)), # Inverse of 270° (i.e. -270°)
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
def save_parameters(self,
|
|
322
|
+
tomo_algorithm: str,
|
|
323
|
+
voxel_size: float,
|
|
324
|
+
seg_info: List[str]
|
|
325
|
+
):
|
|
326
|
+
"""
|
|
327
|
+
Save inference parameters to a YAML file for record-keeping and reproducibility.
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
# Load the model config
|
|
331
|
+
model_config = io.load_yaml(self.model_config[0])
|
|
332
|
+
|
|
333
|
+
# Create parameters dictionary
|
|
334
|
+
params = {
|
|
335
|
+
"inputs": {
|
|
336
|
+
"config": self.config,
|
|
337
|
+
"tomo_alg": tomo_algorithm,
|
|
338
|
+
"voxel_size": voxel_size
|
|
339
|
+
},
|
|
340
|
+
'model': {
|
|
341
|
+
'configs': self.model_config,
|
|
342
|
+
'weights': self.model_weights
|
|
343
|
+
},
|
|
344
|
+
'labels': model_config['labels'],
|
|
345
|
+
"outputs": {
|
|
346
|
+
"seg_name": seg_info[0],
|
|
347
|
+
"seg_user_id": seg_info[1],
|
|
348
|
+
"seg_session_id": seg_info[2]
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
# Print the parameters
|
|
353
|
+
print(f"\nParameters for Inference (Segment Prediction):")
|
|
354
|
+
pprint.pprint(params); print()
|
|
355
|
+
|
|
356
|
+
# Save to YAML file
|
|
357
|
+
overlay_root = io.remove_prefix(self.root.config.overlay_root)
|
|
358
|
+
basepath = os.path.join(overlay_root, 'logs')
|
|
359
|
+
os.makedirs(basepath, exist_ok=True)
|
|
360
|
+
output_path = os.path.join(
|
|
361
|
+
basepath,
|
|
362
|
+
f'segment-{seg_info[1]}_{seg_info[2]}_{seg_info[0]}.yaml')
|
|
363
|
+
io.save_parameters_yaml(params, output_path)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
2
|
+
from octopi.pytorch.segmentation import Predictor
|
|
3
|
+
from typing import List, Union, Optional
|
|
4
|
+
from copick_utils.io import writers
|
|
5
|
+
import queue, torch
|
|
6
|
+
|
|
7
|
+
class MultiGPUPredictor(Predictor):
|
|
8
|
+
|
|
9
|
+
def __init__(self,
|
|
10
|
+
config: str,
|
|
11
|
+
model_config: Union[str, List[str]],
|
|
12
|
+
model_weights: Union[str, List[str]],
|
|
13
|
+
apply_tta: bool = True):
|
|
14
|
+
|
|
15
|
+
# Initialize parent normally
|
|
16
|
+
super().__init__(config, model_config, model_weights, apply_tta)
|
|
17
|
+
|
|
18
|
+
self.num_gpus = torch.cuda.device_count()
|
|
19
|
+
print(f"Available GPUs: {self.num_gpus}")
|
|
20
|
+
|
|
21
|
+
# Only create GPU-specific models if we have multiple GPUs
|
|
22
|
+
if self.num_gpus > 1:
|
|
23
|
+
self._create_gpu_models()
|
|
24
|
+
|
|
25
|
+
def _create_gpu_models(self):
|
|
26
|
+
"""Create separate model instances for each GPU."""
|
|
27
|
+
self.gpu_models = {}
|
|
28
|
+
|
|
29
|
+
for gpu_id in range(self.num_gpus):
|
|
30
|
+
device = torch.device(f'cuda:{gpu_id}')
|
|
31
|
+
gpu_models = []
|
|
32
|
+
|
|
33
|
+
# Copy each model to this GPU
|
|
34
|
+
for model in self.models:
|
|
35
|
+
gpu_model = type(model)()
|
|
36
|
+
gpu_model.load_state_dict(model.state_dict())
|
|
37
|
+
gpu_model.to(device)
|
|
38
|
+
gpu_model.eval()
|
|
39
|
+
gpu_models.append(gpu_model)
|
|
40
|
+
|
|
41
|
+
self.gpu_models[gpu_id] = gpu_models
|
|
42
|
+
print(f"Models loaded on GPU {gpu_id}")
|
|
43
|
+
|
|
44
|
+
def _run_on_gpu(self, gpu_id: int, batch_ids: List[str],
|
|
45
|
+
voxel_spacing: float, tomo_algorithm: str,
|
|
46
|
+
segmentation_name: str, segmentation_user_id: str,
|
|
47
|
+
segmentation_session_id: str):
|
|
48
|
+
"""Run inference on a specific GPU for a batch of runs."""
|
|
49
|
+
device = torch.device(f'cuda:{gpu_id}')
|
|
50
|
+
|
|
51
|
+
# Temporarily switch to this GPU's models and device
|
|
52
|
+
original_device = self.device
|
|
53
|
+
original_models = self.models
|
|
54
|
+
|
|
55
|
+
self.device = device
|
|
56
|
+
self.models = self.gpu_models[gpu_id]
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
print(f"GPU {gpu_id} processing runs: {batch_ids}")
|
|
60
|
+
|
|
61
|
+
# Run prediction using parent class method
|
|
62
|
+
predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
|
|
63
|
+
|
|
64
|
+
# Save predictions
|
|
65
|
+
for idx, run_id in enumerate(batch_ids):
|
|
66
|
+
run = self.root.get_run(run_id)
|
|
67
|
+
seg = predictions[idx]
|
|
68
|
+
writers.segmentation(run, seg, segmentation_user_id,
|
|
69
|
+
segmentation_name, segmentation_session_id,
|
|
70
|
+
voxel_spacing)
|
|
71
|
+
|
|
72
|
+
# Clean up
|
|
73
|
+
del predictions
|
|
74
|
+
torch.cuda.empty_cache()
|
|
75
|
+
|
|
76
|
+
finally:
|
|
77
|
+
# Restore original settings
|
|
78
|
+
self.device = original_device
|
|
79
|
+
self.models = original_models
|
|
80
|
+
|
|
81
|
+
def multigpu_batch_predict(self,
|
|
82
|
+
num_tomos_per_batch: int = 15,
|
|
83
|
+
runIDs: Optional[List[str]] = None,
|
|
84
|
+
voxel_spacing: float = 10,
|
|
85
|
+
tomo_algorithm: str = 'denoised',
|
|
86
|
+
segmentation_name: str = 'prediction',
|
|
87
|
+
segmentation_user_id: str = 'octopi',
|
|
88
|
+
segmentation_session_id: str = '0'):
|
|
89
|
+
"""Run inference across multiple GPUs using threading."""
|
|
90
|
+
|
|
91
|
+
# Get runIDs if not provided
|
|
92
|
+
if runIDs is None:
|
|
93
|
+
runIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is not None]
|
|
94
|
+
skippedRunIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is None]
|
|
95
|
+
if skippedRunIDs:
|
|
96
|
+
print(f"Warning: skipping runs with no voxel spacing {voxel_spacing}: {skippedRunIDs}")
|
|
97
|
+
|
|
98
|
+
# Split runIDs into batches
|
|
99
|
+
batches = [runIDs[i:i + num_tomos_per_batch]
|
|
100
|
+
for i in range(0, len(runIDs), num_tomos_per_batch)]
|
|
101
|
+
|
|
102
|
+
print(f"Processing {len(batches)} batches across {self.num_gpus} GPUs")
|
|
103
|
+
|
|
104
|
+
# Create work queue
|
|
105
|
+
batch_queue = queue.Queue()
|
|
106
|
+
for batch in batches:
|
|
107
|
+
batch_queue.put(batch)
|
|
108
|
+
|
|
109
|
+
def worker(gpu_id):
|
|
110
|
+
while True:
|
|
111
|
+
try:
|
|
112
|
+
batch_ids = batch_queue.get_nowait()
|
|
113
|
+
self._run_on_gpu(gpu_id, batch_ids, voxel_spacing, tomo_algorithm,
|
|
114
|
+
segmentation_name, segmentation_user_id,
|
|
115
|
+
segmentation_session_id)
|
|
116
|
+
batch_queue.task_done()
|
|
117
|
+
except queue.Empty:
|
|
118
|
+
break
|
|
119
|
+
except Exception as e:
|
|
120
|
+
print(f"Error on GPU {gpu_id}: {e}")
|
|
121
|
+
batch_queue.task_done()
|
|
122
|
+
|
|
123
|
+
# Start worker threads for each GPU
|
|
124
|
+
with ThreadPoolExecutor(max_workers=self.num_gpus) as executor:
|
|
125
|
+
futures = [executor.submit(worker, gpu_id) for gpu_id in range(self.num_gpus)]
|
|
126
|
+
for future in futures:
|
|
127
|
+
future.result()
|
|
128
|
+
|
|
129
|
+
print('Multi-GPU predictions complete!')
|
|
130
|
+
|
|
131
|
+
def batch_predict(self,
|
|
132
|
+
num_tomos_per_batch: int = 15,
|
|
133
|
+
runIDs: Optional[List[str]] = None,
|
|
134
|
+
voxel_spacing: float = 10,
|
|
135
|
+
tomo_algorithm: str = 'denoised',
|
|
136
|
+
segmentation_name: str = 'prediction',
|
|
137
|
+
segmentation_user_id: str = 'octopi',
|
|
138
|
+
segmentation_session_id: str = '0'):
|
|
139
|
+
"""Smart batch predict: uses multi-GPU if available, otherwise single GPU."""
|
|
140
|
+
|
|
141
|
+
if self.num_gpus > 1:
|
|
142
|
+
print("Using multi-GPU inference")
|
|
143
|
+
self.multigpu_batch_predict(
|
|
144
|
+
num_tomos_per_batch=num_tomos_per_batch,
|
|
145
|
+
runIDs=runIDs,
|
|
146
|
+
voxel_spacing=voxel_spacing,
|
|
147
|
+
tomo_algorithm=tomo_algorithm,
|
|
148
|
+
segmentation_name=segmentation_name,
|
|
149
|
+
segmentation_user_id=segmentation_user_id,
|
|
150
|
+
segmentation_session_id=segmentation_session_id
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
print("Using single GPU inference")
|
|
154
|
+
super().batch_predict(
|
|
155
|
+
num_tomos_per_batch=num_tomos_per_batch,
|
|
156
|
+
runIDs=runIDs,
|
|
157
|
+
voxel_spacing=voxel_spacing,
|
|
158
|
+
tomo_algorithm=tomo_algorithm,
|
|
159
|
+
segmentation_name=segmentation_name,
|
|
160
|
+
segmentation_user_id=segmentation_user_id,
|
|
161
|
+
segmentation_session_id=segmentation_session_id
|
|
162
|
+
)
|