octopi 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (48) hide show
  1. octopi/__init__.py +1 -0
  2. octopi/datasets/cached_datset.py +1 -1
  3. octopi/datasets/generators.py +1 -1
  4. octopi/datasets/io.py +200 -0
  5. octopi/datasets/multi_config_generator.py +1 -1
  6. octopi/entry_points/common.py +9 -9
  7. octopi/entry_points/create_slurm_submission.py +16 -8
  8. octopi/entry_points/run_create_targets.py +6 -6
  9. octopi/entry_points/run_evaluate.py +4 -3
  10. octopi/entry_points/run_extract_mb_picks.py +22 -45
  11. octopi/entry_points/run_localize.py +37 -54
  12. octopi/entry_points/run_optuna.py +7 -7
  13. octopi/entry_points/run_segment_predict.py +4 -4
  14. octopi/entry_points/run_train.py +7 -8
  15. octopi/extract/localize.py +19 -12
  16. octopi/extract/membranebound_extract.py +11 -10
  17. octopi/extract/midpoint_extract.py +3 -3
  18. octopi/main.py +1 -1
  19. octopi/models/common.py +1 -1
  20. octopi/processing/create_targets_from_picks.py +11 -5
  21. octopi/processing/downsample.py +6 -10
  22. octopi/processing/evaluate.py +24 -11
  23. octopi/processing/importers.py +4 -4
  24. octopi/pytorch/hyper_search.py +2 -3
  25. octopi/pytorch/model_search_submitter.py +15 -15
  26. octopi/pytorch/segmentation.py +147 -192
  27. octopi/pytorch/segmentation_multigpu.py +162 -0
  28. octopi/pytorch/trainer.py +9 -3
  29. octopi/utils/__init__.py +0 -0
  30. octopi/utils/config.py +57 -0
  31. octopi/utils/io.py +128 -0
  32. octopi/{utils.py → utils/parsers.py} +10 -84
  33. octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
  34. octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
  35. octopi/workflows.py +236 -0
  36. octopi-1.2.0.dist-info/METADATA +120 -0
  37. octopi-1.2.0.dist-info/RECORD +62 -0
  38. {octopi-1.0.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
  39. octopi-1.2.0.dist-info/entry_points.txt +3 -0
  40. {octopi-1.0.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
  41. octopi/io.py +0 -457
  42. octopi/processing/my_metrics.py +0 -26
  43. octopi/processing/writers.py +0 -102
  44. octopi-1.0.dist-info/METADATA +0 -209
  45. octopi-1.0.dist-info/RECORD +0 -59
  46. octopi-1.0.dist-info/entry_points.txt +0 -4
  47. /octopi/{losses.py → utils/losses.py} +0 -0
  48. /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
@@ -5,11 +5,12 @@ from monai.data import MetaTensor
5
5
  from monai.transforms import (
6
6
  Compose, AsDiscrete, Activations
7
7
  )
8
- import octopi.processing.writers as write
8
+ from typing import List, Optional, Union
9
+ from octopi.datasets import io as dataio
10
+ from copick_utils.io import writers
9
11
  from octopi.models import common
10
- from typing import List, Optional
11
12
  import torch, copick, gc, os
12
- from octopi import io, utils
13
+ from octopi.utils import io
13
14
  from tqdm import tqdm
14
15
  import numpy as np
15
16
 
@@ -17,20 +18,14 @@ class Predictor:
17
18
 
18
19
  def __init__(self,
19
20
  config: str,
20
- model_config: str,
21
- model_weights: str,
21
+ model_config: Union[str, List[str]],
22
+ model_weights: Union[str, List[str]],
22
23
  apply_tta: bool = True,
23
24
  device: Optional[str] = None):
24
25
 
26
+ # Open the Copick Project
25
27
  self.config = config
26
28
  self.root = copick.from_file(config)
27
-
28
- # Load the model config
29
- model_config = utils.load_yaml(model_config)
30
-
31
- self.Nclass = model_config['model']['num_classes']
32
- self.dim_in = model_config['model']['dim_in']
33
- self.input_dim = None
34
29
 
35
30
  # Get the number of GPUs available
36
31
  num_gpus = torch.cuda.device_count()
@@ -44,109 +39,147 @@ class Predictor:
44
39
  self.device = device
45
40
  print('Running Inference On: ', self.device)
46
41
 
47
- # Check to see if the model weights file exists
48
- if not os.path.exists(model_weights):
49
- raise ValueError(f"Model weights file does not exist: {model_weights}")
42
+ # Initialize TTA if enabled
43
+ self.apply_tta = apply_tta
44
+ self.create_tta_augmentations()
50
45
 
51
- # Load the model weights
52
- model_builder = common.get_model(model_config['model']['architecture'])
53
- model_builder.build_model(model_config['model'])
54
- self.model = model_builder.model
55
- state_dict = torch.load(model_weights, map_location=self.device, weights_only=True)
56
- self.model.load_state_dict(state_dict)
57
- self.model.to(self.device)
58
- self.model.eval()
46
+ # Determine if Model Soup is Enabled
47
+ if isinstance(model_weights, str):
48
+ model_weights = [model_weights]
49
+ self.apply_modelsoup = len(model_weights) > 1
59
50
 
60
- # Initialize TTA if enabled
61
- self.apply_tta = apply_tta
62
- if self.apply_tta:
63
- self.create_tta_augmentations()
64
- # self.post_transforms = Compose([
65
- # Activations(softmax=True) # Keep probability output
66
- # ])
67
- self.softmax_transform = Compose([
68
- Activations(softmax=True) # Keep probability output
69
- ])
51
+ # Handle Single Model Config or Multiple Model Configs
52
+ if isinstance(model_config, str):
53
+ model_config = [model_config] * len(model_weights)
54
+ elif len(model_config) != len(model_weights):
55
+ raise ValueError("Number of model configs must match number of model weights.")
56
+
57
+ # Load the model(s)
58
+ self._load_models(model_config, model_weights)
59
+
60
+ def _load_models(self, model_config: List[str], model_weights: List[str]):
61
+ """Load a single model or multiple models for soup."""
62
+
63
+ self.models = []
64
+ for i, (config_path, weights_path) in enumerate(zip(model_config, model_weights)):
65
+
66
+ # Load the Model Config and Model Builder
67
+ current_modelconfig = io.load_yaml(config_path)
68
+ model_builder = common.get_model(current_modelconfig['model']['architecture'])
69
+
70
+ # Check if the weights file exists
71
+ if not os.path.exists(weights_path):
72
+ raise ValueError(f"Model weights file does not exist: {weights_path}")
70
73
 
71
- # Create the final discretization transform
72
- self.discretize_transform = AsDiscrete(argmax=True)
73
- else:
74
- # Define the post-processing transforms
75
- self.post_transforms = Compose([
76
- Activations(softmax=True),
77
- AsDiscrete(argmax=True)
78
- ])
74
+ # Create model
75
+ model_builder.build_model(current_modelconfig['model'])
76
+ model = model_builder.model
79
77
 
78
+ # Load weights
79
+ state_dict = torch.load(weights_path, map_location=self.device, weights_only=True)
80
+ model.load_state_dict(state_dict)
81
+ model.to(self.device)
82
+ model.eval()
83
+
84
+ self.models.append(model)
80
85
 
81
- def _run_inference(self, input):
82
- """Apply sliding window inference to the input."""
83
- with torch.no_grad():
84
- predictions = sliding_window_inference(
85
- inputs=input,
86
- roi_size=(self.dim_in, self.dim_in, self.dim_in),
87
- sw_batch_size=4, # one window is proecessed at a time
88
- predictor=self.model,
89
- overlap=0.5,
90
- )
91
- return [self.post_transforms(i) for i in decollate_batch(predictions)]
86
+ # For backward compatibility, also set self.model to the first model
87
+ self.model = self.models[0]
92
88
 
93
- def _run_inference_tta(self, input_data):
94
- """Memory-efficient TTA implementation that returns proper discrete segmentation maps."""
89
+ # Set the Number of Classes and Input Dimensions - Assume All Models are the Same
90
+ self.Nclass = current_modelconfig['model']['num_classes']
91
+ self.dim_in = current_modelconfig['model']['dim_in']
92
+ self.input_dim = None
93
+
94
+ # Print a message if Model Soup is Enabled
95
+ if self.apply_modelsoup:
96
+ print(f'Model Soup is Enabled : {len(self.models)} models loaded for ensemble inference')
97
+
98
+ def _run_single_model_inference(self, model, input_data):
99
+ """Run sliding window inference on a single model."""
100
+ return sliding_window_inference(
101
+ inputs=input_data,
102
+ roi_size=(self.dim_in, self.dim_in, self.dim_in),
103
+ sw_batch_size=4,
104
+ predictor=model,
105
+ overlap=0.5,
106
+ )
107
+
108
+ def _apply_tta_single_model(self, model, single_sample):
109
+ """Apply TTA to a single model and single sample."""
110
+ # Initialize probability accumulator
111
+ acc_probs = torch.zeros(
112
+ (1, self.Nclass, *single_sample.shape[2:]),
113
+ dtype=torch.float32, device=self.device
114
+ )
115
+
116
+ # Process each augmentation
117
+ with torch.no_grad():
118
+ for tta_transform, inverse_transform in zip(self.tta_transforms, self.inverse_tta_transforms):
119
+ # Apply transform
120
+ aug_sample = tta_transform(single_sample)
121
+
122
+ # Run inference
123
+ predictions = self._run_single_model_inference(model, aug_sample)
124
+
125
+ # Get softmax probabilities
126
+ probs = torch.softmax(predictions[0], dim=0)
127
+
128
+ # Apply inverse transform
129
+ inv_probs = inverse_transform(probs)
130
+
131
+ # Accumulate probabilities
132
+ acc_probs[0] += inv_probs
133
+
134
+ # Clear memory
135
+ del predictions, probs, inv_probs, aug_sample
136
+ torch.cuda.empty_cache()
137
+
138
+ # Average accumulated probabilities
139
+ acc_probs = acc_probs / len(self.tta_transforms)
95
140
 
141
+ return acc_probs[0] # Return shape [Nclass, Z, Y, X]
142
+
143
+ def _run_inference(self, input_data):
144
+ """
145
+ Main inference function that handles all combinations - Model Soup and/or TTA
146
+ """
96
147
  batch_size = input_data.shape[0]
97
148
  results = []
98
149
 
99
- # Process one sample at a time
150
+ # Process one sample at a time for memory efficiency
100
151
  for sample_idx in range(batch_size):
101
- # Extract single sample
102
152
  single_sample = input_data[sample_idx:sample_idx+1]
103
153
 
104
154
  # Initialize probability accumulator for this sample
105
- # Shape: [1, Nclass, Z, Y, X]
106
155
  acc_probs = torch.zeros(
107
- (1, self.Nclass, *single_sample.shape[2:]),
156
+ (self.Nclass, *single_sample.shape[2:]),
108
157
  dtype=torch.float32, device=self.device
109
158
  )
110
159
 
111
- # Process each augmentation
160
+ # Process each model
112
161
  with torch.no_grad():
113
- for tta_transform, inverse_transform in zip(self.tta_transforms, self.inverse_tta_transforms):
114
- # Apply transform to single sample
115
- aug_sample = tta_transform(single_sample)
162
+ for model in self.models:
163
+ # Apply TTA with this model
164
+ if self.apply_tta:
165
+ model_probs = self._apply_tta_single_model(model, single_sample)
166
+ # Run inference without TTA
167
+ else:
168
+ predictions = self._run_single_model_inference(model, single_sample)
169
+ model_probs = torch.softmax(predictions[0], dim=0)
170
+ del predictions
116
171
 
117
- # Free memory
172
+ # Accumulate probabilities from this model
173
+ acc_probs += model_probs
174
+ del model_probs
118
175
  torch.cuda.empty_cache()
119
-
120
- # Run inference (one sample at a time)
121
- predictions = sliding_window_inference(
122
- inputs=aug_sample,
123
- roi_size=(self.dim_in, self.dim_in, self.dim_in),
124
- sw_batch_size=4, # Process one window at a time
125
- predictor=self.model,
126
- overlap=0.5,
127
- )
128
-
129
- # Get softmax probabilities
130
- probs = self.softmax_transform(predictions[0]) # Get first (only) item
131
-
132
- # Apply inverse transform with correct dimensions
133
- inv_probs = inverse_transform(probs)
134
-
135
- # Accumulate probabilities
136
- acc_probs[0] += inv_probs
137
-
138
- # Clear memory
139
- del predictions, probs, inv_probs, aug_sample
140
176
 
141
- # Average accumulated probabilities
142
- acc_probs = acc_probs / len(self.tta_transforms)
177
+ # Average probabilities across models (and TTA augmentations if applied)
178
+ acc_probs = acc_probs / len(self.models)
143
179
 
144
- # Convert to discrete prediction - get argmax along class dimension
145
- # This gives us a tensor of shape [1, Z, Y, X] with discrete class indices
146
- discrete_pred = torch.argmax(acc_probs, dim=1)
147
-
148
- # Add to results - keeping only the spatial dimensions [Z, Y, X]
149
- results.append(discrete_pred[0])
180
+ # Convert to discrete prediction
181
+ discrete_pred = torch.argmax(acc_probs, dim=0)
182
+ results.append(discrete_pred)
150
183
 
151
184
  # Clear memory
152
185
  del acc_probs, discrete_pred
@@ -157,50 +190,52 @@ class Predictor:
157
190
  def predict_on_gpu(self,
158
191
  runIDs: List[str],
159
192
  voxel_spacing: float,
160
- tomo_algorithm: str ):
193
+ tomo_algorithm: str):
161
194
 
162
195
  # Load data for the current batch
163
- test_loader, test_dataset = io.create_predict_dataloader(
196
+ test_loader, test_dataset = dataio.create_predict_dataloader(
164
197
  self.root,
165
198
  voxel_spacing, tomo_algorithm,
166
199
  runIDs)
167
200
 
168
201
  # Determine Input Crop Size.
169
202
  if self.input_dim is None:
170
- self.input_dim = io.get_input_dimensions(test_dataset, self.dim_in)
203
+ self.input_dim = dataio.get_input_dimensions(test_dataset, self.dim_in)
171
204
 
172
205
  predictions = []
173
206
  with torch.no_grad():
174
207
  for data in tqdm(test_loader):
175
208
  tomogram = data['image'].to(self.device)
176
- if self.apply_tta: data['pred'] = self._run_inference_tta(tomogram)
177
- else: data['pred'] = self._run_inference(tomogram)
209
+ data['pred'] = self._run_inference(tomogram)
210
+
178
211
  for idx in range(len(data['image'])):
179
212
  predictions.append(data['pred'][idx].squeeze(0).numpy(force=True))
180
213
 
181
214
  return predictions
182
215
 
183
216
  def batch_predict(self,
184
- num_tomos_per_batch = 15,
185
- runIDs: Optional[str] = None,
217
+ num_tomos_per_batch: int = 15,
218
+ runIDs: Optional[List[str]] = None,
186
219
  voxel_spacing: float = 10,
187
220
  tomo_algorithm: str = 'denoised',
188
221
  segmentation_name: str = 'prediction',
189
222
  segmentation_user_id: str = 'octopi',
190
223
  segmentation_session_id: str = '0'):
191
-
192
224
  """Run inference on tomograms in batches."""
193
225
 
194
226
  # If runIDs are not provided, load all runs
195
227
  if runIDs is None:
196
- runIDs = [run.name for run in self.root.runs]
197
-
228
+ runIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is not None]
229
+ skippedRunIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is None]
230
+
231
+ if skippedRunIDs:
232
+ print(f"Warning: skipping runs with no voxel spacing {voxel_spacing}: {skippedRunIDs}")
233
+
198
234
  # Iterate over batches of runIDs
