GANDLF 0.1.3.dev20250318__py3-none-any.whl → 0.1.4.dev20250502__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of GANDLF might be problematic. Click here for more details.

Files changed (55) hide show
  1. GANDLF/cli/deploy.py +2 -2
  2. GANDLF/cli/generate_metrics.py +21 -0
  3. GANDLF/cli/main_run.py +4 -12
  4. GANDLF/compute/__init__.py +0 -2
  5. GANDLF/compute/forward_pass.py +0 -1
  6. GANDLF/compute/generic.py +107 -2
  7. GANDLF/compute/inference_loop.py +4 -4
  8. GANDLF/compute/loss_and_metric.py +1 -2
  9. GANDLF/compute/training_loop.py +10 -10
  10. GANDLF/config_manager.py +26 -716
  11. GANDLF/configuration/__init__.py +0 -0
  12. GANDLF/configuration/default_config.py +73 -0
  13. GANDLF/configuration/differential_privacy_config.py +16 -0
  14. GANDLF/configuration/exclude_parameters.py +1 -0
  15. GANDLF/configuration/model_config.py +82 -0
  16. GANDLF/configuration/nested_training_config.py +25 -0
  17. GANDLF/configuration/optimizer_config.py +121 -0
  18. GANDLF/configuration/parameters_config.py +10 -0
  19. GANDLF/configuration/patch_sampler_config.py +11 -0
  20. GANDLF/configuration/post_processing_config.py +10 -0
  21. GANDLF/configuration/pre_processing_config.py +94 -0
  22. GANDLF/configuration/scheduler_config.py +90 -0
  23. GANDLF/configuration/user_defined_config.py +131 -0
  24. GANDLF/configuration/utils.py +96 -0
  25. GANDLF/configuration/validators.py +479 -0
  26. GANDLF/data/__init__.py +14 -16
  27. GANDLF/data/lightning_datamodule.py +119 -0
  28. GANDLF/entrypoints/run.py +29 -35
  29. GANDLF/inference_manager.py +69 -25
  30. GANDLF/losses/__init__.py +23 -1
  31. GANDLF/losses/loss_calculators.py +79 -0
  32. GANDLF/losses/segmentation.py +3 -2
  33. GANDLF/metrics/__init__.py +26 -0
  34. GANDLF/metrics/generic.py +1 -1
  35. GANDLF/metrics/metric_calculators.py +102 -0
  36. GANDLF/metrics/panoptica_config_brats.yaml +50 -0
  37. GANDLF/metrics/segmentation_panoptica.py +35 -0
  38. GANDLF/models/__init__.py +8 -3
  39. GANDLF/models/lightning_module.py +2102 -0
  40. GANDLF/optimizers/__init__.py +4 -8
  41. GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
  42. GANDLF/schedulers/__init__.py +6 -2
  43. GANDLF/training_manager.py +159 -69
  44. GANDLF/utils/__init__.py +4 -3
  45. GANDLF/utils/imaging.py +121 -2
  46. GANDLF/utils/modelio.py +9 -7
  47. GANDLF/utils/pred_target_processors.py +71 -0
  48. GANDLF/utils/write_parse.py +1 -1
  49. GANDLF/version.py +1 -1
  50. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info}/METADATA +14 -8
  51. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info}/RECORD +55 -32
  52. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info}/WHEEL +1 -1
  53. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info}/entry_points.txt +0 -0
  54. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info/licenses}/LICENSE +0 -0
  55. {gandlf-0.1.3.dev20250318.dist-info → gandlf-0.1.4.dev20250502.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,102 @@
1
+ import torch
2
+ from copy import deepcopy
3
+ from GANDLF.metrics import get_metrics
4
+ from abc import ABC, abstractmethod
5
+ from typing import Union
6
+
7
+
8
+ class AbstractMetricCalculator(ABC):
9
+ def __init__(self, params: dict):
10
+ super().__init__()
11
+ self.params = deepcopy(params)
12
+ self._initialize_metrics_dict()
13
+
14
+ def _initialize_metrics_dict(self):
15
+ self.metrics_calculators = get_metrics(self.params)
16
+
17
+ def _process_metric_value(self, metric_value: Union[torch.Tensor, float]):
18
+ if isinstance(metric_value, float):
19
+ return metric_value
20
+ if metric_value.dim() == 0:
21
+ return metric_value.item()
22
+ else:
23
+ return metric_value.tolist()
24
+
25
+ @staticmethod
26
+ def _inject_kwargs_into_params(params, **kwargs):
27
+ for key, value in kwargs.items():
28
+ params[key] = value
29
+ return params
30
+
31
+ @abstractmethod
32
+ def __call__(
33
+ self, prediction: torch.Tensor, target: torch.Tensor, **kwargs
34
+ ) -> torch.Tensor:
35
+ pass
36
+
37
+
38
+ class MetricCalculatorSDNet(AbstractMetricCalculator):
39
+ def __init__(self, params):
40
+ super().__init__(params)
41
+
42
+ def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
43
+ params = deepcopy(self.params)
44
+ params = self._inject_kwargs_into_params(params, **kwargs)
45
+
46
+ metric_results = {}
47
+
48
+ for metric_name, metric_calculator in self.metrics_calculators.items():
49
+ metric_value = (
50
+ metric_calculator(prediction[0], target.squeeze(-1), params)
51
+ .detach()
52
+ .cpu()
53
+ )
54
+ metric_results[metric_name] = self._process_metric_value(metric_value)
55
+ return metric_results
56
+
57
+
58
+ class MetricCalculatorDeepSupervision(AbstractMetricCalculator):
59
+ def __init__(self, params):
60
+ super().__init__(params)
61
+
62
+ def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
63
+ params = deepcopy(self.params)
64
+ params = self._inject_kwargs_into_params(params, **kwargs)
65
+ metric_results = {}
66
+
67
+ for metric_name, metric_calculator in self.metrics_calculators.items():
68
+ metric_results[metric_name] = 0.0
69
+ for i, _ in enumerate(prediction):
70
+ metric_value = (
71
+ metric_calculator(prediction[i], target[i], params).detach().cpu()
72
+ )
73
+ metric_results[metric_name] += self._process_metric_value(metric_value)
74
+ return metric_results
75
+
76
+
77
+ class MetricCalculatorSimple(AbstractMetricCalculator):
78
+ def __init__(self, params):
79
+ super().__init__(params)
80
+
81
+ def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
82
+ params = deepcopy(self.params)
83
+ params = self._inject_kwargs_into_params(params, **kwargs)
84
+ metric_results = {}
85
+
86
+ for metric_name, metric_calculator in self.metrics_calculators.items():
87
+ metric_value = metric_calculator(prediction, target, params).detach().cpu()
88
+ metric_results[metric_name] = self._process_metric_value(metric_value)
89
+ return metric_results
90
+
91
+
92
+ class MetricCalculatorFactory:
93
+ def __init__(self, params: dict):
94
+ self.params = params
95
+
96
+ def get_metric_calculator(self) -> AbstractMetricCalculator:
97
+ if self.params["model"]["architecture"] == "sdnet":
98
+ return MetricCalculatorSDNet(self.params)
99
+ elif "deep" in self.params["model"]["architecture"].lower():
100
+ return MetricCalculatorDeepSupervision(self.params)
101
+ else:
102
+ return MetricCalculatorSimple(self.params)
@@ -0,0 +1,50 @@
1
+ !Panoptica_Evaluator
2
+ decision_metric: null
3
+ decision_threshold: null
4
+ edge_case_handler: !EdgeCaseHandler
5
+ empty_list_std: !EdgeCaseResult NAN
6
+ listmetric_zeroTP_handling:
7
+ !Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
8
+ empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
9
+ normal: !EdgeCaseResult ZERO}
10
+ !Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
11
+ empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
12
+ normal: !EdgeCaseResult ZERO}
13
+ !Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
14
+ empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
15
+ normal: !EdgeCaseResult ZERO}
16
+ !Metric ASSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF,
17
+ empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN,
18
+ normal: !EdgeCaseResult INF}
19
+ !Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
20
+ empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
21
+ normal: !EdgeCaseResult NAN}
22
+ !Metric RVAE: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
23
+ empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
24
+ normal: !EdgeCaseResult NAN}
25
+ expected_input: !InputType SEMANTIC
26
+ global_metrics: [!Metric DSC]
27
+ instance_approximator: !ConnectedComponentsInstanceApproximator {cca_backend: null}
28
+ instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU,
29
+ matching_threshold: 0.5}
30
+ instance_metrics: [!Metric DSC, !Metric IOU, !Metric ASSD, !Metric RVD]
31
+ log_times: false
32
+ save_group_times: false
33
+ segmentation_class_groups: !SegmentationClassGroups
34
+ groups:
35
+ ed: !LabelGroup
36
+ single_instance: false
37
+ value_labels: [2]
38
+ et: !LabelGroup
39
+ single_instance: false
40
+ value_labels: [3]
41
+ net: !LabelGroup
42
+ single_instance: false
43
+ value_labels: [1]
44
+ tc: !LabelMergeGroup
45
+ single_instance: false
46
+ value_labels: [1, 3]
47
+ wt: !LabelMergeGroup
48
+ single_instance: false
49
+ value_labels: [1, 2, 3]
50
+ verbose: false
@@ -0,0 +1,35 @@
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+
5
+ from panoptica import Panoptica_Evaluator
6
+
7
+
8
+ def generate_instance_segmentation(
9
+ prediction: np.ndarray, target: np.ndarray, panoptica_config_path: str = None
10
+ ) -> dict:
11
+ """
12
+ Evaluate a single exam using Panoptica.
13
+
14
+ Args:
15
+ prediction (np.ndarray): The input prediction containing objects.
16
+ label_path (str): The path to the reference label.
17
+ panoptica_config_path (str): The path to the Panoptica configuration file.
18
+
19
+ Returns:
20
+ dict: The evaluation results.
21
+ """
22
+
23
+ cwd = Path(__file__).parent.absolute()
24
+ panoptica_config_path = (
25
+ cwd / "panoptica_config_path.yaml"
26
+ if panoptica_config_path is None
27
+ else panoptica_config_path
28
+ )
29
+ evaluator = Panoptica_Evaluator.load_from_config(panoptica_config_path)
30
+
31
+ # call evaluate
32
+ group2result = evaluator.evaluate(prediction_arr=prediction, reference_arr=target)
33
+
34
+ results = {k: r.to_dict() for k, r in group2result.items()}
35
+ return results
GANDLF/models/__init__.py CHANGED
@@ -36,6 +36,7 @@ from .MSDNet import MSDNet
36
36
  from .brain_age import brainage
37
37
  from .unetr import unetr
38
38
  from .transunet import transunet
39
+ from .modelBase import ModelBase
39
40
 
40
41
  # Define a dictionary of model architectures and corresponding functions
41
42
  global_models_dict = {
@@ -110,7 +111,7 @@ global_models_dict = {
110
111
  }
111
112
 
112
113
 
113
- def get_model(params):
114
+ def get_model(params: dict) -> ModelBase:
114
115
  """
115
116
  Function to get the model definition.
116
117
 
@@ -118,6 +119,10 @@ def get_model(params):
118
119
  params (dict): The parameters' dictionary.
119
120
 
120
121
  Returns:
121
- model (torch.nn.Module): The model definition.
122
+ model (ModelBase): The model definition.
122
123
  """
123
- return global_models_dict[params["model"]["architecture"]](parameters=params)
124
+ chosen_model = params["model"]["architecture"].lower()
125
+ assert (
126
+ chosen_model in global_models_dict
127
+ ), f"Could not find the requested model '{params['model']['architecture']}'"
128
+ return global_models_dict[chosen_model](parameters=params)