octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,290 @@
1
+ from octopi.datasets import generators, multi_config_generator
2
+ from octopi.pytorch import hyper_search
3
+ import torch, mlflow, optuna
4
+ from octopi import utils
5
+ from typing import List
6
+ import pandas as pd
7
+
8
+ class ModelSearchSubmit:
9
+ def __init__(
10
+ self,
11
+ copick_config: str,
12
+ target_name: str,
13
+ target_user_id: str,
14
+ target_session_id: str,
15
+ tomo_algorithm: str,
16
+ voxel_size: float,
17
+ Nclass: int,
18
+ model_type: str,
19
+ mlflow_experiment_name: str,
20
+ random_seed: int,
21
+ num_epochs: int,
22
+ num_trials: int,
23
+ tomo_batch_size: int,
24
+ best_metric: str,
25
+ val_interval: int,
26
+ trainRunIDs: List[str],
27
+ validateRunIDs: List[str],
28
+ data_split: str
29
+ ):
30
+ """
31
+ Initialize the ModelSearch class for architecture search with Optuna.
32
+
33
+ Args:
34
+ copick_config (str or dict): Path to the CoPick configuration file or a dictionary for multi-config training.
35
+ target_name (str): Name of the target for segmentation.
36
+ target_user_id (str): Optional user ID for tracking.
37
+ target_session_id (str): Optional session ID for tracking.
38
+ tomo_algorithm (str): Tomogram algorithm to use.
39
+ voxel_size (float): Voxel size for tomograms.
40
+ Nclass (int): Number of prediction classes.
41
+ model_type (str): Type of model to use.
42
+ mlflow_experiment_name (str): MLflow experiment name.
43
+ random_seed (int): Seed for reproducibility.
44
+ num_epochs (int): Number of epochs per trial.
45
+ num_trials (int): Number of trials for hyperparameter optimization.
46
+ tomo_batch_size (int): Batch size for tomogram loading.
47
+ best_metric (str): Metric to optimize.
48
+ val_interval (int): Validation interval.
49
+ trainRunIDs (List[str]): List of training run IDs.
50
+ validateRunIDs (List[str]): List of validation run IDs.
51
+ data_split (str): Data split ratios.
52
+ """
53
+
54
+ # Input parameters
55
+ self.copick_config = copick_config
56
+ self.target_name = target_name
57
+ self.target_user_id = target_user_id
58
+ self.target_session_id = target_session_id
59
+ self.tomo_algorithm = tomo_algorithm
60
+ self.voxel_size = voxel_size
61
+ self.Nclass = Nclass
62
+ self.model_type = model_type
63
+ self.mlflow_experiment_name = mlflow_experiment_name
64
+ self.random_seed = random_seed
65
+ self.num_epochs = num_epochs
66
+ self.num_trials = num_trials
67
+ self.tomo_batch_size = tomo_batch_size
68
+ self.best_metric = best_metric
69
+ self.val_interval = val_interval
70
+ self.trainRunIDs = trainRunIDs
71
+ self.validateRunIDs = validateRunIDs
72
+ self.data_split = data_split
73
+
74
+ # Data generator - will be initialized in _initialize_data_generator()
75
+ self.data_generator = None
76
+
77
+ # Set random seed for reproducibility
78
+ utils.set_seed(self.random_seed)
79
+
80
+ # Initialize dataset generator
81
+ self._initialize_data_generator()
82
+
83
+ def _initialize_data_generator(self):
84
+ """Initializes the data generator for training and validation datasets."""
85
+ self._print_input_configs()
86
+
87
+ if isinstance(self.copick_config, dict):
88
+ self.data_generator = multi_config_generator.MultiConfigTrainLoaderManager(
89
+ self.copick_config,
90
+ self.target_name,
91
+ target_session_id=self.target_session_id,
92
+ target_user_id=self.target_user_id,
93
+ tomo_algorithm=self.tomo_algorithm,
94
+ voxel_size=self.voxel_size,
95
+ Nclasses=self.Nclass,
96
+ tomo_batch_size=self.tomo_batch_size
97
+ )
98
+ else:
99
+ self.data_generator = generators.TrainLoaderManager(
100
+ self.copick_config,
101
+ self.target_name,
102
+ target_session_id=self.target_session_id,
103
+ target_user_id=self.target_user_id,
104
+ tomo_algorithm=self.tomo_algorithm,
105
+ voxel_size=self.voxel_size,
106
+ Nclasses=self.Nclass,
107
+ tomo_batch_size=self.tomo_batch_size
108
+ )
109
+
110
+ # Split datasets into training and validation
111
+ ratios = utils.parse_data_split(self.data_split)
112
+ self.data_generator.get_data_splits(
113
+ trainRunIDs=self.trainRunIDs,
114
+ validateRunIDs=self.validateRunIDs,
115
+ train_ratio = ratios[0], val_ratio = ratios[1], test_ratio = ratios[2],
116
+ create_test_dataset = False
117
+ )
118
+
119
+ # Get the reload frequency
120
+ self.data_generator.get_reload_frequency(self.num_epochs)
121
+
122
+ def _print_input_configs(self):
123
+ """Prints training configuration for debugging purposes."""
124
+ print(f'\nTraining with:')
125
+ if isinstance(self.copick_config, dict):
126
+ for session, config in self.copick_config.items():
127
+ print(f' {session}: {config}')
128
+ else:
129
+ print(f' {self.copick_config}')
130
+ print()
131
+
132
+ def run_model_search(self):
133
+ """Performs model architecture search using Optuna and MLflow."""
134
+
135
+ # Set up MLflow tracking
136
+ try:
137
+ tracking_uri = utils.mlflow_setup()
138
+ mlflow.set_tracking_uri(tracking_uri)
139
+ except Exception as e:
140
+ print(f'Failed to set up MLflow tracking: {e}')
141
+ pass
142
+
143
+ mlflow.set_experiment(self.mlflow_experiment_name)
144
+
145
+ # Create a storage object with heartbeat configuration
146
+ storage_url = f"sqlite:///explore_results_{self.model_type}/trials.db"
147
+ self.storage = optuna.storages.RDBStorage(
148
+ url=storage_url,
149
+ heartbeat_interval=60, # Record heartbeat every minute
150
+ grace_period=600, # 10 minutes grace period
151
+ failed_trial_callback=optuna.storages.RetryFailedTrialCallback(max_retry=1)
152
+ )
153
+
154
+ # Detect GPU availability
155
+ gpu_count = torch.cuda.device_count()
156
+ print(f'Running Architecture Search Over {gpu_count} GPUs\n')
157
+
158
+ # Initialize model search object
159
+ if gpu_count > 1:
160
+ self._multi_gpu_optuna(gpu_count)
161
+ else:
162
+ model_search = hyper_search.BayesianModelSearch(self.data_generator, self.model_type)
163
+ self._single_gpu_optuna(model_search)
164
+
165
+ def _single_gpu_optuna(self, model_search):
166
+ """Runs Optuna optimization on a single GPU."""
167
+
168
+ with mlflow.start_run(nested=False) as parent_run:
169
+
170
+ model_search.parent_run_id = parent_run.info.run_id
171
+ model_search.parent_run_name = parent_run.info.run_name
172
+
173
+ # Log the experiment parameters
174
+ mlflow.log_params({"random_seed": self.random_seed})
175
+ mlflow.log_params(self.data_generator.get_dataloader_parameters())
176
+ mlflow.log_params({"parent_run_name": parent_run.info.run_name})
177
+
178
+ # Determine device
179
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
180
+
181
+ # Create the study and run the optimization
182
+ study = self.get_optuna_study()
183
+ study.optimize(
184
+ lambda trial: model_search.objective(
185
+ trial, self.num_epochs, device,
186
+ val_interval=self.val_interval,
187
+ best_metric=self.best_metric,
188
+ ),
189
+ n_trials=self.num_trials
190
+ )
191
+
192
+ # Save contour plot
193
+ self.save_contour_plot_as_png(study)
194
+
195
+ print(f"Best trial: {study.best_trial.value}")
196
+ print(f"Best params: {study.best_params}")
197
+
198
+ def _multi_gpu_optuna(self, gpu_count):
199
+ """Runs Optuna optimization on multiple GPUs."""
200
+ with mlflow.start_run() as parent_run:
201
+
202
+ # Log the experiment parameters
203
+ mlflow.log_params({"random_seed": self.random_seed})
204
+ mlflow.log_params(self.data_generator.get_dataloader_parameters())
205
+ mlflow.log_params({"parent_run_name": parent_run.info.run_name})
206
+
207
+ # Run multi-GPU optimization
208
+ study = self.get_optuna_study()
209
+ study.optimize(
210
+ lambda trial: BayesianModelSearch(self.data_generator, self.model_type).multi_gpu_objective(
211
+ parent_run, trial,
212
+ self.num_epochs,
213
+ best_metric=self.best_metric,
214
+ val_interval=self.val_interval,
215
+ gpu_count=gpu_count
216
+ ),
217
+ n_trials=self.num_trials,
218
+ n_jobs=gpu_count
219
+ )
220
+
221
+ # Save contour Plot
222
+ self.save_contour_plot_as_png(study)
223
+
224
+ print(f"Best trial: {study.best_trial.value}")
225
+ print(f"Best params: {study.best_params}")
226
+
227
+ def _get_optuna_sampler(self):
228
+ """Returns Optuna's TPE sampler with default settings."""
229
+ return optuna.samplers.TPESampler(
230
+ n_startup_trials=10,
231
+ n_ei_candidates=24,
232
+ multivariate=True
233
+ )
234
+ # return optuna.samplers.BoTorchSampler(
235
+ # n_startup_trials=10,
236
+ # multivariate=True
237
+ # )
238
+
239
+ def get_optuna_study(self):
240
+ """Returns the Optuna study object."""
241
+ return optuna.create_study(
242
+ storage=self.storage,
243
+ direction="maximize",
244
+ sampler=self._get_optuna_sampler(),
245
+ load_if_exists=True,
246
+ pruner=self._get_optuna_pruner()
247
+ )
248
+
249
+ def _get_optuna_pruner(self):
250
+ """Returns Optuna's pruning strategy."""
251
+ return optuna.pruners.MedianPruner()
252
+
253
+ def save_contour_plot_as_png(self, study):
254
+ """
255
+ Save the contour plot of hyperparameter interactions as a PNG,
256
+ automatically extracting parameter names from the study object.
257
+
258
+ Args:
259
+ study: The Optuna study object.
260
+ output_path: Path to save the PNG file.
261
+ """
262
+ # Extract all parameter names from the study trials
263
+ all_params = set()
264
+ for trial in study.trials:
265
+ all_params.update(trial.params.keys())
266
+ all_params = list(all_params) # Convert to a sorted list for consistency
267
+
268
+ # Generate the contour plot
269
+ fig = optuna.visualization.plot_contour(study, params=all_params)
270
+
271
+ # Adjust figure size and font size
272
+ fig.update_layout(
273
+ width=6000, height=6000, # Large figure size
274
+ font=dict(size=40) # Increase font size for better readability
275
+ )
276
+
277
+ # Save the plot as a PNG file
278
+ fig.write_image(f'explore_results_{self.model_type}/contour_plot.png', scale=1)
279
+
280
+ # Extract trial data
281
+ trials = [
282
+ {**trial.params, 'objective_value': trial.value}
283
+ for trial in study.trials if trial.state == optuna.trial.TrialState.COMPLETE
284
+ ]
285
+
286
+ # Convert to DataFrame
287
+ df = pd.DataFrame(trials)
288
+
289
+ # Save to CSV
290
+ df.to_csv(f"explore_results_{self.model_type}/optuna_results.csv", index=False)
@@ -0,0 +1,317 @@
1
+ from monai.inferers import sliding_window_inference
2
+ from monai.data import decollate_batch
3
+ from torch.multiprocessing import Pool
4
+ from monai.data import MetaTensor
5
+ from monai.transforms import (
6
+ Compose, AsDiscrete, Activations
7
+ )
8
+ import octopi.processing.writers as write
9
+ from octopi.models import common
10
+ from typing import List, Optional
11
+ import torch, copick, gc, os
12
+ from octopi import io, utils
13
+ from tqdm import tqdm
14
+ import numpy as np
15
+
16
+ class Predictor:
17
+
18
+ def __init__(self,
19
+ config: str,
20
+ model_config: str,
21
+ model_weights: str,
22
+ apply_tta: bool = True,
23
+ device: Optional[str] = None):
24
+
25
+ self.config = config
26
+ self.root = copick.from_file(config)
27
+
28
+ # Load the model config
29
+ model_config = utils.load_yaml(model_config)
30
+
31
+ self.Nclass = model_config['model']['num_classes']
32
+ self.dim_in = model_config['model']['dim_in']
33
+ self.input_dim = None
34
+
35
+ # Get the number of GPUs available
36
+ num_gpus = torch.cuda.device_count()
37
+ if num_gpus == 0:
38
+ raise RuntimeError("No GPUs available.")
39
+
40
+ # Set the device
41
+ if device is None:
42
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+ else:
44
+ self.device = device
45
+ print('Running Inference On: ', self.device)
46
+
47
+ # Check to see if the model weights file exists
48
+ if not os.path.exists(model_weights):
49
+ raise ValueError(f"Model weights file does not exist: {model_weights}")
50
+
51
+ # Load the model weights
52
+ model_builder = common.get_model(model_config['model']['architecture'])
53
+ model_builder.build_model(model_config['model'])
54
+ self.model = model_builder.model
55
+ state_dict = torch.load(model_weights, map_location=self.device, weights_only=True)
56
+ self.model.load_state_dict(state_dict)
57
+ self.model.to(self.device)
58
+ self.model.eval()
59
+
60
+ # Initialize TTA if enabled
61
+ self.apply_tta = apply_tta
62
+ if self.apply_tta:
63
+ self.create_tta_augmentations()
64
+ # self.post_transforms = Compose([
65
+ # Activations(softmax=True) # Keep probability output
66
+ # ])
67
+ self.softmax_transform = Compose([
68
+ Activations(softmax=True) # Keep probability output
69
+ ])
70
+
71
+ # Create the final discretization transform
72
+ self.discretize_transform = AsDiscrete(argmax=True)
73
+ else:
74
+ # Define the post-processing transforms
75
+ self.post_transforms = Compose([
76
+ Activations(softmax=True),
77
+ AsDiscrete(argmax=True)
78
+ ])
79
+
80
+
81
+ def _run_inference(self, input):
82
+ """Apply sliding window inference to the input."""
83
+ with torch.no_grad():
84
+ predictions = sliding_window_inference(
85
+ inputs=input,
86
+ roi_size=(self.dim_in, self.dim_in, self.dim_in),
87
+ sw_batch_size=4, # one window is proecessed at a time
88
+ predictor=self.model,
89
+ overlap=0.5,
90
+ )
91
+ return [self.post_transforms(i) for i in decollate_batch(predictions)]
92
+
93
+ def _run_inference_tta(self, input_data):
94
+ """Memory-efficient TTA implementation that returns proper discrete segmentation maps."""
95
+
96
+ batch_size = input_data.shape[0]
97
+ results = []
98
+
99
+ # Process one sample at a time
100
+ for sample_idx in range(batch_size):
101
+ # Extract single sample
102
+ single_sample = input_data[sample_idx:sample_idx+1]
103
+
104
+ # Initialize probability accumulator for this sample
105
+ # Shape: [1, Nclass, Z, Y, X]
106
+ acc_probs = torch.zeros(
107
+ (1, self.Nclass, *single_sample.shape[2:]),
108
+ dtype=torch.float32, device=self.device
109
+ )
110
+
111
+ # Process each augmentation
112
+ with torch.no_grad():
113
+ for tta_transform, inverse_transform in zip(self.tta_transforms, self.inverse_tta_transforms):
114
+ # Apply transform to single sample
115
+ aug_sample = tta_transform(single_sample)
116
+
117
+ # Free memory
118
+ torch.cuda.empty_cache()
119
+
120
+ # Run inference (one sample at a time)
121
+ predictions = sliding_window_inference(
122
+ inputs=aug_sample,
123
+ roi_size=(self.dim_in, self.dim_in, self.dim_in),
124
+ sw_batch_size=4, # Process one window at a time
125
+ predictor=self.model,
126
+ overlap=0.5,
127
+ )
128
+
129
+ # Get softmax probabilities
130
+ probs = self.softmax_transform(predictions[0]) # Get first (only) item
131
+
132
+ # Apply inverse transform with correct dimensions
133
+ inv_probs = inverse_transform(probs)
134
+
135
+ # Accumulate probabilities
136
+ acc_probs[0] += inv_probs
137
+
138
+ # Clear memory
139
+ del predictions, probs, inv_probs, aug_sample
140
+
141
+ # Average accumulated probabilities
142
+ acc_probs = acc_probs / len(self.tta_transforms)
143
+
144
+ # Convert to discrete prediction - get argmax along class dimension
145
+ # This gives us a tensor of shape [1, Z, Y, X] with discrete class indices
146
+ discrete_pred = torch.argmax(acc_probs, dim=1)
147
+
148
+ # Add to results - keeping only the spatial dimensions [Z, Y, X]
149
+ results.append(discrete_pred[0])
150
+
151
+ # Clear memory
152
+ del acc_probs, discrete_pred
153
+ torch.cuda.empty_cache()
154
+
155
+ return results
156
+
157
+ def predict_on_gpu(self,
158
+ runIDs: List[str],
159
+ voxel_spacing: float,
160
+ tomo_algorithm: str ):
161
+
162
+ # Load data for the current batch
163
+ test_loader, test_dataset = io.create_predict_dataloader(
164
+ self.root,
165
+ voxel_spacing, tomo_algorithm,
166
+ runIDs)
167
+
168
+ # Determine Input Crop Size.
169
+ if self.input_dim is None:
170
+ self.input_dim = io.get_input_dimensions(test_dataset, self.dim_in)
171
+
172
+ predictions = []
173
+ with torch.no_grad():
174
+ for data in tqdm(test_loader):
175
+ tomogram = data['image'].to(self.device)
176
+ if self.apply_tta: data['pred'] = self._run_inference_tta(tomogram)
177
+ else: data['pred'] = self._run_inference(tomogram)
178
+ for idx in range(len(data['image'])):
179
+ predictions.append(data['pred'][idx].squeeze(0).numpy(force=True))
180
+
181
+ return predictions
182
+
183
+ def batch_predict(self,
184
+ num_tomos_per_batch = 15,
185
+ runIDs: Optional[str] = None,
186
+ voxel_spacing: float = 10,
187
+ tomo_algorithm: str = 'denoised',
188
+ segmentation_name: str = 'prediction',
189
+ segmentation_user_id: str = 'octopi',
190
+ segmentation_session_id: str = '0'):
191
+
192
+ """Run inference on tomograms in batches."""
193
+
194
+ # If runIDs are not provided, load all runs
195
+ if runIDs is None:
196
+ runIDs = [run.name for run in self.root.runs]
197
+
198
+ # Iterate over batches of runIDs
199
+ for i in range(0, len(runIDs), num_tomos_per_batch):
200
+
201
+ # Get a batch of runIDs
202
+ batch_ids = runIDs[i:i + num_tomos_per_batch]
203
+ print('Running Inference on the Follow RunIDs: ', batch_ids)
204
+
205
+ predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
206
+
207
+ # Save Predictions to Corresponding RunID
208
+ for ind in range(len(batch_ids)):
209
+ run = self.root.get_run(batch_ids[ind])
210
+ seg = predictions[ind]
211
+ write.segmentation(run, seg, segmentation_user_id, segmentation_name,
212
+ segmentation_session_id, voxel_spacing)
213
+
214
+ # After processing and saving predictions for a batch:
215
+ del predictions # Remove reference to the list holding prediction arrays
216
+ torch.cuda.empty_cache() # Clear unused GPU memory
217
+ gc.collect() # Trigger garbage collection for CPU memory
218
+
219
+ print('Predictions Complete!')
220
+
221
+ def create_tta_augmentations(self):
222
+ """Define TTA augmentations and inverse transforms."""
223
+
224
+ # Instead of Flip lets rotate around the first axis 3 times (90,180,270)
225
+ self.tta_transforms = [
226
+ lambda x: x, # Identity (no augmentation)
227
+ lambda x: torch.rot90(x, k=1, dims=(3, 4)), # 90° rotation
228
+ lambda x: torch.rot90(x, k=2, dims=(3, 4)), # 180° rotation
229
+ lambda x: torch.rot90(x, k=3, dims=(3, 4)), # 270° rotation
230
+ # Flip(spatial_axis=0), # Flip along x-axis (depth)
231
+ # Flip(spatial_axis=1), # Flip along y-axis (height)
232
+ # Flip(spatial_axis=2), # Flip along z-axis (width)
233
+ ]
234
+
235
+ # Define inverse transformations (flip back to original orientation)
236
+ self.inverse_tta_transforms = [
237
+ lambda x: x, # Identity (no transformation needed)
238
+ lambda x: torch.rot90(x, k=-1, dims=(2, 3)), # Inverse of 90° (i.e. -90°)
239
+ lambda x: torch.rot90(x, k=-2, dims=(2, 3)), # Inverse of 180° (i.e. -180°)
240
+ lambda x: torch.rot90(x, k=-3, dims=(2, 3)), # Inverse of 270° (i.e. -270°)
241
+ # Flip(spatial_axis=0), # Undo Flip along x-axis
242
+ # Flip(spatial_axis=1), # Undo Flip along y-axis
243
+ # Flip(spatial_axis=2), # Undo Flip along z-axis
244
+ ]
245
+
246
+ ###################################################################################################################################################
247
+
248
+ class MultiGPUPredictor(Predictor):
249
+
250
+ def __init__(self,
251
+ config: str,
252
+ model_config: str,
253
+ model_weights: str):
254
+ super().__init__(config, model_config, model_weights)
255
+ self.num_gpus = torch.cuda.device_count()
256
+ if self.num_gpus < 2:
257
+ raise RuntimeError("MultiGPUPredictor requires at least 2 GPUs.")
258
+
259
+ def predict_on_gpu(self, gpu_id: int, batch_ids: List[str], voxel_spacing: float, tomo_algorithm: str) -> List[np.ndarray]:
260
+ """Helper function to run inference on a single GPU."""
261
+ device = torch.device(f'cuda:{gpu_id}')
262
+ self.model.to(device)
263
+
264
+ # Load data specific to the batch assigned to this GPU
265
+ test_loader = io.load_predict_data(self.root, batch_ids, voxel_spacing, tomo_algorithm)
266
+ predictions = []
267
+
268
+ with torch.no_grad():
269
+ for data in tqdm(test_loader, desc=f"GPU {gpu_id}"):
270
+ tomogram = data['image'].to(device)
271
+ data["prediction"] = self.run_inference(tomogram)
272
+ data = [self.post_processing(i) for i in decollate_batch(data)]
273
+ for b in data:
274
+ predictions.append(b['prediction'].squeeze(0).cpu().numpy())
275
+
276
+ return predictions
277
+
278
+ def multi_gpu_inference(self,
279
+ num_tomos_per_batch: int = 15,
280
+ runIDs: Optional[List[str]] = None,
281
+ voxel_spacing: float = 10,
282
+ tomo_algorithm: str = 'denoised',
283
+ save: bool = False,
284
+ segmentation_name: str = 'prediction',
285
+ segmentation_user_id: str = 'monai',
286
+ segmentation_session_id: str = '0') -> Optional[List[np.ndarray]]:
287
+ """Run inference across multiple GPUs, optionally saving results or returning predictions."""
288
+
289
+ runIDs = runIDs or [run.name for run in self.root.runs]
290
+ all_predictions = []
291
+
292
+ # Divide runIDs into batches for each GPU
293
+ batches = [runIDs[i:i + num_tomos_per_batch] for i in range(0, len(run_ids), num_tomos_per_batch)]
294
+
295
+ # Run inference in parallel across GPUs
296
+ for i in range(0, len(batches), self.num_gpus):
297
+ gpu_batches = batches[i:i + self.num_gpus]
298
+ with Pool(processes=self.num_gpus) as pool:
299
+ results = pool.starmap(
300
+ self.predict_on_gpu,
301
+ [(gpu_id, gpu_batches[gpu_id], voxel_spacing, tomo_algorithm) for gpu_id in range(len(gpu_batches))]
302
+ )
303
+
304
+ # Collect and save results
305
+ for gpu_id, predictions in enumerate(results):
306
+ if save:
307
+ for idx, run_id in enumerate(gpu_batches[gpu_id]):
308
+ run = self.root.get_run(run_id)
309
+ segmentation = predictions[idx]
310
+ write.segmentation(run, segmentation, segmentation_user_id, segmentation_name,
311
+ segmentation_session_id, voxel_spacing)
312
+ else:
313
+ all_predictions.extend(predictions)
314
+
315
+ print('Multi-GPU predictions complete.')
316
+
317
+ return None if save else all_predictions