careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Callable, Union
|
|
2
|
+
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
from ..config.support import SupportedActivation
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_activation(activation: Union[SupportedActivation, str]) -> Callable:
|
|
9
|
+
"""
|
|
10
|
+
Get activation function.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
activation : str
|
|
15
|
+
Activation function name.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
Callable
|
|
20
|
+
Activation function.
|
|
21
|
+
"""
|
|
22
|
+
if activation == SupportedActivation.RELU:
|
|
23
|
+
return nn.ReLU()
|
|
24
|
+
elif activation == SupportedActivation.LEAKYRELU:
|
|
25
|
+
return nn.LeakyReLU()
|
|
26
|
+
elif activation == SupportedActivation.TANH:
|
|
27
|
+
return nn.Tanh()
|
|
28
|
+
elif activation == SupportedActivation.SIGMOID:
|
|
29
|
+
return nn.Sigmoid()
|
|
30
|
+
elif activation == SupportedActivation.SOFTMAX:
|
|
31
|
+
return nn.Softmax(dim=1)
|
|
32
|
+
elif activation == SupportedActivation.NONE:
|
|
33
|
+
return nn.Identity()
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f"Activation {activation} not supported.")
|
careamics/models/layers.py
CHANGED
|
@@ -3,8 +3,11 @@ Layer module.
|
|
|
3
3
|
|
|
4
4
|
This submodule contains layers used in the CAREamics models.
|
|
5
5
|
"""
|
|
6
|
+
from typing import List, Optional, Tuple, Union
|
|
7
|
+
|
|
6
8
|
import torch
|
|
7
9
|
import torch.nn as nn
|
|
10
|
+
from torch.nn import functional as F
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
class Conv_Block(nn.Module):
|
|
@@ -150,3 +153,244 @@ class Conv_Block(nn.Module):
|
|
|
150
153
|
if self.dropout is not None:
|
|
151
154
|
x = self.dropout(x)
|
|
152
155
|
return x
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _unpack_kernel_size(
|
|
159
|
+
kernel_size: Union[Tuple[int, ...], int], dim: int
|
|
160
|
+
) -> Tuple[int, ...]:
|
|
161
|
+
"""Unpack kernel_size to a tuple of ints.
|
|
162
|
+
|
|
163
|
+
Inspired by Kornia implementation. TODO: link
|
|
164
|
+
"""
|
|
165
|
+
if isinstance(kernel_size, int):
|
|
166
|
+
kernel_dims = tuple([kernel_size for _ in range(dim)])
|
|
167
|
+
else:
|
|
168
|
+
kernel_dims = kernel_size
|
|
169
|
+
return kernel_dims
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _compute_zero_padding(
|
|
173
|
+
kernel_size: Union[Tuple[int, ...], int], dim: int
|
|
174
|
+
) -> Tuple[int, ...]:
|
|
175
|
+
"""Utility function that computes zero padding tuple."""
|
|
176
|
+
kernel_dims = _unpack_kernel_size(kernel_size, dim)
|
|
177
|
+
return tuple([(kd - 1) // 2 for kd in kernel_dims])
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_pascal_kernel_1d(
|
|
181
|
+
kernel_size: int,
|
|
182
|
+
norm: bool = False,
|
|
183
|
+
*,
|
|
184
|
+
device: Optional[torch.device] = None,
|
|
185
|
+
dtype: Optional[torch.dtype] = None,
|
|
186
|
+
) -> torch.Tensor:
|
|
187
|
+
"""Generate Yang Hui triangle (Pascal's triangle) for a given number.
|
|
188
|
+
|
|
189
|
+
Inspired by Kornia implementation. TODO link
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
kernel_size: height and width of the kernel.
|
|
194
|
+
norm: if to normalize the kernel or not. Default: False.
|
|
195
|
+
device: tensor device
|
|
196
|
+
dtype: tensor dtype
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
kernel shaped as :math:`(kernel_size,)`
|
|
201
|
+
|
|
202
|
+
Examples
|
|
203
|
+
--------
|
|
204
|
+
>>> get_pascal_kernel_1d(1)
|
|
205
|
+
tensor([1.])
|
|
206
|
+
>>> get_pascal_kernel_1d(2)
|
|
207
|
+
tensor([1., 1.])
|
|
208
|
+
>>> get_pascal_kernel_1d(3)
|
|
209
|
+
tensor([1., 2., 1.])
|
|
210
|
+
>>> get_pascal_kernel_1d(4)
|
|
211
|
+
tensor([1., 3., 3., 1.])
|
|
212
|
+
>>> get_pascal_kernel_1d(5)
|
|
213
|
+
tensor([1., 4., 6., 4., 1.])
|
|
214
|
+
>>> get_pascal_kernel_1d(6)
|
|
215
|
+
tensor([ 1., 5., 10., 10., 5., 1.])
|
|
216
|
+
"""
|
|
217
|
+
pre: List[float] = []
|
|
218
|
+
cur: List[float] = []
|
|
219
|
+
for i in range(kernel_size):
|
|
220
|
+
cur = [1.0] * (i + 1)
|
|
221
|
+
|
|
222
|
+
for j in range(1, i // 2 + 1):
|
|
223
|
+
value = pre[j - 1] + pre[j]
|
|
224
|
+
cur[j] = value
|
|
225
|
+
if i != 2 * j:
|
|
226
|
+
cur[-j - 1] = value
|
|
227
|
+
pre = cur
|
|
228
|
+
|
|
229
|
+
out = torch.tensor(cur, device=device, dtype=dtype)
|
|
230
|
+
|
|
231
|
+
if norm:
|
|
232
|
+
out = out / out.sum()
|
|
233
|
+
|
|
234
|
+
return out
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _get_pascal_kernel_nd(
|
|
238
|
+
kernel_size: Union[Tuple[int, int], int],
|
|
239
|
+
norm: bool = True,
|
|
240
|
+
dim: int = 2,
|
|
241
|
+
*,
|
|
242
|
+
device: Optional[torch.device] = None,
|
|
243
|
+
dtype: Optional[torch.dtype] = None,
|
|
244
|
+
) -> torch.Tensor:
|
|
245
|
+
"""Generate pascal filter kernel by kernel size.
|
|
246
|
+
|
|
247
|
+
Inspired by Kornia implementation.
|
|
248
|
+
|
|
249
|
+
Parameters
|
|
250
|
+
----------
|
|
251
|
+
kernel_size: height and width of the kernel.
|
|
252
|
+
norm: if to normalize the kernel or not. Default: True.
|
|
253
|
+
device: tensor device
|
|
254
|
+
dtype: tensor dtype
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
if kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
|
|
259
|
+
otherwise the kernel will be shaped as kernel_size
|
|
260
|
+
|
|
261
|
+
Examples
|
|
262
|
+
--------
|
|
263
|
+
>>> _get_pascal_kernel_nd(1)
|
|
264
|
+
tensor([[1.]])
|
|
265
|
+
>>> _get_pascal_kernel_nd(4)
|
|
266
|
+
tensor([[0.0156, 0.0469, 0.0469, 0.0156],
|
|
267
|
+
[0.0469, 0.1406, 0.1406, 0.0469],
|
|
268
|
+
[0.0469, 0.1406, 0.1406, 0.0469],
|
|
269
|
+
[0.0156, 0.0469, 0.0469, 0.0156]])
|
|
270
|
+
>>> _get_pascal_kernel_nd(4, norm=False)
|
|
271
|
+
tensor([[1., 3., 3., 1.],
|
|
272
|
+
[3., 9., 9., 3.],
|
|
273
|
+
[3., 9., 9., 3.],
|
|
274
|
+
[1., 3., 3., 1.]])
|
|
275
|
+
"""
|
|
276
|
+
kernel_dims = _unpack_kernel_size(kernel_size, dim)
|
|
277
|
+
|
|
278
|
+
kernel = [
|
|
279
|
+
get_pascal_kernel_1d(kd, device=device, dtype=dtype) for kd in kernel_dims
|
|
280
|
+
]
|
|
281
|
+
|
|
282
|
+
if dim == 2:
|
|
283
|
+
kernel = kernel[0][:, None] * kernel[1][None, :]
|
|
284
|
+
elif dim == 3:
|
|
285
|
+
kernel = (
|
|
286
|
+
kernel[0][:, None, None]
|
|
287
|
+
* kernel[1][None, :, None]
|
|
288
|
+
* kernel[2][None, None, :]
|
|
289
|
+
)
|
|
290
|
+
if norm:
|
|
291
|
+
kernel = kernel / torch.sum(kernel)
|
|
292
|
+
return kernel
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _max_blur_pool_by_kernel2d(
|
|
296
|
+
x: torch.Tensor,
|
|
297
|
+
kernel: torch.Tensor,
|
|
298
|
+
stride: int,
|
|
299
|
+
max_pool_size: int,
|
|
300
|
+
ceil_mode: bool,
|
|
301
|
+
) -> torch.Tensor:
|
|
302
|
+
"""Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
|
|
303
|
+
|
|
304
|
+
Inspired by Kornia implementation.
|
|
305
|
+
"""
|
|
306
|
+
# compute local maxima
|
|
307
|
+
x = F.max_pool2d(
|
|
308
|
+
x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode
|
|
309
|
+
)
|
|
310
|
+
# blur and downsample
|
|
311
|
+
padding = _compute_zero_padding((kernel.shape[-2], kernel.shape[-1]), dim=2)
|
|
312
|
+
return F.conv2d(x, kernel, padding=padding, stride=stride, groups=x.size(1))
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _max_blur_pool_by_kernel3d(
|
|
316
|
+
x: torch.Tensor,
|
|
317
|
+
kernel: torch.Tensor,
|
|
318
|
+
stride: int,
|
|
319
|
+
max_pool_size: int,
|
|
320
|
+
ceil_mode: bool,
|
|
321
|
+
) -> torch.Tensor:
|
|
322
|
+
"""Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
|
|
323
|
+
|
|
324
|
+
Inspired by Kornia implementation.
|
|
325
|
+
"""
|
|
326
|
+
# compute local maxima
|
|
327
|
+
x = F.max_pool3d(
|
|
328
|
+
x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode
|
|
329
|
+
)
|
|
330
|
+
# blur and downsample
|
|
331
|
+
padding = _compute_zero_padding(
|
|
332
|
+
(kernel.shape[-3], kernel.shape[-2], kernel.shape[-1]), dim=3
|
|
333
|
+
)
|
|
334
|
+
return F.conv3d(x, kernel, padding=padding, stride=stride, groups=x.size(1))
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class MaxBlurPool(nn.Module):
|
|
338
|
+
"""Compute pools and blurs and downsample a given feature map.
|
|
339
|
+
|
|
340
|
+
Inspired by Kornia MaxBlurPool implementation. Equivalent to
|
|
341
|
+
```nn.Sequential(nn.MaxPool2d(...), BlurPool2D(...))```
|
|
342
|
+
|
|
343
|
+
Parameters
|
|
344
|
+
----------
|
|
345
|
+
dim: int
|
|
346
|
+
Toggles between 2D and 3D
|
|
347
|
+
kernel_size: Union[Tuple[int, int], int]
|
|
348
|
+
Kernel size for max pooling.
|
|
349
|
+
stride: int
|
|
350
|
+
Stride for pooling.
|
|
351
|
+
max_pool_size: int
|
|
352
|
+
Max kernel size for max pooling.
|
|
353
|
+
ceil_mode: bool
|
|
354
|
+
Should be true to match output size of conv2d with same kernel size.
|
|
355
|
+
|
|
356
|
+
Returns
|
|
357
|
+
-------
|
|
358
|
+
torch.Tensor
|
|
359
|
+
The pooled and blurred tensor.
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
def __init__(
|
|
363
|
+
self,
|
|
364
|
+
dim: int,
|
|
365
|
+
kernel_size: Union[Tuple[int, int], int],
|
|
366
|
+
stride: int = 2,
|
|
367
|
+
max_pool_size: int = 2,
|
|
368
|
+
ceil_mode: bool = False,
|
|
369
|
+
) -> None:
|
|
370
|
+
super().__init__()
|
|
371
|
+
self.dim = dim
|
|
372
|
+
self.kernel_size = kernel_size
|
|
373
|
+
self.stride = stride
|
|
374
|
+
self.max_pool_size = max_pool_size
|
|
375
|
+
self.ceil_mode = ceil_mode
|
|
376
|
+
self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
|
|
377
|
+
|
|
378
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
379
|
+
"""Forward pass of the function."""
|
|
380
|
+
self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
|
|
381
|
+
if self.dim == 2:
|
|
382
|
+
return _max_blur_pool_by_kernel2d(
|
|
383
|
+
x,
|
|
384
|
+
self.kernel.repeat((x.size(1), 1, 1, 1)),
|
|
385
|
+
self.stride,
|
|
386
|
+
self.max_pool_size,
|
|
387
|
+
self.ceil_mode,
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
return _max_blur_pool_by_kernel3d(
|
|
391
|
+
x,
|
|
392
|
+
self.kernel.repeat((x.size(1), 1, 1, 1, 1)),
|
|
393
|
+
self.stride,
|
|
394
|
+
self.max_pool_size,
|
|
395
|
+
self.ceil_mode,
|
|
396
|
+
)
|
|
@@ -3,31 +3,30 @@ Model factory.
|
|
|
3
3
|
|
|
4
4
|
Model creation factory functions.
|
|
5
5
|
"""
|
|
6
|
-
from
|
|
7
|
-
from typing import Dict, Optional, Tuple, Union
|
|
6
|
+
from typing import Union
|
|
8
7
|
|
|
9
8
|
import torch
|
|
10
9
|
|
|
11
|
-
from
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from careamics.utils.logging import get_logger
|
|
15
|
-
|
|
10
|
+
from ..config.architectures import CustomModel, UNetModel, VAEModel, get_custom_model
|
|
11
|
+
from ..config.support import SupportedArchitecture
|
|
12
|
+
from ..utils import get_logger
|
|
16
13
|
from .unet import UNet
|
|
17
14
|
|
|
18
15
|
logger = get_logger(__name__)
|
|
19
16
|
|
|
20
17
|
|
|
21
|
-
def
|
|
18
|
+
def model_factory(
|
|
19
|
+
model_configuration: Union[UNetModel, VAEModel, CustomModel]
|
|
20
|
+
) -> torch.nn.Module:
|
|
22
21
|
"""
|
|
23
|
-
|
|
22
|
+
Deep learning model factory.
|
|
24
23
|
|
|
25
|
-
Supported models are defined in config.
|
|
24
|
+
Supported models are defined in careamics.config.SupportedArchitecture.
|
|
26
25
|
|
|
27
26
|
Parameters
|
|
28
27
|
----------
|
|
29
|
-
|
|
30
|
-
|
|
28
|
+
model_configuration : Union[UNetModel, VAEModel]
|
|
29
|
+
Model configuration
|
|
31
30
|
|
|
32
31
|
Returns
|
|
33
32
|
-------
|
|
@@ -37,215 +36,16 @@ def model_registry(model_name: str) -> torch.nn.Module:
|
|
|
37
36
|
Raises
|
|
38
37
|
------
|
|
39
38
|
NotImplementedError
|
|
40
|
-
If the requested
|
|
41
|
-
"""
|
|
42
|
-
if model_name == Models.UNET:
|
|
43
|
-
return UNet
|
|
44
|
-
else:
|
|
45
|
-
raise NotImplementedError(f"Model {model_name} is not implemented")
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def create_model(
|
|
49
|
-
*,
|
|
50
|
-
model_path: Optional[Union[str, Path]] = None,
|
|
51
|
-
config: Optional[Configuration] = None,
|
|
52
|
-
device: Optional[torch.device] = None,
|
|
53
|
-
) -> Tuple[
|
|
54
|
-
torch.nn.Module,
|
|
55
|
-
torch.optim.Optimizer,
|
|
56
|
-
Union[
|
|
57
|
-
torch.optim.lr_scheduler.LRScheduler,
|
|
58
|
-
torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
|
|
59
|
-
],
|
|
60
|
-
torch.cuda.amp.GradScaler,
|
|
61
|
-
Configuration,
|
|
62
|
-
]:
|
|
63
|
-
"""
|
|
64
|
-
Instantiate a model from a model path or configuration.
|
|
65
|
-
|
|
66
|
-
If both path and configuration are provided, the model path is used. The model
|
|
67
|
-
path should point to either a checkpoint (created during training) or a model
|
|
68
|
-
exported to the bioimage.io format.
|
|
69
|
-
|
|
70
|
-
Parameters
|
|
71
|
-
----------
|
|
72
|
-
model_path : Optional[Union[str, Path]], optional
|
|
73
|
-
Path to a checkpoint or bioimage.io archive, by default None.
|
|
74
|
-
config : Optional[Configuration], optional
|
|
75
|
-
Configuration, by default None.
|
|
76
|
-
device : Optional[torch.device], optional
|
|
77
|
-
Torch device, by default None.
|
|
78
|
-
|
|
79
|
-
Returns
|
|
80
|
-
-------
|
|
81
|
-
torch.nn.Module
|
|
82
|
-
Instantiated model.
|
|
83
|
-
|
|
84
|
-
Raises
|
|
85
|
-
------
|
|
86
|
-
ValueError
|
|
87
|
-
If the checkpoint path is invalid.
|
|
88
|
-
ValueError
|
|
89
|
-
If the checkpoint is invalid.
|
|
90
|
-
ValueError
|
|
91
|
-
If neither checkpoint nor configuration are provided.
|
|
39
|
+
If the requested architecture is not implemented.
|
|
92
40
|
"""
|
|
93
|
-
if
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
f"Invalid model path: {model_path}. Current working dir: \
|
|
99
|
-
{Path.cwd()!s}"
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
if model_path.suffix == ".zip":
|
|
103
|
-
model_path = import_bioimage_model(model_path)
|
|
104
|
-
|
|
105
|
-
# Load checkpoint
|
|
106
|
-
checkpoint = torch.load(model_path, map_location=device)
|
|
107
|
-
|
|
108
|
-
# Load the configuration
|
|
109
|
-
if "config" in checkpoint:
|
|
110
|
-
config = Configuration(**checkpoint["config"])
|
|
111
|
-
algo_config = config.algorithm
|
|
112
|
-
model_config = algo_config.model_parameters
|
|
113
|
-
model_name = algo_config.model
|
|
114
|
-
else:
|
|
115
|
-
raise ValueError("Invalid checkpoint format, no configuration found.")
|
|
116
|
-
|
|
117
|
-
# Create model
|
|
118
|
-
model: torch.nn.Module = model_registry(model_name)(
|
|
119
|
-
depth=model_config.depth,
|
|
120
|
-
conv_dim=algo_config.get_conv_dim(),
|
|
121
|
-
num_channels_init=model_config.num_channels_init,
|
|
122
|
-
)
|
|
123
|
-
model.to(device)
|
|
124
|
-
|
|
125
|
-
# Load the model state dict
|
|
126
|
-
if "model_state_dict" in checkpoint:
|
|
127
|
-
model.load_state_dict(checkpoint["model_state_dict"])
|
|
128
|
-
logger.info("Loaded model state dict")
|
|
129
|
-
else:
|
|
130
|
-
raise ValueError("Invalid checkpoint format")
|
|
131
|
-
|
|
132
|
-
# Load the optimizer and scheduler
|
|
133
|
-
optimizer, scheduler = get_optimizer_and_scheduler(
|
|
134
|
-
config, model, state_dict=checkpoint
|
|
135
|
-
)
|
|
136
|
-
scaler = get_grad_scaler(config, state_dict=checkpoint)
|
|
137
|
-
|
|
138
|
-
elif config is not None:
|
|
139
|
-
# Create model from configuration
|
|
140
|
-
algo_config = config.algorithm
|
|
141
|
-
model_config = algo_config.model_parameters
|
|
142
|
-
model_name = algo_config.model
|
|
143
|
-
|
|
144
|
-
# Create model
|
|
145
|
-
model = model_registry(model_name)(
|
|
146
|
-
depth=model_config.depth,
|
|
147
|
-
conv_dim=algo_config.get_conv_dim(),
|
|
148
|
-
num_channels_init=model_config.num_channels_init,
|
|
149
|
-
)
|
|
150
|
-
model.to(device)
|
|
151
|
-
optimizer, scheduler = get_optimizer_and_scheduler(config, model)
|
|
152
|
-
scaler = get_grad_scaler(config)
|
|
153
|
-
logger.info("Engine initialized from configuration")
|
|
41
|
+
if model_configuration.architecture == SupportedArchitecture.UNET:
|
|
42
|
+
return UNet(**model_configuration.model_dump())
|
|
43
|
+
elif model_configuration.architecture == SupportedArchitecture.CUSTOM:
|
|
44
|
+
assert isinstance(model_configuration, CustomModel)
|
|
45
|
+
model = get_custom_model(model_configuration.name)
|
|
154
46
|
|
|
47
|
+
return model(**model_configuration.model_dump())
|
|
155
48
|
else:
|
|
156
|
-
raise
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def get_optimizer_and_scheduler(
|
|
162
|
-
config: Configuration, model: torch.nn.Module, state_dict: Optional[Dict] = None
|
|
163
|
-
) -> Tuple[
|
|
164
|
-
torch.optim.Optimizer,
|
|
165
|
-
Union[
|
|
166
|
-
torch.optim.lr_scheduler.LRScheduler,
|
|
167
|
-
torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
|
|
168
|
-
],
|
|
169
|
-
]:
|
|
170
|
-
"""
|
|
171
|
-
Create optimizer and learning rate schedulers.
|
|
172
|
-
|
|
173
|
-
If a checkpoint state dictionary is provided, the optimizer and scheduler are
|
|
174
|
-
instantiated to the same state as the checkpoint's optimizer and scheduler.
|
|
175
|
-
|
|
176
|
-
Parameters
|
|
177
|
-
----------
|
|
178
|
-
config : Configuration
|
|
179
|
-
Configuration.
|
|
180
|
-
model : torch.nn.Module
|
|
181
|
-
Model.
|
|
182
|
-
state_dict : Optional[Dict], optional
|
|
183
|
-
Checkpoint state dictionary, by default None.
|
|
184
|
-
|
|
185
|
-
Returns
|
|
186
|
-
-------
|
|
187
|
-
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]
|
|
188
|
-
Optimizer and scheduler.
|
|
189
|
-
"""
|
|
190
|
-
# retrieve optimizer name and parameters from config
|
|
191
|
-
optimizer_name = config.training.optimizer.name
|
|
192
|
-
optimizer_params = config.training.optimizer.parameters
|
|
193
|
-
|
|
194
|
-
# then instantiate it
|
|
195
|
-
optimizer_func = getattr(torch.optim, optimizer_name)
|
|
196
|
-
optimizer = optimizer_func(model.parameters(), **optimizer_params)
|
|
197
|
-
|
|
198
|
-
# same for learning rate scheduler
|
|
199
|
-
scheduler_name = config.training.lr_scheduler.name
|
|
200
|
-
scheduler_params = config.training.lr_scheduler.parameters
|
|
201
|
-
scheduler_func = getattr(torch.optim.lr_scheduler, scheduler_name)
|
|
202
|
-
scheduler = scheduler_func(optimizer, **scheduler_params)
|
|
203
|
-
|
|
204
|
-
# load state from ther checkpoint if available
|
|
205
|
-
if state_dict is not None:
|
|
206
|
-
if "optimizer_state_dict" in state_dict:
|
|
207
|
-
optimizer.load_state_dict(state_dict["optimizer_state_dict"])
|
|
208
|
-
logger.info("Loaded optimizer state dict")
|
|
209
|
-
else:
|
|
210
|
-
logger.warning(
|
|
211
|
-
"No optimizer state dict found in checkpoint. Optimizer not loaded."
|
|
212
|
-
)
|
|
213
|
-
if "scheduler_state_dict" in state_dict:
|
|
214
|
-
scheduler.load_state_dict(state_dict["scheduler_state_dict"])
|
|
215
|
-
logger.info("Loaded LR scheduler state dict")
|
|
216
|
-
else:
|
|
217
|
-
logger.warning(
|
|
218
|
-
"No LR scheduler state dict found in checkpoint. "
|
|
219
|
-
"LR scheduler not loaded."
|
|
220
|
-
)
|
|
221
|
-
return optimizer, scheduler
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
def get_grad_scaler(
|
|
225
|
-
config: Configuration, state_dict: Optional[Dict] = None
|
|
226
|
-
) -> torch.cuda.amp.GradScaler:
|
|
227
|
-
"""
|
|
228
|
-
Instantiate gradscaler.
|
|
229
|
-
|
|
230
|
-
If a checkpoint state dictionary is provided, the scaler is instantiated to the
|
|
231
|
-
same state as the checkpoint's scaler.
|
|
232
|
-
|
|
233
|
-
Parameters
|
|
234
|
-
----------
|
|
235
|
-
config : Configuration
|
|
236
|
-
Configuration.
|
|
237
|
-
state_dict : Optional[Dict], optional
|
|
238
|
-
Checkpoint state dictionary, by default None.
|
|
239
|
-
|
|
240
|
-
Returns
|
|
241
|
-
-------
|
|
242
|
-
torch.cuda.amp.GradScaler
|
|
243
|
-
Instantiated gradscaler.
|
|
244
|
-
"""
|
|
245
|
-
use = config.training.amp.use
|
|
246
|
-
scaling = config.training.amp.init_scale
|
|
247
|
-
scaler = torch.cuda.amp.GradScaler(init_scale=scaling, enabled=use)
|
|
248
|
-
if state_dict is not None and "scaler_state_dict" in state_dict:
|
|
249
|
-
scaler.load_state_dict(state_dict["scaler_state_dict"])
|
|
250
|
-
logger.info("Loaded GradScaler state dict")
|
|
251
|
-
return scaler
|
|
49
|
+
raise NotImplementedError(
|
|
50
|
+
f"Model {model_configuration.architecture} is not implemented or unknown."
|
|
51
|
+
)
|