octopi 1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +1 -0
- octopi/datasets/cached_datset.py +1 -1
- octopi/datasets/generators.py +1 -1
- octopi/datasets/io.py +200 -0
- octopi/datasets/multi_config_generator.py +1 -1
- octopi/entry_points/common.py +5 -5
- octopi/entry_points/create_slurm_submission.py +1 -1
- octopi/entry_points/run_create_targets.py +6 -6
- octopi/entry_points/run_evaluate.py +4 -3
- octopi/entry_points/run_extract_mb_picks.py +5 -5
- octopi/entry_points/run_localize.py +8 -9
- octopi/entry_points/run_optuna.py +7 -7
- octopi/entry_points/run_segment_predict.py +4 -4
- octopi/entry_points/run_train.py +7 -8
- octopi/extract/localize.py +11 -19
- octopi/extract/membranebound_extract.py +11 -10
- octopi/extract/midpoint_extract.py +3 -3
- octopi/models/common.py +1 -1
- octopi/processing/create_targets_from_picks.py +3 -4
- octopi/processing/evaluate.py +24 -11
- octopi/processing/importers.py +4 -4
- octopi/pytorch/hyper_search.py +2 -3
- octopi/pytorch/model_search_submitter.py +4 -4
- octopi/pytorch/segmentation.py +141 -190
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +2 -2
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +128 -0
- octopi/{utils.py → utils/parsers.py} +10 -84
- octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
- octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
- octopi/workflows.py +236 -0
- {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/METADATA +41 -29
- octopi-1.2.0.dist-info/RECORD +62 -0
- {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
- octopi-1.2.0.dist-info/entry_points.txt +3 -0
- {octopi-1.1.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
- octopi/io.py +0 -457
- octopi/processing/my_metrics.py +0 -26
- octopi/processing/writers.py +0 -102
- octopi-1.1.dist-info/RECORD +0 -59
- octopi-1.1.dist-info/entry_points.txt +0 -4
- /octopi/{losses.py → utils/losses.py} +0 -0
- /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
octopi/pytorch/segmentation.py
CHANGED
|
@@ -5,11 +5,12 @@ from monai.data import MetaTensor
|
|
|
5
5
|
from monai.transforms import (
|
|
6
6
|
Compose, AsDiscrete, Activations
|
|
7
7
|
)
|
|
8
|
-
import
|
|
8
|
+
from typing import List, Optional, Union
|
|
9
|
+
from octopi.datasets import io as dataio
|
|
10
|
+
from copick_utils.io import writers
|
|
9
11
|
from octopi.models import common
|
|
10
|
-
from typing import List, Optional
|
|
11
12
|
import torch, copick, gc, os
|
|
12
|
-
from octopi import io
|
|
13
|
+
from octopi.utils import io
|
|
13
14
|
from tqdm import tqdm
|
|
14
15
|
import numpy as np
|
|
15
16
|
|
|
@@ -17,20 +18,14 @@ class Predictor:
|
|
|
17
18
|
|
|
18
19
|
def __init__(self,
|
|
19
20
|
config: str,
|
|
20
|
-
model_config: str,
|
|
21
|
-
model_weights: str,
|
|
21
|
+
model_config: Union[str, List[str]],
|
|
22
|
+
model_weights: Union[str, List[str]],
|
|
22
23
|
apply_tta: bool = True,
|
|
23
24
|
device: Optional[str] = None):
|
|
24
25
|
|
|
26
|
+
# Open the Copick Project
|
|
25
27
|
self.config = config
|
|
26
28
|
self.root = copick.from_file(config)
|
|
27
|
-
|
|
28
|
-
# Load the model config
|
|
29
|
-
model_config = utils.load_yaml(model_config)
|
|
30
|
-
|
|
31
|
-
self.Nclass = model_config['model']['num_classes']
|
|
32
|
-
self.dim_in = model_config['model']['dim_in']
|
|
33
|
-
self.input_dim = None
|
|
34
29
|
|
|
35
30
|
# Get the number of GPUs available
|
|
36
31
|
num_gpus = torch.cuda.device_count()
|
|
@@ -44,109 +39,147 @@ class Predictor:
|
|
|
44
39
|
self.device = device
|
|
45
40
|
print('Running Inference On: ', self.device)
|
|
46
41
|
|
|
47
|
-
#
|
|
48
|
-
|
|
49
|
-
|
|
42
|
+
# Initialize TTA if enabled
|
|
43
|
+
self.apply_tta = apply_tta
|
|
44
|
+
self.create_tta_augmentations()
|
|
50
45
|
|
|
51
|
-
#
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
self.
|
|
55
|
-
state_dict = torch.load(model_weights, map_location=self.device, weights_only=True)
|
|
56
|
-
self.model.load_state_dict(state_dict)
|
|
57
|
-
self.model.to(self.device)
|
|
58
|
-
self.model.eval()
|
|
46
|
+
# Determine if Model Soup is Enabled
|
|
47
|
+
if isinstance(model_weights, str):
|
|
48
|
+
model_weights = [model_weights]
|
|
49
|
+
self.apply_modelsoup = len(model_weights) > 1
|
|
59
50
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
51
|
+
# Handle Single Model Config or Multiple Model Configs
|
|
52
|
+
if isinstance(model_config, str):
|
|
53
|
+
model_config = [model_config] * len(model_weights)
|
|
54
|
+
elif len(model_config) != len(model_weights):
|
|
55
|
+
raise ValueError("Number of model configs must match number of model weights.")
|
|
56
|
+
|
|
57
|
+
# Load the model(s)
|
|
58
|
+
self._load_models(model_config, model_weights)
|
|
59
|
+
|
|
60
|
+
def _load_models(self, model_config: List[str], model_weights: List[str]):
|
|
61
|
+
"""Load a single model or multiple models for soup."""
|
|
62
|
+
|
|
63
|
+
self.models = []
|
|
64
|
+
for i, (config_path, weights_path) in enumerate(zip(model_config, model_weights)):
|
|
65
|
+
|
|
66
|
+
# Load the Model Config and Model Builder
|
|
67
|
+
current_modelconfig = io.load_yaml(config_path)
|
|
68
|
+
model_builder = common.get_model(current_modelconfig['model']['architecture'])
|
|
69
|
+
|
|
70
|
+
# Check if the weights file exists
|
|
71
|
+
if not os.path.exists(weights_path):
|
|
72
|
+
raise ValueError(f"Model weights file does not exist: {weights_path}")
|
|
70
73
|
|
|
71
|
-
# Create
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
74
|
+
# Create model
|
|
75
|
+
model_builder.build_model(current_modelconfig['model'])
|
|
76
|
+
model = model_builder.model
|
|
77
|
+
|
|
78
|
+
# Load weights
|
|
79
|
+
state_dict = torch.load(weights_path, map_location=self.device, weights_only=True)
|
|
80
|
+
model.load_state_dict(state_dict)
|
|
81
|
+
model.to(self.device)
|
|
82
|
+
model.eval()
|
|
79
83
|
|
|
84
|
+
self.models.append(model)
|
|
80
85
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
predictor=self.model,
|
|
89
|
-
overlap=0.5,
|
|
90
|
-
)
|
|
91
|
-
return [self.post_transforms(i) for i in decollate_batch(predictions)]
|
|
86
|
+
# For backward compatibility, also set self.model to the first model
|
|
87
|
+
self.model = self.models[0]
|
|
88
|
+
|
|
89
|
+
# Set the Number of Classes and Input Dimensions - Assume All Models are the Same
|
|
90
|
+
self.Nclass = current_modelconfig['model']['num_classes']
|
|
91
|
+
self.dim_in = current_modelconfig['model']['dim_in']
|
|
92
|
+
self.input_dim = None
|
|
92
93
|
|
|
93
|
-
|
|
94
|
-
|
|
94
|
+
# Print a message if Model Soup is Enabled
|
|
95
|
+
if self.apply_modelsoup:
|
|
96
|
+
print(f'Model Soup is Enabled : {len(self.models)} models loaded for ensemble inference')
|
|
97
|
+
|
|
98
|
+
def _run_single_model_inference(self, model, input_data):
|
|
99
|
+
"""Run sliding window inference on a single model."""
|
|
100
|
+
return sliding_window_inference(
|
|
101
|
+
inputs=input_data,
|
|
102
|
+
roi_size=(self.dim_in, self.dim_in, self.dim_in),
|
|
103
|
+
sw_batch_size=4,
|
|
104
|
+
predictor=model,
|
|
105
|
+
overlap=0.5,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def _apply_tta_single_model(self, model, single_sample):
|
|
109
|
+
"""Apply TTA to a single model and single sample."""
|
|
110
|
+
# Initialize probability accumulator
|
|
111
|
+
acc_probs = torch.zeros(
|
|
112
|
+
(1, self.Nclass, *single_sample.shape[2:]),
|
|
113
|
+
dtype=torch.float32, device=self.device
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Process each augmentation
|
|
117
|
+
with torch.no_grad():
|
|
118
|
+
for tta_transform, inverse_transform in zip(self.tta_transforms, self.inverse_tta_transforms):
|
|
119
|
+
# Apply transform
|
|
120
|
+
aug_sample = tta_transform(single_sample)
|
|
121
|
+
|
|
122
|
+
# Run inference
|
|
123
|
+
predictions = self._run_single_model_inference(model, aug_sample)
|
|
124
|
+
|
|
125
|
+
# Get softmax probabilities
|
|
126
|
+
probs = torch.softmax(predictions[0], dim=0)
|
|
127
|
+
|
|
128
|
+
# Apply inverse transform
|
|
129
|
+
inv_probs = inverse_transform(probs)
|
|
130
|
+
|
|
131
|
+
# Accumulate probabilities
|
|
132
|
+
acc_probs[0] += inv_probs
|
|
133
|
+
|
|
134
|
+
# Clear memory
|
|
135
|
+
del predictions, probs, inv_probs, aug_sample
|
|
136
|
+
torch.cuda.empty_cache()
|
|
137
|
+
|
|
138
|
+
# Average accumulated probabilities
|
|
139
|
+
acc_probs = acc_probs / len(self.tta_transforms)
|
|
95
140
|
|
|
141
|
+
return acc_probs[0] # Return shape [Nclass, Z, Y, X]
|
|
142
|
+
|
|
143
|
+
def _run_inference(self, input_data):
|
|
144
|
+
"""
|
|
145
|
+
Main inference function that handles all combinations - Model Soup and/or TTA
|
|
146
|
+
"""
|
|
96
147
|
batch_size = input_data.shape[0]
|
|
97
148
|
results = []
|
|
98
149
|
|
|
99
|
-
# Process one sample at a time
|
|
150
|
+
# Process one sample at a time for memory efficiency
|
|
100
151
|
for sample_idx in range(batch_size):
|
|
101
|
-
# Extract single sample
|
|
102
152
|
single_sample = input_data[sample_idx:sample_idx+1]
|
|
103
153
|
|
|
104
154
|
# Initialize probability accumulator for this sample
|
|
105
|
-
# Shape: [1, Nclass, Z, Y, X]
|
|
106
155
|
acc_probs = torch.zeros(
|
|
107
|
-
(
|
|
156
|
+
(self.Nclass, *single_sample.shape[2:]),
|
|
108
157
|
dtype=torch.float32, device=self.device
|
|
109
158
|
)
|
|
110
159
|
|
|
111
|
-
# Process each
|
|
160
|
+
# Process each model
|
|
112
161
|
with torch.no_grad():
|
|
113
|
-
for
|
|
114
|
-
# Apply
|
|
115
|
-
|
|
162
|
+
for model in self.models:
|
|
163
|
+
# Apply TTA with this model
|
|
164
|
+
if self.apply_tta:
|
|
165
|
+
model_probs = self._apply_tta_single_model(model, single_sample)
|
|
166
|
+
# Run inference without TTA
|
|
167
|
+
else:
|
|
168
|
+
predictions = self._run_single_model_inference(model, single_sample)
|
|
169
|
+
model_probs = torch.softmax(predictions[0], dim=0)
|
|
170
|
+
del predictions
|
|
116
171
|
|
|
117
|
-
#
|
|
172
|
+
# Accumulate probabilities from this model
|
|
173
|
+
acc_probs += model_probs
|
|
174
|
+
del model_probs
|
|
118
175
|
torch.cuda.empty_cache()
|
|
119
|
-
|
|
120
|
-
# Run inference (one sample at a time)
|
|
121
|
-
predictions = sliding_window_inference(
|
|
122
|
-
inputs=aug_sample,
|
|
123
|
-
roi_size=(self.dim_in, self.dim_in, self.dim_in),
|
|
124
|
-
sw_batch_size=4, # Process one window at a time
|
|
125
|
-
predictor=self.model,
|
|
126
|
-
overlap=0.5,
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
# Get softmax probabilities
|
|
130
|
-
probs = self.softmax_transform(predictions[0]) # Get first (only) item
|
|
131
|
-
|
|
132
|
-
# Apply inverse transform with correct dimensions
|
|
133
|
-
inv_probs = inverse_transform(probs)
|
|
134
|
-
|
|
135
|
-
# Accumulate probabilities
|
|
136
|
-
acc_probs[0] += inv_probs
|
|
137
|
-
|
|
138
|
-
# Clear memory
|
|
139
|
-
del predictions, probs, inv_probs, aug_sample
|
|
140
176
|
|
|
141
|
-
# Average
|
|
142
|
-
acc_probs = acc_probs / len(self.
|
|
177
|
+
# Average probabilities across models (and TTA augmentations if applied)
|
|
178
|
+
acc_probs = acc_probs / len(self.models)
|
|
143
179
|
|
|
144
|
-
# Convert to discrete prediction
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
# Add to results - keeping only the spatial dimensions [Z, Y, X]
|
|
149
|
-
results.append(discrete_pred[0])
|
|
180
|
+
# Convert to discrete prediction
|
|
181
|
+
discrete_pred = torch.argmax(acc_probs, dim=0)
|
|
182
|
+
results.append(discrete_pred)
|
|
150
183
|
|
|
151
184
|
# Clear memory
|
|
152
185
|
del acc_probs, discrete_pred
|
|
@@ -157,38 +190,37 @@ class Predictor:
|
|
|
157
190
|
def predict_on_gpu(self,
|
|
158
191
|
runIDs: List[str],
|
|
159
192
|
voxel_spacing: float,
|
|
160
|
-
tomo_algorithm: str
|
|
193
|
+
tomo_algorithm: str):
|
|
161
194
|
|
|
162
195
|
# Load data for the current batch
|
|
163
|
-
test_loader, test_dataset =
|
|
196
|
+
test_loader, test_dataset = dataio.create_predict_dataloader(
|
|
164
197
|
self.root,
|
|
165
198
|
voxel_spacing, tomo_algorithm,
|
|
166
199
|
runIDs)
|
|
167
200
|
|
|
168
201
|
# Determine Input Crop Size.
|
|
169
202
|
if self.input_dim is None:
|
|
170
|
-
self.input_dim =
|
|
203
|
+
self.input_dim = dataio.get_input_dimensions(test_dataset, self.dim_in)
|
|
171
204
|
|
|
172
205
|
predictions = []
|
|
173
206
|
with torch.no_grad():
|
|
174
207
|
for data in tqdm(test_loader):
|
|
175
208
|
tomogram = data['image'].to(self.device)
|
|
176
|
-
|
|
177
|
-
|
|
209
|
+
data['pred'] = self._run_inference(tomogram)
|
|
210
|
+
|
|
178
211
|
for idx in range(len(data['image'])):
|
|
179
212
|
predictions.append(data['pred'][idx].squeeze(0).numpy(force=True))
|
|
180
213
|
|
|
181
214
|
return predictions
|
|
182
215
|
|
|
183
216
|
def batch_predict(self,
|
|
184
|
-
num_tomos_per_batch = 15,
|
|
185
|
-
runIDs: Optional[str] = None,
|
|
217
|
+
num_tomos_per_batch: int = 15,
|
|
218
|
+
runIDs: Optional[List[str]] = None,
|
|
186
219
|
voxel_spacing: float = 10,
|
|
187
220
|
tomo_algorithm: str = 'denoised',
|
|
188
221
|
segmentation_name: str = 'prediction',
|
|
189
222
|
segmentation_user_id: str = 'octopi',
|
|
190
223
|
segmentation_session_id: str = '0'):
|
|
191
|
-
|
|
192
224
|
"""Run inference on tomograms in batches."""
|
|
193
225
|
|
|
194
226
|
# If runIDs are not provided, load all runs
|
|
@@ -201,10 +233,9 @@ class Predictor:
|
|
|
201
233
|
|
|
202
234
|
# Iterate over batches of runIDs
|
|
203
235
|
for i in range(0, len(runIDs), num_tomos_per_batch):
|
|
204
|
-
|
|
205
236
|
# Get a batch of runIDs
|
|
206
237
|
batch_ids = runIDs[i:i + num_tomos_per_batch]
|
|
207
|
-
print('Running Inference on the
|
|
238
|
+
print('Running Inference on the Following RunIDs: ', batch_ids)
|
|
208
239
|
|
|
209
240
|
predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
|
|
210
241
|
|
|
@@ -212,7 +243,7 @@ class Predictor:
|
|
|
212
243
|
for ind in range(len(batch_ids)):
|
|
213
244
|
run = self.root.get_run(batch_ids[ind])
|
|
214
245
|
seg = predictions[ind]
|
|
215
|
-
|
|
246
|
+
writers.segmentation(run, seg, segmentation_user_id, segmentation_name,
|
|
216
247
|
segmentation_session_id, voxel_spacing)
|
|
217
248
|
|
|
218
249
|
# After processing and saving predictions for a batch:
|
|
@@ -224,98 +255,18 @@ class Predictor:
|
|
|
224
255
|
|
|
225
256
|
def create_tta_augmentations(self):
|
|
226
257
|
"""Define TTA augmentations and inverse transforms."""
|
|
227
|
-
|
|
228
|
-
# Instead of Flip lets rotate around the first axis 3 times (90,180,270)
|
|
258
|
+
# Rotate around the YZ plane (dims 3,4 for input, dims 2,3 for output)
|
|
229
259
|
self.tta_transforms = [
|
|
230
|
-
lambda x: x,
|
|
231
|
-
lambda x: torch.rot90(x, k=1, dims=(3, 4)),
|
|
232
|
-
lambda x: torch.rot90(x, k=2, dims=(3, 4)),
|
|
233
|
-
lambda x: torch.rot90(x, k=3, dims=(3, 4)),
|
|
234
|
-
# lambda x: torch.flip(x, dims=(3,)), # Flip along height (spatial_axis=1)
|
|
235
|
-
# lambda x: torch.flip(x, dims=(4,)), # Flip along width (spatial_axis=2)
|
|
236
|
-
# lambda x: torch.flip(x, dims=(3, 4)), # Flip along both height and width
|
|
260
|
+
lambda x: x, # Identity (no augmentation)
|
|
261
|
+
lambda x: torch.rot90(x, k=1, dims=(3, 4)), # 90° rotation
|
|
262
|
+
lambda x: torch.rot90(x, k=2, dims=(3, 4)), # 180° rotation
|
|
263
|
+
lambda x: torch.rot90(x, k=3, dims=(3, 4)), # 270° rotation
|
|
237
264
|
]
|
|
238
265
|
|
|
239
266
|
# Define inverse transformations (flip back to original orientation)
|
|
240
267
|
self.inverse_tta_transforms = [
|
|
241
|
-
lambda x: x,
|
|
268
|
+
lambda x: x, # Identity (no transformation needed)
|
|
242
269
|
lambda x: torch.rot90(x, k=-1, dims=(2, 3)), # Inverse of 90° (i.e. -90°)
|
|
243
270
|
lambda x: torch.rot90(x, k=-2, dims=(2, 3)), # Inverse of 180° (i.e. -180°)
|
|
244
271
|
lambda x: torch.rot90(x, k=-3, dims=(2, 3)), # Inverse of 270° (i.e. -270°)
|
|
245
|
-
|
|
246
|
-
# lambda x: torch.flip(x, dims=(3,)), # Same as forward
|
|
247
|
-
# lambda x: torch.flip(x, dims=(2, 3)), # Same as forward
|
|
248
|
-
]
|
|
249
|
-
|
|
250
|
-
###################################################################################################################################################
|
|
251
|
-
|
|
252
|
-
class MultiGPUPredictor(Predictor):
|
|
253
|
-
|
|
254
|
-
def __init__(self,
|
|
255
|
-
config: str,
|
|
256
|
-
model_config: str,
|
|
257
|
-
model_weights: str):
|
|
258
|
-
super().__init__(config, model_config, model_weights)
|
|
259
|
-
self.num_gpus = torch.cuda.device_count()
|
|
260
|
-
if self.num_gpus < 2:
|
|
261
|
-
raise RuntimeError("MultiGPUPredictor requires at least 2 GPUs.")
|
|
262
|
-
|
|
263
|
-
def predict_on_gpu(self, gpu_id: int, batch_ids: List[str], voxel_spacing: float, tomo_algorithm: str) -> List[np.ndarray]:
|
|
264
|
-
"""Helper function to run inference on a single GPU."""
|
|
265
|
-
device = torch.device(f'cuda:{gpu_id}')
|
|
266
|
-
self.model.to(device)
|
|
267
|
-
|
|
268
|
-
# Load data specific to the batch assigned to this GPU
|
|
269
|
-
test_loader = io.load_predict_data(self.root, batch_ids, voxel_spacing, tomo_algorithm)
|
|
270
|
-
predictions = []
|
|
271
|
-
|
|
272
|
-
with torch.no_grad():
|
|
273
|
-
for data in tqdm(test_loader, desc=f"GPU {gpu_id}"):
|
|
274
|
-
tomogram = data['image'].to(device)
|
|
275
|
-
data["prediction"] = self.run_inference(tomogram)
|
|
276
|
-
data = [self.post_processing(i) for i in decollate_batch(data)]
|
|
277
|
-
for b in data:
|
|
278
|
-
predictions.append(b['prediction'].squeeze(0).cpu().numpy())
|
|
279
|
-
|
|
280
|
-
return predictions
|
|
281
|
-
|
|
282
|
-
def multi_gpu_inference(self,
|
|
283
|
-
num_tomos_per_batch: int = 15,
|
|
284
|
-
runIDs: Optional[List[str]] = None,
|
|
285
|
-
voxel_spacing: float = 10,
|
|
286
|
-
tomo_algorithm: str = 'denoised',
|
|
287
|
-
save: bool = False,
|
|
288
|
-
segmentation_name: str = 'prediction',
|
|
289
|
-
segmentation_user_id: str = 'monai',
|
|
290
|
-
segmentation_session_id: str = '0') -> Optional[List[np.ndarray]]:
|
|
291
|
-
"""Run inference across multiple GPUs, optionally saving results or returning predictions."""
|
|
292
|
-
|
|
293
|
-
runIDs = runIDs or [run.name for run in self.root.runs]
|
|
294
|
-
all_predictions = []
|
|
295
|
-
|
|
296
|
-
# Divide runIDs into batches for each GPU
|
|
297
|
-
batches = [runIDs[i:i + num_tomos_per_batch] for i in range(0, len(run_ids), num_tomos_per_batch)]
|
|
298
|
-
|
|
299
|
-
# Run inference in parallel across GPUs
|
|
300
|
-
for i in range(0, len(batches), self.num_gpus):
|
|
301
|
-
gpu_batches = batches[i:i + self.num_gpus]
|
|
302
|
-
with Pool(processes=self.num_gpus) as pool:
|
|
303
|
-
results = pool.starmap(
|
|
304
|
-
self.predict_on_gpu,
|
|
305
|
-
[(gpu_id, gpu_batches[gpu_id], voxel_spacing, tomo_algorithm) for gpu_id in range(len(gpu_batches))]
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
# Collect and save results
|
|
309
|
-
for gpu_id, predictions in enumerate(results):
|
|
310
|
-
if save:
|
|
311
|
-
for idx, run_id in enumerate(gpu_batches[gpu_id]):
|
|
312
|
-
run = self.root.get_run(run_id)
|
|
313
|
-
segmentation = predictions[idx]
|
|
314
|
-
write.segmentation(run, segmentation, segmentation_user_id, segmentation_name,
|
|
315
|
-
segmentation_session_id, voxel_spacing)
|
|
316
|
-
else:
|
|
317
|
-
all_predictions.extend(predictions)
|
|
318
|
-
|
|
319
|
-
print('Multi-GPU predictions complete.')
|
|
320
|
-
|
|
321
|
-
return None if save else all_predictions
|
|
272
|
+
]
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
2
|
+
from octopi.pytorch.segmentation import Predictor
|
|
3
|
+
from typing import List, Union, Optional
|
|
4
|
+
from copick_utils.io import writers
|
|
5
|
+
import queue, torch
|
|
6
|
+
|
|
7
|
+
class MultiGPUPredictor(Predictor):
|
|
8
|
+
|
|
9
|
+
def __init__(self,
|
|
10
|
+
config: str,
|
|
11
|
+
model_config: Union[str, List[str]],
|
|
12
|
+
model_weights: Union[str, List[str]],
|
|
13
|
+
apply_tta: bool = True):
|
|
14
|
+
|
|
15
|
+
# Initialize parent normally
|
|
16
|
+
super().__init__(config, model_config, model_weights, apply_tta)
|
|
17
|
+
|
|
18
|
+
self.num_gpus = torch.cuda.device_count()
|
|
19
|
+
print(f"Available GPUs: {self.num_gpus}")
|
|
20
|
+
|
|
21
|
+
# Only create GPU-specific models if we have multiple GPUs
|
|
22
|
+
if self.num_gpus > 1:
|
|
23
|
+
self._create_gpu_models()
|
|
24
|
+
|
|
25
|
+
def _create_gpu_models(self):
|
|
26
|
+
"""Create separate model instances for each GPU."""
|
|
27
|
+
self.gpu_models = {}
|
|
28
|
+
|
|
29
|
+
for gpu_id in range(self.num_gpus):
|
|
30
|
+
device = torch.device(f'cuda:{gpu_id}')
|
|
31
|
+
gpu_models = []
|
|
32
|
+
|
|
33
|
+
# Copy each model to this GPU
|
|
34
|
+
for model in self.models:
|
|
35
|
+
gpu_model = type(model)()
|
|
36
|
+
gpu_model.load_state_dict(model.state_dict())
|
|
37
|
+
gpu_model.to(device)
|
|
38
|
+
gpu_model.eval()
|
|
39
|
+
gpu_models.append(gpu_model)
|
|
40
|
+
|
|
41
|
+
self.gpu_models[gpu_id] = gpu_models
|
|
42
|
+
print(f"Models loaded on GPU {gpu_id}")
|
|
43
|
+
|
|
44
|
+
def _run_on_gpu(self, gpu_id: int, batch_ids: List[str],
|
|
45
|
+
voxel_spacing: float, tomo_algorithm: str,
|
|
46
|
+
segmentation_name: str, segmentation_user_id: str,
|
|
47
|
+
segmentation_session_id: str):
|
|
48
|
+
"""Run inference on a specific GPU for a batch of runs."""
|
|
49
|
+
device = torch.device(f'cuda:{gpu_id}')
|
|
50
|
+
|
|
51
|
+
# Temporarily switch to this GPU's models and device
|
|
52
|
+
original_device = self.device
|
|
53
|
+
original_models = self.models
|
|
54
|
+
|
|
55
|
+
self.device = device
|
|
56
|
+
self.models = self.gpu_models[gpu_id]
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
print(f"GPU {gpu_id} processing runs: {batch_ids}")
|
|
60
|
+
|
|
61
|
+
# Run prediction using parent class method
|
|
62
|
+
predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
|
|
63
|
+
|
|
64
|
+
# Save predictions
|
|
65
|
+
for idx, run_id in enumerate(batch_ids):
|
|
66
|
+
run = self.root.get_run(run_id)
|
|
67
|
+
seg = predictions[idx]
|
|
68
|
+
writers.segmentation(run, seg, segmentation_user_id,
|
|
69
|
+
segmentation_name, segmentation_session_id,
|
|
70
|
+
voxel_spacing)
|
|
71
|
+
|
|
72
|
+
# Clean up
|
|
73
|
+
del predictions
|
|
74
|
+
torch.cuda.empty_cache()
|
|
75
|
+
|
|
76
|
+
finally:
|
|
77
|
+
# Restore original settings
|
|
78
|
+
self.device = original_device
|
|
79
|
+
self.models = original_models
|
|
80
|
+
|
|
81
|
+
def multigpu_batch_predict(self,
|
|
82
|
+
num_tomos_per_batch: int = 15,
|
|
83
|
+
runIDs: Optional[List[str]] = None,
|
|
84
|
+
voxel_spacing: float = 10,
|
|
85
|
+
tomo_algorithm: str = 'denoised',
|
|
86
|
+
segmentation_name: str = 'prediction',
|
|
87
|
+
segmentation_user_id: str = 'octopi',
|
|
88
|
+
segmentation_session_id: str = '0'):
|
|
89
|
+
"""Run inference across multiple GPUs using threading."""
|
|
90
|
+
|
|
91
|
+
# Get runIDs if not provided
|
|
92
|
+
if runIDs is None:
|
|
93
|
+
runIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is not None]
|
|
94
|
+
skippedRunIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is None]
|
|
95
|
+
if skippedRunIDs:
|
|
96
|
+
print(f"Warning: skipping runs with no voxel spacing {voxel_spacing}: {skippedRunIDs}")
|
|
97
|
+
|
|
98
|
+
# Split runIDs into batches
|
|
99
|
+
batches = [runIDs[i:i + num_tomos_per_batch]
|
|
100
|
+
for i in range(0, len(runIDs), num_tomos_per_batch)]
|
|
101
|
+
|
|
102
|
+
print(f"Processing {len(batches)} batches across {self.num_gpus} GPUs")
|
|
103
|
+
|
|
104
|
+
# Create work queue
|
|
105
|
+
batch_queue = queue.Queue()
|
|
106
|
+
for batch in batches:
|
|
107
|
+
batch_queue.put(batch)
|
|
108
|
+
|
|
109
|
+
def worker(gpu_id):
|
|
110
|
+
while True:
|
|
111
|
+
try:
|
|
112
|
+
batch_ids = batch_queue.get_nowait()
|
|
113
|
+
self._run_on_gpu(gpu_id, batch_ids, voxel_spacing, tomo_algorithm,
|
|
114
|
+
segmentation_name, segmentation_user_id,
|
|
115
|
+
segmentation_session_id)
|
|
116
|
+
batch_queue.task_done()
|
|
117
|
+
except queue.Empty:
|
|
118
|
+
break
|
|
119
|
+
except Exception as e:
|
|
120
|
+
print(f"Error on GPU {gpu_id}: {e}")
|
|
121
|
+
batch_queue.task_done()
|
|
122
|
+
|
|
123
|
+
# Start worker threads for each GPU
|
|
124
|
+
with ThreadPoolExecutor(max_workers=self.num_gpus) as executor:
|
|
125
|
+
futures = [executor.submit(worker, gpu_id) for gpu_id in range(self.num_gpus)]
|
|
126
|
+
for future in futures:
|
|
127
|
+
future.result()
|
|
128
|
+
|
|
129
|
+
print('Multi-GPU predictions complete!')
|
|
130
|
+
|
|
131
|
+
def batch_predict(self,
|
|
132
|
+
num_tomos_per_batch: int = 15,
|
|
133
|
+
runIDs: Optional[List[str]] = None,
|
|
134
|
+
voxel_spacing: float = 10,
|
|
135
|
+
tomo_algorithm: str = 'denoised',
|
|
136
|
+
segmentation_name: str = 'prediction',
|
|
137
|
+
segmentation_user_id: str = 'octopi',
|
|
138
|
+
segmentation_session_id: str = '0'):
|
|
139
|
+
"""Smart batch predict: uses multi-GPU if available, otherwise single GPU."""
|
|
140
|
+
|
|
141
|
+
if self.num_gpus > 1:
|
|
142
|
+
print("Using multi-GPU inference")
|
|
143
|
+
self.multigpu_batch_predict(
|
|
144
|
+
num_tomos_per_batch=num_tomos_per_batch,
|
|
145
|
+
runIDs=runIDs,
|
|
146
|
+
voxel_spacing=voxel_spacing,
|
|
147
|
+
tomo_algorithm=tomo_algorithm,
|
|
148
|
+
segmentation_name=segmentation_name,
|
|
149
|
+
segmentation_user_id=segmentation_user_id,
|
|
150
|
+
segmentation_session_id=segmentation_session_id
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
print("Using single GPU inference")
|
|
154
|
+
super().batch_predict(
|
|
155
|
+
num_tomos_per_batch=num_tomos_per_batch,
|
|
156
|
+
runIDs=runIDs,
|
|
157
|
+
voxel_spacing=voxel_spacing,
|
|
158
|
+
tomo_algorithm=tomo_algorithm,
|
|
159
|
+
segmentation_name=segmentation_name,
|
|
160
|
+
segmentation_user_id=segmentation_user_id,
|
|
161
|
+
segmentation_session_id=segmentation_session_id
|
|
162
|
+
)
|
octopi/pytorch/trainer.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from octopi import visualization_tools as viz
|
|
1
|
+
from octopi.utils import visualization_tools as viz
|
|
2
2
|
from monai.inferers import sliding_window_inference
|
|
3
|
-
from octopi import stopping_criteria
|
|
3
|
+
from octopi.utils import stopping_criteria
|
|
4
4
|
from monai.transforms import AsDiscrete
|
|
5
5
|
from monai.data import decollate_batch
|
|
6
6
|
import torch, os, mlflow, re
|
octopi/utils/__init__.py
ADDED
|
File without changes
|
octopi/utils/config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration utilities for MLflow setup and reproducibility.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dotenv import load_dotenv
|
|
6
|
+
import torch, numpy as np
|
|
7
|
+
import os, random
|
|
8
|
+
import octopi
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def mlflow_setup():
|
|
12
|
+
"""
|
|
13
|
+
Set up MLflow configuration from environment variables.
|
|
14
|
+
"""
|
|
15
|
+
module_root = os.path.dirname(octopi.__file__)
|
|
16
|
+
dotenv_path = module_root.replace('src/octopi','') + '.env'
|
|
17
|
+
load_dotenv(dotenv_path=dotenv_path)
|
|
18
|
+
|
|
19
|
+
# MLflow setup
|
|
20
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
21
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
22
|
+
if not password or not username:
|
|
23
|
+
print("Password not found in environment, loading from .env file...")
|
|
24
|
+
load_dotenv() # Loads environment variables from a .env file
|
|
25
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
26
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
27
|
+
|
|
28
|
+
# Check again after loading .env file
|
|
29
|
+
if not password:
|
|
30
|
+
raise ValueError("Password is not set in environment variables or .env file!")
|
|
31
|
+
else:
|
|
32
|
+
print("Password loaded successfully")
|
|
33
|
+
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
34
|
+
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
35
|
+
|
|
36
|
+
return os.getenv('MLFLOW_TRACKING_URI')
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def set_seed(seed):
|
|
40
|
+
"""
|
|
41
|
+
Set random seeds for reproducibility across Python, NumPy, and PyTorch.
|
|
42
|
+
"""
|
|
43
|
+
# Set the seed for Python's random module
|
|
44
|
+
random.seed(seed)
|
|
45
|
+
|
|
46
|
+
# Set the seed for NumPy
|
|
47
|
+
np.random.seed(seed)
|
|
48
|
+
|
|
49
|
+
# Set the seed for PyTorch (both CPU and GPU)
|
|
50
|
+
torch.manual_seed(seed)
|
|
51
|
+
if torch.cuda.is_available():
|
|
52
|
+
torch.cuda.manual_seed(seed)
|
|
53
|
+
torch.cuda.manual_seed_all(seed) # If using multi-GPU
|
|
54
|
+
|
|
55
|
+
# Ensure reproducibility of operations by disabling certain optimizations
|
|
56
|
+
torch.backends.cudnn.deterministic = True
|
|
57
|
+
torch.backends.cudnn.benchmark = False
|