GANDLF 0.1.3.dev20250319__py3-none-any.whl → 0.1.4.dev20250503__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of GANDLF might be problematic. Click here for more details.
- GANDLF/cli/deploy.py +2 -2
- GANDLF/cli/generate_metrics.py +21 -0
- GANDLF/cli/main_run.py +4 -12
- GANDLF/compute/__init__.py +0 -2
- GANDLF/compute/forward_pass.py +0 -1
- GANDLF/compute/generic.py +107 -2
- GANDLF/compute/inference_loop.py +4 -4
- GANDLF/compute/loss_and_metric.py +1 -2
- GANDLF/compute/training_loop.py +10 -10
- GANDLF/config_manager.py +26 -716
- GANDLF/configuration/__init__.py +0 -0
- GANDLF/configuration/default_config.py +73 -0
- GANDLF/configuration/differential_privacy_config.py +16 -0
- GANDLF/configuration/exclude_parameters.py +1 -0
- GANDLF/configuration/model_config.py +82 -0
- GANDLF/configuration/nested_training_config.py +25 -0
- GANDLF/configuration/optimizer_config.py +121 -0
- GANDLF/configuration/parameters_config.py +10 -0
- GANDLF/configuration/patch_sampler_config.py +11 -0
- GANDLF/configuration/post_processing_config.py +10 -0
- GANDLF/configuration/pre_processing_config.py +94 -0
- GANDLF/configuration/scheduler_config.py +90 -0
- GANDLF/configuration/user_defined_config.py +131 -0
- GANDLF/configuration/utils.py +96 -0
- GANDLF/configuration/validators.py +479 -0
- GANDLF/data/__init__.py +14 -16
- GANDLF/data/lightning_datamodule.py +119 -0
- GANDLF/entrypoints/run.py +29 -35
- GANDLF/inference_manager.py +69 -25
- GANDLF/losses/__init__.py +23 -1
- GANDLF/losses/loss_calculators.py +79 -0
- GANDLF/losses/segmentation.py +3 -2
- GANDLF/metrics/__init__.py +26 -0
- GANDLF/metrics/generic.py +1 -1
- GANDLF/metrics/metric_calculators.py +102 -0
- GANDLF/metrics/panoptica_config_brats.yaml +50 -0
- GANDLF/metrics/segmentation_panoptica.py +35 -0
- GANDLF/models/__init__.py +8 -3
- GANDLF/models/lightning_module.py +2102 -0
- GANDLF/optimizers/__init__.py +4 -8
- GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
- GANDLF/schedulers/__init__.py +6 -2
- GANDLF/training_manager.py +159 -69
- GANDLF/utils/__init__.py +4 -3
- GANDLF/utils/imaging.py +121 -2
- GANDLF/utils/modelio.py +9 -7
- GANDLF/utils/pred_target_processors.py +71 -0
- GANDLF/utils/write_parse.py +1 -1
- GANDLF/version.py +1 -1
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info}/METADATA +14 -8
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info}/RECORD +55 -32
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info}/WHEEL +1 -1
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info}/entry_points.txt +0 -0
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info/licenses}/LICENSE +0 -0
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import lightning.pytorch as pl
|
|
2
|
+
from GANDLF.compute.generic import (
|
|
3
|
+
TrainingSubsetDataParser,
|
|
4
|
+
ValidationSubsetDataParser,
|
|
5
|
+
TestSubsetDataParser,
|
|
6
|
+
InferenceSubsetDataParserRadiology,
|
|
7
|
+
)
|
|
8
|
+
from torch.utils.data import DataLoader as TorchDataLoader
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GandlfTrainingDatamodule(pl.LightningDataModule):
|
|
13
|
+
def __init__(self, data_dict_files: dict, parameters_dict: dict):
|
|
14
|
+
super().__init__()
|
|
15
|
+
|
|
16
|
+
# Batch size here and reinitialization of dataloader parsers is used
|
|
17
|
+
# in automatic batch size tuning
|
|
18
|
+
|
|
19
|
+
self.batch_size = parameters_dict["batch_size"]
|
|
20
|
+
|
|
21
|
+
# This init procedure is extreme hack, but the only way to get around the
|
|
22
|
+
# need to modify the parameters dict during the parsing procedure
|
|
23
|
+
|
|
24
|
+
params = deepcopy(parameters_dict)
|
|
25
|
+
|
|
26
|
+
train_subset_parser = TrainingSubsetDataParser(
|
|
27
|
+
data_dict_files["training"], params
|
|
28
|
+
)
|
|
29
|
+
self.training_dataset = train_subset_parser.create_subset_dataset()
|
|
30
|
+
params = train_subset_parser.get_params_extended_with_subset_data()
|
|
31
|
+
|
|
32
|
+
val_subset_parser = ValidationSubsetDataParser(
|
|
33
|
+
data_dict_files["validation"], params
|
|
34
|
+
)
|
|
35
|
+
self.validation_dataset = val_subset_parser.create_subset_dataset()
|
|
36
|
+
params = val_subset_parser.get_params_extended_with_subset_data()
|
|
37
|
+
|
|
38
|
+
testing_data = data_dict_files.get("testing", None)
|
|
39
|
+
self.test_dataset = None
|
|
40
|
+
if testing_data is not None:
|
|
41
|
+
test_subset_parser = TestSubsetDataParser(testing_data, params)
|
|
42
|
+
self.test_dataset = test_subset_parser.create_subset_dataset()
|
|
43
|
+
params = test_subset_parser.get_params_extended_with_subset_data()
|
|
44
|
+
|
|
45
|
+
self.parameters_dict = params
|
|
46
|
+
|
|
47
|
+
def _get_dataloader(self, dataset, batch_size: int, shuffle: bool):
|
|
48
|
+
return TorchDataLoader(
|
|
49
|
+
dataset,
|
|
50
|
+
batch_size=batch_size,
|
|
51
|
+
shuffle=shuffle,
|
|
52
|
+
num_workers=self.updated_parameters_dict.get("num_workers_dataloader", 1),
|
|
53
|
+
pin_memory=self.updated_parameters_dict.get("pin_memory_dataloader", False),
|
|
54
|
+
prefetch_factor=self.updated_parameters_dict.get(
|
|
55
|
+
"prefetch_factor_dataloader", 2
|
|
56
|
+
),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def updated_parameters_dict(self):
|
|
61
|
+
return self.parameters_dict
|
|
62
|
+
|
|
63
|
+
def train_dataloader(self):
|
|
64
|
+
self.updated_parameters_dict["batch_size"] = self.batch_size
|
|
65
|
+
return self._get_dataloader(
|
|
66
|
+
self.training_dataset, self.batch_size, shuffle=True
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def val_dataloader(self):
|
|
70
|
+
return self._get_dataloader(
|
|
71
|
+
self.validation_dataset, batch_size=1, shuffle=False
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def test_dataloader(self):
|
|
75
|
+
if self.test_dataset is None:
|
|
76
|
+
return None
|
|
77
|
+
return self._get_dataloader(self.test_dataset, batch_size=1, shuffle=False)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class GandlfInferenceDatamodule(pl.LightningDataModule):
|
|
81
|
+
def __init__(self, dataframe, parameters_dict):
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.dataframe = dataframe
|
|
84
|
+
params = deepcopy(parameters_dict)
|
|
85
|
+
self.parameters_dict = params
|
|
86
|
+
if self.parameters_dict["modality"] == "rad":
|
|
87
|
+
inference_subset_data_parser_radiology = InferenceSubsetDataParserRadiology(
|
|
88
|
+
self.dataframe, params
|
|
89
|
+
)
|
|
90
|
+
self.inference_dataset = (
|
|
91
|
+
inference_subset_data_parser_radiology.create_subset_dataset()
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self.parameters_dict = (
|
|
95
|
+
inference_subset_data_parser_radiology.get_params_extended_with_subset_data()
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def updated_parameters_dict(self):
|
|
100
|
+
return self.parameters_dict
|
|
101
|
+
|
|
102
|
+
def predict_dataloader(self):
|
|
103
|
+
if self.parameters_dict["modality"] == "rad":
|
|
104
|
+
return TorchDataLoader(
|
|
105
|
+
self.inference_dataset,
|
|
106
|
+
batch_size=1,
|
|
107
|
+
shuffle=False,
|
|
108
|
+
num_workers=self.updated_parameters_dict.get(
|
|
109
|
+
"num_workers_dataloader", 1
|
|
110
|
+
),
|
|
111
|
+
pin_memory=self.updated_parameters_dict.get(
|
|
112
|
+
"pin_memory_dataloader", False
|
|
113
|
+
),
|
|
114
|
+
prefetch_factor=self.updated_parameters_dict.get(
|
|
115
|
+
"prefetch_factor_dataloader", 2
|
|
116
|
+
),
|
|
117
|
+
)
|
|
118
|
+
elif self.parameters_dict["modality"] in ["path", "histo"]:
|
|
119
|
+
return self.dataframe.iterrows()
|
GANDLF/entrypoints/run.py
CHANGED
|
@@ -14,17 +14,19 @@ from GANDLF.cli import main_run, copyrightMessage
|
|
|
14
14
|
from GANDLF.entrypoints import append_copyright_to_help
|
|
15
15
|
from GANDLF.utils import logger_setup
|
|
16
16
|
|
|
17
|
+
from warnings import warn
|
|
18
|
+
|
|
17
19
|
|
|
18
20
|
def _run(
|
|
19
21
|
config: str,
|
|
20
22
|
input_data: str,
|
|
21
23
|
train_flag: bool,
|
|
22
24
|
model_dir: str,
|
|
23
|
-
device: str,
|
|
24
25
|
reset_flag: bool,
|
|
25
26
|
resume_flag: bool,
|
|
27
|
+
device: Optional[str],
|
|
26
28
|
output_path: Optional[str],
|
|
27
|
-
|
|
29
|
+
profile: Optional[bool] = False,
|
|
28
30
|
):
|
|
29
31
|
if model_dir is None and output_path:
|
|
30
32
|
model_dir = output_path
|
|
@@ -49,7 +51,11 @@ def _run(
|
|
|
49
51
|
"'reset' and 'resume' are mutually exclusive; 'resume' will be used."
|
|
50
52
|
)
|
|
51
53
|
reset_flag = False
|
|
52
|
-
|
|
54
|
+
if device is not None:
|
|
55
|
+
warn(
|
|
56
|
+
"Device parameter is deprecated and has no effect. See migration guide docs.",
|
|
57
|
+
DeprecationWarning,
|
|
58
|
+
)
|
|
53
59
|
# TODO: check that output_path is not passed in training mode;
|
|
54
60
|
# maybe user misconfigured the command
|
|
55
61
|
|
|
@@ -57,22 +63,20 @@ def _run(
|
|
|
57
63
|
logging.debug(f"{input_data=}")
|
|
58
64
|
logging.debug(f"{train_flag=}")
|
|
59
65
|
logging.debug(f"{model_dir=}")
|
|
60
|
-
logging.debug(f"{device=}")
|
|
61
66
|
logging.debug(f"{reset_flag=}")
|
|
62
67
|
logging.debug(f"{resume_flag=}")
|
|
63
68
|
logging.debug(f"{output_path=}")
|
|
64
|
-
logging.debug(f"{
|
|
69
|
+
logging.debug(f"{profile=}")
|
|
65
70
|
|
|
66
71
|
main_run(
|
|
67
72
|
data_csv=input_data,
|
|
68
73
|
config_file=config,
|
|
69
74
|
model_dir=model_dir,
|
|
70
75
|
train_mode=train_flag,
|
|
71
|
-
device=device,
|
|
72
76
|
resume=resume_flag,
|
|
73
77
|
reset=reset_flag,
|
|
74
78
|
output_dir=output_path,
|
|
75
|
-
|
|
79
|
+
profile=profile,
|
|
76
80
|
)
|
|
77
81
|
print("Finished.")
|
|
78
82
|
|
|
@@ -108,21 +112,6 @@ def _run(
|
|
|
108
112
|
help="Training: Output directory to save intermediate files and model weights; "
|
|
109
113
|
"inference: location of previous training session output",
|
|
110
114
|
)
|
|
111
|
-
@click.option(
|
|
112
|
-
"--device",
|
|
113
|
-
"-d",
|
|
114
|
-
# TODO: Not sure it's worth to restrict this list. What about other devices?
|
|
115
|
-
# GaNDLF guarantees to work properly with these two options, but
|
|
116
|
-
# other values may be partially working also.
|
|
117
|
-
# * GaNDLF code convert `-1` to `cpu` (i.e. it is expected somebody may pass -1)
|
|
118
|
-
# * `cuda:0` should work also, isn't it? Just would not be treated as `cuda`
|
|
119
|
-
# * Would `mps` work?
|
|
120
|
-
# * int values (like `1`) - are they supported? (legacy mode for cuda https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)
|
|
121
|
-
type=click.Choice(["cuda", "cpu"]),
|
|
122
|
-
required=True, # FIXME: either keep default value, or set required flag
|
|
123
|
-
help="Device to perform requested session on 'cpu' or 'cuda'; "
|
|
124
|
-
"for cuda, ensure CUDA_VISIBLE_DEVICES env var is set",
|
|
125
|
-
)
|
|
126
115
|
@click.option(
|
|
127
116
|
"--reset",
|
|
128
117
|
"-rt",
|
|
@@ -154,19 +143,25 @@ def _run(
|
|
|
154
143
|
is_flag=True,
|
|
155
144
|
help="Track the run time and memory consumption for each layer",
|
|
156
145
|
)
|
|
146
|
+
@click.option(
|
|
147
|
+
"--device",
|
|
148
|
+
"-d",
|
|
149
|
+
type=str,
|
|
150
|
+
help="DEPRECATED - has no effect, see migration guide docs. Device to run the model on.",
|
|
151
|
+
)
|
|
157
152
|
@append_copyright_to_help
|
|
158
153
|
def new_way(
|
|
159
154
|
config: str,
|
|
160
155
|
input_data: str,
|
|
161
156
|
train: bool,
|
|
162
157
|
model_dir: str,
|
|
163
|
-
device: str,
|
|
164
158
|
reset: bool,
|
|
165
159
|
resume: bool,
|
|
166
160
|
output_path: str,
|
|
167
161
|
raw_input: str,
|
|
168
162
|
profile: bool,
|
|
169
163
|
log_file: str,
|
|
164
|
+
device: str,
|
|
170
165
|
):
|
|
171
166
|
"""Semantic segmentation, regression, and classification for medical images using Deep Learning."""
|
|
172
167
|
|
|
@@ -176,11 +171,11 @@ def new_way(
|
|
|
176
171
|
input_data=input_data,
|
|
177
172
|
train_flag=train,
|
|
178
173
|
model_dir=model_dir,
|
|
179
|
-
device=device,
|
|
180
174
|
reset_flag=reset,
|
|
181
175
|
resume_flag=resume,
|
|
182
176
|
output_path=output_path,
|
|
183
|
-
|
|
177
|
+
profile=profile,
|
|
178
|
+
device=device,
|
|
184
179
|
)
|
|
185
180
|
|
|
186
181
|
|
|
@@ -238,15 +233,7 @@ def old_way():
|
|
|
238
233
|
type=str,
|
|
239
234
|
help="Training: Output directory to save intermediate files and model weights; inference: location of previous training session output",
|
|
240
235
|
)
|
|
241
|
-
|
|
242
|
-
"-d",
|
|
243
|
-
"--device",
|
|
244
|
-
default="cuda", # TODO: default value doesn't work as arg is required
|
|
245
|
-
metavar="",
|
|
246
|
-
type=str,
|
|
247
|
-
required=True,
|
|
248
|
-
help="Device to perform requested session on 'cpu' or 'cuda'; for cuda, ensure CUDA_VISIBLE_DEVICES env var is set",
|
|
249
|
-
)
|
|
236
|
+
|
|
250
237
|
parser.add_argument(
|
|
251
238
|
"-rt",
|
|
252
239
|
"--reset",
|
|
@@ -278,6 +265,13 @@ def old_way():
|
|
|
278
265
|
version="%(prog)s v{}".format(version) + "\n\n" + copyrightMessage,
|
|
279
266
|
help="Show program's version number and exit.",
|
|
280
267
|
)
|
|
268
|
+
parser.add_argument(
|
|
269
|
+
"-d",
|
|
270
|
+
"--device",
|
|
271
|
+
metavar="",
|
|
272
|
+
type=str,
|
|
273
|
+
help="DEPRECATED - has no effect, see migration guide docs. Device to run the model on.",
|
|
274
|
+
)
|
|
281
275
|
|
|
282
276
|
# This is a dummy argument that exists to trigger MLCube mounting requirements.
|
|
283
277
|
# Do not remove.
|
|
@@ -293,10 +287,10 @@ def old_way():
|
|
|
293
287
|
input_data=args.inputdata,
|
|
294
288
|
train_flag=args.train,
|
|
295
289
|
model_dir=args.modeldir,
|
|
296
|
-
device=args.device,
|
|
297
290
|
reset_flag=args.reset,
|
|
298
291
|
resume_flag=args.resume,
|
|
299
292
|
output_path=args.outputdir,
|
|
293
|
+
device=args.device,
|
|
300
294
|
)
|
|
301
295
|
|
|
302
296
|
|
GANDLF/inference_manager.py
CHANGED
|
@@ -4,16 +4,16 @@ from typing import Optional
|
|
|
4
4
|
|
|
5
5
|
import pandas as pd
|
|
6
6
|
import torch
|
|
7
|
-
import torch.nn.functional as F
|
|
8
|
-
from GANDLF.compute import inference_loop
|
|
9
7
|
from GANDLF.utils import get_unique_timestamp
|
|
8
|
+
import lightning.pytorch as pl
|
|
9
|
+
from GANDLF.models.lightning_module import GandlfLightningModule
|
|
10
|
+
from GANDLF.data.lightning_datamodule import GandlfInferenceDatamodule
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def InferenceManager(
|
|
13
14
|
dataframe: pd.DataFrame,
|
|
14
15
|
modelDir: str,
|
|
15
16
|
parameters: dict,
|
|
16
|
-
device: str,
|
|
17
17
|
outputDir: Optional[str] = None,
|
|
18
18
|
) -> None:
|
|
19
19
|
"""
|
|
@@ -27,7 +27,7 @@ def InferenceManager(
|
|
|
27
27
|
outputDir (Optional[str], optional): The output directory for the inference results. Defaults to None.
|
|
28
28
|
"""
|
|
29
29
|
if outputDir is None:
|
|
30
|
-
outputDir = os.path.join(modelDir, get_unique_timestamp())
|
|
30
|
+
outputDir = os.path.join(modelDir, get_unique_timestamp(), "output_inference")
|
|
31
31
|
print(
|
|
32
32
|
"Output directory not provided, creating a new directory with a unique timestamp: ",
|
|
33
33
|
outputDir,
|
|
@@ -43,6 +43,35 @@ def InferenceManager(
|
|
|
43
43
|
n_folds = parameters["nested_training"]["validation"]
|
|
44
44
|
modelDir_split = modelDir.split(",") if "," in modelDir else [modelDir]
|
|
45
45
|
|
|
46
|
+
# This should be handled by config parser
|
|
47
|
+
accelerator = parameters.get("accelerator", "auto")
|
|
48
|
+
allowed_accelerators = ["cpu", "gpu", "auto"]
|
|
49
|
+
# codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
|
|
50
|
+
assert (
|
|
51
|
+
accelerator in allowed_accelerators
|
|
52
|
+
), f"Invalid accelerator selected: {accelerator}. Please select from {allowed_accelerators}"
|
|
53
|
+
strategy = parameters.get("strategy", "auto")
|
|
54
|
+
allowed_strategies = ["auto", "ddp"]
|
|
55
|
+
# codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
|
|
56
|
+
assert (
|
|
57
|
+
strategy in allowed_strategies
|
|
58
|
+
), f"Invalid strategy selected: {strategy}. Please select from {allowed_strategies}"
|
|
59
|
+
precision = parameters.get("precision", "32")
|
|
60
|
+
allowed_precisions = [
|
|
61
|
+
"64",
|
|
62
|
+
"64-true",
|
|
63
|
+
"32",
|
|
64
|
+
"32-true",
|
|
65
|
+
"16",
|
|
66
|
+
"16-mixed",
|
|
67
|
+
"bf16",
|
|
68
|
+
"bf16-mixed",
|
|
69
|
+
]
|
|
70
|
+
# codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
|
|
71
|
+
assert (
|
|
72
|
+
precision in allowed_precisions
|
|
73
|
+
), f"Invalid precision selected: {precision}. Please select from {allowed_precisions}"
|
|
74
|
+
|
|
46
75
|
averaged_probs_list = []
|
|
47
76
|
for current_modelDir in modelDir_split:
|
|
48
77
|
fold_dirs = (
|
|
@@ -55,7 +84,6 @@ def InferenceManager(
|
|
|
55
84
|
else [current_modelDir]
|
|
56
85
|
)
|
|
57
86
|
|
|
58
|
-
probs_list = []
|
|
59
87
|
is_classification = parameters["problem_type"] == "classification"
|
|
60
88
|
parameters["model"].setdefault("type", "torch")
|
|
61
89
|
class_list = (
|
|
@@ -63,27 +91,43 @@ def InferenceManager(
|
|
|
63
91
|
if is_classification
|
|
64
92
|
else None
|
|
65
93
|
)
|
|
66
|
-
|
|
94
|
+
probs_list = None
|
|
67
95
|
for fold_dir in fold_dirs:
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
96
|
+
trainer = pl.Trainer(
|
|
97
|
+
accelerator=accelerator,
|
|
98
|
+
strategy=strategy,
|
|
99
|
+
fast_dev_run=False,
|
|
100
|
+
devices=parameters.get("devices", "auto"),
|
|
101
|
+
num_nodes=parameters.get("num_nodes", 1),
|
|
102
|
+
precision=precision,
|
|
103
|
+
gradient_clip_algorithm=parameters["clip_mode"],
|
|
104
|
+
gradient_clip_val=parameters["clip_grad"],
|
|
105
|
+
max_epochs=parameters["num_epochs"],
|
|
106
|
+
sync_batchnorm=False,
|
|
107
|
+
enable_checkpointing=False,
|
|
108
|
+
logger=False,
|
|
109
|
+
num_sanity_val_steps=0,
|
|
75
110
|
)
|
|
111
|
+
datamodule = GandlfInferenceDatamodule(dataframe, parameters)
|
|
112
|
+
parameters = datamodule.updated_parameters_dict
|
|
113
|
+
lightning_module = GandlfLightningModule(parameters, output_dir=fold_dir)
|
|
76
114
|
|
|
115
|
+
if parameters.get("auto_batch_size_find", False):
|
|
116
|
+
print(
|
|
117
|
+
"Auto batch size find is not supported in inference. Dataloader batch size is always 1."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
trainer.predict(lightning_module, datamodule=datamodule)
|
|
77
121
|
if is_classification:
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
122
|
+
prob_values_for_all_subjects_in_fold = list(
|
|
123
|
+
lightning_module.subject_classification_class_probabilities.values()
|
|
124
|
+
)
|
|
125
|
+
if prob_values_for_all_subjects_in_fold:
|
|
126
|
+
probs_list = torch.stack(
|
|
127
|
+
prob_values_for_all_subjects_in_fold, dim=1
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if is_classification and probs_list is not None:
|
|
87
131
|
averaged_probs_list.append(torch.mean(probs_list, 0))
|
|
88
132
|
|
|
89
133
|
# this logic should be changed if we want to do multi-fold inference for histo images
|
|
@@ -93,9 +137,9 @@ def InferenceManager(
|
|
|
93
137
|
)
|
|
94
138
|
averaged_probs_df["SubjectID"] = dataframe.iloc[:, 0]
|
|
95
139
|
|
|
96
|
-
averaged_probs_across_models =
|
|
97
|
-
torch.stack(averaged_probs_list), 0
|
|
98
|
-
)
|
|
140
|
+
averaged_probs_across_models = (
|
|
141
|
+
torch.mean(torch.stack(averaged_probs_list), 0).cpu().numpy()
|
|
142
|
+
)
|
|
99
143
|
averaged_probs_df[class_list] = averaged_probs_across_models
|
|
100
144
|
averaged_probs_df["PredictedClass"] = [
|
|
101
145
|
class_list[idx] for idx in averaged_probs_across_models.argmax(axis=1)
|
GANDLF/losses/__init__.py
CHANGED
|
@@ -13,7 +13,6 @@ from .segmentation import (
|
|
|
13
13
|
from .regression import CE, CEL, MSE_loss, L1_loss
|
|
14
14
|
from .hybrid import DCCE, DCCE_Logits, DC_Focal
|
|
15
15
|
|
|
16
|
-
|
|
17
16
|
# global defines for the losses
|
|
18
17
|
global_losses_dict = {
|
|
19
18
|
"dc": MCD_loss,
|
|
@@ -38,3 +37,26 @@ global_losses_dict = {
|
|
|
38
37
|
"focal": FocalLoss,
|
|
39
38
|
"dc_focal": DC_Focal,
|
|
40
39
|
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_loss(params: dict) -> object:
|
|
43
|
+
"""
|
|
44
|
+
Function to get the loss definition.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
params (dict): The parameters' dictionary.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
loss (object): The loss definition.
|
|
51
|
+
"""
|
|
52
|
+
# TODO This check looks like legacy code, should we have it?
|
|
53
|
+
|
|
54
|
+
if isinstance(params["loss_function"], dict):
|
|
55
|
+
chosen_loss = list(params["loss_function"].keys())[0].lower()
|
|
56
|
+
else:
|
|
57
|
+
chosen_loss = params["loss_function"].lower()
|
|
58
|
+
assert (
|
|
59
|
+
chosen_loss in global_losses_dict
|
|
60
|
+
), f"Could not find the requested loss function '{params['loss_function']}'"
|
|
61
|
+
|
|
62
|
+
return global_losses_dict[chosen_loss]
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
from GANDLF.losses import get_loss
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AbstractLossCalculator(ABC):
|
|
8
|
+
def __init__(self, params: dict):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.params = params
|
|
11
|
+
self._initialize_loss()
|
|
12
|
+
|
|
13
|
+
def _initialize_loss(self):
|
|
14
|
+
self.loss = get_loss(self.params)
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def __call__(
|
|
18
|
+
self, prediction: torch.Tensor, target: torch.Tensor, *args
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LossCalculatorSDNet(AbstractLossCalculator):
|
|
24
|
+
def __init__(self, params):
|
|
25
|
+
super().__init__(params)
|
|
26
|
+
self.l1_loss = get_loss(params)
|
|
27
|
+
self.kld_loss = get_loss(params)
|
|
28
|
+
self.mse_loss = get_loss(params)
|
|
29
|
+
|
|
30
|
+
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args):
|
|
31
|
+
if len(prediction) < 2:
|
|
32
|
+
image: torch.Tensor = args[0]
|
|
33
|
+
loss_seg = self.loss(prediction[0], target.squeeze(-1), self.params)
|
|
34
|
+
loss_reco = self.l1_loss(prediction[1], image[:, :1, ...], None)
|
|
35
|
+
loss_kld = self.kld_loss(prediction[2], prediction[3])
|
|
36
|
+
loss_cycle = self.mse_loss(prediction[2], prediction[4], None)
|
|
37
|
+
return 0.01 * loss_kld + loss_reco + 10 * loss_seg + loss_cycle
|
|
38
|
+
else:
|
|
39
|
+
return self.loss(prediction, target, self.params)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LossCalculatorDeepSupervision(AbstractLossCalculator):
|
|
43
|
+
def __init__(self, params):
|
|
44
|
+
super().__init__(params)
|
|
45
|
+
# This was taken from current Gandlf code, but I am not sure if
|
|
46
|
+
# we should have this set rigidly here, as it enforces the number of
|
|
47
|
+
# classes to be 4.
|
|
48
|
+
self.loss_weights = [0.5, 0.25, 0.175, 0.075]
|
|
49
|
+
|
|
50
|
+
def __call__(
|
|
51
|
+
self, prediction: torch.Tensor, target: torch.Tensor, *args
|
|
52
|
+
) -> torch.Tensor:
|
|
53
|
+
loss_values = []
|
|
54
|
+
for i, pred in enumerate(prediction):
|
|
55
|
+
loss_values.append(
|
|
56
|
+
self.loss(pred, target[i], self.params) * self.loss_weights[i]
|
|
57
|
+
)
|
|
58
|
+
loss = torch.stack(loss_values).sum()
|
|
59
|
+
return loss
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class LossCalculatorSimple(AbstractLossCalculator):
|
|
63
|
+
def __call__(
|
|
64
|
+
self, prediction: torch.Tensor, target: torch.Tensor, *args
|
|
65
|
+
) -> torch.Tensor:
|
|
66
|
+
return self.loss(prediction, target, self.params)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class LossCalculatorFactory:
|
|
70
|
+
def __init__(self, params: dict):
|
|
71
|
+
self.params = params
|
|
72
|
+
|
|
73
|
+
def get_loss_calculator(self) -> AbstractLossCalculator:
|
|
74
|
+
if self.params["model"]["architecture"] == "sdnet":
|
|
75
|
+
return LossCalculatorSDNet(self.params)
|
|
76
|
+
elif "deep" in self.params["model"]["architecture"].lower():
|
|
77
|
+
return LossCalculatorDeepSupervision(self.params)
|
|
78
|
+
else:
|
|
79
|
+
return LossCalculatorSimple(self.params)
|
GANDLF/losses/segmentation.py
CHANGED
|
@@ -87,7 +87,8 @@ def generic_loss_calculator(
|
|
|
87
87
|
accumulated_loss = 0
|
|
88
88
|
# default to a ridiculous value so that it is ignored by default
|
|
89
89
|
ignore_class = -1e10 if ignore_class is None else ignore_class
|
|
90
|
-
|
|
90
|
+
predicted = predicted.squeeze(-1)
|
|
91
|
+
target = target.squeeze(-1)
|
|
91
92
|
for class_index in range(num_class):
|
|
92
93
|
if class_index != ignore_class:
|
|
93
94
|
current_loss = loss_criteria(
|
|
@@ -334,7 +335,7 @@ def FocalLoss(
|
|
|
334
335
|
torch.Tensor: Computed focal loss for a single class.
|
|
335
336
|
"""
|
|
336
337
|
ce_loss = torch.nn.CrossEntropyLoss(reduce=False)
|
|
337
|
-
logpt = ce_loss(preds, target)
|
|
338
|
+
logpt = ce_loss(preds.squeeze(-1), target.squeeze(-1))
|
|
338
339
|
pt = torch.exp(-logpt)
|
|
339
340
|
loss = ((1 - pt) ** gamma) * logpt
|
|
340
341
|
return_loss = loss.sum()
|
GANDLF/metrics/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
All the metrics are to be called from here
|
|
3
3
|
"""
|
|
4
|
+
from warnings import warn
|
|
4
5
|
from typing import Union
|
|
5
6
|
|
|
6
7
|
from GANDLF.losses.regression import MSE_loss, CEL
|
|
@@ -40,6 +41,7 @@ from .synthesis import (
|
|
|
40
41
|
)
|
|
41
42
|
import GANDLF.metrics.classification as classification
|
|
42
43
|
import GANDLF.metrics.regression as regression
|
|
44
|
+
from .segmentation_panoptica import generate_instance_segmentation
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
# global defines for the metrics
|
|
@@ -100,6 +102,30 @@ surface_distance_ids = [
|
|
|
100
102
|
]
|
|
101
103
|
|
|
102
104
|
|
|
105
|
+
def get_metrics(params: dict) -> dict:
|
|
106
|
+
"""
|
|
107
|
+
Returns an dictionary of containing calculators of the specified metric functions
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
params (dict): A dictionary containing the overall training parameters.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
metric_calculators (dict): A dictionary containing the calculators of the specified metric functions.
|
|
114
|
+
"""
|
|
115
|
+
metric_calculators = {}
|
|
116
|
+
for metric_name in params["metrics"]:
|
|
117
|
+
metric_name = metric_name.lower()
|
|
118
|
+
if metric_name not in global_metrics_dict:
|
|
119
|
+
warn(
|
|
120
|
+
f"Metric {metric_name} not found in global metrics dictionary, it will not be used.",
|
|
121
|
+
UserWarning,
|
|
122
|
+
)
|
|
123
|
+
continue
|
|
124
|
+
else:
|
|
125
|
+
metric_calculators[metric_name] = global_metrics_dict[metric_name]
|
|
126
|
+
return metric_calculators
|
|
127
|
+
|
|
128
|
+
|
|
103
129
|
def overall_stats(predictions, ground_truth, params) -> dict[str, Union[float, list]]:
|
|
104
130
|
"""
|
|
105
131
|
Generates a dictionary of metrics calculated on the overall predictions and ground truths.
|
GANDLF/metrics/generic.py
CHANGED
|
@@ -34,7 +34,7 @@ def generic_function_output_with_check(
|
|
|
34
34
|
print(
|
|
35
35
|
"WARNING: Negative values detected in prediction, cannot compute torchmetrics calculations."
|
|
36
36
|
)
|
|
37
|
-
return torch.
|
|
37
|
+
return torch.tensor(0, device=prediction.device)
|
|
38
38
|
else:
|
|
39
39
|
# I need to do this with try-except, otherwise for binary problems it will
|
|
40
40
|
# raise and error as the binary metrics do not have .num_classes
|