GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of GANDLF might be problematic. Click here for more details.

Files changed (57) hide show
  1. GANDLF/cli/deploy.py +2 -2
  2. GANDLF/cli/generate_metrics.py +35 -1
  3. GANDLF/cli/main_run.py +4 -10
  4. GANDLF/compute/__init__.py +0 -2
  5. GANDLF/compute/forward_pass.py +0 -1
  6. GANDLF/compute/generic.py +107 -2
  7. GANDLF/compute/inference_loop.py +4 -4
  8. GANDLF/compute/loss_and_metric.py +1 -2
  9. GANDLF/compute/training_loop.py +10 -10
  10. GANDLF/config_manager.py +33 -717
  11. GANDLF/configuration/__init__.py +0 -0
  12. GANDLF/configuration/default_config.py +73 -0
  13. GANDLF/configuration/differential_privacy_config.py +16 -0
  14. GANDLF/configuration/exclude_parameters.py +1 -0
  15. GANDLF/configuration/model_config.py +82 -0
  16. GANDLF/configuration/nested_training_config.py +25 -0
  17. GANDLF/configuration/optimizer_config.py +121 -0
  18. GANDLF/configuration/parameters_config.py +10 -0
  19. GANDLF/configuration/patch_sampler_config.py +11 -0
  20. GANDLF/configuration/post_processing_config.py +10 -0
  21. GANDLF/configuration/pre_processing_config.py +94 -0
  22. GANDLF/configuration/scheduler_config.py +92 -0
  23. GANDLF/configuration/user_defined_config.py +131 -0
  24. GANDLF/configuration/utils.py +96 -0
  25. GANDLF/configuration/validators.py +479 -0
  26. GANDLF/data/__init__.py +14 -16
  27. GANDLF/data/lightning_datamodule.py +119 -0
  28. GANDLF/entrypoints/run.py +36 -31
  29. GANDLF/inference_manager.py +69 -25
  30. GANDLF/losses/__init__.py +23 -1
  31. GANDLF/losses/loss_calculators.py +79 -0
  32. GANDLF/losses/segmentation.py +3 -2
  33. GANDLF/metrics/__init__.py +26 -0
  34. GANDLF/metrics/generic.py +1 -1
  35. GANDLF/metrics/metric_calculators.py +102 -0
  36. GANDLF/metrics/panoptica_config_brats.yaml +56 -0
  37. GANDLF/metrics/segmentation_panoptica.py +49 -0
  38. GANDLF/models/__init__.py +8 -3
  39. GANDLF/models/lightning_module.py +2102 -0
  40. GANDLF/optimizers/__init__.py +4 -8
  41. GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
  42. GANDLF/schedulers/__init__.py +11 -4
  43. GANDLF/schedulers/wrap_torch.py +15 -3
  44. GANDLF/training_manager.py +160 -50
  45. GANDLF/utils/__init__.py +5 -3
  46. GANDLF/utils/imaging.py +176 -35
  47. GANDLF/utils/modelio.py +12 -8
  48. GANDLF/utils/pred_target_processors.py +71 -0
  49. GANDLF/utils/tensor.py +2 -1
  50. GANDLF/utils/write_parse.py +1 -1
  51. GANDLF/version.py +1 -1
  52. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
  53. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
  54. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
  55. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
  56. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
  57. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
GANDLF/config_manager.py CHANGED
@@ -1,113 +1,11 @@
1
- # import logging
2
1
  import traceback
3
- from typing import Optional, Union
4
- import sys, yaml, ast
5
- import numpy as np
6
- from copy import deepcopy
2
+ from typing import Union
3
+ import yaml
4
+ from pydantic import ValidationError
7
5
 
