octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,438 @@
1
+ from octopi import visualization_tools as viz
2
+ from monai.inferers import sliding_window_inference
3
+ from octopi import stopping_criteria
4
+ from monai.transforms import AsDiscrete
5
+ from monai.data import decollate_batch
6
+ import torch, os, mlflow, re
7
+ import torch_ema as ema
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+
11
+ # Not Ideal, but Necessary if Class is Missing From Dataset
12
+ import warnings
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+ class ModelTrainer:
16
+
17
+ def __init__(self,
18
+ model,
19
+ device,
20
+ loss_function,
21
+ metrics_function,
22
+ optimizer,
23
+ use_ema: bool = True):
24
+
25
+ self.model = model
26
+ self.device = device
27
+ self.loss_function = loss_function
28
+ self.metrics_function = metrics_function
29
+ self.optimizer = optimizer
30
+
31
+ self.parallel_mlflow = False
32
+ self.client = None
33
+ self.trial_run_id = None
34
+
35
+ # Default F-Beta Value
36
+ self.beta = 2
37
+
38
+ # Initialize EMAHandler for the model
39
+ self.ema_experiment = use_ema
40
+ if self.ema_experiment:
41
+ self.ema_handler = ema.ExponentialMovingAverage(self.model.parameters(), decay=0.99)
42
+
43
+ # Initialize Figure and Axes for Plotting
44
+ self.fig = None; self.axs = None
45
+
46
+ def set_parallel_mlflow(self,
47
+ client,
48
+ trial_run_id):
49
+
50
+ self.parallel_mlflow = True
51
+ self.client = client
52
+ self.trial_run_id = trial_run_id
53
+
54
+ def train_update(self):
55
+
56
+ step = 0
57
+ epoch_loss = 0
58
+ self.model.train()
59
+ for batch_data in self.train_loader:
60
+ step += 1
61
+ inputs = batch_data["image"].to(self.device) # Shape: [B, C, H, W, D]
62
+ labels = batch_data["label"].to(self.device) # Shape: [B, C, H, W, D]
63
+ self.optimizer.zero_grad()
64
+ outputs = self.model(inputs) # Output shape: [B, num_classes, H, W, D]
65
+ loss = self.loss_function(outputs, labels)
66
+ loss.backward()
67
+ self.optimizer.step()
68
+
69
+ # Update EMA weights
70
+ if self.ema_experiment:
71
+ self.ema_handler.update()
72
+
73
+ # Update running epoch loss
74
+ epoch_loss += loss.item()
75
+
76
+ # Compute and log average epoch loss
77
+ epoch_loss /= step
78
+ return epoch_loss
79
+
80
+ def validate_update(self):
81
+ """
82
+ Perform validation and compute metrics, including validation loss.
83
+ """
84
+
85
+ # Set model to evaluation mode
86
+ self.model.eval()
87
+ val_loss = 0
88
+ with torch.no_grad():
89
+ for val_data in self.val_loader:
90
+ val_inputs = val_data["image"].to(self.device)
91
+ val_labels = val_data["label"].to(self.device)
92
+
93
+ # Apply sliding window inference
94
+ # roi_size=self.input_dim, # try setting a set size of 128, 144 or 160?
95
+ val_outputs = sliding_window_inference(
96
+ inputs=val_inputs,
97
+ roi_size=(144,144,144),
98
+ sw_batch_size=4,
99
+ predictor=self.model,
100
+ overlap=0.5,
101
+ device=self.device
102
+ )
103
+
104
+ # Compute the loss for this batch
105
+ loss = self.loss_function(val_outputs, val_labels) # Assuming self.loss_function is defined
106
+ val_loss += loss.item() # Accumulate the loss
107
+
108
+ # Apply post-processing
109
+ metric_val_outputs = [self.post_pred(i) for i in decollate_batch(val_outputs)]
110
+ metric_val_labels = [self.post_label(i) for i in decollate_batch(val_labels)]
111
+
112
+ # Compute metrics
113
+ self.metrics_function(y_pred=metric_val_outputs, y=metric_val_labels)
114
+
115
+ # # Contains recall, precision, and f1 for each class
116
+ metric_values = self.metrics_function.aggregate(reduction='mean_batch')
117
+
118
+ # Compute average validation loss and add to metrics dictionary
119
+ val_loss /= len(self.val_loader)
120
+ metric_values.append(val_loss)
121
+
122
+ return metric_values
123
+
124
+ def train(
125
+ self,
126
+ data_load_gen,
127
+ model_save_path: str = 'results',
128
+ my_num_samples: int = 15,
129
+ crop_size: int = 96,
130
+ max_epochs: int = 100,
131
+ val_interval: int = 15,
132
+ lr_scheduler_type: str = 'cosine',
133
+ best_metric: str = 'avg_f1',
134
+ use_mlflow: bool = False,
135
+ verbose: bool = False
136
+ ):
137
+
138
+ # best lr scheduler options are cosine or reduce
139
+ self.warmup_epochs = 5
140
+ self.warmup_lr_factor = 0.1
141
+ self.min_lr = 1e-6
142
+
143
+ self.max_epochs = max_epochs
144
+ self.crop_size = crop_size
145
+ self.num_samples = my_num_samples
146
+ self.val_interval = val_interval
147
+ self.use_mlflow = use_mlflow
148
+
149
+ # Create Save Folder if It Doesn't Exist
150
+ if model_save_path is not None:
151
+ os.makedirs(model_save_path, exist_ok=True)
152
+
153
+ Nclass = data_load_gen.Nclasses
154
+ self.create_results_dictionary(Nclass)
155
+
156
+ # Resolve the best metric
157
+ best_metric = self.resolve_best_metric(best_metric)
158
+
159
+ # Stopping Criteria
160
+ self.stopping_criteria = stopping_criteria.EarlyStoppingChecker(monitor_metric=best_metric, val_interval=val_interval)
161
+
162
+ self.post_pred = AsDiscrete(argmax=True, to_onehot=Nclass)
163
+ self.post_label = AsDiscrete(to_onehot=Nclass)
164
+
165
+ # Produce Dataloaders for the First Training Iteration
166
+ self.train_loader, self.val_loader = data_load_gen.create_train_dataloaders(crop_size=crop_size, num_samples=my_num_samples)
167
+ self.input_dim = data_load_gen.input_dim
168
+
169
+ # Save the original learning rate
170
+ original_lr = self.optimizer.param_groups[0]['lr']
171
+ self.load_learning_rate_scheduler(lr_scheduler_type)
172
+
173
+ # Initialize tqdm around the epoch loop
174
+ for epoch in tqdm(range(max_epochs), desc="Training Progress", unit="epoch"):
175
+
176
+ # Reload dataloaders periodically
177
+ if data_load_gen.reload_frequency > 0 and (epoch + 1) % data_load_gen.reload_frequency == 0:
178
+ self.train_loader, self.val_loader = data_load_gen.create_train_dataloaders(num_samples=my_num_samples)
179
+ # Lower the learning rate for the warm-up period
180
+ for param_group in self.optimizer.param_groups:
181
+ param_group['lr'] = original_lr * self.warmup_lr_factor
182
+
183
+ # Compute and log average epoch loss
184
+ epoch_loss = self.train_update()
185
+
186
+ # Check for NaN in the loss
187
+ if self.stopping_criteria.should_stop_training(epoch_loss):
188
+ tqdm.write(f"Training stopped early due to {self.stopping_criteria.get_stopped_reason()}")
189
+ break
190
+
191
+ current_lr = self.optimizer.param_groups[0]['lr']
192
+ self.my_log_metrics( metrics_dict={"loss": epoch_loss}, curr_step=epoch + 1 )
193
+ self.my_log_metrics( metrics_dict={"learning_rate": current_lr}, curr_step=epoch + 1 )
194
+
195
+ # Validation and metric logging
196
+ if (epoch + 1) % val_interval == 0 or (epoch + 1) == max_epochs:
197
+ if verbose:
198
+ tqdm.write(f"Epoch {epoch + 1}/{max_epochs}, avg_train_loss: {epoch_loss:.4f}")
199
+
200
+ # Validate the Model with or without EMA
201
+ if self.ema_experiment:
202
+ with self.ema_handler.average_parameters():
203
+ metric_values = self.validate_update()
204
+ else:
205
+ metric_values = self.validate_update()
206
+
207
+ # Log all metrics
208
+ self.my_log_metrics( metrics_dict=metric_values, curr_step=epoch + 1 )
209
+
210
+ # Update tqdm description
211
+ if verbose:
212
+ (avg_f1, avg_recall, avg_precision) = (self.results['avg_f1'][-1][1],
213
+ self.results['avg_recall'][-1][1],
214
+ self.results['avg_precision'][-1][1])
215
+ tqdm.write(f"Epoch {epoch + 1}/{max_epochs}, avg_f1_score: {avg_f1:.4f}, avg_recall: {avg_recall:.4f}, avg_precision: {avg_precision:.4f}")
216
+
217
+ # Reset metrics function
218
+ self.metrics_function.reset()
219
+
220
+ # Save the best model
221
+ if self.results[best_metric][-1][1] > self.results["best_metric"]:
222
+ self.results["best_metric"] = self.results[best_metric][-1][1]
223
+ self.results["best_metric_epoch"] = epoch + 1
224
+
225
+ # Read Model Weights and Save
226
+ if self.ema_experiment:
227
+ with self.ema_handler.average_parameters():
228
+ self.save_model(model_save_path)
229
+ else:
230
+ self.save_model(model_save_path)
231
+
232
+ # Save plot if Local Training Call
233
+ if not self.use_mlflow:
234
+ self.fig, self.axs = viz.plot_training_results(
235
+ self.results,
236
+ save_plot=os.path.join(model_save_path, "net_train_history.png"),
237
+ fig=self.fig,
238
+ axs=self.axs)
239
+
240
+ # After Validation Metrics are Logged, Check for Early Stopping
241
+ if self.stopping_criteria.should_stop_training(epoch_loss, results=self.results, check_metrics=True):
242
+ tqdm.write(f"Training stopped early due to {self.stopping_criteria.get_stopped_reason()}")
243
+ break
244
+
245
+ # Run the learning rate scheduler
246
+ early_stop = self.run_scheduler(data_load_gen, original_lr, epoch, val_interval, lr_scheduler_type)
247
+ if early_stop:
248
+ break
249
+
250
+ return self.results
251
+
252
+ def load_learning_rate_scheduler(self, type: str = 'cosine'):
253
+ """
254
+ Initialize and return the learning rate scheduler based on the given type.
255
+ """
256
+ # Configure learning rate scheduler based on the type
257
+ if type == "cosine":
258
+ self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
259
+ self.optimizer, T_max=self.max_epochs, eta_min=self.min_lr )
260
+ elif type == "onecyle":
261
+ max_lr = 1e-3
262
+ steps_per_epoch = len(self.train_loader)
263
+ self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
264
+ self.optimizer, max_lr=max_lr, epochs=self.max_epochs, steps_per_epoch=steps_per_epoch )
265
+ elif type == "reduce":
266
+ mode = "min"
267
+ patience = 3
268
+ factor = 0.5
269
+ self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
270
+ self.optimizer, mode=mode, patience=patience, factor=factor )
271
+ elif type == 'exponential':
272
+ gamma = 0.9
273
+ self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
274
+ self.optimizer, gamma=gamma)
275
+ else:
276
+ raise ValueError(f"Unsupported scheduler type: {type}")
277
+
278
+ def run_scheduler(
279
+ self,
280
+ data_load_gen,
281
+ original_lr: float,
282
+ epoch: int,
283
+ val_interval: int,
284
+ type: str
285
+ ):
286
+ """
287
+ Manage the learning rate scheduler, including warm-up and normal scheduling.
288
+ """
289
+ # Apply warm-up logic
290
+ if (epoch + 1) <= self.warmup_epochs:
291
+ for param_group in self.optimizer.param_groups:
292
+ param_group['lr'] = original_lr * self.warmup_lr_factor
293
+ return False # Continue training
294
+
295
+ # Step the scheduler
296
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.CosineAnnealingLR):
297
+ self.lr_scheduler.step()
298
+ elif isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and (epoch + 1) % val_interval == 0:
299
+ metric_value = self.results['val_loss'][-1][1]
300
+ self.lr_scheduler.step(metric_value)
301
+ else:
302
+ self.lr_scheduler.step() # Step for other schedulers
303
+
304
+ # Check learning rate for early stopping
305
+ current_lr = self.optimizer.param_groups[0]['lr']
306
+ if current_lr < self.min_lr and type != 'onecycle':
307
+ print(f"Early stopping triggered at epoch {epoch + 1} as learning rate fell below {self.min_lr}.")
308
+ return True # Indicate early stopping
309
+
310
+ return False # Continue training
311
+
312
+ def save_model(self, model_save_path: str):
313
+
314
+ # Store Model Weights as Member Variable
315
+ self.model_weights = self.model.state_dict()
316
+
317
+ # Save Model Weights to *.pth file
318
+ if model_save_path is not None:
319
+ torch.save(self.model_weights, os.path.join(model_save_path, "best_model.pth"))
320
+
321
+ def create_results_dictionary(self, Nclass: int):
322
+
323
+ self.results = {
324
+ 'loss': [],
325
+ 'val_loss': [],
326
+ 'avg_f1': [],
327
+ 'avg_recall': [],
328
+ 'avg_precision': [],
329
+ 'avg_fbeta': [],
330
+ 'best_metric': -1, # Initialize as None or a default value
331
+ 'best_metric_epoch': -1
332
+ }
333
+
334
+ for i in range(Nclass-1):
335
+ self.results[f'fbeta_class{i+1}'] = []
336
+ self.results[f'f1_class{i+1}'] = []
337
+ self.results[f'recall_class{i+1}'] = []
338
+ self.results[f'precision_class{i+1}'] = []
339
+
340
+ self.metric_names = self.results.keys()
341
+
342
+ def my_log_metrics(
343
+ self,
344
+ metrics_dict: dict,
345
+ curr_step: int,
346
+ ):
347
+
348
+ # If metrics_dict contains multiple elements (e.g., recall, precision, f1), process them
349
+ if len(metrics_dict) > 1:
350
+
351
+ # Extract individual metrics
352
+ # (assume metrics_dict contains recall, precision, f1, val_loss in sequence)
353
+ recall, precision, f1s, val_loss = metrics_dict[0], metrics_dict[1], metrics_dict[2], metrics_dict[3]
354
+
355
+ # Log per-class metrics
356
+ metrics_to_log = {}
357
+ for i, (rec, prec, f1) in enumerate(zip(recall, precision, f1s)):
358
+ metrics_to_log[f"recall_class{i+1}"] = rec.item()
359
+ metrics_to_log[f"precision_class{i+1}"] = prec.item()
360
+ metrics_to_log[f"f1_class{i+1}"] = f1.item()
361
+ metrics_to_log[f"fbeta_class{i+1}"] = self.fbeta(prec, rec).item()
362
+
363
+ # Prepare average metrics
364
+ metrics_to_log["avg_recall"] = recall.mean().cpu().item()
365
+ metrics_to_log["avg_precision"] = precision.mean().cpu().item()
366
+ metrics_to_log["avg_f1"] = f1s.mean().cpu().item()
367
+ metrics_to_log["avg_fbeta"] = self.fbeta(precision, recall).mean().cpu().item()
368
+ metrics_to_log['val_loss'] = val_loss
369
+
370
+ # Update metrics_dict for further logging
371
+ metrics_dict = metrics_to_log
372
+
373
+ # Log all metrics (per-class and average metrics)
374
+ for metric_name, value in metrics_dict.items():
375
+ if metric_name not in self.results:
376
+ self.results[metric_name] = []
377
+ self.results[metric_name].append((curr_step, value))
378
+
379
+ # Log to MLflow or client
380
+ if self.client is not None and self.trial_run_id is not None:
381
+ for metric_name, value in metrics_dict.items():
382
+ self.client.log_metric(
383
+ run_id=self.trial_run_id,
384
+ key=metric_name,
385
+ value=value,
386
+ step=curr_step,
387
+ )
388
+ elif self.use_mlflow:
389
+ for metric_name, value in metrics_dict.items():
390
+ mlflow.log_metric(metric_name, value, step=curr_step)
391
+
392
+ def fbeta(self, precision, recall):
393
+
394
+ # Handle division by zero
395
+ numerator = (1 + self.beta**2) * (precision * recall)
396
+ denominator = (self.beta**2 * precision) + recall
397
+
398
+ # Use torch.where to handle zero cases
399
+ result = torch.where(
400
+ denominator > 0,
401
+ numerator / denominator,
402
+ torch.zeros_like(precision)
403
+ )
404
+ return result
405
+
406
+ def my_log_params(
407
+ self,
408
+ params_dict: dict,
409
+ ):
410
+
411
+ if self.client is not None and self.trial_run_id is not None:
412
+ for key, value in params_dict.items():
413
+ self.client.log_param(run_id=self.trial_run_id, key=key, value=value)
414
+ else:
415
+ mlflow.log_params(params_dict)
416
+
417
+ # Example input: best_metric = 'fBeta2_class3' or 'fBeta1' or 'f1_class2'
418
+ def resolve_best_metric(self, best_metric):
419
+ fbeta_pattern = r"^fBeta(\d+)(?:_class(\d+))?$" # Matches fBetaX or fBetaX_classY
420
+ match = re.match(fbeta_pattern, best_metric)
421
+
422
+ if match:
423
+ self.beta = int(match.group(1)) # Extract beta value
424
+ class_part = match.group(2)
425
+ if class_part:
426
+ best_metric = f'fbeta_class{class_part}' # fBeta2_class3 → fbeta_class3
427
+ else:
428
+ best_metric = 'avg_fbeta' # fBeta2 → avg_fbeta
429
+
430
+ elif best_metric in self.metric_names:
431
+ pass # It's already a valid metric in the results dict
432
+
433
+ else:
434
+ print(f"'{best_metric}' is not a valid metric. Defaulting to 'avg_f1'.\n")
435
+ best_metric = 'avg_f1'
436
+
437
+ return best_metric
438
+
File without changes