GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of GANDLF might be problematic. Click here for more details.

Files changed (57) hide show
  1. GANDLF/cli/deploy.py +2 -2
  2. GANDLF/cli/generate_metrics.py +35 -1
  3. GANDLF/cli/main_run.py +4 -10
  4. GANDLF/compute/__init__.py +0 -2
  5. GANDLF/compute/forward_pass.py +0 -1
  6. GANDLF/compute/generic.py +107 -2
  7. GANDLF/compute/inference_loop.py +4 -4
  8. GANDLF/compute/loss_and_metric.py +1 -2
  9. GANDLF/compute/training_loop.py +10 -10
  10. GANDLF/config_manager.py +33 -717
  11. GANDLF/configuration/__init__.py +0 -0
  12. GANDLF/configuration/default_config.py +73 -0
  13. GANDLF/configuration/differential_privacy_config.py +16 -0
  14. GANDLF/configuration/exclude_parameters.py +1 -0
  15. GANDLF/configuration/model_config.py +82 -0
  16. GANDLF/configuration/nested_training_config.py +25 -0
  17. GANDLF/configuration/optimizer_config.py +121 -0
  18. GANDLF/configuration/parameters_config.py +10 -0
  19. GANDLF/configuration/patch_sampler_config.py +11 -0
  20. GANDLF/configuration/post_processing_config.py +10 -0
  21. GANDLF/configuration/pre_processing_config.py +94 -0
  22. GANDLF/configuration/scheduler_config.py +92 -0
  23. GANDLF/configuration/user_defined_config.py +131 -0
  24. GANDLF/configuration/utils.py +96 -0
  25. GANDLF/configuration/validators.py +479 -0
  26. GANDLF/data/__init__.py +14 -16
  27. GANDLF/data/lightning_datamodule.py +119 -0
  28. GANDLF/entrypoints/run.py +36 -31
  29. GANDLF/inference_manager.py +69 -25
  30. GANDLF/losses/__init__.py +23 -1
  31. GANDLF/losses/loss_calculators.py +79 -0
  32. GANDLF/losses/segmentation.py +3 -2
  33. GANDLF/metrics/__init__.py +26 -0
  34. GANDLF/metrics/generic.py +1 -1
  35. GANDLF/metrics/metric_calculators.py +102 -0
  36. GANDLF/metrics/panoptica_config_brats.yaml +56 -0
  37. GANDLF/metrics/segmentation_panoptica.py +49 -0
  38. GANDLF/models/__init__.py +8 -3
  39. GANDLF/models/lightning_module.py +2102 -0
  40. GANDLF/optimizers/__init__.py +4 -8
  41. GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
  42. GANDLF/schedulers/__init__.py +11 -4
  43. GANDLF/schedulers/wrap_torch.py +15 -3
  44. GANDLF/training_manager.py +160 -50
  45. GANDLF/utils/__init__.py +5 -3
  46. GANDLF/utils/imaging.py +176 -35
  47. GANDLF/utils/modelio.py +12 -8
  48. GANDLF/utils/pred_target_processors.py +71 -0
  49. GANDLF/utils/tensor.py +2 -1
  50. GANDLF/utils/write_parse.py +1 -1
  51. GANDLF/version.py +1 -1
  52. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
  53. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
  54. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
  55. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
  56. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
  57. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,119 @@
