GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of GANDLF might be problematic. Click here for more details.
- GANDLF/cli/deploy.py +2 -2
- GANDLF/cli/generate_metrics.py +35 -1
- GANDLF/cli/main_run.py +4 -10
- GANDLF/compute/__init__.py +0 -2
- GANDLF/compute/forward_pass.py +0 -1
- GANDLF/compute/generic.py +107 -2
- GANDLF/compute/inference_loop.py +4 -4
- GANDLF/compute/loss_and_metric.py +1 -2
- GANDLF/compute/training_loop.py +10 -10
- GANDLF/config_manager.py +33 -717
- GANDLF/configuration/__init__.py +0 -0
- GANDLF/configuration/default_config.py +73 -0
- GANDLF/configuration/differential_privacy_config.py +16 -0
- GANDLF/configuration/exclude_parameters.py +1 -0
- GANDLF/configuration/model_config.py +82 -0
- GANDLF/configuration/nested_training_config.py +25 -0
- GANDLF/configuration/optimizer_config.py +121 -0
- GANDLF/configuration/parameters_config.py +10 -0
- GANDLF/configuration/patch_sampler_config.py +11 -0
- GANDLF/configuration/post_processing_config.py +10 -0
- GANDLF/configuration/pre_processing_config.py +94 -0
- GANDLF/configuration/scheduler_config.py +92 -0
- GANDLF/configuration/user_defined_config.py +131 -0
- GANDLF/configuration/utils.py +96 -0
- GANDLF/configuration/validators.py +479 -0
- GANDLF/data/__init__.py +14 -16
- GANDLF/data/lightning_datamodule.py +119 -0
- GANDLF/entrypoints/run.py +36 -31
- GANDLF/inference_manager.py +69 -25
- GANDLF/losses/__init__.py +23 -1
- GANDLF/losses/loss_calculators.py +79 -0
- GANDLF/losses/segmentation.py +3 -2
- GANDLF/metrics/__init__.py +26 -0
- GANDLF/metrics/generic.py +1 -1
- GANDLF/metrics/metric_calculators.py +102 -0
- GANDLF/metrics/panoptica_config_brats.yaml +56 -0
- GANDLF/metrics/segmentation_panoptica.py +49 -0
- GANDLF/models/__init__.py +8 -3
- GANDLF/models/lightning_module.py +2102 -0
- GANDLF/optimizers/__init__.py +4 -8
- GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
- GANDLF/schedulers/__init__.py +11 -4
- GANDLF/schedulers/wrap_torch.py +15 -3
- GANDLF/training_manager.py +160 -50
- GANDLF/utils/__init__.py +5 -3
- GANDLF/utils/imaging.py +176 -35
- GANDLF/utils/modelio.py +12 -8
- GANDLF/utils/pred_target_processors.py +71 -0
- GANDLF/utils/tensor.py +2 -1
- GANDLF/utils/write_parse.py +1 -1
- GANDLF/version.py +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional, Union
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from typing import Type
|
|
6
|
+
from pydantic import BaseModel, ValidationError, create_model
|
|
7
|
+
from pydantic_core import ErrorDetails
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def initialize_key(
|
|
11
|
+
parameters: dict, key: str, value: Optional[Union[str, float, list, dict]] = None
|
|
12
|
+
) -> dict:
|
|
13
|
+
"""
|
|
14
|
+
This function initializes a key in the parameters dictionary to a value if it is absent.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
parameters (dict): The parameter dictionary.
|
|
18
|
+
key (str): The key to initialize.
|
|
19
|
+
value (Optional[Union[str, float, list, dict]], optional): The value to initialize. Defaults to None.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
dict: The parameter dictionary.
|
|
23
|
+
"""
|
|
24
|
+
if parameters is None:
|
|
25
|
+
parameters = {}
|
|
26
|
+
if key in parameters:
|
|
27
|
+
if parameters[key] is not None:
|
|
28
|
+
if isinstance(parameters[key], dict):
|
|
29
|
+
# if key is present but not defined
|
|
30
|
+
if len(parameters[key]) == 0:
|
|
31
|
+
parameters[key] = value
|
|
32
|
+
else:
|
|
33
|
+
parameters[key] = value # if key is absent
|
|
34
|
+
|
|
35
|
+
return parameters
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Define custom error messages. The key must be a pydantic type error.
|
|
39
|
+
CUSTOM_MESSAGES = {
|
|
40
|
+
"literal_error": "The input must be a valid option, please read the documentation",
|
|
41
|
+
"missing": "This parameter is required. Please define it",
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def convert_errors(e: ValidationError, custom_messages=None) -> list[ErrorDetails]:
|
|
46
|
+
if custom_messages is None:
|
|
47
|
+
custom_messages = CUSTOM_MESSAGES
|
|
48
|
+
new_errors: list[ErrorDetails] = []
|
|
49
|
+
for error in e.errors():
|
|
50
|
+
custom_message = custom_messages.get(error["type"])
|
|
51
|
+
if custom_message:
|
|
52
|
+
ctx = error.get("ctx")
|
|
53
|
+
error["msg"] = custom_message.format(**ctx) if ctx else custom_message
|
|
54
|
+
new_errors.append(error)
|
|
55
|
+
return new_errors
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def extract_messages(errors: list[ErrorDetails]) -> list[str]:
|
|
59
|
+
error_messages: list[str] = []
|
|
60
|
+
for error in errors:
|
|
61
|
+
location = error.get("loc")
|
|
62
|
+
if len(location) == 2:
|
|
63
|
+
message = f"Configuration Error: Parameter: ({location[0]}, {location[1]}) - {error['msg']}"
|
|
64
|
+
else:
|
|
65
|
+
message = (
|
|
66
|
+
f"Configuration Error: Parameter: ({location[0]}) - {error['msg']}"
|
|
67
|
+
)
|
|
68
|
+
error_messages.append(message)
|
|
69
|
+
return error_messages
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def handle_configuration_errors(e: ValidationError):
|
|
73
|
+
messages = extract_messages(convert_errors(e))
|
|
74
|
+
for message in messages:
|
|
75
|
+
logging.error(message)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def combine_models(base_model: Type[BaseModel], extra_model: Type[BaseModel]):
|
|
79
|
+
"""Combine base model with an extra model dynamically."""
|
|
80
|
+
fields = {}
|
|
81
|
+
# Collect base model fields
|
|
82
|
+
for field_name, field_info in base_model.model_fields.items():
|
|
83
|
+
fields[field_name] = (
|
|
84
|
+
field_info.annotation,
|
|
85
|
+
field_info.default if field_info.default is not Ellipsis else ...,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Add fields from the extra model
|
|
89
|
+
for field_name, field_info in extra_model.model_fields.items():
|
|
90
|
+
fields[field_name] = (
|
|
91
|
+
field_info.annotation,
|
|
92
|
+
field_info.default if field_info.default is not Ellipsis else ...,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Return the new dynamically combined model
|
|
96
|
+
return create_model(base_model.__name__, **fields)
|
|
@@ -0,0 +1,479 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import traceback
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from GANDLF.configuration.differential_privacy_config import DifferentialPrivacyConfig
|
|
5
|
+
from GANDLF.configuration.post_processing_config import PostProcessingConfig
|
|
6
|
+
from GANDLF.configuration.pre_processing_config import (
|
|
7
|
+
HistogramMatchingConfig,
|
|
8
|
+
PreProcessingConfig,
|
|
9
|
+
)
|
|
10
|
+
from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding
|
|
11
|
+
import numpy as np
|
|
12
|
+
import sys
|
|
13
|
+
from GANDLF.configuration.optimizer_config import OptimizerConfig, optimizer_dict_config
|
|
14
|
+
from GANDLF.configuration.patch_sampler_config import PatchSamplerConfig
|
|
15
|
+
from GANDLF.configuration.scheduler_config import (
|
|
16
|
+
SchedulerConfig,
|
|
17
|
+
schedulers_dict_config,
|
|
18
|
+
)
|
|
19
|
+
from GANDLF.configuration.utils import initialize_key, combine_models
|
|
20
|
+
from GANDLF.metrics import surface_distance_ids
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def validate_loss_function(value) -> dict:
|
|
24
|
+
if isinstance(value, dict): # if this is a dict
|
|
25
|
+
if len(value) > 0: # only proceed if something is defined
|
|
26
|
+
for key in value: # iterate through all keys
|
|
27
|
+
if key == "mse":
|
|
28
|
+
if (value[key] is None) or not ("reduction" in value[key]):
|
|
29
|
+
value[key] = {}
|
|
30
|
+
value[key]["reduction"] = "mean"
|
|
31
|
+
else:
|
|
32
|
+
# use simple string for other functions - can be extended with parameters, if needed
|
|
33
|
+
value = key
|
|
34
|
+
else:
|
|
35
|
+
if value == "focal":
|
|
36
|
+
value = {"focal": {}}
|
|
37
|
+
value["focal"]["gamma"] = 2.0
|
|
38
|
+
value["focal"]["size_average"] = True
|
|
39
|
+
elif value == "mse":
|
|
40
|
+
value = {"mse": {}}
|
|
41
|
+
value["mse"]["reduction"] = "mean"
|
|
42
|
+
|
|
43
|
+
return value
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def validate_metrics(value) -> dict:
|
|
47
|
+
if not isinstance(value, dict):
|
|
48
|
+
temp_dict = {}
|
|
49
|
+
else:
|
|
50
|
+
temp_dict = value
|
|
51
|
+
|
|
52
|
+
# initialize metrics dict
|
|
53
|
+
for metric in value:
|
|
54
|
+
# assigning a new variable because some metrics can be dicts, and we want to get the first key
|
|
55
|
+
comparison_string = metric
|
|
56
|
+
if isinstance(metric, dict):
|
|
57
|
+
comparison_string = list(metric.keys())[0]
|
|
58
|
+
# these metrics always need to be dicts
|
|
59
|
+
if comparison_string in [
|
|
60
|
+
"accuracy",
|
|
61
|
+
"f1",
|
|
62
|
+
"precision",
|
|
63
|
+
"recall",
|
|
64
|
+
"specificity",
|
|
65
|
+
"iou",
|
|
66
|
+
]:
|
|
67
|
+
if not isinstance(metric, dict):
|
|
68
|
+
temp_dict[metric] = {}
|
|
69
|
+
else:
|
|
70
|
+
temp_dict[comparison_string] = metric
|
|
71
|
+
elif not isinstance(metric, dict):
|
|
72
|
+
temp_dict[metric] = None
|
|
73
|
+
|
|
74
|
+
# special case for accuracy, precision, recall, and specificity; which could be dicts
|
|
75
|
+
## need to find a better way to do this
|
|
76
|
+
if any(
|
|
77
|
+
_ in comparison_string
|
|
78
|
+
for _ in ["precision", "recall", "specificity", "accuracy", "f1"]
|
|
79
|
+
):
|
|
80
|
+
if comparison_string != "classification_accuracy":
|
|
81
|
+
temp_dict[comparison_string] = initialize_key(
|
|
82
|
+
temp_dict[comparison_string], "average", "weighted"
|
|
83
|
+
)
|
|
84
|
+
temp_dict[comparison_string] = initialize_key(
|
|
85
|
+
temp_dict[comparison_string], "multi_class", True
|
|
86
|
+
)
|
|
87
|
+
temp_dict[comparison_string] = initialize_key(
|
|
88
|
+
temp_dict[comparison_string], "mdmc_average", "samplewise"
|
|
89
|
+
)
|
|
90
|
+
temp_dict[comparison_string] = initialize_key(
|
|
91
|
+
temp_dict[comparison_string], "threshold", 0.5
|
|
92
|
+
)
|
|
93
|
+
if comparison_string == "accuracy":
|
|
94
|
+
temp_dict[comparison_string] = initialize_key(
|
|
95
|
+
temp_dict[comparison_string], "subset_accuracy", False
|
|
96
|
+
)
|
|
97
|
+
elif "iou" in comparison_string:
|
|
98
|
+
temp_dict["iou"] = initialize_key(
|
|
99
|
+
temp_dict["iou"], "reduction", "elementwise_mean"
|
|
100
|
+
)
|
|
101
|
+
temp_dict["iou"] = initialize_key(temp_dict["iou"], "threshold", 0.5)
|
|
102
|
+
elif comparison_string in surface_distance_ids:
|
|
103
|
+
temp_dict[comparison_string] = initialize_key(
|
|
104
|
+
temp_dict[comparison_string], "connectivity", 1
|
|
105
|
+
)
|
|
106
|
+
temp_dict[comparison_string] = initialize_key(
|
|
107
|
+
temp_dict[comparison_string], "threshold", None
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
value = temp_dict
|
|
111
|
+
return value
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def validate_class_list(value):
|
|
115
|
+
if isinstance(value, str):
|
|
116
|
+
if ("||" in value) or ("&&" in value):
|
|
117
|
+
# special case for multi-class computation - this needs to be handled during one-hot encoding mask construction
|
|
118
|
+
print(
|
|
119
|
+
"WARNING: This is a special case for multi-class computation, where different labels are processed together, `reverse_one_hot` will need mapping information to work correctly"
|
|
120
|
+
)
|
|
121
|
+
temp_class_list = value
|
|
122
|
+
# we don't need the brackets
|
|
123
|
+
temp_class_list = temp_class_list.replace("[", "")
|
|
124
|
+
temp_class_list = temp_class_list.replace("]", "")
|
|
125
|
+
value = temp_class_list.split(",")
|
|
126
|
+
else:
|
|
127
|
+
try:
|
|
128
|
+
value = ast.literal_eval(value)
|
|
129
|
+
return value
|
|
130
|
+
except Exception as e:
|
|
131
|
+
assert (
|
|
132
|
+
False
|
|
133
|
+
), f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}"
|
|
134
|
+
# logging.error(
|
|
135
|
+
# f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}"
|
|
136
|
+
# )
|
|
137
|
+
return value
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def validate_patch_size(patch_size, dimension) -> list:
|
|
141
|
+
if isinstance(patch_size, int) or isinstance(patch_size, float):
|
|
142
|
+
patch_size = [patch_size]
|
|
143
|
+
if len(patch_size) == 1 and dimension is not None:
|
|
144
|
+
actual_patch_size = []
|
|
145
|
+
for _ in range(dimension):
|
|
146
|
+
actual_patch_size.append(patch_size[0])
|
|
147
|
+
patch_size = actual_patch_size
|
|
148
|
+
if len(patch_size) == 2: # 2d check
|
|
149
|
+
# ensuring same size during torchio processing
|
|
150
|
+
patch_size.append(1)
|
|
151
|
+
if dimension is None:
|
|
152
|
+
dimension = 2
|
|
153
|
+
elif len(patch_size) == 3: # 2d check
|
|
154
|
+
if dimension is None:
|
|
155
|
+
dimension = 3
|
|
156
|
+
return [patch_size, dimension]
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def validate_norm_type(norm_type, architecture):
|
|
160
|
+
if norm_type is None or norm_type.lower() == "none":
|
|
161
|
+
if not ("vgg" in architecture):
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"Normalization type cannot be 'None' for non-VGG architectures"
|
|
164
|
+
)
|
|
165
|
+
return norm_type
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def validate_parallel_compute_command(value):
|
|
169
|
+
parallel_compute_command = value
|
|
170
|
+
parallel_compute_command = parallel_compute_command.replace("'", "")
|
|
171
|
+
parallel_compute_command = parallel_compute_command.replace('"', "")
|
|
172
|
+
value = parallel_compute_command
|
|
173
|
+
return value
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def validate_scheduler(value, learning_rate, num_epochs):
|
|
177
|
+
if isinstance(value, str):
|
|
178
|
+
value = SchedulerConfig(type=value)
|
|
179
|
+
# Find the scheduler_config class based on the type
|
|
180
|
+
combine_scheduler_class = schedulers_dict_config[value.type]
|
|
181
|
+
# Combine it with the SchedulerConfig class
|
|
182
|
+
schedulerConfigCombine = combine_models(SchedulerConfig, combine_scheduler_class)
|
|
183
|
+
combineScheduler = schedulerConfigCombine(**value.model_dump())
|
|
184
|
+
value = SchedulerConfig(**combineScheduler.model_dump())
|
|
185
|
+
|
|
186
|
+
if value.type == "triangular":
|
|
187
|
+
if value.max_lr is None:
|
|
188
|
+
value.max_lr = learning_rate
|
|
189
|
+
|
|
190
|
+
if value.type in [
|
|
191
|
+
"reduce_on_plateau",
|
|
192
|
+
"reduce-on-plateau",
|
|
193
|
+
"plateau",
|
|
194
|
+
"exp_range",
|
|
195
|
+
"triangular",
|
|
196
|
+
]:
|
|
197
|
+
if value.min_lr is None:
|
|
198
|
+
value.min_lr = learning_rate * 0.001
|
|
199
|
+
|
|
200
|
+
if value.type in ["warmupcosineschedule", "wcs"]:
|
|
201
|
+
value.warmup_steps = num_epochs * 0.1
|
|
202
|
+
|
|
203
|
+
if hasattr(value, "step_size") and value.step_size is None:
|
|
204
|
+
value.step_size = learning_rate / 5.0
|
|
205
|
+
|
|
206
|
+
return value
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def validate_optimizer(value):
|
|
210
|
+
if isinstance(value, str):
|
|
211
|
+
value = OptimizerConfig(type=value)
|
|
212
|
+
|
|
213
|
+
combine_optimizer_class = optimizer_dict_config[value.type]
|
|
214
|
+
# Combine it with the OptimizerConfig class
|
|
215
|
+
optimizerConfigCombine = combine_models(OptimizerConfig, combine_optimizer_class)
|
|
216
|
+
combineOptimizer = optimizerConfigCombine(**value.model_dump())
|
|
217
|
+
value = OptimizerConfig(**combineOptimizer.model_dump())
|
|
218
|
+
|
|
219
|
+
return value
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def validate_data_preprocessing(value) -> dict:
|
|
223
|
+
if not (value is None):
|
|
224
|
+
# perform this only when pre-processing is defined
|
|
225
|
+
if len(value) > 0:
|
|
226
|
+
thresholdOrClip = False
|
|
227
|
+
# this can be extended, as required
|
|
228
|
+
thresholdOrClipDict = ["threshold", "clip", "clamp"]
|
|
229
|
+
|
|
230
|
+
resize_requested = False
|
|
231
|
+
temp_dict = deepcopy(value)
|
|
232
|
+
for key in value:
|
|
233
|
+
if key in ["resize", "resize_image", "resize_images", "resize_patch"]:
|
|
234
|
+
resize_requested = True
|
|
235
|
+
|
|
236
|
+
if key in ["resample_min", "resample_minimum"]:
|
|
237
|
+
if "resolution" in value[key]:
|
|
238
|
+
resize_requested = True
|
|
239
|
+
resolution_temp = np.array(value[key]["resolution"])
|
|
240
|
+
if resolution_temp.size == 1:
|
|
241
|
+
temp_dict[key]["resolution"] = np.array(
|
|
242
|
+
[resolution_temp, resolution_temp]
|
|
243
|
+
).tolist()
|
|
244
|
+
else:
|
|
245
|
+
temp_dict.pop(key)
|
|
246
|
+
|
|
247
|
+
value = temp_dict
|
|
248
|
+
|
|
249
|
+
if resize_requested and "resample" in value:
|
|
250
|
+
for key in ["resize", "resize_image", "resize_images", "resize_patch"]:
|
|
251
|
+
if key in value:
|
|
252
|
+
value.pop(key)
|
|
253
|
+
|
|
254
|
+
print(
|
|
255
|
+
"WARNING: Different 'resize' operations are ignored as 'resample' is defined under 'data_processing'",
|
|
256
|
+
file=sys.stderr,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# iterate through all keys
|
|
260
|
+
for key in value: # iterate through all keys
|
|
261
|
+
if key in thresholdOrClipDict:
|
|
262
|
+
# we only allow one of threshold or clip to occur and not both
|
|
263
|
+
assert not (
|
|
264
|
+
thresholdOrClip
|
|
265
|
+
), "Use only `threshold` or `clip`, not both"
|
|
266
|
+
thresholdOrClip = True
|
|
267
|
+
# initialize if nothing is present
|
|
268
|
+
if not (isinstance(value[key], dict)):
|
|
269
|
+
value[key] = {}
|
|
270
|
+
|
|
271
|
+
# if one of the required parameters is not present, initialize with lowest/highest possible values
|
|
272
|
+
# this ensures the absence of a field doesn't affect processing
|
|
273
|
+
# for threshold or clip, ensure min and max are defined
|
|
274
|
+
if not "min" in value[key]:
|
|
275
|
+
value[key]["min"] = sys.float_info.min
|
|
276
|
+
if not "max" in value[key]:
|
|
277
|
+
value[key]["max"] = sys.float_info.max
|
|
278
|
+
|
|
279
|
+
key = "histogram_matching"
|
|
280
|
+
if key in value:
|
|
281
|
+
if value["histogram_matching"] is not False:
|
|
282
|
+
if not (isinstance(value["histogram_matching"], dict)):
|
|
283
|
+
value["histogram_matching"] = HistogramMatchingConfig()
|
|
284
|
+
|
|
285
|
+
key = "histogram_equalization"
|
|
286
|
+
if key in value:
|
|
287
|
+
if value[key] is not False:
|
|
288
|
+
# if histogram equalization is enabled, call histogram_matching
|
|
289
|
+
value["histogram_matching"] = HistogramMatchingConfig()
|
|
290
|
+
key = "adaptive_histogram_equalization"
|
|
291
|
+
if key in value:
|
|
292
|
+
if value[key] is not False:
|
|
293
|
+
# if histogram equalization is enabled, call histogram_matching
|
|
294
|
+
value["histogram_matching"] = HistogramMatchingConfig(
|
|
295
|
+
target="adaptive"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
pre_processing = PreProcessingConfig(**value)
|
|
299
|
+
return pre_processing.model_dump(include={field for field in value.keys()})
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def validate_data_postprocessing_after_reverse_one_hot_encoding(
|
|
303
|
+
value, data_postprocessing
|
|
304
|
+
) -> list:
|
|
305
|
+
temp_dict = deepcopy(value)
|
|
306
|
+
for key in temp_dict:
|
|
307
|
+
if key in postprocessing_after_reverse_one_hot_encoding:
|
|
308
|
+
value[key] = data_postprocessing[key]
|
|
309
|
+
data_postprocessing.pop(key)
|
|
310
|
+
return [value, data_postprocessing]
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def validate_patch_sampler(value):
|
|
314
|
+
if isinstance(value, str):
|
|
315
|
+
value = PatchSamplerConfig(type=value.lower())
|
|
316
|
+
return value
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def validate_data_augmentation(value, patch_size) -> dict:
|
|
320
|
+
value["default_probability"] = value.get("default_probability", 0.5)
|
|
321
|
+
if not (value is None):
|
|
322
|
+
if len(value) > 0: # only when augmentations are defined
|
|
323
|
+
# special case for random swapping and elastic transformations - which takes a patch size for computation
|
|
324
|
+
for key in ["swap", "elastic"]:
|
|
325
|
+
if key in value:
|
|
326
|
+
value[key] = initialize_key(
|
|
327
|
+
value[key],
|
|
328
|
+
"patch_size",
|
|
329
|
+
np.round(np.array(patch_size) / 10).astype("int").tolist(),
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# special case for swap default initialization
|
|
333
|
+
if "swap" in value:
|
|
334
|
+
value["swap"] = initialize_key(value["swap"], "num_iterations", 100)
|
|
335
|
+
|
|
336
|
+
# special case for affine default initialization
|
|
337
|
+
if "affine" in value:
|
|
338
|
+
value["affine"] = initialize_key(value["affine"], "scales", 0.1)
|
|
339
|
+
value["affine"] = initialize_key(value["affine"], "degrees", 15)
|
|
340
|
+
value["affine"] = initialize_key(value["affine"], "translation", 2)
|
|
341
|
+
|
|
342
|
+
if "motion" in value:
|
|
343
|
+
value["motion"] = initialize_key(value["motion"], "num_transforms", 2)
|
|
344
|
+
value["motion"] = initialize_key(value["motion"], "degrees", 15)
|
|
345
|
+
value["motion"] = initialize_key(value["motion"], "translation", 2)
|
|
346
|
+
value["motion"] = initialize_key(
|
|
347
|
+
value["motion"], "interpolation", "linear"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# special case for random blur/noise - which takes a std-dev range
|
|
351
|
+
for std_aug in ["blur", "noise_var"]:
|
|
352
|
+
if std_aug in value:
|
|
353
|
+
value[std_aug] = initialize_key(value[std_aug], "std", None)
|
|
354
|
+
for std_aug in ["noise"]:
|
|
355
|
+
if std_aug in value:
|
|
356
|
+
value[std_aug] = initialize_key(value[std_aug], "std", [0, 1])
|
|
357
|
+
|
|
358
|
+
# special case for random noise - which takes a mean range
|
|
359
|
+
for mean_aug in ["noise", "noise_var"]:
|
|
360
|
+
if mean_aug in value:
|
|
361
|
+
value[mean_aug] = initialize_key(value[mean_aug], "mean", 0)
|
|
362
|
+
|
|
363
|
+
# special case for augmentations that need axis defined
|
|
364
|
+
for axis_aug in ["flip", "anisotropic", "rotate_90", "rotate_180"]:
|
|
365
|
+
if axis_aug in value:
|
|
366
|
+
value[axis_aug] = initialize_key(value[axis_aug], "axis", [0, 1, 2])
|
|
367
|
+
|
|
368
|
+
# special case for colorjitter
|
|
369
|
+
if "colorjitter" in value:
|
|
370
|
+
value = initialize_key(value, "colorjitter", {})
|
|
371
|
+
for key in ["brightness", "contrast", "saturation"]:
|
|
372
|
+
value["colorjitter"] = initialize_key(
|
|
373
|
+
value["colorjitter"], key, [0, 1]
|
|
374
|
+
)
|
|
375
|
+
value["colorjitter"] = initialize_key(
|
|
376
|
+
value["colorjitter"], "hue", [-0.5, 0.5]
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Added HED augmentation in gandlf
|
|
380
|
+
hed_augmentation_types = [
|
|
381
|
+
"hed_transform",
|
|
382
|
+
# "hed_transform_light",
|
|
383
|
+
# "hed_transform_heavy",
|
|
384
|
+
]
|
|
385
|
+
for augmentation_type in hed_augmentation_types:
|
|
386
|
+
if augmentation_type in value:
|
|
387
|
+
value = initialize_key(value, "hed_transform", {})
|
|
388
|
+
ranges = [
|
|
389
|
+
"haematoxylin_bias_range",
|
|
390
|
+
"eosin_bias_range",
|
|
391
|
+
"dab_bias_range",
|
|
392
|
+
"haematoxylin_sigma_range",
|
|
393
|
+
"eosin_sigma_range",
|
|
394
|
+
"dab_sigma_range",
|
|
395
|
+
]
|
|
396
|
+
|
|
397
|
+
default_range = (
|
|
398
|
+
[-0.1, 0.1]
|
|
399
|
+
if augmentation_type == "hed_transform"
|
|
400
|
+
else (
|
|
401
|
+
[-0.03, 0.03]
|
|
402
|
+
if augmentation_type == "hed_transform_light"
|
|
403
|
+
else [-0.95, 0.95]
|
|
404
|
+
)
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
for key in ranges:
|
|
408
|
+
value["hed_transform"] = initialize_key(
|
|
409
|
+
value["hed_transform"], key, default_range
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
value["hed_transform"] = initialize_key(
|
|
413
|
+
value["hed_transform"], "cutoff_range", [0, 1]
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# special case for anisotropic
|
|
417
|
+
if "anisotropic" in value:
|
|
418
|
+
if not ("downsampling" in value["anisotropic"]):
|
|
419
|
+
default_downsampling = 1.5
|
|
420
|
+
else:
|
|
421
|
+
default_downsampling = value["anisotropic"]["downsampling"]
|
|
422
|
+
|
|
423
|
+
initialize_downsampling = False
|
|
424
|
+
if isinstance(default_downsampling, list):
|
|
425
|
+
if len(default_downsampling) != 2:
|
|
426
|
+
initialize_downsampling = True
|
|
427
|
+
print(
|
|
428
|
+
"WARNING: 'anisotropic' augmentation needs to be either a single number of a list of 2 numbers: https://torchio.readthedocs.io/transforms/augmentation.html?highlight=randomswap#torchio.transforms.RandomAnisotropy.",
|
|
429
|
+
file=sys.stderr,
|
|
430
|
+
)
|
|
431
|
+
default_downsampling = default_downsampling[0] # only
|
|
432
|
+
else:
|
|
433
|
+
initialize_downsampling = True
|
|
434
|
+
|
|
435
|
+
if initialize_downsampling:
|
|
436
|
+
if default_downsampling < 1:
|
|
437
|
+
print(
|
|
438
|
+
"WARNING: 'anisotropic' augmentation needs the 'downsampling' parameter to be greater than 1, defaulting to 1.5.",
|
|
439
|
+
file=sys.stderr,
|
|
440
|
+
)
|
|
441
|
+
# default
|
|
442
|
+
value["anisotropic"]["downsampling"] = 1.5
|
|
443
|
+
|
|
444
|
+
for key in value:
|
|
445
|
+
if key != "default_probability":
|
|
446
|
+
value[key] = initialize_key(
|
|
447
|
+
value[key], "probability", value["default_probability"]
|
|
448
|
+
)
|
|
449
|
+
return value
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def validate_postprocessing(value):
|
|
453
|
+
post_processing = PostProcessingConfig(**value)
|
|
454
|
+
return post_processing.model_dump(include={field for field in value.keys()})
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def validate_differential_privacy(value, batch_size):
|
|
458
|
+
if value is None:
|
|
459
|
+
return value
|
|
460
|
+
if not isinstance(value, dict):
|
|
461
|
+
print(
|
|
462
|
+
"WARNING: Non dictionary value for the key: 'differential_privacy' was used, replacing with default valued dictionary."
|
|
463
|
+
)
|
|
464
|
+
value = DifferentialPrivacyConfig(physical_batch_size=batch_size).model_dump()
|
|
465
|
+
# these are some defaults
|
|
466
|
+
|
|
467
|
+
if value["physical_batch_size"] > batch_size:
|
|
468
|
+
print(
|
|
469
|
+
f"WARNING: The physical batch size {value['physical_batch_size']} is greater"
|
|
470
|
+
f"than the batch size {batch_size}, setting the physical batch size to the batch size."
|
|
471
|
+
)
|
|
472
|
+
value["physical_batch_size"] = batch_size
|
|
473
|
+
|
|
474
|
+
# these keys need to be parsed as floats, not strings
|
|
475
|
+
for key in ["noise_multiplier", "max_grad_norm", "delta", "epsilon"]:
|
|
476
|
+
if key in value:
|
|
477
|
+
value[key] = float(value[key])
|
|
478
|
+
|
|
479
|
+
return DifferentialPrivacyConfig(**value)
|
GANDLF/data/__init__.py
CHANGED
|
@@ -66,19 +66,17 @@ def get_testing_loader(params):
|
|
|
66
66
|
Returns:
|
|
67
67
|
torch.utils.data.DataLoader: The testing loader.
|
|
68
68
|
"""
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
pin_memory=False, # params["pin_memory_dataloader"], # this is going OOM if True - needs investigation
|
|
84
|
-
)
|
|
69
|
+
|
|
70
|
+
queue_from_dataframe = ImagesFromDataFrame(
|
|
71
|
+
get_dataframe(params["testing_data"]),
|
|
72
|
+
params,
|
|
73
|
+
train=False,
|
|
74
|
+
loader_type="testing",
|
|
75
|
+
)
|
|
76
|
+
if not ("channel_keys" in params):
|
|
77
|
+
params = populate_channel_keys_in_params(queue_from_dataframe, params)
|
|
78
|
+
return DataLoader(
|
|
79
|
+
queue_from_dataframe,
|
|
80
|
+
batch_size=1,
|
|
81
|
+
pin_memory=False, # params["pin_memory_dataloader"], # this is going OOM if True - needs investigation
|
|
82
|
+
)
|