careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (132) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +321 -131
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -13
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -202
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -13
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +89 -56
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -271
  111. careamics/config/algorithm.py +0 -231
  112. careamics/config/config.py +0 -296
  113. careamics/config/config_filter.py +0 -44
  114. careamics/config/data.py +0 -194
  115. careamics/config/torch_optim.py +0 -118
  116. careamics/config/training.py +0 -534
  117. careamics/dataset/dataset_utils.py +0 -115
  118. careamics/dataset/patching.py +0 -493
  119. careamics/dataset/prepare_dataset.py +0 -174
  120. careamics/dataset/tiff_dataset.py +0 -211
  121. careamics/engine.py +0 -954
  122. careamics/manipulation/__init__.py +0 -4
  123. careamics/manipulation/pixel_manipulation.py +0 -158
  124. careamics/prediction/prediction_utils.py +0 -102
  125. careamics/utils/ascii_logo.txt +0 -9
  126. careamics/utils/augment.py +0 -65
  127. careamics/utils/normalization.py +0 -55
  128. careamics/utils/validators.py +0 -156
  129. careamics/utils/wandb.py +0 -121
  130. careamics-0.1.0rc1.dist-info/METADATA +0 -80
  131. careamics-0.1.0rc1.dist-info/RECORD +0 -46
  132. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,35 @@
1
+ from typing import Callable, Union
2
+
3
+ import torch.nn as nn
4
+
5
+ from ..config.support import SupportedActivation
6
+
7
+
8
+ def get_activation(activation: Union[SupportedActivation, str]) -> Callable:
9
+ """
10
+ Get activation function.
11
+
12
+ Parameters
13
+ ----------
14
+ activation : str
15
+ Activation function name.
16
+
17
+ Returns
18
+ -------
19
+ Callable
20
+ Activation function.
21
+ """
22
+ if activation == SupportedActivation.RELU:
23
+ return nn.ReLU()
24
+ elif activation == SupportedActivation.LEAKYRELU:
25
+ return nn.LeakyReLU()
26
+ elif activation == SupportedActivation.TANH:
27
+ return nn.Tanh()
28
+ elif activation == SupportedActivation.SIGMOID:
29
+ return nn.Sigmoid()
30
+ elif activation == SupportedActivation.SOFTMAX:
31
+ return nn.Softmax(dim=1)
32
+ elif activation == SupportedActivation.NONE:
33
+ return nn.Identity()
34
+ else:
35
+ raise ValueError(f"Activation {activation} not supported.")
@@ -3,8 +3,11 @@ Layer module.
3
3
 
4
4
  This submodule contains layers used in the CAREamics models.