1
+ import lightning.pytorch as pl
2
+ from GANDLF.compute.generic import (
3
+ TrainingSubsetDataParser,
4
+ ValidationSubsetDataParser,
5
+ TestSubsetDataParser,
6
+ InferenceSubsetDataParserRadiology,
7
+ )
8
+ from torch.utils.data import DataLoader as TorchDataLoader
9
+ from copy import deepcopy
10
+
11
+
12
+ class GandlfTrainingDatamodule(pl.LightningDataModule):
13
+ def __init__(self, data_dict_files: dict, parameters_dict: dict):
14
+ super().__init__()
15
+
16
+ # Batch size here and reinitialization of dataloader parsers is used
17
+ # in automatic batch size tuning
18
+
19
+ self.batch_size = parameters_dict["batch_size"]
20
+
21
+ # This init procedure is extreme hack, but the only way to get around the
22
+ # need to modify the parameters dict during the parsing procedure
23
+
24
+ params = deepcopy(parameters_dict)
25
+
26
+ train_subset_parser = TrainingSubsetDataParser(
27
+ data_dict_files["training"], params
28
+ )
29
+ self.training_dataset = train_subset_parser.create_subset_dataset()
30
+ params = train_subset_parser.get_params_extended_with_subset_data()
31
+
32
+ val_subset_parser = ValidationSubsetDataParser(
33
+ data_dict_files["validation"], params
34
+ )
35
+ self.validation_dataset = val_subset_parser.create_subset_dataset()
36
+ params = val_subset_parser.get_params_extended_with_subset_data()
37
+
38
+ testing_data = data_dict_files.get("testing", None)
39
+ self.test_dataset = None
40
+ if testing_data is not None:
41
+ test_subset_parser = TestSubsetDataParser(testing_data, params)
42
+ self.test_dataset = test_subset_parser.create_subset_dataset()
43
+ params = test_subset_parser.get_params_extended_with_subset_data()
44
+
45
+ self.parameters_dict = params
46
+
47
+ def _get_dataloader(self, dataset, batch_size: int, shuffle: bool):
48
+ return TorchDataLoader(
49
+ dataset,
50
+ batch_size=batch_size,
51
+ shuffle=shuffle,
52
+ num_workers=self.updated_parameters_dict.get("num_workers_dataloader", 1),
53
+ pin_memory=self.updated_parameters_dict.get("pin_memory_dataloader", False),
54
+ prefetch_factor=self.updated_parameters_dict.get(
55
+ "prefetch_factor_dataloader", 2
56
+ ),
57
+ )
58
+
59
+ @property
60
+ def updated_parameters_dict(self):
61
+ return self.parameters_dict
62
+
63
+ def train_dataloader(self):
64
+ self.updated_parameters_dict["batch_size"] = self.batch_size
65
+ return self._get_dataloader(
66
+ self.training_dataset, self.batch_size, shuffle=True
67
+ )
68
+
69
+ def val_dataloader(self):
70
+ return self._get_dataloader(
71
+ self.validation_dataset, batch_size=1, shuffle=False
72
+ )
73
+
74
+ def test_dataloader(self):
75
+ if self.test_dataset is None:
76
+ return None
77
+ return self._get_dataloader(self.test_dataset, batch_size=1, shuffle=False)
78
+
79
+
80
+ class GandlfInferenceDatamodule(pl.LightningDataModule):
81
+ def __init__(self, dataframe, parameters_dict):
82
+ super().__init__()
83
+ self.dataframe = dataframe
84
+ params = deepcopy(parameters_dict)
85
+ self.parameters_dict = params
86
+ if self.parameters_dict["modality"] == "rad":
87
+ inference_subset_data_parser_radiology = InferenceSubsetDataParserRadiology(
88
+ self.dataframe, params
89
+ )
90
+ self.inference_dataset = (
91
+ inference_subset_data_parser_radiology.create_subset_dataset()
92
+ )
93
+
94
+ self.parameters_dict = (
95
+ inference_subset_data_parser_radiology.get_params_extended_with_subset_data()
96
+ )
97
+
98
+ @property
99
+ def updated_parameters_dict(self):
100
+ return self.parameters_dict
101
+
102
+ def predict_dataloader(self):
103
+ if self.parameters_dict["modality"] == "rad":
104
+ return TorchDataLoader(
105
+ self.inference_dataset,
106
+ batch_size=1,
107
+ shuffle=False,
108
+ num_workers=self.updated_parameters_dict.get(
109
+ "num_workers_dataloader", 1
110
+ ),
111
+ pin_memory=self.updated_parameters_dict.get(
112
+ "pin_memory_dataloader", False
113
+ ),
114
+ prefetch_factor=self.updated_parameters_dict.get(
115
+ "prefetch_factor_dataloader", 2
116
+ ),
117
+ )
118
+ elif self.parameters_dict["modality"] in ["path", "histo"]:
119
+ return self.dataframe.iterrows()
GANDLF/entrypoints/run.py CHANGED
@@ -14,16 +14,19 @@ from GANDLF.cli import main_run, copyrightMessage
14
14
  from GANDLF.entrypoints import append_copyright_to_help
