GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of GANDLF might be problematic. Click here for more details.
- GANDLF/cli/deploy.py +2 -2
- GANDLF/cli/generate_metrics.py +35 -1
- GANDLF/cli/main_run.py +4 -10
- GANDLF/compute/__init__.py +0 -2
- GANDLF/compute/forward_pass.py +0 -1
- GANDLF/compute/generic.py +107 -2
- GANDLF/compute/inference_loop.py +4 -4
- GANDLF/compute/loss_and_metric.py +1 -2
- GANDLF/compute/training_loop.py +10 -10
- GANDLF/config_manager.py +33 -717
- GANDLF/configuration/__init__.py +0 -0
- GANDLF/configuration/default_config.py +73 -0
- GANDLF/configuration/differential_privacy_config.py +16 -0
- GANDLF/configuration/exclude_parameters.py +1 -0
- GANDLF/configuration/model_config.py +82 -0
- GANDLF/configuration/nested_training_config.py +25 -0
- GANDLF/configuration/optimizer_config.py +121 -0
- GANDLF/configuration/parameters_config.py +10 -0
- GANDLF/configuration/patch_sampler_config.py +11 -0
- GANDLF/configuration/post_processing_config.py +10 -0
- GANDLF/configuration/pre_processing_config.py +94 -0
- GANDLF/configuration/scheduler_config.py +92 -0
- GANDLF/configuration/user_defined_config.py +131 -0
- GANDLF/configuration/utils.py +96 -0
- GANDLF/configuration/validators.py +479 -0
- GANDLF/data/__init__.py +14 -16
- GANDLF/data/lightning_datamodule.py +119 -0
- GANDLF/entrypoints/run.py +36 -31
- GANDLF/inference_manager.py +69 -25
- GANDLF/losses/__init__.py +23 -1
- GANDLF/losses/loss_calculators.py +79 -0
- GANDLF/losses/segmentation.py +3 -2
- GANDLF/metrics/__init__.py +26 -0
- GANDLF/metrics/generic.py +1 -1
- GANDLF/metrics/metric_calculators.py +102 -0
- GANDLF/metrics/panoptica_config_brats.yaml +56 -0
- GANDLF/metrics/segmentation_panoptica.py +49 -0
- GANDLF/models/__init__.py +8 -3
- GANDLF/models/lightning_module.py +2102 -0
- GANDLF/optimizers/__init__.py +4 -8
- GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
- GANDLF/schedulers/__init__.py +11 -4
- GANDLF/schedulers/wrap_torch.py +15 -3
- GANDLF/training_manager.py +160 -50
- GANDLF/utils/__init__.py +5 -3
- GANDLF/utils/imaging.py +176 -35
- GANDLF/utils/modelio.py +12 -8
- GANDLF/utils/pred_target_processors.py +71 -0
- GANDLF/utils/tensor.py +2 -1
- GANDLF/utils/write_parse.py +1 -1
- GANDLF/version.py +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import lightning.pytorch as pl
|
|
2
|
+
from GANDLF.compute.generic import (
|
|
3
|
+
TrainingSubsetDataParser,
|
|
4
|
+
ValidationSubsetDataParser,
|
|
5
|
+
TestSubsetDataParser,
|
|
6
|
+
InferenceSubsetDataParserRadiology,
|
|
7
|
+
)
|
|
8
|
+
from torch.utils.data import DataLoader as TorchDataLoader
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GandlfTrainingDatamodule(pl.LightningDataModule):
|
|
13
|
+
def __init__(self, data_dict_files: dict, parameters_dict: dict):
|
|
14
|
+
super().__init__()
|
|
15
|
+
|
|
16
|
+
# Batch size here and reinitialization of dataloader parsers is used
|
|
17
|
+
# in automatic batch size tuning
|
|
18
|
+
|
|
19
|
+
self.batch_size = parameters_dict["batch_size"]
|
|
20
|
+
|
|
21
|
+
# This init procedure is extreme hack, but the only way to get around the
|
|
22
|
+
# need to modify the parameters dict during the parsing procedure
|
|
23
|
+
|
|
24
|
+
params = deepcopy(parameters_dict)
|
|
25
|
+
|
|
26
|
+
train_subset_parser = TrainingSubsetDataParser(
|
|
27
|
+
data_dict_files["training"], params
|
|
28
|
+
)
|
|
29
|
+
self.training_dataset = train_subset_parser.create_subset_dataset()
|
|
30
|
+
params = train_subset_parser.get_params_extended_with_subset_data()
|
|
31
|
+
|
|
32
|
+
val_subset_parser = ValidationSubsetDataParser(
|
|
33
|
+
data_dict_files["validation"], params
|
|
34
|
+
)
|
|
35
|
+
self.validation_dataset = val_subset_parser.create_subset_dataset()
|
|
36
|
+
params = val_subset_parser.get_params_extended_with_subset_data()
|
|
37
|
+
|
|
38
|
+
testing_data = data_dict_files.get("testing", None)
|
|
39
|
+
self.test_dataset = None
|
|
40
|
+
if testing_data is not None:
|
|
41
|
+
test_subset_parser = TestSubsetDataParser(testing_data, params)
|
|
42
|
+
self.test_dataset = test_subset_parser.create_subset_dataset()
|
|
43
|
+
params = test_subset_parser.get_params_extended_with_subset_data()
|
|
44
|
+
|
|
45
|
+
self.parameters_dict = params
|
|
46
|
+
|
|
47
|
+
def _get_dataloader(self, dataset, batch_size: int, shuffle: bool):
|
|
48
|
+
return TorchDataLoader(
|
|
49
|
+
dataset,
|
|
50
|
+
batch_size=batch_size,
|
|
51
|
+
shuffle=shuffle,
|
|
52
|
+
num_workers=self.updated_parameters_dict.get("num_workers_dataloader", 1),
|
|
53
|
+
pin_memory=self.updated_parameters_dict.get("pin_memory_dataloader", False),
|
|
54
|
+
prefetch_factor=self.updated_parameters_dict.get(
|
|
55
|
+
"prefetch_factor_dataloader", 2
|
|
56
|
+
),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def updated_parameters_dict(self):
|
|
61
|
+
return self.parameters_dict
|
|
62
|
+
|
|
63
|
+
def train_dataloader(self):
|
|
64
|
+
self.updated_parameters_dict["batch_size"] = self.batch_size
|
|
65
|
+
return self._get_dataloader(
|
|
66
|
+
self.training_dataset, self.batch_size, shuffle=True
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def val_dataloader(self):
|
|
70
|
+
return self._get_dataloader(
|
|
71
|
+
self.validation_dataset, batch_size=1, shuffle=False
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def test_dataloader(self):
|
|
75
|
+
if self.test_dataset is None:
|
|
76
|
+
return None
|
|
77
|
+
return self._get_dataloader(self.test_dataset, batch_size=1, shuffle=False)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class GandlfInferenceDatamodule(pl.LightningDataModule):
|
|
81
|
+
def __init__(self, dataframe, parameters_dict):
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.dataframe = dataframe
|
|
84
|
+
params = deepcopy(parameters_dict)
|
|
85
|
+
self.parameters_dict = params
|
|
86
|
+
if self.parameters_dict["modality"] == "rad":
|
|
87
|
+
inference_subset_data_parser_radiology = InferenceSubsetDataParserRadiology(
|
|
88
|
+
self.dataframe, params
|
|
89
|
+
)
|
|
90
|
+
self.inference_dataset = (
|
|
91
|
+
inference_subset_data_parser_radiology.create_subset_dataset()
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self.parameters_dict = (
|
|
95
|
+
inference_subset_data_parser_radiology.get_params_extended_with_subset_data()
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def updated_parameters_dict(self):
|
|
100
|
+
return self.parameters_dict
|
|
101
|
+
|
|
102
|
+
def predict_dataloader(self):
|
|
103
|
+
if self.parameters_dict["modality"] == "rad":
|
|
104
|
+
return TorchDataLoader(
|
|
105
|
+
self.inference_dataset,
|
|
106
|
+
batch_size=1,
|
|
107
|
+
shuffle=False,
|
|
108
|
+
num_workers=self.updated_parameters_dict.get(
|
|
109
|
+
"num_workers_dataloader", 1
|
|
110
|
+
),
|
|
111
|
+
pin_memory=self.updated_parameters_dict.get(
|
|
112
|
+
"pin_memory_dataloader", False
|
|
113
|
+
),
|
|
114
|
+
prefetch_factor=self.updated_parameters_dict.get(
|
|
115
|
+
"prefetch_factor_dataloader", 2
|
|
116
|
+
),
|
|
117
|
+
)
|
|
118
|
+
elif self.parameters_dict["modality"] in ["path", "histo"]:
|
|
119
|
+
return self.dataframe.iterrows()
|
GANDLF/entrypoints/run.py
CHANGED
|
@@ -14,16 +14,19 @@ from GANDLF.cli import main_run, copyrightMessage
|
|
|
14
14
|
from GANDLF.entrypoints import append_copyright_to_help
|
|
15
15
|
from GANDLF.utils import logger_setup
|
|
16
16
|
|
|
17
|
+
from warnings import warn
|
|
18
|
+
|
|
17
19
|
|
|
18
20
|
def _run(
|
|
19
21
|
config: str,
|
|
20
22
|
input_data: str,
|
|
21
23
|
train_flag: bool,
|
|
22
24
|
model_dir: str,
|
|
23
|
-
device: str,
|
|
24
25
|
reset_flag: bool,
|
|
25
26
|
resume_flag: bool,
|
|
27
|
+
device: Optional[str],
|
|
26
28
|
output_path: Optional[str],
|
|
29
|
+
profile: Optional[bool] = False,
|
|
27
30
|
):
|
|
28
31
|
if model_dir is None and output_path:
|
|
29
32
|
model_dir = output_path
|
|
@@ -48,7 +51,11 @@ def _run(
|
|
|
48
51
|
"'reset' and 'resume' are mutually exclusive; 'resume' will be used."
|
|
49
52
|
)
|
|
50
53
|
reset_flag = False
|
|
51
|
-
|
|
54
|
+
if device is not None:
|
|
55
|
+
warn(
|
|
56
|
+
"Device parameter is deprecated and has no effect. See migration guide docs.",
|
|
57
|
+
DeprecationWarning,
|
|
58
|
+
)
|
|
52
59
|
# TODO: check that output_path is not passed in training mode;
|
|
53
60
|
# maybe user misconfigured the command
|
|
54
61
|
|
|
@@ -56,20 +63,20 @@ def _run(
|
|
|
56
63
|
logging.debug(f"{input_data=}")
|
|
57
64
|
logging.debug(f"{train_flag=}")
|
|
58
65
|
logging.debug(f"{model_dir=}")
|
|
59
|
-
logging.debug(f"{device=}")
|
|
60
66
|
logging.debug(f"{reset_flag=}")
|
|
61
67
|
logging.debug(f"{resume_flag=}")
|
|
62
68
|
logging.debug(f"{output_path=}")
|
|
69
|
+
logging.debug(f"{profile=}")
|
|
63
70
|
|
|
64
71
|
main_run(
|
|
65
72
|
data_csv=input_data,
|
|
66
73
|
config_file=config,
|
|
67
74
|
model_dir=model_dir,
|
|
68
75
|
train_mode=train_flag,
|
|
69
|
-
device=device,
|
|
70
76
|
resume=resume_flag,
|
|
71
77
|
reset=reset_flag,
|
|
72
78
|
output_dir=output_path,
|
|
79
|
+
profile=profile,
|
|
73
80
|
)
|
|
74
81
|
print("Finished.")
|
|
75
82
|
|
|
@@ -105,21 +112,6 @@ def _run(
|
|
|
105
112
|
help="Training: Output directory to save intermediate files and model weights; "
|
|
106
113
|
"inference: location of previous training session output",
|
|
107
114
|
)
|
|
108
|
-
@click.option(
|
|
109
|
-
"--device",
|
|
110
|
-
"-d",
|
|
111
|
-
# TODO: Not sure it's worth to restrict this list. What about other devices?
|
|
112
|
-
# GaNDLF guarantees to work properly with these two options, but
|
|
113
|
-
# other values may be partially working also.
|
|
114
|
-
# * GaNDLF code convert `-1` to `cpu` (i.e. it is expected somebody may pass -1)
|
|
115
|
-
# * `cuda:0` should work also, isn't it? Just would not be treated as `cuda`
|
|
116
|
-
# * Would `mps` work?
|
|
117
|
-
# * int values (like `1`) - are they supported? (legacy mode for cuda https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)
|
|
118
|
-
type=click.Choice(["cuda", "cpu"]),
|
|
119
|
-
required=True, # FIXME: either keep default value, or set required flag
|
|
120
|
-
help="Device to perform requested session on 'cpu' or 'cuda'; "
|
|
121
|
-
"for cuda, ensure CUDA_VISIBLE_DEVICES env var is set",
|
|
122
|
-
)
|
|
123
115
|
@click.option(
|
|
124
116
|
"--reset",
|
|
125
117
|
"-rt",
|
|
@@ -145,18 +137,31 @@ def _run(
|
|
|
145
137
|
default=None,
|
|
146
138
|
help="Output file which will contain the logs.",
|
|
147
139
|
)
|
|
140
|
+
@click.option(
|
|
141
|
+
"--profile",
|
|
142
|
+
"-pf",
|
|
143
|
+
is_flag=True,
|
|
144
|
+
help="Track the run time and memory consumption for each layer",
|
|
145
|
+
)
|
|
146
|
+
@click.option(
|
|
147
|
+
"--device",
|
|
148
|
+
"-d",
|
|
149
|
+
type=str,
|
|
150
|
+
help="DEPRECATED - has no effect, see migration guide docs. Device to run the model on.",
|
|
151
|
+
)
|
|
148
152
|
@append_copyright_to_help
|
|
149
153
|
def new_way(
|
|
150
154
|
config: str,
|
|
151
155
|
input_data: str,
|
|
152
156
|
train: bool,
|
|
153
157
|
model_dir: str,
|
|
154
|
-
device: str,
|
|
155
158
|
reset: bool,
|
|
156
159
|
resume: bool,
|
|
157
160
|
output_path: str,
|
|
158
161
|
raw_input: str,
|
|
162
|
+
profile: bool,
|
|
159
163
|
log_file: str,
|
|
164
|
+
device: str,
|
|
160
165
|
):
|
|
161
166
|
"""Semantic segmentation, regression, and classification for medical images using Deep Learning."""
|
|
162
167
|
|
|
@@ -166,10 +171,11 @@ def new_way(
|
|
|
166
171
|
input_data=input_data,
|
|
167
172
|
train_flag=train,
|
|
168
173
|
model_dir=model_dir,
|
|
169
|
-
device=device,
|
|
170
174
|
reset_flag=reset,
|
|
171
175
|
resume_flag=resume,
|
|
172
176
|
output_path=output_path,
|
|
177
|
+
profile=profile,
|
|
178
|
+
device=device,
|
|
173
179
|
)
|
|
174
180
|
|
|
175
181
|
|
|
@@ -227,15 +233,7 @@ def old_way():
|
|
|
227
233
|
type=str,
|
|
228
234
|
help="Training: Output directory to save intermediate files and model weights; inference: location of previous training session output",
|
|
229
235
|
)
|
|
230
|
-
|
|
231
|
-
"-d",
|
|
232
|
-
"--device",
|
|
233
|
-
default="cuda", # TODO: default value doesn't work as arg is required
|
|
234
|
-
metavar="",
|
|
235
|
-
type=str,
|
|
236
|
-
required=True,
|
|
237
|
-
help="Device to perform requested session on 'cpu' or 'cuda'; for cuda, ensure CUDA_VISIBLE_DEVICES env var is set",
|
|
238
|
-
)
|
|
236
|
+
|
|
239
237
|
parser.add_argument(
|
|
240
238
|
"-rt",
|
|
241
239
|
"--reset",
|
|
@@ -267,6 +265,13 @@ def old_way():
|
|
|
267
265
|
version="%(prog)s v{}".format(version) + "\n\n" + copyrightMessage,
|
|
268
266
|
help="Show program's version number and exit.",
|
|
269
267
|
)
|
|
268
|
+
parser.add_argument(
|
|
269
|
+
"-d",
|
|
270
|
+
"--device",
|
|
271
|
+
metavar="",
|
|
272
|
+
type=str,
|
|
273
|
+
help="DEPRECATED - has no effect, see migration guide docs. Device to run the model on.",
|
|
274
|
+
)
|
|
270
275
|
|
|
271
276
|
# This is a dummy argument that exists to trigger MLCube mounting requirements.
|
|
272
277
|
# Do not remove.
|
|
@@ -282,10 +287,10 @@ def old_way():
|
|
|
282
287
|
input_data=args.inputdata,
|
|
283
288
|
train_flag=args.train,
|
|
284
289
|
model_dir=args.modeldir,
|
|
285
|
-
device=args.device,
|
|
286
290
|
reset_flag=args.reset,
|
|
287
291
|
resume_flag=args.resume,
|
|
288
292
|
output_path=args.outputdir,
|
|
293
|
+
device=args.device,
|
|
289
294
|
)
|
|
290
295
|
|
|
291
296
|
|
GANDLF/inference_manager.py
CHANGED
|
@@ -4,16 +4,16 @@ from typing import Optional
|
|
|
4
4
|
|
|
5
5
|
import pandas as pd
|
|
6
6
|
import torch
|
|
7
|
-
import torch.nn.functional as F
|
|
8
|
-
from GANDLF.compute import inference_loop
|
|
9
7
|
from GANDLF.utils import get_unique_timestamp
|
|
8
|
+
import lightning.pytorch as pl
|
|
9
|
+
from GANDLF.models.lightning_module import GandlfLightningModule
|
|
10
|
+
from GANDLF.data.lightning_datamodule import GandlfInferenceDatamodule
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def InferenceManager(
|
|
13
14
|
dataframe: pd.DataFrame,
|
|
14
15
|
modelDir: str,
|
|
15
16
|
parameters: dict,
|
|
16
|
-
device: str,
|
|
17
17
|
outputDir: Optional[str] = None,
|
|
18
18
|
) -> None:
|
|
19
19
|
"""
|
|
@@ -27,7 +27,7 @@ def InferenceManager(
|
|
|
27
27
|
outputDir (Optional[str], optional): The output directory for the inference results. Defaults to None.
|
|
28
28
|
"""
|
|
29
29
|
if outputDir is None:
|
|
30
|
-
outputDir = os.path.join(modelDir, get_unique_timestamp())
|
|
30
|
+
outputDir = os.path.join(modelDir, get_unique_timestamp(), "output_inference")
|
|
31
31
|
print(
|
|
32
32
|
"Output directory not provided, creating a new directory with a unique timestamp: ",
|
|
33
33
|
outputDir,
|
|
@@ -43,6 +43,35 @@ def InferenceManager(
|
|
|
43
43
|
n_folds = parameters["nested_training"]["validation"]
|
|
44
44
|
modelDir_split = modelDir.split(",") if "," in modelDir else [modelDir]
|
|
45
45
|
|
|
46
|
+
# This should be handled by config parser
|
|
47
|
+
accelerator = parameters.get("accelerator", "auto")
|
|
48
|
+
allowed_accelerators = ["cpu", "gpu", "auto"]
|
|
49
|
+
# codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
|
|
50
|
+
assert (
|
|
51
|
+
accelerator in allowed_accelerators
|
|
52
|
+
), f"Invalid accelerator selected: {accelerator}. Please select from {allowed_accelerators}"
|
|
53
|
+
strategy = parameters.get("strategy", "auto")
|
|
54
|
+
allowed_strategies = ["auto", "ddp"]
|
|
55
|
+
# codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
|
|
56
|
+
assert (
|
|
57
|
+
strategy in allowed_strategies
|
|
58
|
+
), f"Invalid strategy selected: {strategy}. Please select from {allowed_strategies}"
|
|
59
|
+
precision = parameters.get("precision", "32")
|
|
60
|
+
allowed_precisions = [
|
|
61
|
+
"64",
|
|
62
|
+
"64-true",
|
|
63
|
+
"32",
|
|
64
|
+
"32-true",
|
|
65
|
+
"16",
|
|
66
|
+
"16-mixed",
|
|
67
|
+
"bf16",
|
|
68
|
+
"bf16-mixed",
|
|
69
|
+
]
|
|
70
|
+
# codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
|
|
71
|
+
assert (
|
|
72
|
+
precision in allowed_precisions
|
|
73
|
+
), f"Invalid precision selected: {precision}. Please select from {allowed_precisions}"
|
|
74
|
+
|
|
46
75
|
averaged_probs_list = []
|
|
47
76
|
for current_modelDir in modelDir_split:
|
|
48
77
|
fold_dirs = (
|
|
@@ -55,7 +84,6 @@ def InferenceManager(
|
|
|
55
84
|
else [current_modelDir]
|
|
56
85
|
)
|
|
57
86
|
|
|
58
|
-
probs_list = []
|
|
59
87
|
is_classification = parameters["problem_type"] == "classification"
|
|
60
88
|
parameters["model"].setdefault("type", "torch")
|
|
61
89
|
class_list = (
|
|
@@ -63,27 +91,43 @@ def InferenceManager(
|
|
|
63
91
|
if is_classification
|
|
64
92
|
else None
|
|
65
93
|
)
|
|
66
|
-
|
|
94
|
+
probs_list = None
|
|
67
95
|
for fold_dir in fold_dirs:
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
96
|
+
trainer = pl.Trainer(
|
|
97
|
+
accelerator=accelerator,
|
|
98
|
+
strategy=strategy,
|
|
99
|
+
fast_dev_run=False,
|
|
100
|
+
devices=parameters.get("devices", "auto"),
|
|
101
|
+
num_nodes=parameters.get("num_nodes", 1),
|
|
102
|
+
precision=precision,
|
|
103
|
+
gradient_clip_algorithm=parameters["clip_mode"],
|
|
104
|
+
gradient_clip_val=parameters["clip_grad"],
|
|
105
|
+
max_epochs=parameters["num_epochs"],
|
|
106
|
+
sync_batchnorm=False,
|
|
107
|
+
enable_checkpointing=False,
|
|
108
|
+
logger=False,
|
|
109
|
+
num_sanity_val_steps=0,
|
|
75
110
|
)
|
|
111
|
+
datamodule = GandlfInferenceDatamodule(dataframe, parameters)
|
|
112
|
+
parameters = datamodule.updated_parameters_dict
|
|
113
|
+
lightning_module = GandlfLightningModule(parameters, output_dir=fold_dir)
|
|
76
114
|
|
|
115
|
+
if parameters.get("auto_batch_size_find", False):
|
|
116
|
+
print(
|
|
117
|
+
"Auto batch size find is not supported in inference. Dataloader batch size is always 1."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
trainer.predict(lightning_module, datamodule=datamodule)
|
|
77
121
|
if is_classification:
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
122
|
+
prob_values_for_all_subjects_in_fold = list(
|
|
123
|
+
lightning_module.subject_classification_class_probabilities.values()
|
|
124
|
+
)
|
|
125
|
+
if prob_values_for_all_subjects_in_fold:
|
|
126
|
+
probs_list = torch.stack(
|
|
127
|
+
prob_values_for_all_subjects_in_fold, dim=1
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if is_classification and probs_list is not None:
|
|
87
131
|
averaged_probs_list.append(torch.mean(probs_list, 0))
|
|
88
132
|
|
|
89
133
|
# this logic should be changed if we want to do multi-fold inference for histo images
|
|
@@ -93,9 +137,9 @@ def InferenceManager(
|
|
|
93
137
|
)
|
|
94
138
|
averaged_probs_df["SubjectID"] = dataframe.iloc[:, 0]
|
|
95
139
|
|
|
96
|
-
averaged_probs_across_models =
|
|
97
|
-
torch.stack(averaged_probs_list), 0
|
|
98
|
-
)
|
|
140
|
+
averaged_probs_across_models = (
|
|
141
|
+
torch.mean(torch.stack(averaged_probs_list), 0).cpu().numpy()
|
|
142
|
+
)
|
|
99
143
|
averaged_probs_df[class_list] = averaged_probs_across_models
|
|
100
144
|
averaged_probs_df["PredictedClass"] = [
|
|
101
145
|
class_list[idx] for idx in averaged_probs_across_models.argmax(axis=1)
|
GANDLF/losses/__init__.py
CHANGED
|
@@ -13,7 +13,6 @@ from .segmentation import (
|
|
|
13
13
|
from .regression import CE, CEL, MSE_loss, L1_loss
|
|
14
14
|
from .hybrid import DCCE, DCCE_Logits, DC_Focal
|
|
15
15
|
|
|
16
|
-
|
|
17
16
|
# global defines for the losses
|
|
18
17
|
global_losses_dict = {
|
|
19
18
|
"dc": MCD_loss,
|
|
@@ -38,3 +37,26 @@ global_losses_dict = {
|
|
|
38
37
|
"focal": FocalLoss,
|
|
39
38
|
"dc_focal": DC_Focal,
|
|
40
39
|
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_loss(params: dict) -> object:
|
|
43
|
+
"""
|
|
44
|
+
Function to get the loss definition.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
params (dict): The parameters' dictionary.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
loss (object): The loss definition.
|
|
51
|
+
"""
|
|
52
|
+
# TODO This check looks like legacy code, should we have it?
|
|
53
|
+
|
|
54
|
+
if isinstance(params["loss_function"], dict):
|
|
55
|
+
chosen_loss = list(params["loss_function"].keys())[0].lower()
|
|
56
|
+
else:
|
|
57
|
+
chosen_loss = params["loss_function"].lower()
|
|
58
|
+
assert (
|
|
59
|
+
chosen_loss in global_losses_dict
|
|
60
|
+
), f"Could not find the requested loss function '{params['loss_function']}'"
|
|
61
|
+
|
|
62
|
+
return global_losses_dict[chosen_loss]
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
from GANDLF.losses import get_loss
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AbstractLossCalculator(ABC):
|
|
8
|
+
def __init__(self, params: dict):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.params = params
|
|
11
|
+
self._initialize_loss()
|
|
12
|
+
|
|
13
|
+
def _initialize_loss(self):
|
|
14
|
+
self.loss = get_loss(self.params)
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def __call__(
|
|
18
|
+
self, prediction: torch.Tensor, target: torch.Tensor, *args
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LossCalculatorSDNet(AbstractLossCalculator):
|
|
24
|
+
def __init__(self, params):
|
|
25
|
+
super().__init__(params)
|
|
26
|
+
self.l1_loss = get_loss(params)
|
|
27
|
+
self.kld_loss = get_loss(params)
|
|
28
|
+
self.mse_loss = get_loss(params)
|
|
29
|
+
|
|
30
|
+
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args):
|
|
31
|
+
if len(prediction) < 2:
|
|
32
|
+
image: torch.Tensor = args[0]
|
|
33
|
+
loss_seg = self.loss(prediction[0], target.squeeze(-1), self.params)
|
|
34
|
+
loss_reco = self.l1_loss(prediction[1], image[:, :1, ...], None)
|
|
35
|
+
loss_kld = self.kld_loss(prediction[2], prediction[3])
|
|
36
|
+
loss_cycle = self.mse_loss(prediction[2], prediction[4], None)
|
|
37
|
+
return 0.01 * loss_kld + loss_reco + 10 * loss_seg + loss_cycle
|
|
38
|
+
else:
|
|
39
|
+
return self.loss(prediction, target, self.params)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LossCalculatorDeepSupervision(AbstractLossCalculator):
|
|
43
|
+
def __init__(self, params):
|
|
44
|
+
super().__init__(params)
|
|
45
|
+
# This was taken from current Gandlf code, but I am not sure if
|
|
46
|
+
# we should have this set rigidly here, as it enforces the number of
|
|
47
|
+
# classes to be 4.
|
|
48
|
+
self.loss_weights = [0.5, 0.25, 0.175, 0.075]
|
|
49
|
+
|
|
50
|
+
def __call__(
|
|
51
|
+
self, prediction: torch.Tensor, target: torch.Tensor, *args
|
|
52
|
+
) -> torch.Tensor:
|
|
53
|
+
loss_values = []
|
|
54
|
+
for i, pred in enumerate(prediction):
|
|
55
|
+
loss_values.append(
|
|
56
|
+
self.loss(pred, target[i], self.params) * self.loss_weights[i]
|
|
57
|
+
)
|
|
58
|
+
loss = torch.stack(loss_values).sum()
|
|
59
|
+
return loss
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class LossCalculatorSimple(AbstractLossCalculator):
|
|
63
|
+
def __call__(
|
|
64
|
+
self, prediction: torch.Tensor, target: torch.Tensor, *args
|
|
65
|
+
) -> torch.Tensor:
|
|
66
|
+
return self.loss(prediction, target, self.params)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class LossCalculatorFactory:
|
|
70
|
+
def __init__(self, params: dict):
|
|
71
|
+
self.params = params
|
|
72
|
+
|
|
73
|
+
def get_loss_calculator(self) -> AbstractLossCalculator:
|
|
74
|
+
if self.params["model"]["architecture"] == "sdnet":
|
|
75
|
+
return LossCalculatorSDNet(self.params)
|
|
76
|
+
elif "deep" in self.params["model"]["architecture"].lower():
|
|
77
|
+
return LossCalculatorDeepSupervision(self.params)
|
|
78
|
+
else:
|
|
79
|
+
return LossCalculatorSimple(self.params)
|
GANDLF/losses/segmentation.py
CHANGED
|
@@ -87,7 +87,8 @@ def generic_loss_calculator(
|
|
|
87
87
|
accumulated_loss = 0
|
|
88
88
|
# default to a ridiculous value so that it is ignored by default
|
|
89
89
|
ignore_class = -1e10 if ignore_class is None else ignore_class
|
|
90
|
-
|
|
90
|
+
predicted = predicted.squeeze(-1)
|
|
91
|
+
target = target.squeeze(-1)
|
|
91
92
|
for class_index in range(num_class):
|
|
92
93
|
if class_index != ignore_class:
|
|
93
94
|
current_loss = loss_criteria(
|
|
@@ -334,7 +335,7 @@ def FocalLoss(
|
|
|
334
335
|
torch.Tensor: Computed focal loss for a single class.
|
|
335
336
|
"""
|
|
336
337
|
ce_loss = torch.nn.CrossEntropyLoss(reduce=False)
|
|
337
|
-
logpt = ce_loss(preds, target)
|
|
338
|
+
logpt = ce_loss(preds.squeeze(-1), target.squeeze(-1))
|
|
338
339
|
pt = torch.exp(-logpt)
|
|
339
340
|
loss = ((1 - pt) ** gamma) * logpt
|
|
340
341
|
return_loss = loss.sum()
|
GANDLF/metrics/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
All the metrics are to be called from here
|
|
3
3
|
"""
|
|
4
|
+
from warnings import warn
|
|
4
5
|
from typing import Union
|
|
5
6
|
|
|
6
7
|
from GANDLF.losses.regression import MSE_loss, CEL
|
|
@@ -40,6 +41,7 @@ from .synthesis import (
|
|
|
40
41
|
)
|
|
41
42
|
import GANDLF.metrics.classification as classification
|
|
42
43
|
import GANDLF.metrics.regression as regression
|
|
44
|
+
from .segmentation_panoptica import generate_instance_segmentation
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
# global defines for the metrics
|
|
@@ -100,6 +102,30 @@ surface_distance_ids = [
|
|
|
100
102
|
]
|
|
101
103
|
|
|
102
104
|
|
|
105
|
+
def get_metrics(params: dict) -> dict:
|
|
106
|
+
"""
|
|
107
|
+
Returns an dictionary of containing calculators of the specified metric functions
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
params (dict): A dictionary containing the overall training parameters.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
metric_calculators (dict): A dictionary containing the calculators of the specified metric functions.
|
|
114
|
+
"""
|
|
115
|
+
metric_calculators = {}
|
|
116
|
+
for metric_name in params["metrics"]:
|
|
117
|
+
metric_name = metric_name.lower()
|
|
118
|
+
if metric_name not in global_metrics_dict:
|
|
119
|
+
warn(
|
|
120
|
+
f"Metric {metric_name} not found in global metrics dictionary, it will not be used.",
|
|
121
|
+
UserWarning,
|
|
122
|
+
)
|
|
123
|
+
continue
|
|
124
|
+
else:
|
|
125
|
+
metric_calculators[metric_name] = global_metrics_dict[metric_name]
|
|
126
|
+
return metric_calculators
|
|
127
|
+
|
|
128
|
+
|
|
103
129
|
def overall_stats(predictions, ground_truth, params) -> dict[str, Union[float, list]]:
|
|
104
130
|
"""
|
|
105
131
|
Generates a dictionary of metrics calculated on the overall predictions and ground truths.
|
GANDLF/metrics/generic.py
CHANGED
|
@@ -34,7 +34,7 @@ def generic_function_output_with_check(
|
|
|
34
34
|
print(
|
|
35
35
|
"WARNING: Negative values detected in prediction, cannot compute torchmetrics calculations."
|
|
36
36
|
)
|
|
37
|
-
return torch.
|
|
37
|
+
return torch.tensor(0, device=prediction.device)
|
|
38
38
|
else:
|
|
39
39
|
# I need to do this with try-except, otherwise for binary problems it will
|
|
40
40
|
# raise and error as the binary metrics do not have .num_classes
|