GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of GANDLF might be problematic. Click here for more details.
- GANDLF/cli/deploy.py +2 -2
- GANDLF/cli/generate_metrics.py +35 -1
- GANDLF/cli/main_run.py +4 -10
- GANDLF/compute/__init__.py +0 -2
- GANDLF/compute/forward_pass.py +0 -1
- GANDLF/compute/generic.py +107 -2
- GANDLF/compute/inference_loop.py +4 -4
- GANDLF/compute/loss_and_metric.py +1 -2
- GANDLF/compute/training_loop.py +10 -10
- GANDLF/config_manager.py +33 -717
- GANDLF/configuration/__init__.py +0 -0
- GANDLF/configuration/default_config.py +73 -0
- GANDLF/configuration/differential_privacy_config.py +16 -0
- GANDLF/configuration/exclude_parameters.py +1 -0
- GANDLF/configuration/model_config.py +82 -0
- GANDLF/configuration/nested_training_config.py +25 -0
- GANDLF/configuration/optimizer_config.py +121 -0
- GANDLF/configuration/parameters_config.py +10 -0
- GANDLF/configuration/patch_sampler_config.py +11 -0
- GANDLF/configuration/post_processing_config.py +10 -0
- GANDLF/configuration/pre_processing_config.py +94 -0
- GANDLF/configuration/scheduler_config.py +92 -0
- GANDLF/configuration/user_defined_config.py +131 -0
- GANDLF/configuration/utils.py +96 -0
- GANDLF/configuration/validators.py +479 -0
- GANDLF/data/__init__.py +14 -16
- GANDLF/data/lightning_datamodule.py +119 -0
- GANDLF/entrypoints/run.py +36 -31
- GANDLF/inference_manager.py +69 -25
- GANDLF/losses/__init__.py +23 -1
- GANDLF/losses/loss_calculators.py +79 -0
- GANDLF/losses/segmentation.py +3 -2
- GANDLF/metrics/__init__.py +26 -0
- GANDLF/metrics/generic.py +1 -1
- GANDLF/metrics/metric_calculators.py +102 -0
- GANDLF/metrics/panoptica_config_brats.yaml +56 -0
- GANDLF/metrics/segmentation_panoptica.py +49 -0
- GANDLF/models/__init__.py +8 -3
- GANDLF/models/lightning_module.py +2102 -0
- GANDLF/optimizers/__init__.py +4 -8
- GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
- GANDLF/schedulers/__init__.py +11 -4
- GANDLF/schedulers/wrap_torch.py +15 -3
- GANDLF/training_manager.py +160 -50
- GANDLF/utils/__init__.py +5 -3
- GANDLF/utils/imaging.py +176 -35
- GANDLF/utils/modelio.py +12 -8
- GANDLF/utils/pred_target_processors.py +71 -0
- GANDLF/utils/tensor.py +2 -1
- GANDLF/utils/write_parse.py +1 -1
- GANDLF/version.py +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
GANDLF/config_manager.py
CHANGED
|
@@ -1,113 +1,11 @@
|
|
|
1
|
-
# import logging
|
|
2
1
|
import traceback
|
|
3
|
-
from typing import
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
from copy import deepcopy
|
|
2
|
+
from typing import Union
|
|
3
|
+
import yaml
|
|
4
|
+
from pydantic import ValidationError
|
|
7
5
|
|
|
8
|
-
from .
|
|
9
|
-
from GANDLF.
|
|
10
|
-
from GANDLF.
|
|
11
|
-
|
|
12
|
-
from GANDLF.metrics import surface_distance_ids
|
|
13
|
-
from importlib.metadata import version
|
|
14
|
-
|
|
15
|
-
## dictionary to define defaults for appropriate options, which are evaluated
|
|
16
|
-
parameter_defaults = {
|
|
17
|
-
"weighted_loss": False, # whether weighted loss is to be used or not
|
|
18
|
-
"verbose": False, # general application verbosity
|
|
19
|
-
"q_verbose": False, # queue construction verbosity
|
|
20
|
-
"medcam_enabled": False, # interpretability via medcam
|
|
21
|
-
"save_training": False, # save outputs during training
|
|
22
|
-
"save_output": False, # save outputs during validation/testing
|
|
23
|
-
"in_memory": False, # pin data to cpu memory
|
|
24
|
-
"pin_memory_dataloader": False, # pin data to gpu memory
|
|
25
|
-
"scaling_factor": 1, # scaling factor for regression problems
|
|
26
|
-
"q_max_length": 100, # the max length of queue
|
|
27
|
-
"q_samples_per_volume": 10, # number of samples per volume
|
|
28
|
-
"q_num_workers": 4, # number of worker threads to use
|
|
29
|
-
"num_epochs": 100, # total number of epochs to train
|
|
30
|
-
"patience": 100, # number of epochs to wait for performance improvement
|
|
31
|
-
"batch_size": 1, # default batch size of training
|
|
32
|
-
"learning_rate": 0.001, # default learning rate
|
|
33
|
-
"clip_grad": None, # clip_gradient value
|
|
34
|
-
"track_memory_usage": False, # default memory tracking
|
|
35
|
-
"memory_save_mode": False, # default memory saving, if enabled, resize/resample will save files to disk
|
|
36
|
-
"print_rgb_label_warning": True, # print rgb label warning
|
|
37
|
-
"data_postprocessing": {}, # default data postprocessing
|
|
38
|
-
"grid_aggregator_overlap": "crop", # default grid aggregator overlap strategy
|
|
39
|
-
"determinism": False, # using deterministic version of computation
|
|
40
|
-
"previous_parameters": None, # previous parameters to be used for resuming training and perform sanity checking
|
|
41
|
-
}
|
|
42
|
-
|
|
43
|
-
## dictionary to define string defaults for appropriate options
|
|
44
|
-
parameter_defaults_string = {
|
|
45
|
-
"optimizer": "adam", # the optimizer
|
|
46
|
-
"scheduler": "triangle_modified", # the default scheduler
|
|
47
|
-
"clip_mode": None, # default clip mode
|
|
48
|
-
}
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def initialize_parameter(
|
|
52
|
-
params: dict,
|
|
53
|
-
parameter_to_initialize: str,
|
|
54
|
-
value: Optional[Union[str, list, int, dict]] = None,
|
|
55
|
-
evaluate: Optional[bool] = True,
|
|
56
|
-
) -> dict:
|
|
57
|
-
"""
|
|
58
|
-
This function will initialize the parameter in the parameters dict to the value if it is absent.
|
|
59
|
-
|
|
60
|
-
Args:
|
|
61
|
-
params (dict): The parameter dictionary.
|
|
62
|
-
parameter_to_initialize (str): The parameter to initialize.
|
|
63
|
-
value (Optional[Union[str, list, int, dict]], optional): The value to initialize. Defaults to None.
|
|
64
|
-
evaluate (Optional[bool], optional): Whether to evaluate the value. Defaults to True.
|
|
65
|
-
|
|
66
|
-
Returns:
|
|
67
|
-
dict: The parameter dictionary.
|
|
68
|
-
"""
|
|
69
|
-
if parameter_to_initialize in params:
|
|
70
|
-
if evaluate:
|
|
71
|
-
if isinstance(params[parameter_to_initialize], str):
|
|
72
|
-
if params[parameter_to_initialize].lower() == "none":
|
|
73
|
-
params[parameter_to_initialize] = ast.literal_eval(
|
|
74
|
-
params[parameter_to_initialize]
|
|
75
|
-
)
|
|
76
|
-
else:
|
|
77
|
-
print(
|
|
78
|
-
"WARNING: Initializing '" + parameter_to_initialize + "' as " + str(value)
|
|
79
|
-
)
|
|
80
|
-
params[parameter_to_initialize] = value
|
|
81
|
-
|
|
82
|
-
return params
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
def initialize_key(
|
|
86
|
-
parameters: dict, key: str, value: Optional[Union[str, float, list, dict]] = None
|
|
87
|
-
) -> dict:
|
|
88
|
-
"""
|
|
89
|
-
This function initializes a key in the parameters dictionary to a value if it is absent.
|
|
90
|
-
|
|
91
|
-
Args:
|
|
92
|
-
parameters (dict): The parameter dictionary.
|
|
93
|
-
key (str): The key to initialize.
|
|
94
|
-
value (Optional[Union[str, float, list, dict]], optional): The value to initialize. Defaults to None.
|
|
95
|
-
|
|
96
|
-
Returns:
|
|
97
|
-
dict: The parameter dictionary.
|
|
98
|
-
"""
|
|
99
|
-
if parameters is None:
|
|
100
|
-
parameters = {}
|
|
101
|
-
if key in parameters:
|
|
102
|
-
if parameters[key] is not None:
|
|
103
|
-
if isinstance(parameters[key], dict):
|
|
104
|
-
# if key is present but not defined
|
|
105
|
-
if len(parameters[key]) == 0:
|
|
106
|
-
parameters[key] = value
|
|
107
|
-
else:
|
|
108
|
-
parameters[key] = value # if key is absent
|
|
109
|
-
|
|
110
|
-
return parameters
|
|
6
|
+
from GANDLF.configuration.parameters_config import Parameters
|
|
7
|
+
from GANDLF.configuration.exclude_parameters import exclude_parameters
|
|
8
|
+
from GANDLF.configuration.utils import handle_configuration_errors
|
|
111
9
|
|
|
112
10
|
|
|
113
11
|
def _parseConfig(
|
|
@@ -124,618 +22,21 @@ def _parseConfig(
|
|
|
124
22
|
dict: The parameter dictionary.
|
|
125
23
|
"""
|
|
126
24
|
params = config_file_path
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
), "The 'version' key needs to be defined in config with 'minimum' and 'maximum' fields to determine the compatibility of configuration with code base"
|
|
134
|
-
version_check(params["version"], version_to_check=version("GANDLF"))
|
|
135
|
-
|
|
136
|
-
if "patch_size" in params:
|
|
137
|
-
# duplicate patch size if it is an int or float
|
|
138
|
-
if isinstance(params["patch_size"], int) or isinstance(
|
|
139
|
-
params["patch_size"], float
|
|
140
|
-
):
|
|
141
|
-
params["patch_size"] = [params["patch_size"]]
|
|
142
|
-
# in case someone decides to pass a single value list
|
|
143
|
-
if len(params["patch_size"]) == 1:
|
|
144
|
-
actual_patch_size = []
|
|
145
|
-
for _ in range(params["model"]["dimension"]):
|
|
146
|
-
actual_patch_size.append(params["patch_size"][0])
|
|
147
|
-
params["patch_size"] = actual_patch_size
|
|
148
|
-
|
|
149
|
-
# parse patch size as needed for computations
|
|
150
|
-
if len(params["patch_size"]) == 2: # 2d check
|
|
151
|
-
# ensuring same size during torchio processing
|
|
152
|
-
params["patch_size"].append(1)
|
|
153
|
-
if "dimension" not in params["model"]:
|
|
154
|
-
params["model"]["dimension"] = 2
|
|
155
|
-
elif len(params["patch_size"]) == 3: # 2d check
|
|
156
|
-
if "dimension" not in params["model"]:
|
|
157
|
-
params["model"]["dimension"] = 3
|
|
158
|
-
assert "patch_size" in params, "Patch size needs to be defined in the config file"
|
|
159
|
-
|
|
160
|
-
if "resize" in params:
|
|
161
|
-
print(
|
|
162
|
-
"WARNING: 'resize' should be defined under 'data_processing', this will be skipped",
|
|
163
|
-
file=sys.stderr,
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
assert "modality" in params, "'modality' needs to be defined in the config file"
|
|
167
|
-
params["modality"] = params["modality"].lower()
|
|
168
|
-
assert params["modality"] in [
|
|
169
|
-
"rad",
|
|
170
|
-
"histo",
|
|
171
|
-
"path",
|
|
172
|
-
], "Modality should be either 'rad' or 'path'"
|
|
173
|
-
|
|
174
|
-
assert (
|
|
175
|
-
"loss_function" in params
|
|
176
|
-
), "'loss_function' needs to be defined in the config file"
|
|
177
|
-
if "loss_function" in params:
|
|
178
|
-
# check if user has passed a dict
|
|
179
|
-
if isinstance(params["loss_function"], dict): # if this is a dict
|
|
180
|
-
if len(params["loss_function"]) > 0: # only proceed if something is defined
|
|
181
|
-
for key in params["loss_function"]: # iterate through all keys
|
|
182
|
-
if key == "mse":
|
|
183
|
-
if (params["loss_function"][key] is None) or not (
|
|
184
|
-
"reduction" in params["loss_function"][key]
|
|
185
|
-
):
|
|
186
|
-
params["loss_function"][key] = {}
|
|
187
|
-
params["loss_function"][key]["reduction"] = "mean"
|
|
188
|
-
else:
|
|
189
|
-
# use simple string for other functions - can be extended with parameters, if needed
|
|
190
|
-
params["loss_function"] = key
|
|
191
|
-
else:
|
|
192
|
-
# check if user has passed a single string
|
|
193
|
-
if params["loss_function"] == "mse":
|
|
194
|
-
params["loss_function"] = {}
|
|
195
|
-
params["loss_function"]["mse"] = {}
|
|
196
|
-
params["loss_function"]["mse"]["reduction"] = "mean"
|
|
197
|
-
elif params["loss_function"] == "focal":
|
|
198
|
-
params["loss_function"] = {}
|
|
199
|
-
params["loss_function"]["focal"] = {}
|
|
200
|
-
params["loss_function"]["focal"]["gamma"] = 2.0
|
|
201
|
-
params["loss_function"]["focal"]["size_average"] = True
|
|
202
|
-
|
|
203
|
-
assert "metrics" in params, "'metrics' needs to be defined in the config file"
|
|
204
|
-
if "metrics" in params:
|
|
205
|
-
if not isinstance(params["metrics"], dict):
|
|
206
|
-
temp_dict = {}
|
|
207
|
-
else:
|
|
208
|
-
temp_dict = params["metrics"]
|
|
209
|
-
|
|
210
|
-
# initialize metrics dict
|
|
211
|
-
for metric in params["metrics"]:
|
|
212
|
-
# assigning a new variable because some metrics can be dicts, and we want to get the first key
|
|
213
|
-
comparison_string = metric
|
|
214
|
-
if isinstance(metric, dict):
|
|
215
|
-
comparison_string = list(metric.keys())[0]
|
|
216
|
-
# these metrics always need to be dicts
|
|
217
|
-
if comparison_string in [
|
|
218
|
-
"accuracy",
|
|
219
|
-
"f1",
|
|
220
|
-
"precision",
|
|
221
|
-
"recall",
|
|
222
|
-
"specificity",
|
|
223
|
-
"iou",
|
|
224
|
-
]:
|
|
225
|
-
if not isinstance(metric, dict):
|
|
226
|
-
temp_dict[metric] = {}
|
|
227
|
-
else:
|
|
228
|
-
temp_dict[comparison_string] = metric
|
|
229
|
-
elif not isinstance(metric, dict):
|
|
230
|
-
temp_dict[metric] = None
|
|
231
|
-
|
|
232
|
-
# special case for accuracy, precision, recall, and specificity; which could be dicts
|
|
233
|
-
## need to find a better way to do this
|
|
234
|
-
if any(
|
|
235
|
-
_ in comparison_string
|
|
236
|
-
for _ in ["precision", "recall", "specificity", "accuracy", "f1"]
|
|
237
|
-
):
|
|
238
|
-
if comparison_string != "classification_accuracy":
|
|
239
|
-
temp_dict[comparison_string] = initialize_key(
|
|
240
|
-
temp_dict[comparison_string], "average", "weighted"
|
|
241
|
-
)
|
|
242
|
-
temp_dict[comparison_string] = initialize_key(
|
|
243
|
-
temp_dict[comparison_string], "multi_class", True
|
|
244
|
-
)
|
|
245
|
-
temp_dict[comparison_string] = initialize_key(
|
|
246
|
-
temp_dict[comparison_string], "mdmc_average", "samplewise"
|
|
247
|
-
)
|
|
248
|
-
temp_dict[comparison_string] = initialize_key(
|
|
249
|
-
temp_dict[comparison_string], "threshold", 0.5
|
|
250
|
-
)
|
|
251
|
-
if comparison_string == "accuracy":
|
|
252
|
-
temp_dict[comparison_string] = initialize_key(
|
|
253
|
-
temp_dict[comparison_string], "subset_accuracy", False
|
|
254
|
-
)
|
|
255
|
-
elif "iou" in comparison_string:
|
|
256
|
-
temp_dict["iou"] = initialize_key(
|
|
257
|
-
temp_dict["iou"], "reduction", "elementwise_mean"
|
|
258
|
-
)
|
|
259
|
-
temp_dict["iou"] = initialize_key(temp_dict["iou"], "threshold", 0.5)
|
|
260
|
-
elif comparison_string in surface_distance_ids:
|
|
261
|
-
temp_dict[comparison_string] = initialize_key(
|
|
262
|
-
temp_dict[comparison_string], "connectivity", 1
|
|
263
|
-
)
|
|
264
|
-
temp_dict[comparison_string] = initialize_key(
|
|
265
|
-
temp_dict[comparison_string], "threshold", None
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
params["metrics"] = temp_dict
|
|
269
|
-
|
|
270
|
-
# this is NOT a required parameter - a user should be able to train with NO augmentations
|
|
271
|
-
params = initialize_key(params, "data_augmentation", {})
|
|
272
|
-
# for all others, ensure probability is present
|
|
273
|
-
params["data_augmentation"]["default_probability"] = params[
|
|
274
|
-
"data_augmentation"
|
|
275
|
-
].get("default_probability", 0.5)
|
|
276
|
-
|
|
277
|
-
if not (params["data_augmentation"] is None):
|
|
278
|
-
if len(params["data_augmentation"]) > 0: # only when augmentations are defined
|
|
279
|
-
# special case for random swapping and elastic transformations - which takes a patch size for computation
|
|
280
|
-
for key in ["swap", "elastic"]:
|
|
281
|
-
if key in params["data_augmentation"]:
|
|
282
|
-
params["data_augmentation"][key] = initialize_key(
|
|
283
|
-
params["data_augmentation"][key],
|
|
284
|
-
"patch_size",
|
|
285
|
-
np.round(np.array(params["patch_size"]) / 10)
|
|
286
|
-
.astype("int")
|
|
287
|
-
.tolist(),
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
# special case for swap default initialization
|
|
291
|
-
if "swap" in params["data_augmentation"]:
|
|
292
|
-
params["data_augmentation"]["swap"] = initialize_key(
|
|
293
|
-
params["data_augmentation"]["swap"], "num_iterations", 100
|
|
294
|
-
)
|
|
295
|
-
|
|
296
|
-
# special case for affine default initialization
|
|
297
|
-
if "affine" in params["data_augmentation"]:
|
|
298
|
-
params["data_augmentation"]["affine"] = initialize_key(
|
|
299
|
-
params["data_augmentation"]["affine"], "scales", 0.1
|
|
300
|
-
)
|
|
301
|
-
params["data_augmentation"]["affine"] = initialize_key(
|
|
302
|
-
params["data_augmentation"]["affine"], "degrees", 15
|
|
303
|
-
)
|
|
304
|
-
params["data_augmentation"]["affine"] = initialize_key(
|
|
305
|
-
params["data_augmentation"]["affine"], "translation", 2
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
if "motion" in params["data_augmentation"]:
|
|
309
|
-
params["data_augmentation"]["motion"] = initialize_key(
|
|
310
|
-
params["data_augmentation"]["motion"], "num_transforms", 2
|
|
311
|
-
)
|
|
312
|
-
params["data_augmentation"]["motion"] = initialize_key(
|
|
313
|
-
params["data_augmentation"]["motion"], "degrees", 15
|
|
314
|
-
)
|
|
315
|
-
params["data_augmentation"]["motion"] = initialize_key(
|
|
316
|
-
params["data_augmentation"]["motion"], "translation", 2
|
|
317
|
-
)
|
|
318
|
-
params["data_augmentation"]["motion"] = initialize_key(
|
|
319
|
-
params["data_augmentation"]["motion"], "interpolation", "linear"
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
# special case for random blur/noise - which takes a std-dev range
|
|
323
|
-
for std_aug in ["blur", "noise_var"]:
|
|
324
|
-
if std_aug in params["data_augmentation"]:
|
|
325
|
-
params["data_augmentation"][std_aug] = initialize_key(
|
|
326
|
-
params["data_augmentation"][std_aug], "std", None
|
|
327
|
-
)
|
|
328
|
-
for std_aug in ["noise"]:
|
|
329
|
-
if std_aug in params["data_augmentation"]:
|
|
330
|
-
params["data_augmentation"][std_aug] = initialize_key(
|
|
331
|
-
params["data_augmentation"][std_aug], "std", [0, 1]
|
|
332
|
-
)
|
|
333
|
-
|
|
334
|
-
# special case for random noise - which takes a mean range
|
|
335
|
-
for mean_aug in ["noise", "noise_var"]:
|
|
336
|
-
if mean_aug in params["data_augmentation"]:
|
|
337
|
-
params["data_augmentation"][mean_aug] = initialize_key(
|
|
338
|
-
params["data_augmentation"][mean_aug], "mean", 0
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
# special case for augmentations that need axis defined
|
|
342
|
-
for axis_aug in ["flip", "anisotropic", "rotate_90", "rotate_180"]:
|
|
343
|
-
if axis_aug in params["data_augmentation"]:
|
|
344
|
-
params["data_augmentation"][axis_aug] = initialize_key(
|
|
345
|
-
params["data_augmentation"][axis_aug], "axis", [0, 1, 2]
|
|
346
|
-
)
|
|
347
|
-
|
|
348
|
-
# special case for colorjitter
|
|
349
|
-
if "colorjitter" in params["data_augmentation"]:
|
|
350
|
-
params["data_augmentation"] = initialize_key(
|
|
351
|
-
params["data_augmentation"], "colorjitter", {}
|
|
352
|
-
)
|
|
353
|
-
for key in ["brightness", "contrast", "saturation"]:
|
|
354
|
-
params["data_augmentation"]["colorjitter"] = initialize_key(
|
|
355
|
-
params["data_augmentation"]["colorjitter"], key, [0, 1]
|
|
356
|
-
)
|
|
357
|
-
params["data_augmentation"]["colorjitter"] = initialize_key(
|
|
358
|
-
params["data_augmentation"]["colorjitter"], "hue", [-0.5, 0.5]
|
|
359
|
-
)
|
|
360
|
-
|
|
361
|
-
# Added HED augmentation in gandlf
|
|
362
|
-
hed_augmentation_types = [
|
|
363
|
-
"hed_transform",
|
|
364
|
-
# "hed_transform_light",
|
|
365
|
-
# "hed_transform_heavy",
|
|
366
|
-
]
|
|
367
|
-
for augmentation_type in hed_augmentation_types:
|
|
368
|
-
if augmentation_type in params["data_augmentation"]:
|
|
369
|
-
params["data_augmentation"] = initialize_key(
|
|
370
|
-
params["data_augmentation"], "hed_transform", {}
|
|
371
|
-
)
|
|
372
|
-
ranges = [
|
|
373
|
-
"haematoxylin_bias_range",
|
|
374
|
-
"eosin_bias_range",
|
|
375
|
-
"dab_bias_range",
|
|
376
|
-
"haematoxylin_sigma_range",
|
|
377
|
-
"eosin_sigma_range",
|
|
378
|
-
"dab_sigma_range",
|
|
379
|
-
]
|
|
380
|
-
|
|
381
|
-
default_range = (
|
|
382
|
-
[-0.1, 0.1]
|
|
383
|
-
if augmentation_type == "hed_transform"
|
|
384
|
-
else (
|
|
385
|
-
[-0.03, 0.03]
|
|
386
|
-
if augmentation_type == "hed_transform_light"
|
|
387
|
-
else [-0.95, 0.95]
|
|
388
|
-
)
|
|
389
|
-
)
|
|
390
|
-
|
|
391
|
-
for key in ranges:
|
|
392
|
-
params["data_augmentation"]["hed_transform"] = initialize_key(
|
|
393
|
-
params["data_augmentation"]["hed_transform"],
|
|
394
|
-
key,
|
|
395
|
-
default_range,
|
|
396
|
-
)
|
|
397
|
-
|
|
398
|
-
params["data_augmentation"]["hed_transform"] = initialize_key(
|
|
399
|
-
params["data_augmentation"]["hed_transform"],
|
|
400
|
-
"cutoff_range",
|
|
401
|
-
[0, 1],
|
|
402
|
-
)
|
|
403
|
-
|
|
404
|
-
# special case for anisotropic
|
|
405
|
-
if "anisotropic" in params["data_augmentation"]:
|
|
406
|
-
if not ("downsampling" in params["data_augmentation"]["anisotropic"]):
|
|
407
|
-
default_downsampling = 1.5
|
|
408
|
-
else:
|
|
409
|
-
default_downsampling = params["data_augmentation"]["anisotropic"][
|
|
410
|
-
"downsampling"
|
|
411
|
-
]
|
|
412
|
-
|
|
413
|
-
initialize_downsampling = False
|
|
414
|
-
if isinstance(default_downsampling, list):
|
|
415
|
-
if len(default_downsampling) != 2:
|
|
416
|
-
initialize_downsampling = True
|
|
417
|
-
print(
|
|
418
|
-
"WARNING: 'anisotropic' augmentation needs to be either a single number of a list of 2 numbers: https://torchio.readthedocs.io/transforms/augmentation.html?highlight=randomswap#torchio.transforms.RandomAnisotropy.",
|
|
419
|
-
file=sys.stderr,
|
|
420
|
-
)
|
|
421
|
-
default_downsampling = default_downsampling[0] # only
|
|
422
|
-
else:
|
|
423
|
-
initialize_downsampling = True
|
|
424
|
-
|
|
425
|
-
if initialize_downsampling:
|
|
426
|
-
if default_downsampling < 1:
|
|
427
|
-
print(
|
|
428
|
-
"WARNING: 'anisotropic' augmentation needs the 'downsampling' parameter to be greater than 1, defaulting to 1.5.",
|
|
429
|
-
file=sys.stderr,
|
|
430
|
-
)
|
|
431
|
-
# default
|
|
432
|
-
params["data_augmentation"]["anisotropic"]["downsampling"] = 1.5
|
|
433
|
-
|
|
434
|
-
for key in params["data_augmentation"]:
|
|
435
|
-
if key != "default_probability":
|
|
436
|
-
params["data_augmentation"][key] = initialize_key(
|
|
437
|
-
params["data_augmentation"][key],
|
|
438
|
-
"probability",
|
|
439
|
-
params["data_augmentation"]["default_probability"],
|
|
440
|
-
)
|
|
441
|
-
|
|
442
|
-
# this is NOT a required parameter - a user should be able to train with NO built-in pre-processing
|
|
443
|
-
params = initialize_key(params, "data_preprocessing", {})
|
|
444
|
-
if not (params["data_preprocessing"] is None):
|
|
445
|
-
# perform this only when pre-processing is defined
|
|
446
|
-
if len(params["data_preprocessing"]) > 0:
|
|
447
|
-
thresholdOrClip = False
|
|
448
|
-
# this can be extended, as required
|
|
449
|
-
thresholdOrClipDict = ["threshold", "clip", "clamp"]
|
|
450
|
-
|
|
451
|
-
resize_requested = False
|
|
452
|
-
temp_dict = deepcopy(params["data_preprocessing"])
|
|
453
|
-
for key in params["data_preprocessing"]:
|
|
454
|
-
if key in ["resize", "resize_image", "resize_images", "resize_patch"]:
|
|
455
|
-
resize_requested = True
|
|
456
|
-
|
|
457
|
-
if key in ["resample_min", "resample_minimum"]:
|
|
458
|
-
if "resolution" in params["data_preprocessing"][key]:
|
|
459
|
-
resize_requested = True
|
|
460
|
-
resolution_temp = np.array(
|
|
461
|
-
params["data_preprocessing"][key]["resolution"]
|
|
462
|
-
)
|
|
463
|
-
if resolution_temp.size == 1:
|
|
464
|
-
temp_dict[key]["resolution"] = np.array(
|
|
465
|
-
[resolution_temp, resolution_temp]
|
|
466
|
-
).tolist()
|
|
467
|
-
else:
|
|
468
|
-
temp_dict.pop(key)
|
|
469
|
-
|
|
470
|
-
params["data_preprocessing"] = temp_dict
|
|
471
|
-
|
|
472
|
-
if resize_requested and "resample" in params["data_preprocessing"]:
|
|
473
|
-
for key in ["resize", "resize_image", "resize_images", "resize_patch"]:
|
|
474
|
-
if key in params["data_preprocessing"]:
|
|
475
|
-
params["data_preprocessing"].pop(key)
|
|
476
|
-
|
|
477
|
-
print(
|
|
478
|
-
"WARNING: Different 'resize' operations are ignored as 'resample' is defined under 'data_processing'",
|
|
479
|
-
file=sys.stderr,
|
|
480
|
-
)
|
|
481
|
-
|
|
482
|
-
# iterate through all keys
|
|
483
|
-
for key in params["data_preprocessing"]: # iterate through all keys
|
|
484
|
-
if key in thresholdOrClipDict:
|
|
485
|
-
# we only allow one of threshold or clip to occur and not both
|
|
486
|
-
assert not (
|
|
487
|
-
thresholdOrClip
|
|
488
|
-
), "Use only `threshold` or `clip`, not both"
|
|
489
|
-
thresholdOrClip = True
|
|
490
|
-
# initialize if nothing is present
|
|
491
|
-
if not (isinstance(params["data_preprocessing"][key], dict)):
|
|
492
|
-
params["data_preprocessing"][key] = {}
|
|
493
|
-
|
|
494
|
-
# if one of the required parameters is not present, initialize with lowest/highest possible values
|
|
495
|
-
# this ensures the absence of a field doesn't affect processing
|
|
496
|
-
# for threshold or clip, ensure min and max are defined
|
|
497
|
-
if not "min" in params["data_preprocessing"][key]:
|
|
498
|
-
params["data_preprocessing"][key]["min"] = sys.float_info.min
|
|
499
|
-
if not "max" in params["data_preprocessing"][key]:
|
|
500
|
-
params["data_preprocessing"][key]["max"] = sys.float_info.max
|
|
501
|
-
|
|
502
|
-
if key == "histogram_matching":
|
|
503
|
-
if params["data_preprocessing"][key] is not False:
|
|
504
|
-
if not (isinstance(params["data_preprocessing"][key], dict)):
|
|
505
|
-
params["data_preprocessing"][key] = {}
|
|
506
|
-
|
|
507
|
-
if key == "histogram_equalization":
|
|
508
|
-
if params["data_preprocessing"][key] is not False:
|
|
509
|
-
# if histogram equalization is enabled, call histogram_matching
|
|
510
|
-
params["data_preprocessing"]["histogram_matching"] = {}
|
|
511
|
-
|
|
512
|
-
if key == "adaptive_histogram_equalization":
|
|
513
|
-
if params["data_preprocessing"][key] is not False:
|
|
514
|
-
# if histogram equalization is enabled, call histogram_matching
|
|
515
|
-
params["data_preprocessing"]["histogram_matching"] = {
|
|
516
|
-
"target": "adaptive"
|
|
517
|
-
}
|
|
518
|
-
|
|
519
|
-
# this is NOT a required parameter - a user should be able to train with NO built-in post-processing
|
|
520
|
-
params = initialize_key(params, "data_postprocessing", {})
|
|
521
|
-
params = initialize_key(
|
|
522
|
-
params, "data_postprocessing_after_reverse_one_hot_encoding", {}
|
|
523
|
-
)
|
|
524
|
-
temp_dict = deepcopy(params["data_postprocessing"])
|
|
525
|
-
for key in temp_dict:
|
|
526
|
-
if key in postprocessing_after_reverse_one_hot_encoding:
|
|
527
|
-
params["data_postprocessing_after_reverse_one_hot_encoding"][key] = params[
|
|
528
|
-
"data_postprocessing"
|
|
529
|
-
][key]
|
|
530
|
-
params["data_postprocessing"].pop(key)
|
|
531
|
-
|
|
532
|
-
if "model" in params:
|
|
533
|
-
assert isinstance(
|
|
534
|
-
params["model"], dict
|
|
535
|
-
), "The 'model' parameter needs to be populated as a dictionary"
|
|
536
|
-
assert (
|
|
537
|
-
len(params["model"]) > 0
|
|
538
|
-
), "The 'model' parameter needs to be populated as a dictionary and should have all properties present"
|
|
539
|
-
assert (
|
|
540
|
-
"architecture" in params["model"]
|
|
541
|
-
), "The 'model' parameter needs 'architecture' to be defined"
|
|
542
|
-
assert (
|
|
543
|
-
"final_layer" in params["model"]
|
|
544
|
-
), "The 'model' parameter needs 'final_layer' to be defined"
|
|
545
|
-
assert (
|
|
546
|
-
"dimension" in params["model"]
|
|
547
|
-
), "The 'model' parameter needs 'dimension' to be defined"
|
|
548
|
-
|
|
549
|
-
if "amp" in params["model"]:
|
|
550
|
-
pass
|
|
551
|
-
else:
|
|
552
|
-
print("NOT using Mixed Precision Training")
|
|
553
|
-
params["model"]["amp"] = False
|
|
554
|
-
|
|
555
|
-
if "norm_type" in params["model"]:
|
|
556
|
-
if (
|
|
557
|
-
params["model"]["norm_type"] == None
|
|
558
|
-
or params["model"]["norm_type"].lower() == "none"
|
|
559
|
-
):
|
|
560
|
-
if not ("vgg" in params["model"]["architecture"]):
|
|
561
|
-
raise ValueError(
|
|
562
|
-
"Normalization type cannot be 'None' for non-VGG architectures"
|
|
563
|
-
)
|
|
564
|
-
else:
|
|
565
|
-
print("WARNING: Initializing 'norm_type' as 'batch'", flush=True)
|
|
566
|
-
params["model"]["norm_type"] = "batch"
|
|
567
|
-
|
|
568
|
-
if not ("base_filters" in params["model"]):
|
|
569
|
-
base_filters = 32
|
|
570
|
-
params["model"]["base_filters"] = base_filters
|
|
571
|
-
print("Using default 'base_filters' in 'model': ", base_filters)
|
|
572
|
-
if not ("class_list" in params["model"]):
|
|
573
|
-
params["model"]["class_list"] = [] # ensure that this is initialized
|
|
574
|
-
if not ("ignore_label_validation" in params["model"]):
|
|
575
|
-
params["model"]["ignore_label_validation"] = None
|
|
576
|
-
if "batch_norm" in params["model"]:
|
|
577
|
-
print(
|
|
578
|
-
"WARNING: 'batch_norm' is no longer supported, please use 'norm_type' in 'model' instead",
|
|
579
|
-
flush=True,
|
|
580
|
-
)
|
|
581
|
-
params["model"]["print_summary"] = params["model"].get("print_summary", True)
|
|
582
|
-
|
|
583
|
-
channel_keys_to_check = ["n_channels", "channels", "model_channels"]
|
|
584
|
-
for key in channel_keys_to_check:
|
|
585
|
-
if key in params["model"]:
|
|
586
|
-
params["model"]["num_channels"] = params["model"][key]
|
|
587
|
-
break
|
|
588
|
-
|
|
589
|
-
# initialize model type for processing: if not defined, default to torch
|
|
590
|
-
if not ("type" in params["model"]):
|
|
591
|
-
params["model"]["type"] = "torch"
|
|
592
|
-
|
|
593
|
-
# initialize openvino model data type for processing: if not defined, default to FP32
|
|
594
|
-
if not ("data_type" in params["model"]):
|
|
595
|
-
params["model"]["data_type"] = "FP32"
|
|
596
|
-
|
|
597
|
-
# set default save strategy for model
|
|
598
|
-
if not ("save_at_every_epoch" in params["model"]):
|
|
599
|
-
params["model"]["save_at_every_epoch"] = False
|
|
600
|
-
|
|
601
|
-
if params["model"]["save_at_every_epoch"]:
|
|
602
|
-
print(
|
|
603
|
-
"WARNING: 'save_at_every_epoch' will result in TREMENDOUS storage usage; use at your own risk."
|
|
604
|
-
)
|
|
605
|
-
|
|
606
|
-
if isinstance(params["model"]["class_list"], str):
|
|
607
|
-
if ("||" in params["model"]["class_list"]) or (
|
|
608
|
-
"&&" in params["model"]["class_list"]
|
|
609
|
-
):
|
|
610
|
-
# special case for multi-class computation - this needs to be handled during one-hot encoding mask construction
|
|
611
|
-
print(
|
|
612
|
-
"WARNING: This is a special case for multi-class computation, where different labels are processed together, `reverse_one_hot` will need mapping information to work correctly"
|
|
613
|
-
)
|
|
614
|
-
temp_classList = params["model"]["class_list"]
|
|
615
|
-
# we don't need the brackets
|
|
616
|
-
temp_classList = temp_classList.replace("[", "")
|
|
617
|
-
temp_classList = temp_classList.replace("]", "")
|
|
618
|
-
params["model"]["class_list"] = temp_classList.split(",")
|
|
619
|
-
else:
|
|
620
|
-
try:
|
|
621
|
-
params["model"]["class_list"] = eval(params["model"]["class_list"])
|
|
622
|
-
except Exception as e:
|
|
623
|
-
## todo: ensure logging captures assertion errors
|
|
624
|
-
assert (
|
|
625
|
-
False
|
|
626
|
-
), f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}"
|
|
627
|
-
# logging.error(
|
|
628
|
-
# f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}"
|
|
629
|
-
# )
|
|
630
|
-
|
|
631
|
-
assert (
|
|
632
|
-
"nested_training" in params
|
|
633
|
-
), "The parameter 'nested_training' needs to be defined"
|
|
634
|
-
# initialize defaults for nested training
|
|
635
|
-
params["nested_training"]["stratified"] = params["nested_training"].get(
|
|
636
|
-
"stratified", False
|
|
637
|
-
)
|
|
638
|
-
params["nested_training"]["stratified"] = params["nested_training"].get(
|
|
639
|
-
"proportional", params["nested_training"]["stratified"]
|
|
640
|
-
)
|
|
641
|
-
params["nested_training"]["testing"] = params["nested_training"].get("testing", -5)
|
|
642
|
-
params["nested_training"]["validation"] = params["nested_training"].get(
|
|
643
|
-
"validation", -5
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
parallel_compute_command = ""
|
|
647
|
-
if "parallel_compute_command" in params:
|
|
648
|
-
parallel_compute_command = params["parallel_compute_command"]
|
|
649
|
-
parallel_compute_command = parallel_compute_command.replace("'", "")
|
|
650
|
-
parallel_compute_command = parallel_compute_command.replace('"', "")
|
|
651
|
-
params["parallel_compute_command"] = parallel_compute_command
|
|
652
|
-
|
|
653
|
-
if "opt" in params:
|
|
654
|
-
print("DeprecationWarning: 'opt' has been superseded by 'optimizer'")
|
|
655
|
-
params["optimizer"] = params["opt"]
|
|
656
|
-
|
|
657
|
-
# initialize defaults for patch sampler
|
|
658
|
-
temp_patch_sampler_dict = {
|
|
659
|
-
"type": "uniform",
|
|
660
|
-
"enable_padding": False,
|
|
661
|
-
"padding_mode": "symmetric",
|
|
662
|
-
"biased_sampling": False,
|
|
663
|
-
}
|
|
664
|
-
# check if patch_sampler is defined in the config
|
|
665
|
-
if "patch_sampler" in params:
|
|
666
|
-
# if "patch_sampler" is a string, then it is the type of sampler
|
|
667
|
-
if isinstance(params["patch_sampler"], str):
|
|
668
|
-
print(
|
|
669
|
-
"WARNING: Defining 'patch_sampler' as a string will be deprecated in a future release, please use a dictionary instead"
|
|
670
|
-
)
|
|
671
|
-
temp_patch_sampler_dict["type"] = params["patch_sampler"].lower()
|
|
672
|
-
elif isinstance(params["patch_sampler"], dict):
|
|
673
|
-
# dict requires special handling
|
|
674
|
-
for key in params["patch_sampler"]:
|
|
675
|
-
temp_patch_sampler_dict[key] = params["patch_sampler"][key]
|
|
676
|
-
|
|
677
|
-
# now assign the dict back to the params
|
|
678
|
-
params["patch_sampler"] = temp_patch_sampler_dict
|
|
679
|
-
del temp_patch_sampler_dict
|
|
680
|
-
|
|
681
|
-
# define defaults
|
|
682
|
-
for current_parameter in parameter_defaults:
|
|
683
|
-
params = initialize_parameter(
|
|
684
|
-
params, current_parameter, parameter_defaults[current_parameter], True
|
|
685
|
-
)
|
|
686
|
-
|
|
687
|
-
for current_parameter in parameter_defaults_string:
|
|
688
|
-
params = initialize_parameter(
|
|
689
|
-
params,
|
|
690
|
-
current_parameter,
|
|
691
|
-
parameter_defaults_string[current_parameter],
|
|
692
|
-
False,
|
|
693
|
-
)
|
|
694
|
-
|
|
695
|
-
# ensure that the scheduler and optimizer are dicts
|
|
696
|
-
if isinstance(params["scheduler"], str):
|
|
697
|
-
temp_dict = {}
|
|
698
|
-
temp_dict["type"] = params["scheduler"]
|
|
699
|
-
params["scheduler"] = temp_dict
|
|
700
|
-
|
|
701
|
-
if not ("step_size" in params["scheduler"]):
|
|
702
|
-
params["scheduler"]["step_size"] = params["learning_rate"] / 5.0
|
|
703
|
-
print(
|
|
704
|
-
"WARNING: Setting default step_size to:", params["scheduler"]["step_size"]
|
|
705
|
-
)
|
|
706
|
-
|
|
707
|
-
# initialize default optimizer
|
|
708
|
-
params["optimizer"] = params.get("optimizer", {})
|
|
709
|
-
if isinstance(params["optimizer"], str):
|
|
710
|
-
temp_dict = {}
|
|
711
|
-
temp_dict["type"] = params["optimizer"]
|
|
712
|
-
params["optimizer"] = temp_dict
|
|
713
|
-
|
|
714
|
-
# initialize defaults for DP
|
|
715
|
-
if params.get("differential_privacy"):
|
|
716
|
-
params = parse_opacus_params(params, initialize_key)
|
|
717
|
-
|
|
718
|
-
# initialize defaults for inference mechanism
|
|
719
|
-
inference_mechanism = {"grid_aggregator_overlap": "crop", "patch_overlap": 0}
|
|
720
|
-
initialize_inference_mechanism = False
|
|
721
|
-
if not ("inference_mechanism" in params):
|
|
722
|
-
initialize_inference_mechanism = True
|
|
723
|
-
elif not (isinstance(params["inference_mechanism"], dict)):
|
|
724
|
-
initialize_inference_mechanism = True
|
|
725
|
-
else:
|
|
726
|
-
for key in inference_mechanism:
|
|
727
|
-
if not (key in params["inference_mechanism"]):
|
|
728
|
-
params["inference_mechanism"][key] = inference_mechanism[key]
|
|
25
|
+
try:
|
|
26
|
+
if not isinstance(config_file_path, dict):
|
|
27
|
+
params = yaml.safe_load(open(config_file_path, "r"))
|
|
28
|
+
except yaml.YAMLError as e:
|
|
29
|
+
# this is a special case for config files with panoptica parameters
|
|
30
|
+
from panoptica.utils.config import _load_yaml
|
|
729
31
|
|
|
730
|
-
|
|
731
|
-
params["inference_mechanism"] = inference_mechanism
|
|
32
|
+
params = _load_yaml(config_file_path)
|
|
732
33
|
|
|
733
34
|
return params
|
|
734
35
|
|
|
735
36
|
|
|
736
37
|
def ConfigManager(
|
|
737
38
|
config_file_path: Union[str, dict], version_check_flag: bool = True
|
|
738
|
-
) ->
|
|
39
|
+
) -> dict:
|
|
739
40
|
"""
|
|
740
41
|
This function parses the configuration file and returns a dictionary of parameters.
|
|
741
42
|
|
|
@@ -747,12 +48,27 @@ def ConfigManager(
|
|
|
747
48
|
dict: The parameter dictionary.
|
|
748
49
|
"""
|
|
749
50
|
try:
|
|
750
|
-
|
|
51
|
+
parameters_config = Parameters(
|
|
52
|
+
**_parseConfig(config_file_path, version_check_flag)
|
|
53
|
+
)
|
|
54
|
+
parameters = parameters_config.model_dump(
|
|
55
|
+
exclude={
|
|
56
|
+
field
|
|
57
|
+
for field in exclude_parameters
|
|
58
|
+
if getattr(parameters_config, field) is None
|
|
59
|
+
}
|
|
60
|
+
)
|
|
61
|
+
return parameters
|
|
62
|
+
|
|
751
63
|
except Exception as e:
|
|
64
|
+
if isinstance(e, ValidationError):
|
|
65
|
+
handle_configuration_errors(e)
|
|
66
|
+
raise
|
|
752
67
|
## todo: ensure logging captures assertion errors
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
68
|
+
else:
|
|
69
|
+
assert (
|
|
70
|
+
False
|
|
71
|
+
), f"Config parsing failed: {config_file_path=}, {version_check_flag=}, Exception: {str(e)}, {traceback.format_exc()}"
|
|
756
72
|
# logging.error(
|
|
757
73
|
# f"gandlf config parsing failed: {config_file_path=}, {version_check_flag=}, Exception: {str(e)}, {traceback.format_exc()}"
|
|
758
74
|
# )
|