199
235
  for i in range(0, len(runIDs), num_tomos_per_batch):
200
-
201
236
  # Get a batch of runIDs
202
237
  batch_ids = runIDs[i:i + num_tomos_per_batch]
203
- print('Running Inference on the Follow RunIDs: ', batch_ids)
238
+ print('Running Inference on the Following RunIDs: ', batch_ids)
204
239
 
205
240
  predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
206
241
 
@@ -208,7 +243,7 @@ class Predictor:
208
243
  for ind in range(len(batch_ids)):
209
244
  run = self.root.get_run(batch_ids[ind])
210
245
  seg = predictions[ind]
211
- write.segmentation(run, seg, segmentation_user_id, segmentation_name,
246
+ writers.segmentation(run, seg, segmentation_user_id, segmentation_name,
212
247
  segmentation_session_id, voxel_spacing)
213
248
 
214
249
  # After processing and saving predictions for a batch:
@@ -220,98 +255,18 @@ class Predictor:
220
255
 
221
256
  def create_tta_augmentations(self):
222
257
  """Define TTA augmentations and inverse transforms."""
223
-
224
- # Instead of Flip lets rotate around the first axis 3 times (90,180,270)
258
+ # Rotate around the YZ plane (dims 3,4 for input, dims 2,3 for output)
225
259
  self.tta_transforms = [
226
- lambda x: x, # Identity (no augmentation)
227
- lambda x: torch.rot90(x, k=1, dims=(3, 4)), # 90° rotation
228
- lambda x: torch.rot90(x, k=2, dims=(3, 4)), # 180° rotation
229
- lambda x: torch.rot90(x, k=3, dims=(3, 4)), # 270° rotation
230
- # Flip(spatial_axis=0), # Flip along x-axis (depth)
231
- # Flip(spatial_axis=1), # Flip along y-axis (height)
232
- # Flip(spatial_axis=2), # Flip along z-axis (width)
260
+ lambda x: x, # Identity (no augmentation)
261
+ lambda x: torch.rot90(x, k=1, dims=(3, 4)), # 90° rotation
262
+ lambda x: torch.rot90(x, k=2, dims=(3, 4)), # 180° rotation
263
+ lambda x: torch.rot90(x, k=3, dims=(3, 4)), # 270° rotation
233
264
  ]
234
265
 
235
266
  # Define inverse transformations (flip back to original orientation)
236
267
  self.inverse_tta_transforms = [
237
- lambda x: x, # Identity (no transformation needed)
268
+ lambda x: x, # Identity (no transformation needed)
238
269
  lambda x: torch.rot90(x, k=-1, dims=(2, 3)), # Inverse of 90° (i.e. -90°)
239
270
  lambda x: torch.rot90(x, k=-2, dims=(2, 3)), # Inverse of 180° (i.e. -180°)
240
271
  lambda x: torch.rot90(x, k=-3, dims=(2, 3)), # Inverse of 270° (i.e. -270°)
241
- # Flip(spatial_axis=0), # Undo Flip along x-axis
242
- # Flip(spatial_axis=1), # Undo Flip along y-axis
243
- # Flip(spatial_axis=2), # Undo Flip along z-axis
244
- ]
245
-
246
- ###################################################################################################################################################
247
-
248
- class MultiGPUPredictor(Predictor):
249
-
250
- def __init__(self,
251
- config: str,
252
- model_config: str,
253
- model_weights: str):
254
- super().__init__(config, model_config, model_weights)
255
- self.num_gpus = torch.cuda.device_count()
256
- if self.num_gpus < 2:
257
- raise RuntimeError("MultiGPUPredictor requires at least 2 GPUs.")
258
-
259
- def predict_on_gpu(self, gpu_id: int, batch_ids: List[str], voxel_spacing: float, tomo_algorithm: str) -> List[np.ndarray]:
260
- """Helper function to run inference on a single GPU."""
261
- device = torch.device(f'cuda:{gpu_id}')
262
- self.model.to(device)
263
-
264
- # Load data specific to the batch assigned to this GPU
265
- test_loader = io.load_predict_data(self.root, batch_ids, voxel_spacing, tomo_algorithm)
266
- predictions = []
267
-
268
- with torch.no_grad():
269
- for data in tqdm(test_loader, desc=f"GPU {gpu_id}"):
270
- tomogram = data['image'].to(device)
271
- data["prediction"] = self.run_inference(tomogram)
272
- data = [self.post_processing(i) for i in decollate_batch(data)]
273
- for b in data:
274
- predictions.append(b['prediction'].squeeze(0).cpu().numpy())
275
-
276
- return predictions
277
-
278
- def multi_gpu_inference(self,
279
- num_tomos_per_batch: int = 15,
280
- runIDs: Optional[List[str]] = None,
281
- voxel_spacing: float = 10,
282
- tomo_algorithm: str = 'denoised',
283
- save: bool = False,
284
- segmentation_name: str = 'prediction',
285
- segmentation_user_id: str = 'monai',
286
- segmentation_session_id: str = '0') -> Optional[List[np.ndarray]]:
287
- """Run inference across multiple GPUs, optionally saving results or returning predictions."""
288
-
289
- runIDs = runIDs or [run.name for run in self.root.runs]
290
- all_predictions = []
291
-
292
- # Divide runIDs into batches for each GPU
293
- batches = [runIDs[i:i + num_tomos_per_batch] for i in range(0, len(run_ids), num_tomos_per_batch)]
294
-
295
- # Run inference in parallel across GPUs
296
- for i in range(0, len(batches), self.num_gpus):
297
- gpu_batches = batches[i:i + self.num_gpus]
298
- with Pool(processes=self.num_gpus) as pool:
299
- results = pool.starmap(
300
- self.predict_on_gpu,
301
- [(gpu_id, gpu_batches[gpu_id], voxel_spacing, tomo_algorithm) for gpu_id in range(len(gpu_batches))]
302
- )
303
-
304
- # Collect and save results
305
- for gpu_id, predictions in enumerate(results):
306
- if save:
307
- for idx, run_id in enumerate(gpu_batches[gpu_id]):
308
- run = self.root.get_run(run_id)
309
- segmentation = predictions[idx]
310
- write.segmentation(run, segmentation, segmentation_user_id, segmentation_name,
311
- segmentation_session_id, voxel_spacing)
312
- else:
313
- all_predictions.extend(predictions)
314
-
315
- print('Multi-GPU predictions complete.')
316
-
317
- return None if save else all_predictions
272
+ ]
@@ -0,0 +1,162 @@
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ from octopi.pytorch.segmentation import Predictor
3
+ from typing import List, Union, Optional
4
+ from copick_utils.io import writers
5
+ import queue, torch
6
+
7
+ class MultiGPUPredictor(Predictor):
8
+
9
+ def __init__(self,
10
+ config: str,
11
+ model_config: Union[str, List[str]],
12
+ model_weights: Union[str, List[str]],
13
+ apply_tta: bool = True):
14
+
15
+ # Initialize parent normally
16
+ super().__init__(config, model_config, model_weights, apply_tta)
17
+
18
+ self.num_gpus = torch.cuda.device_count()
19
+ print(f"Available GPUs: {self.num_gpus}")
20
+
21
+ # Only create GPU-specific models if we have multiple GPUs
22
+ if self.num_gpus > 1:
23
+ self._create_gpu_models()
24
+
25
+ def _create_gpu_models(self):
26
+ """Create separate model instances for each GPU."""
27
+ self.gpu_models = {}
28
+
29
+ for gpu_id in range(self.num_gpus):
30
+ device = torch.device(f'cuda:{gpu_id}')
31
+ gpu_models = []
32
+
33
+ # Copy each model to this GPU
34
+ for model in self.models:
35
+ gpu_model = type(model)()
36
+ gpu_model.load_state_dict(model.state_dict())
37
+ gpu_model.to(device)
38
+ gpu_model.eval()
39
+ gpu_models.append(gpu_model)
40
+
41
+ self.gpu_models[gpu_id] = gpu_models
42
+ print(f"Models loaded on GPU {gpu_id}")
43
+
44
+ def _run_on_gpu(self, gpu_id: int, batch_ids: List[str],
45
+ voxel_spacing: float, tomo_algorithm: str,
46
+ segmentation_name: str, segmentation_user_id: str,
47
+ segmentation_session_id: str):
48
+ """Run inference on a specific GPU for a batch of runs."""
49
+ device = torch.device(f'cuda:{gpu_id}')
50
+
51
+ # Temporarily switch to this GPU's models and device
52
+ original_device = self.device
53
+ original_models = self.models
54
+
55
+ self.device = device
56
+ self.models = self.gpu_models[gpu_id]
57
+
58
+ try:
59
+ print(f"GPU {gpu_id} processing runs: {batch_ids}")
60
+
61
+ # Run prediction using parent class method
62
+ predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
63
+
64
+ # Save predictions
65
+ for idx, run_id in enumerate(batch_ids):
66
+ run = self.root.get_run(run_id)
67
+ seg = predictions[idx]
68
+ writers.segmentation(run, seg, segmentation_user_id,
69
+ segmentation_name, segmentation_session_id,
70
+ voxel_spacing)
71
+
72
+ # Clean up
73
+ del predictions
74
+ torch.cuda.empty_cache()
75
+
76
+ finally:
77
+ # Restore original settings
78
+ self.device = original_device
79
+ self.models = original_models
80
+
81
+ def multigpu_batch_predict(self,
82
+ num_tomos_per_batch: int = 15,
83
+ runIDs: Optional[List[str]] = None,
84
+ voxel_spacing: float = 10,
85
+ tomo_algorithm: str = 'denoised',
86
+ segmentation_name: str = 'prediction',
87
+ segmentation_user_id: str = 'octopi',
88
+ segmentation_session_id: str = '0'):
89
+ """Run inference across multiple GPUs using threading."""
90
+
91
+ # Get runIDs if not provided
92
+ if runIDs is None:
93
+ runIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is not None]
94
+ skippedRunIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is None]
95
+ if skippedRunIDs:
96
+ print(f"Warning: skipping runs with no voxel spacing {voxel_spacing}: {skippedRunIDs}")
97
+
98
+ # Split runIDs into batches
99
+ batches = [runIDs[i:i + num_tomos_per_batch]
100
+ for i in range(0, len(runIDs), num_tomos_per_batch)]
101
+
102
+ print(f"Processing {len(batches)} batches across {self.num_gpus} GPUs")
103
+
104
+ # Create work queue
105
+ batch_queue = queue.Queue()
106
+ for batch in batches:
107
+ batch_queue.put(batch)
108
+
109
+ def worker(gpu_id):
110
+ while True:
111
+ try:
112
+ batch_ids = batch_queue.get_nowait()
113
+ self._run_on_gpu(gpu_id, batch_ids, voxel_spacing, tomo_algorithm,
114
+ segmentation_name, segmentation_user_id,
115
+ segmentation_session_id)
116
+ batch_queue.task_done()
117
+ except queue.Empty:
118
+ break
119
+ except Exception as e:
120
+ print(f"Error on GPU {gpu_id}: {e}")
121
+ batch_queue.task_done()
122
+
123
+ # Start worker threads for each GPU
124
+ with ThreadPoolExecutor(max_workers=self.num_gpus) as executor:
125
+ futures = [executor.submit(worker, gpu_id) for gpu_id in range(self.num_gpus)]
126
+ for future in futures:
127
+ future.result()
128
+
129
+ print('Multi-GPU predictions complete!')
130
+
131
+ def batch_predict(self,
132
+ num_tomos_per_batch: int = 15,
133
+ runIDs: Optional[List[str]] = None,
134
+ voxel_spacing: float = 10,
135
+ tomo_algorithm: str = 'denoised',
136
+ segmentation_name: str = 'prediction',
137
+ segmentation_user_id: str = 'octopi',
138
+ segmentation_session_id: str = '0'):
139
+ """Smart batch predict: uses multi-GPU if available, otherwise single GPU."""
140
+
141
+ if self.num_gpus > 1:
142
+ print("Using multi-GPU inference")
143
+ self.multigpu_batch_predict(
144
+ num_tomos_per_batch=num_tomos_per_batch,
145
+ runIDs=runIDs,
146
+ voxel_spacing=voxel_spacing,
147
+ tomo_algorithm=tomo_algorithm,
148
+ segmentation_name=segmentation_name,
149
+ segmentation_user_id=segmentation_user_id,
150
+ segmentation_session_id=segmentation_session_id
151
+ )
152
+ else:
153
+ print("Using single GPU inference")
154
+ super().batch_predict(
155
+ num_tomos_per_batch=num_tomos_per_batch,
156
+ runIDs=runIDs,
157
+ voxel_spacing=voxel_spacing,
158
+ tomo_algorithm=tomo_algorithm,
159
+ segmentation_name=segmentation_name,
160
+ segmentation_user_id=segmentation_user_id,
161
+ segmentation_session_id=segmentation_session_id
162
+ )
octopi/pytorch/trainer.py CHANGED
@@ -1,6 +1,6 @@
1
- from octopi import visualization_tools as viz
1
+ from octopi.utils import visualization_tools as viz
2
2
  from monai.inferers import sliding_window_inference
