octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
from octopi.datasets import generators, multi_config_generator
|
|
2
|
+
from octopi.pytorch import hyper_search
|
|
3
|
+
import torch, mlflow, optuna
|
|
4
|
+
from octopi import utils
|
|
5
|
+
from typing import List
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
class ModelSearchSubmit:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
copick_config: str,
|
|
12
|
+
target_name: str,
|
|
13
|
+
target_user_id: str,
|
|
14
|
+
target_session_id: str,
|
|
15
|
+
tomo_algorithm: str,
|
|
16
|
+
voxel_size: float,
|
|
17
|
+
Nclass: int,
|
|
18
|
+
model_type: str,
|
|
19
|
+
mlflow_experiment_name: str,
|
|
20
|
+
random_seed: int,
|
|
21
|
+
num_epochs: int,
|
|
22
|
+
num_trials: int,
|
|
23
|
+
tomo_batch_size: int,
|
|
24
|
+
best_metric: str,
|
|
25
|
+
val_interval: int,
|
|
26
|
+
trainRunIDs: List[str],
|
|
27
|
+
validateRunIDs: List[str],
|
|
28
|
+
data_split: str
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Initialize the ModelSearch class for architecture search with Optuna.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
copick_config (str or dict): Path to the CoPick configuration file or a dictionary for multi-config training.
|
|
35
|
+
target_name (str): Name of the target for segmentation.
|
|
36
|
+
target_user_id (str): Optional user ID for tracking.
|
|
37
|
+
target_session_id (str): Optional session ID for tracking.
|
|
38
|
+
tomo_algorithm (str): Tomogram algorithm to use.
|
|
39
|
+
voxel_size (float): Voxel size for tomograms.
|
|
40
|
+
Nclass (int): Number of prediction classes.
|
|
41
|
+
model_type (str): Type of model to use.
|
|
42
|
+
mlflow_experiment_name (str): MLflow experiment name.
|
|
43
|
+
random_seed (int): Seed for reproducibility.
|
|
44
|
+
num_epochs (int): Number of epochs per trial.
|
|
45
|
+
num_trials (int): Number of trials for hyperparameter optimization.
|
|
46
|
+
tomo_batch_size (int): Batch size for tomogram loading.
|
|
47
|
+
best_metric (str): Metric to optimize.
|
|
48
|
+
val_interval (int): Validation interval.
|
|
49
|
+
trainRunIDs (List[str]): List of training run IDs.
|
|
50
|
+
validateRunIDs (List[str]): List of validation run IDs.
|
|
51
|
+
data_split (str): Data split ratios.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
# Input parameters
|
|
55
|
+
self.copick_config = copick_config
|
|
56
|
+
self.target_name = target_name
|
|
57
|
+
self.target_user_id = target_user_id
|
|
58
|
+
self.target_session_id = target_session_id
|
|
59
|
+
self.tomo_algorithm = tomo_algorithm
|
|
60
|
+
self.voxel_size = voxel_size
|
|
61
|
+
self.Nclass = Nclass
|
|
62
|
+
self.model_type = model_type
|
|
63
|
+
self.mlflow_experiment_name = mlflow_experiment_name
|
|
64
|
+
self.random_seed = random_seed
|
|
65
|
+
self.num_epochs = num_epochs
|
|
66
|
+
self.num_trials = num_trials
|
|
67
|
+
self.tomo_batch_size = tomo_batch_size
|
|
68
|
+
self.best_metric = best_metric
|
|
69
|
+
self.val_interval = val_interval
|
|
70
|
+
self.trainRunIDs = trainRunIDs
|
|
71
|
+
self.validateRunIDs = validateRunIDs
|
|
72
|
+
self.data_split = data_split
|
|
73
|
+
|
|
74
|
+
# Data generator - will be initialized in _initialize_data_generator()
|
|
75
|
+
self.data_generator = None
|
|
76
|
+
|
|
77
|
+
# Set random seed for reproducibility
|
|
78
|
+
utils.set_seed(self.random_seed)
|
|
79
|
+
|
|
80
|
+
# Initialize dataset generator
|
|
81
|
+
self._initialize_data_generator()
|
|
82
|
+
|
|
83
|
+
def _initialize_data_generator(self):
|
|
84
|
+
"""Initializes the data generator for training and validation datasets."""
|
|
85
|
+
self._print_input_configs()
|
|
86
|
+
|
|
87
|
+
if isinstance(self.copick_config, dict):
|
|
88
|
+
self.data_generator = multi_config_generator.MultiConfigTrainLoaderManager(
|
|
89
|
+
self.copick_config,
|
|
90
|
+
self.target_name,
|
|
91
|
+
target_session_id=self.target_session_id,
|
|
92
|
+
target_user_id=self.target_user_id,
|
|
93
|
+
tomo_algorithm=self.tomo_algorithm,
|
|
94
|
+
voxel_size=self.voxel_size,
|
|
95
|
+
Nclasses=self.Nclass,
|
|
96
|
+
tomo_batch_size=self.tomo_batch_size
|
|
97
|
+
)
|
|
98
|
+
else:
|
|
99
|
+
self.data_generator = generators.TrainLoaderManager(
|
|
100
|
+
self.copick_config,
|
|
101
|
+
self.target_name,
|
|
102
|
+
target_session_id=self.target_session_id,
|
|
103
|
+
target_user_id=self.target_user_id,
|
|
104
|
+
tomo_algorithm=self.tomo_algorithm,
|
|
105
|
+
voxel_size=self.voxel_size,
|
|
106
|
+
Nclasses=self.Nclass,
|
|
107
|
+
tomo_batch_size=self.tomo_batch_size
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Split datasets into training and validation
|
|
111
|
+
ratios = utils.parse_data_split(self.data_split)
|
|
112
|
+
self.data_generator.get_data_splits(
|
|
113
|
+
trainRunIDs=self.trainRunIDs,
|
|
114
|
+
validateRunIDs=self.validateRunIDs,
|
|
115
|
+
train_ratio = ratios[0], val_ratio = ratios[1], test_ratio = ratios[2],
|
|
116
|
+
create_test_dataset = False
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Get the reload frequency
|
|
120
|
+
self.data_generator.get_reload_frequency(self.num_epochs)
|
|
121
|
+
|
|
122
|
+
def _print_input_configs(self):
|
|
123
|
+
"""Prints training configuration for debugging purposes."""
|
|
124
|
+
print(f'\nTraining with:')
|
|
125
|
+
if isinstance(self.copick_config, dict):
|
|
126
|
+
for session, config in self.copick_config.items():
|
|
127
|
+
print(f' {session}: {config}')
|
|
128
|
+
else:
|
|
129
|
+
print(f' {self.copick_config}')
|
|
130
|
+
print()
|
|
131
|
+
|
|
132
|
+
def run_model_search(self):
|
|
133
|
+
"""Performs model architecture search using Optuna and MLflow."""
|
|
134
|
+
|
|
135
|
+
# Set up MLflow tracking
|
|
136
|
+
try:
|
|
137
|
+
tracking_uri = utils.mlflow_setup()
|
|
138
|
+
mlflow.set_tracking_uri(tracking_uri)
|
|
139
|
+
except Exception as e:
|
|
140
|
+
print(f'Failed to set up MLflow tracking: {e}')
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
mlflow.set_experiment(self.mlflow_experiment_name)
|
|
144
|
+
|
|
145
|
+
# Create a storage object with heartbeat configuration
|
|
146
|
+
storage_url = f"sqlite:///explore_results_{self.model_type}/trials.db"
|
|
147
|
+
self.storage = optuna.storages.RDBStorage(
|
|
148
|
+
url=storage_url,
|
|
149
|
+
heartbeat_interval=60, # Record heartbeat every minute
|
|
150
|
+
grace_period=600, # 10 minutes grace period
|
|
151
|
+
failed_trial_callback=optuna.storages.RetryFailedTrialCallback(max_retry=1)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Detect GPU availability
|
|
155
|
+
gpu_count = torch.cuda.device_count()
|
|
156
|
+
print(f'Running Architecture Search Over {gpu_count} GPUs\n')
|
|
157
|
+
|
|
158
|
+
# Initialize model search object
|
|
159
|
+
if gpu_count > 1:
|
|
160
|
+
self._multi_gpu_optuna(gpu_count)
|
|
161
|
+
else:
|
|
162
|
+
model_search = hyper_search.BayesianModelSearch(self.data_generator, self.model_type)
|
|
163
|
+
self._single_gpu_optuna(model_search)
|
|
164
|
+
|
|
165
|
+
def _single_gpu_optuna(self, model_search):
|
|
166
|
+
"""Runs Optuna optimization on a single GPU."""
|
|
167
|
+
|
|
168
|
+
with mlflow.start_run(nested=False) as parent_run:
|
|
169
|
+
|
|
170
|
+
model_search.parent_run_id = parent_run.info.run_id
|
|
171
|
+
model_search.parent_run_name = parent_run.info.run_name
|
|
172
|
+
|
|
173
|
+
# Log the experiment parameters
|
|
174
|
+
mlflow.log_params({"random_seed": self.random_seed})
|
|
175
|
+
mlflow.log_params(self.data_generator.get_dataloader_parameters())
|
|
176
|
+
mlflow.log_params({"parent_run_name": parent_run.info.run_name})
|
|
177
|
+
|
|
178
|
+
# Determine device
|
|
179
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
180
|
+
|
|
181
|
+
# Create the study and run the optimization
|
|
182
|
+
study = self.get_optuna_study()
|
|
183
|
+
study.optimize(
|
|
184
|
+
lambda trial: model_search.objective(
|
|
185
|
+
trial, self.num_epochs, device,
|
|
186
|
+
val_interval=self.val_interval,
|
|
187
|
+
best_metric=self.best_metric,
|
|
188
|
+
),
|
|
189
|
+
n_trials=self.num_trials
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Save contour plot
|
|
193
|
+
self.save_contour_plot_as_png(study)
|
|
194
|
+
|
|
195
|
+
print(f"Best trial: {study.best_trial.value}")
|
|
196
|
+
print(f"Best params: {study.best_params}")
|
|
197
|
+
|
|
198
|
+
def _multi_gpu_optuna(self, gpu_count):
|
|
199
|
+
"""Runs Optuna optimization on multiple GPUs."""
|
|
200
|
+
with mlflow.start_run() as parent_run:
|
|
201
|
+
|
|
202
|
+
# Log the experiment parameters
|
|
203
|
+
mlflow.log_params({"random_seed": self.random_seed})
|
|
204
|
+
mlflow.log_params(self.data_generator.get_dataloader_parameters())
|
|
205
|
+
mlflow.log_params({"parent_run_name": parent_run.info.run_name})
|
|
206
|
+
|
|
207
|
+
# Run multi-GPU optimization
|
|
208
|
+
study = self.get_optuna_study()
|
|
209
|
+
study.optimize(
|
|
210
|
+
lambda trial: BayesianModelSearch(self.data_generator, self.model_type).multi_gpu_objective(
|
|
211
|
+
parent_run, trial,
|
|
212
|
+
self.num_epochs,
|
|
213
|
+
best_metric=self.best_metric,
|
|
214
|
+
val_interval=self.val_interval,
|
|
215
|
+
gpu_count=gpu_count
|
|
216
|
+
),
|
|
217
|
+
n_trials=self.num_trials,
|
|
218
|
+
n_jobs=gpu_count
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Save contour Plot
|
|
222
|
+
self.save_contour_plot_as_png(study)
|
|
223
|
+
|
|
224
|
+
print(f"Best trial: {study.best_trial.value}")
|
|
225
|
+
print(f"Best params: {study.best_params}")
|
|
226
|
+
|
|
227
|
+
def _get_optuna_sampler(self):
|
|
228
|
+
"""Returns Optuna's TPE sampler with default settings."""
|
|
229
|
+
return optuna.samplers.TPESampler(
|
|
230
|
+
n_startup_trials=10,
|
|
231
|
+
n_ei_candidates=24,
|
|
232
|
+
multivariate=True
|
|
233
|
+
)
|
|
234
|
+
# return optuna.samplers.BoTorchSampler(
|
|
235
|
+
# n_startup_trials=10,
|
|
236
|
+
# multivariate=True
|
|
237
|
+
# )
|
|
238
|
+
|
|
239
|
+
def get_optuna_study(self):
|
|
240
|
+
"""Returns the Optuna study object."""
|
|
241
|
+
return optuna.create_study(
|
|
242
|
+
storage=self.storage,
|
|
243
|
+
direction="maximize",
|
|
244
|
+
sampler=self._get_optuna_sampler(),
|
|
245
|
+
load_if_exists=True,
|
|
246
|
+
pruner=self._get_optuna_pruner()
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def _get_optuna_pruner(self):
|
|
250
|
+
"""Returns Optuna's pruning strategy."""
|
|
251
|
+
return optuna.pruners.MedianPruner()
|
|
252
|
+
|
|
253
|
+
def save_contour_plot_as_png(self, study):
|
|
254
|
+
"""
|
|
255
|
+
Save the contour plot of hyperparameter interactions as a PNG,
|
|
256
|
+
automatically extracting parameter names from the study object.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
study: The Optuna study object.
|
|
260
|
+
output_path: Path to save the PNG file.
|
|
261
|
+
"""
|
|
262
|
+
# Extract all parameter names from the study trials
|
|
263
|
+
all_params = set()
|
|
264
|
+
for trial in study.trials:
|
|
265
|
+
all_params.update(trial.params.keys())
|
|
266
|
+
all_params = list(all_params) # Convert to a sorted list for consistency
|
|
267
|
+
|
|
268
|
+
# Generate the contour plot
|
|
269
|
+
fig = optuna.visualization.plot_contour(study, params=all_params)
|
|
270
|
+
|
|
271
|
+
# Adjust figure size and font size
|
|
272
|
+
fig.update_layout(
|
|
273
|
+
width=6000, height=6000, # Large figure size
|
|
274
|
+
font=dict(size=40) # Increase font size for better readability
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Save the plot as a PNG file
|
|
278
|
+
fig.write_image(f'explore_results_{self.model_type}/contour_plot.png', scale=1)
|
|
279
|
+
|
|
280
|
+
# Extract trial data
|
|
281
|
+
trials = [
|
|
282
|
+
{**trial.params, 'objective_value': trial.value}
|
|
283
|
+
for trial in study.trials if trial.state == optuna.trial.TrialState.COMPLETE
|
|
284
|
+
]
|
|
285
|
+
|
|
286
|
+
# Convert to DataFrame
|
|
287
|
+
df = pd.DataFrame(trials)
|
|
288
|
+
|
|
289
|
+
# Save to CSV
|
|
290
|
+
df.to_csv(f"explore_results_{self.model_type}/optuna_results.csv", index=False)
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
from monai.inferers import sliding_window_inference
|
|
2
|
+
from monai.data import decollate_batch
|
|
3
|
+
from torch.multiprocessing import Pool
|
|
4
|
+
from monai.data import MetaTensor
|
|
5
|
+
from monai.transforms import (
|
|
6
|
+
Compose, AsDiscrete, Activations
|
|
7
|
+
)
|
|
8
|
+
import octopi.processing.writers as write
|
|
9
|
+
from octopi.models import common
|
|
10
|
+
from typing import List, Optional
|
|
11
|
+
import torch, copick, gc, os
|
|
12
|
+
from octopi import io, utils
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
class Predictor:
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
config: str,
|
|
20
|
+
model_config: str,
|
|
21
|
+
model_weights: str,
|
|
22
|
+
apply_tta: bool = True,
|
|
23
|
+
device: Optional[str] = None):
|
|
24
|
+
|
|
25
|
+
self.config = config
|
|
26
|
+
self.root = copick.from_file(config)
|
|
27
|
+
|
|
28
|
+
# Load the model config
|
|
29
|
+
model_config = utils.load_yaml(model_config)
|
|
30
|
+
|
|
31
|
+
self.Nclass = model_config['model']['num_classes']
|
|
32
|
+
self.dim_in = model_config['model']['dim_in']
|
|
33
|
+
self.input_dim = None
|
|
34
|
+
|
|
35
|
+
# Get the number of GPUs available
|
|
36
|
+
num_gpus = torch.cuda.device_count()
|
|
37
|
+
if num_gpus == 0:
|
|
38
|
+
raise RuntimeError("No GPUs available.")
|
|
39
|
+
|
|
40
|
+
# Set the device
|
|
41
|
+
if device is None:
|
|
42
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
43
|
+
else:
|
|
44
|
+
self.device = device
|
|
45
|
+
print('Running Inference On: ', self.device)
|
|
46
|
+
|
|
47
|
+
# Check to see if the model weights file exists
|
|
48
|
+
if not os.path.exists(model_weights):
|
|
49
|
+
raise ValueError(f"Model weights file does not exist: {model_weights}")
|
|
50
|
+
|
|
51
|
+
# Load the model weights
|
|
52
|
+
model_builder = common.get_model(model_config['model']['architecture'])
|
|
53
|
+
model_builder.build_model(model_config['model'])
|
|
54
|
+
self.model = model_builder.model
|
|
55
|
+
state_dict = torch.load(model_weights, map_location=self.device, weights_only=True)
|
|
56
|
+
self.model.load_state_dict(state_dict)
|
|
57
|
+
self.model.to(self.device)
|
|
58
|
+
self.model.eval()
|
|
59
|
+
|
|
60
|
+
# Initialize TTA if enabled
|
|
61
|
+
self.apply_tta = apply_tta
|
|
62
|
+
if self.apply_tta:
|
|
63
|
+
self.create_tta_augmentations()
|
|
64
|
+
# self.post_transforms = Compose([
|
|
65
|
+
# Activations(softmax=True) # Keep probability output
|
|
66
|
+
# ])
|
|
67
|
+
self.softmax_transform = Compose([
|
|
68
|
+
Activations(softmax=True) # Keep probability output
|
|
69
|
+
])
|
|
70
|
+
|
|
71
|
+
# Create the final discretization transform
|
|
72
|
+
self.discretize_transform = AsDiscrete(argmax=True)
|
|
73
|
+
else:
|
|
74
|
+
# Define the post-processing transforms
|
|
75
|
+
self.post_transforms = Compose([
|
|
76
|
+
Activations(softmax=True),
|
|
77
|
+
AsDiscrete(argmax=True)
|
|
78
|
+
])
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _run_inference(self, input):
|
|
82
|
+
"""Apply sliding window inference to the input."""
|
|
83
|
+
with torch.no_grad():
|
|
84
|
+
predictions = sliding_window_inference(
|
|
85
|
+
inputs=input,
|
|
86
|
+
roi_size=(self.dim_in, self.dim_in, self.dim_in),
|
|
87
|
+
sw_batch_size=4, # one window is proecessed at a time
|
|
88
|
+
predictor=self.model,
|
|
89
|
+
overlap=0.5,
|
|
90
|
+
)
|
|
91
|
+
return [self.post_transforms(i) for i in decollate_batch(predictions)]
|
|
92
|
+
|
|
93
|
+
def _run_inference_tta(self, input_data):
|
|
94
|
+
"""Memory-efficient TTA implementation that returns proper discrete segmentation maps."""
|
|
95
|
+
|
|
96
|
+
batch_size = input_data.shape[0]
|
|
97
|
+
results = []
|
|
98
|
+
|
|
99
|
+
# Process one sample at a time
|
|
100
|
+
for sample_idx in range(batch_size):
|
|
101
|
+
# Extract single sample
|
|
102
|
+
single_sample = input_data[sample_idx:sample_idx+1]
|
|
103
|
+
|
|
104
|
+
# Initialize probability accumulator for this sample
|
|
105
|
+
# Shape: [1, Nclass, Z, Y, X]
|
|
106
|
+
acc_probs = torch.zeros(
|
|
107
|
+
(1, self.Nclass, *single_sample.shape[2:]),
|
|
108
|
+
dtype=torch.float32, device=self.device
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Process each augmentation
|
|
112
|
+
with torch.no_grad():
|
|
113
|
+
for tta_transform, inverse_transform in zip(self.tta_transforms, self.inverse_tta_transforms):
|
|
114
|
+
# Apply transform to single sample
|
|
115
|
+
aug_sample = tta_transform(single_sample)
|
|
116
|
+
|
|
117
|
+
# Free memory
|
|
118
|
+
torch.cuda.empty_cache()
|
|
119
|
+
|
|
120
|
+
# Run inference (one sample at a time)
|
|
121
|
+
predictions = sliding_window_inference(
|
|
122
|
+
inputs=aug_sample,
|
|
123
|
+
roi_size=(self.dim_in, self.dim_in, self.dim_in),
|
|
124
|
+
sw_batch_size=4, # Process one window at a time
|
|
125
|
+
predictor=self.model,
|
|
126
|
+
overlap=0.5,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Get softmax probabilities
|
|
130
|
+
probs = self.softmax_transform(predictions[0]) # Get first (only) item
|
|
131
|
+
|
|
132
|
+
# Apply inverse transform with correct dimensions
|
|
133
|
+
inv_probs = inverse_transform(probs)
|
|
134
|
+
|
|
135
|
+
# Accumulate probabilities
|
|
136
|
+
acc_probs[0] += inv_probs
|
|
137
|
+
|
|
138
|
+
# Clear memory
|
|
139
|
+
del predictions, probs, inv_probs, aug_sample
|
|
140
|
+
|
|
141
|
+
# Average accumulated probabilities
|
|
142
|
+
acc_probs = acc_probs / len(self.tta_transforms)
|
|
143
|
+
|
|
144
|
+
# Convert to discrete prediction - get argmax along class dimension
|
|
145
|
+
# This gives us a tensor of shape [1, Z, Y, X] with discrete class indices
|
|
146
|
+
discrete_pred = torch.argmax(acc_probs, dim=1)
|
|
147
|
+
|
|
148
|
+
# Add to results - keeping only the spatial dimensions [Z, Y, X]
|
|
149
|
+
results.append(discrete_pred[0])
|
|
150
|
+
|
|
151
|
+
# Clear memory
|
|
152
|
+
del acc_probs, discrete_pred
|
|
153
|
+
torch.cuda.empty_cache()
|
|
154
|
+
|
|
155
|
+
return results
|
|
156
|
+
|
|
157
|
+
def predict_on_gpu(self,
|
|
158
|
+
runIDs: List[str],
|
|
159
|
+
voxel_spacing: float,
|
|
160
|
+
tomo_algorithm: str ):
|
|
161
|
+
|
|
162
|
+
# Load data for the current batch
|
|
163
|
+
test_loader, test_dataset = io.create_predict_dataloader(
|
|
164
|
+
self.root,
|
|
165
|
+
voxel_spacing, tomo_algorithm,
|
|
166
|
+
runIDs)
|
|
167
|
+
|
|
168
|
+
# Determine Input Crop Size.
|
|
169
|
+
if self.input_dim is None:
|
|
170
|
+
self.input_dim = io.get_input_dimensions(test_dataset, self.dim_in)
|
|
171
|
+
|
|
172
|
+
predictions = []
|
|
173
|
+
with torch.no_grad():
|
|
174
|
+
for data in tqdm(test_loader):
|
|
175
|
+
tomogram = data['image'].to(self.device)
|
|
176
|
+
if self.apply_tta: data['pred'] = self._run_inference_tta(tomogram)
|
|
177
|
+
else: data['pred'] = self._run_inference(tomogram)
|
|
178
|
+
for idx in range(len(data['image'])):
|
|
179
|
+
predictions.append(data['pred'][idx].squeeze(0).numpy(force=True))
|
|
180
|
+
|
|
181
|
+
return predictions
|
|
182
|
+
|
|
183
|
+
def batch_predict(self,
|
|
184
|
+
num_tomos_per_batch = 15,
|
|
185
|
+
runIDs: Optional[str] = None,
|
|
186
|
+
voxel_spacing: float = 10,
|
|
187
|
+
tomo_algorithm: str = 'denoised',
|
|
188
|
+
segmentation_name: str = 'prediction',
|
|
189
|
+
segmentation_user_id: str = 'octopi',
|
|
190
|
+
segmentation_session_id: str = '0'):
|
|
191
|
+
|
|
192
|
+
"""Run inference on tomograms in batches."""
|
|
193
|
+
|
|
194
|
+
# If runIDs are not provided, load all runs
|
|
195
|
+
if runIDs is None:
|
|
196
|
+
runIDs = [run.name for run in self.root.runs]
|
|
197
|
+
|
|
198
|
+
# Iterate over batches of runIDs
|
|
199
|
+
for i in range(0, len(runIDs), num_tomos_per_batch):
|
|
200
|
+
|
|
201
|
+
# Get a batch of runIDs
|
|
202
|
+
batch_ids = runIDs[i:i + num_tomos_per_batch]
|
|
203
|
+
print('Running Inference on the Follow RunIDs: ', batch_ids)
|
|
204
|
+
|
|
205
|
+
predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
|
|
206
|
+
|
|
207
|
+
# Save Predictions to Corresponding RunID
|
|
208
|
+
for ind in range(len(batch_ids)):
|
|
209
|
+
run = self.root.get_run(batch_ids[ind])
|
|
210
|
+
seg = predictions[ind]
|
|
211
|
+
write.segmentation(run, seg, segmentation_user_id, segmentation_name,
|
|
212
|
+
segmentation_session_id, voxel_spacing)
|
|
213
|
+
|
|
214
|
+
# After processing and saving predictions for a batch:
|
|
215
|
+
del predictions # Remove reference to the list holding prediction arrays
|
|
216
|
+
torch.cuda.empty_cache() # Clear unused GPU memory
|
|
217
|
+
gc.collect() # Trigger garbage collection for CPU memory
|
|
218
|
+
|
|
219
|
+
print('Predictions Complete!')
|
|
220
|
+
|
|
221
|
+
def create_tta_augmentations(self):
|
|
222
|
+
"""Define TTA augmentations and inverse transforms."""
|
|
223
|
+
|
|
224
|
+
# Instead of Flip lets rotate around the first axis 3 times (90,180,270)
|
|
225
|
+
self.tta_transforms = [
|
|
226
|
+
lambda x: x, # Identity (no augmentation)
|
|
227
|
+
lambda x: torch.rot90(x, k=1, dims=(3, 4)), # 90° rotation
|
|
228
|
+
lambda x: torch.rot90(x, k=2, dims=(3, 4)), # 180° rotation
|
|
229
|
+
lambda x: torch.rot90(x, k=3, dims=(3, 4)), # 270° rotation
|
|
230
|
+
# Flip(spatial_axis=0), # Flip along x-axis (depth)
|
|
231
|
+
# Flip(spatial_axis=1), # Flip along y-axis (height)
|
|
232
|
+
# Flip(spatial_axis=2), # Flip along z-axis (width)
|
|
233
|
+
]
|
|
234
|
+
|
|
235
|
+
# Define inverse transformations (flip back to original orientation)
|
|
236
|
+
self.inverse_tta_transforms = [
|
|
237
|
+
lambda x: x, # Identity (no transformation needed)
|
|
238
|
+
lambda x: torch.rot90(x, k=-1, dims=(2, 3)), # Inverse of 90° (i.e. -90°)
|
|
239
|
+
lambda x: torch.rot90(x, k=-2, dims=(2, 3)), # Inverse of 180° (i.e. -180°)
|
|
240
|
+
lambda x: torch.rot90(x, k=-3, dims=(2, 3)), # Inverse of 270° (i.e. -270°)
|
|
241
|
+
# Flip(spatial_axis=0), # Undo Flip along x-axis
|
|
242
|
+
# Flip(spatial_axis=1), # Undo Flip along y-axis
|
|
243
|
+
# Flip(spatial_axis=2), # Undo Flip along z-axis
|
|
244
|
+
]
|
|
245
|
+
|
|
246
|
+
###################################################################################################################################################
|
|
247
|
+
|
|
248
|
+
class MultiGPUPredictor(Predictor):
|
|
249
|
+
|
|
250
|
+
def __init__(self,
|
|
251
|
+
config: str,
|
|
252
|
+
model_config: str,
|
|
253
|
+
model_weights: str):
|
|
254
|
+
super().__init__(config, model_config, model_weights)
|
|
255
|
+
self.num_gpus = torch.cuda.device_count()
|
|
256
|
+
if self.num_gpus < 2:
|
|
257
|
+
raise RuntimeError("MultiGPUPredictor requires at least 2 GPUs.")
|
|
258
|
+
|
|
259
|
+
def predict_on_gpu(self, gpu_id: int, batch_ids: List[str], voxel_spacing: float, tomo_algorithm: str) -> List[np.ndarray]:
|
|
260
|
+
"""Helper function to run inference on a single GPU."""
|
|
261
|
+
device = torch.device(f'cuda:{gpu_id}')
|
|
262
|
+
self.model.to(device)
|
|
263
|
+
|
|
264
|
+
# Load data specific to the batch assigned to this GPU
|
|
265
|
+
test_loader = io.load_predict_data(self.root, batch_ids, voxel_spacing, tomo_algorithm)
|
|
266
|
+
predictions = []
|
|
267
|
+
|
|
268
|
+
with torch.no_grad():
|
|
269
|
+
for data in tqdm(test_loader, desc=f"GPU {gpu_id}"):
|
|
270
|
+
tomogram = data['image'].to(device)
|
|
271
|
+
data["prediction"] = self.run_inference(tomogram)
|
|
272
|
+
data = [self.post_processing(i) for i in decollate_batch(data)]
|
|
273
|
+
for b in data:
|
|
274
|
+
predictions.append(b['prediction'].squeeze(0).cpu().numpy())
|
|
275
|
+
|
|
276
|
+
return predictions
|
|
277
|
+
|
|
278
|
+
def multi_gpu_inference(self,
|
|
279
|
+
num_tomos_per_batch: int = 15,
|
|
280
|
+
runIDs: Optional[List[str]] = None,
|
|
281
|
+
voxel_spacing: float = 10,
|
|
282
|
+
tomo_algorithm: str = 'denoised',
|
|
283
|
+
save: bool = False,
|
|
284
|
+
segmentation_name: str = 'prediction',
|
|
285
|
+
segmentation_user_id: str = 'monai',
|
|
286
|
+
segmentation_session_id: str = '0') -> Optional[List[np.ndarray]]:
|
|
287
|
+
"""Run inference across multiple GPUs, optionally saving results or returning predictions."""
|
|
288
|
+
|
|
289
|
+
runIDs = runIDs or [run.name for run in self.root.runs]
|
|
290
|
+
all_predictions = []
|
|
291
|
+
|
|
292
|
+
# Divide runIDs into batches for each GPU
|
|
293
|
+
batches = [runIDs[i:i + num_tomos_per_batch] for i in range(0, len(run_ids), num_tomos_per_batch)]
|
|
294
|
+
|
|
295
|
+
# Run inference in parallel across GPUs
|
|
296
|
+
for i in range(0, len(batches), self.num_gpus):
|
|
297
|
+
gpu_batches = batches[i:i + self.num_gpus]
|
|
298
|
+
with Pool(processes=self.num_gpus) as pool:
|
|
299
|
+
results = pool.starmap(
|
|
300
|
+
self.predict_on_gpu,
|
|
301
|
+
[(gpu_id, gpu_batches[gpu_id], voxel_spacing, tomo_algorithm) for gpu_id in range(len(gpu_batches))]
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Collect and save results
|
|
305
|
+
for gpu_id, predictions in enumerate(results):
|
|
306
|
+
if save:
|
|
307
|
+
for idx, run_id in enumerate(gpu_batches[gpu_id]):
|
|
308
|
+
run = self.root.get_run(run_id)
|
|
309
|
+
segmentation = predictions[idx]
|
|
310
|
+
write.segmentation(run, segmentation, segmentation_user_id, segmentation_name,
|
|
311
|
+
segmentation_session_id, voxel_spacing)
|
|
312
|
+
else:
|
|
313
|
+
all_predictions.extend(predictions)
|
|
314
|
+
|
|
315
|
+
print('Multi-GPU predictions complete.')
|
|
316
|
+
|
|
317
|
+
return None if save else all_predictions
|