GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of GANDLF might be problematic. Click here for more details.

Files changed (57) hide show
  1. GANDLF/cli/deploy.py +2 -2
  2. GANDLF/cli/generate_metrics.py +35 -1
  3. GANDLF/cli/main_run.py +4 -10
  4. GANDLF/compute/__init__.py +0 -2
  5. GANDLF/compute/forward_pass.py +0 -1
  6. GANDLF/compute/generic.py +107 -2
  7. GANDLF/compute/inference_loop.py +4 -4
  8. GANDLF/compute/loss_and_metric.py +1 -2
  9. GANDLF/compute/training_loop.py +10 -10
  10. GANDLF/config_manager.py +33 -717
  11. GANDLF/configuration/__init__.py +0 -0
  12. GANDLF/configuration/default_config.py +73 -0
  13. GANDLF/configuration/differential_privacy_config.py +16 -0
  14. GANDLF/configuration/exclude_parameters.py +1 -0
  15. GANDLF/configuration/model_config.py +82 -0
  16. GANDLF/configuration/nested_training_config.py +25 -0
  17. GANDLF/configuration/optimizer_config.py +121 -0
  18. GANDLF/configuration/parameters_config.py +10 -0
  19. GANDLF/configuration/patch_sampler_config.py +11 -0
  20. GANDLF/configuration/post_processing_config.py +10 -0
  21. GANDLF/configuration/pre_processing_config.py +94 -0
  22. GANDLF/configuration/scheduler_config.py +92 -0
  23. GANDLF/configuration/user_defined_config.py +131 -0
  24. GANDLF/configuration/utils.py +96 -0
  25. GANDLF/configuration/validators.py +479 -0
  26. GANDLF/data/__init__.py +14 -16
  27. GANDLF/data/lightning_datamodule.py +119 -0
  28. GANDLF/entrypoints/run.py +36 -31
  29. GANDLF/inference_manager.py +69 -25
  30. GANDLF/losses/__init__.py +23 -1
  31. GANDLF/losses/loss_calculators.py +79 -0
  32. GANDLF/losses/segmentation.py +3 -2
  33. GANDLF/metrics/__init__.py +26 -0
  34. GANDLF/metrics/generic.py +1 -1
  35. GANDLF/metrics/metric_calculators.py +102 -0
  36. GANDLF/metrics/panoptica_config_brats.yaml +56 -0
  37. GANDLF/metrics/segmentation_panoptica.py +49 -0
  38. GANDLF/models/__init__.py +8 -3
  39. GANDLF/models/lightning_module.py +2102 -0
  40. GANDLF/optimizers/__init__.py +4 -8
  41. GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
  42. GANDLF/schedulers/__init__.py +11 -4
  43. GANDLF/schedulers/wrap_torch.py +15 -3
  44. GANDLF/training_manager.py +160 -50
  45. GANDLF/utils/__init__.py +5 -3
  46. GANDLF/utils/imaging.py +176 -35
  47. GANDLF/utils/modelio.py +12 -8
  48. GANDLF/utils/pred_target_processors.py +71 -0
  49. GANDLF/utils/tensor.py +2 -1
  50. GANDLF/utils/write_parse.py +1 -1
  51. GANDLF/version.py +1 -1
  52. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
  53. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
  54. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
  55. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
  56. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
  57. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2102 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import psutil