3
- from octopi import stopping_criteria
3
+ from octopi.utils import stopping_criteria
4
4
  from monai.transforms import AsDiscrete
5
5
  from monai.data import decollate_batch
6
6
  import torch, os, mlflow, re
@@ -101,6 +101,9 @@ class ModelTrainer:
101
101
  device=self.device
102
102
  )
103
103
 
104
+ del val_inputs
105
+ torch.cuda.empty_cache()
106
+
104
107
  # Compute the loss for this batch
105
108
  loss = self.loss_function(val_outputs, val_labels) # Assuming self.loss_function is defined
106
109
  val_loss += loss.item() # Accumulate the loss
@@ -112,6 +115,9 @@ class ModelTrainer:
112
115
  # Compute metrics
113
116
  self.metrics_function(y_pred=metric_val_outputs, y=metric_val_labels)
114
117
 
118
+ del val_labels, val_outputs, metric_val_outputs, metric_val_labels
119
+ torch.cuda.empty_cache()
120
+
115
121
  # # Contains recall, precision, and f1 for each class
116
122
  metric_values = self.metrics_function.aggregate(reduction='mean_batch')
117
123
 
@@ -435,4 +441,4 @@ class ModelTrainer:
435
441
  best_metric = 'avg_f1'
436
442
 
437
443
  return best_metric
438
-
444
+
File without changes