GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of GANDLF might be problematic. Click here for more details.

Files changed (57) hide show
  1. GANDLF/cli/deploy.py +2 -2
  2. GANDLF/cli/generate_metrics.py +35 -1
  3. GANDLF/cli/main_run.py +4 -10
  4. GANDLF/compute/__init__.py +0 -2
  5. GANDLF/compute/forward_pass.py +0 -1
  6. GANDLF/compute/generic.py +107 -2
  7. GANDLF/compute/inference_loop.py +4 -4
  8. GANDLF/compute/loss_and_metric.py +1 -2
  9. GANDLF/compute/training_loop.py +10 -10
  10. GANDLF/config_manager.py +33 -717
  11. GANDLF/configuration/__init__.py +0 -0
  12. GANDLF/configuration/default_config.py +73 -0
  13. GANDLF/configuration/differential_privacy_config.py +16 -0
  14. GANDLF/configuration/exclude_parameters.py +1 -0
  15. GANDLF/configuration/model_config.py +82 -0
  16. GANDLF/configuration/nested_training_config.py +25 -0
  17. GANDLF/configuration/optimizer_config.py +121 -0
  18. GANDLF/configuration/parameters_config.py +10 -0
  19. GANDLF/configuration/patch_sampler_config.py +11 -0
  20. GANDLF/configuration/post_processing_config.py +10 -0
  21. GANDLF/configuration/pre_processing_config.py +94 -0
  22. GANDLF/configuration/scheduler_config.py +92 -0
  23. GANDLF/configuration/user_defined_config.py +131 -0
  24. GANDLF/configuration/utils.py +96 -0
  25. GANDLF/configuration/validators.py +479 -0
  26. GANDLF/data/__init__.py +14 -16
  27. GANDLF/data/lightning_datamodule.py +119 -0
  28. GANDLF/entrypoints/run.py +36 -31
  29. GANDLF/inference_manager.py +69 -25
  30. GANDLF/losses/__init__.py +23 -1
  31. GANDLF/losses/loss_calculators.py +79 -0
  32. GANDLF/losses/segmentation.py +3 -2
  33. GANDLF/metrics/__init__.py +26 -0
  34. GANDLF/metrics/generic.py +1 -1
  35. GANDLF/metrics/metric_calculators.py +102 -0
  36. GANDLF/metrics/panoptica_config_brats.yaml +56 -0
  37. GANDLF/metrics/segmentation_panoptica.py +49 -0
  38. GANDLF/models/__init__.py +8 -3
  39. GANDLF/models/lightning_module.py +2102 -0
  40. GANDLF/optimizers/__init__.py +4 -8
  41. GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
  42. GANDLF/schedulers/__init__.py +11 -4
  43. GANDLF/schedulers/wrap_torch.py +15 -3
  44. GANDLF/training_manager.py +160 -50
  45. GANDLF/utils/__init__.py +5 -3
  46. GANDLF/utils/imaging.py +176 -35
  47. GANDLF/utils/modelio.py +12 -8
  48. GANDLF/utils/pred_target_processors.py +71 -0
  49. GANDLF/utils/tensor.py +2 -1
  50. GANDLF/utils/write_parse.py +1 -1
  51. GANDLF/version.py +1 -1
  52. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
  53. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
  54. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
  55. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
  56. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
  57. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,96 @@