5
+ import torch
6
+ import torchio
7
+ import warnings
8
+ import openslide
9
+ import numpy as np
10
+ import SimpleITK as sitk
11
+ import lightning.pytorch as pl
12
+ import torch.nn.functional as F
13
+
14
+
15
+ from medcam import medcam
16
+ from copy import deepcopy
17
+ from statistics import mean
18
+ from multiprocessing import Lock
19
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
20
+ from lightning.pytorch.strategies import DDPStrategy
21
+ from lightning.pytorch.utilities import rank_zero_only
22
+
23
+ from GANDLF.logger import Logger
24
+ from GANDLF.models import get_model
25
+ from GANDLF.metrics import overall_stats
26
+ from GANDLF.optimizers import get_optimizer
27
+ from GANDLF.schedulers import get_scheduler
28
+ from GANDLF.data.post_process import global_postprocessing_dict
29
+ from GANDLF.losses.loss_calculators import LossCalculatorFactory
30
+ from GANDLF.metrics.metric_calculators import MetricCalculatorFactory
31
+ from GANDLF.data.preprocessing import get_transforms_for_preprocessing
32
+ from GANDLF.data.inference_dataloader_histopath import InferTumorSegDataset
33
+ from GANDLF.utils.pred_target_processors import PredictionTargetProcessorFactory
34
+ from GANDLF.privacy.opacus.opacus_anonymization_manager import (
35
+ OpacusAnonymizationManager,
36
+ )
37
+ from GANDLF.privacy.opacus import handle_dynamic_batch_size
38
+
39
+
40
+ from GANDLF.utils import (
41
+ optimize_and_save_model,
42
+ write_training_patches,
43
+ one_hot,
44
+ reverse_one_hot,
45
+ print_model_summary,
46
+ get_date_time,
47
+ save_model,
48
+ load_model,
49
+ version_check,
50
+ get_filename_extension_sanitized,
51
+ resample_image,
52
+ BEST_MODEL_PATH_END,
53
+ INITIAL_MODEL_PATH_END,
54
+ LATEST_MODEL_PATH_END,
55
+ MapSaver,
56
+ )
57
+
58
+ from typing import Tuple, Union, Dict, List, Any
59
+
60
+
61
+ class GandlfLightningModule(pl.LightningModule):
62
+ CLASSIFICATION_REGRESSION_RESULTS_HEADER = "Epoch,SubjectID,PredictedValue\n"
63
+ CLASSIFICATION_REGRESSION_RESULTS_HEADER_HISTOPATH = "SubjectID,x_coords,y_coords"
64
+ FLOAT_FORMATTING_PRECISION = 4
65
+ MULTIPROCESSING_LOCK = Lock()
66
+
67
+ def __init__(self, params: dict, output_dir: str):
68
+ super().__init__()
69
+ self.output_dir = output_dir
70
+ self.params = deepcopy(params)
71
+ self.learning_rate = self.params["learning_rate"]
72
+ self._problem_type_is_regression = params["problem_type"] == "regression"
73
+ self._problem_type_is_classification = (
74
+ params["problem_type"] == "classification"
75
+ )
76
+ self._problem_type_is_segmentation = params["problem_type"] == "segmentation"
77
+ self._initialize_model()
78
+ self._initialize_loss()
79
+ self._initialize_metric_calculators()
80
+ self._initialize_preds_target_processor()
81
+ self._initialize_model_save_paths()
82
+
83
+ def _initialize_model(self):
84
+ """
85
+ Creates the BaseModel instance based on the parameters.
86
+ """
87
+
88
+ self.model = get_model(self.params)
89
+
90
+ def _initialize_loss(self):
91
+ """
92
+ Initializes the loss calculator based on the parameters. Loss calculator
93
+ logic differs for some specific model architectures, see the LossCalculatorFactory
94
+ for more details.
95
+ """
96
+
97
+ self.loss = LossCalculatorFactory(self.params).get_loss_calculator()
98
+
99
+ def _initialize_metric_calculators(self):
100
+ """
101
+ Initializes the metric calculators based on the parameters. Metric calculators
102
+ logic differs for some specific model architectures, see the MetricCalculatorFactory
103
+ for more details.
104
+ """
105
+
106
+ self.metric_calculators = MetricCalculatorFactory(
107
+ self.params
108
+ ).get_metric_calculator()
109
+
110
+ def _initialize_preds_target_processor(self):
111
+ """Initializes the prediction target processor based on the parameters.
112
+ This processor ensures that the prediction and target tensors are in the correct format,
113
+ as some architectures may require different formats for the predictions and targets.
114
+ """
115
+
116
+ self.pred_target_processor = PredictionTargetProcessorFactory(
117
+ self.params
118
+ ).get_prediction_target_processor()
119
+
120
+ def _initialize_model_save_paths(self):
121
+ """
122
+ Initializes the paths used for saving checkpoints of the model.
123
+ """
124
+
125
+ self.model_paths = {
126
+ "best": os.path.join(
127
+ self.output_dir,
128
+ self.params["model"]["architecture"] + BEST_MODEL_PATH_END,
129
+ ),
130
+ "initial": os.path.join(
131
+ self.output_dir,
132
+ self.params["model"]["architecture"] + INITIAL_MODEL_PATH_END,
133
+ ),
134
+ "latest": os.path.join(
135
+ self.output_dir,
136
+ self.params["model"]["architecture"] + LATEST_MODEL_PATH_END,
137
+ ),
138
+ }
139
+
140
+ @rank_zero_only
141
+ def _save_model(self, epoch: int, save_path: str, onnx_export: bool):
142
+ """
143
+ Saves the model to the specified path, adhering to GANDLF save format.
144
+
145
+ Args:
146
+ epoch (int): The epoch number.
147
+ save_path (str): The path to save the model to.
148
+ onnx_export (bool): Whether to export the model to ONNX format
149
+ """
150
+ save_model(
151
+ {
152
+ "epoch": epoch,
153
+ "model_state_dict": self.model.state_dict(),
154
+ "optimizer_state_dict": self.optimizers().optimizer.state_dict(),
155
+ "loss": self.current_best_loss,
156
+ },
157
+ model=self.model,
158
+ params=self.params,
159
+ path=save_path,
160
+ onnx_export=onnx_export,
161
+ )
162
+
163
+ def _prepare_metrics_dict_for_progbar_logging(
164
+ self, metric_results_dict: Dict[str, float]
165
+ ):
166
+ """
167
+ Formats the metric results dictionary into format suitable for
168
+ logging with Lightning's progress bar.
169
+
170
+ Args:
171
+ metric_results_dict (Dict[str, float]): The dictionary containing the metric results.
172
+ """
173
+ metric_results_dict_with_updated_suffix = (
174
+ self._add_stage_prefix_to_metric_results_dict(
175
+ metric_results_dict, self._determine_trainer_stage_string()
176
+ )
177
+ )
178
+ metric_results_dict_with_values_formatted = (
179
+ self._convert_per_class_metric_results_to_separate_key_value_pairs(
180
+ metric_results_dict_with_updated_suffix
181
+ )
182
+ )
183
+ return self._round_metric_values_in_dict(
184
+ metric_results_dict_with_values_formatted
185
+ )
186
+
187
+ @staticmethod
188
+ def _convert_per_class_metric_results_to_separate_key_value_pairs(
189
+ metric_results_dict: Dict[str, Any]
190
+ ) -> Dict[str, float]:
191
+ """
192
+ In case the metric results dictionary contains per-class values, this function
193
+ takes the values and creates separate key-value pairs for each class in the
194
+ results dictionary.
195
+
196
+ Args:
197
+ metric_results_dict (Dict[str, Any]): The dictionary containing the metric results.
198
+
199
+ Returns:
200
+ parsed_results_dict (Dict[str, float]): The dictionary containing the parsed results.
201
+ """
202
+ parsed_results_dict = deepcopy(metric_results_dict)
203
+ for metric_name, metric_value in metric_results_dict.items():
204
+ if isinstance(metric_value, list):
205
+ for n, metric_value_for_given_class in enumerate(metric_value):
206
+ parsed_results_dict[
207
+ metric_name + f"_class_{n}"
208
+ ] = metric_value_for_given_class
209
+ del parsed_results_dict[metric_name]
210
+ return parsed_results_dict
211
+
212
+ @staticmethod
213
+ def _add_stage_prefix_to_metric_results_dict(
214
+ metric_results_dict: Dict[str, float], stage: str
215
+ ):
216
+ """
217
+ Ensures that metric names in the results dictionary are prefixed with the stage
218
+ """
219
+ metric_results_dict_with_updated_suffix = {
220
+ f"{stage}_{metric_name}": metric_value
221
+ for metric_name, metric_value in metric_results_dict.items()
222
+ }
223
+ return metric_results_dict_with_updated_suffix
224
+
225
+ def _round_metric_values_in_dict(self, metric_results_dict: Dict[str, float]):
226
+ """
227
+ Performs rounding of the metric values in the results dictionary.
228
+
229
+ Args:
230
+ metric_results_dict (Dict[str, float]): The dictionary containing the metric results.
231
+
232
+ Returns:
233
+ rounded_metric_results_dict (Dict[str, float]): The dictionary containing the rounded metric results.
234
+ """
235
+
236
+ return {
237
+ k: self._round_value_to_precision(v) for k, v in metric_results_dict.items()
238
+ }
239
+
240
+ def _round_value_to_precision(self, value: float):
241
+ """
242
+ Rounds the value to the specified precision, defined as module constant.
243
+
244
+ Args:
245
+ value (float): The value to round.
246
+
247
+ Returns:
248
+ rounded_value (float): The rounded value.
249
+ """
250
+
251
+ return round(value, self.FLOAT_FORMATTING_PRECISION)
252
+
253
+ def forward(
254
+ self, images: torch.Tensor
255
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
256
+ """
257
+ Forward pass of the model.
258
+ """
259
+ attention_map = None
260
+ is_medcam_enabled = self.params.get("medcam_enabled", False)
261
+ if is_medcam_enabled:
262
+ output, attention_map = self.model(images)
263
+ if self.params["model"]["dimension"] == 2:
264
+ attention_map = torch.unsqueeze(attention_map, -1)
265
+ else:
266
+ output = self.model(images)
267
+ return output, attention_map
268
+
269
+ def on_train_start(self):
270
+ self._print_training_initialization_info()
271
+ self._set_training_start_time()
272
+ self._print_channels_info()
273
+ self._initialize_train_logger()
274
+ self._initialize_training_epoch_containers()
275
+ self.wait_count_before_early_stopping = 0
276
+ self.current_best_loss = sys.float_info.max
277
+
278
+ self.params["current_epoch"] = self.current_epoch
279
+ # TODO check out if the disabled by default medcam is indeed what we
280
+ # meant - it was taken from original code
281
+ if self.params.get("medcam"):
282
+ self._inject_medcam_module()
283
+ self.params["medcam_enabled"] = False
284
+ if self.params.get("differential_privacy"):
285
+ self._initialize_training_differential_privacy()
286
+
287
+ def _try_to_load_model_training_start(self):
288
+ """
289
+ Attempts to load the model at the start of the training.
290
+ """
291
+ if self._try_to_load_model(self.model_paths["best"]):
292
+ print(f"Previous best model loaded from {self.model_paths['best']}.")
293
+ elif self._try_to_load_model(self.model_paths["latest"]):
294
+ print(f"Previous latest model loaded from {self.model_paths['latest']}.")
295
+ else:
296
+ print(
297
+ "Could not load any previous model, training from scratch.", flush=True
298
+ )
299
+
300
+ def _try_to_load_model(self, load_path: str):
301
+ """
302
+ Attempts to load the model from the specified path.
303
+
304
+ Args:
305
+ load_path (str): The path to the model to load.
306
+
307
+ Returns:
308
+ bool: Whether the model was successfully loaded.
309
+ """
310
+ if os.path.exists(load_path):
311
+ try:
312
+ checkpoint_dict = load_model(load_path, self.device)
313
+ version_check(
314
+ self.params["version"], version_to_check=checkpoint_dict["version"]
315
+ )
316
+ # I am purposefully omitting the line below, as "previous_parameters" are not used anywhere
317
+ # params["previous_parameters"] = main_dict.get("parameters", None)
318
+ state_dict = checkpoint_dict["model_state_dict"]
319
+ if self.params.get("differential_privacy"):
320
+ # this is required for torch==1.11 and for DP inference
321
+ new_state_dict = {}
322
+ for key, val in state_dict.items():
323
+ new_key = key.replace("_module.", "")
324
+ new_state_dict[new_key] = val # remove `module.`
325
+ state_dict = new_state_dict
326
+
327
+ self.model.load_state_dict(state_dict)
328
+ if self.trainer.training:
329
+ self.optimizers(False).load_state_dict(
330
+ checkpoint_dict["optimizer_state_dict"]
331
+ )
332
+ self.trainer.fit_loop.epoch_progress.current.completed = (
333
+ checkpoint_dict["epoch"]
334
+ )
335
+ self.trainer.callback_metrics["val_loss"] = checkpoint_dict["loss"]
336
+ return True
337
+ except Exception as e:
338
+ warnings.warn(
339
+ f"Model found under path {load_path}, but error occurred during loading: {e}"
340
+ )
341
+ return False
342
+
343
+ @rank_zero_only
344
+ def _try_to_save_initial_model(self):
345
+ """
346
+ Saves the initial model at the specified path if it does not already exist.
347
+ """
348
+ if not os.path.exists(self.model_paths["initial"]):
349
+ self._save_model(self.current_epoch, self.model_paths["initial"], False)
350
+ print(f"Initial model saved at {self.model_paths['initial']}")
351
+ else:
352
+ print(
353
+ f"Initial model already exists at {self.model_paths['initial']}; Skipping saving"
354
+ )
355
+
356
+ def _inject_medcam_module(self):
357
+ """
358
+ Extends the model with the medcam module, used for generating attention maps.
359
+ """
360
+ self.model = medcam.inject(
361
+ self.model,
362
+ output_dir=os.path.join(
363
+ self.output_dir, "attention_maps", self.params["medcam"]["backend"]
364
+ ),
365
+ backend=self.params["medcam"]["backend"],
366
+ layer=self.params["medcam"]["layer"],
367
+ save_maps=False,
368
+ return_attention=True,
369
+ enabled=False,
370
+ )
371
+
372
+ def _get_metrics_names_for_loggers(self):
373
+ """
374
+ Returns the names of the overall metrics to be logged if the problem type is classification or regression.
375
+ """
376
+ metric_names = list(self.params["metrics"])
377
+ overall_metrics = {}
378
+ if self._problem_type_is_regression:
379
+ overall_metrics = overall_stats(
380
+ torch.tensor([1]), torch.tensor([1]), self.params
381
+ )
382
+ elif self._problem_type_is_classification:
383
+ temp_tensor = torch.randint(
384
+ 0, self.params["model"]["num_classes"], (5,), dtype=torch.int32
385
+ )
386
+ overall_metrics = overall_stats(temp_tensor, temp_tensor, self.params)
387
+ for overall_metric_key in overall_metrics.keys():
388
+ if overall_metric_key not in metric_names:
389
+ metric_names.append(overall_metric_key)
390
+
391
+ return metric_names
392
+
393
+ @rank_zero_only
394
+ def _initialize_train_logger(self):
395
+ self.train_logger = Logger(
396
+ logger_csv_filename=os.path.join(self.output_dir, "logs_training.csv"),
397
+ metrics=self._get_metrics_names_for_loggers(),
398
+ mode="train",
399
+ )
400
+
401
+ @rank_zero_only
402
+ def _set_training_start_time(self):
403
+ self.training_start_time = time.time()
404
+
405
+ @rank_zero_only
406
+ def _print_training_initialization_info(self):
407
+ """
408
+ Basic info printed at the start of the training.
409
+ """
410
+ self._print_host_info()
411
+ if self.params["verbose"]:
412
+ print("Initializing training at :", get_date_time(), flush=True)
413
+ if self.params["model"]["print_summary"]:
414
+ self._print_model_summary()
415
+
416
+ def _print_host_info(self):
417
+ if os.environ.get("HOSTNAME"):
418
+ print("Hostname :", os.environ.get("HOSTNAME"), flush=True)
419
+
420
+ def _print_model_summary(self):
421
+ print_model_summary(
422
+ self.model,
423
+ self.params["batch_size"],
424
+ self.params["model"]["num_channels"],
425
+ self.params["patch_size"],
426
+ )
427
+
428
+ @rank_zero_only
429
+ def _initialize_training_epoch_containers(self):
430
+ """
431
+ Initializes the containers for storing the training epoch data.
432
+ They are used for accumulating the losses, metrics, predictions and labels
433
+ for each epoch, so final calculations can be made at the end of the epoch.
434
+ """
435
+
436
+ self.train_losses: List[torch.Tensor] = []
437
+ self.training_metric_values: List[Dict[str, float]] = []
438
+ if self._problem_type_is_regression or self._problem_type_is_classification:
439
+ self.train_predictions: List[torch.Tensor] = []
440
+ self.train_labels: List[torch.Tensor] = []
441
+
442
+ @rank_zero_only
443
+ def _print_channels_info(self):
444
+ print("Number of channels : ", self.params["model"]["num_channels"])
445
+
446
+ def training_step(self, subject, batch_idx):
447
+ """
448
+ Single training optimization step.
449
+ """
450
+ if self.params.get("save_training"):
451
+ write_training_patches(subject, self.params)
452
+
453
+ if self.params.get("differential_privacy"):
454
+ self._handle_dynamic_batch_size_in_differential_privacy_mode(subject)
455
+
456
+ images = self._prepare_images_batch_from_subject_data(subject)
457
+ labels = self._prepare_labels_batch_from_subject_data(subject)
458
+
459
+ images = self._ensure_proper_images_tensor_dimensions(images)
460
+ labels = self._process_labels(labels)
461
+ model_output, _ = self.forward(images)
462
+ model_output, labels = self.pred_target_processor(model_output, labels)
463
+
464
+ loss = self.loss(model_output, labels, images)
465
+ metric_results = self.metric_calculators(
466
+ model_output, labels, subject_spacing=subject.get("spacing", None)
467
+ )
468
+
469
+ if self._problem_type_is_regression or self._problem_type_is_classification:
470
+ self.train_labels.append(labels.detach().cpu())
471
+ self.train_predictions.append(
472
+ torch.argmax(model_output, dim=1).detach().cpu()
473
+ )
474
+
475
+ self.train_losses.append(loss.detach().cpu())
476
+ self.training_metric_values.append(metric_results)
477
+
478
+ return loss
479
+
480
+ def _prepare_images_batch_from_subject_data(self, subject_data: torchio.Subject):
481
+ """
482
+ Concatenates the images from the subject data into a single tensor.
483
+
484
+ Args:
485
+ subject_data (torchio.Subject): The torchio.Subject object containing the images.
486
+ Can be also a set of already extracted patches.
487
+
488
+ Returns:
489
+ images_batch (torch.Tensor): The concatenated images from the subject data
490
+ of shape (B, C, H, W, D).
491
+
492
+ """
493
+ images_batch = torch.cat(
494
+ [subject_data[key][torchio.DATA] for key in self.params["channel_keys"]],
495
+ dim=1,
496
+ )
497
+ return images_batch
498
+
499
+ def _prepare_labels_batch_from_subject_data(self, subject: torchio.Subject):
500
+ """
501
+ Creates the label tensor from the subject data.
502
+
503
+ Args:
504
+ subject (torchio.Subject): The torchio.Subject object containing the label.
505
+
506
+ Returns:
507
+ label (torch.Tensor): The label tensor of shape (B, C, H, W, D) for segmentation,
508
+ or a tensor of shape (B, ) for classification/regression.
509
+ """
510
+
511
+ if self._problem_type_is_regression or self._problem_type_is_classification:
512
+ label = torch.cat(
513
+ [subject[key] for key in self.params["value_keys"]], dim=0
514
+ )
515
+ # TODO this for sure needs some further investigation
516
+ # min is needed because for certain cases, batch size becomes smaller than the total remaining labels
517
+ label = label.reshape(
518
+ min(self.params["batch_size"], len(label)),
519
+ len(self.params["value_keys"]),
520
+ )
521
+ else:
522
+ label = subject["label"][torchio.DATA]
523
+
524
+ return label
525
+
526
+ def _ensure_proper_images_tensor_dimensions(self, images: torch.Tensor):
527
+ """
528
+ Modify the input images by removing the singular depth dimension added
529
+ by torchio for 2D images.
530
+
531
+ Args:
532
+ images (torch.Tensor): The input images tensor.
533
+
534
+ Returns:
535
+ images (torch.Tensor): The modified images tensor.
536
+ """
537
+
538
+ if self.params["model"]["dimension"] == 2:
539
+ images = images.squeeze(-1)
540
+
541
+ return images
542
+
543
+ def _process_labels(self, labels: torch.Tensor):
544
+ """
545
+ Modifies the labels tensor based on the problem type.
546
+ """
547
+
548
+ if self._problem_type_is_segmentation:
549
+ if labels.shape[1] == 3:
550
+ labels = labels[:, 0, ...].unsqueeze(1)
551
+ warnings.warn(
552
+ "The label image is an RGB image, only the first channel will be used."
553
+ )
554
+
555
+ # for segmentation remove the depth dimension from the label.
556
+ # for classification / regression, flattens class / reg label from list (possible in multilabel) to scalar
557
+ # TODO: second condition is crutch - in some cases label is passed as 1-d Tensor (B,) and if Batch size is 1,
558
+ # it is squeezed to scalar tensor (0-d) and the future logic fails
559
+ if len(labels.shape) != 1:
560
+ labels = labels.squeeze(-1)
561
+
562
+ if self._problem_type_is_segmentation:
563
+ labels = one_hot(labels, self.params["model"]["class_list"])
564
+
565
+ return labels
566
+
567
+ def _handle_dynamic_batch_size_in_differential_privacy_mode(self, subject):
568
+ subject, _ = handle_dynamic_batch_size(subject, self.params)
569
+ return subject
570
+
571
+ def _initialize_training_differential_privacy(self):
572
+ self._check_if_opacus_is_applicable()
573
+ opacus_manager = OpacusAnonymizationManager(self.params)
574
+
575
+ (
576
+ model,
577
+ dp_optimizer,
578
+ train_dataloader,
579
+ privacy_engine,
580
+ ) = opacus_manager.apply_privacy(
581
+ self.model, self.optimizers().optimizer, self.trainer.train_dataloader
582
+ )
583
+ self.model = model
584
+ self.trainer.fit_loop._data_source.instance = train_dataloader
585
+ self.trainer.optimizers = [dp_optimizer]
586
+ # TODO should we reinit the scheduler too?
587
+ self._dp_engine = privacy_engine
588
+
589
+ def _check_if_opacus_is_applicable(self):
590
+ if isinstance(self.trainer.strategy, DDPStrategy):
591
+ raise NotImplementedError(
592
+ "Differential privacy is not supported with DDP strategy. Please use single GPU."
593
+ )
594
+
595
+ def on_train_epoch_start(self):
596
+ self._set_epoch_start_time()
597
+ if self.params["track_memory_usage"]:
598
+ self._write_epoch_start_process_resource_usage(self.current_epoch)
599
+ if self.params["verbose"]:
600
+ self._print_epoch_start_time()
601
+
602
+ def _write_epoch_start_process_resource_usage(self, epoch):
603
+ """
604
+ Writes the memory usage to a file at the start of the epoch.
605
+ Ran separately on each process in case of distributed training.
606
+
607
+ Args:
608
+ epoch (int): The current epoch number.
609
+ """
610
+ filename = f"memory_usage_local_rank_{self.local_rank}_global_rank_{self.global_rank}.csv"
611
+ memory_stats_dir = self._prepare_memory_stats_save_dir()
612
+ full_filepath = os.path.join(memory_stats_dir, filename)
613
+ file_write_mode = "a" if os.path.exists(full_filepath) else "w"
614
+ using_cuda = "cuda" in self.device.type
615
+
616
+ memory_info_string = "Epoch,Memory_Total,Memory_Available,Memory_Percent_Free,Memory_Usage," # used to write output
617
+ if using_cuda:
618
+ memory_info_string += (
619
+ "CUDA_active.all.peak,CUDA_active.all.current,CUDA_active.all.allocated"
620
+ )
621
+ memory_info_string += "\n"
622
+
623
+ host_memory_stats = psutil.virtual_memory()
624
+ memory_info_string += (
625
+ str(epoch)
626
+ + ","
627
+ + str(host_memory_stats[0])
628
+ + ","
629
+ + str(host_memory_stats[1])
630
+ + ","
631
+ + str(host_memory_stats[2])
632
+ + ","
633
+ + str(host_memory_stats[3])
634
+ )
635
+ if using_cuda:
636
+ cuda_memory_stats = torch.cuda.memory_stats()
637
+ memory_info_string += (
638
+ ","
639
+ + str(cuda_memory_stats["active.all.peak"])
640
+ + ","
641
+ + str(cuda_memory_stats["active.all.current"])
642
+ + ","
643
+ + str(cuda_memory_stats["active.all.allocated"])
644
+ )
645
+ memory_info_string += ",\n"
646
+
647
+ # TODO evaluate if this indeed works properly in distributed setting
648
+ self.MULTIPROCESSING_LOCK.acquire()
649
+ with open(full_filepath, file_write_mode) as file_mem:
650
+ file_mem.write(memory_info_string)
651
+ self.MULTIPROCESSING_LOCK.release()
652
+
653
+ @rank_zero_only
654
+ def _prepare_memory_stats_save_dir(self):
655
+ memory_stats_dir = os.path.join(self.output_dir, "memory_stats")
656
+ os.makedirs(memory_stats_dir, exist_ok=True)
657
+ return memory_stats_dir
658
+
659
+ @rank_zero_only
660
+ def _print_epoch_start_time(self):
661
+ print("Epoch start time : ", get_date_time(), flush=True)
662
+
663
+ @rank_zero_only
664
+ def _set_epoch_start_time(self):
665
+ self.epoch_start_time = time.time()
666
+
667
+ # TODO check if it indeed work properly and run only on rank 0
668
+ @rank_zero_only
669
+ def on_train_epoch_end(self):
670
+ epoch_metrics = {}
671
+ metric_names = self.training_metric_values[0].keys()
672
+ for metric_name in metric_names:
673
+ metric_values = [x[metric_name] for x in self.training_metric_values]
674
+ epoch_metrics[
675
+ metric_name
676
+ ] = self._compute_metric_mean_across_values_from_batches(metric_values)
677
+
678
+ if self._problem_type_is_regression or self._problem_type_is_classification:
679
+ training_epoch_average_metrics_overall = overall_stats(
680
+ torch.cat(self.train_predictions),
681
+ torch.cat(self.train_labels),
682
+ self.params,
683
+ )
684
+ epoch_metrics.update(training_epoch_average_metrics_overall)
685
+ mean_loss = self._round_value_to_precision(
686
+ torch.mean(torch.stack(self.train_losses)).item()
687
+ )
688
+
689
+ self._clear_training_epoch_containers()
690
+
691
+ self.train_logger.write(
692
+ self.current_epoch,
693
+ mean_loss,
694
+ self._ensure_proper_metric_formatting_for_logging(epoch_metrics),
695
+ )
696
+ self.log("train_loss", mean_loss, on_epoch=True, prog_bar=True)
697
+ self.log_dict(
698
+ self._prepare_metrics_dict_for_progbar_logging(epoch_metrics),
699
+ on_epoch=True,
700
+ prog_bar=True,
701
+ sync_dist=True,
702
+ )
703
+
704
+ if self.params["verbose"]:
705
+ self._print_epoch_end_time()
706
+ if self.params["model"]["save_at_every_epoch"]:
707
+ self._save_epoch_end_checkpoint()
708
+ if os.path.exists(self.model_paths["latest"]):
709
+ os.remove(self.model_paths["latest"])
710
+ self._save_model(self.current_epoch, self.model_paths["latest"], False)
711
+
712
+ print("Latest model saved")
713
+
714
+ def _compute_metric_mean_across_values_from_batches(
715
+ self, metric_values: List[Union[float, List[float]]]
716
+ ) -> Union[float, List[float]]:
717
+ """
718
+ Given a list of metrics calculated for each batch, computes the mean across all batches.
719
+ Takes into account case where metric is a list of values (e.g. for each class).
720
+
721
+ Args:
722
+ metric_values (List[Union[float, List[float]]]): The list of metric values for each batch.
723
+
724
+ Returns:
725
+ Union[float, List[float]]: The mean value of the metric across all batches.
726
+ """
727
+ if isinstance(metric_values[0], list):
728
+ return [
729
+ mean([batch_metrics[i] for batch_metrics in metric_values])
730
+ for i in range(len(metric_values[0]))
731
+ ]
732
+ return self._round_value_to_precision(mean(metric_values))
733
+
734
+ @staticmethod
735
+ def _ensure_proper_metric_formatting_for_logging(metrics_dict: dict) -> dict:
736
+ """
737
+ Helper function to ensure that all metric values are in the correct format for
738
+ GANDLF's logging system.
739
+
740
+ Args:
741
+ metrics_dict (dict): The dictionary containing the metric values.
742
+
743
+ Returns:
744
+ output_metrics_dict (dict): The dictionary containing the formatted metric values.
745
+ """
746
+ output_metrics_dict = deepcopy(metrics_dict)
747
+ for metric in metrics_dict.keys():
748
+ if isinstance(metrics_dict[metric], list):
749
+ output_metrics_dict[metric] = ("_").join(
750
+ str(metrics_dict[metric])
751
+ .replace("[", "")
752
+ .replace("]", "")
753
+ .replace(" ", "")
754
+ .split(",")
755
+ )
756
+
757
+ return output_metrics_dict
758
+
759
+ @rank_zero_only
760
+ def _save_epoch_end_checkpoint(self):
761
+ """
762
+ Saves the model at the end of the epoch.
763
+ """
764
+ epoch_save_path = os.path.join(
765
+ self.output_dir,
766
+ self.params["model"]["architecture"]
767
+ + "_epoch_"
768
+ + str(self.current_epoch)
769
+ + ".pth.tar",
770
+ )
771
+ self._save_model(self.current_epoch, epoch_save_path, False)
772
+ print("Epoch model saved.")
773
+
774
+ @rank_zero_only
775
+ def _print_epoch_end_time(self):
776
+ print(
777
+ "Time taken for epoch : ",
778
+ (time.time() - self.epoch_start_time) / 60,
779
+ " mins",
780
+ flush=True,
781
+ )
782
+
783
+ @rank_zero_only
784
+ def _clear_training_epoch_containers(self):
785
+ self.train_losses = []
786
+ self.training_metric_values = []
787
+ if self._problem_type_is_regression or self._problem_type_is_classification:
788
+ self.train_predictions = []
789
+ self.train_labels = []
790
+
791
+ @rank_zero_only
792
+ def on_train_end(self):
793
+ if os.path.exists(self.model_paths["best"]):
794
+ # Why don't we handle it here with the full save_model function?
795
+ # TODO Onnx export seems to modify model INPLACE, so when doing cuda
796
+ optimize_and_save_model(
797
+ self.model, self.params, self.model_paths["best"], onnx_export=False
798
+ )
799
+ self._print_total_training_time()
800
+
801
+ @rank_zero_only
802
+ def _print_total_training_time(self):
803
+ print(
804
+ "Total time taken for training : ",
805
+ (time.time() - self.training_start_time) / 60,
806
+ " mins",
807
+ flush=True,
808
+ )
809
+
810
+ @rank_zero_only
811
+ def on_validation_start(self):
812
+ self._initialize_validation_epoch_containers()
813
+ self._initialize_validation_logger()
814
+
815
+ @rank_zero_only
816
+ def _initialize_validation_epoch_containers(self):
817
+ self.val_losses: List[torch.Tensor] = []
818
+ self.validation_metric_values: List[Dict[str, float]] = []
819
+ if self._problem_type_is_regression or self._problem_type_is_classification:
820
+ self.val_predictions: List[float] = []
821
+ self.val_labels: List[float] = []
822
+ if self.params["save_output"]:
823
+ self.rows_to_write: List[str] = []
824
+
825
+ @rank_zero_only
826
+ def _initialize_validation_logger(self):
827
+ self.val_logger = Logger(
828
+ logger_csv_filename=os.path.join(self.output_dir, "logs_validation.csv"),
829
+ metrics=self._get_metrics_names_for_loggers(),
830
+ mode="val",
831
+ add_epsilon=bool(self.params.get("differential_privacy")),
832
+ )
833
+
834
+ @rank_zero_only
835
+ def on_validation_epoch_start(self):
836
+ # TODO this is dead code both here and in original loops
837
+ # by default medcam is injected at the training and ["medcam_enabled"] is set to False
838
+ # so this block is never executed
839
+ if self.params["medcam_enabled"]:
840
+ self.model.enable_medcam()
841
+ self.params["medcam_enabled"] = True
842
+ self._current_validation_epoch_save_dir = os.path.join(
843
+ self.output_dir, "output_validation", f"epoch_{self.current_epoch}"
844
+ )
845
+ self._ensure_path_exists(self._current_validation_epoch_save_dir)
846
+
847
+ def validation_step(self, subject, batch_idx):
848
+ if self.params["verbose"]:
849
+ self._print_currently_processed_subject(subject)
850
+
851
+ subject_dict = self._initialize_subject_dict_nontraining_mode(subject)
852
+ label_present = subject["label"] != ["NA"]
853
+ value_keys_present = "value_keys" in self.params
854
+ label = None
855
+ if label_present:
856
+ subject_dict = self._extend_nontraining_subject_dict_with_label(
857
+ subject, subject_dict
858
+ )
859
+
860
+ if (
861
+ self._problem_type_is_regression
862
+ or self._problem_type_is_classification
863
+ and label_present
864
+ ):
865
+ (
866
+ model_output,
867
+ last_input_batch,
868
+ ) = self._get_predictions_on_subject_using_label_sampler(subject_dict)
869
+
870
+ if self.params["save_output"]:
871
+ processed_logit = self._process_prediction_logit_for_row_writing(
872
+ model_output, self.params["scaling_factor"]
873
+ )
874
+ self.rows_to_write.append(
875
+ self._prepare_row_for_output_csv(
876
+ subject["subject_id"][0], processed_logit, self.current_epoch
877
+ )
878
+ )
879
+
880
+ label = self._initialize_nontraining_label_ground_truth_classification_or_regression(
881
+ subject
882
+ )
883
+ else:
884
+ (
885
+ model_output,
886
+ last_input_batch,
887
+ ) = self._get_predictions_on_subject_using_grid_sampler(subject_dict)
888
+
889
+ if self.params["save_output"]:
890
+ self._save_predictions_for_segmentation_subject(model_output, subject)
891
+
892
+ if self._problem_type_is_segmentation and label_present:
893
+ label = self._initialize_nontraining_label_ground_truth_segmentation(
894
+ subject
895
+ )
896
+ elif (
897
+ self._problem_type_is_classification
898
+ or self._problem_type_is_regression
899
+ and value_keys_present
900
+ ):
901
+ label = self._initialize_nontraining_label_ground_truth_classification_or_regression(
902
+ subject
903
+ )
904
+
905
+ if label is not None:
906
+ label = self._process_labels(label)
907
+ model_output, label = self.pred_target_processor(model_output, label)
908
+ loss = self.loss(model_output, label, last_input_batch)
909
+ metric_results = self.metric_calculators(
910
+ model_output, label, subject_spacing=subject.get("spacing", None)
911
+ )
912
+
913
+ self.val_losses.append(loss)
914
+ self.validation_metric_values.append(metric_results)
915
+
916
+ if (
917
+ self._problem_type_is_regression
918
+ or self._problem_type_is_classification
919
+ and label
920
+ ):
921
+ model_prediction = (
922
+ torch.argmax(model_output[0], 0)
923
+ if self._problem_type_is_classification
924
+ else model_output[0]
925
+ ) # TODO am I right here? For regression, we should not take argmax
926
+ self.val_predictions.append(model_prediction.item())
927
+ self.val_labels.append(label.item())
928
+
929
+ @staticmethod
930
+ def _prepare_row_for_output_csv(
931
+ subject_id: str, prediction_logit: float, epoch: int
932
+ ):
933
+ """
934
+ Helper function to prepare the row for the output CSV file.
935
+
936
+ Args:
937
+ subject_id (str): The subject ID.
938
+ prediction_logit (float): The prediction logit.
939
+ epoch (int): The epoch number.
940
+
941
+ Returns:
942
+ row (str): The row to write to the output CSV file.
943
+ """
944
+
945
+ return f"{epoch},{subject_id},{prediction_logit}\n"
946
+
947
+ @staticmethod
948
+ def _prepare_row_for_output_csv_histopathology_inference(
949
+ subject_name, x_coord, y_coord, output_matrix
950
+ ):
951
+ """
952
+ Helper function to prepare the row for the output CSV file in histopathology inference.
953
+
954
+ Args:
955
+ subject_name (str): The subject name.
956
+ x_coord (int): The x coordinate.
957
+ y_coord (int): The y coordinate.
958
+ output_matrix (np.array) : output matrix of the model, a set of
959
+ predicted 2D matrices for each class
960
+
961
+ Returns:
962
+ row (str): The row to write to the output CSV file.
963
+ """
964
+ base_string = f"{subject_name},{x_coord},{y_coord}"
965
+ for output_for_class in output_matrix:
966
+ base_string += f",{output_for_class}"
967
+ return base_string + "\n"
968
+
969
+ @staticmethod
970
+ def _process_prediction_logit_for_row_writing(
971
+ prediction_logit: torch.Tensor, scaling_factor: float = 1.0
972
+ ):
973
+ """
974
+ Processes the prediction logits for writing to the output CSV file.
975
+
976
+ Args:
977
+ prediction_logit (torch.Tensor): The prediction logits.
978
+ scaling_factor (float): The scaling factor modifying the prediction logit.
979
+ Default is 1 (no scaling).
980
+
981
+ Returns:
982
+ prediction_logit (float): The processed prediction logit.
983
+ """
984
+ return prediction_logit.cpu().max().item() / scaling_factor
985
+
986
+ def _print_currently_processed_subject(self, subject):
987
+ if isinstance(subject, torchio.Subject):
988
+ subject_id = subject["subject_id"]
989
+ elif isinstance(subject, tuple):
990
+ # ugly corner histology inference handling, when incoming batch is
991
+ # a row from dataframe, not a torchio.Subject. This should be solved
992
+ # via some kind of polymorphism in the future
993
+ subject_data = subject[1]
994
+ subject_id = subject_data[self.params["headers"]["subjectIDHeader"]]
995
+ print("== Current subject:", subject_id, flush=True)
996
+
997
+ def _initialize_subject_dict_nontraining_mode(self, subject: torchio.Subject):
998
+ """
999
+ Create a dictionary containing the subject data for the non-training mode
1000
+ (validation, testing, inference).
1001
+
1002
+ Args:
1003
+ subject (torchio.Subject): The subject data.
1004
+
1005
+ Returns:
1006
+ subject_dict (Dict[str, torchio.Image]): The dictionary containing the subject data.
1007
+ """
1008
+ subject_dict = {}
1009
+
1010
+ for channel_key in self.params["channel_keys"]:
1011
+ subject_dict[channel_key] = torchio.ScalarImage(
1012
+ path=subject[channel_key]["path"],
1013
+ tensor=subject[channel_key]["data"].squeeze(0),
1014
+ affine=subject[channel_key]["affine"].squeeze(0),
1015
+ )
1016
+ value_keys_present = "value_keys" in self.params
1017
+ if (
1018
+ self._problem_type_is_regression
1019
+ or self._problem_type_is_classification
1020
+ and value_keys_present
1021
+ ):
1022
+ for key in self.params["value_keys"]:
1023
+ subject_dict["value_" + key] = subject[key]
1024
+
1025
+ return subject_dict
1026
+
1027
+ def _extend_nontraining_subject_dict_with_label(
1028
+ self, subject: torchio.Subject, subject_dict: dict
1029
+ ) -> dict:
1030
+ """
1031
+ Extends the subject dictionary with the label data for the non-training mode.
1032
+
1033
+ Args:
1034
+ subject (torchio.Subject): The subject data.
1035
+ subject_dict (dict): The dictionary containing the subject data.
1036
+
1037
+ Returns:
1038
+ subject_dict (dict): The dictionary containing the subject data with the label data.
1039
+ """
1040
+ subject_dict["label"] = torchio.LabelMap(
1041
+ path=subject["label"]["path"],
1042
+ tensor=subject["label"]["data"].squeeze(0),
1043
+ affine=subject["label"]["affine"].squeeze(0),
1044
+ )
1045
+
1046
+ return subject_dict
1047
+
1048
+ def _initialize_nontraining_label_ground_truth_classification_or_regression(
1049
+ self, subject: torchio.Subject
1050
+ ):
1051
+ """
1052
+ Initializes the ground truth label for classification or regression problems
1053
+ in the non-training mode (validation, testing, inference).
1054
+
1055
+ Args:
1056
+ subject_dict (torchio.Subject): The dictionary containing the subject data.
1057
+
1058
+ Returns:
1059
+ label (torch.Tensor): The ground truth label tensor.
1060
+ """
1061
+ return torch.cat([subject[key] for key in self.params["value_keys"]], dim=0)
1062
+
1063
+ def _initialize_nontraining_label_ground_truth_segmentation(
1064
+ self, subject: torchio.Subject
1065
+ ):
1066
+ """
1067
+ Initializes the ground truth label for segmentation problems in the non-training mode
1068
+ (validation, testing, inference).
1069
+
1070
+ Args:
1071
+ subject_dict (torchio.Subject): The dictionary containing the subject data.
1072
+
1073
+ Returns:
1074
+ label (torch.Tensor): The ground truth label tensor
1075
+ """
1076
+
1077
+ return subject["label"]["data"]
1078
+
1079
+ # TODO this whole logic can be packed into something separate, as it is only used
1080
+ # in validation of regression and classification problems
1081
+ def _get_predictions_on_subject_using_label_sampler(
1082
+ self, subject_dict: dict
1083
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1084
+ """
1085
+ Make predictions on the subject using the label sampler. Used for regression and classification problems.
1086
+
1087
+ Args:
1088
+ subject_dict (dict): The dictionary containing the subject data.
1089
+
1090
+ Returns:
1091
+ total_logits_for_all_patches (torch.Tensor): The total logits for all patches
1092
+ extracted from a subject, normalized by the number of samples per volume.
1093
+ last_batch_of_input_images (torch.Tensor): The last batch of input images. Used
1094
+ mostly for special cases like deep_resunet, deep_unet, etc. when it is needed for
1095
+ loss calculation.
1096
+ """
1097
+
1098
+ def _prepare_images_batch_from_patch_regression_or_classification_with_label_sampler(
1099
+ patches_batch: torchio.Subject,
1100
+ ):
1101
+ """
1102
+ Sampling the patches using the label sampler requires a different approach
1103
+ to preparing the images batch (concatenation dimension changes compared to logic
1104
+ in other steps).
1105
+
1106
+ Args:
1107
+ patches_batch (torchio.Subject): The batch of patches for the subject.
1108
+
1109
+ Returns:
1110
+ images_batch_from_patches (torch.Tensor): The images batch from the patches.
1111
+ """
1112
+ images_batch_from_patches = torch.cat(
1113
+ [
1114
+ patches_batch[key][torchio.DATA]
1115
+ for key in self.params["channel_keys"]
1116
+ ],
1117
+ dim=0,
1118
+ ).unsqueeze(0)
1119
+ if images_batch_from_patches.shape[-1] == 1:
1120
+ images_batch_from_patches = torch.squeeze(images_batch_from_patches, -1)
1121
+ return images_batch_from_patches
1122
+
1123
+ sampler = torchio.data.LabelSampler(self.params["patch_size"])
1124
+ tio_subject = torchio.Subject(subject_dict)
1125
+ patch_loader = sampler(
1126
+ tio_subject, num_patches=self.params["q_samples_per_volume"]
1127
+ )
1128
+
1129
+ model_outputs_list: List[torch.Tensor] = []
1130
+ for patches_batch in patch_loader:
1131
+ images_from_patches = _prepare_images_batch_from_patch_regression_or_classification_with_label_sampler(
1132
+ patches_batch
1133
+ )
1134
+ images_from_patches = self._ensure_proper_images_tensor_dimensions(
1135
+ images_from_patches
1136
+ )
1137
+ model_output, _ = self.forward(images_from_patches)
1138
+ model_outputs_list.append(model_output)
1139
+
1140
+ total_logits_for_all_patches = torch.cat(model_outputs_list).sum(
1141
+ dim=0, keepdim=True
1142
+ )
1143
+ return (
1144
+ total_logits_for_all_patches / self.params["q_samples_per_volume"],
1145
+ images_from_patches,
1146
+ )
1147
+
1148
+ @rank_zero_only
1149
+ def _determine_trainer_stage_string(self):
1150
+ """
1151
+ Helper function to determine the trainer stage and store it as a module attribute.
1152
+ """
1153
+ if self.trainer.validating:
1154
+ return "val"
1155
+ elif self.trainer.testing:
1156
+ return "test"
1157
+ elif self.trainer.predicting:
1158
+ return "inference"
1159
+
1160
+ return "train"
1161
+
1162
+ def _determine_save_path_to_use(self):
1163
+ """
1164
+ Helper function to determine the output save path based on the trainer stage.
1165
+ """
1166
+ if self.trainer.validating:
1167
+ return self._current_validation_epoch_save_dir
1168
+ elif self.trainer.testing:
1169
+ return self._current_test_epoch_save_dir
1170
+ elif self.trainer.predicting:
1171
+ return self._current_inference_save_dir
1172
+ else:
1173
+ raise RuntimeError("Output save path cannot be determined for training")
1174
+
1175
+ def _get_predictions_on_subject_using_grid_sampler(
1176
+ self, subject_dict: dict
1177
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1178
+ """
1179
+ Make predictions on the subject using the grid sampler. This is used in segmentation
1180
+ problems in validation and testing and for all problems in inference
1181
+ (as no ground truth is available in inference).
1182
+
1183
+ Args:
1184
+ subject_dict (dict): The dictionary containing the subject data.
1185
+
1186
+ Returns:
1187
+ aggregated_predictions (torch.Tensor): The predicted segmentation mask.
1188
+ last_batch_of_input_images (torch.Tensor): The last batch of input images. Used
1189
+ mostly for special cases like deep_resunet, deep_unet, etc. when it is needed for
1190
+ loss calculation.
1191
+ """
1192
+
1193
+ def _ensure_output_is_tensor_for_special_architectures(
1194
+ model_output: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
1195
+ ):
1196
+ """
1197
+ Helper function to ensure that the output is a tensor for special architectures
1198
+ that return a tuple of tensors (SDnet, DeepResunet etc)
1199
+
1200
+ Args:
1201
+ model_output (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): The model output.
1202
+ """
1203
+
1204
+ if not isinstance(model_output, torch.Tensor):
1205
+ warnings.warn(
1206
+ f"Model output is not a Tensor: {type(model_output)}. Say, `deep_resunet` and `deep_unet` may return "
1207
+ f"list of tensors on different scales instead of just one prediction Tensor. However due to "
1208
+ f"GaNDLF architecture it is expected that models return only one tensor. For deep_* models "
1209
+ f"only the biggest scale is processed. Use these models with caution till fix is implemented."
1210
+ )
1211
+ model_output = model_output[0]
1212
+
1213
+ return model_output
1214
+
1215
+ def _ensure_output_shape_compatibility_with_torchio(model_output: torch.Tensor):
1216
+ """
1217
+ Helper function to ensure that the output shape is compatible with torchio (4D for 2D segmentation).
1218
+
1219
+ Args:
1220
+ model_output (torch.Tensor): The model output tensor.
1221
+
1222
+ Returns:
1223
+ model_output (torch.Tensor): The model output tensor with the correct shape.
1224
+ """
1225
+ if (
1226
+ self.params["model"]["dimension"] == 2
1227
+ and self._problem_type_is_segmentation
1228
+ ):
1229
+ model_output = model_output.unsqueeze(-1)
1230
+ return model_output
1231
+
1232
+ grid_sampler = self._prepare_grid_sampler(subject_dict)
1233
+ patch_loader = self._prepare_dataloader_from_grid_sampler(grid_sampler)
1234
+
1235
+ prediction_aggregator = torchio.inference.GridAggregator(
1236
+ grid_sampler,
1237
+ overlap_mode=self.params["inference_mechanism"]["grid_aggregator_overlap"],
1238
+ )
1239
+ if self.params["medcam_enabled"]:
1240
+ medcam_attention_map_aggregator = torchio.inference.GridAggregator(
1241
+ grid_sampler,
1242
+ overlap_mode=self.params["inference_mechanism"][
1243
+ "grid_aggregator_overlap"
1244
+ ],
1245
+ )
1246
+ if self._problem_type_is_regression or self._problem_type_is_classification:
1247
+ model_outputs_list: List[torch.Tensor] = []
1248
+
1249
+ for patches_batch in patch_loader:
1250
+ images_from_patches = self._prepare_images_batch_from_subject_data(
1251
+ patches_batch
1252
+ )
1253
+ images_from_patches = self._ensure_proper_images_tensor_dimensions(
1254
+ images_from_patches
1255
+ )
1256
+ model_output, attention_map = self.forward(images_from_patches)
1257
+
1258
+ model_output = _ensure_output_is_tensor_for_special_architectures(
1259
+ model_output
1260
+ )
1261
+ model_output = _ensure_output_shape_compatibility_with_torchio(model_output)
1262
+ if self.params["medcam_enabled"]:
1263
+ medcam_attention_map_aggregator.add_batch(
1264
+ attention_map, patches_batch[torchio.LOCATION] # type: ignore
1265
+ )
1266
+ if self._problem_type_is_segmentation:
1267
+ prediction_aggregator.add_batch(
1268
+ model_output, patches_batch[torchio.LOCATION]
1269
+ )
1270
+ else:
1271
+ model_outputs_list.append(model_output)
1272
+
1273
+ if self.params["medcam_enabled"]:
1274
+ attention_map = medcam_attention_map_aggregator.get_output_tensor()
1275
+ for i, n in enumerate(attention_map):
1276
+ self.model.save_attention_map(
1277
+ n.squeeze(), raw_input=images_from_patches[i].squeeze(-1)
1278
+ )
1279
+
1280
+ if self._problem_type_is_regression or self._problem_type_is_classification:
1281
+ return (
1282
+ torch.cat(model_outputs_list).sum(dim=0, keepdim=True)
1283
+ / len(patch_loader),
1284
+ images_from_patches,
1285
+ )
1286
+
1287
+ return (
1288
+ prediction_aggregator.get_output_tensor().unsqueeze(0).to(self.device),
1289
+ images_from_patches,
1290
+ )
1291
+
1292
+ def _prepare_grid_sampler(self, subject_dict: dict):
1293
+ """
1294
+ Creates the grid sampler for the grid aggregator.
1295
+
1296
+ Args:
1297
+ subject_dict (dict): The dictionary containing the subject data.
1298
+
1299
+ Returns:
1300
+ grid_sampler (torchio.inference.GridSampler): The grid sampler.
1301
+ """
1302
+ grid_sampler = torchio.inference.GridSampler(
1303
+ torchio.Subject(subject_dict),
1304
+ self.params["patch_size"],
1305
+ patch_overlap=self.params["inference_mechanism"]["patch_overlap"],
1306
+ )
1307
+ return grid_sampler
1308
+
1309
+ def _prepare_dataloader_from_grid_sampler(
1310
+ self, grid_sampler: torchio.inference.GridSampler
1311
+ ):
1312
+ """
1313
+ Creates the dataloader from the grid sampler.
1314
+
1315
+ Args:
1316
+ grid_sampler (torchio.inference.GridSampler): The grid sampler.
1317
+
1318
+ Returns:
1319
+ patch_loader (torch.utils.data.DataLoader): The patch loader.
1320
+ """
1321
+
1322
+ return torch.utils.data.DataLoader(grid_sampler, batch_size=1) # type: ignore
1323
+
1324
+ # TODO check if it indeed work properly and run only on rank 0
1325
+ @rank_zero_only
1326
+ def on_validation_epoch_end(self):
1327
+ validation_epoch_average_metrics = {}
1328
+ metric_names = self.validation_metric_values[0].keys()
1329
+ for metric_name in metric_names:
1330
+ metric_values = [x[metric_name] for x in self.validation_metric_values]
1331
+ validation_epoch_average_metrics[
1332
+ metric_name
1333
+ ] = self._compute_metric_mean_across_values_from_batches(metric_values)
1334
+
1335
+ if self._problem_type_is_regression or self._problem_type_is_classification:
1336
+ # This is a workaround - sometimes the lists are empty
1337
+ preds_or_labels_not_empty = not (
1338
+ len(self.val_predictions) == 0 or len(self.val_labels) == 0
1339
+ )
1340
+ if preds_or_labels_not_empty:
1341
+ validation_epoch_average_metrics_overall = overall_stats(
1342
+ torch.tensor(self.val_predictions),
1343
+ torch.tensor(self.val_labels),
1344
+ self.params,
1345
+ )
1346
+ validation_epoch_average_metrics.update(
1347
+ validation_epoch_average_metrics_overall
1348
+ )
1349
+ mean_loss = self._round_value_to_precision(
1350
+ torch.mean(torch.stack(self.val_losses)).item()
1351
+ )
1352
+
1353
+ self.val_logger.write(
1354
+ self.current_epoch,
1355
+ mean_loss,
1356
+ self._ensure_proper_metric_formatting_for_logging(
1357
+ validation_epoch_average_metrics
1358
+ ),
1359
+ )
1360
+
1361
+ self.log("val_loss", mean_loss, on_epoch=True, prog_bar=True)
1362
+ self.log_dict(
1363
+ self._prepare_metrics_dict_for_progbar_logging(
1364
+ validation_epoch_average_metrics
1365
+ ),
1366
+ on_epoch=True,
1367
+ prog_bar=True,
1368
+ sync_dist=False,
1369
+ )
1370
+
1371
+ self._check_if_early_stopping(mean_loss)
1372
+
1373
+ if self.params["save_output"] and (
1374
+ self._problem_type_is_regression or self._problem_type_is_classification
1375
+ ):
1376
+ self._save_predictions_csv_for_regression_or_classification(
1377
+ self.rows_to_write, self._determine_save_path_to_use()
1378
+ )
1379
+ if self.params.get("differential_privacy"):
1380
+ self._print_differential_privacy_info()
1381
+ self._clear_validation_epoch_containers()
1382
+
1383
+ @rank_zero_only
1384
+ def _clear_validation_epoch_containers(self):
1385
+ self.val_losses = []
1386
+ self.validation_metric_values = []
1387
+ if self._problem_type_is_regression or self._problem_type_is_classification:
1388
+ self.val_predictions = []
1389
+ self.val_labels = []
1390
+ if self.params["save_output"]:
1391
+ self.rows_to_write = []
1392
+
1393
+ @rank_zero_only
1394
+ def _ensure_path_exists(self, path):
1395
+ if not os.path.exists(path):
1396
+ os.makedirs(path)
1397
+
1398
+ @rank_zero_only
1399
+ def _print_differential_privacy_info(self):
1400
+ delta = self.params["differential_privacy"]["delta"]
1401
+ epsilon = self._dp_engine.get_epsilon(delta)
1402
+ print(f"Epoch {self.current_epoch} Privacy: ε = {epsilon:.2f}, δ = {delta}")
1403
+ self.log("epsilon", epsilon, on_epoch=True, prog_bar=True)
1404
+ self.log("delta", delta, on_epoch=True, prog_bar=True)
1405
+
1406
+ # TODO called at the validation step, NOT at the end of the epoch - we want to avoid
1407
+ # saving all predictions for all subjects for the end of the epoch
1408
+ def _save_predictions_for_segmentation_subject(
1409
+ self, predicted_segmentation_mask: torch.Tensor, subject: torchio.Subject
1410
+ ):
1411
+ """
1412
+ Saves the predicted segmentation mask for a given subject, performing the necessary postprocessing
1413
+ steps.
1414
+
1415
+ Args:
1416
+ predicted_segmentation_mask (torch.Tensor): The predicted segmentation mask, extracted
1417
+ from the grid aggregator when all validation patches for this subject have
1418
+ been processed.
1419
+ subject (torchio.Subject): The subject for which the segmentation mask was predicted, used
1420
+ to extract the metadata.
1421
+ """
1422
+
1423
+ def _convert_subject_to_sitk_format(subject: torchio.Subject):
1424
+ return torchio.ScalarImage(
1425
+ tensor=subject["1"]["data"].squeeze(0).cpu(),
1426
+ affine=subject["1"]["affine"].squeeze(0).cpu(),
1427
+ ).as_sitk()
1428
+
1429
+ def _postprocess_raw_segmentation_mask(
1430
+ segmentation_mask: np.ndarray, params: dict
1431
+ ):
1432
+ for postprocessor in params["data_postprocessing"]:
1433
+ for _class in range(0, params["model"]["num_classes"]):
1434
+ segmentation_mask[0, _class, ...] = global_postprocessing_dict[
1435
+ postprocessor
1436
+ ](segmentation_mask[0, _class, ...], params)
1437
+
1438
+ return segmentation_mask
1439
+
1440
+ def _swap_mask_axes_for_sitk_save_format_compatibility(
1441
+ segmentation_mask: np.ndarray,
1442
+ ):
1443
+ return np.swapaxes(segmentation_mask, 0, 2)
1444
+
1445
+ def _postprocess_one_hot_reversed_segmentation_mask(
1446
+ segmentation_mask: np.ndarray, params: dict
1447
+ ):
1448
+ for postprocessor in params[
1449
+ "data_postprocessing_after_reverse_one_hot_encoding"
1450
+ ]:
1451
+ segmentation_mask = global_postprocessing_dict[postprocessor](
1452
+ segmentation_mask, params
1453
+ )
1454
+
1455
+ return segmentation_mask
1456
+
1457
+ def _determine_final_prediction_mask_shape(segmentation_mask: np.ndarray):
1458
+ if segmentation_mask.shape[0] == 1:
1459
+ return segmentation_mask.squeeze(0)
1460
+ elif segmentation_mask.shape[-1] == 1:
1461
+ return segmentation_mask.squeeze(-1)
1462
+ else:
1463
+ return segmentation_mask
1464
+
1465
+ predicted_segmentation_mask_numpy = predicted_segmentation_mask.cpu().numpy()
1466
+ predicted_segmentation_mask_numpy = _postprocess_raw_segmentation_mask(
1467
+ predicted_segmentation_mask_numpy, self.params
1468
+ )
1469
+ # taking 0-th element as the batch size is 1, and this is required by reverse_one_hot function
1470
+ decoded_segmentation_mask = reverse_one_hot(
1471
+ predicted_segmentation_mask_numpy[0], self.params["model"]["class_list"]
1472
+ )
1473
+ decoded_segmentation_mask = _swap_mask_axes_for_sitk_save_format_compatibility(
1474
+ decoded_segmentation_mask
1475
+ )
1476
+ decoded_segmentation_mask = _postprocess_one_hot_reversed_segmentation_mask(
1477
+ decoded_segmentation_mask, self.params
1478
+ )
1479
+ decoded_segmentation_mask = _determine_final_prediction_mask_shape(
1480
+ decoded_segmentation_mask
1481
+ )
1482
+
1483
+ image_save_format = get_filename_extension_sanitized(subject["1"]["path"][0])
1484
+ if image_save_format in [".jpg", ".jpeg", ".png"]:
1485
+ decoded_segmentation_mask = decoded_segmentation_mask.astype(np.uint8)
1486
+
1487
+ subject_converted_to_sitk_format = _convert_subject_to_sitk_format(subject)
1488
+ result_sitk_image = sitk.GetImageFromArray(decoded_segmentation_mask)
1489
+ result_sitk_image.CopyInformation(subject_converted_to_sitk_format)
1490
+
1491
+ if "resample" in self.params["data_preprocessing"]:
1492
+ result_sitk_image = resample_image(
1493
+ result_sitk_image,
1494
+ subject_converted_to_sitk_format.GetSpacing(),
1495
+ interpolator=sitk.sitkNearestNeighbor,
1496
+ )
1497
+ segmentation_mask_save_path = os.path.join(
1498
+ self._determine_save_path_to_use(),
1499
+ subject["subject_id"][0],
1500
+ f"{subject['subject_id'][0]}_seg_process_rank_{self.global_rank}{image_save_format}",
1501
+ )
1502
+ self._ensure_path_exists(os.path.dirname(segmentation_mask_save_path))
1503
+ sitk.WriteImage(result_sitk_image, segmentation_mask_save_path)
1504
+
1505
+ @rank_zero_only
1506
+ def _save_predictions_csv_for_regression_or_classification(
1507
+ self, rows_to_write: List[str], save_path: str
1508
+ ):
1509
+ """
1510
+ Saves the predictions for regression or classification problems to a CSV file.
1511
+
1512
+ Args:
1513
+ rows_to_write (List[str]): The rows to write to the CSV file. Each element of
1514
+ the list is a row.
1515
+ save_path (str): The save path for the CSV file.
1516
+ """
1517
+
1518
+ def _determine_header_to_use():
1519
+ if self.trainer.predicting:
1520
+ if self.params["modality"] in ["histo", "path"]:
1521
+ header = self.CLASSIFICATION_REGRESSION_RESULTS_HEADER_HISTOPATH
1522
+ if self._problem_type_is_regression:
1523
+ return header + ",output\n"
1524
+ elif self._problem_type_is_classification:
1525
+ for class_num in range(self.params["model"]["num_classes"]):
1526
+ header += f",probability_{class_num}"
1527
+ return header + "\n"
1528
+ return self.CLASSIFICATION_REGRESSION_RESULTS_HEADER
1529
+
1530
+ csv_save_path = os.path.join(save_path, "output_predictions.csv")
1531
+ merged_output = _determine_header_to_use()
1532
+ for row in rows_to_write:
1533
+ merged_output += row
1534
+ with open(csv_save_path, "w") as file:
1535
+ file.write(merged_output)
1536
+
1537
+ # TODO separate it into checking and saving functions, perhaps even separate class
1538
+ @rank_zero_only
1539
+ def _check_if_early_stopping(self, val_loss: float):
1540
+ """
1541
+ Checks if early stopping should be triggered based on the validation loss.
1542
+ If the loss improves, the best model is saved.
1543
+ """
1544
+ previous_best_loss = deepcopy(self.current_best_loss)
1545
+ if val_loss < self.current_best_loss:
1546
+ self.current_best_loss = val_loss
1547
+ self._save_model(self.current_epoch, self.model_paths["best"], False)
1548
+ print(
1549
+ f"Loss value improved. Previous best loss :{previous_best_loss}, new best loss: {val_loss} Saving best model from epoch {self.current_epoch}",
1550
+ flush=True,
1551
+ )
1552
+ self.wait_count_before_early_stopping = 0
1553
+ else:
1554
+ self.wait_count_before_early_stopping += 1
1555
+ print(
1556
+ f"Validation loss did not improve. Waiting count before early stopping: {self.wait_count_before_early_stopping} / {self.params['patience']}",
1557
+ flush=True,
1558
+ )
1559
+ if self.wait_count_before_early_stopping > self.params["patience"]:
1560
+ self.trainer.should_stop = True
1561
+ print(
1562
+ f"Early stopping triggered at epoch {self.current_epoch}, validation loss did not improve for {self.params['patience']} epochs, with the best loss value being {self.current_best_loss}. Stopping training.",
1563
+ flush=True,
1564
+ )
1565
+ del previous_best_loss
1566
+
1567
+ def on_test_start(self):
1568
+ self._initialize_test_epoch_containers()
1569
+ self._initialize_test_logger()
1570
+
1571
+ @rank_zero_only
1572
+ def _initialize_test_logger(self):
1573
+ self.test_logger = Logger(
1574
+ logger_csv_filename=os.path.join(self.output_dir, "logs_test.csv"),
1575
+ metrics=self._get_metrics_names_for_loggers(),
1576
+ mode="test",
1577
+ )
1578
+
1579
+ @rank_zero_only
1580
+ def _initialize_test_epoch_containers(self):
1581
+ self.test_losses: List[torch.Tensor] = []
1582
+ self.test_metric_values: List[Dict[str, float]] = []
1583
+
1584
+ def on_test_epoch_start(self):
1585
+ if self.params["medcam_enabled"]:
1586
+ self.model.enable_medcam()
1587
+ self.params["medcam_enabled"] = True
1588
+
1589
+ self._current_test_epoch_save_dir = os.path.join(
1590
+ self.output_dir, "output_test", f"epoch_{self.current_epoch}"
1591
+ )
1592
+ self._ensure_path_exists(self._current_test_epoch_save_dir)
1593
+
1594
+ def test_step(self, subject, batch_idx):
1595
+ if self.params["verbose"]:
1596
+ self._print_currently_processed_subject(subject)
1597
+
1598
+ subject_dict = self._initialize_subject_dict_nontraining_mode(subject)
1599
+ label_present = subject["label"] != ["NA"]
1600
+ value_keys_present = "value_keys" in self.params
1601
+ label = None
1602
+ if label_present:
1603
+ subject_dict = self._extend_nontraining_subject_dict_with_label(
1604
+ subject, subject_dict
1605
+ )
1606
+ if (
1607
+ self._problem_type_is_regression
1608
+ or self._problem_type_is_classification
1609
+ and label_present
1610
+ ):
1611
+ (
1612
+ model_output,
1613
+ last_input_batch,
1614
+ ) = self._get_predictions_on_subject_using_label_sampler(subject_dict)
1615
+
1616
+ if self.params["save_output"]:
1617
+ processed_logit = self._process_prediction_logit_for_row_writing(
1618
+ model_output, self.params["scaling_factor"]
1619
+ )
1620
+ self.rows_to_write.append(
1621
+ self._prepare_row_for_output_csv(
1622
+ subject["subject_id"][0], processed_logit, self.current_epoch
1623
+ )
1624
+ )
1625
+
1626
+ label = self._initialize_nontraining_label_ground_truth_classification_or_regression(
1627
+ subject
1628
+ )
1629
+ else:
1630
+ (
1631
+ model_output,
1632
+ last_input_batch,
1633
+ ) = self._get_predictions_on_subject_using_grid_sampler(subject_dict)
1634
+ if self.params["save_output"]:
1635
+ self._save_predictions_for_segmentation_subject(model_output, subject)
1636
+ if self._problem_type_is_segmentation and label_present:
1637
+ label = self._initialize_nontraining_label_ground_truth_segmentation(
1638
+ subject
1639
+ )
1640
+ elif (
1641
+ self._problem_type_is_classification
1642
+ or self._problem_type_is_regression
1643
+ and value_keys_present
1644
+ ):
1645
+ label = self._initialize_nontraining_label_ground_truth_classification_or_regression(
1646
+ subject
1647
+ )
1648
+ if label is not None:
1649
+ label = self._process_labels(label)
1650
+ model_output, label = self.pred_target_processor(model_output, label)
1651
+
1652
+ loss = self.loss(model_output, label, last_input_batch)
1653
+ metric_results = self.metric_calculators(
1654
+ model_output, label, subject_spacing=subject.get("spacing", None)
1655
+ )
1656
+
1657
+ self.test_losses.append(loss)
1658
+ self.test_metric_values.append(metric_results)
1659
+
1660
+ @rank_zero_only
1661
+ def on_test_epoch_end(self):
1662
+ test_epoch_average_metrics = {}
1663
+ metric_names = self.test_metric_values[0].keys()
1664
+ for metric_name in metric_names:
1665
+ metric_values = [x[metric_name] for x in self.test_metric_values]
1666
+ test_epoch_average_metrics[
1667
+ metric_name
1668
+ ] = self._compute_metric_mean_across_values_from_batches(metric_values)
1669
+
1670
+ mean_loss = self._round_value_to_precision(
1671
+ torch.mean(torch.stack(self.test_losses)).item()
1672
+ )
1673
+
1674
+ self.test_logger.write(
1675
+ self.current_epoch,
1676
+ mean_loss,
1677
+ self._ensure_proper_metric_formatting_for_logging(
1678
+ test_epoch_average_metrics
1679
+ ),
1680
+ )
1681
+
1682
+ self.log("test_loss", mean_loss, on_epoch=True, prog_bar=True)
1683
+ self.log_dict(
1684
+ self._prepare_metrics_dict_for_progbar_logging(test_epoch_average_metrics),
1685
+ on_epoch=True,
1686
+ prog_bar=True,
1687
+ sync_dist=False,
1688
+ )
1689
+ if self.params["save_output"] and (
1690
+ self._problem_type_is_regression or self._problem_type_is_classification
1691
+ ):
1692
+ self._save_predictions_csv_for_regression_or_classification(
1693
+ self.rows_to_write, self._determine_save_path_to_use()
1694
+ )
1695
+ self._clear_test_epoch_containers()
1696
+
1697
+ @rank_zero_only
1698
+ def _clear_test_epoch_containers(self):
1699
+ self.test_losses = []
1700
+ self.test_metric_values = []
1701
+
1702
+ def on_predict_start(self):
1703
+ self._initialize_inference_containers()
1704
+ self._try_to_load_model_inference_start()
1705
+
1706
+ if self.params.get("differential_privacy"):
1707
+ self._initialize_inference_differential_privacy()
1708
+
1709
+ def _try_to_load_model_inference_start(self):
1710
+ if self._try_to_load_model(self.model_paths["best"]):
1711
+ print(f"Previous best model loaded from {self.model_paths['best']}.")
1712
+ elif self._try_to_load_model(self.model_paths["latest"]):
1713
+ print(f"Previous latest model loaded from {self.model_paths['latest']}.")
1714
+ else:
1715
+ raise RuntimeError(
1716
+ f"Best/latest models not found to load: {self.model_paths}"
1717
+ )
1718
+
1719
+ @rank_zero_only
1720
+ def _initialize_inference_containers(self):
1721
+ self._current_inference_save_dir = os.path.join(
1722
+ self.output_dir, "output_inference"
1723
+ ) # TODO here we need some mechanism for separate outputs for nested inference
1724
+ self._ensure_path_exists(self._current_inference_save_dir)
1725
+ self.inference_losses = []
1726
+ self.inference_metric_values = []
1727
+ if self._problem_type_is_regression or self._problem_type_is_classification:
1728
+ self.rows_to_write = []
1729
+ self.subject_classification_class_probabilities: Dict[
1730
+ str, torch.Tensor
1731
+ ] = {}
1732
+
1733
+ @rank_zero_only
1734
+ def _print_inference_initialization_info(self):
1735
+ print("Current model type : ", self.params["model"]["type"])
1736
+ print("Number of dims : ", self.params["model"]["dimension"])
1737
+ if "num_channels" in self.params["model"]:
1738
+ print("Number of channels : ", self.params["model"]["num_channels"])
1739
+ print("Number of classes : ", len(self.params["model"]["class_list"]))
1740
+ self._print_host_info()
1741
+ if self.params["model"]["print_summary"]:
1742
+ self._print_model_summary()
1743
+
1744
+ def predict_step(self, batch, batch_idx):
1745
+ if self.params["verbose"]:
1746
+ self._print_currently_processed_subject(batch)
1747
+ # TODO both of those below should return values to complete the logic
1748
+ # of calculating metrics for classification case that is currently handled
1749
+ # by saving/reading logits.csv file
1750
+ if self.params["modality"] == "rad":
1751
+ return self._radiology_inference_step(batch)
1752
+ else:
1753
+ return self._histopathology_inference_step(batch)
1754
+
1755
+ def _radiology_inference_step(self, subject: torchio.Subject):
1756
+ label_present = subject["label"] != ["NA"]
1757
+ subject_dict = self._initialize_subject_dict_nontraining_mode(subject)
1758
+ if label_present:
1759
+ subject_dict = self._extend_nontraining_subject_dict_with_label(
1760
+ subject, subject_dict
1761
+ )
1762
+ if (
1763
+ self._problem_type_is_regression
1764
+ or self._problem_type_is_classification
1765
+ and label_present
1766
+ ):
1767
+ (
1768
+ model_output,
1769
+ last_input_batch,
1770
+ ) = self._get_predictions_on_subject_using_label_sampler(subject_dict)
1771
+
1772
+ processed_logit = self._process_prediction_logit_for_row_writing(
1773
+ model_output, self.params["scaling_factor"]
1774
+ )
1775
+ self.rows_to_write.append(
1776
+ self._prepare_row_for_output_csv(
1777
+ subject["subject_id"][0], processed_logit, self.current_epoch
1778
+ )
1779
+ )
1780
+
1781
+ label = self._initialize_nontraining_label_ground_truth_classification_or_regression(
1782
+ subject
1783
+ )
1784
+ else:
1785
+ (
1786
+ model_output,
1787
+ last_input_batch,
1788
+ ) = self._get_predictions_on_subject_using_grid_sampler(subject_dict)
1789
+ self._save_predictions_for_segmentation_subject(model_output, subject)
1790
+ label = self._initialize_nontraining_label_ground_truth_segmentation(
1791
+ subject
1792
+ )
1793
+ label = self._process_labels(label)
1794
+ model_output, label = self.pred_target_processor(model_output, label)
1795
+
1796
+ loss = self.loss(model_output, label, last_input_batch)
1797
+ metric_results = self.metric_calculators(
1798
+ model_output, label, subject_spacing=subject.get("spacing", None)
1799
+ )
1800
+
1801
+ self.inference_losses.append(loss)
1802
+ self.inference_metric_values.append(metric_results)
1803
+ else:
1804
+ (
1805
+ model_output,
1806
+ last_input_batch,
1807
+ ) = self._get_predictions_on_subject_using_grid_sampler(subject_dict)
1808
+ if self._problem_type_is_classification or self._problem_type_is_regression:
1809
+ processed_logit = self._process_prediction_logit_for_row_writing(
1810
+ model_output
1811
+ )
1812
+ self.rows_to_write.append(
1813
+ self._prepare_row_for_output_csv(
1814
+ subject["subject_id"][0], processed_logit, self.current_epoch
1815
+ )
1816
+ )
1817
+ else:
1818
+ self._save_predictions_for_segmentation_subject(model_output, subject)
1819
+
1820
+ if self._problem_type_is_classification:
1821
+ self.subject_classification_class_probabilities[
1822
+ subject["subject_id"][0]
1823
+ ] = F.softmax(model_output, dim=1)
1824
+
1825
+ # TODO this has to be somehow handled in different way, we
1826
+ # are mixing too much logic in this single module
1827
+ def _histopathology_inference_step(self, row_index_tuple):
1828
+ """
1829
+ Inference step for the histopathology modality. This function is called with an assumption that the highest
1830
+ level dataloader is an iterator over the rows of the dataframe. The function is called for each row of the
1831
+ dataframe.
1832
+
1833
+ Args:
1834
+ row (pd.Series): The row of the dataframe containing the information about the slide to be processed.
1835
+
1836
+ """
1837
+ row = row_index_tuple[1]
1838
+ subject_name = row[self.params["headers"]["subjectIDHeader"]]
1839
+ inference_results_save_dir_for_subject = os.path.join(
1840
+ self._current_inference_save_dir, "histopathology", str(subject_name)
1841
+ )
1842
+ self._ensure_path_exists(inference_results_save_dir_for_subject)
1843
+ self._prepare_histopath_default_inference_params()
1844
+ openslide_image = openslide.open_slide(
1845
+ row[self.params["headers"]["channelHeaders"]].values[0]
1846
+ )
1847
+ max_defined_slide_level = openslide_image.level_count - 1
1848
+ row_slide_level = min(self.params["slide_level"], max_defined_slide_level)
1849
+ row_slide_level = min(row_slide_level, 0)
1850
+ level_width, level_height = openslide_image.level_dimensions[row_slide_level]
1851
+ patch_size = self._ensure_patch_size_is_2D(self.params["patch_size"])
1852
+ count_map = self._initialize_count_map(level_width, level_height)
1853
+ probabilities_map = self._initialize_probability_map(
1854
+ self.params["model"]["num_classes"], level_width, level_height
1855
+ )
1856
+
1857
+ # TODO this should be done by other object or method
1858
+ transform_requested = get_transforms_for_preprocessing(
1859
+ self.params, [], False, False
1860
+ )
1861
+ patient_dataset = InferTumorSegDataset(
1862
+ row[self.params["headers"]["channelHeaders"]].values[0],
1863
+ patch_size=patch_size,
1864
+ stride_size=self.params["stride_size"],
1865
+ selected_level=row_slide_level,
1866
+ mask_level=self.params["mask_level"],
1867
+ transform=transform_requested,
1868
+ )
1869
+ histopathology_dataloader = torch.utils.data.DataLoader(
1870
+ patient_dataset,
1871
+ batch_size=1,
1872
+ shuffle=False,
1873
+ num_workers=self.params["q_num_workers"],
1874
+ )
1875
+ patch_size_updated_after_transforms = patient_dataset.get_patch_size()
1876
+ if self.params["model"]["print_summary"]:
1877
+ print_model_summary(
1878
+ self.model,
1879
+ self.params["batch_size"],
1880
+ self.params["model"]["num_channels"],
1881
+ patch_size_updated_after_transforms,
1882
+ )
1883
+ count_map, probabilities_map = self._iterate_over_histopathology_loader(
1884
+ histopathology_dataloader,
1885
+ count_map,
1886
+ probabilities_map,
1887
+ patch_size_updated_after_transforms,
1888
+ self.params["model"]["num_classes"],
1889
+ subject_name,
1890
+ )
1891
+
1892
+ map_saver = MapSaver(
1893
+ num_classes=self.params["model"]["num_classes"],
1894
+ slide_level=row_slide_level,
1895
+ blending_alpha=self.params["blending_alpha"],
1896
+ level_height=level_height,
1897
+ level_width=level_width,
1898
+ )
1899
+ map_saver.save_count_map(
1900
+ count_map, save_dir=inference_results_save_dir_for_subject
1901
+ )
1902
+
1903
+ map_saver.save_probability_and_segmentation_maps(
1904
+ probabilities_map,
1905
+ openslide_image,
1906
+ save_dir=inference_results_save_dir_for_subject,
1907
+ )
1908
+ if self._problem_type_is_classification or self._problem_type_is_regression:
1909
+ self._save_predictions_csv_for_regression_or_classification(
1910
+ self.rows_to_write, inference_results_save_dir_for_subject
1911
+ )
1912
+
1913
+ def _iterate_over_histopathology_loader(
1914
+ self,
1915
+ histopathology_dataloader,
1916
+ count_map,
1917
+ probability_map,
1918
+ patch_size,
1919
+ num_classes,
1920
+ subject_name,
1921
+ ):
1922
+ for image_patches, (x_coord, y_coord) in histopathology_dataloader:
1923
+ x_coord, y_coord = (
1924
+ x_coord.numpy(),
1925
+ y_coord.numpy(),
1926
+ ) # TODO the dataset should do that when fetching
1927
+ image_patches = image_patches.to(self.device)
1928
+ output, _ = self.forward(image_patches)
1929
+ output = output.cpu().detach().numpy()
1930
+ for i in range(output.shape[0]):
1931
+ self._increment_value_of_count_map_at_given_position(
1932
+ count_map, x_coord[i], y_coord[i], patch_size
1933
+ )
1934
+ for class_index in range(num_classes):
1935
+ self._add_value_to_probability_map_at_given_position(
1936
+ probability_map,
1937
+ x_coord[i],
1938
+ y_coord[i],
1939
+ patch_size,
1940
+ output[i][class_index],
1941
+ class_index,
1942
+ )
1943
+ if (
1944
+ self._problem_type_is_regression
1945
+ or self._problem_type_is_classification
1946
+ ):
1947
+ row_for_csv_saving = (
1948
+ self._prepare_row_for_output_csv_histopathology_inference(
1949
+ subject_name, x_coord[i], y_coord[i], output[i]
1950
+ )
1951
+ )
1952
+ self.rows_to_write.append(row_for_csv_saving)
1953
+ probability_map = np.divide(probability_map, count_map)
1954
+ return count_map, probability_map
1955
+
1956
+ @staticmethod
1957
+ def _increment_value_of_count_map_at_given_position(
1958
+ count_map, x_coord, y_coord, patch_size
1959
+ ):
1960
+ count_map[
1961
+ y_coord : y_coord + patch_size[1], x_coord : x_coord + patch_size[0]
1962
+ ] += 1
1963
+
1964
+ @staticmethod
1965
+ def _add_value_to_probability_map_at_given_position(
1966
+ prob_map, x_coord, y_coord, patch_size, value, class_index
1967
+ ):
1968
+ prob_map[
1969
+ class_index,
1970
+ y_coord : y_coord + patch_size[1],
1971
+ x_coord : x_coord + patch_size[0],
1972
+ ] += value
1973
+
1974
+ # TODO this should be handled by the config parser
1975
+ @rank_zero_only
1976
+ def _prepare_histopath_default_inference_params(self):
1977
+ """
1978
+ Sets the parameters necessary for histopath inference.
1979
+ """
1980
+ self.params["stride_size"] = self.params.get("stride_size", None)
1981
+ self.params["slide_level"] = self.params.get("slide_level", 0)
1982
+ self.params["mask_level"] = self.params.get(
1983
+ "mask_level", self.params["slide_level"]
1984
+ )
1985
+ self.params["blending_alpha"] = float(self.params.get("blending_alpha", 0.5))
1986
+
1987
+ @staticmethod
1988
+ def _initialize_count_map(level_width: int, level_height: int):
1989
+ """
1990
+ Initializes the count maps for the histopathology inference.
1991
+
1992
+ Args:
1993
+ level_width (int): The width of the level.
1994
+ level_height (int): The height of the level.
1995
+
1996
+ Returns:
1997
+ count_map (np.ndarray): The count map.
1998
+ """
1999
+ return np.zeros((level_height, level_width), dtype=np.uint8)
2000
+
2001
+ @staticmethod
2002
+ def _initialize_probability_map(
2003
+ num_classes: int, level_width: int, level_height: int
2004
+ ):
2005
+ """
2006
+ Initializes the probability maps for the histopathology inference.
2007
+ Called for classification and segmentation problems.
2008
+
2009
+ Args:
2010
+ num_classes (int): The number of classes.
2011
+ level_width (int): The width of the level.
2012
+ level_height (int): The height of the level.
2013
+
2014
+ Returns:
2015
+ probs_map (np.ndarray): The probability map.
2016
+ """
2017
+ return np.zeros((num_classes, level_height, level_width), dtype=np.float16)
2018
+
2019
+ @staticmethod
2020
+ def _ensure_patch_size_is_2D(patch_size: List[int]):
2021
+ """
2022
+ Ensures that the patch size is 2D.
2023
+
2024
+ Args:
2025
+ patch_size (List[int]): The patch size.
2026
+
2027
+ Returns:
2028
+ patch_size (List[int]): The 2D patch size.
2029
+ """
2030
+ if len(patch_size) == 3:
2031
+ return patch_size[:-1]
2032
+ return patch_size
2033
+
2034
+ @rank_zero_only
2035
+ def on_predict_end(self):
2036
+ if self.inference_metric_values:
2037
+ inference_epoch_average_metrics = {}
2038
+ metric_names = self.inference_metric_values[0].keys()
2039
+ for metric_name in metric_names:
2040
+ metric_values = [x[metric_name] for x in self.inference_metric_values]
2041
+ inference_epoch_average_metrics[
2042
+ metric_name
2043
+ ] = self._compute_metric_mean_across_values_from_batches(metric_values)
2044
+
2045
+ mean_loss = self._round_value_to_precision(
2046
+ torch.mean(torch.stack(self.inference_losses)).item()
2047
+ )
2048
+
2049
+ print("Inference results:")
2050
+ print(f"Loss: {mean_loss}")
2051
+ print(f"Metrics: {inference_epoch_average_metrics}")
2052
+
2053
+ self._clear_inference_containers()
2054
+
2055
+ @rank_zero_only
2056
+ def _clear_inference_containers(self):
2057
+ self.inference_losses = []
2058
+ self.inference_metric_values = []
2059
+ if self._problem_type_is_regression or self._problem_type_is_classification:
2060
+ self.rows_to_write = []
2061
+
2062
+ def configure_optimizers(self):
2063
+ params = deepcopy(self.params)
2064
+ params["model_parameters"] = self.model.parameters()
2065
+ params["learning_rate"] = self.learning_rate
2066
+ optimizer = get_optimizer(params)
2067
+ if "scheduler" in self.params:
2068
+ params["optimizer_object"] = optimizer
2069
+ scheduler = get_scheduler(params)
2070
+ optimizer_dict = {"optimizer": optimizer, "scheduler": scheduler}
2071
+ if isinstance(scheduler, ReduceLROnPlateau):
2072
+ optimizer_dict["monitor"] = "val_loss"
2073
+ return optimizer_dict
2074
+ return {"optimizer": optimizer}
2075
+
2076
+ def transfer_batch_to_device(self, batch, device, dataloader_idx):
2077
+ """
2078
+ A method called by Lightning to transfer the batch to the device.
2079
+ In case of GANDLF, we need custom logic to transfer the data to the device.
2080
+ """
2081
+ if not (
2082
+ self.trainer.predicting and self.params["modality"] in ["path", "histo"]
2083
+ ):
2084
+ batch = self._move_image_data_to_device(batch, device)
2085
+ batch = self._move_labels_or_values_to_device(batch, device)
2086
+ return batch
2087
+
2088
+ def _move_image_data_to_device(self, subject, device):
2089
+ for channel_key in self.params["channel_keys"]:
2090
+ subject[channel_key][torchio.DATA] = subject[channel_key][torchio.DATA].to(
2091
+ device
2092
+ )
2093
+ return subject
2094
+
2095
+ def _move_labels_or_values_to_device(self, subject, device):
2096
+ if "value_keys" in self.params:
2097
+ for value_key in self.params["value_keys"]:
2098
+ subject[value_key] = subject[value_key].to(device)
2099
+ elif subject["label"] != ["NA"]:
2100
+ subject["label"][torchio.DATA] = subject["label"][torchio.DATA].to(device)
2101
+
2102
+ return subject