8
- from .utils import version_check
9
- from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding
10
- from GANDLF.privacy.opacus import parse_opacus_params
11
-
12
- from GANDLF.metrics import surface_distance_ids
13
- from importlib.metadata import version
14
-
15
- ## dictionary to define defaults for appropriate options, which are evaluated
16
- parameter_defaults = {
17
- "weighted_loss": False, # whether weighted loss is to be used or not
18
- "verbose": False, # general application verbosity
19
- "q_verbose": False, # queue construction verbosity
20
- "medcam_enabled": False, # interpretability via medcam
21
- "save_training": False, # save outputs during training
22
- "save_output": False, # save outputs during validation/testing
23
- "in_memory": False, # pin data to cpu memory
24
- "pin_memory_dataloader": False, # pin data to gpu memory
25
- "scaling_factor": 1, # scaling factor for regression problems
26
- "q_max_length": 100, # the max length of queue
27
- "q_samples_per_volume": 10, # number of samples per volume
28
- "q_num_workers": 4, # number of worker threads to use
29
- "num_epochs": 100, # total number of epochs to train
30
- "patience": 100, # number of epochs to wait for performance improvement
31
- "batch_size": 1, # default batch size of training
32
- "learning_rate": 0.001, # default learning rate
33
- "clip_grad": None, # clip_gradient value
34
- "track_memory_usage": False, # default memory tracking
35
- "memory_save_mode": False, # default memory saving, if enabled, resize/resample will save files to disk
36
- "print_rgb_label_warning": True, # print rgb label warning
37
- "data_postprocessing": {}, # default data postprocessing
38
- "grid_aggregator_overlap": "crop", # default grid aggregator overlap strategy
39
- "determinism": False, # using deterministic version of computation
40
- "previous_parameters": None, # previous parameters to be used for resuming training and perform sanity checking
41
- }
42
-
43
- ## dictionary to define string defaults for appropriate options
44
- parameter_defaults_string = {
45
- "optimizer": "adam", # the optimizer
46
- "scheduler": "triangle_modified", # the default scheduler
47
- "clip_mode": None, # default clip mode
48
- }
49
-
50
-
51
- def initialize_parameter(
52
- params: dict,
53
- parameter_to_initialize: str,
54
- value: Optional[Union[str, list, int, dict]] = None,
55
- evaluate: Optional[bool] = True,
56
- ) -> dict:
57
- """
58
- This function will initialize the parameter in the parameters dict to the value if it is absent.
59
-
60
- Args:
61
- params (dict): The parameter dictionary.
62
- parameter_to_initialize (str): The parameter to initialize.
63
- value (Optional[Union[str, list, int, dict]], optional): The value to initialize. Defaults to None.
64
- evaluate (Optional[bool], optional): Whether to evaluate the value. Defaults to True.
65
-
66
- Returns:
67
- dict: The parameter dictionary.
68
- """
69
- if parameter_to_initialize in params:
70
- if evaluate:
71
- if isinstance(params[parameter_to_initialize], str):
72
- if params[parameter_to_initialize].lower() == "none":
73
- params[parameter_to_initialize] = ast.literal_eval(
74
- params[parameter_to_initialize]
75
- )
76
- else:
77
- print(
78
- "WARNING: Initializing '" + parameter_to_initialize + "' as " + str(value)
79
- )
80
- params[parameter_to_initialize] = value
81
-
82
- return params
83
-
84
-
85
- def initialize_key(
86
- parameters: dict, key: str, value: Optional[Union[str, float, list, dict]] = None
87
- ) -> dict:
88
- """
89
- This function initializes a key in the parameters dictionary to a value if it is absent.
90
-
91
- Args:
92
- parameters (dict): The parameter dictionary.
93
- key (str): The key to initialize.
94
- value (Optional[Union[str, float, list, dict]], optional): The value to initialize. Defaults to None.
95
-
96
- Returns:
97
- dict: The parameter dictionary.
98
- """
99
- if parameters is None:
100
- parameters = {}
101
- if key in parameters:
102
- if parameters[key] is not None:
103
- if isinstance(parameters[key], dict):
104
- # if key is present but not defined
105
- if len(parameters[key]) == 0:
106
- parameters[key] = value
107
- else:
108
- parameters[key] = value # if key is absent
109
-
110
- return parameters
6
+ from GANDLF.configuration.parameters_config import Parameters
7
+ from GANDLF.configuration.exclude_parameters import exclude_parameters
8
+ from GANDLF.configuration.utils import handle_configuration_errors
111
9
 
112
10
 
