octopi 1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (45) hide show
  1. octopi/__init__.py +1 -0
  2. octopi/datasets/cached_datset.py +1 -1
  3. octopi/datasets/generators.py +1 -1
  4. octopi/datasets/io.py +200 -0
  5. octopi/datasets/multi_config_generator.py +1 -1
  6. octopi/entry_points/common.py +5 -5
  7. octopi/entry_points/create_slurm_submission.py +1 -1
  8. octopi/entry_points/run_create_targets.py +6 -6
  9. octopi/entry_points/run_evaluate.py +4 -3
  10. octopi/entry_points/run_extract_mb_picks.py +5 -5
  11. octopi/entry_points/run_localize.py +8 -9
  12. octopi/entry_points/run_optuna.py +7 -7
  13. octopi/entry_points/run_segment_predict.py +4 -4
  14. octopi/entry_points/run_train.py +7 -8
  15. octopi/extract/localize.py +11 -19
  16. octopi/extract/membranebound_extract.py +11 -10
  17. octopi/extract/midpoint_extract.py +3 -3
  18. octopi/models/common.py +1 -1
  19. octopi/processing/create_targets_from_picks.py +3 -4
  20. octopi/processing/evaluate.py +24 -11
  21. octopi/processing/importers.py +4 -4
  22. octopi/pytorch/hyper_search.py +2 -3
  23. octopi/pytorch/model_search_submitter.py +4 -4
  24. octopi/pytorch/segmentation.py +141 -190
  25. octopi/pytorch/segmentation_multigpu.py +162 -0
  26. octopi/pytorch/trainer.py +2 -2
  27. octopi/utils/__init__.py +0 -0
  28. octopi/utils/config.py +57 -0
  29. octopi/utils/io.py +128 -0
  30. octopi/{utils.py → utils/parsers.py} +10 -84
  31. octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
  32. octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
  33. octopi/workflows.py +236 -0
  34. {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/METADATA +41 -29
  35. octopi-1.2.0.dist-info/RECORD +62 -0
  36. {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
  37. octopi-1.2.0.dist-info/entry_points.txt +3 -0
  38. {octopi-1.1.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
  39. octopi/io.py +0 -457
  40. octopi/processing/my_metrics.py +0 -26
  41. octopi/processing/writers.py +0 -102
  42. octopi-1.1.dist-info/RECORD +0 -59
  43. octopi-1.1.dist-info/entry_points.txt +0 -4
  44. /octopi/{losses.py → utils/losses.py} +0 -0
  45. /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
@@ -5,11 +5,12 @@ from monai.data import MetaTensor
5
5
  from monai.transforms import (
6
6
  Compose, AsDiscrete, Activations
7
7
  )
8
- import octopi.processing.writers as write
8
+ from typing import List, Optional, Union
9
+ from octopi.datasets import io as dataio
10
+ from copick_utils.io import writers
9
11
  from octopi.models import common
10
- from typing import List, Optional
11
12
  import torch, copick, gc, os
12
- from octopi import io, utils
13
+ from octopi.utils import io
13
14
  from tqdm import tqdm
14
15
  import numpy as np
15
16
 
@@ -17,20 +18,14 @@ class Predictor:
17
18
 
18
19
  def __init__(self,
19
20
  config: str,
20
- model_config: str,
21
- model_weights: str,
21
+ model_config: Union[str, List[str]],
22
+ model_weights: Union[str, List[str]],
22
23
  apply_tta: bool = True,
23
24
  device: Optional[str] = None):
24
25
 
26
+ # Open the Copick Project
25
27
  self.config = config
26
28
  self.root = copick.from_file(config)
27
-
28
- # Load the model config
29
- model_config = utils.load_yaml(model_config)
30
-
31
- self.Nclass = model_config['model']['num_classes']
32
- self.dim_in = model_config['model']['dim_in']
33
- self.input_dim = None
34
29
 
35
30
  # Get the number of GPUs available
36
31
  num_gpus = torch.cuda.device_count()
@@ -44,109 +39,147 @@ class Predictor:
44
39
  self.device = device
45
40
  print('Running Inference On: ', self.device)
46
41
 
47
- # Check to see if the model weights file exists
48
- if not os.path.exists(model_weights):
49
- raise ValueError(f"Model weights file does not exist: {model_weights}")
42
+ # Initialize TTA if enabled
43
+ self.apply_tta = apply_tta
44
+ self.create_tta_augmentations()
50
45
 
51
- # Load the model weights
52
- model_builder = common.get_model(model_config['model']['architecture'])
53
- model_builder.build_model(model_config['model'])
54
- self.model = model_builder.model
55
- state_dict = torch.load(model_weights, map_location=self.device, weights_only=True)
56
- self.model.load_state_dict(state_dict)
57
- self.model.to(self.device)
58
- self.model.eval()
46
+ # Determine if Model Soup is Enabled
47
+ if isinstance(model_weights, str):
48
+ model_weights = [model_weights]
49
+ self.apply_modelsoup = len(model_weights) > 1
59
50
 
60
- # Initialize TTA if enabled
61
- self.apply_tta = apply_tta
62
- if self.apply_tta:
63
- self.create_tta_augmentations()
64
- # self.post_transforms = Compose([
65
- # Activations(softmax=True) # Keep probability output
66
- # ])
67
- self.softmax_transform = Compose([
68
- Activations(softmax=True) # Keep probability output
69
- ])
51
+ # Handle Single Model Config or Multiple Model Configs
52
+ if isinstance(model_config, str):
53
+ model_config = [model_config] * len(model_weights)
54
+ elif len(model_config) != len(model_weights):
55
+ raise ValueError("Number of model configs must match number of model weights.")
56
+
57
+ # Load the model(s)
58
+ self._load_models(model_config, model_weights)
59
+
60
+ def _load_models(self, model_config: List[str], model_weights: List[str]):
61
+ """Load a single model or multiple models for soup."""
62
+
63
+ self.models = []
64
+ for i, (config_path, weights_path) in enumerate(zip(model_config, model_weights)):
65
+
66
+ # Load the Model Config and Model Builder
67
+ current_modelconfig = io.load_yaml(config_path)
68
+ model_builder = common.get_model(current_modelconfig['model']['architecture'])
69
+
70
+ # Check if the weights file exists
71
+ if not os.path.exists(weights_path):
72
+ raise ValueError(f"Model weights file does not exist: {weights_path}")
70
73
 
71
- # Create the final discretization transform
72
- self.discretize_transform = AsDiscrete(argmax=True)
73
- else:
74
- # Define the post-processing transforms
75
- self.post_transforms = Compose([
76
- Activations(softmax=True),
77
- AsDiscrete(argmax=True)
78
- ])
74
+ # Create model
75
+ model_builder.build_model(current_modelconfig['model'])
76
+ model = model_builder.model
77
+
78
+ # Load weights
79
+ state_dict = torch.load(weights_path, map_location=self.device, weights_only=True)
80
+ model.load_state_dict(state_dict)
81
+ model.to(self.device)
82
+ model.eval()
79
83
 
84
+ self.models.append(model)
80
85
 
81
- def _run_inference(self, input):
82
- """Apply sliding window inference to the input."""
83
- with torch.no_grad():
84
- predictions = sliding_window_inference(
85
- inputs=input,
86
- roi_size=(self.dim_in, self.dim_in, self.dim_in),
87
- sw_batch_size=4, # one window is proecessed at a time
88
- predictor=self.model,
89
- overlap=0.5,
90
- )
91
- return [self.post_transforms(i) for i in decollate_batch(predictions)]
86
+ # For backward compatibility, also set self.model to the first model
87
+ self.model = self.models[0]
88
+
89
+ # Set the Number of Classes and Input Dimensions - Assume All Models are the Same
90
+ self.Nclass = current_modelconfig['model']['num_classes']
91
+ self.dim_in = current_modelconfig['model']['dim_in']
92
+ self.input_dim = None
92
93
 
93
- def _run_inference_tta(self, input_data):
94
- """Memory-efficient TTA implementation that returns proper discrete segmentation maps."""
94
+ # Print a message if Model Soup is Enabled
95
+ if self.apply_modelsoup:
96
+ print(f'Model Soup is Enabled : {len(self.models)} models loaded for ensemble inference')
97
+
98
+ def _run_single_model_inference(self, model, input_data):
99
+ """Run sliding window inference on a single model."""
100
+ return sliding_window_inference(
101
+ inputs=input_data,
102
+ roi_size=(self.dim_in, self.dim_in, self.dim_in),
103
+ sw_batch_size=4,
104
+ predictor=model,
105
+ overlap=0.5,
106
+ )
107
+
108
+ def _apply_tta_single_model(self, model, single_sample):
109
+ """Apply TTA to a single model and single sample."""
110
+ # Initialize probability accumulator
111
+ acc_probs = torch.zeros(
112
+ (1, self.Nclass, *single_sample.shape[2:]),
113
+ dtype=torch.float32, device=self.device
114
+ )
115
+
116
+ # Process each augmentation
117
+ with torch.no_grad():
118
+ for tta_transform, inverse_transform in zip(self.tta_transforms, self.inverse_tta_transforms):
119
+ # Apply transform
120
+ aug_sample = tta_transform(single_sample)
121
+
122
+ # Run inference
123
+ predictions = self._run_single_model_inference(model, aug_sample)
124
+
125
+ # Get softmax probabilities
126
+ probs = torch.softmax(predictions[0], dim=0)
127
+
128
+ # Apply inverse transform
129
+ inv_probs = inverse_transform(probs)
130
+
131
+ # Accumulate probabilities
132
+ acc_probs[0] += inv_probs
133
+
134
+ # Clear memory
135
+ del predictions, probs, inv_probs, aug_sample
136
+ torch.cuda.empty_cache()
137
+
138
+ # Average accumulated probabilities
139
+ acc_probs = acc_probs / len(self.tta_transforms)
95
140
 
141
+ return acc_probs[0] # Return shape [Nclass, Z, Y, X]
142
+
143
+ def _run_inference(self, input_data):
144
+ """
145
+ Main inference function that handles all combinations - Model Soup and/or TTA
146
+ """
96
147
  batch_size = input_data.shape[0]
97
148
  results = []
98
149
 
99
- # Process one sample at a time
150
+ # Process one sample at a time for memory efficiency
100
151
  for sample_idx in range(batch_size):
101
- # Extract single sample
102
152
  single_sample = input_data[sample_idx:sample_idx+1]
103
153
 
104
154
  # Initialize probability accumulator for this sample
105
- # Shape: [1, Nclass, Z, Y, X]
106
155
  acc_probs = torch.zeros(
107
- (1, self.Nclass, *single_sample.shape[2:]),
156
+ (self.Nclass, *single_sample.shape[2:]),
108
157
  dtype=torch.float32, device=self.device
109
158
  )
110
159
 
111
- # Process each augmentation
160
+ # Process each model
112
161
  with torch.no_grad():
113
- for tta_transform, inverse_transform in zip(self.tta_transforms, self.inverse_tta_transforms):
114
- # Apply transform to single sample
115
- aug_sample = tta_transform(single_sample)
162
+ for model in self.models:
163
+ # Apply TTA with this model
164
+ if self.apply_tta:
165
+ model_probs = self._apply_tta_single_model(model, single_sample)
166
+ # Run inference without TTA
167
+ else:
168
+ predictions = self._run_single_model_inference(model, single_sample)
169
+ model_probs = torch.softmax(predictions[0], dim=0)
170
+ del predictions
116
171
 
117
- # Free memory
172
+ # Accumulate probabilities from this model
173
+ acc_probs += model_probs
174
+ del model_probs
118
175
  torch.cuda.empty_cache()
119
-
120
- # Run inference (one sample at a time)
121
- predictions = sliding_window_inference(
122
- inputs=aug_sample,
123
- roi_size=(self.dim_in, self.dim_in, self.dim_in),
124
- sw_batch_size=4, # Process one window at a time
125
- predictor=self.model,
126
- overlap=0.5,
127
- )
128
-
129
- # Get softmax probabilities
130
- probs = self.softmax_transform(predictions[0]) # Get first (only) item
131
-
132
- # Apply inverse transform with correct dimensions
133
- inv_probs = inverse_transform(probs)
134
-
135
- # Accumulate probabilities
136
- acc_probs[0] += inv_probs
137
-
138
- # Clear memory
139
- del predictions, probs, inv_probs, aug_sample
140
176
 
141
- # Average accumulated probabilities
142
- acc_probs = acc_probs / len(self.tta_transforms)
177
+ # Average probabilities across models (and TTA augmentations if applied)
178
+ acc_probs = acc_probs / len(self.models)
143
179
 
144
- # Convert to discrete prediction - get argmax along class dimension
145
- # This gives us a tensor of shape [1, Z, Y, X] with discrete class indices
146
- discrete_pred = torch.argmax(acc_probs, dim=1)
147
-
148
- # Add to results - keeping only the spatial dimensions [Z, Y, X]
149
- results.append(discrete_pred[0])
180
+ # Convert to discrete prediction
181
+ discrete_pred = torch.argmax(acc_probs, dim=0)
182
+ results.append(discrete_pred)
150
183
 
151
184
  # Clear memory
152
185
  del acc_probs, discrete_pred
@@ -157,38 +190,37 @@ class Predictor:
157
190
  def predict_on_gpu(self,
158
191
  runIDs: List[str],
159
192
  voxel_spacing: float,
160
- tomo_algorithm: str ):
193
+ tomo_algorithm: str):
161
194
 
162
195
  # Load data for the current batch
163
- test_loader, test_dataset = io.create_predict_dataloader(
196
+ test_loader, test_dataset = dataio.create_predict_dataloader(
164
197
  self.root,
165
198
  voxel_spacing, tomo_algorithm,
166
199
  runIDs)
167
200
 
168
201
  # Determine Input Crop Size.
169
202
  if self.input_dim is None:
170
- self.input_dim = io.get_input_dimensions(test_dataset, self.dim_in)
203
+ self.input_dim = dataio.get_input_dimensions(test_dataset, self.dim_in)
171
204
 
172
205
  predictions = []
173
206
  with torch.no_grad():
174
207
  for data in tqdm(test_loader):
175
208
  tomogram = data['image'].to(self.device)
176
- if self.apply_tta: data['pred'] = self._run_inference_tta(tomogram)
177
- else: data['pred'] = self._run_inference(tomogram)
209
+ data['pred'] = self._run_inference(tomogram)
210
+
178
211
  for idx in range(len(data['image'])):
179
212
  predictions.append(data['pred'][idx].squeeze(0).numpy(force=True))
180
213
 
181
214
  return predictions
182
215
 
183
216
  def batch_predict(self,
184
- num_tomos_per_batch = 15,
185
- runIDs: Optional[str] = None,
217
+ num_tomos_per_batch: int = 15,
218
+ runIDs: Optional[List[str]] = None,
186
219
  voxel_spacing: float = 10,
187
220
  tomo_algorithm: str = 'denoised',
188
221
  segmentation_name: str = 'prediction',
189
222
  segmentation_user_id: str = 'octopi',
190
223
  segmentation_session_id: str = '0'):
191
-
192
224
  """Run inference on tomograms in batches."""
193
225
 
194
226
  # If runIDs are not provided, load all runs
@@ -201,10 +233,9 @@ class Predictor:
201
233
 
202
234
  # Iterate over batches of runIDs
203
235
  for i in range(0, len(runIDs), num_tomos_per_batch):
204
-
205
236
  # Get a batch of runIDs
206
237
  batch_ids = runIDs[i:i + num_tomos_per_batch]
207
- print('Running Inference on the Follow RunIDs: ', batch_ids)
238
+ print('Running Inference on the Following RunIDs: ', batch_ids)
208
239
 
209
240
  predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
210
241
 
@@ -212,7 +243,7 @@ class Predictor:
212
243
  for ind in range(len(batch_ids)):
213
244
  run = self.root.get_run(batch_ids[ind])
214
245
  seg = predictions[ind]
215
- write.segmentation(run, seg, segmentation_user_id, segmentation_name,
246
+ writers.segmentation(run, seg, segmentation_user_id, segmentation_name,
216
247
  segmentation_session_id, voxel_spacing)
217
248
 
218
249
  # After processing and saving predictions for a batch:
@@ -224,98 +255,18 @@ class Predictor:
224
255
 
225
256
  def create_tta_augmentations(self):
226
257
  """Define TTA augmentations and inverse transforms."""
227
-
228
- # Instead of Flip lets rotate around the first axis 3 times (90,180,270)
258
+ # Rotate around the YZ plane (dims 3,4 for input, dims 2,3 for output)
229
259
  self.tta_transforms = [
230
- lambda x: x, # Identity (no augmentation)
231
- lambda x: torch.rot90(x, k=1, dims=(3, 4)), # 90° rotation
232
- lambda x: torch.rot90(x, k=2, dims=(3, 4)), # 180° rotation
233
- lambda x: torch.rot90(x, k=3, dims=(3, 4)), # 270° rotation
234
- # lambda x: torch.flip(x, dims=(3,)), # Flip along height (spatial_axis=1)
235
- # lambda x: torch.flip(x, dims=(4,)), # Flip along width (spatial_axis=2)
236
- # lambda x: torch.flip(x, dims=(3, 4)), # Flip along both height and width
260
+ lambda x: x, # Identity (no augmentation)
261
+ lambda x: torch.rot90(x, k=1, dims=(3, 4)), # 90° rotation
262
+ lambda x: torch.rot90(x, k=2, dims=(3, 4)), # 180° rotation
263
+ lambda x: torch.rot90(x, k=3, dims=(3, 4)), # 270° rotation
237
264
  ]
238
265
 
239
266
  # Define inverse transformations (flip back to original orientation)
240
267
  self.inverse_tta_transforms = [
241
- lambda x: x, # Identity (no transformation needed)
268
+ lambda x: x, # Identity (no transformation needed)
242
269
  lambda x: torch.rot90(x, k=-1, dims=(2, 3)), # Inverse of 90° (i.e. -90°)
243
270
  lambda x: torch.rot90(x, k=-2, dims=(2, 3)), # Inverse of 180° (i.e. -180°)
244
271
  lambda x: torch.rot90(x, k=-3, dims=(2, 3)), # Inverse of 270° (i.e. -270°)
245
- # lambda x: torch.flip(x, dims=(2,)), # Same as forward
246
- # lambda x: torch.flip(x, dims=(3,)), # Same as forward
247
- # lambda x: torch.flip(x, dims=(2, 3)), # Same as forward
248
- ]
249
-
250
- ###################################################################################################################################################
251
-
252
- class MultiGPUPredictor(Predictor):
253
-
254
- def __init__(self,
255
- config: str,
256
- model_config: str,
257
- model_weights: str):
258
- super().__init__(config, model_config, model_weights)
259
- self.num_gpus = torch.cuda.device_count()
260
- if self.num_gpus < 2:
261
- raise RuntimeError("MultiGPUPredictor requires at least 2 GPUs.")
262
-
263
- def predict_on_gpu(self, gpu_id: int, batch_ids: List[str], voxel_spacing: float, tomo_algorithm: str) -> List[np.ndarray]:
264
- """Helper function to run inference on a single GPU."""
265
- device = torch.device(f'cuda:{gpu_id}')
266
- self.model.to(device)
267
-
268
- # Load data specific to the batch assigned to this GPU
269
- test_loader = io.load_predict_data(self.root, batch_ids, voxel_spacing, tomo_algorithm)
270
- predictions = []
271
-
272
- with torch.no_grad():
273
- for data in tqdm(test_loader, desc=f"GPU {gpu_id}"):
274
- tomogram = data['image'].to(device)
275
- data["prediction"] = self.run_inference(tomogram)
276
- data = [self.post_processing(i) for i in decollate_batch(data)]
277
- for b in data:
278
- predictions.append(b['prediction'].squeeze(0).cpu().numpy())
279
-
280
- return predictions
281
-
282
- def multi_gpu_inference(self,
283
- num_tomos_per_batch: int = 15,
284
- runIDs: Optional[List[str]] = None,
285
- voxel_spacing: float = 10,
286
- tomo_algorithm: str = 'denoised',
287
- save: bool = False,
288
- segmentation_name: str = 'prediction',
289
- segmentation_user_id: str = 'monai',
290
- segmentation_session_id: str = '0') -> Optional[List[np.ndarray]]:
291
- """Run inference across multiple GPUs, optionally saving results or returning predictions."""
292
-
293
- runIDs = runIDs or [run.name for run in self.root.runs]
294
- all_predictions = []
295
-
296
- # Divide runIDs into batches for each GPU
297
- batches = [runIDs[i:i + num_tomos_per_batch] for i in range(0, len(run_ids), num_tomos_per_batch)]
298
-
299
- # Run inference in parallel across GPUs
300
- for i in range(0, len(batches), self.num_gpus):
301
- gpu_batches = batches[i:i + self.num_gpus]
302
- with Pool(processes=self.num_gpus) as pool:
303
- results = pool.starmap(
304
- self.predict_on_gpu,
305
- [(gpu_id, gpu_batches[gpu_id], voxel_spacing, tomo_algorithm) for gpu_id in range(len(gpu_batches))]
306
- )
307
-
308
- # Collect and save results
309
- for gpu_id, predictions in enumerate(results):
310
- if save:
311
- for idx, run_id in enumerate(gpu_batches[gpu_id]):
312
- run = self.root.get_run(run_id)
313
- segmentation = predictions[idx]
314
- write.segmentation(run, segmentation, segmentation_user_id, segmentation_name,
315
- segmentation_session_id, voxel_spacing)
316
- else:
317
- all_predictions.extend(predictions)
318
-
319
- print('Multi-GPU predictions complete.')
320
-
321
- return None if save else all_predictions
272
+ ]
@@ -0,0 +1,162 @@
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ from octopi.pytorch.segmentation import Predictor
3
+ from typing import List, Union, Optional
4
+ from copick_utils.io import writers
5
+ import queue, torch
6
+
7
+ class MultiGPUPredictor(Predictor):
8
+
9
+ def __init__(self,
10
+ config: str,
11
+ model_config: Union[str, List[str]],
12
+ model_weights: Union[str, List[str]],
13
+ apply_tta: bool = True):
14
+
15
+ # Initialize parent normally
16
+ super().__init__(config, model_config, model_weights, apply_tta)
17
+
18
+ self.num_gpus = torch.cuda.device_count()
19
+ print(f"Available GPUs: {self.num_gpus}")
20
+
21
+ # Only create GPU-specific models if we have multiple GPUs
22
+ if self.num_gpus > 1:
23
+ self._create_gpu_models()
24
+
25
+ def _create_gpu_models(self):
26
+ """Create separate model instances for each GPU."""
27
+ self.gpu_models = {}
28
+
29
+ for gpu_id in range(self.num_gpus):
30
+ device = torch.device(f'cuda:{gpu_id}')
31
+ gpu_models = []
32
+
33
+ # Copy each model to this GPU
34
+ for model in self.models:
35
+ gpu_model = type(model)()
36
+ gpu_model.load_state_dict(model.state_dict())
37
+ gpu_model.to(device)
38
+ gpu_model.eval()
39
+ gpu_models.append(gpu_model)
40
+
41
+ self.gpu_models[gpu_id] = gpu_models
42
+ print(f"Models loaded on GPU {gpu_id}")
43
+
44
+ def _run_on_gpu(self, gpu_id: int, batch_ids: List[str],
45
+ voxel_spacing: float, tomo_algorithm: str,
46
+ segmentation_name: str, segmentation_user_id: str,
47
+ segmentation_session_id: str):
48
+ """Run inference on a specific GPU for a batch of runs."""
49
+ device = torch.device(f'cuda:{gpu_id}')
50
+
51
+ # Temporarily switch to this GPU's models and device
52
+ original_device = self.device
53
+ original_models = self.models
54
+
55
+ self.device = device
56
+ self.models = self.gpu_models[gpu_id]
57
+
58
+ try:
59
+ print(f"GPU {gpu_id} processing runs: {batch_ids}")
60
+
61
+ # Run prediction using parent class method
62
+ predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
63
+
64
+ # Save predictions
65
+ for idx, run_id in enumerate(batch_ids):
66
+ run = self.root.get_run(run_id)
67
+ seg = predictions[idx]
68
+ writers.segmentation(run, seg, segmentation_user_id,
69
+ segmentation_name, segmentation_session_id,
70
+ voxel_spacing)
71
+
72
+ # Clean up
73
+ del predictions
74
+ torch.cuda.empty_cache()
75
+
76
+ finally:
77
+ # Restore original settings
78
+ self.device = original_device
79
+ self.models = original_models
80
+
81
+ def multigpu_batch_predict(self,
82
+ num_tomos_per_batch: int = 15,
83
+ runIDs: Optional[List[str]] = None,
84
+ voxel_spacing: float = 10,
85
+ tomo_algorithm: str = 'denoised',
86
+ segmentation_name: str = 'prediction',
87
+ segmentation_user_id: str = 'octopi',
88
+ segmentation_session_id: str = '0'):
89
+ """Run inference across multiple GPUs using threading."""
90
+
91
+ # Get runIDs if not provided
92
+ if runIDs is None:
93
+ runIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is not None]
94
+ skippedRunIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is None]
95
+ if skippedRunIDs:
96
+ print(f"Warning: skipping runs with no voxel spacing {voxel_spacing}: {skippedRunIDs}")
97
+
98
+ # Split runIDs into batches
99
+ batches = [runIDs[i:i + num_tomos_per_batch]
100
+ for i in range(0, len(runIDs), num_tomos_per_batch)]
101
+
102
+ print(f"Processing {len(batches)} batches across {self.num_gpus} GPUs")
103
+
104
+ # Create work queue
105
+ batch_queue = queue.Queue()
106
+ for batch in batches:
107
+ batch_queue.put(batch)
108
+
109
+ def worker(gpu_id):
110
+ while True:
111
+ try:
112
+ batch_ids = batch_queue.get_nowait()
113
+ self._run_on_gpu(gpu_id, batch_ids, voxel_spacing, tomo_algorithm,
114
+ segmentation_name, segmentation_user_id,
115
+ segmentation_session_id)
116
+ batch_queue.task_done()
117
+ except queue.Empty:
118
+ break
119
+ except Exception as e:
120
+ print(f"Error on GPU {gpu_id}: {e}")
121
+ batch_queue.task_done()
122
+
123
+ # Start worker threads for each GPU
124
+ with ThreadPoolExecutor(max_workers=self.num_gpus) as executor:
125
+ futures = [executor.submit(worker, gpu_id) for gpu_id in range(self.num_gpus)]
126
+ for future in futures:
127
+ future.result()
128
+
129
+ print('Multi-GPU predictions complete!')
130
+
131
+ def batch_predict(self,
132
+ num_tomos_per_batch: int = 15,
133
+ runIDs: Optional[List[str]] = None,
134
+ voxel_spacing: float = 10,
135
+ tomo_algorithm: str = 'denoised',
136
+ segmentation_name: str = 'prediction',
137
+ segmentation_user_id: str = 'octopi',
138
+ segmentation_session_id: str = '0'):
139
+ """Smart batch predict: uses multi-GPU if available, otherwise single GPU."""
140
+
141
+ if self.num_gpus > 1:
142
+ print("Using multi-GPU inference")
143
+ self.multigpu_batch_predict(
144
+ num_tomos_per_batch=num_tomos_per_batch,
145
+ runIDs=runIDs,
146
+ voxel_spacing=voxel_spacing,
147
+ tomo_algorithm=tomo_algorithm,
148
+ segmentation_name=segmentation_name,
149
+ segmentation_user_id=segmentation_user_id,
150
+ segmentation_session_id=segmentation_session_id
151
+ )
152
+ else:
153
+ print("Using single GPU inference")
154
+ super().batch_predict(
155
+ num_tomos_per_batch=num_tomos_per_batch,
156
+ runIDs=runIDs,
157
+ voxel_spacing=voxel_spacing,
158
+ tomo_algorithm=tomo_algorithm,
159
+ segmentation_name=segmentation_name,
160
+ segmentation_user_id=segmentation_user_id,
161
+ segmentation_session_id=segmentation_session_id
162
+ )
octopi/pytorch/trainer.py CHANGED
@@ -1,6 +1,6 @@
1
- from octopi import visualization_tools as viz
1
+ from octopi.utils import visualization_tools as viz
2
2
  from monai.inferers import sliding_window_inference
3
- from octopi import stopping_criteria
3
+ from octopi.utils import stopping_criteria
4
4
  from monai.transforms import AsDiscrete
5
5
  from monai.data import decollate_batch
6
6
  import torch, os, mlflow, re
File without changes
octopi/utils/config.py ADDED
@@ -0,0 +1,57 @@
1
+ """
2
+ Configuration utilities for MLflow setup and reproducibility.
3
+ """
4
+
5
+ from dotenv import load_dotenv
6
+ import torch, numpy as np
7
+ import os, random
8
+ import octopi
9
+
10
+
11
+ def mlflow_setup():
12
+ """
13
+ Set up MLflow configuration from environment variables.
14
+ """
15
+ module_root = os.path.dirname(octopi.__file__)
16
+ dotenv_path = module_root.replace('src/octopi','') + '.env'
17
+ load_dotenv(dotenv_path=dotenv_path)
18
+
19
+ # MLflow setup
20
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
21
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
22
+ if not password or not username:
23
+ print("Password not found in environment, loading from .env file...")
24
+ load_dotenv() # Loads environment variables from a .env file
25
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
26
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
27
+
28
+ # Check again after loading .env file
29
+ if not password:
30
+ raise ValueError("Password is not set in environment variables or .env file!")
31
+ else:
32
+ print("Password loaded successfully")
33
+ os.environ['MLFLOW_TRACKING_USERNAME'] = username
34
+ os.environ['MLFLOW_TRACKING_PASSWORD'] = password
35
+
36
+ return os.getenv('MLFLOW_TRACKING_URI')
37
+
38
+
39
+ def set_seed(seed):
40
+ """
41
+ Set random seeds for reproducibility across Python, NumPy, and PyTorch.
42
+ """
43
+ # Set the seed for Python's random module
44
+ random.seed(seed)
45
+
46
+ # Set the seed for NumPy
47
+ np.random.seed(seed)
48
+
49
+ # Set the seed for PyTorch (both CPU and GPU)
50
+ torch.manual_seed(seed)
51
+ if torch.cuda.is_available():
52
+ torch.cuda.manual_seed(seed)
53
+ torch.cuda.manual_seed_all(seed) # If using multi-GPU
54
+
55
+ # Ensure reproducibility of operations by disabling certain optimizations
56
+ torch.backends.cudnn.deterministic = True
57
+ torch.backends.cudnn.benchmark = False