GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of GANDLF might be problematic. Click here for more details.
- GANDLF/cli/deploy.py +2 -2
- GANDLF/cli/generate_metrics.py +35 -1
- GANDLF/cli/main_run.py +4 -10
- GANDLF/compute/__init__.py +0 -2
- GANDLF/compute/forward_pass.py +0 -1
- GANDLF/compute/generic.py +107 -2
- GANDLF/compute/inference_loop.py +4 -4
- GANDLF/compute/loss_and_metric.py +1 -2
- GANDLF/compute/training_loop.py +10 -10
- GANDLF/config_manager.py +33 -717
- GANDLF/configuration/__init__.py +0 -0
- GANDLF/configuration/default_config.py +73 -0
- GANDLF/configuration/differential_privacy_config.py +16 -0
- GANDLF/configuration/exclude_parameters.py +1 -0
- GANDLF/configuration/model_config.py +82 -0
- GANDLF/configuration/nested_training_config.py +25 -0
- GANDLF/configuration/optimizer_config.py +121 -0
- GANDLF/configuration/parameters_config.py +10 -0
- GANDLF/configuration/patch_sampler_config.py +11 -0
- GANDLF/configuration/post_processing_config.py +10 -0
- GANDLF/configuration/pre_processing_config.py +94 -0
- GANDLF/configuration/scheduler_config.py +92 -0
- GANDLF/configuration/user_defined_config.py +131 -0
- GANDLF/configuration/utils.py +96 -0
- GANDLF/configuration/validators.py +479 -0
- GANDLF/data/__init__.py +14 -16
- GANDLF/data/lightning_datamodule.py +119 -0
- GANDLF/entrypoints/run.py +36 -31
- GANDLF/inference_manager.py +69 -25
- GANDLF/losses/__init__.py +23 -1
- GANDLF/losses/loss_calculators.py +79 -0
- GANDLF/losses/segmentation.py +3 -2
- GANDLF/metrics/__init__.py +26 -0
- GANDLF/metrics/generic.py +1 -1
- GANDLF/metrics/metric_calculators.py +102 -0
- GANDLF/metrics/panoptica_config_brats.yaml +56 -0
- GANDLF/metrics/segmentation_panoptica.py +49 -0
- GANDLF/models/__init__.py +8 -3
- GANDLF/models/lightning_module.py +2102 -0
- GANDLF/optimizers/__init__.py +4 -8
- GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
- GANDLF/schedulers/__init__.py +11 -4
- GANDLF/schedulers/wrap_torch.py +15 -3
- GANDLF/training_manager.py +160 -50
- GANDLF/utils/__init__.py +5 -3
- GANDLF/utils/imaging.py +176 -35
- GANDLF/utils/modelio.py +12 -8
- GANDLF/utils/pred_target_processors.py +71 -0
- GANDLF/utils/tensor.py +2 -1
- GANDLF/utils/write_parse.py +1 -1
- GANDLF/version.py +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
- {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
|
File without changes
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field, AfterValidator
|
|
2
|
+
from typing import Dict
|
|
3
|
+
from typing_extensions import Literal, Optional, Annotated
|
|
4
|
+
|
|
5
|
+
from GANDLF.configuration.validators import validate_postprocessing
|
|
6
|
+
|
|
7
|
+
GRID_AGGREGATOR_OVERLAP_OPTIONS = Literal["crop", "average", "hann"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DefaultParameters(BaseModel):
|
|
11
|
+
weighted_loss: bool = Field(
|
|
12
|
+
default=False, description="Whether weighted loss is to be used or not."
|
|
13
|
+
)
|
|
14
|
+
verbose: bool = Field(default=False, description="General application verbosity.")
|
|
15
|
+
q_verbose: bool = Field(default=False, description="Queue construction verbosity.")
|
|
16
|
+
medcam_enabled: bool = Field(
|
|
17
|
+
default=False, description="Enable interpretability via medcam."
|
|
18
|
+
)
|
|
19
|
+
save_training: bool = Field(
|
|
20
|
+
default=False, description="Save outputs during training."
|
|
21
|
+
)
|
|
22
|
+
save_output: bool = Field(
|
|
23
|
+
default=False, description="Save outputs during validation/testing."
|
|
24
|
+
)
|
|
25
|
+
in_memory: bool = Field(default=False, description="Pin data to CPU memory.")
|
|
26
|
+
pin_memory_dataloader: bool = Field(
|
|
27
|
+
default=False, description="Pin data to GPU memory."
|
|
28
|
+
)
|
|
29
|
+
scaling_factor: int = Field(
|
|
30
|
+
default=1, description="Scaling factor for regression problems."
|
|
31
|
+
)
|
|
32
|
+
q_max_length: int = Field(default=100, description="The max length of the queue.")
|
|
33
|
+
q_samples_per_volume: int = Field(
|
|
34
|
+
default=10, description="Number of samples per volume."
|
|
35
|
+
)
|
|
36
|
+
q_num_workers: int = Field(
|
|
37
|
+
default=0, description="Number of worker threads to use."
|
|
38
|
+
)
|
|
39
|
+
num_epochs: int = Field(default=100, description="Total number of epochs to train.")
|
|
40
|
+
patience: int = Field(
|
|
41
|
+
default=100, description="Number of epochs to wait for performance improvement."
|
|
42
|
+
)
|
|
43
|
+
batch_size: int = Field(default=1, description="Default batch size for training.")
|
|
44
|
+
learning_rate: float = Field(default=0.001, description="Default learning rate.")
|
|
45
|
+
clip_grad: Optional[float] = Field(
|
|
46
|
+
default=None, description="Gradient clipping value."
|
|
47
|
+
)
|
|
48
|
+
track_memory_usage: bool = Field(
|
|
49
|
+
default=False, description="Enable memory usage tracking."
|
|
50
|
+
)
|
|
51
|
+
memory_save_mode: bool = Field(
|
|
52
|
+
default=False,
|
|
53
|
+
description="Enable memory-saving mode. If enabled, resize/resample will save files to disk.",
|
|
54
|
+
)
|
|
55
|
+
print_rgb_label_warning: bool = Field(
|
|
56
|
+
default=True, description="Print a warning for RGB labels."
|
|
57
|
+
)
|
|
58
|
+
data_postprocessing: Annotated[
|
|
59
|
+
dict,
|
|
60
|
+
Field(description="Default data postprocessing configuration.", default={}),
|
|
61
|
+
AfterValidator(validate_postprocessing),
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
grid_aggregator_overlap: GRID_AGGREGATOR_OVERLAP_OPTIONS = Field(
|
|
65
|
+
default="crop", description="Default grid aggregator overlap strategy."
|
|
66
|
+
)
|
|
67
|
+
determinism: bool = Field(
|
|
68
|
+
default=False, description="Enable deterministic computation."
|
|
69
|
+
)
|
|
70
|
+
previous_parameters: Optional[Dict] = Field(
|
|
71
|
+
default=None,
|
|
72
|
+
description="Previous parameters to be used for resuming training and performing sanity checks.",
|
|
73
|
+
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing_extensions import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
4
|
+
|
|
5
|
+
ACCOUNTANT_OPTIONS = Literal["rdp", "gdp", "prv"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DifferentialPrivacyConfig(BaseModel):
|
|
9
|
+
model_config = ConfigDict(extra="allow")
|
|
10
|
+
noise_multiplier: float = Field(default=10.0)
|
|
11
|
+
max_grad_norm: float = Field(default=1.0)
|
|
12
|
+
accountant: ACCOUNTANT_OPTIONS = Field(default="rdp")
|
|
13
|
+
secure_mode: bool = Field(default=False)
|
|
14
|
+
allow_opacus_model_fix: bool = Field(default=True)
|
|
15
|
+
delta: float = Field(default=1e-5)
|
|
16
|
+
physical_batch_size: int = Field(validate_default=True)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
exclude_parameters = {"differential_privacy"}
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from pydantic import BaseModel, model_validator, Field, AliasChoices, ConfigDict
|
|
2
|
+
from typing_extensions import Self, Literal, Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
from GANDLF.configuration.validators import validate_class_list, validate_norm_type
|
|
5
|
+
from GANDLF.models import global_models_dict
|
|
6
|
+
|
|
7
|
+
# Define model architecture options
|
|
8
|
+
ARCHITECTURE_OPTIONS = Literal[tuple(global_models_dict.keys())]
|
|
9
|
+
# Define model norm_type options
|
|
10
|
+
NORM_TYPE_OPTIONS = Literal["batch", "instance", "none"]
|
|
11
|
+
# Define model final_layer options
|
|
12
|
+
FINAL_LAYER_OPTIONS = Literal[
|
|
13
|
+
"sigmoid",
|
|
14
|
+
"softmax",
|
|
15
|
+
"logsoftmax",
|
|
16
|
+
"tanh",
|
|
17
|
+
"identity",
|
|
18
|
+
"logits",
|
|
19
|
+
"regression",
|
|
20
|
+
"None",
|
|
21
|
+
"none",
|
|
22
|
+
]
|
|
23
|
+
TYPE_OPTIONS = Literal["torch", "openvino"]
|
|
24
|
+
DIMENSIONS_OPTIONS = Literal[2, 3]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# You can define new parameters for model here. Please read the pydantic documentation.
|
|
28
|
+
# It allows extra fields in model dict.
|
|
29
|
+
class ModelConfig(BaseModel):
|
|
30
|
+
model_config = ConfigDict(
|
|
31
|
+
extra="allow"
|
|
32
|
+
) # it allows extra fields in the model dict
|
|
33
|
+
dimension: Optional[DIMENSIONS_OPTIONS] = Field(
|
|
34
|
+
description="model input dimension (2D or 3D)."
|
|
35
|
+
)
|
|
36
|
+
architecture: ARCHITECTURE_OPTIONS = Field(description="Architecture.")
|
|
37
|
+
final_layer: FINAL_LAYER_OPTIONS = Field(description="Final layer.")
|
|
38
|
+
norm_type: Optional[NORM_TYPE_OPTIONS] = Field(
|
|
39
|
+
description="Normalization type.", default="batch"
|
|
40
|
+
) # TODO: check it again
|
|
41
|
+
base_filters: Optional[int] = Field(
|
|
42
|
+
description="Base filters.", default=None, validate_default=True
|
|
43
|
+
) # default is 32
|
|
44
|
+
class_list: Union[list, str] = Field(default=[], description="Class list.")
|
|
45
|
+
num_channels: Optional[int] = Field(
|
|
46
|
+
description="Number of channels.",
|
|
47
|
+
validation_alias=AliasChoices(
|
|
48
|
+
"num_channels", "n_channels", "channels", "model_channels"
|
|
49
|
+
),
|
|
50
|
+
default=3,
|
|
51
|
+
) # TODO: check it
|
|
52
|
+
type: TYPE_OPTIONS = Field(description="Type of model.", default="torch")
|
|
53
|
+
data_type: str = Field(description="Data type.", default="FP32")
|
|
54
|
+
save_at_every_epoch: bool = Field(default=False, description="Save at every epoch.")
|
|
55
|
+
amp: bool = Field(default=False, description="Automatic mixed precision")
|
|
56
|
+
ignore_label_validation: Union[int, None] = Field(
|
|
57
|
+
default=None, description="Ignore label validation."
|
|
58
|
+
) # TODO: To check it
|
|
59
|
+
print_summary: bool = Field(default=True, description="Print summary.")
|
|
60
|
+
|
|
61
|
+
@model_validator(mode="after")
|
|
62
|
+
def model_validate(self) -> Self:
|
|
63
|
+
# TODO: Change the print to logging.warnings
|
|
64
|
+
self.class_list = validate_class_list(
|
|
65
|
+
self.class_list
|
|
66
|
+
) # init and validate the class_list parameter
|
|
67
|
+
self.norm_type = validate_norm_type(
|
|
68
|
+
self.norm_type, self.architecture
|
|
69
|
+
) # init and validate the norm type
|
|
70
|
+
if self.amp is False:
|
|
71
|
+
print("NOT using Mixed Precision Training")
|
|
72
|
+
|
|
73
|
+
if self.save_at_every_epoch:
|
|
74
|
+
print(
|
|
75
|
+
"WARNING: 'save_at_every_epoch' will result in TREMENDOUS storage usage; use at your own risk."
|
|
76
|
+
) # TODO: It is better to use logging.warning
|
|
77
|
+
|
|
78
|
+
if self.base_filters is None:
|
|
79
|
+
self.base_filters = 32
|
|
80
|
+
print("Using default 'base_filters' in 'model': ", self.base_filters)
|
|
81
|
+
|
|
82
|
+
return self
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field, model_validator
|
|
2
|
+
from typing_extensions import Self, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class NestedTraining(BaseModel):
|
|
6
|
+
stratified: bool = Field(
|
|
7
|
+
default=False,
|
|
8
|
+
description="this will perform stratified k-fold cross-validation but only with offline data splitting",
|
|
9
|
+
)
|
|
10
|
+
testing: int = Field(
|
|
11
|
+
default=-5,
|
|
12
|
+
description="this controls the number of testing data folds for final model evaluation; [NOT recommended] to disable this, use '1'",
|
|
13
|
+
le=10,
|
|
14
|
+
)
|
|
15
|
+
validation: int = Field(
|
|
16
|
+
default=-5,
|
|
17
|
+
description="this controls the number of validation data folds to be used for model *selection* during training (not used for back-propagation)",
|
|
18
|
+
)
|
|
19
|
+
proportional: Optional[bool] = Field(default=False)
|
|
20
|
+
|
|
21
|
+
@model_validator(mode="after")
|
|
22
|
+
def validate_nested_training(self) -> Self:
|
|
23
|
+
if self.proportional is not None:
|
|
24
|
+
self.stratified = self.proportional
|
|
25
|
+
return self
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
4
|
+
from typing_extensions import Literal
|
|
5
|
+
|
|
6
|
+
from GANDLF.optimizers import global_optimizer_dict
|
|
7
|
+
|
|
8
|
+
# takes the keys from global optimizer
|
|
9
|
+
OPTIMIZER_OPTIONS = Literal[tuple(global_optimizer_dict.keys())]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SgdConfig(BaseModel):
|
|
13
|
+
momentum: float = Field(default=0.99)
|
|
14
|
+
weight_decay: float = Field(default=3e-05)
|
|
15
|
+
dampening: float = Field(default=0)
|
|
16
|
+
nesterov: bool = Field(default=True)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AsgdConfig(BaseModel):
|
|
20
|
+
alpha: float = Field(default=0.75)
|
|
21
|
+
t0: float = Field(default=1e6)
|
|
22
|
+
lambd: float = Field(default=1e-4)
|
|
23
|
+
weight_decay: float = Field(default=3e-05)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AdamConfig(BaseModel):
|
|
27
|
+
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
|
|
28
|
+
weight_decay: float = Field(default=0.00005)
|
|
29
|
+
eps: float = Field(default=1e-8)
|
|
30
|
+
amsgrad: bool = Field(default=False)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AdamaxConfig(BaseModel):
|
|
34
|
+
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
|
|
35
|
+
weight_decay: float = Field(default=0.00005)
|
|
36
|
+
eps: float = Field(default=1e-8)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RpropConfig(BaseModel):
|
|
40
|
+
etas: Tuple[float, float] = Field(default=(0.5, 1.2))
|
|
41
|
+
step_sizes: Tuple[float, float] = Field(default=(1e-6, 50))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class AdadeltaConfig(BaseModel):
|
|
45
|
+
rho: float = Field(default=0.9)
|
|
46
|
+
eps: float = Field(default=1e-6)
|
|
47
|
+
weight_decay: float = Field(default=3e-05)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class AdagradConfig(BaseModel):
|
|
51
|
+
lr_decay: float = Field(default=0)
|
|
52
|
+
eps: float = Field(default=1e-6)
|
|
53
|
+
weight_decay: float = Field(default=3e-05)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class RmspropConfig(BaseModel):
|
|
57
|
+
alpha: float = Field(default=0.99)
|
|
58
|
+
eps: float = Field(default=1e-8)
|
|
59
|
+
centered: bool = Field(default=False)
|
|
60
|
+
momentum: float = Field(default=0)
|
|
61
|
+
weight_decay: float = Field(default=3e-05)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class RadamConfig(BaseModel):
|
|
65
|
+
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
|
|
66
|
+
eps: float = Field(default=1e-8)
|
|
67
|
+
weight_decay: float = Field(default=3e-05)
|
|
68
|
+
foreach: bool = Field(default=None)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class NadamConfig(BaseModel):
|
|
72
|
+
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
|
|
73
|
+
eps: float = Field(default=1e-8)
|
|
74
|
+
weight_decay: float = Field(default=3e-05)
|
|
75
|
+
foreach: bool = Field(default=None)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class NovogradConfig(BaseModel):
|
|
79
|
+
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
|
|
80
|
+
eps: float = Field(default=1e-8)
|
|
81
|
+
weight_decay: float = Field(default=3e-05)
|
|
82
|
+
amsgrad: bool = Field(default=False)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class AdemamixConfig(BaseModel):
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class LionConfig(BaseModel):
|
|
90
|
+
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
|
|
91
|
+
weight_decay: float = Field(default=0.0)
|
|
92
|
+
decoupled_weight_decay: bool = Field(default=False)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class AdoptConfig(BaseModel):
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class OptimizerConfig(BaseModel):
|
|
100
|
+
model_config = ConfigDict(extra="allow")
|
|
101
|
+
type: OPTIMIZER_OPTIONS = Field(description="Type of optimizer to use")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
optimizer_dict_config = {
|
|
105
|
+
"sgd": SgdConfig,
|
|
106
|
+
"asgd": AsgdConfig,
|
|
107
|
+
"adam": AdamConfig,
|
|
108
|
+
"adamw": AdamConfig,
|
|
109
|
+
"adamax": AdamaxConfig,
|
|
110
|
+
# "sparseadam": sparseadam,
|
|
111
|
+
"rprop": RpropConfig,
|
|
112
|
+
"adadelta": AdadeltaConfig,
|
|
113
|
+
"adagrad": AdagradConfig,
|
|
114
|
+
"rmsprop": RmspropConfig,
|
|
115
|
+
"radam": RadamConfig,
|
|
116
|
+
"novograd": NovogradConfig,
|
|
117
|
+
"nadam": NadamConfig,
|
|
118
|
+
"ademamix": AdemamixConfig,
|
|
119
|
+
"lion": LionConfig,
|
|
120
|
+
"adopt": AdoptConfig,
|
|
121
|
+
}
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from pydantic import BaseModel, ConfigDict
|
|
2
|
+
from GANDLF.configuration.user_defined_config import UserDefinedParameters
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ParametersConfiguration(BaseModel):
|
|
6
|
+
model_config = ConfigDict(extra="allow")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Parameters(ParametersConfiguration, UserDefinedParameters):
|
|
10
|
+
pass
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
from typing_extensions import Literal
|
|
3
|
+
|
|
4
|
+
TYPE_OPTIONS = Literal["uniform", "label"]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PatchSamplerConfig(BaseModel):
|
|
8
|
+
type: TYPE_OPTIONS = Field(default="uniform")
|
|
9
|
+
enable_padding: bool = Field(default=False)
|
|
10
|
+
padding_mode: str = Field(default="symmetric")
|
|
11
|
+
biased_sampling: bool = Field(default=False)
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
2
|
+
from typing_extensions import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PostProcessingConfig(BaseModel):
|
|
6
|
+
model_config = ConfigDict(extra="forbid", exclude_none=True)
|
|
7
|
+
fill_holes: Any = Field(default=None)
|
|
8
|
+
mapping: dict = Field(default=None)
|
|
9
|
+
morphology: Any = Field(default=None)
|
|
10
|
+
cca: Any = Field(default=None)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from pydantic import BaseModel, ConfigDict, Field, AliasChoices, model_validator
|
|
2
|
+
from typing_extensions import Any, Literal, Self
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ThresholdConfig(BaseModel):
|
|
6
|
+
min: int = Field()
|
|
7
|
+
max: int = Field()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ClipConfig(BaseModel):
|
|
11
|
+
min: int = Field()
|
|
12
|
+
max: int = Field()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RescaleConfig(BaseModel):
|
|
16
|
+
in_min_max: list[float] = Field(default=[15, 125])
|
|
17
|
+
out_min_max: list[float] = Field(default=[0, 1])
|
|
18
|
+
percentiles: list[float] = Field(default=[5, 95])
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class HistogramMatchingConfig(BaseModel):
|
|
22
|
+
num_hist_level: int = Field(default=1024)
|
|
23
|
+
num_match_points: int = Field(default=16)
|
|
24
|
+
target: Any = Field(default=None)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ResampleMinConfig(BaseModel):
|
|
28
|
+
resolution: list[float] = Field(default=None)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ResampleConfig(BaseModel):
|
|
32
|
+
resolution: list[float] = Field(default=None)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class StainNormalizationConfig(BaseModel):
|
|
36
|
+
target: Any = Field()
|
|
37
|
+
extractor: Literal["vahadane", "ruifrok", "macenko"] = Field(default="ruifrok")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class PreProcessingConfig(BaseModel):
|
|
41
|
+
model_config = ConfigDict(extra="forbid", exclude_none=True)
|
|
42
|
+
to_canonical: Any = Field(default=None)
|
|
43
|
+
threshold: ThresholdConfig = Field(default=None)
|
|
44
|
+
clip: ClipConfig = Field(default=None)
|
|
45
|
+
clamp: ClipConfig = Field(default=None)
|
|
46
|
+
crop_external_zero_planes: Any = Field(default=None)
|
|
47
|
+
crop: list[int] = Field(default=None)
|
|
48
|
+
centercrop: list[int] = Field(default=None)
|
|
49
|
+
normalize_by_val: Any = Field(default=None)
|
|
50
|
+
normalize_imagenet: Any = Field(default=None)
|
|
51
|
+
normalize_standardize: Any = Field(default=None)
|
|
52
|
+
normalize_div_by_255: Any = Field(default=None)
|
|
53
|
+
normalize: Any = Field(default=None)
|
|
54
|
+
normalize_nonZero: Any = Field(
|
|
55
|
+
default=None,
|
|
56
|
+
validation_alias=AliasChoices("normalize_nonZero", "normalize_nonzero"),
|
|
57
|
+
)
|
|
58
|
+
normalize_nonZero_masked: Any = Field(
|
|
59
|
+
default=None,
|
|
60
|
+
validation_alias=AliasChoices(
|
|
61
|
+
"normalize_nonZero_masked", "normalize_nonzero_masked"
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
rescale: RescaleConfig = Field(default=None)
|
|
65
|
+
rgba2rgb: Any = Field(
|
|
66
|
+
default=None,
|
|
67
|
+
validation_alias=AliasChoices("rgba2rgb", "rgbatorgb", "rgba_to_rgb"),
|
|
68
|
+
)
|
|
69
|
+
rgb2rgba: Any = Field(
|
|
70
|
+
default=None,
|
|
71
|
+
validation_alias=AliasChoices("rgb2rgba", "rgbtorgba", "rgb_to_rgba"),
|
|
72
|
+
)
|
|
73
|
+
histogram_matching: HistogramMatchingConfig = Field(default=None)
|
|
74
|
+
histogram_equalization: HistogramMatchingConfig = Field(default=None)
|
|
75
|
+
adaptive_histogram_equalization: Any = Field(default=None)
|
|
76
|
+
resample: ResampleConfig = Field(default=None)
|
|
77
|
+
resize_image: list[int] = Field(
|
|
78
|
+
default=None,
|
|
79
|
+
validation_alias=AliasChoices(
|
|
80
|
+
"resize_image", "resize", "resize_image", "resize_images"
|
|
81
|
+
),
|
|
82
|
+
)
|
|
83
|
+
resize_patch: list[int] = Field(default=None)
|
|
84
|
+
stain_normalization: StainNormalizationConfig = Field(default=None)
|
|
85
|
+
resample_min: ResampleMinConfig = Field(
|
|
86
|
+
default=None, validation_alias=AliasChoices("resample_min", "resample_minimum")
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@model_validator(mode="after")
|
|
90
|
+
def pre_processing_validate(self) -> Self:
|
|
91
|
+
if self.adaptive_histogram_equalization is not None:
|
|
92
|
+
self.histogram_matching = HistogramMatchingConfig(target="adaptive")
|
|
93
|
+
self.adaptive_histogram_equalization = None
|
|
94
|
+
return self
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
2
|
+
from typing_extensions import Literal, Union
|
|
3
|
+
from GANDLF.schedulers import global_schedulers_dict
|
|
4
|
+
|
|
5
|
+
TYPE_OPTIONS = Literal[tuple(global_schedulers_dict.keys())]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseTriangleConfig(BaseModel):
|
|
9
|
+
min_lr: float = Field(default=(10**-3))
|
|
10
|
+
max_lr: float = Field(default=1)
|
|
11
|
+
step_size: float = Field(description="step_size", default=None)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TriangleModifiedConfig(BaseModel):
|
|
15
|
+
min_lr: float = Field(default=0.000001)
|
|
16
|
+
max_lr: float = Field(default=0.001)
|
|
17
|
+
max_lr_multiplier: float = Field(default=1.0)
|
|
18
|
+
step_size: float = Field(description="step_size", default=None)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CyclicLrBaseConfig(BaseModel):
|
|
22
|
+
# More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CyclicLR.html
|
|
23
|
+
min_lr: float = Field(
|
|
24
|
+
default=None
|
|
25
|
+
) # The default value is calculated according the learning rate * 0.001
|
|
26
|
+
max_lr: float = Field(default=None) # calculate in the validation stage
|
|
27
|
+
gamma: float = Field(default=0.1)
|
|
28
|
+
scale_mode: Literal["cycle", "iterations"] = Field(default="cycle")
|
|
29
|
+
cycle_momentum: bool = Field(default=False)
|
|
30
|
+
base_momentum: float = Field(default=0.8)
|
|
31
|
+
max_momentum: float = Field(default=0.9)
|
|
32
|
+
step_size: float = Field(description="step_size", default=None)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ExpConfig(BaseModel):
|
|
36
|
+
gamma: float = Field(default=0.1)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class StepConfig(BaseModel):
|
|
40
|
+
gamma: float = Field(default=0.1)
|
|
41
|
+
step_size: float = Field(description="step_size", default=None)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CosineannealingConfig(BaseModel):
|
|
45
|
+
# More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html
|
|
46
|
+
T_0: int = Field(default=5)
|
|
47
|
+
T_mult: float = Field(default=1)
|
|
48
|
+
min_lr: float = Field(default=0.001)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ReduceOnPlateauConfig(BaseModel):
|
|
52
|
+
# More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html
|
|
53
|
+
min_lr: Union[float, list] = Field(default=None)
|
|
54
|
+
gamma: float = Field(default=0.1)
|
|
55
|
+
mode: Literal["min", "max"] = Field(default="min")
|
|
56
|
+
factor: float = Field(default=0.1)
|
|
57
|
+
patience: int = Field(default=10)
|
|
58
|
+
threshold: float = Field(default=0.0001)
|
|
59
|
+
cooldown: int = Field(default=0)
|
|
60
|
+
threshold_mode: Literal["rel", "abs"] = Field(default="rel")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class WarmupcosinescheduleConfig(BaseModel):
|
|
64
|
+
# More details https://docs.monai.io/en/stable/optimizers.html#monai.optimizers.WarmupCosineSchedule
|
|
65
|
+
warmup_steps: int = Field(default=None)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# It allows extra parameters
|
|
69
|
+
class SchedulerConfig(BaseModel):
|
|
70
|
+
model_config = ConfigDict(extra="allow")
|
|
71
|
+
type: TYPE_OPTIONS = Field(description="scheduler type")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Define the type and the scheduler base model class
|
|
75
|
+
schedulers_dict_config = {
|
|
76
|
+
"triangle": BaseTriangleConfig,
|
|
77
|
+
"triangle_modified": TriangleModifiedConfig,
|
|
78
|
+
"triangular": CyclicLrBaseConfig,
|
|
79
|
+
"exp_range": CyclicLrBaseConfig,
|
|
80
|
+
"exp": ExpConfig,
|
|
81
|
+
"exponential": ExpConfig,
|
|
82
|
+
"step": StepConfig,
|
|
83
|
+
"reduce_on_plateau": ReduceOnPlateauConfig,
|
|
84
|
+
"reduce-on-plateau": ReduceOnPlateauConfig,
|
|
85
|
+
"plateau": ReduceOnPlateauConfig,
|
|
86
|
+
"reduceonplateau": ReduceOnPlateauConfig,
|
|
87
|
+
"cosineannealingwarmrestarts": CosineannealingConfig,
|
|
88
|
+
"cosineannealing": CosineannealingConfig,
|
|
89
|
+
"cosineannealinglr": CosineannealingConfig,
|
|
90
|
+
"warmupcosineschedule": WarmupcosinescheduleConfig,
|
|
91
|
+
"wcs": WarmupcosinescheduleConfig,
|
|
92
|
+
}
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
from pydantic import BaseModel, model_validator, Field, AfterValidator
|
|
3
|
+
from GANDLF.configuration.default_config import DefaultParameters
|
|
4
|
+
from GANDLF.configuration.differential_privacy_config import DifferentialPrivacyConfig
|
|
5
|
+
from GANDLF.configuration.nested_training_config import NestedTraining
|
|
6
|
+
from GANDLF.configuration.optimizer_config import OptimizerConfig
|
|
7
|
+
from GANDLF.configuration.patch_sampler_config import PatchSamplerConfig
|
|
8
|
+
from GANDLF.configuration.scheduler_config import SchedulerConfig
|
|
9
|
+
from GANDLF.utils import version_check
|
|
10
|
+
from importlib.metadata import version
|
|
11
|
+
from typing_extensions import Self, Literal, Annotated
|
|
12
|
+
from GANDLF.configuration.validators import (
|
|
13
|
+
validate_scheduler,
|
|
14
|
+
validate_optimizer,
|
|
15
|
+
validate_loss_function,
|
|
16
|
+
validate_metrics,
|
|
17
|
+
validate_data_preprocessing,
|
|
18
|
+
validate_patch_size,
|
|
19
|
+
validate_parallel_compute_command,
|
|
20
|
+
validate_patch_sampler,
|
|
21
|
+
validate_data_augmentation,
|
|
22
|
+
validate_data_postprocessing_after_reverse_one_hot_encoding,
|
|
23
|
+
validate_differential_privacy,
|
|
24
|
+
)
|
|
25
|
+
from GANDLF.configuration.model_config import ModelConfig
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Version(BaseModel):
|
|
29
|
+
minimum: str
|
|
30
|
+
maximum: str
|
|
31
|
+
|
|
32
|
+
@model_validator(mode="after")
|
|
33
|
+
def validate_version(self) -> Self:
|
|
34
|
+
if version_check(self.model_dump(), version_to_check=version("GANDLF")):
|
|
35
|
+
return self
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class InferenceMechanismConfig(BaseModel):
|
|
39
|
+
grid_aggregator_overlap: Literal["crop", "average"] = Field(default="crop")
|
|
40
|
+
patch_overlap: int = Field(default=0)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class UserDefinedParameters(DefaultParameters):
|
|
44
|
+
version: Version = Field(
|
|
45
|
+
default=Version(minimum=version("GANDLF"), maximum=version("GANDLF")),
|
|
46
|
+
description="GANDLF version",
|
|
47
|
+
)
|
|
48
|
+
patch_size: Union[list[Union[int, float]], int, float] = Field(
|
|
49
|
+
description="Patch size."
|
|
50
|
+
)
|
|
51
|
+
model: ModelConfig = Field(description="The model to use. ")
|
|
52
|
+
modality: Literal["rad", "histo", "path"] = Field(description="Modality.")
|
|
53
|
+
loss_function: Annotated[
|
|
54
|
+
Union[dict, str],
|
|
55
|
+
Field(description="Loss function."),
|
|
56
|
+
AfterValidator(validate_loss_function),
|
|
57
|
+
]
|
|
58
|
+
metrics: Annotated[
|
|
59
|
+
Union[dict, list[Union[str, dict, set]]],
|
|
60
|
+
Field(description="Metrics."),
|
|
61
|
+
AfterValidator(validate_metrics),
|
|
62
|
+
]
|
|
63
|
+
nested_training: NestedTraining = Field(description="Nested training.")
|
|
64
|
+
parallel_compute_command: str = Field(
|
|
65
|
+
default="", description="Parallel compute command."
|
|
66
|
+
)
|
|
67
|
+
scheduler: Union[str, SchedulerConfig] = Field(
|
|
68
|
+
description="Scheduler.", default=SchedulerConfig(type="triangle_modified")
|
|
69
|
+
)
|
|
70
|
+
optimizer: Union[str, OptimizerConfig] = Field(
|
|
71
|
+
description="Optimizer.", default=OptimizerConfig(type="adam")
|
|
72
|
+
)
|
|
73
|
+
patch_sampler: Union[str, PatchSamplerConfig] = Field(
|
|
74
|
+
description="Patch sampler.", default=PatchSamplerConfig()
|
|
75
|
+
)
|
|
76
|
+
inference_mechanism: InferenceMechanismConfig = Field(
|
|
77
|
+
description="Inference mechanism.", default=InferenceMechanismConfig()
|
|
78
|
+
)
|
|
79
|
+
data_postprocessing_after_reverse_one_hot_encoding: dict = Field(
|
|
80
|
+
description="data_postprocessing_after_reverse_one_hot_encoding.", default={}
|
|
81
|
+
)
|
|
82
|
+
differential_privacy: Union[bool, DifferentialPrivacyConfig] = Field(
|
|
83
|
+
description="Differential privacy.", default=None
|
|
84
|
+
)
|
|
85
|
+
clip_mode: Literal["norm", "value"] = Field(
|
|
86
|
+
description="Clip mode.", default="norm"
|
|
87
|
+
)
|
|
88
|
+
data_preprocessing: Annotated[
|
|
89
|
+
dict,
|
|
90
|
+
Field(description="Data preprocessing."),
|
|
91
|
+
AfterValidator(validate_data_preprocessing),
|
|
92
|
+
] = {}
|
|
93
|
+
data_augmentation: Annotated[dict, Field(description="Data augmentation.")] = {}
|
|
94
|
+
|
|
95
|
+
# Validators
|
|
96
|
+
@model_validator(mode="after")
|
|
97
|
+
def validate(self) -> Self:
|
|
98
|
+
# validate the patch_size
|
|
99
|
+
self.patch_size, self.model.dimension = validate_patch_size(
|
|
100
|
+
self.patch_size, self.model.dimension
|
|
101
|
+
)
|
|
102
|
+
# validate the parallel_compute_command
|
|
103
|
+
self.parallel_compute_command = validate_parallel_compute_command(
|
|
104
|
+
self.parallel_compute_command
|
|
105
|
+
)
|
|
106
|
+
# validate scheduler
|
|
107
|
+
self.scheduler = validate_scheduler(
|
|
108
|
+
self.scheduler, self.learning_rate, self.num_epochs
|
|
109
|
+
)
|
|
110
|
+
# validate optimizer
|
|
111
|
+
self.optimizer = validate_optimizer(self.optimizer)
|
|
112
|
+
# validate patch_sampler
|
|
113
|
+
self.patch_sampler = validate_patch_sampler(self.patch_sampler)
|
|
114
|
+
# validate_data_augmentation
|
|
115
|
+
self.data_augmentation = validate_data_augmentation(
|
|
116
|
+
self.data_augmentation, self.patch_size
|
|
117
|
+
)
|
|
118
|
+
# validate data_postprocessing_after_reverse_one_hot_encoding
|
|
119
|
+
(
|
|
120
|
+
self.data_postprocessing_after_reverse_one_hot_encoding,
|
|
121
|
+
self.data_postprocessing,
|
|
122
|
+
) = validate_data_postprocessing_after_reverse_one_hot_encoding(
|
|
123
|
+
self.data_postprocessing_after_reverse_one_hot_encoding,
|
|
124
|
+
self.data_postprocessing,
|
|
125
|
+
)
|
|
126
|
+
# validate differential_privacy
|
|
127
|
+
self.differential_privacy = validate_differential_privacy(
|
|
128
|
+
self.differential_privacy, self.batch_size
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return self
|