careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +7 -1
- careamics/bioimage/__init__.py +15 -0
- careamics/bioimage/docs/Noise2Void.md +5 -0
- careamics/bioimage/docs/__init__.py +1 -0
- careamics/bioimage/io.py +182 -0
- careamics/bioimage/rdf.py +105 -0
- careamics/config/__init__.py +11 -0
- careamics/config/algorithm.py +231 -0
- careamics/config/config.py +297 -0
- careamics/config/config_filter.py +44 -0
- careamics/config/data.py +194 -0
- careamics/config/torch_optim.py +118 -0
- careamics/config/training.py +534 -0
- careamics/dataset/__init__.py +1 -0
- careamics/dataset/dataset_utils.py +111 -0
- careamics/dataset/extraction_strategy.py +21 -0
- careamics/dataset/in_memory_dataset.py +202 -0
- careamics/dataset/patching.py +492 -0
- careamics/dataset/prepare_dataset.py +175 -0
- careamics/dataset/tiff_dataset.py +212 -0
- careamics/engine.py +1014 -0
- careamics/losses/__init__.py +4 -0
- careamics/losses/loss_factory.py +38 -0
- careamics/losses/losses.py +34 -0
- careamics/manipulation/__init__.py +4 -0
- careamics/manipulation/pixel_manipulation.py +158 -0
- careamics/models/__init__.py +4 -0
- careamics/models/layers.py +152 -0
- careamics/models/model_factory.py +251 -0
- careamics/models/unet.py +322 -0
- careamics/prediction/__init__.py +9 -0
- careamics/prediction/prediction_utils.py +106 -0
- careamics/utils/__init__.py +20 -0
- careamics/utils/ascii_logo.txt +9 -0
- careamics/utils/augment.py +65 -0
- careamics/utils/context.py +45 -0
- careamics/utils/logging.py +321 -0
- careamics/utils/metrics.py +160 -0
- careamics/utils/normalization.py +55 -0
- careamics/utils/torch_utils.py +89 -0
- careamics/utils/validators.py +170 -0
- careamics/utils/wandb.py +121 -0
- careamics-0.1.0rc2.dist-info/METADATA +81 -0
- careamics-0.1.0rc2.dist-info/RECORD +47 -0
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss factory module.
|
|
3
|
+
|
|
4
|
+
This module contains a factory function for creating loss functions.
|
|
5
|
+
"""
|
|
6
|
+
from typing import Callable
|
|
7
|
+
|
|
8
|
+
from careamics.config import Configuration
|
|
9
|
+
from careamics.config.algorithm import Loss
|
|
10
|
+
|
|
11
|
+
from .losses import n2v_loss
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def create_loss_function(config: Configuration) -> Callable:
|
|
15
|
+
"""
|
|
16
|
+
Create loss function based on Configuration.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
config : Configuration
|
|
21
|
+
Configuration.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
Callable
|
|
26
|
+
Loss function.
|
|
27
|
+
|
|
28
|
+
Raises
|
|
29
|
+
------
|
|
30
|
+
NotImplementedError
|
|
31
|
+
If the loss is unknown.
|
|
32
|
+
"""
|
|
33
|
+
loss_type = config.algorithm.loss
|
|
34
|
+
|
|
35
|
+
if loss_type == Loss.N2V:
|
|
36
|
+
return n2v_loss
|
|
37
|
+
else:
|
|
38
|
+
raise NotImplementedError(f"Loss {loss_type} is not yet supported.")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss submodule.
|
|
3
|
+
|
|
4
|
+
This submodule contains the various losses used in CAREamics.
|
|
5
|
+
"""
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def n2v_loss(
|
|
10
|
+
samples: torch.Tensor, labels: torch.Tensor, masks: torch.Tensor, device: str
|
|
11
|
+
) -> torch.Tensor:
|
|
12
|
+
"""
|
|
13
|
+
N2V Loss function (see Eq.7 in Krull et al).
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
samples : torch.Tensor
|
|
18
|
+
Patches with manipulated pixels.
|
|
19
|
+
labels : torch.Tensor
|
|
20
|
+
Noisy patches.
|
|
21
|
+
masks : torch.Tensor
|
|
22
|
+
Array containing masked pixel locations.
|
|
23
|
+
device : str
|
|
24
|
+
Device to use.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
torch.Tensor
|
|
29
|
+
Loss value.
|
|
30
|
+
"""
|
|
31
|
+
errors = (labels - samples) ** 2
|
|
32
|
+
# Average over pixels and batch
|
|
33
|
+
loss = torch.sum(errors * masks) / torch.sum(masks)
|
|
34
|
+
return loss
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pixel manipulation methods.
|
|
3
|
+
|
|
4
|
+
Pixel manipulation is used in N2V and similar algorithm to replace the value of
|
|
5
|
+
masked pixels.
|
|
6
|
+
"""
|
|
7
|
+
from typing import Callable, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
13
|
+
"""
|
|
14
|
+
Randomly sample a jitter to be applied to the masking grid.
|
|
15
|
+
|
|
16
|
+
This is done to account for cases where the step size is not an integer.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
step : float
|
|
21
|
+
Step size of the grid, output of np.linspace.
|
|
22
|
+
rng : np.random.Generator
|
|
23
|
+
Random number generator.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
np.ndarray
|
|
28
|
+
Array of random jitter to be added to the grid.
|
|
29
|
+
"""
|
|
30
|
+
# Define the random jitter to be added to the grid
|
|
31
|
+
odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2))
|
|
32
|
+
|
|
33
|
+
# Round the step size to the nearest integer depending on the jitter
|
|
34
|
+
return np.floor(step) if odd_jitter == 0 else np.ceil(step)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_stratified_coords(
|
|
38
|
+
mask_pixel_perc: float,
|
|
39
|
+
shape: Tuple[int, ...],
|
|
40
|
+
) -> np.ndarray:
|
|
41
|
+
"""
|
|
42
|
+
Generate coordinates of the pixels to mask.
|
|
43
|
+
|
|
44
|
+
Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
|
|
45
|
+
the distance between masked pixels is approximately the same.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
mask_pixel_perc : float
|
|
50
|
+
Actual (quasi) percentage of masked pixels across the whole image. Used in
|
|
51
|
+
calculating the distance between masked pixels across each axis.
|
|
52
|
+
shape : Tuple[int, ...]
|
|
53
|
+
Shape of the input patch.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
np.ndarray
|
|
58
|
+
Array of coordinates of the masked pixels.
|
|
59
|
+
"""
|
|
60
|
+
rng = np.random.default_rng()
|
|
61
|
+
|
|
62
|
+
# Define the approximate distance between masked pixels
|
|
63
|
+
mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
|
|
64
|
+
np.int32
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Define a grid of coordinates for each axis in the input patch and the step size
|
|
68
|
+
pixel_coords = []
|
|
69
|
+
for axis_size in shape:
|
|
70
|
+
# make sure axis size is evenly divisible by box size
|
|
71
|
+
num_pixels = int(np.ceil(axis_size / mask_pixel_distance))
|
|
72
|
+
axis_pixel_coords, step = np.linspace(
|
|
73
|
+
0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True
|
|
74
|
+
)
|
|
75
|
+
# explain
|
|
76
|
+
pixel_coords.append(axis_pixel_coords.T)
|
|
77
|
+
|
|
78
|
+
# Create a meshgrid of coordinates for each axis in the input patch
|
|
79
|
+
coordinate_grid_list = np.meshgrid(*pixel_coords)
|
|
80
|
+
coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
|
|
81
|
+
|
|
82
|
+
grid_random_increment = rng.integers(
|
|
83
|
+
_odd_jitter_func(float(step), rng)
|
|
84
|
+
* np.ones_like(coordinate_grid).astype(np.int32)
|
|
85
|
+
- 1,
|
|
86
|
+
size=coordinate_grid.shape,
|
|
87
|
+
endpoint=True,
|
|
88
|
+
)
|
|
89
|
+
coordinate_grid += grid_random_increment
|
|
90
|
+
coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1)
|
|
91
|
+
return coordinate_grid
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def default_manipulate(
|
|
95
|
+
patch: np.ndarray,
|
|
96
|
+
mask_pixel_percentage: float,
|
|
97
|
+
roi_size: int = 11,
|
|
98
|
+
augmentations: Optional[Callable] = None,
|
|
99
|
+
) -> Tuple[np.ndarray, ...]:
|
|
100
|
+
"""
|
|
101
|
+
Manipulate pixel in a patch, i.e. replace the masked value.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
patch : np.ndarray
|
|
106
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x).
|
|
107
|
+
mask_pixel_percentage : floar
|
|
108
|
+
Approximate percentage of pixels to be masked.
|
|
109
|
+
roi_size : int
|
|
110
|
+
Size of the ROI the new pixel value is sampled from, by default 11.
|
|
111
|
+
augmentations : Callable, optional
|
|
112
|
+
Augmentations to apply, by default None.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
Tuple[np.ndarray]
|
|
117
|
+
Tuple containing the manipulated patch, the original patch and the mask.
|
|
118
|
+
"""
|
|
119
|
+
original_patch = patch.copy()
|
|
120
|
+
|
|
121
|
+
# Get the coordinates of the pixels to be replaced
|
|
122
|
+
roi_centers = get_stratified_coords(mask_pixel_percentage, patch.shape)
|
|
123
|
+
rng = np.random.default_rng()
|
|
124
|
+
|
|
125
|
+
# Generate coordinate grid for ROI
|
|
126
|
+
roi_span_full = np.arange(-np.floor(roi_size / 2), np.ceil(roi_size / 2)).astype(
|
|
127
|
+
np.int32
|
|
128
|
+
)
|
|
129
|
+
# Remove the center pixel from the grid
|
|
130
|
+
roi_span_wo_center = roi_span_full[roi_span_full != 0]
|
|
131
|
+
|
|
132
|
+
# Randomly select coordinates from the grid
|
|
133
|
+
random_increment = rng.choice(roi_span_wo_center, size=roi_centers.shape)
|
|
134
|
+
|
|
135
|
+
# Clip the coordinates to the patch size
|
|
136
|
+
replacement_coords = np.clip(
|
|
137
|
+
roi_centers + random_increment,
|
|
138
|
+
0,
|
|
139
|
+
[patch.shape[i] - 1 for i in range(len(patch.shape))],
|
|
140
|
+
)
|
|
141
|
+
# Get the replacement pixels from all rois
|
|
142
|
+
replacement_pixels = patch[tuple(replacement_coords.T.tolist())]
|
|
143
|
+
|
|
144
|
+
# Replace the original pixels with the replacement pixels
|
|
145
|
+
patch[tuple(roi_centers.T.tolist())] = replacement_pixels
|
|
146
|
+
mask = np.where(patch != original_patch, 1, 0).astype(np.uint8)
|
|
147
|
+
|
|
148
|
+
patch, original_patch, mask = (
|
|
149
|
+
(patch, original_patch, mask)
|
|
150
|
+
if augmentations is None
|
|
151
|
+
else augmentations(patch, original_patch, mask)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return (
|
|
155
|
+
np.expand_dims(patch, 0),
|
|
156
|
+
np.expand_dims(original_patch, 0),
|
|
157
|
+
np.expand_dims(mask, 0),
|
|
158
|
+
)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Layer module.
|
|
3
|
+
|
|
4
|
+
This submodule contains layers used in the CAREamics models.
|
|
5
|
+
"""
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Conv_Block(nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
Convolution block used in UNets.
|
|
13
|
+
|
|
14
|
+
Convolution block consist of two convolution layers with optional batch norm,
|
|
15
|
+
dropout and with a final activation function.
|
|
16
|
+
|
|
17
|
+
The parameters are directly mapped to PyTorch Conv2D and Conv3d parameters, see
|
|
18
|
+
PyTorch torch.nn.Conv2d and torch.nn.Conv3d for more information.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
conv_dim : int
|
|
23
|
+
Number of dimension of the convolutions, 2 or 3.
|
|
24
|
+
in_channels : int
|
|
25
|
+
Number of input channels.
|
|
26
|
+
out_channels : int
|
|
27
|
+
Number of output channels.
|
|
28
|
+
intermediate_channel_multiplier : int, optional
|
|
29
|
+
Multiplied for the number of output channels, by default 1.
|
|
30
|
+
stride : int, optional
|
|
31
|
+
Stride of the convolutions, by default 1.
|
|
32
|
+
padding : int, optional
|
|
33
|
+
Padding of the convolutions, by default 1.
|
|
34
|
+
bias : bool, optional
|
|
35
|
+
Bias of the convolutions, by default True.
|
|
36
|
+
groups : int, optional
|
|
37
|
+
Controls the connections between inputs and outputs, by default 1.
|
|
38
|
+
activation : str, optional
|
|
39
|
+
Activation function, by default "ReLU".
|
|
40
|
+
dropout_perc : float, optional
|
|
41
|
+
Dropout percentage, by default 0.
|
|
42
|
+
use_batch_norm : bool, optional
|
|
43
|
+
Use batch norm, by default False.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
conv_dim: int,
|
|
49
|
+
in_channels: int,
|
|
50
|
+
out_channels: int,
|
|
51
|
+
intermediate_channel_multiplier: int = 1,
|
|
52
|
+
stride: int = 1,
|
|
53
|
+
padding: int = 1,
|
|
54
|
+
bias: bool = True,
|
|
55
|
+
groups: int = 1,
|
|
56
|
+
activation: str = "ReLU",
|
|
57
|
+
dropout_perc: float = 0,
|
|
58
|
+
use_batch_norm: bool = False,
|
|
59
|
+
) -> None:
|
|
60
|
+
"""
|
|
61
|
+
Constructor.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
conv_dim : int
|
|
66
|
+
Number of dimension of the convolutions, 2 or 3.
|
|
67
|
+
in_channels : int
|
|
68
|
+
Number of input channels.
|
|
69
|
+
out_channels : int
|
|
70
|
+
Number of output channels.
|
|
71
|
+
intermediate_channel_multiplier : int, optional
|
|
72
|
+
Multiplied for the number of output channels, by default 1.
|
|
73
|
+
stride : int, optional
|
|
74
|
+
Stride of the convolutions, by default 1.
|
|
75
|
+
padding : int, optional
|
|
76
|
+
Padding of the convolutions, by default 1.
|
|
77
|
+
bias : bool, optional
|
|
78
|
+
Bias of the convolutions, by default True.
|
|
79
|
+
groups : int, optional
|
|
80
|
+
Controls the connections between inputs and outputs, by default 1.
|
|
81
|
+
activation : str, optional
|
|
82
|
+
Activation function, by default "ReLU".
|
|
83
|
+
dropout_perc : float, optional
|
|
84
|
+
Dropout percentage, by default 0.
|
|
85
|
+
use_batch_norm : bool, optional
|
|
86
|
+
Use batch norm, by default False.
|
|
87
|
+
"""
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.use_batch_norm = use_batch_norm
|
|
90
|
+
self.conv1 = getattr(nn, f"Conv{conv_dim}d")(
|
|
91
|
+
in_channels,
|
|
92
|
+
out_channels * intermediate_channel_multiplier,
|
|
93
|
+
kernel_size=3,
|
|
94
|
+
stride=stride,
|
|
95
|
+
padding=padding,
|
|
96
|
+
bias=bias,
|
|
97
|
+
groups=groups,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self.conv2 = getattr(nn, f"Conv{conv_dim}d")(
|
|
101
|
+
out_channels * intermediate_channel_multiplier,
|
|
102
|
+
out_channels,
|
|
103
|
+
kernel_size=3,
|
|
104
|
+
stride=stride,
|
|
105
|
+
padding=padding,
|
|
106
|
+
bias=bias,
|
|
107
|
+
groups=groups,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
self.batch_norm1 = getattr(nn, f"BatchNorm{conv_dim}d")(
|
|
111
|
+
out_channels * intermediate_channel_multiplier
|
|
112
|
+
)
|
|
113
|
+
self.batch_norm2 = getattr(nn, f"BatchNorm{conv_dim}d")(out_channels)
|
|
114
|
+
|
|
115
|
+
self.dropout = (
|
|
116
|
+
getattr(nn, f"Dropout{conv_dim}d")(dropout_perc)
|
|
117
|
+
if dropout_perc > 0
|
|
118
|
+
else None
|
|
119
|
+
)
|
|
120
|
+
self.activation = (
|
|
121
|
+
getattr(nn, f"{activation}")() if activation is not None else nn.Identity()
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
125
|
+
"""
|
|
126
|
+
Forward pass.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
x : torch.Tensor
|
|
131
|
+
Input tensor.
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
torch.Tensor
|
|
136
|
+
Output tensor.
|
|
137
|
+
"""
|
|
138
|
+
if self.use_batch_norm:
|
|
139
|
+
x = self.conv1(x)
|
|
140
|
+
x = self.batch_norm1(x)
|
|
141
|
+
x = self.activation(x)
|
|
142
|
+
x = self.conv2(x)
|
|
143
|
+
x = self.batch_norm2(x)
|
|
144
|
+
x = self.activation(x)
|
|
145
|
+
else:
|
|
146
|
+
x = self.conv1(x)
|
|
147
|
+
x = self.activation(x)
|
|
148
|
+
x = self.conv2(x)
|
|
149
|
+
x = self.activation(x)
|
|
150
|
+
if self.dropout is not None:
|
|
151
|
+
x = self.dropout(x)
|
|
152
|
+
return x
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model factory.
|
|
3
|
+
|
|
4
|
+
Model creation factory functions.
|
|
5
|
+
"""
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from careamics.bioimage import import_bioimage_model
|
|
12
|
+
from careamics.config import Configuration
|
|
13
|
+
from careamics.config.algorithm import Models
|
|
14
|
+
from careamics.utils.logging import get_logger
|
|
15
|
+
|
|
16
|
+
from .unet import UNet
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def model_registry(model_name: str) -> torch.nn.Module:
|
|
22
|
+
"""
|
|
23
|
+
Model factory.
|
|
24
|
+
|
|
25
|
+
Supported models are defined in config.algorithm.Models.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
model_name : str
|
|
30
|
+
Name of the model.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
torch.nn.Module
|
|
35
|
+
Model class.
|
|
36
|
+
|
|
37
|
+
Raises
|
|
38
|
+
------
|
|
39
|
+
NotImplementedError
|
|
40
|
+
If the requested model is not implemented.
|
|
41
|
+
"""
|
|
42
|
+
if model_name == Models.UNET:
|
|
43
|
+
return UNet
|
|
44
|
+
else:
|
|
45
|
+
raise NotImplementedError(f"Model {model_name} is not implemented")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def create_model(
|
|
49
|
+
*,
|
|
50
|
+
model_path: Optional[Union[str, Path]] = None,
|
|
51
|
+
config: Optional[Configuration] = None,
|
|
52
|
+
device: Optional[torch.device] = None,
|
|
53
|
+
) -> Tuple[
|
|
54
|
+
torch.nn.Module,
|
|
55
|
+
torch.optim.Optimizer,
|
|
56
|
+
Union[
|
|
57
|
+
torch.optim.lr_scheduler.LRScheduler,
|
|
58
|
+
torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
|
|
59
|
+
],
|
|
60
|
+
torch.cuda.amp.GradScaler,
|
|
61
|
+
Configuration,
|
|
62
|
+
]:
|
|
63
|
+
"""
|
|
64
|
+
Instantiate a model from a model path or configuration.
|
|
65
|
+
|
|
66
|
+
If both path and configuration are provided, the model path is used. The model
|
|
67
|
+
path should point to either a checkpoint (created during training) or a model
|
|
68
|
+
exported to the bioimage.io format.
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
model_path : Optional[Union[str, Path]], optional
|
|
73
|
+
Path to a checkpoint or bioimage.io archive, by default None.
|
|
74
|
+
config : Optional[Configuration], optional
|
|
75
|
+
Configuration, by default None.
|
|
76
|
+
device : Optional[torch.device], optional
|
|
77
|
+
Torch device, by default None.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
torch.nn.Module
|
|
82
|
+
Instantiated model.
|
|
83
|
+
|
|
84
|
+
Raises
|
|
85
|
+
------
|
|
86
|
+
ValueError
|
|
87
|
+
If the checkpoint path is invalid.
|
|
88
|
+
ValueError
|
|
89
|
+
If the checkpoint is invalid.
|
|
90
|
+
ValueError
|
|
91
|
+
If neither checkpoint nor configuration are provided.
|
|
92
|
+
"""
|
|
93
|
+
if model_path is not None:
|
|
94
|
+
# Create model from checkpoint
|
|
95
|
+
model_path = Path(model_path)
|
|
96
|
+
if not model_path.exists() or model_path.suffix not in [".pth", ".zip"]:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"Invalid model path: {model_path}. Current working dir: \
|
|
99
|
+
{Path.cwd()!s}"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if model_path.suffix == ".zip":
|
|
103
|
+
model_path = import_bioimage_model(model_path)
|
|
104
|
+
|
|
105
|
+
# Load checkpoint
|
|
106
|
+
checkpoint = torch.load(model_path, map_location=device)
|
|
107
|
+
|
|
108
|
+
# Load the configuration
|
|
109
|
+
if "config" in checkpoint:
|
|
110
|
+
config = Configuration(**checkpoint["config"])
|
|
111
|
+
algo_config = config.algorithm
|
|
112
|
+
model_config = algo_config.model_parameters
|
|
113
|
+
model_name = algo_config.model
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError("Invalid checkpoint format, no configuration found.")
|
|
116
|
+
|
|
117
|
+
# Create model
|
|
118
|
+
model: torch.nn.Module = model_registry(model_name)(
|
|
119
|
+
depth=model_config.depth,
|
|
120
|
+
conv_dim=algo_config.get_conv_dim(),
|
|
121
|
+
num_channels_init=model_config.num_channels_init,
|
|
122
|
+
)
|
|
123
|
+
model.to(device)
|
|
124
|
+
|
|
125
|
+
# Load the model state dict
|
|
126
|
+
if "model_state_dict" in checkpoint:
|
|
127
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
128
|
+
logger.info("Loaded model state dict")
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError("Invalid checkpoint format")
|
|
131
|
+
|
|
132
|
+
# Load the optimizer and scheduler
|
|
133
|
+
optimizer, scheduler = get_optimizer_and_scheduler(
|
|
134
|
+
config, model, state_dict=checkpoint
|
|
135
|
+
)
|
|
136
|
+
scaler = get_grad_scaler(config, state_dict=checkpoint)
|
|
137
|
+
|
|
138
|
+
elif config is not None:
|
|
139
|
+
# Create model from configuration
|
|
140
|
+
algo_config = config.algorithm
|
|
141
|
+
model_config = algo_config.model_parameters
|
|
142
|
+
model_name = algo_config.model
|
|
143
|
+
|
|
144
|
+
# Create model
|
|
145
|
+
model = model_registry(model_name)(
|
|
146
|
+
depth=model_config.depth,
|
|
147
|
+
conv_dim=algo_config.get_conv_dim(),
|
|
148
|
+
num_channels_init=model_config.num_channels_init,
|
|
149
|
+
)
|
|
150
|
+
model.to(device)
|
|
151
|
+
optimizer, scheduler = get_optimizer_and_scheduler(config, model)
|
|
152
|
+
scaler = get_grad_scaler(config)
|
|
153
|
+
logger.info("Engine initialized from configuration")
|
|
154
|
+
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError("Either config or model_path must be provided")
|
|
157
|
+
|
|
158
|
+
return model, optimizer, scheduler, scaler, config
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def get_optimizer_and_scheduler(
|
|
162
|
+
config: Configuration, model: torch.nn.Module, state_dict: Optional[Dict] = None
|
|
163
|
+
) -> Tuple[
|
|
164
|
+
torch.optim.Optimizer,
|
|
165
|
+
Union[
|
|
166
|
+
torch.optim.lr_scheduler.LRScheduler,
|
|
167
|
+
torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
|
|
168
|
+
],
|
|
169
|
+
]:
|
|
170
|
+
"""
|
|
171
|
+
Create optimizer and learning rate schedulers.
|
|
172
|
+
|
|
173
|
+
If a checkpoint state dictionary is provided, the optimizer and scheduler are
|
|
174
|
+
instantiated to the same state as the checkpoint's optimizer and scheduler.
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
config : Configuration
|
|
179
|
+
Configuration.
|
|
180
|
+
model : torch.nn.Module
|
|
181
|
+
Model.
|
|
182
|
+
state_dict : Optional[Dict], optional
|
|
183
|
+
Checkpoint state dictionary, by default None.
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]
|
|
188
|
+
Optimizer and scheduler.
|
|
189
|
+
"""
|
|
190
|
+
# retrieve optimizer name and parameters from config
|
|
191
|
+
optimizer_name = config.training.optimizer.name
|
|
192
|
+
optimizer_params = config.training.optimizer.parameters
|
|
193
|
+
|
|
194
|
+
# then instantiate it
|
|
195
|
+
optimizer_func = getattr(torch.optim, optimizer_name)
|
|
196
|
+
optimizer = optimizer_func(model.parameters(), **optimizer_params)
|
|
197
|
+
|
|
198
|
+
# same for learning rate scheduler
|
|
199
|
+
scheduler_name = config.training.lr_scheduler.name
|
|
200
|
+
scheduler_params = config.training.lr_scheduler.parameters
|
|
201
|
+
scheduler_func = getattr(torch.optim.lr_scheduler, scheduler_name)
|
|
202
|
+
scheduler = scheduler_func(optimizer, **scheduler_params)
|
|
203
|
+
|
|
204
|
+
# load state from ther checkpoint if available
|
|
205
|
+
if state_dict is not None:
|
|
206
|
+
if "optimizer_state_dict" in state_dict:
|
|
207
|
+
optimizer.load_state_dict(state_dict["optimizer_state_dict"])
|
|
208
|
+
logger.info("Loaded optimizer state dict")
|
|
209
|
+
else:
|
|
210
|
+
logger.warning(
|
|
211
|
+
"No optimizer state dict found in checkpoint. Optimizer not loaded."
|
|
212
|
+
)
|
|
213
|
+
if "scheduler_state_dict" in state_dict:
|
|
214
|
+
scheduler.load_state_dict(state_dict["scheduler_state_dict"])
|
|
215
|
+
logger.info("Loaded LR scheduler state dict")
|
|
216
|
+
else:
|
|
217
|
+
logger.warning(
|
|
218
|
+
"No LR scheduler state dict found in checkpoint. "
|
|
219
|
+
"LR scheduler not loaded."
|
|
220
|
+
)
|
|
221
|
+
return optimizer, scheduler
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def get_grad_scaler(
|
|
225
|
+
config: Configuration, state_dict: Optional[Dict] = None
|
|
226
|
+
) -> torch.cuda.amp.GradScaler:
|
|
227
|
+
"""
|
|
228
|
+
Instantiate gradscaler.
|
|
229
|
+
|
|
230
|
+
If a checkpoint state dictionary is provided, the scaler is instantiated to the
|
|
231
|
+
same state as the checkpoint's scaler.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
config : Configuration
|
|
236
|
+
Configuration.
|
|
237
|
+
state_dict : Optional[Dict], optional
|
|
238
|
+
Checkpoint state dictionary, by default None.
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
torch.cuda.amp.GradScaler
|
|
243
|
+
Instantiated gradscaler.
|
|
244
|
+
"""
|
|
245
|
+
use = config.training.amp.use
|
|
246
|
+
scaling = config.training.amp.init_scale
|
|
247
|
+
scaler = torch.cuda.amp.GradScaler(init_scale=scaling, enabled=use)
|
|
248
|
+
if state_dict is not None and "scaler_state_dict" in state_dict:
|
|
249
|
+
scaler.load_state_dict(state_dict["scaler_state_dict"])
|
|
250
|
+
logger.info("Loaded GradScaler state dict")
|
|
251
|
+
return scaler
|