careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (132) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +321 -131
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -13
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -202
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -13
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +89 -56
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -271
  111. careamics/config/algorithm.py +0 -231
  112. careamics/config/config.py +0 -296
  113. careamics/config/config_filter.py +0 -44
  114. careamics/config/data.py +0 -194
  115. careamics/config/torch_optim.py +0 -118
  116. careamics/config/training.py +0 -534
  117. careamics/dataset/dataset_utils.py +0 -115
  118. careamics/dataset/patching.py +0 -493
  119. careamics/dataset/prepare_dataset.py +0 -174
  120. careamics/dataset/tiff_dataset.py +0 -211
  121. careamics/engine.py +0 -954
  122. careamics/manipulation/__init__.py +0 -4
  123. careamics/manipulation/pixel_manipulation.py +0 -158
  124. careamics/prediction/prediction_utils.py +0 -102
  125. careamics/utils/ascii_logo.txt +0 -9
  126. careamics/utils/augment.py +0 -65
  127. careamics/utils/normalization.py +0 -55
  128. careamics/utils/validators.py +0 -156
  129. careamics/utils/wandb.py +0 -121
  130. careamics-0.1.0rc1.dist-info/METADATA +0 -80
  131. careamics-0.1.0rc1.dist-info/RECORD +0 -46
  132. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -1,296 +0,0 @@
