GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of GANDLF might be problematic. Click here for more details.
- GANDLF/cli/deploy.py +2 -2
- GANDLF/cli/generate_metrics.py +35 -1
- GANDLF/cli/main_run.py +4 -10
- GANDLF/compute/__init__.py +0 -2
- GANDLF/compute/forward_pass.py +0 -1
- GANDLF/compute/generic.py +107 -2
- GANDLF/compute/inference_loop.py +4 -4
- GANDLF/compute/loss_and_metric.py +1 -2
- GANDLF/compute/training_loop.py +10 -10
- GANDLF/config_manager.py +33 -717
- GANDLF/configuration/__init__.py +0 -0
- GANDLF/configuration/default_config.py +73 -0
- GANDLF/configuration/differential_privacy_config.py +16 -0
- GANDLF/configuration/exclude_parameters.py +1 -0
- GANDLF/configuration/model_config.py +82 -0
- GANDLF/configuration/nested_training_config.py +25 -0
- GANDLF/configuration/optimizer_config.py +121 -0
- GANDLF/configuration/parameters_config.py +10 -0
- GANDLF/configuration/patch_sampler_config.py +11 -0
- GANDLF/configuration/post_processing_config.py +10 -0
- GANDLF/configuration/pre_processing_config.py +94 -0
- GANDLF/configuration/scheduler_config.py +92 -0
- GANDLF/configuration/user_defined_config.py +131 -0
- GANDLF/configuration/utils.py +96 -0
- GANDLF/configuration/validators.py +479 -0
- GANDLF/data/__init__.py +14 -16
- GANDLF/data/lightning_datamodule.py +119 -0
- GANDLF/entrypoints/run.py +36 -31
- GANDLF/inference_manager.py +69 -25
- GANDLF/losses/__init__.py +23 -1
- GANDLF/losses/loss_calculators.py +79 -0
- GANDLF/losses/segmentation.py +3 -2
- GANDLF/metrics/__init__.py +26 -0
- GANDLF/metrics/generic.py +1 -1
- GANDLF/metrics/metric_calculators.py +102 -0
- GANDLF/metrics/panoptica_config_brats.yaml +56 -0
- GANDLF/metrics/segmentation_panoptica.py +49 -0
- GANDLF/models/__init__.py +8 -3
- GANDLF/models/lightning_module.py +2102 -0
- GANDLF/optimizers/__init__.py +4 -8
- GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
- GANDLF/schedulers/__init__.py +11 -4
- GANDLF/schedulers/wrap_torch.py +15 -3
- GANDLF/training_manager.py +160 -50
- GANDLF/utils/__init__.py +5 -3
- GANDLF/utils/imaging.py +176 -35
- GANDLF/utils/modelio.py +12 -8
- GANDLF/utils/pred_target_processors.py +71 -0
- GANDLF/utils/tensor.py +2 -1
- GANDLF/utils/write_parse.py +1 -1
- GANDLF/version.py +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,2102 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
import psutil
|
|
5
|
+
import torch
|
|
6
|
+
import torchio
|
|
7
|
+
import warnings
|
|
8
|
+
import openslide
|
|
9
|
+
import numpy as np
|
|
10
|
+
import SimpleITK as sitk
|
|
11
|
+
import lightning.pytorch as pl
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from medcam import medcam
|
|
16
|
+
from copy import deepcopy
|
|
17
|
+
from statistics import mean
|
|
18
|
+
from multiprocessing import Lock
|
|
19
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
20
|
+
from lightning.pytorch.strategies import DDPStrategy
|
|
21
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
22
|
+
|
|
23
|
+
from GANDLF.logger import Logger
|
|
24
|
+
from GANDLF.models import get_model
|
|
25
|
+
from GANDLF.metrics import overall_stats
|
|
26
|
+
from GANDLF.optimizers import get_optimizer
|
|
27
|
+
from GANDLF.schedulers import get_scheduler
|
|
28
|
+
from GANDLF.data.post_process import global_postprocessing_dict
|
|
29
|
+
from GANDLF.losses.loss_calculators import LossCalculatorFactory
|
|
30
|
+
from GANDLF.metrics.metric_calculators import MetricCalculatorFactory
|
|
31
|
+
from GANDLF.data.preprocessing import get_transforms_for_preprocessing
|
|
32
|
+
from GANDLF.data.inference_dataloader_histopath import InferTumorSegDataset
|
|
33
|
+
from GANDLF.utils.pred_target_processors import PredictionTargetProcessorFactory
|
|
34
|
+
from GANDLF.privacy.opacus.opacus_anonymization_manager import (
|
|
35
|
+
OpacusAnonymizationManager,
|
|
36
|
+
)
|
|
37
|
+
from GANDLF.privacy.opacus import handle_dynamic_batch_size
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
from GANDLF.utils import (
|
|
41
|
+
optimize_and_save_model,
|
|
42
|
+
write_training_patches,
|
|
43
|
+
one_hot,
|
|
44
|
+
reverse_one_hot,
|
|
45
|
+
print_model_summary,
|
|
46
|
+
get_date_time,
|
|
47
|
+
save_model,
|
|
48
|
+
load_model,
|
|
49
|
+
version_check,
|
|
50
|
+
get_filename_extension_sanitized,
|
|
51
|
+
resample_image,
|
|
52
|
+
BEST_MODEL_PATH_END,
|
|
53
|
+
INITIAL_MODEL_PATH_END,
|
|
54
|
+
LATEST_MODEL_PATH_END,
|
|
55
|
+
MapSaver,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
from typing import Tuple, Union, Dict, List, Any
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class GandlfLightningModule(pl.LightningModule):
|
|
62
|
+
CLASSIFICATION_REGRESSION_RESULTS_HEADER = "Epoch,SubjectID,PredictedValue\n"
|
|
63
|
+
CLASSIFICATION_REGRESSION_RESULTS_HEADER_HISTOPATH = "SubjectID,x_coords,y_coords"
|
|
64
|
+
FLOAT_FORMATTING_PRECISION = 4
|
|
65
|
+
MULTIPROCESSING_LOCK = Lock()
|
|
66
|
+
|
|
67
|
+
def __init__(self, params: dict, output_dir: str):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.output_dir = output_dir
|
|
70
|
+
self.params = deepcopy(params)
|
|
71
|
+
self.learning_rate = self.params["learning_rate"]
|
|
72
|
+
self._problem_type_is_regression = params["problem_type"] == "regression"
|
|
73
|
+
self._problem_type_is_classification = (
|
|
74
|
+
params["problem_type"] == "classification"
|
|
75
|
+
)
|
|
76
|
+
self._problem_type_is_segmentation = params["problem_type"] == "segmentation"
|
|
77
|
+
self._initialize_model()
|
|
78
|
+
self._initialize_loss()
|
|
79
|
+
self._initialize_metric_calculators()
|
|
80
|
+
self._initialize_preds_target_processor()
|
|
81
|
+
self._initialize_model_save_paths()
|
|
82
|
+
|
|
83
|
+
def _initialize_model(self):
|
|
84
|
+
"""
|
|
85
|
+
Creates the BaseModel instance based on the parameters.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
self.model = get_model(self.params)
|
|
89
|
+
|
|
90
|
+
def _initialize_loss(self):
|
|
91
|
+
"""
|
|
92
|
+
Initializes the loss calculator based on the parameters. Loss calculator
|
|
93
|
+
logic differs for some specific model architectures, see the LossCalculatorFactory
|
|
94
|
+
for more details.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
self.loss = LossCalculatorFactory(self.params).get_loss_calculator()
|
|
98
|
+
|
|
99
|
+
def _initialize_metric_calculators(self):
|
|
100
|
+
"""
|
|
101
|
+
Initializes the metric calculators based on the parameters. Metric calculators
|
|
102
|
+
logic differs for some specific model architectures, see the MetricCalculatorFactory
|
|
103
|
+
for more details.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
self.metric_calculators = MetricCalculatorFactory(
|
|
107
|
+
self.params
|
|
108
|
+
).get_metric_calculator()
|
|
109
|
+
|
|
110
|
+
def _initialize_preds_target_processor(self):
|
|
111
|
+
"""Initializes the prediction target processor based on the parameters.
|
|
112
|
+
This processor ensures that the prediction and target tensors are in the correct format,
|
|
113
|
+
as some architectures may require different formats for the predictions and targets.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
self.pred_target_processor = PredictionTargetProcessorFactory(
|
|
117
|
+
self.params
|
|
118
|
+
).get_prediction_target_processor()
|
|
119
|
+
|
|
120
|
+
def _initialize_model_save_paths(self):
|
|
121
|
+
"""
|
|
122
|
+
Initializes the paths used for saving checkpoints of the model.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
self.model_paths = {
|
|
126
|
+
"best": os.path.join(
|
|
127
|
+
self.output_dir,
|
|
128
|
+
self.params["model"]["architecture"] + BEST_MODEL_PATH_END,
|
|
129
|
+
),
|
|
130
|
+
"initial": os.path.join(
|
|
131
|
+
self.output_dir,
|
|
132
|
+
self.params["model"]["architecture"] + INITIAL_MODEL_PATH_END,
|
|
133
|
+
),
|
|
134
|
+
"latest": os.path.join(
|
|
135
|
+
self.output_dir,
|
|
136
|
+
self.params["model"]["architecture"] + LATEST_MODEL_PATH_END,
|
|
137
|
+
),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
@rank_zero_only
|
|
141
|
+
def _save_model(self, epoch: int, save_path: str, onnx_export: bool):
|
|
142
|
+
"""
|
|
143
|
+
Saves the model to the specified path, adhering to GANDLF save format.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
epoch (int): The epoch number.
|
|
147
|
+
save_path (str): The path to save the model to.
|
|
148
|
+
onnx_export (bool): Whether to export the model to ONNX format
|
|
149
|
+
"""
|
|
150
|
+
save_model(
|
|
151
|
+
{
|
|
152
|
+
"epoch": epoch,
|
|
153
|
+
"model_state_dict": self.model.state_dict(),
|
|
154
|
+
"optimizer_state_dict": self.optimizers().optimizer.state_dict(),
|
|
155
|
+
"loss": self.current_best_loss,
|
|
156
|
+
},
|
|
157
|
+
model=self.model,
|
|
158
|
+
params=self.params,
|
|
159
|
+
path=save_path,
|
|
160
|
+
onnx_export=onnx_export,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def _prepare_metrics_dict_for_progbar_logging(
|
|
164
|
+
self, metric_results_dict: Dict[str, float]
|
|
165
|
+
):
|
|
166
|
+
"""
|
|
167
|
+
Formats the metric results dictionary into format suitable for
|
|
168
|
+
logging with Lightning's progress bar.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
metric_results_dict (Dict[str, float]): The dictionary containing the metric results.
|
|
172
|
+
"""
|
|
173
|
+
metric_results_dict_with_updated_suffix = (
|
|
174
|
+
self._add_stage_prefix_to_metric_results_dict(
|
|
175
|
+
metric_results_dict, self._determine_trainer_stage_string()
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
metric_results_dict_with_values_formatted = (
|
|
179
|
+
self._convert_per_class_metric_results_to_separate_key_value_pairs(
|
|
180
|
+
metric_results_dict_with_updated_suffix
|
|
181
|
+
)
|
|
182
|
+
)
|
|
183
|
+
return self._round_metric_values_in_dict(
|
|
184
|
+
metric_results_dict_with_values_formatted
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def _convert_per_class_metric_results_to_separate_key_value_pairs(
|
|
189
|
+
metric_results_dict: Dict[str, Any]
|
|
190
|
+
) -> Dict[str, float]:
|
|
191
|
+
"""
|
|
192
|
+
In case the metric results dictionary contains per-class values, this function
|
|
193
|
+
takes the values and creates separate key-value pairs for each class in the
|
|
194
|
+
results dictionary.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
metric_results_dict (Dict[str, Any]): The dictionary containing the metric results.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
parsed_results_dict (Dict[str, float]): The dictionary containing the parsed results.
|
|
201
|
+
"""
|
|
202
|
+
parsed_results_dict = deepcopy(metric_results_dict)
|
|
203
|
+
for metric_name, metric_value in metric_results_dict.items():
|
|
204
|
+
if isinstance(metric_value, list):
|
|
205
|
+
for n, metric_value_for_given_class in enumerate(metric_value):
|
|
206
|
+
parsed_results_dict[
|
|
207
|
+
metric_name + f"_class_{n}"
|
|
208
|
+
] = metric_value_for_given_class
|
|
209
|
+
del parsed_results_dict[metric_name]
|
|
210
|
+
return parsed_results_dict
|
|
211
|
+
|
|
212
|
+
@staticmethod
|
|
213
|
+
def _add_stage_prefix_to_metric_results_dict(
|
|
214
|
+
metric_results_dict: Dict[str, float], stage: str
|
|
215
|
+
):
|
|
216
|
+
"""
|
|
217
|
+
Ensures that metric names in the results dictionary are prefixed with the stage
|
|
218
|
+
"""
|
|
219
|
+
metric_results_dict_with_updated_suffix = {
|
|
220
|
+
f"{stage}_{metric_name}": metric_value
|
|
221
|
+
for metric_name, metric_value in metric_results_dict.items()
|
|
222
|
+
}
|
|
223
|
+
return metric_results_dict_with_updated_suffix
|
|
224
|
+
|
|
225
|
+
def _round_metric_values_in_dict(self, metric_results_dict: Dict[str, float]):
|
|
226
|
+
"""
|
|
227
|
+
Performs rounding of the metric values in the results dictionary.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
metric_results_dict (Dict[str, float]): The dictionary containing the metric results.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
rounded_metric_results_dict (Dict[str, float]): The dictionary containing the rounded metric results.
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
return {
|
|
237
|
+
k: self._round_value_to_precision(v) for k, v in metric_results_dict.items()
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
def _round_value_to_precision(self, value: float):
|
|
241
|
+
"""
|
|
242
|
+
Rounds the value to the specified precision, defined as module constant.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
value (float): The value to round.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
rounded_value (float): The rounded value.
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
return round(value, self.FLOAT_FORMATTING_PRECISION)
|
|
252
|
+
|
|
253
|
+
def forward(
|
|
254
|
+
self, images: torch.Tensor
|
|
255
|
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
|
256
|
+
"""
|
|
257
|
+
Forward pass of the model.
|
|
258
|
+
"""
|
|
259
|
+
attention_map = None
|
|
260
|
+
is_medcam_enabled = self.params.get("medcam_enabled", False)
|
|
261
|
+
if is_medcam_enabled:
|
|
262
|
+
output, attention_map = self.model(images)
|
|
263
|
+
if self.params["model"]["dimension"] == 2:
|
|
264
|
+
attention_map = torch.unsqueeze(attention_map, -1)
|
|
265
|
+
else:
|
|
266
|
+
output = self.model(images)
|
|
267
|
+
return output, attention_map
|
|
268
|
+
|
|
269
|
+
def on_train_start(self):
|
|
270
|
+
self._print_training_initialization_info()
|
|
271
|
+
self._set_training_start_time()
|
|
272
|
+
self._print_channels_info()
|
|
273
|
+
self._initialize_train_logger()
|
|
274
|
+
self._initialize_training_epoch_containers()
|
|
275
|
+
self.wait_count_before_early_stopping = 0
|
|
276
|
+
self.current_best_loss = sys.float_info.max
|
|
277
|
+
|
|
278
|
+
self.params["current_epoch"] = self.current_epoch
|
|
279
|
+
# TODO check out if the disabled by default medcam is indeed what we
|
|
280
|
+
# meant - it was taken from original code
|
|
281
|
+
if self.params.get("medcam"):
|
|
282
|
+
self._inject_medcam_module()
|
|
283
|
+
self.params["medcam_enabled"] = False
|
|
284
|
+
if self.params.get("differential_privacy"):
|
|
285
|
+
self._initialize_training_differential_privacy()
|
|
286
|
+
|
|
287
|
+
def _try_to_load_model_training_start(self):
|
|
288
|
+
"""
|
|
289
|
+
Attempts to load the model at the start of the training.
|
|
290
|
+
"""
|
|
291
|
+
if self._try_to_load_model(self.model_paths["best"]):
|
|
292
|
+
print(f"Previous best model loaded from {self.model_paths['best']}.")
|
|
293
|
+
elif self._try_to_load_model(self.model_paths["latest"]):
|
|
294
|
+
print(f"Previous latest model loaded from {self.model_paths['latest']}.")
|
|
295
|
+
else:
|
|
296
|
+
print(
|
|
297
|
+
"Could not load any previous model, training from scratch.", flush=True
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
def _try_to_load_model(self, load_path: str):
|
|
301
|
+
"""
|
|
302
|
+
Attempts to load the model from the specified path.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
load_path (str): The path to the model to load.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
bool: Whether the model was successfully loaded.
|
|
309
|
+
"""
|
|
310
|
+
if os.path.exists(load_path):
|
|
311
|
+
try:
|
|
312
|
+
checkpoint_dict = load_model(load_path, self.device)
|
|
313
|
+
version_check(
|
|
314
|
+
self.params["version"], version_to_check=checkpoint_dict["version"]
|
|
315
|
+
)
|
|
316
|
+
# I am purposefully omitting the line below, as "previous_parameters" are not used anywhere
|
|
317
|
+
# params["previous_parameters"] = main_dict.get("parameters", None)
|
|
318
|
+
state_dict = checkpoint_dict["model_state_dict"]
|
|
319
|
+
if self.params.get("differential_privacy"):
|
|
320
|
+
# this is required for torch==1.11 and for DP inference
|
|
321
|
+
new_state_dict = {}
|
|
322
|
+
for key, val in state_dict.items():
|
|
323
|
+
new_key = key.replace("_module.", "")
|
|
324
|
+
new_state_dict[new_key] = val # remove `module.`
|
|
325
|
+
state_dict = new_state_dict
|
|
326
|
+
|
|
327
|
+
self.model.load_state_dict(state_dict)
|
|
328
|
+
if self.trainer.training:
|
|
329
|
+
self.optimizers(False).load_state_dict(
|
|
330
|
+
checkpoint_dict["optimizer_state_dict"]
|
|
331
|
+
)
|
|
332
|
+
self.trainer.fit_loop.epoch_progress.current.completed = (
|
|
333
|
+
checkpoint_dict["epoch"]
|
|
334
|
+
)
|
|
335
|
+
self.trainer.callback_metrics["val_loss"] = checkpoint_dict["loss"]
|
|
336
|
+
return True
|
|
337
|
+
except Exception as e:
|
|
338
|
+
warnings.warn(
|
|
339
|
+
f"Model found under path {load_path}, but error occurred during loading: {e}"
|
|
340
|
+
)
|
|
341
|
+
return False
|
|
342
|
+
|
|
343
|
+
@rank_zero_only
|
|
344
|
+
def _try_to_save_initial_model(self):
|
|
345
|
+
"""
|
|
346
|
+
Saves the initial model at the specified path if it does not already exist.
|
|
347
|
+
"""
|
|
348
|
+
if not os.path.exists(self.model_paths["initial"]):
|
|
349
|
+
self._save_model(self.current_epoch, self.model_paths["initial"], False)
|
|
350
|
+
print(f"Initial model saved at {self.model_paths['initial']}")
|
|
351
|
+
else:
|
|
352
|
+
print(
|
|
353
|
+
f"Initial model already exists at {self.model_paths['initial']}; Skipping saving"
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
def _inject_medcam_module(self):
|
|
357
|
+
"""
|
|
358
|
+
Extends the model with the medcam module, used for generating attention maps.
|
|
359
|
+
"""
|
|
360
|
+
self.model = medcam.inject(
|
|
361
|
+
self.model,
|
|
362
|
+
output_dir=os.path.join(
|
|
363
|
+
self.output_dir, "attention_maps", self.params["medcam"]["backend"]
|
|
364
|
+
),
|
|
365
|
+
backend=self.params["medcam"]["backend"],
|
|
366
|
+
layer=self.params["medcam"]["layer"],
|
|
367
|
+
save_maps=False,
|
|
368
|
+
return_attention=True,
|
|
369
|
+
enabled=False,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
def _get_metrics_names_for_loggers(self):
|
|
373
|
+
"""
|
|
374
|
+
Returns the names of the overall metrics to be logged if the problem type is classification or regression.
|
|
375
|
+
"""
|
|
376
|
+
metric_names = list(self.params["metrics"])
|
|
377
|
+
overall_metrics = {}
|
|
378
|
+
if self._problem_type_is_regression:
|
|
379
|
+
overall_metrics = overall_stats(
|
|
380
|
+
torch.tensor([1]), torch.tensor([1]), self.params
|
|
381
|
+
)
|
|
382
|
+
elif self._problem_type_is_classification:
|
|
383
|
+
temp_tensor = torch.randint(
|
|
384
|
+
0, self.params["model"]["num_classes"], (5,), dtype=torch.int32
|
|
385
|
+
)
|
|
386
|
+
overall_metrics = overall_stats(temp_tensor, temp_tensor, self.params)
|
|
387
|
+
for overall_metric_key in overall_metrics.keys():
|
|
388
|
+
if overall_metric_key not in metric_names:
|
|
389
|
+
metric_names.append(overall_metric_key)
|
|
390
|
+
|
|
391
|
+
return metric_names
|
|
392
|
+
|
|
393
|
+
@rank_zero_only
|
|
394
|
+
def _initialize_train_logger(self):
|
|
395
|
+
self.train_logger = Logger(
|
|
396
|
+
logger_csv_filename=os.path.join(self.output_dir, "logs_training.csv"),
|
|
397
|
+
metrics=self._get_metrics_names_for_loggers(),
|
|
398
|
+
mode="train",
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
@rank_zero_only
|
|
402
|
+
def _set_training_start_time(self):
|
|
403
|
+
self.training_start_time = time.time()
|
|
404
|
+
|
|
405
|
+
@rank_zero_only
|
|
406
|
+
def _print_training_initialization_info(self):
|
|
407
|
+
"""
|
|
408
|
+
Basic info printed at the start of the training.
|
|
409
|
+
"""
|
|
410
|
+
self._print_host_info()
|
|
411
|
+
if self.params["verbose"]:
|
|
412
|
+
print("Initializing training at :", get_date_time(), flush=True)
|
|
413
|
+
if self.params["model"]["print_summary"]:
|
|
414
|
+
self._print_model_summary()
|
|
415
|
+
|
|
416
|
+
def _print_host_info(self):
|
|
417
|
+
if os.environ.get("HOSTNAME"):
|
|
418
|
+
print("Hostname :", os.environ.get("HOSTNAME"), flush=True)
|
|
419
|
+
|
|
420
|
+
def _print_model_summary(self):
|
|
421
|
+
print_model_summary(
|
|
422
|
+
self.model,
|
|
423
|
+
self.params["batch_size"],
|
|
424
|
+
self.params["model"]["num_channels"],
|
|
425
|
+
self.params["patch_size"],
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
@rank_zero_only
|
|
429
|
+
def _initialize_training_epoch_containers(self):
|
|
430
|
+
"""
|
|
431
|
+
Initializes the containers for storing the training epoch data.
|
|
432
|
+
They are used for accumulating the losses, metrics, predictions and labels
|
|
433
|
+
for each epoch, so final calculations can be made at the end of the epoch.
|
|
434
|
+
"""
|
|
435
|
+
|
|
436
|
+
self.train_losses: List[torch.Tensor] = []
|
|
437
|
+
self.training_metric_values: List[Dict[str, float]] = []
|
|
438
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
439
|
+
self.train_predictions: List[torch.Tensor] = []
|
|
440
|
+
self.train_labels: List[torch.Tensor] = []
|
|
441
|
+
|
|
442
|
+
@rank_zero_only
|
|
443
|
+
def _print_channels_info(self):
|
|
444
|
+
print("Number of channels : ", self.params["model"]["num_channels"])
|
|
445
|
+
|
|
446
|
+
def training_step(self, subject, batch_idx):
|
|
447
|
+
"""
|
|
448
|
+
Single training optimization step.
|
|
449
|
+
"""
|
|
450
|
+
if self.params.get("save_training"):
|
|
451
|
+
write_training_patches(subject, self.params)
|
|
452
|
+
|
|
453
|
+
if self.params.get("differential_privacy"):
|
|
454
|
+
self._handle_dynamic_batch_size_in_differential_privacy_mode(subject)
|
|
455
|
+
|
|
456
|
+
images = self._prepare_images_batch_from_subject_data(subject)
|
|
457
|
+
labels = self._prepare_labels_batch_from_subject_data(subject)
|
|
458
|
+
|
|
459
|
+
images = self._ensure_proper_images_tensor_dimensions(images)
|
|
460
|
+
labels = self._process_labels(labels)
|
|
461
|
+
model_output, _ = self.forward(images)
|
|
462
|
+
model_output, labels = self.pred_target_processor(model_output, labels)
|
|
463
|
+
|
|
464
|
+
loss = self.loss(model_output, labels, images)
|
|
465
|
+
metric_results = self.metric_calculators(
|
|
466
|
+
model_output, labels, subject_spacing=subject.get("spacing", None)
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
470
|
+
self.train_labels.append(labels.detach().cpu())
|
|
471
|
+
self.train_predictions.append(
|
|
472
|
+
torch.argmax(model_output, dim=1).detach().cpu()
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
self.train_losses.append(loss.detach().cpu())
|
|
476
|
+
self.training_metric_values.append(metric_results)
|
|
477
|
+
|
|
478
|
+
return loss
|
|
479
|
+
|
|
480
|
+
def _prepare_images_batch_from_subject_data(self, subject_data: torchio.Subject):
|
|
481
|
+
"""
|
|
482
|
+
Concatenates the images from the subject data into a single tensor.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
subject_data (torchio.Subject): The torchio.Subject object containing the images.
|
|
486
|
+
Can be also a set of already extracted patches.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
images_batch (torch.Tensor): The concatenated images from the subject data
|
|
490
|
+
of shape (B, C, H, W, D).
|
|
491
|
+
|
|
492
|
+
"""
|
|
493
|
+
images_batch = torch.cat(
|
|
494
|
+
[subject_data[key][torchio.DATA] for key in self.params["channel_keys"]],
|
|
495
|
+
dim=1,
|
|
496
|
+
)
|
|
497
|
+
return images_batch
|
|
498
|
+
|
|
499
|
+
def _prepare_labels_batch_from_subject_data(self, subject: torchio.Subject):
|
|
500
|
+
"""
|
|
501
|
+
Creates the label tensor from the subject data.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
subject (torchio.Subject): The torchio.Subject object containing the label.
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
label (torch.Tensor): The label tensor of shape (B, C, H, W, D) for segmentation,
|
|
508
|
+
or a tensor of shape (B, ) for classification/regression.
|
|
509
|
+
"""
|
|
510
|
+
|
|
511
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
512
|
+
label = torch.cat(
|
|
513
|
+
[subject[key] for key in self.params["value_keys"]], dim=0
|
|
514
|
+
)
|
|
515
|
+
# TODO this for sure needs some further investigation
|
|
516
|
+
# min is needed because for certain cases, batch size becomes smaller than the total remaining labels
|
|
517
|
+
label = label.reshape(
|
|
518
|
+
min(self.params["batch_size"], len(label)),
|
|
519
|
+
len(self.params["value_keys"]),
|
|
520
|
+
)
|
|
521
|
+
else:
|
|
522
|
+
label = subject["label"][torchio.DATA]
|
|
523
|
+
|
|
524
|
+
return label
|
|
525
|
+
|
|
526
|
+
def _ensure_proper_images_tensor_dimensions(self, images: torch.Tensor):
|
|
527
|
+
"""
|
|
528
|
+
Modify the input images by removing the singular depth dimension added
|
|
529
|
+
by torchio for 2D images.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
images (torch.Tensor): The input images tensor.
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
images (torch.Tensor): The modified images tensor.
|
|
536
|
+
"""
|
|
537
|
+
|
|
538
|
+
if self.params["model"]["dimension"] == 2:
|
|
539
|
+
images = images.squeeze(-1)
|
|
540
|
+
|
|
541
|
+
return images
|
|
542
|
+
|
|
543
|
+
def _process_labels(self, labels: torch.Tensor):
|
|
544
|
+
"""
|
|
545
|
+
Modifies the labels tensor based on the problem type.
|
|
546
|
+
"""
|
|
547
|
+
|
|
548
|
+
if self._problem_type_is_segmentation:
|
|
549
|
+
if labels.shape[1] == 3:
|
|
550
|
+
labels = labels[:, 0, ...].unsqueeze(1)
|
|
551
|
+
warnings.warn(
|
|
552
|
+
"The label image is an RGB image, only the first channel will be used."
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# for segmentation remove the depth dimension from the label.
|
|
556
|
+
# for classification / regression, flattens class / reg label from list (possible in multilabel) to scalar
|
|
557
|
+
# TODO: second condition is crutch - in some cases label is passed as 1-d Tensor (B,) and if Batch size is 1,
|
|
558
|
+
# it is squeezed to scalar tensor (0-d) and the future logic fails
|
|
559
|
+
if len(labels.shape) != 1:
|
|
560
|
+
labels = labels.squeeze(-1)
|
|
561
|
+
|
|
562
|
+
if self._problem_type_is_segmentation:
|
|
563
|
+
labels = one_hot(labels, self.params["model"]["class_list"])
|
|
564
|
+
|
|
565
|
+
return labels
|
|
566
|
+
|
|
567
|
+
def _handle_dynamic_batch_size_in_differential_privacy_mode(self, subject):
|
|
568
|
+
subject, _ = handle_dynamic_batch_size(subject, self.params)
|
|
569
|
+
return subject
|
|
570
|
+
|
|
571
|
+
def _initialize_training_differential_privacy(self):
|
|
572
|
+
self._check_if_opacus_is_applicable()
|
|
573
|
+
opacus_manager = OpacusAnonymizationManager(self.params)
|
|
574
|
+
|
|
575
|
+
(
|
|
576
|
+
model,
|
|
577
|
+
dp_optimizer,
|
|
578
|
+
train_dataloader,
|
|
579
|
+
privacy_engine,
|
|
580
|
+
) = opacus_manager.apply_privacy(
|
|
581
|
+
self.model, self.optimizers().optimizer, self.trainer.train_dataloader
|
|
582
|
+
)
|
|
583
|
+
self.model = model
|
|
584
|
+
self.trainer.fit_loop._data_source.instance = train_dataloader
|
|
585
|
+
self.trainer.optimizers = [dp_optimizer]
|
|
586
|
+
# TODO should we reinit the scheduler too?
|
|
587
|
+
self._dp_engine = privacy_engine
|
|
588
|
+
|
|
589
|
+
def _check_if_opacus_is_applicable(self):
|
|
590
|
+
if isinstance(self.trainer.strategy, DDPStrategy):
|
|
591
|
+
raise NotImplementedError(
|
|
592
|
+
"Differential privacy is not supported with DDP strategy. Please use single GPU."
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
def on_train_epoch_start(self):
|
|
596
|
+
self._set_epoch_start_time()
|
|
597
|
+
if self.params["track_memory_usage"]:
|
|
598
|
+
self._write_epoch_start_process_resource_usage(self.current_epoch)
|
|
599
|
+
if self.params["verbose"]:
|
|
600
|
+
self._print_epoch_start_time()
|
|
601
|
+
|
|
602
|
+
def _write_epoch_start_process_resource_usage(self, epoch):
|
|
603
|
+
"""
|
|
604
|
+
Writes the memory usage to a file at the start of the epoch.
|
|
605
|
+
Ran separately on each process in case of distributed training.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
epoch (int): The current epoch number.
|
|
609
|
+
"""
|
|
610
|
+
filename = f"memory_usage_local_rank_{self.local_rank}_global_rank_{self.global_rank}.csv"
|
|
611
|
+
memory_stats_dir = self._prepare_memory_stats_save_dir()
|
|
612
|
+
full_filepath = os.path.join(memory_stats_dir, filename)
|
|
613
|
+
file_write_mode = "a" if os.path.exists(full_filepath) else "w"
|
|
614
|
+
using_cuda = "cuda" in self.device.type
|
|
615
|
+
|
|
616
|
+
memory_info_string = "Epoch,Memory_Total,Memory_Available,Memory_Percent_Free,Memory_Usage," # used to write output
|
|
617
|
+
if using_cuda:
|
|
618
|
+
memory_info_string += (
|
|
619
|
+
"CUDA_active.all.peak,CUDA_active.all.current,CUDA_active.all.allocated"
|
|
620
|
+
)
|
|
621
|
+
memory_info_string += "\n"
|
|
622
|
+
|
|
623
|
+
host_memory_stats = psutil.virtual_memory()
|
|
624
|
+
memory_info_string += (
|
|
625
|
+
str(epoch)
|
|
626
|
+
+ ","
|
|
627
|
+
+ str(host_memory_stats[0])
|
|
628
|
+
+ ","
|
|
629
|
+
+ str(host_memory_stats[1])
|
|
630
|
+
+ ","
|
|
631
|
+
+ str(host_memory_stats[2])
|
|
632
|
+
+ ","
|
|
633
|
+
+ str(host_memory_stats[3])
|
|
634
|
+
)
|
|
635
|
+
if using_cuda:
|
|
636
|
+
cuda_memory_stats = torch.cuda.memory_stats()
|
|
637
|
+
memory_info_string += (
|
|
638
|
+
","
|
|
639
|
+
+ str(cuda_memory_stats["active.all.peak"])
|
|
640
|
+
+ ","
|
|
641
|
+
+ str(cuda_memory_stats["active.all.current"])
|
|
642
|
+
+ ","
|
|
643
|
+
+ str(cuda_memory_stats["active.all.allocated"])
|
|
644
|
+
)
|
|
645
|
+
memory_info_string += ",\n"
|
|
646
|
+
|
|
647
|
+
# TODO evaluate if this indeed works properly in distributed setting
|
|
648
|
+
self.MULTIPROCESSING_LOCK.acquire()
|
|
649
|
+
with open(full_filepath, file_write_mode) as file_mem:
|
|
650
|
+
file_mem.write(memory_info_string)
|
|
651
|
+
self.MULTIPROCESSING_LOCK.release()
|
|
652
|
+
|
|
653
|
+
@rank_zero_only
|
|
654
|
+
def _prepare_memory_stats_save_dir(self):
|
|
655
|
+
memory_stats_dir = os.path.join(self.output_dir, "memory_stats")
|
|
656
|
+
os.makedirs(memory_stats_dir, exist_ok=True)
|
|
657
|
+
return memory_stats_dir
|
|
658
|
+
|
|
659
|
+
@rank_zero_only
|
|
660
|
+
def _print_epoch_start_time(self):
|
|
661
|
+
print("Epoch start time : ", get_date_time(), flush=True)
|
|
662
|
+
|
|
663
|
+
@rank_zero_only
|
|
664
|
+
def _set_epoch_start_time(self):
|
|
665
|
+
self.epoch_start_time = time.time()
|
|
666
|
+
|
|
667
|
+
# TODO check if it indeed work properly and run only on rank 0
|
|
668
|
+
@rank_zero_only
|
|
669
|
+
def on_train_epoch_end(self):
|
|
670
|
+
epoch_metrics = {}
|
|
671
|
+
metric_names = self.training_metric_values[0].keys()
|
|
672
|
+
for metric_name in metric_names:
|
|
673
|
+
metric_values = [x[metric_name] for x in self.training_metric_values]
|
|
674
|
+
epoch_metrics[
|
|
675
|
+
metric_name
|
|
676
|
+
] = self._compute_metric_mean_across_values_from_batches(metric_values)
|
|
677
|
+
|
|
678
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
679
|
+
training_epoch_average_metrics_overall = overall_stats(
|
|
680
|
+
torch.cat(self.train_predictions),
|
|
681
|
+
torch.cat(self.train_labels),
|
|
682
|
+
self.params,
|
|
683
|
+
)
|
|
684
|
+
epoch_metrics.update(training_epoch_average_metrics_overall)
|
|
685
|
+
mean_loss = self._round_value_to_precision(
|
|
686
|
+
torch.mean(torch.stack(self.train_losses)).item()
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
self._clear_training_epoch_containers()
|
|
690
|
+
|
|
691
|
+
self.train_logger.write(
|
|
692
|
+
self.current_epoch,
|
|
693
|
+
mean_loss,
|
|
694
|
+
self._ensure_proper_metric_formatting_for_logging(epoch_metrics),
|
|
695
|
+
)
|
|
696
|
+
self.log("train_loss", mean_loss, on_epoch=True, prog_bar=True)
|
|
697
|
+
self.log_dict(
|
|
698
|
+
self._prepare_metrics_dict_for_progbar_logging(epoch_metrics),
|
|
699
|
+
on_epoch=True,
|
|
700
|
+
prog_bar=True,
|
|
701
|
+
sync_dist=True,
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
if self.params["verbose"]:
|
|
705
|
+
self._print_epoch_end_time()
|
|
706
|
+
if self.params["model"]["save_at_every_epoch"]:
|
|
707
|
+
self._save_epoch_end_checkpoint()
|
|
708
|
+
if os.path.exists(self.model_paths["latest"]):
|
|
709
|
+
os.remove(self.model_paths["latest"])
|
|
710
|
+
self._save_model(self.current_epoch, self.model_paths["latest"], False)
|
|
711
|
+
|
|
712
|
+
print("Latest model saved")
|
|
713
|
+
|
|
714
|
+
def _compute_metric_mean_across_values_from_batches(
|
|
715
|
+
self, metric_values: List[Union[float, List[float]]]
|
|
716
|
+
) -> Union[float, List[float]]:
|
|
717
|
+
"""
|
|
718
|
+
Given a list of metrics calculated for each batch, computes the mean across all batches.
|
|
719
|
+
Takes into account case where metric is a list of values (e.g. for each class).
|
|
720
|
+
|
|
721
|
+
Args:
|
|
722
|
+
metric_values (List[Union[float, List[float]]]): The list of metric values for each batch.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
Union[float, List[float]]: The mean value of the metric across all batches.
|
|
726
|
+
"""
|
|
727
|
+
if isinstance(metric_values[0], list):
|
|
728
|
+
return [
|
|
729
|
+
mean([batch_metrics[i] for batch_metrics in metric_values])
|
|
730
|
+
for i in range(len(metric_values[0]))
|
|
731
|
+
]
|
|
732
|
+
return self._round_value_to_precision(mean(metric_values))
|
|
733
|
+
|
|
734
|
+
@staticmethod
|
|
735
|
+
def _ensure_proper_metric_formatting_for_logging(metrics_dict: dict) -> dict:
|
|
736
|
+
"""
|
|
737
|
+
Helper function to ensure that all metric values are in the correct format for
|
|
738
|
+
GANDLF's logging system.
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
metrics_dict (dict): The dictionary containing the metric values.
|
|
742
|
+
|
|
743
|
+
Returns:
|
|
744
|
+
output_metrics_dict (dict): The dictionary containing the formatted metric values.
|
|
745
|
+
"""
|
|
746
|
+
output_metrics_dict = deepcopy(metrics_dict)
|
|
747
|
+
for metric in metrics_dict.keys():
|
|
748
|
+
if isinstance(metrics_dict[metric], list):
|
|
749
|
+
output_metrics_dict[metric] = ("_").join(
|
|
750
|
+
str(metrics_dict[metric])
|
|
751
|
+
.replace("[", "")
|
|
752
|
+
.replace("]", "")
|
|
753
|
+
.replace(" ", "")
|
|
754
|
+
.split(",")
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
return output_metrics_dict
|
|
758
|
+
|
|
759
|
+
@rank_zero_only
|
|
760
|
+
def _save_epoch_end_checkpoint(self):
|
|
761
|
+
"""
|
|
762
|
+
Saves the model at the end of the epoch.
|
|
763
|
+
"""
|
|
764
|
+
epoch_save_path = os.path.join(
|
|
765
|
+
self.output_dir,
|
|
766
|
+
self.params["model"]["architecture"]
|
|
767
|
+
+ "_epoch_"
|
|
768
|
+
+ str(self.current_epoch)
|
|
769
|
+
+ ".pth.tar",
|
|
770
|
+
)
|
|
771
|
+
self._save_model(self.current_epoch, epoch_save_path, False)
|
|
772
|
+
print("Epoch model saved.")
|
|
773
|
+
|
|
774
|
+
@rank_zero_only
|
|
775
|
+
def _print_epoch_end_time(self):
|
|
776
|
+
print(
|
|
777
|
+
"Time taken for epoch : ",
|
|
778
|
+
(time.time() - self.epoch_start_time) / 60,
|
|
779
|
+
" mins",
|
|
780
|
+
flush=True,
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
@rank_zero_only
|
|
784
|
+
def _clear_training_epoch_containers(self):
|
|
785
|
+
self.train_losses = []
|
|
786
|
+
self.training_metric_values = []
|
|
787
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
788
|
+
self.train_predictions = []
|
|
789
|
+
self.train_labels = []
|
|
790
|
+
|
|
791
|
+
@rank_zero_only
|
|
792
|
+
def on_train_end(self):
|
|
793
|
+
if os.path.exists(self.model_paths["best"]):
|
|
794
|
+
# Why don't we handle it here with the full save_model function?
|
|
795
|
+
# TODO Onnx export seems to modify model INPLACE, so when doing cuda
|
|
796
|
+
optimize_and_save_model(
|
|
797
|
+
self.model, self.params, self.model_paths["best"], onnx_export=False
|
|
798
|
+
)
|
|
799
|
+
self._print_total_training_time()
|
|
800
|
+
|
|
801
|
+
@rank_zero_only
|
|
802
|
+
def _print_total_training_time(self):
|
|
803
|
+
print(
|
|
804
|
+
"Total time taken for training : ",
|
|
805
|
+
(time.time() - self.training_start_time) / 60,
|
|
806
|
+
" mins",
|
|
807
|
+
flush=True,
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
@rank_zero_only
|
|
811
|
+
def on_validation_start(self):
|
|
812
|
+
self._initialize_validation_epoch_containers()
|
|
813
|
+
self._initialize_validation_logger()
|
|
814
|
+
|
|
815
|
+
@rank_zero_only
|
|
816
|
+
def _initialize_validation_epoch_containers(self):
|
|
817
|
+
self.val_losses: List[torch.Tensor] = []
|
|
818
|
+
self.validation_metric_values: List[Dict[str, float]] = []
|
|
819
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
820
|
+
self.val_predictions: List[float] = []
|
|
821
|
+
self.val_labels: List[float] = []
|
|
822
|
+
if self.params["save_output"]:
|
|
823
|
+
self.rows_to_write: List[str] = []
|
|
824
|
+
|
|
825
|
+
@rank_zero_only
|
|
826
|
+
def _initialize_validation_logger(self):
|
|
827
|
+
self.val_logger = Logger(
|
|
828
|
+
logger_csv_filename=os.path.join(self.output_dir, "logs_validation.csv"),
|
|
829
|
+
metrics=self._get_metrics_names_for_loggers(),
|
|
830
|
+
mode="val",
|
|
831
|
+
add_epsilon=bool(self.params.get("differential_privacy")),
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
@rank_zero_only
|
|
835
|
+
def on_validation_epoch_start(self):
|
|
836
|
+
# TODO this is dead code both here and in original loops
|
|
837
|
+
# by default medcam is injected at the training and ["medcam_enabled"] is set to False
|
|
838
|
+
# so this block is never executed
|
|
839
|
+
if self.params["medcam_enabled"]:
|
|
840
|
+
self.model.enable_medcam()
|
|
841
|
+
self.params["medcam_enabled"] = True
|
|
842
|
+
self._current_validation_epoch_save_dir = os.path.join(
|
|
843
|
+
self.output_dir, "output_validation", f"epoch_{self.current_epoch}"
|
|
844
|
+
)
|
|
845
|
+
self._ensure_path_exists(self._current_validation_epoch_save_dir)
|
|
846
|
+
|
|
847
|
+
def validation_step(self, subject, batch_idx):
|
|
848
|
+
if self.params["verbose"]:
|
|
849
|
+
self._print_currently_processed_subject(subject)
|
|
850
|
+
|
|
851
|
+
subject_dict = self._initialize_subject_dict_nontraining_mode(subject)
|
|
852
|
+
label_present = subject["label"] != ["NA"]
|
|
853
|
+
value_keys_present = "value_keys" in self.params
|
|
854
|
+
label = None
|
|
855
|
+
if label_present:
|
|
856
|
+
subject_dict = self._extend_nontraining_subject_dict_with_label(
|
|
857
|
+
subject, subject_dict
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
if (
|
|
861
|
+
self._problem_type_is_regression
|
|
862
|
+
or self._problem_type_is_classification
|
|
863
|
+
and label_present
|
|
864
|
+
):
|
|
865
|
+
(
|
|
866
|
+
model_output,
|
|
867
|
+
last_input_batch,
|
|
868
|
+
) = self._get_predictions_on_subject_using_label_sampler(subject_dict)
|
|
869
|
+
|
|
870
|
+
if self.params["save_output"]:
|
|
871
|
+
processed_logit = self._process_prediction_logit_for_row_writing(
|
|
872
|
+
model_output, self.params["scaling_factor"]
|
|
873
|
+
)
|
|
874
|
+
self.rows_to_write.append(
|
|
875
|
+
self._prepare_row_for_output_csv(
|
|
876
|
+
subject["subject_id"][0], processed_logit, self.current_epoch
|
|
877
|
+
)
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
label = self._initialize_nontraining_label_ground_truth_classification_or_regression(
|
|
881
|
+
subject
|
|
882
|
+
)
|
|
883
|
+
else:
|
|
884
|
+
(
|
|
885
|
+
model_output,
|
|
886
|
+
last_input_batch,
|
|
887
|
+
) = self._get_predictions_on_subject_using_grid_sampler(subject_dict)
|
|
888
|
+
|
|
889
|
+
if self.params["save_output"]:
|
|
890
|
+
self._save_predictions_for_segmentation_subject(model_output, subject)
|
|
891
|
+
|
|
892
|
+
if self._problem_type_is_segmentation and label_present:
|
|
893
|
+
label = self._initialize_nontraining_label_ground_truth_segmentation(
|
|
894
|
+
subject
|
|
895
|
+
)
|
|
896
|
+
elif (
|
|
897
|
+
self._problem_type_is_classification
|
|
898
|
+
or self._problem_type_is_regression
|
|
899
|
+
and value_keys_present
|
|
900
|
+
):
|
|
901
|
+
label = self._initialize_nontraining_label_ground_truth_classification_or_regression(
|
|
902
|
+
subject
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
if label is not None:
|
|
906
|
+
label = self._process_labels(label)
|
|
907
|
+
model_output, label = self.pred_target_processor(model_output, label)
|
|
908
|
+
loss = self.loss(model_output, label, last_input_batch)
|
|
909
|
+
metric_results = self.metric_calculators(
|
|
910
|
+
model_output, label, subject_spacing=subject.get("spacing", None)
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
self.val_losses.append(loss)
|
|
914
|
+
self.validation_metric_values.append(metric_results)
|
|
915
|
+
|
|
916
|
+
if (
|
|
917
|
+
self._problem_type_is_regression
|
|
918
|
+
or self._problem_type_is_classification
|
|
919
|
+
and label
|
|
920
|
+
):
|
|
921
|
+
model_prediction = (
|
|
922
|
+
torch.argmax(model_output[0], 0)
|
|
923
|
+
if self._problem_type_is_classification
|
|
924
|
+
else model_output[0]
|
|
925
|
+
) # TODO am I right here? For regression, we should not take argmax
|
|
926
|
+
self.val_predictions.append(model_prediction.item())
|
|
927
|
+
self.val_labels.append(label.item())
|
|
928
|
+
|
|
929
|
+
@staticmethod
|
|
930
|
+
def _prepare_row_for_output_csv(
|
|
931
|
+
subject_id: str, prediction_logit: float, epoch: int
|
|
932
|
+
):
|
|
933
|
+
"""
|
|
934
|
+
Helper function to prepare the row for the output CSV file.
|
|
935
|
+
|
|
936
|
+
Args:
|
|
937
|
+
subject_id (str): The subject ID.
|
|
938
|
+
prediction_logit (float): The prediction logit.
|
|
939
|
+
epoch (int): The epoch number.
|
|
940
|
+
|
|
941
|
+
Returns:
|
|
942
|
+
row (str): The row to write to the output CSV file.
|
|
943
|
+
"""
|
|
944
|
+
|
|
945
|
+
return f"{epoch},{subject_id},{prediction_logit}\n"
|
|
946
|
+
|
|
947
|
+
@staticmethod
|
|
948
|
+
def _prepare_row_for_output_csv_histopathology_inference(
|
|
949
|
+
subject_name, x_coord, y_coord, output_matrix
|
|
950
|
+
):
|
|
951
|
+
"""
|
|
952
|
+
Helper function to prepare the row for the output CSV file in histopathology inference.
|
|
953
|
+
|
|
954
|
+
Args:
|
|
955
|
+
subject_name (str): The subject name.
|
|
956
|
+
x_coord (int): The x coordinate.
|
|
957
|
+
y_coord (int): The y coordinate.
|
|
958
|
+
output_matrix (np.array) : output matrix of the model, a set of
|
|
959
|
+
predicted 2D matrices for each class
|
|
960
|
+
|
|
961
|
+
Returns:
|
|
962
|
+
row (str): The row to write to the output CSV file.
|
|
963
|
+
"""
|
|
964
|
+
base_string = f"{subject_name},{x_coord},{y_coord}"
|
|
965
|
+
for output_for_class in output_matrix:
|
|
966
|
+
base_string += f",{output_for_class}"
|
|
967
|
+
return base_string + "\n"
|
|
968
|
+
|
|
969
|
+
@staticmethod
|
|
970
|
+
def _process_prediction_logit_for_row_writing(
|
|
971
|
+
prediction_logit: torch.Tensor, scaling_factor: float = 1.0
|
|
972
|
+
):
|
|
973
|
+
"""
|
|
974
|
+
Processes the prediction logits for writing to the output CSV file.
|
|
975
|
+
|
|
976
|
+
Args:
|
|
977
|
+
prediction_logit (torch.Tensor): The prediction logits.
|
|
978
|
+
scaling_factor (float): The scaling factor modifying the prediction logit.
|
|
979
|
+
Default is 1 (no scaling).
|
|
980
|
+
|
|
981
|
+
Returns:
|
|
982
|
+
prediction_logit (float): The processed prediction logit.
|
|
983
|
+
"""
|
|
984
|
+
return prediction_logit.cpu().max().item() / scaling_factor
|
|
985
|
+
|
|
986
|
+
def _print_currently_processed_subject(self, subject):
|
|
987
|
+
if isinstance(subject, torchio.Subject):
|
|
988
|
+
subject_id = subject["subject_id"]
|
|
989
|
+
elif isinstance(subject, tuple):
|
|
990
|
+
# ugly corner histology inference handling, when incoming batch is
|
|
991
|
+
# a row from dataframe, not a torchio.Subject. This should be solved
|
|
992
|
+
# via some kind of polymorphism in the future
|
|
993
|
+
subject_data = subject[1]
|
|
994
|
+
subject_id = subject_data[self.params["headers"]["subjectIDHeader"]]
|
|
995
|
+
print("== Current subject:", subject_id, flush=True)
|
|
996
|
+
|
|
997
|
+
def _initialize_subject_dict_nontraining_mode(self, subject: torchio.Subject):
|
|
998
|
+
"""
|
|
999
|
+
Create a dictionary containing the subject data for the non-training mode
|
|
1000
|
+
(validation, testing, inference).
|
|
1001
|
+
|
|
1002
|
+
Args:
|
|
1003
|
+
subject (torchio.Subject): The subject data.
|
|
1004
|
+
|
|
1005
|
+
Returns:
|
|
1006
|
+
subject_dict (Dict[str, torchio.Image]): The dictionary containing the subject data.
|
|
1007
|
+
"""
|
|
1008
|
+
subject_dict = {}
|
|
1009
|
+
|
|
1010
|
+
for channel_key in self.params["channel_keys"]:
|
|
1011
|
+
subject_dict[channel_key] = torchio.ScalarImage(
|
|
1012
|
+
path=subject[channel_key]["path"],
|
|
1013
|
+
tensor=subject[channel_key]["data"].squeeze(0),
|
|
1014
|
+
affine=subject[channel_key]["affine"].squeeze(0),
|
|
1015
|
+
)
|
|
1016
|
+
value_keys_present = "value_keys" in self.params
|
|
1017
|
+
if (
|
|
1018
|
+
self._problem_type_is_regression
|
|
1019
|
+
or self._problem_type_is_classification
|
|
1020
|
+
and value_keys_present
|
|
1021
|
+
):
|
|
1022
|
+
for key in self.params["value_keys"]:
|
|
1023
|
+
subject_dict["value_" + key] = subject[key]
|
|
1024
|
+
|
|
1025
|
+
return subject_dict
|
|
1026
|
+
|
|
1027
|
+
def _extend_nontraining_subject_dict_with_label(
|
|
1028
|
+
self, subject: torchio.Subject, subject_dict: dict
|
|
1029
|
+
) -> dict:
|
|
1030
|
+
"""
|
|
1031
|
+
Extends the subject dictionary with the label data for the non-training mode.
|
|
1032
|
+
|
|
1033
|
+
Args:
|
|
1034
|
+
subject (torchio.Subject): The subject data.
|
|
1035
|
+
subject_dict (dict): The dictionary containing the subject data.
|
|
1036
|
+
|
|
1037
|
+
Returns:
|
|
1038
|
+
subject_dict (dict): The dictionary containing the subject data with the label data.
|
|
1039
|
+
"""
|
|
1040
|
+
subject_dict["label"] = torchio.LabelMap(
|
|
1041
|
+
path=subject["label"]["path"],
|
|
1042
|
+
tensor=subject["label"]["data"].squeeze(0),
|
|
1043
|
+
affine=subject["label"]["affine"].squeeze(0),
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
return subject_dict
|
|
1047
|
+
|
|
1048
|
+
def _initialize_nontraining_label_ground_truth_classification_or_regression(
|
|
1049
|
+
self, subject: torchio.Subject
|
|
1050
|
+
):
|
|
1051
|
+
"""
|
|
1052
|
+
Initializes the ground truth label for classification or regression problems
|
|
1053
|
+
in the non-training mode (validation, testing, inference).
|
|
1054
|
+
|
|
1055
|
+
Args:
|
|
1056
|
+
subject_dict (torchio.Subject): The dictionary containing the subject data.
|
|
1057
|
+
|
|
1058
|
+
Returns:
|
|
1059
|
+
label (torch.Tensor): The ground truth label tensor.
|
|
1060
|
+
"""
|
|
1061
|
+
return torch.cat([subject[key] for key in self.params["value_keys"]], dim=0)
|
|
1062
|
+
|
|
1063
|
+
def _initialize_nontraining_label_ground_truth_segmentation(
|
|
1064
|
+
self, subject: torchio.Subject
|
|
1065
|
+
):
|
|
1066
|
+
"""
|
|
1067
|
+
Initializes the ground truth label for segmentation problems in the non-training mode
|
|
1068
|
+
(validation, testing, inference).
|
|
1069
|
+
|
|
1070
|
+
Args:
|
|
1071
|
+
subject_dict (torchio.Subject): The dictionary containing the subject data.
|
|
1072
|
+
|
|
1073
|
+
Returns:
|
|
1074
|
+
label (torch.Tensor): The ground truth label tensor
|
|
1075
|
+
"""
|
|
1076
|
+
|
|
1077
|
+
return subject["label"]["data"]
|
|
1078
|
+
|
|
1079
|
+
# TODO this whole logic can be packed into something separate, as it is only used
|
|
1080
|
+
# in validation of regression and classification problems
|
|
1081
|
+
def _get_predictions_on_subject_using_label_sampler(
|
|
1082
|
+
self, subject_dict: dict
|
|
1083
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1084
|
+
"""
|
|
1085
|
+
Make predictions on the subject using the label sampler. Used for regression and classification problems.
|
|
1086
|
+
|
|
1087
|
+
Args:
|
|
1088
|
+
subject_dict (dict): The dictionary containing the subject data.
|
|
1089
|
+
|
|
1090
|
+
Returns:
|
|
1091
|
+
total_logits_for_all_patches (torch.Tensor): The total logits for all patches
|
|
1092
|
+
extracted from a subject, normalized by the number of samples per volume.
|
|
1093
|
+
last_batch_of_input_images (torch.Tensor): The last batch of input images. Used
|
|
1094
|
+
mostly for special cases like deep_resunet, deep_unet, etc. when it is needed for
|
|
1095
|
+
loss calculation.
|
|
1096
|
+
"""
|
|
1097
|
+
|
|
1098
|
+
def _prepare_images_batch_from_patch_regression_or_classification_with_label_sampler(
|
|
1099
|
+
patches_batch: torchio.Subject,
|
|
1100
|
+
):
|
|
1101
|
+
"""
|
|
1102
|
+
Sampling the patches using the label sampler requires a different approach
|
|
1103
|
+
to preparing the images batch (concatenation dimension changes compared to logic
|
|
1104
|
+
in other steps).
|
|
1105
|
+
|
|
1106
|
+
Args:
|
|
1107
|
+
patches_batch (torchio.Subject): The batch of patches for the subject.
|
|
1108
|
+
|
|
1109
|
+
Returns:
|
|
1110
|
+
images_batch_from_patches (torch.Tensor): The images batch from the patches.
|
|
1111
|
+
"""
|
|
1112
|
+
images_batch_from_patches = torch.cat(
|
|
1113
|
+
[
|
|
1114
|
+
patches_batch[key][torchio.DATA]
|
|
1115
|
+
for key in self.params["channel_keys"]
|
|
1116
|
+
],
|
|
1117
|
+
dim=0,
|
|
1118
|
+
).unsqueeze(0)
|
|
1119
|
+
if images_batch_from_patches.shape[-1] == 1:
|
|
1120
|
+
images_batch_from_patches = torch.squeeze(images_batch_from_patches, -1)
|
|
1121
|
+
return images_batch_from_patches
|
|
1122
|
+
|
|
1123
|
+
sampler = torchio.data.LabelSampler(self.params["patch_size"])
|
|
1124
|
+
tio_subject = torchio.Subject(subject_dict)
|
|
1125
|
+
patch_loader = sampler(
|
|
1126
|
+
tio_subject, num_patches=self.params["q_samples_per_volume"]
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
model_outputs_list: List[torch.Tensor] = []
|
|
1130
|
+
for patches_batch in patch_loader:
|
|
1131
|
+
images_from_patches = _prepare_images_batch_from_patch_regression_or_classification_with_label_sampler(
|
|
1132
|
+
patches_batch
|
|
1133
|
+
)
|
|
1134
|
+
images_from_patches = self._ensure_proper_images_tensor_dimensions(
|
|
1135
|
+
images_from_patches
|
|
1136
|
+
)
|
|
1137
|
+
model_output, _ = self.forward(images_from_patches)
|
|
1138
|
+
model_outputs_list.append(model_output)
|
|
1139
|
+
|
|
1140
|
+
total_logits_for_all_patches = torch.cat(model_outputs_list).sum(
|
|
1141
|
+
dim=0, keepdim=True
|
|
1142
|
+
)
|
|
1143
|
+
return (
|
|
1144
|
+
total_logits_for_all_patches / self.params["q_samples_per_volume"],
|
|
1145
|
+
images_from_patches,
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
@rank_zero_only
|
|
1149
|
+
def _determine_trainer_stage_string(self):
|
|
1150
|
+
"""
|
|
1151
|
+
Helper function to determine the trainer stage and store it as a module attribute.
|
|
1152
|
+
"""
|
|
1153
|
+
if self.trainer.validating:
|
|
1154
|
+
return "val"
|
|
1155
|
+
elif self.trainer.testing:
|
|
1156
|
+
return "test"
|
|
1157
|
+
elif self.trainer.predicting:
|
|
1158
|
+
return "inference"
|
|
1159
|
+
|
|
1160
|
+
return "train"
|
|
1161
|
+
|
|
1162
|
+
def _determine_save_path_to_use(self):
|
|
1163
|
+
"""
|
|
1164
|
+
Helper function to determine the output save path based on the trainer stage.
|
|
1165
|
+
"""
|
|
1166
|
+
if self.trainer.validating:
|
|
1167
|
+
return self._current_validation_epoch_save_dir
|
|
1168
|
+
elif self.trainer.testing:
|
|
1169
|
+
return self._current_test_epoch_save_dir
|
|
1170
|
+
elif self.trainer.predicting:
|
|
1171
|
+
return self._current_inference_save_dir
|
|
1172
|
+
else:
|
|
1173
|
+
raise RuntimeError("Output save path cannot be determined for training")
|
|
1174
|
+
|
|
1175
|
+
def _get_predictions_on_subject_using_grid_sampler(
|
|
1176
|
+
self, subject_dict: dict
|
|
1177
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1178
|
+
"""
|
|
1179
|
+
Make predictions on the subject using the grid sampler. This is used in segmentation
|
|
1180
|
+
problems in validation and testing and for all problems in inference
|
|
1181
|
+
(as no ground truth is available in inference).
|
|
1182
|
+
|
|
1183
|
+
Args:
|
|
1184
|
+
subject_dict (dict): The dictionary containing the subject data.
|
|
1185
|
+
|
|
1186
|
+
Returns:
|
|
1187
|
+
aggregated_predictions (torch.Tensor): The predicted segmentation mask.
|
|
1188
|
+
last_batch_of_input_images (torch.Tensor): The last batch of input images. Used
|
|
1189
|
+
mostly for special cases like deep_resunet, deep_unet, etc. when it is needed for
|
|
1190
|
+
loss calculation.
|
|
1191
|
+
"""
|
|
1192
|
+
|
|
1193
|
+
def _ensure_output_is_tensor_for_special_architectures(
|
|
1194
|
+
model_output: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
1195
|
+
):
|
|
1196
|
+
"""
|
|
1197
|
+
Helper function to ensure that the output is a tensor for special architectures
|
|
1198
|
+
that return a tuple of tensors (SDnet, DeepResunet etc)
|
|
1199
|
+
|
|
1200
|
+
Args:
|
|
1201
|
+
model_output (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): The model output.
|
|
1202
|
+
"""
|
|
1203
|
+
|
|
1204
|
+
if not isinstance(model_output, torch.Tensor):
|
|
1205
|
+
warnings.warn(
|
|
1206
|
+
f"Model output is not a Tensor: {type(model_output)}. Say, `deep_resunet` and `deep_unet` may return "
|
|
1207
|
+
f"list of tensors on different scales instead of just one prediction Tensor. However due to "
|
|
1208
|
+
f"GaNDLF architecture it is expected that models return only one tensor. For deep_* models "
|
|
1209
|
+
f"only the biggest scale is processed. Use these models with caution till fix is implemented."
|
|
1210
|
+
)
|
|
1211
|
+
model_output = model_output[0]
|
|
1212
|
+
|
|
1213
|
+
return model_output
|
|
1214
|
+
|
|
1215
|
+
def _ensure_output_shape_compatibility_with_torchio(model_output: torch.Tensor):
|
|
1216
|
+
"""
|
|
1217
|
+
Helper function to ensure that the output shape is compatible with torchio (4D for 2D segmentation).
|
|
1218
|
+
|
|
1219
|
+
Args:
|
|
1220
|
+
model_output (torch.Tensor): The model output tensor.
|
|
1221
|
+
|
|
1222
|
+
Returns:
|
|
1223
|
+
model_output (torch.Tensor): The model output tensor with the correct shape.
|
|
1224
|
+
"""
|
|
1225
|
+
if (
|
|
1226
|
+
self.params["model"]["dimension"] == 2
|
|
1227
|
+
and self._problem_type_is_segmentation
|
|
1228
|
+
):
|
|
1229
|
+
model_output = model_output.unsqueeze(-1)
|
|
1230
|
+
return model_output
|
|
1231
|
+
|
|
1232
|
+
grid_sampler = self._prepare_grid_sampler(subject_dict)
|
|
1233
|
+
patch_loader = self._prepare_dataloader_from_grid_sampler(grid_sampler)
|
|
1234
|
+
|
|
1235
|
+
prediction_aggregator = torchio.inference.GridAggregator(
|
|
1236
|
+
grid_sampler,
|
|
1237
|
+
overlap_mode=self.params["inference_mechanism"]["grid_aggregator_overlap"],
|
|
1238
|
+
)
|
|
1239
|
+
if self.params["medcam_enabled"]:
|
|
1240
|
+
medcam_attention_map_aggregator = torchio.inference.GridAggregator(
|
|
1241
|
+
grid_sampler,
|
|
1242
|
+
overlap_mode=self.params["inference_mechanism"][
|
|
1243
|
+
"grid_aggregator_overlap"
|
|
1244
|
+
],
|
|
1245
|
+
)
|
|
1246
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
1247
|
+
model_outputs_list: List[torch.Tensor] = []
|
|
1248
|
+
|
|
1249
|
+
for patches_batch in patch_loader:
|
|
1250
|
+
images_from_patches = self._prepare_images_batch_from_subject_data(
|
|
1251
|
+
patches_batch
|
|
1252
|
+
)
|
|
1253
|
+
images_from_patches = self._ensure_proper_images_tensor_dimensions(
|
|
1254
|
+
images_from_patches
|
|
1255
|
+
)
|
|
1256
|
+
model_output, attention_map = self.forward(images_from_patches)
|
|
1257
|
+
|
|
1258
|
+
model_output = _ensure_output_is_tensor_for_special_architectures(
|
|
1259
|
+
model_output
|
|
1260
|
+
)
|
|
1261
|
+
model_output = _ensure_output_shape_compatibility_with_torchio(model_output)
|
|
1262
|
+
if self.params["medcam_enabled"]:
|
|
1263
|
+
medcam_attention_map_aggregator.add_batch(
|
|
1264
|
+
attention_map, patches_batch[torchio.LOCATION] # type: ignore
|
|
1265
|
+
)
|
|
1266
|
+
if self._problem_type_is_segmentation:
|
|
1267
|
+
prediction_aggregator.add_batch(
|
|
1268
|
+
model_output, patches_batch[torchio.LOCATION]
|
|
1269
|
+
)
|
|
1270
|
+
else:
|
|
1271
|
+
model_outputs_list.append(model_output)
|
|
1272
|
+
|
|
1273
|
+
if self.params["medcam_enabled"]:
|
|
1274
|
+
attention_map = medcam_attention_map_aggregator.get_output_tensor()
|
|
1275
|
+
for i, n in enumerate(attention_map):
|
|
1276
|
+
self.model.save_attention_map(
|
|
1277
|
+
n.squeeze(), raw_input=images_from_patches[i].squeeze(-1)
|
|
1278
|
+
)
|
|
1279
|
+
|
|
1280
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
1281
|
+
return (
|
|
1282
|
+
torch.cat(model_outputs_list).sum(dim=0, keepdim=True)
|
|
1283
|
+
/ len(patch_loader),
|
|
1284
|
+
images_from_patches,
|
|
1285
|
+
)
|
|
1286
|
+
|
|
1287
|
+
return (
|
|
1288
|
+
prediction_aggregator.get_output_tensor().unsqueeze(0).to(self.device),
|
|
1289
|
+
images_from_patches,
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
def _prepare_grid_sampler(self, subject_dict: dict):
|
|
1293
|
+
"""
|
|
1294
|
+
Creates the grid sampler for the grid aggregator.
|
|
1295
|
+
|
|
1296
|
+
Args:
|
|
1297
|
+
subject_dict (dict): The dictionary containing the subject data.
|
|
1298
|
+
|
|
1299
|
+
Returns:
|
|
1300
|
+
grid_sampler (torchio.inference.GridSampler): The grid sampler.
|
|
1301
|
+
"""
|
|
1302
|
+
grid_sampler = torchio.inference.GridSampler(
|
|
1303
|
+
torchio.Subject(subject_dict),
|
|
1304
|
+
self.params["patch_size"],
|
|
1305
|
+
patch_overlap=self.params["inference_mechanism"]["patch_overlap"],
|
|
1306
|
+
)
|
|
1307
|
+
return grid_sampler
|
|
1308
|
+
|
|
1309
|
+
def _prepare_dataloader_from_grid_sampler(
|
|
1310
|
+
self, grid_sampler: torchio.inference.GridSampler
|
|
1311
|
+
):
|
|
1312
|
+
"""
|
|
1313
|
+
Creates the dataloader from the grid sampler.
|
|
1314
|
+
|
|
1315
|
+
Args:
|
|
1316
|
+
grid_sampler (torchio.inference.GridSampler): The grid sampler.
|
|
1317
|
+
|
|
1318
|
+
Returns:
|
|
1319
|
+
patch_loader (torch.utils.data.DataLoader): The patch loader.
|
|
1320
|
+
"""
|
|
1321
|
+
|
|
1322
|
+
return torch.utils.data.DataLoader(grid_sampler, batch_size=1) # type: ignore
|
|
1323
|
+
|
|
1324
|
+
# TODO check if it indeed work properly and run only on rank 0
|
|
1325
|
+
@rank_zero_only
|
|
1326
|
+
def on_validation_epoch_end(self):
|
|
1327
|
+
validation_epoch_average_metrics = {}
|
|
1328
|
+
metric_names = self.validation_metric_values[0].keys()
|
|
1329
|
+
for metric_name in metric_names:
|
|
1330
|
+
metric_values = [x[metric_name] for x in self.validation_metric_values]
|
|
1331
|
+
validation_epoch_average_metrics[
|
|
1332
|
+
metric_name
|
|
1333
|
+
] = self._compute_metric_mean_across_values_from_batches(metric_values)
|
|
1334
|
+
|
|
1335
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
1336
|
+
# This is a workaround - sometimes the lists are empty
|
|
1337
|
+
preds_or_labels_not_empty = not (
|
|
1338
|
+
len(self.val_predictions) == 0 or len(self.val_labels) == 0
|
|
1339
|
+
)
|
|
1340
|
+
if preds_or_labels_not_empty:
|
|
1341
|
+
validation_epoch_average_metrics_overall = overall_stats(
|
|
1342
|
+
torch.tensor(self.val_predictions),
|
|
1343
|
+
torch.tensor(self.val_labels),
|
|
1344
|
+
self.params,
|
|
1345
|
+
)
|
|
1346
|
+
validation_epoch_average_metrics.update(
|
|
1347
|
+
validation_epoch_average_metrics_overall
|
|
1348
|
+
)
|
|
1349
|
+
mean_loss = self._round_value_to_precision(
|
|
1350
|
+
torch.mean(torch.stack(self.val_losses)).item()
|
|
1351
|
+
)
|
|
1352
|
+
|
|
1353
|
+
self.val_logger.write(
|
|
1354
|
+
self.current_epoch,
|
|
1355
|
+
mean_loss,
|
|
1356
|
+
self._ensure_proper_metric_formatting_for_logging(
|
|
1357
|
+
validation_epoch_average_metrics
|
|
1358
|
+
),
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
self.log("val_loss", mean_loss, on_epoch=True, prog_bar=True)
|
|
1362
|
+
self.log_dict(
|
|
1363
|
+
self._prepare_metrics_dict_for_progbar_logging(
|
|
1364
|
+
validation_epoch_average_metrics
|
|
1365
|
+
),
|
|
1366
|
+
on_epoch=True,
|
|
1367
|
+
prog_bar=True,
|
|
1368
|
+
sync_dist=False,
|
|
1369
|
+
)
|
|
1370
|
+
|
|
1371
|
+
self._check_if_early_stopping(mean_loss)
|
|
1372
|
+
|
|
1373
|
+
if self.params["save_output"] and (
|
|
1374
|
+
self._problem_type_is_regression or self._problem_type_is_classification
|
|
1375
|
+
):
|
|
1376
|
+
self._save_predictions_csv_for_regression_or_classification(
|
|
1377
|
+
self.rows_to_write, self._determine_save_path_to_use()
|
|
1378
|
+
)
|
|
1379
|
+
if self.params.get("differential_privacy"):
|
|
1380
|
+
self._print_differential_privacy_info()
|
|
1381
|
+
self._clear_validation_epoch_containers()
|
|
1382
|
+
|
|
1383
|
+
@rank_zero_only
|
|
1384
|
+
def _clear_validation_epoch_containers(self):
|
|
1385
|
+
self.val_losses = []
|
|
1386
|
+
self.validation_metric_values = []
|
|
1387
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
1388
|
+
self.val_predictions = []
|
|
1389
|
+
self.val_labels = []
|
|
1390
|
+
if self.params["save_output"]:
|
|
1391
|
+
self.rows_to_write = []
|
|
1392
|
+
|
|
1393
|
+
@rank_zero_only
|
|
1394
|
+
def _ensure_path_exists(self, path):
|
|
1395
|
+
if not os.path.exists(path):
|
|
1396
|
+
os.makedirs(path)
|
|
1397
|
+
|
|
1398
|
+
@rank_zero_only
|
|
1399
|
+
def _print_differential_privacy_info(self):
|
|
1400
|
+
delta = self.params["differential_privacy"]["delta"]
|
|
1401
|
+
epsilon = self._dp_engine.get_epsilon(delta)
|
|
1402
|
+
print(f"Epoch {self.current_epoch} Privacy: ε = {epsilon:.2f}, δ = {delta}")
|
|
1403
|
+
self.log("epsilon", epsilon, on_epoch=True, prog_bar=True)
|
|
1404
|
+
self.log("delta", delta, on_epoch=True, prog_bar=True)
|
|
1405
|
+
|
|
1406
|
+
# TODO called at the validation step, NOT at the end of the epoch - we want to avoid
|
|
1407
|
+
# saving all predictions for all subjects for the end of the epoch
|
|
1408
|
+
def _save_predictions_for_segmentation_subject(
|
|
1409
|
+
self, predicted_segmentation_mask: torch.Tensor, subject: torchio.Subject
|
|
1410
|
+
):
|
|
1411
|
+
"""
|
|
1412
|
+
Saves the predicted segmentation mask for a given subject, performing the necessary postprocessing
|
|
1413
|
+
steps.
|
|
1414
|
+
|
|
1415
|
+
Args:
|
|
1416
|
+
predicted_segmentation_mask (torch.Tensor): The predicted segmentation mask, extracted
|
|
1417
|
+
from the grid aggregator when all validation patches for this subject have
|
|
1418
|
+
been processed.
|
|
1419
|
+
subject (torchio.Subject): The subject for which the segmentation mask was predicted, used
|
|
1420
|
+
to extract the metadata.
|
|
1421
|
+
"""
|
|
1422
|
+
|
|
1423
|
+
def _convert_subject_to_sitk_format(subject: torchio.Subject):
|
|
1424
|
+
return torchio.ScalarImage(
|
|
1425
|
+
tensor=subject["1"]["data"].squeeze(0).cpu(),
|
|
1426
|
+
affine=subject["1"]["affine"].squeeze(0).cpu(),
|
|
1427
|
+
).as_sitk()
|
|
1428
|
+
|
|
1429
|
+
def _postprocess_raw_segmentation_mask(
|
|
1430
|
+
segmentation_mask: np.ndarray, params: dict
|
|
1431
|
+
):
|
|
1432
|
+
for postprocessor in params["data_postprocessing"]:
|
|
1433
|
+
for _class in range(0, params["model"]["num_classes"]):
|
|
1434
|
+
segmentation_mask[0, _class, ...] = global_postprocessing_dict[
|
|
1435
|
+
postprocessor
|
|
1436
|
+
](segmentation_mask[0, _class, ...], params)
|
|
1437
|
+
|
|
1438
|
+
return segmentation_mask
|
|
1439
|
+
|
|
1440
|
+
def _swap_mask_axes_for_sitk_save_format_compatibility(
|
|
1441
|
+
segmentation_mask: np.ndarray,
|
|
1442
|
+
):
|
|
1443
|
+
return np.swapaxes(segmentation_mask, 0, 2)
|
|
1444
|
+
|
|
1445
|
+
def _postprocess_one_hot_reversed_segmentation_mask(
|
|
1446
|
+
segmentation_mask: np.ndarray, params: dict
|
|
1447
|
+
):
|
|
1448
|
+
for postprocessor in params[
|
|
1449
|
+
"data_postprocessing_after_reverse_one_hot_encoding"
|
|
1450
|
+
]:
|
|
1451
|
+
segmentation_mask = global_postprocessing_dict[postprocessor](
|
|
1452
|
+
segmentation_mask, params
|
|
1453
|
+
)
|
|
1454
|
+
|
|
1455
|
+
return segmentation_mask
|
|
1456
|
+
|
|
1457
|
+
def _determine_final_prediction_mask_shape(segmentation_mask: np.ndarray):
|
|
1458
|
+
if segmentation_mask.shape[0] == 1:
|
|
1459
|
+
return segmentation_mask.squeeze(0)
|
|
1460
|
+
elif segmentation_mask.shape[-1] == 1:
|
|
1461
|
+
return segmentation_mask.squeeze(-1)
|
|
1462
|
+
else:
|
|
1463
|
+
return segmentation_mask
|
|
1464
|
+
|
|
1465
|
+
predicted_segmentation_mask_numpy = predicted_segmentation_mask.cpu().numpy()
|
|
1466
|
+
predicted_segmentation_mask_numpy = _postprocess_raw_segmentation_mask(
|
|
1467
|
+
predicted_segmentation_mask_numpy, self.params
|
|
1468
|
+
)
|
|
1469
|
+
# taking 0-th element as the batch size is 1, and this is required by reverse_one_hot function
|
|
1470
|
+
decoded_segmentation_mask = reverse_one_hot(
|
|
1471
|
+
predicted_segmentation_mask_numpy[0], self.params["model"]["class_list"]
|
|
1472
|
+
)
|
|
1473
|
+
decoded_segmentation_mask = _swap_mask_axes_for_sitk_save_format_compatibility(
|
|
1474
|
+
decoded_segmentation_mask
|
|
1475
|
+
)
|
|
1476
|
+
decoded_segmentation_mask = _postprocess_one_hot_reversed_segmentation_mask(
|
|
1477
|
+
decoded_segmentation_mask, self.params
|
|
1478
|
+
)
|
|
1479
|
+
decoded_segmentation_mask = _determine_final_prediction_mask_shape(
|
|
1480
|
+
decoded_segmentation_mask
|
|
1481
|
+
)
|
|
1482
|
+
|
|
1483
|
+
image_save_format = get_filename_extension_sanitized(subject["1"]["path"][0])
|
|
1484
|
+
if image_save_format in [".jpg", ".jpeg", ".png"]:
|
|
1485
|
+
decoded_segmentation_mask = decoded_segmentation_mask.astype(np.uint8)
|
|
1486
|
+
|
|
1487
|
+
subject_converted_to_sitk_format = _convert_subject_to_sitk_format(subject)
|
|
1488
|
+
result_sitk_image = sitk.GetImageFromArray(decoded_segmentation_mask)
|
|
1489
|
+
result_sitk_image.CopyInformation(subject_converted_to_sitk_format)
|
|
1490
|
+
|
|
1491
|
+
if "resample" in self.params["data_preprocessing"]:
|
|
1492
|
+
result_sitk_image = resample_image(
|
|
1493
|
+
result_sitk_image,
|
|
1494
|
+
subject_converted_to_sitk_format.GetSpacing(),
|
|
1495
|
+
interpolator=sitk.sitkNearestNeighbor,
|
|
1496
|
+
)
|
|
1497
|
+
segmentation_mask_save_path = os.path.join(
|
|
1498
|
+
self._determine_save_path_to_use(),
|
|
1499
|
+
subject["subject_id"][0],
|
|
1500
|
+
f"{subject['subject_id'][0]}_seg_process_rank_{self.global_rank}{image_save_format}",
|
|
1501
|
+
)
|
|
1502
|
+
self._ensure_path_exists(os.path.dirname(segmentation_mask_save_path))
|
|
1503
|
+
sitk.WriteImage(result_sitk_image, segmentation_mask_save_path)
|
|
1504
|
+
|
|
1505
|
+
@rank_zero_only
|
|
1506
|
+
def _save_predictions_csv_for_regression_or_classification(
|
|
1507
|
+
self, rows_to_write: List[str], save_path: str
|
|
1508
|
+
):
|
|
1509
|
+
"""
|
|
1510
|
+
Saves the predictions for regression or classification problems to a CSV file.
|
|
1511
|
+
|
|
1512
|
+
Args:
|
|
1513
|
+
rows_to_write (List[str]): The rows to write to the CSV file. Each element of
|
|
1514
|
+
the list is a row.
|
|
1515
|
+
save_path (str): The save path for the CSV file.
|
|
1516
|
+
"""
|
|
1517
|
+
|
|
1518
|
+
def _determine_header_to_use():
|
|
1519
|
+
if self.trainer.predicting:
|
|
1520
|
+
if self.params["modality"] in ["histo", "path"]:
|
|
1521
|
+
header = self.CLASSIFICATION_REGRESSION_RESULTS_HEADER_HISTOPATH
|
|
1522
|
+
if self._problem_type_is_regression:
|
|
1523
|
+
return header + ",output\n"
|
|
1524
|
+
elif self._problem_type_is_classification:
|
|
1525
|
+
for class_num in range(self.params["model"]["num_classes"]):
|
|
1526
|
+
header += f",probability_{class_num}"
|
|
1527
|
+
return header + "\n"
|
|
1528
|
+
return self.CLASSIFICATION_REGRESSION_RESULTS_HEADER
|
|
1529
|
+
|
|
1530
|
+
csv_save_path = os.path.join(save_path, "output_predictions.csv")
|
|
1531
|
+
merged_output = _determine_header_to_use()
|
|
1532
|
+
for row in rows_to_write:
|
|
1533
|
+
merged_output += row
|
|
1534
|
+
with open(csv_save_path, "w") as file:
|
|
1535
|
+
file.write(merged_output)
|
|
1536
|
+
|
|
1537
|
+
# TODO separate it into checking and saving functions, perhaps even separate class
|
|
1538
|
+
@rank_zero_only
|
|
1539
|
+
def _check_if_early_stopping(self, val_loss: float):
|
|
1540
|
+
"""
|
|
1541
|
+
Checks if early stopping should be triggered based on the validation loss.
|
|
1542
|
+
If the loss improves, the best model is saved.
|
|
1543
|
+
"""
|
|
1544
|
+
previous_best_loss = deepcopy(self.current_best_loss)
|
|
1545
|
+
if val_loss < self.current_best_loss:
|
|
1546
|
+
self.current_best_loss = val_loss
|
|
1547
|
+
self._save_model(self.current_epoch, self.model_paths["best"], False)
|
|
1548
|
+
print(
|
|
1549
|
+
f"Loss value improved. Previous best loss :{previous_best_loss}, new best loss: {val_loss} Saving best model from epoch {self.current_epoch}",
|
|
1550
|
+
flush=True,
|
|
1551
|
+
)
|
|
1552
|
+
self.wait_count_before_early_stopping = 0
|
|
1553
|
+
else:
|
|
1554
|
+
self.wait_count_before_early_stopping += 1
|
|
1555
|
+
print(
|
|
1556
|
+
f"Validation loss did not improve. Waiting count before early stopping: {self.wait_count_before_early_stopping} / {self.params['patience']}",
|
|
1557
|
+
flush=True,
|
|
1558
|
+
)
|
|
1559
|
+
if self.wait_count_before_early_stopping > self.params["patience"]:
|
|
1560
|
+
self.trainer.should_stop = True
|
|
1561
|
+
print(
|
|
1562
|
+
f"Early stopping triggered at epoch {self.current_epoch}, validation loss did not improve for {self.params['patience']} epochs, with the best loss value being {self.current_best_loss}. Stopping training.",
|
|
1563
|
+
flush=True,
|
|
1564
|
+
)
|
|
1565
|
+
del previous_best_loss
|
|
1566
|
+
|
|
1567
|
+
def on_test_start(self):
|
|
1568
|
+
self._initialize_test_epoch_containers()
|
|
1569
|
+
self._initialize_test_logger()
|
|
1570
|
+
|
|
1571
|
+
@rank_zero_only
|
|
1572
|
+
def _initialize_test_logger(self):
|
|
1573
|
+
self.test_logger = Logger(
|
|
1574
|
+
logger_csv_filename=os.path.join(self.output_dir, "logs_test.csv"),
|
|
1575
|
+
metrics=self._get_metrics_names_for_loggers(),
|
|
1576
|
+
mode="test",
|
|
1577
|
+
)
|
|
1578
|
+
|
|
1579
|
+
@rank_zero_only
|
|
1580
|
+
def _initialize_test_epoch_containers(self):
|
|
1581
|
+
self.test_losses: List[torch.Tensor] = []
|
|
1582
|
+
self.test_metric_values: List[Dict[str, float]] = []
|
|
1583
|
+
|
|
1584
|
+
def on_test_epoch_start(self):
|
|
1585
|
+
if self.params["medcam_enabled"]:
|
|
1586
|
+
self.model.enable_medcam()
|
|
1587
|
+
self.params["medcam_enabled"] = True
|
|
1588
|
+
|
|
1589
|
+
self._current_test_epoch_save_dir = os.path.join(
|
|
1590
|
+
self.output_dir, "output_test", f"epoch_{self.current_epoch}"
|
|
1591
|
+
)
|
|
1592
|
+
self._ensure_path_exists(self._current_test_epoch_save_dir)
|
|
1593
|
+
|
|
1594
|
+
def test_step(self, subject, batch_idx):
|
|
1595
|
+
if self.params["verbose"]:
|
|
1596
|
+
self._print_currently_processed_subject(subject)
|
|
1597
|
+
|
|
1598
|
+
subject_dict = self._initialize_subject_dict_nontraining_mode(subject)
|
|
1599
|
+
label_present = subject["label"] != ["NA"]
|
|
1600
|
+
value_keys_present = "value_keys" in self.params
|
|
1601
|
+
label = None
|
|
1602
|
+
if label_present:
|
|
1603
|
+
subject_dict = self._extend_nontraining_subject_dict_with_label(
|
|
1604
|
+
subject, subject_dict
|
|
1605
|
+
)
|
|
1606
|
+
if (
|
|
1607
|
+
self._problem_type_is_regression
|
|
1608
|
+
or self._problem_type_is_classification
|
|
1609
|
+
and label_present
|
|
1610
|
+
):
|
|
1611
|
+
(
|
|
1612
|
+
model_output,
|
|
1613
|
+
last_input_batch,
|
|
1614
|
+
) = self._get_predictions_on_subject_using_label_sampler(subject_dict)
|
|
1615
|
+
|
|
1616
|
+
if self.params["save_output"]:
|
|
1617
|
+
processed_logit = self._process_prediction_logit_for_row_writing(
|
|
1618
|
+
model_output, self.params["scaling_factor"]
|
|
1619
|
+
)
|
|
1620
|
+
self.rows_to_write.append(
|
|
1621
|
+
self._prepare_row_for_output_csv(
|
|
1622
|
+
subject["subject_id"][0], processed_logit, self.current_epoch
|
|
1623
|
+
)
|
|
1624
|
+
)
|
|
1625
|
+
|
|
1626
|
+
label = self._initialize_nontraining_label_ground_truth_classification_or_regression(
|
|
1627
|
+
subject
|
|
1628
|
+
)
|
|
1629
|
+
else:
|
|
1630
|
+
(
|
|
1631
|
+
model_output,
|
|
1632
|
+
last_input_batch,
|
|
1633
|
+
) = self._get_predictions_on_subject_using_grid_sampler(subject_dict)
|
|
1634
|
+
if self.params["save_output"]:
|
|
1635
|
+
self._save_predictions_for_segmentation_subject(model_output, subject)
|
|
1636
|
+
if self._problem_type_is_segmentation and label_present:
|
|
1637
|
+
label = self._initialize_nontraining_label_ground_truth_segmentation(
|
|
1638
|
+
subject
|
|
1639
|
+
)
|
|
1640
|
+
elif (
|
|
1641
|
+
self._problem_type_is_classification
|
|
1642
|
+
or self._problem_type_is_regression
|
|
1643
|
+
and value_keys_present
|
|
1644
|
+
):
|
|
1645
|
+
label = self._initialize_nontraining_label_ground_truth_classification_or_regression(
|
|
1646
|
+
subject
|
|
1647
|
+
)
|
|
1648
|
+
if label is not None:
|
|
1649
|
+
label = self._process_labels(label)
|
|
1650
|
+
model_output, label = self.pred_target_processor(model_output, label)
|
|
1651
|
+
|
|
1652
|
+
loss = self.loss(model_output, label, last_input_batch)
|
|
1653
|
+
metric_results = self.metric_calculators(
|
|
1654
|
+
model_output, label, subject_spacing=subject.get("spacing", None)
|
|
1655
|
+
)
|
|
1656
|
+
|
|
1657
|
+
self.test_losses.append(loss)
|
|
1658
|
+
self.test_metric_values.append(metric_results)
|
|
1659
|
+
|
|
1660
|
+
@rank_zero_only
|
|
1661
|
+
def on_test_epoch_end(self):
|
|
1662
|
+
test_epoch_average_metrics = {}
|
|
1663
|
+
metric_names = self.test_metric_values[0].keys()
|
|
1664
|
+
for metric_name in metric_names:
|
|
1665
|
+
metric_values = [x[metric_name] for x in self.test_metric_values]
|
|
1666
|
+
test_epoch_average_metrics[
|
|
1667
|
+
metric_name
|
|
1668
|
+
] = self._compute_metric_mean_across_values_from_batches(metric_values)
|
|
1669
|
+
|
|
1670
|
+
mean_loss = self._round_value_to_precision(
|
|
1671
|
+
torch.mean(torch.stack(self.test_losses)).item()
|
|
1672
|
+
)
|
|
1673
|
+
|
|
1674
|
+
self.test_logger.write(
|
|
1675
|
+
self.current_epoch,
|
|
1676
|
+
mean_loss,
|
|
1677
|
+
self._ensure_proper_metric_formatting_for_logging(
|
|
1678
|
+
test_epoch_average_metrics
|
|
1679
|
+
),
|
|
1680
|
+
)
|
|
1681
|
+
|
|
1682
|
+
self.log("test_loss", mean_loss, on_epoch=True, prog_bar=True)
|
|
1683
|
+
self.log_dict(
|
|
1684
|
+
self._prepare_metrics_dict_for_progbar_logging(test_epoch_average_metrics),
|
|
1685
|
+
on_epoch=True,
|
|
1686
|
+
prog_bar=True,
|
|
1687
|
+
sync_dist=False,
|
|
1688
|
+
)
|
|
1689
|
+
if self.params["save_output"] and (
|
|
1690
|
+
self._problem_type_is_regression or self._problem_type_is_classification
|
|
1691
|
+
):
|
|
1692
|
+
self._save_predictions_csv_for_regression_or_classification(
|
|
1693
|
+
self.rows_to_write, self._determine_save_path_to_use()
|
|
1694
|
+
)
|
|
1695
|
+
self._clear_test_epoch_containers()
|
|
1696
|
+
|
|
1697
|
+
@rank_zero_only
|
|
1698
|
+
def _clear_test_epoch_containers(self):
|
|
1699
|
+
self.test_losses = []
|
|
1700
|
+
self.test_metric_values = []
|
|
1701
|
+
|
|
1702
|
+
def on_predict_start(self):
|
|
1703
|
+
self._initialize_inference_containers()
|
|
1704
|
+
self._try_to_load_model_inference_start()
|
|
1705
|
+
|
|
1706
|
+
if self.params.get("differential_privacy"):
|
|
1707
|
+
self._initialize_inference_differential_privacy()
|
|
1708
|
+
|
|
1709
|
+
def _try_to_load_model_inference_start(self):
|
|
1710
|
+
if self._try_to_load_model(self.model_paths["best"]):
|
|
1711
|
+
print(f"Previous best model loaded from {self.model_paths['best']}.")
|
|
1712
|
+
elif self._try_to_load_model(self.model_paths["latest"]):
|
|
1713
|
+
print(f"Previous latest model loaded from {self.model_paths['latest']}.")
|
|
1714
|
+
else:
|
|
1715
|
+
raise RuntimeError(
|
|
1716
|
+
f"Best/latest models not found to load: {self.model_paths}"
|
|
1717
|
+
)
|
|
1718
|
+
|
|
1719
|
+
@rank_zero_only
|
|
1720
|
+
def _initialize_inference_containers(self):
|
|
1721
|
+
self._current_inference_save_dir = os.path.join(
|
|
1722
|
+
self.output_dir, "output_inference"
|
|
1723
|
+
) # TODO here we need some mechanism for separate outputs for nested inference
|
|
1724
|
+
self._ensure_path_exists(self._current_inference_save_dir)
|
|
1725
|
+
self.inference_losses = []
|
|
1726
|
+
self.inference_metric_values = []
|
|
1727
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
1728
|
+
self.rows_to_write = []
|
|
1729
|
+
self.subject_classification_class_probabilities: Dict[
|
|
1730
|
+
str, torch.Tensor
|
|
1731
|
+
] = {}
|
|
1732
|
+
|
|
1733
|
+
@rank_zero_only
|
|
1734
|
+
def _print_inference_initialization_info(self):
|
|
1735
|
+
print("Current model type : ", self.params["model"]["type"])
|
|
1736
|
+
print("Number of dims : ", self.params["model"]["dimension"])
|
|
1737
|
+
if "num_channels" in self.params["model"]:
|
|
1738
|
+
print("Number of channels : ", self.params["model"]["num_channels"])
|
|
1739
|
+
print("Number of classes : ", len(self.params["model"]["class_list"]))
|
|
1740
|
+
self._print_host_info()
|
|
1741
|
+
if self.params["model"]["print_summary"]:
|
|
1742
|
+
self._print_model_summary()
|
|
1743
|
+
|
|
1744
|
+
def predict_step(self, batch, batch_idx):
|
|
1745
|
+
if self.params["verbose"]:
|
|
1746
|
+
self._print_currently_processed_subject(batch)
|
|
1747
|
+
# TODO both of those below should return values to complete the logic
|
|
1748
|
+
# of calculating metrics for classification case that is currently handled
|
|
1749
|
+
# by saving/reading logits.csv file
|
|
1750
|
+
if self.params["modality"] == "rad":
|
|
1751
|
+
return self._radiology_inference_step(batch)
|
|
1752
|
+
else:
|
|
1753
|
+
return self._histopathology_inference_step(batch)
|
|
1754
|
+
|
|
1755
|
+
def _radiology_inference_step(self, subject: torchio.Subject):
|
|
1756
|
+
label_present = subject["label"] != ["NA"]
|
|
1757
|
+
subject_dict = self._initialize_subject_dict_nontraining_mode(subject)
|
|
1758
|
+
if label_present:
|
|
1759
|
+
subject_dict = self._extend_nontraining_subject_dict_with_label(
|
|
1760
|
+
subject, subject_dict
|
|
1761
|
+
)
|
|
1762
|
+
if (
|
|
1763
|
+
self._problem_type_is_regression
|
|
1764
|
+
or self._problem_type_is_classification
|
|
1765
|
+
and label_present
|
|
1766
|
+
):
|
|
1767
|
+
(
|
|
1768
|
+
model_output,
|
|
1769
|
+
last_input_batch,
|
|
1770
|
+
) = self._get_predictions_on_subject_using_label_sampler(subject_dict)
|
|
1771
|
+
|
|
1772
|
+
processed_logit = self._process_prediction_logit_for_row_writing(
|
|
1773
|
+
model_output, self.params["scaling_factor"]
|
|
1774
|
+
)
|
|
1775
|
+
self.rows_to_write.append(
|
|
1776
|
+
self._prepare_row_for_output_csv(
|
|
1777
|
+
subject["subject_id"][0], processed_logit, self.current_epoch
|
|
1778
|
+
)
|
|
1779
|
+
)
|
|
1780
|
+
|
|
1781
|
+
label = self._initialize_nontraining_label_ground_truth_classification_or_regression(
|
|
1782
|
+
subject
|
|
1783
|
+
)
|
|
1784
|
+
else:
|
|
1785
|
+
(
|
|
1786
|
+
model_output,
|
|
1787
|
+
last_input_batch,
|
|
1788
|
+
) = self._get_predictions_on_subject_using_grid_sampler(subject_dict)
|
|
1789
|
+
self._save_predictions_for_segmentation_subject(model_output, subject)
|
|
1790
|
+
label = self._initialize_nontraining_label_ground_truth_segmentation(
|
|
1791
|
+
subject
|
|
1792
|
+
)
|
|
1793
|
+
label = self._process_labels(label)
|
|
1794
|
+
model_output, label = self.pred_target_processor(model_output, label)
|
|
1795
|
+
|
|
1796
|
+
loss = self.loss(model_output, label, last_input_batch)
|
|
1797
|
+
metric_results = self.metric_calculators(
|
|
1798
|
+
model_output, label, subject_spacing=subject.get("spacing", None)
|
|
1799
|
+
)
|
|
1800
|
+
|
|
1801
|
+
self.inference_losses.append(loss)
|
|
1802
|
+
self.inference_metric_values.append(metric_results)
|
|
1803
|
+
else:
|
|
1804
|
+
(
|
|
1805
|
+
model_output,
|
|
1806
|
+
last_input_batch,
|
|
1807
|
+
) = self._get_predictions_on_subject_using_grid_sampler(subject_dict)
|
|
1808
|
+
if self._problem_type_is_classification or self._problem_type_is_regression:
|
|
1809
|
+
processed_logit = self._process_prediction_logit_for_row_writing(
|
|
1810
|
+
model_output
|
|
1811
|
+
)
|
|
1812
|
+
self.rows_to_write.append(
|
|
1813
|
+
self._prepare_row_for_output_csv(
|
|
1814
|
+
subject["subject_id"][0], processed_logit, self.current_epoch
|
|
1815
|
+
)
|
|
1816
|
+
)
|
|
1817
|
+
else:
|
|
1818
|
+
self._save_predictions_for_segmentation_subject(model_output, subject)
|
|
1819
|
+
|
|
1820
|
+
if self._problem_type_is_classification:
|
|
1821
|
+
self.subject_classification_class_probabilities[
|
|
1822
|
+
subject["subject_id"][0]
|
|
1823
|
+
] = F.softmax(model_output, dim=1)
|
|
1824
|
+
|
|
1825
|
+
# TODO this has to be somehow handled in different way, we
|
|
1826
|
+
# are mixing too much logic in this single module
|
|
1827
|
+
def _histopathology_inference_step(self, row_index_tuple):
|
|
1828
|
+
"""
|
|
1829
|
+
Inference step for the histopathology modality. This function is called with an assumption that the highest
|
|
1830
|
+
level dataloader is an iterator over the rows of the dataframe. The function is called for each row of the
|
|
1831
|
+
dataframe.
|
|
1832
|
+
|
|
1833
|
+
Args:
|
|
1834
|
+
row (pd.Series): The row of the dataframe containing the information about the slide to be processed.
|
|
1835
|
+
|
|
1836
|
+
"""
|
|
1837
|
+
row = row_index_tuple[1]
|
|
1838
|
+
subject_name = row[self.params["headers"]["subjectIDHeader"]]
|
|
1839
|
+
inference_results_save_dir_for_subject = os.path.join(
|
|
1840
|
+
self._current_inference_save_dir, "histopathology", str(subject_name)
|
|
1841
|
+
)
|
|
1842
|
+
self._ensure_path_exists(inference_results_save_dir_for_subject)
|
|
1843
|
+
self._prepare_histopath_default_inference_params()
|
|
1844
|
+
openslide_image = openslide.open_slide(
|
|
1845
|
+
row[self.params["headers"]["channelHeaders"]].values[0]
|
|
1846
|
+
)
|
|
1847
|
+
max_defined_slide_level = openslide_image.level_count - 1
|
|
1848
|
+
row_slide_level = min(self.params["slide_level"], max_defined_slide_level)
|
|
1849
|
+
row_slide_level = min(row_slide_level, 0)
|
|
1850
|
+
level_width, level_height = openslide_image.level_dimensions[row_slide_level]
|
|
1851
|
+
patch_size = self._ensure_patch_size_is_2D(self.params["patch_size"])
|
|
1852
|
+
count_map = self._initialize_count_map(level_width, level_height)
|
|
1853
|
+
probabilities_map = self._initialize_probability_map(
|
|
1854
|
+
self.params["model"]["num_classes"], level_width, level_height
|
|
1855
|
+
)
|
|
1856
|
+
|
|
1857
|
+
# TODO this should be done by other object or method
|
|
1858
|
+
transform_requested = get_transforms_for_preprocessing(
|
|
1859
|
+
self.params, [], False, False
|
|
1860
|
+
)
|
|
1861
|
+
patient_dataset = InferTumorSegDataset(
|
|
1862
|
+
row[self.params["headers"]["channelHeaders"]].values[0],
|
|
1863
|
+
patch_size=patch_size,
|
|
1864
|
+
stride_size=self.params["stride_size"],
|
|
1865
|
+
selected_level=row_slide_level,
|
|
1866
|
+
mask_level=self.params["mask_level"],
|
|
1867
|
+
transform=transform_requested,
|
|
1868
|
+
)
|
|
1869
|
+
histopathology_dataloader = torch.utils.data.DataLoader(
|
|
1870
|
+
patient_dataset,
|
|
1871
|
+
batch_size=1,
|
|
1872
|
+
shuffle=False,
|
|
1873
|
+
num_workers=self.params["q_num_workers"],
|
|
1874
|
+
)
|
|
1875
|
+
patch_size_updated_after_transforms = patient_dataset.get_patch_size()
|
|
1876
|
+
if self.params["model"]["print_summary"]:
|
|
1877
|
+
print_model_summary(
|
|
1878
|
+
self.model,
|
|
1879
|
+
self.params["batch_size"],
|
|
1880
|
+
self.params["model"]["num_channels"],
|
|
1881
|
+
patch_size_updated_after_transforms,
|
|
1882
|
+
)
|
|
1883
|
+
count_map, probabilities_map = self._iterate_over_histopathology_loader(
|
|
1884
|
+
histopathology_dataloader,
|
|
1885
|
+
count_map,
|
|
1886
|
+
probabilities_map,
|
|
1887
|
+
patch_size_updated_after_transforms,
|
|
1888
|
+
self.params["model"]["num_classes"],
|
|
1889
|
+
subject_name,
|
|
1890
|
+
)
|
|
1891
|
+
|
|
1892
|
+
map_saver = MapSaver(
|
|
1893
|
+
num_classes=self.params["model"]["num_classes"],
|
|
1894
|
+
slide_level=row_slide_level,
|
|
1895
|
+
blending_alpha=self.params["blending_alpha"],
|
|
1896
|
+
level_height=level_height,
|
|
1897
|
+
level_width=level_width,
|
|
1898
|
+
)
|
|
1899
|
+
map_saver.save_count_map(
|
|
1900
|
+
count_map, save_dir=inference_results_save_dir_for_subject
|
|
1901
|
+
)
|
|
1902
|
+
|
|
1903
|
+
map_saver.save_probability_and_segmentation_maps(
|
|
1904
|
+
probabilities_map,
|
|
1905
|
+
openslide_image,
|
|
1906
|
+
save_dir=inference_results_save_dir_for_subject,
|
|
1907
|
+
)
|
|
1908
|
+
if self._problem_type_is_classification or self._problem_type_is_regression:
|
|
1909
|
+
self._save_predictions_csv_for_regression_or_classification(
|
|
1910
|
+
self.rows_to_write, inference_results_save_dir_for_subject
|
|
1911
|
+
)
|
|
1912
|
+
|
|
1913
|
+
def _iterate_over_histopathology_loader(
|
|
1914
|
+
self,
|
|
1915
|
+
histopathology_dataloader,
|
|
1916
|
+
count_map,
|
|
1917
|
+
probability_map,
|
|
1918
|
+
patch_size,
|
|
1919
|
+
num_classes,
|
|
1920
|
+
subject_name,
|
|
1921
|
+
):
|
|
1922
|
+
for image_patches, (x_coord, y_coord) in histopathology_dataloader:
|
|
1923
|
+
x_coord, y_coord = (
|
|
1924
|
+
x_coord.numpy(),
|
|
1925
|
+
y_coord.numpy(),
|
|
1926
|
+
) # TODO the dataset should do that when fetching
|
|
1927
|
+
image_patches = image_patches.to(self.device)
|
|
1928
|
+
output, _ = self.forward(image_patches)
|
|
1929
|
+
output = output.cpu().detach().numpy()
|
|
1930
|
+
for i in range(output.shape[0]):
|
|
1931
|
+
self._increment_value_of_count_map_at_given_position(
|
|
1932
|
+
count_map, x_coord[i], y_coord[i], patch_size
|
|
1933
|
+
)
|
|
1934
|
+
for class_index in range(num_classes):
|
|
1935
|
+
self._add_value_to_probability_map_at_given_position(
|
|
1936
|
+
probability_map,
|
|
1937
|
+
x_coord[i],
|
|
1938
|
+
y_coord[i],
|
|
1939
|
+
patch_size,
|
|
1940
|
+
output[i][class_index],
|
|
1941
|
+
class_index,
|
|
1942
|
+
)
|
|
1943
|
+
if (
|
|
1944
|
+
self._problem_type_is_regression
|
|
1945
|
+
or self._problem_type_is_classification
|
|
1946
|
+
):
|
|
1947
|
+
row_for_csv_saving = (
|
|
1948
|
+
self._prepare_row_for_output_csv_histopathology_inference(
|
|
1949
|
+
subject_name, x_coord[i], y_coord[i], output[i]
|
|
1950
|
+
)
|
|
1951
|
+
)
|
|
1952
|
+
self.rows_to_write.append(row_for_csv_saving)
|
|
1953
|
+
probability_map = np.divide(probability_map, count_map)
|
|
1954
|
+
return count_map, probability_map
|
|
1955
|
+
|
|
1956
|
+
@staticmethod
|
|
1957
|
+
def _increment_value_of_count_map_at_given_position(
|
|
1958
|
+
count_map, x_coord, y_coord, patch_size
|
|
1959
|
+
):
|
|
1960
|
+
count_map[
|
|
1961
|
+
y_coord : y_coord + patch_size[1], x_coord : x_coord + patch_size[0]
|
|
1962
|
+
] += 1
|
|
1963
|
+
|
|
1964
|
+
@staticmethod
|
|
1965
|
+
def _add_value_to_probability_map_at_given_position(
|
|
1966
|
+
prob_map, x_coord, y_coord, patch_size, value, class_index
|
|
1967
|
+
):
|
|
1968
|
+
prob_map[
|
|
1969
|
+
class_index,
|
|
1970
|
+
y_coord : y_coord + patch_size[1],
|
|
1971
|
+
x_coord : x_coord + patch_size[0],
|
|
1972
|
+
] += value
|
|
1973
|
+
|
|
1974
|
+
# TODO this should be handled by the config parser
|
|
1975
|
+
@rank_zero_only
|
|
1976
|
+
def _prepare_histopath_default_inference_params(self):
|
|
1977
|
+
"""
|
|
1978
|
+
Sets the parameters necessary for histopath inference.
|
|
1979
|
+
"""
|
|
1980
|
+
self.params["stride_size"] = self.params.get("stride_size", None)
|
|
1981
|
+
self.params["slide_level"] = self.params.get("slide_level", 0)
|
|
1982
|
+
self.params["mask_level"] = self.params.get(
|
|
1983
|
+
"mask_level", self.params["slide_level"]
|
|
1984
|
+
)
|
|
1985
|
+
self.params["blending_alpha"] = float(self.params.get("blending_alpha", 0.5))
|
|
1986
|
+
|
|
1987
|
+
@staticmethod
|
|
1988
|
+
def _initialize_count_map(level_width: int, level_height: int):
|
|
1989
|
+
"""
|
|
1990
|
+
Initializes the count maps for the histopathology inference.
|
|
1991
|
+
|
|
1992
|
+
Args:
|
|
1993
|
+
level_width (int): The width of the level.
|
|
1994
|
+
level_height (int): The height of the level.
|
|
1995
|
+
|
|
1996
|
+
Returns:
|
|
1997
|
+
count_map (np.ndarray): The count map.
|
|
1998
|
+
"""
|
|
1999
|
+
return np.zeros((level_height, level_width), dtype=np.uint8)
|
|
2000
|
+
|
|
2001
|
+
@staticmethod
|
|
2002
|
+
def _initialize_probability_map(
|
|
2003
|
+
num_classes: int, level_width: int, level_height: int
|
|
2004
|
+
):
|
|
2005
|
+
"""
|
|
2006
|
+
Initializes the probability maps for the histopathology inference.
|
|
2007
|
+
Called for classification and segmentation problems.
|
|
2008
|
+
|
|
2009
|
+
Args:
|
|
2010
|
+
num_classes (int): The number of classes.
|
|
2011
|
+
level_width (int): The width of the level.
|
|
2012
|
+
level_height (int): The height of the level.
|
|
2013
|
+
|
|
2014
|
+
Returns:
|
|
2015
|
+
probs_map (np.ndarray): The probability map.
|
|
2016
|
+
"""
|
|
2017
|
+
return np.zeros((num_classes, level_height, level_width), dtype=np.float16)
|
|
2018
|
+
|
|
2019
|
+
@staticmethod
|
|
2020
|
+
def _ensure_patch_size_is_2D(patch_size: List[int]):
|
|
2021
|
+
"""
|
|
2022
|
+
Ensures that the patch size is 2D.
|
|
2023
|
+
|
|
2024
|
+
Args:
|
|
2025
|
+
patch_size (List[int]): The patch size.
|
|
2026
|
+
|
|
2027
|
+
Returns:
|
|
2028
|
+
patch_size (List[int]): The 2D patch size.
|
|
2029
|
+
"""
|
|
2030
|
+
if len(patch_size) == 3:
|
|
2031
|
+
return patch_size[:-1]
|
|
2032
|
+
return patch_size
|
|
2033
|
+
|
|
2034
|
+
@rank_zero_only
|
|
2035
|
+
def on_predict_end(self):
|
|
2036
|
+
if self.inference_metric_values:
|
|
2037
|
+
inference_epoch_average_metrics = {}
|
|
2038
|
+
metric_names = self.inference_metric_values[0].keys()
|
|
2039
|
+
for metric_name in metric_names:
|
|
2040
|
+
metric_values = [x[metric_name] for x in self.inference_metric_values]
|
|
2041
|
+
inference_epoch_average_metrics[
|
|
2042
|
+
metric_name
|
|
2043
|
+
] = self._compute_metric_mean_across_values_from_batches(metric_values)
|
|
2044
|
+
|
|
2045
|
+
mean_loss = self._round_value_to_precision(
|
|
2046
|
+
torch.mean(torch.stack(self.inference_losses)).item()
|
|
2047
|
+
)
|
|
2048
|
+
|
|
2049
|
+
print("Inference results:")
|
|
2050
|
+
print(f"Loss: {mean_loss}")
|
|
2051
|
+
print(f"Metrics: {inference_epoch_average_metrics}")
|
|
2052
|
+
|
|
2053
|
+
self._clear_inference_containers()
|
|
2054
|
+
|
|
2055
|
+
@rank_zero_only
|
|
2056
|
+
def _clear_inference_containers(self):
|
|
2057
|
+
self.inference_losses = []
|
|
2058
|
+
self.inference_metric_values = []
|
|
2059
|
+
if self._problem_type_is_regression or self._problem_type_is_classification:
|
|
2060
|
+
self.rows_to_write = []
|
|
2061
|
+
|
|
2062
|
+
def configure_optimizers(self):
|
|
2063
|
+
params = deepcopy(self.params)
|
|
2064
|
+
params["model_parameters"] = self.model.parameters()
|
|
2065
|
+
params["learning_rate"] = self.learning_rate
|
|
2066
|
+
optimizer = get_optimizer(params)
|
|
2067
|
+
if "scheduler" in self.params:
|
|
2068
|
+
params["optimizer_object"] = optimizer
|
|
2069
|
+
scheduler = get_scheduler(params)
|
|
2070
|
+
optimizer_dict = {"optimizer": optimizer, "scheduler": scheduler}
|
|
2071
|
+
if isinstance(scheduler, ReduceLROnPlateau):
|
|
2072
|
+
optimizer_dict["monitor"] = "val_loss"
|
|
2073
|
+
return optimizer_dict
|
|
2074
|
+
return {"optimizer": optimizer}
|
|
2075
|
+
|
|
2076
|
+
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
|
2077
|
+
"""
|
|
2078
|
+
A method called by Lightning to transfer the batch to the device.
|
|
2079
|
+
In case of GANDLF, we need custom logic to transfer the data to the device.
|
|
2080
|
+
"""
|
|
2081
|
+
if not (
|
|
2082
|
+
self.trainer.predicting and self.params["modality"] in ["path", "histo"]
|
|
2083
|
+
):
|
|
2084
|
+
batch = self._move_image_data_to_device(batch, device)
|
|
2085
|
+
batch = self._move_labels_or_values_to_device(batch, device)
|
|
2086
|
+
return batch
|
|
2087
|
+
|
|
2088
|
+
def _move_image_data_to_device(self, subject, device):
|
|
2089
|
+
for channel_key in self.params["channel_keys"]:
|
|
2090
|
+
subject[channel_key][torchio.DATA] = subject[channel_key][torchio.DATA].to(
|
|
2091
|
+
device
|
|
2092
|
+
)
|
|
2093
|
+
return subject
|
|
2094
|
+
|
|
2095
|
+
def _move_labels_or_values_to_device(self, subject, device):
|
|
2096
|
+
if "value_keys" in self.params:
|
|
2097
|
+
for value_key in self.params["value_keys"]:
|
|
2098
|
+
subject[value_key] = subject[value_key].to(device)
|
|
2099
|
+
elif subject["label"] != ["NA"]:
|
|
2100
|
+
subject["label"][torchio.DATA] = subject["label"][torchio.DATA].to(device)
|
|
2101
|
+
|
|
2102
|
+
return subject
|