15
15
  from GANDLF.utils import logger_setup
16
16
 
17
+ from warnings import warn
18
+
17
19
 
18
20
  def _run(
19
21
  config: str,
20
22
  input_data: str,
21
23
  train_flag: bool,
22
24
  model_dir: str,
23
- device: str,
24
25
  reset_flag: bool,
25
26
  resume_flag: bool,
27
+ device: Optional[str],
26
28
  output_path: Optional[str],
29
+ profile: Optional[bool] = False,
27
30
  ):
28
31
  if model_dir is None and output_path:
29
32
  model_dir = output_path
@@ -48,7 +51,11 @@ def _run(
48
51
  "'reset' and 'resume' are mutually exclusive; 'resume' will be used."
49
52
  )
50
53
  reset_flag = False
51
-
54
+ if device is not None:
55
+ warn(
56
+ "Device parameter is deprecated and has no effect. See migration guide docs.",
57
+ DeprecationWarning,
58
+ )
52
59
  # TODO: check that output_path is not passed in training mode;
53
60
  # maybe user misconfigured the command
54
61
 
@@ -56,20 +63,20 @@ def _run(
56
63
  logging.debug(f"{input_data=}")
57
64
  logging.debug(f"{train_flag=}")
58
65
  logging.debug(f"{model_dir=}")
59
- logging.debug(f"{device=}")
60
66
  logging.debug(f"{reset_flag=}")
61
67
  logging.debug(f"{resume_flag=}")
62
68
  logging.debug(f"{output_path=}")
69
+ logging.debug(f"{profile=}")
63
70
 
64
71
  main_run(
65
72
  data_csv=input_data,
66
73
  config_file=config,
67
74
  model_dir=model_dir,
68
75
  train_mode=train_flag,
69
- device=device,
70
76
  resume=resume_flag,
71
77
  reset=reset_flag,
72
78
  output_dir=output_path,
79
+ profile=profile,
73
80
  )
74
81
  print("Finished.")
75
82
 
@@ -105,21 +112,6 @@ def _run(
105
112
  help="Training: Output directory to save intermediate files and model weights; "
106
113
  "inference: location of previous training session output",
107
114
  )
108
- @click.option(
109
- "--device",
110
- "-d",
111
- # TODO: Not sure it's worth to restrict this list. What about other devices?
112
- # GaNDLF guarantees to work properly with these two options, but
113
- # other values may be partially working also.
114
- # * GaNDLF code convert `-1` to `cpu` (i.e. it is expected somebody may pass -1)
115
- # * `cuda:0` should work also, isn't it? Just would not be treated as `cuda`
116
- # * Would `mps` work?
117
- # * int values (like `1`) - are they supported? (legacy mode for cuda https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)
118
- type=click.Choice(["cuda", "cpu"]),
119
- required=True, # FIXME: either keep default value, or set required flag
120
- help="Device to perform requested session on 'cpu' or 'cuda'; "
121
- "for cuda, ensure CUDA_VISIBLE_DEVICES env var is set",
122
- )
123
115
  @click.option(
124
116
  "--reset",
125
117
  "-rt",
@@ -145,18 +137,31 @@ def _run(
145
137
  default=None,
146
138
  help="Output file which will contain the logs.",
147
139
  )
140
+ @click.option(
141
+ "--profile",
142
+ "-pf",
143
+ is_flag=True,
144
+ help="Track the run time and memory consumption for each layer",
145
+ )
146
+ @click.option(
147
+ "--device",
148
+ "-d",
149
+ type=str,
150
+ help="DEPRECATED - has no effect, see migration guide docs. Device to run the model on.",
151
+ )
148
152
  @append_copyright_to_help
149
153
  def new_way(
150
154
  config: str,
151
155
  input_data: str,
152
156
  train: bool,
153
157
  model_dir: str,
154
- device: str,
155
158
  reset: bool,
156
159
  resume: bool,
157
160
  output_path: str,
158
161
  raw_input: str,
162
+ profile: bool,
159
163
  log_file: str,
164
+ device: str,
160
165
  ):
161
166
  """Semantic segmentation, regression, and classification for medical images using Deep Learning."""
162
167
 
@@ -166,10 +171,11 @@ def new_way(
166
171
  input_data=input_data,
167
172
  train_flag=train,
168
173
  model_dir=model_dir,
169
- device=device,
170
174
  reset_flag=reset,
171
175
  resume_flag=resume,
172
176
  output_path=output_path,
177
+ profile=profile,
178
+ device=device,
173
179
  )
174
180
 
175
181
 
@@ -227,15 +233,7 @@ def old_way():
227
233
  type=str,
228
234
  help="Training: Output directory to save intermediate files and model weights; inference: location of previous training session output",
229
235
  )
230
- parser.add_argument(
231
- "-d",
232
- "--device",
233
- default="cuda", # TODO: default value doesn't work as arg is required
234
- metavar="",
235
- type=str,
236
- required=True,
237
- help="Device to perform requested session on 'cpu' or 'cuda'; for cuda, ensure CUDA_VISIBLE_DEVICES env var is set",
238
- )
236
+
239
237
  parser.add_argument(
240
238
  "-rt",
241
239
  "--reset",
@@ -267,6 +265,13 @@ def old_way():
267
265
  version="%(prog)s v{}".format(version) + "\n\n" + copyrightMessage,
268
266
  help="Show program's version number and exit.",
269
267
  )
268
+ parser.add_argument(
269
+ "-d",
270
+ "--device",
271
+ metavar="",
272
+ type=str,
273
+ help="DEPRECATED - has no effect, see migration guide docs. Device to run the model on.",
274
+ )
270
275
 
271
276
  # This is a dummy argument that exists to trigger MLCube mounting requirements.
272
277
  # Do not remove.
@@ -282,10 +287,10 @@ def old_way():
282
287
  input_data=args.inputdata,
283
288
  train_flag=args.train,
284
289
  model_dir=args.modeldir,
285
- device=args.device,
286
290
  reset_flag=args.reset,
287
291
  resume_flag=args.resume,
288
292
  output_path=args.outputdir,
293
+ device=args.device,
289
294
  )