113
11
  def _parseConfig(
@@ -124,618 +22,21 @@ def _parseConfig(
124
22
  dict: The parameter dictionary.
125
23
  """
126
24
  params = config_file_path
127
- if not isinstance(config_file_path, dict):
128
- params = yaml.safe_load(open(config_file_path, "r"))
129
-
130
- if version_check_flag: # this is only to be used for testing
131
- assert (
132
- "version" in params
133
- ), "The 'version' key needs to be defined in config with 'minimum' and 'maximum' fields to determine the compatibility of configuration with code base"
134
- version_check(params["version"], version_to_check=version("GANDLF"))
135
-
136
- if "patch_size" in params:
137
- # duplicate patch size if it is an int or float
138
- if isinstance(params["patch_size"], int) or isinstance(
139
- params["patch_size"], float
140
- ):
141
- params["patch_size"] = [params["patch_size"]]
142
- # in case someone decides to pass a single value list
143
- if len(params["patch_size"]) == 1:
144
- actual_patch_size = []
145
- for _ in range(params["model"]["dimension"]):
146
- actual_patch_size.append(params["patch_size"][0])
147
- params["patch_size"] = actual_patch_size
148
-
149
- # parse patch size as needed for computations
150
- if len(params["patch_size"]) == 2: # 2d check
151
- # ensuring same size during torchio processing
152
- params["patch_size"].append(1)
153
- if "dimension" not in params["model"]:
154
- params["model"]["dimension"] = 2
155
- elif len(params["patch_size"]) == 3: # 2d check
156
- if "dimension" not in params["model"]:
157
- params["model"]["dimension"] = 3
158
- assert "patch_size" in params, "Patch size needs to be defined in the config file"
159
-
160
- if "resize" in params:
161
- print(
162
- "WARNING: 'resize' should be defined under 'data_processing', this will be skipped",
163
- file=sys.stderr,
164
- )
165
-
166
- assert "modality" in params, "'modality' needs to be defined in the config file"
167
- params["modality"] = params["modality"].lower()
168
- assert params["modality"] in [
169
- "rad",
170
- "histo",
171
- "path",
172
- ], "Modality should be either 'rad' or 'path'"
173
-
174
- assert (
175
- "loss_function" in params
176
- ), "'loss_function' needs to be defined in the config file"
177
- if "loss_function" in params:
178
- # check if user has passed a dict
179
- if isinstance(params["loss_function"], dict): # if this is a dict
180
- if len(params["loss_function"]) > 0: # only proceed if something is defined
181
- for key in params["loss_function"]: # iterate through all keys
182
- if key == "mse":
183
- if (params["loss_function"][key] is None) or not (
184
- "reduction" in params["loss_function"][key]
185
- ):
186
- params["loss_function"][key] = {}
187
- params["loss_function"][key]["reduction"] = "mean"
188
- else:
189
- # use simple string for other functions - can be extended with parameters, if needed
190
- params["loss_function"] = key
191
- else:
192
- # check if user has passed a single string
193
- if params["loss_function"] == "mse":
194
- params["loss_function"] = {}
195
- params["loss_function"]["mse"] = {}
196
- params["loss_function"]["mse"]["reduction"] = "mean"
197
- elif params["loss_function"] == "focal":
198
- params["loss_function"] = {}
199
- params["loss_function"]["focal"] = {}
200
- params["loss_function"]["focal"]["gamma"] = 2.0
201
- params["loss_function"]["focal"]["size_average"] = True
202
-
203
- assert "metrics" in params, "'metrics' needs to be defined in the config file"
204
- if "metrics" in params:
205
- if not isinstance(params["metrics"], dict):
206
- temp_dict = {}
207
- else:
208
- temp_dict = params["metrics"]
209
-
210
- # initialize metrics dict
211
- for metric in params["metrics"]:
212
- # assigning a new variable because some metrics can be dicts, and we want to get the first key
213
- comparison_string = metric
214
- if isinstance(metric, dict):
215
- comparison_string = list(metric.keys())[0]
216
- # these metrics always need to be dicts
217
- if comparison_string in [
218
- "accuracy",
219
- "f1",
220
- "precision",
221
- "recall",
222
- "specificity",
223
- "iou",
224
- ]:
225
- if not isinstance(metric, dict):
226
- temp_dict[metric] = {}
227
- else:
228
- temp_dict[comparison_string] = metric
229
- elif not isinstance(metric, dict):
230
- temp_dict[metric] = None
231
-
232
- # special case for accuracy, precision, recall, and specificity; which could be dicts
233
- ## need to find a better way to do this
234
- if any(
235
- _ in comparison_string
236
- for _ in ["precision", "recall", "specificity", "accuracy", "f1"]
237
- ):
238
- if comparison_string != "classification_accuracy":
239
- temp_dict[comparison_string] = initialize_key(
240
- temp_dict[comparison_string], "average", "weighted"
241
- )
242
- temp_dict[comparison_string] = initialize_key(
243
- temp_dict[comparison_string], "multi_class", True
244
- )
245
- temp_dict[comparison_string] = initialize_key(
246
- temp_dict[comparison_string], "mdmc_average", "samplewise"
247
- )
248
- temp_dict[comparison_string] = initialize_key(
249
- temp_dict[comparison_string], "threshold", 0.5
250
- )
251
- if comparison_string == "accuracy":
252
- temp_dict[comparison_string] = initialize_key(
253
- temp_dict[comparison_string], "subset_accuracy", False
254
- )
255
- elif "iou" in comparison_string:
256
- temp_dict["iou"] = initialize_key(
257
- temp_dict["iou"], "reduction", "elementwise_mean"
258
- )
259
- temp_dict["iou"] = initialize_key(temp_dict["iou"], "threshold", 0.5)
260
- elif comparison_string in surface_distance_ids:
261
- temp_dict[comparison_string] = initialize_key(
262
- temp_dict[comparison_string], "connectivity", 1
263
- )
264
- temp_dict[comparison_string] = initialize_key(
265
- temp_dict[comparison_string], "threshold", None
266
- )
267
-
268
- params["metrics"] = temp_dict
269
-
270
- # this is NOT a required parameter - a user should be able to train with NO augmentations
271
- params = initialize_key(params, "data_augmentation", {})
272
- # for all others, ensure probability is present
273
- params["data_augmentation"]["default_probability"] = params[
274
- "data_augmentation"
275
- ].get("default_probability", 0.5)
276
-
277
- if not (params["data_augmentation"] is None):
278
- if len(params["data_augmentation"]) > 0: # only when augmentations are defined
279
- # special case for random swapping and elastic transformations - which takes a patch size for computation
280
- for key in ["swap", "elastic"]:
281
- if key in params["data_augmentation"]:
282
- params["data_augmentation"][key] = initialize_key(
283
- params["data_augmentation"][key],
284
- "patch_size",
285
- np.round(np.array(params["patch_size"]) / 10)
286
- .astype("int")
287
- .tolist(),
288
- )
289
-
290
- # special case for swap default initialization
291
- if "swap" in params["data_augmentation"]:
292
- params["data_augmentation"]["swap"] = initialize_key(
293
- params["data_augmentation"]["swap"], "num_iterations", 100
294
- )
295
-
296
- # special case for affine default initialization
297
- if "affine" in params["data_augmentation"]:
298
- params["data_augmentation"]["affine"] = initialize_key(
299
- params["data_augmentation"]["affine"], "scales", 0.1
300
- )
301
- params["data_augmentation"]["affine"] = initialize_key(
302
- params["data_augmentation"]["affine"], "degrees", 15
303
- )
304
- params["data_augmentation"]["affine"] = initialize_key(
305
- params["data_augmentation"]["affine"], "translation", 2
306
- )
307
-
308
- if "motion" in params["data_augmentation"]:
309
- params["data_augmentation"]["motion"] = initialize_key(
310
- params["data_augmentation"]["motion"], "num_transforms", 2
311
- )
312
- params["data_augmentation"]["motion"] = initialize_key(
313
- params["data_augmentation"]["motion"], "degrees", 15
314
- )
315
- params["data_augmentation"]["motion"] = initialize_key(
316
- params["data_augmentation"]["motion"], "translation", 2
317
- )
318
- params["data_augmentation"]["motion"] = initialize_key(
319
- params["data_augmentation"]["motion"], "interpolation", "linear"
320
- )
321
-
322
- # special case for random blur/noise - which takes a std-dev range
323
- for std_aug in ["blur", "noise_var"]:
324
- if std_aug in params["data_augmentation"]:
325
- params["data_augmentation"][std_aug] = initialize_key(
326
- params["data_augmentation"][std_aug], "std", None
327
- )
328
- for std_aug in ["noise"]:
329
- if std_aug in params["data_augmentation"]:
330
- params["data_augmentation"][std_aug] = initialize_key(
331
- params["data_augmentation"][std_aug], "std", [0, 1]
332
- )
333
-
334
- # special case for random noise - which takes a mean range
335
- for mean_aug in ["noise", "noise_var"]:
336
- if mean_aug in params["data_augmentation"]:
337
- params["data_augmentation"][mean_aug] = initialize_key(
338
- params["data_augmentation"][mean_aug], "mean", 0
339
- )
340
-
341
- # special case for augmentations that need axis defined
342
- for axis_aug in ["flip", "anisotropic", "rotate_90", "rotate_180"]:
343
- if axis_aug in params["data_augmentation"]:
344
- params["data_augmentation"][axis_aug] = initialize_key(
345
- params["data_augmentation"][axis_aug], "axis", [0, 1, 2]
346
- )
347
-
348
- # special case for colorjitter
349
- if "colorjitter" in params["data_augmentation"]:
350
- params["data_augmentation"] = initialize_key(
351
- params["data_augmentation"], "colorjitter", {}
352
- )
353
- for key in ["brightness", "contrast", "saturation"]:
354
- params["data_augmentation"]["colorjitter"] = initialize_key(
355
- params["data_augmentation"]["colorjitter"], key, [0, 1]
356
- )
357
- params["data_augmentation"]["colorjitter"] = initialize_key(
358
- params["data_augmentation"]["colorjitter"], "hue", [-0.5, 0.5]
359
- )
360
-
361
- # Added HED augmentation in gandlf
362
- hed_augmentation_types = [
363
- "hed_transform",
364
- # "hed_transform_light",
365
- # "hed_transform_heavy",
366
- ]
367
- for augmentation_type in hed_augmentation_types:
368
- if augmentation_type in params["data_augmentation"]:
369
- params["data_augmentation"] = initialize_key(
370
- params["data_augmentation"], "hed_transform", {}
371
- )
372
- ranges = [
373
- "haematoxylin_bias_range",
374
- "eosin_bias_range",
375
- "dab_bias_range",
376
- "haematoxylin_sigma_range",
377
- "eosin_sigma_range",
378
- "dab_sigma_range",
379
- ]
380
-
381
- default_range = (
382
- [-0.1, 0.1]
383
- if augmentation_type == "hed_transform"
384
- else (
385
- [-0.03, 0.03]
386
- if augmentation_type == "hed_transform_light"
387
- else [-0.95, 0.95]
388
- )
389
- )
390
-
391
- for key in ranges:
392
- params["data_augmentation"]["hed_transform"] = initialize_key(
393
- params["data_augmentation"]["hed_transform"],
394
- key,
395
- default_range,
396
- )
397
-
398
- params["data_augmentation"]["hed_transform"] = initialize_key(
399
- params["data_augmentation"]["hed_transform"],
400
- "cutoff_range",
401
- [0, 1],
402
- )
403
-
404
- # special case for anisotropic
405
- if "anisotropic" in params["data_augmentation"]:
406
- if not ("downsampling" in params["data_augmentation"]["anisotropic"]):
407
- default_downsampling = 1.5
408
- else:
409
- default_downsampling = params["data_augmentation"]["anisotropic"][
410
- "downsampling"
411
- ]
412
-
413
- initialize_downsampling = False
414
- if isinstance(default_downsampling, list):
415
- if len(default_downsampling) != 2:
416
- initialize_downsampling = True
417
- print(
418
- "WARNING: 'anisotropic' augmentation needs to be either a single number of a list of 2 numbers: https://torchio.readthedocs.io/transforms/augmentation.html?highlight=randomswap#torchio.transforms.RandomAnisotropy.",
419
- file=sys.stderr,
420
- )
421
- default_downsampling = default_downsampling[0] # only
422
- else:
423
- initialize_downsampling = True
424
-
425
- if initialize_downsampling:
426
- if default_downsampling < 1:
427
- print(
428
- "WARNING: 'anisotropic' augmentation needs the 'downsampling' parameter to be greater than 1, defaulting to 1.5.",
429
- file=sys.stderr,
430
- )
431
- # default
432
- params["data_augmentation"]["anisotropic"]["downsampling"] = 1.5
433
-
434
- for key in params["data_augmentation"]:
435
- if key != "default_probability":
436
- params["data_augmentation"][key] = initialize_key(
437
- params["data_augmentation"][key],
438
- "probability",
439
- params["data_augmentation"]["default_probability"],
440
- )
441
-
442
- # this is NOT a required parameter - a user should be able to train with NO built-in pre-processing
443
- params = initialize_key(params, "data_preprocessing", {})
444
- if not (params["data_preprocessing"] is None):
445
- # perform this only when pre-processing is defined
446
- if len(params["data_preprocessing"]) > 0:
447
- thresholdOrClip = False
448
- # this can be extended, as required
449
- thresholdOrClipDict = ["threshold", "clip", "clamp"]
450
-
451
- resize_requested = False
452
- temp_dict = deepcopy(params["data_preprocessing"])
453
- for key in params["data_preprocessing"]:
454
- if key in ["resize", "resize_image", "resize_images", "resize_patch"]:
455
- resize_requested = True
456
-
457
- if key in ["resample_min", "resample_minimum"]:
458
- if "resolution" in params["data_preprocessing"][key]:
459
- resize_requested = True
460
- resolution_temp = np.array(
461
- params["data_preprocessing"][key]["resolution"]
462
- )
463
- if resolution_temp.size == 1:
464
- temp_dict[key]["resolution"] = np.array(
465
- [resolution_temp, resolution_temp]
466
- ).tolist()
467
- else:
468
- temp_dict.pop(key)
469
-
470
- params["data_preprocessing"] = temp_dict
471
-
472
- if resize_requested and "resample" in params["data_preprocessing"]:
473
- for key in ["resize", "resize_image", "resize_images", "resize_patch"]:
474
- if key in params["data_preprocessing"]:
475
- params["data_preprocessing"].pop(key)
476
-
477
- print(
478
- "WARNING: Different 'resize' operations are ignored as 'resample' is defined under 'data_processing'",
479
- file=sys.stderr,
480
- )
481
-
482
- # iterate through all keys
483
- for key in params["data_preprocessing"]: # iterate through all keys
484
- if key in thresholdOrClipDict:
485
- # we only allow one of threshold or clip to occur and not both
486
- assert not (
487
- thresholdOrClip
488
- ), "Use only `threshold` or `clip`, not both"
489
- thresholdOrClip = True
490
- # initialize if nothing is present
491
- if not (isinstance(params["data_preprocessing"][key], dict)):
492
- params["data_preprocessing"][key] = {}
493
-
494
- # if one of the required parameters is not present, initialize with lowest/highest possible values
495
- # this ensures the absence of a field doesn't affect processing
496
- # for threshold or clip, ensure min and max are defined
497
- if not "min" in params["data_preprocessing"][key]:
498
- params["data_preprocessing"][key]["min"] = sys.float_info.min
499
- if not "max" in params["data_preprocessing"][key]:
500
- params["data_preprocessing"][key]["max"] = sys.float_info.max
501
-
502
- if key == "histogram_matching":
503
- if params["data_preprocessing"][key] is not False:
504
- if not (isinstance(params["data_preprocessing"][key], dict)):
505
- params["data_preprocessing"][key] = {}
506
-
507
- if key == "histogram_equalization":
508
- if params["data_preprocessing"][key] is not False:
509
- # if histogram equalization is enabled, call histogram_matching
510
- params["data_preprocessing"]["histogram_matching"] = {}
511
-
512
- if key == "adaptive_histogram_equalization":
513
- if params["data_preprocessing"][key] is not False:
514
- # if histogram equalization is enabled, call histogram_matching
515
- params["data_preprocessing"]["histogram_matching"] = {
516
- "target": "adaptive"
517
- }
518
-
519
- # this is NOT a required parameter - a user should be able to train with NO built-in post-processing
520
- params = initialize_key(params, "data_postprocessing", {})
521
- params = initialize_key(
522
- params, "data_postprocessing_after_reverse_one_hot_encoding", {}
523
- )
524
- temp_dict = deepcopy(params["data_postprocessing"])
525
- for key in temp_dict:
526
- if key in postprocessing_after_reverse_one_hot_encoding:
527
- params["data_postprocessing_after_reverse_one_hot_encoding"][key] = params[
528
- "data_postprocessing"
529
- ][key]
530
- params["data_postprocessing"].pop(key)
531
-
532
- if "model" in params:
533
- assert isinstance(
534
- params["model"], dict
535
- ), "The 'model' parameter needs to be populated as a dictionary"
536
- assert (
537
- len(params["model"]) > 0
538
- ), "The 'model' parameter needs to be populated as a dictionary and should have all properties present"
539
- assert (
540
- "architecture" in params["model"]
541
- ), "The 'model' parameter needs 'architecture' to be defined"
542
- assert (
543
- "final_layer" in params["model"]
544
- ), "The 'model' parameter needs 'final_layer' to be defined"
545
- assert (
546
- "dimension" in params["model"]
547
- ), "The 'model' parameter needs 'dimension' to be defined"
548
-
549
- if "amp" in params["model"]:
550
- pass
551
- else:
552
- print("NOT using Mixed Precision Training")
553
- params["model"]["amp"] = False
554
-
555
- if "norm_type" in params["model"]:
556
- if (
557
- params["model"]["norm_type"] == None
558
- or params["model"]["norm_type"].lower() == "none"
559
- ):
560
- if not ("vgg" in params["model"]["architecture"]):
561
- raise ValueError(
562
- "Normalization type cannot be 'None' for non-VGG architectures"
563
- )
564
- else:
565
- print("WARNING: Initializing 'norm_type' as 'batch'", flush=True)
566
- params["model"]["norm_type"] = "batch"
567
-
568
- if not ("base_filters" in params["model"]):
569
- base_filters = 32
570
- params["model"]["base_filters"] = base_filters
571
- print("Using default 'base_filters' in 'model': ", base_filters)
572
- if not ("class_list" in params["model"]):
573
- params["model"]["class_list"] = [] # ensure that this is initialized
574
- if not ("ignore_label_validation" in params["model"]):
575
- params["model"]["ignore_label_validation"] = None
576
- if "batch_norm" in params["model"]:
577
- print(
578
- "WARNING: 'batch_norm' is no longer supported, please use 'norm_type' in 'model' instead",
579
- flush=True,
580
- )
581
- params["model"]["print_summary"] = params["model"].get("print_summary", True)
582
-
583
- channel_keys_to_check = ["n_channels", "channels", "model_channels"]
584
- for key in channel_keys_to_check:
585
- if key in params["model"]:
586
- params["model"]["num_channels"] = params["model"][key]
587
- break
588
-
589
- # initialize model type for processing: if not defined, default to torch
590
- if not ("type" in params["model"]):
591
- params["model"]["type"] = "torch"
592
-
593
- # initialize openvino model data type for processing: if not defined, default to FP32
594
- if not ("data_type" in params["model"]):
595
- params["model"]["data_type"] = "FP32"
596
-
597
- # set default save strategy for model
598
- if not ("save_at_every_epoch" in params["model"]):
599
- params["model"]["save_at_every_epoch"] = False
600
-
601
- if params["model"]["save_at_every_epoch"]:
602
- print(
603
- "WARNING: 'save_at_every_epoch' will result in TREMENDOUS storage usage; use at your own risk."
604
- )
605
-
606
- if isinstance(params["model"]["class_list"], str):
607
- if ("||" in params["model"]["class_list"]) or (
608
- "&&" in params["model"]["class_list"]
609
- ):
610
- # special case for multi-class computation - this needs to be handled during one-hot encoding mask construction
611
- print(
612
- "WARNING: This is a special case for multi-class computation, where different labels are processed together, `reverse_one_hot` will need mapping information to work correctly"
613
- )
614
- temp_classList = params["model"]["class_list"]
615
- # we don't need the brackets
616
- temp_classList = temp_classList.replace("[", "")
617
- temp_classList = temp_classList.replace("]", "")
618
- params["model"]["class_list"] = temp_classList.split(",")
619
- else:
620
- try:
621
- params["model"]["class_list"] = eval(params["model"]["class_list"])
622
- except Exception as e:
623
- ## todo: ensure logging captures assertion errors
624
- assert (
625
- False
626
- ), f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}"
627
- # logging.error(
628
- # f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}"
629
- # )
630
-
631
- assert (
632
- "nested_training" in params
633
- ), "The parameter 'nested_training' needs to be defined"
634
- # initialize defaults for nested training
635
- params["nested_training"]["stratified"] = params["nested_training"].get(
636
- "stratified", False
637
- )
638
- params["nested_training"]["stratified"] = params["nested_training"].get(
639
- "proportional", params["nested_training"]["stratified"]
640
- )
641
- params["nested_training"]["testing"] = params["nested_training"].get("testing", -5)
642
- params["nested_training"]["validation"] = params["nested_training"].get(
643
- "validation", -5
644
- )
645
-
646
- parallel_compute_command = ""
647
- if "parallel_compute_command" in params:
648
- parallel_compute_command = params["parallel_compute_command"]
649
- parallel_compute_command = parallel_compute_command.replace("'", "")
650
- parallel_compute_command = parallel_compute_command.replace('"', "")
651
- params["parallel_compute_command"] = parallel_compute_command
652
-
653
- if "opt" in params:
654
- print("DeprecationWarning: 'opt' has been superseded by 'optimizer'")
655
- params["optimizer"] = params["opt"]
656
-
657
- # initialize defaults for patch sampler
658
- temp_patch_sampler_dict = {
659
- "type": "uniform",
660
- "enable_padding": False,
661
- "padding_mode": "symmetric",
662
- "biased_sampling": False,
663
- }
664
- # check if patch_sampler is defined in the config
665
- if "patch_sampler" in params:
666
- # if "patch_sampler" is a string, then it is the type of sampler
667
- if isinstance(params["patch_sampler"], str):
668
- print(
669
- "WARNING: Defining 'patch_sampler' as a string will be deprecated in a future release, please use a dictionary instead"
670
- )
671
- temp_patch_sampler_dict["type"] = params["patch_sampler"].lower()
672
- elif isinstance(params["patch_sampler"], dict):
673
- # dict requires special handling
674
- for key in params["patch_sampler"]:
675
- temp_patch_sampler_dict[key] = params["patch_sampler"][key]
676
-
677
- # now assign the dict back to the params
678
- params["patch_sampler"] = temp_patch_sampler_dict
679
- del temp_patch_sampler_dict
680
-
681
- # define defaults
682
- for current_parameter in parameter_defaults:
683
- params = initialize_parameter(
684
- params, current_parameter, parameter_defaults[current_parameter], True
685
- )
686
-
687
- for current_parameter in parameter_defaults_string:
688
- params = initialize_parameter(
689
- params,
690
- current_parameter,
691
- parameter_defaults_string[current_parameter],
692
- False,
693
- )
694
-
695
- # ensure that the scheduler and optimizer are dicts
696
- if isinstance(params["scheduler"], str):
697
- temp_dict = {}
698
- temp_dict["type"] = params["scheduler"]
699
- params["scheduler"] = temp_dict
700
-
701
- if not ("step_size" in params["scheduler"]):
702
- params["scheduler"]["step_size"] = params["learning_rate"] / 5.0
703
- print(
704
- "WARNING: Setting default step_size to:", params["scheduler"]["step_size"]
705
- )
706
-
707
- # initialize default optimizer
708
- params["optimizer"] = params.get("optimizer", {})
709
- if isinstance(params["optimizer"], str):
710
- temp_dict = {}
711
- temp_dict["type"] = params["optimizer"]
712
- params["optimizer"] = temp_dict
713
-
714
- # initialize defaults for DP
715
- if params.get("differential_privacy"):
716
- params = parse_opacus_params(params, initialize_key)
717
-
718
- # initialize defaults for inference mechanism
719
- inference_mechanism = {"grid_aggregator_overlap": "crop", "patch_overlap": 0}
720
- initialize_inference_mechanism = False
721
- if not ("inference_mechanism" in params):
722
- initialize_inference_mechanism = True
723
- elif not (isinstance(params["inference_mechanism"], dict)):
724
- initialize_inference_mechanism = True
725
- else:
726
- for key in inference_mechanism:
727
- if not (key in params["inference_mechanism"]):
728
- params["inference_mechanism"][key] = inference_mechanism[key]
25
+ try:
26
+ if not isinstance(config_file_path, dict):
27
+ params = yaml.safe_load(open(config_file_path, "r"))
28
+ except yaml.YAMLError as e:
29
+ # this is a special case for config files with panoptica parameters
30
+ from panoptica.utils.config import _load_yaml
729
31
 
730
- if initialize_inference_mechanism:
731
- params["inference_mechanism"] = inference_mechanism
32
+ params = _load_yaml(config_file_path)
732
33
 
733
34
  return params
734
35
 
735
36
 
736
37
  def ConfigManager(
737
38
  config_file_path: Union[str, dict], version_check_flag: bool = True
738
- ) -> None:
39
+ ) -> dict:
739
40
  """
740
41
  This function parses the configuration file and returns a dictionary of parameters.
741
42
 
@@ -747,12 +48,27 @@ def ConfigManager(
747
48
  dict: The parameter dictionary.
748
49
  """
749
50
  try:
750
- return _parseConfig(config_file_path, version_check_flag)
51
+ parameters_config = Parameters(
52
+ **_parseConfig(config_file_path, version_check_flag)
53
+ )
54
+ parameters = parameters_config.model_dump(
55
+ exclude={
56
+ field
57
+ for field in exclude_parameters
58
+ if getattr(parameters_config, field) is None
59
+ }
60
+ )
61
+ return parameters
62
+
751
63
  except Exception as e:
64
+ if isinstance(e, ValidationError):
65
+ handle_configuration_errors(e)
66
+ raise
752
67
  ## todo: ensure logging captures assertion errors
753
- assert (
754
- False
755
- ), f"Config parsing failed: {config_file_path=}, {version_check_flag=}, Exception: {str(e)}, {traceback.format_exc()}"
68
+ else:
69
+ assert (
70
+ False
71
+ ), f"Config parsing failed: {config_file_path=}, {version_check_flag=}, Exception: {str(e)}, {traceback.format_exc()}"
756
72
  # logging.error(
757
73
  # f"gandlf config parsing failed: {config_file_path=}, {version_check_flag=}, Exception: {str(e)}, {traceback.format_exc()}"
758
74
  # )