5
5
  """
6
+ from typing import List, Optional, Tuple, Union
7
+
6
8
  import torch
7
9
  import torch.nn as nn
10
+ from torch.nn import functional as F
8
11
 
9
12
 
10
13
  class Conv_Block(nn.Module):
@@ -150,3 +153,244 @@ class Conv_Block(nn.Module):
150
153
  if self.dropout is not None:
151
154
  x = self.dropout(x)
152
155
  return x
156
+
157
+
158
+ def _unpack_kernel_size(
159
+ kernel_size: Union[Tuple[int, ...], int], dim: int
160
+ ) -> Tuple[int, ...]:
161
+ """Unpack kernel_size to a tuple of ints.
162
+
163
+ Inspired by Kornia implementation. TODO: link
164
+ """
165
+ if isinstance(kernel_size, int):
166
+ kernel_dims = tuple([kernel_size for _ in range(dim)])
167
+ else:
168
+ kernel_dims = kernel_size
169
+ return kernel_dims
170
+
171
+
172
+ def _compute_zero_padding(
173
+ kernel_size: Union[Tuple[int, ...], int], dim: int
174
+ ) -> Tuple[int, ...]:
175
+ """Utility function that computes zero padding tuple."""
176
+ kernel_dims = _unpack_kernel_size(kernel_size, dim)
177
+ return tuple([(kd - 1) // 2 for kd in kernel_dims])
178
+
179
+
180
+ def get_pascal_kernel_1d(
181
+ kernel_size: int,
182
+ norm: bool = False,
183
+ *,
184
+ device: Optional[torch.device] = None,
185
+ dtype: Optional[torch.dtype] = None,
186
+ ) -> torch.Tensor:
187
+ """Generate Yang Hui triangle (Pascal's triangle) for a given number.
188
+
189
+ Inspired by Kornia implementation. TODO link
190
+
191
+ Parameters
192
+ ----------
193
+ kernel_size: height and width of the kernel.
194
+ norm: if to normalize the kernel or not. Default: False.
195
+ device: tensor device
196
+ dtype: tensor dtype
197
+
198
+ Returns
199
+ -------
200
+ kernel shaped as :math:`(kernel_size,)`
201
+
202
+ Examples
203
+ --------
204
+ >>> get_pascal_kernel_1d(1)
205
+ tensor([1.])
206
+ >>> get_pascal_kernel_1d(2)
207
+ tensor([1., 1.])
208
+ >>> get_pascal_kernel_1d(3)
209
+ tensor([1., 2., 1.])
210
+ >>> get_pascal_kernel_1d(4)
211
+ tensor([1., 3., 3., 1.])
212
+ >>> get_pascal_kernel_1d(5)
213
+ tensor([1., 4., 6., 4., 1.])
214
+ >>> get_pascal_kernel_1d(6)
215
+ tensor([ 1., 5., 10., 10., 5., 1.])
216
+ """
217
+ pre: List[float] = []
218
+ cur: List[float] = []
219
+ for i in range(kernel_size):
220
+ cur = [1.0] * (i + 1)
221
+
222
+ for j in range(1, i // 2 + 1):
223
+ value = pre[j - 1] + pre[j]
224
+ cur[j] = value
225
+ if i != 2 * j:
226
+ cur[-j - 1] = value
227
+ pre = cur
228
+
229
+ out = torch.tensor(cur, device=device, dtype=dtype)
230
+
231
+ if norm:
232
+ out = out / out.sum()
233
+
234
+ return out
235
+
236
+
237
+ def _get_pascal_kernel_nd(
238
+ kernel_size: Union[Tuple[int, int], int],
239
+ norm: bool = True,
240
+ dim: int = 2,
241
+ *,
242
+ device: Optional[torch.device] = None,
243
+ dtype: Optional[torch.dtype] = None,
244
+ ) -> torch.Tensor:
245
+ """Generate pascal filter kernel by kernel size.
246
+
247
+ Inspired by Kornia implementation.
248
+
249
+ Parameters
250
+ ----------
251
+ kernel_size: height and width of the kernel.
252
+ norm: if to normalize the kernel or not. Default: True.
253
+ device: tensor device
254
+ dtype: tensor dtype
255
+
256
+ Returns
257
+ -------
258
+ if kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
259
+ otherwise the kernel will be shaped as kernel_size
260
+
261
+ Examples
262
+ --------
263
+ >>> _get_pascal_kernel_nd(1)
264
+ tensor([[1.]])
265
+ >>> _get_pascal_kernel_nd(4)
266
+ tensor([[0.0156, 0.0469, 0.0469, 0.0156],
267
+ [0.0469, 0.1406, 0.1406, 0.0469],
268
+ [0.0469, 0.1406, 0.1406, 0.0469],
269
+ [0.0156, 0.0469, 0.0469, 0.0156]])
270
+ >>> _get_pascal_kernel_nd(4, norm=False)
271
+ tensor([[1., 3., 3., 1.],
272
+ [3., 9., 9., 3.],
273
+ [3., 9., 9., 3.],
274
+ [1., 3., 3., 1.]])
275
+ """
276
+ kernel_dims = _unpack_kernel_size(kernel_size, dim)
277
+
278
+ kernel = [
279
+ get_pascal_kernel_1d(kd, device=device, dtype=dtype) for kd in kernel_dims
280
+ ]
281
+
282
+ if dim == 2:
283
+ kernel = kernel[0][:, None] * kernel[1][None, :]
284
+ elif dim == 3:
285
+ kernel = (
286
+ kernel[0][:, None, None]
287
+ * kernel[1][None, :, None]
288
+ * kernel[2][None, None, :]
289
+ )
290
+ if norm:
291
+ kernel = kernel / torch.sum(kernel)
292
+ return kernel
293
+
294
+
295
+ def _max_blur_pool_by_kernel2d(
296
+ x: torch.Tensor,
297
+ kernel: torch.Tensor,
298
+ stride: int,
299
+ max_pool_size: int,
300
+ ceil_mode: bool,
301
+ ) -> torch.Tensor:
302
+ """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
303
+
304
+ Inspired by Kornia implementation.
305
+ """
306
+ # compute local maxima
307
+ x = F.max_pool2d(
308
+ x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode
309
+ )
310
+ # blur and downsample
311
+ padding = _compute_zero_padding((kernel.shape[-2], kernel.shape[-1]), dim=2)
312
+ return F.conv2d(x, kernel, padding=padding, stride=stride, groups=x.size(1))
313
+
314
+
315
+ def _max_blur_pool_by_kernel3d(
316
+ x: torch.Tensor,
317
+ kernel: torch.Tensor,
318
+ stride: int,
319
+ max_pool_size: int,
320
+ ceil_mode: bool,
321
+ ) -> torch.Tensor:
322
+ """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
323
+
324
+ Inspired by Kornia implementation.
325
+ """
326
+ # compute local maxima
327
+ x = F.max_pool3d(
328
+ x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode
329
+ )
330
+ # blur and downsample
331
+ padding = _compute_zero_padding(
332
+ (kernel.shape[-3], kernel.shape[-2], kernel.shape[-1]), dim=3
333
+ )
334
+ return F.conv3d(x, kernel, padding=padding, stride=stride, groups=x.size(1))
335
+
336
+
337
+ class MaxBlurPool(nn.Module):
338
+ """Compute pools and blurs and downsample a given feature map.
339
+
340
+ Inspired by Kornia MaxBlurPool implementation. Equivalent to
341
+ ```nn.Sequential(nn.MaxPool2d(...), BlurPool2D(...))```
342
+
343
+ Parameters
344
+ ----------
345
+ dim: int
346
+ Toggles between 2D and 3D
347
+ kernel_size: Union[Tuple[int, int], int]
348
+ Kernel size for max pooling.
349
+ stride: int
350
+ Stride for pooling.
351
+ max_pool_size: int
352
+ Max kernel size for max pooling.
353
+ ceil_mode: bool
354
+ Should be true to match output size of conv2d with same kernel size.
355
+
356
+ Returns
357
+ -------
358
+ torch.Tensor
359
+ The pooled and blurred tensor.
360
+ """
361
+
362
+ def __init__(
363
+ self,
364
+ dim: int,
365
+ kernel_size: Union[Tuple[int, int], int],
366
+ stride: int = 2,
367
+ max_pool_size: int = 2,
368
+ ceil_mode: bool = False,
369
+ ) -> None:
370
+ super().__init__()
371
+ self.dim = dim
372
+ self.kernel_size = kernel_size
373
+ self.stride = stride
374
+ self.max_pool_size = max_pool_size
375
+ self.ceil_mode = ceil_mode
376
+ self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
377
+
378
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
379
+ """Forward pass of the function."""
380
+ self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
381
+ if self.dim == 2:
382
+ return _max_blur_pool_by_kernel2d(
383
+ x,
384
+ self.kernel.repeat((x.size(1), 1, 1, 1)),
385
+ self.stride,
386
+ self.max_pool_size,
387
+ self.ceil_mode,
388
+ )
389
+ else:
390
+ return _max_blur_pool_by_kernel3d(
391
+ x,
392
+ self.kernel.repeat((x.size(1), 1, 1, 1, 1)),
393
+ self.stride,
394
+ self.max_pool_size,
395
+ self.ceil_mode,
396
+ )
@@ -3,30 +3,30 @@ Model factory.
3
3
 
4
4
  Model creation factory functions.
5
5
  """
6
- from pathlib import Path
7
- from typing import Dict, Optional, Tuple, Union
6
+ from typing import Union
8
7
 
9
8
  import torch
10
9
 
11
- from ..bioimage import import_bioimage_model
12
- from ..config import Configuration
13
- from ..config.algorithm import Models
14
- from ..utils.logging import get_logger
10
+ from ..config.architectures import CustomModel, UNetModel, VAEModel, get_custom_model
11
+ from ..config.support import SupportedArchitecture
12
+ from ..utils import get_logger
15
13
  from .unet import UNet
16
14
 
17
15
  logger = get_logger(__name__)
18
16
 
19
17
 
20
- def model_registry(model_name: str) -> torch.nn.Module:
18
+ def model_factory(
19
+ model_configuration: Union[UNetModel, VAEModel, CustomModel]
20
+ ) -> torch.nn.Module:
21
21
  """
22
- Model factory.
22
+ Deep learning model factory.
23
23
 
24
- Supported models are defined in config.algorithm.Models.
24
+ Supported models are defined in careamics.config.SupportedArchitecture.
25
25
 
26
26
  Parameters
27
27
  ----------
28
- model_name : str
29
- Name of the model.
28
+ model_configuration : Union[UNetModel, VAEModel]
29
+ Model configuration
30
30
 
31
31
  Returns
32
32
  -------
@@ -36,197 +36,16 @@ def model_registry(model_name: str) -> torch.nn.Module:
36
36
  Raises
37
37
  ------
38
38
  NotImplementedError
39
- If the requested model is not implemented.
39
+ If the requested architecture is not implemented.
40
40
  """
41
- if model_name == Models.UNET:
42
- return UNet
43
- else:
44
- raise NotImplementedError(f"Model {model_name} is not implemented")
45
-
46
-
47
- def create_model(
48
- *,
49
- model_path: Optional[Union[str, Path]] = None,
50
- config: Optional[Configuration] = None,
51
- device: Optional[torch.device] = None,
52
- ) -> torch.nn.Module:
53
- """
54
- Instantiate a model from a checkpoint or configuration.
55
-
56
- If both checkpoint and configuration are provided, the checkpoint is used.
57
-
58
- Parameters
59
- ----------
60
- model_path : Optional[Union[str, Path]], optional
61
- Path to a checkpoint, by default None.
62
- config : Optional[Configuration], optional
63
- Configuration, by default None.
64
- device : Optional[torch.device], optional
65
- Torch device, by default None.
66
-
67
- Returns
68
- -------
69
- torch.nn.Module
70
- Instantiated model.
71
-
72
- Raises
73
- ------
74
- ValueError
75
- If the checkpoint path is invalid.
76
- ValueError
77
- If the checkpoint is invalid.
78
- ValueError
79
- If neither checkpoint nor configuration are provided.
80
- """
81
- if model_path is not None:
82
- # Create model from checkpoint
83
- model_path = Path(model_path)
84
- if not model_path.exists() or model_path.suffix not in [".pth", ".zip"]:
85
- raise ValueError(
86
- f"Invalid model path: {model_path}. Current working dir: \
87
- {Path.cwd()!s}"
88
- )
89
-
90
- if model_path.suffix == ".zip":
91
- model_path = import_bioimage_model(model_path)
92
-
93
- # Load checkpoint
94
- checkpoint = torch.load(model_path, map_location=device)
95
-
96
- # Load the configuration
97
- if "config" in checkpoint:
98
- config = Configuration(**checkpoint["config"])
99
- algo_config = config.algorithm
100
- model_config = algo_config.model_parameters
101
- model_name = algo_config.model
102
- else:
103
- raise ValueError("Invalid checkpoint format, no configuration found.")
104
-
105
- # Create model
106
- model = model_registry(model_name)(
107
- depth=model_config.depth,
108
- conv_dim=algo_config.get_conv_dim(),
109
- num_channels_init=model_config.num_channels_init,
110
- )
111
- model.to(device)
112
- # Load the model state dict
113
- if "model_state_dict" in checkpoint:
114
- model.load_state_dict(checkpoint["model_state_dict"])
115
- logger.info("Loaded model state dict")
116
- else:
117
- raise ValueError("Invalid checkpoint format")
118
-
119
- # Load the optimizer and scheduler
120
- optimizer, scheduler = get_optimizer_and_scheduler(
121
- config, model, state_dict=checkpoint
122
- )
123
- scaler = get_grad_scaler(config, state_dict=checkpoint)
124
-
125
- elif config is not None:
126
- # Create model from configuration
127
- algo_config = config.algorithm
128
- model_config = algo_config.model_parameters
129
- model_name = algo_config.model
130
-
131
- # Create model
132
- model = model_registry(model_name)(
133
- depth=model_config.depth,
134
- conv_dim=algo_config.get_conv_dim(),
135
- num_channels_init=model_config.num_channels_init,
136
- )
137
- model.to(device)
138
- optimizer, scheduler = get_optimizer_and_scheduler(config, model)
139
- scaler = get_grad_scaler(config)
140
- logger.info("Engine initialized from configuration")
41
+ if model_configuration.architecture == SupportedArchitecture.UNET:
42
+ return UNet(**model_configuration.model_dump())
43
+ elif model_configuration.architecture == SupportedArchitecture.CUSTOM:
44
+ assert isinstance(model_configuration, CustomModel)
45
+ model = get_custom_model(model_configuration.name)
141
46
 
47
+ return model(**model_configuration.model_dump())
142
48
  else:
143
- raise ValueError("Either config or model_path must be provided")
144
- # model = compile_model(model)
145
- return model, optimizer, scheduler, scaler, config
146
-
147
-
148
- def get_optimizer_and_scheduler(
149
- config: Configuration, model: torch.nn.Module, state_dict: Optional[Dict] = None
150
- ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
151
- """
152
- Create optimizer and learning rate schedulers.
153
-
154
- If a checkpoint state dictionary is provided, the optimizer and scheduler are
155
- instantiated to the same state as the checkpoint's optimizer and scheduler.
156
-
157
- Parameters
158
- ----------
159
- config : Configuration
160
- Configuration.
161
- model : torch.nn.Module
162
- Model.
163
- state_dict : Optional[Dict], optional
164
- Checkpoint state dictionary, by default None.
165
-
166
- Returns
167
- -------
168
- Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]
169
- Optimizer and scheduler.
170
- """
171
- # retrieve optimizer name and parameters from config
172
- optimizer_name = config.training.optimizer.name
173
- optimizer_params = config.training.optimizer.parameters
174
-
175
- # then instantiate it
176
- optimizer_func = getattr(torch.optim, optimizer_name)
177
- optimizer = optimizer_func(model.parameters(), **optimizer_params)
178
-
179
- # same for learning rate scheduler
180
- scheduler_name = config.training.lr_scheduler.name
181
- scheduler_params = config.training.lr_scheduler.parameters
182
- scheduler_func = getattr(torch.optim.lr_scheduler, scheduler_name)
183
- scheduler = scheduler_func(optimizer, **scheduler_params)
184
-
185
- # load state from ther checkpoint if available
186
- if state_dict is not None:
187
- if "optimizer_state_dict" in state_dict:
188
- optimizer.load_state_dict(state_dict["optimizer_state_dict"])
189
- logger.info("Loaded optimizer state dict")
190
- else:
191
- logger.warning(
192
- "No optimizer state dict found in checkpoint. Optimizer not loaded."
193
- )
194
- if "scheduler_state_dict" in state_dict:
195
- scheduler.load_state_dict(state_dict["scheduler_state_dict"])
196
- logger.info("Loaded LR scheduler state dict")
197
- else:
198
- logger.warning(
199
- "No LR scheduler state dict found in checkpoint. "
200
- "LR scheduler not loaded."
201
- )
202
- return optimizer, scheduler
203
-
204
-
205
- def get_grad_scaler(
206
- config: Configuration, state_dict: Optional[Dict] = None
207
- ) -> torch.cuda.amp.GradScaler:
208
- """
209
- Instantiate gradscaler.
210
-
211
- If a checkpoint state dictionary is provided, the scaler is instantiated to the
212
- same state as the checkpoint's scaler.
213
-
214
- Parameters
215
- ----------
216
- config : Configuration
217
- Configuration.
218
- state_dict : Optional[Dict], optional
219
- Checkpoint state dictionary, by default None.
220
-
221
- Returns
222
- -------
223
- torch.cuda.amp.GradScaler
224
- Instantiated gradscaler.
225
- """
226
- use = config.training.amp.use
227
- scaling = config.training.amp.init_scale
228
- scaler = torch.cuda.amp.GradScaler(init_scale=scaling, enabled=use)
229
- if state_dict is not None and "scaler_state_dict" in state_dict:
230
- scaler.load_state_dict(state_dict["scaler_state_dict"])
231
- logger.info("Loaded GradScaler state dict")
232
- return scaler
49
+ raise NotImplementedError(
50
+ f"Model {model_configuration.architecture} is not implemented or unknown."
51
+ )