GANDLF 0.1.3.dev20250202__py3-none-any.whl → 0.1.6.dev20251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of GANDLF might be problematic. Click here for more details.

Files changed (57) hide show
  1. GANDLF/cli/deploy.py +2 -2
  2. GANDLF/cli/generate_metrics.py +35 -1
  3. GANDLF/cli/main_run.py +4 -10
  4. GANDLF/compute/__init__.py +0 -2
  5. GANDLF/compute/forward_pass.py +0 -1
  6. GANDLF/compute/generic.py +107 -2
  7. GANDLF/compute/inference_loop.py +4 -4
  8. GANDLF/compute/loss_and_metric.py +1 -2
  9. GANDLF/compute/training_loop.py +10 -10
  10. GANDLF/config_manager.py +33 -717
  11. GANDLF/configuration/__init__.py +0 -0
  12. GANDLF/configuration/default_config.py +73 -0
  13. GANDLF/configuration/differential_privacy_config.py +16 -0
  14. GANDLF/configuration/exclude_parameters.py +1 -0
  15. GANDLF/configuration/model_config.py +82 -0
  16. GANDLF/configuration/nested_training_config.py +25 -0
  17. GANDLF/configuration/optimizer_config.py +121 -0
  18. GANDLF/configuration/parameters_config.py +10 -0
  19. GANDLF/configuration/patch_sampler_config.py +11 -0
  20. GANDLF/configuration/post_processing_config.py +10 -0
  21. GANDLF/configuration/pre_processing_config.py +94 -0
  22. GANDLF/configuration/scheduler_config.py +92 -0
  23. GANDLF/configuration/user_defined_config.py +131 -0
  24. GANDLF/configuration/utils.py +96 -0
  25. GANDLF/configuration/validators.py +479 -0
  26. GANDLF/data/__init__.py +14 -16
  27. GANDLF/data/lightning_datamodule.py +119 -0
  28. GANDLF/entrypoints/run.py +36 -31
  29. GANDLF/inference_manager.py +69 -25
  30. GANDLF/losses/__init__.py +23 -1
  31. GANDLF/losses/loss_calculators.py +79 -0
  32. GANDLF/losses/segmentation.py +3 -2
  33. GANDLF/metrics/__init__.py +26 -0
  34. GANDLF/metrics/generic.py +1 -1
  35. GANDLF/metrics/metric_calculators.py +102 -0
  36. GANDLF/metrics/panoptica_config_brats.yaml +56 -0
  37. GANDLF/metrics/segmentation_panoptica.py +49 -0
  38. GANDLF/models/__init__.py +8 -3
  39. GANDLF/models/lightning_module.py +2102 -0
  40. GANDLF/optimizers/__init__.py +4 -8
  41. GANDLF/privacy/opacus/opacus_anonymization_manager.py +243 -0
  42. GANDLF/schedulers/__init__.py +11 -4
  43. GANDLF/schedulers/wrap_torch.py +15 -3
  44. GANDLF/training_manager.py +160 -50
  45. GANDLF/utils/__init__.py +5 -3
  46. GANDLF/utils/imaging.py +176 -35
  47. GANDLF/utils/modelio.py +12 -8
  48. GANDLF/utils/pred_target_processors.py +71 -0
  49. GANDLF/utils/tensor.py +2 -1
  50. GANDLF/utils/write_parse.py +1 -1
  51. GANDLF/version.py +1 -1
  52. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/METADATA +16 -11
  53. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/RECORD +57 -34
  54. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/WHEEL +1 -1
  55. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/entry_points.txt +0 -0
  56. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info/licenses}/LICENSE +0 -0
  57. {GANDLF-0.1.3.dev20250202.dist-info → gandlf-0.1.6.dev20251109.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,73 @@
1
+ from pydantic import BaseModel, Field, AfterValidator
2
+ from typing import Dict
3
+ from typing_extensions import Literal, Optional, Annotated
4
+
5
+ from GANDLF.configuration.validators import validate_postprocessing
6
+
7
+ GRID_AGGREGATOR_OVERLAP_OPTIONS = Literal["crop", "average", "hann"]
8
+
9
+
10
+ class DefaultParameters(BaseModel):
11
+ weighted_loss: bool = Field(
12
+ default=False, description="Whether weighted loss is to be used or not."
13
+ )
14
+ verbose: bool = Field(default=False, description="General application verbosity.")
15
+ q_verbose: bool = Field(default=False, description="Queue construction verbosity.")
16
+ medcam_enabled: bool = Field(
17
+ default=False, description="Enable interpretability via medcam."
18
+ )
19
+ save_training: bool = Field(
20
+ default=False, description="Save outputs during training."
21
+ )
22
+ save_output: bool = Field(
23
+ default=False, description="Save outputs during validation/testing."
24
+ )
25
+ in_memory: bool = Field(default=False, description="Pin data to CPU memory.")
26
+ pin_memory_dataloader: bool = Field(
27
+ default=False, description="Pin data to GPU memory."
28
+ )
29
+ scaling_factor: int = Field(
30
+ default=1, description="Scaling factor for regression problems."
31
+ )
32
+ q_max_length: int = Field(default=100, description="The max length of the queue.")
33
+ q_samples_per_volume: int = Field(
34
+ default=10, description="Number of samples per volume."
35
+ )
36
+ q_num_workers: int = Field(
37
+ default=0, description="Number of worker threads to use."
38
+ )
39
+ num_epochs: int = Field(default=100, description="Total number of epochs to train.")
40
+ patience: int = Field(
41
+ default=100, description="Number of epochs to wait for performance improvement."
42
+ )
43
+ batch_size: int = Field(default=1, description="Default batch size for training.")
44
+ learning_rate: float = Field(default=0.001, description="Default learning rate.")
45
+ clip_grad: Optional[float] = Field(
46
+ default=None, description="Gradient clipping value."
47
+ )
48
+ track_memory_usage: bool = Field(
49
+ default=False, description="Enable memory usage tracking."
50
+ )
51
+ memory_save_mode: bool = Field(
52
+ default=False,
53
+ description="Enable memory-saving mode. If enabled, resize/resample will save files to disk.",
54
+ )
55
+ print_rgb_label_warning: bool = Field(
56
+ default=True, description="Print a warning for RGB labels."
57
+ )
58
+ data_postprocessing: Annotated[
59
+ dict,
60
+ Field(description="Default data postprocessing configuration.", default={}),
61
+ AfterValidator(validate_postprocessing),
62
+ ]
63
+
64
+ grid_aggregator_overlap: GRID_AGGREGATOR_OVERLAP_OPTIONS = Field(
65
+ default="crop", description="Default grid aggregator overlap strategy."
66
+ )
67
+ determinism: bool = Field(
68
+ default=False, description="Enable deterministic computation."
69
+ )
70
+ previous_parameters: Optional[Dict] = Field(
71
+ default=None,
72
+ description="Previous parameters to be used for resuming training and performing sanity checks.",
73
+ )
@@ -0,0 +1,16 @@
1
+ from typing_extensions import Literal
2
+
3
+ from pydantic import BaseModel, Field, ConfigDict
4
+
5
+ ACCOUNTANT_OPTIONS = Literal["rdp", "gdp", "prv"]
6
+
7
+
8
+ class DifferentialPrivacyConfig(BaseModel):
9
+ model_config = ConfigDict(extra="allow")
10
+ noise_multiplier: float = Field(default=10.0)
11
+ max_grad_norm: float = Field(default=1.0)
12
+ accountant: ACCOUNTANT_OPTIONS = Field(default="rdp")
13
+ secure_mode: bool = Field(default=False)
14
+ allow_opacus_model_fix: bool = Field(default=True)
15
+ delta: float = Field(default=1e-5)
16
+ physical_batch_size: int = Field(validate_default=True)
@@ -0,0 +1 @@
1
+ exclude_parameters = {"differential_privacy"}
@@ -0,0 +1,82 @@
1
+ from pydantic import BaseModel, model_validator, Field, AliasChoices, ConfigDict
2
+ from typing_extensions import Self, Literal, Optional
3
+ from typing import Union
4
+ from GANDLF.configuration.validators import validate_class_list, validate_norm_type
5
+ from GANDLF.models import global_models_dict
6
+
7
+ # Define model architecture options
8
+ ARCHITECTURE_OPTIONS = Literal[tuple(global_models_dict.keys())]
9
+ # Define model norm_type options
10
+ NORM_TYPE_OPTIONS = Literal["batch", "instance", "none"]
11
+ # Define model final_layer options
12
+ FINAL_LAYER_OPTIONS = Literal[
13
+ "sigmoid",
14
+ "softmax",
15
+ "logsoftmax",
16
+ "tanh",
17
+ "identity",
18
+ "logits",
19
+ "regression",
20
+ "None",
21
+ "none",
22
+ ]
23
+ TYPE_OPTIONS = Literal["torch", "openvino"]
24
+ DIMENSIONS_OPTIONS = Literal[2, 3]
25
+
26
+
27
+ # You can define new parameters for model here. Please read the pydantic documentation.
28
+ # It allows extra fields in model dict.
29
+ class ModelConfig(BaseModel):
30
+ model_config = ConfigDict(
31
+ extra="allow"
32
+ ) # it allows extra fields in the model dict
33
+ dimension: Optional[DIMENSIONS_OPTIONS] = Field(
34
+ description="model input dimension (2D or 3D)."
35
+ )
36
+ architecture: ARCHITECTURE_OPTIONS = Field(description="Architecture.")
37
+ final_layer: FINAL_LAYER_OPTIONS = Field(description="Final layer.")
38
+ norm_type: Optional[NORM_TYPE_OPTIONS] = Field(
39
+ description="Normalization type.", default="batch"
40
+ ) # TODO: check it again
41
+ base_filters: Optional[int] = Field(
42
+ description="Base filters.", default=None, validate_default=True
43
+ ) # default is 32
44
+ class_list: Union[list, str] = Field(default=[], description="Class list.")
45
+ num_channels: Optional[int] = Field(
46
+ description="Number of channels.",
47
+ validation_alias=AliasChoices(
48
+ "num_channels", "n_channels", "channels", "model_channels"
49
+ ),
50
+ default=3,
51
+ ) # TODO: check it
52
+ type: TYPE_OPTIONS = Field(description="Type of model.", default="torch")
53
+ data_type: str = Field(description="Data type.", default="FP32")
54
+ save_at_every_epoch: bool = Field(default=False, description="Save at every epoch.")
55
+ amp: bool = Field(default=False, description="Automatic mixed precision")
56
+ ignore_label_validation: Union[int, None] = Field(
57
+ default=None, description="Ignore label validation."
58
+ ) # TODO: To check it
59
+ print_summary: bool = Field(default=True, description="Print summary.")
60
+
61
+ @model_validator(mode="after")
62
+ def model_validate(self) -> Self:
63
+ # TODO: Change the print to logging.warnings
64
+ self.class_list = validate_class_list(
65
+ self.class_list
66
+ ) # init and validate the class_list parameter
67
+ self.norm_type = validate_norm_type(
68
+ self.norm_type, self.architecture
69
+ ) # init and validate the norm type
70
+ if self.amp is False:
71
+ print("NOT using Mixed Precision Training")
72
+
73
+ if self.save_at_every_epoch:
74
+ print(
75
+ "WARNING: 'save_at_every_epoch' will result in TREMENDOUS storage usage; use at your own risk."
76
+ ) # TODO: It is better to use logging.warning
77
+
78
+ if self.base_filters is None:
79
+ self.base_filters = 32
80
+ print("Using default 'base_filters' in 'model': ", self.base_filters)
81
+
82
+ return self
@@ -0,0 +1,25 @@
1
+ from pydantic import BaseModel, Field, model_validator
2
+ from typing_extensions import Self, Optional
3
+
4
+
5
+ class NestedTraining(BaseModel):
6
+ stratified: bool = Field(
7
+ default=False,
8
+ description="this will perform stratified k-fold cross-validation but only with offline data splitting",
9
+ )
10
+ testing: int = Field(
11
+ default=-5,
12
+ description="this controls the number of testing data folds for final model evaluation; [NOT recommended] to disable this, use '1'",
13
+ le=10,
14
+ )
15
+ validation: int = Field(
16
+ default=-5,
17
+ description="this controls the number of validation data folds to be used for model *selection* during training (not used for back-propagation)",
18
+ )
19
+ proportional: Optional[bool] = Field(default=False)
20
+
21
+ @model_validator(mode="after")
22
+ def validate_nested_training(self) -> Self:
23
+ if self.proportional is not None:
24
+ self.stratified = self.proportional
25
+ return self
@@ -0,0 +1,121 @@
1
+ from typing import Tuple
2
+
3
+ from pydantic import BaseModel, Field, ConfigDict
4
+ from typing_extensions import Literal
5
+
6
+ from GANDLF.optimizers import global_optimizer_dict
7
+
8
+ # takes the keys from global optimizer
9
+ OPTIMIZER_OPTIONS = Literal[tuple(global_optimizer_dict.keys())]
10
+
11
+
12
+ class SgdConfig(BaseModel):
13
+ momentum: float = Field(default=0.99)
14
+ weight_decay: float = Field(default=3e-05)
15
+ dampening: float = Field(default=0)
16
+ nesterov: bool = Field(default=True)
17
+
18
+
19
+ class AsgdConfig(BaseModel):
20
+ alpha: float = Field(default=0.75)
21
+ t0: float = Field(default=1e6)
22
+ lambd: float = Field(default=1e-4)
23
+ weight_decay: float = Field(default=3e-05)
24
+
25
+
26
+ class AdamConfig(BaseModel):
27
+ betas: Tuple[float, float] = Field(default=(0.9, 0.999))
28
+ weight_decay: float = Field(default=0.00005)
29
+ eps: float = Field(default=1e-8)
30
+ amsgrad: bool = Field(default=False)
31
+
32
+
33
+ class AdamaxConfig(BaseModel):
34
+ betas: Tuple[float, float] = Field(default=(0.9, 0.999))
35
+ weight_decay: float = Field(default=0.00005)
36
+ eps: float = Field(default=1e-8)
37
+
38
+
39
+ class RpropConfig(BaseModel):
40
+ etas: Tuple[float, float] = Field(default=(0.5, 1.2))
41
+ step_sizes: Tuple[float, float] = Field(default=(1e-6, 50))
42
+
43
+
44
+ class AdadeltaConfig(BaseModel):
45
+ rho: float = Field(default=0.9)
46
+ eps: float = Field(default=1e-6)
47
+ weight_decay: float = Field(default=3e-05)
48
+
49
+
50
+ class AdagradConfig(BaseModel):
51
+ lr_decay: float = Field(default=0)
52
+ eps: float = Field(default=1e-6)
53
+ weight_decay: float = Field(default=3e-05)
54
+
55
+
56
+ class RmspropConfig(BaseModel):
57
+ alpha: float = Field(default=0.99)
58
+ eps: float = Field(default=1e-8)
59
+ centered: bool = Field(default=False)
60
+ momentum: float = Field(default=0)
61
+ weight_decay: float = Field(default=3e-05)
62
+
63
+
64
+ class RadamConfig(BaseModel):
65
+ betas: Tuple[float, float] = Field(default=(0.9, 0.999))
66
+ eps: float = Field(default=1e-8)
67
+ weight_decay: float = Field(default=3e-05)
68
+ foreach: bool = Field(default=None)
69
+
70
+
71
+ class NadamConfig(BaseModel):
72
+ betas: Tuple[float, float] = Field(default=(0.9, 0.999))
73
+ eps: float = Field(default=1e-8)
74
+ weight_decay: float = Field(default=3e-05)
75
+ foreach: bool = Field(default=None)
76
+
77
+
78
+ class NovogradConfig(BaseModel):
79
+ betas: Tuple[float, float] = Field(default=(0.9, 0.999))
80
+ eps: float = Field(default=1e-8)
81
+ weight_decay: float = Field(default=3e-05)
82
+ amsgrad: bool = Field(default=False)
83
+
84
+
85
+ class AdemamixConfig(BaseModel):
86
+ pass
87
+
88
+
89
+ class LionConfig(BaseModel):
90
+ betas: Tuple[float, float] = Field(default=(0.9, 0.999))
91
+ weight_decay: float = Field(default=0.0)
92
+ decoupled_weight_decay: bool = Field(default=False)
93
+
94
+
95
+ class AdoptConfig(BaseModel):
96
+ pass
97
+
98
+
99
+ class OptimizerConfig(BaseModel):
100
+ model_config = ConfigDict(extra="allow")
101
+ type: OPTIMIZER_OPTIONS = Field(description="Type of optimizer to use")
102
+
103
+
104
+ optimizer_dict_config = {
105
+ "sgd": SgdConfig,
106
+ "asgd": AsgdConfig,
107
+ "adam": AdamConfig,
108
+ "adamw": AdamConfig,
109
+ "adamax": AdamaxConfig,
110
+ # "sparseadam": sparseadam,
111
+ "rprop": RpropConfig,
112
+ "adadelta": AdadeltaConfig,
113
+ "adagrad": AdagradConfig,
114
+ "rmsprop": RmspropConfig,
115
+ "radam": RadamConfig,
116
+ "novograd": NovogradConfig,
117
+ "nadam": NadamConfig,
118
+ "ademamix": AdemamixConfig,
119
+ "lion": LionConfig,
120
+ "adopt": AdoptConfig,
121
+ }
@@ -0,0 +1,10 @@
1
+ from pydantic import BaseModel, ConfigDict
2
+ from GANDLF.configuration.user_defined_config import UserDefinedParameters
3
+
4
+
5
+ class ParametersConfiguration(BaseModel):
6
+ model_config = ConfigDict(extra="allow")
7
+
8
+
9
+ class Parameters(ParametersConfiguration, UserDefinedParameters):
10
+ pass
@@ -0,0 +1,11 @@
1
+ from pydantic import BaseModel, Field
2
+ from typing_extensions import Literal
3
+
4
+ TYPE_OPTIONS = Literal["uniform", "label"]
5
+
6
+
7
+ class PatchSamplerConfig(BaseModel):
8
+ type: TYPE_OPTIONS = Field(default="uniform")
9
+ enable_padding: bool = Field(default=False)
10
+ padding_mode: str = Field(default="symmetric")
11
+ biased_sampling: bool = Field(default=False)
@@ -0,0 +1,10 @@
1
+ from pydantic import BaseModel, ConfigDict, Field
2
+ from typing_extensions import Any
3
+
4
+
5
+ class PostProcessingConfig(BaseModel):
6
+ model_config = ConfigDict(extra="forbid", exclude_none=True)
7
+ fill_holes: Any = Field(default=None)
8
+ mapping: dict = Field(default=None)
9
+ morphology: Any = Field(default=None)
10
+ cca: Any = Field(default=None)
@@ -0,0 +1,94 @@
1
+ from pydantic import BaseModel, ConfigDict, Field, AliasChoices, model_validator
2
+ from typing_extensions import Any, Literal, Self
3
+
4
+
5
+ class ThresholdConfig(BaseModel):
6
+ min: int = Field()
7
+ max: int = Field()
8
+
9
+
10
+ class ClipConfig(BaseModel):
11
+ min: int = Field()
12
+ max: int = Field()
13
+
14
+
15
+ class RescaleConfig(BaseModel):
16
+ in_min_max: list[float] = Field(default=[15, 125])
17
+ out_min_max: list[float] = Field(default=[0, 1])
18
+ percentiles: list[float] = Field(default=[5, 95])
19
+
20
+
21
+ class HistogramMatchingConfig(BaseModel):
22
+ num_hist_level: int = Field(default=1024)
23
+ num_match_points: int = Field(default=16)
24
+ target: Any = Field(default=None)
25
+
26
+
27
+ class ResampleMinConfig(BaseModel):
28
+ resolution: list[float] = Field(default=None)
29
+
30
+
31
+ class ResampleConfig(BaseModel):
32
+ resolution: list[float] = Field(default=None)
33
+
34
+
35
+ class StainNormalizationConfig(BaseModel):
36
+ target: Any = Field()
37
+ extractor: Literal["vahadane", "ruifrok", "macenko"] = Field(default="ruifrok")
38
+
39
+
40
+ class PreProcessingConfig(BaseModel):
41
+ model_config = ConfigDict(extra="forbid", exclude_none=True)
42
+ to_canonical: Any = Field(default=None)
43
+ threshold: ThresholdConfig = Field(default=None)
44
+ clip: ClipConfig = Field(default=None)
45
+ clamp: ClipConfig = Field(default=None)
46
+ crop_external_zero_planes: Any = Field(default=None)
47
+ crop: list[int] = Field(default=None)
48
+ centercrop: list[int] = Field(default=None)
49
+ normalize_by_val: Any = Field(default=None)
50
+ normalize_imagenet: Any = Field(default=None)
51
+ normalize_standardize: Any = Field(default=None)
52
+ normalize_div_by_255: Any = Field(default=None)
53
+ normalize: Any = Field(default=None)
54
+ normalize_nonZero: Any = Field(
55
+ default=None,
56
+ validation_alias=AliasChoices("normalize_nonZero", "normalize_nonzero"),
57
+ )
58
+ normalize_nonZero_masked: Any = Field(
59
+ default=None,
60
+ validation_alias=AliasChoices(
61
+ "normalize_nonZero_masked", "normalize_nonzero_masked"
62
+ ),
63
+ )
64
+ rescale: RescaleConfig = Field(default=None)
65
+ rgba2rgb: Any = Field(
66
+ default=None,
67
+ validation_alias=AliasChoices("rgba2rgb", "rgbatorgb", "rgba_to_rgb"),
68
+ )
69
+ rgb2rgba: Any = Field(
70
+ default=None,
71
+ validation_alias=AliasChoices("rgb2rgba", "rgbtorgba", "rgb_to_rgba"),
72
+ )
73
+ histogram_matching: HistogramMatchingConfig = Field(default=None)
74
+ histogram_equalization: HistogramMatchingConfig = Field(default=None)
75
+ adaptive_histogram_equalization: Any = Field(default=None)
76
+ resample: ResampleConfig = Field(default=None)
77
+ resize_image: list[int] = Field(
78
+ default=None,
79
+ validation_alias=AliasChoices(
80
+ "resize_image", "resize", "resize_image", "resize_images"
81
+ ),
82
+ )
83
+ resize_patch: list[int] = Field(default=None)
84
+ stain_normalization: StainNormalizationConfig = Field(default=None)
85
+ resample_min: ResampleMinConfig = Field(
86
+ default=None, validation_alias=AliasChoices("resample_min", "resample_minimum")
87
+ )
88
+
89
+ @model_validator(mode="after")
90
+ def pre_processing_validate(self) -> Self:
91
+ if self.adaptive_histogram_equalization is not None:
92
+ self.histogram_matching = HistogramMatchingConfig(target="adaptive")
93
+ self.adaptive_histogram_equalization = None
94
+ return self
@@ -0,0 +1,92 @@
1
+ from pydantic import BaseModel, ConfigDict, Field
2
+ from typing_extensions import Literal, Union
3
+ from GANDLF.schedulers import global_schedulers_dict
4
+
5
+ TYPE_OPTIONS = Literal[tuple(global_schedulers_dict.keys())]
6
+
7
+
8
+ class BaseTriangleConfig(BaseModel):
9
+ min_lr: float = Field(default=(10**-3))
10
+ max_lr: float = Field(default=1)
11
+ step_size: float = Field(description="step_size", default=None)
12
+
13
+
14
+ class TriangleModifiedConfig(BaseModel):
15
+ min_lr: float = Field(default=0.000001)
16
+ max_lr: float = Field(default=0.001)
17
+ max_lr_multiplier: float = Field(default=1.0)
18
+ step_size: float = Field(description="step_size", default=None)
19
+
20
+
21
+ class CyclicLrBaseConfig(BaseModel):
22
+ # More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CyclicLR.html
23
+ min_lr: float = Field(
24
+ default=None
25
+ ) # The default value is calculated according the learning rate * 0.001
26
+ max_lr: float = Field(default=None) # calculate in the validation stage
27
+ gamma: float = Field(default=0.1)
28
+ scale_mode: Literal["cycle", "iterations"] = Field(default="cycle")
29
+ cycle_momentum: bool = Field(default=False)
30
+ base_momentum: float = Field(default=0.8)
31
+ max_momentum: float = Field(default=0.9)
32
+ step_size: float = Field(description="step_size", default=None)
33
+
34
+
35
+ class ExpConfig(BaseModel):
36
+ gamma: float = Field(default=0.1)
37
+
38
+
39
+ class StepConfig(BaseModel):
40
+ gamma: float = Field(default=0.1)
41
+ step_size: float = Field(description="step_size", default=None)
42
+
43
+
44
+ class CosineannealingConfig(BaseModel):
45
+ # More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html
46
+ T_0: int = Field(default=5)
47
+ T_mult: float = Field(default=1)
48
+ min_lr: float = Field(default=0.001)
49
+
50
+
51
+ class ReduceOnPlateauConfig(BaseModel):
52
+ # More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html
53
+ min_lr: Union[float, list] = Field(default=None)
54
+ gamma: float = Field(default=0.1)
55
+ mode: Literal["min", "max"] = Field(default="min")
56
+ factor: float = Field(default=0.1)
57
+ patience: int = Field(default=10)
58
+ threshold: float = Field(default=0.0001)
59
+ cooldown: int = Field(default=0)
60
+ threshold_mode: Literal["rel", "abs"] = Field(default="rel")
61
+
62
+
63
+ class WarmupcosinescheduleConfig(BaseModel):
64
+ # More details https://docs.monai.io/en/stable/optimizers.html#monai.optimizers.WarmupCosineSchedule
65
+ warmup_steps: int = Field(default=None)
66
+
67
+
68
+ # It allows extra parameters
69
+ class SchedulerConfig(BaseModel):
70
+ model_config = ConfigDict(extra="allow")
71
+ type: TYPE_OPTIONS = Field(description="scheduler type")
72
+
73
+
74
+ # Define the type and the scheduler base model class
75
+ schedulers_dict_config = {
76
+ "triangle": BaseTriangleConfig,
77
+ "triangle_modified": TriangleModifiedConfig,
78
+ "triangular": CyclicLrBaseConfig,
79
+ "exp_range": CyclicLrBaseConfig,
80
+ "exp": ExpConfig,
81
+ "exponential": ExpConfig,
82
+ "step": StepConfig,
83
+ "reduce_on_plateau": ReduceOnPlateauConfig,
84
+ "reduce-on-plateau": ReduceOnPlateauConfig,
85
+ "plateau": ReduceOnPlateauConfig,
86
+ "reduceonplateau": ReduceOnPlateauConfig,
87
+ "cosineannealingwarmrestarts": CosineannealingConfig,
88
+ "cosineannealing": CosineannealingConfig,
89
+ "cosineannealinglr": CosineannealingConfig,
90
+ "warmupcosineschedule": WarmupcosinescheduleConfig,
91
+ "wcs": WarmupcosinescheduleConfig,
92
+ }
@@ -0,0 +1,131 @@
1
+ from typing import Union
2
+ from pydantic import BaseModel, model_validator, Field, AfterValidator
3
+ from GANDLF.configuration.default_config import DefaultParameters
4
+ from GANDLF.configuration.differential_privacy_config import DifferentialPrivacyConfig
5
+ from GANDLF.configuration.nested_training_config import NestedTraining
6
+ from GANDLF.configuration.optimizer_config import OptimizerConfig
7
+ from GANDLF.configuration.patch_sampler_config import PatchSamplerConfig
8
+ from GANDLF.configuration.scheduler_config import SchedulerConfig
9
+ from GANDLF.utils import version_check
10
+ from importlib.metadata import version
11
+ from typing_extensions import Self, Literal, Annotated
12
+ from GANDLF.configuration.validators import (
13
+ validate_scheduler,
14
+ validate_optimizer,
15
+ validate_loss_function,
16
+ validate_metrics,
17
+ validate_data_preprocessing,
18
+ validate_patch_size,
19
+ validate_parallel_compute_command,
20
+ validate_patch_sampler,
21
+ validate_data_augmentation,
22
+ validate_data_postprocessing_after_reverse_one_hot_encoding,
23
+ validate_differential_privacy,
24
+ )
25
+ from GANDLF.configuration.model_config import ModelConfig
26
+
27
+
28
+ class Version(BaseModel):
29
+ minimum: str
30
+ maximum: str
31
+
32
+ @model_validator(mode="after")
33
+ def validate_version(self) -> Self:
34
+ if version_check(self.model_dump(), version_to_check=version("GANDLF")):
35
+ return self
36
+
37
+
38
+ class InferenceMechanismConfig(BaseModel):
39
+ grid_aggregator_overlap: Literal["crop", "average"] = Field(default="crop")
40
+ patch_overlap: int = Field(default=0)
41
+
42
+
43
+ class UserDefinedParameters(DefaultParameters):
44
+ version: Version = Field(
45
+ default=Version(minimum=version("GANDLF"), maximum=version("GANDLF")),
46
+ description="GANDLF version",
47
+ )
48
+ patch_size: Union[list[Union[int, float]], int, float] = Field(
49
+ description="Patch size."
50
+ )
51
+ model: ModelConfig = Field(description="The model to use. ")
52
+ modality: Literal["rad", "histo", "path"] = Field(description="Modality.")
53
+ loss_function: Annotated[
54
+ Union[dict, str],
55
+ Field(description="Loss function."),
56
+ AfterValidator(validate_loss_function),
57
+ ]
58
+ metrics: Annotated[
59
+ Union[dict, list[Union[str, dict, set]]],
60
+ Field(description="Metrics."),
61
+ AfterValidator(validate_metrics),
62
+ ]
63
+ nested_training: NestedTraining = Field(description="Nested training.")
64
+ parallel_compute_command: str = Field(
65
+ default="", description="Parallel compute command."
66
+ )
67
+ scheduler: Union[str, SchedulerConfig] = Field(
68
+ description="Scheduler.", default=SchedulerConfig(type="triangle_modified")
69
+ )
70
+ optimizer: Union[str, OptimizerConfig] = Field(
71
+ description="Optimizer.", default=OptimizerConfig(type="adam")
72
+ )
73
+ patch_sampler: Union[str, PatchSamplerConfig] = Field(
74
+ description="Patch sampler.", default=PatchSamplerConfig()
75
+ )
76
+ inference_mechanism: InferenceMechanismConfig = Field(
77
+ description="Inference mechanism.", default=InferenceMechanismConfig()
78
+ )
79
+ data_postprocessing_after_reverse_one_hot_encoding: dict = Field(
80
+ description="data_postprocessing_after_reverse_one_hot_encoding.", default={}
81
+ )
82
+ differential_privacy: Union[bool, DifferentialPrivacyConfig] = Field(
83
+ description="Differential privacy.", default=None
84
+ )
85
+ clip_mode: Literal["norm", "value"] = Field(
86
+ description="Clip mode.", default="norm"
87
+ )
88
+ data_preprocessing: Annotated[
89
+ dict,
90
+ Field(description="Data preprocessing."),
91
+ AfterValidator(validate_data_preprocessing),
92
+ ] = {}
93
+ data_augmentation: Annotated[dict, Field(description="Data augmentation.")] = {}
94
+
95
+ # Validators
96
+ @model_validator(mode="after")
97
+ def validate(self) -> Self:
98
+ # validate the patch_size
99
+ self.patch_size, self.model.dimension = validate_patch_size(
100
+ self.patch_size, self.model.dimension
101
+ )
102
+ # validate the parallel_compute_command
103
+ self.parallel_compute_command = validate_parallel_compute_command(
104
+ self.parallel_compute_command
105
+ )
106
+ # validate scheduler
107
+ self.scheduler = validate_scheduler(
108
+ self.scheduler, self.learning_rate, self.num_epochs
109
+ )
110
+ # validate optimizer
111
+ self.optimizer = validate_optimizer(self.optimizer)
112
+ # validate patch_sampler
113
+ self.patch_sampler = validate_patch_sampler(self.patch_sampler)
114
+ # validate_data_augmentation
115
+ self.data_augmentation = validate_data_augmentation(
116
+ self.data_augmentation, self.patch_size
117
+ )
118
+ # validate data_postprocessing_after_reverse_one_hot_encoding
119
+ (
120
+ self.data_postprocessing_after_reverse_one_hot_encoding,
121
+ self.data_postprocessing,
122
+ ) = validate_data_postprocessing_after_reverse_one_hot_encoding(
123
+ self.data_postprocessing_after_reverse_one_hot_encoding,
124
+ self.data_postprocessing,
125
+ )
126
+ # validate differential_privacy
127
+ self.differential_privacy = validate_differential_privacy(
128
+ self.differential_privacy, self.batch_size
129
+ )
130
+
131
+ return self