careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (48) hide show
  1. careamics/__init__.py +7 -1
  2. careamics/bioimage/__init__.py +15 -0
  3. careamics/bioimage/docs/Noise2Void.md +5 -0
  4. careamics/bioimage/docs/__init__.py +1 -0
  5. careamics/bioimage/io.py +182 -0
  6. careamics/bioimage/rdf.py +105 -0
  7. careamics/config/__init__.py +11 -0
  8. careamics/config/algorithm.py +231 -0
  9. careamics/config/config.py +297 -0
  10. careamics/config/config_filter.py +44 -0
  11. careamics/config/data.py +194 -0
  12. careamics/config/torch_optim.py +118 -0
  13. careamics/config/training.py +534 -0
  14. careamics/dataset/__init__.py +1 -0
  15. careamics/dataset/dataset_utils.py +111 -0
  16. careamics/dataset/extraction_strategy.py +21 -0
  17. careamics/dataset/in_memory_dataset.py +202 -0
  18. careamics/dataset/patching.py +492 -0
  19. careamics/dataset/prepare_dataset.py +175 -0
  20. careamics/dataset/tiff_dataset.py +212 -0
  21. careamics/engine.py +1014 -0
  22. careamics/losses/__init__.py +4 -0
  23. careamics/losses/loss_factory.py +38 -0
  24. careamics/losses/losses.py +34 -0
  25. careamics/manipulation/__init__.py +4 -0
  26. careamics/manipulation/pixel_manipulation.py +158 -0
  27. careamics/models/__init__.py +4 -0
  28. careamics/models/layers.py +152 -0
  29. careamics/models/model_factory.py +251 -0
  30. careamics/models/unet.py +322 -0
  31. careamics/prediction/__init__.py +9 -0
  32. careamics/prediction/prediction_utils.py +106 -0
  33. careamics/utils/__init__.py +20 -0
  34. careamics/utils/ascii_logo.txt +9 -0
  35. careamics/utils/augment.py +65 -0
  36. careamics/utils/context.py +45 -0
  37. careamics/utils/logging.py +321 -0
  38. careamics/utils/metrics.py +160 -0
  39. careamics/utils/normalization.py +55 -0
  40. careamics/utils/torch_utils.py +89 -0
  41. careamics/utils/validators.py +170 -0
  42. careamics/utils/wandb.py +121 -0
  43. careamics-0.1.0rc2.dist-info/METADATA +81 -0
  44. careamics-0.1.0rc2.dist-info/RECORD +47 -0
  45. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
  46. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
  47. careamics-0.0.1.dist-info/METADATA +0 -46
  48. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,4 @@