1
+ import logging
2
+ from typing import Optional, Union
3
+
4
+
5
+ from typing import Type
6
+ from pydantic import BaseModel, ValidationError, create_model
7
+ from pydantic_core import ErrorDetails
8
+
9
+
10
+ def initialize_key(
11
+ parameters: dict, key: str, value: Optional[Union[str, float, list, dict]] = None
12
+ ) -> dict:
13
+ """
14
+ This function initializes a key in the parameters dictionary to a value if it is absent.
15
+
16
+ Args:
17
+ parameters (dict): The parameter dictionary.
18
+ key (str): The key to initialize.
19
+ value (Optional[Union[str, float, list, dict]], optional): The value to initialize. Defaults to None.
20
+
21
+ Returns:
22
+ dict: The parameter dictionary.
23
+ """
24
+ if parameters is None:
25
+ parameters = {}
26
+ if key in parameters:
27
+ if parameters[key] is not None:
28
+ if isinstance(parameters[key], dict):
29
+ # if key is present but not defined
30
+ if len(parameters[key]) == 0:
31
+ parameters[key] = value
32
+ else:
33
+ parameters[key] = value # if key is absent
34
+
35
+ return parameters
36
+
37
+
38
+ # Define custom error messages. The key must be a pydantic type error.
39
+ CUSTOM_MESSAGES = {
40
+ "literal_error": "The input must be a valid option, please read the documentation",
41
+ "missing": "This parameter is required. Please define it",
42
+ }
43
+
44
+
45
+ def convert_errors(e: ValidationError, custom_messages=None) -> list[ErrorDetails]:
46
+ if custom_messages is None:
47
+ custom_messages = CUSTOM_MESSAGES
48
+ new_errors: list[ErrorDetails] = []
49
+ for error in e.errors():
50
+ custom_message = custom_messages.get(error["type"])
51
+ if custom_message:
52
+ ctx = error.get("ctx")
53
+ error["msg"] = custom_message.format(**ctx) if ctx else custom_message
54
+ new_errors.append(error)
55
+ return new_errors
56
+
57
+
58
+ def extract_messages(errors: list[ErrorDetails]) -> list[str]:
59
+ error_messages: list[str] = []
60
+ for error in errors:
61
+ location = error.get("loc")
62
+ if len(location) == 2:
63
+ message = f"Configuration Error: Parameter: ({location[0]}, {location[1]}) - {error['msg']}"
64
+ else:
65
+ message = (
66
+ f"Configuration Error: Parameter: ({location[0]}) - {error['msg']}"
67
+ )
68
+ error_messages.append(message)
69
+ return error_messages
70
+
71
+
72
+ def handle_configuration_errors(e: ValidationError):
73
+ messages = extract_messages(convert_errors(e))
74
+ for message in messages:
75
+ logging.error(message)
76
+
77
+
78
+ def combine_models(base_model: Type[BaseModel], extra_model: Type[BaseModel]):
79
+ """Combine base model with an extra model dynamically."""
80
+ fields = {}
81
+ # Collect base model fields
82
+ for field_name, field_info in base_model.model_fields.items():
83
+ fields[field_name] = (
84
+ field_info.annotation,
85
+ field_info.default if field_info.default is not Ellipsis else ...,
86
+ )
87
+
88
+ # Add fields from the extra model
89
+ for field_name, field_info in extra_model.model_fields.items():
90
+ fields[field_name] = (
91
+ field_info.annotation,
92
+ field_info.default if field_info.default is not Ellipsis else ...,
93
+ )
94
+
95
+ # Return the new dynamically combined model
96
+ return create_model(base_model.__name__, **fields)
@@ -0,0 +1,479 @@
1
+ import ast
2
+ import traceback
3
+ from copy import deepcopy
4
+ from GANDLF.configuration.differential_privacy_config import DifferentialPrivacyConfig
5
+ from GANDLF.configuration.post_processing_config import PostProcessingConfig
6
+ from GANDLF.configuration.pre_processing_config import (
7
+ HistogramMatchingConfig,
8
+ PreProcessingConfig,
9
+ )
10
+ from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding
11
+ import numpy as np
12
+ import sys
13
+ from GANDLF.configuration.optimizer_config import OptimizerConfig, optimizer_dict_config
14
+ from GANDLF.configuration.patch_sampler_config import PatchSamplerConfig
15
+ from GANDLF.configuration.scheduler_config import (
16
+ SchedulerConfig,
17
+ schedulers_dict_config,
18
+ )
19
+ from GANDLF.configuration.utils import initialize_key, combine_models
20
+ from GANDLF.metrics import surface_distance_ids
21
+
22
+
23
+ def validate_loss_function(value) -> dict:
24
+ if isinstance(value, dict): # if this is a dict
25
+ if len(value) > 0: # only proceed if something is defined
26
+ for key in value: # iterate through all keys
27
+ if key == "mse":
28
+ if (value[key] is None) or not ("reduction" in value[key]):
29
+ value[key] = {}
30
+ value[key]["reduction"] = "mean"
31
+ else:
32
+ # use simple string for other functions - can be extended with parameters, if needed
33
+ value = key
34
+ else:
35
+ if value == "focal":
36
+ value = {"focal": {}}
37
+ value["focal"]["gamma"] = 2.0
38
+ value["focal"]["size_average"] = True
39
+ elif value == "mse":
40
+ value = {"mse": {}}
41
+ value["mse"]["reduction"] = "mean"
42
+
43
+ return value
44
+
45
+
46
+ def validate_metrics(value) -> dict:
47
+ if not isinstance(value, dict):
48
+ temp_dict = {}
49
+ else:
50
+ temp_dict = value
51
+
52
+ # initialize metrics dict
53
+ for metric in value:
54
+ # assigning a new variable because some metrics can be dicts, and we want to get the first key
55
+ comparison_string = metric
56
+ if isinstance(metric, dict):
57
+ comparison_string = list(metric.keys())[0]
58
+ # these metrics always need to be dicts
59
+ if comparison_string in [
60
+ "accuracy",
61
+ "f1",
62
+ "precision",
63
+ "recall",
64
+ "specificity",
65
+ "iou",
66
+ ]:
67
+ if not isinstance(metric, dict):
68
+ temp_dict[metric] = {}
69
+ else:
70
+ temp_dict[comparison_string] = metric
71
+ elif not isinstance(metric, dict):
72
+ temp_dict[metric] = None
73
+
74
+ # special case for accuracy, precision, recall, and specificity; which could be dicts
75
+ ## need to find a better way to do this
76
+ if any(
77
+ _ in comparison_string
78
+ for _ in ["precision", "recall", "specificity", "accuracy", "f1"]
79
+ ):
80
+ if comparison_string != "classification_accuracy":
81
+ temp_dict[comparison_string] = initialize_key(
82
+ temp_dict[comparison_string], "average", "weighted"
83
+ )
84
+ temp_dict[comparison_string] = initialize_key(
85
+ temp_dict[comparison_string], "multi_class", True
86
+ )
87
+ temp_dict[comparison_string] = initialize_key(
88
+ temp_dict[comparison_string], "mdmc_average", "samplewise"
89
+ )
90
+ temp_dict[comparison_string] = initialize_key(
91
+ temp_dict[comparison_string], "threshold", 0.5
92
+ )
93
+ if comparison_string == "accuracy":
94
+ temp_dict[comparison_string] = initialize_key(
95
+ temp_dict[comparison_string], "subset_accuracy", False
96
+ )
97
+ elif "iou" in comparison_string:
98
+ temp_dict["iou"] = initialize_key(
99
+ temp_dict["iou"], "reduction", "elementwise_mean"
100
+ )
101
+ temp_dict["iou"] = initialize_key(temp_dict["iou"], "threshold", 0.5)
102
+ elif comparison_string in surface_distance_ids:
103
+ temp_dict[comparison_string] = initialize_key(
104
+ temp_dict[comparison_string], "connectivity", 1
105
+ )
106
+ temp_dict[comparison_string] = initialize_key(
107
+ temp_dict[comparison_string], "threshold", None
108
+ )
109
+
110
+ value = temp_dict
111
+ return value
112
+
113
+
114
+ def validate_class_list(value):
115
+ if isinstance(value, str):
116
+ if ("||" in value) or ("&&" in value):
117
+ # special case for multi-class computation - this needs to be handled during one-hot encoding mask construction
118
+ print(
119
+ "WARNING: This is a special case for multi-class computation, where different labels are processed together, `reverse_one_hot` will need mapping information to work correctly"
120
+ )
121
+ temp_class_list = value
122
+ # we don't need the brackets
123
+ temp_class_list = temp_class_list.replace("[", "")
124
+ temp_class_list = temp_class_list.replace("]", "")
125
+ value = temp_class_list.split(",")
126
+ else:
127
+ try:
128
+ value = ast.literal_eval(value)
129
+ return value
130
+ except Exception as e:
131
+ assert (
132
+ False
133
+ ), f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}"
134
+ # logging.error(
135
+ # f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}"
136
+ # )
137
+ return value
138
+
139
+
140
+ def validate_patch_size(patch_size, dimension) -> list:
141
+ if isinstance(patch_size, int) or isinstance(patch_size, float):
142
+ patch_size = [patch_size]
143
+ if len(patch_size) == 1 and dimension is not None:
144
+ actual_patch_size = []
145
+ for _ in range(dimension):
146
+ actual_patch_size.append(patch_size[0])
147
+ patch_size = actual_patch_size
148
+ if len(patch_size) == 2: # 2d check
149
+ # ensuring same size during torchio processing
150
+ patch_size.append(1)
151
+ if dimension is None:
152
+ dimension = 2
153
+ elif len(patch_size) == 3: # 2d check
154
+ if dimension is None:
155
+ dimension = 3
156
+ return [patch_size, dimension]
157
+
158
+
159
+ def validate_norm_type(norm_type, architecture):
160
+ if norm_type is None or norm_type.lower() == "none":
161
+ if not ("vgg" in architecture):
162
+ raise ValueError(
163
+ "Normalization type cannot be 'None' for non-VGG architectures"
164
+ )
165
+ return norm_type
166
+
167
+
168
+ def validate_parallel_compute_command(value):
169
+ parallel_compute_command = value
170
+ parallel_compute_command = parallel_compute_command.replace("'", "")
171
+ parallel_compute_command = parallel_compute_command.replace('"', "")
172
+ value = parallel_compute_command
173
+ return value
174
+
175
+
176
+ def validate_scheduler(value, learning_rate, num_epochs):
177
+ if isinstance(value, str):
178
+ value = SchedulerConfig(type=value)
179
+ # Find the scheduler_config class based on the type
180
+ combine_scheduler_class = schedulers_dict_config[value.type]
181
+ # Combine it with the SchedulerConfig class
182
+ schedulerConfigCombine = combine_models(SchedulerConfig, combine_scheduler_class)
183
+ combineScheduler = schedulerConfigCombine(**value.model_dump())
184
+ value = SchedulerConfig(**combineScheduler.model_dump())
185
+
186
+ if value.type == "triangular":
187
+ if value.max_lr is None:
188
+ value.max_lr = learning_rate
189
+
190
+ if value.type in [
191
+ "reduce_on_plateau",
192
+ "reduce-on-plateau",
193
+ "plateau",
194
+ "exp_range",
195
+ "triangular",
196
+ ]:
197
+ if value.min_lr is None:
198
+ value.min_lr = learning_rate * 0.001
199
+
200
+ if value.type in ["warmupcosineschedule", "wcs"]:
201
+ value.warmup_steps = num_epochs * 0.1
202
+
203
+ if hasattr(value, "step_size") and value.step_size is None:
204
+ value.step_size = learning_rate / 5.0
205
+
206
+ return value
207
+
208
+
209
+ def validate_optimizer(value):
210
+ if isinstance(value, str):
211
+ value = OptimizerConfig(type=value)
212
+
213
+ combine_optimizer_class = optimizer_dict_config[value.type]
214
+ # Combine it with the OptimizerConfig class
215
+ optimizerConfigCombine = combine_models(OptimizerConfig, combine_optimizer_class)
216
+ combineOptimizer = optimizerConfigCombine(**value.model_dump())
217
+ value = OptimizerConfig(**combineOptimizer.model_dump())
218
+
219
+ return value
220
+
221
+
222
+ def validate_data_preprocessing(value) -> dict:
223
+ if not (value is None):
224
+ # perform this only when pre-processing is defined
225
+ if len(value) > 0:
226
+ thresholdOrClip = False
227
+ # this can be extended, as required
228
+ thresholdOrClipDict = ["threshold", "clip", "clamp"]
229
+
230
+ resize_requested = False
231
+ temp_dict = deepcopy(value)
232
+ for key in value:
233
+ if key in ["resize", "resize_image", "resize_images", "resize_patch"]:
234
+ resize_requested = True
235
+
236
+ if key in ["resample_min", "resample_minimum"]:
237
+ if "resolution" in value[key]:
238
+ resize_requested = True
239
+ resolution_temp = np.array(value[key]["resolution"])
240
+ if resolution_temp.size == 1:
241
+ temp_dict[key]["resolution"] = np.array(
242
+ [resolution_temp, resolution_temp]
243
+ ).tolist()
244
+ else:
245
+ temp_dict.pop(key)
246
+
247
+ value = temp_dict
248
+
249
+ if resize_requested and "resample" in value:
250
+ for key in ["resize", "resize_image", "resize_images", "resize_patch"]:
251
+ if key in value:
252
+ value.pop(key)
253
+
254
+ print(
255
+ "WARNING: Different 'resize' operations are ignored as 'resample' is defined under 'data_processing'",
256
+ file=sys.stderr,
257
+ )
258
+
259
+ # iterate through all keys
260
+ for key in value: # iterate through all keys
261
+ if key in thresholdOrClipDict:
262
+ # we only allow one of threshold or clip to occur and not both
263
+ assert not (
264
+ thresholdOrClip
265
+ ), "Use only `threshold` or `clip`, not both"
266
+ thresholdOrClip = True
267
+ # initialize if nothing is present
268
+ if not (isinstance(value[key], dict)):
269
+ value[key] = {}
270
+
271
+ # if one of the required parameters is not present, initialize with lowest/highest possible values
272
+ # this ensures the absence of a field doesn't affect processing
273
+ # for threshold or clip, ensure min and max are defined
274
+ if not "min" in value[key]:
275
+ value[key]["min"] = sys.float_info.min
276
+ if not "max" in value[key]:
277
+ value[key]["max"] = sys.float_info.max
278
+
279
+ key = "histogram_matching"
280
+ if key in value:
281
+ if value["histogram_matching"] is not False:
282
+ if not (isinstance(value["histogram_matching"], dict)):
283
+ value["histogram_matching"] = HistogramMatchingConfig()
284
+
285
+ key = "histogram_equalization"
286
+ if key in value:
287
+ if value[key] is not False:
288
+ # if histogram equalization is enabled, call histogram_matching
289
+ value["histogram_matching"] = HistogramMatchingConfig()
290
+ key = "adaptive_histogram_equalization"
291
+ if key in value:
292
+ if value[key] is not False:
293
+ # if histogram equalization is enabled, call histogram_matching
294
+ value["histogram_matching"] = HistogramMatchingConfig(
295
+ target="adaptive"
296
+ )
297
+
298
+ pre_processing = PreProcessingConfig(**value)
299
+ return pre_processing.model_dump(include={field for field in value.keys()})
300
+
301
+
302
+ def validate_data_postprocessing_after_reverse_one_hot_encoding(
303
+ value, data_postprocessing
304
+ ) -> list:
305
+ temp_dict = deepcopy(value)
306
+ for key in temp_dict:
307
+ if key in postprocessing_after_reverse_one_hot_encoding:
308
+ value[key] = data_postprocessing[key]
309
+ data_postprocessing.pop(key)
310
+ return [value, data_postprocessing]
311
+
312
+
313
+ def validate_patch_sampler(value):
314
+ if isinstance(value, str):
315
+ value = PatchSamplerConfig(type=value.lower())
316
+ return value
317
+
318
+
319
+ def validate_data_augmentation(value, patch_size) -> dict:
320
+ value["default_probability"] = value.get("default_probability", 0.5)
321
+ if not (value is None):
322
+ if len(value) > 0: # only when augmentations are defined
323
+ # special case for random swapping and elastic transformations - which takes a patch size for computation
324
+ for key in ["swap", "elastic"]:
325
+ if key in value:
326
+ value[key] = initialize_key(
327
+ value[key],
328
+ "patch_size",
329
+ np.round(np.array(patch_size) / 10).astype("int").tolist(),
330
+ )
331
+
332
+ # special case for swap default initialization
333
+ if "swap" in value:
334
+ value["swap"] = initialize_key(value["swap"], "num_iterations", 100)
335
+
336
+ # special case for affine default initialization
337
+ if "affine" in value:
338
+ value["affine"] = initialize_key(value["affine"], "scales", 0.1)
339
+ value["affine"] = initialize_key(value["affine"], "degrees", 15)
340
+ value["affine"] = initialize_key(value["affine"], "translation", 2)
341
+
342
+ if "motion" in value:
343
+ value["motion"] = initialize_key(value["motion"], "num_transforms", 2)
344
+ value["motion"] = initialize_key(value["motion"], "degrees", 15)
345
+ value["motion"] = initialize_key(value["motion"], "translation", 2)
346
+ value["motion"] = initialize_key(
347
+ value["motion"], "interpolation", "linear"
348
+ )
349
+
350
+ # special case for random blur/noise - which takes a std-dev range
351
+ for std_aug in ["blur", "noise_var"]:
352
+ if std_aug in value:
353
+ value[std_aug] = initialize_key(value[std_aug], "std", None)
354
+ for std_aug in ["noise"]:
355
+ if std_aug in value:
356
+ value[std_aug] = initialize_key(value[std_aug], "std", [0, 1])
357
+
358
+ # special case for random noise - which takes a mean range
359
+ for mean_aug in ["noise", "noise_var"]:
360
+ if mean_aug in value:
361
+ value[mean_aug] = initialize_key(value[mean_aug], "mean", 0)
362
+
363
+ # special case for augmentations that need axis defined
364
+ for axis_aug in ["flip", "anisotropic", "rotate_90", "rotate_180"]:
365
+ if axis_aug in value:
366
+ value[axis_aug] = initialize_key(value[axis_aug], "axis", [0, 1, 2])
367
+
368
+ # special case for colorjitter
369
+ if "colorjitter" in value:
370
+ value = initialize_key(value, "colorjitter", {})
371
+ for key in ["brightness", "contrast", "saturation"]:
372
+ value["colorjitter"] = initialize_key(
373
+ value["colorjitter"], key, [0, 1]
374
+ )
375
+ value["colorjitter"] = initialize_key(
376
+ value["colorjitter"], "hue", [-0.5, 0.5]
377
+ )
378
+
379
+ # Added HED augmentation in gandlf
380
+ hed_augmentation_types = [
381
+ "hed_transform",
382
+ # "hed_transform_light",
383
+ # "hed_transform_heavy",
384
+ ]
385
+ for augmentation_type in hed_augmentation_types:
386
+ if augmentation_type in value:
387
+ value = initialize_key(value, "hed_transform", {})
388
+ ranges = [
389
+ "haematoxylin_bias_range",
390
+ "eosin_bias_range",
391
+ "dab_bias_range",
392
+ "haematoxylin_sigma_range",
393
+ "eosin_sigma_range",
394
+ "dab_sigma_range",
395
+ ]
396
+
397
+ default_range = (
398
+ [-0.1, 0.1]
399
+ if augmentation_type == "hed_transform"
400
+ else (
401
+ [-0.03, 0.03]
402
+ if augmentation_type == "hed_transform_light"
403
+ else [-0.95, 0.95]
404
+ )
405
+ )
406
+
407
+ for key in ranges:
408
+ value["hed_transform"] = initialize_key(
409
+ value["hed_transform"], key, default_range
410
+ )
411
+
412
+ value["hed_transform"] = initialize_key(
413
+ value["hed_transform"], "cutoff_range", [0, 1]
414
+ )
415
+
416
+ # special case for anisotropic
417
+ if "anisotropic" in value:
418
+ if not ("downsampling" in value["anisotropic"]):
419
+ default_downsampling = 1.5
420
+ else:
421
+ default_downsampling = value["anisotropic"]["downsampling"]
422
+
423
+ initialize_downsampling = False
424
+ if isinstance(default_downsampling, list):
425
+ if len(default_downsampling) != 2:
426
+ initialize_downsampling = True
427
+ print(
428
+ "WARNING: 'anisotropic' augmentation needs to be either a single number of a list of 2 numbers: https://torchio.readthedocs.io/transforms/augmentation.html?highlight=randomswap#torchio.transforms.RandomAnisotropy.",
429
+ file=sys.stderr,
430
+ )
431
+ default_downsampling = default_downsampling[0] # only
432
+ else:
433
+ initialize_downsampling = True
434
+
435
+ if initialize_downsampling:
436
+ if default_downsampling < 1:
437
+ print(
438
+ "WARNING: 'anisotropic' augmentation needs the 'downsampling' parameter to be greater than 1, defaulting to 1.5.",
439
+ file=sys.stderr,
440
+ )
441
+ # default
442
+ value["anisotropic"]["downsampling"] = 1.5
443
+
444
+ for key in value:
445
+ if key != "default_probability":
446
+ value[key] = initialize_key(
447
+ value[key], "probability", value["default_probability"]
448
+ )
449
+ return value
450
+
451
+
452
+ def validate_postprocessing(value):
453
+ post_processing = PostProcessingConfig(**value)
454
+ return post_processing.model_dump(include={field for field in value.keys()})
455
+
456
+
457
+ def validate_differential_privacy(value, batch_size):
458
+ if value is None:
459
+ return value
460
+ if not isinstance(value, dict):
461
+ print(
462
+ "WARNING: Non dictionary value for the key: 'differential_privacy' was used, replacing with default valued dictionary."
463
+ )
464
+ value = DifferentialPrivacyConfig(physical_batch_size=batch_size).model_dump()
465
+ # these are some defaults
466
+
467
+ if value["physical_batch_size"] > batch_size:
468
+ print(
469
+ f"WARNING: The physical batch size {value['physical_batch_size']} is greater"
470
+ f"than the batch size {batch_size}, setting the physical batch size to the batch size."
471
+ )
472
+ value["physical_batch_size"] = batch_size
473
+
474
+ # these keys need to be parsed as floats, not strings
475
+ for key in ["noise_multiplier", "max_grad_norm", "delta", "epsilon"]:
476
+ if key in value:
477
+ value[key] = float(value[key])
478
+
479
+ return DifferentialPrivacyConfig(**value)
GANDLF/data/__init__.py CHANGED
@@ -66,19 +66,17 @@ def get_testing_loader(params):
66
66
  Returns:
67
67
  torch.utils.data.DataLoader: The testing loader.
68
68
  """
69
- if params["testing_data"] is None:
70
- return None
71
- else:
72
- queue_from_dataframe = ImagesFromDataFrame(
73
- get_dataframe(params["testing_data"]),
74
- params,
75
- train=False,
76
- loader_type="testing",
77
- )
78
- if not ("channel_keys" in params):
79
- params = populate_channel_keys_in_params(queue_from_dataframe, params)
80
- return DataLoader(
81
- queue_from_dataframe,
82
- batch_size=1,
83
- pin_memory=False, # params["pin_memory_dataloader"], # this is going OOM if True - needs investigation
84
- )
69
+
70
+ queue_from_dataframe = ImagesFromDataFrame(
71
+ get_dataframe(params["testing_data"]),
72
+ params,
73
+ train=False,
74
+ loader_type="testing",
75
+ )
76
+ if not ("channel_keys" in params):
77
+ params = populate_channel_keys_in_params(queue_from_dataframe, params)
78
+ return DataLoader(
79
+ queue_from_dataframe,
80
+ batch_size=1,
81
+ pin_memory=False, # params["pin_memory_dataloader"], # this is going OOM if True - needs investigation
82
+ )