1
- """Pydantic CAREamics configuration."""
2
- from __future__ import annotations
3
-
4
- import re
5
- from pathlib import Path
6
- from typing import Dict, List, Union
7
-
8
- import yaml
9
- from pydantic import (
10
- BaseModel,
11
- ConfigDict,
12
- field_validator,
13
- model_validator,
14
- )
15
-
16
- from .algorithm import Algorithm
17
- from .config_filter import paths_to_str
18
- from .data import Data
19
- from .training import Training
20
-
21
-
22
- class Configuration(BaseModel):
23
- """
24
- CAREamics configuration.
25
-
26
- To change the configuration from 2D to 3D, we recommend using the following method:
27
- >>> set_3D(is_3D, axes)
28
-
29
- Attributes
30
- ----------
31
- experiment_name : str
32
- Name of the experiment.
33
- working_directory : Union[str, Path]
34
- Path to the working directory.
35
- algorithm : Algorithm
36
- Algorithm configuration.
37
- training : Training
38
- Training configuration.
39
- """
40
-
41
- model_config = ConfigDict(validate_assignment=True)
42
-
43
- # required parameters
44
- experiment_name: str
45
- working_directory: Path
46
-
47
- # Sub-configurations
48
- algorithm: Algorithm
49
- data: Data
50
- training: Training
51
-
52
- def set_3D(self, is_3D: bool, axes: str) -> None:
53
- """
54
- Set 3D flag and axes.
55
-
56
- Parameters
57
- ----------
58
- is_3D : bool
59
- Whether the algorithm is 3D or not.
60
- axes : str
61
- Axes of the data.
62
- """
63
- # set the flag and axes (this will not trigger validation at the config level)
64
- self.algorithm.is_3D = is_3D
65
- self.data.axes = axes
66
-
67
- # cheap hack: trigger validation
68
- self.algorithm = self.algorithm
69
-
70
- @field_validator("experiment_name")
71
- def no_symbol(cls, name: str) -> str:
72
- """
73
- Validate experiment name.
74
-
75
- A valid experiment name is a non-empty string with only contains letters,
76
- numbers, underscores, dashes and spaces.
77
-
78
- Parameters
79
- ----------
80
- name : str
81
- Name to validate.
82
-
83
- Returns
84
- -------
85
- str
86
- Validated name.
87
-
88
- Raises
89
- ------
90
- ValueError
91
- If the name is empty or contains invalid characters.
92
- """
93
- if len(name) == 0 or name.isspace():
94
- raise ValueError("Experiment name is empty.")
95
-
96
- # Validate using a regex that it contains only letters, numbers, underscores,
97
- # dashes and spaces
98
- if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
99
- raise ValueError(
100
- f"Experiment name contains invalid characters (got {name}). "
101
- f"Only letters, numbers, underscores, dashes and spaces are allowed."
102
- )
103
-
104
- return name
105
-
106
- @field_validator("working_directory")
107
- def parent_directory_exists(cls, workdir: Union[str, Path]) -> Path:
108
- """
109
- Validate working directory.
110
-
111
- A valid working directory is a directory whose parent directory exists. If the
112
- working directory does not exist itself, it is then created.
113
-
114
- Parameters
115
- ----------
116
- workdir : Union[str, Path]
117
- Working directory to validate.
118
-
119
- Returns
120
- -------
121
- Path
122
- Validated working directory.
123
-
124
- Raises
125
- ------
126
- ValueError
127
- If the working directory is not a directory, or if the parent directory does
128
- not exist.
129
- """
130
- path = Path(workdir)
131
-
132
- # check if it is a directory
133
- if path.exists() and not path.is_dir():
134
- raise ValueError(f"Working directory is not a directory (got {workdir}).")
135
-
136
- # check if parent directory exists
137
- if not path.parent.exists():
138
- raise ValueError(
139
- f"Parent directory of working directory does not exist (got {workdir})."
140
- )
141
-
142
- # create directory if it does not exist already
143
- path.mkdir(exist_ok=True)
144
-
145
- return path
146
-
147
- @model_validator(mode="after")
148
- def validate_3D(cls, config: Configuration) -> Configuration:
149
- """
150
- Check 3D flag validity.
151
-
152
- Check that the algorithm is_3D flag is compatible with the axes in the
153
- data configuration.
154
-
155
- Parameters
156
- ----------
157
- config : Configuration
158
- Configuration to validate.
159
-
160
- Returns
161
- -------
162
- Configuration
163
- Validated configuration.
164
-
165
- Raises
166
- ------
167
- ValueError
168
- If the algorithm is 3D but the data axes are not, or if the algorithm is
169
- not 3D but the data axes are.
170
- """
171
- # check that is_3D and axes are compatible
172
- if config.algorithm.is_3D and "Z" not in config.data.axes:
173
- raise ValueError(
174
- f"Algorithm is 3D but data axes are not (got axes {config.data.axes})."
175
- )
176
- elif not config.algorithm.is_3D and "Z" in config.data.axes:
177
- raise ValueError(
178
- f"Algorithm is not 3D but data axes are (got axes {config.data.axes})."
179
- )
180
-
181
- return config
182
-
183
- def model_dump(
184
- self, exclude_optionals: bool = True, *args: List, **kwargs: Dict
185
- ) -> Dict:
186
- """
187
- Override model_dump method.
188
-
189
- The purpose is to ensure export smooth import to yaml. It includes:
190
- - remove entries with None value.
191
- - remove optional values if they have the default value.
192
-
193
- Parameters
194
- ----------
195
- exclude_optionals : bool, optional
196
- Whether to exclude optional fields with default values or not, by default
197
- True.
198
- *args : List
199
- Positional arguments, unused.
200
- **kwargs : Dict
201
- Keyword arguments, unused.
202
-
203
- Returns
204
- -------
205
- dict
206
- Dictionary containing the model parameters.
207
- """
208
- dictionary = super().model_dump(exclude_none=True)
209
-
210
- # remove paths
211
- dictionary = paths_to_str(dictionary)
212
-
213
- dictionary["algorithm"] = self.algorithm.model_dump(
214
- exclude_optionals=exclude_optionals
215
- )
216
- dictionary["data"] = self.data.model_dump()
217
-
218
- dictionary["training"] = self.training.model_dump(
219
- exclude_optionals=exclude_optionals
220
- )
221
-
222
- return dictionary
223
-
224
-
225
- def load_configuration(path: Union[str, Path]) -> Configuration:
226
- """
227
- Load configuration from a yaml file.
228
-
229
- Parameters
230
- ----------
231
- path : Union[str, Path]
232
- Path to the configuration.
233
-
234
- Returns
235
- -------
236
- Configuration
237
- Configuration.
238
-
239
- Raises
240
- ------
241
- FileNotFoundError
242
- If the configuration file does not exist.
243
- """
244
- # load dictionary from yaml
245
- if not Path(path).exists():
246
- raise FileNotFoundError(
247
- f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
248
- )
249
-
250
- dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
251
-
252
- return Configuration(**dictionary)
253
-
254
-
255
- def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
256
- """
257
- Save configuration to path.
258
-
259
- Parameters
260
- ----------
261
- config : Configuration
262
- Configuration to save.
263
- path : Union[str, Path]
264
- Path to a existing folder in which to save the configuration or to an existing
265
- configuration file.
266
-
267
- Returns
268
- -------
269
- Path
270
- Path object representing the configuration.
271
-
272
- Raises
273
- ------
274
- ValueError
275
- If the path does not point to an existing directory or .yml file.
276
- """
277
- # make sure path is a Path object
278
- config_path = Path(path)
279
-
280
- # check if path is pointing to an existing directory or .yml file
281
- if config_path.exists():
282
- if config_path.is_dir():
283
- config_path = Path(config_path, "config.yml")
284
- elif config_path.suffix != ".yml":
285
- raise ValueError(
286
- f"Path must be a directory or .yml file (got {config_path})."
287
- )
288
- else:
289
- if config_path.suffix != ".yml":
290
- raise ValueError(f"Path must be a .yml file (got {config_path}).")
291
-
292
- # save configuration as dictionary to yaml
293
- with open(config_path, "w") as f:
294
- yaml.dump(config.model_dump(), f, default_flow_style=False)
295
-
296
- return config_path
@@ -1,44 +0,0 @@
1
- """Convenience functions to filter dictionaries resulting from a Pydantic export."""
2
- from pathlib import Path
3
- from typing import Dict
4
-
5
-
6
- def paths_to_str(dictionary: dict) -> dict:
7
- """
8
- Replace Path objects in a dictionary by str.
9
-
10
- Parameters
11
- ----------
12
- dictionary : dict
13
- Dictionary to modify.
14
-
15
- Returns
16
- -------
17
- dict
18
- Modified dictionary.
19
- """
20
- for k in dictionary.keys():
21
- if isinstance(dictionary[k], Path):
22
- dictionary[k] = str(dictionary[k])
23
-
24
- return dictionary
25
-
26
-
27
- def remove_default_optionals(dictionary: Dict, default: Dict) -> None:
28
- """
29
- Remove default arguments from a dictionary.
30
-
31
- The method removes arguments if they are equal to the provided default ones.
32
-
33
- Parameters
34
- ----------
35
- dictionary : dict
36
- Dictionary to modify.
37
- default : dict
38
- Dictionary containing the default values.
39
- """
40
- dict_copy = dictionary.copy()
41
- for k in dict_copy.keys():
42
- if k in default.keys():
43
- if dict_copy[k] == default[k]:
44
- del dictionary[k]
careamics/config/data.py DELETED
@@ -1,194 +0,0 @@
1
- """Data configuration."""
2
- from __future__ import annotations
3
-
4
- from enum import Enum
5
- from typing import Dict, List, Optional
6
-
7
- from pydantic import (
8
- BaseModel,
9
- ConfigDict,
10
- Field,
11
- field_validator,
12
- model_validator,
13
- )
14
-
15
- from ..utils import check_axes_validity
16
-
17
-
18
- class SupportedExtension(str, Enum):
19
- """
20
- Supported extensions for input data.
21
-
22
- Currently supported:
23
- - tif/tiff: .tiff files.
24
- """
25
-
26
- TIFF = "tiff"
27
- TIF = "tif"
28
-
29
- @classmethod
30
- def _missing_(cls, value: object) -> str:
31
- """
32
- Override default behaviour for missing values.
33
-
34
- This method is called when `value` is not found in the enum values. It converts
35
- `value` to lowercase, removes "." if it is the first character and tries to
36
- match it with enum values.
37
-
38
- Parameters
39
- ----------
40
- value : object
41
- Value to be matched with enum values.
42
-
43
- Returns
44
- -------
45
- str
46
- Matched enum value.
47
- """
48
- if isinstance(value, str):
49
- lower_value = value.lower()
50
-
51
- if lower_value.startswith("."):
52
- lower_value = lower_value[1:]
53
-
54
- # attempt to match lowercase value with enum values
55
- for member in cls:
56
- if member.value == lower_value:
57
- return member
58
-
59
- # still missing
60
- return super()._missing_(value)
61
-
62
-
63
- class Data(BaseModel):
64
- """
65
- Data configuration.
66
-
67
- If std is specified, mean must be specified as well. Note that setting the std first
68
- and then the mean (if they were both `None` before) will raise a validation error.
69
- Prefer instead the following:
70
- >>> set_mean_and_std(mean, std)
71
-
72
- Attributes
73
- ----------
74
- in_memory : bool
75
- Whether to load the data in memory or not.
76
- data_format : SupportedExtension
77
- Extension of the data, without period.
78
- axes : str
79
- Axes of the data.
80
- mean: Optional[float]
81
- Expected data mean.
82
- std: Optional[float]
83
- Expected data standard deviation.
84
- """
85
-
86
- # Pydantic class configuration
87
- model_config = ConfigDict(
88
- use_enum_values=True,
89
- validate_assignment=True,
90
- )
91
-
92
- # Mandatory fields
93
- in_memory: bool
94
- data_format: SupportedExtension
95
- axes: str
96
-
97
- # Optional fields
98
- mean: Optional[float] = Field(default=None, ge=0)
99
- std: Optional[float] = Field(default=None, gt=0)
100
-
101
- def set_mean_and_std(self, mean: float, std: float) -> None:
102
- """
103
- Set mean and standard deviation of the data.
104
-
105
- This method is preferred to setting the fields directly, as it ensures that the
106
- mean is set first, then the std; thus avoiding a validation error to be thrown.
107
-
108
- Parameters
109
- ----------
110
- mean : float
111
- Mean of the data.
112
- std : float
113
- Standard deviation of the data.
114
- """
115
- self.mean = mean
116
- self.std = std
117
-
118
- @field_validator("axes")
119
- def valid_axes(cls, axes: str) -> str:
120
- """
121
- Validate axes.
122
-
123
- Axes must be a subset of STZYX, must contain YX, be in the right order
124
- and not contain both S and T.
125
-
126
- Parameters
127
- ----------
128
- axes : str
129
- Axes of the training data.
130
-
131
- Returns
132
- -------
133
- str
134
- Validated axes of the training data.
135
-
136
- Raises
137
- ------
138
- ValueError
139
- If axes are not valid.
140
- """
141
- # validate axes
142
- check_axes_validity(axes)
143
-
144
- return axes
145
-
146
- @model_validator(mode="after")
147
- def std_only_with_mean(cls, data_model: Data) -> Data:
148
- """
149
- Check that mean and std are either both None, or both specified.
150
-
151
- If we enforce both None or both specified, we cannot set the values one by one
152
- due to the ConfDict enforcing the validation on assignment. Therefore, we check
153
- only when the std is not None and the mean is None.
154
-
155
- Parameters
156
- ----------
157
- data_model : Data
158
- Data model.
159
-
160
- Returns
161
- -------
162
- Data
163
- Validated data model.
164
-
165
- Raises
166
- ------
167
- ValueError
168
- If std is not None and mean is None.
169
- """
170
- if data_model.std is not None and data_model.mean is None:
171
- raise ValueError("Cannot have std non None if mean is None.")
172
-
173
- return data_model
174
-
175
- def model_dump(self, *args: List, **kwargs: Dict) -> dict:
176
- """
177
- Override model_dump method.
178
-
179
- The purpose is to ensure export smooth import to yaml. It includes:
180
- - remove entries with None value.
181
-
182
- Parameters
183
- ----------
184
- *args : List
185
- Positional arguments, unused.
186
- **kwargs : Dict
187
- Keyword arguments, unused.
188
-
189
- Returns
190
- -------
191
- dict
192
- Dictionary containing the model parameters.
193
- """
194
- return super().model_dump(exclude_none=True)
@@ -1,118 +0,0 @@
1
- """Convenience functions to instantiate torch.optim optimizers and schedulers."""
2
- import inspect
3
- from enum import Enum
4
- from typing import Dict
5
-
6
- from torch import optim
7
-
8
-
9
- class TorchOptimizer(str, Enum):
10
- """
11
- Supported optimizers.
12
-
13
- Currently only supports Adam and SGD.
14
- """
15
-
16
- # ASGD = "ASGD"
17
- # Adadelta = "Adadelta"
18
- # Adagrad = "Adagrad"
19
- Adam = "Adam"
20
- # AdamW = "AdamW"
21
- # Adamax = "Adamax"
22
- # LBFGS = "LBFGS"
23
- # NAdam = "NAdam"
24
- # RAdam = "RAdam"
25
- # RMSprop = "RMSprop"
26
- # Rprop = "Rprop"
27
- SGD = "SGD"
28
- # SparseAdam = "SparseAdam"
29
-
30
-
31
- # TODO: Test which schedulers are compatible and if not, how to make them compatible
32
- # (if we want to support them)
33
- class TorchLRScheduler(str, Enum):
34
- """
35
- Supported learning rate schedulers.
36
-
37
- Currently only supports ReduceLROnPlateau and StepLR.
38
- """
39
-
40
- # ChainedScheduler = "ChainedScheduler"
41
- # ConstantLR = "ConstantLR"
42
- # CosineAnnealingLR = "CosineAnnealingLR"
43
- # CosineAnnealingWarmRestarts = "CosineAnnealingWarmRestarts"
44
- # CyclicLR = "CyclicLR"
45
- # ExponentialLR = "ExponentialLR"
46
- # LambdaLR = "LambdaLR"
47
- # LinearLR = "LinearLR"
48
- # MultiStepLR = "MultiStepLR"
49
- # MultiplicativeLR = "MultiplicativeLR"
50
- # OneCycleLR = "OneCycleLR"
51
- # PolynomialLR = "PolynomialLR"
52
- ReduceLROnPlateau = "ReduceLROnPlateau"
53
- # SequentialLR = "SequentialLR"
54
- StepLR = "StepLR"
55
-
56
-
57
- def get_parameters(
58
- func: type,
59
- user_params: dict,
60
- ) -> dict:
61
- """
62
- Filter parameters according to the function signature.
63
-
64
- Parameters
65
- ----------
66
- func : type
67
- Class object.
68
- user_params : Dict
69
- User provided parameters.
70
-
71
- Returns
72
- -------
73
- Dict
74
- Parameters matching `func`'s signature.
75
- """
76
- # Get the list of all default parameters
77
- default_params = list(inspect.signature(func).parameters.keys())
78
-
79
- # Filter matching parameters
80
- params_to_be_used = set(user_params.keys()) & set(default_params)
81
-
82
- return {key: user_params[key] for key in params_to_be_used}
83
-
84
-
85
- def get_optimizers() -> Dict[str, str]:
86
- """
87
- Return the list of all optimizers available in torch.optim.
88
-
89
- Returns
90
- -------
91
- Dict
92
- Optimizers available in torch.optim.
93
- """
94
- optims = {}
95
- for name, obj in inspect.getmembers(optim):
96
- if inspect.isclass(obj) and issubclass(obj, optim.Optimizer):
97
- if name != "Optimizer":
98
- optims[name] = name
99
- return optims
100
-
101
-
102
- def get_schedulers() -> Dict[str, str]:
103
- """
104
- Return the list of all schedulers available in torch.optim.lr_scheduler.
105
-
106
- Returns
107
- -------
108
- Dict
109
- Schedulers available in torch.optim.lr_scheduler.
110
- """
111
- schedulers = {}
112
- for name, obj in inspect.getmembers(optim.lr_scheduler):
113
- if inspect.isclass(obj) and issubclass(obj, optim.lr_scheduler.LRScheduler):
114
- if "LRScheduler" not in name:
115
- schedulers[name] = name
116
- elif name == "ReduceLROnPlateau": # somewhat not a subclass of LRScheduler
117
- schedulers[name] = name
118
- return schedulers