careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (133) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +323 -134
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -14
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -221
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -12
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +112 -75
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -182
  111. careamics/bioimage/rdf.py +0 -105
  112. careamics/config/algorithm.py +0 -231
  113. careamics/config/config.py +0 -297
  114. careamics/config/config_filter.py +0 -44
  115. careamics/config/data.py +0 -194
  116. careamics/config/torch_optim.py +0 -118
  117. careamics/config/training.py +0 -534
  118. careamics/dataset/dataset_utils.py +0 -111
  119. careamics/dataset/patching.py +0 -492
  120. careamics/dataset/prepare_dataset.py +0 -175
  121. careamics/dataset/tiff_dataset.py +0 -212
  122. careamics/engine.py +0 -1014
  123. careamics/manipulation/__init__.py +0 -4
  124. careamics/manipulation/pixel_manipulation.py +0 -158
  125. careamics/prediction/prediction_utils.py +0 -106
  126. careamics/utils/ascii_logo.txt +0 -9
  127. careamics/utils/augment.py +0 -65
  128. careamics/utils/normalization.py +0 -55
  129. careamics/utils/validators.py +0 -170
  130. careamics/utils/wandb.py +0 -121
  131. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  132. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  133. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,35 @@
1
+ from typing import Callable, Union
2
+
3
+ import torch.nn as nn
4
+
5
+ from ..config.support import SupportedActivation
6
+
7
+
8
+ def get_activation(activation: Union[SupportedActivation, str]) -> Callable:
9
+ """
10
+ Get activation function.
11
+
12
+ Parameters
13
+ ----------
14
+ activation : str
15
+ Activation function name.
16
+
17
+ Returns
18
+ -------
19
+ Callable
20
+ Activation function.
21
+ """
22
+ if activation == SupportedActivation.RELU:
23
+ return nn.ReLU()
24
+ elif activation == SupportedActivation.LEAKYRELU:
25
+ return nn.LeakyReLU()
26
+ elif activation == SupportedActivation.TANH:
27
+ return nn.Tanh()
28
+ elif activation == SupportedActivation.SIGMOID:
29
+ return nn.Sigmoid()
30
+ elif activation == SupportedActivation.SOFTMAX:
31
+ return nn.Softmax(dim=1)
32
+ elif activation == SupportedActivation.NONE:
33
+ return nn.Identity()
34
+ else:
35
+ raise ValueError(f"Activation {activation} not supported.")
@@ -3,8 +3,11 @@ Layer module.
3
3
 
4
4
  This submodule contains layers used in the CAREamics models.