290
295
 
291
296
 
@@ -4,16 +4,16 @@ from typing import Optional
4
4
 
5
5
  import pandas as pd
6
6
  import torch
7
- import torch.nn.functional as F
8
- from GANDLF.compute import inference_loop
9
7
  from GANDLF.utils import get_unique_timestamp
8
+ import lightning.pytorch as pl
9
+ from GANDLF.models.lightning_module import GandlfLightningModule
10
+ from GANDLF.data.lightning_datamodule import GandlfInferenceDatamodule
10
11
 
11
12
 
12
13
  def InferenceManager(
13
14
  dataframe: pd.DataFrame,
14
15
  modelDir: str,
15
16
  parameters: dict,
16
- device: str,
17
17
  outputDir: Optional[str] = None,
18
18
  ) -> None:
19
19
  """
@@ -27,7 +27,7 @@ def InferenceManager(
27
27
  outputDir (Optional[str], optional): The output directory for the inference results. Defaults to None.
28
28
  """
29
29
  if outputDir is None:
30
- outputDir = os.path.join(modelDir, get_unique_timestamp())
30
+ outputDir = os.path.join(modelDir, get_unique_timestamp(), "output_inference")
31
31
  print(
32
32
  "Output directory not provided, creating a new directory with a unique timestamp: ",
33
33
  outputDir,
@@ -43,6 +43,35 @@ def InferenceManager(
43
43
  n_folds = parameters["nested_training"]["validation"]
44
44
  modelDir_split = modelDir.split(",") if "," in modelDir else [modelDir]
45
45
 
46
+ # This should be handled by config parser
47
+ accelerator = parameters.get("accelerator", "auto")
48
+ allowed_accelerators = ["cpu", "gpu", "auto"]
49
+ # codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
50
+ assert (
51
+ accelerator in allowed_accelerators
52
+ ), f"Invalid accelerator selected: {accelerator}. Please select from {allowed_accelerators}"
53
+ strategy = parameters.get("strategy", "auto")
54
+ allowed_strategies = ["auto", "ddp"]
55
+ # codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
56
+ assert (
57
+ strategy in allowed_strategies
58
+ ), f"Invalid strategy selected: {strategy}. Please select from {allowed_strategies}"
59
+ precision = parameters.get("precision", "32")
60
+ allowed_precisions = [
61
+ "64",
62
+ "64-true",
63
+ "32",
64
+ "32-true",
65
+ "16",
66
+ "16-mixed",
67
+ "bf16",
68
+ "bf16-mixed",
69
+ ]
70
+ # codacy ignore Generic/ReDoS: This is not a SQL query, it's an error message.
71
+ assert (
72
+ precision in allowed_precisions
73
+ ), f"Invalid precision selected: {precision}. Please select from {allowed_precisions}"
74
+
46
75
  averaged_probs_list = []
47
76
  for current_modelDir in modelDir_split:
48
77
  fold_dirs = (
@@ -55,7 +84,6 @@ def InferenceManager(
55
84
  else [current_modelDir]
56
85
  )
57
86
 
58
- probs_list = []
59
87
  is_classification = parameters["problem_type"] == "classification"
60
88
  parameters["model"].setdefault("type", "torch")
61
89
  class_list = (
@@ -63,27 +91,43 @@ def InferenceManager(
63
91
  if is_classification
64
92
  else None
65
93
  )
66
-
94
+ probs_list = None
67
95
  for fold_dir in fold_dirs:
68
- parameters["current_fold_dir"] = fold_dir
69
- inference_loop(
70
- inferenceDataFromPickle=dataframe,
71
- modelDir=fold_dir,
72
- device=device,
73
- parameters=parameters,
74
- outputDir=outputDir,
96
+ trainer = pl.Trainer(
97
+ accelerator=accelerator,
98
+ strategy=strategy,
99
+ fast_dev_run=False,
100
+ devices=parameters.get("devices", "auto"),
101
+ num_nodes=parameters.get("num_nodes", 1),
102
+ precision=precision,
103
+ gradient_clip_algorithm=parameters["clip_mode"],
104
+ gradient_clip_val=parameters["clip_grad"],
105
+ max_epochs=parameters["num_epochs"],
106
+ sync_batchnorm=False,
107
+ enable_checkpointing=False,
108
+ logger=False,
109
+ num_sanity_val_steps=0,
75
110
  )
111
+ datamodule = GandlfInferenceDatamodule(dataframe, parameters)
112
+ parameters = datamodule.updated_parameters_dict
113
+ lightning_module = GandlfLightningModule(parameters, output_dir=fold_dir)
76
114
 
115
+ if parameters.get("auto_batch_size_find", False):
116
+ print(
117
+ "Auto batch size find is not supported in inference. Dataloader batch size is always 1."
118
+ )
119
+
120
+ trainer.predict(lightning_module, datamodule=datamodule)
77
121
  if is_classification:
78
- logits_path = os.path.join(fold_dir, "logits.csv")
79
- if os.path.isfile(logits_path):
80
- fold_logits = pd.read_csv(logits_path)[class_list].values
81
- fold_logits = torch.from_numpy(fold_logits)
82
- fold_probs = F.softmax(fold_logits, dim=1)
83
- probs_list.append(fold_probs)
84
-
85
- if is_classification and probs_list:
86
- probs_list = torch.stack(probs_list)
122
+ prob_values_for_all_subjects_in_fold = list(
123
+ lightning_module.subject_classification_class_probabilities.values()
124
+ )
125
+ if prob_values_for_all_subjects_in_fold:
126
+ probs_list = torch.stack(
127
+ prob_values_for_all_subjects_in_fold, dim=1
128
+ )
129
+
130
+ if is_classification and probs_list is not None:
87
131
  averaged_probs_list.append(torch.mean(probs_list, 0))
88
132
 
89
133
  # this logic should be changed if we want to do multi-fold inference for histo images
@@ -93,9 +137,9 @@ def InferenceManager(
93
137
  )
94
138
  averaged_probs_df["SubjectID"] = dataframe.iloc[:, 0]
95
139
 
96
- averaged_probs_across_models = torch.mean(
97
- torch.stack(averaged_probs_list), 0
98
- ).numpy()
140
+ averaged_probs_across_models = (
141
+ torch.mean(torch.stack(averaged_probs_list), 0).cpu().numpy()
142
+ )
99
143
  averaged_probs_df[class_list] = averaged_probs_across_models
100
144
  averaged_probs_df["PredictedClass"] = [
101
145
  class_list[idx] for idx in averaged_probs_across_models.argmax(axis=1)
GANDLF/losses/__init__.py CHANGED
@@ -13,7 +13,6 @@ from .segmentation import (
13
13
  from .regression import CE, CEL, MSE_loss, L1_loss
14
14
  from .hybrid import DCCE, DCCE_Logits, DC_Focal
15
15
 
16
-
17
16
  # global defines for the losses
18
17
  global_losses_dict = {
19
18
  "dc": MCD_loss,
@@ -38,3 +37,26 @@ global_losses_dict = {
38
37
  "focal": FocalLoss,
39
38
  "dc_focal": DC_Focal,
40
39
  }
40
+
41
+
42
+ def get_loss(params: dict) -> object:
43
+ """
44
+ Function to get the loss definition.
45
+
46
+ Args:
47
+ params (dict): The parameters' dictionary.
48
+
49
+ Returns:
50
+ loss (object): The loss definition.
51
+ """
52
+ # TODO This check looks like legacy code, should we have it?
53
+
54
+ if isinstance(params["loss_function"], dict):
55
+ chosen_loss = list(params["loss_function"].keys())[0].lower()
56
+ else:
57
+ chosen_loss = params["loss_function"].lower()
58
+ assert (
59
+ chosen_loss in global_losses_dict
60
+ ), f"Could not find the requested loss function '{params['loss_function']}'"
61
+
62
+ return global_losses_dict[chosen_loss]
@@ -0,0 +1,79 @@
1
+ import torch
2
+ from abc import ABC, abstractmethod
3
+
4
+ from GANDLF.losses import get_loss
5
+
6
+
7
+ class AbstractLossCalculator(ABC):
8
+ def __init__(self, params: dict):
9
+ super().__init__()
10
+ self.params = params
11
+ self._initialize_loss()
12
+
13
+ def _initialize_loss(self):
14
+ self.loss = get_loss(self.params)
15
+
16
+ @abstractmethod
17
+ def __call__(
18
+ self, prediction: torch.Tensor, target: torch.Tensor, *args
19
+ ) -> torch.Tensor:
20
+ pass
21
+
22
+
23
+ class LossCalculatorSDNet(AbstractLossCalculator):
24
+ def __init__(self, params):
25
+ super().__init__(params)
26
+ self.l1_loss = get_loss(params)
27
+ self.kld_loss = get_loss(params)
28
+ self.mse_loss = get_loss(params)
29
+
30
+ def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args):
31
+ if len(prediction) < 2:
32
+ image: torch.Tensor = args[0]
33
+ loss_seg = self.loss(prediction[0], target.squeeze(-1), self.params)
34
+ loss_reco = self.l1_loss(prediction[1], image[:, :1, ...], None)
35
+ loss_kld = self.kld_loss(prediction[2], prediction[3])
36
+ loss_cycle = self.mse_loss(prediction[2], prediction[4], None)
37
+ return 0.01 * loss_kld + loss_reco + 10 * loss_seg + loss_cycle
38
+ else:
39
+ return self.loss(prediction, target, self.params)
40
+
41
+
42
+ class LossCalculatorDeepSupervision(AbstractLossCalculator):
43
+ def __init__(self, params):
44
+ super().__init__(params)
45
+ # This was taken from current Gandlf code, but I am not sure if
46
+ # we should have this set rigidly here, as it enforces the number of
47
+ # classes to be 4.
48
+ self.loss_weights = [0.5, 0.25, 0.175, 0.075]
49
+
50
+ def __call__(
51
+ self, prediction: torch.Tensor, target: torch.Tensor, *args
52
+ ) -> torch.Tensor:
53
+ loss_values = []
54
+ for i, pred in enumerate(prediction):
55
+ loss_values.append(
56
+ self.loss(pred, target[i], self.params) * self.loss_weights[i]
57
+ )
58
+ loss = torch.stack(loss_values).sum()
59
+ return loss
60
+
61
+
62
+ class LossCalculatorSimple(AbstractLossCalculator):
63
+ def __call__(
64
+ self, prediction: torch.Tensor, target: torch.Tensor, *args
65
+ ) -> torch.Tensor:
66
+ return self.loss(prediction, target, self.params)
67
+
68
+
69
+ class LossCalculatorFactory:
70
+ def __init__(self, params: dict):
71
+ self.params = params
72
+
73
+ def get_loss_calculator(self) -> AbstractLossCalculator:
74
+ if self.params["model"]["architecture"] == "sdnet":
75
+ return LossCalculatorSDNet(self.params)
76
+ elif "deep" in self.params["model"]["architecture"].lower():
77
+ return LossCalculatorDeepSupervision(self.params)
78
+ else:
79
+ return LossCalculatorSimple(self.params)
@@ -87,7 +87,8 @@ def generic_loss_calculator(
87
87
  accumulated_loss = 0
88
88
  # default to a ridiculous value so that it is ignored by default
89
89
  ignore_class = -1e10 if ignore_class is None else ignore_class
90
-
90
+ predicted = predicted.squeeze(-1)
91
+ target = target.squeeze(-1)
91
92
  for class_index in range(num_class):
92
93
  if class_index != ignore_class:
93
94
  current_loss = loss_criteria(
@@ -334,7 +335,7 @@ def FocalLoss(
334
335
  torch.Tensor: Computed focal loss for a single class.
335
336
  """
336
337
  ce_loss = torch.nn.CrossEntropyLoss(reduce=False)
337
- logpt = ce_loss(preds, target)
338
+ logpt = ce_loss(preds.squeeze(-1), target.squeeze(-1))
338
339
  pt = torch.exp(-logpt)
339
340
  loss = ((1 - pt) ** gamma) * logpt
340
341
  return_loss = loss.sum()
@@ -1,6 +1,7 @@
1
1
  """
2
2
  All the metrics are to be called from here
3
3
  """
4
+ from warnings import warn
4
5
  from typing import Union
5
6
 
6
7
  from GANDLF.losses.regression import MSE_loss, CEL
@@ -40,6 +41,7 @@ from .synthesis import (
40
41
  )
41
42
  import GANDLF.metrics.classification as classification
42
43
  import GANDLF.metrics.regression as regression
44
+ from .segmentation_panoptica import generate_instance_segmentation
43
45
 
44
46
 
45
47
  # global defines for the metrics
@@ -100,6 +102,30 @@ surface_distance_ids = [
100
102
  ]
101
103
 
102
104
 
105
+ def get_metrics(params: dict) -> dict:
106
+ """
107
+ Returns an dictionary of containing calculators of the specified metric functions
108
+
109
+ Args:
110
+ params (dict): A dictionary containing the overall training parameters.
111
+
112
+ Returns:
113
+ metric_calculators (dict): A dictionary containing the calculators of the specified metric functions.
114
+ """
115
+ metric_calculators = {}
116
+ for metric_name in params["metrics"]:
117
+ metric_name = metric_name.lower()
118
+ if metric_name not in global_metrics_dict:
119
+ warn(
120
+ f"Metric {metric_name} not found in global metrics dictionary, it will not be used.",
121
+ UserWarning,
122
+ )
123
+ continue
124
+ else:
125
+ metric_calculators[metric_name] = global_metrics_dict[metric_name]
126
+ return metric_calculators
127
+
128
+
103
129
  def overall_stats(predictions, ground_truth, params) -> dict[str, Union[float, list]]:
104
130
  """
105
131
  Generates a dictionary of metrics calculated on the overall predictions and ground truths.
GANDLF/metrics/generic.py CHANGED
@@ -34,7 +34,7 @@ def generic_function_output_with_check(
34
34
  print(
35
35
  "WARNING: Negative values detected in prediction, cannot compute torchmetrics calculations."
36
36
  )
37
- return torch.zeros((1), device=prediction.device)
37
+ return torch.tensor(0, device=prediction.device)
38
38
  else:
39
39
  # I need to do this with try-except, otherwise for binary problems it will
40
40
  # raise and error as the binary metrics do not have .num_classes