GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of GANDLF might be problematic. Click here for more details.
- GANDLF/cli/deploy.py +2 -2
- GANDLF/cli/generate_metrics.py +35 -1
- GANDLF/cli/main_run.py +4 -10
- GANDLF/compute/__init__.py +0 -2
- GANDLF/compute/forward_pass.py +0 -1
- GANDLF/compute/generic.py +107 -2
- GANDLF/compute/inference_loop.py +4 -4
- GANDLF/compute/loss_and_metric.py +1 -2
- GANDLF/compute/training_loop.py +10 -10
- GANDLF/config_manager.py +33 -717
- GANDLF/configuration/__init__.py +0 -0
- GANDLF/configuration/default_config.py +73 -0
- GANDLF/configuration/differential_privacy_config.py +16 -0
- GANDLF/configuration/exclude_parameters.py +1 -0
- GANDLF/configuration/model_config.py +82 -0
- GANDLF/configuration/nested_training_config.py +25 -0
- GANDLF/configuration/optimizer_config.py +121 -0
- GANDLF/configuration/parameters_config.py +10 -0
- GANDLF/configuration/patch_sampler_config.py +11 -0
- GANDLF/configuration/post_processing_config.py +10 -0
- GANDLF/configuration/pre_processing_config.py +94 -0
- GANDLF/configuration/scheduler_config.py +92 -0
- GANDLF/configuration/user_defined_config.py +131 -0
- GANDLF/configuration/utils.py +96 -0
- GANDLF/configuration/validators.py +479 -0
- GANDLF/data/__init__.py +14 -16
- GANDLF/data/lightning_datamodule.py +119 -0
- GANDLF/entrypoints/run.py +36 -31
- GANDLF/inference_manager.py +69 -25
- GANDLF/losses/__init__.py +23 -1
- GANDLF/losses/loss_calculators.py +79 -0
- GANDLF/losses/segmentation.py +3 -2
- GANDLF/metrics/__init__.py +26 -0
- GANDLF/metrics/generic.py +1 -1
- GANDLF/metrics/metric_calculators.py +102 -0
- GANDLF/metrics/panoptica_config_brats.yaml +56 -0
- GANDLF/metrics/segmentation_panoptica.py +49 -0
- GANDLF/models/__init__.py +8 -3
- GANDLF/models/lightning_module.py +2102 -0
- GANDLF/optimizers/__init__.py +4 -8
- GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
- GANDLF/schedulers/__init__.py +11 -4
- GANDLF/schedulers/wrap_torch.py +15 -3
- GANDLF/training_manager.py +160 -50
- GANDLF/utils/__init__.py +5 -3
- GANDLF/utils/imaging.py +176 -35
- GANDLF/utils/modelio.py +12 -8
- GANDLF/utils/pred_target_processors.py +71 -0
- GANDLF/utils/tensor.py +2 -1
- GANDLF/utils/write_parse.py +1 -1
- GANDLF/version.py +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from GANDLF.metrics import get_metrics
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AbstractMetricCalculator(ABC):
|
|
9
|
+
def __init__(self, params: dict):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.params = deepcopy(params)
|
|
12
|
+
self._initialize_metrics_dict()
|
|
13
|
+
|
|
14
|
+
def _initialize_metrics_dict(self):
|
|
15
|
+
self.metrics_calculators = get_metrics(self.params)
|
|
16
|
+
|
|
17
|
+
def _process_metric_value(self, metric_value: Union[torch.Tensor, float]):
|
|
18
|
+
if isinstance(metric_value, float):
|
|
19
|
+
return metric_value
|
|
20
|
+
if metric_value.dim() == 0:
|
|
21
|
+
return metric_value.item()
|
|
22
|
+
else:
|
|
23
|
+
return metric_value.tolist()
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def _inject_kwargs_into_params(params, **kwargs):
|
|
27
|
+
for key, value in kwargs.items():
|
|
28
|
+
params[key] = value
|
|
29
|
+
return params
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def __call__(
|
|
33
|
+
self, prediction: torch.Tensor, target: torch.Tensor, **kwargs
|
|
34
|
+
) -> torch.Tensor:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MetricCalculatorSDNet(AbstractMetricCalculator):
|
|
39
|
+
def __init__(self, params):
|
|
40
|
+
super().__init__(params)
|
|
41
|
+
|
|
42
|
+
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
|
|
43
|
+
params = deepcopy(self.params)
|
|
44
|
+
params = self._inject_kwargs_into_params(params, **kwargs)
|
|
45
|
+
|
|
46
|
+
metric_results = {}
|
|
47
|
+
|
|
48
|
+
for metric_name, metric_calculator in self.metrics_calculators.items():
|
|
49
|
+
metric_value = (
|
|
50
|
+
metric_calculator(prediction[0], target.squeeze(-1), params)
|
|
51
|
+
.detach()
|
|
52
|
+
.cpu()
|
|
53
|
+
)
|
|
54
|
+
metric_results[metric_name] = self._process_metric_value(metric_value)
|
|
55
|
+
return metric_results
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MetricCalculatorDeepSupervision(AbstractMetricCalculator):
|
|
59
|
+
def __init__(self, params):
|
|
60
|
+
super().__init__(params)
|
|
61
|
+
|
|
62
|
+
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
|
|
63
|
+
params = deepcopy(self.params)
|
|
64
|
+
params = self._inject_kwargs_into_params(params, **kwargs)
|
|
65
|
+
metric_results = {}
|
|
66
|
+
|
|
67
|
+
for metric_name, metric_calculator in self.metrics_calculators.items():
|
|
68
|
+
metric_results[metric_name] = 0.0
|
|
69
|
+
for i, _ in enumerate(prediction):
|
|
70
|
+
metric_value = (
|
|
71
|
+
metric_calculator(prediction[i], target[i], params).detach().cpu()
|
|
72
|
+
)
|
|
73
|
+
metric_results[metric_name] += self._process_metric_value(metric_value)
|
|
74
|
+
return metric_results
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class MetricCalculatorSimple(AbstractMetricCalculator):
|
|
78
|
+
def __init__(self, params):
|
|
79
|
+
super().__init__(params)
|
|
80
|
+
|
|
81
|
+
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
|
|
82
|
+
params = deepcopy(self.params)
|
|
83
|
+
params = self._inject_kwargs_into_params(params, **kwargs)
|
|
84
|
+
metric_results = {}
|
|
85
|
+
|
|
86
|
+
for metric_name, metric_calculator in self.metrics_calculators.items():
|
|
87
|
+
metric_value = metric_calculator(prediction, target, params).detach().cpu()
|
|
88
|
+
metric_results[metric_name] = self._process_metric_value(metric_value)
|
|
89
|
+
return metric_results
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class MetricCalculatorFactory:
|
|
93
|
+
def __init__(self, params: dict):
|
|
94
|
+
self.params = params
|
|
95
|
+
|
|
96
|
+
def get_metric_calculator(self) -> AbstractMetricCalculator:
|
|
97
|
+
if self.params["model"]["architecture"] == "sdnet":
|
|
98
|
+
return MetricCalculatorSDNet(self.params)
|
|
99
|
+
elif "deep" in self.params["model"]["architecture"].lower():
|
|
100
|
+
return MetricCalculatorDeepSupervision(self.params)
|
|
101
|
+
else:
|
|
102
|
+
return MetricCalculatorSimple(self.params)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
!Panoptica_Evaluator
|
|
2
|
+
decision_metric: null
|
|
3
|
+
decision_threshold: null
|
|
4
|
+
edge_case_handler: !EdgeCaseHandler
|
|
5
|
+
empty_list_std: !EdgeCaseResult NAN
|
|
6
|
+
listmetric_zeroTP_handling:
|
|
7
|
+
!Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
|
|
8
|
+
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
|
|
9
|
+
normal: !EdgeCaseResult ZERO}
|
|
10
|
+
!Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
|
|
11
|
+
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
|
|
12
|
+
normal: !EdgeCaseResult ZERO}
|
|
13
|
+
!Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
|
|
14
|
+
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
|
|
15
|
+
normal: !EdgeCaseResult ZERO}
|
|
16
|
+
!Metric NSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF,
|
|
17
|
+
empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN,
|
|
18
|
+
normal: !EdgeCaseResult INF}
|
|
19
|
+
!Metric HD95: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF,
|
|
20
|
+
empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN,
|
|
21
|
+
normal: !EdgeCaseResult INF}
|
|
22
|
+
!Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
|
|
23
|
+
empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
|
|
24
|
+
normal: !EdgeCaseResult NAN}
|
|
25
|
+
!Metric RVAE: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
|
|
26
|
+
empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
|
|
27
|
+
normal: !EdgeCaseResult NAN}
|
|
28
|
+
expected_input: !InputType SEMANTIC
|
|
29
|
+
global_metrics: [!Metric DSC]
|
|
30
|
+
instance_approximator: !ConnectedComponentsInstanceApproximator {cca_backend: null}
|
|
31
|
+
instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU,
|
|
32
|
+
matching_threshold: 0.5}
|
|
33
|
+
instance_metrics: [!Metric DSC, !Metric IOU, !Metric RVD, !Metric NSD, !Metric HD95]
|
|
34
|
+
log_times: false
|
|
35
|
+
save_group_times: false
|
|
36
|
+
segmentation_class_groups: !SegmentationClassGroups
|
|
37
|
+
groups:
|
|
38
|
+
snfh: !LabelGroup
|
|
39
|
+
single_instance: false
|
|
40
|
+
value_labels: [2]
|
|
41
|
+
et: !LabelGroup
|
|
42
|
+
single_instance: false
|
|
43
|
+
value_labels: [3]
|
|
44
|
+
netc: !LabelGroup
|
|
45
|
+
single_instance: false
|
|
46
|
+
value_labels: [1]
|
|
47
|
+
rc: !LabelGroup
|
|
48
|
+
single_instance: false
|
|
49
|
+
value_labels: [4]
|
|
50
|
+
tc: !LabelMergeGroup
|
|
51
|
+
single_instance: false
|
|
52
|
+
value_labels: [1, 3, 4]
|
|
53
|
+
wt: !LabelMergeGroup
|
|
54
|
+
single_instance: false
|
|
55
|
+
value_labels: [1, 2, 3, 4]
|
|
56
|
+
verbose: false
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from typing import Optional
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from panoptica import Panoptica_Evaluator
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def generate_instance_segmentation(
|
|
11
|
+
prediction: np.ndarray,
|
|
12
|
+
target: np.ndarray,
|
|
13
|
+
parameters: dict,
|
|
14
|
+
panoptica_config_path: Optional[str] = None,
|
|
15
|
+
) -> dict:
|
|
16
|
+
"""
|
|
17
|
+
Evaluate a single exam using Panoptica.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
prediction (np.ndarray): The input prediction containing objects.
|
|
21
|
+
label_path (str): The path to the reference label.
|
|
22
|
+
target (np.ndarray): The input target containing objects.
|
|
23
|
+
parameters (dict): The GaNDLF parameters from which panoptica config is to be extracted.
|
|
24
|
+
panoptica_config_path (str): The path to the Panoptica configuration file.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
dict: The evaluation results.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
os.environ["PANOPTICA_CITATION_REMINDER"] = "False"
|
|
31
|
+
|
|
32
|
+
# the parameters dict takes precedence over the panoptica_config_path
|
|
33
|
+
evaluator = parameters.get("panoptica_config", None)
|
|
34
|
+
if evaluator is None:
|
|
35
|
+
cwd = Path(__file__).parent.absolute()
|
|
36
|
+
panoptica_config_path = (
|
|
37
|
+
str(cwd / "panoptica_config_brats.yaml")
|
|
38
|
+
if panoptica_config_path is None
|
|
39
|
+
else panoptica_config_path
|
|
40
|
+
)
|
|
41
|
+
evaluator = Panoptica_Evaluator.load_from_config(panoptica_config_path)
|
|
42
|
+
|
|
43
|
+
assert evaluator is not None, "Panoptica evaluator could not be initialized."
|
|
44
|
+
|
|
45
|
+
# call evaluate
|
|
46
|
+
group2result = evaluator.evaluate(prediction_arr=prediction, reference_arr=target)
|
|
47
|
+
|
|
48
|
+
results = {k: r.to_dict() for k, r in group2result.items()}
|
|
49
|
+
return results
|
GANDLF/models/__init__.py
CHANGED
|
@@ -36,6 +36,7 @@ from .MSDNet import MSDNet
|
|
|
36
36
|
from .brain_age import brainage
|
|
37
37
|
from .unetr import unetr
|
|
38
38
|
from .transunet import transunet
|
|
39
|
+
from .modelBase import ModelBase
|
|
39
40
|
|
|
40
41
|
# Define a dictionary of model architectures and corresponding functions
|
|
41
42
|
global_models_dict = {
|
|
@@ -110,7 +111,7 @@ global_models_dict = {
|
|
|
110
111
|
}
|
|
111
112
|
|
|
112
113
|
|
|
113
|
-
def get_model(params):
|
|
114
|
+
def get_model(params: dict) -> ModelBase:
|
|
114
115
|
"""
|
|
115
116
|
Function to get the model definition.
|
|
116
117
|
|
|
@@ -118,6 +119,10 @@ def get_model(params):
|
|
|
118
119
|
params (dict): The parameters' dictionary.
|
|
119
120
|
|
|
120
121
|
Returns:
|
|
121
|
-
model (
|
|
122
|
+
model (ModelBase): The model definition.
|
|
122
123
|
"""
|
|
123
|
-
|
|
124
|
+
chosen_model = params["model"]["architecture"].lower()
|
|
125
|
+
assert (
|
|
126
|
+
chosen_model in global_models_dict
|
|
127
|
+
), f"Could not find the requested model '{params['model']['architecture']}'"
|
|
128
|
+
return global_models_dict[chosen_model](parameters=params)
|