GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of GANDLF might be problematic. Click here for more details.

Files changed (57) hide show
  1. GANDLF/cli/deploy.py +2 -2
  2. GANDLF/cli/generate_metrics.py +35 -1
  3. GANDLF/cli/main_run.py +4 -10
  4. GANDLF/compute/__init__.py +0 -2
  5. GANDLF/compute/forward_pass.py +0 -1
  6. GANDLF/compute/generic.py +107 -2
  7. GANDLF/compute/inference_loop.py +4 -4
  8. GANDLF/compute/loss_and_metric.py +1 -2
  9. GANDLF/compute/training_loop.py +10 -10
  10. GANDLF/config_manager.py +33 -717
  11. GANDLF/configuration/__init__.py +0 -0
  12. GANDLF/configuration/default_config.py +73 -0
  13. GANDLF/configuration/differential_privacy_config.py +16 -0
  14. GANDLF/configuration/exclude_parameters.py +1 -0
  15. GANDLF/configuration/model_config.py +82 -0
  16. GANDLF/configuration/nested_training_config.py +25 -0
  17. GANDLF/configuration/optimizer_config.py +121 -0
  18. GANDLF/configuration/parameters_config.py +10 -0
  19. GANDLF/configuration/patch_sampler_config.py +11 -0
  20. GANDLF/configuration/post_processing_config.py +10 -0
  21. GANDLF/configuration/pre_processing_config.py +94 -0
  22. GANDLF/configuration/scheduler_config.py +92 -0
  23. GANDLF/configuration/user_defined_config.py +131 -0
  24. GANDLF/configuration/utils.py +96 -0
  25. GANDLF/configuration/validators.py +479 -0
  26. GANDLF/data/__init__.py +14 -16
  27. GANDLF/data/lightning_datamodule.py +119 -0
  28. GANDLF/entrypoints/run.py +36 -31
  29. GANDLF/inference_manager.py +69 -25
  30. GANDLF/losses/__init__.py +23 -1
  31. GANDLF/losses/loss_calculators.py +79 -0
  32. GANDLF/losses/segmentation.py +3 -2
  33. GANDLF/metrics/__init__.py +26 -0
  34. GANDLF/metrics/generic.py +1 -1
  35. GANDLF/metrics/metric_calculators.py +102 -0
  36. GANDLF/metrics/panoptica_config_brats.yaml +56 -0
  37. GANDLF/metrics/segmentation_panoptica.py +49 -0
  38. GANDLF/models/__init__.py +8 -3
  39. GANDLF/models/lightning_module.py +2102 -0
  40. GANDLF/optimizers/__init__.py +4 -8
  41. GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
  42. GANDLF/schedulers/__init__.py +11 -4
  43. GANDLF/schedulers/wrap_torch.py +15 -3
  44. GANDLF/training_manager.py +160 -50
  45. GANDLF/utils/__init__.py +5 -3
  46. GANDLF/utils/imaging.py +176 -35
  47. GANDLF/utils/modelio.py +12 -8
  48. GANDLF/utils/pred_target_processors.py +71 -0
  49. GANDLF/utils/tensor.py +2 -1
  50. GANDLF/utils/write_parse.py +1 -1
  51. GANDLF/version.py +1 -1
  52. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
  53. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
  54. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
  55. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
  56. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
  57. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,102 @@