1
+ """Losses module."""
2
+
3
+
4
+ from .loss_factory import create_loss_function as create_loss_function
@@ -0,0 +1,38 @@
1
+ """
2
+ Loss factory module.
3
+
4
+ This module contains a factory function for creating loss functions.
5
+ """
6
+ from typing import Callable
7
+
8
+ from careamics.config import Configuration
9
+ from careamics.config.algorithm import Loss
10
+
11
+ from .losses import n2v_loss
12
+
13
+
14
+ def create_loss_function(config: Configuration) -> Callable:
15
+ """
16
+ Create loss function based on Configuration.
17
+
18
+ Parameters
19
+ ----------
20
+ config : Configuration
21
+ Configuration.
22
+
23
+ Returns
24
+ -------
25
+ Callable
26
+ Loss function.
27
+
28
+ Raises
29
+ ------
30
+ NotImplementedError
31
+ If the loss is unknown.
32
+ """
33
+ loss_type = config.algorithm.loss
34
+
35
+ if loss_type == Loss.N2V:
36
+ return n2v_loss
37
+ else:
38
+ raise NotImplementedError(f"Loss {loss_type} is not yet supported.")
@@ -0,0 +1,34 @@
1
+ """
2
+ Loss submodule.
3
+
4
+ This submodule contains the various losses used in CAREamics.
5
+ """
6
+ import torch
7
+
8
+
9
+ def n2v_loss(
10
+ samples: torch.Tensor, labels: torch.Tensor, masks: torch.Tensor, device: str
11
+ ) -> torch.Tensor:
12
+ """
13
+ N2V Loss function (see Eq.7 in Krull et al).
14
+
15
+ Parameters
16
+ ----------
17
+ samples : torch.Tensor
18
+ Patches with manipulated pixels.
19
+ labels : torch.Tensor
20
+ Noisy patches.
21
+ masks : torch.Tensor
22
+ Array containing masked pixel locations.
23
+ device : str
24
+ Device to use.
25
+
26
+ Returns
27
+ -------
28
+ torch.Tensor
29
+ Loss value.
30
+ """
31
+ errors = (labels - samples) ** 2
32
+ # Average over pixels and batch
33
+ loss = torch.sum(errors * masks) / torch.sum(masks)
34
+ return loss
@@ -0,0 +1,4 @@
1
+ """Pixel manipulation functions for N2V."""
2
+
3
+
4
+ from .pixel_manipulation import default_manipulate as default_manipulate
@@ -0,0 +1,158 @@
1
+ """
2
+ Pixel manipulation methods.
3
+
4
+ Pixel manipulation is used in N2V and similar algorithm to replace the value of
5
+ masked pixels.
6
+ """
7
+ from typing import Callable, Optional, Tuple
8
+
9
+ import numpy as np
10
+
11
+
12
+ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
13
+ """
14
+ Randomly sample a jitter to be applied to the masking grid.
15
+
16
+ This is done to account for cases where the step size is not an integer.
17
+
18
+ Parameters
19
+ ----------
20
+ step : float
21
+ Step size of the grid, output of np.linspace.
22
+ rng : np.random.Generator
23
+ Random number generator.
24
+
25
+ Returns
26
+ -------
27
+ np.ndarray
28
+ Array of random jitter to be added to the grid.
29
+ """
30
+ # Define the random jitter to be added to the grid
31
+ odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2))
32
+
33
+ # Round the step size to the nearest integer depending on the jitter
34
+ return np.floor(step) if odd_jitter == 0 else np.ceil(step)
35
+
36
+
37
+ def get_stratified_coords(
38
+ mask_pixel_perc: float,
39
+ shape: Tuple[int, ...],
40
+ ) -> np.ndarray:
41
+ """
42
+ Generate coordinates of the pixels to mask.
43
+
44
+ Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
45
+ the distance between masked pixels is approximately the same.
46
+
47
+ Parameters
48
+ ----------
49
+ mask_pixel_perc : float
50
+ Actual (quasi) percentage of masked pixels across the whole image. Used in
51
+ calculating the distance between masked pixels across each axis.
52
+ shape : Tuple[int, ...]
53
+ Shape of the input patch.
54
+
55
+ Returns
56
+ -------
57
+ np.ndarray
58
+ Array of coordinates of the masked pixels.
59
+ """
60
+ rng = np.random.default_rng()
61
+
62
+ # Define the approximate distance between masked pixels
63
+ mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
64
+ np.int32
65
+ )
66
+
67
+ # Define a grid of coordinates for each axis in the input patch and the step size
68
+ pixel_coords = []
69
+ for axis_size in shape:
70
+ # make sure axis size is evenly divisible by box size
71
+ num_pixels = int(np.ceil(axis_size / mask_pixel_distance))
72
+ axis_pixel_coords, step = np.linspace(
73
+ 0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True
74
+ )
75
+ # explain
76
+ pixel_coords.append(axis_pixel_coords.T)
77
+
78
+ # Create a meshgrid of coordinates for each axis in the input patch
79
+ coordinate_grid_list = np.meshgrid(*pixel_coords)
80
+ coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
81
+
82
+ grid_random_increment = rng.integers(
83
+ _odd_jitter_func(float(step), rng)
84
+ * np.ones_like(coordinate_grid).astype(np.int32)
85
+ - 1,
86
+ size=coordinate_grid.shape,
87
+ endpoint=True,
88
+ )
89
+ coordinate_grid += grid_random_increment
90
+ coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1)
91
+ return coordinate_grid
92
+
93
+
94
+ def default_manipulate(
95
+ patch: np.ndarray,
96
+ mask_pixel_percentage: float,
97
+ roi_size: int = 11,
98
+ augmentations: Optional[Callable] = None,
99
+ ) -> Tuple[np.ndarray, ...]:
100
+ """
101
+ Manipulate pixel in a patch, i.e. replace the masked value.
102
+
103
+ Parameters
104
+ ----------
105
+ patch : np.ndarray
106
+ Image patch, 2D or 3D, shape (y, x) or (z, y, x).
107
+ mask_pixel_percentage : floar
108
+ Approximate percentage of pixels to be masked.
109
+ roi_size : int
110
+ Size of the ROI the new pixel value is sampled from, by default 11.
111
+ augmentations : Callable, optional
112
+ Augmentations to apply, by default None.
113
+
114
+ Returns
115
+ -------
116
+ Tuple[np.ndarray]
117
+ Tuple containing the manipulated patch, the original patch and the mask.
118
+ """
119
+ original_patch = patch.copy()
120
+
121
+ # Get the coordinates of the pixels to be replaced
122
+ roi_centers = get_stratified_coords(mask_pixel_percentage, patch.shape)
123
+ rng = np.random.default_rng()
124
+
125
+ # Generate coordinate grid for ROI
126
+ roi_span_full = np.arange(-np.floor(roi_size / 2), np.ceil(roi_size / 2)).astype(
127
+ np.int32
128
+ )
129
+ # Remove the center pixel from the grid
130
+ roi_span_wo_center = roi_span_full[roi_span_full != 0]
131
+
132
+ # Randomly select coordinates from the grid
133
+ random_increment = rng.choice(roi_span_wo_center, size=roi_centers.shape)
134
+
135
+ # Clip the coordinates to the patch size
136
+ replacement_coords = np.clip(
137
+ roi_centers + random_increment,
138
+ 0,
139
+ [patch.shape[i] - 1 for i in range(len(patch.shape))],
140
+ )
141
+ # Get the replacement pixels from all rois
142
+ replacement_pixels = patch[tuple(replacement_coords.T.tolist())]
143
+
144
+ # Replace the original pixels with the replacement pixels
145
+ patch[tuple(roi_centers.T.tolist())] = replacement_pixels
146
+ mask = np.where(patch != original_patch, 1, 0).astype(np.uint8)
147
+
148
+ patch, original_patch, mask = (
149
+ (patch, original_patch, mask)
150
+ if augmentations is None
151
+ else augmentations(patch, original_patch, mask)
152
+ )
153
+
154
+ return (
155
+ np.expand_dims(patch, 0),
156
+ np.expand_dims(original_patch, 0),
157
+ np.expand_dims(mask, 0),
158
+ )
@@ -0,0 +1,4 @@
1
+ """Models package."""
2
+
3
+ from .model_factory import create_model as create_model
4
+ from .unet import UNet as UNet
@@ -0,0 +1,152 @@
1
+ """
2
+ Layer module.
3
+
4
+ This submodule contains layers used in the CAREamics models.
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class Conv_Block(nn.Module):
11
+ """
12
+ Convolution block used in UNets.
13
+
14
+ Convolution block consist of two convolution layers with optional batch norm,
15
+ dropout and with a final activation function.
16
+
17
+ The parameters are directly mapped to PyTorch Conv2D and Conv3d parameters, see
18
+ PyTorch torch.nn.Conv2d and torch.nn.Conv3d for more information.
19
+
20
+ Parameters
21
+ ----------
22
+ conv_dim : int
23
+ Number of dimension of the convolutions, 2 or 3.
24
+ in_channels : int
25
+ Number of input channels.
26
+ out_channels : int
27
+ Number of output channels.
28
+ intermediate_channel_multiplier : int, optional
29
+ Multiplied for the number of output channels, by default 1.
30
+ stride : int, optional
31
+ Stride of the convolutions, by default 1.
32
+ padding : int, optional
33
+ Padding of the convolutions, by default 1.
34
+ bias : bool, optional
35
+ Bias of the convolutions, by default True.
36
+ groups : int, optional
37
+ Controls the connections between inputs and outputs, by default 1.
38
+ activation : str, optional
39
+ Activation function, by default "ReLU".
40
+ dropout_perc : float, optional
41
+ Dropout percentage, by default 0.
42
+ use_batch_norm : bool, optional
43
+ Use batch norm, by default False.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ conv_dim: int,
49
+ in_channels: int,
50
+ out_channels: int,
51
+ intermediate_channel_multiplier: int = 1,
52
+ stride: int = 1,
53
+ padding: int = 1,
54
+ bias: bool = True,
55
+ groups: int = 1,
56
+ activation: str = "ReLU",
57
+ dropout_perc: float = 0,
58
+ use_batch_norm: bool = False,
59
+ ) -> None:
60
+ """
61
+ Constructor.
62
+
63
+ Parameters
64
+ ----------
65
+ conv_dim : int
66
+ Number of dimension of the convolutions, 2 or 3.
67
+ in_channels : int
68
+ Number of input channels.
69
+ out_channels : int
70
+ Number of output channels.
71
+ intermediate_channel_multiplier : int, optional
72
+ Multiplied for the number of output channels, by default 1.
73
+ stride : int, optional
74
+ Stride of the convolutions, by default 1.
75
+ padding : int, optional
76
+ Padding of the convolutions, by default 1.
77
+ bias : bool, optional
78
+ Bias of the convolutions, by default True.
79
+ groups : int, optional
80
+ Controls the connections between inputs and outputs, by default 1.
81
+ activation : str, optional
82
+ Activation function, by default "ReLU".
83
+ dropout_perc : float, optional
84
+ Dropout percentage, by default 0.
85
+ use_batch_norm : bool, optional
86
+ Use batch norm, by default False.
87
+ """
88
+ super().__init__()
89
+ self.use_batch_norm = use_batch_norm
90
+ self.conv1 = getattr(nn, f"Conv{conv_dim}d")(
91
+ in_channels,
92
+ out_channels * intermediate_channel_multiplier,
93
+ kernel_size=3,
94
+ stride=stride,
95
+ padding=padding,
96
+ bias=bias,
97
+ groups=groups,
98
+ )
99
+
100
+ self.conv2 = getattr(nn, f"Conv{conv_dim}d")(
101
+ out_channels * intermediate_channel_multiplier,
102
+ out_channels,
103
+ kernel_size=3,
104
+ stride=stride,
105
+ padding=padding,
106
+ bias=bias,
107
+ groups=groups,
108
+ )
109
+
110
+ self.batch_norm1 = getattr(nn, f"BatchNorm{conv_dim}d")(
111
+ out_channels * intermediate_channel_multiplier
112
+ )
113
+ self.batch_norm2 = getattr(nn, f"BatchNorm{conv_dim}d")(out_channels)
114
+
115
+ self.dropout = (
116
+ getattr(nn, f"Dropout{conv_dim}d")(dropout_perc)
117
+ if dropout_perc > 0
118
+ else None
119
+ )
120
+ self.activation = (
121
+ getattr(nn, f"{activation}")() if activation is not None else nn.Identity()
122
+ )
123
+
124
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Forward pass.
127
+
128
+ Parameters
129
+ ----------
130
+ x : torch.Tensor
131
+ Input tensor.
132
+
133
+ Returns
134
+ -------
135
+ torch.Tensor
136
+ Output tensor.
137
+ """
138
+ if self.use_batch_norm:
139
+ x = self.conv1(x)
140
+ x = self.batch_norm1(x)
141
+ x = self.activation(x)
142
+ x = self.conv2(x)
143
+ x = self.batch_norm2(x)
144
+ x = self.activation(x)
145
+ else:
146
+ x = self.conv1(x)
147
+ x = self.activation(x)
148
+ x = self.conv2(x)
149
+ x = self.activation(x)
150
+ if self.dropout is not None:
151
+ x = self.dropout(x)
152
+ return x
@@ -0,0 +1,251 @@
1
+ """
2
+ Model factory.
3
+
4
+ Model creation factory functions.
5
+ """
6
+ from pathlib import Path
7
+ from typing import Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+
11
+ from careamics.bioimage import import_bioimage_model
12
+ from careamics.config import Configuration
13
+ from careamics.config.algorithm import Models
14
+ from careamics.utils.logging import get_logger
15
+
16
+ from .unet import UNet
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ def model_registry(model_name: str) -> torch.nn.Module:
22
+ """
23
+ Model factory.
24
+
25
+ Supported models are defined in config.algorithm.Models.
26
+
27
+ Parameters
28
+ ----------
29
+ model_name : str
30
+ Name of the model.
31
+
32
+ Returns
33
+ -------
34
+ torch.nn.Module
35
+ Model class.
36
+
37
+ Raises
38
+ ------
39
+ NotImplementedError
40
+ If the requested model is not implemented.
41
+ """
42
+ if model_name == Models.UNET:
43
+ return UNet
44
+ else:
45
+ raise NotImplementedError(f"Model {model_name} is not implemented")
46
+
47
+
48
+ def create_model(
49
+ *,
50
+ model_path: Optional[Union[str, Path]] = None,
51
+ config: Optional[Configuration] = None,
52
+ device: Optional[torch.device] = None,
53
+ ) -> Tuple[
54
+ torch.nn.Module,
55
+ torch.optim.Optimizer,
56
+ Union[
57
+ torch.optim.lr_scheduler.LRScheduler,
58
+ torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
59
+ ],
60
+ torch.cuda.amp.GradScaler,
61
+ Configuration,
62
+ ]:
63
+ """
64
+ Instantiate a model from a model path or configuration.
65
+
66
+ If both path and configuration are provided, the model path is used. The model
67
+ path should point to either a checkpoint (created during training) or a model
68
+ exported to the bioimage.io format.
69
+
70
+ Parameters
71
+ ----------
72
+ model_path : Optional[Union[str, Path]], optional
73
+ Path to a checkpoint or bioimage.io archive, by default None.
74
+ config : Optional[Configuration], optional
75
+ Configuration, by default None.
76
+ device : Optional[torch.device], optional
77
+ Torch device, by default None.
78
+
79
+ Returns
80
+ -------
81
+ torch.nn.Module
82
+ Instantiated model.
83
+
84
+ Raises
85
+ ------
86
+ ValueError
87
+ If the checkpoint path is invalid.
88
+ ValueError
89
+ If the checkpoint is invalid.
90
+ ValueError
91
+ If neither checkpoint nor configuration are provided.
92
+ """
93
+ if model_path is not None:
94
+ # Create model from checkpoint
95
+ model_path = Path(model_path)
96
+ if not model_path.exists() or model_path.suffix not in [".pth", ".zip"]:
97
+ raise ValueError(
98
+ f"Invalid model path: {model_path}. Current working dir: \
99
+ {Path.cwd()!s}"
100
+ )
101
+
102
+ if model_path.suffix == ".zip":
103
+ model_path = import_bioimage_model(model_path)
104
+
105
+ # Load checkpoint
106
+ checkpoint = torch.load(model_path, map_location=device)
107
+
108
+ # Load the configuration
109
+ if "config" in checkpoint:
110
+ config = Configuration(**checkpoint["config"])
111
+ algo_config = config.algorithm
112
+ model_config = algo_config.model_parameters
113
+ model_name = algo_config.model
114
+ else:
115
+ raise ValueError("Invalid checkpoint format, no configuration found.")
116
+
117
+ # Create model
118
+ model: torch.nn.Module = model_registry(model_name)(
119
+ depth=model_config.depth,
120
+ conv_dim=algo_config.get_conv_dim(),
121
+ num_channels_init=model_config.num_channels_init,
122
+ )
123
+ model.to(device)
124
+
125
+ # Load the model state dict
126
+ if "model_state_dict" in checkpoint:
127
+ model.load_state_dict(checkpoint["model_state_dict"])
128
+ logger.info("Loaded model state dict")
129
+ else:
130
+ raise ValueError("Invalid checkpoint format")
131
+
132
+ # Load the optimizer and scheduler
133
+ optimizer, scheduler = get_optimizer_and_scheduler(
134
+ config, model, state_dict=checkpoint
135
+ )
136
+ scaler = get_grad_scaler(config, state_dict=checkpoint)
137
+
138
+ elif config is not None:
139
+ # Create model from configuration
140
+ algo_config = config.algorithm
141
+ model_config = algo_config.model_parameters
142
+ model_name = algo_config.model
143
+
144
+ # Create model
145
+ model = model_registry(model_name)(
146
+ depth=model_config.depth,
147
+ conv_dim=algo_config.get_conv_dim(),
148
+ num_channels_init=model_config.num_channels_init,
149
+ )
150
+ model.to(device)
151
+ optimizer, scheduler = get_optimizer_and_scheduler(config, model)
152
+ scaler = get_grad_scaler(config)
153
+ logger.info("Engine initialized from configuration")
154
+
155
+ else:
156
+ raise ValueError("Either config or model_path must be provided")
157
+
158
+ return model, optimizer, scheduler, scaler, config
159
+
160
+
161
+ def get_optimizer_and_scheduler(
162
+ config: Configuration, model: torch.nn.Module, state_dict: Optional[Dict] = None
163
+ ) -> Tuple[
164
+ torch.optim.Optimizer,
165
+ Union[
166
+ torch.optim.lr_scheduler.LRScheduler,
167
+ torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
168
+ ],
169
+ ]:
170
+ """
171
+ Create optimizer and learning rate schedulers.
172
+
173
+ If a checkpoint state dictionary is provided, the optimizer and scheduler are
174
+ instantiated to the same state as the checkpoint's optimizer and scheduler.
175
+
176
+ Parameters
177
+ ----------
178
+ config : Configuration
179
+ Configuration.
180
+ model : torch.nn.Module
181
+ Model.
182
+ state_dict : Optional[Dict], optional
183
+ Checkpoint state dictionary, by default None.
184
+
185
+ Returns
186
+ -------
187
+ Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]
188
+ Optimizer and scheduler.
189
+ """
190
+ # retrieve optimizer name and parameters from config
191
+ optimizer_name = config.training.optimizer.name
192
+ optimizer_params = config.training.optimizer.parameters
193
+
194
+ # then instantiate it
195
+ optimizer_func = getattr(torch.optim, optimizer_name)
196
+ optimizer = optimizer_func(model.parameters(), **optimizer_params)
197
+
198
+ # same for learning rate scheduler
199
+ scheduler_name = config.training.lr_scheduler.name
200
+ scheduler_params = config.training.lr_scheduler.parameters
201
+ scheduler_func = getattr(torch.optim.lr_scheduler, scheduler_name)
202
+ scheduler = scheduler_func(optimizer, **scheduler_params)
203
+
204
+ # load state from ther checkpoint if available
205
+ if state_dict is not None:
206
+ if "optimizer_state_dict" in state_dict:
207
+ optimizer.load_state_dict(state_dict["optimizer_state_dict"])
208
+ logger.info("Loaded optimizer state dict")
209
+ else:
210
+ logger.warning(
211
+ "No optimizer state dict found in checkpoint. Optimizer not loaded."
212
+ )
213
+ if "scheduler_state_dict" in state_dict:
214
+ scheduler.load_state_dict(state_dict["scheduler_state_dict"])
215
+ logger.info("Loaded LR scheduler state dict")
216
+ else:
217
+ logger.warning(
218
+ "No LR scheduler state dict found in checkpoint. "
219
+ "LR scheduler not loaded."
220
+ )
221
+ return optimizer, scheduler
222
+
223
+
224
+ def get_grad_scaler(
225
+ config: Configuration, state_dict: Optional[Dict] = None
226
+ ) -> torch.cuda.amp.GradScaler:
227
+ """
228
+ Instantiate gradscaler.
229
+
230
+ If a checkpoint state dictionary is provided, the scaler is instantiated to the
231
+ same state as the checkpoint's scaler.
232
+
233
+ Parameters
234
+ ----------
235
+ config : Configuration
236
+ Configuration.
237
+ state_dict : Optional[Dict], optional
238
+ Checkpoint state dictionary, by default None.
239
+
240
+ Returns
241
+ -------
242
+ torch.cuda.amp.GradScaler
243
+ Instantiated gradscaler.
244
+ """
245
+ use = config.training.amp.use
246
+ scaling = config.training.amp.init_scale
247
+ scaler = torch.cuda.amp.GradScaler(init_scale=scaling, enabled=use)
248
+ if state_dict is not None and "scaler_state_dict" in state_dict:
249
+ scaler.load_state_dict(state_dict["scaler_state_dict"])
250
+ logger.info("Loaded GradScaler state dict")
251
+ return scaler