octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
@@ -0,0 +1,363 @@
1
+ from monai.inferers import sliding_window_inference
2
+ from typing import List, Optional, Union
3
+ from octopi.datasets import io as dataio
4
+ import torch, copick, gc, os, pprint
5
+ from copick_utils.io import writers
6
+ from octopi.models import common
7
+ from octopi.utils import io
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+
11
+ from monai.transforms import (
12
+ Compose,
13
+ NormalizeIntensity,
14
+ EnsureChannelFirst,
15
+ )
16
+
17
+ class Predictor:
18
+
19
+ def __init__(self,
20
+ config: str,
21
+ model_config: Union[str, List[str]],
22
+ model_weights: Union[str, List[str]],
23
+ apply_tta: bool = True,
24
+ device: Optional[str] = None):
25
+
26
+ # Open the Copick Project
27
+ self.config = config
28
+ self.root = copick.from_file(config)
29
+
30
+ # Get the number of GPUs available
31
+ num_gpus = torch.cuda.device_count()
32
+ if num_gpus == 0:
33
+ raise RuntimeError("No GPUs available.")
34
+
35
+ # Set the device
36
+ if device is None:
37
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
+ else:
39
+ self.device = device
40
+ print('Running Inference On: ', self.device)
41
+
42
+ # Initialize TTA if enabled
43
+ self.apply_tta = apply_tta
44
+ self.create_tta_augmentations()
45
+
46
+ # Determine if Model Soup is Enabled
47
+ if isinstance(model_weights, str):
48
+ model_weights = [model_weights]
49
+ self.apply_modelsoup = len(model_weights) > 1
50
+ self.model_weights = model_weights
51
+
52
+ # Sliding Window Inference Parameters
53
+ self.sw_bs = 4 # sliding window batch size
54
+ self.overlap = 0.5 # overlap between windows
55
+ self.sw = None
56
+
57
+ # Handle Single Model Config or Multiple Model Configs
58
+ if isinstance(model_config, str):
59
+ model_config = [model_config] * len(model_weights)
60
+ elif len(model_config) != len(model_weights):
61
+ raise ValueError("Number of model configs must match number of model weights.")
62
+ self.model_config = model_config
63
+
64
+ # Load the model(s)
65
+ self._load_models(model_config, model_weights)
66
+
67
+ def _load_models(self, model_config: List[str], model_weights: List[str]):
68
+ """Load a single model or multiple models for soup."""
69
+
70
+ self.models = []
71
+ for i, (config_path, weights_path) in enumerate(zip(model_config, model_weights)):
72
+
73
+ # Load the Model Config and Model Builder
74
+ current_modelconfig = io.load_yaml(config_path)
75
+ model_builder = common.get_model(current_modelconfig['model']['architecture'])
76
+
77
+ # Check if the weights file exists
78
+ if not os.path.exists(weights_path):
79
+ raise ValueError(f"Model weights file does not exist: {weights_path}")
80
+
81
+ # Create model
82
+ model_builder.build_model(current_modelconfig['model'])
83
+ model = model_builder.model
84
+
85
+ # Load weights
86
+ state_dict = torch.load(weights_path, map_location=self.device, weights_only=True)
87
+ model.load_state_dict(state_dict)
88
+ model.to(self.device)
89
+ model.eval()
90
+
91
+ self.models.append(model)
92
+
93
+ # For backward compatibility, also set self.model to the first model
94
+ self.model = self.models[0]
95
+
96
+ # Set the Number of Classes and Input Dimensions - Assume All Models are the Same
97
+ self.Nclass = current_modelconfig['model']['num_classes']
98
+ self.dim_in = current_modelconfig['model']['dim_in']
99
+ self.input_dim = None
100
+
101
+ # Print a message if Model Soup is Enabled
102
+ if self.apply_modelsoup:
103
+ print(f'Model Soup is Enabled : {len(self.models)} models loaded for ensemble inference')
104
+
105
+ def predict(self, input_data):
106
+ """Run Prediction from an Input Tomogram.
107
+ Args:
108
+ input_data (torch.Tensor or np.ndarray): Input tomogram of shape [Z, Y, X]
109
+ Returns:
110
+ Predicted segmentation mask of shape [Z, Y, X]
111
+ """
112
+
113
+ is_numpy = False
114
+ if isinstance(input_data, np.ndarray):
115
+ is_numpy = True
116
+ input_data = torch.from_numpy(input_data)
117
+
118
+ # Apply transforms directly to tensor (no dictionary needed)
119
+ pre_transforms = Compose([
120
+ EnsureChannelFirst(channel_dim="no_channel"),
121
+ NormalizeIntensity(),
122
+ ])
123
+
124
+ input_data = pre_transforms(input_data)
125
+
126
+ # Add batch dimension and move to device
127
+ input_data = input_data.unsqueeze(0).to(self.device)
128
+
129
+ # Run inference
130
+ pred = self._run_inference(input_data)[0]
131
+
132
+ if is_numpy:
133
+ pred = pred.cpu().numpy()
134
+ return pred
135
+
136
+ def _run_single_model_inference(self, model, input_data):
137
+ """Run sliding window inference on a single model."""
138
+ with torch.cuda.amp.autocast():
139
+ return sliding_window_inference(
140
+ inputs=input_data,
141
+ roi_size=(self.dim_in, self.dim_in, self.dim_in),
142
+ sw_batch_size=self.sw_bs,
143
+ predictor=model,
144
+ overlap=self.overlap,
145
+ )
146
+
147
+ def _apply_tta_single_model(self, model, single_sample):
148
+ """Apply TTA to a single model and single sample."""
149
+ # Initialize probability accumulator
150
+ acc_probs = torch.zeros(
151
+ (1, self.Nclass, *single_sample.shape[2:]),
152
+ dtype=torch.float32, device=self.device
153
+ )
154
+
155
+ # Process each augmentation
156
+ with torch.inference_mode():
157
+ for tta_transform, inverse_transform in zip(self.tta_transforms, self.inverse_tta_transforms):
158
+ # Apply transform
159
+ aug_sample = tta_transform(single_sample)
160
+
161
+ # Run inference
162
+ predictions = self._run_single_model_inference(model, aug_sample)
163
+
164
+ # Get softmax probabilities
165
+ probs = torch.softmax(predictions[0], dim=0)
166
+
167
+ # Apply inverse transform
168
+ inv_probs = inverse_transform(probs)
169
+
170
+ # Accumulate probabilities
171
+ acc_probs[0] += inv_probs
172
+
173
+ # Clear memory
174
+ del predictions, probs, inv_probs, aug_sample
175
+ torch.cuda.empty_cache()
176
+
177
+ # Average accumulated probabilities
178
+ acc_probs = acc_probs / len(self.tta_transforms)
179
+
180
+ return acc_probs[0] # Return shape [Nclass, Z, Y, X]
181
+
182
+ def _run_inference(self, input_data):
183
+ """
184
+ Main inference function that handles all combinations - Model Soup and/or TTA
185
+ """
186
+ # Overwrite sw_bs with sw if provided
187
+ if self.sw is not None:
188
+ self.sw_bs = self.sw
189
+
190
+ # Get the batch size (# of tomograms)
191
+ batch_size = input_data.shape[0]
192
+ results = []
193
+
194
+ # Process one sample at a time for memory efficiency
195
+ for sample_idx in range(batch_size):
196
+ single_sample = input_data[sample_idx:sample_idx+1]
197
+
198
+ # Initialize probability accumulator for this sample
199
+ acc_probs = torch.zeros(
200
+ (self.Nclass, *single_sample.shape[2:]),
201
+ dtype=torch.float32, device=self.device
202
+ )
203
+
204
+ # Process each model
205
+ with torch.inference_mode():
206
+ for model in self.models:
207
+ # Apply TTA with this model
208
+ if self.apply_tta:
209
+ model_probs = self._apply_tta_single_model(model, single_sample)
210
+ # Run inference without TTA
211
+ else:
212
+ predictions = self._run_single_model_inference(model, single_sample)
213
+ model_probs = torch.softmax(predictions[0], dim=0)
214
+ del predictions
215
+
216
+ # Accumulate probabilities from this model
217
+ acc_probs += model_probs
218
+ del model_probs
219
+ torch.cuda.empty_cache()
220
+
221
+ # Average probabilities across models (and TTA augmentations if applied)
222
+ acc_probs = acc_probs / len(self.models)
223
+
224
+ # Convert to discrete prediction
225
+ discrete_pred = torch.argmax(acc_probs, dim=0)
226
+ results.append(discrete_pred)
227
+
228
+ # Clear memory
229
+ del acc_probs, discrete_pred
230
+ torch.cuda.empty_cache()
231
+
232
+ return results
233
+
234
+ def predict_on_gpu(self,
235
+ runIDs: List[str],
236
+ voxel_spacing: float,
237
+ tomo_algorithm: str):
238
+
239
+ # Load data for the current batch
240
+ test_loader, test_dataset = dataio.create_predict_dataloader(
241
+ self.root,
242
+ voxel_spacing, tomo_algorithm,
243
+ runIDs)
244
+
245
+ # Determine Input Crop Size.
246
+ if self.input_dim is None:
247
+ self.input_dim = dataio.get_input_dimensions(test_dataset, self.dim_in)
248
+
249
+ predictions = []
250
+ with torch.inference_mode():
251
+ for data in tqdm(test_loader):
252
+ tomogram = data['image'].to(self.device)
253
+ data['pred'] = self._run_inference(tomogram)
254
+
255
+ for idx in range(len(data['image'])):
256
+ predictions.append(data['pred'][idx].squeeze(0).numpy(force=True))
257
+
258
+ return predictions
259
+
260
+ def batch_predict(self,
261
+ num_tomos_per_batch: int = 15,
262
+ runIDs: Optional[List[str]] = None,
263
+ voxel_spacing: float = 10,
264
+ tomo_algorithm: str = 'denoised',
265
+ segmentation_name: str = 'prediction',
266
+ segmentation_user_id: str = 'octopi',
267
+ segmentation_session_id: str = '0'):
268
+ """Run inference on tomograms in batches."""
269
+
270
+ # Print Save Inference Parameters
271
+ self.save_parameters(tomo_algorithm, voxel_spacing, [segmentation_name, segmentation_user_id, segmentation_session_id])
272
+
273
+ # If runIDs are not provided, load all runs
274
+ if runIDs is None:
275
+ runIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is not None]
276
+ skippedRunIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is None]
277
+
278
+ if skippedRunIDs:
279
+ print(f"Warning: skipping runs with no voxel spacing {voxel_spacing}: {skippedRunIDs}")
280
+
281
+ # Iterate over batches of runIDs
282
+ for i in range(0, len(runIDs), num_tomos_per_batch):
283
+ # Get a batch of runIDs
284
+ batch_ids = runIDs[i:i + num_tomos_per_batch]
285
+ print('Running Inference on the Following RunIDs: ', batch_ids)
286
+
287
+ predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
288
+
289
+ # Save Predictions to Corresponding RunID
290
+ for ind in range(len(batch_ids)):
291
+ run = self.root.get_run(batch_ids[ind])
292
+ seg = predictions[ind]
293
+ writers.segmentation(run, seg, segmentation_user_id, segmentation_name,
294
+ segmentation_session_id, voxel_spacing)
295
+
296
+ # After processing and saving predictions for a batch:
297
+ del predictions # Remove reference to the list holding prediction arrays
298
+ torch.cuda.empty_cache() # Clear unused GPU memory
299
+ gc.collect() # Trigger garbage collection for CPU memory
300
+
301
+ print('✅ Predictions Complete!')
302
+
303
+ def create_tta_augmentations(self):
304
+ """Define TTA augmentations and inverse transforms."""
305
+ # Rotate around the YZ plane (dims 3,4 for input, dims 2,3 for output)
306
+ self.tta_transforms = [
307
+ lambda x: x, # Identity (no augmentation)
308
+ lambda x: torch.rot90(x, k=1, dims=(3, 4)), # 90° rotation
309
+ lambda x: torch.rot90(x, k=2, dims=(3, 4)), # 180° rotation
310
+ lambda x: torch.rot90(x, k=3, dims=(3, 4)), # 270° rotation
311
+ ]
312
+
313
+ # Define inverse transformations (flip back to original orientation)
314
+ self.inverse_tta_transforms = [
315
+ lambda x: x, # Identity (no transformation needed)
316
+ lambda x: torch.rot90(x, k=-1, dims=(2, 3)), # Inverse of 90° (i.e. -90°)
317
+ lambda x: torch.rot90(x, k=-2, dims=(2, 3)), # Inverse of 180° (i.e. -180°)
318
+ lambda x: torch.rot90(x, k=-3, dims=(2, 3)), # Inverse of 270° (i.e. -270°)
319
+ ]
320
+
321
+ def save_parameters(self,
322
+ tomo_algorithm: str,
323
+ voxel_size: float,
324
+ seg_info: List[str]
325
+ ):
326
+ """
327
+ Save inference parameters to a YAML file for record-keeping and reproducibility.
328
+ """
329
+
330
+ # Load the model config
331
+ model_config = io.load_yaml(self.model_config[0])
332
+
333
+ # Create parameters dictionary
334
+ params = {
335
+ "inputs": {
336
+ "config": self.config,
337
+ "tomo_alg": tomo_algorithm,
338
+ "voxel_size": voxel_size
339
+ },
340
+ 'model': {
341
+ 'configs': self.model_config,
342
+ 'weights': self.model_weights
343
+ },
344
+ 'labels': model_config['labels'],
345
+ "outputs": {
346
+ "seg_name": seg_info[0],
347
+ "seg_user_id": seg_info[1],
348
+ "seg_session_id": seg_info[2]
349
+ }
350
+ }
351
+
352
+ # Print the parameters
353
+ print(f"\nParameters for Inference (Segment Prediction):")
354
+ pprint.pprint(params); print()
355
+
356
+ # Save to YAML file
357
+ overlay_root = io.remove_prefix(self.root.config.overlay_root)
358
+ basepath = os.path.join(overlay_root, 'logs')
359
+ os.makedirs(basepath, exist_ok=True)
360
+ output_path = os.path.join(
361
+ basepath,
362
+ f'segment-{seg_info[1]}_{seg_info[2]}_{seg_info[0]}.yaml')
363
+ io.save_parameters_yaml(params, output_path)
@@ -0,0 +1,162 @@
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ from octopi.pytorch.segmentation import Predictor
3
+ from typing import List, Union, Optional
4
+ from copick_utils.io import writers
5
+ import queue, torch
6
+
7
+ class MultiGPUPredictor(Predictor):
8
+
9
+ def __init__(self,
10
+ config: str,
11
+ model_config: Union[str, List[str]],
12
+ model_weights: Union[str, List[str]],
13
+ apply_tta: bool = True):
14
+
15
+ # Initialize parent normally
16
+ super().__init__(config, model_config, model_weights, apply_tta)
17
+
18
+ self.num_gpus = torch.cuda.device_count()
19
+ print(f"Available GPUs: {self.num_gpus}")
20
+
21
+ # Only create GPU-specific models if we have multiple GPUs
22
+ if self.num_gpus > 1:
23
+ self._create_gpu_models()
24
+
25
+ def _create_gpu_models(self):
26
+ """Create separate model instances for each GPU."""
27
+ self.gpu_models = {}
28
+
29
+ for gpu_id in range(self.num_gpus):
30
+ device = torch.device(f'cuda:{gpu_id}')
31
+ gpu_models = []
32
+
33
+ # Copy each model to this GPU
34
+ for model in self.models:
35
+ gpu_model = type(model)()
36
+ gpu_model.load_state_dict(model.state_dict())
37
+ gpu_model.to(device)
38
+ gpu_model.eval()
39
+ gpu_models.append(gpu_model)
40
+
41
+ self.gpu_models[gpu_id] = gpu_models
42
+ print(f"Models loaded on GPU {gpu_id}")
43
+
44
+ def _run_on_gpu(self, gpu_id: int, batch_ids: List[str],
45
+ voxel_spacing: float, tomo_algorithm: str,
46
+ segmentation_name: str, segmentation_user_id: str,
47
+ segmentation_session_id: str):
48
+ """Run inference on a specific GPU for a batch of runs."""
49
+ device = torch.device(f'cuda:{gpu_id}')
50
+
51
+ # Temporarily switch to this GPU's models and device
52
+ original_device = self.device
53
+ original_models = self.models
54
+
55
+ self.device = device
56
+ self.models = self.gpu_models[gpu_id]
57
+
58
+ try:
59
+ print(f"GPU {gpu_id} processing runs: {batch_ids}")
60
+
61
+ # Run prediction using parent class method
62
+ predictions = self.predict_on_gpu(batch_ids, voxel_spacing, tomo_algorithm)
63
+
64
+ # Save predictions
65
+ for idx, run_id in enumerate(batch_ids):
66
+ run = self.root.get_run(run_id)
67
+ seg = predictions[idx]
68
+ writers.segmentation(run, seg, segmentation_user_id,
69
+ segmentation_name, segmentation_session_id,
70
+ voxel_spacing)
71
+
72
+ # Clean up
73
+ del predictions
74
+ torch.cuda.empty_cache()
75
+
76
+ finally:
77
+ # Restore original settings
78
+ self.device = original_device
79
+ self.models = original_models
80
+
81
+ def multigpu_batch_predict(self,
82
+ num_tomos_per_batch: int = 15,
83
+ runIDs: Optional[List[str]] = None,
84
+ voxel_spacing: float = 10,
85
+ tomo_algorithm: str = 'denoised',
86
+ segmentation_name: str = 'prediction',
87
+ segmentation_user_id: str = 'octopi',
88
+ segmentation_session_id: str = '0'):
89
+ """Run inference across multiple GPUs using threading."""
90
+
91
+ # Get runIDs if not provided
92
+ if runIDs is None:
93
+ runIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is not None]
94
+ skippedRunIDs = [run.name for run in self.root.runs if run.get_voxel_spacing(voxel_spacing) is None]
95
+ if skippedRunIDs:
96
+ print(f"Warning: skipping runs with no voxel spacing {voxel_spacing}: {skippedRunIDs}")
97
+
98
+ # Split runIDs into batches
99
+ batches = [runIDs[i:i + num_tomos_per_batch]
100
+ for i in range(0, len(runIDs), num_tomos_per_batch)]
101
+
102
+ print(f"Processing {len(batches)} batches across {self.num_gpus} GPUs")
103
+
104
+ # Create work queue
105
+ batch_queue = queue.Queue()
106
+ for batch in batches:
107
+ batch_queue.put(batch)
108
+
109
+ def worker(gpu_id):
110
+ while True:
111
+ try:
112
+ batch_ids = batch_queue.get_nowait()
113
+ self._run_on_gpu(gpu_id, batch_ids, voxel_spacing, tomo_algorithm,
114
+ segmentation_name, segmentation_user_id,
115
+ segmentation_session_id)
116
+ batch_queue.task_done()
117
+ except queue.Empty:
118
+ break
119
+ except Exception as e:
120
+ print(f"Error on GPU {gpu_id}: {e}")
121
+ batch_queue.task_done()
122
+
123
+ # Start worker threads for each GPU
124
+ with ThreadPoolExecutor(max_workers=self.num_gpus) as executor:
125
+ futures = [executor.submit(worker, gpu_id) for gpu_id in range(self.num_gpus)]
126
+ for future in futures:
127
+ future.result()
128
+
129
+ print('Multi-GPU predictions complete!')
130
+
131
+ def batch_predict(self,
132
+ num_tomos_per_batch: int = 15,
133
+ runIDs: Optional[List[str]] = None,
134
+ voxel_spacing: float = 10,
135
+ tomo_algorithm: str = 'denoised',
136
+ segmentation_name: str = 'prediction',
137
+ segmentation_user_id: str = 'octopi',
138
+ segmentation_session_id: str = '0'):
139
+ """Smart batch predict: uses multi-GPU if available, otherwise single GPU."""
140
+
141
+ if self.num_gpus > 1:
142
+ print("Using multi-GPU inference")
143
+ self.multigpu_batch_predict(
144
+ num_tomos_per_batch=num_tomos_per_batch,
145
+ runIDs=runIDs,
146
+ voxel_spacing=voxel_spacing,
147
+ tomo_algorithm=tomo_algorithm,
148
+ segmentation_name=segmentation_name,
149
+ segmentation_user_id=segmentation_user_id,
150
+ segmentation_session_id=segmentation_session_id
151
+ )
152
+ else:
153
+ print("Using single GPU inference")
154
+ super().batch_predict(
155
+ num_tomos_per_batch=num_tomos_per_batch,
156
+ runIDs=runIDs,
157
+ voxel_spacing=voxel_spacing,
158
+ tomo_algorithm=tomo_algorithm,
159
+ segmentation_name=segmentation_name,
160
+ segmentation_user_id=segmentation_user_id,
161
+ segmentation_session_id=segmentation_session_id
162
+ )