5
5
  """
6
+ from typing import List, Optional, Tuple, Union
7
+
6
8
  import torch
7
9
  import torch.nn as nn
10
+ from torch.nn import functional as F
8
11
 
9
12
 
10
13
  class Conv_Block(nn.Module):
@@ -150,3 +153,244 @@ class Conv_Block(nn.Module):
150
153
  if self.dropout is not None:
151
154
  x = self.dropout(x)
152
155
  return x
156
+
157
+
158
+ def _unpack_kernel_size(
159
+ kernel_size: Union[Tuple[int, ...], int], dim: int
160
+ ) -> Tuple[int, ...]:
161
+ """Unpack kernel_size to a tuple of ints.
162
+
163
+ Inspired by Kornia implementation. TODO: link
164
+ """
165
+ if isinstance(kernel_size, int):
166
+ kernel_dims = tuple([kernel_size for _ in range(dim)])
167
+ else:
168
+ kernel_dims = kernel_size
169
+ return kernel_dims
170
+
171
+
172
+ def _compute_zero_padding(
173
+ kernel_size: Union[Tuple[int, ...], int], dim: int
174
+ ) -> Tuple[int, ...]:
175
+ """Utility function that computes zero padding tuple."""
176
+ kernel_dims = _unpack_kernel_size(kernel_size, dim)
177
+ return tuple([(kd - 1) // 2 for kd in kernel_dims])
178
+
179
+
180
+ def get_pascal_kernel_1d(
181
+ kernel_size: int,
182
+ norm: bool = False,
183
+ *,
184
+ device: Optional[torch.device] = None,
185
+ dtype: Optional[torch.dtype] = None,
186
+ ) -> torch.Tensor:
187
+ """Generate Yang Hui triangle (Pascal's triangle) for a given number.
188
+
189
+ Inspired by Kornia implementation. TODO link
190
+
191
+ Parameters
192
+ ----------
193
+ kernel_size: height and width of the kernel.
194
+ norm: if to normalize the kernel or not. Default: False.
195
+ device: tensor device
196
+ dtype: tensor dtype
197
+
198
+ Returns
199
+ -------
200
+ kernel shaped as :math:`(kernel_size,)`
201
+
202
+ Examples
203
+ --------
204
+ >>> get_pascal_kernel_1d(1)
205
+ tensor([1.])
206
+ >>> get_pascal_kernel_1d(2)
207
+ tensor([1., 1.])
208
+ >>> get_pascal_kernel_1d(3)
209
+ tensor([1., 2., 1.])
210
+ >>> get_pascal_kernel_1d(4)
211
+ tensor([1., 3., 3., 1.])
212
+ >>> get_pascal_kernel_1d(5)
213
+ tensor([1., 4., 6., 4., 1.])
214
+ >>> get_pascal_kernel_1d(6)
215
+ tensor([ 1., 5., 10., 10., 5., 1.])
216
+ """
217
+ pre: List[float] = []
218
+ cur: List[float] = []
219
+ for i in range(kernel_size):
220
+ cur = [1.0] * (i + 1)
221
+
222
+ for j in range(1, i // 2 + 1):
223
+ value = pre[j - 1] + pre[j]
224
+ cur[j] = value
225
+ if i != 2 * j:
226
+ cur[-j - 1] = value
227
+ pre = cur
228
+
229
+ out = torch.tensor(cur, device=device, dtype=dtype)
230
+
231
+ if norm:
232
+ out = out / out.sum()
233
+
234
+ return out
235
+
236
+
237
+ def _get_pascal_kernel_nd(
238
+ kernel_size: Union[Tuple[int, int], int],
239
+ norm: bool = True,
240
+ dim: int = 2,
241
+ *,
242
+ device: Optional[torch.device] = None,
243
+ dtype: Optional[torch.dtype] = None,
244
+ ) -> torch.Tensor:
245
+ """Generate pascal filter kernel by kernel size.
246
+
247
+ Inspired by Kornia implementation.
248
+
249
+ Parameters
250
+ ----------
251
+ kernel_size: height and width of the kernel.
252
+ norm: if to normalize the kernel or not. Default: True.
253
+ device: tensor device
254
+ dtype: tensor dtype
255
+
256
+ Returns
257
+ -------
258
+ if kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
259
+ otherwise the kernel will be shaped as kernel_size
260
+
261
+ Examples
262
+ --------
263
+ >>> _get_pascal_kernel_nd(1)
264
+ tensor([[1.]])
265
+ >>> _get_pascal_kernel_nd(4)
266
+ tensor([[0.0156, 0.0469, 0.0469, 0.0156],
267
+ [0.0469, 0.1406, 0.1406, 0.0469],
268
+ [0.0469, 0.1406, 0.1406, 0.0469],
269
+ [0.0156, 0.0469, 0.0469, 0.0156]])
270
+ >>> _get_pascal_kernel_nd(4, norm=False)
271
+ tensor([[1., 3., 3., 1.],
272
+ [3., 9., 9., 3.],
273
+ [3., 9., 9., 3.],
274
+ [1., 3., 3., 1.]])
275
+ """
276
+ kernel_dims = _unpack_kernel_size(kernel_size, dim)
277
+
278
+ kernel = [
279
+ get_pascal_kernel_1d(kd, device=device, dtype=dtype) for kd in kernel_dims
280
+ ]
281
+
282
+ if dim == 2:
283
+ kernel = kernel[0][:, None] * kernel[1][None, :]
284
+ elif dim == 3:
285
+ kernel = (
286
+ kernel[0][:, None, None]
287
+ * kernel[1][None, :, None]
288
+ * kernel[2][None, None, :]
289
+ )
290
+ if norm:
291
+ kernel = kernel / torch.sum(kernel)
292
+ return kernel
293
+
294
+
295
+ def _max_blur_pool_by_kernel2d(
296
+ x: torch.Tensor,
297
+ kernel: torch.Tensor,
298
+ stride: int,
299
+ max_pool_size: int,
300
+ ceil_mode: bool,
301
+ ) -> torch.Tensor:
302
+ """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
303
+
304
+ Inspired by Kornia implementation.
305
+ """
306
+ # compute local maxima
307
+ x = F.max_pool2d(
308
+ x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode
309
+ )
310
+ # blur and downsample
311
+ padding = _compute_zero_padding((kernel.shape[-2], kernel.shape[-1]), dim=2)
312
+ return F.conv2d(x, kernel, padding=padding, stride=stride, groups=x.size(1))
313
+
314
+
315
+ def _max_blur_pool_by_kernel3d(
316
+ x: torch.Tensor,
317
+ kernel: torch.Tensor,
318
+ stride: int,
319
+ max_pool_size: int,
320
+ ceil_mode: bool,
321
+ ) -> torch.Tensor:
322
+ """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
323
+
324
+ Inspired by Kornia implementation.
325
+ """
326
+ # compute local maxima
327
+ x = F.max_pool3d(
328
+ x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode
329
+ )
330
+ # blur and downsample
331
+ padding = _compute_zero_padding(
332
+ (kernel.shape[-3], kernel.shape[-2], kernel.shape[-1]), dim=3
333
+ )
334
+ return F.conv3d(x, kernel, padding=padding, stride=stride, groups=x.size(1))
335
+
336
+
337
+ class MaxBlurPool(nn.Module):
338
+ """Compute pools and blurs and downsample a given feature map.
339
+
340
+ Inspired by Kornia MaxBlurPool implementation. Equivalent to
341
+ ```nn.Sequential(nn.MaxPool2d(...), BlurPool2D(...))```
342
+
343
+ Parameters
344
+ ----------
345
+ dim: int
346
+ Toggles between 2D and 3D
347
+ kernel_size: Union[Tuple[int, int], int]
348
+ Kernel size for max pooling.
349
+ stride: int
350
+ Stride for pooling.
351
+ max_pool_size: int
352
+ Max kernel size for max pooling.
353
+ ceil_mode: bool
354
+ Should be true to match output size of conv2d with same kernel size.
355
+
356
+ Returns
357
+ -------
358
+ torch.Tensor
359
+ The pooled and blurred tensor.
360
+ """
361
+
362
+ def __init__(
363
+ self,
364
+ dim: int,
365
+ kernel_size: Union[Tuple[int, int], int],
366
+ stride: int = 2,
367
+ max_pool_size: int = 2,
368
+ ceil_mode: bool = False,
369
+ ) -> None:
370
+ super().__init__()
371
+ self.dim = dim
372
+ self.kernel_size = kernel_size
373
+ self.stride = stride
374
+ self.max_pool_size = max_pool_size
375
+ self.ceil_mode = ceil_mode
376
+ self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
377
+
378
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
379
+ """Forward pass of the function."""
380
+ self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
381
+ if self.dim == 2:
382
+ return _max_blur_pool_by_kernel2d(
383
+ x,
384
+ self.kernel.repeat((x.size(1), 1, 1, 1)),
385
+ self.stride,
386
+ self.max_pool_size,
387
+ self.ceil_mode,
388
+ )
389
+ else:
390
+ return _max_blur_pool_by_kernel3d(
391
+ x,
392
+ self.kernel.repeat((x.size(1), 1, 1, 1, 1)),
393
+ self.stride,
394
+ self.max_pool_size,
395
+ self.ceil_mode,
396
+ )
@@ -3,31 +3,30 @@ Model factory.
3
3
 
4
4
  Model creation factory functions.
5
5
  """
6
- from pathlib import Path
7
- from typing import Dict, Optional, Tuple, Union
6
+ from typing import Union
8
7
 
9
8
  import torch
10
9
 
11
- from careamics.bioimage import import_bioimage_model
12
- from careamics.config import Configuration
13
- from careamics.config.algorithm import Models
14
- from careamics.utils.logging import get_logger
15
-
10
+ from ..config.architectures import CustomModel, UNetModel, VAEModel, get_custom_model
11
+ from ..config.support import SupportedArchitecture
12
+ from ..utils import get_logger
16
13
  from .unet import UNet
17
14
 
18
15
  logger = get_logger(__name__)
19
16
 
20
17
 
21
- def model_registry(model_name: str) -> torch.nn.Module:
18
+ def model_factory(
19
+ model_configuration: Union[UNetModel, VAEModel, CustomModel]
20
+ ) -> torch.nn.Module:
22
21
  """
23
- Model factory.
22
+ Deep learning model factory.
24
23
 
25
- Supported models are defined in config.algorithm.Models.
24
+ Supported models are defined in careamics.config.SupportedArchitecture.
26
25
 
27
26
  Parameters
28
27
  ----------
29
- model_name : str
30
- Name of the model.
28
+ model_configuration : Union[UNetModel, VAEModel]
29
+ Model configuration
31
30
 
32
31
  Returns
33
32
  -------
@@ -37,215 +36,16 @@ def model_registry(model_name: str) -> torch.nn.Module:
37
36
  Raises
38
37
  ------
39
38
  NotImplementedError
40
- If the requested model is not implemented.
41
- """
42
- if model_name == Models.UNET:
43
- return UNet
44
- else:
45
- raise NotImplementedError(f"Model {model_name} is not implemented")
46
-
47
-
48
- def create_model(
49
- *,
50
- model_path: Optional[Union[str, Path]] = None,
51
- config: Optional[Configuration] = None,
52
- device: Optional[torch.device] = None,
53
- ) -> Tuple[
54
- torch.nn.Module,
55
- torch.optim.Optimizer,
56
- Union[
57
- torch.optim.lr_scheduler.LRScheduler,
58
- torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
59
- ],
60
- torch.cuda.amp.GradScaler,
61
- Configuration,
62
- ]:
63
- """
64
- Instantiate a model from a model path or configuration.
65
-
66
- If both path and configuration are provided, the model path is used. The model
67
- path should point to either a checkpoint (created during training) or a model
68
- exported to the bioimage.io format.
69
-
70
- Parameters
71
- ----------
72
- model_path : Optional[Union[str, Path]], optional
73
- Path to a checkpoint or bioimage.io archive, by default None.
74
- config : Optional[Configuration], optional
75
- Configuration, by default None.
76
- device : Optional[torch.device], optional
77
- Torch device, by default None.
78
-
79
- Returns
80
- -------
81
- torch.nn.Module
82
- Instantiated model.
83
-
84
- Raises
85
- ------
86
- ValueError
87
- If the checkpoint path is invalid.
88
- ValueError
89
- If the checkpoint is invalid.
90
- ValueError
91
- If neither checkpoint nor configuration are provided.
39
+ If the requested architecture is not implemented.
92
40
  """
93
- if model_path is not None:
94
- # Create model from checkpoint
95
- model_path = Path(model_path)
96
- if not model_path.exists() or model_path.suffix not in [".pth", ".zip"]:
97
- raise ValueError(
98
- f"Invalid model path: {model_path}. Current working dir: \
99
- {Path.cwd()!s}"
100
- )
101
-
102
- if model_path.suffix == ".zip":
103
- model_path = import_bioimage_model(model_path)
104
-
105
- # Load checkpoint
106
- checkpoint = torch.load(model_path, map_location=device)
107
-
108
- # Load the configuration
109
- if "config" in checkpoint:
110
- config = Configuration(**checkpoint["config"])
111
- algo_config = config.algorithm
112
- model_config = algo_config.model_parameters
113
- model_name = algo_config.model
114
- else:
115
- raise ValueError("Invalid checkpoint format, no configuration found.")
116
-
117
- # Create model
118
- model: torch.nn.Module = model_registry(model_name)(
119
- depth=model_config.depth,
120
- conv_dim=algo_config.get_conv_dim(),
121
- num_channels_init=model_config.num_channels_init,
122
- )
123
- model.to(device)
124
-
125
- # Load the model state dict
126
- if "model_state_dict" in checkpoint:
127
- model.load_state_dict(checkpoint["model_state_dict"])
128
- logger.info("Loaded model state dict")
129
- else:
130
- raise ValueError("Invalid checkpoint format")
131
-
132
- # Load the optimizer and scheduler
133
- optimizer, scheduler = get_optimizer_and_scheduler(
134
- config, model, state_dict=checkpoint
135
- )
136
- scaler = get_grad_scaler(config, state_dict=checkpoint)
137
-
138
- elif config is not None:
139
- # Create model from configuration
140
- algo_config = config.algorithm
141
- model_config = algo_config.model_parameters
142
- model_name = algo_config.model
143
-
144
- # Create model
145
- model = model_registry(model_name)(
146
- depth=model_config.depth,
147
- conv_dim=algo_config.get_conv_dim(),
148
- num_channels_init=model_config.num_channels_init,
149
- )
150
- model.to(device)
151
- optimizer, scheduler = get_optimizer_and_scheduler(config, model)
152
- scaler = get_grad_scaler(config)
153
- logger.info("Engine initialized from configuration")
41
+ if model_configuration.architecture == SupportedArchitecture.UNET:
42
+ return UNet(**model_configuration.model_dump())
43
+ elif model_configuration.architecture == SupportedArchitecture.CUSTOM:
44
+ assert isinstance(model_configuration, CustomModel)
45
+ model = get_custom_model(model_configuration.name)
154
46
 
47
+ return model(**model_configuration.model_dump())
155
48
  else:
156
- raise ValueError("Either config or model_path must be provided")
157
-
158
- return model, optimizer, scheduler, scaler, config
159
-
160
-
161
- def get_optimizer_and_scheduler(
162
- config: Configuration, model: torch.nn.Module, state_dict: Optional[Dict] = None
163
- ) -> Tuple[
164
- torch.optim.Optimizer,
165
- Union[
166
- torch.optim.lr_scheduler.LRScheduler,
167
- torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
168
- ],
169
- ]:
170
- """
171
- Create optimizer and learning rate schedulers.
172
-
173
- If a checkpoint state dictionary is provided, the optimizer and scheduler are
174
- instantiated to the same state as the checkpoint's optimizer and scheduler.
175
-
176
- Parameters
177
- ----------
178
- config : Configuration
179
- Configuration.
180
- model : torch.nn.Module
181
- Model.
182
- state_dict : Optional[Dict], optional
183
- Checkpoint state dictionary, by default None.
184
-
185
- Returns
186
- -------
187
- Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]
188
- Optimizer and scheduler.
189
- """
190
- # retrieve optimizer name and parameters from config
191
- optimizer_name = config.training.optimizer.name
192
- optimizer_params = config.training.optimizer.parameters
193
-
194
- # then instantiate it
195
- optimizer_func = getattr(torch.optim, optimizer_name)
196
- optimizer = optimizer_func(model.parameters(), **optimizer_params)
197
-
198
- # same for learning rate scheduler
199
- scheduler_name = config.training.lr_scheduler.name
200
- scheduler_params = config.training.lr_scheduler.parameters
201
- scheduler_func = getattr(torch.optim.lr_scheduler, scheduler_name)
202
- scheduler = scheduler_func(optimizer, **scheduler_params)
203
-
204
- # load state from ther checkpoint if available
205
- if state_dict is not None:
206
- if "optimizer_state_dict" in state_dict:
207
- optimizer.load_state_dict(state_dict["optimizer_state_dict"])
208
- logger.info("Loaded optimizer state dict")
209
- else:
210
- logger.warning(
211
- "No optimizer state dict found in checkpoint. Optimizer not loaded."
212
- )
213
- if "scheduler_state_dict" in state_dict:
214
- scheduler.load_state_dict(state_dict["scheduler_state_dict"])
215
- logger.info("Loaded LR scheduler state dict")
216
- else:
217
- logger.warning(
218
- "No LR scheduler state dict found in checkpoint. "
219
- "LR scheduler not loaded."
220
- )
221
- return optimizer, scheduler
222
-
223
-
224
- def get_grad_scaler(
225
- config: Configuration, state_dict: Optional[Dict] = None
226
- ) -> torch.cuda.amp.GradScaler:
227
- """
228
- Instantiate gradscaler.
229
-
230
- If a checkpoint state dictionary is provided, the scaler is instantiated to the
231
- same state as the checkpoint's scaler.
232
-
233
- Parameters
234
- ----------
235
- config : Configuration
236
- Configuration.
237
- state_dict : Optional[Dict], optional
238
- Checkpoint state dictionary, by default None.
239
-
240
- Returns
241
- -------
242
- torch.cuda.amp.GradScaler
243
- Instantiated gradscaler.
244
- """
245
- use = config.training.amp.use
246
- scaling = config.training.amp.init_scale
247
- scaler = torch.cuda.amp.GradScaler(init_scale=scaling, enabled=use)
248
- if state_dict is not None and "scaler_state_dict" in state_dict:
249
- scaler.load_state_dict(state_dict["scaler_state_dict"])
250
- logger.info("Loaded GradScaler state dict")
251
- return scaler
49
+ raise NotImplementedError(
50
+ f"Model {model_configuration.architecture} is not implemented or unknown."
51
+ )