octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
@@ -0,0 +1,465 @@
1
+ from octopi.utils import visualization_tools as viz
2
+ from monai.inferers import sliding_window_inference
3
+ from octopi.utils import stopping_criteria
4
+ from monai.transforms import AsDiscrete
5
+ from monai.data import decollate_batch
6
+ import torch, os, mlflow, re, optuna
7
+ import torch_ema as ema
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+
11
+ # Not Ideal, but Necessary if Class is Missing From Dataset
12
+ import warnings
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+ class ModelTrainer:
16
+
17
+ def __init__(self,
18
+ model,
19
+ device,
20
+ loss_function,
21
+ metrics_function,
22
+ optimizer,
23
+ use_ema: bool = True):
24
+
25
+ self.model = model
26
+ self.device = device
27
+ self.loss_function = loss_function
28
+ self.metrics_function = metrics_function
29
+ self.optimizer = optimizer
30
+
31
+ self.parallel_mlflow = False
32
+ self.client = None
33
+ self.trial_run_id = None
34
+
35
+ # Default F-Beta Value
36
+ self.beta = 2
37
+ self.overlap = 0.5
38
+ self.sw_bs = 4
39
+
40
+ # Initialize EMAHandler for the model
41
+ self.ema_experiment = use_ema
42
+ if self.ema_experiment:
43
+ self.ema_handler = ema.ExponentialMovingAverage(self.model.parameters(), decay=0.995)
44
+
45
+ # Initialize Figure and Axes for Plotting
46
+ self.fig = None; self.axs = None
47
+
48
+ def set_parallel_mlflow(self,
49
+ client,
50
+ trial_run_id):
51
+
52
+ self.parallel_mlflow = True
53
+ self.client = client
54
+ self.trial_run_id = trial_run_id
55
+
56
+ def train_update(self):
57
+
58
+ step = 0
59
+ epoch_loss = 0
60
+ self.model.train()
61
+ for batch_data in self.train_loader:
62
+ step += 1
63
+ inputs = batch_data["image"].to(self.device) # Shape: [B, C, H, W, D]
64
+ labels = batch_data["label"].to(self.device) # Shape: [B, C, H, W, D]
65
+ self.optimizer.zero_grad()
66
+ outputs = self.model(inputs) # Output shape: [B, num_classes, H, W, D]
67
+ loss = self.loss_function(outputs, labels)
68
+ loss.backward()
69
+ self.optimizer.step()
70
+
71
+ # Update EMA weights
72
+ if self.ema_experiment:
73
+ self.ema_handler.update()
74
+
75
+ # Update running epoch loss
76
+ epoch_loss += loss.item()
77
+
78
+ # Compute and log average epoch loss
79
+ epoch_loss /= step
80
+ return epoch_loss
81
+
82
+ def validate_update(self):
83
+ """
84
+ Perform validation and compute metrics, including validation loss.
85
+ """
86
+
87
+ # Set model to evaluation mode
88
+ self.model.eval()
89
+ val_loss = 0
90
+ with torch.no_grad():
91
+ for val_data in self.val_loader:
92
+ val_inputs = val_data["image"].to(self.device)
93
+ val_labels = val_data["label"].to(self.device) # Keep labels on CPU for metric computation
94
+
95
+ # Apply sliding window inference
96
+ roi = max(128, self.crop_size) # try setting a set size of 128, 144 or 160?
97
+ val_outputs = sliding_window_inference(
98
+ inputs=val_inputs,
99
+ roi_size=(roi, roi, roi),
100
+ sw_batch_size=self.sw_bs,
101
+ predictor=self.model,
102
+ overlap=self.overlap,
103
+ sw_device=self.device,
104
+ device=self.device
105
+ )
106
+
107
+ del val_inputs
108
+ torch.cuda.empty_cache()
109
+
110
+ # Compute the loss for this batch
111
+ loss = self.loss_function(val_outputs, val_labels) # Assuming self.loss_function is defined
112
+ val_loss += loss.item() # Accumulate the loss
113
+
114
+ # Apply post-processing
115
+ metric_val_outputs = [self.post_pred(i) for i in decollate_batch(val_outputs)]
116
+ metric_val_labels = [self.post_label(i) for i in decollate_batch(val_labels)]
117
+
118
+ # Compute metrics
119
+ self.metrics_function(y_pred=metric_val_outputs, y=metric_val_labels)
120
+
121
+ del val_labels, val_outputs, metric_val_outputs, metric_val_labels
122
+ torch.cuda.empty_cache()
123
+
124
+ # # Contains recall, precision, and f1 for each class
125
+ metric_values = self.metrics_function.aggregate(reduction='mean_batch')
126
+
127
+ # Compute average validation loss and add to metrics dictionary
128
+ val_loss /= len(self.val_loader)
129
+ metric_values.append(val_loss)
130
+
131
+ return metric_values
132
+
133
+ def train(
134
+ self,
135
+ data_load_gen,
136
+ model_save_path: str = 'results',
137
+ my_num_samples: int = 15,
138
+ crop_size: int = 96,
139
+ max_epochs: int = 100,
140
+ val_interval: int = 15,
141
+ lr_scheduler_type: str = 'cosine',
142
+ best_metric: str = 'avg_f1',
143
+ use_mlflow: bool = False,
144
+ verbose: bool = False,
145
+ trial: optuna.trial.Trial = None
146
+ ):
147
+
148
+ # best lr scheduler options are cosine or reduce
149
+ self.warmup_epochs = 5
150
+ self.warmup_lr_factor = 0.1
151
+ self.min_lr = 1e-6
152
+
153
+ self.max_epochs = max_epochs
154
+ self.crop_size = crop_size
155
+ self.num_samples = my_num_samples
156
+ self.val_interval = val_interval
157
+ self.use_mlflow = use_mlflow
158
+
159
+ # Create Save Folder if It Doesn't Exist
160
+ if model_save_path is not None:
161
+ os.makedirs(model_save_path, exist_ok=True)
162
+
163
+ Nclass = data_load_gen.Nclasses
164
+ self.create_results_dictionary(Nclass)
165
+
166
+ # Resolve the best metric
167
+ best_metric = self.resolve_best_metric(best_metric)
168
+
169
+ # Stopping Criteria
170
+ self.stopping_criteria = stopping_criteria.EarlyStoppingChecker(monitor_metric=best_metric, val_interval=val_interval)
171
+
172
+ self.post_pred = AsDiscrete(argmax=True, to_onehot=Nclass)
173
+ self.post_label = AsDiscrete(to_onehot=Nclass)
174
+
175
+ # Produce Dataloaders for the First Training Iteration
176
+ self.train_loader, self.val_loader = data_load_gen.create_train_dataloaders(
177
+ crop_size=crop_size, num_samples=my_num_samples
178
+ )
179
+ self.input_dim = data_load_gen.input_dim
180
+
181
+ # Save the original learning rate
182
+ original_lr = self.optimizer.param_groups[0]['lr']
183
+ self.base_lr = original_lr
184
+ self.load_learning_rate_scheduler(lr_scheduler_type)
185
+
186
+ # Initialize tqdm around the epoch loop
187
+ for epoch in tqdm(range(max_epochs), desc=f"Training on GPU: {self.device}", unit="epoch"):
188
+
189
+ # Reload dataloaders periodically
190
+ if data_load_gen.reload_frequency > 0 and (epoch + 1) % data_load_gen.reload_frequency == 0:
191
+ self.train_loader, self.val_loader = data_load_gen.create_train_dataloaders(
192
+ crop_size=crop_size, num_samples=my_num_samples
193
+ )
194
+ # Lower the learning rate for the warm-up period
195
+ for param_group in self.optimizer.param_groups:
196
+ param_group['lr'] = original_lr * self.warmup_lr_factor
197
+
198
+ # Compute and log average epoch loss
199
+ epoch_loss = self.train_update()
200
+
201
+ # Check for NaN in the loss
202
+ if self.stopping_criteria.should_stop_training(epoch_loss):
203
+ tqdm.write(f"Training stopped early due to {self.stopping_criteria.get_stopped_reason()}")
204
+ break
205
+
206
+ current_lr = self.optimizer.param_groups[0]['lr']
207
+ self.my_log_metrics( metrics_dict={"loss": epoch_loss}, curr_step=epoch + 1 )
208
+ self.my_log_metrics( metrics_dict={"learning_rate": current_lr}, curr_step=epoch + 1 )
209
+
210
+ # Validation and metric logging
211
+ if (epoch + 1) % val_interval == 0 or (epoch + 1) == max_epochs:
212
+ if verbose:
213
+ tqdm.write(f"Epoch {epoch + 1}/{max_epochs}, avg_train_loss: {epoch_loss:.4f}")
214
+
215
+ # Validate the Model with or without EMA
216
+ if self.ema_experiment:
217
+ with self.ema_handler.average_parameters():
218
+ metric_values = self.validate_update()
219
+ else:
220
+ metric_values = self.validate_update()
221
+
222
+ # Log all metrics
223
+ self.my_log_metrics( metrics_dict=metric_values, curr_step=epoch + 1 )
224
+
225
+ # Update tqdm description
226
+ if verbose:
227
+ (avg_f1, avg_recall, avg_precision) = (self.results['avg_f1'][-1][1],
228
+ self.results['avg_recall'][-1][1],
229
+ self.results['avg_precision'][-1][1])
230
+ tqdm.write(f"Epoch {epoch + 1}/{max_epochs}, avg_f1_score: {avg_f1:.4f}, avg_recall: {avg_recall:.4f}, avg_precision: {avg_precision:.4f}")
231
+
232
+ # Reset metrics function
233
+ self.metrics_function.reset()
234
+
235
+ # Save the best model
236
+ if self.results[best_metric][-1][1] > self.results["best_metric"]:
237
+ self.results["best_metric"] = self.results[best_metric][-1][1]
238
+ self.results["best_metric_epoch"] = epoch + 1
239
+
240
+ # Read Model Weights and Save
241
+ if self.ema_experiment:
242
+ with self.ema_handler.average_parameters():
243
+ self.save_model(model_save_path)
244
+ else:
245
+ self.save_model(model_save_path)
246
+
247
+ # Save plot if Local Training Call
248
+ if not self.use_mlflow:
249
+ self.fig, self.axs = viz.plot_training_results(
250
+ self.results,
251
+ data_load_gen.class_names,
252
+ save_plot=os.path.join(model_save_path, "net_train_history.png"),
253
+ fig=self.fig,
254
+ axs=self.axs)
255
+
256
+ # Report/prune right after a validation step
257
+ objective_val = self.results[best_metric][-1][1] # e.g., avg_f1 or avg_fbeta
258
+ if trial:
259
+ trial.report(objective_val, step=epoch + 1)
260
+ if trial.should_prune():
261
+ raise optuna.TrialPruned()
262
+
263
+ # After Validation Metrics are Logged, Check for Early Stopping
264
+ if self.stopping_criteria.should_stop_training(epoch_loss, results=self.results, check_metrics=True):
265
+ tqdm.write(f"Training stopped early due to {self.stopping_criteria.get_stopped_reason()}")
266
+ break
267
+
268
+ # Run the learning rate scheduler
269
+ early_stop = self.run_scheduler(data_load_gen, original_lr, epoch, val_interval, lr_scheduler_type)
270
+ if early_stop:
271
+ break
272
+
273
+ return self.results
274
+
275
+ def load_learning_rate_scheduler(self, type: str = 'cosine'):
276
+ """
277
+ Initialize and return the learning rate scheduler based on the given type.
278
+ """
279
+ # Configure learning rate scheduler based on the type
280
+ if type == "cosine":
281
+ self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
282
+ self.optimizer, T_max=self.max_epochs, eta_min=self.min_lr )
283
+ elif type == "onecyle":
284
+ max_lr = 1e-3
285
+ steps_per_epoch = len(self.train_loader)
286
+ self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
287
+ self.optimizer, max_lr=max_lr, epochs=self.max_epochs, steps_per_epoch=steps_per_epoch )
288
+ elif type == "reduce":
289
+ mode = "min"
290
+ patience = 3
291
+ factor = 0.5
292
+ self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
293
+ self.optimizer, mode=mode, patience=patience, factor=factor )
294
+ elif type == 'exponential':
295
+ gamma = 0.9
296
+ self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
297
+ self.optimizer, gamma=gamma)
298
+ else:
299
+ raise ValueError(f"Unsupported scheduler type: {type}")
300
+
301
+ def run_scheduler(
302
+ self,
303
+ data_load_gen,
304
+ original_lr: float,
305
+ epoch: int,
306
+ val_interval: int,
307
+ type: str
308
+ ):
309
+ """
310
+ Manage the learning rate scheduler, including warm-up and normal scheduling.
311
+ """
312
+ # Apply warm-up logic
313
+ if (epoch + 1) <= self.warmup_epochs:
314
+ scale = (epoch + 1) / self.warmup_epochs
315
+ for param_group in self.optimizer.param_groups:
316
+ param_group['lr'] = self.base_lr * (0.1 + 0.9 * scale) # 10% -> 100% over warmup
317
+ return False # Continue training
318
+ # for param_group in self.optimizer.param_groups:
319
+ # param_group['lr'] = original_lr * self.warmup_lr_factor
320
+ # return False # Continue training
321
+
322
+ # Step the scheduler
323
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.CosineAnnealingLR):
324
+ self.lr_scheduler.step()
325
+ elif isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and (epoch + 1) % val_interval == 0:
326
+ metric_value = self.results['val_loss'][-1][1]
327
+ self.lr_scheduler.step(metric_value)
328
+ else:
329
+ self.lr_scheduler.step() # Step for other schedulers
330
+
331
+ # Check learning rate for early stopping
332
+ current_lr = self.optimizer.param_groups[0]['lr']
333
+ if current_lr < self.min_lr and type != 'onecycle':
334
+ print(f"Early stopping triggered at epoch {epoch + 1} as learning rate fell below {self.min_lr}.")
335
+ return True # Indicate early stopping
336
+
337
+ return False # Continue training
338
+
339
+ def save_model(self, model_save_path: str):
340
+
341
+ # Store Model Weights as Member Variable
342
+ self.model_weights = self.model.state_dict()
343
+
344
+ # Save Model Weights to *.pth file
345
+ if model_save_path is not None:
346
+ torch.save(self.model_weights, os.path.join(model_save_path, "best_model.pth"))
347
+
348
+ def create_results_dictionary(self, Nclass: int):
349
+
350
+ self.results = {
351
+ 'loss': [],
352
+ 'val_loss': [],
353
+ 'avg_f1': [],
354
+ 'avg_recall': [],
355
+ 'avg_precision': [],
356
+ 'avg_fbeta': [],
357
+ 'best_metric': -1, # Initialize as None or a default value
358
+ 'best_metric_epoch': -1
359
+ }
360
+
361
+ for i in range(Nclass-1):
362
+ self.results[f'fbeta_class{i+1}'] = []
363
+ self.results[f'f1_class{i+1}'] = []
364
+ self.results[f'recall_class{i+1}'] = []
365
+ self.results[f'precision_class{i+1}'] = []
366
+
367
+ self.metric_names = self.results.keys()
368
+
369
+ def my_log_metrics(
370
+ self,
371
+ metrics_dict: dict,
372
+ curr_step: int,
373
+ ):
374
+
375
+ # If metrics_dict contains multiple elements (e.g., recall, precision, f1), process them
376
+ if len(metrics_dict) > 1:
377
+
378
+ # Extract individual metrics
379
+ # (assume metrics_dict contains recall, precision, f1, val_loss in sequence)
380
+ recall, precision, f1s, val_loss = metrics_dict[0], metrics_dict[1], metrics_dict[2], metrics_dict[3]
381
+
382
+ # Log per-class metrics
383
+ metrics_to_log = {}
384
+ for i, (rec, prec, f1) in enumerate(zip(recall, precision, f1s)):
385
+ metrics_to_log[f"recall_class{i+1}"] = rec.item()
386
+ metrics_to_log[f"precision_class{i+1}"] = prec.item()
387
+ metrics_to_log[f"f1_class{i+1}"] = f1.item()
388
+ metrics_to_log[f"fbeta_class{i+1}"] = self.fbeta(prec, rec).item()
389
+
390
+ # Prepare average metrics
391
+ metrics_to_log["avg_recall"] = recall.mean().cpu().item()
392
+ metrics_to_log["avg_precision"] = precision.mean().cpu().item()
393
+ metrics_to_log["avg_f1"] = f1s.mean().cpu().item()
394
+ metrics_to_log["avg_fbeta"] = self.fbeta(precision, recall).mean().cpu().item()
395
+ metrics_to_log['val_loss'] = val_loss
396
+
397
+ # Update metrics_dict for further logging
398
+ metrics_dict = metrics_to_log
399
+
400
+ # Log all metrics (per-class and average metrics)
401
+ for metric_name, value in metrics_dict.items():
402
+ if metric_name not in self.results:
403
+ self.results[metric_name] = []
404
+ self.results[metric_name].append((curr_step, value))
405
+
406
+ # Log to MLflow or client
407
+ if self.client is not None and self.trial_run_id is not None:
408
+ for metric_name, value in metrics_dict.items():
409
+ self.client.log_metric(
410
+ run_id=self.trial_run_id,
411
+ key=metric_name,
412
+ value=value,
413
+ step=curr_step,
414
+ )
415
+ elif self.use_mlflow:
416
+ for metric_name, value in metrics_dict.items():
417
+ mlflow.log_metric(metric_name, value, step=curr_step)
418
+
419
+ def fbeta(self, precision, recall):
420
+
421
+ # Handle division by zero
422
+ numerator = (1 + self.beta**2) * (precision * recall)
423
+ denominator = (self.beta**2 * precision) + recall
424
+
425
+ # Use torch.where to handle zero cases
426
+ result = torch.where(
427
+ denominator > 0,
428
+ numerator / denominator,
429
+ torch.zeros_like(precision)
430
+ )
431
+ return result
432
+
433
+ def my_log_params(
434
+ self,
435
+ params_dict: dict,
436
+ ):
437
+
438
+ if self.client is not None and self.trial_run_id is not None:
439
+ for key, value in params_dict.items():
440
+ self.client.log_param(run_id=self.trial_run_id, key=key, value=value)
441
+ else:
442
+ mlflow.log_params(params_dict)
443
+
444
+ # Example input: best_metric = 'fBeta2_class3' or 'fBeta1' or 'f1_class2'
445
+ def resolve_best_metric(self, best_metric):
446
+ fbeta_pattern = r"^fBeta(\d+)(?:_class(\d+))?$" # Matches fBetaX or fBetaX_classY
447
+ match = re.match(fbeta_pattern, best_metric)
448
+
449
+ if match:
450
+ self.beta = int(match.group(1)) # Extract beta value
451
+ class_part = match.group(2)
452
+ if class_part:
453
+ best_metric = f'fbeta_class{class_part}' # fBeta2_class3 → fbeta_class3
454
+ else:
455
+ best_metric = 'avg_fbeta' # fBeta2 → avg_fbeta
456
+
457
+ elif best_metric in self.metric_names:
458
+ pass # It's already a valid metric in the results dict
459
+
460
+ else:
461
+ print(f"'{best_metric}' is not a valid metric. Defaulting to 'avg_f1'.\n")
462
+ best_metric = 'avg_f1'
463
+
464
+ return best_metric
465
+
File without changes