1
+ import torch
2
+ from copy import deepcopy
3
+ from GANDLF.metrics import get_metrics
4
+ from abc import ABC, abstractmethod
5
+ from typing import Union
6
+
7
+
8
+ class AbstractMetricCalculator(ABC):
9
+ def __init__(self, params: dict):
10
+ super().__init__()
11
+ self.params = deepcopy(params)
12
+ self._initialize_metrics_dict()
13
+
14
+ def _initialize_metrics_dict(self):
15
+ self.metrics_calculators = get_metrics(self.params)
16
+
17
+ def _process_metric_value(self, metric_value: Union[torch.Tensor, float]):
18
+ if isinstance(metric_value, float):
19
+ return metric_value
20
+ if metric_value.dim() == 0:
21
+ return metric_value.item()
22
+ else:
23
+ return metric_value.tolist()
24
+
25
+ @staticmethod
26
+ def _inject_kwargs_into_params(params, **kwargs):
27
+ for key, value in kwargs.items():
28
+ params[key] = value
29
+ return params
30
+
31
+ @abstractmethod
32
+ def __call__(
33
+ self, prediction: torch.Tensor, target: torch.Tensor, **kwargs
34
+ ) -> torch.Tensor:
35
+ pass
36
+
37
+
38
+ class MetricCalculatorSDNet(AbstractMetricCalculator):
39
+ def __init__(self, params):
40
+ super().__init__(params)
41
+
42
+ def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
43
+ params = deepcopy(self.params)
44
+ params = self._inject_kwargs_into_params(params, **kwargs)
45
+
46
+ metric_results = {}
47
+
48
+ for metric_name, metric_calculator in self.metrics_calculators.items():
49
+ metric_value = (
50
+ metric_calculator(prediction[0], target.squeeze(-1), params)
51
+ .detach()
52
+ .cpu()
53
+ )
54
+ metric_results[metric_name] = self._process_metric_value(metric_value)
55
+ return metric_results
56
+
57
+
58
+ class MetricCalculatorDeepSupervision(AbstractMetricCalculator):
59
+ def __init__(self, params):
60
+ super().__init__(params)
61
+
62
+ def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
63
+ params = deepcopy(self.params)
64
+ params = self._inject_kwargs_into_params(params, **kwargs)
65
+ metric_results = {}
66
+
67
+ for metric_name, metric_calculator in self.metrics_calculators.items():
68
+ metric_results[metric_name] = 0.0
69
+ for i, _ in enumerate(prediction):
70
+ metric_value = (
71
+ metric_calculator(prediction[i], target[i], params).detach().cpu()
72
+ )
73
+ metric_results[metric_name] += self._process_metric_value(metric_value)
74
+ return metric_results
75
+
76
+
77
+ class MetricCalculatorSimple(AbstractMetricCalculator):
78
+ def __init__(self, params):
79
+ super().__init__(params)
80
+
81
+ def __call__(self, prediction: torch.Tensor, target: torch.Tensor, **kwargs):
82
+ params = deepcopy(self.params)
83
+ params = self._inject_kwargs_into_params(params, **kwargs)
84
+ metric_results = {}
85
+
86
+ for metric_name, metric_calculator in self.metrics_calculators.items():
87
+ metric_value = metric_calculator(prediction, target, params).detach().cpu()
88
+ metric_results[metric_name] = self._process_metric_value(metric_value)
89
+ return metric_results
90
+
91
+
92
+ class MetricCalculatorFactory:
93
+ def __init__(self, params: dict):
94
+ self.params = params
95
+
96
+ def get_metric_calculator(self) -> AbstractMetricCalculator:
97
+ if self.params["model"]["architecture"] == "sdnet":
98
+ return MetricCalculatorSDNet(self.params)
99
+ elif "deep" in self.params["model"]["architecture"].lower():
100
+ return MetricCalculatorDeepSupervision(self.params)
101
+ else:
102
+ return MetricCalculatorSimple(self.params)
@@ -0,0 +1,56 @@
1
+ !Panoptica_Evaluator
2
+ decision_metric: null
3
+ decision_threshold: null
4
+ edge_case_handler: !EdgeCaseHandler
5
+ empty_list_std: !EdgeCaseResult NAN
6
+ listmetric_zeroTP_handling:
7
+ !Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
8
+ empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
9
+ normal: !EdgeCaseResult ZERO}
10
+ !Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
11
+ empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
12
+ normal: !EdgeCaseResult ZERO}
13
+ !Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
14
+ empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
15
+ normal: !EdgeCaseResult ZERO}
16
+ !Metric NSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF,
17
+ empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN,
18
+ normal: !EdgeCaseResult INF}
19
+ !Metric HD95: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF,
20
+ empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN,
21
+ normal: !EdgeCaseResult INF}
22
+ !Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
23
+ empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
24
+ normal: !EdgeCaseResult NAN}
25
+ !Metric RVAE: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
26
+ empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
27
+ normal: !EdgeCaseResult NAN}
28
+ expected_input: !InputType SEMANTIC
29
+ global_metrics: [!Metric DSC]
30
+ instance_approximator: !ConnectedComponentsInstanceApproximator {cca_backend: null}
31
+ instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU,
32
+ matching_threshold: 0.5}
33
+ instance_metrics: [!Metric DSC, !Metric IOU, !Metric RVD, !Metric NSD, !Metric HD95]
34
+ log_times: false
35
+ save_group_times: false
36
+ segmentation_class_groups: !SegmentationClassGroups
37
+ groups:
38
+ snfh: !LabelGroup
39
+ single_instance: false
40
+ value_labels: [2]
41
+ et: !LabelGroup
42
+ single_instance: false
43
+ value_labels: [3]
44
+ netc: !LabelGroup
45
+ single_instance: false
46
+ value_labels: [1]
47
+ rc: !LabelGroup
48
+ single_instance: false
49
+ value_labels: [4]
50
+ tc: !LabelMergeGroup
51
+ single_instance: false
52
+ value_labels: [1, 3, 4]
53
+ wt: !LabelMergeGroup
54
+ single_instance: false
55
+ value_labels: [1, 2, 3, 4]
56
+ verbose: false
@@ -0,0 +1,49 @@
1
+ from pathlib import Path
2
+ import os
3
+
4
+ from typing import Optional
5
+ import numpy as np
6
+
7
+ from panoptica import Panoptica_Evaluator
8
+
9
+
10
+ def generate_instance_segmentation(
11
+ prediction: np.ndarray,
12
+ target: np.ndarray,
13
+ parameters: dict,
14
+ panoptica_config_path: Optional[str] = None,
15
+ ) -> dict:
16
+ """
17
+ Evaluate a single exam using Panoptica.
18
+
19
+ Args:
20
+ prediction (np.ndarray): The input prediction containing objects.
21
+ label_path (str): The path to the reference label.
22
+ target (np.ndarray): The input target containing objects.
23
+ parameters (dict): The GaNDLF parameters from which panoptica config is to be extracted.
24
+ panoptica_config_path (str): The path to the Panoptica configuration file.
25
+
26
+ Returns:
27
+ dict: The evaluation results.
28
+ """
29
+
30
+ os.environ["PANOPTICA_CITATION_REMINDER"] = "False"
31
+
32
+ # the parameters dict takes precedence over the panoptica_config_path
33
+ evaluator = parameters.get("panoptica_config", None)
34
+ if evaluator is None:
35
+ cwd = Path(__file__).parent.absolute()
36
+ panoptica_config_path = (
37
+ str(cwd / "panoptica_config_brats.yaml")
38
+ if panoptica_config_path is None
39
+ else panoptica_config_path
40
+ )
41
+ evaluator = Panoptica_Evaluator.load_from_config(panoptica_config_path)
42
+
43
+ assert evaluator is not None, "Panoptica evaluator could not be initialized."
44
+
45
+ # call evaluate
46
+ group2result = evaluator.evaluate(prediction_arr=prediction, reference_arr=target)
47
+
48
+ results = {k: r.to_dict() for k, r in group2result.items()}
49
+ return results
GANDLF/models/__init__.py CHANGED
@@ -36,6 +36,7 @@ from .MSDNet import MSDNet
36
36
  from .brain_age import brainage
37
37
  from .unetr import unetr
38
38
  from .transunet import transunet
39
+ from .modelBase import ModelBase
39
40
 
40
41
  # Define a dictionary of model architectures and corresponding functions
41
42
  global_models_dict = {
@@ -110,7 +111,7 @@ global_models_dict = {
110
111
  }
111
112
 
112
113
 
113
- def get_model(params):
114
+ def get_model(params: dict) -> ModelBase:
114
115
  """
115
116
  Function to get the model definition.
116
117
 
@@ -118,6 +119,10 @@ def get_model(params):
118
119
  params (dict): The parameters' dictionary.
119
120
 
120
121
  Returns:
121
- model (torch.nn.Module): The model definition.
122
+ model (ModelBase): The model definition.
122
123
  """
123
- return global_models_dict[params["model"]["architecture"]](parameters=params)
124
+ chosen_model = params["model"]["architecture"].lower()
125
+ assert (
126
+ chosen_model in global_models_dict
127
+ ), f"Could not find the requested model '{params['model']['architecture']}'"
128
+ return global_models_dict[chosen_model](parameters=params)