GANDLF 0.1.3.dev20250319__py3-none-any.whl → 0.1.4.dev20250503__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of GANDLF might be problematic. Click here for more details.
- GANDLF/cli/deploy.py +2 -2
- GANDLF/cli/generate_metrics.py +21 -0
- GANDLF/cli/main_run.py +4 -12
- GANDLF/compute/__init__.py +0 -2
- GANDLF/compute/forward_pass.py +0 -1
- GANDLF/compute/generic.py +107 -2
- GANDLF/compute/inference_loop.py +4 -4
- GANDLF/compute/loss_and_metric.py +1 -2
- GANDLF/compute/training_loop.py +10 -10
- GANDLF/config_manager.py +26 -716
- GANDLF/configuration/__init__.py +0 -0
- GANDLF/configuration/default_config.py +73 -0
- GANDLF/configuration/differential_privacy_config.py +16 -0
- GANDLF/configuration/exclude_parameters.py +1 -0
- GANDLF/configuration/model_config.py +82 -0
- GANDLF/configuration/nested_training_config.py +25 -0
- GANDLF/configuration/optimizer_config.py +121 -0
- GANDLF/configuration/parameters_config.py +10 -0
- GANDLF/configuration/patch_sampler_config.py +11 -0
- GANDLF/configuration/post_processing_config.py +10 -0
- GANDLF/configuration/pre_processing_config.py +94 -0
- GANDLF/configuration/scheduler_config.py +90 -0
- GANDLF/configuration/user_defined_config.py +131 -0
- GANDLF/configuration/utils.py +96 -0
- GANDLF/configuration/validators.py +479 -0
- GANDLF/data/__init__.py +14 -16
- GANDLF/data/lightning_datamodule.py +119 -0
- GANDLF/entrypoints/run.py +29 -35
- GANDLF/inference_manager.py +69 -25
- GANDLF/losses/__init__.py +23 -1
- GANDLF/losses/loss_calculators.py +79 -0
- GANDLF/losses/segmentation.py +3 -2
- GANDLF/metrics/__init__.py +26 -0
- GANDLF/metrics/generic.py +1 -1
- GANDLF/metrics/metric_calculators.py +102 -0
- GANDLF/metrics/panoptica_config_brats.yaml +50 -0
- GANDLF/metrics/segmentation_panoptica.py +35 -0
- GANDLF/models/__init__.py +8 -3
- GANDLF/models/lightning_module.py +2102 -0
- GANDLF/optimizers/__init__.py +4 -8
- GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
- GANDLF/schedulers/__init__.py +6 -2
- GANDLF/training_manager.py +159 -69
- GANDLF/utils/__init__.py +4 -3
- GANDLF/utils/imaging.py +121 -2
- GANDLF/utils/modelio.py +9 -7
- GANDLF/utils/pred_target_processors.py +71 -0
- GANDLF/utils/write_parse.py +1 -1
- GANDLF/version.py +1 -1
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info}/METADATA +14 -8
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info}/RECORD +55 -32
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info}/WHEEL +1 -1
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info}/entry_points.txt +0 -0
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info/licenses}/LICENSE +0 -0
- {gandlf-0.1.3.dev20250319.dist-info → gandlf-0.1.4.dev20250503.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from GANDLF.metrics import get_metrics
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AbstractMetricCalculator(ABC):
|
|
9
|
+
def __init__(self, params: dict):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.params = deepcopy(params)
|
|
12
|
+
self._initialize_metrics_dict()
|
|
13
|
+
|
|
14
|
+
def _initialize_metrics_dict(self):
|
|
15
|
+
self.metrics_calculators = get_metrics(self.params)
|
|
16
|
+
|
|
17
|
+
def _process_metric_value(self, metric_value: Union[torch.Tensor, float]):
|
|
18
|
+
if isinstance(metric_value, float):
|
|
19
|
+
return metric_value
|
|
20
|
+
if metric_value.dim() == 0:
|
|
21
|
+
return metric_value.item()
|
|
22
|
+
else:
|
|
23
|
+
return metric_value.tolist()
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def _inject_kwargs_into_params(params, **kwargs):
|
|
27
|
+
for key, value in kwargs.items():
|
|
28
|
+
params[key] = value
|
|
29
|
+
return params
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def __call__(
|
|
33
|
+
self, prediction: torch.Tensor, target: torch.Tensor, **kwargs
|
|
34
|
+
) -> torch.Tensor:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MetricCalculatorSDNet(AbstractMetricCalculator):
|
|
39
|
+
def __init__(self, params):
|
|
40
|
+
super().__init__(params)
|
|
41
|
+
|
|
42
|
+
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
|
|
43
|
+
params = deepcopy(self.params)
|
|
44
|
+
params = self._inject_kwargs_into_params(params, **kwargs)
|
|
45
|
+
|
|
46
|
+
metric_results = {}
|
|
47
|
+
|
|
48
|
+
for metric_name, metric_calculator in self.metrics_calculators.items():
|
|
49
|
+
metric_value = (
|
|
50
|
+
metric_calculator(prediction[0], target.squeeze(-1), params)
|
|
51
|
+
.detach()
|
|
52
|
+
.cpu()
|
|
53
|
+
)
|
|
54
|
+
metric_results[metric_name] = self._process_metric_value(metric_value)
|
|
55
|
+
return metric_results
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MetricCalculatorDeepSupervision(AbstractMetricCalculator):
|
|
59
|
+
def __init__(self, params):
|
|
60
|
+
super().__init__(params)
|
|
61
|
+
|
|
62
|
+
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
|
|
63
|
+
params = deepcopy(self.params)
|
|
64
|
+
params = self._inject_kwargs_into_params(params, **kwargs)
|
|
65
|
+
metric_results = {}
|
|
66
|
+
|
|
67
|
+
for metric_name, metric_calculator in self.metrics_calculators.items():
|
|
68
|
+
metric_results[metric_name] = 0.0
|
|
69
|
+
for i, _ in enumerate(prediction):
|
|
70
|
+
metric_value = (
|
|
71
|
+
metric_calculator(prediction[i], target[i], params).detach().cpu()
|
|
72
|
+
)
|
|
73
|
+
metric_results[metric_name] += self._process_metric_value(metric_value)
|
|
74
|
+
return metric_results
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class MetricCalculatorSimple(AbstractMetricCalculator):
|
|
78
|
+
def __init__(self, params):
|
|
79
|
+
super().__init__(params)
|
|
80
|
+
|
|
81
|
+
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
|
|
82
|
+
params = deepcopy(self.params)
|
|
83
|
+
params = self._inject_kwargs_into_params(params, **kwargs)
|
|
84
|
+
metric_results = {}
|
|
85
|
+
|
|
86
|
+
for metric_name, metric_calculator in self.metrics_calculators.items():
|
|
87
|
+
metric_value = metric_calculator(prediction, target, params).detach().cpu()
|
|
88
|
+
metric_results[metric_name] = self._process_metric_value(metric_value)
|
|
89
|
+
return metric_results
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class MetricCalculatorFactory:
|
|
93
|
+
def __init__(self, params: dict):
|
|
94
|
+
self.params = params
|
|
95
|
+
|
|
96
|
+
def get_metric_calculator(self) -> AbstractMetricCalculator:
|
|
97
|
+
if self.params["model"]["architecture"] == "sdnet":
|
|
98
|
+
return MetricCalculatorSDNet(self.params)
|
|
99
|
+
elif "deep" in self.params["model"]["architecture"].lower():
|
|
100
|
+
return MetricCalculatorDeepSupervision(self.params)
|
|
101
|
+
else:
|
|
102
|
+
return MetricCalculatorSimple(self.params)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
!Panoptica_Evaluator
|
|
2
|
+
decision_metric: null
|
|
3
|
+
decision_threshold: null
|
|
4
|
+
edge_case_handler: !EdgeCaseHandler
|
|
5
|
+
empty_list_std: !EdgeCaseResult NAN
|
|
6
|
+
listmetric_zeroTP_handling:
|
|
7
|
+
!Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
|
|
8
|
+
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
|
|
9
|
+
normal: !EdgeCaseResult ZERO}
|
|
10
|
+
!Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
|
|
11
|
+
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
|
|
12
|
+
normal: !EdgeCaseResult ZERO}
|
|
13
|
+
!Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
|
|
14
|
+
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
|
|
15
|
+
normal: !EdgeCaseResult ZERO}
|
|
16
|
+
!Metric ASSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF,
|
|
17
|
+
empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN,
|
|
18
|
+
normal: !EdgeCaseResult INF}
|
|
19
|
+
!Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
|
|
20
|
+
empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
|
|
21
|
+
normal: !EdgeCaseResult NAN}
|
|
22
|
+
!Metric RVAE: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
|
|
23
|
+
empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
|
|
24
|
+
normal: !EdgeCaseResult NAN}
|
|
25
|
+
expected_input: !InputType SEMANTIC
|
|
26
|
+
global_metrics: [!Metric DSC]
|
|
27
|
+
instance_approximator: !ConnectedComponentsInstanceApproximator {cca_backend: null}
|
|
28
|
+
instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU,
|
|
29
|
+
matching_threshold: 0.5}
|
|
30
|
+
instance_metrics: [!Metric DSC, !Metric IOU, !Metric ASSD, !Metric RVD]
|
|
31
|
+
log_times: false
|
|
32
|
+
save_group_times: false
|
|
33
|
+
segmentation_class_groups: !SegmentationClassGroups
|
|
34
|
+
groups:
|
|
35
|
+
ed: !LabelGroup
|
|
36
|
+
single_instance: false
|
|
37
|
+
value_labels: [2]
|
|
38
|
+
et: !LabelGroup
|
|
39
|
+
single_instance: false
|
|
40
|
+
value_labels: [3]
|
|
41
|
+
net: !LabelGroup
|
|
42
|
+
single_instance: false
|
|
43
|
+
value_labels: [1]
|
|
44
|
+
tc: !LabelMergeGroup
|
|
45
|
+
single_instance: false
|
|
46
|
+
value_labels: [1, 3]
|
|
47
|
+
wt: !LabelMergeGroup
|
|
48
|
+
single_instance: false
|
|
49
|
+
value_labels: [1, 2, 3]
|
|
50
|
+
verbose: false
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from panoptica import Panoptica_Evaluator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def generate_instance_segmentation(
|
|
9
|
+
prediction: np.ndarray, target: np.ndarray, panoptica_config_path: str = None
|
|
10
|
+
) -> dict:
|
|
11
|
+
"""
|
|
12
|
+
Evaluate a single exam using Panoptica.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
prediction (np.ndarray): The input prediction containing objects.
|
|
16
|
+
label_path (str): The path to the reference label.
|
|
17
|
+
panoptica_config_path (str): The path to the Panoptica configuration file.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
dict: The evaluation results.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
cwd = Path(__file__).parent.absolute()
|
|
24
|
+
panoptica_config_path = (
|
|
25
|
+
cwd / "panoptica_config_path.yaml"
|
|
26
|
+
if panoptica_config_path is None
|
|
27
|
+
else panoptica_config_path
|
|
28
|
+
)
|
|
29
|
+
evaluator = Panoptica_Evaluator.load_from_config(panoptica_config_path)
|
|
30
|
+
|
|
31
|
+
# call evaluate
|
|
32
|
+
group2result = evaluator.evaluate(prediction_arr=prediction, reference_arr=target)
|
|
33
|
+
|
|
34
|
+
results = {k: r.to_dict() for k, r in group2result.items()}
|
|
35
|
+
return results
|
GANDLF/models/__init__.py
CHANGED
|
@@ -36,6 +36,7 @@ from .MSDNet import MSDNet
|
|
|
36
36
|
from .brain_age import brainage
|
|
37
37
|
from .unetr import unetr
|
|
38
38
|
from .transunet import transunet
|
|
39
|
+
from .modelBase import ModelBase
|
|
39
40
|
|
|
40
41
|
# Define a dictionary of model architectures and corresponding functions
|
|
41
42
|
global_models_dict = {
|
|
@@ -110,7 +111,7 @@ global_models_dict = {
|
|
|
110
111
|
}
|
|
111
112
|
|
|
112
113
|
|
|
113
|
-
def get_model(params):
|
|
114
|
+
def get_model(params: dict) -> ModelBase:
|
|
114
115
|
"""
|
|
115
116
|
Function to get the model definition.
|
|
116
117
|
|
|
@@ -118,6 +119,10 @@ def get_model(params):
|
|
|
118
119
|
params (dict): The parameters' dictionary.
|
|
119
120
|
|
|
120
121
|
Returns:
|
|
121
|
-
model (
|
|
122
|
+
model (ModelBase): The model definition.
|
|
122
123
|
"""
|
|
123
|
-
|
|
124
|
+
chosen_model = params["model"]["architecture"].lower()
|
|
125
|
+
assert (
|
|
126
|
+
chosen_model in global_models_dict
|
|
127
|
+
), f"Could not find the requested model '{params['model']['architecture']}'"
|
|
128
|
+
return global_models_dict[chosen_model](parameters=params)
|